1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Fast Gemv (i.e. matrix*vector multiplication) paths.
17// TODO(b/132094390): remove when GEMM performance is good enough on GEMV cases.
18
19// TFLite's runtime ops concentrate as much as possible the matrix*vector
20// use cases on the (matrix) * (column-vector) case, as opposed to
21// (row-vector) * (matrix). So that is what we focus on optimizing here.
22// Accordingly, the public cpu_backend_gemm::Gemm() entry point checks
23// if we are in this (matrix) * (column-vector) case, and if so calls
24// CustomGemv.
25//
26// cpu_backend_gemm::Gemm is also currently restricted (as enforced in
27// ValidateParams) to the case where the left-hand side matrix is row-major.
28//
29// So the current scope of this CustomGemv function really is:
30// (row-major matrix) * (column-vector).
31
32#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
33#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
34
35#include <stdint.h>
36
37#include <algorithm>
38#include <type_traits>
39#include <vector>
40
41#include "ruy/profiler/instrumentation.h" // from @ruy
42#include "tensorflow/lite/kernels/cpu_backend_context.h"
43#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
44#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
45#include "tensorflow/lite/kernels/internal/common.h"
46#include "tensorflow/lite/kernels/internal/compatibility.h"
47#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
48
49namespace tflite {
50namespace cpu_backend_gemm {
51namespace detail {
52
53// CustomGemvImpl is what needs to be specialized for each custom GEMV path.
54//
55// It does not deal with any multi-threaded implementation detail. Rather,
56// it provides the single-thread implementation to be run by each thread.
57template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
58 typename DstScalar, QuantizationFlavor quantization_flavor>
59struct CustomGemvImpl {
60 // The number of rows of the left-hand-side matrix (and equivalently of the
61 // destination column-vector) that the kernel processes at a time.
62 // This will also be the minimum required number of rows for a Gemv shape
63 // to be supported by this path.
64 //
65 // Gemv implementations are expected to be able to deal with numbers of
66 // rows that aren't multiples of kKernelRows by possibly running the kernel
67 // again at an odd row_start, e.g. if kKernelRows==4, Run() should still
68 // support running on 7 rows by running twice: once with row_start=0 and then
69 // another time with row_start=3.
70 //
71 // On the other hand, gemv implementations are not expected to support
72 // running on fewer than kKernelRows rows. There is no interest in
73 // optimizing such narrow Gemv's that they are just a few dot-products.
74 // Supporting that would require custom kernel code only for that case.
75 static constexpr int kKernelRows = 1;
76
77 // Returns true if the Gemv shape is supported by Run(), provided that
78 // (row_end - row_start) > kKernelRows.
79 static bool IsSupportedGivenSufficientlyManyRows(
80 const MatrixParams<LhsScalar>& lhs_params,
81 const MatrixParams<RhsScalar>& rhs_params,
82 const MatrixParams<DstScalar>& dst_params,
83 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
84 return false;
85 }
86
87 // Performs the Gemv.
88 static void Run(
89 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
90 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
91 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
92 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
93 int row_start, int row_end) {}
94};
95
96// Wraps CustomGemvImpl for multi-threaded operation.
97template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
98 typename DstScalar, QuantizationFlavor quantization_flavor>
99class CustomGemvTask : public cpu_backend_threadpool::Task {
100 public:
101 CustomGemvTask(
102 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
103 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
104 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
105 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
106 int row_start, int row_end)
107 : lhs_params_(lhs_params),
108 lhs_data_(lhs_data),
109 rhs_params_(rhs_params),
110 rhs_data_(rhs_data),
111 dst_params_(dst_params),
112 dst_data_(dst_data),
113 params_(params),
114 row_start_(row_start),
115 row_end_(row_end) {}
116
117 void Run() override {
118 using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
119 quantization_flavor>;
120 Impl::Run(lhs_params_, lhs_data_, rhs_params_, rhs_data_, dst_params_,
121 dst_data_, params_, row_start_, row_end_);
122 }
123
124 private:
125 const MatrixParams<LhsScalar>& lhs_params_;
126 const LhsScalar* lhs_data_;
127 const MatrixParams<RhsScalar>& rhs_params_;
128 const RhsScalar* rhs_data_;
129 const MatrixParams<DstScalar>& dst_params_;
130 DstScalar* dst_data_;
131 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params_;
132 int row_start_;
133 int row_end_;
134};
135
136// Either performs the requested Gemv operation and returns true,
137// or immediately returns false.
138//
139// See the comment at the top of the file for the scope of what this handles.
140// In summary: (row-major matrix) * (column-vector).
141//
142// Here is only high-level logic.
143// The actual implementation details are in specializations of
144// CustomGemvImpl.
145template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
146 typename DstScalar, QuantizationFlavor quantization_flavor>
147bool CustomGemv(
148 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
149 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
150 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
151 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
152 CpuBackendContext* context) {
153 ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm: CustomGemv");
154 using Impl = CustomGemvImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
155 quantization_flavor>;
156 if (lhs_params.rows < Impl::kKernelRows) {
157 return false;
158 }
159 if (!Impl::IsSupportedGivenSufficientlyManyRows(lhs_params, rhs_params,
160 dst_params, params)) {
161 return false;
162 }
163 TFLITE_DCHECK_GE(lhs_params.rows, Impl::kKernelRows);
164 int thread_count = LegacyHowManyThreads<Impl::kKernelRows>(
165 context->max_num_threads(), dst_params.rows, dst_params.cols,
166 lhs_params.cols);
167 if (thread_count == 1) {
168 Impl::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
169 params, 0, lhs_params.rows);
170 } else {
171 using Task = CustomGemvTask<LhsScalar, RhsScalar, AccumScalar, DstScalar,
172 quantization_flavor>;
173 std::vector<Task> tasks;
174 tasks.reserve(thread_count);
175 const int kRowsPerThread =
176 RoundUp<Impl::kKernelRows>(CeilQuotient(dst_params.rows, thread_count));
177 int row_start = 0;
178 for (int i = 0; i < thread_count; i++) {
179 int row_end = std::min(dst_params.rows, row_start + kRowsPerThread);
180 tasks.emplace_back(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
181 dst_data, params, row_start, row_end);
182 row_start = row_end;
183 }
184 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), context);
185 }
186 return true;
187}
188
189// USE_NEON still allows for x86 where we may be using the arm_neon_sse.h
190// wrapper implementing NEON intrinsics on top of SSE4 intrinsics.
191#ifdef USE_NEON
192
193// Some NEON helper functions used by CustomGemvImpl specializations below,
194// allowing for some type genericity in them.
195
196inline int16x8x2_t Load16AndSubtractZeroPoint(const std::uint8_t* src,
197 std::uint8_t zero_point) {
198 uint8x16_t src_u8 = vld1q_u8(src);
199 int16x8_t src_s16_0 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(src_u8)));
200 int16x8_t src_s16_1 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(src_u8)));
201 int16x8x2_t result;
202 int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
203 result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
204 result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
205 return result;
206}
207
208inline int16x8x2_t Load16AndSubtractZeroPoint(const std::int8_t* src,
209 std::int8_t zero_point) {
210 int8x16_t src_s8 = vld1q_s8(src);
211 int16x8_t src_s16_0 = vmovl_s8(vget_low_s8(src_s8));
212 int16x8_t src_s16_1 = vmovl_s8(vget_high_s8(src_s8));
213 int16x8x2_t result;
214 int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
215 result.val[0] = vsubq_s16(src_s16_0, zero_point_vec);
216 result.val[1] = vsubq_s16(src_s16_1, zero_point_vec);
217 return result;
218}
219
220inline int16x8_t Load8AndSubtractZeroPoint(const std::uint8_t* src,
221 std::uint8_t zero_point) {
222 uint8x8_t src_u8 = vld1_u8(src);
223 int16x8_t src_s16 = vreinterpretq_s16_u16(vmovl_u8(src_u8));
224 int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
225 return vsubq_s16(src_s16, zero_point_vec);
226}
227
228inline int16x8_t Load8AndSubtractZeroPoint(const std::int8_t* src,
229 std::int8_t zero_point) {
230 int8x8_t src_s8 = vld1_s8(src);
231 int16x8_t src_s16 = vmovl_s8(src_s8);
232 int16x8_t zero_point_vec = vdupq_n_s16(zero_point);
233 return vsubq_s16(src_s16, zero_point_vec);
234}
235
236inline void ClampAndStore(int32x4_t src, std::uint8_t clamp_min,
237 std::uint8_t clamp_max, std::uint8_t* dst) {
238 // Narrow values down to 16 bit signed.
239 const int16x4_t res16 = vqmovn_s32(src);
240 // Narrow values down to 8 bit unsigned, saturating.
241 uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
242 // Apply the clamping from the activation function
243 res8 = vmax_u8(res8, vdup_n_u8(clamp_min));
244 res8 = vmin_u8(res8, vdup_n_u8(clamp_max));
245 // Store results to destination.
246 vst1_lane_u8(dst + 0, res8, 0);
247 vst1_lane_u8(dst + 1, res8, 1);
248 vst1_lane_u8(dst + 2, res8, 2);
249 vst1_lane_u8(dst + 3, res8, 3);
250}
251
252inline void ClampAndStore(int32x4_t src, std::int8_t clamp_min,
253 std::int8_t clamp_max, std::int8_t* dst) {
254 // Narrow values down to 16 bit signed.
255 const int16x4_t res16 = vqmovn_s32(src);
256 // Narrow values down to 8 bit unsigned, saturating.
257 int8x8_t res8 = vqmovn_s16(vcombine_s16(res16, res16));
258 // Apply the clamping from the activation function
259 res8 = vmax_s8(res8, vdup_n_s8(clamp_min));
260 res8 = vmin_s8(res8, vdup_n_s8(clamp_max));
261 // Store results to destination.
262 vst1_lane_s8(dst + 0, res8, 0);
263 vst1_lane_s8(dst + 1, res8, 1);
264 vst1_lane_s8(dst + 2, res8, 2);
265 vst1_lane_s8(dst + 3, res8, 3);
266}
267
268inline void ClampAndStore(int32x4_t src, std::int16_t clamp_min,
269 std::int16_t clamp_max, std::int16_t* dst) {
270 // Narrow values down to 16 bit signed.
271 int16x4_t res16 = vqmovn_s32(src);
272 // Apply the clamping from the activation function
273 res16 = vmax_s16(res16, vdup_n_s16(clamp_min));
274 res16 = vmin_s16(res16, vdup_n_s16(clamp_max));
275 // Store results to destination.
276 vst1_lane_s16(dst + 0, res16, 0);
277 vst1_lane_s16(dst + 1, res16, 1);
278 vst1_lane_s16(dst + 2, res16, 2);
279 vst1_lane_s16(dst + 3, res16, 3);
280}
281
282template <typename LhsScalar, typename RhsScalar, typename DstScalar,
283 QuantizationFlavor quantization_flavor>
284struct CustomGemvImpl<LhsScalar, RhsScalar, std::int32_t, DstScalar,
285 quantization_flavor> {
286 // This partial template specialization is less generic than its declaration
287 // implies: it assumes the following constraints on its free template
288 // parameters. We guard these assumptions in the following static_assert's.
289 static_assert(std::is_same<LhsScalar, std::uint8_t>::value ||
290 std::is_same<LhsScalar, std::int8_t>::value,
291 "");
292 static_assert(std::is_same<RhsScalar, std::uint8_t>::value ||
293 std::is_same<RhsScalar, std::int8_t>::value,
294 "");
295 static_assert(std::is_same<DstScalar, std::uint8_t>::value ||
296 std::is_same<DstScalar, std::int8_t>::value ||
297 std::is_same<DstScalar, std::int16_t>::value,
298 "");
299 static_assert(quantization_flavor ==
300 QuantizationFlavor::kIntegerWithUniformMultiplier ||
301 quantization_flavor ==
302 QuantizationFlavor::kIntegerWithPerRowMultiplier,
303 "");
304
305 // This implementation's inner loop processes 4 rows of the left-hand side
306 // matrix at a time.
307 static constexpr int kKernelRows = 4;
308
309 static bool IsSupportedGivenSufficientlyManyRows(
310 const MatrixParams<LhsScalar>& lhs_params,
311 const MatrixParams<RhsScalar>& rhs_params,
312 const MatrixParams<DstScalar>& dst_params,
313 const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params) {
314 // The kernel processes at least 8 LHS columns at once to fill NEON
315 // registers. The leftovers-handling code at the end works by loading a
316 // partially overlapping final register by walking back by a few (<8) values
317 // to avoid running past the row's end. This relies on there being
318 // at least 8 LHS columns.
319 return lhs_params.cols >= 8;
320 }
321
322 static void Run(
323 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
324 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
325 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
326 const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
327 int row_start, int row_end) {
328 // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each
329 // iteration of this for loop.
330 TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
331 for (int row = row_start; row < row_end; row += kKernelRows) {
332 // Here is the magic where we allow this kernel to handle any odd number
333 // of rows as long as it's >= kKernelRows: the last group of `kKernelRows`
334 // rows will be nudged to fit, possibly by starting at an odd value of
335 // `row`.
336 row = std::min(row, row_end - kKernelRows);
337 const LhsScalar* filter_ptr = lhs_data + row * lhs_params.cols;
338
339 static constexpr int kCacheLineSize = 64;
340 for (int k = 0; k < rhs_params.rows;
341 k += kCacheLineSize / sizeof(RhsScalar)) {
342 optimized_ops_preload_l1_keep(rhs_data + k);
343 }
344
345 // kPreloadAhead is empirically determined.
346 // End-to-end latency (ms) on mobilenet_v2_0.35_96_8bit, 1 thread,
347 // on Qualcomm S855:
348 //
349 // kPreloadAhead | big core | little core
350 // --------------+----------+------------
351 // 64 | 1.26 | 5.45
352 // 128 | 1.23 | 5.01
353 // 256 | 1.18 | 4.9
354 // 512 | 1.18 | 5.45
355 // 1024 | 1.18 | 6.5
356 // no prefetch | 1.25 | 8.1
357 static constexpr int kPreloadAhead = 256;
358
359 // 4 accumulator registers, one for each row being processed.
360 // Each has 4 int32 lanes that corresponds to columns modulo 4, and
361 // will need to be horizontally reduced at the end.
362 int32x4_t acc0 = vdupq_n_s32(0);
363 int32x4_t acc1 = acc0;
364 int32x4_t acc2 = acc0;
365 int32x4_t acc3 = acc0;
366 int in = 0;
367 // As much as possible, handle 16 columns of the left-hand side matrix
368 // at a time. This allows for decent NEON implementation.
369 for (; in <= lhs_params.cols - 16; in += 16) {
370 const LhsScalar* local_filter_ptr = filter_ptr;
371 int16x8x2_t input_val =
372 Load16AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
373 int16x8x2_t filter_val_0 =
374 Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
375 optimized_ops_preload_l1_stream(local_filter_ptr +
376 kPreloadAhead / sizeof(LhsScalar));
377 local_filter_ptr += lhs_params.cols;
378 int16x8x2_t filter_val_1 =
379 Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
380 optimized_ops_preload_l1_stream(local_filter_ptr +
381 kPreloadAhead / sizeof(LhsScalar));
382 local_filter_ptr += lhs_params.cols;
383 int16x8x2_t filter_val_2 =
384 Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
385 optimized_ops_preload_l1_stream(local_filter_ptr +
386 kPreloadAhead / sizeof(LhsScalar));
387 local_filter_ptr += lhs_params.cols;
388 int16x8x2_t filter_val_3 =
389 Load16AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
390 optimized_ops_preload_l1_stream(local_filter_ptr +
391 kPreloadAhead / sizeof(LhsScalar));
392 filter_ptr += 16;
393 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[0]),
394 vget_low_s16(input_val.val[0]));
395 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[0]),
396 vget_low_s16(input_val.val[0]));
397 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[0]),
398 vget_low_s16(input_val.val[0]));
399 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[0]),
400 vget_low_s16(input_val.val[0]));
401 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0.val[1]),
402 vget_low_s16(input_val.val[1]));
403 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1.val[1]),
404 vget_low_s16(input_val.val[1]));
405 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2.val[1]),
406 vget_low_s16(input_val.val[1]));
407 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3.val[1]),
408 vget_low_s16(input_val.val[1]));
409 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[0]),
410 vget_high_s16(input_val.val[0]));
411 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[0]),
412 vget_high_s16(input_val.val[0]));
413 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[0]),
414 vget_high_s16(input_val.val[0]));
415 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[0]),
416 vget_high_s16(input_val.val[0]));
417 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0.val[1]),
418 vget_high_s16(input_val.val[1]));
419 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1.val[1]),
420 vget_high_s16(input_val.val[1]));
421 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2.val[1]),
422 vget_high_s16(input_val.val[1]));
423 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3.val[1]),
424 vget_high_s16(input_val.val[1]));
425 }
426 // Less that 16 values remain. Try to handle 8 more.
427 if (in <= lhs_params.cols - 8) {
428 int16x8_t input_val =
429 Load8AndSubtractZeroPoint(rhs_data + in, rhs_params.zero_point);
430 int16x8_t filter_val_0 = Load8AndSubtractZeroPoint(
431 filter_ptr + 0 * lhs_params.cols, lhs_params.zero_point);
432 int16x8_t filter_val_1 = Load8AndSubtractZeroPoint(
433 filter_ptr + 1 * lhs_params.cols, lhs_params.zero_point);
434 int16x8_t filter_val_2 = Load8AndSubtractZeroPoint(
435 filter_ptr + 2 * lhs_params.cols, lhs_params.zero_point);
436 int16x8_t filter_val_3 = Load8AndSubtractZeroPoint(
437 filter_ptr + 3 * lhs_params.cols, lhs_params.zero_point);
438 filter_ptr += 8;
439 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
440 vget_low_s16(input_val));
441 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
442 vget_low_s16(input_val));
443 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
444 vget_low_s16(input_val));
445 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
446 vget_low_s16(input_val));
447 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
448 vget_high_s16(input_val));
449 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
450 vget_high_s16(input_val));
451 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
452 vget_high_s16(input_val));
453 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
454 vget_high_s16(input_val));
455 in += 8;
456 }
457 // Less than 8 values remain. Handle the remaining values
458 // in one more copy of the above code handling 8, where we
459 // walk back a few values to be able to load 8 values without
460 // overrunning the buffer. This is where we make use of the requirement
461 // (see IsSupportedGivenSufficientlyManyRows) that there at least
462 // 8 LHS columns.
463 if (in < lhs_params.cols) {
464 // `back` is how many entries to walk back by.
465 // Its value is necessarily between 1 and 7.
466 const int back = in + 8 - lhs_params.cols;
467 TFLITE_DCHECK_GE(back, 1);
468 TFLITE_DCHECK_LE(back, 7);
469 // Load 8 values as usual.
470 int16x8_t input_val = Load8AndSubtractZeroPoint(
471 rhs_data + lhs_params.cols - 8, rhs_params.zero_point);
472 const LhsScalar* local_filter_ptr = filter_ptr - back;
473 filter_ptr += lhs_params.cols - in;
474 int16x8_t filter_val_0 =
475 Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
476 local_filter_ptr += lhs_params.cols;
477 int16x8_t filter_val_1 =
478 Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
479 local_filter_ptr += lhs_params.cols;
480 int16x8_t filter_val_2 =
481 Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
482 local_filter_ptr += lhs_params.cols;
483 int16x8_t filter_val_3 =
484 Load8AndSubtractZeroPoint(local_filter_ptr, lhs_params.zero_point);
485 // Now zero out the `back` first entries of input_val.
486 // vsetq_lane_s16 takes a literal index, so we need unrolled code.
487 switch (back) {
488 case 7:
489 input_val = vsetq_lane_s16(0, input_val, 6);
490 [[clang::fallthrough]];
491 case 6:
492 input_val = vsetq_lane_s16(0, input_val, 5);
493 [[clang::fallthrough]];
494 case 5:
495 input_val = vsetq_lane_s16(0, input_val, 4);
496 [[clang::fallthrough]];
497 case 4:
498 input_val = vsetq_lane_s16(0, input_val, 3);
499 [[clang::fallthrough]];
500 case 3:
501 input_val = vsetq_lane_s16(0, input_val, 2);
502 [[clang::fallthrough]];
503 case 2:
504 input_val = vsetq_lane_s16(0, input_val, 1);
505 [[clang::fallthrough]];
506 default:
507 input_val = vsetq_lane_s16(0, input_val, 0);
508 }
509 // Multiply-accumulate 8 values as usual. The `back` first lanes
510 // of filter_val_* are junk, but it doesn't matter since they get
511 // multiplied by the zeros that we just wrote in the corresponding
512 // lanes of input_val.
513 acc0 = vmlal_s16(acc0, vget_low_s16(filter_val_0),
514 vget_low_s16(input_val));
515 acc1 = vmlal_s16(acc1, vget_low_s16(filter_val_1),
516 vget_low_s16(input_val));
517 acc2 = vmlal_s16(acc2, vget_low_s16(filter_val_2),
518 vget_low_s16(input_val));
519 acc3 = vmlal_s16(acc3, vget_low_s16(filter_val_3),
520 vget_low_s16(input_val));
521 acc0 = vmlal_s16(acc0, vget_high_s16(filter_val_0),
522 vget_high_s16(input_val));
523 acc1 = vmlal_s16(acc1, vget_high_s16(filter_val_1),
524 vget_high_s16(input_val));
525 acc2 = vmlal_s16(acc2, vget_high_s16(filter_val_2),
526 vget_high_s16(input_val));
527 acc3 = vmlal_s16(acc3, vget_high_s16(filter_val_3),
528 vget_high_s16(input_val));
529 }
530
531 // Horizontally reduce accumulators
532 int32x2_t pairwise_reduced_acc_0 =
533 vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
534 int32x2_t pairwise_reduced_acc_1 =
535 vpadd_s32(vget_low_s32(acc1), vget_high_s32(acc1));
536 int32x2_t pairwise_reduced_acc_2 =
537 vpadd_s32(vget_low_s32(acc2), vget_high_s32(acc2));
538 int32x2_t pairwise_reduced_acc_3 =
539 vpadd_s32(vget_low_s32(acc3), vget_high_s32(acc3));
540 const int32x2_t reduced_lo =
541 vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
542 const int32x2_t reduced_hi =
543 vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
544 int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
545 // End of horizontal reduction: now `reduced` is a single int32x4
546 // containing the 4 int32 accumulators corresponding to the 4 rows
547 // being processed.
548
549 // Add bias values.
550 if (params.bias) {
551 int32x4_t bias_vec = vld1q_s32(params.bias + row);
552 reduced = vaddq_s32(reduced, bias_vec);
553 }
554
555 // Get multiplier parameters.
556 int32x4_t multiplier_fixedpoint;
557 int32x4_t multiplier_exponent;
558 if (quantization_flavor ==
559 QuantizationFlavor::kIntegerWithPerRowMultiplier) {
560 multiplier_exponent =
561 vld1q_s32(params.multiplier_exponent_perchannel + row);
562 multiplier_fixedpoint =
563 vld1q_s32(params.multiplier_fixedpoint_perchannel + row);
564 } else {
565 multiplier_exponent = vdupq_n_s32(params.multiplier_exponent);
566 multiplier_fixedpoint = vdupq_n_s32(params.multiplier_fixedpoint);
567 }
568
569 // If positive exponent, shift left.
570 int32x4_t exponent_positive_part =
571 vmaxq_s32(multiplier_exponent, vdupq_n_s32(0));
572 reduced = vshlq_s32(reduced, exponent_positive_part);
573 // Multiply by the fixed-point multiplier.
574 reduced = vqrdmulhq_s32(reduced, multiplier_fixedpoint);
575 // If negative exponent, rounding-shift-right.
576 int32x4_t exponent_negative_part =
577 vminq_s32(multiplier_exponent, vdupq_n_s32(0));
578 reduced = vrshlq_s32(reduced, exponent_negative_part);
579
580 // Add the output offset.
581 const int32x4_t output_offset_vec = vdupq_n_s32(dst_params.zero_point);
582 reduced = vaddq_s32(reduced, output_offset_vec);
583
584 // Finally, clamp and store to the destination.
585 ClampAndStore(reduced, params.clamp_min, params.clamp_max,
586 dst_data + row);
587 }
588 }
589};
590
591// The float specialization below is unconditionally faster than ruy
592// because ruy does not currently have any Gemv path.
593// But it is not unconditionally faster than Eigen, which is what is used
594// unless TFLITE_WITH_RUY is defined. Indeed, Eigen has decently efficient
595// Gemv paths, and they may use AVX instructions, while the present
596// NEON intrinsics code maps at best to SSE4 on x86.
597#ifdef TFLITE_WITH_RUY
598
599// We want to use fused multiply-add when it's available (that is, on A64
600// unconditionally and on A32 with VFPv4) because it's often faster, and
601// because non-fused seems not to be available in A64 so a conscientious
602// compiler might emit slow code (separate mul and add instructions) in order to
603// implement the vmlaq_f32 intrinsic with strict bit-for-bit exactness on A64.
604// (Compilers seem to be generating a fused fmla instruction at the moment,
605// but that could change).
606//
607// We still want to support building for A32 without VFPv4.
608inline float32x4_t mul_add(float32x4_t acc, float32x4_t lhs, float32x4_t rhs) {
609#ifdef __ARM_FEATURE_FMA
610 return vfmaq_f32(acc, lhs, rhs);
611#else
612 return vmlaq_f32(acc, lhs, rhs);
613#endif
614}
615
616template <>
617struct CustomGemvImpl<float, float, float, float,
618 QuantizationFlavor::kFloatingPoint> {
619 // This implementation's inner loop processes 4 rows of the left-hand side
620 // matrix at a time.
621 static constexpr int kKernelRows = 4;
622
623 static bool IsSupportedGivenSufficientlyManyRows(
624 const MatrixParams<float>& lhs_params,
625 const MatrixParams<float>& rhs_params,
626 const MatrixParams<float>& dst_params,
627 const GemmParams<float, float>& params) {
628 // The kernel processes 4 LHS columns at once to fill float32x4 registers.
629 // The leftovers-handling code at the end works by loading a partially
630 // overlapping final register by walking back by a few (<4) floats
631 // to avoid running past the row's end. This relies on there being
632 // at least 4 LHS columns.
633 return lhs_params.cols >= 4;
634 }
635 static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
636 const MatrixParams<float>& rhs_params, const float* rhs_data,
637 const MatrixParams<float>& dst_params, float* dst_data,
638 const GemmParams<float, float>& params, int row_start,
639 int row_end) {
640 // Handle kKernelRows ( == 4) rows of the left-hand side matrix at each
641 // iteration of this for loop.
642 TFLITE_DCHECK_GE(row_end - row_start, kKernelRows);
643 for (int row = row_start; row < row_end; row += kKernelRows) {
644 // Here is the magic where we allow this kernel to handle any odd number
645 // of rows as long as it's >= kKernelRows: the last group of `kKernelRows`
646 // rows will be nudged to fit, possibly by starting at an odd value of
647 // `row`.
648 row = std::min(row, row_end - kKernelRows);
649 const float* filter_ptr = lhs_data + row * lhs_params.cols;
650
651 static constexpr int kCacheLineSize = 64;
652 for (int k = 0; k < rhs_params.rows;
653 k += kCacheLineSize / sizeof(float)) {
654 optimized_ops_preload_l1_keep(rhs_data + k);
655 }
656
657 // kPreloadAhead is empirically determined.
658 // End-to-end latency (ms) on mobilenet_v2_0.35_96_float, 1 thread,
659 // on Qualcomm S855:
660 //
661 // kPreloadAhead | big core | little core
662 // --------------+----------+------------
663 // 64 | 2.4 | 15.2
664 // 128 | 2.15 | 12.9
665 // 256 | 2 | 12.9
666 // 512 | 2.08 | 13.3
667 // 1024 | 2.05 | 14.7
668 // no prefetch | 2.1 | 28
669 static constexpr int kPreloadAhead = 256;
670
671 // 4 accumulator registers, one for each row being processed.
672 // Each has 4 float32 lanes that corresponds to columns modulo 4, and
673 // will need to be horizontally reduced at the end.
674 float32x4_t acc0 = vdupq_n_f32(0);
675 float32x4_t acc1 = acc0;
676 float32x4_t acc2 = acc0;
677 float32x4_t acc3 = acc0;
678 int in = 0;
679 // As much as possible, handle 4 columns of the left-hand side matrix
680 // at a time. This allows for decent NEON implementation.
681 for (; in <= lhs_params.cols - 4; in += 4) {
682 float32x4_t input_val = vld1q_f32(rhs_data + in);
683 const float* local_filter_ptr = filter_ptr;
684 float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
685 optimized_ops_preload_l1_stream(local_filter_ptr +
686 kPreloadAhead / sizeof(float));
687 local_filter_ptr += lhs_params.cols;
688 float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
689 optimized_ops_preload_l1_stream(local_filter_ptr +
690 kPreloadAhead / sizeof(float));
691 local_filter_ptr += lhs_params.cols;
692 float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
693 optimized_ops_preload_l1_stream(local_filter_ptr +
694 kPreloadAhead / sizeof(float));
695 local_filter_ptr += lhs_params.cols;
696 float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
697 optimized_ops_preload_l1_stream(local_filter_ptr +
698 kPreloadAhead / sizeof(float));
699 filter_ptr += 4;
700 acc0 = mul_add(acc0, filter_val_0, input_val);
701 acc1 = mul_add(acc1, filter_val_1, input_val);
702 acc2 = mul_add(acc2, filter_val_2, input_val);
703 acc3 = mul_add(acc3, filter_val_3, input_val);
704 }
705 // Less than 4 values remain. Handle the remaining values
706 // in one more copy of the above code handling 4, where we
707 // walk back a few values to be able to load 4 values without
708 // overrunning the buffer. This is where we make use of the requirement
709 // (see IsSupportedGivenSufficientlyManyRows) that there at least
710 // 4 LHS columns.
711 if (in < lhs_params.cols) {
712 // `back` is how many entries to walk back by.
713 // Its value is necessarily between 1 and 3.
714 const int back = in + 4 - lhs_params.cols;
715 TFLITE_DCHECK_GE(back, 1);
716 TFLITE_DCHECK_LE(back, 3);
717 // Load 4 values as usual.
718 float32x4_t input_val = vld1q_f32(rhs_data + lhs_params.cols - 4);
719 const float* local_filter_ptr = filter_ptr - back;
720 filter_ptr += lhs_params.cols - in;
721 float32x4_t filter_val_0 = vld1q_f32(local_filter_ptr);
722 local_filter_ptr += lhs_params.cols;
723 float32x4_t filter_val_1 = vld1q_f32(local_filter_ptr);
724 local_filter_ptr += lhs_params.cols;
725 float32x4_t filter_val_2 = vld1q_f32(local_filter_ptr);
726 local_filter_ptr += lhs_params.cols;
727 float32x4_t filter_val_3 = vld1q_f32(local_filter_ptr);
728 // Now zero out the `back` first entries of input_val.
729 // vsetq_lane_f32 takes a literal index, so we need unrolled code.
730 switch (back) {
731 case 3:
732 input_val = vsetq_lane_f32(0, input_val, 2);
733 [[clang::fallthrough]];
734 case 2:
735 input_val = vsetq_lane_f32(0, input_val, 1);
736 [[clang::fallthrough]];
737 default:
738 input_val = vsetq_lane_f32(0, input_val, 0);
739 }
740 // Multiply-accumulate 4 values as usual. The `back` first lanes
741 // of filter_val_* are junk, but it doesn't matter since they get
742 // multiplied by the zeros that we just wrote in the corresponding
743 // lanes of input_val.
744 acc0 = mul_add(acc0, filter_val_0, input_val);
745 acc1 = mul_add(acc1, filter_val_1, input_val);
746 acc2 = mul_add(acc2, filter_val_2, input_val);
747 acc3 = mul_add(acc3, filter_val_3, input_val);
748 }
749
750 // Horizontally reduce accumulators
751 float32x2_t pairwise_reduced_acc_0 =
752 vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
753 float32x2_t pairwise_reduced_acc_1 =
754 vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
755 float32x2_t pairwise_reduced_acc_2 =
756 vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
757 float32x2_t pairwise_reduced_acc_3 =
758 vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
759 float32x2_t reduced_lo =
760 vpadd_f32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
761 float32x2_t reduced_hi =
762 vpadd_f32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
763 float32x4_t reduced = vcombine_f32(reduced_lo, reduced_hi);
764 // End of horizontal reduction: now `reduced` is a single float32x4
765 // containing the 4 float32 accumulators corresponding to the 4 rows
766 // being processed.
767
768 if (params.bias) {
769 // Add bias values.
770 reduced = vaddq_f32(reduced, vld1q_f32(params.bias + row));
771 }
772
773 // Clamp and store to destination.
774 reduced = vminq_f32(reduced, vdupq_n_f32(params.clamp_max));
775 reduced = vmaxq_f32(reduced, vdupq_n_f32(params.clamp_min));
776 vst1q_f32(dst_data + row, reduced);
777 }
778 }
779};
780
781#endif // TFLITE_WITH_RUY
782
783#endif // USE_NEON
784
785} // namespace detail
786} // namespace cpu_backend_gemm
787} // namespace tflite
788
789#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_CUSTOM_GEMV_H_
790