1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file verify_memory.cc
22 * \brief Pass to check if memory accesses are legal.
23 */
24#include <tvm/ir/transform.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/target/target.h>
27#include <tvm/tir/analysis.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt_functor.h>
31
32namespace tvm {
33namespace tir {
34namespace {
35
36/*!
37 * \brief Verify if memory accesses are legal.
38 *
39 * In the case that tgt is cuda, if workload is not bound with
40 * threads, CPU code is generated that tries to access GPU memory,
41 * which is illegal.
42 *
43 * This pass performs such verification by checking if all
44 * memory accesses are bound with threads when device type is GPU.
45 */
46class MemoryAccessVerifier final : protected StmtExprVisitor {
47 public:
48 /// Special member functions
49 //@{
50 explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {}
51 virtual ~MemoryAccessVerifier() = default;
52 MemoryAccessVerifier(const MemoryAccessVerifier&) = delete;
53 MemoryAccessVerifier(MemoryAccessVerifier&&) = delete;
54 MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete;
55 MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete;
56 //@}
57
58 /// Interface to perform memory access verification
59 void Run() {
60 if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return;
61 StmtExprVisitor::VisitStmt(func_->body);
62 }
63
64 /// Verification result
65 std::vector<String> Errors() const { return errs_; }
66
67 protected:
68 /// Visitor implementation
69 //@{
70 void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }
71
72 void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }
73
74 void VisitStmt_(const LetStmtNode* op) final {
75 // Book keep definitions
76 defs_[op->var.get()] = op->value;
77 return StmtExprVisitor::VisitStmt_(op);
78 }
79
80 void VisitStmt_(const AttrStmtNode* op) final {
81 if (!InThreadEnv() &&
82 (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) {
83 EnterThreadEnv();
84 StmtExprVisitor::VisitStmt_(op);
85 ExitThreadEnv();
86 } else {
87 StmtExprVisitor::VisitStmt_(op);
88 }
89 }
90
91 void VisitExpr_(const LoadNode* op) final {
92 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
93 }
94
95 void VisitStmt_(const StoreNode* op) final {
96 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
97 }
98
99 void VisitExpr_(const BufferLoadNode* op) final {
100 HandleLoadStoreToVariable(op->buffer->data);
101 return StmtExprVisitor::VisitExpr_(op);
102 }
103
104 void VisitStmt_(const BufferStoreNode* op) final {
105 HandleLoadStoreToVariable(op->buffer->data);
106 return StmtExprVisitor::VisitStmt_(op);
107 }
108 //@}
109
110 /// Check if the value of a Variable comes from function argument.
111 bool IsFromFunctionArgs(const VarNode* var) const {
112 const VarNode* V = var;
113 for (auto kv : func_->buffer_map) {
114 if (V == kv.second->data.get()) return true;
115 }
116
117 while (true) {
118 // Variable is from function args. Return true.
119 if (V == func_->params[0].get()) return true;
120
121 // The value is expected to come from a tvm_struct_get Call.
122 // Get the first argument of tvm_struct_get, and continue.
123 const auto& iter = defs_.find(V);
124 if (iter == defs_.end()) return false;
125 const CallNode* C = iter->second.as<const CallNode>();
126 if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false;
127 V = C->args[0].as<VarNode>();
128 }
129 return false;
130 }
131
132 /// Handle memory access to a Variable
133 void HandleLoadStoreToVariable(const Var& var) {
134 // We skip the access within thread env.
135 if (InThreadEnv()) return;
136
137 // We only handle the variable from function argument.
138 // If it does not come from args, then it could be allocated internally,
139 // it may possibly be in host or device address space.
140 // We do not handle this case, and skip it conservatively.
141 if (!IsFromFunctionArgs(var.get())) return;
142
143 // The verification fails in this case.
144 std::stringstream s;
145 s << "Variable `" << var
146 << "` is directly accessed by host memory (it is not contained in a thread environment or in "
147 "the function arguments.";
148 errs_.push_back(s.str());
149 }
150
151 /// Status getter/setter
152 //@{
153 bool InThreadEnv() const { return in_thread_env_; }
154 void EnterThreadEnv() { in_thread_env_ = true; }
155 void ExitThreadEnv() { in_thread_env_ = false; }
156 //@}
157
158 /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
159 static bool IsGPUDevice(int dev_type) {
160 return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
161 kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type;
162 }
163 /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
164 static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; }
165
166 private:
167 /// Status of visitor
168 //@{
169 bool in_thread_env_{false};
170 std::vector<String> errs_;
171 //@}
172 tir::PrimFunc func_{nullptr}; ///< Function to be verified.
173 int dev_type_{kDLCPU}; ///< Device type
174 std::unordered_map<const VarNode*, PrimExpr> defs_; ///< Variable definitions
175};
176} // namespace
177
178/// Interface of VerifyMemory pass
179std::vector<String> VerifyMemory_(const PrimFunc& func) {
180 auto target = func->GetAttr<Target>(tvm::attr::kTarget);
181 ICHECK(target.defined()) << "VerifyMemory: Require the target attribute";
182
183 VLOG(1) << "verifying memory for target '" << target.value()->str()
184 << "' for primitive:" << std::endl
185 << func;
186
187 if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
188 CallingConv::kDefault) {
189 MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
190 v.Run();
191 return v.Errors();
192 } else {
193 return {};
194 }
195}
196
197bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; }
198
199TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);
200
201namespace transform {
202
203Pass VerifyMemory() {
204 auto pass_func = [=](IRModule mod, PassContext ctx) {
205 for (auto kv : mod->functions) {
206 if (auto* n = kv.second.as<PrimFuncNode>()) {
207 auto func = GetRef<PrimFunc>(n);
208 auto errs = VerifyMemory_(func);
209 if (errs.size() > 0) {
210 std::stringstream s;
211 for (auto& err : errs) {
212 s << " " << err << "\n";
213 }
214 LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
215 << s.str() << " Did you forget to bind?\n"
216 << func;
217 }
218 }
219 }
220 return mod;
221 };
222 return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
223}
224
225TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory);
226
227} // namespace transform
228} // namespace tir
229} // namespace tvm
230