1#pragma once
2
3#ifndef AT_PER_OPERATOR_HEADERS
4#include <ATen/Functions.h>
5#else
6#include <ATen/ops/view.h>
7#include <ATen/ops/view_copy.h>
8#endif
9
10#include <ATen/Tensor.h>
11#include <ATen/core/DimVector.h>
12#include <c10/util/Exception.h>
13#include <c10/util/MaybeOwned.h>
14#include <c10/util/irange.h>
15
16#include <functional>
17#include <sstream>
18#include <tuple>
19#include <utility>
20
21namespace at {
22
23TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
24TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
25TORCH_API SymDimVector
26infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
27
28// Named type instead of a pair/tuple so that we can be sure to
29// construct the vectors in place and get NRVO.
30template <typename Container>
31struct InferExpandGeometryResult {
32 Container sizes;
33 Container strides;
34 explicit InferExpandGeometryResult(size_t ndim)
35 : sizes(ndim), strides(ndim) {}
36 explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
37 : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
38};
39
40TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
41inferExpandGeometry(
42 IntArrayRef tensor_sizes,
43 IntArrayRef tensor_strides,
44 IntArrayRef sizes);
45
46TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
47 IntArrayRef tensor_sizes,
48 IntArrayRef tensor_strides,
49 IntArrayRef sizes);
50
51TORCH_API std::vector<int64_t> infer_dense_strides(
52 IntArrayRef tensor_sizes,
53 IntArrayRef tensor_strides);
54
55// True if input shapes are expandable
56// NOTE: infer_size did a similar check, please keep them sync if change is
57// needed
58inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
59 size_t ndim1 = shape1.size();
60 size_t ndim2 = shape2.size();
61 size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
62
63 for (int64_t i = ndim - 1; i >= 0; --i) {
64 if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
65 shape2[ndim2] == 1) {
66 continue;
67 }
68 return false;
69 }
70 return true;
71}
72
73// avoid copy-construction of Tensor by using a reference_wrapper.
74inline void check_defined(
75 std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
76 const char* api_name) {
77 for (auto& t : tensors) {
78 if (!t.get().defined()) {
79 AT_ERROR(api_name, "(...) called with an undefined Tensor");
80 }
81 }
82}
83
84// NOTE [ ExpandUtils Borrowing ]
85//
86// Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
87// expansion may not actually be needed, in which case we can improve
88// efficiency by returning
89// `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
90// that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
91// must not outlive the original `Tensor` object that `to_expand`
92// referred to! The deleted rvalue reference overloads of these
93// functions help with this by preventing trivial use of a temporary
94// resulting from a function call, but it is still possible to make a
95// mistake.
96
97inline c10::MaybeOwned<Tensor> expand_inplace(
98 const Tensor& tensor,
99 const Tensor& to_expand) {
100 if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
101 return c10::MaybeOwned<Tensor>::borrowed(to_expand);
102 }
103 return c10::MaybeOwned<Tensor>::owned(
104 to_expand.expand_symint(tensor.sym_sizes()));
105}
106
107inline c10::MaybeOwned<Tensor> expand_inplace(
108 const Tensor& tensor,
109 Tensor&& to_expand) = delete;
110
111inline c10::MaybeOwned<Tensor> expand_inplace(
112 const Tensor& tensor,
113 const Tensor& to_expand,
114 const char* api_name) {
115 check_defined({tensor, to_expand}, api_name);
116 return expand_inplace(tensor, to_expand);
117}
118
119inline c10::MaybeOwned<Tensor> expand_inplace(
120 const Tensor& tensor,
121 Tensor&& to_expand,
122 const char* api_name) = delete;
123
124inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
125expand_inplace(
126 const Tensor& tensor,
127 const Tensor& to_expand1,
128 const Tensor& to_expand2) {
129 if (tensor.sizes().equals(to_expand1.sizes()) &&
130 tensor.sizes().equals((to_expand2.sizes()))) {
131 return std::make_tuple(
132 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
133 c10::MaybeOwned<Tensor>::borrowed(to_expand2));
134 }
135
136 return std::make_tuple(
137 c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
138 c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
139}
140
141inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
142expand_inplace(
143 const Tensor& tensor,
144 Tensor&& to_expand1,
145 const Tensor& to_expand2) = delete;
146inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
147expand_inplace(
148 const Tensor& tensor,
149 const Tensor& to_expand1,
150 Tensor&& to_expand2) = delete;
151inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
152expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
153 delete;
154
155inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
156expand_inplace(
157 const Tensor& tensor,
158 const Tensor& to_expand1,
159 const Tensor& to_expand2,
160 const char* api_name) {
161 check_defined({tensor, to_expand1, to_expand2}, api_name);
162 return expand_inplace(tensor, to_expand1, to_expand2);
163}
164
165inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
166expand_inplace(
167 const Tensor& tensor,
168 Tensor&& to_expand1,
169 const Tensor& to_expand2,
170 const char* api_name) = delete;
171inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
172expand_inplace(
173 const Tensor& tensor,
174 const Tensor& to_expand1,
175 Tensor&& to_expand2,
176 const char* api_name) = delete;
177inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
178expand_inplace(
179 const Tensor& tensor,
180 Tensor&& to_expand1,
181 Tensor&& to_expand2,
182 const char* api_name) = delete;
183
184// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
185inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
186expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
187 if (to_expand1.sizes().equals(to_expand2.sizes())) {
188 return std::make_tuple(
189 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
190 c10::MaybeOwned<Tensor>::borrowed(to_expand2));
191 }
192
193 auto expanded_size =
194 infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
195 return std::make_tuple(
196 c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
197 c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)));
198}
199
200inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
201expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
202inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
203expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
204inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
205expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
206
207inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
208expand_outplace(
209 const Tensor& to_expand1,
210 const Tensor& to_expand2,
211 const char* api_name) {
212 check_defined({to_expand1, to_expand2}, api_name);
213 return expand_outplace(to_expand1, to_expand2);
214}
215
216inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
217expand_outplace(
218 Tensor&& to_expand1,
219 const Tensor& to_expand2,
220 const char* api_name) = delete;
221inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
222expand_outplace(
223 const Tensor& to_expand1,
224 Tensor&& to_expand2,
225 const char* api_name) = delete;
226inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
227expand_outplace(
228 Tensor&& to_expand1,
229 Tensor&& to_expand2,
230 const char* api_name) = delete;
231
232inline std::tuple<
233 c10::MaybeOwned<Tensor>,
234 c10::MaybeOwned<Tensor>,
235 c10::MaybeOwned<Tensor>>
236expand_outplace(
237 const Tensor& to_expand1,
238 const Tensor& to_expand2,
239 const Tensor& to_expand3) {
240 if (to_expand1.sizes().equals(to_expand2.sizes()) &&
241 to_expand1.sizes().equals(to_expand3.sizes())) {
242 return std::make_tuple(
243 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
244 c10::MaybeOwned<Tensor>::borrowed(to_expand2),
245 c10::MaybeOwned<Tensor>::borrowed(to_expand3));
246 }
247
248 auto expanded_size12 =
249 infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
250 auto expanded_size =
251 infer_size_dimvector(expanded_size12, to_expand3.sizes());
252 return std::make_tuple(
253 c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
254 c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
255 c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
256}
257
258inline std::tuple<
259 c10::MaybeOwned<Tensor>,
260 c10::MaybeOwned<Tensor>,
261 c10::MaybeOwned<Tensor>>
262expand_outplace(
263 Tensor&& to_expand1,
264 const Tensor& to_expand2,
265 const Tensor& to_expand3) = delete;
266inline std::tuple<
267 c10::MaybeOwned<Tensor>,
268 c10::MaybeOwned<Tensor>,
269 c10::MaybeOwned<Tensor>>
270expand_outplace(
271 const Tensor& to_expand1,
272 Tensor&& to_expand2,
273 const Tensor& to_expand3) = delete;
274inline std::tuple<
275 c10::MaybeOwned<Tensor>,
276 c10::MaybeOwned<Tensor>,
277 c10::MaybeOwned<Tensor>>
278expand_outplace(
279 Tensor&& to_expand1,
280 Tensor&& to_expand2,
281 const Tensor& to_expand3) = delete;
282inline std::tuple<
283 c10::MaybeOwned<Tensor>,
284 c10::MaybeOwned<Tensor>,
285 c10::MaybeOwned<Tensor>>
286expand_outplace(
287 const Tensor& to_expand1,
288 const Tensor& to_expand2,
289 Tensor&& to_expand3) = delete;
290inline std::tuple<
291 c10::MaybeOwned<Tensor>,
292 c10::MaybeOwned<Tensor>,
293 c10::MaybeOwned<Tensor>>
294expand_outplace(
295 Tensor&& to_expand1,
296 const Tensor& to_expand2,
297 Tensor&& to_expand3) = delete;
298inline std::tuple<
299 c10::MaybeOwned<Tensor>,
300 c10::MaybeOwned<Tensor>,
301 c10::MaybeOwned<Tensor>>
302expand_outplace(
303 const Tensor& to_expand1,
304 Tensor&& to_expand2,
305 Tensor&& to_expand3) = delete;
306inline std::tuple<
307 c10::MaybeOwned<Tensor>,
308 c10::MaybeOwned<Tensor>,
309 c10::MaybeOwned<Tensor>>
310expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
311 delete;
312
313inline std::tuple<
314 c10::MaybeOwned<Tensor>,
315 c10::MaybeOwned<Tensor>,
316 c10::MaybeOwned<Tensor>>
317expand_outplace(
318 const Tensor& to_expand1,
319 const Tensor& to_expand2,
320 const Tensor& to_expand3,
321 const char* api_name) {
322 check_defined({to_expand1, to_expand2, to_expand3}, api_name);
323 return expand_outplace(to_expand1, to_expand2, to_expand3);
324}
325
326inline std::tuple<
327 c10::MaybeOwned<Tensor>,
328 c10::MaybeOwned<Tensor>,
329 c10::MaybeOwned<Tensor>>
330expand_outplace(
331 Tensor&& to_expand1,
332 const Tensor& to_expand2,
333 const Tensor& to_expand3,
334 const char* api_name) = delete;
335inline std::tuple<
336 c10::MaybeOwned<Tensor>,
337 c10::MaybeOwned<Tensor>,
338 c10::MaybeOwned<Tensor>>
339expand_outplace(
340 const Tensor& to_expand1,
341 Tensor&& to_expand2,
342 const Tensor& to_expand3,
343 const char* api_name) = delete;
344inline std::tuple<
345 c10::MaybeOwned<Tensor>,
346 c10::MaybeOwned<Tensor>,
347 c10::MaybeOwned<Tensor>>
348expand_outplace(
349 Tensor&& to_expand1,
350 Tensor&& to_expand2,
351 const Tensor& to_expand3,
352 const char* api_name) = delete;
353inline std::tuple<
354 c10::MaybeOwned<Tensor>,
355 c10::MaybeOwned<Tensor>,
356 c10::MaybeOwned<Tensor>>
357expand_outplace(
358 const Tensor& to_expand1,
359 const Tensor& to_expand2,
360 Tensor&& to_expand3,
361 const char* api_name) = delete;
362inline std::tuple<
363 c10::MaybeOwned<Tensor>,
364 c10::MaybeOwned<Tensor>,
365 c10::MaybeOwned<Tensor>>
366expand_outplace(
367 Tensor&& to_expand1,
368 const Tensor& to_expand2,
369 Tensor&& to_expand3,
370 const char* api_name) = delete;
371inline std::tuple<
372 c10::MaybeOwned<Tensor>,
373 c10::MaybeOwned<Tensor>,
374 c10::MaybeOwned<Tensor>>
375expand_outplace(
376 const Tensor& to_expand1,
377 Tensor&& to_expand2,
378 Tensor&& to_expand3,
379 const char* api_name) = delete;
380inline std::tuple<
381 c10::MaybeOwned<Tensor>,
382 c10::MaybeOwned<Tensor>,
383 c10::MaybeOwned<Tensor>>
384expand_outplace(
385 Tensor&& to_expand1,
386 Tensor&& to_expand2,
387 Tensor&& to_expand3,
388 const char* api_name) = delete;
389
390inline c10::MaybeOwned<Tensor> expand_size(
391 const Tensor& to_expand,
392 IntArrayRef sizes) {
393 if (to_expand.sizes().equals(sizes)) {
394 return c10::MaybeOwned<Tensor>::borrowed(to_expand);
395 }
396
397 return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
398}
399
400inline c10::MaybeOwned<Tensor> expand_size(
401 Tensor&& to_expand,
402 IntArrayRef sizes) = delete;
403
404inline c10::MaybeOwned<Tensor> expand_size(
405 const Tensor& to_expand,
406 IntArrayRef sizes,
407 const char* api_name) {
408 check_defined({to_expand}, api_name);
409 return expand_size(to_expand, sizes);
410}
411
412inline c10::MaybeOwned<Tensor> expand_size(
413 Tensor&& to_expand,
414 IntArrayRef sizes,
415 const char* api_name) = delete;
416
417inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
418 // expands a list of Tensors; ignores undefined (null) tensors
419 bool first = true;
420 DimVector sizes;
421 for (const auto i : c10::irange(to_expand.size())) {
422 if (!to_expand[i].defined()) {
423 continue;
424 } else if (first) {
425 sizes = to_expand[i].sizes();
426 first = false;
427 } else {
428 sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
429 }
430 }
431
432 std::vector<Tensor> result(to_expand.size());
433 for (const auto i : c10::irange(to_expand.size())) {
434 if (!to_expand[i].defined()) {
435 continue;
436 } else if (to_expand[i].sizes().equals(sizes)) {
437 result[i] = to_expand[i];
438 } else {
439 result[i] = to_expand[i].expand(sizes);
440 }
441 }
442 return result;
443}
444
445template <typename T>
446inline Tensor _sum_to(
447 Tensor tensor,
448 const c10::ArrayRef<T> shape,
449 bool always_return_non_view = false) {
450 if (shape.size() == 0) {
451 return tensor.sum();
452 }
453
454 auto sizes = at::symint::sizes<T>(tensor);
455 c10::SmallVector<int64_t, 8> reduce_dims;
456 const int64_t leading_dims = sizes.size() - shape.size();
457 for (const auto i : c10::irange(leading_dims)) {
458 reduce_dims.push_back(i);
459 }
460 for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
461 if (shape[i - leading_dims] == 1 && sizes[i] != 1) {
462 reduce_dims.push_back(i);
463 }
464 }
465
466 if (!reduce_dims.empty()) {
467 tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
468 }
469
470 if (always_return_non_view) {
471 // This is only actually used by the functionalization pass.
472 // We want to be able to guarantee that this function doesn't return a view
473 // of the input.
474 return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
475 : tensor.clone();
476 } else {
477 return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
478 }
479}
480
481inline Tensor sum_to(
482 Tensor tensor,
483 const c10::SymIntArrayRef shape,
484 bool always_return_non_view = false) {
485 return _sum_to(std::move(tensor), shape, always_return_non_view);
486}
487
488// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
489// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
490inline Tensor sum_to(
491 Tensor tensor,
492 const IntArrayRef shape,
493 bool always_return_non_view = false) {
494 return _sum_to(std::move(tensor), shape, always_return_non_view);
495}
496
497static inline bool is_expandable_to(
498 SymIntArrayRef shape,
499 c10::SymIntArrayRef desired) {
500 size_t ndim = shape.size();
501 size_t target_dim = desired.size();
502 if (ndim > target_dim) {
503 return false;
504 }
505 for (const auto i : c10::irange(ndim)) {
506 const auto& size = shape[ndim - i - 1];
507 const auto& target = desired[target_dim - i - 1];
508 if (size != target && size != 1) {
509 return false;
510 }
511 }
512 return true;
513}
514
515static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
516 auto sym_shape = c10::SymIntArrayRef(
517 reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
518 auto sym_desired = c10::SymIntArrayRef(
519 reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
520 return is_expandable_to(sym_shape, sym_desired);
521}
522
523} // namespace at
524