1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_KERNEL_ARM_H_
17#define RUY_RUY_KERNEL_ARM_H_
18
19#include <cstddef>
20#include <cstdint>
21
22#include "ruy/asm_helpers.h"
23#include "ruy/kernel_common.h"
24#include "ruy/mat.h"
25#include "ruy/mul_params.h"
26#include "ruy/opt_set.h"
27#include "ruy/path.h"
28#include "ruy/platform.h"
29#include "ruy/profiler/instrumentation.h"
30#include "ruy/side_pair.h"
31#include "ruy/size_util.h"
32#include "ruy/tune.h"
33
34namespace ruy {
35
36#if RUY_PLATFORM_NEON && RUY_OPT(ASM)
37
38RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
39RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
40
41#if RUY_PLATFORM_NEON_64
42void Kernel8bitNeon(const KernelParams8bit<4, 4>& params);
43void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params);
44#elif RUY_PLATFORM_NEON_32
45void Kernel8bitNeon(const KernelParams8bit<4, 2>& params);
46void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params);
47#endif
48void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params);
49void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params);
50void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params);
51void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params);
52void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params);
53
54#if RUY_PLATFORM_NEON_64
55template <typename DstScalar>
56struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
57 static constexpr Path kPath = Path::kNeon;
58 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
59 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
60 Tuning tuning = Tuning::kAuto;
61 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
62 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
63 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
64 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
65 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
66 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
67 end_col, dst, &params);
68 if (dst->layout.cols == 1 &&
69 mul_params.channel_dimension() == ChannelDimension::kRow) {
70 Kernel8bitNeon1Col(params);
71 return;
72 }
73 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
74 Kernel8bitNeonA55ish(params);
75 } else {
76 Kernel8bitNeon(params);
77 }
78 }
79};
80#endif
81
82#if RUY_PLATFORM_NEON_32
83template <typename DstScalar>
84struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
85 static constexpr Path kPath = Path::kNeon;
86 using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
87 using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
88 Tuning tuning = Tuning::kAuto;
89 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
90 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
91 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
92 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
93 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
94 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
95 end_col, dst, &params);
96 if (dst->layout.cols == 1 &&
97 mul_params.channel_dimension() == ChannelDimension::kRow) {
98 Kernel8bitNeon1Col(params);
99 return;
100 }
101 Kernel8bitNeon(params);
102 }
103};
104#endif
105
106#if RUY_PLATFORM_NEON_64
107template <typename DstScalar>
108struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t,
109 DstScalar> {
110 static constexpr Path kPath = Path::kNeonDotprod;
111 Tuning tuning = Tuning::kAuto;
112 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
113 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
114 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
115 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
116 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
117 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
118 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
119 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
120 end_col, dst, &params);
121 if (dst->layout.cols == 1 &&
122 mul_params.channel_dimension() == ChannelDimension::kRow) {
123 Kernel8bitNeonDotprod1Col(params);
124 } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
125 Kernel8bitNeonDotprodA55ish(params);
126 } else if (tuning == Tuning::kX1) {
127 Kernel8bitNeonDotprodX1(params);
128 } else {
129 Kernel8bitNeonDotprod(params);
130 }
131 }
132};
133#endif
134
135void KernelFloatNeon(const KernelParamsFloat<8, 8>& params);
136void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params);
137void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params);
138void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params);
139void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params);
140
141#if RUY_PLATFORM_NEON_64
142// A Float kernel for ARM64 Neon.
143template <>
144struct Kernel<Path::kNeon, float, float, float, float> {
145 static constexpr Path kPath = Path::kNeon;
146 Tuning tuning = Tuning::kAuto;
147 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
148 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
149 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
150 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
151 const MulParams<float, float>& mul_params, int start_row,
152 int start_col, int end_row, int end_col, Mat<float>* dst) const {
153 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
154 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
155 end_col, dst, &params);
156 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
157 KernelFloatNeonA55ish(params);
158 } else if (tuning == Tuning::kX1) {
159 KernelFloatNeonX1(params);
160 } else {
161 KernelFloatNeon(params);
162 }
163 }
164};
165#endif
166
167#if RUY_PLATFORM_NEON_32
168// A Float kernel for ARM32 Neon.
169template <>
170struct Kernel<Path::kNeon, float, float, float, float> {
171 static constexpr Path kPath = Path::kNeon;
172 Tuning tuning = Tuning::kAuto;
173 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
174 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
175 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
176 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
177 const MulParams<float, float>& mul_params, int start_row,
178 int start_col, int end_row, int end_col, Mat<float>* dst) const {
179 KernelParamsFloat<8, 4> params;
180
181 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
182 end_col, dst, &params);
183
184 KernelFloat32Neon(params);
185 }
186};
187#endif
188
189// While the dotprod NEON extension does not concern floating-point arithmetic,
190// its presence allows us to distinguish, in the in-order tuning case, between
191// A53 and A55r1. TODO: should this be folded into tuning?
192template <>
193struct Kernel<Path::kNeonDotprod, float, float, float, float> {
194 static constexpr Path kPath = Path::kNeonDotprod;
195 Tuning tuning = Tuning::kAuto;
196 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
197 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
198 using Base = Kernel<Path::kNeon, float, float, float, float>;
199 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
200 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
201 const MulParams<float, float>& mul_params, int start_row,
202 int start_col, int end_row, int end_col, Mat<float>* dst) const {
203 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
204 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
205 end_col, dst, &params);
206 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
207 KernelFloatNeonDotprodA55ish(params);
208 } else if (tuning == Tuning::kX1) {
209 KernelFloatNeonX1(params);
210 } else {
211 KernelFloatNeon(params);
212 }
213 }
214};
215
216#endif // RUY_PLATFORM_NEON && RUY_OPT(ASM)
217
218} // namespace ruy
219
220#endif // RUY_RUY_KERNEL_ARM_H_
221