1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file canonical_simplify.cc |
22 | * \brief Canonical form based simplification. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | #include "const_fold.h" |
29 | #include "pattern_match.h" |
30 | #include "rewrite_simplify.h" |
31 | |
32 | namespace tvm { |
33 | namespace arith { |
34 | |
35 | using namespace tir; |
36 | |
37 | class SumExpr; |
38 | class SplitExpr; |
39 | |
40 | /*! |
41 | * \brief Base class of all temporary expression introduced |
42 | * for canonicalization. |
43 | */ |
44 | class CanonicalExprNode : public PrimExprNode { |
45 | public: |
46 | virtual ~CanonicalExprNode() {} |
47 | /*! |
48 | * \brief Return the normal Expr that is equivalent to self. |
49 | * \note Can mutate the internal data structure. |
50 | * \return The normal expression. |
51 | */ |
52 | virtual PrimExpr Normalize() const = 0; |
53 | |
54 | // overrides |
55 | void VisitAttrs(tvm::AttrVisitor* v) {} |
56 | |
57 | static constexpr const char* _type_key = "arith.CanonicalExpr" ; |
58 | static constexpr const uint32_t _type_child_slots = 2; |
59 | TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); |
60 | }; |
61 | |
62 | inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) { |
63 | if (mode == kTruncDiv) { |
64 | return truncmod(a, b); |
65 | } else { |
66 | ICHECK_EQ(mode, kFloorDiv); |
67 | return floormod(a, b); |
68 | } |
69 | } |
70 | |
71 | inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { |
72 | if (mode == kTruncDiv) { |
73 | return truncdiv(a, b); |
74 | } else { |
75 | ICHECK_EQ(mode, kFloorDiv); |
76 | return floordiv(a, b); |
77 | } |
78 | } |
79 | |
80 | /*! |
81 | * \brief check if value fits in dtype |
82 | * \param value The value to be analyzed |
83 | * \param dtype The target dtype |
84 | * \param analyzer The analyzer |
85 | * \return whether value fits in dtype |
86 | */ |
87 | bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) { |
88 | if (!IsIndexType(dtype)) { |
89 | return false; |
90 | } |
91 | ConstIntBound bound = analyzer->const_int_bound(value); |
92 | int64_t ubound = Downcast<IntImm>(max_value(dtype))->value; |
93 | int64_t lbound = Downcast<IntImm>(min_value(dtype))->value; |
94 | if (value.dtype().bits() <= dtype.bits() || // upcast is safe |
95 | (bound->max_value <= ubound && bound->min_value >= lbound)) { |
96 | return true; |
97 | } |
98 | return false; |
99 | } |
100 | |
101 | /*! |
102 | * \brief Internal "Split normal form" of expression. |
103 | * |
104 | * This is a special expression that represents |
105 | * a scaled value derived from a split of an index. |
106 | * |
107 | * result = ((index % upper_factor) / lower_factor) * scale |
108 | */ |
109 | class SplitExprNode : public CanonicalExprNode { |
110 | public: |
111 | /*! \brief The base index expression. */ |
112 | PrimExpr index; |
113 | /*! \brief The division factor ratio. */ |
114 | int64_t lower_factor{1}; |
115 | /*! |
116 | * \brief The upper factor. |
117 | * invariance: (upper_factor == kPosInf || upper_factor % lower_factor == 0) |
118 | */ |
119 | int64_t upper_factor{kPosInf}; |
120 | /*! \brief scale to the expression. */ |
121 | int64_t scale{1}; |
122 | /*! \brief Division mode. */ |
123 | DivMode div_mode{kTruncDiv}; |
124 | |
125 | /*! \brief verify that this is a valid entry. */ |
126 | void Verify() const { ICHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } |
127 | |
128 | PrimExpr NormalizeWithScale(int64_t sscale) const { |
129 | PrimExpr res = this->index; |
130 | DataType dtype = this->dtype; |
131 | if (this->scale == 0) { |
132 | return make_const(dtype, 0); |
133 | } |
134 | if (this->upper_factor != SplitExprNode::kPosInf) { |
135 | res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode); |
136 | } |
137 | if (this->lower_factor != 1) { |
138 | res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode); |
139 | } |
140 | sscale *= this->scale; |
141 | if (sscale != 1) { |
142 | ICHECK(!dtype.is_uint() || sscale > 0); |
143 | res = res * make_const(dtype, sscale); |
144 | } |
145 | return res; |
146 | } |
147 | |
148 | PrimExpr Normalize() const final { return NormalizeWithScale(1); } |
149 | |
150 | void MulToSelf(int64_t scale) { this->scale *= scale; } |
151 | |
152 | /*! |
153 | * \brief check if cast can be pushed to sub-expressions |
154 | * \param dtype The target datatype |
155 | * \param analyzer The analyzer |
156 | * \return whether the cast can be safely pushed to children |
157 | */ |
158 | bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { |
159 | // cast(dtype, index % upper_factor / lower_factor * scale) == |
160 | // cast(dtype, index) % upper_factor / lower_factor * scale |
161 | // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of |
162 | // its intermediate results fit in the range of dtype |
163 | if (dtype.bits() >= this->dtype.bits()) { |
164 | return true; // upcast is safe |
165 | } |
166 | PrimExpr res = this->index; |
167 | if (this->scale == 0) { |
168 | return true; |
169 | } |
170 | if (!CastIsSafe(dtype, res, analyzer)) { |
171 | return false; |
172 | } |
173 | if (this->upper_factor != SplitExprNode::kPosInf) { |
174 | res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); |
175 | if (!CastIsSafe(dtype, res, analyzer)) { |
176 | return false; |
177 | } |
178 | } |
179 | if (this->lower_factor != 1) { |
180 | res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); |
181 | if (!CastIsSafe(dtype, res, analyzer)) { |
182 | return false; |
183 | } |
184 | } |
185 | if (this->scale != 1) { |
186 | ICHECK(!this->dtype.is_uint() || this->scale > 0); |
187 | res = res * make_const(this->dtype, this->scale); |
188 | if (!CastIsSafe(dtype, res, analyzer)) { |
189 | return false; |
190 | } |
191 | } |
192 | return true; |
193 | } |
194 | |
195 | /*! |
196 | * \brief self = cast(dtype, self) |
197 | * \param dtype The target datatype |
198 | */ |
199 | void PushCastToChildren(DataType dtype) { |
200 | this->index = cast(dtype, this->index); |
201 | this->dtype = dtype; |
202 | } |
203 | |
204 | inline bool IndexEqual(const SplitExpr& other) const; |
205 | inline bool DivModeCompatibleTo(DivMode mode) const; |
206 | |
207 | /*! \brief positive infty */ |
208 | static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; |
209 | static constexpr const char* _type_key = "arith.SplitExpr" ; |
210 | TVM_DECLARE_FINAL_OBJECT_INFO(SplitExprNode, CanonicalExprNode); |
211 | }; |
212 | |
213 | class SplitExpr : public PrimExpr { |
214 | public: |
215 | TVM_DEFINE_OBJECT_REF_METHODS(SplitExpr, PrimExpr, SplitExprNode); |
216 | TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitExprNode); |
217 | }; |
218 | |
219 | inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { |
220 | if (index.same_as(other->index)) return true; |
221 | return tir::ExprDeepEqual()(index, other->index); |
222 | } |
223 | |
224 | inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { |
225 | if (this->div_mode == mode) return true; |
226 | if (lower_factor == 1 && upper_factor == kPosInf) return true; |
227 | return false; |
228 | } |
229 | |
230 | /*! |
231 | * \brief Normal form that represents sum of expressions. |
232 | * |
233 | * result = sum(args) + base. |
234 | */ |
235 | class SumExprNode : public CanonicalExprNode { |
236 | public: |
237 | /*! |
238 | * \brief arguments to be summed up. |
239 | * |
240 | * args are divided into segments with the same index. |
241 | * within each segment, the SplitExpr is ordered in descending order of lower_factor. |
242 | */ |
243 | std::vector<SplitExpr> args; |
244 | /*! \brief Base value in the summation. */ |
245 | int64_t base{0}; |
246 | /*! \brief The expression equals zero. */ |
247 | bool IsZero() const { return base == 0 && args.size() == 0; } |
248 | /*! |
249 | * \brief Return the normal Expr that is equivalent to self. |
250 | * \return The normal expression. |
251 | */ |
252 | PrimExpr Normalize() const final { |
253 | // quick path 1. |
254 | if (this->args.size() == 0) { |
255 | return make_const(this->dtype, this->base); |
256 | } |
257 | return Normalize_(this->dtype, SimplifySplitExprs(args), base); |
258 | } |
259 | /*! |
260 | * \brief Whether self is divisible by scale. |
261 | * \param scale The scale to be applied. |
262 | */ |
263 | bool DivisibleBy(int64_t scale) { |
264 | if (base % scale != 0) return false; |
265 | for (size_t i = 0; i < this->args.size(); ++i) { |
266 | if (args[i]->scale % scale != 0) return false; |
267 | } |
268 | return true; |
269 | } |
270 | /*! |
271 | * \brief mul scale to self. |
272 | * \param scale The scale to be applied. |
273 | */ |
274 | void MulToSelf(int64_t scale) { |
275 | this->base *= scale; |
276 | for (size_t i = 0; i < this->args.size(); ++i) { |
277 | args[i].CopyOnWrite()->scale *= scale; |
278 | } |
279 | } |
280 | /*! |
281 | * \brief divide by scale. |
282 | * \param scale The scale to be applied. |
283 | */ |
284 | void DivideBy(int64_t scale) { |
285 | ICHECK_EQ(this->base % scale, 0); |
286 | this->base /= scale; |
287 | for (size_t i = 0; i < this->args.size(); ++i) { |
288 | ICHECK_EQ(args[i]->scale % scale, 0); |
289 | args[i].CopyOnWrite()->scale /= scale; |
290 | } |
291 | } |
292 | /*! |
293 | * \brief add constant value to self. |
294 | * \param value to be added. |
295 | */ |
296 | void AddToSelf(int64_t value) { this->base += value; } |
297 | /*! |
298 | * \brief self += other * scale; |
299 | * \param other The expression to be added. |
300 | * \param scale The additional scale on value. |
301 | */ |
302 | void AddToSelf(SplitExpr other, int64_t scale) { |
303 | if (other->scale == 0) return; |
304 | // We need to maintain the segment invariance: |
305 | // Same index are stored close to each other. |
306 | // sorted from big lower_factor to small one. |
307 | size_t start = 0; |
308 | for (; start < args.size(); ++start) { |
309 | if (args[start]->IndexEqual(other)) break; |
310 | } |
311 | for (size_t j = start; j < args.size(); ++j) { |
312 | if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) { |
313 | other.CopyOnWrite()->scale *= scale; |
314 | this->args.insert(this->args.begin() + j, other); |
315 | return; |
316 | } |
317 | if (other->lower_factor == args[j]->lower_factor && |
318 | other->upper_factor == args[j]->upper_factor && |
319 | other->DivModeCompatibleTo(args[j]->div_mode)) { |
320 | args[j].CopyOnWrite()->scale += other->scale * scale; |
321 | return; |
322 | } |
323 | } |
324 | // Insert other in the end. |
325 | other.CopyOnWrite()->scale *= scale; |
326 | this->args.emplace_back(std::move(other)); |
327 | } |
328 | |
329 | void AddToSelf(const SumExpr& other, int64_t scale); |
330 | |
331 | /*! |
332 | * \brief check if cast can be pushed to sub-expressions |
333 | * \param dtype The target datatype |
334 | * \param analyzer The analyzer |
335 | * \return whether the cast can be safely pushed to children |
336 | */ |
337 | bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { |
338 | bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest() |
339 | : base == -(1LL << (dtype.bits() - 1)); |
340 | // cast(dtype, arg_1 + arg_2 + ... arg_n) == |
341 | // cast(dtype, arg_1) + ... + cast(dtype, arg_n) |
342 | // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of |
343 | // its intermediate results fit in the range of dtype |
344 | if (dtype.bits() >= this->dtype.bits()) { |
345 | return true; // upcast is safe |
346 | } |
347 | PrimExpr res = make_const(dtype, 0); |
348 | for (size_t i = 0; i < args.size(); ++i) { |
349 | if (args[i]->scale > 0) { |
350 | res = res + args[i]->Normalize(); |
351 | if (!CastIsSafe(dtype, res, analyzer)) { |
352 | return false; |
353 | } |
354 | } |
355 | } |
356 | if (base > 0 || is_min_value) { |
357 | res = res + make_const(dtype, base); |
358 | if (!CastIsSafe(dtype, res, analyzer)) { |
359 | return false; |
360 | } |
361 | } |
362 | // negative scales follows using sub. |
363 | for (size_t i = 0; i < args.size(); ++i) { |
364 | if (args[i]->scale < 0) { |
365 | res = res - args[i]->NormalizeWithScale(-1); |
366 | if (!CastIsSafe(dtype, res, analyzer)) { |
367 | return false; |
368 | } |
369 | } |
370 | } |
371 | if (base < 0 && !is_min_value) { |
372 | res = res - make_const(dtype, -base); |
373 | if (!CastIsSafe(dtype, res, analyzer)) { |
374 | return false; |
375 | } |
376 | } |
377 | for (const auto& arg : args) { |
378 | if (!arg->CanPushCastToChildren(dtype, analyzer)) { |
379 | return false; |
380 | } |
381 | } |
382 | return true; |
383 | } |
384 | |
385 | /*! |
386 | * \brief self = cast(dtype, self) |
387 | * \param dtype The target datatype |
388 | */ |
389 | void PushCastToChildren(DataType dtype) { |
390 | for (auto& arg : args) { |
391 | arg.CopyOnWrite()->PushCastToChildren(dtype); |
392 | } |
393 | this->dtype = dtype; |
394 | } |
395 | |
396 | static constexpr const char* _type_key = "arith.SumExpr" ; |
397 | TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode); |
398 | |
399 | private: |
400 | /*! |
401 | * \brief Simplify the args by merging SplitExprs |
402 | * \param args The original list of arguments. |
403 | * \return simplified version. |
404 | */ |
405 | static std::vector<SplitExpr> SimplifySplitExprs(std::vector<SplitExpr> args) { |
406 | // NOTE: This algorithm relies on the factor that args are divided into segments |
407 | // and each segment is sorted in descending order of lower_factor. |
408 | for (size_t i = 0; i < args.size(); ++i) { |
409 | if (args[i]->scale == 0) continue; |
410 | for (size_t j = i + 1; j < args.size(); ++j) { |
411 | SplitExpr& lhs = args[i]; |
412 | SplitExpr& rhs = args[j]; |
413 | if (!lhs->IndexEqual(rhs)) break; |
414 | if (lhs->upper_factor < rhs->lower_factor) break; |
415 | if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor && |
416 | lhs->DivModeCompatibleTo(rhs->div_mode)) { |
417 | // folding same co-efficient. |
418 | rhs.CopyOnWrite()->scale += lhs->scale; |
419 | lhs.CopyOnWrite()->scale = 0; |
420 | } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 && |
421 | lhs->scale % rhs->scale == 0 && |
422 | lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor && |
423 | lhs->DivModeCompatibleTo(rhs->div_mode)) { |
424 | // Rules used in the proof: |
425 | // |
426 | // Rule 1: (x % (c * s)) / c = (x / c) % s |
427 | // Proof: |
428 | // x can always be decomposed into p * c * s + q * c + r |
429 | // where 0 <= q * c + r < c * s and 0 <= r < c. |
430 | // Then, lhs = ((p * c * s + q * c + r) % (c * s)) / c = (q * c + r) / c = q |
431 | // rhs = ((p * c * s + q * c + r) / c) % s = (p * s + q) % s = q |
432 | // Thus, lhs = rhs |
433 | // |
434 | // The above proof is for the floordiv. |
435 | // The same rule also holds for truncdiv(division rule in C). |
436 | // Because both sides only involve mul, div and mod, |
437 | // we can take abs of x, c and s, apply the floordiv proof, |
438 | // and finally add the sign back. |
439 | // |
440 | // Rule 2: (x / s) * s + x % s = x (true for both trunc and floor div) |
441 | // |
442 | // General merge condition and proof: |
443 | // - x = lhs->index % lhs->upper_factor |
444 | // - s = lhs->scale / rhs->scale |
445 | // - c = rhs->lower_factor |
446 | // |
447 | // (x / (c * s)) * s + (x % (c * s)) / c |
448 | // => ((x / c) / s) * s + ((x / c) % s) |
449 | // => (x / c) |
450 | // |
451 | // Examples: |
452 | // |
453 | // (z / 6) * 6 + ((z % 6) / 3) * 3 |
454 | // => ((z / 6) * 2 + (z % 6) / 3) * 3 |
455 | // => (z / 3) * 3 |
456 | // note: x = z, c = 3, s = 2 |
457 | // |
458 | // ((z % 12) / 6) * 6 + ((z % 6) / 3) * 3 |
459 | // => (((z % 12) / 6) * 2 + ((z % 12) % 6) / 3) * 3 |
460 | // => ((z % 12) / 3) * 3 |
461 | // note: x = z % 12, c = 3, s = 2 |
462 | // note also the invariance lhs->upper_factor % lhs->lower_factor == 0 |
463 | // |
464 | SplitExprNode* merged = rhs.CopyOnWrite(); |
465 | merged->upper_factor = lhs->upper_factor; |
466 | // reset args[i] to be zero. |
467 | lhs.CopyOnWrite()->scale = 0; |
468 | break; |
469 | } |
470 | } |
471 | } |
472 | // sort by the entry |
473 | // Here we simply sort by descending order of scales. |
474 | // For now, we do not compare by index because that comparison |
475 | // can be runtime dependent and create inderminism. |
476 | // we do not sort by index for now because it can be costly |
477 | // to deep compare Exprs, and address of Vars can be runtime dependent. |
478 | // |
479 | auto fcompare = [](const SplitExpr& lhs, const SplitExpr& rhs) { |
480 | // order by scale first |
481 | if (lhs->scale > rhs->scale) return true; |
482 | if (lhs->scale < rhs->scale) return false; |
483 | // then order by factor |
484 | if (lhs->lower_factor > rhs->lower_factor) return true; |
485 | if (lhs->lower_factor < rhs->lower_factor) return false; |
486 | // then order by upper factor |
487 | if (lhs->upper_factor > rhs->upper_factor) return true; |
488 | if (lhs->upper_factor < rhs->upper_factor) return false; |
489 | // then order by div mode |
490 | if (lhs->div_mode > rhs->div_mode) return true; |
491 | if (lhs->div_mode < rhs->div_mode) return false; |
492 | // tie. |
493 | // TODO(tvm-team) We might consider index as the last comparison point, |
494 | // after we make deep comparator more derministic. |
495 | // Specifically, we can consider comparing names of vars and break ties with address. |
496 | return false; |
497 | }; |
498 | std::stable_sort(args.begin(), args.end(), fcompare); |
499 | return args; |
500 | } |
501 | static PrimExpr Normalize_(DataType dtype, const std::vector<SplitExpr>& args, int64_t base) { |
502 | bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits<int64_t>::lowest() |
503 | : base == -(1LL << (dtype.bits() - 1)); |
504 | // Positive scales first |
505 | PrimExpr res = make_const(dtype, 0); |
506 | for (size_t i = 0; i < args.size(); ++i) { |
507 | if (args[i]->scale > 0) { |
508 | res = res + args[i]->Normalize(); |
509 | } |
510 | } |
511 | if (base > 0 || is_min_value) { |
512 | res = res + make_const(dtype, base); |
513 | } |
514 | // negative scales follows using sub. |
515 | for (size_t i = 0; i < args.size(); ++i) { |
516 | if (args[i]->scale < 0) { |
517 | res = res - args[i]->NormalizeWithScale(-1); |
518 | } |
519 | } |
520 | if (base < 0 && !is_min_value) { |
521 | res = res - make_const(dtype, -base); |
522 | } |
523 | return res; |
524 | } |
525 | }; |
526 | |
527 | class SumExpr : public PrimExpr { |
528 | public: |
529 | TVM_DEFINE_OBJECT_REF_METHODS(SumExpr, PrimExpr, SumExprNode); |
530 | TVM_DEFINE_OBJECT_REF_COW_METHOD(SumExprNode); |
531 | }; |
532 | |
533 | void SumExprNode::AddToSelf(const SumExpr& other, int64_t scale) { |
534 | // NOTE: it is rare to have a balanced long expression, |
535 | // linear scan is fine for our case. |
536 | for (size_t i = 0; i < other->args.size(); ++i) { |
537 | this->AddToSelf(other->args[i], scale); |
538 | } |
539 | this->AddToSelf(other->base * scale); |
540 | } |
541 | |
542 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
543 | .set_dispatch<SplitExprNode>([](const ObjectRef& node, ReprPrinter* p) { |
544 | auto* op = static_cast<const SplitExprNode*>(node.get()); |
545 | auto factor_str = [](int64_t f) { |
546 | return f == SplitExprNode::kPosInf ? std::string("+inf" ) : std::to_string(f); |
547 | }; |
548 | p->stream << "split(" ; |
549 | p->Print(op->index); |
550 | p->stream << ", lower=" << factor_str(op->lower_factor) |
551 | << ", upper=" << factor_str(op->upper_factor) << ", scale=" << op->scale |
552 | << ", div_mode=" ; |
553 | switch (op->div_mode) { |
554 | // No "default", so that the compiler will emit a warning if more div modes are |
555 | // added that are not covered by the switch. |
556 | case kTruncDiv: |
557 | p->stream << "truncdiv" ; |
558 | break; |
559 | case kFloorDiv: |
560 | p->stream << "floordiv" ; |
561 | break; |
562 | } |
563 | p->stream << ')'; |
564 | }); |
565 | |
566 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
567 | .set_dispatch<SumExprNode>([](const ObjectRef& node, ReprPrinter* p) { |
568 | auto* op = static_cast<const SumExprNode*>(node.get()); |
569 | p->stream << "sum(base=" << op->base; |
570 | for (const SplitExpr& s : op->args) { |
571 | p->stream << ", " ; |
572 | p->Print(s); |
573 | } |
574 | p->stream << ')'; |
575 | }); |
576 | |
577 | // Sub-class RewriteSimplifier::Impl to take benefit of |
578 | // rewriter for condition simplification etc. |
579 | class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { |
580 | public: |
581 | using Rewriter = RewriteSimplifier::Impl; |
582 | |
583 | explicit Impl(Analyzer* parent) : Rewriter(parent) {} |
584 | |
585 | PrimExpr CanonicalSimplify(PrimExpr expr) { |
586 | expr = operator()(expr); |
587 | return expr; |
588 | } |
589 | |
590 | // override the original mutate function. |
591 | PrimExpr VisitExpr(const PrimExpr& input_expr) final { |
592 | auto expr = Rewriter::VisitExpr(input_expr); |
593 | return Normalize(expr); |
594 | } |
595 | |
596 | // Normal mutation without normalization. |
597 | PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } |
598 | |
599 | using Rewriter::VisitExpr_; |
600 | PrimExpr VisitExpr_(const AddNode* op) final; |
601 | PrimExpr VisitExpr_(const SubNode* op) final; |
602 | PrimExpr VisitExpr_(const MulNode* op) final; |
603 | PrimExpr VisitExpr_(const DivNode* op) final; |
604 | PrimExpr VisitExpr_(const ModNode* op) final; |
605 | PrimExpr VisitExpr_(const FloorDivNode* op) final; |
606 | PrimExpr VisitExpr_(const FloorModNode* op) final; |
607 | PrimExpr VisitExpr_(const ReduceNode* op) final; |
608 | PrimExpr VisitExpr_(const CastNode* op) final; |
609 | |
610 | private: |
611 | /*! |
612 | * \brief compute lhs / cval |
613 | * \param lhs The left operand. |
614 | * \param cval The constant value. |
615 | * \param div_mode The division mode. |
616 | * \return The result expression; |
617 | */ |
618 | SplitExpr SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode); |
619 | /*! |
620 | * \brief compute lhs % cval |
621 | * \param lhs The left operand. |
622 | * \param cval The constant value. |
623 | * \param div_mode The division mode. |
624 | * \return The result expression; |
625 | */ |
626 | SplitExpr SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode); |
627 | /*! |
628 | * \brief Separate psum into divisible and non-divisible parts. |
629 | * \param psum The sum expression. |
630 | * \param coeff The co-efficient. |
631 | * \param out_divisible The result divisible component. |
632 | * \param out_non_divisible The non-divisible component. |
633 | */ |
634 | void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, |
635 | SumExpr* out_non_divisible); |
636 | /*! |
637 | * \brief Normalize expr to normal expr. |
638 | * \param expr The input expression. |
639 | * \return Normalized expr. |
640 | */ |
641 | PrimExpr Normalize(PrimExpr expr) { |
642 | if (const auto* op = expr.as<CanonicalExprNode>()) { |
643 | return op->Normalize(); |
644 | } else { |
645 | return expr; |
646 | } |
647 | } |
648 | /*! |
649 | * \brief Create a SplitExpr from expr. |
650 | * \param expr The input expr. |
651 | * \return The transformed SplitExpr. |
652 | */ |
653 | SplitExpr ToSplitExpr(PrimExpr expr) { |
654 | if (const auto* op = expr.as<SplitExprNode>()) { |
655 | return GetRef<SplitExpr>(op); |
656 | } |
657 | if (const auto* op = expr.as<SumExprNode>()) { |
658 | if (op->base == 0 && op->args.size() == 1) return op->args[0]; |
659 | } |
660 | if (const auto* op = expr.as<CanonicalExprNode>()) { |
661 | expr = op->Normalize(); |
662 | } |
663 | ObjectPtr<SplitExprNode> n = make_object<SplitExprNode>(); |
664 | n->dtype = expr.dtype(); |
665 | n->index = std::move(expr); |
666 | n->div_mode = kTruncDiv; |
667 | return SplitExpr(n); |
668 | } |
669 | /*! |
670 | * \brief Convert expr to an equivalent SplitExpr |
671 | * that has the specified div_mode. |
672 | * |
673 | * This function will return the same expr if its |
674 | * div_mode already satisfies the need. |
675 | * |
676 | * \param expr The input expr. |
677 | * \param div_mode The new div_mode. |
678 | * \return The transformed SplitExpr. |
679 | */ |
680 | SplitExpr ConvertDivMode(SplitExpr expr, DivMode div_mode) { |
681 | if (expr->div_mode == div_mode) return expr; |
682 | if (expr->DivModeCompatibleTo(div_mode)) { |
683 | expr.CopyOnWrite()->div_mode = div_mode; |
684 | return expr; |
685 | } |
686 | expr = ToSplitExpr(Normalize(expr)); |
687 | ICHECK(expr->DivModeCompatibleTo(div_mode)); |
688 | expr.CopyOnWrite()->div_mode = div_mode; |
689 | return expr; |
690 | } |
691 | /*! |
692 | * \brief Create a SumExpr from expr. |
693 | * \param expr The input expr. |
694 | * \return The transformed SumExpr. |
695 | */ |
696 | SumExpr ToSumExpr(PrimExpr expr) { |
697 | if (const auto* op = expr.as<SumExprNode>()) { |
698 | return GetRef<SumExpr>(op); |
699 | } |
700 | ObjectPtr<SumExprNode> n = make_object<SumExprNode>(); |
701 | n->dtype = expr.dtype(); |
702 | if (const auto* op = expr.as<IntImmNode>()) { |
703 | n->base = op->value; |
704 | return SumExpr(n); |
705 | } else { |
706 | n->args.emplace_back(ToSplitExpr(expr)); |
707 | return SumExpr(n); |
708 | } |
709 | } |
710 | // Simplify the combiner used in reduce. |
711 | PrimExpr SimplifyReduceCombiner(const ReduceNode* op); |
712 | }; |
713 | |
714 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { |
715 | if (!IsIndexType(op->dtype)) { |
716 | return Rewriter::VisitExpr_(op); |
717 | } |
718 | // normalize |
719 | PrimExpr a = this->CanonicalMutate(op->a); |
720 | PrimExpr b = this->CanonicalMutate(op->b); |
721 | |
722 | // const folding |
723 | if (auto const_res = TryConstFold<Add>(a, b)) return const_res.value(); |
724 | |
725 | // canonical form simplification. |
726 | SumExpr ret = ToSumExpr(std::move(a)); |
727 | |
728 | if (const auto* op = b.as<IntImmNode>()) { |
729 | ret.CopyOnWrite()->AddToSelf(op->value); |
730 | } else if (const auto* op = b.as<SumExprNode>()) { |
731 | ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), 1); |
732 | } else { |
733 | ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), 1); |
734 | } |
735 | return std::move(ret); |
736 | } |
737 | |
738 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { |
739 | if (!IsIndexType(op->dtype)) { |
740 | return Rewriter::VisitExpr_(op); |
741 | } |
742 | // normalize |
743 | PrimExpr a = this->CanonicalMutate(op->a); |
744 | PrimExpr b = this->CanonicalMutate(op->b); |
745 | |
746 | // const folding |
747 | if (auto const_res = TryConstFold<Sub>(a, b)) return const_res.value(); |
748 | |
749 | // canonical form simplification. |
750 | SumExpr ret = ToSumExpr(std::move(a)); |
751 | |
752 | if (const auto* op = b.as<IntImmNode>()) { |
753 | ret.CopyOnWrite()->AddToSelf(-op->value); |
754 | } else if (const auto* op = b.as<SumExprNode>()) { |
755 | ret.CopyOnWrite()->AddToSelf(GetRef<SumExpr>(op), -1); |
756 | } else { |
757 | ret.CopyOnWrite()->AddToSelf(ToSplitExpr(b), -1); |
758 | } |
759 | return std::move(ret); |
760 | } |
761 | |
762 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { |
763 | if (!IsIndexType(op->dtype)) { |
764 | return Rewriter::VisitExpr_(op); |
765 | } |
766 | // normalize |
767 | PrimExpr a = this->CanonicalMutate(op->a); |
768 | PrimExpr b = this->CanonicalMutate(op->b); |
769 | |
770 | // const folding |
771 | if (auto const_res = TryConstFold<Mul>(a, b)) return const_res.value(); |
772 | |
773 | // x * c |
774 | if (a.as<IntImmNode>()) { |
775 | std::swap(a, b); |
776 | } |
777 | if (const auto* bconst = b.as<IntImmNode>()) { |
778 | if (a.as<SumExprNode>()) { |
779 | SumExpr ret = Downcast<SumExpr>(std::move(a)); |
780 | ret.CopyOnWrite()->MulToSelf(bconst->value); |
781 | return std::move(ret); |
782 | } else { |
783 | SplitExpr ret = ToSplitExpr(std::move(a)); |
784 | ret.CopyOnWrite()->MulToSelf(bconst->value); |
785 | return std::move(ret); |
786 | } |
787 | } |
788 | |
789 | // normal path. |
790 | a = Normalize(a); |
791 | b = Normalize(b); |
792 | if (op->a.same_as(a) && op->b.same_as(b)) { |
793 | return GetRef<PrimExpr>(op); |
794 | } else { |
795 | return Mul(a, b); |
796 | } |
797 | } |
798 | |
799 | void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, |
800 | SumExpr* out_divisible, |
801 | SumExpr* out_non_divisible) { |
802 | auto divisible = make_object<SumExprNode>(); |
803 | auto non_divisible = make_object<SumExprNode>(); |
804 | divisible->dtype = psum->dtype; |
805 | non_divisible->dtype = psum->dtype; |
806 | |
807 | if (psum->base % coeff == 0) { |
808 | divisible->base = psum->base; |
809 | } else { |
810 | non_divisible->base = psum->base; |
811 | } |
812 | for (const auto& e : psum->args) { |
813 | if (e->scale % coeff == 0) { |
814 | divisible->args.push_back(e); |
815 | } else { |
816 | non_divisible->args.push_back(e); |
817 | } |
818 | } |
819 | *out_divisible = SumExpr(divisible); |
820 | *out_non_divisible = SumExpr(non_divisible); |
821 | } |
822 | |
823 | SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { |
824 | ICHECK_GT(cval, 0); |
825 | lhs = ConvertDivMode(lhs, div_mode); |
826 | |
827 | // the following rule works for both floordiv and truncdiv |
828 | if (lhs->scale % cval == 0) { |
829 | lhs.CopyOnWrite()->scale /= cval; |
830 | return lhs; |
831 | } |
832 | |
833 | if (cval % lhs->scale == 0) { |
834 | int64_t scaled_cval = cval / lhs->scale; |
835 | if (lhs->upper_factor == SplitExprNode::kPosInf || |
836 | lhs->upper_factor % (lhs->lower_factor * scaled_cval) == 0) { |
837 | // directly fold division. |
838 | lhs.CopyOnWrite()->scale = 1; |
839 | lhs.CopyOnWrite()->lower_factor *= scaled_cval; |
840 | lhs->Verify(); |
841 | return lhs; |
842 | } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { |
843 | // (x % c1) / c2 => 0 when c2 >= c1 |
844 | return ToSplitExpr(make_zero(lhs.dtype())); |
845 | } else { |
846 | // move the upper_factor modular into index. |
847 | lhs.CopyOnWrite()->index = |
848 | ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode); |
849 | lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; |
850 | lhs.CopyOnWrite()->scale = 1; |
851 | lhs.CopyOnWrite()->lower_factor *= scaled_cval; |
852 | lhs->Verify(); |
853 | return lhs; |
854 | } |
855 | } |
856 | // directly return the split with cval == 1 |
857 | lhs = ToSplitExpr(Normalize(lhs)); |
858 | ICHECK(lhs->DivModeCompatibleTo(div_mode)); |
859 | ICHECK_EQ(lhs->scale, 1); |
860 | lhs.CopyOnWrite()->lower_factor *= cval; |
861 | lhs.CopyOnWrite()->div_mode = div_mode; |
862 | return lhs; |
863 | } |
864 | |
865 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { |
866 | if (!IsIndexType(op->dtype)) { |
867 | return Rewriter::VisitExpr_(op); |
868 | } |
869 | |
870 | PrimExpr a = this->CanonicalMutate(op->a); |
871 | PrimExpr b = this->CanonicalMutate(op->b); |
872 | |
873 | // const folding |
874 | if (auto const_res = TryConstFold<Div>(a, b)) return const_res.value(); |
875 | PVar<IntImm> c1; |
876 | // x / c1 |
877 | if (c1.Match(b) && c1.Eval()->value > 0) { |
878 | int64_t cval = c1.Eval()->value; |
879 | if (cval == 1) return a; |
880 | |
881 | if (const auto* psum = a.as<SumExprNode>()) { |
882 | SumExpr lhs, ; |
883 | SeparateDivisibleParts(psum, cval, &lhs, &extra); |
884 | // can be divided by cval |
885 | if (extra->IsZero()) { |
886 | lhs.CopyOnWrite()->DivideBy(cval); |
887 | return std::move(lhs); |
888 | } |
889 | // both lhs and extra are non-negative |
890 | if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && |
891 | analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { |
892 | lhs.CopyOnWrite()->DivideBy(cval); |
893 | PrimExpr temp = Normalize(extra); |
894 | if (const auto* pconst = temp.as<IntImmNode>()) { |
895 | lhs.CopyOnWrite()->AddToSelf(pconst->value / cval); |
896 | } else { |
897 | // if 0 <= extra < cval, it means the extra can be eliminated. |
898 | if (TryCompare(temp, cval) != CompareResult::kLT) { |
899 | lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); |
900 | } |
901 | } |
902 | return std::move(lhs); |
903 | } |
904 | } else { |
905 | // if a >= 0 && a < cval, then result == 0 |
906 | auto cbound = analyzer_->const_int_bound(Normalize(a)); |
907 | if (cbound->min_value >= 0 && cbound->max_value < cval) { |
908 | return make_zero(a.dtype()); |
909 | } |
910 | } |
911 | return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); |
912 | } |
913 | // normal path |
914 | a = Normalize(a); |
915 | b = Normalize(b); |
916 | if (op->a.same_as(a) && op->b.same_as(b)) { |
917 | return GetRef<PrimExpr>(op); |
918 | } else { |
919 | return Div(a, b); |
920 | } |
921 | } |
922 | |
923 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { |
924 | if (!IsIndexType(op->dtype)) { |
925 | return Rewriter::VisitExpr_(op); |
926 | } |
927 | PrimExpr a = this->CanonicalMutate(op->a); |
928 | PrimExpr b = this->CanonicalMutate(op->b); |
929 | |
930 | // const folding |
931 | if (auto const_res = TryConstFold<FloorDiv>(a, b)) return const_res.value(); |
932 | PVar<IntImm> c1; |
933 | // x / c1 |
934 | if (c1.Match(b) && c1.Eval()->value > 0) { |
935 | int64_t cval = c1.Eval()->value; |
936 | if (cval == 1) return a; |
937 | |
938 | if (const auto* psum = a.as<SumExprNode>()) { |
939 | SumExpr lhs, ; |
940 | SeparateDivisibleParts(psum, cval, &lhs, &extra); |
941 | if (extra->IsZero()) { |
942 | lhs.CopyOnWrite()->DivideBy(cval); |
943 | return std::move(lhs); |
944 | } |
945 | // continue simplification. |
946 | lhs.CopyOnWrite()->DivideBy(cval); |
947 | PrimExpr temp = Normalize(extra); |
948 | if (const auto* pconst = temp.as<IntImmNode>()) { |
949 | lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval)); |
950 | } else { |
951 | // if 0 <= extra < cval, it means the extra can be eliminated. |
952 | if (!(TryCompare(temp, cval) == CompareResult::kLT && |
953 | analyzer_->CanProveGreaterEqual(temp, 0))) { |
954 | lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); |
955 | } |
956 | } |
957 | return std::move(lhs); |
958 | } else { |
959 | // if a >= 0 && a < cval, then result == 0 |
960 | auto cbound = analyzer_->const_int_bound(Normalize(a)); |
961 | if (cbound->min_value >= 0 && cbound->max_value < cval) { |
962 | return make_zero(a.dtype()); |
963 | } |
964 | } |
965 | return SplitDivConst(ToSplitExpr(std::move(a)), cval, kFloorDiv); |
966 | } |
967 | // normal path |
968 | a = Normalize(a); |
969 | b = Normalize(b); |
970 | if (op->a.same_as(a) && op->b.same_as(b)) { |
971 | return GetRef<PrimExpr>(op); |
972 | } else { |
973 | return FloorDiv(a, b); |
974 | } |
975 | } |
976 | |
977 | SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { |
978 | ICHECK_GT(cval, 0); |
979 | lhs = ConvertDivMode(lhs, div_mode); |
980 | |
981 | if (lhs->scale % cval == 0) { |
982 | lhs.CopyOnWrite()->scale = 0; |
983 | return lhs; |
984 | } |
985 | if (cval % lhs->scale == 0) { |
986 | // The rationale: |
987 | // (index % upper) / lower * scale % cval, given cval = scaled_cval * scale |
988 | // by the rule (x * c1) % (c2 * c1) => (x % c2) * c1, |
989 | // = (index % upper) / lower % scaled_cval * scale |
990 | // by the rule (x / c1) % c2 => (x % (c1 * c2)) / c1, |
991 | // = (index % upper) % (new_upper_factor) / lower * scale |
992 | int64_t scaled_cval = cval / lhs->scale; |
993 | int64_t new_upper_factor = lhs->lower_factor * scaled_cval; |
994 | // try to see if we can reduce the existing upper modular. |
995 | if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) { |
996 | // we gained a new upper factor that is smaller |
997 | // than the original one |
998 | // Perhaps there are more chances in simplifying the index |
999 | // Do a recursive call to simplify the mod with the new factor. |
1000 | if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { |
1001 | auto updated = ToSplitExpr(this->VisitExpr( |
1002 | ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); |
1003 | // re-apply the lower_factor |
1004 | if (lhs->lower_factor != 1) { |
1005 | auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode); |
1006 | ret.CopyOnWrite()->MulToSelf(lhs->scale); |
1007 | return ret; |
1008 | } else { |
1009 | updated.CopyOnWrite()->MulToSelf(lhs->scale); |
1010 | return updated; |
1011 | } |
1012 | } else { |
1013 | lhs.CopyOnWrite()->upper_factor = new_upper_factor; |
1014 | return lhs; |
1015 | } |
1016 | } else if (new_upper_factor % lhs->upper_factor == 0) { |
1017 | // (x % 2) % 4 => x % 2 |
1018 | return lhs; |
1019 | } |
1020 | } |
1021 | // Normalize the value. |
1022 | lhs = ToSplitExpr(Normalize(lhs)); |
1023 | ICHECK(lhs->DivModeCompatibleTo(div_mode)); |
1024 | ICHECK_EQ(lhs->scale, 1); |
1025 | ICHECK_EQ(lhs->lower_factor, 1); |
1026 | lhs.CopyOnWrite()->div_mode = div_mode; |
1027 | lhs.CopyOnWrite()->upper_factor = cval; |
1028 | return lhs; |
1029 | } |
1030 | |
1031 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { |
1032 | if (!IsIndexType(op->dtype)) { |
1033 | return Rewriter::VisitExpr_(op); |
1034 | } |
1035 | // normalize |
1036 | PrimExpr a = this->CanonicalMutate(op->a); |
1037 | PrimExpr b = this->CanonicalMutate(op->b); |
1038 | |
1039 | // const folding |
1040 | if (auto const_res = TryConstFold<Mod>(a, b)) return const_res.value(); |
1041 | |
1042 | PVar<IntImm> c1; |
1043 | // x % c1 |
1044 | if (c1.Match(b) && c1.Eval()->value > 0) { |
1045 | int64_t cval = c1.Eval()->value; |
1046 | if (const auto* psum = a.as<SumExprNode>()) { |
1047 | SumExpr lhs, ; |
1048 | SeparateDivisibleParts(psum, cval, &lhs, &extra); |
1049 | if (extra->IsZero()) { |
1050 | return make_zero(a.dtype()); |
1051 | } |
1052 | // both lhs and extra are non-negative |
1053 | if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && |
1054 | analyzer_->CanProveGreaterEqual(extra->Normalize(), 0)) { |
1055 | PrimExpr temp = Normalize(extra); |
1056 | if (temp.as<IntImmNode>()) { |
1057 | return truncmod(temp, c1.Eval()); |
1058 | } else { |
1059 | // If temp < cval && temp >=0 then can remove the mod. |
1060 | if (TryCompare(temp, cval) == CompareResult::kLT) { |
1061 | return temp; |
1062 | } else { |
1063 | // contonue to use logic below. |
1064 | a = extra; |
1065 | psum = a.as<SumExprNode>(); |
1066 | ICHECK(psum != nullptr); |
1067 | } |
1068 | } |
1069 | } |
1070 | // Simplify the offset constant if necessary. |
1071 | // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 |
1072 | auto cbound = analyzer_->const_int_bound(Normalize(a)); |
1073 | int64_t new_base = psum->base % cval; |
1074 | if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { |
1075 | SumExpr sum_expr = Downcast<SumExpr>(a); |
1076 | sum_expr.CopyOnWrite()->base = new_base; |
1077 | return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); |
1078 | } |
1079 | } else { |
1080 | // if a >= 0 && a < cval, then result == 0 |
1081 | auto cbound = analyzer_->const_int_bound(Normalize(a)); |
1082 | if (cbound->min_value >= 0 && cbound->max_value < cval) { |
1083 | return a; |
1084 | } |
1085 | } |
1086 | return SplitModConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); |
1087 | } |
1088 | // normal path |
1089 | a = Normalize(a); |
1090 | b = Normalize(b); |
1091 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1092 | return GetRef<PrimExpr>(op); |
1093 | } else { |
1094 | return Mod(a, b); |
1095 | } |
1096 | } |
1097 | |
1098 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { |
1099 | if (!IsIndexType(op->dtype)) { |
1100 | return Rewriter::VisitExpr_(op); |
1101 | } |
1102 | // normalize |
1103 | PrimExpr a = this->CanonicalMutate(op->a); |
1104 | PrimExpr b = this->CanonicalMutate(op->b); |
1105 | |
1106 | // const folding |
1107 | if (auto const_res = TryConstFold<FloorMod>(a, b)) return const_res.value(); |
1108 | |
1109 | PVar<IntImm> c1; |
1110 | // x % c1 |
1111 | if (c1.Match(b) && c1.Eval()->value > 0) { |
1112 | int64_t cval = c1.Eval()->value; |
1113 | if (const auto* psum = a.as<SumExprNode>()) { |
1114 | SumExpr lhs, ; |
1115 | SeparateDivisibleParts(psum, cval, &lhs, &extra); |
1116 | PrimExpr temp = Normalize(extra); |
1117 | if (temp.as<IntImmNode>()) { |
1118 | return floormod(temp, c1.Eval()); |
1119 | } else { |
1120 | // If temp < cval && temp >=0 then can remove the mod. |
1121 | if (TryCompare(temp, cval) == CompareResult::kLT && |
1122 | analyzer_->CanProveGreaterEqual(temp, 0)) { |
1123 | return temp; |
1124 | } else { |
1125 | // contonue to use logic below. |
1126 | a = extra; |
1127 | psum = a.as<SumExprNode>(); |
1128 | ICHECK(psum != nullptr); |
1129 | } |
1130 | } |
1131 | // Simplify the offset constant if necessary. |
1132 | // floormod(x - 5, 3) => floormod(x + 1, 3) |
1133 | int64_t new_base = floormod(psum->base, cval); |
1134 | SumExpr sum_expr = Downcast<SumExpr>(std::move(a)); |
1135 | sum_expr.CopyOnWrite()->base = new_base; |
1136 | return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kFloorDiv); |
1137 | } else { |
1138 | // if a >= 0 && a < cval, then result == a |
1139 | auto cbound = analyzer_->const_int_bound(Normalize(a)); |
1140 | if (cbound->min_value >= 0 && cbound->max_value < cval) { |
1141 | return a; |
1142 | } |
1143 | } |
1144 | return SplitModConst(ToSplitExpr(std::move(a)), cval, kFloorDiv); |
1145 | } |
1146 | // normal path |
1147 | a = Normalize(a); |
1148 | b = Normalize(b); |
1149 | if (op->a.same_as(a) && op->b.same_as(b)) { |
1150 | return GetRef<PrimExpr>(op); |
1151 | } else { |
1152 | return FloorMod(a, b); |
1153 | } |
1154 | } |
1155 | |
1156 | // Simplify reduce expression. |
1157 | PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { |
1158 | // First simplify the results |
1159 | Array<PrimExpr> simplified_result; |
1160 | for (const auto& res : op->combiner->result) { |
1161 | PrimExpr new_res = this->VisitExpr(res); |
1162 | simplified_result.push_back(new_res); |
1163 | } |
1164 | |
1165 | // Which components to keep |
1166 | std::vector<int> used(op->combiner->result.size(), false); |
1167 | |
1168 | // This function recursively marks the used components starting from |
1169 | // the index idx |
1170 | std::function<void(int)> mark_used; |
1171 | mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { |
1172 | // if the idx-th component was marked as used before, do nothing |
1173 | if (used[idx]) return; |
1174 | used[idx] = true; |
1175 | |
1176 | // check if the idx-th result expr uses some lhs or rhs variables |
1177 | // and recursively mark the corresponding components |
1178 | for (size_t i = 0; i < simplified_result.size(); ++i) |
1179 | if (!used[i]) { |
1180 | if (UsesVar(simplified_result[idx], |
1181 | [v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) || |
1182 | UsesVar(simplified_result[idx], |
1183 | [v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; })) |
1184 | mark_used(i); |
1185 | } |
1186 | }; |
1187 | |
1188 | // mark all used components starting from the value_index |
1189 | mark_used(op->value_index); |
1190 | |
1191 | // components which have side effects should also be preserved |
1192 | for (size_t i = 0; i < used.size(); ++i) { |
1193 | if (SideEffect(op->source[i]) > CallEffectKind::kReadState || |
1194 | SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState || |
1195 | SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState || |
1196 | (!op->init.empty() && SideEffect(op->init[i]) > CallEffectKind::kReadState)) { |
1197 | mark_used(i); |
1198 | } |
1199 | } |
1200 | |
1201 | int new_value_index = op->value_index; |
1202 | Array<PrimExpr> new_result; |
1203 | Array<PrimExpr> new_identity; |
1204 | Array<Var> new_lhs; |
1205 | Array<Var> new_rhs; |
1206 | Array<PrimExpr> new_source; |
1207 | Array<PrimExpr> new_init; |
1208 | |
1209 | // new stuff is old stuff which is used |
1210 | for (size_t i = 0; i < used.size(); ++i) { |
1211 | if (used[i]) { |
1212 | // We simplify the result and identity, but not the source |
1213 | new_result.push_back(simplified_result[i]); |
1214 | new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i])); |
1215 | new_lhs.push_back(op->combiner->lhs[i]); |
1216 | new_rhs.push_back(op->combiner->rhs[i]); |
1217 | new_source.push_back(op->source[i]); |
1218 | if (!op->init.empty()) new_init.push_back(op->init[i]); |
1219 | } else if (static_cast<int>(i) < op->value_index) { |
1220 | // value_index should also be adjusted |
1221 | new_value_index--; |
1222 | } |
1223 | } |
1224 | |
1225 | CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); |
1226 | return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index, new_init); |
1227 | } |
1228 | |
1229 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { |
1230 | // Recursively call simplification when necessary. |
1231 | PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); |
1232 | op = ret.as<ReduceNode>(); |
1233 | // already been simplified by const reduction axis removal |
1234 | if (op == nullptr) return ret; |
1235 | if (op->axis.empty()) { |
1236 | if (!op->init.empty()) { |
1237 | return this->VisitExpr(Select(op->condition, |
1238 | (*op->combiner.get())(op->init, op->source)[op->value_index], |
1239 | op->init[op->value_index])); |
1240 | } |
1241 | // Note that here we assume that the identity element is indeed identity. Without this |
1242 | // assumption we would have to perform a single iteration of the loop, i.e. use |
1243 | // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` |
1244 | // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. |
1245 | return this->VisitExpr(Select(op->condition, op->source[op->value_index], |
1246 | op->combiner->identity_element[op->value_index])); |
1247 | } |
1248 | // combiner simplification. |
1249 | ret = SimplifyReduceCombiner(op); |
1250 | return ret; |
1251 | } |
1252 | |
1253 | PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) { |
1254 | if (!IsIndexType(op->dtype)) { |
1255 | return Rewriter::VisitExpr_(op); |
1256 | } |
1257 | // normalize |
1258 | PrimExpr value = this->CanonicalMutate(op->value); |
1259 | // PushCastToChildren |
1260 | if (value.as<SumExprNode>()) { |
1261 | SumExpr se = Downcast<SumExpr>(value); |
1262 | if (se->CanPushCastToChildren(op->dtype, analyzer_)) { |
1263 | se.CopyOnWrite()->PushCastToChildren(op->dtype); |
1264 | return std::move(se); |
1265 | } |
1266 | } |
1267 | if (value.as<SplitExprNode>()) { |
1268 | SplitExpr se = Downcast<SplitExpr>(value); |
1269 | if (se->CanPushCastToChildren(op->dtype, analyzer_)) { |
1270 | se.CopyOnWrite()->PushCastToChildren(op->dtype); |
1271 | return std::move(se); |
1272 | } |
1273 | } |
1274 | return Rewriter::VisitExpr_(op); |
1275 | } |
1276 | |
1277 | PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { |
1278 | return impl_->CanonicalSimplify(expr); |
1279 | } |
1280 | |
1281 | void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { |
1282 | impl_->Update(var, info, override); |
1283 | } |
1284 | |
1285 | CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} |
1286 | |
1287 | CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } |
1288 | |
1289 | } // namespace arith |
1290 | } // namespace tvm |
1291 | |