1#include <vector>
2#include <unordered_set>
3
4#include "taichi/ir/ir.h"
5#include "taichi/ir/analysis.h"
6#include "taichi/ir/statements.h"
7#include "taichi/ir/visitors.h"
8#include "taichi/ir/transforms.h"
9#include "taichi/system/profiler.h"
10
11namespace taichi::lang {
12
13class IRVerifier : public BasicStmtVisitor {
14 private:
15 Block *current_block_;
16 Stmt *current_container_stmt_;
17 // each scope corresponds to an unordered_set
18 std::vector<std::unordered_set<Stmt *>> visible_stmts_;
19
20 public:
21 using BasicStmtVisitor::visit;
22
23 explicit IRVerifier(IRNode *root)
24 : current_block_(nullptr), current_container_stmt_(nullptr) {
25 allow_undefined_visitor = true;
26 invoke_default_visitor = true;
27 if (!root->is<Block>())
28 visible_stmts_.emplace_back();
29 if (root->is<Stmt>() && root->as<Stmt>()->is_container_statement()) {
30 current_container_stmt_ = root->as<Stmt>();
31 }
32 }
33
34 void basic_verify(Stmt *stmt) {
35 TI_ASSERT_INFO(stmt->parent == current_block_,
36 "stmt({})->parent({}) != current_block({})", stmt->id,
37 fmt::ptr(stmt->parent), fmt::ptr(current_block_));
38 for (auto &op : stmt->get_operands()) {
39 if (op == nullptr)
40 continue;
41 bool found = false;
42 for (int depth = (int)visible_stmts_.size() - 1; depth >= 0; depth--) {
43 if (visible_stmts_[depth].find(op) != visible_stmts_[depth].end()) {
44 found = true;
45 break;
46 }
47 }
48 TI_ASSERT_INFO(found,
49 "IR broken: stmt {} {} cannot have operand {} {}."
50 " If you are using autodiff, please check out"
51 " https://docs.taichi-lang.org/docs/"
52 "differences_between_taichi_and_python_programs"
53 " If it doesn't help, please open an issue at"
54 " https://github.com/taichi-dev/taichi to help us improve."
55 " Thanks in advance!",
56 stmt->type(), stmt->id, op->type(), op->id);
57 }
58 visible_stmts_.back().insert(stmt);
59 }
60
61 void preprocess_container_stmt(Stmt *stmt) override {
62 basic_verify(stmt);
63 }
64
65 void visit(Stmt *stmt) override {
66 basic_verify(stmt);
67 }
68
69 void visit(Block *block) override {
70 TI_ASSERT_INFO(
71 block->parent_stmt == current_container_stmt_,
72 "block({})->parent({}) != current_container_stmt({})", fmt::ptr(block),
73 block->parent_stmt ? block->parent_stmt->name() : "nullptr",
74 current_container_stmt_ ? current_container_stmt_->name() : "nullptr");
75 auto backup_block = current_block_;
76 current_block_ = block;
77 auto backup_container_stmt = current_container_stmt_;
78 if (!block->parent_stmt || !block->parent_stmt->is<OffloadedStmt>())
79 visible_stmts_.emplace_back();
80 for (auto &stmt : block->statements) {
81 if (stmt->is_container_statement())
82 current_container_stmt_ = stmt.get();
83 stmt->accept(this);
84 if (stmt->is_container_statement())
85 current_container_stmt_ = backup_container_stmt;
86 }
87 current_block_ = backup_block;
88 if (!block->parent_stmt || !block->parent_stmt->is<OffloadedStmt>())
89 current_block_ = backup_block;
90 }
91
92 void visit(OffloadedStmt *stmt) override {
93 basic_verify(stmt);
94 if (stmt->has_body() && !stmt->body) {
95 TI_ERROR("offloaded {} ({})->body is nullptr",
96 offloaded_task_type_name(stmt->task_type), stmt->name());
97 } else if (!stmt->has_body() && stmt->body) {
98 TI_ERROR("offloaded {} ({})->body is {} (should be nullptr)",
99 offloaded_task_type_name(stmt->task_type), stmt->name(),
100 fmt::ptr(stmt->body));
101 }
102 stmt->all_blocks_accept(this);
103 }
104
105 void visit(LocalLoadStmt *stmt) override {
106 basic_verify(stmt);
107 TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>());
108 }
109
110 void visit(LocalStoreStmt *stmt) override {
111 basic_verify(stmt);
112 TI_ASSERT(stmt->dest->is<AllocaStmt>() ||
113 (stmt->dest->is<MatrixPtrStmt>() &&
114 stmt->dest->cast<MatrixPtrStmt>()->offset_used_as_index()));
115 }
116
117 void visit(LoopIndexStmt *stmt) override {
118 basic_verify(stmt);
119 TI_ASSERT(stmt->loop);
120 if (stmt->loop->is<OffloadedStmt>()) {
121 TI_ASSERT(stmt->loop->as<OffloadedStmt>()->task_type ==
122 OffloadedStmt::TaskType::struct_for ||
123 stmt->loop->as<OffloadedStmt>()->task_type ==
124 OffloadedStmt::TaskType::mesh_for ||
125 stmt->loop->as<OffloadedStmt>()->task_type ==
126 OffloadedStmt::TaskType::range_for);
127 } else {
128 TI_ASSERT(stmt->loop->is<StructForStmt>() ||
129 stmt->loop->is<MeshForStmt>() ||
130 stmt->loop->is<RangeForStmt>());
131 }
132 }
133
134 static void run(IRNode *root) {
135 IRVerifier verifier(root);
136 root->accept(&verifier);
137 }
138};
139
140namespace irpass::analysis {
141void verify(IRNode *root) {
142 TI_AUTO_PROF;
143 if (!root->is<Block>() && !root->is<OffloadedStmt>()) {
144 TI_WARN(
145 "IR root is neither a Block nor an OffloadedStmt."
146 " Skipping verification.");
147 } else {
148 IRVerifier::run(root);
149 }
150}
151} // namespace irpass::analysis
152
153} // namespace taichi::lang
154