1 | #include <ATen/TensorGeometry.h> |
---|---|
2 | |
3 | #include <limits> |
4 | #include <cstddef> |
5 | |
6 | namespace at { |
7 | |
8 | // See TensorGeometry.h on why this is useful now that we cache is_contiguous. |
9 | template <typename T> |
10 | bool _geometry_is_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides) { |
11 | assert(!overflows<std::int64_t>(sizes.size())); |
12 | auto dim = static_cast<std::int64_t>(sizes.size()); |
13 | T expected_stride = 1; |
14 | bool contig_if_nonempty = true; |
15 | for (int64_t i = dim - 1; i >= 0; i--) { |
16 | if (sizes[i] == 0) { |
17 | return true; |
18 | } |
19 | if (contig_if_nonempty) { |
20 | if (sizes[i] != 1 && strides[i] != expected_stride) { |
21 | contig_if_nonempty = false; |
22 | } |
23 | expected_stride *= sizes[i]; |
24 | } |
25 | } |
26 | return contig_if_nonempty; |
27 | } |
28 | |
29 | bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) { |
30 | return _geometry_is_contiguous(sizes, strides); |
31 | } |
32 | |
33 | bool TensorGeometry::is_contiguous() const { |
34 | if (numel_ == 0) { |
35 | return true; |
36 | } |
37 | return at::_geometry_is_contiguous<c10::SymInt>(sizes_, strides_); |
38 | } |
39 | |
40 | } // namespace at |
41 |