1#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
2#error \
3 "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead."
4#endif
5
6namespace c10_complex_math {
7
8// Exponential functions
9
10template <typename T>
11C10_HOST_DEVICE inline c10::complex<T> exp(const c10::complex<T>& x) {
12#if defined(__CUDACC__) || defined(__HIPCC__)
13 return static_cast<c10::complex<T>>(
14 thrust::exp(static_cast<thrust::complex<T>>(x)));
15#else
16 return static_cast<c10::complex<T>>(
17 std::exp(static_cast<std::complex<T>>(x)));
18#endif
19}
20
21template <typename T>
22C10_HOST_DEVICE inline c10::complex<T> log(const c10::complex<T>& x) {
23#if defined(__CUDACC__) || defined(__HIPCC__)
24 return static_cast<c10::complex<T>>(
25 thrust::log(static_cast<thrust::complex<T>>(x)));
26#else
27 return static_cast<c10::complex<T>>(
28 std::log(static_cast<std::complex<T>>(x)));
29#endif
30}
31
32template <typename T>
33C10_HOST_DEVICE inline c10::complex<T> log10(const c10::complex<T>& x) {
34#if defined(__CUDACC__) || defined(__HIPCC__)
35 return static_cast<c10::complex<T>>(
36 thrust::log10(static_cast<thrust::complex<T>>(x)));
37#else
38 return static_cast<c10::complex<T>>(
39 std::log10(static_cast<std::complex<T>>(x)));
40#endif
41}
42
43template <typename T>
44C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T>& x) {
45 const c10::complex<T> log2 = c10::complex<T>(::log(2.0), 0.0);
46 return c10_complex_math::log(x) / log2;
47}
48
49// Power functions
50//
51#if defined(_LIBCPP_VERSION) || \
52 (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
53namespace _detail {
54C10_API c10::complex<float> sqrt(const c10::complex<float>& in);
55C10_API c10::complex<double> sqrt(const c10::complex<double>& in);
56C10_API c10::complex<float> acos(const c10::complex<float>& in);
57C10_API c10::complex<double> acos(const c10::complex<double>& in);
58}; // namespace _detail
59#endif
60
61template <typename T>
62C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T>& x) {
63#if defined(__CUDACC__) || defined(__HIPCC__)
64 return static_cast<c10::complex<T>>(
65 thrust::sqrt(static_cast<thrust::complex<T>>(x)));
66#elif !( \
67 defined(_LIBCPP_VERSION) || \
68 (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)))
69 return static_cast<c10::complex<T>>(
70 std::sqrt(static_cast<std::complex<T>>(x)));
71#else
72 return _detail::sqrt(x);
73#endif
74}
75
76template <typename T>
77C10_HOST_DEVICE inline c10::complex<T> pow(
78 const c10::complex<T>& x,
79 const c10::complex<T>& y) {
80#if defined(__CUDACC__) || defined(__HIPCC__)
81 return static_cast<c10::complex<T>>(thrust::pow(
82 static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
83#else
84 return static_cast<c10::complex<T>>(std::pow(
85 static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
86#endif
87}
88
89template <typename T>
90C10_HOST_DEVICE inline c10::complex<T> pow(
91 const c10::complex<T>& x,
92 const T& y) {
93#if defined(__CUDACC__) || defined(__HIPCC__)
94 return static_cast<c10::complex<T>>(
95 thrust::pow(static_cast<thrust::complex<T>>(x), y));
96#else
97 return static_cast<c10::complex<T>>(
98 std::pow(static_cast<std::complex<T>>(x), y));
99#endif
100}
101
102template <typename T>
103C10_HOST_DEVICE inline c10::complex<T> pow(
104 const T& x,
105 const c10::complex<T>& y) {
106#if defined(__CUDACC__) || defined(__HIPCC__)
107 return static_cast<c10::complex<T>>(
108 thrust::pow(x, static_cast<thrust::complex<T>>(y)));
109#else
110 return static_cast<c10::complex<T>>(
111 std::pow(x, static_cast<std::complex<T>>(y)));
112#endif
113}
114
115template <typename T, typename U>
116C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
117 const c10::complex<T>& x,
118 const c10::complex<U>& y) {
119#if defined(__CUDACC__) || defined(__HIPCC__)
120 return static_cast<c10::complex<T>>(thrust::pow(
121 static_cast<thrust::complex<T>>(x), static_cast<thrust::complex<T>>(y)));
122#else
123 return static_cast<c10::complex<T>>(std::pow(
124 static_cast<std::complex<T>>(x), static_cast<std::complex<T>>(y)));
125#endif
126}
127
128template <typename T, typename U>
129C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
130 const c10::complex<T>& x,
131 const U& y) {
132#if defined(__CUDACC__) || defined(__HIPCC__)
133 return static_cast<c10::complex<T>>(
134 thrust::pow(static_cast<thrust::complex<T>>(x), y));
135#else
136 return static_cast<c10::complex<T>>(
137 std::pow(static_cast<std::complex<T>>(x), y));
138#endif
139}
140
141template <typename T, typename U>
142C10_HOST_DEVICE inline c10::complex<decltype(T() * U())> pow(
143 const T& x,
144 const c10::complex<U>& y) {
145#if defined(__CUDACC__) || defined(__HIPCC__)
146 return static_cast<c10::complex<T>>(
147 thrust::pow(x, static_cast<thrust::complex<T>>(y)));
148#else
149 return static_cast<c10::complex<T>>(
150 std::pow(x, static_cast<std::complex<T>>(y)));
151#endif
152}
153
154// Trigonometric functions
155
156template <typename T>
157C10_HOST_DEVICE inline c10::complex<T> sin(const c10::complex<T>& x) {
158#if defined(__CUDACC__) || defined(__HIPCC__)
159 return static_cast<c10::complex<T>>(
160 thrust::sin(static_cast<thrust::complex<T>>(x)));
161#else
162 return static_cast<c10::complex<T>>(
163 std::sin(static_cast<std::complex<T>>(x)));
164#endif
165}
166
167template <typename T>
168C10_HOST_DEVICE inline c10::complex<T> cos(const c10::complex<T>& x) {
169#if defined(__CUDACC__) || defined(__HIPCC__)
170 return static_cast<c10::complex<T>>(
171 thrust::cos(static_cast<thrust::complex<T>>(x)));
172#else
173 return static_cast<c10::complex<T>>(
174 std::cos(static_cast<std::complex<T>>(x)));
175#endif
176}
177
178template <typename T>
179C10_HOST_DEVICE inline c10::complex<T> tan(const c10::complex<T>& x) {
180#if defined(__CUDACC__) || defined(__HIPCC__)
181 return static_cast<c10::complex<T>>(
182 thrust::tan(static_cast<thrust::complex<T>>(x)));
183#else
184 return static_cast<c10::complex<T>>(
185 std::tan(static_cast<std::complex<T>>(x)));
186#endif
187}
188
189template <typename T>
190C10_HOST_DEVICE inline c10::complex<T> asin(const c10::complex<T>& x) {
191#if defined(__CUDACC__) || defined(__HIPCC__)
192 return static_cast<c10::complex<T>>(
193 thrust::asin(static_cast<thrust::complex<T>>(x)));
194#else
195 return static_cast<c10::complex<T>>(
196 std::asin(static_cast<std::complex<T>>(x)));
197#endif
198}
199
200template <typename T>
201C10_HOST_DEVICE inline c10::complex<T> acos(const c10::complex<T>& x) {
202#if defined(__CUDACC__) || defined(__HIPCC__)
203 return static_cast<c10::complex<T>>(
204 thrust::acos(static_cast<thrust::complex<T>>(x)));
205#elif !defined(_LIBCPP_VERSION)
206 return static_cast<c10::complex<T>>(
207 std::acos(static_cast<std::complex<T>>(x)));
208#else
209 return _detail::acos(x);
210#endif
211}
212
213template <typename T>
214C10_HOST_DEVICE inline c10::complex<T> atan(const c10::complex<T>& x) {
215#if defined(__CUDACC__) || defined(__HIPCC__)
216 return static_cast<c10::complex<T>>(
217 thrust::atan(static_cast<thrust::complex<T>>(x)));
218#else
219 return static_cast<c10::complex<T>>(
220 std::atan(static_cast<std::complex<T>>(x)));
221#endif
222}
223
224// Hyperbolic functions
225
226template <typename T>
227C10_HOST_DEVICE inline c10::complex<T> sinh(const c10::complex<T>& x) {
228#if defined(__CUDACC__) || defined(__HIPCC__)
229 return static_cast<c10::complex<T>>(
230 thrust::sinh(static_cast<thrust::complex<T>>(x)));
231#else
232 return static_cast<c10::complex<T>>(
233 std::sinh(static_cast<std::complex<T>>(x)));
234#endif
235}
236
237template <typename T>
238C10_HOST_DEVICE inline c10::complex<T> cosh(const c10::complex<T>& x) {
239#if defined(__CUDACC__) || defined(__HIPCC__)
240 return static_cast<c10::complex<T>>(
241 thrust::cosh(static_cast<thrust::complex<T>>(x)));
242#else
243 return static_cast<c10::complex<T>>(
244 std::cosh(static_cast<std::complex<T>>(x)));
245#endif
246}
247
248template <typename T>
249C10_HOST_DEVICE inline c10::complex<T> tanh(const c10::complex<T>& x) {
250#if defined(__CUDACC__) || defined(__HIPCC__)
251 return static_cast<c10::complex<T>>(
252 thrust::tanh(static_cast<thrust::complex<T>>(x)));
253#else
254 return static_cast<c10::complex<T>>(
255 std::tanh(static_cast<std::complex<T>>(x)));
256#endif
257}
258
259template <typename T>
260C10_HOST_DEVICE inline c10::complex<T> asinh(const c10::complex<T>& x) {
261#if defined(__CUDACC__) || defined(__HIPCC__)
262 return static_cast<c10::complex<T>>(
263 thrust::asinh(static_cast<thrust::complex<T>>(x)));
264#else
265 return static_cast<c10::complex<T>>(
266 std::asinh(static_cast<std::complex<T>>(x)));
267#endif
268}
269
270template <typename T>
271C10_HOST_DEVICE inline c10::complex<T> acosh(const c10::complex<T>& x) {
272#if defined(__CUDACC__) || defined(__HIPCC__)
273 return static_cast<c10::complex<T>>(
274 thrust::acosh(static_cast<thrust::complex<T>>(x)));
275#else
276 return static_cast<c10::complex<T>>(
277 std::acosh(static_cast<std::complex<T>>(x)));
278#endif
279}
280
281template <typename T>
282C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
283#if defined(__CUDACC__) || defined(__HIPCC__)
284 return static_cast<c10::complex<T>>(
285 thrust::atanh(static_cast<thrust::complex<T>>(x)));
286#else
287 return static_cast<c10::complex<T>>(
288 std::atanh(static_cast<std::complex<T>>(x)));
289#endif
290}
291
292template <typename T>
293C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
294 // log1p(z) = log(1 + z)
295 // Let's define 1 + z = r * e ^ (i * a), then we have
296 // log(r * e ^ (i * a)) = log(r) + i * a
297 // With z = x + iy, the term r can be written as
298 // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
299 // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
300 // So, log(r) is
301 // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
302 // = 0.5 * log1p(x * (x + 2) + y ^ 2)
303 // we need to use the expression only on certain condition to avoid overflow
304 // and underflow from `(x * (x + 2) + y ^ 2)`
305 T x = z.real();
306 T y = z.imag();
307 T zabs = std::abs(z);
308 T theta = std::atan2(y, x + T(1));
309 if (zabs < 0.5) {
310 T r = x * (T(2) + x) + y * y;
311 if (r == 0) { // handle underflow
312 return {x, theta};
313 }
314 return {T(0.5) * std::log1p(r), theta};
315 } else {
316 T z0 = std::hypot(x + 1, y);
317 return {std::log(z0), theta};
318 }
319}
320
321} // namespace c10_complex_math
322
323using c10_complex_math::acos;
324using c10_complex_math::acosh;
325using c10_complex_math::asin;
326using c10_complex_math::asinh;
327using c10_complex_math::atan;
328using c10_complex_math::atanh;
329using c10_complex_math::cos;
330using c10_complex_math::cosh;
331using c10_complex_math::exp;
332using c10_complex_math::log;
333using c10_complex_math::log10;
334using c10_complex_math::log1p;
335using c10_complex_math::log2;
336using c10_complex_math::pow;
337using c10_complex_math::sin;
338using c10_complex_math::sinh;
339using c10_complex_math::sqrt;
340using c10_complex_math::tan;
341using c10_complex_math::tanh;
342
343namespace std {
344
345using c10_complex_math::acos;
346using c10_complex_math::acosh;
347using c10_complex_math::asin;
348using c10_complex_math::asinh;
349using c10_complex_math::atan;
350using c10_complex_math::atanh;
351using c10_complex_math::cos;
352using c10_complex_math::cosh;
353using c10_complex_math::exp;
354using c10_complex_math::log;
355using c10_complex_math::log10;
356using c10_complex_math::log1p;
357using c10_complex_math::log2;
358using c10_complex_math::pow;
359using c10_complex_math::sin;
360using c10_complex_math::sinh;
361using c10_complex_math::sqrt;
362using c10_complex_math::tan;
363using c10_complex_math::tanh;
364
365} // namespace std
366