1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <c10/util/complex.h> |
7 | #include <c10/util/irange.h> |
8 | #include <ATen/cpu/vec/intrinsics.h> |
9 | #include <ATen/cpu/vec/vec_base.h> |
10 | |
11 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
12 | #include <sleef.h> |
13 | #endif |
14 | |
15 | namespace at { |
16 | namespace vec { |
17 | // See Note [CPU_CAPABILITY namespace] |
18 | inline namespace CPU_CAPABILITY { |
19 | |
20 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
21 | |
22 | template <> class Vectorized<c10::complex<double>> { |
23 | private: |
24 | __m256d values; |
25 | public: |
26 | using value_type = c10::complex<double>; |
27 | using size_type = int; |
28 | static constexpr size_type size() { |
29 | return 2; |
30 | } |
31 | Vectorized() {} |
32 | Vectorized(__m256d v) : values(v) {} |
33 | Vectorized(c10::complex<double> val) { |
34 | double real_value = val.real(); |
35 | double imag_value = val.imag(); |
36 | values = _mm256_setr_pd(real_value, imag_value, |
37 | real_value, imag_value); |
38 | } |
39 | Vectorized(c10::complex<double> val1, c10::complex<double> val2) { |
40 | values = _mm256_setr_pd(val1.real(), val1.imag(), |
41 | val2.real(), val2.imag()); |
42 | } |
43 | operator __m256d() const { |
44 | return values; |
45 | } |
46 | template <int64_t mask> |
47 | static Vectorized<c10::complex<double>> blend(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
48 | // convert c10::complex<V> index mask to V index mask: xy -> xxyy |
49 | static_assert (mask > -1 && mask < 4, "Unexpected mask value" ); |
50 | switch (mask) { |
51 | case 0: |
52 | return a; |
53 | case 1: |
54 | return _mm256_blend_pd(a.values, b.values, 0x03); |
55 | case 2: |
56 | return _mm256_blend_pd(a.values, b.values, 0x0c); |
57 | case 3: break; |
58 | } |
59 | return b; |
60 | } |
61 | static Vectorized<c10::complex<double>> blendv(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b, |
62 | const Vectorized<c10::complex<double>>& mask) { |
63 | // convert c10::complex<V> index mask to V index mask: xy -> xxyy |
64 | auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); |
65 | return _mm256_blendv_pd(a.values, b.values, mask_); |
66 | |
67 | } |
68 | template<typename step_t> |
69 | static Vectorized<c10::complex<double>> arange(c10::complex<double> base = 0., step_t step = static_cast<step_t>(1)) { |
70 | return Vectorized<c10::complex<double>>(base, |
71 | base + step); |
72 | } |
73 | static Vectorized<c10::complex<double>> set(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b, |
74 | int64_t count = size()) { |
75 | switch (count) { |
76 | case 0: |
77 | return a; |
78 | case 1: |
79 | return blend<1>(a, b); |
80 | } |
81 | return b; |
82 | } |
83 | static Vectorized<c10::complex<double>> loadu(const void* ptr, int64_t count = size()) { |
84 | if (count == size()) |
85 | return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr)); |
86 | |
87 | __at_align__ double tmp_values[2*size()]; |
88 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
89 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
90 | // instructions while a loop would be compiled to one instruction. |
91 | for (const auto i : c10::irange(2*size())) { |
92 | tmp_values[i] = 0.0; |
93 | } |
94 | std::memcpy( |
95 | tmp_values, |
96 | reinterpret_cast<const double*>(ptr), |
97 | count * sizeof(c10::complex<double>)); |
98 | return _mm256_load_pd(tmp_values); |
99 | } |
100 | void store(void* ptr, int count = size()) const { |
101 | if (count == size()) { |
102 | _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values); |
103 | } else if (count > 0) { |
104 | double tmp_values[2*size()]; |
105 | _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values); |
106 | std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<double>)); |
107 | } |
108 | } |
109 | const c10::complex<double>& operator[](int idx) const = delete; |
110 | c10::complex<double>& operator[](int idx) = delete; |
111 | Vectorized<c10::complex<double>> map(c10::complex<double> (*const f)(const c10::complex<double> &)) const { |
112 | __at_align__ c10::complex<double> tmp[size()]; |
113 | store(tmp); |
114 | for (const auto i : c10::irange(size())) { |
115 | tmp[i] = f(tmp[i]); |
116 | } |
117 | return loadu(tmp); |
118 | } |
119 | __m256d abs_2_() const { |
120 | auto val_2 = _mm256_mul_pd(values, values); // a*a b*b |
121 | return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b |
122 | } |
123 | __m256d abs_() const { |
124 | return _mm256_sqrt_pd(abs_2_()); // abs abs |
125 | } |
126 | Vectorized<c10::complex<double>> abs() const { |
127 | const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, |
128 | 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); |
129 | return _mm256_and_pd(abs_(), real_mask); // abs 0 |
130 | } |
131 | __m256d angle_() const { |
132 | //angle = atan2(b/a) |
133 | auto b_a = _mm256_permute_pd(values, 0x05); // b a |
134 | return Sleef_atan2d4_u10(values, b_a); // 90-angle angle |
135 | } |
136 | Vectorized<c10::complex<double>> angle() const { |
137 | const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, |
138 | 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); |
139 | auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle |
140 | return _mm256_and_pd(angle, real_mask); // angle 0 |
141 | } |
142 | Vectorized<c10::complex<double>> sgn() const { |
143 | auto abs = abs_(); |
144 | auto zero = _mm256_setzero_pd(); |
145 | auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ); |
146 | auto abs_val = Vectorized(abs); |
147 | |
148 | auto div = values / abs_val.values; // x / abs(x) |
149 | |
150 | return blendv(div, zero, mask); |
151 | } |
152 | __m256d real_() const { |
153 | const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, |
154 | 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); |
155 | return _mm256_and_pd(values, real_mask); |
156 | } |
157 | Vectorized<c10::complex<double>> real() const { |
158 | return real_(); |
159 | } |
160 | __m256d imag_() const { |
161 | const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, |
162 | 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); |
163 | return _mm256_and_pd(values, imag_mask); |
164 | } |
165 | Vectorized<c10::complex<double>> imag() const { |
166 | return _mm256_permute_pd(imag_(), 0x05); //b a |
167 | } |
168 | __m256d conj_() const { |
169 | const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); |
170 | return _mm256_xor_pd(values, sign_mask); // a -b |
171 | } |
172 | Vectorized<c10::complex<double>> conj() const { |
173 | return conj_(); |
174 | } |
175 | Vectorized<c10::complex<double>> log() const { |
176 | // Most trigonomic ops use the log() op to improve complex number performance. |
177 | return map(std::log); |
178 | } |
179 | Vectorized<c10::complex<double>> log2() const { |
180 | const __m256d log2_ = _mm256_set1_pd(std::log(2)); |
181 | return _mm256_div_pd(log(), log2_); |
182 | } |
183 | Vectorized<c10::complex<double>> log10() const { |
184 | const __m256d log10_ = _mm256_set1_pd(std::log(10)); |
185 | return _mm256_div_pd(log(), log10_); |
186 | } |
187 | Vectorized<c10::complex<double>> log1p() const { |
188 | return map(std::log1p); |
189 | } |
190 | Vectorized<c10::complex<double>> asin() const { |
191 | // asin(x) |
192 | // = -i*ln(iz + sqrt(1 -z^2)) |
193 | // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) |
194 | // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) |
195 | const __m256d one = _mm256_set1_pd(1); |
196 | |
197 | auto conj = conj_(); |
198 | auto b_a = _mm256_permute_pd(conj, 0x05); //-b a |
199 | auto ab = _mm256_mul_pd(conj, b_a); //-ab -ab |
200 | auto im = _mm256_add_pd(ab, ab); //-2ab -2ab |
201 | |
202 | auto val_2 = _mm256_mul_pd(values, values); // a*a b*b |
203 | auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // a*a-b*b b*b-a*a |
204 | re = _mm256_sub_pd(one, re); |
205 | |
206 | auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + i*im) |
207 | auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + sqrt()) |
208 | return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); //-i*ln() |
209 | } |
210 | Vectorized<c10::complex<double>> acos() const { |
211 | // acos(x) = pi/2 - asin(x) |
212 | constexpr auto pi_2d = c10::pi<double> / 2; |
213 | const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0); |
214 | return _mm256_sub_pd(pi_2, asin()); |
215 | } |
216 | Vectorized<c10::complex<double>> atan() const; |
217 | Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>>&) const { |
218 | AT_ERROR("not supported for complex numbers" ); |
219 | } |
220 | Vectorized<c10::complex<double>> erf() const { |
221 | AT_ERROR("not supported for complex numbers" ); |
222 | } |
223 | Vectorized<c10::complex<double>> erfc() const { |
224 | AT_ERROR("not supported for complex numbers" ); |
225 | } |
226 | Vectorized<c10::complex<double>> exp() const { |
227 | //exp(a + bi) |
228 | // = exp(a)*(cos(b) + sin(b)i) |
229 | auto exp = Sleef_expd4_u10(values); //exp(a) exp(b) |
230 | exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a) |
231 | |
232 | auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] |
233 | auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05), |
234 | sin_cos.x, 0x0A); //cos(b) sin(b) |
235 | return _mm256_mul_pd(exp, cos_sin); |
236 | } |
237 | Vectorized<c10::complex<double>> exp2() const { |
238 | // Use identity 2**x = exp(log(2) * x) |
239 | const __m256d ln_2 = _mm256_set1_pd(c10::ln_2<double>); |
240 | Vectorized<c10::complex<double>> scaled_values = _mm256_mul_pd(values, ln_2); |
241 | return scaled_values.exp(); |
242 | } |
243 | Vectorized<c10::complex<double>> expm1() const { |
244 | AT_ERROR("not supported for complex numbers" ); |
245 | } |
246 | Vectorized<c10::complex<double>> sin() const { |
247 | return map(std::sin); |
248 | } |
249 | Vectorized<c10::complex<double>> sinh() const { |
250 | return map(std::sinh); |
251 | } |
252 | Vectorized<c10::complex<double>> cos() const { |
253 | return map(std::cos); |
254 | } |
255 | Vectorized<c10::complex<double>> cosh() const { |
256 | return map(std::cosh); |
257 | } |
258 | Vectorized<c10::complex<double>> ceil() const { |
259 | return _mm256_ceil_pd(values); |
260 | } |
261 | Vectorized<c10::complex<double>> floor() const { |
262 | return _mm256_floor_pd(values); |
263 | } |
264 | Vectorized<c10::complex<double>> hypot(const Vectorized<c10::complex<double>> &) const { |
265 | AT_ERROR("not supported for complex numbers" ); |
266 | } |
267 | Vectorized<c10::complex<double>> igamma(const Vectorized<c10::complex<double>> &) const { |
268 | AT_ERROR("not supported for complex numbers" ); |
269 | } |
270 | Vectorized<c10::complex<double>> igammac(const Vectorized<c10::complex<double>> &) const { |
271 | AT_ERROR("not supported for complex numbers" ); |
272 | } |
273 | Vectorized<c10::complex<double>> neg() const { |
274 | auto zero = _mm256_setzero_pd(); |
275 | return _mm256_sub_pd(zero, values); |
276 | } |
277 | Vectorized<c10::complex<double>> nextafter(const Vectorized<c10::complex<double>> &) const { |
278 | AT_ERROR("not supported for complex numbers" ); |
279 | } |
280 | Vectorized<c10::complex<double>> round() const { |
281 | return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
282 | } |
283 | Vectorized<c10::complex<double>> tan() const { |
284 | return map(std::tan); |
285 | } |
286 | Vectorized<c10::complex<double>> tanh() const { |
287 | return map(std::tanh); |
288 | } |
289 | Vectorized<c10::complex<double>> trunc() const { |
290 | return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
291 | } |
292 | Vectorized<c10::complex<double>> sqrt() const { |
293 | return map(std::sqrt); |
294 | } |
295 | Vectorized<c10::complex<double>> reciprocal() const; |
296 | Vectorized<c10::complex<double>> rsqrt() const { |
297 | return sqrt().reciprocal(); |
298 | } |
299 | Vectorized<c10::complex<double>> pow(const Vectorized<c10::complex<double>> &exp) const { |
300 | __at_align__ c10::complex<double> x_tmp[size()]; |
301 | __at_align__ c10::complex<double> y_tmp[size()]; |
302 | store(x_tmp); |
303 | exp.store(y_tmp); |
304 | for (const auto i : c10::irange(size())) { |
305 | x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); |
306 | } |
307 | return loadu(x_tmp); |
308 | } |
309 | // Comparison using the _CMP_**_OQ predicate. |
310 | // `O`: get false if an operand is NaN |
311 | // `Q`: do not raise if an operand is NaN |
312 | Vectorized<c10::complex<double>> operator==(const Vectorized<c10::complex<double>>& other) const { |
313 | return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); |
314 | } |
315 | Vectorized<c10::complex<double>> operator!=(const Vectorized<c10::complex<double>>& other) const { |
316 | return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); |
317 | } |
318 | Vectorized<c10::complex<double>> operator<(const Vectorized<c10::complex<double>>&) const { |
319 | TORCH_CHECK(false, "not supported for complex numbers" ); |
320 | } |
321 | Vectorized<c10::complex<double>> operator<=(const Vectorized<c10::complex<double>>&) const { |
322 | TORCH_CHECK(false, "not supported for complex numbers" ); |
323 | } |
324 | Vectorized<c10::complex<double>> operator>(const Vectorized<c10::complex<double>>&) const { |
325 | TORCH_CHECK(false, "not supported for complex numbers" ); |
326 | } |
327 | Vectorized<c10::complex<double>> operator>=(const Vectorized<c10::complex<double>>&) const { |
328 | TORCH_CHECK(false, "not supported for complex numbers" ); |
329 | } |
330 | |
331 | Vectorized<c10::complex<double>> eq(const Vectorized<c10::complex<double>>& other) const; |
332 | Vectorized<c10::complex<double>> ne(const Vectorized<c10::complex<double>>& other) const; |
333 | Vectorized<c10::complex<double>> lt(const Vectorized<c10::complex<double>>&) const { |
334 | TORCH_CHECK(false, "not supported for complex numbers" ); |
335 | } |
336 | Vectorized<c10::complex<double>> le(const Vectorized<c10::complex<double>>&) const { |
337 | TORCH_CHECK(false, "not supported for complex numbers" ); |
338 | } |
339 | Vectorized<c10::complex<double>> gt(const Vectorized<c10::complex<double>>&) const { |
340 | TORCH_CHECK(false, "not supported for complex numbers" ); |
341 | } |
342 | Vectorized<c10::complex<double>> ge(const Vectorized<c10::complex<double>>&) const { |
343 | TORCH_CHECK(false, "not supported for complex numbers" ); |
344 | } |
345 | }; |
346 | |
347 | template <> Vectorized<c10::complex<double>> inline operator+(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) { |
348 | return _mm256_add_pd(a, b); |
349 | } |
350 | |
351 | template <> Vectorized<c10::complex<double>> inline operator-(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) { |
352 | return _mm256_sub_pd(a, b); |
353 | } |
354 | |
355 | template <> Vectorized<c10::complex<double>> inline operator*(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) { |
356 | //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i |
357 | const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); |
358 | auto ac_bd = _mm256_mul_pd(a, b); //ac bd |
359 | |
360 | auto d_c = _mm256_permute_pd(b, 0x05); //d c |
361 | d_c = _mm256_xor_pd(sign_mask, d_c); //d -c |
362 | auto ad_bc = _mm256_mul_pd(a, d_c); //ad -bc |
363 | |
364 | auto ret = _mm256_hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc |
365 | return ret; |
366 | } |
367 | |
368 | template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c10::complex<double>> &a, const Vectorized<c10::complex<double>> &b) { |
369 | //re + im*i = (a + bi) / (c + di) |
370 | auto mask = _mm256_set1_pd(-0.f); |
371 | auto fabs_cd = _mm256_andnot_pd(mask, b); // |c| |d| |
372 | auto fabs_dc = _mm256_permute_pd(fabs_cd, 0x05); // |d| |c| |
373 | auto scale = _mm256_div_pd(_mm256_set1_pd(1.0f), _mm256_max_pd(fabs_cd, fabs_dc)); // 1/sc 1/sc |
374 | auto a2 = _mm256_mul_pd(a, scale); // a/sc b/sc |
375 | auto b2 = _mm256_mul_pd(b, scale); // c/sc d/sc |
376 | auto acbd2 = _mm256_mul_pd(a2, b2); |
377 | |
378 | const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0); |
379 | auto dc2 = _mm256_permute_pd(b2, 0x05); // d/sc c/sc |
380 | dc2 = _mm256_xor_pd(sign_mask, dc2); // -d/|c,d| c/sc |
381 | auto adbc2 = _mm256_mul_pd(a2, dc2); //-ad/sc^2 bc/sc^2 |
382 | auto res2 = _mm256_hadd_pd(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 |
383 | |
384 | // get the denominator |
385 | auto denom2 = Vectorized<c10::complex<double>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 |
386 | res2 = _mm256_div_pd(res2, denom2); |
387 | return res2; |
388 | } |
389 | |
390 | // reciprocal. Implement this here so we can use multiplication. |
391 | inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{ |
392 | //re + im*i = (a + bi) / (c + di) |
393 | //re = (ac + bd)/abs_2() = c/abs_2() |
394 | //im = (bc - ad)/abs_2() = d/abs_2() |
395 | const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); |
396 | auto c_d = _mm256_xor_pd(sign_mask, values); //c -d |
397 | return _mm256_div_pd(c_d, abs_2_()); |
398 | } |
399 | |
400 | inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const { |
401 | // atan(x) = i/2 * ln((i + z)/(i - z)) |
402 | const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0); |
403 | const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5); |
404 | |
405 | auto sum = Vectorized(_mm256_add_pd(i, values)); // a 1+b |
406 | auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b |
407 | auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) |
408 | return i_half*ln; // i/2*ln() |
409 | } |
410 | |
411 | template <> |
412 | Vectorized<c10::complex<double>> inline maximum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
413 | auto abs_a = a.abs_2_(); |
414 | auto abs_b = b.abs_2_(); |
415 | auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); |
416 | auto max = _mm256_blendv_pd(a, b, mask); |
417 | // Exploit the fact that all-ones is a NaN. |
418 | auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); |
419 | return _mm256_or_pd(max, isnan); |
420 | } |
421 | |
422 | template <> |
423 | Vectorized<c10::complex<double>> inline minimum(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
424 | auto abs_a = a.abs_2_(); |
425 | auto abs_b = b.abs_2_(); |
426 | auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); |
427 | auto min = _mm256_blendv_pd(a, b, mask); |
428 | // Exploit the fact that all-ones is a NaN. |
429 | auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); |
430 | return _mm256_or_pd(min, isnan); |
431 | } |
432 | |
433 | template <> |
434 | Vectorized<c10::complex<double>> inline operator&(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
435 | return _mm256_and_pd(a, b); |
436 | } |
437 | |
438 | template <> |
439 | Vectorized<c10::complex<double>> inline operator|(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
440 | return _mm256_or_pd(a, b); |
441 | } |
442 | |
443 | template <> |
444 | Vectorized<c10::complex<double>> inline operator^(const Vectorized<c10::complex<double>>& a, const Vectorized<c10::complex<double>>& b) { |
445 | return _mm256_xor_pd(a, b); |
446 | } |
447 | |
448 | inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::eq(const Vectorized<c10::complex<double>>& other) const { |
449 | return (*this == other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0)); |
450 | } |
451 | |
452 | inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::ne(const Vectorized<c10::complex<double>>& other) const { |
453 | return (*this != other) & Vectorized<c10::complex<double>>(_mm256_set1_pd(1.0)); |
454 | } |
455 | |
456 | #endif |
457 | |
458 | }}} |
459 | |