1 | #pragma once |
2 | #include <ATen/core/Tensor.h> |
3 | #include <c10/util/irange.h> |
4 | #include <ATen/core/IListRef.h> |
5 | |
6 | namespace at { |
7 | namespace native { |
8 | |
9 | TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self); |
10 | |
11 | inline bool cat_should_skip_tensor(const Tensor& t) { |
12 | return t.numel() == 0 && t.dim() == 1; |
13 | } |
14 | |
15 | // Check to see if the shape of tensors is compatible |
16 | // for being concatenated along a given dimension. |
17 | inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) { |
18 | int64_t first_dims = first.dim(); |
19 | int64_t second_dims = second.dim(); |
20 | TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got " , |
21 | first_dims, " and " , second_dims); |
22 | for (const auto dim : c10::irange(first_dims)) { |
23 | if (dim == dimension) { |
24 | continue; |
25 | } |
26 | int64_t first_dim_size = first.sizes()[dim]; |
27 | int64_t second_dim_size = second.sizes()[dim]; |
28 | TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension " , |
29 | dimension, ". Expected size " , static_cast<long long>(first_dim_size), " but got size " , static_cast<long long>(second_dim_size), " for tensor number " , index, " in the list." ); |
30 | } |
31 | } |
32 | |
33 | inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) { |
34 | int64_t i = 0; |
35 | for(const Tensor& t : tensors) { |
36 | TORCH_CHECK(t.dim() > 0, |
37 | "zero-dimensional tensor (at position " , i, ") cannot be concatenated" ); |
38 | i++; |
39 | } |
40 | } |
41 | |
42 | inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) { |
43 | TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor" ); |
44 | TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=" , split_size); |
45 | int64_t dim_size = self.size(dim); |
46 | TORCH_CHECK(split_size > 0 || dim_size == 0, |
47 | "split_size can only be 0 if dimension size is 0, " |
48 | "but got dimension size of " , dim_size); |
49 | // if split_size is 0 and dimension size is 0, there is 1 split. |
50 | int64_t num_splits = 1; |
51 | if (split_size != 0) { |
52 | // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size |
53 | // (returns a single split). We might want to error here, but keep it for BC. |
54 | num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1); |
55 | } |
56 | return num_splits; |
57 | } |
58 | |
59 | }} // namespace at::native |
60 | |