1#include "taichi/codegen/spirv/spirv_ir_builder.h"
2#include "taichi/rhi/dx/dx_device.h"
3
4namespace taichi::lang {
5
6namespace spirv {
7
8using cap = DeviceCapability;
9
10void IRBuilder::init_header() {
11 TI_ASSERT(header_.size() == 0U);
12 header_.push_back(spv::MagicNumber);
13
14 header_.push_back(caps_->get(cap::spirv_version));
15
16 TI_TRACE("SPIR-V Version {}", caps_->get(cap::spirv_version));
17
18 // generator: set to 0, unknown
19 header_.push_back(0U);
20 // Bound: set during Finalize
21 header_.push_back(0U);
22 // Schema: reserved
23 header_.push_back(0U);
24
25 // capability
26 ib_.begin(spv::OpCapability).add(spv::CapabilityShader).commit(&header_);
27
28 if (caps_->get(cap::spirv_has_atomic_float64_add)) {
29 ib_.begin(spv::OpCapability)
30 .add(spv::CapabilityAtomicFloat64AddEXT)
31 .commit(&header_);
32 }
33
34 if (caps_->get(cap::spirv_has_atomic_float_add)) {
35 ib_.begin(spv::OpCapability)
36 .add(spv::CapabilityAtomicFloat32AddEXT)
37 .commit(&header_);
38 }
39
40 if (caps_->get(cap::spirv_has_atomic_float_minmax)) {
41 ib_.begin(spv::OpCapability)
42 .add(spv::CapabilityAtomicFloat32MinMaxEXT)
43 .commit(&header_);
44 }
45
46 if (caps_->get(cap::spirv_has_variable_ptr)) {
47 /*
48 ib_.begin(spv::OpCapability)
49 .add(spv::CapabilityVariablePointers)
50 .commit(&header_);
51 ib_.begin(spv::OpCapability)
52 .add(spv::CapabilityVariablePointersStorageBuffer)
53 .commit(&header_);
54 */
55 }
56
57 if (caps_->get(cap::spirv_has_int8)) {
58 ib_.begin(spv::OpCapability).add(spv::CapabilityInt8).commit(&header_);
59 }
60 if (caps_->get(cap::spirv_has_int16)) {
61 ib_.begin(spv::OpCapability).add(spv::CapabilityInt16).commit(&header_);
62 }
63 if (caps_->get(cap::spirv_has_int64)) {
64 ib_.begin(spv::OpCapability).add(spv::CapabilityInt64).commit(&header_);
65 }
66 if (caps_->get(cap::spirv_has_float16)) {
67 ib_.begin(spv::OpCapability).add(spv::CapabilityFloat16).commit(&header_);
68 }
69 if (caps_->get(cap::spirv_has_float64)) {
70 ib_.begin(spv::OpCapability).add(spv::CapabilityFloat64).commit(&header_);
71 }
72 if (caps_->get(cap::spirv_has_physical_storage_buffer)) {
73 ib_.begin(spv::OpCapability)
74 .add(spv::CapabilityPhysicalStorageBufferAddresses)
75 .commit(&header_);
76 }
77
78 ib_.begin(spv::OpExtension)
79 .add("SPV_KHR_storage_buffer_storage_class")
80 .commit(&header_);
81
82 if (caps_->get(cap::spirv_has_no_integer_wrap_decoration)) {
83 ib_.begin(spv::OpExtension)
84 .add("SPV_KHR_no_integer_wrap_decoration")
85 .commit(&header_);
86 }
87
88 if (caps_->get(cap::spirv_has_non_semantic_info)) {
89 ib_.begin(spv::OpExtension)
90 .add("SPV_KHR_non_semantic_info")
91 .commit(&header_);
92 }
93
94 if (caps_->get(cap::spirv_has_variable_ptr)) {
95 ib_.begin(spv::OpExtension)
96 .add("SPV_KHR_variable_pointers")
97 .commit(&header_);
98 }
99
100 if (caps_->get(cap::spirv_has_atomic_float_add)) {
101 ib_.begin(spv::OpExtension)
102 .add("SPV_EXT_shader_atomic_float_add")
103 .commit(&header_);
104 }
105
106 if (caps_->get(cap::spirv_has_atomic_float_minmax)) {
107 ib_.begin(spv::OpExtension)
108 .add("SPV_EXT_shader_atomic_float_min_max")
109 .commit(&header_);
110 }
111
112 if (caps_->get(cap::spirv_has_physical_storage_buffer)) {
113 ib_.begin(spv::OpExtension)
114 .add("SPV_KHR_physical_storage_buffer")
115 .commit(&header_);
116
117 // memory model
118 ib_.begin(spv::OpMemoryModel)
119 .add_seq(spv::AddressingModelPhysicalStorageBuffer64,
120 spv::MemoryModelGLSL450)
121 .commit(&entry_);
122 } else {
123 ib_.begin(spv::OpMemoryModel)
124 .add_seq(spv::AddressingModelLogical, spv::MemoryModelGLSL450)
125 .commit(&entry_);
126 }
127
128 this->init_pre_defs();
129}
130
131std::vector<uint32_t> IRBuilder::finalize() {
132 std::vector<uint32_t> data;
133 // set bound
134 const int bound_loc = 3;
135 header_[bound_loc] = id_counter_;
136 data.insert(data.end(), header_.begin(), header_.end());
137 data.insert(data.end(), entry_.begin(), entry_.end());
138 data.insert(data.end(), exec_mode_.begin(), exec_mode_.end());
139 data.insert(data.end(), strings_.begin(), strings_.end());
140 data.insert(data.end(), names_.begin(), names_.end());
141 data.insert(data.end(), decorate_.begin(), decorate_.end());
142 data.insert(data.end(), global_.begin(), global_.end());
143 data.insert(data.end(), func_header_.begin(), func_header_.end());
144 data.insert(data.end(), function_.begin(), function_.end());
145 return data;
146}
147
148void IRBuilder::init_pre_defs() {
149 ext_glsl450_ = ext_inst_import("GLSL.std.450");
150 if (caps_->get(cap::spirv_has_non_semantic_info)) {
151 debug_printf_ = ext_inst_import("NonSemantic.DebugPrintf");
152 }
153
154 t_bool_ = declare_primitive_type(get_data_type<bool>());
155 if (caps_->get(cap::spirv_has_int8)) {
156 t_int8_ = declare_primitive_type(get_data_type<int8>());
157 t_uint8_ = declare_primitive_type(get_data_type<uint8>());
158 }
159 if (caps_->get(cap::spirv_has_int16)) {
160 t_int16_ = declare_primitive_type(get_data_type<int16>());
161 t_uint16_ = declare_primitive_type(get_data_type<uint16>());
162 }
163 t_int32_ = declare_primitive_type(get_data_type<int32>());
164 t_uint32_ = declare_primitive_type(get_data_type<uint32>());
165 if (caps_->get(cap::spirv_has_int64)) {
166 t_int64_ = declare_primitive_type(get_data_type<int64>());
167 t_uint64_ = declare_primitive_type(get_data_type<uint64>());
168 }
169 t_fp32_ = declare_primitive_type(get_data_type<float32>());
170 if (caps_->get(cap::spirv_has_float16)) {
171 t_fp16_ = declare_primitive_type(PrimitiveType::f16);
172 }
173 if (caps_->get(cap::spirv_has_float64)) {
174 t_fp64_ = declare_primitive_type(get_data_type<float64>());
175 }
176 // declare void, and void functions
177 t_void_.id = id_counter_++;
178 ib_.begin(spv::OpTypeVoid).add(t_void_).commit(&global_);
179 t_void_func_.id = id_counter_++;
180 ib_.begin(spv::OpTypeFunction)
181 .add_seq(t_void_func_, t_void_)
182 .commit(&global_);
183
184 // compute shader related types
185 t_v2_int_.id = id_counter_++;
186 ib_.begin(spv::OpTypeVector)
187 .add(t_v2_int_)
188 .add_seq(t_int32_, 2)
189 .commit(&global_);
190
191 t_v3_int_.id = id_counter_++;
192 ib_.begin(spv::OpTypeVector)
193 .add(t_v3_int_)
194 .add_seq(t_int32_, 3)
195 .commit(&global_);
196
197 t_v3_uint_.id = id_counter_++;
198 ib_.begin(spv::OpTypeVector)
199 .add(t_v3_uint_)
200 .add_seq(t_uint32_, 3)
201 .commit(&global_);
202
203 t_v4_fp32_.id = id_counter_++;
204 ib_.begin(spv::OpTypeVector)
205 .add(t_v4_fp32_)
206 .add_seq(t_fp32_, 4)
207 .commit(&global_);
208
209 t_v2_fp32_.id = id_counter_++;
210 ib_.begin(spv::OpTypeVector)
211 .add(t_v2_fp32_)
212 .add_seq(t_fp32_, 2)
213 .commit(&global_);
214
215 t_v3_fp32_.id = id_counter_++;
216 ib_.begin(spv::OpTypeVector)
217 .add(t_v3_fp32_)
218 .add_seq(t_fp32_, 3)
219 .commit(&global_);
220
221 // pre-defined constants
222 const_i32_zero_ = int_immediate_number(t_int32_, 0);
223 const_i32_one_ = int_immediate_number(t_int32_, 1);
224}
225
226Value IRBuilder::debug_string(std::string s) {
227 Value val = new_value(SType(), ValueKind::kNormal);
228 ib_.begin(spv::OpString).add_seq(val, s).commit(&strings_);
229 return val;
230}
231
232PhiValue IRBuilder::make_phi(const SType &out_type, uint32_t num_incoming) {
233 Value val = new_value(out_type, ValueKind::kNormal);
234 ib_.begin(spv::OpPhi).add_seq(out_type, val);
235 for (uint32_t i = 0; i < 2 * num_incoming; ++i) {
236 ib_.add(0);
237 }
238
239 PhiValue phi;
240 phi.id = val.id;
241 phi.stype = out_type;
242 phi.flag = ValueKind::kNormal;
243 phi.instr = ib_.commit(&function_);
244 return phi;
245}
246
247Value IRBuilder::int_immediate_number(const SType &dtype,
248 int64_t value,
249 bool cache) {
250 return get_const(dtype, reinterpret_cast<uint64_t *>(&value), cache);
251}
252
253Value IRBuilder::uint_immediate_number(const SType &dtype,
254 uint64_t value,
255 bool cache) {
256 return get_const(dtype, &value, cache);
257}
258
259Value IRBuilder::float_immediate_number(const SType &dtype,
260 double value,
261 bool cache) {
262 if (data_type_bits(dtype.dt) == 64) {
263 return get_const(dtype, reinterpret_cast<uint64_t *>(&value), cache);
264 } else if (data_type_bits(dtype.dt) == 32) {
265 float fvalue = static_cast<float>(value);
266 uint32_t *ptr = reinterpret_cast<uint32_t *>(&fvalue);
267 uint64_t data = ptr[0];
268 return get_const(dtype, &data, cache);
269 } else if (data_type_bits(dtype.dt) == 16) {
270 float fvalue = static_cast<float>(value);
271 uint16_t *ptr = reinterpret_cast<uint16_t *>(&fvalue);
272 uint64_t data = ptr[0];
273 return get_const(dtype, &data, cache);
274 } else {
275 TI_ERROR("Type {} not supported.", dtype.dt->to_string());
276 }
277}
278
279SType IRBuilder::get_null_type() {
280 SType res;
281 res.id = id_counter_++;
282 return res;
283}
284
285SType IRBuilder::get_primitive_type(const DataType &dt) const {
286 if (dt->is_primitive(PrimitiveTypeID::u1)) {
287 return t_bool_;
288 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
289 if (!caps_->get(cap::spirv_has_float16))
290 TI_ERROR("Type {} not supported.", dt->to_string());
291 return t_fp16_;
292 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
293 return t_fp32_;
294 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
295 if (!caps_->get(cap::spirv_has_float64))
296 TI_ERROR("Type {} not supported.", dt->to_string());
297 return t_fp64_;
298 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
299 if (!caps_->get(cap::spirv_has_int8))
300 TI_ERROR("Type {} not supported.", dt->to_string());
301 return t_int8_;
302 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
303 if (!caps_->get(cap::spirv_has_int16))
304 TI_ERROR("Type {} not supported.", dt->to_string());
305 return t_int16_;
306 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
307 return t_int32_;
308 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
309 if (!caps_->get(cap::spirv_has_int64))
310 TI_ERROR("Type {} not supported.", dt->to_string());
311 return t_int64_;
312 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
313 if (!caps_->get(cap::spirv_has_int8))
314 TI_ERROR("Type {} not supported.", dt->to_string());
315 return t_uint8_;
316 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
317 if (!caps_->get(cap::spirv_has_int16))
318 TI_ERROR("Type {} not supported.", dt->to_string());
319 return t_uint16_;
320 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
321 return t_uint32_;
322 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
323 if (!caps_->get(cap::spirv_has_int64))
324 TI_ERROR("Type {} not supported.", dt->to_string());
325 return t_uint64_;
326 } else {
327 TI_ERROR("Type {} not supported.", dt->to_string());
328 }
329}
330
331size_t IRBuilder::get_primitive_type_size(const DataType &dt) const {
332 if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 ||
333 dt == PrimitiveType::f64) {
334 return 8;
335 } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 ||
336 dt == PrimitiveType::f32) {
337 return 4;
338 } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
339 dt == PrimitiveType::f16) {
340 return 2;
341 } else {
342 return 1;
343 }
344}
345
346SType IRBuilder::get_primitive_uint_type(const DataType &dt) const {
347 if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 ||
348 dt == PrimitiveType::f64) {
349 return t_uint64_;
350 } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 ||
351 dt == PrimitiveType::f32) {
352 return t_uint32_;
353 } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
354 dt == PrimitiveType::f16) {
355 return t_uint16_;
356 } else {
357 return t_uint8_;
358 }
359}
360
361DataType IRBuilder::get_taichi_uint_type(const DataType &dt) const {
362 if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 ||
363 dt == PrimitiveType::f64) {
364 return PrimitiveType::u64;
365 } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 ||
366 dt == PrimitiveType::f32) {
367 return PrimitiveType::u32;
368 } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
369 dt == PrimitiveType::f16) {
370 return PrimitiveType::u16;
371 } else {
372 return PrimitiveType::u8;
373 }
374}
375
376SType IRBuilder::get_pointer_type(const SType &value_type,
377 spv::StorageClass storage_class) {
378 auto key = std::make_pair(value_type.id, storage_class);
379 auto it = pointer_type_tbl_.find(key);
380 if (it != pointer_type_tbl_.end()) {
381 return it->second;
382 }
383 SType t;
384 t.id = id_counter_++;
385 t.flag = TypeKind::kPtr;
386 t.element_type_id = value_type.id;
387 t.storage_class = storage_class;
388 ib_.begin(spv::OpTypePointer)
389 .add_seq(t, storage_class, value_type)
390 .commit(&global_);
391 pointer_type_tbl_[key] = t;
392 return t;
393}
394
395SType IRBuilder::get_underlying_image_type(const SType &primitive_type,
396 int num_dimensions) {
397 auto key = std::make_pair(primitive_type.id, num_dimensions);
398
399 auto it = sampled_image_underlying_image_type_.find(key);
400 if (it != sampled_image_underlying_image_type_.end()) {
401 return it->second;
402 }
403
404 int img_id = id_counter_++;
405 spv::Dim dim;
406 if (num_dimensions == 1) {
407 dim = spv::Dim1D;
408 } else if (num_dimensions == 2) {
409 dim = spv::Dim2D;
410 } else if (num_dimensions == 3) {
411 dim = spv::Dim3D;
412 } else {
413 TI_ERROR("Unsupported number of dimensions: {}", num_dimensions);
414 }
415 ib_.begin(spv::OpTypeImage)
416 .add_seq(img_id, primitive_type, dim,
417 /*Depth=*/0, /*Arrayed=*/0, /*MS=*/0, /*Sampled=*/1,
418 spv::ImageFormatUnknown)
419 .commit(&global_);
420
421 SType image_type;
422 image_type.id = img_id;
423 image_type.flag = TypeKind::kImage;
424 sampled_image_underlying_image_type_[key] = image_type;
425
426 return image_type;
427}
428
429SType IRBuilder::get_sampled_image_type(const SType &primitive_type,
430 int num_dimensions) {
431 auto key = std::make_pair(primitive_type.id, num_dimensions);
432 auto it = sampled_image_ptr_tbl_.find(key);
433 if (it != sampled_image_ptr_tbl_.end()) {
434 return it->second;
435 }
436
437 SType image_type = get_underlying_image_type(primitive_type, num_dimensions);
438 int img_id = image_type.id;
439
440 SType sampled_t;
441 sampled_t.id = id_counter_++;
442 sampled_t.flag = TypeKind::kImage;
443 ib_.begin(spv::OpTypeSampledImage)
444 .add_seq(sampled_t, img_id)
445 .commit(&global_);
446 sampled_image_ptr_tbl_[key] = sampled_t;
447
448 return sampled_t;
449}
450
451SType IRBuilder::get_storage_image_type(BufferFormat format,
452 int num_dimensions) {
453 auto key = std::make_pair(format, num_dimensions);
454 auto it = storage_image_ptr_tbl_.find(key);
455 if (it != storage_image_ptr_tbl_.end()) {
456 return it->second;
457 }
458 int img_id = id_counter_++;
459
460 spv::Dim dim;
461 if (num_dimensions == 1) {
462 dim = spv::Dim1D;
463 } else if (num_dimensions == 2) {
464 dim = spv::Dim2D;
465 } else if (num_dimensions == 3) {
466 dim = spv::Dim3D;
467 } else {
468 TI_ERROR("Unsupported number of dimensions: {}", num_dimensions);
469 }
470
471 const std::unordered_map<BufferFormat, spv::ImageFormat> format2spv = {
472 {BufferFormat::r8, spv::ImageFormatR8},
473 {BufferFormat::rg8, spv::ImageFormatRg8},
474 {BufferFormat::rgba8, spv::ImageFormatRgba8},
475 {BufferFormat::rgba8srgb, spv::ImageFormatRgba8},
476 {BufferFormat::r8u, spv::ImageFormatR8ui},
477 {BufferFormat::rg8u, spv::ImageFormatRg8ui},
478 {BufferFormat::rgba8u, spv::ImageFormatRgba8ui},
479 {BufferFormat::r8i, spv::ImageFormatR8i},
480 {BufferFormat::rg8i, spv::ImageFormatRg8i},
481 {BufferFormat::rgba8i, spv::ImageFormatRgba8i},
482 {BufferFormat::r16, spv::ImageFormatR16},
483 {BufferFormat::rg16, spv::ImageFormatRg16},
484 {BufferFormat::rgba16, spv::ImageFormatRgba16},
485 {BufferFormat::r16u, spv::ImageFormatR16ui},
486 {BufferFormat::rg16u, spv::ImageFormatRg16ui},
487 {BufferFormat::rgba16u, spv::ImageFormatRgba16ui},
488 {BufferFormat::r16i, spv::ImageFormatR16i},
489 {BufferFormat::rg16i, spv::ImageFormatRg16i},
490 {BufferFormat::rgba16i, spv::ImageFormatRgba16i},
491 {BufferFormat::r16f, spv::ImageFormatR16f},
492 {BufferFormat::rg16f, spv::ImageFormatRg16f},
493 {BufferFormat::rgba16f, spv::ImageFormatRgba16f},
494 {BufferFormat::r32u, spv::ImageFormatR32ui},
495 {BufferFormat::rg32u, spv::ImageFormatRg32ui},
496 {BufferFormat::rgba32u, spv::ImageFormatRgba32ui},
497 {BufferFormat::r32i, spv::ImageFormatR32i},
498 {BufferFormat::rg32i, spv::ImageFormatRg32i},
499 {BufferFormat::rgba32i, spv::ImageFormatRgba32i},
500 {BufferFormat::r32f, spv::ImageFormatR32f},
501 {BufferFormat::rg32f, spv::ImageFormatRg32f},
502 {BufferFormat::rgba32f, spv::ImageFormatRgba32f},
503 {BufferFormat::depth16, spv::ImageFormatR16},
504 {BufferFormat::depth32f, spv::ImageFormatR32f}};
505
506 if (format2spv.find(format) == format2spv.end()) {
507 TI_ERROR("Unsupported image format", num_dimensions);
508 }
509 spv::ImageFormat spv_format = format2spv.at(format);
510
511 // TODO: Add integer type support
512 ib_.begin(spv::OpTypeImage)
513 .add_seq(img_id, f32_type(), dim,
514 /*Depth=*/0, /*Arrayed=*/0, /*MS=*/0, /*Sampled=*/2, spv_format)
515 .commit(&global_);
516 SType img_t;
517 img_t.id = img_id;
518 img_t.flag = TypeKind::kImage;
519 storage_image_ptr_tbl_[key] = img_t;
520 return img_t;
521}
522
523SType IRBuilder::get_storage_pointer_type(const SType &value_type) {
524 spv::StorageClass storage_class;
525 if (caps_->get(cap::spirv_version) < 0x10300) {
526 storage_class = spv::StorageClassUniform;
527 } else {
528 storage_class = spv::StorageClassStorageBuffer;
529 }
530
531 return get_pointer_type(value_type, storage_class);
532}
533
534SType IRBuilder::get_array_type(const SType &value_type, uint32_t num_elems) {
535 SType arr_type;
536 arr_type.id = id_counter_++;
537 arr_type.flag = TypeKind::kPtr;
538 arr_type.element_type_id = value_type.id;
539
540 if (num_elems != 0) {
541 Value length = uint_immediate_number(
542 get_primitive_type(get_data_type<uint32>()), num_elems);
543 ib_.begin(spv::OpTypeArray)
544 .add_seq(arr_type, value_type, length)
545 .commit(&global_);
546 } else {
547 ib_.begin(spv::OpTypeRuntimeArray)
548 .add_seq(arr_type, value_type)
549 .commit(&global_);
550 }
551
552 uint32_t nbytes;
553 if (value_type.flag == TypeKind::kPrimitive) {
554 const auto nbits = data_type_bits(value_type.dt);
555 nbytes = static_cast<uint32_t>(nbits) / 8;
556 } else if (value_type.flag == TypeKind::kSNodeStruct) {
557 nbytes = value_type.snode_desc.container_stride;
558 } else {
559 TI_ERROR("buffer type must be primitive or snode struct");
560 }
561
562 if (nbytes == 0) {
563 if (value_type.flag == TypeKind::kPrimitive) {
564 TI_WARN("Invalid primitive bit size");
565 } else {
566 TI_WARN("Invalid container stride");
567 }
568 }
569
570 // decorate the array type
571 this->decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes);
572
573 return arr_type;
574}
575
576SType IRBuilder::get_struct_array_type(const SType &value_type,
577 uint32_t num_elems) {
578 SType arr_type = get_array_type(value_type, num_elems);
579
580 // declare struct of array
581 SType struct_type;
582 struct_type.id = id_counter_++;
583 struct_type.flag = TypeKind::kStruct;
584 struct_type.element_type_id = value_type.id;
585 ib_.begin(spv::OpTypeStruct).add_seq(struct_type, arr_type).commit(&global_);
586 // decorate the array type.
587 ib_.begin(spv::OpMemberDecorate)
588 .add_seq(struct_type, 0, spv::DecorationOffset, 0)
589 .commit(&decorate_);
590
591 if (caps_->get(cap::spirv_version) < 0x10300) {
592 // NOTE: BufferBlock was deprecated in SPIRV 1.3
593 // use StorageClassStorageBuffer instead.
594 // runtime array are always decorated as BufferBlock(shader storage buffer)
595 if (num_elems == 0) {
596 this->decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
597 }
598 } else {
599 this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
600 }
601
602 return struct_type;
603}
604
605SType IRBuilder::create_struct_type(
606 std::vector<std::tuple<SType, std::string, size_t>> &components) {
607 SType struct_type;
608 struct_type.id = id_counter_++;
609 struct_type.flag = TypeKind::kStruct;
610
611 auto &builder = ib_.begin(spv::OpTypeStruct).add_seq(struct_type);
612
613 for (auto &[type, name, offset] : components) {
614 builder.add_seq(type);
615 }
616
617 builder.commit(&global_);
618
619 int i = 0;
620 for (auto &[type, name, offset] : components) {
621 this->decorate(spv::OpMemberDecorate, struct_type, i, spv::DecorationOffset,
622 offset);
623 this->debug_name(spv::OpMemberName, struct_type, i, name);
624 i++;
625 }
626
627 return struct_type;
628}
629
630Value IRBuilder::buffer_struct_argument(const SType &struct_type,
631 uint32_t descriptor_set,
632 uint32_t binding,
633 const std::string &name) {
634 // NOTE: BufferBlock was deprecated in SPIRV 1.3
635 // use StorageClassStorageBuffer instead.
636 spv::StorageClass storage_class;
637 if (caps_->get(cap::spirv_version) < 0x10300) {
638 storage_class = spv::StorageClassUniform;
639 } else {
640 storage_class = spv::StorageClassStorageBuffer;
641 }
642
643 this->debug_name(spv::OpName, struct_type, name + "_t");
644
645 if (caps_->get(cap::spirv_version) < 0x10300) {
646 // NOTE: BufferBlock was deprecated in SPIRV 1.3
647 // use StorageClassStorageBuffer instead.
648 // runtime array are always decorated as BufferBlock(shader storage buffer)
649 this->decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
650 } else {
651 this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
652 }
653
654 SType ptr_type = get_pointer_type(struct_type, storage_class);
655
656 this->debug_name(spv::OpName, ptr_type, name + "_ptr");
657
658 Value val = new_value(ptr_type, ValueKind::kStructArrayPtr);
659 ib_.begin(spv::OpVariable)
660 .add_seq(ptr_type, val, storage_class)
661 .commit(&global_);
662
663 this->debug_name(spv::OpName, val, name);
664
665 this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet,
666 descriptor_set);
667 this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
668 return val;
669}
670
671Value IRBuilder::uniform_struct_argument(const SType &struct_type,
672 uint32_t descriptor_set,
673 uint32_t binding,
674 const std::string &name) {
675 // NOTE: BufferBlock was deprecated in SPIRV 1.3
676 // use StorageClassStorageBuffer instead.
677 spv::StorageClass storage_class = spv::StorageClassUniform;
678
679 this->debug_name(spv::OpName, struct_type, name + "_t");
680
681 this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
682
683 SType ptr_type = get_pointer_type(struct_type, storage_class);
684
685 this->debug_name(spv::OpName, ptr_type, name + "_ptr");
686
687 Value val = new_value(ptr_type, ValueKind::kStructArrayPtr);
688 ib_.begin(spv::OpVariable)
689 .add_seq(ptr_type, val, storage_class)
690 .commit(&global_);
691
692 this->debug_name(spv::OpName, val, name);
693
694 this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet,
695 descriptor_set);
696 this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
697 return val;
698}
699
700Value IRBuilder::buffer_argument(const SType &value_type,
701 uint32_t descriptor_set,
702 uint32_t binding,
703 const std::string &name) {
704 // NOTE: BufferBlock was deprecated in SPIRV 1.3
705 // use StorageClassStorageBuffer instead.
706 spv::StorageClass storage_class;
707 if (caps_->get(cap::spirv_version) < 0x10300) {
708 storage_class = spv::StorageClassUniform;
709 } else {
710 storage_class = spv::StorageClassStorageBuffer;
711 }
712
713 SType sarr_type = get_struct_array_type(value_type, 0);
714
715 auto typed_name = name + "_" + value_type.dt.to_string();
716
717 this->debug_name(spv::OpName, sarr_type, typed_name + "_struct_array");
718
719 SType ptr_type = get_pointer_type(sarr_type, storage_class);
720
721 this->debug_name(spv::OpName, sarr_type, typed_name + "_ptr");
722
723 Value val = new_value(ptr_type, ValueKind::kStructArrayPtr);
724 ib_.begin(spv::OpVariable)
725 .add_seq(ptr_type, val, storage_class)
726 .commit(&global_);
727
728 this->debug_name(spv::OpName, val, typed_name);
729
730 this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet,
731 descriptor_set);
732 this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
733 return val;
734}
735
736Value IRBuilder::struct_array_access(const SType &res_type,
737 Value buffer,
738 Value index) {
739 TI_ASSERT(buffer.flag == ValueKind::kStructArrayPtr);
740 TI_ASSERT(res_type.flag == TypeKind::kPrimitive);
741
742 spv::StorageClass storage_class;
743 if (caps_->get(cap::spirv_version) < 0x10300) {
744 storage_class = spv::StorageClassUniform;
745 } else {
746 storage_class = spv::StorageClassStorageBuffer;
747 }
748
749 SType ptr_type = this->get_pointer_type(res_type, storage_class);
750 Value ret = new_value(ptr_type, ValueKind::kVariablePtr);
751 ib_.begin(spv::OpAccessChain)
752 .add_seq(ptr_type, ret, buffer, const_i32_zero_, index)
753 .commit(&function_);
754
755 return ret;
756}
757
758Value IRBuilder::texture_argument(int num_channels,
759 int num_dimensions,
760 uint32_t descriptor_set,
761 uint32_t binding) {
762 auto texture_type = this->get_sampled_image_type(f32_type(), num_dimensions);
763 auto texture_ptr_type =
764 get_pointer_type(texture_type, spv::StorageClassUniformConstant);
765
766 Value val = new_value(texture_ptr_type, ValueKind::kVariablePtr);
767 ib_.begin(spv::OpVariable)
768 .add_seq(texture_ptr_type, val, spv::StorageClassUniformConstant)
769 .commit(&global_);
770
771 this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet,
772 descriptor_set);
773 this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
774
775 this->debug_name(spv::OpName, val, "tex");
776
777 this->global_values.push_back(val);
778
779 return val;
780}
781
782Value IRBuilder::storage_image_argument(int num_channels,
783 int num_dimensions,
784 uint32_t descriptor_set,
785 uint32_t binding,
786 BufferFormat format) {
787 auto texture_type = this->get_storage_image_type(format, num_dimensions);
788 auto texture_ptr_type =
789 get_pointer_type(texture_type, spv::StorageClassUniformConstant);
790
791 Value val = new_value(texture_type, ValueKind::kVariablePtr);
792 ib_.begin(spv::OpVariable)
793 .add_seq(texture_ptr_type, val, spv::StorageClassUniformConstant)
794 .commit(&global_);
795
796 this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet,
797 descriptor_set);
798 this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
799
800 this->debug_name(spv::OpName, val, "tex");
801
802 this->global_values.push_back(val);
803
804 return val;
805}
806
807Value IRBuilder::sample_texture(Value texture_var,
808 const std::vector<Value> &args,
809 Value lod) {
810 auto image = this->load_variable(
811 texture_var, this->get_sampled_image_type(f32_type(), args.size()));
812 Value uv;
813 if (args.size() == 1) {
814 uv = args[0];
815 } else if (args.size() == 2) {
816 uv = make_value(spv::OpCompositeConstruct, t_v2_fp32_, args[0], args[1]);
817 } else if (args.size() == 3) {
818 uv = make_value(spv::OpCompositeConstruct, t_v3_fp32_, args[0], args[1],
819 args[2]);
820 } else {
821 TI_ERROR("Unsupported number of texture coordinates");
822 }
823 uint32_t lod_operand = 0x2;
824 auto res_vec4 = make_value(spv::OpImageSampleExplicitLod, t_v4_fp32_, image,
825 uv, lod_operand, lod);
826 return res_vec4;
827}
828
829Value IRBuilder::fetch_texel(Value texture_var,
830 const std::vector<Value> &args,
831 Value lod) {
832 auto sampled_image = this->load_variable(
833 texture_var, this->get_sampled_image_type(f32_type(), args.size()));
834
835 // OpImageFetch requires operand with OpImageType
836 // We have to extract the underlying OpImage from OpSampledImage here
837 SType image_type = get_underlying_image_type(f32_type(), args.size());
838 Value image_val = make_value(spv::OpImage, image_type, sampled_image);
839
840 Value uv;
841 if (args.size() == 1) {
842 uv = args[0];
843 } else if (args.size() == 2) {
844 uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]);
845 } else if (args.size() == 3) {
846 uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1],
847 args[2]);
848 } else {
849 TI_ERROR("Unsupported number of texture coordinates");
850 }
851 uint32_t lod_operand = 0x2;
852 auto res_vec4 = make_value(spv::OpImageFetch, t_v4_fp32_, image_val, uv,
853 lod_operand, lod);
854 return res_vec4;
855}
856
857Value IRBuilder::image_load(Value image_var, const std::vector<Value> &args) {
858 auto image = this->load_variable(image_var, image_var.stype);
859 Value uv;
860 if (args.size() == 1) {
861 uv = args[0];
862 } else if (args.size() == 2) {
863 uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]);
864 } else if (args.size() == 3) {
865 uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1],
866 args[2]);
867 } else {
868 TI_ERROR("Unsupported number of texture coordinates");
869 }
870 auto res_vec4 = make_value(spv::OpImageRead, t_v4_fp32_, image, uv);
871 return res_vec4;
872}
873
874void IRBuilder::image_store(Value image_var, const std::vector<Value> &args) {
875 auto image = this->load_variable(image_var, image_var.stype);
876 Value uv;
877 if (args.size() == 1 + 4) {
878 uv = args[0];
879 } else if (args.size() == 2 + 4) {
880 uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]);
881 } else if (args.size() == 3 + 4) {
882 uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1],
883 args[2]);
884 } else {
885 TI_ERROR("Unsupported number of image coordinates");
886 }
887 int base = args.size() - 4;
888 Value data = make_value(spv::OpCompositeConstruct, t_v4_fp32_, args[base],
889 args[base + 1], args[base + 2], args[base + 3]);
890 make_inst(spv::OpImageWrite, image, uv, data);
891}
892
893void IRBuilder::set_work_group_size(const std::array<int, 3> group_size) {
894 Value size_x =
895 uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[0]));
896 Value size_y =
897 uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[1]));
898 Value size_z =
899 uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[2]));
900
901 if (gl_work_group_size_.id == 0) {
902 gl_work_group_size_.id = id_counter_++;
903 }
904 ib_.begin(spv::OpConstantComposite)
905 .add_seq(t_v3_uint_, gl_work_group_size_, size_x, size_y, size_z)
906 .commit(&global_);
907 this->decorate(spv::OpDecorate, gl_work_group_size_, spv::DecorationBuiltIn,
908 spv::BuiltInWorkgroupSize);
909}
910
911Value IRBuilder::get_num_work_groups(uint32_t dim_index) {
912 if (gl_num_work_groups_.id == 0) {
913 SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput);
914 gl_num_work_groups_ = new_value(ptr_type, ValueKind::kVectorPtr);
915 ib_.begin(spv::OpVariable)
916 .add_seq(ptr_type, gl_num_work_groups_, spv::StorageClassInput)
917 .commit(&global_);
918 this->decorate(spv::OpDecorate, gl_num_work_groups_, spv::DecorationBuiltIn,
919 spv::BuiltInNumWorkgroups);
920 }
921 SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
922 Value ptr = this->make_value(
923 spv::OpAccessChain, pint_type, gl_num_work_groups_,
924 uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index)));
925
926 return this->make_value(spv::OpLoad, t_uint32_, ptr);
927}
928
929Value IRBuilder::get_local_invocation_id(uint32_t dim_index) {
930 if (gl_local_invocation_id_.id == 0) {
931 SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput);
932 gl_local_invocation_id_ = new_value(ptr_type, ValueKind::kVectorPtr);
933 ib_.begin(spv::OpVariable)
934 .add_seq(ptr_type, gl_local_invocation_id_, spv::StorageClassInput)
935 .commit(&global_);
936 this->decorate(spv::OpDecorate, gl_local_invocation_id_,
937 spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId);
938 }
939 SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
940 Value ptr = this->make_value(
941 spv::OpAccessChain, pint_type, gl_local_invocation_id_,
942 uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index)));
943
944 return this->make_value(spv::OpLoad, t_uint32_, ptr);
945}
946
947Value IRBuilder::get_global_invocation_id(uint32_t dim_index) {
948 if (gl_global_invocation_id_.id == 0) {
949 SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput);
950 gl_global_invocation_id_ = new_value(ptr_type, ValueKind::kVectorPtr);
951 ib_.begin(spv::OpVariable)
952 .add_seq(ptr_type, gl_global_invocation_id_, spv::StorageClassInput)
953 .commit(&global_);
954 this->decorate(spv::OpDecorate, gl_global_invocation_id_,
955 spv::DecorationBuiltIn, spv::BuiltInGlobalInvocationId);
956 }
957 SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
958 Value ptr = this->make_value(
959 spv::OpAccessChain, pint_type, gl_global_invocation_id_,
960 uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index)));
961
962 return this->make_value(spv::OpLoad, t_uint32_, ptr);
963}
964
965Value IRBuilder::get_subgroup_invocation_id() {
966 if (subgroup_local_invocation_id_.id == 0) {
967 SType ptr_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
968 subgroup_local_invocation_id_ =
969 new_value(ptr_type, ValueKind::kVariablePtr);
970 ib_.begin(spv::OpVariable)
971 .add_seq(ptr_type, subgroup_local_invocation_id_,
972 spv::StorageClassInput)
973 .commit(&global_);
974 this->decorate(spv::OpDecorate, subgroup_local_invocation_id_,
975 spv::DecorationBuiltIn,
976 spv::BuiltInSubgroupLocalInvocationId);
977 global_values.push_back(subgroup_local_invocation_id_);
978 }
979
980 return this->make_value(spv::OpLoad, t_uint32_,
981 subgroup_local_invocation_id_);
982}
983
984Value IRBuilder::get_subgroup_size() {
985 if (subgroup_size_.id == 0) {
986 SType ptr_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
987 subgroup_size_ = new_value(ptr_type, ValueKind::kVariablePtr);
988 ib_.begin(spv::OpVariable)
989 .add_seq(ptr_type, subgroup_size_, spv::StorageClassInput)
990 .commit(&global_);
991 this->decorate(spv::OpDecorate, subgroup_size_, spv::DecorationBuiltIn,
992 spv::BuiltInSubgroupSize);
993 global_values.push_back(subgroup_size_);
994 }
995
996 return this->make_value(spv::OpLoad, t_uint32_, subgroup_size_);
997}
998
999#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \
1000 Value IRBuilder::_OpName(Value a, Value b) { \
1001 TI_ASSERT(a.stype.id == b.stype.id); \
1002 if (is_integral(a.stype.dt)) { \
1003 return make_value(spv::OpI##_Op, a.stype, a, b); \
1004 } else { \
1005 TI_ASSERT(is_real(a.stype.dt)); \
1006 return make_value(spv::OpF##_Op, a.stype, a, b); \
1007 } \
1008 }
1009
1010#define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \
1011 Value IRBuilder::_OpName(Value a, Value b) { \
1012 TI_ASSERT(a.stype.id == b.stype.id); \
1013 if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \
1014 return make_value(spv::OpS##_Op, a.stype, a, b); \
1015 } else if (is_integral(a.stype.dt)) { \
1016 return make_value(spv::OpU##_Op, a.stype, a, b); \
1017 } else { \
1018 TI_ASSERT(is_real(a.stype.dt)); \
1019 return make_value(spv::OpF##_Op, a.stype, a, b); \
1020 } \
1021 }
1022
1023DEFINE_BUILDER_BINARY_USIGN_OP(add, Add);
1024DEFINE_BUILDER_BINARY_USIGN_OP(sub, Sub);
1025DEFINE_BUILDER_BINARY_USIGN_OP(mul, Mul);
1026DEFINE_BUILDER_BINARY_SIGN_OP(div, Div);
1027
1028Value IRBuilder::mod(Value a, Value b) {
1029 TI_ASSERT(a.stype.id == b.stype.id);
1030 if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) {
1031 // FIXME: figure out why OpSRem does not work
1032 return sub(a, mul(b, div(a, b)));
1033 } else if (is_integral(a.stype.dt)) {
1034 return make_value(spv::OpUMod, a.stype, a, b);
1035 } else {
1036 TI_ASSERT(is_real(a.stype.dt));
1037 return make_value(spv::OpFRem, a.stype, a, b);
1038 }
1039}
1040
1041#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \
1042 Value IRBuilder::_OpName(Value a, Value b) { \
1043 TI_ASSERT(a.stype.id == b.stype.id); \
1044 const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \
1045 if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \
1046 return make_value(spv::OpS##_Op, bool_type, a, b); \
1047 } else if (is_integral(a.stype.dt)) { \
1048 return make_value(spv::OpU##_Op, bool_type, a, b); \
1049 } else { \
1050 TI_ASSERT(is_real(a.stype.dt)); \
1051 return make_value(spv::OpFOrd##_Op, bool_type, a, b); \
1052 } \
1053 }
1054
1055DEFINE_BUILDER_CMP_OP(lt, LessThan);
1056DEFINE_BUILDER_CMP_OP(le, LessThanEqual);
1057DEFINE_BUILDER_CMP_OP(gt, GreaterThan);
1058DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
1059
1060#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \
1061 Value IRBuilder::_OpName(Value a, Value b) { \
1062 TI_ASSERT(a.stype.id == b.stype.id); \
1063 const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \
1064 if (is_integral(a.stype.dt)) { \
1065 return make_value(spv::OpI##_Op, bool_type, a, b); \
1066 } else if (a.stype.id == bool_type.id) { \
1067 return make_value(spv::OpLogical##_Op, bool_type, a, b); \
1068 } else { \
1069 TI_ASSERT(is_real(a.stype.dt)); \
1070 return make_value(spv::OpFOrd##_Op, bool_type, a, b); \
1071 } \
1072 }
1073
1074DEFINE_BUILDER_CMP_UOP(eq, Equal);
1075DEFINE_BUILDER_CMP_UOP(ne, NotEqual);
1076
1077Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) {
1078 TI_ASSERT(is_integral(base.stype.dt));
1079 TI_ASSERT(is_integral(offset.stype.dt));
1080 TI_ASSERT(is_integral(count.stype.dt));
1081 return make_value(spv::OpBitFieldUExtract, base.stype, base, offset, count);
1082}
1083
1084Value IRBuilder::select(Value cond, Value a, Value b) {
1085 TI_ASSERT(a.stype.id == b.stype.id);
1086 TI_ASSERT(cond.stype.id == t_bool_.id);
1087 return make_value(spv::OpSelect, a.stype, cond, a, b);
1088}
1089
1090Value IRBuilder::cast(const SType &dst_type, Value value) {
1091 TI_ASSERT(value.stype.id > 0U);
1092 if (value.stype.id == dst_type.id)
1093 return value;
1094 const DataType &from = value.stype.dt;
1095 const DataType &to = dst_type.dt;
1096 if (from->is_primitive(PrimitiveTypeID::u1)) { // Bool
1097 if (is_integral(to) && is_signed(to)) { // Bool -> Int
1098 return select(value, int_immediate_number(dst_type, 1),
1099 int_immediate_number(dst_type, 0));
1100 } else if (is_integral(to) && is_unsigned(to)) { // Bool -> UInt
1101 return select(value, uint_immediate_number(dst_type, 1),
1102 uint_immediate_number(dst_type, 0));
1103 } else if (is_real(to)) { // Bool -> Float
1104 return make_value(spv::OpConvertUToF, dst_type,
1105 select(value, uint_immediate_number(t_uint32_, 1),
1106 uint_immediate_number(t_uint32_, 0)));
1107 } else {
1108 TI_ERROR("do not support type cast from {} to {}", from.to_string(),
1109 to.to_string());
1110 return Value();
1111 }
1112 } else if (to->is_primitive(PrimitiveTypeID::u1)) { // Bool
1113 if (is_integral(from) && is_signed(from)) { // Int -> Bool
1114 return ne(value, int_immediate_number(value.stype, 0));
1115 } else if (is_integral(from) && is_unsigned(from)) { // UInt -> Bool
1116 return ne(value, uint_immediate_number(value.stype, 0));
1117 } else {
1118 TI_ERROR("do not support type cast from {} to {}", from.to_string(),
1119 to.to_string());
1120 return Value();
1121 }
1122 } else if (is_integral(from) && is_integral(to)) {
1123 auto ret = value;
1124
1125 if (data_type_bits(from) == data_type_bits(to)) {
1126 // Same width conversion
1127 ret = make_value(spv::OpBitcast, dst_type, ret);
1128 } else {
1129 // Different width
1130 // Step 1. Sign extend / truncate value to width of `to`
1131 // Step 2. Bitcast to signess of `to`
1132 auto get_signed_type = [](DataType dt) -> DataType {
1133 // Create a output signed type with the same width as `dt`
1134 if (data_type_bits(dt) == 8)
1135 return PrimitiveType::i8;
1136 else if (data_type_bits(dt) == 16)
1137 return PrimitiveType::i16;
1138 else if (data_type_bits(dt) == 32)
1139 return PrimitiveType::i32;
1140 else if (data_type_bits(dt) == 64)
1141 return PrimitiveType::i64;
1142 else
1143 return PrimitiveType::unknown;
1144 };
1145 auto get_unsigned_type = [](DataType dt) -> DataType {
1146 // Create a output unsigned type with the same width as `dt`
1147 if (data_type_bits(dt) == 8)
1148 return PrimitiveType::u8;
1149 else if (data_type_bits(dt) == 16)
1150 return PrimitiveType::u16;
1151 else if (data_type_bits(dt) == 32)
1152 return PrimitiveType::u32;
1153 else if (data_type_bits(dt) == 64)
1154 return PrimitiveType::u64;
1155 else
1156 return PrimitiveType::unknown;
1157 };
1158
1159 if (is_signed(from)) {
1160 ret = make_value(spv::OpSConvert,
1161 get_primitive_type(get_signed_type(to)), ret);
1162 } else {
1163 ret = make_value(spv::OpUConvert,
1164 get_primitive_type(get_unsigned_type(to)), ret);
1165 }
1166
1167 ret = make_value(spv::OpBitcast, dst_type, ret);
1168 }
1169
1170 return ret;
1171 } else if (is_real(from) && is_integral(to) &&
1172 is_signed(to)) { // Float -> Int
1173 return make_value(spv::OpConvertFToS, dst_type, value);
1174 } else if (is_real(from) && is_integral(to) &&
1175 is_unsigned(to)) { // Float -> UInt
1176 return make_value(spv::OpConvertFToU, dst_type, value);
1177 } else if (is_integral(from) && is_signed(from) &&
1178 is_real(to)) { // Int -> Float
1179 return make_value(spv::OpConvertSToF, dst_type, value);
1180 } else if (is_integral(from) && is_unsigned(from) &&
1181 is_real(to)) { // UInt -> Float
1182 return make_value(spv::OpConvertUToF, dst_type, value);
1183 } else if (is_real(from) && is_real(to)) { // Float -> Float
1184 return make_value(spv::OpFConvert, dst_type, value);
1185 } else {
1186 TI_ERROR("do not support type cast from {} to {}", from.to_string(),
1187 to.to_string());
1188 return Value();
1189 }
1190}
1191
1192Value IRBuilder::alloca_variable(const SType &type) {
1193 SType ptr_type = get_pointer_type(type, spv::StorageClassFunction);
1194 Value ret = new_value(ptr_type, ValueKind::kVariablePtr);
1195 ib_.begin(spv::OpVariable)
1196 .add_seq(ptr_type, ret, spv::StorageClassFunction)
1197 .commit(&func_header_);
1198 return ret;
1199}
1200
1201Value IRBuilder::alloca_workgroup_array(const SType &arr_type) {
1202 SType ptr_type = get_pointer_type(arr_type, spv::StorageClassWorkgroup);
1203 Value ret = new_value(ptr_type, ValueKind::kVariablePtr);
1204 ib_.begin(spv::OpVariable)
1205 .add_seq(ptr_type, ret, spv::StorageClassWorkgroup)
1206 .commit(&global_);
1207 return ret;
1208}
1209
1210Value IRBuilder::load_variable(Value pointer, const SType &res_type) {
1211 TI_ASSERT(pointer.flag == ValueKind::kVariablePtr ||
1212 pointer.flag == ValueKind::kStructArrayPtr ||
1213 pointer.flag == ValueKind::kPhysicalPtr);
1214 Value ret = new_value(res_type, ValueKind::kNormal);
1215 if (pointer.flag == ValueKind::kPhysicalPtr) {
1216 uint32_t alignment = uint32_t(get_primitive_type_size(res_type.dt));
1217 ib_.begin(spv::OpLoad)
1218 .add_seq(res_type, ret, pointer, spv::MemoryAccessAlignedMask,
1219 alignment)
1220 .commit(&function_);
1221 } else {
1222 ib_.begin(spv::OpLoad).add_seq(res_type, ret, pointer).commit(&function_);
1223 }
1224 return ret;
1225}
1226void IRBuilder::store_variable(Value pointer, Value value) {
1227 TI_ASSERT(pointer.flag == ValueKind::kVariablePtr ||
1228 pointer.flag == ValueKind::kPhysicalPtr);
1229 TI_ASSERT(value.stype.id == pointer.stype.element_type_id);
1230 if (pointer.flag == ValueKind::kPhysicalPtr) {
1231 uint32_t alignment = uint32_t(get_primitive_type_size(value.stype.dt));
1232 ib_.begin(spv::OpStore)
1233 .add_seq(pointer, value, spv::MemoryAccessAlignedMask, alignment)
1234 .commit(&function_);
1235 } else {
1236 ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&function_);
1237 }
1238}
1239
1240void IRBuilder::register_value(std::string name, Value value) {
1241 auto it = value_name_tbl_.find(name);
1242 if (it != value_name_tbl_.end() && it->second.flag != ValueKind::kConstant) {
1243 TI_ERROR("{} already exists.", name);
1244 }
1245 this->debug_name(
1246 spv::OpName, value,
1247 fmt::format("{}_{}", name, value.stype.dt.to_string())); // Debug info
1248 value_name_tbl_[name] = value;
1249}
1250
1251Value IRBuilder::query_value(std::string name) const {
1252 auto it = value_name_tbl_.find(name);
1253 if (it != value_name_tbl_.end()) {
1254 return it->second;
1255 }
1256 TI_ERROR("Value \"{}\" does not yet exist.", name);
1257}
1258
1259bool IRBuilder::check_value_existence(const std::string &name) const {
1260 return value_name_tbl_.find(name) != value_name_tbl_.end();
1261}
1262
1263Value IRBuilder::float_atomic(AtomicOpType op_type,
1264 Value addr_ptr,
1265 Value data) {
1266 auto atomic_func_ = [&](std::function<Value(Value, Value)> atomic_op) {
1267 Value ret_val_int = alloca_variable(t_uint32_);
1268
1269 // do-while
1270 Label head = new_label();
1271 Label body = new_label();
1272 Label branch_true = new_label();
1273 Label branch_false = new_label();
1274 Label merge = new_label();
1275 Label exit = new_label();
1276
1277 make_inst(spv::OpBranch, head);
1278 start_label(head);
1279 make_inst(spv::OpLoopMerge, branch_true, merge, 0);
1280 make_inst(spv::OpBranch, body);
1281 make_inst(spv::OpLabel, body);
1282 // while (true)
1283 {
1284 // int old = addr_ptr[0];
1285 Value old_val = load_variable(addr_ptr, t_uint32_);
1286 // int new = floatBitsToInt(atomic_op(intBitsToFloat(old), data));
1287 Value old_float = make_value(spv::OpBitcast, t_fp32_, old_val);
1288 Value new_float = atomic_op(old_float, data);
1289 Value new_val = make_value(spv::OpBitcast, t_uint32_, new_float);
1290 // int loaded = atomicCompSwap(vals[0], old, new);
1291 /*
1292 * Don't need this part, theoretically
1293 auto semantics = uint_immediate_number(
1294 t_uint32_, spv::MemorySemanticsAcquireReleaseMask |
1295 spv::MemorySemanticsUniformMemoryMask);
1296 make_inst(spv::OpMemoryBarrier, const_i32_one_, semantics);
1297 */
1298 Value loaded = make_value(
1299 spv::OpAtomicCompareExchange, t_uint32_, addr_ptr,
1300 /*scope=*/const_i32_one_, /*semantics if equal=*/const_i32_zero_,
1301 /*semantics if unequal=*/const_i32_zero_, new_val, old_val);
1302 // bool ok = (loaded == old);
1303 Value ok = make_value(spv::OpIEqual, t_bool_, loaded, old_val);
1304 // int ret_val_int = loaded;
1305 store_variable(ret_val_int, loaded);
1306 // if (ok)
1307 make_inst(spv::OpSelectionMerge, branch_false, 0);
1308 make_inst(spv::OpBranchConditional, ok, branch_true, branch_false);
1309 {
1310 make_inst(spv::OpLabel, branch_true);
1311 make_inst(spv::OpBranch, exit);
1312 }
1313 // else
1314 {
1315 make_inst(spv::OpLabel, branch_false);
1316 make_inst(spv::OpBranch, merge);
1317 }
1318 // continue;
1319 make_inst(spv::OpLabel, merge);
1320 make_inst(spv::OpBranch, head);
1321 }
1322 start_label(exit);
1323
1324 return make_value(spv::OpBitcast, t_fp32_,
1325 load_variable(ret_val_int, t_uint32_));
1326 };
1327
1328 if (op_type == AtomicOpType::add) {
1329 return atomic_func_([&](Value lhs, Value rhs) { return add(lhs, rhs); });
1330 } else if (op_type == AtomicOpType::sub) {
1331 return atomic_func_([&](Value lhs, Value rhs) { return sub(lhs, rhs); });
1332 } else if (op_type == AtomicOpType::min) {
1333 return atomic_func_([&](Value lhs, Value rhs) {
1334 return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs);
1335 });
1336 } else if (op_type == AtomicOpType::max) {
1337 return atomic_func_([&](Value lhs, Value rhs) {
1338 return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs);
1339 });
1340 } else {
1341 TI_NOT_IMPLEMENTED
1342 }
1343}
1344
1345Value IRBuilder::rand_u32(Value global_tmp_) {
1346 if (!init_rand_) {
1347 init_random_function(global_tmp_);
1348 }
1349
1350 Value _11u = uint_immediate_number(t_uint32_, 11u);
1351 Value _19u = uint_immediate_number(t_uint32_, 19u);
1352 Value _8u = uint_immediate_number(t_uint32_, 8u);
1353 Value _1000000007u = uint_immediate_number(t_uint32_, 1000000007u);
1354 Value tmp0 = load_variable(rand_x_, t_uint32_);
1355 Value tmp1 = make_value(spv::OpShiftLeftLogical, t_uint32_, tmp0, _11u);
1356 Value tmp_t = make_value(spv::OpBitwiseXor, t_uint32_, tmp0, tmp1); // t
1357 store_variable(rand_x_, load_variable(rand_y_, t_uint32_));
1358 store_variable(rand_y_, load_variable(rand_z_, t_uint32_));
1359 Value tmp_w = load_variable(rand_w_, t_uint32_); // reuse w
1360 store_variable(rand_z_, tmp_w);
1361 Value tmp2 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_w, _19u);
1362 Value tmp3 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_w, tmp2);
1363 Value tmp4 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_t, _8u);
1364 Value tmp5 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_t, tmp4);
1365 Value new_w = make_value(spv::OpBitwiseXor, t_uint32_, tmp3, tmp5);
1366 store_variable(rand_w_, new_w);
1367 Value val = make_value(spv::OpIMul, t_uint32_, new_w, _1000000007u);
1368
1369 return val;
1370}
1371
1372Value IRBuilder::rand_f32(Value global_tmp_) {
1373 if (!init_rand_) {
1374 init_random_function(global_tmp_);
1375 }
1376
1377 Value _1_4294967296f = float_immediate_number(t_fp32_, 1.0f / 4294967296.0f);
1378 Value tmp0 = rand_u32(global_tmp_);
1379 Value tmp1 = cast(t_fp32_, tmp0);
1380 Value val = mul(tmp1, _1_4294967296f);
1381
1382 return val;
1383}
1384
1385Value IRBuilder::rand_i32(Value global_tmp_) {
1386 if (!init_rand_) {
1387 init_random_function(global_tmp_);
1388 }
1389
1390 Value tmp0 = rand_u32(global_tmp_);
1391 Value val = cast(t_int32_, tmp0);
1392 return val;
1393}
1394
1395Value IRBuilder::get_const(const SType &dtype,
1396 const uint64_t *pvalue,
1397 bool cache) {
1398 auto key = std::make_pair(dtype.id, pvalue[0]);
1399 if (cache) {
1400 auto it = const_tbl_.find(key);
1401 if (it != const_tbl_.end()) {
1402 return it->second;
1403 }
1404 }
1405
1406 TI_ASSERT(dtype.flag == TypeKind::kPrimitive);
1407 Value ret = new_value(dtype, ValueKind::kConstant);
1408 if (dtype.dt->is_primitive(PrimitiveTypeID::u1)) {
1409 // bool type
1410 if (*pvalue) {
1411 ib_.begin(spv::OpConstantTrue).add_seq(dtype, ret);
1412 } else {
1413 ib_.begin(spv::OpConstantFalse).add_seq(dtype, ret);
1414 }
1415 } else {
1416 // Integral/floating-point types.
1417 ib_.begin(spv::OpConstant).add_seq(dtype, ret);
1418 uint64_t mask = 0xFFFFFFFFUL;
1419 ib_.add(static_cast<uint32_t>(pvalue[0] & mask));
1420 if (data_type_bits(dtype.dt) > 32) {
1421 if (is_integral(dtype.dt)) {
1422 int64_t sign_mask = 0xFFFFFFFFL;
1423 const int64_t *sign_ptr = reinterpret_cast<const int64_t *>(pvalue);
1424 ib_.add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask));
1425 } else {
1426 ib_.add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask));
1427 }
1428 }
1429 }
1430
1431 ib_.commit(&global_);
1432 if (cache) {
1433 const_tbl_[key] = ret;
1434 }
1435 return ret;
1436}
1437
1438SType IRBuilder::declare_primitive_type(DataType dt) {
1439 SType t;
1440 t.id = id_counter_++;
1441 t.dt = dt;
1442 t.flag = TypeKind::kPrimitive;
1443
1444 dt.set_is_pointer(false);
1445 if (dt->is_primitive(PrimitiveTypeID::u1))
1446 ib_.begin(spv::OpTypeBool).add(t).commit(&global_);
1447 else if (is_real(dt))
1448 ib_.begin(spv::OpTypeFloat).add_seq(t, data_type_bits(dt)).commit(&global_);
1449 else if (is_integral(dt))
1450 ib_.begin(spv::OpTypeInt)
1451 .add_seq(t, data_type_bits(dt), static_cast<int>(is_signed(dt)))
1452 .commit(&global_);
1453 else {
1454 TI_ERROR("Type {} not supported.", dt->to_string());
1455 }
1456
1457 return t;
1458}
1459
1460void IRBuilder::init_random_function(Value global_tmp_) {
1461 // variables declare
1462 SType local_type = get_pointer_type(t_uint32_, spv::StorageClassPrivate);
1463 rand_x_ = new_value(local_type, ValueKind::kVariablePtr);
1464 rand_y_ = new_value(local_type, ValueKind::kVariablePtr);
1465 rand_z_ = new_value(local_type, ValueKind::kVariablePtr);
1466 rand_w_ = new_value(local_type, ValueKind::kVariablePtr);
1467 global_values.push_back(rand_x_);
1468 global_values.push_back(rand_y_);
1469 global_values.push_back(rand_z_);
1470 global_values.push_back(rand_w_);
1471 ib_.begin(spv::OpVariable)
1472 .add_seq(local_type, rand_x_, spv::StorageClassPrivate)
1473 .commit(&global_);
1474 ib_.begin(spv::OpVariable)
1475 .add_seq(local_type, rand_y_, spv::StorageClassPrivate)
1476 .commit(&global_);
1477 ib_.begin(spv::OpVariable)
1478 .add_seq(local_type, rand_z_, spv::StorageClassPrivate)
1479 .commit(&global_);
1480 ib_.begin(spv::OpVariable)
1481 .add_seq(local_type, rand_w_, spv::StorageClassPrivate)
1482 .commit(&global_);
1483 debug_name(spv::OpName, rand_x_, "_rand_x");
1484 debug_name(spv::OpName, rand_y_, "_rand_y");
1485 debug_name(spv::OpName, rand_z_, "_rand_z");
1486 debug_name(spv::OpName, rand_w_, "_rand_w");
1487 SType gtmp_type = get_pointer_type(t_uint32_, spv::StorageClassStorageBuffer);
1488 Value rand_gtmp_ = new_value(gtmp_type, ValueKind::kVariablePtr);
1489 debug_name(spv::OpName, rand_gtmp_, "rand_gtmp");
1490
1491 auto load_var = [&](Value pointer, const SType &res_type) {
1492 TI_ASSERT(pointer.flag == ValueKind::kVariablePtr ||
1493 pointer.flag == ValueKind::kStructArrayPtr);
1494 Value ret = new_value(res_type, ValueKind::kNormal);
1495 ib_.begin(spv::OpLoad)
1496 .add_seq(res_type, ret, pointer)
1497 .commit(&func_header_);
1498 return ret;
1499 };
1500
1501 auto store_var = [&](Value pointer, Value value) {
1502 TI_ASSERT(pointer.flag == ValueKind::kVariablePtr);
1503 TI_ASSERT(value.stype.id == pointer.stype.element_type_id);
1504 ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&func_header_);
1505 };
1506
1507 // Constant Number
1508 Value _7654321u = uint_immediate_number(t_uint32_, 7654321u);
1509 Value _1234567u = uint_immediate_number(t_uint32_, 1234567u);
1510 Value _9723451u = uint_immediate_number(t_uint32_, 9723451u);
1511 Value _123456789u = uint_immediate_number(t_uint32_, 123456789u);
1512 Value _1000000007u = uint_immediate_number(t_uint32_, 1000000007u);
1513 Value _362436069u = uint_immediate_number(t_uint32_, 362436069u);
1514 Value _521288629u = uint_immediate_number(t_uint32_, 521288629u);
1515 Value _88675123u = uint_immediate_number(t_uint32_, 88675123u);
1516 Value _1 = int_immediate_number(t_uint32_, 1);
1517 Value _1024 = int_immediate_number(t_uint32_, 1024);
1518
1519 // init_rand_ segment (inline to main)
1520 // ad-hoc: hope no kernel will use more than 1024 gtmp variables...
1521 ib_.begin(spv::OpAccessChain)
1522 .add_seq(gtmp_type, rand_gtmp_, global_tmp_, const_i32_zero_, _1024)
1523 .commit(&func_header_);
1524 // Get gl_GlobalInvocationID.x, assert it has be visited
1525 // (in generate_serial_kernel/generate_range_for_kernel
1526 SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput);
1527 Value tmp0 = new_value(pint_type, ValueKind::kVariablePtr);
1528 ib_.begin(spv::OpAccessChain)
1529 .add_seq(pint_type, tmp0, gl_global_invocation_id_,
1530 uint_immediate_number(t_uint32_, 0))
1531 .commit(&func_header_);
1532 Value tmp1 = load_var(tmp0, t_uint32_);
1533 Value tmp2_ = load_var(rand_gtmp_, t_uint32_);
1534 Value tmp2 = new_value(t_uint32_, ValueKind::kNormal);
1535 ib_.begin(spv::OpBitcast)
1536 .add_seq(t_uint32_, tmp2, tmp2_)
1537 .commit(&func_header_);
1538 Value tmp3 = new_value(t_uint32_, ValueKind::kNormal);
1539 ib_.begin(spv::OpIAdd)
1540 .add_seq(t_uint32_, tmp3, _7654321u, tmp1)
1541 .commit(&func_header_);
1542 Value tmp4 = new_value(t_uint32_, ValueKind::kNormal);
1543 ib_.begin(spv::OpIMul)
1544 .add_seq(t_uint32_, tmp4, _9723451u, tmp2)
1545 .commit(&func_header_);
1546 Value tmp5 = new_value(t_uint32_, ValueKind::kNormal);
1547 ib_.begin(spv::OpIAdd)
1548 .add_seq(t_uint32_, tmp5, _1234567u, tmp4)
1549 .commit(&func_header_);
1550 Value tmp6 = new_value(t_uint32_, ValueKind::kNormal);
1551 ib_.begin(spv::OpIMul)
1552 .add_seq(t_uint32_, tmp6, tmp3, tmp5)
1553 .commit(&func_header_);
1554 Value tmp7 = new_value(t_uint32_, ValueKind::kNormal);
1555 ib_.begin(spv::OpIMul)
1556 .add_seq(t_uint32_, tmp7, _123456789u, tmp6)
1557 .commit(&func_header_);
1558 Value tmp8 = new_value(t_uint32_, ValueKind::kNormal);
1559 ib_.begin(spv::OpIMul)
1560 .add_seq(t_uint32_, tmp8, _1000000007u, tmp7)
1561 .commit(&func_header_);
1562 store_var(rand_x_, tmp8);
1563 store_var(rand_y_, _362436069u);
1564 store_var(rand_z_, _521288629u);
1565 store_var(rand_w_, _88675123u);
1566
1567 // enum spv::Op add_op = spv::OpIAdd;
1568 bool use_atomic_increment = false;
1569
1570// use atomic increment for DX API to avoid error X3694
1571#ifdef TI_WITH_DX11
1572 if (arch_ == Arch::dx11) {
1573 use_atomic_increment = true;
1574 }
1575#endif
1576
1577 if (use_atomic_increment) {
1578 Value tmp9 = new_value(t_uint32_, ValueKind::kNormal);
1579 ib_.begin(spv::Op::OpAtomicIIncrement)
1580 .add_seq(t_uint32_, tmp9, rand_gtmp_,
1581 /*scope_id*/ const_i32_one_,
1582 /*semantics*/ const_i32_zero_)
1583 .commit(&func_header_);
1584 } else {
1585 // Yes, this is not an atomic operation, but just fine since no matter
1586 // how RAND_STATE changes, `gl_GlobalInvocationID.x` can still help
1587 // us to set different seeds for different threads.
1588 // Discussion:
1589 // https://github.com/taichi-dev/taichi/pull/912#discussion_r419021918
1590 Value tmp9 = load_var(rand_gtmp_, t_uint32_);
1591 Value tmp10 = new_value(t_uint32_, ValueKind::kNormal);
1592 ib_.begin(spv::Op::OpIAdd)
1593 .add_seq(t_uint32_, tmp10, tmp9, _1)
1594 .commit(&func_header_);
1595 store_var(rand_gtmp_, tmp10);
1596 }
1597
1598 init_rand_ = true;
1599}
1600
1601} // namespace spirv
1602} // namespace taichi::lang
1603