1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5
6#include <c10/util/complex.h>
7#include <c10/util/irange.h>
8#include <ATen/cpu/vec/intrinsics.h>
9#include <ATen/cpu/vec/vec_base.h>
10
11#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
12#include <sleef.h>
13#endif
14
15namespace at {
16namespace vec {
17// See Note [CPU_CAPABILITY namespace]
18inline namespace CPU_CAPABILITY {
19
20#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
21
22template <> class Vectorized<c10::complex<double>> {
23private:
24 __m256d values;
25public:
26 using value_type = c10::complex<double>;
27 using size_type = int;
28 static constexpr size_type size() {
29 return 2;
30 }
31 Vectorized() {}
32 Vectorized(__m256d v) : values(v) {}
33 Vectorized(c10::complex<double> val) {
34 double real_value = val.real();
35 double imag_value = val.imag();
36 values = _mm256_setr_pd(real_value, imag_value,
37 real_value, imag_value);
38 }
39 Vectorized(c10::complex<double> val1, c10::complex<double> val2) {
40 values = _mm256_setr_pd(val1.real(), val1.imag(),
41 val2.real(), val2.imag());
42 }
43 operator __m256d() const {
44 return values;
45 }
46 template <int64_t mask>
47 static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
48 // convert c10::complex<V> index mask to V index mask: xy -> xxyy
49 static_assert (mask > -1 && mask < 4, "Unexpected mask value");
50 switch (mask) {
51 case 0:
52 return a;
53 case 1:
54 return _mm256_blend_pd(a.values, b.values, 0x03);
55 case 2:
56 return _mm256_blend_pd(a.values, b.values, 0x0c);
57 case 3: break;
58 }
59 return b;
60 }
61 static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b,
62 const Vectorized<c10::complex<double>>& mask) {
63 // convert c10::complex<V> index mask to V index mask: xy -> xxyy
64 auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values);
65 return _mm256_blendv_pd(a.values, b.values, mask_);
66
67 }
68 template<typename step_t>
69 static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0., step_t step = static_cast<step_t>(1)) {
70 return Vectorized<c10::complex<double>>(base,
71 base + step);
72 }
73 static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b,
74 int64_t count = size()) {
75 switch (count) {
76 case 0:
77 return a;
78 case 1:
79 return blend<1>(a, b);
80 }
81 return b;
82 }
83 static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) {
84 if (count == size())
85 return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
86
87 __at_align__ double tmp_values[2*size()];
88 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
89 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
90 // instructions while a loop would be compiled to one instruction.
91 for (const auto i : c10::irange(2*size())) {
92 tmp_values[i] = 0.0;
93 }
94 std::memcpy(
95 tmp_values,
96 reinterpret_cast<const double*>(ptr),
97 count * sizeof(c10::complex<double>));
98 return _mm256_load_pd(tmp_values);
99 }
100 void store(void* ptr, int count = size()) const {
101 if (count == size()) {
102 _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
103 } else if (count > 0) {
104 double tmp_values[2*size()];
105 _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
106 std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>));
107 }
108 }
109 const c10::complex<double>& operator[](int idx) const = delete;
110 c10::complex<double>& operator[](int idx) = delete;
111 Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const {
112 __at_align__ c10::complex<double> tmp[size()];
113 store(tmp);
114 for (const auto i : c10::irange(size())) {
115 tmp[i] = f(tmp[i]);
116 }
117 return loadu(tmp);
118 }
119 __m256d abs_2_() const {
120 auto val_2 = _mm256_mul_pd(values, values); // a*a b*b
121 return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b
122 }
123 __m256d abs_() const {
124 return _mm256_sqrt_pd(abs_2_()); // abs abs
125 }
126 Vectorized<c10::complex<double>> abs() const {
127 const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
128 0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
129 return _mm256_and_pd(abs_(), real_mask); // abs 0
130 }
131 __m256d angle_() const {
132 //angle = atan2(b/a)
133 auto b_a = _mm256_permute_pd(values, 0x05); // b a
134 return Sleef_atan2d4_u10(values, b_a); // 90-angle angle
135 }
136 Vectorized<c10::complex<double>> angle() const {
137 const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
138 0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
139 auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle
140 return _mm256_and_pd(angle, real_mask); // angle 0
141 }
142 Vectorized<c10::complex<double>> sgn() const {
143 auto abs = abs_();
144 auto zero = _mm256_setzero_pd();
145 auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
146 auto abs_val = Vectorized(abs);
147
148 auto div = values / abs_val.values; // x / abs(x)
149
150 return blendv(div, zero, mask);
151 }
152 __m256d real_() const {
153 const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000,
154 0xFFFFFFFFFFFFFFFF, 0x0000000000000000));
155 return _mm256_and_pd(values, real_mask);
156 }
157 Vectorized<c10::complex<double>> real() const {
158 return real_();
159 }
160 __m256d imag_() const {
161 const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF,
162 0x0000000000000000, 0xFFFFFFFFFFFFFFFF));
163 return _mm256_and_pd(values, imag_mask);
164 }
165 Vectorized<c10::complex<double>> imag() const {
166 return _mm256_permute_pd(imag_(), 0x05); //b a
167 }
168 __m256d conj_() const {
169 const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
170 return _mm256_xor_pd(values, sign_mask); // a -b
171 }
172 Vectorized<c10::complex<double>> conj() const {
173 return conj_();
174 }
175 Vectorized<c10::complex<double>> log() const {
176 // Most trigonomic ops use the log() op to improve complex number performance.
177 return map(std::log);
178 }
179 Vectorized<c10::complex<double>> log2() const {
180 const __m256d log2_ = _mm256_set1_pd(std::log(2));
181 return _mm256_div_pd(log(), log2_);
182 }
183 Vectorized<c10::complex<double>> log10() const {
184 const __m256d log10_ = _mm256_set1_pd(std::log(10));
185 return _mm256_div_pd(log(), log10_);
186 }
187 Vectorized<c10::complex<double>> log1p() const {
188 return map(std::log1p);
189 }
190 Vectorized<c10::complex<double>> asin() const {
191 // asin(x)
192 // = -i*ln(iz + sqrt(1 -z^2))
193 // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
194 // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
195 const __m256d one = _mm256_set1_pd(1);
196
197 auto conj = conj_();
198 auto b_a = _mm256_permute_pd(conj, 0x05); //-b a
199 auto ab = _mm256_mul_pd(conj, b_a); //-ab -ab
200 auto im = _mm256_add_pd(ab, ab); //-2ab -2ab
201
202 auto val_2 = _mm256_mul_pd(values, values); // a*a b*b
203 auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // a*a-b*b b*b-a*a
204 re = _mm256_sub_pd(one, re);
205
206 auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + i*im)
207 auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + sqrt())
208 return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); //-i*ln()
209 }
210 Vectorized<c10::complex<double>> acos() const {
211 // acos(x) = pi/2 - asin(x)
212 constexpr auto pi_2d = c10::pi<double> / 2;
213 const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0);
214 return _mm256_sub_pd(pi_2, asin());
215 }
216 Vectorized<c10::complex<double>> atan() const;
217 Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>>&) const {
218 AT_ERROR("not supported for complex numbers");
219 }
220 Vectorized<c10::complex<double>> erf() const {
221 AT_ERROR("not supported for complex numbers");
222 }
223 Vectorized<c10::complex<double>> erfc() const {
224 AT_ERROR("not supported for complex numbers");
225 }
226 Vectorized<c10::complex<double>> exp() const {
227 //exp(a + bi)
228 // = exp(a)*(cos(b) + sin(b)i)
229 auto exp = Sleef_expd4_u10(values); //exp(a) exp(b)
230 exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a)
231
232 auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
233 auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05),
234 sin_cos.x, 0x0A); //cos(b) sin(b)
235 return _mm256_mul_pd(exp, cos_sin);
236 }
237 Vectorized<c10::complex<double>> exp2() const {
238 // Use identity 2**x = exp(log(2) * x)
239 const __m256d ln_2 = _mm256_set1_pd(c10::ln_2<double>);
240 Vectorized<c10::complex<double>> scaled_values = _mm256_mul_pd(values, ln_2);
241 return scaled_values.exp();
242 }
243 Vectorized<c10::complex<double>> expm1() const {
244 AT_ERROR("not supported for complex numbers");
245 }
246 Vectorized<c10::complex<double>> sin() const {
247 return map(std::sin);
248 }
249 Vectorized<c10::complex<double>> sinh() const {
250 return map(std::sinh);
251 }
252 Vectorized<c10::complex<double>> cos() const {
253 return map(std::cos);
254 }
255 Vectorized<c10::complex<double>> cosh() const {
256 return map(std::cosh);
257 }
258 Vectorized<c10::complex<double>> ceil() const {
259 return _mm256_ceil_pd(values);
260 }
261 Vectorized<c10::complex<double>> floor() const {
262 return _mm256_floor_pd(values);
263 }
264 Vectorized<c10::complex<double>> hypot(const Vectorized<c10::complex<double>> &) const {
265 AT_ERROR("not supported for complex numbers");
266 }
267 Vectorized<c10::complex<double>> igamma(const Vectorized<c10::complex<double>> &) const {
268 AT_ERROR("not supported for complex numbers");
269 }
270 Vectorized<c10::complex<double>> igammac(const Vectorized<c10::complex<double>> &) const {
271 AT_ERROR("not supported for complex numbers");
272 }
273 Vectorized<c10::complex<double>> neg() const {
274 auto zero = _mm256_setzero_pd();
275 return _mm256_sub_pd(zero, values);
276 }
277 Vectorized<c10::complex<double>> nextafter(const Vectorized<c10::complex<double>> &) const {
278 AT_ERROR("not supported for complex numbers");
279 }
280 Vectorized<c10::complex<double>> round() const {
281 return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
282 }
283 Vectorized<c10::complex<double>> tan() const {
284 return map(std::tan);
285 }
286 Vectorized<c10::complex<double>> tanh() const {
287 return map(std::tanh);
288 }
289 Vectorized<c10::complex<double>> trunc() const {
290 return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
291 }
292 Vectorized<c10::complex<double>> sqrt() const {
293 return map(std::sqrt);
294 }
295 Vectorized<c10::complex<double>> reciprocal() const;
296 Vectorized<c10::complex<double>> rsqrt() const {
297 return sqrt().reciprocal();
298 }
299 Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const {
300 __at_align__ c10::complex<double> x_tmp[size()];
301 __at_align__ c10::complex<double> y_tmp[size()];
302 store(x_tmp);
303 exp.store(y_tmp);
304 for (const auto i : c10::irange(size())) {
305 x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
306 }
307 return loadu(x_tmp);
308 }
309 // Comparison using the _CMP_**_OQ predicate.
310 // `O`: get false if an operand is NaN
311 // `Q`: do not raise if an operand is NaN
312 Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const {
313 return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
314 }
315 Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const {
316 return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
317 }
318 Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>&) const {
319 TORCH_CHECK(false, "not supported for complex numbers");
320 }
321 Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>&) const {
322 TORCH_CHECK(false, "not supported for complex numbers");
323 }
324 Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>&) const {
325 TORCH_CHECK(false, "not supported for complex numbers");
326 }
327 Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>&) const {
328 TORCH_CHECK(false, "not supported for complex numbers");
329 }
330
331 Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const;
332 Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const;
333 Vectorized<c10::complex<double>> lt(const Vectorized<c10::complex<double>>&) const {
334 TORCH_CHECK(false, "not supported for complex numbers");
335 }
336 Vectorized<c10::complex<double>> le(const Vectorized<c10::complex<double>>&) const {
337 TORCH_CHECK(false, "not supported for complex numbers");
338 }
339 Vectorized<c10::complex<double>> gt(const Vectorized<c10::complex<double>>&) const {
340 TORCH_CHECK(false, "not supported for complex numbers");
341 }
342 Vectorized<c10::complex<double>> ge(const Vectorized<c10::complex<double>>&) const {
343 TORCH_CHECK(false, "not supported for complex numbers");
344 }
345};
346
347template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
348 return _mm256_add_pd(a, b);
349}
350
351template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
352 return _mm256_sub_pd(a, b);
353}
354
355template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
356 //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
357 const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
358 auto ac_bd = _mm256_mul_pd(a, b); //ac bd
359
360 auto d_c = _mm256_permute_pd(b, 0x05); //d c
361 d_c = _mm256_xor_pd(sign_mask, d_c); //d -c
362 auto ad_bc = _mm256_mul_pd(a, d_c); //ad -bc
363
364 auto ret = _mm256_hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc
365 return ret;
366}
367
368template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) {
369 //re + im*i = (a + bi) / (c + di)
370 auto mask = _mm256_set1_pd(-0.f);
371 auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d|
372 auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c|
373 auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, fabs_dc)); // 1/sc 1/sc
374 auto a2 = _mm256_mul_pd(a, scale); // a/sc b/sc
375 auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc
376 auto acbd2 = _mm256_mul_pd(a2, b2);
377
378 const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0);
379 auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc
380 dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc
381 auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2
382 auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
383
384 // get the denominator
385 auto denom2 = Vectorized<c10::complex<double>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
386 res2 = _mm256_div_pd(res2, denom2);
387 return res2;
388}
389
390// reciprocal. Implement this here so we can use multiplication.
391inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
392 //re + im*i = (a + bi) / (c + di)
393 //re = (ac + bd)/abs_2() = c/abs_2()
394 //im = (bc - ad)/abs_2() = d/abs_2()
395 const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
396 auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
397 return _mm256_div_pd(c_d, abs_2_());
398}
399
400inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
401 // atan(x) = i/2 * ln((i + z)/(i - z))
402 const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0);
403 const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5);
404
405 auto sum = Vectorized(_mm256_add_pd(i, values)); // a 1+b
406 auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b
407 auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
408 return i_half*ln; // i/2*ln()
409}
410
411template <>
412Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
413 auto abs_a = a.abs_2_();
414 auto abs_b = b.abs_2_();
415 auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ);
416 auto max = _mm256_blendv_pd(a, b, mask);
417 // Exploit the fact that all-ones is a NaN.
418 auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
419 return _mm256_or_pd(max, isnan);
420}
421
422template <>
423Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
424 auto abs_a = a.abs_2_();
425 auto abs_b = b.abs_2_();
426 auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ);
427 auto min = _mm256_blendv_pd(a, b, mask);
428 // Exploit the fact that all-ones is a NaN.
429 auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q);
430 return _mm256_or_pd(min, isnan);
431}
432
433template <>
434Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
435 return _mm256_and_pd(a, b);
436}
437
438template <>
439Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
440 return _mm256_or_pd(a, b);
441}
442
443template <>
444Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) {
445 return _mm256_xor_pd(a, b);
446}
447
448inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const {
449 return (*this == other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
450}
451
452inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const {
453 return (*this != other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0));
454}
455
456#endif
457
458}}}
459