1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/target/target.h> |
21 | #include <tvm/tir/builtin.h> |
22 | #include <tvm/tir/op.h> |
23 | #include <tvm/tir/stmt_functor.h> |
24 | #include <tvm/tir/transform.h> |
25 | #include <tvm/tir/usmp/algorithms.h> |
26 | #include <tvm/tir/usmp/analysis.h> |
27 | #include <tvm/tir/usmp/transform.h> |
28 | #include <tvm/tir/usmp/utils.h> |
29 | |
30 | #include <stack> |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | namespace usmp { |
36 | |
37 | /*! \brief Creates Allocate nodes with special annotations |
38 | * for I/O tensors in the graph to be memory planned.*/ |
39 | class IOAllocateCreator : public StmtExprVisitor { |
40 | public: |
41 | explicit IOAllocateCreator(const IRModule& module) { |
42 | main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
43 | ICHECK(main_func_.defined()) << "main function is not in the module" ; |
44 | for (const auto& gv_func : module->functions) { |
45 | if (gv_func.second->IsInstance<PrimFuncNode>()) { |
46 | functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second)); |
47 | } |
48 | } |
49 | mod_ = module->ShallowCopy(); |
50 | } |
51 | IRModule operator()(); |
52 | |
53 | private: |
54 | void VisitExpr_(const BufferLoadNode* op) override; |
55 | void VisitExpr_(const LoadNode* op) override; |
56 | void VisitExpr_(const CallNode* op) override; |
57 | void VisitStmt_(const BufferStoreNode* op) override; |
58 | void VisitStmt_(const StoreNode* op) override; |
59 | |
60 | /*! \brief Updates aliases that buffer vars inside the primfunc refer |
61 | * to in terms call arguments they get bound to.*/ |
62 | void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func); |
63 | |
64 | /*! \brief The IRModule that is being mutated */ |
65 | IRModule mod_; |
66 | /*! \brief The main function that calls into operator subgraphs */ |
67 | PrimFunc main_func_; |
68 | /*! \brief The input Vars of the main function */ |
69 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inputs_; |
70 | /*! \brief The output Vars of the main function */ |
71 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_; |
72 | /*! \brief The buffer vars associated with the I/O Vars */ |
73 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> io_buffer_vars_; |
74 | /*! \brief The aliases that buffer vars inside the primfunc refer |
75 | * to in terms call arguments */ |
76 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> aliases_; |
77 | /*! |
78 | * \brief The TIR main function calls by name to PrimFuncs to be able to |
79 | * support BYOC. Therefore, this Map records functions that are present |
80 | * in the IRModule by name/ |
81 | */ |
82 | Map<String, PrimFunc> functions_; |
83 | }; |
84 | |
85 | /*! |
86 | * \brief The function obtains the matched buffer vars for |
87 | * the params of the PrimFunc. |
88 | */ |
89 | Array<Var> static GetMatchedBuffers(const PrimFunc& func) { |
90 | Array<Var> buffer_vars; |
91 | for (unsigned int i = 0; i < func->params.size() - 1; i++) { |
92 | Var param = func->params[i]; |
93 | buffer_vars.push_back(func->buffer_map[param]->data); |
94 | } |
95 | Var last_param = func->params.back(); |
96 | // Checks whether last var is present in the buffer map |
97 | // because it could be the resource handle |
98 | if (func->buffer_map.find(last_param) != func->buffer_map.end()) { |
99 | buffer_vars.push_back(func->buffer_map[last_param]->data); |
100 | } |
101 | return buffer_vars; |
102 | } |
103 | |
104 | /*! |
105 | * \brief The function updates aliases that each buffer var with its |
106 | * associated argument in the callsite. |
107 | */ |
108 | void IOAllocateCreator::UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func) { |
109 | auto param_buffers = GetMatchedBuffers(func); |
110 | // Last var could be a resource handle that does not have a Buffer |
111 | ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size()); |
112 | for (size_t i = 0; i < param_buffers.size(); i++) { |
113 | auto arg = args[i]; |
114 | if (arg->IsInstance<VarNode>()) { |
115 | auto param_buf = param_buffers[i]; |
116 | aliases_[param_buf] = Downcast<Var>(arg); |
117 | } |
118 | } |
119 | } |
120 | |
121 | void IOAllocateCreator::VisitExpr_(const CallNode* op) { |
122 | if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { |
123 | StringImm func_name = Downcast<StringImm>(op->args[0])->value; |
124 | if (functions_.find(func_name->value) != functions_.end()) { |
125 | auto func = functions_.at(func_name->value); |
126 | auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end()); |
127 | this->UpdateAliases(actual_args, func); |
128 | VisitStmt(func->body); |
129 | return; |
130 | } |
131 | } |
132 | if (op->op->IsInstance<PrimFuncNode>()) { |
133 | auto func = Downcast<PrimFunc>(op->op); |
134 | this->UpdateAliases(op->args, func); |
135 | VisitStmt(func->body); |
136 | return; |
137 | } |
138 | StmtExprVisitor::VisitExpr_(op); |
139 | } |
140 | |
141 | void IOAllocateCreator::VisitExpr_(const BufferLoadNode* op) { |
142 | if (aliases_.find(op->buffer->data) != aliases_.end()) { |
143 | Var aliased_var = aliases_[op->buffer->data]; |
144 | if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) { |
145 | ICHECK(outputs_.find(aliased_var) == outputs_.end()) |
146 | << "BufferLoad nodes should not be reading from output buffer vars." ; |
147 | inputs_.insert(aliased_var); |
148 | } |
149 | } |
150 | StmtExprVisitor::VisitExpr_(op); |
151 | } |
152 | |
153 | void IOAllocateCreator::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "should not come here" ; } |
154 | |
155 | void IOAllocateCreator::VisitStmt_(const BufferStoreNode* op) { |
156 | if (aliases_.find(op->buffer->data) != aliases_.end()) { |
157 | Var aliased_var = aliases_[op->buffer->data]; |
158 | if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) { |
159 | ICHECK(inputs_.find(aliased_var) == inputs_.end()) |
160 | << "BufferStore nodes should not be writing to input buffer vars." ; |
161 | outputs_.insert(aliased_var); |
162 | } |
163 | } |
164 | StmtExprVisitor::VisitStmt_(op); |
165 | } |
166 | |
167 | void IOAllocateCreator::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "should not come here" ; } |
168 | |
169 | IRModule IOAllocateCreator::operator()() { |
170 | Array<Var> new_main_params; |
171 | Stmt main_body = main_func_->body; |
172 | for (const Var& param : main_func_->params) { |
173 | if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) { |
174 | Var buffer_var = main_func_->buffer_map[param]->data; |
175 | io_buffer_vars_.insert(buffer_var); |
176 | aliases_[buffer_var] = buffer_var; |
177 | } |
178 | } |
179 | VisitStmt(main_body); |
180 | ICHECK(io_buffer_vars_.size() == inputs_.size() + outputs_.size()) |
181 | << "Every IO Buffer var should be categorized either to be input or output" ; |
182 | for (const Var& param : main_func_->params) { |
183 | if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) { |
184 | Buffer param_buffer = main_func_->buffer_map[param]; |
185 | String io_annotation; |
186 | if (inputs_.find(param_buffer->data) != inputs_.end()) { |
187 | io_annotation = String(kInputTensorAllocate); |
188 | } else { |
189 | io_annotation = String(kOutputTensorAllocate); |
190 | } |
191 | main_body = Allocate(param_buffer->data, param_buffer->dtype, param_buffer->shape, |
192 | const_true(), main_body, {{io_annotation, param->name_hint}}); |
193 | } else { |
194 | new_main_params.push_back(param); |
195 | } |
196 | } |
197 | const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); |
198 | mod_->Update(gv, PrimFunc(new_main_params, main_body, main_func_->ret_type, |
199 | main_func_->buffer_map, main_func_->attrs, main_func_->span)); |
200 | return mod_; |
201 | } |
202 | |
203 | namespace transform { |
204 | |
205 | tvm::transform::Pass CreateAllocatesForIO() { |
206 | auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { |
207 | return IOAllocateCreator(m)(); |
208 | }; |
209 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.CreateAllocatesForIO" , {}); |
210 | } |
211 | |
212 | TVM_REGISTER_GLOBAL("tir.usmp.transform.CreateAllocatesForIO" ).set_body_typed(CreateAllocatesForIO); |
213 | |
214 | } // namespace transform |
215 | |
216 | } // namespace usmp |
217 | } // namespace tir |
218 | } // namespace tvm |
219 | |