1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file verify_gpu_code.cc
22 * \brief Verify the correctness of a GPU IR.
23 * It will check the whether the amount of memory usage or the number of threads
24 * in a block exceeds the limit
25 */
26
27#include <tvm/runtime/registry.h>
28#include <tvm/tir/analysis.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt.h>
31#include <tvm/tir/stmt_functor.h>
32
33#include "../../runtime/thread_storage_scope.h"
34#include "../transforms/ir_utils.h"
35
36namespace tvm {
37namespace tir {
38
39class GPUCodeVerifier : public StmtExprVisitor {
40 public:
41 std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
42 int64_t max_shared_memory_per_block, int64_t max_threads_per_block,
43 int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z,
44 int64_t max_vthread, int64_t max_vector_bytes, int64_t max_kernels) {
45 max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
46 max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
47 max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
48 max_thread_x_ = static_cast<size_t>(max_thread_x);
49 max_thread_y_ = static_cast<size_t>(max_thread_y);
50 max_thread_z_ = static_cast<size_t>(max_thread_z);
51 max_vthread_ = static_cast<size_t>(max_vthread);
52 max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);
53 max_kernels_ = static_cast<size_t>(max_kernels);
54 Reset_();
55
56 // TODO(jcf94): Add support of detecting CUDA Misaligned Address error
57 this->VisitStmt(stmt);
58
59 return errors_;
60 }
61
62 void VisitStmt_(const AllocateNode* op) final {
63 StmtVisitor::VisitStmt_(op);
64 auto scope = GetPtrStorageScope(op->buffer_var);
65 runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope);
66 // visit an allocation of a buffer in shared memory, record its size
67 if (storage_scope.rank == runtime::StorageRank::kLocal) {
68 size_t size = static_cast<size_t>(op->ConstantAllocationSize());
69 local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
70 } else if (storage_scope.rank == runtime::StorageRank::kShared) {
71 size_t size = static_cast<size_t>(op->ConstantAllocationSize());
72 shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
73 }
74 if (op->dtype.lanes() > 1) {
75 if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
76 std::stringstream s;
77 s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
78 << op->dtype.bytes() << ") for dtype " << op->dtype
79 << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
80 errors_.push_back(s.str());
81 }
82 }
83 }
84
85 void VisitStmt_(const AttrStmtNode* op) final {
86 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
87 if (nest_level_ == 0) {
88 // enter a new kernel, reset statistics
89 Reset_();
90 kernels_launched_++;
91 }
92
93 Var var = op->node.as<IterVarNode>()->var;
94 const auto* extent = op->value.as<IntImmNode>();
95 ICHECK(extent);
96
97 std::string name = var.get()->name_hint;
98 // record the number of threads in a block
99 if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
100 name == "vthread") {
101 size_t length = static_cast<size_t>(extent->value);
102 if (!visited_threads_.count(name)) {
103 visited_threads_.insert(name);
104 thread_per_block_ *= length;
105
106 auto err = [this](std::string id, size_t ext, size_t m) {
107 if (ext > m) {
108 std::stringstream s;
109 s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m
110 << ");";
111 errors_.push_back(s.str());
112 }
113 };
114
115 if (name == "threadIdx.x") {
116 err("threadIdx.x", length, max_thread_x_);
117 thread_x_extent_ = length;
118 } else if (name == "threadIdx.y") {
119 err("threadIdx.y", length, max_thread_y_);
120 thread_y_extent_ = length;
121 } else if (name == "threadIdx.z") {
122 err("threadIdx.z", length, max_thread_z_);
123 thread_z_extent_ = length;
124 } else if (name == "vthread") {
125 err("vthread", length, max_vthread_);
126 }
127 } else {
128 // the thread should be bound to axes with the same length
129 auto err = [this, name](std::string id, size_t ext, size_t m) {
130 if (name == id && ext != m) {
131 std::stringstream s;
132 s << "Extent of " << id << " (" << ext << ") does not match the bound " << m;
133 errors_.push_back(s.str());
134 }
135 };
136 err("threadIdx.x", length, thread_x_extent_);
137 err("threadIdx.y", length, thread_y_extent_);
138 err("threadIdx.z", length, thread_z_extent_);
139 }
140 }
141
142 nest_level_++;
143 StmtVisitor::VisitStmt_(op);
144 nest_level_--;
145
146 if (nest_level_ == 0) {
147 // exit a kernel, check the validity
148 auto err = [this](std::string id, size_t num, size_t m) {
149 if (num > m) {
150 std::stringstream s;
151 s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m
152 << ")";
153 errors_.push_back(s.str());
154 }
155 };
156 err("threads per block", thread_per_block_, max_threads_per_block_);
157 err("local memory per block", local_memory_per_block_, max_local_memory_per_block_);
158 err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_);
159
160 if (kernels_launched_ > max_kernels_) {
161 std::stringstream s;
162 s << "Number of launched kernels (" << kernels_launched_
163 << ") is greater than the allowed maximum (" << max_kernels_ << ")";
164 errors_.push_back(s.str());
165 }
166 }
167 } else {
168 StmtVisitor::VisitStmt_(op);
169 }
170 }
171
172 void VisitStmt_(const ForNode* op) {
173 if (op->loop_var->name_hint == "vthread.s") {
174 const auto* extent = op->extent.as<IntImmNode>();
175 ICHECK(extent);
176
177 size_t num_vthread = static_cast<size_t>(extent->value);
178 if (num_vthread > max_vthread_) {
179 std::stringstream s;
180 s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum ("
181 << max_vthread_ << ")";
182 errors_.push_back(s.str());
183 }
184 }
185
186 StmtVisitor::VisitStmt_(op);
187 }
188
189 void VisitExpr_(const LoadNode* op) final {
190 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
191 }
192
193 void VisitStmt_(const StoreNode* op) final {
194 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
195 }
196
197 void CheckBufferIndicesVectorizable(const Array<PrimExpr> indices) {
198 for (const auto index : indices) {
199 if (const auto* ramp = index.as<RampNode>()) {
200 if (!is_one(ramp->stride) &&
201 static_cast<size_t>(ramp->dtype.lanes() * ramp->dtype.bytes()) > max_vector_bytes_) {
202 std::stringstream s;
203 s << "Number of lanes (" << ramp->dtype.lanes() << ") times number of bytes ("
204 << ramp->dtype.bytes() << ") for dtype " << ramp->dtype
205 << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
206 errors_.push_back(s.str());
207 }
208 }
209 }
210 }
211
212 void VisitExpr_(const CastNode* op) {
213 if (op->dtype.lanes() > 1) {
214 if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
215 std::stringstream s;
216 s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
217 << op->dtype.bytes() << ") for dtype " << op->dtype
218 << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
219 errors_.push_back(s.str());
220 }
221 }
222 ExprVisitor::VisitExpr_(op);
223 }
224
225 void VisitExpr_(const BufferLoadNode* op) {
226 if (op->dtype.lanes() > 1) {
227 if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
228 std::stringstream s;
229 s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
230 << op->dtype.bytes() << ") for dtype " << op->dtype
231 << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
232 errors_.push_back(s.str());
233 }
234 CheckBufferIndicesVectorizable(op->indices);
235 }
236 ExprVisitor::VisitExpr_(op);
237 }
238
239 void VisitStmt_(const BufferStoreNode* op) {
240 if (op->value->dtype.lanes() > 1) {
241 if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
242 max_vector_bytes_) {
243 std::stringstream s;
244 s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes ("
245 << op->value->dtype.bytes() << ") for dtype " << op->value->dtype
246 << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
247 errors_.push_back(s.str());
248 }
249 CheckBufferIndicesVectorizable(op->indices);
250 }
251 StmtVisitor::VisitStmt_(op);
252 }
253
254 private:
255 int nest_level_{0};
256
257 std::unordered_set<std::string> visited_threads_;
258
259 size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;
260
261 size_t local_memory_per_block_;
262 size_t shared_memory_per_block_;
263 size_t thread_per_block_;
264 size_t kernels_launched_{0};
265
266 size_t max_local_memory_per_block_;
267 size_t max_shared_memory_per_block_;
268 size_t max_threads_per_block_;
269 size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
270 size_t max_vector_bytes_;
271 size_t max_kernels_;
272
273 std::vector<String> errors_;
274
275 void Reset_() {
276 local_memory_per_block_ = 0;
277 shared_memory_per_block_ = 0;
278
279 visited_threads_.clear();
280 thread_per_block_ = 1;
281 }
282};
283
284std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) {
285 GPUCodeVerifier verifier;
286
287 int64_t max_local_memory_per_block = INT64_MAX;
288 int64_t max_shared_memory_per_block = INT64_MAX;
289 int64_t max_threads_per_block = INT64_MAX;
290 int64_t max_thread_x = INT64_MAX;
291 int64_t max_thread_y = INT64_MAX;
292 int64_t max_thread_z = INT64_MAX;
293 int64_t max_vthread = INT64_MAX;
294 int64_t max_vector_bytes = INT64_MAX;
295 int64_t max_kernels = INT64_MAX;
296
297 for (auto iter : constraints) {
298 const IntImmNode* val = iter.second.as<IntImmNode>();
299 if (iter.first == "max_local_memory_per_block") {
300 max_local_memory_per_block = val->value;
301 } else if (iter.first == "max_shared_memory_per_block") {
302 max_shared_memory_per_block = val->value;
303 } else if (iter.first == "max_threads_per_block") {
304 max_threads_per_block = val->value;
305 } else if (iter.first == "max_thread_x") {
306 max_thread_x = val->value;
307 } else if (iter.first == "max_thread_y") {
308 max_thread_y = val->value;
309 } else if (iter.first == "max_thread_z") {
310 max_thread_z = val->value;
311 } else if (iter.first == "max_vthread") {
312 max_vthread = val->value;
313 } else if (iter.first == "max_vector_bytes") {
314 max_vector_bytes = val->value;
315 } else if (iter.first == "max_kernels") {
316 max_kernels = val->value;
317 } else {
318 LOG(FATAL) << "Invalid check item: " << iter.first;
319 }
320 }
321
322 return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
323 max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
324 max_vthread, max_vector_bytes, max_kernels);
325}
326
327bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
328 auto errs = VerifyGPUCode_(func, constraints);
329 return errs.size() == 0;
330}
331
332TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
333
334namespace transform {
335
336Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
337 auto pass_func = [=](IRModule mod, PassContext ctx) {
338 for (auto kv : mod->functions) {
339 if (auto* n = kv.second.as<PrimFuncNode>()) {
340 auto func = GetRef<PrimFunc>(n);
341 auto errs = VerifyGPUCode_(func, constraints);
342 if (errs.size() != 0) {
343 std::stringstream s;
344 for (auto& err : errs) {
345 s << " " << err << std::endl;
346 }
347 LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
348 << s.str() << " In function\n"
349 << func;
350 }
351 }
352 }
353 return mod;
354 };
355 return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {});
356}
357
358TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode);
359
360} // namespace transform
361} // namespace tir
362} // namespace tvm
363