1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "ruy/kernel_arm.h" |
17 | #include "ruy/opt_set.h" |
18 | #include "ruy/platform.h" |
19 | #include "ruy/profiler/instrumentation.h" |
20 | |
21 | namespace ruy { |
22 | |
23 | #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) |
24 | |
25 | #define RUY_ASM_LABEL_STORE_UINT8 91 |
26 | #define RUY_ASM_LABEL_STORE_INT8 92 |
27 | #define RUY_ASM_LABEL_STORE_INT16 93 |
28 | #define RUY_ASM_LABEL_STORE_INT32 94 |
29 | #define RUY_ASM_LABEL_AFTER_STORE 99 |
30 | |
31 | #define RUY_OFFSET_LHS_BASE_PTR 0 |
32 | #define RUY_OFFSET_RHS_BASE_PTR 4 |
33 | #define RUY_OFFSET_DST_BASE_PTR 8 |
34 | #define RUY_OFFSET_BIAS 12 |
35 | #define RUY_OFFSET_START_ROW 16 |
36 | #define RUY_OFFSET_START_COL 20 |
37 | #define RUY_OFFSET_LAST_ROW 24 |
38 | #define RUY_OFFSET_LAST_COL 28 |
39 | #define RUY_OFFSET_DST_ROWS 32 |
40 | #define RUY_OFFSET_DST_COLS 36 |
41 | #define RUY_OFFSET_LHS_STRIDE 40 |
42 | #define RUY_OFFSET_RHS_STRIDE 44 |
43 | #define RUY_OFFSET_DST_STRIDE 48 |
44 | #define RUY_OFFSET_DEPTH 52 |
45 | #define RUY_OFFSET_CLAMP_MIN 56 |
46 | #define RUY_OFFSET_CLAMP_MAX 60 |
47 | #define RUY_OFFSET_FLAGS 64 |
48 | |
49 | #define RUY_STACK_OFFSET_SIZE 96 |
50 | #define RUY_STACK_OFFSET_DST_COL_PTR 0 |
51 | #define RUY_STACK_OFFSET_DST_PTR 16 |
52 | #define RUY_STACK_OFFSET_ROW 32 |
53 | #define RUY_STACK_OFFSET_COL 48 |
54 | #define RUY_STACK_OFFSET_LHS_COL_PTR 64 |
55 | #define RUY_STACK_OFFSET_RHS_COL_PTR 80 |
56 | |
57 | template <typename Params> |
58 | void CheckOffsetsInKernelParamsFloat32(const Params&) { |
59 | static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "" ); |
60 | static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "" ); |
61 | static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "" ); |
62 | static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "" ); |
63 | static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "" ); |
64 | static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "" ); |
65 | static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "" ); |
66 | static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "" ); |
67 | static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "" ); |
68 | static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "" ); |
69 | static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "" ); |
70 | static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "" ); |
71 | static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "" ); |
72 | static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "" ); |
73 | static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "" ); |
74 | static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "" ); |
75 | } |
76 | |
77 | // Float kernel for ARM32 out-of-order cores. |
78 | // Just like Float 64 version, except accumulate in to 8x4 block to only |
79 | // use 16 128-bit NEON registers. This is a "first pass" kernel and not |
80 | // tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. |
81 | void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) { |
82 | CheckOffsetsInKernelParamsFloat32(params); |
83 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
84 | |
85 | const float* lhs_ptr = params.lhs_base_ptr; |
86 | const float* rhs_ptr = params.rhs_base_ptr; |
87 | // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are |
88 | // each composed of two 64-bit "d" registers. The asm kernel below has the |
89 | // following NEON register allocation: |
90 | // Registers q3 -- q10 are accumulators. During accumulation, |
91 | // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 |
92 | // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block |
93 | // of RHS, like this: |
94 | |
95 | // Register layout in "q" registers: |
96 | // RHS 1x4 block |
97 | // /--------------------------| |
98 | // |q2.s[0] ... q2.s[3] | |
99 | // \--------------------------/ |
100 | // LHS 8x1 block |
101 | // /---------------------\ /--------------------------| |
102 | // | q0.s[0] | | q3.s[0] ... q9.s[0] | |
103 | // | ... | | ... ... | |
104 | // | q0.s[3] | | q3.s[3] q9.s[3] | |
105 | // | q1.s[0] | | q4.s[0] q10.s[0] | |
106 | // | ... | | ... ... ... | |
107 | // | q1.s[3] | | q4.s[3] .. q10.s[3] | |
108 | // \---------------------/ \--------------------------/ |
109 | // accumulators 8x4 block |
110 | // q11, q14, q15 currently unused. q12 and q13 are used to load |
111 | // parameters used for the post-accumulation part of the kernel. |
112 | // For completeness, here is the register layout in "d" registers: |
113 | // RHS 1x4 block |
114 | // /--------------------------| |
115 | // |d4[0] ... d5[1] | |
116 | // \--------------------------/ |
117 | // LHS 8x1 block |
118 | // /---------------------\ /--------------------------| |
119 | // | d0[0] | | d6[0] ... d18[0] | |
120 | // | ... | | ... ... | |
121 | // | d1[1] | | d7[1] d19[1] | |
122 | // | d2[0] | | d8[0] d20[0] | |
123 | // | ... | | ... ... ... | |
124 | // | d3[1] | | d9[1] ... d21[1] | |
125 | // \---------------------/ \--------------------------/ |
126 | // accumulators 8x4 block |
127 | asm volatile( |
128 | #define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" |
129 | |
130 | // clang-format off |
131 | |
132 | // Load the first 32 bytes of LHS and RHS data. |
133 | // Load q0, q1 |
134 | "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" |
135 | "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" |
136 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
137 | // Load q2 |
138 | "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" |
139 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
140 | |
141 | "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
142 | |
143 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
144 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
145 | |
146 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
147 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
148 | |
149 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
150 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
151 | |
152 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
153 | "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
154 | |
155 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
156 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
157 | |
158 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
159 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
160 | // Clear accumulators. |
161 | RUY_MAKE_ZERO(q3) |
162 | RUY_MAKE_ZERO(q4) |
163 | RUY_MAKE_ZERO(q5) |
164 | RUY_MAKE_ZERO(q6) |
165 | RUY_MAKE_ZERO(q7) |
166 | RUY_MAKE_ZERO(q8) |
167 | RUY_MAKE_ZERO(q9) |
168 | RUY_MAKE_ZERO(q10) |
169 | |
170 | // r1 is the number of levels of depth that we have already loaded |
171 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
172 | // above, this is currently 1. |
173 | "mov r1, #1\n" |
174 | |
175 | // Main loop of the whole GEMM, over rows and columns of the |
176 | // destination matrix. |
177 | "1:\n" |
178 | |
179 | // Accumulation loop |
180 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
181 | "cmp r1, r2\n" |
182 | "beq 79f\n" |
183 | |
184 | "2:\n" |
185 | |
186 | "vmla.f32 q3, q0, d4[0]\n" |
187 | "vmla.f32 q5, q0, d4[1]\n" |
188 | "vmla.f32 q7, q0, d5[0]\n" |
189 | "vmla.f32 q9, q0, d5[1]\n" |
190 | "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS |
191 | |
192 | "vmla.f32 q4, q1, d4[0]\n" |
193 | "vmla.f32 q6, q1, d4[1]\n" |
194 | "vmla.f32 q8, q1, d5[0]\n" |
195 | "vmla.f32 q10, q1, d5[1]\n" |
196 | "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
197 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
198 | "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS |
199 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
200 | |
201 | "add r1, r1, #1\n" |
202 | "cmp r1, r2\n" |
203 | |
204 | "blt 2b\n" |
205 | |
206 | "79:\n" |
207 | |
208 | // End of the inner loop on depth. Now perform the remaining |
209 | // multiply-adds of the last level of depth, for which the LHS |
210 | // and RHS data is already loaded. |
211 | |
212 | "vmla.f32 q3, q0, d4[0]\n" |
213 | "vmla.f32 q5, q0, d4[1]\n" |
214 | "vmla.f32 q7, q0, d5[0]\n" |
215 | "vmla.f32 q9, q0, d5[1]\n" |
216 | |
217 | "vmla.f32 q4, q1, d4[0]\n" |
218 | "vmla.f32 q6, q1, d4[1]\n" |
219 | "vmla.f32 q8, q1, d5[0]\n" |
220 | "vmla.f32 q10, q1, d5[1]\n" |
221 | |
222 | // End of accumulation. The registers q3 -- q10 contain the final |
223 | // float32 accumulator values of the current 8x8 destination block. |
224 | // We now have to compute the final values from these accumulators |
225 | // and advance to the next 8x8 block. We intertwine |
226 | // these two aspects whenever possible for optimal pipelining, both |
227 | // at the data flow level (prefetch data for next block as early as |
228 | // possible) and instruction pipelining level (some of the next-block |
229 | // work can dual-issue with some of the final work on the current |
230 | // block). |
231 | |
232 | // Logic to advance to the next block in preparation for the next |
233 | // iteration of the main loop. For now, we only want to compute |
234 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
235 | // not yet ready to update the values of row and col, as we still need |
236 | // the current values for the rest of the work on the current block. |
237 | |
238 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
239 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
240 | "cmp r1, r3\n" // Have we finished the last row? |
241 | |
242 | "bge 4f\n" // If finished last row, go to 4 |
243 | // Not finished last row: then advance to next row. |
244 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
245 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
246 | "add r4, r4, r1, lsl #3\n" |
247 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
248 | "b 5f\n" |
249 | "4:\n" // Finished last row... |
250 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
251 | // Go back to first row |
252 | "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
253 | // Now we need to advance to the next column. If we already |
254 | // finished the last column, then in principle we are done, however |
255 | // we can't just return here, as we need to allow the end work of the |
256 | // current block to complete. The good news is that at this point it |
257 | // doesn't matter what data we load for the next column, since |
258 | // we will exit from the main loop below before actually storing |
259 | // anything computed from that data. |
260 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
261 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
262 | "cmp r8, r4\n" // Have we finished the last column? |
263 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
264 | // Not finished last column: then advance to next column. |
265 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
266 | "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
267 | "add r10, r10, r1, lsl #2\n" |
268 | "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
269 | "5:\n" |
270 | |
271 | // Set the LHS and RHS data pointers to the start of the columns just |
272 | // computed. |
273 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
274 | "mov %[lhs_ptr], r4\n" |
275 | "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
276 | "mov %[rhs_ptr], r5\n" |
277 | |
278 | // Load some parameters needed for the end work on current block. |
279 | "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
280 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
281 | |
282 | // Let r8 be stack offset of the row or column variable, whichever |
283 | // is the channel index. |
284 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
285 | "bne 1000f\n" |
286 | "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" |
287 | "b 1001f\n" |
288 | "1000:\n" |
289 | "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" |
290 | "1001:\n" |
291 | // Let r8 be the channel index. |
292 | "ldr r8, [sp, r8]\n" |
293 | // Compute the bias pointer, by conditionally using the channel index |
294 | // (r8) as offset into bias buffer (r1). |
295 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
296 | "beq 1002f\n" |
297 | "add r1, r1, r8, lsl #2\n" |
298 | "1002:\n" |
299 | |
300 | // Load 4 bias values. When the channel dimension is rows, we will load |
301 | // another 4 bias values just before performing the bias addition below, |
302 | // as this kernel has a 8x4 rectangular shape. |
303 | "vld1.32 {d24, d25}, [r1]!\n" |
304 | |
305 | // Now that we know what LHS and RHS data the next iteration of the |
306 | // main loop will need to load, we start loading the first 32 bytes of |
307 | // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore |
308 | // in the rest of the work on the current block. |
309 | // Load q0, q1 |
310 | "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" |
311 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
312 | // Load q2 |
313 | "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" |
314 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
315 | |
316 | // Perform the bias-addition. |
317 | // Jump based on channel dimension. |
318 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
319 | "bne 6f\n" |
320 | // Case where channels are rows. |
321 | // Load the remaining 4 bias values, since we're on the width-8 side |
322 | // of this 8x4 kernel. |
323 | "vld1.32 {d26, d27}, [r1]\n" |
324 | "vadd.f32 q3, q3, q12\n" |
325 | "vadd.f32 q5, q5, q12\n" |
326 | "vadd.f32 q7, q7, q12\n" |
327 | "vadd.f32 q9, q9, q12\n" |
328 | "vadd.f32 q4, q4, q13\n" |
329 | "vadd.f32 q6, q6, q13\n" |
330 | "vadd.f32 q8, q8, q13\n" |
331 | "vadd.f32 q10, q10, q13\n" |
332 | "b 7f\n" |
333 | |
334 | "6:\n" |
335 | // Case where channels are columns. |
336 | "vdup.32 q11, d24[0]\n" |
337 | "vdup.32 q13, d24[1]\n" |
338 | "vdup.32 q14, d25[0]\n" |
339 | "vdup.32 q15, d25[1]\n" |
340 | "vadd.f32 q3, q3, q11\n" |
341 | "vadd.f32 q4, q4, q11\n" |
342 | "vadd.f32 q5, q5, q13\n" |
343 | "vadd.f32 q6, q6, q13\n" |
344 | "vadd.f32 q7, q7, q14\n" |
345 | "vadd.f32 q8, q8, q14\n" |
346 | "vadd.f32 q9, q9, q15\n" |
347 | "vadd.f32 q10, q10, q15\n" |
348 | "7:\n" |
349 | |
350 | // Load the clamp_min, clamp_max bounds |
351 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
352 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
353 | "vdup.32 q12, r2\n" // clamp_min |
354 | "vdup.32 q13, r3\n" // clamp_max |
355 | |
356 | // Apply the clamp_min bound |
357 | "vmax.f32 q3, q3, q12\n" |
358 | "vmax.f32 q4, q4, q12\n" |
359 | "vmax.f32 q5, q5, q12\n" |
360 | "vmax.f32 q6, q6, q12\n" |
361 | "vmax.f32 q7, q7, q12\n" |
362 | "vmax.f32 q8, q8, q12\n" |
363 | "vmax.f32 q9, q9, q12\n" |
364 | "vmax.f32 q10, q10, q12\n" |
365 | |
366 | // Apply the clamp_max bound |
367 | "vmin.f32 q3, q3, q13\n" |
368 | "vmin.f32 q4, q4, q13\n" |
369 | "vmin.f32 q5, q5, q13\n" |
370 | "vmin.f32 q6, q6, q13\n" |
371 | "vmin.f32 q7, q7, q13\n" |
372 | "vmin.f32 q8, q8, q13\n" |
373 | "vmin.f32 q9, q9, q13\n" |
374 | "vmin.f32 q10, q10, q13\n" |
375 | |
376 | // Compute how much of the 8x4 block of destination values that |
377 | // we have computed, fit in the destination matrix. Typically, all of |
378 | // it fits, but when the destination matrix shape is not a multiple |
379 | // of 8x4, there are some 8x8 blocks along the boundaries that do |
380 | // not fit entirely. |
381 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
382 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
383 | "sub r1, r1, r8\n" |
384 | |
385 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
386 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
387 | "sub r2, r2, r4\n" |
388 | "mov r3, #8\n" |
389 | "mov r5, #4\n" |
390 | "cmp r1, #8\n" |
391 | // Compute r1 = how many rows of the 8x4 block fit |
392 | "it gt\n" |
393 | "movgt r1, r3\n" |
394 | "cmp r2, #4\n" |
395 | // Compute r2 = how many cols of the 8x4 block fit |
396 | "it gt\n" |
397 | "movgt r2, r5\n" |
398 | |
399 | // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. |
400 | "cmp r1, r3\n" |
401 | "it eq\n" |
402 | "cmpeq r2, r5\n" |
403 | // Yes, all of the 8x4 block fits, go to fast path. |
404 | "beq 30f\n" |
405 | // Not all of the 8x4 block fits. |
406 | // Set (r3 address, r4 stride) to write to dst_tmp_buf |
407 | "mov r3, %[dst_tmp_buf]\n" |
408 | "mov r4, #32\n" |
409 | "b 31f\n" |
410 | "30:\n" |
411 | // Yes, all of the 8x4 block fits. |
412 | // Set (r3 address, r4 stride) to write directly to destination matrix. |
413 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
414 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
415 | "mov r4, r5\n" |
416 | "31:\n" |
417 | |
418 | // Write our float values to the destination described by |
419 | // (r3 address, r4 stride) |
420 | "vst1.32 {d6, d7, d8, d9}, [r3]\n" |
421 | "add r3, r3, r4\n" |
422 | RUY_MAKE_ZERO(q3) |
423 | RUY_MAKE_ZERO(q4) |
424 | "vst1.32 {d10, d11, d12, d13}, [r3]\n" |
425 | "add r3, r3, r4\n" |
426 | RUY_MAKE_ZERO(q5) |
427 | RUY_MAKE_ZERO(q6) |
428 | "vst1.32 {d14, d15, d16, d17}, [r3]\n" |
429 | "add r3, r3, r4\n" |
430 | RUY_MAKE_ZERO(q7) |
431 | RUY_MAKE_ZERO(q8) |
432 | "vst1.32 {d18, d19, d20, d21}, [r3]\n" |
433 | "add r3, r3, r4\n" |
434 | RUY_MAKE_ZERO(q9) |
435 | RUY_MAKE_ZERO(q10) |
436 | |
437 | // If all of the 8x4 block fits, we just finished writing it to the |
438 | // destination, so we skip the next part. |
439 | "beq 41f\n" |
440 | // Not all of the 8x8 block fits in the destination matrix. We just |
441 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
442 | // it to copy into the destination matrix the part that fits. |
443 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
444 | "mov r3, %[dst_tmp_buf]\n" |
445 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
446 | "mov r6, #0\n" |
447 | "50:\n" |
448 | "mov r5, #0\n" |
449 | "51:\n" |
450 | "ldr r10, [r3, r5, lsl #2]\n" |
451 | "str r10, [r4, r5, lsl #2]\n" |
452 | "add r5, r5, #1\n" |
453 | "cmp r5, r1\n" |
454 | "blt 51b\n" |
455 | "add r6, r6, #1\n" |
456 | "add r3, r3, #32\n" |
457 | "add r4, r4, r8\n" |
458 | // r2 = how many cols of the 8x4 block fit |
459 | "cmp r6, r2\n" |
460 | "blt 50b\n" |
461 | "41:\n" |
462 | // Load dst_ptr, increment, and write back. |
463 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
464 | "add r4, r4, #32\n" |
465 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
466 | // At this point we have completely finished writing values to the |
467 | // destination matrix for the current block. |
468 | |
469 | // Reload some params --- we had used r3, r5, r10 for a few other things |
470 | // since the last time we had loaded them. |
471 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
472 | "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
473 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
474 | |
475 | // Move to the next block of the destination matrix, for the next iter |
476 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
477 | // been updated earlier. |
478 | // Have we reached the end row? |
479 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
480 | "cmp r8, r3\n" |
481 | |
482 | "beq 20f\n" // yes, end row. |
483 | // Not end row. Move to the next row. |
484 | "add r8, r8, #8\n" |
485 | // Store new value of row |
486 | "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
487 | |
488 | "b 21f\n" |
489 | "20:\n" |
490 | // Was already at end row. |
491 | // Move back to first row. |
492 | "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
493 | // Move to the next column. |
494 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
495 | "add r4, r4, #4\n" |
496 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
497 | |
498 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
499 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
500 | // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) |
501 | "add r1, r1, r8, lsl #2\n" |
502 | // Store dst_col_ptr |
503 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
504 | // Store dst_ptr |
505 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
506 | "21:\n" |
507 | |
508 | // Main loop exit condition: have we hit the end column? |
509 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
510 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
511 | "cmp r8, r4\n" |
512 | |
513 | // r1 is the number of levels of depth that we have already loaded |
514 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
515 | // above, this is currently 1. |
516 | "mov r1, #1\n" |
517 | |
518 | "ble 1b\n" |
519 | |
520 | // Restore stack pointer. |
521 | "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
522 | |
523 | // clang-format on |
524 | : [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr) |
525 | : [ params ] "r" (¶ms), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
526 | // Clobber list must specify q registers (and not their constituent |
527 | // d registers). There is a (currently unexplained) slowdown if |
528 | // d registers are listed in the clobbers list. |
529 | : "r0" , "r1" , "r2" , "r3" , "r4" , "r5" , "r6" , "r8" , "r10" , "cc" , |
530 | "memory" , "q0" , "q1" , "q2" , "q3" , "q4" , "q5" , "q6" , "q7" , "q8" , |
531 | "q9" , "q10" , "q12" , "q13" ); |
532 | } |
533 | |
534 | #undef RUY_MAKE_ZERO |
535 | #undef RUY_STACK_OFFSET_SIZE |
536 | #undef RUY_STACK_OFFSET_DST_COL_PTR |
537 | #undef RUY_STACK_OFFSET_DST_PTR |
538 | #undef RUY_STACK_OFFSET_ROW |
539 | #undef RUY_STACK_OFFSET_COL |
540 | #undef RUY_STACK_OFFSET_LHS_COL_PTR |
541 | #undef RUY_STACK_OFFSET_RHS_COL_PTR |
542 | |
543 | #undef RUY_OFFSET_LHS_BASE_PTR |
544 | #undef RUY_OFFSET_RHS_BASE_PTR |
545 | #undef RUY_OFFSET_DST_BASE_PTR |
546 | #undef RUY_OFFSET_BIAS |
547 | #undef RUY_OFFSET_START_ROW |
548 | #undef RUY_OFFSET_START_COL |
549 | #undef RUY_OFFSET_LAST_ROW |
550 | #undef RUY_OFFSET_LAST_COL |
551 | #undef RUY_OFFSET_DST_ROWS |
552 | #undef RUY_OFFSET_DST_COLS |
553 | #undef RUY_OFFSET_LHS_STRIDE |
554 | #undef RUY_OFFSET_RHS_STRIDE |
555 | #undef RUY_OFFSET_DST_STRIDE |
556 | #undef RUY_OFFSET_DEPTH |
557 | #undef RUY_OFFSET_CLAMP_MIN |
558 | #undef RUY_OFFSET_CLAMP_MAX |
559 | #undef RUY_OFFSET_FLAGS |
560 | |
561 | #define RUY_OFFSET_BIAS 0 |
562 | #define RUY_OFFSET_LHS_SUMS 4 |
563 | #define RUY_OFFSET_RHS_SUMS 8 |
564 | #define RUY_OFFSET_LHS_BASE_PTR 12 |
565 | #define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 |
566 | #define RUY_OFFSET_MULTIPLIER_EXPONENT 20 |
567 | #define RUY_OFFSET_RHS_BASE_PTR 24 |
568 | #define RUY_OFFSET_DST_BASE_PTR 28 |
569 | #define RUY_OFFSET_LHS_ZERO_POINT 32 |
570 | #define RUY_OFFSET_RHS_ZERO_POINT 36 |
571 | #define RUY_OFFSET_DST_ZERO_POINT 40 |
572 | #define RUY_OFFSET_PROD_ZP_DEPTH 44 |
573 | #define RUY_OFFSET_START_ROW 48 |
574 | #define RUY_OFFSET_START_COL 52 |
575 | #define RUY_OFFSET_LAST_ROW 56 |
576 | #define RUY_OFFSET_LAST_COL 60 |
577 | #define RUY_OFFSET_DST_ROWS 64 |
578 | #define RUY_OFFSET_DST_COLS 68 |
579 | #define RUY_OFFSET_LHS_STRIDE 72 |
580 | #define RUY_OFFSET_RHS_STRIDE 76 |
581 | #define RUY_OFFSET_DST_STRIDE 80 |
582 | #define RUY_OFFSET_DEPTH 84 |
583 | #define RUY_OFFSET_CLAMP_MIN 88 |
584 | #define RUY_OFFSET_CLAMP_MAX 92 |
585 | #define RUY_OFFSET_FLAGS 96 |
586 | #define RUY_OFFSET_DST_TYPE_ID 97 |
587 | |
588 | #define RUY_STACK_OFFSET_SIZE 96 |
589 | #define RUY_STACK_OFFSET_DST_COL_PTR 0 |
590 | #define RUY_STACK_OFFSET_DST_PTR 16 |
591 | #define RUY_STACK_OFFSET_ROW 32 |
592 | #define RUY_STACK_OFFSET_COL 48 |
593 | #define RUY_STACK_OFFSET_LHS_COL_PTR 64 |
594 | #define RUY_STACK_OFFSET_RHS_COL_PTR 80 |
595 | |
596 | template <typename Params> |
597 | void CheckOffsetsInKernelParams8bit(const Params&) { |
598 | static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, |
599 | "" ); |
600 | static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, |
601 | "" ); |
602 | static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, |
603 | "" ); |
604 | static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, |
605 | "" ); |
606 | static_assert(offsetof(Params, multiplier_fixedpoint) == |
607 | RUY_OFFSET_MULTIPLIER_FIXEDPOINT, |
608 | "" ); |
609 | static_assert( |
610 | offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, |
611 | "" ); |
612 | static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "" ); |
613 | static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "" ); |
614 | static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "" ); |
615 | static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "" ); |
616 | static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "" ); |
617 | static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "" ); |
618 | static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "" ); |
619 | static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "" ); |
620 | static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "" ); |
621 | static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "" ); |
622 | static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "" ); |
623 | static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "" ); |
624 | static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "" ); |
625 | static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "" ); |
626 | } |
627 | |
628 | // Fast-int8 kernel, ported from ARM 64 version. |
629 | // Relevant target CPUs for this kernel include Krait 400 and A9, |
630 | // since these are 32-bit, out-of-order CPUs. |
631 | void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { |
632 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
633 | |
634 | CheckOffsetsInKernelParams8bit(params); |
635 | |
636 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
637 | const std::int8_t* rhs_col_ptr = |
638 | static_cast<const int8_t*>(params.rhs_base_ptr); |
639 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
640 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
641 | |
642 | // The asm kernel below has the following NEON register allocation: |
643 | // |
644 | // q6 - q13 are 128-bit (4x32b) accumulators. |
645 | // During accumulation, d0 -- d7 are used to load int8 data from LHS and |
646 | // d8 -- d11 from RHS: |
647 | // int8 RHS 16x2 block |
648 | // /-----------------------------| |
649 | // |d8.b[0-7] ..... d10.b[0-7]| |
650 | // | ... ... | |
651 | // |d9.b[0-7] ..... d11.b[0-7]| |
652 | // \-----------------------------/ |
653 | // int8 LHS 4x16 block |
654 | // /------------------------\ /-----------------------------| |
655 | // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | |
656 | // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | |
657 | // (Reload d0, d1, d2, d3) |
658 | // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | |
659 | // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | |
660 | // \------------------------/ \-----------------------------/ |
661 | // 128-bit accumulators 4x2 block |
662 | // |
663 | // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
664 | // optimization for this kernel. |
665 | asm volatile( |
666 | #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" |
667 | |
668 | // clang-format off |
669 | |
670 | // Load the first 64 bytes of LHS and RHS data. |
671 | "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
672 | // Clear accumulators. |
673 | RUY_MAKE_ZERO(q6) |
674 | RUY_MAKE_ZERO(q7) |
675 | "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
676 | RUY_MAKE_ZERO(q8) |
677 | RUY_MAKE_ZERO(q9) |
678 | RUY_MAKE_ZERO(q10) |
679 | "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
680 | RUY_MAKE_ZERO(q11) |
681 | "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" |
682 | |
683 | "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
684 | |
685 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
686 | RUY_MAKE_ZERO(q12) |
687 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
688 | |
689 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
690 | RUY_MAKE_ZERO(q13) |
691 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
692 | |
693 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
694 | RUY_MAKE_ZERO(q14) |
695 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
696 | |
697 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
698 | RUY_MAKE_ZERO(q15) |
699 | "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
700 | |
701 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
702 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
703 | |
704 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
705 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
706 | |
707 | |
708 | // r1 is the number of levels of depth that we have already loaded |
709 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
710 | // above, this is currently 16. |
711 | "mov r1, #16\n" |
712 | |
713 | // Main loop of the whole GEMM, over rows and columns of the |
714 | // destination matrix. |
715 | "1:\n" |
716 | |
717 | // r1 is how many levels of depth we have already loaded |
718 | // data for, r10 is the total depth. |
719 | "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
720 | "cmp r1, r10\n" |
721 | "beq 79f\n" |
722 | |
723 | "2:\n" |
724 | |
725 | // Mult, mult-acc in to q14, q15, q2, q3 |
726 | "vmull.s8 q14, d0, d8\n" |
727 | "vmull.s8 q2, d0, d10\n" |
728 | |
729 | "vmull.s8 q15, d2, d8\n" |
730 | "vmull.s8 q3, d2, d10\n" |
731 | |
732 | "vmlal.s8 q14, d1, d9\n" |
733 | "vmlal.s8 q2, d1, d11\n" |
734 | "vmlal.s8 q15, d3, d9\n" |
735 | "vmlal.s8 q3, d3, d11\n" |
736 | "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
737 | |
738 | // Then pairwise accumulate in to q6, q7, q10, q11 |
739 | "vpadal.s16 q6, q14\n" |
740 | "vpadal.s16 q7, q15\n" |
741 | "vpadal.s16 q10, q2\n" |
742 | "vpadal.s16 q11, q3\n" |
743 | |
744 | // Mult, mult-acc in to q14, q15, q2, q3 |
745 | "vmull.s8 q14, d0, d8\n" |
746 | "vmull.s8 q2, d0, d10\n" |
747 | |
748 | "vmull.s8 q15, d2, d8\n" |
749 | "vmull.s8 q3, d2, d10\n" |
750 | |
751 | "vmlal.s8 q14, d1, d9\n" |
752 | "vmlal.s8 q2, d1, d11\n" |
753 | "vmlal.s8 q15, d3, d9\n" |
754 | "vmlal.s8 q3, d3, d11\n" |
755 | "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
756 | |
757 | // Then pairwise accumulate in to q8, q9, q12, q13 |
758 | "vpadal.s16 q8, q14\n" |
759 | "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" |
760 | "vpadal.s16 q9, q15\n" |
761 | "vpadal.s16 q12, q2\n" |
762 | "vpadal.s16 q13, q3\n" |
763 | |
764 | // Prefetch the next 64 bytes of LHS and RHS data. |
765 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
766 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
767 | |
768 | // Each iteration of this loop advances by 16 levels of depth. |
769 | "add r1, r1, #16\n" |
770 | |
771 | // Loop termination condition |
772 | "cmp r1, r10\n" |
773 | |
774 | "blt 2b\n" |
775 | |
776 | "79:\n" |
777 | |
778 | // Mult, mult-acc in to q14, q15, q2, q3 |
779 | "vmull.s8 q14, d0, d8\n" |
780 | "vmull.s8 q2, d0, d10\n" |
781 | |
782 | "vmull.s8 q15, d2, d8\n" |
783 | "vmull.s8 q3, d2, d10\n" |
784 | |
785 | "vmlal.s8 q14, d1, d9\n" |
786 | "vmlal.s8 q2, d1, d11\n" |
787 | "vmlal.s8 q15, d3, d9\n" |
788 | "vmlal.s8 q3, d3, d11\n" |
789 | "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS |
790 | |
791 | // Then pairwise accumulate in to q6, q7, q10, q11 |
792 | "vpadal.s16 q6, q14\n" |
793 | "vpadal.s16 q7, q15\n" |
794 | "vpadal.s16 q10, q2\n" |
795 | "vpadal.s16 q11, q3\n" |
796 | |
797 | // Mult, mult-acc in to q14, q15, q2, q3 |
798 | "vmull.s8 q14, d0, d8\n" |
799 | "vmull.s8 q2, d0, d10\n" |
800 | |
801 | "vmull.s8 q15, d2, d8\n" |
802 | "vmull.s8 q3, d2, d10\n" |
803 | |
804 | "vmlal.s8 q14, d1, d9\n" |
805 | "vmlal.s8 q2, d1, d11\n" |
806 | "vmlal.s8 q15, d3, d9\n" |
807 | "vmlal.s8 q3, d3, d11\n" |
808 | |
809 | // Then pairwise accumulate in to q8, q9, q12, q13 |
810 | "vpadal.s16 q8, q14\n" |
811 | "vpadal.s16 q9, q15\n" |
812 | "vpadal.s16 q12, q2\n" |
813 | "vpadal.s16 q13, q3\n" |
814 | |
815 | |
816 | // All accumulation over depth done. q6 - q13 contain the 4x32b |
817 | // accumulators for the 4x2 final matrix. |
818 | // We now have to compute the final 8-bit values from these int32 |
819 | // accumulators, and advance to the next 4x2 block. We intertwine |
820 | // these two aspects whenever possible for optimal pipelining, both |
821 | // at the data flow level (prefetch data for next block as early as |
822 | // possible) and instruction pipelining level (some of the next-block |
823 | // work can dual-issue with some of the final work on the current |
824 | // block). |
825 | |
826 | // q6-q13 now contain 4 x 32b |
827 | "vpadd.i32 d0, d12, d13\n" |
828 | "vpadd.i32 d1, d14, d15\n" |
829 | "vpadd.i32 d2, d16, d17\n" |
830 | "vpadd.i32 d3, d18, d19\n" |
831 | "vpadd.i32 d4, d20, d21\n" |
832 | "vpadd.i32 d5, d22, d23\n" |
833 | "vpadd.i32 d6, d24, d25\n" |
834 | "vpadd.i32 d7, d26, d27\n" |
835 | |
836 | // d0-d7 each contain 2 x 32b accumulators. |
837 | // Need to add pairwise to get 1 x 32b for each of the 4x2 entries |
838 | // of destination, (Four 'd' registers total) |
839 | "vpadd.i32 d28, d0, d1\n" |
840 | "vpadd.i32 d29, d2, d3\n" |
841 | "vpadd.i32 d30, d4, d5\n" |
842 | "vpadd.i32 d31, d6, d7\n" |
843 | |
844 | //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries |
845 | |
846 | // Logic to advance to the next block in preparation for the next |
847 | // iteration of the main loop. For now, we only want to compute |
848 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
849 | // not yet ready to update the values of row and col, as we still need |
850 | // the current values for the rest of the work on the current block. |
851 | |
852 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
853 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
854 | "cmp r1, r3\n" // Have we finished the last row? |
855 | |
856 | "bge 4f\n" // If finished last row, go to 4 |
857 | // Not finished last row: then advance to next row. |
858 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
859 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
860 | "add r4, r4, r1, lsl #2\n" |
861 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
862 | "b 5f\n" |
863 | "4:\n" // Finished last row... |
864 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
865 | // Go back to first row |
866 | "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
867 | |
868 | // Now we need to advance to the next column. If we already |
869 | // finished the last column, then in principle we are done, however |
870 | // we can't just return here, as we need to allow the end work of the |
871 | // current block to complete. The good news is that at this point it |
872 | // doesn't matter what data we load for the next column, since |
873 | // we will exit from the main loop below before actually storing |
874 | // anything computed from that data. |
875 | |
876 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
877 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
878 | "cmp r8, r4\n" // Have we finished the last column? |
879 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
880 | // Not finished last column: then advance to next column. |
881 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
882 | "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
883 | "add r10, r10, r1, lsl #1\n" |
884 | "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
885 | "5:\n" |
886 | |
887 | // Set the LHS and RHS data pointers to the start of the columns just |
888 | // computed. |
889 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
890 | "mov %[lhs_ptr], r4\n" |
891 | "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
892 | "mov %[rhs_ptr], r5\n" |
893 | |
894 | // Now we load: bias data, LHS sums data, RHS sums data. |
895 | |
896 | // First, load the base pointers from the params. |
897 | "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
898 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
899 | |
900 | // Let r8 be stack offset of the row or column variable, whichever |
901 | // is the channel index. |
902 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
903 | "bne 1000f\n" |
904 | "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n" |
905 | "b 1001f\n" |
906 | "1000:\n" |
907 | "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n" |
908 | "1001:\n" |
909 | |
910 | // Let r8 be the channel index. |
911 | "ldr r8, [sp, r8]\n" |
912 | // Compute the bias pointer, by conditionally using the channel index |
913 | // (r8) as offset into bias buffer (r1). |
914 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
915 | "beq 1002f\n" |
916 | "add r1, r1, r8, lsl #2\n" |
917 | "1002:\n" |
918 | |
919 | // Load 2 bias values. When the channel dimension is rows, we will load |
920 | // another 2 bias values just before performing the bias addition below, |
921 | // as this kernel has a 4x2 rectangular shape. |
922 | "vld1.32 {d24}, [r1]!\n" |
923 | |
924 | // Now that we know what LHS and RHS data the next iteration of the |
925 | // main loop will need to load, we start loading the first 32 bytes of |
926 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
927 | // in the rest of the work on the current block. |
928 | "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" |
929 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
930 | "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" |
931 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
932 | |
933 | // Add to the bias values the product |
934 | // (depth * lhs_zero_point * rhs_zero_point), |
935 | // See the term NZ1Z2 in equation (7) in |
936 | // https://arxiv.org/pdf/1712.05877.pdf |
937 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
938 | "vdup.32 q9, r3\n" |
939 | "vadd.i32 d24, d24, d18\n" |
940 | |
941 | // Perform the bias-addition (per the above, we have just folded into |
942 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
943 | // Jump based on channel dimension. |
944 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
945 | "bne 6f\n" |
946 | // Case where channels are rows. |
947 | // Load the remaining 2 bias values, since we're on the width-4 side |
948 | // of this 4x2 kernel. |
949 | "vld1.32 {d25}, [r1]\n" |
950 | "vadd.i32 d25, d25, d19\n" |
951 | "vadd.i32 q14, q14, q12\n" |
952 | "vadd.i32 q15, q15, q12\n" |
953 | "b 7f\n" |
954 | |
955 | "6:\n" |
956 | // Case where channels are columns. |
957 | "vdup.32 q10, d24[0]\n" |
958 | "vdup.32 q11, d24[1]\n" |
959 | "vadd.i32 q14, q14, q10\n" |
960 | "vadd.i32 q15, q15, q11\n" |
961 | "7:\n" |
962 | |
963 | // LHS/RHS zero points |
964 | // Has RHS sums |
965 | "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
966 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
967 | "beq 401f\n" |
968 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
969 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
970 | // Offset by current col * number of bytes per value |
971 | "add r3, r3, r4, lsl #2\n" |
972 | "vld1.32 { d12 }, [r3]\n" |
973 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
974 | "vdup.32 q10, r5\n" // create lhs_zero_point_vec |
975 | // Subtract rhs_sums * lhs_zero_point, per |
976 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
977 | "vmls.i32 q14, q10, d12[0]\n" |
978 | "vmls.i32 q15, q10, d12[1]\n" |
979 | "401:\n" |
980 | |
981 | // Has LHS sums |
982 | "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
983 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
984 | "beq 402f\n" |
985 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
986 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
987 | // Offset by current row * number of bytes per value |
988 | "add r2, r2, r4, lsl #2\n" |
989 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
990 | |
991 | // Load 4 lhs_sums values. |
992 | "vld1.32 {d22, d23}, [r2]\n" |
993 | "vdup.32 d13, r5\n" // rhs_zero_point |
994 | |
995 | // Compute lhs_sums * rhs_zero_point. |
996 | "vmul.i32 q11, q11, d13[1]\n" |
997 | // Subtract lhs_sums * rhs_zero_point, per |
998 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
999 | "vsub.s32 q14, q14, q11\n" |
1000 | "vsub.s32 q15, q15, q11\n" |
1001 | |
1002 | // If the destination is int32, it means the user asks for the raw |
1003 | // accumulators, no need for us to downquantize the value. |
1004 | "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
1005 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
1006 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
1007 | |
1008 | "402:\n" |
1009 | |
1010 | // At this point we have computed the final int32 values. Now we |
1011 | // start down-quantizing them to obtain the final 8bit values from them. |
1012 | |
1013 | // As part of this down-quantization, our int32 values will be |
1014 | // multiplied by a multiplier that has a fixed-point component and an |
1015 | // exponent component. |
1016 | |
1017 | // Compute the data pointers for the multiplier data |
1018 | // r1 = exponent part |
1019 | // r2 = fixedpoint part |
1020 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
1021 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
1022 | // r6 has flags, r8 has channel index |
1023 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
1024 | "beq 1003f\n" |
1025 | "add r1, r1, r8, lsl #2\n" |
1026 | "add r2, r2, r8, lsl #2\n" |
1027 | "1003:\n" |
1028 | |
1029 | // Load the first 2 values of multiplier exponent and fixedpoint data |
1030 | // Since this kernel is rectangular 4x2, we will only conditionally load |
1031 | // 2 more values below. |
1032 | "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent |
1033 | "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint |
1034 | |
1035 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n" |
1036 | "vmvn.i32 q8, #0\n" |
1037 | "bne 8f\n" |
1038 | // Case where channels are rows. |
1039 | // Load the remaining 2 bias values, since we're on the width-4 side |
1040 | // of this 4x2 kernel. |
1041 | "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent |
1042 | "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint |
1043 | "vmin.s32 q11, q10, q8\n" |
1044 | "vsub.s32 q10, q10, q11\n" |
1045 | |
1046 | // Apply the positive exponent part of the multiplier. |
1047 | "vshl.s32 q14, q14, q10\n" |
1048 | "vshl.s32 q15, q15, q10\n" |
1049 | |
1050 | // Apply the fixed-point part of the multiplier. |
1051 | "vqdmulh.s32 q14, q14, q6\n" |
1052 | "vqdmulh.s32 q15, q15, q6\n" |
1053 | |
1054 | // Apply the negative exponent part of the multiplier. |
1055 | "vrshl.s32 q14, q14, q11\n" |
1056 | "vrshl.s32 q15, q15, q11\n" |
1057 | "b 9f\n" |
1058 | |
1059 | "8:\n" |
1060 | // Case where channels are columns. |
1061 | "vmin.s32 d22, d20, d16\n" |
1062 | "vsub.s32 d20, d20, d22\n" |
1063 | |
1064 | // Apply the positive exponent part of the multiplier. |
1065 | "vdup.32 q12, d20[0]\n" |
1066 | "vdup.32 q13, d20[1]\n" |
1067 | "vshl.s32 q14, q14, q12\n" |
1068 | "vshl.s32 q15, q15, q13\n" |
1069 | |
1070 | // Apply the fixed-point part of the multiplier. |
1071 | "vqdmulh.s32 q14, q14, d12[0]\n" |
1072 | "vqdmulh.s32 q15, q15, d12[1]\n" |
1073 | |
1074 | // Apply the negative exponent part of the multiplier. |
1075 | "vdup.32 q12, d22[0]\n" |
1076 | "vdup.32 q13, d22[1]\n" |
1077 | "vrshl.s32 q14, q14, q12\n" |
1078 | "vrshl.s32 q15, q15, q13\n" |
1079 | |
1080 | "9:\n" |
1081 | |
1082 | "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
1083 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
1084 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
1085 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
1086 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
1087 | |
1088 | // Store uint8 values: |
1089 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
1090 | |
1091 | // Cast-and-saturate from int32 to int16 |
1092 | // After this, all values for output are in q14. |
1093 | "vqmovn.s32 d28, q14\n" |
1094 | "vqmovn.s32 d29, q15\n" |
1095 | |
1096 | // At this point, d12 -- d26, d30, d31 aren't used anymore for the |
1097 | // current block, so we can start clearing these accumulators for the |
1098 | // next block (next iteration of the main loop). |
1099 | RUY_MAKE_ZERO(q6) |
1100 | RUY_MAKE_ZERO(q7) |
1101 | RUY_MAKE_ZERO(q8) |
1102 | RUY_MAKE_ZERO(q9) |
1103 | RUY_MAKE_ZERO(q10) |
1104 | RUY_MAKE_ZERO(q11) |
1105 | RUY_MAKE_ZERO(q12) |
1106 | RUY_MAKE_ZERO(q13) |
1107 | RUY_MAKE_ZERO(q15) |
1108 | |
1109 | // Load the destination zero point into each of the 8 16-bit slots |
1110 | // in a q register. |
1111 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
1112 | "vdup.16 q13, r4\n" // dst_zero_point |
1113 | |
1114 | // Add the destination zero point |
1115 | "vqadd.s16 q14, q14, q13\n" |
1116 | |
1117 | // Cast-and-saturate from int16 to uint8 |
1118 | // Now all 8 1-byte values are in d30. |
1119 | "vqmovun.s16 d30, q14\n" |
1120 | |
1121 | // Load the clamp_min, clamp_max bounds |
1122 | "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1123 | "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1124 | "vdup.8 d28, r2\n" // clamp_min |
1125 | "vdup.8 d29, r3\n" // clamp_max |
1126 | |
1127 | // Apply the clamp_min bound |
1128 | "vmax.u8 d30, d30, d28\n" |
1129 | // Apply the clamp_max bound |
1130 | "vmin.u8 d30, d30, d29\n" |
1131 | |
1132 | // Compute how much of the 4x2 block of destination 8bit values that |
1133 | // we have computed, fit in the destination matrix. Typically, all of |
1134 | // it fits, but when the destination matrix shape is not a multiple |
1135 | // of 4x2, there are some 4x2 blocks along the boundaries that do |
1136 | // not fit entirely. |
1137 | |
1138 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
1139 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1140 | "sub r1, r1, r8\n" |
1141 | |
1142 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
1143 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1144 | "sub r2, r2, r4\n" |
1145 | "mov r3, #4\n" |
1146 | "mov r5, #2\n" |
1147 | "cmp r1, #4\n" |
1148 | // Compute r1 = how many rows of the 4x2 block fit |
1149 | "it gt\n" |
1150 | "movgt r1, r3\n" |
1151 | |
1152 | "cmp r2, #2\n" |
1153 | // Compute r2 = how many cols of the 4x2 block fit |
1154 | "it gt\n" |
1155 | "movgt r2, r5\n" |
1156 | |
1157 | // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
1158 | "cmp r1, r3\n" |
1159 | "it eq\n" |
1160 | "cmpeq r2, r5\n" |
1161 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1162 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1163 | // Yes, all of the 4x2 block fits, go to fast path. |
1164 | "beq 30f\n" |
1165 | // Not all of the 4x2 block fits. |
1166 | // Store to dst_tmp_buf |
1167 | // Set r3 address to write to dst_tmp_buf. |
1168 | "mov r3, %[dst_tmp_buf]\n" |
1169 | "vst1.8 {d30}, [r3]\n" |
1170 | |
1171 | // Slow loop copying from dst_tmp_buf to dst. |
1172 | "mov r6, #0\n" |
1173 | "50:\n" |
1174 | "mov r8, #0\n" |
1175 | "51:\n" |
1176 | "ldrb r10, [r3, r8]\n" |
1177 | "strb r10, [r4, r8]\n" |
1178 | "add r8, r8, #1\n" |
1179 | "cmp r8, r1\n" |
1180 | "blt 51b\n" |
1181 | "add r6, r6, #1\n" |
1182 | "add r3, r3, #4\n" |
1183 | "add r4, r4, r5\n" |
1184 | "cmp r6, r2\n" |
1185 | "blt 50b\n" |
1186 | "b 31f\n" |
1187 | "30:\n" |
1188 | // Yes, all of the 4x2 block fits. |
1189 | // r3 address, r5 stride |
1190 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1191 | "mov r4, r3\n" |
1192 | "mov r6, #1\n" |
1193 | |
1194 | "vst1.32 {d30[0]}, [r3]\n" |
1195 | "add r4, r4, r5\n" |
1196 | "mov r3, r4\n" |
1197 | "vst1.32 {d30[1]}, [r3]\n" |
1198 | |
1199 | "31:\n" |
1200 | |
1201 | // Load dst_ptr, increment, and write back. |
1202 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1203 | "add r4, r4, #4\n" |
1204 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1205 | |
1206 | RUY_MAKE_ZERO(q13) |
1207 | RUY_MAKE_ZERO(q14) |
1208 | RUY_MAKE_ZERO(q15) |
1209 | |
1210 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1211 | |
1212 | // Store int8 values: |
1213 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
1214 | |
1215 | // Cast-and-saturate from int32 to int16 |
1216 | // After this, all values for output are in q14. |
1217 | "vqmovn.s32 d28, q14\n" |
1218 | "vqmovn.s32 d29, q15\n" |
1219 | |
1220 | // At this point, d12 -- d26, d30, d31 aren't used anymore for the |
1221 | // current block, so we can start clearing these accumulators for the |
1222 | // next block (next iteration of the main loop). |
1223 | RUY_MAKE_ZERO(q6) |
1224 | RUY_MAKE_ZERO(q7) |
1225 | RUY_MAKE_ZERO(q8) |
1226 | RUY_MAKE_ZERO(q9) |
1227 | RUY_MAKE_ZERO(q10) |
1228 | RUY_MAKE_ZERO(q11) |
1229 | RUY_MAKE_ZERO(q12) |
1230 | RUY_MAKE_ZERO(q13) |
1231 | RUY_MAKE_ZERO(q15) |
1232 | |
1233 | // Load the destination zero point into each of the 8 16-bit slots |
1234 | // in a q register. |
1235 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
1236 | "vdup.16 q13, r4\n" // dst_zero_point |
1237 | |
1238 | // Add the destination zero point |
1239 | "vqadd.s16 q14, q14, q13\n" |
1240 | |
1241 | // Cast-and-saturate from int16 to int8 |
1242 | // Now all 8 1-byte values are in d30. |
1243 | "vqmovn.s16 d30, q14\n" |
1244 | |
1245 | // Load the clamp_min, clamp_max bounds |
1246 | "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1247 | "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1248 | "vdup.8 d28, r2\n" // clamp_min |
1249 | "vdup.8 d29, r3\n" // clamp_max |
1250 | |
1251 | // Apply the clamp_min bound |
1252 | "vmax.s8 d30, d30, d28\n" |
1253 | // Apply the clamp_max bound |
1254 | "vmin.s8 d30, d30, d29\n" |
1255 | |
1256 | // Compute how much of the 4x2 block of destination 8bit values that |
1257 | // we have computed, fit in the destination matrix. Typically, all of |
1258 | // it fits, but when the destination matrix shape is not a multiple |
1259 | // of 4x2, there are some 4x2 blocks along the boundaries that do |
1260 | // not fit entirely. |
1261 | |
1262 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
1263 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1264 | "sub r1, r1, r8\n" |
1265 | |
1266 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
1267 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1268 | "sub r2, r2, r4\n" |
1269 | "mov r3, #4\n" |
1270 | "mov r5, #2\n" |
1271 | "cmp r1, #4\n" |
1272 | // Compute r1 = how many rows of the 4x2 block fit |
1273 | "it gt\n" |
1274 | "movgt r1, r3\n" |
1275 | |
1276 | "cmp r2, #2\n" |
1277 | // Compute r2 = how many cols of the 4x2 block fit |
1278 | "it gt\n" |
1279 | "movgt r2, r5\n" |
1280 | |
1281 | // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
1282 | "cmp r1, r3\n" |
1283 | "it eq\n" |
1284 | "cmpeq r2, r5\n" |
1285 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1286 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1287 | // Yes, all of the 4x2 block fits, go to fast path. |
1288 | "beq 30f\n" |
1289 | // Not all of the 4x2 block fits. |
1290 | // Store to dst_tmp_buf |
1291 | // Set r3 address to write to dst_tmp_buf. |
1292 | "mov r3, %[dst_tmp_buf]\n" |
1293 | "vst1.8 {d30}, [r3]\n" |
1294 | |
1295 | // Slow loop copying from dst_tmp_buf to dst. |
1296 | "mov r6, #0\n" |
1297 | "50:\n" |
1298 | "mov r8, #0\n" |
1299 | "51:\n" |
1300 | "ldrb r10, [r3, r8]\n" |
1301 | "strb r10, [r4, r8]\n" |
1302 | "add r8, r8, #1\n" |
1303 | "cmp r8, r1\n" |
1304 | "blt 51b\n" |
1305 | "add r6, r6, #1\n" |
1306 | "add r3, r3, #4\n" |
1307 | "add r4, r4, r5\n" |
1308 | "cmp r6, r2\n" |
1309 | "blt 50b\n" |
1310 | "b 31f\n" |
1311 | "30:\n" |
1312 | // Yes, all of the 4x2 block fits. |
1313 | // r3 address, r5 stride |
1314 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1315 | "mov r4, r3\n" |
1316 | "mov r6, #1\n" |
1317 | |
1318 | "vst1.32 {d30[0]}, [r3]\n" |
1319 | "add r4, r4, r5\n" |
1320 | "mov r3, r4\n" |
1321 | "vst1.32 {d30[1]}, [r3]\n" |
1322 | |
1323 | "31:\n" |
1324 | |
1325 | // Load dst_ptr, increment, and write back. |
1326 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1327 | "add r4, r4, #4\n" |
1328 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1329 | |
1330 | RUY_MAKE_ZERO(q13) |
1331 | RUY_MAKE_ZERO(q14) |
1332 | RUY_MAKE_ZERO(q15) |
1333 | |
1334 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1335 | |
1336 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
1337 | |
1338 | // Load the destination zero point into each of the 4 32-bit slots |
1339 | // in a q register. |
1340 | "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
1341 | "vdup.32 q13, r4\n" // dst_zero_point |
1342 | // Add the destination zero point |
1343 | "vadd.s32 q14, q14, q13\n" |
1344 | "vadd.s32 q15, q15, q13\n" |
1345 | |
1346 | // Cast-and-saturate from int32 to int16 |
1347 | // After this, all values for output are in q14. |
1348 | "vqmovn.s32 d28, q14\n" |
1349 | "vqmovn.s32 d29, q15\n" |
1350 | |
1351 | // At this point, v18 -- v31 aren't used anymore for the current block, |
1352 | // so we can start clearing these accumulators for the next block |
1353 | // (next iteration of the main loop). |
1354 | RUY_MAKE_ZERO(q6) |
1355 | RUY_MAKE_ZERO(q7) |
1356 | RUY_MAKE_ZERO(q8) |
1357 | RUY_MAKE_ZERO(q9) |
1358 | RUY_MAKE_ZERO(q10) |
1359 | RUY_MAKE_ZERO(q11) |
1360 | RUY_MAKE_ZERO(q15) |
1361 | |
1362 | // Load the clamp_min, clamp_max bounds |
1363 | "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
1364 | "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
1365 | "vdup.16 q12, r2\n" // clamp_min |
1366 | "vdup.16 q13, r3\n" // clamp_max |
1367 | |
1368 | // Apply the clamp_min bound |
1369 | "vmax.s16 q14, q14, q12\n" |
1370 | // Apply the clamp_max bound |
1371 | "vmin.s16 q14, q14, q13\n" |
1372 | |
1373 | RUY_MAKE_ZERO(q12) |
1374 | RUY_MAKE_ZERO(q13) |
1375 | |
1376 | // Compute how much of the 4x2 block of destination 16-bit values that |
1377 | // we have computed, fit in the destination matrix. Typically, all of |
1378 | // it fits, but when the destination matrix shape is not a multiple |
1379 | // of 4x2, there are some 4x2 blocks along the boundaries that do |
1380 | // not fit entirely. |
1381 | |
1382 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
1383 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1384 | "sub r1, r1, r8\n" |
1385 | |
1386 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
1387 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1388 | "sub r2, r2, r4\n" |
1389 | "mov r3, #4\n" |
1390 | "mov r5, #2\n" |
1391 | "cmp r1, #4\n" |
1392 | // Compute r1 = how many rows of the 4x2 block fit |
1393 | "it gt\n" |
1394 | "movgt r1, r3\n" |
1395 | |
1396 | "cmp r2, #2\n" |
1397 | // Compute r2 = how many cols of the 4x2 block fit |
1398 | "it gt\n" |
1399 | "movgt r2, r5\n" |
1400 | |
1401 | // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
1402 | "cmp r1, r3\n" |
1403 | "it eq\n" |
1404 | "cmpeq r2, r5\n" |
1405 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1406 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1407 | // Yes, all of the 4x2 block fits, go to fast path. |
1408 | "beq 30f\n" |
1409 | // Not all of the 4x2 block fits. |
1410 | // Store to dst_tmp_buf |
1411 | // Set r3 address to write to dst_tmp_buf. |
1412 | "mov r3, %[dst_tmp_buf]\n" |
1413 | "vst1.16 {q14}, [r3]\n" |
1414 | |
1415 | // Slow loop copying from dst_tmp_buf to dst. |
1416 | "mov r6, #0\n" |
1417 | "50:\n" |
1418 | "mov r8, #0\n" |
1419 | "51:\n" |
1420 | // Shift of offset register for half-word loads not allowed in A32, |
1421 | // so we shift, load/store, then shift back r8. |
1422 | "lsl r8, r8, #1\n" |
1423 | "ldrh r10, [r3, r8]\n" |
1424 | "strh r10, [r4, r8]\n" |
1425 | "lsr r8, r8, #1\n" |
1426 | "add r8, r8, #1\n" |
1427 | "cmp r8, r1\n" |
1428 | "blt 51b\n" |
1429 | "add r6, r6, #1\n" |
1430 | "add r3, r3, #8\n" |
1431 | "add r4, r4, r5\n" |
1432 | "cmp r6, r2\n" |
1433 | "blt 50b\n" |
1434 | "b 31f\n" |
1435 | "30:\n" |
1436 | // Yes, all of the 4x2 block fits. |
1437 | // r3 address, r5 stride |
1438 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1439 | "mov r4, r3\n" |
1440 | "mov r6, #2\n" |
1441 | |
1442 | "vst1.16 {d28[0]}, [r3], r6\n" |
1443 | "add r4, r4, r5\n" |
1444 | "vst1.16 {d28[1]}, [r3], r6\n" |
1445 | "vst1.16 {d28[2]}, [r3], r6\n" |
1446 | "vst1.16 {d28[3]}, [r3], r6\n" |
1447 | "mov r3, r4\n" |
1448 | "vst1.16 {d29[0]}, [r3], r6\n" |
1449 | "vst1.16 {d29[1]}, [r3], r6\n" |
1450 | "vst1.16 {d29[2]}, [r3], r6\n" |
1451 | "vst1.16 {d29[3]}, [r3], r6\n" |
1452 | "31:\n" |
1453 | |
1454 | // Load dst_ptr, increment, and write back. |
1455 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1456 | "add r4, r4, #8\n" |
1457 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1458 | |
1459 | RUY_MAKE_ZERO(q14) |
1460 | |
1461 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1462 | |
1463 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
1464 | |
1465 | // Since the store type is the same as the accum type, no need for |
1466 | // downcast. There's also no need for clamp by min/max. |
1467 | |
1468 | // At this point, v20 -- v31 aren't used anymore for the current block, |
1469 | // so we can start clearing these accumulators for the next block |
1470 | // (next iteration of the main loop). |
1471 | // Clear accumulators. |
1472 | RUY_MAKE_ZERO(q6) |
1473 | RUY_MAKE_ZERO(q7) |
1474 | RUY_MAKE_ZERO(q8) |
1475 | RUY_MAKE_ZERO(q9) |
1476 | RUY_MAKE_ZERO(q10) |
1477 | RUY_MAKE_ZERO(q11) |
1478 | RUY_MAKE_ZERO(q12) |
1479 | RUY_MAKE_ZERO(q13) |
1480 | |
1481 | // Compute how much of the 4x2 block of destination 32 bit values that |
1482 | // we have computed, fit in the destination matrix. Typically, all of |
1483 | // it fits, but when the destination matrix shape is not a multiple |
1484 | // of 4x2, there are some 4x4 blocks along the boundaries that do |
1485 | // not fit entirely. |
1486 | |
1487 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
1488 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1489 | "sub r1, r1, r8\n" |
1490 | |
1491 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
1492 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1493 | "sub r2, r2, r4\n" |
1494 | "mov r3, #4\n" |
1495 | "mov r5, #2\n" |
1496 | "cmp r1, #4\n" |
1497 | // Compute r1 = how many rows of the 4x2 block fit |
1498 | "it gt\n" |
1499 | "movgt r1, r3\n" |
1500 | |
1501 | "cmp r2, #2\n" |
1502 | // Compute r2 = how many cols of the 4x2 block fit |
1503 | "it gt\n" |
1504 | "movgt r2, r5\n" |
1505 | |
1506 | // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. |
1507 | "cmp r1, r3\n" |
1508 | "it eq\n" |
1509 | "cmpeq r2, r5\n" |
1510 | // Yes, all of the 4x2 block fits, go to fast path. |
1511 | "beq 30f\n" |
1512 | // Not all of the 4x2 block fits. |
1513 | // Set (r3 address, r4 stride) to write to dst_tmp_buf |
1514 | "mov r3, %[dst_tmp_buf]\n" |
1515 | "mov r4, #16\n" |
1516 | "b 31f\n" |
1517 | |
1518 | "30:\n" |
1519 | // Yes, all of the 4x2 block fits. |
1520 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1521 | // r3 address, r4 stride |
1522 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1523 | "mov r4, r5\n" |
1524 | |
1525 | "31:\n" |
1526 | |
1527 | "vst1.32 {d28, d29}, [r3]\n" |
1528 | "add r3, r3, r4\n" |
1529 | "vst1.32 {d30, d31}, [r3]\n" |
1530 | |
1531 | // If all of the 4x2 block fits, we just finished writing it to the |
1532 | // destination, so we skip the next part. |
1533 | "beq 41f\n" |
1534 | // Not all of the 4x2 block fits in the destination matrix. We just |
1535 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
1536 | // it to copy into the destination matrix the part that fits. |
1537 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1538 | "mov r3, %[dst_tmp_buf]\n" |
1539 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1540 | "mov r6, #0\n" |
1541 | "50:\n" |
1542 | "mov r5, #0\n" |
1543 | "51:\n" |
1544 | "ldr r10, [r3, r5, lsl #2]\n" |
1545 | "str r10, [r4, r5, lsl #2]\n" |
1546 | "add r5, r5, #1\n" |
1547 | "cmp r5, r1\n" |
1548 | "blt 51b\n" |
1549 | "add r6, r6, #1\n" |
1550 | "add r3, r3, #16\n" |
1551 | "add r4, r4, r8\n" |
1552 | // r2 = how many cols of the 8x4 block fit |
1553 | "cmp r6, r2\n" |
1554 | "blt 50b\n" |
1555 | |
1556 | "41:\n" |
1557 | // Load dst_ptr, increment, and write back. |
1558 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1559 | "add r4, r4, #16\n" |
1560 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1561 | |
1562 | RUY_MAKE_ZERO(q10) |
1563 | RUY_MAKE_ZERO(q11) |
1564 | |
1565 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
1566 | |
1567 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
1568 | |
1569 | // Reload some params --- we had used x5 -- x7 for a few other things |
1570 | // since the last time we had loaded them. |
1571 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1572 | "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1573 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1574 | |
1575 | // Move to the next block of the destination matrix, for the next iter |
1576 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
1577 | // been updated earlier. |
1578 | // Have we reached the end row? |
1579 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1580 | "cmp r8, r3\n" |
1581 | |
1582 | "beq 20f\n" // yes, end row. |
1583 | // Not end row. Move to the next row. |
1584 | "add r8, r8, #4\n" |
1585 | // Store new value of row |
1586 | "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1587 | |
1588 | "b 21f\n" |
1589 | "20:\n" |
1590 | // Was already at end row. |
1591 | // Move back to first row. |
1592 | "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1593 | // Move to the next column. |
1594 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1595 | "add r4, r4, #2\n" |
1596 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1597 | |
1598 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
1599 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
1600 | // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) |
1601 | "add r1, r1, r8, lsl #1\n" |
1602 | // Store dst_col_ptr |
1603 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
1604 | // Store dst_ptr |
1605 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1606 | "21:\n" |
1607 | |
1608 | // Main loop exit condition: have we hit the end column? |
1609 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
1610 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1611 | "cmp r8, r4\n" |
1612 | |
1613 | // w1 is the number of levels of depth that we have already loaded |
1614 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1615 | // above, this is currently 16. |
1616 | "mov r1, #16\n" |
1617 | |
1618 | "ble 1b\n" |
1619 | |
1620 | // Restore stack pointer. |
1621 | "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
1622 | |
1623 | // clang-format on |
1624 | |
1625 | : [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr) |
1626 | : [ params ] "r" (¶ms), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
1627 | : "r0" , "r1" , "r2" , "r3" , "r4" , "r5" , "r6" , "r8" , "r10" , "cc" , |
1628 | // Clobber list must specify q registers (and not their constituent |
1629 | // d registers). There is a (currently unexplained) slowdown if |
1630 | // d registers are listed in the clobbers list. |
1631 | "memory" , "q0" , "q1" , "q2" , "q3" , "q4" , "q5" , "q6" , "q7" , "q8" , |
1632 | "q9" , "q10" , "q12" , "q13" , "q14" , "q15" ); |
1633 | } |
1634 | |
1635 | // Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS |
1636 | // is still packed as if it has two columns |
1637 | void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { |
1638 | profiler::ScopeLabel label("Kernel (kNeon)" ); |
1639 | |
1640 | CheckOffsetsInKernelParams8bit(params); |
1641 | |
1642 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
1643 | const std::int8_t* rhs_col_ptr = |
1644 | static_cast<const int8_t*>(params.rhs_base_ptr); |
1645 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
1646 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
1647 | |
1648 | RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)); |
1649 | |
1650 | // The asm kernel below has the following NEON register allocation: |
1651 | // |
1652 | // q6 - q13 are 128-bit (4x32b) accumulators. |
1653 | // During accumulation, d0 -- d7 are used to load int8 data from LHS and |
1654 | // d8 -- d11 from RHS: |
1655 | // int8 RHS 16x1 block |
1656 | // /------------| |
1657 | // | d8.b[0] | |
1658 | // | ... | |
1659 | // | d8.b[7] | |
1660 | // | d9.b[0] | |
1661 | // | ... | |
1662 | // | d9.b[7] | |
1663 | // \------------/ |
1664 | // int8 LHS 4x16 block |
1665 | // /-----------------------------------------\ /------------| |
1666 | // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | |
1667 | // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | |
1668 | // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | |
1669 | // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | |
1670 | // \-----------------------------------------/ \------------/ |
1671 | // 128-bit accumulators 4x1 block |
1672 | // |
1673 | // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING |
1674 | // optimization for this kernel. |
1675 | asm volatile( |
1676 | #define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" |
1677 | |
1678 | // clang-format off |
1679 | |
1680 | // Load the first 64 bytes of LHS and RHS data. |
1681 | "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
1682 | "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
1683 | "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
1684 | "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
1685 | "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
1686 | // Skip the other column and advance the pointer. |
1687 | "add %[rhs_ptr], %[rhs_ptr], #16\n" |
1688 | |
1689 | "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
1690 | |
1691 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
1692 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
1693 | |
1694 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" |
1695 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
1696 | |
1697 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
1698 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1699 | |
1700 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" |
1701 | "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1702 | |
1703 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1704 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
1705 | |
1706 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" |
1707 | "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
1708 | |
1709 | // Clear accumulators. |
1710 | RUY_MAKE_ZERO(q6) |
1711 | RUY_MAKE_ZERO(q7) |
1712 | RUY_MAKE_ZERO(q8) |
1713 | RUY_MAKE_ZERO(q9) |
1714 | RUY_MAKE_ZERO(q10) |
1715 | RUY_MAKE_ZERO(q11) |
1716 | RUY_MAKE_ZERO(q12) |
1717 | RUY_MAKE_ZERO(q13) |
1718 | RUY_MAKE_ZERO(q14) |
1719 | RUY_MAKE_ZERO(q15) |
1720 | |
1721 | // r1 is the number of levels of depth that we have already loaded |
1722 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
1723 | // above, this is currently 16. |
1724 | "mov r1, #16\n" |
1725 | |
1726 | // Main loop of the whole GEMM, over rows and columns of the |
1727 | // destination matrix. |
1728 | "1:\n" |
1729 | |
1730 | // r1 is how many levels of depth we have already loaded |
1731 | // data for, r10 is the total depth. |
1732 | "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" |
1733 | "cmp r1, r10\n" |
1734 | "beq 79f\n" |
1735 | |
1736 | "2:\n" |
1737 | |
1738 | // Mult, mult-acc in to q14, q15 |
1739 | "vmull.s8 q14, d0, d8\n" |
1740 | "vmull.s8 q15, d2, d8\n" |
1741 | "vmlal.s8 q14, d1, d9\n" |
1742 | "vmlal.s8 q15, d3, d9\n" |
1743 | |
1744 | // Then pairwise accumulate in to q6, q7 |
1745 | "vpadal.s16 q6, q14\n" |
1746 | "vpadal.s16 q7, q15\n" |
1747 | |
1748 | // Mult, mult-acc in to q14, q15 |
1749 | "vmull.s8 q14, d4, d8\n" |
1750 | "vmull.s8 q15, d6, d8\n" |
1751 | "vmlal.s8 q14, d5, d9\n" |
1752 | "vmlal.s8 q15, d7, d9\n" |
1753 | |
1754 | // Then pairwise accumulate in to q8, q9 |
1755 | "vpadal.s16 q8, q14\n" |
1756 | "vpadal.s16 q9, q15\n" |
1757 | |
1758 | |
1759 | // Load the next 64 bytes of LHS and RHS data. |
1760 | "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
1761 | "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
1762 | "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
1763 | "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
1764 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
1765 | "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
1766 | // Skip the other column and advance the pointer. |
1767 | "add %[rhs_ptr], %[rhs_ptr], #16\n" |
1768 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
1769 | |
1770 | // Each iteration of this loop advances by 16 levels of depth. |
1771 | "add r1, r1, #16\n" |
1772 | |
1773 | // Loop termination condition |
1774 | "cmp r1, r10\n" |
1775 | |
1776 | "blt 2b\n" |
1777 | |
1778 | "79:\n" |
1779 | |
1780 | // Mult, mult-acc in to q14, q15 |
1781 | "vmull.s8 q14, d0, d8\n" |
1782 | "vmull.s8 q15, d2, d8\n" |
1783 | "vmlal.s8 q14, d1, d9\n" |
1784 | "vmlal.s8 q15, d3, d9\n" |
1785 | |
1786 | // Then pairwise accumulate in to q6, q7 |
1787 | "vpadal.s16 q6, q14\n" |
1788 | "vpadal.s16 q7, q15\n" |
1789 | |
1790 | // Mult, mult-acc in to q14, q15 |
1791 | "vmull.s8 q14, d4, d8\n" |
1792 | "vmull.s8 q15, d6, d8\n" |
1793 | "vmlal.s8 q14, d5, d9\n" |
1794 | "vmlal.s8 q15, d7, d9\n" |
1795 | |
1796 | // Then pairwise accumulate in to q8, q9 |
1797 | "vpadal.s16 q8, q14\n" |
1798 | "vpadal.s16 q9, q15\n" |
1799 | |
1800 | // All accumulation over depth done. q6 - q9 contain the 4x32b |
1801 | // accumulators for the 4x1 final matrix. |
1802 | // We now have to compute the final 8-bit values from these int32 |
1803 | // accumulators, and advance to the next 4x2 block. We intertwine |
1804 | // these two aspects whenever possible for optimal pipelining, both |
1805 | // at the data flow level (prefetch data for next block as early as |
1806 | // possible) and instruction pipelining level (some of the next-block |
1807 | // work can dual-issue with some of the final work on the current |
1808 | // block). |
1809 | |
1810 | // q6-q9 now contain 4 x 32b |
1811 | "vpadd.i32 d0, d12, d13\n" |
1812 | "vpadd.i32 d1, d14, d15\n" |
1813 | "vpadd.i32 d2, d16, d17\n" |
1814 | "vpadd.i32 d3, d18, d19\n" |
1815 | |
1816 | // d0-d4 each contain 2 x 32b accumulators. |
1817 | // Need to add pairwise to get 1 x 32b for each of the 4x1 entries |
1818 | // of destination, (Four 'd' registers total) |
1819 | "vpadd.i32 d28, d0, d1\n" |
1820 | "vpadd.i32 d29, d2, d3\n" |
1821 | |
1822 | // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. |
1823 | |
1824 | // Logic to advance to the next block in preparation for the next |
1825 | // iteration of the main loop. For now, we only want to compute |
1826 | // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are |
1827 | // not yet ready to update the values of row and col, as we still need |
1828 | // the current values for the rest of the work on the current block. |
1829 | |
1830 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
1831 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1832 | "cmp r1, r3\n" // Have we finished the last row? |
1833 | |
1834 | "bge 4f\n" // If finished last row, go to 4 |
1835 | // Not finished last row: then advance to next row. |
1836 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" |
1837 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
1838 | "add r4, r4, r1, lsl #2\n" |
1839 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
1840 | "b 5f\n" |
1841 | "4:\n" // Finished last row... |
1842 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
1843 | // Go back to first row |
1844 | "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
1845 | |
1846 | // Now we need to advance to the next column. If we already |
1847 | // finished the last column, then in principle we are done, however |
1848 | // we can't just return here, as we need to allow the end work of the |
1849 | // current block to complete. The good news is that at this point it |
1850 | // doesn't matter what data we load for the next column, since |
1851 | // we will exit from the main loop below before actually storing |
1852 | // anything computed from that data. |
1853 | |
1854 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
1855 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1856 | "cmp r8, r4\n" // Have we finished the last column? |
1857 | "bge 5f\n" // If yes, just carry on without updating the column pointer. |
1858 | // Not finished last column: then advance to next column. |
1859 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" |
1860 | "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
1861 | "add r10, r10, r1, lsl #1\n" |
1862 | "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
1863 | "5:\n" |
1864 | |
1865 | // Set the LHS and RHS data pointers to the start of the columns just |
1866 | // computed. |
1867 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" |
1868 | "mov %[lhs_ptr], r4\n" |
1869 | "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" |
1870 | "mov %[rhs_ptr], r5\n" |
1871 | |
1872 | // Now we load: bias data, LHS sums data, RHS sums data. |
1873 | |
1874 | // First, load the base pointers from the params. |
1875 | "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
1876 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" |
1877 | |
1878 | // Offset these base pointers as needed given the current row, col. |
1879 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1880 | |
1881 | "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" |
1882 | "beq 1000f\n" |
1883 | "add r1, r1, r8, lsl #2\n" |
1884 | "1000:\n" |
1885 | |
1886 | // Load 4 bias values. |
1887 | "vld1.32 {d24, d25}, [r1]\n" |
1888 | |
1889 | // Now that we know what LHS and RHS data the next iteration of the |
1890 | // main loop will need to load, we start loading the first 32 bytes of |
1891 | // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore |
1892 | // in the rest of the work on the current block. |
1893 | "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" |
1894 | "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" |
1895 | "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" |
1896 | "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" |
1897 | RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n" ) |
1898 | "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" |
1899 | // Skip the other column and advance the pointer. |
1900 | "add %[rhs_ptr], %[rhs_ptr], #16\n" |
1901 | RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n" ) |
1902 | |
1903 | // Add to the bias values the product |
1904 | // (depth * lhs_zero_point * rhs_zero_point), |
1905 | // See the term NZ1Z2 in equation (7) in |
1906 | // https://arxiv.org/pdf/1712.05877.pdf |
1907 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" |
1908 | "vdup.32 q9, r3\n" |
1909 | "vadd.i32 q12, q12, q9\n" |
1910 | |
1911 | // Perform the bias-addition (per the above, we have just folded into |
1912 | // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) |
1913 | "vadd.i32 q14, q14, q12\n" |
1914 | |
1915 | // LHS/RHS zero points |
1916 | // Has RHS sums |
1917 | "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
1918 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" |
1919 | "beq 401f\n" |
1920 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" |
1921 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
1922 | // Offset by current col * number of bytes per value |
1923 | "add r3, r3, r4, lsl #2\n" |
1924 | "vld1.32 { d12 }, [r3]\n" |
1925 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" |
1926 | "vdup.32 q10, r5\n" // create lhs_zero_point_vec |
1927 | // Subtract rhs_sums * lhs_zero_point, per |
1928 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
1929 | "vmls.i32 q14, q10, d12[0]\n" |
1930 | "401:\n" |
1931 | |
1932 | // Has LHS sums |
1933 | "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" |
1934 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" |
1935 | "beq 402f\n" |
1936 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" |
1937 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1938 | // Offset by current row * number of bytes per value |
1939 | "add r2, r2, r4, lsl #2\n" |
1940 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" |
1941 | |
1942 | // Load 4 lhs_sums values. |
1943 | "vld1.32 {d22, d23}, [r2]\n" |
1944 | "vdup.32 d13, r5\n" // rhs_zero_point |
1945 | |
1946 | // Compute lhs_sums * rhs_zero_point. |
1947 | "vmul.i32 q11, q11, d13[1]\n" |
1948 | // Subtract lhs_sums * rhs_zero_point, per |
1949 | // equation (7) in https://arxiv.org/pdf/1712.05877.pdf |
1950 | "vsub.s32 q14, q14, q11\n" |
1951 | |
1952 | // If the destination is int32, it means the user asks for the raw |
1953 | // accumulators, no need for us to downquantize the value. |
1954 | "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
1955 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" |
1956 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" |
1957 | |
1958 | "402:\n" |
1959 | |
1960 | // At this point we have computed the final int32 values. Now we |
1961 | // start down-quantizing them to obtain the final 8bit values from them. |
1962 | |
1963 | // As part of this down-quantization, our int32 values will be |
1964 | // multiplied by a multiplier that has a fixed-point component and an |
1965 | // exponent component. |
1966 | |
1967 | //Load the exponent part of the multiplier. |
1968 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" |
1969 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
1970 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
1971 | "beq 1001f\n" |
1972 | "add r1, r1, r4, lsl #2\n" |
1973 | "1001:\n" |
1974 | |
1975 | "vld1.32 {q10}, [r1]\n" |
1976 | |
1977 | "vmvn.i32 q8, #0\n" |
1978 | "vmin.s32 q13, q10, q8\n" |
1979 | "vsub.s32 q12, q10, q13\n" |
1980 | |
1981 | // Apply the positive exponent part of the multiplier. |
1982 | "vshl.s32 q14, q14, q12\n" |
1983 | |
1984 | // Load fixed point part of the multiplier |
1985 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" |
1986 | // r6 has flags, r4 has row |
1987 | "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" |
1988 | "beq 1002f\n" |
1989 | "add r1, r1, r4, lsl #2\n" |
1990 | "1002:\n" |
1991 | "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint |
1992 | |
1993 | // Apply the fixed-point part of the multiplier. |
1994 | "vqdmulh.s32 q14, q14, q10\n" |
1995 | |
1996 | // Apply the negative exponent part of the multiplier. |
1997 | "vrshl.s32 q14, q14, q13\n" |
1998 | |
1999 | "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" |
2000 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" |
2001 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" |
2002 | "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" |
2003 | "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" |
2004 | |
2005 | // Store uint8 values: |
2006 | RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" |
2007 | |
2008 | // Cast-and-saturate from int32 to int16 |
2009 | // After this, all values for output are in d28. |
2010 | "vqmovn.s32 d28, q14\n" |
2011 | |
2012 | // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
2013 | // current block, so we can start clearing these accumulators for the |
2014 | // next block (next iteration of the main loop). |
2015 | RUY_MAKE_ZERO(q6) |
2016 | RUY_MAKE_ZERO(q7) |
2017 | RUY_MAKE_ZERO(q8) |
2018 | RUY_MAKE_ZERO(q9) |
2019 | RUY_MAKE_ZERO(q10) |
2020 | RUY_MAKE_ZERO(q11) |
2021 | RUY_MAKE_ZERO(q12) |
2022 | RUY_MAKE_ZERO(q13) |
2023 | RUY_MAKE_ZERO(q15) |
2024 | |
2025 | // Load the destination zero point into each of the 8 16-bit slots |
2026 | // in a q register. |
2027 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
2028 | "vdup.16 q13, r4\n" // dst_zero_point |
2029 | |
2030 | // Add the destination zero point |
2031 | "vqadd.s16 q14, q14, q13\n" |
2032 | |
2033 | // Cast-and-saturate from int16 to uint8 |
2034 | "vqmovun.s16 d30, q14\n" |
2035 | // At this point, we only need 4 8-bit values in the lower half |
2036 | // of d30. |
2037 | |
2038 | |
2039 | // Load the clamp_min, clamp_max bounds |
2040 | "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2041 | "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2042 | "vdup.8 d28, r2\n" // clamp_min |
2043 | "vdup.8 d29, r3\n" // clamp_max |
2044 | |
2045 | // Apply the clamp_min bound |
2046 | "vmax.u8 d30, d30, d28\n" |
2047 | // Apply the clamp_max bound |
2048 | "vmin.u8 d30, d30, d29\n" |
2049 | |
2050 | // Compute how much of the 4x1 block of destination 8bit values that |
2051 | // we have computed, fit in the destination matrix. Typically, all of |
2052 | // it fits, but when the destination matrix shape is not a multiple |
2053 | // of 4x1, there are some 4x1 blocks along the boundaries that do |
2054 | // not fit entirely. |
2055 | |
2056 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
2057 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2058 | "sub r1, r1, r8\n" |
2059 | |
2060 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
2061 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2062 | "sub r2, r2, r4\n" |
2063 | "mov r3, #4\n" |
2064 | "mov r5, #2\n" |
2065 | "cmp r1, #4\n" |
2066 | // Compute r1 = how many rows of the 4x1 block fit |
2067 | "it gt\n" |
2068 | "movgt r1, r3\n" |
2069 | |
2070 | // Test if r1==4, i.e. if all of the 4x1 block fits. |
2071 | "cmp r1, r3\n" |
2072 | |
2073 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2074 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2075 | // Yes, all of the 4x1 block fits, go to fast path. |
2076 | "beq 30f\n" |
2077 | // Not all of the 4x1 block fits. |
2078 | // Store to dst_tmp_buf |
2079 | // Set r3 address to write to dst_tmp_buf. |
2080 | "mov r3, %[dst_tmp_buf]\n" |
2081 | "vst1.8 {d30}, [r3]\n" |
2082 | |
2083 | // Slow loop copying from dst_tmp_buf to dst. |
2084 | "50:\n" |
2085 | "mov r8, #0\n" |
2086 | "51:\n" |
2087 | "ldrb r10, [r3, r8]\n" |
2088 | "strb r10, [r4, r8]\n" |
2089 | "add r8, r8, #1\n" |
2090 | "cmp r8, r1\n" |
2091 | "blt 51b\n" |
2092 | "b 31f\n" |
2093 | "30:\n" |
2094 | // Yes, all of the 4x1 block fits. |
2095 | // r3 address, r5 stride |
2096 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2097 | "mov r4, r3\n" |
2098 | "mov r6, #1\n" |
2099 | |
2100 | "vst1.8 {d30[0]}, [r3], r6\n" |
2101 | "vst1.8 {d30[1]}, [r3], r6\n" |
2102 | "vst1.8 {d30[2]}, [r3], r6\n" |
2103 | "vst1.8 {d30[3]}, [r3], r6\n" |
2104 | "31:\n" |
2105 | |
2106 | // Load dst_ptr, increment, and write back. |
2107 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2108 | "add r4, r4, #4\n" |
2109 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2110 | |
2111 | RUY_MAKE_ZERO(q13) |
2112 | RUY_MAKE_ZERO(q14) |
2113 | RUY_MAKE_ZERO(q15) |
2114 | |
2115 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2116 | |
2117 | // Store int8 values: |
2118 | RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" |
2119 | |
2120 | // Cast-and-saturate from int32 to int16 |
2121 | // After this, all values for output are in d28. |
2122 | "vqmovn.s32 d28, q14\n" |
2123 | |
2124 | // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
2125 | // current block, so we can start clearing these accumulators for the |
2126 | // next block (next iteration of the main loop). |
2127 | RUY_MAKE_ZERO(q6) |
2128 | RUY_MAKE_ZERO(q7) |
2129 | RUY_MAKE_ZERO(q8) |
2130 | RUY_MAKE_ZERO(q9) |
2131 | RUY_MAKE_ZERO(q10) |
2132 | RUY_MAKE_ZERO(q11) |
2133 | RUY_MAKE_ZERO(q12) |
2134 | RUY_MAKE_ZERO(q13) |
2135 | RUY_MAKE_ZERO(q15) |
2136 | |
2137 | // Load the destination zero point into each of the 8 16-bit slots |
2138 | // in a q register. |
2139 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
2140 | "vdup.16 q13, r4\n" // dst_zero_point |
2141 | |
2142 | // Add the destination zero point |
2143 | "vqadd.s16 q14, q14, q13\n" |
2144 | |
2145 | // Cast-and-saturate from int16 to int8 |
2146 | "vqmovn.s16 d30, q14\n" |
2147 | // At this point, we only need 4 8-bit values in the lower half |
2148 | // of d30. |
2149 | |
2150 | // Load the clamp_min, clamp_max bounds |
2151 | "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2152 | "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2153 | "vdup.8 d28, r2\n" // clamp_min |
2154 | "vdup.8 d29, r3\n" // clamp_max |
2155 | |
2156 | // Apply the clamp_min bound |
2157 | "vmax.s8 d30, d30, d28\n" |
2158 | // Apply the clamp_max bound |
2159 | "vmin.s8 d30, d30, d29\n" |
2160 | |
2161 | // Compute how much of the 4x1 block of destination 8bit values that |
2162 | // we have computed, fit in the destination matrix. Typically, all of |
2163 | // it fits, but when the destination matrix shape is not a multiple |
2164 | // of 4x2, there are some 4x2 blocks along the boundaries that do |
2165 | // not fit entirely. |
2166 | |
2167 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
2168 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2169 | "sub r1, r1, r8\n" |
2170 | |
2171 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
2172 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2173 | "sub r2, r2, r4\n" |
2174 | "mov r3, #4\n" |
2175 | "mov r5, #2\n" |
2176 | "cmp r1, #4\n" |
2177 | // Compute r1 = how many rows of the 4x2 block fit |
2178 | "it gt\n" |
2179 | "movgt r1, r3\n" |
2180 | |
2181 | // Test if r1==4 i.e. if all of the 4x1 block fits. |
2182 | "cmp r1, r3\n" |
2183 | |
2184 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2185 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2186 | // Yes, all of the 4x2 block fits, go to fast path. |
2187 | "beq 30f\n" |
2188 | // Not all of the 4x2 block fits. |
2189 | // Store to dst_tmp_buf |
2190 | // Set r3 address to write to dst_tmp_buf. |
2191 | "mov r3, %[dst_tmp_buf]\n" |
2192 | "vst1.8 {d30}, [r3]\n" |
2193 | |
2194 | // Slow loop copying from dst_tmp_buf to dst. |
2195 | "50:\n" |
2196 | "mov r8, #0\n" |
2197 | "51:\n" |
2198 | "ldrb r10, [r3, r8]\n" |
2199 | "strb r10, [r4, r8]\n" |
2200 | "add r8, r8, #1\n" |
2201 | "cmp r8, r1\n" |
2202 | "blt 51b\n" |
2203 | "b 31f\n" |
2204 | "30:\n" |
2205 | // Yes, all of the 4x1 block fits. |
2206 | // r3 address, r5 stride |
2207 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2208 | "mov r4, r3\n" |
2209 | "mov r6, #1\n" |
2210 | |
2211 | "vst1.8 {d30[0]}, [r3], r6\n" |
2212 | "vst1.8 {d30[1]}, [r3], r6\n" |
2213 | "vst1.8 {d30[2]}, [r3], r6\n" |
2214 | "vst1.8 {d30[3]}, [r3], r6\n" |
2215 | "31:\n" |
2216 | |
2217 | // Load dst_ptr, increment, and write back. |
2218 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2219 | "add r4, r4, #4\n" |
2220 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2221 | |
2222 | RUY_MAKE_ZERO(q13) |
2223 | RUY_MAKE_ZERO(q14) |
2224 | RUY_MAKE_ZERO(q15) |
2225 | |
2226 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2227 | |
2228 | RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" |
2229 | |
2230 | // Load the destination zero point into each of the 4 32-bit slots |
2231 | // in a q register. |
2232 | "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" |
2233 | "vdup.32 q13, r4\n" // dst_zero_point |
2234 | // Add the destination zero point |
2235 | "vadd.s32 q14, q14, q13\n" |
2236 | //"vadd.s32 q15, q15, q13\n" |
2237 | |
2238 | // Cast-and-saturate from int32 to int16 |
2239 | // After this, all values for output are in d28. |
2240 | "vqmovn.s32 d28, q14\n" |
2241 | |
2242 | // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the |
2243 | // so we can start clearing these accumulators for the next block |
2244 | // (next iteration of the main loop). |
2245 | RUY_MAKE_ZERO(q6) |
2246 | RUY_MAKE_ZERO(q7) |
2247 | RUY_MAKE_ZERO(q8) |
2248 | RUY_MAKE_ZERO(q9) |
2249 | RUY_MAKE_ZERO(q10) |
2250 | RUY_MAKE_ZERO(q11) |
2251 | RUY_MAKE_ZERO(q15) |
2252 | |
2253 | // Load the clamp_min, clamp_max bounds |
2254 | "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" |
2255 | "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" |
2256 | "vdup.16 d24, r2\n" // clamp_min |
2257 | "vdup.16 d26, r3\n" // clamp_max |
2258 | |
2259 | // Apply the clamp_min bound |
2260 | "vmax.s16 d28, d28, d24\n" |
2261 | // Apply the clamp_max bound |
2262 | "vmin.s16 d28, d28, d26\n" |
2263 | |
2264 | RUY_MAKE_ZERO(q12) |
2265 | RUY_MAKE_ZERO(q13) |
2266 | |
2267 | // Compute how much of the 4x1 block of destination 16-bit values that |
2268 | // we have computed, fit in the destination matrix. Typically, all of |
2269 | // it fits, but when the destination matrix shape is not a multiple |
2270 | // of 4x1, there are some 4x1 blocks along the boundaries that do |
2271 | // not fit entirely. |
2272 | |
2273 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
2274 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2275 | "sub r1, r1, r8\n" |
2276 | |
2277 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
2278 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2279 | "sub r2, r2, r4\n" |
2280 | "mov r3, #4\n" |
2281 | "mov r5, #2\n" |
2282 | "cmp r1, #4\n" |
2283 | // Compute r1 = how many rows of the 4x1 block fit |
2284 | "it gt\n" |
2285 | "movgt r1, r3\n" |
2286 | |
2287 | // Test if r1==4, i.e. if all of the 4x1 block fits. |
2288 | "cmp r1, r3\n" |
2289 | |
2290 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2291 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2292 | // Yes, all of the 4x1 block fits, go to fast path. |
2293 | "beq 30f\n" |
2294 | // Not all of the 4x1 block fits. |
2295 | // Store to dst_tmp_buf |
2296 | // Set r3 address to write to dst_tmp_buf. |
2297 | "mov r3, %[dst_tmp_buf]\n" |
2298 | "vst1.16 {d28}, [r3]\n" |
2299 | |
2300 | // Slow loop copying from dst_tmp_buf to dst. |
2301 | "50:\n" |
2302 | "mov r8, #0\n" |
2303 | "51:\n" |
2304 | // Shift of offset register for half-word loads not allowed in A32, |
2305 | // so we shift, load/store, then shift back r8. |
2306 | "lsl r8, r8, #1\n" |
2307 | "ldrh r10, [r3, r8]\n" |
2308 | "strh r10, [r4, r8]\n" |
2309 | "lsr r8, r8, #1\n" |
2310 | "add r8, r8, #1\n" |
2311 | "cmp r8, r1\n" |
2312 | "blt 51b\n" |
2313 | "b 31f\n" |
2314 | "30:\n" |
2315 | // Yes, all of the 4x1 block fits. |
2316 | // r3 address, r5 stride |
2317 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2318 | "mov r4, r3\n" |
2319 | "mov r6, #2\n" |
2320 | |
2321 | "vst1.16 {d28[0]}, [r3], r6\n" |
2322 | "vst1.16 {d28[1]}, [r3], r6\n" |
2323 | "vst1.16 {d28[2]}, [r3], r6\n" |
2324 | "vst1.16 {d28[3]}, [r3], r6\n" |
2325 | "31:\n" |
2326 | |
2327 | // Load dst_ptr, increment, and write back. |
2328 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2329 | "add r4, r4, #8\n" |
2330 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2331 | |
2332 | RUY_MAKE_ZERO(q14) |
2333 | |
2334 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2335 | |
2336 | RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" |
2337 | |
2338 | // Since the store type is the same as the accum type, no need for |
2339 | // downcast. There's also no need for clamp by min/max. |
2340 | |
2341 | // At this point, v20 -- v31 aren't used anymore for the current block, |
2342 | // so we can start clearing these accumulators for the next block |
2343 | // (next iteration of the main loop). |
2344 | // Clear accumulators. |
2345 | RUY_MAKE_ZERO(q6) |
2346 | RUY_MAKE_ZERO(q7) |
2347 | RUY_MAKE_ZERO(q8) |
2348 | RUY_MAKE_ZERO(q9) |
2349 | RUY_MAKE_ZERO(q10) |
2350 | RUY_MAKE_ZERO(q11) |
2351 | RUY_MAKE_ZERO(q12) |
2352 | RUY_MAKE_ZERO(q13) |
2353 | |
2354 | // Compute how much of the 4x1 block of destination 32 bit values that |
2355 | // we have computed, fit in the destination matrix. Typically, all of |
2356 | // it fits, but when the destination matrix shape is not a multiple |
2357 | // of 4x2, there are some 4x4 blocks along the boundaries that do |
2358 | // not fit entirely. |
2359 | |
2360 | "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" |
2361 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2362 | "sub r1, r1, r8\n" |
2363 | |
2364 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" |
2365 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2366 | "sub r2, r2, r4\n" |
2367 | "mov r3, #4\n" |
2368 | "mov r5, #2\n" |
2369 | "cmp r1, #4\n" |
2370 | // Compute r1 = how many rows of the 4x2 block fit |
2371 | "it gt\n" |
2372 | "movgt r1, r3\n" |
2373 | |
2374 | // Test if r1==4, i.e. if all of the 4x1 block fits. |
2375 | "cmp r1, r3\n" |
2376 | |
2377 | // Yes, all of the 4x1 block fits, go to fast path. |
2378 | "beq 30f\n" |
2379 | // Not all of the 4x1 block fits. |
2380 | // Set (r3 address, r4 stride) to write to dst_tmp_buf |
2381 | "mov r3, %[dst_tmp_buf]\n" |
2382 | "mov r4, #16\n" |
2383 | "b 31f\n" |
2384 | |
2385 | "30:\n" |
2386 | // Yes, all of the 4x1 block fits. |
2387 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2388 | // r3 address, r4 stride |
2389 | "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2390 | "mov r4, r5\n" |
2391 | |
2392 | "31:\n" |
2393 | |
2394 | "vst1.32 {d28, d29}, [r3]\n" |
2395 | |
2396 | // If all of the 4x1 block fits, we just finished writing it to the |
2397 | // destination, so we skip the next part. |
2398 | "beq 41f\n" |
2399 | // Not all of the 4x1 block fits in the destination matrix. We just |
2400 | // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over |
2401 | // it to copy into the destination matrix the part that fits. |
2402 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2403 | "mov r3, %[dst_tmp_buf]\n" |
2404 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2405 | "50:\n" |
2406 | "mov r5, #0\n" |
2407 | "51:\n" |
2408 | "ldr r10, [r3, r5, lsl #2]\n" |
2409 | "str r10, [r4, r5, lsl #2]\n" |
2410 | "add r5, r5, #1\n" |
2411 | "cmp r5, r1\n" |
2412 | "blt 51b\n" |
2413 | |
2414 | "41:\n" |
2415 | // Load dst_ptr, increment, and write back. |
2416 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2417 | "add r4, r4, #16\n" |
2418 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2419 | |
2420 | RUY_MAKE_ZERO(q10) |
2421 | RUY_MAKE_ZERO(q11) |
2422 | |
2423 | "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" |
2424 | |
2425 | RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" |
2426 | |
2427 | // Reload some params --- we had used x5 -- x7 for a few other things |
2428 | // since the last time we had loaded them. |
2429 | "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" |
2430 | "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" |
2431 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" |
2432 | |
2433 | // Move to the next block of the destination matrix, for the next iter |
2434 | // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already |
2435 | // been updated earlier. |
2436 | // Have we reached the end row? |
2437 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2438 | "cmp r8, r3\n" |
2439 | |
2440 | "beq 20f\n" // yes, end row. |
2441 | // Not end row. Move to the next row. |
2442 | "add r8, r8, #4\n" |
2443 | // Store new value of row |
2444 | "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2445 | |
2446 | "b 21f\n" |
2447 | "20:\n" |
2448 | // Was already at end row. |
2449 | // Move back to first row. |
2450 | "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" |
2451 | // Move to the next column. |
2452 | "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2453 | "add r4, r4, #2\n" |
2454 | "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2455 | |
2456 | "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" |
2457 | "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
2458 | // Increment dst_col_ptr by dst_stride (i.e. 1 column) |
2459 | "add r1, r1, r8\n" |
2460 | // Store dst_col_ptr |
2461 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" |
2462 | // Store dst_ptr |
2463 | "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" |
2464 | "21:\n" |
2465 | |
2466 | // Main loop exit condition: have we hit the end column? |
2467 | "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" |
2468 | "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" |
2469 | "cmp r8, r4\n" |
2470 | |
2471 | // w1 is the number of levels of depth that we have already loaded |
2472 | // LHS and RHS data for. Corresponding to the initial ld1 instructions |
2473 | // above, this is currently 16. |
2474 | "mov r1, #16\n" |
2475 | |
2476 | "ble 1b\n" |
2477 | |
2478 | // Restore stack pointer. |
2479 | "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" |
2480 | |
2481 | // clang-format on |
2482 | |
2483 | : [lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr) |
2484 | : [ params ] "r" (¶ms), [dst_tmp_buf] "r" (params.dst_tmp_buf) |
2485 | : "r0" , "r1" , "r2" , "r3" , "r4" , "r5" , "r6" , "r8" , "r10" , "cc" , |
2486 | // Clobber list must specify q registers (and not their constituent |
2487 | // d registers). There is a (currently unexplained) slowdown if |
2488 | // d registers are listed in the clobbers list. |
2489 | "memory" , "q0" , "q1" , "q2" , "q3" , "q4" , "q5" , "q6" , "q7" , "q8" , |
2490 | "q9" , "q10" , "q12" , "q13" , "q14" , "q15" ); |
2491 | } |
2492 | |
2493 | #undef RUY_OFFSET_BIAS |
2494 | #undef RUY_OFFSET_LHS_SUMS |
2495 | #undef RUY_OFFSET_RHS_SUMS |
2496 | #undef RUY_OFFSET_LHS_BASE_PTR |
2497 | #undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT |
2498 | #undef RUY_OFFSET_MULTIPLIER_EXPONENT |
2499 | #undef RUY_OFFSET_RHS_BASE_PTR |
2500 | #undef RUY_OFFSET_DST_BASE_PTR |
2501 | #undef RUY_OFFSET_LHS_ZERO_POINT |
2502 | #undef RUY_OFFSET_RHS_ZERO_POINT |
2503 | #undef RUY_OFFSET_DST_ZERO_POINT |
2504 | #undef RUY_OFFSET_PROD_ZP_DEPTH |
2505 | #undef RUY_OFFSET_START_ROW |
2506 | #undef RUY_OFFSET_START_COL |
2507 | #undef RUY_OFFSET_LAST_ROW |
2508 | #undef RUY_OFFSET_LAST_COL |
2509 | #undef RUY_OFFSET_DST_ROWS |
2510 | #undef RUY_OFFSET_DST_COLS |
2511 | #undef RUY_OFFSET_LHS_STRIDE |
2512 | #undef RUY_OFFSET_RHS_STRIDE |
2513 | #undef RUY_OFFSET_DST_STRIDE |
2514 | #undef RUY_OFFSET_DEPTH |
2515 | #undef RUY_OFFSET_CLAMP_MIN |
2516 | #undef RUY_OFFSET_CLAMP_MAX |
2517 | #undef RUY_OFFSET_FLAGS |
2518 | #undef RUY_OFFSET_DST_TYPE_ID |
2519 | |
2520 | #undef RUY_STACK_OFFSET_SIZE |
2521 | #undef RUY_STACK_OFFSET_DST_COL_PTR |
2522 | #undef RUY_STACK_OFFSET_DST_PTR |
2523 | #undef RUY_STACK_OFFSET_ROW |
2524 | #undef RUY_STACK_OFFSET_COL |
2525 | #undef RUY_STACK_OFFSET_LHS_COL_PTR |
2526 | #undef RUY_STACK_OFFSET_RHS_COL_PTR |
2527 | |
2528 | #endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM) |
2529 | } // namespace ruy |
2530 | |