1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../../transforms/ir_utils.h"
20#include "../utils.h"
21
22namespace tvm {
23namespace tir {
24
25/*! \brief Information used to create new padding block */
26struct PaddingBlockInfo {
27 /*! \brief In-bound block iter regions, wrt loop vars. */
28 Array<Range> in_bound_region;
29 /*! \brief In-bound value, wrt block iter vars. */
30 PrimExpr in_bound_value;
31 /*! \brief Condition of in-bound write, wrt loop vars. */
32 PrimExpr in_bound_predicate;
33 /*! \brief Padding value, should be a constant. */
34 PrimExpr pad_value;
35};
36
37class PaddingPatternMatchError : public ScheduleError {
38 public:
39 PaddingPatternMatchError(IRModule mod, Block block, const std::string& error_msg)
40 : mod_(std::move(mod)), block_(std::move(block)), error_msg_(error_msg) {}
41
42 String FastErrorString() const final {
43 return "ScheduleError: decompose_padding expect the block to match padding pattern\n " +
44 error_msg_;
45 }
46
47 String DetailRenderTemplate() const final {
48 std::ostringstream os;
49 os << "ScheduleError: decompose_padding expect the block {0} to match padding pattern\n "
50 << error_msg_;
51 return os.str();
52 }
53
54 IRModule mod() const final { return mod_; }
55 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
56
57 IRModule mod_;
58 Block block_;
59 std::string error_msg_;
60};
61
62/*!
63 * \brief Helper class to analyze and check the padding pattern of the block,
64 * then return the padding information.
65 */
66class PaddingInfoAnalyzer {
67 public:
68 static PaddingBlockInfo CheckAndGetPaddingInfo(IRModule mod, const BlockRealizeNode* realize,
69 const Map<Var, Range>& dom_map,
70 arith::Analyzer* analyzer) {
71 PaddingInfoAnalyzer padding_analyzer(analyzer);
72 if (!padding_analyzer.MatchPadding(realize, dom_map)) {
73 throw PaddingPatternMatchError(mod, realize->block, padding_analyzer.error_msg_);
74 }
75 return padding_analyzer.info_;
76 }
77
78 private:
79 explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
80
81 /*! \brief Detect padding pattern and update result. */
82 bool MatchPadding(const BlockRealizeNode* realize, const Map<Var, Range>& dom_map) {
83 // Step 1. Check match padding computation pattern.
84 // A[...] = T.if_then_else(predicate, B[...], imm)
85 Block block = realize->block;
86 std::unordered_map<const VarNode*, PrimExpr> iter_values;
87 for (size_t i = 0; i < realize->iter_values.size(); ++i) {
88 Var block_var = block->iter_vars[i]->var;
89 iter_values[block_var.get()] = realize->iter_values[i];
90 }
91 const BufferStoreNode* store = block->body.as<BufferStoreNode>();
92 if (!store) {
93 SetError("Block body expect a BufferStore to the write buffer");
94 return false;
95 }
96 const CallNode* if_then_else = store->value.as<CallNode>();
97 if (!if_then_else || !if_then_else->op.same_as(tir::builtin::if_then_else())) {
98 SetError("Value of BufferStore expect to be constrained by a padding predicate");
99 return false;
100 }
101 PrimExpr pad_predicate = Substitute(if_then_else->args[0], iter_values);
102 PrimExpr in_bound_value = if_then_else->args[1];
103 PrimExpr pad_value = if_then_else->args[2];
104 if (!is_const_number(pad_value)) {
105 SetError("Pad value should be constant");
106 return false;
107 }
108
109 // Step 2. Check in-bound computation to be effectiveless.
110 if (SideEffect(if_then_else->args[1]) > CallEffectKind::kReadState) {
111 SetError("Inbound computation should not have side-effect");
112 return false;
113 }
114
115 // Step 3. Analyze in-bound write region.
116 PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
117 if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
118 SetError("The in-bound predicate is trivial");
119 return false;
120 }
121 Array<Range> in_bound_region = this->EstimateInBoundRegion(
122 /*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
123 /*in_bound_predicate=*/in_bound_predicate);
124 if (in_bound_region.empty()) {
125 return false;
126 }
127
128 // Step 4. Update result information.
129 info_.in_bound_value = if_then_else->args[1];
130 info_.in_bound_region = in_bound_region;
131 info_.in_bound_predicate = in_bound_predicate;
132 info_.pad_value = pad_value;
133 return true;
134 }
135
136 /*! \brief Rewrite predicate to left recursive conjunction, drop likely annotation. */
137 PrimExpr RewritePredicate(const PrimExpr& predicate) {
138 PrimExpr res = const_true();
139 std::function<void(PrimExpr)> update = [&res, &update](PrimExpr e) {
140 arith::PVar<PrimExpr> a, b;
141 if ((a && b).Match(e)) {
142 update(a.Eval());
143 update(b.Eval());
144 } else {
145 if (const CallNode* call = e.as<CallNode>()) {
146 if (call->op.same_as(builtin::likely())) {
147 e = call->args[0];
148 }
149 }
150 res = res && e;
151 }
152 };
153 update(predicate);
154 return analyzer_->Simplify(res);
155 }
156
157 /*! \brief Return iteration region of block vars where the padding predicate evals to true. */
158 Array<Range> EstimateInBoundRegion(const Array<PrimExpr>& iter_values,
159 const Map<Var, Range>& dom_map,
160 const PrimExpr& in_bound_predicate) {
161 Array<Range> region;
162
163 auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate,
164 arith::IterMapLevel::Surjective, analyzer_);
165 if (res->indices.empty()) {
166 SetError("Block iters are not independent wrt padding condition");
167 return {};
168 }
169 for (const arith::IterSumExpr& sum : res->indices) {
170 if (sum->args.empty()) {
171 region.push_back(Range::FromMinExtent(sum->base, 1));
172 } else {
173 ICHECK_EQ(sum->args.size(), 1U);
174 if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) {
175 SetError("Strided iteration is not supported");
176 return {};
177 }
178 region.push_back(Range::FromMinExtent(sum->base, sum->args[0]->extent));
179 }
180 }
181 return region;
182 }
183
184 void SetError(const std::string& msg) { error_msg_ = msg; }
185
186 /*! \brief padding info analyse result. */
187 PaddingBlockInfo info_;
188 /*! \brief current error message. */
189 std::string error_msg_;
190 /*! \brief arithmetic analyzer. */
191 arith::Analyzer* analyzer_;
192};
193
194/*! \brief Create block to fill constant pad values into full region */
195static std::pair<Stmt, BlockRealize> CreateConstBlock(const BlockRealizeNode* realize,
196 const PaddingBlockInfo& info,
197 const Array<For>& loops,
198 const Stmt& highest_pos_inclusive,
199 arith::Analyzer* analyzer) {
200 const Block& block = realize->block;
201 Array<IterVar> new_iter_vars;
202 Map<Var, PrimExpr> repl_dict;
203
204 // create new block itervars
205 for (size_t i = 0; i < block->iter_vars.size(); ++i) {
206 const IterVar& origin_iter = block->iter_vars[i];
207 Var new_var = origin_iter->var.copy_with_suffix("");
208 new_iter_vars.push_back(IterVar(origin_iter->dom, new_var, IterVarType::kDataPar));
209 repl_dict.Set(origin_iter->var, new_var);
210 }
211
212 // rewrite expr helper
213 auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) {
214 return analyzer->Simplify(Substitute(e, repl_dict));
215 };
216
217 // create new write region
218 ICHECK_EQ(block->writes.size(), 1U);
219 BufferRegion write_region = BufferRegion(
220 block->writes[0]->buffer, block->writes[0]->region.Map([rewrite_expr](const Range& r) {
221 return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent));
222 }));
223
224 // create block to fill const pad values
225 BufferStore store = Downcast<BufferStore>(block->body);
226 store.CopyOnWrite()->value = info.pad_value;
227 store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr);
228 Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/{}, /*writes=*/{write_region},
229 /*name_hint=*/block->name_hint + "_pad_const", /*body=*/std::move(store));
230
231 // create new loop vars
232 Array<Var> new_loop_vars;
233 for (const For& loop : loops) {
234 Var new_var = loop->loop_var.copy_with_suffix("");
235 new_loop_vars.push_back(new_var);
236 repl_dict.Set(loop->loop_var, new_var);
237 if (loop.same_as(highest_pos_inclusive)) {
238 break;
239 }
240 }
241
242 // create new block realize node
243 Array<PrimExpr> new_iter_values;
244 for (size_t i = 0; i < realize->iter_values.size(); ++i) {
245 new_iter_values.push_back(rewrite_expr(realize->iter_values[i]));
246 }
247 BlockRealize new_realize(/*iter_values=*/new_iter_values,
248 /*predicate=*/rewrite_expr(realize->predicate),
249 /*block=*/new_block);
250
251 // create new loops
252 Stmt nest_stmt_root = new_realize;
253 for (size_t i = 0; i < new_loop_vars.size(); ++i) {
254 For loop = loops[i];
255 nest_stmt_root =
256 For(new_loop_vars[i], loop->min, loop->extent, ForKind::kSerial, nest_stmt_root);
257 }
258
259 return {nest_stmt_root, new_realize};
260}
261
262/*! \brief Create block to fill in-bound region values. */
263static std::pair<Stmt, BlockRealize> CreateInBoundBlock(const BlockRealizeNode* realize,
264 const PaddingBlockInfo& info,
265
266 const Array<For>& loops,
267 const Stmt& highest_pos_inclusive,
268 arith::Analyzer* analyzer) {
269 const Block& block = realize->block;
270 Array<IterVar> new_iter_vars;
271 Map<Var, PrimExpr> repl_dict;
272
273 // record loop ranges to be mutated
274 Map<Var, Range> new_loop_ranges;
275 for (const For& loop : loops) {
276 new_loop_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
277 if (loop.same_as(highest_pos_inclusive)) {
278 break;
279 }
280 }
281
282 // create new block iter vars and iter bindings
283 Array<PrimExpr> new_iter_binding;
284 for (size_t i = 0; i < info.in_bound_region.size(); ++i) {
285 // add new block itervar
286 const IterVar& origin_itervar = block->iter_vars[i];
287 Var new_var = origin_itervar->var.copy_with_suffix("");
288 Range new_range =
289 Range::FromMinExtent(make_const(new_var->dtype, 0), info.in_bound_region[i]->extent);
290 new_iter_vars.push_back(IterVar(new_range, new_var, IterVarType::kDataPar));
291 repl_dict.Set(origin_itervar->var, new_var + info.in_bound_region[i]->min);
292
293 // update new loop range
294 Var loop_var = GetRef<Var>(realize->iter_values[i].as<VarNode>());
295 if (loop_var.defined() && new_loop_ranges.count(loop_var)) {
296 // if the block binding is the loop var with single child, mutate loop range
297 // instead of insert extra block predicate
298 new_loop_ranges.Set(loop_var, new_range);
299 new_iter_binding.push_back(realize->iter_values[i]);
300 repl_dict.Set(loop_var, loop_var + info.in_bound_region[i]->min);
301 analyzer->Bind(loop_var, new_range, /*allow_override=*/true);
302 } else {
303 new_iter_binding.push_back(
304 analyzer->Simplify(realize->iter_values[i] - info.in_bound_region[i]->min));
305 }
306 }
307
308 // rewrite helpers
309 auto rewrite_expr = [&repl_dict, analyzer](const PrimExpr& e) {
310 return analyzer->Simplify(Substitute(e, repl_dict));
311 };
312 auto rewrite_region = [rewrite_expr](const Region& region) {
313 return region.Map([rewrite_expr](const Range& r) {
314 return Range::FromMinExtent(rewrite_expr(r->min), rewrite_expr(r->extent));
315 });
316 };
317
318 // create new read/write region for in-bound accesses
319 Array<BufferRegion> reads, writes;
320 for (const BufferRegion& read : block->reads) {
321 reads.push_back(BufferRegion(read->buffer, rewrite_region(read->region)));
322 }
323 for (const BufferRegion& write : block->writes) {
324 writes.push_back(BufferRegion(write->buffer, rewrite_region(write->region)));
325 }
326
327 // create new block realize node
328 BufferStore store = Downcast<BufferStore>(block->body);
329 store.CopyOnWrite()->value = rewrite_expr(info.in_bound_value);
330 store.CopyOnWrite()->indices = store->indices.Map(rewrite_expr);
331 Block new_block(/*iter_vars=*/new_iter_vars, /*reads=*/reads, /*writes=*/writes,
332 /*name_hint=*/block->name_hint, /*body=*/std::move(store));
333 PrimExpr new_predicate = rewrite_expr(info.in_bound_predicate);
334 BlockRealize new_realize(/*iter_values=*/new_iter_binding, /*predicate=*/new_predicate,
335 /*block=*/new_block);
336
337 // create new loops
338 Stmt nest_stmt_root = new_realize;
339 for (const For& loop : loops) {
340 auto it = new_loop_ranges.find(loop->loop_var);
341 PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min;
342 PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : (*it).second->extent;
343 nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, nest_stmt_root,
344 loop->thread_binding, loop->annotations, loop->span);
345 if (loop.same_as(highest_pos_inclusive)) {
346 break;
347 }
348 }
349 return {nest_stmt_root, new_realize};
350}
351
352/*!
353 * \brief A helper class to create a new scope that contains decomposed padding blocks.
354 */
355class DecomposePaddingBlockReplacer : public StmtMutator {
356 public:
357 /*! \brief Replacement information */
358 struct ReplaceDesc {
359 /*! \brief loop above which to insert const pad value filling code. */
360 For const_filling_pos;
361 /*! \brief loop under which to insert in bound value filling code. */
362 For in_bound_filling_pos;
363 /*! \brief const pad value filling loop. */
364 Stmt const_filling_loop;
365 /*! \brief highest in bound value filling loop with single child. */
366 Stmt in_bound_filling_loop;
367 /*! \brief const pad value filling block. */
368 BlockRealize const_filling_block;
369 /*! \brief in bound value filling block. */
370 BlockRealize in_bound_filling_block;
371 };
372
373 static Block Replace(Block scope_root, const ReplaceDesc& desc) {
374 DecomposePaddingBlockReplacer replacer(desc);
375 return Downcast<Block>(replacer(std::move(scope_root)));
376 }
377
378 private:
379 explicit DecomposePaddingBlockReplacer(const ReplaceDesc& desc) : desc_(desc) {}
380
381 Stmt VisitStmt_(const ForNode* op) final {
382 Stmt new_loop;
383 if (op == desc_.in_bound_filling_pos.get()) {
384 // position to rewrite inbound filling code
385 new_loop = desc_.in_bound_filling_loop;
386 } else {
387 new_loop = StmtMutator::VisitStmt_(op);
388 }
389 if (op == desc_.const_filling_pos.get()) {
390 // position to insert pad value filling code
391 return std::move(SeqStmt({desc_.const_filling_loop, new_loop}));
392 }
393 return std::move(new_loop);
394 }
395
396 Stmt VisitStmt_(const SeqStmtNode* seq) final {
397 Array<Stmt> new_stmts;
398 new_stmts.reserve(seq->seq.size());
399 for (const Stmt& old_stmt : seq->seq) {
400 new_stmts.push_back(VisitStmt(old_stmt));
401 }
402 return SeqStmt::Flatten(new_stmts);
403 }
404
405 private:
406 const ReplaceDesc& desc_;
407};
408
409StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
410 const StmtSRef& loop_sref, bool check_only) {
411 /*!
412 * Check
413 * - the block is a compact block
414 * - the loop is an ancester of the block
415 * - the block match padding pattern
416 * Mutate
417 * - generate new block to fill padding values
418 * - trim original block to write non-padding part only
419 */
420 // Condition Checks and Information Collection
421 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
422 const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
423 Map<Var, Range> dom_map;
424 arith::Analyzer analyzer;
425
426 // Check 1. check the block is complete.
427 StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
428 CheckCompleteBlock(self, block_sref, scope_root_sref);
429
430 // Check 2. Check loop_sref is an ancestor of block_sref. Also collect
431 // - the highest loop position (inclusive) to insert const pad value filling code above.
432 // - the highest loop position (inclusive) to replace with in-bound value filling code.
433 Array<StmtSRef> loop_srefs = GetLoops(block_sref);
434 Array<For> loops;
435 bool found_const_filling_pos = false;
436 bool found_in_bound_filling_pos = false;
437 For const_filling_pos = GetRef<For>(loop_sref->StmtAs<ForNode>());
438 For in_bound_filling_pos{nullptr};
439 for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) {
440 For cur_loop = GetRef<For>((*it)->StmtAs<ForNode>());
441 Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent);
442 dom_map.Set(cur_loop->loop_var, range);
443 analyzer.Bind(cur_loop->loop_var, range);
444 loops.push_back(cur_loop);
445
446 if (cur_loop.same_as(const_filling_pos)) {
447 ICHECK(!found_const_filling_pos);
448 found_const_filling_pos = true;
449 if (!found_in_bound_filling_pos) {
450 found_in_bound_filling_pos = true;
451 in_bound_filling_pos = cur_loop;
452 }
453 } else if (!found_in_bound_filling_pos) {
454 if (!cur_loop->body->IsInstance<ForNode>() &&
455 !cur_loop->body->IsInstance<BlockRealizeNode>()) {
456 found_in_bound_filling_pos = true;
457 } else {
458 in_bound_filling_pos = cur_loop;
459 }
460 }
461 }
462 ICHECK(in_bound_filling_pos.defined());
463 if (!found_const_filling_pos) {
464 throw LoopPositionError(self->mod, const_filling_pos, GetRef<Block>(block),
465 "decompose_padding");
466 }
467
468 // Check 3. match padding pattern and return padding operation info.
469 PaddingBlockInfo info =
470 PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, &analyzer);
471
472 // IR Manipulation
473 // Step 1. Create const pad value filling part and in-bound value filling part.
474 DecomposePaddingBlockReplacer::ReplaceDesc replace_desc;
475 replace_desc.const_filling_pos = const_filling_pos;
476 replace_desc.in_bound_filling_pos = in_bound_filling_pos;
477 std::tie(replace_desc.const_filling_loop, replace_desc.const_filling_block) =
478 CreateConstBlock(realize, info, loops, const_filling_pos, &analyzer);
479 std::tie(replace_desc.in_bound_filling_loop, replace_desc.in_bound_filling_block) =
480 CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer);
481
482 // Step 2. Execute IR replacement.
483 Block old_scope_root_block = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
484 Block new_scope_root = DecomposePaddingBlockReplacer::Replace(old_scope_root_block, replace_desc);
485 if (check_only) {
486 return block_sref;
487 }
488
489 // Step 3. Update schedule states.
490 self->Replace(scope_root_sref, new_scope_root,
491 {{old_scope_root_block, new_scope_root},
492 {GetRef<Block>(block), replace_desc.in_bound_filling_block->block}});
493 auto new_block_sref = self->stmt2ref.at(replace_desc.const_filling_block->block.get());
494
495 // Set block info of created const pad value filling block
496 BlockInfo& block_info = self->block_info[new_block_sref];
497 block_info.affine_binding = true;
498 block_info.region_cover = true;
499 block_info.scope->stage_pipeline = true;
500
501 // If the const pad value filling block is lifted out of the original subtree,
502 // set the region_cover flag as false since region_cover is the property under the subtree.
503 bool preserve_stage_pipeline = true;
504 for (const StmtSRef& consumer_sref : GetConsumers(self, block_sref)) {
505 StmtSRef lca = GetSRefLowestCommonAncestor({consumer_sref, block_sref});
506 const StmtSRefNode* parent = new_block_sref->parent;
507 bool is_under_lca = false;
508 while (parent) {
509 if (parent == lca.get()) {
510 is_under_lca = true;
511 break;
512 }
513 parent = parent->parent;
514 }
515 if (!is_under_lca) {
516 preserve_stage_pipeline = false;
517 self->block_info[consumer_sref].region_cover = false;
518 }
519 }
520 if (!preserve_stage_pipeline) {
521 self->block_info[scope_root_sref].scope->stage_pipeline = false;
522 }
523 return new_block_sref;
524}
525
526StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
527 const StmtSRef& loop_sref) {
528 return DecomposePaddingImpl(self, block_sref, loop_sref, false);
529}
530
531bool CanDecomposePadding(ScheduleState self, const StmtSRef& block_sref,
532 const StmtSRef& loop_sref) {
533 try {
534 DecomposePaddingImpl(self, block_sref, loop_sref, true);
535 } catch (const tvm::runtime::Error& e) {
536 return false;
537 }
538 return true;
539}
540
541/******** FFI ********/
542
543TVM_REGISTER_GLOBAL("tir.schedule.CanDecomposePadding")
544 .set_body_typed([](Schedule self, BlockRV block_rv, LoopRV loop_rv) {
545 return CanDecomposePadding(self->state(), self->GetSRef(block_rv), self->GetSRef(loop_rv));
546 });
547
548/******** InstructionKind Registration ********/
549
550struct DecomposPaddingTraits : public UnpackedInstTraits<DecomposPaddingTraits> {
551 static constexpr const char* kName = "DecomposePadding";
552 static constexpr bool kIsPure = false;
553
554 private:
555 static constexpr size_t kNumInputs = 2;
556 static constexpr size_t kNumAttrs = 0;
557 static constexpr size_t kNumDecisions = 0;
558
559 static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) {
560 return sch->DecomposePadding(block_rv, loop_rv);
561 }
562
563 static String UnpackedAsPython(Array<String> outputs, String block_rv, LoopRV loop_rv) {
564 PythonAPICall py("decompose_padding");
565 py.Input("block", block_rv);
566 py.Input("loop", loop_rv);
567 py.SingleOutput(outputs);
568 return py.Str();
569 }
570
571 template <typename>
572 friend struct ::tvm::tir::UnpackedInstTraits;
573};
574
575TVM_REGISTER_INST_KIND_TRAITS(DecomposPaddingTraits);
576
577} // namespace tir
578} // namespace tvm
579