1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <ATen/cpu/vec/intrinsics.h> |
7 | #include <ATen/cpu/vec/vec_base.h> |
8 | #include <c10/util/irange.h> |
9 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
10 | #include <sleef.h> |
11 | #endif |
12 | |
13 | namespace at { |
14 | namespace vec { |
15 | // See Note [CPU_CAPABILITY namespace] |
16 | inline namespace CPU_CAPABILITY { |
17 | |
18 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
19 | |
20 | template <> class Vectorized<float> { |
21 | private: |
22 | __m256 values; |
23 | public: |
24 | using value_type = float; |
25 | using size_type = int; |
26 | static constexpr size_type size() { |
27 | return 8; |
28 | } |
29 | Vectorized() {} |
30 | Vectorized(__m256 v) : values(v) {} |
31 | Vectorized(float val) { |
32 | values = _mm256_set1_ps(val); |
33 | } |
34 | Vectorized(float val1, float val2, float val3, float val4, |
35 | float val5, float val6, float val7, float val8) { |
36 | values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8); |
37 | } |
38 | operator __m256() const { |
39 | return values; |
40 | } |
41 | template <int64_t mask> |
42 | static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) { |
43 | return _mm256_blend_ps(a.values, b.values, mask); |
44 | } |
45 | static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b, |
46 | const Vectorized<float>& mask) { |
47 | return _mm256_blendv_ps(a.values, b.values, mask.values); |
48 | } |
49 | template<typename step_t> |
50 | static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) { |
51 | return Vectorized<float>( |
52 | base, base + step, base + 2 * step, base + 3 * step, |
53 | base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step); |
54 | } |
55 | static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b, |
56 | int64_t count = size()) { |
57 | switch (count) { |
58 | case 0: |
59 | return a; |
60 | case 1: |
61 | return blend<1>(a, b); |
62 | case 2: |
63 | return blend<3>(a, b); |
64 | case 3: |
65 | return blend<7>(a, b); |
66 | case 4: |
67 | return blend<15>(a, b); |
68 | case 5: |
69 | return blend<31>(a, b); |
70 | case 6: |
71 | return blend<63>(a, b); |
72 | case 7: |
73 | return blend<127>(a, b); |
74 | } |
75 | return b; |
76 | } |
77 | static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { |
78 | if (count == size()) |
79 | return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr)); |
80 | __at_align__ float tmp_values[size()]; |
81 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
82 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
83 | // instructions while a loop would be compiled to one instruction. |
84 | for (const auto i : c10::irange(size())) { |
85 | tmp_values[i] = 0.0; |
86 | } |
87 | std::memcpy( |
88 | tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float)); |
89 | return _mm256_loadu_ps(tmp_values); |
90 | } |
91 | void store(void* ptr, int64_t count = size()) const { |
92 | if (count == size()) { |
93 | _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values); |
94 | } else if (count > 0) { |
95 | float tmp_values[size()]; |
96 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values); |
97 | std::memcpy(ptr, tmp_values, count * sizeof(float)); |
98 | } |
99 | } |
100 | const float& operator[](int idx) const = delete; |
101 | float& operator[](int idx) = delete; |
102 | int zero_mask() const { |
103 | // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit |
104 | __m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ); |
105 | return _mm256_movemask_ps(cmp); |
106 | } |
107 | Vectorized<float> isnan() const { |
108 | return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); |
109 | } |
110 | Vectorized<float> map(float (*const f)(float)) const { |
111 | __at_align__ float tmp[size()]; |
112 | store(tmp); |
113 | for (const auto i : c10::irange(size())) { |
114 | tmp[i] = f(tmp[i]); |
115 | } |
116 | return loadu(tmp); |
117 | } |
118 | Vectorized<float> abs() const { |
119 | auto mask = _mm256_set1_ps(-0.f); |
120 | return _mm256_andnot_ps(mask, values); |
121 | } |
122 | Vectorized<float> angle() const { |
123 | const auto zero_vec = _mm256_set1_ps(0.f); |
124 | const auto nan_vec = _mm256_set1_ps(NAN); |
125 | const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); |
126 | const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); |
127 | const auto pi = _mm256_set1_ps(c10::pi<float>); |
128 | |
129 | const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); |
130 | auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); |
131 | angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); |
132 | return angle; |
133 | } |
134 | Vectorized<float> real() const { |
135 | return *this; |
136 | } |
137 | Vectorized<float> imag() const { |
138 | return _mm256_set1_ps(0); |
139 | } |
140 | Vectorized<float> conj() const { |
141 | return *this; |
142 | } |
143 | Vectorized<float> acos() const { |
144 | return Vectorized<float>(Sleef_acosf8_u10(values)); |
145 | } |
146 | Vectorized<float> asin() const { |
147 | return Vectorized<float>(Sleef_asinf8_u10(values)); |
148 | } |
149 | Vectorized<float> atan() const { |
150 | return Vectorized<float>(Sleef_atanf8_u10(values)); |
151 | } |
152 | Vectorized<float> atan2(const Vectorized<float> &b) const { |
153 | return Vectorized<float>(Sleef_atan2f8_u10(values, b)); |
154 | } |
155 | Vectorized<float> copysign(const Vectorized<float> &sign) const { |
156 | return Vectorized<float>(Sleef_copysignf8(values, sign)); |
157 | } |
158 | Vectorized<float> erf() const { |
159 | // constants |
160 | const auto neg_zero_vec = _mm256_set1_ps(-0.f); |
161 | const auto one_vec = _mm256_set1_ps(1.0f); |
162 | const auto p = _mm256_set1_ps(0.3275911f); |
163 | const auto p1 = _mm256_set1_ps(0.254829592f); |
164 | const auto p2 = _mm256_set1_ps(-0.284496736f); |
165 | const auto p3 = _mm256_set1_ps(1.421413741f); |
166 | const auto p4 = _mm256_set1_ps(-1.453152027f); |
167 | const auto p5 = _mm256_set1_ps(1.061405429f); |
168 | // sign(x) |
169 | auto sign_mask = _mm256_and_ps(neg_zero_vec, values); |
170 | auto abs_vec = _mm256_xor_ps(sign_mask, values); |
171 | // t = 1 / (p * abs(x) + 1) |
172 | auto tmp0 = _mm256_fmadd_ps(p, abs_vec, one_vec); |
173 | auto t = _mm256_div_ps(one_vec, tmp0); |
174 | // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 |
175 | auto tmp1 = _mm256_fmadd_ps(p5, t, p4); |
176 | auto tmp2 = _mm256_fmadd_ps(tmp1, t, p3); |
177 | auto tmp3 = _mm256_fmadd_ps(tmp2, t, p2); |
178 | auto r = _mm256_fmadd_ps(tmp3, t, p1); |
179 | // - exp(- x * x) |
180 | auto pow_2 = _mm256_mul_ps(values, values); |
181 | auto neg_pow_2 = _mm256_xor_ps(neg_zero_vec, pow_2); |
182 | // auto tmp4 = exp(neg_pow_2); |
183 | auto tmp4 = Vectorized<float>(Sleef_expf8_u10(neg_pow_2)); |
184 | auto tmp5 = _mm256_xor_ps(neg_zero_vec, tmp4); |
185 | // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) |
186 | auto tmp6 = _mm256_mul_ps(tmp5, t); |
187 | auto tmp7 = _mm256_fmadd_ps(tmp6, r, one_vec); |
188 | return _mm256_xor_ps(sign_mask, tmp7); |
189 | } |
190 | Vectorized<float> erfc() const { |
191 | return Vectorized<float>(Sleef_erfcf8_u15(values)); |
192 | } |
193 | Vectorized<float> erfinv() const { |
194 | return map(calc_erfinv); |
195 | } |
196 | Vectorized<float> exp() const { |
197 | return Vectorized<float>(Sleef_expf8_u10(values)); |
198 | } |
199 | Vectorized<float> exp2() const { |
200 | return Vectorized<float>(Sleef_exp2f8_u10(values)); |
201 | } |
202 | Vectorized<float> expm1() const { |
203 | return Vectorized<float>(Sleef_expm1f8_u10(values)); |
204 | } |
205 | Vectorized<float> fmod(const Vectorized<float>& q) const { |
206 | return Vectorized<float>(Sleef_fmodf8(values, q)); |
207 | } |
208 | Vectorized<float> log() const { |
209 | return Vectorized<float>(Sleef_logf8_u10(values)); |
210 | } |
211 | Vectorized<float> log2() const { |
212 | return Vectorized<float>(Sleef_log2f8_u10(values)); |
213 | } |
214 | Vectorized<float> log10() const { |
215 | return Vectorized<float>(Sleef_log10f8_u10(values)); |
216 | } |
217 | Vectorized<float> log1p() const { |
218 | return Vectorized<float>(Sleef_log1pf8_u10(values)); |
219 | } |
220 | Vectorized<float> frac() const; |
221 | Vectorized<float> sin() const { |
222 | return Vectorized<float>(Sleef_sinf8_u10(values)); |
223 | } |
224 | Vectorized<float> sinh() const { |
225 | return Vectorized<float>(Sleef_sinhf8_u10(values)); |
226 | } |
227 | Vectorized<float> cos() const { |
228 | return Vectorized<float>(Sleef_cosf8_u10(values)); |
229 | } |
230 | Vectorized<float> cosh() const { |
231 | return Vectorized<float>(Sleef_coshf8_u10(values)); |
232 | } |
233 | Vectorized<float> ceil() const { |
234 | return _mm256_ceil_ps(values); |
235 | } |
236 | Vectorized<float> floor() const { |
237 | return _mm256_floor_ps(values); |
238 | } |
239 | Vectorized<float> hypot(const Vectorized<float> &b) const { |
240 | return Vectorized<float>(Sleef_hypotf8_u05(values, b)); |
241 | } |
242 | Vectorized<float> i0() const { |
243 | return map(calc_i0); |
244 | } |
245 | Vectorized<float> i0e() const { |
246 | return map(calc_i0e); |
247 | } |
248 | Vectorized<float> igamma(const Vectorized<float> &x) const { |
249 | __at_align__ float tmp[size()]; |
250 | __at_align__ float tmp_x[size()]; |
251 | store(tmp); |
252 | x.store(tmp_x); |
253 | for (const auto i : c10::irange(size())) { |
254 | tmp[i] = calc_igamma(tmp[i], tmp_x[i]); |
255 | } |
256 | return loadu(tmp); |
257 | } |
258 | Vectorized<float> igammac(const Vectorized<float> &x) const { |
259 | __at_align__ float tmp[size()]; |
260 | __at_align__ float tmp_x[size()]; |
261 | store(tmp); |
262 | x.store(tmp_x); |
263 | for (const auto i : c10::irange(size())) { |
264 | tmp[i] = calc_igammac(tmp[i], tmp_x[i]); |
265 | } |
266 | return loadu(tmp); |
267 | } |
268 | Vectorized<float> neg() const { |
269 | return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); |
270 | } |
271 | Vectorized<float> nextafter(const Vectorized<float> &b) const { |
272 | return Vectorized<float>(Sleef_nextafterf8(values, b)); |
273 | } |
274 | Vectorized<float> round() const { |
275 | return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
276 | } |
277 | Vectorized<float> tan() const { |
278 | return Vectorized<float>(Sleef_tanf8_u10(values)); |
279 | } |
280 | Vectorized<float> tanh() const { |
281 | return Vectorized<float>(Sleef_tanhf8_u10(values)); |
282 | } |
283 | Vectorized<float> trunc() const { |
284 | return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
285 | } |
286 | Vectorized<float> lgamma() const { |
287 | return Vectorized<float>(Sleef_lgammaf8_u10(values)); |
288 | } |
289 | Vectorized<float> sqrt() const { |
290 | return _mm256_sqrt_ps(values); |
291 | } |
292 | Vectorized<float> reciprocal() const { |
293 | return _mm256_div_ps(_mm256_set1_ps(1), values); |
294 | } |
295 | Vectorized<float> rsqrt() const { |
296 | return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values)); |
297 | } |
298 | Vectorized<float> pow(const Vectorized<float> &b) const { |
299 | return Vectorized<float>(Sleef_powf8_u10(values, b)); |
300 | } |
301 | // Comparison using the _CMP_**_OQ predicate. |
302 | // `O`: get false if an operand is NaN |
303 | // `Q`: do not raise if an operand is NaN |
304 | Vectorized<float> operator==(const Vectorized<float>& other) const { |
305 | return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); |
306 | } |
307 | |
308 | Vectorized<float> operator!=(const Vectorized<float>& other) const { |
309 | return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); |
310 | } |
311 | |
312 | Vectorized<float> operator<(const Vectorized<float>& other) const { |
313 | return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ); |
314 | } |
315 | |
316 | Vectorized<float> operator<=(const Vectorized<float>& other) const { |
317 | return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ); |
318 | } |
319 | |
320 | Vectorized<float> operator>(const Vectorized<float>& other) const { |
321 | return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ); |
322 | } |
323 | |
324 | Vectorized<float> operator>=(const Vectorized<float>& other) const { |
325 | return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ); |
326 | } |
327 | |
328 | Vectorized<float> eq(const Vectorized<float>& other) const; |
329 | Vectorized<float> ne(const Vectorized<float>& other) const; |
330 | Vectorized<float> gt(const Vectorized<float>& other) const; |
331 | Vectorized<float> ge(const Vectorized<float>& other) const; |
332 | Vectorized<float> lt(const Vectorized<float>& other) const; |
333 | Vectorized<float> le(const Vectorized<float>& other) const; |
334 | }; |
335 | |
336 | template <> |
337 | Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) { |
338 | return _mm256_add_ps(a, b); |
339 | } |
340 | |
341 | template <> |
342 | Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) { |
343 | return _mm256_sub_ps(a, b); |
344 | } |
345 | |
346 | template <> |
347 | Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) { |
348 | return _mm256_mul_ps(a, b); |
349 | } |
350 | |
351 | template <> |
352 | Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) { |
353 | return _mm256_div_ps(a, b); |
354 | } |
355 | |
356 | // frac. Implement this here so we can use subtraction |
357 | inline Vectorized<float> Vectorized<float>::frac() const { |
358 | return *this - this->trunc(); |
359 | } |
360 | |
361 | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if |
362 | // either input is a NaN. |
363 | template <> |
364 | Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) { |
365 | Vectorized<float> max = _mm256_max_ps(a, b); |
366 | Vectorized<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); |
367 | // Exploit the fact that all-ones is a NaN. |
368 | return _mm256_or_ps(max, isnan); |
369 | } |
370 | |
371 | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if |
372 | // either input is a NaN. |
373 | template <> |
374 | Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) { |
375 | Vectorized<float> min = _mm256_min_ps(a, b); |
376 | Vectorized<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q); |
377 | // Exploit the fact that all-ones is a NaN. |
378 | return _mm256_or_ps(min, isnan); |
379 | } |
380 | |
381 | template <> |
382 | Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) { |
383 | return _mm256_min_ps(max, _mm256_max_ps(min, a)); |
384 | } |
385 | |
386 | template <> |
387 | Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) { |
388 | return _mm256_min_ps(max, a); |
389 | } |
390 | |
391 | template <> |
392 | Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) { |
393 | return _mm256_max_ps(min, a); |
394 | } |
395 | |
396 | template <> |
397 | Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) { |
398 | return _mm256_and_ps(a, b); |
399 | } |
400 | |
401 | template <> |
402 | Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) { |
403 | return _mm256_or_ps(a, b); |
404 | } |
405 | |
406 | template <> |
407 | Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) { |
408 | return _mm256_xor_ps(a, b); |
409 | } |
410 | |
411 | inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const { |
412 | return (*this == other) & Vectorized<float>(1.0f); |
413 | } |
414 | |
415 | inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const { |
416 | return (*this != other) & Vectorized<float>(1.0f); |
417 | } |
418 | |
419 | inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const { |
420 | return (*this > other) & Vectorized<float>(1.0f); |
421 | } |
422 | |
423 | inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const { |
424 | return (*this >= other) & Vectorized<float>(1.0f); |
425 | } |
426 | |
427 | inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const { |
428 | return (*this < other) & Vectorized<float>(1.0f); |
429 | } |
430 | |
431 | inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const { |
432 | return (*this <= other) & Vectorized<float>(1.0f); |
433 | } |
434 | |
435 | template <> |
436 | inline void convert(const float* src, float* dst, int64_t n) { |
437 | int64_t i; |
438 | #pragma unroll |
439 | for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { |
440 | _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i)); |
441 | } |
442 | #pragma unroll |
443 | for (; i < n; i++) { |
444 | dst[i] = src[i]; |
445 | } |
446 | } |
447 | |
448 | |
449 | template <> |
450 | Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { |
451 | return _mm256_fmadd_ps(a, b, c); |
452 | } |
453 | |
454 | template <> |
455 | Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { |
456 | return _mm256_fmsub_ps(a, b, c); |
457 | } |
458 | |
459 | // Used by Inductor CPP codegen |
460 | template<> |
461 | inline void transpose_mxn<float, 8, 8>( |
462 | const float* src, |
463 | int64_t ld_src, |
464 | float* dst, |
465 | int64_t ld_dst) { |
466 | // load from src to registers |
467 | // a: a0 a1 a2 a3 a4 a5 a6 a7 |
468 | // b: b0 b1 b2 b3 b4 b5 b6 b7 |
469 | // c: c0 c1 c2 c3 c4 c5 c6 c7 |
470 | // d: d0 d1 d2 d3 d4 d5 d6 d7 |
471 | // e: e0 e1 e2 e3 e4 e5 e6 e7 |
472 | // f: f0 f1 f2 f3 f4 f5 f6 f7 |
473 | // g: g0 g1 g2 g3 g4 g5 g6 g7 |
474 | // h: h0 h1 h2 h3 h4 h5 h6 h7 |
475 | __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); |
476 | __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); |
477 | __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); |
478 | __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); |
479 | __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); |
480 | __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); |
481 | __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); |
482 | __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); |
483 | |
484 | __m256 ta, tb, tc, td, te, tf, tg, th; |
485 | // unpacking and interleaving 32-bit elements |
486 | // a0 b0 a1 b1 a4 b4 a5 b5 |
487 | // a2 b2 a3 b3 a6 b6 a7 b7 |
488 | // c0 d0 c1 d1 ... |
489 | // c2 d2 c3 d3 ... |
490 | // e0 f0 e1 f1 ... |
491 | // e2 f2 e3 f3 ... |
492 | // g0 h0 g1 h1 ... |
493 | // g2 h2 g3 h3 ... |
494 | ta = _mm256_unpacklo_ps(a, b); |
495 | tb = _mm256_unpackhi_ps(a, b); |
496 | tc = _mm256_unpacklo_ps(c, d); |
497 | td = _mm256_unpackhi_ps(c, d); |
498 | te = _mm256_unpacklo_ps(e, f); |
499 | tf = _mm256_unpackhi_ps(e, f); |
500 | tg = _mm256_unpacklo_ps(g, h); |
501 | th = _mm256_unpackhi_ps(g, h); |
502 | |
503 | // unpacking and interleaving 64-bit elements |
504 | // a0 b0 c0 d0 a4 b4 c4 d4 |
505 | // a1 b1 c1 d1 ... |
506 | // a2 b2 c2 d2 ... |
507 | // a3 b3 c3 d3 ... |
508 | // e0 f0 g0 h0 e4 f4 g4 h4 |
509 | // e1 f1 g1 h1 ... |
510 | // e2 f2 g2 h2 ... |
511 | // e3 f3 g3 h3 ... |
512 | a = _mm256_castpd_ps( |
513 | _mm256_unpacklo_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc))); |
514 | b = _mm256_castpd_ps( |
515 | _mm256_unpackhi_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc))); |
516 | c = _mm256_castpd_ps( |
517 | _mm256_unpacklo_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td))); |
518 | d = _mm256_castpd_ps( |
519 | _mm256_unpackhi_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td))); |
520 | e = _mm256_castpd_ps( |
521 | _mm256_unpacklo_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg))); |
522 | f = _mm256_castpd_ps( |
523 | _mm256_unpackhi_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg))); |
524 | g = _mm256_castpd_ps( |
525 | _mm256_unpacklo_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th))); |
526 | h = _mm256_castpd_ps( |
527 | _mm256_unpackhi_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th))); |
528 | |
529 | // shuffle 128-bits (composed of 4 32-bit elements) |
530 | // a0 b0 c0 d0 e0 f0 g0 h0 |
531 | // a1 b1 c1 d1 ... |
532 | // a2 b2 c2 d2 ... |
533 | // a3 b3 c3 d3 ... |
534 | // a4 b4 c4 d4 ... |
535 | // a5 b5 c5 d5 ... |
536 | // a6 b6 c6 d6 ... |
537 | // a7 b7 c7 d7 ... |
538 | ta = _mm256_permute2f128_ps(a, e, 0x20); |
539 | tb = _mm256_permute2f128_ps(b, f, 0x20); |
540 | tc = _mm256_permute2f128_ps(c, g, 0x20); |
541 | td = _mm256_permute2f128_ps(d, h, 0x20); |
542 | te = _mm256_permute2f128_ps(a, e, 0x31); |
543 | tf = _mm256_permute2f128_ps(b, f, 0x31); |
544 | tg = _mm256_permute2f128_ps(c, g, 0x31); |
545 | th = _mm256_permute2f128_ps(d, h, 0x31); |
546 | |
547 | // store from registers to dst |
548 | _mm256_storeu_ps(&dst[0 * ld_dst], ta); |
549 | _mm256_storeu_ps(&dst[1 * ld_dst], tb); |
550 | _mm256_storeu_ps(&dst[2 * ld_dst], tc); |
551 | _mm256_storeu_ps(&dst[3 * ld_dst], td); |
552 | _mm256_storeu_ps(&dst[4 * ld_dst], te); |
553 | _mm256_storeu_ps(&dst[5 * ld_dst], tf); |
554 | _mm256_storeu_ps(&dst[6 * ld_dst], tg); |
555 | _mm256_storeu_ps(&dst[7 * ld_dst], th); |
556 | } |
557 | |
558 | #endif |
559 | |
560 | }}} |
561 | |