1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_MATRIX_H_
17#define RUY_RUY_MATRIX_H_
18
19#include <cstddef>
20#include <cstdint> // IWYU pragma: keep
21#include <type_traits>
22
23#include "ruy/check_macros.h"
24
25namespace ruy {
26
27// Layout storage order. Here and elsewhere, 'col' is short for 'column'.
28// 'column-major' means that each column is contiguous in memory.
29enum class Order : std::uint8_t { kColMajor, kRowMajor };
30
31// Describes the shape and storage layout of a matrix.
32class Layout final {
33 public:
34 int rows() const { return rows_; }
35 void set_rows(int val) { rows_ = val; }
36 int cols() const { return cols_; }
37 void set_cols(int val) { cols_ = val; }
38 int stride() const { return stride_; }
39 void set_stride(int val) { stride_ = val; }
40 Order order() const { return order_; }
41 void set_order(Order val) { order_ = val; }
42
43 private:
44 int rows_ = 0;
45 int cols_ = 0;
46 // Stride is the offset between two adjacent matrix elements
47 // in the non-contiguous direction.
48 int stride_ = 0;
49 Order order_ = Order::kColMajor;
50};
51
52namespace detail {
53
54// Thin wrapper around a pointer with a constness model that works for the
55// purposes of the Matrix class.
56//
57// A typical conundrum of any C++ container class is what type constness should
58// encode at compile time constancy of the contained data?
59// `Matrix<const T>` or `const Matrix<T>`?
60// With either approach it is very difficult to achieve perfect
61// const-correctness that that can only be done with some combination of
62// inconvenient interface and c++ complexity/abstraction.
63//
64// Here we opt for something that's entirely tailored to the needs of the Ruy
65// interface. The only purpose of the Matrix class is to pass matrix data
66// pointers to ruy. There is an asymmetry here: the caller of ruy::Mul only
67// needs to `set` the data; ruy itself only needs to `get` the data. In the
68// caller's code, it's convenient to be able to just deal with `Matrix<T>`
69// without having to sprinkle `const` keywords in the right places, so we want
70// to track whether the data is constant in a way that's decoupled from the
71// constness of `this`, and we never want to have Matrix<const T>. Inside ruy
72// code, as input matrices are passed by const-reference and output matrices are
73// passed by pointer (to non-const), the constness of `this` is telling whether
74// the data is constant. See the `get` and `set` methods below and the comment
75// explaining the core logic that they encapsulate.
76template <typename T>
77class ConstCheckingPtr final {
78 public:
79 using element_type = T;
80
81 // Core accessors. These encapsulate the main logic:
82 // - for `set`, the constness of the argument determines whether internal
83 // pointer should be tracked as const/mutable.
84 // - for `get`, the constness of `this` determines whether the call
85 // counts as a const or mutable use of the internal pointer.
86 void set(T* ptr) {
87 ptr_ = ptr;
88 set_mutable(true);
89 }
90 void set(const T* ptr) {
91 ptr_ = ptr;
92 set_mutable(false);
93 }
94 void set(std::nullptr_t) { ptr_ = nullptr; }
95 T* get() /* NOT const */ {
96 assert_mutable();
97 return const_cast<T*>(ptr_);
98 }
99 const T* get() const { return ptr_; }
100
101 private:
102 // There's never a need for Matrix<const T>.
103 static_assert(!std::is_const<T>::value, "");
104 const T* ptr_ = nullptr;
105#ifndef NDEBUG
106 bool is_mutable_ = true;
107 void set_mutable(bool val) { is_mutable_ = val; }
108 void assert_mutable() { RUY_DCHECK(is_mutable_); }
109#else
110 void set_mutable(bool) {}
111 void assert_mutable() {}
112#endif
113};
114
115} // namespace detail
116
117enum class CachePolicy : std::uint8_t {
118 kNeverCache,
119 kCacheIfLargeSpeedup,
120 kCacheIfSignificantSpeedup,
121 kAlwaysCache,
122};
123
124// A Matrix merely wraps existing data as a matrix. It doesn't own any buffer.
125// The purpose of Matrix is only to be used in ruy's interface -- it's just
126// a structured way for the user to pass to ruy::Mul the matrix data pointers
127// together with other matrix parameters.
128// Scalar may be any floating-point or integral type. When integral, it may be
129// signed or unsigned. It's never const: use Matrix<T> for both input and output
130// matrices, never use Matrix<const T>.
131// See the comments on detail::ConstCheckingPointer.
132template <typename Scalar>
133class Matrix final {
134 public:
135 static_assert(!std::is_const<Scalar>::value,
136 "Never use Matrix<const T>. Just use Matrix<T>. Constness of "
137 "the data is guarded by debug-only runtime assertions. See "
138 "detail::ConstCheckingPtr.");
139
140 Scalar* data() { return data_.get(); }
141 const Scalar* data() const { return data_.get(); }
142 void set_data(Scalar* ptr) { data_.set(ptr); }
143 void set_data(const Scalar* ptr) { data_.set(ptr); }
144 void set_data(std::nullptr_t) { data_.set(nullptr); }
145 const Layout& layout() const { return layout_; }
146 Layout* mutable_layout() { return &layout_; }
147 Scalar zero_point() const { return zero_point_; }
148 void set_zero_point(Scalar value) { zero_point_ = value; }
149 CachePolicy cache_policy() const { return cache_policy_; }
150 void set_cache_policy(CachePolicy value) { cache_policy_ = value; }
151
152 private:
153 // The underlying buffer wrapped by this matrix.
154 detail::ConstCheckingPtr<Scalar> data_;
155 // The shape and data layout of this matrix.
156 Layout layout_;
157 // The zero_point, i.e. which Scalar value is to be interpreted as zero.
158 // When Scalar is floating-point, this must be 0.
159 Scalar zero_point_ = 0;
160 // When the data pointed to by this matrix is constant data, so that it is
161 // valid to assume that equality of pointers implies equality of data,
162 // a CachePolicy may be used instead of the default kNeverCache,
163 // which will enable ruy to take advantage of this constancy of the data to
164 // cache the packing work, which can be a large speedup in matrix*vector
165 // and other narrow shapes.
166 CachePolicy cache_policy_ = CachePolicy::kNeverCache;
167};
168
169inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
170 layout->set_rows(rows);
171 layout->set_cols(cols);
172 layout->set_order(order);
173 layout->set_stride(order == Order::kColMajor ? rows : cols);
174}
175
176template <typename StreamType, typename Scalar>
177StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) {
178 for (int row = 0; row < mat.layout().rows(); row++) {
179 for (int col = 0; col < mat.layout().cols(); col++) {
180 stream << static_cast<double>(Element(mat, row, col)) << " ";
181 }
182 stream << "\n";
183 }
184 return stream;
185}
186
187// TODO(b/130417400) add a unit test
188inline int Offset(const Layout& layout, int row, int col) {
189 // TODO(benoitjacob) - should check this but this make the _slow tests take
190 // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
191 // bypassing the check?
192 // RUY_DCHECK_GE(row, 0);
193 // RUY_DCHECK_GE(col, 0);
194 // RUY_DCHECK_LT(row, layout.rows());
195 // RUY_DCHECK_LT(col, layout.cols());
196 int row_stride = layout.order() == Order::kColMajor ? 1 : layout.stride();
197 int col_stride = layout.order() == Order::kRowMajor ? 1 : layout.stride();
198 return row * row_stride + col * col_stride;
199}
200
201template <typename Scalar>
202const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
203 return mat.data() + Offset(mat.layout(), row, col);
204}
205
206template <typename Scalar>
207Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
208 return mat->data() + Offset(mat->layout(), row, col);
209}
210
211template <typename Scalar>
212Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
213 return *ElementPtr(mat, row, col);
214}
215
216} // namespace ruy
217
218#endif // RUY_RUY_MATRIX_H_
219