1#pragma once
2
3#include <array>
4
5#include <spirv/unified1/spirv.hpp>
6#include "taichi/util/lang_util.h"
7#include "taichi/ir/type.h"
8#include "taichi/util/testing.h"
9#include "taichi/codegen/spirv/snode_struct_compiler.h"
10#include "taichi/rhi/device.h"
11#include "taichi/ir/statements.h"
12
13namespace taichi::lang {
14namespace spirv {
15
16template <bool stop, std::size_t I, typename F>
17struct for_each_dispatcher {
18 template <typename T, typename... Args>
19 static void run(const F &f, T &&value, Args &&...args) { // NOLINT(*)
20 f(I, std::forward<T>(value));
21 for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(
22 f, std::forward<Args>(args)...);
23 }
24};
25
26template <std::size_t I, typename F>
27struct for_each_dispatcher<true, I, F> {
28 static void run(const F &f) {
29 } // NOLINT(*)
30};
31
32template <typename F, typename... Args>
33inline void for_each(const F &f, Args &&...args) { // NOLINT(*)
34 for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(
35 f, std::forward<Args>(args)...);
36}
37
38enum class TypeKind {
39 kPrimitive,
40 kSNodeStruct,
41 kSNodeArray, // array components of a kSNodeStruct
42 kStruct,
43 kPtr,
44 kFunc,
45 kImage
46};
47
48// Represent the SPIRV Type
49struct SType {
50 // The Id to represent type
51 uint32_t id{0};
52
53 // corresponding Taichi type/Compiled SNode info
54 DataType dt;
55
56 SNodeDescriptor snode_desc; // TODO: dt/snode_desc only need one at a time
57 std::vector<uint32_t> snode_child_type_id;
58
59 TypeKind flag{TypeKind::kPrimitive};
60
61 // Content type id if it is a pointer/struct-array class
62 // TODO: SNODE need a vector to store their childrens' element type id
63 uint32_t element_type_id{0};
64
65 // The storage class, if it is a pointer
66 spv::StorageClass storage_class{spv::StorageClassMax};
67};
68
69enum class ValueKind {
70 kNormal,
71 kConstant,
72 kVectorPtr,
73 kStructArrayPtr,
74 kVariablePtr,
75 kPhysicalPtr,
76 kTexture,
77 kFunction,
78 kExtInst
79};
80
81// Represent the SPIRV Value
82struct Value {
83 // The Id to represent type
84 uint32_t id{0};
85 // The data type
86 SType stype;
87 // Additional flags about the value
88 ValueKind flag{ValueKind::kNormal};
89
90 bool operator==(const Value &rhs) const {
91 return id == rhs.id;
92 }
93};
94
95struct ValueHasher {
96 size_t operator()(const spirv::Value &v) const {
97 return std::hash<uint32_t>()(v.id);
98 }
99};
100
101// Represent the SPIRV Label
102struct Label {
103 // The Id to represent label
104 uint32_t id{0};
105};
106
107// A SPIRV instruction,
108// can be used as handle to modify its content later
109class Instr {
110 public:
111 uint32_t word_count() const {
112 return word_count_;
113 }
114
115 uint32_t &operator[](uint32_t idx) {
116 TI_ASSERT(idx < word_count_);
117 return (*data_)[begin_ + idx];
118 }
119
120 private:
121 friend class InstrBuilder;
122
123 std::vector<uint32_t> *data_{nullptr};
124 uint32_t begin_{0};
125 uint32_t word_count_{0};
126};
127
128// Representation of phi value
129struct PhiValue : public Value {
130 Instr instr;
131
132 void set_incoming(uint32_t index, const Value &value, const Label &parent) {
133 TI_ASSERT(this->stype.id == value.stype.id);
134 instr[3 + index * 2] = value.id;
135 instr[3 + index * 2 + 1] = parent.id;
136 }
137};
138
139// Helper class to build SPIRV instruction
140class InstrBuilder {
141 public:
142 InstrBuilder &begin(spv::Op op) {
143 TI_ASSERT(data_.size() == 0U);
144 op_ = op;
145 data_.push_back(0);
146 return *this;
147 }
148
149#define ADD(var, id) \
150 InstrBuilder &add(const var &v) { \
151 data_.push_back(id); \
152 return *this; \
153 }
154
155 ADD(Value, v.id);
156 ADD(SType, v.id);
157 ADD(Label, v.id);
158 ADD(uint32_t, v);
159#undef ADD
160
161 InstrBuilder &add(const std::vector<uint32_t> &v) {
162 for (const auto &v0 : v) {
163 add(v0);
164 }
165 return *this;
166 }
167
168 InstrBuilder &add(const std::string &v) {
169 const uint32_t word_size = sizeof(uint32_t);
170 const auto nwords =
171 (static_cast<uint32_t>(v.length()) + word_size) / word_size;
172 size_t begin = data_.size();
173 data_.resize(begin + nwords, 0U);
174 std::copy(v.begin(), v.end(), reinterpret_cast<char *>(&data_[begin]));
175 return *this;
176 }
177
178 template <typename... Args>
179 InstrBuilder &add_seq(Args &&...args) {
180 AddSeqHelper helper;
181 helper.builder = this;
182 for_each(helper, std::forward<Args>(args)...);
183 return *this;
184 }
185
186 Instr commit(std::vector<uint32_t> *seg) {
187 Instr ret;
188 ret.data_ = seg;
189 ret.begin_ = seg->size();
190 ret.word_count_ = static_cast<uint32_t>(data_.size());
191 data_[0] = op_ | (ret.word_count_ << spv::WordCountShift);
192 seg->insert(seg->end(), data_.begin(), data_.end());
193 data_.clear();
194 return ret;
195 }
196
197 private:
198 // current op code
199 spv::Op op_;
200 // The internal data to store code
201 std::vector<uint32_t> data_;
202 // helper class to support variadic arguments
203 struct AddSeqHelper {
204 // The reference to builder
205 InstrBuilder *builder;
206 // invoke function
207 template <typename T>
208 void operator()(size_t, const T &v) const {
209 builder->add(v);
210 }
211 };
212};
213
214// Builder to build up a single SPIR-V module
215class IRBuilder {
216 public:
217 IRBuilder(Arch arch, const DeviceCapabilityConfig *caps)
218 : arch_(arch), caps_(caps) {
219 }
220
221 template <typename... Args>
222 void debug_name(spv::Op op, Args &&...args) {
223 ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&names_);
224 }
225
226 Value debug_string(std::string str);
227
228 template <typename... Args>
229 void execution_mode(Value func, Args &&...args) {
230 ib_.begin(spv::OpExecutionMode)
231 .add_seq(func, std::forward<Args>(args)...)
232 .commit(&exec_mode_);
233 }
234
235 template <typename... Args>
236 void decorate(spv::Op op, Args &&...args) {
237 ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&decorate_);
238 }
239
240 template <typename... Args>
241 void declare_global(spv::Op op, Args &&...args) {
242 ib_.begin(op).add_seq(std::forward<Args>(args)...).commit(&global_);
243 }
244
245 template <typename... Args>
246 Instr make_inst(spv::Op op, Args &&...args) {
247 return ib_.begin(op)
248 .add_seq(std::forward<Args>(args)...)
249 .commit(&function_);
250 }
251
252 // Initialize header
253 void init_header();
254 // Initialize the predefined contents
255 void init_pre_defs();
256 // Get the final binary built from the builder, return The finalized binary
257 // instruction
258 std::vector<uint32_t> finalize();
259
260 Value ext_inst_import(const std::string &name) {
261 Value val = new_value(SType(), ValueKind::kExtInst);
262 ib_.begin(spv::OpExtInstImport).add_seq(val, name).commit(&header_);
263 return val;
264 }
265
266 Label new_label() {
267 Label label;
268 label.id = id_counter_++;
269 return label;
270 }
271
272 // Start a new block with given label
273 void start_label(Label label) {
274 make_inst(spv::OpLabel, label);
275 curr_label_ = label;
276 }
277
278 Label current_label() const {
279 return curr_label_;
280 }
281
282 // Make a new SSA value
283 template <typename... Args>
284 Value make_value(spv::Op op, const SType &out_type, Args &&...args) {
285 Value val = new_value(out_type, ValueKind::kNormal);
286 make_inst(op, out_type, val, std::forward<Args>(args)...);
287 if (out_type.flag == TypeKind::kPtr) {
288 val.flag = ValueKind::kVariablePtr;
289 }
290 return val;
291 }
292
293 // Make a phi value
294 PhiValue make_phi(const SType &out_type, uint32_t num_incoming);
295
296 // Create Constant Primitive Value
297 // cache: if a variable is named, it should not be cached, or the name may
298 // have conflict.
299 Value int_immediate_number(const SType &dtype,
300 int64_t value,
301 bool cache = true);
302 Value uint_immediate_number(const SType &dtype,
303 uint64_t value,
304 bool cache = true);
305 Value float_immediate_number(const SType &dtype,
306 double value,
307 bool cache = true);
308
309 // Match zero type
310 Value get_zero(const SType &stype) {
311 TI_ASSERT(stype.flag == TypeKind::kPrimitive);
312 if (is_integral(stype.dt)) {
313 if (is_signed(stype.dt)) {
314 return int_immediate_number(stype, 0);
315 } else {
316 return uint_immediate_number(stype, 0);
317 }
318 } else if (is_real(stype.dt)) {
319 return float_immediate_number(stype, 0);
320 } else {
321 TI_NOT_IMPLEMENTED
322 return Value();
323 }
324 }
325
326 // Get null stype
327 SType get_null_type();
328 // Get the spirv type for a given Taichi data type
329 SType get_primitive_type(const DataType &dt) const;
330 // Get the size in bytes of a given Taichi data type
331 size_t get_primitive_type_size(const DataType &dt) const;
332 // Get the spirv uint type with the same size of a given Taichi data type
333 SType get_primitive_uint_type(const DataType &dt) const;
334 // Get the Taichi uint type with the same size of a given Taichi data type
335 DataType get_taichi_uint_type(const DataType &dt) const;
336 // Get the pointer type that points to value_type
337 SType get_storage_pointer_type(const SType &value_type);
338 // Get the pointer type that points to value_type
339 SType get_pointer_type(const SType &value_type,
340 spv::StorageClass storage_class);
341 // Get an image type
342 SType get_sampled_image_type(const SType &primitive_type, int num_dimensions);
343 SType get_underlying_image_type(const SType &primitive_type,
344 int num_dimensions);
345 SType get_storage_image_type(BufferFormat format, int num_dimensions);
346 // Get a value_type[num_elems] type
347 SType get_array_type(const SType &value_type, uint32_t num_elems);
348 // Get a struct{ value_type[num_elems] } type
349 SType get_struct_array_type(const SType &value_type, uint32_t num_elems);
350 // Construct a struct type
351 SType create_struct_type(
352 std::vector<std::tuple<SType, std::string, size_t>> &components);
353
354 // Declare buffer argument of function
355 Value buffer_struct_argument(const SType &struct_type,
356 uint32_t descriptor_set,
357 uint32_t binding,
358 const std::string &name);
359 Value uniform_struct_argument(const SType &struct_type,
360 uint32_t descriptor_set,
361 uint32_t binding,
362 const std::string &name);
363 Value buffer_argument(const SType &value_type,
364 uint32_t descriptor_set,
365 uint32_t binding,
366 const std::string &name);
367 Value struct_array_access(const SType &res_type, Value buffer, Value index);
368
369 Value texture_argument(int num_channels,
370 int num_dimensions,
371 uint32_t descriptor_set,
372 uint32_t binding);
373
374 Value storage_image_argument(int num_channels,
375 int num_dimensions,
376 uint32_t descriptor_set,
377 uint32_t binding,
378 BufferFormat format);
379
380 Value sample_texture(Value texture_var,
381 const std::vector<Value> &args,
382 Value lod);
383
384 Value fetch_texel(Value texture_var,
385 const std::vector<Value> &args,
386 Value lod);
387
388 Value image_load(Value image_var, const std::vector<Value> &args);
389
390 void image_store(Value image_var, const std::vector<Value> &args);
391
392 // Declare a new function
393 // NOTE: only support void kernel function, i.e. main
394 Value new_function() {
395 return new_value(t_void_func_, ValueKind::kFunction);
396 }
397
398 std::vector<Value> global_values;
399
400 // Declare the entry point for a kernel function
401 void commit_kernel_function(const Value &func,
402 const std::string &name,
403 std::vector<Value> args,
404 std::array<int, 3> local_size) {
405 ib_.begin(spv::OpEntryPoint)
406 .add_seq(spv::ExecutionModelGLCompute, func, name);
407 for (const auto &arg : args) {
408 ib_.add(arg);
409 }
410 if (caps_->get(DeviceCapability::spirv_version) >= 0x10400) {
411 for (const auto &v : global_values) {
412 ib_.add(v);
413 }
414 }
415 if (gl_global_invocation_id_.id != 0) {
416 ib_.add(gl_global_invocation_id_);
417 }
418 if (gl_num_work_groups_.id != 0) {
419 ib_.add(gl_num_work_groups_);
420 }
421 ib_.commit(&entry_);
422 ib_.begin(spv::OpExecutionMode)
423 .add_seq(func, spv::ExecutionModeLocalSize, local_size[0],
424 local_size[1], local_size[2])
425 .commit(&entry_);
426 }
427
428 // Start function scope
429 void start_function(const Value &func) {
430 // add function declaration to the header
431 ib_.begin(spv::OpFunction)
432 .add_seq(t_void_, func, 0, t_void_func_)
433 .commit(&func_header_);
434
435 spirv::Label start_label = this->new_label();
436 ib_.begin(spv::OpLabel).add_seq(start_label).commit(&func_header_);
437 curr_label_ = start_label;
438 }
439
440 // Declare gl compute shader related methods
441 void set_work_group_size(const std::array<int, 3> group_size);
442 Value get_work_group_size(uint32_t dim_index);
443 Value get_num_work_groups(uint32_t dim_index);
444 Value get_local_invocation_id(uint32_t dim_index);
445 Value get_global_invocation_id(uint32_t dim_index);
446 Value get_subgroup_invocation_id();
447 Value get_subgroup_size();
448
449 // Expressions
450 Value add(Value a, Value b);
451 Value sub(Value a, Value b);
452 Value mul(Value a, Value b);
453 Value div(Value a, Value b);
454 Value mod(Value a, Value b);
455 Value eq(Value a, Value b);
456 Value ne(Value a, Value b);
457 Value lt(Value a, Value b);
458 Value le(Value a, Value b);
459 Value gt(Value a, Value b);
460 Value ge(Value a, Value b);
461 Value bit_field_extract(Value base, Value offset, Value count);
462 Value select(Value cond, Value a, Value b);
463
464 // Create a cast that cast value to dst_type
465 Value cast(const SType &dst_type, Value value);
466
467 // Create a GLSL450 call
468 template <typename... Args>
469 Value call_glsl450(const SType &ret_type, uint32_t inst_id, Args &&...args) {
470 Value val = new_value(ret_type, ValueKind::kNormal);
471 ib_.begin(spv::OpExtInst)
472 .add_seq(ret_type, val, ext_glsl450_, inst_id)
473 .add_seq(std::forward<Args>(args)...)
474 .commit(&function_);
475 return val;
476 }
477
478 // Create a debugPrintf call
479 void call_debugprintf(std::string formats, const std::vector<Value> &args) {
480 Value format_str = debug_string(formats);
481 Value val = new_value(t_void_, ValueKind::kNormal);
482 ib_.begin(spv::OpExtInst)
483 .add_seq(t_void_, val, debug_printf_, 1, format_str);
484 for (const auto &arg : args) {
485 ib_.add(arg);
486 }
487 ib_.commit(&function_);
488 }
489
490 // Local allocate, load, store methods
491 Value alloca_variable(const SType &type);
492 Value alloca_workgroup_array(const SType &type);
493 Value load_variable(Value pointer, const SType &res_type);
494 void store_variable(Value pointer, Value value);
495
496 // Register name to corresponding Value/VariablePointer
497 void register_value(std::string name, Value value);
498 // Query Value/VariablePointer by name
499 Value query_value(std::string name) const;
500 // Check whether a value has been evaluated
501 bool check_value_existence(const std::string &name) const;
502 // Create a new SSA value
503 Value new_value(const SType &type, ValueKind flag) {
504 Value val;
505 val.id = id_counter_++;
506 val.stype = type;
507 val.flag = flag;
508 return val;
509 }
510
511 // Support easy access to trivial data types
512 SType i64_type() const {
513 return t_int64_;
514 }
515 SType u64_type() const {
516 return t_uint64_;
517 }
518 SType f64_type() const {
519 return t_fp64_;
520 }
521
522 SType i32_type() const {
523 return t_int32_;
524 }
525 SType u32_type() const {
526 return t_uint32_;
527 }
528 SType f32_type() const {
529 return t_fp32_;
530 }
531
532 SType i16_type() const {
533 return t_int16_;
534 }
535 SType u16_type() const {
536 return t_uint16_;
537 }
538 SType f16_type() const {
539 return t_fp16_;
540 }
541
542 SType i8_type() const {
543 return t_int8_;
544 }
545 SType u8_type() const {
546 return t_uint8_;
547 }
548
549 SType bool_type() const {
550 return t_bool_;
551 }
552
553 // quick cache for const zero/one i32
554 Value const_i32_zero_;
555 Value const_i32_one_;
556
557 // Use force-inline float atomic helper function
558 Value float_atomic(AtomicOpType op_type, Value addr_ptr, Value data);
559 Value rand_u32(Value global_tmp_);
560 Value rand_f32(Value global_tmp_);
561 Value rand_i32(Value global_tmp_);
562
563 private:
564 Value get_const(const SType &dtype, const uint64_t *pvalue, bool cache);
565 SType declare_primitive_type(DataType dt);
566
567 void init_random_function(Value global_tmp_);
568
569 Arch arch_;
570 const DeviceCapabilityConfig *caps_;
571
572 // internal instruction builder
573 InstrBuilder ib_;
574 // Current label
575 Label curr_label_;
576 // The current maximum id
577 uint32_t id_counter_{1};
578
579 // glsl 450 extension
580 Value ext_glsl450_;
581
582 // debugprint extension
583 Value debug_printf_;
584
585 SType t_bool_;
586 SType t_int8_;
587 SType t_int16_;
588 SType t_int32_;
589 SType t_int64_;
590 SType t_uint8_;
591 SType t_uint16_;
592 SType t_uint32_;
593 SType t_uint64_;
594 SType t_fp16_;
595 SType t_fp32_;
596 SType t_fp64_;
597 SType t_void_;
598 SType t_void_func_;
599 // gl compute shader related type(s) and variables
600 SType t_v2_int_;
601 SType t_v3_int_;
602 SType t_v3_uint_;
603 SType t_v4_fp32_;
604 SType t_v3_fp32_;
605 SType t_v2_fp32_;
606 Value gl_global_invocation_id_;
607 Value gl_local_invocation_id_;
608 Value gl_num_work_groups_;
609 Value gl_work_group_size_;
610 Value subgroup_local_invocation_id_;
611 Value subgroup_size_;
612
613 // Random function and variables
614 bool init_rand_{false};
615 Value rand_x_;
616 Value rand_y_;
617 Value rand_z_;
618 Value rand_w_; // per-thread local variable
619
620 // map from value to its pointer type
621 std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
622 std::map<std::pair<uint32_t, int>, SType> sampled_image_ptr_tbl_;
623 std::map<std::pair<uint32_t, int>, SType>
624 sampled_image_underlying_image_type_;
625
626 std::map<std::pair<BufferFormat, int>, SType> storage_image_ptr_tbl_;
627
628 // map from constant int to its value
629 std::map<std::pair<uint32_t, uint64_t>, Value> const_tbl_;
630 // map from raw_name(string) to Value
631 std::unordered_map<std::string, Value> value_name_tbl_;
632
633 // Header segment, include import
634 std::vector<uint32_t> header_;
635 // engtry point segment
636 std::vector<uint32_t> entry_;
637 // Header segment
638 std::vector<uint32_t> exec_mode_;
639 // Debug segment
640 // According to SPIR-V spec, the following debug instructions must be
641 // grouped in the order:
642 // - All OpString, OpSourceExtension, OpSource, and OpSourceContinued,
643 // without forward references.
644 // - All OpName and all OpMemberName.
645 // - All OpModuleProcessed instructions.
646
647 // OpString segment
648 std::vector<uint32_t> strings_;
649 // OpName segment
650 std::vector<uint32_t> names_;
651 // Annotation segment
652 std::vector<uint32_t> decorate_;
653 // Global segment: types, variables, types
654 std::vector<uint32_t> global_;
655 // Function header segment
656 std::vector<uint32_t> func_header_;
657 // Main Function segment
658 std::vector<uint32_t> function_;
659};
660} // namespace spirv
661} // namespace taichi::lang
662