1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/kernel.h"
7#include "taichi/program/program.h"
8#include "taichi/transforms/lower_access.h"
9#include "taichi/transforms/scalar_pointer_lowerer.h"
10
11#include <deque>
12#include <set>
13
14namespace taichi::lang {
15namespace {
16
17class LowerAccess;
18
19class PtrLowererImpl : public ScalarPointerLowerer {
20 public:
21 using ScalarPointerLowerer::ScalarPointerLowerer;
22
23 void set_lower_access(LowerAccess *la);
24
25 void set_pointer_needs_activation(bool v) {
26 pointer_needs_activation_ = v;
27 }
28
29 protected:
30 Stmt *handle_snode_at_level(int level,
31 LinearizeStmt *linearized,
32 Stmt *last) override;
33
34 private:
35 LowerAccess *la_{nullptr};
36 std::unordered_set<SNode *> snodes_on_loop_;
37 bool pointer_needs_activation_{false};
38};
39
40// Lower GlobalPtrStmt into smaller pieces for access optimization
41
42class LowerAccess : public IRVisitor {
43 public:
44 DelayedIRModifier modifier;
45 StructForStmt *current_struct_for;
46 const std::vector<SNode *> &kernel_forces_no_activate;
47 bool lower_atomic_ptr;
48
49 LowerAccess(const std::vector<SNode *> &kernel_forces_no_activate,
50 bool lower_atomic_ptr)
51 : kernel_forces_no_activate(kernel_forces_no_activate),
52 lower_atomic_ptr(lower_atomic_ptr) {
53 // TODO: change this to false
54 allow_undefined_visitor = true;
55 current_struct_for = nullptr;
56 }
57
58 void visit(Block *stmt_list) override {
59 for (auto &stmt : stmt_list->statements) {
60 stmt->accept(this);
61 }
62 }
63
64 void visit(IfStmt *if_stmt) override {
65 if (if_stmt->true_statements)
66 if_stmt->true_statements->accept(this);
67 if (if_stmt->false_statements) {
68 if_stmt->false_statements->accept(this);
69 }
70 }
71
72 void visit(OffloadedStmt *stmt) override {
73 stmt->all_blocks_accept(this);
74 }
75
76 void visit(WhileStmt *stmt) override {
77 stmt->body->accept(this);
78 }
79
80 void visit(RangeForStmt *for_stmt) override {
81 for_stmt->body->accept(this);
82 }
83
84 void visit(StructForStmt *for_stmt) override {
85 current_struct_for = for_stmt;
86 for_stmt->body->accept(this);
87 current_struct_for = nullptr;
88 }
89
90 VecStatement lower_ptr(GlobalPtrStmt *ptr,
91 bool activate,
92 SNodeOpType snode_op = SNodeOpType::undefined) {
93 VecStatement lowered;
94 if (snode_op == SNodeOpType::is_active) {
95 // For ti.is_active
96 TI_ASSERT(!activate);
97 }
98 PtrLowererImpl lowerer{ptr->snode, ptr->indices, snode_op,
99 ptr->is_bit_vectorized, &lowered};
100 lowerer.set_pointer_needs_activation(activate);
101 lowerer.set_lower_access(this);
102 lowerer.run();
103 TI_ASSERT(lowered.size() > 0);
104 auto lowered_ptr = lowered.back().get();
105 if (ptr->is_bit_vectorized) {
106 // if the global ptr is bit vectorized, we start from the place snode
107 // and find the parent quant array snode, use its physical type
108 auto parent_ret_type = ptr->snode->parent->physical_type;
109 auto ptr_ret_type =
110 TypeFactory::get_instance().get_pointer_type(parent_ret_type);
111 lowered_ptr->ret_type = DataType(ptr_ret_type);
112 } else {
113 lowered_ptr->ret_type = ptr->snode->dt;
114 }
115 return lowered;
116 }
117
118 void visit(GlobalLoadStmt *stmt) override {
119 if (!stmt->src->is<GlobalPtrStmt>())
120 return;
121 // No need to activate for all read accesses
122 auto lowered = lower_ptr(stmt->src->as<GlobalPtrStmt>(), false);
123 stmt->src = lowered.back().get();
124 modifier.insert_before(stmt, std::move(lowered));
125 }
126
127 // TODO: this seems to be redundant
128 void visit(MatrixPtrStmt *stmt) override {
129 if (!stmt->is_unlowered_global_ptr())
130 return;
131 auto ptr = stmt->origin->as<GlobalPtrStmt>();
132 // If ptr already has activate = false, no need to activate all the
133 // generated micro-access ops. Otherwise, activate the nodes.
134 auto lowered = lower_ptr(ptr, ptr->activate);
135 stmt->origin = lowered.back().get();
136 modifier.insert_before(stmt, std::move(lowered));
137 }
138
139 void visit(GlobalStoreStmt *stmt) override {
140 if (!stmt->dest->is<GlobalPtrStmt>())
141 return;
142 auto ptr = stmt->dest->as<GlobalPtrStmt>();
143 // If ptr already has activate = false, no need to activate all the
144 // generated micro-access ops. Otherwise, activate the nodes.
145 auto lowered = lower_ptr(ptr, ptr->activate);
146 stmt->dest = lowered.back().get();
147 modifier.insert_before(stmt, std::move(lowered));
148 }
149
150 void visit(SNodeOpStmt *stmt) override {
151 if (stmt->ptr->is<GlobalPtrStmt>()) {
152 auto global_ptr = stmt->ptr->as<GlobalPtrStmt>();
153 if (global_ptr->is_cell_access) {
154 auto lowered = lower_ptr(global_ptr, false, stmt->op_type);
155 modifier.replace_with(stmt, std::move(lowered), true);
156 } else if (stmt->op_type == SNodeOpType::get_addr) {
157 auto lowered = lower_ptr(global_ptr, false);
158 auto cast = lowered.push_back<UnaryOpStmt>(UnaryOpType::cast_bits,
159 lowered.back().get());
160 cast->cast_type = TypeFactory::get_instance().get_primitive_type(
161 PrimitiveTypeID::u64);
162 stmt->ptr = lowered.back().get();
163 modifier.replace_with(stmt, std::move(lowered));
164 } else {
165 auto lowered =
166 lower_ptr(global_ptr, SNodeOpStmt::need_activation(stmt->op_type));
167 stmt->ptr = lowered.back().get();
168 modifier.insert_before(stmt, std::move(lowered));
169 }
170 }
171 }
172
173 void visit(AtomicOpStmt *stmt) override {
174 if (!lower_atomic_ptr)
175 return;
176 if (stmt->dest->is<GlobalPtrStmt>()) {
177 auto lowered = lower_ptr(stmt->dest->as<GlobalPtrStmt>(),
178 stmt->dest->as<GlobalPtrStmt>()->activate);
179 stmt->dest = lowered.back().get();
180 modifier.insert_before(stmt, std::move(lowered));
181 }
182 }
183
184 void visit(LocalStoreStmt *stmt) override {
185 if (stmt->val->is<GlobalPtrStmt>()) {
186 auto lowered = lower_ptr(stmt->val->as<GlobalPtrStmt>(), true);
187 stmt->val = lowered.back().get();
188 modifier.insert_before(stmt, std::move(lowered));
189 }
190 }
191
192 static bool run(IRNode *node,
193 const std::vector<SNode *> &kernel_forces_no_activate,
194 bool lower_atomic) {
195 LowerAccess inst(kernel_forces_no_activate, lower_atomic);
196 bool modified = false;
197 while (true) {
198 node->accept(&inst);
199 if (inst.modifier.modify_ir()) {
200 modified = true;
201 } else {
202 break;
203 }
204 }
205 return modified;
206 }
207};
208
209void PtrLowererImpl::set_lower_access(LowerAccess *la) {
210 la_ = la;
211
212 snodes_on_loop_.clear();
213 if (la_->current_struct_for) {
214 for (SNode *s = la_->current_struct_for->snode; s != nullptr;
215 s = s->parent) {
216 snodes_on_loop_.insert(s);
217 }
218 }
219}
220
221Stmt *PtrLowererImpl::handle_snode_at_level(int level,
222 LinearizeStmt *linearized,
223 Stmt *last) {
224 // Check whether |snode| is part of the tree being iterated over by struct for
225 auto *snode = snodes()[level];
226 bool on_loop_tree = (snodes_on_loop_.find(snode) != snodes_on_loop_.end());
227 auto *current_struct_for = la_->current_struct_for;
228 if (on_loop_tree && current_struct_for &&
229 (indices_.size() == current_struct_for->snode->num_active_indices)) {
230 for (int j = 0; j < (int)indices_.size(); j++) {
231 auto diff = irpass::analysis::value_diff_loop_index(
232 indices_[j], current_struct_for, j);
233 if (!diff.linear_related()) {
234 on_loop_tree = false;
235 } else if (j == (int)indices_.size() - 1) {
236 if (!(0 <= diff.low && diff.high <= 1)) { // TODO: Vectorize
237 on_loop_tree = false;
238 }
239 } else {
240 if (!diff.certain() || diff.low != 0) {
241 on_loop_tree = false;
242 }
243 }
244 }
245 }
246
247 // Generates the SNode access operations at the current |level|.
248 if ((snode_op_ != SNodeOpType::undefined) &&
249 (level == (int)snodes().size() - 1)) {
250 // Create a SNodeOp querying if element i(linearized) of node is active
251 lowered_->push_back<SNodeOpStmt>(snode_op_, snode, last, linearized);
252 } else {
253 const bool kernel_forces_no_activate_snode =
254 std::find(la_->kernel_forces_no_activate.begin(),
255 la_->kernel_forces_no_activate.end(),
256 snode) != la_->kernel_forces_no_activate.end();
257
258 const bool needs_activation =
259 snode->need_activation() && pointer_needs_activation_ &&
260 !kernel_forces_no_activate_snode && !on_loop_tree;
261
262 auto lookup = lowered_->push_back<SNodeLookupStmt>(snode, last, linearized,
263 needs_activation);
264 int chid = snode->child_id(snodes()[level + 1]);
265 if (is_bit_vectorized_ && (snode->type == SNodeType::dense) &&
266 (level == path_length() - 2)) {
267 last = lowered_->push_back<GetChStmt>(lookup, chid,
268 /*is_bit_vectorized=*/true);
269 } else {
270 last = lowered_->push_back<GetChStmt>(lookup, chid,
271 /*is_bit_vectorized=*/false);
272 }
273 }
274 return last;
275}
276
277} // namespace
278
279const PassID LowerAccessPass::id = "LowerAccessPass";
280
281namespace irpass {
282
283bool lower_access(IRNode *root,
284 const CompileConfig &config,
285 const LowerAccessPass::Args &args) {
286 bool modified =
287 LowerAccess::run(root, args.kernel_forces_no_activate, args.lower_atomic);
288 type_check(root, config);
289 return modified;
290}
291
292} // namespace irpass
293} // namespace taichi::lang
294