1#pragma once
2#include <ATen/Utils.h>
3#include <c10/util/ArrayRef.h>
4
5#include <vector>
6
7namespace at {
8/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
9/// we can easily view it as a multidimensional array.
10///
11/// Like ArrayRef, this class does not own the underlying data, it is expected
12/// to be used in situations where the data resides in some other buffer.
13///
14/// This is intended to be trivially copyable, so it should be passed by
15/// value.
16///
17/// For now, 2D only (so the copies are actually cheap, without having
18/// to write a SmallVector class) and contiguous only (so we can
19/// return non-strided ArrayRef on index).
20///
21/// P.S. dimension 0 indexes rows, dimension 1 indexes columns
22template <typename T>
23class MatrixRef {
24 public:
25 typedef size_t size_type;
26
27 private:
28 /// Underlying ArrayRef
29 ArrayRef<T> arr;
30
31 /// Stride of dim 0 (outer dimension)
32 size_type stride0;
33
34 // Stride of dim 1 is assumed to be 1
35
36 public:
37 /// Construct an empty Matrixref.
38 /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
39
40 /// Construct an MatrixRef from an ArrayRef and outer stride.
41 /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
42 : arr(arr), stride0(stride0) {
43 TORCH_CHECK(
44 arr.size() % stride0 == 0,
45 "MatrixRef: ArrayRef size ",
46 arr.size(),
47 " not divisible by stride ",
48 stride0)
49 }
50
51 /// @}
52 /// @name Simple Operations
53 /// @{
54
55 /// empty - Check if the matrix is empty.
56 bool empty() const {
57 return arr.empty();
58 }
59
60 const T* data() const {
61 return arr.data();
62 }
63
64 /// size - Get size a dimension
65 size_t size(size_t dim) const {
66 if (dim == 0) {
67 return arr.size() / stride0;
68 } else if (dim == 1) {
69 return stride0;
70 } else {
71 TORCH_CHECK(
72 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
73 }
74 }
75
76 size_t numel() const {
77 return arr.size();
78 }
79
80 /// equals - Check for element-wise equality.
81 bool equals(MatrixRef RHS) const {
82 return stride0 == RHS.stride0 && arr.equals(RHS.arr);
83 }
84
85 /// @}
86 /// @name Operator Overloads
87 /// @{
88 ArrayRef<T> operator[](size_t Index) const {
89 return arr.slice(Index * stride0, stride0);
90 }
91
92 /// Disallow accidental assignment from a temporary.
93 ///
94 /// The declaration here is extra complicated so that "arrayRef = {}"
95 /// continues to select the move assignment operator.
96 template <typename U>
97 typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
98 operator=(U&& Temporary) = delete;
99
100 /// Disallow accidental assignment from a temporary.
101 ///
102 /// The declaration here is extra complicated so that "arrayRef = {}"
103 /// continues to select the move assignment operator.
104 template <typename U>
105 typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
106 operator=(std::initializer_list<U>) = delete;
107};
108
109} // end namespace at
110