1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/ExpandUtils.h>
3#include <ATen/ExpandBase.h>
4
5#include <c10/util/irange.h>
6
7namespace at {
8namespace internal {
9TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) {
10 return OptionalTensorRef(self)->expand(size);
11}
12}
13
14namespace {
15// NOTE: are_expandable did a similar check, please keep them sync if change is needed
16template <typename Container, typename ArrayType>
17Container infer_size_impl(ArrayType a, ArrayType b) {
18 size_t dimsA = a.size();
19 size_t dimsB = b.size();
20 size_t ndim = dimsA > dimsB ? dimsA : dimsB;
21 Container expandedSizes(ndim);
22
23 // Use ptrdiff_t to ensure signed comparison.
24 for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) {
25 ptrdiff_t offset = ndim - 1 - i;
26 ptrdiff_t dimA = dimsA - 1 - offset;
27 ptrdiff_t dimB = dimsB - 1 - offset;
28 auto sizeA = (dimA >= 0) ? a[dimA] : 1;
29 auto sizeB = (dimB >= 0) ? b[dimB] : 1;
30
31 TORCH_CHECK(
32 sizeA == sizeB || sizeA == 1 || sizeB == 1,
33 "The size of tensor a (", sizeA,
34 ") must match the size of tensor b (", sizeB,
35 ") at non-singleton dimension ", i);
36
37 // 1s map to the other size (even 0).
38 expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA);
39 }
40
41 return expandedSizes;
42}
43}
44
45std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
46 return infer_size_impl<std::vector<int64_t>>(a, b);
47}
48
49DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
50 return infer_size_impl<DimVector, IntArrayRef>(a, b);
51}
52
53SymDimVector infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b) {
54 return infer_size_impl<SymDimVector, SymIntArrayRef>(a, b);
55}
56
57template<typename Container>
58C10_ALWAYS_INLINE InferExpandGeometryResult<Container> inferExpandGeometryImpl(
59 IntArrayRef tensor_sizes,
60 IntArrayRef tensor_strides,
61 IntArrayRef sizes) {
62 int64_t ndim = sizes.size();
63 int64_t tensor_dim = tensor_sizes.size();
64
65 if (tensor_dim == 0) {
66 return InferExpandGeometryResult<Container>(sizes, ndim);
67 }
68
69 InferExpandGeometryResult<Container> result(ndim);
70 auto& expandedSizes = result.sizes;
71 auto& expandedStrides = result.strides;
72
73 // create a new geometry for the tensors
74 for (int64_t i = ndim - 1; i >= 0; --i) {
75 int64_t offset = ndim - 1 - i;
76 int64_t dim = tensor_dim - 1 - offset;
77 int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1;
78 int64_t stride = (dim >= 0) ? tensor_strides[dim]
79 : expandedSizes[i + 1] * expandedStrides[i + 1];
80 int64_t targetSize = sizes[i];
81 if (targetSize == -1) {
82 TORCH_CHECK(
83 dim >= 0,
84 "The expanded size of the tensor (",
85 targetSize,
86 ") isn't allowed in a leading, non-existing dimension ",
87 i);
88 targetSize = size;
89 }
90 if (size != targetSize) {
91 TORCH_CHECK(
92 size == 1,
93 "The expanded size of the tensor (",
94 targetSize,
95 ") must match the existing size (",
96 size,
97 ") at non-singleton dimension ",
98 i,
99 ". Target sizes: ",
100 sizes,
101 ". Tensor sizes: ",
102 tensor_sizes);
103 size = targetSize;
104 stride = 0;
105 }
106 expandedSizes[i] = size;
107 expandedStrides[i] = stride;
108 }
109 return result;
110}
111
112std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
113 IntArrayRef tensor_sizes,
114 IntArrayRef tensor_strides,
115 IntArrayRef sizes) {
116 auto result = inferExpandGeometryImpl<std::vector<int64_t>>(
117 tensor_sizes, tensor_strides, sizes);
118 return std::make_tuple(std::move(result.sizes), std::move(result.strides));
119}
120
121InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
122 IntArrayRef tensor_sizes,
123 IntArrayRef tensor_strides,
124 IntArrayRef sizes) {
125 return inferExpandGeometryImpl<DimVector>(
126 tensor_sizes, tensor_strides, sizes);
127}
128
129
130// This function returns a dense and non-overlapping strides, which keeps the same layout permutation
131// as the input `tensor_strides`, computed based on the input `tensor_sizes`.
132// Note:
133// 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping,
134// If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`.
135// However, this function won't check whether inputs are dense or overlapping, so the whole function will
136// still be executed even the inputs are already dense and non-overlapping, this will cause slowness.
137//
138// Please verify whether the inputs are non-dense or overlapping before calling this function if possible,
139// if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()`
140//
141// 2. The strides propagation rule that is used in this function is exactily the same as what is being used in
142// TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details
143
144std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) {
145
146 TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(),
147 "Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size());
148
149 size_t ndim = tensor_sizes.size();
150 if (ndim == 0) {
151 return {};
152 }
153 if (ndim == 1) {
154 return {1};
155 }
156
157 std::vector<int64_t> perm(ndim);
158 // initialize perm with n-1, n-2, ..., 1, 0
159 std::iota(perm.rbegin(), perm.rend(), 0);
160
161 // The following sorting algorithm has exactly the same behavior as TensorIterator
162 // This is to make sure we have the same stride propagation everywhere.
163
164 // return -1 if dim0 should come before dim1
165 // return 1 if dim0 should come after dim1
166 // return 0 if comparison is ambiguous
167 auto should_swap = [&](size_t dim0, size_t dim1) {
168 int64_t stride0 = tensor_strides[dim0];
169 int64_t stride1 = tensor_strides[dim1];
170
171 // if any stride is 0, treat it as ambiguous comparison to
172 // keep the same behavior as TensorIterator
173 if (stride0 == 0 || stride1 == 0) {
174 return 0;
175 }
176 if (stride0 < stride1) {
177 return -1;
178 }
179 if (stride0 > stride1) {
180 return 1;
181 }
182 // for equal strides, the dimension with smaller size goes front
183 if (tensor_sizes[dim0] > tensor_sizes[dim1]) {
184 return 1;
185 }
186 return 0;
187 };
188
189 // Insertion sort (stable) indices in `perm` based on input tensor's stride and shape,
190 // all dimensions with 0 stride won't move. This is the same behavior as TensorIterator.
191 // eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm`
192 // is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2)
193 for (const auto i : c10::irange(1, ndim)) {
194 auto dim1 = i;
195 for (const auto j : c10::irange(1, i + 1)) {
196 auto dim0 = i - j;
197 int comparison = should_swap(perm[dim0], perm[dim1]);
198 if (comparison > 0) {
199 std::swap(perm[dim0], perm[dim1]);
200 dim1 = dim0;
201 }
202 else if (comparison < 0) {
203 break;
204 }
205 }
206 }
207
208 // compute output strides which preserves the input tensor's memory layout
209 std::vector<int64_t> out_strides(ndim);
210 int64_t curr_stride = 1;
211 for (const auto i : c10::irange(ndim)) {
212 int64_t idx = perm[i];
213 out_strides[idx] = curr_stride;
214 // Note: for size 0, we simply treated it as 1, it really doesn't matter here
215 // since the total number of element is 0.
216 if (tensor_sizes[idx] > 1) {
217 curr_stride *= tensor_sizes[idx];
218 }
219 }
220 return out_strides;
221}
222
223} // namespace at
224