1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file merge_dynamic_shared_memory_allocations.cc |
22 | * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. |
23 | * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. |
24 | */ |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | |
33 | #include "../../runtime/thread_storage_scope.h" |
34 | #include "../../support/arena.h" |
35 | #include "ir_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace tir { |
39 | |
40 | using runtime::StorageRank; |
41 | using runtime::StorageScope; |
42 | |
43 | bool IsDynamicSharedMemory(Var buffer_var) { |
44 | StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); |
45 | return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn" ; |
46 | } |
47 | |
48 | /*! |
49 | * \brief collect the mapping from the buffer var to its allocate |
50 | */ |
51 | class AllocateCollector : public StmtExprVisitor { |
52 | public: |
53 | void VisitStmt_(const AllocateNode* op) final { |
54 | if (IsDynamicSharedMemory(op->buffer_var)) { |
55 | dyn_shmem_allocs_[op->buffer_var.get()] = op; |
56 | } |
57 | StmtExprVisitor::VisitStmt_(op); |
58 | } |
59 | // The mapping from the original buffer var to its allocate |
60 | std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_; |
61 | }; |
62 | |
63 | // Find a linear pattern of storage access |
64 | // Used for liveness analysis. |
65 | // "linear" means fitting a complex access pattern into an array of StmtEntry |
66 | // |
67 | // Define "scope" as the body of For/thread_launch/IfThenElse |
68 | // Composite scopes(loop/thread_launch/IfThen) is represented by three StmtEntry: |
69 | // before_scope -> scope_body -> after_scope |
70 | // |
71 | // This pass tries to detect last point that we need to keep memory |
72 | // alive under the same scope as Allocate. |
73 | // The storage need to be kept alive between Allocate and last access. |
74 | // The free point is only inserted at the same scope of Allocate. |
75 | // |
76 | class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { |
77 | public: |
78 | /*! \brief record the touch list of statement. */ |
79 | struct StmtEntry { |
80 | // The statement |
81 | const Object* stmt; |
82 | // The index in the linear_seq_ to point to end of the nested scope. |
83 | // This is only set to non-zero if stmt is a nested scope. |
84 | // if offset > 0, means this is the begin, the end entry is current_index + offset |
85 | // if offset < 0, means this is the end, the begin entry is current_index + offset |
86 | int64_t scope_pair_offset{0}; |
87 | // The buffer variables this statement touched. |
88 | std::vector<const VarNode*> touched; |
89 | }; |
90 | // The scope of each allocation |
91 | struct AllocEntry { |
92 | // the level in the scope stack |
93 | size_t level{0}; |
94 | // allocation stmt |
95 | const AllocateNode* alloc{nullptr}; |
96 | }; |
97 | |
98 | void VisitStmt_(const AllocateNode* op) final { |
99 | size_t level = scope_.size(); |
100 | const VarNode* buf = op->buffer_var.get(); |
101 | alloc_info_[buf].alloc = op; |
102 | alloc_info_[buf].level = level; |
103 | StmtExprVisitor::VisitStmt_(op); |
104 | } |
105 | |
106 | void VisitStmt_(const StoreNode* op) final { |
107 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
108 | } |
109 | |
110 | void VisitStmt_(const BufferStoreNode* op) final { |
111 | scope_.push_back(StmtEntry()); |
112 | // visit subexpr |
113 | StmtExprVisitor::VisitStmt_(op); |
114 | // Add write access. |
115 | const VarNode* buf = op->buffer->data.get(); |
116 | auto it = alloc_info_.find(buf); |
117 | if (it != alloc_info_.end() && it->second.alloc) { |
118 | ICHECK_LT(it->second.level, scope_.size()); |
119 | if (IsDynamicSharedMemory(GetRef<Var>(buf))) { |
120 | scope_[it->second.level].touched.push_back(buf); |
121 | } |
122 | } |
123 | StmtEntry e = scope_.back(); |
124 | scope_.pop_back(); |
125 | if (e.touched.size() != 0) { |
126 | e.stmt = op; |
127 | linear_seq_.push_back(e); |
128 | } |
129 | } |
130 | |
131 | void VisitStmt_(const EvaluateNode* op) final { |
132 | scope_.push_back(StmtEntry()); |
133 | // visit subexpr |
134 | StmtExprVisitor::VisitStmt_(op); |
135 | StmtEntry e = scope_.back(); |
136 | scope_.pop_back(); |
137 | if (e.touched.size() != 0) { |
138 | e.stmt = op; |
139 | linear_seq_.push_back(e); |
140 | } |
141 | } |
142 | |
143 | void VisitExpr_(const LoadNode* op) final { |
144 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
145 | } |
146 | |
147 | void VisitExpr_(const BufferLoadNode* op) final { |
148 | // Add write access. |
149 | StmtExprVisitor::VisitExpr_(op); |
150 | const VarNode* buf = op->buffer->data.get(); |
151 | auto it = alloc_info_.find(buf); |
152 | if (it != alloc_info_.end() && it->second.alloc) { |
153 | ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store." ; |
154 | if (IsDynamicSharedMemory(GetRef<Var>(buf))) { |
155 | scope_[it->second.level].touched.push_back(buf); |
156 | } |
157 | } |
158 | } |
159 | |
160 | void VisitExpr_(const CallNode* op) final { |
161 | if (op->op.same_as(builtin::address_of())) { |
162 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
163 | for (const auto& index : load->indices) { |
164 | this->VisitExpr(index); |
165 | } |
166 | } else { |
167 | StmtExprVisitor::VisitExpr_(op); |
168 | } |
169 | } |
170 | void VisitExpr_(const VarNode* buf) final { |
171 | // Directly reference to the variable count as a read. |
172 | auto it = alloc_info_.find(buf); |
173 | if (it != alloc_info_.end() && it->second.alloc) { |
174 | ICHECK_LT(it->second.level, scope_.size()); |
175 | if (IsDynamicSharedMemory(GetRef<Var>(buf))) { |
176 | scope_[it->second.level].touched.push_back(buf); |
177 | } |
178 | } |
179 | } |
180 | template <typename T> |
181 | void VisitNewScope(const T* op) { |
182 | scope_.push_back(StmtEntry()); |
183 | StmtEntry e; |
184 | e.stmt = op; |
185 | int64_t begin_index = static_cast<int64_t>(linear_seq_.size()); |
186 | // before scope. |
187 | linear_seq_.push_back(e); |
188 | StmtExprVisitor::VisitStmt_(op); |
189 | // after scope. |
190 | e.touched = std::move(scope_.back().touched); |
191 | scope_.pop_back(); |
192 | int64_t end_index = static_cast<int64_t>(linear_seq_.size()); |
193 | ICHECK_GT(end_index, begin_index); |
194 | e.scope_pair_offset = begin_index - end_index; |
195 | linear_seq_.push_back(e); |
196 | // record the pointer to end index. |
197 | ICHECK_NE(end_index, 0U); |
198 | linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; |
199 | } |
200 | void VisitStmt_(const AttrStmtNode* op) final { |
201 | // Only record the outer most thread extent. |
202 | if (op->attr_key == attr::thread_extent && !in_thread_env_) { |
203 | in_thread_env_ = true; |
204 | VisitNewScope(op); |
205 | in_thread_env_ = false; |
206 | } else if (op->attr_key == attr::extern_scope) { |
207 | VisitNewScope(op); |
208 | } else if (op->attr_key == attr::virtual_thread) { |
209 | VisitNewScope(op); |
210 | } else { |
211 | StmtExprVisitor::VisitStmt_(op); |
212 | } |
213 | } |
214 | void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } |
215 | |
216 | void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } |
217 | |
218 | void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } |
219 | |
220 | void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } |
221 | |
222 | // linearized access sequence. |
223 | std::vector<StmtEntry> linear_seq_; |
224 | // The storage scope of each buffer |
225 | std::unordered_map<const VarNode*, AllocEntry> alloc_info_; |
226 | |
227 | private: |
228 | // Whether already in thread env. |
229 | bool in_thread_env_{false}; |
230 | // The scope stack. |
231 | std::vector<StmtEntry> scope_; |
232 | }; |
233 | |
234 | /*! |
235 | * \brief merge the buffers whose live range has no intersection and rewrite the body |
236 | */ |
237 | class DynamicSharedMemoryRewriter : public StmtExprMutator { |
238 | public: |
239 | explicit DynamicSharedMemoryRewriter( |
240 | const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs) |
241 | : dyn_shmem_allocs_{dyn_shmem_allocs} {} |
242 | |
243 | /*! |
244 | * \brief plan the memory reuse for all the buffer allocated in the statement |
245 | * \param stmt the statement |
246 | */ |
247 | void PlanReuse(const Stmt& stmt) { |
248 | DynSharedMemLinearAccessPatternFinder finder; |
249 | finder(stmt); |
250 | this->LivenessAnalysis(finder.linear_seq_); |
251 | this->PlanMemory(finder.linear_seq_); |
252 | } |
253 | |
254 | private: |
255 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
256 | if (op->attr_key == attr::thread_extent && !allocated_) { |
257 | // Allocate one dynamic shared memory allocation at the beginning of thread scope |
258 | int max_layer_num = 0; |
259 | std::vector<const StorageEntry*> all_entry; |
260 | for (const auto& e : const_free_map_) { |
261 | all_entry.push_back(e.second); |
262 | } |
263 | for (const StorageEntry* e : sym_free_list_) { |
264 | all_entry.push_back(e); |
265 | } |
266 | for (const StorageEntry* e : all_entry) { |
267 | max_layer_num = std::max(max_layer_num, static_cast<int>(e->allocs.size())); |
268 | } |
269 | // calculate align for each layer of each storage entry. |
270 | std::vector<int> align(max_layer_num, 0); |
271 | for (const StorageEntry* e : all_entry) { |
272 | for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) { |
273 | for (const VarNode* buffer : e->allocs[i]) { |
274 | const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; |
275 | align[i] = std::max(align[i], alloc->dtype.bytes()); |
276 | } |
277 | } |
278 | } |
279 | // calculate offset for each buffer based on the align of each layer |
280 | for (const StorageEntry* e : all_entry) { |
281 | PrimExpr max_inner_offset = 0; |
282 | for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) { |
283 | PrimExpr inner_offset = 0; |
284 | for (const VarNode* buffer : e->allocs[i]) { |
285 | const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; |
286 | buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; |
287 | inner_offset += alloc->extents[0] * alloc->dtype.bytes(); |
288 | inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]); |
289 | } |
290 | max_inner_offset = max(max_inner_offset, inner_offset); |
291 | } |
292 | merged_alloc_size_ += max_inner_offset; |
293 | } |
294 | |
295 | allocated_ = true; |
296 | Allocate new_body(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(), |
297 | StmtExprMutator::VisitStmt(op->body)); |
298 | return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); |
299 | } |
300 | return StmtMutator::VisitStmt_(op); |
301 | } |
302 | |
303 | Stmt VisitStmt_(const AllocateNode* op) final { |
304 | if (IsDynamicSharedMemory(op->buffer_var)) { |
305 | return StmtExprMutator::VisitStmt(op->body); |
306 | } |
307 | return StmtExprMutator::VisitStmt_(op); |
308 | } |
309 | |
310 | PrimExpr VisitExpr_(const LoadNode* op) final { |
311 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
312 | } |
313 | |
314 | Stmt VisitStmt_(const StoreNode* op) final { |
315 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
316 | } |
317 | |
318 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
319 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
320 | return VisitBufferAccess(std::move(node)); |
321 | } |
322 | |
323 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
324 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
325 | return VisitBufferAccess(std::move(node)); |
326 | } |
327 | |
328 | template <typename Node> |
329 | Node VisitBufferAccess(Node node) { |
330 | if (IsDynamicSharedMemory(node->buffer->data)) { |
331 | ICHECK_EQ(node->indices.size(), 1) |
332 | << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " |
333 | << "and is to be run after " |
334 | << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)" ; |
335 | Array<PrimExpr> indices = {node->indices[0] + |
336 | this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; |
337 | |
338 | auto writer = node.CopyOnWrite(); |
339 | writer->buffer = GetUpdatedBuffer(node->buffer); |
340 | writer->indices = indices; |
341 | } |
342 | |
343 | return node; |
344 | } |
345 | |
346 | Buffer GetUpdatedBuffer(Buffer buffer) { |
347 | auto key = buffer.get(); |
348 | auto it = buffer_remap_.find(key); |
349 | if (it != buffer_remap_.end()) { |
350 | return it->second; |
351 | } |
352 | |
353 | if (IsDynamicSharedMemory(buffer->data)) { |
354 | ICHECK_EQ(buffer->shape.size(), 1) |
355 | << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " |
356 | << "and is to be run after " |
357 | << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)" ; |
358 | auto writer = buffer.CopyOnWrite(); |
359 | writer->data = merged_buf_var_; |
360 | } |
361 | |
362 | buffer_remap_[key] = buffer; |
363 | return buffer; |
364 | } |
365 | |
366 | PrimExpr VisitExpr_(const CallNode* op) final { |
367 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
368 | ICHECK_EQ(op->args.size(), 5U); |
369 | DataType dtype = op->args[0].dtype(); |
370 | Var buffer = Downcast<Var>(op->args[1]); |
371 | if (!IsDynamicSharedMemory(buffer)) { |
372 | return StmtExprMutator::VisitExpr_(op); |
373 | } |
374 | PrimExpr = GetBufferOffset(buffer, dtype); |
375 | |
376 | PrimExpr offset = this->VisitExpr(op->args[2]); |
377 | PrimExpr extent = this->VisitExpr(op->args[3]); |
378 | return Call(op->dtype, op->op, |
379 | {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); |
380 | } else { |
381 | return StmtExprMutator::VisitExpr_(op); |
382 | } |
383 | } |
384 | |
385 | PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { |
386 | auto it = buffer_byte_offsets_.find(buffer_var.get()); |
387 | ICHECK(it != buffer_byte_offsets_.end()); |
388 | return indexdiv(it->second, dtype.bytes()); |
389 | } |
390 | |
391 | using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry; |
392 | struct StorageEntry { |
393 | // The constant size of the buffer in bits, only used if it is constant |
394 | uint64_t const_nbits{0}; |
395 | // Allocs that shares this entry. |
396 | // The inner vector means a "layer" |
397 | // For example, it we need to allocate C in the memory of A and B: |
398 | // | A: 4096 bytes | B: 4096 bytes | |
399 | // | C: 8192 bytes | |
400 | // Then the allocs = {{A, B}, {C}} |
401 | std::vector<std::vector<const VarNode*>> allocs; |
402 | }; |
403 | |
404 | // Event entry in liveness analysis |
405 | struct EventEntry { |
406 | // variables we generate |
407 | std::vector<const VarNode*> gen; |
408 | // variables we kill |
409 | std::vector<const VarNode*> kill; |
410 | }; |
411 | |
412 | /*! |
413 | * \brief Liveness analysis to find gen and kill point of each variable. |
414 | * \param seq the linear pattern of storage access |
415 | */ |
416 | void LivenessAnalysis(const std::vector<StmtEntry>& seq) { |
417 | // find kill point, do a reverse linear scan. |
418 | std::unordered_set<const VarNode*> touched; |
419 | for (size_t i = seq.size(); i != 0; --i) { |
420 | const StmtEntry& s = seq[i - 1]; |
421 | for (const VarNode* buffer : s.touched) { |
422 | if (!touched.count(buffer)) { |
423 | touched.insert(buffer); |
424 | event_map_[s.stmt].kill.push_back(buffer); |
425 | } |
426 | } |
427 | } |
428 | // find gen point, do forward scan |
429 | touched.clear(); |
430 | for (size_t i = 0; i < seq.size(); ++i) { |
431 | int64_t offset = seq[i].scope_pair_offset; |
432 | if (offset < 0) continue; |
433 | const StmtEntry& s = seq[i + offset]; |
434 | for (const VarNode* buffer : s.touched) { |
435 | if (!touched.count(buffer)) { |
436 | touched.insert(buffer); |
437 | event_map_[s.stmt].gen.push_back(buffer); |
438 | } |
439 | } |
440 | } |
441 | } |
442 | |
443 | /*! |
444 | * \brief Memory plan algorithm |
445 | * \param seq the linear pattern of storage access |
446 | * \param alloc_info |
447 | */ |
448 | void PlanMemory(const std::vector<StmtEntry>& seq) { |
449 | std::unordered_set<const VarNode*> inplace_flag; |
450 | |
451 | for (size_t i = 0; i < seq.size(); ++i) { |
452 | auto it = event_map_.find(seq[i].stmt); |
453 | // scope_pair_offset <= 0 means it is either |
454 | // - leaf stmt(offset = 0) |
455 | // - end of scope(offset < 0) |
456 | // In both cases, we need to handle the kill event correctly |
457 | if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { |
458 | for (const VarNode* var : it->second.kill) { |
459 | this->Free(var); |
460 | } |
461 | } |
462 | // scope_pair_offset >= 0 means it is either |
463 | // - leaf stmt(offset = 0) |
464 | // - beginning of scope(offset < 0) |
465 | // In both cases, we need to handle the gen event correctly |
466 | if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { |
467 | for (const VarNode* var : it->second.gen) { |
468 | ICHECK(dyn_shmem_allocs_.count(var)); |
469 | const AllocateNode* alloc = dyn_shmem_allocs_[var]; |
470 | StorageEntry* dst_entry = FindAlloc(alloc); |
471 | alloc_map_[var] = dst_entry; |
472 | } |
473 | } |
474 | } |
475 | } |
476 | /*! |
477 | * \brief Allocate new storage entry. |
478 | * \param op the allocate node |
479 | * \param the size of the allocation in bits |
480 | * \return the new storage entry |
481 | */ |
482 | StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) { |
483 | ICHECK(op != nullptr); |
484 | // Re-use not successful, allocate a new buffer. |
485 | StorageEntry* entry = arena_.make<StorageEntry>(); |
486 | entry->allocs.push_back({op->buffer_var.get()}); |
487 | entry->const_nbits = const_nbits; |
488 | return entry; |
489 | } |
490 | /*! |
491 | * \brief find the storage entry in the free list for the allocate |
492 | * \param op the allocate node |
493 | * \return the storage entry |
494 | */ |
495 | StorageEntry* FindAlloc(const AllocateNode* op) { |
496 | ICHECK(op != nullptr); |
497 | // skip plan for local variable, |
498 | // compiler can do a better job with register allocation. |
499 | const uint64_t match_range = 16; |
500 | uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); |
501 | uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits); |
502 | // disable reuse of small arrays, they will be lowered to registers in LLVM |
503 | // This rules only apply if we are using non special memory |
504 | if (const_nbits > 0 && const_nbits <= 32) { |
505 | return NewAlloc(op, const_nbits); |
506 | } |
507 | |
508 | if (const_nbits != 0) { |
509 | // constant allocation. |
510 | auto begin = const_free_map_.lower_bound(0); |
511 | auto mid = const_free_map_.lower_bound(const_nbits); |
512 | auto end = const_free_map_.upper_bound(const_nbits * match_range); |
513 | // Start looking at the buffer that is bigger than the required size first. |
514 | // If we find one, directly allocate the buffer in its location and remove its entry in the |
515 | // free list |
516 | for (auto it = mid; it != end; ++it) { |
517 | StorageEntry* e = it->second; |
518 | e->const_nbits = std::max(const_nbits, e->const_nbits); |
519 | const_free_map_.erase(it); |
520 | return e; |
521 | } |
522 | // Then start looking at smaller buffers. |
523 | // Keep collecting the buffer until the sum of their size exceeds the buffer to allocate |
524 | // and finally free all these entry in the free list |
525 | std::vector<std::multimap<uint64_t, StorageEntry*>::iterator> delete_it; |
526 | // the alloc list for the new entry |
527 | std::vector<std::vector<const VarNode*>> reuse_allocs; |
528 | uint64_t mem_ct = 0; |
529 | for (auto it = mid; it != begin;) { |
530 | --it; |
531 | delete_it.push_back(it); |
532 | mem_ct += it->second->const_nbits; |
533 | int n = it->second->allocs.size(); |
534 | if (n > static_cast<int>(reuse_allocs.size())) { |
535 | reuse_allocs.resize(n, {}); |
536 | } |
537 | for (int i = 0; i < n; i++) { |
538 | for (const VarNode* alloc : it->second->allocs[i]) { |
539 | reuse_allocs[i].push_back(alloc); |
540 | } |
541 | } |
542 | if (mem_ct >= const_nbits) { |
543 | break; |
544 | } |
545 | } |
546 | reuse_allocs.push_back({op->buffer_var.get()}); |
547 | if (mem_ct != 0) { |
548 | StorageEntry* e = arena_.make<StorageEntry>(); |
549 | e->const_nbits = std::max(const_nbits, mem_ct); |
550 | e->allocs = reuse_allocs; |
551 | for (auto it : delete_it) { |
552 | const_free_map_.erase(it); |
553 | } |
554 | return e; |
555 | } |
556 | } else { |
557 | // if its symbolic allocation, just arbitrarily choose one entry to fit in because we don't |
558 | // know its actual size |
559 | for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { |
560 | StorageEntry* e = *it; |
561 | sym_free_list_.erase(it); |
562 | return e; |
563 | } |
564 | } |
565 | return NewAlloc(op, const_nbits); |
566 | } |
567 | |
568 | /*! |
569 | * \brief add the storage entry to the buffer var into the free list. |
570 | * \param var the buffer var |
571 | */ |
572 | void Free(const VarNode* var) { |
573 | auto it = alloc_map_.find(var); |
574 | ICHECK(it != alloc_map_.end()); |
575 | StorageEntry* e = it->second; |
576 | ICHECK_NE(e->allocs.size(), 0U); |
577 | |
578 | // disable reuse of small arrays |
579 | if (e->const_nbits > 0 && e->const_nbits <= 32) return; |
580 | |
581 | // normal free. |
582 | if (e->const_nbits != 0) { |
583 | const_free_map_.insert({e->const_nbits, e}); |
584 | } else { |
585 | sym_free_list_.push_back(e); |
586 | } |
587 | } |
588 | // The var for the merged buffer |
589 | Var merged_buf_var_{"buf_dyn_shmem" , PointerType(PrimType(DataType::UInt(8)), "shared.dyn" )}; |
590 | // The mapping from the original buffer var to its allocate |
591 | std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_; |
592 | // The size of the merged buffer |
593 | PrimExpr merged_alloc_size_{0}; |
594 | // The mapping from the original buffer var to its offset in the merged buffer |
595 | std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_; |
596 | // The mapping from the original buffer objects to their location in the merged buffer. |
597 | std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
598 | // The flag indicating whether the merged buffer has been allocated |
599 | bool allocated_{false}; |
600 | // Locations of free ops. |
601 | std::unordered_map<const Object*, EventEntry> event_map_; |
602 | // constant size free map. |
603 | std::multimap<uint64_t, StorageEntry*> const_free_map_; |
604 | // symbolic free list, for non constant items. |
605 | std::list<StorageEntry*> sym_free_list_; |
606 | // The allocation assign map |
607 | std::unordered_map<const VarNode*, StorageEntry*> alloc_map_; |
608 | /*! \brief allocator of all the StorageEntry*/ |
609 | support::Arena arena_; |
610 | }; |
611 | |
612 | Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { |
613 | AllocateCollector collector; |
614 | collector(stmt); |
615 | if (collector.dyn_shmem_allocs_.size() > 1) { |
616 | DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); |
617 | rewriter.PlanReuse(stmt); |
618 | return rewriter(std::move(stmt)); |
619 | } |
620 | return stmt; |
621 | } |
622 | |
623 | namespace transform { |
624 | |
625 | Pass MergeDynamicSharedMemoryAllocations() { |
626 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
627 | auto* n = f.CopyOnWrite(); |
628 | n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); |
629 | return f; |
630 | }; |
631 | return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations" , {}); |
632 | } |
633 | |
634 | TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations" ) |
635 | .set_body_typed(MergeDynamicSharedMemoryAllocations); |
636 | |
637 | } // namespace transform |
638 | } // namespace tir |
639 | } // namespace tvm |
640 | |