1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <unordered_set> |
21 | |
22 | #include "../utils.h" |
23 | |
24 | namespace tvm { |
25 | namespace tir { |
26 | |
27 | /******** Error Classes ********/ |
28 | |
29 | class NotSingleWriteBlock : public ScheduleError { |
30 | public: |
31 | explicit NotSingleWriteBlock(IRModule mod, Buffer buffer, Array<StmtSRef> write_blocks) |
32 | : mod_(std::move(mod)), buffer_(std::move(buffer)) { |
33 | ICHECK_GT(write_blocks.size(), 1); |
34 | write_blocks_.reserve(write_blocks.size()); |
35 | for (const StmtSRef& block_sref : write_blocks) { |
36 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
37 | write_blocks_.push_back(GetRef<Block>(block)); |
38 | } |
39 | } |
40 | |
41 | String FastErrorString() const final { |
42 | return "ScheduleError: The buffer is allowed to be written by single block." ; |
43 | } |
44 | |
45 | String DetailRenderTemplate() const final { |
46 | size_t k = write_blocks_.size(); |
47 | return "The buffer " + buffer_->name + " is expected to be written by single block, but got " + |
48 | std::to_string(k) + " blocks who write it." ; |
49 | } |
50 | |
51 | IRModule mod() const final { return mod_; } |
52 | Array<ObjectRef> LocationsOfInterest() const final { |
53 | return {write_blocks_.begin(), write_blocks_.end()}; |
54 | } |
55 | |
56 | private: |
57 | IRModule mod_; |
58 | Buffer buffer_; |
59 | Array<Block> write_blocks_; |
60 | }; |
61 | |
62 | /******** Helper Functions/Classes ********/ |
63 | |
64 | /*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ |
65 | struct CacheStageInfo { |
66 | /*! \brief The buffer to be read. */ |
67 | Buffer read_buffer; |
68 | /*! \brief The buffer to be written. */ |
69 | Buffer write_buffer; |
70 | /*! \brief The buffer allocation to be inserted into the block signature. */ |
71 | Optional<Buffer> alloc; |
72 | /*! \brief The AST node whose body is where the cache stage should be inserted. */ |
73 | StmtSRef loc_sref; |
74 | /*! \brief The index to insert the cache_read/cache_write stage. */ |
75 | size_t loc_pos; |
76 | /*! \brief The cache_read/cache_write stage to be inserted. */ |
77 | Stmt cache_stage; |
78 | /*! \brief The map used for ScheduleStateNode::Replace. */ |
79 | Map<Block, Block> block_reuse; |
80 | /*! \brief A set of blocks that will consume the new cache. */ |
81 | std::unordered_set<StmtSRef, ObjectHash, ObjectEqual> consumer_blocks; |
82 | }; |
83 | |
84 | /*! \brief Return the buffer region realted with the buffer */ |
85 | Optional<BufferRegion> GetBufferRegionFromBuffer(const Array<BufferRegion>& buffer_regions, |
86 | const Buffer& buffer) { |
87 | Optional<BufferRegion> res = NullOpt; |
88 | for (const auto& region : buffer_regions) { |
89 | if (region->buffer.same_as(buffer)) { |
90 | ICHECK(!res.defined()); |
91 | res = region; |
92 | } |
93 | } |
94 | return res; |
95 | } |
96 | |
97 | /*! |
98 | * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer |
99 | * to write buffer. |
100 | * \note This function will store the stmt with loop nesting to the CacheStageInfo, but only return |
101 | * the inside block. |
102 | * \param cache_region The cached copy region. |
103 | * \param info The cache stage information, which will be updated in the function. |
104 | * \param storage_scope The storage scope of the cached buffer (only used in naming here) |
105 | * \returns A block indicating the body of the loop nesting. |
106 | */ |
107 | Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, |
108 | const String& storage_scope) { |
109 | // loop variables |
110 | std::vector<Var> loop_vars; |
111 | // bindings in block realize |
112 | std::vector<PrimExpr> iter_values; |
113 | // Create loop vars and block vars' binding_value |
114 | for (const Range& axis_range : cache_region->region) { |
115 | Var loop_var("ax" + std::to_string(loop_vars.size()), axis_range->extent.dtype()); |
116 | loop_vars.push_back(loop_var); |
117 | iter_values.push_back(axis_range->min + loop_var); |
118 | } |
119 | // block variables |
120 | Array<IterVar> block_vars; |
121 | // block access region for read/write buffers |
122 | Region access_region; |
123 | // indices used in block body |
124 | Array<PrimExpr> access_indices; |
125 | // Create block vars, block's accessed region and accessing indices |
126 | for (const PrimExpr& dim : cache_region->buffer->shape) { |
127 | Var var("v" + std::to_string(access_indices.size()), dim.dtype()); |
128 | block_vars.push_back(IterVar(/*dom=*/Range::FromMinExtent(make_zero(dim->dtype), dim), |
129 | /*var=*/var, |
130 | /*IterVarType=*/kDataPar)); |
131 | access_indices.push_back(var); |
132 | access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); |
133 | } |
134 | |
135 | // Create the body block: |
136 | // reads = [read_buffer[access_region]] |
137 | // writes = [write_buffer[access_region]] |
138 | // write_buffer[access_indices] = read_buffer[access_indices] |
139 | Block block( |
140 | /*iter_vars=*/std::move(block_vars), |
141 | /*reads=*/{BufferRegion(info->read_buffer, access_region)}, |
142 | /*writes=*/{BufferRegion(info->write_buffer, access_region)}, |
143 | /*name_hint=*/cache_region->buffer->name + "_" + storage_scope, |
144 | /*body=*/ |
145 | BufferStore(info->write_buffer, BufferLoad(info->read_buffer, access_indices), |
146 | access_indices), |
147 | /*init=*/NullOpt, |
148 | /*alloc_buffers=*/{}, |
149 | /*match_buffers=*/{}, |
150 | /*annotations=*/{}); |
151 | // Create the block realize node |
152 | Stmt body = BlockRealize(/*values=*/iter_values, |
153 | /*predicate=*/const_true(), |
154 | /*block=*/block); |
155 | // Create surrounding loops |
156 | for (size_t i = loop_vars.size(); i >= 1; --i) { |
157 | body = For(/*loop_var=*/loop_vars[i - 1], |
158 | /*min=*/0, |
159 | /*extent=*/cache_region->region[i - 1]->extent, |
160 | /*kind=*/ForKind::kSerial, |
161 | /*body=*/body); |
162 | } |
163 | info->cache_stage = std::move(body); |
164 | return block; |
165 | } |
166 | |
167 | /*! |
168 | * \brief Create the reindex block and generate the corresponding outer loops. |
169 | * \details The reindex block is a data copy block between the reindex buffer (the intermediate |
170 | * buffer), and the target buffer. |
171 | If buffer_index_type == kWrite, copy from the reindex buffer to the target buffer. |
172 | If buffer_index_type == kRead, copy from the target buffer to the reindex buffer. |
173 | The reindex block has the same block iters and the surrounding loops as the input block. |
174 | However, if a block iter is not used in the indices of the target buffer being reindexed, the |
175 | domain of the block iter, and the corresponding outer loop, will become constant value one, making |
176 | it a trivial iter. |
177 | * \param block The block to be reindexed |
178 | * \param info The cache info |
179 | * \param covered The set of block iter vars covered in the buffer access indices |
180 | * \param original_indices The original buffer access indices |
181 | * \param buffer_index The index of the target buffer |
182 | * \param buffer_index_type The type of buffer index |
183 | * \return The reindex block. |
184 | */ |
185 | Block MakeReIndexStage(const Block& block, CacheStageInfo* info, |
186 | const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered, |
187 | const Array<PrimExpr>& original_indices, int buffer_index, |
188 | BufferIndexType buffer_index_type) { |
189 | // iters of the reindex block |
190 | Array<IterVar> new_block_iters; |
191 | // the substition map from the original block iter to the iters of the reindex block |
192 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectEqual> block_var_replace_map; |
193 | // indices to access the reindex buffer and the target buffer |
194 | Array<PrimExpr> reindex_indices, target_indices; |
195 | |
196 | // Step 1: Create block iters, access regions of the reindex block, and accessing indices to the |
197 | // reindex buffer. |
198 | std::unordered_set<int> skipped_block_iters; |
199 | for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { |
200 | const IterVar& iter = block->iter_vars[i]; |
201 | Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype); |
202 | bool used = covered.count(iter->var); |
203 | if (used) { |
204 | new_block_iters.push_back(IterVar(/*dom=*/iter->dom, |
205 | /*var=*/var, |
206 | /*IterVarType=*/kDataPar)); |
207 | } else { |
208 | skipped_block_iters.insert(i); |
209 | } |
210 | if (used) { |
211 | reindex_indices.push_back(var); |
212 | } |
213 | block_var_replace_map[iter->var] = var; |
214 | } |
215 | |
216 | // Step 2: Replace the original block iters with the new block iters |
217 | for (const PrimExpr& index : original_indices) { |
218 | target_indices.push_back(Substitute(index, block_var_replace_map)); |
219 | } |
220 | |
221 | // Step 3: Create the reindex block |
222 | |
223 | // The src and the dst region and indices of the data copy |
224 | Region src_region{nullptr}; |
225 | Region dst_region{nullptr}; |
226 | Array<PrimExpr> src_indices{nullptr}; |
227 | Array<PrimExpr> dst_indices{nullptr}; |
228 | |
229 | if (buffer_index_type == BufferIndexType::kWrite) { |
230 | src_indices = reindex_indices; |
231 | dst_indices = target_indices; |
232 | } else { |
233 | src_indices = target_indices; |
234 | dst_indices = reindex_indices; |
235 | } |
236 | |
237 | // Create the body block |
238 | Block new_block( |
239 | /*iter_vars=*/new_block_iters, |
240 | /*reads=*/{BufferRegion::FromPoint(info->read_buffer, src_indices)}, |
241 | /*writes=*/{BufferRegion::FromPoint(info->write_buffer, dst_indices)}, |
242 | /*name_hint=*/info->write_buffer->name + "_reindex" , |
243 | /*body=*/ |
244 | BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices)); |
245 | |
246 | // Step 4: Create surrounding loops |
247 | |
248 | // Create loop vars and bindings for block iters |
249 | std::vector<Var> loop_vars; // loop variables |
250 | std::vector<PrimExpr> iter_values; // bindings in block realize |
251 | for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) { |
252 | if (skipped_block_iters.count(i)) { |
253 | continue; |
254 | } |
255 | Var loop_var("ax" + std::to_string(loop_vars.size()), block->iter_vars[i]->var->dtype); |
256 | loop_vars.push_back(loop_var); |
257 | iter_values.push_back(loop_var); |
258 | } |
259 | |
260 | // Create the block realize node |
261 | Stmt body = BlockRealize(/*values=*/iter_values, |
262 | /*predicate=*/const_true(), |
263 | /*block=*/new_block); |
264 | |
265 | // Create the chain of loops |
266 | for (int i = static_cast<int>(new_block_iters.size()) - 1; i >= 0; --i) { |
267 | body = For(/*loop_var=*/loop_vars[i], |
268 | /*min=*/new_block_iters[i]->dom->min, |
269 | /*extent=*/new_block_iters[i]->dom->extent, |
270 | /*kind=*/ForKind::kSerial, |
271 | /*body=*/std::move(body)); |
272 | } |
273 | // Update cache info, which will be used in the later rewriting. |
274 | info->cache_stage = std::move(body); |
275 | return new_block; |
276 | } |
277 | |
278 | /*! |
279 | * \brief Recalculate the `affine_binding` flag of a specifc block |
280 | * \param block_sref The sref to the specific block |
281 | */ |
282 | bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) { |
283 | if (block_sref->parent == nullptr) { |
284 | return true; |
285 | } |
286 | arith::Analyzer analyzer; |
287 | StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent); |
288 | return IsAffineBinding(/*realize=*/GetBlockRealize(self, block_sref), |
289 | /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), |
290 | /*analyzer=*/&analyzer); |
291 | } |
292 | |
293 | /*! |
294 | * \brief Insert the cache_read/cache_write stage into the specific position |
295 | * \param stmt A sequence of statements or a single statement that the new stage is inserted in |
296 | * \param pos The position where the cache stage is inserted |
297 | * \param stage The stage to be inserted |
298 | * \return A SeqStmt, the result after insertion |
299 | */ |
300 | Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { |
301 | if (const auto* alloc = stmt.as<AllocateConstNode>()) { |
302 | auto seq_stmt = InsertCacheStage(alloc->body, pos, stage); |
303 | return AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, seq_stmt, |
304 | alloc->annotations, alloc->span); |
305 | } |
306 | if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) { |
307 | ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt); |
308 | result->seq.insert(result->seq.begin() + pos, stage); |
309 | return SeqStmt(result); |
310 | } |
311 | if (pos == 0) { |
312 | return SeqStmt({stage, stmt}); |
313 | } |
314 | ICHECK_EQ(pos, 1); |
315 | return SeqStmt({stmt, stage}); |
316 | } |
317 | |
318 | /*! |
319 | * \brief Get the only writer block of the input buffer in a given scope block. |
320 | * \param self The state of the schedule |
321 | * \param scope_sref The scope block where the write is considered |
322 | * \param buffer The queried buffer |
323 | * \return The sref of the only writer of the input buffer in the given scope, |
324 | * or `NullOpt` if no block writes it in the scope. |
325 | * \throw NotSingleWriteBlock if there are more than one intrested block. |
326 | */ |
327 | Optional<StmtSRef> GetOnlyWriteBlock(ScheduleState self, const StmtSRef& scope_sref, |
328 | const Buffer& buffer) { |
329 | BlockScope scope = self->GetBlockScope(scope_sref); |
330 | auto it = scope->buffer_writers.find(buffer); |
331 | if (it == scope->buffer_writers.end()) { |
332 | return NullOpt; |
333 | } else { |
334 | const Array<StmtSRef>& block_srefs = it->second; |
335 | ICHECK(!block_srefs.empty()); |
336 | if (block_srefs.size() > 1) { |
337 | throw NotSingleWriteBlock(self->mod, buffer, block_srefs); |
338 | } |
339 | return block_srefs[0]; |
340 | } |
341 | } |
342 | |
343 | /*! |
344 | * \brief Get the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) |
345 | * \param self The state of the schedule. |
346 | * \param buffer_region The buffer region to be analyzed. |
347 | * \param block_sref The sref of the block related to the region. |
348 | * \param dom_low_inclusive The lowest node in the sref tree path. |
349 | * \param dom_high_exclusive The highest node in the sref tree path. |
350 | * \return The relaxed buffer region. |
351 | */ |
352 | BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_region, |
353 | const StmtSRef& block_sref, const StmtSRef& dom_low_inclusive, |
354 | const StmtSRef& dom_high_exclusive) { |
355 | BlockRealize realize = GetBlockRealize(self, block_sref); |
356 | Map<Var, PrimExpr> binding = GetBindings(realize); |
357 | const Buffer& buffer = buffer_region->buffer; |
358 | arith::Analyzer analyzer; |
359 | BufferRegion subst_region = BufferRegion(buffer, Substitute(buffer_region->region, binding)); |
360 | Array<arith::IntSet> int_sets = AnalyzeRegionUpperBound( |
361 | /*region=*/subst_region, |
362 | /*predicate=*/realize->predicate, |
363 | /*dom_low_inclusive=*/dom_low_inclusive, |
364 | /*dom_high_exclusive=*/dom_high_exclusive, |
365 | /*analyzer=*/&analyzer); |
366 | ICHECK_EQ(buffer_region->region.size(), int_sets.size()); |
367 | |
368 | Region region; |
369 | region.reserve(int_sets.size()); |
370 | for (size_t i = 0; i < int_sets.size(); ++i) { |
371 | region.push_back(int_sets[i].CoverRange(Range::FromMinExtent(0, buffer->shape[i]))); |
372 | } |
373 | return BufferRegion(buffer, region); |
374 | } |
375 | |
376 | /*! \brief Detect the insertion position of the new cache stage */ |
377 | class CacheLocDetector : public StmtVisitor { |
378 | public: |
379 | /*! |
380 | * \brief Detect the insertion position of the cache stage, and write the position into the |
381 | * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique |
382 | * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref |
383 | * of the scope block of the cached block \param info The cache stage info. |
384 | */ |
385 | template <bool is_cache_read> |
386 | static void Detect(const ScheduleState& self, const StmtSRef& block_sref, |
387 | const StmtSRef& scope_sref, CacheStageInfo* info) { |
388 | std::vector<StmtSRef> related_blocks; |
389 | // If consumer is specified, skip detecting the others |
390 | if (is_cache_read) { |
391 | if (info->consumer_blocks.size() > 0) { |
392 | for (StmtSRef consumer : info->consumer_blocks) { |
393 | related_blocks.emplace_back(consumer); |
394 | } |
395 | } else { |
396 | for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { |
397 | if (def->kind == DepKind::kRAW) { |
398 | related_blocks.push_back(def->dst); |
399 | } |
400 | } |
401 | } |
402 | } else { |
403 | for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) { |
404 | if (def->kind == DepKind::kRAW) { |
405 | if (info->consumer_blocks.count(def->dst)) { |
406 | continue; |
407 | } |
408 | related_blocks.push_back(def->dst); |
409 | } |
410 | } |
411 | } |
412 | |
413 | if (!related_blocks.empty()) { |
414 | CacheLocDetector detector(self, block_sref, scope_sref, related_blocks); |
415 | detector(GetRef<Stmt>(scope_sref->stmt)); |
416 | info->loc_sref = detector.loc_sref_; |
417 | info->loc_pos = detector.loc_pos_; |
418 | } else { |
419 | info->loc_sref = scope_sref; |
420 | |
421 | auto block_body = scope_sref->StmtAs<BlockNode>()->body; |
422 | // Find the SeqStmtNode within (potentially nested) AllocateConstNodes |
423 | while (block_body->IsInstance<AllocateConstNode>()) { |
424 | block_body = block_body.as<AllocateConstNode>()->body; |
425 | } |
426 | const auto* body = block_body.as<SeqStmtNode>(); |
427 | info->loc_pos = body == nullptr ? 1 : body->size(); |
428 | } |
429 | } |
430 | |
431 | private: |
432 | /*! |
433 | * \brief Constructor |
434 | * \param self The state of the schedule |
435 | * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or |
436 | * cache_write \param scope_sref The sref of the scope block of the cached block \param |
437 | * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read |
438 | */ |
439 | CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, |
440 | const std::vector<StmtSRef>& related_blocks) |
441 | : self_(self), |
442 | block_sref_(block_sref), |
443 | scope_sref_(scope_sref), |
444 | related_blocks_(related_blocks) {} |
445 | |
446 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
447 | bool previous_visited_block = visited_block_; |
448 | visited_block_ = false; |
449 | |
450 | for (size_t i = 0; i < seq_stmt->size(); ++i) { |
451 | if (loc_pos_ != -1) { |
452 | break; |
453 | } |
454 | VisitStmt(seq_stmt->seq[i]); |
455 | // `pos` can be assigned only once when we visited `block_sref` |
456 | if (visited_block_ && visited_related_ && loc_pos_ == -1) { |
457 | // The offset of insert position from the block |
458 | loc_pos_ = i; |
459 | break; |
460 | } else if (visited_related_) { |
461 | // If meet the target consumer, stop searching |
462 | break; |
463 | } |
464 | } |
465 | visited_block_ = visited_block_ || previous_visited_block; |
466 | } |
467 | |
468 | void VisitStmt_(const BlockNode* block) final { |
469 | // Only visit the current scope under buffer writer's parent block |
470 | if (block == scope_sref_->stmt) { |
471 | // The block vistied is the current parent scope |
472 | StmtVisitor::VisitStmt_(block); |
473 | // Handling cases when insert outside any loop or cache_read for input buffer |
474 | if (visited_related_ && !loc_sref_.defined()) { |
475 | loc_sref_ = self_->stmt2ref.at(block); |
476 | // Handling cache_read for input buffer |
477 | if (visited_block_ == false && loc_pos_ == -1) { |
478 | loc_pos_ = 0; |
479 | } |
480 | } |
481 | return; |
482 | } |
483 | // Update `visited_block` |
484 | if (block_sref_->stmt == block) { |
485 | visited_block_ = true; |
486 | return; |
487 | } |
488 | // Update `visited_related` |
489 | for (const StmtSRef& related_block : related_blocks_) { |
490 | if (related_block->stmt == block) { |
491 | visited_related_ = true; |
492 | return; |
493 | } |
494 | } |
495 | } |
496 | |
497 | void VisitStmt_(const ForNode* loop) final { |
498 | StmtVisitor::VisitStmt_(loop); |
499 | if (visited_block_ && visited_related_ && !loc_sref_.defined() && loc_pos_ != -1) { |
500 | loc_sref_ = self_->stmt2ref.at(loop); |
501 | } |
502 | } |
503 | |
504 | private: |
505 | /*! \brief The schedule class */ |
506 | const ScheduleState self_; |
507 | /*! \brief The dominate block which write the buffer */ |
508 | const StmtSRef& block_sref_; |
509 | /*! \brief The parent scope of the dominate block */ |
510 | const StmtSRef& scope_sref_; |
511 | /*! \brief Producer blocks for cache_write and consumer blocks for cache_read */ |
512 | const std::vector<StmtSRef>& related_blocks_; |
513 | /*! \brief The flag whether we have visited the dominate block */ |
514 | bool visited_block_{false}; |
515 | /*! \brief The flag whether we have visited at least one related blocks */ |
516 | bool visited_related_{false}; |
517 | /*! \brief The AST node whose body is where the cache stage should be inserted */ |
518 | StmtSRef loc_sref_{nullptr}; |
519 | /*! \brief The index to insert the cache_read/cache_write stage */ |
520 | int loc_pos_{-1}; |
521 | }; |
522 | |
523 | /*! \brief Detect the insertion position of the new cache stage */ |
524 | class CacheInplaceLocDetector : public StmtVisitor { |
525 | public: |
526 | /*! |
527 | * \brief Detect the insertion position of the cache stage, and write the position into the |
528 | * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique |
529 | * block of the buffer being applied cache_inplace \param scope_sref The sref |
530 | * of the scope block of the cached block \param info The cache stage info. |
531 | */ |
532 | static void Detect(const ScheduleState& self, const StmtSRef& block_sref, |
533 | const StmtSRef& scope_sref, CacheStageInfo* info) { |
534 | CacheInplaceLocDetector detector(self, block_sref, scope_sref); |
535 | detector(GetRef<Stmt>(scope_sref->stmt)); |
536 | info->loc_sref = detector.loc_sref_; |
537 | info->loc_pos = detector.loc_pos_; |
538 | } |
539 | |
540 | private: |
541 | /*! |
542 | * \brief Constructor |
543 | * \param self The state of the schedule |
544 | * \param block_sref The sref of the unique writer block of the buffer being applied cache_inplace |
545 | * \param scope_sref The sref of the scope block of the cached block |
546 | */ |
547 | CacheInplaceLocDetector(const ScheduleState self, const StmtSRef& block_sref, |
548 | const StmtSRef& scope_sref) |
549 | : self_(self), block_sref_(block_sref), scope_sref_(scope_sref) {} |
550 | |
551 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
552 | for (size_t i = 0; i < seq_stmt->size(); ++i) { |
553 | if (loc_pos_ != -1) { |
554 | break; |
555 | } |
556 | VisitStmt(seq_stmt->seq[i]); |
557 | // `pos` can be assigned only once when we visited `block_sref` |
558 | if (visited_block_ && loc_pos_ == -1) { |
559 | // The offset of insert position from the block |
560 | loc_pos_ = i; |
561 | return; |
562 | } |
563 | } |
564 | } |
565 | |
566 | void VisitStmt_(const BlockNode* block) final { |
567 | // Only visit the current scope under buffer writer's parent block |
568 | if (block == scope_sref_->stmt) { |
569 | // The block vistied is the current parent scope |
570 | StmtVisitor::VisitStmt_(block); |
571 | // Handling cases when insert outside any loop |
572 | if (visited_block_ && !loc_sref_.defined()) { |
573 | loc_sref_ = self_->stmt2ref.at(block); |
574 | // Handling for input buffer |
575 | if (loc_pos_ == -1) { |
576 | loc_pos_ = 0; |
577 | } |
578 | } |
579 | } else if (block_sref_->stmt == block) { |
580 | visited_block_ = true; |
581 | } |
582 | } |
583 | |
584 | void VisitStmt_(const ForNode* loop) final { |
585 | StmtVisitor::VisitStmt_(loop); |
586 | if (visited_block_ && !loc_sref_.defined()) { |
587 | loc_sref_ = self_->stmt2ref.at(loop); |
588 | if (loc_pos_ == -1) { |
589 | loc_pos_ = 0; |
590 | } |
591 | } |
592 | } |
593 | |
594 | private: |
595 | /*! \brief The schedule class */ |
596 | const ScheduleState self_; |
597 | /*! \brief The dominate block which write the buffer */ |
598 | const StmtSRef& block_sref_; |
599 | /*! \brief The parent scope of the dominate block */ |
600 | const StmtSRef& scope_sref_; |
601 | /*! \brief The flag whether we have visited the target block */ |
602 | bool visited_block_{false}; |
603 | /*! \brief The AST node whose body is where the cache stage should be inserted */ |
604 | StmtSRef loc_sref_{nullptr}; |
605 | /*! \brief The index to insert the cache_read/cache_write stage */ |
606 | int loc_pos_{-1}; |
607 | }; |
608 | |
609 | /*! \brief Mutator for CacheRead. */ |
610 | class CacheReadRewriter : public StmtExprMutator { |
611 | public: |
612 | /*! |
613 | * \brief Rewrite the AST and add a cache_read stage with the information provided |
614 | * \param scope_sref The parent scope of this mutation |
615 | * \param info The cache stage information |
616 | * \return The new AST rooting at the original parent scope |
617 | */ |
618 | static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info) { |
619 | CacheReadRewriter rewriter(scope_sref, info); |
620 | return rewriter(GetRef<Stmt>(scope_sref->stmt)); |
621 | } |
622 | |
623 | private: |
624 | explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) |
625 | : scope_sref_(scope_sref), info_(info) {} |
626 | |
627 | Stmt VisitStmt_(const ForNode* loop) final { |
628 | Stmt stmt = StmtMutator::VisitStmt_(loop); |
629 | // Check the insertion point |
630 | if (loop == info_->loc_sref->stmt) { |
631 | // Insert cache stage into the loop if it is the right place |
632 | ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>()); |
633 | n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); |
634 | stmt = Stmt(n); |
635 | } |
636 | return stmt; |
637 | } |
638 | |
639 | Stmt VisitStmt_(const BlockNode* block) final { |
640 | Block old_stmt = GetRef<Block>(block); |
641 | // Check if this block is one of the specified consumers. |
642 | // If no consumer blocks are specified, all blocks should be considered consumers. |
643 | bool is_consumer = info_->consumer_blocks.empty(); |
644 | // Otherwise check if this is one of the specified blocks. |
645 | for (StmtSRef consumer_sref : info_->consumer_blocks) { |
646 | const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); |
647 | Block consumer_block = GetRef<Block>(consumer_node); |
648 | if (old_stmt.same_as(consumer_block)) { |
649 | is_consumer = true; |
650 | } |
651 | } |
652 | // Keep track of this blocks status. We'll use this when rewriting loads. |
653 | current_block_consumes = is_consumer; |
654 | // We don't mutate the block which generates info->read_buffer. |
655 | if (block != scope_sref_->stmt && |
656 | GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { |
657 | return std::move(old_stmt); |
658 | } |
659 | // Mutate the body |
660 | Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block)); |
661 | // Check the insertion point |
662 | if (block == info_->loc_sref->stmt) { |
663 | // Insert cache stage into the block if it is the right place |
664 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
665 | n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); |
666 | stmt = Block(n); |
667 | } |
668 | // Check if it is the block corresponding to the parent scope |
669 | if (block == scope_sref_->stmt) { |
670 | // If so, put buffer allocation on the parent scope |
671 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
672 | // In cache_inplace case, alloc_buffer may be already exits. |
673 | if (info_->alloc.defined()) { |
674 | n->alloc_buffers.push_back(info_->alloc.value()); |
675 | stmt = Block(n); |
676 | } |
677 | } else { |
678 | // Otherwise, update read regions and match_buffers |
679 | // Only make this change if the block is one of the specified consumers. |
680 | if (is_consumer) { |
681 | Array<BufferRegion> reads = |
682 | ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); |
683 | Array<MatchBufferRegion> match_buffers = |
684 | ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); |
685 | if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { |
686 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
687 | n->reads = std::move(reads); |
688 | n->match_buffers = std::move(match_buffers); |
689 | stmt = Block(n); |
690 | } |
691 | } |
692 | } |
693 | info_->block_reuse.Set(old_stmt, stmt); |
694 | return std::move(stmt); |
695 | } |
696 | |
697 | PrimExpr VisitExpr_(const BufferLoadNode* load) final { |
698 | if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { |
699 | ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load); |
700 | n->buffer = info_->write_buffer; |
701 | return PrimExpr(n); |
702 | } |
703 | return ExprMutator::VisitExpr_(load); |
704 | } |
705 | |
706 | PrimExpr VisitExpr_(const LoadNode* op) final { |
707 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
708 | } |
709 | |
710 | PrimExpr VisitExpr_(const VarNode* op) final { |
711 | if (op == info_->read_buffer->data.get()) { |
712 | return info_->write_buffer->data; |
713 | } |
714 | return GetRef<PrimExpr>(op); |
715 | } |
716 | |
717 | private: |
718 | /*! \brief The parent scope of the insertion */ |
719 | const StmtSRef& scope_sref_; |
720 | /*! \brief The info for inserting cache stage */ |
721 | CacheStageInfo* info_; |
722 | /*! \brief Whether the most recently visited block is a specified consumer. */ |
723 | bool current_block_consumes; |
724 | }; |
725 | |
726 | /*! \brief Mutator for CacheWrite */ |
727 | class CacheWriteRewriter : public StmtExprMutator { |
728 | public: |
729 | /*! |
730 | * \brief Rewrite the AST and add a cache_write stage with the information provided. |
731 | * \param scope_sref The parent scope of this mutation. |
732 | * \param writer_block_sref The only writer block in the scope. |
733 | * \param info The cache stage information. |
734 | * \return The new AST rooting at the original parent scope. |
735 | */ |
736 | static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, |
737 | CacheStageInfo* info) { |
738 | CacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); |
739 | return rewriter(GetRef<Stmt>(scope_sref->stmt)); |
740 | } |
741 | |
742 | private: |
743 | explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, |
744 | CacheStageInfo* info) |
745 | : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} |
746 | |
747 | Stmt VisitStmt_(const ForNode* loop) final { |
748 | Stmt stmt = StmtMutator::VisitStmt_(loop); |
749 | // Check the insertion point |
750 | if (loop == info_->loc_sref->stmt) { |
751 | // Insert cache stage into the loop if it is the right place |
752 | ObjectPtr<ForNode> n = make_object<ForNode>(*stmt.as<ForNode>()); |
753 | n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); |
754 | stmt = Stmt(n); |
755 | } |
756 | return stmt; |
757 | } |
758 | |
759 | Stmt VisitStmt_(const BlockNode* block) final { |
760 | Block old_stmt = GetRef<Block>(block); |
761 | |
762 | // Check if this block is one of the specified cache consumers. |
763 | // update the read buffer to the cache. |
764 | for (StmtSRef consumer_sref : info_->consumer_blocks) { |
765 | const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); |
766 | Block consumer_block = GetRef<Block>(consumer_node); |
767 | if (old_stmt.same_as(consumer_block)) { |
768 | Array<BufferRegion> reads = |
769 | ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); |
770 | Array<MatchBufferRegion> match_buffers = |
771 | ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); |
772 | if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { |
773 | auto n = CopyOnWrite(block); |
774 | n->reads = std::move(reads); |
775 | n->match_buffers = std::move(match_buffers); |
776 | n->body = VisitStmt(block->body); |
777 | Block new_consumer = Block(n); |
778 | info_->block_reuse.Set(old_stmt, new_consumer); |
779 | return std::move(new_consumer); |
780 | } |
781 | return std::move(old_stmt); |
782 | } |
783 | } |
784 | |
785 | // We only mutate the block which generates info->write_buffer |
786 | if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { |
787 | return std::move(old_stmt); |
788 | } |
789 | |
790 | // Mutate the body |
791 | bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; |
792 | std::swap(under_scope, under_writer_block_); |
793 | Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block)); |
794 | std::swap(under_scope, under_writer_block_); |
795 | |
796 | // Find the insertion point |
797 | if (block == info_->loc_sref->stmt) { |
798 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
799 | n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); |
800 | stmt = Block(n); |
801 | } |
802 | // Put buffer allocation on the parent scope |
803 | if (block == scope_sref_->stmt) { |
804 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
805 | // In cache_inplace case, alloc_buffer may be already exits. |
806 | if (info_->alloc.defined()) { |
807 | n->alloc_buffers.push_back(info_->alloc.value()); |
808 | stmt = Block(n); |
809 | } |
810 | } else { |
811 | // Since cache_write changes the block, we need to update the buffer it writes |
812 | auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); |
813 | auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); |
814 | auto match_buffers = |
815 | ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); |
816 | if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || |
817 | !match_buffers.same_as(block->match_buffers)) { |
818 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
819 | n->writes = std::move(writes); |
820 | n->reads = std::move(reads); |
821 | n->match_buffers = std::move(match_buffers); |
822 | stmt = Block(n); |
823 | } |
824 | } |
825 | info_->block_reuse.Set(old_stmt, stmt); |
826 | return std::move(stmt); |
827 | } |
828 | |
829 | Stmt VisitStmt_(const BufferStoreNode* store) final { |
830 | BufferStore stmt = Downcast<BufferStore>(StmtMutator::VisitStmt_(store)); |
831 | if (stmt->buffer.same_as(info_->write_buffer)) { |
832 | auto n = CopyOnWrite(stmt.get()); |
833 | n->buffer = info_->read_buffer; |
834 | return Stmt(n); |
835 | } else { |
836 | return std::move(stmt); |
837 | } |
838 | } |
839 | |
840 | PrimExpr VisitExpr_(const BufferLoadNode* load) final { |
841 | if (load->buffer.same_as(info_->write_buffer)) { |
842 | ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load); |
843 | n->buffer = info_->read_buffer; |
844 | return PrimExpr(n); |
845 | } |
846 | return ExprMutator::VisitExpr_(load); |
847 | } |
848 | |
849 | PrimExpr VisitExpr_(const LoadNode* op) final { |
850 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
851 | } |
852 | |
853 | Stmt VisitStmt_(const StoreNode* op) final { |
854 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
855 | } |
856 | |
857 | PrimExpr VisitExpr_(const VarNode* op) final { |
858 | if (op == info_->write_buffer->data.get()) { |
859 | return info_->read_buffer->data; |
860 | } |
861 | return GetRef<PrimExpr>(op); |
862 | } |
863 | |
864 | private: |
865 | /*! \brief The parent scope of the insertion. */ |
866 | const StmtSRef& scope_sref_; |
867 | /*! \brief The parent scope of the insertion. */ |
868 | const StmtSRef& writer_block_sref_; |
869 | /*! \brief The info for inserting cache stage. */ |
870 | CacheStageInfo* info_; |
871 | /*! \brief Whether the current node is under the given block. */ |
872 | bool under_writer_block_{false}; |
873 | }; |
874 | |
875 | /*! |
876 | * \brief Create a new buffer by change the shape with block iters to be used as the reindex buffer |
877 | * \param buffer The given buffer. |
878 | * \param block_iters The block iters. |
879 | * \param covered Set of block iter vars covered by the buffer access indices |
880 | * \return The new buffer with target shape. |
881 | */ |
882 | Buffer CreateReindexBuffer(const Buffer& buffer, const Array<IterVar>& block_iters, |
883 | const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered) { |
884 | ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get()); |
885 | ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get()); |
886 | std::vector<PrimExpr> new_shape; |
887 | std::vector<PrimExpr> new_strides; |
888 | for (const auto& iter : block_iters) { |
889 | if (covered.count(iter->var)) { |
890 | new_shape.push_back(iter->dom->min + iter->dom->extent); |
891 | } |
892 | } |
893 | new_strides.clear(); |
894 | new_buffer->shape = new_shape; |
895 | new_buffer->strides = new_strides; |
896 | new_buffer->data = buffer->data.copy_with_suffix("_reindex" ); |
897 | new_buffer->name = buffer->name + "_reindex" ; |
898 | return Buffer(new_buffer); |
899 | } |
900 | |
901 | /*! \brief The schedule error that the target is not a leaf block. */ |
902 | class NotLeafBlockError : public ScheduleError { |
903 | public: |
904 | NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} |
905 | String FastErrorString() const final { |
906 | return "ScheduleError: The target block is not a leaf block." ; |
907 | } |
908 | |
909 | String DetailRenderTemplate() const final { return "The target block {0} is not a leaf block." ; } |
910 | |
911 | IRModule mod() const final { return mod_; } |
912 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
913 | IRModule mod_; |
914 | Block block_; |
915 | }; |
916 | |
917 | /*! \brief The schedule error that the buffer access is invalid for reindex. */ |
918 | class InvalidBufferAccessError : public ScheduleError { |
919 | public: |
920 | enum class ErrorKind { |
921 | kNoAccess, // buffer access not found |
922 | kNonUniqueAccess, // multiple buffer accesses with different indices |
923 | kOpaqueAccess, // opaque access to the buffer |
924 | }; |
925 | |
926 | InvalidBufferAccessError(IRModule mod, Buffer buffer, Block block, ErrorKind kind) |
927 | : mod_(std::move(mod)), buffer_(std::move(buffer)), block_(std::move(block)), kind_(kind) {} |
928 | String FastErrorString() const final { |
929 | return "ScheduleError: The target buffer should be accessed via BufferLoad or BufferStore. The " |
930 | "indices should be the same if there are multiple accesses to the target buffer." ; |
931 | } |
932 | |
933 | String DetailRenderTemplate() const final { |
934 | std::ostringstream os; |
935 | os << "The target buffer " << buffer_->name |
936 | << " should be accessed in the leaf block {0} via BufferLoad or BufferStore. The indices " |
937 | "should be the same if there are multiple accesses to the target buffer. " ; |
938 | if (kind_ == ErrorKind::kNoAccess) { |
939 | os << "No buffer accesses found." ; |
940 | } else if (kind_ == ErrorKind::kNonUniqueAccess) { |
941 | os << "Multiple buffer accesses have non-unique indices." ; |
942 | } else if (kind_ == ErrorKind::kOpaqueAccess) { |
943 | os << "Opaque buffer accesses found." ; |
944 | } |
945 | return os.str(); |
946 | } |
947 | IRModule mod() const final { return mod_; } |
948 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
949 | |
950 | private: |
951 | IRModule mod_; |
952 | Buffer buffer_; |
953 | Block block_; |
954 | ErrorKind kind_; |
955 | }; |
956 | |
957 | /*! \brief Collect the related Load/Store to reindex */ |
958 | class ReIndexCollector : public StmtExprVisitor { |
959 | public: |
960 | static Array<PrimExpr> Collect(const IRModule& mod, const Buffer& buffer, const Block& block) { |
961 | ReIndexCollector collector(mod, buffer, block); |
962 | collector(block->body); |
963 | if (!collector.buffer_access_indices_.defined()) { |
964 | throw InvalidBufferAccessError(mod, buffer, block, |
965 | InvalidBufferAccessError::ErrorKind::kNoAccess); |
966 | } |
967 | return collector.buffer_access_indices_.value(); |
968 | } |
969 | |
970 | private: |
971 | explicit ReIndexCollector(const IRModule& mod, const Buffer& buffer, const Block& block) |
972 | : mod_(mod), buffer_(buffer), block_(block) {} |
973 | |
974 | void VisitExpr_(const BufferLoadNode* load) final { |
975 | StmtExprVisitor::VisitExpr_(load); |
976 | if (load->buffer.same_as(buffer_)) { |
977 | CheckAndUpdateBufferAccessIndices(load->indices); |
978 | } |
979 | } |
980 | |
981 | void VisitStmt_(const BlockNode* block) final { |
982 | // no sub-blocks under this block |
983 | throw NotLeafBlockError(mod_, block_); |
984 | } |
985 | |
986 | void VisitStmt_(const BufferStoreNode* store) final { |
987 | StmtExprVisitor::VisitStmt_(store); |
988 | if (store->buffer.same_as(buffer_)) { |
989 | CheckAndUpdateBufferAccessIndices(store->indices); |
990 | } |
991 | } |
992 | |
993 | void CheckAndUpdateBufferAccessIndices(const Array<PrimExpr> indices) { |
994 | if (!buffer_access_indices_.defined()) { |
995 | buffer_access_indices_ = indices; |
996 | return; |
997 | } else if (!std::equal(buffer_access_indices_.value().begin(), |
998 | buffer_access_indices_.value().end(), indices.begin(), indices.end(), |
999 | ExprDeepEqual())) { |
1000 | throw InvalidBufferAccessError(mod_, buffer_, block_, |
1001 | InvalidBufferAccessError::ErrorKind::kNonUniqueAccess); |
1002 | } |
1003 | } |
1004 | |
1005 | void VisitExpr_(const VarNode* var) final { |
1006 | if (var == buffer_->data.get()) { |
1007 | throw InvalidBufferAccessError(mod_, buffer_, block_, |
1008 | InvalidBufferAccessError::ErrorKind::kOpaqueAccess); |
1009 | } |
1010 | } |
1011 | /*! \brief The IR module */ |
1012 | IRModule mod_; |
1013 | /*! \brief The buffer to rewrite */ |
1014 | Buffer buffer_; |
1015 | /*! \brief The block to visit */ |
1016 | Block block_; |
1017 | /*! \brief The indices of buffer acess to rewrite */ |
1018 | Optional<Array<PrimExpr>> buffer_access_indices_; |
1019 | }; |
1020 | |
1021 | /*! \brief Mutator of ReIndex */ |
1022 | class ReIndexRewriter : public StmtExprMutator { |
1023 | public: |
1024 | static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, |
1025 | const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered) { |
1026 | ReIndexRewriter rewriter(block_sref, info, covered); |
1027 | return rewriter(GetRef<Stmt>(scope_sref->stmt)); |
1028 | } |
1029 | |
1030 | private: |
1031 | explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info, |
1032 | const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered) |
1033 | : block_sref_(block_sref), info_(info), covered_(covered) { |
1034 | new_buffer_ = info->alloc.value(); |
1035 | old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; |
1036 | } |
1037 | |
1038 | Stmt VisitStmt_(const BlockNode* block) final { |
1039 | Block old_stmt = GetRef<Block>(block); |
1040 | if (is_scope_) { |
1041 | is_scope_ = false; |
1042 | Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); |
1043 | // Insert cache stage into the loop |
1044 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
1045 | n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); |
1046 | n->alloc_buffers.push_back(info_->alloc.value()); |
1047 | stmt = Block(n); |
1048 | info_->block_reuse.Set(old_stmt, stmt); |
1049 | return std::move(stmt); |
1050 | } |
1051 | |
1052 | // Visiting the blokc being reindexed |
1053 | if (block == block_sref_->stmt) { |
1054 | // Collect the updated indices and regions |
1055 | for (const IterVar& iter : block->iter_vars) { |
1056 | if (covered_.count(iter->var)) { |
1057 | indices_.push_back(iter->var); |
1058 | region_.push_back(Range::FromMinExtent(iter->var, IntImm(iter->var->dtype, 1))); |
1059 | } |
1060 | } |
1061 | Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); |
1062 | // Update block reads/writes to use the intermediate reindex buffer |
1063 | auto writes = |
1064 | ReplaceBufferRegion(block->writes, old_buffer_, BufferRegion{new_buffer_, region_}); |
1065 | auto reads = |
1066 | ReplaceBufferRegion(block->reads, old_buffer_, BufferRegion{new_buffer_, region_}); |
1067 | auto match_buffers = ReplaceBufferRegion(block->match_buffers, old_buffer_, |
1068 | BufferRegion{new_buffer_, region_}); |
1069 | if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || |
1070 | !match_buffers.same_as(block->match_buffers)) { |
1071 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
1072 | n->writes = std::move(writes); |
1073 | n->reads = std::move(reads); |
1074 | n->match_buffers = std::move(match_buffers); |
1075 | stmt = Block(n); |
1076 | } |
1077 | info_->block_reuse.Set(old_stmt, stmt); |
1078 | return std::move(stmt); |
1079 | } |
1080 | return std::move(old_stmt); |
1081 | } |
1082 | |
1083 | template <typename Node> |
1084 | Node VisitBufferAccess(Node node) { |
1085 | if (node->buffer.same_as(old_buffer_)) { |
1086 | auto* n = node.CopyOnWrite(); |
1087 | n->buffer = new_buffer_; |
1088 | n->indices = indices_; |
1089 | } |
1090 | return node; |
1091 | } |
1092 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
1093 | BufferStore buffer_store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
1094 | return VisitBufferAccess(std::move(buffer_store)); |
1095 | } |
1096 | |
1097 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
1098 | BufferLoad buffer_load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
1099 | return VisitBufferAccess(std::move(buffer_load)); |
1100 | } |
1101 | |
1102 | private: |
1103 | /*! \brief The parent scope of the insertion. */ |
1104 | const StmtSRef& block_sref_; |
1105 | /*! \brief The info for inserting reindex stage. */ |
1106 | CacheStageInfo* info_; |
1107 | /*! \brief Whether old block var is covered in the indices */ |
1108 | const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& covered_; |
1109 | /*! \brief Whether the current block is scope block */ |
1110 | bool is_scope_{true}; |
1111 | /*! \brief The buffer to be replaced */ |
1112 | Buffer old_buffer_; |
1113 | /*! \brief The reindex buffer */ |
1114 | Buffer new_buffer_; |
1115 | /*! \brief The new indices */ |
1116 | Array<PrimExpr> indices_; |
1117 | /*! \brief The new region */ |
1118 | Region region_; |
1119 | }; |
1120 | |
1121 | void CheckRegionCover(const ScheduleState& self, StmtSRef scope_root, Buffer read_buffer) { |
1122 | class NotRegionCoverError : public ScheduleError { |
1123 | public: |
1124 | explicit NotRegionCoverError(IRModule mod, Block block) : mod_(mod), block_(block) {} |
1125 | IRModule mod() const final { return mod_; } |
1126 | String FastErrorString() const final { |
1127 | return "ScheduleError: The scope root's region cover is not complete." ; |
1128 | } |
1129 | String DetailRenderTemplate() const final { |
1130 | return R"(The scope {0} 's region cover is not complete. |
1131 | The region cover property require to hold for every of its child blocks |
1132 | )" ; |
1133 | } |
1134 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
1135 | IRModule mod_; |
1136 | Block block_; |
1137 | }; |
1138 | |
1139 | for (const auto& child_block_sref : tir::GetChildBlocks(self, scope_root)) { |
1140 | const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref); |
1141 | for (const BufferRegion& region : child_block->reads) { |
1142 | if (region->buffer.same_as(read_buffer)) { |
1143 | if (!self->block_info.at(child_block_sref).region_cover) { |
1144 | const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); |
1145 | throw NotRegionCoverError(self->mod, GetRef<Block>(block)); |
1146 | } |
1147 | } |
1148 | } |
1149 | } |
1150 | } |
1151 | |
1152 | /******** Implementation ********/ |
1153 | |
1154 | StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, |
1155 | const String& storage_scope, const Array<StmtSRef> consumer_blocks) { |
1156 | /*! |
1157 | * Check: |
1158 | * - The index is in the array of block reading region |
1159 | * - There is at most one block who write the buffer in the scope |
1160 | * |
1161 | * Mutate: |
1162 | * - Allocate new cache buffer under the current scope. |
1163 | * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. |
1164 | * - Copy the buffer with the consumed region. |
1165 | */ |
1166 | |
1167 | // Step 0. Check the input storage scope. |
1168 | CheckStorageScope(self, storage_scope); |
1169 | |
1170 | // Step 1. Check index, getting the target buffer and the parent scope |
1171 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1172 | Buffer read_buffer = |
1173 | GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead); |
1174 | StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
1175 | // Check required region cover for cache_read |
1176 | CheckRegionCover(self, scope_sref, read_buffer); |
1177 | const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); |
1178 | |
1179 | // Step 2. Create CacheStageInfo |
1180 | CacheStageInfo info; |
1181 | info.read_buffer = read_buffer; |
1182 | // Create the corresponding buffer to be written, i.e. result of cache_read |
1183 | info.write_buffer = WithScope(read_buffer, storage_scope); |
1184 | // Create the corresponding buffer allocation |
1185 | info.alloc = info.write_buffer; |
1186 | |
1187 | // info.consumer_blocks indicates which buffers should consume the cache. |
1188 | for (auto consumer : consumer_blocks) { |
1189 | info.consumer_blocks.insert(consumer); |
1190 | for (auto child : tir::GetChildBlocks(self, consumer)) { |
1191 | info.consumer_blocks.insert(child); |
1192 | } |
1193 | } |
1194 | |
1195 | // Step 3. Update cache stage info. |
1196 | BufferRegion cache_region{nullptr}; |
1197 | if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { |
1198 | // Case 1. The buffer is written inside the block. |
1199 | StmtSRef write_block_sref = _write_block_sref.value(); |
1200 | const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); |
1201 | // Find the producing region |
1202 | BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value(); |
1203 | StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent); |
1204 | |
1205 | // Detect insert position |
1206 | CacheLocDetector::Detect</*is_cache_read=*/true>(self, write_block_sref, scope_sref, &info); |
1207 | cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); |
1208 | } else { |
1209 | // Case 2. The buffer is the input block for the scope. |
1210 | info.loc_sref = scope_sref; |
1211 | info.loc_pos = 0; |
1212 | if (Optional<BufferRegion> region = |
1213 | GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) { |
1214 | cache_region = region.value(); |
1215 | } else { |
1216 | cache_region = BufferRegion::FullRegion(read_buffer); |
1217 | } |
1218 | } |
1219 | |
1220 | // Step 4. Making new cache stage block and rewrite readers. |
1221 | Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, |
1222 | /*storage_scope=*/storage_scope); |
1223 | Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); |
1224 | |
1225 | // Step 5. Replacing and updating flags. |
1226 | self->Replace(scope_sref, new_scope, info.block_reuse); |
1227 | StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); |
1228 | BlockInfo& block_info = self->block_info[result_block_sref]; |
1229 | block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); |
1230 | block_info.region_cover = true; |
1231 | block_info.scope->stage_pipeline = true; |
1232 | return result_block_sref; |
1233 | } |
1234 | |
1235 | StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, |
1236 | const String& storage_scope, const Array<StmtSRef> consumer_blocks) { |
1237 | /*! |
1238 | * Check: |
1239 | * - The index is in the array of block reading region |
1240 | * - There is only one block who write the buffer in the scope |
1241 | * |
1242 | * Mutate: |
1243 | * - Allocate new cache buffer under the current scope. |
1244 | * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. |
1245 | * - Copy the buffer with the consumed region. |
1246 | */ |
1247 | |
1248 | // Step 0. Check the input storage scope. |
1249 | CheckStorageScope(self, storage_scope); |
1250 | |
1251 | // Step 1. Checking index, getting the target buffer and the parent scope |
1252 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1253 | Buffer write_buffer = |
1254 | GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite); |
1255 | StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
1256 | |
1257 | // Step 2. Creating CacheStageInfo |
1258 | CacheStageInfo info; |
1259 | info.read_buffer = WithScope(write_buffer, storage_scope); |
1260 | // Create the corresponding buffer to be written, i.e. result of cache_write |
1261 | info.write_buffer = write_buffer; |
1262 | // Create the corresponding buffer allocation |
1263 | info.alloc = info.read_buffer; |
1264 | |
1265 | // info.consumer_blocks indicates which buffers should consume the cache. |
1266 | for (auto consumer : consumer_blocks) { |
1267 | info.consumer_blocks.insert(consumer); |
1268 | for (auto child : tir::GetChildBlocks(self, consumer)) { |
1269 | info.consumer_blocks.insert(child); |
1270 | } |
1271 | } |
1272 | |
1273 | // Step 3. Check the only writer block. |
1274 | ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); |
1275 | |
1276 | // Step 4. Find the producing region and insert position |
1277 | BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value(); |
1278 | StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent); |
1279 | // Detect insert position |
1280 | CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref, scope_sref, &info); |
1281 | BufferRegion cache_region = |
1282 | RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); |
1283 | |
1284 | // Step 5. Making new cache stage block and rewrite readers. |
1285 | Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, |
1286 | /*storage_scope=*/storage_scope); |
1287 | Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, |
1288 | /*writer_block_sref=*/block_sref, /*info=*/&info); |
1289 | |
1290 | // Step 6. Replacing and updating flags. |
1291 | self->Replace(scope_sref, new_scope, info.block_reuse); |
1292 | StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); |
1293 | BlockInfo& block_info = self->block_info[result_block_sref]; |
1294 | block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); |
1295 | block_info.region_cover = true; |
1296 | block_info.scope->stage_pipeline = true; |
1297 | return result_block_sref; |
1298 | } |
1299 | |
1300 | /*! \brief The schedule error that the target block doesn't both read&write target buffer. */ |
1301 | class NotReadWriteError : public ScheduleError { |
1302 | public: |
1303 | NotReadWriteError(IRModule mod, Block block, Buffer buffer) |
1304 | : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {} |
1305 | String FastErrorString() const final { |
1306 | return "ScheduleError: The target block does not both read & write target buffer." ; |
1307 | } |
1308 | |
1309 | String DetailRenderTemplate() const final { |
1310 | return "The target block {0} does not both read & write target buffer {1}." ; |
1311 | } |
1312 | |
1313 | IRModule mod() const final { return mod_; } |
1314 | Array<ObjectRef> LocationsOfInterest() const final { return {block_, buffer_}; } |
1315 | IRModule mod_; |
1316 | Block block_; |
1317 | Buffer buffer_; |
1318 | }; |
1319 | |
1320 | Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, |
1321 | const String& storage_scope) { |
1322 | /*! |
1323 | * Do cache read then cache write |
1324 | */ |
1325 | |
1326 | // Check 0. Check the input storage scope. |
1327 | CheckStorageScope(self, storage_scope); |
1328 | |
1329 | // Check 1. Check index, get the target buffer and the parent scope |
1330 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1331 | Buffer buffer = |
1332 | GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead); |
1333 | StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
1334 | |
1335 | // Check 3. Check required region cover for cache_read |
1336 | CheckRegionCover(self, scope_sref, buffer); |
1337 | |
1338 | // Check 4. Check if target block both read & write target buffer. |
1339 | const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); |
1340 | Optional<BufferRegion> read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); |
1341 | Optional<BufferRegion> write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); |
1342 | if (!read_region.defined() || !write_region.defined()) { |
1343 | throw NotReadWriteError(self->mod, GetRef<Block>(rw_block), buffer); |
1344 | } |
1345 | |
1346 | Array<StmtSRef> results_block_sref; |
1347 | Buffer new_buffer = WithScope(buffer, storage_scope); |
1348 | |
1349 | // Do cache read |
1350 | // Cache read step 0. Create CacheStageInfo |
1351 | CacheStageInfo info; |
1352 | info.read_buffer = buffer; |
1353 | // Create the corresponding buffer to be written for cache_read |
1354 | info.write_buffer = new_buffer; |
1355 | // Create the corresponding buffer allocation |
1356 | info.alloc = info.write_buffer; |
1357 | // Indicate which buffers should consume the cache. |
1358 | info.consumer_blocks.insert(block_sref); |
1359 | |
1360 | // Cache read step 1. Detect insert position |
1361 | CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info); |
1362 | |
1363 | // Cache read step 2. Making new cache stage block and rewrite readers. |
1364 | Block cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info, |
1365 | /*storage_scope=*/storage_scope); |
1366 | Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); |
1367 | |
1368 | // Cache read step 3. Replacing and updating flags for cache read. |
1369 | self->Replace(scope_sref, new_scope, info.block_reuse); |
1370 | StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); |
1371 | BlockInfo& block_info_read = self->block_info[result_block_sref]; |
1372 | block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); |
1373 | block_info_read.region_cover = true; |
1374 | block_info_read.scope->stage_pipeline = false; |
1375 | results_block_sref.push_back(result_block_sref); |
1376 | |
1377 | // Do cache write |
1378 | // Cache write step 0. Update cache stage info for cache_read. |
1379 | info.read_buffer = new_buffer; |
1380 | // Create the corresponding buffer to be written, i.e. result of cache_write |
1381 | info.write_buffer = buffer; |
1382 | // Create the corresponding buffer allocation |
1383 | info.alloc = nullptr; |
1384 | info.consumer_blocks.clear(); |
1385 | |
1386 | // Cache write step 1. Detect insert position |
1387 | CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info); |
1388 | // insert after target block for cache write |
1389 | info.loc_pos += 1; |
1390 | |
1391 | // Cache write step 2. Making new cache stage block and rewrite readers. |
1392 | Block cache_write_stage = MakeCacheStage(/*cache_region=*/write_region.value(), /*info=*/&info, |
1393 | /*storage_scope=*/storage_scope); |
1394 | new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, |
1395 | /*writer_block_sref=*/block_sref, /*info=*/&info); |
1396 | |
1397 | // Cache write step 4. Replacing and updating flags for cache write. |
1398 | self->Replace(scope_sref, new_scope, info.block_reuse); |
1399 | result_block_sref = self->stmt2ref.at(cache_write_stage.get()); |
1400 | BlockInfo& block_info_write = self->block_info[result_block_sref]; |
1401 | block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref); |
1402 | block_info_write.region_cover = true; |
1403 | block_info_write.scope->stage_pipeline = false; |
1404 | results_block_sref.push_back(result_block_sref); |
1405 | |
1406 | return results_block_sref; |
1407 | } |
1408 | |
1409 | StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
1410 | BufferIndexType buffer_index_type) { |
1411 | const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); |
1412 | Block block = GetRef<Block>(block_ptr); |
1413 | Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); |
1414 | StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); |
1415 | arith::Analyzer analyzer; |
1416 | |
1417 | // Step 1. Collect the original indices and check there's only single pattern of related |
1418 | // Load/Store and the buffer is not accessed opaquely |
1419 | Array<PrimExpr> original_indices = ReIndexCollector::Collect(self->mod, buffer, block); |
1420 | // Simplify the indices if possible |
1421 | for (const IterVar& iter : block->iter_vars) { |
1422 | analyzer.Bind(iter->var, iter->dom); |
1423 | } |
1424 | original_indices.MutateByApply( |
1425 | [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); |
1426 | |
1427 | // Collect block iters appearing in the original_indices |
1428 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> covered; |
1429 | for (const PrimExpr& index : original_indices) { |
1430 | PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { |
1431 | if (const VarNode* var = obj.as<VarNode>()) { |
1432 | covered.insert(GetRef<Var>(var)); |
1433 | } |
1434 | return true; |
1435 | }); |
1436 | } |
1437 | |
1438 | // Step 2. Creating CacheStageInfo |
1439 | CacheStageInfo info; |
1440 | // Create the corresponding buffer to be read(write), i.e. the result of reindex read(write) |
1441 | if (buffer_index_type == BufferIndexType::kWrite) { |
1442 | info.read_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); |
1443 | info.write_buffer = buffer; |
1444 | info.alloc = info.read_buffer; |
1445 | } else { |
1446 | info.read_buffer = buffer; |
1447 | info.write_buffer = CreateReindexBuffer(buffer, block->iter_vars, covered); |
1448 | info.alloc = info.write_buffer; |
1449 | } |
1450 | |
1451 | // Step 3. Check the block belongs to a chain loop nesting under the scope, |
1452 | // and get the insert location |
1453 | const StmtSRefNode* loop; |
1454 | for (loop = block_sref->parent; loop->parent != scope_sref.get();) { |
1455 | const ForNode* outer = loop->parent->StmtAs<ForNode>(); |
1456 | const ForNode* inner = loop->StmtAs<ForNode>(); |
1457 | ICHECK(outer != nullptr && inner != nullptr); |
1458 | ICHECK(outer->body.get() == inner); |
1459 | loop = loop->parent; |
1460 | } |
1461 | |
1462 | info.loc_pos = loop->seq_index == -1 ? 0 : loop->seq_index; |
1463 | if (buffer_index_type == BufferIndexType::kWrite) { |
1464 | info.loc_pos++; |
1465 | } |
1466 | |
1467 | // Step 4. Making new reindex stage block and rewrite |
1468 | Block reindex_stage = |
1469 | MakeReIndexStage(block, &info, covered, original_indices, buffer_index, buffer_index_type); |
1470 | Stmt new_scope = ReIndexRewriter::Rewrite(scope_sref, block_sref, &info, covered); |
1471 | |
1472 | // Step 5. Replacing and updating flags |
1473 | self->Replace(scope_sref, new_scope, info.block_reuse); |
1474 | StmtSRef result_block_sref = self->stmt2ref.at(reindex_stage.get()); |
1475 | BlockInfo& block_info = self->block_info[result_block_sref]; |
1476 | block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); |
1477 | block_info.region_cover = true; |
1478 | block_info.scope->stage_pipeline = true; |
1479 | return result_block_sref; |
1480 | } |
1481 | |
1482 | /******** Instruction Registration ********/ |
1483 | |
1484 | struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> { |
1485 | static constexpr const char* kName = "CacheRead" ; |
1486 | static constexpr bool kIsPure = false; |
1487 | |
1488 | private: |
1489 | static constexpr size_t kNumInputs = 2; |
1490 | static constexpr size_t kNumAttrs = 2; |
1491 | static constexpr size_t kNumDecisions = 0; |
1492 | |
1493 | static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, |
1494 | Array<BlockRV> consumer_blocks, Integer read_buffer_index, |
1495 | String storage_scope) { |
1496 | return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks); |
1497 | } |
1498 | |
1499 | static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks, |
1500 | Integer read_buffer_index, String storage_scope) { |
1501 | PythonAPICall py("cache_read" ); |
1502 | py.Input("block" , block); |
1503 | py.Input("read_buffer_index" , read_buffer_index->value); |
1504 | py.Input("storage_scope" , storage_scope); |
1505 | // Only write out consumer blocks if provided. |
1506 | if (!consumer_blocks.empty()) { |
1507 | py.Input("consumer_blocks" , consumer_blocks); |
1508 | } |
1509 | py.SingleOutput(outputs); |
1510 | return py.Str(); |
1511 | } |
1512 | |
1513 | template <typename> |
1514 | friend struct ::tvm::tir::UnpackedInstTraits; |
1515 | }; |
1516 | |
1517 | struct CacheWriteTraits : public UnpackedInstTraits<CacheWriteTraits> { |
1518 | static constexpr const char* kName = "CacheWrite" ; |
1519 | static constexpr bool kIsPure = false; |
1520 | |
1521 | private: |
1522 | static constexpr size_t kNumInputs = 2; |
1523 | static constexpr size_t kNumAttrs = 2; |
1524 | static constexpr size_t kNumDecisions = 0; |
1525 | |
1526 | static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, |
1527 | Array<BlockRV> consumer_blocks, Integer write_buffer_index, |
1528 | String storage_scope) { |
1529 | return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks); |
1530 | } |
1531 | |
1532 | static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks, |
1533 | Integer write_buffer_index, String storage_scope) { |
1534 | PythonAPICall py("cache_write" ); |
1535 | py.Input("block" , block); |
1536 | py.Input("write_buffer_index" , write_buffer_index->value); |
1537 | py.Input("storage_scope" , storage_scope); |
1538 | // Only write out consumer blocks if provided. |
1539 | if (!consumer_blocks.empty()) { |
1540 | py.Input("consumer_blocks" , consumer_blocks); |
1541 | } |
1542 | py.SingleOutput(outputs); |
1543 | return py.Str(); |
1544 | } |
1545 | |
1546 | template <typename> |
1547 | friend struct ::tvm::tir::UnpackedInstTraits; |
1548 | }; |
1549 | |
1550 | struct CacheInplaceTraits : public UnpackedInstTraits<CacheInplaceTraits> { |
1551 | static constexpr const char* kName = "CacheInplace" ; |
1552 | static constexpr bool kIsPure = false; |
1553 | |
1554 | private: |
1555 | static constexpr size_t kNumInputs = 1; |
1556 | static constexpr size_t kNumAttrs = 2; |
1557 | static constexpr size_t kNumDecisions = 0; |
1558 | |
1559 | static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block, |
1560 | Integer read_buffer_index, String storage_scope) { |
1561 | return sch->CacheInplace(block, read_buffer_index->value, storage_scope); |
1562 | } |
1563 | |
1564 | static String UnpackedAsPython(Array<String> outputs, String block, Integer read_buffer_index, |
1565 | String storage_scope) { |
1566 | PythonAPICall py("cache_inplace" ); |
1567 | py.Input("block" , block); |
1568 | py.Input("read_buffer_index" , read_buffer_index->value); |
1569 | py.Input("storage_scope" , storage_scope); |
1570 | py.OutputList(outputs); |
1571 | return py.Str(); |
1572 | } |
1573 | |
1574 | template <typename> |
1575 | friend struct ::tvm::tir::UnpackedInstTraits; |
1576 | }; |
1577 | |
1578 | struct : public UnpackedInstTraits<ReIndexTraits> { |
1579 | static constexpr const char* = "ReIndex" ; |
1580 | static constexpr bool = false; |
1581 | |
1582 | private: |
1583 | static constexpr size_t = 1; |
1584 | static constexpr size_t = 2; |
1585 | static constexpr size_t = 0; |
1586 | |
1587 | static BlockRV (Schedule sch, BlockRV block, Integer buffer_index, |
1588 | Integer buffer_index_type) { |
1589 | return sch->ReIndex(block, buffer_index.IntValue(), |
1590 | static_cast<BufferIndexType>(buffer_index_type->value)); |
1591 | } |
1592 | |
1593 | static String (Array<String> outputs, String block, Integer buffer_index, |
1594 | Integer buffer_index_type) { |
1595 | PythonAPICall py("reindex" ); |
1596 | py.Input("block" , block); |
1597 | std::ostringstream os; |
1598 | os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value)) |
1599 | << "\", " << buffer_index << ")" ; |
1600 | py.Input("buffer" , os.str()); |
1601 | py.SingleOutput(outputs); |
1602 | return py.Str(); |
1603 | } |
1604 | |
1605 | template <typename> |
1606 | friend struct ::tvm::tir::UnpackedInstTraits; |
1607 | }; |
1608 | |
1609 | TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); |
1610 | TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); |
1611 | TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); |
1612 | TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); |
1613 | } // namespace tir |
1614 | } // namespace tvm |
1615 | |