1 | #include <ATen/ATen.h> |
2 | #include <ATen/Config.h> |
3 | #include <ATen/TensorUtils.h> |
4 | #include <c10/util/accumulate.h> |
5 | #include <c10/util/irange.h> |
6 | |
7 | #include <ostream> |
8 | #include <sstream> |
9 | |
10 | namespace at { |
11 | |
12 | std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) { |
13 | if (t.pos == 0) { |
14 | // 0 is distinguished; it usually indicates 'self' or the return |
15 | // tensor |
16 | out << "'" << t.name << "'" ; |
17 | } else { |
18 | out << "argument #" << t.pos << " '" << t.name << "'" ; |
19 | } |
20 | return out; |
21 | } |
22 | |
23 | void checkDim( |
24 | CheckedFrom c, |
25 | const Tensor& tensor, |
26 | const char* name, |
27 | int pos, // 1-indexed |
28 | int64_t dim) { |
29 | TORCH_CHECK( |
30 | tensor.dim() == dim, |
31 | "Expected " , |
32 | dim, |
33 | "-dimensional tensor, but got " , |
34 | tensor.dim(), |
35 | "-dimensional tensor for " , |
36 | TensorGeometryArg(TensorArg({tensor, name, pos})), |
37 | " (while checking arguments for " , |
38 | c, |
39 | ")" ); |
40 | } |
41 | |
42 | void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) { |
43 | TORCH_CHECK(t->dim() == dim, |
44 | "Expected " , dim, "-dimensional tensor, but got " , t->dim(), |
45 | "-dimensional tensor for " , t," (while checking arguments for " , c, ")" ); |
46 | } |
47 | |
48 | void checkDimRange(CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) { |
49 | TORCH_CHECK( |
50 | t->dim() >= dim_start && t->dim() < dim_end, |
51 | "Expected " , dim_start, " to " , (dim_end - 1), " dimensions, but got " , |
52 | t->dim(), "-dimensional tensor for " , t, " (while checking arguments for " , |
53 | c, ")" ); |
54 | } |
55 | |
56 | void checkContiguous(CheckedFrom c, const TensorGeometryArg& t) { |
57 | TORCH_CHECK( |
58 | t->is_contiguous(), |
59 | "Expected contiguous tensor, but got non-contiguous tensor for " , t, |
60 | " (while checking arguments for " , c, ")" ); |
61 | } |
62 | |
63 | void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) { |
64 | for (auto& t : ts) { |
65 | if (!t->defined()) continue; |
66 | checkContiguous(c, t); |
67 | } |
68 | } |
69 | |
70 | void checkSize(CheckedFrom c, const TensorGeometryArg& t, IntArrayRef sizes) { |
71 | checkDim(c, t, sizes.size()); |
72 | TORCH_CHECK( |
73 | t->sizes().equals(sizes), |
74 | "Expected tensor of size " , sizes, ", but got tensor of size " , t->sizes(), |
75 | " for " , t, " (while checking arguments for " , c, ")" ); |
76 | } |
77 | |
78 | void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, c10::SymIntArrayRef sizes) { |
79 | checkDim(c, t, sizes.size()); |
80 | TORCH_CHECK( |
81 | t->sym_sizes().equals(sizes), |
82 | "Expected tensor of size " , sizes, ", but got tensor of size " , t->sizes(), |
83 | " for " , t, " (while checking arguments for " , c, ")" ); |
84 | } |
85 | |
86 | void checkSize(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) { |
87 | TORCH_CHECK( |
88 | t->size(dim) == size, |
89 | "Expected tensor to have size " , size, " at dimension " , dim, |
90 | ", but got size " , t->size(dim), " for " , t, |
91 | " (while checking arguments for " , c, ")" ); |
92 | } |
93 | |
94 | void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, c10::SymInt size) { |
95 | TORCH_CHECK( |
96 | t->sym_size(dim) == size, |
97 | "Expected tensor to have size " , size, " at dimension " , dim, |
98 | ", but got size " , t->size(dim), " for " , t, |
99 | " (while checking arguments for " , c, ")" ); |
100 | } |
101 | |
102 | void checkAllSame(CheckedFrom c, ArrayRef<TensorArg> tensors, void(*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) { |
103 | const TensorArg* t0 = nullptr; |
104 | for (auto& t : tensors) { |
105 | if (!t->defined()) continue; |
106 | if (t0 != nullptr) { |
107 | fn(c, *t0, t); |
108 | } else { |
109 | t0 = &t; |
110 | } |
111 | } |
112 | } |
113 | |
114 | void checkSameSize(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { |
115 | TORCH_CHECK( |
116 | t1->sizes().equals(t2->sizes()), |
117 | "Expected tensor for " , t1, " to have same size as tensor for " , t2, |
118 | "; but " , t1->sizes(), " does not equal " , t2->sizes(), |
119 | " (while checking arguments for " , c, ")" ); |
120 | } |
121 | |
122 | void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors) { |
123 | checkAllSame(c, tensors, checkSameSize); |
124 | } |
125 | |
126 | void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel) { |
127 | TORCH_CHECK( |
128 | t->numel() == numel, |
129 | "Expected tensor for " , t, " to have " , numel, |
130 | " elements; but it actually has " , t->numel(), " elements" , |
131 | " (while checking arguments for " , c, ")" ); |
132 | } |
133 | |
134 | void (CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { |
135 | TORCH_CHECK( |
136 | t1->numel() == t2->numel(), |
137 | "Expected tensor for " , t1, |
138 | " to have same number of elements as tensor for " , t2, "; but " , |
139 | t1->numel(), " does not equal " , t2->numel(), |
140 | " (while checking arguments for " , c, ")" ); |
141 | } |
142 | |
143 | void (CheckedFrom c, ArrayRef<TensorArg> tensors) { |
144 | checkAllSame(c, tensors, checkSameNumel); |
145 | } |
146 | |
147 | void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { |
148 | if (t1->is_cpu() || t2->is_cpu()) { |
149 | std::ostringstream oss; |
150 | if (t1->is_cpu()) { |
151 | oss << "Tensor for " << t1 << " is on CPU, " ; |
152 | } |
153 | if (t2->is_cpu()) { |
154 | oss << "Tensor for " << t2 << " is on CPU, " ; |
155 | } |
156 | oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it" ) |
157 | << " to be on GPU (while checking arguments for " << c << ")" ; |
158 | AT_ERROR(oss.str()); |
159 | } |
160 | TORCH_CHECK( |
161 | t1->get_device() == t2->get_device(), |
162 | "Expected tensor for " , t1, " to have the same device as tensor for " , t2, |
163 | "; but device " , t1->get_device(), " does not equal " , t2->get_device(), |
164 | " (while checking arguments for " , c, ")" ); |
165 | } |
166 | |
167 | void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors) { |
168 | checkAllSame(c, tensors, checkSameGPU); |
169 | } |
170 | |
171 | void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { |
172 | TORCH_CHECK( |
173 | t1->options().type_equal(t2->options()), |
174 | "Expected tensor for " , t1, " to have the same type as tensor for " , t2, |
175 | "; but type " , t1->toString(), " does not equal " , t2->toString(), |
176 | " (while checking arguments for " , c, ")" ); |
177 | } |
178 | |
179 | void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) { |
180 | TORCH_CHECK( |
181 | t->scalar_type() == ty, |
182 | "Expected tensor for " , t, " to have scalar type " , toString(ty), |
183 | "; but got " , t->toString(), " instead (while checking arguments for " , c, |
184 | ")" ); |
185 | } |
186 | |
187 | void checkScalarTypes(CheckedFrom c, const TensorArg& t, |
188 | at::ArrayRef<ScalarType> l) { |
189 | if (std::find(l.begin(), l.end(), t->scalar_type()) == l.end()) { |
190 | std::ostringstream oss; |
191 | oss << "Expected tensor for " << t << " to have one of the following " |
192 | << "scalar types: " ; |
193 | size_t i = 0; |
194 | for (auto ty : l) { |
195 | if (i != 0) { |
196 | oss << ", " ; |
197 | } |
198 | oss << toString(ty); |
199 | i++; |
200 | } |
201 | oss << "; but got " << t->toString() |
202 | << " instead (while checking arguments for " << c << ")" ; |
203 | AT_ERROR(oss.str()); |
204 | } |
205 | } |
206 | |
207 | void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) { |
208 | checkAllSame(c, tensors, checkSameType); |
209 | } |
210 | |
211 | void checkSameDim(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) { |
212 | TORCH_CHECK( |
213 | t1->dim() == t2->dim(), |
214 | "Expected tensor for " , t1, " to have the same dimension as tensor for " , |
215 | t2, "; but " , t1->dim(), " does not equal " , t2->dim(), |
216 | " (while checking arguments for " , c, ")" ); |
217 | } |
218 | |
219 | void checkDefined(CheckedFrom c, const TensorArg& t) { |
220 | TORCH_CHECK( |
221 | t->defined(), |
222 | "Expected tensor for " , t, " to be non-null, but it was undefined " , |
223 | " (while checking arguments for " , c, ")" ); |
224 | } |
225 | |
226 | void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) { |
227 | // NB: don't filter defined here |
228 | for (auto t : ts) { |
229 | checkDefined(c, t); |
230 | } |
231 | } |
232 | |
233 | void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) { |
234 | TORCH_CHECK( |
235 | !t.defined() || t.options().backend() == backend, |
236 | "Expected tensor to have " , toString(backend), |
237 | " Backend, but got tensor with " , toString(t.options().backend()), " Backend " , |
238 | "(while checking arguments for " , c, ")" ); |
239 | } |
240 | |
241 | void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Backend backend) { |
242 | for (auto &t : tensors) { |
243 | checkBackend(c, t, backend); |
244 | } |
245 | } |
246 | |
247 | void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) { |
248 | TORCH_CHECK( |
249 | !t.defined() || t.device().type() == device_type, |
250 | "Expected tensor to have " , device_type, |
251 | " DeviceType, but got tensor with " , t.device().type(), " DeviceType " , |
252 | "(while checking arguments for " , c, ")" ); |
253 | } |
254 | |
255 | void checkDeviceType(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::DeviceType device_type) { |
256 | for (auto &t : tensors) { |
257 | checkDeviceType(c, t, device_type); |
258 | } |
259 | } |
260 | |
261 | void checkLayout(CheckedFrom c, const Tensor& t, Layout layout) { |
262 | TORCH_CHECK( |
263 | !t.defined() || t.layout() == layout, |
264 | "Expected tensor to have " , layout, |
265 | " Layout, but got tensor with " , t.layout(), " Layout " , |
266 | "(while checking arguments for " , c, ")" ); |
267 | } |
268 | |
269 | void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout) { |
270 | for (auto &t : tensors) { |
271 | checkLayout(c, t, layout); |
272 | } |
273 | } |
274 | |
275 | void * maybe_data_ptr(const Tensor& tensor) { |
276 | return tensor.defined() ? (void *)tensor.data_ptr() : nullptr; |
277 | } |
278 | |
279 | void * maybe_data_ptr(const TensorArg& tensor) { |
280 | return tensor->defined() ? (void *)tensor->data_ptr() : nullptr; |
281 | } |
282 | |
283 | void check_dim_size( |
284 | const Tensor& tensor, |
285 | int64_t dim, |
286 | int64_t dim_size, |
287 | int64_t size) { |
288 | /* Check dimension size of a tensor */ |
289 | TORCH_CHECK( |
290 | tensor.dim() == dim && tensor.size(dim_size) == size, |
291 | "Expected a tensor of dimension " , |
292 | dim, |
293 | " and tensor.size[" , |
294 | dim_size, |
295 | "] == " , |
296 | size, |
297 | " but got: dimension " , |
298 | tensor.dim(), |
299 | " and tensor.size[" , |
300 | dim_size, |
301 | "] = " , |
302 | tensor.size(dim_size)); |
303 | } |
304 | |
305 | namespace detail { |
306 | |
307 | std::vector<int64_t> defaultStrides(IntArrayRef sizes) { |
308 | std::vector<int64_t> strides(sizes.size()); |
309 | int64_t stride = 1; |
310 | for(size_t i = sizes.size(); i > 0; --i) { |
311 | strides[i-1] = stride; |
312 | stride *= sizes[i-1]; |
313 | } |
314 | return strides; |
315 | } |
316 | |
317 | // On a high level, |
318 | // 1. separate `oldshape` into chunks of dimensions, where the dimensions are |
319 | // ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] * |
320 | // oldstride[i+1] |
321 | // 2. `newshape` must be able to be separated into same number of chunks as |
322 | // `oldshape` was separated into, where each chunk of newshape has matching |
323 | // ``numel'', i.e., number of subspaces, as the corresponding chunk of |
324 | // `oldshape`. |
325 | // |
326 | // templatized for DimVector and IntArrayRef use cases, |
327 | // see overloads of computeStride() below. |
328 | // |
329 | template <typename ResultVec, typename NewShapeVec, typename Numel> |
330 | inline c10::optional<ResultVec> computeStride_impl( |
331 | const NewShapeVec& oldshape, |
332 | const NewShapeVec& oldstride, |
333 | const NewShapeVec& newshape, |
334 | ResultVec toResult(const NewShapeVec&) |
335 | ) { |
336 | if (oldshape.empty()) { |
337 | return ResultVec(newshape.size(), 1); |
338 | } |
339 | |
340 | // NOTE: stride is arbitrary in the numel() == 0 case; |
341 | // to match NumPy behavior we copy the strides if the size matches, otherwise |
342 | // we use the stride as if it were computed via resize. |
343 | // This could perhaps be combined with the below code, but the complexity |
344 | // didn't seem worth it. |
345 | const Numel numel = c10::multiply_integers(oldshape); |
346 | if (numel == 0 && oldshape.equals(newshape)) { |
347 | return toResult(oldstride); |
348 | } |
349 | |
350 | ResultVec newstride(newshape.size()); |
351 | if (numel == 0) { |
352 | for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) { |
353 | if (view_d == (int64_t)(newshape.size() - 1)) { |
354 | newstride[view_d] = 1; |
355 | } else { |
356 | newstride[view_d] = |
357 | std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1]; |
358 | } |
359 | } |
360 | return newstride; |
361 | } |
362 | |
363 | int64_t view_d = (int64_t)newshape.size() - 1; |
364 | // stride for each subspace in the chunk |
365 | Numel chunk_base_stride = oldstride.back(); |
366 | // numel in current chunk |
367 | Numel tensor_numel = 1; |
368 | Numel view_numel = 1; |
369 | for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) { |
370 | tensor_numel *= oldshape[tensor_d]; |
371 | // if end of tensor size chunk, check view |
372 | if ((tensor_d == 0) || |
373 | (oldshape[tensor_d - 1] != 1 && |
374 | oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) { |
375 | while (view_d >= 0 && |
376 | (view_numel < tensor_numel || newshape[view_d] == 1)) { |
377 | newstride[view_d] = view_numel * chunk_base_stride; |
378 | view_numel *= newshape[view_d]; |
379 | view_d--; |
380 | } |
381 | if (view_numel != tensor_numel) { |
382 | return c10::nullopt; |
383 | } |
384 | if (tensor_d > 0) { |
385 | chunk_base_stride = oldstride[tensor_d - 1]; |
386 | tensor_numel = 1; |
387 | view_numel = 1; |
388 | } |
389 | } |
390 | } |
391 | if (view_d != -1) { |
392 | return c10::nullopt; |
393 | } |
394 | return newstride; |
395 | } |
396 | |
397 | c10::optional<std::vector<int64_t>> computeStride( |
398 | IntArrayRef oldshape, |
399 | IntArrayRef oldstride, |
400 | IntArrayRef newshape) { |
401 | auto toResult = [](const IntArrayRef& a) { return a.vec(); }; |
402 | return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult); |
403 | } |
404 | |
405 | c10::optional<SymDimVector> computeStride( |
406 | c10::SymIntArrayRef oldshape, |
407 | c10::SymIntArrayRef oldstride, |
408 | c10::SymIntArrayRef newshape) { |
409 | auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); }; |
410 | return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult); |
411 | } |
412 | |
413 | c10::optional<DimVector> computeStride( |
414 | IntArrayRef oldshape, |
415 | IntArrayRef oldstride, |
416 | const DimVector& newshape) { |
417 | auto toResult = [](const IntArrayRef& a) { return DimVector(a); }; |
418 | return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult); |
419 | } |
420 | |
421 | } // namespace detail |
422 | } // namespace at |
423 | |