1#include <ATen/TensorGeometry.h>
2
3#include <limits>
4#include <cstddef>
5
6namespace at {
7
8// See TensorGeometry.h on why this is useful now that we cache is_contiguous.
9template <typename T>
10bool _geometry_is_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides) {
11 assert(!overflows<std::int64_t>(sizes.size()));
12 auto dim = static_cast<std::int64_t>(sizes.size());
13 T expected_stride = 1;
14 bool contig_if_nonempty = true;
15 for (int64_t i = dim - 1; i >= 0; i--) {
16 if (sizes[i] == 0) {
17 return true;
18 }
19 if (contig_if_nonempty) {
20 if (sizes[i] != 1 && strides[i] != expected_stride) {
21 contig_if_nonempty = false;
22 }
23 expected_stride *= sizes[i];
24 }
25 }
26 return contig_if_nonempty;
27}
28
29bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) {
30 return _geometry_is_contiguous(sizes, strides);
31}
32
33bool TensorGeometry::is_contiguous() const {
34 if (numel_ == 0) {
35 return true;
36 }
37 return at::_geometry_is_contiguous<c10::SymInt>(sizes_, strides_);
38}
39
40} // namespace at
41