1#include "taichi/codegen/spirv/spirv_codegen.h"
2
3#include <string>
4#include <vector>
5#include <variant>
6
7#include "taichi/program/program.h"
8#include "taichi/program/kernel.h"
9#include "taichi/ir/statements.h"
10#include "taichi/ir/ir.h"
11#include "taichi/util/line_appender.h"
12#include "taichi/codegen/spirv/kernel_utils.h"
13#include "taichi/codegen/spirv/spirv_ir_builder.h"
14#include "taichi/ir/transforms.h"
15#include "taichi/math/arithmetic.h"
16
17#include <spirv-tools/libspirv.hpp>
18#include <spirv-tools/optimizer.hpp>
19
20namespace taichi::lang {
21namespace spirv {
22namespace {
23
24constexpr char kRootBufferName[] = "root_buffer";
25constexpr char kGlobalTmpsBufferName[] = "global_tmps_buffer";
26constexpr char kArgsBufferName[] = "args_buffer";
27constexpr char kRetBufferName[] = "ret_buffer";
28constexpr char kListgenBufferName[] = "listgen_buffer";
29constexpr char kExtArrBufferName[] = "ext_arr_buffer";
30
31constexpr int kMaxNumThreadsGridStrideLoop = 65536 * 2;
32
33using BufferType = TaskAttributes::BufferType;
34using BufferInfo = TaskAttributes::BufferInfo;
35using BufferBind = TaskAttributes::BufferBind;
36using BufferInfoHasher = TaskAttributes::BufferInfoHasher;
37
38using TextureBind = TaskAttributes::TextureBind;
39
40std::string buffer_instance_name(BufferInfo b) {
41 // https://www.khronos.org/opengl/wiki/Interface_Block_(GLSL)#Syntax
42 switch (b.type) {
43 case BufferType::Root:
44 return std::string(kRootBufferName) + "_" + std::to_string(b.root_id);
45 case BufferType::GlobalTmps:
46 return kGlobalTmpsBufferName;
47 case BufferType::Args:
48 return kArgsBufferName;
49 case BufferType::Rets:
50 return kRetBufferName;
51 case BufferType::ListGen:
52 return kListgenBufferName;
53 case BufferType::ExtArr:
54 return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id);
55 default:
56 TI_NOT_IMPLEMENTED;
57 break;
58 }
59 return {};
60}
61
62class TaskCodegen : public IRVisitor {
63 public:
64 struct Params {
65 OffloadedStmt *task_ir;
66 Arch arch;
67 DeviceCapabilityConfig *caps;
68 std::vector<CompiledSNodeStructs> compiled_structs;
69 const KernelContextAttributes *ctx_attribs;
70 std::string ti_kernel_name;
71 int task_id_in_kernel;
72 };
73
74 const bool use_64bit_pointers = false;
75
76 explicit TaskCodegen(const Params &params)
77 : arch_(params.arch),
78 caps_(params.caps),
79 task_ir_(params.task_ir),
80 compiled_structs_(params.compiled_structs),
81 ctx_attribs_(params.ctx_attribs),
82 task_name_(fmt::format("{}_t{:02d}",
83 params.ti_kernel_name,
84 params.task_id_in_kernel)) {
85 allow_undefined_visitor = true;
86 invoke_default_visitor = true;
87
88 fill_snode_to_root();
89 ir_ = std::make_shared<spirv::IRBuilder>(arch_, caps_);
90 }
91
92 void fill_snode_to_root() {
93 for (int root = 0; root < compiled_structs_.size(); ++root) {
94 for (auto &[node_id, node] : compiled_structs_[root].snode_descriptors) {
95 snode_to_root_[node_id] = root;
96 }
97 }
98 }
99
100 struct Result {
101 std::vector<uint32_t> spirv_code;
102 TaskAttributes task_attribs;
103 std::unordered_map<int, irpass::ExternalPtrAccess> arr_access;
104 };
105
106 Result run() {
107 ir_->init_header();
108 kernel_function_ = ir_->new_function(); // void main();
109 ir_->debug_name(spv::OpName, kernel_function_, "main");
110
111 if (task_ir_->task_type == OffloadedTaskType::serial) {
112 generate_serial_kernel(task_ir_);
113 } else if (task_ir_->task_type == OffloadedTaskType::range_for) {
114 // struct_for is automatically lowered to ranged_for for dense snodes
115 generate_range_for_kernel(task_ir_);
116 } else if (task_ir_->task_type == OffloadedTaskType::struct_for) {
117 generate_struct_for_kernel(task_ir_);
118 } else {
119 TI_ERROR("Unsupported offload type={} on SPIR-V codegen",
120 task_ir_->task_name());
121 }
122 // Headers need global information, so it has to be delayed after visiting
123 // the task IR.
124 emit_headers();
125
126 Result res;
127 res.spirv_code = ir_->finalize();
128 res.task_attribs = std::move(task_attribs_);
129 res.arr_access = irpass::detect_external_ptr_access_in_task(task_ir_);
130
131 return res;
132 }
133
134 void visit(OffloadedStmt *) override {
135 TI_ERROR("This codegen is supposed to deal with one offloaded task");
136 }
137
138 void visit(Block *stmt) override {
139 for (auto &s : stmt->statements) {
140 if (offload_loop_motion_.find(s.get()) == offload_loop_motion_.end()) {
141 s->accept(this);
142 }
143 }
144 }
145
146 void visit(PrintStmt *stmt) override {
147 if (!caps_->get(DeviceCapability::spirv_has_non_semantic_info)) {
148 return;
149 }
150
151 std::string formats;
152 std::vector<Value> vals;
153
154 for (auto const &content : stmt->contents) {
155 if (std::holds_alternative<Stmt *>(content)) {
156 auto arg_stmt = std::get<Stmt *>(content);
157 TI_ASSERT(!arg_stmt->ret_type->is<TensorType>());
158
159 auto value = ir_->query_value(arg_stmt->raw_name());
160 vals.push_back(value);
161 formats += data_type_format(arg_stmt->ret_type, Arch::vulkan);
162 } else {
163 auto arg_str = std::get<std::string>(content);
164 formats += arg_str;
165 }
166 }
167 ir_->call_debugprintf(formats, vals);
168 }
169
170 void visit(ConstStmt *const_stmt) override {
171 auto get_const = [&](const TypedConstant &const_val) {
172 auto dt = const_val.dt.ptr_removed();
173 spirv::SType stype = ir_->get_primitive_type(dt);
174
175 if (dt->is_primitive(PrimitiveTypeID::f32)) {
176 return ir_->float_immediate_number(
177 stype, static_cast<double>(const_val.val_f32), false);
178 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
179 return ir_->int_immediate_number(
180 stype, static_cast<int64_t>(const_val.val_i32), false);
181 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
182 return ir_->int_immediate_number(
183 stype, static_cast<int64_t>(const_val.val_i64), false);
184 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
185 return ir_->float_immediate_number(
186 stype, static_cast<double>(const_val.val_f64), false);
187 } else if (dt->is_primitive(PrimitiveTypeID::i8)) {
188 return ir_->int_immediate_number(
189 stype, static_cast<int64_t>(const_val.val_i8), false);
190 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
191 return ir_->int_immediate_number(
192 stype, static_cast<int64_t>(const_val.val_i16), false);
193 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
194 return ir_->uint_immediate_number(
195 stype, static_cast<uint64_t>(const_val.val_u8), false);
196 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
197 return ir_->uint_immediate_number(
198 stype, static_cast<uint64_t>(const_val.val_u16), false);
199 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
200 return ir_->uint_immediate_number(
201 stype, static_cast<uint64_t>(const_val.val_u32), false);
202 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
203 return ir_->uint_immediate_number(
204 stype, static_cast<uint64_t>(const_val.val_u64), false);
205 } else {
206 TI_P(data_type_name(dt));
207 TI_NOT_IMPLEMENTED
208 return spirv::Value();
209 }
210 };
211
212 spirv::Value val = get_const(const_stmt->val);
213 ir_->register_value(const_stmt->raw_name(), val);
214 }
215
216 void visit(AllocaStmt *alloca) override {
217 spirv::Value ptr_val;
218 if (alloca->ret_type->is<TensorType>()) {
219 auto tensor_type = alloca->ret_type->cast<TensorType>();
220 auto elem_num = tensor_type->get_num_elements();
221 spirv::SType elem_type =
222 ir_->get_primitive_type(tensor_type->get_element_type());
223 spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num);
224 if (alloca->is_shared) { // for shared memory / workgroup memory
225 ptr_val = ir_->alloca_workgroup_array(arr_type);
226 shared_array_binds_.push_back(ptr_val);
227 } else { // for function memory
228 ptr_val = ir_->alloca_variable(arr_type);
229 }
230 } else {
231 // Alloca for a single variable
232 spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
233 ptr_val = ir_->alloca_variable(src_type);
234 ir_->store_variable(ptr_val, ir_->get_zero(src_type));
235 }
236 ir_->register_value(alloca->raw_name(), ptr_val);
237 }
238
239 void visit(MatrixPtrStmt *stmt) override {
240 spirv::Value ptr_val;
241 spirv::Value origin_val = ir_->query_value(stmt->origin->raw_name());
242 spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
243 auto dt = stmt->element_type().ptr_removed();
244 if (stmt->offset_used_as_index()) {
245 if (stmt->origin->is<AllocaStmt>()) {
246 spirv::SType ptr_type = ir_->get_pointer_type(
247 ir_->get_primitive_type(dt), origin_val.stype.storage_class);
248 ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val,
249 offset_val);
250 } else if (stmt->origin->is<GlobalTemporaryStmt>()) {
251 spirv::Value dt_bytes = ir_->int_immediate_number(
252 ir_->i32_type(), ir_->get_primitive_type_size(dt), false);
253 spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val);
254 ptr_val = ir_->add(origin_val, offset_bytes);
255 ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
256 } else {
257 TI_NOT_IMPLEMENTED;
258 }
259 } else { // offset used as bytes
260 ptr_val = ir_->add(origin_val, ir_->cast(origin_val.stype, offset_val));
261 ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
262 }
263 ir_->register_value(stmt->raw_name(), ptr_val);
264 }
265
266 void visit(LocalLoadStmt *stmt) override {
267 auto ptr = stmt->src;
268 spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
269 spirv::Value val = ir_->load_variable(
270 ptr_val, ir_->get_primitive_type(stmt->element_type()));
271 ir_->register_value(stmt->raw_name(), val);
272 }
273
274 void visit(LocalStoreStmt *stmt) override {
275 spirv::Value ptr_val = ir_->query_value(stmt->dest->raw_name());
276 spirv::Value val = ir_->query_value(stmt->val->raw_name());
277 ir_->store_variable(ptr_val, val);
278 }
279
280 void visit(GetRootStmt *stmt) override {
281 const int root_id = snode_to_root_.at(stmt->root()->id);
282 root_stmts_[root_id] = stmt;
283 // get_buffer_value({BufferType::Root, root_id}, PrimitiveType::u32);
284 spirv::Value root_val = make_pointer(0);
285 ir_->register_value(stmt->raw_name(), root_val);
286 }
287
288 void visit(GetChStmt *stmt) override {
289 // TODO: GetChStmt -> GetComponentStmt ?
290 const int root = snode_to_root_.at(stmt->input_snode->id);
291
292 const auto &snode_descs = compiled_structs_[root].snode_descriptors;
293 auto *out_snode = stmt->output_snode;
294 TI_ASSERT(snode_descs.at(stmt->input_snode->id).get_child(stmt->chid) ==
295 out_snode);
296
297 const auto &desc = snode_descs.at(out_snode->id);
298
299 spirv::Value input_ptr_val = ir_->query_value(stmt->input_ptr->raw_name());
300 spirv::Value offset = make_pointer(desc.mem_offset_in_parent_cell);
301 spirv::Value val = ir_->add(input_ptr_val, offset);
302 ir_->register_value(stmt->raw_name(), val);
303
304 if (out_snode->is_place()) {
305 TI_ASSERT(ptr_to_buffers_.count(stmt) == 0);
306 ptr_to_buffers_[stmt] = BufferInfo(BufferType::Root, root);
307 }
308 }
309
310 enum class ActivationOp { activate, deactivate, query };
311
312 spirv::Value bitmasked_activation(ActivationOp op,
313 spirv::Value parent_ptr,
314 int root_id,
315 const SNode *sn,
316 spirv::Value input_index) {
317 spirv::SType ptr_dt = parent_ptr.stype;
318 const auto &snode_descs = compiled_structs_[root_id].snode_descriptors;
319 const auto &desc = snode_descs.at(sn->id);
320
321 auto bitmask_word_index =
322 ir_->make_value(spv::OpShiftRightLogical, ptr_dt, input_index,
323 ir_->uint_immediate_number(ptr_dt, 5));
324 auto bitmask_bit_index =
325 ir_->make_value(spv::OpBitwiseAnd, ptr_dt, input_index,
326 ir_->uint_immediate_number(ptr_dt, 31));
327 auto bitmask_mask = ir_->make_value(spv::OpShiftLeftLogical, ptr_dt,
328 ir_->const_i32_one_, bitmask_bit_index);
329
330 auto buffer = get_buffer_value(BufferInfo(BufferType::Root, root_id),
331 PrimitiveType::u32);
332 auto bitmask_word_ptr =
333 ir_->make_value(spv::OpShiftLeftLogical, ptr_dt, bitmask_word_index,
334 ir_->uint_immediate_number(ir_->u32_type(), 2));
335 bitmask_word_ptr = ir_->add(
336 bitmask_word_ptr,
337 make_pointer(desc.cell_stride * desc.snode->num_cells_per_container));
338 bitmask_word_ptr = ir_->add(parent_ptr, bitmask_word_ptr);
339 bitmask_word_ptr = ir_->make_value(
340 spv::OpShiftRightLogical, ir_->u32_type(), bitmask_word_ptr,
341 ir_->uint_immediate_number(ir_->u32_type(), 2));
342 bitmask_word_ptr =
343 ir_->struct_array_access(ir_->u32_type(), buffer, bitmask_word_ptr);
344
345 if (op == ActivationOp::activate) {
346 return ir_->make_value(spv::OpAtomicOr, ir_->u32_type(), bitmask_word_ptr,
347 /*scope=*/ir_->const_i32_one_,
348 /*semantics=*/ir_->const_i32_zero_, bitmask_mask);
349 } else if (op == ActivationOp::deactivate) {
350 bitmask_mask = ir_->make_value(spv::OpNot, ir_->u32_type(), bitmask_mask);
351 return ir_->make_value(spv::OpAtomicAnd, ir_->u32_type(),
352 bitmask_word_ptr,
353 /*scope=*/ir_->const_i32_one_,
354 /*semantics=*/ir_->const_i32_zero_, bitmask_mask);
355 } else {
356 auto bitmask_val = ir_->load_variable(bitmask_word_ptr, ir_->u32_type());
357 auto bit = ir_->make_value(spv::OpShiftRightLogical, ir_->u32_type(),
358 bitmask_val, bitmask_bit_index);
359 bit = ir_->make_value(spv::OpBitwiseAnd, ir_->u32_type(), bit,
360 ir_->uint_immediate_number(ir_->u32_type(), 1));
361 return ir_->make_value(spv::OpUGreaterThan, ir_->bool_type(), bit,
362 ir_->uint_immediate_number(ir_->u32_type(), 0));
363 }
364 }
365
366 void visit(SNodeOpStmt *stmt) override {
367 const int root_id = snode_to_root_.at(stmt->snode->id);
368 std::string parent = stmt->ptr->raw_name();
369 spirv::Value parent_val = ir_->query_value(parent);
370
371 if (stmt->snode->type == SNodeType::bitmasked) {
372 spirv::Value input_index_val =
373 ir_->cast(parent_val.stype, ir_->query_value(stmt->val->raw_name()));
374
375 if (stmt->op_type == SNodeOpType::is_active) {
376 auto is_active =
377 bitmasked_activation(ActivationOp::query, parent_val, root_id,
378 stmt->snode, input_index_val);
379 is_active =
380 ir_->cast(ir_->get_primitive_type(stmt->ret_type), is_active);
381 is_active = ir_->make_value(spv::OpSNegate, is_active.stype, is_active);
382 ir_->register_value(stmt->raw_name(), is_active);
383 } else if (stmt->op_type == SNodeOpType::deactivate) {
384 bitmasked_activation(ActivationOp::deactivate, parent_val, root_id,
385 stmt->snode, input_index_val);
386 } else if (stmt->op_type == SNodeOpType::activate) {
387 bitmasked_activation(ActivationOp::activate, parent_val, root_id,
388 stmt->snode, input_index_val);
389 } else {
390 TI_NOT_IMPLEMENTED;
391 }
392 } else {
393 TI_NOT_IMPLEMENTED;
394 }
395 }
396
397 void visit(SNodeLookupStmt *stmt) override {
398 // TODO: SNodeLookupStmt -> GetSNodeCellStmt ?
399 bool is_root{false}; // Eliminate first root snode access
400 const int root_id = snode_to_root_.at(stmt->snode->id);
401 std::string parent;
402
403 if (stmt->input_snode) {
404 parent = stmt->input_snode->raw_name();
405 } else {
406 TI_ASSERT(root_stmts_.at(root_id) != nullptr);
407 parent = root_stmts_.at(root_id)->raw_name();
408 }
409 const auto *sn = stmt->snode;
410
411 spirv::Value parent_val = ir_->query_value(parent);
412
413 if (stmt->activate) {
414 if (sn->type == SNodeType::dense) {
415 // Do nothing
416 } else if (sn->type == SNodeType::bitmasked) {
417 spirv::Value input_index_val =
418 ir_->query_value(stmt->input_index->raw_name());
419 bitmasked_activation(ActivationOp::activate, parent_val, root_id, sn,
420 input_index_val);
421 } else {
422 TI_NOT_IMPLEMENTED;
423 }
424 }
425
426 spirv::Value val;
427 if (is_root) {
428 val = parent_val; // Assert Root[0] access at first time
429 } else {
430 const auto &snode_descs = compiled_structs_[root_id].snode_descriptors;
431 const auto &desc = snode_descs.at(sn->id);
432
433 spirv::Value input_index_val = ir_->cast(
434 parent_val.stype, ir_->query_value(stmt->input_index->raw_name()));
435 spirv::Value stride = make_pointer(desc.cell_stride);
436 spirv::Value offset = ir_->mul(input_index_val, stride);
437 val = ir_->add(parent_val, offset);
438 }
439 ir_->register_value(stmt->raw_name(), val);
440 }
441
442 void visit(RandStmt *stmt) override {
443 spirv::Value val;
444 spirv::Value global_tmp =
445 get_buffer_value(BufferType::GlobalTmps, PrimitiveType::u32);
446 if (stmt->element_type()->is_primitive(PrimitiveTypeID::i32)) {
447 val = ir_->rand_i32(global_tmp);
448 } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::u32)) {
449 val = ir_->rand_u32(global_tmp);
450 } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f32)) {
451 val = ir_->rand_f32(global_tmp);
452 } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f16)) {
453 auto highp_val = ir_->rand_f32(global_tmp);
454 val = ir_->cast(ir_->f16_type(), highp_val);
455 } else {
456 TI_ERROR("rand only support 32-bit type");
457 }
458 ir_->register_value(stmt->raw_name(), val);
459 }
460
461 void visit(LinearizeStmt *stmt) override {
462 spirv::Value val = ir_->const_i32_zero_;
463 for (size_t i = 0; i < stmt->inputs.size(); ++i) {
464 spirv::Value strides_val =
465 ir_->int_immediate_number(ir_->i32_type(), stmt->strides[i]);
466 spirv::Value input_val = ir_->query_value(stmt->inputs[i]->raw_name());
467 val = ir_->add(ir_->mul(val, strides_val), input_val);
468 }
469 ir_->register_value(stmt->raw_name(), val);
470 }
471
472 void visit(LoopIndexStmt *stmt) override {
473 const auto stmt_name = stmt->raw_name();
474 if (stmt->loop->is<OffloadedStmt>()) {
475 const auto type = stmt->loop->as<OffloadedStmt>()->task_type;
476 if (type == OffloadedTaskType::range_for) {
477 TI_ASSERT(stmt->index == 0);
478 spirv::Value loop_var = ir_->query_value("ii");
479 // spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_);
480 ir_->register_value(stmt_name, loop_var);
481 } else {
482 TI_NOT_IMPLEMENTED;
483 }
484 } else if (stmt->loop->is<RangeForStmt>()) {
485 TI_ASSERT(stmt->index == 0);
486 spirv::Value loop_var = ir_->query_value(stmt->loop->raw_name());
487 spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_);
488 ir_->register_value(stmt_name, val);
489 } else {
490 TI_NOT_IMPLEMENTED;
491 }
492 }
493
494 void visit(GlobalStoreStmt *stmt) override {
495 spirv::Value val = ir_->query_value(stmt->val->raw_name());
496
497 store_buffer(stmt->dest, val);
498 }
499
500 void visit(GlobalLoadStmt *stmt) override {
501 auto dt = stmt->element_type();
502
503 auto val = load_buffer(stmt->src, dt);
504
505 ir_->register_value(stmt->raw_name(), val);
506 }
507
508 void visit(ArgLoadStmt *stmt) override {
509 const auto arg_id = stmt->arg_id;
510 const auto &arg_attribs = ctx_attribs_->args()[arg_id];
511 if (stmt->is_ptr) {
512 // Do not shift! We are indexing the buffers at byte granularity.
513 // spirv::Value val =
514 // ir_->int_immediate_number(ir_->i32_type(), offset_in_mem);
515 // ir_->register_value(stmt->raw_name(), val);
516 } else {
517 const auto dt = PrimitiveType::get(arg_attribs.dtype);
518 const auto val_type = ir_->get_primitive_type(dt);
519 spirv::Value buffer_val = ir_->make_value(
520 spv::OpAccessChain,
521 ir_->get_pointer_type(val_type, spv::StorageClassUniform),
522 get_buffer_value(BufferType::Args, PrimitiveType::i32),
523 ir_->int_immediate_number(ir_->i32_type(), arg_id));
524 buffer_val.flag = ValueKind::kVariablePtr;
525 spirv::Value val = ir_->load_variable(buffer_val, val_type);
526 ir_->register_value(stmt->raw_name(), val);
527 }
528 }
529
530 void visit(ReturnStmt *stmt) override {
531 // Now we only support one ret
532 auto dt = stmt->element_types()[0];
533 for (int i = 0; i < stmt->values.size(); i++) {
534 spirv::Value buffer_val = ir_->make_value(
535 spv::OpAccessChain,
536 ir_->get_storage_pointer_type(ir_->get_primitive_type(dt)),
537 get_buffer_value(BufferType::Rets, dt),
538 ir_->int_immediate_number(ir_->i32_type(), 0),
539 ir_->int_immediate_number(ir_->i32_type(), i));
540 buffer_val.flag = ValueKind::kVariablePtr;
541 spirv::Value val = ir_->query_value(stmt->values[i]->raw_name());
542 ir_->store_variable(buffer_val, val);
543 }
544 }
545
546 void visit(GlobalTemporaryStmt *stmt) override {
547 spirv::Value val = ir_->int_immediate_number(ir_->i32_type(), stmt->offset,
548 false); // Named Constant
549 ir_->register_value(stmt->raw_name(), val);
550 ptr_to_buffers_[stmt] = BufferType::GlobalTmps;
551 }
552
553 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
554 const auto name = stmt->raw_name();
555 const auto arg_id = stmt->arg_id;
556 const auto axis = stmt->axis;
557
558 const auto extra_args_member_index = ctx_attribs_->args().size();
559
560 const auto extra_arg_index = (arg_id * taichi_max_num_indices) + axis;
561 spirv::Value var_ptr = ir_->make_value(
562 spv::OpAccessChain,
563 ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform),
564 get_buffer_value(BufferType::Args, PrimitiveType::i32),
565 ir_->int_immediate_number(ir_->i32_type(),
566 extra_args_member_index + extra_arg_index));
567 spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type());
568
569 ir_->register_value(name, var);
570 }
571
572 void visit(ExternalPtrStmt *stmt) override {
573 // Used mostly for transferring data between host (e.g. numpy array) and
574 // device.
575 spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0);
576 const auto *argload = stmt->base_ptr->as<ArgLoadStmt>();
577 const int arg_id = argload->arg_id;
578 {
579 const int num_indices = stmt->indices.size();
580 std::vector<std::string> size_var_names;
581 const auto &element_shape = stmt->element_shape;
582 const auto layout = stmt->element_dim <= 0 ? ExternalArrayLayout::kAOS
583 : ExternalArrayLayout::kSOA;
584 const auto extra_args_member_index = ctx_attribs_->args().size();
585 const size_t element_shape_index_offset =
586 (layout == ExternalArrayLayout::kAOS)
587 ? num_indices - element_shape.size()
588 : 0;
589 for (int i = 0; i < num_indices - element_shape.size(); i++) {
590 std::string var_name = fmt::format("{}_size{}_", stmt->raw_name(), i);
591 const auto extra_arg_index = (arg_id * taichi_max_num_indices) + i;
592 spirv::Value var_ptr = ir_->make_value(
593 spv::OpAccessChain,
594 ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform),
595 get_buffer_value(BufferType::Args, PrimitiveType::i32),
596 ir_->int_immediate_number(
597 ir_->i32_type(), extra_args_member_index + extra_arg_index));
598 spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type());
599 ir_->register_value(var_name, var);
600 size_var_names.push_back(std::move(var_name));
601 }
602 int size_var_names_idx = 0;
603 for (int i = 0; i < num_indices; i++) {
604 spirv::Value size_var;
605 // Use immediate numbers to flatten index for element shapes.
606 if (i >= element_shape_index_offset &&
607 i < element_shape_index_offset + element_shape.size()) {
608 size_var = ir_->uint_immediate_number(
609 ir_->i32_type(), element_shape[i - element_shape_index_offset]);
610 } else {
611 size_var = ir_->query_value(size_var_names[size_var_names_idx++]);
612 }
613 spirv::Value indices = ir_->query_value(stmt->indices[i]->raw_name());
614 linear_offset = ir_->mul(linear_offset, size_var);
615 linear_offset = ir_->add(linear_offset, indices);
616 }
617 linear_offset = ir_->make_value(
618 spv::OpShiftLeftLogical, ir_->i32_type(), linear_offset,
619 ir_->int_immediate_number(ir_->i32_type(),
620 log2int(ir_->get_primitive_type_size(
621 argload->ret_type.ptr_removed()))));
622 if (caps_->get(DeviceCapability::spirv_has_no_integer_wrap_decoration)) {
623 ir_->decorate(spv::OpDecorate, linear_offset,
624 spv::DecorationNoSignedWrap);
625 }
626 }
627 if (caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) {
628 spirv::Value addr_ptr = ir_->make_value(
629 spv::OpAccessChain,
630 ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform),
631 get_buffer_value(BufferType::Args, PrimitiveType::i32),
632 ir_->int_immediate_number(ir_->i32_type(), arg_id));
633 spirv::Value addr = ir_->load_variable(addr_ptr, ir_->u64_type());
634 addr = ir_->add(addr, ir_->make_value(spv::OpSConvert, ir_->u64_type(),
635 linear_offset));
636 ir_->register_value(stmt->raw_name(), addr);
637 } else {
638 ir_->register_value(stmt->raw_name(), linear_offset);
639 }
640
641 if (ctx_attribs_->args()[arg_id].is_array) {
642 ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id};
643 } else {
644 ptr_to_buffers_[stmt] = BufferType::Args;
645 }
646 }
647
648 void visit(DecorationStmt *stmt) override {
649 }
650
651 void visit(UnaryOpStmt *stmt) override {
652 const auto operand_name = stmt->operand->raw_name();
653
654 const auto src_dt = stmt->operand->element_type();
655 const auto dst_dt = stmt->element_type();
656 spirv::SType src_type = ir_->get_primitive_type(src_dt);
657 spirv::SType dst_type = ir_->get_primitive_type(dst_dt);
658 spirv::Value operand_val = ir_->query_value(operand_name);
659 spirv::Value val = spirv::Value();
660
661 if (stmt->op_type == UnaryOpType::logic_not) {
662 spirv::Value zero =
663 ir_->get_zero(src_type); // Math zero type to left hand side
664 if (is_integral(src_dt)) {
665 if (is_signed(src_dt)) {
666 zero = ir_->int_immediate_number(src_type, 0);
667 } else {
668 zero = ir_->uint_immediate_number(src_type, 0);
669 }
670 } else if (is_real(src_dt)) {
671 zero = ir_->float_immediate_number(src_type, 0);
672 } else {
673 TI_NOT_IMPLEMENTED
674 }
675 val = ir_->cast(dst_type, ir_->eq(operand_val, zero));
676 } else if (stmt->op_type == UnaryOpType::neg) {
677 operand_val = ir_->cast(dst_type, operand_val);
678 if (is_integral(dst_dt)) {
679 if (is_signed(dst_dt)) {
680 val = ir_->make_value(spv::OpSNegate, dst_type, operand_val);
681 } else {
682 TI_NOT_IMPLEMENTED
683 }
684 } else if (is_real(dst_dt)) {
685 val = ir_->make_value(spv::OpFNegate, dst_type, operand_val);
686 } else {
687 TI_NOT_IMPLEMENTED
688 }
689 } else if (stmt->op_type == UnaryOpType::rsqrt) {
690 const uint32_t InverseSqrt_id = 32;
691 if (is_real(src_dt)) {
692 val = ir_->call_glsl450(src_type, InverseSqrt_id, operand_val);
693 val = ir_->cast(dst_type, val);
694 } else {
695 TI_NOT_IMPLEMENTED
696 }
697 } else if (stmt->op_type == UnaryOpType::sgn) {
698 const uint32_t FSign_id = 6;
699 const uint32_t SSign_id = 7;
700 if (is_integral(src_dt)) {
701 if (is_signed(src_dt)) {
702 val = ir_->call_glsl450(src_type, SSign_id, operand_val);
703 } else {
704 TI_NOT_IMPLEMENTED
705 }
706 } else if (is_real(src_dt)) {
707 val = ir_->call_glsl450(src_type, FSign_id, operand_val);
708 } else {
709 TI_NOT_IMPLEMENTED
710 }
711 val = ir_->cast(dst_type, val);
712 } else if (stmt->op_type == UnaryOpType::bit_not) {
713 operand_val = ir_->cast(dst_type, operand_val);
714 if (is_integral(dst_dt)) {
715 val = ir_->make_value(spv::OpNot, dst_type, operand_val);
716 } else {
717 TI_NOT_IMPLEMENTED
718 }
719 } else if (stmt->op_type == UnaryOpType::cast_value) {
720 val = ir_->cast(dst_type, operand_val);
721 } else if (stmt->op_type == UnaryOpType::cast_bits) {
722 if (data_type_bits(src_dt) == data_type_bits(dst_dt)) {
723 val = ir_->make_value(spv::OpBitcast, dst_type, operand_val);
724 } else {
725 TI_ERROR("bit_cast is only supported between data type with same size");
726 }
727 } else if (stmt->op_type == UnaryOpType::abs) {
728 const uint32_t FAbs_id = 4;
729 const uint32_t SAbs_id = 5;
730 if (src_type.id == dst_type.id) {
731 if (is_integral(src_dt)) {
732 if (is_signed(src_dt)) {
733 val = ir_->call_glsl450(src_type, SAbs_id, operand_val);
734 } else {
735 TI_NOT_IMPLEMENTED
736 }
737 } else if (is_real(src_dt)) {
738 val = ir_->call_glsl450(src_type, FAbs_id, operand_val);
739 } else {
740 TI_NOT_IMPLEMENTED
741 }
742 } else {
743 TI_NOT_IMPLEMENTED
744 }
745 } else if (stmt->op_type == UnaryOpType::inv) {
746 if (is_real(dst_dt)) {
747 val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val);
748 } else {
749 TI_NOT_IMPLEMENTED
750 }
751 }
752#define UNARY_OP_TO_SPIRV(op, instruction, instruction_id, max_bits) \
753 else if (stmt->op_type == UnaryOpType::op) { \
754 const uint32_t instruction = instruction_id; \
755 if (is_real(src_dt)) { \
756 if (data_type_bits(src_dt) > max_bits) { \
757 TI_ERROR("Instruction {}({}) does not {}bits operation", #instruction, \
758 instruction_id, data_type_bits(src_dt)); \
759 } \
760 val = ir_->call_glsl450(src_type, instruction, operand_val); \
761 } else { \
762 TI_NOT_IMPLEMENTED \
763 } \
764 }
765 UNARY_OP_TO_SPIRV(round, Round, 1, 64)
766 UNARY_OP_TO_SPIRV(floor, Floor, 8, 64)
767 UNARY_OP_TO_SPIRV(ceil, Ceil, 9, 64)
768 UNARY_OP_TO_SPIRV(sin, Sin, 13, 32)
769 UNARY_OP_TO_SPIRV(asin, Asin, 16, 32)
770 UNARY_OP_TO_SPIRV(cos, Cos, 14, 32)
771 UNARY_OP_TO_SPIRV(acos, Acos, 17, 32)
772 UNARY_OP_TO_SPIRV(tan, Tan, 15, 32)
773 UNARY_OP_TO_SPIRV(tanh, Tanh, 21, 32)
774 UNARY_OP_TO_SPIRV(exp, Exp, 27, 32)
775 UNARY_OP_TO_SPIRV(log, Log, 28, 32)
776 UNARY_OP_TO_SPIRV(sqrt, Sqrt, 31, 64)
777#undef UNARY_OP_TO_SPIRV
778 else {TI_NOT_IMPLEMENTED} ir_->register_value(stmt->raw_name(), val);
779 }
780
781 void generate_overflow_branch(const spirv::Value &cond_v,
782 const std::string &op,
783 const std::string &tb) {
784 spirv::Value cond =
785 ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
786 spirv::Label then_label = ir_->new_label();
787 spirv::Label merge_label = ir_->new_label();
788 ir_->make_inst(spv::OpSelectionMerge, merge_label,
789 spv::SelectionControlMaskNone);
790 ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label);
791 // then block
792 ir_->start_label(then_label);
793 ir_->call_debugprintf(op + " overflow detected in " + tb, {});
794 ir_->make_inst(spv::OpBranch, merge_label);
795 // merge label
796 ir_->start_label(merge_label);
797 }
798
799 spirv::Value generate_uadd_overflow(const spirv::Value &a,
800 const spirv::Value &b,
801 const std::string &tb) {
802 std::vector<std::tuple<spirv::SType, std::string, size_t>>
803 struct_components_;
804 struct_components_.emplace_back(a.stype, "result", 0);
805 struct_components_.emplace_back(a.stype, "carry",
806 ir_->get_primitive_type_size(a.stype.dt));
807 auto struct_type = ir_->create_struct_type(struct_components_);
808 auto add_carry = ir_->make_value(spv::OpIAddCarry, struct_type, a, b);
809 auto result =
810 ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 0);
811 auto carry =
812 ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 1);
813 generate_overflow_branch(carry, "Addition", tb);
814 return result;
815 }
816
817 spirv::Value generate_usub_overflow(const spirv::Value &a,
818 const spirv::Value &b,
819 const std::string &tb) {
820 std::vector<std::tuple<spirv::SType, std::string, size_t>>
821 struct_components_;
822 struct_components_.emplace_back(a.stype, "result", 0);
823 struct_components_.emplace_back(a.stype, "borrow",
824 ir_->get_primitive_type_size(a.stype.dt));
825 auto struct_type = ir_->create_struct_type(struct_components_);
826 auto add_carry = ir_->make_value(spv::OpISubBorrow, struct_type, a, b);
827 auto result =
828 ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 0);
829 auto borrow =
830 ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 1);
831 generate_overflow_branch(borrow, "Subtraction", tb);
832 return result;
833 }
834
835 spirv::Value generate_sadd_overflow(const spirv::Value &a,
836 const spirv::Value &b,
837 const std::string &tb) {
838 // overflow iff (sign(a) == sign(b)) && (sign(a) != sign(result))
839 auto result = ir_->make_value(spv::OpIAdd, a.stype, a, b);
840 auto zero = ir_->int_immediate_number(a.stype, 0);
841 auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero);
842 auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero);
843 auto r_sign =
844 ir_->make_value(spv::OpSLessThan, ir_->bool_type(), result, zero);
845 auto a_eq_b =
846 ir_->make_value(spv::OpLogicalEqual, ir_->bool_type(), a_sign, b_sign);
847 auto a_neq_r = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(),
848 a_sign, r_sign);
849 auto overflow =
850 ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), a_eq_b, a_neq_r);
851 generate_overflow_branch(overflow, "Addition", tb);
852 return result;
853 }
854
855 spirv::Value generate_ssub_overflow(const spirv::Value &a,
856 const spirv::Value &b,
857 const std::string &tb) {
858 // overflow iff (sign(a) != sign(b)) && (sign(a) != sign(result))
859 auto result = ir_->make_value(spv::OpISub, a.stype, a, b);
860 auto zero = ir_->int_immediate_number(a.stype, 0);
861 auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero);
862 auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero);
863 auto r_sign =
864 ir_->make_value(spv::OpSLessThan, ir_->bool_type(), result, zero);
865 auto a_neq_b = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(),
866 a_sign, b_sign);
867 auto a_neq_r = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(),
868 a_sign, r_sign);
869 auto overflow =
870 ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), a_neq_b, a_neq_r);
871 generate_overflow_branch(overflow, "Subtraction", tb);
872 return result;
873 }
874
875 spirv::Value generate_umul_overflow(const spirv::Value &a,
876 const spirv::Value &b,
877 const std::string &tb) {
878 // overflow iff high bits != 0
879 std::vector<std::tuple<spirv::SType, std::string, size_t>>
880 struct_components_;
881 struct_components_.emplace_back(a.stype, "low", 0);
882 struct_components_.emplace_back(a.stype, "high",
883 ir_->get_primitive_type_size(a.stype.dt));
884 auto struct_type = ir_->create_struct_type(struct_components_);
885 auto mul_ext = ir_->make_value(spv::OpUMulExtended, struct_type, a, b);
886 auto low = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 0);
887 auto high = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 1);
888 generate_overflow_branch(high, "Multiplication", tb);
889 return low;
890 }
891
892 spirv::Value generate_smul_overflow(const spirv::Value &a,
893 const spirv::Value &b,
894 const std::string &tb) {
895 // overflow if high bits are not all sign bit (0 if positive, -1 if
896 // negative) or the sign bit of the low bits is not the expected sign bit.
897 std::vector<std::tuple<spirv::SType, std::string, size_t>>
898 struct_components_;
899 struct_components_.emplace_back(a.stype, "low", 0);
900 struct_components_.emplace_back(a.stype, "high",
901 ir_->get_primitive_type_size(a.stype.dt));
902 auto struct_type = ir_->create_struct_type(struct_components_);
903 auto mul_ext = ir_->make_value(spv::OpSMulExtended, struct_type, a, b);
904 auto low = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 0);
905 auto high = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 1);
906 auto zero = ir_->int_immediate_number(a.stype, 0);
907 auto minus_one = ir_->int_immediate_number(a.stype, -1);
908 auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero);
909 auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero);
910 auto a_not_zero = ir_->ne(a, zero);
911 auto b_not_zero = ir_->ne(b, zero);
912 auto a_b_not_zero = ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(),
913 a_not_zero, b_not_zero);
914 auto low_sign =
915 ir_->make_value(spv::OpSLessThan, ir_->bool_type(), low, zero);
916 auto expected_sign = ir_->make_value(spv::OpLogicalNotEqual,
917 ir_->bool_type(), a_sign, b_sign);
918 expected_sign = ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(),
919 expected_sign, a_b_not_zero);
920 auto not_expected_sign = ir_->ne(low_sign, expected_sign);
921 auto expected_high = ir_->select(expected_sign, minus_one, zero);
922 auto not_expected_high = ir_->ne(high, expected_high);
923 auto overflow = ir_->make_value(spv::OpLogicalOr, ir_->bool_type(),
924 not_expected_high, not_expected_sign);
925 generate_overflow_branch(overflow, "Multiplication", tb);
926 return low;
927 }
928
929 spirv::Value generate_ushl_overflow(const spirv::Value &a,
930 const spirv::Value &b,
931 const std::string &tb) {
932 // overflow iff a << b >> b != a
933 auto result = ir_->make_value(spv::OpShiftLeftLogical, a.stype, a, b);
934 auto restore =
935 ir_->make_value(spv::OpShiftRightLogical, a.stype, result, b);
936 auto overflow = ir_->ne(a, restore);
937 generate_overflow_branch(overflow, "Shift left", tb);
938 return result;
939 }
940
941 spirv::Value generate_sshl_overflow(const spirv::Value &a,
942 const spirv::Value &b,
943 const std::string &tb) {
944 // overflow iff a << b >> b != a
945 auto result = ir_->make_value(spv::OpShiftLeftLogical, a.stype, a, b);
946 auto restore =
947 ir_->make_value(spv::OpShiftRightArithmetic, a.stype, result, b);
948 auto overflow = ir_->ne(a, restore);
949 generate_overflow_branch(overflow, "Shift left", tb);
950 return result;
951 }
952
953 void visit(BinaryOpStmt *bin) override {
954 const auto lhs_name = bin->lhs->raw_name();
955 const auto rhs_name = bin->rhs->raw_name();
956 const auto bin_name = bin->raw_name();
957 const auto op_type = bin->op_type;
958
959 spirv::SType dst_type = ir_->get_primitive_type(bin->element_type());
960 spirv::Value lhs_value = ir_->query_value(lhs_name);
961 spirv::Value rhs_value = ir_->query_value(rhs_name);
962 spirv::Value bin_value = spirv::Value();
963
964 TI_WARN_IF(lhs_value.stype.id != rhs_value.stype.id,
965 "${} type {} != ${} type {}\n{}", lhs_name,
966 lhs_value.stype.dt->to_string(), rhs_name,
967 rhs_value.stype.dt->to_string(), bin->tb);
968
969 bool debug = caps_->get(DeviceCapability::spirv_has_non_semantic_info);
970
971 if (debug && op_type == BinaryOpType::add && is_integral(dst_type.dt)) {
972 if (is_unsigned(dst_type.dt)) {
973 bin_value = generate_uadd_overflow(lhs_value, rhs_value, bin->tb);
974 } else {
975 bin_value = generate_sadd_overflow(lhs_value, rhs_value, bin->tb);
976 }
977 bin_value = ir_->cast(dst_type, bin_value);
978 } else if (debug && op_type == BinaryOpType::sub &&
979 is_integral(dst_type.dt)) {
980 if (is_unsigned(dst_type.dt)) {
981 bin_value = generate_usub_overflow(lhs_value, rhs_value, bin->tb);
982 } else {
983 bin_value = generate_ssub_overflow(lhs_value, rhs_value, bin->tb);
984 }
985 bin_value = ir_->cast(dst_type, bin_value);
986 } else if (debug && op_type == BinaryOpType::mul &&
987 is_integral(dst_type.dt)) {
988 if (is_unsigned(dst_type.dt)) {
989 bin_value = generate_umul_overflow(lhs_value, rhs_value, bin->tb);
990 } else {
991 bin_value = generate_smul_overflow(lhs_value, rhs_value, bin->tb);
992 }
993 bin_value = ir_->cast(dst_type, bin_value);
994 }
995#define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \
996 else if (op_type == BinaryOpType::op) { \
997 bin_value = ir_->func(lhs_value, rhs_value); \
998 bin_value = ir_->cast(dst_type, bin_value); \
999 }
1000
1001 BINARY_OP_TO_SPIRV_ARTHIMATIC(add, add)
1002 BINARY_OP_TO_SPIRV_ARTHIMATIC(sub, sub)
1003 BINARY_OP_TO_SPIRV_ARTHIMATIC(mul, mul)
1004 BINARY_OP_TO_SPIRV_ARTHIMATIC(div, div)
1005 BINARY_OP_TO_SPIRV_ARTHIMATIC(mod, mod)
1006#undef BINARY_OP_TO_SPIRV_ARTHIMATIC
1007
1008#define BINARY_OP_TO_SPIRV_BITWISE(op, sym) \
1009 else if (op_type == BinaryOpType::op) { \
1010 bin_value = ir_->make_value(spv::sym, dst_type, lhs_value, rhs_value); \
1011 }
1012
1013 else if (debug && op_type == BinaryOpType::bit_shl) {
1014 if (is_unsigned(dst_type.dt)) {
1015 bin_value = generate_ushl_overflow(lhs_value, rhs_value, bin->tb);
1016 } else {
1017 bin_value = generate_sshl_overflow(lhs_value, rhs_value, bin->tb);
1018 }
1019 }
1020 BINARY_OP_TO_SPIRV_BITWISE(bit_and, OpBitwiseAnd)
1021 BINARY_OP_TO_SPIRV_BITWISE(bit_or, OpBitwiseOr)
1022 BINARY_OP_TO_SPIRV_BITWISE(bit_xor, OpBitwiseXor)
1023 BINARY_OP_TO_SPIRV_BITWISE(bit_shl, OpShiftLeftLogical)
1024 // NOTE: `OpShiftRightArithmetic` will treat the first bit as sign bit even
1025 // it's the unsigned type
1026 else if (op_type == BinaryOpType::bit_sar) {
1027 bin_value = ir_->make_value(is_unsigned(dst_type.dt)
1028 ? spv::OpShiftRightLogical
1029 : spv::OpShiftRightArithmetic,
1030 dst_type, lhs_value, rhs_value);
1031 }
1032#undef BINARY_OP_TO_SPIRV_BITWISE
1033
1034#define BINARY_OP_TO_SPIRV_LOGICAL(op, func) \
1035 else if (op_type == BinaryOpType::op) { \
1036 bin_value = ir_->func(lhs_value, rhs_value); \
1037 bin_value = ir_->cast(dst_type, bin_value); \
1038 bin_value = ir_->make_value(spv::OpSNegate, dst_type, bin_value); \
1039 }
1040
1041 BINARY_OP_TO_SPIRV_LOGICAL(cmp_lt, lt)
1042 BINARY_OP_TO_SPIRV_LOGICAL(cmp_le, le)
1043 BINARY_OP_TO_SPIRV_LOGICAL(cmp_gt, gt)
1044 BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge)
1045 BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq)
1046 BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne)
1047#undef BINARY_OP_TO_SPIRV_LOGICAL
1048
1049#define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \
1050 max_bits) \
1051 else if (op_type == BinaryOpType::op) { \
1052 const uint32_t instruction = instruction_id; \
1053 if (is_real(bin->element_type())) { \
1054 if (data_type_bits(bin->element_type()) > max_bits) { \
1055 TI_ERROR( \
1056 "[glsl450] the operand type of instruction {}({}) must <= {}bits", \
1057 #instruction, instruction_id, max_bits); \
1058 } \
1059 bin_value = \
1060 ir_->call_glsl450(dst_type, instruction, lhs_value, rhs_value); \
1061 } else { \
1062 TI_NOT_IMPLEMENTED \
1063 } \
1064 }
1065
1066 FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(atan2, Atan2, 25, 32)
1067 FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32)
1068#undef FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC
1069
1070#define BINARY_OP_TO_SPIRV_FUNC(op, S_inst, S_inst_id, U_inst, U_inst_id, \
1071 F_inst, F_inst_id) \
1072 else if (op_type == BinaryOpType::op) { \
1073 const uint32_t S_inst = S_inst_id; \
1074 const uint32_t U_inst = U_inst_id; \
1075 const uint32_t F_inst = F_inst_id; \
1076 const auto dst_dt = bin->element_type(); \
1077 if (is_integral(dst_dt)) { \
1078 if (is_signed(dst_dt)) { \
1079 bin_value = ir_->call_glsl450(dst_type, S_inst, lhs_value, rhs_value); \
1080 } else { \
1081 bin_value = ir_->call_glsl450(dst_type, U_inst, lhs_value, rhs_value); \
1082 } \
1083 } else if (is_real(dst_dt)) { \
1084 bin_value = ir_->call_glsl450(dst_type, F_inst, lhs_value, rhs_value); \
1085 } else { \
1086 TI_NOT_IMPLEMENTED \
1087 } \
1088 }
1089
1090 BINARY_OP_TO_SPIRV_FUNC(min, SMin, 39, UMin, 38, FMin, 37)
1091 BINARY_OP_TO_SPIRV_FUNC(max, SMax, 42, UMax, 41, FMax, 40)
1092#undef BINARY_OP_TO_SPIRV_FUNC
1093 else if (op_type == BinaryOpType::truediv) {
1094 lhs_value = ir_->cast(dst_type, lhs_value);
1095 rhs_value = ir_->cast(dst_type, rhs_value);
1096 bin_value = ir_->div(lhs_value, rhs_value);
1097 }
1098 else {TI_NOT_IMPLEMENTED} ir_->register_value(bin_name, bin_value);
1099 }
1100
1101 void visit(TernaryOpStmt *tri) override {
1102 TI_ASSERT(tri->op_type == TernaryOpType::select);
1103 spirv::Value op1 = ir_->query_value(tri->op1->raw_name());
1104 spirv::Value op2 = ir_->query_value(tri->op2->raw_name());
1105 spirv::Value op3 = ir_->query_value(tri->op3->raw_name());
1106 spirv::Value tri_val =
1107 ir_->cast(ir_->get_primitive_type(tri->element_type()),
1108 ir_->select(ir_->cast(ir_->bool_type(), op1), op2, op3));
1109 ir_->register_value(tri->raw_name(), tri_val);
1110 }
1111
1112 inline bool ends_with(std::string const &value, std::string const &ending) {
1113 if (ending.size() > value.size())
1114 return false;
1115 return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
1116 }
1117
1118 void visit(TexturePtrStmt *stmt) override {
1119 spirv::Value val;
1120
1121 int arg_id = stmt->arg_load_stmt->as<ArgLoadStmt>()->arg_id;
1122 if (argid_to_tex_value_.find(arg_id) != argid_to_tex_value_.end()) {
1123 val = argid_to_tex_value_.at(arg_id);
1124 } else {
1125 if (stmt->is_storage) {
1126 BufferFormat format = BufferFormat::unknown;
1127
1128 if (stmt->num_channels == 1) {
1129 if (stmt->channel_format == PrimitiveType::u8 ||
1130 stmt->channel_format == PrimitiveType::i8) {
1131 format = BufferFormat::r8;
1132 } else if (stmt->channel_format == PrimitiveType::u16 ||
1133 stmt->channel_format == PrimitiveType::i16) {
1134 format = BufferFormat::r16;
1135 } else if (stmt->channel_format == PrimitiveType::f16) {
1136 format = BufferFormat::r16f;
1137 } else if (stmt->channel_format == PrimitiveType::f32) {
1138 format = BufferFormat::r32f;
1139 }
1140 } else if (stmt->num_channels == 2) {
1141 if (stmt->channel_format == PrimitiveType::u8 ||
1142 stmt->channel_format == PrimitiveType::i8) {
1143 format = BufferFormat::rg8;
1144 } else if (stmt->channel_format == PrimitiveType::u16 ||
1145 stmt->channel_format == PrimitiveType::i16) {
1146 format = BufferFormat::rg16;
1147 } else if (stmt->channel_format == PrimitiveType::f16) {
1148 format = BufferFormat::rg16f;
1149 } else if (stmt->channel_format == PrimitiveType::f32) {
1150 format = BufferFormat::rg32f;
1151 }
1152 } else if (stmt->num_channels == 4) {
1153 if (stmt->channel_format == PrimitiveType::u8 ||
1154 stmt->channel_format == PrimitiveType::i8) {
1155 format = BufferFormat::rgba8;
1156 } else if (stmt->channel_format == PrimitiveType::u16 ||
1157 stmt->channel_format == PrimitiveType::i16) {
1158 format = BufferFormat::rgba16;
1159 } else if (stmt->channel_format == PrimitiveType::f16) {
1160 format = BufferFormat::rgba16f;
1161 } else if (stmt->channel_format == PrimitiveType::f32) {
1162 format = BufferFormat::rgba32f;
1163 }
1164 }
1165
1166 int binding = binding_head_++;
1167 val =
1168 ir_->storage_image_argument(/*num_channels=*/4, stmt->dimensions,
1169 /*descriptor_set=*/0, binding, format);
1170 TextureBind bind;
1171 bind.arg_id = arg_id;
1172 bind.binding = binding;
1173 bind.is_storage = true;
1174 texture_binds_.push_back(bind);
1175 argid_to_tex_value_[arg_id] = val;
1176 } else {
1177 int binding = binding_head_++;
1178 val = ir_->texture_argument(/*num_channels=*/4, stmt->dimensions,
1179 /*descriptor_set=*/0, binding);
1180 TextureBind bind;
1181 bind.arg_id = arg_id;
1182 bind.binding = binding;
1183 texture_binds_.push_back(bind);
1184 argid_to_tex_value_[arg_id] = val;
1185 }
1186 }
1187
1188 ir_->register_value(stmt->raw_name(), val);
1189 }
1190
1191 void visit(TextureOpStmt *stmt) override {
1192 spirv::Value tex = ir_->query_value(stmt->texture_ptr->raw_name());
1193 spirv::Value val;
1194 if (stmt->op == TextureOpType::kSampleLod ||
1195 stmt->op == TextureOpType::kFetchTexel) {
1196 // Texture Ops
1197 std::vector<spirv::Value> args;
1198 for (int i = 0; i < stmt->args.size() - 1; i++) {
1199 args.push_back(ir_->query_value(stmt->args[i]->raw_name()));
1200 }
1201 spirv::Value lod = ir_->query_value(stmt->args.back()->raw_name());
1202 if (stmt->op == TextureOpType::kSampleLod) {
1203 // Sample
1204 val = ir_->sample_texture(tex, args, lod);
1205 } else if (stmt->op == TextureOpType::kFetchTexel) {
1206 // Texel fetch
1207 val = ir_->fetch_texel(tex, args, lod);
1208 }
1209 ir_->register_value(stmt->raw_name(), val);
1210 } else if (stmt->op == TextureOpType::kLoad ||
1211 stmt->op == TextureOpType::kStore) {
1212 // Image Ops
1213 std::vector<spirv::Value> args;
1214 for (int i = 0; i < stmt->args.size(); i++) {
1215 args.push_back(ir_->query_value(stmt->args[i]->raw_name()));
1216 }
1217 if (stmt->op == TextureOpType::kLoad) {
1218 // Image Load
1219 val = ir_->image_load(tex, args);
1220 ir_->register_value(stmt->raw_name(), val);
1221 } else if (stmt->op == TextureOpType::kStore) {
1222 // Image Store
1223 ir_->image_store(tex, args);
1224 }
1225 } else {
1226 TI_NOT_IMPLEMENTED;
1227 }
1228 }
1229
1230 void visit(InternalFuncStmt *stmt) override {
1231 spirv::Value val;
1232
1233 if (stmt->func_name == "composite_extract_0") {
1234 val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(),
1235 ir_->query_value(stmt->args[0]->raw_name()), 0);
1236 } else if (stmt->func_name == "composite_extract_1") {
1237 val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(),
1238 ir_->query_value(stmt->args[0]->raw_name()), 1);
1239 } else if (stmt->func_name == "composite_extract_2") {
1240 val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(),
1241 ir_->query_value(stmt->args[0]->raw_name()), 2);
1242 } else if (stmt->func_name == "composite_extract_3") {
1243 val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(),
1244 ir_->query_value(stmt->args[0]->raw_name()), 3);
1245 }
1246
1247 const std::unordered_set<std::string> reduction_ops{
1248 "subgroupAdd", "subgroupMul", "subgroupMin", "subgroupMax",
1249 "subgroupAnd", "subgroupOr", "subgroupXor"};
1250
1251 const std::unordered_set<std::string> inclusive_scan_ops{
1252 "subgroupInclusiveAdd", "subgroupInclusiveMul", "subgroupInclusiveMin",
1253 "subgroupInclusiveMax", "subgroupInclusiveAnd", "subgroupInclusiveOr",
1254 "subgroupInclusiveXor"};
1255
1256 const std::unordered_set<std::string> shuffle_ops{
1257 "subgroupShuffleDown", "subgroupShuffleUp", "subgroupShuffle"};
1258
1259 if (stmt->func_name == "workgroupBarrier") {
1260 ir_->make_inst(
1261 spv::OpControlBarrier,
1262 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup),
1263 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup),
1264 ir_->int_immediate_number(
1265 ir_->i32_type(), spv::MemorySemanticsWorkgroupMemoryMask |
1266 spv::MemorySemanticsAcquireReleaseMask));
1267 val = ir_->const_i32_zero_;
1268 } else if (stmt->func_name == "localInvocationId") {
1269 val = ir_->cast(ir_->i32_type(), ir_->get_local_invocation_id(0));
1270 } else if (stmt->func_name == "globalInvocationId") {
1271 val = ir_->cast(ir_->i32_type(), ir_->get_global_invocation_id(0));
1272 } else if (stmt->func_name == "workgroupMemoryBarrier") {
1273 ir_->make_inst(
1274 spv::OpMemoryBarrier,
1275 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup),
1276 ir_->int_immediate_number(
1277 ir_->i32_type(), spv::MemorySemanticsWorkgroupMemoryMask |
1278 spv::MemorySemanticsAcquireReleaseMask));
1279 val = ir_->const_i32_zero_;
1280 } else if (stmt->func_name == "subgroupElect") {
1281 val = ir_->make_value(
1282 spv::OpGroupNonUniformElect, ir_->bool_type(),
1283 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup));
1284 val = ir_->cast(ir_->i32_type(), val);
1285 } else if (stmt->func_name == "subgroupBarrier") {
1286 ir_->make_inst(
1287 spv::OpControlBarrier,
1288 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup),
1289 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup),
1290 ir_->const_i32_zero_);
1291 val = ir_->const_i32_zero_;
1292 } else if (stmt->func_name == "subgroupMemoryBarrier") {
1293 ir_->make_inst(
1294 spv::OpMemoryBarrier,
1295 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup),
1296 ir_->const_i32_zero_);
1297 val = ir_->const_i32_zero_;
1298 } else if (stmt->func_name == "subgroupSize") {
1299 val = ir_->cast(ir_->i32_type(), ir_->get_subgroup_size());
1300 } else if (stmt->func_name == "subgroupInvocationId") {
1301 val = ir_->cast(ir_->i32_type(), ir_->get_subgroup_invocation_id());
1302 } else if (stmt->func_name == "subgroupBroadcast") {
1303 auto value = ir_->query_value(stmt->args[0]->raw_name());
1304 auto index = ir_->query_value(stmt->args[1]->raw_name());
1305 val = ir_->make_value(
1306 spv::OpGroupNonUniformBroadcast, value.stype,
1307 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), value,
1308 index);
1309 } else if (reduction_ops.find(stmt->func_name) != reduction_ops.end() ||
1310 inclusive_scan_ops.find(stmt->func_name) !=
1311 inclusive_scan_ops.end()) {
1312 auto arg = ir_->query_value(stmt->args[0]->raw_name());
1313 auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type);
1314 spv::Op spv_op;
1315
1316 if (ends_with(stmt->func_name, "Add")) {
1317 if (is_integral(stmt->args[0]->ret_type)) {
1318 spv_op = spv::OpGroupNonUniformIAdd;
1319 } else {
1320 spv_op = spv::OpGroupNonUniformFAdd;
1321 }
1322 } else if (ends_with(stmt->func_name, "Mul")) {
1323 if (is_integral(stmt->args[0]->ret_type)) {
1324 spv_op = spv::OpGroupNonUniformIMul;
1325 } else {
1326 spv_op = spv::OpGroupNonUniformFMul;
1327 }
1328 } else if (ends_with(stmt->func_name, "Min")) {
1329 if (is_integral(stmt->args[0]->ret_type)) {
1330 if (is_signed(stmt->args[0]->ret_type)) {
1331 spv_op = spv::OpGroupNonUniformSMin;
1332 } else {
1333 spv_op = spv::OpGroupNonUniformUMin;
1334 }
1335 } else {
1336 spv_op = spv::OpGroupNonUniformFMin;
1337 }
1338 } else if (ends_with(stmt->func_name, "Max")) {
1339 if (is_integral(stmt->args[0]->ret_type)) {
1340 if (is_signed(stmt->args[0]->ret_type)) {
1341 spv_op = spv::OpGroupNonUniformSMax;
1342 } else {
1343 spv_op = spv::OpGroupNonUniformUMax;
1344 }
1345 } else {
1346 spv_op = spv::OpGroupNonUniformFMax;
1347 }
1348 } else if (ends_with(stmt->func_name, "And")) {
1349 spv_op = spv::OpGroupNonUniformBitwiseAnd;
1350 } else if (ends_with(stmt->func_name, "Or")) {
1351 spv_op = spv::OpGroupNonUniformBitwiseOr;
1352 } else if (ends_with(stmt->func_name, "Xor")) {
1353 spv_op = spv::OpGroupNonUniformBitwiseXor;
1354 } else {
1355 TI_ERROR("Unsupported operation: {}", stmt->func_name);
1356 }
1357
1358 spv::GroupOperation group_op;
1359
1360 if (reduction_ops.find(stmt->func_name) != reduction_ops.end()) {
1361 group_op = spv::GroupOperationReduce;
1362 } else if (inclusive_scan_ops.find(stmt->func_name) !=
1363 inclusive_scan_ops.end()) {
1364 group_op = spv::GroupOperationInclusiveScan;
1365 }
1366
1367 val = ir_->make_value(
1368 spv_op, stype,
1369 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup),
1370 group_op, arg);
1371 } else if (shuffle_ops.find(stmt->func_name) != shuffle_ops.end()) {
1372 auto arg0 = ir_->query_value(stmt->args[0]->raw_name());
1373 auto arg1 = ir_->query_value(stmt->args[1]->raw_name());
1374 auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type);
1375 spv::Op spv_op;
1376
1377 if (ends_with(stmt->func_name, "Down")) {
1378 spv_op = spv::OpGroupNonUniformShuffleDown;
1379 } else if (ends_with(stmt->func_name, "Up")) {
1380 spv_op = spv::OpGroupNonUniformShuffleUp;
1381 } else if (ends_with(stmt->func_name, "Shuffle")) {
1382 spv_op = spv::OpGroupNonUniformShuffle;
1383 } else {
1384 TI_ERROR("Unsupported operation: {}", stmt->func_name);
1385 }
1386
1387 val = ir_->make_value(
1388 spv_op, stype,
1389 ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), arg0,
1390 arg1);
1391 }
1392 ir_->register_value(stmt->raw_name(), val);
1393 }
1394
1395 void visit(AtomicOpStmt *stmt) override {
1396 const auto dt = stmt->dest->element_type().ptr_removed();
1397
1398 spirv::Value data = ir_->query_value(stmt->val->raw_name());
1399 spirv::Value val;
1400 bool use_subgroup_reduction = false;
1401
1402 if (stmt->is_reduction &&
1403 caps_->get(DeviceCapability::spirv_has_subgroup_arithmetic)) {
1404 spv::Op atomic_op = spv::OpNop;
1405 bool negation = false;
1406 if (is_integral(dt)) {
1407 if (stmt->op_type == AtomicOpType::add) {
1408 atomic_op = spv::OpGroupIAdd;
1409 } else if (stmt->op_type == AtomicOpType::sub) {
1410 atomic_op = spv::OpGroupIAdd;
1411 negation = true;
1412 } else if (stmt->op_type == AtomicOpType::min) {
1413 atomic_op = is_signed(dt) ? spv::OpGroupSMin : spv::OpGroupUMin;
1414 } else if (stmt->op_type == AtomicOpType::max) {
1415 atomic_op = is_signed(dt) ? spv::OpGroupSMax : spv::OpGroupUMax;
1416 }
1417 } else if (is_real(dt)) {
1418 if (stmt->op_type == AtomicOpType::add) {
1419 atomic_op = spv::OpGroupFAdd;
1420 } else if (stmt->op_type == AtomicOpType::sub) {
1421 atomic_op = spv::OpGroupFAdd;
1422 negation = true;
1423 } else if (stmt->op_type == AtomicOpType::min) {
1424 atomic_op = spv::OpGroupFMin;
1425 } else if (stmt->op_type == AtomicOpType::max) {
1426 atomic_op = spv::OpGroupFMax;
1427 }
1428 }
1429
1430 if (atomic_op != spv::OpNop) {
1431 spirv::Value scope_subgroup =
1432 ir_->int_immediate_number(ir_->i32_type(), 3);
1433 spirv::Value operation_reduce = ir_->const_i32_zero_;
1434 if (negation) {
1435 if (is_integral(dt)) {
1436 data = ir_->make_value(spv::OpSNegate, data.stype, data);
1437 } else {
1438 data = ir_->make_value(spv::OpFNegate, data.stype, data);
1439 }
1440 }
1441 data = ir_->make_value(atomic_op, ir_->get_primitive_type(dt),
1442 scope_subgroup, operation_reduce, data);
1443 val = data;
1444 use_subgroup_reduction = true;
1445 }
1446 }
1447
1448 spirv::Label then_label;
1449 spirv::Label merge_label;
1450
1451 if (use_subgroup_reduction) {
1452 spirv::Value subgroup_id = ir_->get_subgroup_invocation_id();
1453 spirv::Value cond = ir_->make_value(spv::OpIEqual, ir_->bool_type(),
1454 subgroup_id, ir_->const_i32_zero_);
1455
1456 then_label = ir_->new_label();
1457 merge_label = ir_->new_label();
1458 ir_->make_inst(spv::OpSelectionMerge, merge_label,
1459 spv::SelectionControlMaskNone);
1460 ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label);
1461 ir_->start_label(then_label);
1462 }
1463
1464 spirv::Value addr_ptr;
1465
1466 if (dt->is_primitive(PrimitiveTypeID::f64)) {
1467 if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) &&
1468 stmt->op_type == AtomicOpType::add) {
1469 addr_ptr = at_buffer(stmt->dest, dt);
1470 } else {
1471 addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt));
1472 }
1473 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
1474 if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) &&
1475 stmt->op_type == AtomicOpType::add) {
1476 addr_ptr = at_buffer(stmt->dest, dt);
1477 } else {
1478 addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt));
1479 }
1480 } else {
1481 addr_ptr = at_buffer(stmt->dest, dt);
1482 }
1483
1484 auto ret_type = ir_->get_primitive_type(dt);
1485
1486 if (is_real(dt)) {
1487 spv::Op atomic_fp_op;
1488 if (stmt->op_type == AtomicOpType::add) {
1489 atomic_fp_op = spv::OpAtomicFAddEXT;
1490 }
1491
1492 bool use_native_atomics = false;
1493
1494 if (dt->is_primitive(PrimitiveTypeID::f64)) {
1495 if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) &&
1496 stmt->op_type == AtomicOpType::add) {
1497 use_native_atomics = true;
1498 }
1499 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
1500 if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) &&
1501 stmt->op_type == AtomicOpType::add) {
1502 use_native_atomics = true;
1503 }
1504 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
1505 if (caps_->get(DeviceCapability::spirv_has_atomic_float16_add) &&
1506 stmt->op_type == AtomicOpType::add) {
1507 use_native_atomics = true;
1508 }
1509 }
1510
1511 if (use_native_atomics) {
1512 val =
1513 ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr,
1514 /*scope=*/ir_->const_i32_one_,
1515 /*semantics=*/ir_->const_i32_zero_, data);
1516 } else {
1517 val = ir_->float_atomic(stmt->op_type, addr_ptr, data);
1518 }
1519 } else if (is_integral(dt)) {
1520 spv::Op op;
1521 if (stmt->op_type == AtomicOpType::add) {
1522 op = spv::OpAtomicIAdd;
1523 } else if (stmt->op_type == AtomicOpType::sub) {
1524 op = spv::OpAtomicISub;
1525 } else if (stmt->op_type == AtomicOpType::min) {
1526 op = is_signed(dt) ? spv::OpAtomicSMin : spv::OpAtomicUMin;
1527 } else if (stmt->op_type == AtomicOpType::max) {
1528 op = is_signed(dt) ? spv::OpAtomicSMax : spv::OpAtomicUMax;
1529 } else if (stmt->op_type == AtomicOpType::bit_or) {
1530 op = spv::OpAtomicOr;
1531 } else if (stmt->op_type == AtomicOpType::bit_and) {
1532 op = spv::OpAtomicAnd;
1533 } else if (stmt->op_type == AtomicOpType::bit_xor) {
1534 op = spv::OpAtomicXor;
1535 } else {
1536 TI_NOT_IMPLEMENTED
1537 }
1538
1539 auto uint_type = ir_->get_primitive_uint_type(dt);
1540
1541 if (data.stype.id != addr_ptr.stype.element_type_id) {
1542 data = ir_->make_value(spv::OpBitcast, ret_type, data);
1543 }
1544
1545 // Semantics = (UniformMemory 0x40) | (AcquireRelease 0x8)
1546 ir_->make_inst(
1547 spv::OpMemoryBarrier, ir_->const_i32_one_,
1548 ir_->uint_immediate_number(
1549 ir_->u32_type(), spv::MemorySemanticsAcquireReleaseMask |
1550 spv::MemorySemanticsUniformMemoryMask));
1551 val = ir_->make_value(op, ret_type, addr_ptr,
1552 /*scope=*/ir_->const_i32_one_,
1553 /*semantics=*/ir_->const_i32_zero_, data);
1554
1555 if (val.stype.id != ret_type.id) {
1556 val = ir_->make_value(spv::OpBitcast, ret_type, val);
1557 }
1558 } else {
1559 TI_NOT_IMPLEMENTED
1560 }
1561
1562 if (use_subgroup_reduction) {
1563 ir_->make_inst(spv::OpBranch, merge_label);
1564 ir_->start_label(merge_label);
1565 }
1566
1567 ir_->register_value(stmt->raw_name(), val);
1568 }
1569
1570 void visit(IfStmt *if_stmt) override {
1571 spirv::Value cond_v = ir_->query_value(if_stmt->cond->raw_name());
1572 spirv::Value cond =
1573 ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
1574 spirv::Label then_label = ir_->new_label();
1575 spirv::Label merge_label = ir_->new_label();
1576 spirv::Label else_label = ir_->new_label();
1577 ir_->make_inst(spv::OpSelectionMerge, merge_label,
1578 spv::SelectionControlMaskNone);
1579 ir_->make_inst(spv::OpBranchConditional, cond, then_label, else_label);
1580 // then block
1581 ir_->start_label(then_label);
1582 if (if_stmt->true_statements) {
1583 if_stmt->true_statements->accept(this);
1584 }
1585 // ContinueStmt must be in IfStmt
1586 if (gen_label_) { // Skip OpBranch, because ContinueStmt already generated
1587 // one
1588 gen_label_ = false;
1589 } else {
1590 ir_->make_inst(spv::OpBranch, merge_label);
1591 }
1592 // else block
1593 ir_->start_label(else_label);
1594 if (if_stmt->false_statements) {
1595 if_stmt->false_statements->accept(this);
1596 }
1597 if (gen_label_) {
1598 gen_label_ = false;
1599 } else {
1600 ir_->make_inst(spv::OpBranch, merge_label);
1601 }
1602 // merge label
1603 ir_->start_label(merge_label);
1604 }
1605
1606 void visit(RangeForStmt *for_stmt) override {
1607 auto loop_var_name = for_stmt->raw_name();
1608 // Must get init label after making value(to make sure they are correct)
1609 spirv::Label init_label = ir_->current_label();
1610 spirv::Label head_label = ir_->new_label();
1611 spirv::Label body_label = ir_->new_label();
1612 spirv::Label continue_label = ir_->new_label();
1613 spirv::Label merge_label = ir_->new_label();
1614
1615 spirv::Value begin_ = ir_->query_value(for_stmt->begin->raw_name());
1616 spirv::Value end_ = ir_->query_value(for_stmt->end->raw_name());
1617 spirv::Value init_value;
1618 spirv::Value extent_value;
1619 if (!for_stmt->reversed) {
1620 init_value = begin_;
1621 extent_value = end_;
1622 } else {
1623 // reversed for loop
1624 init_value = ir_->sub(end_, ir_->const_i32_one_);
1625 extent_value = begin_;
1626 }
1627 ir_->make_inst(spv::OpBranch, head_label);
1628
1629 // Loop head
1630 ir_->start_label(head_label);
1631 spirv::PhiValue loop_var = ir_->make_phi(init_value.stype, 2);
1632 loop_var.set_incoming(0, init_value, init_label);
1633 spirv::Value loop_cond;
1634 if (!for_stmt->reversed) {
1635 loop_cond = ir_->lt(loop_var, extent_value);
1636 } else {
1637 loop_cond = ir_->ge(loop_var, extent_value);
1638 }
1639 ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label,
1640 spv::LoopControlMaskNone);
1641 ir_->make_inst(spv::OpBranchConditional, loop_cond, body_label,
1642 merge_label);
1643
1644 // loop body
1645 ir_->start_label(body_label);
1646 push_loop_control_labels(continue_label, merge_label);
1647 ir_->register_value(loop_var_name, spirv::Value(loop_var));
1648 for_stmt->body->accept(this);
1649 pop_loop_control_labels();
1650 ir_->make_inst(spv::OpBranch, continue_label);
1651
1652 // loop continue
1653 ir_->start_label(continue_label);
1654 spirv::Value next_value;
1655 if (!for_stmt->reversed) {
1656 next_value = ir_->add(loop_var, ir_->const_i32_one_);
1657 } else {
1658 next_value = ir_->sub(loop_var, ir_->const_i32_one_);
1659 }
1660 loop_var.set_incoming(1, next_value, ir_->current_label());
1661 ir_->make_inst(spv::OpBranch, head_label);
1662 // loop merge
1663 ir_->start_label(merge_label);
1664 }
1665
1666 void visit(WhileStmt *stmt) override {
1667 spirv::Label head_label = ir_->new_label();
1668 spirv::Label body_label = ir_->new_label();
1669 spirv::Label continue_label = ir_->new_label();
1670 spirv::Label merge_label = ir_->new_label();
1671 ir_->make_inst(spv::OpBranch, head_label);
1672
1673 // Loop head
1674 ir_->start_label(head_label);
1675 ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label,
1676 spv::LoopControlMaskNone);
1677 ir_->make_inst(spv::OpBranch, body_label);
1678
1679 // loop body
1680 ir_->start_label(body_label);
1681 push_loop_control_labels(continue_label, merge_label);
1682 stmt->body->accept(this);
1683 pop_loop_control_labels();
1684 ir_->make_inst(spv::OpBranch, continue_label);
1685
1686 // loop continue
1687 ir_->start_label(continue_label);
1688 ir_->make_inst(spv::OpBranch, head_label);
1689
1690 // loop merge
1691 ir_->start_label(merge_label);
1692 }
1693
1694 void visit(WhileControlStmt *stmt) override {
1695 spirv::Value cond_v = ir_->query_value(stmt->cond->raw_name());
1696 spirv::Value cond =
1697 ir_->eq(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
1698 spirv::Label then_label = ir_->new_label();
1699 spirv::Label merge_label = ir_->new_label();
1700
1701 ir_->make_inst(spv::OpSelectionMerge, merge_label,
1702 spv::SelectionControlMaskNone);
1703 ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label);
1704 ir_->start_label(then_label);
1705 ir_->make_inst(spv::OpBranch, current_merge_label()); // break;
1706 ir_->start_label(merge_label);
1707 }
1708
1709 void visit(ContinueStmt *stmt) override {
1710 auto stmt_in_off_for = [stmt]() {
1711 TI_ASSERT(stmt->scope != nullptr);
1712 if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) {
1713 TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for ||
1714 offl->task_type == OffloadedStmt::TaskType::struct_for);
1715 return true;
1716 }
1717 return false;
1718 };
1719 if (stmt_in_off_for()) {
1720 // Return means end THIS main loop and start next loop, not exit kernel
1721 ir_->make_inst(spv::OpBranch, return_label());
1722 } else {
1723 ir_->make_inst(spv::OpBranch, current_continue_label());
1724 }
1725 gen_label_ = true; // Only ContinueStmt will cause duplicate OpBranch,
1726 // which should be eliminated
1727 }
1728
1729 private:
1730 void emit_headers() {
1731 /*
1732 for (int root = 0; root < compiled_structs_.size(); ++root) {
1733 get_buffer_value({BufferType::Root, root});
1734 }
1735 */
1736 std::array<int, 3> group_size = {
1737 task_attribs_.advisory_num_threads_per_group, 1, 1};
1738 ir_->set_work_group_size(group_size);
1739 std::vector<spirv::Value> buffers;
1740 if (caps_->get(DeviceCapability::spirv_version) > 0x10300) {
1741 buffers = shared_array_binds_;
1742 // One buffer can be bound to different bind points but has to be unique
1743 // in OpEntryPoint interface declarations.
1744 // From Spec: before SPIR-V version 1.4, duplication of these interface id
1745 // is tolerated. Starting with version 1.4, an interface id must not
1746 // appear more than once.
1747 std::unordered_set<spirv::Value, spirv::ValueHasher> entry_point_values;
1748 for (const auto &bb : task_attribs_.buffer_binds) {
1749 for (auto &it : buffer_value_map_) {
1750 if (it.first.first == bb.buffer) {
1751 entry_point_values.insert(it.second);
1752 }
1753 }
1754 }
1755 buffers.insert(buffers.end(), entry_point_values.begin(),
1756 entry_point_values.end());
1757 }
1758 ir_->commit_kernel_function(kernel_function_, "main", buffers,
1759 group_size); // kernel entry
1760 }
1761
1762 void generate_serial_kernel(OffloadedStmt *stmt) {
1763 task_attribs_.name = task_name_;
1764 task_attribs_.task_type = OffloadedTaskType::serial;
1765 task_attribs_.advisory_total_num_threads = 1;
1766 task_attribs_.advisory_num_threads_per_group = 1;
1767
1768 // The computation for a single work is wrapped inside a function, so that
1769 // we can do grid-strided loop.
1770 ir_->start_function(kernel_function_);
1771 spirv::Value cond =
1772 ir_->eq(ir_->get_global_invocation_id(0),
1773 ir_->uint_immediate_number(
1774 ir_->u32_type(), 0)); // if (gl_GlobalInvocationID.x > 0)
1775 spirv::Label then_label = ir_->new_label();
1776 spirv::Label merge_label = ir_->new_label();
1777 kernel_return_label_ = merge_label;
1778
1779 ir_->make_inst(spv::OpSelectionMerge, merge_label,
1780 spv::SelectionControlMaskNone);
1781 ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label);
1782 ir_->start_label(then_label);
1783
1784 // serial kernel
1785 stmt->body->accept(this);
1786
1787 ir_->make_inst(spv::OpBranch, merge_label);
1788 ir_->start_label(merge_label);
1789 ir_->make_inst(spv::OpReturn); // return;
1790 ir_->make_inst(spv::OpFunctionEnd); // } Close kernel
1791
1792 task_attribs_.buffer_binds = get_buffer_binds();
1793 task_attribs_.texture_binds = get_texture_binds();
1794 }
1795
1796 void gen_array_range(Stmt *stmt) {
1797 int num_operands = stmt->num_operands();
1798 for (int i = 0; i < num_operands; i++) {
1799 gen_array_range(stmt->operand(i));
1800 }
1801 offload_loop_motion_.insert(stmt);
1802 stmt->accept(this);
1803 }
1804
1805 void generate_range_for_kernel(OffloadedStmt *stmt) {
1806 task_attribs_.name = task_name_;
1807 task_attribs_.task_type = OffloadedTaskType::range_for;
1808
1809 task_attribs_.range_for_attribs = TaskAttributes::RangeForAttributes();
1810 auto &range_for_attribs = task_attribs_.range_for_attribs.value();
1811 range_for_attribs.const_begin = stmt->const_begin;
1812 range_for_attribs.const_end = stmt->const_end;
1813 range_for_attribs.begin =
1814 (stmt->const_begin ? stmt->begin_value : stmt->begin_offset);
1815 range_for_attribs.end =
1816 (stmt->const_end ? stmt->end_value : stmt->end_offset);
1817
1818 ir_->start_function(kernel_function_);
1819 const std::string total_elems_name("total_elems");
1820 spirv::Value total_elems;
1821 spirv::Value begin_expr_value;
1822 if (range_for_attribs.const_range()) {
1823 const int num_elems = range_for_attribs.end - range_for_attribs.begin;
1824 begin_expr_value = ir_->int_immediate_number(
1825 ir_->i32_type(), stmt->begin_value, false); // Named Constant
1826 total_elems = ir_->int_immediate_number(ir_->i32_type(), num_elems,
1827 false); // Named Constant
1828 task_attribs_.advisory_total_num_threads = num_elems;
1829 } else {
1830 spirv::Value end_expr_value;
1831 if (stmt->end_stmt) {
1832 // Range from args
1833 TI_ASSERT(stmt->const_begin);
1834 begin_expr_value = ir_->int_immediate_number(ir_->i32_type(),
1835 stmt->begin_value, false);
1836 gen_array_range(stmt->end_stmt);
1837 end_expr_value = ir_->query_value(stmt->end_stmt->raw_name());
1838 } else {
1839 // Range from gtmp / constant
1840 if (!stmt->const_begin) {
1841 spirv::Value begin_idx = ir_->make_value(
1842 spv::OpShiftRightArithmetic, ir_->i32_type(),
1843 ir_->int_immediate_number(ir_->i32_type(), stmt->begin_offset),
1844 ir_->int_immediate_number(ir_->i32_type(), 2));
1845 begin_expr_value = ir_->load_variable(
1846 ir_->struct_array_access(
1847 ir_->i32_type(),
1848 get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32),
1849 begin_idx),
1850 ir_->i32_type());
1851 } else {
1852 begin_expr_value = ir_->int_immediate_number(
1853 ir_->i32_type(), stmt->begin_value, false); // Named Constant
1854 }
1855 if (!stmt->const_end) {
1856 spirv::Value end_idx = ir_->make_value(
1857 spv::OpShiftRightArithmetic, ir_->i32_type(),
1858 ir_->int_immediate_number(ir_->i32_type(), stmt->end_offset),
1859 ir_->int_immediate_number(ir_->i32_type(), 2));
1860 end_expr_value = ir_->load_variable(
1861 ir_->struct_array_access(
1862 ir_->i32_type(),
1863 get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32),
1864 end_idx),
1865 ir_->i32_type());
1866 } else {
1867 end_expr_value =
1868 ir_->int_immediate_number(ir_->i32_type(), stmt->end_value, true);
1869 }
1870 }
1871 total_elems = ir_->sub(end_expr_value, begin_expr_value);
1872 task_attribs_.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop;
1873 }
1874 task_attribs_.advisory_num_threads_per_group = stmt->block_dim;
1875 ir_->debug_name(spv::OpName, begin_expr_value, "begin_expr_value");
1876 ir_->debug_name(spv::OpName, total_elems, total_elems_name);
1877
1878 spirv::Value begin_ =
1879 ir_->add(ir_->cast(ir_->i32_type(), ir_->get_global_invocation_id(0)),
1880 begin_expr_value);
1881 ir_->debug_name(spv::OpName, begin_, "begin_");
1882 spirv::Value end_ = ir_->add(total_elems, begin_expr_value);
1883 ir_->debug_name(spv::OpName, end_, "end_");
1884 const std::string total_invocs_name = "total_invocs";
1885 // For now, |total_invocs_name| is equal to |total_elems|. Once we support
1886 // dynamic range, they will be different.
1887 // https://www.khronos.org/opengl/wiki/Compute_Shader#Inputs
1888
1889 // HLSL & WGSL cross compilers do not support this builtin
1890 spirv::Value total_invocs = ir_->cast(
1891 ir_->i32_type(),
1892 ir_->mul(ir_->get_num_work_groups(0),
1893 ir_->uint_immediate_number(
1894 ir_->u32_type(),
1895 task_attribs_.advisory_num_threads_per_group, true)));
1896 /*
1897 const int group_x = (task_attribs_.advisory_total_num_threads +
1898 task_attribs_.advisory_num_threads_per_group - 1) /
1899 task_attribs_.advisory_num_threads_per_group;
1900 spirv::Value total_invocs = ir_->uint_immediate_number(
1901 ir_->i32_type(), group_x * task_attribs_.advisory_num_threads_per_group,
1902 false);
1903 */
1904
1905 ir_->debug_name(spv::OpName, total_invocs, total_invocs_name);
1906
1907 // Must get init label after making value(to make sure they are correct)
1908 spirv::Label init_label = ir_->current_label();
1909 spirv::Label head_label = ir_->new_label();
1910 spirv::Label body_label = ir_->new_label();
1911 spirv::Label continue_label = ir_->new_label();
1912 spirv::Label merge_label = ir_->new_label();
1913 ir_->make_inst(spv::OpBranch, head_label);
1914
1915 // loop head
1916 ir_->start_label(head_label);
1917 spirv::PhiValue loop_var = ir_->make_phi(begin_.stype, 2);
1918 ir_->register_value("ii", loop_var);
1919 loop_var.set_incoming(0, begin_, init_label);
1920 spirv::Value loop_cond = ir_->lt(loop_var, end_);
1921 ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label,
1922 spv::LoopControlMaskNone);
1923 ir_->make_inst(spv::OpBranchConditional, loop_cond, body_label,
1924 merge_label);
1925
1926 // loop body
1927 ir_->start_label(body_label);
1928 push_loop_control_labels(continue_label, merge_label);
1929
1930 // loop kernel
1931 stmt->body->accept(this);
1932 pop_loop_control_labels();
1933 ir_->make_inst(spv::OpBranch, continue_label);
1934
1935 // loop continue
1936 ir_->start_label(continue_label);
1937 spirv::Value next_value = ir_->add(loop_var, total_invocs);
1938 loop_var.set_incoming(1, next_value, ir_->current_label());
1939 ir_->make_inst(spv::OpBranch, head_label);
1940
1941 // loop merge
1942 ir_->start_label(merge_label);
1943
1944 ir_->make_inst(spv::OpReturn);
1945 ir_->make_inst(spv::OpFunctionEnd);
1946
1947 task_attribs_.buffer_binds = get_buffer_binds();
1948 task_attribs_.texture_binds = get_texture_binds();
1949 }
1950
1951 void generate_struct_for_kernel(OffloadedStmt *stmt) {
1952 task_attribs_.name = task_name_;
1953 task_attribs_.task_type = OffloadedTaskType::struct_for;
1954 task_attribs_.advisory_total_num_threads = 65536;
1955 task_attribs_.advisory_num_threads_per_group = 128;
1956
1957 // The computation for a single work is wrapped inside a function, so that
1958 // we can do grid-strided loop.
1959 ir_->start_function(kernel_function_);
1960
1961 auto listgen_buffer =
1962 get_buffer_value(BufferType::ListGen, PrimitiveType::u32);
1963 auto listgen_count_ptr = ir_->struct_array_access(
1964 ir_->u32_type(), listgen_buffer, ir_->const_i32_zero_);
1965 auto listgen_count = ir_->load_variable(listgen_count_ptr, ir_->u32_type());
1966
1967 auto invoc_index = ir_->get_global_invocation_id(0);
1968
1969 spirv::Label loop_head = ir_->new_label();
1970 spirv::Label loop_body = ir_->new_label();
1971 spirv::Label loop_merge = ir_->new_label();
1972
1973 auto loop_index_var = ir_->alloca_variable(ir_->u32_type());
1974 ir_->store_variable(loop_index_var, invoc_index);
1975
1976 ir_->make_inst(spv::OpBranch, loop_head);
1977 ir_->start_label(loop_head);
1978 // for (; index < list_size; index += gl_NumWorkGroups.x *
1979 // gl_WorkGroupSize.x)
1980 auto loop_index = ir_->load_variable(loop_index_var, ir_->u32_type());
1981 auto loop_cond = ir_->make_value(spv::OpULessThan, ir_->bool_type(),
1982 loop_index, listgen_count);
1983 ir_->make_inst(spv::OpLoopMerge, loop_merge, loop_body,
1984 spv::LoopControlMaskNone);
1985 ir_->make_inst(spv::OpBranchConditional, loop_cond, loop_body, loop_merge);
1986 {
1987 ir_->start_label(loop_body);
1988 auto listgen_index_ptr = ir_->struct_array_access(
1989 ir_->u32_type(), listgen_buffer,
1990 ir_->add(ir_->uint_immediate_number(ir_->u32_type(), 1), loop_index));
1991 auto listgen_index =
1992 ir_->load_variable(listgen_index_ptr, ir_->u32_type());
1993
1994 // kernel
1995 ir_->register_value("ii", listgen_index);
1996 stmt->body->accept(this);
1997
1998 // continue
1999 spirv::Value total_invocs = ir_->cast(
2000 ir_->u32_type(),
2001 ir_->mul(ir_->get_num_work_groups(0),
2002 ir_->uint_immediate_number(
2003 ir_->u32_type(),
2004 task_attribs_.advisory_num_threads_per_group, true)));
2005 auto next_index = ir_->add(loop_index, total_invocs);
2006 ir_->store_variable(loop_index_var, next_index);
2007 ir_->make_inst(spv::OpBranch, loop_head);
2008 }
2009 ir_->start_label(loop_merge);
2010
2011 ir_->make_inst(spv::OpReturn); // return;
2012 ir_->make_inst(spv::OpFunctionEnd); // } Close kernel
2013
2014 task_attribs_.buffer_binds = get_buffer_binds();
2015 task_attribs_.texture_binds = get_texture_binds();
2016 }
2017
2018 spirv::Value at_buffer(const Stmt *ptr, DataType dt) {
2019 spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
2020
2021 if (ptr_val.stype.dt == PrimitiveType::u64) {
2022 spirv::Value paddr_ptr = ir_->make_value(
2023 spv::OpConvertUToPtr,
2024 ir_->get_pointer_type(ir_->get_primitive_type(dt),
2025 spv::StorageClassPhysicalStorageBuffer),
2026 ptr_val);
2027 paddr_ptr.flag = ValueKind::kPhysicalPtr;
2028 return paddr_ptr;
2029 }
2030
2031 spirv::Value buffer = get_buffer_value(ptr_to_buffers_.at(ptr), dt);
2032 size_t width = ir_->get_primitive_type_size(dt);
2033 spirv::Value idx_val = ir_->make_value(
2034 spv::OpShiftRightLogical, ptr_val.stype, ptr_val,
2035 ir_->uint_immediate_number(ptr_val.stype, size_t(std::log2(width))));
2036 spirv::Value ret =
2037 ir_->struct_array_access(ir_->get_primitive_type(dt), buffer, idx_val);
2038 return ret;
2039 }
2040
2041 spirv::Value load_buffer(const Stmt *ptr, DataType dt) {
2042 spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
2043
2044 DataType ti_buffer_type = ir_->get_taichi_uint_type(dt);
2045
2046 if (ptr_val.stype.dt == PrimitiveType::u64) {
2047 ti_buffer_type = dt;
2048 }
2049
2050 auto buf_ptr = at_buffer(ptr, ti_buffer_type);
2051 auto val_bits =
2052 ir_->load_variable(buf_ptr, ir_->get_primitive_type(ti_buffer_type));
2053 auto ret = ti_buffer_type == dt
2054 ? val_bits
2055 : ir_->make_value(spv::OpBitcast,
2056 ir_->get_primitive_type(dt), val_bits);
2057 return ret;
2058 }
2059
2060 void store_buffer(const Stmt *ptr, spirv::Value val) {
2061 spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
2062
2063 DataType ti_buffer_type = ir_->get_taichi_uint_type(val.stype.dt);
2064
2065 if (ptr_val.stype.dt == PrimitiveType::u64) {
2066 ti_buffer_type = val.stype.dt;
2067 }
2068
2069 auto buf_ptr = at_buffer(ptr, ti_buffer_type);
2070 auto val_bits =
2071 val.stype.dt == ti_buffer_type
2072 ? val
2073 : ir_->make_value(spv::OpBitcast,
2074 ir_->get_primitive_type(ti_buffer_type), val);
2075 ir_->store_variable(buf_ptr, val_bits);
2076 }
2077
2078 spirv::Value get_buffer_value(BufferInfo buffer, DataType dt) {
2079 auto type = ir_->get_primitive_type(dt);
2080 auto key = std::make_pair(buffer, type.id);
2081
2082 const auto it = buffer_value_map_.find(key);
2083 if (it != buffer_value_map_.end()) {
2084 return it->second;
2085 }
2086
2087 if (buffer.type == BufferType::Args) {
2088 compile_args_struct();
2089
2090 buffer_binding_map_[key] = 0;
2091 buffer_value_map_[key] = args_buffer_value_;
2092 return args_buffer_value_;
2093 }
2094
2095 if (buffer.type == BufferType::Rets) {
2096 compile_ret_struct();
2097
2098 buffer_binding_map_[key] = 1;
2099 buffer_value_map_[key] = ret_buffer_value_;
2100 return ret_buffer_value_;
2101 }
2102
2103 // Binding head starts at 2, so we don't break args and rets
2104 int binding = binding_head_++;
2105 buffer_binding_map_[key] = binding;
2106
2107 spirv::Value buffer_value =
2108 ir_->buffer_argument(type, 0, binding, buffer_instance_name(buffer));
2109 buffer_value_map_[key] = buffer_value;
2110 TI_TRACE("buffer name = {}, value = {}", buffer_instance_name(buffer),
2111 buffer_value.id);
2112
2113 return buffer_value;
2114 }
2115
2116 spirv::Value make_pointer(size_t offset) {
2117 if (use_64bit_pointers) {
2118 // This is hacky, should check out how to encode uint64 values in spirv
2119 return ir_->cast(ir_->u64_type(), ir_->uint_immediate_number(
2120 ir_->u32_type(), uint32_t(offset)));
2121 } else {
2122 return ir_->uint_immediate_number(ir_->u32_type(), uint32_t(offset));
2123 }
2124 }
2125
2126 void compile_args_struct() {
2127 if (!ctx_attribs_->has_args())
2128 return;
2129
2130 // Generate struct IR
2131 tinyir::Block blk;
2132 std::vector<const tinyir::Type *> element_types;
2133 for (auto &arg : ctx_attribs_->args()) {
2134 const tinyir::Type *t;
2135 if (arg.is_array &&
2136 caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) {
2137 t = blk.emplace_back<IntType>(/*num_bits=*/64, /*is_signed=*/false);
2138 } else {
2139 t = translate_ti_primitive(blk, PrimitiveType::get(arg.dtype));
2140 }
2141 element_types.push_back(t);
2142 }
2143 const tinyir::Type *i32_type =
2144 blk.emplace_back<IntType>(/*num_bits=*/32, /*is_signed=*/true);
2145 for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) {
2146 element_types.push_back(i32_type);
2147 }
2148 const tinyir::Type *struct_type =
2149 blk.emplace_back<StructType>(element_types);
2150
2151 // Reduce struct IR
2152 std::unordered_map<const tinyir::Type *, const tinyir::Type *> old2new;
2153 auto reduced_blk = ir_reduce_types(&blk, old2new);
2154 struct_type = old2new[struct_type];
2155
2156 // Layout & translate to SPIR-V
2157 STD140LayoutContext layout_ctx;
2158 auto ir2spirv_map =
2159 ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get());
2160 args_struct_type_.id = ir2spirv_map[struct_type];
2161
2162 args_buffer_value_ =
2163 ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args");
2164 }
2165
2166 void compile_ret_struct() {
2167 if (!ctx_attribs_->has_rets())
2168 return;
2169
2170 std::vector<std::tuple<spirv::SType, std::string, size_t>>
2171 struct_components_;
2172 // Now we only have one ret
2173 TI_ASSERT(ctx_attribs_->rets().size() == 1);
2174 for (auto &ret : ctx_attribs_->rets()) {
2175 // Use array size = 0 to generate a RuntimeArray
2176 if (auto tensor_type =
2177 PrimitiveType::get(ret.dtype)->cast<TensorType>()) {
2178 struct_components_.emplace_back(
2179 ir_->get_array_type(
2180 ir_->get_primitive_type(tensor_type->get_element_type()), 0),
2181 "ret" + std::to_string(ret.index), ret.offset_in_mem);
2182 } else {
2183 struct_components_.emplace_back(
2184 ir_->get_array_type(
2185 ir_->get_primitive_type(PrimitiveType::get(ret.dtype)), 0),
2186 "ret" + std::to_string(ret.index), ret.offset_in_mem);
2187 }
2188 }
2189 ret_struct_type_ = ir_->create_struct_type(struct_components_);
2190
2191 ret_buffer_value_ =
2192 ir_->buffer_struct_argument(ret_struct_type_, 0, 1, "rets");
2193 }
2194
2195 std::vector<BufferBind> get_buffer_binds() {
2196 std::vector<BufferBind> result;
2197 for (auto &[key, val] : buffer_binding_map_) {
2198 result.push_back(BufferBind{key.first, int(val)});
2199 }
2200 return result;
2201 }
2202
2203 std::vector<TextureBind> get_texture_binds() {
2204 return texture_binds_;
2205 }
2206
2207 void push_loop_control_labels(spirv::Label continue_label,
2208 spirv::Label merge_label) {
2209 continue_label_stack_.push_back(continue_label);
2210 merge_label_stack_.push_back(merge_label);
2211 }
2212
2213 void pop_loop_control_labels() {
2214 continue_label_stack_.pop_back();
2215 merge_label_stack_.pop_back();
2216 }
2217
2218 const spirv::Label current_continue_label() const {
2219 return continue_label_stack_.back();
2220 }
2221
2222 const spirv::Label current_merge_label() const {
2223 return merge_label_stack_.back();
2224 }
2225
2226 const spirv::Label return_label() const {
2227 return continue_label_stack_.front();
2228 }
2229
2230 Arch arch_;
2231 DeviceCapabilityConfig *caps_;
2232
2233 struct BufferInfoTypeTupleHasher {
2234 std::size_t operator()(const std::pair<BufferInfo, int> &buf) const {
2235 return BufferInfoHasher()(buf.first) ^ (buf.second << 5);
2236 }
2237 };
2238
2239 spirv::SType args_struct_type_;
2240 spirv::Value args_buffer_value_;
2241
2242 spirv::SType ret_struct_type_;
2243 spirv::Value ret_buffer_value_;
2244
2245 std::shared_ptr<spirv::IRBuilder> ir_; // spirv binary code builder
2246 std::unordered_map<std::pair<BufferInfo, int>,
2247 spirv::Value,
2248 BufferInfoTypeTupleHasher>
2249 buffer_value_map_;
2250 std::unordered_map<std::pair<BufferInfo, int>,
2251 uint32_t,
2252 BufferInfoTypeTupleHasher>
2253 buffer_binding_map_;
2254 std::vector<TextureBind> texture_binds_;
2255 std::vector<spirv::Value> shared_array_binds_;
2256 spirv::Value kernel_function_;
2257 spirv::Label kernel_return_label_;
2258 bool gen_label_{false};
2259
2260 int binding_head_{2}; // Args:0, Ret:1
2261
2262 /*
2263 std::unordered_map<int, spirv::CompiledSpirvSNode>
2264 spirv_snodes_; // maps root id to spirv snode
2265 */
2266
2267 OffloadedStmt *const task_ir_; // not owned
2268 std::vector<CompiledSNodeStructs> compiled_structs_;
2269 std::unordered_map<int, int> snode_to_root_;
2270 const KernelContextAttributes *const ctx_attribs_; // not owned
2271 const std::string task_name_;
2272 std::vector<spirv::Label> continue_label_stack_;
2273 std::vector<spirv::Label> merge_label_stack_;
2274
2275 std::unordered_set<const Stmt *> offload_loop_motion_;
2276
2277 TaskAttributes task_attribs_;
2278 std::unordered_map<int, GetRootStmt *>
2279 root_stmts_; // maps root id to get root stmt
2280 std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_;
2281 std::unordered_map<int, Value> argid_to_tex_value_;
2282};
2283} // namespace
2284
2285static void spriv_message_consumer(spv_message_level_t level,
2286 const char *source,
2287 const spv_position_t &position,
2288 const char *message) {
2289 // TODO: Maybe we can add a macro, e.g. TI_LOG_AT_LEVEL(lv, ...)
2290 if (level <= SPV_MSG_FATAL) {
2291 TI_ERROR("{}\n[{}:{}:{}] {}", source, position.index, position.line,
2292 position.column, message);
2293 } else if (level <= SPV_MSG_WARNING) {
2294 TI_WARN("{}\n[{}:{}:{}] {}", source, position.index, position.line,
2295 position.column, message);
2296 } else if (level <= SPV_MSG_INFO) {
2297 TI_INFO("{}\n[{}:{}:{}] {}", source, position.index, position.line,
2298 position.column, message);
2299 } else if (level <= SPV_MSG_INFO) {
2300 TI_TRACE("{}\n[{}:{}:{}] {}", source, position.index, position.line,
2301 position.column, message);
2302 }
2303}
2304
2305KernelCodegen::KernelCodegen(const Params &params)
2306 : params_(params), ctx_attribs_(*params.kernel, &params.caps) {
2307 uint32_t spirv_version = params.caps.get(DeviceCapability::spirv_version);
2308
2309 spv_target_env target_env;
2310 if (spirv_version >= 0x10600) {
2311 target_env = SPV_ENV_VULKAN_1_3;
2312 } else if (spirv_version >= 0x10500) {
2313 target_env = SPV_ENV_VULKAN_1_2;
2314 } else if (spirv_version >= 0x10400) {
2315 target_env = SPV_ENV_VULKAN_1_1_SPIRV_1_4;
2316 } else if (spirv_version >= 0x10300) {
2317 target_env = SPV_ENV_VULKAN_1_1;
2318 } else {
2319 target_env = SPV_ENV_VULKAN_1_0;
2320 }
2321
2322 spirv_opt_ = std::make_unique<spvtools::Optimizer>(target_env);
2323 spirv_opt_->SetMessageConsumer(spriv_message_consumer);
2324 if (params.enable_spv_opt) {
2325 // From: SPIRV-Tools/source/opt/optimizer.cpp
2326 spirv_opt_->RegisterPass(spvtools::CreateWrapOpKillPass())
2327 .RegisterPass(spvtools::CreateDeadBranchElimPass())
2328 .RegisterPass(spvtools::CreateMergeReturnPass())
2329 .RegisterPass(spvtools::CreateInlineExhaustivePass())
2330 .RegisterPass(spvtools::CreateEliminateDeadFunctionsPass())
2331 .RegisterPass(spvtools::CreateAggressiveDCEPass())
2332 .RegisterPass(spvtools::CreatePrivateToLocalPass())
2333 .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass())
2334 .RegisterPass(spvtools::CreateLocalSingleStoreElimPass())
2335 .RegisterPass(spvtools::CreateScalarReplacementPass())
2336 .RegisterPass(spvtools::CreateLocalAccessChainConvertPass())
2337 .RegisterPass(spvtools::CreateLocalMultiStoreElimPass())
2338 .RegisterPass(spvtools::CreateCCPPass())
2339 .RegisterPass(spvtools::CreateLoopUnrollPass(true))
2340 .RegisterPass(spvtools::CreateRedundancyEliminationPass())
2341 .RegisterPass(spvtools::CreateCombineAccessChainsPass())
2342 .RegisterPass(spvtools::CreateSimplificationPass())
2343 .RegisterPass(spvtools::CreateSSARewritePass())
2344 .RegisterPass(spvtools::CreateVectorDCEPass())
2345 .RegisterPass(spvtools::CreateDeadInsertElimPass())
2346 .RegisterPass(spvtools::CreateIfConversionPass())
2347 .RegisterPass(spvtools::CreateCopyPropagateArraysPass())
2348 .RegisterPass(spvtools::CreateReduceLoadSizePass())
2349 .RegisterPass(spvtools::CreateBlockMergePass());
2350 }
2351 spirv_opt_options_.set_run_validator(false);
2352
2353 spirv_tools_ = std::make_unique<spvtools::SpirvTools>(target_env);
2354}
2355
2356void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs,
2357 std::vector<std::vector<uint32_t>> &generated_spirv) {
2358 auto *root = params_.kernel->ir->as<Block>();
2359 auto &tasks = root->statements;
2360 for (int i = 0; i < tasks.size(); ++i) {
2361 TaskCodegen::Params tp;
2362 tp.task_ir = tasks[i]->as<OffloadedStmt>();
2363 tp.task_id_in_kernel = i;
2364 tp.compiled_structs = params_.compiled_structs;
2365 tp.ctx_attribs = &ctx_attribs_;
2366 tp.ti_kernel_name = fmt::format("{}_{}", params_.ti_kernel_name, i);
2367 tp.arch = params_.arch;
2368 tp.caps = &params_.caps;
2369
2370 TaskCodegen cgen(tp);
2371 auto task_res = cgen.run();
2372
2373 for (auto &[id, access] : task_res.arr_access) {
2374 ctx_attribs_.arr_access[id] = ctx_attribs_.arr_access[id] | access;
2375 }
2376
2377 std::vector<uint32_t> optimized_spv(task_res.spirv_code);
2378
2379 bool success = true;
2380 {
2381 bool result = false;
2382 TI_ERROR_IF(
2383 (result = !spirv_opt_->Run(optimized_spv.data(), optimized_spv.size(),
2384 &optimized_spv, spirv_opt_options_)),
2385 "SPIRV optimization failed");
2386 if (result) {
2387 success = false;
2388 }
2389 }
2390
2391 TI_TRACE("SPIRV-Tools-opt: binary size, before={}, after={}",
2392 task_res.spirv_code.size(), optimized_spv.size());
2393
2394 // Enable to dump SPIR-V assembly of kernels
2395 if constexpr (false) {
2396 std::vector<uint32_t> &spirv =
2397 success ? optimized_spv : task_res.spirv_code;
2398
2399 std::string spirv_asm;
2400 spirv_tools_->Disassemble(optimized_spv, &spirv_asm);
2401 auto kernel_name = tp.ti_kernel_name;
2402 TI_WARN("SPIR-V Assembly dump for {} :\n{}\n\n", kernel_name, spirv_asm);
2403
2404 std::ofstream fout(kernel_name + ".spv",
2405 std::ios::binary | std::ios::out);
2406 fout.write(reinterpret_cast<const char *>(spirv.data()),
2407 spirv.size() * sizeof(uint32_t));
2408 fout.close();
2409 }
2410
2411 kernel_attribs.tasks_attribs.push_back(std::move(task_res.task_attribs));
2412 generated_spirv.push_back(std::move(optimized_spv));
2413 }
2414 kernel_attribs.ctx_attribs = std::move(ctx_attribs_);
2415 kernel_attribs.name = params_.ti_kernel_name;
2416 kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator;
2417}
2418
2419void lower(const CompileConfig &config, Kernel *kernel) {
2420 if (!kernel->lowered()) {
2421 irpass::compile_to_executable(kernel->ir.get(), config, kernel,
2422 kernel->autodiff_mode,
2423 /*ad_use_stack=*/false, config.print_ir,
2424 /*lower_global_access=*/true,
2425 /*make_thread_local=*/false);
2426 kernel->set_lowered(true);
2427 }
2428}
2429
2430} // namespace spirv
2431} // namespace taichi::lang
2432