1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file int_operator.h |
22 | * \brief Additional useful operators for integer. |
23 | */ |
24 | #ifndef TVM_ARITH_INT_OPERATOR_H_ |
25 | #define TVM_ARITH_INT_OPERATOR_H_ |
26 | |
27 | #include <limits> |
28 | #include <utility> |
29 | |
30 | namespace tvm { |
31 | namespace arith { |
32 | |
33 | /*! |
34 | * \brief Check if an integer op with operand x, y will overflow. |
35 | * \param x The left operand. |
36 | * \param y The left operand. |
37 | * \param min_value The minimum value of the domain. |
38 | * \param max_value The maximum value of the domain. |
39 | * \return Whether overflow can happen. |
40 | * \tparam Op The integer operator. |
41 | */ |
42 | template <typename Op> |
43 | inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { |
44 | return false; |
45 | } |
46 | |
47 | template <> |
48 | inline bool WillOverflow<tir::AddNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { |
49 | if ((y > 0) && (x > max_value - y)) return true; |
50 | if ((y < 0) && (x < min_value - y)) return true; |
51 | return false; |
52 | } |
53 | |
54 | template <> |
55 | inline bool WillOverflow<tir::SubNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { |
56 | if ((y > 0) && (x < min_value + y)) return true; |
57 | if ((y < 0) && (x > max_value + y)) return true; |
58 | return false; |
59 | } |
60 | |
61 | template <> |
62 | inline bool WillOverflow<tir::MulNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { |
63 | if (y == 0) return false; |
64 | if (y > 0) { |
65 | if (x < min_value / y) return true; |
66 | if (x > max_value / y) return true; |
67 | } else { |
68 | if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true; |
69 | if (x > min_value / y) return true; |
70 | if (x < max_value / y) return true; |
71 | } |
72 | return false; |
73 | } |
74 | |
75 | template <> |
76 | inline bool WillOverflow<tir::ModNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { |
77 | return y == 0; |
78 | } |
79 | |
80 | /*! |
81 | * \brief Perform trunc division of two integers. |
82 | * \param x The left operand. |
83 | * \param y The right operand. |
84 | * \return the result. |
85 | */ |
86 | inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } |
87 | |
88 | /*! |
89 | * \brief Compute the truncdiv remainder of two integers. |
90 | * \param x The left operand. |
91 | * \param y The right operand. |
92 | * \return the result. |
93 | */ |
94 | inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } |
95 | |
96 | /*! |
97 | * \brief Perform floor division of two integers. |
98 | * \param x The left operand. |
99 | * \param y The right operand. |
100 | * \return the result. |
101 | */ |
102 | inline int64_t floordiv(int64_t x, int64_t y) { |
103 | int64_t rdiv = x / y; |
104 | int64_t rmod = x % y; |
105 | bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); |
106 | return is_floor_div ? rdiv : (rdiv - 1); |
107 | } |
108 | |
109 | /*! |
110 | * \brief Compute the floordiv remainder of two integers. |
111 | * \param x The left operand. |
112 | * \param y The right operand. |
113 | * \return the result. |
114 | */ |
115 | inline int64_t floormod(int64_t x, int64_t y) { |
116 | int64_t rmod = x % y; |
117 | bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); |
118 | return is_floor_div ? rmod : rmod + y; |
119 | } |
120 | |
121 | /*! |
122 | * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) |
123 | * \param a The first coefficient. |
124 | * \param b The second coefficient. |
125 | * \param x The solution of x. |
126 | * \param y The solution of y. |
127 | * \return The GCD of a and b. |
128 | */ |
129 | inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { |
130 | // Extended Euclidean algorithm |
131 | // if a < 0, the problem can be convert into |
132 | // |a|* (-x) + b * y = gcd(|a|, b) |
133 | // |
134 | // initial condition: |
135 | // a * 0 + b * 1 = b |
136 | // a * 1 + b * 0 = a |
137 | int64_t s = 0, old_s = 1; |
138 | int64_t r = b, old_r = a >= 0 ? a : -a; |
139 | // Iteration (r2 < r1): |
140 | // a * x1 + b * y1 = r1 |
141 | // a * x2 + b * y2 = r2 |
142 | // The above two eqs can derive the following eq (q = r1 / r2) |
143 | // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 |
144 | // Because r3 < r2, the iteration can eventually terminate |
145 | while (r != 0) { |
146 | int64_t q = old_r / r; |
147 | int64_t tmp = old_r; |
148 | old_r = r; |
149 | r = tmp - q * r; |
150 | tmp = old_s; |
151 | old_s = s; |
152 | s = tmp - q * s; |
153 | } |
154 | |
155 | *x = a >= 0 ? old_s : -old_s; |
156 | if (b != 0) { |
157 | *y = (old_r - (*x) * a) / b; |
158 | } else { |
159 | *y = 1; |
160 | } |
161 | |
162 | return old_r; |
163 | } |
164 | |
165 | /*! |
166 | * \brief Take GCD of a and b. |
167 | * \param a The first operand. |
168 | * \param b The second operand. |
169 | * \return The result. |
170 | */ |
171 | inline int64_t ZeroAwareGCD(int64_t a, int64_t b) { |
172 | if (a < 0) a = -a; |
173 | if (b < 0) b = -b; |
174 | if (a < b) std::swap(a, b); |
175 | if (b == 0) return a; |
176 | // perform GCD (greatest common divisor) |
177 | // ax + by = gcd(a, b) z if a != 0, b != 0 |
178 | while (a % b != 0) { |
179 | a = a % b; |
180 | std::swap(a, b); |
181 | } |
182 | return b; |
183 | } |
184 | |
185 | /*! |
186 | * \brief Calculate the least common multiple for two values. |
187 | * \param a an integer number |
188 | * \param b an integer number |
189 | * \return the least common multiple. |
190 | */ |
191 | inline int64_t LeastCommonMultiple(int64_t a, int64_t b) { |
192 | int64_t x, y; |
193 | return (a * b) / ExtendedEuclidean(a, b, &x, &y); |
194 | } |
195 | |
196 | } // namespace arith |
197 | } // namespace tvm |
198 | #endif // TVM_ARITH_INT_OPERATOR_H_ |
199 | |