1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <ATen/ExpandUtils.h> |
3 | #include <ATen/ExpandBase.h> |
4 | |
5 | #include <c10/util/irange.h> |
6 | |
7 | namespace at { |
8 | namespace internal { |
9 | TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) { |
10 | return OptionalTensorRef(self)->expand(size); |
11 | } |
12 | } |
13 | |
14 | namespace { |
15 | // NOTE: are_expandable did a similar check, please keep them sync if change is needed |
16 | template <typename Container, typename ArrayType> |
17 | Container infer_size_impl(ArrayType a, ArrayType b) { |
18 | size_t dimsA = a.size(); |
19 | size_t dimsB = b.size(); |
20 | size_t ndim = dimsA > dimsB ? dimsA : dimsB; |
21 | Container expandedSizes(ndim); |
22 | |
23 | // Use ptrdiff_t to ensure signed comparison. |
24 | for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) { |
25 | ptrdiff_t offset = ndim - 1 - i; |
26 | ptrdiff_t dimA = dimsA - 1 - offset; |
27 | ptrdiff_t dimB = dimsB - 1 - offset; |
28 | auto sizeA = (dimA >= 0) ? a[dimA] : 1; |
29 | auto sizeB = (dimB >= 0) ? b[dimB] : 1; |
30 | |
31 | TORCH_CHECK( |
32 | sizeA == sizeB || sizeA == 1 || sizeB == 1, |
33 | "The size of tensor a (" , sizeA, |
34 | ") must match the size of tensor b (" , sizeB, |
35 | ") at non-singleton dimension " , i); |
36 | |
37 | // 1s map to the other size (even 0). |
38 | expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA); |
39 | } |
40 | |
41 | return expandedSizes; |
42 | } |
43 | } |
44 | |
45 | std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) { |
46 | return infer_size_impl<std::vector<int64_t>>(a, b); |
47 | } |
48 | |
49 | DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) { |
50 | return infer_size_impl<DimVector, IntArrayRef>(a, b); |
51 | } |
52 | |
53 | SymDimVector infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b) { |
54 | return infer_size_impl<SymDimVector, SymIntArrayRef>(a, b); |
55 | } |
56 | |
57 | template<typename Container> |
58 | C10_ALWAYS_INLINE InferExpandGeometryResult<Container> inferExpandGeometryImpl( |
59 | IntArrayRef tensor_sizes, |
60 | IntArrayRef tensor_strides, |
61 | IntArrayRef sizes) { |
62 | int64_t ndim = sizes.size(); |
63 | int64_t tensor_dim = tensor_sizes.size(); |
64 | |
65 | if (tensor_dim == 0) { |
66 | return InferExpandGeometryResult<Container>(sizes, ndim); |
67 | } |
68 | |
69 | InferExpandGeometryResult<Container> result(ndim); |
70 | auto& expandedSizes = result.sizes; |
71 | auto& expandedStrides = result.strides; |
72 | |
73 | // create a new geometry for the tensors |
74 | for (int64_t i = ndim - 1; i >= 0; --i) { |
75 | int64_t offset = ndim - 1 - i; |
76 | int64_t dim = tensor_dim - 1 - offset; |
77 | int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1; |
78 | int64_t stride = (dim >= 0) ? tensor_strides[dim] |
79 | : expandedSizes[i + 1] * expandedStrides[i + 1]; |
80 | int64_t targetSize = sizes[i]; |
81 | if (targetSize == -1) { |
82 | TORCH_CHECK( |
83 | dim >= 0, |
84 | "The expanded size of the tensor (" , |
85 | targetSize, |
86 | ") isn't allowed in a leading, non-existing dimension " , |
87 | i); |
88 | targetSize = size; |
89 | } |
90 | if (size != targetSize) { |
91 | TORCH_CHECK( |
92 | size == 1, |
93 | "The expanded size of the tensor (" , |
94 | targetSize, |
95 | ") must match the existing size (" , |
96 | size, |
97 | ") at non-singleton dimension " , |
98 | i, |
99 | ". Target sizes: " , |
100 | sizes, |
101 | ". Tensor sizes: " , |
102 | tensor_sizes); |
103 | size = targetSize; |
104 | stride = 0; |
105 | } |
106 | expandedSizes[i] = size; |
107 | expandedStrides[i] = stride; |
108 | } |
109 | return result; |
110 | } |
111 | |
112 | std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry( |
113 | IntArrayRef tensor_sizes, |
114 | IntArrayRef tensor_strides, |
115 | IntArrayRef sizes) { |
116 | auto result = inferExpandGeometryImpl<std::vector<int64_t>>( |
117 | tensor_sizes, tensor_strides, sizes); |
118 | return std::make_tuple(std::move(result.sizes), std::move(result.strides)); |
119 | } |
120 | |
121 | InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector( |
122 | IntArrayRef tensor_sizes, |
123 | IntArrayRef tensor_strides, |
124 | IntArrayRef sizes) { |
125 | return inferExpandGeometryImpl<DimVector>( |
126 | tensor_sizes, tensor_strides, sizes); |
127 | } |
128 | |
129 | |
130 | // This function returns a dense and non-overlapping strides, which keeps the same layout permutation |
131 | // as the input `tensor_strides`, computed based on the input `tensor_sizes`. |
132 | // Note: |
133 | // 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping, |
134 | // If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`. |
135 | // However, this function won't check whether inputs are dense or overlapping, so the whole function will |
136 | // still be executed even the inputs are already dense and non-overlapping, this will cause slowness. |
137 | // |
138 | // Please verify whether the inputs are non-dense or overlapping before calling this function if possible, |
139 | // if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()` |
140 | // |
141 | // 2. The strides propagation rule that is used in this function is exactily the same as what is being used in |
142 | // TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details |
143 | |
144 | std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) { |
145 | |
146 | TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(), |
147 | "Input sizes and strides should have same size but got " , tensor_sizes.size(), " and " , tensor_strides.size()); |
148 | |
149 | size_t ndim = tensor_sizes.size(); |
150 | if (ndim == 0) { |
151 | return {}; |
152 | } |
153 | if (ndim == 1) { |
154 | return {1}; |
155 | } |
156 | |
157 | std::vector<int64_t> perm(ndim); |
158 | // initialize perm with n-1, n-2, ..., 1, 0 |
159 | std::iota(perm.rbegin(), perm.rend(), 0); |
160 | |
161 | // The following sorting algorithm has exactly the same behavior as TensorIterator |
162 | // This is to make sure we have the same stride propagation everywhere. |
163 | |
164 | // return -1 if dim0 should come before dim1 |
165 | // return 1 if dim0 should come after dim1 |
166 | // return 0 if comparison is ambiguous |
167 | auto should_swap = [&](size_t dim0, size_t dim1) { |
168 | int64_t stride0 = tensor_strides[dim0]; |
169 | int64_t stride1 = tensor_strides[dim1]; |
170 | |
171 | // if any stride is 0, treat it as ambiguous comparison to |
172 | // keep the same behavior as TensorIterator |
173 | if (stride0 == 0 || stride1 == 0) { |
174 | return 0; |
175 | } |
176 | if (stride0 < stride1) { |
177 | return -1; |
178 | } |
179 | if (stride0 > stride1) { |
180 | return 1; |
181 | } |
182 | // for equal strides, the dimension with smaller size goes front |
183 | if (tensor_sizes[dim0] > tensor_sizes[dim1]) { |
184 | return 1; |
185 | } |
186 | return 0; |
187 | }; |
188 | |
189 | // Insertion sort (stable) indices in `perm` based on input tensor's stride and shape, |
190 | // all dimensions with 0 stride won't move. This is the same behavior as TensorIterator. |
191 | // eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm` |
192 | // is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2) |
193 | for (const auto i : c10::irange(1, ndim)) { |
194 | auto dim1 = i; |
195 | for (const auto j : c10::irange(1, i + 1)) { |
196 | auto dim0 = i - j; |
197 | int comparison = should_swap(perm[dim0], perm[dim1]); |
198 | if (comparison > 0) { |
199 | std::swap(perm[dim0], perm[dim1]); |
200 | dim1 = dim0; |
201 | } |
202 | else if (comparison < 0) { |
203 | break; |
204 | } |
205 | } |
206 | } |
207 | |
208 | // compute output strides which preserves the input tensor's memory layout |
209 | std::vector<int64_t> out_strides(ndim); |
210 | int64_t curr_stride = 1; |
211 | for (const auto i : c10::irange(ndim)) { |
212 | int64_t idx = perm[i]; |
213 | out_strides[idx] = curr_stride; |
214 | // Note: for size 0, we simply treated it as 1, it really doesn't matter here |
215 | // since the total number of element is 0. |
216 | if (tensor_sizes[idx] > 1) { |
217 | curr_stride *= tensor_sizes[idx]; |
218 | } |
219 | } |
220 | return out_strides; |
221 | } |
222 | |
223 | } // namespace at |
224 | |