1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Out of bounds array access static analyzer.
22 */
23
24#include <tvm/tir/transform.h>
25
26#include "../../arith/ir_visitor_with_analyzer.h"
27#include "../schedule/error.h"
28
29namespace tvm {
30namespace tir {
31namespace transform {
32struct OOBLocation {
33 Buffer buf;
34 size_t dimension;
35 ObjectRef index;
36 arith::IntSet index_bounds;
37 arith::IntSet shape_bounds;
38};
39
40class OOBError : public ScheduleError {
41 public:
42 OOBError(IRModule mod, std::vector<OOBLocation> locations) : mod_(mod), locations_(locations) {}
43 String FastErrorString() const final { return "Out of bound memory access"; }
44
45 String DetailRenderTemplate() const final {
46 std::stringstream s;
47 for (const auto& oob : locations_) {
48 s << "Out of bounds memory access on buffer " << oob.buf->name << " dimension "
49 << oob.dimension << ".";
50 s << " index " << oob.index << " with bounds [" << oob.index_bounds.min() << ", "
51 << oob.index_bounds.max() << "] is outside the range [0, " << oob.shape_bounds.min()
52 << "].";
53 s << "\n";
54 }
55 return s.str();
56 }
57 IRModule mod() const final { return mod_; }
58 Array<ObjectRef> LocationsOfInterest() const final {
59 std::vector<ObjectRef> locs;
60 for (auto loc : locations_) {
61 locs.push_back(loc.index);
62 }
63 return locs;
64 }
65
66 private:
67 IRModule mod_;
68 std::vector<OOBLocation> locations_;
69};
70class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer {
71 using IRVisitorWithAnalyzer::VisitExpr_;
72 using IRVisitorWithAnalyzer::VisitStmt_;
73
74 public:
75 void VisitStmt_(const BufferStoreNode* node) final {
76 for (size_t i = 0; i < node->buffer->shape.size(); i++) {
77 CheckBounds(node, i);
78 }
79 IRVisitorWithAnalyzer::VisitStmt_(node);
80 }
81 void VisitExpr_(const BufferLoadNode* node) final {
82 for (size_t i = 0; i < node->buffer->shape.size(); i++) {
83 CheckBounds(node, i);
84 }
85 IRVisitorWithAnalyzer::VisitExpr_(node);
86 }
87
88 template <class T>
89 void CheckBounds(const T* node, size_t i) {
90 auto ind_bounds = analyzer_.int_set(node->indices[i]);
91 auto shape_bounds = analyzer_.int_set(node->buffer->shape[i]);
92 // We would expect that
93 // `analyzer_.CanProve(node->indices[i] < 0 || node->indices[i] >= node->buffer->shape[i])`
94 // would be the way to check if any out of bounds access occurs here, but `CanProve` checks if
95 // the statement is true for all possible values (universal quantification). For a mix of in
96 // bounds and out of bounds access, no out of bounds access would be reported. We instead want
97 // to check if there is any value for which the access is out of bounds (existential
98 // quantification).
99 // An solution would be to check that the index is in bounds for every possible value. This
100 // has the problem that some valid access patterns maybe be valid but not provably valid. We
101 // prefer that this analysis is conservative and only shows errors that are provable. This leads
102 // us to the following check: are the bounds of the index outside the bounds of the shape.
103 if (analyzer_.CanProve(ind_bounds.max() >= shape_bounds.min()) ||
104 analyzer_.CanProve(ind_bounds.min() < 0)) {
105 errors.push_back({node->buffer, i, node->indices[i], ind_bounds, shape_bounds});
106 }
107 }
108
109 std::vector<OOBLocation> errors;
110};
111
112transform::Pass OOBChecker() {
113 auto pass_func = [=](tir::PrimFunc func, IRModule mod, transform::PassContext ctx) {
114 OOBCheckerVisitor checker;
115 checker(func->body);
116 if (checker.errors.size() > 0) {
117 // mod doesn't contain our function, so we construct a new mod with out function
118 IRModule func_mod({{GlobalVar("main"), func}});
119 LOG(FATAL) << OOBError(func_mod, checker.errors).RenderReport("Out of bounds checker");
120 }
121 return func;
122 };
123 return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {});
124}
125
126TVM_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker);
127} // namespace transform
128} // namespace tir
129} // namespace tvm
130