1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Lower TVM related builtin intrinsics such as packed call.
22 * \file tir/transforms/lower_tvm_buildin.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include <unordered_set>
31
32#include "ir_utils.h"
33
34namespace tvm {
35namespace tir {
36
37// Calculate the statistics of packed function.
38// These information are needed during codegen.
39class BuiltinLower : public StmtExprMutator {
40 public:
41 // NOTE: Right now, we make the following scoping requirement
42 // for memory allocated by the following primitives
43 // - tvm_stack_make_array
44 // - tvm_stack_make_shape
45 // - arg stack
46 //
47 // Scoping and liveness rules:
48 // - Every call_packed introduce a new scope.
49 // - The memory allocated by tvm_stack_make_array/make_shape will
50 // no longer become valid outside the scope (and may be reused by
51 // subsequent call_packed.
52 // - TODO(tvm-team): we might consider a root scope so stack_make_shape
53 // can be called out-side call_packed.
54 //
55 // Example:
56 // {
57 // call_packed(make_shape1(...),
58 // call_packed(make_shape2(...))
59 // call_packed(make_shape3(...))
60 // }
61 //
62 // In this case, make_shape1 and make_shape2 should not share memory,
63 // but they can share memory with make_shape3.
64 //
65 // Rationale: most of the packed calls needs their own internal
66 // argument stack, and those stack can be shared across calls.
67 // Scoping is a quick way to enable sharing without having
68 // to do full-scale liveness analysis and it does its job.
69 // Alternative approaches can also be used.
70 struct StackSizes {
71 // If a tvm_stack_make_shape call has no arguments, it is still
72 // valid and represents a scalar shape (). Therefore, -1 is used
73 // to represent "no shape arguments exist", while 0 represents
74 // "shape arguments exist, all of which are size 0".
75 int64_t shape_stack{-1};
76 uint64_t array_stack{0};
77 uint64_t arg_stack{0};
78 };
79
80 // Record stack frame for existing scope.
81 struct AllocaScope {
82 Buffer stack_shape;
83 Var stack_array = Var("stack_array", DataType::Handle());
84 Var stack_value = Var("stack_value", DataType::Handle());
85 Buffer stack_tcode;
86
87 StackSizes max_sizes;
88 StackSizes run_sizes;
89
90 void UpdateMax() {
91 max_sizes.shape_stack = std::max(max_sizes.shape_stack, run_sizes.shape_stack);
92 max_sizes.array_stack = std::max(max_sizes.array_stack, run_sizes.array_stack);
93 max_sizes.arg_stack = std::max(max_sizes.arg_stack, run_sizes.arg_stack);
94 }
95
96 void AssertMaxIsValid() const {
97 ICHECK((max_sizes.shape_stack >= run_sizes.shape_stack) ||
98 (max_sizes.array_stack >= run_sizes.array_stack) ||
99 (max_sizes.arg_stack >= run_sizes.arg_stack));
100 }
101 };
102
103 Stmt Build(Stmt stmt) { return this->VisitBodyAndRealizeAlloca(stmt); }
104
105 StackSizes GetMaxStack(Stmt stmt) {
106 BuiltinLower precheck;
107 precheck.is_precheck_ = true;
108 precheck.device_id_ = this->device_id_;
109 precheck.device_type_ = this->device_type_;
110
111 precheck.alloca_scope_.emplace_back();
112 {
113 // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_.
114 auto& scope = precheck.alloca_scope_.back();
115 scope.stack_shape =
116 decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape");
117 scope.stack_tcode =
118 decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode");
119 }
120
121 precheck.VisitStmt(stmt);
122
123 ICHECK_EQ(precheck.alloca_scope_.size(), 1);
124 return precheck.alloca_scope_[0].max_sizes;
125 }
126
127 // Allcoate stack frames, only at parallel-for or root.
128 Stmt VisitBodyAndRealizeAlloca(Stmt stmt) {
129 // Only perform the precheck up to the point where we would add a
130 // new scope.
131 if (is_precheck_) {
132 return stmt;
133 }
134
135 alloca_scope_.emplace_back();
136 {
137 // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_.
138 auto& scope = alloca_scope_.back();
139
140 // Initial check to identify maximum stack sizes. These are used
141 // to construct Buffer objects to hold the stack, which are then
142 // used when mutating.
143 scope.max_sizes = GetMaxStack(stmt);
144
145 if (scope.max_sizes.shape_stack != -1) {
146 scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)},
147 DataType::Int(64), "stack_shape");
148 stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack),
149 stmt);
150 }
151
152 if (scope.max_sizes.array_stack != 0) {
153 stmt = LetStmt(scope.stack_array, StackAlloca("array", scope.max_sizes.array_stack), stmt);
154 }
155
156 if (scope.max_sizes.arg_stack != 0) {
157 scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)},
158 DataType::Int(32), "stack_tcode");
159 stmt =
160 LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt);
161
162 stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack),
163 stmt);
164 }
165 }
166
167 stmt = this->VisitStmt(stmt);
168
169 ICHECK(!alloca_scope_.empty());
170 alloca_scope_.pop_back();
171
172 return stmt;
173 }
174
175 Stmt VisitStmt(const Stmt& s) final {
176 // allocate space to hold prepare stmts before s
177 prep_seq_stack_.emplace_back(std::vector<Stmt>());
178
179 auto scope_size = alloca_scope_.size();
180 auto stmt = StmtExprMutator::VisitStmt(s);
181 {
182 // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_.
183 auto& scope = alloca_scope_.back();
184 // This invariant asserts the assumption that
185 // make_stack_shape only happens within a call_packed.
186 // We could relax this in the future if we want to
187 // introduce root scope as a separate scope
188 ICHECK_EQ(alloca_scope_.size(), scope_size)
189 << "alloca_scope_ length is different before and after recursion";
190 ICHECK_EQ(scope.run_sizes.shape_stack, -1)
191 << "Expect no tvm_stack_make_shape outside of CallNodes";
192 ICHECK_EQ(scope.run_sizes.array_stack, 0)
193 << "Expect no tvm_stack_make_array outside of CallNodes";
194 }
195
196 auto prep_seq = std::move(prep_seq_stack_.back());
197 prep_seq_stack_.pop_back();
198
199 if (prep_seq.size() != 0) {
200 Stmt ret = SeqStmt::Flatten(prep_seq, stmt);
201 return ret;
202 } else {
203 return stmt;
204 }
205 }
206
207 Stmt VisitStmt_(const LetStmtNode* op) final {
208 if (const CallNode* call = op->value.as<CallNode>()) {
209 if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) {
210 return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call));
211 }
212 }
213 return StmtExprMutator::VisitStmt_(op);
214 }
215
216 Stmt VisitStmt_(const AllocateNode* op) {
217 // Lower allocate to device allocate when needed.
218 Stmt stmt = StmtExprMutator::VisitStmt_(op);
219 op = stmt.as<AllocateNode>();
220 // Get constant allocation bound.
221 int64_t nbytes = GetVectorBytes(op->dtype);
222 // If the buffers are for CPU and have global scope,
223 // and less than runtime::kMaxStackAlloca heuristic
224 // they are not serviced with TVMBackendWorkspaceAlloc calls
225 // to be placed on stack.
226 if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) {
227 if (Downcast<Bool>(op->annotations[transform::kDisableLowerTVMBuiltin])) {
228 return stmt;
229 }
230 }
231 if (device_type_.defined()) {
232 if (const auto* dev_type = device_type_.as<IntImmNode>()) {
233 auto storage_scope = Downcast<PointerType>(op->buffer_var->type_annotation)->storage_scope;
234 if (dev_type->value == kDLCPU && storage_scope == "global") {
235 size_t constant_size = op->ConstantAllocationSize();
236 if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
237 return stmt;
238 }
239 }
240 }
241 }
242 PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
243 for (size_t i = 0; i < op->extents.size(); ++i) {
244 total_bytes = total_bytes * op->extents[i];
245 }
246 ICHECK(device_type_.defined()) << "Unknown device type in current IR";
247 ICHECK(device_id_.defined()) << "Unknown device id in current IR";
248 Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
249
250 Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
251 throw_last_error),
252 op->body});
253 Stmt alloca = LetStmt(
254 op->buffer_var,
255 Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
256 {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
257 cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
258 IntImm(DataType::Int(32), op->dtype.bits())}),
259 body);
260
261 PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
262 {cast(DataType::Int(32), device_type_),
263 cast(DataType::Int(32), device_id_), op->buffer_var});
264 Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
265 body = SeqStmt({alloca, free_stmt});
266 body = AttrStmt(op->buffer_var, attr::storage_alignment,
267 make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
268 return body;
269 }
270
271 Stmt VisitStmt_(const AttrStmtNode* op) final {
272 if (op->attr_key == attr::device_id) {
273 ICHECK(!device_id_.defined());
274 device_id_ = op->value;
275 return this->VisitStmt(op->body);
276 } else if (op->attr_key == attr::device_type) {
277 ICHECK(!device_type_.defined());
278 device_type_ = op->value;
279 return this->VisitStmt(op->body);
280 } else {
281 return StmtExprMutator::VisitStmt_(op);
282 }
283 }
284 Stmt VisitStmt_(const ForNode* op) final {
285 PrimExpr min = this->VisitExpr(op->min);
286 PrimExpr extent = this->VisitExpr(op->extent);
287 Stmt body;
288
289 if (op->kind == ForKind::kParallel) {
290 body = this->VisitBodyAndRealizeAlloca(op->body);
291 } else {
292 body = this->VisitStmt(op->body);
293 }
294
295 if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
296 return GetRef<Stmt>(op);
297 } else {
298 auto n = CopyOnWrite(op);
299 n->min = std::move(min);
300 n->extent = std::move(extent);
301 n->body = std::move(body);
302 return Stmt(n);
303 }
304 }
305 PrimExpr VisitExpr_(const CallNode* op) final {
306 if (op->op.same_as(builtin::tvm_call_packed())) {
307 return MakeCallPacked(op, /* use_string_lookup */ true);
308 } else if (op->op.same_as(builtin::tvm_call_cpacked())) {
309 return MakeCallPacked(op, /* use_string_lookup */ false);
310 } else if (op->op.same_as(builtin::tvm_call_trace_packed())) {
311 return MakeCallTracePacked(op);
312 } else if (op->op.same_as(builtin::tvm_stack_make_shape())) {
313 return MakeShape(op);
314 } else if (op->op.same_as(builtin::tvm_stack_make_array())) {
315 return MakeArray(op);
316 } else if (op->op.same_as(builtin::tvm_context_id())) {
317 return make_zero(op->dtype);
318 } else if (op->op.same_as(builtin::dma_copy())) {
319 return MakeDMACopy(op);
320 } else if (op->op.same_as(builtin::dma_wait())) {
321 return MakeDMAWait(op);
322 } else {
323 return StmtExprMutator::VisitExpr_(op);
324 }
325 }
326
327 PrimExpr MakeDMACopy(const CallNode* op) {
328 PrimExpr queue_id = op->args[0];
329 PrimExpr dst = op->args[1];
330 PrimExpr src = op->args[2];
331 PrimExpr size = op->args[3];
332 PrimExpr bypass_cache = op->args[4];
333
334 std::string fdevapi_prefix =
335 "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
336
337 Call call_packed =
338 Call(DataType::Int(32), builtin::tvm_call_packed(),
339 {StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size, bypass_cache});
340 return VisitExpr(call_packed);
341 }
342
343 PrimExpr MakeDMAWait(const CallNode* op) {
344 PrimExpr queue_id = op->args[0];
345 PrimExpr inflight = op->args[1];
346
347 std::string fdevapi_prefix =
348 "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));
349
350 Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
351 {StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
352 return VisitExpr(call_packed);
353 }
354
355 // call shape
356 PrimExpr MakeShape(const CallNode* op) {
357 // if args.size() == 0, it represents a scalar shape ()
358 ICHECK(!alloca_scope_.empty());
359 auto& scope = alloca_scope_.back();
360 auto& prep_seq = prep_seq_stack_.back();
361 if (scope.run_sizes.shape_stack == -1) {
362 scope.run_sizes.shape_stack = 0;
363 }
364 int64_t stack_begin = scope.run_sizes.shape_stack;
365 scope.run_sizes.shape_stack += op->args.size();
366 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
367 op = expr.as<CallNode>();
368 // no need to perform any store for a scalar shape
369 for (size_t i = 0; i < op->args.size(); ++i) {
370 prep_seq.emplace_back(BufferStore(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
371 {ConstInt32(stack_begin + i)}));
372 }
373 return AddressOffset(scope.stack_shape->data, DataType::Int(64), stack_begin);
374 }
375 // make array
376 PrimExpr MakeArray(const CallNode* op) {
377 ICHECK(!alloca_scope_.empty());
378 auto& scope = alloca_scope_.back();
379 auto& prep_seq = prep_seq_stack_.back();
380
381 size_t idx = scope.run_sizes.array_stack;
382 scope.run_sizes.array_stack += 1;
383 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
384 op = expr.as<CallNode>();
385
386 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
387 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
388 PrimExpr strides = op->args[2];
389 if (!strides.defined() || is_zero(strides)) {
390 strides = make_zero(DataType::Handle());
391 }
392 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
393 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
394 DataType dtype = op->args[4].dtype();
395 prep_seq.emplace_back(
396 TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode,
397 make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
398 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
399 make_const(DataType::UInt(8), dtype.bits())));
400 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
401 make_const(DataType::UInt(16), dtype.lanes())));
402 // set byte offset
403 int data_bytes = GetVectorBytes(dtype);
404 PrimExpr elem_offset = op->args[5];
405 PrimExpr byte_offset;
406 if (!is_zero(elem_offset)) {
407 byte_offset = elem_offset * make_const(elem_offset.dtype(), data_bytes);
408 } else {
409 byte_offset = elem_offset;
410 }
411 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
412 cast(DataType::UInt(64), byte_offset)));
413 ICHECK(device_type_.defined()) << "Unknown device type in current IR";
414 ICHECK(device_id_.defined()) << "Unknown device id in current IR";
415 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
416 cast(DataType::Int(32), device_id_)));
417 prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
418 cast(DataType::Int(32), device_type_)));
419 return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr);
420 }
421 // call packed.
422 PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) {
423 auto& scope = alloca_scope_.back();
424 auto& prep_seq = prep_seq_stack_.back();
425
426 int64_t restore_shape_stack = scope.run_sizes.shape_stack;
427 size_t restore_array_stack = scope.run_sizes.array_stack;
428 size_t arg_stack_begin = scope.run_sizes.arg_stack;
429
430 size_t arg_count = op->args.size();
431
432 // cpacked expects a resource_handle parameter
433 if (!use_string_lookup) {
434 arg_count--;
435 }
436
437 scope.run_sizes.arg_stack += arg_count;
438 // Specially handle the buffer packed intrinsic
439 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
440 op = expr.as<CallNode>();
441 for (size_t i = 1; i < arg_count; ++i) {
442 PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
443 PrimExpr arg = op->args[i];
444 DataType t = arg.dtype();
445 DataType api_type = APIType(t);
446 if (t != api_type) {
447 arg = Cast(api_type, arg);
448 }
449 prep_seq.emplace_back(TVMStructSet(scope.stack_value,
450 static_cast<int>(arg_stack_begin + i - 1),
451 builtin::kTVMValueContent, arg));
452 int arg_tcode = api_type.code();
453 if (api_type.is_handle() && arg.as<StringImmNode>()) {
454 arg_tcode = kTVMStr;
455 }
456 if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
457 prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
458 }
459 // Verify stack size matches earlier value.
460 if (is_precheck_) {
461 scope.UpdateMax();
462 } else {
463 scope.AssertMaxIsValid();
464 }
465 scope.run_sizes.shape_stack = restore_shape_stack;
466 scope.run_sizes.array_stack = restore_array_stack;
467 scope.run_sizes.arg_stack = arg_stack_begin;
468 Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
469 ConstInt32(arg_stack_begin),
470 ConstInt32(arg_stack_begin + op->args.size() - 1)};
471
472 // cpacked call resource_handle
473 if (!use_string_lookup) {
474 PrimExpr last_arg = op->args[arg_count];
475 const VarNode* var_node = last_arg.as<VarNode>();
476 if (var_node != nullptr) {
477 tir::Var resource_handle = GetRef<Var>(var_node);
478 packed_args.push_back(StringImm(resource_handle->name_hint));
479 } else {
480 packed_args.push_back(last_arg);
481 }
482 }
483
484 auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered()
485 : builtin::tvm_call_cpacked_lowered();
486 return Call(op->dtype, builtin_call, packed_args);
487 }
488
489 PrimExpr MakeCallTracePacked(const CallNode* op) {
490 ICHECK(!alloca_scope_.empty());
491 auto& scope = alloca_scope_.back();
492 auto& prep_seq = prep_seq_stack_.back();
493
494 int64_t restore_shape_stack = scope.run_sizes.shape_stack;
495 size_t restore_array_stack = scope.run_sizes.array_stack;
496 size_t arg_stack_begin = scope.run_sizes.arg_stack;
497 scope.run_sizes.arg_stack += op->args.size();
498 size_t args_size = op->args.size();
499 ICHECK_GT(args_size, 0);
500 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
501 op = expr.as<CallNode>();
502 for (size_t i = 1; i < op->args.size(); ++i) {
503 PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1);
504 PrimExpr arg = op->args[i];
505 DataType t = arg.dtype();
506 DataType api_type = APIType(t);
507 if (t != api_type) {
508 arg = Cast(api_type, arg);
509 }
510 prep_seq.emplace_back(TVMStructSet(scope.stack_value,
511 static_cast<int>(arg_stack_begin + i - 1),
512 builtin::kTVMValueContent, arg));
513 int arg_tcode = api_type.code();
514 ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
515 prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index}));
516 }
517 // Verify stack size matches earlier value.
518 if (is_precheck_) {
519 scope.UpdateMax();
520 } else {
521 scope.AssertMaxIsValid();
522 }
523 scope.run_sizes.shape_stack = restore_shape_stack;
524 scope.run_sizes.array_stack = restore_array_stack;
525 // Update the top of the stack, so we can use more than one
526 // packed function's arguments with the one stack.
527 scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1;
528 Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data,
529 ConstInt32(arg_stack_begin),
530 ConstInt32(arg_stack_begin + op->args.size() - 1),
531 // Pass traced value.
532 op->args[args_size - 1]};
533 return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args);
534 }
535
536 Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) {
537 ICHECK(device_type_.defined()) << "Unknown device type in current IR";
538 ICHECK(device_id_.defined()) << "Unknown device id in current IR";
539 Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
540
541 Stmt body = SeqStmt(
542 {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
543 let->body});
544
545 DataType dtype =
546 let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;
547
548 std::string fdevapi_prefix = "device_api.";
549 fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
550
551 Array<PrimExpr> args = {
552 StringImm(fdevapi_prefix + ".alloc_nd"),
553 device_type_,
554 device_id_,
555 IntImm(DataType::Int(32), dtype.code()),
556 IntImm(DataType::Int(32), dtype.bits()),
557 };
558
559 for (size_t i = 0; i < call->args.size(); ++i) {
560 args.push_back(call->args[i]);
561 }
562
563 Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
564 Stmt alloca = LetStmt(let->var, call_packed, body);
565
566 PrimExpr storage_scope = call->args[0];
567 Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
568 {StringImm(fdevapi_prefix + ".free_nd"), device_type_, device_id_,
569 storage_scope, let->var});
570
571 Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
572 body = SeqStmt({alloca, free_stmt});
573 return body;
574 }
575
576 private:
577 bool IsArrayHandle(const PrimExpr& arg) {
578 // specially set array handle.
579 if (const CallNode* buf = arg.as<CallNode>()) {
580 if (buf->op.same_as(builtin::tvm_struct_get()) &&
581 buf->args[2].as<IntImmNode>()->value == builtin::kArrAddr) {
582 return true;
583 }
584 }
585 return false;
586 }
587
588 // The prepration sequence to be emitted before the current statement.
589 std::vector<std::vector<Stmt>> prep_seq_stack_;
590 PrimExpr device_type_;
591 PrimExpr device_id_;
592
593 bool is_precheck_{false};
594
595 // Record all stack frames.
596 std::vector<AllocaScope> alloca_scope_;
597};
598
599namespace transform {
600
601Pass LowerTVMBuiltin() {
602 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
603 auto* n = f.CopyOnWrite();
604 n->body = BuiltinLower().Build(n->body);
605 VLOG(2) << "LowerTVMBuiltin: " << f;
606 return f;
607 };
608 return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
609}
610
611TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin);
612
613} // namespace transform
614} // namespace tir
615} // namespace tvm
616