1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/usmp/utils.h
22 * \brief Utilities for Unified Static Memory Planner
23 */
24
25#ifndef TVM_TIR_USMP_UTILS_H_
26#define TVM_TIR_USMP_UTILS_H_
27
28#include <tvm/ir/expr.h>
29#include <tvm/ir/memory_pools.h>
30#include <tvm/runtime/device_api.h>
31#include <tvm/target/target.h>
32#include <tvm/tir/stmt.h>
33
34namespace tvm {
35
36/*!
37 * \brief PassContext option to enable the USMP
38 */
39constexpr const char* kUSMPEnableOption = "tir.usmp.enable";
40/*!
41 * \brief PassContext option to select the memory planning algorithm in USMP
42 */
43constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
44/*!
45 * \brief PassContext option to enable placing I/O tensors in the workspace
46 */
47constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
48/*!
49 * \brief PassContext option to specify a custom memory planning algorithm in USMP.
50 * The algorithm should be provided as registered PackedFunc with the name tir.usmp.algorithm.NAME
51 */
52constexpr const char* kUSMPCustomAlgorithmOption = "tir.usmp.custom_algorithm";
53
54namespace tir {
55namespace usmp {
56/*!
57 * \brief A special kind to distinguish between I/O tensors to the model
58 * and intermediate tensors of the model
59 */
60enum class BufferInfoKind { kIntermediate = 0, kInput = 1, kOutput = 2 };
61
62/*!
63 * \brief Describes an abstract memory buffer that will get allocated inside a pool.
64 * The actual memory buffer in represented by PoolAllocationNode after static memory planning.
65 *
66 * See also for relay-level counterparts:
67 * relay::StorageToken (graph_plan_memory.cc)
68 * relay::backend::StorageInfoNode (relay/backend/utils.h)
69 * Region (python/tvm/relay/transform/memory_plan.py)
70 */
71struct BufferInfoNode : public Object {
72 /*! \brief The name of the buffer var */
73 String name_hint;
74 /*! \brief The size in terms of bytes */
75 Integer size_bytes;
76 /*! \brief The pool candidates that this buffer can get pooled to*/
77 Array<PoolInfo> pool_candidates;
78 /*! \brief The byte alignment required for buffers that will placed within the pool */
79 Integer alignment;
80 /*! \brief The liveness conflicting other buffer info objects */
81 Array<ObjectRef> conflicts;
82 /*! \brief Whether BufferInfo object retains info about IO tensors or intermediaries */
83 BufferInfoKind kind;
84
85 void VisitAttrs(tvm::AttrVisitor* v) {
86 v->Visit("name_hint", &name_hint);
87 v->Visit("size_bytes", &size_bytes);
88 v->Visit("pool_candidates", &pool_candidates);
89 v->Visit("alignment", &alignment);
90 v->Visit("conflicts", &conflicts);
91 v->Visit("kind", &kind);
92 }
93
94 bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const {
95 return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) &&
96 equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) &&
97 equal(conflicts, other->conflicts) && equal(kind, other->kind);
98 }
99
100 void SHashReduce(SHashReducer hash_reduce) const {
101 hash_reduce(name_hint);
102 hash_reduce(size_bytes);
103 hash_reduce(alignment);
104 hash_reduce(conflicts);
105 hash_reduce(pool_candidates);
106 hash_reduce(kind);
107 }
108 /*!
109 * \brief Set the liveness conflicts of this BufferInfo
110 *
111 * \param conflicting_buffer_info_objs An array of BufferInfo that conflicts in liveness
112 */
113 TVM_DLL void SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs);
114
115 static constexpr const char* _type_key = "tir.usmp.BufferInfo";
116 TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object);
117};
118
119class BufferInfo : public ObjectRef {
120 public:
121 TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
122 Integer alignment = runtime::kDefaultWorkspaceAlignment,
123 BufferInfoKind kind = BufferInfoKind::kIntermediate);
124 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode);
125};
126
127/*!
128 * \brief This is a composite node that is produced by extract_buffer_info
129 * analysis pass that contains useful global information that could be useful
130 * for memory planning algorithms.
131 */
132struct BufferInfoAnalysisNode : public Object {
133 /*! \brief The BufferInfo object and its associated TIR statement */
134 Map<BufferInfo, tir::Stmt> buffer_info_stmts;
135 /*! \brief This represent maximum amount of memory being used at
136 * any point of time in the inference. This value is largely the
137 * best allocation an algorithm could achieve. Due to
138 * the complexities of conflict graphs, it would not be feasible
139 * to achieve this value, practically. However, it can be useful
140 * for iterative algorithms to know this value to define termination
141 * criteria.*/
142 Integer memory_pressure;
143
144 void VisitAttrs(tvm::AttrVisitor* v) {
145 v->Visit("buffer_info_stmts", &buffer_info_stmts);
146 v->Visit("memory_pressure", &memory_pressure);
147 }
148
149 bool SEqualReduce(const BufferInfoAnalysisNode* other, SEqualReducer equal) const {
150 return equal(buffer_info_stmts, other->buffer_info_stmts) &&
151 equal(memory_pressure, other->memory_pressure);
152 }
153
154 void SHashReduce(SHashReducer hash_reduce) const {
155 hash_reduce(buffer_info_stmts);
156 hash_reduce(memory_pressure);
157 }
158};
159
160class BufferInfoAnalysis : public ObjectRef {
161 public:
162 TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure);
163 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode);
164};
165
166/*!
167 * \brief The pool allocation produced after the USMP algorithm
168 */
169struct PoolAllocationNode : public Object {
170 /*! \brief The assigned WorkspacePoolInfo or ConstantPoolInfo object */
171 PoolInfo pool_info;
172 /*! \brief The byte offset within the pool*/
173 Integer byte_offset;
174
175 void VisitAttrs(tvm::AttrVisitor* v) {
176 v->Visit("pool_info", &pool_info);
177 v->Visit("byte_offset", &byte_offset);
178 }
179
180 bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const {
181 return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset);
182 }
183
184 void SHashReduce(SHashReducer hash_reduce) const {
185 hash_reduce(pool_info);
186 hash_reduce(byte_offset);
187 }
188
189 static constexpr const char* _type_key = "tir.usmp.PoolAllocation";
190 TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object);
191};
192
193class PoolAllocation : public ObjectRef {
194 public:
195 TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset);
196 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode);
197};
198
199/*!
200 * \brief This object contains information post-allocation for PoolInfo objects
201 */
202struct AllocatedPoolInfoNode : public Object {
203 /*! \brief The assigned PoolInfo object */
204 PoolInfo pool_info;
205 /*! \brief The allocated size into this pool */
206 Integer allocated_size;
207 /*! \brief An optional associated pool Var index of PrimFunc params*/
208 Optional<Integer> pool_var_idx;
209
210 void VisitAttrs(tvm::AttrVisitor* v) {
211 v->Visit("pool_info", &pool_info);
212 v->Visit("allocated_size", &allocated_size);
213 v->Visit("pool_var_idx", &pool_var_idx);
214 }
215
216 bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
217 return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
218 equal(pool_var_idx, other->pool_var_idx);
219 }
220
221 void SHashReduce(SHashReducer hash_reduce) const {
222 hash_reduce(pool_info);
223 hash_reduce(allocated_size);
224 hash_reduce(pool_var_idx);
225 }
226
227 static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
228 TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object);
229};
230
231class AllocatedPoolInfo : public ObjectRef {
232 public:
233 TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size,
234 Integer pool_var_idx = Integer());
235 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
236};
237
238/*!
239 * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo
240 *
241 * \param buffer_info_map IR-bound BufferInfo map
242 */
243Array<BufferInfo> ConvertToArrayOfBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map);
244
245/*!
246 * \brief Calculate workspace required to execute a IRModule with main expressed in TIR
247 *
248 * \param mod the IRModule with TIR-based main function
249 */
250Integer CalculateModuleWorkspaceSize(const IRModule& mod);
251
252/*!
253 * \brief The allocate node attribute to indicate candidate memory pools.
254 * This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in
255 * python/tvm/tir/usmp/utils.py.
256 */
257static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
258
259/*!
260 * \brief The allocate node attribute to indicate it is being used to hold
261 * an input tensor, that needs to be initialized with.
262 */
263static constexpr const char* kInputTensorAllocate = "input_tensor";
264
265/*!
266 * \brief The allocate node attribute to indicate it is being used to hold
267 * an output tensor.
268 */
269static constexpr const char* kOutputTensorAllocate = "output_tensor";
270
271/*!
272 * \brief Calculate the size of the extents in bytes
273 *
274 * \param op the allocate node
275 */
276Integer CalculateExtentsSize(const AllocateNode* op);
277
278/*!
279 * \brief Calculate the size of the extents in bytes
280 *
281 * \param op the allocate const node
282 */
283Integer CalculateExtentsSize(const AllocateConstNode* op);
284
285/*!
286 * \brief Joins the Stmt nodes with PoolAllocation objects
287 *
288 * \param buffer_info_to_stmt the map of BufferInfo objects to Stmt nodes
289 * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects
290 */
291Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
292 const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
293 const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
294
295/*!
296 * \brief Obtains I/O tensor names to their PoolAllocation objects
297 *
298 * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects
299 *
300 * This function will obtain pool allocations for I/O tensors if that had been planned
301 */
302Map<String, PoolAllocation> GetIOPoolAllocations(
303 const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
304
305} // namespace usmp
306} // namespace tir
307
308namespace attr {
309/*!
310 * \brief This is a BaseFunc attribute to indicate which input var represent
311 * a PoolInfo Object in the form of a Map<Var, PoolInfo>.
312 */
313static constexpr const char* kPoolArgs = "pool_args";
314
315/*!
316 * \brief This is a IRModule attribute that contains I/O Tensor names to pool
317 * allocations.
318 */
319static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations";
320
321} // namespace attr
322
323} // namespace tvm
324
325#endif // TVM_TIR_USMP_UTILS_H_
326