1#pragma once
2
3#include <ATen/core/IListRef.h>
4#include <ATen/core/Tensor.h>
5#include <c10/core/TensorImpl.h>
6#include <c10/core/WrapDimMinimal.h>
7#include <c10/util/irange.h>
8
9namespace at {
10
11// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
12// range [-1, 0]. This is a special case for scalar tensors and manifests in
13// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
14// [-dim_post_expr, dim_post_expr-1].
15using c10::maybe_wrap_dim;
16
17inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
18 return maybe_wrap_dim(dim, tensor->dim());
19}
20
21inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
22 if (tensors.empty()) {
23 // can't wrap empty TensorList; rely on underlying implementation to throw
24 // error if necessary.
25 return dim;
26 }
27 return maybe_wrap_dim(dim, tensors[0].dim());
28}
29
30inline int64_t maybe_wrap_dim(
31 int64_t dim,
32 const std::vector<std::vector<int64_t>>& tensor_sizes) {
33 if (tensor_sizes.empty()) {
34 // can't wrap empty list; rely on underlying implementation to throw error
35 // if necessary
36 return dim;
37 }
38 return maybe_wrap_dim(dim, tensor_sizes[0].size());
39}
40
41// Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
42// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
43// specified using negative indices.
44//
45// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
46// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
47// dimensions not in the range [-dim_post_expr, dim_post_expr).
48inline void maybe_wrap_dims_n(
49 int64_t* dims,
50 int64_t ndims,
51 int64_t dim_post_expr,
52 bool wrap_scalars = true) {
53 if (dim_post_expr <= 0) {
54 if (wrap_scalars) {
55 dim_post_expr = 1; // this will make range [-1, 0]
56 } else {
57 TORCH_CHECK_INDEX(
58 ndims == 0,
59 "Dimension specified as ",
60 dims[0],
61 " but tensor has no dimensions");
62 return;
63 }
64 }
65 int64_t min = -dim_post_expr;
66 int64_t max = dim_post_expr - 1;
67 for (const auto i : c10::irange(ndims)) {
68 auto& dim = dims[i];
69 if (dim < min || dim > max) {
70 TORCH_CHECK_INDEX(
71 false,
72 "Dimension out of range (expected to be in range of [",
73 min,
74 ", ",
75 max,
76 "], but got ",
77 dim,
78 ")");
79 }
80 if (dim < 0)
81 dim += dim_post_expr;
82 }
83}
84
85// Given a contiguous container of dimensions `dims`, this function "Wraps"
86// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
87// specified using negative indices.
88//
89// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
90// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
91// dimensions not in the range [-dim_post_expr, dim_post_expr).
92template <typename Container>
93inline void maybe_wrap_dims(
94 Container& dims,
95 int64_t dim_post_expr,
96 bool wrap_scalars = true) {
97 return maybe_wrap_dims_n(
98 dims.data(), dims.size(), dim_post_expr, wrap_scalars);
99}
100
101// previously, size [0] tensors were the only possible empty tensors; thus, it
102// wasn't possible to cat empty tensors unless all the other tensors were
103// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
104// dimension behavior and dimension size checking). We maintain this behavior
105// for backwards compatibility, but only for this specific size (i.e. other
106// empty sizes are not skipped).
107template <typename T>
108inline int64_t _legacy_cat_wrap_dim(
109 int64_t dim,
110 const std::vector<std::vector<T>>& tensor_sizes) {
111 for (auto& sizes : tensor_sizes) {
112 if (sizes.size() == 1 && sizes[0] == 0) {
113 continue;
114 }
115 return maybe_wrap_dim(dim, sizes.size());
116 }
117 return dim;
118}
119
120inline int64_t legacy_cat_wrap_dim(
121 int64_t dim,
122 const std::vector<std::vector<int64_t>>& tensor_sizes) {
123 return _legacy_cat_wrap_dim<int64_t>(dim, tensor_sizes);
124}
125
126inline int64_t legacy_cat_wrap_dim_symint(
127 int64_t dim,
128 const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
129 return _legacy_cat_wrap_dim<c10::SymInt>(dim, tensor_sizes);
130}
131
132inline int64_t legacy_cat_wrap_dim(
133 int64_t dim,
134 const MaterializedITensorListRef& tensors) {
135 for (const Tensor& tensor : tensors) {
136 if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
137 continue;
138 }
139 return maybe_wrap_dim(dim, tensor.dim());
140 }
141 return dim;
142}
143
144// wrap negative dims in a vector
145inline void wrap_all_dims(
146 std::vector<int64_t>& dims_to_wrap,
147 int64_t tensor_total_dims) {
148 for (const auto i : c10::irange(dims_to_wrap.size())) {
149 dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
150 }
151}
152
153} // namespace at
154