1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
17#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
18
19#include <cstdint>
20
21#include "ruy/profiler/instrumentation.h" // from @ruy
22#include "tensorflow/lite/kernels/cpu_backend_context.h"
23#include "tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h"
24#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
25#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
26
27#ifndef TFLITE_WITH_RUY
28#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
29#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
30#include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h"
31#endif
32
33namespace tflite {
34
35namespace cpu_backend_gemm {
36
37// The main entry point for CpuBackendGemm::Gemm.
38//
39// If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka
40// GemmImplUsingRuy. Other cases are as follows:
41//
42// |Quantized (uint8)|Quantized (int8)| Float |
43// TFLITE_WITH_RUY | Ruy | Ruy | Ruy |
44// !TFLITE_WITH_RUY | gemmlowp | Ruy/gemmlowp* | eigen |
45// * - Ruy if NEON is not available.
46
47// On x86 platforms:
48// (default) | gemmlowp | Ruy | eigen |
49// TFLITE_X86_RUY_\ | Ruy | Ruy | Ruy |
50// ENABLED && (AVX
51// or above available)
52
53#if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM)
54/* GEMM dispatch implementation for x86.
55 */
56template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
57 typename DstScalar, QuantizationFlavor quantization_flavor>
58struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar,
59 DstScalar, quantization_flavor> {};
60#else
61/* Generic implementation using ruy.
62 * Non-ruy implementation will be partial specializations of this template.
63 */
64template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
65 typename DstScalar, QuantizationFlavor quantization_flavor>
66struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
67 DstScalar, quantization_flavor> {};
68
69#if !defined(TFLITE_WITH_RUY)
70
71/* Specializations using gemmlowp */
72template <typename SrcScalar, typename DstScalar,
73 QuantizationFlavor quantization_flavor>
74struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
75 quantization_flavor>
76 : detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
77 DstScalar, quantization_flavor> {};
78
79// When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
80// outside of NEON. We avoid the compilation failure by subspecializing these
81// cases, rerouting it back to ruy.
82#if !defined(GEMMLOWP_NEON)
83template <typename SrcScalar, QuantizationFlavor quantization_flavor>
84struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
85 quantization_flavor>
86 : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
87 quantization_flavor> {};
88
89template <typename DstScalar, QuantizationFlavor quantization_flavor>
90struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
91 quantization_flavor>
92 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
93 DstScalar, quantization_flavor> {};
94
95template <QuantizationFlavor quantization_flavor>
96struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
97 quantization_flavor>
98 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
99 std::int8_t, quantization_flavor> {};
100#endif // not GEMMLOWP_NEON
101
102/* Specializations using Eigen */
103
104template <>
105struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
106 : detail::GemmImplUsingEigen {};
107
108#endif // not TFLITE_WITH_RUY
109
110#endif // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM
111
112/* Public entry point */
113
114template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
115 typename DstScalar, QuantizationFlavor quantization_flavor>
116void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
117 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
118 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
119 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
120 CpuBackendContext* context) {
121 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
122 ValidateParams(lhs_params, rhs_params, dst_params, params);
123 if (!IsValidGemm(lhs_params, rhs_params, dst_params)) {
124 // For now, assert in debug mode, return in opt.
125 // TODO(b/183099395) Eliminate debug/release discrepancy by plumbing in
126 // TFLiteStatus so we can return an error here.
127 TFLITE_DCHECK(false);
128 return;
129 }
130 // In some cases we want to unconditionally use ruy as the backend, overriding
131 // the `tflite_with_ruy` setting and the platform default.
132 bool must_use_ruy = false;
133 if (context->use_caching()) {
134 // Only ruy supports caching of pre-packed matrices. Due to the large
135 // performance impact in the cases where it's typically used, this overrides
136 // the default.
137 must_use_ruy = true;
138 }
139 if (lhs_params.order != Order::kRowMajor ||
140 rhs_params.order != Order::kColMajor ||
141 dst_params.order != Order::kColMajor) {
142 // ruy supports all 2^3=8 combinations of storage orders with comparable
143 // performance. In ruy, it's only a runtime switch. In other backends
144 // (gemmlowp, Eigen), storage orders are template parameters, supporting
145 // all 8 combinations would be up to a 8-fold code size increase, so we
146 // prefer to force usage of ruy in these cases.
147 must_use_ruy = true;
148 }
149 if (must_use_ruy) {
150 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
151 quantization_flavor>::Run(lhs_params, lhs_data,
152 rhs_params, rhs_data,
153 dst_params, dst_data,
154 params, context);
155 return;
156 }
157 // If we did not choose to force usage of ruy above, then we may now consider
158 // using custom GEMV code for the matrix*vector cases.
159 const bool try_custom_gemv = (dst_params.cols == 1);
160 if (try_custom_gemv) {
161 // GEMV case: try a custom fast GEMV path. It will return true if it
162 // actually handled it.
163 if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
164 dst_params, dst_data, params, context)) {
165 return;
166 }
167 }
168 // Generic case: dispatch to any backend as a general GEMM.
169 GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
170 quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
171 dst_params, dst_data, params, context);
172}
173
174// Special path for 16x8 quant gemm.
175template <QuantizationFlavor quantization_flavor>
176void Gemm(const MatrixParams<int8_t>& lhs_params, const int8_t* lhs_data,
177 const MatrixParams<int16_t>& rhs_params, const int16_t* rhs_data,
178 const MatrixParams<int16_t>& dst_params, int16_t* dst_data,
179 const GemmParams<int32_t, int16, quantization_flavor>& params,
180 CpuBackendContext* context) {
181 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
182 ValidateParams(lhs_params, rhs_params, dst_params, params);
183 if (!IsValidGemm(lhs_params, rhs_params, dst_params)) {
184 TFLITE_DCHECK(false);
185 return;
186 }
187
188 // Currently, only Ruy backend supports 16x8 quant gemm so we use ruy
189 // only.
190 detail::GemmImplUsingRuy<int8_t, int16_t, int32_t, int16,
191 quantization_flavor>::Run(lhs_params, lhs_data,
192 rhs_params, rhs_data,
193 dst_params, dst_data,
194 params, context);
195}
196
197// Special path for gemm with raw accumulator case. i.e. AccumScalar ==
198// DstScalar == int32 case.
199template <typename LhsScalar, typename RhsScalar,
200 QuantizationFlavor quantization_flavor>
201void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
202 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
203 const MatrixParams<int32_t>& dst_params, int32_t* dst_data,
204 const GemmParams<int32_t, int32_t, quantization_flavor>& params,
205 CpuBackendContext* context) {
206 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
207 ValidateParams(lhs_params, rhs_params, dst_params, params);
208
209 // Currently, only Ruy backend supports get raw accumulator, so we use ruy
210 // only.
211 ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
212 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, int32_t, int32_t,
213 quantization_flavor>::Run(lhs_params, lhs_data,
214 rhs_params, rhs_data,
215 dst_params, dst_data,
216 params, context);
217}
218
219} // namespace cpu_backend_gemm
220
221} // namespace tflite
222
223#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
224