1// Warning: this file is included twice in
2// aten/src/ATen/test/cuda_complex_math_test.cu
3
4#include <c10/util/complex.h>
5#include <gtest/gtest.h>
6
7#ifndef PI
8#define PI 3.141592653589793238463
9#endif
10
11#ifndef tol
12#define tol 1e-6
13#endif
14
15// Exponential functions
16
17C10_DEFINE_TEST(TestExponential, IPi) {
18 // exp(i*pi) = -1
19 {
20 c10::complex<float> e_i_pi = std::exp(c10::complex<float>(0, float(PI)));
21 C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
22 C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
23 }
24 {
25 c10::complex<float> e_i_pi = ::exp(c10::complex<float>(0, float(PI)));
26 C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
27 C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
28 }
29 {
30 c10::complex<double> e_i_pi = std::exp(c10::complex<double>(0, PI));
31 C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
32 C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
33 }
34 {
35 c10::complex<double> e_i_pi = ::exp(c10::complex<double>(0, PI));
36 C10_ASSERT_NEAR(e_i_pi.real(), -1, tol);
37 C10_ASSERT_NEAR(e_i_pi.imag(), 0, tol);
38 }
39}
40
41C10_DEFINE_TEST(TestExponential, EulerFormula) {
42 // exp(ix) = cos(x) + i * sin(x)
43 {
44 c10::complex<float> x(0.1, 1.2);
45 c10::complex<float> e = std::exp(x);
46 float expected_real = std::exp(x.real()) * std::cos(x.imag());
47 float expected_imag = std::exp(x.real()) * std::sin(x.imag());
48 C10_ASSERT_NEAR(e.real(), expected_real, tol);
49 C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
50 }
51 {
52 c10::complex<float> x(0.1, 1.2);
53 c10::complex<float> e = ::exp(x);
54 float expected_real = ::exp(x.real()) * ::cos(x.imag());
55 float expected_imag = ::exp(x.real()) * ::sin(x.imag());
56 C10_ASSERT_NEAR(e.real(), expected_real, tol);
57 C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
58 }
59 {
60 c10::complex<double> x(0.1, 1.2);
61 c10::complex<double> e = std::exp(x);
62 float expected_real = std::exp(x.real()) * std::cos(x.imag());
63 float expected_imag = std::exp(x.real()) * std::sin(x.imag());
64 C10_ASSERT_NEAR(e.real(), expected_real, tol);
65 C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
66 }
67 {
68 c10::complex<double> x(0.1, 1.2);
69 c10::complex<double> e = ::exp(x);
70 float expected_real = ::exp(x.real()) * ::cos(x.imag());
71 float expected_imag = ::exp(x.real()) * ::sin(x.imag());
72 C10_ASSERT_NEAR(e.real(), expected_real, tol);
73 C10_ASSERT_NEAR(e.imag(), expected_imag, tol);
74 }
75}
76
77C10_DEFINE_TEST(TestLog, Definition) {
78 // log(x) = log(r) + i*theta
79 {
80 c10::complex<float> x(1.2, 3.4);
81 c10::complex<float> l = std::log(x);
82 float expected_real = std::log(std::abs(x));
83 float expected_imag = std::arg(x);
84 C10_ASSERT_NEAR(l.real(), expected_real, tol);
85 C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
86 }
87 {
88 c10::complex<float> x(1.2, 3.4);
89 c10::complex<float> l = ::log(x);
90 float expected_real = ::log(std::abs(x));
91 float expected_imag = std::arg(x);
92 C10_ASSERT_NEAR(l.real(), expected_real, tol);
93 C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
94 }
95 {
96 c10::complex<double> x(1.2, 3.4);
97 c10::complex<double> l = std::log(x);
98 float expected_real = std::log(std::abs(x));
99 float expected_imag = std::arg(x);
100 C10_ASSERT_NEAR(l.real(), expected_real, tol);
101 C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
102 }
103 {
104 c10::complex<double> x(1.2, 3.4);
105 c10::complex<double> l = ::log(x);
106 float expected_real = ::log(std::abs(x));
107 float expected_imag = std::arg(x);
108 C10_ASSERT_NEAR(l.real(), expected_real, tol);
109 C10_ASSERT_NEAR(l.imag(), expected_imag, tol);
110 }
111}
112
113C10_DEFINE_TEST(TestLog10, Rev) {
114 // log10(10^x) = x
115 {
116 c10::complex<float> x(0.1, 1.2);
117 c10::complex<float> l = std::log10(std::pow(float(10), x));
118 C10_ASSERT_NEAR(l.real(), float(0.1), tol);
119 C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
120 }
121 {
122 c10::complex<float> x(0.1, 1.2);
123 c10::complex<float> l = ::log10(::pow(float(10), x));
124 C10_ASSERT_NEAR(l.real(), float(0.1), tol);
125 C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
126 }
127 {
128 c10::complex<double> x(0.1, 1.2);
129 c10::complex<double> l = std::log10(std::pow(double(10), x));
130 C10_ASSERT_NEAR(l.real(), double(0.1), tol);
131 C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
132 }
133 {
134 c10::complex<double> x(0.1, 1.2);
135 c10::complex<double> l = ::log10(::pow(double(10), x));
136 C10_ASSERT_NEAR(l.real(), double(0.1), tol);
137 C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
138 }
139}
140
141C10_DEFINE_TEST(TestLog2, Rev) {
142 // log2(2^x) = x
143 {
144 c10::complex<float> x(0.1, 1.2);
145 c10::complex<float> l = std::log2(std::pow(float(2), x));
146 C10_ASSERT_NEAR(l.real(), float(0.1), tol);
147 C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
148 }
149 {
150 c10::complex<float> x(0.1, 1.2);
151 c10::complex<float> l = ::log2(std::pow(float(2), x));
152 C10_ASSERT_NEAR(l.real(), float(0.1), tol);
153 C10_ASSERT_NEAR(l.imag(), float(1.2), tol);
154 }
155 {
156 c10::complex<double> x(0.1, 1.2);
157 c10::complex<double> l = std::log2(std::pow(double(2), x));
158 C10_ASSERT_NEAR(l.real(), double(0.1), tol);
159 C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
160 }
161 {
162 c10::complex<double> x(0.1, 1.2);
163 c10::complex<double> l = ::log2(std::pow(double(2), x));
164 C10_ASSERT_NEAR(l.real(), double(0.1), tol);
165 C10_ASSERT_NEAR(l.imag(), double(1.2), tol);
166 }
167}
168
169C10_DEFINE_TEST(TestLog1p, Normal) {
170 // log1p(x) = log(1 + x)
171 {
172 c10::complex<float> x(0.1, 1.2);
173 c10::complex<float> l1 = std::log1p(x);
174 c10::complex<float> l2 = std::log(1.0f + x);
175 C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
176 C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
177 }
178 {
179 c10::complex<double> x(0.1, 1.2);
180 c10::complex<double> l1 = std::log1p(x);
181 c10::complex<double> l2 = std::log(1.0 + x);
182 C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
183 C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
184 }
185}
186
187C10_DEFINE_TEST(TestLog1p, Small) {
188 // log(1 + x) ~ x for |x| << 1
189 {
190 c10::complex<float> x(1e-9, 2e-9);
191 c10::complex<float> l = std::log1p(x);
192 C10_ASSERT_NEAR(l.real() / x.real(), 1, tol);
193 C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol);
194 }
195 {
196 c10::complex<double> x(1e-100, 2e-100);
197 c10::complex<double> l = std::log1p(x);
198 C10_ASSERT_NEAR(l.real() / x.real(), 1, tol);
199 C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol);
200 }
201}
202
203C10_DEFINE_TEST(TestLog1p, Extreme) {
204 // log(1 + x) ~ x for |x| << 1 and in the brink of overflow / underflow
205 {
206 c10::complex<float> x(-1, 1e-30);
207 c10::complex<float> l = std::log1p(x);
208 C10_ASSERT_NEAR(l.real(), -69.07755278982137, tol);
209 C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
210 }
211 {
212 c10::complex<float> x(-1, 1e30);
213 c10::complex<float> l = std::log1p(x);
214 C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol);
215 C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
216 }
217 {
218 c10::complex<float> x(1e30, 1);
219 c10::complex<float> l = std::log1p(x);
220 C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol);
221 C10_ASSERT_NEAR(l.imag(), 1e-30, tol);
222 }
223 {
224 c10::complex<float> x(1e-30, 1);
225 c10::complex<float> l = std::log1p(x);
226 C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol);
227 C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
228 }
229 {
230 c10::complex<float> x(1e30, 1e30);
231 c10::complex<float> l = std::log1p(x);
232 C10_ASSERT_NEAR(l.real(), 69.42412638010134, tol);
233 C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
234 }
235 {
236 c10::complex<float> x(1e-38, 1e-38);
237 c10::complex<float> l = std::log1p(x);
238 C10_ASSERT_NEAR(l.real(), 1e-38, tol);
239 C10_ASSERT_NEAR(l.imag(), 1e-38, tol);
240 }
241 {
242 c10::complex<float> x(1e-38, 2e-30);
243 c10::complex<float> l = std::log1p(x);
244 C10_ASSERT_NEAR(l.real(), 1e-30, tol);
245 C10_ASSERT_NEAR(l.imag(), 2e-30, tol);
246 }
247 {
248 c10::complex<double> x(-1, 1e-250);
249 c10::complex<double> l = std::log1p(x);
250 C10_ASSERT_NEAR(l.real(), -575.6462732485114, tol);
251 C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
252 }
253 {
254 c10::complex<double> x(-1, 1e250);
255 c10::complex<double> l = std::log1p(x);
256 C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol);
257 C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
258 }
259 {
260 c10::complex<double> x(1e250, 1);
261 c10::complex<double> l = std::log1p(x);
262 C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol);
263 C10_ASSERT_NEAR(l.imag(), 1e-250, tol);
264 }
265 {
266 c10::complex<double> x(1e-250, 1);
267 c10::complex<double> l = std::log1p(x);
268 C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol);
269 C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
270 }
271 {
272 c10::complex<double> x(1e250, 1e250);
273 c10::complex<double> l = std::log1p(x);
274 C10_ASSERT_NEAR(l.real(), 575.9928468387914, tol);
275 C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
276 }
277 {
278 c10::complex<double> x(1e-250, 1e-250);
279 c10::complex<double> l = std::log1p(x);
280 C10_ASSERT_NEAR(l.real(), 1e-250, tol);
281 C10_ASSERT_NEAR(l.imag(), 1e-250, tol);
282 }
283 {
284 c10::complex<double> x(1e-250, 2e-250);
285 c10::complex<double> l = std::log1p(x);
286 C10_ASSERT_NEAR(l.real(), 1e-250, tol);
287 C10_ASSERT_NEAR(l.imag(), 2e-250, tol);
288 }
289 {
290 c10::complex<double> x(2e-308, 1.5e-250);
291 c10::complex<double> l = std::log1p(x);
292 C10_ASSERT_NEAR(l.real(), 2e-308, tol);
293 C10_ASSERT_NEAR(l.imag(), 1.5e-308, tol);
294 }
295}
296
297// Power functions
298
299C10_DEFINE_TEST(TestPowSqrt, Equal) {
300 // x^0.5 = sqrt(x)
301 {
302 c10::complex<float> x(0.1, 1.2);
303 c10::complex<float> y = std::pow(x, float(0.5));
304 c10::complex<float> z = std::sqrt(x);
305 C10_ASSERT_NEAR(y.real(), z.real(), tol);
306 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
307 }
308 {
309 c10::complex<float> x(0.1, 1.2);
310 c10::complex<float> y = ::pow(x, float(0.5));
311 c10::complex<float> z = ::sqrt(x);
312 C10_ASSERT_NEAR(y.real(), z.real(), tol);
313 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
314 }
315 {
316 c10::complex<double> x(0.1, 1.2);
317 c10::complex<double> y = std::pow(x, double(0.5));
318 c10::complex<double> z = std::sqrt(x);
319 C10_ASSERT_NEAR(y.real(), z.real(), tol);
320 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
321 }
322 {
323 c10::complex<double> x(0.1, 1.2);
324 c10::complex<double> y = ::pow(x, double(0.5));
325 c10::complex<double> z = ::sqrt(x);
326 C10_ASSERT_NEAR(y.real(), z.real(), tol);
327 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
328 }
329}
330
331C10_DEFINE_TEST(TestPow, Square) {
332 // x^2 = x * x
333 {
334 c10::complex<float> x(0.1, 1.2);
335 c10::complex<float> y = std::pow(x, float(2));
336 c10::complex<float> z = x * x;
337 C10_ASSERT_NEAR(y.real(), z.real(), tol);
338 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
339 }
340 {
341 c10::complex<float> x(0.1, 1.2);
342 c10::complex<float> y = ::pow(x, float(2));
343 c10::complex<float> z = x * x;
344 C10_ASSERT_NEAR(y.real(), z.real(), tol);
345 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
346 }
347 {
348 c10::complex<double> x(0.1, 1.2);
349 c10::complex<double> y = std::pow(x, double(2));
350 c10::complex<double> z = x * x;
351 C10_ASSERT_NEAR(y.real(), z.real(), tol);
352 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
353 }
354 {
355 c10::complex<double> x(0.1, 1.2);
356 c10::complex<double> y = ::pow(x, double(2));
357 c10::complex<double> z = x * x;
358 C10_ASSERT_NEAR(y.real(), z.real(), tol);
359 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
360 }
361}
362
363// Trigonometric functions and hyperbolic functions
364
365C10_DEFINE_TEST(TestSinCosSinhCosh, Identity) {
366 // sin(x + i * y) = sin(x) * cosh(y) + i * cos(x) * sinh(y)
367 // cos(x + i * y) = cos(x) * cosh(y) - i * sin(x) * sinh(y)
368 {
369 c10::complex<float> x(0.1, 1.2);
370 c10::complex<float> y = std::sin(x);
371 float expected_real = std::sin(x.real()) * std::cosh(x.imag());
372 float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
373 C10_ASSERT_NEAR(y.real(), expected_real, tol);
374 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
375 }
376 {
377 c10::complex<float> x(0.1, 1.2);
378 c10::complex<float> y = ::sin(x);
379 float expected_real = ::sin(x.real()) * ::cosh(x.imag());
380 float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
381 C10_ASSERT_NEAR(y.real(), expected_real, tol);
382 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
383 }
384 {
385 c10::complex<float> x(0.1, 1.2);
386 c10::complex<float> y = std::cos(x);
387 float expected_real = std::cos(x.real()) * std::cosh(x.imag());
388 float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
389 C10_ASSERT_NEAR(y.real(), expected_real, tol);
390 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
391 }
392 {
393 c10::complex<float> x(0.1, 1.2);
394 c10::complex<float> y = ::cos(x);
395 float expected_real = ::cos(x.real()) * ::cosh(x.imag());
396 float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
397 C10_ASSERT_NEAR(y.real(), expected_real, tol);
398 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
399 }
400 {
401 c10::complex<double> x(0.1, 1.2);
402 c10::complex<double> y = std::sin(x);
403 float expected_real = std::sin(x.real()) * std::cosh(x.imag());
404 float expected_imag = std::cos(x.real()) * std::sinh(x.imag());
405 C10_ASSERT_NEAR(y.real(), expected_real, tol);
406 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
407 }
408 {
409 c10::complex<double> x(0.1, 1.2);
410 c10::complex<double> y = ::sin(x);
411 float expected_real = ::sin(x.real()) * ::cosh(x.imag());
412 float expected_imag = ::cos(x.real()) * ::sinh(x.imag());
413 C10_ASSERT_NEAR(y.real(), expected_real, tol);
414 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
415 }
416 {
417 c10::complex<double> x(0.1, 1.2);
418 c10::complex<double> y = std::cos(x);
419 float expected_real = std::cos(x.real()) * std::cosh(x.imag());
420 float expected_imag = -std::sin(x.real()) * std::sinh(x.imag());
421 C10_ASSERT_NEAR(y.real(), expected_real, tol);
422 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
423 }
424 {
425 c10::complex<double> x(0.1, 1.2);
426 c10::complex<double> y = ::cos(x);
427 float expected_real = ::cos(x.real()) * ::cosh(x.imag());
428 float expected_imag = -::sin(x.real()) * ::sinh(x.imag());
429 C10_ASSERT_NEAR(y.real(), expected_real, tol);
430 C10_ASSERT_NEAR(y.imag(), expected_imag, tol);
431 }
432}
433
434C10_DEFINE_TEST(TestTan, Identity) {
435 // tan(x) = sin(x) / cos(x)
436 {
437 c10::complex<float> x(0.1, 1.2);
438 c10::complex<float> y = std::tan(x);
439 c10::complex<float> z = std::sin(x) / std::cos(x);
440 C10_ASSERT_NEAR(y.real(), z.real(), tol);
441 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
442 }
443 {
444 c10::complex<float> x(0.1, 1.2);
445 c10::complex<float> y = ::tan(x);
446 c10::complex<float> z = ::sin(x) / ::cos(x);
447 C10_ASSERT_NEAR(y.real(), z.real(), tol);
448 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
449 }
450 {
451 c10::complex<double> x(0.1, 1.2);
452 c10::complex<double> y = std::tan(x);
453 c10::complex<double> z = std::sin(x) / std::cos(x);
454 C10_ASSERT_NEAR(y.real(), z.real(), tol);
455 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
456 }
457 {
458 c10::complex<double> x(0.1, 1.2);
459 c10::complex<double> y = ::tan(x);
460 c10::complex<double> z = ::sin(x) / ::cos(x);
461 C10_ASSERT_NEAR(y.real(), z.real(), tol);
462 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
463 }
464}
465
466C10_DEFINE_TEST(TestTanh, Identity) {
467 // tanh(x) = sinh(x) / cosh(x)
468 {
469 c10::complex<float> x(0.1, 1.2);
470 c10::complex<float> y = std::tanh(x);
471 c10::complex<float> z = std::sinh(x) / std::cosh(x);
472 C10_ASSERT_NEAR(y.real(), z.real(), tol);
473 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
474 }
475 {
476 c10::complex<float> x(0.1, 1.2);
477 c10::complex<float> y = ::tanh(x);
478 c10::complex<float> z = ::sinh(x) / ::cosh(x);
479 C10_ASSERT_NEAR(y.real(), z.real(), tol);
480 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
481 }
482 {
483 c10::complex<double> x(0.1, 1.2);
484 c10::complex<double> y = std::tanh(x);
485 c10::complex<double> z = std::sinh(x) / std::cosh(x);
486 C10_ASSERT_NEAR(y.real(), z.real(), tol);
487 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
488 }
489 {
490 c10::complex<double> x(0.1, 1.2);
491 c10::complex<double> y = ::tanh(x);
492 c10::complex<double> z = ::sinh(x) / ::cosh(x);
493 C10_ASSERT_NEAR(y.real(), z.real(), tol);
494 C10_ASSERT_NEAR(y.imag(), z.imag(), tol);
495 }
496}
497
498// Rev trigonometric functions
499
500C10_DEFINE_TEST(TestRevTrigonometric, Rev) {
501 // asin(sin(x)) = x
502 // acos(cos(x)) = x
503 // atan(tan(x)) = x
504 {
505 c10::complex<float> x(0.5, 0.6);
506 c10::complex<float> s = std::sin(x);
507 c10::complex<float> ss = std::asin(s);
508 c10::complex<float> c = std::cos(x);
509 c10::complex<float> cc = std::acos(c);
510 c10::complex<float> t = std::tan(x);
511 c10::complex<float> tt = std::atan(t);
512 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
513 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
514 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
515 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
516 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
517 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
518 }
519 {
520 c10::complex<float> x(0.5, 0.6);
521 c10::complex<float> s = ::sin(x);
522 c10::complex<float> ss = ::asin(s);
523 c10::complex<float> c = ::cos(x);
524 c10::complex<float> cc = ::acos(c);
525 c10::complex<float> t = ::tan(x);
526 c10::complex<float> tt = ::atan(t);
527 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
528 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
529 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
530 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
531 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
532 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
533 }
534 {
535 c10::complex<double> x(0.5, 0.6);
536 c10::complex<double> s = std::sin(x);
537 c10::complex<double> ss = std::asin(s);
538 c10::complex<double> c = std::cos(x);
539 c10::complex<double> cc = std::acos(c);
540 c10::complex<double> t = std::tan(x);
541 c10::complex<double> tt = std::atan(t);
542 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
543 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
544 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
545 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
546 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
547 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
548 }
549 {
550 c10::complex<double> x(0.5, 0.6);
551 c10::complex<double> s = ::sin(x);
552 c10::complex<double> ss = ::asin(s);
553 c10::complex<double> c = ::cos(x);
554 c10::complex<double> cc = ::acos(c);
555 c10::complex<double> t = ::tan(x);
556 c10::complex<double> tt = ::atan(t);
557 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
558 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
559 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
560 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
561 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
562 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
563 }
564}
565
566// Rev hyperbolic functions
567
568C10_DEFINE_TEST(TestRevHyperbolic, Rev) {
569 // asinh(sinh(x)) = x
570 // acosh(cosh(x)) = x
571 // atanh(tanh(x)) = x
572 {
573 c10::complex<float> x(0.5, 0.6);
574 c10::complex<float> s = std::sinh(x);
575 c10::complex<float> ss = std::asinh(s);
576 c10::complex<float> c = std::cosh(x);
577 c10::complex<float> cc = std::acosh(c);
578 c10::complex<float> t = std::tanh(x);
579 c10::complex<float> tt = std::atanh(t);
580 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
581 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
582 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
583 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
584 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
585 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
586 }
587 {
588 c10::complex<float> x(0.5, 0.6);
589 c10::complex<float> s = ::sinh(x);
590 c10::complex<float> ss = ::asinh(s);
591 c10::complex<float> c = ::cosh(x);
592 c10::complex<float> cc = ::acosh(c);
593 c10::complex<float> t = ::tanh(x);
594 c10::complex<float> tt = ::atanh(t);
595 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
596 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
597 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
598 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
599 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
600 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
601 }
602 {
603 c10::complex<double> x(0.5, 0.6);
604 c10::complex<double> s = std::sinh(x);
605 c10::complex<double> ss = std::asinh(s);
606 c10::complex<double> c = std::cosh(x);
607 c10::complex<double> cc = std::acosh(c);
608 c10::complex<double> t = std::tanh(x);
609 c10::complex<double> tt = std::atanh(t);
610 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
611 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
612 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
613 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
614 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
615 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
616 }
617 {
618 c10::complex<double> x(0.5, 0.6);
619 c10::complex<double> s = ::sinh(x);
620 c10::complex<double> ss = ::asinh(s);
621 c10::complex<double> c = ::cosh(x);
622 c10::complex<double> cc = ::acosh(c);
623 c10::complex<double> t = ::tanh(x);
624 c10::complex<double> tt = ::atanh(t);
625 C10_ASSERT_NEAR(x.real(), ss.real(), tol);
626 C10_ASSERT_NEAR(x.imag(), ss.imag(), tol);
627 C10_ASSERT_NEAR(x.real(), cc.real(), tol);
628 C10_ASSERT_NEAR(x.imag(), cc.imag(), tol);
629 C10_ASSERT_NEAR(x.real(), tt.real(), tol);
630 C10_ASSERT_NEAR(x.imag(), tt.imag(), tol);
631 }
632}
633