1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/target/target.h>
21#include <tvm/tir/builtin.h>
22#include <tvm/tir/op.h>
23#include <tvm/tir/stmt_functor.h>
24#include <tvm/tir/transform.h>
25#include <tvm/tir/usmp/algorithms.h>
26#include <tvm/tir/usmp/analysis.h>
27#include <tvm/tir/usmp/transform.h>
28#include <tvm/tir/usmp/utils.h>
29
30#include <stack>
31#include <string>
32
33namespace tvm {
34namespace tir {
35namespace usmp {
36
37/*! \brief Creates Allocate nodes with special annotations
38 * for I/O tensors in the graph to be memory planned.*/
39class IOAllocateCreator : public StmtExprVisitor {
40 public:
41 explicit IOAllocateCreator(const IRModule& module) {
42 main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
43 ICHECK(main_func_.defined()) << "main function is not in the module";
44 for (const auto& gv_func : module->functions) {
45 if (gv_func.second->IsInstance<PrimFuncNode>()) {
46 functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
47 }
48 }
49 mod_ = module->ShallowCopy();
50 }
51 IRModule operator()();
52
53 private:
54 void VisitExpr_(const BufferLoadNode* op) override;
55 void VisitExpr_(const LoadNode* op) override;
56 void VisitExpr_(const CallNode* op) override;
57 void VisitStmt_(const BufferStoreNode* op) override;
58 void VisitStmt_(const StoreNode* op) override;
59
60 /*! \brief Updates aliases that buffer vars inside the primfunc refer
61 * to in terms call arguments they get bound to.*/
62 void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
63
64 /*! \brief The IRModule that is being mutated */
65 IRModule mod_;
66 /*! \brief The main function that calls into operator subgraphs */
67 PrimFunc main_func_;
68 /*! \brief The input Vars of the main function */
69 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
70 /*! \brief The output Vars of the main function */
71 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
72 /*! \brief The buffer vars associated with the I/O Vars */
73 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> io_buffer_vars_;
74 /*! \brief The aliases that buffer vars inside the primfunc refer
75 * to in terms call arguments */
76 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> aliases_;
77 /*!
78 * \brief The TIR main function calls by name to PrimFuncs to be able to
79 * support BYOC. Therefore, this Map records functions that are present
80 * in the IRModule by name/
81 */
82 Map<String, PrimFunc> functions_;
83};
84
85/*!
86 * \brief The function obtains the matched buffer vars for
87 * the params of the PrimFunc.
88 */
89Array<Var> static GetMatchedBuffers(const PrimFunc& func) {
90 Array<Var> buffer_vars;
91 for (unsigned int i = 0; i < func->params.size() - 1; i++) {
92 Var param = func->params[i];
93 buffer_vars.push_back(func->buffer_map[param]->data);
94 }
95 Var last_param = func->params.back();
96 // Checks whether last var is present in the buffer map
97 // because it could be the resource handle
98 if (func->buffer_map.find(last_param) != func->buffer_map.end()) {
99 buffer_vars.push_back(func->buffer_map[last_param]->data);
100 }
101 return buffer_vars;
102}
103
104/*!
105 * \brief The function updates aliases that each buffer var with its
106 * associated argument in the callsite.
107 */
108void IOAllocateCreator::UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func) {
109 auto param_buffers = GetMatchedBuffers(func);
110 // Last var could be a resource handle that does not have a Buffer
111 ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size());
112 for (size_t i = 0; i < param_buffers.size(); i++) {
113 auto arg = args[i];
114 if (arg->IsInstance<VarNode>()) {
115 auto param_buf = param_buffers[i];
116 aliases_[param_buf] = Downcast<Var>(arg);
117 }
118 }
119}
120
121void IOAllocateCreator::VisitExpr_(const CallNode* op) {
122 if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
123 StringImm func_name = Downcast<StringImm>(op->args[0])->value;
124 if (functions_.find(func_name->value) != functions_.end()) {
125 auto func = functions_.at(func_name->value);
126 auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end());
127 this->UpdateAliases(actual_args, func);
128 VisitStmt(func->body);
129 return;
130 }
131 }
132 if (op->op->IsInstance<PrimFuncNode>()) {
133 auto func = Downcast<PrimFunc>(op->op);
134 this->UpdateAliases(op->args, func);
135 VisitStmt(func->body);
136 return;
137 }
138 StmtExprVisitor::VisitExpr_(op);
139}
140
141void IOAllocateCreator::VisitExpr_(const BufferLoadNode* op) {
142 if (aliases_.find(op->buffer->data) != aliases_.end()) {
143 Var aliased_var = aliases_[op->buffer->data];
144 if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) {
145 ICHECK(outputs_.find(aliased_var) == outputs_.end())
146 << "BufferLoad nodes should not be reading from output buffer vars.";
147 inputs_.insert(aliased_var);
148 }
149 }
150 StmtExprVisitor::VisitExpr_(op);
151}
152
153void IOAllocateCreator::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "should not come here"; }
154
155void IOAllocateCreator::VisitStmt_(const BufferStoreNode* op) {
156 if (aliases_.find(op->buffer->data) != aliases_.end()) {
157 Var aliased_var = aliases_[op->buffer->data];
158 if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) {
159 ICHECK(inputs_.find(aliased_var) == inputs_.end())
160 << "BufferStore nodes should not be writing to input buffer vars.";
161 outputs_.insert(aliased_var);
162 }
163 }
164 StmtExprVisitor::VisitStmt_(op);
165}
166
167void IOAllocateCreator::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "should not come here"; }
168
169IRModule IOAllocateCreator::operator()() {
170 Array<Var> new_main_params;
171 Stmt main_body = main_func_->body;
172 for (const Var& param : main_func_->params) {
173 if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) {
174 Var buffer_var = main_func_->buffer_map[param]->data;
175 io_buffer_vars_.insert(buffer_var);
176 aliases_[buffer_var] = buffer_var;
177 }
178 }
179 VisitStmt(main_body);
180 ICHECK(io_buffer_vars_.size() == inputs_.size() + outputs_.size())
181 << "Every IO Buffer var should be categorized either to be input or output";
182 for (const Var& param : main_func_->params) {
183 if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) {
184 Buffer param_buffer = main_func_->buffer_map[param];
185 String io_annotation;
186 if (inputs_.find(param_buffer->data) != inputs_.end()) {
187 io_annotation = String(kInputTensorAllocate);
188 } else {
189 io_annotation = String(kOutputTensorAllocate);
190 }
191 main_body = Allocate(param_buffer->data, param_buffer->dtype, param_buffer->shape,
192 const_true(), main_body, {{io_annotation, param->name_hint}});
193 } else {
194 new_main_params.push_back(param);
195 }
196 }
197 const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
198 mod_->Update(gv, PrimFunc(new_main_params, main_body, main_func_->ret_type,
199 main_func_->buffer_map, main_func_->attrs, main_func_->span));
200 return mod_;
201}
202
203namespace transform {
204
205tvm::transform::Pass CreateAllocatesForIO() {
206 auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
207 return IOAllocateCreator(m)();
208 };
209 return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.CreateAllocatesForIO", {});
210}
211
212TVM_REGISTER_GLOBAL("tir.usmp.transform.CreateAllocatesForIO").set_body_typed(CreateAllocatesForIO);
213
214} // namespace transform
215
216} // namespace usmp
217} // namespace tir
218} // namespace tvm
219