1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file int_operator.h
22 * \brief Additional useful operators for integer.
23 */
24#ifndef TVM_ARITH_INT_OPERATOR_H_
25#define TVM_ARITH_INT_OPERATOR_H_
26
27#include <limits>
28#include <utility>
29
30namespace tvm {
31namespace arith {
32
33/*!
34 * \brief Check if an integer op with operand x, y will overflow.
35 * \param x The left operand.
36 * \param y The left operand.
37 * \param min_value The minimum value of the domain.
38 * \param max_value The maximum value of the domain.
39 * \return Whether overflow can happen.
40 * \tparam Op The integer operator.
41 */
42template <typename Op>
43inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
44 return false;
45}
46
47template <>
48inline bool WillOverflow<tir::AddNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
49 if ((y > 0) && (x > max_value - y)) return true;
50 if ((y < 0) && (x < min_value - y)) return true;
51 return false;
52}
53
54template <>
55inline bool WillOverflow<tir::SubNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
56 if ((y > 0) && (x < min_value + y)) return true;
57 if ((y < 0) && (x > max_value + y)) return true;
58 return false;
59}
60
61template <>
62inline bool WillOverflow<tir::MulNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
63 if (y == 0) return false;
64 if (y > 0) {
65 if (x < min_value / y) return true;
66 if (x > max_value / y) return true;
67 } else {
68 if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true;
69 if (x > min_value / y) return true;
70 if (x < max_value / y) return true;
71 }
72 return false;
73}
74
75template <>
76inline bool WillOverflow<tir::ModNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
77 return y == 0;
78}
79
80/*!
81 * \brief Perform trunc division of two integers.
82 * \param x The left operand.
83 * \param y The right operand.
84 * \return the result.
85 */
86inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; }
87
88/*!
89 * \brief Compute the truncdiv remainder of two integers.
90 * \param x The left operand.
91 * \param y The right operand.
92 * \return the result.
93 */
94inline int64_t truncmod(int64_t x, int64_t y) { return x % y; }
95
96/*!
97 * \brief Perform floor division of two integers.
98 * \param x The left operand.
99 * \param y The right operand.
100 * \return the result.
101 */
102inline int64_t floordiv(int64_t x, int64_t y) {
103 int64_t rdiv = x / y;
104 int64_t rmod = x % y;
105 bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
106 return is_floor_div ? rdiv : (rdiv - 1);
107}
108
109/*!
110 * \brief Compute the floordiv remainder of two integers.
111 * \param x The left operand.
112 * \param y The right operand.
113 * \return the result.
114 */
115inline int64_t floormod(int64_t x, int64_t y) {
116 int64_t rmod = x % y;
117 bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
118 return is_floor_div ? rmod : rmod + y;
119}
120
121/*!
122 * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
123 * \param a The first coefficient.
124 * \param b The second coefficient.
125 * \param x The solution of x.
126 * \param y The solution of y.
127 * \return The GCD of a and b.
128 */
129inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
130 // Extended Euclidean algorithm
131 // if a < 0, the problem can be convert into
132 // |a|* (-x) + b * y = gcd(|a|, b)
133 //
134 // initial condition:
135 // a * 0 + b * 1 = b
136 // a * 1 + b * 0 = a
137 int64_t s = 0, old_s = 1;
138 int64_t r = b, old_r = a >= 0 ? a : -a;
139 // Iteration (r2 < r1):
140 // a * x1 + b * y1 = r1
141 // a * x2 + b * y2 = r2
142 // The above two eqs can derive the following eq (q = r1 / r2)
143 // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
144 // Because r3 < r2, the iteration can eventually terminate
145 while (r != 0) {
146 int64_t q = old_r / r;
147 int64_t tmp = old_r;
148 old_r = r;
149 r = tmp - q * r;
150 tmp = old_s;
151 old_s = s;
152 s = tmp - q * s;
153 }
154
155 *x = a >= 0 ? old_s : -old_s;
156 if (b != 0) {
157 *y = (old_r - (*x) * a) / b;
158 } else {
159 *y = 1;
160 }
161
162 return old_r;
163}
164
165/*!
166 * \brief Take GCD of a and b.
167 * \param a The first operand.
168 * \param b The second operand.
169 * \return The result.
170 */
171inline int64_t ZeroAwareGCD(int64_t a, int64_t b) {
172 if (a < 0) a = -a;
173 if (b < 0) b = -b;
174 if (a < b) std::swap(a, b);
175 if (b == 0) return a;
176 // perform GCD (greatest common divisor)
177 // ax + by = gcd(a, b) z if a != 0, b != 0
178 while (a % b != 0) {
179 a = a % b;
180 std::swap(a, b);
181 }
182 return b;
183}
184
185/*!
186 * \brief Calculate the least common multiple for two values.
187 * \param a an integer number
188 * \param b an integer number
189 * \return the least common multiple.
190 */
191inline int64_t LeastCommonMultiple(int64_t a, int64_t b) {
192 int64_t x, y;
193 return (a * b) / ExtendedEuclidean(a, b, &x, &y);
194}
195
196} // namespace arith
197} // namespace tvm
198#endif // TVM_ARITH_INT_OPERATOR_H_
199