1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../../transforms/ir_utils.h" |
20 | #include "../utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | /*! \brief Information used to create new padding block */ |
26 | struct PaddingBlockInfo { |
27 | /*! \brief In-bound block iter regions, wrt loop vars. */ |
28 | Array<Range> in_bound_region; |
29 | /*! \brief In-bound value, wrt block iter vars. */ |
30 | PrimExpr in_bound_value; |
31 | /*! \brief Condition of in-bound write, wrt loop vars. */ |
32 | PrimExpr in_bound_predicate; |
33 | /*! \brief Padding value, should be a constant. */ |
34 | PrimExpr pad_value; |
35 | }; |
36 | |
37 | class PaddingPatternMatchError : public ScheduleError { |
38 | public: |
39 | PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg) |
40 | : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {} |
41 | |
42 | String FastErrorString() const final { |
43 | return "ScheduleError: decompose_padding expect the block to match padding pattern\n " + |
44 | error_msg_; |
45 | } |
46 | |
47 | String DetailRenderTemplate() const final { |
48 | std::ostringstream os; |
49 | os << "ScheduleError: decompose_padding expect the block {0} to match padding pattern\n " |
50 | << error_msg_; |
51 | return os.str(); |
52 | } |
53 | |
54 | IRModule mod() const final { return mod_; } |
55 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
56 | |
57 | IRModule mod_; |
58 | Block block_; |
59 | std::string error_msg_; |
60 | }; |
61 | |
62 | /*! |
63 | * \brief Helper class to analyze and check the padding pattern of the block, |
64 | * then return the padding information. |
65 | */ |
66 | class PaddingInfoAnalyzer { |
67 | public: |
68 | static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize, |
69 | const Map<Var, Range>& dom_map, |
70 | arith::Analyzer* analyzer) { |
71 | PaddingInfoAnalyzer padding_analyzer(analyzer); |
72 | if (!padding_analyzer.MatchPadding(realize, dom_map)) { |
73 | throw PaddingPatternMatchError(mod, realize->block, padding_analyzer.error_msg_); |
74 | } |
75 | return padding_analyzer.info_; |
76 | } |
77 | |
78 | private: |
79 | explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} |
80 | |
81 | /*! \brief Detect padding pattern and update result. */ |
82 | bool MatchPadding(const BlockRealizeNode* realize, const Map<Var, Range>& dom_map) { |
83 | // Step 1. Check match padding computation pattern. |
84 | // A[...] = T.if_then_else(predicate, B[...], imm) |
85 | Block block = realize->block; |
86 | std::unordered_map<const VarNode*, PrimExpr> iter_values; |
87 | for (size_t i = 0; i < realize->iter_values.size(); ++i) { |
88 | Var block_var = block->iter_vars[i]->var; |
89 | iter_values[block_var.get()] = realize->iter_values[i]; |
90 | } |
91 | const BufferStoreNode* store = block->body.as<BufferStoreNode>(); |
92 | if (!store) { |
93 | SetError("Block body expect a BufferStore to the write buffer" ); |
94 | return false; |
95 | } |
96 | const CallNode* if_then_else = store->value.as<CallNode>(); |
97 | if (!if_then_else || !if_then_else->op.same_as(tir::builtin::if_then_else())) { |
98 | SetError("Value of BufferStore expect to be constrained by a padding predicate" ); |
99 | return false; |
100 | } |
101 | PrimExpr pad_predicate = Substitute(if_then_else->args[0], iter_values); |
102 | PrimExpr in_bound_value = if_then_else->args[1]; |
103 | PrimExpr pad_value = if_then_else->args[2]; |
104 | if (!is_const_number(pad_value)) { |
105 | SetError("Pad value should be constant" ); |
106 | return false; |
107 | } |
108 | |
109 | // Step 2. Check in-bound computation to be effectiveless. |
110 | if (SideEffect(if_then_else->args[1]) > CallEffectKind::kReadState) { |
111 | SetError("Inbound computation should not have side-effect" ); |
112 | return false; |
113 | } |
114 | |
115 | // Step 3. Analyze in-bound write region. |
116 | PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate); |
117 | if (analyzer_->CanProveEqual(in_bound_predicate, 1)) { |
118 | SetError("The in-bound predicate is trivial" ); |
119 | return false; |
120 | } |
121 | Array<Range> in_bound_region = this->EstimateInBoundRegion( |
122 | /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map, |
123 | /*in_bound_predicate=*/in_bound_predicate); |
124 | if (in_bound_region.empty()) { |
125 | return false; |
126 | } |
127 | |
128 | // Step 4. Update result information. |
129 | info_.in_bound_value = if_then_else->args[1]; |
130 | info_.in_bound_region = in_bound_region; |
131 | info_.in_bound_predicate = in_bound_predicate; |
132 | info_.pad_value = pad_value; |
133 | return true; |
134 | } |
135 | |
136 | /*! \brief Rewrite predicate to left recursive conjunction, drop likely annotation. */ |
137 | PrimExpr RewritePredicate(const PrimExpr& predicate) { |
138 | PrimExpr res = const_true(); |
139 | std::function<void(PrimExpr)> update = [&res, &update](PrimExpr e) { |
140 | arith::PVar<PrimExpr> a, b; |
141 | if ((a && b).Match(e)) { |
142 | update(a.Eval()); |
143 | update(b.Eval()); |
144 | } else { |
145 | if (const CallNode* call = e.as<CallNode>()) { |
146 | if (call->op.same_as(builtin::likely())) { |
147 | e = call->args[0]; |
148 | } |
149 | } |
150 | res = res && e; |
151 | } |
152 | }; |
153 | update(predicate); |
154 | return analyzer_->Simplify(res); |
155 | } |
156 | |
157 | /*! \brief Return iteration region of block vars where the padding predicate evals to true. */ |
158 | Array<Range> EstimateInBoundRegion(const Array<PrimExpr>& iter_values, |
159 | const Map<Var, Range>& dom_map, |
160 | const PrimExpr& in_bound_predicate) { |
161 | Array<Range> region; |
162 | |
163 | auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate, |
164 | arith::IterMapLevel::Surjective, analyzer_); |
165 | if (res->indices.empty()) { |
166 | SetError("Block iters are not independent wrt padding condition" ); |
167 | return {}; |
168 | } |
169 | for (const arith::IterSumExpr& sum : res->indices) { |
170 | if (sum->args.empty()) { |
171 | region.push_back(Range::FromMinExtent(sum->base, 1)); |
172 | } else { |
173 | ICHECK_EQ(sum->args.size(), 1U); |
174 | if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) { |
175 | SetError("Strided iteration is not supported" ); |
176 | return {}; |
177 | } |
178 | region.push_back(Range::FromMinExtent(sum->base, sum->args[0]->extent)); |
179 | } |
180 | } |
181 | return region; |
182 | } |
183 | |
184 | void SetError(const std::string& msg) { error_msg_ = msg; } |
185 | |
186 | /*! \brief padding info analyse result. */ |
187 | PaddingBlockInfo info_; |
188 | /*! \brief current error message. */ |
189 | std::string error_msg_; |
190 | /*! \brief arithmetic analyzer. */ |
191 | arith::Analyzer* analyzer_; |
192 | }; |
193 | |
194 | /*! \brief Create block to fill constant pad values into full region */ |
195 | static std::pair<Stmt, BlockRealize> CreateConstBlock(const BlockRealizeNode* realize, |
196 | const PaddingBlockInfo& info, |
197 | const Array<For>& loops, |
198 | const Stmt& highest_pos_inclusive, |
199 | arith::Analyzer* analyzer) { |
200 | const Block& block = realize->block; |
201 | Array<IterVar> new_iter_vars; |
202 | Map<Var, PrimExpr> repl_dict; |
203 | |
204 | // create new block itervars |
205 | for (size_t i = 0; i < block->iter_vars.size(); ++i) { |
206 | const IterVar& origin_iter = block->iter_vars[i]; |
207 | Var new_var = origin_iter->var.copy_with_suffix("" ); |
208 | new_iter_vars.push_back(IterVar(origin_iter->dom, new_var, IterVarType::kDataPar)); |
209 | repl_dict.Set(origin_iter->var, new_var); |
210 | } |
211 | |
212 | // rewrite expr helper |
213 | auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { |
214 | return analyzer->Simplify(Substitute(e, repl_dict)); |
215 | }; |
216 | |
217 | // create new write region |
218 | ICHECK_EQ(block->writes.size(), 1U); |
219 | BufferRegion write_region = BufferRegion( |
220 | block->writes[0]->buffer, block->writes[0]->region.Map([rewrite_expr](const Range& r) { |
221 | return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); |
222 | })); |
223 | |
224 | // create block to fill const pad values |
225 | BufferStore store = Downcast<BufferStore>(block->body); |
226 | store.CopyOnWrite()->value = info.pad_value; |
227 | store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); |
228 | Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region}, |
229 | /*name_hint=*/block->name_hint + "_pad_const" , /*body=*/std::move(store)); |
230 | |
231 | // create new loop vars |
232 | Array<Var> new_loop_vars; |
233 | for (const For& loop : loops) { |
234 | Var new_var = loop->loop_var.copy_with_suffix("" ); |
235 | new_loop_vars.push_back(new_var); |
236 | repl_dict.Set(loop->loop_var, new_var); |
237 | if (loop.same_as(highest_pos_inclusive)) { |
238 | break; |
239 | } |
240 | } |
241 | |
242 | // create new block realize node |
243 | Array<PrimExpr> new_iter_values; |
244 | for (size_t i = 0; i < realize->iter_values.size(); ++i) { |
245 | new_iter_values.push_back(rewrite_expr(realize->iter_values[i])); |
246 | } |
247 | BlockRealize new_realize(/*iter_values=*/new_iter_values, |
248 | /*predicate=*/rewrite_expr(realize->predicate), |
249 | /*block=*/new_block); |
250 | |
251 | // create new loops |
252 | Stmt nest_stmt_root = new_realize; |
253 | for (size_t i = 0; i < new_loop_vars.size(); ++i) { |
254 | For loop = loops[i]; |
255 | nest_stmt_root = |
256 | For(new_loop_vars[i], loop->min, loop->extent, ForKind::kSerial, nest_stmt_root); |
257 | } |
258 | |
259 | return {nest_stmt_root, new_realize}; |
260 | } |
261 | |
262 | /*! \brief Create block to fill in-bound region values. */ |
263 | static std::pair<Stmt, BlockRealize> CreateInBoundBlock(const BlockRealizeNode* realize, |
264 | const PaddingBlockInfo& info, |
265 | |
266 | const Array<For>& loops, |
267 | const Stmt& highest_pos_inclusive, |
268 | arith::Analyzer* analyzer) { |
269 | const Block& block = realize->block; |
270 | Array<IterVar> new_iter_vars; |
271 | Map<Var, PrimExpr> repl_dict; |
272 | |
273 | // record loop ranges to be mutated |
274 | Map<Var, Range> new_loop_ranges; |
275 | for (const For& loop : loops) { |
276 | new_loop_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
277 | if (loop.same_as(highest_pos_inclusive)) { |
278 | break; |
279 | } |
280 | } |
281 | |
282 | // create new block iter vars and iter bindings |
283 | Array<PrimExpr> new_iter_binding; |
284 | for (size_t i = 0; i < info.in_bound_region.size(); ++i) { |
285 | // add new block itervar |
286 | const IterVar& origin_itervar = block->iter_vars[i]; |
287 | Var new_var = origin_itervar->var.copy_with_suffix("" ); |
288 | Range new_range = |
289 | Range::FromMinExtent(make_const(new_var->dtype, 0), info.in_bound_region[i]->extent); |
290 | new_iter_vars.push_back(IterVar(new_range, new_var, IterVarType::kDataPar)); |
291 | repl_dict.Set(origin_itervar->var, new_var + info.in_bound_region[i]->min); |
292 | |
293 | // update new loop range |
294 | Var loop_var = GetRef<Var>(realize->iter_values[i].as<VarNode>()); |
295 | if (loop_var.defined() && new_loop_ranges.count(loop_var)) { |
296 | // if the block binding is the loop var with single child, mutate loop range |
297 | // instead of insert extra block predicate |
298 | new_loop_ranges.Set(loop_var, new_range); |
299 | new_iter_binding.push_back(realize->iter_values[i]); |
300 | repl_dict.Set(loop_var, loop_var + info.in_bound_region[i]->min); |
301 | analyzer->Bind(loop_var, new_range, /*allow_override=*/true); |
302 | } else { |
303 | new_iter_binding.push_back( |
304 | analyzer->Simplify(realize->iter_values[i] - info.in_bound_region[i]->min)); |
305 | } |
306 | } |
307 | |
308 | // rewrite helpers |
309 | auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) { |
310 | return analyzer->Simplify(Substitute(e, repl_dict)); |
311 | }; |
312 | auto rewrite_region = [rewrite_expr](const Region& region) { |
313 | return region.Map([rewrite_expr](const Range& r) { |
314 | return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent)); |
315 | }); |
316 | }; |
317 | |
318 | // create new read/write region for in-bound accesses |
319 | Array<BufferRegion> reads, writes; |
320 | for (const BufferRegion& read : block->reads) { |
321 | reads.push_back(BufferRegion(read->buffer, rewrite_region(read->region))); |
322 | } |
323 | for (const BufferRegion& write : block->writes) { |
324 | writes.push_back(BufferRegion(write->buffer, rewrite_region(write->region))); |
325 | } |
326 | |
327 | // create new block realize node |
328 | BufferStore store = Downcast<BufferStore>(block->body); |
329 | store.CopyOnWrite()->value = rewrite_expr(info.in_bound_value); |
330 | store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr); |
331 | Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes, |
332 | /*name_hint=*/block->name_hint, /*body=*/std::move(store)); |
333 | PrimExpr new_predicate = rewrite_expr(info.in_bound_predicate); |
334 | BlockRealize new_realize(/*iter_values=*/new_iter_binding, /*predicate=*/new_predicate, |
335 | /*block=*/new_block); |
336 | |
337 | // create new loops |
338 | Stmt nest_stmt_root = new_realize; |
339 | for (const For& loop : loops) { |
340 | auto it = new_loop_ranges.find(loop->loop_var); |
341 | PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min; |
342 | PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : (*it).second->extent; |
343 | nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, nest_stmt_root, |
344 | loop->thread_binding, loop->annotations, loop->span); |
345 | if (loop.same_as(highest_pos_inclusive)) { |
346 | break; |
347 | } |
348 | } |
349 | return {nest_stmt_root, new_realize}; |
350 | } |
351 | |
352 | /*! |
353 | * \brief A helper class to create a new scope that contains decomposed padding blocks. |
354 | */ |
355 | class DecomposePaddingBlockReplacer : public StmtMutator { |
356 | public: |
357 | /*! \brief Replacement information */ |
358 | struct ReplaceDesc { |
359 | /*! \brief loop above which to insert const pad value filling code. */ |
360 | For const_filling_pos; |
361 | /*! \brief loop under which to insert in bound value filling code. */ |
362 | For in_bound_filling_pos; |
363 | /*! \brief const pad value filling loop. */ |
364 | Stmt const_filling_loop; |
365 | /*! \brief highest in bound value filling loop with single child. */ |
366 | Stmt in_bound_filling_loop; |
367 | /*! \brief const pad value filling block. */ |
368 | BlockRealize const_filling_block; |
369 | /*! \brief in bound value filling block. */ |
370 | BlockRealize in_bound_filling_block; |
371 | }; |
372 | |
373 | static Block Replace(Block scope_root, const ReplaceDesc& desc) { |
374 | DecomposePaddingBlockReplacer replacer(desc); |
375 | return Downcast<Block>(replacer(std::move(scope_root))); |
376 | } |
377 | |
378 | private: |
379 | explicit DecomposePaddingBlockReplacer(const ReplaceDesc& desc) : desc_(desc) {} |
380 | |
381 | Stmt VisitStmt_(const ForNode* op) final { |
382 | Stmt new_loop; |
383 | if (op == desc_.in_bound_filling_pos.get()) { |
384 | // position to rewrite inbound filling code |
385 | new_loop = desc_.in_bound_filling_loop; |
386 | } else { |
387 | new_loop = StmtMutator::VisitStmt_(op); |
388 | } |
389 | if (op == desc_.const_filling_pos.get()) { |
390 | // position to insert pad value filling code |
391 | return std::move(SeqStmt({desc_.const_filling_loop, new_loop})); |
392 | } |
393 | return std::move(new_loop); |
394 | } |
395 | |
396 | Stmt VisitStmt_(const SeqStmtNode* seq) final { |
397 | Array<Stmt> new_stmts; |
398 | new_stmts.reserve(seq->seq.size()); |
399 | for (const Stmt& old_stmt : seq->seq) { |
400 | new_stmts.push_back(VisitStmt(old_stmt)); |
401 | } |
402 | return SeqStmt::Flatten(new_stmts); |
403 | } |
404 | |
405 | private: |
406 | const ReplaceDesc& desc_; |
407 | }; |
408 | |
409 | StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, |
410 | const StmtSRef& loop_sref, bool check_only) { |
411 | /*! |
412 | * Check |
413 | * - the block is a compact block |
414 | * - the loop is an ancester of the block |
415 | * - the block match padding pattern |
416 | * Mutate |
417 | * - generate new block to fill padding values |
418 | * - trim original block to write non-padding part only |
419 | */ |
420 | // Condition Checks and Information Collection |
421 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
422 | const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); |
423 | Map<Var, Range> dom_map; |
424 | arith::Analyzer analyzer; |
425 | |
426 | // Check 1. check the block is complete. |
427 | StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
428 | CheckCompleteBlock(self, block_sref, scope_root_sref); |
429 | |
430 | // Check 2. Check loop_sref is an ancestor of block_sref. Also collect |
431 | // - the highest loop position (inclusive) to insert const pad value filling code above. |
432 | // - the highest loop position (inclusive) to replace with in-bound value filling code. |
433 | Array<StmtSRef> loop_srefs = GetLoops(block_sref); |
434 | Array<For> loops; |
435 | bool found_const_filling_pos = false; |
436 | bool found_in_bound_filling_pos = false; |
437 | For const_filling_pos = GetRef<For>(loop_sref->StmtAs<ForNode>()); |
438 | For in_bound_filling_pos{nullptr}; |
439 | for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { |
440 | For cur_loop = GetRef<For>((*it)->StmtAs<ForNode>()); |
441 | Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); |
442 | dom_map.Set(cur_loop->loop_var, range); |
443 | analyzer.Bind(cur_loop->loop_var, range); |
444 | loops.push_back(cur_loop); |
445 | |
446 | if (cur_loop.same_as(const_filling_pos)) { |
447 | ICHECK(!found_const_filling_pos); |
448 | found_const_filling_pos = true; |
449 | if (!found_in_bound_filling_pos) { |
450 | found_in_bound_filling_pos = true; |
451 | in_bound_filling_pos = cur_loop; |
452 | } |
453 | } else if (!found_in_bound_filling_pos) { |
454 | if (!cur_loop->body->IsInstance<ForNode>() && |
455 | !cur_loop->body->IsInstance<BlockRealizeNode>()) { |
456 | found_in_bound_filling_pos = true; |
457 | } else { |
458 | in_bound_filling_pos = cur_loop; |
459 | } |
460 | } |
461 | } |
462 | ICHECK(in_bound_filling_pos.defined()); |
463 | if (!found_const_filling_pos) { |
464 | throw LoopPositionError(self->mod, const_filling_pos, GetRef<Block>(block), |
465 | "decompose_padding" ); |
466 | } |
467 | |
468 | // Check 3. match padding pattern and return padding operation info. |
469 | PaddingBlockInfo info = |
470 | PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, &analyzer); |
471 | |
472 | // IR Manipulation |
473 | // Step 1. Create const pad value filling part and in-bound value filling part. |
474 | DecomposePaddingBlockReplacer::ReplaceDesc replace_desc; |
475 | replace_desc.const_filling_pos = const_filling_pos; |
476 | replace_desc.in_bound_filling_pos = in_bound_filling_pos; |
477 | std::tie(replace_desc.const_filling_loop, replace_desc.const_filling_block) = |
478 | CreateConstBlock(realize, info, loops, const_filling_pos, &analyzer); |
479 | std::tie(replace_desc.in_bound_filling_loop, replace_desc.in_bound_filling_block) = |
480 | CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); |
481 | |
482 | // Step 2. Execute IR replacement. |
483 | Block old_scope_root_block = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>()); |
484 | Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc); |
485 | if (check_only) { |
486 | return block_sref; |
487 | } |
488 | |
489 | // Step 3. Update schedule states. |
490 | self->Replace(scope_root_sref, new_scope_root, |
491 | {{old_scope_root_block, new_scope_root}, |
492 | {GetRef<Block>(block), replace_desc.in_bound_filling_block->block}}); |
493 | auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get()); |
494 | |
495 | // Set block info of created const pad value filling block |
496 | BlockInfo& block_info = self->block_info[new_block_sref]; |
497 | block_info.affine_binding = true; |
498 | block_info.region_cover = true; |
499 | block_info.scope->stage_pipeline = true; |
500 | |
501 | // If the const pad value filling block is lifted out of the original subtree, |
502 | // set the region_cover flag as false since region_cover is the property under the subtree. |
503 | bool preserve_stage_pipeline = true; |
504 | for (const StmtSRef& consumer_sref : GetConsumers(self, block_sref)) { |
505 | StmtSRef lca = GetSRefLowestCommonAncestor({consumer_sref, block_sref}); |
506 | const StmtSRefNode* parent = new_block_sref->parent; |
507 | bool is_under_lca = false; |
508 | while (parent) { |
509 | if (parent == lca.get()) { |
510 | is_under_lca = true; |
511 | break; |
512 | } |
513 | parent = parent->parent; |
514 | } |
515 | if (!is_under_lca) { |
516 | preserve_stage_pipeline = false; |
517 | self->block_info[consumer_sref].region_cover = false; |
518 | } |
519 | } |
520 | if (!preserve_stage_pipeline) { |
521 | self->block_info[scope_root_sref].scope->stage_pipeline = false; |
522 | } |
523 | return new_block_sref; |
524 | } |
525 | |
526 | StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref, |
527 | const StmtSRef& loop_sref) { |
528 | return DecomposePaddingImpl(self, block_sref, loop_sref, false); |
529 | } |
530 | |
531 | bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref, |
532 | const StmtSRef& loop_sref) { |
533 | try { |
534 | DecomposePaddingImpl(self, block_sref, loop_sref, true); |
535 | } catch (const tvm::runtime::Error& e) { |
536 | return false; |
537 | } |
538 | return true; |
539 | } |
540 | |
541 | /******** FFI ********/ |
542 | |
543 | TVM_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding" ) |
544 | .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) { |
545 | return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv)); |
546 | }); |
547 | |
548 | /******** InstructionKind Registration ********/ |
549 | |
550 | struct DecomposPaddingTraits : public UnpackedInstTraits<DecomposPaddingTraits> { |
551 | static constexpr const char* kName = "DecomposePadding" ; |
552 | static constexpr bool kIsPure = false; |
553 | |
554 | private: |
555 | static constexpr size_t kNumInputs = 2; |
556 | static constexpr size_t kNumAttrs = 0; |
557 | static constexpr size_t kNumDecisions = 0; |
558 | |
559 | static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) { |
560 | return sch->DecomposePadding(block_rv, loop_rv); |
561 | } |
562 | |
563 | static String UnpackedAsPython(Array<String> outputs, String block_rv, LoopRV loop_rv) { |
564 | PythonAPICall py("decompose_padding" ); |
565 | py.Input("block" , block_rv); |
566 | py.Input("loop" , loop_rv); |
567 | py.SingleOutput(outputs); |
568 | return py.Str(); |
569 | } |
570 | |
571 | template <typename> |
572 | friend struct ::tvm::tir::UnpackedInstTraits; |
573 | }; |
574 | |
575 | TVM_REGISTER_INST_KIND_TRAITS(DecomposPaddingTraits); |
576 | |
577 | } // namespace tir |
578 | } // namespace tvm |
579 | |