1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <c10/util/complex.h> |
7 | #include <c10/util/irange.h> |
8 | #include <ATen/cpu/vec/intrinsics.h> |
9 | #include <ATen/cpu/vec/vec_base.h> |
10 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
11 | #include <sleef.h> |
12 | #endif |
13 | |
14 | namespace at { |
15 | namespace vec { |
16 | // See Note [CPU_CAPABILITY namespace] |
17 | inline namespace CPU_CAPABILITY { |
18 | |
19 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
20 | |
21 | template <> class Vectorized<c10::complex<float>> { |
22 | private: |
23 | __m256 values; |
24 | public: |
25 | using value_type = c10::complex<float>; |
26 | using size_type = int; |
27 | static constexpr size_type size() { |
28 | return 4; |
29 | } |
30 | Vectorized() {} |
31 | Vectorized(__m256 v) : values(v) {} |
32 | Vectorized(c10::complex<float> val) { |
33 | float real_value = val.real(); |
34 | float imag_value = val.imag(); |
35 | values = _mm256_setr_ps(real_value, imag_value, |
36 | real_value, imag_value, |
37 | real_value, imag_value, |
38 | real_value, imag_value |
39 | ); |
40 | } |
41 | Vectorized(c10::complex<float> val1, c10::complex<float> val2, c10::complex<float> val3, c10::complex<float> val4) { |
42 | values = _mm256_setr_ps(val1.real(), val1.imag(), |
43 | val2.real(), val2.imag(), |
44 | val3.real(), val3.imag(), |
45 | val4.real(), val4.imag() |
46 | ); |
47 | } |
48 | operator __m256() const { |
49 | return values; |
50 | } |
51 | template <int64_t mask> |
52 | static Vectorized<c10::complex<float>> blend(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
53 | // convert c10::complex<V> index mask to V index mask: xy -> xxyy |
54 | static_assert(mask > -1 && mask < 16, "Unexpected mask range" ); |
55 | switch (mask) { |
56 | case 0: |
57 | return a; |
58 | case 1: |
59 | return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011 |
60 | case 2: |
61 | return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100 |
62 | case 3: |
63 | return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111 |
64 | case 4: |
65 | return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000 |
66 | case 5: |
67 | return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011 |
68 | case 6: |
69 | return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100 |
70 | case 7: |
71 | return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111 |
72 | case 8: |
73 | return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000 |
74 | case 9: |
75 | return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011 |
76 | case 10: |
77 | return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100 |
78 | case 11: |
79 | return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111 |
80 | case 12: |
81 | return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000 |
82 | case 13: |
83 | return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011 |
84 | case 14: |
85 | return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100 |
86 | default: break; |
87 | } |
88 | return b; |
89 | } |
90 | static Vectorized<c10::complex<float>> blendv(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b, |
91 | const Vectorized<c10::complex<float>>& mask) { |
92 | // convert c10::complex<V> index mask to V index mask: xy -> xxyy |
93 | auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); |
94 | return _mm256_blendv_ps(a.values, b.values, mask_); |
95 | |
96 | } |
97 | template<typename step_t> |
98 | static Vectorized<c10::complex<float>> arange(c10::complex<float> base = 0., step_t step = static_cast<step_t>(1)) { |
99 | return Vectorized<c10::complex<float>>(base, |
100 | base + step, |
101 | base + c10::complex<float>(2)*step, |
102 | base + c10::complex<float>(3)*step); |
103 | } |
104 | static Vectorized<c10::complex<float>> set(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b, |
105 | int64_t count = size()) { |
106 | switch (count) { |
107 | case 0: |
108 | return a; |
109 | case 1: |
110 | return blend<1>(a, b); |
111 | case 2: |
112 | return blend<3>(a, b); |
113 | case 3: |
114 | return blend<7>(a, b); |
115 | } |
116 | return b; |
117 | } |
118 | static Vectorized<c10::complex<float>> loadu(const void* ptr, int64_t count = size()) { |
119 | if (count == size()) |
120 | return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr)); |
121 | |
122 | __at_align__ float tmp_values[2*size()]; |
123 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
124 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
125 | // instructions while a loop would be compiled to one instruction. |
126 | for (const auto i : c10::irange(2*size())) { |
127 | tmp_values[i] = 0.0; |
128 | } |
129 | std::memcpy( |
130 | tmp_values, |
131 | reinterpret_cast<const float*>(ptr), |
132 | count * sizeof(c10::complex<float>)); |
133 | return _mm256_load_ps(tmp_values); |
134 | } |
135 | void store(void* ptr, int count = size()) const { |
136 | if (count == size()) { |
137 | _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values); |
138 | } else if (count > 0) { |
139 | float tmp_values[2*size()]; |
140 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values); |
141 | std::memcpy(ptr, tmp_values, count * sizeof(c10::complex<float>)); |
142 | } |
143 | } |
144 | const c10::complex<float>& operator[](int idx) const = delete; |
145 | c10::complex<float>& operator[](int idx) = delete; |
146 | Vectorized<c10::complex<float>> map(c10::complex<float> (*const f)(const c10::complex<float> &)) const { |
147 | __at_align__ c10::complex<float> tmp[size()]; |
148 | store(tmp); |
149 | for (const auto i : c10::irange(size())) { |
150 | tmp[i] = f(tmp[i]); |
151 | } |
152 | return loadu(tmp); |
153 | } |
154 | __m256 abs_2_() const { |
155 | auto val_2 = _mm256_mul_ps(values, values); // a*a b*b |
156 | auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b |
157 | return _mm256_permute_ps(ret, 0xD8); |
158 | } |
159 | __m256 abs_() const { |
160 | return _mm256_sqrt_ps(abs_2_()); // abs abs |
161 | } |
162 | Vectorized<c10::complex<float>> abs() const { |
163 | const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, |
164 | 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); |
165 | return _mm256_and_ps(abs_(), real_mask); // abs 0 |
166 | } |
167 | __m256 angle_() const { |
168 | //angle = atan2(b/a) |
169 | auto b_a = _mm256_permute_ps(values, 0xB1); // b a |
170 | return Sleef_atan2f8_u10(values, b_a); // 90-angle angle |
171 | } |
172 | Vectorized<c10::complex<float>> angle() const { |
173 | const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, |
174 | 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); |
175 | auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle |
176 | return _mm256_and_ps(angle, real_mask); // angle 0 |
177 | } |
178 | Vectorized<c10::complex<float>> sgn() const { |
179 | auto abs = abs_(); |
180 | auto zero = _mm256_setzero_ps(); |
181 | auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ); |
182 | auto abs_val = Vectorized(abs); |
183 | |
184 | auto div = values / abs_val.values; // x / abs(x) |
185 | |
186 | return _mm256_blendv_ps(div, zero, mask); |
187 | } |
188 | __m256 real_() const { |
189 | const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, |
190 | 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); |
191 | return _mm256_and_ps(values, real_mask); |
192 | } |
193 | Vectorized<c10::complex<float>> real() const { |
194 | return real_(); |
195 | } |
196 | __m256 imag_() const { |
197 | const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, |
198 | 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF)); |
199 | return _mm256_and_ps(values, imag_mask); |
200 | } |
201 | Vectorized<c10::complex<float>> imag() const { |
202 | return _mm256_permute_ps(imag_(), 0xB1); //b a |
203 | } |
204 | __m256 conj_() const { |
205 | const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); |
206 | return _mm256_xor_ps(values, sign_mask); // a -b |
207 | } |
208 | Vectorized<c10::complex<float>> conj() const { |
209 | return conj_(); |
210 | } |
211 | Vectorized<c10::complex<float>> log() const { |
212 | // Most trigonomic ops use the log() op to improve complex number performance. |
213 | return map(std::log); |
214 | } |
215 | Vectorized<c10::complex<float>> log2() const { |
216 | const __m256 log2_ = _mm256_set1_ps(std::log(2)); |
217 | return _mm256_div_ps(log(), log2_); |
218 | } |
219 | Vectorized<c10::complex<float>> log10() const { |
220 | const __m256 log10_ = _mm256_set1_ps(std::log(10)); |
221 | return _mm256_div_ps(log(), log10_); |
222 | } |
223 | Vectorized<c10::complex<float>> log1p() const { |
224 | return map(std::log1p); |
225 | } |
226 | Vectorized<c10::complex<float>> asin() const { |
227 | // asin(x) |
228 | // = -i*ln(iz + sqrt(1 -z^2)) |
229 | // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) |
230 | // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) |
231 | const __m256 one = _mm256_set1_ps(1); |
232 | |
233 | auto conj = conj_(); |
234 | auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a |
235 | auto ab = _mm256_mul_ps(conj, b_a); //-ab -ab |
236 | auto im = _mm256_add_ps(ab, ab); //-2ab -2ab |
237 | |
238 | auto val_2 = _mm256_mul_ps(values, values); // a*a b*b |
239 | auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a |
240 | re = _mm256_permute_ps(re, 0xD8); |
241 | re = _mm256_sub_ps(one, re); |
242 | |
243 | auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + i*im) |
244 | auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + sqrt()) |
245 | return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); //-i*ln() |
246 | } |
247 | Vectorized<c10::complex<float>> acos() const { |
248 | return map(std::acos); |
249 | } |
250 | Vectorized<c10::complex<float>> atan() const; |
251 | Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>>& /*b*/) const { |
252 | AT_ERROR("not supported for complex numbers" ); |
253 | } |
254 | Vectorized<c10::complex<float>> erf() const { |
255 | AT_ERROR("not supported for complex numbers" ); |
256 | } |
257 | Vectorized<c10::complex<float>> erfc() const { |
258 | AT_ERROR("not supported for complex numbers" ); |
259 | } |
260 | Vectorized<c10::complex<float>> exp() const { |
261 | //exp(a + bi) |
262 | // = exp(a)*(cos(b) + sin(b)i) |
263 | auto exp = Sleef_expf8_u10(values); //exp(a) exp(b) |
264 | exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a) |
265 | |
266 | auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] |
267 | auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1), |
268 | sin_cos.x, 0xAA); //cos(b) sin(b) |
269 | return _mm256_mul_ps(exp, cos_sin); |
270 | } |
271 | Vectorized<c10::complex<float>> exp2() const { |
272 | // Use identity 2**x = exp(log(2) * x) |
273 | const __m256 ln_2 = _mm256_set1_ps(c10::ln_2<float>); |
274 | Vectorized<c10::complex<float>> scaled_values = _mm256_mul_ps(values, ln_2); |
275 | return scaled_values.exp(); |
276 | } |
277 | Vectorized<c10::complex<float>> expm1() const { |
278 | AT_ERROR("not supported for complex numbers" ); |
279 | } |
280 | Vectorized<c10::complex<float>> sin() const { |
281 | return map(std::sin); |
282 | } |
283 | Vectorized<c10::complex<float>> sinh() const { |
284 | return map(std::sinh); |
285 | } |
286 | Vectorized<c10::complex<float>> cos() const { |
287 | return map(std::cos); |
288 | } |
289 | Vectorized<c10::complex<float>> cosh() const { |
290 | return map(std::cosh); |
291 | } |
292 | Vectorized<c10::complex<float>> ceil() const { |
293 | return _mm256_ceil_ps(values); |
294 | } |
295 | Vectorized<c10::complex<float>> floor() const { |
296 | return _mm256_floor_ps(values); |
297 | } |
298 | Vectorized<c10::complex<float>> hypot(const Vectorized<c10::complex<float>>& /*b*/) const { |
299 | AT_ERROR("not supported for complex numbers" ); |
300 | } |
301 | Vectorized<c10::complex<float>> igamma(const Vectorized<c10::complex<float>>& /*x*/) const { |
302 | AT_ERROR("not supported for complex numbers" ); |
303 | } |
304 | Vectorized<c10::complex<float>> igammac(const Vectorized<c10::complex<float>>& /*x*/) const { |
305 | AT_ERROR("not supported for complex numbers" ); |
306 | } |
307 | Vectorized<c10::complex<float>> neg() const { |
308 | auto zero = _mm256_setzero_ps(); |
309 | return _mm256_sub_ps(zero, values); |
310 | } |
311 | Vectorized<c10::complex<float>> nextafter(const Vectorized<c10::complex<float>>& /*b*/) const { |
312 | AT_ERROR("not supported for complex numbers" ); |
313 | } |
314 | Vectorized<c10::complex<float>> round() const { |
315 | return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
316 | } |
317 | Vectorized<c10::complex<float>> tan() const { |
318 | return map(std::tan); |
319 | } |
320 | Vectorized<c10::complex<float>> tanh() const { |
321 | return map(std::tanh); |
322 | } |
323 | Vectorized<c10::complex<float>> trunc() const { |
324 | return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
325 | } |
326 | Vectorized<c10::complex<float>> sqrt() const { |
327 | return map(std::sqrt); |
328 | } |
329 | Vectorized<c10::complex<float>> reciprocal() const; |
330 | Vectorized<c10::complex<float>> rsqrt() const { |
331 | return sqrt().reciprocal(); |
332 | } |
333 | Vectorized<c10::complex<float>> pow(const Vectorized<c10::complex<float>> &exp) const { |
334 | __at_align__ c10::complex<float> x_tmp[size()]; |
335 | __at_align__ c10::complex<float> y_tmp[size()]; |
336 | store(x_tmp); |
337 | exp.store(y_tmp); |
338 | for (const auto i : c10::irange(size())) { |
339 | x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); |
340 | } |
341 | return loadu(x_tmp); |
342 | } |
343 | // Comparison using the _CMP_**_OQ predicate. |
344 | // `O`: get false if an operand is NaN |
345 | // `Q`: do not raise if an operand is NaN |
346 | Vectorized<c10::complex<float>> operator==(const Vectorized<c10::complex<float>>& other) const { |
347 | return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); |
348 | } |
349 | Vectorized<c10::complex<float>> operator!=(const Vectorized<c10::complex<float>>& other) const { |
350 | return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); |
351 | } |
352 | Vectorized<c10::complex<float>> operator<(const Vectorized<c10::complex<float>>& /*other*/) const { |
353 | TORCH_CHECK(false, "not supported for complex numbers" ); |
354 | } |
355 | Vectorized<c10::complex<float>> operator<=(const Vectorized<c10::complex<float>>& /*other*/) const { |
356 | TORCH_CHECK(false, "not supported for complex numbers" ); |
357 | } |
358 | Vectorized<c10::complex<float>> operator>(const Vectorized<c10::complex<float>>& /*other*/) const { |
359 | TORCH_CHECK(false, "not supported for complex numbers" ); |
360 | } |
361 | Vectorized<c10::complex<float>> operator>=(const Vectorized<c10::complex<float>>& /*other*/) const { |
362 | TORCH_CHECK(false, "not supported for complex numbers" ); |
363 | } |
364 | |
365 | Vectorized<c10::complex<float>> eq(const Vectorized<c10::complex<float>>& other) const; |
366 | Vectorized<c10::complex<float>> ne(const Vectorized<c10::complex<float>>& other) const; |
367 | Vectorized<c10::complex<float>> lt(const Vectorized<c10::complex<float>>& /*other*/) const { |
368 | TORCH_CHECK(false, "not supported for complex numbers" ); |
369 | } |
370 | Vectorized<c10::complex<float>> le(const Vectorized<c10::complex<float>>& /*other*/) const { |
371 | TORCH_CHECK(false, "not supported for complex numbers" ); |
372 | } |
373 | Vectorized<c10::complex<float>> gt(const Vectorized<c10::complex<float>>& /*other*/) const { |
374 | TORCH_CHECK(false, "not supported for complex numbers" ); |
375 | } |
376 | Vectorized<c10::complex<float>> ge(const Vectorized<c10::complex<float>>& /*other*/) const { |
377 | TORCH_CHECK(false, "not supported for complex numbers" ); |
378 | } |
379 | }; |
380 | |
381 | template <> Vectorized<c10::complex<float>> inline operator+(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) { |
382 | return _mm256_add_ps(a, b); |
383 | } |
384 | |
385 | template <> Vectorized<c10::complex<float>> inline operator-(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) { |
386 | return _mm256_sub_ps(a, b); |
387 | } |
388 | |
389 | template <> Vectorized<c10::complex<float>> inline operator*(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) { |
390 | //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i |
391 | const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); |
392 | auto ac_bd = _mm256_mul_ps(a, b); //ac bd |
393 | |
394 | auto d_c = _mm256_permute_ps(b, 0xB1); //d c |
395 | d_c = _mm256_xor_ps(sign_mask, d_c); //d -c |
396 | auto ad_bc = _mm256_mul_ps(a, d_c); //ad -bc |
397 | |
398 | auto ret = _mm256_hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc |
399 | ret = _mm256_permute_ps(ret, 0xD8); |
400 | return ret; |
401 | } |
402 | |
403 | template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c10::complex<float>> &a, const Vectorized<c10::complex<float>> &b) { |
404 | //re + im*i = (a + bi) / (c + di) |
405 | auto mask = _mm256_set1_ps(-0.f); |
406 | auto fabs_cd = _mm256_andnot_ps(mask, b); // |c| |d| |
407 | auto fabs_dc = _mm256_permute_ps(fabs_cd, 0xB1); // |d| |c| |
408 | auto scale = _mm256_rcp_ps(_mm256_max_ps(fabs_cd, fabs_dc)); // 1/sc 1/sc |
409 | auto a2 = _mm256_mul_ps(a, scale); // a/sc b/sc |
410 | auto b2 = _mm256_mul_ps(b, scale); // c/sc d/sc |
411 | auto acbd2 = _mm256_mul_ps(a2, b2); |
412 | |
413 | const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); |
414 | auto dc2 = _mm256_permute_ps(b2, 0xB1); // d/sc c/sc |
415 | dc2 = _mm256_xor_ps(sign_mask, dc2); // -d/|c,d| c/sc |
416 | auto adbc2 = _mm256_mul_ps(a2, dc2); //-ad/sc^2 bc/sc^2 |
417 | auto res2 = _mm256_hadd_ps(acbd2, adbc2); //(ac+bd)/sc^2 (bc-ad)/sc^2 |
418 | res2 = _mm256_permute_ps(res2, 0xD8); |
419 | |
420 | // get the denominator |
421 | auto denom2 = Vectorized<c10::complex<float>>(b2).abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 |
422 | res2 = _mm256_div_ps(res2, denom2); |
423 | return res2; |
424 | } |
425 | |
426 | // reciprocal. Implement this here so we can use multiplication. |
427 | inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const { |
428 | //re + im*i = (a + bi) / (c + di) |
429 | //re = (ac + bd)/abs_2() = c/abs_2() |
430 | //im = (bc - ad)/abs_2() = d/abs_2() |
431 | const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); |
432 | auto c_d = _mm256_xor_ps(sign_mask, values); //c -d |
433 | return _mm256_div_ps(c_d, abs_2_()); |
434 | } |
435 | |
436 | inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const { |
437 | // atan(x) = i/2 * ln((i + z)/(i - z)) |
438 | const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); |
439 | const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); |
440 | |
441 | auto sum = Vectorized(_mm256_add_ps(i, values)); // a 1+b |
442 | auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b |
443 | auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) |
444 | return i_half*ln; // i/2*ln() |
445 | } |
446 | |
447 | template <> |
448 | Vectorized<c10::complex<float>> inline maximum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
449 | auto abs_a = a.abs_2_(); |
450 | auto abs_b = b.abs_2_(); |
451 | auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); |
452 | auto max = _mm256_blendv_ps(a, b, mask); |
453 | // Exploit the fact that all-ones is a NaN. |
454 | auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); |
455 | return _mm256_or_ps(max, isnan); |
456 | } |
457 | |
458 | template <> |
459 | Vectorized<c10::complex<float>> inline minimum(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
460 | auto abs_a = a.abs_2_(); |
461 | auto abs_b = b.abs_2_(); |
462 | auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); |
463 | auto min = _mm256_blendv_ps(a, b, mask); |
464 | // Exploit the fact that all-ones is a NaN. |
465 | auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); |
466 | return _mm256_or_ps(min, isnan); |
467 | } |
468 | |
469 | template <> |
470 | Vectorized<c10::complex<float>> inline operator&(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
471 | return _mm256_and_ps(a, b); |
472 | } |
473 | |
474 | template <> |
475 | Vectorized<c10::complex<float>> inline operator|(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
476 | return _mm256_or_ps(a, b); |
477 | } |
478 | |
479 | template <> |
480 | Vectorized<c10::complex<float>> inline operator^(const Vectorized<c10::complex<float>>& a, const Vectorized<c10::complex<float>>& b) { |
481 | return _mm256_xor_ps(a, b); |
482 | } |
483 | |
484 | inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::eq( |
485 | const Vectorized<c10::complex<float>>& other) const { |
486 | return (*this == other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f)); |
487 | } |
488 | |
489 | inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::ne( |
490 | const Vectorized<c10::complex<float>>& other) const { |
491 | return (*this != other) & Vectorized<c10::complex<float>>(_mm256_set1_ps(1.0f)); |
492 | } |
493 | |
494 | #endif |
495 | |
496 | }}} |
497 | |