1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | |
18 | #include "ruy/asm_helpers.h" |
19 | #include "ruy/check_macros.h" |
20 | #include "ruy/kernel_arm.h" |
21 | #include "ruy/opt_set.h" |
22 | #include "ruy/platform.h" |
23 | #include "ruy/profiler/instrumentation.h" |
24 | |
25 | namespace ruy { |
26 | |
27 | #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
28 | |
29 | #define RUY_ASM_LABEL_STORE_UINT8 91 |
30 | #define RUY_ASM_LABEL_STORE_INT8 92 |
31 | #define RUY_ASM_LABEL_STORE_INT16 93 |
32 | #define RUY_ASM_LABEL_STORE_INT32 94 |
33 | #define RUY_ASM_LABEL_AFTER_STORE 99 |
34 | |
35 | #define RUY_OFFSET_BIAS 0 |
36 | #define RUY_OFFSET_LHS_SUMS 8 |
37 | #define RUY_OFFSET_RHS_SUMS 16 |
38 | #define RUY_OFFSET_LHS_BASE_PTR 24 |
39 | #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 |
40 | #define RUY_OFFSET_MULTIPLIER_EXPONENT 40 |
41 | #define RUY_OFFSET_RHS_BASE_PTR 48 |
42 | #define RUY_OFFSET_DST_BASE_PTR 56 |
43 | #define RUY_OFFSET_LHS_ZERO_POINT 64 |
44 | #define RUY_OFFSET_RHS_ZERO_POINT 68 |
45 | #define RUY_OFFSET_DST_ZERO_POINT 72 |
46 | #define RUY_OFFSET_PROD_ZP_DEPTH 76 |
47 | #define RUY_OFFSET_START_ROW 80 |
48 | #define RUY_OFFSET_START_COL 84 |
49 | #define RUY_OFFSET_LAST_ROW 88 |
50 | #define RUY_OFFSET_LAST_COL 92 |
51 | #define RUY_OFFSET_DST_ROWS 96 |
52 | #define RUY_OFFSET_DST_COLS 100 |
53 | #define RUY_OFFSET_LHS_STRIDE 104 |
54 | #define RUY_OFFSET_RHS_STRIDE 108 |
55 | #define RUY_OFFSET_DST_STRIDE 112 |
56 | #define RUY_OFFSET_DEPTH 116 |
57 | #define RUY_OFFSET_CLAMP_MIN 120 |
58 | #define RUY_OFFSET_CLAMP_MAX 124 |
59 | #define RUY_OFFSET_FLAGS 128 |
60 | |
61 | template <typename Params> |
62 | void CheckOffsetsInKernelParams8bit(const Params&) { |
63 | static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, |
64 | "" ); |
65 | static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, |
66 | "" ); |
67 | static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, |
68 | "" ); |
69 | static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, |
70 | "" ); |
71 | static_assert(offsetof(Params, multiplier_fixedpoint) == |
72 | RUY_OFFSET_MULTIPLIER_FIXEDPOINT, |
73 | "" ); |
74 | static_assert( |
75 | offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, |
76 | "" ); |
77 | static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "" ); |
78 | static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "" ); |
79 | static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "" ); |
80 | static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "" ); |
81 | static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "" ); |
82 | static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "" ); |
83 | static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "" ); |
84 | static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "" ); |
85 | static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "" ); |
86 | static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "" ); |
87 | static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "" ); |
88 | static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "" ); |
89 | static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "" ); |
90 | static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "" ); |
91 | } |
92 | |
93 | // Fast-int8-trick kernel, similar to this production gemmlowp kernel: |
94 | // NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits |
95 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 |
96 | // |
97 | // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, |
98 | // since these are 64-bit, out-of-order and without dotprod support. |
99 | void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) { |
100 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
101 | CheckOffsetsInKernelParams8bit(params); |
102 | |
103 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
104 | const std::int8_t* rhs_col_ptr = |
105 | static_cast<const int8_t*>(params.rhs_base_ptr); |
106 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
107 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
108 | void* dst_col_ptr = params.dst_base_ptr; |
109 | void* dst_ptr = dst_col_ptr; |
110 | int row = params.start_row; |
111 | int col = params.start_col; |
112 | |
113 | // The asm kernel below has the following NEON register allocation: |
114 | // |
115 | // v16 -- v31 are int32 accumulators. |
116 | // During accumulation, v0 -- v3 are used to load int8 data from LHS and |
117 | // v4 -- v7 from RHS: |
118 | // |
119 | // int8 RHS 16x4 block |
120 | // /-----------------------------------------| |
121 | // |v4.b[0] ... v7.b[0] | |
122 | // | ... ... | |
123 | // |v4.b[15] ... v7.b[15] | |
124 | // \-----------------------------------------/ |
125 | // int8 LHS 4x16 block |
126 | // /---------------------\ /-----------------------------------------| |
127 | // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | |
128 | // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | |
129 | // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | |
130 | // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | |
131 | // \---------------------/ \-----------------------------------------/ |
132 | // int32 accumulators 4x4 block |
133 | // |
134 | // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
135 | // optimization for this kernel. |
136 | asm volatile( |
137 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
138 | |
139 | // clang-format off |
140 | |
141 | // Load some parameters into registers. |
142 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
143 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
144 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
145 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
146 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
147 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
148 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
149 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
150 | |
151 | // Load the first 64 bytes of LHS and RHS data. |
152 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
153 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
154 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
155 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
156 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
157 | "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" |
158 | "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" |
159 | "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" |
160 | |
161 | // Clear accumulators. |
162 | RUY_MAKE_ZERO(v16) |
163 | RUY_MAKE_ZERO(v17) |
164 | RUY_MAKE_ZERO(v18) |
165 | RUY_MAKE_ZERO(v19) |
166 | RUY_MAKE_ZERO(v20) |
167 | RUY_MAKE_ZERO(v21) |
168 | RUY_MAKE_ZERO(v22) |
169 | RUY_MAKE_ZERO(v23) |
170 | RUY_MAKE_ZERO(v24) |
171 | RUY_MAKE_ZERO(v25) |
172 | RUY_MAKE_ZERO(v26) |
173 | RUY_MAKE_ZERO(v27) |
174 | RUY_MAKE_ZERO(v28) |
175 | RUY_MAKE_ZERO(v29) |
176 | RUY_MAKE_ZERO(v30) |
177 | RUY_MAKE_ZERO(v31) |
178 | |
179 | // w1 is the number of levels of depth that we have already loaded |
180 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
181 | // above, this is currently 16. |
182 | "mov w1, #16\n" |
183 | |
184 | // Perform the first few multiply-adds on the data that we have already |
185 | // loaded. |
186 | "smull v8.8h, v0.8b, v4.8b\n" |
187 | "smull v9.8h, v1.8b, v4.8b\n" |
188 | "smull v10.8h, v2.8b, v4.8b\n" |
189 | "smull v11.8h, v3.8b, v4.8b\n" |
190 | "smull v12.8h, v0.8b, v5.8b\n" |
191 | "smull v13.8h, v1.8b, v5.8b\n" |
192 | "smull v14.8h, v2.8b, v5.8b\n" |
193 | "smull v15.8h, v3.8b, v5.8b\n" |
194 | |
195 | // Multiply-accumulate second-half, again into the same |
196 | // 16bit local accumulator registers. This is where we |
197 | // take advantage of having int8 instead of uint8 and therefore |
198 | // being able to accumulate two products into int16. |
199 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
200 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
201 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
202 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
203 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
204 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
205 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
206 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
207 | |
208 | |
209 | // Main loop of the whole GEMM, over rows and columns of the |
210 | // destination matrix. |
211 | "1:\n" |
212 | |
213 | // Reminder - w1 is how many levels of depth we have already loaded |
214 | // data for, w12 is the total depth. |
215 | "cmp w1, w12\n" |
216 | "beq 79f\n" |
217 | |
218 | "2:\n" |
219 | |
220 | // Some multiplications and 16-bit accumulation were already done above, |
221 | // so we start right away in the middle. |
222 | "sadalp v16.4s, v8.8h\n" |
223 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
224 | "smull v8.8h, v0.8b, v6.8b\n" |
225 | "sadalp v17.4s, v9.8h\n" |
226 | "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" |
227 | "smull v9.8h, v1.8b, v6.8b\n" |
228 | "sadalp v18.4s, v10.8h\n" |
229 | "smull v10.8h, v2.8b, v6.8b\n" |
230 | "sadalp v19.4s, v11.8h\n" |
231 | "smull v11.8h, v3.8b, v6.8b\n" |
232 | "sadalp v20.4s, v12.8h\n" |
233 | "smull v12.8h, v0.8b, v7.8b\n" |
234 | "sadalp v21.4s, v13.8h\n" |
235 | "smull v13.8h, v1.8b, v7.8b\n" |
236 | "sadalp v22.4s, v14.8h\n" |
237 | "smull v14.8h, v2.8b, v7.8b\n" |
238 | "sadalp v23.4s, v15.8h\n" |
239 | "smull v15.8h, v3.8b, v7.8b\n" |
240 | |
241 | // Multiply-accumulate second-half, again into the same |
242 | // 16bit local accumulator registers. This is where we |
243 | // take advantage of having int8 instead of uint8 and therefore |
244 | // being able to accumulate two products into int16. |
245 | "smlal2 v8.8h, v0.16b, v6.16b\n" |
246 | "smlal2 v9.8h, v1.16b, v6.16b\n" |
247 | "smlal2 v10.8h, v2.16b, v6.16b\n" |
248 | "smlal2 v11.8h, v3.16b, v6.16b\n" |
249 | |
250 | "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" |
251 | |
252 | "smlal2 v12.8h, v0.16b, v7.16b\n" |
253 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
254 | "smlal2 v13.8h, v1.16b, v7.16b\n" |
255 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
256 | "smlal2 v14.8h, v2.16b, v7.16b\n" |
257 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
258 | "smlal2 v15.8h, v3.16b, v7.16b\n" |
259 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
260 | |
261 | "sadalp v24.4s, v8.8h\n" |
262 | "smull v8.8h, v0.8b, v4.8b\n" |
263 | "sadalp v25.4s, v9.8h\n" |
264 | "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" |
265 | "smull v9.8h, v1.8b, v4.8b\n" |
266 | "sadalp v26.4s, v10.8h\n" |
267 | "smull v10.8h, v2.8b, v4.8b\n" |
268 | "sadalp v27.4s, v11.8h\n" |
269 | "smull v11.8h, v3.8b, v4.8b\n" |
270 | "sadalp v28.4s, v12.8h\n" |
271 | "smull v12.8h, v0.8b, v5.8b\n" |
272 | "sadalp v29.4s, v13.8h\n" |
273 | "smull v13.8h, v1.8b, v5.8b\n" |
274 | "sadalp v30.4s, v14.8h\n" |
275 | "smull v14.8h, v2.8b, v5.8b\n" |
276 | "sadalp v31.4s, v15.8h\n" |
277 | "smull v15.8h, v3.8b, v5.8b\n" |
278 | |
279 | // Multiply-accumulate second-half, again into the same |
280 | // 16bit local accumulator registers. This is where we |
281 | // take advantage of having int8 instead of uint8 and therefore |
282 | // being able to accumulate two products into int16. |
283 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
284 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
285 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
286 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
287 | |
288 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
289 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
290 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
291 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
292 | |
293 | |
294 | |
295 | // Each iteration of this loop advances by 16 levels of depth. |
296 | "add w1, w1, #16\n" |
297 | |
298 | // Loop termination condition |
299 | "cmp w1, w12\n" |
300 | |
301 | "blt 2b\n" |
302 | |
303 | "79:\n" |
304 | |
305 | "sadalp v16.4s, v8.8h\n" |
306 | "smull v8.8h, v0.8b, v6.8b\n" |
307 | "sadalp v17.4s, v9.8h\n" |
308 | "smull v9.8h, v1.8b, v6.8b\n" |
309 | "sadalp v18.4s, v10.8h\n" |
310 | "smull v10.8h, v2.8b, v6.8b\n" |
311 | "sadalp v19.4s, v11.8h\n" |
312 | "smull v11.8h, v3.8b, v6.8b\n" |
313 | "sadalp v20.4s, v12.8h\n" |
314 | "smull v12.8h, v0.8b, v7.8b\n" |
315 | "sadalp v21.4s, v13.8h\n" |
316 | "smull v13.8h, v1.8b, v7.8b\n" |
317 | "sadalp v22.4s, v14.8h\n" |
318 | "smull v14.8h, v2.8b, v7.8b\n" |
319 | "sadalp v23.4s, v15.8h\n" |
320 | "smull v15.8h, v3.8b, v7.8b\n" |
321 | |
322 | // Multiply-accumulate second-half, again into the same |
323 | // 16bit local accumulator registers. This is where we |
324 | // take advantage of having int8 instead of uint8 and therefore |
325 | // being able to accumulate two products into int16. |
326 | "smlal2 v8.8h, v0.16b, v6.16b\n" |
327 | "smlal2 v9.8h, v1.16b, v6.16b\n" |
328 | "smlal2 v10.8h, v2.16b, v6.16b\n" |
329 | "smlal2 v11.8h, v3.16b, v6.16b\n" |
330 | |
331 | "smlal2 v12.8h, v0.16b, v7.16b\n" |
332 | "smlal2 v13.8h, v1.16b, v7.16b\n" |
333 | "smlal2 v14.8h, v2.16b, v7.16b\n" |
334 | "smlal2 v15.8h, v3.16b, v7.16b\n" |
335 | |
336 | "sadalp v24.4s, v8.8h\n" |
337 | "sadalp v25.4s, v9.8h\n" |
338 | "sadalp v26.4s, v10.8h\n" |
339 | "sadalp v27.4s, v11.8h\n" |
340 | "sadalp v28.4s, v12.8h\n" |
341 | "sadalp v29.4s, v13.8h\n" |
342 | "sadalp v30.4s, v14.8h\n" |
343 | "sadalp v31.4s, v15.8h\n" |
344 | |
345 | // End of accumulation. The registers v16 -- v31 contain the final |
346 | // int32 accumulator values of the current 4x4 destination block. |
347 | // We now have to compute the final 8-bit values from these int32 |
348 | // accumulators, and advance to the next 4x4 block. We intertwine |
349 | // these two aspects whenever possible for optimal pipelining, both |
350 | // at the data flow level (prefetch data for next block as early as |
351 | // possible) and instruction pipelining level (some of the next-block |
352 | // work can dual-issue with some of the final work on the current |
353 | // block). |
354 | |
355 | // Reduce 32bit accumulators horizontally. |
356 | "addp v16.4s, v16.4s, v17.4s\n" |
357 | "addp v18.4s, v18.4s, v19.4s\n" |
358 | "addp v20.4s, v20.4s, v21.4s\n" |
359 | "addp v22.4s, v22.4s, v23.4s\n" |
360 | "addp v24.4s, v24.4s, v25.4s\n" |
361 | "addp v26.4s, v26.4s, v27.4s\n" |
362 | "addp v28.4s, v28.4s, v29.4s\n" |
363 | "addp v30.4s, v30.4s, v31.4s\n" |
364 | |
365 | // Reduce 32bit accumulators horizontally, second pass |
366 | // (each pass adds pairwise. we need to add 4-wise). |
367 | "addp v16.4s, v16.4s, v18.4s\n" |
368 | "addp v17.4s, v20.4s, v22.4s\n" |
369 | "addp v18.4s, v24.4s, v26.4s\n" |
370 | "addp v19.4s, v28.4s, v30.4s\n" |
371 | |
372 | // Logic to advance to the next block in preparation for the next |
373 | // iteration of the main loop. For now, we only want to compute |
374 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
375 | // not yet ready to update the values of row and col, as we still need |
376 | // the current values for the rest of the work on the current block. |
377 | |
378 | "cmp %w[row], w7\n" // Have we finished the last row? |
379 | "bge 4f\n" // If finished last row, go to 4 |
380 | // Not finished last row: then advance to next row. |
381 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" |
382 | "b 5f\n" |
383 | "4:\n" // Finished last row... |
384 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
385 | // Now we need to advance to the next column. If we already |
386 | // finished the last column, then in principle we are done, however |
387 | // we can't just return here, as we need to allow the end work of the |
388 | // current block to complete. The good news is that at this point it |
389 | // doesn't matter what data we load for the next column, since |
390 | // we will exit from the main loop below before actually storing |
391 | // anything computed from that data. |
392 | "cmp %w[col], w8\n" // Have we finished the last column? |
393 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
394 | // Not finished last column: then advance to next column. |
395 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" |
396 | "5:\n" |
397 | |
398 | // Set the LHS and RHS data pointers to the start of the columns just |
399 | // computed. |
400 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
401 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
402 | |
403 | // Load some parameters needed for the end work on current block. |
404 | "mvni v8.4s, #0\n" |
405 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
406 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
407 | "ins v13.h[4], w4\n" // dst_zero_point |
408 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
409 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
410 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
411 | |
412 | // Now we load: bias data, LHS sums data, RHS sums data. |
413 | |
414 | // First, load the base pointers from the params. |
415 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
416 | |
417 | // Determine the channel index. |
418 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
419 | "csel w3, %w[row], %w[col], eq\n" |
420 | |
421 | // Offset the bias pointer as needed given the current row, col. |
422 | "add x5, x1, x3, lsl #2\n" |
423 | |
424 | // If there is no bias, use no offset, just address the passed zero |
425 | // data. |
426 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
427 | "csel x1, x1, x5, eq\n" |
428 | |
429 | // Load 4 bias values. |
430 | "ld1 {v14.4s}, [x1]\n" |
431 | |
432 | // Load the multiplier_fixedpoint values. |
433 | "add x5, x4, x3, lsl #2\n" |
434 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
435 | "csel x4, x4, x5, eq\n" |
436 | "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint |
437 | |
438 | // Now that we know what LHS and RHS data the next iteration of the |
439 | // main loop will need to load, we start loading the first 32 bytes of |
440 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
441 | // in the rest of the work on the current block. |
442 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
443 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
444 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
445 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
446 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
447 | "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" |
448 | "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" |
449 | "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" |
450 | |
451 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
452 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
453 | "add v14.4s, v14.4s, v9.4s\n" |
454 | |
455 | // Perform the bias-addition (per the above, we have just folded into |
456 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
457 | // Jump based on channel dimension. |
458 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
459 | "bne 6f\n" |
460 | // Case where channels are rows |
461 | "add v16.4s, v16.4s, v14.4s\n" |
462 | "add v17.4s, v17.4s, v14.4s\n" |
463 | "add v18.4s, v18.4s, v14.4s\n" |
464 | "add v19.4s, v19.4s, v14.4s\n" |
465 | "b 7f\n" |
466 | |
467 | "6:\n" |
468 | // Case where channels are columns |
469 | "dup v20.4s, v14.s[0]\n" |
470 | "dup v21.4s, v14.s[1]\n" |
471 | "dup v22.4s, v14.s[2]\n" |
472 | "dup v23.4s, v14.s[3]\n" |
473 | "add v16.4s, v16.4s, v20.4s\n" |
474 | "add v17.4s, v17.4s, v21.4s\n" |
475 | "add v18.4s, v18.4s, v22.4s\n" |
476 | "add v19.4s, v19.4s, v23.4s\n" |
477 | "7:\n" |
478 | |
479 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
480 | "beq 401f\n" |
481 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
482 | "add x3, x3, %x[col], lsl #2\n" |
483 | "ld1 {v14.4s}, [x3]\n" |
484 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
485 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
486 | // Subtract rhs_sums * lhs_zero_point, per |
487 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
488 | "mls v16.4s, v10.4s, v14.s[0]\n" |
489 | "mls v17.4s, v10.4s, v14.s[1]\n" |
490 | "mls v18.4s, v10.4s, v14.s[2]\n" |
491 | "mls v19.4s, v10.4s, v14.s[3]\n" |
492 | "401:\n" |
493 | |
494 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
495 | "beq 402f\n" |
496 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
497 | "add x2, x2, %x[row], lsl #2\n" |
498 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
499 | // Load 4 lhs_sums values. |
500 | "ld1 {v11.4s}, [x2]\n" |
501 | "ins v13.s[1], w5\n" // rhs_zero_point |
502 | // Compute lhs_sums * rhs_zero_point. |
503 | "mul v11.4s, v11.4s, v13.s[1]\n" |
504 | // Subtract lhs_sums * rhs_zero_point, per |
505 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
506 | "sub v16.4s, v16.4s, v11.4s\n" |
507 | "sub v17.4s, v17.4s, v11.4s\n" |
508 | "sub v18.4s, v18.4s, v11.4s\n" |
509 | "sub v19.4s, v19.4s, v11.4s\n" |
510 | |
511 | // If the destination is int32, it means the user asks for the raw |
512 | // accumulators, no need for us to downquantize the value. |
513 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
514 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
515 | |
516 | "402:\n" |
517 | |
518 | // At this point we have computed the final int32 values. Now we |
519 | // start down-quantizing them to obtain the final 8bit values from them. |
520 | |
521 | // As part of this down-quantization, our int32 values will be |
522 | // multiplied by a multiplier that has a fixed-point component and an |
523 | // exponent component. |
524 | |
525 | //Load the exponent part of the multiplier. |
526 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
527 | // Determine the channel index. |
528 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
529 | "csel w3, %w[row], %w[col], eq\n" |
530 | |
531 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
532 | "add x5, x1, x3, lsl #2\n" |
533 | "csel x1, x1, x5, eq\n" |
534 | |
535 | "ld1 {v14.4s}, [x1]\n" |
536 | |
537 | "smin v11.4s, v8.4s, v14.4s\n" |
538 | "sub v12.4s, v14.4s, v11.4s\n" |
539 | |
540 | // Jump based on channel dimension. |
541 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
542 | "bne 8f\n" |
543 | // Case where channels are rows |
544 | |
545 | // Apply the positive exponent part of the multiplier. |
546 | "sshl v16.4s, v16.4s, v12.4s\n" |
547 | "sshl v17.4s, v17.4s, v12.4s\n" |
548 | "sshl v18.4s, v18.4s, v12.4s\n" |
549 | "sshl v19.4s, v19.4s, v12.4s\n" |
550 | |
551 | // Apply the fixed-point part of the multiplier. |
552 | "sqdmulh v16.4s, v16.4s, v15.4s\n" |
553 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
554 | "sqdmulh v18.4s, v18.4s, v15.4s\n" |
555 | "sqdmulh v19.4s, v19.4s, v15.4s\n" |
556 | |
557 | // Apply the negative exponent part of the multiplier. |
558 | "srshl v16.4s, v16.4s, v11.4s\n" |
559 | "srshl v17.4s, v17.4s, v11.4s\n" |
560 | "srshl v18.4s, v18.4s, v11.4s\n" |
561 | "srshl v19.4s, v19.4s, v11.4s\n" |
562 | "b 9f\n" |
563 | |
564 | "8:\n" |
565 | // Case where channels are columns |
566 | |
567 | // Apply the positive exponent part of the multiplier. |
568 | "dup v20.4s, v12.s[0]\n" |
569 | "dup v21.4s, v12.s[1]\n" |
570 | "dup v22.4s, v12.s[2]\n" |
571 | "dup v23.4s, v12.s[3]\n" |
572 | "sshl v16.4s, v16.4s, v20.4s\n" |
573 | "sshl v17.4s, v17.4s, v21.4s\n" |
574 | "sshl v18.4s, v18.4s, v22.4s\n" |
575 | "sshl v19.4s, v19.4s, v23.4s\n" |
576 | |
577 | // Apply the fixed-point part of the multiplier. |
578 | "sqdmulh v16.4s, v16.4s, v15.s[0]\n" |
579 | "sqdmulh v17.4s, v17.4s, v15.s[1]\n" |
580 | "sqdmulh v18.4s, v18.4s, v15.s[2]\n" |
581 | "sqdmulh v19.4s, v19.4s, v15.s[3]\n" |
582 | |
583 | // Apply the negative exponent part of the multiplier. |
584 | "dup v20.4s, v11.s[0]\n" |
585 | "dup v21.4s, v11.s[1]\n" |
586 | "dup v22.4s, v11.s[2]\n" |
587 | "dup v23.4s, v11.s[3]\n" |
588 | "srshl v16.4s, v16.4s, v20.4s\n" |
589 | "srshl v17.4s, v17.4s, v21.4s\n" |
590 | "srshl v18.4s, v18.4s, v22.4s\n" |
591 | "srshl v19.4s, v19.4s, v23.4s\n" |
592 | "9:\n" |
593 | |
594 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
595 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
596 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
597 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
598 | |
599 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
600 | |
601 | // Cast-and-saturate from int32 to int16 |
602 | "sqxtn v16.4h, v16.4s\n" |
603 | "sqxtn2 v16.8h, v17.4s\n" |
604 | "sqxtn v17.4h, v18.4s\n" |
605 | "sqxtn2 v17.8h, v19.4s\n" |
606 | |
607 | // At this point, v18 -- v31 aren't used anymore for the current block, |
608 | // so we can start clearing these accumulators for the next block |
609 | // (next iteration of the main loop). |
610 | RUY_MAKE_ZERO(v18) |
611 | RUY_MAKE_ZERO(v19) |
612 | RUY_MAKE_ZERO(v20) |
613 | RUY_MAKE_ZERO(v21) |
614 | RUY_MAKE_ZERO(v22) |
615 | RUY_MAKE_ZERO(v23) |
616 | RUY_MAKE_ZERO(v24) |
617 | RUY_MAKE_ZERO(v25) |
618 | RUY_MAKE_ZERO(v26) |
619 | RUY_MAKE_ZERO(v27) |
620 | RUY_MAKE_ZERO(v28) |
621 | RUY_MAKE_ZERO(v29) |
622 | RUY_MAKE_ZERO(v30) |
623 | RUY_MAKE_ZERO(v31) |
624 | |
625 | // Add the destination zero point |
626 | "dup v14.8h, v13.h[4]\n" |
627 | "sqadd v16.8h, v16.8h, v14.8h\n" |
628 | "sqadd v17.8h, v17.8h, v14.8h\n" |
629 | |
630 | // Cast-and-saturate from int16 to uint8 |
631 | "sqxtun v16.8b, v16.8h\n" |
632 | "sqxtun2 v16.16b, v17.8h\n" |
633 | |
634 | // Load the clamp_min, clamp_max bounds |
635 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
636 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
637 | "dup v14.16b, w2\n" // clamp_min |
638 | "dup v15.16b, w3\n" // clamp_max |
639 | |
640 | // Apply the clamp_min bound |
641 | "umax v16.16b, v16.16b, v14.16b\n" |
642 | // Apply the clamp_max bound |
643 | "umin v16.16b, v16.16b, v15.16b\n" |
644 | |
645 | // Compute how much of the 4x4 block of destination 8bit values that |
646 | // we have computed, fit in the destination matrix. Typically, all of |
647 | // it fits, but when the destination matrix shape is not a multiple |
648 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
649 | // not fit entirely. |
650 | "sub w1, %w[dst_rows], %w[row]\n" |
651 | "sub w2, %w[dst_cols], %w[col]\n" |
652 | "mov w3, #4\n" |
653 | "cmp w1, #4\n" |
654 | // Compute w1 = how many rows of the 4x4 block fit |
655 | "csel w1, w1, w3, le\n" |
656 | "cmp w2, #4\n" |
657 | // Compute w2 = how many cols of the 4x4 block fit |
658 | "csel w2, w2, w3, le\n" |
659 | |
660 | // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. |
661 | "cmp w1, w3\n" |
662 | "ccmp w2, w3, 0, eq\n" |
663 | "mov x4, %[dst_ptr]\n" |
664 | // Yes, all of the 4x4 block fits, go to fast path. |
665 | "beq 30f\n" |
666 | // Not all of the 4x4 block fits. |
667 | // Store to dst_tmp_buf |
668 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
669 | // Slow loop copying from dst_tmp_buf to dst. |
670 | "mov x3, %[dst_tmp_buf]\n" |
671 | "mov w6, #0\n" |
672 | "50:\n" |
673 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
674 | "mov w5, #0\n" |
675 | "51:\n" |
676 | "ldrb w7, [x3, w5, uxtw]\n" |
677 | "strb w7, [x4, w5, uxtw]\n" |
678 | "add w5, w5, #1\n" |
679 | "cmp w5, w1\n" |
680 | "blt 51b\n" |
681 | "add w6, w6, #1\n" |
682 | "add x3, x3, #4\n" |
683 | "add x4, x4, x11\n" |
684 | "cmp w6, w2\n" |
685 | "blt 50b\n" |
686 | "b 31f\n" |
687 | "30:\n" |
688 | // Yes, all of the 4x4 block fits. |
689 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
690 | "mov x3, x4\n" |
691 | "st1 {v16.b}[0], [x3], #1\n" |
692 | "add x4, x4, x11\n" |
693 | "st1 {v16.b}[1], [x3], #1\n" |
694 | "st1 {v16.b}[2], [x3], #1\n" |
695 | "st1 {v16.b}[3], [x3], #1\n" |
696 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
697 | "mov x3, x4\n" |
698 | "st1 {v16.b}[4], [x3], #1\n" |
699 | "add x4, x4, x11\n" |
700 | "st1 {v16.b}[5], [x3], #1\n" |
701 | "st1 {v16.b}[6], [x3], #1\n" |
702 | "st1 {v16.b}[7], [x3], #1\n" |
703 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
704 | "mov x3, x4\n" |
705 | "st1 {v16.b}[8], [x3], #1\n" |
706 | "add x4, x4, x11\n" |
707 | "st1 {v16.b}[9], [x3], #1\n" |
708 | "st1 {v16.b}[10], [x3], #1\n" |
709 | "st1 {v16.b}[11], [x3], #1\n" |
710 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
711 | "mov x3, x4\n" |
712 | "st1 {v16.b}[12], [x3], #1\n" |
713 | "add x4, x4, x11\n" |
714 | "st1 {v16.b}[13], [x3], #1\n" |
715 | "st1 {v16.b}[14], [x3], #1\n" |
716 | "st1 {v16.b}[15], [x3], #1\n" |
717 | "31:\n" |
718 | |
719 | "add %[dst_ptr], %[dst_ptr], #4\n" |
720 | |
721 | RUY_MAKE_ZERO(v16) |
722 | RUY_MAKE_ZERO(v17) |
723 | |
724 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
725 | |
726 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
727 | |
728 | // Cast-and-saturate from int32 to int16 |
729 | "sqxtn v16.4h, v16.4s\n" |
730 | "sqxtn2 v16.8h, v17.4s\n" |
731 | "sqxtn v17.4h, v18.4s\n" |
732 | "sqxtn2 v17.8h, v19.4s\n" |
733 | |
734 | // At this point, v18 -- v31 aren't used anymore for the current block, |
735 | // so we can start clearing these accumulators for the next block |
736 | // (next iteration of the main loop). |
737 | RUY_MAKE_ZERO(v18) |
738 | RUY_MAKE_ZERO(v19) |
739 | RUY_MAKE_ZERO(v20) |
740 | RUY_MAKE_ZERO(v21) |
741 | RUY_MAKE_ZERO(v22) |
742 | RUY_MAKE_ZERO(v23) |
743 | RUY_MAKE_ZERO(v24) |
744 | RUY_MAKE_ZERO(v25) |
745 | RUY_MAKE_ZERO(v26) |
746 | RUY_MAKE_ZERO(v27) |
747 | RUY_MAKE_ZERO(v28) |
748 | RUY_MAKE_ZERO(v29) |
749 | RUY_MAKE_ZERO(v30) |
750 | RUY_MAKE_ZERO(v31) |
751 | |
752 | // Add the destination zero point |
753 | "dup v14.8h, v13.h[4]\n" |
754 | "sqadd v16.8h, v16.8h, v14.8h\n" |
755 | "sqadd v17.8h, v17.8h, v14.8h\n" |
756 | |
757 | // Cast-and-saturate from int16 to int8 |
758 | "sqxtn v16.8b, v16.8h\n" |
759 | "sqxtn2 v16.16b, v17.8h\n" |
760 | |
761 | // Load the clamp_min, clamp_max bounds |
762 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
763 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
764 | "dup v14.16b, w2\n" // clamp_min |
765 | "dup v15.16b, w3\n" // clamp_max |
766 | |
767 | // Apply the clamp_min bound |
768 | "smax v16.16b, v16.16b, v14.16b\n" |
769 | // Apply the clamp_max bound |
770 | "smin v16.16b, v16.16b, v15.16b\n" |
771 | |
772 | // Compute how much of the 4x4 block of destination 8bit values that |
773 | // we have computed, fit in the destination matrix. Typically, all of |
774 | // it fits, but when the destination matrix shape is not a multiple |
775 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
776 | // not fit entirely. |
777 | "sub w1, %w[dst_rows], %w[row]\n" |
778 | "sub w2, %w[dst_cols], %w[col]\n" |
779 | "mov w3, #4\n" |
780 | "cmp w1, #4\n" |
781 | // Compute w1 = how many rows of the 4x4 block fit |
782 | "csel w1, w1, w3, le\n" |
783 | "cmp w2, #4\n" |
784 | // Compute w2 = how many cols of the 4x4 block fit |
785 | "csel w2, w2, w3, le\n" |
786 | |
787 | // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. |
788 | "cmp w1, w3\n" |
789 | "ccmp w2, w3, 0, eq\n" |
790 | "mov x4, %[dst_ptr]\n" |
791 | // Yes, all of the 4x4 block fits, go to fast path. |
792 | "beq 30f\n" |
793 | // Not all of the 4x4 block fits. |
794 | // Store to dst_tmp_buf |
795 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
796 | // Slow loop copying from dst_tmp_buf to dst. |
797 | "mov x3, %[dst_tmp_buf]\n" |
798 | "mov w6, #0\n" |
799 | "50:\n" |
800 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
801 | "mov w5, #0\n" |
802 | "51:\n" |
803 | "ldrb w7, [x3, w5, uxtw]\n" |
804 | "strb w7, [x4, w5, uxtw]\n" |
805 | "add w5, w5, #1\n" |
806 | "cmp w5, w1\n" |
807 | "blt 51b\n" |
808 | "add w6, w6, #1\n" |
809 | "add x3, x3, #4\n" |
810 | "add x4, x4, x11\n" |
811 | "cmp w6, w2\n" |
812 | "blt 50b\n" |
813 | "b 31f\n" |
814 | "30:\n" |
815 | // Yes, all of the 4x4 block fits. |
816 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
817 | "mov x3, x4\n" |
818 | "st1 {v16.b}[0], [x3], #1\n" |
819 | "add x4, x4, x11\n" |
820 | "st1 {v16.b}[1], [x3], #1\n" |
821 | "st1 {v16.b}[2], [x3], #1\n" |
822 | "st1 {v16.b}[3], [x3], #1\n" |
823 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
824 | "mov x3, x4\n" |
825 | "st1 {v16.b}[4], [x3], #1\n" |
826 | "add x4, x4, x11\n" |
827 | "st1 {v16.b}[5], [x3], #1\n" |
828 | "st1 {v16.b}[6], [x3], #1\n" |
829 | "st1 {v16.b}[7], [x3], #1\n" |
830 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
831 | "mov x3, x4\n" |
832 | "st1 {v16.b}[8], [x3], #1\n" |
833 | "add x4, x4, x11\n" |
834 | "st1 {v16.b}[9], [x3], #1\n" |
835 | "st1 {v16.b}[10], [x3], #1\n" |
836 | "st1 {v16.b}[11], [x3], #1\n" |
837 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
838 | "mov x3, x4\n" |
839 | "st1 {v16.b}[12], [x3], #1\n" |
840 | "add x4, x4, x11\n" |
841 | "st1 {v16.b}[13], [x3], #1\n" |
842 | "st1 {v16.b}[14], [x3], #1\n" |
843 | "st1 {v16.b}[15], [x3], #1\n" |
844 | "31:\n" |
845 | |
846 | "add %[dst_ptr], %[dst_ptr], #4\n" |
847 | |
848 | RUY_MAKE_ZERO(v16) |
849 | RUY_MAKE_ZERO(v17) |
850 | |
851 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
852 | |
853 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
854 | |
855 | // Add the destination zero point |
856 | "dup v14.4h, v13.h[4]\n" |
857 | "saddw v16.4s, v16.4s, v14.4h\n" |
858 | "saddw v17.4s, v17.4s, v14.4h\n" |
859 | "saddw v18.4s, v18.4s, v14.4h\n" |
860 | "saddw v19.4s, v19.4s, v14.4h\n" |
861 | |
862 | // Cast-and-saturate from int32 to int16 |
863 | "sqxtn v16.4h, v16.4s\n" |
864 | "sqxtn2 v16.8h, v17.4s\n" |
865 | "sqxtn v17.4h, v18.4s\n" |
866 | "sqxtn2 v17.8h, v19.4s\n" |
867 | |
868 | // At this point, v18 -- v31 aren't used anymore for the current block, |
869 | // so we can start clearing these accumulators for the next block |
870 | // (next iteration of the main loop). |
871 | RUY_MAKE_ZERO(v18) |
872 | RUY_MAKE_ZERO(v19) |
873 | RUY_MAKE_ZERO(v20) |
874 | RUY_MAKE_ZERO(v21) |
875 | RUY_MAKE_ZERO(v22) |
876 | RUY_MAKE_ZERO(v23) |
877 | RUY_MAKE_ZERO(v24) |
878 | RUY_MAKE_ZERO(v25) |
879 | RUY_MAKE_ZERO(v26) |
880 | RUY_MAKE_ZERO(v27) |
881 | RUY_MAKE_ZERO(v28) |
882 | RUY_MAKE_ZERO(v29) |
883 | RUY_MAKE_ZERO(v30) |
884 | RUY_MAKE_ZERO(v31) |
885 | |
886 | // Load the clamp_min, clamp_max bounds |
887 | "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
888 | "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
889 | "dup v14.8h, w2\n" // clamp_min |
890 | "dup v15.8h, w3\n" // clamp_max |
891 | |
892 | // Apply the clamp_min bound |
893 | "smax v16.8h, v16.8h, v14.8h\n" |
894 | "smax v17.8h, v17.8h, v14.8h\n" |
895 | // Apply the clamp_max bound |
896 | "smin v16.8h, v16.8h, v15.8h\n" |
897 | "smin v17.8h, v17.8h, v15.8h\n" |
898 | |
899 | // Compute how much of the 4x4 block of destination 8bit values that |
900 | // we have computed, fit in the destination matrix. Typically, all of |
901 | // it fits, but when the destination matrix shape is not a multiple |
902 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
903 | // not fit entirely. |
904 | "sub w1, %w[dst_rows], %w[row]\n" |
905 | "sub w2, %w[dst_cols], %w[col]\n" |
906 | "mov w3, #4\n" |
907 | "cmp w1, #4\n" |
908 | // Compute w1 = how many rows of the 4x4 block fit |
909 | "csel w1, w1, w3, le\n" |
910 | "cmp w2, #4\n" |
911 | // Compute w2 = how many cols of the 4x4 block fit |
912 | "csel w2, w2, w3, le\n" |
913 | |
914 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
915 | "cmp w1, w3\n" |
916 | "ccmp w2, w3, 0, eq\n" |
917 | "mov x4, %[dst_ptr]\n" |
918 | // Yes, all of the 4x4 block fits, go to fast path. |
919 | "beq 30f\n" |
920 | // Not all of the 4x4 block fits. |
921 | // Store to dst_tmp_buf |
922 | "str q16, [%[dst_tmp_buf], #0]\n" |
923 | "str q17, [%[dst_tmp_buf], #16]\n" |
924 | // Slow loop copying from dst_tmp_buf to dst. |
925 | "mov x3, %[dst_tmp_buf]\n" |
926 | "mov w6, #0\n" |
927 | "50:\n" |
928 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
929 | "mov w5, #0\n" |
930 | "51:\n" |
931 | "ldrh w7, [x3, x5, lsl #1]\n" |
932 | "strh w7, [x4, x5, lsl #1]\n" |
933 | "add w5, w5, #1\n" |
934 | "cmp w5, w1\n" |
935 | "blt 51b\n" |
936 | "add w6, w6, #1\n" |
937 | "add x3, x3, #8\n" |
938 | "add x4, x4, x11\n" |
939 | "cmp w6, w2\n" |
940 | "blt 50b\n" |
941 | "b 31f\n" |
942 | "30:\n" |
943 | // Yes, all of the 4x4 block fits. |
944 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
945 | "mov x3, x4\n" |
946 | "st1 {v16.h}[0], [x3], #2\n" |
947 | "add x4, x4, x11\n" |
948 | "st1 {v16.h}[1], [x3], #2\n" |
949 | "st1 {v16.h}[2], [x3], #2\n" |
950 | "st1 {v16.h}[3], [x3], #2\n" |
951 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
952 | "mov x3, x4\n" |
953 | "st1 {v16.h}[4], [x3], #2\n" |
954 | "add x4, x4, x11\n" |
955 | "st1 {v16.h}[5], [x3], #2\n" |
956 | "st1 {v16.h}[6], [x3], #2\n" |
957 | "st1 {v16.h}[7], [x3], #2\n" |
958 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
959 | "mov x3, x4\n" |
960 | "st1 {v17.h}[0], [x3], #2\n" |
961 | "add x4, x4, x11\n" |
962 | "st1 {v17.h}[1], [x3], #2\n" |
963 | "st1 {v17.h}[2], [x3], #2\n" |
964 | "st1 {v17.h}[3], [x3], #2\n" |
965 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
966 | "mov x3, x4\n" |
967 | "st1 {v17.h}[4], [x3], #2\n" |
968 | "add x4, x4, x11\n" |
969 | "st1 {v17.h}[5], [x3], #2\n" |
970 | "st1 {v17.h}[6], [x3], #2\n" |
971 | "st1 {v17.h}[7], [x3], #2\n" |
972 | "31:\n" |
973 | |
974 | "add %[dst_ptr], %[dst_ptr], #8\n" |
975 | |
976 | RUY_MAKE_ZERO(v16) |
977 | RUY_MAKE_ZERO(v17) |
978 | |
979 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
980 | |
981 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
982 | |
983 | // Since the store type is the same as the accum type, no need for |
984 | // downcast. There's also no need for clamp by min/max. |
985 | |
986 | // At this point, v20 -- v31 aren't used anymore for the current block, |
987 | // so we can start clearing these accumulators for the next block |
988 | // (next iteration of the main loop). |
989 | RUY_MAKE_ZERO(v20) |
990 | RUY_MAKE_ZERO(v21) |
991 | RUY_MAKE_ZERO(v22) |
992 | RUY_MAKE_ZERO(v23) |
993 | RUY_MAKE_ZERO(v24) |
994 | RUY_MAKE_ZERO(v25) |
995 | RUY_MAKE_ZERO(v26) |
996 | RUY_MAKE_ZERO(v27) |
997 | RUY_MAKE_ZERO(v28) |
998 | RUY_MAKE_ZERO(v29) |
999 | RUY_MAKE_ZERO(v30) |
1000 | RUY_MAKE_ZERO(v31) |
1001 | |
1002 | // Compute how much of the 4x4 block of destination 8bit values that |
1003 | // we have computed, fit in the destination matrix. Typically, all of |
1004 | // it fits, but when the destination matrix shape is not a multiple |
1005 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
1006 | // not fit entirely. |
1007 | "sub w1, %w[dst_rows], %w[row]\n" |
1008 | "sub w2, %w[dst_cols], %w[col]\n" |
1009 | "mov w3, #4\n" |
1010 | "cmp w1, #4\n" |
1011 | // Compute w1 = how many rows of the 4x4 block fit |
1012 | "csel w1, w1, w3, le\n" |
1013 | "cmp w2, #4\n" |
1014 | // Compute w2 = how many cols of the 4x4 block fit |
1015 | "csel w2, w2, w3, le\n" |
1016 | |
1017 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
1018 | "cmp w1, w3\n" |
1019 | "ccmp w2, w3, 0, eq\n" |
1020 | "mov x4, %[dst_ptr]\n" |
1021 | // Yes, all of the 4x4 block fits, go to fast path. |
1022 | "beq 30f\n" |
1023 | // Not all of the 4x4 block fits. |
1024 | // Store to dst_tmp_buf |
1025 | "str q16, [%[dst_tmp_buf], #0]\n" |
1026 | "str q17, [%[dst_tmp_buf], #16]\n" |
1027 | "str q18, [%[dst_tmp_buf], #32]\n" |
1028 | "str q19, [%[dst_tmp_buf], #48]\n" |
1029 | // Slow loop copying from dst_tmp_buf to dst. |
1030 | "mov x3, %[dst_tmp_buf]\n" |
1031 | "mov w6, #0\n" |
1032 | "50:\n" |
1033 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1034 | "mov w5, #0\n" |
1035 | "51:\n" |
1036 | "ldr w7, [x3, x5, lsl #2]\n" |
1037 | "str w7, [x4, x5, lsl #2]\n" |
1038 | "add w5, w5, #1\n" |
1039 | "cmp w5, w1\n" |
1040 | "blt 51b\n" |
1041 | "add w6, w6, #1\n" |
1042 | "add x3, x3, #16\n" |
1043 | "add x4, x4, x11\n" |
1044 | "cmp w6, w2\n" |
1045 | "blt 50b\n" |
1046 | "b 31f\n" |
1047 | "30:\n" |
1048 | // Yes, all of the 4x4 block fits. |
1049 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1050 | "mov x3, x4\n" |
1051 | "st1 {v16.s}[0], [x3], #4\n" |
1052 | "add x4, x4, x11\n" |
1053 | "st1 {v16.s}[1], [x3], #4\n" |
1054 | "st1 {v16.s}[2], [x3], #4\n" |
1055 | "st1 {v16.s}[3], [x3], #4\n" |
1056 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1057 | "mov x3, x4\n" |
1058 | "st1 {v17.s}[0], [x3], #4\n" |
1059 | "add x4, x4, x11\n" |
1060 | "st1 {v17.s}[1], [x3], #4\n" |
1061 | "st1 {v17.s}[2], [x3], #4\n" |
1062 | "st1 {v17.s}[3], [x3], #4\n" |
1063 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1064 | "mov x3, x4\n" |
1065 | "st1 {v18.s}[0], [x3], #4\n" |
1066 | "add x4, x4, x11\n" |
1067 | "st1 {v18.s}[1], [x3], #4\n" |
1068 | "st1 {v18.s}[2], [x3], #4\n" |
1069 | "st1 {v18.s}[3], [x3], #4\n" |
1070 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1071 | "mov x3, x4\n" |
1072 | "st1 {v19.s}[0], [x3], #4\n" |
1073 | "add x4, x4, x11\n" |
1074 | "st1 {v19.s}[1], [x3], #4\n" |
1075 | "st1 {v19.s}[2], [x3], #4\n" |
1076 | "st1 {v19.s}[3], [x3], #4\n" |
1077 | "31:\n" |
1078 | |
1079 | "add %[dst_ptr], %[dst_ptr], #16\n" |
1080 | |
1081 | RUY_MAKE_ZERO(v16) |
1082 | RUY_MAKE_ZERO(v17) |
1083 | RUY_MAKE_ZERO(v18) |
1084 | RUY_MAKE_ZERO(v19) |
1085 | |
1086 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
1087 | |
1088 | // For the next block: perform the first few multiply-adds on the data |
1089 | // that we have already loaded. |
1090 | "smull v8.8h, v0.8b, v4.8b\n" |
1091 | "smull v9.8h, v1.8b, v4.8b\n" |
1092 | "smull v10.8h, v2.8b, v4.8b\n" |
1093 | "smull v11.8h, v3.8b, v4.8b\n" |
1094 | "smull v12.8h, v0.8b, v5.8b\n" |
1095 | "smull v13.8h, v1.8b, v5.8b\n" |
1096 | "smull v14.8h, v2.8b, v5.8b\n" |
1097 | "smull v15.8h, v3.8b, v5.8b\n" |
1098 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
1099 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
1100 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
1101 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
1102 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
1103 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
1104 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
1105 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
1106 | |
1107 | // Reload some params --- we had used x5 -- x7 for a few other things |
1108 | // since the last time we had loaded them. |
1109 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1110 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1111 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1112 | |
1113 | // Move to the next block of the destination matrix, for the next iter |
1114 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
1115 | // been updated earlier. |
1116 | // Have we reached the end row? |
1117 | "cmp %w[row], w7\n" |
1118 | "beq 20f\n" // yes, end row. |
1119 | // Not end row. Move to the next row. |
1120 | "add %w[row], %w[row], #4\n" |
1121 | "b 21f\n" |
1122 | "20:\n" |
1123 | // Was already at end row. |
1124 | "mov %w[row], w6\n" // Move back to first row. |
1125 | "add %w[col], %w[col], #4\n" // Move to the next column. |
1126 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" |
1127 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
1128 | "21:\n" |
1129 | |
1130 | // Main loop exit condition: have we hit the end column? |
1131 | "cmp %w[col], w8\n" |
1132 | |
1133 | // w1 is the number of levels of depth that we have already loaded |
1134 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1135 | // above, this is currently 4. |
1136 | "mov w1, #16\n" |
1137 | |
1138 | "ble 1b\n" |
1139 | |
1140 | // clang-format on |
1141 | |
1142 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
1143 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
1144 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
1145 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
1146 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
1147 | [dst_type_id] "r" (params.dst_type_id) |
1148 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
1149 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
1150 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
1151 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
1152 | } |
1153 | |
1154 | // Similar to existing Kernel8bitNeon but specialized for the case of |
1155 | // RHS cols == 1. |
1156 | // Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, |
1157 | // since these are 64-bit, out-of-order and without dotprod support. |
1158 | void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) { |
1159 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
1160 | |
1161 | CheckOffsetsInKernelParams8bit(params); |
1162 | |
1163 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
1164 | const std::int8_t* rhs_col_ptr = |
1165 | static_cast<const int8_t*>(params.rhs_base_ptr); |
1166 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
1167 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
1168 | void* dst_col_ptr = params.dst_base_ptr; |
1169 | void* dst_ptr = dst_col_ptr; |
1170 | int row = params.start_row; |
1171 | int col = params.start_col; |
1172 | |
1173 | RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); |
1174 | |
1175 | // The asm kernel below has the following NEON register allocation: |
1176 | // |
1177 | // v16 -- v19 are int32 accumulators. |
1178 | // During accumulation, v0 -- v3 are used to load int8 data from LHS and |
1179 | // v4 from RHS: |
1180 | // |
1181 | // int8 RHS 16x1 block |
1182 | // /-----------| |
1183 | // |v4.b[0] | |
1184 | // | ... | |
1185 | // |v4.b[15] | |
1186 | // \-----------/ |
1187 | // int8 LHS 4x16 block |
1188 | // /---------------------\ /-----------| |
1189 | // |v0.b[0] ... v0.b[15] | |v16.4s | |
1190 | // |v1.b[0] ... v1.b[15] | |v17.4s | |
1191 | // |v2.b[0] ... v2.b[15] | |v18.4s | |
1192 | // |v3.b[0] ... v3.b[15] | |v19.4s | |
1193 | // \---------------------/ \-----------/ |
1194 | // int32 accumulators 4x1 block |
1195 | // |
1196 | // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
1197 | // optimization for this kernel. |
1198 | asm volatile( |
1199 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
1200 | |
1201 | // clang-format off |
1202 | |
1203 | // Load some parameters into registers. |
1204 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1205 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1206 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1207 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
1208 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
1209 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
1210 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1211 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
1212 | |
1213 | // Load the first 64 bytes of LHS and RHS data. |
1214 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
1215 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
1216 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
1217 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
1218 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
1219 | "add %[rhs_ptr], %[rhs_ptr], #48\n" |
1220 | |
1221 | // Clear accumulators. |
1222 | RUY_MAKE_ZERO(v16) |
1223 | RUY_MAKE_ZERO(v17) |
1224 | RUY_MAKE_ZERO(v18) |
1225 | RUY_MAKE_ZERO(v19) |
1226 | |
1227 | // w1 is the number of levels of depth that we have already loaded |
1228 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1229 | // above, this is currently 16. |
1230 | "mov w1, #16\n" |
1231 | |
1232 | // Perform the first few multiply-adds on the data that we have already |
1233 | // loaded. |
1234 | "smull v8.8h, v0.8b, v4.8b\n" |
1235 | "smull v9.8h, v1.8b, v4.8b\n" |
1236 | "smull v10.8h, v2.8b, v4.8b\n" |
1237 | "smull v11.8h, v3.8b, v4.8b\n" |
1238 | |
1239 | // Multiply-accumulate second-half, again into the same |
1240 | // 16bit local accumulator registers. This is where we |
1241 | // take advantage of having int8 instead of uint8 and therefore |
1242 | // being able to accumulate two products into int16. |
1243 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
1244 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
1245 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
1246 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
1247 | |
1248 | // Main loop of the whole GEMM, over rows and columns of the |
1249 | // destination matrix. |
1250 | "1:\n" |
1251 | |
1252 | // Reminder - w1 is how many levels of depth we have already loaded |
1253 | // data for, w12 is the total depth. |
1254 | "cmp w1, w12\n" |
1255 | "beq 79f\n" |
1256 | |
1257 | "2:\n" |
1258 | |
1259 | // Some multiplications and 16-bit accumulation were already done above, |
1260 | // so we start right away in the middle. |
1261 | "sadalp v16.4s, v8.8h\n" |
1262 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
1263 | "add %[rhs_ptr], %[rhs_ptr], #48\n" |
1264 | "sadalp v17.4s, v9.8h\n" |
1265 | "sadalp v18.4s, v10.8h\n" |
1266 | "sadalp v19.4s, v11.8h\n" |
1267 | |
1268 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
1269 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
1270 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
1271 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
1272 | |
1273 | "smull v8.8h, v0.8b, v4.8b\n" |
1274 | "smull v9.8h, v1.8b, v4.8b\n" |
1275 | "smull v10.8h, v2.8b, v4.8b\n" |
1276 | "smull v11.8h, v3.8b, v4.8b\n" |
1277 | |
1278 | // Multiply-accumulate second-half, again into the same |
1279 | // 16bit local accumulator registers. This is where we |
1280 | // take advantage of having int8 instead of uint8 and therefore |
1281 | // being able to accumulate two products into int16. |
1282 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
1283 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
1284 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
1285 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
1286 | |
1287 | // Each iteration of this loop advances by 16 levels of depth. |
1288 | "add w1, w1, #16\n" |
1289 | |
1290 | // Loop termination condition |
1291 | "cmp w1, w12\n" |
1292 | |
1293 | "blt 2b\n" |
1294 | |
1295 | "79:\n" |
1296 | |
1297 | "sadalp v16.4s, v8.8h\n" |
1298 | "sadalp v17.4s, v9.8h\n" |
1299 | "sadalp v18.4s, v10.8h\n" |
1300 | "sadalp v19.4s, v11.8h\n" |
1301 | |
1302 | // End of accumulation. The registers v16 -- v19 contain the final |
1303 | // int32 accumulator values of the current 4x1 destination block. |
1304 | // We now have to compute the final 8-bit values from these int32 |
1305 | // accumulators, and advance to the next 4x1 block. We intertwine |
1306 | // these two aspects whenever possible for optimal pipelining, both |
1307 | // at the data flow level (prefetch data for next block as early as |
1308 | // possible) and instruction pipelining level (some of the next-block |
1309 | // work can dual-issue with some of the final work on the current |
1310 | // block). |
1311 | |
1312 | // Reduce 32bit accumulators horizontally. |
1313 | "addp v16.4s, v16.4s, v17.4s\n" |
1314 | "addp v18.4s, v18.4s, v19.4s\n" |
1315 | |
1316 | // Reduce 32bit accumulators horizontally, second pass |
1317 | // (each pass adds pairwise. we need to add 4-wise). |
1318 | "addp v16.4s, v16.4s, v18.4s\n" |
1319 | |
1320 | // Logic to advance to the next block in preparation for the next |
1321 | // iteration of the main loop. For now, we only want to compute |
1322 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
1323 | // not yet ready to update the values of row and col, as we still need |
1324 | // the current values for the rest of the work on the current block. |
1325 | |
1326 | "cmp %w[row], w7\n" // Have we finished the last row? |
1327 | "bge 4f\n" // If finished last row, go to 4 |
1328 | // Not finished last row: then advance to next row. |
1329 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" |
1330 | "b 5f\n" |
1331 | "4:\n" // Finished last row... |
1332 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
1333 | // Now we need to advance to the next column. If we already |
1334 | // finished the last column, then in principle we are done, however |
1335 | // we can't just return here, as we need to allow the end work of the |
1336 | // current block to complete. The good news is that at this point it |
1337 | // doesn't matter what data we load for the next column, since |
1338 | // we will exit from the main loop below before actually storing |
1339 | // anything computed from that data. |
1340 | "cmp %w[col], w8\n" // Have we finished the last column? |
1341 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
1342 | // Not finished last column: then advance to next column. |
1343 | // (still multiply column stride by 4 due to packing) |
1344 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" |
1345 | "5:\n" |
1346 | |
1347 | // Set the LHS and RHS data pointers to the start of the columns just |
1348 | // computed. |
1349 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
1350 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
1351 | |
1352 | // Load some parameters needed for the end work on current block. |
1353 | "mvni v8.4s, #0\n" |
1354 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
1355 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
1356 | "ins v13.h[4], w4\n" // dst_zero_point |
1357 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
1358 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
1359 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
1360 | "add x5, x4, %x[row], lsl #2\n" |
1361 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
1362 | "csel x4, x4, x5, eq\n" |
1363 | |
1364 | "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint |
1365 | |
1366 | // Now we load: bias data, LHS sums data, RHS sums data. |
1367 | |
1368 | // First, load the base pointers from the params. |
1369 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
1370 | |
1371 | "add x5, x1, %x[row], lsl #2\n" |
1372 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
1373 | "csel x1, x1, x5, eq\n" |
1374 | |
1375 | // Load 4 bias values. |
1376 | "ld1 {v14.4s}, [x1]\n" |
1377 | |
1378 | // Now that we know what LHS and RHS data the next iteration of the |
1379 | // main loop will need to load, we start loading the first 32 bytes of |
1380 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
1381 | // in the rest of the work on the current block. |
1382 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
1383 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
1384 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
1385 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
1386 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
1387 | "add %[rhs_ptr], %[rhs_ptr], #48\n" |
1388 | |
1389 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
1390 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
1391 | "add v14.4s, v14.4s, v9.4s\n" |
1392 | |
1393 | // Perform the bias-addition (per the above, we have just folded into |
1394 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
1395 | // (all four 32-bit accumulators are in v16 at this point) |
1396 | "add v16.4s, v16.4s, v14.4s\n" |
1397 | |
1398 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
1399 | "beq 401f\n" |
1400 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
1401 | "add x3, x3, %x[col], lsl #2\n" |
1402 | "ld1 {v14.4s}, [x3]\n" |
1403 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
1404 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
1405 | // Subtract rhs_sums * lhs_zero_point, per |
1406 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
1407 | "mls v16.4s, v10.4s, v14.s[0]\n" |
1408 | "401:\n" |
1409 | |
1410 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
1411 | "beq 402f\n" |
1412 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
1413 | "add x2, x2, %x[row], lsl #2\n" |
1414 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
1415 | // Load 4 lhs_sums values. |
1416 | "ld1 {v11.4s}, [x2]\n" |
1417 | "ins v13.s[1], w5\n" // rhs_zero_point |
1418 | // Compute lhs_sums * rhs_zero_point. |
1419 | "mul v11.4s, v11.4s, v13.s[1]\n" |
1420 | // Subtract lhs_sums * rhs_zero_point, per |
1421 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
1422 | "sub v16.4s, v16.4s, v11.4s\n" |
1423 | |
1424 | // If the destination is int32, it means the user asks for the raw |
1425 | // accumulators, no need for us to downquantize the value. |
1426 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
1427 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
1428 | |
1429 | "402:\n" |
1430 | |
1431 | // At this point we have computed the final int32 values. Now we |
1432 | // start down-quantizing them to obtain the final 8bit values from them. |
1433 | |
1434 | // As part of this down-quantization, our int32 values will be |
1435 | // multiplied by a multiplier that has a fixed-point component and an |
1436 | // exponent component. |
1437 | |
1438 | //Load the exponent part of the multiplier. |
1439 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
1440 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
1441 | "add x5, x1, %x[row], lsl #2\n" |
1442 | "csel x1, x1, x5, eq\n" |
1443 | |
1444 | "ld1 {v14.4s}, [x1]\n" |
1445 | |
1446 | "smin v11.4s, v8.4s, v14.4s\n" |
1447 | "sub v12.4s, v14.4s, v11.4s\n" |
1448 | |
1449 | // Apply the positive exponent part of the multiplier. |
1450 | "sshl v16.4s, v16.4s, v12.4s\n" |
1451 | |
1452 | // Apply the fixed-point part of the multiplier. |
1453 | "sqdmulh v16.4s, v16.4s, v15.4s\n" |
1454 | |
1455 | // Apply the negative exponent part of the multiplier. |
1456 | "srshl v16.4s, v16.4s, v11.4s\n" |
1457 | |
1458 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
1459 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
1460 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
1461 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
1462 | |
1463 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
1464 | |
1465 | // Cast-and-saturate from int32 to int16 |
1466 | // After this instruction, all data is in lower half (64-bits) of v16 |
1467 | "sqxtn v16.4h, v16.4s\n" |
1468 | |
1469 | // At this point, v18 -- v31 aren't used anymore for the current block, |
1470 | // so we can start clearing these accumulators for the next block |
1471 | // (next iteration of the main loop). |
1472 | RUY_MAKE_ZERO(v18) |
1473 | RUY_MAKE_ZERO(v19) |
1474 | |
1475 | // Add the destination zero point |
1476 | "dup v14.8h, v13.h[4]\n" |
1477 | "sqadd v16.8h, v16.8h, v14.8h\n" |
1478 | |
1479 | // Cast-and-saturate from int16 to uint8 |
1480 | // Now all data is in the first 32-bits of v16 |
1481 | "sqxtun v16.8b, v16.8h\n" |
1482 | |
1483 | // Load the clamp_min, clamp_max bounds |
1484 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1485 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1486 | "dup v14.16b, w2\n" // clamp_min |
1487 | "dup v15.16b, w3\n" // clamp_max |
1488 | |
1489 | // Apply the clamp_min bound |
1490 | "umax v16.16b, v16.16b, v14.16b\n" |
1491 | // Apply the clamp_max bound |
1492 | "umin v16.16b, v16.16b, v15.16b\n" |
1493 | |
1494 | // Compute how much of the 4x1 block of destination 8bit values that |
1495 | // we have computed, fit in the destination matrix. Typically, all of |
1496 | // it fits, but when the destination matrix shape is not a multiple |
1497 | // of 4x1, there are some 4x1 blocks along the boundaries that do |
1498 | // not fit entirely. |
1499 | "sub w1, %w[dst_rows], %w[row]\n" |
1500 | "mov w3, #4\n" |
1501 | "cmp w1, #4\n" |
1502 | // Compute w1 = how many rows of the 4x1 block fit |
1503 | "csel w1, w1, w3, le\n" |
1504 | |
1505 | // Test if w1==4, i.e. if all of the 4x1 block fits. |
1506 | "cmp w1, w3\n" |
1507 | |
1508 | "mov x4, %[dst_ptr]\n" |
1509 | // Yes, all of the 4x1 block fits, go to fast path. |
1510 | "beq 30f\n" |
1511 | // Not all of the 4x1 block fits. |
1512 | // Store to dst_tmp_buf |
1513 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
1514 | // Slow loop copying from dst_tmp_buf to dst. |
1515 | "mov x3, %[dst_tmp_buf]\n" |
1516 | "mov w6, #0\n" |
1517 | "50:\n" |
1518 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1519 | "mov w5, #0\n" |
1520 | "51:\n" |
1521 | "ldrb w7, [x3, w5, uxtw]\n" |
1522 | "strb w7, [x4, w5, uxtw]\n" |
1523 | "add w5, w5, #1\n" |
1524 | "cmp w5, w1\n" |
1525 | "blt 51b\n" |
1526 | "b 31f\n" |
1527 | "30:\n" |
1528 | // Yes, all of the 4x1 block fits. |
1529 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1530 | "mov x3, x4\n" |
1531 | "st1 {v16.b}[0], [x3], #1\n" |
1532 | "st1 {v16.b}[1], [x3], #1\n" |
1533 | "st1 {v16.b}[2], [x3], #1\n" |
1534 | "st1 {v16.b}[3], [x3], #1\n" |
1535 | "31:\n" |
1536 | |
1537 | "add %[dst_ptr], %[dst_ptr], #4\n" |
1538 | |
1539 | RUY_MAKE_ZERO(v16) |
1540 | RUY_MAKE_ZERO(v17) |
1541 | |
1542 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1543 | |
1544 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
1545 | |
1546 | // Cast-and-saturate from int32 to int16 |
1547 | // After this, all values for output are in the lower half (64 bits) of v16. |
1548 | "sqxtn v16.4h, v16.4s\n" |
1549 | |
1550 | // At this point, v18 -- v31 aren't used anymore for the current block, |
1551 | // so we can start clearing these accumulators for the next block |
1552 | // (next iteration of the main loop). |
1553 | RUY_MAKE_ZERO(v18) |
1554 | RUY_MAKE_ZERO(v19) |
1555 | |
1556 | // Add the destination zero point |
1557 | "dup v14.8h, v13.h[4]\n" |
1558 | "sqadd v16.8h, v16.8h, v14.8h\n" |
1559 | |
1560 | // Cast-and-saturate from int16 to int8 |
1561 | "sqxtn v16.8b, v16.8h\n" |
1562 | // At this point, we only need 4 lowest 8-bit values in v16. |
1563 | |
1564 | // Load the clamp_min, clamp_max bounds |
1565 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1566 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1567 | "dup v14.16b, w2\n" // clamp_min |
1568 | "dup v15.16b, w3\n" // clamp_max |
1569 | |
1570 | // Apply the clamp_min bound |
1571 | "smax v16.16b, v16.16b, v14.16b\n" |
1572 | // Apply the clamp_max bound |
1573 | "smin v16.16b, v16.16b, v15.16b\n" |
1574 | |
1575 | // Compute how much of the 4x4 block of destination 8bit values that |
1576 | // we have computed, fit in the destination matrix. Typically, all of |
1577 | // it fits, but when the destination matrix shape is not a multiple |
1578 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
1579 | // not fit entirely. |
1580 | "sub w1, %w[dst_rows], %w[row]\n" |
1581 | "sub w2, %w[dst_cols], %w[col]\n" |
1582 | "mov w3, #4\n" |
1583 | "cmp w1, #4\n" |
1584 | // Compute w1 = how many rows of the 4x1 block fit |
1585 | "csel w1, w1, w3, le\n" |
1586 | "cmp w2, #4\n" |
1587 | |
1588 | // Test if w1==4, i.e. if all of the 4x1 block fits. |
1589 | "cmp w1, w3\n" |
1590 | "ccmp w2, w3, 0, eq\n" |
1591 | "mov x4, %[dst_ptr]\n" |
1592 | // Yes, all of the 4x1 block fits, go to fast path. |
1593 | "beq 30f\n" |
1594 | // Not all of the 4x4 block fits. |
1595 | // Store to dst_tmp_buf |
1596 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
1597 | // Slow loop copying from dst_tmp_buf to dst. |
1598 | "mov x3, %[dst_tmp_buf]\n" |
1599 | "mov w6, #0\n" |
1600 | "50:\n" |
1601 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1602 | "mov w5, #0\n" |
1603 | "51:\n" |
1604 | "ldrb w7, [x3, w5, uxtw]\n" |
1605 | "strb w7, [x4, w5, uxtw]\n" |
1606 | "add w5, w5, #1\n" |
1607 | "cmp w5, w1\n" |
1608 | "blt 51b\n" |
1609 | "b 31f\n" |
1610 | "30:\n" |
1611 | // Yes, all of the 4x4 block fits. |
1612 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1613 | "mov x3, x4\n" |
1614 | "st1 {v16.b}[0], [x3], #1\n" |
1615 | "st1 {v16.b}[1], [x3], #1\n" |
1616 | "st1 {v16.b}[2], [x3], #1\n" |
1617 | "st1 {v16.b}[3], [x3], #1\n" |
1618 | "31:\n" |
1619 | |
1620 | "add %[dst_ptr], %[dst_ptr], #4\n" |
1621 | |
1622 | RUY_MAKE_ZERO(v16) |
1623 | RUY_MAKE_ZERO(v17) |
1624 | |
1625 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1626 | |
1627 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
1628 | |
1629 | // Add the destination zero point |
1630 | "dup v14.4h, v13.h[4]\n" |
1631 | "saddw v16.4s, v16.4s, v14.4h\n" |
1632 | |
1633 | // Cast-and-saturate from int32 to int16 |
1634 | // After this instruction, all data is in lower half of v16. |
1635 | "sqxtn v16.4h, v16.4s\n" |
1636 | |
1637 | // At this point, v18 -- v31 aren't used anymore for the current block, |
1638 | // so we can start clearing these accumulators for the next block |
1639 | // (next iteration of the main loop). |
1640 | RUY_MAKE_ZERO(v18) |
1641 | RUY_MAKE_ZERO(v19) |
1642 | |
1643 | // Load the clamp_min, clamp_max bounds |
1644 | "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1645 | "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1646 | "dup v14.8h, w2\n" // clamp_min |
1647 | "dup v15.8h, w3\n" // clamp_max |
1648 | |
1649 | // Apply the clamp_min bound |
1650 | "smax v16.8h, v16.8h, v14.8h\n" |
1651 | // Apply the clamp_max bound |
1652 | "smin v16.8h, v16.8h, v15.8h\n" |
1653 | |
1654 | // Compute how much of the 4x4 block of destination 8bit values that |
1655 | // we have computed, fit in the destination matrix. Typically, all of |
1656 | // it fits, but when the destination matrix shape is not a multiple |
1657 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
1658 | // not fit entirely. |
1659 | "sub w1, %w[dst_rows], %w[row]\n" |
1660 | "sub w2, %w[dst_cols], %w[col]\n" |
1661 | "mov w3, #4\n" |
1662 | "cmp w1, #4\n" |
1663 | // Compute w1 = how many rows of the 4x4 block fit |
1664 | "csel w1, w1, w3, le\n" |
1665 | "cmp w2, #4\n" |
1666 | |
1667 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
1668 | "cmp w1, w3\n" |
1669 | "mov x4, %[dst_ptr]\n" |
1670 | // Yes, all of the 4x4 block fits, go to fast path. |
1671 | "beq 30f\n" |
1672 | // Not all of the 4x4 block fits. |
1673 | // Store to dst_tmp_buf |
1674 | "str q16, [%[dst_tmp_buf], #0]\n" |
1675 | // Slow loop copying from dst_tmp_buf to dst. |
1676 | "mov x3, %[dst_tmp_buf]\n" |
1677 | "mov w6, #0\n" |
1678 | "50:\n" |
1679 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1680 | "mov w5, #0\n" |
1681 | "51:\n" |
1682 | "ldrh w7, [x3, x5, lsl #1]\n" |
1683 | "strh w7, [x4, x5, lsl #1]\n" |
1684 | "add w5, w5, #1\n" |
1685 | "cmp w5, w1\n" |
1686 | "blt 51b\n" |
1687 | "blt 50b\n" |
1688 | "b 31f\n" |
1689 | "30:\n" |
1690 | // Yes, all of the 4x4 block fits. |
1691 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1692 | "mov x3, x4\n" |
1693 | "st1 {v16.h}[0], [x3], #2\n" |
1694 | "st1 {v16.h}[1], [x3], #2\n" |
1695 | "st1 {v16.h}[2], [x3], #2\n" |
1696 | "st1 {v16.h}[3], [x3], #2\n" |
1697 | "31:\n" |
1698 | |
1699 | "add %[dst_ptr], %[dst_ptr], #8\n" |
1700 | |
1701 | RUY_MAKE_ZERO(v16) |
1702 | RUY_MAKE_ZERO(v17) |
1703 | |
1704 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1705 | |
1706 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
1707 | |
1708 | // Since the store type is the same as the accum type, no need for |
1709 | // downcast. There's also no need for clamp by min/max. |
1710 | |
1711 | // Compute how much of the 4x4 block of destination 8bit values that |
1712 | // we have computed, fit in the destination matrix. Typically, all of |
1713 | // it fits, but when the destination matrix shape is not a multiple |
1714 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
1715 | // not fit entirely. |
1716 | "sub w1, %w[dst_rows], %w[row]\n" |
1717 | "sub w2, %w[dst_cols], %w[col]\n" |
1718 | "mov w3, #4\n" |
1719 | "cmp w1, #4\n" |
1720 | // Compute w1 = how many rows of the 4x4 block fit |
1721 | "csel w1, w1, w3, le\n" |
1722 | "cmp w2, #4\n" |
1723 | |
1724 | // Test if w1==4 i.e. if all of the 4x1 block fits. |
1725 | "cmp w1, w3\n" |
1726 | "ccmp w2, w3, 0, eq\n" |
1727 | "mov x4, %[dst_ptr]\n" |
1728 | // Yes, all of the 4x1 block fits, go to fast path. |
1729 | "beq 30f\n" |
1730 | // Not all of the 4x4 block fits. |
1731 | // Store to dst_tmp_buf |
1732 | "str q16, [%[dst_tmp_buf], #0]\n" |
1733 | // Slow loop copying from dst_tmp_buf to dst. |
1734 | "mov x3, %[dst_tmp_buf]\n" |
1735 | "mov w6, #0\n" |
1736 | "50:\n" |
1737 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1738 | "mov w5, #0\n" |
1739 | "51:\n" |
1740 | "ldr w7, [x3, x5, lsl #2]\n" |
1741 | "str w7, [x4, x5, lsl #2]\n" |
1742 | "add w5, w5, #1\n" |
1743 | "cmp w5, w1\n" |
1744 | "blt 51b\n" |
1745 | "b 31f\n" |
1746 | "30:\n" |
1747 | // Yes, all of the 4x4 block fits. |
1748 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
1749 | "mov x3, x4\n" |
1750 | "st1 {v16.s}[0], [x3], #4\n" |
1751 | "st1 {v16.s}[1], [x3], #4\n" |
1752 | "st1 {v16.s}[2], [x3], #4\n" |
1753 | "st1 {v16.s}[3], [x3], #4\n" |
1754 | "31:\n" |
1755 | |
1756 | "add %[dst_ptr], %[dst_ptr], #16\n" |
1757 | |
1758 | RUY_MAKE_ZERO(v16) |
1759 | RUY_MAKE_ZERO(v17) |
1760 | RUY_MAKE_ZERO(v18) |
1761 | RUY_MAKE_ZERO(v19) |
1762 | |
1763 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
1764 | |
1765 | // For the next block: perform the first few multiply-adds on the data |
1766 | // that we have already loaded. |
1767 | "smull v8.8h, v0.8b, v4.8b\n" |
1768 | "smull v9.8h, v1.8b, v4.8b\n" |
1769 | "smull v10.8h, v2.8b, v4.8b\n" |
1770 | "smull v11.8h, v3.8b, v4.8b\n" |
1771 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
1772 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
1773 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
1774 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
1775 | |
1776 | // Reload some params --- we had used x5 -- x7 for a few other things |
1777 | // since the last time we had loaded them. |
1778 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1779 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1780 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1781 | |
1782 | // Move to the next block of the destination matrix, for the next iter |
1783 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
1784 | // been updated earlier. |
1785 | // Have we reached the end row? |
1786 | "cmp %w[row], w7\n" |
1787 | "beq 20f\n" // yes, end row. |
1788 | // Not end row. Move to the next row. |
1789 | "add %w[row], %w[row], #4\n" |
1790 | "b 21f\n" |
1791 | "20:\n" |
1792 | // Was already at end row. |
1793 | "mov %w[row], w6\n" // Move back to first row. |
1794 | "add %w[col], %w[col], #4\n" // Move to the next column. |
1795 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" |
1796 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
1797 | "21:\n" |
1798 | |
1799 | // Main loop exit condition: have we hit the end column? |
1800 | "cmp %w[col], w8\n" |
1801 | |
1802 | // w1 is the number of levels of depth that we have already loaded |
1803 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1804 | // above, this is currently 16. |
1805 | "mov w1, #16\n" |
1806 | |
1807 | "ble 1b\n" |
1808 | |
1809 | // clang-format on |
1810 | |
1811 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
1812 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
1813 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
1814 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
1815 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
1816 | [dst_type_id] "r" (params.dst_type_id) |
1817 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
1818 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
1819 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" ); |
1820 | } |
1821 | |
1822 | // Variant of the above Kernel8bitNeon, tuned for A55-ish CPUs. |
1823 | // Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and |
1824 | // the original Cortex-A55, since these are 64-bit and do not support dotprod. |
1825 | // |
1826 | // While this kernel does not have a direct equivalent in gemmlowp, it was |
1827 | // developed based on insights that David Mansell at ARM shared with their |
1828 | // contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful |
1829 | // comments. Specifically, see this comment about tuning for Cortex-A53: |
1830 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 |
1831 | void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) { |
1832 | profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)" ); |
1833 | |
1834 | CheckOffsetsInKernelParams8bit(params); |
1835 | |
1836 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
1837 | const std::int8_t* rhs_col_ptr = |
1838 | static_cast<const int8_t*>(params.rhs_base_ptr); |
1839 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
1840 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
1841 | void* dst_col_ptr = params.dst_base_ptr; |
1842 | void* dst_ptr = dst_col_ptr; |
1843 | int row = params.start_row; |
1844 | int col = params.start_col; |
1845 | |
1846 | // The asm kernel below has the following NEON register allocation: |
1847 | // |
1848 | // v16 -- v31 are int32 accumulators. |
1849 | // During accumulation, v0 -- v3 are used to load int8 data from LHS and |
1850 | // v4 -- v7 from RHS: |
1851 | // |
1852 | // int8 RHS 16x4 block |
1853 | // /-----------------------------------------| |
1854 | // |v4.b[0] ... v7.b[0] | |
1855 | // | ... ... | |
1856 | // |v4.b[15] ... v7.b[15] | |
1857 | // \-----------------------------------------/ |
1858 | // int8 LHS 4x16 block |
1859 | // /---------------------\ /-----------------------------------------| |
1860 | // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | |
1861 | // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | |
1862 | // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | |
1863 | // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | |
1864 | // \---------------------/ \-----------------------------------------/ |
1865 | // int32 accumulators 4x4 block |
1866 | asm volatile( |
1867 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
1868 | |
1869 | // clang-format off |
1870 | |
1871 | // Load some parameters into registers. |
1872 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1873 | RUY_MAKE_ZERO(v16) |
1874 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1875 | RUY_MAKE_ZERO(v17) |
1876 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1877 | RUY_MAKE_ZERO(v18) |
1878 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
1879 | RUY_MAKE_ZERO(v19) |
1880 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
1881 | RUY_MAKE_ZERO(v20) |
1882 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
1883 | RUY_MAKE_ZERO(v21) |
1884 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1885 | RUY_MAKE_ZERO(v22) |
1886 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
1887 | RUY_MAKE_ZERO(v23) |
1888 | |
1889 | // Load the first 64 bytes of LHS and RHS data. |
1890 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
1891 | RUY_MAKE_ZERO(v24) |
1892 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
1893 | RUY_MAKE_ZERO(v25) |
1894 | "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" |
1895 | RUY_MAKE_ZERO(v26) |
1896 | "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" |
1897 | RUY_MAKE_ZERO(v27) |
1898 | "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" |
1899 | RUY_MAKE_ZERO(v28) |
1900 | "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" |
1901 | RUY_MAKE_ZERO(v29) |
1902 | "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" |
1903 | RUY_MAKE_ZERO(v30) |
1904 | "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" |
1905 | RUY_MAKE_ZERO(v31) |
1906 | |
1907 | |
1908 | // w1 is the number of levels of depth that we have already loaded |
1909 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1910 | // above, this is currently 16. |
1911 | "mov w1, #16\n" |
1912 | |
1913 | // Perform the first few multiply-adds on the data that we have already |
1914 | // loaded. |
1915 | "smull v8.8h, v0.8b, v4.8b\n" |
1916 | "smull v9.8h, v1.8b, v4.8b\n" |
1917 | "smull v10.8h, v2.8b, v4.8b\n" |
1918 | "smull v11.8h, v3.8b, v4.8b\n" |
1919 | "smull v12.8h, v0.8b, v5.8b\n" |
1920 | "smull v13.8h, v1.8b, v5.8b\n" |
1921 | "smull v14.8h, v2.8b, v5.8b\n" |
1922 | "smull v15.8h, v3.8b, v5.8b\n" |
1923 | |
1924 | // Multiply-accumulate second-half, again into the same |
1925 | // 16bit local accumulator registers. This is where we |
1926 | // take advantage of having int8 instead of uint8 and therefore |
1927 | // being able to accumulate two products into int16. |
1928 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
1929 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
1930 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
1931 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
1932 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
1933 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
1934 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
1935 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
1936 | |
1937 | |
1938 | // Main loop of the whole GEMM, over rows and columns of the |
1939 | // destination matrix. |
1940 | "1:\n" |
1941 | |
1942 | // Reminder - w1 is how many levels of depth we have already loaded |
1943 | // data for, w12 is the total depth. |
1944 | "cmp w1, w12\n" |
1945 | "beq 79f\n" |
1946 | |
1947 | "2:\n" |
1948 | |
1949 | // Some multiplications and 16-bit accumulation were already done above, |
1950 | // so we start right away in the middle. |
1951 | "sadalp v16.4s, v8.8h\n" |
1952 | "ldr d4, [%[rhs_ptr], #0]\n" |
1953 | "smull v8.8h, v0.8b, v6.8b\n" |
1954 | "ldr x7, [%[rhs_ptr], #8]\n" |
1955 | "sadalp v17.4s, v9.8h\n" |
1956 | "ldr d5, [%[rhs_ptr], #16]\n" |
1957 | "smull v9.8h, v1.8b, v6.8b\n" |
1958 | "ldr x8, [%[rhs_ptr], #24]\n" |
1959 | "sadalp v18.4s, v10.8h\n" |
1960 | "smull v10.8h, v2.8b, v6.8b\n" |
1961 | "sadalp v19.4s, v11.8h\n" |
1962 | "add %[lhs_ptr], %[lhs_ptr], #64\n" |
1963 | "smull v11.8h, v3.8b, v6.8b\n" |
1964 | "add %[rhs_ptr], %[rhs_ptr], #64\n" |
1965 | "sadalp v20.4s, v12.8h\n" |
1966 | // Each iteration of this loop advances by 16 levels of depth. |
1967 | "add w1, w1, #16\n" |
1968 | "smull v12.8h, v0.8b, v7.8b\n" |
1969 | // Loop termination condition |
1970 | "cmp w1, w12\n" |
1971 | "sadalp v21.4s, v13.8h\n" |
1972 | "ldr x3, [%[lhs_ptr], #-56]\n" |
1973 | "smull v13.8h, v1.8b, v7.8b\n" |
1974 | "ldr x4, [%[lhs_ptr], #-40]\n" |
1975 | "sadalp v22.4s, v14.8h\n" |
1976 | "ldr x5, [%[lhs_ptr], #-24]\n" |
1977 | "smull v14.8h, v2.8b, v7.8b\n" |
1978 | "ldr x6, [%[lhs_ptr], #-8]\n" |
1979 | "sadalp v23.4s, v15.8h\n" |
1980 | "smull v15.8h, v3.8b, v7.8b\n" |
1981 | |
1982 | // Multiply-accumulate second-half, again into the same |
1983 | // 16bit local accumulator registers. This is where we |
1984 | // take advantage of having int8 instead of uint8 and therefore |
1985 | // being able to accumulate two products into int16. |
1986 | "smlal2 v8.8h, v0.16b, v6.16b\n" |
1987 | "smlal2 v9.8h, v1.16b, v6.16b\n" |
1988 | "smlal2 v10.8h, v2.16b, v6.16b\n" |
1989 | "ldr x9, [%[rhs_ptr], #-24]\n" |
1990 | "smlal2 v11.8h, v3.16b, v6.16b\n" |
1991 | "ldr d6, [%[rhs_ptr], #-32]\n" |
1992 | "smlal2 v12.8h, v0.16b, v7.16b\n" |
1993 | "ldr d0, [%[lhs_ptr], #-64]\n" |
1994 | "smlal2 v13.8h, v1.16b, v7.16b\n" |
1995 | "ldr d1, [%[lhs_ptr], #-48]\n" |
1996 | "smlal2 v14.8h, v2.16b, v7.16b\n" |
1997 | "ins v4.d[1], x7\n" |
1998 | "smlal2 v15.8h, v3.16b, v7.16b\n" |
1999 | "ins v5.d[1], x8\n" |
2000 | |
2001 | "ldr d2, [%[lhs_ptr], #-32]\n" |
2002 | "ins v0.d[1], x3\n" |
2003 | "sadalp v24.4s, v8.8h\n" |
2004 | "ldr d3, [%[lhs_ptr], #-16]\n" |
2005 | "ins v1.d[1], x4\n" |
2006 | "smull v8.8h, v0.8b, v4.8b\n" |
2007 | "ins v2.d[1], x5\n" |
2008 | "sadalp v25.4s, v9.8h\n" |
2009 | "ins v3.d[1], x6\n" |
2010 | "smull v9.8h, v1.8b, v4.8b\n" |
2011 | "ldr d7, [%[rhs_ptr], #-16]\n" |
2012 | "sadalp v26.4s, v10.8h\n" |
2013 | "ldr x10, [%[rhs_ptr], #-8]\n" |
2014 | "smull v10.8h, v2.8b, v4.8b\n" |
2015 | "sadalp v27.4s, v11.8h\n" |
2016 | "smull v11.8h, v3.8b, v4.8b\n" |
2017 | "sadalp v28.4s, v12.8h\n" |
2018 | "smull v12.8h, v0.8b, v5.8b\n" |
2019 | "sadalp v29.4s, v13.8h\n" |
2020 | "smull v13.8h, v1.8b, v5.8b\n" |
2021 | "sadalp v30.4s, v14.8h\n" |
2022 | "smull v14.8h, v2.8b, v5.8b\n" |
2023 | "sadalp v31.4s, v15.8h\n" |
2024 | "smull v15.8h, v3.8b, v5.8b\n" |
2025 | |
2026 | // Multiply-accumulate second-half, again into the same |
2027 | // 16bit local accumulator registers. This is where we |
2028 | // take advantage of having int8 instead of uint8 and therefore |
2029 | // being able to accumulate two products into int16. |
2030 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
2031 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
2032 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
2033 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
2034 | |
2035 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
2036 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
2037 | "ins v6.d[1], x9\n" |
2038 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
2039 | "ins v7.d[1], x10\n" |
2040 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
2041 | |
2042 | "blt 2b\n" |
2043 | |
2044 | "79:\n" |
2045 | |
2046 | "sadalp v16.4s, v8.8h\n" |
2047 | "smull v8.8h, v0.8b, v6.8b\n" |
2048 | "sadalp v17.4s, v9.8h\n" |
2049 | "smull v9.8h, v1.8b, v6.8b\n" |
2050 | "sadalp v18.4s, v10.8h\n" |
2051 | "smull v10.8h, v2.8b, v6.8b\n" |
2052 | "sadalp v19.4s, v11.8h\n" |
2053 | "smull v11.8h, v3.8b, v6.8b\n" |
2054 | "sadalp v20.4s, v12.8h\n" |
2055 | "smull v12.8h, v0.8b, v7.8b\n" |
2056 | "sadalp v21.4s, v13.8h\n" |
2057 | "smull v13.8h, v1.8b, v7.8b\n" |
2058 | "sadalp v22.4s, v14.8h\n" |
2059 | "smull v14.8h, v2.8b, v7.8b\n" |
2060 | "sadalp v23.4s, v15.8h\n" |
2061 | "smull v15.8h, v3.8b, v7.8b\n" |
2062 | |
2063 | // Multiply-accumulate second-half, again into the same |
2064 | // 16bit local accumulator registers. This is where we |
2065 | // take advantage of having int8 instead of uint8 and therefore |
2066 | // being able to accumulate two products into int16. |
2067 | "smlal2 v8.8h, v0.16b, v6.16b\n" |
2068 | "smlal2 v9.8h, v1.16b, v6.16b\n" |
2069 | "smlal2 v10.8h, v2.16b, v6.16b\n" |
2070 | "smlal2 v11.8h, v3.16b, v6.16b\n" |
2071 | |
2072 | "smlal2 v12.8h, v0.16b, v7.16b\n" |
2073 | "smlal2 v13.8h, v1.16b, v7.16b\n" |
2074 | "smlal2 v14.8h, v2.16b, v7.16b\n" |
2075 | "smlal2 v15.8h, v3.16b, v7.16b\n" |
2076 | |
2077 | "sadalp v24.4s, v8.8h\n" |
2078 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
2079 | "sadalp v25.4s, v9.8h\n" |
2080 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
2081 | "sadalp v26.4s, v10.8h\n" |
2082 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
2083 | "sadalp v27.4s, v11.8h\n" |
2084 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
2085 | "sadalp v28.4s, v12.8h\n" |
2086 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
2087 | "sadalp v29.4s, v13.8h\n" |
2088 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
2089 | "sadalp v30.4s, v14.8h\n" |
2090 | "sadalp v31.4s, v15.8h\n" |
2091 | |
2092 | // End of accumulation. The registers v16 -- v31 contain the final |
2093 | // int32 accumulator values of the current 4x4 destination block. |
2094 | // We now have to compute the final 8-bit values from these int32 |
2095 | // accumulators, and advance to the next 4x4 block. We intertwine |
2096 | // these two aspects whenever possible for optimal pipelining, both |
2097 | // at the data flow level (prefetch data for next block as early as |
2098 | // possible) and instruction pipelining level (some of the next-block |
2099 | // work can dual-issue with some of the final work on the current |
2100 | // block). |
2101 | |
2102 | // Reduce 32bit accumulators horizontally. |
2103 | "addp v16.4s, v16.4s, v17.4s\n" |
2104 | "addp v18.4s, v18.4s, v19.4s\n" |
2105 | "addp v20.4s, v20.4s, v21.4s\n" |
2106 | "addp v22.4s, v22.4s, v23.4s\n" |
2107 | "addp v24.4s, v24.4s, v25.4s\n" |
2108 | "addp v26.4s, v26.4s, v27.4s\n" |
2109 | "addp v28.4s, v28.4s, v29.4s\n" |
2110 | "addp v30.4s, v30.4s, v31.4s\n" |
2111 | |
2112 | // Reduce 32bit accumulators horizontally, second pass |
2113 | // (each pass adds pairwise. we need to add 4-wise). |
2114 | "addp v16.4s, v16.4s, v18.4s\n" |
2115 | "addp v17.4s, v20.4s, v22.4s\n" |
2116 | "addp v18.4s, v24.4s, v26.4s\n" |
2117 | "addp v19.4s, v28.4s, v30.4s\n" |
2118 | |
2119 | // Logic to advance to the next block in preparation for the next |
2120 | // iteration of the main loop. For now, we only want to compute |
2121 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
2122 | // not yet ready to update the values of row and col, as we still need |
2123 | // the current values for the rest of the work on the current block. |
2124 | |
2125 | "cmp %w[row], w7\n" // Have we finished the last row? |
2126 | "bge 4f\n" // If finished last row, go to 4 |
2127 | // Not finished last row: then advance to next row. |
2128 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" |
2129 | "b 5f\n" |
2130 | "4:\n" // Finished last row... |
2131 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
2132 | // Now we need to advance to the next column. If we already |
2133 | // finished the last column, then in principle we are done, however |
2134 | // we can't just return here, as we need to allow the end work of the |
2135 | // current block to complete. The good news is that at this point it |
2136 | // doesn't matter what data we load for the next column, since |
2137 | // we will exit from the main loop below before actually storing |
2138 | // anything computed from that data. |
2139 | "cmp %w[col], w8\n" // Have we finished the last column? |
2140 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
2141 | // Not finished last column: then advance to next column. |
2142 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" |
2143 | "5:\n" |
2144 | |
2145 | // Set the LHS and RHS data pointers to the start of the columns just |
2146 | // computed. |
2147 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
2148 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
2149 | |
2150 | // Load some parameters needed for the end work on current block. |
2151 | "mvni v8.4s, #0\n" |
2152 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
2153 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
2154 | "ins v13.h[4], w4\n" // dst_zero_point |
2155 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
2156 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
2157 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
2158 | |
2159 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
2160 | |
2161 | // Determine the channel index. |
2162 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
2163 | "csel w3, %w[row], %w[col], eq\n" |
2164 | |
2165 | // Offset the bias pointer as needed given the current row, col. |
2166 | "add x5, x1, x3, lsl #2\n" |
2167 | |
2168 | // If there is no bias, use no offset, just address the passed zero |
2169 | // data. |
2170 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
2171 | "csel x1, x1, x5, eq\n" |
2172 | |
2173 | // Load 4 bias values. |
2174 | "ld1 {v14.4s}, [x1]\n" |
2175 | |
2176 | // Load the multiplier_fixedpoint values. |
2177 | "add x5, x4, x3, lsl #2\n" |
2178 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
2179 | "csel x4, x4, x5, eq\n" |
2180 | "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint |
2181 | |
2182 | // Now that we know what LHS and RHS data the next iteration of the |
2183 | // main loop will need to load, we start loading the first 32 bytes of |
2184 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
2185 | // in the rest of the work on the current block. |
2186 | |
2187 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
2188 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
2189 | "add v14.4s, v14.4s, v9.4s\n" |
2190 | "ldr d0, [%[lhs_ptr], #0]\n" |
2191 | |
2192 | // Perform the bias-addition (per the above, we have just folded into |
2193 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
2194 | // Jump based on channel dimension. |
2195 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
2196 | "bne 6f\n" |
2197 | // Case where channels are rows |
2198 | |
2199 | "add v16.4s, v16.4s, v14.4s\n" |
2200 | "ldr d1, [%[lhs_ptr], #16]\n" |
2201 | "add v17.4s, v17.4s, v14.4s\n" |
2202 | "ldr d2, [%[lhs_ptr], #32]\n" |
2203 | "add v18.4s, v18.4s, v14.4s\n" |
2204 | "ldr d3, [%[lhs_ptr], #48]\n" |
2205 | "add v19.4s, v19.4s, v14.4s\n" |
2206 | "ldr d4, [%[rhs_ptr], #0]\n" |
2207 | "ldr d5, [%[rhs_ptr], #16]\n" |
2208 | "ldr d6, [%[rhs_ptr], #32]\n" |
2209 | "ldr d7, [%[rhs_ptr], #48]\n" |
2210 | |
2211 | "b 7f\n" |
2212 | |
2213 | "6:\n" |
2214 | // Case where channels are columns |
2215 | "dup v20.4s, v14.s[0]\n" |
2216 | "ldr d1, [%[lhs_ptr], #16]\n" |
2217 | "dup v21.4s, v14.s[1]\n" |
2218 | "ldr d2, [%[lhs_ptr], #32]\n" |
2219 | "dup v22.4s, v14.s[2]\n" |
2220 | "ldr d3, [%[lhs_ptr], #48]\n" |
2221 | "dup v23.4s, v14.s[3]\n" |
2222 | "ldr d4, [%[rhs_ptr], #0]\n" |
2223 | "add v16.4s, v16.4s, v20.4s\n" |
2224 | "ldr d5, [%[rhs_ptr], #16]\n" |
2225 | "add v17.4s, v17.4s, v21.4s\n" |
2226 | "ldr d6, [%[rhs_ptr], #32]\n" |
2227 | "add v18.4s, v18.4s, v22.4s\n" |
2228 | "ldr d7, [%[rhs_ptr], #48]\n" |
2229 | "add v19.4s, v19.4s, v23.4s\n" |
2230 | "7:\n" |
2231 | |
2232 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
2233 | "beq 401f\n" |
2234 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
2235 | "add x3, x3, %x[col], lsl #2\n" |
2236 | "ld1 {v14.4s}, [x3]\n" |
2237 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
2238 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
2239 | // Subtract rhs_sums * lhs_zero_point, per |
2240 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
2241 | "mls v16.4s, v10.4s, v14.s[0]\n" |
2242 | "mls v17.4s, v10.4s, v14.s[1]\n" |
2243 | "mls v18.4s, v10.4s, v14.s[2]\n" |
2244 | "mls v19.4s, v10.4s, v14.s[3]\n" |
2245 | "401:\n" |
2246 | |
2247 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
2248 | "beq 402f\n" |
2249 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
2250 | "add x2, x2, %x[row], lsl #2\n" |
2251 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
2252 | // Load 4 lhs_sums values. |
2253 | "ld1 {v11.4s}, [x2]\n" |
2254 | "ins v13.s[1], w5\n" // rhs_zero_point |
2255 | // Compute lhs_sums * rhs_zero_point. |
2256 | "mul v11.4s, v11.4s, v13.s[1]\n" |
2257 | // Subtract lhs_sums * rhs_zero_point, per |
2258 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
2259 | "sub v16.4s, v16.4s, v11.4s\n" |
2260 | "sub v17.4s, v17.4s, v11.4s\n" |
2261 | "sub v18.4s, v18.4s, v11.4s\n" |
2262 | "sub v19.4s, v19.4s, v11.4s\n" |
2263 | |
2264 | // If the destination is int32, it means the user asks for the raw |
2265 | // accumulators, no need for us to downquantize the value. |
2266 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
2267 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
2268 | |
2269 | "402:\n" |
2270 | |
2271 | // At this point we have computed the final int32 values. Now we |
2272 | // start down-quantizing them to obtain the final 8bit values from them. |
2273 | |
2274 | // As part of this down-quantization, our int32 values will be |
2275 | // multiplied by a multiplier that has a fixed-point component and an |
2276 | // exponent component. |
2277 | |
2278 | // Determine the channel index. |
2279 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
2280 | "csel w3, %w[row], %w[col], eq\n" |
2281 | |
2282 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
2283 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
2284 | "add x5, x1, x3, lsl #2\n" |
2285 | "csel x1, x1, x5, eq\n" |
2286 | |
2287 | "ld1 {v14.4s}, [x1]\n" |
2288 | |
2289 | "smin v11.4s, v8.4s, v14.4s\n" |
2290 | "ldr x1, [%[lhs_ptr], #8]\n" |
2291 | "sub v12.4s, v14.4s, v11.4s\n" |
2292 | |
2293 | // Jump based on channel dimension. |
2294 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
2295 | "bne 8f\n" |
2296 | // Case where channels are rows |
2297 | |
2298 | |
2299 | // Apply the positive exponent part of the multiplier. |
2300 | "sshl v16.4s, v16.4s, v12.4s\n" |
2301 | "ldr x2, [%[lhs_ptr], #24]\n" |
2302 | "sshl v17.4s, v17.4s, v12.4s\n" |
2303 | "ldr x3, [%[lhs_ptr], #40]\n" |
2304 | "sshl v18.4s, v18.4s, v12.4s\n" |
2305 | "ldr x4, [%[lhs_ptr], #56]\n" |
2306 | "sshl v19.4s, v19.4s, v12.4s\n" |
2307 | |
2308 | |
2309 | // Apply the fixed-point part of the multiplier. |
2310 | "ins v0.d[1], x1\n" |
2311 | "ldr x1, [%[rhs_ptr], #8]\n" |
2312 | "sqdmulh v16.4s, v16.4s, v15.4s\n" |
2313 | "ins v1.d[1], x2\n" |
2314 | "ldr x2, [%[rhs_ptr], #24]\n" |
2315 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
2316 | "ins v2.d[1], x3\n" |
2317 | "ldr x3, [%[rhs_ptr], #40]\n" |
2318 | "sqdmulh v18.4s, v18.4s, v15.4s\n" |
2319 | "ins v3.d[1], x4\n" |
2320 | "ldr x4, [%[rhs_ptr], #56]\n" |
2321 | "sqdmulh v19.4s, v19.4s, v15.4s\n" |
2322 | |
2323 | // Apply the negative exponent part of the multiplier. |
2324 | "srshl v16.4s, v16.4s, v11.4s\n" |
2325 | "srshl v17.4s, v17.4s, v11.4s\n" |
2326 | "srshl v18.4s, v18.4s, v11.4s\n" |
2327 | "srshl v19.4s, v19.4s, v11.4s\n" |
2328 | |
2329 | "b 9f\n" |
2330 | |
2331 | "8:\n" |
2332 | // Case where channels are columns |
2333 | |
2334 | // Apply the positive exponent part of the multiplier. |
2335 | "dup v20.4s, v12.s[0]\n" |
2336 | "ldr x2, [%[lhs_ptr], #24]\n" |
2337 | "ldr x3, [%[lhs_ptr], #40]\n" |
2338 | "dup v21.4s, v12.s[1]\n" |
2339 | "ldr x4, [%[lhs_ptr], #56]\n" |
2340 | "dup v22.4s, v12.s[2]\n" |
2341 | "ins v0.d[1], x1\n" |
2342 | "dup v23.4s, v12.s[3]\n" |
2343 | "ldr x1, [%[rhs_ptr], #8]\n" |
2344 | "sshl v16.4s, v16.4s, v20.4s\n" |
2345 | "ins v1.d[1], x2\n" |
2346 | "sshl v17.4s, v17.4s, v21.4s\n" |
2347 | "ldr x2, [%[rhs_ptr], #24]\n" |
2348 | "sshl v18.4s, v18.4s, v22.4s\n" |
2349 | "ins v2.d[1], x3\n" |
2350 | "sshl v19.4s, v19.4s, v23.4s\n" |
2351 | "ldr x3, [%[rhs_ptr], #40]\n" |
2352 | |
2353 | // Apply the fixed-point part of the multiplier. |
2354 | "sqdmulh v16.4s, v16.4s, v15.s[0]\n" |
2355 | "ins v3.d[1], x4\n" |
2356 | "sqdmulh v17.4s, v17.4s, v15.s[1]\n" |
2357 | "ldr x4, [%[rhs_ptr], #56]\n" |
2358 | "sqdmulh v18.4s, v18.4s, v15.s[2]\n" |
2359 | "dup v20.4s, v11.s[0]\n" |
2360 | "sqdmulh v19.4s, v19.4s, v15.s[3]\n" |
2361 | |
2362 | // Apply the negative exponent part of the multiplier. |
2363 | "dup v21.4s, v11.s[1]\n" |
2364 | "srshl v16.4s, v16.4s, v20.4s\n" |
2365 | "dup v22.4s, v11.s[2]\n" |
2366 | "srshl v17.4s, v17.4s, v21.4s\n" |
2367 | "dup v23.4s, v11.s[3]\n" |
2368 | "srshl v18.4s, v18.4s, v22.4s\n" |
2369 | "srshl v19.4s, v19.4s, v23.4s\n" |
2370 | |
2371 | "9:\n" |
2372 | |
2373 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
2374 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
2375 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
2376 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
2377 | |
2378 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
2379 | |
2380 | "ins v4.d[1], x1\n" |
2381 | "sqxtn v16.4h, v16.4s\n" |
2382 | "ins v5.d[1], x2\n" |
2383 | "sqxtn2 v16.8h, v17.4s\n" |
2384 | "ins v6.d[1], x3\n" |
2385 | "sqxtn v17.4h, v18.4s\n" |
2386 | "ins v7.d[1], x4\n" |
2387 | RUY_MAKE_ZERO(v18) |
2388 | "sqxtn2 v17.8h, v19.4s\n" |
2389 | |
2390 | // At this point, v18 -- v31 aren't used anymore for the current block, |
2391 | // so we can start clearing these accumulators for the next block |
2392 | // (next iteration of the main loop). |
2393 | RUY_MAKE_ZERO(v19) |
2394 | |
2395 | // Add the destination zero point |
2396 | "add %[lhs_ptr], %[lhs_ptr], #64\n" |
2397 | "dup v14.8h, v13.h[4]\n" |
2398 | RUY_MAKE_ZERO(v20) |
2399 | "add %[rhs_ptr], %[rhs_ptr], #64\n" |
2400 | "sqadd v16.8h, v16.8h, v14.8h\n" |
2401 | RUY_MAKE_ZERO(v21) |
2402 | "sqadd v17.8h, v17.8h, v14.8h\n" |
2403 | RUY_MAKE_ZERO(v22) |
2404 | |
2405 | // Cast-and-saturate from int16 to uint8 |
2406 | "sqxtun v16.8b, v16.8h\n" |
2407 | RUY_MAKE_ZERO(v23) |
2408 | "sqxtun2 v16.16b, v17.8h\n" |
2409 | RUY_MAKE_ZERO(v24) |
2410 | |
2411 | // Load the clamp_min, clamp_max bounds |
2412 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2413 | RUY_MAKE_ZERO(v25) |
2414 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2415 | RUY_MAKE_ZERO(v26) |
2416 | "dup v14.16b, w2\n" // clamp_min |
2417 | RUY_MAKE_ZERO(v27) |
2418 | "dup v15.16b, w3\n" // clamp_max |
2419 | RUY_MAKE_ZERO(v28) |
2420 | |
2421 | // Apply the clamp_min bound |
2422 | "umax v16.16b, v16.16b, v14.16b\n" |
2423 | RUY_MAKE_ZERO(v29) |
2424 | // Apply the clamp_max bound |
2425 | "umin v16.16b, v16.16b, v15.16b\n" |
2426 | RUY_MAKE_ZERO(v30) |
2427 | |
2428 | // Compute how much of the 4x4 block of destination 8bit values that |
2429 | // we have computed, fit in the destination matrix. Typically, all of |
2430 | // it fits, but when the destination matrix shape is not a multiple |
2431 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
2432 | // not fit entirely. |
2433 | "sub w1, %w[dst_rows], %w[row]\n" |
2434 | RUY_MAKE_ZERO(v31) |
2435 | "sub w2, %w[dst_cols], %w[col]\n" |
2436 | "mov w3, #4\n" |
2437 | "cmp w1, #4\n" |
2438 | // Compute w1 = how many rows of the 4x4 block fit |
2439 | "csel w1, w1, w3, le\n" |
2440 | "cmp w2, #4\n" |
2441 | // Compute w2 = how many cols of the 4x4 block fit |
2442 | "csel w2, w2, w3, le\n" |
2443 | |
2444 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
2445 | "cmp w1, w3\n" |
2446 | "ccmp w2, w3, 0, eq\n" |
2447 | "mov x4, %[dst_ptr]\n" |
2448 | // Yes, all of the 4x4 block fits, go to fast path. |
2449 | "beq 30f\n" |
2450 | // Not all of the 4x4 block fits. |
2451 | // Store to dst_tmp_buf |
2452 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
2453 | // Slow loop copying from dst_tmp_buf to dst. |
2454 | "mov x3, %[dst_tmp_buf]\n" |
2455 | "mov w6, #0\n" |
2456 | "50:\n" |
2457 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2458 | "mov w5, #0\n" |
2459 | "51:\n" |
2460 | "ldrb w7, [x3, w5, uxtw]\n" |
2461 | "strb w7, [x4, w5, uxtw]\n" |
2462 | "add w5, w5, #1\n" |
2463 | "cmp w5, w1\n" |
2464 | "blt 51b\n" |
2465 | "add w6, w6, #1\n" |
2466 | "add x3, x3, #4\n" |
2467 | "add x4, x4, x11\n" |
2468 | "cmp w6, w2\n" |
2469 | "blt 50b\n" |
2470 | "b 31f\n" |
2471 | "30:\n" |
2472 | // Yes, all of the 4x4 block fits. |
2473 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2474 | "mov x3, x4\n" |
2475 | "st1 {v16.b}[0], [x3], #1\n" |
2476 | "add x4, x4, x11\n" |
2477 | "st1 {v16.b}[1], [x3], #1\n" |
2478 | "st1 {v16.b}[2], [x3], #1\n" |
2479 | "st1 {v16.b}[3], [x3], #1\n" |
2480 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2481 | "mov x3, x4\n" |
2482 | "st1 {v16.b}[4], [x3], #1\n" |
2483 | "add x4, x4, x11\n" |
2484 | "st1 {v16.b}[5], [x3], #1\n" |
2485 | "st1 {v16.b}[6], [x3], #1\n" |
2486 | "st1 {v16.b}[7], [x3], #1\n" |
2487 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2488 | "mov x3, x4\n" |
2489 | "st1 {v16.b}[8], [x3], #1\n" |
2490 | "add x4, x4, x11\n" |
2491 | "st1 {v16.b}[9], [x3], #1\n" |
2492 | "st1 {v16.b}[10], [x3], #1\n" |
2493 | "st1 {v16.b}[11], [x3], #1\n" |
2494 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2495 | "mov x3, x4\n" |
2496 | "st1 {v16.b}[12], [x3], #1\n" |
2497 | "add x4, x4, x11\n" |
2498 | "st1 {v16.b}[13], [x3], #1\n" |
2499 | "st1 {v16.b}[14], [x3], #1\n" |
2500 | "st1 {v16.b}[15], [x3], #1\n" |
2501 | "31:\n" |
2502 | |
2503 | "add %[dst_ptr], %[dst_ptr], #4\n" |
2504 | |
2505 | RUY_MAKE_ZERO(v16) |
2506 | RUY_MAKE_ZERO(v17) |
2507 | |
2508 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2509 | |
2510 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
2511 | |
2512 | "ins v4.d[1], x1\n" |
2513 | "sqxtn v16.4h, v16.4s\n" |
2514 | "ins v5.d[1], x2\n" |
2515 | "sqxtn2 v16.8h, v17.4s\n" |
2516 | "ins v6.d[1], x3\n" |
2517 | "sqxtn v17.4h, v18.4s\n" |
2518 | "ins v7.d[1], x4\n" |
2519 | RUY_MAKE_ZERO(v18) |
2520 | "sqxtn2 v17.8h, v19.4s\n" |
2521 | |
2522 | // At this point, v18 -- v31 aren't used anymore for the current block, |
2523 | // so we can start clearing these accumulators for the next block |
2524 | // (next iteration of the main loop). |
2525 | RUY_MAKE_ZERO(v19) |
2526 | |
2527 | // Add the destination zero point |
2528 | "add %[lhs_ptr], %[lhs_ptr], #64\n" |
2529 | "dup v14.8h, v13.h[4]\n" |
2530 | RUY_MAKE_ZERO(v20) |
2531 | "add %[rhs_ptr], %[rhs_ptr], #64\n" |
2532 | "sqadd v16.8h, v16.8h, v14.8h\n" |
2533 | RUY_MAKE_ZERO(v21) |
2534 | "sqadd v17.8h, v17.8h, v14.8h\n" |
2535 | RUY_MAKE_ZERO(v22) |
2536 | |
2537 | // Cast-and-saturate from int16 to uint8 |
2538 | "sqxtn v16.8b, v16.8h\n" |
2539 | RUY_MAKE_ZERO(v23) |
2540 | "sqxtn2 v16.16b, v17.8h\n" |
2541 | RUY_MAKE_ZERO(v24) |
2542 | |
2543 | // Load the clamp_min, clamp_max bounds |
2544 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2545 | RUY_MAKE_ZERO(v25) |
2546 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2547 | RUY_MAKE_ZERO(v26) |
2548 | "dup v14.16b, w2\n" // clamp_min |
2549 | RUY_MAKE_ZERO(v27) |
2550 | "dup v15.16b, w3\n" // clamp_max |
2551 | RUY_MAKE_ZERO(v28) |
2552 | |
2553 | // Apply the clamp_min bound |
2554 | "smax v16.16b, v16.16b, v14.16b\n" |
2555 | RUY_MAKE_ZERO(v29) |
2556 | // Apply the clamp_max bound |
2557 | "smin v16.16b, v16.16b, v15.16b\n" |
2558 | RUY_MAKE_ZERO(v30) |
2559 | |
2560 | // Compute how much of the 4x4 block of destination 8bit values that |
2561 | // we have computed, fit in the destination matrix. Typically, all of |
2562 | // it fits, but when the destination matrix shape is not a multiple |
2563 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
2564 | // not fit entirely. |
2565 | "sub w1, %w[dst_rows], %w[row]\n" |
2566 | RUY_MAKE_ZERO(v31) |
2567 | "sub w2, %w[dst_cols], %w[col]\n" |
2568 | "mov w3, #4\n" |
2569 | "cmp w1, #4\n" |
2570 | // Compute w1 = how many rows of the 4x4 block fit |
2571 | "csel w1, w1, w3, le\n" |
2572 | "cmp w2, #4\n" |
2573 | // Compute w2 = how many cols of the 4x4 block fit |
2574 | "csel w2, w2, w3, le\n" |
2575 | |
2576 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
2577 | "cmp w1, w3\n" |
2578 | "ccmp w2, w3, 0, eq\n" |
2579 | "mov x4, %[dst_ptr]\n" |
2580 | // Yes, all of the 4x4 block fits, go to fast path. |
2581 | "beq 30f\n" |
2582 | // Not all of the 4x4 block fits. |
2583 | // Store to dst_tmp_buf |
2584 | "st1 {v16.16b}, [%[dst_tmp_buf]]\n" |
2585 | // Slow loop copying from dst_tmp_buf to dst. |
2586 | "mov x3, %[dst_tmp_buf]\n" |
2587 | "mov w6, #0\n" |
2588 | "50:\n" |
2589 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2590 | "mov w5, #0\n" |
2591 | "51:\n" |
2592 | "ldrb w7, [x3, w5, uxtw]\n" |
2593 | "strb w7, [x4, w5, uxtw]\n" |
2594 | "add w5, w5, #1\n" |
2595 | "cmp w5, w1\n" |
2596 | "blt 51b\n" |
2597 | "add w6, w6, #1\n" |
2598 | "add x3, x3, #4\n" |
2599 | "add x4, x4, x11\n" |
2600 | "cmp w6, w2\n" |
2601 | "blt 50b\n" |
2602 | "b 31f\n" |
2603 | "30:\n" |
2604 | // Yes, all of the 4x4 block fits. |
2605 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2606 | "mov x3, x4\n" |
2607 | "st1 {v16.b}[0], [x3], #1\n" |
2608 | "add x4, x4, x11\n" |
2609 | "st1 {v16.b}[1], [x3], #1\n" |
2610 | "st1 {v16.b}[2], [x3], #1\n" |
2611 | "st1 {v16.b}[3], [x3], #1\n" |
2612 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2613 | "mov x3, x4\n" |
2614 | "st1 {v16.b}[4], [x3], #1\n" |
2615 | "add x4, x4, x11\n" |
2616 | "st1 {v16.b}[5], [x3], #1\n" |
2617 | "st1 {v16.b}[6], [x3], #1\n" |
2618 | "st1 {v16.b}[7], [x3], #1\n" |
2619 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2620 | "mov x3, x4\n" |
2621 | "st1 {v16.b}[8], [x3], #1\n" |
2622 | "add x4, x4, x11\n" |
2623 | "st1 {v16.b}[9], [x3], #1\n" |
2624 | "st1 {v16.b}[10], [x3], #1\n" |
2625 | "st1 {v16.b}[11], [x3], #1\n" |
2626 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2627 | "mov x3, x4\n" |
2628 | "st1 {v16.b}[12], [x3], #1\n" |
2629 | "add x4, x4, x11\n" |
2630 | "st1 {v16.b}[13], [x3], #1\n" |
2631 | "st1 {v16.b}[14], [x3], #1\n" |
2632 | "st1 {v16.b}[15], [x3], #1\n" |
2633 | "31:\n" |
2634 | |
2635 | "add %[dst_ptr], %[dst_ptr], #4\n" |
2636 | |
2637 | RUY_MAKE_ZERO(v16) |
2638 | RUY_MAKE_ZERO(v17) |
2639 | |
2640 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2641 | |
2642 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
2643 | |
2644 | // Add the destination zero point |
2645 | "dup v14.4h, v13.h[4]\n" |
2646 | "saddw v16.4s, v16.4s, v14.4h\n" |
2647 | "saddw v17.4s, v17.4s, v14.4h\n" |
2648 | "saddw v18.4s, v18.4s, v14.4h\n" |
2649 | "saddw v19.4s, v19.4s, v14.4h\n" |
2650 | |
2651 | // Cast-and-saturate from int32 to int16 |
2652 | "ins v4.d[1], x1\n" |
2653 | "sqxtn v16.4h, v16.4s\n" |
2654 | "ins v5.d[1], x2\n" |
2655 | "sqxtn2 v16.8h, v17.4s\n" |
2656 | "ins v6.d[1], x3\n" |
2657 | "sqxtn v17.4h, v18.4s\n" |
2658 | "ins v7.d[1], x4\n" |
2659 | RUY_MAKE_ZERO(v18) |
2660 | "sqxtn2 v17.8h, v19.4s\n" |
2661 | |
2662 | // At this point, v18 -- v31 aren't used anymore for the current block, |
2663 | // so we can start clearing these accumulators for the next block |
2664 | // (next iteration of the main loop). |
2665 | RUY_MAKE_ZERO(v19) |
2666 | |
2667 | "add %[lhs_ptr], %[lhs_ptr], #64\n" |
2668 | RUY_MAKE_ZERO(v20) |
2669 | "add %[rhs_ptr], %[rhs_ptr], #64\n" |
2670 | RUY_MAKE_ZERO(v21) |
2671 | RUY_MAKE_ZERO(v22) |
2672 | |
2673 | RUY_MAKE_ZERO(v23) |
2674 | RUY_MAKE_ZERO(v24) |
2675 | |
2676 | // Load the clamp_min, clamp_max bounds |
2677 | "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2678 | RUY_MAKE_ZERO(v25) |
2679 | "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2680 | RUY_MAKE_ZERO(v26) |
2681 | "dup v14.8h, w2\n" // clamp_min |
2682 | RUY_MAKE_ZERO(v27) |
2683 | "dup v15.8h, w3\n" // clamp_max |
2684 | RUY_MAKE_ZERO(v28) |
2685 | |
2686 | // Apply the clamp_min bound |
2687 | "smax v16.8h, v16.8h, v14.8h\n" |
2688 | "smax v17.8h, v17.8h, v14.8h\n" |
2689 | RUY_MAKE_ZERO(v29) |
2690 | // Apply the clamp_max bound |
2691 | "smin v16.8h, v16.8h, v15.8h\n" |
2692 | "smin v17.8h, v17.8h, v15.8h\n" |
2693 | RUY_MAKE_ZERO(v30) |
2694 | |
2695 | // Compute how much of the 4x4 block of destination 8bit values that |
2696 | // we have computed, fit in the destination matrix. Typically, all of |
2697 | // it fits, but when the destination matrix shape is not a multiple |
2698 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
2699 | // not fit entirely. |
2700 | "sub w1, %w[dst_rows], %w[row]\n" |
2701 | RUY_MAKE_ZERO(v31) |
2702 | "sub w2, %w[dst_cols], %w[col]\n" |
2703 | "mov w3, #4\n" |
2704 | "cmp w1, #4\n" |
2705 | // Compute w1 = how many rows of the 4x4 block fit |
2706 | "csel w1, w1, w3, le\n" |
2707 | "cmp w2, #4\n" |
2708 | // Compute w2 = how many cols of the 4x4 block fit |
2709 | "csel w2, w2, w3, le\n" |
2710 | |
2711 | // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. |
2712 | "cmp w1, w3\n" |
2713 | "ccmp w2, w3, 0, eq\n" |
2714 | "mov x4, %[dst_ptr]\n" |
2715 | // Yes, all of the 4x4 block fits, go to fast path. |
2716 | "beq 30f\n" |
2717 | // Not all of the 4x4 block fits. |
2718 | // Store to dst_tmp_buf |
2719 | "str q16, [%[dst_tmp_buf], #0]\n" |
2720 | "str q17, [%[dst_tmp_buf], #16]\n" |
2721 | // Slow loop copying from dst_tmp_buf to dst. |
2722 | "mov x3, %[dst_tmp_buf]\n" |
2723 | "mov w6, #0\n" |
2724 | "50:\n" |
2725 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2726 | "mov w5, #0\n" |
2727 | "51:\n" |
2728 | "ldrh w7, [x3, x5, lsl #1]\n" |
2729 | "strh w7, [x4, x5, lsl #1]\n" |
2730 | "add w5, w5, #1\n" |
2731 | "cmp w5, w1\n" |
2732 | "blt 51b\n" |
2733 | "add w6, w6, #1\n" |
2734 | "add x3, x3, #8\n" |
2735 | "add x4, x4, x11\n" |
2736 | "cmp w6, w2\n" |
2737 | "blt 50b\n" |
2738 | "b 31f\n" |
2739 | "30:\n" |
2740 | // Yes, all of the 4x4 block fits. |
2741 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2742 | "mov x3, x4\n" |
2743 | "st1 {v16.h}[0], [x3], #2\n" |
2744 | "add x4, x4, x11\n" |
2745 | "st1 {v16.h}[1], [x3], #2\n" |
2746 | "st1 {v16.h}[2], [x3], #2\n" |
2747 | "st1 {v16.h}[3], [x3], #2\n" |
2748 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2749 | "mov x3, x4\n" |
2750 | "st1 {v16.h}[4], [x3], #2\n" |
2751 | "add x4, x4, x11\n" |
2752 | "st1 {v16.h}[5], [x3], #2\n" |
2753 | "st1 {v16.h}[6], [x3], #2\n" |
2754 | "st1 {v16.h}[7], [x3], #2\n" |
2755 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2756 | "mov x3, x4\n" |
2757 | "st1 {v17.h}[0], [x3], #2\n" |
2758 | "add x4, x4, x11\n" |
2759 | "st1 {v17.h}[1], [x3], #2\n" |
2760 | "st1 {v17.h}[2], [x3], #2\n" |
2761 | "st1 {v17.h}[3], [x3], #2\n" |
2762 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2763 | "mov x3, x4\n" |
2764 | "st1 {v17.h}[4], [x3], #2\n" |
2765 | "add x4, x4, x11\n" |
2766 | "st1 {v17.h}[5], [x3], #2\n" |
2767 | "st1 {v17.h}[6], [x3], #2\n" |
2768 | "st1 {v17.h}[7], [x3], #2\n" |
2769 | "31:\n" |
2770 | |
2771 | "add %[dst_ptr], %[dst_ptr], #8\n" |
2772 | |
2773 | RUY_MAKE_ZERO(v16) |
2774 | RUY_MAKE_ZERO(v17) |
2775 | |
2776 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2777 | |
2778 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
2779 | |
2780 | "ldr x1, [%[lhs_ptr], #8]\n" |
2781 | "ldr x2, [%[lhs_ptr], #24]\n" |
2782 | "ldr x3, [%[lhs_ptr], #40]\n" |
2783 | "ldr x4, [%[lhs_ptr], #56]\n" |
2784 | |
2785 | "ins v0.d[1], x1\n" |
2786 | "ldr x1, [%[rhs_ptr], #8]\n" |
2787 | "ins v1.d[1], x2\n" |
2788 | "ldr x2, [%[rhs_ptr], #24]\n" |
2789 | "ins v2.d[1], x3\n" |
2790 | "ldr x3, [%[rhs_ptr], #40]\n" |
2791 | "ins v3.d[1], x4\n" |
2792 | "ldr x4, [%[rhs_ptr], #56]\n" |
2793 | "ins v4.d[1], x1\n" |
2794 | "ins v5.d[1], x2\n" |
2795 | "ins v6.d[1], x3\n" |
2796 | "ins v7.d[1], x4\n" |
2797 | |
2798 | // Since the store type is the same as the accum type, no need for |
2799 | // downcast. There's also no need for clamp by min/max. |
2800 | |
2801 | // At this point, v20 -- v31 aren't used anymore for the current block, |
2802 | // so we can start clearing these accumulators for the next block |
2803 | // (next iteration of the main loop). |
2804 | |
2805 | RUY_MAKE_ZERO(v20) |
2806 | "add %[lhs_ptr], %[lhs_ptr], #64\n" |
2807 | RUY_MAKE_ZERO(v21) |
2808 | "add %[rhs_ptr], %[rhs_ptr], #64\n" |
2809 | RUY_MAKE_ZERO(v22) |
2810 | |
2811 | RUY_MAKE_ZERO(v23) |
2812 | RUY_MAKE_ZERO(v24) |
2813 | RUY_MAKE_ZERO(v25) |
2814 | RUY_MAKE_ZERO(v26) |
2815 | RUY_MAKE_ZERO(v27) |
2816 | RUY_MAKE_ZERO(v28) |
2817 | RUY_MAKE_ZERO(v29) |
2818 | RUY_MAKE_ZERO(v30) |
2819 | |
2820 | // Compute how much of the 4x4 block of destination 8bit values that |
2821 | // we have computed, fit in the destination matrix. Typically, all of |
2822 | // it fits, but when the destination matrix shape is not a multiple |
2823 | // of 4x4, there are some 4x4 blocks along the boundaries that do |
2824 | // not fit entirely. |
2825 | "sub w1, %w[dst_rows], %w[row]\n" |
2826 | RUY_MAKE_ZERO(v31) |
2827 | "sub w2, %w[dst_cols], %w[col]\n" |
2828 | "mov w3, #4\n" |
2829 | "cmp w1, #4\n" |
2830 | // Compute w1 = how many rows of the 4x4 block fit |
2831 | "csel w1, w1, w3, le\n" |
2832 | "cmp w2, #4\n" |
2833 | // Compute w2 = how many cols of the 4x4 block fit |
2834 | "csel w2, w2, w3, le\n" |
2835 | |
2836 | // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. |
2837 | "cmp w1, w3\n" |
2838 | "ccmp w2, w3, 0, eq\n" |
2839 | "mov x4, %[dst_ptr]\n" |
2840 | // Yes, all of the 4x4 block fits, go to fast path. |
2841 | "beq 30f\n" |
2842 | // Not all of the 4x4 block fits. |
2843 | // Store to dst_tmp_buf |
2844 | "str q16, [%[dst_tmp_buf], #0]\n" |
2845 | "str q17, [%[dst_tmp_buf], #16]\n" |
2846 | "str q18, [%[dst_tmp_buf], #32]\n" |
2847 | "str q19, [%[dst_tmp_buf], #48]\n" |
2848 | // Slow loop copying from dst_tmp_buf to dst. |
2849 | "mov x3, %[dst_tmp_buf]\n" |
2850 | "mov w6, #0\n" |
2851 | "50:\n" |
2852 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2853 | "mov w5, #0\n" |
2854 | "51:\n" |
2855 | "ldr w7, [x3, x5, lsl #2]\n" |
2856 | "str w7, [x4, x5, lsl #2]\n" |
2857 | "add w5, w5, #1\n" |
2858 | "cmp w5, w1\n" |
2859 | "blt 51b\n" |
2860 | "add w6, w6, #1\n" |
2861 | "add x3, x3, #16\n" |
2862 | "add x4, x4, x11\n" |
2863 | "cmp w6, w2\n" |
2864 | "blt 50b\n" |
2865 | "b 31f\n" |
2866 | "30:\n" |
2867 | // Yes, all of the 4x4 block fits. |
2868 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2869 | "mov x3, x4\n" |
2870 | "st1 {v16.s}[0], [x3], #4\n" |
2871 | "add x4, x4, x11\n" |
2872 | "st1 {v16.s}[1], [x3], #4\n" |
2873 | "st1 {v16.s}[2], [x3], #4\n" |
2874 | "st1 {v16.s}[3], [x3], #4\n" |
2875 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2876 | "mov x3, x4\n" |
2877 | "st1 {v17.s}[0], [x3], #4\n" |
2878 | "add x4, x4, x11\n" |
2879 | "st1 {v17.s}[1], [x3], #4\n" |
2880 | "st1 {v17.s}[2], [x3], #4\n" |
2881 | "st1 {v17.s}[3], [x3], #4\n" |
2882 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2883 | "mov x3, x4\n" |
2884 | "st1 {v18.s}[0], [x3], #4\n" |
2885 | "add x4, x4, x11\n" |
2886 | "st1 {v18.s}[1], [x3], #4\n" |
2887 | "st1 {v18.s}[2], [x3], #4\n" |
2888 | "st1 {v18.s}[3], [x3], #4\n" |
2889 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
2890 | "mov x3, x4\n" |
2891 | "st1 {v19.s}[0], [x3], #4\n" |
2892 | "add x4, x4, x11\n" |
2893 | "st1 {v19.s}[1], [x3], #4\n" |
2894 | "st1 {v19.s}[2], [x3], #4\n" |
2895 | "st1 {v19.s}[3], [x3], #4\n" |
2896 | "31:\n" |
2897 | |
2898 | "add %[dst_ptr], %[dst_ptr], #16\n" |
2899 | |
2900 | RUY_MAKE_ZERO(v16) |
2901 | RUY_MAKE_ZERO(v17) |
2902 | RUY_MAKE_ZERO(v18) |
2903 | RUY_MAKE_ZERO(v19) |
2904 | |
2905 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
2906 | |
2907 | // For the next block: perform the first few multiply-adds on the data |
2908 | // that we have already loaded. |
2909 | "smull v8.8h, v0.8b, v4.8b\n" |
2910 | "smull v9.8h, v1.8b, v4.8b\n" |
2911 | "smull v10.8h, v2.8b, v4.8b\n" |
2912 | // Reload some params --- we had used x5 -- x7 for a few other things |
2913 | // since the last time we had loaded them. |
2914 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
2915 | "smull v11.8h, v3.8b, v4.8b\n" |
2916 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
2917 | "smull v12.8h, v0.8b, v5.8b\n" |
2918 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
2919 | "smull v13.8h, v1.8b, v5.8b\n" |
2920 | "smull v14.8h, v2.8b, v5.8b\n" |
2921 | "smull v15.8h, v3.8b, v5.8b\n" |
2922 | // Move to the next block of the destination matrix, for the next iter |
2923 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
2924 | // been updated earlier. |
2925 | // Have we reached the end row? |
2926 | "cmp %w[row], w7\n" |
2927 | "smlal2 v8.8h, v0.16b, v4.16b\n" |
2928 | "smlal2 v9.8h, v1.16b, v4.16b\n" |
2929 | "smlal2 v10.8h, v2.16b, v4.16b\n" |
2930 | "smlal2 v11.8h, v3.16b, v4.16b\n" |
2931 | "smlal2 v12.8h, v0.16b, v5.16b\n" |
2932 | "smlal2 v13.8h, v1.16b, v5.16b\n" |
2933 | "smlal2 v14.8h, v2.16b, v5.16b\n" |
2934 | "smlal2 v15.8h, v3.16b, v5.16b\n" |
2935 | |
2936 | |
2937 | "beq 20f\n" // yes, end row. |
2938 | // Not end row. Move to the next row. |
2939 | "add %w[row], %w[row], #4\n" |
2940 | "b 21f\n" |
2941 | "20:\n" |
2942 | // Was already at end row. |
2943 | "mov %w[row], w6\n" // Move back to first row. |
2944 | "add %w[col], %w[col], #4\n" // Move to the next column. |
2945 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" |
2946 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
2947 | "21:\n" |
2948 | |
2949 | // Main loop exit condition: have we hit the end column? |
2950 | "cmp %w[col], w8\n" |
2951 | |
2952 | // w1 is the number of levels of depth that we have already loaded |
2953 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
2954 | // above, this is currently 4. |
2955 | "mov w1, #16\n" |
2956 | |
2957 | "ble 1b\n" |
2958 | |
2959 | // clang-format on |
2960 | |
2961 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
2962 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
2963 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
2964 | : [ params ] "r" (¶ms),[dst_rows] "r" (params.dst_rows), |
2965 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
2966 | [dst_type_id] "r" (params.dst_type_id) |
2967 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
2968 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
2969 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
2970 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
2971 | } |
2972 | |
2973 | // Kernel taking advantage of the optional dotprod instruction. |
2974 | // This is very similar to (and directly inspired by) this gemmlowp kernel |
2975 | // which was contributed by David Mansell at ARM: |
2976 | // NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct |
2977 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 |
2978 | // |
2979 | // Besides the ruy-ification, the main difference here is that we use a 8x8 |
2980 | // instead of 12x8 width, so as to stick to power-of-two widths. This slightly |
2981 | // narrower kernel layout is still wide enough to achieve high performance |
2982 | // although we haven't actually performed a real comparison to know exactly |
2983 | // how this compares to ARM's aforementioned kernel. |
2984 | // |
2985 | // Relevant target CPUs for this kernel include ARM Cortex-A76, |
2986 | // since these are 64-bit, out-of-order and with dotprod support. |
2987 | void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { |
2988 | profiler::ScopeLabel label("Kernel (kNeonDotprod)" ); |
2989 | |
2990 | CheckOffsetsInKernelParams8bit(params); |
2991 | |
2992 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
2993 | const std::int8_t* rhs_col_ptr = |
2994 | static_cast<const int8_t*>(params.rhs_base_ptr); |
2995 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
2996 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
2997 | void* dst_col_ptr = params.dst_base_ptr; |
2998 | void* dst_ptr = dst_col_ptr; |
2999 | int row = params.start_row; |
3000 | int col = params.start_col; |
3001 | |
3002 | // The asm kernel below has the following NEON register allocation: |
3003 | // |
3004 | // v16 -- v31 are int32 accumulators. |
3005 | // During accumulation, v0 -- v15 are used to load int8 data from LHS and |
3006 | // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and |
3007 | // v3 are used to load a 4x8 block of RHS, like this: |
3008 | // |
3009 | // int8 RHS 4x8 block |
3010 | // /-----------------------------------------| |
3011 | // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| |
3012 | // | ... ... | |
3013 | // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| |
3014 | // \-----------------------------------------/ |
3015 | // int8 LHS 8x4 block |
3016 | // /---------------------\ /-----------------------------------------| |
3017 | // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| |
3018 | // | ... ... | | ... ... | |
3019 | // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| |
3020 | // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| |
3021 | // | ... ... | | ... ... | |
3022 | // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| |
3023 | // \---------------------/ \-----------------------------------------/ |
3024 | // int32 accumulators 8x8 block |
3025 | // |
3026 | // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step |
3027 | // is repeated 4 times, using 4x more registers for LHS and RHS, so that |
3028 | // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. |
3029 | // |
3030 | // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are |
3031 | // unused, and v8 -- v15 are used for loading parameters used for the |
3032 | // post-accumulation part of the kernel. |
3033 | asm volatile( |
3034 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
3035 | |
3036 | // clang-format off |
3037 | |
3038 | // Load some parameters into registers. |
3039 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
3040 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
3041 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
3042 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
3043 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
3044 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
3045 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
3046 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
3047 | |
3048 | // Load the first 32 bytes of LHS and RHS data. |
3049 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
3050 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
3051 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
3052 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
3053 | |
3054 | // Clear accumulators. |
3055 | RUY_MAKE_ZERO(v16) |
3056 | RUY_MAKE_ZERO(v17) |
3057 | RUY_MAKE_ZERO(v18) |
3058 | RUY_MAKE_ZERO(v19) |
3059 | RUY_MAKE_ZERO(v20) |
3060 | RUY_MAKE_ZERO(v21) |
3061 | RUY_MAKE_ZERO(v22) |
3062 | RUY_MAKE_ZERO(v23) |
3063 | RUY_MAKE_ZERO(v24) |
3064 | RUY_MAKE_ZERO(v25) |
3065 | RUY_MAKE_ZERO(v26) |
3066 | RUY_MAKE_ZERO(v27) |
3067 | RUY_MAKE_ZERO(v28) |
3068 | RUY_MAKE_ZERO(v29) |
3069 | RUY_MAKE_ZERO(v30) |
3070 | RUY_MAKE_ZERO(v31) |
3071 | |
3072 | // w1 is the number of levels of depth that we have already loaded |
3073 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
3074 | // above, this is currently 4. |
3075 | "mov w1, #4\n" |
3076 | |
3077 | // Perform the first few multiply-adds on the data that we have already |
3078 | // loaded. |
3079 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
3080 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
3081 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
3082 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
3083 | |
3084 | // Main loop of the whole GEMM, over rows and columns of the |
3085 | // destination matrix. |
3086 | "1:\n" |
3087 | |
3088 | // Optional, maximally-streaming, partial-unrolling (4x unrolled) |
3089 | // optimization of the kernel inner loop (over depth). For more |
3090 | // comments, see the non-unrolled loop below after the #endif. |
3091 | #if RUY_OPT(MAX_STREAMING) |
3092 | "cmp w12, #32\n" |
3093 | "blt 78f\n" |
3094 | |
3095 | "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" |
3096 | "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" |
3097 | "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" |
3098 | "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" |
3099 | "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" |
3100 | "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" |
3101 | "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" |
3102 | "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" |
3103 | "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" |
3104 | "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" |
3105 | "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" |
3106 | "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" |
3107 | "mov w1, #16\n" |
3108 | |
3109 | "and w3, w12, #-16\n" |
3110 | "81:\n" |
3111 | "add w1, w1, #16\n" |
3112 | |
3113 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
3114 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
3115 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
3116 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
3117 | "ldr q0, [%[lhs_ptr], #0]\n" |
3118 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
3119 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
3120 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
3121 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
3122 | "ldr q2, [%[rhs_ptr], #0]\n" |
3123 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
3124 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
3125 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
3126 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
3127 | "ldr q1, [%[lhs_ptr], #16]\n" |
3128 | |
3129 | ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" |
3130 | ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" |
3131 | "ldr q3, [%[rhs_ptr], #16]\n" |
3132 | ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" |
3133 | ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" |
3134 | ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" |
3135 | ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" |
3136 | ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" |
3137 | ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" |
3138 | ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" |
3139 | ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" |
3140 | ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" |
3141 | ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" |
3142 | "ldr q5, [%[lhs_ptr], #48]\n" |
3143 | ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" |
3144 | ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" |
3145 | "ldr q7, [%[rhs_ptr], #48]\n" |
3146 | ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" |
3147 | ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" |
3148 | "ldr q4, [%[lhs_ptr], #32]\n" |
3149 | |
3150 | ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" |
3151 | ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" |
3152 | "ldr q6, [%[rhs_ptr], #32]\n" |
3153 | ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" |
3154 | ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" |
3155 | ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" |
3156 | ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" |
3157 | ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" |
3158 | ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" |
3159 | ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" |
3160 | ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" |
3161 | ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" |
3162 | ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" |
3163 | "ldr q9, [%[lhs_ptr], #80]\n" |
3164 | ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" |
3165 | ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" |
3166 | "ldr q11, [%[rhs_ptr], #80]\n" |
3167 | ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" |
3168 | ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" |
3169 | "ldr q8, [%[lhs_ptr], #64]\n" |
3170 | |
3171 | ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" |
3172 | ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" |
3173 | "ldr q10, [%[rhs_ptr], #64]\n" |
3174 | ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" |
3175 | ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" |
3176 | "add %[lhs_ptr], %[lhs_ptr], #128\n" |
3177 | ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" |
3178 | ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" |
3179 | "add %[rhs_ptr], %[rhs_ptr], #128\n" |
3180 | ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" |
3181 | ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" |
3182 | ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" |
3183 | ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" |
3184 | "cmp w1, w3\n" |
3185 | ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" |
3186 | ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" |
3187 | "ldr q13, [%[lhs_ptr], #-16]\n" |
3188 | ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" |
3189 | ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" |
3190 | "ldr q15, [%[rhs_ptr], #-16]\n" |
3191 | ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" |
3192 | ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" |
3193 | "ldr q12, [%[lhs_ptr], #-32]\n" |
3194 | |
3195 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
3196 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
3197 | "ldr q14, [%[rhs_ptr], #-32]\n" |
3198 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
3199 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
3200 | |
3201 | "blt 81b\n" |
3202 | |
3203 | ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" |
3204 | ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" |
3205 | ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" |
3206 | ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" |
3207 | ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" |
3208 | ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" |
3209 | ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" |
3210 | ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" |
3211 | ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" |
3212 | ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" |
3213 | ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" |
3214 | ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" |
3215 | ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" |
3216 | ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" |
3217 | ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" |
3218 | ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" |
3219 | |
3220 | ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" |
3221 | ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" |
3222 | ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" |
3223 | ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" |
3224 | ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" |
3225 | ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" |
3226 | ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" |
3227 | ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" |
3228 | ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" |
3229 | ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" |
3230 | ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" |
3231 | ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" |
3232 | ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" |
3233 | ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" |
3234 | ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" |
3235 | ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" |
3236 | |
3237 | ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" |
3238 | ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" |
3239 | ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" |
3240 | ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" |
3241 | ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" |
3242 | ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" |
3243 | ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" |
3244 | ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" |
3245 | ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" |
3246 | ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" |
3247 | ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" |
3248 | ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" |
3249 | ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" |
3250 | ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" |
3251 | ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" |
3252 | ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" |
3253 | |
3254 | "78:\n" |
3255 | |
3256 | #endif // #if RUY_OPT(MAX_STREAMING) |
3257 | |
3258 | // Ordinary kernel inner loop (over depth), the simpler loop that the |
3259 | // above was an equivalent 4x-partially-unrolled version of. |
3260 | |
3261 | // Reminder - w1 is how many levels of depth we have already loaded |
3262 | // data for, w12 is the total depth. |
3263 | "cmp w1, w12\n" |
3264 | "beq 79f\n" |
3265 | |
3266 | "2:\n" |
3267 | |
3268 | // Because of the data that we have already loaded, we can start the |
3269 | // loop body right away with some multiply-adds. |
3270 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
3271 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
3272 | // Each iteration of this loop advances by 4 levels of depth. |
3273 | "add w1, w1, #4\n" |
3274 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
3275 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
3276 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
3277 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
3278 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
3279 | // Loop termination condition. |
3280 | "cmp w1, w12\n" |
3281 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
3282 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
3283 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
3284 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
3285 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
3286 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
3287 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
3288 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
3289 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
3290 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
3291 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
3292 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
3293 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
3294 | |
3295 | "blt 2b\n" |
3296 | |
3297 | "79:\n" |
3298 | // End of the inner loop on depth. Now perform the remaining |
3299 | // multiply-adds of the last 4 levels of depth, for which the LHS |
3300 | // and RHS data is already loaded. |
3301 | |
3302 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
3303 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
3304 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
3305 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
3306 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
3307 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
3308 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
3309 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
3310 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
3311 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
3312 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
3313 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
3314 | |
3315 | // End of accumulation. The registers v16 -- v31 contain the final |
3316 | // int32 accumulator values of the current 8x8 destination block. |
3317 | // We now have to compute the final 8-bit values from these int32 |
3318 | // accumulators, and advance to the next 8x8 block. We intertwine |
3319 | // these two aspects whenever possible for optimal pipelining, both |
3320 | // at the data flow level (prefetch data for next block as early as |
3321 | // possible) and instruction pipelining level (some of the next-block |
3322 | // work can dual-issue with some of the final work on the current |
3323 | // block). |
3324 | |
3325 | // Logic to advance to the next block in preparation for the next |
3326 | // iteration of the main loop. For now, we only want to compute |
3327 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
3328 | // not yet ready to update the values of row and col, as we still need |
3329 | // the current values for the rest of the work on the current block. |
3330 | |
3331 | "cmp %w[row], w7\n" // Have we finished the last row? |
3332 | "bge 4f\n" // If finished last row, go to 4 |
3333 | // Not finished last row: then advance to next row. |
3334 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
3335 | "b 5f\n" |
3336 | "4:\n" // Finished last row... |
3337 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
3338 | // Now we need to advance to the next column. If we already |
3339 | // finished the last column, then in principle we are done, however |
3340 | // we can't just return here, as we need to allow the end work of the |
3341 | // current block to complete. The good news is that at this point it |
3342 | // doesn't matter what data we load for the next column, since |
3343 | // we will exit from the main loop below before actually storing |
3344 | // anything computed from that data. |
3345 | "cmp %w[col], w8\n" // Have we finished the last column? |
3346 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
3347 | // Not finished last column: then advance to next column. |
3348 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
3349 | "5:\n" |
3350 | |
3351 | // Set the LHS and RHS data pointers to the start of the columns just |
3352 | // computed. |
3353 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
3354 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
3355 | |
3356 | // Load some parameters needed for the end work on current block. |
3357 | "mvni v8.4s, #0\n" |
3358 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
3359 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
3360 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
3361 | |
3362 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
3363 | // Determine the channel index. |
3364 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
3365 | "csel w3, %w[row], %w[col], eq\n" |
3366 | |
3367 | // Offset the bias pointer as needed given the current row, col. |
3368 | "add x5, x1, x3, lsl #2\n" |
3369 | |
3370 | // If there is no bias, use no offset, just address the passed zero |
3371 | // data. |
3372 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
3373 | "csel x1, x1, x5, eq\n" |
3374 | |
3375 | // Load 8 bias values. |
3376 | "ld1 {v14.4s}, [x1], #16\n" |
3377 | "ld1 {v15.4s}, [x1]\n" |
3378 | |
3379 | // Now that we know what LHS and RHS data the next iteration of the |
3380 | // main loop will need to load, we start loading the first 32 bytes of |
3381 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
3382 | // in the rest of the work on the current block. |
3383 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
3384 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
3385 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
3386 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
3387 | |
3388 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
3389 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
3390 | "add v14.4s, v14.4s, v9.4s\n" |
3391 | "add v15.4s, v15.4s, v9.4s\n" |
3392 | |
3393 | // Perform the bias-addition (per the above, we have just folded into |
3394 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
3395 | // Jump based on channel dimension. |
3396 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
3397 | "bne 6f\n" |
3398 | // Case where channels are rows |
3399 | "add v16.4s, v16.4s, v14.4s\n" |
3400 | "add v17.4s, v17.4s, v15.4s\n" |
3401 | "add v18.4s, v18.4s, v14.4s\n" |
3402 | "add v19.4s, v19.4s, v15.4s\n" |
3403 | "add v20.4s, v20.4s, v14.4s\n" |
3404 | "add v21.4s, v21.4s, v15.4s\n" |
3405 | "add v22.4s, v22.4s, v14.4s\n" |
3406 | "add v23.4s, v23.4s, v15.4s\n" |
3407 | "add v24.4s, v24.4s, v14.4s\n" |
3408 | "add v25.4s, v25.4s, v15.4s\n" |
3409 | "add v26.4s, v26.4s, v14.4s\n" |
3410 | "add v27.4s, v27.4s, v15.4s\n" |
3411 | "add v28.4s, v28.4s, v14.4s\n" |
3412 | "add v29.4s, v29.4s, v15.4s\n" |
3413 | "add v30.4s, v30.4s, v14.4s\n" |
3414 | "add v31.4s, v31.4s, v15.4s\n" |
3415 | "b 7f\n" |
3416 | |
3417 | "6:\n" |
3418 | // Case where channels are columns |
3419 | "dup v10.4s, v14.s[0]\n" |
3420 | "dup v11.4s, v14.s[1]\n" |
3421 | "dup v12.4s, v14.s[2]\n" |
3422 | "dup v13.4s, v14.s[3]\n" |
3423 | "add v16.4s, v16.4s, v10.4s\n" |
3424 | "add v17.4s, v17.4s, v10.4s\n" |
3425 | "add v18.4s, v18.4s, v11.4s\n" |
3426 | "add v19.4s, v19.4s, v11.4s\n" |
3427 | "add v20.4s, v20.4s, v12.4s\n" |
3428 | "add v21.4s, v21.4s, v12.4s\n" |
3429 | "add v22.4s, v22.4s, v13.4s\n" |
3430 | "add v23.4s, v23.4s, v13.4s\n" |
3431 | "dup v10.4s, v15.s[0]\n" |
3432 | "dup v11.4s, v15.s[1]\n" |
3433 | "dup v12.4s, v15.s[2]\n" |
3434 | "dup v13.4s, v15.s[3]\n" |
3435 | "add v24.4s, v24.4s, v10.4s\n" |
3436 | "add v25.4s, v25.4s, v10.4s\n" |
3437 | "add v26.4s, v26.4s, v11.4s\n" |
3438 | "add v27.4s, v27.4s, v11.4s\n" |
3439 | "add v28.4s, v28.4s, v12.4s\n" |
3440 | "add v29.4s, v29.4s, v12.4s\n" |
3441 | "add v30.4s, v30.4s, v13.4s\n" |
3442 | "add v31.4s, v31.4s, v13.4s\n" |
3443 | "7:\n" |
3444 | |
3445 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
3446 | "beq 401f\n" |
3447 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
3448 | "add x3, x3, %x[col], lsl #2\n" |
3449 | "ld1 {v14.4s}, [x3], #16\n" |
3450 | "ld1 {v15.4s}, [x3]\n" |
3451 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
3452 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
3453 | // Subtract rhs_sums * lhs_zero_point, per |
3454 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
3455 | "mls v16.4s, v10.4s, v14.s[0]\n" |
3456 | "mls v17.4s, v10.4s, v14.s[0]\n" |
3457 | "mls v18.4s, v10.4s, v14.s[1]\n" |
3458 | "mls v19.4s, v10.4s, v14.s[1]\n" |
3459 | "mls v20.4s, v10.4s, v14.s[2]\n" |
3460 | "mls v21.4s, v10.4s, v14.s[2]\n" |
3461 | "mls v22.4s, v10.4s, v14.s[3]\n" |
3462 | "mls v23.4s, v10.4s, v14.s[3]\n" |
3463 | "mls v24.4s, v10.4s, v15.s[0]\n" |
3464 | "mls v25.4s, v10.4s, v15.s[0]\n" |
3465 | "mls v26.4s, v10.4s, v15.s[1]\n" |
3466 | "mls v27.4s, v10.4s, v15.s[1]\n" |
3467 | "mls v28.4s, v10.4s, v15.s[2]\n" |
3468 | "mls v29.4s, v10.4s, v15.s[2]\n" |
3469 | "mls v30.4s, v10.4s, v15.s[3]\n" |
3470 | "mls v31.4s, v10.4s, v15.s[3]\n" |
3471 | "401:\n" |
3472 | |
3473 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
3474 | "beq 402f\n" |
3475 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
3476 | "add x2, x2, %x[row], lsl #2\n" |
3477 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
3478 | // Load 4 lhs_sums values. |
3479 | "ld1 {v11.4s}, [x2], #16\n" |
3480 | "ld1 {v12.4s}, [x2]\n" |
3481 | "ins v13.s[1], w5\n" // rhs_zero_point |
3482 | // Compute lhs_sums * rhs_zero_point. |
3483 | "mul v11.4s, v11.4s, v13.s[1]\n" |
3484 | "mul v12.4s, v12.4s, v13.s[1]\n" |
3485 | // Subtract lhs_sums * rhs_zero_point, per |
3486 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
3487 | "sub v16.4s, v16.4s, v11.4s\n" |
3488 | "sub v17.4s, v17.4s, v12.4s\n" |
3489 | "sub v18.4s, v18.4s, v11.4s\n" |
3490 | "sub v19.4s, v19.4s, v12.4s\n" |
3491 | "sub v20.4s, v20.4s, v11.4s\n" |
3492 | "sub v21.4s, v21.4s, v12.4s\n" |
3493 | "sub v22.4s, v22.4s, v11.4s\n" |
3494 | "sub v23.4s, v23.4s, v12.4s\n" |
3495 | "sub v24.4s, v24.4s, v11.4s\n" |
3496 | "sub v25.4s, v25.4s, v12.4s\n" |
3497 | "sub v26.4s, v26.4s, v11.4s\n" |
3498 | "sub v27.4s, v27.4s, v12.4s\n" |
3499 | "sub v28.4s, v28.4s, v11.4s\n" |
3500 | "sub v29.4s, v29.4s, v12.4s\n" |
3501 | "sub v30.4s, v30.4s, v11.4s\n" |
3502 | "sub v31.4s, v31.4s, v12.4s\n" |
3503 | |
3504 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
3505 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
3506 | |
3507 | "402:\n" |
3508 | |
3509 | // At this point we have computed the final int32 values. Now we |
3510 | // start down-quantizing them to obtain the final 8bit values from them. |
3511 | |
3512 | // As part of this down-quantization, our int32 values will be |
3513 | // multiplied by a multiplier that has a fixed-point component and an |
3514 | // exponent component. |
3515 | |
3516 | //Load the exponent part of the multiplier. |
3517 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
3518 | // Determine the channel index. |
3519 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
3520 | "csel w3, %w[row], %w[col], eq\n" |
3521 | // Compute the multiplier_exponent pointer |
3522 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
3523 | "add x5, x1, x3, lsl #2\n" |
3524 | "csel x1, x1, x5, eq\n" |
3525 | // Load multiplier_exponent |
3526 | "ldr q9, [x1]\n" |
3527 | "ldr q10, [x1, #16]\n" |
3528 | // Separate positive and negative exponents |
3529 | "smin v11.4s, v8.4s, v9.4s\n" |
3530 | "smin v12.4s, v8.4s, v10.4s\n" |
3531 | "sub v9.4s, v9.4s, v11.4s\n" |
3532 | "sub v10.4s, v10.4s, v12.4s\n" |
3533 | |
3534 | // Compute the multiplier_fixedpoint pointer |
3535 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
3536 | "add x5, x4, x3, lsl #2\n" |
3537 | "csel x4, x4, x5, eq\n" |
3538 | // Load multiplier_fixedpoint |
3539 | "ldr q14, [x4]\n" |
3540 | "ldr q15, [x4, #16]\n" |
3541 | |
3542 | // Jump based on channel dimension. |
3543 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
3544 | "bne 8f\n" |
3545 | // Case where channels are rows |
3546 | |
3547 | // Apply the positive exponent part of the multiplier. |
3548 | "sshl v16.4s, v16.4s, v9.4s\n" |
3549 | "sshl v17.4s, v17.4s, v10.4s\n" |
3550 | "sshl v18.4s, v18.4s, v9.4s\n" |
3551 | "sshl v19.4s, v19.4s, v10.4s\n" |
3552 | "sshl v20.4s, v20.4s, v9.4s\n" |
3553 | "sshl v21.4s, v21.4s, v10.4s\n" |
3554 | "sshl v22.4s, v22.4s, v9.4s\n" |
3555 | "sshl v23.4s, v23.4s, v10.4s\n" |
3556 | "sshl v24.4s, v24.4s, v9.4s\n" |
3557 | "sshl v25.4s, v25.4s, v10.4s\n" |
3558 | "sshl v26.4s, v26.4s, v9.4s\n" |
3559 | "sshl v27.4s, v27.4s, v10.4s\n" |
3560 | "sshl v28.4s, v28.4s, v9.4s\n" |
3561 | "sshl v29.4s, v29.4s, v10.4s\n" |
3562 | "sshl v30.4s, v30.4s, v9.4s\n" |
3563 | "sshl v31.4s, v31.4s, v10.4s\n" |
3564 | "10:\n" |
3565 | |
3566 | // Apply the fixed-point part of the multiplier. |
3567 | "sqdmulh v16.4s, v16.4s, v14.4s\n" |
3568 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
3569 | "sqdmulh v18.4s, v18.4s, v14.4s\n" |
3570 | "sqdmulh v19.4s, v19.4s, v15.4s\n" |
3571 | "sqdmulh v20.4s, v20.4s, v14.4s\n" |
3572 | "sqdmulh v21.4s, v21.4s, v15.4s\n" |
3573 | "sqdmulh v22.4s, v22.4s, v14.4s\n" |
3574 | "sqdmulh v23.4s, v23.4s, v15.4s\n" |
3575 | "sqdmulh v24.4s, v24.4s, v14.4s\n" |
3576 | "sqdmulh v25.4s, v25.4s, v15.4s\n" |
3577 | "sqdmulh v26.4s, v26.4s, v14.4s\n" |
3578 | "sqdmulh v27.4s, v27.4s, v15.4s\n" |
3579 | "sqdmulh v28.4s, v28.4s, v14.4s\n" |
3580 | "sqdmulh v29.4s, v29.4s, v15.4s\n" |
3581 | "sqdmulh v30.4s, v30.4s, v14.4s\n" |
3582 | "sqdmulh v31.4s, v31.4s, v15.4s\n" |
3583 | |
3584 | // Apply the negative exponent part of the multiplier. |
3585 | "srshl v16.4s, v16.4s, v11.4s\n" |
3586 | "srshl v17.4s, v17.4s, v12.4s\n" |
3587 | "srshl v18.4s, v18.4s, v11.4s\n" |
3588 | "srshl v19.4s, v19.4s, v12.4s\n" |
3589 | "srshl v20.4s, v20.4s, v11.4s\n" |
3590 | "srshl v21.4s, v21.4s, v12.4s\n" |
3591 | "srshl v22.4s, v22.4s, v11.4s\n" |
3592 | "srshl v23.4s, v23.4s, v12.4s\n" |
3593 | "srshl v24.4s, v24.4s, v11.4s\n" |
3594 | "srshl v25.4s, v25.4s, v12.4s\n" |
3595 | "srshl v26.4s, v26.4s, v11.4s\n" |
3596 | "srshl v27.4s, v27.4s, v12.4s\n" |
3597 | "srshl v28.4s, v28.4s, v11.4s\n" |
3598 | "srshl v29.4s, v29.4s, v12.4s\n" |
3599 | "srshl v30.4s, v30.4s, v11.4s\n" |
3600 | "srshl v31.4s, v31.4s, v12.4s\n" |
3601 | "b 9f\n" |
3602 | |
3603 | "8:\n" |
3604 | // Case where channels are columns |
3605 | |
3606 | // Apply the positive exponent part of the multiplier. |
3607 | "dup v4.4s, v9.s[0]\n" |
3608 | "dup v5.4s, v9.s[1]\n" |
3609 | "dup v6.4s, v9.s[2]\n" |
3610 | "dup v7.4s, v9.s[3]\n" |
3611 | "sshl v16.4s, v16.4s, v4.4s\n" |
3612 | "sshl v17.4s, v17.4s, v4.4s\n" |
3613 | "sshl v18.4s, v18.4s, v5.4s\n" |
3614 | "sshl v19.4s, v19.4s, v5.4s\n" |
3615 | "sshl v20.4s, v20.4s, v6.4s\n" |
3616 | "sshl v21.4s, v21.4s, v6.4s\n" |
3617 | "sshl v22.4s, v22.4s, v7.4s\n" |
3618 | "sshl v23.4s, v23.4s, v7.4s\n" |
3619 | "dup v4.4s, v10.s[0]\n" |
3620 | "dup v5.4s, v10.s[1]\n" |
3621 | "dup v6.4s, v10.s[2]\n" |
3622 | "dup v7.4s, v10.s[3]\n" |
3623 | "sshl v24.4s, v24.4s, v4.4s\n" |
3624 | "sshl v25.4s, v25.4s, v4.4s\n" |
3625 | "sshl v26.4s, v26.4s, v5.4s\n" |
3626 | "sshl v27.4s, v27.4s, v5.4s\n" |
3627 | "sshl v28.4s, v28.4s, v6.4s\n" |
3628 | "sshl v29.4s, v29.4s, v6.4s\n" |
3629 | "sshl v30.4s, v30.4s, v7.4s\n" |
3630 | "sshl v31.4s, v31.4s, v7.4s\n" |
3631 | "11:\n" |
3632 | |
3633 | // Apply the fixed-point part of the multiplier. |
3634 | "sqdmulh v16.4s, v16.4s, v14.s[0]\n" |
3635 | "sqdmulh v17.4s, v17.4s, v14.s[0]\n" |
3636 | "sqdmulh v18.4s, v18.4s, v14.s[1]\n" |
3637 | "sqdmulh v19.4s, v19.4s, v14.s[1]\n" |
3638 | "sqdmulh v20.4s, v20.4s, v14.s[2]\n" |
3639 | "sqdmulh v21.4s, v21.4s, v14.s[2]\n" |
3640 | "sqdmulh v22.4s, v22.4s, v14.s[3]\n" |
3641 | "sqdmulh v23.4s, v23.4s, v14.s[3]\n" |
3642 | "sqdmulh v24.4s, v24.4s, v15.s[0]\n" |
3643 | "sqdmulh v25.4s, v25.4s, v15.s[0]\n" |
3644 | "sqdmulh v26.4s, v26.4s, v15.s[1]\n" |
3645 | "sqdmulh v27.4s, v27.4s, v15.s[1]\n" |
3646 | "sqdmulh v28.4s, v28.4s, v15.s[2]\n" |
3647 | "sqdmulh v29.4s, v29.4s, v15.s[2]\n" |
3648 | "sqdmulh v30.4s, v30.4s, v15.s[3]\n" |
3649 | "sqdmulh v31.4s, v31.4s, v15.s[3]\n" |
3650 | |
3651 | // Apply the negative exponent part of the multiplier. |
3652 | "dup v4.4s, v11.s[0]\n" |
3653 | "dup v5.4s, v11.s[1]\n" |
3654 | "dup v6.4s, v11.s[2]\n" |
3655 | "dup v7.4s, v11.s[3]\n" |
3656 | "srshl v16.4s, v16.4s, v4.4s\n" |
3657 | "srshl v17.4s, v17.4s, v4.4s\n" |
3658 | "srshl v18.4s, v18.4s, v5.4s\n" |
3659 | "srshl v19.4s, v19.4s, v5.4s\n" |
3660 | "srshl v20.4s, v20.4s, v6.4s\n" |
3661 | "srshl v21.4s, v21.4s, v6.4s\n" |
3662 | "srshl v22.4s, v22.4s, v7.4s\n" |
3663 | "srshl v23.4s, v23.4s, v7.4s\n" |
3664 | "dup v4.4s, v12.s[0]\n" |
3665 | "dup v5.4s, v12.s[1]\n" |
3666 | "dup v6.4s, v12.s[2]\n" |
3667 | "dup v7.4s, v12.s[3]\n" |
3668 | "srshl v24.4s, v24.4s, v4.4s\n" |
3669 | "srshl v25.4s, v25.4s, v4.4s\n" |
3670 | "srshl v26.4s, v26.4s, v5.4s\n" |
3671 | "srshl v27.4s, v27.4s, v5.4s\n" |
3672 | "srshl v28.4s, v28.4s, v6.4s\n" |
3673 | "srshl v29.4s, v29.4s, v6.4s\n" |
3674 | "srshl v30.4s, v30.4s, v7.4s\n" |
3675 | "srshl v31.4s, v31.4s, v7.4s\n" |
3676 | "9:\n" |
3677 | |
3678 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
3679 | "ins v13.h[4], w4\n" // dst_zero_point |
3680 | |
3681 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
3682 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
3683 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
3684 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
3685 | |
3686 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
3687 | |
3688 | // Cast-and-saturate from int32 to int16 |
3689 | "sqxtn v16.4h, v16.4s\n" |
3690 | "sqxtn2 v16.8h, v17.4s\n" |
3691 | "sqxtn v17.4h, v18.4s\n" |
3692 | "sqxtn2 v17.8h, v19.4s\n" |
3693 | "sqxtn v18.4h, v20.4s\n" |
3694 | "sqxtn2 v18.8h, v21.4s\n" |
3695 | "sqxtn v19.4h, v22.4s\n" |
3696 | "sqxtn2 v19.8h, v23.4s\n" |
3697 | "sqxtn v20.4h, v24.4s\n" |
3698 | "sqxtn2 v20.8h, v25.4s\n" |
3699 | "sqxtn v21.4h, v26.4s\n" |
3700 | "sqxtn2 v21.8h, v27.4s\n" |
3701 | "sqxtn v22.4h, v28.4s\n" |
3702 | "sqxtn2 v22.8h, v29.4s\n" |
3703 | "sqxtn v23.4h, v30.4s\n" |
3704 | "sqxtn2 v23.8h, v31.4s\n" |
3705 | |
3706 | // At this point, v24 -- v31 aren't used anymore for the current block, |
3707 | // so we can start clearing these accumulators for the next block |
3708 | // (next iteration of the main loop). |
3709 | RUY_MAKE_ZERO(v24) |
3710 | RUY_MAKE_ZERO(v25) |
3711 | RUY_MAKE_ZERO(v26) |
3712 | RUY_MAKE_ZERO(v27) |
3713 | RUY_MAKE_ZERO(v28) |
3714 | RUY_MAKE_ZERO(v29) |
3715 | RUY_MAKE_ZERO(v30) |
3716 | RUY_MAKE_ZERO(v31) |
3717 | |
3718 | // Add the destination zero point |
3719 | "dup v14.8h, v13.h[4]\n" |
3720 | "sqadd v16.8h, v16.8h, v14.8h\n" |
3721 | "sqadd v17.8h, v17.8h, v14.8h\n" |
3722 | "sqadd v18.8h, v18.8h, v14.8h\n" |
3723 | "sqadd v19.8h, v19.8h, v14.8h\n" |
3724 | "sqadd v20.8h, v20.8h, v14.8h\n" |
3725 | "sqadd v21.8h, v21.8h, v14.8h\n" |
3726 | "sqadd v22.8h, v22.8h, v14.8h\n" |
3727 | "sqadd v23.8h, v23.8h, v14.8h\n" |
3728 | |
3729 | // Cast-and-saturate from int16 to uint8 |
3730 | "sqxtun v16.8b, v16.8h\n" |
3731 | "sqxtun2 v16.16b, v17.8h\n" |
3732 | "sqxtun v17.8b, v18.8h\n" |
3733 | "sqxtun2 v17.16b, v19.8h\n" |
3734 | "sqxtun v18.8b, v20.8h\n" |
3735 | "sqxtun2 v18.16b, v21.8h\n" |
3736 | "sqxtun v19.8b, v22.8h\n" |
3737 | "sqxtun2 v19.16b, v23.8h\n" |
3738 | |
3739 | // Load the clamp_min, clamp_max bounds |
3740 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
3741 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
3742 | "dup v14.16b, w2\n" // clamp_min |
3743 | "dup v15.16b, w3\n" // clamp_max |
3744 | |
3745 | // Apply the clamp_min bound |
3746 | "umax v16.16b, v16.16b, v14.16b\n" |
3747 | "umax v17.16b, v17.16b, v14.16b\n" |
3748 | "umax v18.16b, v18.16b, v14.16b\n" |
3749 | "umax v19.16b, v19.16b, v14.16b\n" |
3750 | |
3751 | // Apply the clamp_max bound |
3752 | "umin v16.16b, v16.16b, v15.16b\n" |
3753 | "umin v17.16b, v17.16b, v15.16b\n" |
3754 | "umin v18.16b, v18.16b, v15.16b\n" |
3755 | "umin v19.16b, v19.16b, v15.16b\n" |
3756 | |
3757 | // Make it so that all of the final 8bit values are stored in the |
3758 | // first 64bits of 128bit NEON registers, so they can be stored |
3759 | // by 64bit st1 store instructions with byte alignment. |
3760 | "dup d20, v16.d[1]\n" |
3761 | "dup d21, v17.d[1]\n" |
3762 | "dup d22, v18.d[1]\n" |
3763 | "dup d23, v19.d[1]\n" |
3764 | |
3765 | // Compute how much of the 8x8 block of destination 8bit values that |
3766 | // we have computed, fit in the destination matrix. Typically, all of |
3767 | // it fits, but when the destination matrix shape is not a multiple |
3768 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
3769 | // not fit entirely. |
3770 | "sub w1, %w[dst_rows], %w[row]\n" |
3771 | "sub w2, %w[dst_cols], %w[col]\n" |
3772 | "mov w3, #8\n" |
3773 | "cmp w1, #8\n" |
3774 | // Compute w1 = how many rows of the 8x8 block fit |
3775 | "csel w1, w1, w3, le\n" |
3776 | "cmp w2, #8\n" |
3777 | // Compute w2 = how many cols of the 8x8 block fit |
3778 | "csel w2, w2, w3, le\n" |
3779 | |
3780 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
3781 | "cmp w1, w3\n" |
3782 | "ccmp w2, w3, 0, eq\n" |
3783 | // Yes, all of the 8x8 block fits, go to fast path. |
3784 | "beq 30f\n" |
3785 | // Not all of the 8x8 block fits. |
3786 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
3787 | "mov x3, %[dst_tmp_buf]\n" |
3788 | "mov x4, #8\n" |
3789 | "b 31f\n" |
3790 | "30:\n" |
3791 | // Yes, all of the 8x8 block fits. |
3792 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
3793 | "mov x3, %[dst_ptr]\n" |
3794 | "mov x4, x11\n" |
3795 | "31:\n" |
3796 | |
3797 | // Write our 8bit values to the destination described by |
3798 | // (x3 address, x4 stride). |
3799 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3800 | "st1 {v16.8b}, [x3], x4\n" |
3801 | RUY_MAKE_ZERO(v16) |
3802 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3803 | "st1 {v20.8b}, [x3], x4\n" |
3804 | RUY_MAKE_ZERO(v20) |
3805 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3806 | "st1 {v17.8b}, [x3], x4\n" |
3807 | RUY_MAKE_ZERO(v17) |
3808 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3809 | "st1 {v21.8b}, [x3], x4\n" |
3810 | RUY_MAKE_ZERO(v21) |
3811 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3812 | "st1 {v18.8b}, [x3], x4\n" |
3813 | RUY_MAKE_ZERO(v18) |
3814 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3815 | "st1 {v22.8b}, [x3], x4\n" |
3816 | RUY_MAKE_ZERO(v22) |
3817 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3818 | "st1 {v19.8b}, [x3], x4\n" |
3819 | RUY_MAKE_ZERO(v19) |
3820 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3821 | "st1 {v23.8b}, [x3], x4\n" |
3822 | RUY_MAKE_ZERO(v23) |
3823 | |
3824 | // For the next block: perform the first few multiply-adds on the data |
3825 | // that we have already loaded. |
3826 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
3827 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
3828 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
3829 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
3830 | |
3831 | // If all of the 8x8 block fits, we just finished writing it to the |
3832 | // destination, so we skip the next part. |
3833 | "beq 41f\n" |
3834 | // Not all of the 8x8 block fits in the destination matrix. We just |
3835 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
3836 | // it to copy into the destination matrix the part that fits. |
3837 | "mov x3, %[dst_tmp_buf]\n" |
3838 | "mov x4, %[dst_ptr]\n" |
3839 | "mov w6, #0\n" |
3840 | "50:\n" |
3841 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
3842 | "mov w5, #0\n" |
3843 | "51:\n" |
3844 | "ldrb w7, [x3, w5, uxtw]\n" |
3845 | "strb w7, [x4, w5, uxtw]\n" |
3846 | "add w5, w5, #1\n" |
3847 | "cmp w5, w1\n" |
3848 | "blt 51b\n" |
3849 | "add w6, w6, #1\n" |
3850 | "add x3, x3, #8\n" |
3851 | "add x4, x4, x11\n" |
3852 | "cmp w6, w2\n" |
3853 | "blt 50b\n" |
3854 | "41:\n" |
3855 | "add %[dst_ptr], %[dst_ptr], #8\n" |
3856 | // At this point we have completely finished writing values to the |
3857 | // destination matrix for the current block. |
3858 | |
3859 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
3860 | |
3861 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
3862 | |
3863 | // Cast-and-saturate from int32 to int16 |
3864 | "sqxtn v16.4h, v16.4s\n" |
3865 | "sqxtn2 v16.8h, v17.4s\n" |
3866 | "sqxtn v17.4h, v18.4s\n" |
3867 | "sqxtn2 v17.8h, v19.4s\n" |
3868 | "sqxtn v18.4h, v20.4s\n" |
3869 | "sqxtn2 v18.8h, v21.4s\n" |
3870 | "sqxtn v19.4h, v22.4s\n" |
3871 | "sqxtn2 v19.8h, v23.4s\n" |
3872 | "sqxtn v20.4h, v24.4s\n" |
3873 | "sqxtn2 v20.8h, v25.4s\n" |
3874 | "sqxtn v21.4h, v26.4s\n" |
3875 | "sqxtn2 v21.8h, v27.4s\n" |
3876 | "sqxtn v22.4h, v28.4s\n" |
3877 | "sqxtn2 v22.8h, v29.4s\n" |
3878 | "sqxtn v23.4h, v30.4s\n" |
3879 | "sqxtn2 v23.8h, v31.4s\n" |
3880 | |
3881 | // At this point, v24 -- v31 aren't used anymore for the current block, |
3882 | // so we can start clearing these accumulators for the next block |
3883 | // (next iteration of the main loop). |
3884 | RUY_MAKE_ZERO(v24) |
3885 | RUY_MAKE_ZERO(v25) |
3886 | RUY_MAKE_ZERO(v26) |
3887 | RUY_MAKE_ZERO(v27) |
3888 | RUY_MAKE_ZERO(v28) |
3889 | RUY_MAKE_ZERO(v29) |
3890 | RUY_MAKE_ZERO(v30) |
3891 | RUY_MAKE_ZERO(v31) |
3892 | |
3893 | // Add the destination zero point |
3894 | "dup v14.8h, v13.h[4]\n" |
3895 | "sqadd v16.8h, v16.8h, v14.8h\n" |
3896 | "sqadd v17.8h, v17.8h, v14.8h\n" |
3897 | "sqadd v18.8h, v18.8h, v14.8h\n" |
3898 | "sqadd v19.8h, v19.8h, v14.8h\n" |
3899 | "sqadd v20.8h, v20.8h, v14.8h\n" |
3900 | "sqadd v21.8h, v21.8h, v14.8h\n" |
3901 | "sqadd v22.8h, v22.8h, v14.8h\n" |
3902 | "sqadd v23.8h, v23.8h, v14.8h\n" |
3903 | |
3904 | // Cast-and-saturate from int16 to uint8 |
3905 | "sqxtn v16.8b, v16.8h\n" |
3906 | "sqxtn2 v16.16b, v17.8h\n" |
3907 | "sqxtn v17.8b, v18.8h\n" |
3908 | "sqxtn2 v17.16b, v19.8h\n" |
3909 | "sqxtn v18.8b, v20.8h\n" |
3910 | "sqxtn2 v18.16b, v21.8h\n" |
3911 | "sqxtn v19.8b, v22.8h\n" |
3912 | "sqxtn2 v19.16b, v23.8h\n" |
3913 | |
3914 | // Load the clamp_min, clamp_max bounds |
3915 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
3916 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
3917 | "dup v14.16b, w2\n" // clamp_min |
3918 | "dup v15.16b, w3\n" // clamp_max |
3919 | |
3920 | // Apply the clamp_min bound |
3921 | "smax v16.16b, v16.16b, v14.16b\n" |
3922 | "smax v17.16b, v17.16b, v14.16b\n" |
3923 | "smax v18.16b, v18.16b, v14.16b\n" |
3924 | "smax v19.16b, v19.16b, v14.16b\n" |
3925 | |
3926 | // Apply the clamp_max bound |
3927 | "smin v16.16b, v16.16b, v15.16b\n" |
3928 | "smin v17.16b, v17.16b, v15.16b\n" |
3929 | "smin v18.16b, v18.16b, v15.16b\n" |
3930 | "smin v19.16b, v19.16b, v15.16b\n" |
3931 | |
3932 | // Make it so that all of the final 8bit values are stored in the |
3933 | // first 64bits of 128bit NEON registers, so they can be stored |
3934 | // by 64bit st1 store instructions with byte alignment. |
3935 | "dup d20, v16.d[1]\n" |
3936 | "dup d21, v17.d[1]\n" |
3937 | "dup d22, v18.d[1]\n" |
3938 | "dup d23, v19.d[1]\n" |
3939 | |
3940 | // Compute how much of the 8x8 block of destination 8bit values that |
3941 | // we have computed, fit in the destination matrix. Typically, all of |
3942 | // it fits, but when the destination matrix shape is not a multiple |
3943 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
3944 | // not fit entirely. |
3945 | "sub w1, %w[dst_rows], %w[row]\n" |
3946 | "sub w2, %w[dst_cols], %w[col]\n" |
3947 | "mov w3, #8\n" |
3948 | "cmp w1, #8\n" |
3949 | // Compute w1 = how many rows of the 8x8 block fit |
3950 | "csel w1, w1, w3, le\n" |
3951 | "cmp w2, #8\n" |
3952 | // Compute w2 = how many cols of the 8x8 block fit |
3953 | "csel w2, w2, w3, le\n" |
3954 | |
3955 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
3956 | "cmp w1, w3\n" |
3957 | "ccmp w2, w3, 0, eq\n" |
3958 | // Yes, all of the 8x8 block fits, go to fast path. |
3959 | "beq 130f\n" |
3960 | // Not all of the 8x8 block fits. |
3961 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
3962 | "mov x3, %[dst_tmp_buf]\n" |
3963 | "mov x4, #8\n" |
3964 | "b 131f\n" |
3965 | "130:\n" |
3966 | // Yes, all of the 8x8 block fits. |
3967 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
3968 | "mov x3, %[dst_ptr]\n" |
3969 | "mov x4, x11\n" |
3970 | "131:\n" |
3971 | |
3972 | // Write our 8bit values to the destination described by |
3973 | // (x3 address, x4 stride). |
3974 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3975 | "st1 {v16.8b}, [x3], x4\n" |
3976 | RUY_MAKE_ZERO(v16) |
3977 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3978 | "st1 {v20.8b}, [x3], x4\n" |
3979 | RUY_MAKE_ZERO(v20) |
3980 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3981 | "st1 {v17.8b}, [x3], x4\n" |
3982 | RUY_MAKE_ZERO(v17) |
3983 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3984 | "st1 {v21.8b}, [x3], x4\n" |
3985 | RUY_MAKE_ZERO(v21) |
3986 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3987 | "st1 {v18.8b}, [x3], x4\n" |
3988 | RUY_MAKE_ZERO(v18) |
3989 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3990 | "st1 {v22.8b}, [x3], x4\n" |
3991 | RUY_MAKE_ZERO(v22) |
3992 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3993 | "st1 {v19.8b}, [x3], x4\n" |
3994 | RUY_MAKE_ZERO(v19) |
3995 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
3996 | "st1 {v23.8b}, [x3], x4\n" |
3997 | RUY_MAKE_ZERO(v23) |
3998 | |
3999 | // For the next block: perform the first few multiply-adds on the data |
4000 | // that we have already loaded. |
4001 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
4002 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
4003 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
4004 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
4005 | |
4006 | // If all of the 8x8 block fits, we just finished writing it to the |
4007 | // destination, so we skip the next part. |
4008 | "beq 141f\n" |
4009 | // Not all of the 8x8 block fits in the destination matrix. We just |
4010 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
4011 | // it to copy into the destination matrix the part that fits. |
4012 | "mov x3, %[dst_tmp_buf]\n" |
4013 | "mov x4, %[dst_ptr]\n" |
4014 | "mov w6, #0\n" |
4015 | "150:\n" |
4016 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4017 | "mov w5, #0\n" |
4018 | "151:\n" |
4019 | "ldrb w7, [x3, w5, uxtw]\n" |
4020 | "strb w7, [x4, w5, uxtw]\n" |
4021 | "add w5, w5, #1\n" |
4022 | "cmp w5, w1\n" |
4023 | "blt 151b\n" |
4024 | "add w6, w6, #1\n" |
4025 | "add x3, x3, #8\n" |
4026 | "add x4, x4, x11\n" |
4027 | "cmp w6, w2\n" |
4028 | "blt 150b\n" |
4029 | "141:\n" |
4030 | "add %[dst_ptr], %[dst_ptr], #8\n" |
4031 | // At this point we have completely finished writing values to the |
4032 | // destination matrix for the current block. |
4033 | |
4034 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
4035 | |
4036 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
4037 | |
4038 | // Add the destination zero point |
4039 | "dup v14.8h, v13.h[4]\n" |
4040 | "saddw v16.4s, v16.4s, v14.4h\n" |
4041 | "saddw v17.4s, v17.4s, v14.4h\n" |
4042 | "saddw v18.4s, v18.4s, v14.4h\n" |
4043 | "saddw v19.4s, v19.4s, v14.4h\n" |
4044 | "saddw v20.4s, v20.4s, v14.4h\n" |
4045 | "saddw v21.4s, v21.4s, v14.4h\n" |
4046 | "saddw v22.4s, v22.4s, v14.4h\n" |
4047 | "saddw v23.4s, v23.4s, v14.4h\n" |
4048 | "saddw v24.4s, v24.4s, v14.4h\n" |
4049 | "saddw v25.4s, v25.4s, v14.4h\n" |
4050 | "saddw v26.4s, v26.4s, v14.4h\n" |
4051 | "saddw v27.4s, v27.4s, v14.4h\n" |
4052 | "saddw v28.4s, v28.4s, v14.4h\n" |
4053 | "saddw v29.4s, v29.4s, v14.4h\n" |
4054 | "saddw v30.4s, v30.4s, v14.4h\n" |
4055 | "saddw v31.4s, v31.4s, v14.4h\n" |
4056 | |
4057 | // Cast-and-saturate from int32 to int16 |
4058 | "sqxtn v16.4h, v16.4s\n" |
4059 | "sqxtn2 v16.8h, v17.4s\n" |
4060 | "sqxtn v17.4h, v18.4s\n" |
4061 | "sqxtn2 v17.8h, v19.4s\n" |
4062 | "sqxtn v18.4h, v20.4s\n" |
4063 | "sqxtn2 v18.8h, v21.4s\n" |
4064 | "sqxtn v19.4h, v22.4s\n" |
4065 | "sqxtn2 v19.8h, v23.4s\n" |
4066 | "sqxtn v20.4h, v24.4s\n" |
4067 | "sqxtn2 v20.8h, v25.4s\n" |
4068 | "sqxtn v21.4h, v26.4s\n" |
4069 | "sqxtn2 v21.8h, v27.4s\n" |
4070 | "sqxtn v22.4h, v28.4s\n" |
4071 | "sqxtn2 v22.8h, v29.4s\n" |
4072 | "sqxtn v23.4h, v30.4s\n" |
4073 | "sqxtn2 v23.8h, v31.4s\n" |
4074 | |
4075 | // At this point, v24 -- v31 aren't used anymore for the current block, |
4076 | // so we can start clearing these accumulators for the next block |
4077 | // (next iteration of the main loop). |
4078 | RUY_MAKE_ZERO(v24) |
4079 | RUY_MAKE_ZERO(v25) |
4080 | RUY_MAKE_ZERO(v26) |
4081 | RUY_MAKE_ZERO(v27) |
4082 | RUY_MAKE_ZERO(v28) |
4083 | RUY_MAKE_ZERO(v29) |
4084 | RUY_MAKE_ZERO(v30) |
4085 | RUY_MAKE_ZERO(v31) |
4086 | |
4087 | // Load the clamp_min, clamp_max bounds |
4088 | "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
4089 | "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
4090 | "dup v14.8h, w2\n" // clamp_min |
4091 | "dup v15.8h, w3\n" // clamp_max |
4092 | |
4093 | // Apply the clamp_min bound |
4094 | "smax v16.8h, v16.8h, v14.8h\n" |
4095 | "smax v17.8h, v17.8h, v14.8h\n" |
4096 | "smax v18.8h, v18.8h, v14.8h\n" |
4097 | "smax v19.8h, v19.8h, v14.8h\n" |
4098 | "smax v20.8h, v20.8h, v14.8h\n" |
4099 | "smax v21.8h, v21.8h, v14.8h\n" |
4100 | "smax v22.8h, v22.8h, v14.8h\n" |
4101 | "smax v23.8h, v23.8h, v14.8h\n" |
4102 | // Apply the clamp_max bound |
4103 | "smin v16.8h, v16.8h, v15.8h\n" |
4104 | "smin v17.8h, v17.8h, v15.8h\n" |
4105 | "smin v18.8h, v18.8h, v15.8h\n" |
4106 | "smin v19.8h, v19.8h, v15.8h\n" |
4107 | "smin v20.8h, v20.8h, v15.8h\n" |
4108 | "smin v21.8h, v21.8h, v15.8h\n" |
4109 | "smin v22.8h, v22.8h, v15.8h\n" |
4110 | "smin v23.8h, v23.8h, v15.8h\n" |
4111 | |
4112 | // Compute how much of the 8x8 block of destination 16bit values that |
4113 | // we have computed, fit in the destination matrix. Typically, all of |
4114 | // it fits, but when the destination matrix shape is not a multiple |
4115 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
4116 | // not fit entirely. |
4117 | "sub w1, %w[dst_rows], %w[row]\n" |
4118 | "sub w2, %w[dst_cols], %w[col]\n" |
4119 | "mov w3, #8\n" |
4120 | "cmp w1, #8\n" |
4121 | // Compute w1 = how many rows of the 8x8 block fit |
4122 | "csel w1, w1, w3, le\n" |
4123 | "cmp w2, #8\n" |
4124 | // Compute w1 = how many rows of the 8x8 block fit |
4125 | "csel w2, w2, w3, le\n" |
4126 | |
4127 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
4128 | "cmp w1, w3\n" |
4129 | "ccmp w2, w3, 0, eq\n" |
4130 | // Yes, all of the 8x8 block fits, go to fast path. |
4131 | "beq 230f\n" |
4132 | // Not all of the 8x8 block fits. |
4133 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
4134 | "mov x3, %[dst_tmp_buf]\n" |
4135 | "mov x4, #16\n" |
4136 | "b 231f\n" |
4137 | "230:\n" |
4138 | // Yes, all of the 8x8 block fits. |
4139 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
4140 | "mov x3, %[dst_ptr]\n" |
4141 | "mov x4, x11\n" |
4142 | "231:\n" |
4143 | |
4144 | // Write our 16bit values to the destination described by |
4145 | // (x3 address, x4 stride). |
4146 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4147 | "st1 {v16.8h}, [x3], x4\n" |
4148 | RUY_MAKE_ZERO(v16) |
4149 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4150 | "st1 {v17.8h}, [x3], x4\n" |
4151 | RUY_MAKE_ZERO(v17) |
4152 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4153 | "st1 {v18.8h}, [x3], x4\n" |
4154 | RUY_MAKE_ZERO(v18) |
4155 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4156 | "st1 {v19.8h}, [x3], x4\n" |
4157 | RUY_MAKE_ZERO(v19) |
4158 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4159 | "st1 {v20.8h}, [x3], x4\n" |
4160 | RUY_MAKE_ZERO(v20) |
4161 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4162 | "st1 {v21.8h}, [x3], x4\n" |
4163 | RUY_MAKE_ZERO(v21) |
4164 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4165 | "st1 {v22.8h}, [x3], x4\n" |
4166 | RUY_MAKE_ZERO(v22) |
4167 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
4168 | "st1 {v23.8h}, [x3], x4\n" |
4169 | RUY_MAKE_ZERO(v23) |
4170 | |
4171 | // For the next block: perform the first few multiply-adds on the data |
4172 | // that we have already loaded. |
4173 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
4174 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
4175 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
4176 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
4177 | |
4178 | // If all of the 8x8 block fits, we just finished writing it to the |
4179 | // destination, so we skip the next part. |
4180 | "beq 241f\n" |
4181 | // Not all of the 8x8 block fits in the destination matrix. We just |
4182 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
4183 | // it to copy into the destination matrix the part that fits. |
4184 | "mov x3, %[dst_tmp_buf]\n" |
4185 | "mov x4, %[dst_ptr]\n" |
4186 | "mov w6, #0\n" |
4187 | "250:\n" |
4188 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4189 | "mov w5, #0\n" |
4190 | "251:\n" |
4191 | "ldrsh w7, [x3, x5, lsl #1]\n" |
4192 | "strh w7, [x4, x5, lsl #1]\n" |
4193 | "add w5, w5, #1\n" |
4194 | "cmp w5, w1\n" |
4195 | "blt 251b\n" |
4196 | "add w6, w6, #1\n" |
4197 | "add x3, x3, #16\n" |
4198 | "add x4, x4, x11\n" |
4199 | "cmp w6, w2\n" |
4200 | "blt 250b\n" |
4201 | "241:\n" |
4202 | "add %[dst_ptr], %[dst_ptr], #16\n" |
4203 | // At this point we have completely finished writing values to the |
4204 | // destination matrix for the current block. |
4205 | |
4206 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
4207 | |
4208 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
4209 | |
4210 | // Since the store type is the same as the accum type, no need for |
4211 | // downcast. There's also no need for clamp by min/max. |
4212 | |
4213 | // Compute how much of the 8x8 block of destination 32it values that |
4214 | // we have computed, fit in the destination matrix. Typically, all of |
4215 | // it fits, but when the destination matrix shape is not a multiple |
4216 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
4217 | // not fit entirely. |
4218 | "sub w1, %w[dst_rows], %w[row]\n" |
4219 | "sub w2, %w[dst_cols], %w[col]\n" |
4220 | "mov w3, #8\n" |
4221 | "cmp w1, #8\n" |
4222 | // Compute w1 = how many rows of the 8x8 block fit |
4223 | "csel w1, w1, w3, le\n" |
4224 | "cmp w2, #8\n" |
4225 | // Compute w1 = how many rows of the 8x8 block fit |
4226 | "csel w2, w2, w3, le\n" |
4227 | |
4228 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
4229 | "cmp w1, w3\n" |
4230 | "ccmp w2, w3, 0, eq\n" |
4231 | // Yes, all of the 8x8 block fits, go to fast path. |
4232 | "beq 330f\n" |
4233 | // Not all of the 8x8 block fits. |
4234 | // Write to dst_tmp_buf |
4235 | "mov x3, %[dst_tmp_buf]\n" |
4236 | "st1 {v16.4s}, [x3], #16\n" |
4237 | RUY_MAKE_ZERO(v16) |
4238 | "st1 {v17.4s}, [x3], #16\n" |
4239 | RUY_MAKE_ZERO(v17) |
4240 | "st1 {v18.4s}, [x3], #16\n" |
4241 | RUY_MAKE_ZERO(v18) |
4242 | "st1 {v19.4s}, [x3], #16\n" |
4243 | RUY_MAKE_ZERO(v19) |
4244 | "st1 {v20.4s}, [x3], #16\n" |
4245 | RUY_MAKE_ZERO(v20) |
4246 | "st1 {v21.4s}, [x3], #16\n" |
4247 | RUY_MAKE_ZERO(v21) |
4248 | "st1 {v22.4s}, [x3], #16\n" |
4249 | RUY_MAKE_ZERO(v22) |
4250 | "st1 {v23.4s}, [x3], #16\n" |
4251 | RUY_MAKE_ZERO(v23) |
4252 | "st1 {v24.4s}, [x3], #16\n" |
4253 | RUY_MAKE_ZERO(v24) |
4254 | "st1 {v25.4s}, [x3], #16\n" |
4255 | RUY_MAKE_ZERO(v25) |
4256 | "st1 {v26.4s}, [x3], #16\n" |
4257 | RUY_MAKE_ZERO(v26) |
4258 | "st1 {v27.4s}, [x3], #16\n" |
4259 | RUY_MAKE_ZERO(v27) |
4260 | "st1 {v28.4s}, [x3], #16\n" |
4261 | RUY_MAKE_ZERO(v28) |
4262 | "st1 {v29.4s}, [x3], #16\n" |
4263 | RUY_MAKE_ZERO(v29) |
4264 | "st1 {v30.4s}, [x3], #16\n" |
4265 | RUY_MAKE_ZERO(v30) |
4266 | "st1 {v31.4s}, [x3], #16\n" |
4267 | RUY_MAKE_ZERO(v31) |
4268 | |
4269 | "b 331f\n" |
4270 | |
4271 | "330:\n" |
4272 | // Yes, all of the 8x8 block fits. |
4273 | "mov x4, %[dst_ptr]\n" |
4274 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4275 | "mov x3, x4\n" |
4276 | "st1 {v16.4s, v17.4s}, [x3], #32\n" |
4277 | RUY_MAKE_ZERO(v16) |
4278 | RUY_MAKE_ZERO(v17) |
4279 | "add x4, x4, x11\n" |
4280 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4281 | "mov x3, x4\n" |
4282 | "st1 {v18.4s, v19.4s}, [x3], #32\n" |
4283 | RUY_MAKE_ZERO(v18) |
4284 | RUY_MAKE_ZERO(v19) |
4285 | "add x4, x4, x11\n" |
4286 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4287 | "mov x3, x4\n" |
4288 | "st1 {v20.4s, v21.4s}, [x3], #32\n" |
4289 | RUY_MAKE_ZERO(v20) |
4290 | RUY_MAKE_ZERO(v21) |
4291 | "add x4, x4, x11\n" |
4292 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4293 | "mov x3, x4\n" |
4294 | "st1 {v22.4s, v23.4s}, [x3], #32\n" |
4295 | RUY_MAKE_ZERO(v22) |
4296 | RUY_MAKE_ZERO(v23) |
4297 | "add x4, x4, x11\n" |
4298 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4299 | "mov x3, x4\n" |
4300 | "st1 {v24.4s, v25.4s}, [x3], #32\n" |
4301 | RUY_MAKE_ZERO(v24) |
4302 | RUY_MAKE_ZERO(v25) |
4303 | "add x4, x4, x11\n" |
4304 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4305 | "mov x3, x4\n" |
4306 | "st1 {v26.4s, v27.4s}, [x3], #32\n" |
4307 | RUY_MAKE_ZERO(v26) |
4308 | RUY_MAKE_ZERO(v27) |
4309 | "add x4, x4, x11\n" |
4310 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4311 | "mov x3, x4\n" |
4312 | "st1 {v28.4s, v29.4s}, [x3], #32\n" |
4313 | RUY_MAKE_ZERO(v28) |
4314 | RUY_MAKE_ZERO(v29) |
4315 | "add x4, x4, x11\n" |
4316 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4317 | "mov x3, x4\n" |
4318 | "st1 {v30.4s, v31.4s}, [x3], #32\n" |
4319 | RUY_MAKE_ZERO(v30) |
4320 | RUY_MAKE_ZERO(v31) |
4321 | |
4322 | "331:\n" |
4323 | |
4324 | // For the next block: perform the first few multiply-adds on the data |
4325 | // that we have already loaded. |
4326 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
4327 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
4328 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
4329 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
4330 | |
4331 | // If all of the 8x8 block fits, we just finished writing it to the |
4332 | // destination, so we skip the next part. |
4333 | "beq 341f\n" |
4334 | |
4335 | // Not all of the 8x8 block fits in the destination matrix. We just |
4336 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
4337 | // it to copy into the destination matrix the part that fits. |
4338 | "mov x3, %[dst_tmp_buf]\n" |
4339 | "mov x4, %[dst_ptr]\n" |
4340 | "mov w6, #0\n" |
4341 | "350:\n" |
4342 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
4343 | "mov w5, #0\n" |
4344 | "351:\n" |
4345 | "ldr w7, [x3, x5, lsl #2]\n" |
4346 | "str w7, [x4, x5, lsl #2]\n" |
4347 | "add w5, w5, #1\n" |
4348 | "cmp w5, w1\n" |
4349 | "blt 351b\n" |
4350 | "add w6, w6, #1\n" |
4351 | "add x3, x3, #32\n" |
4352 | "add x4, x4, x11\n" |
4353 | "cmp w6, w2\n" |
4354 | "blt 350b\n" |
4355 | "341:\n" |
4356 | "add %[dst_ptr], %[dst_ptr], #32\n" |
4357 | // At this point we have completely finished writing values to the |
4358 | // destination matrix for the current block. |
4359 | |
4360 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
4361 | |
4362 | // Reload some params --- we had used x5 -- x7 for a few other things |
4363 | // since the last time we had loaded them. |
4364 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
4365 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
4366 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
4367 | |
4368 | // Move to the next block of the destination matrix, for the next iter |
4369 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
4370 | // been updated earlier. |
4371 | // Have we reached the end row? |
4372 | "cmp %w[row], w7\n" |
4373 | "beq 20f\n" // yes, end row. |
4374 | // Not end row. Move to the next row. |
4375 | "add %w[row], %w[row], #8\n" |
4376 | "b 21f\n" |
4377 | "20:\n" |
4378 | // Was already at end row. |
4379 | "mov %w[row], w6\n" // Move back to first row. |
4380 | "add %w[col], %w[col], #8\n" // Move to the next column. |
4381 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
4382 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
4383 | "21:\n" |
4384 | |
4385 | // Main loop exit condition: have we hit the end column? |
4386 | "cmp %w[col], w8\n" |
4387 | |
4388 | // w1 is the number of levels of depth that we have already loaded |
4389 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
4390 | // above, this is currently 4. |
4391 | "mov w1, #4\n" |
4392 | |
4393 | "ble 1b\n" |
4394 | |
4395 | // clang-format on |
4396 | |
4397 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
4398 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
4399 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
4400 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
4401 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
4402 | [dst_type_id] "r" (params.dst_type_id) |
4403 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
4404 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
4405 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
4406 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
4407 | } |
4408 | |
4409 | // A fork of the above 8bitNeonDotprod kernel but removes the max streaming |
4410 | // manual unrolling. Manually unrolling the inner loops benefits some GEMM |
4411 | // shapes on the Cortex-A76 but destroys performance on the X1 by increasing |
4412 | // backend stalls. Therefore, we remove the MAX_STREAMING option in this |
4413 | // kernel. The target CPU for this kernel is currently only the Cortex-X1. |
4414 | void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) { |
4415 | profiler::ScopeLabel label("Kernel (kNeonDotprod)" ); |
4416 | |
4417 | CheckOffsetsInKernelParams8bit(params); |
4418 | |
4419 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
4420 | const std::int8_t* rhs_col_ptr = |
4421 | static_cast<const int8_t*>(params.rhs_base_ptr); |
4422 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
4423 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
4424 | void* dst_col_ptr = params.dst_base_ptr; |
4425 | void* dst_ptr = dst_col_ptr; |
4426 | int row = params.start_row; |
4427 | int col = params.start_col; |
4428 | |
4429 | // The asm kernel below has the following NEON register allocation: |
4430 | // |
4431 | // v16 -- v31 are int32 accumulators. |
4432 | // During accumulation, v0 -- v15 are used to load int8 data from LHS and |
4433 | // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and |
4434 | // v3 are used to load a 4x8 block of RHS, like this: |
4435 | // |
4436 | // int8 RHS 4x8 block |
4437 | // /-----------------------------------------| |
4438 | // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| |
4439 | // | ... ... | |
4440 | // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| |
4441 | // \-----------------------------------------/ |
4442 | // int8 LHS 8x4 block |
4443 | // /---------------------\ /-----------------------------------------| |
4444 | // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| |
4445 | // | ... ... | | ... ... | |
4446 | // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| |
4447 | // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| |
4448 | // | ... ... | | ... ... | |
4449 | // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| |
4450 | // \---------------------/ \-----------------------------------------/ |
4451 | // int32 accumulators 8x8 block |
4452 | // |
4453 | // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step |
4454 | // is repeated 4 times, using 4x more registers for LHS and RHS, so that |
4455 | // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. |
4456 | // |
4457 | // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are |
4458 | // unused, and v8 -- v15 are used for loading parameters used for the |
4459 | // post-accumulation part of the kernel. |
4460 | asm volatile( |
4461 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
4462 | |
4463 | // clang-format off |
4464 | |
4465 | // Load some parameters into registers. |
4466 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
4467 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
4468 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
4469 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
4470 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
4471 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
4472 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
4473 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
4474 | |
4475 | // Load the first 32 bytes of LHS and RHS data. |
4476 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
4477 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
4478 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
4479 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
4480 | |
4481 | // Clear accumulators. |
4482 | RUY_MAKE_ZERO(v16) |
4483 | RUY_MAKE_ZERO(v17) |
4484 | RUY_MAKE_ZERO(v18) |
4485 | RUY_MAKE_ZERO(v19) |
4486 | RUY_MAKE_ZERO(v20) |
4487 | RUY_MAKE_ZERO(v21) |
4488 | RUY_MAKE_ZERO(v22) |
4489 | RUY_MAKE_ZERO(v23) |
4490 | RUY_MAKE_ZERO(v24) |
4491 | RUY_MAKE_ZERO(v25) |
4492 | RUY_MAKE_ZERO(v26) |
4493 | RUY_MAKE_ZERO(v27) |
4494 | RUY_MAKE_ZERO(v28) |
4495 | RUY_MAKE_ZERO(v29) |
4496 | RUY_MAKE_ZERO(v30) |
4497 | RUY_MAKE_ZERO(v31) |
4498 | |
4499 | // w1 is the number of levels of depth that we have already loaded |
4500 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
4501 | // above, this is currently 4. |
4502 | "mov w1, #4\n" |
4503 | |
4504 | // Perform the first few multiply-adds on the data that we have already |
4505 | // loaded. |
4506 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
4507 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
4508 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
4509 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
4510 | |
4511 | // Main loop of the whole GEMM, over rows and columns of the |
4512 | // destination matrix. |
4513 | "1:\n" |
4514 | |
4515 | // Kernel inner loop (over depth). |
4516 | // Reminder - w1 is how many levels of depth we have already loaded |
4517 | // data for, w12 is the total depth. |
4518 | "cmp w1, w12\n" |
4519 | "beq 79f\n" |
4520 | |
4521 | "2:\n" |
4522 | |
4523 | // Because of the data that we have already loaded, we can start the |
4524 | // loop body right away with some multiply-adds. |
4525 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
4526 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
4527 | // Each iteration of this loop advances by 4 levels of depth. |
4528 | "add w1, w1, #4\n" |
4529 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
4530 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
4531 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
4532 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
4533 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
4534 | // Loop termination condition. |
4535 | "cmp w1, w12\n" |
4536 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
4537 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
4538 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
4539 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
4540 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
4541 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
4542 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
4543 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
4544 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
4545 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
4546 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
4547 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
4548 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
4549 | |
4550 | "blt 2b\n" |
4551 | |
4552 | "79:\n" |
4553 | // End of the inner loop on depth. Now perform the remaining |
4554 | // multiply-adds of the last 4 levels of depth, for which the LHS |
4555 | // and RHS data is already loaded. |
4556 | |
4557 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
4558 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
4559 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
4560 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
4561 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
4562 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
4563 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
4564 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
4565 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
4566 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
4567 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
4568 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
4569 | |
4570 | // End of accumulation. The registers v16 -- v31 contain the final |
4571 | // int32 accumulator values of the current 8x8 destination block. |
4572 | // We now have to compute the final 8-bit values from these int32 |
4573 | // accumulators, and advance to the next 8x8 block. We intertwine |
4574 | // these two aspects whenever possible for optimal pipelining, both |
4575 | // at the data flow level (prefetch data for next block as early as |
4576 | // possible) and instruction pipelining level (some of the next-block |
4577 | // work can dual-issue with some of the final work on the current |
4578 | // block). |
4579 | |
4580 | // Logic to advance to the next block in preparation for the next |
4581 | // iteration of the main loop. For now, we only want to compute |
4582 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
4583 | // not yet ready to update the values of row and col, as we still need |
4584 | // the current values for the rest of the work on the current block. |
4585 | |
4586 | "cmp %w[row], w7\n" // Have we finished the last row? |
4587 | "bge 4f\n" // If finished last row, go to 4 |
4588 | // Not finished last row: then advance to next row. |
4589 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
4590 | "b 5f\n" |
4591 | "4:\n" // Finished last row... |
4592 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
4593 | // Now we need to advance to the next column. If we already |
4594 | // finished the last column, then in principle we are done, however |
4595 | // we can't just return here, as we need to allow the end work of the |
4596 | // current block to complete. The good news is that at this point it |
4597 | // doesn't matter what data we load for the next column, since |
4598 | // we will exit from the main loop below before actually storing |
4599 | // anything computed from that data. |
4600 | "cmp %w[col], w8\n" // Have we finished the last column? |
4601 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
4602 | // Not finished last column: then advance to next column. |
4603 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
4604 | "5:\n" |
4605 | |
4606 | // Set the LHS and RHS data pointers to the start of the columns just |
4607 | // computed. |
4608 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
4609 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
4610 | |
4611 | // Load some parameters needed for the end work on current block. |
4612 | "mvni v8.4s, #0\n" |
4613 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
4614 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
4615 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
4616 | |
4617 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
4618 | // Determine the channel index. |
4619 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
4620 | "csel w3, %w[row], %w[col], eq\n" |
4621 | |
4622 | // Offset the bias pointer as needed given the current row, col. |
4623 | "add x5, x1, x3, lsl #2\n" |
4624 | |
4625 | // If there is no bias, use no offset, just address the passed zero |
4626 | // data. |
4627 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
4628 | "csel x1, x1, x5, eq\n" |
4629 | |
4630 | // Load 8 bias values. |
4631 | "ld1 {v14.4s}, [x1], #16\n" |
4632 | "ld1 {v15.4s}, [x1]\n" |
4633 | |
4634 | // Now that we know what LHS and RHS data the next iteration of the |
4635 | // main loop will need to load, we start loading the first 32 bytes of |
4636 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
4637 | // in the rest of the work on the current block. |
4638 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
4639 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
4640 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
4641 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
4642 | |
4643 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
4644 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
4645 | "add v14.4s, v14.4s, v9.4s\n" |
4646 | "add v15.4s, v15.4s, v9.4s\n" |
4647 | |
4648 | // Perform the bias-addition (per the above, we have just folded into |
4649 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
4650 | // Jump based on channel dimension. |
4651 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
4652 | "bne 6f\n" |
4653 | // Case where channels are rows |
4654 | "add v16.4s, v16.4s, v14.4s\n" |
4655 | "add v17.4s, v17.4s, v15.4s\n" |
4656 | "add v18.4s, v18.4s, v14.4s\n" |
4657 | "add v19.4s, v19.4s, v15.4s\n" |
4658 | "add v20.4s, v20.4s, v14.4s\n" |
4659 | "add v21.4s, v21.4s, v15.4s\n" |
4660 | "add v22.4s, v22.4s, v14.4s\n" |
4661 | "add v23.4s, v23.4s, v15.4s\n" |
4662 | "add v24.4s, v24.4s, v14.4s\n" |
4663 | "add v25.4s, v25.4s, v15.4s\n" |
4664 | "add v26.4s, v26.4s, v14.4s\n" |
4665 | "add v27.4s, v27.4s, v15.4s\n" |
4666 | "add v28.4s, v28.4s, v14.4s\n" |
4667 | "add v29.4s, v29.4s, v15.4s\n" |
4668 | "add v30.4s, v30.4s, v14.4s\n" |
4669 | "add v31.4s, v31.4s, v15.4s\n" |
4670 | "b 7f\n" |
4671 | |
4672 | "6:\n" |
4673 | // Case where channels are columns |
4674 | "dup v10.4s, v14.s[0]\n" |
4675 | "dup v11.4s, v14.s[1]\n" |
4676 | "dup v12.4s, v14.s[2]\n" |
4677 | "dup v13.4s, v14.s[3]\n" |
4678 | "add v16.4s, v16.4s, v10.4s\n" |
4679 | "add v17.4s, v17.4s, v10.4s\n" |
4680 | "add v18.4s, v18.4s, v11.4s\n" |
4681 | "add v19.4s, v19.4s, v11.4s\n" |
4682 | "add v20.4s, v20.4s, v12.4s\n" |
4683 | "add v21.4s, v21.4s, v12.4s\n" |
4684 | "add v22.4s, v22.4s, v13.4s\n" |
4685 | "add v23.4s, v23.4s, v13.4s\n" |
4686 | "dup v10.4s, v15.s[0]\n" |
4687 | "dup v11.4s, v15.s[1]\n" |
4688 | "dup v12.4s, v15.s[2]\n" |
4689 | "dup v13.4s, v15.s[3]\n" |
4690 | "add v24.4s, v24.4s, v10.4s\n" |
4691 | "add v25.4s, v25.4s, v10.4s\n" |
4692 | "add v26.4s, v26.4s, v11.4s\n" |
4693 | "add v27.4s, v27.4s, v11.4s\n" |
4694 | "add v28.4s, v28.4s, v12.4s\n" |
4695 | "add v29.4s, v29.4s, v12.4s\n" |
4696 | "add v30.4s, v30.4s, v13.4s\n" |
4697 | "add v31.4s, v31.4s, v13.4s\n" |
4698 | "7:\n" |
4699 | |
4700 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
4701 | "beq 401f\n" |
4702 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
4703 | "add x3, x3, %x[col], lsl #2\n" |
4704 | "ld1 {v14.4s}, [x3], #16\n" |
4705 | "ld1 {v15.4s}, [x3]\n" |
4706 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
4707 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
4708 | // Subtract rhs_sums * lhs_zero_point, per |
4709 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
4710 | "mls v16.4s, v10.4s, v14.s[0]\n" |
4711 | "mls v17.4s, v10.4s, v14.s[0]\n" |
4712 | "mls v18.4s, v10.4s, v14.s[1]\n" |
4713 | "mls v19.4s, v10.4s, v14.s[1]\n" |
4714 | "mls v20.4s, v10.4s, v14.s[2]\n" |
4715 | "mls v21.4s, v10.4s, v14.s[2]\n" |
4716 | "mls v22.4s, v10.4s, v14.s[3]\n" |
4717 | "mls v23.4s, v10.4s, v14.s[3]\n" |
4718 | "mls v24.4s, v10.4s, v15.s[0]\n" |
4719 | "mls v25.4s, v10.4s, v15.s[0]\n" |
4720 | "mls v26.4s, v10.4s, v15.s[1]\n" |
4721 | "mls v27.4s, v10.4s, v15.s[1]\n" |
4722 | "mls v28.4s, v10.4s, v15.s[2]\n" |
4723 | "mls v29.4s, v10.4s, v15.s[2]\n" |
4724 | "mls v30.4s, v10.4s, v15.s[3]\n" |
4725 | "mls v31.4s, v10.4s, v15.s[3]\n" |
4726 | "401:\n" |
4727 | |
4728 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
4729 | "beq 402f\n" |
4730 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
4731 | "add x2, x2, %x[row], lsl #2\n" |
4732 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
4733 | // Load 4 lhs_sums values. |
4734 | "ld1 {v11.4s}, [x2], #16\n" |
4735 | "ld1 {v12.4s}, [x2]\n" |
4736 | "ins v13.s[1], w5\n" // rhs_zero_point |
4737 | // Compute lhs_sums * rhs_zero_point. |
4738 | "mul v11.4s, v11.4s, v13.s[1]\n" |
4739 | "mul v12.4s, v12.4s, v13.s[1]\n" |
4740 | // Subtract lhs_sums * rhs_zero_point, per |
4741 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
4742 | "sub v16.4s, v16.4s, v11.4s\n" |
4743 | "sub v17.4s, v17.4s, v12.4s\n" |
4744 | "sub v18.4s, v18.4s, v11.4s\n" |
4745 | "sub v19.4s, v19.4s, v12.4s\n" |
4746 | "sub v20.4s, v20.4s, v11.4s\n" |
4747 | "sub v21.4s, v21.4s, v12.4s\n" |
4748 | "sub v22.4s, v22.4s, v11.4s\n" |
4749 | "sub v23.4s, v23.4s, v12.4s\n" |
4750 | "sub v24.4s, v24.4s, v11.4s\n" |
4751 | "sub v25.4s, v25.4s, v12.4s\n" |
4752 | "sub v26.4s, v26.4s, v11.4s\n" |
4753 | "sub v27.4s, v27.4s, v12.4s\n" |
4754 | "sub v28.4s, v28.4s, v11.4s\n" |
4755 | "sub v29.4s, v29.4s, v12.4s\n" |
4756 | "sub v30.4s, v30.4s, v11.4s\n" |
4757 | "sub v31.4s, v31.4s, v12.4s\n" |
4758 | |
4759 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
4760 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
4761 | |
4762 | "402:\n" |
4763 | |
4764 | // At this point we have computed the final int32 values. Now we |
4765 | // start down-quantizing them to obtain the final 8bit values from them. |
4766 | |
4767 | // As part of this down-quantization, our int32 values will be |
4768 | // multiplied by a multiplier that has a fixed-point component and an |
4769 | // exponent component. |
4770 | |
4771 | //Load the exponent part of the multiplier. |
4772 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
4773 | // Determine the channel index. |
4774 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
4775 | "csel w3, %w[row], %w[col], eq\n" |
4776 | // Compute the multiplier_exponent pointer |
4777 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
4778 | "add x5, x1, x3, lsl #2\n" |
4779 | "csel x1, x1, x5, eq\n" |
4780 | // Load multiplier_exponent |
4781 | "ldr q9, [x1]\n" |
4782 | "ldr q10, [x1, #16]\n" |
4783 | // Separate positive and negative exponents |
4784 | "smin v11.4s, v8.4s, v9.4s\n" |
4785 | "smin v12.4s, v8.4s, v10.4s\n" |
4786 | "sub v9.4s, v9.4s, v11.4s\n" |
4787 | "sub v10.4s, v10.4s, v12.4s\n" |
4788 | |
4789 | // Compute the multiplier_fixedpoint pointer |
4790 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
4791 | "add x5, x4, x3, lsl #2\n" |
4792 | "csel x4, x4, x5, eq\n" |
4793 | // Load multiplier_fixedpoint |
4794 | "ldr q14, [x4]\n" |
4795 | "ldr q15, [x4, #16]\n" |
4796 | |
4797 | // Jump based on channel dimension. |
4798 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
4799 | "bne 8f\n" |
4800 | // Case where channels are rows |
4801 | |
4802 | // Apply the positive exponent part of the multiplier. |
4803 | "sshl v16.4s, v16.4s, v9.4s\n" |
4804 | "sshl v17.4s, v17.4s, v10.4s\n" |
4805 | "sshl v18.4s, v18.4s, v9.4s\n" |
4806 | "sshl v19.4s, v19.4s, v10.4s\n" |
4807 | "sshl v20.4s, v20.4s, v9.4s\n" |
4808 | "sshl v21.4s, v21.4s, v10.4s\n" |
4809 | "sshl v22.4s, v22.4s, v9.4s\n" |
4810 | "sshl v23.4s, v23.4s, v10.4s\n" |
4811 | "sshl v24.4s, v24.4s, v9.4s\n" |
4812 | "sshl v25.4s, v25.4s, v10.4s\n" |
4813 | "sshl v26.4s, v26.4s, v9.4s\n" |
4814 | "sshl v27.4s, v27.4s, v10.4s\n" |
4815 | "sshl v28.4s, v28.4s, v9.4s\n" |
4816 | "sshl v29.4s, v29.4s, v10.4s\n" |
4817 | "sshl v30.4s, v30.4s, v9.4s\n" |
4818 | "sshl v31.4s, v31.4s, v10.4s\n" |
4819 | "10:\n" |
4820 | |
4821 | // Apply the fixed-point part of the multiplier. |
4822 | "sqdmulh v16.4s, v16.4s, v14.4s\n" |
4823 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
4824 | "sqdmulh v18.4s, v18.4s, v14.4s\n" |
4825 | "sqdmulh v19.4s, v19.4s, v15.4s\n" |
4826 | "sqdmulh v20.4s, v20.4s, v14.4s\n" |
4827 | "sqdmulh v21.4s, v21.4s, v15.4s\n" |
4828 | "sqdmulh v22.4s, v22.4s, v14.4s\n" |
4829 | "sqdmulh v23.4s, v23.4s, v15.4s\n" |
4830 | "sqdmulh v24.4s, v24.4s, v14.4s\n" |
4831 | "sqdmulh v25.4s, v25.4s, v15.4s\n" |
4832 | "sqdmulh v26.4s, v26.4s, v14.4s\n" |
4833 | "sqdmulh v27.4s, v27.4s, v15.4s\n" |
4834 | "sqdmulh v28.4s, v28.4s, v14.4s\n" |
4835 | "sqdmulh v29.4s, v29.4s, v15.4s\n" |
4836 | "sqdmulh v30.4s, v30.4s, v14.4s\n" |
4837 | "sqdmulh v31.4s, v31.4s, v15.4s\n" |
4838 | |
4839 | // Apply the negative exponent part of the multiplier. |
4840 | "srshl v16.4s, v16.4s, v11.4s\n" |
4841 | "srshl v17.4s, v17.4s, v12.4s\n" |
4842 | "srshl v18.4s, v18.4s, v11.4s\n" |
4843 | "srshl v19.4s, v19.4s, v12.4s\n" |
4844 | "srshl v20.4s, v20.4s, v11.4s\n" |
4845 | "srshl v21.4s, v21.4s, v12.4s\n" |
4846 | "srshl v22.4s, v22.4s, v11.4s\n" |
4847 | "srshl v23.4s, v23.4s, v12.4s\n" |
4848 | "srshl v24.4s, v24.4s, v11.4s\n" |
4849 | "srshl v25.4s, v25.4s, v12.4s\n" |
4850 | "srshl v26.4s, v26.4s, v11.4s\n" |
4851 | "srshl v27.4s, v27.4s, v12.4s\n" |
4852 | "srshl v28.4s, v28.4s, v11.4s\n" |
4853 | "srshl v29.4s, v29.4s, v12.4s\n" |
4854 | "srshl v30.4s, v30.4s, v11.4s\n" |
4855 | "srshl v31.4s, v31.4s, v12.4s\n" |
4856 | "b 9f\n" |
4857 | |
4858 | "8:\n" |
4859 | // Case where channels are columns |
4860 | |
4861 | // Apply the positive exponent part of the multiplier. |
4862 | "dup v4.4s, v9.s[0]\n" |
4863 | "dup v5.4s, v9.s[1]\n" |
4864 | "dup v6.4s, v9.s[2]\n" |
4865 | "dup v7.4s, v9.s[3]\n" |
4866 | "sshl v16.4s, v16.4s, v4.4s\n" |
4867 | "sshl v17.4s, v17.4s, v4.4s\n" |
4868 | "sshl v18.4s, v18.4s, v5.4s\n" |
4869 | "sshl v19.4s, v19.4s, v5.4s\n" |
4870 | "sshl v20.4s, v20.4s, v6.4s\n" |
4871 | "sshl v21.4s, v21.4s, v6.4s\n" |
4872 | "sshl v22.4s, v22.4s, v7.4s\n" |
4873 | "sshl v23.4s, v23.4s, v7.4s\n" |
4874 | "dup v4.4s, v10.s[0]\n" |
4875 | "dup v5.4s, v10.s[1]\n" |
4876 | "dup v6.4s, v10.s[2]\n" |
4877 | "dup v7.4s, v10.s[3]\n" |
4878 | "sshl v24.4s, v24.4s, v4.4s\n" |
4879 | "sshl v25.4s, v25.4s, v4.4s\n" |
4880 | "sshl v26.4s, v26.4s, v5.4s\n" |
4881 | "sshl v27.4s, v27.4s, v5.4s\n" |
4882 | "sshl v28.4s, v28.4s, v6.4s\n" |
4883 | "sshl v29.4s, v29.4s, v6.4s\n" |
4884 | "sshl v30.4s, v30.4s, v7.4s\n" |
4885 | "sshl v31.4s, v31.4s, v7.4s\n" |
4886 | "11:\n" |
4887 | |
4888 | // Apply the fixed-point part of the multiplier. |
4889 | "sqdmulh v16.4s, v16.4s, v14.s[0]\n" |
4890 | "sqdmulh v17.4s, v17.4s, v14.s[0]\n" |
4891 | "sqdmulh v18.4s, v18.4s, v14.s[1]\n" |
4892 | "sqdmulh v19.4s, v19.4s, v14.s[1]\n" |
4893 | "sqdmulh v20.4s, v20.4s, v14.s[2]\n" |
4894 | "sqdmulh v21.4s, v21.4s, v14.s[2]\n" |
4895 | "sqdmulh v22.4s, v22.4s, v14.s[3]\n" |
4896 | "sqdmulh v23.4s, v23.4s, v14.s[3]\n" |
4897 | "sqdmulh v24.4s, v24.4s, v15.s[0]\n" |
4898 | "sqdmulh v25.4s, v25.4s, v15.s[0]\n" |
4899 | "sqdmulh v26.4s, v26.4s, v15.s[1]\n" |
4900 | "sqdmulh v27.4s, v27.4s, v15.s[1]\n" |
4901 | "sqdmulh v28.4s, v28.4s, v15.s[2]\n" |
4902 | "sqdmulh v29.4s, v29.4s, v15.s[2]\n" |
4903 | "sqdmulh v30.4s, v30.4s, v15.s[3]\n" |
4904 | "sqdmulh v31.4s, v31.4s, v15.s[3]\n" |
4905 | |
4906 | // Apply the negative exponent part of the multiplier. |
4907 | "dup v4.4s, v11.s[0]\n" |
4908 | "dup v5.4s, v11.s[1]\n" |
4909 | "dup v6.4s, v11.s[2]\n" |
4910 | "dup v7.4s, v11.s[3]\n" |
4911 | "srshl v16.4s, v16.4s, v4.4s\n" |
4912 | "srshl v17.4s, v17.4s, v4.4s\n" |
4913 | "srshl v18.4s, v18.4s, v5.4s\n" |
4914 | "srshl v19.4s, v19.4s, v5.4s\n" |
4915 | "srshl v20.4s, v20.4s, v6.4s\n" |
4916 | "srshl v21.4s, v21.4s, v6.4s\n" |
4917 | "srshl v22.4s, v22.4s, v7.4s\n" |
4918 | "srshl v23.4s, v23.4s, v7.4s\n" |
4919 | "dup v4.4s, v12.s[0]\n" |
4920 | "dup v5.4s, v12.s[1]\n" |
4921 | "dup v6.4s, v12.s[2]\n" |
4922 | "dup v7.4s, v12.s[3]\n" |
4923 | "srshl v24.4s, v24.4s, v4.4s\n" |
4924 | "srshl v25.4s, v25.4s, v4.4s\n" |
4925 | "srshl v26.4s, v26.4s, v5.4s\n" |
4926 | "srshl v27.4s, v27.4s, v5.4s\n" |
4927 | "srshl v28.4s, v28.4s, v6.4s\n" |
4928 | "srshl v29.4s, v29.4s, v6.4s\n" |
4929 | "srshl v30.4s, v30.4s, v7.4s\n" |
4930 | "srshl v31.4s, v31.4s, v7.4s\n" |
4931 | "9:\n" |
4932 | |
4933 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
4934 | "ins v13.h[4], w4\n" // dst_zero_point |
4935 | |
4936 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
4937 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
4938 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
4939 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
4940 | |
4941 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
4942 | |
4943 | // Cast-and-saturate from int32 to int16 |
4944 | "sqxtn v16.4h, v16.4s\n" |
4945 | "sqxtn2 v16.8h, v17.4s\n" |
4946 | "sqxtn v17.4h, v18.4s\n" |
4947 | "sqxtn2 v17.8h, v19.4s\n" |
4948 | "sqxtn v18.4h, v20.4s\n" |
4949 | "sqxtn2 v18.8h, v21.4s\n" |
4950 | "sqxtn v19.4h, v22.4s\n" |
4951 | "sqxtn2 v19.8h, v23.4s\n" |
4952 | "sqxtn v20.4h, v24.4s\n" |
4953 | "sqxtn2 v20.8h, v25.4s\n" |
4954 | "sqxtn v21.4h, v26.4s\n" |
4955 | "sqxtn2 v21.8h, v27.4s\n" |
4956 | "sqxtn v22.4h, v28.4s\n" |
4957 | "sqxtn2 v22.8h, v29.4s\n" |
4958 | "sqxtn v23.4h, v30.4s\n" |
4959 | "sqxtn2 v23.8h, v31.4s\n" |
4960 | |
4961 | // At this point, v24 -- v31 aren't used anymore for the current block, |
4962 | // so we can start clearing these accumulators for the next block |
4963 | // (next iteration of the main loop). |
4964 | RUY_MAKE_ZERO(v24) |
4965 | RUY_MAKE_ZERO(v25) |
4966 | RUY_MAKE_ZERO(v26) |
4967 | RUY_MAKE_ZERO(v27) |
4968 | RUY_MAKE_ZERO(v28) |
4969 | RUY_MAKE_ZERO(v29) |
4970 | RUY_MAKE_ZERO(v30) |
4971 | RUY_MAKE_ZERO(v31) |
4972 | |
4973 | // Add the destination zero point |
4974 | "dup v14.8h, v13.h[4]\n" |
4975 | "sqadd v16.8h, v16.8h, v14.8h\n" |
4976 | "sqadd v17.8h, v17.8h, v14.8h\n" |
4977 | "sqadd v18.8h, v18.8h, v14.8h\n" |
4978 | "sqadd v19.8h, v19.8h, v14.8h\n" |
4979 | "sqadd v20.8h, v20.8h, v14.8h\n" |
4980 | "sqadd v21.8h, v21.8h, v14.8h\n" |
4981 | "sqadd v22.8h, v22.8h, v14.8h\n" |
4982 | "sqadd v23.8h, v23.8h, v14.8h\n" |
4983 | |
4984 | // Cast-and-saturate from int16 to uint8 |
4985 | "sqxtun v16.8b, v16.8h\n" |
4986 | "sqxtun2 v16.16b, v17.8h\n" |
4987 | "sqxtun v17.8b, v18.8h\n" |
4988 | "sqxtun2 v17.16b, v19.8h\n" |
4989 | "sqxtun v18.8b, v20.8h\n" |
4990 | "sqxtun2 v18.16b, v21.8h\n" |
4991 | "sqxtun v19.8b, v22.8h\n" |
4992 | "sqxtun2 v19.16b, v23.8h\n" |
4993 | |
4994 | // Load the clamp_min, clamp_max bounds |
4995 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
4996 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
4997 | "dup v14.16b, w2\n" // clamp_min |
4998 | "dup v15.16b, w3\n" // clamp_max |
4999 | |
5000 | // Apply the clamp_min bound |
5001 | "umax v16.16b, v16.16b, v14.16b\n" |
5002 | "umax v17.16b, v17.16b, v14.16b\n" |
5003 | "umax v18.16b, v18.16b, v14.16b\n" |
5004 | "umax v19.16b, v19.16b, v14.16b\n" |
5005 | |
5006 | // Apply the clamp_max bound |
5007 | "umin v16.16b, v16.16b, v15.16b\n" |
5008 | "umin v17.16b, v17.16b, v15.16b\n" |
5009 | "umin v18.16b, v18.16b, v15.16b\n" |
5010 | "umin v19.16b, v19.16b, v15.16b\n" |
5011 | |
5012 | // Make it so that all of the final 8bit values are stored in the |
5013 | // first 64bits of 128bit NEON registers, so they can be stored |
5014 | // by 64bit st1 store instructions with byte alignment. |
5015 | "dup d20, v16.d[1]\n" |
5016 | "dup d21, v17.d[1]\n" |
5017 | "dup d22, v18.d[1]\n" |
5018 | "dup d23, v19.d[1]\n" |
5019 | |
5020 | // Compute how much of the 8x8 block of destination 8bit values that |
5021 | // we have computed, fit in the destination matrix. Typically, all of |
5022 | // it fits, but when the destination matrix shape is not a multiple |
5023 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
5024 | // not fit entirely. |
5025 | "sub w1, %w[dst_rows], %w[row]\n" |
5026 | "sub w2, %w[dst_cols], %w[col]\n" |
5027 | "mov w3, #8\n" |
5028 | "cmp w1, #8\n" |
5029 | // Compute w1 = how many rows of the 8x8 block fit |
5030 | "csel w1, w1, w3, le\n" |
5031 | "cmp w2, #8\n" |
5032 | // Compute w2 = how many cols of the 8x8 block fit |
5033 | "csel w2, w2, w3, le\n" |
5034 | |
5035 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
5036 | "cmp w1, w3\n" |
5037 | "ccmp w2, w3, 0, eq\n" |
5038 | // Yes, all of the 8x8 block fits, go to fast path. |
5039 | "beq 30f\n" |
5040 | // Not all of the 8x8 block fits. |
5041 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
5042 | "mov x3, %[dst_tmp_buf]\n" |
5043 | "mov x4, #8\n" |
5044 | "b 31f\n" |
5045 | "30:\n" |
5046 | // Yes, all of the 8x8 block fits. |
5047 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
5048 | "mov x3, %[dst_ptr]\n" |
5049 | "mov x4, x11\n" |
5050 | "31:\n" |
5051 | |
5052 | // Write our 8bit values to the destination described by |
5053 | // (x3 address, x4 stride). |
5054 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5055 | "st1 {v16.8b}, [x3], x4\n" |
5056 | RUY_MAKE_ZERO(v16) |
5057 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5058 | "st1 {v20.8b}, [x3], x4\n" |
5059 | RUY_MAKE_ZERO(v20) |
5060 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5061 | "st1 {v17.8b}, [x3], x4\n" |
5062 | RUY_MAKE_ZERO(v17) |
5063 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5064 | "st1 {v21.8b}, [x3], x4\n" |
5065 | RUY_MAKE_ZERO(v21) |
5066 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5067 | "st1 {v18.8b}, [x3], x4\n" |
5068 | RUY_MAKE_ZERO(v18) |
5069 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5070 | "st1 {v22.8b}, [x3], x4\n" |
5071 | RUY_MAKE_ZERO(v22) |
5072 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5073 | "st1 {v19.8b}, [x3], x4\n" |
5074 | RUY_MAKE_ZERO(v19) |
5075 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5076 | "st1 {v23.8b}, [x3], x4\n" |
5077 | RUY_MAKE_ZERO(v23) |
5078 | |
5079 | // For the next block: perform the first few multiply-adds on the data |
5080 | // that we have already loaded. |
5081 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5082 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
5083 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
5084 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
5085 | |
5086 | // If all of the 8x8 block fits, we just finished writing it to the |
5087 | // destination, so we skip the next part. |
5088 | "beq 41f\n" |
5089 | // Not all of the 8x8 block fits in the destination matrix. We just |
5090 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
5091 | // it to copy into the destination matrix the part that fits. |
5092 | "mov x3, %[dst_tmp_buf]\n" |
5093 | "mov x4, %[dst_ptr]\n" |
5094 | "mov w6, #0\n" |
5095 | "50:\n" |
5096 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5097 | "mov w5, #0\n" |
5098 | "51:\n" |
5099 | "ldrb w7, [x3, w5, uxtw]\n" |
5100 | "strb w7, [x4, w5, uxtw]\n" |
5101 | "add w5, w5, #1\n" |
5102 | "cmp w5, w1\n" |
5103 | "blt 51b\n" |
5104 | "add w6, w6, #1\n" |
5105 | "add x3, x3, #8\n" |
5106 | "add x4, x4, x11\n" |
5107 | "cmp w6, w2\n" |
5108 | "blt 50b\n" |
5109 | "41:\n" |
5110 | "add %[dst_ptr], %[dst_ptr], #8\n" |
5111 | // At this point we have completely finished writing values to the |
5112 | // destination matrix for the current block. |
5113 | |
5114 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
5115 | |
5116 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
5117 | |
5118 | // Cast-and-saturate from int32 to int16 |
5119 | "sqxtn v16.4h, v16.4s\n" |
5120 | "sqxtn2 v16.8h, v17.4s\n" |
5121 | "sqxtn v17.4h, v18.4s\n" |
5122 | "sqxtn2 v17.8h, v19.4s\n" |
5123 | "sqxtn v18.4h, v20.4s\n" |
5124 | "sqxtn2 v18.8h, v21.4s\n" |
5125 | "sqxtn v19.4h, v22.4s\n" |
5126 | "sqxtn2 v19.8h, v23.4s\n" |
5127 | "sqxtn v20.4h, v24.4s\n" |
5128 | "sqxtn2 v20.8h, v25.4s\n" |
5129 | "sqxtn v21.4h, v26.4s\n" |
5130 | "sqxtn2 v21.8h, v27.4s\n" |
5131 | "sqxtn v22.4h, v28.4s\n" |
5132 | "sqxtn2 v22.8h, v29.4s\n" |
5133 | "sqxtn v23.4h, v30.4s\n" |
5134 | "sqxtn2 v23.8h, v31.4s\n" |
5135 | |
5136 | // At this point, v24 -- v31 aren't used anymore for the current block, |
5137 | // so we can start clearing these accumulators for the next block |
5138 | // (next iteration of the main loop). |
5139 | RUY_MAKE_ZERO(v24) |
5140 | RUY_MAKE_ZERO(v25) |
5141 | RUY_MAKE_ZERO(v26) |
5142 | RUY_MAKE_ZERO(v27) |
5143 | RUY_MAKE_ZERO(v28) |
5144 | RUY_MAKE_ZERO(v29) |
5145 | RUY_MAKE_ZERO(v30) |
5146 | RUY_MAKE_ZERO(v31) |
5147 | |
5148 | // Add the destination zero point |
5149 | "dup v14.8h, v13.h[4]\n" |
5150 | "sqadd v16.8h, v16.8h, v14.8h\n" |
5151 | "sqadd v17.8h, v17.8h, v14.8h\n" |
5152 | "sqadd v18.8h, v18.8h, v14.8h\n" |
5153 | "sqadd v19.8h, v19.8h, v14.8h\n" |
5154 | "sqadd v20.8h, v20.8h, v14.8h\n" |
5155 | "sqadd v21.8h, v21.8h, v14.8h\n" |
5156 | "sqadd v22.8h, v22.8h, v14.8h\n" |
5157 | "sqadd v23.8h, v23.8h, v14.8h\n" |
5158 | |
5159 | // Cast-and-saturate from int16 to uint8 |
5160 | "sqxtn v16.8b, v16.8h\n" |
5161 | "sqxtn2 v16.16b, v17.8h\n" |
5162 | "sqxtn v17.8b, v18.8h\n" |
5163 | "sqxtn2 v17.16b, v19.8h\n" |
5164 | "sqxtn v18.8b, v20.8h\n" |
5165 | "sqxtn2 v18.16b, v21.8h\n" |
5166 | "sqxtn v19.8b, v22.8h\n" |
5167 | "sqxtn2 v19.16b, v23.8h\n" |
5168 | |
5169 | // Load the clamp_min, clamp_max bounds |
5170 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
5171 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
5172 | "dup v14.16b, w2\n" // clamp_min |
5173 | "dup v15.16b, w3\n" // clamp_max |
5174 | |
5175 | // Apply the clamp_min bound |
5176 | "smax v16.16b, v16.16b, v14.16b\n" |
5177 | "smax v17.16b, v17.16b, v14.16b\n" |
5178 | "smax v18.16b, v18.16b, v14.16b\n" |
5179 | "smax v19.16b, v19.16b, v14.16b\n" |
5180 | |
5181 | // Apply the clamp_max bound |
5182 | "smin v16.16b, v16.16b, v15.16b\n" |
5183 | "smin v17.16b, v17.16b, v15.16b\n" |
5184 | "smin v18.16b, v18.16b, v15.16b\n" |
5185 | "smin v19.16b, v19.16b, v15.16b\n" |
5186 | |
5187 | // Make it so that all of the final 8bit values are stored in the |
5188 | // first 64bits of 128bit NEON registers, so they can be stored |
5189 | // by 64bit st1 store instructions with byte alignment. |
5190 | "dup d20, v16.d[1]\n" |
5191 | "dup d21, v17.d[1]\n" |
5192 | "dup d22, v18.d[1]\n" |
5193 | "dup d23, v19.d[1]\n" |
5194 | |
5195 | // Compute how much of the 8x8 block of destination 8bit values that |
5196 | // we have computed, fit in the destination matrix. Typically, all of |
5197 | // it fits, but when the destination matrix shape is not a multiple |
5198 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
5199 | // not fit entirely. |
5200 | "sub w1, %w[dst_rows], %w[row]\n" |
5201 | "sub w2, %w[dst_cols], %w[col]\n" |
5202 | "mov w3, #8\n" |
5203 | "cmp w1, #8\n" |
5204 | // Compute w1 = how many rows of the 8x8 block fit |
5205 | "csel w1, w1, w3, le\n" |
5206 | "cmp w2, #8\n" |
5207 | // Compute w2 = how many cols of the 8x8 block fit |
5208 | "csel w2, w2, w3, le\n" |
5209 | |
5210 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
5211 | "cmp w1, w3\n" |
5212 | "ccmp w2, w3, 0, eq\n" |
5213 | // Yes, all of the 8x8 block fits, go to fast path. |
5214 | "beq 130f\n" |
5215 | // Not all of the 8x8 block fits. |
5216 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
5217 | "mov x3, %[dst_tmp_buf]\n" |
5218 | "mov x4, #8\n" |
5219 | "b 131f\n" |
5220 | "130:\n" |
5221 | // Yes, all of the 8x8 block fits. |
5222 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
5223 | "mov x3, %[dst_ptr]\n" |
5224 | "mov x4, x11\n" |
5225 | "131:\n" |
5226 | |
5227 | // Write our 8bit values to the destination described by |
5228 | // (x3 address, x4 stride). |
5229 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5230 | "st1 {v16.8b}, [x3], x4\n" |
5231 | RUY_MAKE_ZERO(v16) |
5232 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5233 | "st1 {v20.8b}, [x3], x4\n" |
5234 | RUY_MAKE_ZERO(v20) |
5235 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5236 | "st1 {v17.8b}, [x3], x4\n" |
5237 | RUY_MAKE_ZERO(v17) |
5238 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5239 | "st1 {v21.8b}, [x3], x4\n" |
5240 | RUY_MAKE_ZERO(v21) |
5241 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5242 | "st1 {v18.8b}, [x3], x4\n" |
5243 | RUY_MAKE_ZERO(v18) |
5244 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5245 | "st1 {v22.8b}, [x3], x4\n" |
5246 | RUY_MAKE_ZERO(v22) |
5247 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5248 | "st1 {v19.8b}, [x3], x4\n" |
5249 | RUY_MAKE_ZERO(v19) |
5250 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5251 | "st1 {v23.8b}, [x3], x4\n" |
5252 | RUY_MAKE_ZERO(v23) |
5253 | |
5254 | // For the next block: perform the first few multiply-adds on the data |
5255 | // that we have already loaded. |
5256 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5257 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
5258 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
5259 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
5260 | |
5261 | // If all of the 8x8 block fits, we just finished writing it to the |
5262 | // destination, so we skip the next part. |
5263 | "beq 141f\n" |
5264 | // Not all of the 8x8 block fits in the destination matrix. We just |
5265 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
5266 | // it to copy into the destination matrix the part that fits. |
5267 | "mov x3, %[dst_tmp_buf]\n" |
5268 | "mov x4, %[dst_ptr]\n" |
5269 | "mov w6, #0\n" |
5270 | "150:\n" |
5271 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5272 | "mov w5, #0\n" |
5273 | "151:\n" |
5274 | "ldrb w7, [x3, w5, uxtw]\n" |
5275 | "strb w7, [x4, w5, uxtw]\n" |
5276 | "add w5, w5, #1\n" |
5277 | "cmp w5, w1\n" |
5278 | "blt 151b\n" |
5279 | "add w6, w6, #1\n" |
5280 | "add x3, x3, #8\n" |
5281 | "add x4, x4, x11\n" |
5282 | "cmp w6, w2\n" |
5283 | "blt 150b\n" |
5284 | "141:\n" |
5285 | "add %[dst_ptr], %[dst_ptr], #8\n" |
5286 | // At this point we have completely finished writing values to the |
5287 | // destination matrix for the current block. |
5288 | |
5289 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
5290 | |
5291 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
5292 | |
5293 | // Add the destination zero point |
5294 | "dup v14.8h, v13.h[4]\n" |
5295 | "saddw v16.4s, v16.4s, v14.4h\n" |
5296 | "saddw v17.4s, v17.4s, v14.4h\n" |
5297 | "saddw v18.4s, v18.4s, v14.4h\n" |
5298 | "saddw v19.4s, v19.4s, v14.4h\n" |
5299 | "saddw v20.4s, v20.4s, v14.4h\n" |
5300 | "saddw v21.4s, v21.4s, v14.4h\n" |
5301 | "saddw v22.4s, v22.4s, v14.4h\n" |
5302 | "saddw v23.4s, v23.4s, v14.4h\n" |
5303 | "saddw v24.4s, v24.4s, v14.4h\n" |
5304 | "saddw v25.4s, v25.4s, v14.4h\n" |
5305 | "saddw v26.4s, v26.4s, v14.4h\n" |
5306 | "saddw v27.4s, v27.4s, v14.4h\n" |
5307 | "saddw v28.4s, v28.4s, v14.4h\n" |
5308 | "saddw v29.4s, v29.4s, v14.4h\n" |
5309 | "saddw v30.4s, v30.4s, v14.4h\n" |
5310 | "saddw v31.4s, v31.4s, v14.4h\n" |
5311 | |
5312 | // Cast-and-saturate from int32 to int16 |
5313 | "sqxtn v16.4h, v16.4s\n" |
5314 | "sqxtn2 v16.8h, v17.4s\n" |
5315 | "sqxtn v17.4h, v18.4s\n" |
5316 | "sqxtn2 v17.8h, v19.4s\n" |
5317 | "sqxtn v18.4h, v20.4s\n" |
5318 | "sqxtn2 v18.8h, v21.4s\n" |
5319 | "sqxtn v19.4h, v22.4s\n" |
5320 | "sqxtn2 v19.8h, v23.4s\n" |
5321 | "sqxtn v20.4h, v24.4s\n" |
5322 | "sqxtn2 v20.8h, v25.4s\n" |
5323 | "sqxtn v21.4h, v26.4s\n" |
5324 | "sqxtn2 v21.8h, v27.4s\n" |
5325 | "sqxtn v22.4h, v28.4s\n" |
5326 | "sqxtn2 v22.8h, v29.4s\n" |
5327 | "sqxtn v23.4h, v30.4s\n" |
5328 | "sqxtn2 v23.8h, v31.4s\n" |
5329 | |
5330 | // At this point, v24 -- v31 aren't used anymore for the current block, |
5331 | // so we can start clearing these accumulators for the next block |
5332 | // (next iteration of the main loop). |
5333 | RUY_MAKE_ZERO(v24) |
5334 | RUY_MAKE_ZERO(v25) |
5335 | RUY_MAKE_ZERO(v26) |
5336 | RUY_MAKE_ZERO(v27) |
5337 | RUY_MAKE_ZERO(v28) |
5338 | RUY_MAKE_ZERO(v29) |
5339 | RUY_MAKE_ZERO(v30) |
5340 | RUY_MAKE_ZERO(v31) |
5341 | |
5342 | // Load the clamp_min, clamp_max bounds |
5343 | "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
5344 | "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
5345 | "dup v14.8h, w2\n" // clamp_min |
5346 | "dup v15.8h, w3\n" // clamp_max |
5347 | |
5348 | // Apply the clamp_min bound |
5349 | "smax v16.8h, v16.8h, v14.8h\n" |
5350 | "smax v17.8h, v17.8h, v14.8h\n" |
5351 | "smax v18.8h, v18.8h, v14.8h\n" |
5352 | "smax v19.8h, v19.8h, v14.8h\n" |
5353 | "smax v20.8h, v20.8h, v14.8h\n" |
5354 | "smax v21.8h, v21.8h, v14.8h\n" |
5355 | "smax v22.8h, v22.8h, v14.8h\n" |
5356 | "smax v23.8h, v23.8h, v14.8h\n" |
5357 | // Apply the clamp_max bound |
5358 | "smin v16.8h, v16.8h, v15.8h\n" |
5359 | "smin v17.8h, v17.8h, v15.8h\n" |
5360 | "smin v18.8h, v18.8h, v15.8h\n" |
5361 | "smin v19.8h, v19.8h, v15.8h\n" |
5362 | "smin v20.8h, v20.8h, v15.8h\n" |
5363 | "smin v21.8h, v21.8h, v15.8h\n" |
5364 | "smin v22.8h, v22.8h, v15.8h\n" |
5365 | "smin v23.8h, v23.8h, v15.8h\n" |
5366 | |
5367 | // Compute how much of the 8x8 block of destination 16bit values that |
5368 | // we have computed, fit in the destination matrix. Typically, all of |
5369 | // it fits, but when the destination matrix shape is not a multiple |
5370 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
5371 | // not fit entirely. |
5372 | "sub w1, %w[dst_rows], %w[row]\n" |
5373 | "sub w2, %w[dst_cols], %w[col]\n" |
5374 | "mov w3, #8\n" |
5375 | "cmp w1, #8\n" |
5376 | // Compute w1 = how many rows of the 8x8 block fit |
5377 | "csel w1, w1, w3, le\n" |
5378 | "cmp w2, #8\n" |
5379 | // Compute w1 = how many rows of the 8x8 block fit |
5380 | "csel w2, w2, w3, le\n" |
5381 | |
5382 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
5383 | "cmp w1, w3\n" |
5384 | "ccmp w2, w3, 0, eq\n" |
5385 | // Yes, all of the 8x8 block fits, go to fast path. |
5386 | "beq 230f\n" |
5387 | // Not all of the 8x8 block fits. |
5388 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
5389 | "mov x3, %[dst_tmp_buf]\n" |
5390 | "mov x4, #16\n" |
5391 | "b 231f\n" |
5392 | "230:\n" |
5393 | // Yes, all of the 8x8 block fits. |
5394 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
5395 | "mov x3, %[dst_ptr]\n" |
5396 | "mov x4, x11\n" |
5397 | "231:\n" |
5398 | |
5399 | // Write our 16bit values to the destination described by |
5400 | // (x3 address, x4 stride). |
5401 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5402 | "st1 {v16.8h}, [x3], x4\n" |
5403 | RUY_MAKE_ZERO(v16) |
5404 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5405 | "st1 {v17.8h}, [x3], x4\n" |
5406 | RUY_MAKE_ZERO(v17) |
5407 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5408 | "st1 {v18.8h}, [x3], x4\n" |
5409 | RUY_MAKE_ZERO(v18) |
5410 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5411 | "st1 {v19.8h}, [x3], x4\n" |
5412 | RUY_MAKE_ZERO(v19) |
5413 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5414 | "st1 {v20.8h}, [x3], x4\n" |
5415 | RUY_MAKE_ZERO(v20) |
5416 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5417 | "st1 {v21.8h}, [x3], x4\n" |
5418 | RUY_MAKE_ZERO(v21) |
5419 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5420 | "st1 {v22.8h}, [x3], x4\n" |
5421 | RUY_MAKE_ZERO(v22) |
5422 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
5423 | "st1 {v23.8h}, [x3], x4\n" |
5424 | RUY_MAKE_ZERO(v23) |
5425 | |
5426 | // For the next block: perform the first few multiply-adds on the data |
5427 | // that we have already loaded. |
5428 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5429 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
5430 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
5431 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
5432 | |
5433 | // If all of the 8x8 block fits, we just finished writing it to the |
5434 | // destination, so we skip the next part. |
5435 | "beq 241f\n" |
5436 | // Not all of the 8x8 block fits in the destination matrix. We just |
5437 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
5438 | // it to copy into the destination matrix the part that fits. |
5439 | "mov x3, %[dst_tmp_buf]\n" |
5440 | "mov x4, %[dst_ptr]\n" |
5441 | "mov w6, #0\n" |
5442 | "250:\n" |
5443 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5444 | "mov w5, #0\n" |
5445 | "251:\n" |
5446 | "ldrsh w7, [x3, x5, lsl #1]\n" |
5447 | "strh w7, [x4, x5, lsl #1]\n" |
5448 | "add w5, w5, #1\n" |
5449 | "cmp w5, w1\n" |
5450 | "blt 251b\n" |
5451 | "add w6, w6, #1\n" |
5452 | "add x3, x3, #16\n" |
5453 | "add x4, x4, x11\n" |
5454 | "cmp w6, w2\n" |
5455 | "blt 250b\n" |
5456 | "241:\n" |
5457 | "add %[dst_ptr], %[dst_ptr], #16\n" |
5458 | // At this point we have completely finished writing values to the |
5459 | // destination matrix for the current block. |
5460 | |
5461 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
5462 | |
5463 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
5464 | |
5465 | // Since the store type is the same as the accum type, no need for |
5466 | // downcast. There's also no need for clamp by min/max. |
5467 | |
5468 | // Compute how much of the 8x8 block of destination 32it values that |
5469 | // we have computed, fit in the destination matrix. Typically, all of |
5470 | // it fits, but when the destination matrix shape is not a multiple |
5471 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
5472 | // not fit entirely. |
5473 | "sub w1, %w[dst_rows], %w[row]\n" |
5474 | "sub w2, %w[dst_cols], %w[col]\n" |
5475 | "mov w3, #8\n" |
5476 | "cmp w1, #8\n" |
5477 | // Compute w1 = how many rows of the 8x8 block fit |
5478 | "csel w1, w1, w3, le\n" |
5479 | "cmp w2, #8\n" |
5480 | // Compute w1 = how many rows of the 8x8 block fit |
5481 | "csel w2, w2, w3, le\n" |
5482 | |
5483 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
5484 | "cmp w1, w3\n" |
5485 | "ccmp w2, w3, 0, eq\n" |
5486 | // Yes, all of the 8x8 block fits, go to fast path. |
5487 | "beq 330f\n" |
5488 | // Not all of the 8x8 block fits. |
5489 | // Write to dst_tmp_buf |
5490 | "mov x3, %[dst_tmp_buf]\n" |
5491 | "st1 {v16.4s}, [x3], #16\n" |
5492 | RUY_MAKE_ZERO(v16) |
5493 | "st1 {v17.4s}, [x3], #16\n" |
5494 | RUY_MAKE_ZERO(v17) |
5495 | "st1 {v18.4s}, [x3], #16\n" |
5496 | RUY_MAKE_ZERO(v18) |
5497 | "st1 {v19.4s}, [x3], #16\n" |
5498 | RUY_MAKE_ZERO(v19) |
5499 | "st1 {v20.4s}, [x3], #16\n" |
5500 | RUY_MAKE_ZERO(v20) |
5501 | "st1 {v21.4s}, [x3], #16\n" |
5502 | RUY_MAKE_ZERO(v21) |
5503 | "st1 {v22.4s}, [x3], #16\n" |
5504 | RUY_MAKE_ZERO(v22) |
5505 | "st1 {v23.4s}, [x3], #16\n" |
5506 | RUY_MAKE_ZERO(v23) |
5507 | "st1 {v24.4s}, [x3], #16\n" |
5508 | RUY_MAKE_ZERO(v24) |
5509 | "st1 {v25.4s}, [x3], #16\n" |
5510 | RUY_MAKE_ZERO(v25) |
5511 | "st1 {v26.4s}, [x3], #16\n" |
5512 | RUY_MAKE_ZERO(v26) |
5513 | "st1 {v27.4s}, [x3], #16\n" |
5514 | RUY_MAKE_ZERO(v27) |
5515 | "st1 {v28.4s}, [x3], #16\n" |
5516 | RUY_MAKE_ZERO(v28) |
5517 | "st1 {v29.4s}, [x3], #16\n" |
5518 | RUY_MAKE_ZERO(v29) |
5519 | "st1 {v30.4s}, [x3], #16\n" |
5520 | RUY_MAKE_ZERO(v30) |
5521 | "st1 {v31.4s}, [x3], #16\n" |
5522 | RUY_MAKE_ZERO(v31) |
5523 | |
5524 | "b 331f\n" |
5525 | |
5526 | "330:\n" |
5527 | // Yes, all of the 8x8 block fits. |
5528 | "mov x4, %[dst_ptr]\n" |
5529 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5530 | "mov x3, x4\n" |
5531 | "st1 {v16.4s, v17.4s}, [x3], #32\n" |
5532 | RUY_MAKE_ZERO(v16) |
5533 | RUY_MAKE_ZERO(v17) |
5534 | "add x4, x4, x11\n" |
5535 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5536 | "mov x3, x4\n" |
5537 | "st1 {v18.4s, v19.4s}, [x3], #32\n" |
5538 | RUY_MAKE_ZERO(v18) |
5539 | RUY_MAKE_ZERO(v19) |
5540 | "add x4, x4, x11\n" |
5541 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5542 | "mov x3, x4\n" |
5543 | "st1 {v20.4s, v21.4s}, [x3], #32\n" |
5544 | RUY_MAKE_ZERO(v20) |
5545 | RUY_MAKE_ZERO(v21) |
5546 | "add x4, x4, x11\n" |
5547 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5548 | "mov x3, x4\n" |
5549 | "st1 {v22.4s, v23.4s}, [x3], #32\n" |
5550 | RUY_MAKE_ZERO(v22) |
5551 | RUY_MAKE_ZERO(v23) |
5552 | "add x4, x4, x11\n" |
5553 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5554 | "mov x3, x4\n" |
5555 | "st1 {v24.4s, v25.4s}, [x3], #32\n" |
5556 | RUY_MAKE_ZERO(v24) |
5557 | RUY_MAKE_ZERO(v25) |
5558 | "add x4, x4, x11\n" |
5559 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5560 | "mov x3, x4\n" |
5561 | "st1 {v26.4s, v27.4s}, [x3], #32\n" |
5562 | RUY_MAKE_ZERO(v26) |
5563 | RUY_MAKE_ZERO(v27) |
5564 | "add x4, x4, x11\n" |
5565 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5566 | "mov x3, x4\n" |
5567 | "st1 {v28.4s, v29.4s}, [x3], #32\n" |
5568 | RUY_MAKE_ZERO(v28) |
5569 | RUY_MAKE_ZERO(v29) |
5570 | "add x4, x4, x11\n" |
5571 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5572 | "mov x3, x4\n" |
5573 | "st1 {v30.4s, v31.4s}, [x3], #32\n" |
5574 | RUY_MAKE_ZERO(v30) |
5575 | RUY_MAKE_ZERO(v31) |
5576 | |
5577 | "331:\n" |
5578 | |
5579 | // For the next block: perform the first few multiply-adds on the data |
5580 | // that we have already loaded. |
5581 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5582 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
5583 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
5584 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
5585 | |
5586 | // If all of the 8x8 block fits, we just finished writing it to the |
5587 | // destination, so we skip the next part. |
5588 | "beq 341f\n" |
5589 | |
5590 | // Not all of the 8x8 block fits in the destination matrix. We just |
5591 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
5592 | // it to copy into the destination matrix the part that fits. |
5593 | "mov x3, %[dst_tmp_buf]\n" |
5594 | "mov x4, %[dst_ptr]\n" |
5595 | "mov w6, #0\n" |
5596 | "350:\n" |
5597 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
5598 | "mov w5, #0\n" |
5599 | "351:\n" |
5600 | "ldr w7, [x3, x5, lsl #2]\n" |
5601 | "str w7, [x4, x5, lsl #2]\n" |
5602 | "add w5, w5, #1\n" |
5603 | "cmp w5, w1\n" |
5604 | "blt 351b\n" |
5605 | "add w6, w6, #1\n" |
5606 | "add x3, x3, #32\n" |
5607 | "add x4, x4, x11\n" |
5608 | "cmp w6, w2\n" |
5609 | "blt 350b\n" |
5610 | "341:\n" |
5611 | "add %[dst_ptr], %[dst_ptr], #32\n" |
5612 | // At this point we have completely finished writing values to the |
5613 | // destination matrix for the current block. |
5614 | |
5615 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
5616 | |
5617 | // Reload some params --- we had used x5 -- x7 for a few other things |
5618 | // since the last time we had loaded them. |
5619 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
5620 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
5621 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
5622 | |
5623 | // Move to the next block of the destination matrix, for the next iter |
5624 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
5625 | // been updated earlier. |
5626 | // Have we reached the end row? |
5627 | "cmp %w[row], w7\n" |
5628 | "beq 20f\n" // yes, end row. |
5629 | // Not end row. Move to the next row. |
5630 | "add %w[row], %w[row], #8\n" |
5631 | "b 21f\n" |
5632 | "20:\n" |
5633 | // Was already at end row. |
5634 | "mov %w[row], w6\n" // Move back to first row. |
5635 | "add %w[col], %w[col], #8\n" // Move to the next column. |
5636 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
5637 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
5638 | "21:\n" |
5639 | |
5640 | // Main loop exit condition: have we hit the end column? |
5641 | "cmp %w[col], w8\n" |
5642 | |
5643 | // w1 is the number of levels of depth that we have already loaded |
5644 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
5645 | // above, this is currently 4. |
5646 | "mov w1, #4\n" |
5647 | |
5648 | "ble 1b\n" |
5649 | |
5650 | // clang-format on |
5651 | |
5652 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
5653 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
5654 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
5655 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
5656 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
5657 | [dst_type_id] "r" (params.dst_type_id) |
5658 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
5659 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
5660 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
5661 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
5662 | } |
5663 | |
5664 | |
5665 | // Similar to the above 8-bit dotprod kernel, but specialized for the case of |
5666 | // RHS cols == 1. |
5667 | // Relevant target CPUs for this kernel include ARM Cortex-A76, |
5668 | // since these are 64-bit, out-of-order and with dotprod support. |
5669 | void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) { |
5670 | profiler::ScopeLabel label("Kernel (kNeonDotprod)" ); |
5671 | |
5672 | CheckOffsetsInKernelParams8bit(params); |
5673 | |
5674 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
5675 | const std::int8_t* rhs_col_ptr = |
5676 | static_cast<const int8_t*>(params.rhs_base_ptr); |
5677 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
5678 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
5679 | void* dst_col_ptr = params.dst_base_ptr; |
5680 | void* dst_ptr = dst_col_ptr; |
5681 | int row = params.start_row; |
5682 | int col = params.start_col; |
5683 | |
5684 | RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); |
5685 | |
5686 | // The asm kernel below has the following NEON register allocation: |
5687 | // |
5688 | // v16 -- v31 are int32 accumulators. |
5689 | // During accumulation, v0 -- v15 are used to load int8 data from LHS and |
5690 | // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and |
5691 | // v3 are used to load a 4x8 block of RHS, like this: |
5692 | // |
5693 | // int8 RHS 4x1 block |
5694 | // /-------| |
5695 | // |v2.b[0]| |
5696 | // | ... | |
5697 | // |v2.b[3]| |
5698 | // \-------/ |
5699 | // int8 LHS 8x4 block |
5700 | // /---------------------\ /--------| |
5701 | // |v0.b[0] ... v0.b[3] | |v16.s[0]| |
5702 | // | ... ... | | ... | |
5703 | // |v0.b[12] ... v0.b[15]| |v16.s[3]| |
5704 | // |v1.b[0] ... v1.b[3] | |v17.s[0]| |
5705 | // | ... ... | | ... | |
5706 | // |v1.b[12] ... v1.b[15]| |v17.s[3]| |
5707 | // \---------------------/ \--------/ |
5708 | // int32 accumulators 8x1 block |
5709 | // |
5710 | // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step |
5711 | // is repeated 4 times, using 4x more registers for LHS and RHS, so that |
5712 | // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. |
5713 | // |
5714 | // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are |
5715 | // unused, and v8 -- v15 are used for loading parameters used for the |
5716 | // post-accumulation part of the kernel. |
5717 | asm volatile( |
5718 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
5719 | |
5720 | // clang-format off |
5721 | |
5722 | // Load some parameters into registers. |
5723 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
5724 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
5725 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
5726 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
5727 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
5728 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
5729 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
5730 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
5731 | |
5732 | // Load the first 32 bytes of LHS and RHS data. |
5733 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
5734 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
5735 | "ld1 {v2.8b}, [%[rhs_ptr]]\n" |
5736 | "add %[rhs_ptr], %[rhs_ptr], #32\n" |
5737 | |
5738 | // Clear accumulators. |
5739 | RUY_MAKE_ZERO(v16) |
5740 | RUY_MAKE_ZERO(v17) |
5741 | |
5742 | // w1 is the number of levels of depth that we have already loaded |
5743 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
5744 | // above, this is currently 4. |
5745 | "mov w1, #4\n" |
5746 | |
5747 | // Perform the first few multiply-adds on the data that we have already |
5748 | // loaded. |
5749 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5750 | |
5751 | // Main loop of the whole GEMM, over rows and columns of the |
5752 | // destination matrix. |
5753 | "1:\n" |
5754 | |
5755 | // Ordinary kernel inner loop (over depth), the simpler loop that the |
5756 | // above was an equivalent 4x-partially-unrolled version of. |
5757 | |
5758 | // Reminder - w1 is how many levels of depth we have already loaded |
5759 | // data for, w12 is the total depth. |
5760 | "cmp w1, w12\n" |
5761 | "beq 79f\n" |
5762 | |
5763 | "2:\n" |
5764 | |
5765 | // Because of the data that we have already loaded, we can start the |
5766 | // loop body right away with some multiply-adds. |
5767 | // Each iteration of this loop advances by 4 levels of depth. |
5768 | "add w1, w1, #4\n" |
5769 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
5770 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
5771 | // Loop termination condition. |
5772 | "cmp w1, w12\n" |
5773 | "ld1 {v2.8b}, [%[rhs_ptr]]\n" |
5774 | "add %[rhs_ptr], %[rhs_ptr], #32\n" |
5775 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
5776 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
5777 | |
5778 | "blt 2b\n" |
5779 | |
5780 | "79:\n" |
5781 | // End of the inner loop on depth. Now perform the remaining |
5782 | // multiply-adds of the last 4 levels of depth, for which the LHS |
5783 | // and RHS data is already loaded. |
5784 | |
5785 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
5786 | |
5787 | // End of accumulation. The registers v16 -- v31 contain the final |
5788 | // int32 accumulator values of the current 8x8 destination block. |
5789 | // We now have to compute the final 8-bit values from these int32 |
5790 | // accumulators, and advance to the next 8x8 block. We intertwine |
5791 | // these two aspects whenever possible for optimal pipelining, both |
5792 | // at the data flow level (prefetch data for next block as early as |
5793 | // possible) and instruction pipelining level (some of the next-block |
5794 | // work can dual-issue with some of the final work on the current |
5795 | // block). |
5796 | |
5797 | // Logic to advance to the next block in preparation for the next |
5798 | // iteration of the main loop. For now, we only want to compute |
5799 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
5800 | // not yet ready to update the values of row and col, as we still need |
5801 | // the current values for the rest of the work on the current block. |
5802 | |
5803 | "cmp %w[row], w7\n" // Have we finished the last row? |
5804 | "bge 4f\n" // If finished last row, go to 4 |
5805 | // Not finished last row: then advance to next row. |
5806 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
5807 | "b 5f\n" |
5808 | "4:\n" // Finished last row... |
5809 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
5810 | // Now we need to advance to the next column. If we already |
5811 | // finished the last column, then in principle we are done, however |
5812 | // we can't just return here, as we need to allow the end work of the |
5813 | // current block to complete. The good news is that at this point it |
5814 | // doesn't matter what data we load for the next column, since |
5815 | // we will exit from the main loop below before actually storing |
5816 | // anything computed from that data. |
5817 | "cmp %w[col], w8\n" // Have we finished the last column? |
5818 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
5819 | // Not finished last column: then advance to next column. |
5820 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
5821 | "5:\n" |
5822 | |
5823 | // Set the LHS and RHS data pointers to the start of the columns just |
5824 | // computed. |
5825 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
5826 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
5827 | |
5828 | // Load some parameters needed for the end work on current block. |
5829 | "mvni v8.4s, #0\n" |
5830 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
5831 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
5832 | "ins v13.h[4], w4\n" // dst_zero_point |
5833 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
5834 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
5835 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
5836 | "add x5, x4, %x[row], lsl #2\n" |
5837 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
5838 | "csel x4, x4, x5, eq\n" |
5839 | |
5840 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
5841 | "add x5, x1, %x[row], lsl #2\n" |
5842 | |
5843 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
5844 | "csel x1, x1, x5, eq\n" |
5845 | |
5846 | // Load 8 bias values. |
5847 | "ld1 {v14.4s}, [x1], #16\n" |
5848 | "ld1 {v15.4s}, [x1]\n" |
5849 | |
5850 | // Now that we know what LHS and RHS data the next iteration of the |
5851 | // main loop will need to load, we start loading the first 32 bytes of |
5852 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
5853 | // in the rest of the work on the current block. |
5854 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
5855 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
5856 | "ld1 {v2.8b}, [%[rhs_ptr]]\n" |
5857 | "add %[rhs_ptr], %[rhs_ptr], #32\n" |
5858 | |
5859 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
5860 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
5861 | "add v14.4s, v14.4s, v9.4s\n" |
5862 | "add v15.4s, v15.4s, v9.4s\n" |
5863 | |
5864 | // Perform the bias-addition (per the above, we have just folded into |
5865 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
5866 | "add v16.4s, v16.4s, v14.4s\n" |
5867 | "add v17.4s, v17.4s, v15.4s\n" |
5868 | |
5869 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
5870 | "beq 401f\n" |
5871 | "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
5872 | "add x3, x3, %x[col], lsl #2\n" |
5873 | "ld1 {v14.4s}, [x3], #16\n" |
5874 | "ld1 {v15.4s}, [x3]\n" |
5875 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
5876 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
5877 | // Subtract rhs_sums * lhs_zero_point, per |
5878 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
5879 | "mls v16.4s, v10.4s, v14.s[0]\n" |
5880 | "mls v17.4s, v10.4s, v14.s[0]\n" |
5881 | "401:\n" |
5882 | |
5883 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
5884 | "beq 402f\n" |
5885 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
5886 | "add x2, x2, %x[row], lsl #2\n" |
5887 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
5888 | // Load 4 lhs_sums values. |
5889 | "ld1 {v11.4s}, [x2], #16\n" |
5890 | "ld1 {v12.4s}, [x2]\n" |
5891 | "ins v13.s[1], w5\n" // rhs_zero_point |
5892 | // Compute lhs_sums * rhs_zero_point. |
5893 | "mul v11.4s, v11.4s, v13.s[1]\n" |
5894 | "mul v12.4s, v12.4s, v13.s[1]\n" |
5895 | // Subtract lhs_sums * rhs_zero_point, per |
5896 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
5897 | "sub v16.4s, v16.4s, v11.4s\n" |
5898 | "sub v17.4s, v17.4s, v12.4s\n" |
5899 | |
5900 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
5901 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
5902 | |
5903 | "402:\n" |
5904 | |
5905 | // At this point we have computed the final int32 values. Now we |
5906 | // start down-quantizing them to obtain the final 8bit values from them. |
5907 | |
5908 | // As part of this down-quantization, our int32 values will be |
5909 | // multiplied by a multiplier that has a fixed-point component and an |
5910 | // exponent component. |
5911 | |
5912 | //Load the exponent part of the multiplier. |
5913 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
5914 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
5915 | "add x5, x1, %x[row], lsl #2\n" |
5916 | "csel x1, x1, x5, eq\n" |
5917 | |
5918 | "ldr q9, [x1]\n" |
5919 | "ldr q10, [x1, #16]\n" |
5920 | |
5921 | "smin v11.4s, v8.4s, v9.4s\n" |
5922 | "smin v12.4s, v8.4s, v10.4s\n" |
5923 | "sub v9.4s, v9.4s, v11.4s\n" |
5924 | "sub v10.4s, v10.4s, v12.4s\n" |
5925 | |
5926 | // Apply the positive exponent part of the multiplier. |
5927 | "sshl v16.4s, v16.4s, v9.4s\n" |
5928 | "sshl v17.4s, v17.4s, v10.4s\n" |
5929 | "403:\n" |
5930 | |
5931 | "ldr q14, [x4]\n" // multiplier_fixedpoint |
5932 | "ldr q15, [x4, #16]\n" // multiplier_fixedpoint |
5933 | |
5934 | // Apply the fixed-point part of the multiplier. |
5935 | "sqdmulh v16.4s, v16.4s, v14.4s\n" |
5936 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
5937 | |
5938 | // Apply the negative exponent part of the multiplier. |
5939 | "srshl v16.4s, v16.4s, v11.4s\n" |
5940 | "srshl v17.4s, v17.4s, v12.4s\n" |
5941 | |
5942 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
5943 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
5944 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
5945 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
5946 | |
5947 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
5948 | |
5949 | // Cast-and-saturate from int32 to int16 |
5950 | "sqxtn v16.4h, v16.4s\n" |
5951 | "sqxtn2 v16.8h, v17.4s\n" |
5952 | // All data in v16 at this point. |
5953 | |
5954 | // Add the destination zero point |
5955 | "dup v14.8h, v13.h[4]\n" |
5956 | "sqadd v16.8h, v16.8h, v14.8h\n" |
5957 | |
5958 | // Cast-and-saturate from int16 to uint8, leaving all data in the |
5959 | // lower half of v16. |
5960 | "sqxtun v16.8b, v16.8h\n" |
5961 | |
5962 | // Load the clamp_min, clamp_max bounds |
5963 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
5964 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
5965 | "dup v14.16b, w2\n" // clamp_min |
5966 | "dup v15.16b, w3\n" // clamp_max |
5967 | |
5968 | // Apply the clamp_min bound |
5969 | "umax v16.16b, v16.16b, v14.16b\n" |
5970 | |
5971 | // Apply the clamp_max bound |
5972 | "umin v16.16b, v16.16b, v15.16b\n" |
5973 | |
5974 | // Make it so that all of the final 8bit values are stored in the |
5975 | // first 64bits of 128bit NEON registers, so they can be stored |
5976 | // by 64bit st1 store instructions with byte alignment. |
5977 | "dup d20, v16.d[1]\n" |
5978 | |
5979 | // Compute how much of the 8x1 block of destination 8bit values that |
5980 | // we have computed, fit in the destination matrix. Typically, all of |
5981 | // it fits, but when the destination matrix shape is not a multiple |
5982 | // of 8x1, there are some 8x1 blocks along the boundaries that do |
5983 | // not fit entirely. |
5984 | "sub w1, %w[dst_rows], %w[row]\n" |
5985 | "sub w2, %w[dst_cols], %w[col]\n" |
5986 | "mov w3, #8\n" |
5987 | "cmp w1, #8\n" |
5988 | // Compute w1 = how many rows of the 8x1 block fit |
5989 | "csel w1, w1, w3, le\n" |
5990 | "cmp w2, #8\n" |
5991 | |
5992 | // Test if w1==8, i.e. if all of the 8x1 block fits. |
5993 | "cmp w1, w3\n" |
5994 | // Yes, all of the 8x1 block fits, go to fast path. |
5995 | "beq 30f\n" |
5996 | // Not all of the 8x1 block fits. |
5997 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
5998 | "mov x3, %[dst_tmp_buf]\n" |
5999 | "mov x4, #8\n" |
6000 | "b 31f\n" |
6001 | "30:\n" |
6002 | // Yes, all of the 8x1 block fits. |
6003 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
6004 | "mov x3, %[dst_ptr]\n" |
6005 | "mov x4, x11\n" |
6006 | "31:\n" |
6007 | |
6008 | // Write our 8bit values to the destination |
6009 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6010 | "st1 {v16.8b}, [x3]\n" |
6011 | RUY_MAKE_ZERO(v16) |
6012 | RUY_MAKE_ZERO(v17) |
6013 | |
6014 | // For the next block: perform the first few multiply-adds on the data |
6015 | // that we have already loaded. |
6016 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6017 | |
6018 | // If all of the 8x8 block fits, we just finished writing it to the |
6019 | // destination, so we skip the next part. |
6020 | "beq 41f\n" |
6021 | // Not all of the 8x8 block fits in the destination matrix. We just |
6022 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
6023 | // it to copy into the destination matrix the part that fits. |
6024 | "mov x3, %[dst_tmp_buf]\n" |
6025 | "mov x4, %[dst_ptr]\n" |
6026 | "mov w6, #0\n" |
6027 | "50:\n" |
6028 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6029 | "mov w5, #0\n" |
6030 | "51:\n" |
6031 | "ldrb w7, [x3, w5, uxtw]\n" |
6032 | "strb w7, [x4, w5, uxtw]\n" |
6033 | "add w5, w5, #1\n" |
6034 | "cmp w5, w1\n" |
6035 | "blt 51b\n" |
6036 | "41:\n" |
6037 | "add %[dst_ptr], %[dst_ptr], #8\n" |
6038 | // At this point we have completely finished writing values to the |
6039 | // destination matrix for the current block. |
6040 | |
6041 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
6042 | |
6043 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
6044 | |
6045 | // Cast-and-saturate from int32 to int16 |
6046 | "sqxtn v16.4h, v16.4s\n" |
6047 | "sqxtn2 v16.8h, v17.4s\n" |
6048 | |
6049 | |
6050 | // Add the destination zero point |
6051 | "dup v14.8h, v13.h[4]\n" |
6052 | "sqadd v16.8h, v16.8h, v14.8h\n" |
6053 | |
6054 | // Cast-and-saturate from int16 to uint8 |
6055 | "sqxtn v16.8b, v16.8h\n" |
6056 | |
6057 | // Load the clamp_min, clamp_max bounds |
6058 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
6059 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
6060 | "dup v14.16b, w2\n" // clamp_min |
6061 | "dup v15.16b, w3\n" // clamp_max |
6062 | |
6063 | // Apply the clamp_min bound |
6064 | "smax v16.16b, v16.16b, v14.16b\n" |
6065 | |
6066 | // Apply the clamp_max bound |
6067 | "smin v16.16b, v16.16b, v15.16b\n" |
6068 | |
6069 | // Make it so that all of the final 8bit values are stored in the |
6070 | // first 64bits of 128bit NEON registers, so they can be stored |
6071 | // by 64bit st1 store instructions with byte alignment. |
6072 | "dup d20, v16.d[1]\n" |
6073 | |
6074 | // Compute how much of the 8x1 block of destination 8bit values that |
6075 | // we have computed, fit in the destination matrix. Typically, all of |
6076 | // it fits, but when the destination matrix shape is not a multiple |
6077 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
6078 | // not fit entirely. |
6079 | "sub w1, %w[dst_rows], %w[row]\n" |
6080 | "sub w2, %w[dst_cols], %w[col]\n" |
6081 | "mov w3, #8\n" |
6082 | "cmp w1, #8\n" |
6083 | // Compute w1 = how many rows of the 8x1 block fit |
6084 | "csel w1, w1, w3, le\n" |
6085 | "cmp w2, #8\n" |
6086 | |
6087 | // Test if w1==8, i.e. if all of the 8x1 block fits. |
6088 | "cmp w1, w3\n" |
6089 | // Yes, all of the 8x1 block fits, go to fast path. |
6090 | "beq 130f\n" |
6091 | // Not all of the 8x1 block fits. |
6092 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
6093 | "mov x3, %[dst_tmp_buf]\n" |
6094 | "mov x4, #8\n" |
6095 | "b 131f\n" |
6096 | "130:\n" |
6097 | // Yes, all of the 8x8 block fits. |
6098 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
6099 | "mov x3, %[dst_ptr]\n" |
6100 | "mov x4, x11\n" |
6101 | "131:\n" |
6102 | |
6103 | // Write our 8bit values to the destination |
6104 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6105 | "st1 {v16.8b}, [x3]\n" |
6106 | RUY_MAKE_ZERO(v16) |
6107 | RUY_MAKE_ZERO(v17) |
6108 | |
6109 | // For the next block: perform the first few multiply-adds on the data |
6110 | // that we have already loaded. |
6111 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6112 | |
6113 | // If all of the 8x8 block fits, we just finished writing it to the |
6114 | // destination, so we skip the next part. |
6115 | "beq 141f\n" |
6116 | // Not all of the 8x8 block fits in the destination matrix. We just |
6117 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
6118 | // it to copy into the destination matrix the part that fits. |
6119 | "mov x3, %[dst_tmp_buf]\n" |
6120 | "mov x4, %[dst_ptr]\n" |
6121 | "mov w6, #0\n" |
6122 | "150:\n" |
6123 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6124 | "mov w5, #0\n" |
6125 | "151:\n" |
6126 | "ldrb w7, [x3, w5, uxtw]\n" |
6127 | "strb w7, [x4, w5, uxtw]\n" |
6128 | "add w5, w5, #1\n" |
6129 | "cmp w5, w1\n" |
6130 | "blt 151b\n" |
6131 | "141:\n" |
6132 | "add %[dst_ptr], %[dst_ptr], #8\n" |
6133 | // At this point we have completely finished writing values to the |
6134 | // destination matrix for the current block. |
6135 | |
6136 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
6137 | |
6138 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
6139 | |
6140 | // Add the destination zero point |
6141 | "dup v14.8h, v13.h[4]\n" |
6142 | "saddw v16.4s, v16.4s, v14.4h\n" |
6143 | "saddw v17.4s, v17.4s, v14.4h\n" |
6144 | |
6145 | // Cast-and-saturate from int32 to int16 |
6146 | "sqxtn v16.4h, v16.4s\n" |
6147 | "sqxtn2 v16.8h, v17.4s\n" |
6148 | |
6149 | // Load the clamp_min, clamp_max bounds |
6150 | "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
6151 | "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
6152 | "dup v14.8h, w2\n" // clamp_min |
6153 | "dup v15.8h, w3\n" // clamp_max |
6154 | |
6155 | // Apply the clamp_min bound |
6156 | "smax v16.8h, v16.8h, v14.8h\n" |
6157 | // Apply the clamp_max bound |
6158 | "smin v16.8h, v16.8h, v15.8h\n" |
6159 | |
6160 | // Compute how much of the 8x1 block of destination 16bit values that |
6161 | // we have computed, fit in the destination matrix. Typically, all of |
6162 | // it fits, but when the destination matrix shape is not a multiple |
6163 | // of 8x8, there are some 8x1 blocks along the boundaries that do |
6164 | // not fit entirely. |
6165 | "sub w1, %w[dst_rows], %w[row]\n" |
6166 | "sub w2, %w[dst_cols], %w[col]\n" |
6167 | "mov w3, #8\n" |
6168 | "cmp w1, #8\n" |
6169 | // Compute w1 = how many rows of the 8x1 block fit |
6170 | "csel w1, w1, w3, le\n" |
6171 | "cmp w2, #8\n" |
6172 | |
6173 | // Test if w1==8, i.e. if all of the 8x8 block fits. |
6174 | "cmp w1, w3\n" |
6175 | // Yes, all of the 8x1 block fits, go to fast path. |
6176 | "beq 230f\n" |
6177 | // Not all of the 8x1 block fits. |
6178 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
6179 | "mov x3, %[dst_tmp_buf]\n" |
6180 | "mov x4, #16\n" |
6181 | "b 231f\n" |
6182 | "230:\n" |
6183 | // Yes, all of the 8x1 block fits. |
6184 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
6185 | "mov x3, %[dst_ptr]\n" |
6186 | "mov x4, x11\n" |
6187 | "231:\n" |
6188 | |
6189 | // Write our 16bit values to the destination |
6190 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6191 | "st1 {v16.8h}, [x3]\n" |
6192 | RUY_MAKE_ZERO(v16) |
6193 | RUY_MAKE_ZERO(v17) |
6194 | |
6195 | // For the next block: perform the first few multiply-adds on the data |
6196 | // that we have already loaded. |
6197 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6198 | |
6199 | // If all of the 8x1 block fits, we just finished writing it to the |
6200 | // destination, so we skip the next part. |
6201 | "beq 241f\n" |
6202 | // Not all of the 8x1 block fits in the destination matrix. We just |
6203 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
6204 | // it to copy into the destination matrix the part that fits. |
6205 | "mov x3, %[dst_tmp_buf]\n" |
6206 | "mov x4, %[dst_ptr]\n" |
6207 | "mov w6, #0\n" |
6208 | "250:\n" |
6209 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
6210 | "mov w5, #0\n" |
6211 | "251:\n" |
6212 | "ldrsh w7, [x3, x5, lsl #1]\n" |
6213 | "strh w7, [x4, x5, lsl #1]\n" |
6214 | "add w5, w5, #1\n" |
6215 | "cmp w5, w1\n" |
6216 | "blt 251b\n" |
6217 | "241:\n" |
6218 | "add %[dst_ptr], %[dst_ptr], #16\n" |
6219 | // At this point we have completely finished writing values to the |
6220 | // destination matrix for the current block. |
6221 | |
6222 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
6223 | |
6224 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
6225 | |
6226 | // Since the store type is the same as the accum type, no need for |
6227 | // downcast. There's also no need for clamp by min/max. |
6228 | |
6229 | // Compute how much of the 8x1 block of destination 32 bit values that |
6230 | // we have computed, fit in the destination matrix. Typically, all of |
6231 | // it fits, but when the destination matrix shape is not a multiple |
6232 | // of 8x1, there are some 8x1 blocks along the boundaries that do |
6233 | // not fit entirely. |
6234 | "sub w1, %w[dst_rows], %w[row]\n" |
6235 | "sub w2, %w[dst_cols], %w[col]\n" |
6236 | "mov w3, #8\n" |
6237 | "cmp w1, #8\n" |
6238 | // Compute w1 = how many rows of the 8x1 block fit |
6239 | "csel w1, w1, w3, le\n" |
6240 | "cmp w2, #8\n" |
6241 | // Compute w1 = how many rows of the 8x8 block fit |
6242 | "csel w2, w2, w3, le\n" |
6243 | |
6244 | // Test if w1==8, i.e. if all of the 8x8 block fits. |
6245 | "cmp w1, w3\n" |
6246 | // Yes, all of the 8x1 block fits, go to fast path. |
6247 | "beq 330f\n" |
6248 | // Not all of the 8x1 block fits. |
6249 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
6250 | "mov x3, %[dst_tmp_buf]\n" |
6251 | "mov x4, #16\n" |
6252 | |
6253 | // Write our 32bit values to the destination described by |
6254 | // (x3 address, x4 stride). |
6255 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
6256 | "st1 {v16.4s}, [x3], x4\n" |
6257 | RUY_MAKE_ZERO(v16) |
6258 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
6259 | "st1 {v17.4s}, [x3], x4\n" |
6260 | RUY_MAKE_ZERO(v17) |
6261 | |
6262 | "b 331f\n" |
6263 | |
6264 | "330:\n" |
6265 | // Yes, all of the 8x1 block fits. |
6266 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
6267 | "mov x4, %[dst_ptr]\n" |
6268 | "mov x3, x4\n" |
6269 | |
6270 | // Write our 32bit values to the destination described by |
6271 | // (x3 address, x4 stride). |
6272 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
6273 | "st1 {v16.4s, v17.4s}, [x3], #32\n" |
6274 | RUY_MAKE_ZERO(v16) |
6275 | RUY_MAKE_ZERO(v17) |
6276 | |
6277 | "331:\n" |
6278 | |
6279 | // For the next block: perform the first few multiply-adds on the data |
6280 | // that we have already loaded. |
6281 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6282 | |
6283 | // If all of the 8x8 block fits, we just finished writing it to the |
6284 | // destination, so we skip the next part. |
6285 | "beq 341f\n" |
6286 | |
6287 | // Not all of the 8x8 block fits in the destination matrix. We just |
6288 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
6289 | // it to copy into the destination matrix the part that fits. |
6290 | "mov x3, %[dst_tmp_buf]\n" |
6291 | "mov x4, %[dst_ptr]\n" |
6292 | "mov w6, #0\n" |
6293 | "350:\n" |
6294 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
6295 | "mov w5, #0\n" |
6296 | "351:\n" |
6297 | "ldr w7, [x3, x5, lsl #2]\n" |
6298 | "str w7, [x4, x5, lsl #2]\n" |
6299 | "add w5, w5, #1\n" |
6300 | "cmp w5, w1\n" |
6301 | "blt 351b\n" |
6302 | "341:\n" |
6303 | "add %[dst_ptr], %[dst_ptr], #32\n" |
6304 | // At this point we have completely finished writing values to the |
6305 | // destination matrix for the current block. |
6306 | |
6307 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
6308 | |
6309 | // Reload some params --- we had used x5 -- x7 for a few other things |
6310 | // since the last time we had loaded them. |
6311 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
6312 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
6313 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
6314 | |
6315 | // Move to the next block of the destination matrix, for the next iter |
6316 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
6317 | // been updated earlier. |
6318 | // Have we reached the end row? |
6319 | "cmp %w[row], w7\n" |
6320 | "beq 20f\n" // yes, end row. |
6321 | // Not end row. Move to the next row. |
6322 | "add %w[row], %w[row], #8\n" |
6323 | "b 21f\n" |
6324 | "20:\n" |
6325 | // Was already at end row. |
6326 | "mov %w[row], w6\n" // Move back to first row. |
6327 | "add %w[col], %w[col], #8\n" // Move to the next column. |
6328 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
6329 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
6330 | "21:\n" |
6331 | |
6332 | // Main loop exit condition: have we hit the end column? |
6333 | "cmp %w[col], w8\n" |
6334 | |
6335 | // w1 is the number of levels of depth that we have already loaded |
6336 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
6337 | // above, this is currently 4. |
6338 | "mov w1, #4\n" |
6339 | |
6340 | "ble 1b\n" |
6341 | |
6342 | // clang-format on |
6343 | |
6344 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
6345 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
6346 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
6347 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
6348 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
6349 | [dst_type_id] "r" (params.dst_type_id) |
6350 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
6351 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
6352 | "v13" , "v14" , "v15" , "v16" , "v17" ); |
6353 | } |
6354 | |
6355 | // Variant of the above Kernel8bitNeonDotprod, tuned for in-order |
6356 | // CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, |
6357 | // since these are 64-bit and support dotprod. |
6358 | // |
6359 | // While this kernel does not have a direct equivalent in gemmlowp, it was |
6360 | // developed based on insights that David Mansell at ARM shared with their |
6361 | // contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful |
6362 | // comments. Specifically, see this comment about tuning for Cortex-A55r1: |
6363 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 |
6364 | void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) { |
6365 | profiler::ScopeLabel label( |
6366 | "Kernel (kNeonDotprod, optimized for in-order cores)" ); |
6367 | |
6368 | CheckOffsetsInKernelParams8bit(params); |
6369 | |
6370 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
6371 | const std::int8_t* rhs_col_ptr = |
6372 | static_cast<const int8_t*>(params.rhs_base_ptr); |
6373 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
6374 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
6375 | void* dst_col_ptr = params.dst_base_ptr; |
6376 | void* dst_ptr = dst_col_ptr; |
6377 | int row = params.start_row; |
6378 | int col = params.start_col; |
6379 | |
6380 | // The asm kernel below has the following NEON register allocation: |
6381 | // |
6382 | // v16 -- v31 are int32 accumulators. |
6383 | // During accumulation, v0 -- v3 are used to load int8 data from LHS and |
6384 | // RHS. |
6385 | // |
6386 | // int8 RHS 4x8 block |
6387 | // /-----------------------------------------| |
6388 | // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| |
6389 | // | ... ... | |
6390 | // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| |
6391 | // \-----------------------------------------/ |
6392 | // int8 LHS 8x4 block |
6393 | // /---------------------\ /-----------------------------------------| |
6394 | // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| |
6395 | // | ... ... | | ... ... | |
6396 | // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| |
6397 | // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| |
6398 | // | ... ... | | ... ... | |
6399 | // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| |
6400 | // \---------------------/ \-----------------------------------------/ |
6401 | // int32 accumulators 8x8 block |
6402 | // |
6403 | // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because |
6404 | // we did not observe a benefit of such partial unrolling on in-order CPUs. |
6405 | // |
6406 | // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for |
6407 | // the post-accumulation part of the kernel. |
6408 | asm volatile( |
6409 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
6410 | |
6411 | // clang-format off |
6412 | |
6413 | // Load some parameters into registers. |
6414 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
6415 | RUY_MAKE_ZERO(v16) |
6416 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
6417 | RUY_MAKE_ZERO(v17) |
6418 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
6419 | RUY_MAKE_ZERO(v18) |
6420 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
6421 | RUY_MAKE_ZERO(v19) |
6422 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
6423 | RUY_MAKE_ZERO(v20) |
6424 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
6425 | RUY_MAKE_ZERO(v21) |
6426 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
6427 | RUY_MAKE_ZERO(v22) |
6428 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
6429 | |
6430 | // Load the first 32 bytes of LHS and RHS data. |
6431 | "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" |
6432 | "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" |
6433 | "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" |
6434 | "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" |
6435 | |
6436 | // Clear accumulators. |
6437 | RUY_MAKE_ZERO(v23) |
6438 | RUY_MAKE_ZERO(v24) |
6439 | RUY_MAKE_ZERO(v25) |
6440 | RUY_MAKE_ZERO(v26) |
6441 | RUY_MAKE_ZERO(v27) |
6442 | // Perform the first few multiply-adds on the data that we have already |
6443 | // loaded. |
6444 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6445 | RUY_MAKE_ZERO(v28) |
6446 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
6447 | RUY_MAKE_ZERO(v29) |
6448 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
6449 | RUY_MAKE_ZERO(v30) |
6450 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
6451 | RUY_MAKE_ZERO(v31) |
6452 | |
6453 | |
6454 | "1:\n" |
6455 | |
6456 | "add x5, %[lhs_ptr], x12, lsl #3\n" |
6457 | "sub x5, x5, #32\n" |
6458 | "cmp %[lhs_ptr], x5\n" |
6459 | |
6460 | "beq 79f\n" |
6461 | |
6462 | // Main accumulation loop |
6463 | "2:\n" |
6464 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
6465 | "ldr x1, [%[lhs_ptr], #8]\n" |
6466 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
6467 | "ldr x3, [%[rhs_ptr], #8]\n" |
6468 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
6469 | "ldr x4, [%[rhs_ptr], #24]\n" |
6470 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
6471 | "ldr d0, [%[lhs_ptr], #0]\n" |
6472 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
6473 | "ins v0.d[1], x1\n" |
6474 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
6475 | "ldr x2, [%[lhs_ptr], #24]\n" |
6476 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
6477 | "add %[lhs_ptr], %[lhs_ptr], #32\n" |
6478 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
6479 | "ldr d2, [%[rhs_ptr], #0]\n" |
6480 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
6481 | "ins v2.d[1], x3\n" |
6482 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
6483 | "cmp %[lhs_ptr], x5\n" |
6484 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
6485 | "add %[rhs_ptr], %[rhs_ptr], #32\n" |
6486 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
6487 | "ldr d3, [%[rhs_ptr], #-16]\n" |
6488 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
6489 | "ldr d1, [%[lhs_ptr], #-16]\n" |
6490 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
6491 | "ins v3.d[1], x4\n" |
6492 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
6493 | "ins v1.d[1], x2\n" |
6494 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
6495 | "blt 2b\n" |
6496 | |
6497 | // Last accumulation steps, nothing left to load. |
6498 | "79:\n" |
6499 | ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" |
6500 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
6501 | ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" |
6502 | "cmp %w[row], w7\n" // Have we finished the last row? |
6503 | ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" |
6504 | ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" |
6505 | ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" |
6506 | ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" |
6507 | ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" |
6508 | ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" |
6509 | ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" |
6510 | ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" |
6511 | ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" |
6512 | ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" |
6513 | |
6514 | // End of accumulation. The registers v16 -- v31 contain the final |
6515 | // int32 accumulator values of the current 8x8 destination block. |
6516 | // We now have to compute the final 8-bit values from these int32 |
6517 | // accumulators, and advance to the next 8x8 block. We intertwine |
6518 | // these two aspects whenever possible for optimal pipelining, both |
6519 | // at the data flow level (prefetch data for next block as early as |
6520 | // possible) and instruction pipelining level (some of the next-block |
6521 | // work can dual-issue with some of the final work on the current |
6522 | // block). |
6523 | |
6524 | // Logic to advance to the next block in preparation for the next |
6525 | // iteration of the main loop. For now, we only want to compute |
6526 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
6527 | // not yet ready to update the values of row and col, as we still need |
6528 | // the current values for the rest of the work on the current block. |
6529 | |
6530 | "bge 4f\n" // If finished last row, go to 4 |
6531 | // Not finished last row: then advance to next row. |
6532 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
6533 | "b 5f\n" |
6534 | "4:\n" // Finished last row... |
6535 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
6536 | // Now we need to advance to the next column. If we already |
6537 | // finished the last column, then in principle we are done, however |
6538 | // we can't just return here, as we need to allow the end work of the |
6539 | // current block to complete. The good news is that at this point it |
6540 | // doesn't matter what data we load for the next column, since |
6541 | // we will exit from the main loop below before actually storing |
6542 | // anything computed from that data. |
6543 | "cmp %w[col], w8\n" // Have we finished the last column? |
6544 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
6545 | // Not finished last column: then advance to next column. |
6546 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
6547 | "5:\n" |
6548 | |
6549 | // Set the LHS and RHS data pointers to the start of the columns just |
6550 | // computed. |
6551 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
6552 | // Load some parameters needed for the end work on current block. |
6553 | "mvni v8.4s, #0\n" |
6554 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
6555 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
6556 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
6557 | "dup v9.4s, w3\n" // create prod_zp_depth_vec |
6558 | |
6559 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
6560 | // Determine the channel index. |
6561 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
6562 | "csel w3, %w[row], %w[col], eq\n" |
6563 | |
6564 | // Offset the bias pointer as needed given the current row, col. |
6565 | "add x5, x1, x3, lsl #2\n" |
6566 | |
6567 | // If there is no bias, use no offset, just address the passed zero |
6568 | // data. |
6569 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
6570 | "csel x1, x1, x5, eq\n" |
6571 | |
6572 | // Load 8 bias values. |
6573 | "ld1 {v14.2s}, [x1], #8\n" |
6574 | "ldr x5, [x1], #8\n" |
6575 | "ins v14.d[1], x5\n" |
6576 | "ld1 {v15.2s}, [x1], #8\n" |
6577 | "ldr x5, [x1], #8\n" |
6578 | "ins v15.d[1], x5\n" |
6579 | |
6580 | // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), |
6581 | // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
6582 | "add v14.4s, v14.4s, v9.4s\n" |
6583 | "add v15.4s, v15.4s, v9.4s\n" |
6584 | // Perform the bias-addition (per the above, we have just folded into |
6585 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
6586 | // Jump based on channel dimension. |
6587 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
6588 | "bne 6f\n" |
6589 | // Case where channels are rows |
6590 | "add v16.4s, v16.4s, v14.4s\n" |
6591 | "add v17.4s, v17.4s, v15.4s\n" |
6592 | "add v18.4s, v18.4s, v14.4s\n" |
6593 | "add v19.4s, v19.4s, v15.4s\n" |
6594 | "add v20.4s, v20.4s, v14.4s\n" |
6595 | "add v21.4s, v21.4s, v15.4s\n" |
6596 | "add v22.4s, v22.4s, v14.4s\n" |
6597 | "add v23.4s, v23.4s, v15.4s\n" |
6598 | "add v24.4s, v24.4s, v14.4s\n" |
6599 | "add v25.4s, v25.4s, v15.4s\n" |
6600 | "add v26.4s, v26.4s, v14.4s\n" |
6601 | "add v27.4s, v27.4s, v15.4s\n" |
6602 | "add v28.4s, v28.4s, v14.4s\n" |
6603 | "add v29.4s, v29.4s, v15.4s\n" |
6604 | "add v30.4s, v30.4s, v14.4s\n" |
6605 | "add v31.4s, v31.4s, v15.4s\n" |
6606 | "b 7f\n" |
6607 | |
6608 | "6:\n" |
6609 | // Case where channels are columns |
6610 | "dup v10.4s, v14.s[0]\n" |
6611 | "dup v11.4s, v14.s[1]\n" |
6612 | "add v16.4s, v16.4s, v10.4s\n" |
6613 | "dup v12.4s, v14.s[2]\n" |
6614 | "add v17.4s, v17.4s, v10.4s\n" |
6615 | "dup v13.4s, v14.s[3]\n" |
6616 | "add v18.4s, v18.4s, v11.4s\n" |
6617 | "dup v10.4s, v15.s[0]\n" |
6618 | "add v19.4s, v19.4s, v11.4s\n" |
6619 | "dup v11.4s, v15.s[1]\n" |
6620 | "add v20.4s, v20.4s, v12.4s\n" |
6621 | "add v21.4s, v21.4s, v12.4s\n" |
6622 | "dup v12.4s, v15.s[2]\n" |
6623 | "add v22.4s, v22.4s, v13.4s\n" |
6624 | "add v23.4s, v23.4s, v13.4s\n" |
6625 | "dup v13.4s, v15.s[3]\n" |
6626 | "add v24.4s, v24.4s, v10.4s\n" |
6627 | "add v25.4s, v25.4s, v10.4s\n" |
6628 | "add v26.4s, v26.4s, v11.4s\n" |
6629 | "add v27.4s, v27.4s, v11.4s\n" |
6630 | "add v28.4s, v28.4s, v12.4s\n" |
6631 | "add v29.4s, v29.4s, v12.4s\n" |
6632 | "add v30.4s, v30.4s, v13.4s\n" |
6633 | "add v31.4s, v31.4s, v13.4s\n" |
6634 | "7:\n" |
6635 | |
6636 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
6637 | "beq 401f\n" |
6638 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
6639 | "dup v10.4s, w5\n" // create lhs_zero_point_vec |
6640 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
6641 | "add x5, x5, %x[col], lsl #2\n" |
6642 | // Load 8 rhs_sums values. |
6643 | "ld1 {v14.2s}, [x5], #8\n" |
6644 | "ldr x7, [x5], #8\n" |
6645 | "ld1 {v15.2s}, [x5], #8\n" |
6646 | "ins v14.d[1], x7\n" |
6647 | "ldr x7, [x5], #8\n" |
6648 | "ins v15.d[1], x7\n" |
6649 | // Subtract rhs_sums * lhs_zero_point, per |
6650 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
6651 | "mls v16.4s, v10.4s, v14.s[0]\n" |
6652 | "mls v17.4s, v10.4s, v14.s[0]\n" |
6653 | "mls v18.4s, v10.4s, v14.s[1]\n" |
6654 | "mls v19.4s, v10.4s, v14.s[1]\n" |
6655 | "mls v20.4s, v10.4s, v14.s[2]\n" |
6656 | "mls v21.4s, v10.4s, v14.s[2]\n" |
6657 | "mls v22.4s, v10.4s, v14.s[3]\n" |
6658 | "mls v23.4s, v10.4s, v14.s[3]\n" |
6659 | "mls v24.4s, v10.4s, v15.s[0]\n" |
6660 | "mls v25.4s, v10.4s, v15.s[0]\n" |
6661 | "mls v26.4s, v10.4s, v15.s[1]\n" |
6662 | "mls v27.4s, v10.4s, v15.s[1]\n" |
6663 | "mls v28.4s, v10.4s, v15.s[2]\n" |
6664 | "mls v29.4s, v10.4s, v15.s[2]\n" |
6665 | "mls v30.4s, v10.4s, v15.s[3]\n" |
6666 | "mls v31.4s, v10.4s, v15.s[3]\n" |
6667 | "401:\n" |
6668 | |
6669 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
6670 | "beq 402f\n" |
6671 | "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
6672 | "add x2, x2, %x[row], lsl #2\n" |
6673 | "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
6674 | "ins v13.s[1], w5\n" // rhs_zero_point |
6675 | // Load 8 lhs_sums values. |
6676 | "ld1 {v11.2s}, [x2], #8\n" |
6677 | "ldr x4, [x2], #8\n" |
6678 | "ins v11.d[1], x4\n" |
6679 | "ld1 {v12.2s}, [x2], #8\n" |
6680 | "ldr x4, [x2], #8\n" |
6681 | "ins v12.d[1], x4\n" |
6682 | // Compute lhs_sums * rhs_zero_point. |
6683 | "mul v11.4s, v11.4s, v13.s[1]\n" |
6684 | "mul v12.4s, v12.4s, v13.s[1]\n" |
6685 | // Subtract lhs_sums * rhs_zero_point, per |
6686 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
6687 | "sub v16.4s, v16.4s, v11.4s\n" |
6688 | "sub v17.4s, v17.4s, v12.4s\n" |
6689 | "sub v18.4s, v18.4s, v11.4s\n" |
6690 | "sub v19.4s, v19.4s, v12.4s\n" |
6691 | "sub v20.4s, v20.4s, v11.4s\n" |
6692 | "sub v21.4s, v21.4s, v12.4s\n" |
6693 | "sub v22.4s, v22.4s, v11.4s\n" |
6694 | "sub v23.4s, v23.4s, v12.4s\n" |
6695 | "sub v24.4s, v24.4s, v11.4s\n" |
6696 | "sub v25.4s, v25.4s, v12.4s\n" |
6697 | "sub v26.4s, v26.4s, v11.4s\n" |
6698 | "sub v27.4s, v27.4s, v12.4s\n" |
6699 | "sub v28.4s, v28.4s, v11.4s\n" |
6700 | "sub v29.4s, v29.4s, v12.4s\n" |
6701 | "sub v30.4s, v30.4s, v11.4s\n" |
6702 | "sub v31.4s, v31.4s, v12.4s\n" |
6703 | |
6704 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
6705 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
6706 | |
6707 | "402:\n" |
6708 | |
6709 | // At this point we have computed the final int32 values. Now we |
6710 | // start down-quantizing them to obtain the final 8bit values from them. |
6711 | |
6712 | // As part of this down-quantization, our int32 values will be |
6713 | // multiplied by a multiplier that has a fixed-point component and an |
6714 | // exponent component. |
6715 | |
6716 | //Load the exponent part of the multiplier. |
6717 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
6718 | // Compute the multiplier_exponent pointer |
6719 | "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
6720 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
6721 | "add x5, x1, x3, lsl #2\n" |
6722 | "csel x1, x1, x5, eq\n" |
6723 | // Load multiplier_exponent |
6724 | "ldr q9, [x1]\n" |
6725 | "ldr q10, [x1, #16]\n" |
6726 | // Separate positive and negative exponents |
6727 | "smin v11.4s, v8.4s, v9.4s\n" |
6728 | "smin v12.4s, v8.4s, v10.4s\n" |
6729 | "sub v9.4s, v9.4s, v11.4s\n" |
6730 | "sub v10.4s, v10.4s, v12.4s\n" |
6731 | |
6732 | // Compute the multiplier_fixedpoint pointer |
6733 | "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
6734 | "add x5, x4, x3, lsl #2\n" |
6735 | "csel x4, x4, x5, eq\n" |
6736 | // Load multiplier_fixedpoint |
6737 | "ldr q14, [x4]\n" |
6738 | "ldr q15, [x4, #16]\n" |
6739 | |
6740 | // Jump based on channel dimension. |
6741 | "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
6742 | "bne 8f\n" |
6743 | // Case where channels are rows |
6744 | |
6745 | // Apply the positive exponent part of the multiplier. |
6746 | "sshl v16.4s, v16.4s, v9.4s\n" |
6747 | "sshl v17.4s, v17.4s, v10.4s\n" |
6748 | "sshl v18.4s, v18.4s, v9.4s\n" |
6749 | "sshl v19.4s, v19.4s, v10.4s\n" |
6750 | "sshl v20.4s, v20.4s, v9.4s\n" |
6751 | "sshl v21.4s, v21.4s, v10.4s\n" |
6752 | "sshl v22.4s, v22.4s, v9.4s\n" |
6753 | "sshl v23.4s, v23.4s, v10.4s\n" |
6754 | "sshl v24.4s, v24.4s, v9.4s\n" |
6755 | "sshl v25.4s, v25.4s, v10.4s\n" |
6756 | "sshl v26.4s, v26.4s, v9.4s\n" |
6757 | "sshl v27.4s, v27.4s, v10.4s\n" |
6758 | "sshl v28.4s, v28.4s, v9.4s\n" |
6759 | "sshl v29.4s, v29.4s, v10.4s\n" |
6760 | "sshl v30.4s, v30.4s, v9.4s\n" |
6761 | "sshl v31.4s, v31.4s, v10.4s\n" |
6762 | "10:\n" |
6763 | |
6764 | // Apply the fixed-point part of the multiplier. |
6765 | // |
6766 | // ... and, interleaved into that: |
6767 | // Now that we know what LHS and RHS data the next iteration of the |
6768 | // main loop will need to load, we start loading the first 32 bytes of |
6769 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
6770 | // in the rest of the work on the current block. |
6771 | "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" |
6772 | "sqdmulh v16.4s, v16.4s, v14.4s\n" |
6773 | "ldr x1, [%[lhs_ptr]], #8\n" |
6774 | "sqdmulh v17.4s, v17.4s, v15.4s\n" |
6775 | "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" |
6776 | "sqdmulh v18.4s, v18.4s, v14.4s\n" |
6777 | "ldr x2, [%[lhs_ptr]], #8\n" |
6778 | "sqdmulh v19.4s, v19.4s, v15.4s\n" |
6779 | "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" |
6780 | "sqdmulh v20.4s, v20.4s, v14.4s\n" |
6781 | "ldr x5, [%[rhs_ptr]], #8\n" |
6782 | "sqdmulh v21.4s, v21.4s, v15.4s\n" |
6783 | "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" |
6784 | "sqdmulh v22.4s, v22.4s, v14.4s\n" |
6785 | "ldr x6, [%[rhs_ptr]], #8\n" |
6786 | "sqdmulh v23.4s, v23.4s, v15.4s\n" |
6787 | "sqdmulh v24.4s, v24.4s, v14.4s\n" |
6788 | "sqdmulh v25.4s, v25.4s, v15.4s\n" |
6789 | "sqdmulh v26.4s, v26.4s, v14.4s\n" |
6790 | "sqdmulh v27.4s, v27.4s, v15.4s\n" |
6791 | "sqdmulh v28.4s, v28.4s, v14.4s\n" |
6792 | "sqdmulh v29.4s, v29.4s, v15.4s\n" |
6793 | "sqdmulh v30.4s, v30.4s, v14.4s\n" |
6794 | "sqdmulh v31.4s, v31.4s, v15.4s\n" |
6795 | |
6796 | // Apply the negative exponent part of the multiplier. |
6797 | "srshl v16.4s, v16.4s, v11.4s\n" |
6798 | "srshl v17.4s, v17.4s, v12.4s\n" |
6799 | "srshl v18.4s, v18.4s, v11.4s\n" |
6800 | "srshl v19.4s, v19.4s, v12.4s\n" |
6801 | "srshl v20.4s, v20.4s, v11.4s\n" |
6802 | "srshl v21.4s, v21.4s, v12.4s\n" |
6803 | "srshl v22.4s, v22.4s, v11.4s\n" |
6804 | "srshl v23.4s, v23.4s, v12.4s\n" |
6805 | "srshl v24.4s, v24.4s, v11.4s\n" |
6806 | "srshl v25.4s, v25.4s, v12.4s\n" |
6807 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
6808 | "srshl v26.4s, v26.4s, v11.4s\n" |
6809 | "ins v13.h[4], w4\n" // dst_zero_point |
6810 | "srshl v27.4s, v27.4s, v12.4s\n" |
6811 | "ins v0.d[1], x1\n" |
6812 | "srshl v28.4s, v28.4s, v11.4s\n" |
6813 | "ins v1.d[1], x2\n" |
6814 | "srshl v29.4s, v29.4s, v12.4s\n" |
6815 | "ins v2.d[1], x5\n" |
6816 | "srshl v30.4s, v30.4s, v11.4s\n" |
6817 | "ins v3.d[1], x6\n" |
6818 | "srshl v31.4s, v31.4s, v12.4s\n" |
6819 | "b 9f\n" |
6820 | |
6821 | "8:\n" |
6822 | // Case where channels are columns |
6823 | |
6824 | // Apply the positive exponent part of the multiplier. |
6825 | "dup v4.4s, v9.s[0]\n" |
6826 | "dup v5.4s, v9.s[1]\n" |
6827 | "sshl v16.4s, v16.4s, v4.4s\n" |
6828 | "dup v6.4s, v9.s[2]\n" |
6829 | "sshl v17.4s, v17.4s, v4.4s\n" |
6830 | "dup v7.4s, v9.s[3]\n" |
6831 | "sshl v18.4s, v18.4s, v5.4s\n" |
6832 | "dup v4.4s, v10.s[0]\n" |
6833 | "sshl v19.4s, v19.4s, v5.4s\n" |
6834 | "dup v5.4s, v10.s[1]\n" |
6835 | "sshl v20.4s, v20.4s, v6.4s\n" |
6836 | "sshl v21.4s, v21.4s, v6.4s\n" |
6837 | "dup v6.4s, v10.s[2]\n" |
6838 | "sshl v22.4s, v22.4s, v7.4s\n" |
6839 | "sshl v23.4s, v23.4s, v7.4s\n" |
6840 | "dup v7.4s, v10.s[3]\n" |
6841 | "sshl v24.4s, v24.4s, v4.4s\n" |
6842 | "sshl v25.4s, v25.4s, v4.4s\n" |
6843 | "sshl v26.4s, v26.4s, v5.4s\n" |
6844 | "sshl v27.4s, v27.4s, v5.4s\n" |
6845 | "sshl v28.4s, v28.4s, v6.4s\n" |
6846 | "sshl v29.4s, v29.4s, v6.4s\n" |
6847 | "sshl v30.4s, v30.4s, v7.4s\n" |
6848 | "sshl v31.4s, v31.4s, v7.4s\n" |
6849 | "11:\n" |
6850 | |
6851 | // Apply the fixed-point part of the multiplier. |
6852 | // |
6853 | // ... and, interleaved into that: |
6854 | // Now that we know what LHS and RHS data the next iteration of the |
6855 | // main loop will need to load, we start loading the first 32 bytes of |
6856 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
6857 | // in the rest of the work on the current block. |
6858 | "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" |
6859 | "sqdmulh v16.4s, v16.4s, v14.s[0]\n" |
6860 | "ldr x1, [%[lhs_ptr]], #8\n" |
6861 | "sqdmulh v17.4s, v17.4s, v14.s[0]\n" |
6862 | "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" |
6863 | "sqdmulh v18.4s, v18.4s, v14.s[1]\n" |
6864 | "ldr x2, [%[lhs_ptr]], #8\n" |
6865 | "sqdmulh v19.4s, v19.4s, v14.s[1]\n" |
6866 | "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" |
6867 | "sqdmulh v20.4s, v20.4s, v14.s[2]\n" |
6868 | "ldr x5, [%[rhs_ptr]], #8\n" |
6869 | "sqdmulh v21.4s, v21.4s, v14.s[2]\n" |
6870 | "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" |
6871 | "sqdmulh v22.4s, v22.4s, v14.s[3]\n" |
6872 | "ldr x6, [%[rhs_ptr]], #8\n" |
6873 | "sqdmulh v23.4s, v23.4s, v14.s[3]\n" |
6874 | "dup v4.4s, v11.s[0]\n" |
6875 | "sqdmulh v24.4s, v24.4s, v15.s[0]\n" |
6876 | "dup v5.4s, v11.s[1]\n" |
6877 | "sqdmulh v25.4s, v25.4s, v15.s[0]\n" |
6878 | "dup v6.4s, v11.s[2]\n" |
6879 | "sqdmulh v26.4s, v26.4s, v15.s[1]\n" |
6880 | "dup v7.4s, v11.s[3]\n" |
6881 | "sqdmulh v27.4s, v27.4s, v15.s[1]\n" |
6882 | "sqdmulh v28.4s, v28.4s, v15.s[2]\n" |
6883 | "sqdmulh v29.4s, v29.4s, v15.s[2]\n" |
6884 | "sqdmulh v30.4s, v30.4s, v15.s[3]\n" |
6885 | "sqdmulh v31.4s, v31.4s, v15.s[3]\n" |
6886 | |
6887 | // Apply the negative exponent part of the multiplier. |
6888 | "srshl v16.4s, v16.4s, v4.4s\n" |
6889 | "srshl v17.4s, v17.4s, v4.4s\n" |
6890 | "dup v4.4s, v12.s[0]\n" |
6891 | "srshl v18.4s, v18.4s, v5.4s\n" |
6892 | "srshl v19.4s, v19.4s, v5.4s\n" |
6893 | "dup v5.4s, v12.s[1]\n" |
6894 | "srshl v20.4s, v20.4s, v6.4s\n" |
6895 | "srshl v21.4s, v21.4s, v6.4s\n" |
6896 | "dup v6.4s, v12.s[2]\n" |
6897 | "srshl v22.4s, v22.4s, v7.4s\n" |
6898 | "srshl v23.4s, v23.4s, v7.4s\n" |
6899 | "dup v7.4s, v12.s[3]\n" |
6900 | "srshl v24.4s, v24.4s, v4.4s\n" |
6901 | "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
6902 | "srshl v25.4s, v25.4s, v4.4s\n" |
6903 | "ins v13.h[4], w4\n" // dst_zero_point |
6904 | "srshl v26.4s, v26.4s, v5.4s\n" |
6905 | "ins v0.d[1], x1\n" |
6906 | "srshl v27.4s, v27.4s, v5.4s\n" |
6907 | "ins v1.d[1], x2\n" |
6908 | "srshl v28.4s, v28.4s, v6.4s\n" |
6909 | "ins v2.d[1], x5\n" |
6910 | "srshl v29.4s, v29.4s, v6.4s\n" |
6911 | "ins v3.d[1], x6\n" |
6912 | "srshl v30.4s, v30.4s, v7.4s\n" |
6913 | "srshl v31.4s, v31.4s, v7.4s\n" |
6914 | "9:\n" |
6915 | |
6916 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
6917 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
6918 | "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
6919 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
6920 | |
6921 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
6922 | |
6923 | // Cast-and-saturate from int32 to int16 |
6924 | "sqxtn v16.4h, v16.4s\n" |
6925 | "sqxtn2 v16.8h, v17.4s\n" |
6926 | "sqxtn v17.4h, v18.4s\n" |
6927 | "sqxtn2 v17.8h, v19.4s\n" |
6928 | "sqxtn v18.4h, v20.4s\n" |
6929 | "sqxtn2 v18.8h, v21.4s\n" |
6930 | "sqxtn v19.4h, v22.4s\n" |
6931 | "sqxtn2 v19.8h, v23.4s\n" |
6932 | "sqxtn v20.4h, v24.4s\n" |
6933 | "sqxtn2 v20.8h, v25.4s\n" |
6934 | "sqxtn v21.4h, v26.4s\n" |
6935 | "sqxtn2 v21.8h, v27.4s\n" |
6936 | "sqxtn v22.4h, v28.4s\n" |
6937 | "sqxtn2 v22.8h, v29.4s\n" |
6938 | "sqxtn v23.4h, v30.4s\n" |
6939 | "sqxtn2 v23.8h, v31.4s\n" |
6940 | |
6941 | // Destination zero_point |
6942 | "dup v14.8h, v13.h[4]\n" |
6943 | // At this point, v24 -- v31 aren't used anymore for the current block, |
6944 | // so we can start clearing these accumulators for the next block |
6945 | // (next iteration of the main loop). |
6946 | RUY_MAKE_ZERO(v24) |
6947 | RUY_MAKE_ZERO(v25) |
6948 | RUY_MAKE_ZERO(v26) |
6949 | RUY_MAKE_ZERO(v27) |
6950 | RUY_MAKE_ZERO(v28) |
6951 | RUY_MAKE_ZERO(v29) |
6952 | RUY_MAKE_ZERO(v30) |
6953 | RUY_MAKE_ZERO(v31) |
6954 | |
6955 | // Add the destination zero point |
6956 | "sqadd v16.8h, v16.8h, v14.8h\n" |
6957 | "sqadd v17.8h, v17.8h, v14.8h\n" |
6958 | "sqadd v18.8h, v18.8h, v14.8h\n" |
6959 | "sqadd v19.8h, v19.8h, v14.8h\n" |
6960 | "sqadd v20.8h, v20.8h, v14.8h\n" |
6961 | "sqadd v21.8h, v21.8h, v14.8h\n" |
6962 | "sqadd v22.8h, v22.8h, v14.8h\n" |
6963 | "sqadd v23.8h, v23.8h, v14.8h\n" |
6964 | |
6965 | // Load the clamp_min, clamp_max bounds |
6966 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
6967 | // Cast-and-saturate from int16 to uint8 |
6968 | "sqxtun v16.8b, v16.8h\n" |
6969 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
6970 | "sqxtun2 v16.16b, v17.8h\n" |
6971 | "sqxtun v17.8b, v18.8h\n" |
6972 | "sqxtun2 v17.16b, v19.8h\n" |
6973 | "sqxtun v18.8b, v20.8h\n" |
6974 | "sqxtun2 v18.16b, v21.8h\n" |
6975 | "sqxtun v19.8b, v22.8h\n" |
6976 | "sqxtun2 v19.16b, v23.8h\n" |
6977 | |
6978 | "dup v14.16b, w2\n" // clamp_min |
6979 | "dup v15.16b, w3\n" // clamp_max |
6980 | |
6981 | // Compute how much of the 8x8 block of destination 8bit values that |
6982 | // we have computed, fit in the destination matrix. Typically, all of |
6983 | // it fits, but when the destination matrix shape is not a multiple |
6984 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
6985 | // not fit entirely. |
6986 | "sub w1, %w[dst_rows], %w[row]\n" |
6987 | // Apply the clamp_min bound |
6988 | "umax v16.16b, v16.16b, v14.16b\n" |
6989 | "sub w2, %w[dst_cols], %w[col]\n" |
6990 | "umax v17.16b, v17.16b, v14.16b\n" |
6991 | "mov w3, #8\n" |
6992 | "umax v18.16b, v18.16b, v14.16b\n" |
6993 | "cmp w1, #8\n" |
6994 | "umax v19.16b, v19.16b, v14.16b\n" |
6995 | // Compute w1 = how many rows of the 8x8 block fit |
6996 | "csel w1, w1, w3, le\n" |
6997 | // Apply the clamp_max bound |
6998 | "umin v16.16b, v16.16b, v15.16b\n" |
6999 | "cmp w2, #8\n" |
7000 | "umin v17.16b, v17.16b, v15.16b\n" |
7001 | // Compute w2 = how many cols of the 8x8 block fit |
7002 | "csel w2, w2, w3, le\n" |
7003 | "umin v18.16b, v18.16b, v15.16b\n" |
7004 | "umin v19.16b, v19.16b, v15.16b\n" |
7005 | |
7006 | // Make it so that all of the final 8bit values are stored in the |
7007 | // first 64bits of 128bit NEON registers, so they can be stored |
7008 | // by 64bit st1 store instructions with byte alignment. |
7009 | "dup d20, v16.d[1]\n" |
7010 | "dup d21, v17.d[1]\n" |
7011 | "dup d22, v18.d[1]\n" |
7012 | "dup d23, v19.d[1]\n" |
7013 | |
7014 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
7015 | "cmp w1, w3\n" |
7016 | "ccmp w2, w3, 0, eq\n" |
7017 | // Yes, all of the 8x8 block fits, go to fast path. |
7018 | "beq 30f\n" |
7019 | // Not all of the 8x8 block fits. |
7020 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
7021 | "mov x3, %[dst_tmp_buf]\n" |
7022 | "mov x4, #8\n" |
7023 | "b 31f\n" |
7024 | "30:\n" |
7025 | // Yes, all of the 8x8 block fits. |
7026 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
7027 | "mov x3, %[dst_ptr]\n" |
7028 | "mov x4, x11\n" |
7029 | "31:\n" |
7030 | |
7031 | // Write our 8bit values to the destination described by |
7032 | // (x3 address, x4 stride). |
7033 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7034 | "st1 {v16.8b}, [x3], x4\n" |
7035 | RUY_MAKE_ZERO(v16) |
7036 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7037 | "st1 {v20.8b}, [x3], x4\n" |
7038 | RUY_MAKE_ZERO(v20) |
7039 | // For the next block: perform the first few multiply-adds on the data |
7040 | // that we have already loaded. |
7041 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
7042 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7043 | "st1 {v17.8b}, [x3], x4\n" |
7044 | RUY_MAKE_ZERO(v17) |
7045 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
7046 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7047 | "st1 {v21.8b}, [x3], x4\n" |
7048 | RUY_MAKE_ZERO(v21) |
7049 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7050 | "st1 {v18.8b}, [x3], x4\n" |
7051 | RUY_MAKE_ZERO(v18) |
7052 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7053 | "st1 {v22.8b}, [x3], x4\n" |
7054 | RUY_MAKE_ZERO(v22) |
7055 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
7056 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7057 | "st1 {v19.8b}, [x3], x4\n" |
7058 | RUY_MAKE_ZERO(v19) |
7059 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
7060 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7061 | "st1 {v23.8b}, [x3], x4\n" |
7062 | RUY_MAKE_ZERO(v23) |
7063 | |
7064 | // If all of the 8x8 block fits, we just finished writing it to the |
7065 | // destination, so we skip the next part. |
7066 | "beq 41f\n" |
7067 | // Not all of the 8x8 block fits in the destination matrix. We just |
7068 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
7069 | // it to copy into the destination matrix the part that fits. |
7070 | "mov x3, %[dst_tmp_buf]\n" |
7071 | "mov x4, %[dst_ptr]\n" |
7072 | "mov w6, #0\n" |
7073 | "50:\n" |
7074 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7075 | "mov w5, #0\n" |
7076 | "51:\n" |
7077 | "ldrb w7, [x3, w5, uxtw]\n" |
7078 | "strb w7, [x4, w5, uxtw]\n" |
7079 | "add w5, w5, #1\n" |
7080 | "cmp w5, w1\n" |
7081 | "blt 51b\n" |
7082 | "add w6, w6, #1\n" |
7083 | "add x3, x3, #8\n" |
7084 | "add x4, x4, x11\n" |
7085 | "cmp w6, w2\n" |
7086 | "blt 50b\n" |
7087 | "41:\n" |
7088 | "add %[dst_ptr], %[dst_ptr], #8\n" |
7089 | |
7090 | // At this point we have completely finished writing values to the |
7091 | // destination matrix for the current block. |
7092 | |
7093 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
7094 | |
7095 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
7096 | |
7097 | // Cast-and-saturate from int32 to int16 |
7098 | "sqxtn v16.4h, v16.4s\n" |
7099 | "sqxtn2 v16.8h, v17.4s\n" |
7100 | "sqxtn v17.4h, v18.4s\n" |
7101 | "sqxtn2 v17.8h, v19.4s\n" |
7102 | "sqxtn v18.4h, v20.4s\n" |
7103 | "sqxtn2 v18.8h, v21.4s\n" |
7104 | "sqxtn v19.4h, v22.4s\n" |
7105 | "sqxtn2 v19.8h, v23.4s\n" |
7106 | "sqxtn v20.4h, v24.4s\n" |
7107 | "sqxtn2 v20.8h, v25.4s\n" |
7108 | "sqxtn v21.4h, v26.4s\n" |
7109 | "sqxtn2 v21.8h, v27.4s\n" |
7110 | "sqxtn v22.4h, v28.4s\n" |
7111 | "sqxtn2 v22.8h, v29.4s\n" |
7112 | "sqxtn v23.4h, v30.4s\n" |
7113 | "sqxtn2 v23.8h, v31.4s\n" |
7114 | |
7115 | // Destination zero_point |
7116 | "dup v14.8h, v13.h[4]\n" |
7117 | // At this point, v24 -- v31 aren't used anymore for the current block, |
7118 | // so we can start clearing these accumulators for the next block |
7119 | // (next iteration of the main loop). |
7120 | RUY_MAKE_ZERO(v24) |
7121 | RUY_MAKE_ZERO(v25) |
7122 | RUY_MAKE_ZERO(v26) |
7123 | RUY_MAKE_ZERO(v27) |
7124 | RUY_MAKE_ZERO(v28) |
7125 | RUY_MAKE_ZERO(v29) |
7126 | RUY_MAKE_ZERO(v30) |
7127 | RUY_MAKE_ZERO(v31) |
7128 | |
7129 | // Add the destination zero point |
7130 | "sqadd v16.8h, v16.8h, v14.8h\n" |
7131 | "sqadd v17.8h, v17.8h, v14.8h\n" |
7132 | "sqadd v18.8h, v18.8h, v14.8h\n" |
7133 | "sqadd v19.8h, v19.8h, v14.8h\n" |
7134 | "sqadd v20.8h, v20.8h, v14.8h\n" |
7135 | "sqadd v21.8h, v21.8h, v14.8h\n" |
7136 | "sqadd v22.8h, v22.8h, v14.8h\n" |
7137 | "sqadd v23.8h, v23.8h, v14.8h\n" |
7138 | |
7139 | // Load the clamp_min, clamp_max bounds |
7140 | "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
7141 | // Cast-and-saturate from int16 to uint8 |
7142 | "sqxtn v16.8b, v16.8h\n" |
7143 | "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
7144 | "sqxtn2 v16.16b, v17.8h\n" |
7145 | "sqxtn v17.8b, v18.8h\n" |
7146 | "sqxtn2 v17.16b, v19.8h\n" |
7147 | "sqxtn v18.8b, v20.8h\n" |
7148 | "sqxtn2 v18.16b, v21.8h\n" |
7149 | "sqxtn v19.8b, v22.8h\n" |
7150 | "sqxtn2 v19.16b, v23.8h\n" |
7151 | |
7152 | "dup v14.16b, w2\n" // clamp_min |
7153 | "dup v15.16b, w3\n" // clamp_max |
7154 | |
7155 | // Compute how much of the 8x8 block of destination 8bit values that |
7156 | // we have computed, fit in the destination matrix. Typically, all of |
7157 | // it fits, but when the destination matrix shape is not a multiple |
7158 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
7159 | // not fit entirely. |
7160 | "sub w1, %w[dst_rows], %w[row]\n" |
7161 | // Apply the clamp_min bound |
7162 | "smax v16.16b, v16.16b, v14.16b\n" |
7163 | "sub w2, %w[dst_cols], %w[col]\n" |
7164 | "smax v17.16b, v17.16b, v14.16b\n" |
7165 | "mov w3, #8\n" |
7166 | "smax v18.16b, v18.16b, v14.16b\n" |
7167 | "cmp w1, #8\n" |
7168 | "smax v19.16b, v19.16b, v14.16b\n" |
7169 | // Compute w1 = how many rows of the 8x8 block fit |
7170 | "csel w1, w1, w3, le\n" |
7171 | // Apply the clamp_max bound |
7172 | "smin v16.16b, v16.16b, v15.16b\n" |
7173 | "cmp w2, #8\n" |
7174 | "smin v17.16b, v17.16b, v15.16b\n" |
7175 | // Compute w2 = how many cols of the 8x8 block fit |
7176 | "csel w2, w2, w3, le\n" |
7177 | "smin v18.16b, v18.16b, v15.16b\n" |
7178 | "smin v19.16b, v19.16b, v15.16b\n" |
7179 | |
7180 | // Make it so that all of the final 8bit values are stored in the |
7181 | // first 64bits of 128bit NEON registers, so they can be stored |
7182 | // by 64bit st1 store instructions with byte alignment. |
7183 | "dup d20, v16.d[1]\n" |
7184 | "dup d21, v17.d[1]\n" |
7185 | "dup d22, v18.d[1]\n" |
7186 | "dup d23, v19.d[1]\n" |
7187 | |
7188 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
7189 | "cmp w1, w3\n" |
7190 | "ccmp w2, w3, 0, eq\n" |
7191 | // Yes, all of the 8x8 block fits, go to fast path. |
7192 | "beq 130f\n" |
7193 | // Not all of the 8x8 block fits. |
7194 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
7195 | "mov x3, %[dst_tmp_buf]\n" |
7196 | "mov x4, #8\n" |
7197 | "b 131f\n" |
7198 | "130:\n" |
7199 | // Yes, all of the 8x8 block fits. |
7200 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
7201 | "mov x3, %[dst_ptr]\n" |
7202 | "mov x4, x11\n" |
7203 | "131:\n" |
7204 | |
7205 | // Write our 8bit values to the destination described by |
7206 | // (x3 address, x4 stride). |
7207 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7208 | "st1 {v16.8b}, [x3], x4\n" |
7209 | RUY_MAKE_ZERO(v16) |
7210 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7211 | "st1 {v20.8b}, [x3], x4\n" |
7212 | RUY_MAKE_ZERO(v20) |
7213 | // For the next block: perform the first few multiply-adds on the data |
7214 | // that we have already loaded. |
7215 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
7216 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7217 | "st1 {v17.8b}, [x3], x4\n" |
7218 | RUY_MAKE_ZERO(v17) |
7219 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
7220 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7221 | "st1 {v21.8b}, [x3], x4\n" |
7222 | RUY_MAKE_ZERO(v21) |
7223 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7224 | "st1 {v18.8b}, [x3], x4\n" |
7225 | RUY_MAKE_ZERO(v18) |
7226 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7227 | "st1 {v22.8b}, [x3], x4\n" |
7228 | RUY_MAKE_ZERO(v22) |
7229 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
7230 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7231 | "st1 {v19.8b}, [x3], x4\n" |
7232 | RUY_MAKE_ZERO(v19) |
7233 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
7234 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7235 | "st1 {v23.8b}, [x3], x4\n" |
7236 | RUY_MAKE_ZERO(v23) |
7237 | |
7238 | // If all of the 8x8 block fits, we just finished writing it to the |
7239 | // destination, so we skip the next part. |
7240 | "beq 141f\n" |
7241 | // Not all of the 8x8 block fits in the destination matrix. We just |
7242 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
7243 | // it to copy into the destination matrix the part that fits. |
7244 | "mov x3, %[dst_tmp_buf]\n" |
7245 | "mov x4, %[dst_ptr]\n" |
7246 | "mov w6, #0\n" |
7247 | "150:\n" |
7248 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7249 | "mov w5, #0\n" |
7250 | "151:\n" |
7251 | "ldrb w7, [x3, w5, uxtw]\n" |
7252 | "strb w7, [x4, w5, uxtw]\n" |
7253 | "add w5, w5, #1\n" |
7254 | "cmp w5, w1\n" |
7255 | "blt 151b\n" |
7256 | "add w6, w6, #1\n" |
7257 | "add x3, x3, #8\n" |
7258 | "add x4, x4, x11\n" |
7259 | "cmp w6, w2\n" |
7260 | "blt 150b\n" |
7261 | "141:\n" |
7262 | "add %[dst_ptr], %[dst_ptr], #8\n" |
7263 | |
7264 | // At this point we have completely finished writing values to the |
7265 | // destination matrix for the current block. |
7266 | |
7267 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
7268 | |
7269 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
7270 | |
7271 | // Add the destination zero point |
7272 | "dup v14.8h, v13.h[4]\n" |
7273 | "saddw v16.4s, v16.4s, v14.4h\n" |
7274 | "saddw v17.4s, v17.4s, v14.4h\n" |
7275 | "saddw v18.4s, v18.4s, v14.4h\n" |
7276 | "saddw v19.4s, v19.4s, v14.4h\n" |
7277 | "saddw v20.4s, v20.4s, v14.4h\n" |
7278 | "saddw v21.4s, v21.4s, v14.4h\n" |
7279 | "saddw v22.4s, v22.4s, v14.4h\n" |
7280 | "saddw v23.4s, v23.4s, v14.4h\n" |
7281 | "saddw v24.4s, v24.4s, v14.4h\n" |
7282 | "saddw v25.4s, v25.4s, v14.4h\n" |
7283 | "saddw v26.4s, v26.4s, v14.4h\n" |
7284 | "saddw v27.4s, v27.4s, v14.4h\n" |
7285 | "saddw v28.4s, v28.4s, v14.4h\n" |
7286 | "saddw v29.4s, v29.4s, v14.4h\n" |
7287 | "saddw v30.4s, v30.4s, v14.4h\n" |
7288 | "saddw v31.4s, v31.4s, v14.4h\n" |
7289 | |
7290 | // Cast-and-saturate from int32 to int16 |
7291 | "sqxtn v16.4h, v16.4s\n" |
7292 | "sqxtn2 v16.8h, v17.4s\n" |
7293 | "sqxtn v17.4h, v18.4s\n" |
7294 | "sqxtn2 v17.8h, v19.4s\n" |
7295 | "sqxtn v18.4h, v20.4s\n" |
7296 | "sqxtn2 v18.8h, v21.4s\n" |
7297 | "sqxtn v19.4h, v22.4s\n" |
7298 | "sqxtn2 v19.8h, v23.4s\n" |
7299 | "sqxtn v20.4h, v24.4s\n" |
7300 | "sqxtn2 v20.8h, v25.4s\n" |
7301 | "sqxtn v21.4h, v26.4s\n" |
7302 | "sqxtn2 v21.8h, v27.4s\n" |
7303 | "sqxtn v22.4h, v28.4s\n" |
7304 | "sqxtn2 v22.8h, v29.4s\n" |
7305 | "sqxtn v23.4h, v30.4s\n" |
7306 | "sqxtn2 v23.8h, v31.4s\n" |
7307 | |
7308 | // At this point, v24 -- v31 aren't used anymore for the current block, |
7309 | // so we can start clearing these accumulators for the next block |
7310 | // (next iteration of the main loop). |
7311 | RUY_MAKE_ZERO(v24) |
7312 | RUY_MAKE_ZERO(v25) |
7313 | RUY_MAKE_ZERO(v26) |
7314 | RUY_MAKE_ZERO(v27) |
7315 | RUY_MAKE_ZERO(v28) |
7316 | RUY_MAKE_ZERO(v29) |
7317 | RUY_MAKE_ZERO(v30) |
7318 | RUY_MAKE_ZERO(v31) |
7319 | |
7320 | // Load the clamp_min, clamp_max bounds |
7321 | "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
7322 | "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
7323 | "dup v14.8h, w2\n" // clamp_min |
7324 | "dup v15.8h, w3\n" // clamp_max |
7325 | |
7326 | // Apply the clamp_min bound |
7327 | "smax v16.8h, v16.8h, v14.8h\n" |
7328 | "smax v17.8h, v17.8h, v14.8h\n" |
7329 | "smax v18.8h, v18.8h, v14.8h\n" |
7330 | "smax v19.8h, v19.8h, v14.8h\n" |
7331 | "smax v20.8h, v20.8h, v14.8h\n" |
7332 | "smax v21.8h, v21.8h, v14.8h\n" |
7333 | "smax v22.8h, v22.8h, v14.8h\n" |
7334 | "smax v23.8h, v23.8h, v14.8h\n" |
7335 | // Apply the clamp_max bound |
7336 | "smin v16.8h, v16.8h, v15.8h\n" |
7337 | "smin v17.8h, v17.8h, v15.8h\n" |
7338 | "smin v18.8h, v18.8h, v15.8h\n" |
7339 | "smin v19.8h, v19.8h, v15.8h\n" |
7340 | "smin v20.8h, v20.8h, v15.8h\n" |
7341 | "smin v21.8h, v21.8h, v15.8h\n" |
7342 | "smin v22.8h, v22.8h, v15.8h\n" |
7343 | "smin v23.8h, v23.8h, v15.8h\n" |
7344 | |
7345 | // Compute how much of the 8x8 block of destination 16bit values that |
7346 | // we have computed, fit in the destination matrix. Typically, all of |
7347 | // it fits, but when the destination matrix shape is not a multiple |
7348 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
7349 | // not fit entirely. |
7350 | "sub w1, %w[dst_rows], %w[row]\n" |
7351 | "sub w2, %w[dst_cols], %w[col]\n" |
7352 | "mov w3, #8\n" |
7353 | "cmp w1, #8\n" |
7354 | // Compute w1 = how many rows of the 8x8 block fit |
7355 | "csel w1, w1, w3, le\n" |
7356 | "cmp w2, #8\n" |
7357 | // Compute w1 = how many rows of the 8x8 block fit |
7358 | "csel w2, w2, w3, le\n" |
7359 | |
7360 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
7361 | "cmp w1, w3\n" |
7362 | "ccmp w2, w3, 0, eq\n" |
7363 | // Yes, all of the 8x8 block fits, go to fast path. |
7364 | "beq 230f\n" |
7365 | // Not all of the 8x8 block fits. |
7366 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
7367 | "mov x3, %[dst_tmp_buf]\n" |
7368 | "mov x4, #16\n" |
7369 | "b 231f\n" |
7370 | "230:\n" |
7371 | // Yes, all of the 8x8 block fits. |
7372 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
7373 | "mov x3, %[dst_ptr]\n" |
7374 | "mov x4, x11\n" |
7375 | "231:\n" |
7376 | |
7377 | // Write our 8bit values to the destination described by |
7378 | // (x3 address, x4 stride). |
7379 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7380 | "st1 {v16.8h}, [x3], x4\n" |
7381 | RUY_MAKE_ZERO(v16) |
7382 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7383 | "st1 {v17.8h}, [x3], x4\n" |
7384 | RUY_MAKE_ZERO(v17) |
7385 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7386 | "st1 {v18.8h}, [x3], x4\n" |
7387 | RUY_MAKE_ZERO(v18) |
7388 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7389 | "st1 {v19.8h}, [x3], x4\n" |
7390 | RUY_MAKE_ZERO(v19) |
7391 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7392 | "st1 {v20.8h}, [x3], x4\n" |
7393 | RUY_MAKE_ZERO(v20) |
7394 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7395 | "st1 {v21.8h}, [x3], x4\n" |
7396 | RUY_MAKE_ZERO(v21) |
7397 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7398 | "st1 {v22.8h}, [x3], x4\n" |
7399 | RUY_MAKE_ZERO(v22) |
7400 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
7401 | "st1 {v23.8h}, [x3], x4\n" |
7402 | RUY_MAKE_ZERO(v23) |
7403 | |
7404 | // For the next block: perform the first few multiply-adds on the data |
7405 | // that we have already loaded. |
7406 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
7407 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
7408 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
7409 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
7410 | |
7411 | // If all of the 8x8 block fits, we just finished writing it to the |
7412 | // destination, so we skip the next part. |
7413 | "beq 241f\n" |
7414 | // Not all of the 8x8 block fits in the destination matrix. We just |
7415 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
7416 | // it to copy into the destination matrix the part that fits. |
7417 | "mov x3, %[dst_tmp_buf]\n" |
7418 | "mov x4, %[dst_ptr]\n" |
7419 | "mov w6, #0\n" |
7420 | "250:\n" |
7421 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7422 | "mov w5, #0\n" |
7423 | "251:\n" |
7424 | "ldrsh w7, [x3, x5, lsl #1]\n" |
7425 | "strh w7, [x4, x5, lsl #1]\n" |
7426 | "add w5, w5, #1\n" |
7427 | "cmp w5, w1\n" |
7428 | "blt 251b\n" |
7429 | "add w6, w6, #1\n" |
7430 | "add x3, x3, #16\n" |
7431 | "add x4, x4, x11\n" |
7432 | "cmp w6, w2\n" |
7433 | "blt 250b\n" |
7434 | "241:\n" |
7435 | "add %[dst_ptr], %[dst_ptr], #16\n" |
7436 | // At this point we have completely finished writing values to the |
7437 | // destination matrix for the current block. |
7438 | |
7439 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
7440 | |
7441 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
7442 | |
7443 | "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" |
7444 | "ldr x1, [%[lhs_ptr]], #8\n" |
7445 | "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" |
7446 | "ldr x2, [%[lhs_ptr]], #8\n" |
7447 | "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" |
7448 | "ldr x5, [%[rhs_ptr]], #8\n" |
7449 | "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" |
7450 | "ldr x6, [%[rhs_ptr]], #8\n" |
7451 | "ins v0.d[1], x1\n" |
7452 | "ins v1.d[1], x2\n" |
7453 | "ins v2.d[1], x5\n" |
7454 | "ins v3.d[1], x6\n" |
7455 | |
7456 | // Since the store type is the same as the accum type, no need for |
7457 | // downcast. There's also no need for clamp by min/max. |
7458 | |
7459 | // Compute how much of the 8x8 block of destination 32it values that |
7460 | // we have computed, fit in the destination matrix. Typically, all of |
7461 | // it fits, but when the destination matrix shape is not a multiple |
7462 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
7463 | // not fit entirely. |
7464 | "sub w1, %w[dst_rows], %w[row]\n" |
7465 | "sub w2, %w[dst_cols], %w[col]\n" |
7466 | "mov w3, #8\n" |
7467 | "cmp w1, #8\n" |
7468 | // Compute w1 = how many rows of the 8x8 block fit |
7469 | "csel w1, w1, w3, le\n" |
7470 | "cmp w2, #8\n" |
7471 | // Compute w1 = how many rows of the 8x8 block fit |
7472 | "csel w2, w2, w3, le\n" |
7473 | |
7474 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
7475 | "cmp w1, w3\n" |
7476 | "ccmp w2, w3, 0, eq\n" |
7477 | // Yes, all of the 8x8 block fits, go to fast path. |
7478 | "beq 330f\n" |
7479 | // Not all of the 8x8 block fits. |
7480 | // Write to dst_tmp_buf |
7481 | "mov x3, %[dst_tmp_buf]\n" |
7482 | "st1 {v16.4s}, [x3], #16\n" |
7483 | RUY_MAKE_ZERO(v16) |
7484 | "st1 {v17.4s}, [x3], #16\n" |
7485 | RUY_MAKE_ZERO(v17) |
7486 | "st1 {v18.4s}, [x3], #16\n" |
7487 | RUY_MAKE_ZERO(v18) |
7488 | "st1 {v19.4s}, [x3], #16\n" |
7489 | RUY_MAKE_ZERO(v19) |
7490 | "st1 {v20.4s}, [x3], #16\n" |
7491 | RUY_MAKE_ZERO(v20) |
7492 | "st1 {v21.4s}, [x3], #16\n" |
7493 | RUY_MAKE_ZERO(v21) |
7494 | "st1 {v22.4s}, [x3], #16\n" |
7495 | RUY_MAKE_ZERO(v22) |
7496 | "st1 {v23.4s}, [x3], #16\n" |
7497 | RUY_MAKE_ZERO(v23) |
7498 | "st1 {v24.4s}, [x3], #16\n" |
7499 | RUY_MAKE_ZERO(v24) |
7500 | "st1 {v25.4s}, [x3], #16\n" |
7501 | RUY_MAKE_ZERO(v25) |
7502 | "st1 {v26.4s}, [x3], #16\n" |
7503 | RUY_MAKE_ZERO(v26) |
7504 | "st1 {v27.4s}, [x3], #16\n" |
7505 | RUY_MAKE_ZERO(v27) |
7506 | "st1 {v28.4s}, [x3], #16\n" |
7507 | RUY_MAKE_ZERO(v28) |
7508 | "st1 {v29.4s}, [x3], #16\n" |
7509 | RUY_MAKE_ZERO(v29) |
7510 | "st1 {v30.4s}, [x3], #16\n" |
7511 | RUY_MAKE_ZERO(v30) |
7512 | "st1 {v31.4s}, [x3], #16\n" |
7513 | RUY_MAKE_ZERO(v31) |
7514 | |
7515 | "b 331f\n" |
7516 | |
7517 | "330:\n" |
7518 | // Yes, all of the 8x8 block fits. |
7519 | "mov x4, %[dst_ptr]\n" |
7520 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7521 | "st1 {v16.4s, v17.4s}, [x4], x11\n" |
7522 | RUY_MAKE_ZERO(v16) |
7523 | RUY_MAKE_ZERO(v17) |
7524 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7525 | "st1 {v18.4s, v19.4s}, [x4], x11\n" |
7526 | RUY_MAKE_ZERO(v18) |
7527 | RUY_MAKE_ZERO(v19) |
7528 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7529 | "st1 {v20.4s, v21.4s}, [x4], x11\n" |
7530 | RUY_MAKE_ZERO(v20) |
7531 | RUY_MAKE_ZERO(v21) |
7532 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7533 | "st1 {v22.4s, v23.4s}, [x4], x11\n" |
7534 | RUY_MAKE_ZERO(v22) |
7535 | RUY_MAKE_ZERO(v23) |
7536 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7537 | "st1 {v24.4s, v25.4s}, [x4], x11\n" |
7538 | RUY_MAKE_ZERO(v24) |
7539 | RUY_MAKE_ZERO(v25) |
7540 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7541 | "st1 {v26.4s, v27.4s}, [x4], x11\n" |
7542 | RUY_MAKE_ZERO(v26) |
7543 | RUY_MAKE_ZERO(v27) |
7544 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7545 | "st1 {v28.4s, v29.4s}, [x4], x11\n" |
7546 | RUY_MAKE_ZERO(v28) |
7547 | RUY_MAKE_ZERO(v29) |
7548 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7549 | "st1 {v30.4s, v31.4s}, [x4], x11\n" |
7550 | RUY_MAKE_ZERO(v30) |
7551 | RUY_MAKE_ZERO(v31) |
7552 | |
7553 | "331:\n" |
7554 | |
7555 | // For the next block: perform the first few multiply-adds on the data |
7556 | // that we have already loaded. |
7557 | ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" |
7558 | ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" |
7559 | ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" |
7560 | ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" |
7561 | |
7562 | // If all of the 8x8 block fits, we just finished writing it to the |
7563 | // destination, so we skip the next part. |
7564 | "beq 341f\n" |
7565 | |
7566 | // Not all of the 8x8 block fits in the destination matrix. We just |
7567 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
7568 | // it to copy into the destination matrix the part that fits. |
7569 | "mov x3, %[dst_tmp_buf]\n" |
7570 | "mov x4, %[dst_ptr]\n" |
7571 | "mov w6, #0\n" |
7572 | "350:\n" |
7573 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
7574 | "mov w5, #0\n" |
7575 | "351:\n" |
7576 | "ldr w7, [x3, x5, lsl #2]\n" |
7577 | "str w7, [x4, x5, lsl #2]\n" |
7578 | "add w5, w5, #1\n" |
7579 | "cmp w5, w1\n" |
7580 | "blt 351b\n" |
7581 | "add w6, w6, #1\n" |
7582 | "add x3, x3, #32\n" |
7583 | "add x4, x4, x11\n" |
7584 | "cmp w6, w2\n" |
7585 | "blt 350b\n" |
7586 | "341:\n" |
7587 | "add %[dst_ptr], %[dst_ptr], #32\n" |
7588 | // At this point we have completely finished writing values to the |
7589 | // destination matrix for the current block. |
7590 | |
7591 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
7592 | |
7593 | // Reload some params --- we had used x5 -- x7 for a few other things |
7594 | // since the last time we had loaded them. |
7595 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
7596 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
7597 | |
7598 | // Move to the next block of the destination matrix, for the next iter |
7599 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
7600 | // been updated earlier. |
7601 | // Have we reached the end row? |
7602 | "cmp %w[row], w7\n" |
7603 | "beq 20f\n" // yes, end row. |
7604 | // Not end row. Move to the next row. |
7605 | "add %w[row], %w[row], #8\n" |
7606 | "b 21f\n" |
7607 | "20:\n" |
7608 | // Was already at end row. |
7609 | "mov %w[row], w6\n" // Move back to first row. |
7610 | "add %w[col], %w[col], #8\n" // Move to the next column. |
7611 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
7612 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
7613 | "21:\n" |
7614 | |
7615 | // Main loop exit condition: have we hit the end column? |
7616 | "cmp %w[col], w8\n" |
7617 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
7618 | "ble 1b\n" |
7619 | |
7620 | // clang-format on |
7621 | |
7622 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
7623 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
7624 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
7625 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
7626 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf), |
7627 | [dst_type_id] "r" (params.dst_type_id) |
7628 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
7629 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
7630 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
7631 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
7632 | } |
7633 | #undef RUY_OFFSET_BIAS |
7634 | #undef RUY_OFFSET_LHS_SUMS |
7635 | #undef RUY_OFFSET_RHS_SUMS |
7636 | #undef RUY_OFFSET_LHS_BASE_PTR |
7637 | #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT |
7638 | #undef RUY_OFFSET_MULTIPLIER_EXPONENT |
7639 | #undef RUY_OFFSET_RHS_BASE_PTR |
7640 | #undef RUY_OFFSET_DST_BASE_PTR |
7641 | #undef RUY_OFFSET_LHS_ZERO_POINT |
7642 | #undef RUY_OFFSET_RHS_ZERO_POINT |
7643 | #undef RUY_OFFSET_DST_ZERO_POINT |
7644 | #undef RUY_OFFSET_PROD_ZP_DEPTH |
7645 | #undef RUY_OFFSET_START_ROW |
7646 | #undef RUY_OFFSET_START_COL |
7647 | #undef RUY_OFFSET_LAST_ROW |
7648 | #undef RUY_OFFSET_LAST_COL |
7649 | #undef RUY_OFFSET_DST_ROWS |
7650 | #undef RUY_OFFSET_DST_COLS |
7651 | #undef RUY_OFFSET_LHS_STRIDE |
7652 | #undef RUY_OFFSET_RHS_STRIDE |
7653 | #undef RUY_OFFSET_DST_STRIDE |
7654 | #undef RUY_OFFSET_DEPTH |
7655 | #undef RUY_OFFSET_CLAMP_MIN |
7656 | #undef RUY_OFFSET_CLAMP_MAX |
7657 | #undef RUY_OFFSET_FLAGS |
7658 | |
7659 | #define RUY_OFFSET_LHS_BASE_PTR 0 |
7660 | #define RUY_OFFSET_RHS_BASE_PTR 8 |
7661 | #define RUY_OFFSET_DST_BASE_PTR 16 |
7662 | #define RUY_OFFSET_BIAS 24 |
7663 | #define RUY_OFFSET_START_ROW 32 |
7664 | #define RUY_OFFSET_START_COL 36 |
7665 | #define RUY_OFFSET_LAST_ROW 40 |
7666 | #define RUY_OFFSET_LAST_COL 44 |
7667 | #define RUY_OFFSET_LHS_STRIDE 56 |
7668 | #define RUY_OFFSET_RHS_STRIDE 60 |
7669 | #define RUY_OFFSET_DST_STRIDE 64 |
7670 | #define RUY_OFFSET_DEPTH 68 |
7671 | #define RUY_OFFSET_CLAMP_MIN 72 |
7672 | #define RUY_OFFSET_CLAMP_MAX 76 |
7673 | #define RUY_OFFSET_FLAGS 80 |
7674 | |
7675 | template <typename Params> |
7676 | void CheckOffsetsInKernelParamsFloat(const Params&) { |
7677 | static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "" ); |
7678 | static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "" ); |
7679 | static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "" ); |
7680 | static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "" ); |
7681 | static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "" ); |
7682 | static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "" ); |
7683 | static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "" ); |
7684 | static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "" ); |
7685 | static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "" ); |
7686 | static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "" ); |
7687 | static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "" ); |
7688 | static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "" ); |
7689 | static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "" ); |
7690 | static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "" ); |
7691 | static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "" ); |
7692 | } |
7693 | |
7694 | // Just a plain float kernel; good enough for out-of-order cores. |
7695 | // The closest to it in the gemmlowp collection would be |
7696 | // NEON_64bit_GEMM_Float32_WithScalar, |
7697 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 |
7698 | // |
7699 | // Besides ruy-ification, the main nuance here is that we stick to a 8x8 |
7700 | // width instead of the wider 12x8 that the register space permits and that |
7701 | // the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now |
7702 | // and we don't have evidence that going beyond 8x8 is needed. |
7703 | void KernelFloatNeon(const KernelParamsFloat<8, 8>& params) { |
7704 | CheckOffsetsInKernelParamsFloat(params); |
7705 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
7706 | |
7707 | const float* lhs_col_ptr = params.lhs_base_ptr; |
7708 | const float* rhs_col_ptr = params.rhs_base_ptr; |
7709 | const float* lhs_ptr = lhs_col_ptr; |
7710 | const float* rhs_ptr = rhs_col_ptr; |
7711 | float* dst_col_ptr = params.dst_base_ptr; |
7712 | float* dst_ptr = dst_col_ptr; |
7713 | int row = params.start_row; |
7714 | int col = params.start_col; |
7715 | |
7716 | // The asm kernel below has the following NEON register allocation: |
7717 | // |
7718 | // v16 -- v31 are accumulators. |
7719 | // During accumulation, v0 -- v15 are used to load data from LHS and RHS. |
7720 | // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and |
7721 | // v3 are used to load a 1x8 block of RHS, like this: |
7722 | // |
7723 | // RHS 1x8 block |
7724 | // /-----------------------------------------| |
7725 | // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| |
7726 | // \-----------------------------------------/ |
7727 | // LHS 8x1 block |
7728 | // /---------------------\ /-----------------------------------------| |
7729 | // | v0.s[0] | |v16.s[0] ... v30.s[0]| |
7730 | // | ... | | ... ... | |
7731 | // | v0.s[3] | |v16.s[3] ... v30.s[3]| |
7732 | // | v1.s[0] | |v17.s[0] ... v31.s[0]| |
7733 | // | ... | | ... ... | |
7734 | // | v1.s[3] | |v17.s[3] ... v31.s[3]| |
7735 | // \---------------------/ \-----------------------------------------/ |
7736 | // accumulators 8x8 block |
7737 | // |
7738 | // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step |
7739 | // is repeated 4 times, using 4x more registers for LHS and RHS, so that |
7740 | // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. |
7741 | // |
7742 | // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are |
7743 | // unused, and v8 -- v15 are used for floading parameters used for the |
7744 | // post-accumulation part of the kernel. |
7745 | asm volatile( |
7746 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
7747 | |
7748 | // clang-format off |
7749 | |
7750 | // Load some parameters into registers. |
7751 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
7752 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
7753 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
7754 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
7755 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
7756 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
7757 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
7758 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
7759 | |
7760 | // Load the first 32 bytes of LHS and RHS data. |
7761 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
7762 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
7763 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
7764 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
7765 | |
7766 | // Clear accumulators. |
7767 | RUY_MAKE_ZERO(v16) |
7768 | RUY_MAKE_ZERO(v17) |
7769 | RUY_MAKE_ZERO(v18) |
7770 | RUY_MAKE_ZERO(v19) |
7771 | RUY_MAKE_ZERO(v20) |
7772 | RUY_MAKE_ZERO(v21) |
7773 | RUY_MAKE_ZERO(v22) |
7774 | RUY_MAKE_ZERO(v23) |
7775 | RUY_MAKE_ZERO(v24) |
7776 | RUY_MAKE_ZERO(v25) |
7777 | RUY_MAKE_ZERO(v26) |
7778 | RUY_MAKE_ZERO(v27) |
7779 | RUY_MAKE_ZERO(v28) |
7780 | RUY_MAKE_ZERO(v29) |
7781 | RUY_MAKE_ZERO(v30) |
7782 | RUY_MAKE_ZERO(v31) |
7783 | |
7784 | // w1 is the number of levels of depth that we have already loaded |
7785 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
7786 | // above, this is currently 1. |
7787 | "mov w1, #1\n" |
7788 | |
7789 | // Main loop of the whole GEMM, over rows and columns of the |
7790 | // destination matrix. |
7791 | "1:\n" |
7792 | |
7793 | "fmla v16.4s, v0.4s, v2.s[0]\n" |
7794 | "fmla v18.4s, v0.4s, v2.s[1]\n" |
7795 | "fmla v20.4s, v0.4s, v2.s[2]\n" |
7796 | "fmla v22.4s, v0.4s, v2.s[3]\n" |
7797 | |
7798 | #if RUY_OPT(MAX_STREAMING) |
7799 | "cmp w12, #8\n" |
7800 | "blt 78f\n" |
7801 | "and w2, w12, #-4\n" |
7802 | |
7803 | "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" |
7804 | "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" |
7805 | "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" |
7806 | "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" |
7807 | |
7808 | "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" |
7809 | "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" |
7810 | "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" |
7811 | "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" |
7812 | |
7813 | "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" |
7814 | "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" |
7815 | "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" |
7816 | "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" |
7817 | "mov w1, #4\n" |
7818 | |
7819 | "80:\n" |
7820 | |
7821 | "add %[lhs_ptr], %[lhs_ptr], #128\n" |
7822 | "add %[rhs_ptr], %[rhs_ptr], #128\n" |
7823 | |
7824 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
7825 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
7826 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
7827 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
7828 | "ldr q0, [%[lhs_ptr], #-128]\n" |
7829 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
7830 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
7831 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
7832 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
7833 | "ldr q3, [%[rhs_ptr], #-112]\n" |
7834 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
7835 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
7836 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
7837 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
7838 | "ldr q1, [%[lhs_ptr], #-112]\n" |
7839 | "fmla v16.4s, v4.4s, v6.s[0]\n" |
7840 | "fmla v18.4s, v4.4s, v6.s[1]\n" |
7841 | "ldr q2, [%[rhs_ptr], #-128]\n" |
7842 | "fmla v20.4s, v4.4s, v6.s[2]\n" |
7843 | "fmla v22.4s, v4.4s, v6.s[3]\n" |
7844 | |
7845 | "fmla v24.4s, v4.4s, v7.s[0]\n" |
7846 | "fmla v26.4s, v4.4s, v7.s[1]\n" |
7847 | "fmla v28.4s, v4.4s, v7.s[2]\n" |
7848 | "fmla v30.4s, v4.4s, v7.s[3]\n" |
7849 | "ldr q4, [%[lhs_ptr], #-96]\n" |
7850 | "fmla v25.4s, v5.4s, v7.s[0]\n" |
7851 | "fmla v27.4s, v5.4s, v7.s[1]\n" |
7852 | "fmla v29.4s, v5.4s, v7.s[2]\n" |
7853 | "fmla v31.4s, v5.4s, v7.s[3]\n" |
7854 | "ldr q7, [%[rhs_ptr], #-80]\n" |
7855 | "fmla v17.4s, v5.4s, v6.s[0]\n" |
7856 | "fmla v19.4s, v5.4s, v6.s[1]\n" |
7857 | "fmla v21.4s, v5.4s, v6.s[2]\n" |
7858 | "fmla v23.4s, v5.4s, v6.s[3]\n" |
7859 | "ldr q5, [%[lhs_ptr], #-80]\n" |
7860 | "fmla v16.4s, v8.4s, v10.s[0]\n" |
7861 | "fmla v18.4s, v8.4s, v10.s[1]\n" |
7862 | "ldr q6, [%[rhs_ptr], #-96]\n" |
7863 | "fmla v20.4s, v8.4s, v10.s[2]\n" |
7864 | "fmla v22.4s, v8.4s, v10.s[3]\n" |
7865 | |
7866 | "fmla v24.4s, v8.4s, v11.s[0]\n" |
7867 | "fmla v26.4s, v8.4s, v11.s[1]\n" |
7868 | "fmla v28.4s, v8.4s, v11.s[2]\n" |
7869 | "fmla v30.4s, v8.4s, v11.s[3]\n" |
7870 | "ldr q8, [%[lhs_ptr], #-64]\n" |
7871 | "fmla v25.4s, v9.4s, v11.s[0]\n" |
7872 | "fmla v27.4s, v9.4s, v11.s[1]\n" |
7873 | "fmla v29.4s, v9.4s, v11.s[2]\n" |
7874 | "fmla v31.4s, v9.4s, v11.s[3]\n" |
7875 | "ldr q11, [%[rhs_ptr], #-48]\n" |
7876 | "fmla v17.4s, v9.4s, v10.s[0]\n" |
7877 | "fmla v19.4s, v9.4s, v10.s[1]\n" |
7878 | "fmla v21.4s, v9.4s, v10.s[2]\n" |
7879 | "fmla v23.4s, v9.4s, v10.s[3]\n" |
7880 | "ldr q9, [%[lhs_ptr], #-48]\n" |
7881 | "fmla v16.4s, v12.4s, v14.s[0]\n" |
7882 | "fmla v18.4s, v12.4s, v14.s[1]\n" |
7883 | "ldr q10, [%[rhs_ptr], #-64]\n" |
7884 | "fmla v20.4s, v12.4s, v14.s[2]\n" |
7885 | "fmla v22.4s, v12.4s, v14.s[3]\n" |
7886 | |
7887 | "fmla v24.4s, v12.4s, v15.s[0]\n" |
7888 | "fmla v26.4s, v12.4s, v15.s[1]\n" |
7889 | "fmla v28.4s, v12.4s, v15.s[2]\n" |
7890 | "fmla v30.4s, v12.4s, v15.s[3]\n" |
7891 | "ldr q12, [%[lhs_ptr], #-32]\n" |
7892 | "fmla v25.4s, v13.4s, v15.s[0]\n" |
7893 | "fmla v27.4s, v13.4s, v15.s[1]\n" |
7894 | "fmla v29.4s, v13.4s, v15.s[2]\n" |
7895 | "fmla v31.4s, v13.4s, v15.s[3]\n" |
7896 | "ldr q15, [%[rhs_ptr], #-16]\n" |
7897 | "fmla v17.4s, v13.4s, v14.s[0]\n" |
7898 | "fmla v19.4s, v13.4s, v14.s[1]\n" |
7899 | "fmla v21.4s, v13.4s, v14.s[2]\n" |
7900 | "fmla v23.4s, v13.4s, v14.s[3]\n" |
7901 | "ldr q13, [%[lhs_ptr], #-16]\n" |
7902 | "fmla v16.4s, v0.4s, v2.s[0]\n" |
7903 | "fmla v18.4s, v0.4s, v2.s[1]\n" |
7904 | "ldr q14, [%[rhs_ptr], #-32]\n" |
7905 | "fmla v20.4s, v0.4s, v2.s[2]\n" |
7906 | "fmla v22.4s, v0.4s, v2.s[3]\n" |
7907 | |
7908 | "add w1, w1, #4\n" |
7909 | "cmp w1, w2\n" |
7910 | "blt 80b\n" |
7911 | |
7912 | "fmla v16.4s, v4.4s, v6.s[0]\n" |
7913 | "fmla v18.4s, v4.4s, v6.s[1]\n" |
7914 | "fmla v20.4s, v4.4s, v6.s[2]\n" |
7915 | "fmla v22.4s, v4.4s, v6.s[3]\n" |
7916 | "fmla v24.4s, v4.4s, v7.s[0]\n" |
7917 | "fmla v26.4s, v4.4s, v7.s[1]\n" |
7918 | "fmla v28.4s, v4.4s, v7.s[2]\n" |
7919 | "fmla v30.4s, v4.4s, v7.s[3]\n" |
7920 | "fmla v25.4s, v5.4s, v7.s[0]\n" |
7921 | "fmla v27.4s, v5.4s, v7.s[1]\n" |
7922 | "fmla v29.4s, v5.4s, v7.s[2]\n" |
7923 | "fmla v31.4s, v5.4s, v7.s[3]\n" |
7924 | "fmla v17.4s, v5.4s, v6.s[0]\n" |
7925 | "fmla v19.4s, v5.4s, v6.s[1]\n" |
7926 | "fmla v21.4s, v5.4s, v6.s[2]\n" |
7927 | "fmla v23.4s, v5.4s, v6.s[3]\n" |
7928 | |
7929 | "fmla v16.4s, v8.4s, v10.s[0]\n" |
7930 | "fmla v18.4s, v8.4s, v10.s[1]\n" |
7931 | "fmla v20.4s, v8.4s, v10.s[2]\n" |
7932 | "fmla v22.4s, v8.4s, v10.s[3]\n" |
7933 | "fmla v24.4s, v8.4s, v11.s[0]\n" |
7934 | "fmla v26.4s, v8.4s, v11.s[1]\n" |
7935 | "fmla v28.4s, v8.4s, v11.s[2]\n" |
7936 | "fmla v30.4s, v8.4s, v11.s[3]\n" |
7937 | "fmla v25.4s, v9.4s, v11.s[0]\n" |
7938 | "fmla v27.4s, v9.4s, v11.s[1]\n" |
7939 | "fmla v29.4s, v9.4s, v11.s[2]\n" |
7940 | "fmla v31.4s, v9.4s, v11.s[3]\n" |
7941 | "fmla v17.4s, v9.4s, v10.s[0]\n" |
7942 | "fmla v19.4s, v9.4s, v10.s[1]\n" |
7943 | "fmla v21.4s, v9.4s, v10.s[2]\n" |
7944 | "fmla v23.4s, v9.4s, v10.s[3]\n" |
7945 | |
7946 | "fmla v16.4s, v12.4s, v14.s[0]\n" |
7947 | "fmla v18.4s, v12.4s, v14.s[1]\n" |
7948 | "fmla v20.4s, v12.4s, v14.s[2]\n" |
7949 | "fmla v22.4s, v12.4s, v14.s[3]\n" |
7950 | "fmla v24.4s, v12.4s, v15.s[0]\n" |
7951 | "fmla v26.4s, v12.4s, v15.s[1]\n" |
7952 | "fmla v28.4s, v12.4s, v15.s[2]\n" |
7953 | "fmla v30.4s, v12.4s, v15.s[3]\n" |
7954 | "fmla v25.4s, v13.4s, v15.s[0]\n" |
7955 | "fmla v27.4s, v13.4s, v15.s[1]\n" |
7956 | "fmla v29.4s, v13.4s, v15.s[2]\n" |
7957 | "fmla v31.4s, v13.4s, v15.s[3]\n" |
7958 | "fmla v17.4s, v13.4s, v14.s[0]\n" |
7959 | "fmla v19.4s, v13.4s, v14.s[1]\n" |
7960 | "fmla v21.4s, v13.4s, v14.s[2]\n" |
7961 | "fmla v23.4s, v13.4s, v14.s[3]\n" |
7962 | |
7963 | "78:\n" |
7964 | #endif |
7965 | |
7966 | // Accumulation loop |
7967 | "cmp w1, w12\n" |
7968 | "beq 79f\n" |
7969 | |
7970 | "2:\n" |
7971 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
7972 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
7973 | "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" |
7974 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
7975 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
7976 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
7977 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
7978 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
7979 | "add w1, w1, #1\n" |
7980 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
7981 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
7982 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
7983 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
7984 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
7985 | "cmp w1, w12\n" |
7986 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
7987 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
7988 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
7989 | "fmla v16.4s, v0.4s, v4.s[0]\n" |
7990 | "fmla v18.4s, v0.4s, v4.s[1]\n" |
7991 | "mov v2.16b, v4.16b\n" |
7992 | "fmla v20.4s, v0.4s, v4.s[2]\n" |
7993 | "fmla v22.4s, v0.4s, v4.s[3]\n" |
7994 | "blt 2b\n" |
7995 | |
7996 | "79:\n" |
7997 | |
7998 | // End of the inner loop on depth. Now perform the remaining |
7999 | // multiply-adds of the last level of depth, for which the LHS |
8000 | // and RHS data is already loaded. |
8001 | |
8002 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
8003 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
8004 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
8005 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
8006 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
8007 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
8008 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
8009 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
8010 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
8011 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
8012 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
8013 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
8014 | |
8015 | // End of accumulation. The registers v16 -- v31 contain the final |
8016 | // int32 accumulator values of the current 8x8 destination block. |
8017 | // We now have to compute the final 8-bit values from these int32 |
8018 | // accumulators, and advance to the next 8x8 block. We intertwine |
8019 | // these two aspects whenever possible for optimal pipelining, both |
8020 | // at the data flow level (prefetch data for next block as early as |
8021 | // possible) and instruction pipelining level (some of the next-block |
8022 | // work can dual-issue with some of the final work on the current |
8023 | // block). |
8024 | |
8025 | // Logic to advance to the next block in preparation for the next |
8026 | // iteration of the main loop. For now, we only want to compute |
8027 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
8028 | // not yet ready to update the values of row and col, as we still need |
8029 | // the current values for the rest of the work on the current block. |
8030 | |
8031 | "cmp %w[row], w7\n" // Have we finished the last row? |
8032 | "bge 4f\n" // If finished last row, go to 4 |
8033 | // Not finished last row: then advance to next row. |
8034 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
8035 | "b 5f\n" |
8036 | "4:\n" // Finished last row... |
8037 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
8038 | // Now we need to advance to the next column. If we already |
8039 | // finished the last column, then in principle we are done, however |
8040 | // we can't just return here, as we need to allow the end work of the |
8041 | // current block to complete. The good news is that at this point it |
8042 | // doesn't matter what data we load for the next column, since |
8043 | // we will exit from the main loop below before actually storing |
8044 | // anything computed from that data. |
8045 | "cmp %w[col], w8\n" // Have we finished the last column? |
8046 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
8047 | // Not finished last column: then advance to next column. |
8048 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
8049 | "5:\n" |
8050 | |
8051 | // Set the LHS and RHS data pointers to the start of the columns just |
8052 | // computed. |
8053 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
8054 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
8055 | |
8056 | // Load some parameters needed for the end work on current block. |
8057 | "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
8058 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
8059 | |
8060 | // Determine the channel index. |
8061 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
8062 | "csel w3, %w[row], %w[col], eq\n" |
8063 | |
8064 | // Offset the bias pointer as needed given the current row, col. |
8065 | "add x5, x1, x3, lsl #2\n" |
8066 | |
8067 | // If there is no bias, use no offset, just address the passed zero |
8068 | // data. |
8069 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
8070 | "csel x1, x1, x5, eq\n" |
8071 | |
8072 | // Load 8 bias values. |
8073 | "ld1 {v14.4s}, [x1], #16\n" |
8074 | "ld1 {v15.4s}, [x1]\n" |
8075 | |
8076 | // Now that we know what LHS and RHS data the next iteration of the |
8077 | // main loop will need to load, we start loading the first 32 bytes of |
8078 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
8079 | // in the rest of the work on the current block. |
8080 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
8081 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
8082 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
8083 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
8084 | |
8085 | // Perform the bias-addition. |
8086 | // Jump based on channel dimension. |
8087 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
8088 | "bne 6f\n" |
8089 | // Case where channels are rows |
8090 | "fadd v16.4s, v16.4s, v14.4s\n" |
8091 | "fadd v17.4s, v17.4s, v15.4s\n" |
8092 | "fadd v18.4s, v18.4s, v14.4s\n" |
8093 | "fadd v19.4s, v19.4s, v15.4s\n" |
8094 | "fadd v20.4s, v20.4s, v14.4s\n" |
8095 | "fadd v21.4s, v21.4s, v15.4s\n" |
8096 | "fadd v22.4s, v22.4s, v14.4s\n" |
8097 | "fadd v23.4s, v23.4s, v15.4s\n" |
8098 | "fadd v24.4s, v24.4s, v14.4s\n" |
8099 | "fadd v25.4s, v25.4s, v15.4s\n" |
8100 | "fadd v26.4s, v26.4s, v14.4s\n" |
8101 | "fadd v27.4s, v27.4s, v15.4s\n" |
8102 | "fadd v28.4s, v28.4s, v14.4s\n" |
8103 | "fadd v29.4s, v29.4s, v15.4s\n" |
8104 | "fadd v30.4s, v30.4s, v14.4s\n" |
8105 | "fadd v31.4s, v31.4s, v15.4s\n" |
8106 | "b 7f\n" |
8107 | |
8108 | "6:\n" |
8109 | // Case where channels are columns |
8110 | "dup v8.4s, v14.s[0]\n" |
8111 | "dup v9.4s, v14.s[1]\n" |
8112 | "dup v10.4s, v14.s[2]\n" |
8113 | "dup v11.4s, v14.s[3]\n" |
8114 | "dup v12.4s, v15.s[0]\n" |
8115 | "dup v13.4s, v15.s[1]\n" |
8116 | "dup v14.4s, v15.s[2]\n" |
8117 | "dup v15.4s, v15.s[3]\n" |
8118 | "fadd v16.4s, v16.4s, v8.4s\n" |
8119 | "fadd v17.4s, v17.4s, v8.4s\n" |
8120 | "fadd v18.4s, v18.4s, v9.4s\n" |
8121 | "fadd v19.4s, v19.4s, v9.4s\n" |
8122 | "fadd v20.4s, v20.4s, v10.4s\n" |
8123 | "fadd v21.4s, v21.4s, v10.4s\n" |
8124 | "fadd v22.4s, v22.4s, v11.4s\n" |
8125 | "fadd v23.4s, v23.4s, v11.4s\n" |
8126 | "fadd v24.4s, v24.4s, v12.4s\n" |
8127 | "fadd v25.4s, v25.4s, v12.4s\n" |
8128 | "fadd v26.4s, v26.4s, v13.4s\n" |
8129 | "fadd v27.4s, v27.4s, v13.4s\n" |
8130 | "fadd v28.4s, v28.4s, v14.4s\n" |
8131 | "fadd v29.4s, v29.4s, v14.4s\n" |
8132 | "fadd v30.4s, v30.4s, v15.4s\n" |
8133 | "fadd v31.4s, v31.4s, v15.4s\n" |
8134 | "7:\n" |
8135 | |
8136 | // Load the clamp_min, clamp_max bounds |
8137 | "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
8138 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
8139 | "dup v14.4s, w2\n" // clamp_min |
8140 | "dup v15.4s, w3\n" // clamp_max |
8141 | |
8142 | // Apply the clamp_min bound |
8143 | "fmax v16.4s, v16.4s, v14.4s\n" |
8144 | "fmax v17.4s, v17.4s, v14.4s\n" |
8145 | "fmax v18.4s, v18.4s, v14.4s\n" |
8146 | "fmax v19.4s, v19.4s, v14.4s\n" |
8147 | "fmax v20.4s, v20.4s, v14.4s\n" |
8148 | "fmax v21.4s, v21.4s, v14.4s\n" |
8149 | "fmax v22.4s, v22.4s, v14.4s\n" |
8150 | "fmax v23.4s, v23.4s, v14.4s\n" |
8151 | "fmax v24.4s, v24.4s, v14.4s\n" |
8152 | "fmax v25.4s, v25.4s, v14.4s\n" |
8153 | "fmax v26.4s, v26.4s, v14.4s\n" |
8154 | "fmax v27.4s, v27.4s, v14.4s\n" |
8155 | "fmax v28.4s, v28.4s, v14.4s\n" |
8156 | "fmax v29.4s, v29.4s, v14.4s\n" |
8157 | "fmax v30.4s, v30.4s, v14.4s\n" |
8158 | "fmax v31.4s, v31.4s, v14.4s\n" |
8159 | |
8160 | // Apply the clamp_max bound |
8161 | "fmin v16.4s, v16.4s, v15.4s\n" |
8162 | "fmin v17.4s, v17.4s, v15.4s\n" |
8163 | "fmin v18.4s, v18.4s, v15.4s\n" |
8164 | "fmin v19.4s, v19.4s, v15.4s\n" |
8165 | "fmin v20.4s, v20.4s, v15.4s\n" |
8166 | "fmin v21.4s, v21.4s, v15.4s\n" |
8167 | "fmin v22.4s, v22.4s, v15.4s\n" |
8168 | "fmin v23.4s, v23.4s, v15.4s\n" |
8169 | "fmin v24.4s, v24.4s, v15.4s\n" |
8170 | "fmin v25.4s, v25.4s, v15.4s\n" |
8171 | "fmin v26.4s, v26.4s, v15.4s\n" |
8172 | "fmin v27.4s, v27.4s, v15.4s\n" |
8173 | "fmin v28.4s, v28.4s, v15.4s\n" |
8174 | "fmin v29.4s, v29.4s, v15.4s\n" |
8175 | "fmin v30.4s, v30.4s, v15.4s\n" |
8176 | "fmin v31.4s, v31.4s, v15.4s\n" |
8177 | |
8178 | // Compute how much of the 8x8 block of destination 8bit values that |
8179 | // we have computed, fit in the destination matrix. Typically, all of |
8180 | // it fits, but when the destination matrix shape is not a multiple |
8181 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
8182 | // not fit entirely. |
8183 | "sub w1, %w[dst_rows], %w[row]\n" |
8184 | "sub w2, %w[dst_cols], %w[col]\n" |
8185 | "mov w3, #8\n" |
8186 | "cmp w1, #8\n" |
8187 | // Compute w1 = how many rows of the 8x8 block fit |
8188 | "csel w1, w1, w3, le\n" |
8189 | "cmp w2, #8\n" |
8190 | // Compute w2 = how many cols of the 8x8 block fit |
8191 | "csel w2, w2, w3, le\n" |
8192 | |
8193 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
8194 | "cmp w1, w3\n" |
8195 | "ccmp w2, w3, 0, eq\n" |
8196 | // Yes, all of the 8x8 block fits, go to fast path. |
8197 | "beq 30f\n" |
8198 | // Not all of the 8x8 block fits. |
8199 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
8200 | "mov x3, %[dst_tmp_buf]\n" |
8201 | "mov x4, #32\n" |
8202 | "b 31f\n" |
8203 | "30:\n" |
8204 | // Yes, all of the 8x8 block fits. |
8205 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
8206 | "mov x3, %[dst_ptr]\n" |
8207 | "mov x4, x11\n" |
8208 | "31:\n" |
8209 | |
8210 | // Write our 8bit values to the destination described by |
8211 | // (x3 address, x4 stride). |
8212 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8213 | "str q16, [x3, #0]\n" |
8214 | "str q17, [x3, #16]\n" |
8215 | "add x3, x3, x4\n" |
8216 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8217 | RUY_MAKE_ZERO(v16) |
8218 | RUY_MAKE_ZERO(v17) |
8219 | "str q18, [x3, #0]\n" |
8220 | "str q19, [x3, #16]\n" |
8221 | "add x3, x3, x4\n" |
8222 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8223 | RUY_MAKE_ZERO(v18) |
8224 | RUY_MAKE_ZERO(v19) |
8225 | "str q20, [x3, #0]\n" |
8226 | "str q21, [x3, #16]\n" |
8227 | "add x3, x3, x4\n" |
8228 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8229 | RUY_MAKE_ZERO(v20) |
8230 | RUY_MAKE_ZERO(v21) |
8231 | "str q22, [x3, #0]\n" |
8232 | "str q23, [x3, #16]\n" |
8233 | "add x3, x3, x4\n" |
8234 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8235 | RUY_MAKE_ZERO(v22) |
8236 | RUY_MAKE_ZERO(v23) |
8237 | "str q24, [x3, #0]\n" |
8238 | "str q25, [x3, #16]\n" |
8239 | "add x3, x3, x4\n" |
8240 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8241 | RUY_MAKE_ZERO(v24) |
8242 | RUY_MAKE_ZERO(v25) |
8243 | "str q26, [x3, #0]\n" |
8244 | "str q27, [x3, #16]\n" |
8245 | "add x3, x3, x4\n" |
8246 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8247 | RUY_MAKE_ZERO(v26) |
8248 | RUY_MAKE_ZERO(v27) |
8249 | "str q28, [x3, #0]\n" |
8250 | "str q29, [x3, #16]\n" |
8251 | "add x3, x3, x4\n" |
8252 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8253 | RUY_MAKE_ZERO(v28) |
8254 | RUY_MAKE_ZERO(v29) |
8255 | "str q30, [x3, #0]\n" |
8256 | "str q31, [x3, #16]\n" |
8257 | RUY_MAKE_ZERO(v30) |
8258 | RUY_MAKE_ZERO(v31) |
8259 | |
8260 | // If all of the 8x8 block fits, we just finished writing it to the |
8261 | // destination, so we skip the next part. |
8262 | "beq 41f\n" |
8263 | // Not all of the 8x8 block fits in the destination matrix. We just |
8264 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
8265 | // it to copy into the destination matrix the part that fits. |
8266 | "mov x3, %[dst_tmp_buf]\n" |
8267 | "mov x4, %[dst_ptr]\n" |
8268 | "mov w6, #0\n" |
8269 | "50:\n" |
8270 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
8271 | "mov w5, #0\n" |
8272 | "51:\n" |
8273 | "ldr w7, [x3, x5, lsl #2]\n" |
8274 | "str w7, [x4, x5, lsl #2]\n" |
8275 | "add w5, w5, #1\n" |
8276 | "cmp w5, w1\n" |
8277 | "blt 51b\n" |
8278 | "add w6, w6, #1\n" |
8279 | "add x3, x3, #32\n" |
8280 | "add x4, x4, x11\n" |
8281 | "cmp w6, w2\n" |
8282 | "blt 50b\n" |
8283 | "41:\n" |
8284 | "add %[dst_ptr], %[dst_ptr], #32\n" |
8285 | // At this point we have completely finished writing values to the |
8286 | // destination matrix for the current block. |
8287 | |
8288 | // Reload some params --- we had used x5 -- x7 for a few other things |
8289 | // since the last time we had loaded them. |
8290 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
8291 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
8292 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
8293 | |
8294 | // Move to the next block of the destination matrix, for the next iter |
8295 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
8296 | // been updated earlier. |
8297 | // Have we reached the end row? |
8298 | "cmp %w[row], w7\n" |
8299 | "beq 20f\n" // yes, end row. |
8300 | // Not end row. Move to the next row. |
8301 | "add %w[row], %w[row], #8\n" |
8302 | "b 21f\n" |
8303 | "20:\n" |
8304 | // Was already at end row. |
8305 | "mov %w[row], w6\n" // Move back to first row. |
8306 | "add %w[col], %w[col], #8\n" // Move to the next column. |
8307 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
8308 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
8309 | "21:\n" |
8310 | |
8311 | // Main loop exit condition: have we hit the end column? |
8312 | "cmp %w[col], w8\n" |
8313 | |
8314 | // w1 is the number of levels of depth that we have already loaded |
8315 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
8316 | // above, this is currently 1. |
8317 | "mov w1, #1\n" |
8318 | |
8319 | "ble 1b\n" |
8320 | |
8321 | // clang-format on |
8322 | |
8323 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
8324 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
8325 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
8326 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
8327 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
8328 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
8329 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
8330 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
8331 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
8332 | } |
8333 | |
8334 | // A fork of the standard float kernel where we omit the manual loop unrolling |
8335 | // to recover performance on the X1. For now, the X1 core is the only CPU that |
8336 | // uses this kernel. |
8337 | void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params) { |
8338 | CheckOffsetsInKernelParamsFloat(params); |
8339 | profiler::ScopeLabel label("Kernel (kNeon) X1" ); |
8340 | |
8341 | const float* lhs_col_ptr = params.lhs_base_ptr; |
8342 | const float* rhs_col_ptr = params.rhs_base_ptr; |
8343 | const float* lhs_ptr = lhs_col_ptr; |
8344 | const float* rhs_ptr = rhs_col_ptr; |
8345 | float* dst_col_ptr = params.dst_base_ptr; |
8346 | float* dst_ptr = dst_col_ptr; |
8347 | int row = params.start_row; |
8348 | int col = params.start_col; |
8349 | |
8350 | // The asm kernel below has the following NEON register allocation: |
8351 | // |
8352 | // v16 -- v31 are accumulators. |
8353 | // During accumulation, v0 -- v15 are used to load data from LHS and RHS. |
8354 | // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and |
8355 | // v3 are used to load a 1x8 block of RHS, like this: |
8356 | // |
8357 | // RHS 1x8 block |
8358 | // /-----------------------------------------| |
8359 | // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| |
8360 | // \-----------------------------------------/ |
8361 | // LHS 8x1 block |
8362 | // /---------------------\ /-----------------------------------------| |
8363 | // | v0.s[0] | |v16.s[0] ... v30.s[0]| |
8364 | // | ... | | ... ... | |
8365 | // | v0.s[3] | |v16.s[3] ... v30.s[3]| |
8366 | // | v1.s[0] | |v17.s[0] ... v31.s[0]| |
8367 | // | ... | | ... ... | |
8368 | // | v1.s[3] | |v17.s[3] ... v31.s[3]| |
8369 | // \---------------------/ \-----------------------------------------/ |
8370 | // accumulators 8x8 block |
8371 | // |
8372 | // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step |
8373 | // is repeated 4 times, using 4x more registers for LHS and RHS, so that |
8374 | // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. |
8375 | // |
8376 | // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are |
8377 | // unused, and v8 -- v15 are used for floading parameters used for the |
8378 | // post-accumulation part of the kernel. |
8379 | asm volatile( |
8380 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
8381 | |
8382 | // clang-format off |
8383 | |
8384 | // Load some parameters into registers. |
8385 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
8386 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
8387 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
8388 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
8389 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
8390 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
8391 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
8392 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
8393 | |
8394 | // Load the first 32 bytes of LHS and RHS data. |
8395 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
8396 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
8397 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
8398 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
8399 | |
8400 | // Clear accumulators. |
8401 | RUY_MAKE_ZERO(v16) |
8402 | RUY_MAKE_ZERO(v17) |
8403 | RUY_MAKE_ZERO(v18) |
8404 | RUY_MAKE_ZERO(v19) |
8405 | RUY_MAKE_ZERO(v20) |
8406 | RUY_MAKE_ZERO(v21) |
8407 | RUY_MAKE_ZERO(v22) |
8408 | RUY_MAKE_ZERO(v23) |
8409 | RUY_MAKE_ZERO(v24) |
8410 | RUY_MAKE_ZERO(v25) |
8411 | RUY_MAKE_ZERO(v26) |
8412 | RUY_MAKE_ZERO(v27) |
8413 | RUY_MAKE_ZERO(v28) |
8414 | RUY_MAKE_ZERO(v29) |
8415 | RUY_MAKE_ZERO(v30) |
8416 | RUY_MAKE_ZERO(v31) |
8417 | |
8418 | // w1 is the number of levels of depth that we have already loaded |
8419 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
8420 | // above, this is currently 1. |
8421 | "mov w1, #1\n" |
8422 | |
8423 | // Main loop of the whole GEMM, over rows and columns of the |
8424 | // destination matrix. |
8425 | "1:\n" |
8426 | |
8427 | "fmla v16.4s, v0.4s, v2.s[0]\n" |
8428 | "fmla v18.4s, v0.4s, v2.s[1]\n" |
8429 | "fmla v20.4s, v0.4s, v2.s[2]\n" |
8430 | "fmla v22.4s, v0.4s, v2.s[3]\n" |
8431 | |
8432 | // Accumulation loop |
8433 | "cmp w1, w12\n" |
8434 | "beq 79f\n" |
8435 | |
8436 | "2:\n" |
8437 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
8438 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
8439 | "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" |
8440 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
8441 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
8442 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
8443 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
8444 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
8445 | "add w1, w1, #1\n" |
8446 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
8447 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
8448 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
8449 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
8450 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
8451 | "cmp w1, w12\n" |
8452 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
8453 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
8454 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
8455 | "fmla v16.4s, v0.4s, v4.s[0]\n" |
8456 | "fmla v18.4s, v0.4s, v4.s[1]\n" |
8457 | "mov v2.16b, v4.16b\n" |
8458 | "fmla v20.4s, v0.4s, v4.s[2]\n" |
8459 | "fmla v22.4s, v0.4s, v4.s[3]\n" |
8460 | "blt 2b\n" |
8461 | |
8462 | "79:\n" |
8463 | |
8464 | // End of the inner loop on depth. Now perform the remaining |
8465 | // multiply-adds of the last level of depth, for which the LHS |
8466 | // and RHS data is already loaded. |
8467 | |
8468 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
8469 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
8470 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
8471 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
8472 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
8473 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
8474 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
8475 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
8476 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
8477 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
8478 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
8479 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
8480 | |
8481 | // End of accumulation. The registers v16 -- v31 contain the final |
8482 | // int32 accumulator values of the current 8x8 destination block. |
8483 | // We now have to compute the final 8-bit values from these int32 |
8484 | // accumulators, and advance to the next 8x8 block. We intertwine |
8485 | // these two aspects whenever possible for optimal pipelining, both |
8486 | // at the data flow level (prefetch data for next block as early as |
8487 | // possible) and instruction pipelining level (some of the next-block |
8488 | // work can dual-issue with some of the final work on the current |
8489 | // block). |
8490 | |
8491 | // Logic to advance to the next block in preparation for the next |
8492 | // iteration of the main loop. For now, we only want to compute |
8493 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
8494 | // not yet ready to update the values of row and col, as we still need |
8495 | // the current values for the rest of the work on the current block. |
8496 | |
8497 | "cmp %w[row], w7\n" // Have we finished the last row? |
8498 | "bge 4f\n" // If finished last row, go to 4 |
8499 | // Not finished last row: then advance to next row. |
8500 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
8501 | "b 5f\n" |
8502 | "4:\n" // Finished last row... |
8503 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
8504 | // Now we need to advance to the next column. If we already |
8505 | // finished the last column, then in principle we are done, however |
8506 | // we can't just return here, as we need to allow the end work of the |
8507 | // current block to complete. The good news is that at this point it |
8508 | // doesn't matter what data we load for the next column, since |
8509 | // we will exit from the main loop below before actually storing |
8510 | // anything computed from that data. |
8511 | "cmp %w[col], w8\n" // Have we finished the last column? |
8512 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
8513 | // Not finished last column: then advance to next column. |
8514 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
8515 | "5:\n" |
8516 | |
8517 | // Set the LHS and RHS data pointers to the start of the columns just |
8518 | // computed. |
8519 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
8520 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
8521 | |
8522 | // Load some parameters needed for the end work on current block. |
8523 | "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
8524 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
8525 | |
8526 | // Determine the channel index. |
8527 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
8528 | "csel w3, %w[row], %w[col], eq\n" |
8529 | |
8530 | // Offset the bias pointer as needed given the current row, col. |
8531 | "add x5, x1, x3, lsl #2\n" |
8532 | |
8533 | // If there is no bias, use no offset, just address the passed zero |
8534 | // data. |
8535 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
8536 | "csel x1, x1, x5, eq\n" |
8537 | |
8538 | // Load 8 bias values. |
8539 | "ld1 {v14.4s}, [x1], #16\n" |
8540 | "ld1 {v15.4s}, [x1]\n" |
8541 | |
8542 | // Now that we know what LHS and RHS data the next iteration of the |
8543 | // main loop will need to load, we start loading the first 32 bytes of |
8544 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
8545 | // in the rest of the work on the current block. |
8546 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
8547 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
8548 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
8549 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
8550 | |
8551 | // Perform the bias-addition. |
8552 | // Jump based on channel dimension. |
8553 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
8554 | "bne 6f\n" |
8555 | // Case where channels are rows |
8556 | "fadd v16.4s, v16.4s, v14.4s\n" |
8557 | "fadd v17.4s, v17.4s, v15.4s\n" |
8558 | "fadd v18.4s, v18.4s, v14.4s\n" |
8559 | "fadd v19.4s, v19.4s, v15.4s\n" |
8560 | "fadd v20.4s, v20.4s, v14.4s\n" |
8561 | "fadd v21.4s, v21.4s, v15.4s\n" |
8562 | "fadd v22.4s, v22.4s, v14.4s\n" |
8563 | "fadd v23.4s, v23.4s, v15.4s\n" |
8564 | "fadd v24.4s, v24.4s, v14.4s\n" |
8565 | "fadd v25.4s, v25.4s, v15.4s\n" |
8566 | "fadd v26.4s, v26.4s, v14.4s\n" |
8567 | "fadd v27.4s, v27.4s, v15.4s\n" |
8568 | "fadd v28.4s, v28.4s, v14.4s\n" |
8569 | "fadd v29.4s, v29.4s, v15.4s\n" |
8570 | "fadd v30.4s, v30.4s, v14.4s\n" |
8571 | "fadd v31.4s, v31.4s, v15.4s\n" |
8572 | "b 7f\n" |
8573 | |
8574 | "6:\n" |
8575 | // Case where channels are columns |
8576 | "dup v8.4s, v14.s[0]\n" |
8577 | "dup v9.4s, v14.s[1]\n" |
8578 | "dup v10.4s, v14.s[2]\n" |
8579 | "dup v11.4s, v14.s[3]\n" |
8580 | "dup v12.4s, v15.s[0]\n" |
8581 | "dup v13.4s, v15.s[1]\n" |
8582 | "dup v14.4s, v15.s[2]\n" |
8583 | "dup v15.4s, v15.s[3]\n" |
8584 | "fadd v16.4s, v16.4s, v8.4s\n" |
8585 | "fadd v17.4s, v17.4s, v8.4s\n" |
8586 | "fadd v18.4s, v18.4s, v9.4s\n" |
8587 | "fadd v19.4s, v19.4s, v9.4s\n" |
8588 | "fadd v20.4s, v20.4s, v10.4s\n" |
8589 | "fadd v21.4s, v21.4s, v10.4s\n" |
8590 | "fadd v22.4s, v22.4s, v11.4s\n" |
8591 | "fadd v23.4s, v23.4s, v11.4s\n" |
8592 | "fadd v24.4s, v24.4s, v12.4s\n" |
8593 | "fadd v25.4s, v25.4s, v12.4s\n" |
8594 | "fadd v26.4s, v26.4s, v13.4s\n" |
8595 | "fadd v27.4s, v27.4s, v13.4s\n" |
8596 | "fadd v28.4s, v28.4s, v14.4s\n" |
8597 | "fadd v29.4s, v29.4s, v14.4s\n" |
8598 | "fadd v30.4s, v30.4s, v15.4s\n" |
8599 | "fadd v31.4s, v31.4s, v15.4s\n" |
8600 | "7:\n" |
8601 | |
8602 | // Load the clamp_min, clamp_max bounds |
8603 | "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
8604 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
8605 | "dup v14.4s, w2\n" // clamp_min |
8606 | "dup v15.4s, w3\n" // clamp_max |
8607 | |
8608 | // Apply the clamp_min bound |
8609 | "fmax v16.4s, v16.4s, v14.4s\n" |
8610 | "fmax v17.4s, v17.4s, v14.4s\n" |
8611 | "fmax v18.4s, v18.4s, v14.4s\n" |
8612 | "fmax v19.4s, v19.4s, v14.4s\n" |
8613 | "fmax v20.4s, v20.4s, v14.4s\n" |
8614 | "fmax v21.4s, v21.4s, v14.4s\n" |
8615 | "fmax v22.4s, v22.4s, v14.4s\n" |
8616 | "fmax v23.4s, v23.4s, v14.4s\n" |
8617 | "fmax v24.4s, v24.4s, v14.4s\n" |
8618 | "fmax v25.4s, v25.4s, v14.4s\n" |
8619 | "fmax v26.4s, v26.4s, v14.4s\n" |
8620 | "fmax v27.4s, v27.4s, v14.4s\n" |
8621 | "fmax v28.4s, v28.4s, v14.4s\n" |
8622 | "fmax v29.4s, v29.4s, v14.4s\n" |
8623 | "fmax v30.4s, v30.4s, v14.4s\n" |
8624 | "fmax v31.4s, v31.4s, v14.4s\n" |
8625 | |
8626 | // Apply the clamp_max bound |
8627 | "fmin v16.4s, v16.4s, v15.4s\n" |
8628 | "fmin v17.4s, v17.4s, v15.4s\n" |
8629 | "fmin v18.4s, v18.4s, v15.4s\n" |
8630 | "fmin v19.4s, v19.4s, v15.4s\n" |
8631 | "fmin v20.4s, v20.4s, v15.4s\n" |
8632 | "fmin v21.4s, v21.4s, v15.4s\n" |
8633 | "fmin v22.4s, v22.4s, v15.4s\n" |
8634 | "fmin v23.4s, v23.4s, v15.4s\n" |
8635 | "fmin v24.4s, v24.4s, v15.4s\n" |
8636 | "fmin v25.4s, v25.4s, v15.4s\n" |
8637 | "fmin v26.4s, v26.4s, v15.4s\n" |
8638 | "fmin v27.4s, v27.4s, v15.4s\n" |
8639 | "fmin v28.4s, v28.4s, v15.4s\n" |
8640 | "fmin v29.4s, v29.4s, v15.4s\n" |
8641 | "fmin v30.4s, v30.4s, v15.4s\n" |
8642 | "fmin v31.4s, v31.4s, v15.4s\n" |
8643 | |
8644 | // Compute how much of the 8x8 block of destination 8bit values that |
8645 | // we have computed, fit in the destination matrix. Typically, all of |
8646 | // it fits, but when the destination matrix shape is not a multiple |
8647 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
8648 | // not fit entirely. |
8649 | "sub w1, %w[dst_rows], %w[row]\n" |
8650 | "sub w2, %w[dst_cols], %w[col]\n" |
8651 | "mov w3, #8\n" |
8652 | "cmp w1, #8\n" |
8653 | // Compute w1 = how many rows of the 8x8 block fit |
8654 | "csel w1, w1, w3, le\n" |
8655 | "cmp w2, #8\n" |
8656 | // Compute w2 = how many cols of the 8x8 block fit |
8657 | "csel w2, w2, w3, le\n" |
8658 | |
8659 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
8660 | "cmp w1, w3\n" |
8661 | "ccmp w2, w3, 0, eq\n" |
8662 | // Yes, all of the 8x8 block fits, go to fast path. |
8663 | "beq 30f\n" |
8664 | // Not all of the 8x8 block fits. |
8665 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
8666 | "mov x3, %[dst_tmp_buf]\n" |
8667 | "mov x4, #32\n" |
8668 | "b 31f\n" |
8669 | "30:\n" |
8670 | // Yes, all of the 8x8 block fits. |
8671 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
8672 | "mov x3, %[dst_ptr]\n" |
8673 | "mov x4, x11\n" |
8674 | "31:\n" |
8675 | |
8676 | // Write our 8bit values to the destination described by |
8677 | // (x3 address, x4 stride). |
8678 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8679 | "str q16, [x3, #0]\n" |
8680 | "str q17, [x3, #16]\n" |
8681 | "add x3, x3, x4\n" |
8682 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8683 | RUY_MAKE_ZERO(v16) |
8684 | RUY_MAKE_ZERO(v17) |
8685 | "str q18, [x3, #0]\n" |
8686 | "str q19, [x3, #16]\n" |
8687 | "add x3, x3, x4\n" |
8688 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8689 | RUY_MAKE_ZERO(v18) |
8690 | RUY_MAKE_ZERO(v19) |
8691 | "str q20, [x3, #0]\n" |
8692 | "str q21, [x3, #16]\n" |
8693 | "add x3, x3, x4\n" |
8694 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8695 | RUY_MAKE_ZERO(v20) |
8696 | RUY_MAKE_ZERO(v21) |
8697 | "str q22, [x3, #0]\n" |
8698 | "str q23, [x3, #16]\n" |
8699 | "add x3, x3, x4\n" |
8700 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8701 | RUY_MAKE_ZERO(v22) |
8702 | RUY_MAKE_ZERO(v23) |
8703 | "str q24, [x3, #0]\n" |
8704 | "str q25, [x3, #16]\n" |
8705 | "add x3, x3, x4\n" |
8706 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8707 | RUY_MAKE_ZERO(v24) |
8708 | RUY_MAKE_ZERO(v25) |
8709 | "str q26, [x3, #0]\n" |
8710 | "str q27, [x3, #16]\n" |
8711 | "add x3, x3, x4\n" |
8712 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8713 | RUY_MAKE_ZERO(v26) |
8714 | RUY_MAKE_ZERO(v27) |
8715 | "str q28, [x3, #0]\n" |
8716 | "str q29, [x3, #16]\n" |
8717 | "add x3, x3, x4\n" |
8718 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
8719 | RUY_MAKE_ZERO(v28) |
8720 | RUY_MAKE_ZERO(v29) |
8721 | "str q30, [x3, #0]\n" |
8722 | "str q31, [x3, #16]\n" |
8723 | RUY_MAKE_ZERO(v30) |
8724 | RUY_MAKE_ZERO(v31) |
8725 | |
8726 | // If all of the 8x8 block fits, we just finished writing it to the |
8727 | // destination, so we skip the next part. |
8728 | "beq 41f\n" |
8729 | // Not all of the 8x8 block fits in the destination matrix. We just |
8730 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
8731 | // it to copy into the destination matrix the part that fits. |
8732 | "mov x3, %[dst_tmp_buf]\n" |
8733 | "mov x4, %[dst_ptr]\n" |
8734 | "mov w6, #0\n" |
8735 | "50:\n" |
8736 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
8737 | "mov w5, #0\n" |
8738 | "51:\n" |
8739 | "ldr w7, [x3, x5, lsl #2]\n" |
8740 | "str w7, [x4, x5, lsl #2]\n" |
8741 | "add w5, w5, #1\n" |
8742 | "cmp w5, w1\n" |
8743 | "blt 51b\n" |
8744 | "add w6, w6, #1\n" |
8745 | "add x3, x3, #32\n" |
8746 | "add x4, x4, x11\n" |
8747 | "cmp w6, w2\n" |
8748 | "blt 50b\n" |
8749 | "41:\n" |
8750 | "add %[dst_ptr], %[dst_ptr], #32\n" |
8751 | // At this point we have completely finished writing values to the |
8752 | // destination matrix for the current block. |
8753 | |
8754 | // Reload some params --- we had used x5 -- x7 for a few other things |
8755 | // since the last time we had loaded them. |
8756 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
8757 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
8758 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
8759 | |
8760 | // Move to the next block of the destination matrix, for the next iter |
8761 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
8762 | // been updated earlier. |
8763 | // Have we reached the end row? |
8764 | "cmp %w[row], w7\n" |
8765 | "beq 20f\n" // yes, end row. |
8766 | // Not end row. Move to the next row. |
8767 | "add %w[row], %w[row], #8\n" |
8768 | "b 21f\n" |
8769 | "20:\n" |
8770 | // Was already at end row. |
8771 | "mov %w[row], w6\n" // Move back to first row. |
8772 | "add %w[col], %w[col], #8\n" // Move to the next column. |
8773 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
8774 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
8775 | "21:\n" |
8776 | |
8777 | // Main loop exit condition: have we hit the end column? |
8778 | "cmp %w[col], w8\n" |
8779 | |
8780 | // w1 is the number of levels of depth that we have already loaded |
8781 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
8782 | // above, this is currently 1. |
8783 | "mov w1, #1\n" |
8784 | |
8785 | "ble 1b\n" |
8786 | |
8787 | // clang-format on |
8788 | |
8789 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
8790 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
8791 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
8792 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
8793 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
8794 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
8795 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
8796 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
8797 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
8798 | } |
8799 | |
8800 | // Variant of KernelFloatNeon tuned for in-order CPUs that do not |
8801 | // support dotprod (while dotprod by itself is not relevant to floating-point, |
8802 | // this additional bit of information that we have about the target happens to |
8803 | // be useful here). |
8804 | // |
8805 | // So a typical target CPU here would be ARM Cortex-A53 or the original |
8806 | // Cortex-A55. |
8807 | // |
8808 | // This kernel is similar to and inspired by gemmlowp's |
8809 | // NEON_64bit_GEMM_Float32_WithScalar_A53. |
8810 | // which was contributed by David Mansell with very helpful |
8811 | // comments. Specifically, see this comment about tuning for Cortex-A53: |
8812 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 |
8813 | void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params) { |
8814 | profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)" ); |
8815 | |
8816 | CheckOffsetsInKernelParamsFloat(params); |
8817 | |
8818 | const float* lhs_col_ptr = params.lhs_base_ptr; |
8819 | const float* rhs_col_ptr = params.rhs_base_ptr; |
8820 | const float* lhs_ptr = lhs_col_ptr; |
8821 | const float* rhs_ptr = rhs_col_ptr; |
8822 | float* dst_col_ptr = params.dst_base_ptr; |
8823 | float* dst_ptr = dst_col_ptr; |
8824 | int row = params.start_row; |
8825 | int col = params.start_col; |
8826 | |
8827 | // The asm kernel below has the following NEON register allocation: |
8828 | // |
8829 | // v16 -- v31 are accumulators. |
8830 | // During accumulation, v0 -- v3 are used to load data from LHS and RHS. |
8831 | // |
8832 | // RHS 1x8 block |
8833 | // /-----------------------------------------| |
8834 | // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| |
8835 | // \-----------------------------------------/ |
8836 | // LHS 8x1 block |
8837 | // /---------------------\ /-----------------------------------------| |
8838 | // | v0.s[0] | |v16.s[0] ... v30.s[0]| |
8839 | // | ... | | ... ... | |
8840 | // | v0.s[3] | |v16.s[3] ... v30.s[3]| |
8841 | // | v1.s[0] | |v17.s[0] ... v31.s[0]| |
8842 | // | ... | | ... ... | |
8843 | // | v1.s[3] | |v17.s[3] ... v31.s[3]| |
8844 | // \---------------------/ \-----------------------------------------/ |
8845 | // accumulators 8x8 block |
8846 | // |
8847 | // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because |
8848 | // we did not observe a benefit of such partial unrolling on in-order CPUs. |
8849 | // |
8850 | // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used |
8851 | // for the post-accumulation part of the kernel. |
8852 | asm volatile( |
8853 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
8854 | |
8855 | // clang-format off |
8856 | |
8857 | // Load some parameters into registers. |
8858 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
8859 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
8860 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
8861 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
8862 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
8863 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
8864 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
8865 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
8866 | |
8867 | |
8868 | // Clear accumulators. |
8869 | RUY_MAKE_ZERO(v16) |
8870 | // Load the first 32 bytes of LHS and RHS data. |
8871 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
8872 | RUY_MAKE_ZERO(v17) |
8873 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
8874 | RUY_MAKE_ZERO(v18) |
8875 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
8876 | RUY_MAKE_ZERO(v19) |
8877 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
8878 | RUY_MAKE_ZERO(v20) |
8879 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n" ) |
8880 | RUY_MAKE_ZERO(v21) |
8881 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n" ) |
8882 | RUY_MAKE_ZERO(v22) |
8883 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n" ) |
8884 | RUY_MAKE_ZERO(v23) |
8885 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n" ) |
8886 | RUY_MAKE_ZERO(v24) |
8887 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n" ) |
8888 | RUY_MAKE_ZERO(v25) |
8889 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n" ) |
8890 | RUY_MAKE_ZERO(v26) |
8891 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n" ) |
8892 | RUY_MAKE_ZERO(v27) |
8893 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n" ) |
8894 | RUY_MAKE_ZERO(v28) |
8895 | RUY_MAKE_ZERO(v29) |
8896 | RUY_MAKE_ZERO(v30) |
8897 | RUY_MAKE_ZERO(v31) |
8898 | |
8899 | // w1 is the number of levels of depth that remain to load |
8900 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
8901 | // above, this is currently depth - 1. |
8902 | "sub w1, w12, #1\n" |
8903 | |
8904 | // Main loop of the whole GEMM, over rows and columns of the |
8905 | // destination matrix. |
8906 | "1:\n" |
8907 | |
8908 | "cmp w1, #0\n" |
8909 | "fmla v16.4s, v0.4s, v2.s[0]\n" |
8910 | "fmla v18.4s, v0.4s, v2.s[1]\n" |
8911 | "fmla v20.4s, v0.4s, v2.s[2]\n" |
8912 | "fmla v22.4s, v0.4s, v2.s[3]\n" |
8913 | |
8914 | // Accumulation loop |
8915 | "beq 79f\n" |
8916 | |
8917 | "2:\n" |
8918 | |
8919 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
8920 | "ldr x2, [%[lhs_ptr], #8]\n" |
8921 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
8922 | "ldr x3, [%[lhs_ptr], #24]\n" |
8923 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
8924 | "ldr x5, [%[rhs_ptr], #24]\n" |
8925 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
8926 | "ldr x4, [%[rhs_ptr], #8]\n" |
8927 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
8928 | "subs w1, w1, #1\n" |
8929 | "ldr d0, [%[lhs_ptr]], #32\n" |
8930 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
8931 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
8932 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
8933 | "ins v0.d[1], x2\n" |
8934 | "ldr d3, [%[rhs_ptr], #16]\n" |
8935 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
8936 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
8937 | "ins v3.d[1], x5\n" |
8938 | "ldr d4, [%[rhs_ptr]], #32\n" |
8939 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
8940 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
8941 | "fmla v16.4s, v0.4s, v4.s[0]\n" |
8942 | "ins v4.d[1], x4\n" |
8943 | "ldr d1, [%[lhs_ptr], #-16]\n" |
8944 | "fmla v18.4s, v0.4s, v4.s[1]\n" |
8945 | "fmla v20.4s, v0.4s, v4.s[2]\n" |
8946 | "ins v1.d[1], x3\n" |
8947 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n" ) |
8948 | "mov v2.16b, v4.16b\n" |
8949 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n" ) |
8950 | "fmla v22.4s, v0.4s, v4.s[3]\n" |
8951 | "bne 2b\n" |
8952 | |
8953 | "79:\n" |
8954 | |
8955 | // End of the inner loop on depth. Now perform the remaining |
8956 | // multiply-adds of the last level of depth, for which the LHS |
8957 | // and RHS data is already loaded. |
8958 | |
8959 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
8960 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
8961 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
8962 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
8963 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
8964 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
8965 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
8966 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
8967 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
8968 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
8969 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
8970 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
8971 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
8972 | |
8973 | // End of accumulation. The registers v16 -- v31 contain the final |
8974 | // int32 accumulator values of the current 8x8 destination block. |
8975 | // We now have to compute the final 8-bit values from these int32 |
8976 | // accumulators, and advance to the next 8x8 block. We intertwine |
8977 | // these two aspects whenever possible for optimal pipelining, both |
8978 | // at the data flow level (prefetch data for next block as early as |
8979 | // possible) and instruction pipelining level (some of the next-block |
8980 | // work can dual-issue with some of the final work on the current |
8981 | // block). |
8982 | |
8983 | // Logic to advance to the next block in preparation for the next |
8984 | // iteration of the main loop. For now, we only want to compute |
8985 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
8986 | // not yet ready to update the values of row and col, as we still need |
8987 | // the current values for the rest of the work on the current block. |
8988 | |
8989 | "cmp %w[row], w7\n" // Have we finished the last row? |
8990 | "bge 4f\n" // If finished last row, go to 4 |
8991 | // Not finished last row: then advance to next row. |
8992 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
8993 | "b 5f\n" |
8994 | "4:\n" // Finished last row... |
8995 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
8996 | // Now we need to advance to the next column. If we already |
8997 | // finished the last column, then in principle we are done, however |
8998 | // we can't just return here, as we need to allow the end work of the |
8999 | // current block to complete. The good news is that at this point it |
9000 | // doesn't matter what data we load for the next column, since |
9001 | // we will exit from the main loop below before actually storing |
9002 | // anything computed from that data. |
9003 | "cmp %w[col], w8\n" // Have we finished the last column? |
9004 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
9005 | // Not finished last column: then advance to next column. |
9006 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
9007 | "5:\n" |
9008 | |
9009 | // Set the LHS and RHS data pointers to the start of the columns just |
9010 | // computed. |
9011 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
9012 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
9013 | |
9014 | // Load some parameters needed for the end work on current block. |
9015 | "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
9016 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
9017 | |
9018 | // Determine the channel index. |
9019 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
9020 | "csel w3, %w[row], %w[col], eq\n" |
9021 | |
9022 | // Offset the bias pointer as needed given the current row, col. |
9023 | "add x5, x1, x3, lsl #2\n" |
9024 | |
9025 | // If there is no bias, use no offset, just address the passed zero |
9026 | // data. |
9027 | |
9028 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
9029 | "csel x1, x1, x5, eq\n" |
9030 | |
9031 | // Load 8 bias values. |
9032 | "ld1 {v14.4s}, [x1], #16\n" |
9033 | "ld1 {v15.4s}, [x1]\n" |
9034 | |
9035 | // Now that we know what LHS and RHS data the next iteration of the |
9036 | // main loop will need to load, we start loading the first 32 bytes of |
9037 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
9038 | // in the rest of the work on the current block. |
9039 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
9040 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
9041 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
9042 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
9043 | |
9044 | // Perform the bias-addition. |
9045 | // Jump based on channel dimension. |
9046 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
9047 | "bne 6f\n" |
9048 | // Case where channels are rows |
9049 | "fadd v16.4s, v16.4s, v14.4s\n" |
9050 | "fadd v17.4s, v17.4s, v15.4s\n" |
9051 | "fadd v18.4s, v18.4s, v14.4s\n" |
9052 | "fadd v19.4s, v19.4s, v15.4s\n" |
9053 | "fadd v20.4s, v20.4s, v14.4s\n" |
9054 | "fadd v21.4s, v21.4s, v15.4s\n" |
9055 | "fadd v22.4s, v22.4s, v14.4s\n" |
9056 | "fadd v23.4s, v23.4s, v15.4s\n" |
9057 | "fadd v24.4s, v24.4s, v14.4s\n" |
9058 | "fadd v25.4s, v25.4s, v15.4s\n" |
9059 | "fadd v26.4s, v26.4s, v14.4s\n" |
9060 | "fadd v27.4s, v27.4s, v15.4s\n" |
9061 | "fadd v28.4s, v28.4s, v14.4s\n" |
9062 | "fadd v29.4s, v29.4s, v15.4s\n" |
9063 | "fadd v30.4s, v30.4s, v14.4s\n" |
9064 | "fadd v31.4s, v31.4s, v15.4s\n" |
9065 | "b 7f\n" |
9066 | |
9067 | "6:\n" |
9068 | // Case where channels are columns |
9069 | "dup v8.4s, v14.s[0]\n" |
9070 | "dup v9.4s, v14.s[1]\n" |
9071 | "fadd v16.4s, v16.4s, v8.4s\n" |
9072 | "dup v10.4s, v14.s[2]\n" |
9073 | "fadd v17.4s, v17.4s, v8.4s\n" |
9074 | "dup v11.4s, v14.s[3]\n" |
9075 | "fadd v18.4s, v18.4s, v9.4s\n" |
9076 | "dup v12.4s, v15.s[0]\n" |
9077 | "fadd v19.4s, v19.4s, v9.4s\n" |
9078 | "dup v13.4s, v15.s[1]\n" |
9079 | "fadd v20.4s, v20.4s, v10.4s\n" |
9080 | "dup v14.4s, v15.s[2]\n" |
9081 | "fadd v21.4s, v21.4s, v10.4s\n" |
9082 | "dup v15.4s, v15.s[3]\n" |
9083 | "fadd v22.4s, v22.4s, v11.4s\n" |
9084 | "fadd v23.4s, v23.4s, v11.4s\n" |
9085 | "fadd v24.4s, v24.4s, v12.4s\n" |
9086 | "fadd v25.4s, v25.4s, v12.4s\n" |
9087 | "fadd v26.4s, v26.4s, v13.4s\n" |
9088 | "fadd v27.4s, v27.4s, v13.4s\n" |
9089 | "fadd v28.4s, v28.4s, v14.4s\n" |
9090 | "fadd v29.4s, v29.4s, v14.4s\n" |
9091 | "fadd v30.4s, v30.4s, v15.4s\n" |
9092 | "fadd v31.4s, v31.4s, v15.4s\n" |
9093 | "7:\n" |
9094 | |
9095 | // Load the clamp_min, clamp_max bounds |
9096 | "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
9097 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
9098 | "dup v14.4s, w2\n" // clamp_min |
9099 | "dup v15.4s, w3\n" // clamp_max |
9100 | |
9101 | // Apply the clamp_min bound |
9102 | "fmax v16.4s, v16.4s, v14.4s\n" |
9103 | "fmax v17.4s, v17.4s, v14.4s\n" |
9104 | "fmax v18.4s, v18.4s, v14.4s\n" |
9105 | "fmax v19.4s, v19.4s, v14.4s\n" |
9106 | "fmax v20.4s, v20.4s, v14.4s\n" |
9107 | "fmax v21.4s, v21.4s, v14.4s\n" |
9108 | "fmax v22.4s, v22.4s, v14.4s\n" |
9109 | "fmax v23.4s, v23.4s, v14.4s\n" |
9110 | "fmax v24.4s, v24.4s, v14.4s\n" |
9111 | "fmax v25.4s, v25.4s, v14.4s\n" |
9112 | "fmax v26.4s, v26.4s, v14.4s\n" |
9113 | "fmax v27.4s, v27.4s, v14.4s\n" |
9114 | "fmax v28.4s, v28.4s, v14.4s\n" |
9115 | "fmax v29.4s, v29.4s, v14.4s\n" |
9116 | "fmax v30.4s, v30.4s, v14.4s\n" |
9117 | "fmax v31.4s, v31.4s, v14.4s\n" |
9118 | |
9119 | // Apply the clamp_max bound |
9120 | "fmin v16.4s, v16.4s, v15.4s\n" |
9121 | "fmin v17.4s, v17.4s, v15.4s\n" |
9122 | "fmin v18.4s, v18.4s, v15.4s\n" |
9123 | "fmin v19.4s, v19.4s, v15.4s\n" |
9124 | "fmin v20.4s, v20.4s, v15.4s\n" |
9125 | "fmin v21.4s, v21.4s, v15.4s\n" |
9126 | "fmin v22.4s, v22.4s, v15.4s\n" |
9127 | "fmin v23.4s, v23.4s, v15.4s\n" |
9128 | "fmin v24.4s, v24.4s, v15.4s\n" |
9129 | "fmin v25.4s, v25.4s, v15.4s\n" |
9130 | "fmin v26.4s, v26.4s, v15.4s\n" |
9131 | "fmin v27.4s, v27.4s, v15.4s\n" |
9132 | "fmin v28.4s, v28.4s, v15.4s\n" |
9133 | "fmin v29.4s, v29.4s, v15.4s\n" |
9134 | "fmin v30.4s, v30.4s, v15.4s\n" |
9135 | "fmin v31.4s, v31.4s, v15.4s\n" |
9136 | |
9137 | // Compute how much of the 8x8 block of destination 8bit values that |
9138 | // we have computed, fit in the destination matrix. Typically, all of |
9139 | // it fits, but when the destination matrix shape is not a multiple |
9140 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
9141 | // not fit entirely. |
9142 | "sub w1, %w[dst_rows], %w[row]\n" |
9143 | "sub w2, %w[dst_cols], %w[col]\n" |
9144 | "mov w3, #8\n" |
9145 | "cmp w1, #8\n" |
9146 | // Compute w1 = how many rows of the 8x8 block fit |
9147 | "csel w1, w1, w3, le\n" |
9148 | "cmp w2, #8\n" |
9149 | // Compute w2 = how many cols of the 8x8 block fit |
9150 | "csel w2, w2, w3, le\n" |
9151 | |
9152 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
9153 | "cmp w1, w3\n" |
9154 | "ccmp w2, w3, 0, eq\n" |
9155 | // Yes, all of the 8x8 block fits, go to fast path. |
9156 | "beq 30f\n" |
9157 | // Not all of the 8x8 block fits. |
9158 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
9159 | "mov x3, %[dst_tmp_buf]\n" |
9160 | "mov x4, #32\n" |
9161 | "b 31f\n" |
9162 | "30:\n" |
9163 | // Yes, all of the 8x8 block fits. |
9164 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
9165 | "mov x3, %[dst_ptr]\n" |
9166 | "mov x4, x11\n" |
9167 | "31:\n" |
9168 | |
9169 | // Write our 8bit values to the destination described by |
9170 | // (x3 address, x4 stride). |
9171 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9172 | "str q16, [x3, #0]\n" |
9173 | "str q17, [x3, #16]\n" |
9174 | "add x3, x3, x4\n" |
9175 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9176 | RUY_MAKE_ZERO(v16) |
9177 | RUY_MAKE_ZERO(v17) |
9178 | "str q18, [x3, #0]\n" |
9179 | "str q19, [x3, #16]\n" |
9180 | "add x3, x3, x4\n" |
9181 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9182 | RUY_MAKE_ZERO(v18) |
9183 | RUY_MAKE_ZERO(v19) |
9184 | "str q20, [x3, #0]\n" |
9185 | "str q21, [x3, #16]\n" |
9186 | "add x3, x3, x4\n" |
9187 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9188 | RUY_MAKE_ZERO(v20) |
9189 | RUY_MAKE_ZERO(v21) |
9190 | "str q22, [x3, #0]\n" |
9191 | "str q23, [x3, #16]\n" |
9192 | "add x3, x3, x4\n" |
9193 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9194 | RUY_MAKE_ZERO(v22) |
9195 | RUY_MAKE_ZERO(v23) |
9196 | "str q24, [x3, #0]\n" |
9197 | "str q25, [x3, #16]\n" |
9198 | "add x3, x3, x4\n" |
9199 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9200 | RUY_MAKE_ZERO(v24) |
9201 | RUY_MAKE_ZERO(v25) |
9202 | "str q26, [x3, #0]\n" |
9203 | "str q27, [x3, #16]\n" |
9204 | "add x3, x3, x4\n" |
9205 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9206 | RUY_MAKE_ZERO(v26) |
9207 | RUY_MAKE_ZERO(v27) |
9208 | "str q28, [x3, #0]\n" |
9209 | "str q29, [x3, #16]\n" |
9210 | "add x3, x3, x4\n" |
9211 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9212 | RUY_MAKE_ZERO(v28) |
9213 | RUY_MAKE_ZERO(v29) |
9214 | "str q30, [x3, #0]\n" |
9215 | "str q31, [x3, #16]\n" |
9216 | RUY_MAKE_ZERO(v30) |
9217 | RUY_MAKE_ZERO(v31) |
9218 | |
9219 | // If all of the 8x8 block fits, we just finished writing it to the |
9220 | // destination, so we skip the next part. |
9221 | "beq 41f\n" |
9222 | // Not all of the 8x8 block fits in the destination matrix. We just |
9223 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
9224 | // it to copy into the destination matrix the part that fits. |
9225 | "mov x3, %[dst_tmp_buf]\n" |
9226 | "mov x4, %[dst_ptr]\n" |
9227 | "mov w6, #0\n" |
9228 | "50:\n" |
9229 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
9230 | "mov w5, #0\n" |
9231 | "51:\n" |
9232 | "ldr w7, [x3, x5, lsl #2]\n" |
9233 | "str w7, [x4, x5, lsl #2]\n" |
9234 | "add w5, w5, #1\n" |
9235 | "cmp w5, w1\n" |
9236 | "blt 51b\n" |
9237 | "add w6, w6, #1\n" |
9238 | "add x3, x3, #32\n" |
9239 | "add x4, x4, x11\n" |
9240 | "cmp w6, w2\n" |
9241 | "blt 50b\n" |
9242 | "41:\n" |
9243 | "add %[dst_ptr], %[dst_ptr], #32\n" |
9244 | // At this point we have completely finished writing values to the |
9245 | // destination matrix for the current block. |
9246 | |
9247 | // Reload some params --- we had used x5 -- x7 for a few other things |
9248 | // since the last time we had loaded them. |
9249 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
9250 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
9251 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
9252 | |
9253 | // Move to the next block of the destination matrix, for the next iter |
9254 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
9255 | // been updated earlier. |
9256 | // Have we reached the end row? |
9257 | "cmp %w[row], w7\n" |
9258 | "beq 20f\n" // yes, end row. |
9259 | // Not end row. Move to the next row. |
9260 | "add %w[row], %w[row], #8\n" |
9261 | "b 21f\n" |
9262 | "20:\n" |
9263 | // Was already at end row. |
9264 | "mov %w[row], w6\n" // Move back to first row. |
9265 | "add %w[col], %w[col], #8\n" // Move to the next column. |
9266 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
9267 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
9268 | "21:\n" |
9269 | |
9270 | // Main loop exit condition: have we hit the end column? |
9271 | "cmp %w[col], w8\n" |
9272 | |
9273 | // w1 is the number of levels of depth that remain to load |
9274 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
9275 | // above, this is currently depth - 1. |
9276 | "sub w1, w12, #1\n" |
9277 | |
9278 | "ble 1b\n" |
9279 | |
9280 | // clang-format on |
9281 | |
9282 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
9283 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
9284 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
9285 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
9286 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
9287 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
9288 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
9289 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
9290 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
9291 | } |
9292 | |
9293 | // Variant of KernelFloatNeonA55ish tuned for in-order CPUs that do |
9294 | // support dotprod (while dotprod by itself is not relevant to floating-point, |
9295 | // this additional bit of information that we have about the target happens to |
9296 | // be useful here). |
9297 | // |
9298 | // So a typical target CPU here would be ARM Cortex-A55r1. |
9299 | // |
9300 | // This kernel is similar to and inspired by gemmlowp's |
9301 | // NEON_64bit_GEMM_Float32_WithScalar_A55r1. |
9302 | // which was contributed by David Mansell with very helpful |
9303 | // comments. Specifically, see this comment about tuning for Cortex-A55r1: |
9304 | // https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 |
9305 | void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params) { |
9306 | profiler::ScopeLabel label( |
9307 | "Kernel (kNeonDotprod, optimized for in-order cores)" ); |
9308 | |
9309 | CheckOffsetsInKernelParamsFloat(params); |
9310 | |
9311 | const float* lhs_col_ptr = params.lhs_base_ptr; |
9312 | const float* rhs_col_ptr = params.rhs_base_ptr; |
9313 | const float* lhs_ptr = lhs_col_ptr; |
9314 | const float* rhs_ptr = rhs_col_ptr; |
9315 | float* dst_col_ptr = params.dst_base_ptr; |
9316 | float* dst_ptr = dst_col_ptr; |
9317 | int row = params.start_row; |
9318 | int col = params.start_col; |
9319 | |
9320 | // The asm kernel below has the following NEON register allocation: |
9321 | // |
9322 | // v16 -- v31 are accumulators. |
9323 | // During accumulation, v0 -- v3 are used to load data from LHS and RHS. |
9324 | // |
9325 | // RHS 1x8 block |
9326 | // /-----------------------------------------| |
9327 | // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| |
9328 | // \-----------------------------------------/ |
9329 | // LHS 8x1 block |
9330 | // /---------------------\ /-----------------------------------------| |
9331 | // | v0.s[0] | |v16.s[0] ... v30.s[0]| |
9332 | // | ... | | ... ... | |
9333 | // | v0.s[3] | |v16.s[3] ... v30.s[3]| |
9334 | // | v1.s[0] | |v17.s[0] ... v31.s[0]| |
9335 | // | ... | | ... ... | |
9336 | // | v1.s[3] | |v17.s[3] ... v31.s[3]| |
9337 | // \---------------------/ \-----------------------------------------/ |
9338 | // accumulators 8x8 block |
9339 | // |
9340 | // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because |
9341 | // we did not observe a benefit of such partial unrolling on in-order CPUs. |
9342 | // |
9343 | // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used |
9344 | // for the post-accumulation part of the kernel. |
9345 | asm volatile( |
9346 | #define RUY_MAKE_ZERO(reg) "movi " #reg ".4s, #0\n" |
9347 | |
9348 | // clang-format off |
9349 | |
9350 | // Load some parameters into registers. |
9351 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
9352 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
9353 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
9354 | "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
9355 | "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
9356 | "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
9357 | "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
9358 | "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
9359 | |
9360 | |
9361 | // Clear accumulators. |
9362 | RUY_MAKE_ZERO(v16) |
9363 | // Load the first 32 bytes of LHS and RHS data. |
9364 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
9365 | RUY_MAKE_ZERO(v17) |
9366 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
9367 | RUY_MAKE_ZERO(v18) |
9368 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
9369 | RUY_MAKE_ZERO(v19) |
9370 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
9371 | RUY_MAKE_ZERO(v20) |
9372 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n" ) |
9373 | RUY_MAKE_ZERO(v21) |
9374 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n" ) |
9375 | RUY_MAKE_ZERO(v22) |
9376 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n" ) |
9377 | RUY_MAKE_ZERO(v23) |
9378 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n" ) |
9379 | RUY_MAKE_ZERO(v24) |
9380 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n" ) |
9381 | RUY_MAKE_ZERO(v25) |
9382 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n" ) |
9383 | RUY_MAKE_ZERO(v26) |
9384 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n" ) |
9385 | RUY_MAKE_ZERO(v27) |
9386 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n" ) |
9387 | RUY_MAKE_ZERO(v28) |
9388 | RUY_MAKE_ZERO(v29) |
9389 | RUY_MAKE_ZERO(v30) |
9390 | RUY_MAKE_ZERO(v31) |
9391 | |
9392 | // w1 is the number of levels of depth that remain to load |
9393 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
9394 | // above, this is currently depth - 1. |
9395 | "sub w1, w12, #1\n" |
9396 | |
9397 | // Main loop of the whole GEMM, over rows and columns of the |
9398 | // destination matrix. |
9399 | "1:\n" |
9400 | |
9401 | "cmp w1, #0\n" |
9402 | "fmla v16.4s, v0.4s, v2.s[0]\n" |
9403 | "fmla v18.4s, v0.4s, v2.s[1]\n" |
9404 | "fmla v20.4s, v0.4s, v2.s[2]\n" |
9405 | "fmla v22.4s, v0.4s, v2.s[3]\n" |
9406 | |
9407 | // Accumulation loop |
9408 | "beq 79f\n" |
9409 | |
9410 | "2:\n" |
9411 | |
9412 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n" ) |
9413 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
9414 | "ldr x2, [%[lhs_ptr], #8]\n" |
9415 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
9416 | "ldr x3, [%[lhs_ptr], #24]\n" |
9417 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
9418 | "ldr x5, [%[rhs_ptr], #24]\n" |
9419 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
9420 | "ldr d0, [%[lhs_ptr]], #32\n" |
9421 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
9422 | "ldr x4, [%[rhs_ptr], #8]\n" |
9423 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
9424 | "subs w1, w1, #1\n" |
9425 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
9426 | "ins v0.d[1], x2\n" |
9427 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
9428 | "ldr d3, [%[rhs_ptr], #16]\n" |
9429 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
9430 | "ins v3.d[1], x5\n" |
9431 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
9432 | "ldr d4, [%[rhs_ptr]], #32\n" |
9433 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
9434 | "ins v4.d[1], x4\n" |
9435 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
9436 | RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n" ) |
9437 | "fmla v16.4s, v0.4s, v4.s[0]\n" |
9438 | "ldr d1, [%[lhs_ptr], #-16]\n" |
9439 | "fmla v18.4s, v0.4s, v4.s[1]\n" |
9440 | "ins v1.d[1], x3\n" |
9441 | "fmla v20.4s, v0.4s, v4.s[2]\n" |
9442 | "mov v2.16b, v4.16b\n" |
9443 | "fmla v22.4s, v0.4s, v4.s[3]\n" |
9444 | "bne 2b\n" |
9445 | |
9446 | "79:\n" |
9447 | |
9448 | // End of the inner loop on depth. Now perform the remaining |
9449 | // multiply-adds of the last level of depth, for which the LHS |
9450 | // and RHS data is already loaded. |
9451 | |
9452 | "fmla v24.4s, v0.4s, v3.s[0]\n" |
9453 | "fmla v26.4s, v0.4s, v3.s[1]\n" |
9454 | "fmla v28.4s, v0.4s, v3.s[2]\n" |
9455 | "fmla v30.4s, v0.4s, v3.s[3]\n" |
9456 | "fmla v25.4s, v1.4s, v3.s[0]\n" |
9457 | "fmla v27.4s, v1.4s, v3.s[1]\n" |
9458 | "fmla v29.4s, v1.4s, v3.s[2]\n" |
9459 | "fmla v31.4s, v1.4s, v3.s[3]\n" |
9460 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
9461 | "fmla v17.4s, v1.4s, v2.s[0]\n" |
9462 | "fmla v19.4s, v1.4s, v2.s[1]\n" |
9463 | "fmla v21.4s, v1.4s, v2.s[2]\n" |
9464 | "fmla v23.4s, v1.4s, v2.s[3]\n" |
9465 | |
9466 | // End of accumulation. The registers v16 -- v31 contain the final |
9467 | // int32 accumulator values of the current 8x8 destination block. |
9468 | // We now have to compute the final 8-bit values from these int32 |
9469 | // accumulators, and advance to the next 8x8 block. We intertwine |
9470 | // these two aspects whenever possible for optimal pipelining, both |
9471 | // at the data flow level (prefetch data for next block as early as |
9472 | // possible) and instruction pipelining level (some of the next-block |
9473 | // work can dual-issue with some of the final work on the current |
9474 | // block). |
9475 | |
9476 | // Logic to advance to the next block in preparation for the next |
9477 | // iteration of the main loop. For now, we only want to compute |
9478 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
9479 | // not yet ready to update the values of row and col, as we still need |
9480 | // the current values for the rest of the work on the current block. |
9481 | |
9482 | "cmp %w[row], w7\n" // Have we finished the last row? |
9483 | "bge 4f\n" // If finished last row, go to 4 |
9484 | // Not finished last row: then advance to next row. |
9485 | "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" |
9486 | "b 5f\n" |
9487 | "4:\n" // Finished last row... |
9488 | "mov %[lhs_col_ptr], x5\n" // Go back to first row |
9489 | // Now we need to advance to the next column. If we already |
9490 | // finished the last column, then in principle we are done, however |
9491 | // we can't just return here, as we need to allow the end work of the |
9492 | // current block to complete. The good news is that at this point it |
9493 | // doesn't matter what data we load for the next column, since |
9494 | // we will exit from the main loop below before actually storing |
9495 | // anything computed from that data. |
9496 | "cmp %w[col], w8\n" // Have we finished the last column? |
9497 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
9498 | // Not finished last column: then advance to next column. |
9499 | "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" |
9500 | "5:\n" |
9501 | |
9502 | // Set the LHS and RHS data pointers to the start of the columns just |
9503 | // computed. |
9504 | "mov %[lhs_ptr], %[lhs_col_ptr]\n" |
9505 | "mov %[rhs_ptr], %[rhs_col_ptr]\n" |
9506 | |
9507 | // Load some parameters needed for the end work on current block. |
9508 | "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
9509 | "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
9510 | |
9511 | // Determine the channel index. |
9512 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
9513 | "csel w3, %w[row], %w[col], eq\n" |
9514 | |
9515 | // Offset the bias pointer as needed given the current row, col. |
9516 | "add x5, x1, x3, lsl #2\n" |
9517 | |
9518 | // If there is no bias, use no offset, just address the passed zero |
9519 | // data. |
9520 | |
9521 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
9522 | "csel x1, x1, x5, eq\n" |
9523 | |
9524 | // Load 8 bias values. |
9525 | "ld1 {v14.4s}, [x1], #16\n" |
9526 | "ld1 {v15.4s}, [x1]\n" |
9527 | |
9528 | // Now that we know what LHS and RHS data the next iteration of the |
9529 | // main loop will need to load, we start loading the first 32 bytes of |
9530 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
9531 | // in the rest of the work on the current block. |
9532 | "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" |
9533 | "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" |
9534 | "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" |
9535 | "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" |
9536 | |
9537 | // Perform the bias-addition. |
9538 | // Jump based on channel dimension. |
9539 | "tst w4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
9540 | "bne 6f\n" |
9541 | // Case where channels are rows |
9542 | "fadd v16.4s, v16.4s, v14.4s\n" |
9543 | "fadd v17.4s, v17.4s, v15.4s\n" |
9544 | "fadd v18.4s, v18.4s, v14.4s\n" |
9545 | "fadd v19.4s, v19.4s, v15.4s\n" |
9546 | "fadd v20.4s, v20.4s, v14.4s\n" |
9547 | "fadd v21.4s, v21.4s, v15.4s\n" |
9548 | "fadd v22.4s, v22.4s, v14.4s\n" |
9549 | "fadd v23.4s, v23.4s, v15.4s\n" |
9550 | "fadd v24.4s, v24.4s, v14.4s\n" |
9551 | "fadd v25.4s, v25.4s, v15.4s\n" |
9552 | "fadd v26.4s, v26.4s, v14.4s\n" |
9553 | "fadd v27.4s, v27.4s, v15.4s\n" |
9554 | "fadd v28.4s, v28.4s, v14.4s\n" |
9555 | "fadd v29.4s, v29.4s, v15.4s\n" |
9556 | "fadd v30.4s, v30.4s, v14.4s\n" |
9557 | "fadd v31.4s, v31.4s, v15.4s\n" |
9558 | "b 7f\n" |
9559 | |
9560 | "6:\n" |
9561 | // Case where channels are columns |
9562 | "dup v8.4s, v14.s[0]\n" |
9563 | "dup v9.4s, v14.s[1]\n" |
9564 | "fadd v16.4s, v16.4s, v8.4s\n" |
9565 | "dup v10.4s, v14.s[2]\n" |
9566 | "fadd v17.4s, v17.4s, v8.4s\n" |
9567 | "dup v11.4s, v14.s[3]\n" |
9568 | "fadd v18.4s, v18.4s, v9.4s\n" |
9569 | "dup v12.4s, v15.s[0]\n" |
9570 | "fadd v19.4s, v19.4s, v9.4s\n" |
9571 | "dup v13.4s, v15.s[1]\n" |
9572 | "fadd v20.4s, v20.4s, v10.4s\n" |
9573 | "dup v14.4s, v15.s[2]\n" |
9574 | "fadd v21.4s, v21.4s, v10.4s\n" |
9575 | "dup v15.4s, v15.s[3]\n" |
9576 | "fadd v22.4s, v22.4s, v11.4s\n" |
9577 | "fadd v23.4s, v23.4s, v11.4s\n" |
9578 | "fadd v24.4s, v24.4s, v12.4s\n" |
9579 | "fadd v25.4s, v25.4s, v12.4s\n" |
9580 | "fadd v26.4s, v26.4s, v13.4s\n" |
9581 | "fadd v27.4s, v27.4s, v13.4s\n" |
9582 | "fadd v28.4s, v28.4s, v14.4s\n" |
9583 | "fadd v29.4s, v29.4s, v14.4s\n" |
9584 | "fadd v30.4s, v30.4s, v15.4s\n" |
9585 | "fadd v31.4s, v31.4s, v15.4s\n" |
9586 | "7:\n" |
9587 | |
9588 | // Load the clamp_min, clamp_max bounds |
9589 | "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
9590 | "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
9591 | "dup v14.4s, w2\n" // clamp_min |
9592 | "dup v15.4s, w3\n" // clamp_max |
9593 | |
9594 | // Apply the clamp_min bound |
9595 | "fmax v16.4s, v16.4s, v14.4s\n" |
9596 | "fmax v17.4s, v17.4s, v14.4s\n" |
9597 | "fmax v18.4s, v18.4s, v14.4s\n" |
9598 | "fmax v19.4s, v19.4s, v14.4s\n" |
9599 | "fmax v20.4s, v20.4s, v14.4s\n" |
9600 | "fmax v21.4s, v21.4s, v14.4s\n" |
9601 | "fmax v22.4s, v22.4s, v14.4s\n" |
9602 | "fmax v23.4s, v23.4s, v14.4s\n" |
9603 | "fmax v24.4s, v24.4s, v14.4s\n" |
9604 | "fmax v25.4s, v25.4s, v14.4s\n" |
9605 | "fmax v26.4s, v26.4s, v14.4s\n" |
9606 | "fmax v27.4s, v27.4s, v14.4s\n" |
9607 | "fmax v28.4s, v28.4s, v14.4s\n" |
9608 | "fmax v29.4s, v29.4s, v14.4s\n" |
9609 | "fmax v30.4s, v30.4s, v14.4s\n" |
9610 | "fmax v31.4s, v31.4s, v14.4s\n" |
9611 | |
9612 | // Apply the clamp_max bound |
9613 | "fmin v16.4s, v16.4s, v15.4s\n" |
9614 | "fmin v17.4s, v17.4s, v15.4s\n" |
9615 | "fmin v18.4s, v18.4s, v15.4s\n" |
9616 | "fmin v19.4s, v19.4s, v15.4s\n" |
9617 | "fmin v20.4s, v20.4s, v15.4s\n" |
9618 | "fmin v21.4s, v21.4s, v15.4s\n" |
9619 | "fmin v22.4s, v22.4s, v15.4s\n" |
9620 | "fmin v23.4s, v23.4s, v15.4s\n" |
9621 | "fmin v24.4s, v24.4s, v15.4s\n" |
9622 | "fmin v25.4s, v25.4s, v15.4s\n" |
9623 | "fmin v26.4s, v26.4s, v15.4s\n" |
9624 | "fmin v27.4s, v27.4s, v15.4s\n" |
9625 | "fmin v28.4s, v28.4s, v15.4s\n" |
9626 | "fmin v29.4s, v29.4s, v15.4s\n" |
9627 | "fmin v30.4s, v30.4s, v15.4s\n" |
9628 | "fmin v31.4s, v31.4s, v15.4s\n" |
9629 | |
9630 | // Compute how much of the 8x8 block of destination 8bit values that |
9631 | // we have computed, fit in the destination matrix. Typically, all of |
9632 | // it fits, but when the destination matrix shape is not a multiple |
9633 | // of 8x8, there are some 8x8 blocks along the boundaries that do |
9634 | // not fit entirely. |
9635 | "sub w1, %w[dst_rows], %w[row]\n" |
9636 | "sub w2, %w[dst_cols], %w[col]\n" |
9637 | "mov w3, #8\n" |
9638 | "cmp w1, #8\n" |
9639 | // Compute w1 = how many rows of the 8x8 block fit |
9640 | "csel w1, w1, w3, le\n" |
9641 | "cmp w2, #8\n" |
9642 | // Compute w2 = how many cols of the 8x8 block fit |
9643 | "csel w2, w2, w3, le\n" |
9644 | |
9645 | // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. |
9646 | "cmp w1, w3\n" |
9647 | "ccmp w2, w3, 0, eq\n" |
9648 | // Yes, all of the 8x8 block fits, go to fast path. |
9649 | "beq 30f\n" |
9650 | // Not all of the 8x8 block fits. |
9651 | // Set (x3 address, x4 stride) to write to dst_tmp_buf |
9652 | "mov x3, %[dst_tmp_buf]\n" |
9653 | "mov x4, #32\n" |
9654 | "b 31f\n" |
9655 | "30:\n" |
9656 | // Yes, all of the 8x8 block fits. |
9657 | // Set (x3 address, x4 stride) to write directly to destination matrix. |
9658 | "mov x3, %[dst_ptr]\n" |
9659 | "mov x4, x11\n" |
9660 | "31:\n" |
9661 | |
9662 | // Write our 8bit values to the destination described by |
9663 | // (x3 address, x4 stride). |
9664 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9665 | "str q16, [x3, #0]\n" |
9666 | "str q17, [x3, #16]\n" |
9667 | "add x3, x3, x4\n" |
9668 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9669 | RUY_MAKE_ZERO(v16) |
9670 | RUY_MAKE_ZERO(v17) |
9671 | "str q18, [x3, #0]\n" |
9672 | "str q19, [x3, #16]\n" |
9673 | "add x3, x3, x4\n" |
9674 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9675 | RUY_MAKE_ZERO(v18) |
9676 | RUY_MAKE_ZERO(v19) |
9677 | "str q20, [x3, #0]\n" |
9678 | "str q21, [x3, #16]\n" |
9679 | "add x3, x3, x4\n" |
9680 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9681 | RUY_MAKE_ZERO(v20) |
9682 | RUY_MAKE_ZERO(v21) |
9683 | "str q22, [x3, #0]\n" |
9684 | "str q23, [x3, #16]\n" |
9685 | "add x3, x3, x4\n" |
9686 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9687 | RUY_MAKE_ZERO(v22) |
9688 | RUY_MAKE_ZERO(v23) |
9689 | "str q24, [x3, #0]\n" |
9690 | "str q25, [x3, #16]\n" |
9691 | "add x3, x3, x4\n" |
9692 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9693 | RUY_MAKE_ZERO(v24) |
9694 | RUY_MAKE_ZERO(v25) |
9695 | "str q26, [x3, #0]\n" |
9696 | "str q27, [x3, #16]\n" |
9697 | "add x3, x3, x4\n" |
9698 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9699 | RUY_MAKE_ZERO(v26) |
9700 | RUY_MAKE_ZERO(v27) |
9701 | "str q28, [x3, #0]\n" |
9702 | "str q29, [x3, #16]\n" |
9703 | "add x3, x3, x4\n" |
9704 | RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n" ) |
9705 | RUY_MAKE_ZERO(v28) |
9706 | RUY_MAKE_ZERO(v29) |
9707 | "str q30, [x3, #0]\n" |
9708 | "str q31, [x3, #16]\n" |
9709 | RUY_MAKE_ZERO(v30) |
9710 | RUY_MAKE_ZERO(v31) |
9711 | |
9712 | // If all of the 8x8 block fits, we just finished writing it to the |
9713 | // destination, so we skip the next part. |
9714 | "beq 41f\n" |
9715 | // Not all of the 8x8 block fits in the destination matrix. We just |
9716 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
9717 | // it to copy into the destination matrix the part that fits. |
9718 | "mov x3, %[dst_tmp_buf]\n" |
9719 | "mov x4, %[dst_ptr]\n" |
9720 | "mov w6, #0\n" |
9721 | "50:\n" |
9722 | RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n" ) |
9723 | "mov w5, #0\n" |
9724 | "51:\n" |
9725 | "ldr w7, [x3, x5, lsl #2]\n" |
9726 | "str w7, [x4, x5, lsl #2]\n" |
9727 | "add w5, w5, #1\n" |
9728 | "cmp w5, w1\n" |
9729 | "blt 51b\n" |
9730 | "add w6, w6, #1\n" |
9731 | "add x3, x3, #32\n" |
9732 | "add x4, x4, x11\n" |
9733 | "cmp w6, w2\n" |
9734 | "blt 50b\n" |
9735 | "41:\n" |
9736 | "add %[dst_ptr], %[dst_ptr], #32\n" |
9737 | // At this point we have completely finished writing values to the |
9738 | // destination matrix for the current block. |
9739 | |
9740 | // Reload some params --- we had used x5 -- x7 for a few other things |
9741 | // since the last time we had loaded them. |
9742 | "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
9743 | "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
9744 | "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
9745 | |
9746 | // Move to the next block of the destination matrix, for the next iter |
9747 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
9748 | // been updated earlier. |
9749 | // Have we reached the end row? |
9750 | "cmp %w[row], w7\n" |
9751 | "beq 20f\n" // yes, end row. |
9752 | // Not end row. Move to the next row. |
9753 | "add %w[row], %w[row], #8\n" |
9754 | "b 21f\n" |
9755 | "20:\n" |
9756 | // Was already at end row. |
9757 | "mov %w[row], w6\n" // Move back to first row. |
9758 | "add %w[col], %w[col], #8\n" // Move to the next column. |
9759 | "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" |
9760 | "mov %[dst_ptr], %[dst_col_ptr]\n" |
9761 | "21:\n" |
9762 | |
9763 | // Main loop exit condition: have we hit the end column? |
9764 | "cmp %w[col], w8\n" |
9765 | |
9766 | // w1 is the number of levels of depth that remain to load |
9767 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
9768 | // above, this is currently depth - 1. |
9769 | "sub w1, w12, #1\n" |
9770 | |
9771 | "ble 1b\n" |
9772 | |
9773 | // clang-format on |
9774 | |
9775 | : [ lhs_col_ptr ] "+r" (lhs_col_ptr), [rhs_col_ptr] "+r" (rhs_col_ptr), |
9776 | [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), |
9777 | [dst_col_ptr] "+r" (dst_col_ptr), [dst_ptr] "+r" (dst_ptr), [row] "+r" (row), [col] "+r" (col) |
9778 | : [ params ] "r" (¶ms), [dst_rows] "r" (params.dst_rows), |
9779 | [dst_cols] "r" (params.dst_cols), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
9780 | : "x1" , "x2" , "x3" , "x4" , "x5" , "x6" , "x7" , "x8" , "x9" , "x10" , "x11" , "x12" , "x13" , "cc" , |
9781 | "memory" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
9782 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , |
9783 | "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
9784 | } |
9785 | #undef RUY_OFFSET_BIAS |
9786 | #undef RUY_OFFSET_FLAGS |
9787 | #undef RUY_OFFSET_LHS_BASE_PTR |
9788 | #undef RUY_OFFSET_CLAMP_MIN |
9789 | #undef RUY_OFFSET_CLAMP_MAX |
9790 | #undef RUY_OFFSET_START_ROW |
9791 | #undef RUY_OFFSET_LAST_ROW |
9792 | #undef RUY_OFFSET_LAST_COL |
9793 | #undef RUY_OFFSET_LHS_STRIDE |
9794 | #undef RUY_OFFSET_RHS_STRIDE |
9795 | #undef RUY_OFFSET_DST_STRIDE |
9796 | #undef RUY_OFFSET_DEPTH |
9797 | #undef RUY_OFFSET_START_COL |
9798 | #undef RUY_OFFSET_RHS_BASE_PTR |
9799 | #undef RUY_OFFSET_DST_BASE_PTR |
9800 | |
9801 | #endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
9802 | |
9803 | } // namespace ruy |
9804 | |