1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file const_fold.h
22 * \brief Centralized location for constant folding.
23 */
24#ifndef TVM_ARITH_CONST_FOLD_H_
25#define TVM_ARITH_CONST_FOLD_H_
26
27#include <tvm/runtime/container/optional.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include <algorithm>
32#include <cmath>
33#include <limits>
34
35#include "int_operator.h"
36
37namespace tvm {
38namespace arith {
39
40/*!
41 * \brief Try to run binary compute with constant folding.
42 *
43 * \param a The left operand.
44 * \param b The right operand.
45 * \tparam Op The operator type.
46 *
47 * \note a and b Must already matched data types with each other.
48 * \return NullOpt if constant fold fails, otherwise return folded result.
49 */
50template <typename Op>
51inline Optional<PrimExpr> TryConstFold(PrimExpr a, PrimExpr b);
52
53/*!
54 * \brief Try to run unary compute with constant folding.
55 *
56 * \param a The left operand.
57 * \tparam Op The operator type.
58 *
59 * \note a and b Must already matched data types with each other.
60 * \return NullOpt if constant fold fails, otherwise return folded result.
61 */
62template <typename Op>
63inline Optional<PrimExpr> TryConstFold(PrimExpr a);
64
65/*!
66 * \brief Check whether type is used to represent index.
67 *
68 * Index types are frequently used in shape computation
69 * and need to be aggressively constant-folded.
70 *
71 * \param type The type to represent index.
72 * \return the checked result.
73 */
74inline bool IsIndexType(const DataType& type) {
75 return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
76}
77
78/*! \brief Helper to get const folding result repr in int64. */
79inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) {
80 if (dtype.bits() < 64) {
81 x &= (1LL << dtype.bits()) - 1;
82 }
83 if (dtype.is_int()) {
84 // get sign extended value of integer with specified bits
85 int64_t m = 1LL << (dtype.bits() - 1);
86 x = (x ^ m) - m;
87 }
88 return x;
89}
90
91/*! \brief Helper to get fp32 const folding result repr in double. */
92inline double GetFoldResultDoubleRepr(float x) {
93 double res = static_cast<double>(x);
94 if (std::isinf(res) || std::isnan(res)) {
95 return res;
96 }
97 // certain platform (eg, on gcc7-i386) do the folding arithmetic
98 // on float and write back to double is optimized to double
99 // precision arithmetic, this is legal and we check the output
100 // range thus to ensure consistency when the float result is inf.
101 if (res < std::numeric_limits<float>::lowest()) {
102 LOG(WARNING) << "underlying float value overflow";
103 return -std::numeric_limits<double>::infinity();
104 } else if (res > std::numeric_limits<float>::max()) {
105 LOG(WARNING) << "underlying float value overflow";
106 return std::numeric_limits<double>::infinity();
107 }
108 return res;
109}
110
111#define TVM_ARITH_CONST_PROPAGATION(BODY) \
112 using tir::FloatImmNode; \
113 const IntImmNode* pa = a.as<IntImmNode>(); \
114 const IntImmNode* pb = b.as<IntImmNode>(); \
115 const FloatImmNode* fa = a.as<FloatImmNode>(); \
116 const FloatImmNode* fb = b.as<FloatImmNode>(); \
117 BODY;
118
119#define TVM_INDEX_CONST_PROPAGATION(BODY) \
120 const IntImmNode* pa = a.as<IntImmNode>(); \
121 const IntImmNode* pb = b.as<IntImmNode>(); \
122 const DataType& ta = a.dtype(); \
123 const DataType& tb = b.dtype(); \
124 if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
125 BODY; \
126 }
127
128// specialization of constant folders.
129template <>
130inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
131 TVM_ARITH_CONST_PROPAGATION({
132 const DataType& rtype = a.dtype();
133 if (pa && pb) {
134 int64_t res = pa->value + pb->value;
135 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
136 }
137 if (pa && pa->value == 0) return b;
138 if (pb && pb->value == 0) return a;
139 if (fa && fb) {
140 if (rtype.bits() == 32) {
141 return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) +
142 static_cast<float>(fb->value)));
143 } else if (rtype.bits() == 64) {
144 return FloatImm(rtype, fa->value + fb->value);
145 }
146 }
147 if (fa && fa->value == 0) return b;
148 if (fb && fb->value == 0) return a;
149 });
150 return NullOpt;
151}
152
153template <>
154inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
155 TVM_ARITH_CONST_PROPAGATION({
156 ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) &&
157 (pb && pb->dtype.is_uint() && pb->value > 0U)))
158 << "Checked failed. Minuend 's value is 0U and it's dtype is uint "
159 << "while Subtrahend's dtype is uint; which will cause a negative uint";
160 const DataType& rtype = a.dtype();
161 if (pa && pb) {
162 int64_t res = pa->value - pb->value;
163 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
164 }
165 if (pb && pb->value == 0) return a;
166 if (fa && fb) {
167 if (rtype.bits() == 32) {
168 return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) -
169 static_cast<float>(fb->value)));
170 } else if (rtype.bits() == 64) {
171 return FloatImm(rtype, fa->value - fb->value);
172 }
173 }
174 if (fb && fb->value == 0) return a;
175 });
176 return NullOpt;
177}
178
179template <>
180inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
181 TVM_ARITH_CONST_PROPAGATION({
182 const DataType& rtype = a.dtype();
183 if (pa && pb) {
184 int64_t res = pa->value * pb->value;
185 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
186 }
187 if (pa) {
188 if (pa->value == 1) return b;
189 if (pa->value == 0) return a;
190 }
191 if (pb) {
192 if (pb->value == 1) return a;
193 if (pb->value == 0) return b;
194 }
195 if (fa && fb) {
196 if (rtype.bits() == 32) {
197 return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) *
198 static_cast<float>(fb->value)));
199 } else if (rtype.bits() == 64) {
200 return FloatImm(rtype, fa->value * fb->value);
201 }
202 }
203 if (fa) {
204 if (fa->value == 1) return b;
205 if (fa->value == 0) return a;
206 }
207 if (fb) {
208 if (fb->value == 1) return a;
209 if (fb->value == 0) return b;
210 }
211 });
212 return NullOpt;
213}
214
215template <>
216inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
217 TVM_ARITH_CONST_PROPAGATION({
218 const DataType& rtype = a.dtype();
219 if (pa && pb) {
220 // due to division and mod can have different modes
221 // NOTE: this will assumes truc div.
222 ICHECK_NE(pb->value, 0) << "Divide by zero";
223 int64_t res = pa->value / pb->value;
224 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
225 }
226 if (pa) {
227 if (pa->value == 0) return a;
228 }
229 if (pb) {
230 if (pb->value == 1) return a;
231 ICHECK_NE(pb->value, 0) << "Divide by zero";
232 }
233 if (fa && fb) {
234 ICHECK_NE(fb->value, 0) << "Divide by zero";
235 if (rtype.bits() == 32) {
236 return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) /
237 static_cast<float>(fb->value)));
238 } else if (rtype.bits() == 64) {
239 return FloatImm(rtype, fa->value / fb->value);
240 }
241 }
242 if (fa && fa->value == 0) return a;
243 if (fb) {
244 if (fb->value == 1) return a;
245 ICHECK_NE(fb->value, 0) << "Divide by zero";
246 }
247 });
248 return NullOpt;
249}
250
251template <>
252inline Optional<PrimExpr> TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
253 TVM_INDEX_CONST_PROPAGATION({
254 const DataType& rtype = a.dtype();
255 if (pa && pb) {
256 ICHECK_NE(pb->value, 0) << "Divide by zero";
257 int64_t res = pa->value % pb->value;
258 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
259 }
260 if (pa) {
261 if (pa->value == 0) return a;
262 }
263 if (pb) {
264 if (pb->value == 1) return tir::make_zero(rtype);
265 ICHECK_NE(pb->value, 0) << "Divide by zero";
266 }
267 });
268 return NullOpt;
269}
270
271template <>
272inline Optional<PrimExpr> TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
273 TVM_ARITH_CONST_PROPAGATION({
274 const DataType& rtype = a.dtype();
275 if (pa && pb) {
276 ICHECK_NE(pb->value, 0) << "Divide by zero";
277 int64_t res = arith::floordiv(pa->value, pb->value);
278 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
279 }
280 if (pa) {
281 if (pa->value == 0) return a;
282 }
283 if (pb) {
284 if (pb->value == 1) return a;
285 ICHECK_NE(pb->value, 0) << "Divide by zero";
286 }
287 if (fa && fb && fb->value != 0) {
288 if (rtype.bits() == 32) {
289 return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast<float>(fa->value) /
290 static_cast<float>(fb->value))));
291 } else if (rtype.bits() == 64) {
292 return FloatImm(rtype, std::floor(fa->value / fb->value));
293 } else {
294 return NullOpt;
295 }
296 }
297 if (fa && fa->value == 0) return a;
298 if (fb) {
299 if (fb->value == 1) return a;
300 ICHECK_NE(fb->value, 0) << "Divide by zero";
301 }
302 });
303 return NullOpt;
304}
305
306template <>
307inline Optional<PrimExpr> TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
308 TVM_INDEX_CONST_PROPAGATION({
309 const DataType& rtype = a.dtype();
310 if (pa && pb) {
311 ICHECK_NE(pb->value, 0) << "Divide by zero";
312 int64_t res = arith::floormod(pa->value, pb->value);
313 return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
314 }
315 if (pa) {
316 if (pa->value == 0) return a;
317 }
318 if (pb) {
319 if (pb->value == 1) return tir::make_zero(rtype);
320 ICHECK_NE(pb->value, 0) << "Divide by zero";
321 }
322 });
323 return NullOpt;
324}
325
326template <>
327inline Optional<PrimExpr> TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) {
328 TVM_ARITH_CONST_PROPAGATION({
329 const DataType& rtype = a.dtype();
330 if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
331 if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
332 });
333 if (a.same_as(b)) return a;
334 return NullOpt;
335}
336
337template <>
338inline Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
339 TVM_ARITH_CONST_PROPAGATION({
340 const DataType& rtype = a.dtype();
341 if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
342 if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
343 });
344 if (a.same_as(b)) return a;
345 return NullOpt;
346}
347
348template <>
349inline Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
350 TVM_ARITH_CONST_PROPAGATION({
351 if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
352 if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
353 });
354 return NullOpt;
355}
356
357template <>
358inline Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
359 TVM_ARITH_CONST_PROPAGATION({
360 if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
361 if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
362 });
363 return NullOpt;
364}
365
366template <>
367inline Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
368 TVM_ARITH_CONST_PROPAGATION({
369 if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
370 if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
371 });
372 return NullOpt;
373}
374
375template <>
376inline Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
377 TVM_ARITH_CONST_PROPAGATION({
378 if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
379 if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
380 });
381 return NullOpt;
382}
383
384template <>
385inline Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
386 TVM_ARITH_CONST_PROPAGATION({
387 if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
388 if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
389 });
390 return NullOpt;
391}
392
393template <>
394inline Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
395 TVM_ARITH_CONST_PROPAGATION({
396 if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
397 if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
398 });
399 return NullOpt;
400}
401
402template <>
403inline Optional<PrimExpr> TryConstFold<tir::And>(PrimExpr a, PrimExpr b) {
404 const IntImmNode* pa = a.as<IntImmNode>();
405 const IntImmNode* pb = b.as<IntImmNode>();
406 if (pa && pa->value) return b;
407 if (pa && !pa->value) return a;
408 if (pb && pb->value) return a;
409 if (pb && !pb->value) return b;
410 return NullOpt;
411}
412
413template <>
414inline Optional<PrimExpr> TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) {
415 const IntImmNode* pa = a.as<IntImmNode>();
416 const IntImmNode* pb = b.as<IntImmNode>();
417 if (pa && pa->value) return a;
418 if (pa && !pa->value) return b;
419 if (pb && pb->value) return b;
420 if (pb && !pb->value) return a;
421 return NullOpt;
422}
423
424template <>
425inline Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
426 const IntImmNode* pa = a.as<IntImmNode>();
427 if (pa) {
428 return IntImm(DataType::UInt(1), !(pa->value));
429 }
430 return NullOpt;
431}
432
433/*! \brief Helper namespace for symbolic value limits */
434struct SymbolicLimits {
435 /*! \brief positive infinity */
436 static PrimExpr pos_inf_;
437 /*! \brief negative infinity */
438 static PrimExpr neg_inf_;
439};
440
441/*!
442 * \brief Opaque expression representing positive infinity.
443 *
444 * It can can only be used as parameter of by min/max
445 * for integer analysis and cannot be used in normal expressions.
446 *
447 * \return positive infinity.
448 */
449inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; }
450
451/*!
452 * \brief Check if value is positive infinity.
453 * \param value The value to be checked.
454 *
455 * \return The check result.
456 */
457inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); }
458
459/*!
460 * \brief Opaque expression representing negative infinity.
461 *
462 * It can can only be used as parameter of by min/max
463 * for integer analysis and cannot be used in normal expressions.
464 *
465 * \return negative infinity.
466 */
467inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; }
468
469/*!
470 * \brief Check if value is negative infinity.
471 * \param value The value to be checked.
472 *
473 * \return The check result.
474 */
475inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); }
476
477} // namespace arith
478} // namespace tvm
479#endif // TVM_ARITH_CONST_FOLD_H_
480