1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_KERNEL_ARM_H_ |
17 | #define RUY_RUY_KERNEL_ARM_H_ |
18 | |
19 | #include <cstddef> |
20 | #include <cstdint> |
21 | |
22 | #include "ruy/asm_helpers.h" |
23 | #include "ruy/kernel_common.h" |
24 | #include "ruy/mat.h" |
25 | #include "ruy/mul_params.h" |
26 | #include "ruy/opt_set.h" |
27 | #include "ruy/path.h" |
28 | #include "ruy/platform.h" |
29 | #include "ruy/profiler/instrumentation.h" |
30 | #include "ruy/side_pair.h" |
31 | #include "ruy/size_util.h" |
32 | #include "ruy/tune.h" |
33 | |
34 | namespace ruy { |
35 | |
36 | #if RUY_PLATFORM_NEON && RUY_OPT(ASM) |
37 | |
38 | RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) |
39 | RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) |
40 | |
41 | #if RUY_PLATFORM_NEON_64 |
42 | void Kernel8bitNeon(const KernelParams8bit<4, 4>& params); |
43 | void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params); |
44 | #elif RUY_PLATFORM_NEON_32 |
45 | void Kernel8bitNeon(const KernelParams8bit<4, 2>& params); |
46 | void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params); |
47 | #endif |
48 | void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params); |
49 | void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params); |
50 | void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params); |
51 | void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params); |
52 | void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params); |
53 | |
54 | #if RUY_PLATFORM_NEON_64 |
55 | template <typename DstScalar> |
56 | struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { |
57 | static constexpr Path kPath = Path::kNeon; |
58 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; |
59 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; |
60 | Tuning tuning = Tuning::kAuto; |
61 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
62 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
63 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
64 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
65 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
66 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
67 | end_col, dst, ¶ms); |
68 | if (dst->layout.cols == 1 && |
69 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
70 | Kernel8bitNeon1Col(params); |
71 | return; |
72 | } |
73 | if (__builtin_expect(tuning == Tuning::kA55ish, true)) { |
74 | Kernel8bitNeonA55ish(params); |
75 | } else { |
76 | Kernel8bitNeon(params); |
77 | } |
78 | } |
79 | }; |
80 | #endif |
81 | |
82 | #if RUY_PLATFORM_NEON_32 |
83 | template <typename DstScalar> |
84 | struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> { |
85 | static constexpr Path kPath = Path::kNeon; |
86 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>; |
87 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>; |
88 | Tuning tuning = Tuning::kAuto; |
89 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
90 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
91 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
92 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
93 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
94 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
95 | end_col, dst, ¶ms); |
96 | if (dst->layout.cols == 1 && |
97 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
98 | Kernel8bitNeon1Col(params); |
99 | return; |
100 | } |
101 | Kernel8bitNeon(params); |
102 | } |
103 | }; |
104 | #endif |
105 | |
106 | #if RUY_PLATFORM_NEON_64 |
107 | template <typename DstScalar> |
108 | struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t, |
109 | DstScalar> { |
110 | static constexpr Path kPath = Path::kNeonDotprod; |
111 | Tuning tuning = Tuning::kAuto; |
112 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
113 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
114 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
115 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
116 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
117 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
118 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
119 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
120 | end_col, dst, ¶ms); |
121 | if (dst->layout.cols == 1 && |
122 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
123 | Kernel8bitNeonDotprod1Col(params); |
124 | } else if (__builtin_expect(tuning == Tuning::kA55ish, true)) { |
125 | Kernel8bitNeonDotprodA55ish(params); |
126 | } else if (tuning == Tuning::kX1) { |
127 | Kernel8bitNeonDotprodX1(params); |
128 | } else { |
129 | Kernel8bitNeonDotprod(params); |
130 | } |
131 | } |
132 | }; |
133 | #endif |
134 | |
135 | void KernelFloatNeon(const KernelParamsFloat<8, 8>& params); |
136 | void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params); |
137 | void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params); |
138 | void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params); |
139 | void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params); |
140 | |
141 | #if RUY_PLATFORM_NEON_64 |
142 | // A Float kernel for ARM64 Neon. |
143 | template <> |
144 | struct Kernel<Path::kNeon, float, float, float, float> { |
145 | static constexpr Path kPath = Path::kNeon; |
146 | Tuning tuning = Tuning::kAuto; |
147 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
148 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
149 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
150 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
151 | const MulParams<float, float>& mul_params, int start_row, |
152 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
153 | KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; |
154 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
155 | end_col, dst, ¶ms); |
156 | if (__builtin_expect(tuning == Tuning::kA55ish, true)) { |
157 | KernelFloatNeonA55ish(params); |
158 | } else if (tuning == Tuning::kX1) { |
159 | KernelFloatNeonX1(params); |
160 | } else { |
161 | KernelFloatNeon(params); |
162 | } |
163 | } |
164 | }; |
165 | #endif |
166 | |
167 | #if RUY_PLATFORM_NEON_32 |
168 | // A Float kernel for ARM32 Neon. |
169 | template <> |
170 | struct Kernel<Path::kNeon, float, float, float, float> { |
171 | static constexpr Path kPath = Path::kNeon; |
172 | Tuning tuning = Tuning::kAuto; |
173 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
174 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>; |
175 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
176 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
177 | const MulParams<float, float>& mul_params, int start_row, |
178 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
179 | KernelParamsFloat<8, 4> params; |
180 | |
181 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
182 | end_col, dst, ¶ms); |
183 | |
184 | KernelFloat32Neon(params); |
185 | } |
186 | }; |
187 | #endif |
188 | |
189 | // While the dotprod NEON extension does not concern floating-point arithmetic, |
190 | // its presence allows us to distinguish, in the in-order tuning case, between |
191 | // A53 and A55r1. TODO: should this be folded into tuning? |
192 | template <> |
193 | struct Kernel<Path::kNeonDotprod, float, float, float, float> { |
194 | static constexpr Path kPath = Path::kNeonDotprod; |
195 | Tuning tuning = Tuning::kAuto; |
196 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
197 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
198 | using Base = Kernel<Path::kNeon, float, float, float, float>; |
199 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
200 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
201 | const MulParams<float, float>& mul_params, int start_row, |
202 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
203 | KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; |
204 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
205 | end_col, dst, ¶ms); |
206 | if (__builtin_expect(tuning == Tuning::kA55ish, true)) { |
207 | KernelFloatNeonDotprodA55ish(params); |
208 | } else if (tuning == Tuning::kX1) { |
209 | KernelFloatNeonX1(params); |
210 | } else { |
211 | KernelFloatNeon(params); |
212 | } |
213 | } |
214 | }; |
215 | |
216 | #endif // RUY_PLATFORM_NEON && RUY_OPT(ASM) |
217 | |
218 | } // namespace ruy |
219 | |
220 | #endif // RUY_RUY_KERNEL_ARM_H_ |
221 | |