1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <unordered_set>
21
22#include "../utils.h"
23
24namespace tvm {
25namespace tir {
26
27/******** Error Classes ********/
28
29class NotSingleWriteBlock : public ScheduleError {
30 public:
31 explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array<StmtSRef> write_blocks)
32 : mod_(std::move(mod)), buffer_(std::move(buffer)) {
33 ICHECK_GT(write_blocks.size(), 1);
34 write_blocks_.reserve(write_blocks.size());
35 for (const StmtSRef& block_sref : write_blocks) {
36 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
37 write_blocks_.push_back(GetRef<Block>(block));
38 }
39 }
40
41 String FastErrorString() const final {
42 return "ScheduleError: The buffer is allowed to be written by single block.";
43 }
44
45 String DetailRenderTemplate() const final {
46 size_t k = write_blocks_.size();
47 return "The buffer " + buffer_->name + " is expected to be written by single block, but got " +
48 std::to_string(k) + " blocks who write it.";
49 }
50
51 IRModule mod() const final { return mod_; }
52 Array<ObjectRef> LocationsOfInterest() const final {
53 return {write_blocks_.begin(), write_blocks_.end()};
54 }
55
56 private:
57 IRModule mod_;
58 Buffer buffer_;
59 Array<Block> write_blocks_;
60};
61
62/******** Helper Functions/Classes ********/
63
64/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */
65struct CacheStageInfo {
66 /*! \brief The buffer to be read. */
67 Buffer read_buffer;
68 /*! \brief The buffer to be written. */
69 Buffer write_buffer;
70 /*! \brief The buffer allocation to be inserted into the block signature. */
71 Optional<Buffer> alloc;
72 /*! \brief The AST node whose body is where the cache stage should be inserted. */
73 StmtSRef loc_sref;
74 /*! \brief The index to insert the cache_read/cache_write stage. */
75 size_t loc_pos;
76 /*! \brief The cache_read/cache_write stage to be inserted. */
77 Stmt cache_stage;
78 /*! \brief The map used for ScheduleStateNode::Replace. */
79 Map<Block, Block> block_reuse;
80 /*! \brief A set of blocks that will consume the new cache. */
81 std::unordered_set<StmtSRef, ObjectHash, ObjectEqual> consumer_blocks;
82};
83
84/*! \brief Return the buffer region realted with the buffer */
85Optional<BufferRegion> GetBufferRegionFromBuffer(const Array<BufferRegion>& buffer_regions,
86 const Buffer& buffer) {
87 Optional<BufferRegion> res = NullOpt;
88 for (const auto& region : buffer_regions) {
89 if (region->buffer.same_as(buffer)) {
90 ICHECK(!res.defined());
91 res = region;
92 }
93 }
94 return res;
95}
96
97/*!
98 * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer
99 * to write buffer.
100 * \note This function will store the stmt with loop nesting to the CacheStageInfo, but only return
101 * the inside block.
102 * \param cache_region The cached copy region.
103 * \param info The cache stage information, which will be updated in the function.
104 * \param storage_scope The storage scope of the cached buffer (only used in naming here)
105 * \returns A block indicating the body of the loop nesting.
106 */
107Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info,
108 const String& storage_scope) {
109 // loop variables
110 std::vector<Var> loop_vars;
111 // bindings in block realize
112 std::vector<PrimExpr> iter_values;
113 // Create loop vars and block vars' binding_value
114 for (const Range& axis_range : cache_region->region) {
115 Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype());
116 loop_vars.push_back(loop_var);
117 iter_values.push_back(axis_range->min + loop_var);
118 }
119 // block variables
120 Array<IterVar> block_vars;
121 // block access region for read/write buffers
122 Region access_region;
123 // indices used in block body
124 Array<PrimExpr> access_indices;
125 // Create block vars, block's accessed region and accessing indices
126 for (const PrimExpr& dim : cache_region->buffer->shape) {
127 Var var("v" + std::to_string(access_indices.size()), dim.dtype());
128 block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim),
129 /*var=*/var,
130 /*IterVarType=*/kDataPar));
131 access_indices.push_back(var);
132 access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1)));
133 }
134
135 // Create the body block:
136 // reads = [read_buffer[access_region]]
137 // writes = [write_buffer[access_region]]
138 // write_buffer[access_indices] = read_buffer[access_indices]
139 Block block(
140 /*iter_vars=*/std::move(block_vars),
141 /*reads=*/{BufferRegion(info->read_buffer, access_region)},
142 /*writes=*/{BufferRegion(info->write_buffer, access_region)},
143 /*name_hint=*/cache_region->buffer->name + "_" + storage_scope,
144 /*body=*/
145 BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices),
146 access_indices),
147 /*init=*/NullOpt,
148 /*alloc_buffers=*/{},
149 /*match_buffers=*/{},
150 /*annotations=*/{});
151 // Create the block realize node
152 Stmt body = BlockRealize(/*values=*/iter_values,
153 /*predicate=*/const_true(),
154 /*block=*/block);
155 // Create surrounding loops
156 for (size_t i = loop_vars.size(); i >= 1; --i) {
157 body = For(/*loop_var=*/loop_vars[i - 1],
158 /*min=*/0,
159 /*extent=*/cache_region->region[i - 1]->extent,
160 /*kind=*/ForKind::kSerial,
161 /*body=*/body);
162 }
163 info->cache_stage = std::move(body);
164 return block;
165}
166
167/*!
168 * \brief Create the reindex block and generate the corresponding outer loops.
169 * \details The reindex block is a data copy block between the reindex buffer (the intermediate
170 * buffer), and the target buffer.
171 If buffer_index_type == kWrite, copy from the reindex buffer to the target buffer.
172 If buffer_index_type == kRead, copy from the target buffer to the reindex buffer.
173 The reindex block has the same block iters and the surrounding loops as the input block.
174 However, if a block iter is not used in the indices of the target buffer being reindexed, the
175 domain of the block iter, and the corresponding outer loop, will become constant value one, making
176 it a trivial iter.
177 * \param block The block to be reindexed
178 * \param info The cache info
179 * \param covered The set of block iter vars covered in the buffer access indices
180 * \param original_indices The original buffer access indices
181 * \param buffer_index The index of the target buffer
182 * \param buffer_index_type The type of buffer index
183 * \return The reindex block.
184 */
185Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
186 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered,
187 const Array<PrimExpr>& original_indices, int buffer_index,
188 BufferIndexType buffer_index_type) {
189 // iters of the reindex block
190 Array<IterVar> new_block_iters;
191 // the substition map from the original block iter to the iters of the reindex block
192 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectEqual> block_var_replace_map;
193 // indices to access the reindex buffer and the target buffer
194 Array<PrimExpr> reindex_indices, target_indices;
195
196 // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the
197 // reindex buffer.
198 std::unordered_set<int> skipped_block_iters;
199 for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
200 const IterVar& iter = block->iter_vars[i];
201 Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype);
202 bool used = covered.count(iter->var);
203 if (used) {
204 new_block_iters.push_back(IterVar(/*dom=*/iter->dom,
205 /*var=*/var,
206 /*IterVarType=*/kDataPar));
207 } else {
208 skipped_block_iters.insert(i);
209 }
210 if (used) {
211 reindex_indices.push_back(var);
212 }
213 block_var_replace_map[iter->var] = var;
214 }
215
216 // Step 2: Replace the original block iters with the new block iters
217 for (const PrimExpr& index : original_indices) {
218 target_indices.push_back(Substitute(index, block_var_replace_map));
219 }
220
221 // Step 3: Create the reindex block
222
223 // The src and the dst region and indices of the data copy
224 Region src_region{nullptr};
225 Region dst_region{nullptr};
226 Array<PrimExpr> src_indices{nullptr};
227 Array<PrimExpr> dst_indices{nullptr};
228
229 if (buffer_index_type == BufferIndexType::kWrite) {
230 src_indices = reindex_indices;
231 dst_indices = target_indices;
232 } else {
233 src_indices = target_indices;
234 dst_indices = reindex_indices;
235 }
236
237 // Create the body block
238 Block new_block(
239 /*iter_vars=*/new_block_iters,
240 /*reads=*/{BufferRegion::FromPoint(info->read_buffer, src_indices)},
241 /*writes=*/{BufferRegion::FromPoint(info->write_buffer, dst_indices)},
242 /*name_hint=*/info->write_buffer->name + "_reindex",
243 /*body=*/
244 BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices));
245
246 // Step 4: Create surrounding loops
247
248 // Create loop vars and bindings for block iters
249 std::vector<Var> loop_vars; // loop variables
250 std::vector<PrimExpr> iter_values; // bindings in block realize
251 for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
252 if (skipped_block_iters.count(i)) {
253 continue;
254 }
255 Var loop_var("ax" + std::to_string(loop_vars.size()), block->iter_vars[i]->var->dtype);
256 loop_vars.push_back(loop_var);
257 iter_values.push_back(loop_var);
258 }
259
260 // Create the block realize node
261 Stmt body = BlockRealize(/*values=*/iter_values,
262 /*predicate=*/const_true(),
263 /*block=*/new_block);
264
265 // Create the chain of loops
266 for (int i = static_cast<int>(new_block_iters.size()) - 1; i >= 0; --i) {
267 body = For(/*loop_var=*/loop_vars[i],
268 /*min=*/new_block_iters[i]->dom->min,
269 /*extent=*/new_block_iters[i]->dom->extent,
270 /*kind=*/ForKind::kSerial,
271 /*body=*/std::move(body));
272 }
273 // Update cache info, which will be used in the later rewriting.
274 info->cache_stage = std::move(body);
275 return new_block;
276}
277
278/*!
279 * \brief Recalculate the `affine_binding` flag of a specifc block
280 * \param block_sref The sref to the specific block
281 */
282bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) {
283 if (block_sref->parent == nullptr) {
284 return true;
285 }
286 arith::Analyzer analyzer;
287 StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
288 return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref),
289 /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref),
290 /*analyzer=*/&analyzer);
291}
292
293/*!
294 * \brief Insert the cache_read/cache_write stage into the specific position
295 * \param stmt A sequence of statements or a single statement that the new stage is inserted in
296 * \param pos The position where the cache stage is inserted
297 * \param stage The stage to be inserted
298 * \return A SeqStmt, the result after insertion
299 */
300Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) {
301 if (const auto* alloc = stmt.as<AllocateConstNode>()) {
302 auto seq_stmt = InsertCacheStage(alloc->body, pos, stage);
303 return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt,
304 alloc->annotations, alloc->span);
305 }
306 if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
307 ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
308 result->seq.insert(result->seq.begin() + pos, stage);
309 return SeqStmt(result);
310 }
311 if (pos == 0) {
312 return SeqStmt({stage, stmt});
313 }
314 ICHECK_EQ(pos, 1);
315 return SeqStmt({stmt, stage});
316}
317
318/*!
319 * \brief Get the only writer block of the input buffer in a given scope block.
320 * \param self The state of the schedule
321 * \param scope_sref The scope block where the write is considered
322 * \param buffer The queried buffer
323 * \return The sref of the only writer of the input buffer in the given scope,
324 * or `NullOpt` if no block writes it in the scope.
325 * \throw NotSingleWriteBlock if there are more than one intrested block.
326 */
327Optional<StmtSRef> GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref,
328 const Buffer& buffer) {
329 BlockScope scope = self->GetBlockScope(scope_sref);
330 auto it = scope->buffer_writers.find(buffer);
331 if (it == scope->buffer_writers.end()) {
332 return NullOpt;
333 } else {
334 const Array<StmtSRef>& block_srefs = it->second;
335 ICHECK(!block_srefs.empty());
336 if (block_srefs.size() > 1) {
337 throw NotSingleWriteBlock(self->mod, buffer, block_srefs);
338 }
339 return block_srefs[0];
340 }
341}
342
343/*!
344 * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
345 * \param self The state of the schedule.
346 * \param buffer_region The buffer region to be analyzed.
347 * \param block_sref The sref of the block related to the region.
348 * \param dom_low_inclusive The lowest node in the sref tree path.
349 * \param dom_high_exclusive The highest node in the sref tree path.
350 * \return The relaxed buffer region.
351 */
352BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region,
353 const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive,
354 const StmtSRef& dom_high_exclusive) {
355 BlockRealize realize = GetBlockRealize(self, block_sref);
356 Map<Var, PrimExpr> binding = GetBindings(realize);
357 const Buffer& buffer = buffer_region->buffer;
358 arith::Analyzer analyzer;
359 BufferRegion subst_region = BufferRegion(buffer, Substitute(buffer_region->region, binding));
360 Array<arith::IntSet> int_sets = AnalyzeRegionUpperBound(
361 /*region=*/subst_region,
362 /*predicate=*/realize->predicate,
363 /*dom_low_inclusive=*/dom_low_inclusive,
364 /*dom_high_exclusive=*/dom_high_exclusive,
365 /*analyzer=*/&analyzer);
366 ICHECK_EQ(buffer_region->region.size(), int_sets.size());
367
368 Region region;
369 region.reserve(int_sets.size());
370 for (size_t i = 0; i < int_sets.size(); ++i) {
371 region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i])));
372 }
373 return BufferRegion(buffer, region);
374}
375
376/*! \brief Detect the insertion position of the new cache stage */
377class CacheLocDetector : public StmtVisitor {
378 public:
379 /*!
380 * \brief Detect the insertion position of the cache stage, and write the position into the
381 * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique
382 * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref
383 * of the scope block of the cached block \param info The cache stage info.
384 */
385 template <bool is_cache_read>
386 static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
387 const StmtSRef& scope_sref, CacheStageInfo* info) {
388 std::vector<StmtSRef> related_blocks;
389 // If consumer is specified, skip detecting the others
390 if (is_cache_read) {
391 if (info->consumer_blocks.size() > 0) {
392 for (StmtSRef consumer : info->consumer_blocks) {
393 related_blocks.emplace_back(consumer);
394 }
395 } else {
396 for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
397 if (def->kind == DepKind::kRAW) {
398 related_blocks.push_back(def->dst);
399 }
400 }
401 }
402 } else {
403 for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
404 if (def->kind == DepKind::kRAW) {
405 if (info->consumer_blocks.count(def->dst)) {
406 continue;
407 }
408 related_blocks.push_back(def->dst);
409 }
410 }
411 }
412
413 if (!related_blocks.empty()) {
414 CacheLocDetector detector(self, block_sref, scope_sref, related_blocks);
415 detector(GetRef<Stmt>(scope_sref->stmt));
416 info->loc_sref = detector.loc_sref_;
417 info->loc_pos = detector.loc_pos_;
418 } else {
419 info->loc_sref = scope_sref;
420
421 auto block_body = scope_sref->StmtAs<BlockNode>()->body;
422 // Find the SeqStmtNode within (potentially nested) AllocateConstNodes
423 while (block_body->IsInstance<AllocateConstNode>()) {
424 block_body = block_body.as<AllocateConstNode>()->body;
425 }
426 const auto* body = block_body.as<SeqStmtNode>();
427 info->loc_pos = body == nullptr ? 1 : body->size();
428 }
429 }
430
431 private:
432 /*!
433 * \brief Constructor
434 * \param self The state of the schedule
435 * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or
436 * cache_write \param scope_sref The sref of the scope block of the cached block \param
437 * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read
438 */
439 CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref,
440 const std::vector<StmtSRef>& related_blocks)
441 : self_(self),
442 block_sref_(block_sref),
443 scope_sref_(scope_sref),
444 related_blocks_(related_blocks) {}
445
446 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
447 bool previous_visited_block = visited_block_;
448 visited_block_ = false;
449
450 for (size_t i = 0; i < seq_stmt->size(); ++i) {
451 if (loc_pos_ != -1) {
452 break;
453 }
454 VisitStmt(seq_stmt->seq[i]);
455 // `pos` can be assigned only once when we visited `block_sref`
456 if (visited_block_ && visited_related_ && loc_pos_ == -1) {
457 // The offset of insert position from the block
458 loc_pos_ = i;
459 break;
460 } else if (visited_related_) {
461 // If meet the target consumer, stop searching
462 break;
463 }
464 }
465 visited_block_ = visited_block_ || previous_visited_block;
466 }
467
468 void VisitStmt_(const BlockNode* block) final {
469 // Only visit the current scope under buffer writer's parent block
470 if (block == scope_sref_->stmt) {
471 // The block vistied is the current parent scope
472 StmtVisitor::VisitStmt_(block);
473 // Handling cases when insert outside any loop or cache_read for input buffer
474 if (visited_related_ && !loc_sref_.defined()) {
475 loc_sref_ = self_->stmt2ref.at(block);
476 // Handling cache_read for input buffer
477 if (visited_block_ == false && loc_pos_ == -1) {
478 loc_pos_ = 0;
479 }
480 }
481 return;
482 }
483 // Update `visited_block`
484 if (block_sref_->stmt == block) {
485 visited_block_ = true;
486 return;
487 }
488 // Update `visited_related`
489 for (const StmtSRef& related_block : related_blocks_) {
490 if (related_block->stmt == block) {
491 visited_related_ = true;
492 return;
493 }
494 }
495 }
496
497 void VisitStmt_(const ForNode* loop) final {
498 StmtVisitor::VisitStmt_(loop);
499 if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_ != -1) {
500 loc_sref_ = self_->stmt2ref.at(loop);
501 }
502 }
503
504 private:
505 /*! \brief The schedule class */
506 const ScheduleState self_;
507 /*! \brief The dominate block which write the buffer */
508 const StmtSRef& block_sref_;
509 /*! \brief The parent scope of the dominate block */
510 const StmtSRef& scope_sref_;
511 /*! \brief Producer blocks for cache_write and consumer blocks for cache_read */
512 const std::vector<StmtSRef>& related_blocks_;
513 /*! \brief The flag whether we have visited the dominate block */
514 bool visited_block_{false};
515 /*! \brief The flag whether we have visited at least one related blocks */
516 bool visited_related_{false};
517 /*! \brief The AST node whose body is where the cache stage should be inserted */
518 StmtSRef loc_sref_{nullptr};
519 /*! \brief The index to insert the cache_read/cache_write stage */
520 int loc_pos_{-1};
521};
522
523/*! \brief Detect the insertion position of the new cache stage */
524class CacheInplaceLocDetector : public StmtVisitor {
525 public:
526 /*!
527 * \brief Detect the insertion position of the cache stage, and write the position into the
528 * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique
529 * block of the buffer being applied cache_inplace \param scope_sref The sref
530 * of the scope block of the cached block \param info The cache stage info.
531 */
532 static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
533 const StmtSRef& scope_sref, CacheStageInfo* info) {
534 CacheInplaceLocDetector detector(self, block_sref, scope_sref);
535 detector(GetRef<Stmt>(scope_sref->stmt));
536 info->loc_sref = detector.loc_sref_;
537 info->loc_pos = detector.loc_pos_;
538 }
539
540 private:
541 /*!
542 * \brief Constructor
543 * \param self The state of the schedule
544 * \param block_sref The sref of the unique writer block of the buffer being applied cache_inplace
545 * \param scope_sref The sref of the scope block of the cached block
546 */
547 CacheInplaceLocDetector(const ScheduleState self, const StmtSRef& block_sref,
548 const StmtSRef& scope_sref)
549 : self_(self), block_sref_(block_sref), scope_sref_(scope_sref) {}
550
551 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
552 for (size_t i = 0; i < seq_stmt->size(); ++i) {
553 if (loc_pos_ != -1) {
554 break;
555 }
556 VisitStmt(seq_stmt->seq[i]);
557 // `pos` can be assigned only once when we visited `block_sref`
558 if (visited_block_ && loc_pos_ == -1) {
559 // The offset of insert position from the block
560 loc_pos_ = i;
561 return;
562 }
563 }
564 }
565
566 void VisitStmt_(const BlockNode* block) final {
567 // Only visit the current scope under buffer writer's parent block
568 if (block == scope_sref_->stmt) {
569 // The block vistied is the current parent scope
570 StmtVisitor::VisitStmt_(block);
571 // Handling cases when insert outside any loop
572 if (visited_block_ && !loc_sref_.defined()) {
573 loc_sref_ = self_->stmt2ref.at(block);
574 // Handling for input buffer
575 if (loc_pos_ == -1) {
576 loc_pos_ = 0;
577 }
578 }
579 } else if (block_sref_->stmt == block) {
580 visited_block_ = true;
581 }
582 }
583
584 void VisitStmt_(const ForNode* loop) final {
585 StmtVisitor::VisitStmt_(loop);
586 if (visited_block_ && !loc_sref_.defined()) {
587 loc_sref_ = self_->stmt2ref.at(loop);
588 if (loc_pos_ == -1) {
589 loc_pos_ = 0;
590 }
591 }
592 }
593
594 private:
595 /*! \brief The schedule class */
596 const ScheduleState self_;
597 /*! \brief The dominate block which write the buffer */
598 const StmtSRef& block_sref_;
599 /*! \brief The parent scope of the dominate block */
600 const StmtSRef& scope_sref_;
601 /*! \brief The flag whether we have visited the target block */
602 bool visited_block_{false};
603 /*! \brief The AST node whose body is where the cache stage should be inserted */
604 StmtSRef loc_sref_{nullptr};
605 /*! \brief The index to insert the cache_read/cache_write stage */
606 int loc_pos_{-1};
607};
608
609/*! \brief Mutator for CacheRead. */
610class CacheReadRewriter : public StmtExprMutator {
611 public:
612 /*!
613 * \brief Rewrite the AST and add a cache_read stage with the information provided
614 * \param scope_sref The parent scope of this mutation
615 * \param info The cache stage information
616 * \return The new AST rooting at the original parent scope
617 */
618 static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) {
619 CacheReadRewriter rewriter(scope_sref, info);
620 return rewriter(GetRef<Stmt>(scope_sref->stmt));
621 }
622
623 private:
624 explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info)
625 : scope_sref_(scope_sref), info_(info) {}
626
627 Stmt VisitStmt_(const ForNode* loop) final {
628 Stmt stmt = StmtMutator::VisitStmt_(loop);
629 // Check the insertion point
630 if (loop == info_->loc_sref->stmt) {
631 // Insert cache stage into the loop if it is the right place
632 ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
633 n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
634 stmt = Stmt(n);
635 }
636 return stmt;
637 }
638
639 Stmt VisitStmt_(const BlockNode* block) final {
640 Block old_stmt = GetRef<Block>(block);
641 // Check if this block is one of the specified consumers.
642 // If no consumer blocks are specified, all blocks should be considered consumers.
643 bool is_consumer = info_->consumer_blocks.empty();
644 // Otherwise check if this is one of the specified blocks.
645 for (StmtSRef consumer_sref : info_->consumer_blocks) {
646 const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref);
647 Block consumer_block = GetRef<Block>(consumer_node);
648 if (old_stmt.same_as(consumer_block)) {
649 is_consumer = true;
650 }
651 }
652 // Keep track of this blocks status. We'll use this when rewriting loads.
653 current_block_consumes = is_consumer;
654 // We don't mutate the block which generates info->read_buffer.
655 if (block != scope_sref_->stmt &&
656 GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) {
657 return std::move(old_stmt);
658 }
659 // Mutate the body
660 Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
661 // Check the insertion point
662 if (block == info_->loc_sref->stmt) {
663 // Insert cache stage into the block if it is the right place
664 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
665 n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
666 stmt = Block(n);
667 }
668 // Check if it is the block corresponding to the parent scope
669 if (block == scope_sref_->stmt) {
670 // If so, put buffer allocation on the parent scope
671 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
672 // In cache_inplace case, alloc_buffer may be already exits.
673 if (info_->alloc.defined()) {
674 n->alloc_buffers.push_back(info_->alloc.value());
675 stmt = Block(n);
676 }
677 } else {
678 // Otherwise, update read regions and match_buffers
679 // Only make this change if the block is one of the specified consumers.
680 if (is_consumer) {
681 Array<BufferRegion> reads =
682 ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
683 Array<MatchBufferRegion> match_buffers =
684 ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer);
685 if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
686 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
687 n->reads = std::move(reads);
688 n->match_buffers = std::move(match_buffers);
689 stmt = Block(n);
690 }
691 }
692 }
693 info_->block_reuse.Set(old_stmt, stmt);
694 return std::move(stmt);
695 }
696
697 PrimExpr VisitExpr_(const BufferLoadNode* load) final {
698 if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) {
699 ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
700 n->buffer = info_->write_buffer;
701 return PrimExpr(n);
702 }
703 return ExprMutator::VisitExpr_(load);
704 }
705
706 PrimExpr VisitExpr_(const LoadNode* op) final {
707 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
708 }
709
710 PrimExpr VisitExpr_(const VarNode* op) final {
711 if (op == info_->read_buffer->data.get()) {
712 return info_->write_buffer->data;
713 }
714 return GetRef<PrimExpr>(op);
715 }
716
717 private:
718 /*! \brief The parent scope of the insertion */
719 const StmtSRef& scope_sref_;
720 /*! \brief The info for inserting cache stage */
721 CacheStageInfo* info_;
722 /*! \brief Whether the most recently visited block is a specified consumer. */
723 bool current_block_consumes;
724};
725
726/*! \brief Mutator for CacheWrite */
727class CacheWriteRewriter : public StmtExprMutator {
728 public:
729 /*!
730 * \brief Rewrite the AST and add a cache_write stage with the information provided.
731 * \param scope_sref The parent scope of this mutation.
732 * \param writer_block_sref The only writer block in the scope.
733 * \param info The cache stage information.
734 * \return The new AST rooting at the original parent scope.
735 */
736 static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref,
737 CacheStageInfo* info) {
738 CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info);
739 return rewriter(GetRef<Stmt>(scope_sref->stmt));
740 }
741
742 private:
743 explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref,
744 CacheStageInfo* info)
745 : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {}
746
747 Stmt VisitStmt_(const ForNode* loop) final {
748 Stmt stmt = StmtMutator::VisitStmt_(loop);
749 // Check the insertion point
750 if (loop == info_->loc_sref->stmt) {
751 // Insert cache stage into the loop if it is the right place
752 ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>());
753 n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
754 stmt = Stmt(n);
755 }
756 return stmt;
757 }
758
759 Stmt VisitStmt_(const BlockNode* block) final {
760 Block old_stmt = GetRef<Block>(block);
761
762 // Check if this block is one of the specified cache consumers.
763 // update the read buffer to the cache.
764 for (StmtSRef consumer_sref : info_->consumer_blocks) {
765 const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref);
766 Block consumer_block = GetRef<Block>(consumer_node);
767 if (old_stmt.same_as(consumer_block)) {
768 Array<BufferRegion> reads =
769 ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer);
770 Array<MatchBufferRegion> match_buffers =
771 ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer);
772 if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
773 auto n = CopyOnWrite(block);
774 n->reads = std::move(reads);
775 n->match_buffers = std::move(match_buffers);
776 n->body = VisitStmt(block->body);
777 Block new_consumer = Block(n);
778 info_->block_reuse.Set(old_stmt, new_consumer);
779 return std::move(new_consumer);
780 }
781 return std::move(old_stmt);
782 }
783 }
784
785 // We only mutate the block which generates info->write_buffer
786 if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) {
787 return std::move(old_stmt);
788 }
789
790 // Mutate the body
791 bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt;
792 std::swap(under_scope, under_writer_block_);
793 Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
794 std::swap(under_scope, under_writer_block_);
795
796 // Find the insertion point
797 if (block == info_->loc_sref->stmt) {
798 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
799 n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
800 stmt = Block(n);
801 }
802 // Put buffer allocation on the parent scope
803 if (block == scope_sref_->stmt) {
804 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
805 // In cache_inplace case, alloc_buffer may be already exits.
806 if (info_->alloc.defined()) {
807 n->alloc_buffers.push_back(info_->alloc.value());
808 stmt = Block(n);
809 }
810 } else {
811 // Since cache_write changes the block, we need to update the buffer it writes
812 auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer);
813 auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer);
814 auto match_buffers =
815 ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer);
816 if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
817 !match_buffers.same_as(block->match_buffers)) {
818 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
819 n->writes = std::move(writes);
820 n->reads = std::move(reads);
821 n->match_buffers = std::move(match_buffers);
822 stmt = Block(n);
823 }
824 }
825 info_->block_reuse.Set(old_stmt, stmt);
826 return std::move(stmt);
827 }
828
829 Stmt VisitStmt_(const BufferStoreNode* store) final {
830 BufferStore stmt = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
831 if (stmt->buffer.same_as(info_->write_buffer)) {
832 auto n = CopyOnWrite(stmt.get());
833 n->buffer = info_->read_buffer;
834 return Stmt(n);
835 } else {
836 return std::move(stmt);
837 }
838 }
839
840 PrimExpr VisitExpr_(const BufferLoadNode* load) final {
841 if (load->buffer.same_as(info_->write_buffer)) {
842 ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
843 n->buffer = info_->read_buffer;
844 return PrimExpr(n);
845 }
846 return ExprMutator::VisitExpr_(load);
847 }
848
849 PrimExpr VisitExpr_(const LoadNode* op) final {
850 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
851 }
852
853 Stmt VisitStmt_(const StoreNode* op) final {
854 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
855 }
856
857 PrimExpr VisitExpr_(const VarNode* op) final {
858 if (op == info_->write_buffer->data.get()) {
859 return info_->read_buffer->data;
860 }
861 return GetRef<PrimExpr>(op);
862 }
863
864 private:
865 /*! \brief The parent scope of the insertion. */
866 const StmtSRef& scope_sref_;
867 /*! \brief The parent scope of the insertion. */
868 const StmtSRef& writer_block_sref_;
869 /*! \brief The info for inserting cache stage. */
870 CacheStageInfo* info_;
871 /*! \brief Whether the current node is under the given block. */
872 bool under_writer_block_{false};
873};
874
875/*!
876 * \brief Create a new buffer by change the shape with block iters to be used as the reindex buffer
877 * \param buffer The given buffer.
878 * \param block_iters The block iters.
879 * \param covered Set of block iter vars covered by the buffer access indices
880 * \return The new buffer with target shape.
881 */
882Buffer CreateReindexBuffer(const Buffer& buffer, const Array<IterVar>& block_iters,
883 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered) {
884 ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
885 ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
886 std::vector<PrimExpr> new_shape;
887 std::vector<PrimExpr> new_strides;
888 for (const auto& iter : block_iters) {
889 if (covered.count(iter->var)) {
890 new_shape.push_back(iter->dom->min + iter->dom->extent);
891 }
892 }
893 new_strides.clear();
894 new_buffer->shape = new_shape;
895 new_buffer->strides = new_strides;
896 new_buffer->data = buffer->data.copy_with_suffix("_reindex");
897 new_buffer->name = buffer->name + "_reindex";
898 return Buffer(new_buffer);
899}
900
901/*! \brief The schedule error that the target is not a leaf block. */
902class NotLeafBlockError : public ScheduleError {
903 public:
904 NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {}
905 String FastErrorString() const final {
906 return "ScheduleError: The target block is not a leaf block.";
907 }
908
909 String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block."; }
910
911 IRModule mod() const final { return mod_; }
912 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
913 IRModule mod_;
914 Block block_;
915};
916
917/*! \brief The schedule error that the buffer access is invalid for reindex. */
918class InvalidBufferAccessError : public ScheduleError {
919 public:
920 enum class ErrorKind {
921 kNoAccess, // buffer access not found
922 kNonUniqueAccess, // multiple buffer accesses with different indices
923 kOpaqueAccess, // opaque access to the buffer
924 };
925
926 InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind)
927 : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {}
928 String FastErrorString() const final {
929 return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The "
930 "indices should be the same if there are multiple accesses to the target buffer.";
931 }
932
933 String DetailRenderTemplate() const final {
934 std::ostringstream os;
935 os << "The target buffer " << buffer_->name
936 << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices "
937 "should be the same if there are multiple accesses to the target buffer. ";
938 if (kind_ == ErrorKind::kNoAccess) {
939 os << "No buffer accesses found.";
940 } else if (kind_ == ErrorKind::kNonUniqueAccess) {
941 os << "Multiple buffer accesses have non-unique indices.";
942 } else if (kind_ == ErrorKind::kOpaqueAccess) {
943 os << "Opaque buffer accesses found.";
944 }
945 return os.str();
946 }
947 IRModule mod() const final { return mod_; }
948 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
949
950 private:
951 IRModule mod_;
952 Buffer buffer_;
953 Block block_;
954 ErrorKind kind_;
955};
956
957/*! \brief Collect the related Load/Store to reindex */
958class ReIndexCollector : public StmtExprVisitor {
959 public:
960 static Array<PrimExpr> Collect(const IRModule& mod, const Buffer& buffer, const Block& block) {
961 ReIndexCollector collector(mod, buffer, block);
962 collector(block->body);
963 if (!collector.buffer_access_indices_.defined()) {
964 throw InvalidBufferAccessError(mod, buffer, block,
965 InvalidBufferAccessError::ErrorKind::kNoAccess);
966 }
967 return collector.buffer_access_indices_.value();
968 }
969
970 private:
971 explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const Block& block)
972 : mod_(mod), buffer_(buffer), block_(block) {}
973
974 void VisitExpr_(const BufferLoadNode* load) final {
975 StmtExprVisitor::VisitExpr_(load);
976 if (load->buffer.same_as(buffer_)) {
977 CheckAndUpdateBufferAccessIndices(load->indices);
978 }
979 }
980
981 void VisitStmt_(const BlockNode* block) final {
982 // no sub-blocks under this block
983 throw NotLeafBlockError(mod_, block_);
984 }
985
986 void VisitStmt_(const BufferStoreNode* store) final {
987 StmtExprVisitor::VisitStmt_(store);
988 if (store->buffer.same_as(buffer_)) {
989 CheckAndUpdateBufferAccessIndices(store->indices);
990 }
991 }
992
993 void CheckAndUpdateBufferAccessIndices(const Array<PrimExpr> indices) {
994 if (!buffer_access_indices_.defined()) {
995 buffer_access_indices_ = indices;
996 return;
997 } else if (!std::equal(buffer_access_indices_.value().begin(),
998 buffer_access_indices_.value().end(), indices.begin(), indices.end(),
999 ExprDeepEqual())) {
1000 throw InvalidBufferAccessError(mod_, buffer_, block_,
1001 InvalidBufferAccessError::ErrorKind::kNonUniqueAccess);
1002 }
1003 }
1004
1005 void VisitExpr_(const VarNode* var) final {
1006 if (var == buffer_->data.get()) {
1007 throw InvalidBufferAccessError(mod_, buffer_, block_,
1008 InvalidBufferAccessError::ErrorKind::kOpaqueAccess);
1009 }
1010 }
1011 /*! \brief The IR module */
1012 IRModule mod_;
1013 /*! \brief The buffer to rewrite */
1014 Buffer buffer_;
1015 /*! \brief The block to visit */
1016 Block block_;
1017 /*! \brief The indices of buffer acess to rewrite */
1018 Optional<Array<PrimExpr>> buffer_access_indices_;
1019};
1020
1021/*! \brief Mutator of ReIndex */
1022class ReIndexRewriter : public StmtExprMutator {
1023 public:
1024 static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info,
1025 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered) {
1026 ReIndexRewriter rewriter(block_sref, info, covered);
1027 return rewriter(GetRef<Stmt>(scope_sref->stmt));
1028 }
1029
1030 private:
1031 explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info,
1032 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered)
1033 : block_sref_(block_sref), info_(info), covered_(covered) {
1034 new_buffer_ = info->alloc.value();
1035 old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer;
1036 }
1037
1038 Stmt VisitStmt_(const BlockNode* block) final {
1039 Block old_stmt = GetRef<Block>(block);
1040 if (is_scope_) {
1041 is_scope_ = false;
1042 Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
1043 // Insert cache stage into the loop
1044 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
1045 n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage);
1046 n->alloc_buffers.push_back(info_->alloc.value());
1047 stmt = Block(n);
1048 info_->block_reuse.Set(old_stmt, stmt);
1049 return std::move(stmt);
1050 }
1051
1052 // Visiting the blokc being reindexed
1053 if (block == block_sref_->stmt) {
1054 // Collect the updated indices and regions
1055 for (const IterVar& iter : block->iter_vars) {
1056 if (covered_.count(iter->var)) {
1057 indices_.push_back(iter->var);
1058 region_.push_back(Range::FromMinExtent(iter->var, IntImm(iter->var->dtype, 1)));
1059 }
1060 }
1061 Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
1062 // Update block reads/writes to use the intermediate reindex buffer
1063 auto writes =
1064 ReplaceBufferRegion(block->writes, old_buffer_, BufferRegion{new_buffer_, region_});
1065 auto reads =
1066 ReplaceBufferRegion(block->reads, old_buffer_, BufferRegion{new_buffer_, region_});
1067 auto match_buffers = ReplaceBufferRegion(block->match_buffers, old_buffer_,
1068 BufferRegion{new_buffer_, region_});
1069 if (!writes.same_as(block->writes) || !reads.same_as(block->reads) ||
1070 !match_buffers.same_as(block->match_buffers)) {
1071 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
1072 n->writes = std::move(writes);
1073 n->reads = std::move(reads);
1074 n->match_buffers = std::move(match_buffers);
1075 stmt = Block(n);
1076 }
1077 info_->block_reuse.Set(old_stmt, stmt);
1078 return std::move(stmt);
1079 }
1080 return std::move(old_stmt);
1081 }
1082
1083 template <typename Node>
1084 Node VisitBufferAccess(Node node) {
1085 if (node->buffer.same_as(old_buffer_)) {
1086 auto* n = node.CopyOnWrite();
1087 n->buffer = new_buffer_;
1088 n->indices = indices_;
1089 }
1090 return node;
1091 }
1092 Stmt VisitStmt_(const BufferStoreNode* op) final {
1093 BufferStore buffer_store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
1094 return VisitBufferAccess(std::move(buffer_store));
1095 }
1096
1097 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
1098 BufferLoad buffer_load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
1099 return VisitBufferAccess(std::move(buffer_load));
1100 }
1101
1102 private:
1103 /*! \brief The parent scope of the insertion. */
1104 const StmtSRef& block_sref_;
1105 /*! \brief The info for inserting reindex stage. */
1106 CacheStageInfo* info_;
1107 /*! \brief Whether old block var is covered in the indices */
1108 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered_;
1109 /*! \brief Whether the current block is scope block */
1110 bool is_scope_{true};
1111 /*! \brief The buffer to be replaced */
1112 Buffer old_buffer_;
1113 /*! \brief The reindex buffer */
1114 Buffer new_buffer_;
1115 /*! \brief The new indices */
1116 Array<PrimExpr> indices_;
1117 /*! \brief The new region */
1118 Region region_;
1119};
1120
1121void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) {
1122 class NotRegionCoverError : public ScheduleError {
1123 public:
1124 explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {}
1125 IRModule mod() const final { return mod_; }
1126 String FastErrorString() const final {
1127 return "ScheduleError: The scope root's region cover is not complete.";
1128 }
1129 String DetailRenderTemplate() const final {
1130 return R"(The scope {0} 's region cover is not complete.
1131The region cover property require to hold for every of its child blocks
1132)";
1133 }
1134 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
1135 IRModule mod_;
1136 Block block_;
1137 };
1138
1139 for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) {
1140 const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
1141 for (const BufferRegion& region : child_block->reads) {
1142 if (region->buffer.same_as(read_buffer)) {
1143 if (!self->block_info.at(child_block_sref).region_cover) {
1144 const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
1145 throw NotRegionCoverError(self->mod, GetRef<Block>(block));
1146 }
1147 }
1148 }
1149 }
1150}
1151
1152/******** Implementation ********/
1153
1154StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
1155 const String& storage_scope, const Array<StmtSRef> consumer_blocks) {
1156 /*!
1157 * Check:
1158 * - The index is in the array of block reading region
1159 * - There is at most one block who write the buffer in the scope
1160 *
1161 * Mutate:
1162 * - Allocate new cache buffer under the current scope.
1163 * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks.
1164 * - Copy the buffer with the consumed region.
1165 */
1166
1167 // Step 0. Check the input storage scope.
1168 CheckStorageScope(self, storage_scope);
1169
1170 // Step 1. Check index, getting the target buffer and the parent scope
1171 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1172 Buffer read_buffer =
1173 GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
1174 StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
1175 // Check required region cover for cache_read
1176 CheckRegionCover(self, scope_sref, read_buffer);
1177 const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
1178
1179 // Step 2. Create CacheStageInfo
1180 CacheStageInfo info;
1181 info.read_buffer = read_buffer;
1182 // Create the corresponding buffer to be written, i.e. result of cache_read
1183 info.write_buffer = WithScope(read_buffer, storage_scope);
1184 // Create the corresponding buffer allocation
1185 info.alloc = info.write_buffer;
1186
1187 // info.consumer_blocks indicates which buffers should consume the cache.
1188 for (auto consumer : consumer_blocks) {
1189 info.consumer_blocks.insert(consumer);
1190 for (auto child : tir::GetChildBlocks(self, consumer)) {
1191 info.consumer_blocks.insert(child);
1192 }
1193 }
1194
1195 // Step 3. Update cache stage info.
1196 BufferRegion cache_region{nullptr};
1197 if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) {
1198 // Case 1. The buffer is written inside the block.
1199 StmtSRef write_block_sref = _write_block_sref.value();
1200 const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
1201 // Find the producing region
1202 BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value();
1203 StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
1204
1205 // Detect insert position
1206 CacheLocDetector::Detect</*is_cache_read=*/true>(self, write_block_sref, scope_sref, &info);
1207 cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
1208 } else {
1209 // Case 2. The buffer is the input block for the scope.
1210 info.loc_sref = scope_sref;
1211 info.loc_pos = 0;
1212 if (Optional<BufferRegion> region =
1213 GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
1214 cache_region = region.value();
1215 } else {
1216 cache_region = BufferRegion::FullRegion(read_buffer);
1217 }
1218 }
1219
1220 // Step 4. Making new cache stage block and rewrite readers.
1221 Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
1222 /*storage_scope=*/storage_scope);
1223 Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
1224
1225 // Step 5. Replacing and updating flags.
1226 self->Replace(scope_sref, new_scope, info.block_reuse);
1227 StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
1228 BlockInfo& block_info = self->block_info[result_block_sref];
1229 block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
1230 block_info.region_cover = true;
1231 block_info.scope->stage_pipeline = true;
1232 return result_block_sref;
1233}
1234
1235StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
1236 const String& storage_scope, const Array<StmtSRef> consumer_blocks) {
1237 /*!
1238 * Check:
1239 * - The index is in the array of block reading region
1240 * - There is only one block who write the buffer in the scope
1241 *
1242 * Mutate:
1243 * - Allocate new cache buffer under the current scope.
1244 * - Find the lowest ancestor of the block and ANY ONE of the producer blocks.
1245 * - Copy the buffer with the consumed region.
1246 */
1247
1248 // Step 0. Check the input storage scope.
1249 CheckStorageScope(self, storage_scope);
1250
1251 // Step 1. Checking index, getting the target buffer and the parent scope
1252 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1253 Buffer write_buffer =
1254 GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite);
1255 StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
1256
1257 // Step 2. Creating CacheStageInfo
1258 CacheStageInfo info;
1259 info.read_buffer = WithScope(write_buffer, storage_scope);
1260 // Create the corresponding buffer to be written, i.e. result of cache_write
1261 info.write_buffer = write_buffer;
1262 // Create the corresponding buffer allocation
1263 info.alloc = info.read_buffer;
1264
1265 // info.consumer_blocks indicates which buffers should consume the cache.
1266 for (auto consumer : consumer_blocks) {
1267 info.consumer_blocks.insert(consumer);
1268 for (auto child : tir::GetChildBlocks(self, consumer)) {
1269 info.consumer_blocks.insert(child);
1270 }
1271 }
1272
1273 // Step 3. Check the only writer block.
1274 ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get());
1275
1276 // Step 4. Find the producing region and insert position
1277 BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value();
1278 StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
1279 // Detect insert position
1280 CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref, scope_sref, &info);
1281 BufferRegion cache_region =
1282 RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
1283
1284 // Step 5. Making new cache stage block and rewrite readers.
1285 Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
1286 /*storage_scope=*/storage_scope);
1287 Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref,
1288 /*writer_block_sref=*/block_sref, /*info=*/&info);
1289
1290 // Step 6. Replacing and updating flags.
1291 self->Replace(scope_sref, new_scope, info.block_reuse);
1292 StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get());
1293 BlockInfo& block_info = self->block_info[result_block_sref];
1294 block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
1295 block_info.region_cover = true;
1296 block_info.scope->stage_pipeline = true;
1297 return result_block_sref;
1298}
1299
1300/*! \brief The schedule error that the target block doesn't both read&write target buffer. */
1301class NotReadWriteError : public ScheduleError {
1302 public:
1303 NotReadWriteError(IRModule mod, Block block, Buffer buffer)
1304 : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {}
1305 String FastErrorString() const final {
1306 return "ScheduleError: The target block does not both read & write target buffer.";
1307 }
1308
1309 String DetailRenderTemplate() const final {
1310 return "The target block {0} does not both read & write target buffer {1}.";
1311 }
1312
1313 IRModule mod() const final { return mod_; }
1314 Array<ObjectRef> LocationsOfInterest() const final { return {block_, buffer_}; }
1315 IRModule mod_;
1316 Block block_;
1317 Buffer buffer_;
1318};
1319
1320Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
1321 const String& storage_scope) {
1322 /*!
1323 * Do cache read then cache write
1324 */
1325
1326 // Check 0. Check the input storage scope.
1327 CheckStorageScope(self, storage_scope);
1328
1329 // Check 1. Check index, get the target buffer and the parent scope
1330 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
1331 Buffer buffer =
1332 GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
1333 StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
1334
1335 // Check 3. Check required region cover for cache_read
1336 CheckRegionCover(self, scope_sref, buffer);
1337
1338 // Check 4. Check if target block both read & write target buffer.
1339 const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref);
1340 Optional<BufferRegion> read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer);
1341 Optional<BufferRegion> write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer);
1342 if (!read_region.defined() || !write_region.defined()) {
1343 throw NotReadWriteError(self->mod, GetRef<Block>(rw_block), buffer);
1344 }
1345
1346 Array<StmtSRef> results_block_sref;
1347 Buffer new_buffer = WithScope(buffer, storage_scope);
1348
1349 // Do cache read
1350 // Cache read step 0. Create CacheStageInfo
1351 CacheStageInfo info;
1352 info.read_buffer = buffer;
1353 // Create the corresponding buffer to be written for cache_read
1354 info.write_buffer = new_buffer;
1355 // Create the corresponding buffer allocation
1356 info.alloc = info.write_buffer;
1357 // Indicate which buffers should consume the cache.
1358 info.consumer_blocks.insert(block_sref);
1359
1360 // Cache read step 1. Detect insert position
1361 CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info);
1362
1363 // Cache read step 2. Making new cache stage block and rewrite readers.
1364 Block cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info,
1365 /*storage_scope=*/storage_scope);
1366 Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
1367
1368 // Cache read step 3. Replacing and updating flags for cache read.
1369 self->Replace(scope_sref, new_scope, info.block_reuse);
1370 StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
1371 BlockInfo& block_info_read = self->block_info[result_block_sref];
1372 block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
1373 block_info_read.region_cover = true;
1374 block_info_read.scope->stage_pipeline = false;
1375 results_block_sref.push_back(result_block_sref);
1376
1377 // Do cache write
1378 // Cache write step 0. Update cache stage info for cache_read.
1379 info.read_buffer = new_buffer;
1380 // Create the corresponding buffer to be written, i.e. result of cache_write
1381 info.write_buffer = buffer;
1382 // Create the corresponding buffer allocation
1383 info.alloc = nullptr;
1384 info.consumer_blocks.clear();
1385
1386 // Cache write step 1. Detect insert position
1387 CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info);
1388 // insert after target block for cache write
1389 info.loc_pos += 1;
1390
1391 // Cache write step 2. Making new cache stage block and rewrite readers.
1392 Block cache_write_stage = MakeCacheStage(/*cache_region=*/write_region.value(), /*info=*/&info,
1393 /*storage_scope=*/storage_scope);
1394 new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref,
1395 /*writer_block_sref=*/block_sref, /*info=*/&info);
1396
1397 // Cache write step 4. Replacing and updating flags for cache write.
1398 self->Replace(scope_sref, new_scope, info.block_reuse);
1399 result_block_sref = self->stmt2ref.at(cache_write_stage.get());
1400 BlockInfo& block_info_write = self->block_info[result_block_sref];
1401 block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref);
1402 block_info_write.region_cover = true;
1403 block_info_write.scope->stage_pipeline = false;
1404 results_block_sref.push_back(result_block_sref);
1405
1406 return results_block_sref;
1407}
1408
1409StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
1410 BufferIndexType buffer_index_type) {
1411 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
1412 Block block = GetRef<Block>(block_ptr);
1413 Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type);
1414 StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
1415 arith::Analyzer analyzer;
1416
1417 // Step 1. Collect the original indices and check there's only single pattern of related
1418 // Load/Store and the buffer is not accessed opaquely
1419 Array<PrimExpr> original_indices = ReIndexCollector::Collect(self->mod, buffer, block);
1420 // Simplify the indices if possible
1421 for (const IterVar& iter : block->iter_vars) {
1422 analyzer.Bind(iter->var, iter->dom);
1423 }
1424 original_indices.MutateByApply(
1425 [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); });
1426
1427 // Collect block iters appearing in the original_indices
1428 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> covered;
1429 for (const PrimExpr& index : original_indices) {
1430 PreOrderVisit(index, [&](const ObjectRef& obj) -> bool {
1431 if (const VarNode* var = obj.as<VarNode>()) {
1432 covered.insert(GetRef<Var>(var));
1433 }
1434 return true;
1435 });
1436 }
1437
1438 // Step 2. Creating CacheStageInfo
1439 CacheStageInfo info;
1440 // Create the corresponding buffer to be read(write), i.e. the result of reindex read(write)
1441 if (buffer_index_type == BufferIndexType::kWrite) {
1442 info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
1443 info.write_buffer = buffer;
1444 info.alloc = info.read_buffer;
1445 } else {
1446 info.read_buffer = buffer;
1447 info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered);
1448 info.alloc = info.write_buffer;
1449 }
1450
1451 // Step 3. Check the block belongs to a chain loop nesting under the scope,
1452 // and get the insert location
1453 const StmtSRefNode* loop;
1454 for (loop = block_sref->parent; loop->parent != scope_sref.get();) {
1455 const ForNode* outer = loop->parent->StmtAs<ForNode>();
1456 const ForNode* inner = loop->StmtAs<ForNode>();
1457 ICHECK(outer != nullptr && inner != nullptr);
1458 ICHECK(outer->body.get() == inner);
1459 loop = loop->parent;
1460 }
1461
1462 info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index;
1463 if (buffer_index_type == BufferIndexType::kWrite) {
1464 info.loc_pos++;
1465 }
1466
1467 // Step 4. Making new reindex stage block and rewrite
1468 Block reindex_stage =
1469 MakeReIndexStage(block, &info, covered, original_indices, buffer_index, buffer_index_type);
1470 Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, covered);
1471
1472 // Step 5. Replacing and updating flags
1473 self->Replace(scope_sref, new_scope, info.block_reuse);
1474 StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get());
1475 BlockInfo& block_info = self->block_info[result_block_sref];
1476 block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
1477 block_info.region_cover = true;
1478 block_info.scope->stage_pipeline = true;
1479 return result_block_sref;
1480}
1481
1482/******** Instruction Registration ********/
1483
1484struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
1485 static constexpr const char* kName = "CacheRead";
1486 static constexpr bool kIsPure = false;
1487
1488 private:
1489 static constexpr size_t kNumInputs = 2;
1490 static constexpr size_t kNumAttrs = 2;
1491 static constexpr size_t kNumDecisions = 0;
1492
1493 static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block,
1494 Array<BlockRV> consumer_blocks, Integer read_buffer_index,
1495 String storage_scope) {
1496 return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks);
1497 }
1498
1499 static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks,
1500 Integer read_buffer_index, String storage_scope) {
1501 PythonAPICall py("cache_read");
1502 py.Input("block", block);
1503 py.Input("read_buffer_index", read_buffer_index->value);
1504 py.Input("storage_scope", storage_scope);
1505 // Only write out consumer blocks if provided.
1506 if (!consumer_blocks.empty()) {
1507 py.Input("consumer_blocks", consumer_blocks);
1508 }
1509 py.SingleOutput(outputs);
1510 return py.Str();
1511 }
1512
1513 template <typename>
1514 friend struct ::tvm::tir::UnpackedInstTraits;
1515};
1516
1517struct CacheWriteTraits : public UnpackedInstTraits<CacheWriteTraits> {
1518 static constexpr const char* kName = "CacheWrite";
1519 static constexpr bool kIsPure = false;
1520
1521 private:
1522 static constexpr size_t kNumInputs = 2;
1523 static constexpr size_t kNumAttrs = 2;
1524 static constexpr size_t kNumDecisions = 0;
1525
1526 static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block,
1527 Array<BlockRV> consumer_blocks, Integer write_buffer_index,
1528 String storage_scope) {
1529 return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks);
1530 }
1531
1532 static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks,
1533 Integer write_buffer_index, String storage_scope) {
1534 PythonAPICall py("cache_write");
1535 py.Input("block", block);
1536 py.Input("write_buffer_index", write_buffer_index->value);
1537 py.Input("storage_scope", storage_scope);
1538 // Only write out consumer blocks if provided.
1539 if (!consumer_blocks.empty()) {
1540 py.Input("consumer_blocks", consumer_blocks);
1541 }
1542 py.SingleOutput(outputs);
1543 return py.Str();
1544 }
1545
1546 template <typename>
1547 friend struct ::tvm::tir::UnpackedInstTraits;
1548};
1549
1550struct CacheInplaceTraits : public UnpackedInstTraits<CacheInplaceTraits> {
1551 static constexpr const char* kName = "CacheInplace";
1552 static constexpr bool kIsPure = false;
1553
1554 private:
1555 static constexpr size_t kNumInputs = 1;
1556 static constexpr size_t kNumAttrs = 2;
1557 static constexpr size_t kNumDecisions = 0;
1558
1559 static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block,
1560 Integer read_buffer_index, String storage_scope) {
1561 return sch->CacheInplace(block, read_buffer_index->value, storage_scope);
1562 }
1563
1564 static String UnpackedAsPython(Array<String> outputs, String block, Integer read_buffer_index,
1565 String storage_scope) {
1566 PythonAPICall py("cache_inplace");
1567 py.Input("block", block);
1568 py.Input("read_buffer_index", read_buffer_index->value);
1569 py.Input("storage_scope", storage_scope);
1570 py.OutputList(outputs);
1571 return py.Str();
1572 }
1573
1574 template <typename>
1575 friend struct ::tvm::tir::UnpackedInstTraits;
1576};
1577
1578struct ReIndexTraits : public UnpackedInstTraits<ReIndexTraits> {
1579 static constexpr const char* kName = "ReIndex";
1580 static constexpr bool kIsPure = false;
1581
1582 private:
1583 static constexpr size_t kNumInputs = 1;
1584 static constexpr size_t kNumAttrs = 2;
1585 static constexpr size_t kNumDecisions = 0;
1586
1587 static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index,
1588 Integer buffer_index_type) {
1589 return sch->ReIndex(block, buffer_index.IntValue(),
1590 static_cast<BufferIndexType>(buffer_index_type->value));
1591 }
1592
1593 static String UnpackedAsPython(Array<String> outputs, String block, Integer buffer_index,
1594 Integer buffer_index_type) {
1595 PythonAPICall py("reindex");
1596 py.Input("block", block);
1597 std::ostringstream os;
1598 os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
1599 << "\", " << buffer_index << ")";
1600 py.Input("buffer", os.str());
1601 py.SingleOutput(outputs);
1602 return py.Str();
1603 }
1604
1605 template <typename>
1606 friend struct ::tvm::tir::UnpackedInstTraits;
1607};
1608
1609TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits);
1610TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits);
1611TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits);
1612TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits);
1613} // namespace tir
1614} // namespace tvm
1615