1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/transforms/simplify.h" |
7 | #include "taichi/program/kernel.h" |
8 | #include "taichi/program/program.h" |
9 | #include "taichi/transforms/utils.h" |
10 | #include <set> |
11 | #include <unordered_set> |
12 | #include <utility> |
13 | |
14 | namespace taichi::lang { |
15 | |
16 | // Common subexpression elimination, store forwarding, useless local store |
17 | // elimination; Simplify if statements into conditional stores. |
18 | class BasicBlockSimplify : public IRVisitor { |
19 | public: |
20 | Block *block; |
21 | |
22 | int current_stmt_id; |
23 | std::set<int> &visited; |
24 | StructForStmt *current_struct_for; |
25 | CompileConfig config; |
26 | DelayedIRModifier modifier; |
27 | |
28 | BasicBlockSimplify(Block *block, |
29 | std::set<int> &visited, |
30 | StructForStmt *current_struct_for, |
31 | const CompileConfig &config) |
32 | : block(block), |
33 | visited(visited), |
34 | current_struct_for(current_struct_for), |
35 | config(config) { |
36 | allow_undefined_visitor = true; |
37 | invoke_default_visitor = false; |
38 | } |
39 | |
40 | bool is_done(Stmt *stmt) { |
41 | return visited.find(stmt->instance_id) != visited.end(); |
42 | } |
43 | |
44 | void set_done(Stmt *stmt) { |
45 | visited.insert(stmt->instance_id); |
46 | } |
47 | |
48 | void accept_block() { |
49 | for (int i = 0; i < (int)block->statements.size(); i++) { |
50 | current_stmt_id = i; |
51 | block->statements[i]->accept(this); |
52 | } |
53 | } |
54 | |
55 | static bool run(Block *block, |
56 | std::set<int> &visited, |
57 | StructForStmt *current_struct_for, |
58 | const CompileConfig &config) { |
59 | BasicBlockSimplify simplifier(block, visited, current_struct_for, config); |
60 | bool ir_modified = false; |
61 | while (true) { |
62 | simplifier.accept_block(); |
63 | if (simplifier.modifier.modify_ir()) { |
64 | ir_modified = true; |
65 | } else { |
66 | break; |
67 | } |
68 | } |
69 | return ir_modified; |
70 | } |
71 | |
72 | void visit(GlobalLoadStmt *stmt) override { |
73 | if (is_done(stmt)) |
74 | return; |
75 | for (int i = 0; i < current_stmt_id; i++) { |
76 | auto &bstmt = block->statements[i]; |
77 | if (stmt->ret_type == bstmt->ret_type) { |
78 | auto &bstmt_data = *bstmt; |
79 | if (typeid(bstmt_data) == typeid(*stmt)) { |
80 | auto bstmt_ = bstmt->as<GlobalLoadStmt>(); |
81 | bool same = stmt->src == bstmt_->src; |
82 | if (same) { |
83 | // no store to the var? |
84 | bool has_store = false; |
85 | auto advanced_optimization = config.advanced_optimization; |
86 | for (int j = i + 1; j < current_stmt_id; j++) { |
87 | if (!advanced_optimization) { |
88 | if (block->statements[j] |
89 | ->is_container_statement()) { // no if, while, etc.. |
90 | has_store = true; |
91 | break; |
92 | } |
93 | if (block->statements[j]->is<GlobalStoreStmt>()) { |
94 | has_store = true; |
95 | } |
96 | continue; |
97 | } |
98 | if (block->statements[j]->is<FuncCallStmt>()) { |
99 | has_store = true; |
100 | } |
101 | if (!irpass::analysis::gather_statements( |
102 | block->statements[j].get(), |
103 | [&](Stmt *s) { |
104 | if (auto store = s->cast<GlobalStoreStmt>()) |
105 | return irpass::analysis::maybe_same_address( |
106 | store->dest, stmt->src); |
107 | else if (auto atomic = s->cast<AtomicOpStmt>()) |
108 | return irpass::analysis::maybe_same_address( |
109 | atomic->dest, stmt->src); |
110 | else |
111 | return false; |
112 | }) |
113 | .empty()) { |
114 | has_store = true; |
115 | break; |
116 | } |
117 | } |
118 | if (!has_store) { |
119 | stmt->replace_usages_with(bstmt.get()); |
120 | modifier.erase(stmt); |
121 | return; |
122 | } |
123 | } |
124 | } |
125 | } |
126 | } |
127 | set_done(stmt); |
128 | } |
129 | |
130 | void visit(IntegerOffsetStmt *stmt) override { |
131 | if (stmt->offset == 0) { |
132 | stmt->replace_usages_with(stmt->input); |
133 | modifier.erase(stmt); |
134 | } |
135 | } |
136 | |
137 | template <typename T> |
138 | static bool identical_vectors(const std::vector<T> &a, |
139 | const std::vector<T> &b) { |
140 | if (a.size() != b.size()) { |
141 | return false; |
142 | } else { |
143 | for (int i = 0; i < (int)a.size(); i++) { |
144 | if (a[i] != b[i]) |
145 | return false; |
146 | } |
147 | } |
148 | return true; |
149 | } |
150 | |
151 | void visit(LinearizeStmt *stmt) override { |
152 | if (!stmt->inputs.empty() && stmt->inputs.back()->is<IntegerOffsetStmt>()) { |
153 | auto previous_offset = stmt->inputs.back()->as<IntegerOffsetStmt>(); |
154 | // push forward offset |
155 | auto offset_stmt = |
156 | Stmt::make<IntegerOffsetStmt>(stmt, previous_offset->offset); |
157 | |
158 | stmt->inputs.back() = previous_offset->input; |
159 | stmt->replace_usages_with(offset_stmt.get()); |
160 | offset_stmt->as<IntegerOffsetStmt>()->input = stmt; |
161 | modifier.insert_after(stmt, std::move(offset_stmt)); |
162 | return; |
163 | } |
164 | |
165 | // Lower into a series of adds and muls. |
166 | auto sum = Stmt::make<ConstStmt>(TypedConstant(0)); |
167 | auto stride_product = 1; |
168 | for (int i = (int)stmt->inputs.size() - 1; i >= 0; i--) { |
169 | auto stride_stmt = Stmt::make<ConstStmt>(TypedConstant(stride_product)); |
170 | auto mul = Stmt::make<BinaryOpStmt>(BinaryOpType::mul, stmt->inputs[i], |
171 | stride_stmt.get()); |
172 | auto newsum = |
173 | Stmt::make<BinaryOpStmt>(BinaryOpType::add, sum.get(), mul.get()); |
174 | modifier.insert_before(stmt, std::move(sum)); |
175 | sum = std::move(newsum); |
176 | modifier.insert_before(stmt, std::move(stride_stmt)); |
177 | modifier.insert_before(stmt, std::move(mul)); |
178 | stride_product *= stmt->strides[i]; |
179 | } |
180 | // Compare the result with 0 to make sure no overflow occurs under Debug |
181 | // Mode. |
182 | bool debug = config.debug; |
183 | if (debug) { |
184 | auto zero = Stmt::make<ConstStmt>(TypedConstant(0)); |
185 | auto check_sum = |
186 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ge, sum.get(), zero.get()); |
187 | auto assert = Stmt::make<AssertStmt>( |
188 | check_sum.get(), "The indices provided are too big!\n" + stmt->tb, |
189 | std::vector<Stmt *>()); |
190 | // Because Taichi's assertion is checked only after the execution of the |
191 | // kernel, when the linear index overflows and goes negative, we have to |
192 | // replace that with 0 to make sure that the rest of the kernel can still |
193 | // complete. Otherwise, Taichi would crash due to illegal mem address. |
194 | auto select = Stmt::make<TernaryOpStmt>( |
195 | TernaryOpType::select, check_sum.get(), sum.get(), zero.get()); |
196 | |
197 | modifier.insert_before(stmt, std::move(zero)); |
198 | modifier.insert_before(stmt, std::move(sum)); |
199 | modifier.insert_before(stmt, std::move(check_sum)); |
200 | modifier.insert_before(stmt, std::move(assert)); |
201 | stmt->replace_usages_with(select.get()); |
202 | modifier.insert_before(stmt, std::move(select)); |
203 | } else { |
204 | stmt->replace_usages_with(sum.get()); |
205 | modifier.insert_before(stmt, std::move(sum)); |
206 | } |
207 | modifier.erase(stmt); |
208 | // get types of adds and muls |
209 | modifier.type_check(stmt->parent, config); |
210 | } |
211 | |
212 | void visit(SNodeLookupStmt *stmt) override { |
213 | if (is_done(stmt)) |
214 | return; |
215 | |
216 | if (stmt->input_index->is<IntegerOffsetStmt>()) { |
217 | auto previous_offset = stmt->input_index->as<IntegerOffsetStmt>(); |
218 | // push forward offset |
219 | |
220 | auto snode = stmt->snode; |
221 | // compute offset... |
222 | for (int i = 0; i < (int)snode->ch.size(); i++) { |
223 | TI_ASSERT(snode->ch[i]->type == SNodeType::place); |
224 | TI_ASSERT(snode->ch[i]->dt->is_primitive(PrimitiveTypeID::i32) || |
225 | snode->ch[i]->dt->is_primitive(PrimitiveTypeID::f32)); |
226 | } |
227 | |
228 | auto offset_stmt = Stmt::make<IntegerOffsetStmt>( |
229 | stmt, previous_offset->offset * sizeof(int32) * (snode->ch.size())); |
230 | |
231 | stmt->input_index = previous_offset->input; |
232 | stmt->replace_usages_with(offset_stmt.get()); |
233 | offset_stmt->as<IntegerOffsetStmt>()->input = stmt; |
234 | modifier.insert_after(stmt, std::move(offset_stmt)); |
235 | return; |
236 | } |
237 | |
238 | set_done(stmt); |
239 | } |
240 | |
241 | void visit(GetChStmt *stmt) override { |
242 | if (is_done(stmt)) |
243 | return; |
244 | |
245 | if (stmt->input_ptr->is<IntegerOffsetStmt>()) { |
246 | auto previous_offset = stmt->input_ptr->as<IntegerOffsetStmt>(); |
247 | // push forward offset |
248 | |
249 | // auto snode = stmt->input_snode; |
250 | auto offset_stmt = Stmt::make<IntegerOffsetStmt>( |
251 | stmt, stmt->chid * sizeof(int32) + previous_offset->offset); |
252 | |
253 | stmt->input_ptr = previous_offset->input; |
254 | stmt->replace_usages_with(offset_stmt.get()); |
255 | stmt->chid = 0; |
256 | stmt->output_snode = stmt->input_snode->ch[stmt->chid].get(); |
257 | offset_stmt->as<IntegerOffsetStmt>()->input = stmt; |
258 | modifier.insert_after(stmt, std::move(offset_stmt)); |
259 | return; |
260 | } |
261 | |
262 | set_done(stmt); |
263 | } |
264 | |
265 | void visit(WhileControlStmt *stmt) override { |
266 | if (stmt->mask) { |
267 | stmt->mask = nullptr; |
268 | modifier.mark_as_modified(); |
269 | return; |
270 | } |
271 | } |
272 | |
273 | static bool is_global_write(Stmt *stmt) { |
274 | return stmt->is<GlobalStoreStmt>() || stmt->is<AtomicOpStmt>(); |
275 | } |
276 | |
277 | static bool is_atomic_value_used(const stmt_vector &clause, |
278 | int atomic_stmt_i) { |
279 | // Cast type to check precondition |
280 | const auto *stmt = clause[atomic_stmt_i]->as<AtomicOpStmt>(); |
281 | auto alloca = stmt->dest; |
282 | |
283 | for (std::size_t i = atomic_stmt_i + 1; i < clause.size(); ++i) { |
284 | for (const auto &op : clause[i]->get_operands()) { |
285 | if (op && (op->instance_id == stmt->instance_id || |
286 | op->instance_id == alloca->instance_id)) { |
287 | return true; |
288 | } |
289 | } |
290 | } |
291 | return false; |
292 | } |
293 | |
294 | void visit(IfStmt *if_stmt) override { |
295 | auto flatten = [&](stmt_vector &clause, bool true_branch) { |
296 | bool plain_clause = true; // no global store, no container |
297 | |
298 | // Here we try to move statements outside the clause; |
299 | // Keep only global atomics/store, and other statements that have no |
300 | // global side effects. LocalStore is kept and specially treated later. |
301 | |
302 | bool global_state_changed = false; |
303 | for (int i = 0; i < (int)clause.size() && plain_clause; i++) { |
304 | bool has_side_effects = clause[i]->is_container_statement() || |
305 | clause[i]->has_global_side_effect(); |
306 | |
307 | if (global_state_changed && clause[i]->is<GlobalLoadStmt>()) { |
308 | // This clause cannot be trivially simplified, since there's a global |
309 | // load after store and they must be kept in order |
310 | plain_clause = false; |
311 | } |
312 | |
313 | if (clause[i]->is<GlobalStoreStmt>() || |
314 | clause[i]->is<LocalStoreStmt>() || !has_side_effects) { |
315 | // This stmt can be kept. |
316 | } else if (clause[i]->is<AtomicOpStmt>()) { |
317 | plain_clause = plain_clause && !is_atomic_value_used(clause, i); |
318 | } else { |
319 | plain_clause = false; |
320 | } |
321 | if (is_global_write(clause[i].get()) || has_side_effects) { |
322 | global_state_changed = true; |
323 | } |
324 | } |
325 | if (plain_clause) { |
326 | for (int i = 0; i < (int)clause.size(); i++) { |
327 | if (is_global_write(clause[i].get())) { |
328 | // do nothing. Keep the statement. |
329 | continue; |
330 | } |
331 | if (clause[i]->is<LocalStoreStmt>()) { |
332 | auto store = clause[i]->as<LocalStoreStmt>(); |
333 | auto load = Stmt::make<LocalLoadStmt>(store->dest); |
334 | modifier.type_check(load.get(), config); |
335 | auto select = Stmt::make<TernaryOpStmt>( |
336 | TernaryOpType::select, if_stmt->cond, |
337 | true_branch ? store->val : load.get(), |
338 | true_branch ? load.get() : store->val); |
339 | modifier.type_check(select.get(), config); |
340 | store->val = select.get(); |
341 | modifier.insert_before(if_stmt, std::move(load)); |
342 | modifier.insert_before(if_stmt, std::move(select)); |
343 | modifier.insert_before(if_stmt, std::move(clause[i])); |
344 | } else { |
345 | modifier.insert_before(if_stmt, std::move(clause[i])); |
346 | } |
347 | } |
348 | auto clean_clause = stmt_vector(); |
349 | bool reduced = false; |
350 | for (auto &&stmt : clause) { |
351 | if (stmt != nullptr) { |
352 | clean_clause.push_back(std::move(stmt)); |
353 | } else { |
354 | reduced = true; |
355 | } |
356 | } |
357 | clause = std::move(clean_clause); |
358 | return reduced; |
359 | } |
360 | return false; |
361 | }; |
362 | |
363 | if (config.flatten_if) { |
364 | if (if_stmt->true_statements && |
365 | flatten(if_stmt->true_statements->statements, true)) { |
366 | modifier.mark_as_modified(); |
367 | return; |
368 | } |
369 | if (if_stmt->false_statements && |
370 | flatten(if_stmt->false_statements->statements, false)) { |
371 | modifier.mark_as_modified(); |
372 | return; |
373 | } |
374 | } |
375 | |
376 | if (if_stmt->true_statements) { |
377 | if (if_stmt->true_statements->statements.empty()) { |
378 | if_stmt->set_true_statements(nullptr); |
379 | modifier.mark_as_modified(); |
380 | return; |
381 | } |
382 | } |
383 | |
384 | if (if_stmt->false_statements) { |
385 | if (if_stmt->false_statements->statements.empty()) { |
386 | if_stmt->set_false_statements(nullptr); |
387 | modifier.mark_as_modified(); |
388 | return; |
389 | } |
390 | } |
391 | |
392 | if (!if_stmt->true_statements && !if_stmt->false_statements) { |
393 | modifier.erase(if_stmt); |
394 | return; |
395 | } |
396 | |
397 | if (config.advanced_optimization) { |
398 | // Merge adjacent if's with the identical condition. |
399 | // TODO: What about IfStmt::true_mask and IfStmt::false_mask? |
400 | if (current_stmt_id < block->size() - 1 && |
401 | block->statements[current_stmt_id + 1]->is<IfStmt>()) { |
402 | auto bstmt = block->statements[current_stmt_id + 1]->as<IfStmt>(); |
403 | if (bstmt->cond == if_stmt->cond) { |
404 | auto concatenate = [](std::unique_ptr<Block> &clause1, |
405 | std::unique_ptr<Block> &clause2) { |
406 | if (clause1 == nullptr) { |
407 | clause1 = std::move(clause2); |
408 | return; |
409 | } |
410 | if (clause2 != nullptr) |
411 | clause1->insert(VecStatement(std::move(clause2->statements)), 0); |
412 | }; |
413 | concatenate(bstmt->true_statements, if_stmt->true_statements); |
414 | concatenate(bstmt->false_statements, if_stmt->false_statements); |
415 | modifier.erase(if_stmt); |
416 | return; |
417 | } |
418 | } |
419 | } |
420 | } |
421 | |
422 | void visit(OffloadedStmt *stmt) override { |
423 | if (stmt->has_body() && stmt->body->statements.empty()) { |
424 | modifier.erase(stmt); |
425 | return; |
426 | } |
427 | } |
428 | |
429 | void visit(WhileStmt *stmt) override { |
430 | if (stmt->mask) { |
431 | stmt->mask = nullptr; |
432 | modifier.mark_as_modified(); |
433 | return; |
434 | } |
435 | } |
436 | }; |
437 | |
438 | class Simplify : public IRVisitor { |
439 | public: |
440 | StructForStmt *current_struct_for; |
441 | bool modified; |
442 | const CompileConfig &config; |
443 | |
444 | Simplify(IRNode *node, const CompileConfig &config) : config(config) { |
445 | modified = false; |
446 | allow_undefined_visitor = true; |
447 | invoke_default_visitor = true; |
448 | current_struct_for = nullptr; |
449 | node->accept(this); |
450 | } |
451 | |
452 | void visit(Block *block) override { |
453 | std::set<int> visited; |
454 | if (BasicBlockSimplify::run(block, visited, current_struct_for, config)) { |
455 | modified = true; |
456 | } |
457 | for (auto &stmt : block->statements) { |
458 | stmt->accept(this); |
459 | } |
460 | } |
461 | |
462 | void visit(IfStmt *if_stmt) override { |
463 | if (if_stmt->true_statements) |
464 | if_stmt->true_statements->accept(this); |
465 | if (if_stmt->false_statements) |
466 | if_stmt->false_statements->accept(this); |
467 | } |
468 | |
469 | void visit(RangeForStmt *for_stmt) override { |
470 | for_stmt->body->accept(this); |
471 | } |
472 | |
473 | void visit(StructForStmt *for_stmt) override { |
474 | TI_ASSERT_INFO(current_struct_for == nullptr, |
475 | "Nested struct-fors are not supported for now. " |
476 | "Please try to use range-fors for inner loops." ); |
477 | current_struct_for = for_stmt; |
478 | for_stmt->body->accept(this); |
479 | current_struct_for = nullptr; |
480 | } |
481 | |
482 | void visit(MeshForStmt *for_stmt) override { |
483 | for_stmt->body->accept(this); |
484 | } |
485 | |
486 | void visit(WhileStmt *stmt) override { |
487 | stmt->body->accept(this); |
488 | } |
489 | |
490 | void visit(OffloadedStmt *stmt) override { |
491 | stmt->all_blocks_accept(this); |
492 | } |
493 | }; |
494 | |
495 | const PassID FullSimplifyPass::id = "FullSimplifyPass" ; |
496 | |
497 | namespace irpass { |
498 | |
499 | bool simplify(IRNode *root, const CompileConfig &config) { |
500 | TI_AUTO_PROF; |
501 | bool modified = false; |
502 | while (true) { |
503 | Simplify pass(root, config); |
504 | if (pass.modified) |
505 | modified = true; |
506 | else |
507 | break; |
508 | } |
509 | return modified; |
510 | } |
511 | |
512 | void full_simplify(IRNode *root, |
513 | const CompileConfig &config, |
514 | const FullSimplifyPass::Args &args) { |
515 | TI_AUTO_PROF; |
516 | if (config.advanced_optimization) { |
517 | bool first_iteration = true; |
518 | while (true) { |
519 | bool modified = false; |
520 | if (extract_constant(root, config)) |
521 | modified = true; |
522 | if (unreachable_code_elimination(root)) |
523 | modified = true; |
524 | if (binary_op_simplify(root, config)) |
525 | modified = true; |
526 | if (config.constant_folding && |
527 | constant_fold(root, config, {args.program})) |
528 | modified = true; |
529 | if (die(root)) |
530 | modified = true; |
531 | if (alg_simp(root, config)) |
532 | modified = true; |
533 | if (loop_invariant_code_motion(root, config)) |
534 | modified = true; |
535 | if (die(root)) |
536 | modified = true; |
537 | if (simplify(root, config)) |
538 | modified = true; |
539 | if (die(root)) |
540 | modified = true; |
541 | if (config.opt_level > 0 && whole_kernel_cse(root)) |
542 | modified = true; |
543 | // Don't do this time-consuming optimization pass again if the IR is |
544 | // not modified. |
545 | if (config.opt_level > 0 && first_iteration && config.cfg_optimization && |
546 | cfg_optimization(root, args.after_lower_access, args.autodiff_enabled, |
547 | !config.real_matrix_scalarize)) |
548 | modified = true; |
549 | first_iteration = false; |
550 | if (!modified) |
551 | break; |
552 | } |
553 | return; |
554 | } |
555 | if (config.constant_folding) { |
556 | constant_fold(root, config, {args.program}); |
557 | die(root); |
558 | } |
559 | simplify(root, config); |
560 | die(root); |
561 | } |
562 | |
563 | } // namespace irpass |
564 | |
565 | } // namespace taichi::lang |
566 | |