1 | #include "taichi/codegen/spirv/kernel_utils.h" |
2 | |
3 | #include <unordered_map> |
4 | |
5 | #include "taichi/program/kernel.h" |
6 | #define TI_RUNTIME_HOST |
7 | #include "taichi/program/context.h" |
8 | #undef TI_RUNTIME_HOST |
9 | |
10 | namespace taichi::lang { |
11 | namespace spirv { |
12 | |
13 | // static |
14 | std::string TaskAttributes::buffers_name(BufferInfo b) { |
15 | if (b.type == BufferType::Args) { |
16 | return "Args" ; |
17 | } |
18 | if (b.type == BufferType::Rets) { |
19 | return "Rets" ; |
20 | } |
21 | if (b.type == BufferType::GlobalTmps) { |
22 | return "GlobalTmps" ; |
23 | } |
24 | if (b.type == BufferType::Root) { |
25 | return std::string("Root: " ) + std::to_string(b.root_id); |
26 | } |
27 | TI_ERROR("unrecognized buffer type" ); |
28 | } |
29 | |
30 | std::string TaskAttributes::debug_string() const { |
31 | std::string result; |
32 | result += fmt::format( |
33 | "<TaskAttributes name={} advisory_total_num_threads={} " |
34 | "task_type={} buffers=[ " , |
35 | name, advisory_total_num_threads, offloaded_task_type_name(task_type)); |
36 | for (auto b : buffer_binds) { |
37 | result += buffers_name(b.buffer) + " " ; |
38 | } |
39 | result += "]" ; // closes |buffers| |
40 | // TODO(k-ye): show range_for |
41 | result += ">" ; |
42 | return result; |
43 | } |
44 | |
45 | std::string TaskAttributes::BufferBind::debug_string() const { |
46 | return fmt::format("<type={} binding={}>" , |
47 | TaskAttributes::buffers_name(buffer), binding); |
48 | } |
49 | |
50 | KernelContextAttributes::KernelContextAttributes( |
51 | const Kernel &kernel, |
52 | const DeviceCapabilityConfig *caps) |
53 | : args_bytes_(0), |
54 | rets_bytes_(0), |
55 | extra_args_bytes_(RuntimeContext::extra_args_size) { |
56 | arr_access.resize(kernel.parameter_list.size(), irpass::ExternalPtrAccess(0)); |
57 | arg_attribs_vec_.reserve(kernel.parameter_list.size()); |
58 | // TODO: We should be able to limit Kernel args and rets to be primitive types |
59 | // as well but let's leave that as a followup up PR. |
60 | for (const auto &ka : kernel.parameter_list) { |
61 | ArgAttributes aa; |
62 | aa.dtype = ka.get_element_type()->as<PrimitiveType>()->type; |
63 | const size_t dt_bytes = ka.get_element_size(); |
64 | aa.is_array = ka.is_array; |
65 | if (aa.is_array) { |
66 | aa.field_dim = ka.total_dim - ka.get_element_shape().size(); |
67 | aa.element_shape = ka.get_element_shape(); |
68 | } |
69 | aa.stride = dt_bytes; |
70 | aa.index = arg_attribs_vec_.size(); |
71 | arg_attribs_vec_.push_back(aa); |
72 | } |
73 | for (const auto &kr : kernel.rets) { |
74 | RetAttributes ra; |
75 | size_t dt_bytes{0}; |
76 | if (auto tensor_type = kr.dt->cast<TensorType>()) { |
77 | auto tensor_dtype = tensor_type->get_element_type(); |
78 | TI_ASSERT(tensor_dtype->is<PrimitiveType>()); |
79 | ra.dtype = tensor_dtype->cast<PrimitiveType>()->type; |
80 | dt_bytes = data_type_size(tensor_dtype); |
81 | ra.is_array = true; |
82 | ra.stride = tensor_type->get_num_elements() * dt_bytes; |
83 | } else { |
84 | TI_ASSERT(kr.dt->is<PrimitiveType>()); |
85 | ra.dtype = kr.dt->cast<PrimitiveType>()->type; |
86 | dt_bytes = data_type_size(kr.dt); |
87 | ra.is_array = false; |
88 | ra.stride = dt_bytes; |
89 | } |
90 | ra.index = ret_attribs_vec_.size(); |
91 | ret_attribs_vec_.push_back(ra); |
92 | } |
93 | |
94 | auto arange_args = [](auto *vec, size_t offset, bool is_ret, |
95 | bool has_buffer_ptr) -> size_t { |
96 | size_t bytes = offset; |
97 | for (int i = 0; i < vec->size(); ++i) { |
98 | auto &attribs = (*vec)[i]; |
99 | const size_t dt_bytes = |
100 | (attribs.is_array && !is_ret && has_buffer_ptr) |
101 | ? sizeof(uint64_t) |
102 | : data_type_size(PrimitiveType::get(attribs.dtype)); |
103 | // Align bytes to the nearest multiple of dt_bytes |
104 | bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; |
105 | attribs.offset_in_mem = bytes; |
106 | bytes += is_ret ? attribs.stride : dt_bytes; |
107 | TI_TRACE( |
108 | " at={} {} offset_in_mem={} stride={}" , |
109 | (*vec)[i].is_array ? (is_ret ? "array" : "vector ptr" ) : "scalar" , i, |
110 | attribs.offset_in_mem, attribs.stride); |
111 | } |
112 | return bytes - offset; |
113 | }; |
114 | |
115 | TI_TRACE("args:" ); |
116 | args_bytes_ = arange_args( |
117 | &arg_attribs_vec_, 0, false, |
118 | caps->get(DeviceCapability::spirv_has_physical_storage_buffer)); |
119 | // Align to extra args |
120 | args_bytes_ = (args_bytes_ + 4 - 1) / 4 * 4; |
121 | |
122 | TI_TRACE("rets:" ); |
123 | rets_bytes_ = arange_args(&ret_attribs_vec_, 0, true, false); |
124 | |
125 | TI_TRACE("sizes: args={} rets={}" , args_bytes(), rets_bytes()); |
126 | TI_ASSERT(has_rets() == (rets_bytes_ > 0)); |
127 | } |
128 | |
129 | } // namespace spirv |
130 | } // namespace taichi::lang |
131 | |