1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/transforms/utils.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | namespace { |
10 | |
11 | using TaskType = OffloadedStmt::TaskType; |
12 | |
13 | void convert_to_range_for(OffloadedStmt *offloaded) { |
14 | TI_ASSERT(offloaded->task_type == TaskType::struct_for); |
15 | |
16 | std::vector<SNode *> snodes; |
17 | auto *snode = offloaded->snode; |
18 | int64 total_n = 1; |
19 | std::array<int, taichi_max_num_indices> total_shape; |
20 | total_shape.fill(1); |
21 | while (snode->type != SNodeType::root) { |
22 | snodes.push_back(snode); |
23 | for (int j = 0; j < taichi_max_num_indices; j++) { |
24 | total_shape[j] *= snode->extractors[j].shape; |
25 | } |
26 | total_n *= snode->num_cells_per_container; |
27 | snode = snode->parent; |
28 | } |
29 | TI_ASSERT(total_n <= std::numeric_limits<int>::max()); |
30 | std::reverse(snodes.begin(), snodes.end()); |
31 | |
32 | offloaded->const_begin = true; |
33 | offloaded->const_end = true; |
34 | offloaded->begin_value = 0; |
35 | offloaded->end_value = total_n; |
36 | |
37 | ////// Begin core transformation |
38 | auto body = std::move(offloaded->body); |
39 | const int num_loop_vars = |
40 | snodes.empty() ? 0 : snodes.back()->num_active_indices; |
41 | |
42 | std::vector<Stmt *> new_loop_vars; |
43 | |
44 | VecStatement body_header; |
45 | |
46 | std::vector<int> physical_indices; |
47 | |
48 | for (int i = 0; i < num_loop_vars; i++) { |
49 | new_loop_vars.push_back(body_header.push_back<ConstStmt>(TypedConstant(0))); |
50 | physical_indices.push_back(snodes.back()->physical_index_position[i]); |
51 | } |
52 | |
53 | auto main_loop_var = body_header.push_back<LoopIndexStmt>(nullptr, 0); |
54 | // We will set main_loop_var->loop later. |
55 | |
56 | for (int i = 0; i < (int)snodes.size(); i++) { |
57 | auto snode = snodes[i]; |
58 | Stmt * = main_loop_var; |
59 | if (i != 0) { // first extraction doesn't need a mod |
60 | extracted = generate_mod(&body_header, extracted, total_n); |
61 | } |
62 | total_n /= snode->num_cells_per_container; |
63 | extracted = generate_div(&body_header, extracted, total_n); |
64 | bool = true; |
65 | for (int j = 0; j < (int)physical_indices.size(); j++) { |
66 | auto p = physical_indices[j]; |
67 | auto ext = snode->extractors[p]; |
68 | if (!ext.active) |
69 | continue; |
70 | Stmt *index = extracted; |
71 | if (is_first_extraction) { // first extraction doesn't need a mod |
72 | is_first_extraction = false; |
73 | } else { |
74 | index = generate_mod(&body_header, index, ext.acc_shape * ext.shape); |
75 | } |
76 | index = generate_div(&body_header, index, ext.acc_shape); |
77 | total_shape[p] /= ext.shape; |
78 | auto multiplier = |
79 | body_header.push_back<ConstStmt>(TypedConstant(total_shape[p])); |
80 | auto delta = body_header.push_back<BinaryOpStmt>(BinaryOpType::mul, index, |
81 | multiplier); |
82 | new_loop_vars[j] = body_header.push_back<BinaryOpStmt>( |
83 | BinaryOpType::add, new_loop_vars[j], delta); |
84 | } |
85 | } |
86 | |
87 | irpass::replace_statements( |
88 | body.get(), /*filter=*/ |
89 | [&](Stmt *s) { |
90 | if (auto loop_index = s->cast<LoopIndexStmt>()) { |
91 | return loop_index->loop == offloaded; |
92 | } else { |
93 | return false; |
94 | } |
95 | }, |
96 | /*finder=*/ |
97 | [&](Stmt *s) { |
98 | auto index = std::find(physical_indices.begin(), physical_indices.end(), |
99 | s->as<LoopIndexStmt>()->index); |
100 | TI_ASSERT(index != physical_indices.end()); |
101 | return new_loop_vars[index - physical_indices.begin()]; |
102 | }); |
103 | |
104 | body->insert(std::move(body_header), 0); |
105 | |
106 | offloaded->body = std::move(body); |
107 | offloaded->body->parent_stmt = offloaded; |
108 | main_loop_var->loop = offloaded; |
109 | ////// End core transformation |
110 | |
111 | offloaded->task_type = TaskType::range_for; |
112 | } |
113 | |
114 | void maybe_convert(OffloadedStmt *stmt) { |
115 | if ((stmt->task_type == TaskType::struct_for) && |
116 | stmt->snode->is_path_all_dense) { |
117 | convert_to_range_for(stmt); |
118 | } |
119 | } |
120 | |
121 | } // namespace |
122 | |
123 | namespace irpass { |
124 | |
125 | void demote_dense_struct_fors(IRNode *root) { |
126 | if (auto *block = root->cast<Block>()) { |
127 | for (auto &s_ : block->statements) { |
128 | if (auto *s = s_->cast<OffloadedStmt>()) { |
129 | maybe_convert(s); |
130 | } |
131 | } |
132 | } else if (auto *s = root->cast<OffloadedStmt>()) { |
133 | maybe_convert(s); |
134 | } |
135 | re_id(root); |
136 | } |
137 | |
138 | } // namespace irpass |
139 | |
140 | } // namespace taichi::lang |
141 | |