1 | #pragma once |
2 | #include <ATen/Utils.h> |
3 | #include <c10/util/ArrayRef.h> |
4 | |
5 | #include <vector> |
6 | |
7 | namespace at { |
8 | /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that |
9 | /// we can easily view it as a multidimensional array. |
10 | /// |
11 | /// Like ArrayRef, this class does not own the underlying data, it is expected |
12 | /// to be used in situations where the data resides in some other buffer. |
13 | /// |
14 | /// This is intended to be trivially copyable, so it should be passed by |
15 | /// value. |
16 | /// |
17 | /// For now, 2D only (so the copies are actually cheap, without having |
18 | /// to write a SmallVector class) and contiguous only (so we can |
19 | /// return non-strided ArrayRef on index). |
20 | /// |
21 | /// P.S. dimension 0 indexes rows, dimension 1 indexes columns |
22 | template <typename T> |
23 | class MatrixRef { |
24 | public: |
25 | typedef size_t size_type; |
26 | |
27 | private: |
28 | /// Underlying ArrayRef |
29 | ArrayRef<T> arr; |
30 | |
31 | /// Stride of dim 0 (outer dimension) |
32 | size_type stride0; |
33 | |
34 | // Stride of dim 1 is assumed to be 1 |
35 | |
36 | public: |
37 | /// Construct an empty Matrixref. |
38 | /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} |
39 | |
40 | /// Construct an MatrixRef from an ArrayRef and outer stride. |
41 | /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0) |
42 | : arr(arr), stride0(stride0) { |
43 | TORCH_CHECK( |
44 | arr.size() % stride0 == 0, |
45 | "MatrixRef: ArrayRef size " , |
46 | arr.size(), |
47 | " not divisible by stride " , |
48 | stride0) |
49 | } |
50 | |
51 | /// @} |
52 | /// @name Simple Operations |
53 | /// @{ |
54 | |
55 | /// empty - Check if the matrix is empty. |
56 | bool empty() const { |
57 | return arr.empty(); |
58 | } |
59 | |
60 | const T* data() const { |
61 | return arr.data(); |
62 | } |
63 | |
64 | /// size - Get size a dimension |
65 | size_t size(size_t dim) const { |
66 | if (dim == 0) { |
67 | return arr.size() / stride0; |
68 | } else if (dim == 1) { |
69 | return stride0; |
70 | } else { |
71 | TORCH_CHECK( |
72 | 0, "MatrixRef: out of bounds dimension " , dim, "; expected 0 or 1" ); |
73 | } |
74 | } |
75 | |
76 | size_t numel() const { |
77 | return arr.size(); |
78 | } |
79 | |
80 | /// equals - Check for element-wise equality. |
81 | bool equals(MatrixRef RHS) const { |
82 | return stride0 == RHS.stride0 && arr.equals(RHS.arr); |
83 | } |
84 | |
85 | /// @} |
86 | /// @name Operator Overloads |
87 | /// @{ |
88 | ArrayRef<T> operator[](size_t Index) const { |
89 | return arr.slice(Index * stride0, stride0); |
90 | } |
91 | |
92 | /// Disallow accidental assignment from a temporary. |
93 | /// |
94 | /// The declaration here is extra complicated so that "arrayRef = {}" |
95 | /// continues to select the move assignment operator. |
96 | template <typename U> |
97 | typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type& |
98 | operator=(U&& Temporary) = delete; |
99 | |
100 | /// Disallow accidental assignment from a temporary. |
101 | /// |
102 | /// The declaration here is extra complicated so that "arrayRef = {}" |
103 | /// continues to select the move assignment operator. |
104 | template <typename U> |
105 | typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type& |
106 | operator=(std::initializer_list<U>) = delete; |
107 | }; |
108 | |
109 | } // end namespace at |
110 | |