1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file compact_buffer_region.cc
22 * \brief Compact the buffer size into its exact need.
23 */
24
25#include <tvm/arith/int_set.h>
26#include <tvm/arith/int_solver.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include <stack>
32
33#include "../../support/arena.h"
34#include "../../support/nd_int_set.h"
35#include "../../support/utils.h"
36#include "../schedule/utils.h"
37#include "ir_utils.h"
38
39namespace tvm {
40namespace tir {
41
42using support::NDIntSet;
43
44/*!
45 * \brief simplify and return the region collected by NDIntSet. return the original
46 * buffer shape if the int_set is empty.
47 */
48Region SimplifyAndNarrowBufferRegionFromNDIntSet(
49 const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape, arith::Analyzer* analyzer,
50 const std::vector<const ForNode*>& ancestor_loops) {
51 Array<Range> result;
52 result.reserve(nd_int_set.size());
53 for (size_t i = 0; i < nd_int_set.size(); ++i) {
54 const arith::IntSet& int_set = nd_int_set[i];
55 Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]));
56 PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
57 PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i], range->extent));
58
59 // Check the buffer region is not loop dependent, since loop dependent
60 // allocation is not supported yet.
61 auto is_loop_var = [&ancestor_loops](const VarNode* v) {
62 return std::any_of(ancestor_loops.begin(), ancestor_loops.end(),
63 [v](const ForNode* n) { return n->loop_var.get() == v; });
64 };
65 if (UsesVar(extent, is_loop_var)) {
66 // try estimate a constant upperbound on region's extent
67 int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
68 if (upperbound != arith::ConstIntBound::kPosInf) {
69 extent = make_const(extent->dtype, upperbound);
70 } else {
71 // or else we have to fallback to full region
72 min = make_zero(original_shape[i]->dtype);
73 extent = original_shape[i];
74 }
75 }
76
77 result.push_back(Range::FromMinExtent(min, extent));
78 }
79 return result;
80}
81
82/*! \brief a more constrained bound estimate for n-dimentional int set */
83NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
84 const std::unordered_map<const VarNode*, arith::IntSet>& dom_map,
85 arith::Analyzer* analyzer) {
86 std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom;
87 for (const auto& it : dom_map) {
88 var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
89 }
90 Optional<Array<arith::IntSet>> eval_res =
91 arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer);
92 if (eval_res.defined()) {
93 return NDIntSet(eval_res.value().begin(), eval_res.value().end());
94 }
95 return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
96}
97
98/*!
99 * \brief Collect the access region of each buffer.
100 * \note The param buffer regions will not be collected.
101 */
102class BufferAccessRegionCollector : public StmtExprVisitor {
103 public:
104 static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect(
105 const PrimFunc& f) {
106 BufferAccessRegionCollector collector;
107 collector(f->body);
108 return std::move(collector.buffer_access_region_);
109 }
110
111 private:
112 struct BufferAccessInfo {
113 /*! \brief The buffer. */
114 Buffer buffer;
115 /*! \brief The buffer access region, which can be updated during visiting. */
116 NDIntSet accessed_region;
117
118 explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region)
119 : buffer(buffer), accessed_region(region) {}
120 };
121
122 BufferAccessRegionCollector() = default;
123
124 /**************** Visitor overload ****************/
125
126 void VisitStmt_(const BufferStoreNode* op) final {
127 VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
128 VisitExpr(op->value);
129 }
130
131 void VisitExpr_(const BufferLoadNode* op) final {
132 VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
133 }
134
135 void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
136
137 void VisitExpr_(const LoadNode* op) final {
138 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
139 }
140
141 void VisitStmt_(const StoreNode* op) final {
142 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
143 }
144
145 void VisitStmt_(const ForNode* op) final {
146 ancestor_loops_.push_back(op);
147 Range loop_range = Range::FromMinExtent(op->min, op->extent);
148 dom_analyzer_.Bind(op->loop_var, loop_range);
149 dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range));
150 StmtExprVisitor::VisitStmt_(op);
151 dom_map_.erase(op->loop_var.get());
152 ancestor_loops_.pop_back();
153 }
154
155 void VisitStmt_(const LetStmtNode* op) final {
156 StmtExprVisitor::VisitExpr(op->value);
157 if (arith::IsIndexType(op->value->dtype)) {
158 dom_analyzer_.Bind(op->var, op->value);
159 dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value));
160 }
161 StmtExprVisitor::VisitStmt(op->body);
162 if (arith::IsIndexType(op->value->dtype)) {
163 dom_map_.erase(op->var.get());
164 }
165 }
166
167 void VisitExpr_(const LetNode* op) final {
168 StmtExprVisitor::VisitExpr(op->value);
169 if (arith::IsIndexType(op->value->dtype)) {
170 dom_analyzer_.Bind(op->var, op->value);
171 dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value));
172 }
173 StmtExprVisitor::VisitExpr(op->body);
174 if (arith::IsIndexType(op->value->dtype)) {
175 dom_map_.erase(op->var.get());
176 }
177 }
178
179 void VisitStmt_(const IfThenElseNode* op) final {
180 // Visit condition
181 StmtExprVisitor::VisitExpr(op->condition);
182 {
183 // Visit then branch
184 With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
185 StmtExprVisitor::VisitStmt(op->then_case);
186 }
187 if (op->else_case) {
188 // Visit else branch
189 With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
190 StmtExprVisitor::VisitStmt(op->else_case.value());
191 }
192 }
193
194 void VisitExpr_(const CallNode* op) final {
195 if (op->op.same_as(builtin::if_then_else())) {
196 // Visit condition
197 StmtExprVisitor::VisitExpr(op->args[0]);
198 {
199 // Visit then branch
200 With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
201 StmtExprVisitor::VisitExpr(op->args[1]);
202 }
203 {
204 // Visit else branch
205 With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
206 StmtExprVisitor::VisitExpr(op->args[2]);
207 }
208 return;
209 }
210 StmtExprVisitor::VisitExpr_(op);
211 }
212
213 void VisitStmt_(const BlockNode* op) final {
214 // Step 0. Check there is no init part.
215 ICHECK(!op->init.defined());
216 // Step 1. Record and update current read/write region annotations
217 std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
218 cur_access_annotations;
219 for (const BufferRegion& region : op->reads) {
220 cur_access_annotations[region->buffer].push_back(region);
221 }
222 for (const BufferRegion& region : op->writes) {
223 cur_access_annotations[region->buffer].push_back(region);
224 }
225 for (auto& p : cur_access_annotations) {
226 auto& regions = access_annotations_[p.first];
227 p.second.swap(regions);
228 }
229 // Step 2. Record relax position of ancestor_loops_ into buffer_var_in_scope_
230 for (const Buffer& buffer : op->alloc_buffers) {
231 buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, ancestor_loops_.size()));
232 }
233 // Step 3. Visit match buffers
234 for (const MatchBufferRegion& region : op->match_buffers) {
235 VisitBufferAccess(region->source);
236 }
237 // Step 4. Visit block body recursively
238 StmtExprVisitor::VisitStmt_(op);
239 // Step 5. Recover read/write region annotations
240 for (auto& p : cur_access_annotations) {
241 auto& regions = access_annotations_[p.first];
242 if (p.second.empty()) {
243 access_annotations_.erase(p.first);
244 } else {
245 regions.swap(p.second);
246 }
247 }
248 // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers.
249 for (const Buffer& buffer : op->alloc_buffers) {
250 auto it = relaxed_accesses_.find(buffer);
251 ICHECK(it != relaxed_accesses_.end())
252 << buffer << " is allocated but not accessed within block scope";
253 const NDIntSet& nd_int_set = it->second;
254 buffer_access_region_[buffer] = SimplifyAndNarrowBufferRegionFromNDIntSet(
255 nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
256 }
257 }
258
259 void VisitStmt_(const BlockRealizeNode* op) final {
260 PrimExpr cur_predicate = predicate_in_scope;
261 predicate_in_scope = op->predicate;
262 StmtExprVisitor::VisitStmt_(op);
263 predicate_in_scope = cur_predicate;
264 }
265
266 /**************** Helper functions ****************/
267
268 void VisitBufferAccess(const BufferRegion& buffer_region) {
269 const BufferNode* buffer = buffer_region->buffer.get();
270 auto it = buffer_var_in_scope_.find(buffer->data);
271 if (it != buffer_var_in_scope_.end()) {
272 const Buffer& buffer = it->second.first;
273 size_t n_ancestor_loops = it->second.second;
274 // Step 1. Stop ancestor loop vars out of the allocation block from
275 // being relaxed unless NeedRelaxThread() is true.
276 std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
277 for (size_t i = 0; i < n_ancestor_loops; ++i) {
278 const ForNode* loop = ancestor_loops_[i];
279 const VarNode* v = loop->loop_var.get();
280 if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) {
281 continue;
282 }
283 auto dom_it = dom_map_.find(v);
284 ICHECK(dom_it != dom_map_.end())
285 << "Could not find domain for loop variable " << v->name_hint;
286 non_relaxed[i] = dom_it->second;
287 dom_map_.erase(dom_it);
288 }
289 // Step 2. Relax the access region
290 NDIntSet nd_int_set =
291 NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_);
292 // Step 3. Restore the non-relaxed ancestor loops domain
293 for (size_t i = 0; i < n_ancestor_loops; ++i) {
294 const VarNode* v = ancestor_loops_[i]->loop_var.get();
295 dom_map_.emplace(v, non_relaxed[i]);
296 }
297 // Step 4. Update relaxed_accesses_ dict
298 auto access_it = relaxed_accesses_.find(buffer);
299 if (access_it != relaxed_accesses_.end()) {
300 support::NDIntSetUnionWith(&access_it->second, nd_int_set);
301 } else {
302 relaxed_accesses_.insert(access_it, {buffer, nd_int_set});
303 }
304 }
305 }
306
307 void VisitBufferVar(const Var& var) {
308 auto it = buffer_var_in_scope_.find(var);
309 if (it != buffer_var_in_scope_.end()) {
310 const Buffer& buffer = it->second.first;
311 auto annotation_it = access_annotations_.find(buffer);
312 if (annotation_it != access_annotations_.end()) {
313 // opaque buffer has explicit accessed region annotations
314 for (const BufferRegion& region : annotation_it->second) {
315 VisitBufferAccess(region);
316 }
317 } else {
318 VisitBufferAccess(BufferRegion::FullRegion(buffer));
319 }
320 }
321 }
322
323 /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */
324 static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& scope) {
325 if (loop->kind != ForKind::kThreadBinding) {
326 return false;
327 }
328 ICHECK(loop->thread_binding.defined());
329 const String& thread_tag = loop->thread_binding.value()->thread_tag;
330 // When there is warp memory
331 // threadIdx.x must be set to be warp index.
332 return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create(thread_tag));
333 }
334
335 /**************** Class members ****************/
336 /*! \brief The loops from the current node up to the root. */
337 std::vector<const ForNode*> ancestor_loops_;
338
339 /*!
340 * \brief The vars of the buffer allocated under the current block.
341 * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where
342 * n_ancester_loop is the loop num out of the current block.
343 * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when
344 * we evaluate this buffer's access regions.
345 */
346 std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual>
347 buffer_var_in_scope_;
348 /*! \brief The block predicate of current scope */
349 PrimExpr predicate_in_scope{true};
350
351 /*! \brief The map from loop vars to their iter range. */
352 std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
353 /*! \brief Extra map from free vars to their iter range hints. */
354 std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
355 /*! \brief The analyzer aware of loop domains. */
356 arith::Analyzer dom_analyzer_;
357 /*! \brief The map from Buffer to it's relaxed access set. */
358 std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_accesses_;
359 /*! \brief The map from Buffer to it entire access region, used for returning. */
360 std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_;
361 /*! \brief The map from Buffer to it's access regions annotated by current block. */
362 std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual>
363 access_annotations_;
364};
365
366/*! \brief Collect storage alignment information from block annotations. */
367class StorageAlignCollector : public StmtVisitor {
368 public:
369 static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual> Collect(
370 const PrimFunc& f) {
371 StorageAlignCollector collector;
372 collector(f->body);
373 return std::move(collector.storage_align_);
374 }
375
376 private:
377 void VisitStmt_(const BlockNode* op) final {
378 auto it = op->annotations.find(attr::buffer_dim_align);
379 if (it != op->annotations.end()) {
380 auto storage_align_annotation = Downcast<StorageAlignAnnotation>((*it).second);
381 for (const auto& storage_align_tuple : storage_align_annotation) {
382 int buffer_index = storage_align_tuple[0]->value;
383 const Buffer& buffer = op->writes[buffer_index]->buffer;
384 storage_align_[buffer].push_back(storage_align_tuple);
385 }
386 }
387 StmtVisitor::VisitStmt_(op);
388 }
389
390 /*! \brief The map from Buffer to its storage alignment information. */
391 std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual> storage_align_;
392};
393
394/*! \brief Reallocate the buffers with minimal region. */
395class BufferCompactor : public StmtExprMutator {
396 public:
397 static Stmt Compact(
398 const PrimFunc& f,
399 const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
400 const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
401 storage_align) {
402 std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
403
404 for (const auto& kv : regions) {
405 const Buffer& buffer = kv.first;
406 Region region = kv.second;
407 BufferAllocInfo buffer_alloc_info(std::move(region));
408 auto it = storage_align.find(buffer);
409 if (it != storage_align.end()) {
410 std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
411 for (const StorageAlignTuple& dim_align : (*it).second) {
412 ICHECK(dim_align.size() == 4);
413 int dim = dim_align[1]->value;
414 int factor = dim_align[2]->value;
415 int offset = dim_align[3]->value;
416 dim_aligns.at(dim) = {factor, offset};
417 }
418 buffer_alloc_info.dim_aligns = std::move(dim_aligns);
419 }
420 buffer_info.emplace(buffer, std::move(buffer_alloc_info));
421 }
422 BufferCompactor compactor(std::move(buffer_info));
423 Stmt stmt = compactor(f->body);
424 return stmt;
425 }
426
427 private:
428 /*! \brief The storage alignment for a dimension */
429 struct DimAlignInfo {
430 /*! \brief The factor of the alignment */
431 int align_factor{0};
432 /*! \brief The offset of the alignment */
433 int align_offset{0};
434 };
435
436 struct BufferAllocInfo {
437 /*! \brief The buffer access region. */
438 Region region;
439 /*! \brief The storage alignment information. */
440 std::vector<DimAlignInfo> dim_aligns;
441 /*!
442 * \brief The reallocated buffer with minimal size.
443 * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
444 */
445 Buffer new_buffer;
446
447 explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
448 };
449
450 explicit BufferCompactor(
451 std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info)
452 : buffer_info_(std::move(buffer_info)) {}
453
454 Stmt VisitStmt_(const BufferStoreNode* _op) final {
455 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
456 BufferStoreNode* op = store.CopyOnWrite();
457 RewriteBufferAccess(&op->buffer, &op->indices);
458 return std::move(store);
459 }
460
461 PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
462 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
463 BufferLoadNode* op = load.CopyOnWrite();
464 RewriteBufferAccess(&op->buffer, &op->indices);
465 return std::move(load);
466 }
467
468 Stmt VisitStmt_(const BlockNode* op) final {
469 // Step 0. Check there is no Init part.
470 ICHECK(!op->init.defined());
471 // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo.
472 Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
473 // Step 2. Recursively rewrite BufferLoad/BufferStore.
474 Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
475 // Step 3. Update block signature.
476 BlockNode* n = block.CopyOnWrite();
477 RewriteBufferRegions(&n->reads);
478 RewriteBufferRegions(&n->writes);
479 RewriteMatchBuffers(&n->match_buffers);
480 n->alloc_buffers = std::move(alloc_buffers);
481 return std::move(block);
482 }
483
484 Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
485 Array<Buffer> result;
486 result.reserve(buffers.size());
487 for (const Buffer& buffer : buffers) {
488 auto it = buffer_info_.find(buffer);
489 ICHECK(it != buffer_info_.end());
490 BufferAllocInfo& info = it->second;
491 Array<PrimExpr> shape;
492 shape.reserve(info.region.size());
493 for (const Range& range : info.region) {
494 shape.push_back(range->extent);
495 }
496 Array<PrimExpr> strides;
497 if (info.dim_aligns.size()) {
498 ICHECK(info.dim_aligns.size() == shape.size());
499 strides.resize(shape.size());
500 PrimExpr stride = make_const(shape[0].dtype(), 1);
501 for (size_t i = shape.size(); i != 0; --i) {
502 size_t dim = i - 1;
503 if (info.dim_aligns[dim].align_factor != 0) {
504 PrimExpr factor = make_const(stride.dtype(), info.dim_aligns[dim].align_factor);
505 PrimExpr offset = make_const(stride.dtype(), info.dim_aligns[dim].align_offset);
506 stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor);
507 }
508 strides.Set(dim, stride);
509 stride = stride * shape[dim];
510 }
511 }
512 ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
513 n->shape = std::move(shape);
514 n->strides = std::move(strides);
515 info.new_buffer = Buffer(std::move(n));
516 result.push_back(info.new_buffer);
517 }
518 return result;
519 }
520
521 void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
522 auto it = buffer_info_.find(*buffer);
523 if (it == buffer_info_.end()) {
524 // Skip if the buffer is parameter
525 return;
526 }
527 const BufferAllocInfo& info = it->second;
528 ICHECK_EQ(indices->size(), info.region.size());
529 int ndim = info.region.size();
530 Array<PrimExpr> new_indices;
531 new_indices.reserve(ndim);
532 for (int i = 0; i < ndim; ++i) {
533 new_indices.push_back((*indices)[i] - info.region[i]->min);
534 }
535 *buffer = info.new_buffer;
536 *indices = std::move(new_indices);
537 }
538
539 void RewriteBufferRegion(Buffer* buffer, Region* region) const {
540 auto it = buffer_info_.find(*buffer);
541 if (it == buffer_info_.end()) {
542 // Skip if the buffer is parameter
543 return;
544 }
545 const BufferAllocInfo& info = it->second;
546 ICHECK_EQ(region->size(), info.region.size());
547 Region new_region;
548 new_region.reserve(info.region.size());
549 for (size_t i = 0; i < info.region.size(); ++i) {
550 const Range& range = (*region)[i];
551 new_region.push_back(Range::FromMinExtent(range->min - info.region[i]->min, range->extent));
552 }
553 *buffer = info.new_buffer;
554 *region = std::move(new_region);
555 }
556
557 void RewriteBufferRegions(Array<BufferRegion>* regions) const {
558 Array<BufferRegion> new_regions;
559 new_regions.reserve(regions->size());
560 for (const auto& region : *regions) {
561 BufferRegion buffer_region = region;
562 BufferRegionNode* p = buffer_region.CopyOnWrite();
563 RewriteBufferRegion(&p->buffer, &p->region);
564 new_regions.push_back(buffer_region);
565 }
566 *regions = std::move(new_regions);
567 }
568
569 void RewriteMatchBuffers(Array<MatchBufferRegion>* match_buffers) const {
570 Array<MatchBufferRegion> result;
571 result.reserve(match_buffers->size());
572 for (const auto& match_buffer : *match_buffers) {
573 const BufferRegion& buffer_region = match_buffer->source;
574 auto p = make_object<BufferRegionNode>(*buffer_region.get());
575 RewriteBufferRegion(&p->buffer, &p->region);
576 result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p)));
577 }
578 *match_buffers = std::move(result);
579 }
580
581 /*! \brief The allocation information about each buffer. */
582 std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
583};
584
585PrimFunc CompactBufferAllocation(PrimFunc f) {
586 // Only apply this pass to TIR that is not from TE schedules
587 if (!IsFromLegacyTESchedule(f)) {
588 PrimFuncNode* fptr = f.CopyOnWrite();
589 std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
590 BufferAccessRegionCollector::Collect(f);
591 std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
592 storage_align = StorageAlignCollector::Collect(f);
593 fptr->body = BufferCompactor::Compact(f, region, storage_align);
594 return f;
595 } else {
596 return f;
597 }
598}
599
600namespace transform {
601
602Pass CompactBufferAllocation() {
603 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
604 return CompactBufferAllocation(std::move(f));
605 };
606 return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
607}
608
609TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation")
610 .set_body_typed(CompactBufferAllocation);
611} // namespace transform
612
613} // namespace tir
614} // namespace tvm
615