1#include <c10/util/complex.h>
2#include <c10/util/math_compat.h>
3
4#include <cmath>
5
6// Note [ Complex Square root in libc++]
7// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8// In libc++ complex square root is computed using polar form
9// This is a reasonably fast algorithm, but can result in significant
10// numerical errors when arg is close to 0, pi/2, pi, or 3pi/4
11// In that case provide a more conservative implementation which is
12// slower but less prone to those kinds of errors
13// In libstdc++ complex square root yield invalid results
14// for -x-0.0j unless C99 csqrt/csqrtf fallbacks are used
15
16#if defined(_LIBCPP_VERSION) || \
17 (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
18
19namespace {
20template <typename T>
21c10::complex<T> compute_csqrt(const c10::complex<T>& z) {
22 constexpr auto half = T(.5);
23
24 // Trust standard library to correctly handle infs and NaNs
25 if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
26 std::isnan(z.imag())) {
27 return static_cast<c10::complex<T>>(
28 std::sqrt(static_cast<std::complex<T>>(z)));
29 }
30
31 // Special case for square root of pure imaginary values
32 if (z.real() == T(0)) {
33 if (z.imag() == T(0)) {
34 return c10::complex<T>(T(0), z.imag());
35 }
36 auto v = std::sqrt(half * std::abs(z.imag()));
37 return c10::complex<T>(v, std::copysign(v, z.imag()));
38 }
39
40 // At this point, z is non-zero and finite
41 if (z.real() >= 0.0) {
42 auto t = std::sqrt((z.real() + std::abs(z)) * half);
43 return c10::complex<T>(t, half * (z.imag() / t));
44 }
45
46 auto t = std::sqrt((-z.real() + std::abs(z)) * half);
47 return c10::complex<T>(
48 half * std::abs(z.imag() / t), std::copysign(t, z.imag()));
49}
50
51// Compute complex arccosine using formula from W. Kahan
52// "Branch Cuts for Complex Elementary Functions" 1986 paper:
53// cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re())
54// cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im())
55template <typename T>
56c10::complex<T> compute_cacos(const c10::complex<T>& z) {
57 auto constexpr one = T(1);
58 // Trust standard library to correctly handle infs and NaNs
59 if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
60 std::isnan(z.imag())) {
61 return static_cast<c10::complex<T>>(
62 std::acos(static_cast<std::complex<T>>(z)));
63 }
64 auto a = compute_csqrt(c10::complex<T>(one - z.real(), -z.imag()));
65 auto b = compute_csqrt(c10::complex<T>(one + z.real(), z.imag()));
66 auto c = compute_csqrt(c10::complex<T>(one + z.real(), -z.imag()));
67 auto r = T(2) * std::atan2(a.real(), b.real());
68 // Explicitly unroll (a*c).imag()
69 auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real());
70 return c10::complex<T>(r, i);
71}
72} // anonymous namespace
73
74namespace c10_complex_math {
75namespace _detail {
76c10::complex<float> sqrt(const c10::complex<float>& in) {
77 return compute_csqrt(in);
78}
79
80c10::complex<double> sqrt(const c10::complex<double>& in) {
81 return compute_csqrt(in);
82}
83
84c10::complex<float> acos(const c10::complex<float>& in) {
85 return compute_cacos(in);
86}
87
88c10::complex<double> acos(const c10::complex<double>& in) {
89 return compute_cacos(in);
90}
91
92} // namespace _detail
93} // namespace c10_complex_math
94#endif
95