1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file loop_partition.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/arith/bound.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <optional> |
33 | #include <unordered_map> |
34 | #include <unordered_set> |
35 | |
36 | #include "../../arith/interval_set.h" |
37 | #include "../../runtime/thread_storage_scope.h" |
38 | #include "ir_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | struct LoopPartitionConfigNode : public tvm::AttrsNode<LoopPartitionConfigNode> { |
44 | bool partition_const_loop; |
45 | bool no_unroll_loop_with_extent_one; |
46 | bool unroll_loop_with_partition_hint_no_interval; |
47 | |
48 | TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig" ) { |
49 | TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop" ).set_default(false); |
50 | TVM_ATTR_FIELD(no_unroll_loop_with_extent_one) |
51 | .describe("Don't unroll loops with extent 1" ) |
52 | .set_default(false); |
53 | TVM_ATTR_FIELD(unroll_loop_with_partition_hint_no_interval) |
54 | .describe("Unroll loops with pragma_loop_partition_hint and no interval" ) |
55 | .set_default(false); |
56 | } |
57 | }; |
58 | |
59 | class LoopPartitionConfig : public Attrs { |
60 | public: |
61 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopPartitionConfig, Attrs, LoopPartitionConfigNode); |
62 | }; |
63 | |
64 | TVM_REGISTER_NODE_TYPE(LoopPartitionConfigNode); |
65 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.LoopPartition" , LoopPartitionConfig); |
66 | |
67 | using arith::DeduceBound; |
68 | using arith::Intersect; |
69 | using arith::IntSet; |
70 | |
71 | using PartitionKey = std::pair<PrimExpr, bool>; |
72 | struct PartitionKeyHash { |
73 | std::size_t operator()(PartitionKey const& k) const noexcept { |
74 | std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces) |
75 | std::size_t h2 = std::hash<bool>{}(k.second); |
76 | return h1 ^ h2; |
77 | } |
78 | }; |
79 | |
80 | struct PartitionKeyEqual { |
81 | bool operator()(const PartitionKey& k1, const PartitionKey& k2) const { |
82 | // NOLINTNEXTLINE(whitespace/braces) |
83 | return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first); |
84 | } |
85 | }; |
86 | |
87 | // Each mapping (cond, cond_value) -> interval represents the fact that |
88 | // condition cond is proven to have value cond_value (true or false) in interval. |
89 | using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>; |
90 | |
91 | using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>; |
92 | |
93 | // Select potential candidate IRs that can be partitioned. |
94 | // Rule: |
95 | // - the range should not be const |
96 | // - there exist a condition expression in the scope that use the var |
97 | class CandidateSelector final : public StmtExprVisitor { |
98 | public: |
99 | using VarIsUsed = bool; |
100 | explicit CandidateSelector(bool partition_const_loop) |
101 | : partition_const_loop_(partition_const_loop) {} |
102 | |
103 | void VisitStmt_(const ForNode* op) final { |
104 | // partition const loop when sets partition_const_loop_ |
105 | if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) { |
106 | // always treat var with hint to be partitioned |
107 | const VarNode* var = op->loop_var.get(); |
108 | if (partition_hint_vars.count(var)) { |
109 | candidates.insert(GetRef<Stmt>(op)); |
110 | StmtExprVisitor::VisitStmt_(op); |
111 | return; |
112 | } |
113 | record_.insert({var, false}); |
114 | StmtExprVisitor::VisitStmt_(op); |
115 | if (record_.at(var) && !no_split_) { |
116 | candidates.insert(GetRef<Stmt>(op)); |
117 | } |
118 | record_.erase(var); |
119 | } else { |
120 | StmtExprVisitor::VisitStmt_(op); |
121 | } |
122 | } |
123 | |
124 | void VisitStmt_(const AttrStmtNode* op) final { |
125 | if (op->attr_key == attr::thread_extent) { |
126 | const IterVarNode* iv = op->node.as<IterVarNode>(); |
127 | ICHECK(iv); |
128 | Var var = iv->var; |
129 | runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); |
130 | if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) { |
131 | // always treat var with hint to be partitioned |
132 | if (partition_hint_vars.count(var.get())) { |
133 | candidates.insert(GetRef<Stmt>(op)); |
134 | StmtExprVisitor::VisitStmt_(op); |
135 | return; |
136 | } |
137 | record_.insert({var.get(), false}); |
138 | StmtExprVisitor::VisitStmt_(op); |
139 | if (record_.at(var.get()) && !no_split_) { |
140 | candidates.insert(GetRef<Stmt>(op)); |
141 | } |
142 | record_.erase(var.get()); |
143 | return; |
144 | } |
145 | } else if (op->attr_key == attr::pragma_loop_partition_hint) { |
146 | if (analyzer_.CanProve(op->value)) { |
147 | const VarNode* var = nullptr; |
148 | if (op->node->IsInstance<VarNode>()) { |
149 | var = op->node.as<VarNode>(); |
150 | } else if (op->node->IsInstance<IterVarNode>()) { |
151 | var = op->node.as<IterVarNode>()->var.get(); |
152 | } |
153 | ICHECK(var); |
154 | partition_hint_vars.insert(var); |
155 | } |
156 | } |
157 | StmtExprVisitor::VisitStmt_(op); |
158 | } |
159 | |
160 | void VisitStmt_(const SeqStmtNode* op) final { |
161 | bool init_no_split = no_split_; |
162 | for (Stmt stmt : op->seq) { |
163 | // erase the no split state of before visiting the next one. |
164 | bool temp = init_no_split; |
165 | std::swap(temp, no_split_); |
166 | this->VisitStmt(stmt); |
167 | // restore the no split flag. |
168 | no_split_ = no_split_ || temp; |
169 | } |
170 | } |
171 | |
172 | void VisitExpr_(const CallNode* op) final { |
173 | if (op->op.same_as(builtin::likely())) { |
174 | in_likely_ = true; |
175 | StmtExprVisitor::VisitExpr_(op); |
176 | in_likely_ = false; |
177 | } else if (op->op.same_as(builtin::tvm_thread_allreduce())) { |
178 | // no split if the body contains allreduce. |
179 | no_split_ = true; |
180 | return; |
181 | } else { |
182 | StmtExprVisitor::VisitExpr_(op); |
183 | } |
184 | } |
185 | |
186 | void VisitExpr_(const VarNode* op) final { |
187 | if (in_likely_ && record_.count(op)) { |
188 | record_.at(op) = true; |
189 | } |
190 | } |
191 | |
192 | std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates; |
193 | std::unordered_set<const VarNode*> partition_hint_vars; |
194 | |
195 | private: |
196 | bool in_likely_{false}; |
197 | bool no_split_{false}; |
198 | bool partition_const_loop_{false}; |
199 | std::unordered_map<const VarNode*, VarIsUsed> record_; |
200 | arith::Analyzer analyzer_; |
201 | }; |
202 | |
203 | // Finder try best to find partitions for hinted vars |
204 | #define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ |
205 | void VisitExpr_(const OpNodeT* op) final { \ |
206 | if (has_partition_hint_) { \ |
207 | DeduceCondition(GetRef<PrimExpr>(op)); \ |
208 | return; \ |
209 | } \ |
210 | StmtExprVisitor::VisitExpr_(op); \ |
211 | } |
212 | |
213 | // Populate partitions data structure, i.e., for a specific variable, |
214 | // find an interval in which each condition has fixed true or false value |
215 | class PartitionFinder : public StmtExprVisitor { |
216 | public: |
217 | explicit PartitionFinder(Var current_var, |
218 | const std::unordered_map<const VarNode*, IntSet>& hint_map, |
219 | const std::unordered_map<const VarNode*, IntSet>& relax_map, |
220 | bool has_partition_hint) |
221 | : current_var_(current_var), |
222 | has_partition_hint_(has_partition_hint), |
223 | hint_map_(hint_map), |
224 | relax_map_(relax_map) { |
225 | for (const auto& kv : hint_map) { |
226 | out_vars_.insert(kv.first); |
227 | } |
228 | for (const auto& kv : relax_map) { |
229 | out_vars_.insert(kv.first); |
230 | } |
231 | } |
232 | |
233 | void VisitStmt_(const ForNode* op) final { |
234 | auto f_vset_contains = [this](const VarNode* var) { return out_vars_.count(var); }; |
235 | if (UsesVar(op->min, f_vset_contains) || UsesVar(op->extent, f_vset_contains)) return; |
236 | |
237 | const VarNode* var = op->loop_var.get(); |
238 | hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); |
239 | relax_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); |
240 | StmtExprVisitor::VisitStmt_(op); |
241 | relax_map_.erase(var); |
242 | hint_map_.erase(var); |
243 | } |
244 | |
245 | void VisitStmt_(const AttrStmtNode* op) final { |
246 | // handle thread_axis |
247 | if (op->attr_key == attr::thread_extent) { |
248 | const IterVarNode* thread_axis = op->node.as<IterVarNode>(); |
249 | ICHECK(thread_axis); |
250 | const VarNode* var = thread_axis->var.get(); |
251 | IntSet dom = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value)); |
252 | hint_map_.insert({var, dom}); |
253 | relax_map_.insert({var, dom}); |
254 | StmtExprVisitor::VisitStmt_(op); |
255 | relax_map_.erase(var); |
256 | hint_map_.erase(var); |
257 | } else { |
258 | StmtExprVisitor::VisitStmt_(op); |
259 | } |
260 | } |
261 | |
262 | void VisitExpr_(const CallNode* op) final { |
263 | if (op->op.same_as(builtin::likely())) { |
264 | DeduceCondition(op->args[0]); |
265 | } else { |
266 | StmtExprVisitor::VisitExpr_(op); |
267 | } |
268 | } |
269 | |
270 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode); |
271 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode); |
272 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode); |
273 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode); |
274 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode); |
275 | DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode); |
276 | |
277 | Partition partitions; |
278 | |
279 | private: |
280 | void DeduceCondition(const PrimExpr& cond) { |
281 | // For cond, find out the interval, if exists, in which we can prove that cond is |
282 | // true. Also find the interval, if exists, in which we can prove that cond is |
283 | // false. |
284 | if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { |
285 | IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); |
286 | if (!interval.IsNothing()) { |
287 | // cond is true within interval |
288 | partitions[{cond, true}] = interval; |
289 | } |
290 | PrimExpr inverse_cond = InverseCond(cond); |
291 | if (inverse_cond.defined()) { |
292 | IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); |
293 | if (!interval.IsNothing()) { |
294 | // cond is false within interval |
295 | partitions[{cond, false}] = interval; |
296 | } |
297 | } |
298 | } |
299 | } |
300 | |
301 | PrimExpr InverseCond(const PrimExpr& cond) { |
302 | PrimExpr inverse_cond; |
303 | if (const LTNode* op = cond.as<LTNode>()) { |
304 | // a < b -> a >= b |
305 | inverse_cond = GE(op->a, op->b); |
306 | } else if (const GTNode* op = cond.as<GTNode>()) { |
307 | // a > b -> a <= b |
308 | inverse_cond = LE(op->a, op->b); |
309 | } else if (const LENode* op = cond.as<LENode>()) { |
310 | // a <= b -> a > b |
311 | inverse_cond = GT(op->a, op->b); |
312 | } else if (const GENode* op = cond.as<GENode>()) { |
313 | // a >= b -> a < b |
314 | inverse_cond = LT(op->a, op->b); |
315 | } else if (const EQNode* op = cond.as<EQNode>()) { |
316 | // a == b -> a != b |
317 | inverse_cond = NE(op->a, op->b); |
318 | // a != b -> a == b |
319 | } else if (const NENode* op = cond.as<NENode>()) { |
320 | inverse_cond = EQ(op->a, op->b); |
321 | } |
322 | return inverse_cond; |
323 | } |
324 | |
325 | Var current_var_; |
326 | bool has_partition_hint_; |
327 | std::unordered_set<const VarNode*> out_vars_; |
328 | std::unordered_map<const VarNode*, IntSet> hint_map_; |
329 | std::unordered_map<const VarNode*, IntSet> relax_map_; |
330 | }; |
331 | |
332 | // Replace the set of conditions given by ps with cond_value (true or false) |
333 | class ConditionEliminator : public StmtExprMutator { |
334 | public: |
335 | explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true) |
336 | : ps_(ps), cond_value_(cond_value) {} |
337 | |
338 | PrimExpr VisitExpr(const PrimExpr& e) final { |
339 | if (ps_.find(e) != ps_.end()) { |
340 | return VisitExpr(cond_value_ ? const_true() : const_false()); |
341 | } |
342 | return StmtExprMutator::VisitExpr(e); |
343 | } |
344 | |
345 | private: |
346 | ExpressionSet ps_; |
347 | bool cond_value_; |
348 | }; |
349 | |
350 | // Insert the partition branch at the innermost thread scope |
351 | class ThreadPartitionInserter : public StmtMutator { |
352 | public: |
353 | explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond) |
354 | : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} |
355 | |
356 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
357 | if (op->attr_key == attr::thread_extent) { |
358 | innermost_thread_scope_ = true; |
359 | Stmt stmt = StmtMutator::VisitStmt_(op); |
360 | // add branch code inside the innermost thread scope |
361 | if (innermost_thread_scope_) { |
362 | Stmt simplified_body = ConditionEliminator(ps_)(op->body); |
363 | Stmt body = IfThenElse(cond_, simplified_body, op->body); |
364 | PrimExpr value = this->VisitExpr(op->value); |
365 | stmt = AttrStmt(op->node, op->attr_key, value, body); |
366 | } |
367 | innermost_thread_scope_ = false; |
368 | return stmt; |
369 | } else { |
370 | return StmtMutator::VisitStmt_(op); |
371 | } |
372 | } |
373 | |
374 | private: |
375 | const ExpressionSet& ps_; |
376 | PrimExpr cond_; |
377 | bool innermost_thread_scope_; |
378 | }; |
379 | |
380 | // Try to partition range of iteration variables in order to remove (some) |
381 | // likely conditions |
382 | class LoopPartitioner : public StmtMutator { |
383 | public: |
384 | explicit LoopPartitioner(bool partition_const_loop, bool no_unroll_loop_with_extent_one, |
385 | bool unroll_loop_with_partition_hint_no_interval) |
386 | : selector(CandidateSelector(partition_const_loop)), |
387 | no_unroll_loop_with_extent_one_(no_unroll_loop_with_extent_one), |
388 | unroll_loop_with_partition_hint_no_interval_(unroll_loop_with_partition_hint_no_interval) {} |
389 | |
390 | Stmt VisitAndMutate(Stmt stmt) { |
391 | selector(stmt); |
392 | return operator()(std::move(stmt)); |
393 | } |
394 | |
395 | Stmt VisitStmt_(const ForNode* op) final { |
396 | analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true); |
397 | auto fs = GetRef<Stmt>(op); |
398 | if (selector.candidates.count(fs)) { |
399 | Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); |
400 | if (s.defined()) return s; |
401 | } |
402 | |
403 | // normal path when loop partition fails |
404 | // normal loop variable can be put into hint map. |
405 | hint_map_.insert({op->loop_var.get(), IntSet::Interval(op->min, op->min + op->extent - 1)}); |
406 | Stmt res = StmtMutator::VisitStmt_(op); |
407 | hint_map_.erase(op->loop_var.get()); |
408 | return res; |
409 | } |
410 | |
411 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
412 | if (op->attr_key != attr::thread_extent) { |
413 | return StmtMutator::VisitStmt_(op); |
414 | } |
415 | |
416 | const IterVarNode* iv = op->node.as<IterVarNode>(); |
417 | ICHECK(iv); |
418 | Var var = iv->var; |
419 | auto as = GetRef<Stmt>(op); |
420 | if (selector.candidates.count(as)) { |
421 | Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true); |
422 | if (s.defined()) return s; |
423 | } |
424 | |
425 | // normal path when loop parittion fails. |
426 | runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); |
427 | Stmt res; |
428 | if (scope.rank == 1) { |
429 | // threadIdx should be put into relax map, in case of divergence. |
430 | relax_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)}); |
431 | res = StmtMutator::VisitStmt_(op); |
432 | relax_map_.erase(var.get()); |
433 | } else { |
434 | hint_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)}); |
435 | res = StmtMutator::VisitStmt_(op); |
436 | hint_map_.erase(var.get()); |
437 | } |
438 | return res; |
439 | } |
440 | |
441 | private: |
442 | Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, |
443 | bool partition_thread_scope); |
444 | |
445 | std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions, |
446 | const arith::IntervalSet& for_interval, |
447 | bool cond_value, bool has_partition_hint); |
448 | |
449 | inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); |
450 | |
451 | /* Candidate IRs that may be partitioned potentially */ |
452 | std::unordered_map<const VarNode*, IntSet> hint_map_; |
453 | std::unordered_map<const VarNode*, IntSet> relax_map_; |
454 | arith::Analyzer analyzer_; |
455 | CandidateSelector selector; |
456 | bool no_unroll_loop_with_extent_one_; |
457 | bool unroll_loop_with_partition_hint_no_interval_; |
458 | }; |
459 | |
460 | // Returns an interval (in the first component) in which all the conditions |
461 | // given in the second component provably have value given by cond_value |
462 | std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset( |
463 | const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value, |
464 | bool has_partition_hint) { |
465 | Array<IntSet> sets; |
466 | ExpressionSet cond_set; |
467 | |
468 | for (const auto& kv : partitions) { |
469 | if (kv.first.second == cond_value) { |
470 | arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second); |
471 | arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval); |
472 | if (!intersection->IsEmpty()) { |
473 | sets.push_back(kv.second); |
474 | cond_set.insert(kv.first.first); |
475 | } |
476 | } |
477 | } |
478 | IntSet interval = sets.empty() ? IntSet::Nothing() : Intersect(sets); |
479 | |
480 | // Try to find the intersection of the cond_intervals until the intersection |
481 | // is nothing when has_partition_hint is true. |
482 | if (interval.IsNothing() && has_partition_hint) { |
483 | arith::IntervalSet cond_intersection = arith::IntervalSet::Everything(); |
484 | cond_set.clear(); |
485 | |
486 | for (const auto& kv : partitions) { |
487 | if (kv.first.second == cond_value) { |
488 | arith::IntervalSet cond_interval = Downcast<arith::IntervalSet>(kv.second); |
489 | arith::IntervalSet intersection = arith::Intersect(&analyzer_, cond_interval, for_interval); |
490 | if (!intersection->IsEmpty()) { |
491 | cond_intersection = arith::Intersect(&analyzer_, cond_intersection, cond_interval); |
492 | // Return the latest interval and cond_set if the cond_intersection is nothing. |
493 | if (!cond_intersection->IsEmpty()) { |
494 | cond_set.insert(kv.first.first); |
495 | interval = arith::IntervalSet(analyzer_.Simplify(cond_intersection->min_value), |
496 | analyzer_.Simplify(cond_intersection->max_value)); |
497 | } else { |
498 | break; |
499 | } |
500 | } |
501 | } |
502 | } |
503 | } |
504 | |
505 | return std::make_pair(interval, cond_set); |
506 | } |
507 | |
508 | /* |
509 | * Tries to recursively partition the range of the variable (given by var) of |
510 | * the for loop (given by node and stmt) into a |
511 | * number of disjoint ranges such that in some ranges one or more predicates |
512 | * in the loopnest are provably true or false in each range. For example, given the |
513 | * following loop to partition: |
514 | * for (i = 0; i < 4; i++) |
515 | * for (j = 0; j < 10; j++) |
516 | * if (likely(i*10 + j < 36)) |
517 | * A[10*i+j] = B[10*i+j] |
518 | * |
519 | * We first partition range of i, i.e., [0,3] into subranges [0,2] and [3,3] because the |
520 | * likely condition is always true for the first subrange but not always true for the |
521 | * second subrange. Therefore, we'll have |
522 | * for (i = 0; i < 3; i++) |
523 | * for (j = 0; j < 10; j++) |
524 | * if (likely(1)) |
525 | * A[10*i+j] = B[10*i+j] |
526 | * for (i = 0; i < 1; i++) |
527 | * for (j = 0; j < 10; j++) |
528 | * if (likely((i+3)*10 + j < 36)) |
529 | * A[10*(i+3)+j] = B[10*(i+3)+j] |
530 | * Which is simplified as: |
531 | * for (i = 0; i < 3; i++) |
532 | * for (j = 0; j < 10; j++) |
533 | * A[10*i+j] = B[10*i+j] |
534 | * for (j = 0; j < 10; j++) // loopnest 1 |
535 | * if (likely(j < 6)) |
536 | * A[30+j] = B[30+j] |
537 | * Now, we recursively partition j in loopnest 1 into subranges [0,5] and [6,9] where the |
538 | * condition is true for the first subrange and now always true for the second subrange. |
539 | * for (j = 0; j < 6; j++) |
540 | * if (likely(1)) |
541 | * A[30+j] = B[30+j] |
542 | * for (j = 0; j < 4; j++) // loop 2 |
543 | * if (likely(j < 0)) |
544 | * A[36+j] = B[36+j] |
545 | * Finally we recursively partition loop 2 above into subrange [0,3] where the |
546 | * condition is false and empty interval where the condition is not false, |
547 | * therefore we generate |
548 | * for (j = 0; j < 4; j++) |
549 | * if (likely(0)) |
550 | * A[36+j] = B[36+j] |
551 | * which will eventually be simplified to empty code. And because only one loop was generated |
552 | * from loop 2 we stop recursing. |
553 | */ |
554 | Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body, |
555 | bool partition_thread_scope) { |
556 | using namespace arith; |
557 | // include hint of var. |
558 | hint_map_.insert({var.get(), IntSet::Interval(min, max)}); |
559 | |
560 | bool has_partition_hint_ = selector.partition_hint_vars.count(var.get()); |
561 | PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_); |
562 | finder(body); |
563 | |
564 | hint_map_.erase(var.get()); |
565 | if (finder.partitions.empty()) return Stmt(); |
566 | |
567 | arith::IntervalSet for_interval(min, max); |
568 | |
569 | auto [middle_interval, cond_set, |
570 | opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> { |
571 | { |
572 | // find an interval in which all conditions on var are true |
573 | auto [middle_interval, cond_set] = |
574 | GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_); |
575 | if (!middle_interval.IsNothing()) { |
576 | return {middle_interval, cond_set, true}; |
577 | } |
578 | } |
579 | |
580 | { |
581 | // if such interval doesn't exist, find an interval in which all |
582 | // conditions on var are false |
583 | auto [middle_interval, cond_set] = |
584 | GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_); |
585 | |
586 | if (!middle_interval.IsNothing()) { |
587 | return {middle_interval, cond_set, false}; |
588 | } |
589 | } |
590 | |
591 | // we couldn't find an interval in which the conditions are |
592 | // provably true or false. Therefore, we can't partition the loop |
593 | // based on those conds |
594 | return {{}, {}, std::nullopt}; |
595 | }(); |
596 | |
597 | if (!opt_cond_value.has_value()) { |
598 | if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ && |
599 | analyzer_.CanProve(max - min > 0)) { |
600 | auto new_body = VisitAndMutate(body); |
601 | return For(var, min, max - min + 1, ForKind::kUnrolled, new_body); |
602 | } |
603 | return Stmt(); |
604 | } |
605 | bool cond_value = opt_cond_value.value(); |
606 | |
607 | IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval); |
608 | // middle_interval is the subrange of the loop variable range for which a |
609 | // set of conditions are true (or false resp.) |
610 | // The part of the loop variable range that is before (after resp.) that |
611 | // subrange is prefixed with pre- (post- resp.) |
612 | |
613 | // Calculating pre-subrange and generating code for it. |
614 | // pre-subrange = [min, body_begin) |
615 | PrimExpr body_begin; |
616 | Stmt pre_stmt; |
617 | bool pre_stmt_recurse = true; |
618 | if (middle_interval_i->HasLowerBound()) { |
619 | body_begin = analyzer_.Simplify(middle_interval.min()); |
620 | if (!analyzer_.CanProve(body_begin == min)) { |
621 | PrimExpr extent = analyzer_.Simplify(body_begin - min); |
622 | if (!analyzer_.CanProve(extent > 0)) { |
623 | body_begin = tvm::max(body_begin, min); |
624 | // stop recursing on this interval if we can't prove it has non-negative length |
625 | pre_stmt_recurse = false; |
626 | } |
627 | if (!analyzer_.CanProve(extent <= 0)) { |
628 | if (!partition_thread_scope) { |
629 | Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); |
630 | pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); |
631 | } |
632 | } |
633 | } |
634 | } else { |
635 | body_begin = min; |
636 | } |
637 | |
638 | // Calculating post-subrange and generating code for it. |
639 | // post-subrange = [post_doubt_begin, max+1) |
640 | PrimExpr post_doubt_begin; |
641 | Stmt post_stmt; |
642 | bool post_stmt_recurse = true; |
643 | if (middle_interval_i->HasUpperBound()) { |
644 | post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1); |
645 | if (!analyzer_.CanProve(middle_interval.max() == max)) { |
646 | // require the extent to be non-negative |
647 | PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1); |
648 | if (!analyzer_.CanProve(extent > 0)) { |
649 | post_doubt_begin = tvm::min(post_doubt_begin, max + 1); |
650 | // stop recursing on this interval if we can't prove it has non-negative length |
651 | post_stmt_recurse = false; |
652 | } |
653 | if (!analyzer_.CanProve(extent <= 0)) { |
654 | if (!partition_thread_scope) { |
655 | Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); |
656 | post_stmt = MakeFor(stmt.get(), extent, post_body); |
657 | } |
658 | } |
659 | } |
660 | } else { |
661 | post_doubt_begin = max + 1; |
662 | } |
663 | |
664 | Stmt s; |
665 | |
666 | // Generating code for middle subrange |
667 | if (!partition_thread_scope) { |
668 | Stmt mid_stmt; |
669 | if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { |
670 | // [body_begin, post_doubt_begin) |
671 | Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body); |
672 | Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); |
673 | mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body); |
674 | // Recurse until partitions is empty |
675 | mid_stmt = VisitAndMutate(mid_stmt); |
676 | // Recurse for each non-empty subrange only if there are at least |
677 | // two non-empty subranges |
678 | if (pre_stmt.defined() || post_stmt.defined()) { |
679 | if (pre_stmt.defined() && pre_stmt_recurse) { |
680 | pre_stmt = VisitAndMutate(pre_stmt); |
681 | } |
682 | if (post_stmt.defined() && post_stmt_recurse) { |
683 | post_stmt = VisitAndMutate(post_stmt); |
684 | } |
685 | } |
686 | } |
687 | s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt); |
688 | } else { |
689 | PrimExpr cond = const_true(); |
690 | if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); |
691 | if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); |
692 | s = ThreadPartitionInserter(cond_set, cond)(stmt); |
693 | } |
694 | s = ConvertSSA(s); |
695 | return s; |
696 | } |
697 | |
698 | inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { |
699 | const ForNode* for_node = static_cast<const ForNode*>(node); |
700 | ICHECK(for_node); |
701 | if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && |
702 | !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { |
703 | // If the loop extent is 1, do not create the loop anymore |
704 | return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); |
705 | } else { |
706 | ICHECK(for_node->kind != ForKind::kThreadBinding); |
707 | return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body, |
708 | for_node->thread_binding, for_node->annotations); |
709 | } |
710 | } |
711 | |
712 | class RemoveLikelyTagsAndHints : public StmtExprMutator { |
713 | public: |
714 | PrimExpr VisitExpr_(const CallNode* op) final { |
715 | if (op->op.same_as(builtin::likely())) { |
716 | ICHECK_EQ(op->args.size(), 1); |
717 | return StmtExprMutator::VisitExpr(op->args[0]); |
718 | } else { |
719 | return StmtExprMutator::VisitExpr_(op); |
720 | } |
721 | } |
722 | |
723 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
724 | if (op->attr_key == attr::pragma_loop_partition_hint) { |
725 | return VisitStmt(op->body); |
726 | } |
727 | return StmtExprMutator::VisitStmt_(op); |
728 | } |
729 | }; |
730 | |
731 | Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one, |
732 | bool unroll_loop_with_partition_hint_no_interval) { |
733 | stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one, |
734 | unroll_loop_with_partition_hint_no_interval) |
735 | .VisitAndMutate(std::move(stmt)); |
736 | stmt = RemoveLikelyTagsAndHints()(std::move(stmt)); |
737 | return stmt; |
738 | } |
739 | |
740 | namespace transform { |
741 | |
742 | Pass LoopPartition() { |
743 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
744 | auto* n = f.CopyOnWrite(); |
745 | auto cfg = ctx->GetConfig<LoopPartitionConfig>("tir.LoopPartition" ); |
746 | if (!cfg.defined()) { |
747 | cfg = AttrsWithDefaultValues<LoopPartitionConfig>(); |
748 | } |
749 | n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop, |
750 | cfg.value()->no_unroll_loop_with_extent_one, |
751 | cfg.value()->unroll_loop_with_partition_hint_no_interval); |
752 | return f; |
753 | }; |
754 | return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition" , {}); |
755 | } |
756 | |
757 | TVM_REGISTER_GLOBAL("tir.transform.LoopPartition" ).set_body_typed(LoopPartition); |
758 | |
759 | } // namespace transform |
760 | |
761 | } // namespace tir |
762 | } // namespace tvm |
763 | |