1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/visitors.h"
5#include "taichi/transforms/check_out_of_bound.h"
6#include "taichi/transforms/utils.h"
7#include <set>
8
9namespace taichi::lang {
10
11// TODO: also check RangeAssumptionStmt
12
13class CheckOutOfBound : public BasicStmtVisitor {
14 public:
15 using BasicStmtVisitor::visit;
16 std::set<int> visited;
17 DelayedIRModifier modifier;
18 std::string kernel_name;
19
20 explicit CheckOutOfBound(const std::string &kernel_name)
21 : kernel_name(kernel_name) {
22 }
23
24 bool is_done(Stmt *stmt) {
25 return visited.find(stmt->instance_id) != visited.end();
26 }
27
28 void set_done(Stmt *stmt) {
29 visited.insert(stmt->instance_id);
30 }
31
32 void visit(SNodeOpStmt *stmt) override {
33 if (stmt->ptr != nullptr) {
34 TI_ASSERT(stmt->ptr->is<GlobalPtrStmt>());
35 // We have already done the check on its ptr argument. No need to do
36 // anything here.
37 return;
38 }
39
40 // TODO: implement bound check here for other situations.
41 }
42
43 void visit(GlobalPtrStmt *stmt) override {
44 if (is_done(stmt))
45 return;
46 auto snode = stmt->snode;
47 bool has_offset = !(snode->index_offsets.empty());
48 auto new_stmts = VecStatement();
49 auto zero = new_stmts.push_back<ConstStmt>(TypedConstant(0));
50 Stmt *result = new_stmts.push_back<ConstStmt>(TypedConstant(true));
51
52 std::string msg =
53 fmt::format("(kernel={}) Accessing field ({}) of size (", kernel_name,
54 snode->get_node_type_name_hinted());
55 std::string offset_msg = "offset (";
56 std::vector<Stmt *> args;
57 for (int i = 0; i < stmt->indices.size(); i++) {
58 int offset_i = has_offset ? snode->index_offsets[i] : 0;
59
60 // Note that during lower_ast, index arguments to GlobalPtrStmt are
61 // already converted to [0, +inf) range.
62
63 auto lower_bound = zero;
64 auto check_lower_bound = new_stmts.push_back<BinaryOpStmt>(
65 BinaryOpType::cmp_ge, stmt->indices[i], lower_bound);
66 int size_i = snode->shape_along_axis(i);
67 int upper_bound_i = size_i;
68 auto upper_bound =
69 new_stmts.push_back<ConstStmt>(TypedConstant(upper_bound_i));
70 auto check_upper_bound = new_stmts.push_back<BinaryOpStmt>(
71 BinaryOpType::cmp_lt, stmt->indices[i], upper_bound);
72 auto check_i = new_stmts.push_back<BinaryOpStmt>(
73 BinaryOpType::bit_and, check_lower_bound, check_upper_bound);
74 result = new_stmts.push_back<BinaryOpStmt>(BinaryOpType::bit_and, result,
75 check_i);
76 if (i > 0) {
77 msg += ", ";
78 offset_msg += ", ";
79 }
80 msg += std::to_string(size_i);
81 offset_msg += std::to_string(offset_i);
82
83 auto input_index = stmt->indices[i];
84 if (offset_i != 0) {
85 auto offset = new_stmts.push_back<ConstStmt>(TypedConstant(offset_i));
86 input_index = new_stmts.push_back<BinaryOpStmt>(BinaryOpType::add,
87 input_index, offset);
88 }
89 args.emplace_back(input_index);
90 }
91 offset_msg += ") ";
92 msg += ") " + (has_offset ? offset_msg : "") + "with indices (";
93 for (int i = 0; i < stmt->indices.size(); i++) {
94 if (i > 0)
95 msg += ", ";
96 msg += "%d";
97 }
98 msg += ")";
99 msg += "\n" + stmt->tb;
100
101 new_stmts.push_back<AssertStmt>(result, msg, args);
102 modifier.insert_before(stmt, std::move(new_stmts));
103 set_done(stmt);
104 }
105
106 void visit(BinaryOpStmt *stmt) override {
107 // Insert assertions if debug is on
108 if (is_done(stmt)) {
109 return;
110 }
111 if (stmt->op_type == BinaryOpType::pow) {
112 if (is_integral(stmt->rhs->ret_type) &&
113 is_integral(stmt->lhs->ret_type)) {
114 auto compare_rhs = Stmt::make<ConstStmt>(TypedConstant(0));
115 auto compare = std::make_unique<BinaryOpStmt>(
116 BinaryOpType::cmp_ge, stmt->rhs, compare_rhs.get());
117 compare->ret_type = PrimitiveType::i32;
118 std::string msg = "Negative exponent in pow(int, int) is not allowed.";
119 msg += "\n" + stmt->tb;
120 auto assert_stmt = std::make_unique<AssertStmt>(compare.get(), msg,
121 std::vector<Stmt *>());
122 assert_stmt->accept(this);
123 modifier.insert_before(stmt, std::move(compare_rhs));
124 modifier.insert_before(stmt, std::move(compare));
125 modifier.insert_before(stmt, std::move(assert_stmt));
126 set_done(stmt);
127 }
128 }
129 }
130
131 static bool run(IRNode *node,
132 const CompileConfig &config,
133 const std::string &kernel_name) {
134 CheckOutOfBound checker(kernel_name);
135 bool modified = false;
136 while (true) {
137 node->accept(&checker);
138 if (checker.modifier.modify_ir()) {
139 modified = true;
140 } else {
141 break;
142 }
143 }
144 if (modified)
145 irpass::type_check(node, config);
146 return modified;
147 }
148};
149
150const PassID CheckOutOfBoundPass::id = "CheckOutOfBoundPass";
151
152namespace irpass {
153
154bool check_out_of_bound(IRNode *root,
155 const CompileConfig &config,
156 const CheckOutOfBoundPass::Args &args) {
157 TI_AUTO_PROF;
158 return CheckOutOfBound::run(root, config, args.kernel_name);
159}
160
161} // namespace irpass
162
163} // namespace taichi::lang
164