1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_KERNEL_COMMON_H_ |
17 | #define RUY_RUY_KERNEL_COMMON_H_ |
18 | |
19 | #include <algorithm> |
20 | #include <cstdint> |
21 | #include <type_traits> |
22 | |
23 | #include "ruy/apply_multiplier.h" |
24 | #include "ruy/check_macros.h" |
25 | #include "ruy/mat.h" |
26 | #include "ruy/matrix.h" |
27 | #include "ruy/mul_params.h" |
28 | #include "ruy/opt_set.h" |
29 | #include "ruy/path.h" |
30 | #include "ruy/platform.h" |
31 | #include "ruy/profiler/instrumentation.h" |
32 | #include "ruy/side_pair.h" |
33 | #include "ruy/size_util.h" |
34 | #include "ruy/tune.h" |
35 | |
36 | namespace ruy { |
37 | |
38 | template <Path ThePath, typename LhsScalar, typename RhsScalar, |
39 | typename AccumScalar, typename DstScalar> |
40 | struct Kernel; |
41 | |
42 | #define RUY_INHERIT_KERNEL(PARENT, CHILD) \ |
43 | template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ |
44 | typename AccumScalar> \ |
45 | struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \ |
46 | : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \ |
47 | explicit Kernel(Tuning tuning) \ |
48 | : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \ |
49 | tuning) {} \ |
50 | }; |
51 | |
52 | // KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. |
53 | // |
54 | // In other cases, we still define (empty) versions, so that dummy kernels |
55 | // can use the classes in function signatures. |
56 | #if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \ |
57 | RUY_PLATFORM_X86 |
58 | |
59 | #define RUY_ASM_FLAG_HAS_BIAS 0x1 |
60 | #define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 |
61 | #define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 |
62 | #define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 |
63 | #define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 |
64 | #define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20 |
65 | |
66 | #define RUY_ASM_TYPE_ID_UINT8 1 |
67 | #define RUY_ASM_TYPE_ID_INT8 2 |
68 | #define RUY_ASM_TYPE_ID_INT16 3 |
69 | #define RUY_ASM_TYPE_ID_INT32 4 |
70 | |
71 | template <typename DstScalar> |
72 | struct DstTypeId {}; |
73 | |
74 | template <> |
75 | struct DstTypeId<std::uint8_t> { |
76 | static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; |
77 | }; |
78 | |
79 | template <> |
80 | struct DstTypeId<std::int8_t> { |
81 | static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; |
82 | }; |
83 | |
84 | template <> |
85 | struct DstTypeId<std::int16_t> { |
86 | static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; |
87 | }; |
88 | |
89 | template <> |
90 | struct DstTypeId<std::int32_t> { |
91 | static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; |
92 | }; |
93 | |
94 | template <int LhsCols, int RhsCols> |
95 | struct KernelParams8bit { |
96 | static constexpr int kMaxDstTypeSize = 4; |
97 | |
98 | const std::int32_t* bias; |
99 | const std::int32_t* lhs_sums; |
100 | const std::int32_t* rhs_sums; |
101 | const std::int8_t* lhs_base_ptr; |
102 | const std::int32_t* multiplier_fixedpoint; |
103 | const std::int32_t* multiplier_exponent; |
104 | // Make it void* to support 8bit(LHS)x16bit(RHS) case. |
105 | const void* rhs_base_ptr; |
106 | void* dst_base_ptr; |
107 | std::int32_t lhs_zero_point; |
108 | std::int32_t rhs_zero_point; |
109 | std::int32_t dst_zero_point; |
110 | std::int32_t prod_zp_depth; |
111 | std::int32_t start_row; |
112 | std::int32_t start_col; |
113 | std::int32_t last_row; |
114 | std::int32_t last_col; |
115 | std::int32_t dst_rows; |
116 | std::int32_t dst_cols; |
117 | std::int32_t lhs_stride; |
118 | std::int32_t rhs_stride; |
119 | std::int32_t dst_stride; |
120 | std::int32_t depth; |
121 | std::int32_t clamp_min; |
122 | std::int32_t clamp_max; |
123 | std::uint8_t flags; |
124 | std::uint8_t dst_type_id; |
125 | const std::int32_t zero_data[LhsCols] = {0}; |
126 | std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; |
127 | std::int32_t multiplier_fixedpoint_buf[LhsCols]; |
128 | std::int32_t multiplier_exponent_buf[LhsCols]; |
129 | std::size_t rhs_scalar_size; |
130 | }; |
131 | |
132 | template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols> |
133 | void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, |
134 | const PMat<RhsScalar>& rhs, |
135 | const MulParams<std::int32_t, DstScalar>& mul_params, |
136 | int start_row, int start_col, int end_row, |
137 | int end_col, Mat<DstScalar>* dst, |
138 | KernelParams8bit<LhsCols, RhsCols>* params) { |
139 | using Params = KernelParams8bit<LhsCols, RhsCols>; |
140 | |
141 | static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "" ); |
142 | |
143 | const int depth = lhs.layout.rows; |
144 | RUY_DCHECK_EQ(start_row % LhsCols, 0); |
145 | RUY_DCHECK_EQ(start_col % RhsCols, 0); |
146 | RUY_DCHECK_EQ(end_row % LhsCols, 0); |
147 | RUY_DCHECK_EQ(end_col % RhsCols, 0); |
148 | |
149 | params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; |
150 | params->rhs_scalar_size = sizeof(RhsScalar); |
151 | params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; |
152 | params->flags = 0; |
153 | params->bias = params->zero_data; |
154 | if (mul_params.bias()) { |
155 | params->bias = mul_params.bias(); |
156 | params->flags |= RUY_ASM_FLAG_HAS_BIAS; |
157 | } |
158 | if (lhs.sums) { |
159 | params->lhs_sums = lhs.sums; |
160 | params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; |
161 | } |
162 | if (rhs.sums) { |
163 | params->rhs_sums = rhs.sums; |
164 | params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; |
165 | } |
166 | if (mul_params.channel_dimension() == ChannelDimension::kCol) { |
167 | params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; |
168 | } |
169 | params->start_row = start_row; |
170 | params->start_col = start_col; |
171 | params->last_row = end_row - LhsCols; |
172 | params->last_col = end_col - RhsCols; |
173 | params->lhs_stride = lhs.layout.stride; |
174 | params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride; |
175 | params->dst_stride = sizeof(DstScalar) * dst->layout.stride; |
176 | params->lhs_zero_point = lhs.zero_point; |
177 | params->rhs_zero_point = rhs.zero_point; |
178 | params->dst_zero_point = dst->zero_point; |
179 | params->depth = depth; |
180 | params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; |
181 | params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; |
182 | if (mul_params.multiplier_fixedpoint_perchannel()) { |
183 | // Temporary release-assert to debug some crashes in an application. |
184 | RUY_CHECK(mul_params.multiplier_exponent_perchannel()); |
185 | params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; |
186 | params->multiplier_fixedpoint = |
187 | mul_params.multiplier_fixedpoint_perchannel(); |
188 | params->multiplier_exponent = mul_params.multiplier_exponent_perchannel(); |
189 | } else { |
190 | params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; |
191 | params->multiplier_exponent = params->multiplier_exponent_buf; |
192 | for (int i = 0; i < LhsCols; i++) { |
193 | params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint(); |
194 | params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent(); |
195 | } |
196 | } |
197 | params->clamp_min = mul_params.clamp_min(); |
198 | params->clamp_max = mul_params.clamp_max(); |
199 | params->dst_rows = dst->layout.rows; |
200 | params->dst_cols = dst->layout.cols; |
201 | |
202 | RUY_DCHECK_LT(params->last_row, params->dst_rows); |
203 | RUY_DCHECK_LT(params->last_col, params->dst_cols); |
204 | |
205 | params->dst_type_id = DstTypeId<DstScalar>::kValue; |
206 | params->dst_base_ptr = |
207 | dst->data.get() + start_col * dst->layout.stride + start_row; |
208 | |
209 | // Temporary release-asserts to debug some crashes in an application. |
210 | RUY_CHECK(params->multiplier_fixedpoint); |
211 | RUY_CHECK(params->multiplier_exponent); |
212 | RUY_CHECK(params->bias); |
213 | } |
214 | |
215 | template <int LhsCols, int RhsCols> |
216 | struct KernelParamsFloat { |
217 | const float* lhs_base_ptr; |
218 | const float* rhs_base_ptr; |
219 | float* dst_base_ptr; |
220 | const float* bias; |
221 | std::int32_t start_row; |
222 | std::int32_t start_col; |
223 | std::int32_t last_row; |
224 | std::int32_t last_col; |
225 | std::int32_t dst_rows; |
226 | std::int32_t dst_cols; |
227 | std::int32_t lhs_stride; |
228 | std::int32_t rhs_stride; |
229 | std::int32_t dst_stride; |
230 | std::int32_t depth; |
231 | float clamp_min; |
232 | float clamp_max; |
233 | std::uint8_t flags; |
234 | const float zero_data[LhsCols] = {0}; |
235 | float dst_tmp_buf[LhsCols * RhsCols]; |
236 | }; |
237 | |
238 | template <int LhsCols, int RhsCols> |
239 | inline void MakeKernelParamsFloat(const PMat<float>& lhs, |
240 | const PMat<float>& rhs, |
241 | const MulParams<float, float>& mul_params, |
242 | int start_row, int start_col, int end_row, |
243 | int end_col, Mat<float>* dst, |
244 | KernelParamsFloat<LhsCols, RhsCols>* params) { |
245 | const int depth = lhs.layout.rows; |
246 | RUY_DCHECK_EQ(start_row % LhsCols, 0); |
247 | RUY_DCHECK_EQ(start_col % RhsCols, 0); |
248 | RUY_DCHECK_EQ(end_row % LhsCols, 0); |
249 | RUY_DCHECK_EQ(end_col % RhsCols, 0); |
250 | |
251 | params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; |
252 | params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; |
253 | params->dst_base_ptr = |
254 | dst->data.get() + start_col * dst->layout.stride + start_row; |
255 | |
256 | std::uint8_t flags = 0; |
257 | params->bias = params->zero_data; |
258 | if (mul_params.bias()) { |
259 | params->bias = mul_params.bias(); |
260 | flags |= RUY_ASM_FLAG_HAS_BIAS; |
261 | } |
262 | if (mul_params.channel_dimension() == ChannelDimension::kCol) { |
263 | flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; |
264 | } |
265 | params->flags = flags; |
266 | params->start_row = start_row; |
267 | params->start_col = start_col; |
268 | params->last_row = end_row - LhsCols; |
269 | params->last_col = end_col - RhsCols; |
270 | params->lhs_stride = sizeof(float) * lhs.layout.stride; |
271 | params->rhs_stride = sizeof(float) * rhs.layout.stride; |
272 | params->dst_stride = sizeof(float) * dst->layout.stride; |
273 | params->depth = depth; |
274 | params->clamp_min = mul_params.clamp_min(); |
275 | params->clamp_max = mul_params.clamp_max(); |
276 | params->dst_rows = dst->layout.rows; |
277 | params->dst_cols = dst->layout.cols; |
278 | |
279 | RUY_DCHECK_LT(params->last_row, params->dst_rows); |
280 | RUY_DCHECK_LT(params->last_col, params->dst_cols); |
281 | } |
282 | |
283 | #else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && |
284 | // RUY_OPT(ASM)) || RUY_PLATFORM_X86 |
285 | |
286 | template <int LhsCols, int RhsCols> |
287 | struct KernelParams8bit {}; |
288 | |
289 | template <int LhsCols, int RhsCols> |
290 | struct KernelParamsFloat {}; |
291 | |
292 | #endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && |
293 | // RUY_OPT(ASM)) || RUY_PLATFORM_X86 |
294 | |
295 | } // namespace ruy |
296 | |
297 | #endif // RUY_RUY_KERNEL_COMMON_H_ |
298 | |