1#include <ATen/ATen.h>
2#include <ATen/Config.h>
3#include <ATen/TensorUtils.h>
4#include <c10/util/accumulate.h>
5#include <c10/util/irange.h>
6
7#include <ostream>
8#include <sstream>
9
10namespace at {
11
12std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) {
13 if (t.pos == 0) {
14 // 0 is distinguished; it usually indicates 'self' or the return
15 // tensor
16 out << "'" << t.name << "'";
17 } else {
18 out << "argument #" << t.pos << " '" << t.name << "'";
19 }
20 return out;
21}
22
23void checkDim(
24 CheckedFrom c,
25 const Tensor& tensor,
26 const char* name,
27 int pos, // 1-indexed
28 int64_t dim) {
29 TORCH_CHECK(
30 tensor.dim() == dim,
31 "Expected ",
32 dim,
33 "-dimensional tensor, but got ",
34 tensor.dim(),
35 "-dimensional tensor for ",
36 TensorGeometryArg(TensorArg({tensor, name, pos})),
37 " (while checking arguments for ",
38 c,
39 ")");
40}
41
42void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) {
43 TORCH_CHECK(t->dim() == dim,
44 "Expected ", dim, "-dimensional tensor, but got ", t->dim(),
45 "-dimensional tensor for ", t," (while checking arguments for ", c, ")");
46}
47
48void checkDimRange(CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) {
49 TORCH_CHECK(
50 t->dim() >= dim_start && t->dim() < dim_end,
51 "Expected ", dim_start, " to ", (dim_end - 1), " dimensions, but got ",
52 t->dim(), "-dimensional tensor for ", t, " (while checking arguments for ",
53 c, ")");
54}
55
56void checkContiguous(CheckedFrom c, const TensorGeometryArg& t) {
57 TORCH_CHECK(
58 t->is_contiguous(),
59 "Expected contiguous tensor, but got non-contiguous tensor for ", t,
60 " (while checking arguments for ", c, ")");
61}
62
63void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) {
64 for (auto& t : ts) {
65 if (!t->defined()) continue;
66 checkContiguous(c, t);
67 }
68}
69
70void checkSize(CheckedFrom c, const TensorGeometryArg& t, IntArrayRef sizes) {
71 checkDim(c, t, sizes.size());
72 TORCH_CHECK(
73 t->sizes().equals(sizes),
74 "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
75 " for ", t, " (while checking arguments for ", c, ")");
76}
77
78void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, c10::SymIntArrayRef sizes) {
79 checkDim(c, t, sizes.size());
80 TORCH_CHECK(
81 t->sym_sizes().equals(sizes),
82 "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
83 " for ", t, " (while checking arguments for ", c, ")");
84}
85
86void checkSize(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) {
87 TORCH_CHECK(
88 t->size(dim) == size,
89 "Expected tensor to have size ", size, " at dimension ", dim,
90 ", but got size ", t->size(dim), " for ", t,
91 " (while checking arguments for ", c, ")");
92}
93
94void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, c10::SymInt size) {
95 TORCH_CHECK(
96 t->sym_size(dim) == size,
97 "Expected tensor to have size ", size, " at dimension ", dim,
98 ", but got size ", t->size(dim), " for ", t,
99 " (while checking arguments for ", c, ")");
100}
101
102void checkAllSame(CheckedFrom c, ArrayRef<TensorArg> tensors, void(*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) {
103 const TensorArg* t0 = nullptr;
104 for (auto& t : tensors) {
105 if (!t->defined()) continue;
106 if (t0 != nullptr) {
107 fn(c, *t0, t);
108 } else {
109 t0 = &t;
110 }
111 }
112}
113
114void checkSameSize(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
115 TORCH_CHECK(
116 t1->sizes().equals(t2->sizes()),
117 "Expected tensor for ", t1, " to have same size as tensor for ", t2,
118 "; but ", t1->sizes(), " does not equal ", t2->sizes(),
119 " (while checking arguments for ", c, ")");
120}
121
122void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors) {
123 checkAllSame(c, tensors, checkSameSize);
124}
125
126void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel) {
127 TORCH_CHECK(
128 t->numel() == numel,
129 "Expected tensor for ", t, " to have ", numel,
130 " elements; but it actually has ", t->numel(), " elements",
131 " (while checking arguments for ", c, ")");
132}
133
134void checkSameNumel(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
135 TORCH_CHECK(
136 t1->numel() == t2->numel(),
137 "Expected tensor for ", t1,
138 " to have same number of elements as tensor for ", t2, "; but ",
139 t1->numel(), " does not equal ", t2->numel(),
140 " (while checking arguments for ", c, ")");
141}
142
143void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors) {
144 checkAllSame(c, tensors, checkSameNumel);
145}
146
147void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
148 if (t1->is_cpu() || t2->is_cpu()) {
149 std::ostringstream oss;
150 if (t1->is_cpu()) {
151 oss << "Tensor for " << t1 << " is on CPU, ";
152 }
153 if (t2->is_cpu()) {
154 oss << "Tensor for " << t2 << " is on CPU, ";
155 }
156 oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
157 << " to be on GPU (while checking arguments for " << c << ")";
158 AT_ERROR(oss.str());
159 }
160 TORCH_CHECK(
161 t1->get_device() == t2->get_device(),
162 "Expected tensor for ", t1, " to have the same device as tensor for ", t2,
163 "; but device ", t1->get_device(), " does not equal ", t2->get_device(),
164 " (while checking arguments for ", c, ")");
165}
166
167void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors) {
168 checkAllSame(c, tensors, checkSameGPU);
169}
170
171void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
172 TORCH_CHECK(
173 t1->options().type_equal(t2->options()),
174 "Expected tensor for ", t1, " to have the same type as tensor for ", t2,
175 "; but type ", t1->toString(), " does not equal ", t2->toString(),
176 " (while checking arguments for ", c, ")");
177}
178
179void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) {
180 TORCH_CHECK(
181 t->scalar_type() == ty,
182 "Expected tensor for ", t, " to have scalar type ", toString(ty),
183 "; but got ", t->toString(), " instead (while checking arguments for ", c,
184 ")");
185}
186
187void checkScalarTypes(CheckedFrom c, const TensorArg& t,
188 at::ArrayRef<ScalarType> l) {
189 if (std::find(l.begin(), l.end(), t->scalar_type()) == l.end()) {
190 std::ostringstream oss;
191 oss << "Expected tensor for " << t << " to have one of the following "
192 << "scalar types: ";
193 size_t i = 0;
194 for (auto ty : l) {
195 if (i != 0) {
196 oss << ", ";
197 }
198 oss << toString(ty);
199 i++;
200 }
201 oss << "; but got " << t->toString()
202 << " instead (while checking arguments for " << c << ")";
203 AT_ERROR(oss.str());
204 }
205}
206
207void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
208 checkAllSame(c, tensors, checkSameType);
209}
210
211void checkSameDim(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) {
212 TORCH_CHECK(
213 t1->dim() == t2->dim(),
214 "Expected tensor for ", t1, " to have the same dimension as tensor for ",
215 t2, "; but ", t1->dim(), " does not equal ", t2->dim(),
216 " (while checking arguments for ", c, ")");
217}
218
219void checkDefined(CheckedFrom c, const TensorArg& t) {
220 TORCH_CHECK(
221 t->defined(),
222 "Expected tensor for ", t, " to be non-null, but it was undefined ",
223 " (while checking arguments for ", c, ")");
224}
225
226void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {
227 // NB: don't filter defined here
228 for (auto t : ts) {
229 checkDefined(c, t);
230 }
231}
232
233void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) {
234 TORCH_CHECK(
235 !t.defined() || t.options().backend() == backend,
236 "Expected tensor to have ", toString(backend),
237 " Backend, but got tensor with ", toString(t.options().backend()), " Backend ",
238 "(while checking arguments for ", c, ")");
239}
240
241void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Backend backend) {
242 for (auto &t : tensors) {
243 checkBackend(c, t, backend);
244 }
245}
246
247void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) {
248 TORCH_CHECK(
249 !t.defined() || t.device().type() == device_type,
250 "Expected tensor to have ", device_type,
251 " DeviceType, but got tensor with ", t.device().type(), " DeviceType ",
252 "(while checking arguments for ", c, ")");
253}
254
255void checkDeviceType(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::DeviceType device_type) {
256 for (auto &t : tensors) {
257 checkDeviceType(c, t, device_type);
258 }
259}
260
261void checkLayout(CheckedFrom c, const Tensor& t, Layout layout) {
262 TORCH_CHECK(
263 !t.defined() || t.layout() == layout,
264 "Expected tensor to have ", layout,
265 " Layout, but got tensor with ", t.layout(), " Layout ",
266 "(while checking arguments for ", c, ")");
267}
268
269void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout) {
270 for (auto &t : tensors) {
271 checkLayout(c, t, layout);
272 }
273}
274
275void * maybe_data_ptr(const Tensor& tensor) {
276 return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
277}
278
279void * maybe_data_ptr(const TensorArg& tensor) {
280 return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
281}
282
283void check_dim_size(
284 const Tensor& tensor,
285 int64_t dim,
286 int64_t dim_size,
287 int64_t size) {
288 /* Check dimension size of a tensor */
289 TORCH_CHECK(
290 tensor.dim() == dim && tensor.size(dim_size) == size,
291 "Expected a tensor of dimension ",
292 dim,
293 " and tensor.size[",
294 dim_size,
295 "] == ",
296 size,
297 " but got: dimension ",
298 tensor.dim(),
299 " and tensor.size[",
300 dim_size,
301 "] = ",
302 tensor.size(dim_size));
303}
304
305namespace detail {
306
307std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
308 std::vector<int64_t> strides(sizes.size());
309 int64_t stride = 1;
310 for(size_t i = sizes.size(); i > 0; --i) {
311 strides[i-1] = stride;
312 stride *= sizes[i-1];
313 }
314 return strides;
315}
316
317// On a high level,
318// 1. separate `oldshape` into chunks of dimensions, where the dimensions are
319// ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] *
320// oldstride[i+1]
321// 2. `newshape` must be able to be separated into same number of chunks as
322// `oldshape` was separated into, where each chunk of newshape has matching
323// ``numel'', i.e., number of subspaces, as the corresponding chunk of
324// `oldshape`.
325//
326// templatized for DimVector and IntArrayRef use cases,
327// see overloads of computeStride() below.
328//
329template <typename ResultVec, typename NewShapeVec, typename Numel>
330inline c10::optional<ResultVec> computeStride_impl(
331 const NewShapeVec& oldshape,
332 const NewShapeVec& oldstride,
333 const NewShapeVec& newshape,
334 ResultVec toResult(const NewShapeVec&)
335) {
336 if (oldshape.empty()) {
337 return ResultVec(newshape.size(), 1);
338 }
339
340 // NOTE: stride is arbitrary in the numel() == 0 case;
341 // to match NumPy behavior we copy the strides if the size matches, otherwise
342 // we use the stride as if it were computed via resize.
343 // This could perhaps be combined with the below code, but the complexity
344 // didn't seem worth it.
345 const Numel numel = c10::multiply_integers(oldshape);
346 if (numel == 0 && oldshape.equals(newshape)) {
347 return toResult(oldstride);
348 }
349
350 ResultVec newstride(newshape.size());
351 if (numel == 0) {
352 for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
353 if (view_d == (int64_t)(newshape.size() - 1)) {
354 newstride[view_d] = 1;
355 } else {
356 newstride[view_d] =
357 std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
358 }
359 }
360 return newstride;
361 }
362
363 int64_t view_d = (int64_t)newshape.size() - 1;
364 // stride for each subspace in the chunk
365 Numel chunk_base_stride = oldstride.back();
366 // numel in current chunk
367 Numel tensor_numel = 1;
368 Numel view_numel = 1;
369 for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
370 tensor_numel *= oldshape[tensor_d];
371 // if end of tensor size chunk, check view
372 if ((tensor_d == 0) ||
373 (oldshape[tensor_d - 1] != 1 &&
374 oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
375 while (view_d >= 0 &&
376 (view_numel < tensor_numel || newshape[view_d] == 1)) {
377 newstride[view_d] = view_numel * chunk_base_stride;
378 view_numel *= newshape[view_d];
379 view_d--;
380 }
381 if (view_numel != tensor_numel) {
382 return c10::nullopt;
383 }
384 if (tensor_d > 0) {
385 chunk_base_stride = oldstride[tensor_d - 1];
386 tensor_numel = 1;
387 view_numel = 1;
388 }
389 }
390 }
391 if (view_d != -1) {
392 return c10::nullopt;
393 }
394 return newstride;
395}
396
397c10::optional<std::vector<int64_t>> computeStride(
398 IntArrayRef oldshape,
399 IntArrayRef oldstride,
400 IntArrayRef newshape) {
401 auto toResult = [](const IntArrayRef& a) { return a.vec(); };
402 return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
403}
404
405c10::optional<SymDimVector> computeStride(
406 c10::SymIntArrayRef oldshape,
407 c10::SymIntArrayRef oldstride,
408 c10::SymIntArrayRef newshape) {
409 auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); };
410 return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult);
411}
412
413c10::optional<DimVector> computeStride(
414 IntArrayRef oldshape,
415 IntArrayRef oldstride,
416 const DimVector& newshape) {
417 auto toResult = [](const IntArrayRef& a) { return DimVector(a); };
418 return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
419}
420
421} // namespace detail
422} // namespace at
423