1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file make_packed_call.cc |
22 | * \brief Rewrite packed calls in AOT so that the arguments are packed |
23 | */ |
24 | #include <tvm/tir/builtin.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/function.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include <unordered_map> |
32 | |
33 | #include "ir_utils.h" |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | using InputMap = |
39 | std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; |
40 | /** |
41 | * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize |
42 | * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in) |
43 | */ |
44 | class PackedCallLegalizer : public StmtExprMutator { |
45 | public: |
46 | PackedCallLegalizer(IRModule m, const InputMap& inputs) : mod_{m}, inputs_{inputs} {} |
47 | |
48 | Stmt Legalize(tir::Stmt body) { return StmtExprMutator::VisitStmt(body); } |
49 | |
50 | Stmt VisitStmt_(const EvaluateNode* op) final { |
51 | if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op); |
52 | const CallNode* call = op->value.as<CallNode>(); |
53 | // Given a packed call f(A,B,C), we need a set of new statements |
54 | // let A_packed = set_struct(tvm_value1, A) |
55 | // let B_packed = set_struct(tvm_value2, B) |
56 | // let C_packed = set_struct(tvm_value3, C) |
57 | // call_packed(f, A_packed, B_packed, C_packed) |
58 | if (call) { |
59 | if (call->op.same_as(builtin::tvm_call_cpacked())) { |
60 | Array<PrimExpr> packed_args{call->args[0]}; |
61 | VLOG(2) << "Legalize call:" << call; |
62 | BaseFunc base_func = mod_->Lookup(Downcast<StringImm>(call->args[0])->value); |
63 | const PrimFuncNode* prim_func = base_func.as<PrimFuncNode>(); |
64 | VLOG(2) << " to func " << base_func; |
65 | for (unsigned i = 1; i < call->args.size() - 1; i++) { |
66 | // No need to pack inputs of the prim_func |
67 | if (inputs_[call->args[i]] == true) { |
68 | packed_args.push_back(call->args[i]); |
69 | } else { |
70 | // Stack-allocate a DLTensor for this parameter. Note that LowerTVMBuiltin will collect |
71 | // all such stack-allocated tensors and minimize the storage needed by reusing |
72 | // DLTensors. |
73 | Array<PrimExpr> call_args{call->args[i]}; |
74 | tvm::runtime::Map<tvm::tir::Var, tvm::tir::Buffer>::iterator param_buf_it; |
75 | if (prim_func != nullptr) { |
76 | auto param_var = prim_func->params[i - 1]; |
77 | param_buf_it = prim_func->buffer_map.find(param_var); |
78 | } |
79 | if (prim_func != nullptr && param_buf_it != prim_func->buffer_map.end()) { |
80 | Buffer param = (*param_buf_it).second; |
81 | PrimExpr shape = tvm::tir::Call( |
82 | DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), param->shape); |
83 | Cast var_type(param->dtype, IntImm(DataType::Int(32), 0)); |
84 | call_args.push_back(shape /* shape */); |
85 | call_args.push_back(make_zero(DataType::Handle()) /* strides */); |
86 | call_args.push_back(tvm::IntImm(DataType::UInt(32), param->shape.size()) /* ndim */); |
87 | call_args.push_back(var_type /* carries dtype */); |
88 | call_args.push_back(param->elem_offset /* elem_offset */); |
89 | } else { |
90 | // When the PrimFunc cannot be found, most DLTensor information cannot be populated. |
91 | PrimExpr shape = tvm::tir::Call( |
92 | DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), Array<PrimExpr>()); |
93 | Cast var_type(DataType::Handle(), IntImm(DataType::Int(32), 0)); |
94 | call_args.push_back(shape /* shape */); |
95 | call_args.push_back(make_zero(DataType::Handle()) /* strides */); |
96 | call_args.push_back(tvm::IntImm(DataType::UInt(32), 0) /* ndim */); |
97 | call_args.push_back(var_type /* carries dtype */); |
98 | call_args.push_back(tvm::IntImm(DataType::UInt(64), 0) /* elem_offset */); |
99 | } |
100 | packed_args.push_back(tvm::tir::Call( |
101 | DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), call_args)); |
102 | } |
103 | } |
104 | packed_args.push_back(call->args[call->args.size() - 1]); // push device_context |
105 | // Evaluate the packed call |
106 | return tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)); |
107 | } |
108 | } |
109 | return StmtExprMutator::VisitStmt_(op); |
110 | } |
111 | |
112 | private: |
113 | IRModule mod_; |
114 | InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed. |
115 | }; |
116 | |
117 | namespace transform { |
118 | |
119 | Pass LegalizePackedCalls() { |
120 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
121 | auto* n = f.CopyOnWrite(); |
122 | |
123 | // Note which Var are inputs and exclude them from packing. |
124 | InputMap inputs; |
125 | for (auto i : f->params) { |
126 | inputs[i] = true; |
127 | } |
128 | n->body = PackedCallLegalizer(m, inputs).Legalize(std::move(n->body)); |
129 | return f; |
130 | }; |
131 | return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls" , {}); |
132 | } |
133 | |
134 | TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls" ).set_body_typed(LegalizePackedCalls); |
135 | } // namespace transform |
136 | |
137 | } // namespace tir |
138 | } // namespace tvm |
139 | |