1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_
17#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_
18
19// If TFLITE_WITH_RUY is set, Ruy is the only GEMM option. In this header
20// we select either Ruy or an alternative based on the SIMD extentions
21// available on the given x86 platform.
22#ifndef TFLITE_WITH_RUY
23
24#include "tensorflow/lite/kernels/cpu_backend_context.h"
25#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h"
26#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h"
27#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
28#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h"
29#include "tensorflow/lite/kernels/internal/compatibility.h"
30
31namespace tflite {
32namespace cpu_backend_gemm {
33namespace detail {
34
35template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
36 typename DstScalar, QuantizationFlavor quantization_flavor>
37struct GemmImplX86 {
38 static void Run(
39 const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
40 const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
41 const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
42 const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
43 CpuBackendContext* context) {
44 // TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated
45 // path is enabled.
46 if (context->PreferGemmlowpOnX86()) {
47 // Dispatch to gemmlowp.
48 detail::GemmImplUsingGemmlowp<
49 LhsScalar, RhsScalar, AccumScalar, DstScalar,
50 quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
51 dst_params, dst_data, params, context);
52
53 return;
54 }
55 // Run-time dispatch to Ruy for platforms with AVX or above.
56 detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
57 quantization_flavor>::Run(lhs_params, lhs_data,
58 rhs_params, rhs_data,
59 dst_params, dst_data,
60 params, context);
61 }
62};
63
64// For float, defer to eigen for now.
65template <>
66struct GemmImplX86<float, float, float, float,
67 QuantizationFlavor::kFloatingPoint> {
68 static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
69 const MatrixParams<float>& rhs_params, const float* rhs_data,
70 const MatrixParams<float>& dst_params, float* dst_data,
71 const GemmParams<float, float,
72 QuantizationFlavor::kFloatingPoint>& params,
73 CpuBackendContext* context) {
74 GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data,
75 dst_params, dst_data, params, context);
76 }
77};
78
79// gemmlowp requires NEON for certain quantization cases. See note in
80// cpu_backend_gemm.h
81#if !defined(GEMMLOWP_NEON)
82template <typename SrcScalar, QuantizationFlavor quantization_flavor>
83struct GemmImplX86<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
84 quantization_flavor>
85 : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
86 quantization_flavor> {};
87
88template <typename DstScalar, QuantizationFlavor quantization_flavor>
89struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, DstScalar,
90 quantization_flavor>
91 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
92 DstScalar, quantization_flavor> {};
93
94template <QuantizationFlavor quantization_flavor>
95struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
96 quantization_flavor>
97 : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
98 std::int8_t, quantization_flavor> {};
99#endif // not GEMMLOWP_NEON
100} // namespace detail
101} // namespace cpu_backend_gemm
102} // namespace tflite
103
104#endif // not TFLITE_WITH_RUY
105
106#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_
107