1 | #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) |
2 | #error \ |
3 | "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead." |
4 | #endif |
5 | |
6 | namespace c10_complex_math { |
7 | |
8 | // Exponential functions |
9 | |
10 | template <typename T> |
11 | C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) { |
12 | #if defined(__CUDACC__) || defined(__HIPCC__) |
13 | return static_cast<c10::complex<T>>( |
14 | thrust::exp(static_cast<thrust::complex<T>>(x))); |
15 | #else |
16 | return static_cast<c10::complex<T>>( |
17 | std::exp(static_cast<std::complex<T>>(x))); |
18 | #endif |
19 | } |
20 | |
21 | template <typename T> |
22 | C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) { |
23 | #if defined(__CUDACC__) || defined(__HIPCC__) |
24 | return static_cast<c10::complex<T>>( |
25 | thrust::log(static_cast<thrust::complex<T>>(x))); |
26 | #else |
27 | return static_cast<c10::complex<T>>( |
28 | std::log(static_cast<std::complex<T>>(x))); |
29 | #endif |
30 | } |
31 | |
32 | template <typename T> |
33 | C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) { |
34 | #if defined(__CUDACC__) || defined(__HIPCC__) |
35 | return static_cast<c10::complex<T>>( |
36 | thrust::log10(static_cast<thrust::complex<T>>(x))); |
37 | #else |
38 | return static_cast<c10::complex<T>>( |
39 | std::log10(static_cast<std::complex<T>>(x))); |
40 | #endif |
41 | } |
42 | |
43 | template <typename T> |
44 | C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) { |
45 | const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0); |
46 | return c10_complex_math::log(x) / log2; |
47 | } |
48 | |
49 | // Power functions |
50 | // |
51 | #if defined(_LIBCPP_VERSION) || \ |
52 | (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) |
53 | namespace _detail { |
54 | C10_API c10::complex<float> sqrt(const c10::complex<float>& in); |
55 | C10_API c10::complex<double> sqrt(const c10::complex<double>& in); |
56 | C10_API c10::complex<float> acos(const c10::complex<float>& in); |
57 | C10_API c10::complex<double> acos(const c10::complex<double>& in); |
58 | }; // namespace _detail |
59 | #endif |
60 | |
61 | template <typename T> |
62 | C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) { |
63 | #if defined(__CUDACC__) || defined(__HIPCC__) |
64 | return static_cast<c10::complex<T>>( |
65 | thrust::sqrt(static_cast<thrust::complex<T>>(x))); |
66 | #elif !( \ |
67 | defined(_LIBCPP_VERSION) || \ |
68 | (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) |
69 | return static_cast<c10::complex<T>>( |
70 | std::sqrt(static_cast<std::complex<T>>(x))); |
71 | #else |
72 | return _detail::sqrt(x); |
73 | #endif |
74 | } |
75 | |
76 | template <typename T> |
77 | C10_HOST_DEVICE inline c10::complex<T> pow( |
78 | const c10::complex<T>& x, |
79 | const c10::complex<T>& y) { |
80 | #if defined(__CUDACC__) || defined(__HIPCC__) |
81 | return static_cast<c10::complex<T>>(thrust::pow( |
82 | static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y))); |
83 | #else |
84 | return static_cast<c10::complex<T>>(std::pow( |
85 | static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y))); |
86 | #endif |
87 | } |
88 | |
89 | template <typename T> |
90 | C10_HOST_DEVICE inline c10::complex<T> pow( |
91 | const c10::complex<T>& x, |
92 | const T& y) { |
93 | #if defined(__CUDACC__) || defined(__HIPCC__) |
94 | return static_cast<c10::complex<T>>( |
95 | thrust::pow(static_cast<thrust::complex<T>>(x), y)); |
96 | #else |
97 | return static_cast<c10::complex<T>>( |
98 | std::pow(static_cast<std::complex<T>>(x), y)); |
99 | #endif |
100 | } |
101 | |
102 | template <typename T> |
103 | C10_HOST_DEVICE inline c10::complex<T> pow( |
104 | const T& x, |
105 | const c10::complex<T>& y) { |
106 | #if defined(__CUDACC__) || defined(__HIPCC__) |
107 | return static_cast<c10::complex<T>>( |
108 | thrust::pow(x, static_cast<thrust::complex<T>>(y))); |
109 | #else |
110 | return static_cast<c10::complex<T>>( |
111 | std::pow(x, static_cast<std::complex<T>>(y))); |
112 | #endif |
113 | } |
114 | |
115 | template <typename T, typename U> |
116 | C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow( |
117 | const c10::complex<T>& x, |
118 | const c10::complex<U>& y) { |
119 | #if defined(__CUDACC__) || defined(__HIPCC__) |
120 | return static_cast<c10::complex<T>>(thrust::pow( |
121 | static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y))); |
122 | #else |
123 | return static_cast<c10::complex<T>>(std::pow( |
124 | static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y))); |
125 | #endif |
126 | } |
127 | |
128 | template <typename T, typename U> |
129 | C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow( |
130 | const c10::complex<T>& x, |
131 | const U& y) { |
132 | #if defined(__CUDACC__) || defined(__HIPCC__) |
133 | return static_cast<c10::complex<T>>( |
134 | thrust::pow(static_cast<thrust::complex<T>>(x), y)); |
135 | #else |
136 | return static_cast<c10::complex<T>>( |
137 | std::pow(static_cast<std::complex<T>>(x), y)); |
138 | #endif |
139 | } |
140 | |
141 | template <typename T, typename U> |
142 | C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow( |
143 | const T& x, |
144 | const c10::complex<U>& y) { |
145 | #if defined(__CUDACC__) || defined(__HIPCC__) |
146 | return static_cast<c10::complex<T>>( |
147 | thrust::pow(x, static_cast<thrust::complex<T>>(y))); |
148 | #else |
149 | return static_cast<c10::complex<T>>( |
150 | std::pow(x, static_cast<std::complex<T>>(y))); |
151 | #endif |
152 | } |
153 | |
154 | // Trigonometric functions |
155 | |
156 | template <typename T> |
157 | C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) { |
158 | #if defined(__CUDACC__) || defined(__HIPCC__) |
159 | return static_cast<c10::complex<T>>( |
160 | thrust::sin(static_cast<thrust::complex<T>>(x))); |
161 | #else |
162 | return static_cast<c10::complex<T>>( |
163 | std::sin(static_cast<std::complex<T>>(x))); |
164 | #endif |
165 | } |
166 | |
167 | template <typename T> |
168 | C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) { |
169 | #if defined(__CUDACC__) || defined(__HIPCC__) |
170 | return static_cast<c10::complex<T>>( |
171 | thrust::cos(static_cast<thrust::complex<T>>(x))); |
172 | #else |
173 | return static_cast<c10::complex<T>>( |
174 | std::cos(static_cast<std::complex<T>>(x))); |
175 | #endif |
176 | } |
177 | |
178 | template <typename T> |
179 | C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) { |
180 | #if defined(__CUDACC__) || defined(__HIPCC__) |
181 | return static_cast<c10::complex<T>>( |
182 | thrust::tan(static_cast<thrust::complex<T>>(x))); |
183 | #else |
184 | return static_cast<c10::complex<T>>( |
185 | std::tan(static_cast<std::complex<T>>(x))); |
186 | #endif |
187 | } |
188 | |
189 | template <typename T> |
190 | C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) { |
191 | #if defined(__CUDACC__) || defined(__HIPCC__) |
192 | return static_cast<c10::complex<T>>( |
193 | thrust::asin(static_cast<thrust::complex<T>>(x))); |
194 | #else |
195 | return static_cast<c10::complex<T>>( |
196 | std::asin(static_cast<std::complex<T>>(x))); |
197 | #endif |
198 | } |
199 | |
200 | template <typename T> |
201 | C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) { |
202 | #if defined(__CUDACC__) || defined(__HIPCC__) |
203 | return static_cast<c10::complex<T>>( |
204 | thrust::acos(static_cast<thrust::complex<T>>(x))); |
205 | #elif !defined(_LIBCPP_VERSION) |
206 | return static_cast<c10::complex<T>>( |
207 | std::acos(static_cast<std::complex<T>>(x))); |
208 | #else |
209 | return _detail::acos(x); |
210 | #endif |
211 | } |
212 | |
213 | template <typename T> |
214 | C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) { |
215 | #if defined(__CUDACC__) || defined(__HIPCC__) |
216 | return static_cast<c10::complex<T>>( |
217 | thrust::atan(static_cast<thrust::complex<T>>(x))); |
218 | #else |
219 | return static_cast<c10::complex<T>>( |
220 | std::atan(static_cast<std::complex<T>>(x))); |
221 | #endif |
222 | } |
223 | |
224 | // Hyperbolic functions |
225 | |
226 | template <typename T> |
227 | C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) { |
228 | #if defined(__CUDACC__) || defined(__HIPCC__) |
229 | return static_cast<c10::complex<T>>( |
230 | thrust::sinh(static_cast<thrust::complex<T>>(x))); |
231 | #else |
232 | return static_cast<c10::complex<T>>( |
233 | std::sinh(static_cast<std::complex<T>>(x))); |
234 | #endif |
235 | } |
236 | |
237 | template <typename T> |
238 | C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) { |
239 | #if defined(__CUDACC__) || defined(__HIPCC__) |
240 | return static_cast<c10::complex<T>>( |
241 | thrust::cosh(static_cast<thrust::complex<T>>(x))); |
242 | #else |
243 | return static_cast<c10::complex<T>>( |
244 | std::cosh(static_cast<std::complex<T>>(x))); |
245 | #endif |
246 | } |
247 | |
248 | template <typename T> |
249 | C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) { |
250 | #if defined(__CUDACC__) || defined(__HIPCC__) |
251 | return static_cast<c10::complex<T>>( |
252 | thrust::tanh(static_cast<thrust::complex<T>>(x))); |
253 | #else |
254 | return static_cast<c10::complex<T>>( |
255 | std::tanh(static_cast<std::complex<T>>(x))); |
256 | #endif |
257 | } |
258 | |
259 | template <typename T> |
260 | C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) { |
261 | #if defined(__CUDACC__) || defined(__HIPCC__) |
262 | return static_cast<c10::complex<T>>( |
263 | thrust::asinh(static_cast<thrust::complex<T>>(x))); |
264 | #else |
265 | return static_cast<c10::complex<T>>( |
266 | std::asinh(static_cast<std::complex<T>>(x))); |
267 | #endif |
268 | } |
269 | |
270 | template <typename T> |
271 | C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) { |
272 | #if defined(__CUDACC__) || defined(__HIPCC__) |
273 | return static_cast<c10::complex<T>>( |
274 | thrust::acosh(static_cast<thrust::complex<T>>(x))); |
275 | #else |
276 | return static_cast<c10::complex<T>>( |
277 | std::acosh(static_cast<std::complex<T>>(x))); |
278 | #endif |
279 | } |
280 | |
281 | template <typename T> |
282 | C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) { |
283 | #if defined(__CUDACC__) || defined(__HIPCC__) |
284 | return static_cast<c10::complex<T>>( |
285 | thrust::atanh(static_cast<thrust::complex<T>>(x))); |
286 | #else |
287 | return static_cast<c10::complex<T>>( |
288 | std::atanh(static_cast<std::complex<T>>(x))); |
289 | #endif |
290 | } |
291 | |
292 | template <typename T> |
293 | C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) { |
294 | // log1p(z) = log(1 + z) |
295 | // Let's define 1 + z = r * e ^ (i * a), then we have |
296 | // log(r * e ^ (i * a)) = log(r) + i * a |
297 | // With z = x + iy, the term r can be written as |
298 | // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 |
299 | // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 |
300 | // So, log(r) is |
301 | // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) |
302 | // = 0.5 * log1p(x * (x + 2) + y ^ 2) |
303 | // we need to use the expression only on certain condition to avoid overflow |
304 | // and underflow from `(x * (x + 2) + y ^ 2)` |
305 | T x = z.real(); |
306 | T y = z.imag(); |
307 | T zabs = std::abs(z); |
308 | T theta = std::atan2(y, x + T(1)); |
309 | if (zabs < 0.5) { |
310 | T r = x * (T(2) + x) + y * y; |
311 | if (r == 0) { // handle underflow |
312 | return {x, theta}; |
313 | } |
314 | return {T(0.5) * std::log1p(r), theta}; |
315 | } else { |
316 | T z0 = std::hypot(x + 1, y); |
317 | return {std::log(z0), theta}; |
318 | } |
319 | } |
320 | |
321 | } // namespace c10_complex_math |
322 | |
323 | using c10_complex_math::acos; |
324 | using c10_complex_math::acosh; |
325 | using c10_complex_math::asin; |
326 | using c10_complex_math::asinh; |
327 | using c10_complex_math::atan; |
328 | using c10_complex_math::atanh; |
329 | using c10_complex_math::cos; |
330 | using c10_complex_math::cosh; |
331 | using c10_complex_math::exp; |
332 | using c10_complex_math::log; |
333 | using c10_complex_math::log10; |
334 | using c10_complex_math::log1p; |
335 | using c10_complex_math::log2; |
336 | using c10_complex_math::pow; |
337 | using c10_complex_math::sin; |
338 | using c10_complex_math::sinh; |
339 | using c10_complex_math::sqrt; |
340 | using c10_complex_math::tan; |
341 | using c10_complex_math::tanh; |
342 | |
343 | namespace std { |
344 | |
345 | using c10_complex_math::acos; |
346 | using c10_complex_math::acosh; |
347 | using c10_complex_math::asin; |
348 | using c10_complex_math::asinh; |
349 | using c10_complex_math::atan; |
350 | using c10_complex_math::atanh; |
351 | using c10_complex_math::cos; |
352 | using c10_complex_math::cosh; |
353 | using c10_complex_math::exp; |
354 | using c10_complex_math::log; |
355 | using c10_complex_math::log10; |
356 | using c10_complex_math::log1p; |
357 | using c10_complex_math::log2; |
358 | using c10_complex_math::pow; |
359 | using c10_complex_math::sin; |
360 | using c10_complex_math::sinh; |
361 | using c10_complex_math::sqrt; |
362 | using c10_complex_math::tan; |
363 | using c10_complex_math::tanh; |
364 | |
365 | } // namespace std |
366 | |