1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | // |
6 | // Note [Do not compile initializers with AVX] |
7 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
8 | // If you define a static initializer in this file, the initialization will use |
9 | // AVX instructions because these object files are compiled with AVX enabled. |
10 | // We need to avoid non-trivial global data in these architecture specific files |
11 | // because there's no way to guard the global initializers with CPU capability |
12 | // detection. |
13 | // |
14 | // See https://github.com/pytorch/pytorch/issues/37577 for an instance |
15 | // of this bug in the past. |
16 | |
17 | #include <cassert> |
18 | #include <cstring> |
19 | #include <functional> |
20 | #include <cmath> |
21 | #include <type_traits> |
22 | #include <bitset> |
23 | |
24 | #include <ATen/cpu/vec/intrinsics.h> |
25 | #include <ATen/native/Math.h> |
26 | #include <ATen/NumericUtils.h> |
27 | #include <c10/util/C++17.h> |
28 | #include <c10/util/BFloat16.h> |
29 | #include <c10/util/BFloat16-math.h> |
30 | #include <c10/util/copysign.h> |
31 | #include <c10/util/math_compat.h> |
32 | #include <ATen/native/cpu/zmath.h> |
33 | #include <c10/util/TypeCast.h> |
34 | #include <c10/macros/Macros.h> |
35 | #include <c10/util/irange.h> |
36 | #include <c10/util/Load.h> |
37 | |
38 | // These macros helped us unify vec_base.h |
39 | #ifdef CPU_CAPABILITY_AVX512 |
40 | #if defined(__GNUC__) |
41 | #define __at_align__ __attribute__((aligned(64))) |
42 | #elif defined(_WIN32) |
43 | #define __at_align__ __declspec(align(64)) |
44 | #else |
45 | #define __at_align__ |
46 | #endif |
47 | #define VECTOR_WIDTH 64 |
48 | #define int_vector __m512i |
49 | #else // CPU_CAPABILITY_AVX512 |
50 | #if defined(__GNUC__) |
51 | #define __at_align__ __attribute__((aligned(32))) |
52 | #elif defined(_WIN32) |
53 | #define __at_align__ __declspec(align(32)) |
54 | #else |
55 | #define __at_align__ |
56 | #endif |
57 | #define VECTOR_WIDTH 32 |
58 | #define int_vector __m256i |
59 | #endif // CPU_CAPABILITY_AVX512 |
60 | |
61 | namespace at { |
62 | namespace vec { |
63 | // See Note [CPU_CAPABILITY namespace] |
64 | inline namespace CPU_CAPABILITY { |
65 | // at::Half and at::BFloat16 should be treated as floating point |
66 | template <typename T> |
67 | struct is_floating_point: |
68 | std::integral_constant<bool, |
69 | std::is_floating_point<T>::value || |
70 | std::is_same<T, at::Half>::value || |
71 | std::is_same<T, at::BFloat16>::value> { |
72 | }; |
73 | |
74 | template<size_t n> struct int_of_size; |
75 | |
76 | #define DEFINE_INT_OF_SIZE(int_t) \ |
77 | template<> struct int_of_size<sizeof(int_t)> { using type = int_t; } |
78 | |
79 | DEFINE_INT_OF_SIZE(int64_t); |
80 | DEFINE_INT_OF_SIZE(int32_t); |
81 | DEFINE_INT_OF_SIZE(int16_t); |
82 | DEFINE_INT_OF_SIZE(int8_t); |
83 | |
84 | #undef DEFINE_INT_OF_SIZE |
85 | |
86 | template <typename T> |
87 | using int_same_size_t = typename int_of_size<sizeof(T)>::type; |
88 | |
89 | // NOTE: If you specialize on a type, you must define all operations! |
90 | |
91 | // emulates Vectorized types |
92 | #if defined(__s390x__) |
93 | template <class T, class TEMP=void> |
94 | #else |
95 | template <class T> |
96 | #endif |
97 | struct Vectorized { |
98 | private: |
99 | __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; |
100 | public: |
101 | using value_type = T; |
102 | using size_type = int; |
103 | // Note [constexpr static function to avoid odr-usage compiler bug] |
104 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
105 | // Why, you might ask, is size defined to be a static constexpr function, |
106 | // rather than a more ordinary 'static constexpr int size;' variable? |
107 | // The problem lies within ODR rules for static constexpr members versus |
108 | // static constexpr functions. First, recall that this class (along with all |
109 | // of its derivations) live in an anonymous namespace: they are intended to be |
110 | // *completely* inlined at their use-sites, because we need to compile it |
111 | // multiple times for different instruction sets. |
112 | // |
113 | // Because of this constraint, we CANNOT provide a single definition for |
114 | // any static members in this class; since we want to compile the class |
115 | // multiple times, there wouldn't actually be any good place to put the |
116 | // definition. Now here is the problem: if we ODR-use a static constexpr |
117 | // member, we are *obligated* to provide a definition. Without the |
118 | // definition, you get a compile error like: |
119 | // |
120 | // relocation R_X86_64_PC32 against undefined symbol |
121 | // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making |
122 | // a shared object; recompile with -fPIC |
123 | // |
124 | // If this were C++17, we could replace a static constexpr variable with |
125 | // an inline variable which doesn't require one definition. But we are not |
126 | // C++17. So the next best thing is to replace the member with a static |
127 | // constexpr (and therefore inline) function, which does not require ODR |
128 | // either. |
129 | // |
130 | // Also, technically according to the C++ standard, we don't have to define |
131 | // a constexpr variable if we never odr-use it. But it seems that some |
132 | // versions GCC/Clang have buggy determinations on whether or not an |
133 | // identifier is odr-used or not, and in any case it's hard to tell if |
134 | // a variable is odr-used or not. So best to just cut the problem at the root. |
135 | static constexpr size_type size_T = sizeof(T); // Workaround to compile with VS2022. |
136 | static constexpr size_type size() { |
137 | return VECTOR_WIDTH / size_T; |
138 | } |
139 | Vectorized() : values{static_cast<T>(0)} {} |
140 | Vectorized(T val) { |
141 | for (int i = 0; i != size(); i++) { |
142 | values[i] = val; |
143 | } |
144 | } |
145 | template<typename... Args, |
146 | typename = std::enable_if_t<(sizeof...(Args) == size())>> |
147 | Vectorized(Args... vals) : values{vals...}{ |
148 | } |
149 | // This also implies const T& operator[](int idx) const |
150 | inline operator const T*() const { |
151 | return values; |
152 | } |
153 | // This also implies T& operator[](int idx) |
154 | inline operator T*() { |
155 | return values; |
156 | } |
157 | // Return the values as char* for type punning |
158 | auto as_bytes() const -> const char* { |
159 | return reinterpret_cast<const char*>(values); |
160 | } |
161 | template <int64_t mask_> |
162 | static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) { |
163 | int64_t mask = mask_; |
164 | Vectorized vector; |
165 | for (const auto i : c10::irange(size())) { |
166 | if (mask & 0x01) { |
167 | vector[i] = b[i]; |
168 | } else { |
169 | vector[i] = a[i]; |
170 | } |
171 | mask = mask >> 1; |
172 | } |
173 | return vector; |
174 | } |
175 | static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b, |
176 | const Vectorized<T>& mask) { |
177 | Vectorized vector; |
178 | int_same_size_t<T> buffer[size()]; |
179 | mask.store(buffer); |
180 | for (const auto i : c10::irange(size())) { |
181 | if (buffer[i] & 0x01) |
182 | { |
183 | vector[i] = b[i]; |
184 | } else { |
185 | vector[i] = a[i]; |
186 | } |
187 | } |
188 | return vector; |
189 | } |
190 | template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double) |
191 | static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) { |
192 | Vectorized vector; |
193 | for (const auto i : c10::irange(size())) { |
194 | vector.values[i] = base + i * step; |
195 | } |
196 | return vector; |
197 | } |
198 | static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) { |
199 | Vectorized vector; |
200 | for (const auto i : c10::irange(size())) { |
201 | if (i < count) { |
202 | vector[i] = b[i]; |
203 | } else { |
204 | vector[i] = a[i]; |
205 | } |
206 | } |
207 | return vector; |
208 | } |
209 | static Vectorized<T> loadu(const void* ptr) { |
210 | Vectorized vector; |
211 | std::memcpy(vector.values, ptr, VECTOR_WIDTH); |
212 | return vector; |
213 | } |
214 | static Vectorized<T> loadu(const void* ptr, int64_t count) { |
215 | Vectorized vector; |
216 | std::memcpy(vector.values, ptr, count * sizeof(T)); |
217 | return vector; |
218 | } |
219 | void store(void* ptr, int count = size()) const { |
220 | std::memcpy(ptr, values, count * sizeof(T)); |
221 | } |
222 | int zero_mask() const { |
223 | // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit |
224 | int mask = 0; |
225 | for (int i = 0; i < size(); ++ i) { |
226 | if (values[i] == static_cast<T>(0)) { |
227 | mask |= (1 << i); |
228 | } |
229 | } |
230 | return mask; |
231 | } |
232 | Vectorized<T> isnan() const { |
233 | Vectorized<T> vector; |
234 | for (int64_t i = 0; i != size(); i++) { |
235 | if (_isnan(values[i])) { |
236 | std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T)); |
237 | } else { |
238 | std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T)); |
239 | } |
240 | } |
241 | return vector; |
242 | } |
243 | Vectorized<T> map(T (*const f)(T)) const { |
244 | Vectorized<T> ret; |
245 | for (int64_t i = 0; i != size(); i++) { |
246 | ret[i] = f(values[i]); |
247 | } |
248 | return ret; |
249 | } |
250 | Vectorized<T> map(T (*const f)(const T &)) const { |
251 | Vectorized<T> ret; |
252 | for (int64_t i = 0; i != size(); i++) { |
253 | ret[i] = f(values[i]); |
254 | } |
255 | return ret; |
256 | } |
257 | template <typename other_t_abs = T, |
258 | typename std::enable_if<!is_floating_point<other_t_abs>::value && !c10::is_complex<other_t_abs>::value, int>::type = 0> |
259 | Vectorized<T> abs() const { |
260 | // other_t_abs is for SFINAE and clarity. Make sure it is not changed. |
261 | static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T" ); |
262 | return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; }); |
263 | } |
264 | template <typename float_t_abs = T, |
265 | typename std::enable_if<is_floating_point<float_t_abs>::value, int>::type = 0> |
266 | Vectorized<T> abs() const { |
267 | // float_t_abs is for SFINAE and clarity. Make sure it is not changed. |
268 | static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T" ); |
269 | // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in |
270 | // 0.0) properly. |
271 | return map([](T x) -> T { return std::abs(x); }); |
272 | } |
273 | template <typename complex_t_abs = T, |
274 | typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0> |
275 | Vectorized<T> abs() const { |
276 | // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. |
277 | static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T" ); |
278 | // Specifically map() does not perform the type conversion needed by abs. |
279 | return map([](T x) { return static_cast<T>(std::abs(x)); }); |
280 | } |
281 | |
282 | template <typename other_t_sgn = T, |
283 | typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0> |
284 | Vectorized<T> sgn() const { |
285 | return map(at::native::sgn_impl); |
286 | } |
287 | |
288 | template <typename other_t_angle = T, |
289 | typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0> |
290 | Vectorized<T> angle() const { |
291 | // other_t_angle is for SFINAE and clarity. Make sure it is not changed. |
292 | static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T" ); |
293 | return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T> |
294 | } |
295 | template <typename complex_t_angle = T, |
296 | typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0> |
297 | Vectorized<T> angle() const { |
298 | // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. |
299 | static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T" ); |
300 | return map([](T x) { return static_cast<T>(std::arg(x)); }); |
301 | } |
302 | template <typename other_t_real = T, |
303 | typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0> |
304 | Vectorized<T> real() const { |
305 | // other_t_real is for SFINAE and clarity. Make sure it is not changed. |
306 | static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T" ); |
307 | return *this; |
308 | } |
309 | template <typename complex_t_real = T, |
310 | typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0> |
311 | Vectorized<T> real() const { |
312 | // complex_t_real is for SFINAE and clarity. Make sure it is not changed. |
313 | static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T" ); |
314 | return map([](T x) { return static_cast<T>(x.real()); }); |
315 | } |
316 | template <typename other_t_imag = T, |
317 | typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0> |
318 | Vectorized<T> imag() const { |
319 | // other_t_imag is for SFINAE and clarity. Make sure it is not changed. |
320 | static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T" ); |
321 | return Vectorized(0); |
322 | } |
323 | template <typename complex_t_imag = T, |
324 | typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0> |
325 | Vectorized<T> imag() const { |
326 | // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. |
327 | static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T" ); |
328 | return map([](T x) { return static_cast<T>(x.imag()); }); |
329 | } |
330 | template <typename other_t_conj = T, |
331 | typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0> |
332 | Vectorized<T> conj() const { |
333 | // other_t_conj is for SFINAE and clarity. Make sure it is not changed. |
334 | static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T" ); |
335 | return *this; |
336 | } |
337 | template <typename complex_t_conj = T, |
338 | typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0> |
339 | Vectorized<T> conj() const { |
340 | // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. |
341 | static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T" ); |
342 | return map([](T x) { return static_cast<T>(std::conj(x)); }); |
343 | } |
344 | Vectorized<T> acos() const { |
345 | return map(std::acos); |
346 | } |
347 | Vectorized<T> asin() const { |
348 | return map(std::asin); |
349 | } |
350 | Vectorized<T> atan() const { |
351 | return map(std::atan); |
352 | } |
353 | Vectorized<T> atan2(const Vectorized<T> &exp) const { |
354 | Vectorized<T> ret; |
355 | for (const auto i : c10::irange(size())) { |
356 | ret[i] = std::atan2(values[i], exp[i]); |
357 | } |
358 | return ret; |
359 | } |
360 | template < |
361 | typename U = T, |
362 | typename std::enable_if_t<is_floating_point<U>::value, int> = 0> |
363 | Vectorized<T> copysign(const Vectorized<T> &sign) const { |
364 | Vectorized<T> ret; |
365 | for (size_type i = 0; i < size(); i++) { |
366 | ret[i] = c10::copysign(values[i], sign[i]); |
367 | } |
368 | return ret; |
369 | } |
370 | Vectorized<T> erf() const { |
371 | return map(std::erf); |
372 | } |
373 | Vectorized<T> erfc() const { |
374 | return map(std::erfc); |
375 | } |
376 | Vectorized<T> erfinv() const { |
377 | return map(calc_erfinv); |
378 | } |
379 | Vectorized<T> exp() const { |
380 | return map(std::exp); |
381 | } |
382 | Vectorized<T> exp2() const { |
383 | return map(exp2_impl); |
384 | } |
385 | Vectorized<T> expm1() const { |
386 | return map(std::expm1); |
387 | } |
388 | Vectorized<T> frac() const { |
389 | return *this - this->trunc(); |
390 | } |
391 | template < |
392 | typename U = T, |
393 | typename std::enable_if_t<is_floating_point<U>::value, int> = 0> |
394 | Vectorized<T> fmod(const Vectorized<T>& q) const { |
395 | // U is for SFINAE purposes only. Make sure it is not changed. |
396 | static_assert(std::is_same<U, T>::value, "U must be T" ); |
397 | Vectorized<T> ret; |
398 | for (const auto i : c10::irange(size())) { |
399 | ret[i] = std::fmod(values[i], q[i]); |
400 | } |
401 | return ret; |
402 | } |
403 | Vectorized<T> log() const { |
404 | return map(std::log); |
405 | } |
406 | Vectorized<T> log10() const { |
407 | return map(std::log10); |
408 | } |
409 | Vectorized<T> log1p() const { |
410 | return map(std::log1p); |
411 | } |
412 | template <typename other_t_log2 = T, |
413 | typename std::enable_if<!c10::is_complex<other_t_log2>::value, int>::type = 0> |
414 | Vectorized<T> log2() const { |
415 | // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. |
416 | static_assert(std::is_same<other_t_log2, T>::value, "other_t_log2 must be T" ); |
417 | return map(std::log2); |
418 | } |
419 | template <typename complex_t_log2 = T, |
420 | typename std::enable_if<c10::is_complex<complex_t_log2>::value, int>::type = 0> |
421 | Vectorized<T> log2() const { |
422 | // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. |
423 | static_assert(std::is_same<complex_t_log2, T>::value, "complex_t_log2 must be T" ); |
424 | const T log_2 = T(std::log(2.0)); |
425 | return Vectorized(map(std::log))/Vectorized(log_2); |
426 | } |
427 | Vectorized<T> ceil() const { |
428 | return map(at::native::ceil_impl); |
429 | } |
430 | Vectorized<T> cos() const { |
431 | return map(std::cos); |
432 | } |
433 | Vectorized<T> cosh() const { |
434 | return map(std::cosh); |
435 | } |
436 | Vectorized<T> floor() const { |
437 | return map(at::native::floor_impl); |
438 | } |
439 | Vectorized<T> hypot(const Vectorized<T> &b) const { |
440 | Vectorized<T> ret; |
441 | for (const auto i : c10::irange(size())) { |
442 | ret[i] = std::hypot(values[i], b[i]); |
443 | } |
444 | return ret; |
445 | } |
446 | Vectorized<T> i0() const { |
447 | return map(calc_i0); |
448 | } |
449 | Vectorized<T> i0e() const { |
450 | return map(calc_i0e); |
451 | } |
452 | Vectorized<T> igamma(const Vectorized<T> &x) const { |
453 | Vectorized<T> ret; |
454 | for (const auto i : c10::irange(size())) { |
455 | ret[i] = calc_igamma(values[i], x[i]); |
456 | } |
457 | return ret; |
458 | } |
459 | Vectorized<T> igammac(const Vectorized<T> &x) const { |
460 | Vectorized<T> ret; |
461 | for (const auto i : c10::irange(size())) { |
462 | ret[i] = calc_igammac(values[i], x[i]); |
463 | } |
464 | return ret; |
465 | } |
466 | Vectorized<T> neg() const { |
467 | // NB: the trailing return type is needed because we need to coerce the |
468 | // return value back to T in the case of unary operator- incuring a |
469 | // promotion |
470 | return map([](T x) -> T { return -x; }); |
471 | } |
472 | Vectorized<T> nextafter(const Vectorized<T> &b) const { |
473 | Vectorized<T> ret; |
474 | for (const auto i : c10::irange(size())) { |
475 | ret[i] = std::nextafter(values[i], b[i]); |
476 | } |
477 | return ret; |
478 | } |
479 | Vectorized<T> round() const { |
480 | // We do not use std::round because we would like to round midway numbers to the nearest even integer. |
481 | return map(at::native::round_impl); |
482 | } |
483 | Vectorized<T> sin() const { |
484 | return map(std::sin); |
485 | } |
486 | Vectorized<T> sinh() const { |
487 | return map(std::sinh); |
488 | } |
489 | Vectorized<T> tan() const { |
490 | return map(std::tan); |
491 | } |
492 | Vectorized<T> tanh() const { |
493 | return map(std::tanh); |
494 | } |
495 | Vectorized<T> trunc() const { |
496 | return map(at::native::trunc_impl); |
497 | } |
498 | Vectorized<T> lgamma() const { |
499 | return map(std::lgamma); |
500 | } |
501 | Vectorized<T> sqrt() const { |
502 | return map(std::sqrt); |
503 | } |
504 | Vectorized<T> reciprocal() const { |
505 | return map([](T x) { return (T)(1) / x; }); |
506 | } |
507 | Vectorized<T> rsqrt() const { |
508 | return map([](T x) { return (T)1 / std::sqrt(x); }); |
509 | } |
510 | Vectorized<T> pow(const Vectorized<T> &exp) const { |
511 | Vectorized<T> ret; |
512 | for (const auto i : c10::irange(size())) { |
513 | ret[i] = std::pow(values[i], exp[i]); |
514 | } |
515 | return ret; |
516 | } |
517 | private: |
518 | template <typename Op> |
519 | inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const { |
520 | // All bits are set to 1 if the pred is true, otherwise 0. |
521 | Vectorized<T> vector; |
522 | for (int64_t i = 0; i != size(); i++) { |
523 | if (op(values[i], other.values[i])) { |
524 | std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T)); |
525 | } else { |
526 | std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T)); |
527 | } |
528 | } |
529 | return vector; |
530 | } |
531 | |
532 | public: |
533 | Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); } |
534 | Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); } |
535 | Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); } |
536 | Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); } |
537 | Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); } |
538 | Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); } |
539 | |
540 | private: |
541 | template <typename Op> |
542 | inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const { |
543 | // 1 if the pred is true, otherwise 0. |
544 | Vectorized<T> vector; |
545 | for (int i = 0; i != size(); ++ i) { |
546 | vector[i] = static_cast<T>(op(values[i], other.values[i])); |
547 | } |
548 | return vector; |
549 | } |
550 | |
551 | public: |
552 | Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); } |
553 | Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); } |
554 | Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); } |
555 | Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); } |
556 | Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); } |
557 | Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); } |
558 | }; |
559 | |
560 | template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) { |
561 | Vectorized<T> c; |
562 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
563 | c[i] = a[i] + b[i]; |
564 | } |
565 | return c; |
566 | } |
567 | |
568 | template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) { |
569 | Vectorized<T> c; |
570 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
571 | c[i] = a[i] - b[i]; |
572 | } |
573 | return c; |
574 | } |
575 | |
576 | template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) { |
577 | Vectorized<T> c; |
578 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
579 | c[i] = a[i] * b[i]; |
580 | } |
581 | return c; |
582 | } |
583 | |
584 | template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ { |
585 | Vectorized<T> c; |
586 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
587 | c[i] = a[i] / b[i]; |
588 | } |
589 | return c; |
590 | } |
591 | |
592 | template <class T> Vectorized<T> inline operator||( |
593 | const Vectorized<T> &a, const Vectorized<T> &b) { |
594 | Vectorized<T> c; |
595 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
596 | c[i] = a[i] || b[i]; |
597 | } |
598 | return c; |
599 | } |
600 | |
601 | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if |
602 | // either input is a NaN. |
603 | template <class T, |
604 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
605 | Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) { |
606 | Vectorized<T> c; |
607 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
608 | c[i] = (a[i] > b[i]) ? a[i] : b[i]; |
609 | if (_isnan(a[i])) { |
610 | // If either input is NaN, propagate a NaN. |
611 | // NOTE: The case where b[i] was NaN is handled correctly by the naive |
612 | // ternary operator above. |
613 | c[i] = a[i]; |
614 | } |
615 | } |
616 | return c; |
617 | } |
618 | |
619 | template <class T, |
620 | typename std::enable_if<c10::is_complex<T>::value, int>::type = 0> |
621 | Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) { |
622 | Vectorized<T> c; |
623 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
624 | c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; |
625 | if (_isnan(a[i])) { |
626 | // If either input is NaN, propagate a NaN. |
627 | // NOTE: The case where b[i] was NaN is handled correctly by the naive |
628 | // ternary operator above. |
629 | c[i] = a[i]; |
630 | } |
631 | } |
632 | return c; |
633 | } |
634 | |
635 | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if |
636 | // either input is a NaN. |
637 | template <class T, |
638 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
639 | Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) { |
640 | Vectorized<T> c; |
641 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
642 | c[i] = (a[i] < b[i]) ? a[i] : b[i]; |
643 | if (_isnan(a[i])) { |
644 | // If either input is NaN, propagate a NaN. |
645 | // NOTE: The case where b[i] was NaN is handled correctly by the naive |
646 | // ternary operator above. |
647 | c[i] = a[i]; |
648 | } |
649 | } |
650 | return c; |
651 | } |
652 | |
653 | template <class T, |
654 | typename std::enable_if<c10::is_complex<T>::value, int>::type = 0> |
655 | Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) { |
656 | Vectorized<T> c; |
657 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
658 | c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; |
659 | if (_isnan(a[i])) { |
660 | // If either input is NaN, propagate a NaN. |
661 | // NOTE: The case where b[i] was NaN is handled correctly by the naive |
662 | // ternary operator above. |
663 | c[i] = a[i]; |
664 | } |
665 | } |
666 | return c; |
667 | } |
668 | |
669 | template <class T, |
670 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
671 | Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) { |
672 | Vectorized<T> c; |
673 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
674 | c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); |
675 | } |
676 | return c; |
677 | } |
678 | |
679 | template <class T, |
680 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
681 | Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) { |
682 | Vectorized<T> c; |
683 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
684 | c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; |
685 | } |
686 | return c; |
687 | } |
688 | |
689 | template <class T, |
690 | typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0> |
691 | Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) { |
692 | Vectorized<T> c; |
693 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
694 | c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; |
695 | } |
696 | return c; |
697 | } |
698 | |
699 | struct Vectorizedi; |
700 | |
701 | #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) |
702 | template <class T, typename Op> |
703 | static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) { |
704 | int_vector buffer; |
705 | #if defined(CPU_CAPABILITY_AVX2) |
706 | int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a)); |
707 | int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b)); |
708 | #elif defined(CPU_CAPABILITY_AVX512) |
709 | int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a)); |
710 | int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b)); |
711 | #endif |
712 | buffer = op(a_buffer, b_buffer); |
713 | __at_align__ T results[Vectorized<T>::size()]; |
714 | |
715 | #if defined(CPU_CAPABILITY_AVX2) |
716 | _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer); |
717 | #elif defined(CPU_CAPABILITY_AVX512) |
718 | _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer); |
719 | #endif |
720 | return Vectorized<T>::loadu(results); |
721 | } |
722 | |
723 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
724 | inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) { |
725 | // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline |
726 | #if defined(CPU_CAPABILITY_AVX2) |
727 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); |
728 | #elif defined(CPU_CAPABILITY_AVX512) |
729 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); |
730 | #endif |
731 | } |
732 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
733 | inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) { |
734 | // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline |
735 | #if defined(CPU_CAPABILITY_AVX2) |
736 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); |
737 | #elif defined(CPU_CAPABILITY_AVX512) |
738 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); |
739 | #endif |
740 | } |
741 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
742 | inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) { |
743 | // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline |
744 | #if defined(CPU_CAPABILITY_AVX2) |
745 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); |
746 | #elif defined(CPU_CAPABILITY_AVX512) |
747 | return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); |
748 | #endif |
749 | } |
750 | |
751 | #else |
752 | |
753 | template <typename T> |
754 | auto load(char const* data) -> T { |
755 | T ret; |
756 | std::memcpy(&ret, data, sizeof(ret)); |
757 | return ret; |
758 | } |
759 | |
760 | template<class T, typename Op> |
761 | static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) { |
762 | static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); |
763 | __at_align__ intmax_t buffer[element_no]; |
764 | static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)" ); |
765 | static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)" ); |
766 | // We should be using memcpy in order to respect the strict aliasing rule |
767 | // see: https://github.com/pytorch/pytorch/issues/66119 |
768 | // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 |
769 | // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) |
770 | const auto* a_data = a.as_bytes(); |
771 | const auto* b_data = b.as_bytes(); |
772 | // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) |
773 | for (auto& out : buffer) { |
774 | out = op(load<intmax_t>(a_data), load<intmax_t>(b_data)); |
775 | a_data += sizeof(intmax_t); |
776 | b_data += sizeof(intmax_t); |
777 | } |
778 | assert(a_data == a.as_bytes() + sizeof(a)); |
779 | assert(b_data == b.as_bytes() + sizeof(b)); |
780 | return Vectorized<T>::loadu(buffer); |
781 | } |
782 | |
783 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
784 | inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) { |
785 | return bitwise_binary_op(a, b, std::bit_and<intmax_t>()); |
786 | } |
787 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
788 | inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) { |
789 | return bitwise_binary_op(a, b, std::bit_or<intmax_t>()); |
790 | } |
791 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
792 | inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) { |
793 | return bitwise_binary_op(a, b, std::bit_xor<intmax_t>()); |
794 | } |
795 | |
796 | #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) |
797 | |
798 | template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0> |
799 | inline Vectorized<T> operator~(const Vectorized<T>& a) { |
800 | Vectorized<T> ones; // All bits are 1 |
801 | memset((T*) ones, 0xFF, VECTOR_WIDTH); |
802 | return a ^ ones; |
803 | } |
804 | |
805 | template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) { |
806 | Vectorized<T> c; |
807 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
808 | c[i] = a[i] << b[i]; |
809 | } |
810 | return c; |
811 | } |
812 | |
813 | template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) { |
814 | Vectorized<T> c; |
815 | for (int i = 0; i != Vectorized<T>::size(); i++) { |
816 | c[i] = a[i] >> b[i]; |
817 | } |
818 | return c; |
819 | } |
820 | |
821 | template <typename T> |
822 | inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) { |
823 | a = a + b; |
824 | return a; |
825 | } |
826 | template <typename T> |
827 | inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) { |
828 | a = a - b; |
829 | return a; |
830 | } |
831 | template <typename T> |
832 | inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) { |
833 | a = a / b; |
834 | return a; |
835 | } |
836 | template <typename T> |
837 | inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) { |
838 | a = a % b; |
839 | return a; |
840 | } |
841 | template <typename T> |
842 | inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) { |
843 | a = a * b; |
844 | return a; |
845 | } |
846 | |
847 | template <typename T> |
848 | inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) { |
849 | a = a << b; |
850 | return a; |
851 | } |
852 | |
853 | template <typename T> |
854 | inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) { |
855 | a = a >> b; |
856 | return a; |
857 | } |
858 | |
859 | template <typename T> |
860 | inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) { |
861 | return a * b + c; |
862 | } |
863 | |
864 | template <typename T> |
865 | inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) { |
866 | return a * b - c; |
867 | } |
868 | |
869 | template <int64_t scale = 1, typename T = void> |
870 | std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>> |
871 | inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) { |
872 | static constexpr int size = Vectorized<T>::size(); |
873 | int_same_size_t<T> index_arr[size]; |
874 | vindex.store(static_cast<void*>(index_arr)); |
875 | T buffer[size]; |
876 | for (const auto i : c10::irange(size)) { |
877 | buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; |
878 | } |
879 | return Vectorized<T>::loadu(static_cast<void*>(buffer)); |
880 | } |
881 | |
882 | template <int64_t scale = 1, typename T = void> |
883 | std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>> |
884 | inline mask_gather(const Vectorized<T>& src, T const* base_addr, |
885 | const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) { |
886 | static constexpr int size = Vectorized<T>::size(); |
887 | T src_arr[size]; |
888 | int_same_size_t<T> mask_arr[size]; // use int type so we can logical and |
889 | int_same_size_t<T> index_arr[size]; |
890 | src.store(static_cast<void*>(src_arr)); |
891 | mask.store(static_cast<void*>(mask_arr)); |
892 | vindex.store(static_cast<void*>(index_arr)); |
893 | T buffer[size]; |
894 | for (const auto i : c10::irange(size)) { |
895 | if (mask_arr[i] & 0x01) { // check highest bit |
896 | buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; |
897 | } else { |
898 | buffer[i] = src_arr[i]; |
899 | } |
900 | } |
901 | mask = Vectorized<T>(); // "zero out" mask |
902 | return Vectorized<T>::loadu(static_cast<void*>(buffer)); |
903 | } |
904 | |
905 | // Cast a given vector to another type without changing the bits representation. |
906 | // So a Vectorized<double> of 512 bits containing all ones can be cast to a |
907 | // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s). |
908 | // A Vec<double> of 256 bits containing all ones can be cast to a |
909 | // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s). |
910 | // There is a struct here because we don't have static_if and I can't |
911 | // partially specialize a templated function. |
912 | template<typename dst_t, typename src_t> |
913 | struct CastImpl { |
914 | static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) { |
915 | src_t src_arr[Vectorized<src_t>::size()]; |
916 | src.store(static_cast<void*>(src_arr)); |
917 | return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr)); |
918 | } |
919 | }; |
920 | |
921 | template<typename scalar_t> |
922 | struct CastImpl<scalar_t, scalar_t> { |
923 | static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) { |
924 | return src; |
925 | } |
926 | }; |
927 | |
928 | template<typename dst_t, typename src_t> |
929 | inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) { |
930 | return CastImpl<dst_t, src_t>::apply(src); |
931 | } |
932 | |
933 | template <typename T> |
934 | inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectorized<T>& src) { |
935 | static constexpr int size = Vectorized<T>::size(); |
936 | T src_arr[size]; |
937 | src.store(static_cast<void*>(src_arr)); |
938 | int_same_size_t<T> buffer[size]; |
939 | for (const auto i : c10::irange(size)) { |
940 | buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]); |
941 | } |
942 | return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer)); |
943 | } |
944 | |
945 | // Example inputs for AVX512: |
946 | // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} |
947 | // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} |
948 | // returns: |
949 | // Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} |
950 | // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} |
951 | // Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3} |
952 | // b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7} |
953 | // returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7} |
954 | // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7} |
955 | template <typename T> |
956 | inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>> |
957 | deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) { |
958 | static constexpr int size = Vectorized<T>::size(); |
959 | static constexpr int half_size = size / 2; |
960 | T a_arr[size]; |
961 | T b_arr[size]; |
962 | T buffer1[size]; |
963 | T buffer2[size]; |
964 | a.store(static_cast<void*>(a_arr)); |
965 | b.store(static_cast<void*>(b_arr)); |
966 | for (const auto i : c10::irange(half_size)) { |
967 | buffer1[i] = a_arr[i * 2]; |
968 | buffer1[half_size + i] = b_arr[i * 2]; |
969 | buffer2[i] = a_arr[i * 2 + 1]; |
970 | buffer2[half_size + i] = b_arr[i * 2 + 1]; |
971 | } |
972 | return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)), |
973 | Vectorized<T>::loadu(static_cast<void*>(buffer2))); |
974 | } |
975 | |
976 | // inverse operation of deinterleave2 |
977 | // Example inputs for AVX512: |
978 | // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} |
979 | // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} |
980 | // returns, for AVX512: |
981 | // Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} |
982 | // Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} |
983 | // Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7} |
984 | // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7} |
985 | // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3} |
986 | // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7} |
987 | template <typename T> |
988 | inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>> |
989 | interleave2(const Vectorized<T>& a, const Vectorized<T>& b) { |
990 | static constexpr int size = Vectorized<T>::size(); |
991 | static constexpr int half_size = size / 2; |
992 | T a_arr[size]; |
993 | T b_arr[size]; |
994 | T buffer1[size]; |
995 | T buffer2[size]; |
996 | a.store(static_cast<void*>(a_arr)); |
997 | b.store(static_cast<void*>(b_arr)); |
998 | for (const auto i : c10::irange(half_size)) { |
999 | buffer1[i * 2] = a_arr[i]; |
1000 | buffer1[i * 2 + 1] = b_arr[i]; |
1001 | buffer2[i * 2] = a_arr[half_size + i]; |
1002 | buffer2[i * 2 + 1] = b_arr[half_size + i]; |
1003 | } |
1004 | return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)), |
1005 | Vectorized<T>::loadu(static_cast<void*>(buffer2))); |
1006 | } |
1007 | |
1008 | template <typename src_T, typename dst_T> |
1009 | inline void convert(const src_T *src, dst_T *dst, int64_t n) { |
1010 | #ifndef _MSC_VER |
1011 | # pragma unroll |
1012 | #endif |
1013 | for (const auto i : c10::irange(n)) { |
1014 | (void)i; //Suppress unused variable warning |
1015 | *dst = c10::convert<dst_T>(c10::load(src)); |
1016 | src++; |
1017 | dst++; |
1018 | } |
1019 | } |
1020 | |
1021 | template <typename T> |
1022 | inline Vectorized<T> flip(const Vectorized<T> & data) { |
1023 | static constexpr int size = Vectorized<T>::size(); |
1024 | T output[size]; |
1025 | T buffer[size]; |
1026 | data.store(static_cast<void*>(buffer)); |
1027 | for (const auto i : c10::irange(size)) { |
1028 | output[i] = buffer[size - i - 1]; |
1029 | } |
1030 | return Vectorized<T>::loadu(static_cast<void*>(output)); |
1031 | } |
1032 | |
1033 | // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading |
1034 | // dimension of `src` and `ld_dst` is the leading dimension of `dst`. |
1035 | template <typename T, int M, int N> |
1036 | inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { |
1037 | for (int i = 0; i < M; i++) { |
1038 | for (int j = 0; j < N; j++) { |
1039 | dst[j*ld_dst + i] = src[i*ld_src + j]; |
1040 | } |
1041 | } |
1042 | } |
1043 | |
1044 | }}} |
1045 | |