1 | #pragma once |
2 | #include <c10/macros/Macros.h> |
3 | #include <c10/util/ArrayRef.h> |
4 | |
5 | #include <iterator> |
6 | #include <numeric> |
7 | #include <type_traits> |
8 | |
9 | // GCC has __builtin_mul_overflow from before it supported __has_builtin |
10 | #ifdef _MSC_VER |
11 | #define C10_HAS_BUILTIN_OVERFLOW() (0) |
12 | #include <c10/util/llvmMathExtras.h> |
13 | #include <intrin.h> |
14 | #else |
15 | #define C10_HAS_BUILTIN_OVERFLOW() (1) |
16 | #endif |
17 | |
18 | namespace c10 { |
19 | |
20 | C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { |
21 | #if C10_HAS_BUILTIN_OVERFLOW() |
22 | return __builtin_add_overflow(a, b, out); |
23 | #else |
24 | unsigned long long tmp; |
25 | #if defined(_M_IX86) || defined(_M_X64) |
26 | auto carry = _addcarry_u64(0, a, b, &tmp); |
27 | #else |
28 | tmp = a + b; |
29 | unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); |
30 | auto carry = vector >> 63; |
31 | #endif |
32 | *out = tmp; |
33 | return carry; |
34 | #endif |
35 | } |
36 | |
37 | C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) { |
38 | #if C10_HAS_BUILTIN_OVERFLOW() |
39 | return __builtin_mul_overflow(a, b, out); |
40 | #else |
41 | *out = a * b; |
42 | // This test isnt exact, but avoids doing integer division |
43 | return ( |
44 | (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64); |
45 | #endif |
46 | } |
47 | |
48 | template <typename It> |
49 | bool safe_multiplies_u64(It first, It last, uint64_t* out) { |
50 | #if C10_HAS_BUILTIN_OVERFLOW() |
51 | uint64_t prod = 1; |
52 | bool overflow = false; |
53 | for (; first != last; ++first) { |
54 | overflow |= c10::mul_overflows(prod, *first, &prod); |
55 | } |
56 | *out = prod; |
57 | return overflow; |
58 | #else |
59 | uint64_t prod = 1; |
60 | uint64_t prod_log2 = 0; |
61 | bool is_zero = false; |
62 | for (; first != last; ++first) { |
63 | auto x = static_cast<uint64_t>(*first); |
64 | prod *= x; |
65 | // log2(0) isn't valid, so need to track it specially |
66 | is_zero |= (x == 0); |
67 | prod_log2 += c10::llvm::Log2_64_Ceil(x); |
68 | } |
69 | *out = prod; |
70 | // This test isnt exact, but avoids doing integer division |
71 | return !is_zero && (prod_log2 >= 64); |
72 | #endif |
73 | } |
74 | |
75 | template <typename Container> |
76 | bool safe_multiplies_u64(const Container& c, uint64_t* out) { |
77 | return safe_multiplies_u64(c.begin(), c.end(), out); |
78 | } |
79 | |
80 | } // namespace c10 |
81 | |