1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Fast Gemv (i.e. matrix*vector multiplication) paths. |
17 | // TODO(b/132094390): remove when GEMM performance is good enough on GEMV cases. |
18 | |
19 | // TFLite's runtime ops concentrate as much as possible the matrix*vector |
20 | // use cases on the (matrix) * (column-vector) case, as opposed to |
21 | // (row-vector) * (matrix). So that is what we focus on optimizing here. |
22 | // Accordingly, the public cpu_backend_gemm::Gemm() entry point checks |
23 | // if we are in this (matrix) * (column-vector) case, and if so calls |
24 | // CustomGemv. |
25 | // |
26 | // cpu_backend_gemm::Gemm is also currently restricted (as enforced in |
27 | // ValidateParams) to the case where the left-hand side matrix is row-major. |
28 | // |
29 | // So the current scope of this CustomGemv function really is: |
30 | // (row-major matrix) * (column-vector). |
31 | |
32 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_ |
33 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_ |
34 | |
35 | #include <stdint.h> |
36 | |
37 | #include <algorithm> |
38 | #include <type_traits> |
39 | #include <vector> |
40 | |
41 | #include "ruy/profiler/instrumentation.h" // from @ruy |
42 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
43 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" |
44 | #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" |
45 | #include "tensorflow/lite/kernels/internal/common.h" |
46 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
47 | #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" |
48 | |
49 | namespace tflite { |
50 | namespace cpu_backend_gemm { |
51 | namespace detail { |
52 | |
53 | // CustomGemvImpl is what needs to be specialized for each custom GEMV path. |
54 | // |
55 | // It does not deal with any multi-threaded implementation detail. Rather, |
56 | // it provides the single-thread implementation to be run by each thread. |
57 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
58 | typename DstScalar, QuantizationFlavor quantization_flavor> |
59 | struct CustomGemvImpl { |
60 | // The number of rows of the left-hand-side matrix (and equivalently of the |
61 | // destination column-vector) that the kernel processes at a time. |
62 | // This will also be the minimum required number of rows for a Gemv shape |
63 | // to be supported by this path. |
64 | // |
65 | // Gemv implementations are expected to be able to deal with numbers of |
66 | // rows that aren't multiples of kKernelRows by possibly running the kernel |
67 | // again at an odd row_start, e.g. if kKernelRows==4, Run() should still |
68 | // support running on 7 rows by running twice: once with row_start=0 and then |
69 | // another time with row_start=3. |
70 | // |
71 | // On the other hand, gemv implementations are not expected to support |
72 | // running on fewer than kKernelRows rows. There is no interest in |
73 | // optimizing such narrow Gemv's that they are just a few dot-products. |
74 | // Supporting that would require custom kernel code only for that case. |
75 | static constexpr int kKernelRows = 1; |
76 | |
77 | // Returns true if the Gemv shape is supported by Run(), provided that |
78 | // (row_end - row_start) > kKernelRows. |
79 | static bool IsSupportedGivenSufficientlyManyRows( |
80 | const MatrixParams<LhsScalar>& lhs_params, |
81 | const MatrixParams<RhsScalar>& rhs_params, |
82 | const MatrixParams<DstScalar>& dst_params, |
83 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) { |
84 | return false; |
85 | } |
86 | |
87 | // Performs the Gemv. |
88 | static void Run( |
89 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
90 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
91 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
92 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
93 | int row_start, int row_end) {} |
94 | }; |
95 | |
96 | // Wraps CustomGemvImpl for multi-threaded operation. |
97 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
98 | typename DstScalar, QuantizationFlavor quantization_flavor> |
99 | class CustomGemvTask : public cpu_backend_threadpool::Task { |
100 | public: |
101 | CustomGemvTask( |
102 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
103 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
104 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
105 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
106 | int row_start, int row_end) |
107 | : lhs_params_(lhs_params), |
108 | lhs_data_(lhs_data), |
109 | rhs_params_(rhs_params), |
110 | rhs_data_(rhs_data), |
111 | dst_params_(dst_params), |
112 | dst_data_(dst_data), |
113 | params_(params), |
114 | row_start_(row_start), |
115 | row_end_(row_end) {} |
116 | |
117 | void Run() override { |
118 | using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
119 | quantization_flavor>; |
120 | Impl::Run(lhs_params_, lhs_data_, rhs_params_, rhs_data_, dst_params_, |
121 | dst_data_, params_, row_start_, row_end_); |
122 | } |
123 | |
124 | private: |
125 | const MatrixParams<LhsScalar>& lhs_params_; |
126 | const LhsScalar* lhs_data_; |
127 | const MatrixParams<RhsScalar>& rhs_params_; |
128 | const RhsScalar* rhs_data_; |
129 | const MatrixParams<DstScalar>& dst_params_; |
130 | DstScalar* dst_data_; |
131 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params_; |
132 | int row_start_; |
133 | int row_end_; |
134 | }; |
135 | |
136 | // Either performs the requested Gemv operation and returns true, |
137 | // or immediately returns false. |
138 | // |
139 | // See the comment at the top of the file for the scope of what this handles. |
140 | // In summary: (row-major matrix) * (column-vector). |
141 | // |
142 | // Here is only high-level logic. |
143 | // The actual implementation details are in specializations of |
144 | // CustomGemvImpl. |
145 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
146 | typename DstScalar, QuantizationFlavor quantization_flavor> |
147 | bool CustomGemv( |
148 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
149 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
150 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
151 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
152 | CpuBackendContext* context) { |
153 | ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm: CustomGemv" ); |
154 | using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
155 | quantization_flavor>; |
156 | if (lhs_params.rows < Impl::kKernelRows) { |
157 | return false; |
158 | } |
159 | if (!Impl::IsSupportedGivenSufficientlyManyRows(lhs_params, rhs_params, |
160 | dst_params, params)) { |
161 | return false; |
162 | } |
163 | TFLITE_DCHECK_GE(lhs_params.rows, Impl::kKernelRows); |
164 | int thread_count = LegacyHowManyThreads<Impl::kKernelRows>( |
165 | context->max_num_threads(), dst_params.rows, dst_params.cols, |
166 | lhs_params.cols); |
167 | if (thread_count == 1) { |
168 | Impl::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, |
169 | params, 0, lhs_params.rows); |
170 | } else { |
171 | using Task = CustomGemvTask<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
172 | quantization_flavor>; |
173 | std::vector<Task> tasks; |
174 | tasks.reserve(thread_count); |
175 | const int kRowsPerThread = |
176 | RoundUp<Impl::kKernelRows>(CeilQuotient(dst_params.rows, thread_count)); |
177 | int row_start = 0; |
178 | for (int i = 0; i < thread_count; i++) { |
179 | int row_end = std::min(dst_params.rows, row_start + kRowsPerThread); |
180 | tasks.emplace_back(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, |
181 | dst_data, params, row_start, row_end); |
182 | row_start = row_end; |
183 | } |
184 | cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), context); |
185 | } |
186 | return true; |
187 | } |
188 | |
189 | // USE_NEON still allows for x86 where we may be using the arm_neon_sse.h |
190 | // wrapper implementing NEON intrinsics on top of SSE4 intrinsics. |
191 | #ifdef USE_NEON |
192 | |
193 | // Some NEON helper functions used by CustomGemvImpl specializations below, |
194 | // allowing for some type genericity in them. |
195 | |
196 | inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src, |
197 | std::uint8_t zero_point) { |
198 | uint8x16_t src_u8 = vld1q_u8(src); |
199 | int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8))); |
200 | int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8))); |
201 | int16x8x2_t result; |
202 | int16x8_t zero_point_vec = vdupq_n_s16(zero_point); |
203 | result.val[0] = vsubq_s16(src_s16_0, zero_point_vec); |
204 | result.val[1] = vsubq_s16(src_s16_1, zero_point_vec); |
205 | return result; |
206 | } |
207 | |
208 | inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src, |
209 | std::int8_t zero_point) { |
210 | int8x16_t src_s8 = vld1q_s8(src); |
211 | int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8)); |
212 | int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8)); |
213 | int16x8x2_t result; |
214 | int16x8_t zero_point_vec = vdupq_n_s16(zero_point); |
215 | result.val[0] = vsubq_s16(src_s16_0, zero_point_vec); |
216 | result.val[1] = vsubq_s16(src_s16_1, zero_point_vec); |
217 | return result; |
218 | } |
219 | |
220 | inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src, |
221 | std::uint8_t zero_point) { |
222 | uint8x8_t src_u8 = vld1_u8(src); |
223 | int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8)); |
224 | int16x8_t zero_point_vec = vdupq_n_s16(zero_point); |
225 | return vsubq_s16(src_s16, zero_point_vec); |
226 | } |
227 | |
228 | inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src, |
229 | std::int8_t zero_point) { |
230 | int8x8_t src_s8 = vld1_s8(src); |
231 | int16x8_t src_s16 = vmovl_s8(src_s8); |
232 | int16x8_t zero_point_vec = vdupq_n_s16(zero_point); |
233 | return vsubq_s16(src_s16, zero_point_vec); |
234 | } |
235 | |
236 | inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min, |
237 | std::uint8_t clamp_max, std::uint8_t* dst) { |
238 | // Narrow values down to 16 bit signed. |
239 | const int16x4_t res16 = vqmovn_s32(src); |
240 | // Narrow values down to 8 bit unsigned, saturating. |
241 | uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16)); |
242 | // Apply the clamping from the activation function |
243 | res8 = vmax_u8(res8, vdup_n_u8(clamp_min)); |
244 | res8 = vmin_u8(res8, vdup_n_u8(clamp_max)); |
245 | // Store results to destination. |
246 | vst1_lane_u8(dst + 0, res8, 0); |
247 | vst1_lane_u8(dst + 1, res8, 1); |
248 | vst1_lane_u8(dst + 2, res8, 2); |
249 | vst1_lane_u8(dst + 3, res8, 3); |
250 | } |
251 | |
252 | inline void ClampAndStore(int32x4_t src, std::int8_t clamp_min, |
253 | std::int8_t clamp_max, std::int8_t* dst) { |
254 | // Narrow values down to 16 bit signed. |
255 | const int16x4_t res16 = vqmovn_s32(src); |
256 | // Narrow values down to 8 bit unsigned, saturating. |
257 | int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16)); |
258 | // Apply the clamping from the activation function |
259 | res8 = vmax_s8(res8, vdup_n_s8(clamp_min)); |
260 | res8 = vmin_s8(res8, vdup_n_s8(clamp_max)); |
261 | // Store results to destination. |
262 | vst1_lane_s8(dst + 0, res8, 0); |
263 | vst1_lane_s8(dst + 1, res8, 1); |
264 | vst1_lane_s8(dst + 2, res8, 2); |
265 | vst1_lane_s8(dst + 3, res8, 3); |
266 | } |
267 | |
268 | inline void ClampAndStore(int32x4_t src, std::int16_t clamp_min, |
269 | std::int16_t clamp_max, std::int16_t* dst) { |
270 | // Narrow values down to 16 bit signed. |
271 | int16x4_t res16 = vqmovn_s32(src); |
272 | // Apply the clamping from the activation function |
273 | res16 = vmax_s16(res16, vdup_n_s16(clamp_min)); |
274 | res16 = vmin_s16(res16, vdup_n_s16(clamp_max)); |
275 | // Store results to destination. |
276 | vst1_lane_s16(dst + 0, res16, 0); |
277 | vst1_lane_s16(dst + 1, res16, 1); |
278 | vst1_lane_s16(dst + 2, res16, 2); |
279 | vst1_lane_s16(dst + 3, res16, 3); |
280 | } |
281 | |
282 | template <typename LhsScalar, typename RhsScalar, typename DstScalar, |
283 | QuantizationFlavor quantization_flavor> |
284 | struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar, |
285 | quantization_flavor> { |
286 | // This partial template specialization is less generic than its declaration |
287 | // implies: it assumes the following constraints on its free template |
288 | // parameters. We guard these assumptions in the following static_assert's. |
289 | static_assert(std::is_same<LhsScalar, std::uint8_t>::value || |
290 | std::is_same<LhsScalar, std::int8_t>::value, |
291 | "" ); |
292 | static_assert(std::is_same<RhsScalar, std::uint8_t>::value || |
293 | std::is_same<RhsScalar, std::int8_t>::value, |
294 | "" ); |
295 | static_assert(std::is_same<DstScalar, std::uint8_t>::value || |
296 | std::is_same<DstScalar, std::int8_t>::value || |
297 | std::is_same<DstScalar, std::int16_t>::value, |
298 | "" ); |
299 | static_assert(quantization_flavor == |
300 | QuantizationFlavor::kIntegerWithUniformMultiplier || |
301 | quantization_flavor == |
302 | QuantizationFlavor::kIntegerWithPerRowMultiplier, |
303 | "" ); |
304 | |
305 | // This implementation's inner loop processes 4 rows of the left-hand side |
306 | // matrix at a time. |
307 | static constexpr int kKernelRows = 4; |
308 | |
309 | static bool IsSupportedGivenSufficientlyManyRows( |
310 | const MatrixParams<LhsScalar>& lhs_params, |
311 | const MatrixParams<RhsScalar>& rhs_params, |
312 | const MatrixParams<DstScalar>& dst_params, |
313 | const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) { |
314 | // The kernel processes at least 8 LHS columns at once to fill NEON |
315 | // registers. The leftovers-handling code at the end works by loading a |
316 | // partially overlapping final register by walking back by a few (<8) values |
317 | // to avoid running past the row's end. This relies on there being |
318 | // at least 8 LHS columns. |
319 | return lhs_params.cols >= 8; |
320 | } |
321 | |
322 | static void Run( |
323 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
324 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
325 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
326 | const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params, |
327 | int row_start, int row_end) { |
328 | // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each |
329 | // iteration of this for loop. |
330 | TFLITE_DCHECK_GE(row_end - row_start, kKernelRows); |
331 | for (int row = row_start; row < row_end; row += kKernelRows) { |
332 | // Here is the magic where we allow this kernel to handle any odd number |
333 | // of rows as long as it's >= kKernelRows: the last group of `kKernelRows` |
334 | // rows will be nudged to fit, possibly by starting at an odd value of |
335 | // `row`. |
336 | row = std::min(row, row_end - kKernelRows); |
337 | const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols; |
338 | |
339 | static constexpr int kCacheLineSize = 64; |
340 | for (int k = 0; k < rhs_params.rows; |
341 | k += kCacheLineSize / sizeof(RhsScalar)) { |
342 | optimized_ops_preload_l1_keep(rhs_data + k); |
343 | } |
344 | |
345 | // kPreloadAhead is empirically determined. |
346 | // End-to-end latency (ms) on mobilenet_v2_0.35_96_8bit, 1 thread, |
347 | // on Qualcomm S855: |
348 | // |
349 | // kPreloadAhead | big core | little core |
350 | // --------------+----------+------------ |
351 | // 64 | 1.26 | 5.45 |
352 | // 128 | 1.23 | 5.01 |
353 | // 256 | 1.18 | 4.9 |
354 | // 512 | 1.18 | 5.45 |
355 | // 1024 | 1.18 | 6.5 |
356 | // no prefetch | 1.25 | 8.1 |
357 | static constexpr int kPreloadAhead = 256; |
358 | |
359 | // 4 accumulator registers, one for each row being processed. |
360 | // Each has 4 int32 lanes that corresponds to columns modulo 4, and |
361 | // will need to be horizontally reduced at the end. |
362 | int32x4_t acc0 = vdupq_n_s32(0); |
363 | int32x4_t acc1 = acc0; |
364 | int32x4_t acc2 = acc0; |
365 | int32x4_t acc3 = acc0; |
366 | int in = 0; |
367 | // As much as possible, handle 16 columns of the left-hand side matrix |
368 | // at a time. This allows for decent NEON implementation. |
369 | for (; in <= lhs_params.cols - 16; in += 16) { |
370 | const LhsScalar* local_filter_ptr = filter_ptr; |
371 | int16x8x2_t input_val = |
372 | Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point); |
373 | int16x8x2_t filter_val_0 = |
374 | Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
375 | optimized_ops_preload_l1_stream(local_filter_ptr + |
376 | kPreloadAhead / sizeof(LhsScalar)); |
377 | local_filter_ptr += lhs_params.cols; |
378 | int16x8x2_t filter_val_1 = |
379 | Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
380 | optimized_ops_preload_l1_stream(local_filter_ptr + |
381 | kPreloadAhead / sizeof(LhsScalar)); |
382 | local_filter_ptr += lhs_params.cols; |
383 | int16x8x2_t filter_val_2 = |
384 | Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
385 | optimized_ops_preload_l1_stream(local_filter_ptr + |
386 | kPreloadAhead / sizeof(LhsScalar)); |
387 | local_filter_ptr += lhs_params.cols; |
388 | int16x8x2_t filter_val_3 = |
389 | Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
390 | optimized_ops_preload_l1_stream(local_filter_ptr + |
391 | kPreloadAhead / sizeof(LhsScalar)); |
392 | filter_ptr += 16; |
393 | acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]), |
394 | vget_low_s16(input_val.val[0])); |
395 | acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[0]), |
396 | vget_low_s16(input_val.val[0])); |
397 | acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[0]), |
398 | vget_low_s16(input_val.val[0])); |
399 | acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[0]), |
400 | vget_low_s16(input_val.val[0])); |
401 | acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[1]), |
402 | vget_low_s16(input_val.val[1])); |
403 | acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[1]), |
404 | vget_low_s16(input_val.val[1])); |
405 | acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[1]), |
406 | vget_low_s16(input_val.val[1])); |
407 | acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[1]), |
408 | vget_low_s16(input_val.val[1])); |
409 | acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[0]), |
410 | vget_high_s16(input_val.val[0])); |
411 | acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[0]), |
412 | vget_high_s16(input_val.val[0])); |
413 | acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[0]), |
414 | vget_high_s16(input_val.val[0])); |
415 | acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[0]), |
416 | vget_high_s16(input_val.val[0])); |
417 | acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[1]), |
418 | vget_high_s16(input_val.val[1])); |
419 | acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[1]), |
420 | vget_high_s16(input_val.val[1])); |
421 | acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[1]), |
422 | vget_high_s16(input_val.val[1])); |
423 | acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]), |
424 | vget_high_s16(input_val.val[1])); |
425 | } |
426 | // Less that 16 values remain. Try to handle 8 more. |
427 | if (in <= lhs_params.cols - 8) { |
428 | int16x8_t input_val = |
429 | Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point); |
430 | int16x8_t filter_val_0 = Load8AndSubtractZeroPoint( |
431 | filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point); |
432 | int16x8_t filter_val_1 = Load8AndSubtractZeroPoint( |
433 | filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point); |
434 | int16x8_t filter_val_2 = Load8AndSubtractZeroPoint( |
435 | filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point); |
436 | int16x8_t filter_val_3 = Load8AndSubtractZeroPoint( |
437 | filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point); |
438 | filter_ptr += 8; |
439 | acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0), |
440 | vget_low_s16(input_val)); |
441 | acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1), |
442 | vget_low_s16(input_val)); |
443 | acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2), |
444 | vget_low_s16(input_val)); |
445 | acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3), |
446 | vget_low_s16(input_val)); |
447 | acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0), |
448 | vget_high_s16(input_val)); |
449 | acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1), |
450 | vget_high_s16(input_val)); |
451 | acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2), |
452 | vget_high_s16(input_val)); |
453 | acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3), |
454 | vget_high_s16(input_val)); |
455 | in += 8; |
456 | } |
457 | // Less than 8 values remain. Handle the remaining values |
458 | // in one more copy of the above code handling 8, where we |
459 | // walk back a few values to be able to load 8 values without |
460 | // overrunning the buffer. This is where we make use of the requirement |
461 | // (see IsSupportedGivenSufficientlyManyRows) that there at least |
462 | // 8 LHS columns. |
463 | if (in < lhs_params.cols) { |
464 | // `back` is how many entries to walk back by. |
465 | // Its value is necessarily between 1 and 7. |
466 | const int back = in + 8 - lhs_params.cols; |
467 | TFLITE_DCHECK_GE(back, 1); |
468 | TFLITE_DCHECK_LE(back, 7); |
469 | // Load 8 values as usual. |
470 | int16x8_t input_val = Load8AndSubtractZeroPoint( |
471 | rhs_data + lhs_params.cols - 8, rhs_params.zero_point); |
472 | const LhsScalar* local_filter_ptr = filter_ptr - back; |
473 | filter_ptr += lhs_params.cols - in; |
474 | int16x8_t filter_val_0 = |
475 | Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
476 | local_filter_ptr += lhs_params.cols; |
477 | int16x8_t filter_val_1 = |
478 | Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
479 | local_filter_ptr += lhs_params.cols; |
480 | int16x8_t filter_val_2 = |
481 | Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
482 | local_filter_ptr += lhs_params.cols; |
483 | int16x8_t filter_val_3 = |
484 | Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point); |
485 | // Now zero out the `back` first entries of input_val. |
486 | // vsetq_lane_s16 takes a literal index, so we need unrolled code. |
487 | switch (back) { |
488 | case 7: |
489 | input_val = vsetq_lane_s16(0, input_val, 6); |
490 | [[clang::fallthrough]]; |
491 | case 6: |
492 | input_val = vsetq_lane_s16(0, input_val, 5); |
493 | [[clang::fallthrough]]; |
494 | case 5: |
495 | input_val = vsetq_lane_s16(0, input_val, 4); |
496 | [[clang::fallthrough]]; |
497 | case 4: |
498 | input_val = vsetq_lane_s16(0, input_val, 3); |
499 | [[clang::fallthrough]]; |
500 | case 3: |
501 | input_val = vsetq_lane_s16(0, input_val, 2); |
502 | [[clang::fallthrough]]; |
503 | case 2: |
504 | input_val = vsetq_lane_s16(0, input_val, 1); |
505 | [[clang::fallthrough]]; |
506 | default: |
507 | input_val = vsetq_lane_s16(0, input_val, 0); |
508 | } |
509 | // Multiply-accumulate 8 values as usual. The `back` first lanes |
510 | // of filter_val_* are junk, but it doesn't matter since they get |
511 | // multiplied by the zeros that we just wrote in the corresponding |
512 | // lanes of input_val. |
513 | acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0), |
514 | vget_low_s16(input_val)); |
515 | acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1), |
516 | vget_low_s16(input_val)); |
517 | acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2), |
518 | vget_low_s16(input_val)); |
519 | acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3), |
520 | vget_low_s16(input_val)); |
521 | acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0), |
522 | vget_high_s16(input_val)); |
523 | acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1), |
524 | vget_high_s16(input_val)); |
525 | acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2), |
526 | vget_high_s16(input_val)); |
527 | acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3), |
528 | vget_high_s16(input_val)); |
529 | } |
530 | |
531 | // Horizontally reduce accumulators |
532 | int32x2_t pairwise_reduced_acc_0 = |
533 | vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0)); |
534 | int32x2_t pairwise_reduced_acc_1 = |
535 | vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1)); |
536 | int32x2_t pairwise_reduced_acc_2 = |
537 | vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2)); |
538 | int32x2_t pairwise_reduced_acc_3 = |
539 | vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3)); |
540 | const int32x2_t reduced_lo = |
541 | vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); |
542 | const int32x2_t reduced_hi = |
543 | vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); |
544 | int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); |
545 | // End of horizontal reduction: now `reduced` is a single int32x4 |
546 | // containing the 4 int32 accumulators corresponding to the 4 rows |
547 | // being processed. |
548 | |
549 | // Add bias values. |
550 | if (params.bias) { |
551 | int32x4_t bias_vec = vld1q_s32(params.bias + row); |
552 | reduced = vaddq_s32(reduced, bias_vec); |
553 | } |
554 | |
555 | // Get multiplier parameters. |
556 | int32x4_t multiplier_fixedpoint; |
557 | int32x4_t multiplier_exponent; |
558 | if (quantization_flavor == |
559 | QuantizationFlavor::kIntegerWithPerRowMultiplier) { |
560 | multiplier_exponent = |
561 | vld1q_s32(params.multiplier_exponent_perchannel + row); |
562 | multiplier_fixedpoint = |
563 | vld1q_s32(params.multiplier_fixedpoint_perchannel + row); |
564 | } else { |
565 | multiplier_exponent = vdupq_n_s32(params.multiplier_exponent); |
566 | multiplier_fixedpoint = vdupq_n_s32(params.multiplier_fixedpoint); |
567 | } |
568 | |
569 | // If positive exponent, shift left. |
570 | int32x4_t exponent_positive_part = |
571 | vmaxq_s32(multiplier_exponent, vdupq_n_s32(0)); |
572 | reduced = vshlq_s32(reduced, exponent_positive_part); |
573 | // Multiply by the fixed-point multiplier. |
574 | reduced = vqrdmulhq_s32(reduced, multiplier_fixedpoint); |
575 | // If negative exponent, rounding-shift-right. |
576 | int32x4_t exponent_negative_part = |
577 | vminq_s32(multiplier_exponent, vdupq_n_s32(0)); |
578 | reduced = vrshlq_s32(reduced, exponent_negative_part); |
579 | |
580 | // Add the output offset. |
581 | const int32x4_t output_offset_vec = vdupq_n_s32(dst_params.zero_point); |
582 | reduced = vaddq_s32(reduced, output_offset_vec); |
583 | |
584 | // Finally, clamp and store to the destination. |
585 | ClampAndStore(reduced, params.clamp_min, params.clamp_max, |
586 | dst_data + row); |
587 | } |
588 | } |
589 | }; |
590 | |
591 | // The float specialization below is unconditionally faster than ruy |
592 | // because ruy does not currently have any Gemv path. |
593 | // But it is not unconditionally faster than Eigen, which is what is used |
594 | // unless TFLITE_WITH_RUY is defined. Indeed, Eigen has decently efficient |
595 | // Gemv paths, and they may use AVX instructions, while the present |
596 | // NEON intrinsics code maps at best to SSE4 on x86. |
597 | #ifdef TFLITE_WITH_RUY |
598 | |
599 | // We want to use fused multiply-add when it's available (that is, on A64 |
600 | // unconditionally and on A32 with VFPv4) because it's often faster, and |
601 | // because non-fused seems not to be available in A64 so a conscientious |
602 | // compiler might emit slow code (separate mul and add instructions) in order to |
603 | // implement the vmlaq_f32 intrinsic with strict bit-for-bit exactness on A64. |
604 | // (Compilers seem to be generating a fused fmla instruction at the moment, |
605 | // but that could change). |
606 | // |
607 | // We still want to support building for A32 without VFPv4. |
608 | inline float32x4_t mul_add(float32x4_t acc, float32x4_t lhs, float32x4_t rhs) { |
609 | #ifdef __ARM_FEATURE_FMA |
610 | return vfmaq_f32(acc, lhs, rhs); |
611 | #else |
612 | return vmlaq_f32(acc, lhs, rhs); |
613 | #endif |
614 | } |
615 | |
616 | template <> |
617 | struct CustomGemvImpl<float, float, float, float, |
618 | QuantizationFlavor::kFloatingPoint> { |
619 | // This implementation's inner loop processes 4 rows of the left-hand side |
620 | // matrix at a time. |
621 | static constexpr int kKernelRows = 4; |
622 | |
623 | static bool IsSupportedGivenSufficientlyManyRows( |
624 | const MatrixParams<float>& lhs_params, |
625 | const MatrixParams<float>& rhs_params, |
626 | const MatrixParams<float>& dst_params, |
627 | const GemmParams<float, float>& params) { |
628 | // The kernel processes 4 LHS columns at once to fill float32x4 registers. |
629 | // The leftovers-handling code at the end works by loading a partially |
630 | // overlapping final register by walking back by a few (<4) floats |
631 | // to avoid running past the row's end. This relies on there being |
632 | // at least 4 LHS columns. |
633 | return lhs_params.cols >= 4; |
634 | } |
635 | static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data, |
636 | const MatrixParams<float>& rhs_params, const float* rhs_data, |
637 | const MatrixParams<float>& dst_params, float* dst_data, |
638 | const GemmParams<float, float>& params, int row_start, |
639 | int row_end) { |
640 | // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each |
641 | // iteration of this for loop. |
642 | TFLITE_DCHECK_GE(row_end - row_start, kKernelRows); |
643 | for (int row = row_start; row < row_end; row += kKernelRows) { |
644 | // Here is the magic where we allow this kernel to handle any odd number |
645 | // of rows as long as it's >= kKernelRows: the last group of `kKernelRows` |
646 | // rows will be nudged to fit, possibly by starting at an odd value of |
647 | // `row`. |
648 | row = std::min(row, row_end - kKernelRows); |
649 | const float* filter_ptr = lhs_data + row * lhs_params.cols; |
650 | |
651 | static constexpr int kCacheLineSize = 64; |
652 | for (int k = 0; k < rhs_params.rows; |
653 | k += kCacheLineSize / sizeof(float)) { |
654 | optimized_ops_preload_l1_keep(rhs_data + k); |
655 | } |
656 | |
657 | // kPreloadAhead is empirically determined. |
658 | // End-to-end latency (ms) on mobilenet_v2_0.35_96_float, 1 thread, |
659 | // on Qualcomm S855: |
660 | // |
661 | // kPreloadAhead | big core | little core |
662 | // --------------+----------+------------ |
663 | // 64 | 2.4 | 15.2 |
664 | // 128 | 2.15 | 12.9 |
665 | // 256 | 2 | 12.9 |
666 | // 512 | 2.08 | 13.3 |
667 | // 1024 | 2.05 | 14.7 |
668 | // no prefetch | 2.1 | 28 |
669 | static constexpr int kPreloadAhead = 256; |
670 | |
671 | // 4 accumulator registers, one for each row being processed. |
672 | // Each has 4 float32 lanes that corresponds to columns modulo 4, and |
673 | // will need to be horizontally reduced at the end. |
674 | float32x4_t acc0 = vdupq_n_f32(0); |
675 | float32x4_t acc1 = acc0; |
676 | float32x4_t acc2 = acc0; |
677 | float32x4_t acc3 = acc0; |
678 | int in = 0; |
679 | // As much as possible, handle 4 columns of the left-hand side matrix |
680 | // at a time. This allows for decent NEON implementation. |
681 | for (; in <= lhs_params.cols - 4; in += 4) { |
682 | float32x4_t input_val = vld1q_f32(rhs_data + in); |
683 | const float* local_filter_ptr = filter_ptr; |
684 | float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr); |
685 | optimized_ops_preload_l1_stream(local_filter_ptr + |
686 | kPreloadAhead / sizeof(float)); |
687 | local_filter_ptr += lhs_params.cols; |
688 | float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr); |
689 | optimized_ops_preload_l1_stream(local_filter_ptr + |
690 | kPreloadAhead / sizeof(float)); |
691 | local_filter_ptr += lhs_params.cols; |
692 | float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr); |
693 | optimized_ops_preload_l1_stream(local_filter_ptr + |
694 | kPreloadAhead / sizeof(float)); |
695 | local_filter_ptr += lhs_params.cols; |
696 | float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr); |
697 | optimized_ops_preload_l1_stream(local_filter_ptr + |
698 | kPreloadAhead / sizeof(float)); |
699 | filter_ptr += 4; |
700 | acc0 = mul_add(acc0, filter_val_0, input_val); |
701 | acc1 = mul_add(acc1, filter_val_1, input_val); |
702 | acc2 = mul_add(acc2, filter_val_2, input_val); |
703 | acc3 = mul_add(acc3, filter_val_3, input_val); |
704 | } |
705 | // Less than 4 values remain. Handle the remaining values |
706 | // in one more copy of the above code handling 4, where we |
707 | // walk back a few values to be able to load 4 values without |
708 | // overrunning the buffer. This is where we make use of the requirement |
709 | // (see IsSupportedGivenSufficientlyManyRows) that there at least |
710 | // 4 LHS columns. |
711 | if (in < lhs_params.cols) { |
712 | // `back` is how many entries to walk back by. |
713 | // Its value is necessarily between 1 and 3. |
714 | const int back = in + 4 - lhs_params.cols; |
715 | TFLITE_DCHECK_GE(back, 1); |
716 | TFLITE_DCHECK_LE(back, 3); |
717 | // Load 4 values as usual. |
718 | float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4); |
719 | const float* local_filter_ptr = filter_ptr - back; |
720 | filter_ptr += lhs_params.cols - in; |
721 | float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr); |
722 | local_filter_ptr += lhs_params.cols; |
723 | float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr); |
724 | local_filter_ptr += lhs_params.cols; |
725 | float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr); |
726 | local_filter_ptr += lhs_params.cols; |
727 | float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr); |
728 | // Now zero out the `back` first entries of input_val. |
729 | // vsetq_lane_f32 takes a literal index, so we need unrolled code. |
730 | switch (back) { |
731 | case 3: |
732 | input_val = vsetq_lane_f32(0, input_val, 2); |
733 | [[clang::fallthrough]]; |
734 | case 2: |
735 | input_val = vsetq_lane_f32(0, input_val, 1); |
736 | [[clang::fallthrough]]; |
737 | default: |
738 | input_val = vsetq_lane_f32(0, input_val, 0); |
739 | } |
740 | // Multiply-accumulate 4 values as usual. The `back` first lanes |
741 | // of filter_val_* are junk, but it doesn't matter since they get |
742 | // multiplied by the zeros that we just wrote in the corresponding |
743 | // lanes of input_val. |
744 | acc0 = mul_add(acc0, filter_val_0, input_val); |
745 | acc1 = mul_add(acc1, filter_val_1, input_val); |
746 | acc2 = mul_add(acc2, filter_val_2, input_val); |
747 | acc3 = mul_add(acc3, filter_val_3, input_val); |
748 | } |
749 | |
750 | // Horizontally reduce accumulators |
751 | float32x2_t pairwise_reduced_acc_0 = |
752 | vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)); |
753 | float32x2_t pairwise_reduced_acc_1 = |
754 | vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)); |
755 | float32x2_t pairwise_reduced_acc_2 = |
756 | vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2)); |
757 | float32x2_t pairwise_reduced_acc_3 = |
758 | vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3)); |
759 | float32x2_t reduced_lo = |
760 | vpadd_f32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); |
761 | float32x2_t reduced_hi = |
762 | vpadd_f32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); |
763 | float32x4_t reduced = vcombine_f32(reduced_lo, reduced_hi); |
764 | // End of horizontal reduction: now `reduced` is a single float32x4 |
765 | // containing the 4 float32 accumulators corresponding to the 4 rows |
766 | // being processed. |
767 | |
768 | if (params.bias) { |
769 | // Add bias values. |
770 | reduced = vaddq_f32(reduced, vld1q_f32(params.bias + row)); |
771 | } |
772 | |
773 | // Clamp and store to destination. |
774 | reduced = vminq_f32(reduced, vdupq_n_f32(params.clamp_max)); |
775 | reduced = vmaxq_f32(reduced, vdupq_n_f32(params.clamp_min)); |
776 | vst1q_f32(dst_data + row, reduced); |
777 | } |
778 | } |
779 | }; |
780 | |
781 | #endif // TFLITE_WITH_RUY |
782 | |
783 | #endif // USE_NEON |
784 | |
785 | } // namespace detail |
786 | } // namespace cpu_backend_gemm |
787 | } // namespace tflite |
788 | |
789 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_ |
790 | |