1 | #pragma once |
2 | |
3 | #include <optional> |
4 | #include <string> |
5 | #include <vector> |
6 | |
7 | #include "taichi/ir/offloaded_task_type.h" |
8 | #include "taichi/ir/type.h" |
9 | #include "taichi/ir/transforms.h" |
10 | #include "taichi/rhi/device.h" |
11 | |
12 | namespace taichi::lang { |
13 | |
14 | class Kernel; |
15 | class SNode; |
16 | |
17 | namespace spirv { |
18 | |
19 | /** |
20 | * Per offloaded task attributes. |
21 | */ |
22 | struct TaskAttributes { |
23 | enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr }; |
24 | |
25 | struct BufferInfo { |
26 | BufferType type; |
27 | int root_id{-1}; // only used if type==Root or type==ExtArr |
28 | |
29 | BufferInfo() = default; |
30 | |
31 | // NOLINTNEXTLINE(google-explicit-constructor) |
32 | BufferInfo(BufferType buffer_type) : type(buffer_type) { |
33 | } |
34 | |
35 | BufferInfo(BufferType buffer_type, int root_buffer_id) |
36 | : type(buffer_type), root_id(root_buffer_id) { |
37 | } |
38 | |
39 | bool operator==(const BufferInfo &other) const { |
40 | if (type != other.type) { |
41 | return false; |
42 | } |
43 | if (type == BufferType::Root || type == BufferType::ExtArr) { |
44 | return root_id == other.root_id; |
45 | } |
46 | return true; |
47 | } |
48 | |
49 | TI_IO_DEF(type, root_id); |
50 | }; |
51 | |
52 | struct BufferInfoHasher { |
53 | std::size_t operator()(const BufferInfo &buf) const { |
54 | using std::hash; |
55 | using std::size_t; |
56 | using std::string; |
57 | |
58 | return hash<BufferType>()(buf.type) ^ buf.root_id; |
59 | } |
60 | }; |
61 | |
62 | struct BufferBind { |
63 | BufferInfo buffer; |
64 | int binding{0}; |
65 | |
66 | std::string debug_string() const; |
67 | |
68 | TI_IO_DEF(buffer, binding); |
69 | }; |
70 | |
71 | struct TextureBind { |
72 | int arg_id{0}; |
73 | int binding{0}; |
74 | bool is_storage{false}; |
75 | |
76 | TI_IO_DEF(arg_id, binding, is_storage); |
77 | }; |
78 | |
79 | std::string name; |
80 | std::string source_path; |
81 | // Total number of threads to launch (i.e. threads per grid). Note that this |
82 | // is only advisory, because eventually this number is also determined by the |
83 | // runtime config. This works because grid strided loop is supported. |
84 | int advisory_total_num_threads{0}; |
85 | int advisory_num_threads_per_group{0}; |
86 | |
87 | OffloadedTaskType task_type; |
88 | |
89 | struct RangeForAttributes { |
90 | // |begin| has different meanings depending on |const_begin|: |
91 | // * true : It is the left boundary of the loop known at compile time. |
92 | // * false: It is the offset of the begin in the global tmps buffer. |
93 | // |
94 | // Same applies to |end|. |
95 | size_t begin{0}; |
96 | size_t end{0}; |
97 | bool const_begin{true}; |
98 | bool const_end{true}; |
99 | |
100 | inline bool const_range() const { |
101 | return (const_begin && const_end); |
102 | } |
103 | |
104 | TI_IO_DEF(begin, end, const_begin, const_end); |
105 | }; |
106 | std::vector<BufferBind> buffer_binds; |
107 | std::vector<TextureBind> texture_binds; |
108 | // Only valid when |task_type| is range_for. |
109 | std::optional<RangeForAttributes> range_for_attribs; |
110 | |
111 | static std::string buffers_name(BufferInfo b); |
112 | |
113 | std::string debug_string() const; |
114 | |
115 | TI_IO_DEF(name, |
116 | advisory_total_num_threads, |
117 | advisory_num_threads_per_group, |
118 | task_type, |
119 | buffer_binds, |
120 | texture_binds, |
121 | range_for_attribs); |
122 | }; |
123 | |
124 | /** |
125 | * This class contains the attributes descriptors for both the input args and |
126 | * the return values of a Taichi kernel. |
127 | * |
128 | * Note that all SPIRV tasks (shaders) belonging to the same Taichi kernel will |
129 | * share the same kernel args (i.e. they use the same device buffer for input |
130 | * args and return values). This is because kernel arguments is a Taichi-level |
131 | * concept. |
132 | * |
133 | * Memory layout |
134 | * |
135 | * /---- input args ----\/---- ret vals -----\/-- extra args --\ |
136 | * +----------+---------+----------+---------+-----------------+ |
137 | * | scalar | array | scalar | array | scalar | |
138 | * +----------+---------+----------+---------+-----------------+ |
139 | */ |
140 | class KernelContextAttributes { |
141 | private: |
142 | /** |
143 | * Attributes that are shared by the input arg and the return value. |
144 | */ |
145 | struct AttribsBase { |
146 | // For scalar arg, this is max(stride(dt), 4) |
147 | // For array arg, this is #elements * max(stride(dt), 4) |
148 | // Unit: byte |
149 | size_t stride{0}; |
150 | // Offset in the context buffer |
151 | size_t offset_in_mem{0}; |
152 | // Index of the input arg or the return value in the host `Context` |
153 | int index{-1}; |
154 | PrimitiveTypeID dtype{PrimitiveTypeID::unknown}; |
155 | bool is_array{false}; |
156 | std::vector<int> element_shape; |
157 | std::size_t field_dim{0}; |
158 | |
159 | TI_IO_DEF(stride, |
160 | offset_in_mem, |
161 | index, |
162 | dtype, |
163 | is_array, |
164 | element_shape, |
165 | field_dim); |
166 | }; |
167 | |
168 | public: |
169 | /** |
170 | * This is mostly the same as Kernel::Arg, with device specific attributes. |
171 | */ |
172 | struct ArgAttributes : public AttribsBase {}; |
173 | |
174 | /** |
175 | * This is mostly the same as Kernel::Ret, with device specific attributes. |
176 | */ |
177 | struct RetAttributes : public AttribsBase {}; |
178 | |
179 | KernelContextAttributes() = default; |
180 | explicit KernelContextAttributes(const Kernel &kernel, |
181 | const DeviceCapabilityConfig *caps); |
182 | |
183 | /** |
184 | * Whether this kernel has any argument |
185 | */ |
186 | inline bool has_args() const { |
187 | return !arg_attribs_vec_.empty(); |
188 | } |
189 | |
190 | inline const std::vector<ArgAttributes> &args() const { |
191 | return arg_attribs_vec_; |
192 | } |
193 | |
194 | /** |
195 | * Whether this kernel has any return value |
196 | */ |
197 | inline bool has_rets() const { |
198 | return !ret_attribs_vec_.empty(); |
199 | } |
200 | |
201 | inline const std::vector<RetAttributes> &rets() const { |
202 | return ret_attribs_vec_; |
203 | } |
204 | |
205 | /** |
206 | * Whether this kernel has either arguments or return values. |
207 | */ |
208 | inline bool empty() const { |
209 | return !(has_args() || has_rets()); |
210 | } |
211 | |
212 | /** |
213 | * Number of bytes needed by all the arguments. |
214 | */ |
215 | inline size_t args_bytes() const { |
216 | return args_bytes_; |
217 | } |
218 | |
219 | /** |
220 | * Number of bytes needed by all the return values. |
221 | */ |
222 | inline size_t rets_bytes() const { |
223 | return rets_bytes_; |
224 | } |
225 | |
226 | /** |
227 | * Number of bytes needed by the extra arguments. |
228 | * |
229 | * Extra argument region is used to store some metadata, like the shape of the |
230 | * external array. |
231 | */ |
232 | inline size_t () const { |
233 | return extra_args_bytes_; |
234 | } |
235 | |
236 | /** |
237 | * Offset (in bytes) of the extra arguments in the memory. |
238 | */ |
239 | inline size_t () const { |
240 | return args_bytes(); |
241 | } |
242 | |
243 | std::vector<irpass::ExternalPtrAccess> arr_access; |
244 | |
245 | TI_IO_DEF(arg_attribs_vec_, |
246 | ret_attribs_vec_, |
247 | args_bytes_, |
248 | rets_bytes_, |
249 | extra_args_bytes_, |
250 | arr_access); |
251 | |
252 | private: |
253 | std::vector<ArgAttributes> arg_attribs_vec_; |
254 | std::vector<RetAttributes> ret_attribs_vec_; |
255 | |
256 | size_t args_bytes_{0}; |
257 | size_t rets_bytes_{0}; |
258 | size_t {0}; |
259 | }; |
260 | |
261 | /** |
262 | * Groups all the device kernels generated from a single ti.kernel. |
263 | */ |
264 | struct TaichiKernelAttributes { |
265 | // Taichi kernel name |
266 | std::string name; |
267 | // Is this kernel for evaluating the constant fold result? |
268 | bool is_jit_evaluator{false}; |
269 | // Attributes of all the tasks produced from this single Taichi kernel. |
270 | std::vector<TaskAttributes> tasks_attribs; |
271 | |
272 | KernelContextAttributes ctx_attribs; |
273 | |
274 | TI_IO_DEF(name, is_jit_evaluator, tasks_attribs, ctx_attribs); |
275 | }; |
276 | |
277 | } // namespace spirv |
278 | } // namespace taichi::lang |
279 | |