1 | #pragma once |
2 | |
3 | #ifdef __HIPCC__ |
4 | #include <hip/hip_runtime.h> |
5 | #endif |
6 | |
7 | #include <c10/macros/Macros.h> |
8 | #include <c10/util/BFloat16.h> |
9 | #include <c10/util/Half.h> |
10 | #include <c10/util/complex.h> |
11 | |
12 | #include <cmath> |
13 | #include <type_traits> |
14 | |
15 | namespace at { |
16 | |
17 | // std::isnan isn't performant to use on integral types; it will |
18 | // (uselessly) convert to floating point and then do the test. |
19 | // This function is. |
20 | |
21 | template < |
22 | typename T, |
23 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
24 | inline C10_HOST_DEVICE bool _isnan(T /*val*/) { |
25 | return false; |
26 | } |
27 | |
28 | template < |
29 | typename T, |
30 | typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> |
31 | inline C10_HOST_DEVICE bool _isnan(T val) { |
32 | #if defined(__CUDACC__) || defined(__HIPCC__) |
33 | return ::isnan(val); |
34 | #else |
35 | return std::isnan(val); |
36 | #endif |
37 | } |
38 | |
39 | template < |
40 | typename T, |
41 | typename std::enable_if<c10::is_complex<T>::value, int>::type = 0> |
42 | inline bool _isnan(T val) { |
43 | return std::isnan(val.real()) || std::isnan(val.imag()); |
44 | } |
45 | |
46 | template < |
47 | typename T, |
48 | typename std::enable_if<std::is_same<T, at::Half>::value, int>::type = 0> |
49 | inline C10_HOST_DEVICE bool _isnan(T val) { |
50 | return at::_isnan(static_cast<float>(val)); |
51 | } |
52 | |
53 | template < |
54 | typename T, |
55 | typename std::enable_if<std::is_same<T, at::BFloat16>::value, int>::type = |
56 | 0> |
57 | inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { |
58 | return at::_isnan(static_cast<float>(val)); |
59 | } |
60 | |
61 | inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { |
62 | return at::_isnan(static_cast<float>(val)); |
63 | } |
64 | |
65 | // std::isinf isn't performant to use on integral types; it will |
66 | // (uselessly) convert to floating point and then do the test. |
67 | // This function is. |
68 | |
69 | template < |
70 | typename T, |
71 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
72 | inline C10_HOST_DEVICE bool _isinf(T /*val*/) { |
73 | return false; |
74 | } |
75 | |
76 | template < |
77 | typename T, |
78 | typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0> |
79 | inline C10_HOST_DEVICE bool _isinf(T val) { |
80 | #if defined(__CUDACC__) || defined(__HIPCC__) |
81 | return ::isinf(val); |
82 | #else |
83 | return std::isinf(val); |
84 | #endif |
85 | } |
86 | |
87 | inline C10_HOST_DEVICE bool _isinf(at::Half val) { |
88 | return at::_isinf(static_cast<float>(val)); |
89 | } |
90 | |
91 | inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) { |
92 | return at::_isinf(static_cast<float>(val)); |
93 | } |
94 | |
95 | template <typename T> |
96 | C10_HOST_DEVICE inline T exp(T x) { |
97 | static_assert( |
98 | !std::is_same<T, double>::value, |
99 | "this template must be used with float or less precise type" ); |
100 | #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) |
101 | // use __expf fast approximation for peak bandwidth |
102 | return __expf(x); |
103 | #else |
104 | return ::exp(x); |
105 | #endif |
106 | } |
107 | |
108 | template <> |
109 | C10_HOST_DEVICE inline double exp<double>(double x) { |
110 | return ::exp(x); |
111 | } |
112 | |
113 | template <typename T> |
114 | C10_HOST_DEVICE inline T log(T x) { |
115 | static_assert( |
116 | !std::is_same<T, double>::value, |
117 | "this template must be used with float or less precise type" ); |
118 | #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) |
119 | // use __logf fast approximation for peak bandwidth |
120 | return __logf(x); |
121 | #else |
122 | return ::log(x); |
123 | #endif |
124 | } |
125 | |
126 | template <> |
127 | C10_HOST_DEVICE inline double log<double>(double x) { |
128 | return ::log(x); |
129 | } |
130 | |
131 | template <typename T> |
132 | C10_HOST_DEVICE inline T log1p(T x) { |
133 | static_assert( |
134 | !std::is_same<T, double>::value, |
135 | "this template must be used with float or less precise type" ); |
136 | #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) |
137 | // use __logf fast approximation for peak bandwidth |
138 | // NOTE: There is no __log1pf so unfortunately we lose precision. |
139 | return __logf(1.0f + x); |
140 | #else |
141 | return ::log1p(x); |
142 | #endif |
143 | } |
144 | |
145 | template <> |
146 | C10_HOST_DEVICE inline double log1p<double>(double x) { |
147 | return ::log1p(x); |
148 | } |
149 | |
150 | template <typename T> |
151 | C10_HOST_DEVICE inline T tan(T x) { |
152 | static_assert( |
153 | !std::is_same<T, double>::value, |
154 | "this template must be used with float or less precise type" ); |
155 | #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) |
156 | // use __tanf fast approximation for peak bandwidth |
157 | return __tanf(x); |
158 | #else |
159 | return ::tan(x); |
160 | #endif |
161 | } |
162 | |
163 | template <> |
164 | C10_HOST_DEVICE inline double tan<double>(double x) { |
165 | return ::tan(x); |
166 | } |
167 | |
168 | } // namespace at |
169 | |