1 | #pragma once |
2 | |
3 | #include <ATen/DimVector.h> |
4 | #include <ATen/EmptyTensor.h> |
5 | #include <ATen/Tensor.h> |
6 | #include <ATen/TensorGeometry.h> |
7 | #include <ATen/Utils.h> |
8 | |
9 | #include <utility> |
10 | |
11 | // These functions are NOT in Utils.h, because this file has a dep on Tensor.h |
12 | |
13 | #define TORCH_CHECK_TENSOR_ALL(cond, ...) \ |
14 | TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__); |
15 | |
16 | namespace at { |
17 | |
18 | // The following are utility functions for checking that arguments |
19 | // make sense. These are particularly useful for native functions, |
20 | // which do NO argument checking by default. |
21 | |
22 | struct TORCH_API TensorArg { |
23 | const Tensor& tensor; |
24 | const char* name; |
25 | int pos; // 1-indexed |
26 | TensorArg(const Tensor& tensor, const char* name, int pos) |
27 | : tensor(tensor), name(name), pos(pos) {} |
28 | // Try to mitigate any possibility of dangling reference to temporaries. |
29 | TensorArg(Tensor&& tensor, const char* name, int pos) = delete; |
30 | const Tensor* operator->() const { |
31 | return &tensor; |
32 | } |
33 | const Tensor& operator*() const { |
34 | return tensor; |
35 | } |
36 | }; |
37 | |
38 | struct TORCH_API TensorGeometryArg { |
39 | TensorGeometry tensor; |
40 | const char* name; |
41 | int pos; // 1-indexed |
42 | /* implicit */ TensorGeometryArg(TensorArg arg) |
43 | : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {} |
44 | TensorGeometryArg(TensorGeometry tensor, const char* name, int pos) |
45 | : tensor(std::move(tensor)), name(name), pos(pos) {} |
46 | const TensorGeometry* operator->() const { |
47 | return &tensor; |
48 | } |
49 | const TensorGeometry& operator*() const { |
50 | return tensor; |
51 | } |
52 | }; |
53 | |
54 | // A string describing which function did checks on its input |
55 | // arguments. |
56 | // TODO: Consider generalizing this into a call stack. |
57 | using CheckedFrom = const char*; |
58 | |
59 | // The undefined convention: singular operators assume their arguments |
60 | // are defined, but functions which take multiple tensors will |
61 | // implicitly filter out undefined tensors (to make it easier to perform |
62 | // tests which should apply if the tensor is defined, and should not |
63 | // otherwise.) |
64 | // |
65 | // NB: This means that the n-ary operators take lists of TensorArg, |
66 | // not TensorGeometryArg, because the Tensor to TensorGeometry |
67 | // conversion will blow up if you have undefined tensors. |
68 | |
69 | TORCH_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t); |
70 | TORCH_API void checkDim( |
71 | CheckedFrom c, |
72 | const Tensor& tensor, |
73 | const char* name, |
74 | int pos, // 1-indexed |
75 | int64_t dim); |
76 | TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim); |
77 | // NB: this is an inclusive-exclusive range |
78 | TORCH_API void checkDimRange( |
79 | CheckedFrom c, |
80 | const TensorGeometryArg& t, |
81 | int64_t dim_start, |
82 | int64_t dim_end); |
83 | TORCH_API void checkSameDim( |
84 | CheckedFrom c, |
85 | const TensorGeometryArg& t1, |
86 | const TensorGeometryArg& t2); |
87 | TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); |
88 | TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts); |
89 | TORCH_API void checkSize( |
90 | CheckedFrom c, |
91 | const TensorGeometryArg& t, |
92 | IntArrayRef sizes); |
93 | TORCH_API void checkSize_symint( |
94 | CheckedFrom c, |
95 | const TensorGeometryArg& t, |
96 | c10::SymIntArrayRef sizes); |
97 | TORCH_API void checkSize( |
98 | CheckedFrom c, |
99 | const TensorGeometryArg& t, |
100 | int64_t dim, |
101 | int64_t size); |
102 | TORCH_API void checkSize_symint( |
103 | CheckedFrom c, |
104 | const TensorGeometryArg& t, |
105 | int64_t dim, |
106 | c10::SymInt size); |
107 | TORCH_API void checkNumel( |
108 | CheckedFrom c, |
109 | const TensorGeometryArg& t, |
110 | int64_t numel); |
111 | TORCH_API void ( |
112 | CheckedFrom c, |
113 | const TensorGeometryArg& t1, |
114 | const TensorGeometryArg& t2); |
115 | TORCH_API void (CheckedFrom c, ArrayRef<TensorArg> tensors); |
116 | TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s); |
117 | TORCH_API void checkScalarTypes( |
118 | CheckedFrom c, |
119 | const TensorArg& t, |
120 | at::ArrayRef<ScalarType> l); |
121 | TORCH_API void checkSameGPU( |
122 | CheckedFrom c, |
123 | const TensorArg& t1, |
124 | const TensorArg& t2); |
125 | TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors); |
126 | TORCH_API void checkSameType( |
127 | CheckedFrom c, |
128 | const TensorArg& t1, |
129 | const TensorArg& t2); |
130 | TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors); |
131 | TORCH_API void checkSameSize( |
132 | CheckedFrom c, |
133 | const TensorArg& t1, |
134 | const TensorArg& t2); |
135 | TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t); |
136 | TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t); |
137 | |
138 | // FixMe: does TensorArg slow things down? |
139 | TORCH_API void checkBackend( |
140 | CheckedFrom c, |
141 | at::ArrayRef<Tensor> t, |
142 | at::Backend backend); |
143 | |
144 | TORCH_API void checkDeviceType( |
145 | CheckedFrom c, |
146 | at::ArrayRef<Tensor> tensors, |
147 | at::DeviceType device_type); |
148 | |
149 | TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); |
150 | |
151 | TORCH_API void checkLayout( |
152 | CheckedFrom c, |
153 | at::ArrayRef<Tensor> tensors, |
154 | at::Layout layout); |
155 | |
156 | // Methods for getting data_ptr if tensor is defined |
157 | TORCH_API void* maybe_data_ptr(const Tensor& tensor); |
158 | TORCH_API void* maybe_data_ptr(const TensorArg& tensor); |
159 | |
160 | TORCH_API void check_dim_size( |
161 | const Tensor& tensor, |
162 | int64_t dim, |
163 | int64_t dim_size, |
164 | int64_t size); |
165 | |
166 | namespace detail { |
167 | TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes); |
168 | |
169 | TORCH_API c10::optional<std::vector<int64_t>> computeStride( |
170 | IntArrayRef oldshape, |
171 | IntArrayRef oldstride, |
172 | IntArrayRef newshape); |
173 | |
174 | TORCH_API c10::optional<SymDimVector> computeStride( |
175 | c10::SymIntArrayRef oldshape, |
176 | c10::SymIntArrayRef oldstride, |
177 | c10::SymIntArrayRef newshape); |
178 | |
179 | TORCH_API c10::optional<DimVector> computeStride( |
180 | IntArrayRef oldshape, |
181 | IntArrayRef oldstride, |
182 | const DimVector& newshape); |
183 | |
184 | } // namespace detail |
185 | } // namespace at |
186 | |