1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ |
17 | #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ |
18 | |
19 | #include <tuple> |
20 | |
21 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
22 | #ifndef TFLITE_WITH_RUY |
23 | |
24 | #include <cstdint> |
25 | #include <type_traits> |
26 | |
27 | #include "public/gemmlowp.h" |
28 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
29 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" |
30 | #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" |
31 | |
32 | namespace tflite { |
33 | namespace cpu_backend_gemm { |
34 | namespace detail { |
35 | |
36 | template <typename DstScalar> |
37 | struct GemmlowpSaturatingCastStage {}; |
38 | |
39 | template <> |
40 | struct GemmlowpSaturatingCastStage<std::uint8_t> { |
41 | using Type = gemmlowp::OutputStageSaturatingCastToUint8; |
42 | }; |
43 | |
44 | template <> |
45 | struct GemmlowpSaturatingCastStage<std::int8_t> { |
46 | using Type = gemmlowp::OutputStageSaturatingCastToInt8; |
47 | }; |
48 | |
49 | template <> |
50 | struct GemmlowpSaturatingCastStage<std::int16_t> { |
51 | using Type = gemmlowp::OutputStageSaturatingCastToInt16; |
52 | }; |
53 | |
54 | template <typename DstScalar> |
55 | struct GemmlowpBitDepthParams {}; |
56 | |
57 | template <> |
58 | struct GemmlowpBitDepthParams<std::uint8_t> { |
59 | using Type = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; |
60 | }; |
61 | |
62 | template <> |
63 | struct GemmlowpBitDepthParams<std::int8_t> { |
64 | using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams; |
65 | }; |
66 | |
67 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
68 | typename DstScalar, QuantizationFlavor quantization_flavor> |
69 | struct GemmImplUsingGemmlowp {}; |
70 | |
71 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
72 | typename DstScalar> |
73 | struct GemmImplUsingGemmlowp< |
74 | LhsScalar, RhsScalar, AccumScalar, DstScalar, |
75 | QuantizationFlavor::kIntegerWithUniformMultiplier> { |
76 | static_assert(std::is_same<LhsScalar, RhsScalar>::value, "" ); |
77 | static_assert(std::is_same<AccumScalar, std::int32_t>::value, "" ); |
78 | using SrcScalar = LhsScalar; |
79 | |
80 | static void Run( |
81 | const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data, |
82 | const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data, |
83 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
84 | const GemmParams<std::int32_t, DstScalar, |
85 | QuantizationFlavor::kIntegerWithUniformMultiplier>& |
86 | params, |
87 | CpuBackendContext* context) { |
88 | gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor> |
89 | gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); |
90 | gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor> |
91 | gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); |
92 | gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst( |
93 | dst_data, dst_params.rows, dst_params.cols); |
94 | |
95 | using ColVectorMap = |
96 | gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>; |
97 | gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; |
98 | scale_stage.result_offset_after_shift = dst_params.zero_point; |
99 | scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint; |
100 | scale_stage.result_exponent = params.multiplier_exponent; |
101 | using SaturatingCastStageType = |
102 | typename GemmlowpSaturatingCastStage<DstScalar>::Type; |
103 | gemmlowp::OutputStageClamp clamp_stage; |
104 | clamp_stage.min = params.clamp_min; |
105 | clamp_stage.max = params.clamp_max; |
106 | SaturatingCastStageType saturating_cast_stage; |
107 | using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type; |
108 | if (params.bias) { |
109 | ColVectorMap bias_vector(params.bias, lhs_params.rows); |
110 | gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage; |
111 | bias_addition_stage.bias_vector = bias_vector; |
112 | auto output_pipeline = std::make_tuple( |
113 | bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage); |
114 | gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>( |
115 | context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, |
116 | &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point, |
117 | output_pipeline); |
118 | } else { |
119 | auto output_pipeline = |
120 | std::make_tuple(scale_stage, clamp_stage, saturating_cast_stage); |
121 | gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>( |
122 | context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, |
123 | &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point, |
124 | output_pipeline); |
125 | } |
126 | } |
127 | }; |
128 | |
129 | template <typename LhsScalar, typename RhsScalar, typename AccumScalar, |
130 | typename DstScalar> |
131 | struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
132 | QuantizationFlavor::kIntegerWithPerRowMultiplier> { |
133 | static_assert(std::is_same<LhsScalar, RhsScalar>::value, "" ); |
134 | static_assert(std::is_same<AccumScalar, std::int32_t>::value, "" ); |
135 | using SrcScalar = LhsScalar; |
136 | |
137 | static void Run( |
138 | const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data, |
139 | const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data, |
140 | const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data, |
141 | const GemmParams<std::int32_t, DstScalar, |
142 | QuantizationFlavor::kIntegerWithPerRowMultiplier>& |
143 | params, |
144 | CpuBackendContext* context) { |
145 | // gemmlowp support for this per-channel path is limited to NEON. |
146 | // We fall back to ruy outside of NEON. |
147 | #ifdef GEMMLOWP_NEON |
148 | gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor> |
149 | gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); |
150 | gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor> |
151 | gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); |
152 | gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst( |
153 | dst_data, dst_params.rows, dst_params.cols); |
154 | |
155 | using ColVectorMap = |
156 | gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>; |
157 | ColVectorMap bias_vector(params.bias, lhs_params.rows); |
158 | gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage; |
159 | bias_addition_stage.bias_vector = bias_vector; |
160 | gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< |
161 | gemmlowp::VectorShape::Col> |
162 | scale_stage; |
163 | scale_stage.result_offset_after_shift = dst_params.zero_point; |
164 | scale_stage.result_fixedpoint_multiplier = |
165 | ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows); |
166 | scale_stage.result_exponent = |
167 | ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows); |
168 | using SaturatingCastStageType = |
169 | typename GemmlowpSaturatingCastStage<DstScalar>::Type; |
170 | gemmlowp::OutputStageClamp clamp_stage; |
171 | clamp_stage.min = params.clamp_min; |
172 | clamp_stage.max = params.clamp_max; |
173 | SaturatingCastStageType saturating_cast_stage; |
174 | auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, |
175 | clamp_stage, saturating_cast_stage); |
176 | using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type; |
177 | gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>( |
178 | context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, |
179 | -lhs_params.zero_point, -rhs_params.zero_point, output_pipeline); |
180 | #else |
181 | GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar, |
182 | QuantizationFlavor::kIntegerWithPerRowMultiplier>:: |
183 | Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, |
184 | params, context); |
185 | #endif |
186 | } |
187 | }; |
188 | |
189 | } // namespace detail |
190 | } // namespace cpu_backend_gemm |
191 | } // namespace tflite |
192 | |
193 | #endif // not TFLITE_WITH_RUY |
194 | |
195 | #endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_GEMMLOWP_H_ |
196 | |