1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/transforms/utils.h"
6
7namespace taichi::lang {
8
9namespace {
10
11using TaskType = OffloadedStmt::TaskType;
12
13void convert_to_range_for(OffloadedStmt *offloaded) {
14 TI_ASSERT(offloaded->task_type == TaskType::struct_for);
15
16 std::vector<SNode *> snodes;
17 auto *snode = offloaded->snode;
18 int64 total_n = 1;
19 std::array<int, taichi_max_num_indices> total_shape;
20 total_shape.fill(1);
21 while (snode->type != SNodeType::root) {
22 snodes.push_back(snode);
23 for (int j = 0; j < taichi_max_num_indices; j++) {
24 total_shape[j] *= snode->extractors[j].shape;
25 }
26 total_n *= snode->num_cells_per_container;
27 snode = snode->parent;
28 }
29 TI_ASSERT(total_n <= std::numeric_limits<int>::max());
30 std::reverse(snodes.begin(), snodes.end());
31
32 offloaded->const_begin = true;
33 offloaded->const_end = true;
34 offloaded->begin_value = 0;
35 offloaded->end_value = total_n;
36
37 ////// Begin core transformation
38 auto body = std::move(offloaded->body);
39 const int num_loop_vars =
40 snodes.empty() ? 0 : snodes.back()->num_active_indices;
41
42 std::vector<Stmt *> new_loop_vars;
43
44 VecStatement body_header;
45
46 std::vector<int> physical_indices;
47
48 for (int i = 0; i < num_loop_vars; i++) {
49 new_loop_vars.push_back(body_header.push_back<ConstStmt>(TypedConstant(0)));
50 physical_indices.push_back(snodes.back()->physical_index_position[i]);
51 }
52
53 auto main_loop_var = body_header.push_back<LoopIndexStmt>(nullptr, 0);
54 // We will set main_loop_var->loop later.
55
56 for (int i = 0; i < (int)snodes.size(); i++) {
57 auto snode = snodes[i];
58 Stmt *extracted = main_loop_var;
59 if (i != 0) { // first extraction doesn't need a mod
60 extracted = generate_mod(&body_header, extracted, total_n);
61 }
62 total_n /= snode->num_cells_per_container;
63 extracted = generate_div(&body_header, extracted, total_n);
64 bool is_first_extraction = true;
65 for (int j = 0; j < (int)physical_indices.size(); j++) {
66 auto p = physical_indices[j];
67 auto ext = snode->extractors[p];
68 if (!ext.active)
69 continue;
70 Stmt *index = extracted;
71 if (is_first_extraction) { // first extraction doesn't need a mod
72 is_first_extraction = false;
73 } else {
74 index = generate_mod(&body_header, index, ext.acc_shape * ext.shape);
75 }
76 index = generate_div(&body_header, index, ext.acc_shape);
77 total_shape[p] /= ext.shape;
78 auto multiplier =
79 body_header.push_back<ConstStmt>(TypedConstant(total_shape[p]));
80 auto delta = body_header.push_back<BinaryOpStmt>(BinaryOpType::mul, index,
81 multiplier);
82 new_loop_vars[j] = body_header.push_back<BinaryOpStmt>(
83 BinaryOpType::add, new_loop_vars[j], delta);
84 }
85 }
86
87 irpass::replace_statements(
88 body.get(), /*filter=*/
89 [&](Stmt *s) {
90 if (auto loop_index = s->cast<LoopIndexStmt>()) {
91 return loop_index->loop == offloaded;
92 } else {
93 return false;
94 }
95 },
96 /*finder=*/
97 [&](Stmt *s) {
98 auto index = std::find(physical_indices.begin(), physical_indices.end(),
99 s->as<LoopIndexStmt>()->index);
100 TI_ASSERT(index != physical_indices.end());
101 return new_loop_vars[index - physical_indices.begin()];
102 });
103
104 body->insert(std::move(body_header), 0);
105
106 offloaded->body = std::move(body);
107 offloaded->body->parent_stmt = offloaded;
108 main_loop_var->loop = offloaded;
109 ////// End core transformation
110
111 offloaded->task_type = TaskType::range_for;
112}
113
114void maybe_convert(OffloadedStmt *stmt) {
115 if ((stmt->task_type == TaskType::struct_for) &&
116 stmt->snode->is_path_all_dense) {
117 convert_to_range_for(stmt);
118 }
119}
120
121} // namespace
122
123namespace irpass {
124
125void demote_dense_struct_fors(IRNode *root) {
126 if (auto *block = root->cast<Block>()) {
127 for (auto &s_ : block->statements) {
128 if (auto *s = s_->cast<OffloadedStmt>()) {
129 maybe_convert(s);
130 }
131 }
132 } else if (auto *s = root->cast<OffloadedStmt>()) {
133 maybe_convert(s);
134 }
135 re_id(root);
136}
137
138} // namespace irpass
139
140} // namespace taichi::lang
141