1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5
6#include <ATen/cpu/vec/intrinsics.h>
7#include <ATen/cpu/vec/vec_base.h>
8#include <c10/macros/Macros.h>
9#include <c10/util/irange.h>
10#include <iostream>
11
12namespace at {
13namespace vec {
14inline namespace CPU_CAPABILITY {
15
16#ifdef CPU_CAPABILITY_AVX2
17
18struct Vectorizedi {
19protected:
20 __m256i values;
21
22 static inline __m256i invert(const __m256i& v) {
23 const auto ones = _mm256_set1_epi64x(-1);
24 return _mm256_xor_si256(ones, v);
25 }
26public:
27 Vectorizedi() {}
28 Vectorizedi(__m256i v) : values(v) {}
29 operator __m256i() const {
30 return values;
31 }
32};
33
34#else
35
36struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined
37
38#endif // CPU_CAPABILITY_AVX2
39
40#ifdef CPU_CAPABILITY_AVX2
41
42template <>
43class Vectorized<int64_t> : public Vectorizedi {
44private:
45 static const Vectorized<int64_t> ones;
46public:
47 using value_type = int64_t;
48 using size_type = int;
49 static constexpr size_type size() {
50 return 4;
51 }
52 using Vectorizedi::Vectorizedi;
53 Vectorized() {}
54 Vectorized(int64_t v) { values = _mm256_set1_epi64x(v); }
55 Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
56 values = _mm256_setr_epi64x(val1, val2, val3, val4);
57 }
58 template <int64_t mask>
59 static Vectorized<int64_t> blend(Vectorized<int64_t> a, Vectorized<int64_t> b) {
60 __at_align__ int64_t tmp_values[size()];
61 a.store(tmp_values);
62 if (mask & 0x01)
63 tmp_values[0] = _mm256_extract_epi64(b.values, 0);
64 if (mask & 0x02)
65 tmp_values[1] = _mm256_extract_epi64(b.values, 1);
66 if (mask & 0x04)
67 tmp_values[2] = _mm256_extract_epi64(b.values, 2);
68 if (mask & 0x08)
69 tmp_values[3] = _mm256_extract_epi64(b.values, 3);
70 return loadu(tmp_values);
71 }
72 static Vectorized<int64_t> blendv(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b,
73 const Vectorized<int64_t>& mask) {
74 return _mm256_blendv_epi8(a.values, b.values, mask.values);
75 }
76 template <typename step_t>
77 static Vectorized<int64_t> arange(int64_t base = 0, step_t step = static_cast<step_t>(1)) {
78 return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
79 }
80 static Vectorized<int64_t>
81 set(Vectorized<int64_t> a, Vectorized<int64_t> b, int64_t count = size()) {
82 switch (count) {
83 case 0:
84 return a;
85 case 1:
86 return blend<1>(a, b);
87 case 2:
88 return blend<3>(a, b);
89 case 3:
90 return blend<7>(a, b);
91 }
92 return b;
93 }
94 static Vectorized<int64_t> loadu(const void* ptr) {
95 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
96 }
97 static Vectorized<int64_t> loadu(const void* ptr, int64_t count) {
98 __at_align__ int64_t tmp_values[size()];
99 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
100 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
101 // instructions while a loop would be compiled to one instruction.
102 for (const auto i : c10::irange(size())) {
103 tmp_values[i] = 0;
104 }
105 std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
106 return loadu(tmp_values);
107 }
108 void store(void* ptr, int count = size()) const {
109 if (count == size()) {
110 // ptr need not to be aligned here. See
111 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
112 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
113 } else if (count > 0) {
114 __at_align__ int64_t tmp_values[size()];
115 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
116 std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
117 }
118 }
119 const int64_t& operator[](int idx) const = delete;
120 int64_t& operator[](int idx) = delete;
121 Vectorized<int64_t> abs() const {
122 auto zero = _mm256_set1_epi64x(0);
123 auto is_larger = _mm256_cmpgt_epi64(zero, values);
124 auto inverse = _mm256_xor_si256(values, is_larger);
125 return _mm256_sub_epi64(inverse, is_larger);
126 }
127 Vectorized<int64_t> real() const {
128 return *this;
129 }
130 Vectorized<int64_t> imag() const {
131 return _mm256_set1_epi64x(0);
132 }
133 Vectorized<int64_t> conj() const {
134 return *this;
135 }
136 Vectorized<int64_t> neg() const;
137 Vectorized<int64_t> operator==(const Vectorized<int64_t>& other) const {
138 return _mm256_cmpeq_epi64(values, other.values);
139 }
140 Vectorized<int64_t> operator!=(const Vectorized<int64_t>& other) const {
141 return invert(_mm256_cmpeq_epi64(values, other.values));
142 }
143 Vectorized<int64_t> operator<(const Vectorized<int64_t>& other) const {
144 return _mm256_cmpgt_epi64(other.values, values);
145 }
146 Vectorized<int64_t> operator<=(const Vectorized<int64_t>& other) const {
147 return invert(_mm256_cmpgt_epi64(values, other.values));
148 }
149 Vectorized<int64_t> operator>(const Vectorized<int64_t>& other) const {
150 return _mm256_cmpgt_epi64(values, other.values);
151 }
152 Vectorized<int64_t> operator>=(const Vectorized<int64_t>& other) const {
153 return invert(_mm256_cmpgt_epi64(other.values, values));
154 }
155
156 Vectorized<int64_t> eq(const Vectorized<int64_t>& other) const;
157 Vectorized<int64_t> ne(const Vectorized<int64_t>& other) const;
158 Vectorized<int64_t> gt(const Vectorized<int64_t>& other) const;
159 Vectorized<int64_t> ge(const Vectorized<int64_t>& other) const;
160 Vectorized<int64_t> lt(const Vectorized<int64_t>& other) const;
161 Vectorized<int64_t> le(const Vectorized<int64_t>& other) const;
162};
163
164template <>
165class Vectorized<int32_t> : public Vectorizedi {
166private:
167 static const Vectorized<int32_t> ones;
168public:
169 using value_type = int32_t;
170 static constexpr int size() {
171 return 8;
172 }
173 using Vectorizedi::Vectorizedi;
174 Vectorized() {}
175 Vectorized(int32_t v) { values = _mm256_set1_epi32(v); }
176 Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
177 int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
178 values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
179 }
180 template <int64_t mask>
181 static Vectorized<int32_t> blend(Vectorized<int32_t> a, Vectorized<int32_t> b) {
182 return _mm256_blend_epi32(a, b, mask);
183 }
184 static Vectorized<int32_t> blendv(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b,
185 const Vectorized<int32_t>& mask) {
186 return _mm256_blendv_epi8(a.values, b.values, mask.values);
187 }
188 template <typename step_t>
189 static Vectorized<int32_t> arange(int32_t base = 0, step_t step = static_cast<step_t>(1)) {
190 return Vectorized<int32_t>(
191 base, base + step, base + 2 * step, base + 3 * step,
192 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
193 }
194 static Vectorized<int32_t>
195 set(Vectorized<int32_t> a, Vectorized<int32_t> b, int32_t count = size()) {
196 switch (count) {
197 case 0:
198 return a;
199 case 1:
200 return blend<1>(a, b);
201 case 2:
202 return blend<3>(a, b);
203 case 3:
204 return blend<7>(a, b);
205 case 4:
206 return blend<15>(a, b);
207 case 5:
208 return blend<31>(a, b);
209 case 6:
210 return blend<63>(a, b);
211 case 7:
212 return blend<127>(a, b);
213 }
214 return b;
215 }
216 static Vectorized<int32_t> loadu(const void* ptr) {
217 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
218 }
219 static Vectorized<int32_t> loadu(const void* ptr, int32_t count) {
220 __at_align__ int32_t tmp_values[size()];
221 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
222 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
223 // instructions while a loop would be compiled to one instruction.
224 for (const auto i : c10::irange(size())) {
225 tmp_values[i] = 0;
226 }
227 std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
228 return loadu(tmp_values);
229 }
230 void store(void* ptr, int count = size()) const {
231 if (count == size()) {
232 // ptr need not to be aligned here. See
233 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
234 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
235 } else if (count > 0) {
236 __at_align__ int32_t tmp_values[size()];
237 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
238 std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
239 }
240 }
241 const int32_t& operator[](int idx) const = delete;
242 int32_t& operator[](int idx) = delete;
243 Vectorized<int32_t> abs() const {
244 return _mm256_abs_epi32(values);
245 }
246 Vectorized<int32_t> real() const {
247 return *this;
248 }
249 Vectorized<int32_t> imag() const {
250 return _mm256_set1_epi32(0);
251 }
252 Vectorized<int32_t> conj() const {
253 return *this;
254 }
255 Vectorized<int32_t> neg() const;
256 Vectorized<int32_t> operator==(const Vectorized<int32_t>& other) const {
257 return _mm256_cmpeq_epi32(values, other.values);
258 }
259 Vectorized<int32_t> operator!=(const Vectorized<int32_t>& other) const {
260 return invert(_mm256_cmpeq_epi32(values, other.values));
261 }
262 Vectorized<int32_t> operator<(const Vectorized<int32_t>& other) const {
263 return _mm256_cmpgt_epi32(other.values, values);
264 }
265 Vectorized<int32_t> operator<=(const Vectorized<int32_t>& other) const {
266 return invert(_mm256_cmpgt_epi32(values, other.values));
267 }
268 Vectorized<int32_t> operator>(const Vectorized<int32_t>& other) const {
269 return _mm256_cmpgt_epi32(values, other.values);
270 }
271 Vectorized<int32_t> operator>=(const Vectorized<int32_t>& other) const {
272 return invert(_mm256_cmpgt_epi32(other.values, values));
273 }
274 Vectorized<int32_t> eq(const Vectorized<int32_t>& other) const;
275 Vectorized<int32_t> ne(const Vectorized<int32_t>& other) const;
276 Vectorized<int32_t> gt(const Vectorized<int32_t>& other) const;
277 Vectorized<int32_t> ge(const Vectorized<int32_t>& other) const;
278 Vectorized<int32_t> lt(const Vectorized<int32_t>& other) const;
279 Vectorized<int32_t> le(const Vectorized<int32_t>& other) const;
280};
281
282template <>
283inline void convert(const int32_t *src, float *dst, int64_t n) {
284 int64_t i;
285 // int32_t and float have same size
286#ifndef _MSC_VER
287# pragma unroll
288#endif
289 for (i = 0; i <= (n - Vectorized<int32_t>::size()); i += Vectorized<int32_t>::size()) {
290 auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
291 auto output_vec = _mm256_cvtepi32_ps(input_vec);
292 _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
293 }
294#ifndef _MSC_VER
295# pragma unroll
296#endif
297 for (; i < n; i++) {
298 dst[i] = static_cast<float>(src[i]);
299 }
300}
301
302template <>
303inline void convert(const int32_t *src, double *dst, int64_t n) {
304 int64_t i;
305 // int32_t has half the size of double
306#ifndef _MSC_VER
307# pragma unroll
308#endif
309 for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
310 auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
311 auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
312 _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
313 }
314#ifndef _MSC_VER
315# pragma unroll
316#endif
317 for (; i < n; i++) {
318 dst[i] = static_cast<double>(src[i]);
319 }
320}
321
322template <>
323class Vectorized<int16_t> : public Vectorizedi {
324private:
325 static const Vectorized<int16_t> ones;
326public:
327 using value_type = int16_t;
328 static constexpr int size() {
329 return 16;
330 }
331 using Vectorizedi::Vectorizedi;
332 Vectorized() {}
333 Vectorized(int16_t v) { values = _mm256_set1_epi16(v); }
334 Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
335 int16_t val5, int16_t val6, int16_t val7, int16_t val8,
336 int16_t val9, int16_t val10, int16_t val11, int16_t val12,
337 int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
338 values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
339 val9, val10, val11, val12, val13, val14, val15, val16);
340 }
341 template <int64_t mask>
342 static Vectorized<int16_t> blend(Vectorized<int16_t> a, Vectorized<int16_t> b) {
343 __at_align__ int16_t tmp_values[size()];
344 a.store(tmp_values);
345 if (mask & 0x01)
346 tmp_values[0] = _mm256_extract_epi16(b.values, 0);
347 if (mask & 0x02)
348 tmp_values[1] = _mm256_extract_epi16(b.values, 1);
349 if (mask & 0x04)
350 tmp_values[2] = _mm256_extract_epi16(b.values, 2);
351 if (mask & 0x08)
352 tmp_values[3] = _mm256_extract_epi16(b.values, 3);
353 if (mask & 0x10)
354 tmp_values[4] = _mm256_extract_epi16(b.values, 4);
355 if (mask & 0x20)
356 tmp_values[5] = _mm256_extract_epi16(b.values, 5);
357 if (mask & 0x40)
358 tmp_values[6] = _mm256_extract_epi16(b.values, 6);
359 if (mask & 0x80)
360 tmp_values[7] = _mm256_extract_epi16(b.values, 7);
361 if (mask & 0x100)
362 tmp_values[8] = _mm256_extract_epi16(b.values, 8);
363 if (mask & 0x200)
364 tmp_values[9] = _mm256_extract_epi16(b.values, 9);
365 if (mask & 0x400)
366 tmp_values[10] = _mm256_extract_epi16(b.values, 10);
367 if (mask & 0x800)
368 tmp_values[11] = _mm256_extract_epi16(b.values, 11);
369 if (mask & 0x1000)
370 tmp_values[12] = _mm256_extract_epi16(b.values, 12);
371 if (mask & 0x2000)
372 tmp_values[13] = _mm256_extract_epi16(b.values, 13);
373 if (mask & 0x4000)
374 tmp_values[14] = _mm256_extract_epi16(b.values, 14);
375 if (mask & 0x8000)
376 tmp_values[15] = _mm256_extract_epi16(b.values, 15);
377 return loadu(tmp_values);
378 }
379 static Vectorized<int16_t> blendv(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b,
380 const Vectorized<int16_t>& mask) {
381 return _mm256_blendv_epi8(a.values, b.values, mask.values);
382 }
383 template <typename step_t>
384 static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
385 return Vectorized<int16_t>(
386 base, base + step, base + 2 * step, base + 3 * step,
387 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
388 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
389 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
390 }
391 static Vectorized<int16_t>
392 set(Vectorized<int16_t> a, Vectorized<int16_t> b, int16_t count = size()) {
393 switch (count) {
394 case 0:
395 return a;
396 case 1:
397 return blend<1>(a, b);
398 case 2:
399 return blend<3>(a, b);
400 case 3:
401 return blend<7>(a, b);
402 case 4:
403 return blend<15>(a, b);
404 case 5:
405 return blend<31>(a, b);
406 case 6:
407 return blend<63>(a, b);
408 case 7:
409 return blend<127>(a, b);
410 case 8:
411 return blend<255>(a, b);
412 case 9:
413 return blend<511>(a, b);
414 case 10:
415 return blend<1023>(a, b);
416 case 11:
417 return blend<2047>(a, b);
418 case 12:
419 return blend<4095>(a, b);
420 case 13:
421 return blend<8191>(a, b);
422 case 14:
423 return blend<16383>(a, b);
424 case 15:
425 return blend<32767>(a, b);
426 }
427 return b;
428 }
429 static Vectorized<int16_t> loadu(const void* ptr) {
430 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
431 }
432 static Vectorized<int16_t> loadu(const void* ptr, int16_t count) {
433 __at_align__ int16_t tmp_values[size()];
434 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
435 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
436 // instructions while a loop would be compiled to one instruction.
437 for (const auto i : c10::irange(size())) {
438 tmp_values[i] = 0;
439 }
440 std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
441 return loadu(tmp_values);
442 }
443 void store(void* ptr, int count = size()) const {
444 if (count == size()) {
445 // ptr need not to be aligned here. See
446 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
447 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
448 } else if (count > 0) {
449 __at_align__ int16_t tmp_values[size()];
450 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
451 std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
452 }
453 }
454 const int16_t& operator[](int idx) const = delete;
455 int16_t& operator[](int idx) = delete;
456 Vectorized<int16_t> abs() const {
457 return _mm256_abs_epi16(values);
458 }
459 Vectorized<int16_t> real() const {
460 return *this;
461 }
462 Vectorized<int16_t> imag() const {
463 return _mm256_set1_epi16(0);
464 }
465 Vectorized<int16_t> conj() const {
466 return *this;
467 }
468 Vectorized<int16_t> neg() const;
469 Vectorized<int16_t> operator==(const Vectorized<int16_t>& other) const {
470 return _mm256_cmpeq_epi16(values, other.values);
471 }
472 Vectorized<int16_t> operator!=(const Vectorized<int16_t>& other) const {
473 return invert(_mm256_cmpeq_epi16(values, other.values));
474 }
475 Vectorized<int16_t> operator<(const Vectorized<int16_t>& other) const {
476 return _mm256_cmpgt_epi16(other.values, values);
477 }
478 Vectorized<int16_t> operator<=(const Vectorized<int16_t>& other) const {
479 return invert(_mm256_cmpgt_epi16(values, other.values));
480 }
481 Vectorized<int16_t> operator>(const Vectorized<int16_t>& other) const {
482 return _mm256_cmpgt_epi16(values, other.values);
483 }
484 Vectorized<int16_t> operator>=(const Vectorized<int16_t>& other) const {
485 return invert(_mm256_cmpgt_epi16(other.values, values));
486 }
487
488 Vectorized<int16_t> eq(const Vectorized<int16_t>& other) const;
489 Vectorized<int16_t> ne(const Vectorized<int16_t>& other) const;
490 Vectorized<int16_t> gt(const Vectorized<int16_t>& other) const;
491 Vectorized<int16_t> ge(const Vectorized<int16_t>& other) const;
492 Vectorized<int16_t> lt(const Vectorized<int16_t>& other) const;
493 Vectorized<int16_t> le(const Vectorized<int16_t>& other) const;
494};
495
496template <typename T>
497class Vectorized8 : public Vectorizedi {
498 static_assert(
499 std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
500 "Only int8_t/uint8_t are supported");
501protected:
502 static const Vectorized<T> ones;
503public:
504 using value_type = T;
505 static constexpr int size() {
506 return 32;
507 }
508 using Vectorizedi::Vectorizedi;
509 Vectorized8() {}
510 Vectorized8(T v) { values = _mm256_set1_epi8(v); }
511 Vectorized8(T val1, T val2, T val3, T val4,
512 T val5, T val6, T val7, T val8,
513 T val9, T val10, T val11, T val12,
514 T val13, T val14, T val15, T val16,
515 T val17, T val18, T val19, T val20,
516 T val21, T val22, T val23, T val24,
517 T val25, T val26, T val27, T val28,
518 T val29, T val30, T val31, T val32) {
519 values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8,
520 val9, val10, val11, val12, val13, val14, val15, val16,
521 val17, val18, val19, val20, val21, val22, val23, val24,
522 val25, val26, val27, val28, val29, val30, val31, val32);
523 }
524 template <int64_t mask>
525 static Vectorized<T> blend(Vectorized<T> a, Vectorized<T> b) {
526 __at_align__ T tmp_values[size()];
527 a.store(tmp_values);
528 if (mask & 0x01)
529 tmp_values[0] = _mm256_extract_epi8(b.values, 0);
530 if (mask & 0x02)
531 tmp_values[1] = _mm256_extract_epi8(b.values, 1);
532 if (mask & 0x04)
533 tmp_values[2] = _mm256_extract_epi8(b.values, 2);
534 if (mask & 0x08)
535 tmp_values[3] = _mm256_extract_epi8(b.values, 3);
536 if (mask & 0x10)
537 tmp_values[4] = _mm256_extract_epi8(b.values, 4);
538 if (mask & 0x20)
539 tmp_values[5] = _mm256_extract_epi8(b.values, 5);
540 if (mask & 0x40)
541 tmp_values[6] = _mm256_extract_epi8(b.values, 6);
542 if (mask & 0x80)
543 tmp_values[7] = _mm256_extract_epi8(b.values, 7);
544 if (mask & 0x100)
545 tmp_values[8] = _mm256_extract_epi8(b.values, 8);
546 if (mask & 0x200)
547 tmp_values[9] = _mm256_extract_epi8(b.values, 9);
548 if (mask & 0x400)
549 tmp_values[10] = _mm256_extract_epi8(b.values, 10);
550 if (mask & 0x800)
551 tmp_values[11] = _mm256_extract_epi8(b.values, 11);
552 if (mask & 0x1000)
553 tmp_values[12] = _mm256_extract_epi8(b.values, 12);
554 if (mask & 0x2000)
555 tmp_values[13] = _mm256_extract_epi8(b.values, 13);
556 if (mask & 0x4000)
557 tmp_values[14] = _mm256_extract_epi8(b.values, 14);
558 if (mask & 0x8000)
559 tmp_values[15] = _mm256_extract_epi8(b.values, 15);
560 if (mask & 0x010000)
561 tmp_values[16] = _mm256_extract_epi8(b.values, 16);
562 if (mask & 0x020000)
563 tmp_values[17] = _mm256_extract_epi8(b.values, 17);
564 if (mask & 0x040000)
565 tmp_values[18] = _mm256_extract_epi8(b.values, 18);
566 if (mask & 0x080000)
567 tmp_values[19] = _mm256_extract_epi8(b.values, 19);
568 if (mask & 0x100000)
569 tmp_values[20] = _mm256_extract_epi8(b.values, 20);
570 if (mask & 0x200000)
571 tmp_values[21] = _mm256_extract_epi8(b.values, 21);
572 if (mask & 0x400000)
573 tmp_values[22] = _mm256_extract_epi8(b.values, 22);
574 if (mask & 0x800000)
575 tmp_values[23] = _mm256_extract_epi8(b.values, 23);
576 if (mask & 0x1000000)
577 tmp_values[24] = _mm256_extract_epi8(b.values, 24);
578 if (mask & 0x2000000)
579 tmp_values[25] = _mm256_extract_epi8(b.values, 25);
580 if (mask & 0x4000000)
581 tmp_values[26] = _mm256_extract_epi8(b.values, 26);
582 if (mask & 0x8000000)
583 tmp_values[27] = _mm256_extract_epi8(b.values, 27);
584 if (mask & 0x10000000)
585 tmp_values[28] = _mm256_extract_epi8(b.values, 28);
586 if (mask & 0x20000000)
587 tmp_values[29] = _mm256_extract_epi8(b.values, 29);
588 if (mask & 0x40000000)
589 tmp_values[30] = _mm256_extract_epi8(b.values, 30);
590 if (mask & 0x80000000)
591 tmp_values[31] = _mm256_extract_epi8(b.values, 31);
592 return loadu(tmp_values);
593 }
594 static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
595 const Vectorized<T>& mask) {
596 return _mm256_blendv_epi8(a.values, b.values, mask.values);
597 }
598 template <typename step_t>
599 static Vectorized<T> arange(T base = 0, step_t step = static_cast<step_t>(1)) {
600 return Vectorized<T>(
601 base, base + step, base + 2 * step, base + 3 * step,
602 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
603 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
604 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step,
605 base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step,
606 base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step,
607 base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step,
608 base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step);
609 }
610 static Vectorized<T>
611 set(Vectorized<T> a, Vectorized<T> b, T count = size()) {
612 switch (count) {
613 case 0:
614 return a;
615 case 1:
616 return blend<0x1>(a, b);
617 case 2:
618 return blend<0x3>(a, b);
619 case 3:
620 return blend<0x7>(a, b);
621 case 4:
622 return blend<0xF>(a, b);
623 case 5:
624 return blend<0x1F>(a, b);
625 case 6:
626 return blend<0x3F>(a, b);
627 case 7:
628 return blend<0x7F>(a, b);
629 case 8:
630 return blend<0xFF>(a, b);
631 case 9:
632 return blend<0x1FF>(a, b);
633 case 10:
634 return blend<0x3FF>(a, b);
635 case 11:
636 return blend<0x7FF>(a, b);
637 case 12:
638 return blend<0xFFF>(a, b);
639 case 13:
640 return blend<0x1FFF>(a, b);
641 case 14:
642 return blend<0x3FFF>(a, b);
643 case 15:
644 return blend<0x7FFF>(a, b);
645 case 16:
646 return blend<0xFFFF>(a, b);
647 case 17:
648 return blend<0x1FFFF>(a, b);
649 case 18:
650 return blend<0x3FFFF>(a, b);
651 case 19:
652 return blend<0x7FFFF>(a, b);
653 case 20:
654 return blend<0xFFFFF>(a, b);
655 case 21:
656 return blend<0x1FFFFF>(a, b);
657 case 22:
658 return blend<0x3FFFFF>(a, b);
659 case 23:
660 return blend<0x7FFFFF>(a, b);
661 case 24:
662 return blend<0xFFFFFF>(a, b);
663 case 25:
664 return blend<0x1FFFFFF>(a, b);
665 case 26:
666 return blend<0x3FFFFFF>(a, b);
667 case 27:
668 return blend<0x7FFFFFF>(a, b);
669 case 28:
670 return blend<0xFFFFFFF>(a, b);
671 case 29:
672 return blend<0x1FFFFFFF>(a, b);
673 case 30:
674 return blend<0x3FFFFFFF>(a, b);
675 case 31:
676 return blend<0x7FFFFFFF>(a, b);
677 }
678 return b;
679 }
680 static Vectorized<T> loadu(const void* ptr) {
681 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
682 }
683 static Vectorized<T> loadu(const void* ptr, T count) {
684 __at_align__ T tmp_values[size()];
685 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
686 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
687 // instructions while a loop would be compiled to one instruction.
688 for (const auto i : c10::irange(size())) {
689 tmp_values[i] = 0;
690 }
691 std::memcpy(tmp_values, ptr, count * sizeof(T));
692 return loadu(tmp_values);
693 }
694 void store(void* ptr, int count = size()) const {
695 if (count == size()) {
696 // ptr need not to be aligned here. See
697 // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html
698 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
699 } else if (count > 0) {
700 __at_align__ T tmp_values[size()];
701 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
702 std::memcpy(ptr, tmp_values, count * sizeof(T));
703 }
704 }
705 const T& operator[](int idx) const = delete;
706 T& operator[](int idx) = delete;
707 Vectorized<T> real() const {
708 return *this;
709 }
710 Vectorized<T> imag() const {
711 return _mm256_set1_epi8(0);
712 }
713 Vectorized<T> conj() const {
714 return *this;
715 }
716};
717
718template<>
719class Vectorized<int8_t>: public Vectorized8<int8_t> {
720public:
721 using Vectorized8::Vectorized8;
722
723 Vectorized<int8_t> neg() const;
724
725 Vectorized<int8_t> abs() const {
726 return _mm256_abs_epi8(values);
727 }
728
729 Vectorized<int8_t> operator==(const Vectorized<int8_t>& other) const {
730 return _mm256_cmpeq_epi8(values, other.values);
731 }
732 Vectorized<int8_t> operator!=(const Vectorized<int8_t>& other) const {
733 return invert(_mm256_cmpeq_epi8(values, other.values));
734 }
735 Vectorized<int8_t> operator<(const Vectorized<int8_t>& other) const {
736 return _mm256_cmpgt_epi8(other.values, values);
737 }
738 Vectorized<int8_t> operator<=(const Vectorized<int8_t>& other) const {
739 return invert(_mm256_cmpgt_epi8(values, other.values));
740 }
741 Vectorized<int8_t> operator>(const Vectorized<int8_t>& other) const {
742 return other < *this;
743 }
744 Vectorized<int8_t> operator>=(const Vectorized<int8_t>& other) const {
745 return other <= *this;
746 }
747
748 Vectorized<int8_t> eq(const Vectorized<int8_t>& other) const;
749 Vectorized<int8_t> ne(const Vectorized<int8_t>& other) const;
750 Vectorized<int8_t> gt(const Vectorized<int8_t>& other) const;
751 Vectorized<int8_t> ge(const Vectorized<int8_t>& other) const;
752 Vectorized<int8_t> lt(const Vectorized<int8_t>& other) const;
753 Vectorized<int8_t> le(const Vectorized<int8_t>& other) const;
754};
755
756template<>
757class Vectorized<uint8_t>: public Vectorized8<uint8_t> {
758public:
759 using Vectorized8::Vectorized8;
760
761 Vectorized<uint8_t> neg() const;
762
763 Vectorized<uint8_t> abs() const {
764 return *this;
765 }
766
767 Vectorized<uint8_t> operator==(const Vectorized<uint8_t>& other) const {
768 return _mm256_cmpeq_epi8(values, other.values);
769 }
770 Vectorized<uint8_t> operator!=(const Vectorized<uint8_t>& other) const {
771 return invert(_mm256_cmpeq_epi8(values, other.values));
772 }
773 Vectorized<uint8_t> operator<(const Vectorized<uint8_t>& other) const {
774 __m256i max = _mm256_max_epu8(values, other.values);
775 return invert(_mm256_cmpeq_epi8(max, values));
776 }
777 Vectorized<uint8_t> operator<=(const Vectorized<uint8_t>& other) const {
778 __m256i max = _mm256_max_epu8(values, other.values);
779 return _mm256_cmpeq_epi8(max, other.values);
780 }
781 Vectorized<uint8_t> operator>(const Vectorized<uint8_t>& other) const {
782 return other < *this;
783 }
784 Vectorized<uint8_t> operator>=(const Vectorized<uint8_t>& other) const {
785 return other <= *this;
786 }
787
788 Vectorized<uint8_t> eq(const Vectorized<uint8_t>& other) const;
789 Vectorized<uint8_t> ne(const Vectorized<uint8_t>& other) const;
790 Vectorized<uint8_t> gt(const Vectorized<uint8_t>& other) const;
791 Vectorized<uint8_t> ge(const Vectorized<uint8_t>& other) const;
792 Vectorized<uint8_t> lt(const Vectorized<uint8_t>& other) const;
793 Vectorized<uint8_t> le(const Vectorized<uint8_t>& other) const;
794};
795
796template <>
797Vectorized<int64_t> inline operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
798 return _mm256_add_epi64(a, b);
799}
800
801template <>
802Vectorized<int32_t> inline operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
803 return _mm256_add_epi32(a, b);
804}
805
806template <>
807Vectorized<int16_t> inline operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
808 return _mm256_add_epi16(a, b);
809}
810
811template <>
812Vectorized<int8_t> inline operator+(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
813 return _mm256_add_epi8(a, b);
814}
815
816template <>
817Vectorized<uint8_t> inline operator+(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
818 return _mm256_add_epi8(a, b);
819}
820
821template <>
822Vectorized<int64_t> inline operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
823 return _mm256_sub_epi64(a, b);
824}
825
826template <>
827Vectorized<int32_t> inline operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
828 return _mm256_sub_epi32(a, b);
829}
830
831template <>
832Vectorized<int16_t> inline operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
833 return _mm256_sub_epi16(a, b);
834}
835
836template <>
837Vectorized<int8_t> inline operator-(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
838 return _mm256_sub_epi8(a, b);
839}
840
841template <>
842Vectorized<uint8_t> inline operator-(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
843 return _mm256_sub_epi8(a, b);
844}
845
846// Negation. Defined here so we can utilize operator-
847inline Vectorized<int64_t> Vectorized<int64_t>::neg() const {
848 return Vectorized<int64_t>(0) - *this;
849}
850
851inline Vectorized<int32_t> Vectorized<int32_t>::neg() const {
852 return Vectorized<int32_t>(0) - *this;
853}
854
855inline Vectorized<int16_t> Vectorized<int16_t>::neg() const {
856 return Vectorized<int16_t>(0) - *this;
857}
858
859inline Vectorized<int8_t> Vectorized<int8_t>::neg() const {
860 return Vectorized<int8_t>(0) - *this;
861}
862
863inline Vectorized<uint8_t> Vectorized<uint8_t>::neg() const {
864 return Vectorized<uint8_t>(0) - *this;
865}
866
867// Emulate operations with no native 64-bit support in avx,
868// by extracting each element, performing the operation pointwise,
869// then combining the results into a vector.
870template <typename op_t>
871Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const op_t& op) {
872 int64_t a0 = _mm256_extract_epi64(a, 0);
873 int64_t a1 = _mm256_extract_epi64(a, 1);
874 int64_t a2 = _mm256_extract_epi64(a, 2);
875 int64_t a3 = _mm256_extract_epi64(a, 3);
876
877 int64_t b0 = _mm256_extract_epi64(b, 0);
878 int64_t b1 = _mm256_extract_epi64(b, 1);
879 int64_t b2 = _mm256_extract_epi64(b, 2);
880 int64_t b3 = _mm256_extract_epi64(b, 3);
881
882 int64_t c0 = op(a0, b0);
883 int64_t c1 = op(a1, b1);
884 int64_t c2 = op(a2, b2);
885 int64_t c3 = op(a3, b3);
886
887 return _mm256_set_epi64x(c3, c2, c1, c0);
888}
889
890template <typename op_t>
891Vectorized<int64_t> inline emulate(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b, const Vectorized<int64_t>& c, const op_t& op) {
892 int64_t a0 = _mm256_extract_epi64(a, 0);
893 int64_t a1 = _mm256_extract_epi64(a, 1);
894 int64_t a2 = _mm256_extract_epi64(a, 2);
895 int64_t a3 = _mm256_extract_epi64(a, 3);
896
897 int64_t b0 = _mm256_extract_epi64(b, 0);
898 int64_t b1 = _mm256_extract_epi64(b, 1);
899 int64_t b2 = _mm256_extract_epi64(b, 2);
900 int64_t b3 = _mm256_extract_epi64(b, 3);
901
902 int64_t c0 = _mm256_extract_epi64(c, 0);
903 int64_t c1 = _mm256_extract_epi64(c, 1);
904 int64_t c2 = _mm256_extract_epi64(c, 2);
905 int64_t c3 = _mm256_extract_epi64(c, 3);
906
907 int64_t d0 = op(a0, b0, c0);
908 int64_t d1 = op(a1, b1, c1);
909 int64_t d2 = op(a2, b2, c2);
910 int64_t d3 = op(a3, b3, c3);
911
912 return _mm256_set_epi64x(d3, d2, d1, d0);
913}
914
915// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
916// This could be implemented more efficiently using epi32 instructions
917// This is also technically avx compatible, but then we'll need AVX
918// code for add as well.
919// Note: intentionally ignores undefined behavior like (-lowest * -1).
920template <>
921Vectorized<int64_t> inline operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
922 return emulate(a, b, [](int64_t a_point, int64_t b_point) __ubsan_ignore_undefined__ {return a_point * b_point;});
923}
924
925template <>
926Vectorized<int32_t> inline operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
927 return _mm256_mullo_epi32(a, b);
928}
929
930template <>
931Vectorized<int16_t> inline operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
932 return _mm256_mullo_epi16(a, b);
933}
934
935template <typename T, typename Op>
936Vectorized<T> inline int_elementwise_binary_256(const Vectorized<T>& a, const Vectorized<T>& b, Op op) {
937 T values_a[Vectorized<T>::size()];
938 T values_b[Vectorized<T>::size()];
939 a.store(values_a);
940 b.store(values_b);
941 for (int i = 0; i != Vectorized<T>::size(); i++) {
942 values_a[i] = op(values_a[i], values_b[i]);
943 }
944 return Vectorized<T>::loadu(values_a);
945}
946
947template <>
948Vectorized<int8_t> inline operator*(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
949 // We don't have an instruction for multiplying int8_t
950 return int_elementwise_binary_256(a, b, std::multiplies<int8_t>());
951}
952
953template <>
954Vectorized<uint8_t> inline operator*(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
955 // We don't have an instruction for multiplying uint8_t
956 return int_elementwise_binary_256(a, b, std::multiplies<uint8_t>());
957}
958
959template <>
960Vectorized<int64_t> inline minimum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
961 return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
962}
963
964template <>
965Vectorized<int32_t> inline minimum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
966 return _mm256_min_epi32(a, b);
967}
968
969template <>
970Vectorized<int16_t> inline minimum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
971 return _mm256_min_epi16(a, b);
972}
973
974template <>
975Vectorized<int8_t> inline minimum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
976 return _mm256_min_epi8(a, b);
977}
978
979template <>
980Vectorized<uint8_t> inline minimum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
981 return _mm256_min_epu8(a, b);
982}
983
984template <>
985Vectorized<int64_t> inline maximum(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
986 return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
987}
988
989template <>
990Vectorized<int32_t> inline maximum(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
991 return _mm256_max_epi32(a, b);
992}
993
994template <>
995Vectorized<int16_t> inline maximum(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
996 return _mm256_max_epi16(a, b);
997}
998
999template <>
1000Vectorized<int8_t> inline maximum(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1001 return _mm256_max_epi8(a, b);
1002}
1003
1004template <>
1005Vectorized<uint8_t> inline maximum(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1006 return _mm256_max_epu8(a, b);
1007}
1008
1009template <>
1010Vectorized<int64_t> inline clamp(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val, const Vectorized<int64_t>& max_val) {
1011 return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));});
1012}
1013
1014template <>
1015Vectorized<int32_t> inline clamp(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val, const Vectorized<int32_t>& max_val) {
1016 return _mm256_min_epi32(max_val, _mm256_max_epi32(a, min_val));
1017}
1018
1019template <>
1020Vectorized<int16_t> inline clamp(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val, const Vectorized<int16_t>& max_val) {
1021 return _mm256_min_epi16(max_val, _mm256_max_epi16(a, min_val));
1022}
1023
1024template <>
1025Vectorized<int8_t> inline clamp(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val, const Vectorized<int8_t>& max_val) {
1026 return _mm256_min_epi8(max_val, _mm256_max_epi8(a, min_val));
1027}
1028
1029template <>
1030Vectorized<uint8_t> inline clamp(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val, const Vectorized<uint8_t>& max_val) {
1031 return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val));
1032}
1033
1034template <>
1035Vectorized<int64_t> inline clamp_max(const Vectorized<int64_t>& a, const Vectorized<int64_t>& max_val) {
1036 return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);});
1037}
1038
1039template <>
1040Vectorized<int32_t> inline clamp_max(const Vectorized<int32_t>& a, const Vectorized<int32_t>& max_val) {
1041 return _mm256_min_epi32(max_val, a);
1042}
1043
1044template <>
1045Vectorized<int16_t> inline clamp_max(const Vectorized<int16_t>& a, const Vectorized<int16_t>& max_val) {
1046 return _mm256_min_epi16(max_val, a);
1047}
1048
1049template <>
1050Vectorized<int8_t> inline clamp_max(const Vectorized<int8_t>& a, const Vectorized<int8_t>& max_val) {
1051 return _mm256_min_epi8(max_val, a);
1052}
1053
1054template <>
1055Vectorized<uint8_t> inline clamp_max(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& max_val) {
1056 return _mm256_min_epu8(max_val, a);
1057}
1058
1059template <>
1060Vectorized<int64_t> inline clamp_min(const Vectorized<int64_t>& a, const Vectorized<int64_t>& min_val) {
1061 return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);});
1062}
1063
1064template <>
1065Vectorized<int32_t> inline clamp_min(const Vectorized<int32_t>& a, const Vectorized<int32_t>& min_val) {
1066 return _mm256_max_epi32(min_val, a);
1067}
1068
1069template <>
1070Vectorized<int16_t> inline clamp_min(const Vectorized<int16_t>& a, const Vectorized<int16_t>& min_val) {
1071 return _mm256_max_epi16(min_val, a);
1072}
1073
1074template <>
1075Vectorized<int8_t> inline clamp_min(const Vectorized<int8_t>& a, const Vectorized<int8_t>& min_val) {
1076 return _mm256_max_epi8(min_val, a);
1077}
1078
1079template <>
1080Vectorized<uint8_t> inline clamp_min(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& min_val) {
1081 return _mm256_max_epu8(min_val, a);
1082}
1083
1084template<typename T>
1085Vectorized<int32_t> inline convert_to_int32(const T* ptr) {
1086 return Vectorized<int32_t>::loadu(ptr);
1087}
1088
1089template<>
1090Vectorized<int32_t> inline convert_to_int32<int8_t>(const int8_t* ptr) {
1091 return _mm256_cvtepi8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1092}
1093
1094template<>
1095Vectorized<int32_t> inline convert_to_int32<uint8_t>(const uint8_t* ptr) {
1096 return _mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr)));
1097}
1098
1099template <>
1100Vectorized<int64_t> inline operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1101 return int_elementwise_binary_256(a, b, std::divides<int64_t>());
1102}
1103template <>
1104Vectorized<int32_t> inline operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1105 return int_elementwise_binary_256(a, b, std::divides<int32_t>());
1106}
1107template <>
1108Vectorized<int16_t> inline operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1109 return int_elementwise_binary_256(a, b, std::divides<int16_t>());
1110}
1111template <>
1112Vectorized<int8_t> inline operator/(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1113 return int_elementwise_binary_256(a, b, std::divides<int8_t>());
1114}
1115template <>
1116Vectorized<uint8_t> inline operator/(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1117 return int_elementwise_binary_256(a, b, std::divides<uint8_t>());
1118}
1119
1120template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1121inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
1122 return _mm256_and_si256(a, b);
1123}
1124template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1125inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
1126 return _mm256_or_si256(a, b);
1127}
1128template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1129inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
1130 return _mm256_xor_si256(a, b);
1131}
1132template<class T, typename std::enable_if_t<std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
1133inline Vectorized<T> operator~(const Vectorized<T>& a) {
1134 return _mm256_xor_si256(a, _mm256_set1_epi32(-1));
1135}
1136
1137inline Vectorized<int64_t> Vectorized<int64_t>::eq(const Vectorized<int64_t>& other) const {
1138 return (*this == other) & Vectorized<int64_t>(1);
1139}
1140
1141inline Vectorized<int64_t> Vectorized<int64_t>::ne(const Vectorized<int64_t>& other) const {
1142 return (*this != other) & Vectorized<int64_t>(1);
1143}
1144
1145inline Vectorized<int64_t> Vectorized<int64_t>::gt(const Vectorized<int64_t>& other) const {
1146 return (*this > other) & Vectorized<int64_t>(1);
1147}
1148
1149inline Vectorized<int64_t> Vectorized<int64_t>::ge(const Vectorized<int64_t>& other) const {
1150 return (*this >= other) & Vectorized<int64_t>(1);
1151}
1152
1153inline Vectorized<int64_t> Vectorized<int64_t>::lt(const Vectorized<int64_t>& other) const {
1154 return (*this < other) & Vectorized<int64_t>(1);
1155}
1156
1157inline Vectorized<int64_t> Vectorized<int64_t>::le(const Vectorized<int64_t>& other) const {
1158 return (*this <= other) & Vectorized<int64_t>(1);
1159}
1160
1161inline Vectorized<int32_t> Vectorized<int32_t>::eq(const Vectorized<int32_t>& other) const {
1162 return (*this == other) & Vectorized<int32_t>(1);
1163}
1164
1165inline Vectorized<int32_t> Vectorized<int32_t>::ne(const Vectorized<int32_t>& other) const {
1166 return (*this != other) & Vectorized<int32_t>(1);
1167}
1168
1169inline Vectorized<int32_t> Vectorized<int32_t>::gt(const Vectorized<int32_t>& other) const {
1170 return (*this > other) & Vectorized<int32_t>(1);
1171}
1172
1173inline Vectorized<int32_t> Vectorized<int32_t>::ge(const Vectorized<int32_t>& other) const {
1174 return (*this >= other) & Vectorized<int32_t>(1);
1175}
1176
1177inline Vectorized<int32_t> Vectorized<int32_t>::lt(const Vectorized<int32_t>& other) const {
1178 return (*this < other) & Vectorized<int32_t>(1);
1179}
1180
1181inline Vectorized<int32_t> Vectorized<int32_t>::le(const Vectorized<int32_t>& other) const {
1182 return (*this <= other) & Vectorized<int32_t>(1);
1183}
1184
1185inline Vectorized<int16_t> Vectorized<int16_t>::eq(const Vectorized<int16_t>& other) const {
1186 return (*this == other) & Vectorized<int16_t>(1);
1187}
1188
1189inline Vectorized<int16_t> Vectorized<int16_t>::ne(const Vectorized<int16_t>& other) const {
1190 return (*this != other) & Vectorized<int16_t>(1);
1191}
1192
1193inline Vectorized<int16_t> Vectorized<int16_t>::gt(const Vectorized<int16_t>& other) const {
1194 return (*this > other) & Vectorized<int16_t>(1);
1195}
1196
1197inline Vectorized<int16_t> Vectorized<int16_t>::ge(const Vectorized<int16_t>& other) const {
1198 return (*this >= other) & Vectorized<int16_t>(1);
1199}
1200
1201inline Vectorized<int16_t> Vectorized<int16_t>::lt(const Vectorized<int16_t>& other) const {
1202 return (*this < other) & Vectorized<int16_t>(1);
1203}
1204
1205inline Vectorized<int16_t> Vectorized<int16_t>::le(const Vectorized<int16_t>& other) const {
1206 return (*this <= other) & Vectorized<int16_t>(1);
1207}
1208
1209inline Vectorized<int8_t> Vectorized<int8_t>::eq(const Vectorized<int8_t>& other) const {
1210 return (*this == other) & Vectorized<int8_t>(1);
1211}
1212
1213inline Vectorized<int8_t> Vectorized<int8_t>::ne(const Vectorized<int8_t>& other) const {
1214 return (*this != other) & Vectorized<int8_t>(1);
1215}
1216
1217inline Vectorized<int8_t> Vectorized<int8_t>::gt(const Vectorized<int8_t>& other) const {
1218 return (*this > other) & Vectorized<int8_t>(1);
1219}
1220
1221inline Vectorized<int8_t> Vectorized<int8_t>::ge(const Vectorized<int8_t>& other) const {
1222 return (*this >= other) & Vectorized<int8_t>(1);
1223}
1224
1225inline Vectorized<int8_t> Vectorized<int8_t>::lt(const Vectorized<int8_t>& other) const {
1226 return (*this < other) & Vectorized<int8_t>(1);
1227}
1228
1229inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other) const {
1230 return (*this <= other) & Vectorized<int8_t>(1);
1231}
1232
1233inline Vectorized<uint8_t> Vectorized<uint8_t>::eq(const Vectorized<uint8_t>& other) const {
1234 return (*this == other) & Vectorized<uint8_t>(1);
1235}
1236
1237inline Vectorized<uint8_t> Vectorized<uint8_t>::ne(const Vectorized<uint8_t>& other) const {
1238 return (*this != other) & Vectorized<uint8_t>(1);
1239}
1240
1241inline Vectorized<uint8_t> Vectorized<uint8_t>::gt(const Vectorized<uint8_t>& other) const {
1242 return (*this > other) & Vectorized<uint8_t>(1);
1243}
1244
1245inline Vectorized<uint8_t> Vectorized<uint8_t>::ge(const Vectorized<uint8_t>& other) const {
1246 return (*this >= other) & Vectorized<uint8_t>(1);
1247}
1248
1249inline Vectorized<uint8_t> Vectorized<uint8_t>::lt(const Vectorized<uint8_t>& other) const {
1250 return (*this < other) & Vectorized<uint8_t>(1);
1251}
1252
1253inline Vectorized<uint8_t> Vectorized<uint8_t>::le(const Vectorized<uint8_t>& other) const {
1254 return (*this <= other) & Vectorized<uint8_t>(1);
1255}
1256
1257template <bool left_shift>
1258Vectorized<int16_t> inline shift_256_16(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1259 // No vector instruction for shifting int16_t, so emulating it instead.
1260
1261 // Control masks for shuffle operation, treating 256 bits as an
1262 // array of 16-bit elements, and considering pairs of neighboring
1263 // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
1264 // M!=N) is set so that shuffle will move element with index M from
1265 // input pair into element with index N in output pair, and element
1266 // with index M in output pair will be set to all 0s.
1267 __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80,
1268 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80,
1269 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80,
1270 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80);
1271 __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26,
1272 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18,
1273 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10,
1274 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2);
1275
1276 // Masks for bitwise and operation, treating 256 bits as an array of
1277 // 16-bit elements, and considering them in pairs of neighboring
1278 // elements. A mask named "keep_M" (M in [0,1]) is set so that
1279 // bitwise and will copy element with index M from input pair into
1280 // element with the same index in output pair, while the other
1281 // element in output pair will be set to all 0s.
1282 __m256i keep_0 = _mm256_set1_epi32(0xFFFF);
1283 __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000);
1284
1285 // Take each 16-bit element with idx%2==0 from input array to be
1286 // shifted and extend it to 32 bits so that 0s are added to the
1287 // right. Then, perform shifting on this 32-bit number. Upper 16
1288 // bits will be proper result of shifting original 16-bit number, so
1289 // write them to result array, into the same position from which
1290 // corresponding input element is taken. Also, make sure that
1291 // result array elements with idx%2!=0 are set to all 0s.
1292 //
1293 // Note that number of bits to shift for is extended to 32 bits by
1294 // adding 0s to the left. That means this number is not properly
1295 // sign-extended for negative values. However, number of bits to
1296 // shift is treated as an unsigned integer by respective shift
1297 // intrinsics anyway so if negative then either with or without
1298 // proper sign extension, it will be interpreted as a number greater
1299 // than 32, and the shifting result will be the same.
1300 __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1);
1301 __m256i b0 = _mm256_and_si256(b, keep_0);
1302 __m256i c0;
1303 if (left_shift)
1304 c0 = _mm256_sllv_epi32(a0, b0);
1305 else
1306 c0 = _mm256_srav_epi32(a0, b0);
1307 c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
1308
1309 // Peform shifting the same way for input array elements with
1310 // idx%2==1.
1311 __m256i a1 = _mm256_and_si256(a, keep_1);
1312 __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1313 __m256i c1;
1314 if (left_shift)
1315 c1 = _mm256_sllv_epi32(a1, b1);
1316 else
1317 c1 = _mm256_srav_epi32(a1, b1);
1318 c1 = _mm256_and_si256(c1, keep_1);
1319
1320 // Merge partial results into the final result.
1321 __m256i c = _mm256_or_si256(c0, c1);
1322
1323 return c;
1324}
1325
1326template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0>
1327Vectorized<T> inline shift_256_8(const Vectorized<T>& a, const Vectorized<T>& b) {
1328 // No vector instruction for shifting int8_t/uint8_t, so emulating
1329 // it instead.
1330
1331 // Control masks for shuffle operation, treating 256 bits as an
1332 // array of 8-bit elements, and considering quadruples of
1333 // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
1334 // in [0,1,2,3], and M!=N) is set so that shuffle will move element
1335 // with index M from input quadruple into element with index N in
1336 // output quadruple, and other elements in output quadruple will be
1337 // set to all 0s.
1338 __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80,
1339 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80,
1340 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80,
1341 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80);
1342 __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25,
1343 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17,
1344 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9,
1345 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1);
1346 __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80,
1347 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80,
1348 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80,
1349 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80);
1350 __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26,
1351 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18,
1352 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10,
1353 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2);
1354 __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80,
1355 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80,
1356 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80,
1357 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80);
1358 __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27,
1359 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19,
1360 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11,
1361 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3);
1362 __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80,
1363 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80,
1364 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80,
1365 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80);
1366 __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80,
1367 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80,
1368 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80,
1369 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80);
1370
1371 // Masks for bitwise and operation, treating 256 bits as an array of
1372 // 8-bit elements, and considering them in quadruples of neighboring
1373 // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that
1374 // bitwise and will copy element with index M from input quadruple
1375 // into element with the same index in output quadruple, while the
1376 // other elements in output quadruple will be set to all 0s.
1377 __m256i keep_0 = _mm256_set1_epi32(0xFF);
1378 __m256i keep_3 = _mm256_set1_epi32(0xFF000000);
1379
1380 // Take each 8-bit element with idx%4==0 from input array to be
1381 // shifted and extend it to 32 bits so that 0s are added to the
1382 // right. Then, perform shifting on this 32-bit number. Upper 8
1383 // bits will be proper result of shifting original 8-bit number, so
1384 // write them to result array, into the same position from which
1385 // corresponding input element is taken. Also, make sure that
1386 // result array elements with idx%4!=0 are set to all 0s.
1387 //
1388 // Note that number of bits to shift for is extended to 32 bits by
1389 // adding 0s to the left. That means this number is not properly
1390 // sign-extended for negative values. However, number of bits to
1391 // shift is treated as an unsigned integer by respective shift
1392 // intrinsics anyway so if negative then either with or without
1393 // proper sign extension, it will be interpreted as a number greater
1394 // than 32, and the shifting result will be the same.
1395 __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3);
1396 __m256i b0 = _mm256_and_si256(b, keep_0);
1397 __m256i c0;
1398 if (left_shift)
1399 c0 = _mm256_sllv_epi32(a0, b0);
1400 else
1401 if (std::is_same<T, int8_t>::value)
1402 c0 = _mm256_srav_epi32(a0, b0);
1403 else
1404 c0 = _mm256_srlv_epi32(a0, b0);
1405 c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
1406
1407 // Peform shifting the same way for input array elements with
1408 // idx%4==1.
1409 __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
1410 __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
1411 __m256i c1;
1412 if (left_shift)
1413 c1 = _mm256_sllv_epi32(a1, b1);
1414 else
1415 if (std::is_same<T, int8_t>::value)
1416 c1 = _mm256_srav_epi32(a1, b1);
1417 else
1418 c1 = _mm256_srlv_epi32(a1, b1);
1419 c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
1420
1421 // Peform shifting the same way for input array elements with
1422 // idx%4==2.
1423 __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
1424 __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
1425 __m256i c2;
1426 if (left_shift)
1427 c2 = _mm256_sllv_epi32(a2, b2);
1428 else
1429 if (std::is_same<T, int8_t>::value)
1430 c2 = _mm256_srav_epi32(a2, b2);
1431 else
1432 c2 = _mm256_srlv_epi32(a2, b2);
1433 c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
1434
1435 // Peform shifting the same way for input array elements with
1436 // idx%4==3.
1437 __m256i a3 = _mm256_and_si256(a, keep_3);
1438 __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
1439 __m256i c3;
1440 if (left_shift)
1441 c3 = _mm256_sllv_epi32(a3, b3);
1442 else
1443 if (std::is_same<T, int8_t>::value)
1444 c3 = _mm256_srav_epi32(a3, b3);
1445 else
1446 c3 = _mm256_srlv_epi32(a3, b3);
1447 c3 = _mm256_and_si256(c3, keep_3);
1448
1449 // Merge partial results into the final result.
1450 __m256i c01 = _mm256_or_si256(c0, c1);
1451 __m256i c23 = _mm256_or_si256(c2, c3);
1452 __m256i c = _mm256_or_si256(c01, c23);
1453
1454 return c;
1455}
1456
1457template <>
1458Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1459 return _mm256_sllv_epi64(a, b);
1460}
1461
1462template <>
1463Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1464 return _mm256_sllv_epi32(a, b);
1465}
1466
1467template <>
1468Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1469 return shift_256_16<true>(a, b);
1470}
1471
1472template <>
1473Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1474 return shift_256_8<true>(a, b);
1475}
1476
1477template <>
1478Vectorized<uint8_t> inline operator<<(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1479 return shift_256_8<true>(a, b);
1480}
1481
1482template <>
1483Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
1484 // No vector instruction for right shifting int64_t, so emulating it
1485 // instead.
1486
1487 // Shift the number logically to the right, thus filling the most
1488 // significant bits with 0s. Then, replace these bits with the sign
1489 // bit.
1490 __m256i sign_bits = _mm256_cmpgt_epi64(_mm256_set1_epi64x(0), a);
1491 __m256i b_inv_mod_64 = _mm256_sub_epi64(_mm256_set1_epi64x(64), b);
1492 __m256i sign_ext = _mm256_sllv_epi64(sign_bits, b_inv_mod_64);
1493 __m256i c = _mm256_srlv_epi64(a, b);
1494 c = _mm256_or_si256(c, sign_ext);
1495
1496 return c;
1497}
1498
1499template <>
1500Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
1501 return _mm256_srav_epi32(a, b);
1502}
1503
1504template <>
1505Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
1506 return shift_256_16<false>(a, b);
1507}
1508
1509template <>
1510Vectorized<int8_t> inline operator>>(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
1511 return shift_256_8<false>(a, b);
1512}
1513
1514template <>
1515Vectorized<uint8_t> inline operator>>(const Vectorized<uint8_t>& a, const Vectorized<uint8_t>& b) {
1516 return shift_256_8<false>(a, b);
1517}
1518
1519#endif
1520
1521}}}
1522