1 | #pragma once |
---|---|
2 | |
3 | #if defined(CPU_CAPABILITY_AVX512) |
4 | #include <ATen/cpu/vec/vec512/vec512.h> |
5 | #else |
6 | #include <ATen/cpu/vec/vec256/vec256.h> |
7 | #endif |
8 | |
9 | namespace at { |
10 | namespace vec { |
11 | // See Note [CPU_CAPABILITY namespace] |
12 | inline namespace CPU_CAPABILITY { |
13 | |
14 | inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) { |
15 | __at_align__ bool buffer[x.size()]; |
16 | x.ne(Vectorized<int8_t>(0)).store(buffer); |
17 | |
18 | Vectorized<bool> ret; |
19 | static_assert(x.size() == ret.size(), ""); |
20 | std::memcpy(ret, buffer, ret.size() * sizeof(bool)); |
21 | return ret; |
22 | } |
23 | |
24 | template <> |
25 | inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) { |
26 | // See NOTE [Loading boolean values] |
27 | return convert_to_bool(Vectorized<int8_t>::loadu(ptr)); |
28 | } |
29 | |
30 | template <> |
31 | inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) { |
32 | // See NOTE [Loading boolean values] |
33 | return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count)); |
34 | } |
35 | |
36 | }}} // namespace at::vec::CPU_CAPABILITY |
37 |