1 | #pragma once |
2 | |
3 | #ifndef AT_PER_OPERATOR_HEADERS |
4 | #include <ATen/Functions.h> |
5 | #else |
6 | #include <ATen/ops/view.h> |
7 | #include <ATen/ops/view_copy.h> |
8 | #endif |
9 | |
10 | #include <ATen/Tensor.h> |
11 | #include <ATen/core/DimVector.h> |
12 | #include <c10/util/Exception.h> |
13 | #include <c10/util/MaybeOwned.h> |
14 | #include <c10/util/irange.h> |
15 | |
16 | #include <functional> |
17 | #include <sstream> |
18 | #include <tuple> |
19 | #include <utility> |
20 | |
21 | namespace at { |
22 | |
23 | TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b); |
24 | TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); |
25 | TORCH_API SymDimVector |
26 | infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); |
27 | |
28 | // Named type instead of a pair/tuple so that we can be sure to |
29 | // construct the vectors in place and get NRVO. |
30 | template <typename Container> |
31 | struct InferExpandGeometryResult { |
32 | Container sizes; |
33 | Container strides; |
34 | explicit InferExpandGeometryResult(size_t ndim) |
35 | : sizes(ndim), strides(ndim) {} |
36 | explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim) |
37 | : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {} |
38 | }; |
39 | |
40 | TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>> |
41 | inferExpandGeometry( |
42 | IntArrayRef tensor_sizes, |
43 | IntArrayRef tensor_strides, |
44 | IntArrayRef sizes); |
45 | |
46 | TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector( |
47 | IntArrayRef tensor_sizes, |
48 | IntArrayRef tensor_strides, |
49 | IntArrayRef sizes); |
50 | |
51 | TORCH_API std::vector<int64_t> infer_dense_strides( |
52 | IntArrayRef tensor_sizes, |
53 | IntArrayRef tensor_strides); |
54 | |
55 | // True if input shapes are expandable |
56 | // NOTE: infer_size did a similar check, please keep them sync if change is |
57 | // needed |
58 | inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) { |
59 | size_t ndim1 = shape1.size(); |
60 | size_t ndim2 = shape2.size(); |
61 | size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2; |
62 | |
63 | for (int64_t i = ndim - 1; i >= 0; --i) { |
64 | if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 || |
65 | shape2[ndim2] == 1) { |
66 | continue; |
67 | } |
68 | return false; |
69 | } |
70 | return true; |
71 | } |
72 | |
73 | // avoid copy-construction of Tensor by using a reference_wrapper. |
74 | inline void check_defined( |
75 | std::initializer_list<std::reference_wrapper<const Tensor>> tensors, |
76 | const char* api_name) { |
77 | for (auto& t : tensors) { |
78 | if (!t.get().defined()) { |
79 | AT_ERROR(api_name, "(...) called with an undefined Tensor" ); |
80 | } |
81 | } |
82 | } |
83 | |
84 | // NOTE [ ExpandUtils Borrowing ] |
85 | // |
86 | // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because |
87 | // expansion may not actually be needed, in which case we can improve |
88 | // efficiency by returning |
89 | // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means |
90 | // that you need to be careful: the returned `c10::MaybeOwned<Tensor>` |
91 | // must not outlive the original `Tensor` object that `to_expand` |
92 | // referred to! The deleted rvalue reference overloads of these |
93 | // functions help with this by preventing trivial use of a temporary |
94 | // resulting from a function call, but it is still possible to make a |
95 | // mistake. |
96 | |
97 | inline c10::MaybeOwned<Tensor> expand_inplace( |
98 | const Tensor& tensor, |
99 | const Tensor& to_expand) { |
100 | if (tensor.sym_sizes().equals(to_expand.sym_sizes())) { |
101 | return c10::MaybeOwned<Tensor>::borrowed(to_expand); |
102 | } |
103 | return c10::MaybeOwned<Tensor>::owned( |
104 | to_expand.expand_symint(tensor.sym_sizes())); |
105 | } |
106 | |
107 | inline c10::MaybeOwned<Tensor> expand_inplace( |
108 | const Tensor& tensor, |
109 | Tensor&& to_expand) = delete; |
110 | |
111 | inline c10::MaybeOwned<Tensor> expand_inplace( |
112 | const Tensor& tensor, |
113 | const Tensor& to_expand, |
114 | const char* api_name) { |
115 | check_defined({tensor, to_expand}, api_name); |
116 | return expand_inplace(tensor, to_expand); |
117 | } |
118 | |
119 | inline c10::MaybeOwned<Tensor> expand_inplace( |
120 | const Tensor& tensor, |
121 | Tensor&& to_expand, |
122 | const char* api_name) = delete; |
123 | |
124 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
125 | expand_inplace( |
126 | const Tensor& tensor, |
127 | const Tensor& to_expand1, |
128 | const Tensor& to_expand2) { |
129 | if (tensor.sizes().equals(to_expand1.sizes()) && |
130 | tensor.sizes().equals((to_expand2.sizes()))) { |
131 | return std::make_tuple( |
132 | c10::MaybeOwned<Tensor>::borrowed(to_expand1), |
133 | c10::MaybeOwned<Tensor>::borrowed(to_expand2)); |
134 | } |
135 | |
136 | return std::make_tuple( |
137 | c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())), |
138 | c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes()))); |
139 | } |
140 | |
141 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
142 | expand_inplace( |
143 | const Tensor& tensor, |
144 | Tensor&& to_expand1, |
145 | const Tensor& to_expand2) = delete; |
146 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
147 | expand_inplace( |
148 | const Tensor& tensor, |
149 | const Tensor& to_expand1, |
150 | Tensor&& to_expand2) = delete; |
151 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
152 | expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) = |
153 | delete; |
154 | |
155 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
156 | expand_inplace( |
157 | const Tensor& tensor, |
158 | const Tensor& to_expand1, |
159 | const Tensor& to_expand2, |
160 | const char* api_name) { |
161 | check_defined({tensor, to_expand1, to_expand2}, api_name); |
162 | return expand_inplace(tensor, to_expand1, to_expand2); |
163 | } |
164 | |
165 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
166 | expand_inplace( |
167 | const Tensor& tensor, |
168 | Tensor&& to_expand1, |
169 | const Tensor& to_expand2, |
170 | const char* api_name) = delete; |
171 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
172 | expand_inplace( |
173 | const Tensor& tensor, |
174 | const Tensor& to_expand1, |
175 | Tensor&& to_expand2, |
176 | const char* api_name) = delete; |
177 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
178 | expand_inplace( |
179 | const Tensor& tensor, |
180 | Tensor&& to_expand1, |
181 | Tensor&& to_expand2, |
182 | const char* api_name) = delete; |
183 | |
184 | // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation. |
185 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
186 | expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) { |
187 | if (to_expand1.sizes().equals(to_expand2.sizes())) { |
188 | return std::make_tuple( |
189 | c10::MaybeOwned<Tensor>::borrowed(to_expand1), |
190 | c10::MaybeOwned<Tensor>::borrowed(to_expand2)); |
191 | } |
192 | |
193 | auto expanded_size = |
194 | infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); |
195 | return std::make_tuple( |
196 | c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)), |
197 | c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size))); |
198 | } |
199 | |
200 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
201 | expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete; |
202 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
203 | expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete; |
204 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
205 | expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete; |
206 | |
207 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
208 | expand_outplace( |
209 | const Tensor& to_expand1, |
210 | const Tensor& to_expand2, |
211 | const char* api_name) { |
212 | check_defined({to_expand1, to_expand2}, api_name); |
213 | return expand_outplace(to_expand1, to_expand2); |
214 | } |
215 | |
216 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
217 | expand_outplace( |
218 | Tensor&& to_expand1, |
219 | const Tensor& to_expand2, |
220 | const char* api_name) = delete; |
221 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
222 | expand_outplace( |
223 | const Tensor& to_expand1, |
224 | Tensor&& to_expand2, |
225 | const char* api_name) = delete; |
226 | inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>> |
227 | expand_outplace( |
228 | Tensor&& to_expand1, |
229 | Tensor&& to_expand2, |
230 | const char* api_name) = delete; |
231 | |
232 | inline std::tuple< |
233 | c10::MaybeOwned<Tensor>, |
234 | c10::MaybeOwned<Tensor>, |
235 | c10::MaybeOwned<Tensor>> |
236 | expand_outplace( |
237 | const Tensor& to_expand1, |
238 | const Tensor& to_expand2, |
239 | const Tensor& to_expand3) { |
240 | if (to_expand1.sizes().equals(to_expand2.sizes()) && |
241 | to_expand1.sizes().equals(to_expand3.sizes())) { |
242 | return std::make_tuple( |
243 | c10::MaybeOwned<Tensor>::borrowed(to_expand1), |
244 | c10::MaybeOwned<Tensor>::borrowed(to_expand2), |
245 | c10::MaybeOwned<Tensor>::borrowed(to_expand3)); |
246 | } |
247 | |
248 | auto expanded_size12 = |
249 | infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); |
250 | auto expanded_size = |
251 | infer_size_dimvector(expanded_size12, to_expand3.sizes()); |
252 | return std::make_tuple( |
253 | c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)), |
254 | c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)), |
255 | c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size))); |
256 | } |
257 | |
258 | inline std::tuple< |
259 | c10::MaybeOwned<Tensor>, |
260 | c10::MaybeOwned<Tensor>, |
261 | c10::MaybeOwned<Tensor>> |
262 | expand_outplace( |
263 | Tensor&& to_expand1, |
264 | const Tensor& to_expand2, |
265 | const Tensor& to_expand3) = delete; |
266 | inline std::tuple< |
267 | c10::MaybeOwned<Tensor>, |
268 | c10::MaybeOwned<Tensor>, |
269 | c10::MaybeOwned<Tensor>> |
270 | expand_outplace( |
271 | const Tensor& to_expand1, |
272 | Tensor&& to_expand2, |
273 | const Tensor& to_expand3) = delete; |
274 | inline std::tuple< |
275 | c10::MaybeOwned<Tensor>, |
276 | c10::MaybeOwned<Tensor>, |
277 | c10::MaybeOwned<Tensor>> |
278 | expand_outplace( |
279 | Tensor&& to_expand1, |
280 | Tensor&& to_expand2, |
281 | const Tensor& to_expand3) = delete; |
282 | inline std::tuple< |
283 | c10::MaybeOwned<Tensor>, |
284 | c10::MaybeOwned<Tensor>, |
285 | c10::MaybeOwned<Tensor>> |
286 | expand_outplace( |
287 | const Tensor& to_expand1, |
288 | const Tensor& to_expand2, |
289 | Tensor&& to_expand3) = delete; |
290 | inline std::tuple< |
291 | c10::MaybeOwned<Tensor>, |
292 | c10::MaybeOwned<Tensor>, |
293 | c10::MaybeOwned<Tensor>> |
294 | expand_outplace( |
295 | Tensor&& to_expand1, |
296 | const Tensor& to_expand2, |
297 | Tensor&& to_expand3) = delete; |
298 | inline std::tuple< |
299 | c10::MaybeOwned<Tensor>, |
300 | c10::MaybeOwned<Tensor>, |
301 | c10::MaybeOwned<Tensor>> |
302 | expand_outplace( |
303 | const Tensor& to_expand1, |
304 | Tensor&& to_expand2, |
305 | Tensor&& to_expand3) = delete; |
306 | inline std::tuple< |
307 | c10::MaybeOwned<Tensor>, |
308 | c10::MaybeOwned<Tensor>, |
309 | c10::MaybeOwned<Tensor>> |
310 | expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) = |
311 | delete; |
312 | |
313 | inline std::tuple< |
314 | c10::MaybeOwned<Tensor>, |
315 | c10::MaybeOwned<Tensor>, |
316 | c10::MaybeOwned<Tensor>> |
317 | expand_outplace( |
318 | const Tensor& to_expand1, |
319 | const Tensor& to_expand2, |
320 | const Tensor& to_expand3, |
321 | const char* api_name) { |
322 | check_defined({to_expand1, to_expand2, to_expand3}, api_name); |
323 | return expand_outplace(to_expand1, to_expand2, to_expand3); |
324 | } |
325 | |
326 | inline std::tuple< |
327 | c10::MaybeOwned<Tensor>, |
328 | c10::MaybeOwned<Tensor>, |
329 | c10::MaybeOwned<Tensor>> |
330 | expand_outplace( |
331 | Tensor&& to_expand1, |
332 | const Tensor& to_expand2, |
333 | const Tensor& to_expand3, |
334 | const char* api_name) = delete; |
335 | inline std::tuple< |
336 | c10::MaybeOwned<Tensor>, |
337 | c10::MaybeOwned<Tensor>, |
338 | c10::MaybeOwned<Tensor>> |
339 | expand_outplace( |
340 | const Tensor& to_expand1, |
341 | Tensor&& to_expand2, |
342 | const Tensor& to_expand3, |
343 | const char* api_name) = delete; |
344 | inline std::tuple< |
345 | c10::MaybeOwned<Tensor>, |
346 | c10::MaybeOwned<Tensor>, |
347 | c10::MaybeOwned<Tensor>> |
348 | expand_outplace( |
349 | Tensor&& to_expand1, |
350 | Tensor&& to_expand2, |
351 | const Tensor& to_expand3, |
352 | const char* api_name) = delete; |
353 | inline std::tuple< |
354 | c10::MaybeOwned<Tensor>, |
355 | c10::MaybeOwned<Tensor>, |
356 | c10::MaybeOwned<Tensor>> |
357 | expand_outplace( |
358 | const Tensor& to_expand1, |
359 | const Tensor& to_expand2, |
360 | Tensor&& to_expand3, |
361 | const char* api_name) = delete; |
362 | inline std::tuple< |
363 | c10::MaybeOwned<Tensor>, |
364 | c10::MaybeOwned<Tensor>, |
365 | c10::MaybeOwned<Tensor>> |
366 | expand_outplace( |
367 | Tensor&& to_expand1, |
368 | const Tensor& to_expand2, |
369 | Tensor&& to_expand3, |
370 | const char* api_name) = delete; |
371 | inline std::tuple< |
372 | c10::MaybeOwned<Tensor>, |
373 | c10::MaybeOwned<Tensor>, |
374 | c10::MaybeOwned<Tensor>> |
375 | expand_outplace( |
376 | const Tensor& to_expand1, |
377 | Tensor&& to_expand2, |
378 | Tensor&& to_expand3, |
379 | const char* api_name) = delete; |
380 | inline std::tuple< |
381 | c10::MaybeOwned<Tensor>, |
382 | c10::MaybeOwned<Tensor>, |
383 | c10::MaybeOwned<Tensor>> |
384 | expand_outplace( |
385 | Tensor&& to_expand1, |
386 | Tensor&& to_expand2, |
387 | Tensor&& to_expand3, |
388 | const char* api_name) = delete; |
389 | |
390 | inline c10::MaybeOwned<Tensor> expand_size( |
391 | const Tensor& to_expand, |
392 | IntArrayRef sizes) { |
393 | if (to_expand.sizes().equals(sizes)) { |
394 | return c10::MaybeOwned<Tensor>::borrowed(to_expand); |
395 | } |
396 | |
397 | return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes)); |
398 | } |
399 | |
400 | inline c10::MaybeOwned<Tensor> expand_size( |
401 | Tensor&& to_expand, |
402 | IntArrayRef sizes) = delete; |
403 | |
404 | inline c10::MaybeOwned<Tensor> expand_size( |
405 | const Tensor& to_expand, |
406 | IntArrayRef sizes, |
407 | const char* api_name) { |
408 | check_defined({to_expand}, api_name); |
409 | return expand_size(to_expand, sizes); |
410 | } |
411 | |
412 | inline c10::MaybeOwned<Tensor> expand_size( |
413 | Tensor&& to_expand, |
414 | IntArrayRef sizes, |
415 | const char* api_name) = delete; |
416 | |
417 | inline std::vector<Tensor> expand_outplace(TensorList to_expand) { |
418 | // expands a list of Tensors; ignores undefined (null) tensors |
419 | bool first = true; |
420 | DimVector sizes; |
421 | for (const auto i : c10::irange(to_expand.size())) { |
422 | if (!to_expand[i].defined()) { |
423 | continue; |
424 | } else if (first) { |
425 | sizes = to_expand[i].sizes(); |
426 | first = false; |
427 | } else { |
428 | sizes = infer_size_dimvector(sizes, to_expand[i].sizes()); |
429 | } |
430 | } |
431 | |
432 | std::vector<Tensor> result(to_expand.size()); |
433 | for (const auto i : c10::irange(to_expand.size())) { |
434 | if (!to_expand[i].defined()) { |
435 | continue; |
436 | } else if (to_expand[i].sizes().equals(sizes)) { |
437 | result[i] = to_expand[i]; |
438 | } else { |
439 | result[i] = to_expand[i].expand(sizes); |
440 | } |
441 | } |
442 | return result; |
443 | } |
444 | |
445 | template <typename T> |
446 | inline Tensor _sum_to( |
447 | Tensor tensor, |
448 | const c10::ArrayRef<T> shape, |
449 | bool always_return_non_view = false) { |
450 | if (shape.size() == 0) { |
451 | return tensor.sum(); |
452 | } |
453 | |
454 | auto sizes = at::symint::sizes<T>(tensor); |
455 | c10::SmallVector<int64_t, 8> reduce_dims; |
456 | const int64_t leading_dims = sizes.size() - shape.size(); |
457 | for (const auto i : c10::irange(leading_dims)) { |
458 | reduce_dims.push_back(i); |
459 | } |
460 | for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) { |
461 | if (shape[i - leading_dims] == 1 && sizes[i] != 1) { |
462 | reduce_dims.push_back(i); |
463 | } |
464 | } |
465 | |
466 | if (!reduce_dims.empty()) { |
467 | tensor = tensor.sum(reduce_dims, /*keepdim=*/true); |
468 | } |
469 | |
470 | if (always_return_non_view) { |
471 | // This is only actually used by the functionalization pass. |
472 | // We want to be able to guarantee that this function doesn't return a view |
473 | // of the input. |
474 | return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape) |
475 | : tensor.clone(); |
476 | } else { |
477 | return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor; |
478 | } |
479 | } |
480 | |
481 | inline Tensor sum_to( |
482 | Tensor tensor, |
483 | const c10::SymIntArrayRef shape, |
484 | bool always_return_non_view = false) { |
485 | return _sum_to(std::move(tensor), shape, always_return_non_view); |
486 | } |
487 | |
488 | // Sums `tensor` repeatedly to produce a tensor of shape `shape`. |
489 | // Precondition: is_expandable_to(shape, tensor.sizes()) must be true |
490 | inline Tensor sum_to( |
491 | Tensor tensor, |
492 | const IntArrayRef shape, |
493 | bool always_return_non_view = false) { |
494 | return _sum_to(std::move(tensor), shape, always_return_non_view); |
495 | } |
496 | |
497 | static inline bool is_expandable_to( |
498 | SymIntArrayRef shape, |
499 | c10::SymIntArrayRef desired) { |
500 | size_t ndim = shape.size(); |
501 | size_t target_dim = desired.size(); |
502 | if (ndim > target_dim) { |
503 | return false; |
504 | } |
505 | for (const auto i : c10::irange(ndim)) { |
506 | const auto& size = shape[ndim - i - 1]; |
507 | const auto& target = desired[target_dim - i - 1]; |
508 | if (size != target && size != 1) { |
509 | return false; |
510 | } |
511 | } |
512 | return true; |
513 | } |
514 | |
515 | static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { |
516 | auto sym_shape = c10::SymIntArrayRef( |
517 | reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size()); |
518 | auto sym_desired = c10::SymIntArrayRef( |
519 | reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size()); |
520 | return is_expandable_to(sym_shape, sym_desired); |
521 | } |
522 | |
523 | } // namespace at |
524 | |