1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file merge_dynamic_shared_memory_allocations.cc
22 * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation.
23 * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation.
24 */
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include <unordered_map>
31#include <unordered_set>
32
33#include "../../runtime/thread_storage_scope.h"
34#include "../../support/arena.h"
35#include "ir_utils.h"
36
37namespace tvm {
38namespace tir {
39
40using runtime::StorageRank;
41using runtime::StorageScope;
42
43bool IsDynamicSharedMemory(Var buffer_var) {
44 StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
45 return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
46}
47
48/*!
49 * \brief collect the mapping from the buffer var to its allocate
50 */
51class AllocateCollector : public StmtExprVisitor {
52 public:
53 void VisitStmt_(const AllocateNode* op) final {
54 if (IsDynamicSharedMemory(op->buffer_var)) {
55 dyn_shmem_allocs_[op->buffer_var.get()] = op;
56 }
57 StmtExprVisitor::VisitStmt_(op);
58 }
59 // The mapping from the original buffer var to its allocate
60 std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
61};
62
63// Find a linear pattern of storage access
64// Used for liveness analysis.
65// "linear" means fitting a complex access pattern into an array of StmtEntry
66//
67// Define "scope" as the body of For/thread_launch/IfThenElse
68// Composite scopes(loop/thread_launch/IfThen) is represented by three StmtEntry:
69// before_scope -> scope_body -> after_scope
70//
71// This pass tries to detect last point that we need to keep memory
72// alive under the same scope as Allocate.
73// The storage need to be kept alive between Allocate and last access.
74// The free point is only inserted at the same scope of Allocate.
75//
76class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
77 public:
78 /*! \brief record the touch list of statement. */
79 struct StmtEntry {
80 // The statement
81 const Object* stmt;
82 // The index in the linear_seq_ to point to end of the nested scope.
83 // This is only set to non-zero if stmt is a nested scope.
84 // if offset > 0, means this is the begin, the end entry is current_index + offset
85 // if offset < 0, means this is the end, the begin entry is current_index + offset
86 int64_t scope_pair_offset{0};
87 // The buffer variables this statement touched.
88 std::vector<const VarNode*> touched;
89 };
90 // The scope of each allocation
91 struct AllocEntry {
92 // the level in the scope stack
93 size_t level{0};
94 // allocation stmt
95 const AllocateNode* alloc{nullptr};
96 };
97
98 void VisitStmt_(const AllocateNode* op) final {
99 size_t level = scope_.size();
100 const VarNode* buf = op->buffer_var.get();
101 alloc_info_[buf].alloc = op;
102 alloc_info_[buf].level = level;
103 StmtExprVisitor::VisitStmt_(op);
104 }
105
106 void VisitStmt_(const StoreNode* op) final {
107 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
108 }
109
110 void VisitStmt_(const BufferStoreNode* op) final {
111 scope_.push_back(StmtEntry());
112 // visit subexpr
113 StmtExprVisitor::VisitStmt_(op);
114 // Add write access.
115 const VarNode* buf = op->buffer->data.get();
116 auto it = alloc_info_.find(buf);
117 if (it != alloc_info_.end() && it->second.alloc) {
118 ICHECK_LT(it->second.level, scope_.size());
119 if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
120 scope_[it->second.level].touched.push_back(buf);
121 }
122 }
123 StmtEntry e = scope_.back();
124 scope_.pop_back();
125 if (e.touched.size() != 0) {
126 e.stmt = op;
127 linear_seq_.push_back(e);
128 }
129 }
130
131 void VisitStmt_(const EvaluateNode* op) final {
132 scope_.push_back(StmtEntry());
133 // visit subexpr
134 StmtExprVisitor::VisitStmt_(op);
135 StmtEntry e = scope_.back();
136 scope_.pop_back();
137 if (e.touched.size() != 0) {
138 e.stmt = op;
139 linear_seq_.push_back(e);
140 }
141 }
142
143 void VisitExpr_(const LoadNode* op) final {
144 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
145 }
146
147 void VisitExpr_(const BufferLoadNode* op) final {
148 // Add write access.
149 StmtExprVisitor::VisitExpr_(op);
150 const VarNode* buf = op->buffer->data.get();
151 auto it = alloc_info_.find(buf);
152 if (it != alloc_info_.end() && it->second.alloc) {
153 ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
154 if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
155 scope_[it->second.level].touched.push_back(buf);
156 }
157 }
158 }
159
160 void VisitExpr_(const CallNode* op) final {
161 if (op->op.same_as(builtin::address_of())) {
162 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
163 for (const auto& index : load->indices) {
164 this->VisitExpr(index);
165 }
166 } else {
167 StmtExprVisitor::VisitExpr_(op);
168 }
169 }
170 void VisitExpr_(const VarNode* buf) final {
171 // Directly reference to the variable count as a read.
172 auto it = alloc_info_.find(buf);
173 if (it != alloc_info_.end() && it->second.alloc) {
174 ICHECK_LT(it->second.level, scope_.size());
175 if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
176 scope_[it->second.level].touched.push_back(buf);
177 }
178 }
179 }
180 template <typename T>
181 void VisitNewScope(const T* op) {
182 scope_.push_back(StmtEntry());
183 StmtEntry e;
184 e.stmt = op;
185 int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
186 // before scope.
187 linear_seq_.push_back(e);
188 StmtExprVisitor::VisitStmt_(op);
189 // after scope.
190 e.touched = std::move(scope_.back().touched);
191 scope_.pop_back();
192 int64_t end_index = static_cast<int64_t>(linear_seq_.size());
193 ICHECK_GT(end_index, begin_index);
194 e.scope_pair_offset = begin_index - end_index;
195 linear_seq_.push_back(e);
196 // record the pointer to end index.
197 ICHECK_NE(end_index, 0U);
198 linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
199 }
200 void VisitStmt_(const AttrStmtNode* op) final {
201 // Only record the outer most thread extent.
202 if (op->attr_key == attr::thread_extent && !in_thread_env_) {
203 in_thread_env_ = true;
204 VisitNewScope(op);
205 in_thread_env_ = false;
206 } else if (op->attr_key == attr::extern_scope) {
207 VisitNewScope(op);
208 } else if (op->attr_key == attr::virtual_thread) {
209 VisitNewScope(op);
210 } else {
211 StmtExprVisitor::VisitStmt_(op);
212 }
213 }
214 void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
215
216 void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
217
218 void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
219
220 void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
221
222 // linearized access sequence.
223 std::vector<StmtEntry> linear_seq_;
224 // The storage scope of each buffer
225 std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
226
227 private:
228 // Whether already in thread env.
229 bool in_thread_env_{false};
230 // The scope stack.
231 std::vector<StmtEntry> scope_;
232};
233
234/*!
235 * \brief merge the buffers whose live range has no intersection and rewrite the body
236 */
237class DynamicSharedMemoryRewriter : public StmtExprMutator {
238 public:
239 explicit DynamicSharedMemoryRewriter(
240 const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
241 : dyn_shmem_allocs_{dyn_shmem_allocs} {}
242
243 /*!
244 * \brief plan the memory reuse for all the buffer allocated in the statement
245 * \param stmt the statement
246 */
247 void PlanReuse(const Stmt& stmt) {
248 DynSharedMemLinearAccessPatternFinder finder;
249 finder(stmt);
250 this->LivenessAnalysis(finder.linear_seq_);
251 this->PlanMemory(finder.linear_seq_);
252 }
253
254 private:
255 Stmt VisitStmt_(const AttrStmtNode* op) final {
256 if (op->attr_key == attr::thread_extent && !allocated_) {
257 // Allocate one dynamic shared memory allocation at the beginning of thread scope
258 int max_layer_num = 0;
259 std::vector<const StorageEntry*> all_entry;
260 for (const auto& e : const_free_map_) {
261 all_entry.push_back(e.second);
262 }
263 for (const StorageEntry* e : sym_free_list_) {
264 all_entry.push_back(e);
265 }
266 for (const StorageEntry* e : all_entry) {
267 max_layer_num = std::max(max_layer_num, static_cast<int>(e->allocs.size()));
268 }
269 // calculate align for each layer of each storage entry.
270 std::vector<int> align(max_layer_num, 0);
271 for (const StorageEntry* e : all_entry) {
272 for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
273 for (const VarNode* buffer : e->allocs[i]) {
274 const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
275 align[i] = std::max(align[i], alloc->dtype.bytes());
276 }
277 }
278 }
279 // calculate offset for each buffer based on the align of each layer
280 for (const StorageEntry* e : all_entry) {
281 PrimExpr max_inner_offset = 0;
282 for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
283 PrimExpr inner_offset = 0;
284 for (const VarNode* buffer : e->allocs[i]) {
285 const AllocateNode* alloc = dyn_shmem_allocs_[buffer];
286 buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
287 inner_offset += alloc->extents[0] * alloc->dtype.bytes();
288 inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]);
289 }
290 max_inner_offset = max(max_inner_offset, inner_offset);
291 }
292 merged_alloc_size_ += max_inner_offset;
293 }
294
295 allocated_ = true;
296 Allocate new_body(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(),
297 StmtExprMutator::VisitStmt(op->body));
298 return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span);
299 }
300 return StmtMutator::VisitStmt_(op);
301 }
302
303 Stmt VisitStmt_(const AllocateNode* op) final {
304 if (IsDynamicSharedMemory(op->buffer_var)) {
305 return StmtExprMutator::VisitStmt(op->body);
306 }
307 return StmtExprMutator::VisitStmt_(op);
308 }
309
310 PrimExpr VisitExpr_(const LoadNode* op) final {
311 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
312 }
313
314 Stmt VisitStmt_(const StoreNode* op) final {
315 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
316 }
317
318 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
319 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
320 return VisitBufferAccess(std::move(node));
321 }
322
323 Stmt VisitStmt_(const BufferStoreNode* op) final {
324 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
325 return VisitBufferAccess(std::move(node));
326 }
327
328 template <typename Node>
329 Node VisitBufferAccess(Node node) {
330 if (IsDynamicSharedMemory(node->buffer->data)) {
331 ICHECK_EQ(node->indices.size(), 1)
332 << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
333 << "and is to be run after "
334 << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
335 Array<PrimExpr> indices = {node->indices[0] +
336 this->GetBufferOffset(node->buffer->data, node->buffer->dtype)};
337
338 auto writer = node.CopyOnWrite();
339 writer->buffer = GetUpdatedBuffer(node->buffer);
340 writer->indices = indices;
341 }
342
343 return node;
344 }
345
346 Buffer GetUpdatedBuffer(Buffer buffer) {
347 auto key = buffer.get();
348 auto it = buffer_remap_.find(key);
349 if (it != buffer_remap_.end()) {
350 return it->second;
351 }
352
353 if (IsDynamicSharedMemory(buffer->data)) {
354 ICHECK_EQ(buffer->shape.size(), 1)
355 << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
356 << "and is to be run after "
357 << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
358 auto writer = buffer.CopyOnWrite();
359 writer->data = merged_buf_var_;
360 }
361
362 buffer_remap_[key] = buffer;
363 return buffer;
364 }
365
366 PrimExpr VisitExpr_(const CallNode* op) final {
367 if (op->op.same_as(builtin::tvm_access_ptr())) {
368 ICHECK_EQ(op->args.size(), 5U);
369 DataType dtype = op->args[0].dtype();
370 Var buffer = Downcast<Var>(op->args[1]);
371 if (!IsDynamicSharedMemory(buffer)) {
372 return StmtExprMutator::VisitExpr_(op);
373 }
374 PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
375
376 PrimExpr offset = this->VisitExpr(op->args[2]);
377 PrimExpr extent = this->VisitExpr(op->args[3]);
378 return Call(op->dtype, op->op,
379 {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
380 } else {
381 return StmtExprMutator::VisitExpr_(op);
382 }
383 }
384
385 PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
386 auto it = buffer_byte_offsets_.find(buffer_var.get());
387 ICHECK(it != buffer_byte_offsets_.end());
388 return indexdiv(it->second, dtype.bytes());
389 }
390
391 using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry;
392 struct StorageEntry {
393 // The constant size of the buffer in bits, only used if it is constant
394 uint64_t const_nbits{0};
395 // Allocs that shares this entry.
396 // The inner vector means a "layer"
397 // For example, it we need to allocate C in the memory of A and B:
398 // | A: 4096 bytes | B: 4096 bytes |
399 // | C: 8192 bytes |
400 // Then the allocs = {{A, B}, {C}}
401 std::vector<std::vector<const VarNode*>> allocs;
402 };
403
404 // Event entry in liveness analysis
405 struct EventEntry {
406 // variables we generate
407 std::vector<const VarNode*> gen;
408 // variables we kill
409 std::vector<const VarNode*> kill;
410 };
411
412 /*!
413 * \brief Liveness analysis to find gen and kill point of each variable.
414 * \param seq the linear pattern of storage access
415 */
416 void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
417 // find kill point, do a reverse linear scan.
418 std::unordered_set<const VarNode*> touched;
419 for (size_t i = seq.size(); i != 0; --i) {
420 const StmtEntry& s = seq[i - 1];
421 for (const VarNode* buffer : s.touched) {
422 if (!touched.count(buffer)) {
423 touched.insert(buffer);
424 event_map_[s.stmt].kill.push_back(buffer);
425 }
426 }
427 }
428 // find gen point, do forward scan
429 touched.clear();
430 for (size_t i = 0; i < seq.size(); ++i) {
431 int64_t offset = seq[i].scope_pair_offset;
432 if (offset < 0) continue;
433 const StmtEntry& s = seq[i + offset];
434 for (const VarNode* buffer : s.touched) {
435 if (!touched.count(buffer)) {
436 touched.insert(buffer);
437 event_map_[s.stmt].gen.push_back(buffer);
438 }
439 }
440 }
441 }
442
443 /*!
444 * \brief Memory plan algorithm
445 * \param seq the linear pattern of storage access
446 * \param alloc_info
447 */
448 void PlanMemory(const std::vector<StmtEntry>& seq) {
449 std::unordered_set<const VarNode*> inplace_flag;
450
451 for (size_t i = 0; i < seq.size(); ++i) {
452 auto it = event_map_.find(seq[i].stmt);
453 // scope_pair_offset <= 0 means it is either
454 // - leaf stmt(offset = 0)
455 // - end of scope(offset < 0)
456 // In both cases, we need to handle the kill event correctly
457 if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
458 for (const VarNode* var : it->second.kill) {
459 this->Free(var);
460 }
461 }
462 // scope_pair_offset >= 0 means it is either
463 // - leaf stmt(offset = 0)
464 // - beginning of scope(offset < 0)
465 // In both cases, we need to handle the gen event correctly
466 if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
467 for (const VarNode* var : it->second.gen) {
468 ICHECK(dyn_shmem_allocs_.count(var));
469 const AllocateNode* alloc = dyn_shmem_allocs_[var];
470 StorageEntry* dst_entry = FindAlloc(alloc);
471 alloc_map_[var] = dst_entry;
472 }
473 }
474 }
475 }
476 /*!
477 * \brief Allocate new storage entry.
478 * \param op the allocate node
479 * \param the size of the allocation in bits
480 * \return the new storage entry
481 */
482 StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) {
483 ICHECK(op != nullptr);
484 // Re-use not successful, allocate a new buffer.
485 StorageEntry* entry = arena_.make<StorageEntry>();
486 entry->allocs.push_back({op->buffer_var.get()});
487 entry->const_nbits = const_nbits;
488 return entry;
489 }
490 /*!
491 * \brief find the storage entry in the free list for the allocate
492 * \param op the allocate node
493 * \return the storage entry
494 */
495 StorageEntry* FindAlloc(const AllocateNode* op) {
496 ICHECK(op != nullptr);
497 // skip plan for local variable,
498 // compiler can do a better job with register allocation.
499 const uint64_t match_range = 16;
500 uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
501 uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
502 // disable reuse of small arrays, they will be lowered to registers in LLVM
503 // This rules only apply if we are using non special memory
504 if (const_nbits > 0 && const_nbits <= 32) {
505 return NewAlloc(op, const_nbits);
506 }
507
508 if (const_nbits != 0) {
509 // constant allocation.
510 auto begin = const_free_map_.lower_bound(0);
511 auto mid = const_free_map_.lower_bound(const_nbits);
512 auto end = const_free_map_.upper_bound(const_nbits * match_range);
513 // Start looking at the buffer that is bigger than the required size first.
514 // If we find one, directly allocate the buffer in its location and remove its entry in the
515 // free list
516 for (auto it = mid; it != end; ++it) {
517 StorageEntry* e = it->second;
518 e->const_nbits = std::max(const_nbits, e->const_nbits);
519 const_free_map_.erase(it);
520 return e;
521 }
522 // Then start looking at smaller buffers.
523 // Keep collecting the buffer until the sum of their size exceeds the buffer to allocate
524 // and finally free all these entry in the free list
525 std::vector<std::multimap<uint64_t, StorageEntry*>::iterator> delete_it;
526 // the alloc list for the new entry
527 std::vector<std::vector<const VarNode*>> reuse_allocs;
528 uint64_t mem_ct = 0;
529 for (auto it = mid; it != begin;) {
530 --it;
531 delete_it.push_back(it);
532 mem_ct += it->second->const_nbits;
533 int n = it->second->allocs.size();
534 if (n > static_cast<int>(reuse_allocs.size())) {
535 reuse_allocs.resize(n, {});
536 }
537 for (int i = 0; i < n; i++) {
538 for (const VarNode* alloc : it->second->allocs[i]) {
539 reuse_allocs[i].push_back(alloc);
540 }
541 }
542 if (mem_ct >= const_nbits) {
543 break;
544 }
545 }
546 reuse_allocs.push_back({op->buffer_var.get()});
547 if (mem_ct != 0) {
548 StorageEntry* e = arena_.make<StorageEntry>();
549 e->const_nbits = std::max(const_nbits, mem_ct);
550 e->allocs = reuse_allocs;
551 for (auto it : delete_it) {
552 const_free_map_.erase(it);
553 }
554 return e;
555 }
556 } else {
557 // if its symbolic allocation, just arbitrarily choose one entry to fit in because we don't
558 // know its actual size
559 for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
560 StorageEntry* e = *it;
561 sym_free_list_.erase(it);
562 return e;
563 }
564 }
565 return NewAlloc(op, const_nbits);
566 }
567
568 /*!
569 * \brief add the storage entry to the buffer var into the free list.
570 * \param var the buffer var
571 */
572 void Free(const VarNode* var) {
573 auto it = alloc_map_.find(var);
574 ICHECK(it != alloc_map_.end());
575 StorageEntry* e = it->second;
576 ICHECK_NE(e->allocs.size(), 0U);
577
578 // disable reuse of small arrays
579 if (e->const_nbits > 0 && e->const_nbits <= 32) return;
580
581 // normal free.
582 if (e->const_nbits != 0) {
583 const_free_map_.insert({e->const_nbits, e});
584 } else {
585 sym_free_list_.push_back(e);
586 }
587 }
588 // The var for the merged buffer
589 Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
590 // The mapping from the original buffer var to its allocate
591 std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
592 // The size of the merged buffer
593 PrimExpr merged_alloc_size_{0};
594 // The mapping from the original buffer var to its offset in the merged buffer
595 std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
596 // The mapping from the original buffer objects to their location in the merged buffer.
597 std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
598 // The flag indicating whether the merged buffer has been allocated
599 bool allocated_{false};
600 // Locations of free ops.
601 std::unordered_map<const Object*, EventEntry> event_map_;
602 // constant size free map.
603 std::multimap<uint64_t, StorageEntry*> const_free_map_;
604 // symbolic free list, for non constant items.
605 std::list<StorageEntry*> sym_free_list_;
606 // The allocation assign map
607 std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
608 /*! \brief allocator of all the StorageEntry*/
609 support::Arena arena_;
610};
611
612Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
613 AllocateCollector collector;
614 collector(stmt);
615 if (collector.dyn_shmem_allocs_.size() > 1) {
616 DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
617 rewriter.PlanReuse(stmt);
618 return rewriter(std::move(stmt));
619 }
620 return stmt;
621}
622
623namespace transform {
624
625Pass MergeDynamicSharedMemoryAllocations() {
626 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
627 auto* n = f.CopyOnWrite();
628 n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
629 return f;
630 };
631 return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {});
632}
633
634TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations")
635 .set_body_typed(MergeDynamicSharedMemoryAllocations);
636
637} // namespace transform
638} // namespace tir
639} // namespace tvm
640