1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rewrite_simplify.cc |
22 | * \brief Rewrite-rule based simplification. |
23 | */ |
24 | // Acknowledgement: Most rewrite-rules are from Halide. |
25 | #include "rewrite_simplify.h" |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/tir/builtin.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <algorithm> |
32 | #include <utility> |
33 | |
34 | #include "../target/datatype/registry.h" |
35 | #include "conjunctive_normal_form.h" |
36 | #include "const_fold.h" |
37 | #include "constraint_extract.h" |
38 | #include "pattern_match.h" |
39 | |
40 | namespace tvm { |
41 | namespace arith { |
42 | |
43 | using namespace tir; |
44 | |
45 | // macro for doing simple rewrite |
46 | #define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ |
47 | if ((SrcExpr).Match(ret)) { \ |
48 | return (ResExpr).Eval(); \ |
49 | } |
50 | |
51 | // macro for rewrite + recursively rewrite ResExpr |
52 | #define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \ |
53 | if ((SrcExpr).Match(ret)) { \ |
54 | return RecursiveRewrite((ResExpr).Eval()); \ |
55 | } |
56 | |
57 | // macro rewrite only if CondExor is true after match. |
58 | #define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ |
59 | if ((SrcExpr).Match(ret) && (CondExpr)) { \ |
60 | return (ResExpr).Eval(); \ |
61 | } |
62 | |
63 | // macro rewrite + recursive_rewrite only if CondExor is true after match. |
64 | #define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ |
65 | if ((SrcExpr).Match(ret) && (CondExpr)) { \ |
66 | return RecursiveRewrite((ResExpr).Eval()); \ |
67 | } |
68 | |
69 | // NOTE for developers: |
70 | // |
71 | // We mainly focus on index expression simplification. |
72 | // Besides the RewriteSimplifier, some cases can be better |
73 | // handled by CanonicalSimplifier. |
74 | // |
75 | |
76 | /* Utility for rewriting only boolean portions of an expression |
77 | * |
78 | * Performs a subset of simplifications done by RewriteSimplifier, |
79 | * sufficient to negate a simplified expression. Intended for |
80 | * application on an expression that has previously been simplified. |
81 | * |
82 | * \param expr The boolean expression to be normalized |
83 | * |
84 | * \returns The normalized boolean expression |
85 | */ |
86 | PrimExpr NormalizeBooleanOperators(PrimExpr expr) { |
87 | PVar<PrimExpr> x, y; |
88 | |
89 | while (true) { |
90 | if ((!!x).Match(expr)) { |
91 | expr = x.Eval(); |
92 | } else if ((!(x || y)).Match(expr)) { |
93 | return NormalizeBooleanOperators(!x.Eval()) && NormalizeBooleanOperators(!y.Eval()); |
94 | } else if ((!(x && y)).Match(expr)) { |
95 | return NormalizeBooleanOperators(!x.Eval()) || NormalizeBooleanOperators(!y.Eval()); |
96 | } else if ((x >= y).Match(expr) || (!(x < y)).Match(expr) || (!(y > x)).Match(expr)) { |
97 | return y.Eval() <= x.Eval(); |
98 | } else if ((x > y).Match(expr) || (!(x <= y)).Match(expr) || (!(y >= x)).Match(expr)) { |
99 | return y.Eval() < x.Eval(); |
100 | } else if ((!(x == y)).Match(expr)) { |
101 | return x.Eval() != y.Eval(); |
102 | } else if ((!(x != y)).Match(expr)) { |
103 | return x.Eval() == y.Eval(); |
104 | } else { |
105 | return expr; |
106 | } |
107 | } |
108 | } |
109 | |
110 | CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) { |
111 | CompareResult output = CompareResult::kUnknown; |
112 | |
113 | auto is_finished = [&output]() { |
114 | return output == CompareResult::kEQ || output == CompareResult::kLT || |
115 | output == CompareResult::kGT; |
116 | }; |
117 | |
118 | output = CompareResult(output & TryCompareUsingConstIntBounds(x, y)); |
119 | |
120 | if (is_finished()) return output; |
121 | |
122 | output = CompareResult(output & TryCompareUsingKnownInequalities(x, y)); |
123 | |
124 | return output; |
125 | } |
126 | |
127 | CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimExpr& x, |
128 | const PrimExpr y) { |
129 | return TryCompare(x - y, 0); |
130 | } |
131 | |
132 | CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x, |
133 | const PrimExpr& y) { |
134 | bool propagate_inequalities = enabled_extensions_ & kTransitivelyProveInequalities; |
135 | return analyzer_->transitive_comparisons.TryCompare(x, y, propagate_inequalities); |
136 | } |
137 | |
138 | // try to prove x equals val |
139 | CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) { |
140 | PrimExpr diff = this->VisitExpr(x); |
141 | if (const auto* ptr = diff.as<IntImmNode>()) { |
142 | if (ptr->value == val) { |
143 | return CompareResult::kEQ; |
144 | } else if (ptr->value > val) { |
145 | return CompareResult::kGT; |
146 | } else if (ptr->value < val) { |
147 | return CompareResult::kLT; |
148 | } |
149 | } |
150 | ConstIntBound dbound = analyzer_->const_int_bound(diff); |
151 | if (dbound->min_value == val && dbound->max_value == val) { |
152 | return CompareResult::kEQ; |
153 | } |
154 | if (dbound->min_value > val) { |
155 | return CompareResult::kGT; |
156 | } |
157 | if (dbound->max_value < val) { |
158 | return CompareResult::kLT; |
159 | } |
160 | if (dbound->min_value >= val) { |
161 | return CompareResult::kGE; |
162 | } |
163 | if (dbound->max_value <= val) { |
164 | return CompareResult::kLE; |
165 | } |
166 | if (val == 0) { |
167 | ModularSet dmod = analyzer_->modular_set(diff); |
168 | if (dmod->base != 0) { |
169 | return CompareResult::kNE; |
170 | } |
171 | } |
172 | return CompareResult::kUnknown; |
173 | } |
174 | |
175 | void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { |
176 | if (!can_override) { |
177 | auto it = var_map_.find(var); |
178 | if (it != var_map_.end()) { |
179 | ICHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" |
180 | << " with a different value: " |
181 | << "original=" << it->second << ", new=" << info; |
182 | } |
183 | } |
184 | var_map_[var] = info; |
185 | } |
186 | |
187 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { |
188 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
189 | op = ret.as<AddNode>(); |
190 | if (auto const_res = TryConstFold<Add>(op->a, op->b)) return const_res.value(); |
191 | // Pattern var to match any expression |
192 | PVar<PrimExpr> x, y, z, b1, b2, s1, s2; |
193 | // Pattern var match IntImm |
194 | PVar<IntImm> c1, c2, c3; |
195 | // Pattern var match FloatImm |
196 | PVar<FloatImm> c4; |
197 | // Pattern var for lanes in broadcast and ramp |
198 | PVar<int> lanes; |
199 | // Vector rules |
200 | if (op->dtype.lanes() != 1) { |
201 | TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); |
202 | TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); |
203 | TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); |
204 | TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); |
205 | TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f); |
206 | } |
207 | |
208 | if (IsIndexType(op->dtype)) { |
209 | // Index rules |
210 | // cancelation rules |
211 | TVM_TRY_REWRITE((x - y) + y, x); |
212 | TVM_TRY_REWRITE(x + (y - x), y); |
213 | |
214 | TVM_TRY_REWRITE((x - y) + (y - z), x - z); |
215 | TVM_TRY_REWRITE((x - y) + (z - x), z - y); |
216 | |
217 | TVM_TRY_REWRITE(min(x, y - z) + z, min(x + z, y)); |
218 | TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z)); |
219 | TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); |
220 | TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); |
221 | |
222 | TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), |
223 | c1.Eval()->value == -c2.Eval()->value); |
224 | TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), |
225 | c1.Eval()->value == -c2.Eval()->value); |
226 | TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y), |
227 | c1.Eval()->value == -c2.Eval()->value); |
228 | TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y), |
229 | c1.Eval()->value == -c2.Eval()->value); |
230 | |
231 | TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); |
232 | TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); |
233 | TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); |
234 | TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); |
235 | |
236 | TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value); |
237 | TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value); |
238 | TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value); |
239 | TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value); |
240 | |
241 | // constant folding |
242 | // NOTE: canonicalization might better at this. |
243 | TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); |
244 | |
245 | // mul co-efficient folding |
246 | TVM_TRY_REWRITE(x + x, x * 2); |
247 | TVM_TRY_REWRITE(x * y + x, x * (y + 1)); |
248 | TVM_TRY_REWRITE(y * x + x, x * (y + 1)); |
249 | TVM_TRY_REWRITE(x + y * x, x * (1 + y)); |
250 | TVM_TRY_REWRITE(x + x * y, x * (1 + y)); |
251 | TVM_TRY_REWRITE(x * y + x * z, x * (y + z)); |
252 | TVM_TRY_REWRITE(y * x + x * z, x * (y + z)); |
253 | TVM_TRY_REWRITE(x * y + z * x, x * (y + z)); |
254 | TVM_TRY_REWRITE(y * x + z * x, x * (y + z)); |
255 | |
256 | // DivMod rules |
257 | // truc div |
258 | TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); |
259 | // floor div |
260 | TVM_TRY_REWRITE(floordiv(x, y) * y + floormod(x, y), x); |
261 | TVM_TRY_REWRITE(y * floordiv(x, y) + floormod(x, y), x); |
262 | TVM_TRY_REWRITE(floormod(x, y) + floordiv(x, y) * y, x); |
263 | TVM_TRY_REWRITE(floormod(x, y) + y * floordiv(x, y), x); |
264 | |
265 | TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), |
266 | c2.Eval()->value > 0); |
267 | |
268 | // canonicalization rule |
269 | // will try rewrite again after canonicalization. |
270 | TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1); |
271 | TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1); |
272 | TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); |
273 | TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); |
274 | TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); |
275 | TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x); |
276 | |
277 | // DivMod rules |
278 | // truc div |
279 | TVM_TRY_RECURSIVE_REWRITE(truncmod(y, c1) + x * c1, x * c1 + truncmod(y, c1)); |
280 | // floor div |
281 | TVM_TRY_RECURSIVE_REWRITE(floormod(y, c1) + x * c1, x * c1 + floormod(y, c1)); |
282 | } |
283 | |
284 | // condition rules. |
285 | TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2)); |
286 | // default value |
287 | return ret; |
288 | } |
289 | |
290 | std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { |
291 | size_t old_literal_size = literal_constraints_.size(); |
292 | // we will compare the already simplified result with the constraint, |
293 | // so simplify the constraint as well |
294 | PrimExpr new_constraint = operator()(constraint); |
295 | for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { |
296 | if (SideEffect(subconstraint) <= CallEffectKind::kPure) { |
297 | literal_constraints_.push_back(subconstraint); |
298 | PrimExpr negation; |
299 | if (subconstraint.dtype().is_bool()) { |
300 | // We could apply NormalizeBooleanOperators during |
301 | // TryMatchLiteralConstraint, but that would require |
302 | // performing a rewrite of each expression being checked. |
303 | // This way, we only apply a rewrite for each constraint being |
304 | // applied. |
305 | negation = NormalizeBooleanOperators(Not(subconstraint)); |
306 | } else { |
307 | negation = subconstraint == make_zero(subconstraint.dtype()); |
308 | } |
309 | literal_constraints_.push_back(Not(negation)); |
310 | } |
311 | } |
312 | size_t new_literal_size = literal_constraints_.size(); |
313 | auto frecover = [old_literal_size, new_literal_size, this]() { |
314 | ICHECK_EQ(literal_constraints_.size(), new_literal_size); |
315 | literal_constraints_.resize(old_literal_size); |
316 | }; |
317 | return frecover; |
318 | } |
319 | |
320 | void RewriteSimplifier::Impl::SetEnabledExtensions(Extension flags) { enabled_extensions_ = flags; } |
321 | |
322 | RewriteSimplifier::Extension RewriteSimplifier::Impl::GetEnabledExtensions() const { |
323 | return enabled_extensions_; |
324 | } |
325 | |
326 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { |
327 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
328 | op = ret.as<SubNode>(); |
329 | if (auto const_res = TryConstFold<Sub>(op->a, op->b)) return const_res.value(); |
330 | // Pattern var to match any expression |
331 | PVar<PrimExpr> x, y, z, b1, b2, s1, s2; |
332 | // Pattern var match IntImm |
333 | PVar<IntImm> c1, c2, c3; |
334 | // Pattern var for lanes in broadcast and ramp |
335 | PVar<int> lanes; |
336 | // Vector rules |
337 | if (op->dtype.lanes() != 1) { |
338 | TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); |
339 | TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); |
340 | TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); |
341 | TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); |
342 | } |
343 | |
344 | if (IsIndexType(op->dtype)) { |
345 | // Index rules |
346 | // cancelation rules |
347 | TVM_TRY_REWRITE((x + y) - y, x); |
348 | TVM_TRY_REWRITE((x + y) - x, y); |
349 | TVM_TRY_REWRITE(x - (y + x), 0 - y); |
350 | TVM_TRY_REWRITE(x - (x + y), 0 - y); |
351 | |
352 | TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x)); |
353 | TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0)); |
354 | TVM_TRY_REWRITE(max(x, y) - x, max(0, y - x)); |
355 | TVM_TRY_REWRITE(max(x, y) - y, max(x - y, 0)); |
356 | |
357 | TVM_TRY_REWRITE(x - max(x, y), min(0, x - y)); |
358 | TVM_TRY_REWRITE(y - max(x, y), min(y - x, 0)); |
359 | TVM_TRY_REWRITE(x - min(x, y), max(0, x - y)); |
360 | TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0)); |
361 | |
362 | // mul co-efficient folding |
363 | TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); |
364 | TVM_TRY_REWRITE(x * y - x, x * (y - 1)); |
365 | TVM_TRY_REWRITE(y * x - x, x * (y - 1)); |
366 | TVM_TRY_REWRITE(x - y * x, x * (1 - y)); |
367 | TVM_TRY_REWRITE(x - x * y, x * (1 - y)); |
368 | TVM_TRY_REWRITE(x * y - x * z, x * (y - z)); |
369 | TVM_TRY_REWRITE(y * x - x * z, x * (y - z)); |
370 | TVM_TRY_REWRITE(x * y - z * x, x * (y - z)); |
371 | TVM_TRY_REWRITE(y * x - z * x, x * (y - z)); |
372 | |
373 | // constant cancelation |
374 | TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); |
375 | TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); |
376 | |
377 | // cancelization rule involving 4 operands |
378 | TVM_TRY_REWRITE((x + y) - (x + z), y - z); |
379 | TVM_TRY_REWRITE((x + y) - (z + x), y - z); |
380 | TVM_TRY_REWRITE((y + x) - (z + x), y - z); |
381 | TVM_TRY_REWRITE((y + x) - (x + z), y - z); |
382 | |
383 | TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); |
384 | TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); |
385 | TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); |
386 | TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); |
387 | |
388 | TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); |
389 | TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); |
390 | TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); |
391 | TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); |
392 | |
393 | TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); |
394 | TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); |
395 | TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); |
396 | TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); |
397 | |
398 | TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x)); |
399 | TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x)); |
400 | |
401 | TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s1, |
402 | CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0)); |
403 | |
404 | TVM_TRY_REWRITE_IF(min(b1, b2) - min(s1, s2), b1 - s2, |
405 | CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); |
406 | TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s1, |
407 | CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0)); |
408 | TVM_TRY_REWRITE_IF(max(b1, b2) - max(s1, s2), b1 - s2, |
409 | CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); |
410 | |
411 | // DivMod rules |
412 | // trucdiv |
413 | // NOTE: c*(x/c) + x % c == x is true all division mode. |
414 | TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0); |
415 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0); |
416 | TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y, |
417 | c1.Eval()->value != 0); |
418 | TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1), |
419 | c1.Eval()->value != 0); |
420 | TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y, |
421 | c1.Eval()->value != 0); |
422 | TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y, |
423 | c1.Eval()->value != 0); |
424 | |
425 | TVM_TRY_REWRITE_IF( |
426 | x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, |
427 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
428 | TVM_TRY_REWRITE_IF( |
429 | truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, |
430 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
431 | TVM_TRY_REWRITE_IF( |
432 | x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, |
433 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
434 | TVM_TRY_REWRITE_IF( |
435 | truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, |
436 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
437 | TVM_TRY_REWRITE_IF( |
438 | x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, |
439 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
440 | TVM_TRY_REWRITE_IF( |
441 | truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, |
442 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
443 | |
444 | // Proof in the case of floordiv, need positive condition. |
445 | // let x = a * c3 + r |
446 | // (x + c1) / c3 - x / c3 => (r + c1) / c3 |
447 | // NOTE: the use of floormod(c2, c3) was intentional to simplify the const. |
448 | TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), |
449 | truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3), |
450 | CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && |
451 | c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); |
452 | TVM_TRY_REWRITE_IF( |
453 | truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3), |
454 | CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); |
455 | |
456 | // floordiv |
457 | TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0); |
458 | TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0); |
459 | TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y, |
460 | c1.Eval()->value != 0); |
461 | TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1), |
462 | c1.Eval()->value != 0); |
463 | TVM_TRY_REWRITE_IF(x - floordiv(x - y, c1) * c1, floormod(x - y, c1) + y, |
464 | c1.Eval()->value != 0); |
465 | TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y, |
466 | c1.Eval()->value != 0); |
467 | |
468 | TVM_TRY_REWRITE_IF( |
469 | x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, |
470 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
471 | TVM_TRY_REWRITE_IF( |
472 | floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, |
473 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
474 | TVM_TRY_REWRITE_IF( |
475 | x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, |
476 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
477 | TVM_TRY_REWRITE_IF( |
478 | floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, |
479 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
480 | TVM_TRY_REWRITE_IF( |
481 | x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, |
482 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
483 | TVM_TRY_REWRITE_IF( |
484 | floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, |
485 | c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); |
486 | |
487 | TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), |
488 | floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), |
489 | c3.Eval()->value > 0); |
490 | TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3), |
491 | c3.Eval()->value > 0); |
492 | |
493 | // canonicalization rule |
494 | // will try rewrite again after canonicalization. |
495 | TVM_TRY_REWRITE(x - c1, x + (0 - c1)); |
496 | TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); |
497 | TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); |
498 | TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); |
499 | } else if (op->dtype.is_float()) { |
500 | // Cancellation rules. Deliberately off of the integer path, to |
501 | // avoid introducing checks on the side effects for the fast path. |
502 | TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x), |
503 | SideEffect(x.Eval()) <= CallEffectKind::kReadState); |
504 | TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState); |
505 | TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); |
506 | TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); |
507 | TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState); |
508 | } |
509 | |
510 | // condition rules. |
511 | TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2)); |
512 | TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z))); |
513 | TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y)); |
514 | return ret; |
515 | } |
516 | |
517 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { |
518 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
519 | op = ret.as<MulNode>(); |
520 | if (auto const_res = TryConstFold<Mul>(op->a, op->b)) return const_res.value(); |
521 | // Pattern var to match any expression |
522 | PVar<PrimExpr> x, y, z, b1, b2, s1, s2; |
523 | // Pattern var match IntImm |
524 | PVar<IntImm> c1, c2; |
525 | // Pattern var match FloatImm |
526 | PVar<FloatImm> c3; |
527 | // Pattern var for lanes in broadcast and ramp |
528 | PVar<int> lanes; |
529 | // Vector rules |
530 | if (op->dtype.lanes() != 1) { |
531 | TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); |
532 | TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); |
533 | TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); |
534 | TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f); |
535 | } |
536 | |
537 | if (IsIndexType(op->dtype)) { |
538 | // constant simplification rule |
539 | TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); |
540 | TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); |
541 | TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); |
542 | TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); |
543 | |
544 | // Two representations of const*ceildiv(x, c1) |
545 | TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c2), c1) * c1, x - floormod(x, c2), |
546 | c1.Eval()->value == -c2.Eval()->value); |
547 | |
548 | // canonicalization |
549 | TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); |
550 | TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); |
551 | TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); |
552 | } |
553 | return ret; |
554 | } |
555 | |
556 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { |
557 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
558 | op = ret.as<DivNode>(); |
559 | if (auto const_res = TryConstFold<Div>(op->a, op->b)) return const_res.value(); |
560 | // Pattern var to match any expression |
561 | PVar<PrimExpr> x, y, z, b1; |
562 | // Pattern var match IntImm |
563 | PVar<IntImm> c1, c2, c3; |
564 | // Pattern var for lanes in broadcast and ramp |
565 | PVar<int> lanes; |
566 | |
567 | // x / 2.0 = x * 0.5 |
568 | if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) { |
569 | ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() || |
570 | datatype::Registry::Global()->GetTypeRegistered(op->dtype.code())); |
571 | return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); |
572 | } |
573 | |
574 | // Vector rules |
575 | if (op->dtype.lanes() != 1) { |
576 | // NOTE: use div as the pattern also works for float. |
577 | TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); |
578 | // ramp / bcast |
579 | if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { |
580 | int64_t c1val = c1.Eval()->value; |
581 | int64_t c2val = c2.Eval()->value; |
582 | ICHECK(c2val != 0) << "division by zero" ; |
583 | if (c1val % c2val == 0) { |
584 | return ramp(div(b1, c2), div(c1, c2), lanes).Eval(); |
585 | } |
586 | // If all possible indices in ramp are the same. |
587 | if (CanProveGreaterEqual(b1.Eval(), 0)) { |
588 | ModularSet bmod = analyzer_->modular_set(b1.Eval()); |
589 | int64_t ramp_min = bmod->base / c2val; |
590 | int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; |
591 | if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) { |
592 | return broadcast(div(b1, c2), lanes).Eval(); |
593 | } |
594 | } |
595 | } |
596 | } |
597 | |
598 | if (IsIndexType(op->dtype)) { |
599 | // Be-aware of the division rules: |
600 | // We adopt the default C division uses truncation instead of floordiv. |
601 | // This means most rules need to check non-negativeness of the operands. |
602 | |
603 | // TryConstFold doesn't work for negative cases because it is also used by legacy |
604 | // parts of tvm which still assume euclidean div. In this simplifier we assume that the division |
605 | // is truncated, so perform const folding again. |
606 | // NOTE: trunc div required |
607 | if (truncdiv(c1, c2).Match(ret)) { |
608 | int64_t c1val = c1.Eval()->value; |
609 | int64_t c2val = c2.Eval()->value; |
610 | return make_const(op->dtype, truncdiv(c1val, c2val)); |
611 | } |
612 | |
613 | // while it is always true for trunc div |
614 | // restrict to common case(positive div) |
615 | TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2), |
616 | c1.Eval()->value > 0 && c2.Eval()->value > 0); |
617 | |
618 | TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3), |
619 | c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 && |
620 | CanProveGreaterEqual(x.Eval(), 0)); |
621 | |
622 | if (truncdiv(x * c1, c2).Match(ret)) { |
623 | int64_t c1val = c1.Eval()->value; |
624 | int64_t c2val = c2.Eval()->value; |
625 | if (c1val > 0 && c2val > 0) { |
626 | if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval(); |
627 | if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval(); |
628 | } |
629 | } |
630 | |
631 | TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x)); |
632 | TVM_TRY_REWRITE(truncdiv(x * c1, x), c1); |
633 | TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1); |
634 | |
635 | // Rules involving 2-operands. |
636 | TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2), |
637 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
638 | c1.Eval()->value % c2.Eval()->value == 0 && |
639 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
640 | |
641 | TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)), |
642 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
643 | c1.Eval()->value % c2.Eval()->value == 0 && |
644 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
645 | |
646 | TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)), |
647 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
648 | c1.Eval()->value % c2.Eval()->value == 0 && |
649 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
650 | |
651 | TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2), |
652 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
653 | c1.Eval()->value % c2.Eval()->value == 0 && |
654 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
655 | |
656 | TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)), |
657 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
658 | c1.Eval()->value % c2.Eval()->value == 0 && |
659 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
660 | |
661 | TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)), |
662 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && |
663 | c1.Eval()->value % c2.Eval()->value == 0 && |
664 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
665 | |
666 | // Rules involving 3-operands. |
667 | TVM_TRY_REWRITE_IF( |
668 | truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), |
669 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
670 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
671 | |
672 | TVM_TRY_REWRITE_IF( |
673 | truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2), |
674 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
675 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0)); |
676 | |
677 | TVM_TRY_REWRITE_IF( |
678 | truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2), |
679 | c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
680 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0)); |
681 | |
682 | TVM_TRY_REWRITE_IF( |
683 | truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), |
684 | c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
685 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
686 | |
687 | TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2), |
688 | c1.Eval()->value > 0 && c2.Eval()->value > 0 && |
689 | c1.Eval()->value % c2.Eval()->value == 0 && |
690 | CanProveGreaterEqual(x.Eval(), 0)); |
691 | |
692 | TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1, |
693 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
694 | TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1, |
695 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
696 | |
697 | TVM_TRY_REWRITE_IF( |
698 | truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1, |
699 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
700 | TVM_TRY_REWRITE_IF( |
701 | truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1, |
702 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
703 | TVM_TRY_REWRITE_IF( |
704 | truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1, |
705 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
706 | TVM_TRY_REWRITE_IF( |
707 | truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1, |
708 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); |
709 | |
710 | TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x, |
711 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
712 | TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x, |
713 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); |
714 | |
715 | TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z), |
716 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && |
717 | CanProveGreaterEqual(z.Eval(), 0)); |
718 | TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z), |
719 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && |
720 | CanProveGreaterEqual(z.Eval(), 0)); |
721 | TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x, |
722 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && |
723 | CanProveGreaterEqual(z.Eval(), 0)); |
724 | TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x, |
725 | CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && |
726 | CanProveGreaterEqual(z.Eval(), 0)); |
727 | } |
728 | return ret; |
729 | } |
730 | |
731 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { |
732 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
733 | op = ret.as<ModNode>(); |
734 | if (auto const_res = TryConstFold<Mod>(op->a, op->b)) return const_res.value(); |
735 | |
736 | // Pattern var to match any expression |
737 | PVar<PrimExpr> x, y, z, b1; |
738 | // Pattern var match IntImm |
739 | PVar<IntImm> c1, c2; |
740 | // Pattern var for lanes in broadcast and ramp |
741 | PVar<int> lanes; |
742 | |
743 | // Vector rules |
744 | if (op->dtype.lanes() != 1) { |
745 | TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)), |
746 | broadcast(truncmod(x, y), lanes)); |
747 | |
748 | // ramp % bcast |
749 | if (truncmod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { |
750 | int64_t c1val = c1.Eval()->value; |
751 | int64_t c2val = c2.Eval()->value; |
752 | ICHECK(c2val != 0) << "division by zero" ; |
753 | if (c1val % c2val == 0) { |
754 | return broadcast(truncmod(b1, c2), lanes).Eval(); |
755 | } |
756 | // If all possible indices in ramp are the same. |
757 | if (CanProveGreaterEqual(b1.Eval(), 0)) { |
758 | ModularSet bmod = analyzer_->modular_set(b1.Eval()); |
759 | int64_t ramp_min = bmod->base / c2val; |
760 | int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val; |
761 | if (bmod->coeff % c2val == 0) { |
762 | if (ramp_min == ramp_max) { |
763 | return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); |
764 | } else { |
765 | return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); |
766 | } |
767 | } |
768 | } |
769 | } |
770 | } |
771 | |
772 | if (IsIndexType(op->dtype)) { |
773 | // Be-aware of the division rules: |
774 | // We adopt the default C division uses truncation instead of floordiv. |
775 | // This means most rules need to check non-negativeness of the operands. |
776 | TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x), |
777 | c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); |
778 | |
779 | TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2), |
780 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
781 | CanProveGreaterEqual((x * c1).Eval(), 0) && |
782 | CanProveGreaterEqual(y.Eval(), 0)); |
783 | |
784 | TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2), |
785 | c2.Eval()->value > 0 && c1.Eval()->value >= 0 && |
786 | c1.Eval()->value % c2.Eval()->value == 0 && |
787 | CanProveGreaterEqual(x.Eval(), 0)); |
788 | |
789 | TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2), |
790 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && |
791 | CanProveGreaterEqual(x.Eval(), 0) && |
792 | CanProveGreaterEqual((y * c1).Eval(), 0)); |
793 | |
794 | // canonicalization: x % c == x % (-c) for truncated division |
795 | // NOTE: trunc div required |
796 | TVM_TRY_RECURSIVE_REWRITE_IF( |
797 | truncmod(x, c1), truncmod(x, PConst<PrimExpr>(make_const(op->dtype, -c1.Eval()->value))), |
798 | c1.Eval()->value < 0); |
799 | |
800 | // try modular analysis |
801 | if (truncmod(x, c1).Match(ret)) { |
802 | ModularSet mod = analyzer_->modular_set(x.Eval()); |
803 | int64_t c1val = c1.Eval()->value; |
804 | if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) { |
805 | return truncmod(mod->base, c1).Eval(); |
806 | } |
807 | } |
808 | } |
809 | return ret; |
810 | } |
811 | |
812 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { |
813 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
814 | op = ret.as<FloorDivNode>(); |
815 | if (auto const_res = TryConstFold<FloorDiv>(op->a, op->b)) return const_res.value(); |
816 | // Pattern var to match any expression |
817 | PVar<PrimExpr> x, y, z, b1; |
818 | // Pattern var match IntImm |
819 | PVar<IntImm> c1, c2, c3; |
820 | // Pattern var for lanes in broadcast and ramp |
821 | PVar<int> lanes; |
822 | |
823 | // Vector rules |
824 | if (op->dtype.lanes() != 1) { |
825 | TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)), |
826 | broadcast(floordiv(x, y), lanes)); |
827 | // ramp // bcast |
828 | if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { |
829 | int64_t c1val = c1.Eval()->value; |
830 | int64_t c2val = c2.Eval()->value; |
831 | ICHECK(c2val != 0) << "division by zero" ; |
832 | if (c1val % c2val == 0) { |
833 | return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); |
834 | } |
835 | // If all possible indices in ramp are the same. |
836 | ModularSet bmod = analyzer_->modular_set(b1.Eval()); |
837 | int64_t ramp_min = floordiv(bmod->base, c2val); |
838 | int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); |
839 | if (ramp_min == ramp_max) { |
840 | // If b1 can devide c2 |
841 | if (bmod->coeff % c2val == 0) { |
842 | return broadcast(floordiv(b1, c2), lanes).Eval(); |
843 | } |
844 | // If all indices can be guaranteed to settle inside a coeff range |
845 | if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { |
846 | return broadcast(floordiv(b1, c2), lanes).Eval(); |
847 | } |
848 | } |
849 | } |
850 | } |
851 | |
852 | if (IsIndexType(op->dtype)) { |
853 | // Be-aware of the division rules: this is floor division. |
854 | TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2), |
855 | c1.Eval()->value > 0 && c2.Eval()->value > 0); |
856 | |
857 | TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3), |
858 | c1.Eval()->value > 0 && c3.Eval()->value > 0); |
859 | |
860 | if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) || |
861 | floordiv(y + x * c1, c2).Match(ret)) { |
862 | int64_t c1val = c1.Eval()->value; |
863 | int64_t c2val = c2.Eval()->value; |
864 | PrimExpr yval = y.EvalOr(Integer(0)); |
865 | if (c2val == 0) return ret; |
866 | |
867 | // try eliminate residue part |
868 | PrimExpr residue = |
869 | floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val); |
870 | PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); |
871 | auto bound = analyzer_->const_int_bound(residue); |
872 | if (bound.defined() && bound->max_value == bound->min_value) { |
873 | return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); |
874 | } |
875 | |
876 | // try simplify divisor |
877 | if (c1val > 0 && c2val > 0 && c2val % c1val == 0 && |
878 | CanProveLess(floormod(yval, c2val), c1val)) { |
879 | // assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then |
880 | // (x * c1 + y) // c2 |
881 | // ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1) |
882 | // ==> x' + d + (b * c1 + e) // c2 |
883 | // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 |
884 | // ==> x // (c2 // c1) + (y // c2) |
885 | return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; |
886 | } |
887 | } |
888 | |
889 | TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x)); |
890 | TVM_TRY_REWRITE(floordiv(x * c1, x), c1); |
891 | TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); |
892 | |
893 | // Rules involving 2-operands. |
894 | TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), |
895 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
896 | |
897 | TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), |
898 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
899 | |
900 | TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), |
901 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
902 | |
903 | TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), |
904 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
905 | |
906 | // Rules involving 3-operands. |
907 | TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), |
908 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
909 | TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), floordiv(x, floordiv(c2, c1)), |
910 | c1.Eval()->value > 0 && c2.Eval()->value > 0 && |
911 | c2.Eval()->value % c1.Eval()->value == 0 && |
912 | CanProveEqual(floordiv(y.Eval() + z.Eval(), c1.Eval()), 0)); |
913 | |
914 | TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2), |
915 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
916 | |
917 | TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2), |
918 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
919 | |
920 | TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), |
921 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
922 | |
923 | TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), |
924 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
925 | |
926 | TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); |
927 | |
928 | TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); |
929 | |
930 | TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); |
931 | |
932 | TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1, |
933 | CanProveGreaterEqual(x.Eval(), 0)); |
934 | TVM_TRY_REWRITE_IF(floordiv((y + x) + z, x), floordiv(y + z, x) + 1, |
935 | CanProveGreaterEqual(x.Eval(), 0)); |
936 | TVM_TRY_REWRITE_IF(floordiv(y + (z + x), x), floordiv(y + z, x) + 1, |
937 | CanProveGreaterEqual(x.Eval(), 0)); |
938 | TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1, |
939 | CanProveGreaterEqual(x.Eval(), 0)); |
940 | |
941 | TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0)); |
942 | TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0)); |
943 | |
944 | TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z), |
945 | CanProveGreaterEqual(z.Eval(), 0)); |
946 | TVM_TRY_REWRITE_IF(floordiv(z * x + y, z), x + floordiv(y, z), |
947 | CanProveGreaterEqual(z.Eval(), 0)); |
948 | TVM_TRY_REWRITE_IF(floordiv(y + x * z, z), floordiv(y, z) + x, |
949 | CanProveGreaterEqual(z.Eval(), 0)); |
950 | TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, |
951 | CanProveGreaterEqual(z.Eval(), 0)); |
952 | |
953 | TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); |
954 | } |
955 | return ret; |
956 | } |
957 | |
958 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { |
959 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
960 | op = ret.as<FloorModNode>(); |
961 | if (auto const_res = TryConstFold<FloorMod>(op->a, op->b)) return const_res.value(); |
962 | |
963 | // Pattern var to match any expression |
964 | PVar<PrimExpr> x, y, z, b1; |
965 | // Pattern var match IntImm |
966 | PVar<IntImm> c1, c2; |
967 | // Pattern var for lanes in broadcast and ramp |
968 | PVar<int> lanes; |
969 | |
970 | // Vector rules |
971 | if (op->dtype.lanes() != 1) { |
972 | TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)), |
973 | broadcast(floormod(x, y), lanes)); |
974 | |
975 | // floormod(ramp, bcast) |
976 | if (floormod(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) { |
977 | int64_t c1val = c1.Eval()->value; |
978 | int64_t c2val = c2.Eval()->value; |
979 | ICHECK(c2val != 0) << "division by zero" ; |
980 | if (c1val % c2val == 0) { |
981 | return broadcast(floormod(b1, c2), lanes).Eval(); |
982 | } |
983 | // If all possible indices in ramp are the same. |
984 | ModularSet bmod = analyzer_->modular_set(b1.Eval()); |
985 | int64_t ramp_min = floordiv(bmod->base, c2val); |
986 | int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); |
987 | if (ramp_min == ramp_max) { |
988 | // If b1 can devide c2 |
989 | if (bmod->coeff % c2val == 0) { |
990 | return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); |
991 | } |
992 | // If all indices can be guaranteed to settle inside a coeff range |
993 | if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { |
994 | return ramp(floormod(b1, c2), c1, lanes).Eval(); |
995 | } |
996 | } |
997 | if (bmod->coeff % c2val == 0) { |
998 | return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); |
999 | } |
1000 | } |
1001 | } |
1002 | |
1003 | if (IsIndexType(op->dtype)) { |
1004 | // Be-aware of the division rules: we use floordiv/floormod here |
1005 | TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), |
1006 | c2.Eval()->value != 0); |
1007 | |
1008 | TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, |
1009 | c1.Eval()->value > 0 && c2.Eval()->value > 0 && |
1010 | c2.Eval()->value % c1.Eval()->value == 0 && |
1011 | CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); |
1012 | |
1013 | TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), |
1014 | c2.Eval()->value > 0); |
1015 | |
1016 | TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), |
1017 | c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); |
1018 | |
1019 | TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), |
1020 | c2.Eval()->value > 0); |
1021 | |
1022 | TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); |
1023 | |
1024 | TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); |
1025 | TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); |
1026 | |
1027 | // try modular analysis |
1028 | if (floormod(x, c1).Match(ret)) { |
1029 | ModularSet mod = analyzer_->modular_set(x.Eval()); |
1030 | int64_t c1val = c1.Eval()->value; |
1031 | if (mod->coeff % c1val == 0 && c1val > 0) { |
1032 | return floormod(mod->base, c1).Eval(); |
1033 | } |
1034 | } |
1035 | } |
1036 | return ret; |
1037 | } |
1038 | |
1039 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { |
1040 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1041 | op = ret.as<MinNode>(); |
1042 | if (auto const_res = TryConstFold<Min>(op->a, op->b)) return const_res.value(); |
1043 | |
1044 | // Pattern var to match any expression |
1045 | PVar<PrimExpr> x, y, z, s1, s2; |
1046 | // Pattern var match IntImm |
1047 | PVar<IntImm> c1, c2; |
1048 | PVar<int> lanes; |
1049 | |
1050 | // vector rule |
1051 | if (op->dtype.lanes() != 1) { |
1052 | TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); |
1053 | TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), |
1054 | min(x, broadcast(min(y, z), lanes))); |
1055 | } |
1056 | if (IsIndexType(op->dtype)) { |
1057 | TVM_TRY_REWRITE(min(x, x), x); |
1058 | |
1059 | // constant int bound |
1060 | ConstIntBound a_bound = analyzer_->const_int_bound(op->a); |
1061 | ConstIntBound b_bound = analyzer_->const_int_bound(op->b); |
1062 | if (a_bound->max_value <= b_bound->min_value) { |
1063 | return op->a; |
1064 | } |
1065 | if (b_bound->max_value <= a_bound->min_value) { |
1066 | return op->b; |
1067 | } |
1068 | |
1069 | // constant comparison |
1070 | if (min(x + c1, x + c2).Match(ret)) { |
1071 | if (c1.Eval()->value < c2.Eval()->value) { |
1072 | return (x + c1).Eval(); |
1073 | } else { |
1074 | return (x + c2).Eval(); |
1075 | } |
1076 | } |
1077 | if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) { |
1078 | if (c1.Eval()->value < 0) { |
1079 | return (x + c1).Eval(); |
1080 | } else { |
1081 | return x.Eval(); |
1082 | } |
1083 | } |
1084 | if (min(c1 - x, c2 - x).Match(ret)) { |
1085 | if (c1.Eval()->value < c2.Eval()->value) { |
1086 | return (c1 - x).Eval(); |
1087 | } else { |
1088 | return (c2 - x).Eval(); |
1089 | } |
1090 | } |
1091 | |
1092 | // DivMod rules |
1093 | // Divide up rounding: truc div |
1094 | // NOTE: trucdiv(x, y) >= floordiv(x, y) |
1095 | TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x, |
1096 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1097 | TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), |
1098 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && |
1099 | CanProveGreaterEqual(x.Eval(), 1)); |
1100 | |
1101 | TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, |
1102 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1103 | TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), |
1104 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && |
1105 | CanProveGreaterEqual(x.Eval(), 1)); |
1106 | |
1107 | // Divide up rounding: floor div |
1108 | TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x, |
1109 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1110 | TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), |
1111 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && |
1112 | CanProveGreaterEqual(x.Eval(), 1)); |
1113 | |
1114 | TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x, |
1115 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1116 | TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2), |
1117 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && |
1118 | CanProveGreaterEqual(x.Eval(), 1)); |
1119 | |
1120 | TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0); |
1121 | TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0); |
1122 | |
1123 | TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y)); |
1124 | TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y)); |
1125 | TVM_TRY_REWRITE(min(min(x, y), max(x, y)), min(x, y)); |
1126 | TVM_TRY_REWRITE(min(min(x, y), max(y, x)), min(x, y)); |
1127 | |
1128 | TVM_TRY_REWRITE(min(max(x, y), x), x); |
1129 | TVM_TRY_REWRITE(min(max(x, y), y), y); |
1130 | TVM_TRY_REWRITE(min(min(x, y), x), min(x, y)); |
1131 | TVM_TRY_REWRITE(min(min(x, y), y), min(x, y)); |
1132 | |
1133 | TVM_TRY_REWRITE(min(x, max(x, y)), x); |
1134 | TVM_TRY_REWRITE(min(y, max(x, y)), y); |
1135 | TVM_TRY_REWRITE(min(x, min(x, y)), min(x, y)); |
1136 | TVM_TRY_REWRITE(min(y, min(x, y)), min(x, y)); |
1137 | |
1138 | TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z)); |
1139 | TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1)); |
1140 | TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y), |
1141 | min(min(min(min(x, y), z), s1), s2)); |
1142 | |
1143 | TVM_TRY_REWRITE(min(max(x, y), max(x, z)), max(min(y, z), x)); |
1144 | TVM_TRY_REWRITE(min(max(x, y), max(z, x)), max(min(y, z), x)); |
1145 | TVM_TRY_REWRITE(min(max(y, x), max(x, z)), max(min(y, z), x)); |
1146 | TVM_TRY_REWRITE(min(max(y, x), max(z, x)), max(min(y, z), x)); |
1147 | |
1148 | TVM_TRY_REWRITE(min(min(x, y), min(x, z)), min(min(y, z), x)); |
1149 | TVM_TRY_REWRITE(min(min(x, y), min(z, x)), min(min(y, z), x)); |
1150 | TVM_TRY_REWRITE(min(min(y, x), min(x, z)), min(min(y, z), x)); |
1151 | TVM_TRY_REWRITE(min(min(y, x), min(z, x)), min(min(y, z), x)); |
1152 | |
1153 | TVM_TRY_REWRITE(min(y + x, z + x), min(y, z) + x); |
1154 | TVM_TRY_REWRITE(min(y + x, x + z), min(y, z) + x); |
1155 | TVM_TRY_REWRITE(min(x + y, x + z), min(y, z) + x); |
1156 | TVM_TRY_REWRITE(min(x + y, z + x), min(y, z) + x); |
1157 | |
1158 | // sub distribution |
1159 | TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x); |
1160 | TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z)); |
1161 | |
1162 | // constant folding rule. |
1163 | TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2))); |
1164 | |
1165 | // scaling rule |
1166 | if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) { |
1167 | if (c1.Eval()->value > 0) { |
1168 | return truncdiv(min(x, y), c1).Eval(); |
1169 | } else { |
1170 | return truncdiv(max(x, y), c1).Eval(); |
1171 | } |
1172 | } |
1173 | if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { |
1174 | if (c1.Eval()->value > 0) { |
1175 | return floordiv(min(x, y), c1).Eval(); |
1176 | } else { |
1177 | return floordiv(max(x, y), c1).Eval(); |
1178 | } |
1179 | } |
1180 | if (min(x * c1, y * c1).Match(ret)) { |
1181 | if (c1.Eval()->value > 0) { |
1182 | return (min(x, y) * c1).Eval(); |
1183 | } else { |
1184 | return (max(x, y) * c1).Eval(); |
1185 | } |
1186 | } |
1187 | if (min(x * c1, c2).Match(ret)) { |
1188 | int64_t c1val = c1.Eval()->value; |
1189 | int64_t c2val = c2.Eval()->value; |
1190 | if (c1val == 0) { |
1191 | return c2val < 0 ? c2.Eval() : c1.Eval(); |
1192 | } |
1193 | if (c2val % c1val == 0) { |
1194 | if (c1val > 0) { |
1195 | return (min(x, c2val / c1val) * c1val).Eval(); |
1196 | } else { |
1197 | return (max(x, c2val / c1val) * c1val).Eval(); |
1198 | } |
1199 | } |
1200 | } |
1201 | |
1202 | // canonicalization |
1203 | TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); |
1204 | TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0); |
1205 | } |
1206 | |
1207 | // condition rules. |
1208 | TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2))); |
1209 | return ret; |
1210 | } |
1211 | |
1212 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { |
1213 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1214 | op = ret.as<MaxNode>(); |
1215 | if (auto const_res = TryConstFold<Max>(op->a, op->b)) return const_res.value(); |
1216 | |
1217 | // Pattern var to match any expression |
1218 | PVar<PrimExpr> x, y, z, s1, s2; |
1219 | // Pattern var match IntImm |
1220 | PVar<IntImm> c1, c2; |
1221 | PVar<int> lanes; |
1222 | |
1223 | // vector rule |
1224 | if (op->dtype.lanes() != 1) { |
1225 | TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); |
1226 | TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), |
1227 | max(x, broadcast(max(y, z), lanes))); |
1228 | } |
1229 | if (IsIndexType(op->dtype)) { |
1230 | TVM_TRY_REWRITE(max(x, x), x); |
1231 | |
1232 | // constant int bound |
1233 | ConstIntBound a_bound = analyzer_->const_int_bound(op->a); |
1234 | ConstIntBound b_bound = analyzer_->const_int_bound(op->b); |
1235 | if (a_bound->min_value >= b_bound->max_value) { |
1236 | return op->a; |
1237 | } |
1238 | if (b_bound->min_value >= a_bound->max_value) { |
1239 | return op->b; |
1240 | } |
1241 | |
1242 | // constant comparison |
1243 | if (max(x + c1, x + c2).Match(ret)) { |
1244 | if (c1.Eval()->value > c2.Eval()->value) { |
1245 | return (x + c1).Eval(); |
1246 | } else { |
1247 | return (x + c2).Eval(); |
1248 | } |
1249 | } |
1250 | if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) { |
1251 | if (c1.Eval()->value > 0) { |
1252 | return (x + c1).Eval(); |
1253 | } else { |
1254 | return x.Eval(); |
1255 | } |
1256 | } |
1257 | if (max(c1 - x, c2 - x).Match(ret)) { |
1258 | if (c1.Eval()->value > c2.Eval()->value) { |
1259 | return (c1 - x).Eval(); |
1260 | } else { |
1261 | return (c2 - x).Eval(); |
1262 | } |
1263 | } |
1264 | |
1265 | // DivMod rules |
1266 | // Divide up rounding: truc div |
1267 | // NOTE: trucdiv(x, y) >= floordiv(x, y) |
1268 | TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2, |
1269 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1270 | TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2, |
1271 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1272 | |
1273 | // Divide up rounding: floor div |
1274 | TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2, |
1275 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1276 | TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2, |
1277 | c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); |
1278 | |
1279 | TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0); |
1280 | TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0); |
1281 | |
1282 | TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y)); |
1283 | TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y)); |
1284 | TVM_TRY_REWRITE(max(max(x, y), min(x, y)), max(x, y)); |
1285 | TVM_TRY_REWRITE(max(max(x, y), min(y, x)), max(x, y)); |
1286 | |
1287 | TVM_TRY_REWRITE(max(min(x, y), x), x); |
1288 | TVM_TRY_REWRITE(max(min(x, y), y), y); |
1289 | TVM_TRY_REWRITE(max(max(x, y), x), max(x, y)); |
1290 | TVM_TRY_REWRITE(max(max(x, y), y), max(x, y)); |
1291 | |
1292 | TVM_TRY_REWRITE(max(x, min(x, y)), x); |
1293 | TVM_TRY_REWRITE(max(y, min(x, y)), y); |
1294 | TVM_TRY_REWRITE(max(x, max(x, y)), max(x, y)); |
1295 | TVM_TRY_REWRITE(max(y, max(x, y)), max(x, y)); |
1296 | |
1297 | TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z)); |
1298 | TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1)); |
1299 | TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y), |
1300 | max(max(max(max(x, y), z), s1), s2)); |
1301 | |
1302 | // max/max cancelation |
1303 | TVM_TRY_REWRITE(max(max(x, y), max(x, z)), max(max(y, z), x)); |
1304 | TVM_TRY_REWRITE(max(max(x, y), max(z, x)), max(max(y, z), x)); |
1305 | TVM_TRY_REWRITE(max(max(y, x), max(x, z)), max(max(y, z), x)); |
1306 | TVM_TRY_REWRITE(max(max(y, x), max(z, x)), max(max(y, z), x)); |
1307 | |
1308 | // max/min distribution |
1309 | TVM_TRY_REWRITE(max(min(x, y), min(x, z)), min(max(y, z), x)); |
1310 | TVM_TRY_REWRITE(max(min(x, y), min(z, x)), min(max(y, z), x)); |
1311 | TVM_TRY_REWRITE(max(min(y, x), min(x, z)), min(max(y, z), x)); |
1312 | TVM_TRY_REWRITE(max(min(y, x), min(z, x)), min(max(y, z), x)); |
1313 | |
1314 | // add distribution |
1315 | TVM_TRY_REWRITE(max(y + x, z + x), max(y, z) + x); |
1316 | TVM_TRY_REWRITE(max(y + x, x + z), max(y, z) + x); |
1317 | TVM_TRY_REWRITE(max(x + y, x + z), max(y, z) + x); |
1318 | TVM_TRY_REWRITE(max(x + y, z + x), max(y, z) + x); |
1319 | |
1320 | // sub distribution |
1321 | TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x); |
1322 | TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z)); |
1323 | |
1324 | // constant folding rule. |
1325 | TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2))); |
1326 | |
1327 | // scaling rule |
1328 | if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) { |
1329 | if (c1.Eval()->value > 0) { |
1330 | return truncdiv(max(x, y), c1).Eval(); |
1331 | } else { |
1332 | return truncdiv(min(x, y), c1).Eval(); |
1333 | } |
1334 | } |
1335 | if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) { |
1336 | if (c1.Eval()->value > 0) { |
1337 | return floordiv(max(x, y), c1).Eval(); |
1338 | } else { |
1339 | return floordiv(min(x, y), c1).Eval(); |
1340 | } |
1341 | } |
1342 | if (max(x * c1, y * c1).Match(ret)) { |
1343 | if (c1.Eval()->value > 0) { |
1344 | return (max(x, y) * c1).Eval(); |
1345 | } else { |
1346 | return (min(x, y) * c1).Eval(); |
1347 | } |
1348 | } |
1349 | if (max(x * c1, c2).Match(ret)) { |
1350 | int64_t c1val = c1.Eval()->value; |
1351 | int64_t c2val = c2.Eval()->value; |
1352 | if (c1val == 0) { |
1353 | return c2val > 0 ? c2.Eval() : c1.Eval(); |
1354 | } |
1355 | if (c2val % c1val == 0) { |
1356 | if (c1val > 0) { |
1357 | return (max(x, c2val / c1val) * c1val).Eval(); |
1358 | } else { |
1359 | return (min(x, c2val / c1val) * c1val).Eval(); |
1360 | } |
1361 | } |
1362 | } |
1363 | |
1364 | // canonicalization |
1365 | TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); |
1366 | TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); |
1367 | } |
1368 | |
1369 | // condition rules. |
1370 | TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2))); |
1371 | return ret; |
1372 | } |
1373 | |
1374 | Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const { |
1375 | PrimExpr negation = Not(expr); |
1376 | |
1377 | ExprDeepEqual expr_equal; |
1378 | for (const auto& constraint : literal_constraints_) { |
1379 | if (expr_equal(constraint, expr)) { |
1380 | return make_const(expr->dtype, true); |
1381 | } |
1382 | if (expr_equal(constraint, negation)) { |
1383 | return make_const(expr->dtype, false); |
1384 | } |
1385 | } |
1386 | return NullOpt; |
1387 | } |
1388 | |
1389 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { |
1390 | EQ ret = Downcast<EQ>(IRMutatorWithAnalyzer::VisitExpr_(op)); |
1391 | op = ret.get(); |
1392 | |
1393 | if (auto const_res = TryConstFold<EQ>(op->a, op->b)) { |
1394 | return const_res.value(); |
1395 | } |
1396 | if (auto match = TryMatchLiteralConstraint(ret)) { |
1397 | return match.value(); |
1398 | } |
1399 | |
1400 | return ApplyRewriteRules(ret); |
1401 | } |
1402 | |
1403 | PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { |
1404 | // Pattern var to match any expression |
1405 | PVar<PrimExpr> x, y; |
1406 | // Pattern var match IntImm |
1407 | PVar<IntImm> c1; |
1408 | PVar<int> lanes; |
1409 | |
1410 | // vector rule |
1411 | if (ret->dtype.lanes() != 1) { |
1412 | TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); |
1413 | } |
1414 | |
1415 | if (IsIndexType(ret->a.dtype())) { |
1416 | CompareResult result = TryCompare(ret->a, ret->b); |
1417 | if (result == CompareResult::kEQ) { |
1418 | return make_const(ret->dtype, true); |
1419 | } else if (result == CompareResult::kNE || result == CompareResult::kGT || |
1420 | result == CompareResult::kLT) { |
1421 | return make_const(ret->dtype, false); |
1422 | } |
1423 | TVM_TRY_REWRITE(c1 == x, x == c1); |
1424 | |
1425 | TVM_TRY_REWRITE(x - c1 == 0, x == c1); |
1426 | TVM_TRY_REWRITE(c1 - x == 0, x == c1); |
1427 | TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); |
1428 | TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); |
1429 | } |
1430 | return std::move(ret); |
1431 | } |
1432 | |
1433 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { |
1434 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1435 | op = ret.as<NENode>(); |
1436 | |
1437 | if (auto const_res = TryConstFold<NE>(op->a, op->b)) return const_res.value(); |
1438 | if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); |
1439 | |
1440 | if (IsIndexType(op->a.dtype())) { |
1441 | CompareResult result = TryCompare(op->a, op->b); |
1442 | if (result == CompareResult::kNE || result == CompareResult::kGT || |
1443 | result == CompareResult::kLT) { |
1444 | return make_const(op->dtype, true); |
1445 | } else if (result == CompareResult::kEQ) { |
1446 | return make_const(op->dtype, false); |
1447 | } else if (result == CompareResult::kGE) { |
1448 | // Known: a >= b |
1449 | // |
1450 | // a != b |
1451 | // (a < b) or (b < a) |
1452 | // False or (b < a) |
1453 | // b < a |
1454 | return ApplyRewriteRules(LT(op->b, op->a)); |
1455 | } else if (result == CompareResult::kLE) { |
1456 | // Known: a <= b |
1457 | // |
1458 | // a != b |
1459 | // (a < b) or (b < a) |
1460 | // (a < b) or False |
1461 | // a < b |
1462 | return ApplyRewriteRules(LT(op->a, op->b)); |
1463 | } |
1464 | } |
1465 | |
1466 | return ApplyRewriteRules(Not(ApplyRewriteRules(EQ(op->a, op->b)))); |
1467 | } |
1468 | |
1469 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { |
1470 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1471 | op = ret.as<LENode>(); |
1472 | ICHECK(op); |
1473 | |
1474 | if (auto const_res = TryConstFold<LE>(op->a, op->b)) return const_res.value(); |
1475 | if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); |
1476 | |
1477 | // Check for applicable rewrites before attempting to prove/disprove |
1478 | // the inequality. This preserves earlier behavior, where (A<=B*x) |
1479 | // simplifies to (ceildiv(A,B)<=x) when (A%B!=0). Performing the |
1480 | // TryCompare first would simplify to the equivalent |
1481 | // (floordiv(A,B)<x) in these cases instead. |
1482 | ret = ApplyRewriteRules(Not(ApplyRewriteRules(LT(op->b, op->a)))); |
1483 | |
1484 | if (auto op = ret.as<LENode>(); op && IsIndexType(op->a.dtype())) { |
1485 | CompareResult result = TryCompare(op->a, op->b); |
1486 | if (result == CompareResult::kLE || result == CompareResult::kLT || |
1487 | result == CompareResult::kEQ) { |
1488 | return make_const(op->dtype, true); |
1489 | } else if (result == CompareResult::kGT) { |
1490 | return make_const(op->dtype, false); |
1491 | } else if (result == CompareResult::kNE) { |
1492 | // Known: a != b |
1493 | // |
1494 | // a <= b |
1495 | // (a < b) or (a == b) |
1496 | // (a < b) or False |
1497 | // a < b |
1498 | return ApplyRewriteRules(LT(op->a, op->b)); |
1499 | } else if (result == CompareResult::kGE) { |
1500 | // Known: a >= b |
1501 | // |
1502 | // a <= b |
1503 | // (a < b) or (a == b) |
1504 | // False or (a == b) |
1505 | // a == b |
1506 | return ApplyRewriteRules(EQ(op->a, op->b)); |
1507 | } |
1508 | } |
1509 | |
1510 | return ret; |
1511 | } |
1512 | |
1513 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { |
1514 | return this->VisitExpr(op->b < op->a); |
1515 | } |
1516 | |
1517 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { |
1518 | return this->VisitExpr(op->b <= op->a); |
1519 | } |
1520 | |
1521 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { |
1522 | LT node = Downcast<LT>(IRMutatorWithAnalyzer::VisitExpr_(op)); |
1523 | op = node.get(); |
1524 | |
1525 | if (auto const_res = TryConstFold<LT>(op->a, op->b)) return const_res.value(); |
1526 | if (auto match = TryMatchLiteralConstraint(node)) return match.value(); |
1527 | |
1528 | return ApplyRewriteRules(node); |
1529 | } |
1530 | |
1531 | PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { |
1532 | // Pattern var to match any expression |
1533 | PVar<PrimExpr> x, y, z, s1, s2; |
1534 | // Pattern var match IntImm |
1535 | PVar<IntImm> c1, c2; |
1536 | PVar<int> lanes; |
1537 | |
1538 | // vector rule |
1539 | if (ret->dtype.lanes() != 1) { |
1540 | TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); |
1541 | TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); |
1542 | } |
1543 | |
1544 | if (IsIndexType(ret->a.dtype())) { |
1545 | CompareResult result = TryCompare(ret->a, ret->b); |
1546 | if (result == CompareResult::kLT) { |
1547 | return make_const(ret->dtype, true); |
1548 | } |
1549 | if (result == CompareResult::kEQ || result == CompareResult::kGT || |
1550 | result == CompareResult::kGE) { |
1551 | return make_const(ret->dtype, false); |
1552 | } |
1553 | |
1554 | // clang-format off |
1555 | TVM_TRY_REWRITE(x + y < x + z, y < z); |
1556 | TVM_TRY_REWRITE(x + y < z + x, y < z); |
1557 | TVM_TRY_REWRITE(y + x < x + z, y < z); |
1558 | TVM_TRY_REWRITE(y + x < z + x, y < z); |
1559 | TVM_TRY_REWRITE(y - x < z - x, y < z); |
1560 | TVM_TRY_REWRITE(x - y < x - z, z < y); |
1561 | |
1562 | TVM_TRY_REWRITE(x < x + z, 0 < z); |
1563 | TVM_TRY_REWRITE(x < z + x, 0 < z); |
1564 | TVM_TRY_REWRITE(x < x - z, z < 0); |
1565 | TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x); |
1566 | TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1); |
1567 | |
1568 | TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0); |
1569 | TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); |
1570 | |
1571 | // constant cancelation: only need to make use of one mod |
1572 | // truc div |
1573 | TVM_TRY_REWRITE_IF(x * c2 < c1, |
1574 | x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); |
1575 | // NOTE: trunc div required |
1576 | TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2), |
1577 | c1.Eval()->value <= 0 && c2.Eval()->value > 0); |
1578 | // NOTE: trunc div required (euclidean is ok too, floored is not) |
1579 | TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 && |
1580 | c2.Eval()->value < 0); |
1581 | // NOTE: trunc div required (floored is ok too, euclidean is not) |
1582 | TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x, |
1583 | c1.Eval()->value <= 0 && c2.Eval()->value < 0); |
1584 | // NOTE: trunc div required |
1585 | TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x, |
1586 | c1.Eval()->value < 0 && c2.Eval()->value > 0); |
1587 | TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x, |
1588 | c1.Eval()->value >= 0 && c2.Eval()->value > 0); |
1589 | // NOTE: trunc div required (floored is ok too, euclidean is not) |
1590 | TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1, |
1591 | c1.Eval()->value < 0 && c2.Eval()->value < 0); |
1592 | // NOTE: trunc div required (euclidean is ok too, floored is not) |
1593 | TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2), |
1594 | c1.Eval()->value >= 0 && c2.Eval()->value < 0); |
1595 | // DivMod rules |
1596 | // trucdiv |
1597 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, |
1598 | x<c1 * c2, c1.Eval()->value> 0 && c2.Eval()->value > 0); |
1599 | // NOTE: trunc div required |
1600 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, |
1601 | x<c1*(c2 - 1) + 1, c1.Eval()->value> 0 && c2.Eval()->value <= 0); |
1602 | |
1603 | TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x, |
1604 | c1.Eval()->value >= 0 && c2.Eval()->value > 0); |
1605 | // NOTE: trunc div required |
1606 | TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x, |
1607 | c1.Eval()->value < 0 && c2.Eval()->value > 0); |
1608 | |
1609 | // invariance for any div mod: x - (x / c1) * c1 == x % c1 |
1610 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0); |
1611 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, |
1612 | 0 < truncmod(x, c1) + y, c1.Eval()->value > 0); |
1613 | TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, |
1614 | y < truncmod(x, c1), c1.Eval()->value > 0); |
1615 | |
1616 | TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x, |
1617 | c2 < truncmod(x + c2, c1), c1.Eval()->value > 0); |
1618 | TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y, |
1619 | c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0); |
1620 | TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y, |
1621 | y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); |
1622 | |
1623 | // floordiv |
1624 | TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0); |
1625 | TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0); |
1626 | |
1627 | TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0); |
1628 | TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, |
1629 | 0 < floormod(x, c1) + y, c1.Eval()->value > 0); |
1630 | TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, |
1631 | y < floormod(x, c1), c1.Eval()->value > 0); |
1632 | TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x, |
1633 | c2 < floormod(x + c2, c1), c1.Eval()->value > 0); |
1634 | TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y, |
1635 | c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0); |
1636 | TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y, |
1637 | y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); |
1638 | |
1639 | // canonicalization rule |
1640 | TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z); |
1641 | TVM_TRY_RECURSIVE_REWRITE(max(x, y) < z, x < z && y < z); |
1642 | TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y); |
1643 | TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y); |
1644 | |
1645 | TVM_TRY_RECURSIVE_REWRITE(x < c1 - y, x + y < c1); |
1646 | TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1); |
1647 | TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y); |
1648 | TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); |
1649 | |
1650 | TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); |
1651 | TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); |
1652 | TVM_TRY_REWRITE(x - c1 < 0, x < c1); |
1653 | |
1654 | TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y); |
1655 | TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y); |
1656 | TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y); |
1657 | TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y); |
1658 | // clang-format on |
1659 | } |
1660 | return std::move(ret); |
1661 | } |
1662 | |
1663 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { |
1664 | Not ret = Downcast<Not>(IRMutatorWithAnalyzer::VisitExpr_(op)); |
1665 | if (auto const_res = TryConstFold<Not>(ret->a)) return const_res.value(); |
1666 | if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); |
1667 | |
1668 | return ApplyRewriteRules(ret); |
1669 | } |
1670 | |
1671 | PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) { |
1672 | // Pattern var to match any expression |
1673 | PVar<PrimExpr> x, y; |
1674 | PVar<int> lanes; |
1675 | if (ret->dtype.lanes() != 1) { |
1676 | TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); |
1677 | } |
1678 | |
1679 | TVM_TRY_REWRITE(!(!x), x); |
1680 | TVM_TRY_REWRITE(!(x <= y), y < x); |
1681 | TVM_TRY_REWRITE(!(x >= y), x < y); |
1682 | TVM_TRY_REWRITE(!(x < y), y <= x); |
1683 | TVM_TRY_REWRITE(!(x > y), x <= y); |
1684 | TVM_TRY_REWRITE(!(x == y), x != y); |
1685 | TVM_TRY_REWRITE(!(x != y), x == y); |
1686 | TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); |
1687 | TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); |
1688 | return std::move(ret); |
1689 | } |
1690 | |
1691 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { |
1692 | PrimExpr ret = [&]() -> PrimExpr { |
1693 | // If this extension isn't enabled, just delegate out. |
1694 | if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) { |
1695 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
1696 | } |
1697 | |
1698 | PrimExpr a = op->a; |
1699 | PrimExpr b = op->b; |
1700 | |
1701 | // Alternate which branch is used as the constraint, and which is |
1702 | // being simplified. Because some sub-analyzers expect their |
1703 | // constraints to already be simplified, each branch may require |
1704 | // more than one update. The loop condition allows each branch to |
1705 | // be visited up to twice, but only performs the second visit if |
1706 | // necessary. |
1707 | size_t iterations_since_update = 0; |
1708 | for (size_t i = 0; i < 4; i++) { |
1709 | PrimExpr& to_update = (i % 2 == 0) ? a : b; |
1710 | const PrimExpr& constraint = (i % 2 == 0) ? b : a; |
1711 | |
1712 | With<ConstraintContext> context(analyzer_, constraint); |
1713 | PrimExpr updated = VisitExpr(to_update); |
1714 | |
1715 | if (!to_update.same_as(updated)) { |
1716 | to_update = updated; |
1717 | iterations_since_update = 0; |
1718 | } else { |
1719 | iterations_since_update++; |
1720 | if (iterations_since_update >= 2) { |
1721 | break; |
1722 | } |
1723 | } |
1724 | } |
1725 | |
1726 | // Only construct a new object if a change has been made. |
1727 | // Otherwise, follow ExprMutator's convention of returning the |
1728 | // original object. |
1729 | if (a.same_as(op->a) && b.same_as(op->b)) { |
1730 | return GetRef<PrimExpr>(op); |
1731 | } else { |
1732 | return And(a, b); |
1733 | } |
1734 | }(); |
1735 | |
1736 | op = ret.as<AndNode>(); |
1737 | |
1738 | if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value(); |
1739 | if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); |
1740 | if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && |
1741 | !recursively_visiting_boolean_) { |
1742 | return SimplifyAsAndOfOrs(ret, analyzer_); |
1743 | } |
1744 | |
1745 | // Pattern var to match any expression |
1746 | PVar<PrimExpr> x, y, z; |
1747 | // Pattern var match IntImm |
1748 | PVar<IntImm> c1, c2, c3; |
1749 | PVar<int> lanes; |
1750 | |
1751 | if (op->dtype.lanes() != 1) { |
1752 | TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); |
1753 | } |
1754 | |
1755 | auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false)); |
1756 | TVM_TRY_REWRITE(x == y && x != y, cfalse); |
1757 | TVM_TRY_REWRITE(x != y && x == y, cfalse); |
1758 | TVM_TRY_REWRITE(x && !x, cfalse); |
1759 | TVM_TRY_REWRITE(x <= y && y < x, cfalse); |
1760 | TVM_TRY_REWRITE(y < x && x <= y, cfalse); |
1761 | |
1762 | TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); |
1763 | TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); |
1764 | |
1765 | TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); |
1766 | TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); |
1767 | TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); |
1768 | TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); |
1769 | |
1770 | TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); |
1771 | TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); |
1772 | |
1773 | TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); |
1774 | TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); |
1775 | |
1776 | TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3); |
1777 | TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3); |
1778 | |
1779 | TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 && |
1780 | x - y * c1<c1, y == floordiv(x, c1), c1.Eval()->value> 0); |
1781 | TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1), |
1782 | c1.Eval()->value > 0); |
1783 | |
1784 | TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1)); |
1785 | TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1)); |
1786 | TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1), |
1787 | c2.Eval()->value == -c1.Eval()->value); |
1788 | TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1), |
1789 | c2.Eval()->value == -c1.Eval()->value); |
1790 | |
1791 | TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3, |
1792 | x < c1 - c2 + c3 && floormod(x, c2) < c3, |
1793 | c1.Eval()->value % c2.Eval()->value == 0); |
1794 | TVM_TRY_RECURSIVE_REWRITE_IF( |
1795 | x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3, |
1796 | (c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value > |
1797 | c3.Eval()->value); |
1798 | |
1799 | TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3, |
1800 | x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3, |
1801 | (c1.Eval()->value + 1) % c2.Eval()->value == 0); |
1802 | TVM_TRY_RECURSIVE_REWRITE_IF( |
1803 | x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3, |
1804 | (((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value > |
1805 | c3.Eval()->value); |
1806 | |
1807 | TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3, |
1808 | c1 * c2 <= x && x < c1 * c2 + c3); |
1809 | TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1, |
1810 | c1 * c2 <= x && x < c1 * c2 + c3); |
1811 | TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3, |
1812 | c1 * c2 <= x && x <= c1 * c2 + c3); |
1813 | TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1, |
1814 | c1 * c2 <= x && x <= c1 * c2 + c3); |
1815 | |
1816 | TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2), |
1817 | c1 * c2 + c3 <= x && x < (c1 + 1) * c2); |
1818 | TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1, |
1819 | c1 * c2 + c3 <= x && x < (c1 + 1) * c2); |
1820 | TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2), |
1821 | c1 * c2 + c3 < x && x < (c1 + 1) * c2); |
1822 | TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1, |
1823 | c1 * c2 + c3 < x && x < (c1 + 1) * c2); |
1824 | |
1825 | TVM_TRY_RECURSIVE_REWRITE(x && (y && z), (x && y) && z); |
1826 | |
1827 | return ret; |
1828 | } |
1829 | |
1830 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { |
1831 | PrimExpr orig = GetRef<PrimExpr>(op); |
1832 | |
1833 | PrimExpr ret = [&]() -> PrimExpr { |
1834 | // If this extension isn't enabled, just delegate out. |
1835 | if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) { |
1836 | return IRMutatorWithAnalyzer::VisitExpr_(op); |
1837 | } |
1838 | |
1839 | PrimExpr a = op->a; |
1840 | PrimExpr b = op->b; |
1841 | |
1842 | // Alternate which branch is used as the constraint, and which |
1843 | // is being simplified. Because some sub-analyzers expect their |
1844 | // constraints to already be simplified, each branch may require |
1845 | // more than update. The loop condition allows each branch to be |
1846 | // visited up to twice, but only if performs the second visit if |
1847 | // necessary. |
1848 | size_t iterations_since_update = 0; |
1849 | for (size_t i = 0; i < 4; i++) { |
1850 | PrimExpr& to_update = (i % 2 == 0) ? a : b; |
1851 | const PrimExpr& constraint = (i % 2 == 0) ? b : a; |
1852 | |
1853 | With<ConstraintContext> context(analyzer_, NormalizeBooleanOperators(Not(constraint))); |
1854 | PrimExpr updated = VisitExpr(to_update); |
1855 | |
1856 | if (!to_update.same_as(updated)) { |
1857 | to_update = updated; |
1858 | iterations_since_update = 0; |
1859 | } else { |
1860 | iterations_since_update++; |
1861 | if (iterations_since_update >= 2) { |
1862 | break; |
1863 | } |
1864 | } |
1865 | } |
1866 | |
1867 | // Only construct a new object if a change has been made. |
1868 | // Otherwise, follow ExprMutator's convention of returning the |
1869 | // original object. |
1870 | if (a.same_as(op->a) && b.same_as(op->b)) { |
1871 | return GetRef<PrimExpr>(op); |
1872 | } else { |
1873 | return Or(a, b); |
1874 | } |
1875 | }(); |
1876 | |
1877 | op = ret.as<OrNode>(); |
1878 | if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value(); |
1879 | if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); |
1880 | if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && |
1881 | !recursively_visiting_boolean_) { |
1882 | return SimplifyAsAndOfOrs(ret, analyzer_); |
1883 | } |
1884 | |
1885 | // Pattern var to match any expression |
1886 | PVar<PrimExpr> x, y, z; |
1887 | // Pattern var match IntImm |
1888 | PVar<IntImm> c1, c2; |
1889 | PVar<int> lanes; |
1890 | |
1891 | if (op->dtype.lanes() != 1) { |
1892 | TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); |
1893 | } |
1894 | |
1895 | auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true)); |
1896 | |
1897 | TVM_TRY_REWRITE(x == y || x != y, ctrue); |
1898 | TVM_TRY_REWRITE(x != y || x == y, ctrue); |
1899 | TVM_TRY_REWRITE(x || !x, ctrue); |
1900 | TVM_TRY_REWRITE(x <= y || y < x, ctrue); |
1901 | TVM_TRY_REWRITE(y < x || x <= y, ctrue); |
1902 | |
1903 | TVM_TRY_REWRITE(x < y || y < x, x != y); |
1904 | |
1905 | TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); |
1906 | TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); |
1907 | |
1908 | TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); |
1909 | TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); |
1910 | TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); |
1911 | TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); |
1912 | |
1913 | TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); |
1914 | TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); |
1915 | |
1916 | TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); |
1917 | TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); |
1918 | |
1919 | TVM_TRY_RECURSIVE_REWRITE(x < y || x == y, x <= y); |
1920 | TVM_TRY_RECURSIVE_REWRITE(x < y || y == x, x <= y); |
1921 | TVM_TRY_RECURSIVE_REWRITE(x == y || x < y, x <= y); |
1922 | TVM_TRY_RECURSIVE_REWRITE(y == x || x < y, x <= y); |
1923 | |
1924 | TVM_TRY_RECURSIVE_REWRITE(x || (y || z), (x || y) || z); |
1925 | |
1926 | return ret; |
1927 | } |
1928 | |
1929 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) { |
1930 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1931 | op = ret.as<SelectNode>(); |
1932 | if (op == nullptr) return ret; |
1933 | // Pattern var to match any expression |
1934 | PVar<PrimExpr> x, y; |
1935 | TVM_TRY_REWRITE(select(x, y, y), y); |
1936 | return ret; |
1937 | } |
1938 | |
1939 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { |
1940 | // add condition context to if_then_else |
1941 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
1942 | op = ret.as<CallNode>(); |
1943 | if (op == nullptr) return ret; |
1944 | |
1945 | if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) { |
1946 | return op->args[0]; |
1947 | } else if (op->op.same_as(tir::builtin::shift_right())) { |
1948 | if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) { |
1949 | // the operator overload will eagerly constant fold. |
1950 | return op->args[0] >> op->args[1]; |
1951 | } |
1952 | } else if (op->op.same_as(tir::builtin::shift_left())) { |
1953 | if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) { |
1954 | // the operator overload will eagerly constant fold. |
1955 | return op->args[0] << op->args[1]; |
1956 | } |
1957 | } else if (op->op.same_as(Op::Get("tir.ceil" ))) { |
1958 | PrimExpr ceil_arg = op->args[0]; |
1959 | if (auto arg_int = op->args[0].as<IntImmNode>()) { |
1960 | return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); |
1961 | } else if (auto arg_float = ceil_arg.as<FloatImmNode>()) { |
1962 | return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value))); |
1963 | } else if (auto arg_call = ceil_arg.as<CallNode>()) { |
1964 | // ceil(log2(cast(n,"float64"))) is used as the implementation of |
1965 | // topi.math.ceil_log2, and appears in iteration bounds. |
1966 | if (arg_call->op.same_as(Op::Get("tir.log2" ))) { |
1967 | PrimExpr log_arg = arg_call->args[0]; |
1968 | if (auto as_float = log_arg.as<FloatImmNode>()) { |
1969 | // ceil(log2(n)) can be simplified, and should produce the |
1970 | // same integer result regardless of the target's rounding |
1971 | // conventions. |
1972 | return FloatImm(op->dtype, std::ceil(std::log2(as_float->value))); |
1973 | } |
1974 | } |
1975 | } |
1976 | } |
1977 | |
1978 | if (op->op.same_as(tir::builtin::likely())) { |
1979 | // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } |
1980 | if (auto match = TryMatchLiteralConstraint(op->args[0])) { |
1981 | return match.value(); |
1982 | } |
1983 | } |
1984 | |
1985 | if (op->op.same_as(tir::builtin::if_then_else())) { |
1986 | // Simplify nested if_then_else |
1987 | // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr } |
1988 | // => if (cond && inner_cond) { inner_then_expr } else { else_expr } |
1989 | const PrimExpr& cond = op->args[0]; |
1990 | const PrimExpr& then_expr = op->args[1]; |
1991 | const PrimExpr& else_expr = op->args[2]; |
1992 | const CallNode* inner_call = then_expr.as<CallNode>(); |
1993 | if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) { |
1994 | const PrimExpr& inner_cond = inner_call->args[0]; |
1995 | const PrimExpr& inner_then_expr = inner_call->args[1]; |
1996 | const PrimExpr& inner_else_expr = inner_call->args[2]; |
1997 | // Only check constant cases to avoid recursion |
1998 | if (is_const_number(inner_else_expr) && is_const_number(else_expr) && |
1999 | analyzer_->CanProve(inner_else_expr == else_expr)) { |
2000 | return if_then_else(cond && inner_cond, inner_then_expr, else_expr); |
2001 | } |
2002 | } |
2003 | } |
2004 | |
2005 | return ret; |
2006 | } |
2007 | |
2008 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { |
2009 | Var var = GetRef<Var>(op); |
2010 | if (op->dtype == DataType::Bool()) { |
2011 | if (auto match = TryMatchLiteralConstraint(var)) { |
2012 | return match.value(); |
2013 | } |
2014 | } |
2015 | |
2016 | auto it = var_map_.find(var); |
2017 | if (it != var_map_.end()) { |
2018 | return it->second; |
2019 | } |
2020 | return GetRef<PrimExpr>(op); |
2021 | } |
2022 | |
2023 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { |
2024 | PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); |
2025 | op = ret.as<CastNode>(); |
2026 | return cast(op->dtype, op->value); |
2027 | } |
2028 | |
2029 | bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) { |
2030 | // Only inline trivial bindings to avoid deep expression explosion |
2031 | // when we need let to construct complicated expressions. |
2032 | if (is_const_number(op->value)) return true; |
2033 | if (op->value.as<VarNode>()) return true; |
2034 | return false; |
2035 | } |
2036 | |
2037 | PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { |
2038 | PrimExpr value = this->VisitExpr(op->value); |
2039 | if (CanInlineLet(op)) { |
2040 | // it is fine to discard the let binding |
2041 | // because the value will always be inlined in the simplifier. |
2042 | analyzer_->Bind(op->var, value); |
2043 | return this->VisitExpr(op->body); |
2044 | } |
2045 | PrimExpr body = this->VisitExpr(op->body); |
2046 | if (value.same_as(op->value) && body.same_as(op->body)) { |
2047 | return GetRef<PrimExpr>(op); |
2048 | } else { |
2049 | return Let(op->var, value, body); |
2050 | } |
2051 | } |
2052 | |
2053 | PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { |
2054 | // Run simplification in post order |
2055 | PrimExpr res = expr; |
2056 | int max_iter = 2; |
2057 | for (int i = 0; i < max_iter; ++i) { |
2058 | PrimExpr new_expr = impl_->operator()(res); |
2059 | if (new_expr.same_as(res)) return res; |
2060 | res = new_expr; |
2061 | } |
2062 | return res; |
2063 | } |
2064 | |
2065 | void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_override) { |
2066 | impl_->Update(var, info, allow_override); |
2067 | } |
2068 | |
2069 | std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { |
2070 | return impl_->EnterConstraint(constraint); |
2071 | } |
2072 | |
2073 | void RewriteSimplifier::SetEnabledExtensions(Extension flags) { |
2074 | impl_->SetEnabledExtensions(flags); |
2075 | } |
2076 | RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const { |
2077 | return impl_->GetEnabledExtensions(); |
2078 | } |
2079 | |
2080 | RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} |
2081 | |
2082 | RewriteSimplifier::~RewriteSimplifier() { delete impl_; } |
2083 | |
2084 | } // namespace arith |
2085 | } // namespace tvm |
2086 | |