1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file const_fold.h |
22 | * \brief Centralized location for constant folding. |
23 | */ |
24 | #ifndef TVM_ARITH_CONST_FOLD_H_ |
25 | #define TVM_ARITH_CONST_FOLD_H_ |
26 | |
27 | #include <tvm/runtime/container/optional.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <algorithm> |
32 | #include <cmath> |
33 | #include <limits> |
34 | |
35 | #include "int_operator.h" |
36 | |
37 | namespace tvm { |
38 | namespace arith { |
39 | |
40 | /*! |
41 | * \brief Try to run binary compute with constant folding. |
42 | * |
43 | * \param a The left operand. |
44 | * \param b The right operand. |
45 | * \tparam Op The operator type. |
46 | * |
47 | * \note a and b Must already matched data types with each other. |
48 | * \return NullOpt if constant fold fails, otherwise return folded result. |
49 | */ |
50 | template <typename Op> |
51 | inline Optional<PrimExpr> TryConstFold(PrimExpr a, PrimExpr b); |
52 | |
53 | /*! |
54 | * \brief Try to run unary compute with constant folding. |
55 | * |
56 | * \param a The left operand. |
57 | * \tparam Op The operator type. |
58 | * |
59 | * \note a and b Must already matched data types with each other. |
60 | * \return NullOpt if constant fold fails, otherwise return folded result. |
61 | */ |
62 | template <typename Op> |
63 | inline Optional<PrimExpr> TryConstFold(PrimExpr a); |
64 | |
65 | /*! |
66 | * \brief Check whether type is used to represent index. |
67 | * |
68 | * Index types are frequently used in shape computation |
69 | * and need to be aggressively constant-folded. |
70 | * |
71 | * \param type The type to represent index. |
72 | * \return the checked result. |
73 | */ |
74 | inline bool IsIndexType(const DataType& type) { |
75 | return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); |
76 | } |
77 | |
78 | /*! \brief Helper to get const folding result repr in int64. */ |
79 | inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) { |
80 | if (dtype.bits() < 64) { |
81 | x &= (1LL << dtype.bits()) - 1; |
82 | } |
83 | if (dtype.is_int()) { |
84 | // get sign extended value of integer with specified bits |
85 | int64_t m = 1LL << (dtype.bits() - 1); |
86 | x = (x ^ m) - m; |
87 | } |
88 | return x; |
89 | } |
90 | |
91 | /*! \brief Helper to get fp32 const folding result repr in double. */ |
92 | inline double GetFoldResultDoubleRepr(float x) { |
93 | double res = static_cast<double>(x); |
94 | if (std::isinf(res) || std::isnan(res)) { |
95 | return res; |
96 | } |
97 | // certain platform (eg, on gcc7-i386) do the folding arithmetic |
98 | // on float and write back to double is optimized to double |
99 | // precision arithmetic, this is legal and we check the output |
100 | // range thus to ensure consistency when the float result is inf. |
101 | if (res < std::numeric_limits<float>::lowest()) { |
102 | LOG(WARNING) << "underlying float value overflow" ; |
103 | return -std::numeric_limits<double>::infinity(); |
104 | } else if (res > std::numeric_limits<float>::max()) { |
105 | LOG(WARNING) << "underlying float value overflow" ; |
106 | return std::numeric_limits<double>::infinity(); |
107 | } |
108 | return res; |
109 | } |
110 | |
111 | #define TVM_ARITH_CONST_PROPAGATION(BODY) \ |
112 | using tir::FloatImmNode; \ |
113 | const IntImmNode* pa = a.as<IntImmNode>(); \ |
114 | const IntImmNode* pb = b.as<IntImmNode>(); \ |
115 | const FloatImmNode* fa = a.as<FloatImmNode>(); \ |
116 | const FloatImmNode* fb = b.as<FloatImmNode>(); \ |
117 | BODY; |
118 | |
119 | #define TVM_INDEX_CONST_PROPAGATION(BODY) \ |
120 | const IntImmNode* pa = a.as<IntImmNode>(); \ |
121 | const IntImmNode* pb = b.as<IntImmNode>(); \ |
122 | const DataType& ta = a.dtype(); \ |
123 | const DataType& tb = b.dtype(); \ |
124 | if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ |
125 | BODY; \ |
126 | } |
127 | |
128 | // specialization of constant folders. |
129 | template <> |
130 | inline Optional<PrimExpr> TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) { |
131 | TVM_ARITH_CONST_PROPAGATION({ |
132 | const DataType& rtype = a.dtype(); |
133 | if (pa && pb) { |
134 | int64_t res = pa->value + pb->value; |
135 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
136 | } |
137 | if (pa && pa->value == 0) return b; |
138 | if (pb && pb->value == 0) return a; |
139 | if (fa && fb) { |
140 | if (rtype.bits() == 32) { |
141 | return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) + |
142 | static_cast<float>(fb->value))); |
143 | } else if (rtype.bits() == 64) { |
144 | return FloatImm(rtype, fa->value + fb->value); |
145 | } |
146 | } |
147 | if (fa && fa->value == 0) return b; |
148 | if (fb && fb->value == 0) return a; |
149 | }); |
150 | return NullOpt; |
151 | } |
152 | |
153 | template <> |
154 | inline Optional<PrimExpr> TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) { |
155 | TVM_ARITH_CONST_PROPAGATION({ |
156 | ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && |
157 | (pb && pb->dtype.is_uint() && pb->value > 0U))) |
158 | << "Checked failed. Minuend 's value is 0U and it's dtype is uint " |
159 | << "while Subtrahend's dtype is uint; which will cause a negative uint" ; |
160 | const DataType& rtype = a.dtype(); |
161 | if (pa && pb) { |
162 | int64_t res = pa->value - pb->value; |
163 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
164 | } |
165 | if (pb && pb->value == 0) return a; |
166 | if (fa && fb) { |
167 | if (rtype.bits() == 32) { |
168 | return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) - |
169 | static_cast<float>(fb->value))); |
170 | } else if (rtype.bits() == 64) { |
171 | return FloatImm(rtype, fa->value - fb->value); |
172 | } |
173 | } |
174 | if (fb && fb->value == 0) return a; |
175 | }); |
176 | return NullOpt; |
177 | } |
178 | |
179 | template <> |
180 | inline Optional<PrimExpr> TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) { |
181 | TVM_ARITH_CONST_PROPAGATION({ |
182 | const DataType& rtype = a.dtype(); |
183 | if (pa && pb) { |
184 | int64_t res = pa->value * pb->value; |
185 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
186 | } |
187 | if (pa) { |
188 | if (pa->value == 1) return b; |
189 | if (pa->value == 0) return a; |
190 | } |
191 | if (pb) { |
192 | if (pb->value == 1) return a; |
193 | if (pb->value == 0) return b; |
194 | } |
195 | if (fa && fb) { |
196 | if (rtype.bits() == 32) { |
197 | return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) * |
198 | static_cast<float>(fb->value))); |
199 | } else if (rtype.bits() == 64) { |
200 | return FloatImm(rtype, fa->value * fb->value); |
201 | } |
202 | } |
203 | if (fa) { |
204 | if (fa->value == 1) return b; |
205 | if (fa->value == 0) return a; |
206 | } |
207 | if (fb) { |
208 | if (fb->value == 1) return a; |
209 | if (fb->value == 0) return b; |
210 | } |
211 | }); |
212 | return NullOpt; |
213 | } |
214 | |
215 | template <> |
216 | inline Optional<PrimExpr> TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) { |
217 | TVM_ARITH_CONST_PROPAGATION({ |
218 | const DataType& rtype = a.dtype(); |
219 | if (pa && pb) { |
220 | // due to division and mod can have different modes |
221 | // NOTE: this will assumes truc div. |
222 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
223 | int64_t res = pa->value / pb->value; |
224 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
225 | } |
226 | if (pa) { |
227 | if (pa->value == 0) return a; |
228 | } |
229 | if (pb) { |
230 | if (pb->value == 1) return a; |
231 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
232 | } |
233 | if (fa && fb) { |
234 | ICHECK_NE(fb->value, 0) << "Divide by zero" ; |
235 | if (rtype.bits() == 32) { |
236 | return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) / |
237 | static_cast<float>(fb->value))); |
238 | } else if (rtype.bits() == 64) { |
239 | return FloatImm(rtype, fa->value / fb->value); |
240 | } |
241 | } |
242 | if (fa && fa->value == 0) return a; |
243 | if (fb) { |
244 | if (fb->value == 1) return a; |
245 | ICHECK_NE(fb->value, 0) << "Divide by zero" ; |
246 | } |
247 | }); |
248 | return NullOpt; |
249 | } |
250 | |
251 | template <> |
252 | inline Optional<PrimExpr> TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) { |
253 | TVM_INDEX_CONST_PROPAGATION({ |
254 | const DataType& rtype = a.dtype(); |
255 | if (pa && pb) { |
256 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
257 | int64_t res = pa->value % pb->value; |
258 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
259 | } |
260 | if (pa) { |
261 | if (pa->value == 0) return a; |
262 | } |
263 | if (pb) { |
264 | if (pb->value == 1) return tir::make_zero(rtype); |
265 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
266 | } |
267 | }); |
268 | return NullOpt; |
269 | } |
270 | |
271 | template <> |
272 | inline Optional<PrimExpr> TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) { |
273 | TVM_ARITH_CONST_PROPAGATION({ |
274 | const DataType& rtype = a.dtype(); |
275 | if (pa && pb) { |
276 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
277 | int64_t res = arith::floordiv(pa->value, pb->value); |
278 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
279 | } |
280 | if (pa) { |
281 | if (pa->value == 0) return a; |
282 | } |
283 | if (pb) { |
284 | if (pb->value == 1) return a; |
285 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
286 | } |
287 | if (fa && fb && fb->value != 0) { |
288 | if (rtype.bits() == 32) { |
289 | return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast<float>(fa->value) / |
290 | static_cast<float>(fb->value)))); |
291 | } else if (rtype.bits() == 64) { |
292 | return FloatImm(rtype, std::floor(fa->value / fb->value)); |
293 | } else { |
294 | return NullOpt; |
295 | } |
296 | } |
297 | if (fa && fa->value == 0) return a; |
298 | if (fb) { |
299 | if (fb->value == 1) return a; |
300 | ICHECK_NE(fb->value, 0) << "Divide by zero" ; |
301 | } |
302 | }); |
303 | return NullOpt; |
304 | } |
305 | |
306 | template <> |
307 | inline Optional<PrimExpr> TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) { |
308 | TVM_INDEX_CONST_PROPAGATION({ |
309 | const DataType& rtype = a.dtype(); |
310 | if (pa && pb) { |
311 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
312 | int64_t res = arith::floormod(pa->value, pb->value); |
313 | return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); |
314 | } |
315 | if (pa) { |
316 | if (pa->value == 0) return a; |
317 | } |
318 | if (pb) { |
319 | if (pb->value == 1) return tir::make_zero(rtype); |
320 | ICHECK_NE(pb->value, 0) << "Divide by zero" ; |
321 | } |
322 | }); |
323 | return NullOpt; |
324 | } |
325 | |
326 | template <> |
327 | inline Optional<PrimExpr> TryConstFold<tir::Min>(PrimExpr a, PrimExpr b) { |
328 | TVM_ARITH_CONST_PROPAGATION({ |
329 | const DataType& rtype = a.dtype(); |
330 | if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); |
331 | if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); |
332 | }); |
333 | if (a.same_as(b)) return a; |
334 | return NullOpt; |
335 | } |
336 | |
337 | template <> |
338 | inline Optional<PrimExpr> TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) { |
339 | TVM_ARITH_CONST_PROPAGATION({ |
340 | const DataType& rtype = a.dtype(); |
341 | if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); |
342 | if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); |
343 | }); |
344 | if (a.same_as(b)) return a; |
345 | return NullOpt; |
346 | } |
347 | |
348 | template <> |
349 | inline Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) { |
350 | TVM_ARITH_CONST_PROPAGATION({ |
351 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); |
352 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); |
353 | }); |
354 | return NullOpt; |
355 | } |
356 | |
357 | template <> |
358 | inline Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) { |
359 | TVM_ARITH_CONST_PROPAGATION({ |
360 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); |
361 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); |
362 | }); |
363 | return NullOpt; |
364 | } |
365 | |
366 | template <> |
367 | inline Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) { |
368 | TVM_ARITH_CONST_PROPAGATION({ |
369 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); |
370 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); |
371 | }); |
372 | return NullOpt; |
373 | } |
374 | |
375 | template <> |
376 | inline Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) { |
377 | TVM_ARITH_CONST_PROPAGATION({ |
378 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); |
379 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); |
380 | }); |
381 | return NullOpt; |
382 | } |
383 | |
384 | template <> |
385 | inline Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) { |
386 | TVM_ARITH_CONST_PROPAGATION({ |
387 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); |
388 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); |
389 | }); |
390 | return NullOpt; |
391 | } |
392 | |
393 | template <> |
394 | inline Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) { |
395 | TVM_ARITH_CONST_PROPAGATION({ |
396 | if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); |
397 | if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); |
398 | }); |
399 | return NullOpt; |
400 | } |
401 | |
402 | template <> |
403 | inline Optional<PrimExpr> TryConstFold<tir::And>(PrimExpr a, PrimExpr b) { |
404 | const IntImmNode* pa = a.as<IntImmNode>(); |
405 | const IntImmNode* pb = b.as<IntImmNode>(); |
406 | if (pa && pa->value) return b; |
407 | if (pa && !pa->value) return a; |
408 | if (pb && pb->value) return a; |
409 | if (pb && !pb->value) return b; |
410 | return NullOpt; |
411 | } |
412 | |
413 | template <> |
414 | inline Optional<PrimExpr> TryConstFold<tir::Or>(PrimExpr a, PrimExpr b) { |
415 | const IntImmNode* pa = a.as<IntImmNode>(); |
416 | const IntImmNode* pb = b.as<IntImmNode>(); |
417 | if (pa && pa->value) return a; |
418 | if (pa && !pa->value) return b; |
419 | if (pb && pb->value) return b; |
420 | if (pb && !pb->value) return a; |
421 | return NullOpt; |
422 | } |
423 | |
424 | template <> |
425 | inline Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) { |
426 | const IntImmNode* pa = a.as<IntImmNode>(); |
427 | if (pa) { |
428 | return IntImm(DataType::UInt(1), !(pa->value)); |
429 | } |
430 | return NullOpt; |
431 | } |
432 | |
433 | /*! \brief Helper namespace for symbolic value limits */ |
434 | struct SymbolicLimits { |
435 | /*! \brief positive infinity */ |
436 | static PrimExpr pos_inf_; |
437 | /*! \brief negative infinity */ |
438 | static PrimExpr neg_inf_; |
439 | }; |
440 | |
441 | /*! |
442 | * \brief Opaque expression representing positive infinity. |
443 | * |
444 | * It can can only be used as parameter of by min/max |
445 | * for integer analysis and cannot be used in normal expressions. |
446 | * |
447 | * \return positive infinity. |
448 | */ |
449 | inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; } |
450 | |
451 | /*! |
452 | * \brief Check if value is positive infinity. |
453 | * \param value The value to be checked. |
454 | * |
455 | * \return The check result. |
456 | */ |
457 | inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); } |
458 | |
459 | /*! |
460 | * \brief Opaque expression representing negative infinity. |
461 | * |
462 | * It can can only be used as parameter of by min/max |
463 | * for integer analysis and cannot be used in normal expressions. |
464 | * |
465 | * \return negative infinity. |
466 | */ |
467 | inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; } |
468 | |
469 | /*! |
470 | * \brief Check if value is negative infinity. |
471 | * \param value The value to be checked. |
472 | * |
473 | * \return The check result. |
474 | */ |
475 | inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); } |
476 | |
477 | } // namespace arith |
478 | } // namespace tvm |
479 | #endif // TVM_ARITH_CONST_FOLD_H_ |
480 | |