1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc |
22 | * \brief This pass would convert the pool allocations to offsets from pools |
23 | */ |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/function.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | #include <tvm/tir/transform.h> |
32 | #include <tvm/tir/usmp/transform.h> |
33 | #include <tvm/tir/usmp/utils.h> |
34 | |
35 | #include <stack> |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | namespace usmp { |
40 | |
41 | /*! |
42 | * \brief The StmtExpr mutator class to replace allocate nodes |
43 | * with offsets within memory pools |
44 | * |
45 | * This mutator class will add Pool variables recursively to every PrimFunc |
46 | * starting from the main PrimFunc. For all allocate nodes, that have been |
47 | * memory planned, will be mutated into an offset using a Let binding. |
48 | */ |
49 | class PoolAllocationToOffsetConverter : public StmtExprMutator { |
50 | public: |
51 | PoolAllocationToOffsetConverter(const IRModule& module, |
52 | const Map<tir::Stmt, PoolAllocation>& pool_allocations, |
53 | bool emit_tvmscript_printable = false) |
54 | : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) { |
55 | module_ = module->ShallowCopy(); |
56 | for (const auto& kv : pool_allocations) { |
57 | size_t extent_size = -1; |
58 | if (kv.first->IsInstance<AllocateNode>()) { |
59 | Allocate allocate_node = Downcast<Allocate>(kv.first); |
60 | extent_size = CalculateExtentsSize(allocate_node.operator->()).IntValue(); |
61 | } else if (kv.first->IsInstance<AllocateConstNode>()) { |
62 | AllocateConst allocate_const_node = Downcast<AllocateConst>(kv.first); |
63 | extent_size = CalculateExtentsSize(allocate_const_node.operator->()).IntValue(); |
64 | } else { |
65 | ICHECK(false) << "Not supported node type " << kv.first->GetTypeKey(); |
66 | } |
67 | PoolAllocation pool_allocation = kv.second; |
68 | PoolInfo pool_info = pool_allocation->pool_info; |
69 | int byte_pool_offset = pool_allocation->byte_offset->value; |
70 | int required_pool_size_for_allocation = byte_pool_offset + extent_size; |
71 | if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { |
72 | all_pools_sizes_[pool_info] = required_pool_size_for_allocation; |
73 | } else { |
74 | int prev_required_pool_size = all_pools_sizes_[pool_info]; |
75 | if (prev_required_pool_size < required_pool_size_for_allocation) { |
76 | all_pools_sizes_[pool_info] = required_pool_size_for_allocation; |
77 | } |
78 | } |
79 | } |
80 | |
81 | for (const auto& kv : all_pools_sizes_) { |
82 | PoolInfo pi = kv.first; |
83 | int allocated_size = kv.second; |
84 | allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size)); |
85 | } |
86 | std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(), |
87 | [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) { |
88 | if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) { |
89 | return true; |
90 | } |
91 | return false; |
92 | }); |
93 | } |
94 | IRModule operator()(); |
95 | |
96 | private: |
97 | PrimExpr VisitExpr_(const CallNode* op) override; |
98 | Stmt VisitStmt_(const AllocateNode* op) override; |
99 | PrimExpr VisitExpr_(const VarNode* op) override; |
100 | PrimExpr VisitExpr_(const BufferLoadNode* op) override; |
101 | Stmt VisitStmt_(const BufferStoreNode* op) override; |
102 | |
103 | Stmt VisitStmt_(const AllocateConstNode* op) override; |
104 | LetStmt ToLetStmt(const PoolAllocation& pool_allocation, const Var& buffer_var, const Stmt& body); |
105 | /*! \brief This is a structure where the modified function |
106 | * signature is kept while body of the function is mutated |
107 | */ |
108 | struct ScopeInfo { |
109 | Array<tir::Var> params; |
110 | Map<PoolInfo, tir::Var> pools_to_params; |
111 | Array<AllocatedPoolInfo> allocated_pool_params; |
112 | Map<tir::Var, Buffer> buffer_map; |
113 | }; |
114 | |
115 | /*! \brief The function scope information that are needed |
116 | * in the mutation of the function need to be stacked and |
117 | * popped when each function is entered/exited in the |
118 | * mutation process. |
119 | */ |
120 | std::stack<ScopeInfo> scope_stack; |
121 | /*! \brief Each PrimFunc signature needs to be updated |
122 | * with pool variables. This is a helper function to |
123 | * capture the updated information to ScopeInfo object. |
124 | */ |
125 | ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func); |
126 | /*! \brief This is a helper to create the PrimFunc with |
127 | * pool variables that calls the UpdateFunctionScopeInfo |
128 | * inside of it. |
129 | */ |
130 | PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc); |
131 | /*! \brief This is a helper to append the pool args to |
132 | * the callsite of the function. |
133 | */ |
134 | Array<PrimExpr> AppendPoolParamsToArgs(Array<PrimExpr> args, bool has_device_context); |
135 | /*! \brief Some arguments that used to be Allocate nodes |
136 | * should be replaced by Let nodes in the pass that loads |
137 | * the space from a pool variable. |
138 | */ |
139 | Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args); |
140 | /*! \brief Obtain a resource handle if its there |
141 | */ |
142 | Optional<Var> GetResourceHandle(const PrimFunc& func); |
143 | /*! \brief Get the Buffer object representing the mapped access into |
144 | * the pool. |
145 | */ |
146 | Buffer GetRemappedBuffer(Buffer buf); |
147 | |
148 | /*! \brief The tir::Var map to PoolInfo objects */ |
149 | Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_; |
150 | /*! \brief The buffer var map to their allocate nodes */ |
151 | Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_; |
152 | /*! \brief The IRModule being constructed/mutated */ |
153 | IRModule module_; |
154 | /*! \brief The input allocate node to PoolAllocation map */ |
155 | Map<tir::Stmt, PoolAllocation> pool_allocations_; |
156 | /*! \brief The set of ordered pools to ensure an unique order of args for functions */ |
157 | std::vector<AllocatedPoolInfo> allocated_pool_ordering_; |
158 | /*! \brief The storage of calculated pool size at init */ |
159 | std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_; |
160 | /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded |
161 | * to position from a pool as designated by a PoolAllocation |
162 | */ |
163 | Map<tir::Var, tir::Var> allocate_var_to_let_var_; |
164 | /*! \brief A map from the original buffer object |
165 | * |
166 | * Each key-value pair in this map satisfies |
167 | * ``allocate_buf_to_let_var[key->data] = value->data``. However, |
168 | * since more than one `tir::Buffer` may use the same Var, they must |
169 | * be tracked separately. |
170 | */ |
171 | Map<tir::Buffer, tir::Buffer> original_buf_to_let_buf_; |
172 | |
173 | Map<String, Bool> signature_has_device_context_; |
174 | /*! \brief A counter to give references to pools a reproducible unique set of names */ |
175 | int pool_var_count_ = 0; |
176 | /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ |
177 | bool emit_tvmscript_printable_ = false; |
178 | /*! \brief A counter to give references to pools a reproducible unique set of names */ |
179 | std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs; |
180 | |
181 | Map<PoolInfo, Array<ConstantInfo>> pool_initializations_; |
182 | void AppdendConstInitializationData(ScopeInfo si); |
183 | }; |
184 | |
185 | Optional<Var> PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) { |
186 | if (!func->params.empty() && |
187 | func->buffer_map.find(func->params.back()) == func->buffer_map.end()) { |
188 | return func->params.back(); |
189 | } |
190 | return Optional<Var>(); |
191 | } |
192 | |
193 | PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( |
194 | const PrimFunc& original_func) { |
195 | ScopeInfo si; |
196 | |
197 | Optional<Var> resource_handle = GetResourceHandle(original_func); |
198 | si.params = original_func->params; |
199 | if (resource_handle) { |
200 | si.params.pop_back(); |
201 | ICHECK(si.params.size() == original_func->params.size() - 1); |
202 | } |
203 | si.buffer_map = original_func->buffer_map; |
204 | Map<tir::Var, PoolInfo> ret; |
205 | for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) { |
206 | PoolInfo pool_info = allocated_pool_info->pool_info; |
207 | String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++); |
208 | String var_name = pool_ref_name + "_var" ; |
209 | DataType elem_dtype = DataType::UInt(8); |
210 | Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global" )); |
211 | Var pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global" )); |
212 | si.params.push_back(pool_var); |
213 | si.pools_to_params.Set(pool_info, pool_var); |
214 | si.allocated_pool_params.push_back(AllocatedPoolInfo( |
215 | allocated_pool_info->pool_info, allocated_pool_info->allocated_size, si.params.size() - 1)); |
216 | |
217 | int pool_size = all_pools_sizes_[pool_info]; |
218 | String buffer_var_name = pool_ref_name + "_buffer_var" ; |
219 | si.buffer_map.Set(pool_var, |
220 | Buffer(buffer_var /* data */, elem_dtype /* dtype */, {pool_size} /* shape */, |
221 | {1} /* strides */, 0 /* elem_offset */, buffer_var_name /* name */, |
222 | 16 /* data_alignment */, 1 /* offset_factor */, |
223 | BufferType::kDefault /* buffer-type */)); |
224 | } |
225 | if (resource_handle) { |
226 | si.params.push_back(resource_handle.value()); |
227 | } |
228 | return si; |
229 | } |
230 | |
231 | PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( |
232 | const PrimFunc& original_primfunc) { |
233 | // Only create the new function if it was not modified with pool params |
234 | if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) { |
235 | ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc); |
236 | this->scope_stack.push(si); |
237 | Stmt new_body = this->VisitStmt(original_primfunc->body); |
238 | this->scope_stack.pop(); |
239 | DictAttrs original_attrs = original_primfunc->attrs; |
240 | // We dont need attrs of PrimFunc that might include non printable attrs such as target |
241 | // for unit tests where emit_tvmscript_printable_ is to be used. |
242 | if (emit_tvmscript_printable_) { |
243 | original_attrs = DictAttrs(); |
244 | } |
245 | PrimFunc ret = |
246 | PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); |
247 | if (!emit_tvmscript_printable_) { |
248 | ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); |
249 | } |
250 | visited_primfuncs.insert(ret); |
251 | return ret; |
252 | } |
253 | return original_primfunc; |
254 | } |
255 | |
256 | Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(Array<PrimExpr> args, |
257 | bool has_device_context) { |
258 | Array<PrimExpr> new_args; |
259 | PrimExpr resource_handle_arg; |
260 | // name, params...params[, context] |
261 | if (has_device_context) { |
262 | resource_handle_arg = args.back(); |
263 | args.pop_back(); |
264 | } |
265 | for (const auto& arg : args) { |
266 | new_args.push_back(VisitExpr(arg)); |
267 | } |
268 | ScopeInfo top_scope = this->scope_stack.top(); |
269 | for (const auto& pools_vars : top_scope.pools_to_params) { |
270 | tir::Var pool_var = pools_vars.second; |
271 | Buffer buffer_var = top_scope.buffer_map[pool_var]; |
272 | new_args.push_back(buffer_var->data); |
273 | } |
274 | if (resource_handle_arg.defined()) { |
275 | new_args.push_back(resource_handle_arg); |
276 | } |
277 | return new_args; |
278 | } |
279 | |
280 | Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( |
281 | const Array<PrimExpr>& args) { |
282 | Array<PrimExpr> ret; |
283 | for (const PrimExpr& arg : args) { |
284 | if (arg->IsInstance<VarNode>() && |
285 | allocate_var_to_let_var_.find(Downcast<Var>(arg)) != allocate_var_to_let_var_.end()) { |
286 | ret.push_back(allocate_var_to_let_var_[Downcast<Var>(arg)]); |
287 | } else { |
288 | ret.push_back(VisitExpr(arg)); |
289 | } |
290 | } |
291 | return ret; |
292 | } |
293 | |
294 | PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { |
295 | if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { |
296 | String func_name = Downcast<StringImm>(op->args[0])->value; |
297 | Array<PrimExpr> new_args; |
298 | if (module_->ContainGlobalVar(func_name) && |
299 | module_->Lookup(func_name)->IsInstance<PrimFuncNode>()) { |
300 | GlobalVar gv = module_->GetGlobalVar(func_name); |
301 | PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv)); |
302 | |
303 | if (!signature_has_device_context_.count(func_name)) { |
304 | if (op->args.size() == func->params.size() + 2) { |
305 | signature_has_device_context_.Set(func_name, Bool(true)); |
306 | } else { |
307 | signature_has_device_context_.Set(func_name, Bool(false)); |
308 | } |
309 | } |
310 | |
311 | PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); |
312 | module_->Update(gv, prim_func); |
313 | new_args = AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]); |
314 | new_args = ReplaceAllocateArgsWithLetArgs(new_args); |
315 | } else { |
316 | new_args = ReplaceAllocateArgsWithLetArgs(op->args); |
317 | } |
318 | return Call(op->dtype, op->op, new_args); |
319 | } |
320 | if (op->op->IsInstance<PrimFuncNode>()) { |
321 | String func_name = Downcast<StringImm>(op->args[0])->value; |
322 | PrimFunc func = Downcast<PrimFunc>(op->op); |
323 | PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); |
324 | Array<PrimExpr> new_args = |
325 | AppendPoolParamsToArgs(op->args, signature_has_device_context_[func_name]); |
326 | new_args = ReplaceAllocateArgsWithLetArgs(new_args); |
327 | return Call(op->dtype, prim_func, new_args); |
328 | } |
329 | return StmtExprMutator::VisitExpr_(op); |
330 | } |
331 | |
332 | LetStmt PoolAllocationToOffsetConverter::ToLetStmt(const PoolAllocation& pool_allocation, |
333 | const Var& buffer_var, const Stmt& body) { |
334 | ScopeInfo scope_info = scope_stack.top(); |
335 | Var param = scope_info.pools_to_params[pool_allocation->pool_info]; |
336 | BufferLoad load_node = BufferLoad(scope_info.buffer_map[param], {pool_allocation->byte_offset}); |
337 | Call address_of_load = Call(DataType::Handle(), builtin::address_of(), {load_node}); |
338 | |
339 | Type let_var_type = buffer_var->type_annotation; |
340 | if (emit_tvmscript_printable_) { |
341 | // Strip the storage_scope from the variable type, as TVMScript |
342 | // doesn't parsethe scoped pointers (e.g. ``T.Ptr[global T.int32]``) |
343 | // correctly. |
344 | let_var_type = PointerType(Downcast<PointerType>(let_var_type)->element_type); |
345 | } |
346 | Var let_var(buffer_var->name_hint + "_let" , let_var_type); |
347 | allocate_var_to_let_var_.Set(buffer_var, let_var); |
348 | Stmt new_body = VisitStmt(body); |
349 | allocate_var_to_let_var_.erase(buffer_var); |
350 | return LetStmt(let_var, address_of_load, new_body); |
351 | } |
352 | |
353 | Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { |
354 | if (pool_allocations_.count(GetRef<Allocate>(op))) { |
355 | return ToLetStmt(pool_allocations_[GetRef<Stmt>(op)], op->buffer_var, op->body); |
356 | } |
357 | return StmtExprMutator::VisitStmt_(op); |
358 | } |
359 | |
360 | Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateConstNode* op) { |
361 | if (pool_allocations_.count(GetRef<AllocateConst>(op))) { |
362 | const auto& result = ToLetStmt(pool_allocations_[GetRef<Stmt>(op)], op->buffer_var, op->body); |
363 | |
364 | PoolInfo pool_info = pool_allocations_[GetRef<Stmt>(op)]->pool_info; |
365 | if (pool_initializations_.find(pool_info) == pool_initializations_.end()) { |
366 | pool_initializations_.Set(pool_info, {}); |
367 | } |
368 | |
369 | auto consts = pool_initializations_[pool_info]; |
370 | consts.push_back({result->var->name_hint, pool_allocations_[GetRef<Stmt>(op)]->byte_offset, |
371 | op->data.value()}); |
372 | |
373 | pool_initializations_.Set(pool_info, consts); |
374 | return result; |
375 | } |
376 | return StmtExprMutator::VisitStmt_(op); |
377 | } |
378 | |
379 | Stmt PoolAllocationToOffsetConverter::VisitStmt_(const BufferStoreNode* op) { |
380 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
381 | |
382 | Buffer remapped = GetRemappedBuffer(store->buffer); |
383 | if (!op->buffer.same_as(remapped)) { |
384 | store.CopyOnWrite()->buffer = remapped; |
385 | } |
386 | return std::move(store); |
387 | } |
388 | |
389 | PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { |
390 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
391 | |
392 | Buffer remapped = GetRemappedBuffer(load->buffer); |
393 | if (!op->buffer.same_as(remapped)) { |
394 | load.CopyOnWrite()->buffer = remapped; |
395 | } |
396 | return std::move(load); |
397 | } |
398 | |
399 | PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const VarNode* op) { |
400 | auto it = allocate_var_to_let_var_.find(GetRef<Var>(op)); |
401 | if (it != allocate_var_to_let_var_.end()) { |
402 | return (*it).second; |
403 | } |
404 | |
405 | return StmtExprMutator::VisitExpr_(op); |
406 | } |
407 | |
408 | Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) { |
409 | { |
410 | auto it = original_buf_to_let_buf_.find(original); |
411 | if (it != original_buf_to_let_buf_.end()) { |
412 | return (*it).second; |
413 | } |
414 | } |
415 | |
416 | Buffer remapped = original; |
417 | |
418 | auto it = allocate_var_to_let_var_.find(original->data); |
419 | if (it != allocate_var_to_let_var_.end()) { |
420 | remapped = Buffer((*it).second, original->dtype, original->shape, original->strides, |
421 | original->elem_offset, original->name, original->data_alignment, |
422 | original->offset_factor, original->buffer_type, original->axis_separators, |
423 | original->span); |
424 | } |
425 | |
426 | original_buf_to_let_buf_.Set(original, remapped); |
427 | return remapped; |
428 | } |
429 | |
430 | void PoolAllocationToOffsetConverter::AppdendConstInitializationData( |
431 | PoolAllocationToOffsetConverter::ScopeInfo si) { |
432 | for (AllocatedPoolInfo api : si.allocated_pool_params) { |
433 | const auto& it = pool_initializations_.find(api->pool_info); |
434 | if (it != pool_initializations_.end()) { |
435 | auto* pi = const_cast<ConstantPoolInfoNode*>(api->pool_info.as<ConstantPoolInfoNode>()); |
436 | pi->constant_info_array = (*it).second; |
437 | } |
438 | } |
439 | } |
440 | |
441 | IRModule PoolAllocationToOffsetConverter::operator()() { |
442 | GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main); |
443 | PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv)); |
444 | ScopeInfo si = UpdateFunctionScopeInfo(main_func); |
445 | this->scope_stack.push(si); |
446 | Stmt main_func_body = this->VisitStmt(main_func->body); |
447 | this->scope_stack.pop(); |
448 | AppdendConstInitializationData(si); |
449 | // We dont need attrs of PrimFunc that might include non printable attrs such as target |
450 | // for unit tests where emit_tvmscript_printable_ is to be used. |
451 | if (!emit_tvmscript_printable_) { |
452 | main_func = |
453 | PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); |
454 | main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); |
455 | } else { |
456 | main_func = |
457 | PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); |
458 | } |
459 | module_->Update(gv, main_func); |
460 | if (!emit_tvmscript_printable_) { |
461 | return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params); |
462 | } |
463 | return this->module_; |
464 | } |
465 | |
466 | namespace transform { |
467 | |
468 | tvm::transform::Pass ConvertPoolAllocationsToOffsets( |
469 | const Map<tir::Stmt, PoolAllocation>& pool_allocations, Bool emit_tvmscript_printable) { |
470 | auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { |
471 | return Downcast<IRModule>(PoolAllocationToOffsetConverter( |
472 | m, pool_allocations, emit_tvmscript_printable->value != 0)()); |
473 | }; |
474 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets" , |
475 | {}); |
476 | } |
477 | |
478 | TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets" ) |
479 | .set_body_typed(ConvertPoolAllocationsToOffsets); |
480 | |
481 | } // namespace transform |
482 | |
483 | } // namespace usmp |
484 | } // namespace tir |
485 | } // namespace tvm |
486 | |