1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_MATRIX_H_ |
17 | #define RUY_RUY_MATRIX_H_ |
18 | |
19 | #include <cstddef> |
20 | #include <cstdint> // IWYU pragma: keep |
21 | #include <type_traits> |
22 | |
23 | #include "ruy/check_macros.h" |
24 | |
25 | namespace ruy { |
26 | |
27 | // Layout storage order. Here and elsewhere, 'col' is short for 'column'. |
28 | // 'column-major' means that each column is contiguous in memory. |
29 | enum class Order : std::uint8_t { kColMajor, kRowMajor }; |
30 | |
31 | // Describes the shape and storage layout of a matrix. |
32 | class Layout final { |
33 | public: |
34 | int rows() const { return rows_; } |
35 | void set_rows(int val) { rows_ = val; } |
36 | int cols() const { return cols_; } |
37 | void set_cols(int val) { cols_ = val; } |
38 | int stride() const { return stride_; } |
39 | void set_stride(int val) { stride_ = val; } |
40 | Order order() const { return order_; } |
41 | void set_order(Order val) { order_ = val; } |
42 | |
43 | private: |
44 | int rows_ = 0; |
45 | int cols_ = 0; |
46 | // Stride is the offset between two adjacent matrix elements |
47 | // in the non-contiguous direction. |
48 | int stride_ = 0; |
49 | Order order_ = Order::kColMajor; |
50 | }; |
51 | |
52 | namespace detail { |
53 | |
54 | // Thin wrapper around a pointer with a constness model that works for the |
55 | // purposes of the Matrix class. |
56 | // |
57 | // A typical conundrum of any C++ container class is what type constness should |
58 | // encode at compile time constancy of the contained data? |
59 | // `Matrix<const T>` or `const Matrix<T>`? |
60 | // With either approach it is very difficult to achieve perfect |
61 | // const-correctness that that can only be done with some combination of |
62 | // inconvenient interface and c++ complexity/abstraction. |
63 | // |
64 | // Here we opt for something that's entirely tailored to the needs of the Ruy |
65 | // interface. The only purpose of the Matrix class is to pass matrix data |
66 | // pointers to ruy. There is an asymmetry here: the caller of ruy::Mul only |
67 | // needs to `set` the data; ruy itself only needs to `get` the data. In the |
68 | // caller's code, it's convenient to be able to just deal with `Matrix<T>` |
69 | // without having to sprinkle `const` keywords in the right places, so we want |
70 | // to track whether the data is constant in a way that's decoupled from the |
71 | // constness of `this`, and we never want to have Matrix<const T>. Inside ruy |
72 | // code, as input matrices are passed by const-reference and output matrices are |
73 | // passed by pointer (to non-const), the constness of `this` is telling whether |
74 | // the data is constant. See the `get` and `set` methods below and the comment |
75 | // explaining the core logic that they encapsulate. |
76 | template <typename T> |
77 | class ConstCheckingPtr final { |
78 | public: |
79 | using element_type = T; |
80 | |
81 | // Core accessors. These encapsulate the main logic: |
82 | // - for `set`, the constness of the argument determines whether internal |
83 | // pointer should be tracked as const/mutable. |
84 | // - for `get`, the constness of `this` determines whether the call |
85 | // counts as a const or mutable use of the internal pointer. |
86 | void set(T* ptr) { |
87 | ptr_ = ptr; |
88 | set_mutable(true); |
89 | } |
90 | void set(const T* ptr) { |
91 | ptr_ = ptr; |
92 | set_mutable(false); |
93 | } |
94 | void set(std::nullptr_t) { ptr_ = nullptr; } |
95 | T* get() /* NOT const */ { |
96 | assert_mutable(); |
97 | return const_cast<T*>(ptr_); |
98 | } |
99 | const T* get() const { return ptr_; } |
100 | |
101 | private: |
102 | // There's never a need for Matrix<const T>. |
103 | static_assert(!std::is_const<T>::value, "" ); |
104 | const T* ptr_ = nullptr; |
105 | #ifndef NDEBUG |
106 | bool is_mutable_ = true; |
107 | void set_mutable(bool val) { is_mutable_ = val; } |
108 | void assert_mutable() { RUY_DCHECK(is_mutable_); } |
109 | #else |
110 | void set_mutable(bool) {} |
111 | void assert_mutable() {} |
112 | #endif |
113 | }; |
114 | |
115 | } // namespace detail |
116 | |
117 | enum class CachePolicy : std::uint8_t { |
118 | kNeverCache, |
119 | kCacheIfLargeSpeedup, |
120 | kCacheIfSignificantSpeedup, |
121 | kAlwaysCache, |
122 | }; |
123 | |
124 | // A Matrix merely wraps existing data as a matrix. It doesn't own any buffer. |
125 | // The purpose of Matrix is only to be used in ruy's interface -- it's just |
126 | // a structured way for the user to pass to ruy::Mul the matrix data pointers |
127 | // together with other matrix parameters. |
128 | // Scalar may be any floating-point or integral type. When integral, it may be |
129 | // signed or unsigned. It's never const: use Matrix<T> for both input and output |
130 | // matrices, never use Matrix<const T>. |
131 | // See the comments on detail::ConstCheckingPointer. |
132 | template <typename Scalar> |
133 | class Matrix final { |
134 | public: |
135 | static_assert(!std::is_const<Scalar>::value, |
136 | "Never use Matrix<const T>. Just use Matrix<T>. Constness of " |
137 | "the data is guarded by debug-only runtime assertions. See " |
138 | "detail::ConstCheckingPtr." ); |
139 | |
140 | Scalar* data() { return data_.get(); } |
141 | const Scalar* data() const { return data_.get(); } |
142 | void set_data(Scalar* ptr) { data_.set(ptr); } |
143 | void set_data(const Scalar* ptr) { data_.set(ptr); } |
144 | void set_data(std::nullptr_t) { data_.set(nullptr); } |
145 | const Layout& layout() const { return layout_; } |
146 | Layout* mutable_layout() { return &layout_; } |
147 | Scalar zero_point() const { return zero_point_; } |
148 | void set_zero_point(Scalar value) { zero_point_ = value; } |
149 | CachePolicy cache_policy() const { return cache_policy_; } |
150 | void set_cache_policy(CachePolicy value) { cache_policy_ = value; } |
151 | |
152 | private: |
153 | // The underlying buffer wrapped by this matrix. |
154 | detail::ConstCheckingPtr<Scalar> data_; |
155 | // The shape and data layout of this matrix. |
156 | Layout layout_; |
157 | // The zero_point, i.e. which Scalar value is to be interpreted as zero. |
158 | // When Scalar is floating-point, this must be 0. |
159 | Scalar zero_point_ = 0; |
160 | // When the data pointed to by this matrix is constant data, so that it is |
161 | // valid to assume that equality of pointers implies equality of data, |
162 | // a CachePolicy may be used instead of the default kNeverCache, |
163 | // which will enable ruy to take advantage of this constancy of the data to |
164 | // cache the packing work, which can be a large speedup in matrix*vector |
165 | // and other narrow shapes. |
166 | CachePolicy cache_policy_ = CachePolicy::kNeverCache; |
167 | }; |
168 | |
169 | inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { |
170 | layout->set_rows(rows); |
171 | layout->set_cols(cols); |
172 | layout->set_order(order); |
173 | layout->set_stride(order == Order::kColMajor ? rows : cols); |
174 | } |
175 | |
176 | template <typename StreamType, typename Scalar> |
177 | StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) { |
178 | for (int row = 0; row < mat.layout().rows(); row++) { |
179 | for (int col = 0; col < mat.layout().cols(); col++) { |
180 | stream << static_cast<double>(Element(mat, row, col)) << " " ; |
181 | } |
182 | stream << "\n" ; |
183 | } |
184 | return stream; |
185 | } |
186 | |
187 | // TODO(b/130417400) add a unit test |
188 | inline int Offset(const Layout& layout, int row, int col) { |
189 | // TODO(benoitjacob) - should check this but this make the _slow tests take |
190 | // 5x longer. Find a mitigation like in Eigen with an 'internal' variant |
191 | // bypassing the check? |
192 | // RUY_DCHECK_GE(row, 0); |
193 | // RUY_DCHECK_GE(col, 0); |
194 | // RUY_DCHECK_LT(row, layout.rows()); |
195 | // RUY_DCHECK_LT(col, layout.cols()); |
196 | int row_stride = layout.order() == Order::kColMajor ? 1 : layout.stride(); |
197 | int col_stride = layout.order() == Order::kRowMajor ? 1 : layout.stride(); |
198 | return row * row_stride + col * col_stride; |
199 | } |
200 | |
201 | template <typename Scalar> |
202 | const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) { |
203 | return mat.data() + Offset(mat.layout(), row, col); |
204 | } |
205 | |
206 | template <typename Scalar> |
207 | Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) { |
208 | return mat->data() + Offset(mat->layout(), row, col); |
209 | } |
210 | |
211 | template <typename Scalar> |
212 | Scalar Element(const Matrix<Scalar>& mat, int row, int col) { |
213 | return *ElementPtr(mat, row, col); |
214 | } |
215 | |
216 | } // namespace ruy |
217 | |
218 | #endif // RUY_RUY_MATRIX_H_ |
219 | |