1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/arith/iter_affine_map.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/arith/iter_affine_map.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/expr_functor.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/stmt_functor.h>
30
31#include <utility>
32
33#include "../support/utils.h"
34#include "const_fold.h"
35#include "pattern_match.h"
36
37namespace tvm {
38namespace arith {
39
40using namespace tir;
41
42IterMark::IterMark(PrimExpr source, PrimExpr extent) {
43 auto n = make_object<IterMarkNode>();
44 n->source = std::move(source);
45 n->extent = std::move(extent);
46 data_ = std::move(n);
47}
48
49TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) {
50 return IterMark(source, extent);
51});
52
53TVM_REGISTER_NODE_TYPE(IterMarkNode);
54
55TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
56 .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
57 auto* op = static_cast<const IterMarkNode*>(node.get());
58 p->stream << "IterMark(" << op->source << ", extent=" << op->extent << ")";
59 });
60
61IterSplitExpr::IterSplitExpr(IterMark source) {
62 auto n = make_object<IterSplitExprNode>();
63 auto one = make_const(source->source->dtype, 1);
64 n->dtype = source->source->dtype;
65 n->source = std::move(source);
66 n->extent = n->source->extent;
67 n->lower_factor = one;
68 n->scale = one;
69 data_ = std::move(n);
70}
71
72IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) {
73 auto n = make_object<IterSplitExprNode>();
74 auto one = make_const(source->source->dtype, 1);
75 n->dtype = source->source->dtype;
76 n->source = std::move(source);
77 n->extent = n->source->extent;
78 n->lower_factor = one;
79 n->scale = std::move(scale);
80 data_ = std::move(n);
81}
82
83IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
84 PrimExpr scale) {
85 auto n = make_object<IterSplitExprNode>();
86 n->dtype = source->source->dtype;
87 n->source = std::move(source);
88 n->lower_factor = std::move(lower_factor);
89 n->extent = std::move(extent);
90 n->scale = std::move(scale);
91 data_ = std::move(n);
92}
93
94TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
95 .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) {
96 return IterSplitExpr(source, lower_factor, extent, scale);
97 });
98
99TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
100
101TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
102 .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
103 auto* op = static_cast<const IterSplitExprNode*>(node.get());
104 p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor
105 << ", extent=" << op->extent << ", scale=" << op->scale << ")";
106 });
107
108IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
109 auto n = make_object<IterSumExprNode>();
110 n->dtype = base->dtype;
111 n->args = std::move(args);
112 n->base = std::move(base);
113 data_ = std::move(n);
114}
115
116TVM_REGISTER_GLOBAL("arith.IterSumExpr")
117 .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
118 return IterSumExpr(args, base);
119 });
120
121TVM_REGISTER_NODE_TYPE(IterSumExprNode);
122
123TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
124 .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
125 auto* op = static_cast<const IterSumExprNode*>(node.get());
126 p->stream << "IterSum(" << op->args << ", " << op->base << ")";
127 });
128
129/*!
130 * \brief Collector that collects the outgoing split reference of each IterMark.
131 *
132 * These out-going splits can then be used to check if the iterators are independent.
133 */
134class IterMarkSplitCollector {
135 public:
136 // mark all IterMarks that are visited.
137 std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
138 // each iter mark to its outgoing splits that are referenced.
139 std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
140 mark2splits_;
141 /*!
142 * \brief Collect all mark2splits recursively from indices.
143 * \param indices The iterator of interest.
144 */
145 void Collect(const Array<IterSumExpr>& indices) {
146 for (IterSumExpr sum_expr : indices) {
147 for (IterSplitExpr split : sum_expr->args) {
148 this->CollectInternal(split->source);
149 mark2splits_[split->source].push_back(split);
150 }
151 }
152 }
153
154 void CollectInternal(const IterMark& mark) {
155 if (visited_.count(mark)) return;
156 visited_.insert(mark);
157 if (auto* op = mark->source.as<IterSumExprNode>()) {
158 for (IterSplitExpr split : op->args) {
159 this->CollectInternal(split->source);
160 mark2splits_[split->source].push_back(split);
161 }
162 }
163 }
164};
165
166/*! \brief Record form of IterMark(x, extent) + offset */
167struct IterMarkWithOffset {
168 IterMark mark;
169 PrimExpr offset{0};
170 IterMarkWithOffset() {}
171 IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {}
172};
173
174/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
175class IterMapRewriter : public ExprMutator {
176 public:
177 using Parent = ExprMutator;
178
179 explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
180 IterMapLevel check_level, bool simplify_trivial_iterators,
181 Array<String>* errors)
182 : analyzer_(analyzer),
183 check_level_(check_level),
184 errors_(*errors),
185 padding_predicate_(const_false()) {
186 for (auto kv : input_iters) {
187 const Var& var = kv.first;
188 const Range& vrng = kv.second;
189 if (simplify_trivial_iterators && is_one(vrng->extent)) {
190 var_map_[var] = IterSumExpr({}, vrng->min);
191 } else if (is_zero(vrng->min)) {
192 IterMark mark(var, vrng->extent);
193 var_map_[var] = IterSplitExpr(mark);
194 input_marks_.push_back(mark);
195 } else {
196 IterMark mark(var - vrng->min, vrng->extent);
197 IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark));
198 sum_expr.CopyOnWrite()->base = vrng->min;
199 var_map_[var] = sum_expr;
200 input_marks_.push_back(mark);
201 }
202 }
203 }
204
205 PrimExpr padding_predicate() const { return padding_predicate_; }
206 bool requires_padding() const { return requires_padding_; }
207
208 IterSumExpr Rewrite(const PrimExpr& expr) {
209 return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
210 }
211
212 IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) {
213 update_iterator_padding_ = true;
214 auto res = Rewrite(expr);
215 update_iterator_padding_ = false;
216 return res;
217 }
218
219 IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
220 const Optional<PrimExpr>& predicate_induced_min,
221 const Optional<PrimExpr>& predicate_induced_max) {
222 return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
223 predicate_induced_max);
224 }
225
226 /*!
227 * \brief If require bijective mapping, this function checks two conditions:
228 * - C0: Each iter mark should be fully covered by non-overlapping splits.
229 * - C1: All of the input iterators are used.
230 * Example: given x in [0, 8) y in [0, 6)
231 * - bindings = [x, x + 1, y] won't pass because x and x+1 contribute
232 * two splits that overlaps with each other.
233 * - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4
234 * contribute two non-overlapping splits that covers x.
235 * - bindings = [x / 4, x % 4] won't pass because y is not used.
236 *
237 * If only require surjective mapping, this function checks one condition:
238 * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits.
239 * Example: given x in [0, 8) y in [0, 6)
240 * - bindings = [x / 4] will pass because x / 4 can be one split of x
241 * - bindings = [x / 4, x % 4] will pass because x / 4 and x % 4
242 * contribute two non-overlapping splits that covers x.
243 * - bindings = [x / 3] will not pass because x / 3 can not be one split of x
244 * \return whether the bindings are valid
245 */
246 bool CheckMapping(const Array<IterSumExpr>& bindings, IterMapLevel check_level) {
247 IterMarkSplitCollector collector;
248 // We can check that for each iter mark:
249 // All the splits that refers to the iter_mark covers its extent.
250 // The splits do not overlap with each other.
251 collector.Collect(bindings);
252
253 for (const IterMark& mark : collector.visited_) {
254 if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) {
255 return false;
256 }
257 }
258 if (check_level == IterMapLevel::Bijective) {
259 // all input marks must be visited
260 for (const IterMark& mark : input_marks_) {
261 if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) {
262 return false;
263 }
264 }
265 }
266 return true;
267 }
268
269 /*!
270 * \brief Check the validity of iterator constraints
271 * The flattened forms of two different iterator constraints
272 * either 1) follow inclusion relation or 2) have no intersection
273 *
274 * For Example, x = i0*30 + i1*15 + i2*3 + i3,
275 * 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \\intersect {i2, i3} = empty set.
276 * 2) [i0*2 + i1 < 3, i1*5 + i2 < 5] is not valid,
277 * since {i0, i1} \\intersect {i1, i2} = {i1}, i0 \\in {i0, i1}, i0 \\notin {i1, i2}
278 * \return whether the predicates are valid;
279 */
280 bool CheckConstraints() const {
281 // the constrained_iters_flattened_ are in the order of shorter to longer
282 // since we visit the predicates in the order of size
283 for (size_t i = 0; i < constrained_iters_flattened_.size(); ++i) {
284 for (size_t j = i + 1; j < constrained_iters_flattened_.size(); ++j) {
285 // state: 0(start), -1(no intersection), 1(inclusion)
286 int state = 0;
287 for (const IterSplitExpr& arg1 : constrained_iters_flattened_[i]->args) {
288 bool found = false;
289 for (const IterSplitExpr& arg2 : constrained_iters_flattened_[j]->args) {
290 if (IterSplitEqual(arg1, arg2)) {
291 found = true;
292 break;
293 }
294 }
295 // Check either it is inclusion or intersection, but not both
296 if (state == 0) {
297 state = found ? 1 : -1;
298 } else if ((state == -1 && found) || (state == 1 && !found)) {
299 return false;
300 }
301 }
302 }
303 }
304 return true;
305 }
306
307 // override the original mutate function.
308 PrimExpr VisitExpr(const PrimExpr& input_expr) final {
309 auto expr = ExprMutator::VisitExpr(input_expr);
310 if (expr->IsInstance<IterMapExprNode>()) {
311 ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in "
312 << "IterMapRewriter using DirectMutate. "
313 << "Indirect return occurred in " << input_expr;
314 }
315 return expr;
316 }
317
318 // Normal mutation without normalization.
319 PrimExpr DirectMutate(const PrimExpr& expr) { return ExprMutator::VisitExpr(expr); }
320
321 PrimExpr VisitExpr_(const VarNode* op) final;
322 PrimExpr VisitExpr_(const AddNode* op) final;
323 PrimExpr VisitExpr_(const SubNode* op) final;
324 PrimExpr VisitExpr_(const MulNode* op) final;
325 PrimExpr VisitExpr_(const FloorDivNode* op) final;
326 PrimExpr VisitExpr_(const FloorModNode* op) final;
327
328 private:
329 /* \brief Preprocessing common to both FloorDiv and FloorMod
330 *
331 * \param dividend The dividend to be manipulated.
332 */
333 IterSumExpr PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend);
334
335 // Create an iterator that represents the expression (split+base), with
336 // padding such that the iterator's extents are evenly divisible by
337 // `divisor`.
338 //
339 // If iterators can have padding added through UpdatePadding, pad a
340 // dividend out to be evenly divisible. Otherwise, validate that the
341 // padding previously defined for the split using UpdatePadding can be
342 // used. If no such previous padding exists, return an empty
343 // IterMark.
344 //
345 // Returns a pair of IterSplit that represents (split+base) in a
346 // form that can be dividied by divisors, and PrimExpr that
347 // represents the left padding applied to split.
348 std::pair<IterSplitExpr, PrimExpr> PadDividendToDivisor(IterSplitExpr split, PrimExpr base,
349 PrimExpr divisor);
350
351 friend struct ErrorLogger;
352
353 /* \brief Utility class for logging errors.
354 *
355 * It is not an error for IterMapRewriter to receive an expression that
356 * cannot be represented as an IterSumExpr. In these cases,
357 * IterMapRewriter returns the unrepresentable portions of the TIR graph
358 * without modification. As a result, the usual ICHECK or LOG(FATAL)
359 * macros cannot be used. Instead, ErrorLogger(this) can be used to
360 * report an unrepresentable TIR graph, which may be used in error
361 * messages at the calling scope.
362 */
363 class ErrorLogger {
364 public:
365 explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {}
366 ~ErrorLogger() { rewriter->errors_.push_back(os.str()); }
367
368 template <typename T>
369 ErrorLogger& operator<<(T&& t) {
370 os << std::forward<T>(t);
371 return *this;
372 }
373
374 private:
375 IterMapRewriter* rewriter;
376 std::ostringstream os;
377 };
378
379 struct IterPaddingInfo {
380 // GCD of padding factor collected during first pass
381 PrimExpr padding_factor{1};
382
383 PrimExpr left_pad{0};
384 PrimExpr right_pad{0};
385
386 // Padded form of original iter mark
387 IterMark padded;
388 };
389
390 // temp hash for de-duplication purposes.
391 struct IterSumHash {
392 size_t operator()(const IterSumExpr& value) const {
393 // for now only hash on source index.
394 size_t hash = value->args.size();
395 for (const IterSplitExpr& arg : value->args) {
396 hash = support::HashCombine(hash, std::hash<const Object*>()(arg->source.get()));
397 }
398 return hash;
399 }
400 };
401
402 static bool IterSplitEqual(const IterSplitExpr& lhs, const IterSplitExpr& rhs,
403 bool check_scale = true) {
404 tir::ExprDeepEqual equal;
405 if (!lhs->source.same_as(rhs->source)) return false;
406 if (!equal(lhs->lower_factor, rhs->lower_factor)) return false;
407 if (check_scale && !equal(lhs->scale, rhs->scale)) return false;
408 if (!equal(lhs->extent, rhs->extent)) return false;
409 return true;
410 }
411
412 struct IterSumEqual {
413 bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
414 tir::ExprDeepEqual equal;
415 if (lhs->args.size() != rhs->args.size()) return false;
416 if (!equal(lhs->base, rhs->base)) return false;
417 for (size_t i = 0; i < lhs->args.size(); ++i) {
418 if (!IterSplitEqual(lhs->args[i], rhs->args[i])) return false;
419 }
420 return true;
421 }
422 };
423
424 // Internal analyzer
425 Analyzer* analyzer_;
426 // Iter map check level
427 IterMapLevel check_level_;
428 // Error messages for each unresolved expression.
429 Array<String>& errors_;
430 // The var map
431 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
432 // input iter marks
433 std::vector<IterMark> input_marks_;
434
435 // Map from an iter mark to the padded iterator information for
436 // it. This is necessary for introducing the same padding in all
437 // usage of an input iterator. (e.g. (i-1) occurring in the
438 // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be
439 // left-padded by 31 for each occurrence.)
440 std::unordered_map<IterMark, IterPaddingInfo, StructuralHash, StructuralEqual> padded_iter_map_;
441
442 // Map from padded iter mark to it's origin mark
443 std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual> padded_origin_map_;
444
445 /* If update_iterator_padding_ is true, allow the extents of the IterMap to be
446 * padded beyond the original iterators.
447 *
448 * For example, if update_iterator_padding_ is true, the expressions i//4 and
449 * i%4, where i is on the range [0,18), would be represented as
450 * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4).
451 * This representation would be forbidden if update_iterator_padding_ is false,
452 * because lower_factor=4 does not evenly divide the original extent of
453 * 18.
454 */
455 bool update_iterator_padding_{false};
456
457 /* A boolean expression that is true for any padding that has been
458 * introduced, and false otherwise. If update_iterator_padding_ is false,
459 * padding_predicate_ will always be false.
460 *
461 * Example: [i//4, i%4], i in range [0,16)
462 * padding_predicate_ will be false
463 *
464 * Example: [i//4, i%4], i in range [0,18)
465 * padding_predicate_ will be `(i//4 == 3) && (i%4 >= 2)`
466 *
467 * Example: [i//4, i%4], i in range [0,N)
468 * padding_predicate_ will be `(N%4!=0) && (i//4 == (N+3)//4-1) && (i%4 >= N%4)`
469 */
470 PrimExpr padding_predicate_;
471
472 /* A boolean flag denotes there are padding iterations detected
473 * in the first round of indices rewriting.
474 */
475 bool requires_padding_{false};
476
477 // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
478 // an extra offset)
479 // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
480 // predicate: j*2 + k < 9
481 // Then, flattened form = IterSum(IterSplit(i, scale=9),
482 // IterSplit(j, scale=2),
483 // IterSplit(k, scale=1))
484 // normal form = IterSum(IterSplit(i, scale=9),
485 // IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
486 // IterSplit(k, scale=1)),
487 // extent=9)
488 // scale=1))
489 // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
490 // predicate: 1 <= j*2 + k < 9
491 // Then, flattened form = IterSum(IterSplit(i, scale=8),
492 // IterSplit(j, scale=2),
493 // IterSplit(k, scale=1))
494 // normal form = IterSum(IterSplit(i, scale=8),
495 // IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
496 // IterSplit(k, scale=1), base=-1),
497 // extent=9-1)
498 // scale=1),
499 // base=1)
500 std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
501 // The map for sum that maps normal form to flattened form
502 std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
503 // The flattened forms of constrained iters
504 std::vector<IterSumExpr> constrained_iters_flattened_;
505
506 /*!
507 * \brief Look for a split in splits that is not used such that its lower_factor is smallest.
508 * Note that here we use division to compare lower_factor.
509 * \param splits the split array to search in.
510 * \param used the input used array.
511 * \param expected_lower_factor the skipped lower factor.
512 * \return the index of the expected split, split.size() if not found.
513 */
514 size_t SearchSkipLowerFactor(const std::vector<IterSplitExpr>& splits,
515 const std::vector<bool>& used,
516 const PrimExpr& expected_lower_factor) {
517 size_t res = splits.size();
518 for (size_t i = 0; i < splits.size(); ++i) {
519 if (used[i]) continue;
520 if (!used[i] && !CanProveDivisible(splits[i]->lower_factor, expected_lower_factor)) {
521 // all the remaining unused splits should have their lower factor divisible
522 return splits.size();
523 }
524 if (res == splits.size() ||
525 CanProveDivisible(splits[res]->lower_factor, splits[i]->lower_factor)) {
526 // note down the split with smaller lower factor
527 res = i;
528 }
529 }
530 return res;
531 }
532
533 /*!
534 * \brief If bijective is required, verify that splits fully covers mark in a non-overlapping
535 * fashion, If not, verify that splits are valid and compatible for the mark.
536 * If verification passes, return splits from outermost to innermost order.
537 * If not, return an empty array.
538 * \param mark The iterator of interest.
539 * \param splits The splits to be verified.
540 * \param check_level Iteration mapping's check level.
541 * \return The normalized splits.
542 */
543 Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
544 const std::vector<IterSplitExpr>& splits,
545 IterMapLevel check_level) {
546 std::vector<bool> used(splits.size(), false);
547 std::vector<IterSplitExpr> iters;
548 PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
549
550 for (size_t i = 0; i < splits.size(); ++i) {
551 size_t j = 0;
552 for (; j < splits.size(); ++j) {
553 if (used[j]) continue;
554 if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) {
555 break;
556 }
557 }
558 if (j == splits.size()) {
559 // we do not allow incomplete split if the bindings should be bijective
560 if (check_level == IterMapLevel::Bijective) {
561 return Array<IterSplitExpr>();
562 }
563 // look for the next split skipping this lower factor
564 // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
565 // It is valid to only have [y / 6, y % 2] if bijective is not required
566 // We can skip (y / 2) % 6
567 j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
568 // split not found
569 if (j == splits.size()) {
570 return Array<IterSplitExpr>();
571 }
572 }
573
574 used[j] = true;
575 iters.push_back(splits[j]);
576 expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
577 }
578
579 // Extract iteration mark info before padding
580 auto pad_mark_it = padded_origin_map_.find(mark);
581 bool has_padding = pad_mark_it != padded_origin_map_.end();
582
583 bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent);
584 bool match_iter_divisor =
585 match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor);
586
587 // Case 1. bijective is required.
588 // We check the extent we calculate is consistent with the extent of the mark and
589 // iteration mark's padding is not allowed.
590 //
591 // Case 2. bijective is not required and there is no padding.
592 // We check the extent we calculate is a factor of the extent of the mark
593 // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
594 //
595 // Case 3. bijective is not required and there exists padding. We check either
596 // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is
597 // the single split for the iter mark.
598 // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective
599 // according to how we pad the original iteration mark.
600 // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent
601 // before padding is greater or equal than the extent we calculate.
602 // For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24.
603 //
604 if (check_level == IterMapLevel::Bijective) {
605 if (has_padding) {
606 ErrorLogger(this) << "Bijectvie mapping should not take iter paddings";
607 return {};
608 } else if (!match_full_iter) {
609 ErrorLogger(this) << "The iterations do not traverse full iter space";
610 return {};
611 }
612 } else if (!has_padding) {
613 if (!match_iter_divisor) {
614 ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
615 return {};
616 }
617 } else if (check_level == IterMapLevel::Surjective) {
618 PrimExpr extent_before_padding = pad_mark_it->second->extent;
619 if (match_full_iter) {
620 if (splits.size() != 1) {
621 ErrorLogger(this) << "Dependent iterations on padding iter space";
622 return Array<IterSplitExpr>();
623 } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) &&
624 !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
625 ErrorLogger(this) << "Split on padding iteration is not surjective "
626 << "if the split extent equals to the full iter space extent";
627 return Array<IterSplitExpr>();
628 }
629 } else if (match_iter_divisor) {
630 if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) {
631 ErrorLogger(this) << "The extent before padding is less than lower factor";
632 return Array<IterSplitExpr>();
633 }
634 } else {
635 ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent";
636 return {};
637 }
638 }
639 return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
640 }
641
642 /*!
643 * \brief Normalize the iter expression with constraint (min <= expr < max)
644 * \param expr The iter expression.
645 * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined.
646 * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
647 * \return The Normalized expression.
648 */
649 IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
650 Optional<PrimExpr> predicate_induced_max) {
651 // normalize to zero base
652 PrimExpr base = expr->base;
653 if (!is_zero(base)) {
654 expr.CopyOnWrite()->base = 0;
655 if (predicate_induced_min.defined())
656 predicate_induced_min = predicate_induced_min.value() - base;
657 if (predicate_induced_max.defined())
658 predicate_induced_max = predicate_induced_max.value() - base;
659 }
660 Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
661 ICHECK(!opt.defined() || opt.value()->args.size() == 1);
662 // scale should be 1
663 if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
664 const IterSplitExpr split = opt.value()->args[0];
665 IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source);
666 // get the flattened form
667 auto it = flattened_map_.find(structured_form);
668 ICHECK(it != flattened_map_.end());
669 IterSumExpr flattened_form = it->second;
670 // get the mark and offset of the structured_form
671 auto it_mark = sum_fuse_map_.find(flattened_form);
672 ICHECK(it_mark != sum_fuse_map_.end());
673 IterMark mark = it_mark->second.mark;
674 PrimExpr mark_offset = it_mark->second.offset;
675 PrimExpr iter_min = mark_offset;
676 PrimExpr iter_max = iter_min + mark->extent;
677 if (predicate_induced_min.defined()) {
678 iter_min = max(predicate_induced_min.value(), iter_min);
679 }
680 if (predicate_induced_max.defined()) {
681 iter_max = min(predicate_induced_max.value(), iter_max);
682 }
683 if (!is_zero(iter_min)) {
684 // structured form's offset should be updated
685 flattened_map_.erase(structured_form);
686 structured_form.CopyOnWrite()->base = -iter_min;
687 mark.CopyOnWrite()->source = structured_form;
688 flattened_map_[structured_form] = flattened_form;
689 }
690 mark.CopyOnWrite()->extent = iter_max - iter_min;
691 sum_fuse_map_[flattened_form] = {mark, iter_min};
692 // we need to note down the flattened form of constrained iterators
693 // to check the validity of constraints, see also CheckConstraints()
694 constrained_iters_flattened_.push_back(flattened_form);
695 expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
696 expr.CopyOnWrite()->base = base + iter_min;
697 return expr;
698 }
699 ErrorLogger(this) << "Could not normalize iterators using the constraints given.";
700 return expr;
701 }
702
703 /*!
704 * \brief Normalize expr to an iterator + offset.
705 * \param expr The input expression.
706 * \return The Normalized expression.
707 */
708 IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
709 // We are normalizing a regular iter
710 if (expr->args.size() < 1) return expr;
711 Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
712 if (opt.defined()) {
713 return opt.value();
714 } else {
715 ErrorLogger(this) << "Could not normalize iterators";
716 return expr;
717 }
718 }
719
720 /*!
721 * \brief Create a IterSumExpr from expr.
722 * \param expr The input expr.
723 * \return The transformed IterSumExpr.
724 */
725 static IterSumExpr ToIterSumExpr(const PrimExpr& expr) {
726 if (const auto* op = expr.as<IterSumExprNode>()) {
727 return GetRef<IterSumExpr>(op);
728 } else if (const auto* op = expr.as<IterSplitExprNode>()) {
729 return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
730 } else {
731 ICHECK(!expr->IsInstance<IterMapExprNode>());
732 return IterSumExpr({}, expr);
733 }
734 }
735
736 /*!
737 * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base
738 * = (x1*s1 + x2*s2 + ... + xn)*cn + base
739 * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base
740 * = [IterSplit(IterMark(y), scale=cn)] + base
741 * return a corresponding IterSumExpr with extra offset if needed.
742 * Try to normalize IterSum into a fused IterMark
743 * \param expr The input sum.
744 * \param check_level The check level if iter mapping.
745 * \return The sum with the fused IterMark and extra offset if succeed.
746 */
747 Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
748 // select the iterators in order
749 std::vector<bool> visited(expr->args.size(), false);
750 std::vector<IterSplitExpr> flattened_iters, grouped_iters;
751 // canonicalize the expression into two different forms: flattened form and structured form
752 // step0. check if find the base scale first
753 Optional<IntImm> base_scale = NullOpt;
754 size_t base_index = 0;
755 for (size_t i = 0; i < expr->args.size(); ++i) {
756 if (const auto* op = expr->args[i]->scale.as<IntImmNode>()) {
757 if (!base_scale || op->value < base_scale.value()->value) {
758 base_scale = GetRef<IntImm>(op);
759 base_index = i;
760 }
761 }
762 }
763 if (!base_scale) {
764 return NullOpt;
765 }
766 // check if it can be remapped into a fused pattern.
767 PrimExpr expected_extra_base = 0;
768 PrimExpr tail_extent = 0;
769 PrimExpr expected_scale = base_scale.value();
770 for (size_t i = 0; i < expr->args.size();) {
771 // find position such that expr->args[j] match expected scale
772 int j = i == 0 ? base_index : expr->args.size() - 1;
773
774 size_t matched_pos = expr->args.size();
775 PrimExpr matched_scale{nullptr};
776 bool is_exact_match{false};
777
778 for (; j >= 0; --j) {
779 if (visited[j]) {
780 continue;
781 }
782 const PrimExpr& cur_scale = expr->args[j]->scale;
783
784 // for bijective mapping, the matched scale must equal to expected scale
785 if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
786 matched_pos = j;
787 matched_scale = cur_scale;
788 is_exact_match = true;
789 break;
790 }
791 if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
792 // find the closest scale which is less or equal to expected scale
793 if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
794 analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
795 if (matched_pos == expr->args.size() ||
796 analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
797 matched_pos = j;
798 matched_scale = cur_scale;
799 }
800 }
801 }
802 }
803 if (matched_pos == expr->args.size()) {
804 return NullOpt;
805 }
806 // look for the longest constrained iter started from expr->args[j]
807 // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
808 // predicate: j*2 + k < 9
809 // We need to match the predicate in expr and adjust the expected scale,
810 // otherwise we expect the scale of i to be 2*5=10
811 Optional<IterSumExpr> constraint_to_match;
812 for (const IterSumExpr& iter : constrained_iters_flattened_) {
813 if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
814 // find a predicate started from match position
815 if (!constraint_to_match ||
816 constraint_to_match.value()->args.size() < iter->args.size()) {
817 constraint_to_match = iter;
818 }
819 }
820 }
821 if (constraint_to_match) {
822 // match the predicate and mark the iterators in the constraint_to_match as visited
823 // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
824 // predicate = j*2 + k < 9
825 // then j*2 + k matches the lower two splits of expr
826 for (auto it = constraint_to_match.value()->args.rbegin();
827 it != constraint_to_match.value()->args.rend(); ++it) {
828 size_t k = 0;
829 for (; k < expr->args.size(); ++k) {
830 if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
831 if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
832 break;
833 }
834 }
835 if (k == expr->args.size()) {
836 return NullOpt;
837 }
838 visited[k] = true;
839 flattened_iters.push_back(expr->args[k]);
840 }
841 auto iter = sum_fuse_map_.find(constraint_to_match.value());
842 ICHECK(iter != sum_fuse_map_.end());
843 const IterMarkWithOffset& iter_matched = iter->second;
844 grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
845 expected_extra_base += iter_matched.offset * matched_scale;
846 if (!is_exact_match) {
847 tail_extent += expected_scale - matched_scale;
848 }
849 expected_scale = matched_scale * iter_matched.mark->extent;
850 // move forward
851 i += constraint_to_match.value()->args.size();
852 } else {
853 // constraint_to_match not found, skip this iterator
854 visited[matched_pos] = true;
855 IterSplitExpr arg = expr->args[matched_pos];
856 arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
857 flattened_iters.push_back(arg);
858 grouped_iters.push_back(arg);
859 if (!is_exact_match) {
860 tail_extent += expected_scale - matched_scale;
861 }
862 expected_scale = matched_scale * expr->args[matched_pos]->extent;
863 ++i;
864 }
865 }
866 // Get the flattened form and structured form
867 // both forms have splits from outermost to innermost
868 IterSumExpr structured_form = expr, flattened_form = expr;
869 flattened_form.CopyOnWrite()->args =
870 Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
871 flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
872 structured_form.CopyOnWrite()->args =
873 Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
874 structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0);
875 auto it = sum_fuse_map_.find(flattened_form);
876 if (it != sum_fuse_map_.end()) {
877 // old iter
878 if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) {
879 // the extra offset is not consistent with old
880 return NullOpt;
881 }
882 return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())},
883 expr->base + expected_extra_base);
884 } else {
885 // new iter, form a new mark
886 IterMark mark =
887 IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
888 sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
889 flattened_map_[structured_form] = flattened_form;
890 return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
891 expr->base + expected_extra_base);
892 }
893 }
894
895 bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs);
896
897 PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
898 PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);
899
900 static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
901 tir::ExprDeepEqual equal;
902 for (size_t i = 0; i < lhs->args.size(); ++i) {
903 IterSplitExpr lvalue = lhs->args[i];
904 if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) &&
905 equal(lvalue->extent, rhs->extent)) {
906 if (sign > 0) {
907 rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
908 } else {
909 rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
910 }
911 lhs->args.Set(i, rhs);
912 return;
913 }
914 }
915 if (sign > 0) {
916 lhs->args.push_back(rhs);
917 } else {
918 rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
919 lhs->args.push_back(rhs);
920 }
921 }
922
923 static void AddToLhs(IterSumExprNode* lhs, const IterSumExpr& rhs, int sign) {
924 for (const auto& arg : rhs->args) {
925 AddToLhs(lhs, arg, sign);
926 }
927 if (sign > 0) {
928 lhs->base += rhs->base;
929 } else {
930 lhs->base -= rhs->base;
931 }
932 }
933
934 static void MulToLhs(IterSumExprNode* lhs, const PrimExpr& rhs) {
935 for (size_t i = 0; i < lhs->args.size(); ++i) {
936 IterSplitExpr lvalue = lhs->args[i];
937 lvalue.CopyOnWrite()->scale *= rhs;
938 lhs->args.Set(i, lvalue);
939 }
940 lhs->base *= rhs;
941 }
942};
943
944/*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */
945struct IterConstraint {
946 // The expr of the iter
947 PrimExpr iter;
948 // The expr of the lower_bound, maybe undefined
949 Optional<PrimExpr> lower_bound;
950 // The expr of the upper_bound, maybe undefined
951 Optional<PrimExpr> upper_bound;
952 // The size of the iter, which is the number of nodes
953 size_t expr_size = 0;
954
955 IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
956 size_t size)
957 : iter(std::move(iter)),
958 lower_bound(std::move(lower_bound)),
959 upper_bound(std::move(upper_bound)),
960 expr_size(size) {}
961};
962
963/*!
964 * \brief Split the predicate into `(a < b) && (c < d) && ...`
965 * \param pred The predicate to be split.
966 * \param input_iters The input iterators.
967 * \param result The result of predicate split.
968 * \return A list of IterConstraint, empty if the split failed.
969 */
970bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
971 std::vector<IterConstraint>* result) {
972 arith::PVar<PrimExpr> lhs, rhs, rest;
973 for (;;) {
974 // try extract comparisions
975 bool is_finish = false;
976 bool is_greater = false;
977 bool is_equal = false;
978 if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
979 // pass
980 } else if ((lhs < rhs).Match(pred)) {
981 is_finish = true;
982 } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
983 is_equal = true;
984 } else if ((lhs <= rhs).Match(pred)) {
985 is_equal = true;
986 is_finish = true;
987 } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
988 is_greater = true;
989 } else if ((lhs > rhs).Match(pred)) {
990 is_greater = true;
991 is_finish = true;
992 } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
993 is_greater = true;
994 is_equal = true;
995 } else if ((lhs >= rhs).Match(pred)) {
996 is_greater = true;
997 is_equal = true;
998 is_finish = true;
999 } else {
1000 return false;
1001 }
1002 PrimExpr lhs_expr = lhs.Eval();
1003 PrimExpr rhs_expr = rhs.Eval();
1004 // we only accept predicate of integers
1005 if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
1006 (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
1007 return false;
1008 }
1009 // determine iter and bound, if we can not distinguish them simply,
1010 // try divide (lhs - rhs) into itervar aware and itervar free parts
1011 auto f_use_itervar = [&input_iters](const VarNode* v) {
1012 return input_iters->count(GetRef<Var>(v));
1013 };
1014 bool bound_at_left;
1015 if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
1016 // At least it uses one input iter
1017 if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
1018 bound_at_left = true;
1019 } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
1020 bound_at_left = false;
1021 } else {
1022 bound_at_left = false; // accumulate bound to rhs
1023 PrimExpr sum_parts = lhs_expr - rhs_expr;
1024 lhs_expr = 0;
1025 rhs_expr = 0;
1026 std::function<void(const PrimExpr&, bool)> f_extract =
1027 [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
1028 if (const AddNode* add = part.as<AddNode>()) {
1029 f_extract(add->a, sign);
1030 f_extract(add->b, sign);
1031 } else if (const SubNode* sub = part.as<SubNode>()) {
1032 f_extract(sub->a, sign);
1033 f_extract(sub->b, !sign);
1034 } else if (UsesVar(part, f_use_itervar)) {
1035 lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
1036 } else {
1037 rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
1038 }
1039 };
1040 f_extract(sum_parts, true);
1041 arith::Analyzer analyzer;
1042 lhs_expr = analyzer.Simplify(lhs_expr);
1043 rhs_expr = analyzer.Simplify(rhs_expr);
1044 }
1045 Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
1046 PrimExpr iter;
1047 if (is_greater) {
1048 if (bound_at_left) {
1049 // bound > iter / bound >= iter
1050 upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
1051 iter = rhs_expr;
1052 } else {
1053 // iter > bound / iter >= bound
1054 lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
1055 iter = lhs_expr;
1056 }
1057 } else {
1058 if (bound_at_left) {
1059 // bound < iter / bound <= iter
1060 lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
1061 iter = rhs_expr;
1062 } else {
1063 // iter < bound / iter <= bound
1064 upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
1065 iter = lhs_expr;
1066 }
1067 }
1068 // If it is a predicate for a single input iter
1069 if (const auto* var_ptr = iter.as<VarNode>()) {
1070 auto it = input_iters->find(GetRef<Var>(var_ptr));
1071 if (it != input_iters->end()) {
1072 PrimExpr iter_min = (*it).second->min;
1073 PrimExpr iter_max = (*it).second->min + (*it).second->extent;
1074 if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
1075 if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
1076 input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
1077 }
1078 } else {
1079 result->emplace_back(iter, lower_bound, upper_bound, 0);
1080 }
1081 }
1082 if (is_finish) {
1083 break;
1084 }
1085 pred = rest.Eval();
1086 }
1087 return true;
1088}
1089
1090bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
1091 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> iters;
1092 for (const auto& it : iter_ranges) iters.insert(it.first);
1093 auto f = [&](const VarNode* var) { return iters.count(GetRef<Var>(var)); };
1094 for (const auto& it : iter_ranges) {
1095 if (UsesVar(it.second->min, f) || UsesVar(it.second->extent, f)) return false;
1096 }
1097 return true;
1098}
1099
1100IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
1101 const PrimExpr& predicate, IterMapLevel check_level,
1102 arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
1103 IterMapResult result;
1104
1105 // Overall detection algorithm is divided into two steps:
1106 // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
1107 // - Step1: IterIndependenceChecker checks if the iterator are independent.
1108 if (!IterRangeSanityCheck(input_iters)) {
1109 result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other.");
1110 return result;
1111 }
1112 Map<Var, Range> constrained_input_iters = input_iters;
1113 std::vector<IterConstraint> constraints;
1114 if (!is_one(predicate) &&
1115 !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
1116 result->errors.push_back("Could not parse predicate as constraints on the input iterators.");
1117 return result;
1118 }
1119 // We have to make sure when we visit an iterator, all the constraints related with its successors
1120 // in the iter var graph has been visited, where the expression of this iterator will contain the
1121 // expression of its successor, so we sort them by their sizes.
1122 for (IterConstraint& constraint : constraints) {
1123 constraint.expr_size = CalculateExprComplexity(constraint.iter);
1124 }
1125
1126 std::sort(
1127 constraints.begin(), constraints.end(),
1128 [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
1129
1130 IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
1131 simplify_trivial_iterators, &result->errors);
1132 // Step0.0: rewrite constraints in the order from size-small ones to size-big ones
1133 for (const IterConstraint& constraint : constraints) {
1134 auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
1135 constraint.upper_bound);
1136 if (result->errors.size() > 0) {
1137 return result;
1138 }
1139 }
1140 if (!rewriter.CheckConstraints()) {
1141 result->errors.push_back("Invalid constraints.");
1142 return result;
1143 }
1144
1145 // Step0.1: Rewrite indicies and determine required padding,
1146 // if there is no padding, it should be the final result.
1147 Array<IterSumExpr> rewrite_indices;
1148 rewrite_indices.reserve(indices.size());
1149 bool allow_padding = check_level != IterMapLevel::Bijective;
1150 if (allow_padding) {
1151 for (PrimExpr value : indices) {
1152 rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value));
1153 if (result->errors.size() > 0) {
1154 return result;
1155 }
1156 }
1157 }
1158
1159 // Step0.2: Rewrite indices in the second round.
1160 if (!allow_padding || rewriter.requires_padding()) {
1161 rewrite_indices.clear();
1162 for (PrimExpr value : indices) {
1163 rewrite_indices.push_back(rewriter.Rewrite(value));
1164 if (result->errors.size() > 0) {
1165 return result;
1166 }
1167 }
1168 }
1169 result->padding_predicate = rewriter.padding_predicate();
1170
1171 // Step1: IterIndependenceChecker checks if the iterator are independent.
1172 if (!rewriter.CheckMapping(rewrite_indices, check_level)) {
1173 if (check_level == IterMapLevel::Bijective) {
1174 result->errors.push_back("Index mapping does not form a bijective transform.");
1175 } else {
1176 result->errors.push_back("Mapped indices are not independent.");
1177 }
1178 return result;
1179 }
1180 result->indices = rewrite_indices;
1181 return result;
1182}
1183
1184TVM_REGISTER_GLOBAL("arith.DetectIterMap")
1185 .set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
1186 const PrimExpr& input_pred, int check_level,
1187 bool simplify_trivial_iterators) {
1188 arith::Analyzer ana;
1189 return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana,
1190 simplify_trivial_iterators);
1191 });
1192
1193PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
1194 auto var = GetRef<Var>(op);
1195 auto it = var_map_.find(var);
1196 if (it != var_map_.end()) return it->second;
1197 return std::move(var);
1198}
1199
1200PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
1201 if (!IsIndexType(op->dtype)) {
1202 return Parent::VisitExpr_(op);
1203 }
1204 PrimExpr a = this->DirectMutate(op->a);
1205 PrimExpr b = this->DirectMutate(op->b);
1206
1207 // const folding
1208 if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
1209 // does not contain iter map.
1210 if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
1211 if (op->a.same_as(a) && op->b.same_as(b)) {
1212 return GetRef<PrimExpr>(op);
1213 } else {
1214 return Add(a, b);
1215 }
1216 }
1217
1218 // canonical form simplification.
1219 IterSumExpr ret = ToIterSumExpr(a);
1220
1221 if (!b->IsInstance<IterMapExprNode>()) {
1222 ret.CopyOnWrite()->base += b;
1223 } else if (const auto* op = b.as<IterSumExprNode>()) {
1224 AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), 1);
1225 } else if (const auto* op = b.as<IterSplitExprNode>()) {
1226 AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), 1);
1227 } else {
1228 AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1);
1229 }
1230 return std::move(ret);
1231}
1232
1233PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) {
1234 if (!IsIndexType(op->dtype)) {
1235 return Parent::VisitExpr_(op);
1236 }
1237
1238 PrimExpr a = this->DirectMutate(op->a);
1239 PrimExpr b = this->DirectMutate(op->b);
1240
1241 // const folding
1242 if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
1243
1244 // does not contain iter map.
1245 if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
1246 if (op->a.same_as(a) && op->b.same_as(b)) {
1247 return GetRef<PrimExpr>(op);
1248 } else {
1249 return Sub(a, b);
1250 }
1251 }
1252
1253 // canonical form simplification.
1254 IterSumExpr ret = ToIterSumExpr(a);
1255
1256 if (!b->IsInstance<IterMapExprNode>()) {
1257 ret.CopyOnWrite()->base -= b;
1258 } else if (const auto* op = b.as<IterSumExprNode>()) {
1259 AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), -1);
1260 } else if (const auto* op = b.as<IterSplitExprNode>()) {
1261 AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), -1);
1262 } else {
1263 AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1);
1264 }
1265 return std::move(ret);
1266}
1267
1268PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
1269 if (!IsIndexType(op->dtype)) {
1270 return Parent::VisitExpr_(op);
1271 }
1272 // normalize
1273 PrimExpr a = this->DirectMutate(op->a);
1274 PrimExpr b = this->DirectMutate(op->b);
1275
1276 // const folding
1277 if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
1278
1279 // does not contain iter map.
1280 if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
1281 if (op->a.same_as(a) && op->b.same_as(b)) {
1282 return GetRef<PrimExpr>(op);
1283 } else {
1284 return Mul(a, b);
1285 }
1286 }
1287
1288 if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
1289 // cannot multiply two iterators, mark as unresolved.
1290 ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, "
1291 << "occurs in " << GetRef<Mul>(op);
1292 return GetRef<PrimExpr>(op);
1293 }
1294
1295 if (!a->IsInstance<IterMapExprNode>()) {
1296 std::swap(a, b);
1297 }
1298
1299 if (a->IsInstance<IterSumExprNode>()) {
1300 IterSumExpr ret = Downcast<IterSumExpr>(std::move(a));
1301 MulToLhs(ret.CopyOnWrite(), b);
1302 return std::move(ret);
1303 } else {
1304 ICHECK(a->IsInstance<IterSplitExprNode>());
1305 IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
1306 ret.CopyOnWrite()->scale *= b;
1307 return std::move(ret);
1308 }
1309}
1310
1311IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) {
1312 if (dividend->IsInstance<IterSplitExprNode>()) {
1313 auto split = Downcast<IterSplitExpr>(dividend);
1314 return IterSumExpr({split}, make_zero(split.dtype()));
1315 } else if (dividend->IsInstance<IterSumExprNode>()) {
1316 auto sum = Downcast<IterSumExpr>(dividend);
1317 if (sum->args.empty()) {
1318 return IterSumExpr();
1319 } else if (sum->args.size() == 1) {
1320 return sum;
1321 }
1322 auto opt_fused = TryFuseIters(sum, check_level_);
1323 if (!opt_fused) {
1324 ErrorLogger(this) << "Dividend " << original_dividend
1325 << ", can't be written as a single fused IterSum";
1326 return IterSumExpr();
1327 }
1328 IterSumExpr fused = opt_fused.value();
1329 ICHECK_EQ(fused->args.size(), 1U);
1330 return fused;
1331 } else {
1332 LOG(FATAL) << "Unsupported subclass of IterMarkExpr";
1333 }
1334}
1335
1336/*! \brief Find approximate least common multiplier. */
1337PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) {
1338 auto fsplit = [](const PrimExpr& e) -> std::pair<PrimExpr, int64_t> {
1339 if (const IntImmNode* imm = e.as<IntImmNode>()) {
1340 return {1, imm->value};
1341 }
1342 PVar<PrimExpr> pv;
1343 PVar<IntImm> pc;
1344 if ((pv * pc).Match(e) || (pc * pv).Match(e)) {
1345 return {pv.Eval(), pc.Eval()->value};
1346 } else {
1347 return {e, 1};
1348 }
1349 };
1350 auto p1 = fsplit(a);
1351 auto p2 = fsplit(b);
1352 auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
1353 if (analyzer->CanProveEqual(p1.first, p2.first)) {
1354 return p1.first * const_lcm;
1355 } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) {
1356 return p1.first * const_lcm;
1357 } else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) {
1358 return p2.first * const_lcm;
1359 } else {
1360 return (p1.first * p2.first) * const_lcm;
1361 }
1362}
1363
1364std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSplitExpr split,
1365 PrimExpr base,
1366 PrimExpr divisor) {
1367 // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor
1368 // If FloorMod: (((source//lower_factor) % extent) + base) % divisor
1369
1370 // First, adding any padding that is on the lower side of a
1371 // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0
1372 // when iter == 0.
1373 PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor));
1374
1375 // Next, adding any padding that is on the upper side of a
1376 // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, divisor) == 0
1377 // when iter == extent.
1378 PrimExpr right_edge = left_pad + split->extent;
1379 PrimExpr right_pad;
1380 if (CanProveDivisible(right_edge, divisor)) {
1381 right_pad = 0;
1382 } else {
1383 right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
1384 }
1385
1386 const IterMark& mark = split->source;
1387 if (update_iterator_padding_) {
1388 // In the first pass, the primary goal is to collect all the divisors
1389 // that may be used for padding. These will impact the divisor used
1390 // to determine padding in the second pass. We try add padding to
1391 // split's source iteraton mark thus all splits under the same mark will
1392 // share the same padded source iteration.
1393 auto& info = padded_iter_map_[mark];
1394 info.padding_factor =
1395 ApproxLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_);
1396
1397 // If the split itself require no padding, return directly.
1398 if (is_zero(left_pad) && is_zero(right_pad)) {
1399 return {split, 0};
1400 }
1401
1402 // Update padding requirement on the lower side of the source iter mark.
1403 // In the second pass, all splits would check whether the maximum left pading
1404 // on the iter mark is compatible with it's own left padding.
1405 requires_padding_ = true;
1406 PrimExpr mark_left_pad = left_pad * split->lower_factor;
1407 info.left_pad = max(info.left_pad, mark_left_pad);
1408
1409 // Since we only care the extent in the first pass's result
1410 // we just create result of compatible padded extent, ignoring
1411 // possible relations between different padded iters.
1412 PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad);
1413 split.CopyOnWrite()->extent = padded_extent;
1414 return {split, left_pad};
1415 }
1416
1417 // In the second pass, update iteration mark's to padded form
1418 auto it = padded_iter_map_.find(mark);
1419 if (it == padded_iter_map_.end()) {
1420 return {split, left_pad};
1421 }
1422 auto& info = it->second;
1423 if (is_zero(info.left_pad) && CanProveDivisible(mark->extent, info.padding_factor)) {
1424 // the iter mark requires no padding
1425 return {split, left_pad};
1426 }
1427
1428 // check that padding factor is compatible with current split and divisor
1429 ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor))
1430 << "The padding factor " << info.padding_factor << " is not divisible by "
1431 << split->lower_factor << " for the split " << split;
1432 ICHECK(CanProveDivisible(info.padding_factor, divisor))
1433 << "The padding factor " << info.padding_factor << " is not divisible by " << divisor
1434 << " for the split " << split;
1435
1436 if (!info.padded.defined()) {
1437 // the first time encounter the iter mark to pad, update the padded mark.
1438 PrimExpr mark_left_pad = info.left_pad;
1439 if (CanProveDivisible(mark_left_pad, split->lower_factor)) {
1440 // correct current split's left padding
1441 // (mark_left_pad + iter) // lower_factor % extent =>
1442 // (left_pad * lower_factor + mark) // lower_factor % extent =>
1443 // (left_pad + mark // lower_factor) % extent =>
1444 // left_pad + (mark // lower_factor % extent) =>
1445 // left_pad + split
1446 // since the extent covers the full padding range.
1447 left_pad = floordiv(mark_left_pad, split->lower_factor);
1448 } else {
1449 ErrorLogger(this) << "Detect incompatible left padding on " << NormalizeIterMapToExpr(split)
1450 << ", the iter mark is left padded with " << mark_left_pad;
1451 return {IterSplitExpr(), PrimExpr()};
1452 }
1453
1454 PrimExpr right_edge = mark->extent + mark_left_pad;
1455 PrimExpr mark_right_pad;
1456 if (CanProveDivisible(right_edge, info.padding_factor)) {
1457 mark_right_pad = 0;
1458 } else {
1459 mark_right_pad = floormod(-right_edge, info.padding_factor);
1460 }
1461 PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad);
1462 info.right_pad = mark_right_pad;
1463 info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent);
1464 padded_origin_map_[info.padded] = mark;
1465
1466 auto left_padding_introduced = (mark_left_pad != 0);
1467
1468 // Equivalent to (0 <= split < left_pad), but easier to simplify in
1469 // terms of the transformed variables.
1470 auto left_padding_predicate =
1471 left_padding_introduced &&
1472 (floordiv(info.padded->source, info.padding_factor) == 0 &&
1473 floormod(info.padded->source, info.padding_factor) < mark_left_pad);
1474 auto right_padding_introduced = (mark_right_pad != 0);
1475
1476 // Equivalent to (right_edge <= split < right_edge + right_pad), but
1477 // easier to simplify in terms of the transformed variables.
1478 auto right_padding_predicate =
1479 right_padding_introduced && (floordiv(info.padded->source, info.padding_factor) ==
1480 floordiv(right_edge, info.padding_factor) &&
1481 floormod(info.padded->source, info.padding_factor) >=
1482 floormod(right_edge, info.padding_factor));
1483 padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate);
1484 }
1485 split.CopyOnWrite()->source = info.padded;
1486 split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor);
1487 return {split, left_pad};
1488}
1489
1490PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) {
1491 // (lhs + base) // rhs
1492
1493 if (is_one(rhs)) {
1494 if (is_zero(base)) {
1495 // floordiv(x, 1) = x
1496 return std::move(lhs);
1497 } else {
1498 // floordiv(x+y, 1) = x+y
1499 return IterSumExpr({lhs}, base);
1500 }
1501 }
1502
1503 if (!is_one(lhs->scale)) {
1504 if (CanProveDivisible(lhs->scale, rhs) && is_zero(base)) {
1505 // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs
1506 lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs);
1507 return std::move(lhs);
1508 } else if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) {
1509 // floordiv(x*c1*c2 + y*c2, c2) = x*c1 + y, c1=scale/rhs
1510 lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs);
1511 return IterSumExpr({lhs}, floordiv(base, rhs));
1512 } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) {
1513 // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale
1514 rhs = floordiv(rhs, lhs->scale);
1515 lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1);
1516 } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) {
1517 // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale
1518 base = floordiv(base, lhs->scale);
1519 rhs = floordiv(rhs, lhs->scale);
1520 lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1);
1521 } else {
1522 // mark as unresolved.
1523 ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, "
1524 << lhs->scale << " and the divisor " << rhs
1525 << " cannot be simplified to remove the scaling factor.";
1526 return PrimExpr();
1527 }
1528 }
1529
1530 // We handle scale!=1 in above code, hence we only consider floordiv(x, rhs) below
1531 // where x=floormod(floordiv(iter, lower_factor), extent) + base
1532
1533 auto pair = PadDividendToDivisor(lhs, base, rhs);
1534 IterSplitExpr padded = pair.first;
1535 PrimExpr left_pad = pair.second;
1536 if (!padded.defined()) {
1537 return PrimExpr();
1538 }
1539
1540 // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1)
1541 // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor)
1542 // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, t<c2, u<c1
1543 // = t
1544 // = floormod(sc2+t, c2)
1545 // = floormod(floordiv(y, c1), c2)
1546 // = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs, c2=extent/rhs
1547 IterSplitExpr new_split(padded->source,
1548 /* lower_factor = */ padded->lower_factor * rhs,
1549 /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)),
1550 /* scale = */ padded->scale);
1551
1552 auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6);
1553 if (is_zero(new_base)) {
1554 return std::move(new_split);
1555 } else {
1556 return IterSumExpr({new_split}, new_base);
1557 }
1558}
1559
1560PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
1561 if (!IsIndexType(op->dtype)) {
1562 return Parent::VisitExpr_(op);
1563 }
1564
1565 PrimExpr a = this->DirectMutate(op->a);
1566 PrimExpr b = this->DirectMutate(op->b);
1567
1568 // const folding
1569 if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
1570
1571 // does not contain iter map.
1572 if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
1573 if (op->a.same_as(a) && op->b.same_as(b)) {
1574 return GetRef<PrimExpr>(op);
1575 } else {
1576 return FloorDiv(a, b);
1577 }
1578 }
1579
1580 if (b->IsInstance<IterMapExprNode>()) {
1581 // cannot divide an iterator, mark as unresolved.
1582 ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef<PrimExpr>(op)
1583 << " may not be an iterator";
1584 return GetRef<PrimExpr>(op);
1585 }
1586
1587 IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a);
1588 if (!preprocessed.defined()) {
1589 return GetRef<PrimExpr>(op);
1590 }
1591 ICHECK_EQ(preprocessed->args.size(), 1U);
1592 PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b);
1593 if (!remainder.defined()) {
1594 return GetRef<PrimExpr>(op);
1595 }
1596 return remainder;
1597}
1598
1599PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) {
1600 // (lhs + base) % rhs
1601
1602 if (is_one(rhs)) {
1603 // floormod(x, 1) = 0
1604 return make_zero(lhs->dtype);
1605 }
1606
1607 if (!is_one(lhs->scale)) {
1608 if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) {
1609 // floormod(x*c1*c2, c1) = 0
1610 return make_zero(lhs->dtype);
1611 } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) {
1612 // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale
1613 rhs = floordiv(rhs, lhs->scale);
1614 } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) {
1615 // floormod(x*c1 + y*c1, c1*c2) = (floormod(x+y, c2)) * c1, where c2 = rhs/scale
1616 rhs = floordiv(rhs, lhs->scale);
1617 base = floordiv(base, lhs->scale);
1618 } else {
1619 // mark as unresolved.
1620 ErrorLogger(this)
1621 << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, "
1622 << lhs->scale << " and the right-hand " << rhs
1623 << " cannot be used to simplify out the scaling factor.";
1624 return PrimExpr();
1625 }
1626 }
1627
1628 // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below
1629 // where x=floormod(floordiv(iter, lower_factor), extent) + base
1630 auto pair = PadDividendToDivisor(lhs, base, rhs);
1631 IterSplitExpr padded = pair.first;
1632 if (!padded.defined()) {
1633 return PrimExpr();
1634 }
1635
1636 // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1)
1637 // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs
1638 return IterSplitExpr(padded->source,
1639 /* lower_factor = */ padded->lower_factor,
1640 /* extent = */ rhs,
1641 /* scale = */ padded->scale);
1642}
1643
1644PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
1645 if (!IsIndexType(op->dtype)) {
1646 return Parent::VisitExpr_(op);
1647 }
1648
1649 PrimExpr a = this->DirectMutate(op->a);
1650 PrimExpr b = this->DirectMutate(op->b);
1651
1652 // const folding
1653 if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
1654
1655 // does not contain iter map.
1656 if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
1657 if (op->a.same_as(a) && op->b.same_as(b)) {
1658 return GetRef<PrimExpr>(op);
1659 } else {
1660 return FloorMod(a, b);
1661 }
1662 }
1663
1664 if (b->IsInstance<IterMapExprNode>()) {
1665 // cannot mod an iterator, mark as unresolved.
1666 ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in "
1667 << GetRef<PrimExpr>(op) << " may not be an iterator";
1668 return GetRef<PrimExpr>(op);
1669 }
1670
1671 IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a);
1672 if (!preprocessed.defined()) {
1673 return GetRef<PrimExpr>(op);
1674 }
1675
1676 ICHECK_EQ(preprocessed->args.size(), 1U);
1677 PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b);
1678 if (!remainder.defined()) {
1679 return GetRef<PrimExpr>(op);
1680 }
1681 return remainder;
1682}
1683
1684/*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr.
1685 */
1686class IterMapToExprNormalizer : public ExprMutator {
1687 public:
1688 explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
1689
1690 PrimExpr Convert(const PrimExpr& expr) { return VisitExpr(expr); }
1691
1692 private:
1693 /*! \brief Override VisitExpr for iter expr type processing */
1694 PrimExpr VisitExpr(const PrimExpr& expr) override {
1695 if (const auto* op = expr.as<IterSplitExprNode>()) {
1696 return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
1697 } else if (const auto* op = expr.as<IterSumExprNode>()) {
1698 return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
1699 } else {
1700 return ExprMutator::VisitExpr(expr);
1701 }
1702 }
1703
1704 PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
1705 PrimExpr res = 0;
1706 for (const IterSplitExpr& arg : expr->args) {
1707 res += ConvertIterSplitExpr(arg);
1708 }
1709 res += expr->base;
1710 return res;
1711 }
1712
1713 PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
1714 PrimExpr source;
1715 if (const auto* op = expr->source->source.as<VarNode>()) {
1716 source = GetRef<Var>(op);
1717 } else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
1718 source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
1719 } else {
1720 source = VisitExpr(expr->source->source);
1721 }
1722 if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
1723 return source * expr->scale;
1724 } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
1725 return floordiv(source, expr->lower_factor) * expr->scale;
1726 } else {
1727 return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
1728 expr->scale;
1729 }
1730 }
1731
1732 private:
1733 Analyzer* analyzer_;
1734};
1735
1736bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
1737 const auto* clhs = lhs.as<IntImmNode>();
1738 const auto* crhs = rhs.as<IntImmNode>();
1739 if (crhs && crhs->value == 0) {
1740 return false;
1741 } else if (clhs && crhs) {
1742 return clhs->value % crhs->value == 0;
1743 }
1744
1745 IterMapToExprNormalizer normalizer(analyzer_);
1746 PrimExpr dividend = normalizer.Convert(lhs);
1747 PrimExpr divisor = normalizer.Convert(rhs);
1748
1749 return analyzer_->CanProveEqual(dividend, divisor) ||
1750 analyzer_->CanProve(floormod(dividend, divisor) == 0);
1751}
1752
1753PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) {
1754 arith::Analyzer analyzer;
1755 IterMapToExprNormalizer normalizer(&analyzer);
1756 return normalizer.Convert(expr);
1757}
1758
1759TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed(NormalizeIterMapToExpr);
1760
1761Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
1762 const PrimExpr& input_pred, IterMapLevel check_level,
1763 bool simplify_trivial_iterators) {
1764 if (!IterRangeSanityCheck(input_iters)) return indices;
1765 Analyzer analyzer;
1766 auto res = DetectIterMap(indices, input_iters, input_pred, check_level, &analyzer,
1767 /*simplify_trivial_iterators=*/simplify_trivial_iterators);
1768 Array<IterSumExpr> rewrite = res->indices;
1769
1770 if (rewrite.empty()) {
1771 return indices;
1772 }
1773 Array<PrimExpr> simplified;
1774 simplified.reserve(rewrite.size());
1775 IterMapToExprNormalizer converter(&analyzer);
1776 for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr));
1777 return simplified;
1778}
1779
1780/*!
1781 * \brief Divider to divide the bindings into two sets of bindings(outer and inner)
1782 * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
1783 * We do message passing among IterSplitExpr and IterSumExpr.
1784 *
1785 * Example
1786 * - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits,
1787 * and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing,
1788 * then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk.
1789 * - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing.
1790 * We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2.
1791 * Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided.
1792 */
1793class SubspaceDivider {
1794 public:
1795 explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
1796 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
1797 : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}
1798
1799 size_t unresolved_count() const { return unresolved_count_; }
1800
1801 // Denotes outer*inner_extent + inner, used as message passing carrier
1802 struct DivisionResult {
1803 public:
1804 // IterMapExpr of outer iters
1805 IterMapExpr outer;
1806 // IterMapExpr of inner iters
1807 IterMapExpr inner;
1808 // extent of outer
1809 PrimExpr outer_extent;
1810 // extent of inner
1811 PrimExpr inner_extent;
1812
1813 // The kind of the division result.
1814 enum class Kind {
1815 kInner, // Indicates the division result is totally in inner subspace.
1816 kOuter, // Indicates the division result is totally in outer subspace.
1817 kMixed, // Indicates the division result is mixed in both subspace.
1818 } kind;
1819
1820 DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
1821 PrimExpr inner_extent, Kind kind = Kind::kMixed)
1822 : outer(std::move(outer)),
1823 inner(std::move(inner)),
1824 outer_extent(std::move(outer_extent)),
1825 inner_extent(std::move(inner_extent)),
1826 kind(kind) {}
1827
1828 // whether the division result is totally in outer subspace
1829 bool IsOuter() const { return kind == Kind::kOuter; }
1830
1831 // whether the division result is totally in inner subspace
1832 bool IsInner() const { return kind == Kind::kInner; }
1833
1834 IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }
1835
1836 IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }
1837
1838 static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
1839 auto dtype = iter.dtype();
1840 return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter,
1841 extent, Kind::kInner);
1842 }
1843
1844 static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
1845 auto dtype = iter.dtype();
1846 return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)),
1847 make_const(dtype, 1), Kind::kOuter);
1848 }
1849
1850 // Special value to indicate the division is not possible
1851 static DivisionResult Failure() {
1852 return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
1853 }
1854
1855 private:
1856 static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
1857 if (const auto* op = expr.as<IterSplitExprNode>()) {
1858 return GetRef<IterSplitExpr>(op);
1859 } else if (const auto* op = expr.as<IterSumExprNode>()) {
1860 return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
1861 } else {
1862 LOG(FATAL) << "Unknown IterMapExpr type";
1863 }
1864 }
1865 };
1866
1867 // Divide an IterSumExpr
1868 DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) {
1869 auto dtype = expr.dtype();
1870 if (expr->args.empty()) {
1871 // base
1872 return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1),
1873 IterSumExpr({}, expr->base), make_const(dtype, 1));
1874 } else if (expr->args.size() == 1) {
1875 // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
1876 if (!is_one(expr->args[0]->scale)) {
1877 unresolved_count_++;
1878 return DivisionResult::Failure();
1879 }
1880 DivisionResult res = DivideIterSplitExpr(expr->args[0]);
1881 if (!is_zero(expr->base)) res = AddBase(res, expr->base);
1882 return res;
1883 }
1884 // arg1 + arg2 + ... + argn + base
1885 // then we can write it as Y*E(X)+X
1886 // if it starts with contiguous outer splits, followed by contiguous inner splits
1887 PrimExpr extent = make_const(dtype, 1);
1888 std::vector<IterSplitExpr> outer_args, inner_args;
1889 bool inner = true, scale_is_one = false;
1890 // we check in inverse order so we can visit from inner to outer
1891 for (auto it = expr->args.rbegin(); it != expr->args.rend(); ++it) {
1892 const IterSplitExpr& arg = *it;
1893 if (is_one(arg->scale)) scale_is_one = true;
1894 DivisionResult arg_division = DivideIterSplitExpr(arg);
1895 IterSplitExpr new_arg;
1896 if (arg_division.IsInner()) {
1897 if (!inner) {
1898 unresolved_count_++;
1899 return DivisionResult::Failure();
1900 }
1901 new_arg = arg_division.GetInnerAsSplit();
1902 inner_args.push_back(new_arg);
1903 inner = true;
1904 } else if (arg_division.IsOuter()) {
1905 new_arg = arg_division.GetOuterAsSplit();
1906 outer_args.push_back(new_arg);
1907 inner = false;
1908 } else {
1909 unresolved_count_++;
1910 return DivisionResult::Failure();
1911 }
1912 extent *= new_arg->extent;
1913 }
1914 if (!scale_is_one) {
1915 unresolved_count_++;
1916 return DivisionResult::Failure();
1917 }
1918 bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
1919 const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, make_const(dtype, 0));
1920 const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
1921 IterSumExpr outer_source = Downcast<IterSumExpr>(outer_mark->source);
1922 IterSumExpr inner_source = Downcast<IterSumExpr>(inner_mark->source);
1923 if (need_predicate) {
1924 // if we have a predicate on this sum expr, then we cannot divide it into Y*E+X
1925 // it should either be Y*1+0 or 0*E(X)+X
1926 IterMapToExprNormalizer converter(analyzer_);
1927 if (inner_args.empty()) {
1928 // Y*1+0
1929 outer_preds_ = outer_preds_ && (converter.Convert(outer_source) < mark_extent);
1930 return DivisionResult::Outer(outer_source, mark_extent);
1931 } else if (outer_args.empty()) {
1932 // 0*E(X)+X
1933 inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
1934 return DivisionResult::Inner(inner_source, mark_extent);
1935 } else {
1936 unresolved_count_++;
1937 return DivisionResult::Failure();
1938 }
1939 }
1940 return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
1941 }
1942
1943 PrimExpr GetOuterPreds() const { return outer_preds_; }
1944 PrimExpr GetInnerPreds() const { return inner_preds_; }
1945
1946 private:
1947 DivisionResult AddBase(DivisionResult division, PrimExpr base) {
1948 DivisionResult res = division;
1949 if (const auto* op = division.inner.as<IterSplitExprNode>()) {
1950 res.inner = IterSumExpr({GetRef<IterSplitExpr>(op)}, base);
1951 } else if (const auto* op = division.inner.as<IterSumExprNode>()) {
1952 const auto& expr = GetRef<IterSumExpr>(op);
1953 res.inner = IterSumExpr(expr->args, expr->base + base);
1954 }
1955 return res;
1956 }
1957
1958 // args are sorted from inner to outer
1959 static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args, PrimExpr base) {
1960 std::vector<IterSplitExpr> res;
1961 PrimExpr extent = make_const(base.dtype(), 1);
1962 for (const IterSplitExpr& it : args) {
1963 IterSplitExpr arg = it;
1964 arg.CopyOnWrite()->scale = extent;
1965 extent *= arg->extent;
1966 res.push_back(arg);
1967 }
1968 return IterMark(IterSumExpr(Array<IterSplitExpr>(res.rbegin(), res.rend()), base), extent);
1969 }
1970
1971 DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) {
1972 auto it = split_map_.find(expr);
1973 if (it != split_map_.end()) {
1974 // We will calculate all the splits of an IterMark's division form when we first
1975 // encounter one of them. If we encounter another later, we directly return the record.
1976 return it->second;
1977 }
1978 const Array<IterSplitExpr>& splits = collector_.mark2splits_.at(expr->source);
1979 if (const auto* iter_ptr = expr->source->source.as<VarNode>()) {
1980 // source is input_iter
1981 bool inner = sub_iters_.count(GetRef<Var>(iter_ptr));
1982 for (const IterSplitExpr& split : splits) {
1983 if (inner) {
1984 // 0*E(split)+split
1985 split_map_.emplace(split, DivisionResult::Inner(split, split->extent));
1986 } else {
1987 // split*1 + 0
1988 split_map_.emplace(split, DivisionResult::Outer(split, split->extent));
1989 }
1990 }
1991 } else if (const auto* iter_ptr = expr->source->source.as<IterSumExprNode>()) {
1992 // source = Y*E+X
1993 // splits = [s1, s2, ..., sn]
1994 // we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y)
1995 // extent(si+1)...extent(sn)=extent(X)
1996 // For example, if source = Y*3+X \in [0, 12), Y \in [0, 4), X \in [0, 3)
1997 // Case 1. splits = [s1, s2, s3] = [source / 6, (source / 3) % 2, source % 3],
1998 // where extent(s1) = 2, extent(s2) = 2, extent(s3) = 3.
1999 // Since extent(s1)extent(s2) = extent(Y), extent(s3) = extent(X), we have
2000 // s1 = (Y / 2)*1 + 0, s2 = (Y % 2)*1 + 0, s3 = 0*3 + X
2001 // Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2],
2002 // where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2.
2003 // It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X.
2004 DivisionResult mark_division =
2005 DivideIterSumExpr(GetRef<IterSumExpr>(iter_ptr), expr->source->extent);
2006 if (splits.size() == 1) {
2007 return mark_division;
2008 }
2009 IterMark outer_mark(Downcast<IterSumExpr>(mark_division.outer), mark_division.outer_extent);
2010 IterMark inner_mark(Downcast<IterSumExpr>(mark_division.inner), mark_division.inner_extent);
2011 bool encountered_boundary = mark_division.IsOuter();
2012 std::vector<bool> used(splits.size(), false);
2013 std::vector<IterSplitExpr> inner_iters, outer_iters;
2014 PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1);
2015 // find the boundary of outer and inner, like case 1 above
2016 for (size_t i = 0; i < splits.size(); ++i) {
2017 size_t j = 0;
2018 for (; j < splits.size(); ++j) {
2019 if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
2020 break;
2021 }
2022 if (j == splits.size()) {
2023 unresolved_count_++;
2024 return DivisionResult::Failure();
2025 }
2026 used[j] = true;
2027 if (!encountered_boundary) {
2028 inner_iters.push_back(splits[j]);
2029 } else {
2030 outer_iters.push_back(splits[j]);
2031 }
2032 expected_lower_factor *= splits[j]->extent;
2033 if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
2034 encountered_boundary = true;
2035 }
2036 if (!encountered_boundary) {
2037 unresolved_count_++;
2038 return DivisionResult::Failure();
2039 }
2040 for (const IterSplitExpr& inner_iter : inner_iters) {
2041 IterSplitExpr new_iter = inner_iter;
2042 new_iter.CopyOnWrite()->source = inner_mark;
2043 split_map_.emplace(inner_iter, DivisionResult::Inner(new_iter, inner_iter->extent));
2044 }
2045 for (const IterSplitExpr& outer_iter : outer_iters) {
2046 IterSplitExpr new_iter = outer_iter;
2047 new_iter.CopyOnWrite()->source = outer_mark;
2048 new_iter.CopyOnWrite()->lower_factor =
2049 floordiv(outer_iter->lower_factor, outer_iters[0]->lower_factor);
2050 split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
2051 }
2052 } else {
2053 unresolved_count_++;
2054 return DivisionResult::Failure();
2055 }
2056 return split_map_.at(expr);
2057 }
2058
2059 size_t unresolved_count_{0};
2060 // arithmetic analyzer used to call CanProve
2061 Analyzer* analyzer_;
2062 // collector that collects the outgoing split reference of each IterMark
2063 const IterMarkSplitCollector collector_;
2064 // the set of subspace iters
2065 const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters_;
2066 // map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X)
2067 std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
2068 // predicate of outer space and inner space;
2069 PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)};
2070};
2071
2072Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
2073 const Map<Var, Range>& input_iters,
2074 const Array<Var>& sub_iters, const PrimExpr& predicate,
2075 IterMapLevel check_level, arith::Analyzer* analyzer,
2076 bool simplify_trivial_iterators) {
2077 if (!IterRangeSanityCheck(input_iters)) return Array<Array<IterMark>>();
2078 auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer,
2079 simplify_trivial_iterators);
2080 const Array<IterSumExpr>& maps = res->indices;
2081 if (maps.empty()) return {};
2082
2083 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
2084 for (const Var& inner_iter : sub_iters) {
2085 inner_iter_set.insert(inner_iter);
2086 }
2087
2088 IterMarkSplitCollector collector;
2089 collector.Collect(maps);
2090 SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set);
2091
2092 std::vector<Array<IterMark>> results;
2093 for (const IterSumExpr& expr : maps) {
2094 SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0);
2095 if (subspace_divider.unresolved_count()) return {};
2096 results.push_back(
2097 {IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)});
2098 }
2099
2100 results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()),
2101 IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())});
2102 return results;
2103}
2104
2105TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
2106 .set_body_typed([](const Array<PrimExpr>& bindings, const Map<Var, Range>& root_iters,
2107 const Array<Var>& sub_iters, const PrimExpr& predicate, int check_level,
2108 bool simplify_trivial_iterators) {
2109 arith::Analyzer ana;
2110 return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level),
2111 &ana, simplify_trivial_iterators);
2112 });
2113
2114class InverseAffineIterMapTransformer {
2115 public:
2116 explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {}
2117
2118 Map<Var, PrimExpr> operator()(const Array<IterSumExpr>& iter_map,
2119 const Array<PrimExpr>& outputs) {
2120 ICHECK(iter_map.size() == outputs.size());
2121 std::vector<const IterMapExprNode*> post_dfs_order = ReverseTopologyOrder(iter_map);
2122
2123 // initialize back propagation accumulator
2124 for (const IterMapExprNode* node : post_dfs_order) {
2125 backprop_.Set(GetRef<IterMapExpr>(node), Integer(0));
2126 }
2127 for (size_t i = 0; i < iter_map.size(); i++) {
2128 backprop_.Set(iter_map[i], outputs[i]);
2129 }
2130
2131 // run back propagation
2132 for (const IterMapExprNode* node : post_dfs_order) {
2133 if (node->IsInstance<IterSumExprNode>()) {
2134 Visit_(Downcast<IterSumExpr>(GetRef<IterMapExpr>(node)));
2135 } else {
2136 ICHECK(node->IsInstance<IterSplitExprNode>());
2137 Visit_(Downcast<IterSplitExpr>(GetRef<IterMapExpr>(node)));
2138 }
2139 }
2140 return std::move(inverse_);
2141 }
2142
2143 private:
2144 void Visit_(const IterSumExpr& iter_map_expr) {
2145 PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base;
2146
2147 // Case 1: Propagate to the input node directly when the sum expression has only one components
2148 if (iter_map_expr->args.size() == 1) {
2149 const auto& source = iter_map_expr->args[0];
2150 backprop_.Set(source, backprop_.at(source) + input);
2151 return;
2152 }
2153
2154 // Case 2: If the sum expression has multiple components, check the fuse pattern and then split
2155 // the sum expression for each components.
2156 // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2
2157 // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the
2158 // propagated value to get the corresponding components of i1 and i2, which are
2159 // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
2160 CheckFusePattern(iter_map_expr);
2161 for (size_t i = iter_map_expr->args.size(); i > 0; i--) {
2162 const IterSplitExpr& split = iter_map_expr->args[i - 1];
2163 PrimExpr prop_value = floordiv(input, split->scale);
2164 // the first part has the same extent as the split expression, floormod is not needed
2165 if (i > 1) {
2166 prop_value = floormod(prop_value, split->extent);
2167 }
2168 backprop_.Set(split, backprop_.at(split) + prop_value);
2169 }
2170 }
2171
2172 std::vector<const IterMapExprNode*> ReverseTopologyOrder(const Array<IterSumExpr>& iter_map) {
2173 std::vector<const IterMapExprNode*> post_dfs_order;
2174 std::unordered_map<IterMapExpr, bool, ObjectPtrHash, ObjectPtrEqual> visited;
2175
2176 std::function<void(const IterMapExpr&)> fvisit = [&](const IterMapExpr& expr) {
2177 if (visited[expr]) {
2178 return;
2179 }
2180 visited[expr] = true;
2181 if (const auto* sum_expr = expr.as<IterSumExprNode>()) {
2182 for (const IterSplitExpr& child : sum_expr->args) {
2183 fvisit(child);
2184 }
2185 } else {
2186 const auto* split_expr = expr.as<IterSplitExprNode>();
2187 ICHECK(split_expr);
2188 if (const auto* source = split_expr->source->source.as<IterMapExprNode>()) {
2189 fvisit(GetRef<IterMapExpr>(source));
2190 }
2191 }
2192 post_dfs_order.push_back(expr.get());
2193 };
2194 for (const IterSumExpr& expr : iter_map) {
2195 fvisit(expr);
2196 }
2197 std::reverse(post_dfs_order.begin(), post_dfs_order.end());
2198 return post_dfs_order;
2199 }
2200
2201 void Visit_(const IterSplitExpr& iter_map_expr) {
2202 PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor;
2203 const IterMark& source = iter_map_expr->source;
2204 if (source->source.as<IterSumExprNode>()) {
2205 IterSumExpr source_expr = Downcast<IterSumExpr>(source->source);
2206 backprop_.Set(source_expr, backprop_.at(source_expr) + input);
2207 } else {
2208 Var source_var = Downcast<Var>(source->source);
2209 if (inverse_.count(source_var)) {
2210 inverse_.Set(source_var, inverse_.at(source_var) + input);
2211 } else {
2212 inverse_.Set(source_var, input);
2213 }
2214 }
2215 }
2216
2217 /*
2218 * \brief Check the fuse pattern of sum_expr. We assume components of sum_expr is sorted in
2219 * descending order of lower_factor.
2220 */
2221 void CheckFusePattern(const IterSumExpr sum_expr) {
2222 if (sum_expr->args.empty()) {
2223 return;
2224 }
2225 PrimExpr expected_scale = sum_expr->args.back()->scale;
2226 for (size_t i = sum_expr->args.size(); i > 0; i--) {
2227 ICHECK(analyzer_->CanProveEqual(sum_expr->args[i - 1]->scale, expected_scale));
2228 expected_scale *= sum_expr->args[i - 1]->extent;
2229 }
2230 }
2231
2232 Analyzer* analyzer_;
2233 Map<IterMapExpr, PrimExpr> backprop_; // the accumulator of backpropgation
2234 Map<Var, PrimExpr> inverse_; // the result of inverse transformation
2235};
2236
2237Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
2238 const Array<PrimExpr> outputs) {
2239 Analyzer analyzer;
2240 return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs);
2241}
2242
2243TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap);
2244
2245TVM_REGISTER_NODE_TYPE(IterMapResultNode);
2246
2247} // namespace arith
2248} // namespace tvm
2249