1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/usmp/utils.h |
22 | * \brief Utilities for Unified Static Memory Planner |
23 | */ |
24 | |
25 | #ifndef TVM_TIR_USMP_UTILS_H_ |
26 | #define TVM_TIR_USMP_UTILS_H_ |
27 | |
28 | #include <tvm/ir/expr.h> |
29 | #include <tvm/ir/memory_pools.h> |
30 | #include <tvm/runtime/device_api.h> |
31 | #include <tvm/target/target.h> |
32 | #include <tvm/tir/stmt.h> |
33 | |
34 | namespace tvm { |
35 | |
36 | /*! |
37 | * \brief PassContext option to enable the USMP |
38 | */ |
39 | constexpr const char* kUSMPEnableOption = "tir.usmp.enable" ; |
40 | /*! |
41 | * \brief PassContext option to select the memory planning algorithm in USMP |
42 | */ |
43 | constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm" ; |
44 | /*! |
45 | * \brief PassContext option to enable placing I/O tensors in the workspace |
46 | */ |
47 | constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io" ; |
48 | /*! |
49 | * \brief PassContext option to specify a custom memory planning algorithm in USMP. |
50 | * The algorithm should be provided as registered PackedFunc with the name tir.usmp.algorithm.NAME |
51 | */ |
52 | constexpr const char* kUSMPCustomAlgorithmOption = "tir.usmp.custom_algorithm" ; |
53 | |
54 | namespace tir { |
55 | namespace usmp { |
56 | /*! |
57 | * \brief A special kind to distinguish between I/O tensors to the model |
58 | * and intermediate tensors of the model |
59 | */ |
60 | enum class BufferInfoKind { kIntermediate = 0, kInput = 1, kOutput = 2 }; |
61 | |
62 | /*! |
63 | * \brief Describes an abstract memory buffer that will get allocated inside a pool. |
64 | * The actual memory buffer in represented by PoolAllocationNode after static memory planning. |
65 | * |
66 | * See also for relay-level counterparts: |
67 | * relay::StorageToken (graph_plan_memory.cc) |
68 | * relay::backend::StorageInfoNode (relay/backend/utils.h) |
69 | * Region (python/tvm/relay/transform/memory_plan.py) |
70 | */ |
71 | struct BufferInfoNode : public Object { |
72 | /*! \brief The name of the buffer var */ |
73 | String name_hint; |
74 | /*! \brief The size in terms of bytes */ |
75 | Integer size_bytes; |
76 | /*! \brief The pool candidates that this buffer can get pooled to*/ |
77 | Array<PoolInfo> pool_candidates; |
78 | /*! \brief The byte alignment required for buffers that will placed within the pool */ |
79 | Integer alignment; |
80 | /*! \brief The liveness conflicting other buffer info objects */ |
81 | Array<ObjectRef> conflicts; |
82 | /*! \brief Whether BufferInfo object retains info about IO tensors or intermediaries */ |
83 | BufferInfoKind kind; |
84 | |
85 | void VisitAttrs(tvm::AttrVisitor* v) { |
86 | v->Visit("name_hint" , &name_hint); |
87 | v->Visit("size_bytes" , &size_bytes); |
88 | v->Visit("pool_candidates" , &pool_candidates); |
89 | v->Visit("alignment" , &alignment); |
90 | v->Visit("conflicts" , &conflicts); |
91 | v->Visit("kind" , &kind); |
92 | } |
93 | |
94 | bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const { |
95 | return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) && |
96 | equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) && |
97 | equal(conflicts, other->conflicts) && equal(kind, other->kind); |
98 | } |
99 | |
100 | void SHashReduce(SHashReducer hash_reduce) const { |
101 | hash_reduce(name_hint); |
102 | hash_reduce(size_bytes); |
103 | hash_reduce(alignment); |
104 | hash_reduce(conflicts); |
105 | hash_reduce(pool_candidates); |
106 | hash_reduce(kind); |
107 | } |
108 | /*! |
109 | * \brief Set the liveness conflicts of this BufferInfo |
110 | * |
111 | * \param conflicting_buffer_info_objs An array of BufferInfo that conflicts in liveness |
112 | */ |
113 | TVM_DLL void SetConflicts(Array<ObjectRef> conflicting_buffer_info_objs); |
114 | |
115 | static constexpr const char* _type_key = "tir.usmp.BufferInfo" ; |
116 | TVM_DECLARE_FINAL_OBJECT_INFO(BufferInfoNode, Object); |
117 | }; |
118 | |
119 | class BufferInfo : public ObjectRef { |
120 | public: |
121 | TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates, |
122 | Integer alignment = runtime::kDefaultWorkspaceAlignment, |
123 | BufferInfoKind kind = BufferInfoKind::kIntermediate); |
124 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode); |
125 | }; |
126 | |
127 | /*! |
128 | * \brief This is a composite node that is produced by extract_buffer_info |
129 | * analysis pass that contains useful global information that could be useful |
130 | * for memory planning algorithms. |
131 | */ |
132 | struct BufferInfoAnalysisNode : public Object { |
133 | /*! \brief The BufferInfo object and its associated TIR statement */ |
134 | Map<BufferInfo, tir::Stmt> buffer_info_stmts; |
135 | /*! \brief This represent maximum amount of memory being used at |
136 | * any point of time in the inference. This value is largely the |
137 | * best allocation an algorithm could achieve. Due to |
138 | * the complexities of conflict graphs, it would not be feasible |
139 | * to achieve this value, practically. However, it can be useful |
140 | * for iterative algorithms to know this value to define termination |
141 | * criteria.*/ |
142 | Integer memory_pressure; |
143 | |
144 | void VisitAttrs(tvm::AttrVisitor* v) { |
145 | v->Visit("buffer_info_stmts" , &buffer_info_stmts); |
146 | v->Visit("memory_pressure" , &memory_pressure); |
147 | } |
148 | |
149 | bool SEqualReduce(const BufferInfoAnalysisNode* other, SEqualReducer equal) const { |
150 | return equal(buffer_info_stmts, other->buffer_info_stmts) && |
151 | equal(memory_pressure, other->memory_pressure); |
152 | } |
153 | |
154 | void SHashReduce(SHashReducer hash_reduce) const { |
155 | hash_reduce(buffer_info_stmts); |
156 | hash_reduce(memory_pressure); |
157 | } |
158 | }; |
159 | |
160 | class BufferInfoAnalysis : public ObjectRef { |
161 | public: |
162 | TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure); |
163 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode); |
164 | }; |
165 | |
166 | /*! |
167 | * \brief The pool allocation produced after the USMP algorithm |
168 | */ |
169 | struct PoolAllocationNode : public Object { |
170 | /*! \brief The assigned WorkspacePoolInfo or ConstantPoolInfo object */ |
171 | PoolInfo pool_info; |
172 | /*! \brief The byte offset within the pool*/ |
173 | Integer byte_offset; |
174 | |
175 | void VisitAttrs(tvm::AttrVisitor* v) { |
176 | v->Visit("pool_info" , &pool_info); |
177 | v->Visit("byte_offset" , &byte_offset); |
178 | } |
179 | |
180 | bool SEqualReduce(const PoolAllocationNode* other, SEqualReducer equal) const { |
181 | return equal(pool_info, other->pool_info) && equal(byte_offset, other->byte_offset); |
182 | } |
183 | |
184 | void SHashReduce(SHashReducer hash_reduce) const { |
185 | hash_reduce(pool_info); |
186 | hash_reduce(byte_offset); |
187 | } |
188 | |
189 | static constexpr const char* _type_key = "tir.usmp.PoolAllocation" ; |
190 | TVM_DECLARE_FINAL_OBJECT_INFO(PoolAllocationNode, Object); |
191 | }; |
192 | |
193 | class PoolAllocation : public ObjectRef { |
194 | public: |
195 | TVM_DLL PoolAllocation(PoolInfo pool_info, Integer byte_offset); |
196 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode); |
197 | }; |
198 | |
199 | /*! |
200 | * \brief This object contains information post-allocation for PoolInfo objects |
201 | */ |
202 | struct AllocatedPoolInfoNode : public Object { |
203 | /*! \brief The assigned PoolInfo object */ |
204 | PoolInfo pool_info; |
205 | /*! \brief The allocated size into this pool */ |
206 | Integer allocated_size; |
207 | /*! \brief An optional associated pool Var index of PrimFunc params*/ |
208 | Optional<Integer> pool_var_idx; |
209 | |
210 | void VisitAttrs(tvm::AttrVisitor* v) { |
211 | v->Visit("pool_info" , &pool_info); |
212 | v->Visit("allocated_size" , &allocated_size); |
213 | v->Visit("pool_var_idx" , &pool_var_idx); |
214 | } |
215 | |
216 | bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const { |
217 | return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) && |
218 | equal(pool_var_idx, other->pool_var_idx); |
219 | } |
220 | |
221 | void SHashReduce(SHashReducer hash_reduce) const { |
222 | hash_reduce(pool_info); |
223 | hash_reduce(allocated_size); |
224 | hash_reduce(pool_var_idx); |
225 | } |
226 | |
227 | static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo" ; |
228 | TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object); |
229 | }; |
230 | |
231 | class AllocatedPoolInfo : public ObjectRef { |
232 | public: |
233 | TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, |
234 | Integer pool_var_idx = Integer()); |
235 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode); |
236 | }; |
237 | |
238 | /*! |
239 | * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo |
240 | * |
241 | * \param buffer_info_map IR-bound BufferInfo map |
242 | */ |
243 | Array<BufferInfo> ConvertToArrayOfBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map); |
244 | |
245 | /*! |
246 | * \brief Calculate workspace required to execute a IRModule with main expressed in TIR |
247 | * |
248 | * \param mod the IRModule with TIR-based main function |
249 | */ |
250 | Integer CalculateModuleWorkspaceSize(const IRModule& mod); |
251 | |
252 | /*! |
253 | * \brief The allocate node attribute to indicate candidate memory pools. |
254 | * This needs to be kept in sync with CANDIDATE_MEMORY_POOL_ATTR in |
255 | * python/tvm/tir/usmp/utils.py. |
256 | */ |
257 | static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools" ; |
258 | |
259 | /*! |
260 | * \brief The allocate node attribute to indicate it is being used to hold |
261 | * an input tensor, that needs to be initialized with. |
262 | */ |
263 | static constexpr const char* kInputTensorAllocate = "input_tensor" ; |
264 | |
265 | /*! |
266 | * \brief The allocate node attribute to indicate it is being used to hold |
267 | * an output tensor. |
268 | */ |
269 | static constexpr const char* kOutputTensorAllocate = "output_tensor" ; |
270 | |
271 | /*! |
272 | * \brief Calculate the size of the extents in bytes |
273 | * |
274 | * \param op the allocate node |
275 | */ |
276 | Integer CalculateExtentsSize(const AllocateNode* op); |
277 | |
278 | /*! |
279 | * \brief Calculate the size of the extents in bytes |
280 | * |
281 | * \param op the allocate const node |
282 | */ |
283 | Integer CalculateExtentsSize(const AllocateConstNode* op); |
284 | |
285 | /*! |
286 | * \brief Joins the Stmt nodes with PoolAllocation objects |
287 | * |
288 | * \param buffer_info_to_stmt the map of BufferInfo objects to Stmt nodes |
289 | * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects |
290 | */ |
291 | Map<Stmt, PoolAllocation> AssignStmtPoolAllocations( |
292 | const Map<BufferInfo, Stmt>& buffer_info_to_stmt, |
293 | const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation); |
294 | |
295 | /*! |
296 | * \brief Obtains I/O tensor names to their PoolAllocation objects |
297 | * |
298 | * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects |
299 | * |
300 | * This function will obtain pool allocations for I/O tensors if that had been planned |
301 | */ |
302 | Map<String, PoolAllocation> GetIOPoolAllocations( |
303 | const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation); |
304 | |
305 | } // namespace usmp |
306 | } // namespace tir |
307 | |
308 | namespace attr { |
309 | /*! |
310 | * \brief This is a BaseFunc attribute to indicate which input var represent |
311 | * a PoolInfo Object in the form of a Map<Var, PoolInfo>. |
312 | */ |
313 | static constexpr const char* kPoolArgs = "pool_args" ; |
314 | |
315 | /*! |
316 | * \brief This is a IRModule attribute that contains I/O Tensor names to pool |
317 | * allocations. |
318 | */ |
319 | static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations" ; |
320 | |
321 | } // namespace attr |
322 | |
323 | } // namespace tvm |
324 | |
325 | #endif // TVM_TIR_USMP_UTILS_H_ |
326 | |