1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5
6#include <c10/util/complex.h>
7#include <c10/util/irange.h>
8#include <ATen/cpu/vec/intrinsics.h>
9#include <ATen/cpu/vec/vec_base.h>
10#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
11#include <sleef.h>
12#endif
13
14namespace at {
15namespace vec {
16// See Note [CPU_CAPABILITY namespace]
17inline namespace CPU_CAPABILITY {
18
19#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
20
21template <> class Vectorized<c10::complex<float>> {
22private:
23 __m256 values;
24public:
25 using value_type = c10::complex<float>;
26 using size_type = int;
27 static constexpr size_type size() {
28 return 4;
29 }
30 Vectorized() {}
31 Vectorized(__m256 v) : values(v) {}
32 Vectorized(c10::complex<float> val) {
33 float real_value = val.real();
34 float imag_value = val.imag();
35 values = _mm256_setr_ps(real_value, imag_value,
36 real_value, imag_value,
37 real_value, imag_value,
38 real_value, imag_value
39 );
40 }
41 Vectorized(c10::complex<float> val1, c10::complex<float> val2, c10::complex<float> val3, c10::complex<float> val4) {
42 values = _mm256_setr_ps(val1.real(), val1.imag(),
43 val2.real(), val2.imag(),
44 val3.real(), val3.imag(),
45 val4.real(), val4.imag()
46 );
47 }
48 operator __m256() const {
49 return values;
50 }
51 template <int64_t mask>
52 static Vectorized<c10::complex<float>> blend(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
53 // convert c10::complex<V> index mask to V index mask: xy -> xxyy
54 static_assert(mask > -1 && mask < 16, "Unexpected mask range");
55 switch (mask) {
56 case 0:
57 return a;
58 case 1:
59 return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011
60 case 2:
61 return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100
62 case 3:
63 return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111
64 case 4:
65 return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000
66 case 5:
67 return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011
68 case 6:
69 return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100
70 case 7:
71 return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111
72 case 8:
73 return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000
74 case 9:
75 return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011
76 case 10:
77 return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100
78 case 11:
79 return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111
80 case 12:
81 return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000
82 case 13:
83 return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011
84 case 14:
85 return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100
86 default: break;
87 }
88 return b;
89 }
90 static Vectorized<c10::complex<float>> blendv(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b,
91 const Vectorized<c10::complex<float>>& mask) {
92 // convert c10::complex<V> index mask to V index mask: xy -> xxyy
93 auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values);
94 return _mm256_blendv_ps(a.values, b.values, mask_);
95
96 }
97 template<typename step_t>
98 static Vectorized<c10::complex<float>> arange(c10::complex<float> base = 0., step_t step = static_cast<step_t>(1)) {
99 return Vectorized<c10::complex<float>>(base,
100 base + step,
101 base + c10::complex<float>(2)*step,
102 base + c10::complex<float>(3)*step);
103 }
104 static Vectorized<c10::complex<float>> set(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b,
105 int64_t count = size()) {
106 switch (count) {
107 case 0:
108 return a;
109 case 1:
110 return blend<1>(a, b);
111 case 2:
112 return blend<3>(a, b);
113 case 3:
114 return blend<7>(a, b);
115 }
116 return b;
117 }
118 static Vectorized<c10::complex<float>> loadu(const void* ptr, int64_t count = size()) {
119 if (count == size())
120 return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
121
122 __at_align__ float tmp_values[2*size()];
123 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
124 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
125 // instructions while a loop would be compiled to one instruction.
126 for (const auto i : c10::irange(2*size())) {
127 tmp_values[i] = 0.0;
128 }
129 std::memcpy(
130 tmp_values,
131 reinterpret_cast<const float*>(ptr),
132 count * sizeof(c10::complex<float>));
133 return _mm256_load_ps(tmp_values);
134 }
135 void store(void* ptr, int count = size()) const {
136 if (count == size()) {
137 _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
138 } else if (count > 0) {
139 float tmp_values[2*size()];
140 _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
141 std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<float>));
142 }
143 }
144 const c10::complex<float>& operator[](int idx) const = delete;
145 c10::complex<float>& operator[](int idx) = delete;
146 Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const {
147 __at_align__ c10::complex<float> tmp[size()];
148 store(tmp);
149 for (const auto i : c10::irange(size())) {
150 tmp[i] = f(tmp[i]);
151 }
152 return loadu(tmp);
153 }
154 __m256 abs_2_() const {
155 auto val_2 = _mm256_mul_ps(values, values); // a*a b*b
156 auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b
157 return _mm256_permute_ps(ret, 0xD8);
158 }
159 __m256 abs_() const {
160 return _mm256_sqrt_ps(abs_2_()); // abs abs
161 }
162 Vectorized<c10::complex<float>> abs() const {
163 const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
164 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
165 return _mm256_and_ps(abs_(), real_mask); // abs 0
166 }
167 __m256 angle_() const {
168 //angle = atan2(b/a)
169 auto b_a = _mm256_permute_ps(values, 0xB1); // b a
170 return Sleef_atan2f8_u10(values, b_a); // 90-angle angle
171 }
172 Vectorized<c10::complex<float>> angle() const {
173 const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
174 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
175 auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle
176 return _mm256_and_ps(angle, real_mask); // angle 0
177 }
178 Vectorized<c10::complex<float>> sgn() const {
179 auto abs = abs_();
180 auto zero = _mm256_setzero_ps();
181 auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
182 auto abs_val = Vectorized(abs);
183
184 auto div = values / abs_val.values; // x / abs(x)
185
186 return _mm256_blendv_ps(div, zero, mask);
187 }
188 __m256 real_() const {
189 const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000,
190 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000));
191 return _mm256_and_ps(values, real_mask);
192 }
193 Vectorized<c10::complex<float>> real() const {
194 return real_();
195 }
196 __m256 imag_() const {
197 const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF,
198 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF));
199 return _mm256_and_ps(values, imag_mask);
200 }
201 Vectorized<c10::complex<float>> imag() const {
202 return _mm256_permute_ps(imag_(), 0xB1); //b a
203 }
204 __m256 conj_() const {
205 const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
206 return _mm256_xor_ps(values, sign_mask); // a -b
207 }
208 Vectorized<c10::complex<float>> conj() const {
209 return conj_();
210 }
211 Vectorized<c10::complex<float>> log() const {
212 // Most trigonomic ops use the log() op to improve complex number performance.
213 return map(std::log);
214 }
215 Vectorized<c10::complex<float>> log2() const {
216 const __m256 log2_ = _mm256_set1_ps(std::log(2));
217 return _mm256_div_ps(log(), log2_);
218 }
219 Vectorized<c10::complex<float>> log10() const {
220 const __m256 log10_ = _mm256_set1_ps(std::log(10));
221 return _mm256_div_ps(log(), log10_);
222 }
223 Vectorized<c10::complex<float>> log1p() const {
224 return map(std::log1p);
225 }
226 Vectorized<c10::complex<float>> asin() const {
227 // asin(x)
228 // = -i*ln(iz + sqrt(1 -z^2))
229 // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
230 // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
231 const __m256 one = _mm256_set1_ps(1);
232
233 auto conj = conj_();
234 auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a
235 auto ab = _mm256_mul_ps(conj, b_a); //-ab -ab
236 auto im = _mm256_add_ps(ab, ab); //-2ab -2ab
237
238 auto val_2 = _mm256_mul_ps(values, values); // a*a b*b
239 auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a
240 re = _mm256_permute_ps(re, 0xD8);
241 re = _mm256_sub_ps(one, re);
242
243 auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + i*im)
244 auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + sqrt())
245 return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); //-i*ln()
246 }
247 Vectorized<c10::complex<float>> acos() const {
248 return map(std::acos);
249 }
250 Vectorized<c10::complex<float>> atan() const;
251 Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>>& /*b*/) const {
252 AT_ERROR("not supported for complex numbers");
253 }
254 Vectorized<c10::complex<float>> erf() const {
255 AT_ERROR("not supported for complex numbers");
256 }
257 Vectorized<c10::complex<float>> erfc() const {
258 AT_ERROR("not supported for complex numbers");
259 }
260 Vectorized<c10::complex<float>> exp() const {
261 //exp(a + bi)
262 // = exp(a)*(cos(b) + sin(b)i)
263 auto exp = Sleef_expf8_u10(values); //exp(a) exp(b)
264 exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a)
265
266 auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)]
267 auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1),
268 sin_cos.x, 0xAA); //cos(b) sin(b)
269 return _mm256_mul_ps(exp, cos_sin);
270 }
271 Vectorized<c10::complex<float>> exp2() const {
272 // Use identity 2**x = exp(log(2) * x)
273 const __m256 ln_2 = _mm256_set1_ps(c10::ln_2<float>);
274 Vectorized<c10::complex<float>> scaled_values = _mm256_mul_ps(values, ln_2);
275 return scaled_values.exp();
276 }
277 Vectorized<c10::complex<float>> expm1() const {
278 AT_ERROR("not supported for complex numbers");
279 }
280 Vectorized<c10::complex<float>> sin() const {
281 return map(std::sin);
282 }
283 Vectorized<c10::complex<float>> sinh() const {
284 return map(std::sinh);
285 }
286 Vectorized<c10::complex<float>> cos() const {
287 return map(std::cos);
288 }
289 Vectorized<c10::complex<float>> cosh() const {
290 return map(std::cosh);
291 }
292 Vectorized<c10::complex<float>> ceil() const {
293 return _mm256_ceil_ps(values);
294 }
295 Vectorized<c10::complex<float>> floor() const {
296 return _mm256_floor_ps(values);
297 }
298 Vectorized<c10::complex<float>> hypot(const Vectorized<c10::complex<float>>& /*b*/) const {
299 AT_ERROR("not supported for complex numbers");
300 }
301 Vectorized<c10::complex<float>> igamma(const Vectorized<c10::complex<float>>& /*x*/) const {
302 AT_ERROR("not supported for complex numbers");
303 }
304 Vectorized<c10::complex<float>> igammac(const Vectorized<c10::complex<float>>& /*x*/) const {
305 AT_ERROR("not supported for complex numbers");
306 }
307 Vectorized<c10::complex<float>> neg() const {
308 auto zero = _mm256_setzero_ps();
309 return _mm256_sub_ps(zero, values);
310 }
311 Vectorized<c10::complex<float>> nextafter(const Vectorized<c10::complex<float>>& /*b*/) const {
312 AT_ERROR("not supported for complex numbers");
313 }
314 Vectorized<c10::complex<float>> round() const {
315 return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
316 }
317 Vectorized<c10::complex<float>> tan() const {
318 return map(std::tan);
319 }
320 Vectorized<c10::complex<float>> tanh() const {
321 return map(std::tanh);
322 }
323 Vectorized<c10::complex<float>> trunc() const {
324 return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
325 }
326 Vectorized<c10::complex<float>> sqrt() const {
327 return map(std::sqrt);
328 }
329 Vectorized<c10::complex<float>> reciprocal() const;
330 Vectorized<c10::complex<float>> rsqrt() const {
331 return sqrt().reciprocal();
332 }
333 Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const {
334 __at_align__ c10::complex<float> x_tmp[size()];
335 __at_align__ c10::complex<float> y_tmp[size()];
336 store(x_tmp);
337 exp.store(y_tmp);
338 for (const auto i : c10::irange(size())) {
339 x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
340 }
341 return loadu(x_tmp);
342 }
343 // Comparison using the _CMP_**_OQ predicate.
344 // `O`: get false if an operand is NaN
345 // `Q`: do not raise if an operand is NaN
346 Vectorized<c10::complex<float>> operator==(const Vectorized<c10::complex<float>>& other) const {
347 return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
348 }
349 Vectorized<c10::complex<float>> operator!=(const Vectorized<c10::complex<float>>& other) const {
350 return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ);
351 }
352 Vectorized<c10::complex<float>> operator<(const Vectorized<c10::complex<float>>& /*other*/) const {
353 TORCH_CHECK(false, "not supported for complex numbers");
354 }
355 Vectorized<c10::complex<float>> operator<=(const Vectorized<c10::complex<float>>& /*other*/) const {
356 TORCH_CHECK(false, "not supported for complex numbers");
357 }
358 Vectorized<c10::complex<float>> operator>(const Vectorized<c10::complex<float>>& /*other*/) const {
359 TORCH_CHECK(false, "not supported for complex numbers");
360 }
361 Vectorized<c10::complex<float>> operator>=(const Vectorized<c10::complex<float>>& /*other*/) const {
362 TORCH_CHECK(false, "not supported for complex numbers");
363 }
364
365 Vectorized<c10::complex<float>> eq(const Vectorized<c10::complex<float>>& other) const;
366 Vectorized<c10::complex<float>> ne(const Vectorized<c10::complex<float>>& other) const;
367 Vectorized<c10::complex<float>> lt(const Vectorized<c10::complex<float>>& /*other*/) const {
368 TORCH_CHECK(false, "not supported for complex numbers");
369 }
370 Vectorized<c10::complex<float>> le(const Vectorized<c10::complex<float>>& /*other*/) const {
371 TORCH_CHECK(false, "not supported for complex numbers");
372 }
373 Vectorized<c10::complex<float>> gt(const Vectorized<c10::complex<float>>& /*other*/) const {
374 TORCH_CHECK(false, "not supported for complex numbers");
375 }
376 Vectorized<c10::complex<float>> ge(const Vectorized<c10::complex<float>>& /*other*/) const {
377 TORCH_CHECK(false, "not supported for complex numbers");
378 }
379};
380
381template <> Vectorized<c10::complex<float>> inline operator+(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
382 return _mm256_add_ps(a, b);
383}
384
385template <> Vectorized<c10::complex<float>> inline operator-(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
386 return _mm256_sub_ps(a, b);
387}
388
389template <> Vectorized<c10::complex<float>> inline operator*(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
390 //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
391 const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
392 auto ac_bd = _mm256_mul_ps(a, b); //ac bd
393
394 auto d_c = _mm256_permute_ps(b, 0xB1); //d c
395 d_c = _mm256_xor_ps(sign_mask, d_c); //d -c
396 auto ad_bc = _mm256_mul_ps(a, d_c); //ad -bc
397
398 auto ret = _mm256_hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc
399 ret = _mm256_permute_ps(ret, 0xD8);
400 return ret;
401}
402
403template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) {
404 //re + im*i = (a + bi) / (c + di)
405 auto mask = _mm256_set1_ps(-0.f);
406 auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d|
407 auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c|
408 auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc
409 auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc
410 auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc
411 auto acbd2 = _mm256_mul_ps(a2, b2);
412
413 const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
414 auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc
415 dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc
416 auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2
417 auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2
418 res2 = _mm256_permute_ps(res2, 0xD8);
419
420 // get the denominator
421 auto denom2 = Vectorized<c10::complex<float>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
422 res2 = _mm256_div_ps(res2, denom2);
423 return res2;
424}
425
426// reciprocal. Implement this here so we can use multiplication.
427inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
428 //re + im*i = (a + bi) / (c + di)
429 //re = (ac + bd)/abs_2() = c/abs_2()
430 //im = (bc - ad)/abs_2() = d/abs_2()
431 const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
432 auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
433 return _mm256_div_ps(c_d, abs_2_());
434}
435
436inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {
437 // atan(x) = i/2 * ln((i + z)/(i - z))
438 const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
439 const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5);
440
441 auto sum = Vectorized(_mm256_add_ps(i, values)); // a 1+b
442 auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b
443 auto ln = (sum/sub).log(); // ln((i + z)/(i - z))
444 return i_half*ln; // i/2*ln()
445}
446
447template <>
448Vectorized<c10::complex<float>> inline maximum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
449 auto abs_a = a.abs_2_();
450 auto abs_b = b.abs_2_();
451 auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
452 auto max = _mm256_blendv_ps(a, b, mask);
453 // Exploit the fact that all-ones is a NaN.
454 auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
455 return _mm256_or_ps(max, isnan);
456}
457
458template <>
459Vectorized<c10::complex<float>> inline minimum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
460 auto abs_a = a.abs_2_();
461 auto abs_b = b.abs_2_();
462 auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
463 auto min = _mm256_blendv_ps(a, b, mask);
464 // Exploit the fact that all-ones is a NaN.
465 auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
466 return _mm256_or_ps(min, isnan);
467}
468
469template <>
470Vectorized<c10::complex<float>> inline operator&(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
471 return _mm256_and_ps(a, b);
472}
473
474template <>
475Vectorized<c10::complex<float>> inline operator|(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
476 return _mm256_or_ps(a, b);
477}
478
479template <>
480Vectorized<c10::complex<float>> inline operator^(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) {
481 return _mm256_xor_ps(a, b);
482}
483
484inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::eq(
485 const Vectorized<c10::complex<float>>& other) const {
486 return (*this == other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
487}
488
489inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::ne(
490 const Vectorized<c10::complex<float>>& other) const {
491 return (*this != other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f));
492}
493
494#endif
495
496}}}
497