1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_MUL_PARAMS_H_ |
17 | #define RUY_RUY_MUL_PARAMS_H_ |
18 | |
19 | #include <cstdint> |
20 | #include <limits> |
21 | #include <type_traits> |
22 | |
23 | #include "ruy/check_macros.h" |
24 | #include "ruy/size_util.h" |
25 | |
26 | namespace ruy { |
27 | |
28 | // Enumeration to designate which dimension is the 'channels', for MulParams |
29 | // features that are 'per-channel', namely the bias-vector and the quantized |
30 | // multiplier. |
31 | enum class ChannelDimension : std::int8_t { |
32 | // kRow means that 'per-channel' means 'per row of the destination matrix' |
33 | kRow, |
34 | // kCol means that 'per-channel' means 'per column of the destination matrix' |
35 | kCol |
36 | }; |
37 | |
38 | namespace detail { |
39 | template <typename tAccumScalar, typename tDstScalar> |
40 | struct MulParamsStorage; |
41 | } |
42 | |
43 | // MulParams describes all about a matrix multiplication that |
44 | // isn't encoded in the LHS, RHS and destination matrices. Some of that |
45 | // information is encoded as compile-time constants and types (for instance, the |
46 | // choice of accumulator type, AccumScalar). Some of that information is encoded |
47 | // as runtime values (for instance, the optional bias vector). |
48 | // |
49 | // Template parameters: |
50 | // AccumScalar: Accumulator type. The type of accumulators used to compute the |
51 | // dot-products before being ultimately casted to the destination type. |
52 | // DstScalar: The destination scalar type. |
53 | // |
54 | // Constraints on these template parameters (see also the ruy::Mul comment): |
55 | // * If DstScalar is floating-point then AccumScalar must also be. |
56 | // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover |
57 | // in that integral case, there is a mode switch: |
58 | // - If DstScalar is std::int32_t then the multiplier_* fields are all |
59 | // disabled, and ruy::Mul will just return raw (unscaled) accumulators. |
60 | // - If DstScalar is not std::int32_t then the multiplier_* fields are |
61 | // enabled, and ruy::Mul will use them to scale internal std::int32_t |
62 | // accumulators before casting them to the DstScalar type. The default |
63 | // values are such that the effective multiplier is 1 (no scaling). |
64 | // |
65 | // For the latter case (DstScalar integral and narrower than std::int32_t), |
66 | // reference code can be found in the implementation of ruy::ApplyMultiplier. |
67 | // If you look there, you'll find warnings like this: |
68 | // |
69 | // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! |
70 | // Warning: this code is not meant to be bit-exact-normative. |
71 | // Please refer to the class comment of ruy::MulParams, in mul_params.h. |
72 | // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! |
73 | // |
74 | // The explanation of this warning is that as of early 2021, we still don't know |
75 | // whether it is advisable to let this code as-is have normative value, or |
76 | // whether that would become advisable after some specific final change. |
77 | // |
78 | // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform |
79 | // bit-exactly to this reference, but we also know that x86 could be faster if |
80 | // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't |
81 | // know that this particular reference code is inherently better than other |
82 | // forms that could perform better on these architectures --- in fact, the |
83 | // alternative that was proposed in [2] as better performing on ARM Cortex-M |
84 | // is also inherently more accurate thanks to rounding only once, but it would |
85 | // perform worse on both ARM NEON, and x86. |
86 | // |
87 | // In fact, if we look at other hardware architectures beyond current Ruy |
88 | // targets, namely "hardware accelerators", it becomes clear that there is no |
89 | // hope for any form of this to be efficiently implementable simultaneously on |
90 | // all current relevant hardware. Indeed, some accelerators prefer to perform |
91 | // the multiplication in IEEE float32, others in IEEE float16, others in |
92 | // bfloat16, others in 16-bit fixed-point... |
93 | // |
94 | // See: |
95 | // [1] https://github.com/google/ruy/pull/227 |
96 | // [2] https://github.com/tensorflow/tensorflow/issues/25087 |
97 | template <typename tAccumScalar, typename tDstScalar> |
98 | class MulParams final { |
99 | public: |
100 | using AccumScalar = tAccumScalar; |
101 | using DstScalar = tDstScalar; |
102 | |
103 | // The bias vector data, if not null. |
104 | const AccumScalar* bias() const { return storage_.bias; } |
105 | void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } |
106 | // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) |
107 | // of the multiplier by which accumulators are multiplied before being casted |
108 | // to the destination type. |
109 | AccumScalar multiplier_fixedpoint() const { |
110 | return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; |
111 | } |
112 | void set_multiplier_fixedpoint(const AccumScalar value) { |
113 | set_perchannel(false); |
114 | storage_.multiplier_fixedpoint = value; |
115 | } |
116 | // Only for non-floating-point cases. The exponent part of the aforementioned |
117 | // multiplier. |
118 | int multiplier_exponent() const { |
119 | return storage_.perchannel ? 0 : storage_.multiplier_exponent; |
120 | } |
121 | void set_multiplier_exponent(const int value) { |
122 | set_perchannel(false); |
123 | storage_.multiplier_exponent = value; |
124 | } |
125 | // Per-channel variant of multiplier_fixedpoint. Setting this switches |
126 | // to per-channel mode, where `multiplier_fixedpoint` and |
127 | // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` |
128 | // and `multiplier_exponent_perchannel` are used instead. |
129 | // |
130 | // This must point to a buffer of as many values as there are rows in the |
131 | // destination matrix. Each row of the destination matrix will use the |
132 | // corresponding buffer element instead of multiplier_fixedpoint. |
133 | const AccumScalar* multiplier_fixedpoint_perchannel() const { |
134 | return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel |
135 | : nullptr; |
136 | } |
137 | void set_multiplier_fixedpoint_perchannel(const AccumScalar* ptr) { |
138 | set_perchannel(true); |
139 | storage_.multiplier_fixedpoint_perchannel = ptr; |
140 | } |
141 | // Per-channel variant of multiplier_exponent. Same comments as for |
142 | // multiplier_fixedpoint_perchannel. |
143 | const int* multiplier_exponent_perchannel() const { |
144 | return storage_.perchannel ? storage_.multiplier_exponent_perchannel |
145 | : nullptr; |
146 | } |
147 | void set_multiplier_exponent_perchannel(const int* ptr) { |
148 | set_perchannel(true); |
149 | storage_.multiplier_exponent_perchannel = ptr; |
150 | } |
151 | // min clamp bound of destination values. |
152 | DstScalar clamp_min() const { return storage_.clamp_min; } |
153 | void set_clamp_min(const DstScalar value) { storage_.clamp_min = value; } |
154 | // max clamp bound of destination values. |
155 | DstScalar clamp_max() const { return storage_.clamp_max; } |
156 | void set_clamp_max(const DstScalar value) { storage_.clamp_max = value; } |
157 | // Designates which dimension is the 'channels', for per-channel features |
158 | // such as bias-addition and per-channel quantization multipliers. |
159 | ChannelDimension channel_dimension() const { |
160 | return storage_.channel_dimension; |
161 | } |
162 | void set_channel_dimension(ChannelDimension value) { |
163 | storage_.channel_dimension = value; |
164 | } |
165 | // Specifies the upward rounding of the allocated capacity of per-channel |
166 | // buffers such as bias vectors and per-channel quantization multipliers. |
167 | // The unit is matrix entries, not bytes. |
168 | // |
169 | // This value must be a power of two. |
170 | // |
171 | // The default value, 1, means no upward rounding, meaning that the buffers |
172 | // are not required to have a capacity greater than the size of the |
173 | // corresponding matrix dimension, i.e. the number of rows (respectively |
174 | // columns) of the destination matrix if `channel_dimension()` is kRow |
175 | // (respectively kCol). |
176 | // |
177 | // Higher values allow the implementation to assume that it is OK to access |
178 | // these buffers a little past this boundary, which is useful in SIMD |
179 | // optimized kernels. In practice, when this value is lower than what the |
180 | // kernel requires, ruy has to internally reallocate and copy per-channel |
181 | // buffers. When this value is high enough, this reallocation and copy is |
182 | // avoided. |
183 | // |
184 | // When a value greater than 1 is specified, the tail region of the buffer |
185 | // (past the end of the values actually corresponding to channels) is required |
186 | // to be zero-initialized. |
187 | // |
188 | // As of 2020, values as high as 16 may be useful on some CPU architectures |
189 | // (corresponding to the widest kernels used on any CPU architecture). |
190 | int perchannel_buffers_capacity_rounding() const { |
191 | return 1 << storage_.perchannel_buffers_capacity_rounding_log2; |
192 | } |
193 | void set_perchannel_buffers_capacity_rounding(int value) { |
194 | // Note: pot_log2 asserts (debug-only) that its argument is a power-of-two. |
195 | storage_.perchannel_buffers_capacity_rounding_log2 = pot_log2(value); |
196 | } |
197 | |
198 | private: |
199 | detail::MulParamsStorage<AccumScalar, DstScalar> storage_; |
200 | |
201 | void set_perchannel(bool perchannel) { |
202 | if (storage_.perchannel == perchannel) { |
203 | return; |
204 | } |
205 | if (perchannel) { |
206 | RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); |
207 | RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); |
208 | } else { |
209 | RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); |
210 | RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); |
211 | } |
212 | storage_.perchannel = perchannel; |
213 | } |
214 | }; |
215 | |
216 | namespace detail { |
217 | |
218 | // Floating-point case. |
219 | template <typename AccumScalar, typename DstScalar> |
220 | struct MulParamsStorage final { |
221 | static_assert(std::is_floating_point<AccumScalar>::value, "" ); |
222 | static_assert(std::is_floating_point<DstScalar>::value, "" ); |
223 | static_assert(sizeof(DstScalar) <= sizeof(AccumScalar), "" ); |
224 | |
225 | const AccumScalar* bias = nullptr; |
226 | DstScalar clamp_min = -std::numeric_limits<DstScalar>::infinity(); |
227 | DstScalar clamp_max = std::numeric_limits<DstScalar>::infinity(); |
228 | ChannelDimension channel_dimension = ChannelDimension::kRow; |
229 | std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; |
230 | |
231 | // Data members that are disabled in this case are left as `static constexpr` |
232 | // so that one can write some generic code. |
233 | static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = |
234 | nullptr; |
235 | static constexpr const int* multiplier_exponent_perchannel = nullptr; |
236 | static constexpr AccumScalar multiplier_fixedpoint = 0; |
237 | static constexpr int multiplier_exponent = 0; |
238 | static constexpr bool perchannel = false; |
239 | }; |
240 | |
241 | // Specialization for the integer-quantized type, with down-quantization of |
242 | // int32 accumulators to a narrower destination scalar type. |
243 | template <typename DstScalar> |
244 | struct MulParamsStorage<std::int32_t, DstScalar> final { |
245 | using AccumScalar = std::int32_t; |
246 | static_assert(std::is_integral<DstScalar>::value, "" ); |
247 | static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "" ); |
248 | |
249 | const AccumScalar* bias = nullptr; |
250 | // union { // This used to be a union, temporarily flattened to debug a crash |
251 | const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; |
252 | // Let the default multiplier be effecively a multiplication by 1, so that |
253 | // the matmul behaves as a (saturating) plain integer matmul. Unfortunately |
254 | // 1 is not exactly representable in fixedpoint with 0 integer bits, but |
255 | // using the highest representable value is a sufficiently good |
256 | // approximation: since this specialization of MulParams is for the case |
257 | // where DstScalar is at least 2x narrower than MulScalar, the values |
258 | // for which there would be a difference will get saturated anyway. |
259 | AccumScalar multiplier_fixedpoint = 0; |
260 | //}; |
261 | // union { // This used to be a union, temporarily flattened to debug a crash |
262 | const int* multiplier_exponent_perchannel = nullptr; |
263 | // See the above comment about the default value of multiplier_fixedpoint. |
264 | int multiplier_exponent = 0; |
265 | // }; |
266 | DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); |
267 | DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); |
268 | ChannelDimension channel_dimension = ChannelDimension::kRow; |
269 | bool perchannel = false; |
270 | std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; |
271 | }; |
272 | |
273 | // Specialization used in the integer case when outputting raw int32 |
274 | // accumulators, without down-quantization to a narrower destination scalar |
275 | // type. In this case, the feature of clamping destination values is not |
276 | // available. |
277 | template <> |
278 | struct MulParamsStorage<std::int32_t, std::int32_t> final { |
279 | using AccumScalar = std::int32_t; |
280 | using DstScalar = std::int32_t; |
281 | |
282 | const AccumScalar* bias = nullptr; |
283 | ChannelDimension channel_dimension = ChannelDimension::kRow; |
284 | std::int8_t perchannel_buffers_capacity_rounding_log2 = 0; |
285 | |
286 | // Data members that are disabled in this case are left as `static constexpr` |
287 | // so that one can write some generic code. |
288 | static constexpr const AccumScalar* multiplier_fixedpoint_perchannel = |
289 | nullptr; |
290 | static constexpr const int* multiplier_exponent_perchannel = nullptr; |
291 | static constexpr AccumScalar multiplier_fixedpoint = 0; |
292 | static constexpr int multiplier_exponent = 0; |
293 | static constexpr DstScalar clamp_min = |
294 | std::numeric_limits<DstScalar>::lowest(); |
295 | static constexpr DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); |
296 | static constexpr bool perchannel = false; |
297 | }; |
298 | |
299 | } // namespace detail |
300 | |
301 | } // namespace ruy |
302 | |
303 | #endif // RUY_RUY_MUL_PARAMS_H_ |
304 | |