1#pragma once
2
3#include <complex>
4
5#include <c10/macros/Macros.h>
6
7#if defined(__CUDACC__) || defined(__HIPCC__)
8#include <thrust/complex.h>
9#endif
10
11C10_CLANG_DIAGNOSTIC_PUSH()
12#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
13C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
14#endif
15#if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
16C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
17#endif
18
19namespace c10 {
20
21// c10::complex is an implementation of complex numbers that aims
22// to work on all devices supported by PyTorch
23//
24// Most of the APIs duplicates std::complex
25// Reference: https://en.cppreference.com/w/cpp/numeric/complex
26//
27// [NOTE: Complex Operator Unification]
28// Operators currently use a mix of std::complex, thrust::complex, and
29// c10::complex internally. The end state is that all operators will use
30// c10::complex internally. Until then, there may be some hacks to support all
31// variants.
32//
33//
34// [Note on Constructors]
35//
36// The APIs of constructors are mostly copied from C++ standard:
37// https://en.cppreference.com/w/cpp/numeric/complex/complex
38//
39// Since C++14, all constructors are constexpr in std::complex
40//
41// There are three types of constructors:
42// - initializing from real and imag:
43// `constexpr complex( const T& re = T(), const T& im = T() );`
44// - implicitly-declared copy constructor
45// - converting constructors
46//
47// Converting constructors:
48// - std::complex defines converting constructor between float/double/long
49// double,
50// while we define converting constructor between float/double.
51// - For these converting constructors, upcasting is implicit, downcasting is
52// explicit.
53// - We also define explicit casting from std::complex/thrust::complex
54// - Note that the conversion from thrust is not constexpr, because
55// thrust does not define them as constexpr ????
56//
57//
58// [Operator =]
59//
60// The APIs of operator = are mostly copied from C++ standard:
61// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
62//
63// Since C++20, all operator= are constexpr. Although we are not building with
64// C++20, we also obey this behavior.
65//
66// There are three types of assign operator:
67// - Assign a real value from the same scalar type
68// - In std, this is templated as complex& operator=(const T& x)
69// with specialization `complex& operator=(T x)` for float/double/long
70// double Since we only support float and double, on will use `complex&
71// operator=(T x)`
72// - Copy assignment operator and converting assignment operator
73// - There is no specialization of converting assignment operators, which type
74// is
75// convertible is solely dependent on whether the scalar type is convertible
76//
77// In addition to the standard assignment, we also provide assignment operators
78// with std and thrust
79//
80//
81// [Casting operators]
82//
83// std::complex does not have casting operators. We define casting operators
84// casting to std::complex and thrust::complex
85//
86//
87// [Operator ""]
88//
89// std::complex has custom literals `i`, `if` and `il` defined in namespace
90// `std::literals::complex_literals`. We define our own custom literals in the
91// namespace `c10::complex_literals`. Our custom literals does not follow the
92// same behavior as in std::complex, instead, we define _if, _id to construct
93// float/double complex literals.
94//
95//
96// [real() and imag()]
97//
98// In C++20, there are two overload of these functions, one it to return the
99// real/imag, another is to set real/imag, they are both constexpr. We follow
100// this design.
101//
102//
103// [Operator +=,-=,*=,/=]
104//
105// Since C++20, these operators become constexpr. In our implementation, they
106// are also constexpr.
107//
108// There are two types of such operators: operating with a real number, or
109// operating with another complex number. For the operating with a real number,
110// the generic template form has argument type `const T &`, while the overload
111// for float/double/long double has `T`. We will follow the same type as
112// float/double/long double in std.
113//
114// [Unary operator +-]
115//
116// Since C++20, they are constexpr. We also make them expr
117//
118// [Binary operators +-*/]
119//
120// Each operator has three versions (taking + as example):
121// - complex + complex
122// - complex + real
123// - real + complex
124//
125// [Operator ==, !=]
126//
127// Each operator has three versions (taking == as example):
128// - complex == complex
129// - complex == real
130// - real == complex
131//
132// Some of them are removed on C++20, but we decide to keep them
133//
134// [Operator <<, >>]
135//
136// These are implemented by casting to std::complex
137//
138//
139//
140// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
141// because:
142// - lots of members and functions of c10::Half are not constexpr
143// - thrust::complex only support float and double
144
145template <typename T>
146struct alignas(sizeof(T) * 2) complex {
147 using value_type = T;
148
149 T real_ = T(0);
150 T imag_ = T(0);
151
152 constexpr complex() = default;
153 C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
154 : real_(re), imag_(im) {}
155 template <typename U>
156 explicit constexpr complex(const std::complex<U>& other)
157 : complex(other.real(), other.imag()) {}
158#if defined(__CUDACC__) || defined(__HIPCC__)
159 template <typename U>
160 explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
161 : real_(other.real()), imag_(other.imag()) {}
162// NOTE can not be implemented as follow due to ROCm bug:
163// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
164// complex(other.real(), other.imag()) {}
165#endif
166
167 // Use SFINAE to specialize casting constructor for c10::complex<float> and
168 // c10::complex<double>
169 template <typename U = T>
170 C10_HOST_DEVICE explicit constexpr complex(
171 const std::enable_if_t<std::is_same<U, float>::value, complex<double>>&
172 other)
173 : real_(other.real_), imag_(other.imag_) {}
174 template <typename U = T>
175 C10_HOST_DEVICE constexpr complex(
176 const std::enable_if_t<std::is_same<U, double>::value, complex<float>>&
177 other)
178 : real_(other.real_), imag_(other.imag_) {}
179
180 constexpr complex<T>& operator=(T re) {
181 real_ = re;
182 imag_ = 0;
183 return *this;
184 }
185
186 constexpr complex<T>& operator+=(T re) {
187 real_ += re;
188 return *this;
189 }
190
191 constexpr complex<T>& operator-=(T re) {
192 real_ -= re;
193 return *this;
194 }
195
196 constexpr complex<T>& operator*=(T re) {
197 real_ *= re;
198 imag_ *= re;
199 return *this;
200 }
201
202 constexpr complex<T>& operator/=(T re) {
203 real_ /= re;
204 imag_ /= re;
205 return *this;
206 }
207
208 template <typename U>
209 constexpr complex<T>& operator=(const complex<U>& rhs) {
210 real_ = rhs.real();
211 imag_ = rhs.imag();
212 return *this;
213 }
214
215 template <typename U>
216 constexpr complex<T>& operator+=(const complex<U>& rhs) {
217 real_ += rhs.real();
218 imag_ += rhs.imag();
219 return *this;
220 }
221
222 template <typename U>
223 constexpr complex<T>& operator-=(const complex<U>& rhs) {
224 real_ -= rhs.real();
225 imag_ -= rhs.imag();
226 return *this;
227 }
228
229 template <typename U>
230 constexpr complex<T>& operator*=(const complex<U>& rhs) {
231 // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
232 T a = real_;
233 T b = imag_;
234 U c = rhs.real();
235 U d = rhs.imag();
236 real_ = a * c - b * d;
237 imag_ = a * d + b * c;
238 return *this;
239 }
240
241#ifdef __APPLE__
242#define FORCE_INLINE_APPLE __attribute__((always_inline))
243#else
244#define FORCE_INLINE_APPLE
245#endif
246 template <typename U>
247 constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
248 __ubsan_ignore_float_divide_by_zero__ {
249 // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
250 // the calculation below follows numpy's complex division
251 T ar = real_;
252 T ai = imag_;
253 U br = rhs.real();
254 U bi = rhs.imag();
255
256#if defined(__GNUC__) && !defined(__clang__)
257 // std::abs is already constexpr by gcc
258 auto abs_br = std::abs(br);
259 auto abs_bi = std::abs(bi);
260#else
261 auto abs_br = br < 0 ? -br : br;
262 auto abs_bi = bi < 0 ? -bi : bi;
263#endif
264
265 if (abs_br >= abs_bi) {
266 if (abs_br == 0 && abs_bi == 0) {
267 /* divide by zeros should yield a complex inf or nan */
268 real_ = ar / abs_br;
269 imag_ = ai / abs_bi;
270 } else {
271 auto rat = bi / br;
272 auto scl = 1.0 / (br + bi * rat);
273 real_ = (ar + ai * rat) * scl;
274 imag_ = (ai - ar * rat) * scl;
275 }
276 } else {
277 auto rat = br / bi;
278 auto scl = 1.0 / (bi + br * rat);
279 real_ = (ar * rat + ai) * scl;
280 imag_ = (ai * rat - ar) * scl;
281 }
282 return *this;
283 }
284#undef FORCE_INLINE_APPLE
285
286 template <typename U>
287 constexpr complex<T>& operator=(const std::complex<U>& rhs) {
288 real_ = rhs.real();
289 imag_ = rhs.imag();
290 return *this;
291 }
292
293#if defined(__CUDACC__) || defined(__HIPCC__)
294 template <typename U>
295 C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
296 real_ = rhs.real();
297 imag_ = rhs.imag();
298 return *this;
299 }
300#endif
301
302 template <typename U>
303 explicit constexpr operator std::complex<U>() const {
304 return std::complex<U>(std::complex<T>(real(), imag()));
305 }
306
307#if defined(__CUDACC__) || defined(__HIPCC__)
308 template <typename U>
309 C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
310 return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
311 }
312#endif
313
314 // consistent with NumPy behavior
315 explicit constexpr operator bool() const {
316 return real() || imag();
317 }
318
319 C10_HOST_DEVICE constexpr T real() const {
320 return real_;
321 }
322 constexpr void real(T value) {
323 real_ = value;
324 }
325 constexpr T imag() const {
326 return imag_;
327 }
328 constexpr void imag(T value) {
329 imag_ = value;
330 }
331};
332
333namespace complex_literals {
334
335constexpr complex<float> operator"" _if(long double imag) {
336 return complex<float>(0.0f, static_cast<float>(imag));
337}
338
339constexpr complex<double> operator"" _id(long double imag) {
340 return complex<double>(0.0, static_cast<double>(imag));
341}
342
343constexpr complex<float> operator"" _if(unsigned long long imag) {
344 return complex<float>(0.0f, static_cast<float>(imag));
345}
346
347constexpr complex<double> operator"" _id(unsigned long long imag) {
348 return complex<double>(0.0, static_cast<double>(imag));
349}
350
351} // namespace complex_literals
352
353template <typename T>
354constexpr complex<T> operator+(const complex<T>& val) {
355 return val;
356}
357
358template <typename T>
359constexpr complex<T> operator-(const complex<T>& val) {
360 return complex<T>(-val.real(), -val.imag());
361}
362
363template <typename T>
364constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
365 complex<T> result = lhs;
366 return result += rhs;
367}
368
369template <typename T>
370constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
371 complex<T> result = lhs;
372 return result += rhs;
373}
374
375template <typename T>
376constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
377 return complex<T>(lhs + rhs.real(), rhs.imag());
378}
379
380template <typename T>
381constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
382 complex<T> result = lhs;
383 return result -= rhs;
384}
385
386template <typename T>
387constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
388 complex<T> result = lhs;
389 return result -= rhs;
390}
391
392template <typename T>
393constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
394 complex<T> result = -rhs;
395 return result += lhs;
396}
397
398template <typename T>
399constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
400 complex<T> result = lhs;
401 return result *= rhs;
402}
403
404template <typename T>
405constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
406 complex<T> result = lhs;
407 return result *= rhs;
408}
409
410template <typename T>
411constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
412 complex<T> result = rhs;
413 return result *= lhs;
414}
415
416template <typename T>
417constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
418 complex<T> result = lhs;
419 return result /= rhs;
420}
421
422template <typename T>
423constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
424 complex<T> result = lhs;
425 return result /= rhs;
426}
427
428template <typename T>
429constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
430 complex<T> result(lhs, T());
431 return result /= rhs;
432}
433
434// Define operators between integral scalars and c10::complex. std::complex does
435// not support this when T is a floating-point number. This is useful because it
436// saves a lot of "static_cast" when operate a complex and an integer. This
437// makes the code both less verbose and potentially more efficient.
438#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
439 typename std::enable_if_t< \
440 std::is_floating_point<fT>::value && std::is_integral<iT>::value, \
441 int> = 0
442
443template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
444constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
445 return a + static_cast<fT>(b);
446}
447
448template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
449constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
450 return static_cast<fT>(a) + b;
451}
452
453template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
454constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
455 return a - static_cast<fT>(b);
456}
457
458template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
459constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
460 return static_cast<fT>(a) - b;
461}
462
463template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
464constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
465 return a * static_cast<fT>(b);
466}
467
468template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
469constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
470 return static_cast<fT>(a) * b;
471}
472
473template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
474constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
475 return a / static_cast<fT>(b);
476}
477
478template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
479constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
480 return static_cast<fT>(a) / b;
481}
482
483#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
484
485template <typename T>
486constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
487 return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
488}
489
490template <typename T>
491constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
492 return (lhs.real() == rhs) && (lhs.imag() == T());
493}
494
495template <typename T>
496constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
497 return (lhs == rhs.real()) && (T() == rhs.imag());
498}
499
500template <typename T>
501constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
502 return !(lhs == rhs);
503}
504
505template <typename T>
506constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
507 return !(lhs == rhs);
508}
509
510template <typename T>
511constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
512 return !(lhs == rhs);
513}
514
515template <typename T, typename CharT, typename Traits>
516std::basic_ostream<CharT, Traits>& operator<<(
517 std::basic_ostream<CharT, Traits>& os,
518 const complex<T>& x) {
519 return (os << static_cast<std::complex<T>>(x));
520}
521
522template <typename T, typename CharT, typename Traits>
523std::basic_istream<CharT, Traits>& operator>>(
524 std::basic_istream<CharT, Traits>& is,
525 complex<T>& x) {
526 std::complex<T> tmp;
527 is >> tmp;
528 x = tmp;
529 return is;
530}
531
532} // namespace c10
533
534// std functions
535//
536// The implementation of these functions also follow the design of C++20
537
538namespace std {
539
540template <typename T>
541constexpr T real(const c10::complex<T>& z) {
542 return z.real();
543}
544
545template <typename T>
546constexpr T imag(const c10::complex<T>& z) {
547 return z.imag();
548}
549
550template <typename T>
551C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
552#if defined(__CUDACC__) || defined(__HIPCC__)
553 return thrust::abs(static_cast<thrust::complex<T>>(z));
554#else
555 return std::abs(static_cast<std::complex<T>>(z));
556#endif
557}
558
559#if defined(USE_ROCM)
560#define ROCm_Bug(x)
561#else
562#define ROCm_Bug(x) x
563#endif
564
565template <typename T>
566C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
567 return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
568}
569
570#undef ROCm_Bug
571
572template <typename T>
573constexpr T norm(const c10::complex<T>& z) {
574 return z.real() * z.real() + z.imag() * z.imag();
575}
576
577// For std::conj, there are other versions of it:
578// constexpr std::complex<float> conj( float z );
579// template< class DoubleOrInteger >
580// constexpr std::complex<double> conj( DoubleOrInteger z );
581// constexpr std::complex<long double> conj( long double z );
582// These are not implemented
583// TODO(@zasdfgbnm): implement them as c10::conj
584template <typename T>
585constexpr c10::complex<T> conj(const c10::complex<T>& z) {
586 return c10::complex<T>(z.real(), -z.imag());
587}
588
589// Thrust does not have complex --> complex version of thrust::proj,
590// so this function is not implemented at c10 right now.
591// TODO(@zasdfgbnm): implement it by ourselves
592
593// There is no c10 version of std::polar, because std::polar always
594// returns std::complex. Use c10::polar instead;
595
596} // namespace std
597
598namespace c10 {
599
600template <typename T>
601C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
602#if defined(__CUDACC__) || defined(__HIPCC__)
603 return static_cast<complex<T>>(thrust::polar(r, theta));
604#else
605 // std::polar() requires r >= 0, so spell out the explicit implementation to
606 // avoid a branch.
607 return complex<T>(r * std::cos(theta), r * std::sin(theta));
608#endif
609}
610
611} // namespace c10
612
613C10_CLANG_DIAGNOSTIC_POP()
614
615#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
616// math functions are included in a separate file
617#include <c10/util/complex_math.h> // IWYU pragma: keep
618// utilities for complex types
619#include <c10/util/complex_utils.h> // IWYU pragma: keep
620#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
621