1 | #include <c10/macros/Macros.h> |
2 | #include <c10/util/complex.h> |
3 | #include <c10/util/hash.h> |
4 | #include <gtest/gtest.h> |
5 | #include <sstream> |
6 | #include <tuple> |
7 | #include <type_traits> |
8 | #include <unordered_map> |
9 | |
10 | #if (defined(__CUDACC__) || defined(__HIPCC__)) |
11 | #define MAYBE_GLOBAL __global__ |
12 | #else |
13 | #define MAYBE_GLOBAL |
14 | #endif |
15 | |
16 | #define PI 3.141592653589793238463 |
17 | |
18 | namespace memory { |
19 | |
20 | MAYBE_GLOBAL void test_size() { |
21 | static_assert(sizeof(c10::complex<float>) == 2 * sizeof(float), "" ); |
22 | static_assert(sizeof(c10::complex<double>) == 2 * sizeof(double), "" ); |
23 | } |
24 | |
25 | MAYBE_GLOBAL void test_align() { |
26 | static_assert(alignof(c10::complex<float>) == 2 * sizeof(float), "" ); |
27 | static_assert(alignof(c10::complex<double>) == 2 * sizeof(double), "" ); |
28 | } |
29 | |
30 | MAYBE_GLOBAL void test_pod() { |
31 | static_assert(std::is_standard_layout<c10::complex<float>>::value, "" ); |
32 | static_assert(std::is_standard_layout<c10::complex<double>>::value, "" ); |
33 | } |
34 | |
35 | TEST(TestMemory, ReinterpretCast) { |
36 | { |
37 | std::complex<float> z(1, 2); |
38 | c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z); |
39 | ASSERT_EQ(zz.real(), float(1)); |
40 | ASSERT_EQ(zz.imag(), float(2)); |
41 | } |
42 | |
43 | { |
44 | c10::complex<float> z(3, 4); |
45 | std::complex<float> zz = *reinterpret_cast<std::complex<float>*>(&z); |
46 | ASSERT_EQ(zz.real(), float(3)); |
47 | ASSERT_EQ(zz.imag(), float(4)); |
48 | } |
49 | |
50 | { |
51 | std::complex<double> z(1, 2); |
52 | c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z); |
53 | ASSERT_EQ(zz.real(), double(1)); |
54 | ASSERT_EQ(zz.imag(), double(2)); |
55 | } |
56 | |
57 | { |
58 | c10::complex<double> z(3, 4); |
59 | std::complex<double> zz = *reinterpret_cast<std::complex<double>*>(&z); |
60 | ASSERT_EQ(zz.real(), double(3)); |
61 | ASSERT_EQ(zz.imag(), double(4)); |
62 | } |
63 | } |
64 | |
65 | #if defined(__CUDACC__) || defined(__HIPCC__) |
66 | TEST(TestMemory, ThrustReinterpretCast) { |
67 | { |
68 | thrust::complex<float> z(1, 2); |
69 | c10::complex<float> zz = *reinterpret_cast<c10::complex<float>*>(&z); |
70 | ASSERT_EQ(zz.real(), float(1)); |
71 | ASSERT_EQ(zz.imag(), float(2)); |
72 | } |
73 | |
74 | { |
75 | c10::complex<float> z(3, 4); |
76 | thrust::complex<float> zz = *reinterpret_cast<thrust::complex<float>*>(&z); |
77 | ASSERT_EQ(zz.real(), float(3)); |
78 | ASSERT_EQ(zz.imag(), float(4)); |
79 | } |
80 | |
81 | { |
82 | thrust::complex<double> z(1, 2); |
83 | c10::complex<double> zz = *reinterpret_cast<c10::complex<double>*>(&z); |
84 | ASSERT_EQ(zz.real(), double(1)); |
85 | ASSERT_EQ(zz.imag(), double(2)); |
86 | } |
87 | |
88 | { |
89 | c10::complex<double> z(3, 4); |
90 | thrust::complex<double> zz = |
91 | *reinterpret_cast<thrust::complex<double>*>(&z); |
92 | ASSERT_EQ(zz.real(), double(3)); |
93 | ASSERT_EQ(zz.imag(), double(4)); |
94 | } |
95 | } |
96 | #endif |
97 | |
98 | } // namespace memory |
99 | |
100 | namespace constructors { |
101 | |
102 | template <typename scalar_t> |
103 | C10_HOST_DEVICE void test_construct_from_scalar() { |
104 | constexpr scalar_t num1 = scalar_t(1.23); |
105 | constexpr scalar_t num2 = scalar_t(4.56); |
106 | constexpr scalar_t zero = scalar_t(); |
107 | static_assert(c10::complex<scalar_t>(num1, num2).real() == num1, "" ); |
108 | static_assert(c10::complex<scalar_t>(num1, num2).imag() == num2, "" ); |
109 | static_assert(c10::complex<scalar_t>(num1).real() == num1, "" ); |
110 | static_assert(c10::complex<scalar_t>(num1).imag() == zero, "" ); |
111 | static_assert(c10::complex<scalar_t>().real() == zero, "" ); |
112 | static_assert(c10::complex<scalar_t>().imag() == zero, "" ); |
113 | } |
114 | |
115 | template <typename scalar_t, typename other_t> |
116 | C10_HOST_DEVICE void test_construct_from_other() { |
117 | constexpr other_t num1 = other_t(1.23); |
118 | constexpr other_t num2 = other_t(4.56); |
119 | constexpr scalar_t num3 = scalar_t(num1); |
120 | constexpr scalar_t num4 = scalar_t(num2); |
121 | static_assert( |
122 | c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).real() == num3, |
123 | "" ); |
124 | static_assert( |
125 | c10::complex<scalar_t>(c10::complex<other_t>(num1, num2)).imag() == num4, |
126 | "" ); |
127 | } |
128 | |
129 | MAYBE_GLOBAL void test_convert_constructors() { |
130 | test_construct_from_scalar<float>(); |
131 | test_construct_from_scalar<double>(); |
132 | |
133 | static_assert( |
134 | std::is_convertible<c10::complex<float>, c10::complex<float>>::value, "" ); |
135 | static_assert( |
136 | !std::is_convertible<c10::complex<double>, c10::complex<float>>::value, |
137 | "" ); |
138 | static_assert( |
139 | std::is_convertible<c10::complex<float>, c10::complex<double>>::value, |
140 | "" ); |
141 | static_assert( |
142 | std::is_convertible<c10::complex<double>, c10::complex<double>>::value, |
143 | "" ); |
144 | |
145 | static_assert( |
146 | std::is_constructible<c10::complex<float>, c10::complex<float>>::value, |
147 | "" ); |
148 | static_assert( |
149 | std::is_constructible<c10::complex<double>, c10::complex<float>>::value, |
150 | "" ); |
151 | static_assert( |
152 | std::is_constructible<c10::complex<float>, c10::complex<double>>::value, |
153 | "" ); |
154 | static_assert( |
155 | std::is_constructible<c10::complex<double>, c10::complex<double>>::value, |
156 | "" ); |
157 | |
158 | test_construct_from_other<float, float>(); |
159 | test_construct_from_other<float, double>(); |
160 | test_construct_from_other<double, float>(); |
161 | test_construct_from_other<double, double>(); |
162 | } |
163 | |
164 | template <typename scalar_t> |
165 | C10_HOST_DEVICE void test_construct_from_std() { |
166 | constexpr scalar_t num1 = scalar_t(1.23); |
167 | constexpr scalar_t num2 = scalar_t(4.56); |
168 | static_assert( |
169 | c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).real() == num1, |
170 | "" ); |
171 | static_assert( |
172 | c10::complex<scalar_t>(std::complex<scalar_t>(num1, num2)).imag() == num2, |
173 | "" ); |
174 | } |
175 | |
176 | MAYBE_GLOBAL void test_std_conversion() { |
177 | test_construct_from_std<float>(); |
178 | test_construct_from_std<double>(); |
179 | } |
180 | |
181 | #if defined(__CUDACC__) || defined(__HIPCC__) |
182 | template <typename scalar_t> |
183 | void test_construct_from_thrust() { |
184 | constexpr scalar_t num1 = scalar_t(1.23); |
185 | constexpr scalar_t num2 = scalar_t(4.56); |
186 | ASSERT_EQ( |
187 | c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).real(), |
188 | num1); |
189 | ASSERT_EQ( |
190 | c10::complex<scalar_t>(thrust::complex<scalar_t>(num1, num2)).imag(), |
191 | num2); |
192 | } |
193 | |
194 | TEST(TestConstructors, FromThrust) { |
195 | test_construct_from_thrust<float>(); |
196 | test_construct_from_thrust<double>(); |
197 | } |
198 | #endif |
199 | |
200 | TEST(TestConstructors, UnorderedMap) { |
201 | std::unordered_map< |
202 | c10::complex<double>, |
203 | c10::complex<double>, |
204 | c10::hash<c10::complex<double>>> |
205 | m; |
206 | auto key1 = c10::complex<double>(2.5, 3); |
207 | auto key2 = c10::complex<double>(2, 0); |
208 | auto val1 = c10::complex<double>(2, -3.2); |
209 | auto val2 = c10::complex<double>(0, -3); |
210 | m[key1] = val1; |
211 | m[key2] = val2; |
212 | ASSERT_EQ(m[key1], val1); |
213 | ASSERT_EQ(m[key2], val2); |
214 | } |
215 | |
216 | } // namespace constructors |
217 | |
218 | namespace assignment { |
219 | |
220 | template <typename scalar_t> |
221 | constexpr c10::complex<scalar_t> one() { |
222 | c10::complex<scalar_t> result(3, 4); |
223 | result = scalar_t(1); |
224 | return result; |
225 | } |
226 | |
227 | MAYBE_GLOBAL void test_assign_real() { |
228 | static_assert(one<float>().real() == float(1), "" ); |
229 | static_assert(one<float>().imag() == float(), "" ); |
230 | static_assert(one<double>().real() == double(1), "" ); |
231 | static_assert(one<double>().imag() == double(), "" ); |
232 | } |
233 | |
234 | constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two() { |
235 | constexpr c10::complex<float> src(1, 2); |
236 | c10::complex<double> ret0; |
237 | c10::complex<float> ret1; |
238 | ret0 = ret1 = src; |
239 | return std::make_tuple(ret0, ret1); |
240 | } |
241 | |
242 | MAYBE_GLOBAL void test_assign_other() { |
243 | constexpr auto tup = one_two(); |
244 | static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "" ); |
245 | static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "" ); |
246 | static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "" ); |
247 | static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "" ); |
248 | } |
249 | |
250 | constexpr std::tuple<c10::complex<double>, c10::complex<float>> one_two_std() { |
251 | constexpr std::complex<float> src(1, 1); |
252 | c10::complex<double> ret0; |
253 | c10::complex<float> ret1; |
254 | ret0 = ret1 = src; |
255 | return std::make_tuple(ret0, ret1); |
256 | } |
257 | |
258 | MAYBE_GLOBAL void test_assign_std() { |
259 | constexpr auto tup = one_two(); |
260 | static_assert(std::get<c10::complex<double>>(tup).real() == double(1), "" ); |
261 | static_assert(std::get<c10::complex<double>>(tup).imag() == double(2), "" ); |
262 | static_assert(std::get<c10::complex<float>>(tup).real() == float(1), "" ); |
263 | static_assert(std::get<c10::complex<float>>(tup).imag() == float(2), "" ); |
264 | } |
265 | |
266 | #if defined(__CUDACC__) || defined(__HIPCC__) |
267 | C10_HOST_DEVICE std::tuple<c10::complex<double>, c10::complex<float>> |
268 | one_two_thrust() { |
269 | thrust::complex<float> src(1, 2); |
270 | c10::complex<double> ret0; |
271 | c10::complex<float> ret1; |
272 | ret0 = ret1 = src; |
273 | return std::make_tuple(ret0, ret1); |
274 | } |
275 | |
276 | TEST(TestAssignment, FromThrust) { |
277 | auto tup = one_two_thrust(); |
278 | ASSERT_EQ(std::get<c10::complex<double>>(tup).real(), double(1)); |
279 | ASSERT_EQ(std::get<c10::complex<double>>(tup).imag(), double(2)); |
280 | ASSERT_EQ(std::get<c10::complex<float>>(tup).real(), float(1)); |
281 | ASSERT_EQ(std::get<c10::complex<float>>(tup).imag(), float(2)); |
282 | } |
283 | #endif |
284 | |
285 | } // namespace assignment |
286 | |
287 | namespace literals { |
288 | |
289 | MAYBE_GLOBAL void test_complex_literals() { |
290 | using namespace c10::complex_literals; |
291 | static_assert(std::is_same<decltype(0.5_if), c10::complex<float>>::value, "" ); |
292 | static_assert((0.5_if).real() == float(), "" ); |
293 | static_assert((0.5_if).imag() == float(0.5), "" ); |
294 | static_assert( |
295 | std::is_same<decltype(0.5_id), c10::complex<double>>::value, "" ); |
296 | static_assert((0.5_id).real() == float(), "" ); |
297 | static_assert((0.5_id).imag() == float(0.5), "" ); |
298 | |
299 | static_assert(std::is_same<decltype(1_if), c10::complex<float>>::value, "" ); |
300 | static_assert((1_if).real() == float(), "" ); |
301 | static_assert((1_if).imag() == float(1), "" ); |
302 | static_assert(std::is_same<decltype(1_id), c10::complex<double>>::value, "" ); |
303 | static_assert((1_id).real() == double(), "" ); |
304 | static_assert((1_id).imag() == double(1), "" ); |
305 | } |
306 | |
307 | } // namespace literals |
308 | |
309 | namespace real_imag { |
310 | |
311 | template <typename scalar_t> |
312 | constexpr c10::complex<scalar_t> zero_one() { |
313 | c10::complex<scalar_t> result; |
314 | result.imag(scalar_t(1)); |
315 | return result; |
316 | } |
317 | |
318 | template <typename scalar_t> |
319 | constexpr c10::complex<scalar_t> one_zero() { |
320 | c10::complex<scalar_t> result; |
321 | result.real(scalar_t(1)); |
322 | return result; |
323 | } |
324 | |
325 | MAYBE_GLOBAL void test_real_imag_modify() { |
326 | static_assert(zero_one<float>().real() == float(0), "" ); |
327 | static_assert(zero_one<float>().imag() == float(1), "" ); |
328 | static_assert(zero_one<double>().real() == double(0), "" ); |
329 | static_assert(zero_one<double>().imag() == double(1), "" ); |
330 | |
331 | static_assert(one_zero<float>().real() == float(1), "" ); |
332 | static_assert(one_zero<float>().imag() == float(0), "" ); |
333 | static_assert(one_zero<double>().real() == double(1), "" ); |
334 | static_assert(one_zero<double>().imag() == double(0), "" ); |
335 | } |
336 | |
337 | } // namespace real_imag |
338 | |
339 | namespace arithmetic_assign { |
340 | |
341 | template <typename scalar_t> |
342 | constexpr c10::complex<scalar_t> p(scalar_t value) { |
343 | c10::complex<scalar_t> result(scalar_t(2), scalar_t(2)); |
344 | result += value; |
345 | return result; |
346 | } |
347 | |
348 | template <typename scalar_t> |
349 | constexpr c10::complex<scalar_t> m(scalar_t value) { |
350 | c10::complex<scalar_t> result(scalar_t(2), scalar_t(2)); |
351 | result -= value; |
352 | return result; |
353 | } |
354 | |
355 | template <typename scalar_t> |
356 | constexpr c10::complex<scalar_t> t(scalar_t value) { |
357 | c10::complex<scalar_t> result(scalar_t(2), scalar_t(2)); |
358 | result *= value; |
359 | return result; |
360 | } |
361 | |
362 | template <typename scalar_t> |
363 | constexpr c10::complex<scalar_t> d(scalar_t value) { |
364 | c10::complex<scalar_t> result(scalar_t(2), scalar_t(2)); |
365 | result /= value; |
366 | return result; |
367 | } |
368 | |
369 | template <typename scalar_t> |
370 | C10_HOST_DEVICE void test_arithmetic_assign_scalar() { |
371 | constexpr c10::complex<scalar_t> x = p(scalar_t(1)); |
372 | static_assert(x.real() == scalar_t(3), "" ); |
373 | static_assert(x.imag() == scalar_t(2), "" ); |
374 | constexpr c10::complex<scalar_t> y = m(scalar_t(1)); |
375 | static_assert(y.real() == scalar_t(1), "" ); |
376 | static_assert(y.imag() == scalar_t(2), "" ); |
377 | constexpr c10::complex<scalar_t> z = t(scalar_t(2)); |
378 | static_assert(z.real() == scalar_t(4), "" ); |
379 | static_assert(z.imag() == scalar_t(4), "" ); |
380 | constexpr c10::complex<scalar_t> t = d(scalar_t(2)); |
381 | static_assert(t.real() == scalar_t(1), "" ); |
382 | static_assert(t.imag() == scalar_t(1), "" ); |
383 | } |
384 | |
385 | template <typename scalar_t, typename rhs_t> |
386 | constexpr c10::complex<scalar_t> p( |
387 | scalar_t real, |
388 | scalar_t imag, |
389 | c10::complex<rhs_t> rhs) { |
390 | c10::complex<scalar_t> result(real, imag); |
391 | result += rhs; |
392 | return result; |
393 | } |
394 | |
395 | template <typename scalar_t, typename rhs_t> |
396 | constexpr c10::complex<scalar_t> m( |
397 | scalar_t real, |
398 | scalar_t imag, |
399 | c10::complex<rhs_t> rhs) { |
400 | c10::complex<scalar_t> result(real, imag); |
401 | result -= rhs; |
402 | return result; |
403 | } |
404 | |
405 | template <typename scalar_t, typename rhs_t> |
406 | constexpr c10::complex<scalar_t> t( |
407 | scalar_t real, |
408 | scalar_t imag, |
409 | c10::complex<rhs_t> rhs) { |
410 | c10::complex<scalar_t> result(real, imag); |
411 | result *= rhs; |
412 | return result; |
413 | } |
414 | |
415 | template <typename scalar_t, typename rhs_t> |
416 | constexpr c10::complex<scalar_t> d( |
417 | scalar_t real, |
418 | scalar_t imag, |
419 | c10::complex<rhs_t> rhs) { |
420 | c10::complex<scalar_t> result(real, imag); |
421 | result /= rhs; |
422 | return result; |
423 | } |
424 | |
425 | template <typename scalar_t> |
426 | C10_HOST_DEVICE void test_arithmetic_assign_complex() { |
427 | using namespace c10::complex_literals; |
428 | constexpr c10::complex<scalar_t> x2 = p(scalar_t(2), scalar_t(2), 1.0_if); |
429 | static_assert(x2.real() == scalar_t(2), "" ); |
430 | static_assert(x2.imag() == scalar_t(3), "" ); |
431 | constexpr c10::complex<scalar_t> x3 = p(scalar_t(2), scalar_t(2), 1.0_id); |
432 | static_assert(x3.real() == scalar_t(2), "" ); |
433 | |
434 | // this test is skipped due to a bug in constexpr evaluation |
435 | // in nvcc. This bug has already been fixed since CUDA 11.2 |
436 | #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) |
437 | static_assert(x3.imag() == scalar_t(3), "" ); |
438 | #endif |
439 | |
440 | constexpr c10::complex<scalar_t> y2 = m(scalar_t(2), scalar_t(2), 1.0_if); |
441 | static_assert(y2.real() == scalar_t(2), "" ); |
442 | static_assert(y2.imag() == scalar_t(1), "" ); |
443 | constexpr c10::complex<scalar_t> y3 = m(scalar_t(2), scalar_t(2), 1.0_id); |
444 | static_assert(y3.real() == scalar_t(2), "" ); |
445 | |
446 | // this test is skipped due to a bug in constexpr evaluation |
447 | // in nvcc. This bug has already been fixed since CUDA 11.2 |
448 | #if !defined(__CUDACC__) || (defined(CUDA_VERSION) && CUDA_VERSION >= 11020) |
449 | static_assert(y3.imag() == scalar_t(1), "" ); |
450 | #endif |
451 | |
452 | constexpr c10::complex<scalar_t> z2 = t(scalar_t(1), scalar_t(-2), 1.0_if); |
453 | static_assert(z2.real() == scalar_t(2), "" ); |
454 | static_assert(z2.imag() == scalar_t(1), "" ); |
455 | constexpr c10::complex<scalar_t> z3 = t(scalar_t(1), scalar_t(-2), 1.0_id); |
456 | static_assert(z3.real() == scalar_t(2), "" ); |
457 | static_assert(z3.imag() == scalar_t(1), "" ); |
458 | |
459 | constexpr c10::complex<scalar_t> t2 = d(scalar_t(-1), scalar_t(2), 1.0_if); |
460 | static_assert(t2.real() == scalar_t(2), "" ); |
461 | static_assert(t2.imag() == scalar_t(1), "" ); |
462 | constexpr c10::complex<scalar_t> t3 = d(scalar_t(-1), scalar_t(2), 1.0_id); |
463 | static_assert(t3.real() == scalar_t(2), "" ); |
464 | static_assert(t3.imag() == scalar_t(1), "" ); |
465 | } |
466 | |
467 | MAYBE_GLOBAL void test_arithmetic_assign() { |
468 | test_arithmetic_assign_scalar<float>(); |
469 | test_arithmetic_assign_scalar<double>(); |
470 | test_arithmetic_assign_complex<float>(); |
471 | test_arithmetic_assign_complex<double>(); |
472 | } |
473 | |
474 | } // namespace arithmetic_assign |
475 | |
476 | namespace arithmetic { |
477 | |
478 | template <typename scalar_t> |
479 | C10_HOST_DEVICE void test_arithmetic_() { |
480 | static_assert( |
481 | c10::complex<scalar_t>(1, 2) == +c10::complex<scalar_t>(1, 2), "" ); |
482 | static_assert( |
483 | c10::complex<scalar_t>(-1, -2) == -c10::complex<scalar_t>(1, 2), "" ); |
484 | |
485 | static_assert( |
486 | c10::complex<scalar_t>(1, 2) + c10::complex<scalar_t>(3, 4) == |
487 | c10::complex<scalar_t>(4, 6), |
488 | "" ); |
489 | static_assert( |
490 | c10::complex<scalar_t>(1, 2) + scalar_t(3) == |
491 | c10::complex<scalar_t>(4, 2), |
492 | "" ); |
493 | static_assert( |
494 | scalar_t(3) + c10::complex<scalar_t>(1, 2) == |
495 | c10::complex<scalar_t>(4, 2), |
496 | "" ); |
497 | |
498 | static_assert( |
499 | c10::complex<scalar_t>(1, 2) - c10::complex<scalar_t>(3, 4) == |
500 | c10::complex<scalar_t>(-2, -2), |
501 | "" ); |
502 | static_assert( |
503 | c10::complex<scalar_t>(1, 2) - scalar_t(3) == |
504 | c10::complex<scalar_t>(-2, 2), |
505 | "" ); |
506 | static_assert( |
507 | scalar_t(3) - c10::complex<scalar_t>(1, 2) == |
508 | c10::complex<scalar_t>(2, -2), |
509 | "" ); |
510 | |
511 | static_assert( |
512 | c10::complex<scalar_t>(1, 2) * c10::complex<scalar_t>(3, 4) == |
513 | c10::complex<scalar_t>(-5, 10), |
514 | "" ); |
515 | static_assert( |
516 | c10::complex<scalar_t>(1, 2) * scalar_t(3) == |
517 | c10::complex<scalar_t>(3, 6), |
518 | "" ); |
519 | static_assert( |
520 | scalar_t(3) * c10::complex<scalar_t>(1, 2) == |
521 | c10::complex<scalar_t>(3, 6), |
522 | "" ); |
523 | |
524 | static_assert( |
525 | c10::complex<scalar_t>(-5, 10) / c10::complex<scalar_t>(3, 4) == |
526 | c10::complex<scalar_t>(1, 2), |
527 | "" ); |
528 | static_assert( |
529 | c10::complex<scalar_t>(5, 10) / scalar_t(5) == |
530 | c10::complex<scalar_t>(1, 2), |
531 | "" ); |
532 | static_assert( |
533 | scalar_t(25) / c10::complex<scalar_t>(3, 4) == |
534 | c10::complex<scalar_t>(3, -4), |
535 | "" ); |
536 | } |
537 | |
538 | MAYBE_GLOBAL void test_arithmetic() { |
539 | test_arithmetic_<float>(); |
540 | test_arithmetic_<double>(); |
541 | } |
542 | |
543 | template <typename T, typename int_t> |
544 | void test_binary_ops_for_int_type_(T real, T img, int_t num) { |
545 | c10::complex<T> c(real, img); |
546 | ASSERT_EQ(c + num, c10::complex<T>(real + num, img)); |
547 | ASSERT_EQ(num + c, c10::complex<T>(num + real, img)); |
548 | ASSERT_EQ(c - num, c10::complex<T>(real - num, img)); |
549 | ASSERT_EQ(num - c, c10::complex<T>(num - real, -img)); |
550 | ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num)); |
551 | ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img)); |
552 | ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num)); |
553 | ASSERT_EQ( |
554 | num / c, |
555 | c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c))); |
556 | } |
557 | |
558 | template <typename T> |
559 | void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) { |
560 | test_binary_ops_for_int_type_<T, int8_t>(real, img, i); |
561 | test_binary_ops_for_int_type_<T, int16_t>(real, img, i); |
562 | test_binary_ops_for_int_type_<T, int32_t>(real, img, i); |
563 | test_binary_ops_for_int_type_<T, int64_t>(real, img, i); |
564 | } |
565 | |
566 | TEST(TestArithmeticIntScalar, All) { |
567 | test_binary_ops_for_all_int_types_<float>(1.0, 0.1, 1); |
568 | test_binary_ops_for_all_int_types_<double>(-1.3, -0.2, -2); |
569 | } |
570 | |
571 | } // namespace arithmetic |
572 | |
573 | namespace equality { |
574 | |
575 | template <typename scalar_t> |
576 | C10_HOST_DEVICE void test_equality_() { |
577 | static_assert( |
578 | c10::complex<scalar_t>(1, 2) == c10::complex<scalar_t>(1, 2), "" ); |
579 | static_assert(c10::complex<scalar_t>(1, 0) == scalar_t(1), "" ); |
580 | static_assert(scalar_t(1) == c10::complex<scalar_t>(1, 0), "" ); |
581 | static_assert( |
582 | c10::complex<scalar_t>(1, 2) != c10::complex<scalar_t>(3, 4), "" ); |
583 | static_assert(c10::complex<scalar_t>(1, 2) != scalar_t(1), "" ); |
584 | static_assert(scalar_t(1) != c10::complex<scalar_t>(1, 2), "" ); |
585 | } |
586 | |
587 | MAYBE_GLOBAL void test_equality() { |
588 | test_equality_<float>(); |
589 | test_equality_<double>(); |
590 | } |
591 | |
592 | } // namespace equality |
593 | |
594 | namespace io { |
595 | |
596 | template <typename scalar_t> |
597 | void test_io_() { |
598 | std::stringstream ss; |
599 | c10::complex<scalar_t> a(1, 2); |
600 | ss << a; |
601 | ASSERT_EQ(ss.str(), "(1,2)" ); |
602 | ss.str("(3,4)" ); |
603 | ss >> a; |
604 | ASSERT_TRUE(a == c10::complex<scalar_t>(3, 4)); |
605 | } |
606 | |
607 | TEST(TestIO, All) { |
608 | test_io_<float>(); |
609 | test_io_<double>(); |
610 | } |
611 | |
612 | } // namespace io |
613 | |
614 | namespace test_std { |
615 | |
616 | template <typename scalar_t> |
617 | C10_HOST_DEVICE void test_callable_() { |
618 | static_assert(std::real(c10::complex<scalar_t>(1, 2)) == scalar_t(1), "" ); |
619 | static_assert(std::imag(c10::complex<scalar_t>(1, 2)) == scalar_t(2), "" ); |
620 | std::abs(c10::complex<scalar_t>(1, 2)); |
621 | std::arg(c10::complex<scalar_t>(1, 2)); |
622 | static_assert(std::norm(c10::complex<scalar_t>(3, 4)) == scalar_t(25), "" ); |
623 | static_assert( |
624 | std::conj(c10::complex<scalar_t>(3, 4)) == c10::complex<scalar_t>(3, -4), |
625 | "" ); |
626 | c10::polar(float(1), float(PI / 2)); |
627 | c10::polar(double(1), double(PI / 2)); |
628 | } |
629 | |
630 | MAYBE_GLOBAL void test_callable() { |
631 | test_callable_<float>(); |
632 | test_callable_<double>(); |
633 | } |
634 | |
635 | template <typename scalar_t> |
636 | void test_values_() { |
637 | ASSERT_EQ(std::abs(c10::complex<scalar_t>(3, 4)), scalar_t(5)); |
638 | ASSERT_LT(std::abs(std::arg(c10::complex<scalar_t>(0, 1)) - PI / 2), 1e-6); |
639 | ASSERT_LT( |
640 | std::abs( |
641 | c10::polar(scalar_t(1), scalar_t(PI / 2)) - |
642 | c10::complex<scalar_t>(0, 1)), |
643 | 1e-6); |
644 | } |
645 | |
646 | TEST(TestStd, BasicFunctions) { |
647 | test_values_<float>(); |
648 | test_values_<double>(); |
649 | // CSQRT edge cases: checks for overflows which are likely to occur |
650 | // if square root is computed using polar form |
651 | ASSERT_LT( |
652 | std::abs(std::sqrt(c10::complex<float>(-1e20, -4988429.2)).real()), 3e-4); |
653 | ASSERT_LT( |
654 | std::abs(std::sqrt(c10::complex<double>(-1e60, -4988429.2)).real()), |
655 | 3e-4); |
656 | } |
657 | |
658 | } // namespace test_std |
659 | |