1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/arith/int_set.h> |
20 | |
21 | #include "./utils.h" |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | template <class K, class V> |
26 | using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>; |
27 | |
28 | /**************** Utility functions ****************/ |
29 | |
30 | /*! |
31 | * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) |
32 | * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added |
33 | * to the result. |
34 | * \param region The buffer region to be analyzed |
35 | * \param dom_low_inclusive The lowest node in the sref tree path |
36 | * \param dom_high_exclusive The highest node in the sref tree path |
37 | * \return An n-dimensional integer set |
38 | */ |
39 | Array<arith::IntSet> AnalyzeRegionUpperBound(const BufferRegion& region, // |
40 | const PrimExpr& predicate, // |
41 | const StmtSRef& dom_low_inclusive, // |
42 | const StmtSRef& dom_high_exclusive, // |
43 | arith::Analyzer* analyzer) { |
44 | Map<Var, Range> var_dom = LoopDomainOfSRefTreePath( |
45 | /*low_inclusive=*/dom_low_inclusive, |
46 | /*high_exclusive=*/dom_high_exclusive, |
47 | /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); |
48 | return EstimateRegionUpperBound( |
49 | /*region=*/region->region, |
50 | /*var_dom=*/var_dom, |
51 | /*predicate=*/predicate, /*analyzer=*/analyzer); |
52 | } |
53 | |
54 | /*! |
55 | * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) |
56 | * Some subregion may be discarded during the lower-bound analysis. |
57 | * \param realize The block realize that touches the buffer region |
58 | * \param region The buffer region to be analyzed |
59 | * \param dom_low_inclusive The lowest node in the sref tree path |
60 | * \param dom_high_exclusive The highest node in the sref tree path |
61 | * \param analyzer The analyzer |
62 | * \return An n-dimensional integer set |
63 | */ |
64 | Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, // |
65 | const PrimExpr& predicate, // |
66 | const StmtSRef& dom_low_inclusive, // |
67 | const StmtSRef& dom_high_exclusive, // |
68 | arith::Analyzer* analyzer) { |
69 | Map<Var, Range> var_dom = LoopDomainOfSRefTreePath( |
70 | /*low_inclusive=*/dom_low_inclusive, |
71 | /*high_exclusive=*/dom_high_exclusive, |
72 | /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); |
73 | if (Optional<Array<arith::IntSet>> result = EstimateRegionLowerBound( |
74 | /*region=*/region->region, |
75 | /*var_dom=*/var_dom, |
76 | /*predicate=*/predicate, /*analyzer=*/analyzer)) { |
77 | return result.value(); |
78 | } |
79 | return Array<arith::IntSet>(region->buffer->shape.size(), arith::IntSet::Nothing()); |
80 | } |
81 | |
82 | /*! |
83 | * \brief Checks if the produced region can cover the consumed region |
84 | * \param buffer_shape The shape of the buffer |
85 | * \param produced_region The N-dimensional produced region |
86 | * \param consumed_region The N-dimensional consumed region |
87 | * \param analyzer The analyzer |
88 | * \return A boolean indicating if the produced region could cover the consumed region |
89 | */ |
90 | bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape, |
91 | const Array<arith::IntSet>& produced_region, |
92 | const Array<arith::IntSet>& consumed_region, |
93 | arith::Analyzer* analyzer) { |
94 | ICHECK_EQ(buffer_shape.size(), consumed_region.size()); |
95 | ICHECK_EQ(produced_region.size(), consumed_region.size()); |
96 | int ndim = produced_region.size(); |
97 | for (int i = 0; i < ndim; ++i) { |
98 | arith::IntSet buffer_size = arith::IntSet::FromMinExtent(0, buffer_shape[i]); |
99 | if (produced_region[i].IsNothing()) { |
100 | return false; |
101 | } |
102 | arith::IntSet produced = |
103 | arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()), |
104 | analyzer->canonical_simplify(produced_region[i].max())); |
105 | arith::IntSet consumed = |
106 | arith::IntSet::Interval(analyzer->canonical_simplify(consumed_region[i].min()), |
107 | analyzer->canonical_simplify(consumed_region[i].max())); |
108 | produced = arith::Intersect({produced, buffer_size}); |
109 | consumed = arith::Intersect({consumed, buffer_size}); |
110 | |
111 | produced = arith::IntSet::Interval(analyzer->Simplify(produced.min()), |
112 | analyzer->Simplify(produced.max())); |
113 | consumed = arith::IntSet::Interval(analyzer->Simplify(consumed.min()), |
114 | analyzer->Simplify(consumed.max())); |
115 | |
116 | if (!analyzer->CanProve((analyzer->canonical_simplify(produced.min() - consumed.min()) <= 0) && |
117 | (analyzer->canonical_simplify(consumed.max() - produced.max()) <= 0))) { |
118 | return false; |
119 | } |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | /*! |
125 | * \brief Set the `StmtSRefNode::seq_index` field for stmt |
126 | * \param self The schedule class |
127 | * \param stmt The statement, or the realize node of the statement whose sref to be set |
128 | * \param seq_index The seq_index to be set |
129 | * \note The method is NOP for statements that are not schedulable, i.e. not For or Block |
130 | */ |
131 | void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { |
132 | if (const auto* realize = stmt.as<BlockRealizeNode>()) { |
133 | const BlockNode* block = realize->block.get(); |
134 | ICHECK(self->stmt2ref.count(block)); |
135 | self->stmt2ref.at(block)->seq_index = seq_index; |
136 | } else if (const auto* block = stmt.as<BlockNode>()) { |
137 | ICHECK(self->stmt2ref.count(block)); |
138 | self->stmt2ref.at(block)->seq_index = seq_index; |
139 | } else if (const auto* loop = stmt.as<ForNode>()) { |
140 | ICHECK(self->stmt2ref.count(loop)); |
141 | self->stmt2ref.at(loop)->seq_index = seq_index; |
142 | } else { |
143 | // do nothing |
144 | } |
145 | } |
146 | |
147 | /*! |
148 | * \brief Update seq_index of the children of a SeqStmt |
149 | * \param self The schedule class |
150 | * \param seq_stmt The SeqStmt whose children need updating |
151 | */ |
152 | void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) { |
153 | int i = 0; |
154 | for (const Stmt& stmt : seq_stmt->seq) { |
155 | SetSeqIndex(self, stmt, i); |
156 | ++i; |
157 | } |
158 | } |
159 | |
160 | /*! |
161 | * \brief Update the sref information on the schedule class, as well as the statement of sref itself |
162 | * More specifically, update |
163 | * `sref->stmt` to `new_stmt` |
164 | * `self->stmt2ref`, remove the old statement that sref points to, and add the new statement |
165 | * \param self The schedule class to be updated |
166 | * \param sref The sref to be updated |
167 | * \param new_stmt The statement that replaces the statement inside the sref |
168 | */ |
169 | void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) { |
170 | ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>()); |
171 | const StmtNode* old_stmt = sref->stmt; |
172 | ICHECK_NE(new_stmt, old_stmt); |
173 | self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref); |
174 | self->stmt2ref.erase(sref->stmt); |
175 | sref->stmt = new_stmt; |
176 | } |
177 | |
178 | /**************** Creation ****************/ |
179 | /*! \brief A helper class to update BlockInfo for a ScheduleStateNode */ |
180 | class BlockInfoCollector : private StmtVisitor { |
181 | public: |
182 | static void Collect(ScheduleStateNode* self, const Stmt& stmt) { |
183 | BlockInfoCollector collector(self); |
184 | collector.VisitStmt(stmt); |
185 | } |
186 | |
187 | private: |
188 | explicit BlockInfoCollector(ScheduleStateNode* self) |
189 | : self_(self), srefs_{}, block2realize_{}, block_frames_{} { |
190 | block_frames_.emplace({}); |
191 | } |
192 | |
193 | /*! |
194 | * \brief Add a new statement to the stack, which becomes the current scope |
195 | * \param stmt A for-loop statement or a block statement |
196 | * \return A sref to the stmt |
197 | */ |
198 | void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); } |
199 | |
200 | /*! \brief Pop the top of the scope */ |
201 | StmtSRef PopSRef() { |
202 | StmtSRef sref = srefs_.back(); |
203 | srefs_.pop_back(); |
204 | return sref; |
205 | } |
206 | |
207 | void MakeBlockInfo(StmtSRef scope_root) { |
208 | bool is_root_block = srefs_.empty(); |
209 | // Calculate `BlockInfo::scope` |
210 | Array<StmtSRef> child_block_srefs = std::move(block_frames_.back()); |
211 | BlockInfo& info = self_->block_info[scope_root] = BlockInfo(BlockScope(child_block_srefs)); |
212 | // Set `affine_binding` |
213 | if (is_root_block) { |
214 | // If the block doesn't have outer loops and BlockRealize, |
215 | // then we set the affine binding flag as true only if the block has no block vars |
216 | const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root); |
217 | if (block->iter_vars.empty()) info.affine_binding = true; |
218 | } else { |
219 | info.affine_binding = |
220 | IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt), |
221 | /*loop_var_ranges=*/LoopDomainOfSRefTreePath(srefs_.back()), |
222 | /*analyzer=*/&analyzer_); |
223 | } |
224 | // Set `region_cover` to true, will be updated on its scope block |
225 | info.region_cover = true; |
226 | // Set `stage_pipeline` and `region_cover` for its intermediate children |
227 | info.scope->stage_pipeline = |
228 | CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs); |
229 | } |
230 | |
231 | bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& scope_root, |
232 | const Array<StmtSRef>& child_block_srefs) { |
233 | const StmtSRefNode* limit = scope_root->parent; |
234 | bool stage_pipeline = true; |
235 | // Step 1. Unbind the read/write regions of each child block |
236 | std::unordered_map<const StmtSRefNode*, Array<BufferRegion>> block_reads_unbound; |
237 | std::unordered_map<const StmtSRefNode*, Array<BufferRegion>> block_writes_unbound; |
238 | block_reads_unbound.reserve(child_block_srefs.size()); |
239 | block_writes_unbound.reserve(child_block_srefs.size()); |
240 | for (const StmtSRef& block_sref : child_block_srefs) { |
241 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
242 | Map<Var, PrimExpr> binding = GetBindings(block2realize_.at(block)); |
243 | // Step 1.1. Unbind read regions |
244 | Array<BufferRegion> reads; |
245 | reads.reserve(block->reads.size()); |
246 | for (const BufferRegion& region : block->reads) { |
247 | reads.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); |
248 | } |
249 | block_reads_unbound.emplace(block_sref.get(), std::move(reads)); |
250 | // Step 1.2. Unbind write regions |
251 | Array<BufferRegion> writes; |
252 | writes.reserve(block->writes.size()); |
253 | for (const BufferRegion& region : block->writes) { |
254 | writes.push_back(BufferRegion(region->buffer, Substitute(region->region, binding))); |
255 | } |
256 | block_writes_unbound.emplace(block_sref.get(), std::move(writes)); |
257 | } |
258 | // Step 2. For each consumer, check the region cover property |
259 | for (const auto& kv : info.scope->dst2deps) { |
260 | const StmtSRef& consumer_block_sref = kv.first; |
261 | const Array<Dependency>& deps = kv.second; |
262 | const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); |
263 | const BlockRealize& consumer_realize = block2realize_.at(consumer_block); |
264 | bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true; |
265 | // Step 2.1. Extract the path to the scope root |
266 | std::unordered_map<const StmtSRefNode*, std::vector<const StmtSRefNode*>> lca_loc; |
267 | for (const StmtSRefNode* p = consumer_block_sref.get(); p != limit; p = p->parent) { |
268 | ICHECK(p != nullptr); |
269 | lca_loc[p] = {}; |
270 | } |
271 | // Step 2.2. For each producer, find the LCA of the consumer |
272 | for (const Dependency& dep : deps) { |
273 | if (dep->kind == DepKind::kWAR || dep->kind == DepKind::kOpaque) { |
274 | stage_pipeline = false; |
275 | } |
276 | // Only care about producer-consumer relationship |
277 | if (dep->kind != DepKind::kRAW) { |
278 | continue; |
279 | } |
280 | const StmtSRef& producer = dep->src; |
281 | for (const StmtSRefNode* p = producer.get();; p = p->parent) { |
282 | ICHECK(p != nullptr); |
283 | auto it = lca_loc.find(p); |
284 | // Find the first (lowest) position in the ancestor of the consumer, |
285 | // which is the LCA by definition |
286 | if (it != lca_loc.end()) { |
287 | it->second.push_back(producer.get()); |
288 | break; |
289 | } |
290 | } |
291 | } |
292 | // Step 2.3. For each LCA, gather the produced regions, |
293 | // then check if it could cover the consumed region |
294 | for (StmtSRef lca = consumer_block_sref; region_cover && lca.get() != limit; |
295 | lca = GetRef<StmtSRef>(lca->parent)) { |
296 | const std::vector<const StmtSRefNode*>& producer_block_srefs = lca_loc.at(lca.get()); |
297 | // Skip empty LCA positions |
298 | if (producer_block_srefs.empty()) { |
299 | continue; |
300 | } |
301 | // For each buffer, record the regions generated under this loop |
302 | std::unordered_map<const BufferNode*, std::vector<Array<arith::IntSet>>> touched_regions; |
303 | // Step 2.3.1. Find all the regions read by the consumer that we care about |
304 | for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { |
305 | const BufferNode* buffer = region->buffer.get(); |
306 | touched_regions[buffer] = {}; |
307 | } |
308 | // Step 2.3.2. Find all the regions written by each producer |
309 | for (const StmtSRefNode* producer_block_sref : producer_block_srefs) { |
310 | const BlockRealize& producer_realize = block2realize_.at(producer_block_sref->stmt); |
311 | StmtSRef parent_sref = GetRef<StmtSRef>(producer_block_sref->parent); |
312 | for (const BufferRegion& region : block_writes_unbound.at(producer_block_sref)) { |
313 | const BufferNode* buffer = region->buffer.get(); |
314 | auto it = touched_regions.find(buffer); |
315 | // Skip the regions that is not read by the consumer |
316 | if (it != touched_regions.end()) { |
317 | std::vector<Array<arith::IntSet>>& touched_region = it->second; |
318 | // The analysis here is trying to be conservation to rule out false positive cases, |
319 | // and to make sure region cover property must be satisfied once the flag is on |
320 | // Therefore, we use lower-bound analysis for producers and upper-bound analysis for |
321 | // consumer, and require that the produced region can cover the consumed region |
322 | touched_region.push_back(AnalyzeRegionLowerBound( |
323 | /*region=*/region, |
324 | /*predicate=*/producer_realize->predicate, |
325 | /*dom_low_inclusive=*/parent_sref, |
326 | /*dom_high_exclusive=*/lca, |
327 | /*analyzer=*/&analyzer_)); |
328 | } |
329 | } |
330 | } |
331 | // Step 2.3.3. For each buffer, check the region cover property |
332 | { |
333 | StmtSRef parent_sref = GetRef<StmtSRef>(consumer_block_sref->parent); |
334 | for (const BufferRegion& region : block_reads_unbound.at(consumer_block_sref.get())) { |
335 | const BufferNode* buffer = region->buffer.get(); |
336 | const std::vector<Array<arith::IntSet>>& touched_region = touched_regions.at(buffer); |
337 | if (!touched_region.empty()) { |
338 | Array<arith::IntSet> produced_region = |
339 | arith::UnionRegionLowerBound({touched_region.begin(), touched_region.end()}); |
340 | Array<arith::IntSet> consumed_region = AnalyzeRegionUpperBound( |
341 | /*region=*/region, |
342 | /*predicate=*/consumer_realize->predicate, |
343 | /*dom_low_inclusive=*/parent_sref, |
344 | /*dom_high_exclusive=*/lca, |
345 | /*analyzer=*/&analyzer_); |
346 | if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, |
347 | &analyzer_)) { |
348 | region_cover = false; |
349 | self_->block_info.at(consumer_block_sref).region_cover = region_cover; |
350 | break; |
351 | } |
352 | } |
353 | } |
354 | } |
355 | } |
356 | stage_pipeline = stage_pipeline && region_cover; |
357 | } |
358 | return stage_pipeline; |
359 | } |
360 | |
361 | void VisitStmt_(const ForNode* loop) final { |
362 | analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
363 | PushSRef(loop); |
364 | VisitStmt(loop->body); |
365 | PopSRef(); |
366 | } |
367 | |
368 | void VisitStmt_(const BlockRealizeNode* realize) final { |
369 | block_frames_.emplace_back(); |
370 | const BlockNode* block = realize->block.get(); |
371 | block2realize_.emplace(block, GetRef<BlockRealize>(realize)); |
372 | // Recursive visit |
373 | PushSRef(block); |
374 | VisitStmt(block->body); // `block->init` is not visited |
375 | StmtSRef sref = PopSRef(); |
376 | // Create BlockInfo for the block |
377 | MakeBlockInfo(sref); |
378 | // Update parent scope |
379 | block_frames_.pop_back(); |
380 | block_frames_.back().push_back(sref); |
381 | } |
382 | |
383 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
384 | // Set `seq_index` information for SeqStmtNode |
385 | StmtVisitor::VisitStmt_(seq_stmt); |
386 | SetSeqIndexInChildren(self_, seq_stmt); |
387 | } |
388 | |
389 | /*! \brief The ScheduleStateNode we are operating on */ |
390 | ScheduleStateNode* self_; |
391 | /*! \brief The stack frame used to indicate the current scope */ |
392 | std::vector<StmtSRef> srefs_; |
393 | /*! \brief The BlockRealize corresponding to blocks */ |
394 | std::unordered_map<const StmtNode*, BlockRealize> block2realize_; |
395 | /*! \brief The stack frames of blocks in the DFS visit. */ |
396 | std::vector<Array<StmtSRef>> block_frames_; |
397 | /*! \brief The auxiliary analyzer */ |
398 | arith::Analyzer analyzer_; |
399 | }; |
400 | |
401 | /*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ |
402 | class StateCreator : private StmtVisitor { |
403 | public: |
404 | /*! |
405 | * \brief The entry function |
406 | * \param self The schedule state to be completed |
407 | */ |
408 | static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) { |
409 | ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>(); |
410 | ScheduleStateNode* self = n.get(); |
411 | // Set `n->mod` |
412 | n->mod = std::move(mod); |
413 | // Set `n->debug_mask` |
414 | n->debug_mask = debug_mask; |
415 | // Set `n->stmt2ref` and `n->block_info` |
416 | StateCreator creator(self); |
417 | for (const auto& kv : n->mod->functions) { |
418 | const BaseFunc& base_func = kv.second; |
419 | if (const auto* func = base_func.as<PrimFuncNode>()) { |
420 | VerifyWellFormed(GetRef<PrimFunc>(func)); |
421 | creator.VisitStmt(func->body); |
422 | BlockInfoCollector::Collect(self, func->body); |
423 | } |
424 | } |
425 | return n; |
426 | } |
427 | |
428 | private: |
429 | explicit StateCreator(ScheduleStateNode* self) : self_(self) {} |
430 | |
431 | /*! |
432 | * \brief Add a new statement to the stack, which becomes the current scope |
433 | * \param stmt A for-loop statement or a block statement |
434 | * \return A sref to the stmt |
435 | */ |
436 | void PushSRef(const StmtNode* stmt) { |
437 | if (srefs_.empty()) { |
438 | srefs_.push_back( |
439 | StmtSRef(stmt, |
440 | /*parent=*/nullptr, |
441 | /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex |
442 | } else { |
443 | StmtSRefNode* parent = srefs_.back().get(); |
444 | srefs_.push_back( |
445 | StmtSRef(stmt, parent, |
446 | /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex |
447 | } |
448 | } |
449 | |
450 | /*! \brief Pop the top of the scope and record it in stmt2ref map */ |
451 | void PopAndRecordSRef() { |
452 | StmtSRef sref = std::move(srefs_.back()); |
453 | self_->stmt2ref[sref->stmt] = sref; |
454 | srefs_.pop_back(); |
455 | } |
456 | |
457 | void VisitStmt_(const ForNode* loop) final { |
458 | PushSRef(loop); |
459 | VisitStmt(loop->body); |
460 | PopAndRecordSRef(); |
461 | } |
462 | |
463 | void VisitStmt_(const BlockRealizeNode* realize) final { |
464 | const BlockNode* block = realize->block.get(); |
465 | PushSRef(block); |
466 | VisitStmt(block->body); // `block->init` is not visited |
467 | PopAndRecordSRef(); |
468 | } |
469 | |
470 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
471 | // Set `seq_index` information for SeqStmtNode |
472 | StmtVisitor::VisitStmt_(seq_stmt); |
473 | SetSeqIndexInChildren(self_, seq_stmt); |
474 | } |
475 | |
476 | /*! \brief The result ScheduleStateNode */ |
477 | ScheduleStateNode* self_; |
478 | /*! \brief The stack frame used to indicate the current scope */ |
479 | std::vector<StmtSRef> srefs_; |
480 | }; |
481 | |
482 | /**************** Constructor ****************/ |
483 | |
484 | ScheduleState::ScheduleState(IRModule mod, int debug_mask) { |
485 | CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported" ; |
486 | data_ = StateCreator::Create(mod, debug_mask); |
487 | } |
488 | |
489 | /**************** Replace ****************/ |
490 | |
491 | /* |
492 | * The goal of the replacement algorithm is to substitute a subtree `src_stmt` of the AST to a new |
493 | * subtree `tgt_stmt`, and maintain the corresponding sref tree accordingly, with some srefs reused, |
494 | * so that the srefs users hold doesn't expire. For example, if we split a loop into 2, and the |
495 | * original loop has a child block, then the sref to the child block should be reused, so that users |
496 | * won't have to acquire that sref again. |
497 | * |
498 | * The workflow of the replacement algorithm is: |
499 | * 1) Detect all possible reuses in class ReuseInfo |
500 | * 2) Remove the expired srefs in class SRefTreePruner |
501 | * 3) Update the reused the sref, and create the srefs for new statements, in class SRefUpdater |
502 | * 4) Renew the ancestors of `src_stmt` to reflect the replacement |
503 | */ |
504 | |
505 | /*! |
506 | * \brief Record the different sref reuse types in the replacement |
507 | * |
508 | * 1) Intact: the subtree appears as the same object on both `src_stmt` and `tgt_stmt`, |
509 | * which, given the immutability of the IR, means the entire subtree is unchanged, |
510 | * and we do not need to recurse into the subtree. |
511 | * |
512 | * 2) Loop/Block sref reuse: for two different objects (`src`, `tgt`), |
513 | * which are both loops or both blocks, |
514 | * there is correspondence between them, |
515 | * which makes us to reuse the sref pointing to `src`, and change it to point to `tgt`. |
516 | * |
517 | * \note The intact reuse and loop sref reuse are collected in the ReuseCollector, |
518 | * while the block reuse is specified by the caller. |
519 | * |
520 | * \sa ReuseCollector |
521 | */ |
522 | struct ReuseInfo { |
523 | /*! |
524 | * \brief Kind 1. Intact reuse. If a stmt is in `intact`, it means its corresponding |
525 | * sref is reused and it is intact reuse. |
526 | */ |
527 | std::unordered_set<const StmtNode*> intact; |
528 | /*! |
529 | * \brief Kind 2.1. Loop sref reuse |
530 | * If the loop var of a loop is in `loop_sref_possible_reuse`, |
531 | * it means that when `src_stmt` has a loop that uses this loop var, |
532 | * the reuse kind is loop sref reuse. |
533 | * \note For each loop var in `loop_sref_possible_reuse`, it is possible that `src_stmt` doesn't |
534 | * contain a loop that uses this loop var, and that is the reason why it is named "possible". |
535 | */ |
536 | std::unordered_set<const VarNode*> loop_sref_possible_reuse; |
537 | /*! |
538 | * \brief Kind 2.2. Block sref reuse. |
539 | * Maps an old Block in `src_stmt` to a new block in `tgt_stmt`, |
540 | * indicating the sref to the old block should be reused in the sref to the new block. |
541 | */ |
542 | std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse; |
543 | }; |
544 | |
545 | /*! |
546 | * \brief A helper visitor which collects two cases of sref reuses in the `tgt_stmt`: |
547 | * |
548 | * 1) Intact: the subtree represented by `intact` appears on both old and new IR. |
549 | * Given the immutability of the IR, we can quickly decide that the entire subtree is unchanged, |
550 | * which means we do not need to visit into the subtree of the old statement. |
551 | * |
552 | * 2) Reused block/loop: for two different objects (`src`, `tgt`), |
553 | * which are both loops or both blocks, |
554 | * and there is correspondence between them, |
555 | * which makes us to reuse the sref pointing to `src`, and changes it to point to `tgt`, |
556 | */ |
557 | class ReuseCollector : public StmtVisitor { |
558 | public: |
559 | static ReuseInfo Collect(const ScheduleStateNode* self, const Stmt& tgt_stmt) { |
560 | ReuseCollector collector(self); |
561 | collector.VisitStmt(tgt_stmt); |
562 | ReuseInfo result; |
563 | result.intact = {collector.intact_.begin(), collector.intact_.end()}; |
564 | result.loop_sref_possible_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()}; |
565 | // `result.block_reuse ` is not set here because ReuseCollector doesn't collect it, |
566 | // and it is supposed to be properly set by the caller. |
567 | return result; |
568 | } |
569 | |
570 | private: |
571 | explicit ReuseCollector(const ScheduleStateNode* self) : self_(self) {} |
572 | |
573 | void VisitStmt_(const ForNode* op) final { |
574 | if (self_->stmt2ref.count(op)) { |
575 | intact_.push_back(op); |
576 | } else { |
577 | // Collect loop vars for detecting reuse of loop sref |
578 | loop_vars_.push_back(op->loop_var.get()); |
579 | StmtVisitor::VisitStmt_(op); |
580 | } |
581 | } |
582 | |
583 | void VisitStmt_(const BlockNode* op) final { |
584 | if (self_->stmt2ref.count(op)) { |
585 | intact_.push_back(op); |
586 | } else { |
587 | StmtVisitor::VisitStmt_(op); |
588 | } |
589 | } |
590 | |
591 | /*! \brief The schedule state to be worked on */ |
592 | const ScheduleStateNode* self_; |
593 | /*! \brief The intact statements we have collected along the way of visiting */ |
594 | std::vector<const StmtNode*> intact_; |
595 | /*! \brief The loop variable we collected in the tgt_stmt */ |
596 | std::vector<const VarNode*> loop_vars_; |
597 | }; |
598 | |
599 | /*! |
600 | * \brief A helper visitor which removes the stale srefs in the `src_stmt` |
601 | * that are useless after the replacement. |
602 | * |
603 | * It uses the reuse information previously collected to |
604 | * 1) delete those srefs that are not reused. |
605 | * 2) return the sref objects that are loop/block sref reuses, but not intact reuses |
606 | */ |
607 | class SRefTreePruner : public StmtVisitor { |
608 | public: |
609 | /*! |
610 | * \brief The entry function |
611 | * \param self The schedule class |
612 | * \param info The reuse info about intact reuse and loop/block reuse |
613 | * \param src_stmt The `src_stmt` where stale srefs to be removed |
614 | * \return Mapping from the reuse elements to reused srefs, more specifically: |
615 | * 1) Loop reuse: maps a loop var to the reused sref |
616 | * 2) Block reuse: maps a block stmt to the reused sref, |
617 | * where the block comes from the subtree of `tgt_stmt` |
618 | * 3) Intact reuse: not returned |
619 | */ |
620 | static std::unordered_map<const Object*, StmtSRef> Prune(ScheduleStateNode* self, |
621 | const ReuseInfo& reuse_info, |
622 | const Stmt& src_stmt) { |
623 | SRefTreePruner pruner(self, reuse_info); |
624 | pruner.VisitStmt(src_stmt); |
625 | return std::move(pruner.reused_srefs_); |
626 | } |
627 | |
628 | private: |
629 | explicit SRefTreePruner(ScheduleStateNode* self, const ReuseInfo& reuse_info) |
630 | : self_(self), reuse_info_(reuse_info) {} |
631 | |
632 | void VisitStmt_(const ForNode* op) final { |
633 | if (reuse_info_.intact.count(op)) { |
634 | return; |
635 | } |
636 | auto it = self_->stmt2ref.find(op); |
637 | ICHECK(it != self_->stmt2ref.end()) |
638 | << "IndexError: Cannot find corresponding StmtSRef for the loop:\n" |
639 | << GetRef<For>(op); |
640 | StmtSRef& sref = it->second; |
641 | // Detect reuse |
642 | const VarNode* loop_var = op->loop_var.get(); |
643 | if (reuse_info_.loop_sref_possible_reuse.count(loop_var)) { |
644 | // sref can be reused |
645 | reused_srefs_.emplace(loop_var, std::move(sref)); |
646 | } else { |
647 | sref->Reset(); |
648 | } |
649 | // erase the statement |
650 | self_->stmt2ref.erase(it); |
651 | // detect recursively |
652 | VisitStmt(op->body); |
653 | } |
654 | |
655 | void VisitStmt_(const BlockNode* op) final { |
656 | if (reuse_info_.intact.count(op)) { |
657 | return; |
658 | } |
659 | auto it = self_->stmt2ref.find(op); |
660 | ICHECK(it != self_->stmt2ref.end()) |
661 | << "IndexError: Cannot find corresponding StmtSRef for the block:\n" |
662 | << GetRef<Block>(op); |
663 | StmtSRef& sref = it->second; |
664 | // Detect reuse |
665 | const auto& sref_reuse = reuse_info_.block_sref_reuse; |
666 | if (auto reuse_it = sref_reuse.find(op); reuse_it != sref_reuse.end()) { |
667 | const BlockNode* to_reuse = reuse_it->second; |
668 | // sref can be reused |
669 | reused_srefs_.emplace(to_reuse, std::move(sref)); |
670 | } else { |
671 | sref->Reset(); |
672 | self_->block_info.erase(sref); |
673 | } |
674 | // erase the statement |
675 | self_->stmt2ref.erase(it); |
676 | // detect recursively |
677 | // op->init is omitted |
678 | VisitStmt(op->body); |
679 | } |
680 | |
681 | /*! \brief The schedule state we are working on */ |
682 | ScheduleStateNode* self_; |
683 | /*! \brief The reuse information we collected previously */ |
684 | const ReuseInfo& reuse_info_; |
685 | /*! |
686 | * \brief Reused srefs: |
687 | * 1) loop var -> StmtSRef |
688 | * 2) block stmt -> StmtSRef, where the block comes from the subtree of `tgt_stmt` |
689 | */ |
690 | std::unordered_map<const Object*, StmtSRef> reused_srefs_; |
691 | }; |
692 | |
693 | /*! |
694 | * \brief Update the sref in the `tgt_stmt` given the reuse information |
695 | * |
696 | * After being updated, in the `tgt_stmt` subtree, |
697 | * 1) all `StmtSRefNode::parent`s are correct |
698 | * 2) all `StmtSRefNode::seq_index`s are correct, except for the root |
699 | * 3) all `StmtSRefNode::stmt`s are correct, except for the root |
700 | */ |
701 | class SRefUpdater : public StmtVisitor { |
702 | public: |
703 | static void Update(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, |
704 | const std::unordered_map<const Object*, StmtSRef>& reused_srefs, |
705 | const Stmt& tgt_stmt) { |
706 | SRefUpdater(self, src_stmt_parent, reused_srefs).VisitStmt(tgt_stmt); |
707 | } |
708 | |
709 | private: |
710 | explicit SRefUpdater(ScheduleStateNode* self, StmtSRefNode* src_stmt_parent, |
711 | const std::unordered_map<const Object*, StmtSRef>& reused_srefs) |
712 | : self_(GetRef<ScheduleState>(self)), |
713 | ancestors_{src_stmt_parent}, |
714 | reused_srefs_(reused_srefs) {} |
715 | |
716 | void VisitStmt_(const ForNode* op) final { |
717 | StmtSRef& sref = self_->stmt2ref[op]; |
718 | // Detect intact reuse |
719 | if (sref.defined()) { |
720 | sref->parent = ancestors_.back(); |
721 | sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex |
722 | return; |
723 | } |
724 | // Detect loop reuse |
725 | auto it = reused_srefs_.find(op->loop_var.get()); |
726 | if (it != reused_srefs_.end()) { |
727 | // Update `stmt2ref[op]` to `reused_srefs_[op->loop_var]` |
728 | sref = it->second; |
729 | sref->stmt = op; |
730 | sref->parent = ancestors_.back(); |
731 | sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex |
732 | } else { |
733 | // A new loop sref without reuse |
734 | sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(), |
735 | /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex |
736 | } |
737 | // Recursive visit |
738 | ancestors_.push_back(sref.get()); |
739 | VisitStmt(op->body); |
740 | ancestors_.pop_back(); |
741 | } |
742 | |
743 | void VisitStmt_(const BlockNode* op) final { |
744 | StmtSRef& sref = self_->stmt2ref[op]; |
745 | // Detect intact |
746 | if (sref.defined()) { |
747 | sref->parent = ancestors_.back(); |
748 | sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex |
749 | return; |
750 | } |
751 | // Detect block reuse |
752 | auto it = reused_srefs_.find(op); |
753 | if (it != reused_srefs_.end()) { |
754 | // Update `stmt2ref[op]` to `reused_srefs_[op]` |
755 | sref = it->second; |
756 | sref->stmt = op; |
757 | sref->parent = ancestors_.back(); |
758 | sref->seq_index = -1; // `seq_index` will be set properly in SetSeqIndex |
759 | } else { |
760 | // A new block sref without reuse |
761 | sref = StmtSRef(/*stmt=*/op, /*parent=*/ancestors_.back(), |
762 | /*seq_index=*/-1); // `seq_index` will be set properly in SetSeqIndex |
763 | } |
764 | // Recursive visit |
765 | ancestors_.push_back(sref.get()); |
766 | VisitStmt(op->body); |
767 | ancestors_.pop_back(); |
768 | // Additionally, need to update the scope because the block is changed |
769 | UpdateBlockInfo(sref); |
770 | } |
771 | |
772 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
773 | StmtVisitor::VisitStmt_(seq_stmt); |
774 | SetSeqIndexInChildren(self_.get(), seq_stmt); |
775 | } |
776 | |
777 | void UpdateBlockInfo(const StmtSRef& block_sref) { |
778 | using TIter = std::unordered_map<StmtSRef, BlockInfo, ObjectPtrHash, ObjectPtrEqual>::iterator; |
779 | // The caller is responsible for correcting the flags |
780 | BlockInfo new_info((BlockScope(GetChildBlockSRefOnSRefTree(self_, block_sref)))); |
781 | std::pair<TIter, bool> insert_result = self_->block_info.emplace(block_sref, new_info); |
782 | bool inserted = insert_result.second; |
783 | BlockInfo& info = insert_result.first->second; |
784 | if (inserted) { |
785 | // Insertion has happened, update the flags accordingly |
786 | BlockInfo& info = insert_result.first->second; |
787 | info.affine_binding = false; |
788 | info.region_cover = false; |
789 | info.scope->stage_pipeline = false; |
790 | } else { |
791 | // Insertion didn't take place, because the entry has been there before. |
792 | // In this case, we assume that flags are still valid so intentionally keep them unchanged |
793 | new_info.scope->stage_pipeline = info.scope->stage_pipeline; |
794 | info.scope = std::move(new_info.scope); |
795 | } |
796 | } |
797 | |
798 | /*! \brief The schedule state class to be worked on */ |
799 | ScheduleState self_; |
800 | /*! \brief A stack containing all the ancestor For/Block nodes during the visit */ |
801 | std::vector<StmtSRefNode*> ancestors_; |
802 | /*! \brief Maps the loop var / block to the reused sref */ |
803 | const std::unordered_map<const Object*, StmtSRef>& reused_srefs_; |
804 | }; |
805 | |
806 | /*! |
807 | * \brief A helper that returns a new copy of `parent_stmt`, |
808 | * where the subtree `child_src_stmt` is replaced with the subtree `child_tgt_stmt`. |
809 | * \note The visitor assumes `child_src_stmt` is the child of `parent_stmt` in the sref tree. |
810 | */ |
811 | class ChildReplacer : private StmtMutator { |
812 | public: |
813 | static Stmt Replace(const StmtNode* parent_stmt, const StmtNode* child_src_stmt, |
814 | const Stmt& child_tgt_stmt, int seq_index, bool allow_copy_on_write) { |
815 | // Check the invariant |
816 | ICHECK(child_src_stmt->IsInstance<BlockNode>() || // |
817 | child_src_stmt->IsInstance<ForNode>()); |
818 | ICHECK(child_tgt_stmt->IsInstance<BlockNode>() || // |
819 | child_tgt_stmt->IsInstance<ForNode>() || // |
820 | child_tgt_stmt->IsInstance<BlockRealizeNode>()); |
821 | ChildReplacer replacer(child_src_stmt, child_tgt_stmt, seq_index); |
822 | replacer.allow_copy_on_write_ = allow_copy_on_write; |
823 | return replacer.CopyOnWriteAndVisit(parent_stmt); |
824 | } |
825 | |
826 | private: |
827 | explicit ChildReplacer(const StmtNode* src_stmt, const Stmt& tgt_stmt, int seq_index) |
828 | : src_stmt_(src_stmt), tgt_stmt_(tgt_stmt), seq_index_(seq_index) {} |
829 | |
830 | Stmt VisitStmt(const Stmt& stmt) final { |
831 | if (stmt.get() == src_stmt_) { |
832 | // If the statement matches the `src_stmt` to be replaced, just return the `tgt_stmt` |
833 | return tgt_stmt_; |
834 | } else { |
835 | return StmtMutator::VisitStmt(stmt); |
836 | } |
837 | } |
838 | |
839 | // Skipping sibling blocks and loops other than `src_stmt_` |
840 | Stmt VisitStmt_(const BlockNode* op) final { return GetRef<Stmt>(op); } |
841 | Stmt VisitStmt_(const ForNode* op) final { return GetRef<Stmt>(op); } |
842 | |
843 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
844 | int i = this->seq_index_; |
845 | int n = static_cast<int>(op->seq.size()); |
846 | if (0 <= i && i < n) { |
847 | const Stmt& stmt = op->seq[i]; |
848 | Optional<Stmt> new_stmt = NullOpt; |
849 | const StmtNode* src_stmt = this->src_stmt_; |
850 | // `stmt` can be For or BlockRealize |
851 | // `src_stmt` can be For or Block |
852 | // so the match from `stmt` to `src_stmt` can be |
853 | // 1) For -> For |
854 | // 2) BlockRealize -> Block |
855 | if (stmt.get() == src_stmt) { |
856 | // Case 1. src_stmt is For, stmt is For |
857 | new_stmt = tgt_stmt_; |
858 | } else if (const auto* realize = stmt.as<BlockRealizeNode>()) { |
859 | // Case 2. stmt is BlockRealize, src_stmt is Block |
860 | if (realize->block.get() == src_stmt) { |
861 | const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode); |
862 | ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize); |
863 | new_realize->block = GetRef<Block>(tgt_block); |
864 | new_stmt = BlockRealize(std::move(new_realize)); |
865 | } |
866 | } |
867 | // Move new_stmt to position i |
868 | if (new_stmt.defined()) { |
869 | ObjectPtr<SeqStmtNode> new_seq_stmt = CopyOnWrite(op); |
870 | new_seq_stmt->seq.Set(i, new_stmt.value()); |
871 | return SeqStmt(std::move(new_seq_stmt)); |
872 | } |
873 | } |
874 | return StmtMutator::VisitStmt_(op); |
875 | } |
876 | |
877 | Stmt CopyOnWriteAndVisit(const StmtNode* parent_stmt) { |
878 | // Step 1. Copy-on-write the `parent_stmt` and extract its `body`, |
879 | // where `body` means the body of either a block or a loop |
880 | // Step 2. Mutate the `block/loop->body`, searching for `child_old_stmt` |
881 | // and replace it with `child_tgt_stmt` |
882 | if (parent_stmt->IsInstance<BlockNode>()) { |
883 | auto* block = const_cast<BlockNode*>(static_cast<const BlockNode*>(parent_stmt)); |
884 | ObjectPtr<BlockNode> new_block = CopyOnWrite(block); |
885 | new_block->body = this->VisitStmt(new_block->body); |
886 | return Block(std::move(new_block)); |
887 | } else if (parent_stmt->IsInstance<ForNode>()) { |
888 | auto* loop = const_cast<ForNode*>(static_cast<const ForNode*>(parent_stmt)); |
889 | ObjectPtr<ForNode> new_loop = CopyOnWrite(loop); |
890 | new_loop->body = this->VisitStmt(new_loop->body); |
891 | return For(std::move(new_loop)); |
892 | } |
893 | LOG(FATAL) << "TypeError: Unexpected type: " << parent_stmt->GetTypeKey(); |
894 | throw; |
895 | } |
896 | |
897 | /*! \brief The `src_stmt` to be replaced */ |
898 | const StmtNode* src_stmt_; |
899 | /*! \brief The `tgt_stmt` to be replaced in */ |
900 | const Stmt& tgt_stmt_; |
901 | /*! |
902 | * \brief The `seq_index` of the `src_stmt` |
903 | * \sa StmtSRefNode |
904 | */ |
905 | int seq_index_; |
906 | }; |
907 | |
908 | void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, |
909 | const Map<Block, Block>& _block_sref_reuse) { |
910 | if (this->debug_mask != 0) { |
911 | const StmtNode* src_stmt = _src_sref->stmt; |
912 | bool input_correct = |
913 | (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<ForNode>()) || |
914 | (src_stmt->IsInstance<ForNode>() && tgt_stmt->IsInstance<BlockRealizeNode>()) || |
915 | (src_stmt->IsInstance<BlockNode>() && tgt_stmt->IsInstance<BlockNode>()); |
916 | if (!input_correct) { |
917 | LOG(FATAL) << "TypeError: src_stmt has type: " << src_stmt->GetTypeKey() |
918 | << ". tgt_stmt has type: " << tgt_stmt->GetTypeKey() << ".\nsrc_stmt:\n" |
919 | << GetRef<Stmt>(src_stmt) << "\ntgt_stmt:\n" |
920 | << tgt_stmt; |
921 | } |
922 | } |
923 | // Rule out the case that no replacement happens |
924 | if (_src_sref->stmt == tgt_stmt.get()) { |
925 | return; |
926 | } |
927 | // Reset sref as a new sref so that its content won't be affected by subsequent changes |
928 | StmtSRef src_sref(_src_sref->stmt, _src_sref->parent, _src_sref->seq_index); |
929 | Stmt src_stmt = GetRef<Stmt>(src_sref->stmt); |
930 | // Step 1. Create all the nodes needed for the new sref tree. |
931 | // After this step |
932 | // 1) all `parent`s are correct |
933 | // 2) all `seq_index`s are correct, except for the root |
934 | // 3) all `stmt`s are correct, except for the root |
935 | { |
936 | // Step 0. Setup block_sref_reuse |
937 | std::unordered_map<const BlockNode*, const BlockNode*> block_sref_reuse; |
938 | block_sref_reuse.reserve(_block_sref_reuse.size() + 1); |
939 | for (const auto& kv : _block_sref_reuse) { |
940 | block_sref_reuse.emplace(kv.first.get(), kv.second.get()); |
941 | } |
942 | // Step 1.1. Collect info for different kinds of reuses |
943 | // 1) intact |
944 | // 2) loop/block reuse |
945 | ReuseInfo reuse_info = ReuseCollector::Collect(this, tgt_stmt); |
946 | reuse_info.block_sref_reuse = std::move(block_sref_reuse); |
947 | // Step 1.2. Collect loop/block reuse to their corresponding srefs |
948 | // and remove those srefs in the `src_stmt` that are no longer used after replacement |
949 | std::unordered_map<const Object*, StmtSRef> reused_srefs = |
950 | SRefTreePruner::Prune(this, reuse_info, src_stmt); |
951 | // Step 1.3. Update the sref tree, inserting newly created srefs and properly handle reused |
952 | // srefs in `tgt_stmt` |
953 | SRefUpdater::Update(this, src_sref->parent, reused_srefs, tgt_stmt); |
954 | } |
955 | // Step 2. Set the ancestors' children properly |
956 | // Iteratively visit the ancestors, creating new ones whose `body`s are properly fixed. |
957 | // The visit stops when all the ancestors are uniquely referenced, i.e. can mutate inplace. |
958 | // Along the way, because we create a new ancestor path, |
959 | // we need to update those sref points from old ancestors to newly created ones |
960 | // Variables: |
961 | // 1) `num_copy_steps`. The maximum number of hops until we need to copy. To reach a node that |
962 | // can be mutated inplace, it needs `num_copy_steps + 1` hops. |
963 | // 2) `need_module_copy`. If true, need to mutate the PrimFunc and IRModule the sref belongs to. |
964 | // 3) `g_var` and `g_func`. Indicate which GlobalVar and PrimFunc the sref corresponds to |
965 | int num_copy_steps = -1; |
966 | bool need_module_copy = false; |
967 | const PrimFuncNode* g_func = nullptr; |
968 | GlobalVar g_var; |
969 | { |
970 | int i = 0; |
971 | const StmtSRefNode* p = src_sref.get(); |
972 | while (true) { |
973 | if (!p->stmt->unique()) { |
974 | num_copy_steps = i; |
975 | } |
976 | if (p->parent == nullptr) { |
977 | break; |
978 | } |
979 | ++i; |
980 | p = p->parent; |
981 | } |
982 | // Find `g_func` and `g_var` where the `src_sref` is in |
983 | g_func = GetRootPrimFunc(this->mod, p->stmt, &g_var); |
984 | need_module_copy = num_copy_steps == i || // |
985 | !this->mod.unique() || // |
986 | !this->mod->functions.unique() || // |
987 | !g_func->unique(); |
988 | } |
989 | // Loop invariant: |
990 | // |
991 | // Before step `i`: |
992 | // 1) `child_sref` is `src_sref` going up by `i` steps |
993 | // 2) `child_tgt_stmt` is the subtree that `child_sref` should correspond to after replacement |
994 | // 3) except for the subtree root, srefs that point to the subtree of `child_tgt_stmt` are correct |
995 | // 4) for the subtree root of `child_tgt_stmt`, `child_sref` has not pointed to it yet |
996 | // 5) `tgt_stmt` is of type Loop, Block or BlockRealize |
997 | // |
998 | // During step `i`: |
999 | // 1) Create `parent_stmt` that corresponds to `child_sref->parent` |
1000 | // 2) Point `child_sref` to `child_tgt_stmt` |
1001 | // 3) `tgt_stmt` is of type Loop or Block |
1002 | StmtSRefNode* child_sref = src_sref.get(); |
1003 | Stmt child_tgt_stmt = std::move(tgt_stmt); |
1004 | for (int i = 0; (need_module_copy || i <= num_copy_steps) && child_sref->parent != nullptr; ++i) { |
1005 | bool can_directly_mutate_parent = !need_module_copy && i == num_copy_steps; |
1006 | // Replace `child_sref->stmt` to `child_tgt_stmt`. |
1007 | const StmtNode* parent_stmt = child_sref->parent->stmt; |
1008 | const StmtNode* child_src_stmt = child_sref->stmt; |
1009 | // Step 2.1. Link `child_sref` to `child_tgt_stmt` |
1010 | if (i == 0) { |
1011 | // As the invariance of SRefUpdater, |
1012 | // the `seq_index` of the root of `tgt_stmt` is set as -1, |
1013 | // which might be incorrect |
1014 | SetSeqIndex(this, child_tgt_stmt, child_sref->seq_index); |
1015 | } else { |
1016 | // Point `child_sref` to `child_tgt_stmt` |
1017 | UpdateSRef(this, child_sref, child_tgt_stmt.get()); |
1018 | } |
1019 | // Step 2.2. Create `new_parent_stmt`, by mutating the body of `parent_stmt` |
1020 | Stmt new_parent_stmt = |
1021 | ChildReplacer::Replace(parent_stmt, child_src_stmt, child_tgt_stmt, |
1022 | /*seq_index=*/child_sref->seq_index, |
1023 | /*allow_copy_on_write=*/can_directly_mutate_parent); |
1024 | // Step 2.3. Go to next parent |
1025 | if (can_directly_mutate_parent) { |
1026 | // If the node can be directly mutated inplace, |
1027 | // then there is no need to update its parent and the function |
1028 | break; |
1029 | } |
1030 | child_tgt_stmt = std::move(new_parent_stmt); |
1031 | child_sref = child_sref->parent; |
1032 | } |
1033 | // Step 3. Handle the case that we mutate the root |
1034 | if (need_module_copy) { |
1035 | // From the loop invariant, upon exit, while its subtree is properly set, |
1036 | // `child_sref` is not properly to `child_tgt_stmt` yet. |
1037 | if (src_sref->parent != nullptr) { |
1038 | // Not replacing a root |
1039 | UpdateSRef(this, child_sref, child_tgt_stmt.get()); |
1040 | } |
1041 | // Ensure the uniqueness of `this->mod` and `this->mod->functions` |
1042 | IRModuleNode* new_mod = this->mod.CopyOnWrite(); |
1043 | MapNode* new_map = new_mod->functions.CopyOnWrite(); |
1044 | // Move out the PrimFunc where the sref belong while ensuring uniqueness |
1045 | PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var))); |
1046 | ICHECK(ref_new_func.get() == g_func); |
1047 | PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); |
1048 | // If `g_func` was not unique, after the 3 lines above: |
1049 | // `ref_new_func` points to a unique PrimFunc |
1050 | // `g_func` points to the previous PrimFunc if it is not unique |
1051 | // If `g_func` was unique, after the 3 lines above: |
1052 | // `ref_new_func` points to the same unique function that `g_func` points to |
1053 | // Update the body of the function the sref belongs to Assign |
1054 | const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode); |
1055 | // Make `child_tgt_stmt` the root block |
1056 | const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode); |
1057 | ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize); |
1058 | new_realize->block = GetRef<Block>(child_block); |
1059 | new_func->body = BlockRealize(std::move(new_realize)); |
1060 | // Finally, move the `ref_new_func` back and update `this->mod` |
1061 | new_map->at(g_var) = std::move(ref_new_func); |
1062 | this->mod = GetRef<IRModule>(new_mod); |
1063 | } |
1064 | uint32_t flag = (debug_mask != -1) // |
1065 | ? static_cast<uint32_t>(debug_mask) // |
1066 | : std::numeric_limits<uint32_t>::max(); |
1067 | if (flag & ScheduleDebugMask::kVerifySRefTree) { |
1068 | VerifySRefTree(GetRef<ScheduleState>(this)); |
1069 | } |
1070 | } |
1071 | |
1072 | void ScheduleStateNode::DebugVerify() const { |
1073 | ICHECK_GE(debug_mask, -1); |
1074 | uint32_t flag = (debug_mask != -1) // |
1075 | ? static_cast<uint32_t>(debug_mask) // |
1076 | : std::numeric_limits<uint32_t>::max(); |
1077 | if (flag & ScheduleDebugMask::kVerifySRefTree) { |
1078 | VerifySRefTree(GetRef<ScheduleState>(this)); |
1079 | } |
1080 | if (flag & ScheduleDebugMask::kVerifyCachedFlags) { |
1081 | VerifyCachedFlags(GetRef<ScheduleState>(this)); |
1082 | } |
1083 | } |
1084 | |
1085 | /**************** BlockInfo-related ****************/ |
1086 | |
1087 | BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { |
1088 | TVM_SREF_TO_BLOCK(block_sref); |
1089 | auto it = this->block_info.find(block_sref); |
1090 | CHECK(it != this->block_info.end()) |
1091 | << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" |
1092 | << GetRef<Stmt>(block_sref->stmt); |
1093 | return it->second; |
1094 | } |
1095 | |
1096 | void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) { |
1097 | BlockInfoCollector::Collect(this, stmt); |
1098 | } |
1099 | |
1100 | TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { |
1101 | const BlockInfo& info = self->GetBlockInfo(block_sref); |
1102 | return {Bool(info.affine_binding), // |
1103 | Bool(info.region_cover), // |
1104 | Bool(info.scope->stage_pipeline)}; |
1105 | } |
1106 | |
1107 | /**************** FFI ****************/ |
1108 | |
1109 | TVM_REGISTER_NODE_TYPE(ScheduleStateNode); |
1110 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState" ) |
1111 | .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState { |
1112 | return ScheduleState(mod, debug_mask); |
1113 | }); |
1114 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope" ) |
1115 | .set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope); |
1116 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateReplace" ) |
1117 | .set_body_method<ScheduleState>(&ScheduleStateNode::Replace); |
1118 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetSRef" ) |
1119 | .set_body_typed([](ScheduleState self, Stmt stmt) -> Optional<StmtSRef> { |
1120 | auto it = self->stmt2ref.find(stmt.get()); |
1121 | return it != self->stmt2ref.end() ? it->second : Optional<StmtSRef>(NullOpt); |
1122 | }); |
1123 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetCachedFlags" ).set_body_typed(GetCachedFlags); |
1124 | |
1125 | } // namespace tir |
1126 | } // namespace tvm |
1127 | |