1#pragma once
2
3#include <optional>
4#include <string>
5#include <vector>
6
7#include "taichi/ir/offloaded_task_type.h"
8#include "taichi/ir/type.h"
9#include "taichi/ir/transforms.h"
10#include "taichi/rhi/device.h"
11
12namespace taichi::lang {
13
14class Kernel;
15class SNode;
16
17namespace spirv {
18
19/**
20 * Per offloaded task attributes.
21 */
22struct TaskAttributes {
23 enum class BufferType { Root, GlobalTmps, Args, Rets, ListGen, ExtArr };
24
25 struct BufferInfo {
26 BufferType type;
27 int root_id{-1}; // only used if type==Root or type==ExtArr
28
29 BufferInfo() = default;
30
31 // NOLINTNEXTLINE(google-explicit-constructor)
32 BufferInfo(BufferType buffer_type) : type(buffer_type) {
33 }
34
35 BufferInfo(BufferType buffer_type, int root_buffer_id)
36 : type(buffer_type), root_id(root_buffer_id) {
37 }
38
39 bool operator==(const BufferInfo &other) const {
40 if (type != other.type) {
41 return false;
42 }
43 if (type == BufferType::Root || type == BufferType::ExtArr) {
44 return root_id == other.root_id;
45 }
46 return true;
47 }
48
49 TI_IO_DEF(type, root_id);
50 };
51
52 struct BufferInfoHasher {
53 std::size_t operator()(const BufferInfo &buf) const {
54 using std::hash;
55 using std::size_t;
56 using std::string;
57
58 return hash<BufferType>()(buf.type) ^ buf.root_id;
59 }
60 };
61
62 struct BufferBind {
63 BufferInfo buffer;
64 int binding{0};
65
66 std::string debug_string() const;
67
68 TI_IO_DEF(buffer, binding);
69 };
70
71 struct TextureBind {
72 int arg_id{0};
73 int binding{0};
74 bool is_storage{false};
75
76 TI_IO_DEF(arg_id, binding, is_storage);
77 };
78
79 std::string name;
80 std::string source_path;
81 // Total number of threads to launch (i.e. threads per grid). Note that this
82 // is only advisory, because eventually this number is also determined by the
83 // runtime config. This works because grid strided loop is supported.
84 int advisory_total_num_threads{0};
85 int advisory_num_threads_per_group{0};
86
87 OffloadedTaskType task_type;
88
89 struct RangeForAttributes {
90 // |begin| has different meanings depending on |const_begin|:
91 // * true : It is the left boundary of the loop known at compile time.
92 // * false: It is the offset of the begin in the global tmps buffer.
93 //
94 // Same applies to |end|.
95 size_t begin{0};
96 size_t end{0};
97 bool const_begin{true};
98 bool const_end{true};
99
100 inline bool const_range() const {
101 return (const_begin && const_end);
102 }
103
104 TI_IO_DEF(begin, end, const_begin, const_end);
105 };
106 std::vector<BufferBind> buffer_binds;
107 std::vector<TextureBind> texture_binds;
108 // Only valid when |task_type| is range_for.
109 std::optional<RangeForAttributes> range_for_attribs;
110
111 static std::string buffers_name(BufferInfo b);
112
113 std::string debug_string() const;
114
115 TI_IO_DEF(name,
116 advisory_total_num_threads,
117 advisory_num_threads_per_group,
118 task_type,
119 buffer_binds,
120 texture_binds,
121 range_for_attribs);
122};
123
124/**
125 * This class contains the attributes descriptors for both the input args and
126 * the return values of a Taichi kernel.
127 *
128 * Note that all SPIRV tasks (shaders) belonging to the same Taichi kernel will
129 * share the same kernel args (i.e. they use the same device buffer for input
130 * args and return values). This is because kernel arguments is a Taichi-level
131 * concept.
132 *
133 * Memory layout
134 *
135 * /---- input args ----\/---- ret vals -----\/-- extra args --\
136 * +----------+---------+----------+---------+-----------------+
137 * | scalar | array | scalar | array | scalar |
138 * +----------+---------+----------+---------+-----------------+
139 */
140class KernelContextAttributes {
141 private:
142 /**
143 * Attributes that are shared by the input arg and the return value.
144 */
145 struct AttribsBase {
146 // For scalar arg, this is max(stride(dt), 4)
147 // For array arg, this is #elements * max(stride(dt), 4)
148 // Unit: byte
149 size_t stride{0};
150 // Offset in the context buffer
151 size_t offset_in_mem{0};
152 // Index of the input arg or the return value in the host `Context`
153 int index{-1};
154 PrimitiveTypeID dtype{PrimitiveTypeID::unknown};
155 bool is_array{false};
156 std::vector<int> element_shape;
157 std::size_t field_dim{0};
158
159 TI_IO_DEF(stride,
160 offset_in_mem,
161 index,
162 dtype,
163 is_array,
164 element_shape,
165 field_dim);
166 };
167
168 public:
169 /**
170 * This is mostly the same as Kernel::Arg, with device specific attributes.
171 */
172 struct ArgAttributes : public AttribsBase {};
173
174 /**
175 * This is mostly the same as Kernel::Ret, with device specific attributes.
176 */
177 struct RetAttributes : public AttribsBase {};
178
179 KernelContextAttributes() = default;
180 explicit KernelContextAttributes(const Kernel &kernel,
181 const DeviceCapabilityConfig *caps);
182
183 /**
184 * Whether this kernel has any argument
185 */
186 inline bool has_args() const {
187 return !arg_attribs_vec_.empty();
188 }
189
190 inline const std::vector<ArgAttributes> &args() const {
191 return arg_attribs_vec_;
192 }
193
194 /**
195 * Whether this kernel has any return value
196 */
197 inline bool has_rets() const {
198 return !ret_attribs_vec_.empty();
199 }
200
201 inline const std::vector<RetAttributes> &rets() const {
202 return ret_attribs_vec_;
203 }
204
205 /**
206 * Whether this kernel has either arguments or return values.
207 */
208 inline bool empty() const {
209 return !(has_args() || has_rets());
210 }
211
212 /**
213 * Number of bytes needed by all the arguments.
214 */
215 inline size_t args_bytes() const {
216 return args_bytes_;
217 }
218
219 /**
220 * Number of bytes needed by all the return values.
221 */
222 inline size_t rets_bytes() const {
223 return rets_bytes_;
224 }
225
226 /**
227 * Number of bytes needed by the extra arguments.
228 *
229 * Extra argument region is used to store some metadata, like the shape of the
230 * external array.
231 */
232 inline size_t extra_args_bytes() const {
233 return extra_args_bytes_;
234 }
235
236 /**
237 * Offset (in bytes) of the extra arguments in the memory.
238 */
239 inline size_t extra_args_mem_offset() const {
240 return args_bytes();
241 }
242
243 std::vector<irpass::ExternalPtrAccess> arr_access;
244
245 TI_IO_DEF(arg_attribs_vec_,
246 ret_attribs_vec_,
247 args_bytes_,
248 rets_bytes_,
249 extra_args_bytes_,
250 arr_access);
251
252 private:
253 std::vector<ArgAttributes> arg_attribs_vec_;
254 std::vector<RetAttributes> ret_attribs_vec_;
255
256 size_t args_bytes_{0};
257 size_t rets_bytes_{0};
258 size_t extra_args_bytes_{0};
259};
260
261/**
262 * Groups all the device kernels generated from a single ti.kernel.
263 */
264struct TaichiKernelAttributes {
265 // Taichi kernel name
266 std::string name;
267 // Is this kernel for evaluating the constant fold result?
268 bool is_jit_evaluator{false};
269 // Attributes of all the tasks produced from this single Taichi kernel.
270 std::vector<TaskAttributes> tasks_attribs;
271
272 KernelContextAttributes ctx_attribs;
273
274 TI_IO_DEF(name, is_jit_evaluator, tasks_attribs, ctx_attribs);
275};
276
277} // namespace spirv
278} // namespace taichi::lang
279