1#pragma once
2
3#include <ATen/DimVector.h>
4#include <ATen/EmptyTensor.h>
5#include <ATen/Tensor.h>
6#include <ATen/TensorGeometry.h>
7#include <ATen/Utils.h>
8
9#include <utility>
10
11// These functions are NOT in Utils.h, because this file has a dep on Tensor.h
12
13#define TORCH_CHECK_TENSOR_ALL(cond, ...) \
14 TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
15
16namespace at {
17
18// The following are utility functions for checking that arguments
19// make sense. These are particularly useful for native functions,
20// which do NO argument checking by default.
21
22struct TORCH_API TensorArg {
23 const Tensor& tensor;
24 const char* name;
25 int pos; // 1-indexed
26 TensorArg(const Tensor& tensor, const char* name, int pos)
27 : tensor(tensor), name(name), pos(pos) {}
28 // Try to mitigate any possibility of dangling reference to temporaries.
29 TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
30 const Tensor* operator->() const {
31 return &tensor;
32 }
33 const Tensor& operator*() const {
34 return tensor;
35 }
36};
37
38struct TORCH_API TensorGeometryArg {
39 TensorGeometry tensor;
40 const char* name;
41 int pos; // 1-indexed
42 /* implicit */ TensorGeometryArg(TensorArg arg)
43 : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
44 TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
45 : tensor(std::move(tensor)), name(name), pos(pos) {}
46 const TensorGeometry* operator->() const {
47 return &tensor;
48 }
49 const TensorGeometry& operator*() const {
50 return tensor;
51 }
52};
53
54// A string describing which function did checks on its input
55// arguments.
56// TODO: Consider generalizing this into a call stack.
57using CheckedFrom = const char*;
58
59// The undefined convention: singular operators assume their arguments
60// are defined, but functions which take multiple tensors will
61// implicitly filter out undefined tensors (to make it easier to perform
62// tests which should apply if the tensor is defined, and should not
63// otherwise.)
64//
65// NB: This means that the n-ary operators take lists of TensorArg,
66// not TensorGeometryArg, because the Tensor to TensorGeometry
67// conversion will blow up if you have undefined tensors.
68
69TORCH_API std::ostream& operator<<(std::ostream& out, TensorGeometryArg t);
70TORCH_API void checkDim(
71 CheckedFrom c,
72 const Tensor& tensor,
73 const char* name,
74 int pos, // 1-indexed
75 int64_t dim);
76TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
77// NB: this is an inclusive-exclusive range
78TORCH_API void checkDimRange(
79 CheckedFrom c,
80 const TensorGeometryArg& t,
81 int64_t dim_start,
82 int64_t dim_end);
83TORCH_API void checkSameDim(
84 CheckedFrom c,
85 const TensorGeometryArg& t1,
86 const TensorGeometryArg& t2);
87TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
88TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
89TORCH_API void checkSize(
90 CheckedFrom c,
91 const TensorGeometryArg& t,
92 IntArrayRef sizes);
93TORCH_API void checkSize_symint(
94 CheckedFrom c,
95 const TensorGeometryArg& t,
96 c10::SymIntArrayRef sizes);
97TORCH_API void checkSize(
98 CheckedFrom c,
99 const TensorGeometryArg& t,
100 int64_t dim,
101 int64_t size);
102TORCH_API void checkSize_symint(
103 CheckedFrom c,
104 const TensorGeometryArg& t,
105 int64_t dim,
106 c10::SymInt size);
107TORCH_API void checkNumel(
108 CheckedFrom c,
109 const TensorGeometryArg& t,
110 int64_t numel);
111TORCH_API void checkSameNumel(
112 CheckedFrom c,
113 const TensorGeometryArg& t1,
114 const TensorGeometryArg& t2);
115TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
116TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
117TORCH_API void checkScalarTypes(
118 CheckedFrom c,
119 const TensorArg& t,
120 at::ArrayRef<ScalarType> l);
121TORCH_API void checkSameGPU(
122 CheckedFrom c,
123 const TensorArg& t1,
124 const TensorArg& t2);
125TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
126TORCH_API void checkSameType(
127 CheckedFrom c,
128 const TensorArg& t1,
129 const TensorArg& t2);
130TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
131TORCH_API void checkSameSize(
132 CheckedFrom c,
133 const TensorArg& t1,
134 const TensorArg& t2);
135TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
136TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
137
138// FixMe: does TensorArg slow things down?
139TORCH_API void checkBackend(
140 CheckedFrom c,
141 at::ArrayRef<Tensor> t,
142 at::Backend backend);
143
144TORCH_API void checkDeviceType(
145 CheckedFrom c,
146 at::ArrayRef<Tensor> tensors,
147 at::DeviceType device_type);
148
149TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
150
151TORCH_API void checkLayout(
152 CheckedFrom c,
153 at::ArrayRef<Tensor> tensors,
154 at::Layout layout);
155
156// Methods for getting data_ptr if tensor is defined
157TORCH_API void* maybe_data_ptr(const Tensor& tensor);
158TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
159
160TORCH_API void check_dim_size(
161 const Tensor& tensor,
162 int64_t dim,
163 int64_t dim_size,
164 int64_t size);
165
166namespace detail {
167TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
168
169TORCH_API c10::optional<std::vector<int64_t>> computeStride(
170 IntArrayRef oldshape,
171 IntArrayRef oldstride,
172 IntArrayRef newshape);
173
174TORCH_API c10::optional<SymDimVector> computeStride(
175 c10::SymIntArrayRef oldshape,
176 c10::SymIntArrayRef oldstride,
177 c10::SymIntArrayRef newshape);
178
179TORCH_API c10::optional<DimVector> computeStride(
180 IntArrayRef oldshape,
181 IntArrayRef oldstride,
182 const DimVector& newshape);
183
184} // namespace detail
185} // namespace at
186