1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file compact_buffer_region.cc |
22 | * \brief Compact the buffer size into its exact need. |
23 | */ |
24 | |
25 | #include <tvm/arith/int_set.h> |
26 | #include <tvm/arith/int_solver.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include <stack> |
32 | |
33 | #include "../../support/arena.h" |
34 | #include "../../support/nd_int_set.h" |
35 | #include "../../support/utils.h" |
36 | #include "../schedule/utils.h" |
37 | #include "ir_utils.h" |
38 | |
39 | namespace tvm { |
40 | namespace tir { |
41 | |
42 | using support::NDIntSet; |
43 | |
44 | /*! |
45 | * \brief simplify and return the region collected by NDIntSet. return the original |
46 | * buffer shape if the int_set is empty. |
47 | */ |
48 | Region SimplifyAndNarrowBufferRegionFromNDIntSet( |
49 | const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape, arith::Analyzer* analyzer, |
50 | const std::vector<const ForNode*>& ancestor_loops) { |
51 | Array<Range> result; |
52 | result.reserve(nd_int_set.size()); |
53 | for (size_t i = 0; i < nd_int_set.size(); ++i) { |
54 | const arith::IntSet& int_set = nd_int_set[i]; |
55 | Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])); |
56 | PrimExpr min = analyzer->Simplify(tvm::max(0, range->min)); |
57 | PrimExpr extent = analyzer->Simplify(tvm::min(original_shape[i], range->extent)); |
58 | |
59 | // Check the buffer region is not loop dependent, since loop dependent |
60 | // allocation is not supported yet. |
61 | auto is_loop_var = [&ancestor_loops](const VarNode* v) { |
62 | return std::any_of(ancestor_loops.begin(), ancestor_loops.end(), |
63 | [v](const ForNode* n) { return n->loop_var.get() == v; }); |
64 | }; |
65 | if (UsesVar(extent, is_loop_var)) { |
66 | // try estimate a constant upperbound on region's extent |
67 | int64_t upperbound = analyzer->const_int_bound(extent)->max_value; |
68 | if (upperbound != arith::ConstIntBound::kPosInf) { |
69 | extent = make_const(extent->dtype, upperbound); |
70 | } else { |
71 | // or else we have to fallback to full region |
72 | min = make_zero(original_shape[i]->dtype); |
73 | extent = original_shape[i]; |
74 | } |
75 | } |
76 | |
77 | result.push_back(Range::FromMinExtent(min, extent)); |
78 | } |
79 | return result; |
80 | } |
81 | |
82 | /*! \brief a more constrained bound estimate for n-dimentional int set */ |
83 | NDIntSet NDIntSetEval(Region region, PrimExpr predicate, |
84 | const std::unordered_map<const VarNode*, arith::IntSet>& dom_map, |
85 | arith::Analyzer* analyzer) { |
86 | std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom; |
87 | for (const auto& it : dom_map) { |
88 | var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); |
89 | } |
90 | Optional<Array<arith::IntSet>> eval_res = |
91 | arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer); |
92 | if (eval_res.defined()) { |
93 | return NDIntSet(eval_res.value().begin(), eval_res.value().end()); |
94 | } |
95 | return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map); |
96 | } |
97 | |
98 | /*! |
99 | * \brief Collect the access region of each buffer. |
100 | * \note The param buffer regions will not be collected. |
101 | */ |
102 | class BufferAccessRegionCollector : public StmtExprVisitor { |
103 | public: |
104 | static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect( |
105 | const PrimFunc& f) { |
106 | BufferAccessRegionCollector collector; |
107 | collector(f->body); |
108 | return std::move(collector.buffer_access_region_); |
109 | } |
110 | |
111 | private: |
112 | struct BufferAccessInfo { |
113 | /*! \brief The buffer. */ |
114 | Buffer buffer; |
115 | /*! \brief The buffer access region, which can be updated during visiting. */ |
116 | NDIntSet accessed_region; |
117 | |
118 | explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region) |
119 | : buffer(buffer), accessed_region(region) {} |
120 | }; |
121 | |
122 | BufferAccessRegionCollector() = default; |
123 | |
124 | /**************** Visitor overload ****************/ |
125 | |
126 | void VisitStmt_(const BufferStoreNode* op) final { |
127 | VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); |
128 | VisitExpr(op->value); |
129 | } |
130 | |
131 | void VisitExpr_(const BufferLoadNode* op) final { |
132 | VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); |
133 | } |
134 | |
135 | void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); } |
136 | |
137 | void VisitExpr_(const LoadNode* op) final { |
138 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
139 | } |
140 | |
141 | void VisitStmt_(const StoreNode* op) final { |
142 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
143 | } |
144 | |
145 | void VisitStmt_(const ForNode* op) final { |
146 | ancestor_loops_.push_back(op); |
147 | Range loop_range = Range::FromMinExtent(op->min, op->extent); |
148 | dom_analyzer_.Bind(op->loop_var, loop_range); |
149 | dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range)); |
150 | StmtExprVisitor::VisitStmt_(op); |
151 | dom_map_.erase(op->loop_var.get()); |
152 | ancestor_loops_.pop_back(); |
153 | } |
154 | |
155 | void VisitStmt_(const LetStmtNode* op) final { |
156 | StmtExprVisitor::VisitExpr(op->value); |
157 | if (arith::IsIndexType(op->value->dtype)) { |
158 | dom_analyzer_.Bind(op->var, op->value); |
159 | dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); |
160 | } |
161 | StmtExprVisitor::VisitStmt(op->body); |
162 | if (arith::IsIndexType(op->value->dtype)) { |
163 | dom_map_.erase(op->var.get()); |
164 | } |
165 | } |
166 | |
167 | void VisitExpr_(const LetNode* op) final { |
168 | StmtExprVisitor::VisitExpr(op->value); |
169 | if (arith::IsIndexType(op->value->dtype)) { |
170 | dom_analyzer_.Bind(op->var, op->value); |
171 | dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); |
172 | } |
173 | StmtExprVisitor::VisitExpr(op->body); |
174 | if (arith::IsIndexType(op->value->dtype)) { |
175 | dom_map_.erase(op->var.get()); |
176 | } |
177 | } |
178 | |
179 | void VisitStmt_(const IfThenElseNode* op) final { |
180 | // Visit condition |
181 | StmtExprVisitor::VisitExpr(op->condition); |
182 | { |
183 | // Visit then branch |
184 | With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true); |
185 | StmtExprVisitor::VisitStmt(op->then_case); |
186 | } |
187 | if (op->else_case) { |
188 | // Visit else branch |
189 | With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false); |
190 | StmtExprVisitor::VisitStmt(op->else_case.value()); |
191 | } |
192 | } |
193 | |
194 | void VisitExpr_(const CallNode* op) final { |
195 | if (op->op.same_as(builtin::if_then_else())) { |
196 | // Visit condition |
197 | StmtExprVisitor::VisitExpr(op->args[0]); |
198 | { |
199 | // Visit then branch |
200 | With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true); |
201 | StmtExprVisitor::VisitExpr(op->args[1]); |
202 | } |
203 | { |
204 | // Visit else branch |
205 | With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false); |
206 | StmtExprVisitor::VisitExpr(op->args[2]); |
207 | } |
208 | return; |
209 | } |
210 | StmtExprVisitor::VisitExpr_(op); |
211 | } |
212 | |
213 | void VisitStmt_(const BlockNode* op) final { |
214 | // Step 0. Check there is no init part. |
215 | ICHECK(!op->init.defined()); |
216 | // Step 1. Record and update current read/write region annotations |
217 | std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual> |
218 | cur_access_annotations; |
219 | for (const BufferRegion& region : op->reads) { |
220 | cur_access_annotations[region->buffer].push_back(region); |
221 | } |
222 | for (const BufferRegion& region : op->writes) { |
223 | cur_access_annotations[region->buffer].push_back(region); |
224 | } |
225 | for (auto& p : cur_access_annotations) { |
226 | auto& regions = access_annotations_[p.first]; |
227 | p.second.swap(regions); |
228 | } |
229 | // Step 2. Record relax position of ancestor_loops_ into buffer_var_in_scope_ |
230 | for (const Buffer& buffer : op->alloc_buffers) { |
231 | buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, ancestor_loops_.size())); |
232 | } |
233 | // Step 3. Visit match buffers |
234 | for (const MatchBufferRegion& region : op->match_buffers) { |
235 | VisitBufferAccess(region->source); |
236 | } |
237 | // Step 4. Visit block body recursively |
238 | StmtExprVisitor::VisitStmt_(op); |
239 | // Step 5. Recover read/write region annotations |
240 | for (auto& p : cur_access_annotations) { |
241 | auto& regions = access_annotations_[p.first]; |
242 | if (p.second.empty()) { |
243 | access_annotations_.erase(p.first); |
244 | } else { |
245 | regions.swap(p.second); |
246 | } |
247 | } |
248 | // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. |
249 | for (const Buffer& buffer : op->alloc_buffers) { |
250 | auto it = relaxed_accesses_.find(buffer); |
251 | ICHECK(it != relaxed_accesses_.end()) |
252 | << buffer << " is allocated but not accessed within block scope" ; |
253 | const NDIntSet& nd_int_set = it->second; |
254 | buffer_access_region_[buffer] = SimplifyAndNarrowBufferRegionFromNDIntSet( |
255 | nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_); |
256 | } |
257 | } |
258 | |
259 | void VisitStmt_(const BlockRealizeNode* op) final { |
260 | PrimExpr cur_predicate = predicate_in_scope; |
261 | predicate_in_scope = op->predicate; |
262 | StmtExprVisitor::VisitStmt_(op); |
263 | predicate_in_scope = cur_predicate; |
264 | } |
265 | |
266 | /**************** Helper functions ****************/ |
267 | |
268 | void VisitBufferAccess(const BufferRegion& buffer_region) { |
269 | const BufferNode* buffer = buffer_region->buffer.get(); |
270 | auto it = buffer_var_in_scope_.find(buffer->data); |
271 | if (it != buffer_var_in_scope_.end()) { |
272 | const Buffer& buffer = it->second.first; |
273 | size_t n_ancestor_loops = it->second.second; |
274 | // Step 1. Stop ancestor loop vars out of the allocation block from |
275 | // being relaxed unless NeedRelaxThread() is true. |
276 | std::vector<arith::IntSet> non_relaxed(n_ancestor_loops); |
277 | for (size_t i = 0; i < n_ancestor_loops; ++i) { |
278 | const ForNode* loop = ancestor_loops_[i]; |
279 | const VarNode* v = loop->loop_var.get(); |
280 | if (NeedRelaxThread(GetRef<For>(loop), runtime::StorageScope::Create(buffer.scope()))) { |
281 | continue; |
282 | } |
283 | auto dom_it = dom_map_.find(v); |
284 | ICHECK(dom_it != dom_map_.end()) |
285 | << "Could not find domain for loop variable " << v->name_hint; |
286 | non_relaxed[i] = dom_it->second; |
287 | dom_map_.erase(dom_it); |
288 | } |
289 | // Step 2. Relax the access region |
290 | NDIntSet nd_int_set = |
291 | NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_); |
292 | // Step 3. Restore the non-relaxed ancestor loops domain |
293 | for (size_t i = 0; i < n_ancestor_loops; ++i) { |
294 | const VarNode* v = ancestor_loops_[i]->loop_var.get(); |
295 | dom_map_.emplace(v, non_relaxed[i]); |
296 | } |
297 | // Step 4. Update relaxed_accesses_ dict |
298 | auto access_it = relaxed_accesses_.find(buffer); |
299 | if (access_it != relaxed_accesses_.end()) { |
300 | support::NDIntSetUnionWith(&access_it->second, nd_int_set); |
301 | } else { |
302 | relaxed_accesses_.insert(access_it, {buffer, nd_int_set}); |
303 | } |
304 | } |
305 | } |
306 | |
307 | void VisitBufferVar(const Var& var) { |
308 | auto it = buffer_var_in_scope_.find(var); |
309 | if (it != buffer_var_in_scope_.end()) { |
310 | const Buffer& buffer = it->second.first; |
311 | auto annotation_it = access_annotations_.find(buffer); |
312 | if (annotation_it != access_annotations_.end()) { |
313 | // opaque buffer has explicit accessed region annotations |
314 | for (const BufferRegion& region : annotation_it->second) { |
315 | VisitBufferAccess(region); |
316 | } |
317 | } else { |
318 | VisitBufferAccess(BufferRegion::FullRegion(buffer)); |
319 | } |
320 | } |
321 | } |
322 | |
323 | /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */ |
324 | static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& scope) { |
325 | if (loop->kind != ForKind::kThreadBinding) { |
326 | return false; |
327 | } |
328 | ICHECK(loop->thread_binding.defined()); |
329 | const String& thread_tag = loop->thread_binding.value()->thread_tag; |
330 | // When there is warp memory |
331 | // threadIdx.x must be set to be warp index. |
332 | return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create(thread_tag)); |
333 | } |
334 | |
335 | /**************** Class members ****************/ |
336 | /*! \brief The loops from the current node up to the root. */ |
337 | std::vector<const ForNode*> ancestor_loops_; |
338 | |
339 | /*! |
340 | * \brief The vars of the buffer allocated under the current block. |
341 | * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where |
342 | * n_ancester_loop is the loop num out of the current block. |
343 | * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when |
344 | * we evaluate this buffer's access regions. |
345 | */ |
346 | std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual> |
347 | buffer_var_in_scope_; |
348 | /*! \brief The block predicate of current scope */ |
349 | PrimExpr predicate_in_scope{true}; |
350 | |
351 | /*! \brief The map from loop vars to their iter range. */ |
352 | std::unordered_map<const VarNode*, arith::IntSet> dom_map_; |
353 | /*! \brief Extra map from free vars to their iter range hints. */ |
354 | std::unordered_map<const VarNode*, arith::IntSet> hint_map_; |
355 | /*! \brief The analyzer aware of loop domains. */ |
356 | arith::Analyzer dom_analyzer_; |
357 | /*! \brief The map from Buffer to it's relaxed access set. */ |
358 | std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> relaxed_accesses_; |
359 | /*! \brief The map from Buffer to it entire access region, used for returning. */ |
360 | std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> buffer_access_region_; |
361 | /*! \brief The map from Buffer to it's access regions annotated by current block. */ |
362 | std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, ObjectPtrEqual> |
363 | access_annotations_; |
364 | }; |
365 | |
366 | /*! \brief Collect storage alignment information from block annotations. */ |
367 | class StorageAlignCollector : public StmtVisitor { |
368 | public: |
369 | static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual> Collect( |
370 | const PrimFunc& f) { |
371 | StorageAlignCollector collector; |
372 | collector(f->body); |
373 | return std::move(collector.storage_align_); |
374 | } |
375 | |
376 | private: |
377 | void VisitStmt_(const BlockNode* op) final { |
378 | auto it = op->annotations.find(attr::buffer_dim_align); |
379 | if (it != op->annotations.end()) { |
380 | auto storage_align_annotation = Downcast<StorageAlignAnnotation>((*it).second); |
381 | for (const auto& storage_align_tuple : storage_align_annotation) { |
382 | int buffer_index = storage_align_tuple[0]->value; |
383 | const Buffer& buffer = op->writes[buffer_index]->buffer; |
384 | storage_align_[buffer].push_back(storage_align_tuple); |
385 | } |
386 | } |
387 | StmtVisitor::VisitStmt_(op); |
388 | } |
389 | |
390 | /*! \brief The map from Buffer to its storage alignment information. */ |
391 | std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual> storage_align_; |
392 | }; |
393 | |
394 | /*! \brief Reallocate the buffers with minimal region. */ |
395 | class BufferCompactor : public StmtExprMutator { |
396 | public: |
397 | static Stmt Compact( |
398 | const PrimFunc& f, |
399 | const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions, |
400 | const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>& |
401 | storage_align) { |
402 | std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info; |
403 | |
404 | for (const auto& kv : regions) { |
405 | const Buffer& buffer = kv.first; |
406 | Region region = kv.second; |
407 | BufferAllocInfo buffer_alloc_info(std::move(region)); |
408 | auto it = storage_align.find(buffer); |
409 | if (it != storage_align.end()) { |
410 | std::vector<DimAlignInfo> dim_aligns(buffer->shape.size()); |
411 | for (const StorageAlignTuple& dim_align : (*it).second) { |
412 | ICHECK(dim_align.size() == 4); |
413 | int dim = dim_align[1]->value; |
414 | int factor = dim_align[2]->value; |
415 | int offset = dim_align[3]->value; |
416 | dim_aligns.at(dim) = {factor, offset}; |
417 | } |
418 | buffer_alloc_info.dim_aligns = std::move(dim_aligns); |
419 | } |
420 | buffer_info.emplace(buffer, std::move(buffer_alloc_info)); |
421 | } |
422 | BufferCompactor compactor(std::move(buffer_info)); |
423 | Stmt stmt = compactor(f->body); |
424 | return stmt; |
425 | } |
426 | |
427 | private: |
428 | /*! \brief The storage alignment for a dimension */ |
429 | struct DimAlignInfo { |
430 | /*! \brief The factor of the alignment */ |
431 | int align_factor{0}; |
432 | /*! \brief The offset of the alignment */ |
433 | int align_offset{0}; |
434 | }; |
435 | |
436 | struct BufferAllocInfo { |
437 | /*! \brief The buffer access region. */ |
438 | Region region; |
439 | /*! \brief The storage alignment information. */ |
440 | std::vector<DimAlignInfo> dim_aligns; |
441 | /*! |
442 | * \brief The reallocated buffer with minimal size. |
443 | * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). |
444 | */ |
445 | Buffer new_buffer; |
446 | |
447 | explicit BufferAllocInfo(Region region) : region(std::move(region)) {} |
448 | }; |
449 | |
450 | explicit BufferCompactor( |
451 | std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info) |
452 | : buffer_info_(std::move(buffer_info)) {} |
453 | |
454 | Stmt VisitStmt_(const BufferStoreNode* _op) final { |
455 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op)); |
456 | BufferStoreNode* op = store.CopyOnWrite(); |
457 | RewriteBufferAccess(&op->buffer, &op->indices); |
458 | return std::move(store); |
459 | } |
460 | |
461 | PrimExpr VisitExpr_(const BufferLoadNode* _op) final { |
462 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op)); |
463 | BufferLoadNode* op = load.CopyOnWrite(); |
464 | RewriteBufferAccess(&op->buffer, &op->indices); |
465 | return std::move(load); |
466 | } |
467 | |
468 | Stmt VisitStmt_(const BlockNode* op) final { |
469 | // Step 0. Check there is no Init part. |
470 | ICHECK(!op->init.defined()); |
471 | // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. |
472 | Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers); |
473 | // Step 2. Recursively rewrite BufferLoad/BufferStore. |
474 | Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
475 | // Step 3. Update block signature. |
476 | BlockNode* n = block.CopyOnWrite(); |
477 | RewriteBufferRegions(&n->reads); |
478 | RewriteBufferRegions(&n->writes); |
479 | RewriteMatchBuffers(&n->match_buffers); |
480 | n->alloc_buffers = std::move(alloc_buffers); |
481 | return std::move(block); |
482 | } |
483 | |
484 | Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) { |
485 | Array<Buffer> result; |
486 | result.reserve(buffers.size()); |
487 | for (const Buffer& buffer : buffers) { |
488 | auto it = buffer_info_.find(buffer); |
489 | ICHECK(it != buffer_info_.end()); |
490 | BufferAllocInfo& info = it->second; |
491 | Array<PrimExpr> shape; |
492 | shape.reserve(info.region.size()); |
493 | for (const Range& range : info.region) { |
494 | shape.push_back(range->extent); |
495 | } |
496 | Array<PrimExpr> strides; |
497 | if (info.dim_aligns.size()) { |
498 | ICHECK(info.dim_aligns.size() == shape.size()); |
499 | strides.resize(shape.size()); |
500 | PrimExpr stride = make_const(shape[0].dtype(), 1); |
501 | for (size_t i = shape.size(); i != 0; --i) { |
502 | size_t dim = i - 1; |
503 | if (info.dim_aligns[dim].align_factor != 0) { |
504 | PrimExpr factor = make_const(stride.dtype(), info.dim_aligns[dim].align_factor); |
505 | PrimExpr offset = make_const(stride.dtype(), info.dim_aligns[dim].align_offset); |
506 | stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); |
507 | } |
508 | strides.Set(dim, stride); |
509 | stride = stride * shape[dim]; |
510 | } |
511 | } |
512 | ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get()); |
513 | n->shape = std::move(shape); |
514 | n->strides = std::move(strides); |
515 | info.new_buffer = Buffer(std::move(n)); |
516 | result.push_back(info.new_buffer); |
517 | } |
518 | return result; |
519 | } |
520 | |
521 | void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const { |
522 | auto it = buffer_info_.find(*buffer); |
523 | if (it == buffer_info_.end()) { |
524 | // Skip if the buffer is parameter |
525 | return; |
526 | } |
527 | const BufferAllocInfo& info = it->second; |
528 | ICHECK_EQ(indices->size(), info.region.size()); |
529 | int ndim = info.region.size(); |
530 | Array<PrimExpr> new_indices; |
531 | new_indices.reserve(ndim); |
532 | for (int i = 0; i < ndim; ++i) { |
533 | new_indices.push_back((*indices)[i] - info.region[i]->min); |
534 | } |
535 | *buffer = info.new_buffer; |
536 | *indices = std::move(new_indices); |
537 | } |
538 | |
539 | void RewriteBufferRegion(Buffer* buffer, Region* region) const { |
540 | auto it = buffer_info_.find(*buffer); |
541 | if (it == buffer_info_.end()) { |
542 | // Skip if the buffer is parameter |
543 | return; |
544 | } |
545 | const BufferAllocInfo& info = it->second; |
546 | ICHECK_EQ(region->size(), info.region.size()); |
547 | Region new_region; |
548 | new_region.reserve(info.region.size()); |
549 | for (size_t i = 0; i < info.region.size(); ++i) { |
550 | const Range& range = (*region)[i]; |
551 | new_region.push_back(Range::FromMinExtent(range->min - info.region[i]->min, range->extent)); |
552 | } |
553 | *buffer = info.new_buffer; |
554 | *region = std::move(new_region); |
555 | } |
556 | |
557 | void RewriteBufferRegions(Array<BufferRegion>* regions) const { |
558 | Array<BufferRegion> new_regions; |
559 | new_regions.reserve(regions->size()); |
560 | for (const auto& region : *regions) { |
561 | BufferRegion buffer_region = region; |
562 | BufferRegionNode* p = buffer_region.CopyOnWrite(); |
563 | RewriteBufferRegion(&p->buffer, &p->region); |
564 | new_regions.push_back(buffer_region); |
565 | } |
566 | *regions = std::move(new_regions); |
567 | } |
568 | |
569 | void RewriteMatchBuffers(Array<MatchBufferRegion>* match_buffers) const { |
570 | Array<MatchBufferRegion> result; |
571 | result.reserve(match_buffers->size()); |
572 | for (const auto& match_buffer : *match_buffers) { |
573 | const BufferRegion& buffer_region = match_buffer->source; |
574 | auto p = make_object<BufferRegionNode>(*buffer_region.get()); |
575 | RewriteBufferRegion(&p->buffer, &p->region); |
576 | result.push_back(MatchBufferRegion(match_buffer->buffer, BufferRegion(p))); |
577 | } |
578 | *match_buffers = std::move(result); |
579 | } |
580 | |
581 | /*! \brief The allocation information about each buffer. */ |
582 | std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_; |
583 | }; |
584 | |
585 | PrimFunc CompactBufferAllocation(PrimFunc f) { |
586 | // Only apply this pass to TIR that is not from TE schedules |
587 | if (!IsFromLegacyTESchedule(f)) { |
588 | PrimFuncNode* fptr = f.CopyOnWrite(); |
589 | std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region = |
590 | BufferAccessRegionCollector::Collect(f); |
591 | std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual> |
592 | storage_align = StorageAlignCollector::Collect(f); |
593 | fptr->body = BufferCompactor::Compact(f, region, storage_align); |
594 | return f; |
595 | } else { |
596 | return f; |
597 | } |
598 | } |
599 | |
600 | namespace transform { |
601 | |
602 | Pass CompactBufferAllocation() { |
603 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
604 | return CompactBufferAllocation(std::move(f)); |
605 | }; |
606 | return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation" , {}); |
607 | } |
608 | |
609 | TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation" ) |
610 | .set_body_typed(CompactBufferAllocation); |
611 | } // namespace transform |
612 | |
613 | } // namespace tir |
614 | } // namespace tvm |
615 | |