1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <ATen/cpu/vec/intrinsics.h> |
7 | #include <ATen/cpu/vec/vec_base.h> |
8 | #include <c10/macros/Macros.h> |
9 | #include <c10/util/irange.h> |
10 | #include <iostream> |
11 | |
12 | namespace at { |
13 | namespace vec { |
14 | inline namespace CPU_CAPABILITY { |
15 | |
16 | #ifdef CPU_CAPABILITY_AVX2 |
17 | |
18 | struct Vectorizedi { |
19 | protected: |
20 | __m256i values; |
21 | |
22 | static inline __m256i invert(const __m256i& v) { |
23 | const auto ones = _mm256_set1_epi64x(-1); |
24 | return _mm256_xor_si256(ones, v); |
25 | } |
26 | public: |
27 | Vectorizedi() {} |
28 | Vectorizedi(__m256i v) : values(v) {} |
29 | operator __m256i() const { |
30 | return values; |
31 | } |
32 | }; |
33 | |
34 | #else |
35 | |
36 | struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined |
37 | |
38 | #endif // CPU_CAPABILITY_AVX2 |
39 | |
40 | #ifdef CPU_CAPABILITY_AVX2 |
41 | |
42 | template <> |
43 | class Vectorized<int64_t> : public Vectorizedi { |
44 | private: |
45 | static const Vectorized<int64_t> ones; |
46 | public: |
47 | using value_type = int64_t; |
48 | using size_type = int; |
49 | static constexpr size_type size() { |
50 | return 4; |
51 | } |
52 | using Vectorizedi::Vectorizedi; |
53 | Vectorized() {} |
54 | Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); } |
55 | Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) { |
56 | values = _mm256_setr_epi64x(val1, val2, val3, val4); |
57 | } |
58 | template <int64_t mask> |
59 | static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) { |
60 | __at_align__ int64_t tmp_values[size()]; |
61 | a.store(tmp_values); |
62 | if (mask & 0x01) |
63 | tmp_values[0] = _mm256_extract_epi64(b.values, 0); |
64 | if (mask & 0x02) |
65 | tmp_values[1] = _mm256_extract_epi64(b.values, 1); |
66 | if (mask & 0x04) |
67 | tmp_values[2] = _mm256_extract_epi64(b.values, 2); |
68 | if (mask & 0x08) |
69 | tmp_values[3] = _mm256_extract_epi64(b.values, 3); |
70 | return loadu(tmp_values); |
71 | } |
72 | static Vectorized<int64_t> blendv(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, |
73 | const Vectorized<int64_t>& mask) { |
74 | return _mm256_blendv_epi8(a.values, b.values, mask.values); |
75 | } |
76 | template <typename step_t> |
77 | static Vectorized<int64_t> arange(int64_t base = 0, step_t step = static_cast<step_t>(1)) { |
78 | return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step); |
79 | } |
80 | static Vectorized<int64_t> |
81 | set(Vectorized<int64_t> a, Vectorized<int64_t> b, int64_t count = size()) { |
82 | switch (count) { |
83 | case 0: |
84 | return a; |
85 | case 1: |
86 | return blend<1>(a, b); |
87 | case 2: |
88 | return blend<3>(a, b); |
89 | case 3: |
90 | return blend<7>(a, b); |
91 | } |
92 | return b; |
93 | } |
94 | static Vectorized<int64_t> loadu(const void* ptr) { |
95 | return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); |
96 | } |
97 | static Vectorized<int64_t> loadu(const void* ptr, int64_t count) { |
98 | __at_align__ int64_t tmp_values[size()]; |
99 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
100 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
101 | // instructions while a loop would be compiled to one instruction. |
102 | for (const auto i : c10::irange(size())) { |
103 | tmp_values[i] = 0; |
104 | } |
105 | std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); |
106 | return loadu(tmp_values); |
107 | } |
108 | void store(void* ptr, int count = size()) const { |
109 | if (count == size()) { |
110 | // ptr need not to be aligned here. See |
111 | // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html |
112 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); |
113 | } else if (count > 0) { |
114 | __at_align__ int64_t tmp_values[size()]; |
115 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); |
116 | std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); |
117 | } |
118 | } |
119 | const int64_t& operator[](int idx) const = delete; |
120 | int64_t& operator[](int idx) = delete; |
121 | Vectorized<int64_t> abs() const { |
122 | auto zero = _mm256_set1_epi64x(0); |
123 | auto is_larger = _mm256_cmpgt_epi64(zero, values); |
124 | auto inverse = _mm256_xor_si256(values, is_larger); |
125 | return _mm256_sub_epi64(inverse, is_larger); |
126 | } |
127 | Vectorized<int64_t> real() const { |
128 | return *this; |
129 | } |
130 | Vectorized<int64_t> imag() const { |
131 | return _mm256_set1_epi64x(0); |
132 | } |
133 | Vectorized<int64_t> conj() const { |
134 | return *this; |
135 | } |
136 | Vectorized<int64_t> neg() const; |
137 | Vectorized<int64_t> operator==(const Vectorized<int64_t>& other) const { |
138 | return _mm256_cmpeq_epi64(values, other.values); |
139 | } |
140 | Vectorized<int64_t> operator!=(const Vectorized<int64_t>& other) const { |
141 | return invert(_mm256_cmpeq_epi64(values, other.values)); |
142 | } |
143 | Vectorized<int64_t> operator<(const Vectorized<int64_t>& other) const { |
144 | return _mm256_cmpgt_epi64(other.values, values); |
145 | } |
146 | Vectorized<int64_t> operator<=(const Vectorized<int64_t>& other) const { |
147 | return invert(_mm256_cmpgt_epi64(values, other.values)); |
148 | } |
149 | Vectorized<int64_t> operator>(const Vectorized<int64_t>& other) const { |
150 | return _mm256_cmpgt_epi64(values, other.values); |
151 | } |
152 | Vectorized<int64_t> operator>=(const Vectorized<int64_t>& other) const { |
153 | return invert(_mm256_cmpgt_epi64(other.values, values)); |
154 | } |
155 | |
156 | Vectorized<int64_t> eq(const Vectorized<int64_t>& other) const; |
157 | Vectorized<int64_t> ne(const Vectorized<int64_t>& other) const; |
158 | Vectorized<int64_t> gt(const Vectorized<int64_t>& other) const; |
159 | Vectorized<int64_t> ge(const Vectorized<int64_t>& other) const; |
160 | Vectorized<int64_t> lt(const Vectorized<int64_t>& other) const; |
161 | Vectorized<int64_t> le(const Vectorized<int64_t>& other) const; |
162 | }; |
163 | |
164 | template <> |
165 | class Vectorized<int32_t> : public Vectorizedi { |
166 | private: |
167 | static const Vectorized<int32_t> ones; |
168 | public: |
169 | using value_type = int32_t; |
170 | static constexpr int size() { |
171 | return 8; |
172 | } |
173 | using Vectorizedi::Vectorizedi; |
174 | Vectorized() {} |
175 | Vectorized(int32_t v) { values = _mm256_set1_epi32(v); } |
176 | Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4, |
177 | int32_t val5, int32_t val6, int32_t val7, int32_t val8) { |
178 | values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8); |
179 | } |
180 | template <int64_t mask> |
181 | static Vectorized<int32_t> blend(Vectorized<int32_t> a, Vectorized<int32_t> b) { |
182 | return _mm256_blend_epi32(a, b, mask); |
183 | } |
184 | static Vectorized<int32_t> blendv(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b, |
185 | const Vectorized<int32_t>& mask) { |
186 | return _mm256_blendv_epi8(a.values, b.values, mask.values); |
187 | } |
188 | template <typename step_t> |
189 | static Vectorized<int32_t> arange(int32_t base = 0, step_t step = static_cast<step_t>(1)) { |
190 | return Vectorized<int32_t>( |
191 | base, base + step, base + 2 * step, base + 3 * step, |
192 | base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step); |
193 | } |
194 | static Vectorized<int32_t> |
195 | set(Vectorized<int32_t> a, Vectorized<int32_t> b, int32_t count = size()) { |
196 | switch (count) { |
197 | case 0: |
198 | return a; |
199 | case 1: |
200 | return blend<1>(a, b); |
201 | case 2: |
202 | return blend<3>(a, b); |
203 | case 3: |
204 | return blend<7>(a, b); |
205 | case 4: |
206 | return blend<15>(a, b); |
207 | case 5: |
208 | return blend<31>(a, b); |
209 | case 6: |
210 | return blend<63>(a, b); |
211 | case 7: |
212 | return blend<127>(a, b); |
213 | } |
214 | return b; |
215 | } |
216 | static Vectorized<int32_t> loadu(const void* ptr) { |
217 | return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); |
218 | } |
219 | static Vectorized<int32_t> loadu(const void* ptr, int32_t count) { |
220 | __at_align__ int32_t tmp_values[size()]; |
221 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
222 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
223 | // instructions while a loop would be compiled to one instruction. |
224 | for (const auto i : c10::irange(size())) { |
225 | tmp_values[i] = 0; |
226 | } |
227 | std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); |
228 | return loadu(tmp_values); |
229 | } |
230 | void store(void* ptr, int count = size()) const { |
231 | if (count == size()) { |
232 | // ptr need not to be aligned here. See |
233 | // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html |
234 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); |
235 | } else if (count > 0) { |
236 | __at_align__ int32_t tmp_values[size()]; |
237 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); |
238 | std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); |
239 | } |
240 | } |
241 | const int32_t& operator[](int idx) const = delete; |
242 | int32_t& operator[](int idx) = delete; |
243 | Vectorized<int32_t> abs() const { |
244 | return _mm256_abs_epi32(values); |
245 | } |
246 | Vectorized<int32_t> real() const { |
247 | return *this; |
248 | } |
249 | Vectorized<int32_t> imag() const { |
250 | return _mm256_set1_epi32(0); |
251 | } |
252 | Vectorized<int32_t> conj() const { |
253 | return *this; |
254 | } |
255 | Vectorized<int32_t> neg() const; |
256 | Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const { |
257 | return _mm256_cmpeq_epi32(values, other.values); |
258 | } |
259 | Vectorized<int32_t> operator!=(const Vectorized<int32_t>& other) const { |
260 | return invert(_mm256_cmpeq_epi32(values, other.values)); |
261 | } |
262 | Vectorized<int32_t> operator<(const Vectorized<int32_t>& other) const { |
263 | return _mm256_cmpgt_epi32(other.values, values); |
264 | } |
265 | Vectorized<int32_t> operator<=(const Vectorized<int32_t>& other) const { |
266 | return invert(_mm256_cmpgt_epi32(values, other.values)); |
267 | } |
268 | Vectorized<int32_t> operator>(const Vectorized<int32_t>& other) const { |
269 | return _mm256_cmpgt_epi32(values, other.values); |
270 | } |
271 | Vectorized<int32_t> operator>=(const Vectorized<int32_t>& other) const { |
272 | return invert(_mm256_cmpgt_epi32(other.values, values)); |
273 | } |
274 | Vectorized<int32_t> eq(const Vectorized<int32_t>& other) const; |
275 | Vectorized<int32_t> ne(const Vectorized<int32_t>& other) const; |
276 | Vectorized<int32_t> gt(const Vectorized<int32_t>& other) const; |
277 | Vectorized<int32_t> ge(const Vectorized<int32_t>& other) const; |
278 | Vectorized<int32_t> lt(const Vectorized<int32_t>& other) const; |
279 | Vectorized<int32_t> le(const Vectorized<int32_t>& other) const; |
280 | }; |
281 | |
282 | template <> |
283 | inline void convert(const int32_t *src, float *dst, int64_t n) { |
284 | int64_t i; |
285 | // int32_t and float have same size |
286 | #ifndef _MSC_VER |
287 | # pragma unroll |
288 | #endif |
289 | for (i = 0; i <= (n - Vectorized<int32_t>::size()); i += Vectorized<int32_t>::size()) { |
290 | auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i)); |
291 | auto output_vec = _mm256_cvtepi32_ps(input_vec); |
292 | _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec); |
293 | } |
294 | #ifndef _MSC_VER |
295 | # pragma unroll |
296 | #endif |
297 | for (; i < n; i++) { |
298 | dst[i] = static_cast<float>(src[i]); |
299 | } |
300 | } |
301 | |
302 | template <> |
303 | inline void convert(const int32_t *src, double *dst, int64_t n) { |
304 | int64_t i; |
305 | // int32_t has half the size of double |
306 | #ifndef _MSC_VER |
307 | # pragma unroll |
308 | #endif |
309 | for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) { |
310 | auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i)); |
311 | auto output_vec = _mm256_cvtepi32_pd(input_128_vec); |
312 | _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec); |
313 | } |
314 | #ifndef _MSC_VER |
315 | # pragma unroll |
316 | #endif |
317 | for (; i < n; i++) { |
318 | dst[i] = static_cast<double>(src[i]); |
319 | } |
320 | } |
321 | |
322 | template <> |
323 | class Vectorized<int16_t> : public Vectorizedi { |
324 | private: |
325 | static const Vectorized<int16_t> ones; |
326 | public: |
327 | using value_type = int16_t; |
328 | static constexpr int size() { |
329 | return 16; |
330 | } |
331 | using Vectorizedi::Vectorizedi; |
332 | Vectorized() {} |
333 | Vectorized(int16_t v) { values = _mm256_set1_epi16(v); } |
334 | Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4, |
335 | int16_t val5, int16_t val6, int16_t val7, int16_t val8, |
336 | int16_t val9, int16_t val10, int16_t val11, int16_t val12, |
337 | int16_t val13, int16_t val14, int16_t val15, int16_t val16) { |
338 | values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8, |
339 | val9, val10, val11, val12, val13, val14, val15, val16); |
340 | } |
341 | template <int64_t mask> |
342 | static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) { |
343 | __at_align__ int16_t tmp_values[size()]; |
344 | a.store(tmp_values); |
345 | if (mask & 0x01) |
346 | tmp_values[0] = _mm256_extract_epi16(b.values, 0); |
347 | if (mask & 0x02) |
348 | tmp_values[1] = _mm256_extract_epi16(b.values, 1); |
349 | if (mask & 0x04) |
350 | tmp_values[2] = _mm256_extract_epi16(b.values, 2); |
351 | if (mask & 0x08) |
352 | tmp_values[3] = _mm256_extract_epi16(b.values, 3); |
353 | if (mask & 0x10) |
354 | tmp_values[4] = _mm256_extract_epi16(b.values, 4); |
355 | if (mask & 0x20) |
356 | tmp_values[5] = _mm256_extract_epi16(b.values, 5); |
357 | if (mask & 0x40) |
358 | tmp_values[6] = _mm256_extract_epi16(b.values, 6); |
359 | if (mask & 0x80) |
360 | tmp_values[7] = _mm256_extract_epi16(b.values, 7); |
361 | if (mask & 0x100) |
362 | tmp_values[8] = _mm256_extract_epi16(b.values, 8); |
363 | if (mask & 0x200) |
364 | tmp_values[9] = _mm256_extract_epi16(b.values, 9); |
365 | if (mask & 0x400) |
366 | tmp_values[10] = _mm256_extract_epi16(b.values, 10); |
367 | if (mask & 0x800) |
368 | tmp_values[11] = _mm256_extract_epi16(b.values, 11); |
369 | if (mask & 0x1000) |
370 | tmp_values[12] = _mm256_extract_epi16(b.values, 12); |
371 | if (mask & 0x2000) |
372 | tmp_values[13] = _mm256_extract_epi16(b.values, 13); |
373 | if (mask & 0x4000) |
374 | tmp_values[14] = _mm256_extract_epi16(b.values, 14); |
375 | if (mask & 0x8000) |
376 | tmp_values[15] = _mm256_extract_epi16(b.values, 15); |
377 | return loadu(tmp_values); |
378 | } |
379 | static Vectorized<int16_t> blendv(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b, |
380 | const Vectorized<int16_t>& mask) { |
381 | return _mm256_blendv_epi8(a.values, b.values, mask.values); |
382 | } |
383 | template <typename step_t> |
384 | static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) { |
385 | return Vectorized<int16_t>( |
386 | base, base + step, base + 2 * step, base + 3 * step, |
387 | base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, |
388 | base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, |
389 | base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); |
390 | } |
391 | static Vectorized<int16_t> |
392 | set(Vectorized<int16_t> a, Vectorized<int16_t> b, int16_t count = size()) { |
393 | switch (count) { |
394 | case 0: |
395 | return a; |
396 | case 1: |
397 | return blend<1>(a, b); |
398 | case 2: |
399 | return blend<3>(a, b); |
400 | case 3: |
401 | return blend<7>(a, b); |
402 | case 4: |
403 | return blend<15>(a, b); |
404 | case 5: |
405 | return blend<31>(a, b); |
406 | case 6: |
407 | return blend<63>(a, b); |
408 | case 7: |
409 | return blend<127>(a, b); |
410 | case 8: |
411 | return blend<255>(a, b); |
412 | case 9: |
413 | return blend<511>(a, b); |
414 | case 10: |
415 | return blend<1023>(a, b); |
416 | case 11: |
417 | return blend<2047>(a, b); |
418 | case 12: |
419 | return blend<4095>(a, b); |
420 | case 13: |
421 | return blend<8191>(a, b); |
422 | case 14: |
423 | return blend<16383>(a, b); |
424 | case 15: |
425 | return blend<32767>(a, b); |
426 | } |
427 | return b; |
428 | } |
429 | static Vectorized<int16_t> loadu(const void* ptr) { |
430 | return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); |
431 | } |
432 | static Vectorized<int16_t> loadu(const void* ptr, int16_t count) { |
433 | __at_align__ int16_t tmp_values[size()]; |
434 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
435 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
436 | // instructions while a loop would be compiled to one instruction. |
437 | for (const auto i : c10::irange(size())) { |
438 | tmp_values[i] = 0; |
439 | } |
440 | std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); |
441 | return loadu(tmp_values); |
442 | } |
443 | void store(void* ptr, int count = size()) const { |
444 | if (count == size()) { |
445 | // ptr need not to be aligned here. See |
446 | // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html |
447 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); |
448 | } else if (count > 0) { |
449 | __at_align__ int16_t tmp_values[size()]; |
450 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); |
451 | std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); |
452 | } |
453 | } |
454 | const int16_t& operator[](int idx) const = delete; |
455 | int16_t& operator[](int idx) = delete; |
456 | Vectorized<int16_t> abs() const { |
457 | return _mm256_abs_epi16(values); |
458 | } |
459 | Vectorized<int16_t> real() const { |
460 | return *this; |
461 | } |
462 | Vectorized<int16_t> imag() const { |
463 | return _mm256_set1_epi16(0); |
464 | } |
465 | Vectorized<int16_t> conj() const { |
466 | return *this; |
467 | } |
468 | Vectorized<int16_t> neg() const; |
469 | Vectorized<int16_t> operator==(const Vectorized<int16_t>& other) const { |
470 | return _mm256_cmpeq_epi16(values, other.values); |
471 | } |
472 | Vectorized<int16_t> operator!=(const Vectorized<int16_t>& other) const { |
473 | return invert(_mm256_cmpeq_epi16(values, other.values)); |
474 | } |
475 | Vectorized<int16_t> operator<(const Vectorized<int16_t>& other) const { |
476 | return _mm256_cmpgt_epi16(other.values, values); |
477 | } |
478 | Vectorized<int16_t> operator<=(const Vectorized<int16_t>& other) const { |
479 | return invert(_mm256_cmpgt_epi16(values, other.values)); |
480 | } |
481 | Vectorized<int16_t> operator>(const Vectorized<int16_t>& other) const { |
482 | return _mm256_cmpgt_epi16(values, other.values); |
483 | } |
484 | Vectorized<int16_t> operator>=(const Vectorized<int16_t>& other) const { |
485 | return invert(_mm256_cmpgt_epi16(other.values, values)); |
486 | } |
487 | |
488 | Vectorized<int16_t> eq(const Vectorized<int16_t>& other) const; |
489 | Vectorized<int16_t> ne(const Vectorized<int16_t>& other) const; |
490 | Vectorized<int16_t> gt(const Vectorized<int16_t>& other) const; |
491 | Vectorized<int16_t> ge(const Vectorized<int16_t>& other) const; |
492 | Vectorized<int16_t> lt(const Vectorized<int16_t>& other) const; |
493 | Vectorized<int16_t> le(const Vectorized<int16_t>& other) const; |
494 | }; |
495 | |
496 | template <typename T> |
497 | class Vectorized8 : public Vectorizedi { |
498 | static_assert( |
499 | std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, |
500 | "Only int8_t/uint8_t are supported" ); |
501 | protected: |
502 | static const Vectorized<T> ones; |
503 | public: |
504 | using value_type = T; |
505 | static constexpr int size() { |
506 | return 32; |
507 | } |
508 | using Vectorizedi::Vectorizedi; |
509 | Vectorized8() {} |
510 | Vectorized8(T v) { values = _mm256_set1_epi8(v); } |
511 | Vectorized8(T val1, T val2, T val3, T val4, |
512 | T val5, T val6, T val7, T val8, |
513 | T val9, T val10, T val11, T val12, |
514 | T val13, T val14, T val15, T val16, |
515 | T val17, T val18, T val19, T val20, |
516 | T val21, T val22, T val23, T val24, |
517 | T val25, T val26, T val27, T val28, |
518 | T val29, T val30, T val31, T val32) { |
519 | values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8, |
520 | val9, val10, val11, val12, val13, val14, val15, val16, |
521 | val17, val18, val19, val20, val21, val22, val23, val24, |
522 | val25, val26, val27, val28, val29, val30, val31, val32); |
523 | } |
524 | template <int64_t mask> |
525 | static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) { |
526 | __at_align__ T tmp_values[size()]; |
527 | a.store(tmp_values); |
528 | if (mask & 0x01) |
529 | tmp_values[0] = _mm256_extract_epi8(b.values, 0); |
530 | if (mask & 0x02) |
531 | tmp_values[1] = _mm256_extract_epi8(b.values, 1); |
532 | if (mask & 0x04) |
533 | tmp_values[2] = _mm256_extract_epi8(b.values, 2); |
534 | if (mask & 0x08) |
535 | tmp_values[3] = _mm256_extract_epi8(b.values, 3); |
536 | if (mask & 0x10) |
537 | tmp_values[4] = _mm256_extract_epi8(b.values, 4); |
538 | if (mask & 0x20) |
539 | tmp_values[5] = _mm256_extract_epi8(b.values, 5); |
540 | if (mask & 0x40) |
541 | tmp_values[6] = _mm256_extract_epi8(b.values, 6); |
542 | if (mask & 0x80) |
543 | tmp_values[7] = _mm256_extract_epi8(b.values, 7); |
544 | if (mask & 0x100) |
545 | tmp_values[8] = _mm256_extract_epi8(b.values, 8); |
546 | if (mask & 0x200) |
547 | tmp_values[9] = _mm256_extract_epi8(b.values, 9); |
548 | if (mask & 0x400) |
549 | tmp_values[10] = _mm256_extract_epi8(b.values, 10); |
550 | if (mask & 0x800) |
551 | tmp_values[11] = _mm256_extract_epi8(b.values, 11); |
552 | if (mask & 0x1000) |
553 | tmp_values[12] = _mm256_extract_epi8(b.values, 12); |
554 | if (mask & 0x2000) |
555 | tmp_values[13] = _mm256_extract_epi8(b.values, 13); |
556 | if (mask & 0x4000) |
557 | tmp_values[14] = _mm256_extract_epi8(b.values, 14); |
558 | if (mask & 0x8000) |
559 | tmp_values[15] = _mm256_extract_epi8(b.values, 15); |
560 | if (mask & 0x010000) |
561 | tmp_values[16] = _mm256_extract_epi8(b.values, 16); |
562 | if (mask & 0x020000) |
563 | tmp_values[17] = _mm256_extract_epi8(b.values, 17); |
564 | if (mask & 0x040000) |
565 | tmp_values[18] = _mm256_extract_epi8(b.values, 18); |
566 | if (mask & 0x080000) |
567 | tmp_values[19] = _mm256_extract_epi8(b.values, 19); |
568 | if (mask & 0x100000) |
569 | tmp_values[20] = _mm256_extract_epi8(b.values, 20); |
570 | if (mask & 0x200000) |
571 | tmp_values[21] = _mm256_extract_epi8(b.values, 21); |
572 | if (mask & 0x400000) |
573 | tmp_values[22] = _mm256_extract_epi8(b.values, 22); |
574 | if (mask & 0x800000) |
575 | tmp_values[23] = _mm256_extract_epi8(b.values, 23); |
576 | if (mask & 0x1000000) |
577 | tmp_values[24] = _mm256_extract_epi8(b.values, 24); |
578 | if (mask & 0x2000000) |
579 | tmp_values[25] = _mm256_extract_epi8(b.values, 25); |
580 | if (mask & 0x4000000) |
581 | tmp_values[26] = _mm256_extract_epi8(b.values, 26); |
582 | if (mask & 0x8000000) |
583 | tmp_values[27] = _mm256_extract_epi8(b.values, 27); |
584 | if (mask & 0x10000000) |
585 | tmp_values[28] = _mm256_extract_epi8(b.values, 28); |
586 | if (mask & 0x20000000) |
587 | tmp_values[29] = _mm256_extract_epi8(b.values, 29); |
588 | if (mask & 0x40000000) |
589 | tmp_values[30] = _mm256_extract_epi8(b.values, 30); |
590 | if (mask & 0x80000000) |
591 | tmp_values[31] = _mm256_extract_epi8(b.values, 31); |
592 | return loadu(tmp_values); |
593 | } |
594 | static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b, |
595 | const Vectorized<T>& mask) { |
596 | return _mm256_blendv_epi8(a.values, b.values, mask.values); |
597 | } |
598 | template <typename step_t> |
599 | static Vectorized<T> arange(T base = 0, step_t step = static_cast<step_t>(1)) { |
600 | return Vectorized<T>( |
601 | base, base + step, base + 2 * step, base + 3 * step, |
602 | base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, |
603 | base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, |
604 | base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, |
605 | base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, |
606 | base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, |
607 | base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, |
608 | base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step); |
609 | } |
610 | static Vectorized<T> |
611 | set(Vectorized<T> a, Vectorized<T> b, T count = size()) { |
612 | switch (count) { |
613 | case 0: |
614 | return a; |
615 | case 1: |
616 | return blend<0x1>(a, b); |
617 | case 2: |
618 | return blend<0x3>(a, b); |
619 | case 3: |
620 | return blend<0x7>(a, b); |
621 | case 4: |
622 | return blend<0xF>(a, b); |
623 | case 5: |
624 | return blend<0x1F>(a, b); |
625 | case 6: |
626 | return blend<0x3F>(a, b); |
627 | case 7: |
628 | return blend<0x7F>(a, b); |
629 | case 8: |
630 | return blend<0xFF>(a, b); |
631 | case 9: |
632 | return blend<0x1FF>(a, b); |
633 | case 10: |
634 | return blend<0x3FF>(a, b); |
635 | case 11: |
636 | return blend<0x7FF>(a, b); |
637 | case 12: |
638 | return blend<0xFFF>(a, b); |
639 | case 13: |
640 | return blend<0x1FFF>(a, b); |
641 | case 14: |
642 | return blend<0x3FFF>(a, b); |
643 | case 15: |
644 | return blend<0x7FFF>(a, b); |
645 | case 16: |
646 | return blend<0xFFFF>(a, b); |
647 | case 17: |
648 | return blend<0x1FFFF>(a, b); |
649 | case 18: |
650 | return blend<0x3FFFF>(a, b); |
651 | case 19: |
652 | return blend<0x7FFFF>(a, b); |
653 | case 20: |
654 | return blend<0xFFFFF>(a, b); |
655 | case 21: |
656 | return blend<0x1FFFFF>(a, b); |
657 | case 22: |
658 | return blend<0x3FFFFF>(a, b); |
659 | case 23: |
660 | return blend<0x7FFFFF>(a, b); |
661 | case 24: |
662 | return blend<0xFFFFFF>(a, b); |
663 | case 25: |
664 | return blend<0x1FFFFFF>(a, b); |
665 | case 26: |
666 | return blend<0x3FFFFFF>(a, b); |
667 | case 27: |
668 | return blend<0x7FFFFFF>(a, b); |
669 | case 28: |
670 | return blend<0xFFFFFFF>(a, b); |
671 | case 29: |
672 | return blend<0x1FFFFFFF>(a, b); |
673 | case 30: |
674 | return blend<0x3FFFFFFF>(a, b); |
675 | case 31: |
676 | return blend<0x7FFFFFFF>(a, b); |
677 | } |
678 | return b; |
679 | } |
680 | static Vectorized<T> loadu(const void* ptr) { |
681 | return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); |
682 | } |
683 | static Vectorized<T> loadu(const void* ptr, T count) { |
684 | __at_align__ T tmp_values[size()]; |
685 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
686 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
687 | // instructions while a loop would be compiled to one instruction. |
688 | for (const auto i : c10::irange(size())) { |
689 | tmp_values[i] = 0; |
690 | } |
691 | std::memcpy(tmp_values, ptr, count * sizeof(T)); |
692 | return loadu(tmp_values); |
693 | } |
694 | void store(void* ptr, int count = size()) const { |
695 | if (count == size()) { |
696 | // ptr need not to be aligned here. See |
697 | // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html |
698 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); |
699 | } else if (count > 0) { |
700 | __at_align__ T tmp_values[size()]; |
701 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); |
702 | std::memcpy(ptr, tmp_values, count * sizeof(T)); |
703 | } |
704 | } |
705 | const T& operator[](int idx) const = delete; |
706 | T& operator[](int idx) = delete; |
707 | Vectorized<T> real() const { |
708 | return *this; |
709 | } |
710 | Vectorized<T> imag() const { |
711 | return _mm256_set1_epi8(0); |
712 | } |
713 | Vectorized<T> conj() const { |
714 | return *this; |
715 | } |
716 | }; |
717 | |
718 | template<> |
719 | class Vectorized<int8_t>: public Vectorized8<int8_t> { |
720 | public: |
721 | using Vectorized8::Vectorized8; |
722 | |
723 | Vectorized<int8_t> neg() const; |
724 | |
725 | Vectorized<int8_t> abs() const { |
726 | return _mm256_abs_epi8(values); |
727 | } |
728 | |
729 | Vectorized<int8_t> operator==(const Vectorized<int8_t>& other) const { |
730 | return _mm256_cmpeq_epi8(values, other.values); |
731 | } |
732 | Vectorized<int8_t> operator!=(const Vectorized<int8_t>& other) const { |
733 | return invert(_mm256_cmpeq_epi8(values, other.values)); |
734 | } |
735 | Vectorized<int8_t> operator<(const Vectorized<int8_t>& other) const { |
736 | return _mm256_cmpgt_epi8(other.values, values); |
737 | } |
738 | Vectorized<int8_t> operator<=(const Vectorized<int8_t>& other) const { |
739 | return invert(_mm256_cmpgt_epi8(values, other.values)); |
740 | } |
741 | Vectorized<int8_t> operator>(const Vectorized<int8_t>& other) const { |
742 | return other < *this; |
743 | } |
744 | Vectorized<int8_t> operator>=(const Vectorized<int8_t>& other) const { |
745 | return other <= *this; |
746 | } |
747 | |
748 | Vectorized<int8_t> eq(const Vectorized<int8_t>& other) const; |
749 | Vectorized<int8_t> ne(const Vectorized<int8_t>& other) const; |
750 | Vectorized<int8_t> gt(const Vectorized<int8_t>& other) const; |
751 | Vectorized<int8_t> ge(const Vectorized<int8_t>& other) const; |
752 | Vectorized<int8_t> lt(const Vectorized<int8_t>& other) const; |
753 | Vectorized<int8_t> le(const Vectorized<int8_t>& other) const; |
754 | }; |
755 | |
756 | template<> |
757 | class Vectorized<uint8_t>: public Vectorized8<uint8_t> { |
758 | public: |
759 | using Vectorized8::Vectorized8; |
760 | |
761 | Vectorized<uint8_t> neg() const; |
762 | |
763 | Vectorized<uint8_t> abs() const { |
764 | return *this; |
765 | } |
766 | |
767 | Vectorized<uint8_t> operator==(const Vectorized<uint8_t>& other) const { |
768 | return _mm256_cmpeq_epi8(values, other.values); |
769 | } |
770 | Vectorized<uint8_t> operator!=(const Vectorized<uint8_t>& other) const { |
771 | return invert(_mm256_cmpeq_epi8(values, other.values)); |
772 | } |
773 | Vectorized<uint8_t> operator<(const Vectorized<uint8_t>& other) const { |
774 | __m256i max = _mm256_max_epu8(values, other.values); |
775 | return invert(_mm256_cmpeq_epi8(max, values)); |
776 | } |
777 | Vectorized<uint8_t> operator<=(const Vectorized<uint8_t>& other) const { |
778 | __m256i max = _mm256_max_epu8(values, other.values); |
779 | return _mm256_cmpeq_epi8(max, other.values); |
780 | } |
781 | Vectorized<uint8_t> operator>(const Vectorized<uint8_t>& other) const { |
782 | return other < *this; |
783 | } |
784 | Vectorized<uint8_t> operator>=(const Vectorized<uint8_t>& other) const { |
785 | return other <= *this; |
786 | } |
787 | |
788 | Vectorized<uint8_t> eq(const Vectorized<uint8_t>& other) const; |
789 | Vectorized<uint8_t> ne(const Vectorized<uint8_t>& other) const; |
790 | Vectorized<uint8_t> gt(const Vectorized<uint8_t>& other) const; |
791 | Vectorized<uint8_t> ge(const Vectorized<uint8_t>& other) const; |
792 | Vectorized<uint8_t> lt(const Vectorized<uint8_t>& other) const; |
793 | Vectorized<uint8_t> le(const Vectorized<uint8_t>& other) const; |
794 | }; |
795 | |
796 | template <> |
797 | Vectorized<int64_t> inline operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
798 | return _mm256_add_epi64(a, b); |
799 | } |
800 | |
801 | template <> |
802 | Vectorized<int32_t> inline operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
803 | return _mm256_add_epi32(a, b); |
804 | } |
805 | |
806 | template <> |
807 | Vectorized<int16_t> inline operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
808 | return _mm256_add_epi16(a, b); |
809 | } |
810 | |
811 | template <> |
812 | Vectorized<int8_t> inline operator+(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
813 | return _mm256_add_epi8(a, b); |
814 | } |
815 | |
816 | template <> |
817 | Vectorized<uint8_t> inline operator+(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
818 | return _mm256_add_epi8(a, b); |
819 | } |
820 | |
821 | template <> |
822 | Vectorized<int64_t> inline operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
823 | return _mm256_sub_epi64(a, b); |
824 | } |
825 | |
826 | template <> |
827 | Vectorized<int32_t> inline operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
828 | return _mm256_sub_epi32(a, b); |
829 | } |
830 | |
831 | template <> |
832 | Vectorized<int16_t> inline operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
833 | return _mm256_sub_epi16(a, b); |
834 | } |
835 | |
836 | template <> |
837 | Vectorized<int8_t> inline operator-(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
838 | return _mm256_sub_epi8(a, b); |
839 | } |
840 | |
841 | template <> |
842 | Vectorized<uint8_t> inline operator-(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
843 | return _mm256_sub_epi8(a, b); |
844 | } |
845 | |
846 | // Negation. Defined here so we can utilize operator- |
847 | inline Vectorized<int64_t> Vectorized<int64_t>::neg() const { |
848 | return Vectorized<int64_t>(0) - *this; |
849 | } |
850 | |
851 | inline Vectorized<int32_t> Vectorized<int32_t>::neg() const { |
852 | return Vectorized<int32_t>(0) - *this; |
853 | } |
854 | |
855 | inline Vectorized<int16_t> Vectorized<int16_t>::neg() const { |
856 | return Vectorized<int16_t>(0) - *this; |
857 | } |
858 | |
859 | inline Vectorized<int8_t> Vectorized<int8_t>::neg() const { |
860 | return Vectorized<int8_t>(0) - *this; |
861 | } |
862 | |
863 | inline Vectorized<uint8_t> Vectorized<uint8_t>::neg() const { |
864 | return Vectorized<uint8_t>(0) - *this; |
865 | } |
866 | |
867 | // Emulate operations with no native 64-bit support in avx, |
868 | // by extracting each element, performing the operation pointwise, |
869 | // then combining the results into a vector. |
870 | template <typename op_t> |
871 | Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const op_t& op) { |
872 | int64_t a0 = _mm256_extract_epi64(a, 0); |
873 | int64_t a1 = _mm256_extract_epi64(a, 1); |
874 | int64_t a2 = _mm256_extract_epi64(a, 2); |
875 | int64_t a3 = _mm256_extract_epi64(a, 3); |
876 | |
877 | int64_t b0 = _mm256_extract_epi64(b, 0); |
878 | int64_t b1 = _mm256_extract_epi64(b, 1); |
879 | int64_t b2 = _mm256_extract_epi64(b, 2); |
880 | int64_t b3 = _mm256_extract_epi64(b, 3); |
881 | |
882 | int64_t c0 = op(a0, b0); |
883 | int64_t c1 = op(a1, b1); |
884 | int64_t c2 = op(a2, b2); |
885 | int64_t c3 = op(a3, b3); |
886 | |
887 | return _mm256_set_epi64x(c3, c2, c1, c0); |
888 | } |
889 | |
890 | template <typename op_t> |
891 | Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const Vectorized<int64_t>& c, const op_t& op) { |
892 | int64_t a0 = _mm256_extract_epi64(a, 0); |
893 | int64_t a1 = _mm256_extract_epi64(a, 1); |
894 | int64_t a2 = _mm256_extract_epi64(a, 2); |
895 | int64_t a3 = _mm256_extract_epi64(a, 3); |
896 | |
897 | int64_t b0 = _mm256_extract_epi64(b, 0); |
898 | int64_t b1 = _mm256_extract_epi64(b, 1); |
899 | int64_t b2 = _mm256_extract_epi64(b, 2); |
900 | int64_t b3 = _mm256_extract_epi64(b, 3); |
901 | |
902 | int64_t c0 = _mm256_extract_epi64(c, 0); |
903 | int64_t c1 = _mm256_extract_epi64(c, 1); |
904 | int64_t c2 = _mm256_extract_epi64(c, 2); |
905 | int64_t c3 = _mm256_extract_epi64(c, 3); |
906 | |
907 | int64_t d0 = op(a0, b0, c0); |
908 | int64_t d1 = op(a1, b1, c1); |
909 | int64_t d2 = op(a2, b2, c2); |
910 | int64_t d3 = op(a3, b3, c3); |
911 | |
912 | return _mm256_set_epi64x(d3, d2, d1, d0); |
913 | } |
914 | |
915 | // AVX2 has no intrinsic for int64_t multiply so it needs to be emulated |
916 | // This could be implemented more efficiently using epi32 instructions |
917 | // This is also technically avx compatible, but then we'll need AVX |
918 | // code for add as well. |
919 | // Note: intentionally ignores undefined behavior like (-lowest * -1). |
920 | template <> |
921 | Vectorized<int64_t> inline operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
922 | return emulate(a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ {return a_point * b_point;}); |
923 | } |
924 | |
925 | template <> |
926 | Vectorized<int32_t> inline operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
927 | return _mm256_mullo_epi32(a, b); |
928 | } |
929 | |
930 | template <> |
931 | Vectorized<int16_t> inline operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
932 | return _mm256_mullo_epi16(a, b); |
933 | } |
934 | |
935 | template <typename T, typename Op> |
936 | Vectorized<T> inline int_elementwise_binary_256(const Vectorized<T>& a, const Vectorized<T>& b, Op op) { |
937 | T values_a[Vectorized<T>::size()]; |
938 | T values_b[Vectorized<T>::size()]; |
939 | a.store(values_a); |
940 | b.store(values_b); |
941 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
942 | values_a[i] = op(values_a[i], values_b[i]); |
943 | } |
944 | return Vectorized<T>::loadu(values_a); |
945 | } |
946 | |
947 | template <> |
948 | Vectorized<int8_t> inline operator*(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
949 | // We don't have an instruction for multiplying int8_t |
950 | return int_elementwise_binary_256(a, b, std::multiplies<int8_t>()); |
951 | } |
952 | |
953 | template <> |
954 | Vectorized<uint8_t> inline operator*(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
955 | // We don't have an instruction for multiplying uint8_t |
956 | return int_elementwise_binary_256(a, b, std::multiplies<uint8_t>()); |
957 | } |
958 | |
959 | template <> |
960 | Vectorized<int64_t> inline minimum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
961 | return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);}); |
962 | } |
963 | |
964 | template <> |
965 | Vectorized<int32_t> inline minimum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
966 | return _mm256_min_epi32(a, b); |
967 | } |
968 | |
969 | template <> |
970 | Vectorized<int16_t> inline minimum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
971 | return _mm256_min_epi16(a, b); |
972 | } |
973 | |
974 | template <> |
975 | Vectorized<int8_t> inline minimum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
976 | return _mm256_min_epi8(a, b); |
977 | } |
978 | |
979 | template <> |
980 | Vectorized<uint8_t> inline minimum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
981 | return _mm256_min_epu8(a, b); |
982 | } |
983 | |
984 | template <> |
985 | Vectorized<int64_t> inline maximum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
986 | return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);}); |
987 | } |
988 | |
989 | template <> |
990 | Vectorized<int32_t> inline maximum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
991 | return _mm256_max_epi32(a, b); |
992 | } |
993 | |
994 | template <> |
995 | Vectorized<int16_t> inline maximum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
996 | return _mm256_max_epi16(a, b); |
997 | } |
998 | |
999 | template <> |
1000 | Vectorized<int8_t> inline maximum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
1001 | return _mm256_max_epi8(a, b); |
1002 | } |
1003 | |
1004 | template <> |
1005 | Vectorized<uint8_t> inline maximum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
1006 | return _mm256_max_epu8(a, b); |
1007 | } |
1008 | |
1009 | template <> |
1010 | Vectorized<int64_t> inline clamp(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val, const Vectorized<int64_t>& max_val) { |
1011 | return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));}); |
1012 | } |
1013 | |
1014 | template <> |
1015 | Vectorized<int32_t> inline clamp(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val, const Vectorized<int32_t>& max_val) { |
1016 | return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val)); |
1017 | } |
1018 | |
1019 | template <> |
1020 | Vectorized<int16_t> inline clamp(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val, const Vectorized<int16_t>& max_val) { |
1021 | return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val)); |
1022 | } |
1023 | |
1024 | template <> |
1025 | Vectorized<int8_t> inline clamp(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val, const Vectorized<int8_t>& max_val) { |
1026 | return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val)); |
1027 | } |
1028 | |
1029 | template <> |
1030 | Vectorized<uint8_t> inline clamp(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val, const Vectorized<uint8_t>& max_val) { |
1031 | return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val)); |
1032 | } |
1033 | |
1034 | template <> |
1035 | Vectorized<int64_t> inline clamp_max(const Vectorized<int64_t>& a, const Vectorized<int64_t>& max_val) { |
1036 | return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);}); |
1037 | } |
1038 | |
1039 | template <> |
1040 | Vectorized<int32_t> inline clamp_max(const Vectorized<int32_t>& a, const Vectorized<int32_t>& max_val) { |
1041 | return _mm256_min_epi32(max_val, a); |
1042 | } |
1043 | |
1044 | template <> |
1045 | Vectorized<int16_t> inline clamp_max(const Vectorized<int16_t>& a, const Vectorized<int16_t>& max_val) { |
1046 | return _mm256_min_epi16(max_val, a); |
1047 | } |
1048 | |
1049 | template <> |
1050 | Vectorized<int8_t> inline clamp_max(const Vectorized<int8_t>& a, const Vectorized<int8_t>& max_val) { |
1051 | return _mm256_min_epi8(max_val, a); |
1052 | } |
1053 | |
1054 | template <> |
1055 | Vectorized<uint8_t> inline clamp_max(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& max_val) { |
1056 | return _mm256_min_epu8(max_val, a); |
1057 | } |
1058 | |
1059 | template <> |
1060 | Vectorized<int64_t> inline clamp_min(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val) { |
1061 | return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);}); |
1062 | } |
1063 | |
1064 | template <> |
1065 | Vectorized<int32_t> inline clamp_min(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val) { |
1066 | return _mm256_max_epi32(min_val, a); |
1067 | } |
1068 | |
1069 | template <> |
1070 | Vectorized<int16_t> inline clamp_min(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val) { |
1071 | return _mm256_max_epi16(min_val, a); |
1072 | } |
1073 | |
1074 | template <> |
1075 | Vectorized<int8_t> inline clamp_min(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val) { |
1076 | return _mm256_max_epi8(min_val, a); |
1077 | } |
1078 | |
1079 | template <> |
1080 | Vectorized<uint8_t> inline clamp_min(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val) { |
1081 | return _mm256_max_epu8(min_val, a); |
1082 | } |
1083 | |
1084 | template<typename T> |
1085 | Vectorized<int32_t> inline convert_to_int32(const T* ptr) { |
1086 | return Vectorized<int32_t>::loadu(ptr); |
1087 | } |
1088 | |
1089 | template<> |
1090 | Vectorized<int32_t> inline convert_to_int32<int8_t>(const int8_t* ptr) { |
1091 | return _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr))); |
1092 | } |
1093 | |
1094 | template<> |
1095 | Vectorized<int32_t> inline convert_to_int32<uint8_t>(const uint8_t* ptr) { |
1096 | return _mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr))); |
1097 | } |
1098 | |
1099 | template <> |
1100 | Vectorized<int64_t> inline operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
1101 | return int_elementwise_binary_256(a, b, std::divides<int64_t>()); |
1102 | } |
1103 | template <> |
1104 | Vectorized<int32_t> inline operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
1105 | return int_elementwise_binary_256(a, b, std::divides<int32_t>()); |
1106 | } |
1107 | template <> |
1108 | Vectorized<int16_t> inline operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
1109 | return int_elementwise_binary_256(a, b, std::divides<int16_t>()); |
1110 | } |
1111 | template <> |
1112 | Vectorized<int8_t> inline operator/(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
1113 | return int_elementwise_binary_256(a, b, std::divides<int8_t>()); |
1114 | } |
1115 | template <> |
1116 | Vectorized<uint8_t> inline operator/(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
1117 | return int_elementwise_binary_256(a, b, std::divides<uint8_t>()); |
1118 | } |
1119 | |
1120 | template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
1121 | inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) { |
1122 | return _mm256_and_si256(a, b); |
1123 | } |
1124 | template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
1125 | inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) { |
1126 | return _mm256_or_si256(a, b); |
1127 | } |
1128 | template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
1129 | inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) { |
1130 | return _mm256_xor_si256(a, b); |
1131 | } |
1132 | template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
1133 | inline Vectorized<T> operator~(const Vectorized<T>& a) { |
1134 | return _mm256_xor_si256(a, _mm256_set1_epi32(-1)); |
1135 | } |
1136 | |
1137 | inline Vectorized<int64_t> Vectorized<int64_t>::eq(const Vectorized<int64_t>& other) const { |
1138 | return (*this == other) & Vectorized<int64_t>(1); |
1139 | } |
1140 | |
1141 | inline Vectorized<int64_t> Vectorized<int64_t>::ne(const Vectorized<int64_t>& other) const { |
1142 | return (*this != other) & Vectorized<int64_t>(1); |
1143 | } |
1144 | |
1145 | inline Vectorized<int64_t> Vectorized<int64_t>::gt(const Vectorized<int64_t>& other) const { |
1146 | return (*this > other) & Vectorized<int64_t>(1); |
1147 | } |
1148 | |
1149 | inline Vectorized<int64_t> Vectorized<int64_t>::ge(const Vectorized<int64_t>& other) const { |
1150 | return (*this >= other) & Vectorized<int64_t>(1); |
1151 | } |
1152 | |
1153 | inline Vectorized<int64_t> Vectorized<int64_t>::lt(const Vectorized<int64_t>& other) const { |
1154 | return (*this < other) & Vectorized<int64_t>(1); |
1155 | } |
1156 | |
1157 | inline Vectorized<int64_t> Vectorized<int64_t>::le(const Vectorized<int64_t>& other) const { |
1158 | return (*this <= other) & Vectorized<int64_t>(1); |
1159 | } |
1160 | |
1161 | inline Vectorized<int32_t> Vectorized<int32_t>::eq(const Vectorized<int32_t>& other) const { |
1162 | return (*this == other) & Vectorized<int32_t>(1); |
1163 | } |
1164 | |
1165 | inline Vectorized<int32_t> Vectorized<int32_t>::ne(const Vectorized<int32_t>& other) const { |
1166 | return (*this != other) & Vectorized<int32_t>(1); |
1167 | } |
1168 | |
1169 | inline Vectorized<int32_t> Vectorized<int32_t>::gt(const Vectorized<int32_t>& other) const { |
1170 | return (*this > other) & Vectorized<int32_t>(1); |
1171 | } |
1172 | |
1173 | inline Vectorized<int32_t> Vectorized<int32_t>::ge(const Vectorized<int32_t>& other) const { |
1174 | return (*this >= other) & Vectorized<int32_t>(1); |
1175 | } |
1176 | |
1177 | inline Vectorized<int32_t> Vectorized<int32_t>::lt(const Vectorized<int32_t>& other) const { |
1178 | return (*this < other) & Vectorized<int32_t>(1); |
1179 | } |
1180 | |
1181 | inline Vectorized<int32_t> Vectorized<int32_t>::le(const Vectorized<int32_t>& other) const { |
1182 | return (*this <= other) & Vectorized<int32_t>(1); |
1183 | } |
1184 | |
1185 | inline Vectorized<int16_t> Vectorized<int16_t>::eq(const Vectorized<int16_t>& other) const { |
1186 | return (*this == other) & Vectorized<int16_t>(1); |
1187 | } |
1188 | |
1189 | inline Vectorized<int16_t> Vectorized<int16_t>::ne(const Vectorized<int16_t>& other) const { |
1190 | return (*this != other) & Vectorized<int16_t>(1); |
1191 | } |
1192 | |
1193 | inline Vectorized<int16_t> Vectorized<int16_t>::gt(const Vectorized<int16_t>& other) const { |
1194 | return (*this > other) & Vectorized<int16_t>(1); |
1195 | } |
1196 | |
1197 | inline Vectorized<int16_t> Vectorized<int16_t>::ge(const Vectorized<int16_t>& other) const { |
1198 | return (*this >= other) & Vectorized<int16_t>(1); |
1199 | } |
1200 | |
1201 | inline Vectorized<int16_t> Vectorized<int16_t>::lt(const Vectorized<int16_t>& other) const { |
1202 | return (*this < other) & Vectorized<int16_t>(1); |
1203 | } |
1204 | |
1205 | inline Vectorized<int16_t> Vectorized<int16_t>::le(const Vectorized<int16_t>& other) const { |
1206 | return (*this <= other) & Vectorized<int16_t>(1); |
1207 | } |
1208 | |
1209 | inline Vectorized<int8_t> Vectorized<int8_t>::eq(const Vectorized<int8_t>& other) const { |
1210 | return (*this == other) & Vectorized<int8_t>(1); |
1211 | } |
1212 | |
1213 | inline Vectorized<int8_t> Vectorized<int8_t>::ne(const Vectorized<int8_t>& other) const { |
1214 | return (*this != other) & Vectorized<int8_t>(1); |
1215 | } |
1216 | |
1217 | inline Vectorized<int8_t> Vectorized<int8_t>::gt(const Vectorized<int8_t>& other) const { |
1218 | return (*this > other) & Vectorized<int8_t>(1); |
1219 | } |
1220 | |
1221 | inline Vectorized<int8_t> Vectorized<int8_t>::ge(const Vectorized<int8_t>& other) const { |
1222 | return (*this >= other) & Vectorized<int8_t>(1); |
1223 | } |
1224 | |
1225 | inline Vectorized<int8_t> Vectorized<int8_t>::lt(const Vectorized<int8_t>& other) const { |
1226 | return (*this < other) & Vectorized<int8_t>(1); |
1227 | } |
1228 | |
1229 | inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other) const { |
1230 | return (*this <= other) & Vectorized<int8_t>(1); |
1231 | } |
1232 | |
1233 | inline Vectorized<uint8_t> Vectorized<uint8_t>::eq(const Vectorized<uint8_t>& other) const { |
1234 | return (*this == other) & Vectorized<uint8_t>(1); |
1235 | } |
1236 | |
1237 | inline Vectorized<uint8_t> Vectorized<uint8_t>::ne(const Vectorized<uint8_t>& other) const { |
1238 | return (*this != other) & Vectorized<uint8_t>(1); |
1239 | } |
1240 | |
1241 | inline Vectorized<uint8_t> Vectorized<uint8_t>::gt(const Vectorized<uint8_t>& other) const { |
1242 | return (*this > other) & Vectorized<uint8_t>(1); |
1243 | } |
1244 | |
1245 | inline Vectorized<uint8_t> Vectorized<uint8_t>::ge(const Vectorized<uint8_t>& other) const { |
1246 | return (*this >= other) & Vectorized<uint8_t>(1); |
1247 | } |
1248 | |
1249 | inline Vectorized<uint8_t> Vectorized<uint8_t>::lt(const Vectorized<uint8_t>& other) const { |
1250 | return (*this < other) & Vectorized<uint8_t>(1); |
1251 | } |
1252 | |
1253 | inline Vectorized<uint8_t> Vectorized<uint8_t>::le(const Vectorized<uint8_t>& other) const { |
1254 | return (*this <= other) & Vectorized<uint8_t>(1); |
1255 | } |
1256 | |
1257 | template <bool left_shift> |
1258 | Vectorized<int16_t> inline shift_256_16(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
1259 | // No vector instruction for shifting int16_t, so emulating it instead. |
1260 | |
1261 | // Control masks for shuffle operation, treating 256 bits as an |
1262 | // array of 16-bit elements, and considering pairs of neighboring |
1263 | // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and |
1264 | // M!=N) is set so that shuffle will move element with index M from |
1265 | // input pair into element with index N in output pair, and element |
1266 | // with index M in output pair will be set to all 0s. |
1267 | __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80, |
1268 | 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80, |
1269 | 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80, |
1270 | 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80); |
1271 | __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26, |
1272 | 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18, |
1273 | 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10, |
1274 | 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2); |
1275 | |
1276 | // Masks for bitwise and operation, treating 256 bits as an array of |
1277 | // 16-bit elements, and considering them in pairs of neighboring |
1278 | // elements. A mask named "keep_M" (M in [0,1]) is set so that |
1279 | // bitwise and will copy element with index M from input pair into |
1280 | // element with the same index in output pair, while the other |
1281 | // element in output pair will be set to all 0s. |
1282 | __m256i keep_0 = _mm256_set1_epi32(0xFFFF); |
1283 | __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); |
1284 | |
1285 | // Take each 16-bit element with idx%2==0 from input array to be |
1286 | // shifted and extend it to 32 bits so that 0s are added to the |
1287 | // right. Then, perform shifting on this 32-bit number. Upper 16 |
1288 | // bits will be proper result of shifting original 16-bit number, so |
1289 | // write them to result array, into the same position from which |
1290 | // corresponding input element is taken. Also, make sure that |
1291 | // result array elements with idx%2!=0 are set to all 0s. |
1292 | // |
1293 | // Note that number of bits to shift for is extended to 32 bits by |
1294 | // adding 0s to the left. That means this number is not properly |
1295 | // sign-extended for negative values. However, number of bits to |
1296 | // shift is treated as an unsigned integer by respective shift |
1297 | // intrinsics anyway so if negative then either with or without |
1298 | // proper sign extension, it will be interpreted as a number greater |
1299 | // than 32, and the shifting result will be the same. |
1300 | __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); |
1301 | __m256i b0 = _mm256_and_si256(b, keep_0); |
1302 | __m256i c0; |
1303 | if (left_shift) |
1304 | c0 = _mm256_sllv_epi32(a0, b0); |
1305 | else |
1306 | c0 = _mm256_srav_epi32(a0, b0); |
1307 | c0 = _mm256_shuffle_epi8(c0, ctl_1_0); |
1308 | |
1309 | // Peform shifting the same way for input array elements with |
1310 | // idx%2==1. |
1311 | __m256i a1 = _mm256_and_si256(a, keep_1); |
1312 | __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); |
1313 | __m256i c1; |
1314 | if (left_shift) |
1315 | c1 = _mm256_sllv_epi32(a1, b1); |
1316 | else |
1317 | c1 = _mm256_srav_epi32(a1, b1); |
1318 | c1 = _mm256_and_si256(c1, keep_1); |
1319 | |
1320 | // Merge partial results into the final result. |
1321 | __m256i c = _mm256_or_si256(c0, c1); |
1322 | |
1323 | return c; |
1324 | } |
1325 | |
1326 | template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0> |
1327 | Vectorized<T> inline shift_256_8(const Vectorized<T>& a, const Vectorized<T>& b) { |
1328 | // No vector instruction for shifting int8_t/uint8_t, so emulating |
1329 | // it instead. |
1330 | |
1331 | // Control masks for shuffle operation, treating 256 bits as an |
1332 | // array of 8-bit elements, and considering quadruples of |
1333 | // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N |
1334 | // in [0,1,2,3], and M!=N) is set so that shuffle will move element |
1335 | // with index M from input quadruple into element with index N in |
1336 | // output quadruple, and other elements in output quadruple will be |
1337 | // set to all 0s. |
1338 | __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80, |
1339 | 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80, |
1340 | 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80, |
1341 | 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80); |
1342 | __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25, |
1343 | 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17, |
1344 | 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9, |
1345 | 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1); |
1346 | __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80, |
1347 | 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80, |
1348 | 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80, |
1349 | 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80); |
1350 | __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26, |
1351 | 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18, |
1352 | 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10, |
1353 | 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2); |
1354 | __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80, |
1355 | 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80, |
1356 | 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80, |
1357 | 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80); |
1358 | __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, |
1359 | 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, |
1360 | 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, |
1361 | 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3); |
1362 | __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, |
1363 | 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, |
1364 | 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, |
1365 | 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80); |
1366 | __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80, |
1367 | 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80, |
1368 | 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80, |
1369 | 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80); |
1370 | |
1371 | // Masks for bitwise and operation, treating 256 bits as an array of |
1372 | // 8-bit elements, and considering them in quadruples of neighboring |
1373 | // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that |
1374 | // bitwise and will copy element with index M from input quadruple |
1375 | // into element with the same index in output quadruple, while the |
1376 | // other elements in output quadruple will be set to all 0s. |
1377 | __m256i keep_0 = _mm256_set1_epi32(0xFF); |
1378 | __m256i keep_3 = _mm256_set1_epi32(0xFF000000); |
1379 | |
1380 | // Take each 8-bit element with idx%4==0 from input array to be |
1381 | // shifted and extend it to 32 bits so that 0s are added to the |
1382 | // right. Then, perform shifting on this 32-bit number. Upper 8 |
1383 | // bits will be proper result of shifting original 8-bit number, so |
1384 | // write them to result array, into the same position from which |
1385 | // corresponding input element is taken. Also, make sure that |
1386 | // result array elements with idx%4!=0 are set to all 0s. |
1387 | // |
1388 | // Note that number of bits to shift for is extended to 32 bits by |
1389 | // adding 0s to the left. That means this number is not properly |
1390 | // sign-extended for negative values. However, number of bits to |
1391 | // shift is treated as an unsigned integer by respective shift |
1392 | // intrinsics anyway so if negative then either with or without |
1393 | // proper sign extension, it will be interpreted as a number greater |
1394 | // than 32, and the shifting result will be the same. |
1395 | __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); |
1396 | __m256i b0 = _mm256_and_si256(b, keep_0); |
1397 | __m256i c0; |
1398 | if (left_shift) |
1399 | c0 = _mm256_sllv_epi32(a0, b0); |
1400 | else |
1401 | if (std::is_same<T, int8_t>::value) |
1402 | c0 = _mm256_srav_epi32(a0, b0); |
1403 | else |
1404 | c0 = _mm256_srlv_epi32(a0, b0); |
1405 | c0 = _mm256_shuffle_epi8(c0, ctl_3_0); |
1406 | |
1407 | // Peform shifting the same way for input array elements with |
1408 | // idx%4==1. |
1409 | __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); |
1410 | __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); |
1411 | __m256i c1; |
1412 | if (left_shift) |
1413 | c1 = _mm256_sllv_epi32(a1, b1); |
1414 | else |
1415 | if (std::is_same<T, int8_t>::value) |
1416 | c1 = _mm256_srav_epi32(a1, b1); |
1417 | else |
1418 | c1 = _mm256_srlv_epi32(a1, b1); |
1419 | c1 = _mm256_shuffle_epi8(c1, ctl_3_1); |
1420 | |
1421 | // Peform shifting the same way for input array elements with |
1422 | // idx%4==2. |
1423 | __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); |
1424 | __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); |
1425 | __m256i c2; |
1426 | if (left_shift) |
1427 | c2 = _mm256_sllv_epi32(a2, b2); |
1428 | else |
1429 | if (std::is_same<T, int8_t>::value) |
1430 | c2 = _mm256_srav_epi32(a2, b2); |
1431 | else |
1432 | c2 = _mm256_srlv_epi32(a2, b2); |
1433 | c2 = _mm256_shuffle_epi8(c2, ctl_3_2); |
1434 | |
1435 | // Peform shifting the same way for input array elements with |
1436 | // idx%4==3. |
1437 | __m256i a3 = _mm256_and_si256(a, keep_3); |
1438 | __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); |
1439 | __m256i c3; |
1440 | if (left_shift) |
1441 | c3 = _mm256_sllv_epi32(a3, b3); |
1442 | else |
1443 | if (std::is_same<T, int8_t>::value) |
1444 | c3 = _mm256_srav_epi32(a3, b3); |
1445 | else |
1446 | c3 = _mm256_srlv_epi32(a3, b3); |
1447 | c3 = _mm256_and_si256(c3, keep_3); |
1448 | |
1449 | // Merge partial results into the final result. |
1450 | __m256i c01 = _mm256_or_si256(c0, c1); |
1451 | __m256i c23 = _mm256_or_si256(c2, c3); |
1452 | __m256i c = _mm256_or_si256(c01, c23); |
1453 | |
1454 | return c; |
1455 | } |
1456 | |
1457 | template <> |
1458 | Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
1459 | return _mm256_sllv_epi64(a, b); |
1460 | } |
1461 | |
1462 | template <> |
1463 | Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
1464 | return _mm256_sllv_epi32(a, b); |
1465 | } |
1466 | |
1467 | template <> |
1468 | Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
1469 | return shift_256_16<true>(a, b); |
1470 | } |
1471 | |
1472 | template <> |
1473 | Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
1474 | return shift_256_8<true>(a, b); |
1475 | } |
1476 | |
1477 | template <> |
1478 | Vectorized<uint8_t> inline operator<<(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
1479 | return shift_256_8<true>(a, b); |
1480 | } |
1481 | |
1482 | template <> |
1483 | Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) { |
1484 | // No vector instruction for right shifting int64_t, so emulating it |
1485 | // instead. |
1486 | |
1487 | // Shift the number logically to the right, thus filling the most |
1488 | // significant bits with 0s. Then, replace these bits with the sign |
1489 | // bit. |
1490 | __m256i sign_bits = _mm256_cmpgt_epi64(_mm256_set1_epi64x(0), a); |
1491 | __m256i b_inv_mod_64 = _mm256_sub_epi64(_mm256_set1_epi64x(64), b); |
1492 | __m256i sign_ext = _mm256_sllv_epi64(sign_bits, b_inv_mod_64); |
1493 | __m256i c = _mm256_srlv_epi64(a, b); |
1494 | c = _mm256_or_si256(c, sign_ext); |
1495 | |
1496 | return c; |
1497 | } |
1498 | |
1499 | template <> |
1500 | Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) { |
1501 | return _mm256_srav_epi32(a, b); |
1502 | } |
1503 | |
1504 | template <> |
1505 | Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) { |
1506 | return shift_256_16<false>(a, b); |
1507 | } |
1508 | |
1509 | template <> |
1510 | Vectorized<int8_t> inline operator>>(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) { |
1511 | return shift_256_8<false>(a, b); |
1512 | } |
1513 | |
1514 | template <> |
1515 | Vectorized<uint8_t> inline operator>>(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) { |
1516 | return shift_256_8<false>(a, b); |
1517 | } |
1518 | |
1519 | #endif |
1520 | |
1521 | }}} |
1522 | |