1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file verify_gpu_code.cc |
22 | * \brief Verify the correctness of a GPU IR. |
23 | * It will check the whether the amount of memory usage or the number of threads |
24 | * in a block exceeds the limit |
25 | */ |
26 | |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/tir/analysis.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | |
33 | #include "../../runtime/thread_storage_scope.h" |
34 | #include "../transforms/ir_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | class GPUCodeVerifier : public StmtExprVisitor { |
40 | public: |
41 | std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block, |
42 | int64_t max_shared_memory_per_block, int64_t max_threads_per_block, |
43 | int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z, |
44 | int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) { |
45 | max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block); |
46 | max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block); |
47 | max_threads_per_block_ = static_cast<size_t>(max_threads_per_block); |
48 | max_thread_x_ = static_cast<size_t>(max_thread_x); |
49 | max_thread_y_ = static_cast<size_t>(max_thread_y); |
50 | max_thread_z_ = static_cast<size_t>(max_thread_z); |
51 | max_vthread_ = static_cast<size_t>(max_vthread); |
52 | max_vector_bytes_ = static_cast<size_t>(max_vector_bytes); |
53 | max_kernels_ = static_cast<size_t>(max_kernels); |
54 | Reset_(); |
55 | |
56 | // TODO(jcf94): Add support of detecting CUDA Misaligned Address error |
57 | this->VisitStmt(stmt); |
58 | |
59 | return errors_; |
60 | } |
61 | |
62 | void VisitStmt_(const AllocateNode* op) final { |
63 | StmtVisitor::VisitStmt_(op); |
64 | auto scope = GetPtrStorageScope(op->buffer_var); |
65 | runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope); |
66 | // visit an allocation of a buffer in shared memory, record its size |
67 | if (storage_scope.rank == runtime::StorageRank::kLocal) { |
68 | size_t size = static_cast<size_t>(op->ConstantAllocationSize()); |
69 | local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); |
70 | } else if (storage_scope.rank == runtime::StorageRank::kShared) { |
71 | size_t size = static_cast<size_t>(op->ConstantAllocationSize()); |
72 | shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); |
73 | } |
74 | if (op->dtype.lanes() > 1) { |
75 | if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { |
76 | std::stringstream s; |
77 | s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" |
78 | << op->dtype.bytes() << ") for dtype " << op->dtype |
79 | << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")" ; |
80 | errors_.push_back(s.str()); |
81 | } |
82 | } |
83 | } |
84 | |
85 | void VisitStmt_(const AttrStmtNode* op) final { |
86 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { |
87 | if (nest_level_ == 0) { |
88 | // enter a new kernel, reset statistics |
89 | Reset_(); |
90 | kernels_launched_++; |
91 | } |
92 | |
93 | Var var = op->node.as<IterVarNode>()->var; |
94 | const auto* extent = op->value.as<IntImmNode>(); |
95 | ICHECK(extent); |
96 | |
97 | std::string name = var.get()->name_hint; |
98 | // record the number of threads in a block |
99 | if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || |
100 | name == "vthread" ) { |
101 | size_t length = static_cast<size_t>(extent->value); |
102 | if (!visited_threads_.count(name)) { |
103 | visited_threads_.insert(name); |
104 | thread_per_block_ *= length; |
105 | |
106 | auto err = [this](std::string id, size_t ext, size_t m) { |
107 | if (ext > m) { |
108 | std::stringstream s; |
109 | s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m |
110 | << ");" ; |
111 | errors_.push_back(s.str()); |
112 | } |
113 | }; |
114 | |
115 | if (name == "threadIdx.x" ) { |
116 | err("threadIdx.x" , length, max_thread_x_); |
117 | thread_x_extent_ = length; |
118 | } else if (name == "threadIdx.y" ) { |
119 | err("threadIdx.y" , length, max_thread_y_); |
120 | thread_y_extent_ = length; |
121 | } else if (name == "threadIdx.z" ) { |
122 | err("threadIdx.z" , length, max_thread_z_); |
123 | thread_z_extent_ = length; |
124 | } else if (name == "vthread" ) { |
125 | err("vthread" , length, max_vthread_); |
126 | } |
127 | } else { |
128 | // the thread should be bound to axes with the same length |
129 | auto err = [this, name](std::string id, size_t ext, size_t m) { |
130 | if (name == id && ext != m) { |
131 | std::stringstream s; |
132 | s << "Extent of " << id << " (" << ext << ") does not match the bound " << m; |
133 | errors_.push_back(s.str()); |
134 | } |
135 | }; |
136 | err("threadIdx.x" , length, thread_x_extent_); |
137 | err("threadIdx.y" , length, thread_y_extent_); |
138 | err("threadIdx.z" , length, thread_z_extent_); |
139 | } |
140 | } |
141 | |
142 | nest_level_++; |
143 | StmtVisitor::VisitStmt_(op); |
144 | nest_level_--; |
145 | |
146 | if (nest_level_ == 0) { |
147 | // exit a kernel, check the validity |
148 | auto err = [this](std::string id, size_t num, size_t m) { |
149 | if (num > m) { |
150 | std::stringstream s; |
151 | s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m |
152 | << ")" ; |
153 | errors_.push_back(s.str()); |
154 | } |
155 | }; |
156 | err("threads per block" , thread_per_block_, max_threads_per_block_); |
157 | err("local memory per block" , local_memory_per_block_, max_local_memory_per_block_); |
158 | err("shared memory per block" , shared_memory_per_block_, max_shared_memory_per_block_); |
159 | |
160 | if (kernels_launched_ > max_kernels_) { |
161 | std::stringstream s; |
162 | s << "Number of launched kernels (" << kernels_launched_ |
163 | << ") is greater than the allowed maximum (" << max_kernels_ << ")" ; |
164 | errors_.push_back(s.str()); |
165 | } |
166 | } |
167 | } else { |
168 | StmtVisitor::VisitStmt_(op); |
169 | } |
170 | } |
171 | |
172 | void VisitStmt_(const ForNode* op) { |
173 | if (op->loop_var->name_hint == "vthread.s" ) { |
174 | const auto* extent = op->extent.as<IntImmNode>(); |
175 | ICHECK(extent); |
176 | |
177 | size_t num_vthread = static_cast<size_t>(extent->value); |
178 | if (num_vthread > max_vthread_) { |
179 | std::stringstream s; |
180 | s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum (" |
181 | << max_vthread_ << ")" ; |
182 | errors_.push_back(s.str()); |
183 | } |
184 | } |
185 | |
186 | StmtVisitor::VisitStmt_(op); |
187 | } |
188 | |
189 | void VisitExpr_(const LoadNode* op) final { |
190 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
191 | } |
192 | |
193 | void VisitStmt_(const StoreNode* op) final { |
194 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
195 | } |
196 | |
197 | void CheckBufferIndicesVectorizable(const Array<PrimExpr> indices) { |
198 | for (const auto index : indices) { |
199 | if (const auto* ramp = index.as<RampNode>()) { |
200 | if (!is_one(ramp->stride) && |
201 | static_cast<size_t>(ramp->dtype.lanes() * ramp->dtype.bytes()) > max_vector_bytes_) { |
202 | std::stringstream s; |
203 | s << "Number of lanes (" << ramp->dtype.lanes() << ") times number of bytes (" |
204 | << ramp->dtype.bytes() << ") for dtype " << ramp->dtype |
205 | << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")" ; |
206 | errors_.push_back(s.str()); |
207 | } |
208 | } |
209 | } |
210 | } |
211 | |
212 | void VisitExpr_(const CastNode* op) { |
213 | if (op->dtype.lanes() > 1) { |
214 | if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { |
215 | std::stringstream s; |
216 | s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" |
217 | << op->dtype.bytes() << ") for dtype " << op->dtype |
218 | << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")" ; |
219 | errors_.push_back(s.str()); |
220 | } |
221 | } |
222 | ExprVisitor::VisitExpr_(op); |
223 | } |
224 | |
225 | void VisitExpr_(const BufferLoadNode* op) { |
226 | if (op->dtype.lanes() > 1) { |
227 | if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { |
228 | std::stringstream s; |
229 | s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" |
230 | << op->dtype.bytes() << ") for dtype " << op->dtype |
231 | << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")" ; |
232 | errors_.push_back(s.str()); |
233 | } |
234 | CheckBufferIndicesVectorizable(op->indices); |
235 | } |
236 | ExprVisitor::VisitExpr_(op); |
237 | } |
238 | |
239 | void VisitStmt_(const BufferStoreNode* op) { |
240 | if (op->value->dtype.lanes() > 1) { |
241 | if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) > |
242 | max_vector_bytes_) { |
243 | std::stringstream s; |
244 | s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes (" |
245 | << op->value->dtype.bytes() << ") for dtype " << op->value->dtype |
246 | << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")" ; |
247 | errors_.push_back(s.str()); |
248 | } |
249 | CheckBufferIndicesVectorizable(op->indices); |
250 | } |
251 | StmtVisitor::VisitStmt_(op); |
252 | } |
253 | |
254 | private: |
255 | int nest_level_{0}; |
256 | |
257 | std::unordered_set<std::string> visited_threads_; |
258 | |
259 | size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; |
260 | |
261 | size_t local_memory_per_block_; |
262 | size_t shared_memory_per_block_; |
263 | size_t thread_per_block_; |
264 | size_t kernels_launched_{0}; |
265 | |
266 | size_t max_local_memory_per_block_; |
267 | size_t max_shared_memory_per_block_; |
268 | size_t max_threads_per_block_; |
269 | size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_; |
270 | size_t max_vector_bytes_; |
271 | size_t max_kernels_; |
272 | |
273 | std::vector<String> errors_; |
274 | |
275 | void Reset_() { |
276 | local_memory_per_block_ = 0; |
277 | shared_memory_per_block_ = 0; |
278 | |
279 | visited_threads_.clear(); |
280 | thread_per_block_ = 1; |
281 | } |
282 | }; |
283 | |
284 | std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) { |
285 | GPUCodeVerifier verifier; |
286 | |
287 | int64_t max_local_memory_per_block = INT64_MAX; |
288 | int64_t max_shared_memory_per_block = INT64_MAX; |
289 | int64_t max_threads_per_block = INT64_MAX; |
290 | int64_t max_thread_x = INT64_MAX; |
291 | int64_t max_thread_y = INT64_MAX; |
292 | int64_t max_thread_z = INT64_MAX; |
293 | int64_t max_vthread = INT64_MAX; |
294 | int64_t max_vector_bytes = INT64_MAX; |
295 | int64_t max_kernels = INT64_MAX; |
296 | |
297 | for (auto iter : constraints) { |
298 | const IntImmNode* val = iter.second.as<IntImmNode>(); |
299 | if (iter.first == "max_local_memory_per_block" ) { |
300 | max_local_memory_per_block = val->value; |
301 | } else if (iter.first == "max_shared_memory_per_block" ) { |
302 | max_shared_memory_per_block = val->value; |
303 | } else if (iter.first == "max_threads_per_block" ) { |
304 | max_threads_per_block = val->value; |
305 | } else if (iter.first == "max_thread_x" ) { |
306 | max_thread_x = val->value; |
307 | } else if (iter.first == "max_thread_y" ) { |
308 | max_thread_y = val->value; |
309 | } else if (iter.first == "max_thread_z" ) { |
310 | max_thread_z = val->value; |
311 | } else if (iter.first == "max_vthread" ) { |
312 | max_vthread = val->value; |
313 | } else if (iter.first == "max_vector_bytes" ) { |
314 | max_vector_bytes = val->value; |
315 | } else if (iter.first == "max_kernels" ) { |
316 | max_kernels = val->value; |
317 | } else { |
318 | LOG(FATAL) << "Invalid check item: " << iter.first; |
319 | } |
320 | } |
321 | |
322 | return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, |
323 | max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, |
324 | max_vthread, max_vector_bytes, max_kernels); |
325 | } |
326 | |
327 | bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) { |
328 | auto errs = VerifyGPUCode_(func, constraints); |
329 | return errs.size() == 0; |
330 | } |
331 | |
332 | TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code" ).set_body_typed(VerifyGPUCode); |
333 | |
334 | namespace transform { |
335 | |
336 | Pass VerifyGPUCode(Map<String, PrimExpr> constraints) { |
337 | auto pass_func = [=](IRModule mod, PassContext ctx) { |
338 | for (auto kv : mod->functions) { |
339 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
340 | auto func = GetRef<PrimFunc>(n); |
341 | auto errs = VerifyGPUCode_(func, constraints); |
342 | if (errs.size() != 0) { |
343 | std::stringstream s; |
344 | for (auto& err : errs) { |
345 | s << " " << err << std::endl; |
346 | } |
347 | LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n" |
348 | << s.str() << " In function\n" |
349 | << func; |
350 | } |
351 | } |
352 | } |
353 | return mod; |
354 | }; |
355 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode" , {}); |
356 | } |
357 | |
358 | TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode" ).set_body_typed(VerifyGPUCode); |
359 | |
360 | } // namespace transform |
361 | } // namespace tir |
362 | } // namespace tvm |
363 | |