1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/program.h"
7
8#include <set>
9#include <unordered_map>
10#include <utility>
11
12namespace taichi::lang {
13
14namespace irpass {
15namespace {
16bool demotable_axis_load(Stmt *stmt) {
17 // Stmt involving simple arithmetic of ExternalTensorShapeAlongAxisStmt
18 // shouldn't be saved in global tmp, just clone them to each shader
19 // separately.
20 int n_op = stmt->num_operands();
21 if (n_op == 0) {
22 return stmt->is<ExternalTensorShapeAlongAxisStmt>() ||
23 stmt->is<ConstStmt>();
24 }
25 for (int i = 0; i < n_op; i++) {
26 auto op = stmt->operand(i);
27 if (!demotable_axis_load(op))
28 return false;
29 }
30 return true;
31}
32class SquashPtrOffset : public IRVisitor {
33 public:
34 SquashPtrOffset() {
35 allow_undefined_visitor = true;
36 invoke_default_visitor = true;
37 }
38 void visit(Stmt *stmt) override {
39 top_level_ptr_ = stmt;
40 }
41 void visit(MatrixPtrStmt *stmt) override {
42 stmt->origin->accept(this);
43 }
44 static Stmt *run(Stmt *root) {
45 SquashPtrOffset v;
46 root->accept(&v);
47 return v.top_level_ptr_;
48 }
49
50 private:
51 Stmt *top_level_ptr_ = nullptr;
52};
53
54// Offloaded local variables to its offset in the global tmps memory.
55using StmtToOffsetMap = std::unordered_map<const Stmt *, std::size_t>;
56
57struct OffloadedRanges {
58 using Map = std::unordered_map<const OffloadedStmt *, Stmt *>;
59 Map begin_stmts;
60 Map end_stmts;
61};
62
63// Break kernel into multiple parts and emit struct for listgens
64// For GPU backends this pass also determines the grid dim and block dims
65class Offloader {
66 public:
67 static OffloadedRanges run(IRNode *root, const CompileConfig &config) {
68 OffloadedRanges offloaded_ranges;
69
70 auto root_block = dynamic_cast<Block *>(root);
71 auto root_statements = std::move(root_block->statements);
72 root_block->statements.clear();
73 const auto arch = config.arch;
74 auto pending_serial_statements =
75 Stmt::make_typed<OffloadedStmt>(OffloadedStmt::TaskType::serial, arch);
76 pending_serial_statements->grid_dim = 1;
77 pending_serial_statements->block_dim = 1;
78
79 auto assemble_serial_statements = [&]() {
80 if (!pending_serial_statements->body->statements.empty()) {
81 root_block->insert(std::move(pending_serial_statements));
82 pending_serial_statements = Stmt::make_typed<OffloadedStmt>(
83 OffloadedStmt::TaskType::serial, arch);
84 pending_serial_statements->grid_dim = 1;
85 pending_serial_statements->block_dim = 1;
86 }
87 };
88
89 for (int i = 0; i < (int)root_statements.size(); i++) {
90 auto &stmt = root_statements[i];
91 // Note that stmt->parent is root_block, which doesn't contain stmt now.
92 if (auto s = stmt->cast<RangeForStmt>(); s && !s->strictly_serialized) {
93 assemble_serial_statements();
94 auto offloaded = Stmt::make_typed<OffloadedStmt>(
95 OffloadedStmt::TaskType::range_for, arch);
96 // offloaded->body is an empty block now.
97 offloaded->grid_dim = config.saturating_grid_dim;
98 if (s->block_dim == 0) {
99 offloaded->block_dim = Program::default_block_dim(config);
100 } else {
101 offloaded->block_dim = s->block_dim;
102 }
103 if (auto val = s->begin->cast<ConstStmt>()) {
104 offloaded->const_begin = true;
105 offloaded->begin_value = val->val.val_int32();
106 } else {
107 offloaded_ranges.begin_stmts.insert(
108 std::make_pair(offloaded.get(), s->begin));
109 }
110
111 if (auto val = s->end->cast<ConstStmt>()) {
112 offloaded->const_end = true;
113 offloaded->end_value = val->val.val_int32();
114 } else {
115 if ((arch == Arch::opengl || arch == Arch::vulkan ||
116 arch == Arch::gles) &&
117 demotable_axis_load(s->end)) {
118 // TODO: We need to update codegen for each backend gradually so
119 // let's limit it to opengl backend for now.
120 auto end_copy = s->end->clone();
121 offloaded->end_stmt = end_copy.get();
122 offloaded->body->insert(std::move(end_copy));
123 }
124 offloaded_ranges.end_stmts.insert(
125 std::make_pair(offloaded.get(), s->end));
126 }
127
128 offloaded->num_cpu_threads =
129 std::min(s->num_cpu_threads, config.cpu_max_num_threads);
130 replace_all_usages_with(s, s, offloaded.get());
131 for (int j = 0; j < (int)s->body->statements.size(); j++) {
132 offloaded->body->insert(std::move(s->body->statements[j]));
133 }
134 offloaded->range_hint = s->range_hint;
135 root_block->insert(std::move(offloaded));
136 } else if (auto st = stmt->cast<StructForStmt>()) {
137 assemble_serial_statements();
138 emit_struct_for(st, root_block, config, st->mem_access_opt);
139 } else if (auto st = stmt->cast<MeshForStmt>()) {
140 assemble_serial_statements();
141 auto offloaded = Stmt::make_typed<OffloadedStmt>(
142 OffloadedStmt::TaskType::mesh_for, arch);
143 offloaded->grid_dim = config.saturating_grid_dim;
144 if (st->block_dim == 0) {
145 offloaded->block_dim = Program::default_block_dim(config);
146 } else {
147 offloaded->block_dim = st->block_dim;
148 }
149 offloaded->num_cpu_threads =
150 std::min(st->num_cpu_threads, config.cpu_max_num_threads);
151 replace_all_usages_with(st, st, offloaded.get());
152 for (int j = 0; j < (int)st->body->statements.size(); j++) {
153 offloaded->body->insert(std::move(st->body->statements[j]));
154 }
155 offloaded->mesh = st->mesh;
156 offloaded->major_from_type = std::move(st->major_from_type);
157 offloaded->major_to_types = std::move(st->major_to_types);
158 offloaded->minor_relation_types = std::move(st->minor_relation_types);
159 offloaded->mem_access_opt = st->mem_access_opt;
160 root_block->insert(std::move(offloaded));
161 } else {
162 pending_serial_statements->body->insert(std::move(stmt));
163 }
164 }
165 assemble_serial_statements();
166 return offloaded_ranges;
167 }
168
169 private:
170 static void emit_struct_for(StructForStmt *for_stmt,
171 Block *root_block,
172 const CompileConfig &config,
173 const MemoryAccessOptions &mem_access_opt) {
174 auto leaf = for_stmt->snode;
175 // make a list of nodes, from the leaf block (instead of 'place') to root
176 std::vector<SNode *> path;
177 // leaf is the place (scalar)
178 // leaf->parent is the leaf block
179 // so listgen should be invoked from the root to leaf->parent
180 for (auto p = leaf; p; p = p->parent) {
181 path.push_back(p);
182 }
183 std::reverse(path.begin(), path.end());
184
185 // If |demotable| is true, this will later be demoting into a range-for
186 // task, so we don't need to generate clear/listgen tasks.
187 const bool demotable =
188 (leaf->is_path_all_dense && config.demote_dense_struct_fors);
189 const auto arch = config.arch;
190 if (!demotable) {
191 for (int i = 1; i < path.size(); i++) {
192 auto snode_child = path[i];
193 if (snode_child->type == SNodeType::quant_array &&
194 for_stmt->is_bit_vectorized) {
195 TI_ASSERT(i == path.size() - 1);
196 continue;
197 }
198 auto offloaded_clear_list = Stmt::make_typed<OffloadedStmt>(
199 OffloadedStmt::TaskType::serial, arch);
200 offloaded_clear_list->body->insert(
201 Stmt::make<ClearListStmt>(snode_child));
202 offloaded_clear_list->grid_dim = 1;
203 offloaded_clear_list->block_dim = 1;
204 // Intentionally do not set offloaded_clear_list->snode, so that there
205 // is nothing special about this task, which could otherwise cause
206 // problems when fused with other serial tasks.
207 root_block->insert(std::move(offloaded_clear_list));
208 auto offloaded_listgen = Stmt::make_typed<OffloadedStmt>(
209 OffloadedStmt::TaskType::listgen, arch);
210 offloaded_listgen->snode = snode_child;
211 offloaded_listgen->grid_dim = config.saturating_grid_dim;
212 offloaded_listgen->block_dim =
213 std::min(snode_child->max_num_elements(),
214 (int64)std::min(Program::default_block_dim(config),
215 config.max_block_dim));
216 root_block->insert(std::move(offloaded_listgen));
217 }
218 }
219
220 auto offloaded_struct_for = Stmt::make_typed<OffloadedStmt>(
221 OffloadedStmt::TaskType::struct_for, arch);
222
223 offloaded_struct_for->index_offsets = for_stmt->index_offsets;
224
225 offloaded_struct_for->grid_dim = config.saturating_grid_dim;
226
227 const auto snode_num_elements = for_stmt->snode->max_num_elements();
228 if (for_stmt->block_dim == 0) {
229 // adaptive
230 offloaded_struct_for->block_dim =
231 std::min(snode_num_elements, (int64)config.default_gpu_block_dim);
232 } else {
233 if (for_stmt->block_dim > snode_num_elements) {
234 TI_WARN(
235 "Specified block dim {} is bigger than SNode element size {}. "
236 "Clipping.\n{}",
237 for_stmt->block_dim, snode_num_elements, for_stmt->tb);
238 offloaded_struct_for->block_dim = snode_num_elements;
239 } else {
240 offloaded_struct_for->block_dim = for_stmt->block_dim;
241 }
242 }
243
244 replace_all_usages_with(for_stmt, for_stmt, offloaded_struct_for.get());
245
246 for (int i = 0; i < (int)for_stmt->body->statements.size(); i++) {
247 offloaded_struct_for->body->insert(
248 std::move(for_stmt->body->statements[i]));
249 }
250
251 offloaded_struct_for->snode = for_stmt->snode;
252 offloaded_struct_for->is_bit_vectorized = for_stmt->is_bit_vectorized;
253 offloaded_struct_for->num_cpu_threads =
254 std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads);
255 offloaded_struct_for->mem_access_opt = mem_access_opt;
256
257 root_block->insert(std::move(offloaded_struct_for));
258 }
259};
260
261// Build a mapping from all statements to its containing OffloadedStmt
262class StmtToOffloaded : public BasicStmtVisitor {
263 private:
264 StmtToOffloaded() {
265 allow_undefined_visitor = true;
266 invoke_default_visitor = true;
267 current_offloaded_ = nullptr;
268 }
269
270 public:
271 void visit(OffloadedStmt *stmt) override {
272 current_offloaded_ = stmt;
273 stmt_to_offloaded_[stmt] = current_offloaded_;
274 if (stmt->body)
275 stmt->body->accept(this);
276 current_offloaded_ = nullptr;
277 }
278
279 void visit(Stmt *stmt) override {
280 if (current_offloaded_ != nullptr) {
281 // inside a offloaded stmt, record its belonging offloaded_stmt
282 stmt_to_offloaded_[stmt] = current_offloaded_;
283 }
284 }
285
286 void preprocess_container_stmt(Stmt *stmt) override {
287 if (current_offloaded_ != nullptr) {
288 // inside a offloaded stmt, record its belonging offloaded_stmt
289 stmt_to_offloaded_[stmt] = current_offloaded_;
290 }
291 }
292
293 public:
294 static std::unordered_map<Stmt *, Stmt *> run(IRNode *ir) {
295 StmtToOffloaded pass;
296 ir->accept(&pass);
297 return pass.stmt_to_offloaded_;
298 }
299
300 private:
301 using BasicStmtVisitor::visit;
302
303 // Local variables to its containing offloaded statement
304 std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded_;
305
306 Stmt *current_offloaded_;
307};
308
309/*
310After offloading, some local variables/instructions are accessed across
311offloaded blocks. This pass promote these local values into global variables.
312
313Steps:
314 1. IdentifyValuesUsedInOtherOffloads
315 2. PromoteIntermediateToGlobalTmp
316 3. FixCrossOffloadReferences
317*/
318
319// Traverse offloaded blocks to identify out-of-offload local LD/ST and
320// statement references
321class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
322 using BasicStmtVisitor::visit;
323
324 private:
325 IdentifyValuesUsedInOtherOffloads(
326 const CompileConfig &config,
327 const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
328 OffloadedRanges *offloaded_ranges)
329 : config_(config),
330 stmt_to_offloaded_(stmt_to_offloaded),
331 offloaded_ranges_(offloaded_ranges) {
332 allow_undefined_visitor = true;
333 invoke_default_visitor = true;
334 current_offloaded_ = nullptr;
335 global_offset_ = 0;
336 }
337
338 std::size_t allocate_global(DataType type) {
339 auto ret = global_offset_;
340 if (type->is<TensorType>()) {
341 auto tensor_type = type->cast<TensorType>();
342 global_offset_ += tensor_type->get_num_elements() *
343 data_type_size(tensor_type->get_element_type());
344 } else {
345 std::size_t type_size = data_type_size(type);
346 // align global_offset to a multiple of type_size
347 global_offset_ =
348 ((global_offset_ + type_size - 1) / type_size) * type_size;
349 ret = global_offset_;
350 global_offset_ += type_size;
351 }
352 TI_ASSERT(global_offset_ < taichi_global_tmp_buffer_size);
353 return ret;
354 }
355
356 public:
357 void visit(OffloadedStmt *stmt) override {
358 current_offloaded_ = stmt;
359 if (auto begin = offloaded_ranges_->begin_stmts.find(stmt);
360 begin != offloaded_ranges_->begin_stmts.end()) {
361 test_and_allocate(begin->second);
362 }
363 if (auto end = offloaded_ranges_->end_stmts.find(stmt);
364 end != offloaded_ranges_->end_stmts.end()) {
365 test_and_allocate(end->second);
366 }
367 if (stmt->body)
368 stmt->body->accept(this);
369 current_offloaded_ = nullptr;
370 }
371
372 void visit(AllocaStmt *stmt) override {
373 TI_ASSERT(current_offloaded_);
374 }
375
376 void test_and_allocate(Stmt *stmt) {
377 if (stmt == nullptr)
378 return;
379 if (stmt_to_offloaded_[stmt] == current_offloaded_)
380 return;
381 // Directly insert copies of ConstStmts later
382 if (stmt->is<ConstStmt>())
383 return;
384 auto top_level_ptr = SquashPtrOffset::run(stmt);
385 // We don't support storing a pointer for now.
386 if (top_level_ptr->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
387 (stmt->is<ArgLoadStmt>() && stmt->as<ArgLoadStmt>()->is_ptr))
388 return;
389 if ((config_.arch == Arch::opengl || config_.arch == Arch::vulkan ||
390 config_.arch == Arch::gles) &&
391 demotable_axis_load(stmt))
392 return;
393 // Not yet allocated
394 if (local_to_global_.find(top_level_ptr) == local_to_global_.end()) {
395 local_to_global_[top_level_ptr] =
396 allocate_global(top_level_ptr->ret_type);
397 }
398 }
399
400 void generic_visit(Stmt *stmt) {
401 int n_op = stmt->num_operands();
402 for (int i = 0; i < n_op; i++) {
403 auto op = stmt->operand(i);
404 test_and_allocate(op);
405 }
406 }
407
408 void preprocess_container_stmt(Stmt *stmt) override {
409 generic_visit(stmt);
410 }
411
412 void visit(Stmt *stmt) override {
413 generic_visit(stmt);
414 }
415
416 static StmtToOffsetMap run(
417 IRNode *root,
418 const CompileConfig &config,
419 const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
420 OffloadedRanges *offloaded_ranges) {
421 IdentifyValuesUsedInOtherOffloads pass(config, stmt_to_offloaded,
422 offloaded_ranges);
423 root->accept(&pass);
424 return pass.local_to_global_;
425 }
426
427 private:
428 CompileConfig config_;
429 std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded_;
430 OffloadedRanges *const offloaded_ranges_;
431 // Local variables to global temporary offsets (in bytes)
432 StmtToOffsetMap local_to_global_;
433 Stmt *current_offloaded_;
434 std::size_t global_offset_;
435};
436
437// Store intermediate values to globals so that statements in later offloaded
438// statement can load
439class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor {
440 using BasicStmtVisitor::visit;
441
442 private:
443 explicit PromoteIntermediateToGlobalTmp(
444 const StmtToOffsetMap &local_to_global_offset)
445 : local_to_global_offset_(local_to_global_offset) {
446 allow_undefined_visitor = true;
447 invoke_default_visitor = true;
448 }
449
450 public:
451 void visit(Stmt *stmt) override {
452 if (!stmt->is<AllocaStmt>() &&
453 local_to_global_offset_.find(stmt) != local_to_global_offset_.end() &&
454 stored_to_global_.find(stmt) == stored_to_global_.end()) {
455 stored_to_global_.insert(stmt);
456 auto offset = local_to_global_offset_[stmt];
457 auto ptr = stmt->insert_after_me(
458 Stmt::make<GlobalTemporaryStmt>(offset, stmt->ret_type));
459 ptr->insert_after_me(Stmt::make<GlobalStoreStmt>(ptr, stmt));
460 }
461 }
462
463 static void run(IRNode *root, const StmtToOffsetMap &local_to_global_offset) {
464 PromoteIntermediateToGlobalTmp pass(local_to_global_offset);
465 root->accept(&pass);
466 }
467
468 private:
469 StmtToOffsetMap local_to_global_offset_;
470 std::set<Stmt *> stored_to_global_;
471};
472
473class FixCrossOffloadReferences : public BasicStmtVisitor {
474 using BasicStmtVisitor::visit;
475
476 private:
477 FixCrossOffloadReferences(
478 const CompileConfig &config,
479 const StmtToOffsetMap &local_to_global_offset,
480 const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
481 OffloadedRanges *offloaded_ranges)
482 : config_(config),
483 local_to_global_offset_(local_to_global_offset),
484 stmt_to_offloaded_(stmt_to_offloaded),
485 offloaded_ranges_(offloaded_ranges) {
486 allow_undefined_visitor = true;
487 invoke_default_visitor = true;
488 }
489
490 void visit(OffloadedStmt *stmt) override {
491 if (stmt->body)
492 stmt->body->accept(this);
493 if (stmt->task_type == OffloadedStmt::TaskType::range_for) {
494 if (!stmt->const_begin) {
495 TI_ASSERT(offloaded_ranges_->begin_stmts.find(stmt) !=
496 offloaded_ranges_->begin_stmts.end())
497 TI_ASSERT_INFO(local_to_global_offset_.find(
498 offloaded_ranges_->begin_stmts.find(stmt)->second) !=
499 local_to_global_offset_.end(),
500 "Begin fails.")
501 stmt->begin_offset =
502 local_to_global_offset_[offloaded_ranges_->begin_stmts.find(stmt)
503 ->second];
504 }
505 if (!stmt->const_end) {
506 if (stmt->end_stmt) {
507 stmt->end_stmt->accept(this);
508 stmt->end_offset = 0;
509 } else {
510 TI_ASSERT(offloaded_ranges_->end_stmts.find(stmt) !=
511 offloaded_ranges_->end_stmts.end())
512 TI_ASSERT_INFO(local_to_global_offset_.find(
513 offloaded_ranges_->end_stmts.find(stmt)->second) !=
514 local_to_global_offset_.end(),
515 "End fails.")
516 stmt->end_offset =
517 local_to_global_offset_[offloaded_ranges_->end_stmts.find(stmt)
518 ->second];
519 }
520 }
521 }
522 }
523
524 // Replace alloca with global var initialization (set to 0)
525 void visit(AllocaStmt *stmt) override {
526 if (local_to_global_offset_.find(stmt) == local_to_global_offset_.end())
527 return;
528 VecStatement replacement;
529 auto ret_type = stmt->ret_type;
530 local_to_global_vector_type_[stmt] = ret_type;
531 auto ptr = replacement.push_back<GlobalTemporaryStmt>(
532 local_to_global_offset_[stmt], ret_type);
533 auto offloaded = stmt_to_offloaded_[stmt];
534 stmt_to_offloaded_[ptr] = offloaded;
535 if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
536 TypedConstant zero(tensor_type->get_element_type());
537 auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
538 stmt_to_offloaded_[const_zero_stmt] = offloaded;
539 for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
540 auto const_offset_stmt =
541 replacement.push_back<ConstStmt>(TypedConstant(i));
542 auto ptr_offset_stmt =
543 replacement.push_back<MatrixPtrStmt>(ptr, const_offset_stmt);
544 auto global_store_stmt = replacement.push_back<GlobalStoreStmt>(
545 ptr_offset_stmt, const_zero_stmt);
546 stmt_to_offloaded_[const_offset_stmt] = offloaded;
547 stmt_to_offloaded_[ptr_offset_stmt] = offloaded;
548 stmt_to_offloaded_[global_store_stmt] = offloaded;
549 }
550 } else {
551 TypedConstant zero(stmt->ret_type);
552 auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
553 auto global_store_stmt =
554 replacement.push_back<GlobalStoreStmt>(ptr, const_zero_stmt);
555 stmt_to_offloaded_[global_store_stmt] = offloaded;
556 }
557
558 stmt->parent->replace_with(stmt, std::move(replacement), false);
559 // To deal with the same offloaded visit_operand()
560 stmt_to_offloaded_[stmt] = nullptr;
561 }
562
563 // Replace local LD/ST with global LD/ST
564 void visit(LocalLoadStmt *stmt) override {
565 generic_visit(stmt);
566 auto ptr = stmt->src;
567 auto top_level_ptr = SquashPtrOffset::run(ptr);
568 if (top_level_ptr->is<GlobalTemporaryStmt>()) {
569 VecStatement replacement;
570 auto global_load = replacement.push_back<GlobalLoadStmt>(ptr);
571 stmt_to_offloaded_[global_load] = stmt_to_offloaded_[stmt];
572 stmt->parent->replace_with(stmt, std::move(replacement));
573 }
574 }
575
576 void visit(LocalStoreStmt *stmt) override {
577 generic_visit(stmt);
578 auto ptr = stmt->dest;
579 auto top_level_ptr = SquashPtrOffset::run(ptr);
580 if (top_level_ptr->is<GlobalTemporaryStmt>()) {
581 VecStatement replacement;
582 auto global_store =
583 replacement.push_back<GlobalStoreStmt>(ptr, stmt->val);
584 stmt_to_offloaded_[global_store] = stmt_to_offloaded_[stmt];
585 stmt->parent->replace_with(stmt, std::move(replacement));
586 }
587 }
588
589 bool visit_operand(Stmt *stmt, int index) {
590 // return true if modified
591 TI_ASSERT(index >= 0 && index < stmt->num_operands());
592 auto op = stmt->operand(index);
593 if (op == nullptr)
594 return false;
595 if (stmt_to_offloaded_[stmt] ==
596 stmt_to_offloaded_[op]) // same OffloadedStmt
597 return false;
598
599 auto offloaded = stmt_to_offloaded_[stmt];
600
601 if (op->is<GlobalPtrStmt>()) {
602 auto copy = op->clone();
603 auto pcopy = copy.get();
604 copy->as<GlobalPtrStmt>()->activate = false;
605 stmt_to_offloaded_[copy.get()] = offloaded;
606 stmt->set_operand(index, copy.get());
607 stmt->insert_before_me(std::move(copy));
608 generic_visit(pcopy);
609 return true;
610 }
611
612 if (local_to_global_offset_.find(op) == local_to_global_offset_.end()) {
613 // For stmts that are not promoted to global tmp, clone them into current
614 // offloaded task. E.g.
615 // ConstStmt/MatrixPtrStmt/GlobalTemporaryStmt/ExternalTensorShapeAlongAxisStmt
616 // etc.
617 auto copy = op->clone();
618 auto pcopy = copy.get();
619 stmt_to_offloaded_[copy.get()] = offloaded;
620 stmt->set_operand(index, copy.get());
621 stmt->insert_before_me(std::move(copy));
622 generic_visit(pcopy);
623 } else {
624 auto global_temporary = Stmt::make<GlobalTemporaryStmt>(
625 local_to_global_offset_[op], op->ret_type);
626 stmt_to_offloaded_[global_temporary.get()] = offloaded;
627 stmt->set_operand(index, global_temporary.get());
628 if (op->is<AllocaStmt>() || op->ret_type.is_pointer()) {
629 // For cases like Alloca both TensorType and Scalar which will be
630 // followed by LocalLoad. Avoid repeated loads here.
631 stmt->insert_before_me(std::move(global_temporary));
632 } else {
633 // For other cases like ArgLoadStmt UnaryOpStmt which needs to load.
634 auto load = Stmt::make<GlobalLoadStmt>(global_temporary.get());
635 stmt_to_offloaded_[load.get()] = offloaded;
636 stmt->set_operand(index, load.get());
637 stmt->insert_before_me(std::move(global_temporary));
638 stmt->insert_before_me(std::move(load));
639 }
640 }
641 return true;
642 }
643
644 void generic_visit(Stmt *stmt) {
645 int n_op = stmt->num_operands();
646 for (int i = 0; i < n_op; i++) {
647 visit_operand(stmt, i);
648 }
649 }
650
651 void visit(Stmt *stmt) override {
652 generic_visit(stmt);
653 }
654
655 void preprocess_container_stmt(Stmt *stmt) override {
656 generic_visit(stmt);
657 }
658
659 public:
660 static void run(IRNode *root,
661 const CompileConfig &config,
662 const StmtToOffsetMap &local_to_global_offset,
663 const std::unordered_map<Stmt *, Stmt *> &stmt_to_offloaded,
664 OffloadedRanges *offloaded_ranges) {
665 FixCrossOffloadReferences pass(config, local_to_global_offset,
666 stmt_to_offloaded, offloaded_ranges);
667 root->accept(&pass);
668 }
669
670 private:
671 [[maybe_unused]] const CompileConfig &config_;
672 StmtToOffsetMap local_to_global_offset_;
673 std::unordered_map<Stmt *, Stmt *> stmt_to_offloaded_;
674 OffloadedRanges *const offloaded_ranges_;
675 std::unordered_map<Stmt *, DataType> local_to_global_vector_type_;
676};
677
678void insert_gc(IRNode *root, const CompileConfig &config) {
679 auto *b = dynamic_cast<Block *>(root);
680 TI_ASSERT(b);
681 std::vector<std::pair<int, std::vector<SNode *>>> gc_statements;
682 for (int i = 0; i < (int)b->statements.size(); i++) {
683 auto snodes =
684 irpass::analysis::gather_deactivations(b->statements[i].get());
685 gc_statements.emplace_back(
686 std::make_pair(i, std::vector<SNode *>(snodes.begin(), snodes.end())));
687 }
688
689 for (int i = (int)b->statements.size() - 1; i >= 0; i--) {
690 auto snodes = gc_statements[i].second;
691 for (auto *snode : snodes) {
692 if (is_gc_able(snode->type)) {
693 auto gc_task = Stmt::make_typed<OffloadedStmt>(
694 OffloadedStmt::TaskType::gc, config.arch);
695 gc_task->snode = snode;
696 b->insert(std::move(gc_task), i + 1);
697 }
698 }
699 }
700 if (!irpass::analysis::gather_statements(root, [](Stmt *stmt) {
701 return stmt->is<FuncCallStmt>();
702 }).empty()) {
703 auto gc_task = Stmt::make_typed<OffloadedStmt>(
704 OffloadedStmt::TaskType::gc_rc, config.arch);
705 b->insert(std::move(gc_task));
706 }
707}
708
709class AssociateContinueScope : public BasicStmtVisitor {
710 public:
711 using BasicStmtVisitor::visit;
712 using Parent = BasicStmtVisitor;
713
714 void visit(WhileStmt *stmt) override {
715 auto *old_loop = cur_internal_loop_;
716 cur_internal_loop_ = stmt;
717 Parent::visit(stmt);
718 cur_internal_loop_ = old_loop;
719 }
720
721 void visit(RangeForStmt *stmt) override {
722 auto *old_loop = cur_internal_loop_;
723 cur_internal_loop_ = stmt;
724 Parent::visit(stmt);
725 cur_internal_loop_ = old_loop;
726 }
727
728 void visit(StructForStmt *stmt) override {
729 TI_ERROR("struct_for cannot be nested inside a kernel, stmt={}",
730 stmt->name());
731 }
732
733 void visit(OffloadedStmt *stmt) override {
734 TI_ASSERT(cur_offloaded_stmt_ == nullptr);
735 TI_ASSERT(cur_internal_loop_ == nullptr);
736 cur_offloaded_stmt_ = stmt;
737 Parent::visit(stmt);
738 cur_offloaded_stmt_ = nullptr;
739 }
740
741 void visit(ContinueStmt *stmt) override {
742 if (stmt->scope == nullptr) {
743 if (cur_internal_loop_ != nullptr) {
744 stmt->scope = cur_internal_loop_;
745 } else {
746 stmt->scope = cur_offloaded_stmt_;
747 }
748 modified_ = true;
749 }
750 TI_ASSERT(stmt->scope != nullptr);
751 }
752
753 static void run(IRNode *root) {
754 while (true) {
755 AssociateContinueScope pass;
756 root->accept(&pass);
757 if (!pass.modified_) {
758 break;
759 }
760 }
761 }
762
763 private:
764 explicit AssociateContinueScope()
765 : modified_(false),
766 cur_offloaded_stmt_(nullptr),
767 cur_internal_loop_(nullptr) {
768 }
769
770 bool modified_;
771 OffloadedStmt *cur_offloaded_stmt_;
772 Stmt *cur_internal_loop_;
773};
774
775} // namespace
776
777void offload(IRNode *root, const CompileConfig &config) {
778 TI_AUTO_PROF;
779 auto offloaded_ranges = Offloader::run(root, config);
780 type_check(root, config);
781 {
782 auto stmt_to_offloaded = StmtToOffloaded::run(root);
783 const auto local_to_global_offset = IdentifyValuesUsedInOtherOffloads::run(
784 root, config, stmt_to_offloaded, &offloaded_ranges);
785 PromoteIntermediateToGlobalTmp::run(root, local_to_global_offset);
786 stmt_to_offloaded = StmtToOffloaded::run(root);
787 FixCrossOffloadReferences::run(root, config, local_to_global_offset,
788 stmt_to_offloaded, &offloaded_ranges);
789 }
790 insert_gc(root, config);
791 // TODO(k-ye): Move this into its own pass. However, we need to wait for all
792 // backends to integrate with https://github.com/taichi-dev/taichi/pull/700
793 AssociateContinueScope::run(root);
794 type_check(root, config);
795 re_id(root);
796}
797
798} // namespace irpass
799
800} // namespace taichi::lang
801