1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5//
6// Note [Do not compile initializers with AVX]
7// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8// If you define a static initializer in this file, the initialization will use
9// AVX instructions because these object files are compiled with AVX enabled.
10// We need to avoid non-trivial global data in these architecture specific files
11// because there's no way to guard the global initializers with CPU capability
12// detection.
13//
14// See https://github.com/pytorch/pytorch/issues/37577 for an instance
15// of this bug in the past.
16
17#include <cassert>
18#include <cstring>
19#include <functional>
20#include <cmath>
21#include <type_traits>
22#include <bitset>
23
24#include <ATen/cpu/vec/intrinsics.h>
25#include <ATen/native/Math.h>
26#include <ATen/NumericUtils.h>
27#include <c10/util/C++17.h>
28#include <c10/util/BFloat16.h>
29#include <c10/util/BFloat16-math.h>
30#include <c10/util/copysign.h>
31#include <c10/util/math_compat.h>
32#include <ATen/native/cpu/zmath.h>
33#include <c10/util/TypeCast.h>
34#include <c10/macros/Macros.h>
35#include <c10/util/irange.h>
36#include <c10/util/Load.h>
37
38// These macros helped us unify vec_base.h
39#ifdef CPU_CAPABILITY_AVX512
40#if defined(__GNUC__)
41#define __at_align__ __attribute__((aligned(64)))
42#elif defined(_WIN32)
43#define __at_align__ __declspec(align(64))
44#else
45#define __at_align__
46#endif
47#define VECTOR_WIDTH 64
48#define int_vector __m512i
49#else // CPU_CAPABILITY_AVX512
50#if defined(__GNUC__)
51#define __at_align__ __attribute__((aligned(32)))
52#elif defined(_WIN32)
53#define __at_align__ __declspec(align(32))
54#else
55#define __at_align__
56#endif
57#define VECTOR_WIDTH 32
58#define int_vector __m256i
59#endif // CPU_CAPABILITY_AVX512
60
61namespace at {
62namespace vec {
63// See Note [CPU_CAPABILITY namespace]
64inline namespace CPU_CAPABILITY {
65// at::Half and at::BFloat16 should be treated as floating point
66template <typename T>
67struct is_floating_point:
68 std::integral_constant<bool,
69 std::is_floating_point<T>::value ||
70 std::is_same<T, at::Half>::value ||
71 std::is_same<T, at::BFloat16>::value> {
72};
73
74template<size_t n> struct int_of_size;
75
76#define DEFINE_INT_OF_SIZE(int_t) \
77template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
78
79DEFINE_INT_OF_SIZE(int64_t);
80DEFINE_INT_OF_SIZE(int32_t);
81DEFINE_INT_OF_SIZE(int16_t);
82DEFINE_INT_OF_SIZE(int8_t);
83
84#undef DEFINE_INT_OF_SIZE
85
86template <typename T>
87using int_same_size_t = typename int_of_size<sizeof(T)>::type;
88
89// NOTE: If you specialize on a type, you must define all operations!
90
91// emulates Vectorized types
92#if defined(__s390x__)
93template <class T, class TEMP=void>
94#else
95template <class T>
96#endif
97struct Vectorized {
98private:
99 __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
100public:
101 using value_type = T;
102 using size_type = int;
103 // Note [constexpr static function to avoid odr-usage compiler bug]
104 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
105 // Why, you might ask, is size defined to be a static constexpr function,
106 // rather than a more ordinary 'static constexpr int size;' variable?
107 // The problem lies within ODR rules for static constexpr members versus
108 // static constexpr functions. First, recall that this class (along with all
109 // of its derivations) live in an anonymous namespace: they are intended to be
110 // *completely* inlined at their use-sites, because we need to compile it
111 // multiple times for different instruction sets.
112 //
113 // Because of this constraint, we CANNOT provide a single definition for
114 // any static members in this class; since we want to compile the class
115 // multiple times, there wouldn't actually be any good place to put the
116 // definition. Now here is the problem: if we ODR-use a static constexpr
117 // member, we are *obligated* to provide a definition. Without the
118 // definition, you get a compile error like:
119 //
120 // relocation R_X86_64_PC32 against undefined symbol
121 // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
122 // a shared object; recompile with -fPIC
123 //
124 // If this were C++17, we could replace a static constexpr variable with
125 // an inline variable which doesn't require one definition. But we are not
126 // C++17. So the next best thing is to replace the member with a static
127 // constexpr (and therefore inline) function, which does not require ODR
128 // either.
129 //
130 // Also, technically according to the C++ standard, we don't have to define
131 // a constexpr variable if we never odr-use it. But it seems that some
132 // versions GCC/Clang have buggy determinations on whether or not an
133 // identifier is odr-used or not, and in any case it's hard to tell if
134 // a variable is odr-used or not. So best to just cut the problem at the root.
135 static constexpr size_type size_T = sizeof(T); // Workaround to compile with VS2022.
136 static constexpr size_type size() {
137 return VECTOR_WIDTH / size_T;
138 }
139 Vectorized() : values{static_cast<T>(0)} {}
140 Vectorized(T val) {
141 for (int i = 0; i != size(); i++) {
142 values[i] = val;
143 }
144 }
145 template<typename... Args,
146 typename = std::enable_if_t<(sizeof...(Args) == size())>>
147 Vectorized(Args... vals) : values{vals...}{
148 }
149 // This also implies const T& operator[](int idx) const
150 inline operator const T*() const {
151 return values;
152 }
153 // This also implies T& operator[](int idx)
154 inline operator T*() {
155 return values;
156 }
157 // Return the values as char* for type punning
158 auto as_bytes() const -> const char* {
159 return reinterpret_cast<const char*>(values);
160 }
161 template <int64_t mask_>
162 static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
163 int64_t mask = mask_;
164 Vectorized vector;
165 for (const auto i : c10::irange(size())) {
166 if (mask & 0x01) {
167 vector[i] = b[i];
168 } else {
169 vector[i] = a[i];
170 }
171 mask = mask >> 1;
172 }
173 return vector;
174 }
175 static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
176 const Vectorized<T>& mask) {
177 Vectorized vector;
178 int_same_size_t<T> buffer[size()];
179 mask.store(buffer);
180 for (const auto i : c10::irange(size())) {
181 if (buffer[i] & 0x01)
182 {
183 vector[i] = b[i];
184 } else {
185 vector[i] = a[i];
186 }
187 }
188 return vector;
189 }
190 template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
191 static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
192 Vectorized vector;
193 for (const auto i : c10::irange(size())) {
194 vector.values[i] = base + i * step;
195 }
196 return vector;
197 }
198 static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
199 Vectorized vector;
200 for (const auto i : c10::irange(size())) {
201 if (i < count) {
202 vector[i] = b[i];
203 } else {
204 vector[i] = a[i];
205 }
206 }
207 return vector;
208 }
209 static Vectorized<T> loadu(const void* ptr) {
210 Vectorized vector;
211 std::memcpy(vector.values, ptr, VECTOR_WIDTH);
212 return vector;
213 }
214 static Vectorized<T> loadu(const void* ptr, int64_t count) {
215 Vectorized vector;
216 std::memcpy(vector.values, ptr, count * sizeof(T));
217 return vector;
218 }
219 void store(void* ptr, int count = size()) const {
220 std::memcpy(ptr, values, count * sizeof(T));
221 }
222 int zero_mask() const {
223 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
224 int mask = 0;
225 for (int i = 0; i < size(); ++ i) {
226 if (values[i] == static_cast<T>(0)) {
227 mask |= (1 << i);
228 }
229 }
230 return mask;
231 }
232 Vectorized<T> isnan() const {
233 Vectorized<T> vector;
234 for (int64_t i = 0; i != size(); i++) {
235 if (_isnan(values[i])) {
236 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
237 } else {
238 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
239 }
240 }
241 return vector;
242 }
243 Vectorized<T> map(T (*const f)(T)) const {
244 Vectorized<T> ret;
245 for (int64_t i = 0; i != size(); i++) {
246 ret[i] = f(values[i]);
247 }
248 return ret;
249 }
250 Vectorized<T> map(T (*const f)(const T &)) const {
251 Vectorized<T> ret;
252 for (int64_t i = 0; i != size(); i++) {
253 ret[i] = f(values[i]);
254 }
255 return ret;
256 }
257 template <typename other_t_abs = T,
258 typename std::enable_if<!is_floating_point<other_t_abs>::value && !c10::is_complex<other_t_abs>::value, int>::type = 0>
259 Vectorized<T> abs() const {
260 // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
261 static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
262 return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
263 }
264 template <typename float_t_abs = T,
265 typename std::enable_if<is_floating_point<float_t_abs>::value, int>::type = 0>
266 Vectorized<T> abs() const {
267 // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
268 static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
269 // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
270 // 0.0) properly.
271 return map([](T x) -> T { return std::abs(x); });
272 }
273 template <typename complex_t_abs = T,
274 typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0>
275 Vectorized<T> abs() const {
276 // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
277 static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
278 // Specifically map() does not perform the type conversion needed by abs.
279 return map([](T x) { return static_cast<T>(std::abs(x)); });
280 }
281
282 template <typename other_t_sgn = T,
283 typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
284 Vectorized<T> sgn() const {
285 return map(at::native::sgn_impl);
286 }
287
288 template <typename other_t_angle = T,
289 typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
290 Vectorized<T> angle() const {
291 // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
292 static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
293 return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
294 }
295 template <typename complex_t_angle = T,
296 typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
297 Vectorized<T> angle() const {
298 // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
299 static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
300 return map([](T x) { return static_cast<T>(std::arg(x)); });
301 }
302 template <typename other_t_real = T,
303 typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0>
304 Vectorized<T> real() const {
305 // other_t_real is for SFINAE and clarity. Make sure it is not changed.
306 static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
307 return *this;
308 }
309 template <typename complex_t_real = T,
310 typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0>
311 Vectorized<T> real() const {
312 // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
313 static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
314 return map([](T x) { return static_cast<T>(x.real()); });
315 }
316 template <typename other_t_imag = T,
317 typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0>
318 Vectorized<T> imag() const {
319 // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
320 static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
321 return Vectorized(0);
322 }
323 template <typename complex_t_imag = T,
324 typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0>
325 Vectorized<T> imag() const {
326 // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
327 static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
328 return map([](T x) { return static_cast<T>(x.imag()); });
329 }
330 template <typename other_t_conj = T,
331 typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0>
332 Vectorized<T> conj() const {
333 // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
334 static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
335 return *this;
336 }
337 template <typename complex_t_conj = T,
338 typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0>
339 Vectorized<T> conj() const {
340 // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
341 static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
342 return map([](T x) { return static_cast<T>(std::conj(x)); });
343 }
344 Vectorized<T> acos() const {
345 return map(std::acos);
346 }
347 Vectorized<T> asin() const {
348 return map(std::asin);
349 }
350 Vectorized<T> atan() const {
351 return map(std::atan);
352 }
353 Vectorized<T> atan2(const Vectorized<T> &exp) const {
354 Vectorized<T> ret;
355 for (const auto i : c10::irange(size())) {
356 ret[i] = std::atan2(values[i], exp[i]);
357 }
358 return ret;
359 }
360 template <
361 typename U = T,
362 typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
363 Vectorized<T> copysign(const Vectorized<T> &sign) const {
364 Vectorized<T> ret;
365 for (size_type i = 0; i < size(); i++) {
366 ret[i] = c10::copysign(values[i], sign[i]);
367 }
368 return ret;
369 }
370 Vectorized<T> erf() const {
371 return map(std::erf);
372 }
373 Vectorized<T> erfc() const {
374 return map(std::erfc);
375 }
376 Vectorized<T> erfinv() const {
377 return map(calc_erfinv);
378 }
379 Vectorized<T> exp() const {
380 return map(std::exp);
381 }
382 Vectorized<T> exp2() const {
383 return map(exp2_impl);
384 }
385 Vectorized<T> expm1() const {
386 return map(std::expm1);
387 }
388 Vectorized<T> frac() const {
389 return *this - this->trunc();
390 }
391 template <
392 typename U = T,
393 typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
394 Vectorized<T> fmod(const Vectorized<T>& q) const {
395 // U is for SFINAE purposes only. Make sure it is not changed.
396 static_assert(std::is_same<U, T>::value, "U must be T");
397 Vectorized<T> ret;
398 for (const auto i : c10::irange(size())) {
399 ret[i] = std::fmod(values[i], q[i]);
400 }
401 return ret;
402 }
403 Vectorized<T> log() const {
404 return map(std::log);
405 }
406 Vectorized<T> log10() const {
407 return map(std::log10);
408 }
409 Vectorized<T> log1p() const {
410 return map(std::log1p);
411 }
412 template <typename other_t_log2 = T,
413 typename std::enable_if<!c10::is_complex<other_t_log2>::value, int>::type = 0>
414 Vectorized<T> log2() const {
415 // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
416 static_assert(std::is_same<other_t_log2, T>::value, "other_t_log2 must be T");
417 return map(std::log2);
418 }
419 template <typename complex_t_log2 = T,
420 typename std::enable_if<c10::is_complex<complex_t_log2>::value, int>::type = 0>
421 Vectorized<T> log2() const {
422 // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
423 static_assert(std::is_same<complex_t_log2, T>::value, "complex_t_log2 must be T");
424 const T log_2 = T(std::log(2.0));
425 return Vectorized(map(std::log))/Vectorized(log_2);
426 }
427 Vectorized<T> ceil() const {
428 return map(at::native::ceil_impl);
429 }
430 Vectorized<T> cos() const {
431 return map(std::cos);
432 }
433 Vectorized<T> cosh() const {
434 return map(std::cosh);
435 }
436 Vectorized<T> floor() const {
437 return map(at::native::floor_impl);
438 }
439 Vectorized<T> hypot(const Vectorized<T> &b) const {
440 Vectorized<T> ret;
441 for (const auto i : c10::irange(size())) {
442 ret[i] = std::hypot(values[i], b[i]);
443 }
444 return ret;
445 }
446 Vectorized<T> i0() const {
447 return map(calc_i0);
448 }
449 Vectorized<T> i0e() const {
450 return map(calc_i0e);
451 }
452 Vectorized<T> igamma(const Vectorized<T> &x) const {
453 Vectorized<T> ret;
454 for (const auto i : c10::irange(size())) {
455 ret[i] = calc_igamma(values[i], x[i]);
456 }
457 return ret;
458 }
459 Vectorized<T> igammac(const Vectorized<T> &x) const {
460 Vectorized<T> ret;
461 for (const auto i : c10::irange(size())) {
462 ret[i] = calc_igammac(values[i], x[i]);
463 }
464 return ret;
465 }
466 Vectorized<T> neg() const {
467 // NB: the trailing return type is needed because we need to coerce the
468 // return value back to T in the case of unary operator- incuring a
469 // promotion
470 return map([](T x) -> T { return -x; });
471 }
472 Vectorized<T> nextafter(const Vectorized<T> &b) const {
473 Vectorized<T> ret;
474 for (const auto i : c10::irange(size())) {
475 ret[i] = std::nextafter(values[i], b[i]);
476 }
477 return ret;
478 }
479 Vectorized<T> round() const {
480 // We do not use std::round because we would like to round midway numbers to the nearest even integer.
481 return map(at::native::round_impl);
482 }
483 Vectorized<T> sin() const {
484 return map(std::sin);
485 }
486 Vectorized<T> sinh() const {
487 return map(std::sinh);
488 }
489 Vectorized<T> tan() const {
490 return map(std::tan);
491 }
492 Vectorized<T> tanh() const {
493 return map(std::tanh);
494 }
495 Vectorized<T> trunc() const {
496 return map(at::native::trunc_impl);
497 }
498 Vectorized<T> lgamma() const {
499 return map(std::lgamma);
500 }
501 Vectorized<T> sqrt() const {
502 return map(std::sqrt);
503 }
504 Vectorized<T> reciprocal() const {
505 return map([](T x) { return (T)(1) / x; });
506 }
507 Vectorized<T> rsqrt() const {
508 return map([](T x) { return (T)1 / std::sqrt(x); });
509 }
510 Vectorized<T> pow(const Vectorized<T> &exp) const {
511 Vectorized<T> ret;
512 for (const auto i : c10::irange(size())) {
513 ret[i] = std::pow(values[i], exp[i]);
514 }
515 return ret;
516 }
517private:
518 template <typename Op>
519 inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
520 // All bits are set to 1 if the pred is true, otherwise 0.
521 Vectorized<T> vector;
522 for (int64_t i = 0; i != size(); i++) {
523 if (op(values[i], other.values[i])) {
524 std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
525 } else {
526 std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
527 }
528 }
529 return vector;
530 }
531
532public:
533 Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
534 Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
535 Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
536 Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
537 Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
538 Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
539
540private:
541 template <typename Op>
542 inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
543 // 1 if the pred is true, otherwise 0.
544 Vectorized<T> vector;
545 for (int i = 0; i != size(); ++ i) {
546 vector[i] = static_cast<T>(op(values[i], other.values[i]));
547 }
548 return vector;
549 }
550
551public:
552 Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
553 Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
554 Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
555 Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
556 Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
557 Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
558};
559
560template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
561 Vectorized<T> c;
562 for (int i = 0; i != Vectorized<T>::size(); i++) {
563 c[i] = a[i] + b[i];
564 }
565 return c;
566}
567
568template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
569 Vectorized<T> c;
570 for (int i = 0; i != Vectorized<T>::size(); i++) {
571 c[i] = a[i] - b[i];
572 }
573 return c;
574}
575
576template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
577 Vectorized<T> c;
578 for (int i = 0; i != Vectorized<T>::size(); i++) {
579 c[i] = a[i] * b[i];
580 }
581 return c;
582}
583
584template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
585 Vectorized<T> c;
586 for (int i = 0; i != Vectorized<T>::size(); i++) {
587 c[i] = a[i] / b[i];
588 }
589 return c;
590}
591
592template <class T> Vectorized<T> inline operator||(
593 const Vectorized<T> &a, const Vectorized<T> &b) {
594 Vectorized<T> c;
595 for (int i = 0; i != Vectorized<T>::size(); i++) {
596 c[i] = a[i] || b[i];
597 }
598 return c;
599}
600
601// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
602// either input is a NaN.
603template <class T,
604 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
605Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
606 Vectorized<T> c;
607 for (int i = 0; i != Vectorized<T>::size(); i++) {
608 c[i] = (a[i] > b[i]) ? a[i] : b[i];
609 if (_isnan(a[i])) {
610 // If either input is NaN, propagate a NaN.
611 // NOTE: The case where b[i] was NaN is handled correctly by the naive
612 // ternary operator above.
613 c[i] = a[i];
614 }
615 }
616 return c;
617}
618
619template <class T,
620 typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
621Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
622 Vectorized<T> c;
623 for (int i = 0; i != Vectorized<T>::size(); i++) {
624 c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
625 if (_isnan(a[i])) {
626 // If either input is NaN, propagate a NaN.
627 // NOTE: The case where b[i] was NaN is handled correctly by the naive
628 // ternary operator above.
629 c[i] = a[i];
630 }
631 }
632 return c;
633}
634
635// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
636// either input is a NaN.
637template <class T,
638 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
639Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
640 Vectorized<T> c;
641 for (int i = 0; i != Vectorized<T>::size(); i++) {
642 c[i] = (a[i] < b[i]) ? a[i] : b[i];
643 if (_isnan(a[i])) {
644 // If either input is NaN, propagate a NaN.
645 // NOTE: The case where b[i] was NaN is handled correctly by the naive
646 // ternary operator above.
647 c[i] = a[i];
648 }
649 }
650 return c;
651}
652
653template <class T,
654 typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
655Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
656 Vectorized<T> c;
657 for (int i = 0; i != Vectorized<T>::size(); i++) {
658 c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
659 if (_isnan(a[i])) {
660 // If either input is NaN, propagate a NaN.
661 // NOTE: The case where b[i] was NaN is handled correctly by the naive
662 // ternary operator above.
663 c[i] = a[i];
664 }
665 }
666 return c;
667}
668
669template <class T,
670 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
671Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
672 Vectorized<T> c;
673 for (int i = 0; i != Vectorized<T>::size(); i++) {
674 c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
675 }
676 return c;
677}
678
679template <class T,
680 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
681Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
682 Vectorized<T> c;
683 for (int i = 0; i != Vectorized<T>::size(); i++) {
684 c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
685 }
686 return c;
687}
688
689template <class T,
690 typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
691Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
692 Vectorized<T> c;
693 for (int i = 0; i != Vectorized<T>::size(); i++) {
694 c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
695 }
696 return c;
697}
698
699struct Vectorizedi;
700
701#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
702template <class T, typename Op>
703static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
704 int_vector buffer;
705#if defined(CPU_CAPABILITY_AVX2)
706 int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
707 int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
708#elif defined(CPU_CAPABILITY_AVX512)
709 int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
710 int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
711#endif
712 buffer = op(a_buffer, b_buffer);
713 __at_align__ T results[Vectorized<T>::size()];
714
715#if defined(CPU_CAPABILITY_AVX2)
716 _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
717#elif defined(CPU_CAPABILITY_AVX512)
718 _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
719#endif
720 return Vectorized<T>::loadu(results);
721}
722
723template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
724inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
725 // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
726#if defined(CPU_CAPABILITY_AVX2)
727 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
728#elif defined(CPU_CAPABILITY_AVX512)
729 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
730#endif
731}
732template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
733inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
734 // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
735#if defined(CPU_CAPABILITY_AVX2)
736 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
737#elif defined(CPU_CAPABILITY_AVX512)
738 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
739#endif
740}
741template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
742inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
743 // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
744#if defined(CPU_CAPABILITY_AVX2)
745 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
746#elif defined(CPU_CAPABILITY_AVX512)
747 return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
748#endif
749}
750
751#else
752
753template <typename T>
754auto load(char const* data) -> T {
755 T ret;
756 std::memcpy(&ret, data, sizeof(ret));
757 return ret;
758}
759
760template<class T, typename Op>
761static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
762 static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
763 __at_align__ intmax_t buffer[element_no];
764 static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
765 static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
766 // We should be using memcpy in order to respect the strict aliasing rule
767 // see: https://github.com/pytorch/pytorch/issues/66119
768 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
769 // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
770 const auto* a_data = a.as_bytes();
771 const auto* b_data = b.as_bytes();
772 // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
773 for (auto& out : buffer) {
774 out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
775 a_data += sizeof(intmax_t);
776 b_data += sizeof(intmax_t);
777 }
778 assert(a_data == a.as_bytes() + sizeof(a));
779 assert(b_data == b.as_bytes() + sizeof(b));
780 return Vectorized<T>::loadu(buffer);
781}
782
783template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
784inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
785 return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
786}
787template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
788inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
789 return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
790}
791template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
792inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
793 return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
794}
795
796#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
797
798template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
799inline Vectorized<T> operator~(const Vectorized<T>& a) {
800 Vectorized<T> ones; // All bits are 1
801 memset((T*) ones, 0xFF, VECTOR_WIDTH);
802 return a ^ ones;
803}
804
805template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
806 Vectorized<T> c;
807 for (int i = 0; i != Vectorized<T>::size(); i++) {
808 c[i] = a[i] << b[i];
809 }
810 return c;
811}
812
813template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
814 Vectorized<T> c;
815 for (int i = 0; i != Vectorized<T>::size(); i++) {
816 c[i] = a[i] >> b[i];
817 }
818 return c;
819}
820
821template <typename T>
822inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
823 a = a + b;
824 return a;
825}
826template <typename T>
827inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
828 a = a - b;
829 return a;
830}
831template <typename T>
832inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
833 a = a / b;
834 return a;
835}
836template <typename T>
837inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
838 a = a % b;
839 return a;
840}
841template <typename T>
842inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
843 a = a * b;
844 return a;
845}
846
847template <typename T>
848inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
849 a = a << b;
850 return a;
851}
852
853template <typename T>
854inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
855 a = a >> b;
856 return a;
857}
858
859template <typename T>
860inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
861 return a * b + c;
862}
863
864template <typename T>
865inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
866 return a * b - c;
867}
868
869template <int64_t scale = 1, typename T = void>
870std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
871inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
872 static constexpr int size = Vectorized<T>::size();
873 int_same_size_t<T> index_arr[size];
874 vindex.store(static_cast<void*>(index_arr));
875 T buffer[size];
876 for (const auto i : c10::irange(size)) {
877 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
878 }
879 return Vectorized<T>::loadu(static_cast<void*>(buffer));
880}
881
882template <int64_t scale = 1, typename T = void>
883std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
884inline mask_gather(const Vectorized<T>& src, T const* base_addr,
885 const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
886 static constexpr int size = Vectorized<T>::size();
887 T src_arr[size];
888 int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
889 int_same_size_t<T> index_arr[size];
890 src.store(static_cast<void*>(src_arr));
891 mask.store(static_cast<void*>(mask_arr));
892 vindex.store(static_cast<void*>(index_arr));
893 T buffer[size];
894 for (const auto i : c10::irange(size)) {
895 if (mask_arr[i] & 0x01) { // check highest bit
896 buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
897 } else {
898 buffer[i] = src_arr[i];
899 }
900 }
901 mask = Vectorized<T>(); // "zero out" mask
902 return Vectorized<T>::loadu(static_cast<void*>(buffer));
903}
904
905// Cast a given vector to another type without changing the bits representation.
906// So a Vectorized<double> of 512 bits containing all ones can be cast to a
907// Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
908// A Vec<double> of 256 bits containing all ones can be cast to a
909// Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
910// There is a struct here because we don't have static_if and I can't
911// partially specialize a templated function.
912template<typename dst_t, typename src_t>
913struct CastImpl {
914 static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
915 src_t src_arr[Vectorized<src_t>::size()];
916 src.store(static_cast<void*>(src_arr));
917 return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
918 }
919};
920
921template<typename scalar_t>
922struct CastImpl<scalar_t, scalar_t> {
923 static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
924 return src;
925 }
926};
927
928template<typename dst_t, typename src_t>
929inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
930 return CastImpl<dst_t, src_t>::apply(src);
931}
932
933template <typename T>
934inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectorized<T>& src) {
935 static constexpr int size = Vectorized<T>::size();
936 T src_arr[size];
937 src.store(static_cast<void*>(src_arr));
938 int_same_size_t<T> buffer[size];
939 for (const auto i : c10::irange(size)) {
940 buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]);
941 }
942 return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
943}
944
945// Example inputs for AVX512:
946// a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
947// b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
948// returns:
949// Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
950// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
951// Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
952// b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
953// returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
954// Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
955template <typename T>
956inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
957deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
958 static constexpr int size = Vectorized<T>::size();
959 static constexpr int half_size = size / 2;
960 T a_arr[size];
961 T b_arr[size];
962 T buffer1[size];
963 T buffer2[size];
964 a.store(static_cast<void*>(a_arr));
965 b.store(static_cast<void*>(b_arr));
966 for (const auto i : c10::irange(half_size)) {
967 buffer1[i] = a_arr[i * 2];
968 buffer1[half_size + i] = b_arr[i * 2];
969 buffer2[i] = a_arr[i * 2 + 1];
970 buffer2[half_size + i] = b_arr[i * 2 + 1];
971 }
972 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
973 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
974}
975
976// inverse operation of deinterleave2
977// Example inputs for AVX512:
978// a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
979// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
980// returns, for AVX512:
981// Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
982// Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
983// Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
984// b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
985// returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
986// Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
987template <typename T>
988inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
989interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
990 static constexpr int size = Vectorized<T>::size();
991 static constexpr int half_size = size / 2;
992 T a_arr[size];
993 T b_arr[size];
994 T buffer1[size];
995 T buffer2[size];
996 a.store(static_cast<void*>(a_arr));
997 b.store(static_cast<void*>(b_arr));
998 for (const auto i : c10::irange(half_size)) {
999 buffer1[i * 2] = a_arr[i];
1000 buffer1[i * 2 + 1] = b_arr[i];
1001 buffer2[i * 2] = a_arr[half_size + i];
1002 buffer2[i * 2 + 1] = b_arr[half_size + i];
1003 }
1004 return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
1005 Vectorized<T>::loadu(static_cast<void*>(buffer2)));
1006}
1007
1008template <typename src_T, typename dst_T>
1009inline void convert(const src_T *src, dst_T *dst, int64_t n) {
1010#ifndef _MSC_VER
1011# pragma unroll
1012#endif
1013 for (const auto i : c10::irange(n)) {
1014 (void)i; //Suppress unused variable warning
1015 *dst = c10::convert<dst_T>(c10::load(src));
1016 src++;
1017 dst++;
1018 }
1019}
1020
1021template <typename T>
1022inline Vectorized<T> flip(const Vectorized<T> & data) {
1023 static constexpr int size = Vectorized<T>::size();
1024 T output[size];
1025 T buffer[size];
1026 data.store(static_cast<void*>(buffer));
1027 for (const auto i : c10::irange(size)) {
1028 output[i] = buffer[size - i - 1];
1029 }
1030 return Vectorized<T>::loadu(static_cast<void*>(output));
1031}
1032
1033// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
1034// dimension of `src` and `ld_dst` is the leading dimension of `dst`.
1035template <typename T, int M, int N>
1036inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
1037 for (int i = 0; i < M; i++) {
1038 for (int j = 0; j < N; j++) {
1039 dst[j*ld_dst + i] = src[i*ld_src + j];
1040 }
1041 }
1042}
1043
1044}}}
1045