1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_ |
17 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_ |
18 | |
19 | #include <cstdint> |
20 | |
21 | #include "ruy/profiler/instrumentation.h" // from @ruy |
22 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
23 | #include "tensorflow/lite/kernels/cpu_backend_gemm_custom_gemv.h" |
24 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" |
25 | #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" |
26 | |
27 | #ifndef TFLITE_WITH_RUY |
28 | #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" |
29 | #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" |
30 | #include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h" |
31 | #endif |
32 | |
33 | namespace tflite { |
34 | |
35 | namespace cpu_backend_gemm { |
36 | |
37 | // The main entry point for CpuBackendGemm::Gemm. |
38 | // |
39 | // If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka |
40 | // GemmImplUsingRuy. Other cases are as follows: |
41 | // |
42 | // |Quantized (uint8)|Quantized (int8)| Float | |
43 | // TFLITE_WITH_RUY | Ruy | Ruy | Ruy | |
44 | // !TFLITE_WITH_RUY | gemmlowp | Ruy/gemmlowp* | eigen | |
45 | // * - Ruy if NEON is not available. |
46 | |
47 | // On x86 platforms: |
48 | // (default) | gemmlowp | Ruy | eigen | |
49 | // TFLITE_X86_RUY_\ | Ruy | Ruy | Ruy | |
50 | // ENABLED && (AVX |
51 | // or above available) |
52 | |
53 | #if !defined(TFLITE_WITH_RUY) && defined(TFLITE_X86_PLATFORM) |
54 | /* GEMM dispatch implementation for x86. |
55 | */ |
56 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
57 | typename DstScalar, QuantizationFlavor quantization_flavor> |
58 | struct GemmImpl : detail::GemmImplX86<LhsScalar, RhsScalar, AccumScalar, |
59 | DstScalar, quantization_flavor> {}; |
60 | #else |
61 | /* Generic implementation using ruy. |
62 | * Non-ruy implementation will be partial specializations of this template. |
63 | */ |
64 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
65 | typename DstScalar, QuantizationFlavor quantization_flavor> |
66 | struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, |
67 | DstScalar, quantization_flavor> {}; |
68 | |
69 | #if !defined(TFLITE_WITH_RUY) |
70 | |
71 | /* Specializations using gemmlowp */ |
72 | template <typename SrcScalar, typename DstScalar, |
73 | QuantizationFlavor quantization_flavor> |
74 | struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar, |
75 | quantization_flavor> |
76 | : detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t, |
77 | DstScalar, quantization_flavor> {}; |
78 | |
79 | // When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile |
80 | // outside of NEON. We avoid the compilation failure by subspecializing these |
81 | // cases, rerouting it back to ruy. |
82 | #if !defined(GEMMLOWP_NEON) |
83 | template <typename SrcScalar, QuantizationFlavor quantization_flavor> |
84 | struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t, |
85 | quantization_flavor> |
86 | : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t, |
87 | quantization_flavor> {}; |
88 | |
89 | template <typename DstScalar, QuantizationFlavor quantization_flavor> |
90 | struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar, |
91 | quantization_flavor> |
92 | : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, |
93 | DstScalar, quantization_flavor> {}; |
94 | |
95 | template <QuantizationFlavor quantization_flavor> |
96 | struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t, |
97 | quantization_flavor> |
98 | : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, |
99 | std::int8_t, quantization_flavor> {}; |
100 | #endif // not GEMMLOWP_NEON |
101 | |
102 | /* Specializations using Eigen */ |
103 | |
104 | template <> |
105 | struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint> |
106 | : detail::GemmImplUsingEigen {}; |
107 | |
108 | #endif // not TFLITE_WITH_RUY |
109 | |
110 | #endif // not TFLITE_WITH_RUY and TFLITE_X86_PLATFORM |
111 | |
112 | /* Public entry point */ |
113 | |
114 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
115 | typename DstScalar, QuantizationFlavor quantization_flavor> |
116 | void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
117 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
118 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
119 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
120 | CpuBackendContext* context) { |
121 | ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm" ); |
122 | ValidateParams(lhs_params, rhs_params, dst_params, params); |
123 | if (!IsValidGemm(lhs_params, rhs_params, dst_params)) { |
124 | // For now, assert in debug mode, return in opt. |
125 | // TODO(b/183099395) Eliminate debug/release discrepancy by plumbing in |
126 | // TFLiteStatus so we can return an error here. |
127 | TFLITE_DCHECK(false); |
128 | return; |
129 | } |
130 | // In some cases we want to unconditionally use ruy as the backend, overriding |
131 | // the `tflite_with_ruy` setting and the platform default. |
132 | bool must_use_ruy = false; |
133 | if (context->use_caching()) { |
134 | // Only ruy supports caching of pre-packed matrices. Due to the large |
135 | // performance impact in the cases where it's typically used, this overrides |
136 | // the default. |
137 | must_use_ruy = true; |
138 | } |
139 | if (lhs_params.order != Order::kRowMajor || |
140 | rhs_params.order != Order::kColMajor || |
141 | dst_params.order != Order::kColMajor) { |
142 | // ruy supports all 2^3=8 combinations of storage orders with comparable |
143 | // performance. In ruy, it's only a runtime switch. In other backends |
144 | // (gemmlowp, Eigen), storage orders are template parameters, supporting |
145 | // all 8 combinations would be up to a 8-fold code size increase, so we |
146 | // prefer to force usage of ruy in these cases. |
147 | must_use_ruy = true; |
148 | } |
149 | if (must_use_ruy) { |
150 | detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
151 | quantization_flavor>::Run(lhs_params, lhs_data, |
152 | rhs_params, rhs_data, |
153 | dst_params, dst_data, |
154 | params, context); |
155 | return; |
156 | } |
157 | // If we did not choose to force usage of ruy above, then we may now consider |
158 | // using custom GEMV code for the matrix*vector cases. |
159 | const bool try_custom_gemv = (dst_params.cols == 1); |
160 | if (try_custom_gemv) { |
161 | // GEMV case: try a custom fast GEMV path. It will return true if it |
162 | // actually handled it. |
163 | if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, |
164 | dst_params, dst_data, params, context)) { |
165 | return; |
166 | } |
167 | } |
168 | // Generic case: dispatch to any backend as a general GEMM. |
169 | GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
170 | quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, |
171 | dst_params, dst_data, params, context); |
172 | } |
173 | |
174 | // Special path for 16x8 quant gemm. |
175 | template <QuantizationFlavor quantization_flavor> |
176 | void Gemm(const MatrixParams<int8_t>& lhs_params, const int8_t* lhs_data, |
177 | const MatrixParams<int16_t>& rhs_params, const int16_t* rhs_data, |
178 | const MatrixParams<int16_t>& dst_params, int16_t* dst_data, |
179 | const GemmParams<int32_t, int16, quantization_flavor>& params, |
180 | CpuBackendContext* context) { |
181 | ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm" ); |
182 | ValidateParams(lhs_params, rhs_params, dst_params, params); |
183 | if (!IsValidGemm(lhs_params, rhs_params, dst_params)) { |
184 | TFLITE_DCHECK(false); |
185 | return; |
186 | } |
187 | |
188 | // Currently, only Ruy backend supports 16x8 quant gemm so we use ruy |
189 | // only. |
190 | detail::GemmImplUsingRuy<int8_t, int16_t, int32_t, int16, |
191 | quantization_flavor>::Run(lhs_params, lhs_data, |
192 | rhs_params, rhs_data, |
193 | dst_params, dst_data, |
194 | params, context); |
195 | } |
196 | |
197 | // Special path for gemm with raw accumulator case. i.e. AccumScalar == |
198 | // DstScalar == int32 case. |
199 | template <typename LhsScalar, typename RhsScalar, |
200 | QuantizationFlavor quantization_flavor> |
201 | void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
202 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
203 | const MatrixParams<int32_t>& dst_params, int32_t* dst_data, |
204 | const GemmParams<int32_t, int32_t, quantization_flavor>& params, |
205 | CpuBackendContext* context) { |
206 | ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm" ); |
207 | ValidateParams(lhs_params, rhs_params, dst_params, params); |
208 | |
209 | // Currently, only Ruy backend supports get raw accumulator, so we use ruy |
210 | // only. |
211 | ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM" ); |
212 | detail::GemmImplUsingRuy<LhsScalar, RhsScalar, int32_t, int32_t, |
213 | quantization_flavor>::Run(lhs_params, lhs_data, |
214 | rhs_params, rhs_data, |
215 | dst_params, dst_data, |
216 | params, context); |
217 | } |
218 | |
219 | } // namespace cpu_backend_gemm |
220 | |
221 | } // namespace tflite |
222 | |
223 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_ |
224 | |