1 | #pragma once |
2 | |
3 | #include <ATen/core/TensorBase.h> |
4 | #include <c10/core/WrapDimMinimal.h> |
5 | |
6 | namespace at { |
7 | |
8 | // Return if the tensor geometry represented by `sizes` and `strides` is |
9 | // contiguous Although we cache is_contiguous in tensor now, this is till useful |
10 | // because it allows checking if a particular geometry is contiguous without |
11 | // explicitly constructing a tensor, e.g., when you want to choose a kernel |
12 | // strategy based on whether a subgeometry is contiguous. |
13 | TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); |
14 | |
15 | struct TORCH_API TensorGeometry { |
16 | TensorGeometry() = default; |
17 | |
18 | explicit TensorGeometry(c10::SymIntArrayRef sizes) |
19 | : sizes_(sizes.vec()), |
20 | strides_(sizes.size()), |
21 | has_symbolic_sizes_strides_( |
22 | !c10::asIntArrayRefSlowOpt(sizes).has_value()) { |
23 | int64_t dim = sizes.size(); |
24 | c10::SymInt expected_stride = 1; |
25 | for (int64_t i = dim - 1; i >= 0; i--) { |
26 | strides_[i] = expected_stride; |
27 | expected_stride *= sizes_[i]; |
28 | } |
29 | numel_ = expected_stride; |
30 | } |
31 | |
32 | explicit TensorGeometry(const TensorBase& t) |
33 | : sizes_(t.sym_sizes().vec()), |
34 | strides_(t.sym_strides().vec()), |
35 | storage_offset_(t.sym_storage_offset()), |
36 | numel_(t.sym_numel()), |
37 | has_symbolic_sizes_strides_( |
38 | t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} |
39 | |
40 | // true if the tensor is contiguous |
41 | bool is_contiguous() const; |
42 | |
43 | int64_t dim() const { |
44 | return sizes_.size(); |
45 | } |
46 | |
47 | int64_t size(int64_t dim) const { |
48 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
49 | dim = c10::maybe_wrap_dim(dim, this->dim()); |
50 | return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked(); |
51 | } |
52 | c10::IntArrayRef sizes() const { |
53 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
54 | return c10::asIntArrayRefUnchecked(sizes_); |
55 | } |
56 | int64_t stride(int64_t dim) const { |
57 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
58 | dim = c10::maybe_wrap_dim(dim, this->dim()); |
59 | return strides_.at(static_cast<size_t>(dim)).as_int_unchecked(); |
60 | } |
61 | c10::IntArrayRef strides() const { |
62 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
63 | return c10::asIntArrayRefUnchecked(strides_); |
64 | } |
65 | int64_t storage_offset() const { |
66 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
67 | return storage_offset_.as_int_unchecked(); |
68 | } |
69 | int64_t numel() const { |
70 | TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); |
71 | return numel_.as_int_unchecked(); |
72 | } |
73 | |
74 | c10::SymInt sym_size(int64_t dim) const { |
75 | dim = c10::maybe_wrap_dim(dim, this->dim()); |
76 | return sizes_.at(static_cast<size_t>(dim)); |
77 | } |
78 | c10::SymIntArrayRef sym_sizes() const { |
79 | return sizes_; |
80 | } |
81 | c10::SymInt sym_stride(int64_t dim) const { |
82 | dim = c10::maybe_wrap_dim(dim, this->dim()); |
83 | return strides_.at(static_cast<size_t>(dim)); |
84 | } |
85 | c10::SymIntArrayRef sym_strides() const { |
86 | return strides_; |
87 | } |
88 | c10::SymInt sym_storage_offset() const { |
89 | return storage_offset_; |
90 | } |
91 | c10::SymInt sym_numel() const { |
92 | return numel_; |
93 | } |
94 | |
95 | TensorGeometry transpose(int64_t dim0, int64_t dim1) { |
96 | TensorGeometry r = *this; // copy |
97 | TORCH_CHECK( |
98 | dim0 < dim(), |
99 | "transpose: dim0=" , |
100 | dim0, |
101 | " out of range (dim=" , |
102 | dim(), |
103 | ")" ) |
104 | TORCH_CHECK( |
105 | dim1 < dim(), |
106 | "transpose: dim1=" , |
107 | dim1, |
108 | " out of range (dim=" , |
109 | dim(), |
110 | ")" ) |
111 | std::swap(r.sizes_[dim0], r.sizes_[dim1]); |
112 | std::swap(r.strides_[dim0], r.strides_[dim1]); |
113 | return r; |
114 | } |
115 | |
116 | private: |
117 | std::vector<c10::SymInt> sizes_; |
118 | std::vector<c10::SymInt> strides_; |
119 | c10::SymInt storage_offset_; |
120 | c10::SymInt numel_; |
121 | bool has_symbolic_sizes_strides_{false}; |
122 | }; |
123 | |
124 | } // namespace at |
125 | |