1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file canonical_simplify.cc
22 * \brief Canonical form based simplification.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/op.h>
27
28#include "const_fold.h"
29#include "pattern_match.h"
30#include "rewrite_simplify.h"
31
32namespace tvm {
33namespace arith {
34
35using namespace tir;
36
37class SumExpr;
38class SplitExpr;
39
40/*!
41 * \brief Base class of all temporary expression introduced
42 * for canonicalization.
43 */
44class CanonicalExprNode : public PrimExprNode {
45 public:
46 virtual ~CanonicalExprNode() {}
47 /*!
48 * \brief Return the normal Expr that is equivalent to self.
49 * \note Can mutate the internal data structure.
50 * \return The normal expression.
51 */
52 virtual PrimExpr Normalize() const = 0;
53
54 // overrides
55 void VisitAttrs(tvm::AttrVisitor* v) {}
56
57 static constexpr const char* _type_key = "arith.CanonicalExpr";
58 static constexpr const uint32_t _type_child_slots = 2;
59 TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
60};
61
62inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
63 if (mode == kTruncDiv) {
64 return truncmod(a, b);
65 } else {
66 ICHECK_EQ(mode, kFloorDiv);
67 return floormod(a, b);
68 }
69}
70
71inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
72 if (mode == kTruncDiv) {
73 return truncdiv(a, b);
74 } else {
75 ICHECK_EQ(mode, kFloorDiv);
76 return floordiv(a, b);
77 }
78}
79
80/*!
81 * \brief check if value fits in dtype
82 * \param value The value to be analyzed
83 * \param dtype The target dtype
84 * \param analyzer The analyzer
85 * \return whether value fits in dtype
86 */
87bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
88 if (!IsIndexType(dtype)) {
89 return false;
90 }
91 ConstIntBound bound = analyzer->const_int_bound(value);
92 int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
93 int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
94 if (value.dtype().bits() <= dtype.bits() || // upcast is safe
95 (bound->max_value <= ubound && bound->min_value >= lbound)) {
96 return true;
97 }
98 return false;
99}
100
101/*!
102 * \brief Internal "Split normal form" of expression.
103 *
104 * This is a special expression that represents
105 * a scaled value derived from a split of an index.
106 *
107 * result = ((index % upper_factor) / lower_factor) * scale
108 */
109class SplitExprNode : public CanonicalExprNode {
110 public:
111 /*! \brief The base index expression. */
112 PrimExpr index;
113 /*! \brief The division factor ratio. */
114 int64_t lower_factor{1};
115 /*!
116 * \brief The upper factor.
117 * invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0)
118 */
119 int64_t upper_factor{kPosInf};
120 /*! \brief scale to the expression. */
121 int64_t scale{1};
122 /*! \brief Division mode. */
123 DivMode div_mode{kTruncDiv};
124
125 /*! \brief verify that this is a valid entry. */
126 void Verify() const { ICHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); }
127
128 PrimExpr NormalizeWithScale(int64_t sscale) const {
129 PrimExpr res = this->index;
130 DataType dtype = this->dtype;
131 if (this->scale == 0) {
132 return make_const(dtype, 0);
133 }
134 if (this->upper_factor != SplitExprNode::kPosInf) {
135 res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode);
136 }
137 if (this->lower_factor != 1) {
138 res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode);
139 }
140 sscale *= this->scale;
141 if (sscale != 1) {
142 ICHECK(!dtype.is_uint() || sscale > 0);
143 res = res * make_const(dtype, sscale);
144 }
145 return res;
146 }
147
148 PrimExpr Normalize() const final { return NormalizeWithScale(1); }
149
150 void MulToSelf(int64_t scale) { this->scale *= scale; }
151
152 /*!
153 * \brief check if cast can be pushed to sub-expressions
154 * \param dtype The target datatype
155 * \param analyzer The analyzer
156 * \return whether the cast can be safely pushed to children
157 */
158 bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
159 // cast(dtype, index % upper_factor / lower_factor * scale) ==
160 // cast(dtype, index) % upper_factor / lower_factor * scale
161 // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
162 // its intermediate results fit in the range of dtype
163 if (dtype.bits() >= this->dtype.bits()) {
164 return true; // upcast is safe
165 }
166 PrimExpr res = this->index;
167 if (this->scale == 0) {
168 return true;
169 }
170 if (!CastIsSafe(dtype, res, analyzer)) {
171 return false;
172 }
173 if (this->upper_factor != SplitExprNode::kPosInf) {
174 res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
175 if (!CastIsSafe(dtype, res, analyzer)) {
176 return false;
177 }
178 }
179 if (this->lower_factor != 1) {
180 res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
181 if (!CastIsSafe(dtype, res, analyzer)) {
182 return false;
183 }
184 }
185 if (this->scale != 1) {
186 ICHECK(!this->dtype.is_uint() || this->scale > 0);
187 res = res * make_const(this->dtype, this->scale);
188 if (!CastIsSafe(dtype, res, analyzer)) {
189 return false;
190 }
191 }
192 return true;
193 }
194
195 /*!
196 * \brief self = cast(dtype, self)
197 * \param dtype The target datatype
198 */
199 void PushCastToChildren(DataType dtype) {
200 this->index = cast(dtype, this->index);
201 this->dtype = dtype;
202 }
203
204 inline bool IndexEqual(const SplitExpr& other) const;
205 inline bool DivModeCompatibleTo(DivMode mode) const;
206
207 /*! \brief positive infty */
208 static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
209 static constexpr const char* _type_key = "arith.SplitExpr";
210 TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode);
211};
212
213class SplitExpr : public PrimExpr {
214 public:
215 TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode);
216 TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode);
217};
218
219inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const {
220 if (index.same_as(other->index)) return true;
221 return tir::ExprDeepEqual()(index, other->index);
222}
223
224inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const {
225 if (this->div_mode == mode) return true;
226 if (lower_factor == 1 && upper_factor == kPosInf) return true;
227 return false;
228}
229
230/*!
231 * \brief Normal form that represents sum of expressions.
232 *
233 * result = sum(args) + base.
234 */
235class SumExprNode : public CanonicalExprNode {
236 public:
237 /*!
238 * \brief arguments to be summed up.
239 *
240 * args are divided into segments with the same index.
241 * within each segment, the SplitExpr is ordered in descending order of lower_factor.
242 */
243 std::vector<SplitExpr> args;
244 /*! \brief Base value in the summation. */
245 int64_t base{0};
246 /*! \brief The expression equals zero. */
247 bool IsZero() const { return base == 0 && args.size() == 0; }
248 /*!
249 * \brief Return the normal Expr that is equivalent to self.
250 * \return The normal expression.
251 */
252 PrimExpr Normalize() const final {
253 // quick path 1.
254 if (this->args.size() == 0) {
255 return make_const(this->dtype, this->base);
256 }
257 return Normalize_(this->dtype, SimplifySplitExprs(args), base);
258 }
259 /*!
260 * \brief Whether self is divisible by scale.
261 * \param scale The scale to be applied.
262 */
263 bool DivisibleBy(int64_t scale) {
264 if (base % scale != 0) return false;
265 for (size_t i = 0; i < this->args.size(); ++i) {
266 if (args[i]->scale % scale != 0) return false;
267 }
268 return true;
269 }
270 /*!
271 * \brief mul scale to self.
272 * \param scale The scale to be applied.
273 */
274 void MulToSelf(int64_t scale) {
275 this->base *= scale;
276 for (size_t i = 0; i < this->args.size(); ++i) {
277 args[i].CopyOnWrite()->scale *= scale;
278 }
279 }
280 /*!
281 * \brief divide by scale.
282 * \param scale The scale to be applied.
283 */
284 void DivideBy(int64_t scale) {
285 ICHECK_EQ(this->base % scale, 0);
286 this->base /= scale;
287 for (size_t i = 0; i < this->args.size(); ++i) {
288 ICHECK_EQ(args[i]->scale % scale, 0);
289 args[i].CopyOnWrite()->scale /= scale;
290 }
291 }
292 /*!
293 * \brief add constant value to self.
294 * \param value to be added.
295 */
296 void AddToSelf(int64_t value) { this->base += value; }
297 /*!
298 * \brief self += other * scale;
299 * \param other The expression to be added.
300 * \param scale The additional scale on value.
301 */
302 void AddToSelf(SplitExpr other, int64_t scale) {
303 if (other->scale == 0) return;
304 // We need to maintain the segment invariance:
305 // Same index are stored close to each other.
306 // sorted from big lower_factor to small one.
307 size_t start = 0;
308 for (; start < args.size(); ++start) {
309 if (args[start]->IndexEqual(other)) break;
310 }
311 for (size_t j = start; j < args.size(); ++j) {
312 if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) {
313 other.CopyOnWrite()->scale *= scale;
314 this->args.insert(this->args.begin() + j, other);
315 return;
316 }
317 if (other->lower_factor == args[j]->lower_factor &&
318 other->upper_factor == args[j]->upper_factor &&
319 other->DivModeCompatibleTo(args[j]->div_mode)) {
320 args[j].CopyOnWrite()->scale += other->scale * scale;
321 return;
322 }
323 }
324 // Insert other in the end.
325 other.CopyOnWrite()->scale *= scale;
326 this->args.emplace_back(std::move(other));
327 }
328
329 void AddToSelf(const SumExpr& other, int64_t scale);
330
331 /*!
332 * \brief check if cast can be pushed to sub-expressions
333 * \param dtype The target datatype
334 * \param analyzer The analyzer
335 * \return whether the cast can be safely pushed to children
336 */
337 bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
338 bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest()
339 : base == -(1LL << (dtype.bits() - 1));
340 // cast(dtype, arg_1 + arg_2 + ... arg_n) ==
341 // cast(dtype, arg_1) + ... + cast(dtype, arg_n)
342 // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
343 // its intermediate results fit in the range of dtype
344 if (dtype.bits() >= this->dtype.bits()) {
345 return true; // upcast is safe
346 }
347 PrimExpr res = make_const(dtype, 0);
348 for (size_t i = 0; i < args.size(); ++i) {
349 if (args[i]->scale > 0) {
350 res = res + args[i]->Normalize();
351 if (!CastIsSafe(dtype, res, analyzer)) {
352 return false;
353 }
354 }
355 }
356 if (base > 0 || is_min_value) {
357 res = res + make_const(dtype, base);
358 if (!CastIsSafe(dtype, res, analyzer)) {
359 return false;
360 }
361 }
362 // negative scales follows using sub.
363 for (size_t i = 0; i < args.size(); ++i) {
364 if (args[i]->scale < 0) {
365 res = res - args[i]->NormalizeWithScale(-1);
366 if (!CastIsSafe(dtype, res, analyzer)) {
367 return false;
368 }
369 }
370 }
371 if (base < 0 && !is_min_value) {
372 res = res - make_const(dtype, -base);
373 if (!CastIsSafe(dtype, res, analyzer)) {
374 return false;
375 }
376 }
377 for (const auto& arg : args) {
378 if (!arg->CanPushCastToChildren(dtype, analyzer)) {
379 return false;
380 }
381 }
382 return true;
383 }
384
385 /*!
386 * \brief self = cast(dtype, self)
387 * \param dtype The target datatype
388 */
389 void PushCastToChildren(DataType dtype) {
390 for (auto& arg : args) {
391 arg.CopyOnWrite()->PushCastToChildren(dtype);
392 }
393 this->dtype = dtype;
394 }
395
396 static constexpr const char* _type_key = "arith.SumExpr";
397 TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);
398
399 private:
400 /*!
401 * \brief Simplify the args by merging SplitExprs
402 * \param args The original list of arguments.
403 * \return simplified version.
404 */
405 static std::vector<SplitExpr> SimplifySplitExprs(std::vector<SplitExpr> args) {
406 // NOTE: This algorithm relies on the factor that args are divided into segments
407 // and each segment is sorted in descending order of lower_factor.
408 for (size_t i = 0; i < args.size(); ++i) {
409 if (args[i]->scale == 0) continue;
410 for (size_t j = i + 1; j < args.size(); ++j) {
411 SplitExpr& lhs = args[i];
412 SplitExpr& rhs = args[j];
413 if (!lhs->IndexEqual(rhs)) break;
414 if (lhs->upper_factor < rhs->lower_factor) break;
415 if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor &&
416 lhs->DivModeCompatibleTo(rhs->div_mode)) {
417 // folding same co-efficient.
418 rhs.CopyOnWrite()->scale += lhs->scale;
419 lhs.CopyOnWrite()->scale = 0;
420 } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 &&
421 lhs->scale % rhs->scale == 0 &&
422 lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor &&
423 lhs->DivModeCompatibleTo(rhs->div_mode)) {
424 // Rules used in the proof:
425 //
426 // Rule 1: (x % (c * s)) / c = (x / c) % s
427 // Proof:
428 // x can always be decomposed into p * c * s + q * c + r
429 // where 0 <= q * c + r < c * s and 0 <= r < c.
430 // Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q
431 // rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q
432 // Thus, lhs = rhs
433 //
434 // The above proof is for the floordiv.
435 // The same rule also holds for truncdiv(division rule in C).
436 // Because both sides only involve mul, div and mod,
437 // we can take abs of x, c and s, apply the floordiv proof,
438 // and finally add the sign back.
439 //
440 // Rule 2: (x / s) * s + x % s = x (true for both trunc and floor div)
441 //
442 // General merge condition and proof:
443 // - x = lhs->index % lhs->upper_factor
444 // - s = lhs->scale / rhs->scale
445 // - c = rhs->lower_factor
446 //
447 // (x / (c * s)) * s + (x % (c * s)) / c
448 // => ((x / c) / s) * s + ((x / c) % s)
449 // => (x / c)
450 //
451 // Examples:
452 //
453 // (z / 6) * 6 + ((z % 6) / 3) * 3
454 // => ((z / 6) * 2 + (z % 6) / 3) * 3
455 // => (z / 3) * 3
456 // note: x = z, c = 3, s = 2
457 //
458 // ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3
459 // => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3
460 // => ((z % 12) / 3) * 3
461 // note: x = z % 12, c = 3, s = 2
462 // note also the invariance lhs->upper_factor % lhs->lower_factor == 0
463 //
464 SplitExprNode* merged = rhs.CopyOnWrite();
465 merged->upper_factor = lhs->upper_factor;
466 // reset args[i] to be zero.
467 lhs.CopyOnWrite()->scale = 0;
468 break;
469 }
470 }
471 }
472 // sort by the entry
473 // Here we simply sort by descending order of scales.
474 // For now, we do not compare by index because that comparison
475 // can be runtime dependent and create inderminism.
476 // we do not sort by index for now because it can be costly
477 // to deep compare Exprs, and address of Vars can be runtime dependent.
478 //
479 auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) {
480 // order by scale first
481 if (lhs->scale > rhs->scale) return true;
482 if (lhs->scale < rhs->scale) return false;
483 // then order by factor
484 if (lhs->lower_factor > rhs->lower_factor) return true;
485 if (lhs->lower_factor < rhs->lower_factor) return false;
486 // then order by upper factor
487 if (lhs->upper_factor > rhs->upper_factor) return true;
488 if (lhs->upper_factor < rhs->upper_factor) return false;
489 // then order by div mode
490 if (lhs->div_mode > rhs->div_mode) return true;
491 if (lhs->div_mode < rhs->div_mode) return false;
492 // tie.
493 // TODO(tvm-team) We might consider index as the last comparison point,
494 // after we make deep comparator more derministic.
495 // Specifically, we can consider comparing names of vars and break ties with address.
496 return false;
497 };
498 std::stable_sort(args.begin(), args.end(), fcompare);
499 return args;
500 }
501 static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& args, int64_t base) {
502 bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest()
503 : base == -(1LL << (dtype.bits() - 1));
504 // Positive scales first
505 PrimExpr res = make_const(dtype, 0);
506 for (size_t i = 0; i < args.size(); ++i) {
507 if (args[i]->scale > 0) {
508 res = res + args[i]->Normalize();
509 }
510 }
511 if (base > 0 || is_min_value) {
512 res = res + make_const(dtype, base);
513 }
514 // negative scales follows using sub.
515 for (size_t i = 0; i < args.size(); ++i) {
516 if (args[i]->scale < 0) {
517 res = res - args[i]->NormalizeWithScale(-1);
518 }
519 }
520 if (base < 0 && !is_min_value) {
521 res = res - make_const(dtype, -base);
522 }
523 return res;
524 }
525};
526
527class SumExpr : public PrimExpr {
528 public:
529 TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode);
530 TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode);
531};
532
533void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) {
534 // NOTE: it is rare to have a balanced long expression,
535 // linear scan is fine for our case.
536 for (size_t i = 0; i < other->args.size(); ++i) {
537 this->AddToSelf(other->args[i], scale);
538 }
539 this->AddToSelf(other->base * scale);
540}
541
542TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
543 .set_dispatch<SplitExprNode>([](const ObjectRef& node, ReprPrinter* p) {
544 auto* op = static_cast<const SplitExprNode*>(node.get());
545 auto factor_str = [](int64_t f) {
546 return f == SplitExprNode::kPosInf ? std::string("+inf") : std::to_string(f);
547 };
548 p->stream << "split(";
549 p->Print(op->index);
550 p->stream << ", lower=" << factor_str(op->lower_factor)
551 << ", upper=" << factor_str(op->upper_factor) << ", scale=" << op->scale
552 << ", div_mode=";
553 switch (op->div_mode) {
554 // No "default", so that the compiler will emit a warning if more div modes are
555 // added that are not covered by the switch.
556 case kTruncDiv:
557 p->stream << "truncdiv";
558 break;
559 case kFloorDiv:
560 p->stream << "floordiv";
561 break;
562 }
563 p->stream << ')';
564 });
565
566TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
567 .set_dispatch<SumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
568 auto* op = static_cast<const SumExprNode*>(node.get());
569 p->stream << "sum(base=" << op->base;
570 for (const SplitExpr& s : op->args) {
571 p->stream << ", ";
572 p->Print(s);
573 }
574 p->stream << ')';
575 });
576
577// Sub-class RewriteSimplifier::Impl to take benefit of
578// rewriter for condition simplification etc.
579class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
580 public:
581 using Rewriter = RewriteSimplifier::Impl;
582
583 explicit Impl(Analyzer* parent) : Rewriter(parent) {}
584
585 PrimExpr CanonicalSimplify(PrimExpr expr) {
586 expr = operator()(expr);
587 return expr;
588 }
589
590 // override the original mutate function.
591 PrimExpr VisitExpr(const PrimExpr& input_expr) final {
592 auto expr = Rewriter::VisitExpr(input_expr);
593 return Normalize(expr);
594 }
595
596 // Normal mutation without normalization.
597 PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); }
598
599 using Rewriter::VisitExpr_;
600 PrimExpr VisitExpr_(const AddNode* op) final;
601 PrimExpr VisitExpr_(const SubNode* op) final;
602 PrimExpr VisitExpr_(const MulNode* op) final;
603 PrimExpr VisitExpr_(const DivNode* op) final;
604 PrimExpr VisitExpr_(const ModNode* op) final;
605 PrimExpr VisitExpr_(const FloorDivNode* op) final;
606 PrimExpr VisitExpr_(const FloorModNode* op) final;
607 PrimExpr VisitExpr_(const ReduceNode* op) final;
608 PrimExpr VisitExpr_(const CastNode* op) final;
609
610 private:
611 /*!
612 * \brief compute lhs / cval
613 * \param lhs The left operand.
614 * \param cval The constant value.
615 * \param div_mode The division mode.
616 * \return The result expression;
617 */
618 SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
619 /*!
620 * \brief compute lhs % cval
621 * \param lhs The left operand.
622 * \param cval The constant value.
623 * \param div_mode The division mode.
624 * \return The result expression;
625 */
626 SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode);
627 /*!
628 * \brief Separate psum into divisible and non-divisible parts.
629 * \param psum The sum expression.
630 * \param coeff The co-efficient.
631 * \param out_divisible The result divisible component.
632 * \param out_non_divisible The non-divisible component.
633 */
634 void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
635 SumExpr* out_non_divisible);
636 /*!
637 * \brief Normalize expr to normal expr.
638 * \param expr The input expression.
639 * \return Normalized expr.
640 */
641 PrimExpr Normalize(PrimExpr expr) {
642 if (const auto* op = expr.as<CanonicalExprNode>()) {
643 return op->Normalize();
644 } else {
645 return expr;
646 }
647 }
648 /*!
649 * \brief Create a SplitExpr from expr.
650 * \param expr The input expr.
651 * \return The transformed SplitExpr.
652 */
653 SplitExpr ToSplitExpr(PrimExpr expr) {
654 if (const auto* op = expr.as<SplitExprNode>()) {
655 return GetRef<SplitExpr>(op);
656 }
657 if (const auto* op = expr.as<SumExprNode>()) {
658 if (op->base == 0 && op->args.size() == 1) return op->args[0];
659 }
660 if (const auto* op = expr.as<CanonicalExprNode>()) {
661 expr = op->Normalize();
662 }
663 ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>();
664 n->dtype = expr.dtype();
665 n->index = std::move(expr);
666 n->div_mode = kTruncDiv;
667 return SplitExpr(n);
668 }
669 /*!
670 * \brief Convert expr to an equivalent SplitExpr
671 * that has the specified div_mode.
672 *
673 * This function will return the same expr if its
674 * div_mode already satisfies the need.
675 *
676 * \param expr The input expr.
677 * \param div_mode The new div_mode.
678 * \return The transformed SplitExpr.
679 */
680 SplitExpr ConvertDivMode(SplitExpr expr, DivMode div_mode) {
681 if (expr->div_mode == div_mode) return expr;
682 if (expr->DivModeCompatibleTo(div_mode)) {
683 expr.CopyOnWrite()->div_mode = div_mode;
684 return expr;
685 }
686 expr = ToSplitExpr(Normalize(expr));
687 ICHECK(expr->DivModeCompatibleTo(div_mode));
688 expr.CopyOnWrite()->div_mode = div_mode;
689 return expr;
690 }
691 /*!
692 * \brief Create a SumExpr from expr.
693 * \param expr The input expr.
694 * \return The transformed SumExpr.
695 */
696 SumExpr ToSumExpr(PrimExpr expr) {
697 if (const auto* op = expr.as<SumExprNode>()) {
698 return GetRef<SumExpr>(op);
699 }
700 ObjectPtr<SumExprNode> n = make_object<SumExprNode>();
701 n->dtype = expr.dtype();
702 if (const auto* op = expr.as<IntImmNode>()) {
703 n->base = op->value;
704 return SumExpr(n);
705 } else {
706 n->args.emplace_back(ToSplitExpr(expr));
707 return SumExpr(n);
708 }
709 }
710 // Simplify the combiner used in reduce.
711 PrimExpr SimplifyReduceCombiner(const ReduceNode* op);
712};
713
714PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) {
715 if (!IsIndexType(op->dtype)) {
716 return Rewriter::VisitExpr_(op);
717 }
718 // normalize
719 PrimExpr a = this->CanonicalMutate(op->a);
720 PrimExpr b = this->CanonicalMutate(op->b);
721
722 // const folding
723 if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value();
724
725 // canonical form simplification.
726 SumExpr ret = ToSumExpr(std::move(a));
727
728 if (const auto* op = b.as<IntImmNode>()) {
729 ret.CopyOnWrite()->AddToSelf(op->value);
730 } else if (const auto* op = b.as<SumExprNode>()) {
731 ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1);
732 } else {
733 ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1);
734 }
735 return std::move(ret);
736}
737
738PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) {
739 if (!IsIndexType(op->dtype)) {
740 return Rewriter::VisitExpr_(op);
741 }
742 // normalize
743 PrimExpr a = this->CanonicalMutate(op->a);
744 PrimExpr b = this->CanonicalMutate(op->b);
745
746 // const folding
747 if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value();
748
749 // canonical form simplification.
750 SumExpr ret = ToSumExpr(std::move(a));
751
752 if (const auto* op = b.as<IntImmNode>()) {
753 ret.CopyOnWrite()->AddToSelf(-op->value);
754 } else if (const auto* op = b.as<SumExprNode>()) {
755 ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1);
756 } else {
757 ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1);
758 }
759 return std::move(ret);
760}
761
762PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
763 if (!IsIndexType(op->dtype)) {
764 return Rewriter::VisitExpr_(op);
765 }
766 // normalize
767 PrimExpr a = this->CanonicalMutate(op->a);
768 PrimExpr b = this->CanonicalMutate(op->b);
769
770 // const folding
771 if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value();
772
773 // x * c
774 if (a.as<IntImmNode>()) {
775 std::swap(a, b);
776 }
777 if (const auto* bconst = b.as<IntImmNode>()) {
778 if (a.as<SumExprNode>()) {
779 SumExpr ret = Downcast<SumExpr>(std::move(a));
780 ret.CopyOnWrite()->MulToSelf(bconst->value);
781 return std::move(ret);
782 } else {
783 SplitExpr ret = ToSplitExpr(std::move(a));
784 ret.CopyOnWrite()->MulToSelf(bconst->value);
785 return std::move(ret);
786 }
787 }
788
789 // normal path.
790 a = Normalize(a);
791 b = Normalize(b);
792 if (op->a.same_as(a) && op->b.same_as(b)) {
793 return GetRef<PrimExpr>(op);
794 } else {
795 return Mul(a, b);
796 }
797}
798
799void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff,
800 SumExpr* out_divisible,
801 SumExpr* out_non_divisible) {
802 auto divisible = make_object<SumExprNode>();
803 auto non_divisible = make_object<SumExprNode>();
804 divisible->dtype = psum->dtype;
805 non_divisible->dtype = psum->dtype;
806
807 if (psum->base % coeff == 0) {
808 divisible->base = psum->base;
809 } else {
810 non_divisible->base = psum->base;
811 }
812 for (const auto& e : psum->args) {
813 if (e->scale % coeff == 0) {
814 divisible->args.push_back(e);
815 } else {
816 non_divisible->args.push_back(e);
817 }
818 }
819 *out_divisible = SumExpr(divisible);
820 *out_non_divisible = SumExpr(non_divisible);
821}
822
823SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
824 ICHECK_GT(cval, 0);
825 lhs = ConvertDivMode(lhs, div_mode);
826
827 // the following rule works for both floordiv and truncdiv
828 if (lhs->scale % cval == 0) {
829 lhs.CopyOnWrite()->scale /= cval;
830 return lhs;
831 }
832
833 if (cval % lhs->scale == 0) {
834 int64_t scaled_cval = cval / lhs->scale;
835 if (lhs->upper_factor == SplitExprNode::kPosInf ||
836 lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) {
837 // directly fold division.
838 lhs.CopyOnWrite()->scale = 1;
839 lhs.CopyOnWrite()->lower_factor *= scaled_cval;
840 lhs->Verify();
841 return lhs;
842 } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) {
843 // (x % c1) / c2 => 0 when c2 >= c1
844 return ToSplitExpr(make_zero(lhs.dtype()));
845 } else {
846 // move the upper_factor modular into index.
847 lhs.CopyOnWrite()->index =
848 ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode);
849 lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
850 lhs.CopyOnWrite()->scale = 1;
851 lhs.CopyOnWrite()->lower_factor *= scaled_cval;
852 lhs->Verify();
853 return lhs;
854 }
855 }
856 // directly return the split with cval == 1
857 lhs = ToSplitExpr(Normalize(lhs));
858 ICHECK(lhs->DivModeCompatibleTo(div_mode));
859 ICHECK_EQ(lhs->scale, 1);
860 lhs.CopyOnWrite()->lower_factor *= cval;
861 lhs.CopyOnWrite()->div_mode = div_mode;
862 return lhs;
863}
864
865PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
866 if (!IsIndexType(op->dtype)) {
867 return Rewriter::VisitExpr_(op);
868 }
869
870 PrimExpr a = this->CanonicalMutate(op->a);
871 PrimExpr b = this->CanonicalMutate(op->b);
872
873 // const folding
874 if (auto const_res = TryConstFold<Div>(a, b)) return const_res.value();
875 PVar<IntImm> c1;
876 // x / c1
877 if (c1.Match(b) && c1.Eval()->value > 0) {
878 int64_t cval = c1.Eval()->value;
879 if (cval == 1) return a;
880
881 if (const auto* psum = a.as<SumExprNode>()) {
882 SumExpr lhs, extra;
883 SeparateDivisibleParts(psum, cval, &lhs, &extra);
884 // can be divided by cval
885 if (extra->IsZero()) {
886 lhs.CopyOnWrite()->DivideBy(cval);
887 return std::move(lhs);
888 }
889 // both lhs and extra are non-negative
890 if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
891 analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
892 lhs.CopyOnWrite()->DivideBy(cval);
893 PrimExpr temp = Normalize(extra);
894 if (const auto* pconst = temp.as<IntImmNode>()) {
895 lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
896 } else {
897 // if 0 <= extra < cval, it means the extra can be eliminated.
898 if (TryCompare(temp, cval) != CompareResult::kLT) {
899 lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
900 }
901 }
902 return std::move(lhs);
903 }
904 } else {
905 // if a >= 0 && a < cval, then result == 0
906 auto cbound = analyzer_->const_int_bound(Normalize(a));
907 if (cbound->min_value >= 0 && cbound->max_value < cval) {
908 return make_zero(a.dtype());
909 }
910 }
911 return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
912 }
913 // normal path
914 a = Normalize(a);
915 b = Normalize(b);
916 if (op->a.same_as(a) && op->b.same_as(b)) {
917 return GetRef<PrimExpr>(op);
918 } else {
919 return Div(a, b);
920 }
921}
922
923PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
924 if (!IsIndexType(op->dtype)) {
925 return Rewriter::VisitExpr_(op);
926 }
927 PrimExpr a = this->CanonicalMutate(op->a);
928 PrimExpr b = this->CanonicalMutate(op->b);
929
930 // const folding
931 if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value();
932 PVar<IntImm> c1;
933 // x / c1
934 if (c1.Match(b) && c1.Eval()->value > 0) {
935 int64_t cval = c1.Eval()->value;
936 if (cval == 1) return a;
937
938 if (const auto* psum = a.as<SumExprNode>()) {
939 SumExpr lhs, extra;
940 SeparateDivisibleParts(psum, cval, &lhs, &extra);
941 if (extra->IsZero()) {
942 lhs.CopyOnWrite()->DivideBy(cval);
943 return std::move(lhs);
944 }
945 // continue simplification.
946 lhs.CopyOnWrite()->DivideBy(cval);
947 PrimExpr temp = Normalize(extra);
948 if (const auto* pconst = temp.as<IntImmNode>()) {
949 lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
950 } else {
951 // if 0 <= extra < cval, it means the extra can be eliminated.
952 if (!(TryCompare(temp, cval) == CompareResult::kLT &&
953 analyzer_->CanProveGreaterEqual(temp, 0))) {
954 lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
955 }
956 }
957 return std::move(lhs);
958 } else {
959 // if a >= 0 && a < cval, then result == 0
960 auto cbound = analyzer_->const_int_bound(Normalize(a));
961 if (cbound->min_value >= 0 && cbound->max_value < cval) {
962 return make_zero(a.dtype());
963 }
964 }
965 return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
966 }
967 // normal path
968 a = Normalize(a);
969 b = Normalize(b);
970 if (op->a.same_as(a) && op->b.same_as(b)) {
971 return GetRef<PrimExpr>(op);
972 } else {
973 return FloorDiv(a, b);
974 }
975}
976
977SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) {
978 ICHECK_GT(cval, 0);
979 lhs = ConvertDivMode(lhs, div_mode);
980
981 if (lhs->scale % cval == 0) {
982 lhs.CopyOnWrite()->scale = 0;
983 return lhs;
984 }
985 if (cval % lhs->scale == 0) {
986 // The rationale:
987 // (index % upper) / lower * scale % cval, given cval = scaled_cval * scale
988 // by the rule (x * c1) % (c2 * c1) => (x % c2) * c1,
989 // = (index % upper) / lower % scaled_cval * scale
990 // by the rule (x / c1) % c2 => (x % (c1 * c2)) / c1,
991 // = (index % upper) % (new_upper_factor) / lower * scale
992 int64_t scaled_cval = cval / lhs->scale;
993 int64_t new_upper_factor = lhs->lower_factor * scaled_cval;
994 // try to see if we can reduce the existing upper modular.
995 if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) {
996 // we gained a new upper factor that is smaller
997 // than the original one
998 // Perhaps there are more chances in simplifying the index
999 // Do a recursive call to simplify the mod with the new factor.
1000 if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) {
1001 auto updated = ToSplitExpr(this->VisitExpr(
1002 ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode)));
1003 // re-apply the lower_factor
1004 if (lhs->lower_factor != 1) {
1005 auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode);
1006 ret.CopyOnWrite()->MulToSelf(lhs->scale);
1007 return ret;
1008 } else {
1009 updated.CopyOnWrite()->MulToSelf(lhs->scale);
1010 return updated;
1011 }
1012 } else {
1013 lhs.CopyOnWrite()->upper_factor = new_upper_factor;
1014 return lhs;
1015 }
1016 } else if (new_upper_factor % lhs->upper_factor == 0) {
1017 // (x % 2) % 4 => x % 2
1018 return lhs;
1019 }
1020 }
1021 // Normalize the value.
1022 lhs = ToSplitExpr(Normalize(lhs));
1023 ICHECK(lhs->DivModeCompatibleTo(div_mode));
1024 ICHECK_EQ(lhs->scale, 1);
1025 ICHECK_EQ(lhs->lower_factor, 1);
1026 lhs.CopyOnWrite()->div_mode = div_mode;
1027 lhs.CopyOnWrite()->upper_factor = cval;
1028 return lhs;
1029}
1030
1031PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
1032 if (!IsIndexType(op->dtype)) {
1033 return Rewriter::VisitExpr_(op);
1034 }
1035 // normalize
1036 PrimExpr a = this->CanonicalMutate(op->a);
1037 PrimExpr b = this->CanonicalMutate(op->b);
1038
1039 // const folding
1040 if (auto const_res = TryConstFold<Mod>(a, b)) return const_res.value();
1041
1042 PVar<IntImm> c1;
1043 // x % c1
1044 if (c1.Match(b) && c1.Eval()->value > 0) {
1045 int64_t cval = c1.Eval()->value;
1046 if (const auto* psum = a.as<SumExprNode>()) {
1047 SumExpr lhs, extra;
1048 SeparateDivisibleParts(psum, cval, &lhs, &extra);
1049 if (extra->IsZero()) {
1050 return make_zero(a.dtype());
1051 }
1052 // both lhs and extra are non-negative
1053 if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) &&
1054 analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) {
1055 PrimExpr temp = Normalize(extra);
1056 if (temp.as<IntImmNode>()) {
1057 return truncmod(temp, c1.Eval());
1058 } else {
1059 // If temp < cval && temp >=0 then can remove the mod.
1060 if (TryCompare(temp, cval) == CompareResult::kLT) {
1061 return temp;
1062 } else {
1063 // contonue to use logic below.
1064 a = extra;
1065 psum = a.as<SumExprNode>();
1066 ICHECK(psum != nullptr);
1067 }
1068 }
1069 }
1070 // Simplify the offset constant if necessary.
1071 // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0
1072 auto cbound = analyzer_->const_int_bound(Normalize(a));
1073 int64_t new_base = psum->base % cval;
1074 if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) {
1075 SumExpr sum_expr = Downcast<SumExpr>(a);
1076 sum_expr.CopyOnWrite()->base = new_base;
1077 return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv);
1078 }
1079 } else {
1080 // if a >= 0 && a < cval, then result == 0
1081 auto cbound = analyzer_->const_int_bound(Normalize(a));
1082 if (cbound->min_value >= 0 && cbound->max_value < cval) {
1083 return a;
1084 }
1085 }
1086 return SplitModConst(ToSplitExpr(std::move(a)), cval, kTruncDiv);
1087 }
1088 // normal path
1089 a = Normalize(a);
1090 b = Normalize(b);
1091 if (op->a.same_as(a) && op->b.same_as(b)) {
1092 return GetRef<PrimExpr>(op);
1093 } else {
1094 return Mod(a, b);
1095 }
1096}
1097
1098PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
1099 if (!IsIndexType(op->dtype)) {
1100 return Rewriter::VisitExpr_(op);
1101 }
1102 // normalize
1103 PrimExpr a = this->CanonicalMutate(op->a);
1104 PrimExpr b = this->CanonicalMutate(op->b);
1105
1106 // const folding
1107 if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value();
1108
1109 PVar<IntImm> c1;
1110 // x % c1
1111 if (c1.Match(b) && c1.Eval()->value > 0) {
1112 int64_t cval = c1.Eval()->value;
1113 if (const auto* psum = a.as<SumExprNode>()) {
1114 SumExpr lhs, extra;
1115 SeparateDivisibleParts(psum, cval, &lhs, &extra);
1116 PrimExpr temp = Normalize(extra);
1117 if (temp.as<IntImmNode>()) {
1118 return floormod(temp, c1.Eval());
1119 } else {
1120 // If temp < cval && temp >=0 then can remove the mod.
1121 if (TryCompare(temp, cval) == CompareResult::kLT &&
1122 analyzer_->CanProveGreaterEqual(temp, 0)) {
1123 return temp;
1124 } else {
1125 // contonue to use logic below.
1126 a = extra;
1127 psum = a.as<SumExprNode>();
1128 ICHECK(psum != nullptr);
1129 }
1130 }
1131 // Simplify the offset constant if necessary.
1132 // floormod(x - 5, 3) => floormod(x + 1, 3)
1133 int64_t new_base = floormod(psum->base, cval);
1134 SumExpr sum_expr = Downcast<SumExpr>(std::move(a));
1135 sum_expr.CopyOnWrite()->base = new_base;
1136 return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv);
1137 } else {
1138 // if a >= 0 && a < cval, then result == a
1139 auto cbound = analyzer_->const_int_bound(Normalize(a));
1140 if (cbound->min_value >= 0 && cbound->max_value < cval) {
1141 return a;
1142 }
1143 }
1144 return SplitModConst(ToSplitExpr(std::move(a)), cval, kFloorDiv);
1145 }
1146 // normal path
1147 a = Normalize(a);
1148 b = Normalize(b);
1149 if (op->a.same_as(a) && op->b.same_as(b)) {
1150 return GetRef<PrimExpr>(op);
1151 } else {
1152 return FloorMod(a, b);
1153 }
1154}
1155
1156// Simplify reduce expression.
1157PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) {
1158 // First simplify the results
1159 Array<PrimExpr> simplified_result;
1160 for (const auto& res : op->combiner->result) {
1161 PrimExpr new_res = this->VisitExpr(res);
1162 simplified_result.push_back(new_res);
1163 }
1164
1165 // Which components to keep
1166 std::vector<int> used(op->combiner->result.size(), false);
1167
1168 // This function recursively marks the used components starting from
1169 // the index idx
1170 std::function<void(int)> mark_used;
1171 mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
1172 // if the idx-th component was marked as used before, do nothing
1173 if (used[idx]) return;
1174 used[idx] = true;
1175
1176 // check if the idx-th result expr uses some lhs or rhs variables
1177 // and recursively mark the corresponding components
1178 for (size_t i = 0; i < simplified_result.size(); ++i)
1179 if (!used[i]) {
1180 if (UsesVar(simplified_result[idx],
1181 [v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
1182 UsesVar(simplified_result[idx],
1183 [v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
1184 mark_used(i);
1185 }
1186 };
1187
1188 // mark all used components starting from the value_index
1189 mark_used(op->value_index);
1190
1191 // components which have side effects should also be preserved
1192 for (size_t i = 0; i < used.size(); ++i) {
1193 if (SideEffect(op->source[i]) > CallEffectKind::kReadState ||
1194 SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState ||
1195 SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState ||
1196 (!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) {
1197 mark_used(i);
1198 }
1199 }
1200
1201 int new_value_index = op->value_index;
1202 Array<PrimExpr> new_result;
1203 Array<PrimExpr> new_identity;
1204 Array<Var> new_lhs;
1205 Array<Var> new_rhs;
1206 Array<PrimExpr> new_source;
1207 Array<PrimExpr> new_init;
1208
1209 // new stuff is old stuff which is used
1210 for (size_t i = 0; i < used.size(); ++i) {
1211 if (used[i]) {
1212 // We simplify the result and identity, but not the source
1213 new_result.push_back(simplified_result[i]);
1214 new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i]));
1215 new_lhs.push_back(op->combiner->lhs[i]);
1216 new_rhs.push_back(op->combiner->rhs[i]);
1217 new_source.push_back(op->source[i]);
1218 if (!op->init.empty()) new_init.push_back(op->init[i]);
1219 } else if (static_cast<int>(i) < op->value_index) {
1220 // value_index should also be adjusted
1221 new_value_index--;
1222 }
1223 }
1224
1225 CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity);
1226 return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init);
1227}
1228
1229PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
1230 // Recursively call simplification when necessary.
1231 PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op);
1232 op = ret.as<ReduceNode>();
1233 // already been simplified by const reduction axis removal
1234 if (op == nullptr) return ret;
1235 if (op->axis.empty()) {
1236 if (!op->init.empty()) {
1237 return this->VisitExpr(Select(op->condition,
1238 (*op->combiner.get())(op->init, op->source)[op->value_index],
1239 op->init[op->value_index]));
1240 }
1241 // Note that here we assume that the identity element is indeed identity. Without this
1242 // assumption we would have to perform a single iteration of the loop, i.e. use
1243 // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]`
1244 // instead of `op->source[op->value_index]`. The former may be more difficult to simplify.
1245 return this->VisitExpr(Select(op->condition, op->source[op->value_index],
1246 op->combiner->identity_element[op->value_index]));
1247 }
1248 // combiner simplification.
1249 ret = SimplifyReduceCombiner(op);
1250 return ret;
1251}
1252
1253PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
1254 if (!IsIndexType(op->dtype)) {
1255 return Rewriter::VisitExpr_(op);
1256 }
1257 // normalize
1258 PrimExpr value = this->CanonicalMutate(op->value);
1259 // PushCastToChildren
1260 if (value.as<SumExprNode>()) {
1261 SumExpr se = Downcast<SumExpr>(value);
1262 if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
1263 se.CopyOnWrite()->PushCastToChildren(op->dtype);
1264 return std::move(se);
1265 }
1266 }
1267 if (value.as<SplitExprNode>()) {
1268 SplitExpr se = Downcast<SplitExpr>(value);
1269 if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
1270 se.CopyOnWrite()->PushCastToChildren(op->dtype);
1271 return std::move(se);
1272 }
1273 }
1274 return Rewriter::VisitExpr_(op);
1275}
1276
1277PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
1278 return impl_->CanonicalSimplify(expr);
1279}
1280
1281void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) {
1282 impl_->Update(var, info, override);
1283}
1284
1285CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
1286
1287CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; }
1288
1289} // namespace arith
1290} // namespace tvm
1291