1#pragma once
2
3#if defined(CPU_CAPABILITY_AVX512)
4#include <ATen/cpu/vec/vec512/vec512.h>
5#else
6#include <ATen/cpu/vec/vec256/vec256.h>
7#endif
8
9namespace at {
10namespace vec {
11// See Note [CPU_CAPABILITY namespace]
12inline namespace CPU_CAPABILITY {
13
14inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
15 __at_align__ bool buffer[x.size()];
16 x.ne(Vectorized<int8_t>(0)).store(buffer);
17
18 Vectorized<bool> ret;
19 static_assert(x.size() == ret.size(), "");
20 std::memcpy(ret, buffer, ret.size() * sizeof(bool));
21 return ret;
22}
23
24template <>
25inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
26 // See NOTE [Loading boolean values]
27 return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
28}
29
30template <>
31inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
32 // See NOTE [Loading boolean values]
33 return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
34}
35
36}}} // namespace at::vec::CPU_CAPABILITY
37