1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/usmp/utils.cc |
22 | * \brief Utilities for Unified Static Memory Planner |
23 | */ |
24 | |
25 | #include <tvm/ir/memory_pools.h> |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/runtime/registry.h> |
28 | #include <tvm/tir/analysis.h> |
29 | #include <tvm/tir/builtin.h> |
30 | #include <tvm/tir/function.h> |
31 | #include <tvm/tir/stmt.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | #include <tvm/tir/usmp/utils.h> |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | namespace usmp { |
38 | |
39 | BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates, |
40 | Integer alignment, BufferInfoKind kind) { |
41 | auto bufinfo_node = make_object<BufferInfoNode>(); |
42 | bufinfo_node->name_hint = name_hint; |
43 | bufinfo_node->size_bytes = size_bytes; |
44 | bufinfo_node->pool_candidates = pool_candidates; |
45 | bufinfo_node->alignment = alignment; |
46 | bufinfo_node->kind = kind; |
47 | data_ = std::move(bufinfo_node); |
48 | } |
49 | |
50 | void BufferInfoNode::SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs) { |
51 | this->conflicts = conflicting_buffer_info_objs; |
52 | } |
53 | |
54 | TVM_REGISTER_NODE_TYPE(BufferInfoNode); |
55 | TVM_REGISTER_GLOBAL("tir.usmp.BufferInfo" ) |
56 | .set_body_typed([](String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates, |
57 | Integer alignment) { |
58 | if (!alignment.defined()) { |
59 | return BufferInfo(name_hint, size_bytes, pool_candidates); |
60 | } |
61 | return BufferInfo(name_hint, size_bytes, pool_candidates, alignment); |
62 | }); |
63 | TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts" ) |
64 | .set_body_method<BufferInfo>(&BufferInfoNode::SetConflicts); |
65 | |
66 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
67 | .set_dispatch<BufferInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { |
68 | auto* node = static_cast<const BufferInfoNode*>(ref.get()); |
69 | std::unordered_map<BufferInfoKind, String> toString = { |
70 | {BufferInfoKind::kIntermediate, "kIntermediate" }, |
71 | {BufferInfoKind::kInput, "kInput" }, |
72 | {BufferInfoKind::kOutput, "kOutput" }}; |
73 | p->stream << "BufferInfoNode(\n" |
74 | << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes |
75 | << ",\n pool_candidates=" << node->pool_candidates |
76 | << ",\n alignment=" << node->alignment << ",\n kind=" << toString[node->kind] |
77 | << ",\n conflicts=" << node->conflicts.size() << ")" ; |
78 | }); |
79 | |
80 | BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, |
81 | Integer memory_pressure) { |
82 | auto bufinfo_analysis_node = make_object<BufferInfoAnalysisNode>(); |
83 | bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts; |
84 | bufinfo_analysis_node->memory_pressure = memory_pressure; |
85 | data_ = std::move(bufinfo_analysis_node); |
86 | } |
87 | |
88 | TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode); |
89 | TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis" ) |
90 | .set_body_typed([](Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure) { |
91 | return BufferInfoAnalysis(buffer_info_stmts, memory_pressure); |
92 | }); |
93 | |
94 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
95 | .set_dispatch<BufferInfoAnalysisNode>([](const ObjectRef& ref, ReprPrinter* p) { |
96 | auto* node = static_cast<const BufferInfoAnalysisNode*>(ref.get()); |
97 | p->stream << "BufferInfoAnalysisNode(\n" |
98 | << "buffer_info_stmts=" << node->buffer_info_stmts |
99 | << ",\n memory_pressure=" << node->memory_pressure << ")" ; |
100 | }); |
101 | |
102 | PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) { |
103 | auto pool_allocation_node = make_object<PoolAllocationNode>(); |
104 | pool_allocation_node->pool_info = pool_info; |
105 | pool_allocation_node->byte_offset = byte_offset; |
106 | data_ = std::move(pool_allocation_node); |
107 | } |
108 | |
109 | TVM_REGISTER_NODE_TYPE(PoolAllocationNode); |
110 | TVM_REGISTER_GLOBAL("tir.usmp.PoolAllocation" ) |
111 | .set_body_typed([](PoolInfo pool_info, Integer byte_offset) { |
112 | return PoolAllocation(pool_info, byte_offset); |
113 | }); |
114 | |
115 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
116 | .set_dispatch<PoolAllocationNode>([](const ObjectRef& ref, ReprPrinter* p) { |
117 | auto* node = static_cast<const PoolAllocationNode*>(ref.get()); |
118 | p->stream << "PoolAllocationNode(\n" |
119 | << "pool_info=" << node->pool_info << ",\n byte_offset=" << node->byte_offset |
120 | << ")" ; |
121 | }); |
122 | |
123 | AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, |
124 | Integer pool_var_idx) { |
125 | auto allocated_poolinfo_node = make_object<AllocatedPoolInfoNode>(); |
126 | allocated_poolinfo_node->pool_info = pool_info; |
127 | allocated_poolinfo_node->allocated_size = allocated_size; |
128 | if (pool_var_idx.defined()) { |
129 | allocated_poolinfo_node->pool_var_idx = pool_var_idx; |
130 | } |
131 | data_ = std::move(allocated_poolinfo_node); |
132 | } |
133 | |
134 | TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode); |
135 | TVM_REGISTER_GLOBAL("ir.AllocatedPoolInfo" ) |
136 | .set_body_typed([](PoolInfo pool_info, Integer allocated_size, Integer pool_var_idx) { |
137 | return AllocatedPoolInfo(pool_info, allocated_size, pool_var_idx); |
138 | }); |
139 | |
140 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
141 | .set_dispatch<AllocatedPoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { |
142 | auto* node = static_cast<const AllocatedPoolInfoNode*>(ref.get()); |
143 | p->stream << "AllocatedPoolInfoNode(\n" |
144 | << "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size |
145 | << ")" ; |
146 | }); |
147 | |
148 | Array<BufferInfo> ConvertToArrayOfBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map) { |
149 | Array<BufferInfo> ret; |
150 | for (const auto& kv : buffer_info_map) { |
151 | auto buffer_info = kv.first; |
152 | ret.push_back(buffer_info); |
153 | } |
154 | return ret; |
155 | } |
156 | |
157 | Map<Stmt, PoolAllocation> AssignStmtPoolAllocations( |
158 | const Map<BufferInfo, Stmt>& buffer_info_to_stmt, |
159 | const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) { |
160 | Map<Stmt, PoolAllocation> ret; |
161 | for (const auto& kv : buffer_info_to_pool_allocation) { |
162 | BufferInfo bi = kv.first; |
163 | Stmt stmt_ = buffer_info_to_stmt[bi]; |
164 | PoolAllocation pa = kv.second; |
165 | ret.Set(stmt_, pa); |
166 | } |
167 | return ret; |
168 | } |
169 | |
170 | Map<String, PoolAllocation> GetIOPoolAllocations( |
171 | const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) { |
172 | Map<String, PoolAllocation> io_tensor_name_to_pool_allocation; |
173 | for (const auto& kv : buffer_info_to_pool_allocation) { |
174 | BufferInfo buffer_info = kv.first; |
175 | PoolAllocation pool_allocation = kv.second; |
176 | if (buffer_info->kind != BufferInfoKind::kIntermediate) { |
177 | io_tensor_name_to_pool_allocation.Set(buffer_info->name_hint, pool_allocation); |
178 | } |
179 | } |
180 | return io_tensor_name_to_pool_allocation; |
181 | } |
182 | |
183 | static Integer CalculateExtentsSize(const DataType& dtype, const Array<PrimExpr>& extents) { |
184 | size_t element_size_bytes = dtype.bytes(); |
185 | size_t num_elements = 1; |
186 | for (const auto& ext : extents) { |
187 | if (ext->IsInstance<IntImmNode>()) { |
188 | num_elements *= Downcast<IntImm>(ext)->value; |
189 | } else { |
190 | // We can't statically calculate workspace for dynamic shapes |
191 | return Integer(); |
192 | } |
193 | } |
194 | return Integer(num_elements * element_size_bytes); |
195 | } |
196 | |
197 | Integer CalculateExtentsSize(const AllocateNode* op) { |
198 | return CalculateExtentsSize(op->dtype, op->extents); |
199 | } |
200 | |
201 | Integer CalculateExtentsSize(const AllocateConstNode* op) { |
202 | return CalculateExtentsSize(op->dtype, op->extents); |
203 | } |
204 | |
205 | class ModuleWorkspaceSizeCalculator : public StmtExprVisitor { |
206 | public: |
207 | explicit ModuleWorkspaceSizeCalculator(const IRModule& module) : mod_(module) { |
208 | for (const auto& gv_func : mod_->functions) { |
209 | if ((gv_func.second)->IsInstance<tir::PrimFuncNode>()) { |
210 | functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second)); |
211 | } |
212 | } |
213 | main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
214 | ICHECK(main_func_.defined()) << "main function is not in the module" ; |
215 | Optional<Target> target_host = main_func_->GetAttr<Target>(tvm::attr::kTarget); |
216 | ICHECK(target_host) << "main function does not have a target attr" ; |
217 | target_host_ = target_host.value(); |
218 | } |
219 | |
220 | Integer operator()() { |
221 | UpdateWorkspaceData(main_func_); |
222 | return Integer(max_workspace_size); |
223 | } |
224 | |
225 | private: |
226 | void UpdateWorkspaceData(const PrimFunc& func) { |
227 | Target tgt = func->GetAttr<Target>(tvm::attr::kTarget).value_or(target_host_); |
228 | Integer workspace_byte_alignment = |
229 | tgt->GetAttr<Integer>("workspace-byte-alignment" ).value_or(16); |
230 | Integer workspace_req = CalculateWorkspaceBytes(func, workspace_byte_alignment); |
231 | if (workspace_req.IntValue() != 0) { |
232 | current_workspace_size_ += workspace_req->value; |
233 | } |
234 | if (max_workspace_size < current_workspace_size_) { |
235 | max_workspace_size = current_workspace_size_; |
236 | } |
237 | this->VisitStmt(func->body); |
238 | if (workspace_req.IntValue() != 0) { |
239 | current_workspace_size_ -= workspace_req->value; |
240 | } |
241 | } |
242 | |
243 | void VisitExpr_(const CallNode* op) override { |
244 | if (op->op.same_as(builtin::call_extern())) { |
245 | PrimFunc func = functions_.at(Downcast<StringImm>(op->args[0])->value); |
246 | UpdateWorkspaceData(func); |
247 | } else if (op->op->IsInstance<PrimFuncNode>()) { |
248 | PrimFunc func = Downcast<PrimFunc>(op->op); |
249 | UpdateWorkspaceData(func); |
250 | } else { |
251 | StmtExprVisitor::VisitExpr_(op); |
252 | } |
253 | } |
254 | |
255 | IRModule mod_; |
256 | Target target_host_; |
257 | PrimFunc main_func_; |
258 | Map<String, PrimFunc> functions_; |
259 | size_t current_workspace_size_ = 0; |
260 | size_t max_workspace_size = 0; |
261 | }; |
262 | |
263 | Integer CalculateModuleWorkspaceSize(const IRModule& mod) { |
264 | return ModuleWorkspaceSizeCalculator(mod)(); |
265 | } |
266 | |
267 | TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo" ) |
268 | .set_body_typed([](Map<BufferInfo, Stmt> buffer_info_map) { |
269 | return (ConvertToArrayOfBufferInfo(buffer_info_map)); |
270 | }); |
271 | |
272 | TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations" ).set_body_typed(AssignStmtPoolAllocations); |
273 | |
274 | } // namespace usmp |
275 | } // namespace tir |
276 | } // namespace tvm |
277 | |