1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */
25class BlockPredicateAppender : public StmtMutator {
26 public:
27 /*!
28 * \brief Constructor
29 * \param to_append The predicate to be appended to BlockRealizeNode
30 */
31 explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {}
32
33 private:
34 // For each direct child of type BlockRealizeNode, append the predicate
35 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
36 // We do not recursively do this
37 ObjectPtr<BlockRealizeNode> n = CopyOnWrite(realize);
38 n->predicate = n->predicate && to_append_;
39 return BlockRealize(n);
40 }
41
42 /*! \brief The predicate to be appended */
43 const PrimExpr& to_append_;
44};
45
46/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */
47class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
48 public:
49 explicit SubstituteVarAndCollectOpaqueBlock(std::function<Optional<PrimExpr>(const Var&)> vmap,
50 Map<Block, Block>* opaque_blocks)
51 : vmap_(vmap), opaque_blocks_(opaque_blocks) {}
52
53 private:
54 PrimExpr VisitExpr_(const VarNode* op) final {
55 Var var = GetRef<Var>(op);
56 if (Optional<PrimExpr> ret = vmap_(var)) {
57 return tvm::cast(var.dtype(), ret.value());
58 } else {
59 return std::move(var);
60 }
61 }
62
63 Stmt VisitStmt_(const BlockRealizeNode* op) final {
64 BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
65 if (realize->block->iter_vars.empty()) {
66 opaque_blocks_->Set(op->block, realize->block);
67 }
68 return std::move(realize);
69 }
70
71 /*! \brief The substitute function */
72 std::function<Optional<PrimExpr>(const Var&)> vmap_;
73 /*! \brief The reuse mapping of opaque blocks */
74 Map<Block, Block>* opaque_blocks_;
75};
76
77/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
78class IterMapSimplifyBlockBinding : public StmtExprMutator {
79 public:
80 explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent,
81 bool preserve_unit_iters)
82 : opaque_blocks_(opaque_blocks),
83 loop_var2extent_(loop_var2extent),
84 preserve_unit_iters_(preserve_unit_iters) {}
85
86 static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks,
87 bool preserve_unit_iters) {
88 Map<Var, Range> loop_var2extent;
89 for (const StmtSRef& sref : loop_srefs) {
90 const ForNode* loop = TVM_SREF_TO_FOR(sref);
91 loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
92 }
93 return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
94 preserve_unit_iters)(std::move(stmt)));
95 }
96
97 private:
98 Stmt VisitStmt_(const ForNode* op) final {
99 loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent));
100 Stmt res = StmtMutator::VisitStmt_(op);
101 loop_var2extent_.erase(op->loop_var);
102 return res;
103 }
104
105 Stmt VisitStmt_(const BlockRealizeNode* op) final {
106 // skip opaque block and update mapping
107 if (op->iter_values.empty()) {
108 Block block = op->block;
109 BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
110 for (const std::pair<ObjectRef, ObjectRef>& entry : *opaque_blocks_) {
111 if (entry.second.same_as(block)) {
112 opaque_blocks_->at(entry.first) = realize->block;
113 break;
114 }
115 }
116 return std::move(realize);
117 }
118 Array<PrimExpr> v =
119 arith::IterMapSimplify(/*indices=*/op->iter_values,
120 /*input_iters=*/loop_var2extent_,
121 /*input_pred=*/op->predicate,
122 /*check_level=*/arith::IterMapLevel::Surjective,
123 /*simplify_trivial_iterators=*/!preserve_unit_iters_);
124 if (v.same_as(op->iter_values)) {
125 return GetRef<Stmt>(op);
126 } else {
127 ObjectPtr<BlockRealizeNode> n = CopyOnWrite(op);
128 n->iter_values = std::move(v);
129 return Stmt(n);
130 }
131 }
132
133 /*! \brief The reuse mapping */
134 MapNode* opaque_blocks_;
135 /*! \brief The range of loops */
136 Map<Var, Range> loop_var2extent_;
137 /*! \brief Whether or not to simplify unit iterators */
138 bool preserve_unit_iters_;
139};
140
141class BlockPropertyError : public ScheduleError {
142 public:
143 /*!
144 * \brief Check that all the blocks under the specific stmt have affine bindings
145 * wrt top loop sref and only have data-parallel or reduction block iters
146 * \param self The state of the schedule
147 * \param sref The sref to the specific stmt
148 */
149 static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top,
150 const StmtSRefNode* sref) {
151 class BlockIterTypeAndAffineBindingChecker : public StmtVisitor {
152 public:
153 explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state,
154 const StmtSRefNode* top)
155 : state_(state), top_(top) {}
156
157 private:
158 void VisitStmt_(const BlockNode* op) final {
159 for (const IterVar& iter_var : op->iter_vars) {
160 if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) {
161 throw BlockPropertyError(state_->mod, GetRef<Block>(op));
162 }
163 Optional<StmtSRef> high_exclusive =
164 top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt);
165 CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive);
166 }
167 }
168 const ScheduleState& state_;
169 const StmtSRefNode* top_;
170 };
171
172 BlockIterTypeAndAffineBindingChecker checker(self, top);
173 checker(GetRef<Stmt>(sref->stmt));
174 }
175
176 explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {}
177
178 String FastErrorString() const final {
179 return "ScheduleError: The block under the loops to be reordered have block iter type other "
180 "than data-parallel or reduction";
181 }
182
183 String DetailRenderTemplate() const final {
184 return "The block {0} under the loops to be reordered have block iter type other than "
185 "data-parallel or reduction";
186 }
187
188 IRModule mod() const final { return mod_; }
189 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
190
191 IRModule mod_;
192 Block block_;
193};
194
195class HasAnnotationOrThreadBindingError : public ScheduleError {
196 public:
197 explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop)
198 : mod_(mod), loop_(std::move(loop)) {}
199
200 String FastErrorString() const final {
201 return "ScheduleError: The primitive can't be applied because the loop has annotation or "
202 "thread binding";
203 }
204
205 String DetailRenderTemplate() const final {
206 return "The primitive can't be applied because the loop {0} has annotation or thread binding";
207 }
208
209 IRModule mod() const final { return mod_; }
210 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
211
212 IRModule mod_;
213 For loop_;
214};
215
216class OuterNotInnerParent : public ScheduleError {
217 public:
218 explicit OuterNotInnerParent(IRModule mod, For outer, For inner)
219 : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {}
220
221 String FastErrorString() const final {
222 return "ScheduleError: The outer loop is not the parent of the inner loop";
223 }
224
225 String DetailRenderTemplate() const final {
226 return "The loops can't be fused because the outer loop {0} is not the parent of the inner "
227 "loop {1}";
228 }
229
230 IRModule mod() const final { return mod_; }
231 Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; }
232
233 IRModule mod_;
234 For outer_;
235 For inner_;
236};
237
238class NotOnlyChildError : public ScheduleError {
239 public:
240 explicit NotOnlyChildError(IRModule mod, For outer, For inner)
241 : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {}
242
243 String FastErrorString() const final {
244 return "ScheduleError: The inner loop is not the only child of outer loop";
245 }
246
247 String DetailRenderTemplate() const final {
248 return "The loops can't be fused because the inner loop {1} is not the only child of outer "
249 "loop {0}.";
250 }
251
252 IRModule mod() const final { return mod_; }
253 Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; }
254
255 IRModule mod_;
256 For outer_;
257 For inner_;
258};
259
260class NotSingleInferFactorError : public ScheduleError {
261 public:
262 explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
263
264 String FastErrorString() const final {
265 return "ScheduleError: only one factor can be specified as -1 or none";
266 }
267
268 String DetailRenderTemplate() const final {
269 return "Only one factor can be specified as -1 or none";
270 }
271
272 IRModule mod() const final { return mod_; }
273 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
274
275 IRModule mod_;
276};
277
278class WrongFactorProductError : public ScheduleError {
279 public:
280 explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}
281
282 String FastErrorString() const final {
283 return "ScheduleError: The product of factors is not larger than or equal to the extent of "
284 "loop";
285 }
286
287 String DetailRenderTemplate() const final {
288 return "The product of factors is not larger than or equal to the extent of loop {0}";
289 }
290
291 IRModule mod() const final { return mod_; }
292 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
293
294 IRModule mod_;
295 For loop_;
296};
297
298class LoopMultiAppearanceError : public ScheduleError {
299 public:
300 explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}
301
302 String FastErrorString() const final {
303 return "ScheduleError: Some loop appears in the input array for multiple times.";
304 }
305
306 String DetailRenderTemplate() const final {
307 return "Loop {0} appears in the input array for multiple times.";
308 }
309
310 IRModule mod() const final { return mod_; }
311 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
312
313 IRModule mod_;
314 For loop_;
315};
316
317class LoopsNotAChainError : public ScheduleError {
318 public:
319 enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt };
320
321 explicit LoopsNotAChainError(IRModule mod, Optional<Stmt> problematic_loop, ProblemKind kind)
322 : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {}
323
324 String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; }
325
326 String DetailRenderTemplate() const final {
327 std::stringstream ss;
328 ss << "The loops are not in a chain because";
329 if (kind_ == ProblemKind::kNotUnderAScope) {
330 ss << " they are not under the same scope.";
331 } else {
332 ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}";
333 }
334 return ss.str();
335 }
336
337 IRModule mod() const final { return mod_; }
338 Array<ObjectRef> LocationsOfInterest() const final {
339 if (kind_ == ProblemKind::kNotUnderAScope) {
340 return {};
341 } else {
342 ICHECK(problematic_loop_.defined());
343 return {problematic_loop_.value()};
344 }
345 }
346
347 IRModule mod_;
348 Optional<Stmt> problematic_loop_;
349 ProblemKind kind_;
350};
351
352class DependentLoopError : public ScheduleError {
353 public:
354 enum class PrimitiveKind { kFuse, kReorder };
355 explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind)
356 : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {}
357
358 String FastErrorString() const final {
359 if (kind_ == PrimitiveKind::kReorder) {
360 return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop "
361 "in the new order";
362 } else {
363 return "ScheduleError: A loop's `extent` is dependent on another loop";
364 }
365 }
366
367 String DetailRenderTemplate() const final {
368 if (kind_ == PrimitiveKind::kReorder) {
369 return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ +
370 " in the new order";
371 } else {
372 return "A loop {0}'s `extent` is dependent on another loop " + inner_var_;
373 }
374 }
375
376 IRModule mod() const final { return mod_; }
377 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
378
379 IRModule mod_;
380 For loop_;
381 String inner_var_;
382 PrimitiveKind kind_;
383};
384
385Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors,
386 bool preserve_unit_iters) {
387 // Invariance
388 // - The total repeat number has not changed for each direct child block with updating predicate.
389 // - The execution order has not changed. (The block executes with the same args and the same
390 // order with before.
391 // Step 1. Check correctness
392 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
393 if (!loop->annotations.empty() || loop->thread_binding.defined()) {
394 throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
395 }
396 // Currently, loops not starting with 0 are not supported
397 arith::Analyzer analyzer;
398 CheckLoopStartsWithZero(self, loop_sref, &analyzer);
399
400 // Find the most common dtype
401 DataType dtype;
402 {
403 int bits = loop->loop_var.dtype().bits();
404 for (const PrimExpr& factor : factors) {
405 bits = std::max(bits, factor.dtype().bits());
406 }
407 dtype = DataType::Int(bits);
408 }
409 int n = factors.size();
410 PrimExpr substitute_value = make_const(dtype, 0);
411 std::vector<Var> new_loop_vars;
412 new_loop_vars.reserve(n);
413 for (int i = 0; i < n; i++) {
414 const PrimExpr& factor = factors[i];
415 Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)).copy_with_dtype(dtype);
416 substitute_value = substitute_value * factor + var;
417 analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor)));
418 new_loop_vars.emplace_back(std::move(var));
419 }
420 Map<Block, Block> opaque_block_reuse;
421 Stmt new_stmt = loop->body;
422 new_stmt = SubstituteVarAndCollectOpaqueBlock(
423 [&](const Var& v) -> Optional<PrimExpr> {
424 if (v.same_as(loop->loop_var)) {
425 return substitute_value;
426 } else {
427 return NullOpt;
428 }
429 },
430 &opaque_block_reuse)(std::move(new_stmt));
431 // Step 3. Update predicate to guard the loop
432 PrimExpr predicate = substitute_value < loop->extent;
433 if (!analyzer.CanProve(predicate)) {
434 new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt));
435 }
436 // Step 4. Generate nested loops to replace the original loop and simplify the binding
437 for (int i = n - 1; i >= 0; i--) {
438 new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
439 }
440 new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
441 opaque_block_reuse.CopyOnWrite(),
442 preserve_unit_iters);
443 self->Replace(loop_sref, new_stmt, opaque_block_reuse);
444 Array<StmtSRef> result_srefs;
445 result_srefs.reserve(n);
446 for (int i = 0; i < n; i++) {
447 result_srefs.push_back(self->stmt2ref.at(new_stmt.get()));
448 const ForNode* outer_loop = TVM_TYPE_AS(new_stmt, ForNode);
449 new_stmt = outer_loop->body;
450 }
451 return result_srefs;
452}
453
454StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
455 // Invariance
456 // - The total repeat number has not changed for each direct child block.
457 // - The execution order has not changed. (The block executes with the same
458 // args and the same order with before.)
459 std::vector<const ForNode*> loops;
460 loops.reserve(loop_srefs.size());
461 StmtSRef outer_loop_sref{nullptr};
462 const ForNode* outer_loop = nullptr;
463 arith::Analyzer analyzer;
464 std::unordered_set<const VarNode*> outer_loop_vars;
465 // Step 1. check correctness
466 for (const StmtSRef& sref : loop_srefs) {
467 const ForNode* loop = TVM_SREF_TO_FOR(sref);
468 if (!loop->annotations.empty() || loop->thread_binding.defined()) {
469 throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
470 }
471 if (outer_loop_sref.defined()) {
472 if (sref->parent != outer_loop_sref.get()) {
473 throw OuterNotInnerParent(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop));
474 }
475 if (!outer_loop->body.same_as(GetRef<For>(loop))) {
476 throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop));
477 }
478 }
479 outer_loop_sref = sref;
480 outer_loop = loop;
481 CheckLoopStartsWithZero(self, sref, &analyzer);
482 const VarNode* used_var = nullptr;
483 auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) {
484 if (outer_loop_vars.count(var)) {
485 used_var = var;
486 return true;
487 }
488 return false;
489 };
490 if (UsesVar(loop->extent, f_contain)) {
491 throw DependentLoopError(self->mod, GetRef<For>(loop), used_var->name_hint,
492 DependentLoopError::PrimitiveKind::kFuse);
493 }
494 outer_loop_vars.insert(loop->loop_var.get());
495 loops.push_back(loop);
496 }
497 // Step 2. Create fused loop var and replace the original loop vars
498 std::string suffix;
499 int n = loops.size();
500 int bits = loops[0]->loop_var.dtype().bits();
501 for (int i = 1; i < n; i++) {
502 suffix += "_" + loops[i]->loop_var->name_hint;
503 bits = std::max(bits, loops[i]->loop_var.dtype().bits());
504 }
505 suffix += "_fused";
506 Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits));
507 Array<PrimExpr> substitute_value;
508 substitute_value.resize(loops.size());
509 PrimExpr lower = 1;
510 for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
511 substitute_value.Set(i, is_one(loops[i]->extent)
512 ? 0
513 : floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
514 lower = lower * loops[i]->extent;
515 }
516 substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
517 Stmt new_stmt = loops.back()->body;
518 Map<Block, Block> opaque_block_reuse;
519 auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
520 for (int i = 0; i < n; i++) {
521 if (v.same_as(loops[i]->loop_var)) {
522 return substitute_value[i];
523 }
524 }
525 return NullOpt;
526 };
527 new_stmt =
528 SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
529 // Step 3. Generate a loop to replace the original loops
530 PrimExpr fused_extent = 1;
531 for (int i = 0; i < n; i++) {
532 fused_extent *= loops[i]->extent;
533 }
534 fused_extent = analyzer.Simplify(fused_extent);
535 new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
536 new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
537 std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(),
538 preserve_unit_iters);
539 self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
540 return self->stmt2ref.at(new_stmt.get());
541}
542
543/*!
544 * \brief Collect an array of loop srefs into a set
545 * \param self The schedule state
546 * \param ordered_loop_srefs The array of loop srefs
547 * \return A set containing all loops in the array
548 * \throws ScheduleError If there are duplicate loops in the array
549 */
550std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
551 const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) {
552 std::unordered_set<const StmtSRefNode*> loop_srefs;
553 loop_srefs.reserve(ordered_loop_srefs.size());
554 for (const StmtSRef& loop_sref : ordered_loop_srefs) {
555 auto inserted = loop_srefs.insert(loop_sref.get());
556 if (!inserted.second) {
557 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
558 throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
559 }
560 }
561 return loop_srefs;
562}
563
564/*!
565 * \brief Get the top and bottom boundary of reorder range (which should be a chain)
566 * \param self The schedule state
567 * \param loop_srefs The set containing the srefs to the loops to be reordered
568 * \return A pair containing the top and bottom boundary of the reorder range
569 * \throws ScheduleError If the loops to be reordered is not in a chain
570 */
571std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange(
572 const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& loop_srefs) {
573 const StmtSRefNode* top = nullptr;
574 const StmtSRefNode* bottom = *loop_srefs.begin();
575 std::unordered_set<const StmtSRefNode*> visited;
576 bool scope_block_visited = false;
577 bool first_traversal = true;
578 for (const StmtSRefNode* loop_sref : loop_srefs) {
579 if (visited.count(loop_sref)) {
580 continue;
581 }
582 for (const StmtSRefNode* v = loop_sref;; v = v->parent) {
583 // Case 1. If `v` corresponds to a block, stop traversal.
584 if (v->stmt->IsInstance<BlockNode>()) {
585 if (scope_block_visited) {
586 throw LoopsNotAChainError(self->mod, NullOpt,
587 LoopsNotAChainError::ProblemKind::kNotUnderAScope);
588 }
589 scope_block_visited = true;
590 break;
591 }
592 // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update
593 // `bottom`.
594 if (visited.count(v)) {
595 if (v != bottom) {
596 throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt),
597 LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
598 }
599 bottom = loop_sref;
600 break;
601 }
602 // Case 3. Add `v` into `visited`
603 visited.insert(v);
604 // If it's the first traversal and the loop corresponding to `v` is in the input array,
605 // update `top`.
606 if (first_traversal && loop_srefs.count(v)) {
607 top = v;
608 }
609 }
610 first_traversal = false;
611 }
612 return std::make_pair(top, bottom);
613}
614
615/*!
616 * \brief Get all the loops in the reorder range
617 * \param self The schedule state
618 * \param top The top boundary of the reorder range
619 * \param bottom The bottom boundary of the reorder range
620 * \return An array containing all the loops in the reorder range
621 * \throws ScheduleError If some loop in the reorder range is not single-branch
622 */
623std::vector<const StmtSRefNode*> GetLoopsInReorderRange(const ScheduleState& self,
624 const StmtSRefNode* top,
625 const StmtSRefNode* bottom) {
626 std::vector<const StmtSRefNode*> chain;
627 for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) {
628 const StmtSRefNode* parent_loop_sref = loop_sref->parent;
629 const ForNode* outer = parent_loop_sref->StmtAs<ForNode>();
630 const ForNode* inner = loop_sref->StmtAs<ForNode>();
631 ICHECK(outer != nullptr && inner != nullptr);
632 if (outer->body.get() != inner) {
633 throw LoopsNotAChainError(self->mod, GetRef<For>(outer),
634 LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt);
635 }
636 chain.push_back(loop_sref);
637 loop_sref = parent_loop_sref;
638 }
639 chain.push_back(top);
640 return chain;
641}
642
643/*!
644 * \brief Construct a loop chain in the new order
645 * \param self The schedule state
646 * \param chain The loops in the reorder range
647 * \param ordered_loop_srefs The loop srefs to be reordered
648 * \param loop_srefs The set containing loop srefs to be reordered
649 * \return The new loop chain
650 * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after
651 * reordering
652 */
653For ConstructNewLoopChain(const ScheduleState& self, std::vector<const StmtSRefNode*> chain,
654 const Array<StmtSRef>& ordered_loop_srefs,
655 const std::unordered_set<const StmtSRefNode*>& loop_srefs) {
656 std::unordered_set<const VarNode*> inner_vars;
657 inner_vars.reserve(chain.size());
658 For new_loop{nullptr};
659 int index = static_cast<int>(ordered_loop_srefs.size()) - 1;
660 for (const StmtSRefNode* loop_sref : chain) {
661 const ForNode* copy = nullptr;
662 if (loop_srefs.count(loop_sref)) {
663 copy = ordered_loop_srefs[index]->StmtAs<ForNode>();
664 --index;
665 } else {
666 copy = loop_sref->StmtAs<ForNode>();
667 }
668 ICHECK(copy != nullptr);
669 ObjectPtr<ForNode> n = make_object<ForNode>(*copy);
670 if (new_loop.defined()) {
671 n->body = new_loop;
672 } else {
673 n->body = loop_sref->StmtAs<ForNode>()->body;
674 }
675 const VarNode* used_var = nullptr;
676 auto f_contain = [&inner_vars, &used_var](const VarNode* var) {
677 if (inner_vars.count(var)) {
678 used_var = var;
679 return true;
680 }
681 return false;
682 };
683 if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) {
684 throw DependentLoopError(self->mod, GetRef<For>(copy), used_var->name_hint,
685 DependentLoopError::PrimitiveKind::kReorder);
686 }
687 inner_vars.insert(copy->loop_var.get());
688 new_loop = For(std::move(n));
689 }
690 return new_loop;
691}
692
693void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) {
694 if (ordered_loop_srefs.size() <= 1) {
695 return;
696 }
697 // Step 1. Check uniqueness and collect the input loop srefs into a set
698 std::unordered_set<const StmtSRefNode*> loop_srefs =
699 CollectLoopsIntoSet(self, ordered_loop_srefs);
700 // Step 2. Gather loops to be reordered
701 // For each loop sref in the input sref array, traverse upwards along its parent pointer in the
702 // sref tree, and stop on either a block, or a previously-visited loop
703 // - the top of the reorder range is the last loop visited in the first traversal which exists in
704 // the input array
705 // - the bottom of the reorder range is the last loop in the input array which is not visited in
706 // the previous traversals
707 auto [top, bottom] = GetBoundaryOfReorderRange(self, loop_srefs);
708 // Step 3. Collect all loops in the chain and check the loops are single-branch
709 std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom);
710 // Step 4. Check the block below has all its block_var to be data-parallel or reduction,
711 // and the block has an affine binding wrt top of the loop range.
712 BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom);
713 // Step 5. Replace the original loops with the reordered loops and check that outer loop is
714 // not dependent on inner loop
715 For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs);
716 self->Replace(GetRef<StmtSRef>(top), new_loop, {});
717}
718
719StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
720 if (sref->stmt->IsInstance<ForNode>()) {
721 For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<Stmt>(sref->stmt));
722 self->Replace(sref, new_loop, {});
723 return self->stmt2ref.at(new_loop.get());
724 }
725 class NewLoopCreator : public StmtMutator {
726 public:
727 explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {}
728
729 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
730 if (realize->block.get() == src_block_) {
731 new_loop_ =
732 For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<BlockRealize>(realize));
733 return new_loop_;
734 }
735 return StmtMutator::VisitStmt_(realize);
736 }
737
738 const StmtNode* src_block_;
739 For new_loop_{nullptr};
740 };
741
742 CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block";
743 StmtSRef parent_sref = GetRef<StmtSRef>(sref->parent);
744 NewLoopCreator creator(sref->stmt);
745 Stmt new_stmt = creator(GetRef<Stmt>(parent_sref->stmt));
746 if (new_stmt->IsInstance<ForNode>()) {
747 self->Replace(parent_sref, std::move(new_stmt), {});
748 } else {
749 Block old_parent_block = GetRef<Block>(parent_sref->StmtAs<BlockNode>());
750 Block new_parent_block = Downcast<Block>(new_stmt);
751 self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}});
752 }
753 return self->stmt2ref.at(creator.new_loop_.get());
754}
755
756/******** InstructionKind Registration ********/
757
758struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
759 static constexpr const char* kName = "Split";
760 static constexpr bool kIsPure = false;
761
762 private:
763 static constexpr size_t kNumInputs = 2;
764 static constexpr size_t kNumAttrs = 1;
765 static constexpr size_t kNumDecisions = 0;
766
767 template <size_t delta>
768 static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
769 const Array<ObjectRef>& inputs) {
770 thread_local ObjectRef loop_rv{nullptr};
771 thread_local Array<ObjectRef> factors{nullptr};
772 loop_rv = inputs[0];
773 factors = Array<ObjectRef>{inputs.begin() + 1, inputs.end()};
774 setter(delta, loop_rv);
775 setter(delta + 1, factors);
776 }
777
778 static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv,
779 Array<Optional<ExprRV>> factors,
780 Bool preserve_unit_iters) {
781 return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool());
782 }
783
784 static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors,
785 Bool preserve_unit_iters) {
786 PythonAPICall py("split");
787 py.Input("loop", loop_rv);
788 py.Input("factors", factors);
789 py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
790 py.OutputList(outputs);
791 return py.Str();
792 }
793
794 template <typename>
795 friend struct ::tvm::tir::UnpackedInstTraits;
796};
797
798struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
799 static constexpr const char* kName = "Fuse";
800 static constexpr bool kIsPure = false;
801
802 private:
803 static constexpr size_t kNumInputs = 1;
804 static constexpr size_t kNumAttrs = 1;
805 static constexpr size_t kNumDecisions = 0;
806
807 template <size_t delta>
808 static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
809 const Array<ObjectRef>& inputs) {
810 setter(delta, inputs);
811 }
812
813 static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs,
814 Bool preserve_unit_iters) {
815 return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool());
816 }
817
818 static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs,
819 Bool preserve_unit_iters) {
820 PythonAPICall py("fuse");
821 for (const String& loop_rv : loop_rvs) {
822 py.Input("", loop_rv);
823 }
824 py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
825 py.SingleOutput(outputs);
826 return py.Str();
827 }
828
829 template <typename>
830 friend struct ::tvm::tir::UnpackedInstTraits;
831};
832
833struct ReorderTraits : public UnpackedInstTraits<ReorderTraits> {
834 static constexpr const char* kName = "Reorder";
835 static constexpr bool kIsPure = false;
836
837 private:
838 static constexpr size_t kNumInputs = 1;
839 static constexpr size_t kNumAttrs = 0;
840 static constexpr size_t kNumDecisions = 0;
841
842 template <size_t delta>
843 static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
844 const Array<ObjectRef>& inputs) {
845 setter(delta, inputs);
846 }
847
848 static void UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
849 return sch->Reorder(loop_rvs);
850 }
851
852 static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
853 PythonAPICall py("reorder");
854 for (const String& loop_rv : loop_rvs) {
855 py.Input("", loop_rv);
856 }
857 return py.Str();
858 }
859
860 template <typename>
861 friend struct ::tvm::tir::UnpackedInstTraits;
862};
863
864struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
865 static constexpr const char* kName = "AddUnitLoop";
866 static constexpr bool kIsPure = false;
867
868 private:
869 static constexpr size_t kNumInputs = 1;
870 static constexpr size_t kNumAttrs = 0;
871 static constexpr size_t kNumDecisions = 0;
872
873 static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) {
874 if (const auto* block = rv.as<BlockRVNode>()) {
875 return sch->AddUnitLoop(GetRef<BlockRV>(block));
876 } else if (const auto* loop = rv.as<LoopRVNode>()) {
877 return sch->AddUnitLoop(GetRef<LoopRV>(loop));
878 } else {
879 LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block";
880 throw;
881 }
882 }
883
884 static String UnpackedAsPython(Array<String> outputs, String rv) {
885 PythonAPICall py("add_unit_loop");
886 py.Input("block_or_loop", rv);
887 py.SingleOutput(outputs);
888 return py.Str();
889 }
890
891 template <typename>
892 friend struct ::tvm::tir::UnpackedInstTraits;
893};
894
895TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
896TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
897TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
898TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);
899
900} // namespace tir
901} // namespace tvm
902