1 | #include <ATen/core/TensorBase.h> |
---|---|
2 | |
3 | // Broadcasting utilities for working with TensorBase |
4 | namespace at { |
5 | namespace internal { |
6 | TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size); |
7 | } // namespace internal |
8 | |
9 | inline c10::MaybeOwned<TensorBase> expand_size( |
10 | const TensorBase& self, |
11 | IntArrayRef size) { |
12 | if (size.equals(self.sizes())) { |
13 | return c10::MaybeOwned<TensorBase>::borrowed(self); |
14 | } |
15 | return c10::MaybeOwned<TensorBase>::owned( |
16 | at::internal::expand_slow_path(self, size)); |
17 | } |
18 | c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) = |
19 | delete; |
20 | |
21 | inline c10::MaybeOwned<TensorBase> expand_inplace( |
22 | const TensorBase& tensor, |
23 | const TensorBase& to_expand) { |
24 | return expand_size(to_expand, tensor.sizes()); |
25 | } |
26 | c10::MaybeOwned<TensorBase> expand_inplace( |
27 | const TensorBase& tensor, |
28 | TensorBase&& to_expand) = delete; |
29 | |
30 | } // namespace at |
31 |