1 | #include "taichi/ir/ir.h" |
---|---|
2 | #include "taichi/ir/transforms.h" |
3 | #include "taichi/ir/visitors.h" |
4 | #include "taichi/ir/frontend_ir.h" |
5 | #include "taichi/system/profiler.h" |
6 | |
7 | #include <set> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | namespace irpass { |
12 | |
13 | // TODO: gather Expr as well? |
14 | class GatherStmts : public BasicStmtVisitor { |
15 | public: |
16 | using BasicStmtVisitor::visit; |
17 | |
18 | std::vector<Stmt *> stmts; |
19 | |
20 | GatherStmts() { |
21 | invoke_default_visitor = true; |
22 | } |
23 | |
24 | void visit(Stmt *stmt) override { |
25 | stmts.push_back(stmt); |
26 | } |
27 | }; |
28 | |
29 | void reverse_segments(IRNode *root) { |
30 | TI_AUTO_PROF; |
31 | auto block = dynamic_cast<Block *>(root); |
32 | std::vector<std::vector<pStmt>> statement_blocks(1); |
33 | bool has_for = false; |
34 | bool has_non_for = false; |
35 | for (auto &&s : block->statements) { |
36 | if (s->is<FrontendForStmt>()) { |
37 | has_for = true; |
38 | statement_blocks.emplace_back(); |
39 | statement_blocks.back().push_back(std::move(s)); |
40 | statement_blocks.emplace_back(); |
41 | } else { |
42 | has_non_for = true; |
43 | statement_blocks.back().push_back(std::move(s)); |
44 | } |
45 | } |
46 | block->statements.clear(); |
47 | std::reverse(statement_blocks.begin(), statement_blocks.end()); |
48 | /* |
49 | for (auto &b : statement_blocks) { |
50 | std::vector<Stmt *> stmts; |
51 | for (auto &s : b) { |
52 | GatherStmts gather; |
53 | s->accept(&gather); |
54 | stmts.insert(stmts.end(), gather.stmts.begin(), gather.stmts.end()); |
55 | } |
56 | std::set<Stmt *> stmt_set(stmts.begin(), stmts.end()); |
57 | bool valid = true; |
58 | for (auto s : stmts) { |
59 | for (auto op : s->get_operands()) { |
60 | if (stmt_set.find(op) == stmt_set.end()) { |
61 | valid = false; |
62 | } |
63 | } |
64 | } |
65 | } |
66 | */ |
67 | if (has_for && has_non_for) { |
68 | TI_ERROR( |
69 | "Invalid program input for autodiff: " |
70 | "Mixed usage of for-loops and statements without looping. \n" |
71 | "Please split them into two kernels " |
72 | "and check the documentation for more details:\n" |
73 | "https://docs.taichi-lang.org/docs/" |
74 | "differentiable_programming"); |
75 | } |
76 | for (auto &sblock : statement_blocks) { |
77 | for (auto &&s : sblock) { |
78 | block->statements.push_back(std::move(s)); |
79 | } |
80 | } |
81 | } |
82 | |
83 | } // namespace irpass |
84 | |
85 | } // namespace taichi::lang |
86 |