1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file verify_memory.cc |
22 | * \brief Pass to check if memory accesses are legal. |
23 | */ |
24 | #include <tvm/ir/transform.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/target/target.h> |
27 | #include <tvm/tir/analysis.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | namespace { |
35 | |
36 | /*! |
37 | * \brief Verify if memory accesses are legal. |
38 | * |
39 | * In the case that tgt is cuda, if workload is not bound with |
40 | * threads, CPU code is generated that tries to access GPU memory, |
41 | * which is illegal. |
42 | * |
43 | * This pass performs such verification by checking if all |
44 | * memory accesses are bound with threads when device type is GPU. |
45 | */ |
46 | class MemoryAccessVerifier final : protected StmtExprVisitor { |
47 | public: |
48 | /// Special member functions |
49 | //@{ |
50 | explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} |
51 | virtual ~MemoryAccessVerifier() = default; |
52 | MemoryAccessVerifier(const MemoryAccessVerifier&) = delete; |
53 | MemoryAccessVerifier(MemoryAccessVerifier&&) = delete; |
54 | MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete; |
55 | MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete; |
56 | //@} |
57 | |
58 | /// Interface to perform memory access verification |
59 | void Run() { |
60 | if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return; |
61 | StmtExprVisitor::VisitStmt(func_->body); |
62 | } |
63 | |
64 | /// Verification result |
65 | std::vector<String> Errors() const { return errs_; } |
66 | |
67 | protected: |
68 | /// Visitor implementation |
69 | //@{ |
70 | void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); } |
71 | |
72 | void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); } |
73 | |
74 | void VisitStmt_(const LetStmtNode* op) final { |
75 | // Book keep definitions |
76 | defs_[op->var.get()] = op->value; |
77 | return StmtExprVisitor::VisitStmt_(op); |
78 | } |
79 | |
80 | void VisitStmt_(const AttrStmtNode* op) final { |
81 | if (!InThreadEnv() && |
82 | (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { |
83 | EnterThreadEnv(); |
84 | StmtExprVisitor::VisitStmt_(op); |
85 | ExitThreadEnv(); |
86 | } else { |
87 | StmtExprVisitor::VisitStmt_(op); |
88 | } |
89 | } |
90 | |
91 | void VisitExpr_(const LoadNode* op) final { |
92 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
93 | } |
94 | |
95 | void VisitStmt_(const StoreNode* op) final { |
96 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
97 | } |
98 | |
99 | void VisitExpr_(const BufferLoadNode* op) final { |
100 | HandleLoadStoreToVariable(op->buffer->data); |
101 | return StmtExprVisitor::VisitExpr_(op); |
102 | } |
103 | |
104 | void VisitStmt_(const BufferStoreNode* op) final { |
105 | HandleLoadStoreToVariable(op->buffer->data); |
106 | return StmtExprVisitor::VisitStmt_(op); |
107 | } |
108 | //@} |
109 | |
110 | /// Check if the value of a Variable comes from function argument. |
111 | bool IsFromFunctionArgs(const VarNode* var) const { |
112 | const VarNode* V = var; |
113 | for (auto kv : func_->buffer_map) { |
114 | if (V == kv.second->data.get()) return true; |
115 | } |
116 | |
117 | while (true) { |
118 | // Variable is from function args. Return true. |
119 | if (V == func_->params[0].get()) return true; |
120 | |
121 | // The value is expected to come from a tvm_struct_get Call. |
122 | // Get the first argument of tvm_struct_get, and continue. |
123 | const auto& iter = defs_.find(V); |
124 | if (iter == defs_.end()) return false; |
125 | const CallNode* C = iter->second.as<const CallNode>(); |
126 | if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false; |
127 | V = C->args[0].as<VarNode>(); |
128 | } |
129 | return false; |
130 | } |
131 | |
132 | /// Handle memory access to a Variable |
133 | void HandleLoadStoreToVariable(const Var& var) { |
134 | // We skip the access within thread env. |
135 | if (InThreadEnv()) return; |
136 | |
137 | // We only handle the variable from function argument. |
138 | // If it does not come from args, then it could be allocated internally, |
139 | // it may possibly be in host or device address space. |
140 | // We do not handle this case, and skip it conservatively. |
141 | if (!IsFromFunctionArgs(var.get())) return; |
142 | |
143 | // The verification fails in this case. |
144 | std::stringstream s; |
145 | s << "Variable `" << var |
146 | << "` is directly accessed by host memory (it is not contained in a thread environment or in " |
147 | "the function arguments." ; |
148 | errs_.push_back(s.str()); |
149 | } |
150 | |
151 | /// Status getter/setter |
152 | //@{ |
153 | bool InThreadEnv() const { return in_thread_env_; } |
154 | void EnterThreadEnv() { in_thread_env_ = true; } |
155 | void ExitThreadEnv() { in_thread_env_ = false; } |
156 | //@} |
157 | |
158 | /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. |
159 | static bool IsGPUDevice(int dev_type) { |
160 | return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || |
161 | kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; |
162 | } |
163 | /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. |
164 | static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } |
165 | |
166 | private: |
167 | /// Status of visitor |
168 | //@{ |
169 | bool in_thread_env_{false}; |
170 | std::vector<String> errs_; |
171 | //@} |
172 | tir::PrimFunc func_{nullptr}; ///< Function to be verified. |
173 | int dev_type_{kDLCPU}; ///< Device type |
174 | std::unordered_map<const VarNode*, PrimExpr> defs_; ///< Variable definitions |
175 | }; |
176 | } // namespace |
177 | |
178 | /// Interface of VerifyMemory pass |
179 | std::vector<String> VerifyMemory_(const PrimFunc& func) { |
180 | auto target = func->GetAttr<Target>(tvm::attr::kTarget); |
181 | ICHECK(target.defined()) << "VerifyMemory: Require the target attribute" ; |
182 | |
183 | VLOG(1) << "verifying memory for target '" << target.value()->str() |
184 | << "' for primitive:" << std::endl |
185 | << func; |
186 | |
187 | if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == |
188 | CallingConv::kDefault) { |
189 | MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType()); |
190 | v.Run(); |
191 | return v.Errors(); |
192 | } else { |
193 | return {}; |
194 | } |
195 | } |
196 | |
197 | bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } |
198 | |
199 | TVM_REGISTER_GLOBAL("tir.analysis.verify_memory" ).set_body_typed(VerifyMemory); |
200 | |
201 | namespace transform { |
202 | |
203 | Pass VerifyMemory() { |
204 | auto pass_func = [=](IRModule mod, PassContext ctx) { |
205 | for (auto kv : mod->functions) { |
206 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
207 | auto func = GetRef<PrimFunc>(n); |
208 | auto errs = VerifyMemory_(func); |
209 | if (errs.size() > 0) { |
210 | std::stringstream s; |
211 | for (auto& err : errs) { |
212 | s << " " << err << "\n" ; |
213 | } |
214 | LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n" |
215 | << s.str() << " Did you forget to bind?\n" |
216 | << func; |
217 | } |
218 | } |
219 | } |
220 | return mod; |
221 | }; |
222 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory" , {}); |
223 | } |
224 | |
225 | TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory" ).set_body_typed(VerifyMemory); |
226 | |
227 | } // namespace transform |
228 | } // namespace tir |
229 | } // namespace tvm |
230 | |