1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
22 * \brief This pass would convert the pool allocations to offsets from pools
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/runtime/device_api.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/function.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/stmt_functor.h>
31#include <tvm/tir/transform.h>
32#include <tvm/tir/usmp/transform.h>
33#include <tvm/tir/usmp/utils.h>
34
35#include <stack>
36
37namespace tvm {
38namespace tir {
39namespace usmp {
40
41/*!
42 * \brief The StmtExpr mutator class to replace allocate nodes
43 * with offsets within memory pools
44 *
45 * This mutator class will add Pool variables recursively to every PrimFunc
46 * starting from the main PrimFunc. For all allocate nodes, that have been
47 * memory planned, will be mutated into an offset using a Let binding.
48 */
49class PoolAllocationToOffsetConverter : public StmtExprMutator {
50 public:
51 PoolAllocationToOffsetConverter(const IRModule& module,
52 const Map<tir::Stmt, PoolAllocation>& pool_allocations,
53 bool emit_tvmscript_printable = false)
54 : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
55 module_ = module->ShallowCopy();
56 for (const auto& kv : pool_allocations) {
57 size_t extent_size = -1;
58 if (kv.first->IsInstance<AllocateNode>()) {
59 Allocate allocate_node = Downcast<Allocate>(kv.first);
60 extent_size = CalculateExtentsSize(allocate_node.operator->()).IntValue();
61 } else if (kv.first->IsInstance<AllocateConstNode>()) {
62 AllocateConst allocate_const_node = Downcast<AllocateConst>(kv.first);
63 extent_size = CalculateExtentsSize(allocate_const_node.operator->()).IntValue();
64 } else {
65 ICHECK(false) << "Not supported node type " << kv.first->GetTypeKey();
66 }
67 PoolAllocation pool_allocation = kv.second;
68 PoolInfo pool_info = pool_allocation->pool_info;
69 int byte_pool_offset = pool_allocation->byte_offset->value;
70 int required_pool_size_for_allocation = byte_pool_offset + extent_size;
71 if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
72 all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
73 } else {
74 int prev_required_pool_size = all_pools_sizes_[pool_info];
75 if (prev_required_pool_size < required_pool_size_for_allocation) {
76 all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
77 }
78 }
79 }
80
81 for (const auto& kv : all_pools_sizes_) {
82 PoolInfo pi = kv.first;
83 int allocated_size = kv.second;
84 allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
85 }
86 std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
87 [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
88 if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
89 return true;
90 }
91 return false;
92 });
93 }
94 IRModule operator()();
95
96 private:
97 PrimExpr VisitExpr_(const CallNode* op) override;
98 Stmt VisitStmt_(const AllocateNode* op) override;
99 PrimExpr VisitExpr_(const VarNode* op) override;
100 PrimExpr VisitExpr_(const BufferLoadNode* op) override;
101 Stmt VisitStmt_(const BufferStoreNode* op) override;
102
103 Stmt VisitStmt_(const AllocateConstNode* op) override;
104 LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body);
105 /*! \brief This is a structure where the modified function
106 * signature is kept while body of the function is mutated
107 */
108 struct ScopeInfo {
109 Array<tir::Var> params;
110 Map<PoolInfo, tir::Var> pools_to_params;
111 Array<AllocatedPoolInfo> allocated_pool_params;
112 Map<tir::Var, Buffer> buffer_map;
113 };
114
115 /*! \brief The function scope information that are needed
116 * in the mutation of the function need to be stacked and
117 * popped when each function is entered/exited in the
118 * mutation process.
119 */
120 std::stack<ScopeInfo> scope_stack;
121 /*! \brief Each PrimFunc signature needs to be updated
122 * with pool variables. This is a helper function to
123 * capture the updated information to ScopeInfo object.
124 */
125 ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func);
126 /*! \brief This is a helper to create the PrimFunc with
127 * pool variables that calls the UpdateFunctionScopeInfo
128 * inside of it.
129 */
130 PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc);
131 /*! \brief This is a helper to append the pool args to
132 * the callsite of the function.
133 */
134 Array<PrimExpr> AppendPoolParamsToArgs(Array<PrimExpr> args, bool has_device_context);
135 /*! \brief Some arguments that used to be Allocate nodes
136 * should be replaced by Let nodes in the pass that loads
137 * the space from a pool variable.
138 */
139 Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args);
140 /*! \brief Obtain a resource handle if its there
141 */
142 Optional<Var> GetResourceHandle(const PrimFunc& func);
143 /*! \brief Get the Buffer object representing the mapped access into
144 * the pool.
145 */
146 Buffer GetRemappedBuffer(Buffer buf);
147
148 /*! \brief The tir::Var map to PoolInfo objects */
149 Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_;
150 /*! \brief The buffer var map to their allocate nodes */
151 Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
152 /*! \brief The IRModule being constructed/mutated */
153 IRModule module_;
154 /*! \brief The input allocate node to PoolAllocation map */
155 Map<tir::Stmt, PoolAllocation> pool_allocations_;
156 /*! \brief The set of ordered pools to ensure an unique order of args for functions */
157 std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
158 /*! \brief The storage of calculated pool size at init */
159 std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
160 /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded
161 * to position from a pool as designated by a PoolAllocation
162 */
163 Map<tir::Var, tir::Var> allocate_var_to_let_var_;
164 /*! \brief A map from the original buffer object
165 *
166 * Each key-value pair in this map satisfies
167 * ``allocate_buf_to_let_var[key->data] = value->data``. However,
168 * since more than one `tir::Buffer` may use the same Var, they must
169 * be tracked separately.
170 */
171 Map<tir::Buffer, tir::Buffer> original_buf_to_let_buf_;
172
173 Map<String, Bool> signature_has_device_context_;
174 /*! \brief A counter to give references to pools a reproducible unique set of names */
175 int pool_var_count_ = 0;
176 /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
177 bool emit_tvmscript_printable_ = false;
178 /*! \brief A counter to give references to pools a reproducible unique set of names */
179 std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
180
181 Map<PoolInfo, Array<ConstantInfo>> pool_initializations_;
182 void AppdendConstInitializationData(ScopeInfo si);
183};
184
185Optional<Var> PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) {
186 if (!func->params.empty() &&
187 func->buffer_map.find(func->params.back()) == func->buffer_map.end()) {
188 return func->params.back();
189 }
190 return Optional<Var>();
191}
192
193PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo(
194 const PrimFunc& original_func) {
195 ScopeInfo si;
196
197 Optional<Var> resource_handle = GetResourceHandle(original_func);
198 si.params = original_func->params;
199 if (resource_handle) {
200 si.params.pop_back();
201 ICHECK(si.params.size() == original_func->params.size() - 1);
202 }
203 si.buffer_map = original_func->buffer_map;
204 Map<tir::Var, PoolInfo> ret;
205 for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
206 PoolInfo pool_info = allocated_pool_info->pool_info;
207 String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++);
208 String var_name = pool_ref_name + "_var";
209 DataType elem_dtype = DataType::UInt(8);
210 Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global"));
211 Var pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global"));
212 si.params.push_back(pool_var);
213 si.pools_to_params.Set(pool_info, pool_var);
214 si.allocated_pool_params.push_back(AllocatedPoolInfo(
215 allocated_pool_info->pool_info, allocated_pool_info->allocated_size, si.params.size() - 1));
216
217 int pool_size = all_pools_sizes_[pool_info];
218 String buffer_var_name = pool_ref_name + "_buffer_var";
219 si.buffer_map.Set(pool_var,
220 Buffer(buffer_var /* data */, elem_dtype /* dtype */, {pool_size} /* shape */,
221 {1} /* strides */, 0 /* elem_offset */, buffer_var_name /* name */,
222 16 /* data_alignment */, 1 /* offset_factor */,
223 BufferType::kDefault /* buffer-type */));
224 }
225 if (resource_handle) {
226 si.params.push_back(resource_handle.value());
227 }
228 return si;
229}
230
231PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
232 const PrimFunc& original_primfunc) {
233 // Only create the new function if it was not modified with pool params
234 if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) {
235 ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
236 this->scope_stack.push(si);
237 Stmt new_body = this->VisitStmt(original_primfunc->body);
238 this->scope_stack.pop();
239 DictAttrs original_attrs = original_primfunc->attrs;
240 // We dont need attrs of PrimFunc that might include non printable attrs such as target
241 // for unit tests where emit_tvmscript_printable_ is to be used.
242 if (emit_tvmscript_printable_) {
243 original_attrs = DictAttrs();
244 }
245 PrimFunc ret =
246 PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
247 if (!emit_tvmscript_printable_) {
248 ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
249 }
250 visited_primfuncs.insert(ret);
251 return ret;
252 }
253 return original_primfunc;
254}
255
256Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(Array<PrimExpr> args,
257 bool has_device_context) {
258 Array<PrimExpr> new_args;
259 PrimExpr resource_handle_arg;
260 // name, params...params[, context]
261 if (has_device_context) {
262 resource_handle_arg = args.back();
263 args.pop_back();
264 }
265 for (const auto& arg : args) {
266 new_args.push_back(VisitExpr(arg));
267 }
268 ScopeInfo top_scope = this->scope_stack.top();
269 for (const auto& pools_vars : top_scope.pools_to_params) {
270 tir::Var pool_var = pools_vars.second;
271 Buffer buffer_var = top_scope.buffer_map[pool_var];
272 new_args.push_back(buffer_var->data);
273 }
274 if (resource_handle_arg.defined()) {
275 new_args.push_back(resource_handle_arg);
276 }
277 return new_args;
278}
279
280Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
281 const Array<PrimExpr>& args) {
282 Array<PrimExpr> ret;
283 for (const PrimExpr& arg : args) {
284 if (arg->IsInstance<VarNode>() &&
285 allocate_var_to_let_var_.find(Downcast<Var>(arg)) != allocate_var_to_let_var_.end()) {
286 ret.push_back(allocate_var_to_let_var_[Downcast<Var>(arg)]);
287 } else {
288 ret.push_back(VisitExpr(arg));
289 }
290 }
291 return ret;
292}
293
294PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
295 if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
296 String func_name = Downcast<StringImm>(op->args[0])->value;
297 Array<PrimExpr> new_args;
298 if (module_->ContainGlobalVar(func_name) &&
299 module_->Lookup(func_name)->IsInstance<PrimFuncNode>()) {
300 GlobalVar gv = module_->GetGlobalVar(func_name);
301 PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
302
303 if (!signature_has_device_context_.count(func_name)) {
304 if (op->args.size() == func->params.size() + 2) {
305 signature_has_device_context_.Set(func_name, Bool(true));
306 } else {
307 signature_has_device_context_.Set(func_name, Bool(false));
308 }
309 }
310
311 PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
312 module_->Update(gv, prim_func);
313 new_args = AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]);
314 new_args = ReplaceAllocateArgsWithLetArgs(new_args);
315 } else {
316 new_args = ReplaceAllocateArgsWithLetArgs(op->args);
317 }
318 return Call(op->dtype, op->op, new_args);
319 }
320 if (op->op->IsInstance<PrimFuncNode>()) {
321 String func_name = Downcast<StringImm>(op->args[0])->value;
322 PrimFunc func = Downcast<PrimFunc>(op->op);
323 PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
324 Array<PrimExpr> new_args =
325 AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]);
326 new_args = ReplaceAllocateArgsWithLetArgs(new_args);
327 return Call(op->dtype, prim_func, new_args);
328 }
329 return StmtExprMutator::VisitExpr_(op);
330}
331
332LetStmt PoolAllocationToOffsetConverter::ToLetStmt(const PoolAllocation& pool_allocation,
333 const Var& buffer_var, const Stmt& body) {
334 ScopeInfo scope_info = scope_stack.top();
335 Var param = scope_info.pools_to_params[pool_allocation->pool_info];
336 BufferLoad load_node = BufferLoad(scope_info.buffer_map[param], {pool_allocation->byte_offset});
337 Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node});
338
339 Type let_var_type = buffer_var->type_annotation;
340 if (emit_tvmscript_printable_) {
341 // Strip the storage_scope from the variable type, as TVMScript
342 // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``)
343 // correctly.
344 let_var_type = PointerType(Downcast<PointerType>(let_var_type)->element_type);
345 }
346 Var let_var(buffer_var->name_hint + "_let", let_var_type);
347 allocate_var_to_let_var_.Set(buffer_var, let_var);
348 Stmt new_body = VisitStmt(body);
349 allocate_var_to_let_var_.erase(buffer_var);
350 return LetStmt(let_var, address_of_load, new_body);
351}
352
353Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
354 if (pool_allocations_.count(GetRef<Allocate>(op))) {
355 return ToLetStmt(pool_allocations_[GetRef<Stmt>(op)], op->buffer_var, op->body);
356 }
357 return StmtExprMutator::VisitStmt_(op);
358}
359
360Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateConstNode* op) {
361 if (pool_allocations_.count(GetRef<AllocateConst>(op))) {
362 const auto& result = ToLetStmt(pool_allocations_[GetRef<Stmt>(op)], op->buffer_var, op->body);
363
364 PoolInfo pool_info = pool_allocations_[GetRef<Stmt>(op)]->pool_info;
365 if (pool_initializations_.find(pool_info) == pool_initializations_.end()) {
366 pool_initializations_.Set(pool_info, {});
367 }
368
369 auto consts = pool_initializations_[pool_info];
370 consts.push_back({result->var->name_hint, pool_allocations_[GetRef<Stmt>(op)]->byte_offset,
371 op->data.value()});
372
373 pool_initializations_.Set(pool_info, consts);
374 return result;
375 }
376 return StmtExprMutator::VisitStmt_(op);
377}
378
379Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) {
380 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
381
382 Buffer remapped = GetRemappedBuffer(store->buffer);
383 if (!op->buffer.same_as(remapped)) {
384 store.CopyOnWrite()->buffer = remapped;
385 }
386 return std::move(store);
387}
388
389PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) {
390 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
391
392 Buffer remapped = GetRemappedBuffer(load->buffer);
393 if (!op->buffer.same_as(remapped)) {
394 load.CopyOnWrite()->buffer = remapped;
395 }
396 return std::move(load);
397}
398
399PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const VarNode* op) {
400 auto it = allocate_var_to_let_var_.find(GetRef<Var>(op));
401 if (it != allocate_var_to_let_var_.end()) {
402 return (*it).second;
403 }
404
405 return StmtExprMutator::VisitExpr_(op);
406}
407
408Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) {
409 {
410 auto it = original_buf_to_let_buf_.find(original);
411 if (it != original_buf_to_let_buf_.end()) {
412 return (*it).second;
413 }
414 }
415
416 Buffer remapped = original;
417
418 auto it = allocate_var_to_let_var_.find(original->data);
419 if (it != allocate_var_to_let_var_.end()) {
420 remapped = Buffer((*it).second, original->dtype, original->shape, original->strides,
421 original->elem_offset, original->name, original->data_alignment,
422 original->offset_factor, original->buffer_type, original->axis_separators,
423 original->span);
424 }
425
426 original_buf_to_let_buf_.Set(original, remapped);
427 return remapped;
428}
429
430void PoolAllocationToOffsetConverter::AppdendConstInitializationData(
431 PoolAllocationToOffsetConverter::ScopeInfo si) {
432 for (AllocatedPoolInfo api : si.allocated_pool_params) {
433 const auto& it = pool_initializations_.find(api->pool_info);
434 if (it != pool_initializations_.end()) {
435 auto* pi = const_cast<ConstantPoolInfoNode*>(api->pool_info.as<ConstantPoolInfoNode>());
436 pi->constant_info_array = (*it).second;
437 }
438 }
439}
440
441IRModule PoolAllocationToOffsetConverter::operator()() {
442 GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
443 PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
444 ScopeInfo si = UpdateFunctionScopeInfo(main_func);
445 this->scope_stack.push(si);
446 Stmt main_func_body = this->VisitStmt(main_func->body);
447 this->scope_stack.pop();
448 AppdendConstInitializationData(si);
449 // We dont need attrs of PrimFunc that might include non printable attrs such as target
450 // for unit tests where emit_tvmscript_printable_ is to be used.
451 if (!emit_tvmscript_printable_) {
452 main_func =
453 PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs);
454 main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params);
455 } else {
456 main_func =
457 PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs());
458 }
459 module_->Update(gv, main_func);
460 if (!emit_tvmscript_printable_) {
461 return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params);
462 }
463 return this->module_;
464}
465
466namespace transform {
467
468tvm::transform::Pass ConvertPoolAllocationsToOffsets(
469 const Map<tir::Stmt, PoolAllocation>& pool_allocations, Bool emit_tvmscript_printable) {
470 auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
471 return Downcast<IRModule>(PoolAllocationToOffsetConverter(
472 m, pool_allocations, emit_tvmscript_printable->value != 0)());
473 };
474 return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets",
475 {});
476}
477
478TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets")
479 .set_body_typed(ConvertPoolAllocationsToOffsets);
480
481} // namespace transform
482
483} // namespace usmp
484} // namespace tir
485} // namespace tvm
486