1 | #include <c10/util/complex.h> |
2 | #include <c10/util/math_compat.h> |
3 | |
4 | #include <cmath> |
5 | |
6 | // Note [ Complex Square root in libc++] |
7 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
8 | // In libc++ complex square root is computed using polar form |
9 | // This is a reasonably fast algorithm, but can result in significant |
10 | // numerical errors when arg is close to 0, pi/2, pi, or 3pi/4 |
11 | // In that case provide a more conservative implementation which is |
12 | // slower but less prone to those kinds of errors |
13 | // In libstdc++ complex square root yield invalid results |
14 | // for -x-0.0j unless C99 csqrt/csqrtf fallbacks are used |
15 | |
16 | #if defined(_LIBCPP_VERSION) || \ |
17 | (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) |
18 | |
19 | namespace { |
20 | template <typename T> |
21 | c10::complex<T> compute_csqrt(const c10::complex<T>& z) { |
22 | constexpr auto half = T(.5); |
23 | |
24 | // Trust standard library to correctly handle infs and NaNs |
25 | if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || |
26 | std::isnan(z.imag())) { |
27 | return static_cast<c10::complex<T>>( |
28 | std::sqrt(static_cast<std::complex<T>>(z))); |
29 | } |
30 | |
31 | // Special case for square root of pure imaginary values |
32 | if (z.real() == T(0)) { |
33 | if (z.imag() == T(0)) { |
34 | return c10::complex<T>(T(0), z.imag()); |
35 | } |
36 | auto v = std::sqrt(half * std::abs(z.imag())); |
37 | return c10::complex<T>(v, std::copysign(v, z.imag())); |
38 | } |
39 | |
40 | // At this point, z is non-zero and finite |
41 | if (z.real() >= 0.0) { |
42 | auto t = std::sqrt((z.real() + std::abs(z)) * half); |
43 | return c10::complex<T>(t, half * (z.imag() / t)); |
44 | } |
45 | |
46 | auto t = std::sqrt((-z.real() + std::abs(z)) * half); |
47 | return c10::complex<T>( |
48 | half * std::abs(z.imag() / t), std::copysign(t, z.imag())); |
49 | } |
50 | |
51 | // Compute complex arccosine using formula from W. Kahan |
52 | // "Branch Cuts for Complex Elementary Functions" 1986 paper: |
53 | // cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re()) |
54 | // cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im()) |
55 | template <typename T> |
56 | c10::complex<T> compute_cacos(const c10::complex<T>& z) { |
57 | auto constexpr one = T(1); |
58 | // Trust standard library to correctly handle infs and NaNs |
59 | if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) || |
60 | std::isnan(z.imag())) { |
61 | return static_cast<c10::complex<T>>( |
62 | std::acos(static_cast<std::complex<T>>(z))); |
63 | } |
64 | auto a = compute_csqrt(c10::complex<T>(one - z.real(), -z.imag())); |
65 | auto b = compute_csqrt(c10::complex<T>(one + z.real(), z.imag())); |
66 | auto c = compute_csqrt(c10::complex<T>(one + z.real(), -z.imag())); |
67 | auto r = T(2) * std::atan2(a.real(), b.real()); |
68 | // Explicitly unroll (a*c).imag() |
69 | auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real()); |
70 | return c10::complex<T>(r, i); |
71 | } |
72 | } // anonymous namespace |
73 | |
74 | namespace c10_complex_math { |
75 | namespace _detail { |
76 | c10::complex<float> sqrt(const c10::complex<float>& in) { |
77 | return compute_csqrt(in); |
78 | } |
79 | |
80 | c10::complex<double> sqrt(const c10::complex<double>& in) { |
81 | return compute_csqrt(in); |
82 | } |
83 | |
84 | c10::complex<float> acos(const c10::complex<float>& in) { |
85 | return compute_cacos(in); |
86 | } |
87 | |
88 | c10::complex<double> acos(const c10::complex<double>& in) { |
89 | return compute_cacos(in); |
90 | } |
91 | |
92 | } // namespace _detail |
93 | } // namespace c10_complex_math |
94 | #endif |
95 | |