1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file loop_partition.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/arith/bound.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32#include <optional>
33#include <unordered_map>
34#include <unordered_set>
35
36#include "../../arith/interval_set.h"
37#include "../../runtime/thread_storage_scope.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> {
44 bool partition_const_loop;
45 bool no_unroll_loop_with_extent_one;
46 bool unroll_loop_with_partition_hint_no_interval;
47
48 TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") {
49 TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false);
50 TVM_ATTR_FIELD(no_unroll_loop_with_extent_one)
51 .describe("Don't unroll loops with extent 1")
52 .set_default(false);
53 TVM_ATTR_FIELD(unroll_loop_with_partition_hint_no_interval)
54 .describe("Unroll loops with pragma_loop_partition_hint and no interval")
55 .set_default(false);
56 }
57};
58
59class LoopPartitionConfig : public Attrs {
60 public:
61 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopPartitionConfig, Attrs, LoopPartitionConfigNode);
62};
63
64TVM_REGISTER_NODE_TYPE(LoopPartitionConfigNode);
65TVM_REGISTER_PASS_CONFIG_OPTION("tir.LoopPartition", LoopPartitionConfig);
66
67using arith::DeduceBound;
68using arith::Intersect;
69using arith::IntSet;
70
71using PartitionKey = std::pair<PrimExpr, bool>;
72struct PartitionKeyHash {
73 std::size_t operator()(PartitionKey const& k) const noexcept {
74 std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces)
75 std::size_t h2 = std::hash<bool>{}(k.second);
76 return h1 ^ h2;
77 }
78};
79
80struct PartitionKeyEqual {
81 bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
82 // NOLINTNEXTLINE(whitespace/braces)
83 return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
84 }
85};
86
87// Each mapping (cond, cond_value) -> interval represents the fact that
88// condition cond is proven to have value cond_value (true or false) in interval.
89using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;
90
91using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
92
93// Select potential candidate IRs that can be partitioned.
94// Rule:
95// - the range should not be const
96// - there exist a condition expression in the scope that use the var
97class CandidateSelector final : public StmtExprVisitor {
98 public:
99 using VarIsUsed = bool;
100 explicit CandidateSelector(bool partition_const_loop)
101 : partition_const_loop_(partition_const_loop) {}
102
103 void VisitStmt_(const ForNode* op) final {
104 // partition const loop when sets partition_const_loop_
105 if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
106 // always treat var with hint to be partitioned
107 const VarNode* var = op->loop_var.get();
108 if (partition_hint_vars.count(var)) {
109 candidates.insert(GetRef<Stmt>(op));
110 StmtExprVisitor::VisitStmt_(op);
111 return;
112 }
113 record_.insert({var, false});
114 StmtExprVisitor::VisitStmt_(op);
115 if (record_.at(var) && !no_split_) {
116 candidates.insert(GetRef<Stmt>(op));
117 }
118 record_.erase(var);
119 } else {
120 StmtExprVisitor::VisitStmt_(op);
121 }
122 }
123
124 void VisitStmt_(const AttrStmtNode* op) final {
125 if (op->attr_key == attr::thread_extent) {
126 const IterVarNode* iv = op->node.as<IterVarNode>();
127 ICHECK(iv);
128 Var var = iv->var;
129 runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
130 if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
131 // always treat var with hint to be partitioned
132 if (partition_hint_vars.count(var.get())) {
133 candidates.insert(GetRef<Stmt>(op));
134 StmtExprVisitor::VisitStmt_(op);
135 return;
136 }
137 record_.insert({var.get(), false});
138 StmtExprVisitor::VisitStmt_(op);
139 if (record_.at(var.get()) && !no_split_) {
140 candidates.insert(GetRef<Stmt>(op));
141 }
142 record_.erase(var.get());
143 return;
144 }
145 } else if (op->attr_key == attr::pragma_loop_partition_hint) {
146 if (analyzer_.CanProve(op->value)) {
147 const VarNode* var = nullptr;
148 if (op->node->IsInstance<VarNode>()) {
149 var = op->node.as<VarNode>();
150 } else if (op->node->IsInstance<IterVarNode>()) {
151 var = op->node.as<IterVarNode>()->var.get();
152 }
153 ICHECK(var);
154 partition_hint_vars.insert(var);
155 }
156 }
157 StmtExprVisitor::VisitStmt_(op);
158 }
159
160 void VisitStmt_(const SeqStmtNode* op) final {
161 bool init_no_split = no_split_;
162 for (Stmt stmt : op->seq) {
163 // erase the no split state of before visiting the next one.
164 bool temp = init_no_split;
165 std::swap(temp, no_split_);
166 this->VisitStmt(stmt);
167 // restore the no split flag.
168 no_split_ = no_split_ || temp;
169 }
170 }
171
172 void VisitExpr_(const CallNode* op) final {
173 if (op->op.same_as(builtin::likely())) {
174 in_likely_ = true;
175 StmtExprVisitor::VisitExpr_(op);
176 in_likely_ = false;
177 } else if (op->op.same_as(builtin::tvm_thread_allreduce())) {
178 // no split if the body contains allreduce.
179 no_split_ = true;
180 return;
181 } else {
182 StmtExprVisitor::VisitExpr_(op);
183 }
184 }
185
186 void VisitExpr_(const VarNode* op) final {
187 if (in_likely_ && record_.count(op)) {
188 record_.at(op) = true;
189 }
190 }
191
192 std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;
193 std::unordered_set<const VarNode*> partition_hint_vars;
194
195 private:
196 bool in_likely_{false};
197 bool no_split_{false};
198 bool partition_const_loop_{false};
199 std::unordered_map<const VarNode*, VarIsUsed> record_;
200 arith::Analyzer analyzer_;
201};
202
203// Finder try best to find partitions for hinted vars
204#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \
205 void VisitExpr_(const OpNodeT* op) final { \
206 if (has_partition_hint_) { \
207 DeduceCondition(GetRef<PrimExpr>(op)); \
208 return; \
209 } \
210 StmtExprVisitor::VisitExpr_(op); \
211 }
212
213// Populate partitions data structure, i.e., for a specific variable,
214// find an interval in which each condition has fixed true or false value
215class PartitionFinder : public StmtExprVisitor {
216 public:
217 explicit PartitionFinder(Var current_var,
218 const std::unordered_map<const VarNode*, IntSet>& hint_map,
219 const std::unordered_map<const VarNode*, IntSet>& relax_map,
220 bool has_partition_hint)
221 : current_var_(current_var),
222 has_partition_hint_(has_partition_hint),
223 hint_map_(hint_map),
224 relax_map_(relax_map) {
225 for (const auto& kv : hint_map) {
226 out_vars_.insert(kv.first);
227 }
228 for (const auto& kv : relax_map) {
229 out_vars_.insert(kv.first);
230 }
231 }
232
233 void VisitStmt_(const ForNode* op) final {
234 auto f_vset_contains = [this](const VarNode* var) { return out_vars_.count(var); };
235 if (UsesVar(op->min, f_vset_contains) || UsesVar(op->extent, f_vset_contains)) return;
236
237 const VarNode* var = op->loop_var.get();
238 hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
239 relax_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
240 StmtExprVisitor::VisitStmt_(op);
241 relax_map_.erase(var);
242 hint_map_.erase(var);
243 }
244
245 void VisitStmt_(const AttrStmtNode* op) final {
246 // handle thread_axis
247 if (op->attr_key == attr::thread_extent) {
248 const IterVarNode* thread_axis = op->node.as<IterVarNode>();
249 ICHECK(thread_axis);
250 const VarNode* var = thread_axis->var.get();
251 IntSet dom = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
252 hint_map_.insert({var, dom});
253 relax_map_.insert({var, dom});
254 StmtExprVisitor::VisitStmt_(op);
255 relax_map_.erase(var);
256 hint_map_.erase(var);
257 } else {
258 StmtExprVisitor::VisitStmt_(op);
259 }
260 }
261
262 void VisitExpr_(const CallNode* op) final {
263 if (op->op.same_as(builtin::likely())) {
264 DeduceCondition(op->args[0]);
265 } else {
266 StmtExprVisitor::VisitExpr_(op);
267 }
268 }
269
270 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode);
271 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode);
272 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode);
273 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode);
274 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode);
275 DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode);
276
277 Partition partitions;
278
279 private:
280 void DeduceCondition(const PrimExpr& cond) {
281 // For cond, find out the interval, if exists, in which we can prove that cond is
282 // true. Also find the interval, if exists, in which we can prove that cond is
283 // false.
284 if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) {
285 IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
286 if (!interval.IsNothing()) {
287 // cond is true within interval
288 partitions[{cond, true}] = interval;
289 }
290 PrimExpr inverse_cond = InverseCond(cond);
291 if (inverse_cond.defined()) {
292 IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
293 if (!interval.IsNothing()) {
294 // cond is false within interval
295 partitions[{cond, false}] = interval;
296 }
297 }
298 }
299 }
300
301 PrimExpr InverseCond(const PrimExpr& cond) {
302 PrimExpr inverse_cond;
303 if (const LTNode* op = cond.as<LTNode>()) {
304 // a < b -> a >= b
305 inverse_cond = GE(op->a, op->b);
306 } else if (const GTNode* op = cond.as<GTNode>()) {
307 // a > b -> a <= b
308 inverse_cond = LE(op->a, op->b);
309 } else if (const LENode* op = cond.as<LENode>()) {
310 // a <= b -> a > b
311 inverse_cond = GT(op->a, op->b);
312 } else if (const GENode* op = cond.as<GENode>()) {
313 // a >= b -> a < b
314 inverse_cond = LT(op->a, op->b);
315 } else if (const EQNode* op = cond.as<EQNode>()) {
316 // a == b -> a != b
317 inverse_cond = NE(op->a, op->b);
318 // a != b -> a == b
319 } else if (const NENode* op = cond.as<NENode>()) {
320 inverse_cond = EQ(op->a, op->b);
321 }
322 return inverse_cond;
323 }
324
325 Var current_var_;
326 bool has_partition_hint_;
327 std::unordered_set<const VarNode*> out_vars_;
328 std::unordered_map<const VarNode*, IntSet> hint_map_;
329 std::unordered_map<const VarNode*, IntSet> relax_map_;
330};
331
332// Replace the set of conditions given by ps with cond_value (true or false)
333class ConditionEliminator : public StmtExprMutator {
334 public:
335 explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
336 : ps_(ps), cond_value_(cond_value) {}
337
338 PrimExpr VisitExpr(const PrimExpr& e) final {
339 if (ps_.find(e) != ps_.end()) {
340 return VisitExpr(cond_value_ ? const_true() : const_false());
341 }
342 return StmtExprMutator::VisitExpr(e);
343 }
344
345 private:
346 ExpressionSet ps_;
347 bool cond_value_;
348};
349
350// Insert the partition branch at the innermost thread scope
351class ThreadPartitionInserter : public StmtMutator {
352 public:
353 explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
354 : ps_(ps), cond_(cond), innermost_thread_scope_(false) {}
355
356 Stmt VisitStmt_(const AttrStmtNode* op) final {
357 if (op->attr_key == attr::thread_extent) {
358 innermost_thread_scope_ = true;
359 Stmt stmt = StmtMutator::VisitStmt_(op);
360 // add branch code inside the innermost thread scope
361 if (innermost_thread_scope_) {
362 Stmt simplified_body = ConditionEliminator(ps_)(op->body);
363 Stmt body = IfThenElse(cond_, simplified_body, op->body);
364 PrimExpr value = this->VisitExpr(op->value);
365 stmt = AttrStmt(op->node, op->attr_key, value, body);
366 }
367 innermost_thread_scope_ = false;
368 return stmt;
369 } else {
370 return StmtMutator::VisitStmt_(op);
371 }
372 }
373
374 private:
375 const ExpressionSet& ps_;
376 PrimExpr cond_;
377 bool innermost_thread_scope_;
378};
379
380// Try to partition range of iteration variables in order to remove (some)
381// likely conditions
382class LoopPartitioner : public StmtMutator {
383 public:
384 explicit LoopPartitioner(bool partition_const_loop, bool no_unroll_loop_with_extent_one,
385 bool unroll_loop_with_partition_hint_no_interval)
386 : selector(CandidateSelector(partition_const_loop)),
387 no_unroll_loop_with_extent_one_(no_unroll_loop_with_extent_one),
388 unroll_loop_with_partition_hint_no_interval_(unroll_loop_with_partition_hint_no_interval) {}
389
390 Stmt VisitAndMutate(Stmt stmt) {
391 selector(stmt);
392 return operator()(std::move(stmt));
393 }
394
395 Stmt VisitStmt_(const ForNode* op) final {
396 analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true);
397 auto fs = GetRef<Stmt>(op);
398 if (selector.candidates.count(fs)) {
399 Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
400 if (s.defined()) return s;
401 }
402
403 // normal path when loop partition fails
404 // normal loop variable can be put into hint map.
405 hint_map_.insert({op->loop_var.get(), IntSet::Interval(op->min, op->min + op->extent - 1)});
406 Stmt res = StmtMutator::VisitStmt_(op);
407 hint_map_.erase(op->loop_var.get());
408 return res;
409 }
410
411 Stmt VisitStmt_(const AttrStmtNode* op) final {
412 if (op->attr_key != attr::thread_extent) {
413 return StmtMutator::VisitStmt_(op);
414 }
415
416 const IterVarNode* iv = op->node.as<IterVarNode>();
417 ICHECK(iv);
418 Var var = iv->var;
419 auto as = GetRef<Stmt>(op);
420 if (selector.candidates.count(as)) {
421 Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
422 if (s.defined()) return s;
423 }
424
425 // normal path when loop parittion fails.
426 runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
427 Stmt res;
428 if (scope.rank == 1) {
429 // threadIdx should be put into relax map, in case of divergence.
430 relax_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
431 res = StmtMutator::VisitStmt_(op);
432 relax_map_.erase(var.get());
433 } else {
434 hint_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
435 res = StmtMutator::VisitStmt_(op);
436 hint_map_.erase(var.get());
437 }
438 return res;
439 }
440
441 private:
442 Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
443 bool partition_thread_scope);
444
445 std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
446 const arith::IntervalSet& for_interval,
447 bool cond_value, bool has_partition_hint);
448
449 inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);
450
451 /* Candidate IRs that may be partitioned potentially */
452 std::unordered_map<const VarNode*, IntSet> hint_map_;
453 std::unordered_map<const VarNode*, IntSet> relax_map_;
454 arith::Analyzer analyzer_;
455 CandidateSelector selector;
456 bool no_unroll_loop_with_extent_one_;
457 bool unroll_loop_with_partition_hint_no_interval_;
458};
459
460// Returns an interval (in the first component) in which all the conditions
461// given in the second component provably have value given by cond_value
462std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
463 const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value,
464 bool has_partition_hint) {
465 Array<IntSet> sets;
466 ExpressionSet cond_set;
467
468 for (const auto& kv : partitions) {
469 if (kv.first.second == cond_value) {
470 arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
471 arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);
472 if (!intersection->IsEmpty()) {
473 sets.push_back(kv.second);
474 cond_set.insert(kv.first.first);
475 }
476 }
477 }
478 IntSet interval = sets.empty() ? IntSet::Nothing() : Intersect(sets);
479
480 // Try to find the intersection of the cond_intervals until the intersection
481 // is nothing when has_partition_hint is true.
482 if (interval.IsNothing() && has_partition_hint) {
483 arith::IntervalSet cond_intersection = arith::IntervalSet::Everything();
484 cond_set.clear();
485
486 for (const auto& kv : partitions) {
487 if (kv.first.second == cond_value) {
488 arith::IntervalSet cond_interval = Downcast<arith::IntervalSet>(kv.second);
489 arith::IntervalSet intersection = arith::Intersect(&analyzer_, cond_interval, for_interval);
490 if (!intersection->IsEmpty()) {
491 cond_intersection = arith::Intersect(&analyzer_, cond_intersection, cond_interval);
492 // Return the latest interval and cond_set if the cond_intersection is nothing.
493 if (!cond_intersection->IsEmpty()) {
494 cond_set.insert(kv.first.first);
495 interval = arith::IntervalSet(analyzer_.Simplify(cond_intersection->min_value),
496 analyzer_.Simplify(cond_intersection->max_value));
497 } else {
498 break;
499 }
500 }
501 }
502 }
503 }
504
505 return std::make_pair(interval, cond_set);
506}
507
508/*
509 * Tries to recursively partition the range of the variable (given by var) of
510 * the for loop (given by node and stmt) into a
511 * number of disjoint ranges such that in some ranges one or more predicates
512 * in the loopnest are provably true or false in each range. For example, given the
513 * following loop to partition:
514 * for (i = 0; i < 4; i++)
515 * for (j = 0; j < 10; j++)
516 * if (likely(i*10 + j < 36))
517 * A[10*i+j] = B[10*i+j]
518 *
519 * We first partition range of i, i.e., [0,3] into subranges [0,2] and [3,3] because the
520 * likely condition is always true for the first subrange but not always true for the
521 * second subrange. Therefore, we'll have
522 * for (i = 0; i < 3; i++)
523 * for (j = 0; j < 10; j++)
524 * if (likely(1))
525 * A[10*i+j] = B[10*i+j]
526 * for (i = 0; i < 1; i++)
527 * for (j = 0; j < 10; j++)
528 * if (likely((i+3)*10 + j < 36))
529 * A[10*(i+3)+j] = B[10*(i+3)+j]
530 * Which is simplified as:
531 * for (i = 0; i < 3; i++)
532 * for (j = 0; j < 10; j++)
533 * A[10*i+j] = B[10*i+j]
534 * for (j = 0; j < 10; j++) // loopnest 1
535 * if (likely(j < 6))
536 * A[30+j] = B[30+j]
537 * Now, we recursively partition j in loopnest 1 into subranges [0,5] and [6,9] where the
538 * condition is true for the first subrange and now always true for the second subrange.
539 * for (j = 0; j < 6; j++)
540 * if (likely(1))
541 * A[30+j] = B[30+j]
542 * for (j = 0; j < 4; j++) // loop 2
543 * if (likely(j < 0))
544 * A[36+j] = B[36+j]
545 * Finally we recursively partition loop 2 above into subrange [0,3] where the
546 * condition is false and empty interval where the condition is not false,
547 * therefore we generate
548 * for (j = 0; j < 4; j++)
549 * if (likely(0))
550 * A[36+j] = B[36+j]
551 * which will eventually be simplified to empty code. And because only one loop was generated
552 * from loop 2 we stop recursing.
553 */
554Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
555 bool partition_thread_scope) {
556 using namespace arith;
557 // include hint of var.
558 hint_map_.insert({var.get(), IntSet::Interval(min, max)});
559
560 bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
561 PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
562 finder(body);
563
564 hint_map_.erase(var.get());
565 if (finder.partitions.empty()) return Stmt();
566
567 arith::IntervalSet for_interval(min, max);
568
569 auto [middle_interval, cond_set,
570 opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> {
571 {
572 // find an interval in which all conditions on var are true
573 auto [middle_interval, cond_set] =
574 GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
575 if (!middle_interval.IsNothing()) {
576 return {middle_interval, cond_set, true};
577 }
578 }
579
580 {
581 // if such interval doesn't exist, find an interval in which all
582 // conditions on var are false
583 auto [middle_interval, cond_set] =
584 GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
585
586 if (!middle_interval.IsNothing()) {
587 return {middle_interval, cond_set, false};
588 }
589 }
590
591 // we couldn't find an interval in which the conditions are
592 // provably true or false. Therefore, we can't partition the loop
593 // based on those conds
594 return {{}, {}, std::nullopt};
595 }();
596
597 if (!opt_cond_value.has_value()) {
598 if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
599 analyzer_.CanProve(max - min > 0)) {
600 auto new_body = VisitAndMutate(body);
601 return For(var, min, max - min + 1, ForKind::kUnrolled, new_body);
602 }
603 return Stmt();
604 }
605 bool cond_value = opt_cond_value.value();
606
607 IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
608 // middle_interval is the subrange of the loop variable range for which a
609 // set of conditions are true (or false resp.)
610 // The part of the loop variable range that is before (after resp.) that
611 // subrange is prefixed with pre- (post- resp.)
612
613 // Calculating pre-subrange and generating code for it.
614 // pre-subrange = [min, body_begin)
615 PrimExpr body_begin;
616 Stmt pre_stmt;
617 bool pre_stmt_recurse = true;
618 if (middle_interval_i->HasLowerBound()) {
619 body_begin = analyzer_.Simplify(middle_interval.min());
620 if (!analyzer_.CanProve(body_begin == min)) {
621 PrimExpr extent = analyzer_.Simplify(body_begin - min);
622 if (!analyzer_.CanProve(extent > 0)) {
623 body_begin = tvm::max(body_begin, min);
624 // stop recursing on this interval if we can't prove it has non-negative length
625 pre_stmt_recurse = false;
626 }
627 if (!analyzer_.CanProve(extent <= 0)) {
628 if (!partition_thread_scope) {
629 Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
630 pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
631 }
632 }
633 }
634 } else {
635 body_begin = min;
636 }
637
638 // Calculating post-subrange and generating code for it.
639 // post-subrange = [post_doubt_begin, max+1)
640 PrimExpr post_doubt_begin;
641 Stmt post_stmt;
642 bool post_stmt_recurse = true;
643 if (middle_interval_i->HasUpperBound()) {
644 post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
645 if (!analyzer_.CanProve(middle_interval.max() == max)) {
646 // require the extent to be non-negative
647 PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1);
648 if (!analyzer_.CanProve(extent > 0)) {
649 post_doubt_begin = tvm::min(post_doubt_begin, max + 1);
650 // stop recursing on this interval if we can't prove it has non-negative length
651 post_stmt_recurse = false;
652 }
653 if (!analyzer_.CanProve(extent <= 0)) {
654 if (!partition_thread_scope) {
655 Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
656 post_stmt = MakeFor(stmt.get(), extent, post_body);
657 }
658 }
659 }
660 } else {
661 post_doubt_begin = max + 1;
662 }
663
664 Stmt s;
665
666 // Generating code for middle subrange
667 if (!partition_thread_scope) {
668 Stmt mid_stmt;
669 if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
670 // [body_begin, post_doubt_begin)
671 Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
672 Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
673 mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
674 // Recurse until partitions is empty
675 mid_stmt = VisitAndMutate(mid_stmt);
676 // Recurse for each non-empty subrange only if there are at least
677 // two non-empty subranges
678 if (pre_stmt.defined() || post_stmt.defined()) {
679 if (pre_stmt.defined() && pre_stmt_recurse) {
680 pre_stmt = VisitAndMutate(pre_stmt);
681 }
682 if (post_stmt.defined() && post_stmt_recurse) {
683 post_stmt = VisitAndMutate(post_stmt);
684 }
685 }
686 }
687 s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
688 } else {
689 PrimExpr cond = const_true();
690 if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
691 if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
692 s = ThreadPartitionInserter(cond_set, cond)(stmt);
693 }
694 s = ConvertSSA(s);
695 return s;
696}
697
698inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) {
699 const ForNode* for_node = static_cast<const ForNode*>(node);
700 ICHECK(for_node);
701 if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
702 !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
703 // If the loop extent is 1, do not create the loop anymore
704 return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
705 } else {
706 ICHECK(for_node->kind != ForKind::kThreadBinding);
707 return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body,
708 for_node->thread_binding, for_node->annotations);
709 }
710}
711
712class RemoveLikelyTagsAndHints : public StmtExprMutator {
713 public:
714 PrimExpr VisitExpr_(const CallNode* op) final {
715 if (op->op.same_as(builtin::likely())) {
716 ICHECK_EQ(op->args.size(), 1);
717 return StmtExprMutator::VisitExpr(op->args[0]);
718 } else {
719 return StmtExprMutator::VisitExpr_(op);
720 }
721 }
722
723 Stmt VisitStmt_(const AttrStmtNode* op) final {
724 if (op->attr_key == attr::pragma_loop_partition_hint) {
725 return VisitStmt(op->body);
726 }
727 return StmtExprMutator::VisitStmt_(op);
728 }
729};
730
731Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one,
732 bool unroll_loop_with_partition_hint_no_interval) {
733 stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one,
734 unroll_loop_with_partition_hint_no_interval)
735 .VisitAndMutate(std::move(stmt));
736 stmt = RemoveLikelyTagsAndHints()(std::move(stmt));
737 return stmt;
738}
739
740namespace transform {
741
742Pass LoopPartition() {
743 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
744 auto* n = f.CopyOnWrite();
745 auto cfg = ctx->GetConfig<LoopPartitionConfig>("tir.LoopPartition");
746 if (!cfg.defined()) {
747 cfg = AttrsWithDefaultValues<LoopPartitionConfig>();
748 }
749 n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop,
750 cfg.value()->no_unroll_loop_with_extent_one,
751 cfg.value()->unroll_loop_with_partition_hint_no_interval);
752 return f;
753 };
754 return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
755}
756
757TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition);
758
759} // namespace transform
760
761} // namespace tir
762} // namespace tvm
763