1 | #pragma once |
2 | |
3 | #include <ATen/DimVector.h> |
4 | #include <c10/core/ScalarType.h> |
5 | #include <c10/core/SymIntArrayRef.h> |
6 | #include <c10/util/DimVector.h> |
7 | #include <c10/util/Optional.h> |
8 | #include <sstream> |
9 | #include <vector> |
10 | |
11 | namespace at { |
12 | |
13 | // Infers the size of a dim with size -1, if it exists. Also checks that new |
14 | // shape is compatible with the number of elements. |
15 | // |
16 | // templated to handle std::vector<int64_t> and DimVector use cases, see |
17 | // below |
18 | // |
19 | template <typename InputArrayRef, typename NumelType, typename ResultVec> |
20 | inline void infer_size_impl( |
21 | InputArrayRef shape, |
22 | NumelType numel, |
23 | ResultVec& res) { |
24 | NumelType newsize = 1; |
25 | // N.B. this is an index, not a sym dim! |
26 | auto infer_dim = c10::optional<int64_t>(); |
27 | for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) { |
28 | if (shape[dim] == -1) { |
29 | if (infer_dim) { |
30 | throw std::runtime_error("only one dimension can be inferred" ); |
31 | } |
32 | infer_dim = dim; |
33 | } else if (shape[dim] >= 0) { |
34 | newsize *= shape[dim]; |
35 | } else { |
36 | AT_ERROR("invalid shape dimension " , shape[dim]); |
37 | } |
38 | } |
39 | |
40 | if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) { |
41 | if (infer_dim) { |
42 | // We have a degree of freedom here to select the dimension size; follow |
43 | // NumPy semantics and just bail. However, a nice error message is needed |
44 | // because users often use `view` as a way to flatten & unflatten |
45 | // dimensions and will otherwise be confused why |
46 | // empty_tensor.view( 0, 0) |
47 | // works yet |
48 | // empty_tensor.view(-1, 0) |
49 | // doesn't. |
50 | TORCH_CHECK( |
51 | newsize != 0, |
52 | "cannot reshape tensor of 0 elements into shape " , |
53 | shape, |
54 | " because the unspecified dimension size -1 can be any " |
55 | "value and is ambiguous" ); |
56 | res[*infer_dim] = numel / newsize; |
57 | } |
58 | return; |
59 | } |
60 | |
61 | std::ostringstream ss; |
62 | ss << "shape '" << shape << "' is invalid for input of size " << numel; |
63 | throw std::runtime_error(ss.str()); |
64 | } |
65 | |
66 | inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) { |
67 | auto res = shape.vec(); |
68 | infer_size_impl(shape, numel, res); |
69 | return res; |
70 | } |
71 | |
72 | inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) { |
73 | auto res = at::DimVector(shape); |
74 | infer_size_impl(shape, numel, res); |
75 | return res; |
76 | } |
77 | |
78 | inline at::SymDimVector infer_size_dv( |
79 | c10::SymIntArrayRef shape, |
80 | c10::SymInt numel) { |
81 | auto res = at::SymDimVector(shape); |
82 | infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>( |
83 | shape, std::move(numel), res); |
84 | return res; |
85 | } |
86 | |
87 | } // namespace at |
88 | |