1 | #include "taichi/codegen/spirv/spirv_codegen.h" |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | #include <variant> |
6 | |
7 | #include "taichi/program/program.h" |
8 | #include "taichi/program/kernel.h" |
9 | #include "taichi/ir/statements.h" |
10 | #include "taichi/ir/ir.h" |
11 | #include "taichi/util/line_appender.h" |
12 | #include "taichi/codegen/spirv/kernel_utils.h" |
13 | #include "taichi/codegen/spirv/spirv_ir_builder.h" |
14 | #include "taichi/ir/transforms.h" |
15 | #include "taichi/math/arithmetic.h" |
16 | |
17 | #include <spirv-tools/libspirv.hpp> |
18 | #include <spirv-tools/optimizer.hpp> |
19 | |
20 | namespace taichi::lang { |
21 | namespace spirv { |
22 | namespace { |
23 | |
24 | constexpr char kRootBufferName[] = "root_buffer" ; |
25 | constexpr char kGlobalTmpsBufferName[] = "global_tmps_buffer" ; |
26 | constexpr char kArgsBufferName[] = "args_buffer" ; |
27 | constexpr char kRetBufferName[] = "ret_buffer" ; |
28 | constexpr char kListgenBufferName[] = "listgen_buffer" ; |
29 | constexpr char kExtArrBufferName[] = "ext_arr_buffer" ; |
30 | |
31 | constexpr int kMaxNumThreadsGridStrideLoop = 65536 * 2; |
32 | |
33 | using BufferType = TaskAttributes::BufferType; |
34 | using BufferInfo = TaskAttributes::BufferInfo; |
35 | using BufferBind = TaskAttributes::BufferBind; |
36 | using BufferInfoHasher = TaskAttributes::BufferInfoHasher; |
37 | |
38 | using TextureBind = TaskAttributes::TextureBind; |
39 | |
40 | std::string buffer_instance_name(BufferInfo b) { |
41 | // https://www.khronos.org/opengl/wiki/Interface_Block_(GLSL)#Syntax |
42 | switch (b.type) { |
43 | case BufferType::Root: |
44 | return std::string(kRootBufferName) + "_" + std::to_string(b.root_id); |
45 | case BufferType::GlobalTmps: |
46 | return kGlobalTmpsBufferName; |
47 | case BufferType::Args: |
48 | return kArgsBufferName; |
49 | case BufferType::Rets: |
50 | return kRetBufferName; |
51 | case BufferType::ListGen: |
52 | return kListgenBufferName; |
53 | case BufferType::ExtArr: |
54 | return std::string(kExtArrBufferName) + "_" + std::to_string(b.root_id); |
55 | default: |
56 | TI_NOT_IMPLEMENTED; |
57 | break; |
58 | } |
59 | return {}; |
60 | } |
61 | |
62 | class TaskCodegen : public IRVisitor { |
63 | public: |
64 | struct Params { |
65 | OffloadedStmt *task_ir; |
66 | Arch arch; |
67 | DeviceCapabilityConfig *caps; |
68 | std::vector<CompiledSNodeStructs> compiled_structs; |
69 | const KernelContextAttributes *ctx_attribs; |
70 | std::string ti_kernel_name; |
71 | int task_id_in_kernel; |
72 | }; |
73 | |
74 | const bool use_64bit_pointers = false; |
75 | |
76 | explicit TaskCodegen(const Params ¶ms) |
77 | : arch_(params.arch), |
78 | caps_(params.caps), |
79 | task_ir_(params.task_ir), |
80 | compiled_structs_(params.compiled_structs), |
81 | ctx_attribs_(params.ctx_attribs), |
82 | task_name_(fmt::format("{}_t{:02d}" , |
83 | params.ti_kernel_name, |
84 | params.task_id_in_kernel)) { |
85 | allow_undefined_visitor = true; |
86 | invoke_default_visitor = true; |
87 | |
88 | fill_snode_to_root(); |
89 | ir_ = std::make_shared<spirv::IRBuilder>(arch_, caps_); |
90 | } |
91 | |
92 | void fill_snode_to_root() { |
93 | for (int root = 0; root < compiled_structs_.size(); ++root) { |
94 | for (auto &[node_id, node] : compiled_structs_[root].snode_descriptors) { |
95 | snode_to_root_[node_id] = root; |
96 | } |
97 | } |
98 | } |
99 | |
100 | struct Result { |
101 | std::vector<uint32_t> spirv_code; |
102 | TaskAttributes task_attribs; |
103 | std::unordered_map<int, irpass::ExternalPtrAccess> arr_access; |
104 | }; |
105 | |
106 | Result run() { |
107 | ir_->init_header(); |
108 | kernel_function_ = ir_->new_function(); // void main(); |
109 | ir_->debug_name(spv::OpName, kernel_function_, "main" ); |
110 | |
111 | if (task_ir_->task_type == OffloadedTaskType::serial) { |
112 | generate_serial_kernel(task_ir_); |
113 | } else if (task_ir_->task_type == OffloadedTaskType::range_for) { |
114 | // struct_for is automatically lowered to ranged_for for dense snodes |
115 | generate_range_for_kernel(task_ir_); |
116 | } else if (task_ir_->task_type == OffloadedTaskType::struct_for) { |
117 | generate_struct_for_kernel(task_ir_); |
118 | } else { |
119 | TI_ERROR("Unsupported offload type={} on SPIR-V codegen" , |
120 | task_ir_->task_name()); |
121 | } |
122 | // Headers need global information, so it has to be delayed after visiting |
123 | // the task IR. |
124 | emit_headers(); |
125 | |
126 | Result res; |
127 | res.spirv_code = ir_->finalize(); |
128 | res.task_attribs = std::move(task_attribs_); |
129 | res.arr_access = irpass::detect_external_ptr_access_in_task(task_ir_); |
130 | |
131 | return res; |
132 | } |
133 | |
134 | void visit(OffloadedStmt *) override { |
135 | TI_ERROR("This codegen is supposed to deal with one offloaded task" ); |
136 | } |
137 | |
138 | void visit(Block *stmt) override { |
139 | for (auto &s : stmt->statements) { |
140 | if (offload_loop_motion_.find(s.get()) == offload_loop_motion_.end()) { |
141 | s->accept(this); |
142 | } |
143 | } |
144 | } |
145 | |
146 | void visit(PrintStmt *stmt) override { |
147 | if (!caps_->get(DeviceCapability::spirv_has_non_semantic_info)) { |
148 | return; |
149 | } |
150 | |
151 | std::string formats; |
152 | std::vector<Value> vals; |
153 | |
154 | for (auto const &content : stmt->contents) { |
155 | if (std::holds_alternative<Stmt *>(content)) { |
156 | auto arg_stmt = std::get<Stmt *>(content); |
157 | TI_ASSERT(!arg_stmt->ret_type->is<TensorType>()); |
158 | |
159 | auto value = ir_->query_value(arg_stmt->raw_name()); |
160 | vals.push_back(value); |
161 | formats += data_type_format(arg_stmt->ret_type, Arch::vulkan); |
162 | } else { |
163 | auto arg_str = std::get<std::string>(content); |
164 | formats += arg_str; |
165 | } |
166 | } |
167 | ir_->call_debugprintf(formats, vals); |
168 | } |
169 | |
170 | void visit(ConstStmt *const_stmt) override { |
171 | auto get_const = [&](const TypedConstant &const_val) { |
172 | auto dt = const_val.dt.ptr_removed(); |
173 | spirv::SType stype = ir_->get_primitive_type(dt); |
174 | |
175 | if (dt->is_primitive(PrimitiveTypeID::f32)) { |
176 | return ir_->float_immediate_number( |
177 | stype, static_cast<double>(const_val.val_f32), false); |
178 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
179 | return ir_->int_immediate_number( |
180 | stype, static_cast<int64_t>(const_val.val_i32), false); |
181 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
182 | return ir_->int_immediate_number( |
183 | stype, static_cast<int64_t>(const_val.val_i64), false); |
184 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
185 | return ir_->float_immediate_number( |
186 | stype, static_cast<double>(const_val.val_f64), false); |
187 | } else if (dt->is_primitive(PrimitiveTypeID::i8)) { |
188 | return ir_->int_immediate_number( |
189 | stype, static_cast<int64_t>(const_val.val_i8), false); |
190 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
191 | return ir_->int_immediate_number( |
192 | stype, static_cast<int64_t>(const_val.val_i16), false); |
193 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
194 | return ir_->uint_immediate_number( |
195 | stype, static_cast<uint64_t>(const_val.val_u8), false); |
196 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
197 | return ir_->uint_immediate_number( |
198 | stype, static_cast<uint64_t>(const_val.val_u16), false); |
199 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
200 | return ir_->uint_immediate_number( |
201 | stype, static_cast<uint64_t>(const_val.val_u32), false); |
202 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
203 | return ir_->uint_immediate_number( |
204 | stype, static_cast<uint64_t>(const_val.val_u64), false); |
205 | } else { |
206 | TI_P(data_type_name(dt)); |
207 | TI_NOT_IMPLEMENTED |
208 | return spirv::Value(); |
209 | } |
210 | }; |
211 | |
212 | spirv::Value val = get_const(const_stmt->val); |
213 | ir_->register_value(const_stmt->raw_name(), val); |
214 | } |
215 | |
216 | void visit(AllocaStmt *alloca) override { |
217 | spirv::Value ptr_val; |
218 | if (alloca->ret_type->is<TensorType>()) { |
219 | auto tensor_type = alloca->ret_type->cast<TensorType>(); |
220 | auto elem_num = tensor_type->get_num_elements(); |
221 | spirv::SType elem_type = |
222 | ir_->get_primitive_type(tensor_type->get_element_type()); |
223 | spirv::SType arr_type = ir_->get_array_type(elem_type, elem_num); |
224 | if (alloca->is_shared) { // for shared memory / workgroup memory |
225 | ptr_val = ir_->alloca_workgroup_array(arr_type); |
226 | shared_array_binds_.push_back(ptr_val); |
227 | } else { // for function memory |
228 | ptr_val = ir_->alloca_variable(arr_type); |
229 | } |
230 | } else { |
231 | // Alloca for a single variable |
232 | spirv::SType src_type = ir_->get_primitive_type(alloca->element_type()); |
233 | ptr_val = ir_->alloca_variable(src_type); |
234 | ir_->store_variable(ptr_val, ir_->get_zero(src_type)); |
235 | } |
236 | ir_->register_value(alloca->raw_name(), ptr_val); |
237 | } |
238 | |
239 | void visit(MatrixPtrStmt *stmt) override { |
240 | spirv::Value ptr_val; |
241 | spirv::Value origin_val = ir_->query_value(stmt->origin->raw_name()); |
242 | spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name()); |
243 | auto dt = stmt->element_type().ptr_removed(); |
244 | if (stmt->offset_used_as_index()) { |
245 | if (stmt->origin->is<AllocaStmt>()) { |
246 | spirv::SType ptr_type = ir_->get_pointer_type( |
247 | ir_->get_primitive_type(dt), origin_val.stype.storage_class); |
248 | ptr_val = ir_->make_value(spv::OpAccessChain, ptr_type, origin_val, |
249 | offset_val); |
250 | } else if (stmt->origin->is<GlobalTemporaryStmt>()) { |
251 | spirv::Value dt_bytes = ir_->int_immediate_number( |
252 | ir_->i32_type(), ir_->get_primitive_type_size(dt), false); |
253 | spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val); |
254 | ptr_val = ir_->add(origin_val, offset_bytes); |
255 | ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; |
256 | } else { |
257 | TI_NOT_IMPLEMENTED; |
258 | } |
259 | } else { // offset used as bytes |
260 | ptr_val = ir_->add(origin_val, ir_->cast(origin_val.stype, offset_val)); |
261 | ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin]; |
262 | } |
263 | ir_->register_value(stmt->raw_name(), ptr_val); |
264 | } |
265 | |
266 | void visit(LocalLoadStmt *stmt) override { |
267 | auto ptr = stmt->src; |
268 | spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); |
269 | spirv::Value val = ir_->load_variable( |
270 | ptr_val, ir_->get_primitive_type(stmt->element_type())); |
271 | ir_->register_value(stmt->raw_name(), val); |
272 | } |
273 | |
274 | void visit(LocalStoreStmt *stmt) override { |
275 | spirv::Value ptr_val = ir_->query_value(stmt->dest->raw_name()); |
276 | spirv::Value val = ir_->query_value(stmt->val->raw_name()); |
277 | ir_->store_variable(ptr_val, val); |
278 | } |
279 | |
280 | void visit(GetRootStmt *stmt) override { |
281 | const int root_id = snode_to_root_.at(stmt->root()->id); |
282 | root_stmts_[root_id] = stmt; |
283 | // get_buffer_value({BufferType::Root, root_id}, PrimitiveType::u32); |
284 | spirv::Value root_val = make_pointer(0); |
285 | ir_->register_value(stmt->raw_name(), root_val); |
286 | } |
287 | |
288 | void visit(GetChStmt *stmt) override { |
289 | // TODO: GetChStmt -> GetComponentStmt ? |
290 | const int root = snode_to_root_.at(stmt->input_snode->id); |
291 | |
292 | const auto &snode_descs = compiled_structs_[root].snode_descriptors; |
293 | auto *out_snode = stmt->output_snode; |
294 | TI_ASSERT(snode_descs.at(stmt->input_snode->id).get_child(stmt->chid) == |
295 | out_snode); |
296 | |
297 | const auto &desc = snode_descs.at(out_snode->id); |
298 | |
299 | spirv::Value input_ptr_val = ir_->query_value(stmt->input_ptr->raw_name()); |
300 | spirv::Value offset = make_pointer(desc.mem_offset_in_parent_cell); |
301 | spirv::Value val = ir_->add(input_ptr_val, offset); |
302 | ir_->register_value(stmt->raw_name(), val); |
303 | |
304 | if (out_snode->is_place()) { |
305 | TI_ASSERT(ptr_to_buffers_.count(stmt) == 0); |
306 | ptr_to_buffers_[stmt] = BufferInfo(BufferType::Root, root); |
307 | } |
308 | } |
309 | |
310 | enum class ActivationOp { activate, deactivate, query }; |
311 | |
312 | spirv::Value bitmasked_activation(ActivationOp op, |
313 | spirv::Value parent_ptr, |
314 | int root_id, |
315 | const SNode *sn, |
316 | spirv::Value input_index) { |
317 | spirv::SType ptr_dt = parent_ptr.stype; |
318 | const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; |
319 | const auto &desc = snode_descs.at(sn->id); |
320 | |
321 | auto bitmask_word_index = |
322 | ir_->make_value(spv::OpShiftRightLogical, ptr_dt, input_index, |
323 | ir_->uint_immediate_number(ptr_dt, 5)); |
324 | auto bitmask_bit_index = |
325 | ir_->make_value(spv::OpBitwiseAnd, ptr_dt, input_index, |
326 | ir_->uint_immediate_number(ptr_dt, 31)); |
327 | auto bitmask_mask = ir_->make_value(spv::OpShiftLeftLogical, ptr_dt, |
328 | ir_->const_i32_one_, bitmask_bit_index); |
329 | |
330 | auto buffer = get_buffer_value(BufferInfo(BufferType::Root, root_id), |
331 | PrimitiveType::u32); |
332 | auto bitmask_word_ptr = |
333 | ir_->make_value(spv::OpShiftLeftLogical, ptr_dt, bitmask_word_index, |
334 | ir_->uint_immediate_number(ir_->u32_type(), 2)); |
335 | bitmask_word_ptr = ir_->add( |
336 | bitmask_word_ptr, |
337 | make_pointer(desc.cell_stride * desc.snode->num_cells_per_container)); |
338 | bitmask_word_ptr = ir_->add(parent_ptr, bitmask_word_ptr); |
339 | bitmask_word_ptr = ir_->make_value( |
340 | spv::OpShiftRightLogical, ir_->u32_type(), bitmask_word_ptr, |
341 | ir_->uint_immediate_number(ir_->u32_type(), 2)); |
342 | bitmask_word_ptr = |
343 | ir_->struct_array_access(ir_->u32_type(), buffer, bitmask_word_ptr); |
344 | |
345 | if (op == ActivationOp::activate) { |
346 | return ir_->make_value(spv::OpAtomicOr, ir_->u32_type(), bitmask_word_ptr, |
347 | /*scope=*/ir_->const_i32_one_, |
348 | /*semantics=*/ir_->const_i32_zero_, bitmask_mask); |
349 | } else if (op == ActivationOp::deactivate) { |
350 | bitmask_mask = ir_->make_value(spv::OpNot, ir_->u32_type(), bitmask_mask); |
351 | return ir_->make_value(spv::OpAtomicAnd, ir_->u32_type(), |
352 | bitmask_word_ptr, |
353 | /*scope=*/ir_->const_i32_one_, |
354 | /*semantics=*/ir_->const_i32_zero_, bitmask_mask); |
355 | } else { |
356 | auto bitmask_val = ir_->load_variable(bitmask_word_ptr, ir_->u32_type()); |
357 | auto bit = ir_->make_value(spv::OpShiftRightLogical, ir_->u32_type(), |
358 | bitmask_val, bitmask_bit_index); |
359 | bit = ir_->make_value(spv::OpBitwiseAnd, ir_->u32_type(), bit, |
360 | ir_->uint_immediate_number(ir_->u32_type(), 1)); |
361 | return ir_->make_value(spv::OpUGreaterThan, ir_->bool_type(), bit, |
362 | ir_->uint_immediate_number(ir_->u32_type(), 0)); |
363 | } |
364 | } |
365 | |
366 | void visit(SNodeOpStmt *stmt) override { |
367 | const int root_id = snode_to_root_.at(stmt->snode->id); |
368 | std::string parent = stmt->ptr->raw_name(); |
369 | spirv::Value parent_val = ir_->query_value(parent); |
370 | |
371 | if (stmt->snode->type == SNodeType::bitmasked) { |
372 | spirv::Value input_index_val = |
373 | ir_->cast(parent_val.stype, ir_->query_value(stmt->val->raw_name())); |
374 | |
375 | if (stmt->op_type == SNodeOpType::is_active) { |
376 | auto is_active = |
377 | bitmasked_activation(ActivationOp::query, parent_val, root_id, |
378 | stmt->snode, input_index_val); |
379 | is_active = |
380 | ir_->cast(ir_->get_primitive_type(stmt->ret_type), is_active); |
381 | is_active = ir_->make_value(spv::OpSNegate, is_active.stype, is_active); |
382 | ir_->register_value(stmt->raw_name(), is_active); |
383 | } else if (stmt->op_type == SNodeOpType::deactivate) { |
384 | bitmasked_activation(ActivationOp::deactivate, parent_val, root_id, |
385 | stmt->snode, input_index_val); |
386 | } else if (stmt->op_type == SNodeOpType::activate) { |
387 | bitmasked_activation(ActivationOp::activate, parent_val, root_id, |
388 | stmt->snode, input_index_val); |
389 | } else { |
390 | TI_NOT_IMPLEMENTED; |
391 | } |
392 | } else { |
393 | TI_NOT_IMPLEMENTED; |
394 | } |
395 | } |
396 | |
397 | void visit(SNodeLookupStmt *stmt) override { |
398 | // TODO: SNodeLookupStmt -> GetSNodeCellStmt ? |
399 | bool is_root{false}; // Eliminate first root snode access |
400 | const int root_id = snode_to_root_.at(stmt->snode->id); |
401 | std::string parent; |
402 | |
403 | if (stmt->input_snode) { |
404 | parent = stmt->input_snode->raw_name(); |
405 | } else { |
406 | TI_ASSERT(root_stmts_.at(root_id) != nullptr); |
407 | parent = root_stmts_.at(root_id)->raw_name(); |
408 | } |
409 | const auto *sn = stmt->snode; |
410 | |
411 | spirv::Value parent_val = ir_->query_value(parent); |
412 | |
413 | if (stmt->activate) { |
414 | if (sn->type == SNodeType::dense) { |
415 | // Do nothing |
416 | } else if (sn->type == SNodeType::bitmasked) { |
417 | spirv::Value input_index_val = |
418 | ir_->query_value(stmt->input_index->raw_name()); |
419 | bitmasked_activation(ActivationOp::activate, parent_val, root_id, sn, |
420 | input_index_val); |
421 | } else { |
422 | TI_NOT_IMPLEMENTED; |
423 | } |
424 | } |
425 | |
426 | spirv::Value val; |
427 | if (is_root) { |
428 | val = parent_val; // Assert Root[0] access at first time |
429 | } else { |
430 | const auto &snode_descs = compiled_structs_[root_id].snode_descriptors; |
431 | const auto &desc = snode_descs.at(sn->id); |
432 | |
433 | spirv::Value input_index_val = ir_->cast( |
434 | parent_val.stype, ir_->query_value(stmt->input_index->raw_name())); |
435 | spirv::Value stride = make_pointer(desc.cell_stride); |
436 | spirv::Value offset = ir_->mul(input_index_val, stride); |
437 | val = ir_->add(parent_val, offset); |
438 | } |
439 | ir_->register_value(stmt->raw_name(), val); |
440 | } |
441 | |
442 | void visit(RandStmt *stmt) override { |
443 | spirv::Value val; |
444 | spirv::Value global_tmp = |
445 | get_buffer_value(BufferType::GlobalTmps, PrimitiveType::u32); |
446 | if (stmt->element_type()->is_primitive(PrimitiveTypeID::i32)) { |
447 | val = ir_->rand_i32(global_tmp); |
448 | } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::u32)) { |
449 | val = ir_->rand_u32(global_tmp); |
450 | } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f32)) { |
451 | val = ir_->rand_f32(global_tmp); |
452 | } else if (stmt->element_type()->is_primitive(PrimitiveTypeID::f16)) { |
453 | auto highp_val = ir_->rand_f32(global_tmp); |
454 | val = ir_->cast(ir_->f16_type(), highp_val); |
455 | } else { |
456 | TI_ERROR("rand only support 32-bit type" ); |
457 | } |
458 | ir_->register_value(stmt->raw_name(), val); |
459 | } |
460 | |
461 | void visit(LinearizeStmt *stmt) override { |
462 | spirv::Value val = ir_->const_i32_zero_; |
463 | for (size_t i = 0; i < stmt->inputs.size(); ++i) { |
464 | spirv::Value strides_val = |
465 | ir_->int_immediate_number(ir_->i32_type(), stmt->strides[i]); |
466 | spirv::Value input_val = ir_->query_value(stmt->inputs[i]->raw_name()); |
467 | val = ir_->add(ir_->mul(val, strides_val), input_val); |
468 | } |
469 | ir_->register_value(stmt->raw_name(), val); |
470 | } |
471 | |
472 | void visit(LoopIndexStmt *stmt) override { |
473 | const auto stmt_name = stmt->raw_name(); |
474 | if (stmt->loop->is<OffloadedStmt>()) { |
475 | const auto type = stmt->loop->as<OffloadedStmt>()->task_type; |
476 | if (type == OffloadedTaskType::range_for) { |
477 | TI_ASSERT(stmt->index == 0); |
478 | spirv::Value loop_var = ir_->query_value("ii" ); |
479 | // spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_); |
480 | ir_->register_value(stmt_name, loop_var); |
481 | } else { |
482 | TI_NOT_IMPLEMENTED; |
483 | } |
484 | } else if (stmt->loop->is<RangeForStmt>()) { |
485 | TI_ASSERT(stmt->index == 0); |
486 | spirv::Value loop_var = ir_->query_value(stmt->loop->raw_name()); |
487 | spirv::Value val = ir_->add(loop_var, ir_->const_i32_zero_); |
488 | ir_->register_value(stmt_name, val); |
489 | } else { |
490 | TI_NOT_IMPLEMENTED; |
491 | } |
492 | } |
493 | |
494 | void visit(GlobalStoreStmt *stmt) override { |
495 | spirv::Value val = ir_->query_value(stmt->val->raw_name()); |
496 | |
497 | store_buffer(stmt->dest, val); |
498 | } |
499 | |
500 | void visit(GlobalLoadStmt *stmt) override { |
501 | auto dt = stmt->element_type(); |
502 | |
503 | auto val = load_buffer(stmt->src, dt); |
504 | |
505 | ir_->register_value(stmt->raw_name(), val); |
506 | } |
507 | |
508 | void visit(ArgLoadStmt *stmt) override { |
509 | const auto arg_id = stmt->arg_id; |
510 | const auto &arg_attribs = ctx_attribs_->args()[arg_id]; |
511 | if (stmt->is_ptr) { |
512 | // Do not shift! We are indexing the buffers at byte granularity. |
513 | // spirv::Value val = |
514 | // ir_->int_immediate_number(ir_->i32_type(), offset_in_mem); |
515 | // ir_->register_value(stmt->raw_name(), val); |
516 | } else { |
517 | const auto dt = PrimitiveType::get(arg_attribs.dtype); |
518 | const auto val_type = ir_->get_primitive_type(dt); |
519 | spirv::Value buffer_val = ir_->make_value( |
520 | spv::OpAccessChain, |
521 | ir_->get_pointer_type(val_type, spv::StorageClassUniform), |
522 | get_buffer_value(BufferType::Args, PrimitiveType::i32), |
523 | ir_->int_immediate_number(ir_->i32_type(), arg_id)); |
524 | buffer_val.flag = ValueKind::kVariablePtr; |
525 | spirv::Value val = ir_->load_variable(buffer_val, val_type); |
526 | ir_->register_value(stmt->raw_name(), val); |
527 | } |
528 | } |
529 | |
530 | void visit(ReturnStmt *stmt) override { |
531 | // Now we only support one ret |
532 | auto dt = stmt->element_types()[0]; |
533 | for (int i = 0; i < stmt->values.size(); i++) { |
534 | spirv::Value buffer_val = ir_->make_value( |
535 | spv::OpAccessChain, |
536 | ir_->get_storage_pointer_type(ir_->get_primitive_type(dt)), |
537 | get_buffer_value(BufferType::Rets, dt), |
538 | ir_->int_immediate_number(ir_->i32_type(), 0), |
539 | ir_->int_immediate_number(ir_->i32_type(), i)); |
540 | buffer_val.flag = ValueKind::kVariablePtr; |
541 | spirv::Value val = ir_->query_value(stmt->values[i]->raw_name()); |
542 | ir_->store_variable(buffer_val, val); |
543 | } |
544 | } |
545 | |
546 | void visit(GlobalTemporaryStmt *stmt) override { |
547 | spirv::Value val = ir_->int_immediate_number(ir_->i32_type(), stmt->offset, |
548 | false); // Named Constant |
549 | ir_->register_value(stmt->raw_name(), val); |
550 | ptr_to_buffers_[stmt] = BufferType::GlobalTmps; |
551 | } |
552 | |
553 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
554 | const auto name = stmt->raw_name(); |
555 | const auto arg_id = stmt->arg_id; |
556 | const auto axis = stmt->axis; |
557 | |
558 | const auto = ctx_attribs_->args().size(); |
559 | |
560 | const auto = (arg_id * taichi_max_num_indices) + axis; |
561 | spirv::Value var_ptr = ir_->make_value( |
562 | spv::OpAccessChain, |
563 | ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform), |
564 | get_buffer_value(BufferType::Args, PrimitiveType::i32), |
565 | ir_->int_immediate_number(ir_->i32_type(), |
566 | extra_args_member_index + extra_arg_index)); |
567 | spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type()); |
568 | |
569 | ir_->register_value(name, var); |
570 | } |
571 | |
572 | void visit(ExternalPtrStmt *stmt) override { |
573 | // Used mostly for transferring data between host (e.g. numpy array) and |
574 | // device. |
575 | spirv::Value linear_offset = ir_->int_immediate_number(ir_->i32_type(), 0); |
576 | const auto *argload = stmt->base_ptr->as<ArgLoadStmt>(); |
577 | const int arg_id = argload->arg_id; |
578 | { |
579 | const int num_indices = stmt->indices.size(); |
580 | std::vector<std::string> size_var_names; |
581 | const auto &element_shape = stmt->element_shape; |
582 | const auto layout = stmt->element_dim <= 0 ? ExternalArrayLayout::kAOS |
583 | : ExternalArrayLayout::kSOA; |
584 | const auto = ctx_attribs_->args().size(); |
585 | const size_t element_shape_index_offset = |
586 | (layout == ExternalArrayLayout::kAOS) |
587 | ? num_indices - element_shape.size() |
588 | : 0; |
589 | for (int i = 0; i < num_indices - element_shape.size(); i++) { |
590 | std::string var_name = fmt::format("{}_size{}_" , stmt->raw_name(), i); |
591 | const auto = (arg_id * taichi_max_num_indices) + i; |
592 | spirv::Value var_ptr = ir_->make_value( |
593 | spv::OpAccessChain, |
594 | ir_->get_pointer_type(ir_->i32_type(), spv::StorageClassUniform), |
595 | get_buffer_value(BufferType::Args, PrimitiveType::i32), |
596 | ir_->int_immediate_number( |
597 | ir_->i32_type(), extra_args_member_index + extra_arg_index)); |
598 | spirv::Value var = ir_->load_variable(var_ptr, ir_->i32_type()); |
599 | ir_->register_value(var_name, var); |
600 | size_var_names.push_back(std::move(var_name)); |
601 | } |
602 | int size_var_names_idx = 0; |
603 | for (int i = 0; i < num_indices; i++) { |
604 | spirv::Value size_var; |
605 | // Use immediate numbers to flatten index for element shapes. |
606 | if (i >= element_shape_index_offset && |
607 | i < element_shape_index_offset + element_shape.size()) { |
608 | size_var = ir_->uint_immediate_number( |
609 | ir_->i32_type(), element_shape[i - element_shape_index_offset]); |
610 | } else { |
611 | size_var = ir_->query_value(size_var_names[size_var_names_idx++]); |
612 | } |
613 | spirv::Value indices = ir_->query_value(stmt->indices[i]->raw_name()); |
614 | linear_offset = ir_->mul(linear_offset, size_var); |
615 | linear_offset = ir_->add(linear_offset, indices); |
616 | } |
617 | linear_offset = ir_->make_value( |
618 | spv::OpShiftLeftLogical, ir_->i32_type(), linear_offset, |
619 | ir_->int_immediate_number(ir_->i32_type(), |
620 | log2int(ir_->get_primitive_type_size( |
621 | argload->ret_type.ptr_removed())))); |
622 | if (caps_->get(DeviceCapability::spirv_has_no_integer_wrap_decoration)) { |
623 | ir_->decorate(spv::OpDecorate, linear_offset, |
624 | spv::DecorationNoSignedWrap); |
625 | } |
626 | } |
627 | if (caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) { |
628 | spirv::Value addr_ptr = ir_->make_value( |
629 | spv::OpAccessChain, |
630 | ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform), |
631 | get_buffer_value(BufferType::Args, PrimitiveType::i32), |
632 | ir_->int_immediate_number(ir_->i32_type(), arg_id)); |
633 | spirv::Value addr = ir_->load_variable(addr_ptr, ir_->u64_type()); |
634 | addr = ir_->add(addr, ir_->make_value(spv::OpSConvert, ir_->u64_type(), |
635 | linear_offset)); |
636 | ir_->register_value(stmt->raw_name(), addr); |
637 | } else { |
638 | ir_->register_value(stmt->raw_name(), linear_offset); |
639 | } |
640 | |
641 | if (ctx_attribs_->args()[arg_id].is_array) { |
642 | ptr_to_buffers_[stmt] = {BufferType::ExtArr, arg_id}; |
643 | } else { |
644 | ptr_to_buffers_[stmt] = BufferType::Args; |
645 | } |
646 | } |
647 | |
648 | void visit(DecorationStmt *stmt) override { |
649 | } |
650 | |
651 | void visit(UnaryOpStmt *stmt) override { |
652 | const auto operand_name = stmt->operand->raw_name(); |
653 | |
654 | const auto src_dt = stmt->operand->element_type(); |
655 | const auto dst_dt = stmt->element_type(); |
656 | spirv::SType src_type = ir_->get_primitive_type(src_dt); |
657 | spirv::SType dst_type = ir_->get_primitive_type(dst_dt); |
658 | spirv::Value operand_val = ir_->query_value(operand_name); |
659 | spirv::Value val = spirv::Value(); |
660 | |
661 | if (stmt->op_type == UnaryOpType::logic_not) { |
662 | spirv::Value zero = |
663 | ir_->get_zero(src_type); // Math zero type to left hand side |
664 | if (is_integral(src_dt)) { |
665 | if (is_signed(src_dt)) { |
666 | zero = ir_->int_immediate_number(src_type, 0); |
667 | } else { |
668 | zero = ir_->uint_immediate_number(src_type, 0); |
669 | } |
670 | } else if (is_real(src_dt)) { |
671 | zero = ir_->float_immediate_number(src_type, 0); |
672 | } else { |
673 | TI_NOT_IMPLEMENTED |
674 | } |
675 | val = ir_->cast(dst_type, ir_->eq(operand_val, zero)); |
676 | } else if (stmt->op_type == UnaryOpType::neg) { |
677 | operand_val = ir_->cast(dst_type, operand_val); |
678 | if (is_integral(dst_dt)) { |
679 | if (is_signed(dst_dt)) { |
680 | val = ir_->make_value(spv::OpSNegate, dst_type, operand_val); |
681 | } else { |
682 | TI_NOT_IMPLEMENTED |
683 | } |
684 | } else if (is_real(dst_dt)) { |
685 | val = ir_->make_value(spv::OpFNegate, dst_type, operand_val); |
686 | } else { |
687 | TI_NOT_IMPLEMENTED |
688 | } |
689 | } else if (stmt->op_type == UnaryOpType::rsqrt) { |
690 | const uint32_t InverseSqrt_id = 32; |
691 | if (is_real(src_dt)) { |
692 | val = ir_->call_glsl450(src_type, InverseSqrt_id, operand_val); |
693 | val = ir_->cast(dst_type, val); |
694 | } else { |
695 | TI_NOT_IMPLEMENTED |
696 | } |
697 | } else if (stmt->op_type == UnaryOpType::sgn) { |
698 | const uint32_t FSign_id = 6; |
699 | const uint32_t SSign_id = 7; |
700 | if (is_integral(src_dt)) { |
701 | if (is_signed(src_dt)) { |
702 | val = ir_->call_glsl450(src_type, SSign_id, operand_val); |
703 | } else { |
704 | TI_NOT_IMPLEMENTED |
705 | } |
706 | } else if (is_real(src_dt)) { |
707 | val = ir_->call_glsl450(src_type, FSign_id, operand_val); |
708 | } else { |
709 | TI_NOT_IMPLEMENTED |
710 | } |
711 | val = ir_->cast(dst_type, val); |
712 | } else if (stmt->op_type == UnaryOpType::bit_not) { |
713 | operand_val = ir_->cast(dst_type, operand_val); |
714 | if (is_integral(dst_dt)) { |
715 | val = ir_->make_value(spv::OpNot, dst_type, operand_val); |
716 | } else { |
717 | TI_NOT_IMPLEMENTED |
718 | } |
719 | } else if (stmt->op_type == UnaryOpType::cast_value) { |
720 | val = ir_->cast(dst_type, operand_val); |
721 | } else if (stmt->op_type == UnaryOpType::cast_bits) { |
722 | if (data_type_bits(src_dt) == data_type_bits(dst_dt)) { |
723 | val = ir_->make_value(spv::OpBitcast, dst_type, operand_val); |
724 | } else { |
725 | TI_ERROR("bit_cast is only supported between data type with same size" ); |
726 | } |
727 | } else if (stmt->op_type == UnaryOpType::abs) { |
728 | const uint32_t FAbs_id = 4; |
729 | const uint32_t SAbs_id = 5; |
730 | if (src_type.id == dst_type.id) { |
731 | if (is_integral(src_dt)) { |
732 | if (is_signed(src_dt)) { |
733 | val = ir_->call_glsl450(src_type, SAbs_id, operand_val); |
734 | } else { |
735 | TI_NOT_IMPLEMENTED |
736 | } |
737 | } else if (is_real(src_dt)) { |
738 | val = ir_->call_glsl450(src_type, FAbs_id, operand_val); |
739 | } else { |
740 | TI_NOT_IMPLEMENTED |
741 | } |
742 | } else { |
743 | TI_NOT_IMPLEMENTED |
744 | } |
745 | } else if (stmt->op_type == UnaryOpType::inv) { |
746 | if (is_real(dst_dt)) { |
747 | val = ir_->div(ir_->float_immediate_number(dst_type, 1), operand_val); |
748 | } else { |
749 | TI_NOT_IMPLEMENTED |
750 | } |
751 | } |
752 | #define UNARY_OP_TO_SPIRV(op, instruction, instruction_id, max_bits) \ |
753 | else if (stmt->op_type == UnaryOpType::op) { \ |
754 | const uint32_t instruction = instruction_id; \ |
755 | if (is_real(src_dt)) { \ |
756 | if (data_type_bits(src_dt) > max_bits) { \ |
757 | TI_ERROR("Instruction {}({}) does not {}bits operation", #instruction, \ |
758 | instruction_id, data_type_bits(src_dt)); \ |
759 | } \ |
760 | val = ir_->call_glsl450(src_type, instruction, operand_val); \ |
761 | } else { \ |
762 | TI_NOT_IMPLEMENTED \ |
763 | } \ |
764 | } |
765 | UNARY_OP_TO_SPIRV(round, Round, 1, 64) |
766 | UNARY_OP_TO_SPIRV(floor, Floor, 8, 64) |
767 | UNARY_OP_TO_SPIRV(ceil, Ceil, 9, 64) |
768 | UNARY_OP_TO_SPIRV(sin, Sin, 13, 32) |
769 | UNARY_OP_TO_SPIRV(asin, Asin, 16, 32) |
770 | UNARY_OP_TO_SPIRV(cos, Cos, 14, 32) |
771 | UNARY_OP_TO_SPIRV(acos, Acos, 17, 32) |
772 | UNARY_OP_TO_SPIRV(tan, Tan, 15, 32) |
773 | UNARY_OP_TO_SPIRV(tanh, Tanh, 21, 32) |
774 | UNARY_OP_TO_SPIRV(exp, Exp, 27, 32) |
775 | UNARY_OP_TO_SPIRV(log, Log, 28, 32) |
776 | UNARY_OP_TO_SPIRV(sqrt, Sqrt, 31, 64) |
777 | #undef UNARY_OP_TO_SPIRV |
778 | else {TI_NOT_IMPLEMENTED} ir_->register_value(stmt->raw_name(), val); |
779 | } |
780 | |
781 | void generate_overflow_branch(const spirv::Value &cond_v, |
782 | const std::string &op, |
783 | const std::string &tb) { |
784 | spirv::Value cond = |
785 | ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_)); |
786 | spirv::Label then_label = ir_->new_label(); |
787 | spirv::Label merge_label = ir_->new_label(); |
788 | ir_->make_inst(spv::OpSelectionMerge, merge_label, |
789 | spv::SelectionControlMaskNone); |
790 | ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label); |
791 | // then block |
792 | ir_->start_label(then_label); |
793 | ir_->call_debugprintf(op + " overflow detected in " + tb, {}); |
794 | ir_->make_inst(spv::OpBranch, merge_label); |
795 | // merge label |
796 | ir_->start_label(merge_label); |
797 | } |
798 | |
799 | spirv::Value generate_uadd_overflow(const spirv::Value &a, |
800 | const spirv::Value &b, |
801 | const std::string &tb) { |
802 | std::vector<std::tuple<spirv::SType, std::string, size_t>> |
803 | struct_components_; |
804 | struct_components_.emplace_back(a.stype, "result" , 0); |
805 | struct_components_.emplace_back(a.stype, "carry" , |
806 | ir_->get_primitive_type_size(a.stype.dt)); |
807 | auto struct_type = ir_->create_struct_type(struct_components_); |
808 | auto add_carry = ir_->make_value(spv::OpIAddCarry, struct_type, a, b); |
809 | auto result = |
810 | ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 0); |
811 | auto carry = |
812 | ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 1); |
813 | generate_overflow_branch(carry, "Addition" , tb); |
814 | return result; |
815 | } |
816 | |
817 | spirv::Value generate_usub_overflow(const spirv::Value &a, |
818 | const spirv::Value &b, |
819 | const std::string &tb) { |
820 | std::vector<std::tuple<spirv::SType, std::string, size_t>> |
821 | struct_components_; |
822 | struct_components_.emplace_back(a.stype, "result" , 0); |
823 | struct_components_.emplace_back(a.stype, "borrow" , |
824 | ir_->get_primitive_type_size(a.stype.dt)); |
825 | auto struct_type = ir_->create_struct_type(struct_components_); |
826 | auto add_carry = ir_->make_value(spv::OpISubBorrow, struct_type, a, b); |
827 | auto result = |
828 | ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 0); |
829 | auto borrow = |
830 | ir_->make_value(spv::OpCompositeExtract, a.stype, add_carry, 1); |
831 | generate_overflow_branch(borrow, "Subtraction" , tb); |
832 | return result; |
833 | } |
834 | |
835 | spirv::Value generate_sadd_overflow(const spirv::Value &a, |
836 | const spirv::Value &b, |
837 | const std::string &tb) { |
838 | // overflow iff (sign(a) == sign(b)) && (sign(a) != sign(result)) |
839 | auto result = ir_->make_value(spv::OpIAdd, a.stype, a, b); |
840 | auto zero = ir_->int_immediate_number(a.stype, 0); |
841 | auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero); |
842 | auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero); |
843 | auto r_sign = |
844 | ir_->make_value(spv::OpSLessThan, ir_->bool_type(), result, zero); |
845 | auto a_eq_b = |
846 | ir_->make_value(spv::OpLogicalEqual, ir_->bool_type(), a_sign, b_sign); |
847 | auto a_neq_r = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(), |
848 | a_sign, r_sign); |
849 | auto overflow = |
850 | ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), a_eq_b, a_neq_r); |
851 | generate_overflow_branch(overflow, "Addition" , tb); |
852 | return result; |
853 | } |
854 | |
855 | spirv::Value generate_ssub_overflow(const spirv::Value &a, |
856 | const spirv::Value &b, |
857 | const std::string &tb) { |
858 | // overflow iff (sign(a) != sign(b)) && (sign(a) != sign(result)) |
859 | auto result = ir_->make_value(spv::OpISub, a.stype, a, b); |
860 | auto zero = ir_->int_immediate_number(a.stype, 0); |
861 | auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero); |
862 | auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero); |
863 | auto r_sign = |
864 | ir_->make_value(spv::OpSLessThan, ir_->bool_type(), result, zero); |
865 | auto a_neq_b = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(), |
866 | a_sign, b_sign); |
867 | auto a_neq_r = ir_->make_value(spv::OpLogicalNotEqual, ir_->bool_type(), |
868 | a_sign, r_sign); |
869 | auto overflow = |
870 | ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), a_neq_b, a_neq_r); |
871 | generate_overflow_branch(overflow, "Subtraction" , tb); |
872 | return result; |
873 | } |
874 | |
875 | spirv::Value generate_umul_overflow(const spirv::Value &a, |
876 | const spirv::Value &b, |
877 | const std::string &tb) { |
878 | // overflow iff high bits != 0 |
879 | std::vector<std::tuple<spirv::SType, std::string, size_t>> |
880 | struct_components_; |
881 | struct_components_.emplace_back(a.stype, "low" , 0); |
882 | struct_components_.emplace_back(a.stype, "high" , |
883 | ir_->get_primitive_type_size(a.stype.dt)); |
884 | auto struct_type = ir_->create_struct_type(struct_components_); |
885 | auto mul_ext = ir_->make_value(spv::OpUMulExtended, struct_type, a, b); |
886 | auto low = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 0); |
887 | auto high = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 1); |
888 | generate_overflow_branch(high, "Multiplication" , tb); |
889 | return low; |
890 | } |
891 | |
892 | spirv::Value generate_smul_overflow(const spirv::Value &a, |
893 | const spirv::Value &b, |
894 | const std::string &tb) { |
895 | // overflow if high bits are not all sign bit (0 if positive, -1 if |
896 | // negative) or the sign bit of the low bits is not the expected sign bit. |
897 | std::vector<std::tuple<spirv::SType, std::string, size_t>> |
898 | struct_components_; |
899 | struct_components_.emplace_back(a.stype, "low" , 0); |
900 | struct_components_.emplace_back(a.stype, "high" , |
901 | ir_->get_primitive_type_size(a.stype.dt)); |
902 | auto struct_type = ir_->create_struct_type(struct_components_); |
903 | auto mul_ext = ir_->make_value(spv::OpSMulExtended, struct_type, a, b); |
904 | auto low = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 0); |
905 | auto high = ir_->make_value(spv::OpCompositeExtract, a.stype, mul_ext, 1); |
906 | auto zero = ir_->int_immediate_number(a.stype, 0); |
907 | auto minus_one = ir_->int_immediate_number(a.stype, -1); |
908 | auto a_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), a, zero); |
909 | auto b_sign = ir_->make_value(spv::OpSLessThan, ir_->bool_type(), b, zero); |
910 | auto a_not_zero = ir_->ne(a, zero); |
911 | auto b_not_zero = ir_->ne(b, zero); |
912 | auto a_b_not_zero = ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), |
913 | a_not_zero, b_not_zero); |
914 | auto low_sign = |
915 | ir_->make_value(spv::OpSLessThan, ir_->bool_type(), low, zero); |
916 | auto expected_sign = ir_->make_value(spv::OpLogicalNotEqual, |
917 | ir_->bool_type(), a_sign, b_sign); |
918 | expected_sign = ir_->make_value(spv::OpLogicalAnd, ir_->bool_type(), |
919 | expected_sign, a_b_not_zero); |
920 | auto not_expected_sign = ir_->ne(low_sign, expected_sign); |
921 | auto expected_high = ir_->select(expected_sign, minus_one, zero); |
922 | auto not_expected_high = ir_->ne(high, expected_high); |
923 | auto overflow = ir_->make_value(spv::OpLogicalOr, ir_->bool_type(), |
924 | not_expected_high, not_expected_sign); |
925 | generate_overflow_branch(overflow, "Multiplication" , tb); |
926 | return low; |
927 | } |
928 | |
929 | spirv::Value generate_ushl_overflow(const spirv::Value &a, |
930 | const spirv::Value &b, |
931 | const std::string &tb) { |
932 | // overflow iff a << b >> b != a |
933 | auto result = ir_->make_value(spv::OpShiftLeftLogical, a.stype, a, b); |
934 | auto restore = |
935 | ir_->make_value(spv::OpShiftRightLogical, a.stype, result, b); |
936 | auto overflow = ir_->ne(a, restore); |
937 | generate_overflow_branch(overflow, "Shift left" , tb); |
938 | return result; |
939 | } |
940 | |
941 | spirv::Value generate_sshl_overflow(const spirv::Value &a, |
942 | const spirv::Value &b, |
943 | const std::string &tb) { |
944 | // overflow iff a << b >> b != a |
945 | auto result = ir_->make_value(spv::OpShiftLeftLogical, a.stype, a, b); |
946 | auto restore = |
947 | ir_->make_value(spv::OpShiftRightArithmetic, a.stype, result, b); |
948 | auto overflow = ir_->ne(a, restore); |
949 | generate_overflow_branch(overflow, "Shift left" , tb); |
950 | return result; |
951 | } |
952 | |
953 | void visit(BinaryOpStmt *bin) override { |
954 | const auto lhs_name = bin->lhs->raw_name(); |
955 | const auto rhs_name = bin->rhs->raw_name(); |
956 | const auto bin_name = bin->raw_name(); |
957 | const auto op_type = bin->op_type; |
958 | |
959 | spirv::SType dst_type = ir_->get_primitive_type(bin->element_type()); |
960 | spirv::Value lhs_value = ir_->query_value(lhs_name); |
961 | spirv::Value rhs_value = ir_->query_value(rhs_name); |
962 | spirv::Value bin_value = spirv::Value(); |
963 | |
964 | TI_WARN_IF(lhs_value.stype.id != rhs_value.stype.id, |
965 | "${} type {} != ${} type {}\n{}" , lhs_name, |
966 | lhs_value.stype.dt->to_string(), rhs_name, |
967 | rhs_value.stype.dt->to_string(), bin->tb); |
968 | |
969 | bool debug = caps_->get(DeviceCapability::spirv_has_non_semantic_info); |
970 | |
971 | if (debug && op_type == BinaryOpType::add && is_integral(dst_type.dt)) { |
972 | if (is_unsigned(dst_type.dt)) { |
973 | bin_value = generate_uadd_overflow(lhs_value, rhs_value, bin->tb); |
974 | } else { |
975 | bin_value = generate_sadd_overflow(lhs_value, rhs_value, bin->tb); |
976 | } |
977 | bin_value = ir_->cast(dst_type, bin_value); |
978 | } else if (debug && op_type == BinaryOpType::sub && |
979 | is_integral(dst_type.dt)) { |
980 | if (is_unsigned(dst_type.dt)) { |
981 | bin_value = generate_usub_overflow(lhs_value, rhs_value, bin->tb); |
982 | } else { |
983 | bin_value = generate_ssub_overflow(lhs_value, rhs_value, bin->tb); |
984 | } |
985 | bin_value = ir_->cast(dst_type, bin_value); |
986 | } else if (debug && op_type == BinaryOpType::mul && |
987 | is_integral(dst_type.dt)) { |
988 | if (is_unsigned(dst_type.dt)) { |
989 | bin_value = generate_umul_overflow(lhs_value, rhs_value, bin->tb); |
990 | } else { |
991 | bin_value = generate_smul_overflow(lhs_value, rhs_value, bin->tb); |
992 | } |
993 | bin_value = ir_->cast(dst_type, bin_value); |
994 | } |
995 | #define BINARY_OP_TO_SPIRV_ARTHIMATIC(op, func) \ |
996 | else if (op_type == BinaryOpType::op) { \ |
997 | bin_value = ir_->func(lhs_value, rhs_value); \ |
998 | bin_value = ir_->cast(dst_type, bin_value); \ |
999 | } |
1000 | |
1001 | BINARY_OP_TO_SPIRV_ARTHIMATIC(add, add) |
1002 | BINARY_OP_TO_SPIRV_ARTHIMATIC(sub, sub) |
1003 | BINARY_OP_TO_SPIRV_ARTHIMATIC(mul, mul) |
1004 | BINARY_OP_TO_SPIRV_ARTHIMATIC(div, div) |
1005 | BINARY_OP_TO_SPIRV_ARTHIMATIC(mod, mod) |
1006 | #undef BINARY_OP_TO_SPIRV_ARTHIMATIC |
1007 | |
1008 | #define BINARY_OP_TO_SPIRV_BITWISE(op, sym) \ |
1009 | else if (op_type == BinaryOpType::op) { \ |
1010 | bin_value = ir_->make_value(spv::sym, dst_type, lhs_value, rhs_value); \ |
1011 | } |
1012 | |
1013 | else if (debug && op_type == BinaryOpType::bit_shl) { |
1014 | if (is_unsigned(dst_type.dt)) { |
1015 | bin_value = generate_ushl_overflow(lhs_value, rhs_value, bin->tb); |
1016 | } else { |
1017 | bin_value = generate_sshl_overflow(lhs_value, rhs_value, bin->tb); |
1018 | } |
1019 | } |
1020 | BINARY_OP_TO_SPIRV_BITWISE(bit_and, OpBitwiseAnd) |
1021 | BINARY_OP_TO_SPIRV_BITWISE(bit_or, OpBitwiseOr) |
1022 | BINARY_OP_TO_SPIRV_BITWISE(bit_xor, OpBitwiseXor) |
1023 | BINARY_OP_TO_SPIRV_BITWISE(bit_shl, OpShiftLeftLogical) |
1024 | // NOTE: `OpShiftRightArithmetic` will treat the first bit as sign bit even |
1025 | // it's the unsigned type |
1026 | else if (op_type == BinaryOpType::bit_sar) { |
1027 | bin_value = ir_->make_value(is_unsigned(dst_type.dt) |
1028 | ? spv::OpShiftRightLogical |
1029 | : spv::OpShiftRightArithmetic, |
1030 | dst_type, lhs_value, rhs_value); |
1031 | } |
1032 | #undef BINARY_OP_TO_SPIRV_BITWISE |
1033 | |
1034 | #define BINARY_OP_TO_SPIRV_LOGICAL(op, func) \ |
1035 | else if (op_type == BinaryOpType::op) { \ |
1036 | bin_value = ir_->func(lhs_value, rhs_value); \ |
1037 | bin_value = ir_->cast(dst_type, bin_value); \ |
1038 | bin_value = ir_->make_value(spv::OpSNegate, dst_type, bin_value); \ |
1039 | } |
1040 | |
1041 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_lt, lt) |
1042 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_le, le) |
1043 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_gt, gt) |
1044 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge) |
1045 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq) |
1046 | BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne) |
1047 | #undef BINARY_OP_TO_SPIRV_LOGICAL |
1048 | |
1049 | #define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \ |
1050 | max_bits) \ |
1051 | else if (op_type == BinaryOpType::op) { \ |
1052 | const uint32_t instruction = instruction_id; \ |
1053 | if (is_real(bin->element_type())) { \ |
1054 | if (data_type_bits(bin->element_type()) > max_bits) { \ |
1055 | TI_ERROR( \ |
1056 | "[glsl450] the operand type of instruction {}({}) must <= {}bits", \ |
1057 | #instruction, instruction_id, max_bits); \ |
1058 | } \ |
1059 | bin_value = \ |
1060 | ir_->call_glsl450(dst_type, instruction, lhs_value, rhs_value); \ |
1061 | } else { \ |
1062 | TI_NOT_IMPLEMENTED \ |
1063 | } \ |
1064 | } |
1065 | |
1066 | FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(atan2, Atan2, 25, 32) |
1067 | FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32) |
1068 | #undef FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC |
1069 | |
1070 | #define BINARY_OP_TO_SPIRV_FUNC(op, S_inst, S_inst_id, U_inst, U_inst_id, \ |
1071 | F_inst, F_inst_id) \ |
1072 | else if (op_type == BinaryOpType::op) { \ |
1073 | const uint32_t S_inst = S_inst_id; \ |
1074 | const uint32_t U_inst = U_inst_id; \ |
1075 | const uint32_t F_inst = F_inst_id; \ |
1076 | const auto dst_dt = bin->element_type(); \ |
1077 | if (is_integral(dst_dt)) { \ |
1078 | if (is_signed(dst_dt)) { \ |
1079 | bin_value = ir_->call_glsl450(dst_type, S_inst, lhs_value, rhs_value); \ |
1080 | } else { \ |
1081 | bin_value = ir_->call_glsl450(dst_type, U_inst, lhs_value, rhs_value); \ |
1082 | } \ |
1083 | } else if (is_real(dst_dt)) { \ |
1084 | bin_value = ir_->call_glsl450(dst_type, F_inst, lhs_value, rhs_value); \ |
1085 | } else { \ |
1086 | TI_NOT_IMPLEMENTED \ |
1087 | } \ |
1088 | } |
1089 | |
1090 | BINARY_OP_TO_SPIRV_FUNC(min, SMin, 39, UMin, 38, FMin, 37) |
1091 | BINARY_OP_TO_SPIRV_FUNC(max, SMax, 42, UMax, 41, FMax, 40) |
1092 | #undef BINARY_OP_TO_SPIRV_FUNC |
1093 | else if (op_type == BinaryOpType::truediv) { |
1094 | lhs_value = ir_->cast(dst_type, lhs_value); |
1095 | rhs_value = ir_->cast(dst_type, rhs_value); |
1096 | bin_value = ir_->div(lhs_value, rhs_value); |
1097 | } |
1098 | else {TI_NOT_IMPLEMENTED} ir_->register_value(bin_name, bin_value); |
1099 | } |
1100 | |
1101 | void visit(TernaryOpStmt *tri) override { |
1102 | TI_ASSERT(tri->op_type == TernaryOpType::select); |
1103 | spirv::Value op1 = ir_->query_value(tri->op1->raw_name()); |
1104 | spirv::Value op2 = ir_->query_value(tri->op2->raw_name()); |
1105 | spirv::Value op3 = ir_->query_value(tri->op3->raw_name()); |
1106 | spirv::Value tri_val = |
1107 | ir_->cast(ir_->get_primitive_type(tri->element_type()), |
1108 | ir_->select(ir_->cast(ir_->bool_type(), op1), op2, op3)); |
1109 | ir_->register_value(tri->raw_name(), tri_val); |
1110 | } |
1111 | |
1112 | inline bool ends_with(std::string const &value, std::string const &ending) { |
1113 | if (ending.size() > value.size()) |
1114 | return false; |
1115 | return std::equal(ending.rbegin(), ending.rend(), value.rbegin()); |
1116 | } |
1117 | |
1118 | void visit(TexturePtrStmt *stmt) override { |
1119 | spirv::Value val; |
1120 | |
1121 | int arg_id = stmt->arg_load_stmt->as<ArgLoadStmt>()->arg_id; |
1122 | if (argid_to_tex_value_.find(arg_id) != argid_to_tex_value_.end()) { |
1123 | val = argid_to_tex_value_.at(arg_id); |
1124 | } else { |
1125 | if (stmt->is_storage) { |
1126 | BufferFormat format = BufferFormat::unknown; |
1127 | |
1128 | if (stmt->num_channels == 1) { |
1129 | if (stmt->channel_format == PrimitiveType::u8 || |
1130 | stmt->channel_format == PrimitiveType::i8) { |
1131 | format = BufferFormat::r8; |
1132 | } else if (stmt->channel_format == PrimitiveType::u16 || |
1133 | stmt->channel_format == PrimitiveType::i16) { |
1134 | format = BufferFormat::r16; |
1135 | } else if (stmt->channel_format == PrimitiveType::f16) { |
1136 | format = BufferFormat::r16f; |
1137 | } else if (stmt->channel_format == PrimitiveType::f32) { |
1138 | format = BufferFormat::r32f; |
1139 | } |
1140 | } else if (stmt->num_channels == 2) { |
1141 | if (stmt->channel_format == PrimitiveType::u8 || |
1142 | stmt->channel_format == PrimitiveType::i8) { |
1143 | format = BufferFormat::rg8; |
1144 | } else if (stmt->channel_format == PrimitiveType::u16 || |
1145 | stmt->channel_format == PrimitiveType::i16) { |
1146 | format = BufferFormat::rg16; |
1147 | } else if (stmt->channel_format == PrimitiveType::f16) { |
1148 | format = BufferFormat::rg16f; |
1149 | } else if (stmt->channel_format == PrimitiveType::f32) { |
1150 | format = BufferFormat::rg32f; |
1151 | } |
1152 | } else if (stmt->num_channels == 4) { |
1153 | if (stmt->channel_format == PrimitiveType::u8 || |
1154 | stmt->channel_format == PrimitiveType::i8) { |
1155 | format = BufferFormat::rgba8; |
1156 | } else if (stmt->channel_format == PrimitiveType::u16 || |
1157 | stmt->channel_format == PrimitiveType::i16) { |
1158 | format = BufferFormat::rgba16; |
1159 | } else if (stmt->channel_format == PrimitiveType::f16) { |
1160 | format = BufferFormat::rgba16f; |
1161 | } else if (stmt->channel_format == PrimitiveType::f32) { |
1162 | format = BufferFormat::rgba32f; |
1163 | } |
1164 | } |
1165 | |
1166 | int binding = binding_head_++; |
1167 | val = |
1168 | ir_->storage_image_argument(/*num_channels=*/4, stmt->dimensions, |
1169 | /*descriptor_set=*/0, binding, format); |
1170 | TextureBind bind; |
1171 | bind.arg_id = arg_id; |
1172 | bind.binding = binding; |
1173 | bind.is_storage = true; |
1174 | texture_binds_.push_back(bind); |
1175 | argid_to_tex_value_[arg_id] = val; |
1176 | } else { |
1177 | int binding = binding_head_++; |
1178 | val = ir_->texture_argument(/*num_channels=*/4, stmt->dimensions, |
1179 | /*descriptor_set=*/0, binding); |
1180 | TextureBind bind; |
1181 | bind.arg_id = arg_id; |
1182 | bind.binding = binding; |
1183 | texture_binds_.push_back(bind); |
1184 | argid_to_tex_value_[arg_id] = val; |
1185 | } |
1186 | } |
1187 | |
1188 | ir_->register_value(stmt->raw_name(), val); |
1189 | } |
1190 | |
1191 | void visit(TextureOpStmt *stmt) override { |
1192 | spirv::Value tex = ir_->query_value(stmt->texture_ptr->raw_name()); |
1193 | spirv::Value val; |
1194 | if (stmt->op == TextureOpType::kSampleLod || |
1195 | stmt->op == TextureOpType::kFetchTexel) { |
1196 | // Texture Ops |
1197 | std::vector<spirv::Value> args; |
1198 | for (int i = 0; i < stmt->args.size() - 1; i++) { |
1199 | args.push_back(ir_->query_value(stmt->args[i]->raw_name())); |
1200 | } |
1201 | spirv::Value lod = ir_->query_value(stmt->args.back()->raw_name()); |
1202 | if (stmt->op == TextureOpType::kSampleLod) { |
1203 | // Sample |
1204 | val = ir_->sample_texture(tex, args, lod); |
1205 | } else if (stmt->op == TextureOpType::kFetchTexel) { |
1206 | // Texel fetch |
1207 | val = ir_->fetch_texel(tex, args, lod); |
1208 | } |
1209 | ir_->register_value(stmt->raw_name(), val); |
1210 | } else if (stmt->op == TextureOpType::kLoad || |
1211 | stmt->op == TextureOpType::kStore) { |
1212 | // Image Ops |
1213 | std::vector<spirv::Value> args; |
1214 | for (int i = 0; i < stmt->args.size(); i++) { |
1215 | args.push_back(ir_->query_value(stmt->args[i]->raw_name())); |
1216 | } |
1217 | if (stmt->op == TextureOpType::kLoad) { |
1218 | // Image Load |
1219 | val = ir_->image_load(tex, args); |
1220 | ir_->register_value(stmt->raw_name(), val); |
1221 | } else if (stmt->op == TextureOpType::kStore) { |
1222 | // Image Store |
1223 | ir_->image_store(tex, args); |
1224 | } |
1225 | } else { |
1226 | TI_NOT_IMPLEMENTED; |
1227 | } |
1228 | } |
1229 | |
1230 | void visit(InternalFuncStmt *stmt) override { |
1231 | spirv::Value val; |
1232 | |
1233 | if (stmt->func_name == "composite_extract_0" ) { |
1234 | val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(), |
1235 | ir_->query_value(stmt->args[0]->raw_name()), 0); |
1236 | } else if (stmt->func_name == "composite_extract_1" ) { |
1237 | val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(), |
1238 | ir_->query_value(stmt->args[0]->raw_name()), 1); |
1239 | } else if (stmt->func_name == "composite_extract_2" ) { |
1240 | val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(), |
1241 | ir_->query_value(stmt->args[0]->raw_name()), 2); |
1242 | } else if (stmt->func_name == "composite_extract_3" ) { |
1243 | val = ir_->make_value(spv::OpCompositeExtract, ir_->f32_type(), |
1244 | ir_->query_value(stmt->args[0]->raw_name()), 3); |
1245 | } |
1246 | |
1247 | const std::unordered_set<std::string> reduction_ops{ |
1248 | "subgroupAdd" , "subgroupMul" , "subgroupMin" , "subgroupMax" , |
1249 | "subgroupAnd" , "subgroupOr" , "subgroupXor" }; |
1250 | |
1251 | const std::unordered_set<std::string> inclusive_scan_ops{ |
1252 | "subgroupInclusiveAdd" , "subgroupInclusiveMul" , "subgroupInclusiveMin" , |
1253 | "subgroupInclusiveMax" , "subgroupInclusiveAnd" , "subgroupInclusiveOr" , |
1254 | "subgroupInclusiveXor" }; |
1255 | |
1256 | const std::unordered_set<std::string> shuffle_ops{ |
1257 | "subgroupShuffleDown" , "subgroupShuffleUp" , "subgroupShuffle" }; |
1258 | |
1259 | if (stmt->func_name == "workgroupBarrier" ) { |
1260 | ir_->make_inst( |
1261 | spv::OpControlBarrier, |
1262 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup), |
1263 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup), |
1264 | ir_->int_immediate_number( |
1265 | ir_->i32_type(), spv::MemorySemanticsWorkgroupMemoryMask | |
1266 | spv::MemorySemanticsAcquireReleaseMask)); |
1267 | val = ir_->const_i32_zero_; |
1268 | } else if (stmt->func_name == "localInvocationId" ) { |
1269 | val = ir_->cast(ir_->i32_type(), ir_->get_local_invocation_id(0)); |
1270 | } else if (stmt->func_name == "globalInvocationId" ) { |
1271 | val = ir_->cast(ir_->i32_type(), ir_->get_global_invocation_id(0)); |
1272 | } else if (stmt->func_name == "workgroupMemoryBarrier" ) { |
1273 | ir_->make_inst( |
1274 | spv::OpMemoryBarrier, |
1275 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeWorkgroup), |
1276 | ir_->int_immediate_number( |
1277 | ir_->i32_type(), spv::MemorySemanticsWorkgroupMemoryMask | |
1278 | spv::MemorySemanticsAcquireReleaseMask)); |
1279 | val = ir_->const_i32_zero_; |
1280 | } else if (stmt->func_name == "subgroupElect" ) { |
1281 | val = ir_->make_value( |
1282 | spv::OpGroupNonUniformElect, ir_->bool_type(), |
1283 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup)); |
1284 | val = ir_->cast(ir_->i32_type(), val); |
1285 | } else if (stmt->func_name == "subgroupBarrier" ) { |
1286 | ir_->make_inst( |
1287 | spv::OpControlBarrier, |
1288 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), |
1289 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), |
1290 | ir_->const_i32_zero_); |
1291 | val = ir_->const_i32_zero_; |
1292 | } else if (stmt->func_name == "subgroupMemoryBarrier" ) { |
1293 | ir_->make_inst( |
1294 | spv::OpMemoryBarrier, |
1295 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), |
1296 | ir_->const_i32_zero_); |
1297 | val = ir_->const_i32_zero_; |
1298 | } else if (stmt->func_name == "subgroupSize" ) { |
1299 | val = ir_->cast(ir_->i32_type(), ir_->get_subgroup_size()); |
1300 | } else if (stmt->func_name == "subgroupInvocationId" ) { |
1301 | val = ir_->cast(ir_->i32_type(), ir_->get_subgroup_invocation_id()); |
1302 | } else if (stmt->func_name == "subgroupBroadcast" ) { |
1303 | auto value = ir_->query_value(stmt->args[0]->raw_name()); |
1304 | auto index = ir_->query_value(stmt->args[1]->raw_name()); |
1305 | val = ir_->make_value( |
1306 | spv::OpGroupNonUniformBroadcast, value.stype, |
1307 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), value, |
1308 | index); |
1309 | } else if (reduction_ops.find(stmt->func_name) != reduction_ops.end() || |
1310 | inclusive_scan_ops.find(stmt->func_name) != |
1311 | inclusive_scan_ops.end()) { |
1312 | auto arg = ir_->query_value(stmt->args[0]->raw_name()); |
1313 | auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type); |
1314 | spv::Op spv_op; |
1315 | |
1316 | if (ends_with(stmt->func_name, "Add" )) { |
1317 | if (is_integral(stmt->args[0]->ret_type)) { |
1318 | spv_op = spv::OpGroupNonUniformIAdd; |
1319 | } else { |
1320 | spv_op = spv::OpGroupNonUniformFAdd; |
1321 | } |
1322 | } else if (ends_with(stmt->func_name, "Mul" )) { |
1323 | if (is_integral(stmt->args[0]->ret_type)) { |
1324 | spv_op = spv::OpGroupNonUniformIMul; |
1325 | } else { |
1326 | spv_op = spv::OpGroupNonUniformFMul; |
1327 | } |
1328 | } else if (ends_with(stmt->func_name, "Min" )) { |
1329 | if (is_integral(stmt->args[0]->ret_type)) { |
1330 | if (is_signed(stmt->args[0]->ret_type)) { |
1331 | spv_op = spv::OpGroupNonUniformSMin; |
1332 | } else { |
1333 | spv_op = spv::OpGroupNonUniformUMin; |
1334 | } |
1335 | } else { |
1336 | spv_op = spv::OpGroupNonUniformFMin; |
1337 | } |
1338 | } else if (ends_with(stmt->func_name, "Max" )) { |
1339 | if (is_integral(stmt->args[0]->ret_type)) { |
1340 | if (is_signed(stmt->args[0]->ret_type)) { |
1341 | spv_op = spv::OpGroupNonUniformSMax; |
1342 | } else { |
1343 | spv_op = spv::OpGroupNonUniformUMax; |
1344 | } |
1345 | } else { |
1346 | spv_op = spv::OpGroupNonUniformFMax; |
1347 | } |
1348 | } else if (ends_with(stmt->func_name, "And" )) { |
1349 | spv_op = spv::OpGroupNonUniformBitwiseAnd; |
1350 | } else if (ends_with(stmt->func_name, "Or" )) { |
1351 | spv_op = spv::OpGroupNonUniformBitwiseOr; |
1352 | } else if (ends_with(stmt->func_name, "Xor" )) { |
1353 | spv_op = spv::OpGroupNonUniformBitwiseXor; |
1354 | } else { |
1355 | TI_ERROR("Unsupported operation: {}" , stmt->func_name); |
1356 | } |
1357 | |
1358 | spv::GroupOperation group_op; |
1359 | |
1360 | if (reduction_ops.find(stmt->func_name) != reduction_ops.end()) { |
1361 | group_op = spv::GroupOperationReduce; |
1362 | } else if (inclusive_scan_ops.find(stmt->func_name) != |
1363 | inclusive_scan_ops.end()) { |
1364 | group_op = spv::GroupOperationInclusiveScan; |
1365 | } |
1366 | |
1367 | val = ir_->make_value( |
1368 | spv_op, stype, |
1369 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), |
1370 | group_op, arg); |
1371 | } else if (shuffle_ops.find(stmt->func_name) != shuffle_ops.end()) { |
1372 | auto arg0 = ir_->query_value(stmt->args[0]->raw_name()); |
1373 | auto arg1 = ir_->query_value(stmt->args[1]->raw_name()); |
1374 | auto stype = ir_->get_primitive_type(stmt->args[0]->ret_type); |
1375 | spv::Op spv_op; |
1376 | |
1377 | if (ends_with(stmt->func_name, "Down" )) { |
1378 | spv_op = spv::OpGroupNonUniformShuffleDown; |
1379 | } else if (ends_with(stmt->func_name, "Up" )) { |
1380 | spv_op = spv::OpGroupNonUniformShuffleUp; |
1381 | } else if (ends_with(stmt->func_name, "Shuffle" )) { |
1382 | spv_op = spv::OpGroupNonUniformShuffle; |
1383 | } else { |
1384 | TI_ERROR("Unsupported operation: {}" , stmt->func_name); |
1385 | } |
1386 | |
1387 | val = ir_->make_value( |
1388 | spv_op, stype, |
1389 | ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), arg0, |
1390 | arg1); |
1391 | } |
1392 | ir_->register_value(stmt->raw_name(), val); |
1393 | } |
1394 | |
1395 | void visit(AtomicOpStmt *stmt) override { |
1396 | const auto dt = stmt->dest->element_type().ptr_removed(); |
1397 | |
1398 | spirv::Value data = ir_->query_value(stmt->val->raw_name()); |
1399 | spirv::Value val; |
1400 | bool use_subgroup_reduction = false; |
1401 | |
1402 | if (stmt->is_reduction && |
1403 | caps_->get(DeviceCapability::spirv_has_subgroup_arithmetic)) { |
1404 | spv::Op atomic_op = spv::OpNop; |
1405 | bool negation = false; |
1406 | if (is_integral(dt)) { |
1407 | if (stmt->op_type == AtomicOpType::add) { |
1408 | atomic_op = spv::OpGroupIAdd; |
1409 | } else if (stmt->op_type == AtomicOpType::sub) { |
1410 | atomic_op = spv::OpGroupIAdd; |
1411 | negation = true; |
1412 | } else if (stmt->op_type == AtomicOpType::min) { |
1413 | atomic_op = is_signed(dt) ? spv::OpGroupSMin : spv::OpGroupUMin; |
1414 | } else if (stmt->op_type == AtomicOpType::max) { |
1415 | atomic_op = is_signed(dt) ? spv::OpGroupSMax : spv::OpGroupUMax; |
1416 | } |
1417 | } else if (is_real(dt)) { |
1418 | if (stmt->op_type == AtomicOpType::add) { |
1419 | atomic_op = spv::OpGroupFAdd; |
1420 | } else if (stmt->op_type == AtomicOpType::sub) { |
1421 | atomic_op = spv::OpGroupFAdd; |
1422 | negation = true; |
1423 | } else if (stmt->op_type == AtomicOpType::min) { |
1424 | atomic_op = spv::OpGroupFMin; |
1425 | } else if (stmt->op_type == AtomicOpType::max) { |
1426 | atomic_op = spv::OpGroupFMax; |
1427 | } |
1428 | } |
1429 | |
1430 | if (atomic_op != spv::OpNop) { |
1431 | spirv::Value scope_subgroup = |
1432 | ir_->int_immediate_number(ir_->i32_type(), 3); |
1433 | spirv::Value operation_reduce = ir_->const_i32_zero_; |
1434 | if (negation) { |
1435 | if (is_integral(dt)) { |
1436 | data = ir_->make_value(spv::OpSNegate, data.stype, data); |
1437 | } else { |
1438 | data = ir_->make_value(spv::OpFNegate, data.stype, data); |
1439 | } |
1440 | } |
1441 | data = ir_->make_value(atomic_op, ir_->get_primitive_type(dt), |
1442 | scope_subgroup, operation_reduce, data); |
1443 | val = data; |
1444 | use_subgroup_reduction = true; |
1445 | } |
1446 | } |
1447 | |
1448 | spirv::Label then_label; |
1449 | spirv::Label merge_label; |
1450 | |
1451 | if (use_subgroup_reduction) { |
1452 | spirv::Value subgroup_id = ir_->get_subgroup_invocation_id(); |
1453 | spirv::Value cond = ir_->make_value(spv::OpIEqual, ir_->bool_type(), |
1454 | subgroup_id, ir_->const_i32_zero_); |
1455 | |
1456 | then_label = ir_->new_label(); |
1457 | merge_label = ir_->new_label(); |
1458 | ir_->make_inst(spv::OpSelectionMerge, merge_label, |
1459 | spv::SelectionControlMaskNone); |
1460 | ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label); |
1461 | ir_->start_label(then_label); |
1462 | } |
1463 | |
1464 | spirv::Value addr_ptr; |
1465 | |
1466 | if (dt->is_primitive(PrimitiveTypeID::f64)) { |
1467 | if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) && |
1468 | stmt->op_type == AtomicOpType::add) { |
1469 | addr_ptr = at_buffer(stmt->dest, dt); |
1470 | } else { |
1471 | addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt)); |
1472 | } |
1473 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
1474 | if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) && |
1475 | stmt->op_type == AtomicOpType::add) { |
1476 | addr_ptr = at_buffer(stmt->dest, dt); |
1477 | } else { |
1478 | addr_ptr = at_buffer(stmt->dest, ir_->get_taichi_uint_type(dt)); |
1479 | } |
1480 | } else { |
1481 | addr_ptr = at_buffer(stmt->dest, dt); |
1482 | } |
1483 | |
1484 | auto ret_type = ir_->get_primitive_type(dt); |
1485 | |
1486 | if (is_real(dt)) { |
1487 | spv::Op atomic_fp_op; |
1488 | if (stmt->op_type == AtomicOpType::add) { |
1489 | atomic_fp_op = spv::OpAtomicFAddEXT; |
1490 | } |
1491 | |
1492 | bool use_native_atomics = false; |
1493 | |
1494 | if (dt->is_primitive(PrimitiveTypeID::f64)) { |
1495 | if (caps_->get(DeviceCapability::spirv_has_atomic_float64_add) && |
1496 | stmt->op_type == AtomicOpType::add) { |
1497 | use_native_atomics = true; |
1498 | } |
1499 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
1500 | if (caps_->get(DeviceCapability::spirv_has_atomic_float_add) && |
1501 | stmt->op_type == AtomicOpType::add) { |
1502 | use_native_atomics = true; |
1503 | } |
1504 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
1505 | if (caps_->get(DeviceCapability::spirv_has_atomic_float16_add) && |
1506 | stmt->op_type == AtomicOpType::add) { |
1507 | use_native_atomics = true; |
1508 | } |
1509 | } |
1510 | |
1511 | if (use_native_atomics) { |
1512 | val = |
1513 | ir_->make_value(atomic_fp_op, ir_->get_primitive_type(dt), addr_ptr, |
1514 | /*scope=*/ir_->const_i32_one_, |
1515 | /*semantics=*/ir_->const_i32_zero_, data); |
1516 | } else { |
1517 | val = ir_->float_atomic(stmt->op_type, addr_ptr, data); |
1518 | } |
1519 | } else if (is_integral(dt)) { |
1520 | spv::Op op; |
1521 | if (stmt->op_type == AtomicOpType::add) { |
1522 | op = spv::OpAtomicIAdd; |
1523 | } else if (stmt->op_type == AtomicOpType::sub) { |
1524 | op = spv::OpAtomicISub; |
1525 | } else if (stmt->op_type == AtomicOpType::min) { |
1526 | op = is_signed(dt) ? spv::OpAtomicSMin : spv::OpAtomicUMin; |
1527 | } else if (stmt->op_type == AtomicOpType::max) { |
1528 | op = is_signed(dt) ? spv::OpAtomicSMax : spv::OpAtomicUMax; |
1529 | } else if (stmt->op_type == AtomicOpType::bit_or) { |
1530 | op = spv::OpAtomicOr; |
1531 | } else if (stmt->op_type == AtomicOpType::bit_and) { |
1532 | op = spv::OpAtomicAnd; |
1533 | } else if (stmt->op_type == AtomicOpType::bit_xor) { |
1534 | op = spv::OpAtomicXor; |
1535 | } else { |
1536 | TI_NOT_IMPLEMENTED |
1537 | } |
1538 | |
1539 | auto uint_type = ir_->get_primitive_uint_type(dt); |
1540 | |
1541 | if (data.stype.id != addr_ptr.stype.element_type_id) { |
1542 | data = ir_->make_value(spv::OpBitcast, ret_type, data); |
1543 | } |
1544 | |
1545 | // Semantics = (UniformMemory 0x40) | (AcquireRelease 0x8) |
1546 | ir_->make_inst( |
1547 | spv::OpMemoryBarrier, ir_->const_i32_one_, |
1548 | ir_->uint_immediate_number( |
1549 | ir_->u32_type(), spv::MemorySemanticsAcquireReleaseMask | |
1550 | spv::MemorySemanticsUniformMemoryMask)); |
1551 | val = ir_->make_value(op, ret_type, addr_ptr, |
1552 | /*scope=*/ir_->const_i32_one_, |
1553 | /*semantics=*/ir_->const_i32_zero_, data); |
1554 | |
1555 | if (val.stype.id != ret_type.id) { |
1556 | val = ir_->make_value(spv::OpBitcast, ret_type, val); |
1557 | } |
1558 | } else { |
1559 | TI_NOT_IMPLEMENTED |
1560 | } |
1561 | |
1562 | if (use_subgroup_reduction) { |
1563 | ir_->make_inst(spv::OpBranch, merge_label); |
1564 | ir_->start_label(merge_label); |
1565 | } |
1566 | |
1567 | ir_->register_value(stmt->raw_name(), val); |
1568 | } |
1569 | |
1570 | void visit(IfStmt *if_stmt) override { |
1571 | spirv::Value cond_v = ir_->query_value(if_stmt->cond->raw_name()); |
1572 | spirv::Value cond = |
1573 | ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_)); |
1574 | spirv::Label then_label = ir_->new_label(); |
1575 | spirv::Label merge_label = ir_->new_label(); |
1576 | spirv::Label else_label = ir_->new_label(); |
1577 | ir_->make_inst(spv::OpSelectionMerge, merge_label, |
1578 | spv::SelectionControlMaskNone); |
1579 | ir_->make_inst(spv::OpBranchConditional, cond, then_label, else_label); |
1580 | // then block |
1581 | ir_->start_label(then_label); |
1582 | if (if_stmt->true_statements) { |
1583 | if_stmt->true_statements->accept(this); |
1584 | } |
1585 | // ContinueStmt must be in IfStmt |
1586 | if (gen_label_) { // Skip OpBranch, because ContinueStmt already generated |
1587 | // one |
1588 | gen_label_ = false; |
1589 | } else { |
1590 | ir_->make_inst(spv::OpBranch, merge_label); |
1591 | } |
1592 | // else block |
1593 | ir_->start_label(else_label); |
1594 | if (if_stmt->false_statements) { |
1595 | if_stmt->false_statements->accept(this); |
1596 | } |
1597 | if (gen_label_) { |
1598 | gen_label_ = false; |
1599 | } else { |
1600 | ir_->make_inst(spv::OpBranch, merge_label); |
1601 | } |
1602 | // merge label |
1603 | ir_->start_label(merge_label); |
1604 | } |
1605 | |
1606 | void visit(RangeForStmt *for_stmt) override { |
1607 | auto loop_var_name = for_stmt->raw_name(); |
1608 | // Must get init label after making value(to make sure they are correct) |
1609 | spirv::Label init_label = ir_->current_label(); |
1610 | spirv::Label head_label = ir_->new_label(); |
1611 | spirv::Label body_label = ir_->new_label(); |
1612 | spirv::Label continue_label = ir_->new_label(); |
1613 | spirv::Label merge_label = ir_->new_label(); |
1614 | |
1615 | spirv::Value begin_ = ir_->query_value(for_stmt->begin->raw_name()); |
1616 | spirv::Value end_ = ir_->query_value(for_stmt->end->raw_name()); |
1617 | spirv::Value init_value; |
1618 | spirv::Value extent_value; |
1619 | if (!for_stmt->reversed) { |
1620 | init_value = begin_; |
1621 | extent_value = end_; |
1622 | } else { |
1623 | // reversed for loop |
1624 | init_value = ir_->sub(end_, ir_->const_i32_one_); |
1625 | extent_value = begin_; |
1626 | } |
1627 | ir_->make_inst(spv::OpBranch, head_label); |
1628 | |
1629 | // Loop head |
1630 | ir_->start_label(head_label); |
1631 | spirv::PhiValue loop_var = ir_->make_phi(init_value.stype, 2); |
1632 | loop_var.set_incoming(0, init_value, init_label); |
1633 | spirv::Value loop_cond; |
1634 | if (!for_stmt->reversed) { |
1635 | loop_cond = ir_->lt(loop_var, extent_value); |
1636 | } else { |
1637 | loop_cond = ir_->ge(loop_var, extent_value); |
1638 | } |
1639 | ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label, |
1640 | spv::LoopControlMaskNone); |
1641 | ir_->make_inst(spv::OpBranchConditional, loop_cond, body_label, |
1642 | merge_label); |
1643 | |
1644 | // loop body |
1645 | ir_->start_label(body_label); |
1646 | push_loop_control_labels(continue_label, merge_label); |
1647 | ir_->register_value(loop_var_name, spirv::Value(loop_var)); |
1648 | for_stmt->body->accept(this); |
1649 | pop_loop_control_labels(); |
1650 | ir_->make_inst(spv::OpBranch, continue_label); |
1651 | |
1652 | // loop continue |
1653 | ir_->start_label(continue_label); |
1654 | spirv::Value next_value; |
1655 | if (!for_stmt->reversed) { |
1656 | next_value = ir_->add(loop_var, ir_->const_i32_one_); |
1657 | } else { |
1658 | next_value = ir_->sub(loop_var, ir_->const_i32_one_); |
1659 | } |
1660 | loop_var.set_incoming(1, next_value, ir_->current_label()); |
1661 | ir_->make_inst(spv::OpBranch, head_label); |
1662 | // loop merge |
1663 | ir_->start_label(merge_label); |
1664 | } |
1665 | |
1666 | void visit(WhileStmt *stmt) override { |
1667 | spirv::Label head_label = ir_->new_label(); |
1668 | spirv::Label body_label = ir_->new_label(); |
1669 | spirv::Label continue_label = ir_->new_label(); |
1670 | spirv::Label merge_label = ir_->new_label(); |
1671 | ir_->make_inst(spv::OpBranch, head_label); |
1672 | |
1673 | // Loop head |
1674 | ir_->start_label(head_label); |
1675 | ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label, |
1676 | spv::LoopControlMaskNone); |
1677 | ir_->make_inst(spv::OpBranch, body_label); |
1678 | |
1679 | // loop body |
1680 | ir_->start_label(body_label); |
1681 | push_loop_control_labels(continue_label, merge_label); |
1682 | stmt->body->accept(this); |
1683 | pop_loop_control_labels(); |
1684 | ir_->make_inst(spv::OpBranch, continue_label); |
1685 | |
1686 | // loop continue |
1687 | ir_->start_label(continue_label); |
1688 | ir_->make_inst(spv::OpBranch, head_label); |
1689 | |
1690 | // loop merge |
1691 | ir_->start_label(merge_label); |
1692 | } |
1693 | |
1694 | void visit(WhileControlStmt *stmt) override { |
1695 | spirv::Value cond_v = ir_->query_value(stmt->cond->raw_name()); |
1696 | spirv::Value cond = |
1697 | ir_->eq(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_)); |
1698 | spirv::Label then_label = ir_->new_label(); |
1699 | spirv::Label merge_label = ir_->new_label(); |
1700 | |
1701 | ir_->make_inst(spv::OpSelectionMerge, merge_label, |
1702 | spv::SelectionControlMaskNone); |
1703 | ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label); |
1704 | ir_->start_label(then_label); |
1705 | ir_->make_inst(spv::OpBranch, current_merge_label()); // break; |
1706 | ir_->start_label(merge_label); |
1707 | } |
1708 | |
1709 | void visit(ContinueStmt *stmt) override { |
1710 | auto stmt_in_off_for = [stmt]() { |
1711 | TI_ASSERT(stmt->scope != nullptr); |
1712 | if (auto *offl = stmt->scope->cast<OffloadedStmt>(); offl) { |
1713 | TI_ASSERT(offl->task_type == OffloadedStmt::TaskType::range_for || |
1714 | offl->task_type == OffloadedStmt::TaskType::struct_for); |
1715 | return true; |
1716 | } |
1717 | return false; |
1718 | }; |
1719 | if (stmt_in_off_for()) { |
1720 | // Return means end THIS main loop and start next loop, not exit kernel |
1721 | ir_->make_inst(spv::OpBranch, return_label()); |
1722 | } else { |
1723 | ir_->make_inst(spv::OpBranch, current_continue_label()); |
1724 | } |
1725 | gen_label_ = true; // Only ContinueStmt will cause duplicate OpBranch, |
1726 | // which should be eliminated |
1727 | } |
1728 | |
1729 | private: |
1730 | void () { |
1731 | /* |
1732 | for (int root = 0; root < compiled_structs_.size(); ++root) { |
1733 | get_buffer_value({BufferType::Root, root}); |
1734 | } |
1735 | */ |
1736 | std::array<int, 3> group_size = { |
1737 | task_attribs_.advisory_num_threads_per_group, 1, 1}; |
1738 | ir_->set_work_group_size(group_size); |
1739 | std::vector<spirv::Value> buffers; |
1740 | if (caps_->get(DeviceCapability::spirv_version) > 0x10300) { |
1741 | buffers = shared_array_binds_; |
1742 | // One buffer can be bound to different bind points but has to be unique |
1743 | // in OpEntryPoint interface declarations. |
1744 | // From Spec: before SPIR-V version 1.4, duplication of these interface id |
1745 | // is tolerated. Starting with version 1.4, an interface id must not |
1746 | // appear more than once. |
1747 | std::unordered_set<spirv::Value, spirv::ValueHasher> entry_point_values; |
1748 | for (const auto &bb : task_attribs_.buffer_binds) { |
1749 | for (auto &it : buffer_value_map_) { |
1750 | if (it.first.first == bb.buffer) { |
1751 | entry_point_values.insert(it.second); |
1752 | } |
1753 | } |
1754 | } |
1755 | buffers.insert(buffers.end(), entry_point_values.begin(), |
1756 | entry_point_values.end()); |
1757 | } |
1758 | ir_->commit_kernel_function(kernel_function_, "main" , buffers, |
1759 | group_size); // kernel entry |
1760 | } |
1761 | |
1762 | void generate_serial_kernel(OffloadedStmt *stmt) { |
1763 | task_attribs_.name = task_name_; |
1764 | task_attribs_.task_type = OffloadedTaskType::serial; |
1765 | task_attribs_.advisory_total_num_threads = 1; |
1766 | task_attribs_.advisory_num_threads_per_group = 1; |
1767 | |
1768 | // The computation for a single work is wrapped inside a function, so that |
1769 | // we can do grid-strided loop. |
1770 | ir_->start_function(kernel_function_); |
1771 | spirv::Value cond = |
1772 | ir_->eq(ir_->get_global_invocation_id(0), |
1773 | ir_->uint_immediate_number( |
1774 | ir_->u32_type(), 0)); // if (gl_GlobalInvocationID.x > 0) |
1775 | spirv::Label then_label = ir_->new_label(); |
1776 | spirv::Label merge_label = ir_->new_label(); |
1777 | kernel_return_label_ = merge_label; |
1778 | |
1779 | ir_->make_inst(spv::OpSelectionMerge, merge_label, |
1780 | spv::SelectionControlMaskNone); |
1781 | ir_->make_inst(spv::OpBranchConditional, cond, then_label, merge_label); |
1782 | ir_->start_label(then_label); |
1783 | |
1784 | // serial kernel |
1785 | stmt->body->accept(this); |
1786 | |
1787 | ir_->make_inst(spv::OpBranch, merge_label); |
1788 | ir_->start_label(merge_label); |
1789 | ir_->make_inst(spv::OpReturn); // return; |
1790 | ir_->make_inst(spv::OpFunctionEnd); // } Close kernel |
1791 | |
1792 | task_attribs_.buffer_binds = get_buffer_binds(); |
1793 | task_attribs_.texture_binds = get_texture_binds(); |
1794 | } |
1795 | |
1796 | void gen_array_range(Stmt *stmt) { |
1797 | int num_operands = stmt->num_operands(); |
1798 | for (int i = 0; i < num_operands; i++) { |
1799 | gen_array_range(stmt->operand(i)); |
1800 | } |
1801 | offload_loop_motion_.insert(stmt); |
1802 | stmt->accept(this); |
1803 | } |
1804 | |
1805 | void generate_range_for_kernel(OffloadedStmt *stmt) { |
1806 | task_attribs_.name = task_name_; |
1807 | task_attribs_.task_type = OffloadedTaskType::range_for; |
1808 | |
1809 | task_attribs_.range_for_attribs = TaskAttributes::RangeForAttributes(); |
1810 | auto &range_for_attribs = task_attribs_.range_for_attribs.value(); |
1811 | range_for_attribs.const_begin = stmt->const_begin; |
1812 | range_for_attribs.const_end = stmt->const_end; |
1813 | range_for_attribs.begin = |
1814 | (stmt->const_begin ? stmt->begin_value : stmt->begin_offset); |
1815 | range_for_attribs.end = |
1816 | (stmt->const_end ? stmt->end_value : stmt->end_offset); |
1817 | |
1818 | ir_->start_function(kernel_function_); |
1819 | const std::string total_elems_name("total_elems" ); |
1820 | spirv::Value total_elems; |
1821 | spirv::Value begin_expr_value; |
1822 | if (range_for_attribs.const_range()) { |
1823 | const int num_elems = range_for_attribs.end - range_for_attribs.begin; |
1824 | begin_expr_value = ir_->int_immediate_number( |
1825 | ir_->i32_type(), stmt->begin_value, false); // Named Constant |
1826 | total_elems = ir_->int_immediate_number(ir_->i32_type(), num_elems, |
1827 | false); // Named Constant |
1828 | task_attribs_.advisory_total_num_threads = num_elems; |
1829 | } else { |
1830 | spirv::Value end_expr_value; |
1831 | if (stmt->end_stmt) { |
1832 | // Range from args |
1833 | TI_ASSERT(stmt->const_begin); |
1834 | begin_expr_value = ir_->int_immediate_number(ir_->i32_type(), |
1835 | stmt->begin_value, false); |
1836 | gen_array_range(stmt->end_stmt); |
1837 | end_expr_value = ir_->query_value(stmt->end_stmt->raw_name()); |
1838 | } else { |
1839 | // Range from gtmp / constant |
1840 | if (!stmt->const_begin) { |
1841 | spirv::Value begin_idx = ir_->make_value( |
1842 | spv::OpShiftRightArithmetic, ir_->i32_type(), |
1843 | ir_->int_immediate_number(ir_->i32_type(), stmt->begin_offset), |
1844 | ir_->int_immediate_number(ir_->i32_type(), 2)); |
1845 | begin_expr_value = ir_->load_variable( |
1846 | ir_->struct_array_access( |
1847 | ir_->i32_type(), |
1848 | get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32), |
1849 | begin_idx), |
1850 | ir_->i32_type()); |
1851 | } else { |
1852 | begin_expr_value = ir_->int_immediate_number( |
1853 | ir_->i32_type(), stmt->begin_value, false); // Named Constant |
1854 | } |
1855 | if (!stmt->const_end) { |
1856 | spirv::Value end_idx = ir_->make_value( |
1857 | spv::OpShiftRightArithmetic, ir_->i32_type(), |
1858 | ir_->int_immediate_number(ir_->i32_type(), stmt->end_offset), |
1859 | ir_->int_immediate_number(ir_->i32_type(), 2)); |
1860 | end_expr_value = ir_->load_variable( |
1861 | ir_->struct_array_access( |
1862 | ir_->i32_type(), |
1863 | get_buffer_value(BufferType::GlobalTmps, PrimitiveType::i32), |
1864 | end_idx), |
1865 | ir_->i32_type()); |
1866 | } else { |
1867 | end_expr_value = |
1868 | ir_->int_immediate_number(ir_->i32_type(), stmt->end_value, true); |
1869 | } |
1870 | } |
1871 | total_elems = ir_->sub(end_expr_value, begin_expr_value); |
1872 | task_attribs_.advisory_total_num_threads = kMaxNumThreadsGridStrideLoop; |
1873 | } |
1874 | task_attribs_.advisory_num_threads_per_group = stmt->block_dim; |
1875 | ir_->debug_name(spv::OpName, begin_expr_value, "begin_expr_value" ); |
1876 | ir_->debug_name(spv::OpName, total_elems, total_elems_name); |
1877 | |
1878 | spirv::Value begin_ = |
1879 | ir_->add(ir_->cast(ir_->i32_type(), ir_->get_global_invocation_id(0)), |
1880 | begin_expr_value); |
1881 | ir_->debug_name(spv::OpName, begin_, "begin_" ); |
1882 | spirv::Value end_ = ir_->add(total_elems, begin_expr_value); |
1883 | ir_->debug_name(spv::OpName, end_, "end_" ); |
1884 | const std::string total_invocs_name = "total_invocs" ; |
1885 | // For now, |total_invocs_name| is equal to |total_elems|. Once we support |
1886 | // dynamic range, they will be different. |
1887 | // https://www.khronos.org/opengl/wiki/Compute_Shader#Inputs |
1888 | |
1889 | // HLSL & WGSL cross compilers do not support this builtin |
1890 | spirv::Value total_invocs = ir_->cast( |
1891 | ir_->i32_type(), |
1892 | ir_->mul(ir_->get_num_work_groups(0), |
1893 | ir_->uint_immediate_number( |
1894 | ir_->u32_type(), |
1895 | task_attribs_.advisory_num_threads_per_group, true))); |
1896 | /* |
1897 | const int group_x = (task_attribs_.advisory_total_num_threads + |
1898 | task_attribs_.advisory_num_threads_per_group - 1) / |
1899 | task_attribs_.advisory_num_threads_per_group; |
1900 | spirv::Value total_invocs = ir_->uint_immediate_number( |
1901 | ir_->i32_type(), group_x * task_attribs_.advisory_num_threads_per_group, |
1902 | false); |
1903 | */ |
1904 | |
1905 | ir_->debug_name(spv::OpName, total_invocs, total_invocs_name); |
1906 | |
1907 | // Must get init label after making value(to make sure they are correct) |
1908 | spirv::Label init_label = ir_->current_label(); |
1909 | spirv::Label head_label = ir_->new_label(); |
1910 | spirv::Label body_label = ir_->new_label(); |
1911 | spirv::Label continue_label = ir_->new_label(); |
1912 | spirv::Label merge_label = ir_->new_label(); |
1913 | ir_->make_inst(spv::OpBranch, head_label); |
1914 | |
1915 | // loop head |
1916 | ir_->start_label(head_label); |
1917 | spirv::PhiValue loop_var = ir_->make_phi(begin_.stype, 2); |
1918 | ir_->register_value("ii" , loop_var); |
1919 | loop_var.set_incoming(0, begin_, init_label); |
1920 | spirv::Value loop_cond = ir_->lt(loop_var, end_); |
1921 | ir_->make_inst(spv::OpLoopMerge, merge_label, continue_label, |
1922 | spv::LoopControlMaskNone); |
1923 | ir_->make_inst(spv::OpBranchConditional, loop_cond, body_label, |
1924 | merge_label); |
1925 | |
1926 | // loop body |
1927 | ir_->start_label(body_label); |
1928 | push_loop_control_labels(continue_label, merge_label); |
1929 | |
1930 | // loop kernel |
1931 | stmt->body->accept(this); |
1932 | pop_loop_control_labels(); |
1933 | ir_->make_inst(spv::OpBranch, continue_label); |
1934 | |
1935 | // loop continue |
1936 | ir_->start_label(continue_label); |
1937 | spirv::Value next_value = ir_->add(loop_var, total_invocs); |
1938 | loop_var.set_incoming(1, next_value, ir_->current_label()); |
1939 | ir_->make_inst(spv::OpBranch, head_label); |
1940 | |
1941 | // loop merge |
1942 | ir_->start_label(merge_label); |
1943 | |
1944 | ir_->make_inst(spv::OpReturn); |
1945 | ir_->make_inst(spv::OpFunctionEnd); |
1946 | |
1947 | task_attribs_.buffer_binds = get_buffer_binds(); |
1948 | task_attribs_.texture_binds = get_texture_binds(); |
1949 | } |
1950 | |
1951 | void generate_struct_for_kernel(OffloadedStmt *stmt) { |
1952 | task_attribs_.name = task_name_; |
1953 | task_attribs_.task_type = OffloadedTaskType::struct_for; |
1954 | task_attribs_.advisory_total_num_threads = 65536; |
1955 | task_attribs_.advisory_num_threads_per_group = 128; |
1956 | |
1957 | // The computation for a single work is wrapped inside a function, so that |
1958 | // we can do grid-strided loop. |
1959 | ir_->start_function(kernel_function_); |
1960 | |
1961 | auto listgen_buffer = |
1962 | get_buffer_value(BufferType::ListGen, PrimitiveType::u32); |
1963 | auto listgen_count_ptr = ir_->struct_array_access( |
1964 | ir_->u32_type(), listgen_buffer, ir_->const_i32_zero_); |
1965 | auto listgen_count = ir_->load_variable(listgen_count_ptr, ir_->u32_type()); |
1966 | |
1967 | auto invoc_index = ir_->get_global_invocation_id(0); |
1968 | |
1969 | spirv::Label loop_head = ir_->new_label(); |
1970 | spirv::Label loop_body = ir_->new_label(); |
1971 | spirv::Label loop_merge = ir_->new_label(); |
1972 | |
1973 | auto loop_index_var = ir_->alloca_variable(ir_->u32_type()); |
1974 | ir_->store_variable(loop_index_var, invoc_index); |
1975 | |
1976 | ir_->make_inst(spv::OpBranch, loop_head); |
1977 | ir_->start_label(loop_head); |
1978 | // for (; index < list_size; index += gl_NumWorkGroups.x * |
1979 | // gl_WorkGroupSize.x) |
1980 | auto loop_index = ir_->load_variable(loop_index_var, ir_->u32_type()); |
1981 | auto loop_cond = ir_->make_value(spv::OpULessThan, ir_->bool_type(), |
1982 | loop_index, listgen_count); |
1983 | ir_->make_inst(spv::OpLoopMerge, loop_merge, loop_body, |
1984 | spv::LoopControlMaskNone); |
1985 | ir_->make_inst(spv::OpBranchConditional, loop_cond, loop_body, loop_merge); |
1986 | { |
1987 | ir_->start_label(loop_body); |
1988 | auto listgen_index_ptr = ir_->struct_array_access( |
1989 | ir_->u32_type(), listgen_buffer, |
1990 | ir_->add(ir_->uint_immediate_number(ir_->u32_type(), 1), loop_index)); |
1991 | auto listgen_index = |
1992 | ir_->load_variable(listgen_index_ptr, ir_->u32_type()); |
1993 | |
1994 | // kernel |
1995 | ir_->register_value("ii" , listgen_index); |
1996 | stmt->body->accept(this); |
1997 | |
1998 | // continue |
1999 | spirv::Value total_invocs = ir_->cast( |
2000 | ir_->u32_type(), |
2001 | ir_->mul(ir_->get_num_work_groups(0), |
2002 | ir_->uint_immediate_number( |
2003 | ir_->u32_type(), |
2004 | task_attribs_.advisory_num_threads_per_group, true))); |
2005 | auto next_index = ir_->add(loop_index, total_invocs); |
2006 | ir_->store_variable(loop_index_var, next_index); |
2007 | ir_->make_inst(spv::OpBranch, loop_head); |
2008 | } |
2009 | ir_->start_label(loop_merge); |
2010 | |
2011 | ir_->make_inst(spv::OpReturn); // return; |
2012 | ir_->make_inst(spv::OpFunctionEnd); // } Close kernel |
2013 | |
2014 | task_attribs_.buffer_binds = get_buffer_binds(); |
2015 | task_attribs_.texture_binds = get_texture_binds(); |
2016 | } |
2017 | |
2018 | spirv::Value at_buffer(const Stmt *ptr, DataType dt) { |
2019 | spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); |
2020 | |
2021 | if (ptr_val.stype.dt == PrimitiveType::u64) { |
2022 | spirv::Value paddr_ptr = ir_->make_value( |
2023 | spv::OpConvertUToPtr, |
2024 | ir_->get_pointer_type(ir_->get_primitive_type(dt), |
2025 | spv::StorageClassPhysicalStorageBuffer), |
2026 | ptr_val); |
2027 | paddr_ptr.flag = ValueKind::kPhysicalPtr; |
2028 | return paddr_ptr; |
2029 | } |
2030 | |
2031 | spirv::Value buffer = get_buffer_value(ptr_to_buffers_.at(ptr), dt); |
2032 | size_t width = ir_->get_primitive_type_size(dt); |
2033 | spirv::Value idx_val = ir_->make_value( |
2034 | spv::OpShiftRightLogical, ptr_val.stype, ptr_val, |
2035 | ir_->uint_immediate_number(ptr_val.stype, size_t(std::log2(width)))); |
2036 | spirv::Value ret = |
2037 | ir_->struct_array_access(ir_->get_primitive_type(dt), buffer, idx_val); |
2038 | return ret; |
2039 | } |
2040 | |
2041 | spirv::Value load_buffer(const Stmt *ptr, DataType dt) { |
2042 | spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); |
2043 | |
2044 | DataType ti_buffer_type = ir_->get_taichi_uint_type(dt); |
2045 | |
2046 | if (ptr_val.stype.dt == PrimitiveType::u64) { |
2047 | ti_buffer_type = dt; |
2048 | } |
2049 | |
2050 | auto buf_ptr = at_buffer(ptr, ti_buffer_type); |
2051 | auto val_bits = |
2052 | ir_->load_variable(buf_ptr, ir_->get_primitive_type(ti_buffer_type)); |
2053 | auto ret = ti_buffer_type == dt |
2054 | ? val_bits |
2055 | : ir_->make_value(spv::OpBitcast, |
2056 | ir_->get_primitive_type(dt), val_bits); |
2057 | return ret; |
2058 | } |
2059 | |
2060 | void store_buffer(const Stmt *ptr, spirv::Value val) { |
2061 | spirv::Value ptr_val = ir_->query_value(ptr->raw_name()); |
2062 | |
2063 | DataType ti_buffer_type = ir_->get_taichi_uint_type(val.stype.dt); |
2064 | |
2065 | if (ptr_val.stype.dt == PrimitiveType::u64) { |
2066 | ti_buffer_type = val.stype.dt; |
2067 | } |
2068 | |
2069 | auto buf_ptr = at_buffer(ptr, ti_buffer_type); |
2070 | auto val_bits = |
2071 | val.stype.dt == ti_buffer_type |
2072 | ? val |
2073 | : ir_->make_value(spv::OpBitcast, |
2074 | ir_->get_primitive_type(ti_buffer_type), val); |
2075 | ir_->store_variable(buf_ptr, val_bits); |
2076 | } |
2077 | |
2078 | spirv::Value get_buffer_value(BufferInfo buffer, DataType dt) { |
2079 | auto type = ir_->get_primitive_type(dt); |
2080 | auto key = std::make_pair(buffer, type.id); |
2081 | |
2082 | const auto it = buffer_value_map_.find(key); |
2083 | if (it != buffer_value_map_.end()) { |
2084 | return it->second; |
2085 | } |
2086 | |
2087 | if (buffer.type == BufferType::Args) { |
2088 | compile_args_struct(); |
2089 | |
2090 | buffer_binding_map_[key] = 0; |
2091 | buffer_value_map_[key] = args_buffer_value_; |
2092 | return args_buffer_value_; |
2093 | } |
2094 | |
2095 | if (buffer.type == BufferType::Rets) { |
2096 | compile_ret_struct(); |
2097 | |
2098 | buffer_binding_map_[key] = 1; |
2099 | buffer_value_map_[key] = ret_buffer_value_; |
2100 | return ret_buffer_value_; |
2101 | } |
2102 | |
2103 | // Binding head starts at 2, so we don't break args and rets |
2104 | int binding = binding_head_++; |
2105 | buffer_binding_map_[key] = binding; |
2106 | |
2107 | spirv::Value buffer_value = |
2108 | ir_->buffer_argument(type, 0, binding, buffer_instance_name(buffer)); |
2109 | buffer_value_map_[key] = buffer_value; |
2110 | TI_TRACE("buffer name = {}, value = {}" , buffer_instance_name(buffer), |
2111 | buffer_value.id); |
2112 | |
2113 | return buffer_value; |
2114 | } |
2115 | |
2116 | spirv::Value make_pointer(size_t offset) { |
2117 | if (use_64bit_pointers) { |
2118 | // This is hacky, should check out how to encode uint64 values in spirv |
2119 | return ir_->cast(ir_->u64_type(), ir_->uint_immediate_number( |
2120 | ir_->u32_type(), uint32_t(offset))); |
2121 | } else { |
2122 | return ir_->uint_immediate_number(ir_->u32_type(), uint32_t(offset)); |
2123 | } |
2124 | } |
2125 | |
2126 | void compile_args_struct() { |
2127 | if (!ctx_attribs_->has_args()) |
2128 | return; |
2129 | |
2130 | // Generate struct IR |
2131 | tinyir::Block blk; |
2132 | std::vector<const tinyir::Type *> element_types; |
2133 | for (auto &arg : ctx_attribs_->args()) { |
2134 | const tinyir::Type *t; |
2135 | if (arg.is_array && |
2136 | caps_->get(DeviceCapability::spirv_has_physical_storage_buffer)) { |
2137 | t = blk.emplace_back<IntType>(/*num_bits=*/64, /*is_signed=*/false); |
2138 | } else { |
2139 | t = translate_ti_primitive(blk, PrimitiveType::get(arg.dtype)); |
2140 | } |
2141 | element_types.push_back(t); |
2142 | } |
2143 | const tinyir::Type *i32_type = |
2144 | blk.emplace_back<IntType>(/*num_bits=*/32, /*is_signed=*/true); |
2145 | for (int i = 0; i < ctx_attribs_->extra_args_bytes() / 4; i++) { |
2146 | element_types.push_back(i32_type); |
2147 | } |
2148 | const tinyir::Type *struct_type = |
2149 | blk.emplace_back<StructType>(element_types); |
2150 | |
2151 | // Reduce struct IR |
2152 | std::unordered_map<const tinyir::Type *, const tinyir::Type *> old2new; |
2153 | auto reduced_blk = ir_reduce_types(&blk, old2new); |
2154 | struct_type = old2new[struct_type]; |
2155 | |
2156 | // Layout & translate to SPIR-V |
2157 | STD140LayoutContext layout_ctx; |
2158 | auto ir2spirv_map = |
2159 | ir_translate_to_spirv(reduced_blk.get(), layout_ctx, ir_.get()); |
2160 | args_struct_type_.id = ir2spirv_map[struct_type]; |
2161 | |
2162 | args_buffer_value_ = |
2163 | ir_->uniform_struct_argument(args_struct_type_, 0, 0, "args" ); |
2164 | } |
2165 | |
2166 | void compile_ret_struct() { |
2167 | if (!ctx_attribs_->has_rets()) |
2168 | return; |
2169 | |
2170 | std::vector<std::tuple<spirv::SType, std::string, size_t>> |
2171 | struct_components_; |
2172 | // Now we only have one ret |
2173 | TI_ASSERT(ctx_attribs_->rets().size() == 1); |
2174 | for (auto &ret : ctx_attribs_->rets()) { |
2175 | // Use array size = 0 to generate a RuntimeArray |
2176 | if (auto tensor_type = |
2177 | PrimitiveType::get(ret.dtype)->cast<TensorType>()) { |
2178 | struct_components_.emplace_back( |
2179 | ir_->get_array_type( |
2180 | ir_->get_primitive_type(tensor_type->get_element_type()), 0), |
2181 | "ret" + std::to_string(ret.index), ret.offset_in_mem); |
2182 | } else { |
2183 | struct_components_.emplace_back( |
2184 | ir_->get_array_type( |
2185 | ir_->get_primitive_type(PrimitiveType::get(ret.dtype)), 0), |
2186 | "ret" + std::to_string(ret.index), ret.offset_in_mem); |
2187 | } |
2188 | } |
2189 | ret_struct_type_ = ir_->create_struct_type(struct_components_); |
2190 | |
2191 | ret_buffer_value_ = |
2192 | ir_->buffer_struct_argument(ret_struct_type_, 0, 1, "rets" ); |
2193 | } |
2194 | |
2195 | std::vector<BufferBind> get_buffer_binds() { |
2196 | std::vector<BufferBind> result; |
2197 | for (auto &[key, val] : buffer_binding_map_) { |
2198 | result.push_back(BufferBind{key.first, int(val)}); |
2199 | } |
2200 | return result; |
2201 | } |
2202 | |
2203 | std::vector<TextureBind> get_texture_binds() { |
2204 | return texture_binds_; |
2205 | } |
2206 | |
2207 | void push_loop_control_labels(spirv::Label continue_label, |
2208 | spirv::Label merge_label) { |
2209 | continue_label_stack_.push_back(continue_label); |
2210 | merge_label_stack_.push_back(merge_label); |
2211 | } |
2212 | |
2213 | void pop_loop_control_labels() { |
2214 | continue_label_stack_.pop_back(); |
2215 | merge_label_stack_.pop_back(); |
2216 | } |
2217 | |
2218 | const spirv::Label current_continue_label() const { |
2219 | return continue_label_stack_.back(); |
2220 | } |
2221 | |
2222 | const spirv::Label current_merge_label() const { |
2223 | return merge_label_stack_.back(); |
2224 | } |
2225 | |
2226 | const spirv::Label return_label() const { |
2227 | return continue_label_stack_.front(); |
2228 | } |
2229 | |
2230 | Arch arch_; |
2231 | DeviceCapabilityConfig *caps_; |
2232 | |
2233 | struct BufferInfoTypeTupleHasher { |
2234 | std::size_t operator()(const std::pair<BufferInfo, int> &buf) const { |
2235 | return BufferInfoHasher()(buf.first) ^ (buf.second << 5); |
2236 | } |
2237 | }; |
2238 | |
2239 | spirv::SType args_struct_type_; |
2240 | spirv::Value args_buffer_value_; |
2241 | |
2242 | spirv::SType ret_struct_type_; |
2243 | spirv::Value ret_buffer_value_; |
2244 | |
2245 | std::shared_ptr<spirv::IRBuilder> ir_; // spirv binary code builder |
2246 | std::unordered_map<std::pair<BufferInfo, int>, |
2247 | spirv::Value, |
2248 | BufferInfoTypeTupleHasher> |
2249 | buffer_value_map_; |
2250 | std::unordered_map<std::pair<BufferInfo, int>, |
2251 | uint32_t, |
2252 | BufferInfoTypeTupleHasher> |
2253 | buffer_binding_map_; |
2254 | std::vector<TextureBind> texture_binds_; |
2255 | std::vector<spirv::Value> shared_array_binds_; |
2256 | spirv::Value kernel_function_; |
2257 | spirv::Label kernel_return_label_; |
2258 | bool gen_label_{false}; |
2259 | |
2260 | int binding_head_{2}; // Args:0, Ret:1 |
2261 | |
2262 | /* |
2263 | std::unordered_map<int, spirv::CompiledSpirvSNode> |
2264 | spirv_snodes_; // maps root id to spirv snode |
2265 | */ |
2266 | |
2267 | OffloadedStmt *const task_ir_; // not owned |
2268 | std::vector<CompiledSNodeStructs> compiled_structs_; |
2269 | std::unordered_map<int, int> snode_to_root_; |
2270 | const KernelContextAttributes *const ctx_attribs_; // not owned |
2271 | const std::string task_name_; |
2272 | std::vector<spirv::Label> continue_label_stack_; |
2273 | std::vector<spirv::Label> merge_label_stack_; |
2274 | |
2275 | std::unordered_set<const Stmt *> offload_loop_motion_; |
2276 | |
2277 | TaskAttributes task_attribs_; |
2278 | std::unordered_map<int, GetRootStmt *> |
2279 | root_stmts_; // maps root id to get root stmt |
2280 | std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_; |
2281 | std::unordered_map<int, Value> argid_to_tex_value_; |
2282 | }; |
2283 | } // namespace |
2284 | |
2285 | static void spriv_message_consumer(spv_message_level_t level, |
2286 | const char *source, |
2287 | const spv_position_t &position, |
2288 | const char *message) { |
2289 | // TODO: Maybe we can add a macro, e.g. TI_LOG_AT_LEVEL(lv, ...) |
2290 | if (level <= SPV_MSG_FATAL) { |
2291 | TI_ERROR("{}\n[{}:{}:{}] {}" , source, position.index, position.line, |
2292 | position.column, message); |
2293 | } else if (level <= SPV_MSG_WARNING) { |
2294 | TI_WARN("{}\n[{}:{}:{}] {}" , source, position.index, position.line, |
2295 | position.column, message); |
2296 | } else if (level <= SPV_MSG_INFO) { |
2297 | TI_INFO("{}\n[{}:{}:{}] {}" , source, position.index, position.line, |
2298 | position.column, message); |
2299 | } else if (level <= SPV_MSG_INFO) { |
2300 | TI_TRACE("{}\n[{}:{}:{}] {}" , source, position.index, position.line, |
2301 | position.column, message); |
2302 | } |
2303 | } |
2304 | |
2305 | KernelCodegen::KernelCodegen(const Params ¶ms) |
2306 | : params_(params), ctx_attribs_(*params.kernel, ¶ms.caps) { |
2307 | uint32_t spirv_version = params.caps.get(DeviceCapability::spirv_version); |
2308 | |
2309 | spv_target_env target_env; |
2310 | if (spirv_version >= 0x10600) { |
2311 | target_env = SPV_ENV_VULKAN_1_3; |
2312 | } else if (spirv_version >= 0x10500) { |
2313 | target_env = SPV_ENV_VULKAN_1_2; |
2314 | } else if (spirv_version >= 0x10400) { |
2315 | target_env = SPV_ENV_VULKAN_1_1_SPIRV_1_4; |
2316 | } else if (spirv_version >= 0x10300) { |
2317 | target_env = SPV_ENV_VULKAN_1_1; |
2318 | } else { |
2319 | target_env = SPV_ENV_VULKAN_1_0; |
2320 | } |
2321 | |
2322 | spirv_opt_ = std::make_unique<spvtools::Optimizer>(target_env); |
2323 | spirv_opt_->SetMessageConsumer(spriv_message_consumer); |
2324 | if (params.enable_spv_opt) { |
2325 | // From: SPIRV-Tools/source/opt/optimizer.cpp |
2326 | spirv_opt_->RegisterPass(spvtools::CreateWrapOpKillPass()) |
2327 | .RegisterPass(spvtools::CreateDeadBranchElimPass()) |
2328 | .RegisterPass(spvtools::CreateMergeReturnPass()) |
2329 | .RegisterPass(spvtools::CreateInlineExhaustivePass()) |
2330 | .RegisterPass(spvtools::CreateEliminateDeadFunctionsPass()) |
2331 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
2332 | .RegisterPass(spvtools::CreatePrivateToLocalPass()) |
2333 | .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()) |
2334 | .RegisterPass(spvtools::CreateLocalSingleStoreElimPass()) |
2335 | .RegisterPass(spvtools::CreateScalarReplacementPass()) |
2336 | .RegisterPass(spvtools::CreateLocalAccessChainConvertPass()) |
2337 | .RegisterPass(spvtools::CreateLocalMultiStoreElimPass()) |
2338 | .RegisterPass(spvtools::CreateCCPPass()) |
2339 | .RegisterPass(spvtools::CreateLoopUnrollPass(true)) |
2340 | .RegisterPass(spvtools::CreateRedundancyEliminationPass()) |
2341 | .RegisterPass(spvtools::CreateCombineAccessChainsPass()) |
2342 | .RegisterPass(spvtools::CreateSimplificationPass()) |
2343 | .RegisterPass(spvtools::CreateSSARewritePass()) |
2344 | .RegisterPass(spvtools::CreateVectorDCEPass()) |
2345 | .RegisterPass(spvtools::CreateDeadInsertElimPass()) |
2346 | .RegisterPass(spvtools::CreateIfConversionPass()) |
2347 | .RegisterPass(spvtools::CreateCopyPropagateArraysPass()) |
2348 | .RegisterPass(spvtools::CreateReduceLoadSizePass()) |
2349 | .RegisterPass(spvtools::CreateBlockMergePass()); |
2350 | } |
2351 | spirv_opt_options_.set_run_validator(false); |
2352 | |
2353 | spirv_tools_ = std::make_unique<spvtools::SpirvTools>(target_env); |
2354 | } |
2355 | |
2356 | void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, |
2357 | std::vector<std::vector<uint32_t>> &generated_spirv) { |
2358 | auto *root = params_.kernel->ir->as<Block>(); |
2359 | auto &tasks = root->statements; |
2360 | for (int i = 0; i < tasks.size(); ++i) { |
2361 | TaskCodegen::Params tp; |
2362 | tp.task_ir = tasks[i]->as<OffloadedStmt>(); |
2363 | tp.task_id_in_kernel = i; |
2364 | tp.compiled_structs = params_.compiled_structs; |
2365 | tp.ctx_attribs = &ctx_attribs_; |
2366 | tp.ti_kernel_name = fmt::format("{}_{}" , params_.ti_kernel_name, i); |
2367 | tp.arch = params_.arch; |
2368 | tp.caps = ¶ms_.caps; |
2369 | |
2370 | TaskCodegen cgen(tp); |
2371 | auto task_res = cgen.run(); |
2372 | |
2373 | for (auto &[id, access] : task_res.arr_access) { |
2374 | ctx_attribs_.arr_access[id] = ctx_attribs_.arr_access[id] | access; |
2375 | } |
2376 | |
2377 | std::vector<uint32_t> optimized_spv(task_res.spirv_code); |
2378 | |
2379 | bool success = true; |
2380 | { |
2381 | bool result = false; |
2382 | TI_ERROR_IF( |
2383 | (result = !spirv_opt_->Run(optimized_spv.data(), optimized_spv.size(), |
2384 | &optimized_spv, spirv_opt_options_)), |
2385 | "SPIRV optimization failed" ); |
2386 | if (result) { |
2387 | success = false; |
2388 | } |
2389 | } |
2390 | |
2391 | TI_TRACE("SPIRV-Tools-opt: binary size, before={}, after={}" , |
2392 | task_res.spirv_code.size(), optimized_spv.size()); |
2393 | |
2394 | // Enable to dump SPIR-V assembly of kernels |
2395 | if constexpr (false) { |
2396 | std::vector<uint32_t> &spirv = |
2397 | success ? optimized_spv : task_res.spirv_code; |
2398 | |
2399 | std::string spirv_asm; |
2400 | spirv_tools_->Disassemble(optimized_spv, &spirv_asm); |
2401 | auto kernel_name = tp.ti_kernel_name; |
2402 | TI_WARN("SPIR-V Assembly dump for {} :\n{}\n\n" , kernel_name, spirv_asm); |
2403 | |
2404 | std::ofstream fout(kernel_name + ".spv" , |
2405 | std::ios::binary | std::ios::out); |
2406 | fout.write(reinterpret_cast<const char *>(spirv.data()), |
2407 | spirv.size() * sizeof(uint32_t)); |
2408 | fout.close(); |
2409 | } |
2410 | |
2411 | kernel_attribs.tasks_attribs.push_back(std::move(task_res.task_attribs)); |
2412 | generated_spirv.push_back(std::move(optimized_spv)); |
2413 | } |
2414 | kernel_attribs.ctx_attribs = std::move(ctx_attribs_); |
2415 | kernel_attribs.name = params_.ti_kernel_name; |
2416 | kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator; |
2417 | } |
2418 | |
2419 | void lower(const CompileConfig &config, Kernel *kernel) { |
2420 | if (!kernel->lowered()) { |
2421 | irpass::compile_to_executable(kernel->ir.get(), config, kernel, |
2422 | kernel->autodiff_mode, |
2423 | /*ad_use_stack=*/false, config.print_ir, |
2424 | /*lower_global_access=*/true, |
2425 | /*make_thread_local=*/false); |
2426 | kernel->set_lowered(true); |
2427 | } |
2428 | } |
2429 | |
2430 | } // namespace spirv |
2431 | } // namespace taichi::lang |
2432 | |