1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bounds_checker.cc |
22 | */ |
23 | // Instrument checkers for out of the bounds access. |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | #include <tvm/tir/transform.h> |
32 | |
33 | #include <unordered_map> |
34 | #include <utility> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | // TODO(Lunderberg): Move this pass to be before |
41 | // StorageFlatten/FlattenBuffer. That will simplify this pass, |
42 | // because it can check directly against the buffer limits. |
43 | class BoundCollector : public StmtVisitor { |
44 | public: |
45 | BoundCollector() {} |
46 | |
47 | void VisitStmt_(const AttrStmtNode* op) final { |
48 | if (op->attr_key == tir::attr::buffer_bound) { |
49 | const VarNode* key = op->node.as<VarNode>(); |
50 | const CallNode* container = op->value.as<CallNode>(); |
51 | if (key && container) { |
52 | mem_to_shape[key] = container->args; |
53 | } |
54 | } |
55 | StmtVisitor::VisitStmt_(op); |
56 | } |
57 | // Hashtable which maps buffer_var to shape. |
58 | std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape; |
59 | }; |
60 | |
61 | class BoundChecker : public StmtExprMutator { |
62 | public: |
63 | explicit BoundChecker(const std::unordered_map<const VarNode*, Array<PrimExpr>>& mem_to_shape) |
64 | : mem_to_shape_(mem_to_shape) {} |
65 | |
66 | Stmt VisitStmt_(const AllocateNode* op) final { |
67 | // If the shape was updated we should update the hashtable. |
68 | if (UpdateIsNeeded(op->buffer_var)) { |
69 | Update(op->buffer_var, op->extents, op->dtype); |
70 | } |
71 | return StmtExprMutator::VisitStmt_(op); |
72 | } |
73 | |
74 | PrimExpr VisitExpr_(const CallNode* op) final { |
75 | if (process_store_ && op->op.same_as(builtin::if_then_else())) { |
76 | unsafe_rewritten_ = true; |
77 | } |
78 | return StmtExprMutator::VisitExpr_(op); |
79 | } |
80 | |
81 | PrimExpr VisitExpr_(const LoadNode* op) final { |
82 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
83 | } |
84 | |
85 | Stmt VisitStmt_(const StoreNode* op) final { |
86 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
87 | } |
88 | |
89 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
90 | store_scope_bound_collector_.clear(); |
91 | process_store_ = true; |
92 | unsafe_rewritten_ = false; |
93 | StmtExprMutator::VisitStmt_(op); |
94 | process_store_ = false; |
95 | if (CanInstrument(op->indices, op->buffer->data)) { |
96 | Collect(op->indices, op->buffer->data); |
97 | } |
98 | // The collector should has at least one item. |
99 | if (store_scope_bound_collector_.size()) { |
100 | PrimExpr condition = MakeCondition(); |
101 | if (!condition.as<StringImmNode>()) { |
102 | Stmt nop = Evaluate(1); |
103 | Stmt then_case = GetRef<Stmt>(op); |
104 | Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); |
105 | Stmt body = IfThenElse(condition, then_case, else_case); |
106 | return body; |
107 | } |
108 | } |
109 | return GetRef<Stmt>(op); |
110 | } |
111 | |
112 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
113 | if (CanInstrument(op->indices, op->buffer->data)) { |
114 | Collect(op->indices, op->buffer->data); |
115 | } |
116 | return StmtExprMutator::VisitExpr_(op); |
117 | } |
118 | |
119 | private: |
120 | bool UpdateIsNeeded(const Var& buffer_var) const { |
121 | return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); |
122 | } |
123 | |
124 | void Update(const Var& buffer_var, Array<PrimExpr> new_shape, const DataType& type) { |
125 | // Sanity check at first. |
126 | if (!ShapeIsValid(new_shape)) { |
127 | return; |
128 | } |
129 | |
130 | new_shape.MutateByApply([&](const PrimExpr& dim) { |
131 | // Cast to uint64 to avoid potential overflow. |
132 | return make_const(DataType::UInt(64), type.lanes()) * dim; |
133 | }); |
134 | mem_to_shape_[buffer_var.get()] = new_shape; |
135 | } |
136 | |
137 | bool ShapeIsValid(const Array<PrimExpr>& shape) const { |
138 | if (!shape.defined()) { |
139 | return false; |
140 | } |
141 | for (const auto& dim : shape) { |
142 | if (!IsValidScalar(dim) || is_negative_const(dim)) { |
143 | return false; |
144 | } |
145 | } |
146 | |
147 | return true; |
148 | } |
149 | |
150 | bool IndicesAreValid(const Array<PrimExpr>& indices) const { |
151 | if (!indices.defined()) { |
152 | return false; |
153 | } |
154 | |
155 | for (const auto& index : indices) { |
156 | if (!index.defined()) { |
157 | return false; |
158 | } |
159 | |
160 | if (const RampNode* ramp_index = index.as<RampNode>()) { |
161 | if (!IsValidScalar(ramp_index->base)) { |
162 | return false; |
163 | } |
164 | if (!IsValidScalar(ramp_index->stride)) { |
165 | return false; |
166 | } |
167 | if (ramp_index->lanes <= 0) { |
168 | return false; |
169 | } |
170 | } |
171 | } |
172 | return true; |
173 | } |
174 | |
175 | bool IsValidScalar(const PrimExpr& expr) const { |
176 | return expr.defined() && expr.dtype().is_scalar(); |
177 | } |
178 | |
179 | bool CanInstrument(const Array<PrimExpr>& indices, const Var& buffer_var) const { |
180 | return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && |
181 | IndicesAreValid(indices) && !unsafe_rewritten_; |
182 | } |
183 | |
184 | void Collect(Array<PrimExpr> indices, Var buffer_var) { |
185 | store_scope_bound_collector_.push_back( |
186 | std::make_pair(indices, mem_to_shape_[buffer_var.get()])); |
187 | } |
188 | |
189 | PrimExpr MakeCondition() { |
190 | PrimExpr condition; |
191 | for (const auto& pair : store_scope_bound_collector_) { |
192 | Array<PrimExpr> indices = pair.first; |
193 | Array<PrimExpr> shape = pair.second; |
194 | |
195 | ICHECK_EQ(indices.size(), shape.size()) |
196 | << "Mismatch between dimension of physical shape and physical indices" ; |
197 | |
198 | for (size_t i = 0; i < indices.size(); i++) { |
199 | PrimExpr index = indices[i]; |
200 | PrimExpr upper_bound = shape[i]; |
201 | |
202 | if (const RampNode* ramp_index = index.as<RampNode>()) { |
203 | // In case index is base + stride * i. |
204 | // Non inclusive range. |
205 | index = Add(ramp_index->base, |
206 | Mul(ramp_index->stride, |
207 | make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); |
208 | } |
209 | |
210 | // Try to simplify index and bound. |
211 | index = analyzer_.Simplify(index); |
212 | upper_bound = analyzer_.Simplify(upper_bound); |
213 | |
214 | // Cast to the same type - signed, to be able to check lower bound. |
215 | index = Cast(DataType::Int(64), index); |
216 | upper_bound = Cast(DataType::Int(64), upper_bound); |
217 | |
218 | // Looks like a lower bound should always be zero after normalization. |
219 | PrimExpr lower_bound = make_zero(DataType::Int(64)); |
220 | |
221 | PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); |
222 | condition = condition.defined() ? And(condition, current_condition) : current_condition; |
223 | } |
224 | } |
225 | return condition; |
226 | } |
227 | |
228 | // Whether we process store value recursively. |
229 | bool process_store_{false}; |
230 | // Whether we face tvm_if_then_else intrinsic. |
231 | bool unsafe_rewritten_{false}; |
232 | // Pool which collects the pair of index and shape for specific store/load. |
233 | std::vector<std::pair<Array<PrimExpr>, Array<PrimExpr>>> store_scope_bound_collector_; |
234 | // Error message. |
235 | const char* const error_message_ = "OUT OF THE BOUNDS" ; |
236 | // Hashtable which maps buffer_var to shape. |
237 | std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape_; |
238 | // internal analyzer |
239 | arith::Analyzer analyzer_; |
240 | }; |
241 | |
242 | Stmt InstrumentBoundCheckers(Stmt stmt) { |
243 | BoundCollector bound_collector; |
244 | // At first walk recursively and collect bound attributes. |
245 | bound_collector(stmt); |
246 | return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); |
247 | } |
248 | |
249 | namespace transform { |
250 | |
251 | Pass InstrumentBoundCheckers() { |
252 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
253 | auto* n = f.CopyOnWrite(); |
254 | BoundCollector bound_collector; |
255 | // At first walk recursively and collect bound attributes. |
256 | bound_collector(n->body); |
257 | n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body)); |
258 | return f; |
259 | }; |
260 | return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers" , {}); |
261 | } |
262 | |
263 | TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers" ) |
264 | .set_body_typed(InstrumentBoundCheckers); |
265 | |
266 | } // namespace transform |
267 | |
268 | } // namespace tir |
269 | } // namespace tvm |
270 | |