1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/arith/int_set.h>
20
21#include "./utils.h"
22namespace tvm {
23namespace tir {
24
25template <class K, class V>
26using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
27
28/**************** Utility functions ****************/
29
30/*!
31 * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
32 * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added
33 * to the result.
34 * \param region The buffer region to be analyzed
35 * \param dom_low_inclusive The lowest node in the sref tree path
36 * \param dom_high_exclusive The highest node in the sref tree path
37 * \return An n-dimensional integer set
38 */
39Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region, //
40 const PrimExpr& predicate, //
41 const StmtSRef& dom_low_inclusive, //
42 const StmtSRef& dom_high_exclusive, //
43 arith::Analyzer* analyzer) {
44 Map<Var, Range> var_dom = LoopDomainOfSRefTreePath(
45 /*low_inclusive=*/dom_low_inclusive,
46 /*high_exclusive=*/dom_high_exclusive,
47 /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()));
48 return EstimateRegionUpperBound(
49 /*region=*/region->region,
50 /*var_dom=*/var_dom,
51 /*predicate=*/predicate, /*analyzer=*/analyzer);
52}
53
54/*!
55 * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive)
56 * Some subregion may be discarded during the lower-bound analysis.
57 * \param realize The block realize that touches the buffer region
58 * \param region The buffer region to be analyzed
59 * \param dom_low_inclusive The lowest node in the sref tree path
60 * \param dom_high_exclusive The highest node in the sref tree path
61 * \param analyzer The analyzer
62 * \return An n-dimensional integer set
63 */
64Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, //
65 const PrimExpr& predicate, //
66 const StmtSRef& dom_low_inclusive, //
67 const StmtSRef& dom_high_exclusive, //
68 arith::Analyzer* analyzer) {
69 Map<Var, Range> var_dom = LoopDomainOfSRefTreePath(
70 /*low_inclusive=*/dom_low_inclusive,
71 /*high_exclusive=*/dom_high_exclusive,
72 /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope()));
73 if (Optional<Array<arith::IntSet>> result = EstimateRegionLowerBound(
74 /*region=*/region->region,
75 /*var_dom=*/var_dom,
76 /*predicate=*/predicate, /*analyzer=*/analyzer)) {
77 return result.value();
78 }
79 return Array<arith::IntSet>(region->buffer->shape.size(), arith::IntSet::Nothing());
80}
81
82/*!
83 * \brief Checks if the produced region can cover the consumed region
84 * \param buffer_shape The shape of the buffer
85 * \param produced_region The N-dimensional produced region
86 * \param consumed_region The N-dimensional consumed region
87 * \param analyzer The analyzer
88 * \return A boolean indicating if the produced region could cover the consumed region
89 */
90bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
91 const Array<arith::IntSet>& produced_region,
92 const Array<arith::IntSet>& consumed_region,
93 arith::Analyzer* analyzer) {
94 ICHECK_EQ(buffer_shape.size(), consumed_region.size());
95 ICHECK_EQ(produced_region.size(), consumed_region.size());
96 int ndim = produced_region.size();
97 for (int i = 0; i < ndim; ++i) {
98 arith::IntSet buffer_size = arith::IntSet::FromMinExtent(0, buffer_shape[i]);
99 if (produced_region[i].IsNothing()) {
100 return false;
101 }
102 arith::IntSet produced =
103 arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()),
104 analyzer->canonical_simplify(produced_region[i].max()));
105 arith::IntSet consumed =
106 arith::IntSet::Interval(analyzer->canonical_simplify(consumed_region[i].min()),
107 analyzer->canonical_simplify(consumed_region[i].max()));
108 produced = arith::Intersect({produced, buffer_size});
109 consumed = arith::Intersect({consumed, buffer_size});
110
111 produced = arith::IntSet::Interval(analyzer->Simplify(produced.min()),
112 analyzer->Simplify(produced.max()));
113 consumed = arith::IntSet::Interval(analyzer->Simplify(consumed.min()),
114 analyzer->Simplify(consumed.max()));
115
116 if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) &&
117 (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) {
118 return false;
119 }
120 }
121 return true;
122}
123
124/*!
125 * \brief Set the `StmtSRefNode::seq_index` field for stmt
126 * \param self The schedule class
127 * \param stmt The statement, or the realize node of the statement whose sref to be set
128 * \param seq_index The seq_index to be set
129 * \note The method is NOP for statements that are not schedulable, i.e. not For or Block
130 */
131void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) {
132 if (const auto* realize = stmt.as<BlockRealizeNode>()) {
133 const BlockNode* block = realize->block.get();
134 ICHECK(self->stmt2ref.count(block));
135 self->stmt2ref.at(block)->seq_index = seq_index;
136 } else if (const auto* block = stmt.as<BlockNode>()) {
137 ICHECK(self->stmt2ref.count(block));
138 self->stmt2ref.at(block)->seq_index = seq_index;
139 } else if (const auto* loop = stmt.as<ForNode>()) {
140 ICHECK(self->stmt2ref.count(loop));
141 self->stmt2ref.at(loop)->seq_index = seq_index;
142 } else {
143 // do nothing
144 }
145}
146
147/*!
148 * \brief Update seq_index of the children of a SeqStmt
149 * \param self The schedule class
150 * \param seq_stmt The SeqStmt whose children need updating
151 */
152void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) {
153 int i = 0;
154 for (const Stmt& stmt : seq_stmt->seq) {
155 SetSeqIndex(self, stmt, i);
156 ++i;
157 }
158}
159
160/*!
161 * \brief Update the sref information on the schedule class, as well as the statement of sref itself
162 * More specifically, update
163 * `sref->stmt` to `new_stmt`
164 * `self->stmt2ref`, remove the old statement that sref points to, and add the new statement
165 * \param self The schedule class to be updated
166 * \param sref The sref to be updated
167 * \param new_stmt The statement that replaces the statement inside the sref
168 */
169void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) {
170 ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>());
171 const StmtNode* old_stmt = sref->stmt;
172 ICHECK_NE(new_stmt, old_stmt);
173 self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref);
174 self->stmt2ref.erase(sref->stmt);
175 sref->stmt = new_stmt;
176}
177
178/**************** Creation ****************/
179/*! \brief A helper class to update BlockInfo for a ScheduleStateNode */
180class BlockInfoCollector : private StmtVisitor {
181 public:
182 static void Collect(ScheduleStateNode* self, const Stmt& stmt) {
183 BlockInfoCollector collector(self);
184 collector.VisitStmt(stmt);
185 }
186
187 private:
188 explicit BlockInfoCollector(ScheduleStateNode* self)
189 : self_(self), srefs_{}, block2realize_{}, block_frames_{} {
190 block_frames_.emplace({});
191 }
192
193 /*!
194 * \brief Add a new statement to the stack, which becomes the current scope
195 * \param stmt A for-loop statement or a block statement
196 * \return A sref to the stmt
197 */
198 void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); }
199
200 /*! \brief Pop the top of the scope */
201 StmtSRef PopSRef() {
202 StmtSRef sref = srefs_.back();
203 srefs_.pop_back();
204 return sref;
205 }
206
207 void MakeBlockInfo(StmtSRef scope_root) {
208 bool is_root_block = srefs_.empty();
209 // Calculate `BlockInfo::scope`
210 Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
211 BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs));
212 // Set `affine_binding`
213 if (is_root_block) {
214 // If the block doesn't have outer loops and BlockRealize,
215 // then we set the affine binding flag as true only if the block has no block vars
216 const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
217 if (block->iter_vars.empty()) info.affine_binding = true;
218 } else {
219 info.affine_binding =
220 IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt),
221 /*loop_var_ranges=*/LoopDomainOfSRefTreePath(srefs_.back()),
222 /*analyzer=*/&analyzer_);
223 }
224 // Set `region_cover` to true, will be updated on its scope block
225 info.region_cover = true;
226 // Set `stage_pipeline` and `region_cover` for its intermediate children
227 info.scope->stage_pipeline =
228 CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs);
229 }
230
231 bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root,
232 const Array<StmtSRef>& child_block_srefs) {
233 const StmtSRefNode* limit = scope_root->parent;
234 bool stage_pipeline = true;
235 // Step 1. Unbind the read/write regions of each child block
236 std::unordered_map<const StmtSRefNode*, Array<BufferRegion>> block_reads_unbound;
237 std::unordered_map<const StmtSRefNode*, Array<BufferRegion>> block_writes_unbound;
238 block_reads_unbound.reserve(child_block_srefs.size());
239 block_writes_unbound.reserve(child_block_srefs.size());
240 for (const StmtSRef& block_sref : child_block_srefs) {
241 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
242 Map<Var, PrimExpr> binding = GetBindings(block2realize_.at(block));
243 // Step 1.1. Unbind read regions
244 Array<BufferRegion> reads;
245 reads.reserve(block->reads.size());
246 for (const BufferRegion& region : block->reads) {
247 reads.push_back(BufferRegion(region->buffer, Substitute(region->region, binding)));
248 }
249 block_reads_unbound.emplace(block_sref.get(), std::move(reads));
250 // Step 1.2. Unbind write regions
251 Array<BufferRegion> writes;
252 writes.reserve(block->writes.size());
253 for (const BufferRegion& region : block->writes) {
254 writes.push_back(BufferRegion(region->buffer, Substitute(region->region, binding)));
255 }
256 block_writes_unbound.emplace(block_sref.get(), std::move(writes));
257 }
258 // Step 2. For each consumer, check the region cover property
259 for (const auto& kv : info.scope->dst2deps) {
260 const StmtSRef& consumer_block_sref = kv.first;
261 const Array<Dependency>& deps = kv.second;
262 const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
263 const BlockRealize& consumer_realize = block2realize_.at(consumer_block);
264 bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true;
265 // Step 2.1. Extract the path to the scope root
266 std::unordered_map<const StmtSRefNode*, std::vector<const StmtSRefNode*>> lca_loc;
267 for (const StmtSRefNode* p = consumer_block_sref.get(); p != limit; p = p->parent) {
268 ICHECK(p != nullptr);
269 lca_loc[p] = {};
270 }
271 // Step 2.2. For each producer, find the LCA of the consumer
272 for (const Dependency& dep : deps) {
273 if (dep->kind == DepKind::kWAR || dep->kind == DepKind::kOpaque) {
274 stage_pipeline = false;
275 }
276 // Only care about producer-consumer relationship
277 if (dep->kind != DepKind::kRAW) {
278 continue;
279 }
280 const StmtSRef& producer = dep->src;
281 for (const StmtSRefNode* p = producer.get();; p = p->parent) {
282 ICHECK(p != nullptr);
283 auto it = lca_loc.find(p);
284 // Find the first (lowest) position in the ancestor of the consumer,
285 // which is the LCA by definition
286 if (it != lca_loc.end()) {
287 it->second.push_back(producer.get());
288 break;
289 }
290 }
291 }
292 // Step 2.3. For each LCA, gather the produced regions,
293 // then check if it could cover the consumed region
294 for (StmtSRef lca = consumer_block_sref; region_cover && lca.get() != limit;
295 lca = GetRef<StmtSRef>(lca->parent)) {
296 const std::vector<const StmtSRefNode*>& producer_block_srefs = lca_loc.at(lca.get());
297 // Skip empty LCA positions
298 if (producer_block_srefs.empty()) {
299 continue;
300 }
301 // For each buffer, record the regions generated under this loop
302 std::unordered_map<const BufferNode*, std::vector<Array<arith::IntSet>>> touched_regions;
303 // Step 2.3.1. Find all the regions read by the consumer that we care about
304 for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) {
305 const BufferNode* buffer = region->buffer.get();
306 touched_regions[buffer] = {};
307 }
308 // Step 2.3.2. Find all the regions written by each producer
309 for (const StmtSRefNode* producer_block_sref : producer_block_srefs) {
310 const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt);
311 StmtSRef parent_sref = GetRef<StmtSRef>(producer_block_sref->parent);
312 for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) {
313 const BufferNode* buffer = region->buffer.get();
314 auto it = touched_regions.find(buffer);
315 // Skip the regions that is not read by the consumer
316 if (it != touched_regions.end()) {
317 std::vector<Array<arith::IntSet>>& touched_region = it->second;
318 // The analysis here is trying to be conservation to rule out false positive cases,
319 // and to make sure region cover property must be satisfied once the flag is on
320 // Therefore, we use lower-bound analysis for producers and upper-bound analysis for
321 // consumer, and require that the produced region can cover the consumed region
322 touched_region.push_back(AnalyzeRegionLowerBound(
323 /*region=*/region,
324 /*predicate=*/producer_realize->predicate,
325 /*dom_low_inclusive=*/parent_sref,
326 /*dom_high_exclusive=*/lca,
327 /*analyzer=*/&analyzer_));
328 }
329 }
330 }
331 // Step 2.3.3. For each buffer, check the region cover property
332 {
333 StmtSRef parent_sref = GetRef<StmtSRef>(consumer_block_sref->parent);
334 for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) {
335 const BufferNode* buffer = region->buffer.get();
336 const std::vector<Array<arith::IntSet>>& touched_region = touched_regions.at(buffer);
337 if (!touched_region.empty()) {
338 Array<arith::IntSet> produced_region =
339 arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()});
340 Array<arith::IntSet> consumed_region = AnalyzeRegionUpperBound(
341 /*region=*/region,
342 /*predicate=*/consumer_realize->predicate,
343 /*dom_low_inclusive=*/parent_sref,
344 /*dom_high_exclusive=*/lca,
345 /*analyzer=*/&analyzer_);
346 if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region,
347 &analyzer_)) {
348 region_cover = false;
349 self_->block_info.at(consumer_block_sref).region_cover = region_cover;
350 break;
351 }
352 }
353 }
354 }
355 }
356 stage_pipeline = stage_pipeline && region_cover;
357 }
358 return stage_pipeline;
359 }
360
361 void VisitStmt_(const ForNode* loop) final {
362 analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
363 PushSRef(loop);
364 VisitStmt(loop->body);
365 PopSRef();
366 }
367
368 void VisitStmt_(const BlockRealizeNode* realize) final {
369 block_frames_.emplace_back();
370 const BlockNode* block = realize->block.get();
371 block2realize_.emplace(block, GetRef<BlockRealize>(realize));
372 // Recursive visit
373 PushSRef(block);
374 VisitStmt(block->body); // `block->init` is not visited
375 StmtSRef sref = PopSRef();
376 // Create BlockInfo for the block
377 MakeBlockInfo(sref);
378 // Update parent scope
379 block_frames_.pop_back();
380 block_frames_.back().push_back(sref);
381 }
382
383 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
384 // Set `seq_index` information for SeqStmtNode
385 StmtVisitor::VisitStmt_(seq_stmt);
386 SetSeqIndexInChildren(self_, seq_stmt);
387 }
388
389 /*! \brief The ScheduleStateNode we are operating on */
390 ScheduleStateNode* self_;
391 /*! \brief The stack frame used to indicate the current scope */
392 std::vector<StmtSRef> srefs_;
393 /*! \brief The BlockRealize corresponding to blocks */
394 std::unordered_map<const StmtNode*, BlockRealize> block2realize_;
395 /*! \brief The stack frames of blocks in the DFS visit. */
396 std::vector<Array<StmtSRef>> block_frames_;
397 /*! \brief The auxiliary analyzer */
398 arith::Analyzer analyzer_;
399};
400
401/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
402class StateCreator : private StmtVisitor {
403 public:
404 /*!
405 * \brief The entry function
406 * \param self The schedule state to be completed
407 */
408 static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
409 ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
410 ScheduleStateNode* self = n.get();
411 // Set `n->mod`
412 n->mod = std::move(mod);
413 // Set `n->debug_mask`
414 n->debug_mask = debug_mask;
415 // Set `n->stmt2ref` and `n->block_info`
416 StateCreator creator(self);
417 for (const auto& kv : n->mod->functions) {
418 const BaseFunc& base_func = kv.second;
419 if (const auto* func = base_func.as<PrimFuncNode>()) {
420 VerifyWellFormed(GetRef<PrimFunc>(func));
421 creator.VisitStmt(func->body);
422 BlockInfoCollector::Collect(self, func->body);
423 }
424 }
425 return n;
426 }
427
428 private:
429 explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
430
431 /*!
432 * \brief Add a new statement to the stack, which becomes the current scope
433 * \param stmt A for-loop statement or a block statement
434 * \return A sref to the stmt
435 */
436 void PushSRef(const StmtNode* stmt) {
437 if (srefs_.empty()) {
438 srefs_.push_back(
439 StmtSRef(stmt,
440 /*parent=*/nullptr,
441 /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
442 } else {
443 StmtSRefNode* parent = srefs_.back().get();
444 srefs_.push_back(
445 StmtSRef(stmt, parent,
446 /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex
447 }
448 }
449
450 /*! \brief Pop the top of the scope and record it in stmt2ref map */
451 void PopAndRecordSRef() {
452 StmtSRef sref = std::move(srefs_.back());
453 self_->stmt2ref[sref->stmt] = sref;
454 srefs_.pop_back();
455 }
456
457 void VisitStmt_(const ForNode* loop) final {
458 PushSRef(loop);
459 VisitStmt(loop->body);
460 PopAndRecordSRef();
461 }
462
463 void VisitStmt_(const BlockRealizeNode* realize) final {
464 const BlockNode* block = realize->block.get();
465 PushSRef(block);
466 VisitStmt(block->body); // `block->init` is not visited
467 PopAndRecordSRef();
468 }
469
470 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
471 // Set `seq_index` information for SeqStmtNode
472 StmtVisitor::VisitStmt_(seq_stmt);
473 SetSeqIndexInChildren(self_, seq_stmt);
474 }
475
476 /*! \brief The result ScheduleStateNode */
477 ScheduleStateNode* self_;
478 /*! \brief The stack frame used to indicate the current scope */
479 std::vector<StmtSRef> srefs_;
480};
481
482/**************** Constructor ****************/
483
484ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
485 CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported";
486 data_ = StateCreator::Create(mod, debug_mask);
487}
488
489/**************** Replace ****************/
490
491/*
492 * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new
493 * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused,
494 * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the
495 * original loop has a child block, then the sref to the child block should be reused, so that users
496 * won't have to acquire that sref again.
497 *
498 * The workflow of the replacement algorithm is:
499 * 1) Detect all possible reuses in class ReuseInfo
500 * 2) Remove the expired srefs in class SRefTreePruner
501 * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater
502 * 4) Renew the ancestors of `src_stmt` to reflect the replacement
503 */
504
505/*!
506 * \brief Record the different sref reuse types in the replacement
507 *
508 * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`,
509 * which, given the immutability of the IR, means the entire subtree is unchanged,
510 * and we do not need to recurse into the subtree.
511 *
512 * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`),
513 * which are both loops or both blocks,
514 * there is correspondence between them,
515 * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`.
516 *
517 * \note The intact reuse and loop sref reuse are collected in the ReuseCollector,
518 * while the block reuse is specified by the caller.
519 *
520 * \sa ReuseCollector
521 */
522struct ReuseInfo {
523 /*!
524 * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding
525 * sref is reused and it is intact reuse.
526 */
527 std::unordered_set<const StmtNode*> intact;
528 /*!
529 * \brief Kind 2.1. Loop sref reuse
530 * If the loop var of a loop is in `loop_sref_possible_reuse`,
531 * it means that when `src_stmt` has a loop that uses this loop var,
532 * the reuse kind is loop sref reuse.
533 * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't
534 * contain a loop that uses this loop var, and that is the reason why it is named "possible".
535 */
536 std::unordered_set<const VarNode*> loop_sref_possible_reuse;
537 /*!
538 * \brief Kind 2.2. Block sref reuse.
539 * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`,
540 * indicating the sref to the old block should be reused in the sref to the new block.
541 */
542 std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
543};
544
545/*!
546 * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`:
547 *
548 * 1) Intact: the subtree represented by `intact` appears on both old and new IR.
549 * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged,
550 * which means we do not need to visit into the subtree of the old statement.
551 *
552 * 2) Reused block/loop: for two different objects (`src`, `tgt`),
553 * which are both loops or both blocks,
554 * and there is correspondence between them,
555 * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`,
556 */
557class ReuseCollector : public StmtVisitor {
558 public:
559 static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) {
560 ReuseCollector collector(self);
561 collector.VisitStmt(tgt_stmt);
562 ReuseInfo result;
563 result.intact = {collector.intact_.begin(), collector.intact_.end()};
564 result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()};
565 // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it,
566 // and it is supposed to be properly set by the caller.
567 return result;
568 }
569
570 private:
571 explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {}
572
573 void VisitStmt_(const ForNode* op) final {
574 if (self_->stmt2ref.count(op)) {
575 intact_.push_back(op);
576 } else {
577 // Collect loop vars for detecting reuse of loop sref
578 loop_vars_.push_back(op->loop_var.get());
579 StmtVisitor::VisitStmt_(op);
580 }
581 }
582
583 void VisitStmt_(const BlockNode* op) final {
584 if (self_->stmt2ref.count(op)) {
585 intact_.push_back(op);
586 } else {
587 StmtVisitor::VisitStmt_(op);
588 }
589 }
590
591 /*! \brief The schedule state to be worked on */
592 const ScheduleStateNode* self_;
593 /*! \brief The intact statements we have collected along the way of visiting */
594 std::vector<const StmtNode*> intact_;
595 /*! \brief The loop variable we collected in the tgt_stmt */
596 std::vector<const VarNode*> loop_vars_;
597};
598
599/*!
600 * \brief A helper visitor which removes the stale srefs in the `src_stmt`
601 * that are useless after the replacement.
602 *
603 * It uses the reuse information previously collected to
604 * 1) delete those srefs that are not reused.
605 * 2) return the sref objects that are loop/block sref reuses, but not intact reuses
606 */
607class SRefTreePruner : public StmtVisitor {
608 public:
609 /*!
610 * \brief The entry function
611 * \param self The schedule class
612 * \param info The reuse info about intact reuse and loop/block reuse
613 * \param src_stmt The `src_stmt` where stale srefs to be removed
614 * \return Mapping from the reuse elements to reused srefs, more specifically:
615 * 1) Loop reuse: maps a loop var to the reused sref
616 * 2) Block reuse: maps a block stmt to the reused sref,
617 * where the block comes from the subtree of `tgt_stmt`
618 * 3) Intact reuse: not returned
619 */
620 static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self,
621 const ReuseInfo& reuse_info,
622 const Stmt& src_stmt) {
623 SRefTreePruner pruner(self, reuse_info);
624 pruner.VisitStmt(src_stmt);
625 return std::move(pruner.reused_srefs_);
626 }
627
628 private:
629 explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info)
630 : self_(self), reuse_info_(reuse_info) {}
631
632 void VisitStmt_(const ForNode* op) final {
633 if (reuse_info_.intact.count(op)) {
634 return;
635 }
636 auto it = self_->stmt2ref.find(op);
637 ICHECK(it != self_->stmt2ref.end())
638 << "IndexError: Cannot find corresponding StmtSRef for the loop:\n"
639 << GetRef<For>(op);
640 StmtSRef& sref = it->second;
641 // Detect reuse
642 const VarNode* loop_var = op->loop_var.get();
643 if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) {
644 // sref can be reused
645 reused_srefs_.emplace(loop_var, std::move(sref));
646 } else {
647 sref->Reset();
648 }
649 // erase the statement
650 self_->stmt2ref.erase(it);
651 // detect recursively
652 VisitStmt(op->body);
653 }
654
655 void VisitStmt_(const BlockNode* op) final {
656 if (reuse_info_.intact.count(op)) {
657 return;
658 }
659 auto it = self_->stmt2ref.find(op);
660 ICHECK(it != self_->stmt2ref.end())
661 << "IndexError: Cannot find corresponding StmtSRef for the block:\n"
662 << GetRef<Block>(op);
663 StmtSRef& sref = it->second;
664 // Detect reuse
665 const auto& sref_reuse = reuse_info_.block_sref_reuse;
666 if (auto reuse_it = sref_reuse.find(op); reuse_it != sref_reuse.end()) {
667 const BlockNode* to_reuse = reuse_it->second;
668 // sref can be reused
669 reused_srefs_.emplace(to_reuse, std::move(sref));
670 } else {
671 sref->Reset();
672 self_->block_info.erase(sref);
673 }
674 // erase the statement
675 self_->stmt2ref.erase(it);
676 // detect recursively
677 // op->init is omitted
678 VisitStmt(op->body);
679 }
680
681 /*! \brief The schedule state we are working on */
682 ScheduleStateNode* self_;
683 /*! \brief The reuse information we collected previously */
684 const ReuseInfo& reuse_info_;
685 /*!
686 * \brief Reused srefs:
687 * 1) loop var -> StmtSRef
688 * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt`
689 */
690 std::unordered_map<const Object*, StmtSRef> reused_srefs_;
691};
692
693/*!
694 * \brief Update the sref in the `tgt_stmt` given the reuse information
695 *
696 * After being updated, in the `tgt_stmt` subtree,
697 * 1) all `StmtSRefNode::parent`s are correct
698 * 2) all `StmtSRefNode::seq_index`s are correct, except for the root
699 * 3) all `StmtSRefNode::stmt`s are correct, except for the root
700 */
701class SRefUpdater : public StmtVisitor {
702 public:
703 static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
704 const std::unordered_map<const Object*, StmtSRef>& reused_srefs,
705 const Stmt& tgt_stmt) {
706 SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt);
707 }
708
709 private:
710 explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent,
711 const std::unordered_map<const Object*, StmtSRef>& reused_srefs)
712 : self_(GetRef<ScheduleState>(self)),
713 ancestors_{src_stmt_parent},
714 reused_srefs_(reused_srefs) {}
715
716 void VisitStmt_(const ForNode* op) final {
717 StmtSRef& sref = self_->stmt2ref[op];
718 // Detect intact reuse
719 if (sref.defined()) {
720 sref->parent = ancestors_.back();
721 sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
722 return;
723 }
724 // Detect loop reuse
725 auto it = reused_srefs_.find(op->loop_var.get());
726 if (it != reused_srefs_.end()) {
727 // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]`
728 sref = it->second;
729 sref->stmt = op;
730 sref->parent = ancestors_.back();
731 sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
732 } else {
733 // A new loop sref without reuse
734 sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
735 /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex
736 }
737 // Recursive visit
738 ancestors_.push_back(sref.get());
739 VisitStmt(op->body);
740 ancestors_.pop_back();
741 }
742
743 void VisitStmt_(const BlockNode* op) final {
744 StmtSRef& sref = self_->stmt2ref[op];
745 // Detect intact
746 if (sref.defined()) {
747 sref->parent = ancestors_.back();
748 sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
749 return;
750 }
751 // Detect block reuse
752 auto it = reused_srefs_.find(op);
753 if (it != reused_srefs_.end()) {
754 // Update `stmt2ref[op]` to `reused_srefs_[op]`
755 sref = it->second;
756 sref->stmt = op;
757 sref->parent = ancestors_.back();
758 sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex
759 } else {
760 // A new block sref without reuse
761 sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(),
762 /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex
763 }
764 // Recursive visit
765 ancestors_.push_back(sref.get());
766 VisitStmt(op->body);
767 ancestors_.pop_back();
768 // Additionally, need to update the scope because the block is changed
769 UpdateBlockInfo(sref);
770 }
771
772 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
773 StmtVisitor::VisitStmt_(seq_stmt);
774 SetSeqIndexInChildren(self_.get(), seq_stmt);
775 }
776
777 void UpdateBlockInfo(const StmtSRef& block_sref) {
778 using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator;
779 // The caller is responsible for correcting the flags
780 BlockInfo new_info((BlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref))));
781 std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info);
782 bool inserted = insert_result.second;
783 BlockInfo& info = insert_result.first->second;
784 if (inserted) {
785 // Insertion has happened, update the flags accordingly
786 BlockInfo& info = insert_result.first->second;
787 info.affine_binding = false;
788 info.region_cover = false;
789 info.scope->stage_pipeline = false;
790 } else {
791 // Insertion didn't take place, because the entry has been there before.
792 // In this case, we assume that flags are still valid so intentionally keep them unchanged
793 new_info.scope->stage_pipeline = info.scope->stage_pipeline;
794 info.scope = std::move(new_info.scope);
795 }
796 }
797
798 /*! \brief The schedule state class to be worked on */
799 ScheduleState self_;
800 /*! \brief A stack containing all the ancestor For/Block nodes during the visit */
801 std::vector<StmtSRefNode*> ancestors_;
802 /*! \brief Maps the loop var / block to the reused sref */
803 const std::unordered_map<const Object*, StmtSRef>& reused_srefs_;
804};
805
806/*!
807 * \brief A helper that returns a new copy of `parent_stmt`,
808 * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`.
809 * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree.
810 */
811class ChildReplacer : private StmtMutator {
812 public:
813 static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt,
814 const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) {
815 // Check the invariant
816 ICHECK(child_src_stmt->IsInstance<BlockNode>() || //
817 child_src_stmt->IsInstance<ForNode>());
818 ICHECK(child_tgt_stmt->IsInstance<BlockNode>() || //
819 child_tgt_stmt->IsInstance<ForNode>() || //
820 child_tgt_stmt->IsInstance<BlockRealizeNode>());
821 ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index);
822 replacer.allow_copy_on_write_ = allow_copy_on_write;
823 return replacer.CopyOnWriteAndVisit(parent_stmt);
824 }
825
826 private:
827 explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index)
828 : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {}
829
830 Stmt VisitStmt(const Stmt& stmt) final {
831 if (stmt.get() == src_stmt_) {
832 // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt`
833 return tgt_stmt_;
834 } else {
835 return StmtMutator::VisitStmt(stmt);
836 }
837 }
838
839 // Skipping sibling blocks and loops other than `src_stmt_`
840 Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); }
841 Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); }
842
843 Stmt VisitStmt_(const SeqStmtNode* op) final {
844 int i = this->seq_index_;
845 int n = static_cast<int>(op->seq.size());
846 if (0 <= i && i < n) {
847 const Stmt& stmt = op->seq[i];
848 Optional<Stmt> new_stmt = NullOpt;
849 const StmtNode* src_stmt = this->src_stmt_;
850 // `stmt` can be For or BlockRealize
851 // `src_stmt` can be For or Block
852 // so the match from `stmt` to `src_stmt` can be
853 // 1) For -> For
854 // 2) BlockRealize -> Block
855 if (stmt.get() == src_stmt) {
856 // Case 1. src_stmt is For, stmt is For
857 new_stmt = tgt_stmt_;
858 } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
859 // Case 2. stmt is BlockRealize, src_stmt is Block
860 if (realize->block.get() == src_stmt) {
861 const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode);
862 ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
863 new_realize->block = GetRef<Block>(tgt_block);
864 new_stmt = BlockRealize(std::move(new_realize));
865 }
866 }
867 // Move new_stmt to position i
868 if (new_stmt.defined()) {
869 ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op);
870 new_seq_stmt->seq.Set(i, new_stmt.value());
871 return SeqStmt(std::move(new_seq_stmt));
872 }
873 }
874 return StmtMutator::VisitStmt_(op);
875 }
876
877 Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) {
878 // Step 1. Copy-on-write the `parent_stmt` and extract its `body`,
879 // where `body` means the body of either a block or a loop
880 // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt`
881 // and replace it with `child_tgt_stmt`
882 if (parent_stmt->IsInstance<BlockNode>()) {
883 auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt));
884 ObjectPtr<BlockNode> new_block = CopyOnWrite(block);
885 new_block->body = this->VisitStmt(new_block->body);
886 return Block(std::move(new_block));
887 } else if (parent_stmt->IsInstance<ForNode>()) {
888 auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt));
889 ObjectPtr<ForNode> new_loop = CopyOnWrite(loop);
890 new_loop->body = this->VisitStmt(new_loop->body);
891 return For(std::move(new_loop));
892 }
893 LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey();
894 throw;
895 }
896
897 /*! \brief The `src_stmt` to be replaced */
898 const StmtNode* src_stmt_;
899 /*! \brief The `tgt_stmt` to be replaced in */
900 const Stmt& tgt_stmt_;
901 /*!
902 * \brief The `seq_index` of the `src_stmt`
903 * \sa StmtSRefNode
904 */
905 int seq_index_;
906};
907
908void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt,
909 const Map<Block, Block>& _block_sref_reuse) {
910 if (this->debug_mask != 0) {
911 const StmtNode* src_stmt = _src_sref->stmt;
912 bool input_correct =
913 (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) ||
914 (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) ||
915 (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>());
916 if (!input_correct) {
917 LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey()
918 << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n"
919 << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n"
920 << tgt_stmt;
921 }
922 }
923 // Rule out the case that no replacement happens
924 if (_src_sref->stmt == tgt_stmt.get()) {
925 return;
926 }
927 // Reset sref as a new sref so that its content won't be affected by subsequent changes
928 StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index);
929 Stmt src_stmt = GetRef<Stmt>(src_sref->stmt);
930 // Step 1. Create all the nodes needed for the new sref tree.
931 // After this step
932 // 1) all `parent`s are correct
933 // 2) all `seq_index`s are correct, except for the root
934 // 3) all `stmt`s are correct, except for the root
935 {
936 // Step 0. Setup block_sref_reuse
937 std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse;
938 block_sref_reuse.reserve(_block_sref_reuse.size() + 1);
939 for (const auto& kv : _block_sref_reuse) {
940 block_sref_reuse.emplace(kv.first.get(), kv.second.get());
941 }
942 // Step 1.1. Collect info for different kinds of reuses
943 // 1) intact
944 // 2) loop/block reuse
945 ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt);
946 reuse_info.block_sref_reuse = std::move(block_sref_reuse);
947 // Step 1.2. Collect loop/block reuse to their corresponding srefs
948 // and remove those srefs in the `src_stmt` that are no longer used after replacement
949 std::unordered_map<const Object*, StmtSRef> reused_srefs =
950 SRefTreePruner::Prune(this, reuse_info, src_stmt);
951 // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused
952 // srefs in `tgt_stmt`
953 SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt);
954 }
955 // Step 2. Set the ancestors' children properly
956 // Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed.
957 // The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace.
958 // Along the way, because we create a new ancestor path,
959 // we need to update those sref points from old ancestors to newly created ones
960 // Variables:
961 // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that
962 // can be mutated inplace, it needs `num_copy_steps + 1` hops.
963 // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to.
964 // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to
965 int num_copy_steps = -1;
966 bool need_module_copy = false;
967 const PrimFuncNode* g_func = nullptr;
968 GlobalVar g_var;
969 {
970 int i = 0;
971 const StmtSRefNode* p = src_sref.get();
972 while (true) {
973 if (!p->stmt->unique()) {
974 num_copy_steps = i;
975 }
976 if (p->parent == nullptr) {
977 break;
978 }
979 ++i;
980 p = p->parent;
981 }
982 // Find `g_func` and `g_var` where the `src_sref` is in
983 g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var);
984 need_module_copy = num_copy_steps == i || //
985 !this->mod.unique() || //
986 !this->mod->functions.unique() || //
987 !g_func->unique();
988 }
989 // Loop invariant:
990 //
991 // Before step `i`:
992 // 1) `child_sref` is `src_sref` going up by `i` steps
993 // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement
994 // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are correct
995 // 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet
996 // 5) `tgt_stmt` is of type Loop, Block or BlockRealize
997 //
998 // During step `i`:
999 // 1) Create `parent_stmt` that corresponds to `child_sref->parent`
1000 // 2) Point `child_sref` to `child_tgt_stmt`
1001 // 3) `tgt_stmt` is of type Loop or Block
1002 StmtSRefNode* child_sref = src_sref.get();
1003 Stmt child_tgt_stmt = std::move(tgt_stmt);
1004 for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) {
1005 bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps;
1006 // Replace `child_sref->stmt` to `child_tgt_stmt`.
1007 const StmtNode* parent_stmt = child_sref->parent->stmt;
1008 const StmtNode* child_src_stmt = child_sref->stmt;
1009 // Step 2.1. Link `child_sref` to `child_tgt_stmt`
1010 if (i == 0) {
1011 // As the invariance of SRefUpdater,
1012 // the `seq_index` of the root of `tgt_stmt` is set as -1,
1013 // which might be incorrect
1014 SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index);
1015 } else {
1016 // Point `child_sref` to `child_tgt_stmt`
1017 UpdateSRef(this, child_sref, child_tgt_stmt.get());
1018 }
1019 // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt`
1020 Stmt new_parent_stmt =
1021 ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt,
1022 /*seq_index=*/child_sref->seq_index,
1023 /*allow_copy_on_write=*/can_directly_mutate_parent);
1024 // Step 2.3. Go to next parent
1025 if (can_directly_mutate_parent) {
1026 // If the node can be directly mutated inplace,
1027 // then there is no need to update its parent and the function
1028 break;
1029 }
1030 child_tgt_stmt = std::move(new_parent_stmt);
1031 child_sref = child_sref->parent;
1032 }
1033 // Step 3. Handle the case that we mutate the root
1034 if (need_module_copy) {
1035 // From the loop invariant, upon exit, while its subtree is properly set,
1036 // `child_sref` is not properly to `child_tgt_stmt` yet.
1037 if (src_sref->parent != nullptr) {
1038 // Not replacing a root
1039 UpdateSRef(this, child_sref, child_tgt_stmt.get());
1040 }
1041 // Ensure the uniqueness of `this->mod` and `this->mod->functions`
1042 IRModuleNode* new_mod = this->mod.CopyOnWrite();
1043 MapNode* new_map = new_mod->functions.CopyOnWrite();
1044 // Move out the PrimFunc where the sref belong while ensuring uniqueness
1045 PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
1046 ICHECK(ref_new_func.get() == g_func);
1047 PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
1048 // If `g_func` was not unique, after the 3 lines above:
1049 // `ref_new_func` points to a unique PrimFunc
1050 // `g_func` points to the previous PrimFunc if it is not unique
1051 // If `g_func` was unique, after the 3 lines above:
1052 // `ref_new_func` points to the same unique function that `g_func` points to
1053 // Update the body of the function the sref belongs to Assign
1054 const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode);
1055 // Make `child_tgt_stmt` the root block
1056 const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode);
1057 ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
1058 new_realize->block = GetRef<Block>(child_block);
1059 new_func->body = BlockRealize(std::move(new_realize));
1060 // Finally, move the `ref_new_func` back and update `this->mod`
1061 new_map->at(g_var) = std::move(ref_new_func);
1062 this->mod = GetRef<IRModule>(new_mod);
1063 }
1064 uint32_t flag = (debug_mask != -1) //
1065 ? static_cast<uint32_t>(debug_mask) //
1066 : std::numeric_limits<uint32_t>::max();
1067 if (flag & ScheduleDebugMask::kVerifySRefTree) {
1068 VerifySRefTree(GetRef<ScheduleState>(this));
1069 }
1070}
1071
1072void ScheduleStateNode::DebugVerify() const {
1073 ICHECK_GE(debug_mask, -1);
1074 uint32_t flag = (debug_mask != -1) //
1075 ? static_cast<uint32_t>(debug_mask) //
1076 : std::numeric_limits<uint32_t>::max();
1077 if (flag & ScheduleDebugMask::kVerifySRefTree) {
1078 VerifySRefTree(GetRef<ScheduleState>(this));
1079 }
1080 if (flag & ScheduleDebugMask::kVerifyCachedFlags) {
1081 VerifyCachedFlags(GetRef<ScheduleState>(this));
1082 }
1083}
1084
1085/**************** BlockInfo-related ****************/
1086
1087BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
1088 TVM_SREF_TO_BLOCK(block_sref);
1089 auto it = this->block_info.find(block_sref);
1090 CHECK(it != this->block_info.end())
1091 << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
1092 << GetRef<Stmt>(block_sref->stmt);
1093 return it->second;
1094}
1095
1096void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) {
1097 BlockInfoCollector::Collect(this, stmt);
1098}
1099
1100TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) {
1101 const BlockInfo& info = self->GetBlockInfo(block_sref);
1102 return {Bool(info.affine_binding), //
1103 Bool(info.region_cover), //
1104 Bool(info.scope->stage_pipeline)};
1105}
1106
1107/**************** FFI ****************/
1108
1109TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
1110TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
1111 .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
1112 return ScheduleState(mod, debug_mask);
1113 });
1114TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
1115 .set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
1116TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace")
1117 .set_body_method<ScheduleState>(&ScheduleStateNode::Replace);
1118TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef")
1119 .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional<StmtSRef> {
1120 auto it = self->stmt2ref.find(stmt.get());
1121 return it != self->stmt2ref.end() ? it->second : Optional<StmtSRef>(NullOpt);
1122 });
1123TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags").set_body_typed(GetCachedFlags);
1124
1125} // namespace tir
1126} // namespace tvm
1127