1#include "taichi/ir/ir.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/transforms/utils.h"
7
8namespace taichi::lang {
9
10namespace {
11
12void convert_to_range_for(OffloadedStmt *offloaded) {
13 TI_ASSERT(offloaded->task_type == OffloadedTaskType::mesh_for);
14
15 DelayedIRModifier modifier;
16 auto stmts = irpass::analysis::gather_statements(
17 offloaded->body.get(),
18 [&](Stmt *stmt) { return stmt->is<MeshIndexConversionStmt>(); });
19 for (size_t i = 0; i < stmts.size(); ++i) {
20 auto conv_stmt = stmts[i]->cast<MeshIndexConversionStmt>();
21 if (conv_stmt->conv_type == mesh::ConvType::l2g) {
22 stmts[i]->replace_usages_with(conv_stmt->idx);
23 modifier.erase(stmts[i]);
24 } else if (conv_stmt->conv_type == mesh::ConvType::l2r) {
25 stmts[i]->as<MeshIndexConversionStmt>()->conv_type = mesh::ConvType::g2r;
26 }
27 }
28
29 modifier.modify_ir();
30
31 offloaded->const_begin = true;
32 offloaded->const_end = true;
33 offloaded->begin_value = 0;
34 offloaded->end_value =
35 offloaded->mesh->num_elements.find(offloaded->major_from_type)->second;
36 offloaded->mesh = nullptr;
37 offloaded->task_type = OffloadedTaskType::range_for;
38}
39
40void maybe_convert(OffloadedStmt *offloaded) {
41 if (offloaded->task_type == OffloadedTaskType::mesh_for &&
42 offloaded->major_to_types.size() == 0) {
43 auto stmts = irpass::analysis::gather_statements( // ti.mesh_patch_idx()
44 // relies on mesh-for
45 offloaded->body.get(),
46 [&](Stmt *stmt) { return stmt->is<MeshPatchIndexStmt>(); });
47 if (stmts.size() == 0) {
48 convert_to_range_for(offloaded);
49 }
50 }
51}
52
53} // namespace
54
55namespace irpass {
56
57void demote_no_access_mesh_fors(IRNode *root) {
58 if (auto *block = root->cast<Block>()) {
59 for (auto &s_ : block->statements) {
60 if (auto *s = s_->cast<OffloadedStmt>()) {
61 maybe_convert(s);
62 }
63 }
64 } else if (auto *s = root->cast<OffloadedStmt>()) {
65 maybe_convert(s);
66 }
67 re_id(root);
68}
69
70} // namespace irpass
71
72} // namespace taichi::lang
73