1 | #include "taichi/codegen/spirv/spirv_ir_builder.h" |
2 | #include "taichi/rhi/dx/dx_device.h" |
3 | |
4 | namespace taichi::lang { |
5 | |
6 | namespace spirv { |
7 | |
8 | using cap = DeviceCapability; |
9 | |
10 | void IRBuilder::() { |
11 | TI_ASSERT(header_.size() == 0U); |
12 | header_.push_back(spv::MagicNumber); |
13 | |
14 | header_.push_back(caps_->get(cap::spirv_version)); |
15 | |
16 | TI_TRACE("SPIR-V Version {}" , caps_->get(cap::spirv_version)); |
17 | |
18 | // generator: set to 0, unknown |
19 | header_.push_back(0U); |
20 | // Bound: set during Finalize |
21 | header_.push_back(0U); |
22 | // Schema: reserved |
23 | header_.push_back(0U); |
24 | |
25 | // capability |
26 | ib_.begin(spv::OpCapability).add(spv::CapabilityShader).commit(&header_); |
27 | |
28 | if (caps_->get(cap::spirv_has_atomic_float64_add)) { |
29 | ib_.begin(spv::OpCapability) |
30 | .add(spv::CapabilityAtomicFloat64AddEXT) |
31 | .commit(&header_); |
32 | } |
33 | |
34 | if (caps_->get(cap::spirv_has_atomic_float_add)) { |
35 | ib_.begin(spv::OpCapability) |
36 | .add(spv::CapabilityAtomicFloat32AddEXT) |
37 | .commit(&header_); |
38 | } |
39 | |
40 | if (caps_->get(cap::spirv_has_atomic_float_minmax)) { |
41 | ib_.begin(spv::OpCapability) |
42 | .add(spv::CapabilityAtomicFloat32MinMaxEXT) |
43 | .commit(&header_); |
44 | } |
45 | |
46 | if (caps_->get(cap::spirv_has_variable_ptr)) { |
47 | /* |
48 | ib_.begin(spv::OpCapability) |
49 | .add(spv::CapabilityVariablePointers) |
50 | .commit(&header_); |
51 | ib_.begin(spv::OpCapability) |
52 | .add(spv::CapabilityVariablePointersStorageBuffer) |
53 | .commit(&header_); |
54 | */ |
55 | } |
56 | |
57 | if (caps_->get(cap::spirv_has_int8)) { |
58 | ib_.begin(spv::OpCapability).add(spv::CapabilityInt8).commit(&header_); |
59 | } |
60 | if (caps_->get(cap::spirv_has_int16)) { |
61 | ib_.begin(spv::OpCapability).add(spv::CapabilityInt16).commit(&header_); |
62 | } |
63 | if (caps_->get(cap::spirv_has_int64)) { |
64 | ib_.begin(spv::OpCapability).add(spv::CapabilityInt64).commit(&header_); |
65 | } |
66 | if (caps_->get(cap::spirv_has_float16)) { |
67 | ib_.begin(spv::OpCapability).add(spv::CapabilityFloat16).commit(&header_); |
68 | } |
69 | if (caps_->get(cap::spirv_has_float64)) { |
70 | ib_.begin(spv::OpCapability).add(spv::CapabilityFloat64).commit(&header_); |
71 | } |
72 | if (caps_->get(cap::spirv_has_physical_storage_buffer)) { |
73 | ib_.begin(spv::OpCapability) |
74 | .add(spv::CapabilityPhysicalStorageBufferAddresses) |
75 | .commit(&header_); |
76 | } |
77 | |
78 | ib_.begin(spv::OpExtension) |
79 | .add("SPV_KHR_storage_buffer_storage_class" ) |
80 | .commit(&header_); |
81 | |
82 | if (caps_->get(cap::spirv_has_no_integer_wrap_decoration)) { |
83 | ib_.begin(spv::OpExtension) |
84 | .add("SPV_KHR_no_integer_wrap_decoration" ) |
85 | .commit(&header_); |
86 | } |
87 | |
88 | if (caps_->get(cap::spirv_has_non_semantic_info)) { |
89 | ib_.begin(spv::OpExtension) |
90 | .add("SPV_KHR_non_semantic_info" ) |
91 | .commit(&header_); |
92 | } |
93 | |
94 | if (caps_->get(cap::spirv_has_variable_ptr)) { |
95 | ib_.begin(spv::OpExtension) |
96 | .add("SPV_KHR_variable_pointers" ) |
97 | .commit(&header_); |
98 | } |
99 | |
100 | if (caps_->get(cap::spirv_has_atomic_float_add)) { |
101 | ib_.begin(spv::OpExtension) |
102 | .add("SPV_EXT_shader_atomic_float_add" ) |
103 | .commit(&header_); |
104 | } |
105 | |
106 | if (caps_->get(cap::spirv_has_atomic_float_minmax)) { |
107 | ib_.begin(spv::OpExtension) |
108 | .add("SPV_EXT_shader_atomic_float_min_max" ) |
109 | .commit(&header_); |
110 | } |
111 | |
112 | if (caps_->get(cap::spirv_has_physical_storage_buffer)) { |
113 | ib_.begin(spv::OpExtension) |
114 | .add("SPV_KHR_physical_storage_buffer" ) |
115 | .commit(&header_); |
116 | |
117 | // memory model |
118 | ib_.begin(spv::OpMemoryModel) |
119 | .add_seq(spv::AddressingModelPhysicalStorageBuffer64, |
120 | spv::MemoryModelGLSL450) |
121 | .commit(&entry_); |
122 | } else { |
123 | ib_.begin(spv::OpMemoryModel) |
124 | .add_seq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) |
125 | .commit(&entry_); |
126 | } |
127 | |
128 | this->init_pre_defs(); |
129 | } |
130 | |
131 | std::vector<uint32_t> IRBuilder::finalize() { |
132 | std::vector<uint32_t> data; |
133 | // set bound |
134 | const int bound_loc = 3; |
135 | header_[bound_loc] = id_counter_; |
136 | data.insert(data.end(), header_.begin(), header_.end()); |
137 | data.insert(data.end(), entry_.begin(), entry_.end()); |
138 | data.insert(data.end(), exec_mode_.begin(), exec_mode_.end()); |
139 | data.insert(data.end(), strings_.begin(), strings_.end()); |
140 | data.insert(data.end(), names_.begin(), names_.end()); |
141 | data.insert(data.end(), decorate_.begin(), decorate_.end()); |
142 | data.insert(data.end(), global_.begin(), global_.end()); |
143 | data.insert(data.end(), func_header_.begin(), func_header_.end()); |
144 | data.insert(data.end(), function_.begin(), function_.end()); |
145 | return data; |
146 | } |
147 | |
148 | void IRBuilder::init_pre_defs() { |
149 | ext_glsl450_ = ext_inst_import("GLSL.std.450" ); |
150 | if (caps_->get(cap::spirv_has_non_semantic_info)) { |
151 | debug_printf_ = ext_inst_import("NonSemantic.DebugPrintf" ); |
152 | } |
153 | |
154 | t_bool_ = declare_primitive_type(get_data_type<bool>()); |
155 | if (caps_->get(cap::spirv_has_int8)) { |
156 | t_int8_ = declare_primitive_type(get_data_type<int8>()); |
157 | t_uint8_ = declare_primitive_type(get_data_type<uint8>()); |
158 | } |
159 | if (caps_->get(cap::spirv_has_int16)) { |
160 | t_int16_ = declare_primitive_type(get_data_type<int16>()); |
161 | t_uint16_ = declare_primitive_type(get_data_type<uint16>()); |
162 | } |
163 | t_int32_ = declare_primitive_type(get_data_type<int32>()); |
164 | t_uint32_ = declare_primitive_type(get_data_type<uint32>()); |
165 | if (caps_->get(cap::spirv_has_int64)) { |
166 | t_int64_ = declare_primitive_type(get_data_type<int64>()); |
167 | t_uint64_ = declare_primitive_type(get_data_type<uint64>()); |
168 | } |
169 | t_fp32_ = declare_primitive_type(get_data_type<float32>()); |
170 | if (caps_->get(cap::spirv_has_float16)) { |
171 | t_fp16_ = declare_primitive_type(PrimitiveType::f16); |
172 | } |
173 | if (caps_->get(cap::spirv_has_float64)) { |
174 | t_fp64_ = declare_primitive_type(get_data_type<float64>()); |
175 | } |
176 | // declare void, and void functions |
177 | t_void_.id = id_counter_++; |
178 | ib_.begin(spv::OpTypeVoid).add(t_void_).commit(&global_); |
179 | t_void_func_.id = id_counter_++; |
180 | ib_.begin(spv::OpTypeFunction) |
181 | .add_seq(t_void_func_, t_void_) |
182 | .commit(&global_); |
183 | |
184 | // compute shader related types |
185 | t_v2_int_.id = id_counter_++; |
186 | ib_.begin(spv::OpTypeVector) |
187 | .add(t_v2_int_) |
188 | .add_seq(t_int32_, 2) |
189 | .commit(&global_); |
190 | |
191 | t_v3_int_.id = id_counter_++; |
192 | ib_.begin(spv::OpTypeVector) |
193 | .add(t_v3_int_) |
194 | .add_seq(t_int32_, 3) |
195 | .commit(&global_); |
196 | |
197 | t_v3_uint_.id = id_counter_++; |
198 | ib_.begin(spv::OpTypeVector) |
199 | .add(t_v3_uint_) |
200 | .add_seq(t_uint32_, 3) |
201 | .commit(&global_); |
202 | |
203 | t_v4_fp32_.id = id_counter_++; |
204 | ib_.begin(spv::OpTypeVector) |
205 | .add(t_v4_fp32_) |
206 | .add_seq(t_fp32_, 4) |
207 | .commit(&global_); |
208 | |
209 | t_v2_fp32_.id = id_counter_++; |
210 | ib_.begin(spv::OpTypeVector) |
211 | .add(t_v2_fp32_) |
212 | .add_seq(t_fp32_, 2) |
213 | .commit(&global_); |
214 | |
215 | t_v3_fp32_.id = id_counter_++; |
216 | ib_.begin(spv::OpTypeVector) |
217 | .add(t_v3_fp32_) |
218 | .add_seq(t_fp32_, 3) |
219 | .commit(&global_); |
220 | |
221 | // pre-defined constants |
222 | const_i32_zero_ = int_immediate_number(t_int32_, 0); |
223 | const_i32_one_ = int_immediate_number(t_int32_, 1); |
224 | } |
225 | |
226 | Value IRBuilder::debug_string(std::string s) { |
227 | Value val = new_value(SType(), ValueKind::kNormal); |
228 | ib_.begin(spv::OpString).add_seq(val, s).commit(&strings_); |
229 | return val; |
230 | } |
231 | |
232 | PhiValue IRBuilder::make_phi(const SType &out_type, uint32_t num_incoming) { |
233 | Value val = new_value(out_type, ValueKind::kNormal); |
234 | ib_.begin(spv::OpPhi).add_seq(out_type, val); |
235 | for (uint32_t i = 0; i < 2 * num_incoming; ++i) { |
236 | ib_.add(0); |
237 | } |
238 | |
239 | PhiValue phi; |
240 | phi.id = val.id; |
241 | phi.stype = out_type; |
242 | phi.flag = ValueKind::kNormal; |
243 | phi.instr = ib_.commit(&function_); |
244 | return phi; |
245 | } |
246 | |
247 | Value IRBuilder::int_immediate_number(const SType &dtype, |
248 | int64_t value, |
249 | bool cache) { |
250 | return get_const(dtype, reinterpret_cast<uint64_t *>(&value), cache); |
251 | } |
252 | |
253 | Value IRBuilder::uint_immediate_number(const SType &dtype, |
254 | uint64_t value, |
255 | bool cache) { |
256 | return get_const(dtype, &value, cache); |
257 | } |
258 | |
259 | Value IRBuilder::float_immediate_number(const SType &dtype, |
260 | double value, |
261 | bool cache) { |
262 | if (data_type_bits(dtype.dt) == 64) { |
263 | return get_const(dtype, reinterpret_cast<uint64_t *>(&value), cache); |
264 | } else if (data_type_bits(dtype.dt) == 32) { |
265 | float fvalue = static_cast<float>(value); |
266 | uint32_t *ptr = reinterpret_cast<uint32_t *>(&fvalue); |
267 | uint64_t data = ptr[0]; |
268 | return get_const(dtype, &data, cache); |
269 | } else if (data_type_bits(dtype.dt) == 16) { |
270 | float fvalue = static_cast<float>(value); |
271 | uint16_t *ptr = reinterpret_cast<uint16_t *>(&fvalue); |
272 | uint64_t data = ptr[0]; |
273 | return get_const(dtype, &data, cache); |
274 | } else { |
275 | TI_ERROR("Type {} not supported." , dtype.dt->to_string()); |
276 | } |
277 | } |
278 | |
279 | SType IRBuilder::get_null_type() { |
280 | SType res; |
281 | res.id = id_counter_++; |
282 | return res; |
283 | } |
284 | |
285 | SType IRBuilder::get_primitive_type(const DataType &dt) const { |
286 | if (dt->is_primitive(PrimitiveTypeID::u1)) { |
287 | return t_bool_; |
288 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
289 | if (!caps_->get(cap::spirv_has_float16)) |
290 | TI_ERROR("Type {} not supported." , dt->to_string()); |
291 | return t_fp16_; |
292 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
293 | return t_fp32_; |
294 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
295 | if (!caps_->get(cap::spirv_has_float64)) |
296 | TI_ERROR("Type {} not supported." , dt->to_string()); |
297 | return t_fp64_; |
298 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
299 | if (!caps_->get(cap::spirv_has_int8)) |
300 | TI_ERROR("Type {} not supported." , dt->to_string()); |
301 | return t_int8_; |
302 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
303 | if (!caps_->get(cap::spirv_has_int16)) |
304 | TI_ERROR("Type {} not supported." , dt->to_string()); |
305 | return t_int16_; |
306 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
307 | return t_int32_; |
308 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
309 | if (!caps_->get(cap::spirv_has_int64)) |
310 | TI_ERROR("Type {} not supported." , dt->to_string()); |
311 | return t_int64_; |
312 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
313 | if (!caps_->get(cap::spirv_has_int8)) |
314 | TI_ERROR("Type {} not supported." , dt->to_string()); |
315 | return t_uint8_; |
316 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
317 | if (!caps_->get(cap::spirv_has_int16)) |
318 | TI_ERROR("Type {} not supported." , dt->to_string()); |
319 | return t_uint16_; |
320 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
321 | return t_uint32_; |
322 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
323 | if (!caps_->get(cap::spirv_has_int64)) |
324 | TI_ERROR("Type {} not supported." , dt->to_string()); |
325 | return t_uint64_; |
326 | } else { |
327 | TI_ERROR("Type {} not supported." , dt->to_string()); |
328 | } |
329 | } |
330 | |
331 | size_t IRBuilder::get_primitive_type_size(const DataType &dt) const { |
332 | if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || |
333 | dt == PrimitiveType::f64) { |
334 | return 8; |
335 | } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || |
336 | dt == PrimitiveType::f32) { |
337 | return 4; |
338 | } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || |
339 | dt == PrimitiveType::f16) { |
340 | return 2; |
341 | } else { |
342 | return 1; |
343 | } |
344 | } |
345 | |
346 | SType IRBuilder::get_primitive_uint_type(const DataType &dt) const { |
347 | if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || |
348 | dt == PrimitiveType::f64) { |
349 | return t_uint64_; |
350 | } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || |
351 | dt == PrimitiveType::f32) { |
352 | return t_uint32_; |
353 | } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || |
354 | dt == PrimitiveType::f16) { |
355 | return t_uint16_; |
356 | } else { |
357 | return t_uint8_; |
358 | } |
359 | } |
360 | |
361 | DataType IRBuilder::get_taichi_uint_type(const DataType &dt) const { |
362 | if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 || |
363 | dt == PrimitiveType::f64) { |
364 | return PrimitiveType::u64; |
365 | } else if (dt == PrimitiveType::i32 || dt == PrimitiveType::u32 || |
366 | dt == PrimitiveType::f32) { |
367 | return PrimitiveType::u32; |
368 | } else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 || |
369 | dt == PrimitiveType::f16) { |
370 | return PrimitiveType::u16; |
371 | } else { |
372 | return PrimitiveType::u8; |
373 | } |
374 | } |
375 | |
376 | SType IRBuilder::get_pointer_type(const SType &value_type, |
377 | spv::StorageClass storage_class) { |
378 | auto key = std::make_pair(value_type.id, storage_class); |
379 | auto it = pointer_type_tbl_.find(key); |
380 | if (it != pointer_type_tbl_.end()) { |
381 | return it->second; |
382 | } |
383 | SType t; |
384 | t.id = id_counter_++; |
385 | t.flag = TypeKind::kPtr; |
386 | t.element_type_id = value_type.id; |
387 | t.storage_class = storage_class; |
388 | ib_.begin(spv::OpTypePointer) |
389 | .add_seq(t, storage_class, value_type) |
390 | .commit(&global_); |
391 | pointer_type_tbl_[key] = t; |
392 | return t; |
393 | } |
394 | |
395 | SType IRBuilder::get_underlying_image_type(const SType &primitive_type, |
396 | int num_dimensions) { |
397 | auto key = std::make_pair(primitive_type.id, num_dimensions); |
398 | |
399 | auto it = sampled_image_underlying_image_type_.find(key); |
400 | if (it != sampled_image_underlying_image_type_.end()) { |
401 | return it->second; |
402 | } |
403 | |
404 | int img_id = id_counter_++; |
405 | spv::Dim dim; |
406 | if (num_dimensions == 1) { |
407 | dim = spv::Dim1D; |
408 | } else if (num_dimensions == 2) { |
409 | dim = spv::Dim2D; |
410 | } else if (num_dimensions == 3) { |
411 | dim = spv::Dim3D; |
412 | } else { |
413 | TI_ERROR("Unsupported number of dimensions: {}" , num_dimensions); |
414 | } |
415 | ib_.begin(spv::OpTypeImage) |
416 | .add_seq(img_id, primitive_type, dim, |
417 | /*Depth=*/0, /*Arrayed=*/0, /*MS=*/0, /*Sampled=*/1, |
418 | spv::ImageFormatUnknown) |
419 | .commit(&global_); |
420 | |
421 | SType image_type; |
422 | image_type.id = img_id; |
423 | image_type.flag = TypeKind::kImage; |
424 | sampled_image_underlying_image_type_[key] = image_type; |
425 | |
426 | return image_type; |
427 | } |
428 | |
429 | SType IRBuilder::get_sampled_image_type(const SType &primitive_type, |
430 | int num_dimensions) { |
431 | auto key = std::make_pair(primitive_type.id, num_dimensions); |
432 | auto it = sampled_image_ptr_tbl_.find(key); |
433 | if (it != sampled_image_ptr_tbl_.end()) { |
434 | return it->second; |
435 | } |
436 | |
437 | SType image_type = get_underlying_image_type(primitive_type, num_dimensions); |
438 | int img_id = image_type.id; |
439 | |
440 | SType sampled_t; |
441 | sampled_t.id = id_counter_++; |
442 | sampled_t.flag = TypeKind::kImage; |
443 | ib_.begin(spv::OpTypeSampledImage) |
444 | .add_seq(sampled_t, img_id) |
445 | .commit(&global_); |
446 | sampled_image_ptr_tbl_[key] = sampled_t; |
447 | |
448 | return sampled_t; |
449 | } |
450 | |
451 | SType IRBuilder::get_storage_image_type(BufferFormat format, |
452 | int num_dimensions) { |
453 | auto key = std::make_pair(format, num_dimensions); |
454 | auto it = storage_image_ptr_tbl_.find(key); |
455 | if (it != storage_image_ptr_tbl_.end()) { |
456 | return it->second; |
457 | } |
458 | int img_id = id_counter_++; |
459 | |
460 | spv::Dim dim; |
461 | if (num_dimensions == 1) { |
462 | dim = spv::Dim1D; |
463 | } else if (num_dimensions == 2) { |
464 | dim = spv::Dim2D; |
465 | } else if (num_dimensions == 3) { |
466 | dim = spv::Dim3D; |
467 | } else { |
468 | TI_ERROR("Unsupported number of dimensions: {}" , num_dimensions); |
469 | } |
470 | |
471 | const std::unordered_map<BufferFormat, spv::ImageFormat> format2spv = { |
472 | {BufferFormat::r8, spv::ImageFormatR8}, |
473 | {BufferFormat::rg8, spv::ImageFormatRg8}, |
474 | {BufferFormat::rgba8, spv::ImageFormatRgba8}, |
475 | {BufferFormat::rgba8srgb, spv::ImageFormatRgba8}, |
476 | {BufferFormat::r8u, spv::ImageFormatR8ui}, |
477 | {BufferFormat::rg8u, spv::ImageFormatRg8ui}, |
478 | {BufferFormat::rgba8u, spv::ImageFormatRgba8ui}, |
479 | {BufferFormat::r8i, spv::ImageFormatR8i}, |
480 | {BufferFormat::rg8i, spv::ImageFormatRg8i}, |
481 | {BufferFormat::rgba8i, spv::ImageFormatRgba8i}, |
482 | {BufferFormat::r16, spv::ImageFormatR16}, |
483 | {BufferFormat::rg16, spv::ImageFormatRg16}, |
484 | {BufferFormat::rgba16, spv::ImageFormatRgba16}, |
485 | {BufferFormat::r16u, spv::ImageFormatR16ui}, |
486 | {BufferFormat::rg16u, spv::ImageFormatRg16ui}, |
487 | {BufferFormat::rgba16u, spv::ImageFormatRgba16ui}, |
488 | {BufferFormat::r16i, spv::ImageFormatR16i}, |
489 | {BufferFormat::rg16i, spv::ImageFormatRg16i}, |
490 | {BufferFormat::rgba16i, spv::ImageFormatRgba16i}, |
491 | {BufferFormat::r16f, spv::ImageFormatR16f}, |
492 | {BufferFormat::rg16f, spv::ImageFormatRg16f}, |
493 | {BufferFormat::rgba16f, spv::ImageFormatRgba16f}, |
494 | {BufferFormat::r32u, spv::ImageFormatR32ui}, |
495 | {BufferFormat::rg32u, spv::ImageFormatRg32ui}, |
496 | {BufferFormat::rgba32u, spv::ImageFormatRgba32ui}, |
497 | {BufferFormat::r32i, spv::ImageFormatR32i}, |
498 | {BufferFormat::rg32i, spv::ImageFormatRg32i}, |
499 | {BufferFormat::rgba32i, spv::ImageFormatRgba32i}, |
500 | {BufferFormat::r32f, spv::ImageFormatR32f}, |
501 | {BufferFormat::rg32f, spv::ImageFormatRg32f}, |
502 | {BufferFormat::rgba32f, spv::ImageFormatRgba32f}, |
503 | {BufferFormat::depth16, spv::ImageFormatR16}, |
504 | {BufferFormat::depth32f, spv::ImageFormatR32f}}; |
505 | |
506 | if (format2spv.find(format) == format2spv.end()) { |
507 | TI_ERROR("Unsupported image format" , num_dimensions); |
508 | } |
509 | spv::ImageFormat spv_format = format2spv.at(format); |
510 | |
511 | // TODO: Add integer type support |
512 | ib_.begin(spv::OpTypeImage) |
513 | .add_seq(img_id, f32_type(), dim, |
514 | /*Depth=*/0, /*Arrayed=*/0, /*MS=*/0, /*Sampled=*/2, spv_format) |
515 | .commit(&global_); |
516 | SType img_t; |
517 | img_t.id = img_id; |
518 | img_t.flag = TypeKind::kImage; |
519 | storage_image_ptr_tbl_[key] = img_t; |
520 | return img_t; |
521 | } |
522 | |
523 | SType IRBuilder::get_storage_pointer_type(const SType &value_type) { |
524 | spv::StorageClass storage_class; |
525 | if (caps_->get(cap::spirv_version) < 0x10300) { |
526 | storage_class = spv::StorageClassUniform; |
527 | } else { |
528 | storage_class = spv::StorageClassStorageBuffer; |
529 | } |
530 | |
531 | return get_pointer_type(value_type, storage_class); |
532 | } |
533 | |
534 | SType IRBuilder::get_array_type(const SType &value_type, uint32_t num_elems) { |
535 | SType arr_type; |
536 | arr_type.id = id_counter_++; |
537 | arr_type.flag = TypeKind::kPtr; |
538 | arr_type.element_type_id = value_type.id; |
539 | |
540 | if (num_elems != 0) { |
541 | Value length = uint_immediate_number( |
542 | get_primitive_type(get_data_type<uint32>()), num_elems); |
543 | ib_.begin(spv::OpTypeArray) |
544 | .add_seq(arr_type, value_type, length) |
545 | .commit(&global_); |
546 | } else { |
547 | ib_.begin(spv::OpTypeRuntimeArray) |
548 | .add_seq(arr_type, value_type) |
549 | .commit(&global_); |
550 | } |
551 | |
552 | uint32_t nbytes; |
553 | if (value_type.flag == TypeKind::kPrimitive) { |
554 | const auto nbits = data_type_bits(value_type.dt); |
555 | nbytes = static_cast<uint32_t>(nbits) / 8; |
556 | } else if (value_type.flag == TypeKind::kSNodeStruct) { |
557 | nbytes = value_type.snode_desc.container_stride; |
558 | } else { |
559 | TI_ERROR("buffer type must be primitive or snode struct" ); |
560 | } |
561 | |
562 | if (nbytes == 0) { |
563 | if (value_type.flag == TypeKind::kPrimitive) { |
564 | TI_WARN("Invalid primitive bit size" ); |
565 | } else { |
566 | TI_WARN("Invalid container stride" ); |
567 | } |
568 | } |
569 | |
570 | // decorate the array type |
571 | this->decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); |
572 | |
573 | return arr_type; |
574 | } |
575 | |
576 | SType IRBuilder::get_struct_array_type(const SType &value_type, |
577 | uint32_t num_elems) { |
578 | SType arr_type = get_array_type(value_type, num_elems); |
579 | |
580 | // declare struct of array |
581 | SType struct_type; |
582 | struct_type.id = id_counter_++; |
583 | struct_type.flag = TypeKind::kStruct; |
584 | struct_type.element_type_id = value_type.id; |
585 | ib_.begin(spv::OpTypeStruct).add_seq(struct_type, arr_type).commit(&global_); |
586 | // decorate the array type. |
587 | ib_.begin(spv::OpMemberDecorate) |
588 | .add_seq(struct_type, 0, spv::DecorationOffset, 0) |
589 | .commit(&decorate_); |
590 | |
591 | if (caps_->get(cap::spirv_version) < 0x10300) { |
592 | // NOTE: BufferBlock was deprecated in SPIRV 1.3 |
593 | // use StorageClassStorageBuffer instead. |
594 | // runtime array are always decorated as BufferBlock(shader storage buffer) |
595 | if (num_elems == 0) { |
596 | this->decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); |
597 | } |
598 | } else { |
599 | this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); |
600 | } |
601 | |
602 | return struct_type; |
603 | } |
604 | |
605 | SType IRBuilder::create_struct_type( |
606 | std::vector<std::tuple<SType, std::string, size_t>> &components) { |
607 | SType struct_type; |
608 | struct_type.id = id_counter_++; |
609 | struct_type.flag = TypeKind::kStruct; |
610 | |
611 | auto &builder = ib_.begin(spv::OpTypeStruct).add_seq(struct_type); |
612 | |
613 | for (auto &[type, name, offset] : components) { |
614 | builder.add_seq(type); |
615 | } |
616 | |
617 | builder.commit(&global_); |
618 | |
619 | int i = 0; |
620 | for (auto &[type, name, offset] : components) { |
621 | this->decorate(spv::OpMemberDecorate, struct_type, i, spv::DecorationOffset, |
622 | offset); |
623 | this->debug_name(spv::OpMemberName, struct_type, i, name); |
624 | i++; |
625 | } |
626 | |
627 | return struct_type; |
628 | } |
629 | |
630 | Value IRBuilder::buffer_struct_argument(const SType &struct_type, |
631 | uint32_t descriptor_set, |
632 | uint32_t binding, |
633 | const std::string &name) { |
634 | // NOTE: BufferBlock was deprecated in SPIRV 1.3 |
635 | // use StorageClassStorageBuffer instead. |
636 | spv::StorageClass storage_class; |
637 | if (caps_->get(cap::spirv_version) < 0x10300) { |
638 | storage_class = spv::StorageClassUniform; |
639 | } else { |
640 | storage_class = spv::StorageClassStorageBuffer; |
641 | } |
642 | |
643 | this->debug_name(spv::OpName, struct_type, name + "_t" ); |
644 | |
645 | if (caps_->get(cap::spirv_version) < 0x10300) { |
646 | // NOTE: BufferBlock was deprecated in SPIRV 1.3 |
647 | // use StorageClassStorageBuffer instead. |
648 | // runtime array are always decorated as BufferBlock(shader storage buffer) |
649 | this->decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); |
650 | } else { |
651 | this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); |
652 | } |
653 | |
654 | SType ptr_type = get_pointer_type(struct_type, storage_class); |
655 | |
656 | this->debug_name(spv::OpName, ptr_type, name + "_ptr" ); |
657 | |
658 | Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); |
659 | ib_.begin(spv::OpVariable) |
660 | .add_seq(ptr_type, val, storage_class) |
661 | .commit(&global_); |
662 | |
663 | this->debug_name(spv::OpName, val, name); |
664 | |
665 | this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, |
666 | descriptor_set); |
667 | this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); |
668 | return val; |
669 | } |
670 | |
671 | Value IRBuilder::uniform_struct_argument(const SType &struct_type, |
672 | uint32_t descriptor_set, |
673 | uint32_t binding, |
674 | const std::string &name) { |
675 | // NOTE: BufferBlock was deprecated in SPIRV 1.3 |
676 | // use StorageClassStorageBuffer instead. |
677 | spv::StorageClass storage_class = spv::StorageClassUniform; |
678 | |
679 | this->debug_name(spv::OpName, struct_type, name + "_t" ); |
680 | |
681 | this->decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); |
682 | |
683 | SType ptr_type = get_pointer_type(struct_type, storage_class); |
684 | |
685 | this->debug_name(spv::OpName, ptr_type, name + "_ptr" ); |
686 | |
687 | Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); |
688 | ib_.begin(spv::OpVariable) |
689 | .add_seq(ptr_type, val, storage_class) |
690 | .commit(&global_); |
691 | |
692 | this->debug_name(spv::OpName, val, name); |
693 | |
694 | this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, |
695 | descriptor_set); |
696 | this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); |
697 | return val; |
698 | } |
699 | |
700 | Value IRBuilder::buffer_argument(const SType &value_type, |
701 | uint32_t descriptor_set, |
702 | uint32_t binding, |
703 | const std::string &name) { |
704 | // NOTE: BufferBlock was deprecated in SPIRV 1.3 |
705 | // use StorageClassStorageBuffer instead. |
706 | spv::StorageClass storage_class; |
707 | if (caps_->get(cap::spirv_version) < 0x10300) { |
708 | storage_class = spv::StorageClassUniform; |
709 | } else { |
710 | storage_class = spv::StorageClassStorageBuffer; |
711 | } |
712 | |
713 | SType sarr_type = get_struct_array_type(value_type, 0); |
714 | |
715 | auto typed_name = name + "_" + value_type.dt.to_string(); |
716 | |
717 | this->debug_name(spv::OpName, sarr_type, typed_name + "_struct_array" ); |
718 | |
719 | SType ptr_type = get_pointer_type(sarr_type, storage_class); |
720 | |
721 | this->debug_name(spv::OpName, sarr_type, typed_name + "_ptr" ); |
722 | |
723 | Value val = new_value(ptr_type, ValueKind::kStructArrayPtr); |
724 | ib_.begin(spv::OpVariable) |
725 | .add_seq(ptr_type, val, storage_class) |
726 | .commit(&global_); |
727 | |
728 | this->debug_name(spv::OpName, val, typed_name); |
729 | |
730 | this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, |
731 | descriptor_set); |
732 | this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); |
733 | return val; |
734 | } |
735 | |
736 | Value IRBuilder::struct_array_access(const SType &res_type, |
737 | Value buffer, |
738 | Value index) { |
739 | TI_ASSERT(buffer.flag == ValueKind::kStructArrayPtr); |
740 | TI_ASSERT(res_type.flag == TypeKind::kPrimitive); |
741 | |
742 | spv::StorageClass storage_class; |
743 | if (caps_->get(cap::spirv_version) < 0x10300) { |
744 | storage_class = spv::StorageClassUniform; |
745 | } else { |
746 | storage_class = spv::StorageClassStorageBuffer; |
747 | } |
748 | |
749 | SType ptr_type = this->get_pointer_type(res_type, storage_class); |
750 | Value ret = new_value(ptr_type, ValueKind::kVariablePtr); |
751 | ib_.begin(spv::OpAccessChain) |
752 | .add_seq(ptr_type, ret, buffer, const_i32_zero_, index) |
753 | .commit(&function_); |
754 | |
755 | return ret; |
756 | } |
757 | |
758 | Value IRBuilder::texture_argument(int num_channels, |
759 | int num_dimensions, |
760 | uint32_t descriptor_set, |
761 | uint32_t binding) { |
762 | auto texture_type = this->get_sampled_image_type(f32_type(), num_dimensions); |
763 | auto texture_ptr_type = |
764 | get_pointer_type(texture_type, spv::StorageClassUniformConstant); |
765 | |
766 | Value val = new_value(texture_ptr_type, ValueKind::kVariablePtr); |
767 | ib_.begin(spv::OpVariable) |
768 | .add_seq(texture_ptr_type, val, spv::StorageClassUniformConstant) |
769 | .commit(&global_); |
770 | |
771 | this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, |
772 | descriptor_set); |
773 | this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); |
774 | |
775 | this->debug_name(spv::OpName, val, "tex" ); |
776 | |
777 | this->global_values.push_back(val); |
778 | |
779 | return val; |
780 | } |
781 | |
782 | Value IRBuilder::storage_image_argument(int num_channels, |
783 | int num_dimensions, |
784 | uint32_t descriptor_set, |
785 | uint32_t binding, |
786 | BufferFormat format) { |
787 | auto texture_type = this->get_storage_image_type(format, num_dimensions); |
788 | auto texture_ptr_type = |
789 | get_pointer_type(texture_type, spv::StorageClassUniformConstant); |
790 | |
791 | Value val = new_value(texture_type, ValueKind::kVariablePtr); |
792 | ib_.begin(spv::OpVariable) |
793 | .add_seq(texture_ptr_type, val, spv::StorageClassUniformConstant) |
794 | .commit(&global_); |
795 | |
796 | this->decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, |
797 | descriptor_set); |
798 | this->decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); |
799 | |
800 | this->debug_name(spv::OpName, val, "tex" ); |
801 | |
802 | this->global_values.push_back(val); |
803 | |
804 | return val; |
805 | } |
806 | |
807 | Value IRBuilder::sample_texture(Value texture_var, |
808 | const std::vector<Value> &args, |
809 | Value lod) { |
810 | auto image = this->load_variable( |
811 | texture_var, this->get_sampled_image_type(f32_type(), args.size())); |
812 | Value uv; |
813 | if (args.size() == 1) { |
814 | uv = args[0]; |
815 | } else if (args.size() == 2) { |
816 | uv = make_value(spv::OpCompositeConstruct, t_v2_fp32_, args[0], args[1]); |
817 | } else if (args.size() == 3) { |
818 | uv = make_value(spv::OpCompositeConstruct, t_v3_fp32_, args[0], args[1], |
819 | args[2]); |
820 | } else { |
821 | TI_ERROR("Unsupported number of texture coordinates" ); |
822 | } |
823 | uint32_t lod_operand = 0x2; |
824 | auto res_vec4 = make_value(spv::OpImageSampleExplicitLod, t_v4_fp32_, image, |
825 | uv, lod_operand, lod); |
826 | return res_vec4; |
827 | } |
828 | |
829 | Value IRBuilder::fetch_texel(Value texture_var, |
830 | const std::vector<Value> &args, |
831 | Value lod) { |
832 | auto sampled_image = this->load_variable( |
833 | texture_var, this->get_sampled_image_type(f32_type(), args.size())); |
834 | |
835 | // OpImageFetch requires operand with OpImageType |
836 | // We have to extract the underlying OpImage from OpSampledImage here |
837 | SType image_type = get_underlying_image_type(f32_type(), args.size()); |
838 | Value image_val = make_value(spv::OpImage, image_type, sampled_image); |
839 | |
840 | Value uv; |
841 | if (args.size() == 1) { |
842 | uv = args[0]; |
843 | } else if (args.size() == 2) { |
844 | uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]); |
845 | } else if (args.size() == 3) { |
846 | uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1], |
847 | args[2]); |
848 | } else { |
849 | TI_ERROR("Unsupported number of texture coordinates" ); |
850 | } |
851 | uint32_t lod_operand = 0x2; |
852 | auto res_vec4 = make_value(spv::OpImageFetch, t_v4_fp32_, image_val, uv, |
853 | lod_operand, lod); |
854 | return res_vec4; |
855 | } |
856 | |
857 | Value IRBuilder::image_load(Value image_var, const std::vector<Value> &args) { |
858 | auto image = this->load_variable(image_var, image_var.stype); |
859 | Value uv; |
860 | if (args.size() == 1) { |
861 | uv = args[0]; |
862 | } else if (args.size() == 2) { |
863 | uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]); |
864 | } else if (args.size() == 3) { |
865 | uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1], |
866 | args[2]); |
867 | } else { |
868 | TI_ERROR("Unsupported number of texture coordinates" ); |
869 | } |
870 | auto res_vec4 = make_value(spv::OpImageRead, t_v4_fp32_, image, uv); |
871 | return res_vec4; |
872 | } |
873 | |
874 | void IRBuilder::image_store(Value image_var, const std::vector<Value> &args) { |
875 | auto image = this->load_variable(image_var, image_var.stype); |
876 | Value uv; |
877 | if (args.size() == 1 + 4) { |
878 | uv = args[0]; |
879 | } else if (args.size() == 2 + 4) { |
880 | uv = make_value(spv::OpCompositeConstruct, t_v2_int_, args[0], args[1]); |
881 | } else if (args.size() == 3 + 4) { |
882 | uv = make_value(spv::OpCompositeConstruct, t_v3_int_, args[0], args[1], |
883 | args[2]); |
884 | } else { |
885 | TI_ERROR("Unsupported number of image coordinates" ); |
886 | } |
887 | int base = args.size() - 4; |
888 | Value data = make_value(spv::OpCompositeConstruct, t_v4_fp32_, args[base], |
889 | args[base + 1], args[base + 2], args[base + 3]); |
890 | make_inst(spv::OpImageWrite, image, uv, data); |
891 | } |
892 | |
893 | void IRBuilder::set_work_group_size(const std::array<int, 3> group_size) { |
894 | Value size_x = |
895 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[0])); |
896 | Value size_y = |
897 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[1])); |
898 | Value size_z = |
899 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(group_size[2])); |
900 | |
901 | if (gl_work_group_size_.id == 0) { |
902 | gl_work_group_size_.id = id_counter_++; |
903 | } |
904 | ib_.begin(spv::OpConstantComposite) |
905 | .add_seq(t_v3_uint_, gl_work_group_size_, size_x, size_y, size_z) |
906 | .commit(&global_); |
907 | this->decorate(spv::OpDecorate, gl_work_group_size_, spv::DecorationBuiltIn, |
908 | spv::BuiltInWorkgroupSize); |
909 | } |
910 | |
911 | Value IRBuilder::get_num_work_groups(uint32_t dim_index) { |
912 | if (gl_num_work_groups_.id == 0) { |
913 | SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput); |
914 | gl_num_work_groups_ = new_value(ptr_type, ValueKind::kVectorPtr); |
915 | ib_.begin(spv::OpVariable) |
916 | .add_seq(ptr_type, gl_num_work_groups_, spv::StorageClassInput) |
917 | .commit(&global_); |
918 | this->decorate(spv::OpDecorate, gl_num_work_groups_, spv::DecorationBuiltIn, |
919 | spv::BuiltInNumWorkgroups); |
920 | } |
921 | SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
922 | Value ptr = this->make_value( |
923 | spv::OpAccessChain, pint_type, gl_num_work_groups_, |
924 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index))); |
925 | |
926 | return this->make_value(spv::OpLoad, t_uint32_, ptr); |
927 | } |
928 | |
929 | Value IRBuilder::get_local_invocation_id(uint32_t dim_index) { |
930 | if (gl_local_invocation_id_.id == 0) { |
931 | SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput); |
932 | gl_local_invocation_id_ = new_value(ptr_type, ValueKind::kVectorPtr); |
933 | ib_.begin(spv::OpVariable) |
934 | .add_seq(ptr_type, gl_local_invocation_id_, spv::StorageClassInput) |
935 | .commit(&global_); |
936 | this->decorate(spv::OpDecorate, gl_local_invocation_id_, |
937 | spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId); |
938 | } |
939 | SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
940 | Value ptr = this->make_value( |
941 | spv::OpAccessChain, pint_type, gl_local_invocation_id_, |
942 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index))); |
943 | |
944 | return this->make_value(spv::OpLoad, t_uint32_, ptr); |
945 | } |
946 | |
947 | Value IRBuilder::get_global_invocation_id(uint32_t dim_index) { |
948 | if (gl_global_invocation_id_.id == 0) { |
949 | SType ptr_type = this->get_pointer_type(t_v3_uint_, spv::StorageClassInput); |
950 | gl_global_invocation_id_ = new_value(ptr_type, ValueKind::kVectorPtr); |
951 | ib_.begin(spv::OpVariable) |
952 | .add_seq(ptr_type, gl_global_invocation_id_, spv::StorageClassInput) |
953 | .commit(&global_); |
954 | this->decorate(spv::OpDecorate, gl_global_invocation_id_, |
955 | spv::DecorationBuiltIn, spv::BuiltInGlobalInvocationId); |
956 | } |
957 | SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
958 | Value ptr = this->make_value( |
959 | spv::OpAccessChain, pint_type, gl_global_invocation_id_, |
960 | uint_immediate_number(t_uint32_, static_cast<uint64_t>(dim_index))); |
961 | |
962 | return this->make_value(spv::OpLoad, t_uint32_, ptr); |
963 | } |
964 | |
965 | Value IRBuilder::get_subgroup_invocation_id() { |
966 | if (subgroup_local_invocation_id_.id == 0) { |
967 | SType ptr_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
968 | subgroup_local_invocation_id_ = |
969 | new_value(ptr_type, ValueKind::kVariablePtr); |
970 | ib_.begin(spv::OpVariable) |
971 | .add_seq(ptr_type, subgroup_local_invocation_id_, |
972 | spv::StorageClassInput) |
973 | .commit(&global_); |
974 | this->decorate(spv::OpDecorate, subgroup_local_invocation_id_, |
975 | spv::DecorationBuiltIn, |
976 | spv::BuiltInSubgroupLocalInvocationId); |
977 | global_values.push_back(subgroup_local_invocation_id_); |
978 | } |
979 | |
980 | return this->make_value(spv::OpLoad, t_uint32_, |
981 | subgroup_local_invocation_id_); |
982 | } |
983 | |
984 | Value IRBuilder::get_subgroup_size() { |
985 | if (subgroup_size_.id == 0) { |
986 | SType ptr_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
987 | subgroup_size_ = new_value(ptr_type, ValueKind::kVariablePtr); |
988 | ib_.begin(spv::OpVariable) |
989 | .add_seq(ptr_type, subgroup_size_, spv::StorageClassInput) |
990 | .commit(&global_); |
991 | this->decorate(spv::OpDecorate, subgroup_size_, spv::DecorationBuiltIn, |
992 | spv::BuiltInSubgroupSize); |
993 | global_values.push_back(subgroup_size_); |
994 | } |
995 | |
996 | return this->make_value(spv::OpLoad, t_uint32_, subgroup_size_); |
997 | } |
998 | |
999 | #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ |
1000 | Value IRBuilder::_OpName(Value a, Value b) { \ |
1001 | TI_ASSERT(a.stype.id == b.stype.id); \ |
1002 | if (is_integral(a.stype.dt)) { \ |
1003 | return make_value(spv::OpI##_Op, a.stype, a, b); \ |
1004 | } else { \ |
1005 | TI_ASSERT(is_real(a.stype.dt)); \ |
1006 | return make_value(spv::OpF##_Op, a.stype, a, b); \ |
1007 | } \ |
1008 | } |
1009 | |
1010 | #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ |
1011 | Value IRBuilder::_OpName(Value a, Value b) { \ |
1012 | TI_ASSERT(a.stype.id == b.stype.id); \ |
1013 | if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ |
1014 | return make_value(spv::OpS##_Op, a.stype, a, b); \ |
1015 | } else if (is_integral(a.stype.dt)) { \ |
1016 | return make_value(spv::OpU##_Op, a.stype, a, b); \ |
1017 | } else { \ |
1018 | TI_ASSERT(is_real(a.stype.dt)); \ |
1019 | return make_value(spv::OpF##_Op, a.stype, a, b); \ |
1020 | } \ |
1021 | } |
1022 | |
1023 | DEFINE_BUILDER_BINARY_USIGN_OP(add, Add); |
1024 | DEFINE_BUILDER_BINARY_USIGN_OP(sub, Sub); |
1025 | DEFINE_BUILDER_BINARY_USIGN_OP(mul, Mul); |
1026 | DEFINE_BUILDER_BINARY_SIGN_OP(div, Div); |
1027 | |
1028 | Value IRBuilder::mod(Value a, Value b) { |
1029 | TI_ASSERT(a.stype.id == b.stype.id); |
1030 | if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { |
1031 | // FIXME: figure out why OpSRem does not work |
1032 | return sub(a, mul(b, div(a, b))); |
1033 | } else if (is_integral(a.stype.dt)) { |
1034 | return make_value(spv::OpUMod, a.stype, a, b); |
1035 | } else { |
1036 | TI_ASSERT(is_real(a.stype.dt)); |
1037 | return make_value(spv::OpFRem, a.stype, a, b); |
1038 | } |
1039 | } |
1040 | |
1041 | #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ |
1042 | Value IRBuilder::_OpName(Value a, Value b) { \ |
1043 | TI_ASSERT(a.stype.id == b.stype.id); \ |
1044 | const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \ |
1045 | if (is_integral(a.stype.dt) && is_signed(a.stype.dt)) { \ |
1046 | return make_value(spv::OpS##_Op, bool_type, a, b); \ |
1047 | } else if (is_integral(a.stype.dt)) { \ |
1048 | return make_value(spv::OpU##_Op, bool_type, a, b); \ |
1049 | } else { \ |
1050 | TI_ASSERT(is_real(a.stype.dt)); \ |
1051 | return make_value(spv::OpFOrd##_Op, bool_type, a, b); \ |
1052 | } \ |
1053 | } |
1054 | |
1055 | DEFINE_BUILDER_CMP_OP(lt, LessThan); |
1056 | DEFINE_BUILDER_CMP_OP(le, LessThanEqual); |
1057 | DEFINE_BUILDER_CMP_OP(gt, GreaterThan); |
1058 | DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual); |
1059 | |
1060 | #define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ |
1061 | Value IRBuilder::_OpName(Value a, Value b) { \ |
1062 | TI_ASSERT(a.stype.id == b.stype.id); \ |
1063 | const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \ |
1064 | if (is_integral(a.stype.dt)) { \ |
1065 | return make_value(spv::OpI##_Op, bool_type, a, b); \ |
1066 | } else if (a.stype.id == bool_type.id) { \ |
1067 | return make_value(spv::OpLogical##_Op, bool_type, a, b); \ |
1068 | } else { \ |
1069 | TI_ASSERT(is_real(a.stype.dt)); \ |
1070 | return make_value(spv::OpFOrd##_Op, bool_type, a, b); \ |
1071 | } \ |
1072 | } |
1073 | |
1074 | DEFINE_BUILDER_CMP_UOP(eq, Equal); |
1075 | DEFINE_BUILDER_CMP_UOP(ne, NotEqual); |
1076 | |
1077 | Value IRBuilder::(Value base, Value offset, Value count) { |
1078 | TI_ASSERT(is_integral(base.stype.dt)); |
1079 | TI_ASSERT(is_integral(offset.stype.dt)); |
1080 | TI_ASSERT(is_integral(count.stype.dt)); |
1081 | return make_value(spv::OpBitFieldUExtract, base.stype, base, offset, count); |
1082 | } |
1083 | |
1084 | Value IRBuilder::select(Value cond, Value a, Value b) { |
1085 | TI_ASSERT(a.stype.id == b.stype.id); |
1086 | TI_ASSERT(cond.stype.id == t_bool_.id); |
1087 | return make_value(spv::OpSelect, a.stype, cond, a, b); |
1088 | } |
1089 | |
1090 | Value IRBuilder::cast(const SType &dst_type, Value value) { |
1091 | TI_ASSERT(value.stype.id > 0U); |
1092 | if (value.stype.id == dst_type.id) |
1093 | return value; |
1094 | const DataType &from = value.stype.dt; |
1095 | const DataType &to = dst_type.dt; |
1096 | if (from->is_primitive(PrimitiveTypeID::u1)) { // Bool |
1097 | if (is_integral(to) && is_signed(to)) { // Bool -> Int |
1098 | return select(value, int_immediate_number(dst_type, 1), |
1099 | int_immediate_number(dst_type, 0)); |
1100 | } else if (is_integral(to) && is_unsigned(to)) { // Bool -> UInt |
1101 | return select(value, uint_immediate_number(dst_type, 1), |
1102 | uint_immediate_number(dst_type, 0)); |
1103 | } else if (is_real(to)) { // Bool -> Float |
1104 | return make_value(spv::OpConvertUToF, dst_type, |
1105 | select(value, uint_immediate_number(t_uint32_, 1), |
1106 | uint_immediate_number(t_uint32_, 0))); |
1107 | } else { |
1108 | TI_ERROR("do not support type cast from {} to {}" , from.to_string(), |
1109 | to.to_string()); |
1110 | return Value(); |
1111 | } |
1112 | } else if (to->is_primitive(PrimitiveTypeID::u1)) { // Bool |
1113 | if (is_integral(from) && is_signed(from)) { // Int -> Bool |
1114 | return ne(value, int_immediate_number(value.stype, 0)); |
1115 | } else if (is_integral(from) && is_unsigned(from)) { // UInt -> Bool |
1116 | return ne(value, uint_immediate_number(value.stype, 0)); |
1117 | } else { |
1118 | TI_ERROR("do not support type cast from {} to {}" , from.to_string(), |
1119 | to.to_string()); |
1120 | return Value(); |
1121 | } |
1122 | } else if (is_integral(from) && is_integral(to)) { |
1123 | auto ret = value; |
1124 | |
1125 | if (data_type_bits(from) == data_type_bits(to)) { |
1126 | // Same width conversion |
1127 | ret = make_value(spv::OpBitcast, dst_type, ret); |
1128 | } else { |
1129 | // Different width |
1130 | // Step 1. Sign extend / truncate value to width of `to` |
1131 | // Step 2. Bitcast to signess of `to` |
1132 | auto get_signed_type = [](DataType dt) -> DataType { |
1133 | // Create a output signed type with the same width as `dt` |
1134 | if (data_type_bits(dt) == 8) |
1135 | return PrimitiveType::i8; |
1136 | else if (data_type_bits(dt) == 16) |
1137 | return PrimitiveType::i16; |
1138 | else if (data_type_bits(dt) == 32) |
1139 | return PrimitiveType::i32; |
1140 | else if (data_type_bits(dt) == 64) |
1141 | return PrimitiveType::i64; |
1142 | else |
1143 | return PrimitiveType::unknown; |
1144 | }; |
1145 | auto get_unsigned_type = [](DataType dt) -> DataType { |
1146 | // Create a output unsigned type with the same width as `dt` |
1147 | if (data_type_bits(dt) == 8) |
1148 | return PrimitiveType::u8; |
1149 | else if (data_type_bits(dt) == 16) |
1150 | return PrimitiveType::u16; |
1151 | else if (data_type_bits(dt) == 32) |
1152 | return PrimitiveType::u32; |
1153 | else if (data_type_bits(dt) == 64) |
1154 | return PrimitiveType::u64; |
1155 | else |
1156 | return PrimitiveType::unknown; |
1157 | }; |
1158 | |
1159 | if (is_signed(from)) { |
1160 | ret = make_value(spv::OpSConvert, |
1161 | get_primitive_type(get_signed_type(to)), ret); |
1162 | } else { |
1163 | ret = make_value(spv::OpUConvert, |
1164 | get_primitive_type(get_unsigned_type(to)), ret); |
1165 | } |
1166 | |
1167 | ret = make_value(spv::OpBitcast, dst_type, ret); |
1168 | } |
1169 | |
1170 | return ret; |
1171 | } else if (is_real(from) && is_integral(to) && |
1172 | is_signed(to)) { // Float -> Int |
1173 | return make_value(spv::OpConvertFToS, dst_type, value); |
1174 | } else if (is_real(from) && is_integral(to) && |
1175 | is_unsigned(to)) { // Float -> UInt |
1176 | return make_value(spv::OpConvertFToU, dst_type, value); |
1177 | } else if (is_integral(from) && is_signed(from) && |
1178 | is_real(to)) { // Int -> Float |
1179 | return make_value(spv::OpConvertSToF, dst_type, value); |
1180 | } else if (is_integral(from) && is_unsigned(from) && |
1181 | is_real(to)) { // UInt -> Float |
1182 | return make_value(spv::OpConvertUToF, dst_type, value); |
1183 | } else if (is_real(from) && is_real(to)) { // Float -> Float |
1184 | return make_value(spv::OpFConvert, dst_type, value); |
1185 | } else { |
1186 | TI_ERROR("do not support type cast from {} to {}" , from.to_string(), |
1187 | to.to_string()); |
1188 | return Value(); |
1189 | } |
1190 | } |
1191 | |
1192 | Value IRBuilder::alloca_variable(const SType &type) { |
1193 | SType ptr_type = get_pointer_type(type, spv::StorageClassFunction); |
1194 | Value ret = new_value(ptr_type, ValueKind::kVariablePtr); |
1195 | ib_.begin(spv::OpVariable) |
1196 | .add_seq(ptr_type, ret, spv::StorageClassFunction) |
1197 | .commit(&func_header_); |
1198 | return ret; |
1199 | } |
1200 | |
1201 | Value IRBuilder::alloca_workgroup_array(const SType &arr_type) { |
1202 | SType ptr_type = get_pointer_type(arr_type, spv::StorageClassWorkgroup); |
1203 | Value ret = new_value(ptr_type, ValueKind::kVariablePtr); |
1204 | ib_.begin(spv::OpVariable) |
1205 | .add_seq(ptr_type, ret, spv::StorageClassWorkgroup) |
1206 | .commit(&global_); |
1207 | return ret; |
1208 | } |
1209 | |
1210 | Value IRBuilder::load_variable(Value pointer, const SType &res_type) { |
1211 | TI_ASSERT(pointer.flag == ValueKind::kVariablePtr || |
1212 | pointer.flag == ValueKind::kStructArrayPtr || |
1213 | pointer.flag == ValueKind::kPhysicalPtr); |
1214 | Value ret = new_value(res_type, ValueKind::kNormal); |
1215 | if (pointer.flag == ValueKind::kPhysicalPtr) { |
1216 | uint32_t alignment = uint32_t(get_primitive_type_size(res_type.dt)); |
1217 | ib_.begin(spv::OpLoad) |
1218 | .add_seq(res_type, ret, pointer, spv::MemoryAccessAlignedMask, |
1219 | alignment) |
1220 | .commit(&function_); |
1221 | } else { |
1222 | ib_.begin(spv::OpLoad).add_seq(res_type, ret, pointer).commit(&function_); |
1223 | } |
1224 | return ret; |
1225 | } |
1226 | void IRBuilder::store_variable(Value pointer, Value value) { |
1227 | TI_ASSERT(pointer.flag == ValueKind::kVariablePtr || |
1228 | pointer.flag == ValueKind::kPhysicalPtr); |
1229 | TI_ASSERT(value.stype.id == pointer.stype.element_type_id); |
1230 | if (pointer.flag == ValueKind::kPhysicalPtr) { |
1231 | uint32_t alignment = uint32_t(get_primitive_type_size(value.stype.dt)); |
1232 | ib_.begin(spv::OpStore) |
1233 | .add_seq(pointer, value, spv::MemoryAccessAlignedMask, alignment) |
1234 | .commit(&function_); |
1235 | } else { |
1236 | ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&function_); |
1237 | } |
1238 | } |
1239 | |
1240 | void IRBuilder::register_value(std::string name, Value value) { |
1241 | auto it = value_name_tbl_.find(name); |
1242 | if (it != value_name_tbl_.end() && it->second.flag != ValueKind::kConstant) { |
1243 | TI_ERROR("{} already exists." , name); |
1244 | } |
1245 | this->debug_name( |
1246 | spv::OpName, value, |
1247 | fmt::format("{}_{}" , name, value.stype.dt.to_string())); // Debug info |
1248 | value_name_tbl_[name] = value; |
1249 | } |
1250 | |
1251 | Value IRBuilder::query_value(std::string name) const { |
1252 | auto it = value_name_tbl_.find(name); |
1253 | if (it != value_name_tbl_.end()) { |
1254 | return it->second; |
1255 | } |
1256 | TI_ERROR("Value \"{}\" does not yet exist." , name); |
1257 | } |
1258 | |
1259 | bool IRBuilder::check_value_existence(const std::string &name) const { |
1260 | return value_name_tbl_.find(name) != value_name_tbl_.end(); |
1261 | } |
1262 | |
1263 | Value IRBuilder::float_atomic(AtomicOpType op_type, |
1264 | Value addr_ptr, |
1265 | Value data) { |
1266 | auto atomic_func_ = [&](std::function<Value(Value, Value)> atomic_op) { |
1267 | Value ret_val_int = alloca_variable(t_uint32_); |
1268 | |
1269 | // do-while |
1270 | Label head = new_label(); |
1271 | Label body = new_label(); |
1272 | Label branch_true = new_label(); |
1273 | Label branch_false = new_label(); |
1274 | Label merge = new_label(); |
1275 | Label exit = new_label(); |
1276 | |
1277 | make_inst(spv::OpBranch, head); |
1278 | start_label(head); |
1279 | make_inst(spv::OpLoopMerge, branch_true, merge, 0); |
1280 | make_inst(spv::OpBranch, body); |
1281 | make_inst(spv::OpLabel, body); |
1282 | // while (true) |
1283 | { |
1284 | // int old = addr_ptr[0]; |
1285 | Value old_val = load_variable(addr_ptr, t_uint32_); |
1286 | // int new = floatBitsToInt(atomic_op(intBitsToFloat(old), data)); |
1287 | Value old_float = make_value(spv::OpBitcast, t_fp32_, old_val); |
1288 | Value new_float = atomic_op(old_float, data); |
1289 | Value new_val = make_value(spv::OpBitcast, t_uint32_, new_float); |
1290 | // int loaded = atomicCompSwap(vals[0], old, new); |
1291 | /* |
1292 | * Don't need this part, theoretically |
1293 | auto semantics = uint_immediate_number( |
1294 | t_uint32_, spv::MemorySemanticsAcquireReleaseMask | |
1295 | spv::MemorySemanticsUniformMemoryMask); |
1296 | make_inst(spv::OpMemoryBarrier, const_i32_one_, semantics); |
1297 | */ |
1298 | Value loaded = make_value( |
1299 | spv::OpAtomicCompareExchange, t_uint32_, addr_ptr, |
1300 | /*scope=*/const_i32_one_, /*semantics if equal=*/const_i32_zero_, |
1301 | /*semantics if unequal=*/const_i32_zero_, new_val, old_val); |
1302 | // bool ok = (loaded == old); |
1303 | Value ok = make_value(spv::OpIEqual, t_bool_, loaded, old_val); |
1304 | // int ret_val_int = loaded; |
1305 | store_variable(ret_val_int, loaded); |
1306 | // if (ok) |
1307 | make_inst(spv::OpSelectionMerge, branch_false, 0); |
1308 | make_inst(spv::OpBranchConditional, ok, branch_true, branch_false); |
1309 | { |
1310 | make_inst(spv::OpLabel, branch_true); |
1311 | make_inst(spv::OpBranch, exit); |
1312 | } |
1313 | // else |
1314 | { |
1315 | make_inst(spv::OpLabel, branch_false); |
1316 | make_inst(spv::OpBranch, merge); |
1317 | } |
1318 | // continue; |
1319 | make_inst(spv::OpLabel, merge); |
1320 | make_inst(spv::OpBranch, head); |
1321 | } |
1322 | start_label(exit); |
1323 | |
1324 | return make_value(spv::OpBitcast, t_fp32_, |
1325 | load_variable(ret_val_int, t_uint32_)); |
1326 | }; |
1327 | |
1328 | if (op_type == AtomicOpType::add) { |
1329 | return atomic_func_([&](Value lhs, Value rhs) { return add(lhs, rhs); }); |
1330 | } else if (op_type == AtomicOpType::sub) { |
1331 | return atomic_func_([&](Value lhs, Value rhs) { return sub(lhs, rhs); }); |
1332 | } else if (op_type == AtomicOpType::min) { |
1333 | return atomic_func_([&](Value lhs, Value rhs) { |
1334 | return call_glsl450(t_fp32_, /*FMin*/ 37, lhs, rhs); |
1335 | }); |
1336 | } else if (op_type == AtomicOpType::max) { |
1337 | return atomic_func_([&](Value lhs, Value rhs) { |
1338 | return call_glsl450(t_fp32_, /*FMax*/ 40, lhs, rhs); |
1339 | }); |
1340 | } else { |
1341 | TI_NOT_IMPLEMENTED |
1342 | } |
1343 | } |
1344 | |
1345 | Value IRBuilder::rand_u32(Value global_tmp_) { |
1346 | if (!init_rand_) { |
1347 | init_random_function(global_tmp_); |
1348 | } |
1349 | |
1350 | Value _11u = uint_immediate_number(t_uint32_, 11u); |
1351 | Value _19u = uint_immediate_number(t_uint32_, 19u); |
1352 | Value _8u = uint_immediate_number(t_uint32_, 8u); |
1353 | Value _1000000007u = uint_immediate_number(t_uint32_, 1000000007u); |
1354 | Value tmp0 = load_variable(rand_x_, t_uint32_); |
1355 | Value tmp1 = make_value(spv::OpShiftLeftLogical, t_uint32_, tmp0, _11u); |
1356 | Value tmp_t = make_value(spv::OpBitwiseXor, t_uint32_, tmp0, tmp1); // t |
1357 | store_variable(rand_x_, load_variable(rand_y_, t_uint32_)); |
1358 | store_variable(rand_y_, load_variable(rand_z_, t_uint32_)); |
1359 | Value tmp_w = load_variable(rand_w_, t_uint32_); // reuse w |
1360 | store_variable(rand_z_, tmp_w); |
1361 | Value tmp2 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_w, _19u); |
1362 | Value tmp3 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_w, tmp2); |
1363 | Value tmp4 = make_value(spv::OpShiftRightLogical, t_uint32_, tmp_t, _8u); |
1364 | Value tmp5 = make_value(spv::OpBitwiseXor, t_uint32_, tmp_t, tmp4); |
1365 | Value new_w = make_value(spv::OpBitwiseXor, t_uint32_, tmp3, tmp5); |
1366 | store_variable(rand_w_, new_w); |
1367 | Value val = make_value(spv::OpIMul, t_uint32_, new_w, _1000000007u); |
1368 | |
1369 | return val; |
1370 | } |
1371 | |
1372 | Value IRBuilder::rand_f32(Value global_tmp_) { |
1373 | if (!init_rand_) { |
1374 | init_random_function(global_tmp_); |
1375 | } |
1376 | |
1377 | Value _1_4294967296f = float_immediate_number(t_fp32_, 1.0f / 4294967296.0f); |
1378 | Value tmp0 = rand_u32(global_tmp_); |
1379 | Value tmp1 = cast(t_fp32_, tmp0); |
1380 | Value val = mul(tmp1, _1_4294967296f); |
1381 | |
1382 | return val; |
1383 | } |
1384 | |
1385 | Value IRBuilder::rand_i32(Value global_tmp_) { |
1386 | if (!init_rand_) { |
1387 | init_random_function(global_tmp_); |
1388 | } |
1389 | |
1390 | Value tmp0 = rand_u32(global_tmp_); |
1391 | Value val = cast(t_int32_, tmp0); |
1392 | return val; |
1393 | } |
1394 | |
1395 | Value IRBuilder::get_const(const SType &dtype, |
1396 | const uint64_t *pvalue, |
1397 | bool cache) { |
1398 | auto key = std::make_pair(dtype.id, pvalue[0]); |
1399 | if (cache) { |
1400 | auto it = const_tbl_.find(key); |
1401 | if (it != const_tbl_.end()) { |
1402 | return it->second; |
1403 | } |
1404 | } |
1405 | |
1406 | TI_ASSERT(dtype.flag == TypeKind::kPrimitive); |
1407 | Value ret = new_value(dtype, ValueKind::kConstant); |
1408 | if (dtype.dt->is_primitive(PrimitiveTypeID::u1)) { |
1409 | // bool type |
1410 | if (*pvalue) { |
1411 | ib_.begin(spv::OpConstantTrue).add_seq(dtype, ret); |
1412 | } else { |
1413 | ib_.begin(spv::OpConstantFalse).add_seq(dtype, ret); |
1414 | } |
1415 | } else { |
1416 | // Integral/floating-point types. |
1417 | ib_.begin(spv::OpConstant).add_seq(dtype, ret); |
1418 | uint64_t mask = 0xFFFFFFFFUL; |
1419 | ib_.add(static_cast<uint32_t>(pvalue[0] & mask)); |
1420 | if (data_type_bits(dtype.dt) > 32) { |
1421 | if (is_integral(dtype.dt)) { |
1422 | int64_t sign_mask = 0xFFFFFFFFL; |
1423 | const int64_t *sign_ptr = reinterpret_cast<const int64_t *>(pvalue); |
1424 | ib_.add(static_cast<uint32_t>((sign_ptr[0] >> 32L) & sign_mask)); |
1425 | } else { |
1426 | ib_.add(static_cast<uint32_t>((pvalue[0] >> 32UL) & mask)); |
1427 | } |
1428 | } |
1429 | } |
1430 | |
1431 | ib_.commit(&global_); |
1432 | if (cache) { |
1433 | const_tbl_[key] = ret; |
1434 | } |
1435 | return ret; |
1436 | } |
1437 | |
1438 | SType IRBuilder::declare_primitive_type(DataType dt) { |
1439 | SType t; |
1440 | t.id = id_counter_++; |
1441 | t.dt = dt; |
1442 | t.flag = TypeKind::kPrimitive; |
1443 | |
1444 | dt.set_is_pointer(false); |
1445 | if (dt->is_primitive(PrimitiveTypeID::u1)) |
1446 | ib_.begin(spv::OpTypeBool).add(t).commit(&global_); |
1447 | else if (is_real(dt)) |
1448 | ib_.begin(spv::OpTypeFloat).add_seq(t, data_type_bits(dt)).commit(&global_); |
1449 | else if (is_integral(dt)) |
1450 | ib_.begin(spv::OpTypeInt) |
1451 | .add_seq(t, data_type_bits(dt), static_cast<int>(is_signed(dt))) |
1452 | .commit(&global_); |
1453 | else { |
1454 | TI_ERROR("Type {} not supported." , dt->to_string()); |
1455 | } |
1456 | |
1457 | return t; |
1458 | } |
1459 | |
1460 | void IRBuilder::init_random_function(Value global_tmp_) { |
1461 | // variables declare |
1462 | SType local_type = get_pointer_type(t_uint32_, spv::StorageClassPrivate); |
1463 | rand_x_ = new_value(local_type, ValueKind::kVariablePtr); |
1464 | rand_y_ = new_value(local_type, ValueKind::kVariablePtr); |
1465 | rand_z_ = new_value(local_type, ValueKind::kVariablePtr); |
1466 | rand_w_ = new_value(local_type, ValueKind::kVariablePtr); |
1467 | global_values.push_back(rand_x_); |
1468 | global_values.push_back(rand_y_); |
1469 | global_values.push_back(rand_z_); |
1470 | global_values.push_back(rand_w_); |
1471 | ib_.begin(spv::OpVariable) |
1472 | .add_seq(local_type, rand_x_, spv::StorageClassPrivate) |
1473 | .commit(&global_); |
1474 | ib_.begin(spv::OpVariable) |
1475 | .add_seq(local_type, rand_y_, spv::StorageClassPrivate) |
1476 | .commit(&global_); |
1477 | ib_.begin(spv::OpVariable) |
1478 | .add_seq(local_type, rand_z_, spv::StorageClassPrivate) |
1479 | .commit(&global_); |
1480 | ib_.begin(spv::OpVariable) |
1481 | .add_seq(local_type, rand_w_, spv::StorageClassPrivate) |
1482 | .commit(&global_); |
1483 | debug_name(spv::OpName, rand_x_, "_rand_x" ); |
1484 | debug_name(spv::OpName, rand_y_, "_rand_y" ); |
1485 | debug_name(spv::OpName, rand_z_, "_rand_z" ); |
1486 | debug_name(spv::OpName, rand_w_, "_rand_w" ); |
1487 | SType gtmp_type = get_pointer_type(t_uint32_, spv::StorageClassStorageBuffer); |
1488 | Value rand_gtmp_ = new_value(gtmp_type, ValueKind::kVariablePtr); |
1489 | debug_name(spv::OpName, rand_gtmp_, "rand_gtmp" ); |
1490 | |
1491 | auto load_var = [&](Value pointer, const SType &res_type) { |
1492 | TI_ASSERT(pointer.flag == ValueKind::kVariablePtr || |
1493 | pointer.flag == ValueKind::kStructArrayPtr); |
1494 | Value ret = new_value(res_type, ValueKind::kNormal); |
1495 | ib_.begin(spv::OpLoad) |
1496 | .add_seq(res_type, ret, pointer) |
1497 | .commit(&func_header_); |
1498 | return ret; |
1499 | }; |
1500 | |
1501 | auto store_var = [&](Value pointer, Value value) { |
1502 | TI_ASSERT(pointer.flag == ValueKind::kVariablePtr); |
1503 | TI_ASSERT(value.stype.id == pointer.stype.element_type_id); |
1504 | ib_.begin(spv::OpStore).add_seq(pointer, value).commit(&func_header_); |
1505 | }; |
1506 | |
1507 | // Constant Number |
1508 | Value _7654321u = uint_immediate_number(t_uint32_, 7654321u); |
1509 | Value _1234567u = uint_immediate_number(t_uint32_, 1234567u); |
1510 | Value _9723451u = uint_immediate_number(t_uint32_, 9723451u); |
1511 | Value _123456789u = uint_immediate_number(t_uint32_, 123456789u); |
1512 | Value _1000000007u = uint_immediate_number(t_uint32_, 1000000007u); |
1513 | Value _362436069u = uint_immediate_number(t_uint32_, 362436069u); |
1514 | Value _521288629u = uint_immediate_number(t_uint32_, 521288629u); |
1515 | Value _88675123u = uint_immediate_number(t_uint32_, 88675123u); |
1516 | Value _1 = int_immediate_number(t_uint32_, 1); |
1517 | Value _1024 = int_immediate_number(t_uint32_, 1024); |
1518 | |
1519 | // init_rand_ segment (inline to main) |
1520 | // ad-hoc: hope no kernel will use more than 1024 gtmp variables... |
1521 | ib_.begin(spv::OpAccessChain) |
1522 | .add_seq(gtmp_type, rand_gtmp_, global_tmp_, const_i32_zero_, _1024) |
1523 | .commit(&func_header_); |
1524 | // Get gl_GlobalInvocationID.x, assert it has be visited |
1525 | // (in generate_serial_kernel/generate_range_for_kernel |
1526 | SType pint_type = this->get_pointer_type(t_uint32_, spv::StorageClassInput); |
1527 | Value tmp0 = new_value(pint_type, ValueKind::kVariablePtr); |
1528 | ib_.begin(spv::OpAccessChain) |
1529 | .add_seq(pint_type, tmp0, gl_global_invocation_id_, |
1530 | uint_immediate_number(t_uint32_, 0)) |
1531 | .commit(&func_header_); |
1532 | Value tmp1 = load_var(tmp0, t_uint32_); |
1533 | Value tmp2_ = load_var(rand_gtmp_, t_uint32_); |
1534 | Value tmp2 = new_value(t_uint32_, ValueKind::kNormal); |
1535 | ib_.begin(spv::OpBitcast) |
1536 | .add_seq(t_uint32_, tmp2, tmp2_) |
1537 | .commit(&func_header_); |
1538 | Value tmp3 = new_value(t_uint32_, ValueKind::kNormal); |
1539 | ib_.begin(spv::OpIAdd) |
1540 | .add_seq(t_uint32_, tmp3, _7654321u, tmp1) |
1541 | .commit(&func_header_); |
1542 | Value tmp4 = new_value(t_uint32_, ValueKind::kNormal); |
1543 | ib_.begin(spv::OpIMul) |
1544 | .add_seq(t_uint32_, tmp4, _9723451u, tmp2) |
1545 | .commit(&func_header_); |
1546 | Value tmp5 = new_value(t_uint32_, ValueKind::kNormal); |
1547 | ib_.begin(spv::OpIAdd) |
1548 | .add_seq(t_uint32_, tmp5, _1234567u, tmp4) |
1549 | .commit(&func_header_); |
1550 | Value tmp6 = new_value(t_uint32_, ValueKind::kNormal); |
1551 | ib_.begin(spv::OpIMul) |
1552 | .add_seq(t_uint32_, tmp6, tmp3, tmp5) |
1553 | .commit(&func_header_); |
1554 | Value tmp7 = new_value(t_uint32_, ValueKind::kNormal); |
1555 | ib_.begin(spv::OpIMul) |
1556 | .add_seq(t_uint32_, tmp7, _123456789u, tmp6) |
1557 | .commit(&func_header_); |
1558 | Value tmp8 = new_value(t_uint32_, ValueKind::kNormal); |
1559 | ib_.begin(spv::OpIMul) |
1560 | .add_seq(t_uint32_, tmp8, _1000000007u, tmp7) |
1561 | .commit(&func_header_); |
1562 | store_var(rand_x_, tmp8); |
1563 | store_var(rand_y_, _362436069u); |
1564 | store_var(rand_z_, _521288629u); |
1565 | store_var(rand_w_, _88675123u); |
1566 | |
1567 | // enum spv::Op add_op = spv::OpIAdd; |
1568 | bool use_atomic_increment = false; |
1569 | |
1570 | // use atomic increment for DX API to avoid error X3694 |
1571 | #ifdef TI_WITH_DX11 |
1572 | if (arch_ == Arch::dx11) { |
1573 | use_atomic_increment = true; |
1574 | } |
1575 | #endif |
1576 | |
1577 | if (use_atomic_increment) { |
1578 | Value tmp9 = new_value(t_uint32_, ValueKind::kNormal); |
1579 | ib_.begin(spv::Op::OpAtomicIIncrement) |
1580 | .add_seq(t_uint32_, tmp9, rand_gtmp_, |
1581 | /*scope_id*/ const_i32_one_, |
1582 | /*semantics*/ const_i32_zero_) |
1583 | .commit(&func_header_); |
1584 | } else { |
1585 | // Yes, this is not an atomic operation, but just fine since no matter |
1586 | // how RAND_STATE changes, `gl_GlobalInvocationID.x` can still help |
1587 | // us to set different seeds for different threads. |
1588 | // Discussion: |
1589 | // https://github.com/taichi-dev/taichi/pull/912#discussion_r419021918 |
1590 | Value tmp9 = load_var(rand_gtmp_, t_uint32_); |
1591 | Value tmp10 = new_value(t_uint32_, ValueKind::kNormal); |
1592 | ib_.begin(spv::Op::OpIAdd) |
1593 | .add_seq(t_uint32_, tmp10, tmp9, _1) |
1594 | .commit(&func_header_); |
1595 | store_var(rand_gtmp_, tmp10); |
1596 | } |
1597 | |
1598 | init_rand_ = true; |
1599 | } |
1600 | |
1601 | } // namespace spirv |
1602 | } // namespace taichi::lang |
1603 | |