1 | #pragma once |
2 | |
3 | #include <ATen/core/IListRef.h> |
4 | #include <ATen/core/Tensor.h> |
5 | #include <c10/core/TensorImpl.h> |
6 | #include <c10/core/WrapDimMinimal.h> |
7 | #include <c10/util/irange.h> |
8 | |
9 | namespace at { |
10 | |
11 | // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the |
12 | // range [-1, 0]. This is a special case for scalar tensors and manifests in |
13 | // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range |
14 | // [-dim_post_expr, dim_post_expr-1]. |
15 | using c10::maybe_wrap_dim; |
16 | |
17 | inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) { |
18 | return maybe_wrap_dim(dim, tensor->dim()); |
19 | } |
20 | |
21 | inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) { |
22 | if (tensors.empty()) { |
23 | // can't wrap empty TensorList; rely on underlying implementation to throw |
24 | // error if necessary. |
25 | return dim; |
26 | } |
27 | return maybe_wrap_dim(dim, tensors[0].dim()); |
28 | } |
29 | |
30 | inline int64_t maybe_wrap_dim( |
31 | int64_t dim, |
32 | const std::vector<std::vector<int64_t>>& tensor_sizes) { |
33 | if (tensor_sizes.empty()) { |
34 | // can't wrap empty list; rely on underlying implementation to throw error |
35 | // if necessary |
36 | return dim; |
37 | } |
38 | return maybe_wrap_dim(dim, tensor_sizes[0].size()); |
39 | } |
40 | |
41 | // Given an array of dimensions `dims` of length `ndims`, this function "Wraps" |
42 | // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be |
43 | // specified using negative indices. |
44 | // |
45 | // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will |
46 | // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for |
47 | // dimensions not in the range [-dim_post_expr, dim_post_expr). |
48 | inline void maybe_wrap_dims_n( |
49 | int64_t* dims, |
50 | int64_t ndims, |
51 | int64_t dim_post_expr, |
52 | bool wrap_scalars = true) { |
53 | if (dim_post_expr <= 0) { |
54 | if (wrap_scalars) { |
55 | dim_post_expr = 1; // this will make range [-1, 0] |
56 | } else { |
57 | TORCH_CHECK_INDEX( |
58 | ndims == 0, |
59 | "Dimension specified as " , |
60 | dims[0], |
61 | " but tensor has no dimensions" ); |
62 | return; |
63 | } |
64 | } |
65 | int64_t min = -dim_post_expr; |
66 | int64_t max = dim_post_expr - 1; |
67 | for (const auto i : c10::irange(ndims)) { |
68 | auto& dim = dims[i]; |
69 | if (dim < min || dim > max) { |
70 | TORCH_CHECK_INDEX( |
71 | false, |
72 | "Dimension out of range (expected to be in range of [" , |
73 | min, |
74 | ", " , |
75 | max, |
76 | "], but got " , |
77 | dim, |
78 | ")" ); |
79 | } |
80 | if (dim < 0) |
81 | dim += dim_post_expr; |
82 | } |
83 | } |
84 | |
85 | // Given a contiguous container of dimensions `dims`, this function "Wraps" |
86 | // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be |
87 | // specified using negative indices. |
88 | // |
89 | // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will |
90 | // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for |
91 | // dimensions not in the range [-dim_post_expr, dim_post_expr). |
92 | template <typename Container> |
93 | inline void maybe_wrap_dims( |
94 | Container& dims, |
95 | int64_t dim_post_expr, |
96 | bool wrap_scalars = true) { |
97 | return maybe_wrap_dims_n( |
98 | dims.data(), dims.size(), dim_post_expr, wrap_scalars); |
99 | } |
100 | |
101 | // previously, size [0] tensors were the only possible empty tensors; thus, it |
102 | // wasn't possible to cat empty tensors unless all the other tensors were |
103 | // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap |
104 | // dimension behavior and dimension size checking). We maintain this behavior |
105 | // for backwards compatibility, but only for this specific size (i.e. other |
106 | // empty sizes are not skipped). |
107 | template <typename T> |
108 | inline int64_t _legacy_cat_wrap_dim( |
109 | int64_t dim, |
110 | const std::vector<std::vector<T>>& tensor_sizes) { |
111 | for (auto& sizes : tensor_sizes) { |
112 | if (sizes.size() == 1 && sizes[0] == 0) { |
113 | continue; |
114 | } |
115 | return maybe_wrap_dim(dim, sizes.size()); |
116 | } |
117 | return dim; |
118 | } |
119 | |
120 | inline int64_t legacy_cat_wrap_dim( |
121 | int64_t dim, |
122 | const std::vector<std::vector<int64_t>>& tensor_sizes) { |
123 | return _legacy_cat_wrap_dim<int64_t>(dim, tensor_sizes); |
124 | } |
125 | |
126 | inline int64_t legacy_cat_wrap_dim_symint( |
127 | int64_t dim, |
128 | const std::vector<std::vector<c10::SymInt>>& tensor_sizes) { |
129 | return _legacy_cat_wrap_dim<c10::SymInt>(dim, tensor_sizes); |
130 | } |
131 | |
132 | inline int64_t legacy_cat_wrap_dim( |
133 | int64_t dim, |
134 | const MaterializedITensorListRef& tensors) { |
135 | for (const Tensor& tensor : tensors) { |
136 | if (tensor.dim() == 1 && tensor.sizes()[0] == 0) { |
137 | continue; |
138 | } |
139 | return maybe_wrap_dim(dim, tensor.dim()); |
140 | } |
141 | return dim; |
142 | } |
143 | |
144 | // wrap negative dims in a vector |
145 | inline void wrap_all_dims( |
146 | std::vector<int64_t>& dims_to_wrap, |
147 | int64_t tensor_total_dims) { |
148 | for (const auto i : c10::irange(dims_to_wrap.size())) { |
149 | dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims); |
150 | } |
151 | } |
152 | |
153 | } // namespace at |
154 | |