1#pragma once
2#include <c10/macros/Macros.h>
3#include <c10/util/BFloat16.h>
4#include <c10/util/Half.h>
5
6#include <type_traits>
7
8C10_CLANG_DIAGNOSTIC_PUSH()
9#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
10C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
11#endif
12#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
13C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
14#endif
15
16namespace c10 {
17
18template <typename dest_t, typename src_t>
19struct needs_real {
20 constexpr static bool value =
21 (is_complex<src_t>::value && !is_complex<dest_t>::value);
22};
23
24template <bool, typename src_t>
25struct maybe_real {
26 C10_HOST_DEVICE static inline src_t apply(src_t src) {
27 return src;
28 }
29};
30
31template <typename src_t>
32struct maybe_real<true, src_t> {
33 C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
34 return src.real();
35 }
36};
37
38// Note: deliberately ignores undefined behavior, consistent with NumPy.
39// PyTorch's type conversions can cause a variety of undefined behavior,
40// including float to integral overflow and signed to unsigned integer overflow.
41// Some of this undefined behavior is addressed below.
42template <typename dest_t, typename src_t>
43struct static_cast_with_inter_type {
44 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply(
45 src_t src) {
46 constexpr bool real = needs_real<dest_t, src_t>::value;
47 auto r = maybe_real<real, src_t>::apply(src);
48 return static_cast<dest_t>(r);
49 }
50};
51
52// Partial template instantiation for casting to uint8.
53// Note: Converting from negative float values to unsigned integer types is
54// undefined behavior in C++, and current CPU and GPU compilers exhibit
55// divergent behavior. Casting from negative float values to signed
56// integer types and then to unsigned integer types is not undefined,
57// however, so this cast improves the consistency of type conversions
58// to uint8 across compilers.
59// Further note: Type conversions across compilers still have other undefined
60// and divergent behavior.
61template <typename src_t>
62struct static_cast_with_inter_type<uint8_t, src_t> {
63 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply(
64 src_t src) {
65 constexpr bool real = needs_real<uint8_t, src_t>::value;
66 return static_cast<uint8_t>(
67 static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
68 }
69};
70
71template <>
72struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> {
73 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
74 c10::Half>
75 apply(c10::BFloat16 src) {
76 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
77 }
78};
79
80template <>
81struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
82 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
83 c10::Half>
84 apply(c10::Half src) {
85 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
86 }
87};
88
89template <>
90struct static_cast_with_inter_type<
91 c10::complex<c10::Half>,
92 c10::complex<double>> {
93 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
94 c10::Half>
95 apply(c10::complex<double> src) {
96 return static_cast<c10::complex<c10::Half>>(
97 static_cast<c10::complex<float>>(src));
98 }
99};
100
101template <typename To, typename From>
102C10_HOST_DEVICE To convert(From f) {
103 return static_cast_with_inter_type<To, From>::apply(f);
104}
105
106// Define separately to avoid being inlined and prevent code-size bloat
107C10_API void report_overflow(const char* name);
108
109template <typename To, typename From>
110To checked_convert(From f, const char* name) {
111 // Converting to bool can't overflow so we exclude this case from checking.
112 if (!std::is_same<To, bool>::value && overflows<To, From>(f)) {
113 report_overflow(name);
114 }
115 return convert<To, From>(f);
116}
117
118} // namespace c10
119
120C10_CLANG_DIAGNOSTIC_POP()
121
122// Trigger tests for D25440771. TODO: Remove this line any time you want.
123