1#pragma once
2
3#include <ATen/core/TensorBase.h>
4#include <c10/core/WrapDimMinimal.h>
5
6namespace at {
7
8// Return if the tensor geometry represented by `sizes` and `strides` is
9// contiguous Although we cache is_contiguous in tensor now, this is till useful
10// because it allows checking if a particular geometry is contiguous without
11// explicitly constructing a tensor, e.g., when you want to choose a kernel
12// strategy based on whether a subgeometry is contiguous.
13TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
14
15struct TORCH_API TensorGeometry {
16 TensorGeometry() = default;
17
18 explicit TensorGeometry(c10::SymIntArrayRef sizes)
19 : sizes_(sizes.vec()),
20 strides_(sizes.size()),
21 has_symbolic_sizes_strides_(
22 !c10::asIntArrayRefSlowOpt(sizes).has_value()) {
23 int64_t dim = sizes.size();
24 c10::SymInt expected_stride = 1;
25 for (int64_t i = dim - 1; i >= 0; i--) {
26 strides_[i] = expected_stride;
27 expected_stride *= sizes_[i];
28 }
29 numel_ = expected_stride;
30 }
31
32 explicit TensorGeometry(const TensorBase& t)
33 : sizes_(t.sym_sizes().vec()),
34 strides_(t.sym_strides().vec()),
35 storage_offset_(t.sym_storage_offset()),
36 numel_(t.sym_numel()),
37 has_symbolic_sizes_strides_(
38 t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
39
40 // true if the tensor is contiguous
41 bool is_contiguous() const;
42
43 int64_t dim() const {
44 return sizes_.size();
45 }
46
47 int64_t size(int64_t dim) const {
48 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
49 dim = c10::maybe_wrap_dim(dim, this->dim());
50 return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
51 }
52 c10::IntArrayRef sizes() const {
53 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
54 return c10::asIntArrayRefUnchecked(sizes_);
55 }
56 int64_t stride(int64_t dim) const {
57 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
58 dim = c10::maybe_wrap_dim(dim, this->dim());
59 return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
60 }
61 c10::IntArrayRef strides() const {
62 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
63 return c10::asIntArrayRefUnchecked(strides_);
64 }
65 int64_t storage_offset() const {
66 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
67 return storage_offset_.as_int_unchecked();
68 }
69 int64_t numel() const {
70 TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
71 return numel_.as_int_unchecked();
72 }
73
74 c10::SymInt sym_size(int64_t dim) const {
75 dim = c10::maybe_wrap_dim(dim, this->dim());
76 return sizes_.at(static_cast<size_t>(dim));
77 }
78 c10::SymIntArrayRef sym_sizes() const {
79 return sizes_;
80 }
81 c10::SymInt sym_stride(int64_t dim) const {
82 dim = c10::maybe_wrap_dim(dim, this->dim());
83 return strides_.at(static_cast<size_t>(dim));
84 }
85 c10::SymIntArrayRef sym_strides() const {
86 return strides_;
87 }
88 c10::SymInt sym_storage_offset() const {
89 return storage_offset_;
90 }
91 c10::SymInt sym_numel() const {
92 return numel_;
93 }
94
95 TensorGeometry transpose(int64_t dim0, int64_t dim1) {
96 TensorGeometry r = *this; // copy
97 TORCH_CHECK(
98 dim0 < dim(),
99 "transpose: dim0=",
100 dim0,
101 " out of range (dim=",
102 dim(),
103 ")")
104 TORCH_CHECK(
105 dim1 < dim(),
106 "transpose: dim1=",
107 dim1,
108 " out of range (dim=",
109 dim(),
110 ")")
111 std::swap(r.sizes_[dim0], r.sizes_[dim1]);
112 std::swap(r.strides_[dim0], r.strides_[dim1]);
113 return r;
114 }
115
116 private:
117 std::vector<c10::SymInt> sizes_;
118 std::vector<c10::SymInt> strides_;
119 c10::SymInt storage_offset_;
120 c10::SymInt numel_;
121 bool has_symbolic_sizes_strides_{false};
122};
123
124} // namespace at
125