1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_ |
17 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_ |
18 | |
19 | // If TFLITE_WITH_RUY is set, Ruy is the only GEMM option. In this header |
20 | // we select either Ruy or an alternative based on the SIMD extentions |
21 | // available on the given x86 platform. |
22 | #ifndef TFLITE_WITH_RUY |
23 | |
24 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
25 | #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" |
26 | #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" |
27 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" |
28 | #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" |
29 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
30 | |
31 | namespace tflite { |
32 | namespace cpu_backend_gemm { |
33 | namespace detail { |
34 | |
35 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
36 | typename DstScalar, QuantizationFlavor quantization_flavor> |
37 | struct GemmImplX86 { |
38 | static void Run( |
39 | const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data, |
40 | const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data, |
41 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
42 | const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params, |
43 | CpuBackendContext* context) { |
44 | // TODO(b/168923364) Ruy is preferred on x86, but check if the deprecated |
45 | // path is enabled. |
46 | if (context->PreferGemmlowpOnX86()) { |
47 | // Dispatch to gemmlowp. |
48 | detail::GemmImplUsingGemmlowp< |
49 | LhsScalar, RhsScalar, AccumScalar, DstScalar, |
50 | quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, |
51 | dst_params, dst_data, params, context); |
52 | |
53 | return; |
54 | } |
55 | // Run-time dispatch to Ruy for platforms with AVX or above. |
56 | detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
57 | quantization_flavor>::Run(lhs_params, lhs_data, |
58 | rhs_params, rhs_data, |
59 | dst_params, dst_data, |
60 | params, context); |
61 | } |
62 | }; |
63 | |
64 | // For float, defer to eigen for now. |
65 | template <> |
66 | struct GemmImplX86<float, float, float, float, |
67 | QuantizationFlavor::kFloatingPoint> { |
68 | static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data, |
69 | const MatrixParams<float>& rhs_params, const float* rhs_data, |
70 | const MatrixParams<float>& dst_params, float* dst_data, |
71 | const GemmParams<float, float, |
72 | QuantizationFlavor::kFloatingPoint>& params, |
73 | CpuBackendContext* context) { |
74 | GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, |
75 | dst_params, dst_data, params, context); |
76 | } |
77 | }; |
78 | |
79 | // gemmlowp requires NEON for certain quantization cases. See note in |
80 | // cpu_backend_gemm.h |
81 | #if !defined(GEMMLOWP_NEON) |
82 | template <typename SrcScalar, QuantizationFlavor quantization_flavor> |
83 | struct GemmImplX86<SrcScalar, SrcScalar, std::int32_t, std::int8_t, |
84 | quantization_flavor> |
85 | : detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t, |
86 | quantization_flavor> {}; |
87 | |
88 | template <typename DstScalar, QuantizationFlavor quantization_flavor> |
89 | struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, DstScalar, |
90 | quantization_flavor> |
91 | : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, |
92 | DstScalar, quantization_flavor> {}; |
93 | |
94 | template <QuantizationFlavor quantization_flavor> |
95 | struct GemmImplX86<std::int8_t, std::int8_t, std::int32_t, std::int8_t, |
96 | quantization_flavor> |
97 | : detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t, |
98 | std::int8_t, quantization_flavor> {}; |
99 | #endif // not GEMMLOWP_NEON |
100 | } // namespace detail |
101 | } // namespace cpu_backend_gemm |
102 | } // namespace tflite |
103 | |
104 | #endif // not TFLITE_WITH_RUY |
105 | |
106 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_X86_H_ |
107 | |