1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file make_packed_call.cc
22 * \brief Rewrite packed calls in AOT so that the arguments are packed
23 */
24#include <tvm/tir/builtin.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include <unordered_map>
32
33#include "ir_utils.h"
34
35namespace tvm {
36namespace tir {
37
38using InputMap =
39 std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
40/**
41 * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize
42 * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in)
43 */
44class PackedCallLegalizer : public StmtExprMutator {
45 public:
46 PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {}
47
48 Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); }
49
50 Stmt VisitStmt_(const EvaluateNode* op) final {
51 if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op);
52 const CallNode* call = op->value.as<CallNode>();
53 // Given a packed call f(A,B,C), we need a set of new statements
54 // let A_packed = set_struct(tvm_value1, A)
55 // let B_packed = set_struct(tvm_value2, B)
56 // let C_packed = set_struct(tvm_value3, C)
57 // call_packed(f, A_packed, B_packed, C_packed)
58 if (call) {
59 if (call->op.same_as(builtin::tvm_call_cpacked())) {
60 Array<PrimExpr> packed_args{call->args[0]};
61 VLOG(2) << "Legalize call:" << call;
62 BaseFunc base_func = mod_->Lookup(Downcast<StringImm>(call->args[0])->value);
63 const PrimFuncNode* prim_func = base_func.as<PrimFuncNode>();
64 VLOG(2) << " to func " << base_func;
65 for (unsigned i = 1; i < call->args.size() - 1; i++) {
66 // No need to pack inputs of the prim_func
67 if (inputs_[call->args[i]] == true) {
68 packed_args.push_back(call->args[i]);
69 } else {
70 // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect
71 // all such stack-allocated tensors and minimize the storage needed by reusing
72 // DLTensors.
73 Array<PrimExpr> call_args{call->args[i]};
74 tvm::runtime::Map<tvm::tir::Var, tvm::tir::Buffer>::iterator param_buf_it;
75 if (prim_func != nullptr) {
76 auto param_var = prim_func->params[i - 1];
77 param_buf_it = prim_func->buffer_map.find(param_var);
78 }
79 if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) {
80 Buffer param = (*param_buf_it).second;
81 PrimExpr shape = tvm::tir::Call(
82 DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape);
83 Cast var_type(param->dtype, IntImm(DataType::Int(32), 0));
84 call_args.push_back(shape /* shape */);
85 call_args.push_back(make_zero(DataType::Handle()) /* strides */);
86 call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */);
87 call_args.push_back(var_type /* carries dtype */);
88 call_args.push_back(param->elem_offset /* elem_offset */);
89 } else {
90 // When the PrimFunc cannot be found, most DLTensor information cannot be populated.
91 PrimExpr shape = tvm::tir::Call(
92 DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array<PrimExpr>());
93 Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0));
94 call_args.push_back(shape /* shape */);
95 call_args.push_back(make_zero(DataType::Handle()) /* strides */);
96 call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */);
97 call_args.push_back(var_type /* carries dtype */);
98 call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */);
99 }
100 packed_args.push_back(tvm::tir::Call(
101 DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args));
102 }
103 }
104 packed_args.push_back(call->args[call->args.size() - 1]); // push device_context
105 // Evaluate the packed call
106 return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args));
107 }
108 }
109 return StmtExprMutator::VisitStmt_(op);
110 }
111
112 private:
113 IRModule mod_;
114 InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed.
115};
116
117namespace transform {
118
119Pass LegalizePackedCalls() {
120 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
121 auto* n = f.CopyOnWrite();
122
123 // Note which Var are inputs and exclude them from packing.
124 InputMap inputs;
125 for (auto i : f->params) {
126 inputs[i] = true;
127 }
128 n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body));
129 return f;
130 };
131 return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {});
132}
133
134TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls").set_body_typed(LegalizePackedCalls);
135} // namespace transform
136
137} // namespace tir
138} // namespace tvm
139