1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rewrite_simplify.cc
22 * \brief Rewrite-rule based simplification.
23 */
24// Acknowledgement: Most rewrite-rules are from Halide.
25#include "rewrite_simplify.h"
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/tir/builtin.h>
29#include <tvm/tir/op.h>
30
31#include <algorithm>
32#include <utility>
33
34#include "../target/datatype/registry.h"
35#include "conjunctive_normal_form.h"
36#include "const_fold.h"
37#include "constraint_extract.h"
38#include "pattern_match.h"
39
40namespace tvm {
41namespace arith {
42
43using namespace tir;
44
45// macro for doing simple rewrite
46#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \
47 if ((SrcExpr).Match(ret)) { \
48 return (ResExpr).Eval(); \
49 }
50
51// macro for rewrite + recursively rewrite ResExpr
52#define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \
53 if ((SrcExpr).Match(ret)) { \
54 return RecursiveRewrite((ResExpr).Eval()); \
55 }
56
57// macro rewrite only if CondExor is true after match.
58#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
59 if ((SrcExpr).Match(ret) && (CondExpr)) { \
60 return (ResExpr).Eval(); \
61 }
62
63// macro rewrite + recursive_rewrite only if CondExor is true after match.
64#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \
65 if ((SrcExpr).Match(ret) && (CondExpr)) { \
66 return RecursiveRewrite((ResExpr).Eval()); \
67 }
68
69// NOTE for developers:
70//
71// We mainly focus on index expression simplification.
72// Besides the RewriteSimplifier, some cases can be better
73// handled by CanonicalSimplifier.
74//
75
76/* Utility for rewriting only boolean portions of an expression
77 *
78 * Performs a subset of simplifications done by RewriteSimplifier,
79 * sufficient to negate a simplified expression. Intended for
80 * application on an expression that has previously been simplified.
81 *
82 * \param expr The boolean expression to be normalized
83 *
84 * \returns The normalized boolean expression
85 */
86PrimExpr NormalizeBooleanOperators(PrimExpr expr) {
87 PVar<PrimExpr> x, y;
88
89 while (true) {
90 if ((!!x).Match(expr)) {
91 expr = x.Eval();
92 } else if ((!(x || y)).Match(expr)) {
93 return NormalizeBooleanOperators(!x.Eval()) && NormalizeBooleanOperators(!y.Eval());
94 } else if ((!(x && y)).Match(expr)) {
95 return NormalizeBooleanOperators(!x.Eval()) || NormalizeBooleanOperators(!y.Eval());
96 } else if ((x >= y).Match(expr) || (!(x < y)).Match(expr) || (!(y > x)).Match(expr)) {
97 return y.Eval() <= x.Eval();
98 } else if ((x > y).Match(expr) || (!(x <= y)).Match(expr) || (!(y >= x)).Match(expr)) {
99 return y.Eval() < x.Eval();
100 } else if ((!(x == y)).Match(expr)) {
101 return x.Eval() != y.Eval();
102 } else if ((!(x != y)).Match(expr)) {
103 return x.Eval() == y.Eval();
104 } else {
105 return expr;
106 }
107 }
108}
109
110CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) {
111 CompareResult output = CompareResult::kUnknown;
112
113 auto is_finished = [&output]() {
114 return output == CompareResult::kEQ || output == CompareResult::kLT ||
115 output == CompareResult::kGT;
116 };
117
118 output = CompareResult(output & TryCompareUsingConstIntBounds(x, y));
119
120 if (is_finished()) return output;
121
122 output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
123
124 return output;
125}
126
127CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimExpr& x,
128 const PrimExpr y) {
129 return TryCompare(x - y, 0);
130}
131
132CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x,
133 const PrimExpr& y) {
134 bool propagate_inequalities = enabled_extensions_ & kTransitivelyProveInequalities;
135 return analyzer_->transitive_comparisons.TryCompare(x, y, propagate_inequalities);
136}
137
138// try to prove x equals val
139CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) {
140 PrimExpr diff = this->VisitExpr(x);
141 if (const auto* ptr = diff.as<IntImmNode>()) {
142 if (ptr->value == val) {
143 return CompareResult::kEQ;
144 } else if (ptr->value > val) {
145 return CompareResult::kGT;
146 } else if (ptr->value < val) {
147 return CompareResult::kLT;
148 }
149 }
150 ConstIntBound dbound = analyzer_->const_int_bound(diff);
151 if (dbound->min_value == val && dbound->max_value == val) {
152 return CompareResult::kEQ;
153 }
154 if (dbound->min_value > val) {
155 return CompareResult::kGT;
156 }
157 if (dbound->max_value < val) {
158 return CompareResult::kLT;
159 }
160 if (dbound->min_value >= val) {
161 return CompareResult::kGE;
162 }
163 if (dbound->max_value <= val) {
164 return CompareResult::kLE;
165 }
166 if (val == 0) {
167 ModularSet dmod = analyzer_->modular_set(diff);
168 if (dmod->base != 0) {
169 return CompareResult::kNE;
170 }
171 }
172 return CompareResult::kUnknown;
173}
174
175void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
176 if (!can_override) {
177 auto it = var_map_.find(var);
178 if (it != var_map_.end()) {
179 ICHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'"
180 << " with a different value: "
181 << "original=" << it->second << ", new=" << info;
182 }
183 }
184 var_map_[var] = info;
185}
186
187PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
188 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
189 op = ret.as<AddNode>();
190 if (auto const_res = TryConstFold<Add>(op->a, op->b)) return const_res.value();
191 // Pattern var to match any expression
192 PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
193 // Pattern var match IntImm
194 PVar<IntImm> c1, c2, c3;
195 // Pattern var match FloatImm
196 PVar<FloatImm> c4;
197 // Pattern var for lanes in broadcast and ramp
198 PVar<int> lanes;
199 // Vector rules
200 if (op->dtype.lanes() != 1) {
201 TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes));
202 TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes));
203 TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes));
204 TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes));
205 TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f);
206 }
207
208 if (IsIndexType(op->dtype)) {
209 // Index rules
210 // cancelation rules
211 TVM_TRY_REWRITE((x - y) + y, x);
212 TVM_TRY_REWRITE(x + (y - x), y);
213
214 TVM_TRY_REWRITE((x - y) + (y - z), x - z);
215 TVM_TRY_REWRITE((x - y) + (z - x), z - y);
216
217 TVM_TRY_REWRITE(min(x, y - z) + z, min(x + z, y));
218 TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z));
219 TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
220 TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
221
222 TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
223 c1.Eval()->value == -c2.Eval()->value);
224 TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
225 c1.Eval()->value == -c2.Eval()->value);
226 TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y),
227 c1.Eval()->value == -c2.Eval()->value);
228 TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y),
229 c1.Eval()->value == -c2.Eval()->value);
230
231 TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
232 TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
233 TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
234 TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y);
235
236 TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value);
237 TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value);
238 TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value);
239 TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value);
240
241 // constant folding
242 // NOTE: canonicalization might better at this.
243 TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2));
244
245 // mul co-efficient folding
246 TVM_TRY_REWRITE(x + x, x * 2);
247 TVM_TRY_REWRITE(x * y + x, x * (y + 1));
248 TVM_TRY_REWRITE(y * x + x, x * (y + 1));
249 TVM_TRY_REWRITE(x + y * x, x * (1 + y));
250 TVM_TRY_REWRITE(x + x * y, x * (1 + y));
251 TVM_TRY_REWRITE(x * y + x * z, x * (y + z));
252 TVM_TRY_REWRITE(y * x + x * z, x * (y + z));
253 TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
254 TVM_TRY_REWRITE(y * x + z * x, x * (y + z));
255
256 // DivMod rules
257 // truc div
258 TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
259 // floor div
260 TVM_TRY_REWRITE(floordiv(x, y) * y + floormod(x, y), x);
261 TVM_TRY_REWRITE(y * floordiv(x, y) + floormod(x, y), x);
262 TVM_TRY_REWRITE(floormod(x, y) + floordiv(x, y) * y, x);
263 TVM_TRY_REWRITE(floormod(x, y) + y * floordiv(x, y), x);
264
265 TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
266 c2.Eval()->value > 0);
267
268 // canonicalization rule
269 // will try rewrite again after canonicalization.
270 TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1);
271 TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1);
272 TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
273 TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
274 TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
275 TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
276
277 // DivMod rules
278 // truc div
279 TVM_TRY_RECURSIVE_REWRITE(truncmod(y, c1) + x * c1, x * c1 + truncmod(y, c1));
280 // floor div
281 TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1));
282 }
283
284 // condition rules.
285 TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2));
286 // default value
287 return ret;
288}
289
290std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
291 size_t old_literal_size = literal_constraints_.size();
292 // we will compare the already simplified result with the constraint,
293 // so simplify the constraint as well
294 PrimExpr new_constraint = operator()(constraint);
295 for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) {
296 if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
297 literal_constraints_.push_back(subconstraint);
298 PrimExpr negation;
299 if (subconstraint.dtype().is_bool()) {
300 // We could apply NormalizeBooleanOperators during
301 // TryMatchLiteralConstraint, but that would require
302 // performing a rewrite of each expression being checked.
303 // This way, we only apply a rewrite for each constraint being
304 // applied.
305 negation = NormalizeBooleanOperators(Not(subconstraint));
306 } else {
307 negation = subconstraint == make_zero(subconstraint.dtype());
308 }
309 literal_constraints_.push_back(Not(negation));
310 }
311 }
312 size_t new_literal_size = literal_constraints_.size();
313 auto frecover = [old_literal_size, new_literal_size, this]() {
314 ICHECK_EQ(literal_constraints_.size(), new_literal_size);
315 literal_constraints_.resize(old_literal_size);
316 };
317 return frecover;
318}
319
320void RewriteSimplifier::Impl::SetEnabledExtensions(Extension flags) { enabled_extensions_ = flags; }
321
322RewriteSimplifier::Extension RewriteSimplifier::Impl::GetEnabledExtensions() const {
323 return enabled_extensions_;
324}
325
326PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
327 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
328 op = ret.as<SubNode>();
329 if (auto const_res = TryConstFold<Sub>(op->a, op->b)) return const_res.value();
330 // Pattern var to match any expression
331 PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
332 // Pattern var match IntImm
333 PVar<IntImm> c1, c2, c3;
334 // Pattern var for lanes in broadcast and ramp
335 PVar<int> lanes;
336 // Vector rules
337 if (op->dtype.lanes() != 1) {
338 TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes));
339 TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes));
340 TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes));
341 TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes));
342 }
343
344 if (IsIndexType(op->dtype)) {
345 // Index rules
346 // cancelation rules
347 TVM_TRY_REWRITE((x + y) - y, x);
348 TVM_TRY_REWRITE((x + y) - x, y);
349 TVM_TRY_REWRITE(x - (y + x), 0 - y);
350 TVM_TRY_REWRITE(x - (x + y), 0 - y);
351
352 TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x));
353 TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0));
354 TVM_TRY_REWRITE(max(x, y) - x, max(0, y - x));
355 TVM_TRY_REWRITE(max(x, y) - y, max(x - y, 0));
356
357 TVM_TRY_REWRITE(x - max(x, y), min(0, x - y));
358 TVM_TRY_REWRITE(y - max(x, y), min(y - x, 0));
359 TVM_TRY_REWRITE(x - min(x, y), max(0, x - y));
360 TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0));
361
362 // mul co-efficient folding
363 TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
364 TVM_TRY_REWRITE(x * y - x, x * (y - 1));
365 TVM_TRY_REWRITE(y * x - x, x * (y - 1));
366 TVM_TRY_REWRITE(x - y * x, x * (1 - y));
367 TVM_TRY_REWRITE(x - x * y, x * (1 - y));
368 TVM_TRY_REWRITE(x * y - x * z, x * (y - z));
369 TVM_TRY_REWRITE(y * x - x * z, x * (y - z));
370 TVM_TRY_REWRITE(x * y - z * x, x * (y - z));
371 TVM_TRY_REWRITE(y * x - z * x, x * (y - z));
372
373 // constant cancelation
374 TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
375 TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2));
376
377 // cancelization rule involving 4 operands
378 TVM_TRY_REWRITE((x + y) - (x + z), y - z);
379 TVM_TRY_REWRITE((x + y) - (z + x), y - z);
380 TVM_TRY_REWRITE((y + x) - (z + x), y - z);
381 TVM_TRY_REWRITE((y + x) - (x + z), y - z);
382
383 TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x));
384 TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x));
385 TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y));
386 TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y));
387
388 TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x));
389 TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x));
390 TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y));
391 TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y));
392
393 TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z));
394 TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z));
395 TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y));
396 TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y));
397
398 TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x));
399 TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x));
400
401 TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s1,
402 CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
403
404 TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s2,
405 CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
406 TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s1,
407 CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
408 TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s2,
409 CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0));
410
411 // DivMod rules
412 // trucdiv
413 // NOTE: c*(x/c) + x % c == x is true all division mode.
414 TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0);
415 TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0);
416 TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y,
417 c1.Eval()->value != 0);
418 TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1),
419 c1.Eval()->value != 0);
420 TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y,
421 c1.Eval()->value != 0);
422 TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y,
423 c1.Eval()->value != 0);
424
425 TVM_TRY_REWRITE_IF(
426 x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
427 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
428 TVM_TRY_REWRITE_IF(
429 truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
430 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
431 TVM_TRY_REWRITE_IF(
432 x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
433 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
434 TVM_TRY_REWRITE_IF(
435 truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
436 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
437 TVM_TRY_REWRITE_IF(
438 x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
439 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
440 TVM_TRY_REWRITE_IF(
441 truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
442 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
443
444 // Proof in the case of floordiv, need positive condition.
445 // let x = a * c3 + r
446 // (x + c1) / c3 - x / c3 => (r + c1) / c3
447 // NOTE: the use of floormod(c2, c3) was intentional to simplify the const.
448 TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3),
449 truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
450 CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
451 c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0);
452 TVM_TRY_REWRITE_IF(
453 truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3),
454 CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0);
455
456 // floordiv
457 TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0);
458 TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0);
459 TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y,
460 c1.Eval()->value != 0);
461 TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1),
462 c1.Eval()->value != 0);
463 TVM_TRY_REWRITE_IF(x - floordiv(x - y, c1) * c1, floormod(x - y, c1) + y,
464 c1.Eval()->value != 0);
465 TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
466 c1.Eval()->value != 0);
467
468 TVM_TRY_REWRITE_IF(
469 x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
470 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
471 TVM_TRY_REWRITE_IF(
472 floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2,
473 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
474 TVM_TRY_REWRITE_IF(
475 x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2,
476 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
477 TVM_TRY_REWRITE_IF(
478 floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2,
479 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
480 TVM_TRY_REWRITE_IF(
481 x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2,
482 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
483 TVM_TRY_REWRITE_IF(
484 floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
485 c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
486
487 TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
488 floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
489 c3.Eval()->value > 0);
490 TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3),
491 c3.Eval()->value > 0);
492
493 // canonicalization rule
494 // will try rewrite again after canonicalization.
495 TVM_TRY_REWRITE(x - c1, x + (0 - c1));
496 TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
497 TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
498 TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
499 } else if (op->dtype.is_float()) {
500 // Cancellation rules. Deliberately off of the integer path, to
501 // avoid introducing checks on the side effects for the fast path.
502 TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
503 SideEffect(x.Eval()) <= CallEffectKind::kReadState);
504 TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
505 TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
506 TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
507 TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
508 }
509
510 // condition rules.
511 TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2));
512 TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z)));
513 TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y));
514 return ret;
515}
516
517PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
518 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
519 op = ret.as<MulNode>();
520 if (auto const_res = TryConstFold<Mul>(op->a, op->b)) return const_res.value();
521 // Pattern var to match any expression
522 PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
523 // Pattern var match IntImm
524 PVar<IntImm> c1, c2;
525 // Pattern var match FloatImm
526 PVar<FloatImm> c3;
527 // Pattern var for lanes in broadcast and ramp
528 PVar<int> lanes;
529 // Vector rules
530 if (op->dtype.lanes() != 1) {
531 TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes));
532 TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes));
533 TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes));
534 TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f);
535 }
536
537 if (IsIndexType(op->dtype)) {
538 // constant simplification rule
539 TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2);
540 TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2));
541 TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y);
542 TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y);
543
544 // Two representations of const*ceildiv(x, c1)
545 TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c2), c1) * c1, x - floormod(x, c2),
546 c1.Eval()->value == -c2.Eval()->value);
547
548 // canonicalization
549 TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1);
550 TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
551 TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0);
552 }
553 return ret;
554}
555
556PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
557 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
558 op = ret.as<DivNode>();
559 if (auto const_res = TryConstFold<Div>(op->a, op->b)) return const_res.value();
560 // Pattern var to match any expression
561 PVar<PrimExpr> x, y, z, b1;
562 // Pattern var match IntImm
563 PVar<IntImm> c1, c2, c3;
564 // Pattern var for lanes in broadcast and ramp
565 PVar<int> lanes;
566
567 // x / 2.0 = x * 0.5
568 if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
569 ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
570 datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
571 return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
572 }
573
574 // Vector rules
575 if (op->dtype.lanes() != 1) {
576 // NOTE: use div as the pattern also works for float.
577 TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes));
578 // ramp / bcast
579 if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
580 int64_t c1val = c1.Eval()->value;
581 int64_t c2val = c2.Eval()->value;
582 ICHECK(c2val != 0) << "division by zero";
583 if (c1val % c2val == 0) {
584 return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
585 }
586 // If all possible indices in ramp are the same.
587 if (CanProveGreaterEqual(b1.Eval(), 0)) {
588 ModularSet bmod = analyzer_->modular_set(b1.Eval());
589 int64_t ramp_min = bmod->base / c2val;
590 int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
591 if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
592 return broadcast(div(b1, c2), lanes).Eval();
593 }
594 }
595 }
596 }
597
598 if (IsIndexType(op->dtype)) {
599 // Be-aware of the division rules:
600 // We adopt the default C division uses truncation instead of floordiv.
601 // This means most rules need to check non-negativeness of the operands.
602
603 // TryConstFold doesn't work for negative cases because it is also used by legacy
604 // parts of tvm which still assume euclidean div. In this simplifier we assume that the division
605 // is truncated, so perform const folding again.
606 // NOTE: trunc div required
607 if (truncdiv(c1, c2).Match(ret)) {
608 int64_t c1val = c1.Eval()->value;
609 int64_t c2val = c2.Eval()->value;
610 return make_const(op->dtype, truncdiv(c1val, c2val));
611 }
612
613 // while it is always true for trunc div
614 // restrict to common case(positive div)
615 TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2),
616 c1.Eval()->value > 0 && c2.Eval()->value > 0);
617
618 TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3),
619 c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 &&
620 CanProveGreaterEqual(x.Eval(), 0));
621
622 if (truncdiv(x * c1, c2).Match(ret)) {
623 int64_t c1val = c1.Eval()->value;
624 int64_t c2val = c2.Eval()->value;
625 if (c1val > 0 && c2val > 0) {
626 if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval();
627 if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval();
628 }
629 }
630
631 TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x));
632 TVM_TRY_REWRITE(truncdiv(x * c1, x), c1);
633 TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1);
634
635 // Rules involving 2-operands.
636 TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2),
637 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
638 c1.Eval()->value % c2.Eval()->value == 0 &&
639 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
640
641 TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)),
642 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
643 c1.Eval()->value % c2.Eval()->value == 0 &&
644 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
645
646 TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)),
647 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
648 c1.Eval()->value % c2.Eval()->value == 0 &&
649 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
650
651 TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2),
652 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
653 c1.Eval()->value % c2.Eval()->value == 0 &&
654 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
655
656 TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)),
657 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
658 c1.Eval()->value % c2.Eval()->value == 0 &&
659 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
660
661 TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)),
662 c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
663 c1.Eval()->value % c2.Eval()->value == 0 &&
664 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
665
666 // Rules involving 3-operands.
667 TVM_TRY_REWRITE_IF(
668 truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
669 c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
670 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
671
672 TVM_TRY_REWRITE_IF(
673 truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2),
674 c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
675 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0));
676
677 TVM_TRY_REWRITE_IF(
678 truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2),
679 c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
680 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0));
681
682 TVM_TRY_REWRITE_IF(
683 truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
684 c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
685 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
686
687 TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2),
688 c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
689 c1.Eval()->value % c2.Eval()->value == 0 &&
690 CanProveGreaterEqual(x.Eval(), 0));
691
692 TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1,
693 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
694 TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1,
695 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
696
697 TVM_TRY_REWRITE_IF(
698 truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1,
699 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
700 TVM_TRY_REWRITE_IF(
701 truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1,
702 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
703 TVM_TRY_REWRITE_IF(
704 truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1,
705 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
706 TVM_TRY_REWRITE_IF(
707 truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1,
708 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
709
710 TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x,
711 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
712 TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x,
713 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
714
715 TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z),
716 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
717 CanProveGreaterEqual(z.Eval(), 0));
718 TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z),
719 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
720 CanProveGreaterEqual(z.Eval(), 0));
721 TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x,
722 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
723 CanProveGreaterEqual(z.Eval(), 0));
724 TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x,
725 CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
726 CanProveGreaterEqual(z.Eval(), 0));
727 }
728 return ret;
729}
730
731PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) {
732 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
733 op = ret.as<ModNode>();
734 if (auto const_res = TryConstFold<Mod>(op->a, op->b)) return const_res.value();
735
736 // Pattern var to match any expression
737 PVar<PrimExpr> x, y, z, b1;
738 // Pattern var match IntImm
739 PVar<IntImm> c1, c2;
740 // Pattern var for lanes in broadcast and ramp
741 PVar<int> lanes;
742
743 // Vector rules
744 if (op->dtype.lanes() != 1) {
745 TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)),
746 broadcast(truncmod(x, y), lanes));
747
748 // ramp % bcast
749 if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
750 int64_t c1val = c1.Eval()->value;
751 int64_t c2val = c2.Eval()->value;
752 ICHECK(c2val != 0) << "division by zero";
753 if (c1val % c2val == 0) {
754 return broadcast(truncmod(b1, c2), lanes).Eval();
755 }
756 // If all possible indices in ramp are the same.
757 if (CanProveGreaterEqual(b1.Eval(), 0)) {
758 ModularSet bmod = analyzer_->modular_set(b1.Eval());
759 int64_t ramp_min = bmod->base / c2val;
760 int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
761 if (bmod->coeff % c2val == 0) {
762 if (ramp_min == ramp_max) {
763 return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
764 } else {
765 return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
766 }
767 }
768 }
769 }
770 }
771
772 if (IsIndexType(op->dtype)) {
773 // Be-aware of the division rules:
774 // We adopt the default C division uses truncation instead of floordiv.
775 // This means most rules need to check non-negativeness of the operands.
776 TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x),
777 c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0);
778
779 TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2),
780 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
781 CanProveGreaterEqual((x * c1).Eval(), 0) &&
782 CanProveGreaterEqual(y.Eval(), 0));
783
784 TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2),
785 c2.Eval()->value > 0 && c1.Eval()->value >= 0 &&
786 c1.Eval()->value % c2.Eval()->value == 0 &&
787 CanProveGreaterEqual(x.Eval(), 0));
788
789 TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2),
790 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
791 CanProveGreaterEqual(x.Eval(), 0) &&
792 CanProveGreaterEqual((y * c1).Eval(), 0));
793
794 // canonicalization: x % c == x % (-c) for truncated division
795 // NOTE: trunc div required
796 TVM_TRY_RECURSIVE_REWRITE_IF(
797 truncmod(x, c1), truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))),
798 c1.Eval()->value < 0);
799
800 // try modular analysis
801 if (truncmod(x, c1).Match(ret)) {
802 ModularSet mod = analyzer_->modular_set(x.Eval());
803 int64_t c1val = c1.Eval()->value;
804 if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) {
805 return truncmod(mod->base, c1).Eval();
806 }
807 }
808 }
809 return ret;
810}
811
812PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
813 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
814 op = ret.as<FloorDivNode>();
815 if (auto const_res = TryConstFold<FloorDiv>(op->a, op->b)) return const_res.value();
816 // Pattern var to match any expression
817 PVar<PrimExpr> x, y, z, b1;
818 // Pattern var match IntImm
819 PVar<IntImm> c1, c2, c3;
820 // Pattern var for lanes in broadcast and ramp
821 PVar<int> lanes;
822
823 // Vector rules
824 if (op->dtype.lanes() != 1) {
825 TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)),
826 broadcast(floordiv(x, y), lanes));
827 // ramp // bcast
828 if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
829 int64_t c1val = c1.Eval()->value;
830 int64_t c2val = c2.Eval()->value;
831 ICHECK(c2val != 0) << "division by zero";
832 if (c1val % c2val == 0) {
833 return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
834 }
835 // If all possible indices in ramp are the same.
836 ModularSet bmod = analyzer_->modular_set(b1.Eval());
837 int64_t ramp_min = floordiv(bmod->base, c2val);
838 int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
839 if (ramp_min == ramp_max) {
840 // If b1 can devide c2
841 if (bmod->coeff % c2val == 0) {
842 return broadcast(floordiv(b1, c2), lanes).Eval();
843 }
844 // If all indices can be guaranteed to settle inside a coeff range
845 if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
846 return broadcast(floordiv(b1, c2), lanes).Eval();
847 }
848 }
849 }
850 }
851
852 if (IsIndexType(op->dtype)) {
853 // Be-aware of the division rules: this is floor division.
854 TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2),
855 c1.Eval()->value > 0 && c2.Eval()->value > 0);
856
857 TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3),
858 c1.Eval()->value > 0 && c3.Eval()->value > 0);
859
860 if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) ||
861 floordiv(y + x * c1, c2).Match(ret)) {
862 int64_t c1val = c1.Eval()->value;
863 int64_t c2val = c2.Eval()->value;
864 PrimExpr yval = y.EvalOr(Integer(0));
865 if (c2val == 0) return ret;
866
867 // try eliminate residue part
868 PrimExpr residue =
869 floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val);
870 PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val);
871 auto bound = analyzer_->const_int_bound(residue);
872 if (bound.defined() && bound->max_value == bound->min_value) {
873 return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value));
874 }
875
876 // try simplify divisor
877 if (c1val > 0 && c2val > 0 && c2val % c1val == 0 &&
878 CanProveLess(floormod(yval, c2val), c1val)) {
879 // assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then
880 // (x * c1 + y) // c2
881 // ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1)
882 // ==> x' + d + (b * c1 + e) // c2
883 // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1
884 // ==> x // (c2 // c1) + (y // c2)
885 return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div;
886 }
887 }
888
889 TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
890 TVM_TRY_REWRITE(floordiv(x * c1, x), c1);
891 TVM_TRY_REWRITE(floordiv(c1 * x, x), c1);
892
893 // Rules involving 2-operands.
894 TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
895 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
896
897 TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)),
898 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
899
900 TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
901 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
902
903 TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)),
904 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
905
906 // Rules involving 3-operands.
907 TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
908 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
909 TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), floordiv(x, floordiv(c2, c1)),
910 c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
911 c2.Eval()->value % c1.Eval()->value == 0 &&
912 CanProveEqual(floordiv(y.Eval() + z.Eval(), c1.Eval()), 0));
913
914 TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2),
915 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
916
917 TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2),
918 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
919
920 TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
921 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
922
923 TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
924 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
925
926 TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);
927
928 TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
929
930 TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
931
932 TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1,
933 CanProveGreaterEqual(x.Eval(), 0));
934 TVM_TRY_REWRITE_IF(floordiv((y + x) + z, x), floordiv(y + z, x) + 1,
935 CanProveGreaterEqual(x.Eval(), 0));
936 TVM_TRY_REWRITE_IF(floordiv(y + (z + x), x), floordiv(y + z, x) + 1,
937 CanProveGreaterEqual(x.Eval(), 0));
938 TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1,
939 CanProveGreaterEqual(x.Eval(), 0));
940
941 TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0));
942 TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0));
943
944 TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z),
945 CanProveGreaterEqual(z.Eval(), 0));
946 TVM_TRY_REWRITE_IF(floordiv(z * x + y, z), x + floordiv(y, z),
947 CanProveGreaterEqual(z.Eval(), 0));
948 TVM_TRY_REWRITE_IF(floordiv(y + x * z, z), floordiv(y, z) + x,
949 CanProveGreaterEqual(z.Eval(), 0));
950 TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x,
951 CanProveGreaterEqual(z.Eval(), 0));
952
953 TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0);
954 }
955 return ret;
956}
957
958PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
959 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
960 op = ret.as<FloorModNode>();
961 if (auto const_res = TryConstFold<FloorMod>(op->a, op->b)) return const_res.value();
962
963 // Pattern var to match any expression
964 PVar<PrimExpr> x, y, z, b1;
965 // Pattern var match IntImm
966 PVar<IntImm> c1, c2;
967 // Pattern var for lanes in broadcast and ramp
968 PVar<int> lanes;
969
970 // Vector rules
971 if (op->dtype.lanes() != 1) {
972 TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)),
973 broadcast(floormod(x, y), lanes));
974
975 // floormod(ramp, bcast)
976 if (floormod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
977 int64_t c1val = c1.Eval()->value;
978 int64_t c2val = c2.Eval()->value;
979 ICHECK(c2val != 0) << "division by zero";
980 if (c1val % c2val == 0) {
981 return broadcast(floormod(b1, c2), lanes).Eval();
982 }
983 // If all possible indices in ramp are the same.
984 ModularSet bmod = analyzer_->modular_set(b1.Eval());
985 int64_t ramp_min = floordiv(bmod->base, c2val);
986 int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
987 if (ramp_min == ramp_max) {
988 // If b1 can devide c2
989 if (bmod->coeff % c2val == 0) {
990 return ramp(floormod(bmod->base, c2), c1, lanes).Eval();
991 }
992 // If all indices can be guaranteed to settle inside a coeff range
993 if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
994 return ramp(floormod(b1, c2), c1, lanes).Eval();
995 }
996 }
997 if (bmod->coeff % c2val == 0) {
998 return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
999 }
1000 }
1001 }
1002
1003 if (IsIndexType(op->dtype)) {
1004 // Be-aware of the division rules: we use floordiv/floormod here
1005 TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2),
1006 c2.Eval()->value != 0);
1007
1008 TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y,
1009 c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
1010 c2.Eval()->value % c1.Eval()->value == 0 &&
1011 CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));
1012
1013 TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
1014 c2.Eval()->value > 0);
1015
1016 TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
1017 c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
1018
1019 TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
1020 c2.Eval()->value > 0);
1021
1022 TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
1023
1024 TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
1025 TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));
1026
1027 // try modular analysis
1028 if (floormod(x, c1).Match(ret)) {
1029 ModularSet mod = analyzer_->modular_set(x.Eval());
1030 int64_t c1val = c1.Eval()->value;
1031 if (mod->coeff % c1val == 0 && c1val > 0) {
1032 return floormod(mod->base, c1).Eval();
1033 }
1034 }
1035 }
1036 return ret;
1037}
1038
1039PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
1040 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1041 op = ret.as<MinNode>();
1042 if (auto const_res = TryConstFold<Min>(op->a, op->b)) return const_res.value();
1043
1044 // Pattern var to match any expression
1045 PVar<PrimExpr> x, y, z, s1, s2;
1046 // Pattern var match IntImm
1047 PVar<IntImm> c1, c2;
1048 PVar<int> lanes;
1049
1050 // vector rule
1051 if (op->dtype.lanes() != 1) {
1052 TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes));
1053 TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
1054 min(x, broadcast(min(y, z), lanes)));
1055 }
1056 if (IsIndexType(op->dtype)) {
1057 TVM_TRY_REWRITE(min(x, x), x);
1058
1059 // constant int bound
1060 ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
1061 ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
1062 if (a_bound->max_value <= b_bound->min_value) {
1063 return op->a;
1064 }
1065 if (b_bound->max_value <= a_bound->min_value) {
1066 return op->b;
1067 }
1068
1069 // constant comparison
1070 if (min(x + c1, x + c2).Match(ret)) {
1071 if (c1.Eval()->value < c2.Eval()->value) {
1072 return (x + c1).Eval();
1073 } else {
1074 return (x + c2).Eval();
1075 }
1076 }
1077 if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) {
1078 if (c1.Eval()->value < 0) {
1079 return (x + c1).Eval();
1080 } else {
1081 return x.Eval();
1082 }
1083 }
1084 if (min(c1 - x, c2 - x).Match(ret)) {
1085 if (c1.Eval()->value < c2.Eval()->value) {
1086 return (c1 - x).Eval();
1087 } else {
1088 return (c2 - x).Eval();
1089 }
1090 }
1091
1092 // DivMod rules
1093 // Divide up rounding: truc div
1094 // NOTE: trucdiv(x, y) >= floordiv(x, y)
1095 TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x,
1096 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1097 TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
1098 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
1099 CanProveGreaterEqual(x.Eval(), 1));
1100
1101 TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x,
1102 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1103 TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2),
1104 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
1105 CanProveGreaterEqual(x.Eval(), 1));
1106
1107 // Divide up rounding: floor div
1108 TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x,
1109 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1110 TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
1111 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
1112 CanProveGreaterEqual(x.Eval(), 1));
1113
1114 TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x,
1115 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1116 TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2),
1117 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
1118 CanProveGreaterEqual(x.Eval(), 1));
1119
1120 TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0);
1121 TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0);
1122
1123 TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y));
1124 TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y));
1125 TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y));
1126 TVM_TRY_REWRITE(min(min(x, y), max(y, x)), min(x, y));
1127
1128 TVM_TRY_REWRITE(min(max(x, y), x), x);
1129 TVM_TRY_REWRITE(min(max(x, y), y), y);
1130 TVM_TRY_REWRITE(min(min(x, y), x), min(x, y));
1131 TVM_TRY_REWRITE(min(min(x, y), y), min(x, y));
1132
1133 TVM_TRY_REWRITE(min(x, max(x, y)), x);
1134 TVM_TRY_REWRITE(min(y, max(x, y)), y);
1135 TVM_TRY_REWRITE(min(x, min(x, y)), min(x, y));
1136 TVM_TRY_REWRITE(min(y, min(x, y)), min(x, y));
1137
1138 TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z));
1139 TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1));
1140 TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y),
1141 min(min(min(min(x, y), z), s1), s2));
1142
1143 TVM_TRY_REWRITE(min(max(x, y), max(x, z)), max(min(y, z), x));
1144 TVM_TRY_REWRITE(min(max(x, y), max(z, x)), max(min(y, z), x));
1145 TVM_TRY_REWRITE(min(max(y, x), max(x, z)), max(min(y, z), x));
1146 TVM_TRY_REWRITE(min(max(y, x), max(z, x)), max(min(y, z), x));
1147
1148 TVM_TRY_REWRITE(min(min(x, y), min(x, z)), min(min(y, z), x));
1149 TVM_TRY_REWRITE(min(min(x, y), min(z, x)), min(min(y, z), x));
1150 TVM_TRY_REWRITE(min(min(y, x), min(x, z)), min(min(y, z), x));
1151 TVM_TRY_REWRITE(min(min(y, x), min(z, x)), min(min(y, z), x));
1152
1153 TVM_TRY_REWRITE(min(y + x, z + x), min(y, z) + x);
1154 TVM_TRY_REWRITE(min(y + x, x + z), min(y, z) + x);
1155 TVM_TRY_REWRITE(min(x + y, x + z), min(y, z) + x);
1156 TVM_TRY_REWRITE(min(x + y, z + x), min(y, z) + x);
1157
1158 // sub distribution
1159 TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x);
1160 TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z));
1161
1162 // constant folding rule.
1163 TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));
1164
1165 // scaling rule
1166 if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
1167 if (c1.Eval()->value > 0) {
1168 return truncdiv(min(x, y), c1).Eval();
1169 } else {
1170 return truncdiv(max(x, y), c1).Eval();
1171 }
1172 }
1173 if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
1174 if (c1.Eval()->value > 0) {
1175 return floordiv(min(x, y), c1).Eval();
1176 } else {
1177 return floordiv(max(x, y), c1).Eval();
1178 }
1179 }
1180 if (min(x * c1, y * c1).Match(ret)) {
1181 if (c1.Eval()->value > 0) {
1182 return (min(x, y) * c1).Eval();
1183 } else {
1184 return (max(x, y) * c1).Eval();
1185 }
1186 }
1187 if (min(x * c1, c2).Match(ret)) {
1188 int64_t c1val = c1.Eval()->value;
1189 int64_t c2val = c2.Eval()->value;
1190 if (c1val == 0) {
1191 return c2val < 0 ? c2.Eval() : c1.Eval();
1192 }
1193 if (c2val % c1val == 0) {
1194 if (c1val > 0) {
1195 return (min(x, c2val / c1val) * c1val).Eval();
1196 } else {
1197 return (max(x, c2val / c1val) * c1val).Eval();
1198 }
1199 }
1200 }
1201
1202 // canonicalization
1203 TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
1204 TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0);
1205 }
1206
1207 // condition rules.
1208 TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2)));
1209 return ret;
1210}
1211
1212PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
1213 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1214 op = ret.as<MaxNode>();
1215 if (auto const_res = TryConstFold<Max>(op->a, op->b)) return const_res.value();
1216
1217 // Pattern var to match any expression
1218 PVar<PrimExpr> x, y, z, s1, s2;
1219 // Pattern var match IntImm
1220 PVar<IntImm> c1, c2;
1221 PVar<int> lanes;
1222
1223 // vector rule
1224 if (op->dtype.lanes() != 1) {
1225 TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes));
1226 TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
1227 max(x, broadcast(max(y, z), lanes)));
1228 }
1229 if (IsIndexType(op->dtype)) {
1230 TVM_TRY_REWRITE(max(x, x), x);
1231
1232 // constant int bound
1233 ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
1234 ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
1235 if (a_bound->min_value >= b_bound->max_value) {
1236 return op->a;
1237 }
1238 if (b_bound->min_value >= a_bound->max_value) {
1239 return op->b;
1240 }
1241
1242 // constant comparison
1243 if (max(x + c1, x + c2).Match(ret)) {
1244 if (c1.Eval()->value > c2.Eval()->value) {
1245 return (x + c1).Eval();
1246 } else {
1247 return (x + c2).Eval();
1248 }
1249 }
1250 if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) {
1251 if (c1.Eval()->value > 0) {
1252 return (x + c1).Eval();
1253 } else {
1254 return x.Eval();
1255 }
1256 }
1257 if (max(c1 - x, c2 - x).Match(ret)) {
1258 if (c1.Eval()->value > c2.Eval()->value) {
1259 return (c1 - x).Eval();
1260 } else {
1261 return (c2 - x).Eval();
1262 }
1263 }
1264
1265 // DivMod rules
1266 // Divide up rounding: truc div
1267 // NOTE: trucdiv(x, y) >= floordiv(x, y)
1268 TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2,
1269 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1270 TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2,
1271 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1272
1273 // Divide up rounding: floor div
1274 TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2,
1275 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1276 TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2,
1277 c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
1278
1279 TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0);
1280 TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0);
1281
1282 TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y));
1283 TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y));
1284 TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y));
1285 TVM_TRY_REWRITE(max(max(x, y), min(y, x)), max(x, y));
1286
1287 TVM_TRY_REWRITE(max(min(x, y), x), x);
1288 TVM_TRY_REWRITE(max(min(x, y), y), y);
1289 TVM_TRY_REWRITE(max(max(x, y), x), max(x, y));
1290 TVM_TRY_REWRITE(max(max(x, y), y), max(x, y));
1291
1292 TVM_TRY_REWRITE(max(x, min(x, y)), x);
1293 TVM_TRY_REWRITE(max(y, min(x, y)), y);
1294 TVM_TRY_REWRITE(max(x, max(x, y)), max(x, y));
1295 TVM_TRY_REWRITE(max(y, max(x, y)), max(x, y));
1296
1297 TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z));
1298 TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1));
1299 TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y),
1300 max(max(max(max(x, y), z), s1), s2));
1301
1302 // max/max cancelation
1303 TVM_TRY_REWRITE(max(max(x, y), max(x, z)), max(max(y, z), x));
1304 TVM_TRY_REWRITE(max(max(x, y), max(z, x)), max(max(y, z), x));
1305 TVM_TRY_REWRITE(max(max(y, x), max(x, z)), max(max(y, z), x));
1306 TVM_TRY_REWRITE(max(max(y, x), max(z, x)), max(max(y, z), x));
1307
1308 // max/min distribution
1309 TVM_TRY_REWRITE(max(min(x, y), min(x, z)), min(max(y, z), x));
1310 TVM_TRY_REWRITE(max(min(x, y), min(z, x)), min(max(y, z), x));
1311 TVM_TRY_REWRITE(max(min(y, x), min(x, z)), min(max(y, z), x));
1312 TVM_TRY_REWRITE(max(min(y, x), min(z, x)), min(max(y, z), x));
1313
1314 // add distribution
1315 TVM_TRY_REWRITE(max(y + x, z + x), max(y, z) + x);
1316 TVM_TRY_REWRITE(max(y + x, x + z), max(y, z) + x);
1317 TVM_TRY_REWRITE(max(x + y, x + z), max(y, z) + x);
1318 TVM_TRY_REWRITE(max(x + y, z + x), max(y, z) + x);
1319
1320 // sub distribution
1321 TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x);
1322 TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z));
1323
1324 // constant folding rule.
1325 TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
1326
1327 // scaling rule
1328 if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
1329 if (c1.Eval()->value > 0) {
1330 return truncdiv(max(x, y), c1).Eval();
1331 } else {
1332 return truncdiv(min(x, y), c1).Eval();
1333 }
1334 }
1335 if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
1336 if (c1.Eval()->value > 0) {
1337 return floordiv(max(x, y), c1).Eval();
1338 } else {
1339 return floordiv(min(x, y), c1).Eval();
1340 }
1341 }
1342 if (max(x * c1, y * c1).Match(ret)) {
1343 if (c1.Eval()->value > 0) {
1344 return (max(x, y) * c1).Eval();
1345 } else {
1346 return (min(x, y) * c1).Eval();
1347 }
1348 }
1349 if (max(x * c1, c2).Match(ret)) {
1350 int64_t c1val = c1.Eval()->value;
1351 int64_t c2val = c2.Eval()->value;
1352 if (c1val == 0) {
1353 return c2val > 0 ? c2.Eval() : c1.Eval();
1354 }
1355 if (c2val % c1val == 0) {
1356 if (c1val > 0) {
1357 return (max(x, c2val / c1val) * c1val).Eval();
1358 } else {
1359 return (min(x, c2val / c1val) * c1val).Eval();
1360 }
1361 }
1362 }
1363
1364 // canonicalization
1365 TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
1366 TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
1367 }
1368
1369 // condition rules.
1370 TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2)));
1371 return ret;
1372}
1373
1374Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const {
1375 PrimExpr negation = Not(expr);
1376
1377 ExprDeepEqual expr_equal;
1378 for (const auto& constraint : literal_constraints_) {
1379 if (expr_equal(constraint, expr)) {
1380 return make_const(expr->dtype, true);
1381 }
1382 if (expr_equal(constraint, negation)) {
1383 return make_const(expr->dtype, false);
1384 }
1385 }
1386 return NullOpt;
1387}
1388
1389PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
1390 EQ ret = Downcast<EQ>(IRMutatorWithAnalyzer::VisitExpr_(op));
1391 op = ret.get();
1392
1393 if (auto const_res = TryConstFold<EQ>(op->a, op->b)) {
1394 return const_res.value();
1395 }
1396 if (auto match = TryMatchLiteralConstraint(ret)) {
1397 return match.value();
1398 }
1399
1400 return ApplyRewriteRules(ret);
1401}
1402
1403PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
1404 // Pattern var to match any expression
1405 PVar<PrimExpr> x, y;
1406 // Pattern var match IntImm
1407 PVar<IntImm> c1;
1408 PVar<int> lanes;
1409
1410 // vector rule
1411 if (ret->dtype.lanes() != 1) {
1412 TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes));
1413 }
1414
1415 if (IsIndexType(ret->a.dtype())) {
1416 CompareResult result = TryCompare(ret->a, ret->b);
1417 if (result == CompareResult::kEQ) {
1418 return make_const(ret->dtype, true);
1419 } else if (result == CompareResult::kNE || result == CompareResult::kGT ||
1420 result == CompareResult::kLT) {
1421 return make_const(ret->dtype, false);
1422 }
1423 TVM_TRY_REWRITE(c1 == x, x == c1);
1424
1425 TVM_TRY_REWRITE(x - c1 == 0, x == c1);
1426 TVM_TRY_REWRITE(c1 - x == 0, x == c1);
1427 TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
1428 TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);
1429 }
1430 return std::move(ret);
1431}
1432
1433PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) {
1434 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1435 op = ret.as<NENode>();
1436
1437 if (auto const_res = TryConstFold<NE>(op->a, op->b)) return const_res.value();
1438 if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
1439
1440 if (IsIndexType(op->a.dtype())) {
1441 CompareResult result = TryCompare(op->a, op->b);
1442 if (result == CompareResult::kNE || result == CompareResult::kGT ||
1443 result == CompareResult::kLT) {
1444 return make_const(op->dtype, true);
1445 } else if (result == CompareResult::kEQ) {
1446 return make_const(op->dtype, false);
1447 } else if (result == CompareResult::kGE) {
1448 // Known: a >= b
1449 //
1450 // a != b
1451 // (a < b) or (b < a)
1452 // False or (b < a)
1453 // b < a
1454 return ApplyRewriteRules(LT(op->b, op->a));
1455 } else if (result == CompareResult::kLE) {
1456 // Known: a <= b
1457 //
1458 // a != b
1459 // (a < b) or (b < a)
1460 // (a < b) or False
1461 // a < b
1462 return ApplyRewriteRules(LT(op->a, op->b));
1463 }
1464 }
1465
1466 return ApplyRewriteRules(Not(ApplyRewriteRules(EQ(op->a, op->b))));
1467}
1468
1469PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) {
1470 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1471 op = ret.as<LENode>();
1472 ICHECK(op);
1473
1474 if (auto const_res = TryConstFold<LE>(op->a, op->b)) return const_res.value();
1475 if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
1476
1477 // Check for applicable rewrites before attempting to prove/disprove
1478 // the inequality. This preserves earlier behavior, where (A<=B*x)
1479 // simplifies to (ceildiv(A,B)<=x) when (A%B!=0). Performing the
1480 // TryCompare first would simplify to the equivalent
1481 // (floordiv(A,B)<x) in these cases instead.
1482 ret = ApplyRewriteRules(Not(ApplyRewriteRules(LT(op->b, op->a))));
1483
1484 if (auto op = ret.as<LENode>(); op && IsIndexType(op->a.dtype())) {
1485 CompareResult result = TryCompare(op->a, op->b);
1486 if (result == CompareResult::kLE || result == CompareResult::kLT ||
1487 result == CompareResult::kEQ) {
1488 return make_const(op->dtype, true);
1489 } else if (result == CompareResult::kGT) {
1490 return make_const(op->dtype, false);
1491 } else if (result == CompareResult::kNE) {
1492 // Known: a != b
1493 //
1494 // a <= b
1495 // (a < b) or (a == b)
1496 // (a < b) or False
1497 // a < b
1498 return ApplyRewriteRules(LT(op->a, op->b));
1499 } else if (result == CompareResult::kGE) {
1500 // Known: a >= b
1501 //
1502 // a <= b
1503 // (a < b) or (a == b)
1504 // False or (a == b)
1505 // a == b
1506 return ApplyRewriteRules(EQ(op->a, op->b));
1507 }
1508 }
1509
1510 return ret;
1511}
1512
1513PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) {
1514 return this->VisitExpr(op->b < op->a);
1515}
1516
1517PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) {
1518 return this->VisitExpr(op->b <= op->a);
1519}
1520
1521PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
1522 LT node = Downcast<LT>(IRMutatorWithAnalyzer::VisitExpr_(op));
1523 op = node.get();
1524
1525 if (auto const_res = TryConstFold<LT>(op->a, op->b)) return const_res.value();
1526 if (auto match = TryMatchLiteralConstraint(node)) return match.value();
1527
1528 return ApplyRewriteRules(node);
1529}
1530
1531PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
1532 // Pattern var to match any expression
1533 PVar<PrimExpr> x, y, z, s1, s2;
1534 // Pattern var match IntImm
1535 PVar<IntImm> c1, c2;
1536 PVar<int> lanes;
1537
1538 // vector rule
1539 if (ret->dtype.lanes() != 1) {
1540 TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes));
1541 TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes));
1542 }
1543
1544 if (IsIndexType(ret->a.dtype())) {
1545 CompareResult result = TryCompare(ret->a, ret->b);
1546 if (result == CompareResult::kLT) {
1547 return make_const(ret->dtype, true);
1548 }
1549 if (result == CompareResult::kEQ || result == CompareResult::kGT ||
1550 result == CompareResult::kGE) {
1551 return make_const(ret->dtype, false);
1552 }
1553
1554 // clang-format off
1555 TVM_TRY_REWRITE(x + y < x + z, y < z);
1556 TVM_TRY_REWRITE(x + y < z + x, y < z);
1557 TVM_TRY_REWRITE(y + x < x + z, y < z);
1558 TVM_TRY_REWRITE(y + x < z + x, y < z);
1559 TVM_TRY_REWRITE(y - x < z - x, y < z);
1560 TVM_TRY_REWRITE(x - y < x - z, z < y);
1561
1562 TVM_TRY_REWRITE(x < x + z, 0 < z);
1563 TVM_TRY_REWRITE(x < z + x, 0 < z);
1564 TVM_TRY_REWRITE(x < x - z, z < 0);
1565 TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x);
1566 TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1);
1567
1568 TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0);
1569 TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0);
1570
1571 // constant cancelation: only need to make use of one mod
1572 // truc div
1573 TVM_TRY_REWRITE_IF(x * c2 < c1,
1574 x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0);
1575 // NOTE: trunc div required
1576 TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2),
1577 c1.Eval()->value <= 0 && c2.Eval()->value > 0);
1578 // NOTE: trunc div required (euclidean is ok too, floored is not)
1579 TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 &&
1580 c2.Eval()->value < 0);
1581 // NOTE: trunc div required (floored is ok too, euclidean is not)
1582 TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x,
1583 c1.Eval()->value <= 0 && c2.Eval()->value < 0);
1584 // NOTE: trunc div required
1585 TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x,
1586 c1.Eval()->value < 0 && c2.Eval()->value > 0);
1587 TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x,
1588 c1.Eval()->value >= 0 && c2.Eval()->value > 0);
1589 // NOTE: trunc div required (floored is ok too, euclidean is not)
1590 TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1,
1591 c1.Eval()->value < 0 && c2.Eval()->value < 0);
1592 // NOTE: trunc div required (euclidean is ok too, floored is not)
1593 TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2),
1594 c1.Eval()->value >= 0 && c2.Eval()->value < 0);
1595 // DivMod rules
1596 // trucdiv
1597 TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2,
1598 x<c1 * c2, c1.Eval()->value> 0 && c2.Eval()->value > 0);
1599 // NOTE: trunc div required
1600 TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2,
1601 x<c1*(c2 - 1) + 1, c1.Eval()->value> 0 && c2.Eval()->value <= 0);
1602
1603 TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x,
1604 c1.Eval()->value >= 0 && c2.Eval()->value > 0);
1605 // NOTE: trunc div required
1606 TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x,
1607 c1.Eval()->value < 0 && c2.Eval()->value > 0);
1608
1609 // invariance for any div mod: x - (x / c1) * c1 == x % c1
1610 TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0);
1611 TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y,
1612 0 < truncmod(x, c1) + y, c1.Eval()->value > 0);
1613 TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y,
1614 y < truncmod(x, c1), c1.Eval()->value > 0);
1615
1616 TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x,
1617 c2 < truncmod(x + c2, c1), c1.Eval()->value > 0);
1618 TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y,
1619 c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0);
1620 TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y,
1621 y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0);
1622
1623 // floordiv
1624 TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0);
1625 TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0);
1626
1627 TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0);
1628 TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y,
1629 0 < floormod(x, c1) + y, c1.Eval()->value > 0);
1630 TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y,
1631 y < floormod(x, c1), c1.Eval()->value > 0);
1632 TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x,
1633 c2 < floormod(x + c2, c1), c1.Eval()->value > 0);
1634 TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y,
1635 c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0);
1636 TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y,
1637 y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0);
1638
1639 // canonicalization rule
1640 TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z);
1641 TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z);
1642 TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y);
1643 TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y);
1644
1645 TVM_TRY_RECURSIVE_REWRITE(x < c1 - y, x + y < c1);
1646 TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1);
1647 TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y);
1648 TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);
1649
1650 TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
1651 TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
1652 TVM_TRY_REWRITE(x - c1 < 0, x < c1);
1653
1654 TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y);
1655 TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y);
1656 TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y);
1657 TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y);
1658 // clang-format on
1659 }
1660 return std::move(ret);
1661}
1662
1663PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
1664 Not ret = Downcast<Not>(IRMutatorWithAnalyzer::VisitExpr_(op));
1665 if (auto const_res = TryConstFold<Not>(ret->a)) return const_res.value();
1666 if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
1667
1668 return ApplyRewriteRules(ret);
1669}
1670
1671PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) {
1672 // Pattern var to match any expression
1673 PVar<PrimExpr> x, y;
1674 PVar<int> lanes;
1675 if (ret->dtype.lanes() != 1) {
1676 TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
1677 }
1678
1679 TVM_TRY_REWRITE(!(!x), x);
1680 TVM_TRY_REWRITE(!(x <= y), y < x);
1681 TVM_TRY_REWRITE(!(x >= y), x < y);
1682 TVM_TRY_REWRITE(!(x < y), y <= x);
1683 TVM_TRY_REWRITE(!(x > y), x <= y);
1684 TVM_TRY_REWRITE(!(x == y), x != y);
1685 TVM_TRY_REWRITE(!(x != y), x == y);
1686 TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
1687 TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
1688 return std::move(ret);
1689}
1690
1691PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
1692 PrimExpr ret = [&]() -> PrimExpr {
1693 // If this extension isn't enabled, just delegate out.
1694 if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
1695 return IRMutatorWithAnalyzer::VisitExpr_(op);
1696 }
1697
1698 PrimExpr a = op->a;
1699 PrimExpr b = op->b;
1700
1701 // Alternate which branch is used as the constraint, and which is
1702 // being simplified. Because some sub-analyzers expect their
1703 // constraints to already be simplified, each branch may require
1704 // more than one update. The loop condition allows each branch to
1705 // be visited up to twice, but only performs the second visit if
1706 // necessary.
1707 size_t iterations_since_update = 0;
1708 for (size_t i = 0; i < 4; i++) {
1709 PrimExpr& to_update = (i % 2 == 0) ? a : b;
1710 const PrimExpr& constraint = (i % 2 == 0) ? b : a;
1711
1712 With<ConstraintContext> context(analyzer_, constraint);
1713 PrimExpr updated = VisitExpr(to_update);
1714
1715 if (!to_update.same_as(updated)) {
1716 to_update = updated;
1717 iterations_since_update = 0;
1718 } else {
1719 iterations_since_update++;
1720 if (iterations_since_update >= 2) {
1721 break;
1722 }
1723 }
1724 }
1725
1726 // Only construct a new object if a change has been made.
1727 // Otherwise, follow ExprMutator's convention of returning the
1728 // original object.
1729 if (a.same_as(op->a) && b.same_as(op->b)) {
1730 return GetRef<PrimExpr>(op);
1731 } else {
1732 return And(a, b);
1733 }
1734 }();
1735
1736 op = ret.as<AndNode>();
1737
1738 if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
1739 if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
1740 if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) &&
1741 !recursively_visiting_boolean_) {
1742 return SimplifyAsAndOfOrs(ret, analyzer_);
1743 }
1744
1745 // Pattern var to match any expression
1746 PVar<PrimExpr> x, y, z;
1747 // Pattern var match IntImm
1748 PVar<IntImm> c1, c2, c3;
1749 PVar<int> lanes;
1750
1751 if (op->dtype.lanes() != 1) {
1752 TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
1753 }
1754
1755 auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
1756 TVM_TRY_REWRITE(x == y && x != y, cfalse);
1757 TVM_TRY_REWRITE(x != y && x == y, cfalse);
1758 TVM_TRY_REWRITE(x && !x, cfalse);
1759 TVM_TRY_REWRITE(x <= y && y < x, cfalse);
1760 TVM_TRY_REWRITE(y < x && x <= y, cfalse);
1761
1762 TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
1763 TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
1764
1765 TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
1766 TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
1767 TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
1768 TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
1769
1770 TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
1771 TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);
1772
1773 TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
1774 TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
1775
1776 TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3);
1777 TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3);
1778
1779 TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 &&
1780 x - y * c1<c1, y == floordiv(x, c1), c1.Eval()->value> 0);
1781 TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1),
1782 c1.Eval()->value > 0);
1783
1784 TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1));
1785 TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1));
1786 TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1),
1787 c2.Eval()->value == -c1.Eval()->value);
1788 TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1),
1789 c2.Eval()->value == -c1.Eval()->value);
1790
1791 TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
1792 x < c1 - c2 + c3 && floormod(x, c2) < c3,
1793 c1.Eval()->value % c2.Eval()->value == 0);
1794 TVM_TRY_RECURSIVE_REWRITE_IF(
1795 x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
1796 (c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
1797 c3.Eval()->value);
1798
1799 TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
1800 x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
1801 (c1.Eval()->value + 1) % c2.Eval()->value == 0);
1802 TVM_TRY_RECURSIVE_REWRITE_IF(
1803 x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
1804 (((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
1805 c3.Eval()->value);
1806
1807 TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
1808 c1 * c2 <= x && x < c1 * c2 + c3);
1809 TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1,
1810 c1 * c2 <= x && x < c1 * c2 + c3);
1811 TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
1812 c1 * c2 <= x && x <= c1 * c2 + c3);
1813 TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1,
1814 c1 * c2 <= x && x <= c1 * c2 + c3);
1815
1816 TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
1817 c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
1818 TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1,
1819 c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
1820 TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
1821 c1 * c2 + c3 < x && x < (c1 + 1) * c2);
1822 TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1,
1823 c1 * c2 + c3 < x && x < (c1 + 1) * c2);
1824
1825 TVM_TRY_RECURSIVE_REWRITE(x && (y && z), (x && y) && z);
1826
1827 return ret;
1828}
1829
1830PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
1831 PrimExpr orig = GetRef<PrimExpr>(op);
1832
1833 PrimExpr ret = [&]() -> PrimExpr {
1834 // If this extension isn't enabled, just delegate out.
1835 if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
1836 return IRMutatorWithAnalyzer::VisitExpr_(op);
1837 }
1838
1839 PrimExpr a = op->a;
1840 PrimExpr b = op->b;
1841
1842 // Alternate which branch is used as the constraint, and which
1843 // is being simplified. Because some sub-analyzers expect their
1844 // constraints to already be simplified, each branch may require
1845 // more than update. The loop condition allows each branch to be
1846 // visited up to twice, but only if performs the second visit if
1847 // necessary.
1848 size_t iterations_since_update = 0;
1849 for (size_t i = 0; i < 4; i++) {
1850 PrimExpr& to_update = (i % 2 == 0) ? a : b;
1851 const PrimExpr& constraint = (i % 2 == 0) ? b : a;
1852
1853 With<ConstraintContext> context(analyzer_, NormalizeBooleanOperators(Not(constraint)));
1854 PrimExpr updated = VisitExpr(to_update);
1855
1856 if (!to_update.same_as(updated)) {
1857 to_update = updated;
1858 iterations_since_update = 0;
1859 } else {
1860 iterations_since_update++;
1861 if (iterations_since_update >= 2) {
1862 break;
1863 }
1864 }
1865 }
1866
1867 // Only construct a new object if a change has been made.
1868 // Otherwise, follow ExprMutator's convention of returning the
1869 // original object.
1870 if (a.same_as(op->a) && b.same_as(op->b)) {
1871 return GetRef<PrimExpr>(op);
1872 } else {
1873 return Or(a, b);
1874 }
1875 }();
1876
1877 op = ret.as<OrNode>();
1878 if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value();
1879 if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
1880 if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) &&
1881 !recursively_visiting_boolean_) {
1882 return SimplifyAsAndOfOrs(ret, analyzer_);
1883 }
1884
1885 // Pattern var to match any expression
1886 PVar<PrimExpr> x, y, z;
1887 // Pattern var match IntImm
1888 PVar<IntImm> c1, c2;
1889 PVar<int> lanes;
1890
1891 if (op->dtype.lanes() != 1) {
1892 TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
1893 }
1894
1895 auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));
1896
1897 TVM_TRY_REWRITE(x == y || x != y, ctrue);
1898 TVM_TRY_REWRITE(x != y || x == y, ctrue);
1899 TVM_TRY_REWRITE(x || !x, ctrue);
1900 TVM_TRY_REWRITE(x <= y || y < x, ctrue);
1901 TVM_TRY_REWRITE(y < x || x <= y, ctrue);
1902
1903 TVM_TRY_REWRITE(x < y || y < x, x != y);
1904
1905 TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
1906 TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);
1907
1908 TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
1909 TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
1910 TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
1911 TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
1912
1913 TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
1914 TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
1915
1916 TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
1917 TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
1918
1919 TVM_TRY_RECURSIVE_REWRITE(x < y || x == y, x <= y);
1920 TVM_TRY_RECURSIVE_REWRITE(x < y || y == x, x <= y);
1921 TVM_TRY_RECURSIVE_REWRITE(x == y || x < y, x <= y);
1922 TVM_TRY_RECURSIVE_REWRITE(y == x || x < y, x <= y);
1923
1924 TVM_TRY_RECURSIVE_REWRITE(x || (y || z), (x || y) || z);
1925
1926 return ret;
1927}
1928
1929PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) {
1930 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1931 op = ret.as<SelectNode>();
1932 if (op == nullptr) return ret;
1933 // Pattern var to match any expression
1934 PVar<PrimExpr> x, y;
1935 TVM_TRY_REWRITE(select(x, y, y), y);
1936 return ret;
1937}
1938
1939PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
1940 // add condition context to if_then_else
1941 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
1942 op = ret.as<CallNode>();
1943 if (op == nullptr) return ret;
1944
1945 if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
1946 return op->args[0];
1947 } else if (op->op.same_as(tir::builtin::shift_right())) {
1948 if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
1949 // the operator overload will eagerly constant fold.
1950 return op->args[0] >> op->args[1];
1951 }
1952 } else if (op->op.same_as(tir::builtin::shift_left())) {
1953 if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
1954 // the operator overload will eagerly constant fold.
1955 return op->args[0] << op->args[1];
1956 }
1957 } else if (op->op.same_as(Op::Get("tir.ceil"))) {
1958 PrimExpr ceil_arg = op->args[0];
1959 if (auto arg_int = op->args[0].as<IntImmNode>()) {
1960 return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
1961 } else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
1962 return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
1963 } else if (auto arg_call = ceil_arg.as<CallNode>()) {
1964 // ceil(log2(cast(n,"float64"))) is used as the implementation of
1965 // topi.math.ceil_log2, and appears in iteration bounds.
1966 if (arg_call->op.same_as(Op::Get("tir.log2"))) {
1967 PrimExpr log_arg = arg_call->args[0];
1968 if (auto as_float = log_arg.as<FloatImmNode>()) {
1969 // ceil(log2(n)) can be simplified, and should produce the
1970 // same integer result regardless of the target's rounding
1971 // conventions.
1972 return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
1973 }
1974 }
1975 }
1976 }
1977
1978 if (op->op.same_as(tir::builtin::likely())) {
1979 // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
1980 if (auto match = TryMatchLiteralConstraint(op->args[0])) {
1981 return match.value();
1982 }
1983 }
1984
1985 if (op->op.same_as(tir::builtin::if_then_else())) {
1986 // Simplify nested if_then_else
1987 // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr }
1988 // => if (cond && inner_cond) { inner_then_expr } else { else_expr }
1989 const PrimExpr& cond = op->args[0];
1990 const PrimExpr& then_expr = op->args[1];
1991 const PrimExpr& else_expr = op->args[2];
1992 const CallNode* inner_call = then_expr.as<CallNode>();
1993 if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) {
1994 const PrimExpr& inner_cond = inner_call->args[0];
1995 const PrimExpr& inner_then_expr = inner_call->args[1];
1996 const PrimExpr& inner_else_expr = inner_call->args[2];
1997 // Only check constant cases to avoid recursion
1998 if (is_const_number(inner_else_expr) && is_const_number(else_expr) &&
1999 analyzer_->CanProve(inner_else_expr == else_expr)) {
2000 return if_then_else(cond && inner_cond, inner_then_expr, else_expr);
2001 }
2002 }
2003 }
2004
2005 return ret;
2006}
2007
2008PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) {
2009 Var var = GetRef<Var>(op);
2010 if (op->dtype == DataType::Bool()) {
2011 if (auto match = TryMatchLiteralConstraint(var)) {
2012 return match.value();
2013 }
2014 }
2015
2016 auto it = var_map_.find(var);
2017 if (it != var_map_.end()) {
2018 return it->second;
2019 }
2020 return GetRef<PrimExpr>(op);
2021}
2022
2023PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) {
2024 PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
2025 op = ret.as<CastNode>();
2026 return cast(op->dtype, op->value);
2027}
2028
2029bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) {
2030 // Only inline trivial bindings to avoid deep expression explosion
2031 // when we need let to construct complicated expressions.
2032 if (is_const_number(op->value)) return true;
2033 if (op->value.as<VarNode>()) return true;
2034 return false;
2035}
2036
2037PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
2038 PrimExpr value = this->VisitExpr(op->value);
2039 if (CanInlineLet(op)) {
2040 // it is fine to discard the let binding
2041 // because the value will always be inlined in the simplifier.
2042 analyzer_->Bind(op->var, value);
2043 return this->VisitExpr(op->body);
2044 }
2045 PrimExpr body = this->VisitExpr(op->body);
2046 if (value.same_as(op->value) && body.same_as(op->body)) {
2047 return GetRef<PrimExpr>(op);
2048 } else {
2049 return Let(op->var, value, body);
2050 }
2051}
2052
2053PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) {
2054 // Run simplification in post order
2055 PrimExpr res = expr;
2056 int max_iter = 2;
2057 for (int i = 0; i < max_iter; ++i) {
2058 PrimExpr new_expr = impl_->operator()(res);
2059 if (new_expr.same_as(res)) return res;
2060 res = new_expr;
2061 }
2062 return res;
2063}
2064
2065void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_override) {
2066 impl_->Update(var, info, allow_override);
2067}
2068
2069std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
2070 return impl_->EnterConstraint(constraint);
2071}
2072
2073void RewriteSimplifier::SetEnabledExtensions(Extension flags) {
2074 impl_->SetEnabledExtensions(flags);
2075}
2076RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
2077 return impl_->GetEnabledExtensions();
2078}
2079
2080RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}
2081
2082RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
2083
2084} // namespace arith
2085} // namespace tvm
2086