1#pragma once
2
3#include <ATen/DimVector.h>
4#include <c10/core/ScalarType.h>
5#include <c10/core/SymIntArrayRef.h>
6#include <c10/util/DimVector.h>
7#include <c10/util/Optional.h>
8#include <sstream>
9#include <vector>
10
11namespace at {
12
13// Infers the size of a dim with size -1, if it exists. Also checks that new
14// shape is compatible with the number of elements.
15//
16// templated to handle std::vector<int64_t> and DimVector use cases, see
17// below
18//
19template <typename InputArrayRef, typename NumelType, typename ResultVec>
20inline void infer_size_impl(
21 InputArrayRef shape,
22 NumelType numel,
23 ResultVec& res) {
24 NumelType newsize = 1;
25 // N.B. this is an index, not a sym dim!
26 auto infer_dim = c10::optional<int64_t>();
27 for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28 if (shape[dim] == -1) {
29 if (infer_dim) {
30 throw std::runtime_error("only one dimension can be inferred");
31 }
32 infer_dim = dim;
33 } else if (shape[dim] >= 0) {
34 newsize *= shape[dim];
35 } else {
36 AT_ERROR("invalid shape dimension ", shape[dim]);
37 }
38 }
39
40 if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
41 if (infer_dim) {
42 // We have a degree of freedom here to select the dimension size; follow
43 // NumPy semantics and just bail. However, a nice error message is needed
44 // because users often use `view` as a way to flatten & unflatten
45 // dimensions and will otherwise be confused why
46 // empty_tensor.view( 0, 0)
47 // works yet
48 // empty_tensor.view(-1, 0)
49 // doesn't.
50 TORCH_CHECK(
51 newsize != 0,
52 "cannot reshape tensor of 0 elements into shape ",
53 shape,
54 " because the unspecified dimension size -1 can be any "
55 "value and is ambiguous");
56 res[*infer_dim] = numel / newsize;
57 }
58 return;
59 }
60
61 std::ostringstream ss;
62 ss << "shape '" << shape << "' is invalid for input of size " << numel;
63 throw std::runtime_error(ss.str());
64}
65
66inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
67 auto res = shape.vec();
68 infer_size_impl(shape, numel, res);
69 return res;
70}
71
72inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
73 auto res = at::DimVector(shape);
74 infer_size_impl(shape, numel, res);
75 return res;
76}
77
78inline at::SymDimVector infer_size_dv(
79 c10::SymIntArrayRef shape,
80 c10::SymInt numel) {
81 auto res = at::SymDimVector(shape);
82 infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
83 shape, std::move(numel), res);
84 return res;
85}
86
87} // namespace at
88