1 | #include <vector> |
2 | #include <unordered_set> |
3 | |
4 | #include "taichi/ir/ir.h" |
5 | #include "taichi/ir/analysis.h" |
6 | #include "taichi/ir/statements.h" |
7 | #include "taichi/ir/visitors.h" |
8 | #include "taichi/ir/transforms.h" |
9 | #include "taichi/system/profiler.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | class IRVerifier : public BasicStmtVisitor { |
14 | private: |
15 | Block *current_block_; |
16 | Stmt *current_container_stmt_; |
17 | // each scope corresponds to an unordered_set |
18 | std::vector<std::unordered_set<Stmt *>> visible_stmts_; |
19 | |
20 | public: |
21 | using BasicStmtVisitor::visit; |
22 | |
23 | explicit IRVerifier(IRNode *root) |
24 | : current_block_(nullptr), current_container_stmt_(nullptr) { |
25 | allow_undefined_visitor = true; |
26 | invoke_default_visitor = true; |
27 | if (!root->is<Block>()) |
28 | visible_stmts_.emplace_back(); |
29 | if (root->is<Stmt>() && root->as<Stmt>()->is_container_statement()) { |
30 | current_container_stmt_ = root->as<Stmt>(); |
31 | } |
32 | } |
33 | |
34 | void basic_verify(Stmt *stmt) { |
35 | TI_ASSERT_INFO(stmt->parent == current_block_, |
36 | "stmt({})->parent({}) != current_block({})" , stmt->id, |
37 | fmt::ptr(stmt->parent), fmt::ptr(current_block_)); |
38 | for (auto &op : stmt->get_operands()) { |
39 | if (op == nullptr) |
40 | continue; |
41 | bool found = false; |
42 | for (int depth = (int)visible_stmts_.size() - 1; depth >= 0; depth--) { |
43 | if (visible_stmts_[depth].find(op) != visible_stmts_[depth].end()) { |
44 | found = true; |
45 | break; |
46 | } |
47 | } |
48 | TI_ASSERT_INFO(found, |
49 | "IR broken: stmt {} {} cannot have operand {} {}." |
50 | " If you are using autodiff, please check out" |
51 | " https://docs.taichi-lang.org/docs/" |
52 | "differences_between_taichi_and_python_programs" |
53 | " If it doesn't help, please open an issue at" |
54 | " https://github.com/taichi-dev/taichi to help us improve." |
55 | " Thanks in advance!" , |
56 | stmt->type(), stmt->id, op->type(), op->id); |
57 | } |
58 | visible_stmts_.back().insert(stmt); |
59 | } |
60 | |
61 | void preprocess_container_stmt(Stmt *stmt) override { |
62 | basic_verify(stmt); |
63 | } |
64 | |
65 | void visit(Stmt *stmt) override { |
66 | basic_verify(stmt); |
67 | } |
68 | |
69 | void visit(Block *block) override { |
70 | TI_ASSERT_INFO( |
71 | block->parent_stmt == current_container_stmt_, |
72 | "block({})->parent({}) != current_container_stmt({})" , fmt::ptr(block), |
73 | block->parent_stmt ? block->parent_stmt->name() : "nullptr" , |
74 | current_container_stmt_ ? current_container_stmt_->name() : "nullptr" ); |
75 | auto backup_block = current_block_; |
76 | current_block_ = block; |
77 | auto backup_container_stmt = current_container_stmt_; |
78 | if (!block->parent_stmt || !block->parent_stmt->is<OffloadedStmt>()) |
79 | visible_stmts_.emplace_back(); |
80 | for (auto &stmt : block->statements) { |
81 | if (stmt->is_container_statement()) |
82 | current_container_stmt_ = stmt.get(); |
83 | stmt->accept(this); |
84 | if (stmt->is_container_statement()) |
85 | current_container_stmt_ = backup_container_stmt; |
86 | } |
87 | current_block_ = backup_block; |
88 | if (!block->parent_stmt || !block->parent_stmt->is<OffloadedStmt>()) |
89 | current_block_ = backup_block; |
90 | } |
91 | |
92 | void visit(OffloadedStmt *stmt) override { |
93 | basic_verify(stmt); |
94 | if (stmt->has_body() && !stmt->body) { |
95 | TI_ERROR("offloaded {} ({})->body is nullptr" , |
96 | offloaded_task_type_name(stmt->task_type), stmt->name()); |
97 | } else if (!stmt->has_body() && stmt->body) { |
98 | TI_ERROR("offloaded {} ({})->body is {} (should be nullptr)" , |
99 | offloaded_task_type_name(stmt->task_type), stmt->name(), |
100 | fmt::ptr(stmt->body)); |
101 | } |
102 | stmt->all_blocks_accept(this); |
103 | } |
104 | |
105 | void visit(LocalLoadStmt *stmt) override { |
106 | basic_verify(stmt); |
107 | TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>()); |
108 | } |
109 | |
110 | void visit(LocalStoreStmt *stmt) override { |
111 | basic_verify(stmt); |
112 | TI_ASSERT(stmt->dest->is<AllocaStmt>() || |
113 | (stmt->dest->is<MatrixPtrStmt>() && |
114 | stmt->dest->cast<MatrixPtrStmt>()->offset_used_as_index())); |
115 | } |
116 | |
117 | void visit(LoopIndexStmt *stmt) override { |
118 | basic_verify(stmt); |
119 | TI_ASSERT(stmt->loop); |
120 | if (stmt->loop->is<OffloadedStmt>()) { |
121 | TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type == |
122 | OffloadedStmt::TaskType::struct_for || |
123 | stmt->loop->as<OffloadedStmt>()->task_type == |
124 | OffloadedStmt::TaskType::mesh_for || |
125 | stmt->loop->as<OffloadedStmt>()->task_type == |
126 | OffloadedStmt::TaskType::range_for); |
127 | } else { |
128 | TI_ASSERT(stmt->loop->is<StructForStmt>() || |
129 | stmt->loop->is<MeshForStmt>() || |
130 | stmt->loop->is<RangeForStmt>()); |
131 | } |
132 | } |
133 | |
134 | static void run(IRNode *root) { |
135 | IRVerifier verifier(root); |
136 | root->accept(&verifier); |
137 | } |
138 | }; |
139 | |
140 | namespace irpass::analysis { |
141 | void verify(IRNode *root) { |
142 | TI_AUTO_PROF; |
143 | if (!root->is<Block>() && !root->is<OffloadedStmt>()) { |
144 | TI_WARN( |
145 | "IR root is neither a Block nor an OffloadedStmt." |
146 | " Skipping verification." ); |
147 | } else { |
148 | IRVerifier::run(root); |
149 | } |
150 | } |
151 | } // namespace irpass::analysis |
152 | |
153 | } // namespace taichi::lang |
154 | |