1 | #pragma once |
2 | |
3 | #include <array> |
4 | |
5 | #include <spirv/unified1/spirv.hpp> |
6 | #include "taichi/util/lang_util.h" |
7 | #include "taichi/ir/type.h" |
8 | #include "taichi/util/testing.h" |
9 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
10 | #include "taichi/rhi/device.h" |
11 | #include "taichi/ir/statements.h" |
12 | |
13 | namespace taichi::lang { |
14 | namespace spirv { |
15 | |
16 | template <bool stop, std::size_t I, typename F> |
17 | struct for_each_dispatcher { |
18 | template <typename T, typename... Args> |
19 | static void run(const F &f, T &&value, Args &&...args) { // NOLINT(*) |
20 | f(I, std::forward<T>(value)); |
21 | for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run( |
22 | f, std::forward<Args>(args)...); |
23 | } |
24 | }; |
25 | |
26 | template <std::size_t I, typename F> |
27 | struct for_each_dispatcher<true, I, F> { |
28 | static void run(const F &f) { |
29 | } // NOLINT(*) |
30 | }; |
31 | |
32 | template <typename F, typename... Args> |
33 | inline void for_each(const F &f, Args &&...args) { // NOLINT(*) |
34 | for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run( |
35 | f, std::forward<Args>(args)...); |
36 | } |
37 | |
38 | enum class TypeKind { |
39 | kPrimitive, |
40 | kSNodeStruct, |
41 | kSNodeArray, // array components of a kSNodeStruct |
42 | kStruct, |
43 | kPtr, |
44 | kFunc, |
45 | kImage |
46 | }; |
47 | |
48 | // Represent the SPIRV Type |
49 | struct SType { |
50 | // The Id to represent type |
51 | uint32_t id{0}; |
52 | |
53 | // corresponding Taichi type/Compiled SNode info |
54 | DataType dt; |
55 | |
56 | SNodeDescriptor snode_desc; // TODO: dt/snode_desc only need one at a time |
57 | std::vector<uint32_t> snode_child_type_id; |
58 | |
59 | TypeKind flag{TypeKind::kPrimitive}; |
60 | |
61 | // Content type id if it is a pointer/struct-array class |
62 | // TODO: SNODE need a vector to store their childrens' element type id |
63 | uint32_t element_type_id{0}; |
64 | |
65 | // The storage class, if it is a pointer |
66 | spv::StorageClass storage_class{spv::StorageClassMax}; |
67 | }; |
68 | |
69 | enum class ValueKind { |
70 | kNormal, |
71 | kConstant, |
72 | kVectorPtr, |
73 | kStructArrayPtr, |
74 | kVariablePtr, |
75 | kPhysicalPtr, |
76 | kTexture, |
77 | kFunction, |
78 | kExtInst |
79 | }; |
80 | |
81 | // Represent the SPIRV Value |
82 | struct Value { |
83 | // The Id to represent type |
84 | uint32_t id{0}; |
85 | // The data type |
86 | SType stype; |
87 | // Additional flags about the value |
88 | ValueKind flag{ValueKind::kNormal}; |
89 | |
90 | bool operator==(const Value &rhs) const { |
91 | return id == rhs.id; |
92 | } |
93 | }; |
94 | |
95 | struct ValueHasher { |
96 | size_t operator()(const spirv::Value &v) const { |
97 | return std::hash<uint32_t>()(v.id); |
98 | } |
99 | }; |
100 | |
101 | // Represent the SPIRV Label |
102 | struct Label { |
103 | // The Id to represent label |
104 | uint32_t id{0}; |
105 | }; |
106 | |
107 | // A SPIRV instruction, |
108 | // can be used as handle to modify its content later |
109 | class Instr { |
110 | public: |
111 | uint32_t word_count() const { |
112 | return word_count_; |
113 | } |
114 | |
115 | uint32_t &operator[](uint32_t idx) { |
116 | TI_ASSERT(idx < word_count_); |
117 | return (*data_)[begin_ + idx]; |
118 | } |
119 | |
120 | private: |
121 | friend class InstrBuilder; |
122 | |
123 | std::vector<uint32_t> *data_{nullptr}; |
124 | uint32_t begin_{0}; |
125 | uint32_t word_count_{0}; |
126 | }; |
127 | |
128 | // Representation of phi value |
129 | struct PhiValue : public Value { |
130 | Instr instr; |
131 | |
132 | void set_incoming(uint32_t index, const Value &value, const Label &parent) { |
133 | TI_ASSERT(this->stype.id == value.stype.id); |
134 | instr[3 + index * 2] = value.id; |
135 | instr[3 + index * 2 + 1] = parent.id; |
136 | } |
137 | }; |
138 | |
139 | // Helper class to build SPIRV instruction |
140 | class InstrBuilder { |
141 | public: |
142 | InstrBuilder &begin(spv::Op op) { |
143 | TI_ASSERT(data_.size() == 0U); |
144 | op_ = op; |
145 | data_.push_back(0); |
146 | return *this; |
147 | } |
148 | |
149 | #define ADD(var, id) \ |
150 | InstrBuilder &add(const var &v) { \ |
151 | data_.push_back(id); \ |
152 | return *this; \ |
153 | } |
154 | |
155 | ADD(Value, v.id); |
156 | ADD(SType, v.id); |
157 | ADD(Label, v.id); |
158 | ADD(uint32_t, v); |
159 | #undef ADD |
160 | |
161 | InstrBuilder &add(const std::vector<uint32_t> &v) { |
162 | for (const auto &v0 : v) { |
163 | add(v0); |
164 | } |
165 | return *this; |
166 | } |
167 | |
168 | InstrBuilder &add(const std::string &v) { |
169 | const uint32_t word_size = sizeof(uint32_t); |
170 | const auto nwords = |
171 | (static_cast<uint32_t>(v.length()) + word_size) / word_size; |
172 | size_t begin = data_.size(); |
173 | data_.resize(begin + nwords, 0U); |
174 | std::copy(v.begin(), v.end(), reinterpret_cast<char *>(&data_[begin])); |
175 | return *this; |
176 | } |
177 | |
178 | template <typename... Args> |
179 | InstrBuilder &add_seq(Args &&...args) { |
180 | AddSeqHelper helper; |
181 | helper.builder = this; |
182 | for_each(helper, std::forward<Args>(args)...); |
183 | return *this; |
184 | } |
185 | |
186 | Instr commit(std::vector<uint32_t> *seg) { |
187 | Instr ret; |
188 | ret.data_ = seg; |
189 | ret.begin_ = seg->size(); |
190 | ret.word_count_ = static_cast<uint32_t>(data_.size()); |
191 | data_[0] = op_ | (ret.word_count_ << spv::WordCountShift); |
192 | seg->insert(seg->end(), data_.begin(), data_.end()); |
193 | data_.clear(); |
194 | return ret; |
195 | } |
196 | |
197 | private: |
198 | // current op code |
199 | spv::Op op_; |
200 | // The internal data to store code |
201 | std::vector<uint32_t> data_; |
202 | // helper class to support variadic arguments |
203 | struct AddSeqHelper { |
204 | // The reference to builder |
205 | InstrBuilder *builder; |
206 | // invoke function |
207 | template <typename T> |
208 | void operator()(size_t, const T &v) const { |
209 | builder->add(v); |
210 | } |
211 | }; |
212 | }; |
213 | |
214 | // Builder to build up a single SPIR-V module |
215 | class IRBuilder { |
216 | public: |
217 | IRBuilder(Arch arch, const DeviceCapabilityConfig *caps) |
218 | : arch_(arch), caps_(caps) { |
219 | } |
220 | |
221 | template <typename... Args> |
222 | void debug_name(spv::Op op, Args &&...args) { |
223 | ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&names_); |
224 | } |
225 | |
226 | Value debug_string(std::string str); |
227 | |
228 | template <typename... Args> |
229 | void execution_mode(Value func, Args &&...args) { |
230 | ib_.begin(spv::OpExecutionMode) |
231 | .add_seq(func, std::forward<Args>(args)...) |
232 | .commit(&exec_mode_); |
233 | } |
234 | |
235 | template <typename... Args> |
236 | void decorate(spv::Op op, Args &&...args) { |
237 | ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&decorate_); |
238 | } |
239 | |
240 | template <typename... Args> |
241 | void declare_global(spv::Op op, Args &&...args) { |
242 | ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&global_); |
243 | } |
244 | |
245 | template <typename... Args> |
246 | Instr make_inst(spv::Op op, Args &&...args) { |
247 | return ib_.begin(op) |
248 | .add_seq(std::forward<Args>(args)...) |
249 | .commit(&function_); |
250 | } |
251 | |
252 | // Initialize header |
253 | void (); |
254 | // Initialize the predefined contents |
255 | void init_pre_defs(); |
256 | // Get the final binary built from the builder, return The finalized binary |
257 | // instruction |
258 | std::vector<uint32_t> finalize(); |
259 | |
260 | Value ext_inst_import(const std::string &name) { |
261 | Value val = new_value(SType(), ValueKind::kExtInst); |
262 | ib_.begin(spv::OpExtInstImport).add_seq(val, name).commit(&header_); |
263 | return val; |
264 | } |
265 | |
266 | Label new_label() { |
267 | Label label; |
268 | label.id = id_counter_++; |
269 | return label; |
270 | } |
271 | |
272 | // Start a new block with given label |
273 | void start_label(Label label) { |
274 | make_inst(spv::OpLabel, label); |
275 | curr_label_ = label; |
276 | } |
277 | |
278 | Label current_label() const { |
279 | return curr_label_; |
280 | } |
281 | |
282 | // Make a new SSA value |
283 | template <typename... Args> |
284 | Value make_value(spv::Op op, const SType &out_type, Args &&...args) { |
285 | Value val = new_value(out_type, ValueKind::kNormal); |
286 | make_inst(op, out_type, val, std::forward<Args>(args)...); |
287 | if (out_type.flag == TypeKind::kPtr) { |
288 | val.flag = ValueKind::kVariablePtr; |
289 | } |
290 | return val; |
291 | } |
292 | |
293 | // Make a phi value |
294 | PhiValue make_phi(const SType &out_type, uint32_t num_incoming); |
295 | |
296 | // Create Constant Primitive Value |
297 | // cache: if a variable is named, it should not be cached, or the name may |
298 | // have conflict. |
299 | Value int_immediate_number(const SType &dtype, |
300 | int64_t value, |
301 | bool cache = true); |
302 | Value uint_immediate_number(const SType &dtype, |
303 | uint64_t value, |
304 | bool cache = true); |
305 | Value float_immediate_number(const SType &dtype, |
306 | double value, |
307 | bool cache = true); |
308 | |
309 | // Match zero type |
310 | Value get_zero(const SType &stype) { |
311 | TI_ASSERT(stype.flag == TypeKind::kPrimitive); |
312 | if (is_integral(stype.dt)) { |
313 | if (is_signed(stype.dt)) { |
314 | return int_immediate_number(stype, 0); |
315 | } else { |
316 | return uint_immediate_number(stype, 0); |
317 | } |
318 | } else if (is_real(stype.dt)) { |
319 | return float_immediate_number(stype, 0); |
320 | } else { |
321 | TI_NOT_IMPLEMENTED |
322 | return Value(); |
323 | } |
324 | } |
325 | |
326 | // Get null stype |
327 | SType get_null_type(); |
328 | // Get the spirv type for a given Taichi data type |
329 | SType get_primitive_type(const DataType &dt) const; |
330 | // Get the size in bytes of a given Taichi data type |
331 | size_t get_primitive_type_size(const DataType &dt) const; |
332 | // Get the spirv uint type with the same size of a given Taichi data type |
333 | SType get_primitive_uint_type(const DataType &dt) const; |
334 | // Get the Taichi uint type with the same size of a given Taichi data type |
335 | DataType get_taichi_uint_type(const DataType &dt) const; |
336 | // Get the pointer type that points to value_type |
337 | SType get_storage_pointer_type(const SType &value_type); |
338 | // Get the pointer type that points to value_type |
339 | SType get_pointer_type(const SType &value_type, |
340 | spv::StorageClass storage_class); |
341 | // Get an image type |
342 | SType get_sampled_image_type(const SType &primitive_type, int num_dimensions); |
343 | SType get_underlying_image_type(const SType &primitive_type, |
344 | int num_dimensions); |
345 | SType get_storage_image_type(BufferFormat format, int num_dimensions); |
346 | // Get a value_type[num_elems] type |
347 | SType get_array_type(const SType &value_type, uint32_t num_elems); |
348 | // Get a struct{ value_type[num_elems] } type |
349 | SType get_struct_array_type(const SType &value_type, uint32_t num_elems); |
350 | // Construct a struct type |
351 | SType create_struct_type( |
352 | std::vector<std::tuple<SType, std::string, size_t>> &components); |
353 | |
354 | // Declare buffer argument of function |
355 | Value buffer_struct_argument(const SType &struct_type, |
356 | uint32_t descriptor_set, |
357 | uint32_t binding, |
358 | const std::string &name); |
359 | Value uniform_struct_argument(const SType &struct_type, |
360 | uint32_t descriptor_set, |
361 | uint32_t binding, |
362 | const std::string &name); |
363 | Value buffer_argument(const SType &value_type, |
364 | uint32_t descriptor_set, |
365 | uint32_t binding, |
366 | const std::string &name); |
367 | Value struct_array_access(const SType &res_type, Value buffer, Value index); |
368 | |
369 | Value texture_argument(int num_channels, |
370 | int num_dimensions, |
371 | uint32_t descriptor_set, |
372 | uint32_t binding); |
373 | |
374 | Value storage_image_argument(int num_channels, |
375 | int num_dimensions, |
376 | uint32_t descriptor_set, |
377 | uint32_t binding, |
378 | BufferFormat format); |
379 | |
380 | Value sample_texture(Value texture_var, |
381 | const std::vector<Value> &args, |
382 | Value lod); |
383 | |
384 | Value fetch_texel(Value texture_var, |
385 | const std::vector<Value> &args, |
386 | Value lod); |
387 | |
388 | Value image_load(Value image_var, const std::vector<Value> &args); |
389 | |
390 | void image_store(Value image_var, const std::vector<Value> &args); |
391 | |
392 | // Declare a new function |
393 | // NOTE: only support void kernel function, i.e. main |
394 | Value new_function() { |
395 | return new_value(t_void_func_, ValueKind::kFunction); |
396 | } |
397 | |
398 | std::vector<Value> global_values; |
399 | |
400 | // Declare the entry point for a kernel function |
401 | void commit_kernel_function(const Value &func, |
402 | const std::string &name, |
403 | std::vector<Value> args, |
404 | std::array<int, 3> local_size) { |
405 | ib_.begin(spv::OpEntryPoint) |
406 | .add_seq(spv::ExecutionModelGLCompute, func, name); |
407 | for (const auto &arg : args) { |
408 | ib_.add(arg); |
409 | } |
410 | if (caps_->get(DeviceCapability::spirv_version) >= 0x10400) { |
411 | for (const auto &v : global_values) { |
412 | ib_.add(v); |
413 | } |
414 | } |
415 | if (gl_global_invocation_id_.id != 0) { |
416 | ib_.add(gl_global_invocation_id_); |
417 | } |
418 | if (gl_num_work_groups_.id != 0) { |
419 | ib_.add(gl_num_work_groups_); |
420 | } |
421 | ib_.commit(&entry_); |
422 | ib_.begin(spv::OpExecutionMode) |
423 | .add_seq(func, spv::ExecutionModeLocalSize, local_size[0], |
424 | local_size[1], local_size[2]) |
425 | .commit(&entry_); |
426 | } |
427 | |
428 | // Start function scope |
429 | void start_function(const Value &func) { |
430 | // add function declaration to the header |
431 | ib_.begin(spv::OpFunction) |
432 | .add_seq(t_void_, func, 0, t_void_func_) |
433 | .commit(&func_header_); |
434 | |
435 | spirv::Label start_label = this->new_label(); |
436 | ib_.begin(spv::OpLabel).add_seq(start_label).commit(&func_header_); |
437 | curr_label_ = start_label; |
438 | } |
439 | |
440 | // Declare gl compute shader related methods |
441 | void set_work_group_size(const std::array<int, 3> group_size); |
442 | Value get_work_group_size(uint32_t dim_index); |
443 | Value get_num_work_groups(uint32_t dim_index); |
444 | Value get_local_invocation_id(uint32_t dim_index); |
445 | Value get_global_invocation_id(uint32_t dim_index); |
446 | Value get_subgroup_invocation_id(); |
447 | Value get_subgroup_size(); |
448 | |
449 | // Expressions |
450 | Value add(Value a, Value b); |
451 | Value sub(Value a, Value b); |
452 | Value mul(Value a, Value b); |
453 | Value div(Value a, Value b); |
454 | Value mod(Value a, Value b); |
455 | Value eq(Value a, Value b); |
456 | Value ne(Value a, Value b); |
457 | Value lt(Value a, Value b); |
458 | Value le(Value a, Value b); |
459 | Value gt(Value a, Value b); |
460 | Value ge(Value a, Value b); |
461 | Value (Value base, Value offset, Value count); |
462 | Value select(Value cond, Value a, Value b); |
463 | |
464 | // Create a cast that cast value to dst_type |
465 | Value cast(const SType &dst_type, Value value); |
466 | |
467 | // Create a GLSL450 call |
468 | template <typename... Args> |
469 | Value call_glsl450(const SType &ret_type, uint32_t inst_id, Args &&...args) { |
470 | Value val = new_value(ret_type, ValueKind::kNormal); |
471 | ib_.begin(spv::OpExtInst) |
472 | .add_seq(ret_type, val, ext_glsl450_, inst_id) |
473 | .add_seq(std::forward<Args>(args)...) |
474 | .commit(&function_); |
475 | return val; |
476 | } |
477 | |
478 | // Create a debugPrintf call |
479 | void call_debugprintf(std::string formats, const std::vector<Value> &args) { |
480 | Value format_str = debug_string(formats); |
481 | Value val = new_value(t_void_, ValueKind::kNormal); |
482 | ib_.begin(spv::OpExtInst) |
483 | .add_seq(t_void_, val, debug_printf_, 1, format_str); |
484 | for (const auto &arg : args) { |
485 | ib_.add(arg); |
486 | } |
487 | ib_.commit(&function_); |
488 | } |
489 | |
490 | // Local allocate, load, store methods |
491 | Value alloca_variable(const SType &type); |
492 | Value alloca_workgroup_array(const SType &type); |
493 | Value load_variable(Value pointer, const SType &res_type); |
494 | void store_variable(Value pointer, Value value); |
495 | |
496 | // Register name to corresponding Value/VariablePointer |
497 | void register_value(std::string name, Value value); |
498 | // Query Value/VariablePointer by name |
499 | Value query_value(std::string name) const; |
500 | // Check whether a value has been evaluated |
501 | bool check_value_existence(const std::string &name) const; |
502 | // Create a new SSA value |
503 | Value new_value(const SType &type, ValueKind flag) { |
504 | Value val; |
505 | val.id = id_counter_++; |
506 | val.stype = type; |
507 | val.flag = flag; |
508 | return val; |
509 | } |
510 | |
511 | // Support easy access to trivial data types |
512 | SType i64_type() const { |
513 | return t_int64_; |
514 | } |
515 | SType u64_type() const { |
516 | return t_uint64_; |
517 | } |
518 | SType f64_type() const { |
519 | return t_fp64_; |
520 | } |
521 | |
522 | SType i32_type() const { |
523 | return t_int32_; |
524 | } |
525 | SType u32_type() const { |
526 | return t_uint32_; |
527 | } |
528 | SType f32_type() const { |
529 | return t_fp32_; |
530 | } |
531 | |
532 | SType i16_type() const { |
533 | return t_int16_; |
534 | } |
535 | SType u16_type() const { |
536 | return t_uint16_; |
537 | } |
538 | SType f16_type() const { |
539 | return t_fp16_; |
540 | } |
541 | |
542 | SType i8_type() const { |
543 | return t_int8_; |
544 | } |
545 | SType u8_type() const { |
546 | return t_uint8_; |
547 | } |
548 | |
549 | SType bool_type() const { |
550 | return t_bool_; |
551 | } |
552 | |
553 | // quick cache for const zero/one i32 |
554 | Value const_i32_zero_; |
555 | Value const_i32_one_; |
556 | |
557 | // Use force-inline float atomic helper function |
558 | Value float_atomic(AtomicOpType op_type, Value addr_ptr, Value data); |
559 | Value rand_u32(Value global_tmp_); |
560 | Value rand_f32(Value global_tmp_); |
561 | Value rand_i32(Value global_tmp_); |
562 | |
563 | private: |
564 | Value get_const(const SType &dtype, const uint64_t *pvalue, bool cache); |
565 | SType declare_primitive_type(DataType dt); |
566 | |
567 | void init_random_function(Value global_tmp_); |
568 | |
569 | Arch arch_; |
570 | const DeviceCapabilityConfig *caps_; |
571 | |
572 | // internal instruction builder |
573 | InstrBuilder ib_; |
574 | // Current label |
575 | Label curr_label_; |
576 | // The current maximum id |
577 | uint32_t id_counter_{1}; |
578 | |
579 | // glsl 450 extension |
580 | Value ext_glsl450_; |
581 | |
582 | // debugprint extension |
583 | Value debug_printf_; |
584 | |
585 | SType t_bool_; |
586 | SType t_int8_; |
587 | SType t_int16_; |
588 | SType t_int32_; |
589 | SType t_int64_; |
590 | SType t_uint8_; |
591 | SType t_uint16_; |
592 | SType t_uint32_; |
593 | SType t_uint64_; |
594 | SType t_fp16_; |
595 | SType t_fp32_; |
596 | SType t_fp64_; |
597 | SType t_void_; |
598 | SType t_void_func_; |
599 | // gl compute shader related type(s) and variables |
600 | SType t_v2_int_; |
601 | SType t_v3_int_; |
602 | SType t_v3_uint_; |
603 | SType t_v4_fp32_; |
604 | SType t_v3_fp32_; |
605 | SType t_v2_fp32_; |
606 | Value gl_global_invocation_id_; |
607 | Value gl_local_invocation_id_; |
608 | Value gl_num_work_groups_; |
609 | Value gl_work_group_size_; |
610 | Value subgroup_local_invocation_id_; |
611 | Value subgroup_size_; |
612 | |
613 | // Random function and variables |
614 | bool init_rand_{false}; |
615 | Value rand_x_; |
616 | Value rand_y_; |
617 | Value rand_z_; |
618 | Value rand_w_; // per-thread local variable |
619 | |
620 | // map from value to its pointer type |
621 | std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_; |
622 | std::map<std::pair<uint32_t, int>, SType> sampled_image_ptr_tbl_; |
623 | std::map<std::pair<uint32_t, int>, SType> |
624 | sampled_image_underlying_image_type_; |
625 | |
626 | std::map<std::pair<BufferFormat, int>, SType> storage_image_ptr_tbl_; |
627 | |
628 | // map from constant int to its value |
629 | std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_; |
630 | // map from raw_name(string) to Value |
631 | std::unordered_map<std::string, Value> value_name_tbl_; |
632 | |
633 | // Header segment, include import |
634 | std::vector<uint32_t> ; |
635 | // engtry point segment |
636 | std::vector<uint32_t> entry_; |
637 | // Header segment |
638 | std::vector<uint32_t> exec_mode_; |
639 | // Debug segment |
640 | // According to SPIR-V spec, the following debug instructions must be |
641 | // grouped in the order: |
642 | // - All OpString, OpSourceExtension, OpSource, and OpSourceContinued, |
643 | // without forward references. |
644 | // - All OpName and all OpMemberName. |
645 | // - All OpModuleProcessed instructions. |
646 | |
647 | // OpString segment |
648 | std::vector<uint32_t> strings_; |
649 | // OpName segment |
650 | std::vector<uint32_t> names_; |
651 | // Annotation segment |
652 | std::vector<uint32_t> decorate_; |
653 | // Global segment: types, variables, types |
654 | std::vector<uint32_t> global_; |
655 | // Function header segment |
656 | std::vector<uint32_t> ; |
657 | // Main Function segment |
658 | std::vector<uint32_t> function_; |
659 | }; |
660 | } // namespace spirv |
661 | } // namespace taichi::lang |
662 | |