1 | #pragma once |
2 | |
3 | #include <ATen/EmptyTensor.h> |
4 | #include <ATen/Formatting.h> |
5 | #include <ATen/core/ATenGeneral.h> |
6 | #include <ATen/core/Generator.h> |
7 | #include <c10/core/ScalarType.h> |
8 | #include <c10/core/StorageImpl.h> |
9 | #include <c10/core/UndefinedTensorImpl.h> |
10 | #include <c10/util/ArrayRef.h> |
11 | #include <c10/util/Exception.h> |
12 | #include <c10/util/accumulate.h> |
13 | #include <c10/util/irange.h> |
14 | |
15 | #include <algorithm> |
16 | #include <memory> |
17 | #include <numeric> |
18 | #include <sstream> |
19 | #include <typeinfo> |
20 | |
21 | #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ |
22 | TypeName(const TypeName&) = delete; \ |
23 | void operator=(const TypeName&) = delete |
24 | |
25 | namespace at { |
26 | |
27 | TORCH_API int _crash_if_asan(int); |
28 | |
29 | // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*) |
30 | // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat. |
31 | // Once cat is ported entirely to ATen this can be deleted! |
32 | static inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap( |
33 | ArrayRef<Tensor> tensors, |
34 | const char* name, |
35 | int pos, |
36 | DeviceType device_type, |
37 | ScalarType scalar_type) { |
38 | std::vector<TensorImpl*> unwrapped; |
39 | unwrapped.reserve(tensors.size()); |
40 | for (const auto i : c10::irange(tensors.size())) { |
41 | const auto& expr = tensors[i]; |
42 | if (expr.layout() != Layout::Strided) { |
43 | AT_ERROR( |
44 | "Expected dense tensor but got " , |
45 | expr.layout(), |
46 | " for sequence element " , |
47 | i, |
48 | " in sequence argument at position #" , |
49 | pos, |
50 | " '" , |
51 | name, |
52 | "'" ); |
53 | } |
54 | if (expr.device().type() != device_type) { |
55 | AT_ERROR( |
56 | "Expected object of device type " , |
57 | device_type, |
58 | " but got device type " , |
59 | expr.device().type(), |
60 | " for sequence element " , |
61 | i, |
62 | " in sequence argument at position #" , |
63 | pos, |
64 | " '" , |
65 | name, |
66 | "'" ); |
67 | } |
68 | if (expr.scalar_type() != scalar_type) { |
69 | AT_ERROR( |
70 | "Expected object of scalar type " , |
71 | scalar_type, |
72 | " but got scalar type " , |
73 | expr.scalar_type(), |
74 | " for sequence element " , |
75 | i, |
76 | " in sequence argument at position #" , |
77 | pos, |
78 | " '" , |
79 | name, |
80 | "'" ); |
81 | } |
82 | unwrapped.emplace_back(expr.unsafeGetTensorImpl()); |
83 | } |
84 | return unwrapped; |
85 | } |
86 | |
87 | template <size_t N> |
88 | std::array<int64_t, N> check_intlist( |
89 | ArrayRef<int64_t> list, |
90 | const char* name, |
91 | int pos) { |
92 | if (list.empty()) { |
93 | // TODO: is this necessary? We used to treat nullptr-vs-not in IntList |
94 | // differently with strides as a way of faking optional. |
95 | list = {}; |
96 | } |
97 | auto res = std::array<int64_t, N>(); |
98 | if (list.size() == 1 && N > 1) { |
99 | res.fill(list[0]); |
100 | return res; |
101 | } |
102 | if (list.size() != N) { |
103 | AT_ERROR( |
104 | "Expected a list of " , |
105 | N, |
106 | " ints but got " , |
107 | list.size(), |
108 | " for argument #" , |
109 | pos, |
110 | " '" , |
111 | name, |
112 | "'" ); |
113 | } |
114 | std::copy_n(list.begin(), N, res.begin()); |
115 | return res; |
116 | } |
117 | |
118 | using at::detail::check_size_nonnegative; |
119 | |
120 | namespace detail { |
121 | |
122 | template <typename T> |
123 | TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options); |
124 | |
125 | template <typename T> |
126 | TORCH_API Tensor |
127 | tensor_backend(ArrayRef<T> values, const TensorOptions& options); |
128 | |
129 | template <typename T> |
130 | TORCH_API Tensor |
131 | tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options); |
132 | |
133 | template <typename T> |
134 | TORCH_API Tensor |
135 | tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options); |
136 | } // namespace detail |
137 | |
138 | } // namespace at |
139 | |