1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/transforms/simplify.h"
7#include "taichi/program/kernel.h"
8#include "taichi/program/program.h"
9#include "taichi/transforms/utils.h"
10#include <set>
11#include <unordered_set>
12#include <utility>
13
14namespace taichi::lang {
15
16// Common subexpression elimination, store forwarding, useless local store
17// elimination; Simplify if statements into conditional stores.
18class BasicBlockSimplify : public IRVisitor {
19 public:
20 Block *block;
21
22 int current_stmt_id;
23 std::set<int> &visited;
24 StructForStmt *current_struct_for;
25 CompileConfig config;
26 DelayedIRModifier modifier;
27
28 BasicBlockSimplify(Block *block,
29 std::set<int> &visited,
30 StructForStmt *current_struct_for,
31 const CompileConfig &config)
32 : block(block),
33 visited(visited),
34 current_struct_for(current_struct_for),
35 config(config) {
36 allow_undefined_visitor = true;
37 invoke_default_visitor = false;
38 }
39
40 bool is_done(Stmt *stmt) {
41 return visited.find(stmt->instance_id) != visited.end();
42 }
43
44 void set_done(Stmt *stmt) {
45 visited.insert(stmt->instance_id);
46 }
47
48 void accept_block() {
49 for (int i = 0; i < (int)block->statements.size(); i++) {
50 current_stmt_id = i;
51 block->statements[i]->accept(this);
52 }
53 }
54
55 static bool run(Block *block,
56 std::set<int> &visited,
57 StructForStmt *current_struct_for,
58 const CompileConfig &config) {
59 BasicBlockSimplify simplifier(block, visited, current_struct_for, config);
60 bool ir_modified = false;
61 while (true) {
62 simplifier.accept_block();
63 if (simplifier.modifier.modify_ir()) {
64 ir_modified = true;
65 } else {
66 break;
67 }
68 }
69 return ir_modified;
70 }
71
72 void visit(GlobalLoadStmt *stmt) override {
73 if (is_done(stmt))
74 return;
75 for (int i = 0; i < current_stmt_id; i++) {
76 auto &bstmt = block->statements[i];
77 if (stmt->ret_type == bstmt->ret_type) {
78 auto &bstmt_data = *bstmt;
79 if (typeid(bstmt_data) == typeid(*stmt)) {
80 auto bstmt_ = bstmt->as<GlobalLoadStmt>();
81 bool same = stmt->src == bstmt_->src;
82 if (same) {
83 // no store to the var?
84 bool has_store = false;
85 auto advanced_optimization = config.advanced_optimization;
86 for (int j = i + 1; j < current_stmt_id; j++) {
87 if (!advanced_optimization) {
88 if (block->statements[j]
89 ->is_container_statement()) { // no if, while, etc..
90 has_store = true;
91 break;
92 }
93 if (block->statements[j]->is<GlobalStoreStmt>()) {
94 has_store = true;
95 }
96 continue;
97 }
98 if (block->statements[j]->is<FuncCallStmt>()) {
99 has_store = true;
100 }
101 if (!irpass::analysis::gather_statements(
102 block->statements[j].get(),
103 [&](Stmt *s) {
104 if (auto store = s->cast<GlobalStoreStmt>())
105 return irpass::analysis::maybe_same_address(
106 store->dest, stmt->src);
107 else if (auto atomic = s->cast<AtomicOpStmt>())
108 return irpass::analysis::maybe_same_address(
109 atomic->dest, stmt->src);
110 else
111 return false;
112 })
113 .empty()) {
114 has_store = true;
115 break;
116 }
117 }
118 if (!has_store) {
119 stmt->replace_usages_with(bstmt.get());
120 modifier.erase(stmt);
121 return;
122 }
123 }
124 }
125 }
126 }
127 set_done(stmt);
128 }
129
130 void visit(IntegerOffsetStmt *stmt) override {
131 if (stmt->offset == 0) {
132 stmt->replace_usages_with(stmt->input);
133 modifier.erase(stmt);
134 }
135 }
136
137 template <typename T>
138 static bool identical_vectors(const std::vector<T> &a,
139 const std::vector<T> &b) {
140 if (a.size() != b.size()) {
141 return false;
142 } else {
143 for (int i = 0; i < (int)a.size(); i++) {
144 if (a[i] != b[i])
145 return false;
146 }
147 }
148 return true;
149 }
150
151 void visit(LinearizeStmt *stmt) override {
152 if (!stmt->inputs.empty() && stmt->inputs.back()->is<IntegerOffsetStmt>()) {
153 auto previous_offset = stmt->inputs.back()->as<IntegerOffsetStmt>();
154 // push forward offset
155 auto offset_stmt =
156 Stmt::make<IntegerOffsetStmt>(stmt, previous_offset->offset);
157
158 stmt->inputs.back() = previous_offset->input;
159 stmt->replace_usages_with(offset_stmt.get());
160 offset_stmt->as<IntegerOffsetStmt>()->input = stmt;
161 modifier.insert_after(stmt, std::move(offset_stmt));
162 return;
163 }
164
165 // Lower into a series of adds and muls.
166 auto sum = Stmt::make<ConstStmt>(TypedConstant(0));
167 auto stride_product = 1;
168 for (int i = (int)stmt->inputs.size() - 1; i >= 0; i--) {
169 auto stride_stmt = Stmt::make<ConstStmt>(TypedConstant(stride_product));
170 auto mul = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->inputs[i],
171 stride_stmt.get());
172 auto newsum =
173 Stmt::make<BinaryOpStmt>(BinaryOpType::add, sum.get(), mul.get());
174 modifier.insert_before(stmt, std::move(sum));
175 sum = std::move(newsum);
176 modifier.insert_before(stmt, std::move(stride_stmt));
177 modifier.insert_before(stmt, std::move(mul));
178 stride_product *= stmt->strides[i];
179 }
180 // Compare the result with 0 to make sure no overflow occurs under Debug
181 // Mode.
182 bool debug = config.debug;
183 if (debug) {
184 auto zero = Stmt::make<ConstStmt>(TypedConstant(0));
185 auto check_sum =
186 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ge, sum.get(), zero.get());
187 auto assert = Stmt::make<AssertStmt>(
188 check_sum.get(), "The indices provided are too big!\n" + stmt->tb,
189 std::vector<Stmt *>());
190 // Because Taichi's assertion is checked only after the execution of the
191 // kernel, when the linear index overflows and goes negative, we have to
192 // replace that with 0 to make sure that the rest of the kernel can still
193 // complete. Otherwise, Taichi would crash due to illegal mem address.
194 auto select = Stmt::make<TernaryOpStmt>(
195 TernaryOpType::select, check_sum.get(), sum.get(), zero.get());
196
197 modifier.insert_before(stmt, std::move(zero));
198 modifier.insert_before(stmt, std::move(sum));
199 modifier.insert_before(stmt, std::move(check_sum));
200 modifier.insert_before(stmt, std::move(assert));
201 stmt->replace_usages_with(select.get());
202 modifier.insert_before(stmt, std::move(select));
203 } else {
204 stmt->replace_usages_with(sum.get());
205 modifier.insert_before(stmt, std::move(sum));
206 }
207 modifier.erase(stmt);
208 // get types of adds and muls
209 modifier.type_check(stmt->parent, config);
210 }
211
212 void visit(SNodeLookupStmt *stmt) override {
213 if (is_done(stmt))
214 return;
215
216 if (stmt->input_index->is<IntegerOffsetStmt>()) {
217 auto previous_offset = stmt->input_index->as<IntegerOffsetStmt>();
218 // push forward offset
219
220 auto snode = stmt->snode;
221 // compute offset...
222 for (int i = 0; i < (int)snode->ch.size(); i++) {
223 TI_ASSERT(snode->ch[i]->type == SNodeType::place);
224 TI_ASSERT(snode->ch[i]->dt->is_primitive(PrimitiveTypeID::i32) ||
225 snode->ch[i]->dt->is_primitive(PrimitiveTypeID::f32));
226 }
227
228 auto offset_stmt = Stmt::make<IntegerOffsetStmt>(
229 stmt, previous_offset->offset * sizeof(int32) * (snode->ch.size()));
230
231 stmt->input_index = previous_offset->input;
232 stmt->replace_usages_with(offset_stmt.get());
233 offset_stmt->as<IntegerOffsetStmt>()->input = stmt;
234 modifier.insert_after(stmt, std::move(offset_stmt));
235 return;
236 }
237
238 set_done(stmt);
239 }
240
241 void visit(GetChStmt *stmt) override {
242 if (is_done(stmt))
243 return;
244
245 if (stmt->input_ptr->is<IntegerOffsetStmt>()) {
246 auto previous_offset = stmt->input_ptr->as<IntegerOffsetStmt>();
247 // push forward offset
248
249 // auto snode = stmt->input_snode;
250 auto offset_stmt = Stmt::make<IntegerOffsetStmt>(
251 stmt, stmt->chid * sizeof(int32) + previous_offset->offset);
252
253 stmt->input_ptr = previous_offset->input;
254 stmt->replace_usages_with(offset_stmt.get());
255 stmt->chid = 0;
256 stmt->output_snode = stmt->input_snode->ch[stmt->chid].get();
257 offset_stmt->as<IntegerOffsetStmt>()->input = stmt;
258 modifier.insert_after(stmt, std::move(offset_stmt));
259 return;
260 }
261
262 set_done(stmt);
263 }
264
265 void visit(WhileControlStmt *stmt) override {
266 if (stmt->mask) {
267 stmt->mask = nullptr;
268 modifier.mark_as_modified();
269 return;
270 }
271 }
272
273 static bool is_global_write(Stmt *stmt) {
274 return stmt->is<GlobalStoreStmt>() || stmt->is<AtomicOpStmt>();
275 }
276
277 static bool is_atomic_value_used(const stmt_vector &clause,
278 int atomic_stmt_i) {
279 // Cast type to check precondition
280 const auto *stmt = clause[atomic_stmt_i]->as<AtomicOpStmt>();
281 auto alloca = stmt->dest;
282
283 for (std::size_t i = atomic_stmt_i + 1; i < clause.size(); ++i) {
284 for (const auto &op : clause[i]->get_operands()) {
285 if (op && (op->instance_id == stmt->instance_id ||
286 op->instance_id == alloca->instance_id)) {
287 return true;
288 }
289 }
290 }
291 return false;
292 }
293
294 void visit(IfStmt *if_stmt) override {
295 auto flatten = [&](stmt_vector &clause, bool true_branch) {
296 bool plain_clause = true; // no global store, no container
297
298 // Here we try to move statements outside the clause;
299 // Keep only global atomics/store, and other statements that have no
300 // global side effects. LocalStore is kept and specially treated later.
301
302 bool global_state_changed = false;
303 for (int i = 0; i < (int)clause.size() && plain_clause; i++) {
304 bool has_side_effects = clause[i]->is_container_statement() ||
305 clause[i]->has_global_side_effect();
306
307 if (global_state_changed && clause[i]->is<GlobalLoadStmt>()) {
308 // This clause cannot be trivially simplified, since there's a global
309 // load after store and they must be kept in order
310 plain_clause = false;
311 }
312
313 if (clause[i]->is<GlobalStoreStmt>() ||
314 clause[i]->is<LocalStoreStmt>() || !has_side_effects) {
315 // This stmt can be kept.
316 } else if (clause[i]->is<AtomicOpStmt>()) {
317 plain_clause = plain_clause && !is_atomic_value_used(clause, i);
318 } else {
319 plain_clause = false;
320 }
321 if (is_global_write(clause[i].get()) || has_side_effects) {
322 global_state_changed = true;
323 }
324 }
325 if (plain_clause) {
326 for (int i = 0; i < (int)clause.size(); i++) {
327 if (is_global_write(clause[i].get())) {
328 // do nothing. Keep the statement.
329 continue;
330 }
331 if (clause[i]->is<LocalStoreStmt>()) {
332 auto store = clause[i]->as<LocalStoreStmt>();
333 auto load = Stmt::make<LocalLoadStmt>(store->dest);
334 modifier.type_check(load.get(), config);
335 auto select = Stmt::make<TernaryOpStmt>(
336 TernaryOpType::select, if_stmt->cond,
337 true_branch ? store->val : load.get(),
338 true_branch ? load.get() : store->val);
339 modifier.type_check(select.get(), config);
340 store->val = select.get();
341 modifier.insert_before(if_stmt, std::move(load));
342 modifier.insert_before(if_stmt, std::move(select));
343 modifier.insert_before(if_stmt, std::move(clause[i]));
344 } else {
345 modifier.insert_before(if_stmt, std::move(clause[i]));
346 }
347 }
348 auto clean_clause = stmt_vector();
349 bool reduced = false;
350 for (auto &&stmt : clause) {
351 if (stmt != nullptr) {
352 clean_clause.push_back(std::move(stmt));
353 } else {
354 reduced = true;
355 }
356 }
357 clause = std::move(clean_clause);
358 return reduced;
359 }
360 return false;
361 };
362
363 if (config.flatten_if) {
364 if (if_stmt->true_statements &&
365 flatten(if_stmt->true_statements->statements, true)) {
366 modifier.mark_as_modified();
367 return;
368 }
369 if (if_stmt->false_statements &&
370 flatten(if_stmt->false_statements->statements, false)) {
371 modifier.mark_as_modified();
372 return;
373 }
374 }
375
376 if (if_stmt->true_statements) {
377 if (if_stmt->true_statements->statements.empty()) {
378 if_stmt->set_true_statements(nullptr);
379 modifier.mark_as_modified();
380 return;
381 }
382 }
383
384 if (if_stmt->false_statements) {
385 if (if_stmt->false_statements->statements.empty()) {
386 if_stmt->set_false_statements(nullptr);
387 modifier.mark_as_modified();
388 return;
389 }
390 }
391
392 if (!if_stmt->true_statements && !if_stmt->false_statements) {
393 modifier.erase(if_stmt);
394 return;
395 }
396
397 if (config.advanced_optimization) {
398 // Merge adjacent if's with the identical condition.
399 // TODO: What about IfStmt::true_mask and IfStmt::false_mask?
400 if (current_stmt_id < block->size() - 1 &&
401 block->statements[current_stmt_id + 1]->is<IfStmt>()) {
402 auto bstmt = block->statements[current_stmt_id + 1]->as<IfStmt>();
403 if (bstmt->cond == if_stmt->cond) {
404 auto concatenate = [](std::unique_ptr<Block> &clause1,
405 std::unique_ptr<Block> &clause2) {
406 if (clause1 == nullptr) {
407 clause1 = std::move(clause2);
408 return;
409 }
410 if (clause2 != nullptr)
411 clause1->insert(VecStatement(std::move(clause2->statements)), 0);
412 };
413 concatenate(bstmt->true_statements, if_stmt->true_statements);
414 concatenate(bstmt->false_statements, if_stmt->false_statements);
415 modifier.erase(if_stmt);
416 return;
417 }
418 }
419 }
420 }
421
422 void visit(OffloadedStmt *stmt) override {
423 if (stmt->has_body() && stmt->body->statements.empty()) {
424 modifier.erase(stmt);
425 return;
426 }
427 }
428
429 void visit(WhileStmt *stmt) override {
430 if (stmt->mask) {
431 stmt->mask = nullptr;
432 modifier.mark_as_modified();
433 return;
434 }
435 }
436};
437
438class Simplify : public IRVisitor {
439 public:
440 StructForStmt *current_struct_for;
441 bool modified;
442 const CompileConfig &config;
443
444 Simplify(IRNode *node, const CompileConfig &config) : config(config) {
445 modified = false;
446 allow_undefined_visitor = true;
447 invoke_default_visitor = true;
448 current_struct_for = nullptr;
449 node->accept(this);
450 }
451
452 void visit(Block *block) override {
453 std::set<int> visited;
454 if (BasicBlockSimplify::run(block, visited, current_struct_for, config)) {
455 modified = true;
456 }
457 for (auto &stmt : block->statements) {
458 stmt->accept(this);
459 }
460 }
461
462 void visit(IfStmt *if_stmt) override {
463 if (if_stmt->true_statements)
464 if_stmt->true_statements->accept(this);
465 if (if_stmt->false_statements)
466 if_stmt->false_statements->accept(this);
467 }
468
469 void visit(RangeForStmt *for_stmt) override {
470 for_stmt->body->accept(this);
471 }
472
473 void visit(StructForStmt *for_stmt) override {
474 TI_ASSERT_INFO(current_struct_for == nullptr,
475 "Nested struct-fors are not supported for now. "
476 "Please try to use range-fors for inner loops.");
477 current_struct_for = for_stmt;
478 for_stmt->body->accept(this);
479 current_struct_for = nullptr;
480 }
481
482 void visit(MeshForStmt *for_stmt) override {
483 for_stmt->body->accept(this);
484 }
485
486 void visit(WhileStmt *stmt) override {
487 stmt->body->accept(this);
488 }
489
490 void visit(OffloadedStmt *stmt) override {
491 stmt->all_blocks_accept(this);
492 }
493};
494
495const PassID FullSimplifyPass::id = "FullSimplifyPass";
496
497namespace irpass {
498
499bool simplify(IRNode *root, const CompileConfig &config) {
500 TI_AUTO_PROF;
501 bool modified = false;
502 while (true) {
503 Simplify pass(root, config);
504 if (pass.modified)
505 modified = true;
506 else
507 break;
508 }
509 return modified;
510}
511
512void full_simplify(IRNode *root,
513 const CompileConfig &config,
514 const FullSimplifyPass::Args &args) {
515 TI_AUTO_PROF;
516 if (config.advanced_optimization) {
517 bool first_iteration = true;
518 while (true) {
519 bool modified = false;
520 if (extract_constant(root, config))
521 modified = true;
522 if (unreachable_code_elimination(root))
523 modified = true;
524 if (binary_op_simplify(root, config))
525 modified = true;
526 if (config.constant_folding &&
527 constant_fold(root, config, {args.program}))
528 modified = true;
529 if (die(root))
530 modified = true;
531 if (alg_simp(root, config))
532 modified = true;
533 if (loop_invariant_code_motion(root, config))
534 modified = true;
535 if (die(root))
536 modified = true;
537 if (simplify(root, config))
538 modified = true;
539 if (die(root))
540 modified = true;
541 if (config.opt_level > 0 && whole_kernel_cse(root))
542 modified = true;
543 // Don't do this time-consuming optimization pass again if the IR is
544 // not modified.
545 if (config.opt_level > 0 && first_iteration && config.cfg_optimization &&
546 cfg_optimization(root, args.after_lower_access, args.autodiff_enabled,
547 !config.real_matrix_scalarize))
548 modified = true;
549 first_iteration = false;
550 if (!modified)
551 break;
552 }
553 return;
554 }
555 if (config.constant_folding) {
556 constant_fold(root, config, {args.program});
557 die(root);
558 }
559 simplify(root, config);
560 die(root);
561}
562
563} // namespace irpass
564
565} // namespace taichi::lang
566