1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/transforms/check_out_of_bound.h" |
6 | #include "taichi/transforms/utils.h" |
7 | #include <set> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | // TODO: also check RangeAssumptionStmt |
12 | |
13 | class CheckOutOfBound : public BasicStmtVisitor { |
14 | public: |
15 | using BasicStmtVisitor::visit; |
16 | std::set<int> visited; |
17 | DelayedIRModifier modifier; |
18 | std::string kernel_name; |
19 | |
20 | explicit CheckOutOfBound(const std::string &kernel_name) |
21 | : kernel_name(kernel_name) { |
22 | } |
23 | |
24 | bool is_done(Stmt *stmt) { |
25 | return visited.find(stmt->instance_id) != visited.end(); |
26 | } |
27 | |
28 | void set_done(Stmt *stmt) { |
29 | visited.insert(stmt->instance_id); |
30 | } |
31 | |
32 | void visit(SNodeOpStmt *stmt) override { |
33 | if (stmt->ptr != nullptr) { |
34 | TI_ASSERT(stmt->ptr->is<GlobalPtrStmt>()); |
35 | // We have already done the check on its ptr argument. No need to do |
36 | // anything here. |
37 | return; |
38 | } |
39 | |
40 | // TODO: implement bound check here for other situations. |
41 | } |
42 | |
43 | void visit(GlobalPtrStmt *stmt) override { |
44 | if (is_done(stmt)) |
45 | return; |
46 | auto snode = stmt->snode; |
47 | bool has_offset = !(snode->index_offsets.empty()); |
48 | auto new_stmts = VecStatement(); |
49 | auto zero = new_stmts.push_back<ConstStmt>(TypedConstant(0)); |
50 | Stmt *result = new_stmts.push_back<ConstStmt>(TypedConstant(true)); |
51 | |
52 | std::string msg = |
53 | fmt::format("(kernel={}) Accessing field ({}) of size (" , kernel_name, |
54 | snode->get_node_type_name_hinted()); |
55 | std::string offset_msg = "offset (" ; |
56 | std::vector<Stmt *> args; |
57 | for (int i = 0; i < stmt->indices.size(); i++) { |
58 | int offset_i = has_offset ? snode->index_offsets[i] : 0; |
59 | |
60 | // Note that during lower_ast, index arguments to GlobalPtrStmt are |
61 | // already converted to [0, +inf) range. |
62 | |
63 | auto lower_bound = zero; |
64 | auto check_lower_bound = new_stmts.push_back<BinaryOpStmt>( |
65 | BinaryOpType::cmp_ge, stmt->indices[i], lower_bound); |
66 | int size_i = snode->shape_along_axis(i); |
67 | int upper_bound_i = size_i; |
68 | auto upper_bound = |
69 | new_stmts.push_back<ConstStmt>(TypedConstant(upper_bound_i)); |
70 | auto check_upper_bound = new_stmts.push_back<BinaryOpStmt>( |
71 | BinaryOpType::cmp_lt, stmt->indices[i], upper_bound); |
72 | auto check_i = new_stmts.push_back<BinaryOpStmt>( |
73 | BinaryOpType::bit_and, check_lower_bound, check_upper_bound); |
74 | result = new_stmts.push_back<BinaryOpStmt>(BinaryOpType::bit_and, result, |
75 | check_i); |
76 | if (i > 0) { |
77 | msg += ", " ; |
78 | offset_msg += ", " ; |
79 | } |
80 | msg += std::to_string(size_i); |
81 | offset_msg += std::to_string(offset_i); |
82 | |
83 | auto input_index = stmt->indices[i]; |
84 | if (offset_i != 0) { |
85 | auto offset = new_stmts.push_back<ConstStmt>(TypedConstant(offset_i)); |
86 | input_index = new_stmts.push_back<BinaryOpStmt>(BinaryOpType::add, |
87 | input_index, offset); |
88 | } |
89 | args.emplace_back(input_index); |
90 | } |
91 | offset_msg += ") " ; |
92 | msg += ") " + (has_offset ? offset_msg : "" ) + "with indices (" ; |
93 | for (int i = 0; i < stmt->indices.size(); i++) { |
94 | if (i > 0) |
95 | msg += ", " ; |
96 | msg += "%d" ; |
97 | } |
98 | msg += ")" ; |
99 | msg += "\n" + stmt->tb; |
100 | |
101 | new_stmts.push_back<AssertStmt>(result, msg, args); |
102 | modifier.insert_before(stmt, std::move(new_stmts)); |
103 | set_done(stmt); |
104 | } |
105 | |
106 | void visit(BinaryOpStmt *stmt) override { |
107 | // Insert assertions if debug is on |
108 | if (is_done(stmt)) { |
109 | return; |
110 | } |
111 | if (stmt->op_type == BinaryOpType::pow) { |
112 | if (is_integral(stmt->rhs->ret_type) && |
113 | is_integral(stmt->lhs->ret_type)) { |
114 | auto compare_rhs = Stmt::make<ConstStmt>(TypedConstant(0)); |
115 | auto compare = std::make_unique<BinaryOpStmt>( |
116 | BinaryOpType::cmp_ge, stmt->rhs, compare_rhs.get()); |
117 | compare->ret_type = PrimitiveType::i32; |
118 | std::string msg = "Negative exponent in pow(int, int) is not allowed." ; |
119 | msg += "\n" + stmt->tb; |
120 | auto assert_stmt = std::make_unique<AssertStmt>(compare.get(), msg, |
121 | std::vector<Stmt *>()); |
122 | assert_stmt->accept(this); |
123 | modifier.insert_before(stmt, std::move(compare_rhs)); |
124 | modifier.insert_before(stmt, std::move(compare)); |
125 | modifier.insert_before(stmt, std::move(assert_stmt)); |
126 | set_done(stmt); |
127 | } |
128 | } |
129 | } |
130 | |
131 | static bool run(IRNode *node, |
132 | const CompileConfig &config, |
133 | const std::string &kernel_name) { |
134 | CheckOutOfBound checker(kernel_name); |
135 | bool modified = false; |
136 | while (true) { |
137 | node->accept(&checker); |
138 | if (checker.modifier.modify_ir()) { |
139 | modified = true; |
140 | } else { |
141 | break; |
142 | } |
143 | } |
144 | if (modified) |
145 | irpass::type_check(node, config); |
146 | return modified; |
147 | } |
148 | }; |
149 | |
150 | const PassID CheckOutOfBoundPass::id = "CheckOutOfBoundPass" ; |
151 | |
152 | namespace irpass { |
153 | |
154 | bool check_out_of_bound(IRNode *root, |
155 | const CompileConfig &config, |
156 | const CheckOutOfBoundPass::Args &args) { |
157 | TI_AUTO_PROF; |
158 | return CheckOutOfBound::run(root, config, args.kernel_name); |
159 | } |
160 | |
161 | } // namespace irpass |
162 | |
163 | } // namespace taichi::lang |
164 | |