1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/usmp/utils.cc
22 * \brief Utilities for Unified Static Memory Planner
23 */
24
25#include <tvm/ir/memory_pools.h>
26#include <tvm/runtime/device_api.h>
27#include <tvm/runtime/registry.h>
28#include <tvm/tir/analysis.h>
29#include <tvm/tir/builtin.h>
30#include <tvm/tir/function.h>
31#include <tvm/tir/stmt.h>
32#include <tvm/tir/stmt_functor.h>
33#include <tvm/tir/usmp/utils.h>
34
35namespace tvm {
36namespace tir {
37namespace usmp {
38
39BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
40 Integer alignment, BufferInfoKind kind) {
41 auto bufinfo_node = make_object<BufferInfoNode>();
42 bufinfo_node->name_hint = name_hint;
43 bufinfo_node->size_bytes = size_bytes;
44 bufinfo_node->pool_candidates = pool_candidates;
45 bufinfo_node->alignment = alignment;
46 bufinfo_node->kind = kind;
47 data_ = std::move(bufinfo_node);
48}
49
50void BufferInfoNode::SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs) {
51 this->conflicts = conflicting_buffer_info_objs;
52}
53
54TVM_REGISTER_NODE_TYPE(BufferInfoNode);
55TVM_REGISTER_GLOBAL("tir.usmp.BufferInfo")
56 .set_body_typed([](String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
57 Integer alignment) {
58 if (!alignment.defined()) {
59 return BufferInfo(name_hint, size_bytes, pool_candidates);
60 }
61 return BufferInfo(name_hint, size_bytes, pool_candidates, alignment);
62 });
63TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts")
64 .set_body_method<BufferInfo>(&BufferInfoNode::SetConflicts);
65
66TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
67 .set_dispatch<BufferInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
68 auto* node = static_cast<const BufferInfoNode*>(ref.get());
69 std::unordered_map<BufferInfoKind, String> toString = {
70 {BufferInfoKind::kIntermediate, "kIntermediate"},
71 {BufferInfoKind::kInput, "kInput"},
72 {BufferInfoKind::kOutput, "kOutput"}};
73 p->stream << "BufferInfoNode(\n"
74 << "name_hint=" << node->name_hint << ",\n size_bytes=" << node->size_bytes
75 << ",\n pool_candidates=" << node->pool_candidates
76 << ",\n alignment=" << node->alignment << ",\n kind=" << toString[node->kind]
77 << ",\n conflicts=" << node->conflicts.size() << ")";
78 });
79
80BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts,
81 Integer memory_pressure) {
82 auto bufinfo_analysis_node = make_object<BufferInfoAnalysisNode>();
83 bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts;
84 bufinfo_analysis_node->memory_pressure = memory_pressure;
85 data_ = std::move(bufinfo_analysis_node);
86}
87
88TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode);
89TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis")
90 .set_body_typed([](Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure) {
91 return BufferInfoAnalysis(buffer_info_stmts, memory_pressure);
92 });
93
94TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
95 .set_dispatch<BufferInfoAnalysisNode>([](const ObjectRef& ref, ReprPrinter* p) {
96 auto* node = static_cast<const BufferInfoAnalysisNode*>(ref.get());
97 p->stream << "BufferInfoAnalysisNode(\n"
98 << "buffer_info_stmts=" << node->buffer_info_stmts
99 << ",\n memory_pressure=" << node->memory_pressure << ")";
100 });
101
102PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) {
103 auto pool_allocation_node = make_object<PoolAllocationNode>();
104 pool_allocation_node->pool_info = pool_info;
105 pool_allocation_node->byte_offset = byte_offset;
106 data_ = std::move(pool_allocation_node);
107}
108
109TVM_REGISTER_NODE_TYPE(PoolAllocationNode);
110TVM_REGISTER_GLOBAL("tir.usmp.PoolAllocation")
111 .set_body_typed([](PoolInfo pool_info, Integer byte_offset) {
112 return PoolAllocation(pool_info, byte_offset);
113 });
114
115TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
116 .set_dispatch<PoolAllocationNode>([](const ObjectRef& ref, ReprPrinter* p) {
117 auto* node = static_cast<const PoolAllocationNode*>(ref.get());
118 p->stream << "PoolAllocationNode(\n"
119 << "pool_info=" << node->pool_info << ",\n byte_offset=" << node->byte_offset
120 << ")";
121 });
122
123AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size,
124 Integer pool_var_idx) {
125 auto allocated_poolinfo_node = make_object<AllocatedPoolInfoNode>();
126 allocated_poolinfo_node->pool_info = pool_info;
127 allocated_poolinfo_node->allocated_size = allocated_size;
128 if (pool_var_idx.defined()) {
129 allocated_poolinfo_node->pool_var_idx = pool_var_idx;
130 }
131 data_ = std::move(allocated_poolinfo_node);
132}
133
134TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode);
135TVM_REGISTER_GLOBAL("ir.AllocatedPoolInfo")
136 .set_body_typed([](PoolInfo pool_info, Integer allocated_size, Integer pool_var_idx) {
137 return AllocatedPoolInfo(pool_info, allocated_size, pool_var_idx);
138 });
139
140TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
141 .set_dispatch<AllocatedPoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
142 auto* node = static_cast<const AllocatedPoolInfoNode*>(ref.get());
143 p->stream << "AllocatedPoolInfoNode(\n"
144 << "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size
145 << ")";
146 });
147
148Array<BufferInfo> ConvertToArrayOfBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map) {
149 Array<BufferInfo> ret;
150 for (const auto& kv : buffer_info_map) {
151 auto buffer_info = kv.first;
152 ret.push_back(buffer_info);
153 }
154 return ret;
155}
156
157Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
158 const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
159 const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) {
160 Map<Stmt, PoolAllocation> ret;
161 for (const auto& kv : buffer_info_to_pool_allocation) {
162 BufferInfo bi = kv.first;
163 Stmt stmt_ = buffer_info_to_stmt[bi];
164 PoolAllocation pa = kv.second;
165 ret.Set(stmt_, pa);
166 }
167 return ret;
168}
169
170Map<String, PoolAllocation> GetIOPoolAllocations(
171 const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) {
172 Map<String, PoolAllocation> io_tensor_name_to_pool_allocation;
173 for (const auto& kv : buffer_info_to_pool_allocation) {
174 BufferInfo buffer_info = kv.first;
175 PoolAllocation pool_allocation = kv.second;
176 if (buffer_info->kind != BufferInfoKind::kIntermediate) {
177 io_tensor_name_to_pool_allocation.Set(buffer_info->name_hint, pool_allocation);
178 }
179 }
180 return io_tensor_name_to_pool_allocation;
181}
182
183static Integer CalculateExtentsSize(const DataType& dtype, const Array<PrimExpr>& extents) {
184 size_t element_size_bytes = dtype.bytes();
185 size_t num_elements = 1;
186 for (const auto& ext : extents) {
187 if (ext->IsInstance<IntImmNode>()) {
188 num_elements *= Downcast<IntImm>(ext)->value;
189 } else {
190 // We can't statically calculate workspace for dynamic shapes
191 return Integer();
192 }
193 }
194 return Integer(num_elements * element_size_bytes);
195}
196
197Integer CalculateExtentsSize(const AllocateNode* op) {
198 return CalculateExtentsSize(op->dtype, op->extents);
199}
200
201Integer CalculateExtentsSize(const AllocateConstNode* op) {
202 return CalculateExtentsSize(op->dtype, op->extents);
203}
204
205class ModuleWorkspaceSizeCalculator : public StmtExprVisitor {
206 public:
207 explicit ModuleWorkspaceSizeCalculator(const IRModule& module) : mod_(module) {
208 for (const auto& gv_func : mod_->functions) {
209 if ((gv_func.second)->IsInstance<tir::PrimFuncNode>()) {
210 functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
211 }
212 }
213 main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
214 ICHECK(main_func_.defined()) << "main function is not in the module";
215 Optional<Target> target_host = main_func_->GetAttr<Target>(tvm::attr::kTarget);
216 ICHECK(target_host) << "main function does not have a target attr";
217 target_host_ = target_host.value();
218 }
219
220 Integer operator()() {
221 UpdateWorkspaceData(main_func_);
222 return Integer(max_workspace_size);
223 }
224
225 private:
226 void UpdateWorkspaceData(const PrimFunc& func) {
227 Target tgt = func->GetAttr<Target>(tvm::attr::kTarget).value_or(target_host_);
228 Integer workspace_byte_alignment =
229 tgt->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
230 Integer workspace_req = CalculateWorkspaceBytes(func, workspace_byte_alignment);
231 if (workspace_req.IntValue() != 0) {
232 current_workspace_size_ += workspace_req->value;
233 }
234 if (max_workspace_size < current_workspace_size_) {
235 max_workspace_size = current_workspace_size_;
236 }
237 this->VisitStmt(func->body);
238 if (workspace_req.IntValue() != 0) {
239 current_workspace_size_ -= workspace_req->value;
240 }
241 }
242
243 void VisitExpr_(const CallNode* op) override {
244 if (op->op.same_as(builtin::call_extern())) {
245 PrimFunc func = functions_.at(Downcast<StringImm>(op->args[0])->value);
246 UpdateWorkspaceData(func);
247 } else if (op->op->IsInstance<PrimFuncNode>()) {
248 PrimFunc func = Downcast<PrimFunc>(op->op);
249 UpdateWorkspaceData(func);
250 } else {
251 StmtExprVisitor::VisitExpr_(op);
252 }
253 }
254
255 IRModule mod_;
256 Target target_host_;
257 PrimFunc main_func_;
258 Map<String, PrimFunc> functions_;
259 size_t current_workspace_size_ = 0;
260 size_t max_workspace_size = 0;
261};
262
263Integer CalculateModuleWorkspaceSize(const IRModule& mod) {
264 return ModuleWorkspaceSizeCalculator(mod)();
265}
266
267TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo")
268 .set_body_typed([](Map<BufferInfo, Stmt> buffer_info_map) {
269 return (ConvertToArrayOfBufferInfo(buffer_info_map));
270 });
271
272TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations);
273
274} // namespace usmp
275} // namespace tir
276} // namespace tvm
277