1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/arith/iter_affine_map.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/arith/iter_affine_map.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/expr_functor.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | #include <utility> |
32 | |
33 | #include "../support/utils.h" |
34 | #include "const_fold.h" |
35 | #include "pattern_match.h" |
36 | |
37 | namespace tvm { |
38 | namespace arith { |
39 | |
40 | using namespace tir; |
41 | |
42 | IterMark::IterMark(PrimExpr source, PrimExpr extent) { |
43 | auto n = make_object<IterMarkNode>(); |
44 | n->source = std::move(source); |
45 | n->extent = std::move(extent); |
46 | data_ = std::move(n); |
47 | } |
48 | |
49 | TVM_REGISTER_GLOBAL("arith.IterMark" ).set_body_typed([](PrimExpr source, PrimExpr extent) { |
50 | return IterMark(source, extent); |
51 | }); |
52 | |
53 | TVM_REGISTER_NODE_TYPE(IterMarkNode); |
54 | |
55 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
56 | .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) { |
57 | auto* op = static_cast<const IterMarkNode*>(node.get()); |
58 | p->stream << "IterMark(" << op->source << ", extent=" << op->extent << ")" ; |
59 | }); |
60 | |
61 | IterSplitExpr::IterSplitExpr(IterMark source) { |
62 | auto n = make_object<IterSplitExprNode>(); |
63 | auto one = make_const(source->source->dtype, 1); |
64 | n->dtype = source->source->dtype; |
65 | n->source = std::move(source); |
66 | n->extent = n->source->extent; |
67 | n->lower_factor = one; |
68 | n->scale = one; |
69 | data_ = std::move(n); |
70 | } |
71 | |
72 | IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { |
73 | auto n = make_object<IterSplitExprNode>(); |
74 | auto one = make_const(source->source->dtype, 1); |
75 | n->dtype = source->source->dtype; |
76 | n->source = std::move(source); |
77 | n->extent = n->source->extent; |
78 | n->lower_factor = one; |
79 | n->scale = std::move(scale); |
80 | data_ = std::move(n); |
81 | } |
82 | |
83 | IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, |
84 | PrimExpr scale) { |
85 | auto n = make_object<IterSplitExprNode>(); |
86 | n->dtype = source->source->dtype; |
87 | n->source = std::move(source); |
88 | n->lower_factor = std::move(lower_factor); |
89 | n->extent = std::move(extent); |
90 | n->scale = std::move(scale); |
91 | data_ = std::move(n); |
92 | } |
93 | |
94 | TVM_REGISTER_GLOBAL("arith.IterSplitExpr" ) |
95 | .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { |
96 | return IterSplitExpr(source, lower_factor, extent, scale); |
97 | }); |
98 | |
99 | TVM_REGISTER_NODE_TYPE(IterSplitExprNode); |
100 | |
101 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
102 | .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) { |
103 | auto* op = static_cast<const IterSplitExprNode*>(node.get()); |
104 | p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor |
105 | << ", extent=" << op->extent << ", scale=" << op->scale << ")" ; |
106 | }); |
107 | |
108 | IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) { |
109 | auto n = make_object<IterSumExprNode>(); |
110 | n->dtype = base->dtype; |
111 | n->args = std::move(args); |
112 | n->base = std::move(base); |
113 | data_ = std::move(n); |
114 | } |
115 | |
116 | TVM_REGISTER_GLOBAL("arith.IterSumExpr" ) |
117 | .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) { |
118 | return IterSumExpr(args, base); |
119 | }); |
120 | |
121 | TVM_REGISTER_NODE_TYPE(IterSumExprNode); |
122 | |
123 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
124 | .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) { |
125 | auto* op = static_cast<const IterSumExprNode*>(node.get()); |
126 | p->stream << "IterSum(" << op->args << ", " << op->base << ")" ; |
127 | }); |
128 | |
129 | /*! |
130 | * \brief Collector that collects the outgoing split reference of each IterMark. |
131 | * |
132 | * These out-going splits can then be used to check if the iterators are independent. |
133 | */ |
134 | class IterMarkSplitCollector { |
135 | public: |
136 | // mark all IterMarks that are visited. |
137 | std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_; |
138 | // each iter mark to its outgoing splits that are referenced. |
139 | std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual> |
140 | mark2splits_; |
141 | /*! |
142 | * \brief Collect all mark2splits recursively from indices. |
143 | * \param indices The iterator of interest. |
144 | */ |
145 | void Collect(const Array<IterSumExpr>& indices) { |
146 | for (IterSumExpr sum_expr : indices) { |
147 | for (IterSplitExpr split : sum_expr->args) { |
148 | this->CollectInternal(split->source); |
149 | mark2splits_[split->source].push_back(split); |
150 | } |
151 | } |
152 | } |
153 | |
154 | void CollectInternal(const IterMark& mark) { |
155 | if (visited_.count(mark)) return; |
156 | visited_.insert(mark); |
157 | if (auto* op = mark->source.as<IterSumExprNode>()) { |
158 | for (IterSplitExpr split : op->args) { |
159 | this->CollectInternal(split->source); |
160 | mark2splits_[split->source].push_back(split); |
161 | } |
162 | } |
163 | } |
164 | }; |
165 | |
166 | /*! \brief Record form of IterMark(x, extent) + offset */ |
167 | struct IterMarkWithOffset { |
168 | IterMark mark; |
169 | PrimExpr offset{0}; |
170 | IterMarkWithOffset() {} |
171 | IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {} |
172 | }; |
173 | |
174 | /*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */ |
175 | class IterMapRewriter : public ExprMutator { |
176 | public: |
177 | using Parent = ExprMutator; |
178 | |
179 | explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters, |
180 | IterMapLevel check_level, bool simplify_trivial_iterators, |
181 | Array<String>* errors) |
182 | : analyzer_(analyzer), |
183 | check_level_(check_level), |
184 | errors_(*errors), |
185 | padding_predicate_(const_false()) { |
186 | for (auto kv : input_iters) { |
187 | const Var& var = kv.first; |
188 | const Range& vrng = kv.second; |
189 | if (simplify_trivial_iterators && is_one(vrng->extent)) { |
190 | var_map_[var] = IterSumExpr({}, vrng->min); |
191 | } else if (is_zero(vrng->min)) { |
192 | IterMark mark(var, vrng->extent); |
193 | var_map_[var] = IterSplitExpr(mark); |
194 | input_marks_.push_back(mark); |
195 | } else { |
196 | IterMark mark(var - vrng->min, vrng->extent); |
197 | IterSumExpr sum_expr = ToIterSumExpr(IterSplitExpr(mark)); |
198 | sum_expr.CopyOnWrite()->base = vrng->min; |
199 | var_map_[var] = sum_expr; |
200 | input_marks_.push_back(mark); |
201 | } |
202 | } |
203 | } |
204 | |
205 | PrimExpr padding_predicate() const { return padding_predicate_; } |
206 | bool requires_padding() const { return requires_padding_; } |
207 | |
208 | IterSumExpr Rewrite(const PrimExpr& expr) { |
209 | return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); |
210 | } |
211 | |
212 | IterSumExpr RewriteAndUpdatePadding(const PrimExpr& expr) { |
213 | update_iterator_padding_ = true; |
214 | auto res = Rewrite(expr); |
215 | update_iterator_padding_ = false; |
216 | return res; |
217 | } |
218 | |
219 | IterSumExpr RewriteIterConstraint(const PrimExpr& expr, |
220 | const Optional<PrimExpr>& predicate_induced_min, |
221 | const Optional<PrimExpr>& predicate_induced_max) { |
222 | return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, |
223 | predicate_induced_max); |
224 | } |
225 | |
226 | /*! |
227 | * \brief If require bijective mapping, this function checks two conditions: |
228 | * - C0: Each iter mark should be fully covered by non-overlapping splits. |
229 | * - C1: All of the input iterators are used. |
230 | * Example: given x in [0, 8) y in [0, 6) |
231 | * - bindings = [x, x + 1, y] won't pass because x and x+1 contribute |
232 | * two splits that overlaps with each other. |
233 | * - bindings = [x / 4, x % 4, y] will pass because x / 4 and x % 4 |
234 | * contribute two non-overlapping splits that covers x. |
235 | * - bindings = [x / 4, x % 4] won't pass because y is not used. |
236 | * |
237 | * If only require surjective mapping, this function checks one condition: |
238 | * - C0: Each iter mark has a chance to be fully covered by non-overlapping splits. |
239 | * Example: given x in [0, 8) y in [0, 6) |
240 | * - bindings = [x / 4] will pass because x / 4 can be one split of x |
241 | * - bindings = [x / 4, x % 4] will pass because x / 4 and x % 4 |
242 | * contribute two non-overlapping splits that covers x. |
243 | * - bindings = [x / 3] will not pass because x / 3 can not be one split of x |
244 | * \return whether the bindings are valid |
245 | */ |
246 | bool CheckMapping(const Array<IterSumExpr>& bindings, IterMapLevel check_level) { |
247 | IterMarkSplitCollector collector; |
248 | // We can check that for each iter mark: |
249 | // All the splits that refers to the iter_mark covers its extent. |
250 | // The splits do not overlap with each other. |
251 | collector.Collect(bindings); |
252 | |
253 | for (const IterMark& mark : collector.visited_) { |
254 | if (TryNormalizeSplits(mark, collector.mark2splits_[mark], check_level).empty()) { |
255 | return false; |
256 | } |
257 | } |
258 | if (check_level == IterMapLevel::Bijective) { |
259 | // all input marks must be visited |
260 | for (const IterMark& mark : input_marks_) { |
261 | if (collector.visited_.count(mark) == 0 && !is_one(mark->extent)) { |
262 | return false; |
263 | } |
264 | } |
265 | } |
266 | return true; |
267 | } |
268 | |
269 | /*! |
270 | * \brief Check the validity of iterator constraints |
271 | * The flattened forms of two different iterator constraints |
272 | * either 1) follow inclusion relation or 2) have no intersection |
273 | * |
274 | * For Example, x = i0*30 + i1*15 + i2*3 + i3, |
275 | * 1) [i0*2 + i1 < 3, i2*3 + i3 < 5] is valid, since {i0, i1} \\intersect {i2, i3} = empty set. |
276 | * 2) [i0*2 + i1 < 3, i1*5 + i2 < 5] is not valid, |
277 | * since {i0, i1} \\intersect {i1, i2} = {i1}, i0 \\in {i0, i1}, i0 \\notin {i1, i2} |
278 | * \return whether the predicates are valid; |
279 | */ |
280 | bool CheckConstraints() const { |
281 | // the constrained_iters_flattened_ are in the order of shorter to longer |
282 | // since we visit the predicates in the order of size |
283 | for (size_t i = 0; i < constrained_iters_flattened_.size(); ++i) { |
284 | for (size_t j = i + 1; j < constrained_iters_flattened_.size(); ++j) { |
285 | // state: 0(start), -1(no intersection), 1(inclusion) |
286 | int state = 0; |
287 | for (const IterSplitExpr& arg1 : constrained_iters_flattened_[i]->args) { |
288 | bool found = false; |
289 | for (const IterSplitExpr& arg2 : constrained_iters_flattened_[j]->args) { |
290 | if (IterSplitEqual(arg1, arg2)) { |
291 | found = true; |
292 | break; |
293 | } |
294 | } |
295 | // Check either it is inclusion or intersection, but not both |
296 | if (state == 0) { |
297 | state = found ? 1 : -1; |
298 | } else if ((state == -1 && found) || (state == 1 && !found)) { |
299 | return false; |
300 | } |
301 | } |
302 | } |
303 | } |
304 | return true; |
305 | } |
306 | |
307 | // override the original mutate function. |
308 | PrimExpr VisitExpr(const PrimExpr& input_expr) final { |
309 | auto expr = ExprMutator::VisitExpr(input_expr); |
310 | if (expr->IsInstance<IterMapExprNode>()) { |
311 | ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " |
312 | << "IterMapRewriter using DirectMutate. " |
313 | << "Indirect return occurred in " << input_expr; |
314 | } |
315 | return expr; |
316 | } |
317 | |
318 | // Normal mutation without normalization. |
319 | PrimExpr DirectMutate(const PrimExpr& expr) { return ExprMutator::VisitExpr(expr); } |
320 | |
321 | PrimExpr VisitExpr_(const VarNode* op) final; |
322 | PrimExpr VisitExpr_(const AddNode* op) final; |
323 | PrimExpr VisitExpr_(const SubNode* op) final; |
324 | PrimExpr VisitExpr_(const MulNode* op) final; |
325 | PrimExpr VisitExpr_(const FloorDivNode* op) final; |
326 | PrimExpr VisitExpr_(const FloorModNode* op) final; |
327 | |
328 | private: |
329 | /* \brief Preprocessing common to both FloorDiv and FloorMod |
330 | * |
331 | * \param dividend The dividend to be manipulated. |
332 | */ |
333 | IterSumExpr PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend); |
334 | |
335 | // Create an iterator that represents the expression (split+base), with |
336 | // padding such that the iterator's extents are evenly divisible by |
337 | // `divisor`. |
338 | // |
339 | // If iterators can have padding added through UpdatePadding, pad a |
340 | // dividend out to be evenly divisible. Otherwise, validate that the |
341 | // padding previously defined for the split using UpdatePadding can be |
342 | // used. If no such previous padding exists, return an empty |
343 | // IterMark. |
344 | // |
345 | // Returns a pair of IterSplit that represents (split+base) in a |
346 | // form that can be dividied by divisors, and PrimExpr that |
347 | // represents the left padding applied to split. |
348 | std::pair<IterSplitExpr, PrimExpr> PadDividendToDivisor(IterSplitExpr split, PrimExpr base, |
349 | PrimExpr divisor); |
350 | |
351 | friend struct ErrorLogger; |
352 | |
353 | /* \brief Utility class for logging errors. |
354 | * |
355 | * It is not an error for IterMapRewriter to receive an expression that |
356 | * cannot be represented as an IterSumExpr. In these cases, |
357 | * IterMapRewriter returns the unrepresentable portions of the TIR graph |
358 | * without modification. As a result, the usual ICHECK or LOG(FATAL) |
359 | * macros cannot be used. Instead, ErrorLogger(this) can be used to |
360 | * report an unrepresentable TIR graph, which may be used in error |
361 | * messages at the calling scope. |
362 | */ |
363 | class ErrorLogger { |
364 | public: |
365 | explicit ErrorLogger(IterMapRewriter* rewriter) : rewriter(rewriter) {} |
366 | ~ErrorLogger() { rewriter->errors_.push_back(os.str()); } |
367 | |
368 | template <typename T> |
369 | ErrorLogger& operator<<(T&& t) { |
370 | os << std::forward<T>(t); |
371 | return *this; |
372 | } |
373 | |
374 | private: |
375 | IterMapRewriter* rewriter; |
376 | std::ostringstream os; |
377 | }; |
378 | |
379 | struct IterPaddingInfo { |
380 | // GCD of padding factor collected during first pass |
381 | PrimExpr padding_factor{1}; |
382 | |
383 | PrimExpr left_pad{0}; |
384 | PrimExpr right_pad{0}; |
385 | |
386 | // Padded form of original iter mark |
387 | IterMark padded; |
388 | }; |
389 | |
390 | // temp hash for de-duplication purposes. |
391 | struct IterSumHash { |
392 | size_t operator()(const IterSumExpr& value) const { |
393 | // for now only hash on source index. |
394 | size_t hash = value->args.size(); |
395 | for (const IterSplitExpr& arg : value->args) { |
396 | hash = support::HashCombine(hash, std::hash<const Object*>()(arg->source.get())); |
397 | } |
398 | return hash; |
399 | } |
400 | }; |
401 | |
402 | static bool IterSplitEqual(const IterSplitExpr& lhs, const IterSplitExpr& rhs, |
403 | bool check_scale = true) { |
404 | tir::ExprDeepEqual equal; |
405 | if (!lhs->source.same_as(rhs->source)) return false; |
406 | if (!equal(lhs->lower_factor, rhs->lower_factor)) return false; |
407 | if (check_scale && !equal(lhs->scale, rhs->scale)) return false; |
408 | if (!equal(lhs->extent, rhs->extent)) return false; |
409 | return true; |
410 | } |
411 | |
412 | struct IterSumEqual { |
413 | bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const { |
414 | tir::ExprDeepEqual equal; |
415 | if (lhs->args.size() != rhs->args.size()) return false; |
416 | if (!equal(lhs->base, rhs->base)) return false; |
417 | for (size_t i = 0; i < lhs->args.size(); ++i) { |
418 | if (!IterSplitEqual(lhs->args[i], rhs->args[i])) return false; |
419 | } |
420 | return true; |
421 | } |
422 | }; |
423 | |
424 | // Internal analyzer |
425 | Analyzer* analyzer_; |
426 | // Iter map check level |
427 | IterMapLevel check_level_; |
428 | // Error messages for each unresolved expression. |
429 | Array<String>& errors_; |
430 | // The var map |
431 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_; |
432 | // input iter marks |
433 | std::vector<IterMark> input_marks_; |
434 | |
435 | // Map from an iter mark to the padded iterator information for |
436 | // it. This is necessary for introducing the same padding in all |
437 | // usage of an input iterator. (e.g. (i-1) occurring in the |
438 | // expressions [(i-1)%8, ((i-1)//8)%4, (i-1)//32] should be |
439 | // left-padded by 31 for each occurrence.) |
440 | std::unordered_map<IterMark, IterPaddingInfo, StructuralHash, StructuralEqual> padded_iter_map_; |
441 | |
442 | // Map from padded iter mark to it's origin mark |
443 | std::unordered_map<IterMark, IterMark, StructuralHash, StructuralEqual> padded_origin_map_; |
444 | |
445 | /* If update_iterator_padding_ is true, allow the extents of the IterMap to be |
446 | * padded beyond the original iterators. |
447 | * |
448 | * For example, if update_iterator_padding_ is true, the expressions i//4 and |
449 | * i%4, where i is on the range [0,18), would be represented as |
450 | * IterSplit(i, lower_factor=4, extent=5) and IterSplit(i, extent=4). |
451 | * This representation would be forbidden if update_iterator_padding_ is false, |
452 | * because lower_factor=4 does not evenly divide the original extent of |
453 | * 18. |
454 | */ |
455 | bool update_iterator_padding_{false}; |
456 | |
457 | /* A boolean expression that is true for any padding that has been |
458 | * introduced, and false otherwise. If update_iterator_padding_ is false, |
459 | * padding_predicate_ will always be false. |
460 | * |
461 | * Example: [i//4, i%4], i in range [0,16) |
462 | * padding_predicate_ will be false |
463 | * |
464 | * Example: [i//4, i%4], i in range [0,18) |
465 | * padding_predicate_ will be `(i//4 == 3) && (i%4 >= 2)` |
466 | * |
467 | * Example: [i//4, i%4], i in range [0,N) |
468 | * padding_predicate_ will be `(N%4!=0) && (i//4 == (N+3)//4-1) && (i%4 >= N%4)` |
469 | */ |
470 | PrimExpr padding_predicate_; |
471 | |
472 | /* A boolean flag denotes there are padding iterations detected |
473 | * in the first round of indices rewriting. |
474 | */ |
475 | bool requires_padding_{false}; |
476 | |
477 | // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly |
478 | // an extra offset) |
479 | // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) |
480 | // predicate: j*2 + k < 9 |
481 | // Then, flattened form = IterSum(IterSplit(i, scale=9), |
482 | // IterSplit(j, scale=2), |
483 | // IterSplit(k, scale=1)) |
484 | // normal form = IterSum(IterSplit(i, scale=9), |
485 | // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), |
486 | // IterSplit(k, scale=1)), |
487 | // extent=9) |
488 | // scale=1)) |
489 | // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) |
490 | // predicate: 1 <= j*2 + k < 9 |
491 | // Then, flattened form = IterSum(IterSplit(i, scale=8), |
492 | // IterSplit(j, scale=2), |
493 | // IterSplit(k, scale=1)) |
494 | // normal form = IterSum(IterSplit(i, scale=8), |
495 | // IterSplit(IterMark(IterSum(IterSplit(j, scale=2), |
496 | // IterSplit(k, scale=1), base=-1), |
497 | // extent=9-1) |
498 | // scale=1), |
499 | // base=1) |
500 | std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_; |
501 | // The map for sum that maps normal form to flattened form |
502 | std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_; |
503 | // The flattened forms of constrained iters |
504 | std::vector<IterSumExpr> constrained_iters_flattened_; |
505 | |
506 | /*! |
507 | * \brief Look for a split in splits that is not used such that its lower_factor is smallest. |
508 | * Note that here we use division to compare lower_factor. |
509 | * \param splits the split array to search in. |
510 | * \param used the input used array. |
511 | * \param expected_lower_factor the skipped lower factor. |
512 | * \return the index of the expected split, split.size() if not found. |
513 | */ |
514 | size_t SearchSkipLowerFactor(const std::vector<IterSplitExpr>& splits, |
515 | const std::vector<bool>& used, |
516 | const PrimExpr& expected_lower_factor) { |
517 | size_t res = splits.size(); |
518 | for (size_t i = 0; i < splits.size(); ++i) { |
519 | if (used[i]) continue; |
520 | if (!used[i] && !CanProveDivisible(splits[i]->lower_factor, expected_lower_factor)) { |
521 | // all the remaining unused splits should have their lower factor divisible |
522 | return splits.size(); |
523 | } |
524 | if (res == splits.size() || |
525 | CanProveDivisible(splits[res]->lower_factor, splits[i]->lower_factor)) { |
526 | // note down the split with smaller lower factor |
527 | res = i; |
528 | } |
529 | } |
530 | return res; |
531 | } |
532 | |
533 | /*! |
534 | * \brief If bijective is required, verify that splits fully covers mark in a non-overlapping |
535 | * fashion, If not, verify that splits are valid and compatible for the mark. |
536 | * If verification passes, return splits from outermost to innermost order. |
537 | * If not, return an empty array. |
538 | * \param mark The iterator of interest. |
539 | * \param splits The splits to be verified. |
540 | * \param check_level Iteration mapping's check level. |
541 | * \return The normalized splits. |
542 | */ |
543 | Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark, |
544 | const std::vector<IterSplitExpr>& splits, |
545 | IterMapLevel check_level) { |
546 | std::vector<bool> used(splits.size(), false); |
547 | std::vector<IterSplitExpr> iters; |
548 | PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); |
549 | |
550 | for (size_t i = 0; i < splits.size(); ++i) { |
551 | size_t j = 0; |
552 | for (; j < splits.size(); ++j) { |
553 | if (used[j]) continue; |
554 | if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) { |
555 | break; |
556 | } |
557 | } |
558 | if (j == splits.size()) { |
559 | // we do not allow incomplete split if the bindings should be bijective |
560 | if (check_level == IterMapLevel::Bijective) { |
561 | return Array<IterSplitExpr>(); |
562 | } |
563 | // look for the next split skipping this lower factor |
564 | // For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2] |
565 | // It is valid to only have [y / 6, y % 2] if bijective is not required |
566 | // We can skip (y / 2) % 6 |
567 | j = SearchSkipLowerFactor(splits, used, expected_lower_factor); |
568 | // split not found |
569 | if (j == splits.size()) { |
570 | return Array<IterSplitExpr>(); |
571 | } |
572 | } |
573 | |
574 | used[j] = true; |
575 | iters.push_back(splits[j]); |
576 | expected_lower_factor = splits[j]->lower_factor * splits[j]->extent; |
577 | } |
578 | |
579 | // Extract iteration mark info before padding |
580 | auto pad_mark_it = padded_origin_map_.find(mark); |
581 | bool has_padding = pad_mark_it != padded_origin_map_.end(); |
582 | |
583 | bool match_full_iter = analyzer_->CanProveEqual(expected_lower_factor, mark->extent); |
584 | bool match_iter_divisor = |
585 | match_full_iter || CanProveDivisible(mark->extent, expected_lower_factor); |
586 | |
587 | // Case 1. bijective is required. |
588 | // We check the extent we calculate is consistent with the extent of the mark and |
589 | // iteration mark's padding is not allowed. |
590 | // |
591 | // Case 2. bijective is not required and there is no padding. |
592 | // We check the extent we calculate is a factor of the extent of the mark |
593 | // For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not. |
594 | // |
595 | // Case 3. bijective is not required and there exists padding. We check either |
596 | // (3.1) The extent we calculate is consistent with the extent of the padded mark and it is |
597 | // the single split for the iter mark. |
598 | // For example, padded iter p in [0, 24), [(p / 12)] is valid because it is surjective |
599 | // according to how we pad the original iteration mark. |
600 | // (3.2) The extent we calculate is a factor of the extent of the padded mark, and the extent |
601 | // before padding is greater or equal than the extent we calculate. |
602 | // For example, the original extent is 14, [(p % 12)] is valid, with p padded to 24. |
603 | // |
604 | if (check_level == IterMapLevel::Bijective) { |
605 | if (has_padding) { |
606 | ErrorLogger(this) << "Bijectvie mapping should not take iter paddings" ; |
607 | return {}; |
608 | } else if (!match_full_iter) { |
609 | ErrorLogger(this) << "The iterations do not traverse full iter space" ; |
610 | return {}; |
611 | } |
612 | } else if (!has_padding) { |
613 | if (!match_iter_divisor) { |
614 | ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent" ; |
615 | return {}; |
616 | } |
617 | } else if (check_level == IterMapLevel::Surjective) { |
618 | PrimExpr extent_before_padding = pad_mark_it->second->extent; |
619 | if (match_full_iter) { |
620 | if (splits.size() != 1) { |
621 | ErrorLogger(this) << "Dependent iterations on padding iter space" ; |
622 | return Array<IterSplitExpr>(); |
623 | } else if (analyzer_->CanProveEqual(splits[0]->extent, expected_lower_factor) && |
624 | !analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { |
625 | ErrorLogger(this) << "Split on padding iteration is not surjective " |
626 | << "if the split extent equals to the full iter space extent" ; |
627 | return Array<IterSplitExpr>(); |
628 | } |
629 | } else if (match_iter_divisor) { |
630 | if (!analyzer_->CanProve(extent_before_padding >= expected_lower_factor)) { |
631 | ErrorLogger(this) << "The extent before padding is less than lower factor" ; |
632 | return Array<IterSplitExpr>(); |
633 | } |
634 | } else { |
635 | ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent" ; |
636 | return {}; |
637 | } |
638 | } |
639 | return Array<IterSplitExpr>(iters.rbegin(), iters.rend()); |
640 | } |
641 | |
642 | /*! |
643 | * \brief Normalize the iter expression with constraint (min <= expr < max) |
644 | * \param expr The iter expression. |
645 | * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined. |
646 | * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. |
647 | * \return The Normalized expression. |
648 | */ |
649 | IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min, |
650 | Optional<PrimExpr> predicate_induced_max) { |
651 | // normalize to zero base |
652 | PrimExpr base = expr->base; |
653 | if (!is_zero(base)) { |
654 | expr.CopyOnWrite()->base = 0; |
655 | if (predicate_induced_min.defined()) |
656 | predicate_induced_min = predicate_induced_min.value() - base; |
657 | if (predicate_induced_max.defined()) |
658 | predicate_induced_max = predicate_induced_max.value() - base; |
659 | } |
660 | Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_); |
661 | ICHECK(!opt.defined() || opt.value()->args.size() == 1); |
662 | // scale should be 1 |
663 | if (opt.defined() && is_one(opt.value()->args[0]->scale)) { |
664 | const IterSplitExpr split = opt.value()->args[0]; |
665 | IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source); |
666 | // get the flattened form |
667 | auto it = flattened_map_.find(structured_form); |
668 | ICHECK(it != flattened_map_.end()); |
669 | IterSumExpr flattened_form = it->second; |
670 | // get the mark and offset of the structured_form |
671 | auto it_mark = sum_fuse_map_.find(flattened_form); |
672 | ICHECK(it_mark != sum_fuse_map_.end()); |
673 | IterMark mark = it_mark->second.mark; |
674 | PrimExpr mark_offset = it_mark->second.offset; |
675 | PrimExpr iter_min = mark_offset; |
676 | PrimExpr iter_max = iter_min + mark->extent; |
677 | if (predicate_induced_min.defined()) { |
678 | iter_min = max(predicate_induced_min.value(), iter_min); |
679 | } |
680 | if (predicate_induced_max.defined()) { |
681 | iter_max = min(predicate_induced_max.value(), iter_max); |
682 | } |
683 | if (!is_zero(iter_min)) { |
684 | // structured form's offset should be updated |
685 | flattened_map_.erase(structured_form); |
686 | structured_form.CopyOnWrite()->base = -iter_min; |
687 | mark.CopyOnWrite()->source = structured_form; |
688 | flattened_map_[structured_form] = flattened_form; |
689 | } |
690 | mark.CopyOnWrite()->extent = iter_max - iter_min; |
691 | sum_fuse_map_[flattened_form] = {mark, iter_min}; |
692 | // we need to note down the flattened form of constrained iterators |
693 | // to check the validity of constraints, see also CheckConstraints() |
694 | constrained_iters_flattened_.push_back(flattened_form); |
695 | expr.CopyOnWrite()->args = Array<IterSplitExpr>({split}); |
696 | expr.CopyOnWrite()->base = base + iter_min; |
697 | return expr; |
698 | } |
699 | ErrorLogger(this) << "Could not normalize iterators using the constraints given." ; |
700 | return expr; |
701 | } |
702 | |
703 | /*! |
704 | * \brief Normalize expr to an iterator + offset. |
705 | * \param expr The input expression. |
706 | * \return The Normalized expression. |
707 | */ |
708 | IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { |
709 | // We are normalizing a regular iter |
710 | if (expr->args.size() < 1) return expr; |
711 | Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_); |
712 | if (opt.defined()) { |
713 | return opt.value(); |
714 | } else { |
715 | ErrorLogger(this) << "Could not normalize iterators" ; |
716 | return expr; |
717 | } |
718 | } |
719 | |
720 | /*! |
721 | * \brief Create a IterSumExpr from expr. |
722 | * \param expr The input expr. |
723 | * \return The transformed IterSumExpr. |
724 | */ |
725 | static IterSumExpr ToIterSumExpr(const PrimExpr& expr) { |
726 | if (const auto* op = expr.as<IterSumExprNode>()) { |
727 | return GetRef<IterSumExpr>(op); |
728 | } else if (const auto* op = expr.as<IterSplitExprNode>()) { |
729 | return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype)); |
730 | } else { |
731 | ICHECK(!expr->IsInstance<IterMapExprNode>()); |
732 | return IterSumExpr({}, expr); |
733 | } |
734 | } |
735 | |
736 | /*! |
737 | * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base |
738 | * = (x1*s1 + x2*s2 + ... + xn)*cn + base |
739 | * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base |
740 | * = [IterSplit(IterMark(y), scale=cn)] + base |
741 | * return a corresponding IterSumExpr with extra offset if needed. |
742 | * Try to normalize IterSum into a fused IterMark |
743 | * \param expr The input sum. |
744 | * \param check_level The check level if iter mapping. |
745 | * \return The sum with the fused IterMark and extra offset if succeed. |
746 | */ |
747 | Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { |
748 | // select the iterators in order |
749 | std::vector<bool> visited(expr->args.size(), false); |
750 | std::vector<IterSplitExpr> flattened_iters, grouped_iters; |
751 | // canonicalize the expression into two different forms: flattened form and structured form |
752 | // step0. check if find the base scale first |
753 | Optional<IntImm> base_scale = NullOpt; |
754 | size_t base_index = 0; |
755 | for (size_t i = 0; i < expr->args.size(); ++i) { |
756 | if (const auto* op = expr->args[i]->scale.as<IntImmNode>()) { |
757 | if (!base_scale || op->value < base_scale.value()->value) { |
758 | base_scale = GetRef<IntImm>(op); |
759 | base_index = i; |
760 | } |
761 | } |
762 | } |
763 | if (!base_scale) { |
764 | return NullOpt; |
765 | } |
766 | // check if it can be remapped into a fused pattern. |
767 | PrimExpr = 0; |
768 | PrimExpr tail_extent = 0; |
769 | PrimExpr expected_scale = base_scale.value(); |
770 | for (size_t i = 0; i < expr->args.size();) { |
771 | // find position such that expr->args[j] match expected scale |
772 | int j = i == 0 ? base_index : expr->args.size() - 1; |
773 | |
774 | size_t matched_pos = expr->args.size(); |
775 | PrimExpr matched_scale{nullptr}; |
776 | bool is_exact_match{false}; |
777 | |
778 | for (; j >= 0; --j) { |
779 | if (visited[j]) { |
780 | continue; |
781 | } |
782 | const PrimExpr& cur_scale = expr->args[j]->scale; |
783 | |
784 | // for bijective mapping, the matched scale must equal to expected scale |
785 | if (analyzer_->CanProveEqual(cur_scale, expected_scale)) { |
786 | matched_pos = j; |
787 | matched_scale = cur_scale; |
788 | is_exact_match = true; |
789 | break; |
790 | } |
791 | if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) { |
792 | // find the closest scale which is less or equal to expected scale |
793 | if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) && |
794 | analyzer_->CanProveGreaterEqual(cur_scale, 0)) { |
795 | if (matched_pos == expr->args.size() || |
796 | analyzer_->CanProveLess(matched_scale - cur_scale, 0)) { |
797 | matched_pos = j; |
798 | matched_scale = cur_scale; |
799 | } |
800 | } |
801 | } |
802 | } |
803 | if (matched_pos == expr->args.size()) { |
804 | return NullOpt; |
805 | } |
806 | // look for the longest constrained iter started from expr->args[j] |
807 | // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) |
808 | // predicate: j*2 + k < 9 |
809 | // We need to match the predicate in expr and adjust the expected scale, |
810 | // otherwise we expect the scale of i to be 2*5=10 |
811 | Optional<IterSumExpr> constraint_to_match; |
812 | for (const IterSumExpr& iter : constrained_iters_flattened_) { |
813 | if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) { |
814 | // find a predicate started from match position |
815 | if (!constraint_to_match || |
816 | constraint_to_match.value()->args.size() < iter->args.size()) { |
817 | constraint_to_match = iter; |
818 | } |
819 | } |
820 | } |
821 | if (constraint_to_match) { |
822 | // match the predicate and mark the iterators in the constraint_to_match as visited |
823 | // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2) |
824 | // predicate = j*2 + k < 9 |
825 | // then j*2 + k matches the lower two splits of expr |
826 | for (auto it = constraint_to_match.value()->args.rbegin(); |
827 | it != constraint_to_match.value()->args.rend(); ++it) { |
828 | size_t k = 0; |
829 | for (; k < expr->args.size(); ++k) { |
830 | if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) { |
831 | if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale)) |
832 | break; |
833 | } |
834 | } |
835 | if (k == expr->args.size()) { |
836 | return NullOpt; |
837 | } |
838 | visited[k] = true; |
839 | flattened_iters.push_back(expr->args[k]); |
840 | } |
841 | auto iter = sum_fuse_map_.find(constraint_to_match.value()); |
842 | ICHECK(iter != sum_fuse_map_.end()); |
843 | const IterMarkWithOffset& iter_matched = iter->second; |
844 | grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value())); |
845 | expected_extra_base += iter_matched.offset * matched_scale; |
846 | if (!is_exact_match) { |
847 | tail_extent += expected_scale - matched_scale; |
848 | } |
849 | expected_scale = matched_scale * iter_matched.mark->extent; |
850 | // move forward |
851 | i += constraint_to_match.value()->args.size(); |
852 | } else { |
853 | // constraint_to_match not found, skip this iterator |
854 | visited[matched_pos] = true; |
855 | IterSplitExpr arg = expr->args[matched_pos]; |
856 | arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value())); |
857 | flattened_iters.push_back(arg); |
858 | grouped_iters.push_back(arg); |
859 | if (!is_exact_match) { |
860 | tail_extent += expected_scale - matched_scale; |
861 | } |
862 | expected_scale = matched_scale * expr->args[matched_pos]->extent; |
863 | ++i; |
864 | } |
865 | } |
866 | // Get the flattened form and structured form |
867 | // both forms have splits from outermost to innermost |
868 | IterSumExpr structured_form = expr, flattened_form = expr; |
869 | flattened_form.CopyOnWrite()->args = |
870 | Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend()); |
871 | flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); |
872 | structured_form.CopyOnWrite()->args = |
873 | Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend()); |
874 | structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); |
875 | auto it = sum_fuse_map_.find(flattened_form); |
876 | if (it != sum_fuse_map_.end()) { |
877 | // old iter |
878 | if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) { |
879 | // the extra offset is not consistent with old |
880 | return NullOpt; |
881 | } |
882 | return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())}, |
883 | expr->base + expected_extra_base); |
884 | } else { |
885 | // new iter, form a new mark |
886 | IterMark mark = |
887 | IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent); |
888 | sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0); |
889 | flattened_map_[structured_form] = flattened_form; |
890 | return IterSumExpr({IterSplitExpr(mark, base_scale.value())}, |
891 | expr->base + expected_extra_base); |
892 | } |
893 | } |
894 | |
895 | bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs); |
896 | |
897 | PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); |
898 | PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); |
899 | |
900 | static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { |
901 | tir::ExprDeepEqual equal; |
902 | for (size_t i = 0; i < lhs->args.size(); ++i) { |
903 | IterSplitExpr lvalue = lhs->args[i]; |
904 | if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) && |
905 | equal(lvalue->extent, rhs->extent)) { |
906 | if (sign > 0) { |
907 | rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale; |
908 | } else { |
909 | rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale; |
910 | } |
911 | lhs->args.Set(i, rhs); |
912 | return; |
913 | } |
914 | } |
915 | if (sign > 0) { |
916 | lhs->args.push_back(rhs); |
917 | } else { |
918 | rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale; |
919 | lhs->args.push_back(rhs); |
920 | } |
921 | } |
922 | |
923 | static void AddToLhs(IterSumExprNode* lhs, const IterSumExpr& rhs, int sign) { |
924 | for (const auto& arg : rhs->args) { |
925 | AddToLhs(lhs, arg, sign); |
926 | } |
927 | if (sign > 0) { |
928 | lhs->base += rhs->base; |
929 | } else { |
930 | lhs->base -= rhs->base; |
931 | } |
932 | } |
933 | |
934 | static void MulToLhs(IterSumExprNode* lhs, const PrimExpr& rhs) { |
935 | for (size_t i = 0; i < lhs->args.size(); ++i) { |
936 | IterSplitExpr lvalue = lhs->args[i]; |
937 | lvalue.CopyOnWrite()->scale *= rhs; |
938 | lhs->args.Set(i, lvalue); |
939 | } |
940 | lhs->base *= rhs; |
941 | } |
942 | }; |
943 | |
944 | /*! \brief An internal struct to represent range extent on iterators(iter < upper_bound). */ |
945 | struct IterConstraint { |
946 | // The expr of the iter |
947 | PrimExpr iter; |
948 | // The expr of the lower_bound, maybe undefined |
949 | Optional<PrimExpr> lower_bound; |
950 | // The expr of the upper_bound, maybe undefined |
951 | Optional<PrimExpr> upper_bound; |
952 | // The size of the iter, which is the number of nodes |
953 | size_t expr_size = 0; |
954 | |
955 | IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound, |
956 | size_t size) |
957 | : iter(std::move(iter)), |
958 | lower_bound(std::move(lower_bound)), |
959 | upper_bound(std::move(upper_bound)), |
960 | expr_size(size) {} |
961 | }; |
962 | |
963 | /*! |
964 | * \brief Split the predicate into `(a < b) && (c < d) && ...` |
965 | * \param pred The predicate to be split. |
966 | * \param input_iters The input iterators. |
967 | * \param result The result of predicate split. |
968 | * \return A list of IterConstraint, empty if the split failed. |
969 | */ |
970 | bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters, |
971 | std::vector<IterConstraint>* result) { |
972 | arith::PVar<PrimExpr> lhs, rhs, rest; |
973 | for (;;) { |
974 | // try extract comparisions |
975 | bool is_finish = false; |
976 | bool is_greater = false; |
977 | bool is_equal = false; |
978 | if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) { |
979 | // pass |
980 | } else if ((lhs < rhs).Match(pred)) { |
981 | is_finish = true; |
982 | } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) { |
983 | is_equal = true; |
984 | } else if ((lhs <= rhs).Match(pred)) { |
985 | is_equal = true; |
986 | is_finish = true; |
987 | } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) { |
988 | is_greater = true; |
989 | } else if ((lhs > rhs).Match(pred)) { |
990 | is_greater = true; |
991 | is_finish = true; |
992 | } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) { |
993 | is_greater = true; |
994 | is_equal = true; |
995 | } else if ((lhs >= rhs).Match(pred)) { |
996 | is_greater = true; |
997 | is_equal = true; |
998 | is_finish = true; |
999 | } else { |
1000 | return false; |
1001 | } |
1002 | PrimExpr lhs_expr = lhs.Eval(); |
1003 | PrimExpr rhs_expr = rhs.Eval(); |
1004 | // we only accept predicate of integers |
1005 | if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && |
1006 | (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { |
1007 | return false; |
1008 | } |
1009 | // determine iter and bound, if we can not distinguish them simply, |
1010 | // try divide (lhs - rhs) into itervar aware and itervar free parts |
1011 | auto f_use_itervar = [&input_iters](const VarNode* v) { |
1012 | return input_iters->count(GetRef<Var>(v)); |
1013 | }; |
1014 | bool bound_at_left; |
1015 | if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) { |
1016 | // At least it uses one input iter |
1017 | if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) { |
1018 | bound_at_left = true; |
1019 | } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) { |
1020 | bound_at_left = false; |
1021 | } else { |
1022 | bound_at_left = false; // accumulate bound to rhs |
1023 | PrimExpr sum_parts = lhs_expr - rhs_expr; |
1024 | lhs_expr = 0; |
1025 | rhs_expr = 0; |
1026 | std::function<void(const PrimExpr&, bool)> = |
1027 | [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) { |
1028 | if (const AddNode* add = part.as<AddNode>()) { |
1029 | f_extract(add->a, sign); |
1030 | f_extract(add->b, sign); |
1031 | } else if (const SubNode* sub = part.as<SubNode>()) { |
1032 | f_extract(sub->a, sign); |
1033 | f_extract(sub->b, !sign); |
1034 | } else if (UsesVar(part, f_use_itervar)) { |
1035 | lhs_expr = sign ? lhs_expr + part : lhs_expr - part; |
1036 | } else { |
1037 | rhs_expr = sign ? rhs_expr - part : rhs_expr + part; |
1038 | } |
1039 | }; |
1040 | f_extract(sum_parts, true); |
1041 | arith::Analyzer analyzer; |
1042 | lhs_expr = analyzer.Simplify(lhs_expr); |
1043 | rhs_expr = analyzer.Simplify(rhs_expr); |
1044 | } |
1045 | Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt; |
1046 | PrimExpr iter; |
1047 | if (is_greater) { |
1048 | if (bound_at_left) { |
1049 | // bound > iter / bound >= iter |
1050 | upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; |
1051 | iter = rhs_expr; |
1052 | } else { |
1053 | // iter > bound / iter >= bound |
1054 | lower_bound = is_equal ? rhs_expr : rhs_expr + 1; |
1055 | iter = lhs_expr; |
1056 | } |
1057 | } else { |
1058 | if (bound_at_left) { |
1059 | // bound < iter / bound <= iter |
1060 | lower_bound = is_equal ? lhs_expr : lhs_expr + 1; |
1061 | iter = rhs_expr; |
1062 | } else { |
1063 | // iter < bound / iter <= bound |
1064 | upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; |
1065 | iter = lhs_expr; |
1066 | } |
1067 | } |
1068 | // If it is a predicate for a single input iter |
1069 | if (const auto* var_ptr = iter.as<VarNode>()) { |
1070 | auto it = input_iters->find(GetRef<Var>(var_ptr)); |
1071 | if (it != input_iters->end()) { |
1072 | PrimExpr iter_min = (*it).second->min; |
1073 | PrimExpr iter_max = (*it).second->min + (*it).second->extent; |
1074 | if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); |
1075 | if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); |
1076 | input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max)); |
1077 | } |
1078 | } else { |
1079 | result->emplace_back(iter, lower_bound, upper_bound, 0); |
1080 | } |
1081 | } |
1082 | if (is_finish) { |
1083 | break; |
1084 | } |
1085 | pred = rest.Eval(); |
1086 | } |
1087 | return true; |
1088 | } |
1089 | |
1090 | bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) { |
1091 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> iters; |
1092 | for (const auto& it : iter_ranges) iters.insert(it.first); |
1093 | auto f = [&](const VarNode* var) { return iters.count(GetRef<Var>(var)); }; |
1094 | for (const auto& it : iter_ranges) { |
1095 | if (UsesVar(it.second->min, f) || UsesVar(it.second->extent, f)) return false; |
1096 | } |
1097 | return true; |
1098 | } |
1099 | |
1100 | IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, |
1101 | const PrimExpr& predicate, IterMapLevel check_level, |
1102 | arith::Analyzer* analyzer, bool simplify_trivial_iterators) { |
1103 | IterMapResult result; |
1104 | |
1105 | // Overall detection algorithm is divided into two steps: |
1106 | // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. |
1107 | // - Step1: IterIndependenceChecker checks if the iterator are independent. |
1108 | if (!IterRangeSanityCheck(input_iters)) { |
1109 | result->errors.push_back("Invalid iterators. Iterators may not be expressions of each other." ); |
1110 | return result; |
1111 | } |
1112 | Map<Var, Range> constrained_input_iters = input_iters; |
1113 | std::vector<IterConstraint> constraints; |
1114 | if (!is_one(predicate) && |
1115 | !MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) { |
1116 | result->errors.push_back("Could not parse predicate as constraints on the input iterators." ); |
1117 | return result; |
1118 | } |
1119 | // We have to make sure when we visit an iterator, all the constraints related with its successors |
1120 | // in the iter var graph has been visited, where the expression of this iterator will contain the |
1121 | // expression of its successor, so we sort them by their sizes. |
1122 | for (IterConstraint& constraint : constraints) { |
1123 | constraint.expr_size = CalculateExprComplexity(constraint.iter); |
1124 | } |
1125 | |
1126 | std::sort( |
1127 | constraints.begin(), constraints.end(), |
1128 | [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); |
1129 | |
1130 | IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level, |
1131 | simplify_trivial_iterators, &result->errors); |
1132 | // Step0.0: rewrite constraints in the order from size-small ones to size-big ones |
1133 | for (const IterConstraint& constraint : constraints) { |
1134 | auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, |
1135 | constraint.upper_bound); |
1136 | if (result->errors.size() > 0) { |
1137 | return result; |
1138 | } |
1139 | } |
1140 | if (!rewriter.CheckConstraints()) { |
1141 | result->errors.push_back("Invalid constraints." ); |
1142 | return result; |
1143 | } |
1144 | |
1145 | // Step0.1: Rewrite indicies and determine required padding, |
1146 | // if there is no padding, it should be the final result. |
1147 | Array<IterSumExpr> rewrite_indices; |
1148 | rewrite_indices.reserve(indices.size()); |
1149 | bool allow_padding = check_level != IterMapLevel::Bijective; |
1150 | if (allow_padding) { |
1151 | for (PrimExpr value : indices) { |
1152 | rewrite_indices.push_back(rewriter.RewriteAndUpdatePadding(value)); |
1153 | if (result->errors.size() > 0) { |
1154 | return result; |
1155 | } |
1156 | } |
1157 | } |
1158 | |
1159 | // Step0.2: Rewrite indices in the second round. |
1160 | if (!allow_padding || rewriter.requires_padding()) { |
1161 | rewrite_indices.clear(); |
1162 | for (PrimExpr value : indices) { |
1163 | rewrite_indices.push_back(rewriter.Rewrite(value)); |
1164 | if (result->errors.size() > 0) { |
1165 | return result; |
1166 | } |
1167 | } |
1168 | } |
1169 | result->padding_predicate = rewriter.padding_predicate(); |
1170 | |
1171 | // Step1: IterIndependenceChecker checks if the iterator are independent. |
1172 | if (!rewriter.CheckMapping(rewrite_indices, check_level)) { |
1173 | if (check_level == IterMapLevel::Bijective) { |
1174 | result->errors.push_back("Index mapping does not form a bijective transform." ); |
1175 | } else { |
1176 | result->errors.push_back("Mapped indices are not independent." ); |
1177 | } |
1178 | return result; |
1179 | } |
1180 | result->indices = rewrite_indices; |
1181 | return result; |
1182 | } |
1183 | |
1184 | TVM_REGISTER_GLOBAL("arith.DetectIterMap" ) |
1185 | .set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, |
1186 | const PrimExpr& input_pred, int check_level, |
1187 | bool simplify_trivial_iterators) { |
1188 | arith::Analyzer ana; |
1189 | return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, |
1190 | simplify_trivial_iterators); |
1191 | }); |
1192 | |
1193 | PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { |
1194 | auto var = GetRef<Var>(op); |
1195 | auto it = var_map_.find(var); |
1196 | if (it != var_map_.end()) return it->second; |
1197 | return std::move(var); |
1198 | } |
1199 | |
1200 | PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) { |
1201 | if (!IsIndexType(op->dtype)) { |
1202 | return Parent::VisitExpr_(op); |
1203 | } |
1204 | PrimExpr a = this->DirectMutate(op->a); |
1205 | PrimExpr b = this->DirectMutate(op->b); |
1206 | |
1207 | // const folding |
1208 | if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value(); |
1209 | // does not contain iter map. |
1210 | if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) { |
1211 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1212 | return GetRef<PrimExpr>(op); |
1213 | } else { |
1214 | return Add(a, b); |
1215 | } |
1216 | } |
1217 | |
1218 | // canonical form simplification. |
1219 | IterSumExpr ret = ToIterSumExpr(a); |
1220 | |
1221 | if (!b->IsInstance<IterMapExprNode>()) { |
1222 | ret.CopyOnWrite()->base += b; |
1223 | } else if (const auto* op = b.as<IterSumExprNode>()) { |
1224 | AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), 1); |
1225 | } else if (const auto* op = b.as<IterSplitExprNode>()) { |
1226 | AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), 1); |
1227 | } else { |
1228 | AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1); |
1229 | } |
1230 | return std::move(ret); |
1231 | } |
1232 | |
1233 | PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) { |
1234 | if (!IsIndexType(op->dtype)) { |
1235 | return Parent::VisitExpr_(op); |
1236 | } |
1237 | |
1238 | PrimExpr a = this->DirectMutate(op->a); |
1239 | PrimExpr b = this->DirectMutate(op->b); |
1240 | |
1241 | // const folding |
1242 | if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value(); |
1243 | |
1244 | // does not contain iter map. |
1245 | if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) { |
1246 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1247 | return GetRef<PrimExpr>(op); |
1248 | } else { |
1249 | return Sub(a, b); |
1250 | } |
1251 | } |
1252 | |
1253 | // canonical form simplification. |
1254 | IterSumExpr ret = ToIterSumExpr(a); |
1255 | |
1256 | if (!b->IsInstance<IterMapExprNode>()) { |
1257 | ret.CopyOnWrite()->base -= b; |
1258 | } else if (const auto* op = b.as<IterSumExprNode>()) { |
1259 | AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), -1); |
1260 | } else if (const auto* op = b.as<IterSplitExprNode>()) { |
1261 | AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), -1); |
1262 | } else { |
1263 | AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1); |
1264 | } |
1265 | return std::move(ret); |
1266 | } |
1267 | |
1268 | PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { |
1269 | if (!IsIndexType(op->dtype)) { |
1270 | return Parent::VisitExpr_(op); |
1271 | } |
1272 | // normalize |
1273 | PrimExpr a = this->DirectMutate(op->a); |
1274 | PrimExpr b = this->DirectMutate(op->b); |
1275 | |
1276 | // const folding |
1277 | if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value(); |
1278 | |
1279 | // does not contain iter map. |
1280 | if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) { |
1281 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1282 | return GetRef<PrimExpr>(op); |
1283 | } else { |
1284 | return Mul(a, b); |
1285 | } |
1286 | } |
1287 | |
1288 | if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) { |
1289 | // cannot multiply two iterators, mark as unresolved. |
1290 | ErrorLogger(this) << "Product of two iterators cannot be represented as an IterMap, " |
1291 | << "occurs in " << GetRef<Mul>(op); |
1292 | return GetRef<PrimExpr>(op); |
1293 | } |
1294 | |
1295 | if (!a->IsInstance<IterMapExprNode>()) { |
1296 | std::swap(a, b); |
1297 | } |
1298 | |
1299 | if (a->IsInstance<IterSumExprNode>()) { |
1300 | IterSumExpr ret = Downcast<IterSumExpr>(std::move(a)); |
1301 | MulToLhs(ret.CopyOnWrite(), b); |
1302 | return std::move(ret); |
1303 | } else { |
1304 | ICHECK(a->IsInstance<IterSplitExprNode>()); |
1305 | IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a)); |
1306 | ret.CopyOnWrite()->scale *= b; |
1307 | return std::move(ret); |
1308 | } |
1309 | } |
1310 | |
1311 | IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { |
1312 | if (dividend->IsInstance<IterSplitExprNode>()) { |
1313 | auto split = Downcast<IterSplitExpr>(dividend); |
1314 | return IterSumExpr({split}, make_zero(split.dtype())); |
1315 | } else if (dividend->IsInstance<IterSumExprNode>()) { |
1316 | auto sum = Downcast<IterSumExpr>(dividend); |
1317 | if (sum->args.empty()) { |
1318 | return IterSumExpr(); |
1319 | } else if (sum->args.size() == 1) { |
1320 | return sum; |
1321 | } |
1322 | auto opt_fused = TryFuseIters(sum, check_level_); |
1323 | if (!opt_fused) { |
1324 | ErrorLogger(this) << "Dividend " << original_dividend |
1325 | << ", can't be written as a single fused IterSum" ; |
1326 | return IterSumExpr(); |
1327 | } |
1328 | IterSumExpr fused = opt_fused.value(); |
1329 | ICHECK_EQ(fused->args.size(), 1U); |
1330 | return fused; |
1331 | } else { |
1332 | LOG(FATAL) << "Unsupported subclass of IterMarkExpr" ; |
1333 | } |
1334 | } |
1335 | |
1336 | /*! \brief Find approximate least common multiplier. */ |
1337 | PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { |
1338 | auto fsplit = [](const PrimExpr& e) -> std::pair<PrimExpr, int64_t> { |
1339 | if (const IntImmNode* imm = e.as<IntImmNode>()) { |
1340 | return {1, imm->value}; |
1341 | } |
1342 | PVar<PrimExpr> pv; |
1343 | PVar<IntImm> pc; |
1344 | if ((pv * pc).Match(e) || (pc * pv).Match(e)) { |
1345 | return {pv.Eval(), pc.Eval()->value}; |
1346 | } else { |
1347 | return {e, 1}; |
1348 | } |
1349 | }; |
1350 | auto p1 = fsplit(a); |
1351 | auto p2 = fsplit(b); |
1352 | auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second)); |
1353 | if (analyzer->CanProveEqual(p1.first, p2.first)) { |
1354 | return p1.first * const_lcm; |
1355 | } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { |
1356 | return p1.first * const_lcm; |
1357 | } else if (analyzer->CanProveEqual(floormod(p2.first, p1.first), 0)) { |
1358 | return p2.first * const_lcm; |
1359 | } else { |
1360 | return (p1.first * p2.first) * const_lcm; |
1361 | } |
1362 | } |
1363 | |
1364 | std::pair<IterSplitExpr, PrimExpr> IterMapRewriter::PadDividendToDivisor(IterSplitExpr split, |
1365 | PrimExpr base, |
1366 | PrimExpr divisor) { |
1367 | // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor |
1368 | // If FloorMod: (((source//lower_factor) % extent) + base) % divisor |
1369 | |
1370 | // First, adding any padding that is on the lower side of a |
1371 | // FloorDiv/FloorMod, such that floormod(split - left_pad, divisor) == 0 |
1372 | // when iter == 0. |
1373 | PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor)); |
1374 | |
1375 | // Next, adding any padding that is on the upper side of a |
1376 | // FloorDiv/FloorMod, such that floormod(left_pad + split + right_pad, divisor) == 0 |
1377 | // when iter == extent. |
1378 | PrimExpr right_edge = left_pad + split->extent; |
1379 | PrimExpr right_pad; |
1380 | if (CanProveDivisible(right_edge, divisor)) { |
1381 | right_pad = 0; |
1382 | } else { |
1383 | right_pad = analyzer_->Simplify(floormod(-right_edge, divisor)); |
1384 | } |
1385 | |
1386 | const IterMark& mark = split->source; |
1387 | if (update_iterator_padding_) { |
1388 | // In the first pass, the primary goal is to collect all the divisors |
1389 | // that may be used for padding. These will impact the divisor used |
1390 | // to determine padding in the second pass. We try add padding to |
1391 | // split's source iteraton mark thus all splits under the same mark will |
1392 | // share the same padded source iteration. |
1393 | auto& info = padded_iter_map_[mark]; |
1394 | info.padding_factor = |
1395 | ApproxLeastCommonMultiple(info.padding_factor, divisor * split->lower_factor, analyzer_); |
1396 | |
1397 | // If the split itself require no padding, return directly. |
1398 | if (is_zero(left_pad) && is_zero(right_pad)) { |
1399 | return {split, 0}; |
1400 | } |
1401 | |
1402 | // Update padding requirement on the lower side of the source iter mark. |
1403 | // In the second pass, all splits would check whether the maximum left pading |
1404 | // on the iter mark is compatible with it's own left padding. |
1405 | requires_padding_ = true; |
1406 | PrimExpr mark_left_pad = left_pad * split->lower_factor; |
1407 | info.left_pad = max(info.left_pad, mark_left_pad); |
1408 | |
1409 | // Since we only care the extent in the first pass's result |
1410 | // we just create result of compatible padded extent, ignoring |
1411 | // possible relations between different padded iters. |
1412 | PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + right_pad); |
1413 | split.CopyOnWrite()->extent = padded_extent; |
1414 | return {split, left_pad}; |
1415 | } |
1416 | |
1417 | // In the second pass, update iteration mark's to padded form |
1418 | auto it = padded_iter_map_.find(mark); |
1419 | if (it == padded_iter_map_.end()) { |
1420 | return {split, left_pad}; |
1421 | } |
1422 | auto& info = it->second; |
1423 | if (is_zero(info.left_pad) && CanProveDivisible(mark->extent, info.padding_factor)) { |
1424 | // the iter mark requires no padding |
1425 | return {split, left_pad}; |
1426 | } |
1427 | |
1428 | // check that padding factor is compatible with current split and divisor |
1429 | ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) |
1430 | << "The padding factor " << info.padding_factor << " is not divisible by " |
1431 | << split->lower_factor << " for the split " << split; |
1432 | ICHECK(CanProveDivisible(info.padding_factor, divisor)) |
1433 | << "The padding factor " << info.padding_factor << " is not divisible by " << divisor |
1434 | << " for the split " << split; |
1435 | |
1436 | if (!info.padded.defined()) { |
1437 | // the first time encounter the iter mark to pad, update the padded mark. |
1438 | PrimExpr mark_left_pad = info.left_pad; |
1439 | if (CanProveDivisible(mark_left_pad, split->lower_factor)) { |
1440 | // correct current split's left padding |
1441 | // (mark_left_pad + iter) // lower_factor % extent => |
1442 | // (left_pad * lower_factor + mark) // lower_factor % extent => |
1443 | // (left_pad + mark // lower_factor) % extent => |
1444 | // left_pad + (mark // lower_factor % extent) => |
1445 | // left_pad + split |
1446 | // since the extent covers the full padding range. |
1447 | left_pad = floordiv(mark_left_pad, split->lower_factor); |
1448 | } else { |
1449 | ErrorLogger(this) << "Detect incompatible left padding on " << NormalizeIterMapToExpr(split) |
1450 | << ", the iter mark is left padded with " << mark_left_pad; |
1451 | return {IterSplitExpr(), PrimExpr()}; |
1452 | } |
1453 | |
1454 | PrimExpr right_edge = mark->extent + mark_left_pad; |
1455 | PrimExpr mark_right_pad; |
1456 | if (CanProveDivisible(right_edge, info.padding_factor)) { |
1457 | mark_right_pad = 0; |
1458 | } else { |
1459 | mark_right_pad = floormod(-right_edge, info.padding_factor); |
1460 | } |
1461 | PrimExpr padded_extent = analyzer_->Simplify(right_edge + mark_right_pad); |
1462 | info.right_pad = mark_right_pad; |
1463 | info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); |
1464 | padded_origin_map_[info.padded] = mark; |
1465 | |
1466 | auto left_padding_introduced = (mark_left_pad != 0); |
1467 | |
1468 | // Equivalent to (0 <= split < left_pad), but easier to simplify in |
1469 | // terms of the transformed variables. |
1470 | auto left_padding_predicate = |
1471 | left_padding_introduced && |
1472 | (floordiv(info.padded->source, info.padding_factor) == 0 && |
1473 | floormod(info.padded->source, info.padding_factor) < mark_left_pad); |
1474 | auto right_padding_introduced = (mark_right_pad != 0); |
1475 | |
1476 | // Equivalent to (right_edge <= split < right_edge + right_pad), but |
1477 | // easier to simplify in terms of the transformed variables. |
1478 | auto right_padding_predicate = |
1479 | right_padding_introduced && (floordiv(info.padded->source, info.padding_factor) == |
1480 | floordiv(right_edge, info.padding_factor) && |
1481 | floormod(info.padded->source, info.padding_factor) >= |
1482 | floormod(right_edge, info.padding_factor)); |
1483 | padding_predicate_ = padding_predicate_ || (left_padding_predicate || right_padding_predicate); |
1484 | } |
1485 | split.CopyOnWrite()->source = info.padded; |
1486 | split.CopyOnWrite()->extent = floordiv(info.padded->extent, split->lower_factor); |
1487 | return {split, left_pad}; |
1488 | } |
1489 | |
1490 | PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { |
1491 | // (lhs + base) // rhs |
1492 | |
1493 | if (is_one(rhs)) { |
1494 | if (is_zero(base)) { |
1495 | // floordiv(x, 1) = x |
1496 | return std::move(lhs); |
1497 | } else { |
1498 | // floordiv(x+y, 1) = x+y |
1499 | return IterSumExpr({lhs}, base); |
1500 | } |
1501 | } |
1502 | |
1503 | if (!is_one(lhs->scale)) { |
1504 | if (CanProveDivisible(lhs->scale, rhs) && is_zero(base)) { |
1505 | // floordiv(x*c1*c2, c2) = x*c1, c1=scale/rhs |
1506 | lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); |
1507 | return std::move(lhs); |
1508 | } else if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { |
1509 | // floordiv(x*c1*c2 + y*c2, c2) = x*c1 + y, c1=scale/rhs |
1510 | lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs); |
1511 | return IterSumExpr({lhs}, floordiv(base, rhs)); |
1512 | } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { |
1513 | // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale |
1514 | rhs = floordiv(rhs, lhs->scale); |
1515 | lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); |
1516 | } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { |
1517 | // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale |
1518 | base = floordiv(base, lhs->scale); |
1519 | rhs = floordiv(rhs, lhs->scale); |
1520 | lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); |
1521 | } else { |
1522 | // mark as unresolved. |
1523 | ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " |
1524 | << lhs->scale << " and the divisor " << rhs |
1525 | << " cannot be simplified to remove the scaling factor." ; |
1526 | return PrimExpr(); |
1527 | } |
1528 | } |
1529 | |
1530 | // We handle scale!=1 in above code, hence we only consider floordiv(x, rhs) below |
1531 | // where x=floormod(floordiv(iter, lower_factor), extent) + base |
1532 | |
1533 | auto pair = PadDividendToDivisor(lhs, base, rhs); |
1534 | IterSplitExpr padded = pair.first; |
1535 | PrimExpr left_pad = pair.second; |
1536 | if (!padded.defined()) { |
1537 | return PrimExpr(); |
1538 | } |
1539 | |
1540 | // floordiv(floormod(floordiv(iter, lower_factor), c1c2), c1) |
1541 | // = floordiv(floormod(y, c1c2), c1), where y=floordiv(iter, lower_factor) |
1542 | // = floordiv(floormod(sc1c2+tc1+u, c1c2), c1), where y=sc1c2+tc1+u, t<c2, u<c1 |
1543 | // = t |
1544 | // = floormod(sc2+t, c2) |
1545 | // = floormod(floordiv(y, c1), c2) |
1546 | // = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs, c2=extent/rhs |
1547 | IterSplitExpr new_split(padded->source, |
1548 | /* lower_factor = */ padded->lower_factor * rhs, |
1549 | /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), |
1550 | /* scale = */ padded->scale); |
1551 | |
1552 | auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); |
1553 | if (is_zero(new_base)) { |
1554 | return std::move(new_split); |
1555 | } else { |
1556 | return IterSumExpr({new_split}, new_base); |
1557 | } |
1558 | } |
1559 | |
1560 | PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { |
1561 | if (!IsIndexType(op->dtype)) { |
1562 | return Parent::VisitExpr_(op); |
1563 | } |
1564 | |
1565 | PrimExpr a = this->DirectMutate(op->a); |
1566 | PrimExpr b = this->DirectMutate(op->b); |
1567 | |
1568 | // const folding |
1569 | if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value(); |
1570 | |
1571 | // does not contain iter map. |
1572 | if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) { |
1573 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1574 | return GetRef<PrimExpr>(op); |
1575 | } else { |
1576 | return FloorDiv(a, b); |
1577 | } |
1578 | } |
1579 | |
1580 | if (b->IsInstance<IterMapExprNode>()) { |
1581 | // cannot divide an iterator, mark as unresolved. |
1582 | ErrorLogger(this) << "Cannot represent as an IterMap: the divisor in " << GetRef<PrimExpr>(op) |
1583 | << " may not be an iterator" ; |
1584 | return GetRef<PrimExpr>(op); |
1585 | } |
1586 | |
1587 | IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a); |
1588 | if (!preprocessed.defined()) { |
1589 | return GetRef<PrimExpr>(op); |
1590 | } |
1591 | ICHECK_EQ(preprocessed->args.size(), 1U); |
1592 | PrimExpr remainder = SplitFloorDivConst(preprocessed->args[0], preprocessed->base, b); |
1593 | if (!remainder.defined()) { |
1594 | return GetRef<PrimExpr>(op); |
1595 | } |
1596 | return remainder; |
1597 | } |
1598 | |
1599 | PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs) { |
1600 | // (lhs + base) % rhs |
1601 | |
1602 | if (is_one(rhs)) { |
1603 | // floormod(x, 1) = 0 |
1604 | return make_zero(lhs->dtype); |
1605 | } |
1606 | |
1607 | if (!is_one(lhs->scale)) { |
1608 | if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { |
1609 | // floormod(x*c1*c2, c1) = 0 |
1610 | return make_zero(lhs->dtype); |
1611 | } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { |
1612 | // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale |
1613 | rhs = floordiv(rhs, lhs->scale); |
1614 | } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { |
1615 | // floormod(x*c1 + y*c1, c1*c2) = (floormod(x+y, c2)) * c1, where c2 = rhs/scale |
1616 | rhs = floordiv(rhs, lhs->scale); |
1617 | base = floordiv(base, lhs->scale); |
1618 | } else { |
1619 | // mark as unresolved. |
1620 | ErrorLogger(this) |
1621 | << "Cannot represent as IterMap: the left-hand side of FloorMod has a scaling factor, " |
1622 | << lhs->scale << " and the right-hand " << rhs |
1623 | << " cannot be used to simplify out the scaling factor." ; |
1624 | return PrimExpr(); |
1625 | } |
1626 | } |
1627 | |
1628 | // We handle scale!=1 in above code, hence we only consider floormod(x, rhs) below |
1629 | // where x=floormod(floordiv(iter, lower_factor), extent) + base |
1630 | auto pair = PadDividendToDivisor(lhs, base, rhs); |
1631 | IterSplitExpr padded = pair.first; |
1632 | if (!padded.defined()) { |
1633 | return PrimExpr(); |
1634 | } |
1635 | |
1636 | // floormod(floormod(floordiv(iter, lower_factor), c1c2), c1) |
1637 | // = floormod(floordiv(iter, lower_factor), c1), where c1=rhs |
1638 | return IterSplitExpr(padded->source, |
1639 | /* lower_factor = */ padded->lower_factor, |
1640 | /* extent = */ rhs, |
1641 | /* scale = */ padded->scale); |
1642 | } |
1643 | |
1644 | PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { |
1645 | if (!IsIndexType(op->dtype)) { |
1646 | return Parent::VisitExpr_(op); |
1647 | } |
1648 | |
1649 | PrimExpr a = this->DirectMutate(op->a); |
1650 | PrimExpr b = this->DirectMutate(op->b); |
1651 | |
1652 | // const folding |
1653 | if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value(); |
1654 | |
1655 | // does not contain iter map. |
1656 | if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) { |
1657 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1658 | return GetRef<PrimExpr>(op); |
1659 | } else { |
1660 | return FloorMod(a, b); |
1661 | } |
1662 | } |
1663 | |
1664 | if (b->IsInstance<IterMapExprNode>()) { |
1665 | // cannot mod an iterator, mark as unresolved. |
1666 | ErrorLogger(this) << "Cannot represent as an IterMap: the right-hand side of FloorMod in " |
1667 | << GetRef<PrimExpr>(op) << " may not be an iterator" ; |
1668 | return GetRef<PrimExpr>(op); |
1669 | } |
1670 | |
1671 | IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a); |
1672 | if (!preprocessed.defined()) { |
1673 | return GetRef<PrimExpr>(op); |
1674 | } |
1675 | |
1676 | ICHECK_EQ(preprocessed->args.size(), 1U); |
1677 | PrimExpr remainder = SplitFloorModConst(preprocessed->args[0], preprocessed->base, b); |
1678 | if (!remainder.defined()) { |
1679 | return GetRef<PrimExpr>(op); |
1680 | } |
1681 | return remainder; |
1682 | } |
1683 | |
1684 | /*! * \brief Given an expression that may contain IterVarMapExpr, transform it to normal PrimExpr. |
1685 | */ |
1686 | class IterMapToExprNormalizer : public ExprMutator { |
1687 | public: |
1688 | explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} |
1689 | |
1690 | PrimExpr Convert(const PrimExpr& expr) { return VisitExpr(expr); } |
1691 | |
1692 | private: |
1693 | /*! \brief Override VisitExpr for iter expr type processing */ |
1694 | PrimExpr VisitExpr(const PrimExpr& expr) override { |
1695 | if (const auto* op = expr.as<IterSplitExprNode>()) { |
1696 | return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op)); |
1697 | } else if (const auto* op = expr.as<IterSumExprNode>()) { |
1698 | return ConvertIterSumExpr(GetRef<IterSumExpr>(op)); |
1699 | } else { |
1700 | return ExprMutator::VisitExpr(expr); |
1701 | } |
1702 | } |
1703 | |
1704 | PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) { |
1705 | PrimExpr res = 0; |
1706 | for (const IterSplitExpr& arg : expr->args) { |
1707 | res += ConvertIterSplitExpr(arg); |
1708 | } |
1709 | res += expr->base; |
1710 | return res; |
1711 | } |
1712 | |
1713 | PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) { |
1714 | PrimExpr source; |
1715 | if (const auto* op = |
---|