1#include <c10/macros/Macros.h>
2#include <c10/util/complex.h>
3#include <c10/util/hash.h>
4#include <gtest/gtest.h>
5#include <sstream>
6#include <tuple>
7#include <type_traits>
8#include <unordered_map>
9
10#if (defined(__CUDACC__) || defined(__HIPCC__))
11#define MAYBE_GLOBAL __global__
12#else
13#define MAYBE_GLOBAL
14#endif
15
16#define PI 3.141592653589793238463
17
18namespace memory {
19
20MAYBE_GLOBAL void test_size() {
21 static_assert(sizeof(c10::complex<float>) == 2 * sizeof(float), "");
22 static_assert(sizeof(c10::complex<double>) == 2 * sizeof(double), "");
23}
24
25MAYBE_GLOBAL void test_align() {
26 static_assert(alignof(c10::complex<float>) == 2 * sizeof(float), "");
27 static_assert(alignof(c10::complex<double>) == 2 * sizeof(double), "");
28}
29
30MAYBE_GLOBAL void test_pod() {
31 static_assert(std::is_standard_layout<c10::complex<float>>::value, "");
32 static_assert(std::is_standard_layout<c10::complex<double>>::value, "");
33}
34
35TEST(TestMemory, ReinterpretCast) {
36 {
37 std::complex<float> z(1, 2);
38 c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
39 ASSERT_EQ(zz.real(), float(1));
40 ASSERT_EQ(zz.imag(), float(2));
41 }
42
43 {
44 c10::complex<float> z(3, 4);
45 std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z);
46 ASSERT_EQ(zz.real(), float(3));
47 ASSERT_EQ(zz.imag(), float(4));
48 }
49
50 {
51 std::complex<double> z(1, 2);
52 c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
53 ASSERT_EQ(zz.real(), double(1));
54 ASSERT_EQ(zz.imag(), double(2));
55 }
56
57 {
58 c10::complex<double> z(3, 4);
59 std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z);
60 ASSERT_EQ(zz.real(), double(3));
61 ASSERT_EQ(zz.imag(), double(4));
62 }
63}
64
65#if defined(__CUDACC__) || defined(__HIPCC__)
66TEST(TestMemory, ThrustReinterpretCast) {
67 {
68 thrust::complex<float> z(1, 2);
69 c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z);
70 ASSERT_EQ(zz.real(), float(1));
71 ASSERT_EQ(zz.imag(), float(2));
72 }
73
74 {
75 c10::complex<float> z(3, 4);
76 thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z);
77 ASSERT_EQ(zz.real(), float(3));
78 ASSERT_EQ(zz.imag(), float(4));
79 }
80
81 {
82 thrust::complex<double> z(1, 2);
83 c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z);
84 ASSERT_EQ(zz.real(), double(1));
85 ASSERT_EQ(zz.imag(), double(2));
86 }
87
88 {
89 c10::complex<double> z(3, 4);
90 thrust::complex<double> zz =
91 *reinterpret_cast<thrust::complex<double>*>(&z);
92 ASSERT_EQ(zz.real(), double(3));
93 ASSERT_EQ(zz.imag(), double(4));
94 }
95}
96#endif
97
98} // namespace memory
99
100namespace constructors {
101
102template <typename scalar_t>
103C10_HOST_DEVICE void test_construct_from_scalar() {
104 constexpr scalar_t num1 = scalar_t(1.23);
105 constexpr scalar_t num2 = scalar_t(4.56);
106 constexpr scalar_t zero = scalar_t();
107 static_assert(c10::complex<scalar_t>(num1, num2).real() == num1, "");
108 static_assert(c10::complex<scalar_t>(num1, num2).imag() == num2, "");
109 static_assert(c10::complex<scalar_t>(num1).real() == num1, "");
110 static_assert(c10::complex<scalar_t>(num1).imag() == zero, "");
111 static_assert(c10::complex<scalar_t>().real() == zero, "");
112 static_assert(c10::complex<scalar_t>().imag() == zero, "");
113}
114
115template <typename scalar_t, typename other_t>
116C10_HOST_DEVICE void test_construct_from_other() {
117 constexpr other_t num1 = other_t(1.23);
118 constexpr other_t num2 = other_t(4.56);
119 constexpr scalar_t num3 = scalar_t(num1);
120 constexpr scalar_t num4 = scalar_t(num2);
121 static_assert(
122 c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3,
123 "");
124 static_assert(
125 c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4,
126 "");
127}
128
129MAYBE_GLOBAL void test_convert_constructors() {
130 test_construct_from_scalar<float>();
131 test_construct_from_scalar<double>();
132
133 static_assert(
134 std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "");
135 static_assert(
136 !std::is_convertible<c10::complex<double>, c10::complex<float>>::value,
137 "");
138 static_assert(
139 std::is_convertible<c10::complex<float>, c10::complex<double>>::value,
140 "");
141 static_assert(
142 std::is_convertible<c10::complex<double>, c10::complex<double>>::value,
143 "");
144
145 static_assert(
146 std::is_constructible<c10::complex<float>, c10::complex<float>>::value,
147 "");
148 static_assert(
149 std::is_constructible<c10::complex<double>, c10::complex<float>>::value,
150 "");
151 static_assert(
152 std::is_constructible<c10::complex<float>, c10::complex<double>>::value,
153 "");
154 static_assert(
155 std::is_constructible<c10::complex<double>, c10::complex<double>>::value,
156 "");
157
158 test_construct_from_other<float, float>();
159 test_construct_from_other<float, double>();
160 test_construct_from_other<double, float>();
161 test_construct_from_other<double, double>();
162}
163
164template <typename scalar_t>
165C10_HOST_DEVICE void test_construct_from_std() {
166 constexpr scalar_t num1 = scalar_t(1.23);
167 constexpr scalar_t num2 = scalar_t(4.56);
168 static_assert(
169 c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1,
170 "");
171 static_assert(
172 c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2,
173 "");
174}
175
176MAYBE_GLOBAL void test_std_conversion() {
177 test_construct_from_std<float>();
178 test_construct_from_std<double>();
179}
180
181#if defined(__CUDACC__) || defined(__HIPCC__)
182template <typename scalar_t>
183void test_construct_from_thrust() {
184 constexpr scalar_t num1 = scalar_t(1.23);
185 constexpr scalar_t num2 = scalar_t(4.56);
186 ASSERT_EQ(
187 c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(),
188 num1);
189 ASSERT_EQ(
190 c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(),
191 num2);
192}
193
194TEST(TestConstructors, FromThrust) {
195 test_construct_from_thrust<float>();
196 test_construct_from_thrust<double>();
197}
198#endif
199
200TEST(TestConstructors, UnorderedMap) {
201 std::unordered_map<
202 c10::complex<double>,
203 c10::complex<double>,
204 c10::hash<c10::complex<double>>>
205 m;
206 auto key1 = c10::complex<double>(2.5, 3);
207 auto key2 = c10::complex<double>(2, 0);
208 auto val1 = c10::complex<double>(2, -3.2);
209 auto val2 = c10::complex<double>(0, -3);
210 m[key1] = val1;
211 m[key2] = val2;
212 ASSERT_EQ(m[key1], val1);
213 ASSERT_EQ(m[key2], val2);
214}
215
216} // namespace constructors
217
218namespace assignment {
219
220template <typename scalar_t>
221constexpr c10::complex<scalar_t> one() {
222 c10::complex<scalar_t> result(3, 4);
223 result = scalar_t(1);
224 return result;
225}
226
227MAYBE_GLOBAL void test_assign_real() {
228 static_assert(one<float>().real() == float(1), "");
229 static_assert(one<float>().imag() == float(), "");
230 static_assert(one<double>().real() == double(1), "");
231 static_assert(one<double>().imag() == double(), "");
232}
233
234constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two() {
235 constexpr c10::complex<float> src(1, 2);
236 c10::complex<double> ret0;
237 c10::complex<float> ret1;
238 ret0 = ret1 = src;
239 return std::make_tuple(ret0, ret1);
240}
241
242MAYBE_GLOBAL void test_assign_other() {
243 constexpr auto tup = one_two();
244 static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
245 static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
246 static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
247 static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
248}
249
250constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two_std() {
251 constexpr std::complex<float> src(1, 1);
252 c10::complex<double> ret0;
253 c10::complex<float> ret1;
254 ret0 = ret1 = src;
255 return std::make_tuple(ret0, ret1);
256}
257
258MAYBE_GLOBAL void test_assign_std() {
259 constexpr auto tup = one_two();
260 static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "");
261 static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "");
262 static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "");
263 static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "");
264}
265
266#if defined(__CUDACC__) || defined(__HIPCC__)
267C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>>
268one_two_thrust() {
269 thrust::complex<float> src(1, 2);
270 c10::complex<double> ret0;
271 c10::complex<float> ret1;
272 ret0 = ret1 = src;
273 return std::make_tuple(ret0, ret1);
274}
275
276TEST(TestAssignment, FromThrust) {
277 auto tup = one_two_thrust();
278 ASSERT_EQ(std::get<c10::complex<double>>(tup).real(), double(1));
279 ASSERT_EQ(std::get<c10::complex<double>>(tup).imag(), double(2));
280 ASSERT_EQ(std::get<c10::complex<float>>(tup).real(), float(1));
281 ASSERT_EQ(std::get<c10::complex<float>>(tup).imag(), float(2));
282}
283#endif
284
285} // namespace assignment
286
287namespace literals {
288
289MAYBE_GLOBAL void test_complex_literals() {
290 using namespace c10::complex_literals;
291 static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "");
292 static_assert((0.5_if).real() == float(), "");
293 static_assert((0.5_if).imag() == float(0.5), "");
294 static_assert(
295 std::is_same<decltype(0.5_id), c10::complex<double>>::value, "");
296 static_assert((0.5_id).real() == float(), "");
297 static_assert((0.5_id).imag() == float(0.5), "");
298
299 static_assert(std::is_same<decltype(1_if), c10::complex<float>>::value, "");
300 static_assert((1_if).real() == float(), "");
301 static_assert((1_if).imag() == float(1), "");
302 static_assert(std::is_same<decltype(1_id), c10::complex<double>>::value, "");
303 static_assert((1_id).real() == double(), "");
304 static_assert((1_id).imag() == double(1), "");
305}
306
307} // namespace literals
308
309namespace real_imag {
310
311template <typename scalar_t>
312constexpr c10::complex<scalar_t> zero_one() {
313 c10::complex<scalar_t> result;
314 result.imag(scalar_t(1));
315 return result;
316}
317
318template <typename scalar_t>
319constexpr c10::complex<scalar_t> one_zero() {
320 c10::complex<scalar_t> result;
321 result.real(scalar_t(1));
322 return result;
323}
324
325MAYBE_GLOBAL void test_real_imag_modify() {
326 static_assert(zero_one<float>().real() == float(0), "");
327 static_assert(zero_one<float>().imag() == float(1), "");
328 static_assert(zero_one<double>().real() == double(0), "");
329 static_assert(zero_one<double>().imag() == double(1), "");
330
331 static_assert(one_zero<float>().real() == float(1), "");
332 static_assert(one_zero<float>().imag() == float(0), "");
333 static_assert(one_zero<double>().real() == double(1), "");
334 static_assert(one_zero<double>().imag() == double(0), "");
335}
336
337} // namespace real_imag
338
339namespace arithmetic_assign {
340
341template <typename scalar_t>
342constexpr c10::complex<scalar_t> p(scalar_t value) {
343 c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
344 result += value;
345 return result;
346}
347
348template <typename scalar_t>
349constexpr c10::complex<scalar_t> m(scalar_t value) {
350 c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
351 result -= value;
352 return result;
353}
354
355template <typename scalar_t>
356constexpr c10::complex<scalar_t> t(scalar_t value) {
357 c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
358 result *= value;
359 return result;
360}
361
362template <typename scalar_t>
363constexpr c10::complex<scalar_t> d(scalar_t value) {
364 c10::complex<scalar_t> result(scalar_t(2), scalar_t(2));
365 result /= value;
366 return result;
367}
368
369template <typename scalar_t>
370C10_HOST_DEVICE void test_arithmetic_assign_scalar() {
371 constexpr c10::complex<scalar_t> x = p(scalar_t(1));
372 static_assert(x.real() == scalar_t(3), "");
373 static_assert(x.imag() == scalar_t(2), "");
374 constexpr c10::complex<scalar_t> y = m(scalar_t(1));
375 static_assert(y.real() == scalar_t(1), "");
376 static_assert(y.imag() == scalar_t(2), "");
377 constexpr c10::complex<scalar_t> z = t(scalar_t(2));
378 static_assert(z.real() == scalar_t(4), "");
379 static_assert(z.imag() == scalar_t(4), "");
380 constexpr c10::complex<scalar_t> t = d(scalar_t(2));
381 static_assert(t.real() == scalar_t(1), "");
382 static_assert(t.imag() == scalar_t(1), "");
383}
384
385template <typename scalar_t, typename rhs_t>
386constexpr c10::complex<scalar_t> p(
387 scalar_t real,
388 scalar_t imag,
389 c10::complex<rhs_t> rhs) {
390 c10::complex<scalar_t> result(real, imag);
391 result += rhs;
392 return result;
393}
394
395template <typename scalar_t, typename rhs_t>
396constexpr c10::complex<scalar_t> m(
397 scalar_t real,
398 scalar_t imag,
399 c10::complex<rhs_t> rhs) {
400 c10::complex<scalar_t> result(real, imag);
401 result -= rhs;
402 return result;
403}
404
405template <typename scalar_t, typename rhs_t>
406constexpr c10::complex<scalar_t> t(
407 scalar_t real,
408 scalar_t imag,
409 c10::complex<rhs_t> rhs) {
410 c10::complex<scalar_t> result(real, imag);
411 result *= rhs;
412 return result;
413}
414
415template <typename scalar_t, typename rhs_t>
416constexpr c10::complex<scalar_t> d(
417 scalar_t real,
418 scalar_t imag,
419 c10::complex<rhs_t> rhs) {
420 c10::complex<scalar_t> result(real, imag);
421 result /= rhs;
422 return result;
423}
424
425template <typename scalar_t>
426C10_HOST_DEVICE void test_arithmetic_assign_complex() {
427 using namespace c10::complex_literals;
428 constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if);
429 static_assert(x2.real() == scalar_t(2), "");
430 static_assert(x2.imag() == scalar_t(3), "");
431 constexpr c10::complex<scalar_t> x3 = p(scalar_t(2), scalar_t(2), 1.0_id);
432 static_assert(x3.real() == scalar_t(2), "");
433
434 // this test is skipped due to a bug in constexpr evaluation
435 // in nvcc. This bug has already been fixed since CUDA 11.2
436#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
437 static_assert(x3.imag() == scalar_t(3), "");
438#endif
439
440 constexpr c10::complex<scalar_t> y2 = m(scalar_t(2), scalar_t(2), 1.0_if);
441 static_assert(y2.real() == scalar_t(2), "");
442 static_assert(y2.imag() == scalar_t(1), "");
443 constexpr c10::complex<scalar_t> y3 = m(scalar_t(2), scalar_t(2), 1.0_id);
444 static_assert(y3.real() == scalar_t(2), "");
445
446 // this test is skipped due to a bug in constexpr evaluation
447 // in nvcc. This bug has already been fixed since CUDA 11.2
448#if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020)
449 static_assert(y3.imag() == scalar_t(1), "");
450#endif
451
452 constexpr c10::complex<scalar_t> z2 = t(scalar_t(1), scalar_t(-2), 1.0_if);
453 static_assert(z2.real() == scalar_t(2), "");
454 static_assert(z2.imag() == scalar_t(1), "");
455 constexpr c10::complex<scalar_t> z3 = t(scalar_t(1), scalar_t(-2), 1.0_id);
456 static_assert(z3.real() == scalar_t(2), "");
457 static_assert(z3.imag() == scalar_t(1), "");
458
459 constexpr c10::complex<scalar_t> t2 = d(scalar_t(-1), scalar_t(2), 1.0_if);
460 static_assert(t2.real() == scalar_t(2), "");
461 static_assert(t2.imag() == scalar_t(1), "");
462 constexpr c10::complex<scalar_t> t3 = d(scalar_t(-1), scalar_t(2), 1.0_id);
463 static_assert(t3.real() == scalar_t(2), "");
464 static_assert(t3.imag() == scalar_t(1), "");
465}
466
467MAYBE_GLOBAL void test_arithmetic_assign() {
468 test_arithmetic_assign_scalar<float>();
469 test_arithmetic_assign_scalar<double>();
470 test_arithmetic_assign_complex<float>();
471 test_arithmetic_assign_complex<double>();
472}
473
474} // namespace arithmetic_assign
475
476namespace arithmetic {
477
478template <typename scalar_t>
479C10_HOST_DEVICE void test_arithmetic_() {
480 static_assert(
481 c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "");
482 static_assert(
483 c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "");
484
485 static_assert(
486 c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) ==
487 c10::complex<scalar_t>(4, 6),
488 "");
489 static_assert(
490 c10::complex<scalar_t>(1, 2) + scalar_t(3) ==
491 c10::complex<scalar_t>(4, 2),
492 "");
493 static_assert(
494 scalar_t(3) + c10::complex<scalar_t>(1, 2) ==
495 c10::complex<scalar_t>(4, 2),
496 "");
497
498 static_assert(
499 c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) ==
500 c10::complex<scalar_t>(-2, -2),
501 "");
502 static_assert(
503 c10::complex<scalar_t>(1, 2) - scalar_t(3) ==
504 c10::complex<scalar_t>(-2, 2),
505 "");
506 static_assert(
507 scalar_t(3) - c10::complex<scalar_t>(1, 2) ==
508 c10::complex<scalar_t>(2, -2),
509 "");
510
511 static_assert(
512 c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) ==
513 c10::complex<scalar_t>(-5, 10),
514 "");
515 static_assert(
516 c10::complex<scalar_t>(1, 2) * scalar_t(3) ==
517 c10::complex<scalar_t>(3, 6),
518 "");
519 static_assert(
520 scalar_t(3) * c10::complex<scalar_t>(1, 2) ==
521 c10::complex<scalar_t>(3, 6),
522 "");
523
524 static_assert(
525 c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) ==
526 c10::complex<scalar_t>(1, 2),
527 "");
528 static_assert(
529 c10::complex<scalar_t>(5, 10) / scalar_t(5) ==
530 c10::complex<scalar_t>(1, 2),
531 "");
532 static_assert(
533 scalar_t(25) / c10::complex<scalar_t>(3, 4) ==
534 c10::complex<scalar_t>(3, -4),
535 "");
536}
537
538MAYBE_GLOBAL void test_arithmetic() {
539 test_arithmetic_<float>();
540 test_arithmetic_<double>();
541}
542
543template <typename T, typename int_t>
544void test_binary_ops_for_int_type_(T real, T img, int_t num) {
545 c10::complex<T> c(real, img);
546 ASSERT_EQ(c + num, c10::complex<T>(real + num, img));
547 ASSERT_EQ(num + c, c10::complex<T>(num + real, img));
548 ASSERT_EQ(c - num, c10::complex<T>(real - num, img));
549 ASSERT_EQ(num - c, c10::complex<T>(num - real, -img));
550 ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num));
551 ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img));
552 ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num));
553 ASSERT_EQ(
554 num / c,
555 c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
556}
557
558template <typename T>
559void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) {
560 test_binary_ops_for_int_type_<T, int8_t>(real, img, i);
561 test_binary_ops_for_int_type_<T, int16_t>(real, img, i);
562 test_binary_ops_for_int_type_<T, int32_t>(real, img, i);
563 test_binary_ops_for_int_type_<T, int64_t>(real, img, i);
564}
565
566TEST(TestArithmeticIntScalar, All) {
567 test_binary_ops_for_all_int_types_<float>(1.0, 0.1, 1);
568 test_binary_ops_for_all_int_types_<double>(-1.3, -0.2, -2);
569}
570
571} // namespace arithmetic
572
573namespace equality {
574
575template <typename scalar_t>
576C10_HOST_DEVICE void test_equality_() {
577 static_assert(
578 c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "");
579 static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "");
580 static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "");
581 static_assert(
582 c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "");
583 static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "");
584 static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "");
585}
586
587MAYBE_GLOBAL void test_equality() {
588 test_equality_<float>();
589 test_equality_<double>();
590}
591
592} // namespace equality
593
594namespace io {
595
596template <typename scalar_t>
597void test_io_() {
598 std::stringstream ss;
599 c10::complex<scalar_t> a(1, 2);
600 ss << a;
601 ASSERT_EQ(ss.str(), "(1,2)");
602 ss.str("(3,4)");
603 ss >> a;
604 ASSERT_TRUE(a == c10::complex<scalar_t>(3, 4));
605}
606
607TEST(TestIO, All) {
608 test_io_<float>();
609 test_io_<double>();
610}
611
612} // namespace io
613
614namespace test_std {
615
616template <typename scalar_t>
617C10_HOST_DEVICE void test_callable_() {
618 static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "");
619 static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "");
620 std::abs(c10::complex<scalar_t>(1, 2));
621 std::arg(c10::complex<scalar_t>(1, 2));
622 static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "");
623 static_assert(
624 std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4),
625 "");
626 c10::polar(float(1), float(PI / 2));
627 c10::polar(double(1), double(PI / 2));
628}
629
630MAYBE_GLOBAL void test_callable() {
631 test_callable_<float>();
632 test_callable_<double>();
633}
634
635template <typename scalar_t>
636void test_values_() {
637 ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5));
638 ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6);
639 ASSERT_LT(
640 std::abs(
641 c10::polar(scalar_t(1), scalar_t(PI / 2)) -
642 c10::complex<scalar_t>(0, 1)),
643 1e-6);
644}
645
646TEST(TestStd, BasicFunctions) {
647 test_values_<float>();
648 test_values_<double>();
649 // CSQRT edge cases: checks for overflows which are likely to occur
650 // if square root is computed using polar form
651 ASSERT_LT(
652 std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4);
653 ASSERT_LT(
654 std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()),
655 3e-4);
656}
657
658} // namespace test_std
659