1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_KERNEL_COMMON_H_
17#define RUY_RUY_KERNEL_COMMON_H_
18
19#include <algorithm>
20#include <cstdint>
21#include <type_traits>
22
23#include "ruy/apply_multiplier.h"
24#include "ruy/check_macros.h"
25#include "ruy/mat.h"
26#include "ruy/matrix.h"
27#include "ruy/mul_params.h"
28#include "ruy/opt_set.h"
29#include "ruy/path.h"
30#include "ruy/platform.h"
31#include "ruy/profiler/instrumentation.h"
32#include "ruy/side_pair.h"
33#include "ruy/size_util.h"
34#include "ruy/tune.h"
35
36namespace ruy {
37
38template <Path ThePath, typename LhsScalar, typename RhsScalar,
39 typename AccumScalar, typename DstScalar>
40struct Kernel;
41
42#define RUY_INHERIT_KERNEL(PARENT, CHILD) \
43 template <typename LhsScalar, typename RhsScalar, typename DstScalar, \
44 typename AccumScalar> \
45 struct Kernel<CHILD, LhsScalar, RhsScalar, AccumScalar, DstScalar> \
46 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar> { \
47 explicit Kernel(Tuning tuning) \
48 : Kernel<PARENT, LhsScalar, RhsScalar, AccumScalar, DstScalar>( \
49 tuning) {} \
50 };
51
52// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code.
53//
54// In other cases, we still define (empty) versions, so that dummy kernels
55// can use the classes in function signatures.
56#if ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)) || \
57 RUY_PLATFORM_X86
58
59#define RUY_ASM_FLAG_HAS_BIAS 0x1
60#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2
61#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4
62#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8
63#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10
64#define RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL 0x20
65
66#define RUY_ASM_TYPE_ID_UINT8 1
67#define RUY_ASM_TYPE_ID_INT8 2
68#define RUY_ASM_TYPE_ID_INT16 3
69#define RUY_ASM_TYPE_ID_INT32 4
70
71template <typename DstScalar>
72struct DstTypeId {};
73
74template <>
75struct DstTypeId<std::uint8_t> {
76 static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8;
77};
78
79template <>
80struct DstTypeId<std::int8_t> {
81 static constexpr int kValue = RUY_ASM_TYPE_ID_INT8;
82};
83
84template <>
85struct DstTypeId<std::int16_t> {
86 static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
87};
88
89template <>
90struct DstTypeId<std::int32_t> {
91 static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
92};
93
94template <int LhsCols, int RhsCols>
95struct KernelParams8bit {
96 static constexpr int kMaxDstTypeSize = 4;
97
98 const std::int32_t* bias;
99 const std::int32_t* lhs_sums;
100 const std::int32_t* rhs_sums;
101 const std::int8_t* lhs_base_ptr;
102 const std::int32_t* multiplier_fixedpoint;
103 const std::int32_t* multiplier_exponent;
104 // Make it void* to support 8bit(LHS)x16bit(RHS) case.
105 const void* rhs_base_ptr;
106 void* dst_base_ptr;
107 std::int32_t lhs_zero_point;
108 std::int32_t rhs_zero_point;
109 std::int32_t dst_zero_point;
110 std::int32_t prod_zp_depth;
111 std::int32_t start_row;
112 std::int32_t start_col;
113 std::int32_t last_row;
114 std::int32_t last_col;
115 std::int32_t dst_rows;
116 std::int32_t dst_cols;
117 std::int32_t lhs_stride;
118 std::int32_t rhs_stride;
119 std::int32_t dst_stride;
120 std::int32_t depth;
121 std::int32_t clamp_min;
122 std::int32_t clamp_max;
123 std::uint8_t flags;
124 std::uint8_t dst_type_id;
125 const std::int32_t zero_data[LhsCols] = {0};
126 std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
127 std::int32_t multiplier_fixedpoint_buf[LhsCols];
128 std::int32_t multiplier_exponent_buf[LhsCols];
129 std::size_t rhs_scalar_size;
130};
131
132template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols>
133void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
134 const PMat<RhsScalar>& rhs,
135 const MulParams<std::int32_t, DstScalar>& mul_params,
136 int start_row, int start_col, int end_row,
137 int end_col, Mat<DstScalar>* dst,
138 KernelParams8bit<LhsCols, RhsCols>* params) {
139 using Params = KernelParams8bit<LhsCols, RhsCols>;
140
141 static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, "");
142
143 const int depth = lhs.layout.rows;
144 RUY_DCHECK_EQ(start_row % LhsCols, 0);
145 RUY_DCHECK_EQ(start_col % RhsCols, 0);
146 RUY_DCHECK_EQ(end_row % LhsCols, 0);
147 RUY_DCHECK_EQ(end_col % RhsCols, 0);
148
149 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
150 params->rhs_scalar_size = sizeof(RhsScalar);
151 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
152 params->flags = 0;
153 params->bias = params->zero_data;
154 if (mul_params.bias()) {
155 params->bias = mul_params.bias();
156 params->flags |= RUY_ASM_FLAG_HAS_BIAS;
157 }
158 if (lhs.sums) {
159 params->lhs_sums = lhs.sums;
160 params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS;
161 }
162 if (rhs.sums) {
163 params->rhs_sums = rhs.sums;
164 params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS;
165 }
166 if (mul_params.channel_dimension() == ChannelDimension::kCol) {
167 params->flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
168 }
169 params->start_row = start_row;
170 params->start_col = start_col;
171 params->last_row = end_row - LhsCols;
172 params->last_col = end_col - RhsCols;
173 params->lhs_stride = lhs.layout.stride;
174 params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride;
175 params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
176 params->lhs_zero_point = lhs.zero_point;
177 params->rhs_zero_point = rhs.zero_point;
178 params->dst_zero_point = dst->zero_point;
179 params->depth = depth;
180 params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
181 params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
182 if (mul_params.multiplier_fixedpoint_perchannel()) {
183 // Temporary release-assert to debug some crashes in an application.
184 RUY_CHECK(mul_params.multiplier_exponent_perchannel());
185 params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
186 params->multiplier_fixedpoint =
187 mul_params.multiplier_fixedpoint_perchannel();
188 params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
189 } else {
190 params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
191 params->multiplier_exponent = params->multiplier_exponent_buf;
192 for (int i = 0; i < LhsCols; i++) {
193 params->multiplier_fixedpoint_buf[i] = mul_params.multiplier_fixedpoint();
194 params->multiplier_exponent_buf[i] = mul_params.multiplier_exponent();
195 }
196 }
197 params->clamp_min = mul_params.clamp_min();
198 params->clamp_max = mul_params.clamp_max();
199 params->dst_rows = dst->layout.rows;
200 params->dst_cols = dst->layout.cols;
201
202 RUY_DCHECK_LT(params->last_row, params->dst_rows);
203 RUY_DCHECK_LT(params->last_col, params->dst_cols);
204
205 params->dst_type_id = DstTypeId<DstScalar>::kValue;
206 params->dst_base_ptr =
207 dst->data.get() + start_col * dst->layout.stride + start_row;
208
209 // Temporary release-asserts to debug some crashes in an application.
210 RUY_CHECK(params->multiplier_fixedpoint);
211 RUY_CHECK(params->multiplier_exponent);
212 RUY_CHECK(params->bias);
213}
214
215template <int LhsCols, int RhsCols>
216struct KernelParamsFloat {
217 const float* lhs_base_ptr;
218 const float* rhs_base_ptr;
219 float* dst_base_ptr;
220 const float* bias;
221 std::int32_t start_row;
222 std::int32_t start_col;
223 std::int32_t last_row;
224 std::int32_t last_col;
225 std::int32_t dst_rows;
226 std::int32_t dst_cols;
227 std::int32_t lhs_stride;
228 std::int32_t rhs_stride;
229 std::int32_t dst_stride;
230 std::int32_t depth;
231 float clamp_min;
232 float clamp_max;
233 std::uint8_t flags;
234 const float zero_data[LhsCols] = {0};
235 float dst_tmp_buf[LhsCols * RhsCols];
236};
237
238template <int LhsCols, int RhsCols>
239inline void MakeKernelParamsFloat(const PMat<float>& lhs,
240 const PMat<float>& rhs,
241 const MulParams<float, float>& mul_params,
242 int start_row, int start_col, int end_row,
243 int end_col, Mat<float>* dst,
244 KernelParamsFloat<LhsCols, RhsCols>* params) {
245 const int depth = lhs.layout.rows;
246 RUY_DCHECK_EQ(start_row % LhsCols, 0);
247 RUY_DCHECK_EQ(start_col % RhsCols, 0);
248 RUY_DCHECK_EQ(end_row % LhsCols, 0);
249 RUY_DCHECK_EQ(end_col % RhsCols, 0);
250
251 params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
252 params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
253 params->dst_base_ptr =
254 dst->data.get() + start_col * dst->layout.stride + start_row;
255
256 std::uint8_t flags = 0;
257 params->bias = params->zero_data;
258 if (mul_params.bias()) {
259 params->bias = mul_params.bias();
260 flags |= RUY_ASM_FLAG_HAS_BIAS;
261 }
262 if (mul_params.channel_dimension() == ChannelDimension::kCol) {
263 flags |= RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
264 }
265 params->flags = flags;
266 params->start_row = start_row;
267 params->start_col = start_col;
268 params->last_row = end_row - LhsCols;
269 params->last_col = end_col - RhsCols;
270 params->lhs_stride = sizeof(float) * lhs.layout.stride;
271 params->rhs_stride = sizeof(float) * rhs.layout.stride;
272 params->dst_stride = sizeof(float) * dst->layout.stride;
273 params->depth = depth;
274 params->clamp_min = mul_params.clamp_min();
275 params->clamp_max = mul_params.clamp_max();
276 params->dst_rows = dst->layout.rows;
277 params->dst_cols = dst->layout.cols;
278
279 RUY_DCHECK_LT(params->last_row, params->dst_rows);
280 RUY_DCHECK_LT(params->last_col, params->dst_cols);
281}
282
283#else // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
284 // RUY_OPT(ASM)) || RUY_PLATFORM_X86
285
286template <int LhsCols, int RhsCols>
287struct KernelParams8bit {};
288
289template <int LhsCols, int RhsCols>
290struct KernelParamsFloat {};
291
292#endif // ((RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) &&
293 // RUY_OPT(ASM)) || RUY_PLATFORM_X86
294
295} // namespace ruy
296
297#endif // RUY_RUY_KERNEL_COMMON_H_
298