1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_MUL_PARAMS_H_
17#define RUY_RUY_MUL_PARAMS_H_
18
19#include <cstdint>
20#include <limits>
21#include <type_traits>
22
23#include "ruy/check_macros.h"
24#include "ruy/size_util.h"
25
26namespace ruy {
27
28// Enumeration to designate which dimension is the 'channels', for MulParams
29// features that are 'per-channel', namely the bias-vector and the quantized
30// multiplier.
31enum class ChannelDimension : std::int8_t {
32 // kRow means that 'per-channel' means 'per row of the destination matrix'
33 kRow,
34 // kCol means that 'per-channel' means 'per column of the destination matrix'
35 kCol
36};
37
38namespace detail {
39template <typename tAccumScalar, typename tDstScalar>
40struct MulParamsStorage;
41}
42
43// MulParams describes all about a matrix multiplication that
44// isn't encoded in the LHS, RHS and destination matrices. Some of that
45// information is encoded as compile-time constants and types (for instance, the
46// choice of accumulator type, AccumScalar). Some of that information is encoded
47// as runtime values (for instance, the optional bias vector).
48//
49// Template parameters:
50// AccumScalar: Accumulator type. The type of accumulators used to compute the
51// dot-products before being ultimately casted to the destination type.
52// DstScalar: The destination scalar type.
53//
54// Constraints on these template parameters (see also the ruy::Mul comment):
55// * If DstScalar is floating-point then AccumScalar must also be.
56// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
57// in that integral case, there is a mode switch:
58// - If DstScalar is std::int32_t then the multiplier_* fields are all
59// disabled, and ruy::Mul will just return raw (unscaled) accumulators.
60// - If DstScalar is not std::int32_t then the multiplier_* fields are
61// enabled, and ruy::Mul will use them to scale internal std::int32_t
62// accumulators before casting them to the DstScalar type. The default
63// values are such that the effective multiplier is 1 (no scaling).
64//
65// For the latter case (DstScalar integral and narrower than std::int32_t),
66// reference code can be found in the implementation of ruy::ApplyMultiplier.
67// If you look there, you'll find warnings like this:
68//
69// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
70// Warning: this code is not meant to be bit-exact-normative.
71// Please refer to the class comment of ruy::MulParams, in mul_params.h.
72// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
73//
74// The explanation of this warning is that as of early 2021, we still don't know
75// whether it is advisable to let this code as-is have normative value, or
76// whether that would become advisable after some specific final change.
77//
78// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
79// bit-exactly to this reference, but we also know that x86 could be faster if
80// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
81// know that this particular reference code is inherently better than other
82// forms that could perform better on these architectures --- in fact, the
83// alternative that was proposed in [2] as better performing on ARM Cortex-M
84// is also inherently more accurate thanks to rounding only once, but it would
85// perform worse on both ARM NEON, and x86.
86//
87// In fact, if we look at other hardware architectures beyond current Ruy
88// targets, namely "hardware accelerators", it becomes clear that there is no
89// hope for any form of this to be efficiently implementable simultaneously on
90// all current relevant hardware. Indeed, some accelerators prefer to perform
91// the multiplication in IEEE float32, others in IEEE float16, others in
92// bfloat16, others in 16-bit fixed-point...
93//
94// See:
95// [1] https://github.com/google/ruy/pull/227
96// [2] https://github.com/tensorflow/tensorflow/issues/25087
97template <typename tAccumScalar, typename tDstScalar>
98class MulParams final {
99 public:
100 using AccumScalar = tAccumScalar;
101 using DstScalar = tDstScalar;
102
103 // The bias vector data, if not null.
104 const AccumScalar* bias() const { return storage_.bias; }
105 void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
106 // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
107 // of the multiplier by which accumulators are multiplied before being casted
108 // to the destination type.
109 AccumScalar multiplier_fixedpoint() const {
110 return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
111 }
112 void set_multiplier_fixedpoint(const AccumScalar value) {
113 set_perchannel(false);
114 storage_.multiplier_fixedpoint = value;
115 }
116 // Only for non-floating-point cases. The exponent part of the aforementioned
117 // multiplier.
118 int multiplier_exponent() const {
119 return storage_.perchannel ? 0 : storage_.multiplier_exponent;
120 }
121 void set_multiplier_exponent(const int value) {
122 set_perchannel(false);
123 storage_.multiplier_exponent = value;
124 }
125 // Per-channel variant of multiplier_fixedpoint. Setting this switches
126 // to per-channel mode, where `multiplier_fixedpoint` and
127 // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
128 // and `multiplier_exponent_perchannel` are used instead.
129 //
130 // This must point to a buffer of as many values as there are rows in the
131 // destination matrix. Each row of the destination matrix will use the
132 // corresponding buffer element instead of multiplier_fixedpoint.
133 const AccumScalar* multiplier_fixedpoint_perchannel() const {
134 return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
135 : nullptr;
136 }
137 void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) {
138 set_perchannel(true);
139 storage_.multiplier_fixedpoint_perchannel = ptr;
140 }
141 // Per-channel variant of multiplier_exponent. Same comments as for
142 // multiplier_fixedpoint_perchannel.
143 const int* multiplier_exponent_perchannel() const {
144 return storage_.perchannel ? storage_.multiplier_exponent_perchannel
145 : nullptr;
146 }
147 void set_multiplier_exponent_perchannel(const int* ptr) {
148 set_perchannel(true);
149 storage_.multiplier_exponent_perchannel = ptr;
150 }
151 // min clamp bound of destination values.
152 DstScalar clamp_min() const { return storage_.clamp_min; }
153 void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; }
154 // max clamp bound of destination values.
155 DstScalar clamp_max() const { return storage_.clamp_max; }
156 void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; }
157 // Designates which dimension is the 'channels', for per-channel features
158 // such as bias-addition and per-channel quantization multipliers.
159 ChannelDimension channel_dimension() const {
160 return storage_.channel_dimension;
161 }
162 void set_channel_dimension(ChannelDimension value) {
163 storage_.channel_dimension = value;
164 }
165 // Specifies the upward rounding of the allocated capacity of per-channel
166 // buffers such as bias vectors and per-channel quantization multipliers.
167 // The unit is matrix entries, not bytes.
168 //
169 // This value must be a power of two.
170 //
171 // The default value, 1, means no upward rounding, meaning that the buffers
172 // are not required to have a capacity greater than the size of the
173 // corresponding matrix dimension, i.e. the number of rows (respectively
174 // columns) of the destination matrix if `channel_dimension()` is kRow
175 // (respectively kCol).
176 //
177 // Higher values allow the implementation to assume that it is OK to access
178 // these buffers a little past this boundary, which is useful in SIMD
179 // optimized kernels. In practice, when this value is lower than what the
180 // kernel requires, ruy has to internally reallocate and copy per-channel
181 // buffers. When this value is high enough, this reallocation and copy is
182 // avoided.
183 //
184 // When a value greater than 1 is specified, the tail region of the buffer
185 // (past the end of the values actually corresponding to channels) is required
186 // to be zero-initialized.
187 //
188 // As of 2020, values as high as 16 may be useful on some CPU architectures
189 // (corresponding to the widest kernels used on any CPU architecture).
190 int perchannel_buffers_capacity_rounding() const {
191 return 1 << storage_.perchannel_buffers_capacity_rounding_log2;
192 }
193 void set_perchannel_buffers_capacity_rounding(int value) {
194 // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two.
195 storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value);
196 }
197
198 private:
199 detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
200
201 void set_perchannel(bool perchannel) {
202 if (storage_.perchannel == perchannel) {
203 return;
204 }
205 if (perchannel) {
206 RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
207 RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
208 } else {
209 RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
210 RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
211 }
212 storage_.perchannel = perchannel;
213 }
214};
215
216namespace detail {
217
218// Floating-point case.
219template <typename AccumScalar, typename DstScalar>
220struct MulParamsStorage final {
221 static_assert(std::is_floating_point<AccumScalar>::value, "");
222 static_assert(std::is_floating_point<DstScalar>::value, "");
223 static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "");
224
225 const AccumScalar* bias = nullptr;
226 DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity();
227 DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity();
228 ChannelDimension channel_dimension = ChannelDimension::kRow;
229 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
230
231 // Data members that are disabled in this case are left as `static constexpr`
232 // so that one can write some generic code.
233 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
234 nullptr;
235 static constexpr const int* multiplier_exponent_perchannel = nullptr;
236 static constexpr AccumScalar multiplier_fixedpoint = 0;
237 static constexpr int multiplier_exponent = 0;
238 static constexpr bool perchannel = false;
239};
240
241// Specialization for the integer-quantized type, with down-quantization of
242// int32 accumulators to a narrower destination scalar type.
243template <typename DstScalar>
244struct MulParamsStorage<std::int32_t, DstScalar> final {
245 using AccumScalar = std::int32_t;
246 static_assert(std::is_integral<DstScalar>::value, "");
247 static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
248
249 const AccumScalar* bias = nullptr;
250 // union { // This used to be a union, temporarily flattened to debug a crash
251 const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
252 // Let the default multiplier be effecively a multiplication by 1, so that
253 // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
254 // 1 is not exactly representable in fixedpoint with 0 integer bits, but
255 // using the highest representable value is a sufficiently good
256 // approximation: since this specialization of MulParams is for the case
257 // where DstScalar is at least 2x narrower than MulScalar, the values
258 // for which there would be a difference will get saturated anyway.
259 AccumScalar multiplier_fixedpoint = 0;
260 //};
261 // union { // This used to be a union, temporarily flattened to debug a crash
262 const int* multiplier_exponent_perchannel = nullptr;
263 // See the above comment about the default value of multiplier_fixedpoint.
264 int multiplier_exponent = 0;
265 // };
266 DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
267 DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
268 ChannelDimension channel_dimension = ChannelDimension::kRow;
269 bool perchannel = false;
270 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
271};
272
273// Specialization used in the integer case when outputting raw int32
274// accumulators, without down-quantization to a narrower destination scalar
275// type. In this case, the feature of clamping destination values is not
276// available.
277template <>
278struct MulParamsStorage<std::int32_t, std::int32_t> final {
279 using AccumScalar = std::int32_t;
280 using DstScalar = std::int32_t;
281
282 const AccumScalar* bias = nullptr;
283 ChannelDimension channel_dimension = ChannelDimension::kRow;
284 std::int8_t perchannel_buffers_capacity_rounding_log2 = 0;
285
286 // Data members that are disabled in this case are left as `static constexpr`
287 // so that one can write some generic code.
288 static constexpr const AccumScalar* multiplier_fixedpoint_perchannel =
289 nullptr;
290 static constexpr const int* multiplier_exponent_perchannel = nullptr;
291 static constexpr AccumScalar multiplier_fixedpoint = 0;
292 static constexpr int multiplier_exponent = 0;
293 static constexpr DstScalar clamp_min =
294 std::numeric_limits<DstScalar>::lowest();
295 static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
296 static constexpr bool perchannel = false;
297};
298
299} // namespace detail
300
301} // namespace ruy
302
303#endif // RUY_RUY_MUL_PARAMS_H_
304