1#pragma once
2#include <c10/macros/Macros.h>
3#include <c10/util/ArrayRef.h>
4
5#include <iterator>
6#include <numeric>
7#include <type_traits>
8
9// GCC has __builtin_mul_overflow from before it supported __has_builtin
10#ifdef _MSC_VER
11#define C10_HAS_BUILTIN_OVERFLOW() (0)
12#include <c10/util/llvmMathExtras.h>
13#include <intrin.h>
14#else
15#define C10_HAS_BUILTIN_OVERFLOW() (1)
16#endif
17
18namespace c10 {
19
20C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
21#if C10_HAS_BUILTIN_OVERFLOW()
22 return __builtin_add_overflow(a, b, out);
23#else
24 unsigned long long tmp;
25#if defined(_M_IX86) || defined(_M_X64)
26 auto carry = _addcarry_u64(0, a, b, &tmp);
27#else
28 tmp = a + b;
29 unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp);
30 auto carry = vector >> 63;
31#endif
32 *out = tmp;
33 return carry;
34#endif
35}
36
37C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
38#if C10_HAS_BUILTIN_OVERFLOW()
39 return __builtin_mul_overflow(a, b, out);
40#else
41 *out = a * b;
42 // This test isnt exact, but avoids doing integer division
43 return (
44 (c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
45#endif
46}
47
48template <typename It>
49bool safe_multiplies_u64(It first, It last, uint64_t* out) {
50#if C10_HAS_BUILTIN_OVERFLOW()
51 uint64_t prod = 1;
52 bool overflow = false;
53 for (; first != last; ++first) {
54 overflow |= c10::mul_overflows(prod, *first, &prod);
55 }
56 *out = prod;
57 return overflow;
58#else
59 uint64_t prod = 1;
60 uint64_t prod_log2 = 0;
61 bool is_zero = false;
62 for (; first != last; ++first) {
63 auto x = static_cast<uint64_t>(*first);
64 prod *= x;
65 // log2(0) isn't valid, so need to track it specially
66 is_zero |= (x == 0);
67 prod_log2 += c10::llvm::Log2_64_Ceil(x);
68 }
69 *out = prod;
70 // This test isnt exact, but avoids doing integer division
71 return !is_zero && (prod_log2 >= 64);
72#endif
73}
74
75template <typename Container>
76bool safe_multiplies_u64(const Container& c, uint64_t* out) {
77 return safe_multiplies_u64(c.begin(), c.end(), out);
78}
79
80} // namespace c10
81