1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_ |
17 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_ |
18 | |
19 | #include <cstdint> |
20 | #include <limits> |
21 | #include <type_traits> |
22 | |
23 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
24 | |
25 | namespace tflite { |
26 | |
27 | namespace cpu_backend_gemm { |
28 | |
29 | // Matrix storage order: column-major or row-major. |
30 | enum class Order { kColMajor, kRowMajor }; |
31 | |
32 | enum class CachePolicy : std::uint8_t { |
33 | kNeverCache, |
34 | kCacheIfLargeSpeedup, |
35 | kAlwaysCache, |
36 | }; |
37 | |
38 | inline CachePolicy DefaultCachePolicy(bool is_constant_data) { |
39 | return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup |
40 | : CachePolicy::kNeverCache; |
41 | } |
42 | |
43 | // MatrixParams encapsulates the parameters that Gemm needs about each |
44 | // matrix, besides the buffer data pointer. |
45 | // Compare to ruy::Matrix, which also encapsulates the data pointer. |
46 | // Rationale for leaving the data pointer out of here: doing so |
47 | // requires complicated const-correctness mechanics. See |
48 | // ruy::ConstCheckingPtr. |
49 | template <typename Scalar> |
50 | struct MatrixParams { |
51 | // Storage layout order. For now we only do plain linear non-strided |
52 | // layout. It would be easy to support a stride if needed. |
53 | Order order = Order::kColMajor; |
54 | // Number of rows of the matrix. |
55 | int rows = 0; |
56 | // Number of columns of the matrix. |
57 | int cols = 0; |
58 | // The zero_point, i.e. which Scalar value is to be interpreted as zero. |
59 | // When Scalar is floating-point, this must be 0. |
60 | Scalar zero_point = 0; |
61 | // When the data pointed to by this matrix is constant data, so that it is |
62 | // valid to assume that equality of pointers implies equality of data, |
63 | // a CachePolicy may be used instead of the default kNeverCache, |
64 | // which will enable ruy to take advantage of this constancy of the data to |
65 | // cache the packing work, which can be a large speedup in matrix*vector |
66 | // and other narrow shapes. |
67 | CachePolicy cache_policy = CachePolicy::kNeverCache; |
68 | }; |
69 | |
70 | // Enumeration of broad categories of Gemm. |
71 | // |
72 | // The primary reason for this to exist is to allow Gemm to compile |
73 | // only uniform-quantized or only per-channel-quantized code paths. |
74 | // This is unneeded with ruy as the back-end, as this is only a runtime |
75 | // difference in ruy, but with gemmlowp these really are separate code |
76 | // paths and templatizing in a QuantizationFlavor is necessary to avoid |
77 | // compiling unused gemmlowp code. Indeed, TFLite currently uses |
78 | // uint8 with uniform quantization and int8 with per-channel quantization, |
79 | // and does not use uint8 with per-channel. We want to avoid compiling |
80 | // the gemmlowp uint8 per-channel path when gemmlowp is the back-end. |
81 | // |
82 | // It's possible to drop this in the future if gemmlowp goes away and no |
83 | // other then-relevant backend library handles quantized paths in a way that |
84 | // requires knowing this at compile-time. |
85 | enum class QuantizationFlavor { |
86 | // Floating-point Gemm: the accumulators are not multiplied by any |
87 | // 'multiplier'. |
88 | kFloatingPoint, |
89 | // Quantized Gemm using a single multiplier for all accumulators. |
90 | kIntegerWithUniformMultiplier, |
91 | // Quantized Gemm using a separate multipliers for accumulators of each |
92 | // row of the destination matrix. This is what is called 'per-channel' |
93 | // in GemmParams. Here we use the more specific 'per-row' terminology |
94 | // to allow for the possibility of 'per-column' in the future, and to |
95 | // allow for that to be a separate code path in some back-end such as |
96 | // gemmlowp. |
97 | kIntegerWithPerRowMultiplier |
98 | }; |
99 | |
100 | // Additional parameters that Gemm needs, beyond what falls into |
101 | // the MatrixParams that it takes. Compare to ruy::Spec. |
102 | // |
103 | // Decoupling AccumScalar from DstScalar (rather than deducing it from that) |
104 | // is useful future-proofing. Think of a float16 path using float32 accum. |
105 | // |
106 | // QuantizationFlavor is passed here even though it's technically not used |
107 | // in this class. This is so that we retain the ability in the future to |
108 | // specialize this class for quantization flavor, and this allows for |
109 | // Gemm to be templatized in quantization_flavor via the GemmParams that it |
110 | // takes, allowing for automatic template parameter deduction to take place, |
111 | // so that most call sites don't need to specify a QuantizationFlavor |
112 | // (only those that need perchannel quantization do). |
113 | template <typename AccumScalar, typename DstScalar, |
114 | QuantizationFlavor quantization_flavor = |
115 | std::is_floating_point<AccumScalar>::value |
116 | ? QuantizationFlavor::kFloatingPoint |
117 | : QuantizationFlavor::kIntegerWithUniformMultiplier> |
118 | struct GemmParams { |
119 | // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) |
120 | // of the multiplier by which accumulators are multiplied before being casted |
121 | // to the destination type. |
122 | AccumScalar multiplier_fixedpoint = 0; |
123 | // Only for non-floating-point cases. The exponent part of the aforementioned |
124 | // multiplier. |
125 | int multiplier_exponent = 0; |
126 | // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must |
127 | // point to a buffer of as many values as there are rows in the destination |
128 | // matrix. Each row of the destination matrix will use the corresponding |
129 | // buffer element instead of multiplier_fixedpoint. |
130 | const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; |
131 | // Per-channel variant of multiplier_exponent. If not nullptr, this must |
132 | // point to a buffer of as many values as there are rows in the destination |
133 | // matrix. Each row of the destination matrix will use the corresponding |
134 | // buffer element instead of multiplier_exponent. |
135 | // |
136 | // Either none or both of multiplier_exponent_perchannel and |
137 | // multiplier_fixedpoint_perchannel must be nullptr. |
138 | const int* multiplier_exponent_perchannel = nullptr; |
139 | // The bias vector data, if not null. |
140 | const AccumScalar* bias = nullptr; |
141 | // min clamp bound of destination values. |
142 | DstScalar clamp_min = std::is_floating_point<DstScalar>::value |
143 | ? -std::numeric_limits<DstScalar>::infinity() |
144 | : std::numeric_limits<DstScalar>::lowest(); |
145 | // max clamp bound of destination values. |
146 | DstScalar clamp_max = std::is_floating_point<DstScalar>::value |
147 | ? std::numeric_limits<DstScalar>::infinity() |
148 | : std::numeric_limits<DstScalar>::max(); |
149 | }; |
150 | |
151 | /* Convenience typedefs */ |
152 | |
153 | template <typename DstScalar> |
154 | using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>; |
155 | |
156 | using FloatGemmParams = GemmParams<float, float>; |
157 | |
158 | /* Validation functions */ |
159 | |
160 | // Note that this uses TFLITE_DCHECK from kernels/internal/compatibility.h |
161 | // and not TF_LITE_ASSERT from op_macros.h. We want this to be explicitly |
162 | // debug-build-only assertions so that there's not reason not to |
163 | // generously validate, and TF_LITE_ASSERT is actually at the moment |
164 | // a release-build assertion. See b/131587258. |
165 | |
166 | // Validates self-consistency of GemmParams. |
167 | template <typename AccumScalar, typename DstScalar, |
168 | QuantizationFlavor quantization_flavor> |
169 | void ValidateGemmParams( |
170 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) { |
171 | // Guard consistency of the quantized multiplier fields. |
172 | if (quantization_flavor == QuantizationFlavor::kFloatingPoint) { |
173 | TFLITE_DCHECK(!params.multiplier_fixedpoint); |
174 | TFLITE_DCHECK(!params.multiplier_exponent); |
175 | TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel); |
176 | TFLITE_DCHECK(!params.multiplier_exponent_perchannel); |
177 | } else if (quantization_flavor == |
178 | QuantizationFlavor::kIntegerWithUniformMultiplier && |
179 | !std::is_same<DstScalar, int32_t>::value) { |
180 | TFLITE_DCHECK(params.multiplier_fixedpoint); |
181 | // Nothing to check about multiplier_exponent |
182 | TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel); |
183 | TFLITE_DCHECK(!params.multiplier_exponent_perchannel); |
184 | } else if (quantization_flavor == |
185 | QuantizationFlavor::kIntegerWithPerRowMultiplier && |
186 | !std::is_same<DstScalar, int32_t>::value) { |
187 | TFLITE_DCHECK(!params.multiplier_fixedpoint); |
188 | TFLITE_DCHECK(!params.multiplier_exponent); |
189 | TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel); |
190 | TFLITE_DCHECK(params.multiplier_exponent_perchannel); |
191 | } else { |
192 | // For the get raw accumulator case, we should make sure none of the |
193 | // quantization params are set. |
194 | TFLITE_DCHECK(!params.multiplier_fixedpoint); |
195 | TFLITE_DCHECK(!params.multiplier_exponent); |
196 | TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel); |
197 | TFLITE_DCHECK(!params.multiplier_exponent_perchannel); |
198 | } |
199 | } |
200 | |
201 | namespace detail { |
202 | |
203 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
204 | typename DstScalar, QuantizationFlavor quantization_flavor> |
205 | struct ValidateTypes { |
206 | // This generic implementation is for quantized flavors. |
207 | // kFloatingPoint will be a specialization below. |
208 | static_assert(!std::is_floating_point<LhsScalar>::value, "" ); |
209 | static_assert(!std::is_floating_point<RhsScalar>::value, "" ); |
210 | static_assert(!std::is_floating_point<AccumScalar>::value, "" ); |
211 | // No requirement on DstScalar --- we might in the future allow it |
212 | // to be floating point even in a quantized Gemm. |
213 | }; |
214 | |
215 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
216 | typename DstScalar> |
217 | struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
218 | QuantizationFlavor::kFloatingPoint> { |
219 | static_assert(std::is_floating_point<LhsScalar>::value, "" ); |
220 | static_assert(std::is_floating_point<RhsScalar>::value, "" ); |
221 | static_assert(std::is_floating_point<AccumScalar>::value, "" ); |
222 | static_assert(std::is_floating_point<DstScalar>::value, "" ); |
223 | }; |
224 | |
225 | } // namespace detail |
226 | |
227 | // Validates overall consistency of all the parameters taken by a Gemm call: |
228 | // the 3 MatrixParams and the GemmParams. |
229 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
230 | typename DstScalar, QuantizationFlavor quantization_flavor> |
231 | void ValidateParams( |
232 | const MatrixParams<LhsScalar>& lhs_params, |
233 | const MatrixParams<RhsScalar>& rhs_params, |
234 | const MatrixParams<DstScalar>& dst_params, |
235 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) { |
236 | (void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
237 | quantization_flavor>(); |
238 | ValidateGemmParams(params); |
239 | } |
240 | |
241 | // Test if the Gemm is degenerate in some way, e.g. nonsensical dimenions. |
242 | template <typename LhsScalar, typename RhsScalar, typename DstScalar> |
243 | bool IsValidGemm(const MatrixParams<LhsScalar>& lhs_params, |
244 | const MatrixParams<RhsScalar>& rhs_params, |
245 | const MatrixParams<DstScalar>& dst_params) { |
246 | bool valid = true; |
247 | valid &= lhs_params.rows >= 1; |
248 | valid &= lhs_params.cols >= 1; |
249 | valid &= rhs_params.rows >= 1; |
250 | valid &= rhs_params.cols >= 1; |
251 | valid &= dst_params.rows >= 1; |
252 | valid &= dst_params.cols >= 1; |
253 | valid &= lhs_params.cols == rhs_params.rows; |
254 | valid &= rhs_params.cols == dst_params.cols; |
255 | valid &= lhs_params.rows == lhs_params.rows; |
256 | return valid; |
257 | } |
258 | |
259 | } // namespace cpu_backend_gemm |
260 | |
261 | } // namespace tflite |
262 | |
263 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_ |
264 | |