1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
17#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
18
19#include <cstdint>
20#include <limits>
21#include <type_traits>
22
23#include "tensorflow/lite/kernels/internal/compatibility.h"
24
25namespace tflite {
26
27namespace cpu_backend_gemm {
28
29// Matrix storage order: column-major or row-major.
30enum class Order { kColMajor, kRowMajor };
31
32enum class CachePolicy : std::uint8_t {
33 kNeverCache,
34 kCacheIfLargeSpeedup,
35 kAlwaysCache,
36};
37
38inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
39 return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
40 : CachePolicy::kNeverCache;
41}
42
43// MatrixParams encapsulates the parameters that Gemm needs about each
44// matrix, besides the buffer data pointer.
45// Compare to ruy::Matrix, which also encapsulates the data pointer.
46// Rationale for leaving the data pointer out of here: doing so
47// requires complicated const-correctness mechanics. See
48// ruy::ConstCheckingPtr.
49template <typename Scalar>
50struct MatrixParams {
51 // Storage layout order. For now we only do plain linear non-strided
52 // layout. It would be easy to support a stride if needed.
53 Order order = Order::kColMajor;
54 // Number of rows of the matrix.
55 int rows = 0;
56 // Number of columns of the matrix.
57 int cols = 0;
58 // The zero_point, i.e. which Scalar value is to be interpreted as zero.
59 // When Scalar is floating-point, this must be 0.
60 Scalar zero_point = 0;
61 // When the data pointed to by this matrix is constant data, so that it is
62 // valid to assume that equality of pointers implies equality of data,
63 // a CachePolicy may be used instead of the default kNeverCache,
64 // which will enable ruy to take advantage of this constancy of the data to
65 // cache the packing work, which can be a large speedup in matrix*vector
66 // and other narrow shapes.
67 CachePolicy cache_policy = CachePolicy::kNeverCache;
68};
69
70// Enumeration of broad categories of Gemm.
71//
72// The primary reason for this to exist is to allow Gemm to compile
73// only uniform-quantized or only per-channel-quantized code paths.
74// This is unneeded with ruy as the back-end, as this is only a runtime
75// difference in ruy, but with gemmlowp these really are separate code
76// paths and templatizing in a QuantizationFlavor is necessary to avoid
77// compiling unused gemmlowp code. Indeed, TFLite currently uses
78// uint8 with uniform quantization and int8 with per-channel quantization,
79// and does not use uint8 with per-channel. We want to avoid compiling
80// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
81//
82// It's possible to drop this in the future if gemmlowp goes away and no
83// other then-relevant backend library handles quantized paths in a way that
84// requires knowing this at compile-time.
85enum class QuantizationFlavor {
86 // Floating-point Gemm: the accumulators are not multiplied by any
87 // 'multiplier'.
88 kFloatingPoint,
89 // Quantized Gemm using a single multiplier for all accumulators.
90 kIntegerWithUniformMultiplier,
91 // Quantized Gemm using a separate multipliers for accumulators of each
92 // row of the destination matrix. This is what is called 'per-channel'
93 // in GemmParams. Here we use the more specific 'per-row' terminology
94 // to allow for the possibility of 'per-column' in the future, and to
95 // allow for that to be a separate code path in some back-end such as
96 // gemmlowp.
97 kIntegerWithPerRowMultiplier
98};
99
100// Additional parameters that Gemm needs, beyond what falls into
101// the MatrixParams that it takes. Compare to ruy::Spec.
102//
103// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
104// is useful future-proofing. Think of a float16 path using float32 accum.
105//
106// QuantizationFlavor is passed here even though it's technically not used
107// in this class. This is so that we retain the ability in the future to
108// specialize this class for quantization flavor, and this allows for
109// Gemm to be templatized in quantization_flavor via the GemmParams that it
110// takes, allowing for automatic template parameter deduction to take place,
111// so that most call sites don't need to specify a QuantizationFlavor
112// (only those that need perchannel quantization do).
113template <typename AccumScalar, typename DstScalar,
114 QuantizationFlavor quantization_flavor =
115 std::is_floating_point<AccumScalar>::value
116 ? QuantizationFlavor::kFloatingPoint
117 : QuantizationFlavor::kIntegerWithUniformMultiplier>
118struct GemmParams {
119 // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
120 // of the multiplier by which accumulators are multiplied before being casted
121 // to the destination type.
122 AccumScalar multiplier_fixedpoint = 0;
123 // Only for non-floating-point cases. The exponent part of the aforementioned
124 // multiplier.
125 int multiplier_exponent = 0;
126 // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
127 // point to a buffer of as many values as there are rows in the destination
128 // matrix. Each row of the destination matrix will use the corresponding
129 // buffer element instead of multiplier_fixedpoint.
130 const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
131 // Per-channel variant of multiplier_exponent. If not nullptr, this must
132 // point to a buffer of as many values as there are rows in the destination
133 // matrix. Each row of the destination matrix will use the corresponding
134 // buffer element instead of multiplier_exponent.
135 //
136 // Either none or both of multiplier_exponent_perchannel and
137 // multiplier_fixedpoint_perchannel must be nullptr.
138 const int* multiplier_exponent_perchannel = nullptr;
139 // The bias vector data, if not null.
140 const AccumScalar* bias = nullptr;
141 // min clamp bound of destination values.
142 DstScalar clamp_min = std::is_floating_point<DstScalar>::value
143 ? -std::numeric_limits<DstScalar>::infinity()
144 : std::numeric_limits<DstScalar>::lowest();
145 // max clamp bound of destination values.
146 DstScalar clamp_max = std::is_floating_point<DstScalar>::value
147 ? std::numeric_limits<DstScalar>::infinity()
148 : std::numeric_limits<DstScalar>::max();
149};
150
151/* Convenience typedefs */
152
153template <typename DstScalar>
154using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>;
155
156using FloatGemmParams = GemmParams<float, float>;
157
158/* Validation functions */
159
160// Note that this uses TFLITE_DCHECK from kernels/internal/compatibility.h
161// and not TF_LITE_ASSERT from op_macros.h. We want this to be explicitly
162// debug-build-only assertions so that there's not reason not to
163// generously validate, and TF_LITE_ASSERT is actually at the moment
164// a release-build assertion. See b/131587258.
165
166// Validates self-consistency of GemmParams.
167template <typename AccumScalar, typename DstScalar,
168 QuantizationFlavor quantization_flavor>
169void ValidateGemmParams(
170 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
171 // Guard consistency of the quantized multiplier fields.
172 if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
173 TFLITE_DCHECK(!params.multiplier_fixedpoint);
174 TFLITE_DCHECK(!params.multiplier_exponent);
175 TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
176 TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
177 } else if (quantization_flavor ==
178 QuantizationFlavor::kIntegerWithUniformMultiplier &&
179 !std::is_same<DstScalar, int32_t>::value) {
180 TFLITE_DCHECK(params.multiplier_fixedpoint);
181 // Nothing to check about multiplier_exponent
182 TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
183 TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
184 } else if (quantization_flavor ==
185 QuantizationFlavor::kIntegerWithPerRowMultiplier &&
186 !std::is_same<DstScalar, int32_t>::value) {
187 TFLITE_DCHECK(!params.multiplier_fixedpoint);
188 TFLITE_DCHECK(!params.multiplier_exponent);
189 TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
190 TFLITE_DCHECK(params.multiplier_exponent_perchannel);
191 } else {
192 // For the get raw accumulator case, we should make sure none of the
193 // quantization params are set.
194 TFLITE_DCHECK(!params.multiplier_fixedpoint);
195 TFLITE_DCHECK(!params.multiplier_exponent);
196 TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
197 TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
198 }
199}
200
201namespace detail {
202
203template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
204 typename DstScalar, QuantizationFlavor quantization_flavor>
205struct ValidateTypes {
206 // This generic implementation is for quantized flavors.
207 // kFloatingPoint will be a specialization below.
208 static_assert(!std::is_floating_point<LhsScalar>::value, "");
209 static_assert(!std::is_floating_point<RhsScalar>::value, "");
210 static_assert(!std::is_floating_point<AccumScalar>::value, "");
211 // No requirement on DstScalar --- we might in the future allow it
212 // to be floating point even in a quantized Gemm.
213};
214
215template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
216 typename DstScalar>
217struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
218 QuantizationFlavor::kFloatingPoint> {
219 static_assert(std::is_floating_point<LhsScalar>::value, "");
220 static_assert(std::is_floating_point<RhsScalar>::value, "");
221 static_assert(std::is_floating_point<AccumScalar>::value, "");
222 static_assert(std::is_floating_point<DstScalar>::value, "");
223};
224
225} // namespace detail
226
227// Validates overall consistency of all the parameters taken by a Gemm call:
228// the 3 MatrixParams and the GemmParams.
229template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
230 typename DstScalar, QuantizationFlavor quantization_flavor>
231void ValidateParams(
232 const MatrixParams<LhsScalar>& lhs_params,
233 const MatrixParams<RhsScalar>& rhs_params,
234 const MatrixParams<DstScalar>& dst_params,
235 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
236 (void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
237 quantization_flavor>();
238 ValidateGemmParams(params);
239}
240
241// Test if the Gemm is degenerate in some way, e.g. nonsensical dimenions.
242template <typename LhsScalar, typename RhsScalar, typename DstScalar>
243bool IsValidGemm(const MatrixParams<LhsScalar>& lhs_params,
244 const MatrixParams<RhsScalar>& rhs_params,
245 const MatrixParams<DstScalar>& dst_params) {
246 bool valid = true;
247 valid &= lhs_params.rows >= 1;
248 valid &= lhs_params.cols >= 1;
249 valid &= rhs_params.rows >= 1;
250 valid &= rhs_params.cols >= 1;
251 valid &= dst_params.rows >= 1;
252 valid &= dst_params.cols >= 1;
253 valid &= lhs_params.cols == rhs_params.rows;
254 valid &= rhs_params.cols == dst_params.cols;
255 valid &= lhs_params.rows == lhs_params.rows;
256 return valid;
257}
258
259} // namespace cpu_backend_gemm
260
261} // namespace tflite
262
263#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
264