1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file bounds_checker.cc
22 */
23// Instrument checkers for out of the bounds access.
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/stmt_functor.h>
31#include <tvm/tir/transform.h>
32
33#include <unordered_map>
34#include <utility>
35#include <vector>
36
37namespace tvm {
38namespace tir {
39
40// TODO(Lunderberg): Move this pass to be before
41// StorageFlatten/FlattenBuffer. That will simplify this pass,
42// because it can check directly against the buffer limits.
43class BoundCollector : public StmtVisitor {
44 public:
45 BoundCollector() {}
46
47 void VisitStmt_(const AttrStmtNode* op) final {
48 if (op->attr_key == tir::attr::buffer_bound) {
49 const VarNode* key = op->node.as<VarNode>();
50 const CallNode* container = op->value.as<CallNode>();
51 if (key && container) {
52 mem_to_shape[key] = container->args;
53 }
54 }
55 StmtVisitor::VisitStmt_(op);
56 }
57 // Hashtable which maps buffer_var to shape.
58 std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape;
59};
60
61class BoundChecker : public StmtExprMutator {
62 public:
63 explicit BoundChecker(const std::unordered_map<const VarNode*, Array<PrimExpr>>& mem_to_shape)
64 : mem_to_shape_(mem_to_shape) {}
65
66 Stmt VisitStmt_(const AllocateNode* op) final {
67 // If the shape was updated we should update the hashtable.
68 if (UpdateIsNeeded(op->buffer_var)) {
69 Update(op->buffer_var, op->extents, op->dtype);
70 }
71 return StmtExprMutator::VisitStmt_(op);
72 }
73
74 PrimExpr VisitExpr_(const CallNode* op) final {
75 if (process_store_ && op->op.same_as(builtin::if_then_else())) {
76 unsafe_rewritten_ = true;
77 }
78 return StmtExprMutator::VisitExpr_(op);
79 }
80
81 PrimExpr VisitExpr_(const LoadNode* op) final {
82 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
83 }
84
85 Stmt VisitStmt_(const StoreNode* op) final {
86 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
87 }
88
89 Stmt VisitStmt_(const BufferStoreNode* op) final {
90 store_scope_bound_collector_.clear();
91 process_store_ = true;
92 unsafe_rewritten_ = false;
93 StmtExprMutator::VisitStmt_(op);
94 process_store_ = false;
95 if (CanInstrument(op->indices, op->buffer->data)) {
96 Collect(op->indices, op->buffer->data);
97 }
98 // The collector should has at least one item.
99 if (store_scope_bound_collector_.size()) {
100 PrimExpr condition = MakeCondition();
101 if (!condition.as<StringImmNode>()) {
102 Stmt nop = Evaluate(1);
103 Stmt then_case = GetRef<Stmt>(op);
104 Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop);
105 Stmt body = IfThenElse(condition, then_case, else_case);
106 return body;
107 }
108 }
109 return GetRef<Stmt>(op);
110 }
111
112 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
113 if (CanInstrument(op->indices, op->buffer->data)) {
114 Collect(op->indices, op->buffer->data);
115 }
116 return StmtExprMutator::VisitExpr_(op);
117 }
118
119 private:
120 bool UpdateIsNeeded(const Var& buffer_var) const {
121 return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get()));
122 }
123
124 void Update(const Var& buffer_var, Array<PrimExpr> new_shape, const DataType& type) {
125 // Sanity check at first.
126 if (!ShapeIsValid(new_shape)) {
127 return;
128 }
129
130 new_shape.MutateByApply([&](const PrimExpr& dim) {
131 // Cast to uint64 to avoid potential overflow.
132 return make_const(DataType::UInt(64), type.lanes()) * dim;
133 });
134 mem_to_shape_[buffer_var.get()] = new_shape;
135 }
136
137 bool ShapeIsValid(const Array<PrimExpr>& shape) const {
138 if (!shape.defined()) {
139 return false;
140 }
141 for (const auto& dim : shape) {
142 if (!IsValidScalar(dim) || is_negative_const(dim)) {
143 return false;
144 }
145 }
146
147 return true;
148 }
149
150 bool IndicesAreValid(const Array<PrimExpr>& indices) const {
151 if (!indices.defined()) {
152 return false;
153 }
154
155 for (const auto& index : indices) {
156 if (!index.defined()) {
157 return false;
158 }
159
160 if (const RampNode* ramp_index = index.as<RampNode>()) {
161 if (!IsValidScalar(ramp_index->base)) {
162 return false;
163 }
164 if (!IsValidScalar(ramp_index->stride)) {
165 return false;
166 }
167 if (ramp_index->lanes <= 0) {
168 return false;
169 }
170 }
171 }
172 return true;
173 }
174
175 bool IsValidScalar(const PrimExpr& expr) const {
176 return expr.defined() && expr.dtype().is_scalar();
177 }
178
179 bool CanInstrument(const Array<PrimExpr>& indices, const Var& buffer_var) const {
180 return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) &&
181 IndicesAreValid(indices) && !unsafe_rewritten_;
182 }
183
184 void Collect(Array<PrimExpr> indices, Var buffer_var) {
185 store_scope_bound_collector_.push_back(
186 std::make_pair(indices, mem_to_shape_[buffer_var.get()]));
187 }
188
189 PrimExpr MakeCondition() {
190 PrimExpr condition;
191 for (const auto& pair : store_scope_bound_collector_) {
192 Array<PrimExpr> indices = pair.first;
193 Array<PrimExpr> shape = pair.second;
194
195 ICHECK_EQ(indices.size(), shape.size())
196 << "Mismatch between dimension of physical shape and physical indices";
197
198 for (size_t i = 0; i < indices.size(); i++) {
199 PrimExpr index = indices[i];
200 PrimExpr upper_bound = shape[i];
201
202 if (const RampNode* ramp_index = index.as<RampNode>()) {
203 // In case index is base + stride * i.
204 // Non inclusive range.
205 index = Add(ramp_index->base,
206 Mul(ramp_index->stride,
207 make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1)));
208 }
209
210 // Try to simplify index and bound.
211 index = analyzer_.Simplify(index);
212 upper_bound = analyzer_.Simplify(upper_bound);
213
214 // Cast to the same type - signed, to be able to check lower bound.
215 index = Cast(DataType::Int(64), index);
216 upper_bound = Cast(DataType::Int(64), upper_bound);
217
218 // Looks like a lower bound should always be zero after normalization.
219 PrimExpr lower_bound = make_zero(DataType::Int(64));
220
221 PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound));
222 condition = condition.defined() ? And(condition, current_condition) : current_condition;
223 }
224 }
225 return condition;
226 }
227
228 // Whether we process store value recursively.
229 bool process_store_{false};
230 // Whether we face tvm_if_then_else intrinsic.
231 bool unsafe_rewritten_{false};
232 // Pool which collects the pair of index and shape for specific store/load.
233 std::vector<std::pair<Array<PrimExpr>, Array<PrimExpr>>> store_scope_bound_collector_;
234 // Error message.
235 const char* const error_message_ = "OUT OF THE BOUNDS";
236 // Hashtable which maps buffer_var to shape.
237 std::unordered_map<const VarNode*, Array<PrimExpr>> mem_to_shape_;
238 // internal analyzer
239 arith::Analyzer analyzer_;
240};
241
242Stmt InstrumentBoundCheckers(Stmt stmt) {
243 BoundCollector bound_collector;
244 // At first walk recursively and collect bound attributes.
245 bound_collector(stmt);
246 return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
247}
248
249namespace transform {
250
251Pass InstrumentBoundCheckers() {
252 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
253 auto* n = f.CopyOnWrite();
254 BoundCollector bound_collector;
255 // At first walk recursively and collect bound attributes.
256 bound_collector(n->body);
257 n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
258 return f;
259 };
260 return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
261}
262
263TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
264 .set_body_typed(InstrumentBoundCheckers);
265
266} // namespace transform
267
268} // namespace tir
269} // namespace tvm
270