1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/WrapDimUtils.h> |
4 | #include <c10/core/TensorImpl.h> |
5 | #include <c10/util/irange.h> |
6 | #include <bitset> |
7 | #include <sstream> |
8 | |
9 | namespace at { |
10 | |
11 | // This is in an extra file to work around strange interaction of |
12 | // bitset on Windows with operator overloading |
13 | |
14 | constexpr size_t dim_bitset_size = 64; |
15 | |
16 | static inline std::bitset<dim_bitset_size> dim_list_to_bitset( |
17 | OptionalIntArrayRef opt_dims, |
18 | int64_t ndims) { |
19 | TORCH_CHECK( |
20 | ndims <= (int64_t)dim_bitset_size, |
21 | "only tensors with up to ", |
22 | dim_bitset_size, |
23 | " dims are supported"); |
24 | std::bitset<dim_bitset_size> seen; |
25 | if (opt_dims.has_value()) { |
26 | auto dims = opt_dims.value(); |
27 | for (const auto i : c10::irange(dims.size())) { |
28 | size_t dim = maybe_wrap_dim(dims[i], ndims); |
29 | TORCH_CHECK( |
30 | !seen[dim], |
31 | "dim ", |
32 | dim, |
33 | " appears multiple times in the list of dims"); |
34 | seen[dim] = true; |
35 | } |
36 | } else { |
37 | for (int64_t dim = 0; dim < ndims; dim++) { |
38 | seen[dim] = true; |
39 | } |
40 | } |
41 | return seen; |
42 | } |
43 | |
44 | } // namespace at |
45 |