1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Lower TVM related builtin intrinsics such as packed call. |
22 | * \file tir/transforms/lower_tvm_buildin.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <unordered_set> |
31 | |
32 | #include "ir_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | // Calculate the statistics of packed function. |
38 | // These information are needed during codegen. |
39 | class BuiltinLower : public StmtExprMutator { |
40 | public: |
41 | // NOTE: Right now, we make the following scoping requirement |
42 | // for memory allocated by the following primitives |
43 | // - tvm_stack_make_array |
44 | // - tvm_stack_make_shape |
45 | // - arg stack |
46 | // |
47 | // Scoping and liveness rules: |
48 | // - Every call_packed introduce a new scope. |
49 | // - The memory allocated by tvm_stack_make_array/make_shape will |
50 | // no longer become valid outside the scope (and may be reused by |
51 | // subsequent call_packed. |
52 | // - TODO(tvm-team): we might consider a root scope so stack_make_shape |
53 | // can be called out-side call_packed. |
54 | // |
55 | // Example: |
56 | // { |
57 | // call_packed(make_shape1(...), |
58 | // call_packed(make_shape2(...)) |
59 | // call_packed(make_shape3(...)) |
60 | // } |
61 | // |
62 | // In this case, make_shape1 and make_shape2 should not share memory, |
63 | // but they can share memory with make_shape3. |
64 | // |
65 | // Rationale: most of the packed calls needs their own internal |
66 | // argument stack, and those stack can be shared across calls. |
67 | // Scoping is a quick way to enable sharing without having |
68 | // to do full-scale liveness analysis and it does its job. |
69 | // Alternative approaches can also be used. |
70 | struct StackSizes { |
71 | // If a tvm_stack_make_shape call has no arguments, it is still |
72 | // valid and represents a scalar shape (). Therefore, -1 is used |
73 | // to represent "no shape arguments exist", while 0 represents |
74 | // "shape arguments exist, all of which are size 0". |
75 | int64_t shape_stack{-1}; |
76 | uint64_t array_stack{0}; |
77 | uint64_t arg_stack{0}; |
78 | }; |
79 | |
80 | // Record stack frame for existing scope. |
81 | struct AllocaScope { |
82 | Buffer stack_shape; |
83 | Var stack_array = Var("stack_array" , DataType::Handle()); |
84 | Var stack_value = Var("stack_value" , DataType::Handle()); |
85 | Buffer stack_tcode; |
86 | |
87 | StackSizes max_sizes; |
88 | StackSizes run_sizes; |
89 | |
90 | void UpdateMax() { |
91 | max_sizes.shape_stack = std::max(max_sizes.shape_stack, run_sizes.shape_stack); |
92 | max_sizes.array_stack = std::max(max_sizes.array_stack, run_sizes.array_stack); |
93 | max_sizes.arg_stack = std::max(max_sizes.arg_stack, run_sizes.arg_stack); |
94 | } |
95 | |
96 | void AssertMaxIsValid() const { |
97 | ICHECK((max_sizes.shape_stack >= run_sizes.shape_stack) || |
98 | (max_sizes.array_stack >= run_sizes.array_stack) || |
99 | (max_sizes.arg_stack >= run_sizes.arg_stack)); |
100 | } |
101 | }; |
102 | |
103 | Stmt Build(Stmt stmt) { return this->VisitBodyAndRealizeAlloca(stmt); } |
104 | |
105 | StackSizes GetMaxStack(Stmt stmt) { |
106 | BuiltinLower precheck; |
107 | precheck.is_precheck_ = true; |
108 | precheck.device_id_ = this->device_id_; |
109 | precheck.device_type_ = this->device_type_; |
110 | |
111 | precheck.alloca_scope_.emplace_back(); |
112 | { |
113 | // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. |
114 | auto& scope = precheck.alloca_scope_.back(); |
115 | scope.stack_shape = |
116 | decl_buffer({IntImm(DataType::Int(64), 0)}, DataType::Int(64), "stack_shape" ); |
117 | scope.stack_tcode = |
118 | decl_buffer({IntImm(DataType::UInt(64), 0)}, DataType::Int(32), "stack_tcode" ); |
119 | } |
120 | |
121 | precheck.VisitStmt(stmt); |
122 | |
123 | ICHECK_EQ(precheck.alloca_scope_.size(), 1); |
124 | return precheck.alloca_scope_[0].max_sizes; |
125 | } |
126 | |
127 | // Allcoate stack frames, only at parallel-for or root. |
128 | Stmt VisitBodyAndRealizeAlloca(Stmt stmt) { |
129 | // Only perform the precheck up to the point where we would add a |
130 | // new scope. |
131 | if (is_precheck_) { |
132 | return stmt; |
133 | } |
134 | |
135 | alloca_scope_.emplace_back(); |
136 | { |
137 | // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. |
138 | auto& scope = alloca_scope_.back(); |
139 | |
140 | // Initial check to identify maximum stack sizes. These are used |
141 | // to construct Buffer objects to hold the stack, which are then |
142 | // used when mutating. |
143 | scope.max_sizes = GetMaxStack(stmt); |
144 | |
145 | if (scope.max_sizes.shape_stack != -1) { |
146 | scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, |
147 | DataType::Int(64), "stack_shape" ); |
148 | stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape" , scope.max_sizes.shape_stack), |
149 | stmt); |
150 | } |
151 | |
152 | if (scope.max_sizes.array_stack != 0) { |
153 | stmt = LetStmt(scope.stack_array, StackAlloca("array" , scope.max_sizes.array_stack), stmt); |
154 | } |
155 | |
156 | if (scope.max_sizes.arg_stack != 0) { |
157 | scope.stack_tcode = decl_buffer({IntImm(DataType::UInt(64), scope.max_sizes.arg_stack)}, |
158 | DataType::Int(32), "stack_tcode" ); |
159 | stmt = |
160 | LetStmt(scope.stack_value, StackAlloca("arg_value" , scope.max_sizes.arg_stack), stmt); |
161 | |
162 | stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode" , scope.max_sizes.arg_stack), |
163 | stmt); |
164 | } |
165 | } |
166 | |
167 | stmt = this->VisitStmt(stmt); |
168 | |
169 | ICHECK(!alloca_scope_.empty()); |
170 | alloca_scope_.pop_back(); |
171 | |
172 | return stmt; |
173 | } |
174 | |
175 | Stmt VisitStmt(const Stmt& s) final { |
176 | // allocate space to hold prepare stmts before s |
177 | prep_seq_stack_.emplace_back(std::vector<Stmt>()); |
178 | |
179 | auto scope_size = alloca_scope_.size(); |
180 | auto stmt = StmtExprMutator::VisitStmt(s); |
181 | { |
182 | // NOTE: this scope reference is invalid after any mutation is applied to alloca_scope_. |
183 | auto& scope = alloca_scope_.back(); |
184 | // This invariant asserts the assumption that |
185 | // make_stack_shape only happens within a call_packed. |
186 | // We could relax this in the future if we want to |
187 | // introduce root scope as a separate scope |
188 | ICHECK_EQ(alloca_scope_.size(), scope_size) |
189 | << "alloca_scope_ length is different before and after recursion" ; |
190 | ICHECK_EQ(scope.run_sizes.shape_stack, -1) |
191 | << "Expect no tvm_stack_make_shape outside of CallNodes" ; |
192 | ICHECK_EQ(scope.run_sizes.array_stack, 0) |
193 | << "Expect no tvm_stack_make_array outside of CallNodes" ; |
194 | } |
195 | |
196 | auto prep_seq = std::move(prep_seq_stack_.back()); |
197 | prep_seq_stack_.pop_back(); |
198 | |
199 | if (prep_seq.size() != 0) { |
200 | Stmt ret = SeqStmt::Flatten(prep_seq, stmt); |
201 | return ret; |
202 | } else { |
203 | return stmt; |
204 | } |
205 | } |
206 | |
207 | Stmt VisitStmt_(const LetStmtNode* op) final { |
208 | if (const CallNode* call = op->value.as<CallNode>()) { |
209 | if (call->op.same_as(builtin::nd_mem_alloc_with_scope())) { |
210 | return StmtExprMutator::VisitStmt(MakeNdMemAllocWithScope(op, call)); |
211 | } |
212 | } |
213 | return StmtExprMutator::VisitStmt_(op); |
214 | } |
215 | |
216 | Stmt VisitStmt_(const AllocateNode* op) { |
217 | // Lower allocate to device allocate when needed. |
218 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
219 | op = stmt.as<AllocateNode>(); |
220 | // Get constant allocation bound. |
221 | int64_t nbytes = GetVectorBytes(op->dtype); |
222 | // If the buffers are for CPU and have global scope, |
223 | // and less than runtime::kMaxStackAlloca heuristic |
224 | // they are not serviced with TVMBackendWorkspaceAlloc calls |
225 | // to be placed on stack. |
226 | if (op->annotations.count(transform::kDisableLowerTVMBuiltin)) { |
227 | if (Downcast<Bool>(op->annotations[transform::kDisableLowerTVMBuiltin])) { |
228 | return stmt; |
229 | } |
230 | } |
231 | if (device_type_.defined()) { |
232 | if (const auto* dev_type = device_type_.as<IntImmNode>()) { |
233 | auto storage_scope = Downcast<PointerType>(op->buffer_var->type_annotation)->storage_scope; |
234 | if (dev_type->value == kDLCPU && storage_scope == "global" ) { |
235 | size_t constant_size = op->ConstantAllocationSize(); |
236 | if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { |
237 | return stmt; |
238 | } |
239 | } |
240 | } |
241 | } |
242 | PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); |
243 | for (size_t i = 0; i < op->extents.size(); ++i) { |
244 | total_bytes = total_bytes * op->extents[i]; |
245 | } |
246 | ICHECK(device_type_.defined()) << "Unknown device type in current IR" ; |
247 | ICHECK(device_id_.defined()) << "Unknown device id in current IR" ; |
248 | Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); |
249 | |
250 | Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), |
251 | throw_last_error), |
252 | op->body}); |
253 | Stmt alloca = LetStmt( |
254 | op->buffer_var, |
255 | Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace" ), |
256 | {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), |
257 | cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), |
258 | IntImm(DataType::Int(32), op->dtype.bits())}), |
259 | body); |
260 | |
261 | PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace" ), |
262 | {cast(DataType::Int(32), device_type_), |
263 | cast(DataType::Int(32), device_id_), op->buffer_var}); |
264 | Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); |
265 | body = SeqStmt({alloca, free_stmt}); |
266 | body = AttrStmt(op->buffer_var, attr::storage_alignment, |
267 | make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); |
268 | return body; |
269 | } |
270 | |
271 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
272 | if (op->attr_key == attr::device_id) { |
273 | ICHECK(!device_id_.defined()); |
274 | device_id_ = op->value; |
275 | return this->VisitStmt(op->body); |
276 | } else if (op->attr_key == attr::device_type) { |
277 | ICHECK(!device_type_.defined()); |
278 | device_type_ = op->value; |
279 | return this->VisitStmt(op->body); |
280 | } else { |
281 | return StmtExprMutator::VisitStmt_(op); |
282 | } |
283 | } |
284 | Stmt VisitStmt_(const ForNode* op) final { |
285 | PrimExpr min = this->VisitExpr(op->min); |
286 | PrimExpr extent = this->VisitExpr(op->extent); |
287 | Stmt body; |
288 | |
289 | if (op->kind == ForKind::kParallel) { |
290 | body = this->VisitBodyAndRealizeAlloca(op->body); |
291 | } else { |
292 | body = this->VisitStmt(op->body); |
293 | } |
294 | |
295 | if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { |
296 | return GetRef<Stmt>(op); |
297 | } else { |
298 | auto n = CopyOnWrite(op); |
299 | n->min = std::move(min); |
300 | n->extent = std::move(extent); |
301 | n->body = std::move(body); |
302 | return Stmt(n); |
303 | } |
304 | } |
305 | PrimExpr VisitExpr_(const CallNode* op) final { |
306 | if (op->op.same_as(builtin::tvm_call_packed())) { |
307 | return MakeCallPacked(op, /* use_string_lookup */ true); |
308 | } else if (op->op.same_as(builtin::tvm_call_cpacked())) { |
309 | return MakeCallPacked(op, /* use_string_lookup */ false); |
310 | } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { |
311 | return MakeCallTracePacked(op); |
312 | } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { |
313 | return MakeShape(op); |
314 | } else if (op->op.same_as(builtin::tvm_stack_make_array())) { |
315 | return MakeArray(op); |
316 | } else if (op->op.same_as(builtin::tvm_context_id())) { |
317 | return make_zero(op->dtype); |
318 | } else if (op->op.same_as(builtin::dma_copy())) { |
319 | return MakeDMACopy(op); |
320 | } else if (op->op.same_as(builtin::dma_wait())) { |
321 | return MakeDMAWait(op); |
322 | } else { |
323 | return StmtExprMutator::VisitExpr_(op); |
324 | } |
325 | } |
326 | |
327 | PrimExpr MakeDMACopy(const CallNode* op) { |
328 | PrimExpr queue_id = op->args[0]; |
329 | PrimExpr dst = op->args[1]; |
330 | PrimExpr src = op->args[2]; |
331 | PrimExpr size = op->args[3]; |
332 | PrimExpr bypass_cache = op->args[4]; |
333 | |
334 | std::string fdevapi_prefix = |
335 | "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value)); |
336 | |
337 | Call call_packed = |
338 | Call(DataType::Int(32), builtin::tvm_call_packed(), |
339 | {StringImm(fdevapi_prefix + ".dma_copy" ), queue_id, dst, src, size, bypass_cache}); |
340 | return VisitExpr(call_packed); |
341 | } |
342 | |
343 | PrimExpr MakeDMAWait(const CallNode* op) { |
344 | PrimExpr queue_id = op->args[0]; |
345 | PrimExpr inflight = op->args[1]; |
346 | |
347 | std::string fdevapi_prefix = |
348 | "device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value)); |
349 | |
350 | Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), |
351 | {StringImm(fdevapi_prefix + ".dma_wait" ), queue_id, inflight}); |
352 | return VisitExpr(call_packed); |
353 | } |
354 | |
355 | // call shape |
356 | PrimExpr MakeShape(const CallNode* op) { |
357 | // if args.size() == 0, it represents a scalar shape () |
358 | ICHECK(!alloca_scope_.empty()); |
359 | auto& scope = alloca_scope_.back(); |
360 | auto& prep_seq = prep_seq_stack_.back(); |
361 | if (scope.run_sizes.shape_stack == -1) { |
362 | scope.run_sizes.shape_stack = 0; |
363 | } |
364 | int64_t stack_begin = scope.run_sizes.shape_stack; |
365 | scope.run_sizes.shape_stack += op->args.size(); |
366 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
367 | op = expr.as<CallNode>(); |
368 | // no need to perform any store for a scalar shape |
369 | for (size_t i = 0; i < op->args.size(); ++i) { |
370 | prep_seq.emplace_back(BufferStore(scope.stack_shape, cast(DataType::Int(64), op->args[i]), |
371 | {ConstInt32(stack_begin + i)})); |
372 | } |
373 | return AddressOffset(scope.stack_shape->data, DataType::Int(64), stack_begin); |
374 | } |
375 | // make array |
376 | PrimExpr MakeArray(const CallNode* op) { |
377 | ICHECK(!alloca_scope_.empty()); |
378 | auto& scope = alloca_scope_.back(); |
379 | auto& prep_seq = prep_seq_stack_.back(); |
380 | |
381 | size_t idx = scope.run_sizes.array_stack; |
382 | scope.run_sizes.array_stack += 1; |
383 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
384 | op = expr.as<CallNode>(); |
385 | |
386 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0])); |
387 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1])); |
388 | PrimExpr strides = op->args[2]; |
389 | if (!strides.defined() || is_zero(strides)) { |
390 | strides = make_zero(DataType::Handle()); |
391 | } |
392 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides)); |
393 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3])); |
394 | DataType dtype = op->args[4].dtype(); |
395 | prep_seq.emplace_back( |
396 | TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode, |
397 | make_const(DataType::UInt(8), static_cast<int>(dtype.code())))); |
398 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits, |
399 | make_const(DataType::UInt(8), dtype.bits()))); |
400 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes, |
401 | make_const(DataType::UInt(16), dtype.lanes()))); |
402 | // set byte offset |
403 | int data_bytes = GetVectorBytes(dtype); |
404 | PrimExpr elem_offset = op->args[5]; |
405 | PrimExpr byte_offset; |
406 | if (!is_zero(elem_offset)) { |
407 | byte_offset = elem_offset * make_const(elem_offset.dtype(), data_bytes); |
408 | } else { |
409 | byte_offset = elem_offset; |
410 | } |
411 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, |
412 | cast(DataType::UInt(64), byte_offset))); |
413 | ICHECK(device_type_.defined()) << "Unknown device type in current IR" ; |
414 | ICHECK(device_id_.defined()) << "Unknown device id in current IR" ; |
415 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId, |
416 | cast(DataType::Int(32), device_id_))); |
417 | prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType, |
418 | cast(DataType::Int(32), device_type_))); |
419 | return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); |
420 | } |
421 | // call packed. |
422 | PrimExpr MakeCallPacked(const CallNode* op, bool use_string_lookup) { |
423 | auto& scope = alloca_scope_.back(); |
424 | auto& prep_seq = prep_seq_stack_.back(); |
425 | |
426 | int64_t restore_shape_stack = scope.run_sizes.shape_stack; |
427 | size_t restore_array_stack = scope.run_sizes.array_stack; |
428 | size_t arg_stack_begin = scope.run_sizes.arg_stack; |
429 | |
430 | size_t arg_count = op->args.size(); |
431 | |
432 | // cpacked expects a resource_handle parameter |
433 | if (!use_string_lookup) { |
434 | arg_count--; |
435 | } |
436 | |
437 | scope.run_sizes.arg_stack += arg_count; |
438 | // Specially handle the buffer packed intrinsic |
439 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
440 | op = expr.as<CallNode>(); |
441 | for (size_t i = 1; i < arg_count; ++i) { |
442 | PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); |
443 | PrimExpr arg = op->args[i]; |
444 | DataType t = arg.dtype(); |
445 | DataType api_type = APIType(t); |
446 | if (t != api_type) { |
447 | arg = Cast(api_type, arg); |
448 | } |
449 | prep_seq.emplace_back(TVMStructSet(scope.stack_value, |
450 | static_cast<int>(arg_stack_begin + i - 1), |
451 | builtin::kTVMValueContent, arg)); |
452 | int arg_tcode = api_type.code(); |
453 | if (api_type.is_handle() && arg.as<StringImmNode>()) { |
454 | arg_tcode = kTVMStr; |
455 | } |
456 | if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; |
457 | prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); |
458 | } |
459 | // Verify stack size matches earlier value. |
460 | if (is_precheck_) { |
461 | scope.UpdateMax(); |
462 | } else { |
463 | scope.AssertMaxIsValid(); |
464 | } |
465 | scope.run_sizes.shape_stack = restore_shape_stack; |
466 | scope.run_sizes.array_stack = restore_array_stack; |
467 | scope.run_sizes.arg_stack = arg_stack_begin; |
468 | Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, |
469 | ConstInt32(arg_stack_begin), |
470 | ConstInt32(arg_stack_begin + op->args.size() - 1)}; |
471 | |
472 | // cpacked call resource_handle |
473 | if (!use_string_lookup) { |
474 | PrimExpr last_arg = op->args[arg_count]; |
475 | const VarNode* var_node = last_arg.as<VarNode>(); |
476 | if (var_node != nullptr) { |
477 | tir::Var resource_handle = GetRef<Var>(var_node); |
478 | packed_args.push_back(StringImm(resource_handle->name_hint)); |
479 | } else { |
480 | packed_args.push_back(last_arg); |
481 | } |
482 | } |
483 | |
484 | auto builtin_call = use_string_lookup ? builtin::tvm_call_packed_lowered() |
485 | : builtin::tvm_call_cpacked_lowered(); |
486 | return Call(op->dtype, builtin_call, packed_args); |
487 | } |
488 | |
489 | PrimExpr MakeCallTracePacked(const CallNode* op) { |
490 | ICHECK(!alloca_scope_.empty()); |
491 | auto& scope = alloca_scope_.back(); |
492 | auto& prep_seq = prep_seq_stack_.back(); |
493 | |
494 | int64_t restore_shape_stack = scope.run_sizes.shape_stack; |
495 | size_t restore_array_stack = scope.run_sizes.array_stack; |
496 | size_t arg_stack_begin = scope.run_sizes.arg_stack; |
497 | scope.run_sizes.arg_stack += op->args.size(); |
498 | size_t args_size = op->args.size(); |
499 | ICHECK_GT(args_size, 0); |
500 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
501 | op = expr.as<CallNode>(); |
502 | for (size_t i = 1; i < op->args.size(); ++i) { |
503 | PrimExpr stack_index = ConstInt32(arg_stack_begin + i - 1); |
504 | PrimExpr arg = op->args[i]; |
505 | DataType t = arg.dtype(); |
506 | DataType api_type = APIType(t); |
507 | if (t != api_type) { |
508 | arg = Cast(api_type, arg); |
509 | } |
510 | prep_seq.emplace_back(TVMStructSet(scope.stack_value, |
511 | static_cast<int>(arg_stack_begin + i - 1), |
512 | builtin::kTVMValueContent, arg)); |
513 | int arg_tcode = api_type.code(); |
514 | ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers" ; |
515 | prep_seq.emplace_back(BufferStore(scope.stack_tcode, ConstInt32(arg_tcode), {stack_index})); |
516 | } |
517 | // Verify stack size matches earlier value. |
518 | if (is_precheck_) { |
519 | scope.UpdateMax(); |
520 | } else { |
521 | scope.AssertMaxIsValid(); |
522 | } |
523 | scope.run_sizes.shape_stack = restore_shape_stack; |
524 | scope.run_sizes.array_stack = restore_array_stack; |
525 | // Update the top of the stack, so we can use more than one |
526 | // packed function's arguments with the one stack. |
527 | scope.run_sizes.arg_stack = arg_stack_begin + args_size - 1; |
528 | Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode->data, |
529 | ConstInt32(arg_stack_begin), |
530 | ConstInt32(arg_stack_begin + op->args.size() - 1), |
531 | // Pass traced value. |
532 | op->args[args_size - 1]}; |
533 | return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args); |
534 | } |
535 | |
536 | Stmt MakeNdMemAllocWithScope(const LetStmtNode* let, const CallNode* call) { |
537 | ICHECK(device_type_.defined()) << "Unknown device type in current IR" ; |
538 | ICHECK(device_id_.defined()) << "Unknown device id in current IR" ; |
539 | Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); |
540 | |
541 | Stmt body = SeqStmt( |
542 | {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), |
543 | let->body}); |
544 | |
545 | DataType dtype = |
546 | let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype; |
547 | |
548 | std::string fdevapi_prefix = "device_api." ; |
549 | fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value); |
550 | |
551 | Array<PrimExpr> args = { |
552 | StringImm(fdevapi_prefix + ".alloc_nd" ), |
553 | device_type_, |
554 | device_id_, |
555 | IntImm(DataType::Int(32), dtype.code()), |
556 | IntImm(DataType::Int(32), dtype.bits()), |
557 | }; |
558 | |
559 | for (size_t i = 0; i < call->args.size(); ++i) { |
560 | args.push_back(call->args[i]); |
561 | } |
562 | |
563 | Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); |
564 | Stmt alloca = LetStmt(let->var, call_packed, body); |
565 | |
566 | PrimExpr storage_scope = call->args[0]; |
567 | Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), |
568 | {StringImm(fdevapi_prefix + ".free_nd" ), device_type_, device_id_, |
569 | storage_scope, let->var}); |
570 | |
571 | Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); |
572 | body = SeqStmt({alloca, free_stmt}); |
573 | return body; |
574 | } |
575 | |
576 | private: |
577 | bool IsArrayHandle(const PrimExpr& arg) { |
578 | // specially set array handle. |
579 | if (const CallNode* buf = arg.as<CallNode>()) { |
580 | if (buf->op.same_as(builtin::tvm_struct_get()) && |
581 | buf->args[2].as<IntImmNode>()->value == builtin::kArrAddr) { |
582 | return true; |
583 | } |
584 | } |
585 | return false; |
586 | } |
587 | |
588 | // The prepration sequence to be emitted before the current statement. |
589 | std::vector<std::vector<Stmt>> prep_seq_stack_; |
590 | PrimExpr device_type_; |
591 | PrimExpr device_id_; |
592 | |
593 | bool is_precheck_{false}; |
594 | |
595 | // Record all stack frames. |
596 | std::vector<AllocaScope> alloca_scope_; |
597 | }; |
598 | |
599 | namespace transform { |
600 | |
601 | Pass LowerTVMBuiltin() { |
602 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
603 | auto* n = f.CopyOnWrite(); |
604 | n->body = BuiltinLower().Build(n->body); |
605 | VLOG(2) << "LowerTVMBuiltin: " << f; |
606 | return f; |
607 | }; |
608 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin" , {}); |
609 | } |
610 | |
611 | TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin" ).set_body_typed(LowerTVMBuiltin); |
612 | |
613 | } // namespace transform |
614 | } // namespace tir |
615 | } // namespace tvm |
616 | |