1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "ruy/kernel_arm.h"
17#include "ruy/opt_set.h"
18#include "ruy/platform.h"
19#include "ruy/profiler/instrumentation.h"
20
21namespace ruy {
22
23#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
24
25#define RUY_ASM_LABEL_STORE_UINT8 91
26#define RUY_ASM_LABEL_STORE_INT8 92
27#define RUY_ASM_LABEL_STORE_INT16 93
28#define RUY_ASM_LABEL_STORE_INT32 94
29#define RUY_ASM_LABEL_AFTER_STORE 99
30
31#define RUY_OFFSET_LHS_BASE_PTR 0
32#define RUY_OFFSET_RHS_BASE_PTR 4
33#define RUY_OFFSET_DST_BASE_PTR 8
34#define RUY_OFFSET_BIAS 12
35#define RUY_OFFSET_START_ROW 16
36#define RUY_OFFSET_START_COL 20
37#define RUY_OFFSET_LAST_ROW 24
38#define RUY_OFFSET_LAST_COL 28
39#define RUY_OFFSET_DST_ROWS 32
40#define RUY_OFFSET_DST_COLS 36
41#define RUY_OFFSET_LHS_STRIDE 40
42#define RUY_OFFSET_RHS_STRIDE 44
43#define RUY_OFFSET_DST_STRIDE 48
44#define RUY_OFFSET_DEPTH 52
45#define RUY_OFFSET_CLAMP_MIN 56
46#define RUY_OFFSET_CLAMP_MAX 60
47#define RUY_OFFSET_FLAGS 64
48
49#define RUY_STACK_OFFSET_SIZE 96
50#define RUY_STACK_OFFSET_DST_COL_PTR 0
51#define RUY_STACK_OFFSET_DST_PTR 16
52#define RUY_STACK_OFFSET_ROW 32
53#define RUY_STACK_OFFSET_COL 48
54#define RUY_STACK_OFFSET_LHS_COL_PTR 64
55#define RUY_STACK_OFFSET_RHS_COL_PTR 80
56
57template <typename Params>
58void CheckOffsetsInKernelParamsFloat32(const Params&) {
59 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
60 static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, "");
61 static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, "");
62 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
63 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
64 static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, "");
65 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
66 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
67 static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, "");
68 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
69 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
70 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
71 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
72 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
73 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
74 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
75}
76
77// Float kernel for ARM32 out-of-order cores.
78// Just like Float 64 version, except accumulate in to 8x4 block to only
79// use 16 128-bit NEON registers. This is a "first pass" kernel and not
80// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9.
81void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params) {
82 CheckOffsetsInKernelParamsFloat32(params);
83 profiler::ScopeLabel label("Kernel (kNeon)");
84
85 const float* lhs_ptr = params.lhs_base_ptr;
86 const float* rhs_ptr = params.rhs_base_ptr;
87 // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are
88 // each composed of two 64-bit "d" registers. The asm kernel below has the
89 // following NEON register allocation:
90 // Registers q3 -- q10 are accumulators. During accumulation,
91 // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1
92 // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block
93 // of RHS, like this:
94
95 // Register layout in "q" registers:
96 // RHS 1x4 block
97 // /--------------------------|
98 // |q2.s[0] ... q2.s[3] |
99 // \--------------------------/
100 // LHS 8x1 block
101 // /---------------------\ /--------------------------|
102 // | q0.s[0] | | q3.s[0] ... q9.s[0] |
103 // | ... | | ... ... |
104 // | q0.s[3] | | q3.s[3] q9.s[3] |
105 // | q1.s[0] | | q4.s[0] q10.s[0] |
106 // | ... | | ... ... ... |
107 // | q1.s[3] | | q4.s[3] .. q10.s[3] |
108 // \---------------------/ \--------------------------/
109 // accumulators 8x4 block
110 // q11, q14, q15 currently unused. q12 and q13 are used to load
111 // parameters used for the post-accumulation part of the kernel.
112 // For completeness, here is the register layout in "d" registers:
113 // RHS 1x4 block
114 // /--------------------------|
115 // |d4[0] ... d5[1] |
116 // \--------------------------/
117 // LHS 8x1 block
118 // /---------------------\ /--------------------------|
119 // | d0[0] | | d6[0] ... d18[0] |
120 // | ... | | ... ... |
121 // | d1[1] | | d7[1] d19[1] |
122 // | d2[0] | | d8[0] d20[0] |
123 // | ... | | ... ... ... |
124 // | d3[1] | | d9[1] ... d21[1] |
125 // \---------------------/ \--------------------------/
126 // accumulators 8x4 block
127 asm volatile(
128#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n"
129
130 // clang-format off
131
132 // Load the first 32 bytes of LHS and RHS data.
133 // Load q0, q1
134 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n"
135 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n"
136 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
137 // Load q2
138 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
139 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
140
141 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
142
143 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
144 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
145
146 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
147 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
148
149 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
150 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
151
152 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
153 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
154
155 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
156 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
157
158 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
159 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
160 // Clear accumulators.
161 RUY_MAKE_ZERO(q3)
162 RUY_MAKE_ZERO(q4)
163 RUY_MAKE_ZERO(q5)
164 RUY_MAKE_ZERO(q6)
165 RUY_MAKE_ZERO(q7)
166 RUY_MAKE_ZERO(q8)
167 RUY_MAKE_ZERO(q9)
168 RUY_MAKE_ZERO(q10)
169
170 // r1 is the number of levels of depth that we have already loaded
171 // LHS and RHS data for. Corresponding to the initial ld1 instructions
172 // above, this is currently 1.
173 "mov r1, #1\n"
174
175 // Main loop of the whole GEMM, over rows and columns of the
176 // destination matrix.
177 "1:\n"
178
179 // Accumulation loop
180 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
181 "cmp r1, r2\n"
182 "beq 79f\n"
183
184 "2:\n"
185
186 "vmla.f32 q3, q0, d4[0]\n"
187 "vmla.f32 q5, q0, d4[1]\n"
188 "vmla.f32 q7, q0, d5[0]\n"
189 "vmla.f32 q9, q0, d5[1]\n"
190 "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS
191
192 "vmla.f32 q4, q1, d4[0]\n"
193 "vmla.f32 q6, q1, d4[1]\n"
194 "vmla.f32 q8, q1, d5[0]\n"
195 "vmla.f32 q10, q1, d5[1]\n"
196 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
197 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
198 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS
199 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
200
201 "add r1, r1, #1\n"
202 "cmp r1, r2\n"
203
204 "blt 2b\n"
205
206 "79:\n"
207
208 // End of the inner loop on depth. Now perform the remaining
209 // multiply-adds of the last level of depth, for which the LHS
210 // and RHS data is already loaded.
211
212 "vmla.f32 q3, q0, d4[0]\n"
213 "vmla.f32 q5, q0, d4[1]\n"
214 "vmla.f32 q7, q0, d5[0]\n"
215 "vmla.f32 q9, q0, d5[1]\n"
216
217 "vmla.f32 q4, q1, d4[0]\n"
218 "vmla.f32 q6, q1, d4[1]\n"
219 "vmla.f32 q8, q1, d5[0]\n"
220 "vmla.f32 q10, q1, d5[1]\n"
221
222 // End of accumulation. The registers q3 -- q10 contain the final
223 // float32 accumulator values of the current 8x8 destination block.
224 // We now have to compute the final values from these accumulators
225 // and advance to the next 8x8 block. We intertwine
226 // these two aspects whenever possible for optimal pipelining, both
227 // at the data flow level (prefetch data for next block as early as
228 // possible) and instruction pipelining level (some of the next-block
229 // work can dual-issue with some of the final work on the current
230 // block).
231
232 // Logic to advance to the next block in preparation for the next
233 // iteration of the main loop. For now, we only want to compute
234 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
235 // not yet ready to update the values of row and col, as we still need
236 // the current values for the rest of the work on the current block.
237
238 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
239 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
240 "cmp r1, r3\n" // Have we finished the last row?
241
242 "bge 4f\n" // If finished last row, go to 4
243 // Not finished last row: then advance to next row.
244 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
245 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
246 "add r4, r4, r1, lsl #3\n"
247 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
248 "b 5f\n"
249 "4:\n" // Finished last row...
250 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
251 // Go back to first row
252 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
253 // Now we need to advance to the next column. If we already
254 // finished the last column, then in principle we are done, however
255 // we can't just return here, as we need to allow the end work of the
256 // current block to complete. The good news is that at this point it
257 // doesn't matter what data we load for the next column, since
258 // we will exit from the main loop below before actually storing
259 // anything computed from that data.
260 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
261 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
262 "cmp r8, r4\n" // Have we finished the last column?
263 "bge 5f\n" // If yes, just carry on without updating the column pointer.
264 // Not finished last column: then advance to next column.
265 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
266 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
267 "add r10, r10, r1, lsl #2\n"
268 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
269 "5:\n"
270
271 // Set the LHS and RHS data pointers to the start of the columns just
272 // computed.
273 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
274 "mov %[lhs_ptr], r4\n"
275 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
276 "mov %[rhs_ptr], r5\n"
277
278 // Load some parameters needed for the end work on current block.
279 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
280 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
281
282 // Let r8 be stack offset of the row or column variable, whichever
283 // is the channel index.
284 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
285 "bne 1000f\n"
286 "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
287 "b 1001f\n"
288 "1000:\n"
289 "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
290 "1001:\n"
291 // Let r8 be the channel index.
292 "ldr r8, [sp, r8]\n"
293 // Compute the bias pointer, by conditionally using the channel index
294 // (r8) as offset into bias buffer (r1).
295 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
296 "beq 1002f\n"
297 "add r1, r1, r8, lsl #2\n"
298 "1002:\n"
299
300 // Load 4 bias values. When the channel dimension is rows, we will load
301 // another 4 bias values just before performing the bias addition below,
302 // as this kernel has a 8x4 rectangular shape.
303 "vld1.32 {d24, d25}, [r1]!\n"
304
305 // Now that we know what LHS and RHS data the next iteration of the
306 // main loop will need to load, we start loading the first 32 bytes of
307 // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore
308 // in the rest of the work on the current block.
309 // Load q0, q1
310 "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
311 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
312 // Load q2
313 "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n"
314 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
315
316 // Perform the bias-addition.
317 // Jump based on channel dimension.
318 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
319 "bne 6f\n"
320 // Case where channels are rows.
321 // Load the remaining 4 bias values, since we're on the width-8 side
322 // of this 8x4 kernel.
323 "vld1.32 {d26, d27}, [r1]\n"
324 "vadd.f32 q3, q3, q12\n"
325 "vadd.f32 q5, q5, q12\n"
326 "vadd.f32 q7, q7, q12\n"
327 "vadd.f32 q9, q9, q12\n"
328 "vadd.f32 q4, q4, q13\n"
329 "vadd.f32 q6, q6, q13\n"
330 "vadd.f32 q8, q8, q13\n"
331 "vadd.f32 q10, q10, q13\n"
332 "b 7f\n"
333
334 "6:\n"
335 // Case where channels are columns.
336 "vdup.32 q11, d24[0]\n"
337 "vdup.32 q13, d24[1]\n"
338 "vdup.32 q14, d25[0]\n"
339 "vdup.32 q15, d25[1]\n"
340 "vadd.f32 q3, q3, q11\n"
341 "vadd.f32 q4, q4, q11\n"
342 "vadd.f32 q5, q5, q13\n"
343 "vadd.f32 q6, q6, q13\n"
344 "vadd.f32 q7, q7, q14\n"
345 "vadd.f32 q8, q8, q14\n"
346 "vadd.f32 q9, q9, q15\n"
347 "vadd.f32 q10, q10, q15\n"
348 "7:\n"
349
350 // Load the clamp_min, clamp_max bounds
351 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
352 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
353 "vdup.32 q12, r2\n" // clamp_min
354 "vdup.32 q13, r3\n" // clamp_max
355
356 // Apply the clamp_min bound
357 "vmax.f32 q3, q3, q12\n"
358 "vmax.f32 q4, q4, q12\n"
359 "vmax.f32 q5, q5, q12\n"
360 "vmax.f32 q6, q6, q12\n"
361 "vmax.f32 q7, q7, q12\n"
362 "vmax.f32 q8, q8, q12\n"
363 "vmax.f32 q9, q9, q12\n"
364 "vmax.f32 q10, q10, q12\n"
365
366 // Apply the clamp_max bound
367 "vmin.f32 q3, q3, q13\n"
368 "vmin.f32 q4, q4, q13\n"
369 "vmin.f32 q5, q5, q13\n"
370 "vmin.f32 q6, q6, q13\n"
371 "vmin.f32 q7, q7, q13\n"
372 "vmin.f32 q8, q8, q13\n"
373 "vmin.f32 q9, q9, q13\n"
374 "vmin.f32 q10, q10, q13\n"
375
376 // Compute how much of the 8x4 block of destination values that
377 // we have computed, fit in the destination matrix. Typically, all of
378 // it fits, but when the destination matrix shape is not a multiple
379 // of 8x4, there are some 8x8 blocks along the boundaries that do
380 // not fit entirely.
381 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
382 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
383 "sub r1, r1, r8\n"
384
385 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
386 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
387 "sub r2, r2, r4\n"
388 "mov r3, #8\n"
389 "mov r5, #4\n"
390 "cmp r1, #8\n"
391 // Compute r1 = how many rows of the 8x4 block fit
392 "it gt\n"
393 "movgt r1, r3\n"
394 "cmp r2, #4\n"
395 // Compute r2 = how many cols of the 8x4 block fit
396 "it gt\n"
397 "movgt r2, r5\n"
398
399 // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits.
400 "cmp r1, r3\n"
401 "it eq\n"
402 "cmpeq r2, r5\n"
403 // Yes, all of the 8x4 block fits, go to fast path.
404 "beq 30f\n"
405 // Not all of the 8x4 block fits.
406 // Set (r3 address, r4 stride) to write to dst_tmp_buf
407 "mov r3, %[dst_tmp_buf]\n"
408 "mov r4, #32\n"
409 "b 31f\n"
410 "30:\n"
411 // Yes, all of the 8x4 block fits.
412 // Set (r3 address, r4 stride) to write directly to destination matrix.
413 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
414 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
415 "mov r4, r5\n"
416 "31:\n"
417
418 // Write our float values to the destination described by
419 // (r3 address, r4 stride)
420 "vst1.32 {d6, d7, d8, d9}, [r3]\n"
421 "add r3, r3, r4\n"
422 RUY_MAKE_ZERO(q3)
423 RUY_MAKE_ZERO(q4)
424 "vst1.32 {d10, d11, d12, d13}, [r3]\n"
425 "add r3, r3, r4\n"
426 RUY_MAKE_ZERO(q5)
427 RUY_MAKE_ZERO(q6)
428 "vst1.32 {d14, d15, d16, d17}, [r3]\n"
429 "add r3, r3, r4\n"
430 RUY_MAKE_ZERO(q7)
431 RUY_MAKE_ZERO(q8)
432 "vst1.32 {d18, d19, d20, d21}, [r3]\n"
433 "add r3, r3, r4\n"
434 RUY_MAKE_ZERO(q9)
435 RUY_MAKE_ZERO(q10)
436
437 // If all of the 8x4 block fits, we just finished writing it to the
438 // destination, so we skip the next part.
439 "beq 41f\n"
440 // Not all of the 8x8 block fits in the destination matrix. We just
441 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
442 // it to copy into the destination matrix the part that fits.
443 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
444 "mov r3, %[dst_tmp_buf]\n"
445 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
446 "mov r6, #0\n"
447 "50:\n"
448 "mov r5, #0\n"
449 "51:\n"
450 "ldr r10, [r3, r5, lsl #2]\n"
451 "str r10, [r4, r5, lsl #2]\n"
452 "add r5, r5, #1\n"
453 "cmp r5, r1\n"
454 "blt 51b\n"
455 "add r6, r6, #1\n"
456 "add r3, r3, #32\n"
457 "add r4, r4, r8\n"
458 // r2 = how many cols of the 8x4 block fit
459 "cmp r6, r2\n"
460 "blt 50b\n"
461 "41:\n"
462 // Load dst_ptr, increment, and write back.
463 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
464 "add r4, r4, #32\n"
465 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
466 // At this point we have completely finished writing values to the
467 // destination matrix for the current block.
468
469 // Reload some params --- we had used r3, r5, r10 for a few other things
470 // since the last time we had loaded them.
471 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
472 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
473 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
474
475 // Move to the next block of the destination matrix, for the next iter
476 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
477 // been updated earlier.
478 // Have we reached the end row?
479 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
480 "cmp r8, r3\n"
481
482 "beq 20f\n" // yes, end row.
483 // Not end row. Move to the next row.
484 "add r8, r8, #8\n"
485 // Store new value of row
486 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
487
488 "b 21f\n"
489 "20:\n"
490 // Was already at end row.
491 // Move back to first row.
492 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
493 // Move to the next column.
494 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
495 "add r4, r4, #4\n"
496 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
497
498 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
499 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
500 // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns)
501 "add r1, r1, r8, lsl #2\n"
502 // Store dst_col_ptr
503 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
504 // Store dst_ptr
505 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
506 "21:\n"
507
508 // Main loop exit condition: have we hit the end column?
509 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
510 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
511 "cmp r8, r4\n"
512
513 // r1 is the number of levels of depth that we have already loaded
514 // LHS and RHS data for. Corresponding to the initial ld1 instructions
515 // above, this is currently 1.
516 "mov r1, #1\n"
517
518 "ble 1b\n"
519
520 // Restore stack pointer.
521 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
522
523 // clang-format on
524 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
525 : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
526 // Clobber list must specify q registers (and not their constituent
527 // d registers). There is a (currently unexplained) slowdown if
528 // d registers are listed in the clobbers list.
529 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
530 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
531 "q9", "q10", "q12", "q13");
532}
533
534#undef RUY_MAKE_ZERO
535#undef RUY_STACK_OFFSET_SIZE
536#undef RUY_STACK_OFFSET_DST_COL_PTR
537#undef RUY_STACK_OFFSET_DST_PTR
538#undef RUY_STACK_OFFSET_ROW
539#undef RUY_STACK_OFFSET_COL
540#undef RUY_STACK_OFFSET_LHS_COL_PTR
541#undef RUY_STACK_OFFSET_RHS_COL_PTR
542
543#undef RUY_OFFSET_LHS_BASE_PTR
544#undef RUY_OFFSET_RHS_BASE_PTR
545#undef RUY_OFFSET_DST_BASE_PTR
546#undef RUY_OFFSET_BIAS
547#undef RUY_OFFSET_START_ROW
548#undef RUY_OFFSET_START_COL
549#undef RUY_OFFSET_LAST_ROW
550#undef RUY_OFFSET_LAST_COL
551#undef RUY_OFFSET_DST_ROWS
552#undef RUY_OFFSET_DST_COLS
553#undef RUY_OFFSET_LHS_STRIDE
554#undef RUY_OFFSET_RHS_STRIDE
555#undef RUY_OFFSET_DST_STRIDE
556#undef RUY_OFFSET_DEPTH
557#undef RUY_OFFSET_CLAMP_MIN
558#undef RUY_OFFSET_CLAMP_MAX
559#undef RUY_OFFSET_FLAGS
560
561#define RUY_OFFSET_BIAS 0
562#define RUY_OFFSET_LHS_SUMS 4
563#define RUY_OFFSET_RHS_SUMS 8
564#define RUY_OFFSET_LHS_BASE_PTR 12
565#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16
566#define RUY_OFFSET_MULTIPLIER_EXPONENT 20
567#define RUY_OFFSET_RHS_BASE_PTR 24
568#define RUY_OFFSET_DST_BASE_PTR 28
569#define RUY_OFFSET_LHS_ZERO_POINT 32
570#define RUY_OFFSET_RHS_ZERO_POINT 36
571#define RUY_OFFSET_DST_ZERO_POINT 40
572#define RUY_OFFSET_PROD_ZP_DEPTH 44
573#define RUY_OFFSET_START_ROW 48
574#define RUY_OFFSET_START_COL 52
575#define RUY_OFFSET_LAST_ROW 56
576#define RUY_OFFSET_LAST_COL 60
577#define RUY_OFFSET_DST_ROWS 64
578#define RUY_OFFSET_DST_COLS 68
579#define RUY_OFFSET_LHS_STRIDE 72
580#define RUY_OFFSET_RHS_STRIDE 76
581#define RUY_OFFSET_DST_STRIDE 80
582#define RUY_OFFSET_DEPTH 84
583#define RUY_OFFSET_CLAMP_MIN 88
584#define RUY_OFFSET_CLAMP_MAX 92
585#define RUY_OFFSET_FLAGS 96
586#define RUY_OFFSET_DST_TYPE_ID 97
587
588#define RUY_STACK_OFFSET_SIZE 96
589#define RUY_STACK_OFFSET_DST_COL_PTR 0
590#define RUY_STACK_OFFSET_DST_PTR 16
591#define RUY_STACK_OFFSET_ROW 32
592#define RUY_STACK_OFFSET_COL 48
593#define RUY_STACK_OFFSET_LHS_COL_PTR 64
594#define RUY_STACK_OFFSET_RHS_COL_PTR 80
595
596template <typename Params>
597void CheckOffsetsInKernelParams8bit(const Params&) {
598 static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
599 "");
600 static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
601 "");
602 static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
603 "");
604 static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
605 "");
606 static_assert(offsetof(Params, multiplier_fixedpoint) ==
607 RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
608 "");
609 static_assert(
610 offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
611 "");
612 static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
613 static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
614 static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
615 static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
616 static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
617 static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
618 static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
619 static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
620 static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
621 static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
622 static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
623 static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
624 static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
625 static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
626}
627
628// Fast-int8 kernel, ported from ARM 64 version.
629// Relevant target CPUs for this kernel include Krait 400 and A9,
630// since these are 32-bit, out-of-order CPUs.
631void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
632 profiler::ScopeLabel label("Kernel (kNeon)");
633
634 CheckOffsetsInKernelParams8bit(params);
635
636 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
637 const std::int8_t* rhs_col_ptr =
638 static_cast<const int8_t*>(params.rhs_base_ptr);
639 const std::int8_t* lhs_ptr = lhs_col_ptr;
640 const std::int8_t* rhs_ptr = rhs_col_ptr;
641
642 // The asm kernel below has the following NEON register allocation:
643 //
644 // q6 - q13 are 128-bit (4x32b) accumulators.
645 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
646 // d8 -- d11 from RHS:
647 // int8 RHS 16x2 block
648 // /-----------------------------|
649 // |d8.b[0-7] ..... d10.b[0-7]|
650 // | ... ... |
651 // |d9.b[0-7] ..... d11.b[0-7]|
652 // \-----------------------------/
653 // int8 LHS 4x16 block
654 // /------------------------\ /-----------------------------|
655 // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 |
656 // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 |
657 // (Reload d0, d1, d2, d3)
658 // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 |
659 // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 |
660 // \------------------------/ \-----------------------------/
661 // 128-bit accumulators 4x2 block
662 //
663 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
664 // optimization for this kernel.
665 asm volatile(
666#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
667
668 // clang-format off
669
670 // Load the first 64 bytes of LHS and RHS data.
671 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
672 // Clear accumulators.
673 RUY_MAKE_ZERO(q6)
674 RUY_MAKE_ZERO(q7)
675 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
676 RUY_MAKE_ZERO(q8)
677 RUY_MAKE_ZERO(q9)
678 RUY_MAKE_ZERO(q10)
679 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
680 RUY_MAKE_ZERO(q11)
681 "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n"
682
683 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
684
685 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
686 RUY_MAKE_ZERO(q12)
687 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
688
689 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
690 RUY_MAKE_ZERO(q13)
691 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
692
693 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
694 RUY_MAKE_ZERO(q14)
695 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
696
697 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
698 RUY_MAKE_ZERO(q15)
699 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
700
701 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
702 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
703
704 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
705 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
706
707
708 // r1 is the number of levels of depth that we have already loaded
709 // LHS and RHS data for. Corresponding to the initial ld1 instructions
710 // above, this is currently 16.
711 "mov r1, #16\n"
712
713 // Main loop of the whole GEMM, over rows and columns of the
714 // destination matrix.
715 "1:\n"
716
717 // r1 is how many levels of depth we have already loaded
718 // data for, r10 is the total depth.
719 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
720 "cmp r1, r10\n"
721 "beq 79f\n"
722
723 "2:\n"
724
725 // Mult, mult-acc in to q14, q15, q2, q3
726 "vmull.s8 q14, d0, d8\n"
727 "vmull.s8 q2, d0, d10\n"
728
729 "vmull.s8 q15, d2, d8\n"
730 "vmull.s8 q3, d2, d10\n"
731
732 "vmlal.s8 q14, d1, d9\n"
733 "vmlal.s8 q2, d1, d11\n"
734 "vmlal.s8 q15, d3, d9\n"
735 "vmlal.s8 q3, d3, d11\n"
736 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
737
738 // Then pairwise accumulate in to q6, q7, q10, q11
739 "vpadal.s16 q6, q14\n"
740 "vpadal.s16 q7, q15\n"
741 "vpadal.s16 q10, q2\n"
742 "vpadal.s16 q11, q3\n"
743
744 // Mult, mult-acc in to q14, q15, q2, q3
745 "vmull.s8 q14, d0, d8\n"
746 "vmull.s8 q2, d0, d10\n"
747
748 "vmull.s8 q15, d2, d8\n"
749 "vmull.s8 q3, d2, d10\n"
750
751 "vmlal.s8 q14, d1, d9\n"
752 "vmlal.s8 q2, d1, d11\n"
753 "vmlal.s8 q15, d3, d9\n"
754 "vmlal.s8 q3, d3, d11\n"
755 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
756
757 // Then pairwise accumulate in to q8, q9, q12, q13
758 "vpadal.s16 q8, q14\n"
759 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
760 "vpadal.s16 q9, q15\n"
761 "vpadal.s16 q12, q2\n"
762 "vpadal.s16 q13, q3\n"
763
764 // Prefetch the next 64 bytes of LHS and RHS data.
765 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
766 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
767
768 // Each iteration of this loop advances by 16 levels of depth.
769 "add r1, r1, #16\n"
770
771 // Loop termination condition
772 "cmp r1, r10\n"
773
774 "blt 2b\n"
775
776 "79:\n"
777
778 // Mult, mult-acc in to q14, q15, q2, q3
779 "vmull.s8 q14, d0, d8\n"
780 "vmull.s8 q2, d0, d10\n"
781
782 "vmull.s8 q15, d2, d8\n"
783 "vmull.s8 q3, d2, d10\n"
784
785 "vmlal.s8 q14, d1, d9\n"
786 "vmlal.s8 q2, d1, d11\n"
787 "vmlal.s8 q15, d3, d9\n"
788 "vmlal.s8 q3, d3, d11\n"
789 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS
790
791 // Then pairwise accumulate in to q6, q7, q10, q11
792 "vpadal.s16 q6, q14\n"
793 "vpadal.s16 q7, q15\n"
794 "vpadal.s16 q10, q2\n"
795 "vpadal.s16 q11, q3\n"
796
797 // Mult, mult-acc in to q14, q15, q2, q3
798 "vmull.s8 q14, d0, d8\n"
799 "vmull.s8 q2, d0, d10\n"
800
801 "vmull.s8 q15, d2, d8\n"
802 "vmull.s8 q3, d2, d10\n"
803
804 "vmlal.s8 q14, d1, d9\n"
805 "vmlal.s8 q2, d1, d11\n"
806 "vmlal.s8 q15, d3, d9\n"
807 "vmlal.s8 q3, d3, d11\n"
808
809 // Then pairwise accumulate in to q8, q9, q12, q13
810 "vpadal.s16 q8, q14\n"
811 "vpadal.s16 q9, q15\n"
812 "vpadal.s16 q12, q2\n"
813 "vpadal.s16 q13, q3\n"
814
815
816 // All accumulation over depth done. q6 - q13 contain the 4x32b
817 // accumulators for the 4x2 final matrix.
818 // We now have to compute the final 8-bit values from these int32
819 // accumulators, and advance to the next 4x2 block. We intertwine
820 // these two aspects whenever possible for optimal pipelining, both
821 // at the data flow level (prefetch data for next block as early as
822 // possible) and instruction pipelining level (some of the next-block
823 // work can dual-issue with some of the final work on the current
824 // block).
825
826 // q6-q13 now contain 4 x 32b
827 "vpadd.i32 d0, d12, d13\n"
828 "vpadd.i32 d1, d14, d15\n"
829 "vpadd.i32 d2, d16, d17\n"
830 "vpadd.i32 d3, d18, d19\n"
831 "vpadd.i32 d4, d20, d21\n"
832 "vpadd.i32 d5, d22, d23\n"
833 "vpadd.i32 d6, d24, d25\n"
834 "vpadd.i32 d7, d26, d27\n"
835
836 // d0-d7 each contain 2 x 32b accumulators.
837 // Need to add pairwise to get 1 x 32b for each of the 4x2 entries
838 // of destination, (Four 'd' registers total)
839 "vpadd.i32 d28, d0, d1\n"
840 "vpadd.i32 d29, d2, d3\n"
841 "vpadd.i32 d30, d4, d5\n"
842 "vpadd.i32 d31, d6, d7\n"
843
844 //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries
845
846 // Logic to advance to the next block in preparation for the next
847 // iteration of the main loop. For now, we only want to compute
848 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
849 // not yet ready to update the values of row and col, as we still need
850 // the current values for the rest of the work on the current block.
851
852 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
853 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
854 "cmp r1, r3\n" // Have we finished the last row?
855
856 "bge 4f\n" // If finished last row, go to 4
857 // Not finished last row: then advance to next row.
858 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
859 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
860 "add r4, r4, r1, lsl #2\n"
861 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
862 "b 5f\n"
863 "4:\n" // Finished last row...
864 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
865 // Go back to first row
866 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
867
868 // Now we need to advance to the next column. If we already
869 // finished the last column, then in principle we are done, however
870 // we can't just return here, as we need to allow the end work of the
871 // current block to complete. The good news is that at this point it
872 // doesn't matter what data we load for the next column, since
873 // we will exit from the main loop below before actually storing
874 // anything computed from that data.
875
876 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
877 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
878 "cmp r8, r4\n" // Have we finished the last column?
879 "bge 5f\n" // If yes, just carry on without updating the column pointer.
880 // Not finished last column: then advance to next column.
881 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
882 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
883 "add r10, r10, r1, lsl #1\n"
884 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
885 "5:\n"
886
887 // Set the LHS and RHS data pointers to the start of the columns just
888 // computed.
889 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
890 "mov %[lhs_ptr], r4\n"
891 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
892 "mov %[rhs_ptr], r5\n"
893
894 // Now we load: bias data, LHS sums data, RHS sums data.
895
896 // First, load the base pointers from the params.
897 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
898 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
899
900 // Let r8 be stack offset of the row or column variable, whichever
901 // is the channel index.
902 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
903 "bne 1000f\n"
904 "mov r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
905 "b 1001f\n"
906 "1000:\n"
907 "mov r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
908 "1001:\n"
909
910 // Let r8 be the channel index.
911 "ldr r8, [sp, r8]\n"
912 // Compute the bias pointer, by conditionally using the channel index
913 // (r8) as offset into bias buffer (r1).
914 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
915 "beq 1002f\n"
916 "add r1, r1, r8, lsl #2\n"
917 "1002:\n"
918
919 // Load 2 bias values. When the channel dimension is rows, we will load
920 // another 2 bias values just before performing the bias addition below,
921 // as this kernel has a 4x2 rectangular shape.
922 "vld1.32 {d24}, [r1]!\n"
923
924 // Now that we know what LHS and RHS data the next iteration of the
925 // main loop will need to load, we start loading the first 32 bytes of
926 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
927 // in the rest of the work on the current block.
928 "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n"
929 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
930 "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n"
931 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
932
933 // Add to the bias values the product
934 // (depth * lhs_zero_point * rhs_zero_point),
935 // See the term NZ1Z2 in equation (7) in
936 // https://arxiv.org/pdf/1712.05877.pdf
937 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
938 "vdup.32 q9, r3\n"
939 "vadd.i32 d24, d24, d18\n"
940
941 // Perform the bias-addition (per the above, we have just folded into
942 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
943 // Jump based on channel dimension.
944 "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
945 "bne 6f\n"
946 // Case where channels are rows.
947 // Load the remaining 2 bias values, since we're on the width-4 side
948 // of this 4x2 kernel.
949 "vld1.32 {d25}, [r1]\n"
950 "vadd.i32 d25, d25, d19\n"
951 "vadd.i32 q14, q14, q12\n"
952 "vadd.i32 q15, q15, q12\n"
953 "b 7f\n"
954
955 "6:\n"
956 // Case where channels are columns.
957 "vdup.32 q10, d24[0]\n"
958 "vdup.32 q11, d24[1]\n"
959 "vadd.i32 q14, q14, q10\n"
960 "vadd.i32 q15, q15, q11\n"
961 "7:\n"
962
963 // LHS/RHS zero points
964 // Has RHS sums
965 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
966 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
967 "beq 401f\n"
968 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
969 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
970 // Offset by current col * number of bytes per value
971 "add r3, r3, r4, lsl #2\n"
972 "vld1.32 { d12 }, [r3]\n"
973 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
974 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
975 // Subtract rhs_sums * lhs_zero_point, per
976 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
977 "vmls.i32 q14, q10, d12[0]\n"
978 "vmls.i32 q15, q10, d12[1]\n"
979 "401:\n"
980
981 // Has LHS sums
982 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
983 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
984 "beq 402f\n"
985 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
986 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
987 // Offset by current row * number of bytes per value
988 "add r2, r2, r4, lsl #2\n"
989 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
990
991 // Load 4 lhs_sums values.
992 "vld1.32 {d22, d23}, [r2]\n"
993 "vdup.32 d13, r5\n" // rhs_zero_point
994
995 // Compute lhs_sums * rhs_zero_point.
996 "vmul.i32 q11, q11, d13[1]\n"
997 // Subtract lhs_sums * rhs_zero_point, per
998 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
999 "vsub.s32 q14, q14, q11\n"
1000 "vsub.s32 q15, q15, q11\n"
1001
1002 // If the destination is int32, it means the user asks for the raw
1003 // accumulators, no need for us to downquantize the value.
1004 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1005 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1006 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1007
1008 "402:\n"
1009
1010 // At this point we have computed the final int32 values. Now we
1011 // start down-quantizing them to obtain the final 8bit values from them.
1012
1013 // As part of this down-quantization, our int32 values will be
1014 // multiplied by a multiplier that has a fixed-point component and an
1015 // exponent component.
1016
1017 // Compute the data pointers for the multiplier data
1018 // r1 = exponent part
1019 // r2 = fixedpoint part
1020 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1021 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1022 // r6 has flags, r8 has channel index
1023 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1024 "beq 1003f\n"
1025 "add r1, r1, r8, lsl #2\n"
1026 "add r2, r2, r8, lsl #2\n"
1027 "1003:\n"
1028
1029 // Load the first 2 values of multiplier exponent and fixedpoint data
1030 // Since this kernel is rectangular 4x2, we will only conditionally load
1031 // 2 more values below.
1032 "vld1.32 {d20}, [r1]!\n" // 2 values of multiplier_exponent
1033 "vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint
1034
1035 "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
1036 "vmvn.i32 q8, #0\n"
1037 "bne 8f\n"
1038 // Case where channels are rows.
1039 // Load the remaining 2 bias values, since we're on the width-4 side
1040 // of this 4x2 kernel.
1041 "vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent
1042 "vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint
1043 "vmin.s32 q11, q10, q8\n"
1044 "vsub.s32 q10, q10, q11\n"
1045
1046 // Apply the positive exponent part of the multiplier.
1047 "vshl.s32 q14, q14, q10\n"
1048 "vshl.s32 q15, q15, q10\n"
1049
1050 // Apply the fixed-point part of the multiplier.
1051 "vqdmulh.s32 q14, q14, q6\n"
1052 "vqdmulh.s32 q15, q15, q6\n"
1053
1054 // Apply the negative exponent part of the multiplier.
1055 "vrshl.s32 q14, q14, q11\n"
1056 "vrshl.s32 q15, q15, q11\n"
1057 "b 9f\n"
1058
1059 "8:\n"
1060 // Case where channels are columns.
1061 "vmin.s32 d22, d20, d16\n"
1062 "vsub.s32 d20, d20, d22\n"
1063
1064 // Apply the positive exponent part of the multiplier.
1065 "vdup.32 q12, d20[0]\n"
1066 "vdup.32 q13, d20[1]\n"
1067 "vshl.s32 q14, q14, q12\n"
1068 "vshl.s32 q15, q15, q13\n"
1069
1070 // Apply the fixed-point part of the multiplier.
1071 "vqdmulh.s32 q14, q14, d12[0]\n"
1072 "vqdmulh.s32 q15, q15, d12[1]\n"
1073
1074 // Apply the negative exponent part of the multiplier.
1075 "vdup.32 q12, d22[0]\n"
1076 "vdup.32 q13, d22[1]\n"
1077 "vrshl.s32 q14, q14, q12\n"
1078 "vrshl.s32 q15, q15, q13\n"
1079
1080 "9:\n"
1081
1082 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1083 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
1084 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
1085 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
1086 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
1087
1088 // Store uint8 values:
1089 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
1090
1091 // Cast-and-saturate from int32 to int16
1092 // After this, all values for output are in q14.
1093 "vqmovn.s32 d28, q14\n"
1094 "vqmovn.s32 d29, q15\n"
1095
1096 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1097 // current block, so we can start clearing these accumulators for the
1098 // next block (next iteration of the main loop).
1099 RUY_MAKE_ZERO(q6)
1100 RUY_MAKE_ZERO(q7)
1101 RUY_MAKE_ZERO(q8)
1102 RUY_MAKE_ZERO(q9)
1103 RUY_MAKE_ZERO(q10)
1104 RUY_MAKE_ZERO(q11)
1105 RUY_MAKE_ZERO(q12)
1106 RUY_MAKE_ZERO(q13)
1107 RUY_MAKE_ZERO(q15)
1108
1109 // Load the destination zero point into each of the 8 16-bit slots
1110 // in a q register.
1111 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1112 "vdup.16 q13, r4\n" // dst_zero_point
1113
1114 // Add the destination zero point
1115 "vqadd.s16 q14, q14, q13\n"
1116
1117 // Cast-and-saturate from int16 to uint8
1118 // Now all 8 1-byte values are in d30.
1119 "vqmovun.s16 d30, q14\n"
1120
1121 // Load the clamp_min, clamp_max bounds
1122 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1123 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1124 "vdup.8 d28, r2\n" // clamp_min
1125 "vdup.8 d29, r3\n" // clamp_max
1126
1127 // Apply the clamp_min bound
1128 "vmax.u8 d30, d30, d28\n"
1129 // Apply the clamp_max bound
1130 "vmin.u8 d30, d30, d29\n"
1131
1132 // Compute how much of the 4x2 block of destination 8bit values that
1133 // we have computed, fit in the destination matrix. Typically, all of
1134 // it fits, but when the destination matrix shape is not a multiple
1135 // of 4x2, there are some 4x2 blocks along the boundaries that do
1136 // not fit entirely.
1137
1138 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1139 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1140 "sub r1, r1, r8\n"
1141
1142 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1143 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1144 "sub r2, r2, r4\n"
1145 "mov r3, #4\n"
1146 "mov r5, #2\n"
1147 "cmp r1, #4\n"
1148 // Compute r1 = how many rows of the 4x2 block fit
1149 "it gt\n"
1150 "movgt r1, r3\n"
1151
1152 "cmp r2, #2\n"
1153 // Compute r2 = how many cols of the 4x2 block fit
1154 "it gt\n"
1155 "movgt r2, r5\n"
1156
1157 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1158 "cmp r1, r3\n"
1159 "it eq\n"
1160 "cmpeq r2, r5\n"
1161 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1162 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1163 // Yes, all of the 4x2 block fits, go to fast path.
1164 "beq 30f\n"
1165 // Not all of the 4x2 block fits.
1166 // Store to dst_tmp_buf
1167 // Set r3 address to write to dst_tmp_buf.
1168 "mov r3, %[dst_tmp_buf]\n"
1169 "vst1.8 {d30}, [r3]\n"
1170
1171 // Slow loop copying from dst_tmp_buf to dst.
1172 "mov r6, #0\n"
1173 "50:\n"
1174 "mov r8, #0\n"
1175 "51:\n"
1176 "ldrb r10, [r3, r8]\n"
1177 "strb r10, [r4, r8]\n"
1178 "add r8, r8, #1\n"
1179 "cmp r8, r1\n"
1180 "blt 51b\n"
1181 "add r6, r6, #1\n"
1182 "add r3, r3, #4\n"
1183 "add r4, r4, r5\n"
1184 "cmp r6, r2\n"
1185 "blt 50b\n"
1186 "b 31f\n"
1187 "30:\n"
1188 // Yes, all of the 4x2 block fits.
1189 // r3 address, r5 stride
1190 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1191 "mov r4, r3\n"
1192 "mov r6, #1\n"
1193
1194 "vst1.32 {d30[0]}, [r3]\n"
1195 "add r4, r4, r5\n"
1196 "mov r3, r4\n"
1197 "vst1.32 {d30[1]}, [r3]\n"
1198
1199 "31:\n"
1200
1201 // Load dst_ptr, increment, and write back.
1202 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1203 "add r4, r4, #4\n"
1204 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1205
1206 RUY_MAKE_ZERO(q13)
1207 RUY_MAKE_ZERO(q14)
1208 RUY_MAKE_ZERO(q15)
1209
1210 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1211
1212 // Store int8 values:
1213 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
1214
1215 // Cast-and-saturate from int32 to int16
1216 // After this, all values for output are in q14.
1217 "vqmovn.s32 d28, q14\n"
1218 "vqmovn.s32 d29, q15\n"
1219
1220 // At this point, d12 -- d26, d30, d31 aren't used anymore for the
1221 // current block, so we can start clearing these accumulators for the
1222 // next block (next iteration of the main loop).
1223 RUY_MAKE_ZERO(q6)
1224 RUY_MAKE_ZERO(q7)
1225 RUY_MAKE_ZERO(q8)
1226 RUY_MAKE_ZERO(q9)
1227 RUY_MAKE_ZERO(q10)
1228 RUY_MAKE_ZERO(q11)
1229 RUY_MAKE_ZERO(q12)
1230 RUY_MAKE_ZERO(q13)
1231 RUY_MAKE_ZERO(q15)
1232
1233 // Load the destination zero point into each of the 8 16-bit slots
1234 // in a q register.
1235 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1236 "vdup.16 q13, r4\n" // dst_zero_point
1237
1238 // Add the destination zero point
1239 "vqadd.s16 q14, q14, q13\n"
1240
1241 // Cast-and-saturate from int16 to int8
1242 // Now all 8 1-byte values are in d30.
1243 "vqmovn.s16 d30, q14\n"
1244
1245 // Load the clamp_min, clamp_max bounds
1246 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1247 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1248 "vdup.8 d28, r2\n" // clamp_min
1249 "vdup.8 d29, r3\n" // clamp_max
1250
1251 // Apply the clamp_min bound
1252 "vmax.s8 d30, d30, d28\n"
1253 // Apply the clamp_max bound
1254 "vmin.s8 d30, d30, d29\n"
1255
1256 // Compute how much of the 4x2 block of destination 8bit values that
1257 // we have computed, fit in the destination matrix. Typically, all of
1258 // it fits, but when the destination matrix shape is not a multiple
1259 // of 4x2, there are some 4x2 blocks along the boundaries that do
1260 // not fit entirely.
1261
1262 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1263 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1264 "sub r1, r1, r8\n"
1265
1266 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1267 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1268 "sub r2, r2, r4\n"
1269 "mov r3, #4\n"
1270 "mov r5, #2\n"
1271 "cmp r1, #4\n"
1272 // Compute r1 = how many rows of the 4x2 block fit
1273 "it gt\n"
1274 "movgt r1, r3\n"
1275
1276 "cmp r2, #2\n"
1277 // Compute r2 = how many cols of the 4x2 block fit
1278 "it gt\n"
1279 "movgt r2, r5\n"
1280
1281 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1282 "cmp r1, r3\n"
1283 "it eq\n"
1284 "cmpeq r2, r5\n"
1285 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1286 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1287 // Yes, all of the 4x2 block fits, go to fast path.
1288 "beq 30f\n"
1289 // Not all of the 4x2 block fits.
1290 // Store to dst_tmp_buf
1291 // Set r3 address to write to dst_tmp_buf.
1292 "mov r3, %[dst_tmp_buf]\n"
1293 "vst1.8 {d30}, [r3]\n"
1294
1295 // Slow loop copying from dst_tmp_buf to dst.
1296 "mov r6, #0\n"
1297 "50:\n"
1298 "mov r8, #0\n"
1299 "51:\n"
1300 "ldrb r10, [r3, r8]\n"
1301 "strb r10, [r4, r8]\n"
1302 "add r8, r8, #1\n"
1303 "cmp r8, r1\n"
1304 "blt 51b\n"
1305 "add r6, r6, #1\n"
1306 "add r3, r3, #4\n"
1307 "add r4, r4, r5\n"
1308 "cmp r6, r2\n"
1309 "blt 50b\n"
1310 "b 31f\n"
1311 "30:\n"
1312 // Yes, all of the 4x2 block fits.
1313 // r3 address, r5 stride
1314 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1315 "mov r4, r3\n"
1316 "mov r6, #1\n"
1317
1318 "vst1.32 {d30[0]}, [r3]\n"
1319 "add r4, r4, r5\n"
1320 "mov r3, r4\n"
1321 "vst1.32 {d30[1]}, [r3]\n"
1322
1323 "31:\n"
1324
1325 // Load dst_ptr, increment, and write back.
1326 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1327 "add r4, r4, #4\n"
1328 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1329
1330 RUY_MAKE_ZERO(q13)
1331 RUY_MAKE_ZERO(q14)
1332 RUY_MAKE_ZERO(q15)
1333
1334 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1335
1336 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
1337
1338 // Load the destination zero point into each of the 4 32-bit slots
1339 // in a q register.
1340 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
1341 "vdup.32 q13, r4\n" // dst_zero_point
1342 // Add the destination zero point
1343 "vadd.s32 q14, q14, q13\n"
1344 "vadd.s32 q15, q15, q13\n"
1345
1346 // Cast-and-saturate from int32 to int16
1347 // After this, all values for output are in q14.
1348 "vqmovn.s32 d28, q14\n"
1349 "vqmovn.s32 d29, q15\n"
1350
1351 // At this point, v18 -- v31 aren't used anymore for the current block,
1352 // so we can start clearing these accumulators for the next block
1353 // (next iteration of the main loop).
1354 RUY_MAKE_ZERO(q6)
1355 RUY_MAKE_ZERO(q7)
1356 RUY_MAKE_ZERO(q8)
1357 RUY_MAKE_ZERO(q9)
1358 RUY_MAKE_ZERO(q10)
1359 RUY_MAKE_ZERO(q11)
1360 RUY_MAKE_ZERO(q15)
1361
1362 // Load the clamp_min, clamp_max bounds
1363 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
1364 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
1365 "vdup.16 q12, r2\n" // clamp_min
1366 "vdup.16 q13, r3\n" // clamp_max
1367
1368 // Apply the clamp_min bound
1369 "vmax.s16 q14, q14, q12\n"
1370 // Apply the clamp_max bound
1371 "vmin.s16 q14, q14, q13\n"
1372
1373 RUY_MAKE_ZERO(q12)
1374 RUY_MAKE_ZERO(q13)
1375
1376 // Compute how much of the 4x2 block of destination 16-bit values that
1377 // we have computed, fit in the destination matrix. Typically, all of
1378 // it fits, but when the destination matrix shape is not a multiple
1379 // of 4x2, there are some 4x2 blocks along the boundaries that do
1380 // not fit entirely.
1381
1382 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1383 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1384 "sub r1, r1, r8\n"
1385
1386 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1387 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1388 "sub r2, r2, r4\n"
1389 "mov r3, #4\n"
1390 "mov r5, #2\n"
1391 "cmp r1, #4\n"
1392 // Compute r1 = how many rows of the 4x2 block fit
1393 "it gt\n"
1394 "movgt r1, r3\n"
1395
1396 "cmp r2, #2\n"
1397 // Compute r2 = how many cols of the 4x2 block fit
1398 "it gt\n"
1399 "movgt r2, r5\n"
1400
1401 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1402 "cmp r1, r3\n"
1403 "it eq\n"
1404 "cmpeq r2, r5\n"
1405 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1406 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1407 // Yes, all of the 4x2 block fits, go to fast path.
1408 "beq 30f\n"
1409 // Not all of the 4x2 block fits.
1410 // Store to dst_tmp_buf
1411 // Set r3 address to write to dst_tmp_buf.
1412 "mov r3, %[dst_tmp_buf]\n"
1413 "vst1.16 {q14}, [r3]\n"
1414
1415 // Slow loop copying from dst_tmp_buf to dst.
1416 "mov r6, #0\n"
1417 "50:\n"
1418 "mov r8, #0\n"
1419 "51:\n"
1420 // Shift of offset register for half-word loads not allowed in A32,
1421 // so we shift, load/store, then shift back r8.
1422 "lsl r8, r8, #1\n"
1423 "ldrh r10, [r3, r8]\n"
1424 "strh r10, [r4, r8]\n"
1425 "lsr r8, r8, #1\n"
1426 "add r8, r8, #1\n"
1427 "cmp r8, r1\n"
1428 "blt 51b\n"
1429 "add r6, r6, #1\n"
1430 "add r3, r3, #8\n"
1431 "add r4, r4, r5\n"
1432 "cmp r6, r2\n"
1433 "blt 50b\n"
1434 "b 31f\n"
1435 "30:\n"
1436 // Yes, all of the 4x2 block fits.
1437 // r3 address, r5 stride
1438 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1439 "mov r4, r3\n"
1440 "mov r6, #2\n"
1441
1442 "vst1.16 {d28[0]}, [r3], r6\n"
1443 "add r4, r4, r5\n"
1444 "vst1.16 {d28[1]}, [r3], r6\n"
1445 "vst1.16 {d28[2]}, [r3], r6\n"
1446 "vst1.16 {d28[3]}, [r3], r6\n"
1447 "mov r3, r4\n"
1448 "vst1.16 {d29[0]}, [r3], r6\n"
1449 "vst1.16 {d29[1]}, [r3], r6\n"
1450 "vst1.16 {d29[2]}, [r3], r6\n"
1451 "vst1.16 {d29[3]}, [r3], r6\n"
1452 "31:\n"
1453
1454 // Load dst_ptr, increment, and write back.
1455 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1456 "add r4, r4, #8\n"
1457 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1458
1459 RUY_MAKE_ZERO(q14)
1460
1461 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1462
1463 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
1464
1465 // Since the store type is the same as the accum type, no need for
1466 // downcast. There's also no need for clamp by min/max.
1467
1468 // At this point, v20 -- v31 aren't used anymore for the current block,
1469 // so we can start clearing these accumulators for the next block
1470 // (next iteration of the main loop).
1471 // Clear accumulators.
1472 RUY_MAKE_ZERO(q6)
1473 RUY_MAKE_ZERO(q7)
1474 RUY_MAKE_ZERO(q8)
1475 RUY_MAKE_ZERO(q9)
1476 RUY_MAKE_ZERO(q10)
1477 RUY_MAKE_ZERO(q11)
1478 RUY_MAKE_ZERO(q12)
1479 RUY_MAKE_ZERO(q13)
1480
1481 // Compute how much of the 4x2 block of destination 32 bit values that
1482 // we have computed, fit in the destination matrix. Typically, all of
1483 // it fits, but when the destination matrix shape is not a multiple
1484 // of 4x2, there are some 4x4 blocks along the boundaries that do
1485 // not fit entirely.
1486
1487 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
1488 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1489 "sub r1, r1, r8\n"
1490
1491 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
1492 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1493 "sub r2, r2, r4\n"
1494 "mov r3, #4\n"
1495 "mov r5, #2\n"
1496 "cmp r1, #4\n"
1497 // Compute r1 = how many rows of the 4x2 block fit
1498 "it gt\n"
1499 "movgt r1, r3\n"
1500
1501 "cmp r2, #2\n"
1502 // Compute r2 = how many cols of the 4x2 block fit
1503 "it gt\n"
1504 "movgt r2, r5\n"
1505
1506 // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits.
1507 "cmp r1, r3\n"
1508 "it eq\n"
1509 "cmpeq r2, r5\n"
1510 // Yes, all of the 4x2 block fits, go to fast path.
1511 "beq 30f\n"
1512 // Not all of the 4x2 block fits.
1513 // Set (r3 address, r4 stride) to write to dst_tmp_buf
1514 "mov r3, %[dst_tmp_buf]\n"
1515 "mov r4, #16\n"
1516 "b 31f\n"
1517
1518 "30:\n"
1519 // Yes, all of the 4x2 block fits.
1520 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1521 // r3 address, r4 stride
1522 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1523 "mov r4, r5\n"
1524
1525 "31:\n"
1526
1527 "vst1.32 {d28, d29}, [r3]\n"
1528 "add r3, r3, r4\n"
1529 "vst1.32 {d30, d31}, [r3]\n"
1530
1531 // If all of the 4x2 block fits, we just finished writing it to the
1532 // destination, so we skip the next part.
1533 "beq 41f\n"
1534 // Not all of the 4x2 block fits in the destination matrix. We just
1535 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
1536 // it to copy into the destination matrix the part that fits.
1537 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1538 "mov r3, %[dst_tmp_buf]\n"
1539 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1540 "mov r6, #0\n"
1541 "50:\n"
1542 "mov r5, #0\n"
1543 "51:\n"
1544 "ldr r10, [r3, r5, lsl #2]\n"
1545 "str r10, [r4, r5, lsl #2]\n"
1546 "add r5, r5, #1\n"
1547 "cmp r5, r1\n"
1548 "blt 51b\n"
1549 "add r6, r6, #1\n"
1550 "add r3, r3, #16\n"
1551 "add r4, r4, r8\n"
1552 // r2 = how many cols of the 8x4 block fit
1553 "cmp r6, r2\n"
1554 "blt 50b\n"
1555
1556 "41:\n"
1557 // Load dst_ptr, increment, and write back.
1558 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1559 "add r4, r4, #16\n"
1560 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1561
1562 RUY_MAKE_ZERO(q10)
1563 RUY_MAKE_ZERO(q11)
1564
1565 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
1566
1567 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
1568
1569 // Reload some params --- we had used x5 -- x7 for a few other things
1570 // since the last time we had loaded them.
1571 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1572 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1573 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1574
1575 // Move to the next block of the destination matrix, for the next iter
1576 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
1577 // been updated earlier.
1578 // Have we reached the end row?
1579 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1580 "cmp r8, r3\n"
1581
1582 "beq 20f\n" // yes, end row.
1583 // Not end row. Move to the next row.
1584 "add r8, r8, #4\n"
1585 // Store new value of row
1586 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1587
1588 "b 21f\n"
1589 "20:\n"
1590 // Was already at end row.
1591 // Move back to first row.
1592 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1593 // Move to the next column.
1594 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1595 "add r4, r4, #2\n"
1596 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1597
1598 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
1599 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1600 // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns)
1601 "add r1, r1, r8, lsl #1\n"
1602 // Store dst_col_ptr
1603 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1604 // Store dst_ptr
1605 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1606 "21:\n"
1607
1608 // Main loop exit condition: have we hit the end column?
1609 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1610 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1611 "cmp r8, r4\n"
1612
1613 // w1 is the number of levels of depth that we have already loaded
1614 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1615 // above, this is currently 16.
1616 "mov r1, #16\n"
1617
1618 "ble 1b\n"
1619
1620 // Restore stack pointer.
1621 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1622
1623 // clang-format on
1624
1625 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
1626 : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
1627 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
1628 // Clobber list must specify q registers (and not their constituent
1629 // d registers). There is a (currently unexplained) slowdown if
1630 // d registers are listed in the clobbers list.
1631 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
1632 "q9", "q10", "q12", "q13", "q14", "q15");
1633}
1634
1635// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS
1636// is still packed as if it has two columns
1637void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
1638 profiler::ScopeLabel label("Kernel (kNeon)");
1639
1640 CheckOffsetsInKernelParams8bit(params);
1641
1642 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1643 const std::int8_t* rhs_col_ptr =
1644 static_cast<const int8_t*>(params.rhs_base_ptr);
1645 const std::int8_t* lhs_ptr = lhs_col_ptr;
1646 const std::int8_t* rhs_ptr = rhs_col_ptr;
1647
1648 RUY_DCHECK(!(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL));
1649
1650 // The asm kernel below has the following NEON register allocation:
1651 //
1652 // q6 - q13 are 128-bit (4x32b) accumulators.
1653 // During accumulation, d0 -- d7 are used to load int8 data from LHS and
1654 // d8 -- d11 from RHS:
1655 // int8 RHS 16x1 block
1656 // /------------|
1657 // | d8.b[0] |
1658 // | ... |
1659 // | d8.b[7] |
1660 // | d9.b[0] |
1661 // | ... |
1662 // | d9.b[7] |
1663 // \------------/
1664 // int8 LHS 4x16 block
1665 // /-----------------------------------------\ /------------|
1666 // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 |
1667 // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 |
1668 // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 |
1669 // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 |
1670 // \-----------------------------------------/ \------------/
1671 // 128-bit accumulators 4x1 block
1672 //
1673 // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
1674 // optimization for this kernel.
1675 asm volatile(
1676#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n"
1677
1678 // clang-format off
1679
1680 // Load the first 64 bytes of LHS and RHS data.
1681 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1682 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1683 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1684 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1685 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1686 // Skip the other column and advance the pointer.
1687 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1688
1689 "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
1690
1691 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1692 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
1693
1694 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n"
1695 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
1696
1697 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
1698 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1699
1700 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n"
1701 "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1702
1703 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1704 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1705
1706 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n"
1707 "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1708
1709 // Clear accumulators.
1710 RUY_MAKE_ZERO(q6)
1711 RUY_MAKE_ZERO(q7)
1712 RUY_MAKE_ZERO(q8)
1713 RUY_MAKE_ZERO(q9)
1714 RUY_MAKE_ZERO(q10)
1715 RUY_MAKE_ZERO(q11)
1716 RUY_MAKE_ZERO(q12)
1717 RUY_MAKE_ZERO(q13)
1718 RUY_MAKE_ZERO(q14)
1719 RUY_MAKE_ZERO(q15)
1720
1721 // r1 is the number of levels of depth that we have already loaded
1722 // LHS and RHS data for. Corresponding to the initial ld1 instructions
1723 // above, this is currently 16.
1724 "mov r1, #16\n"
1725
1726 // Main loop of the whole GEMM, over rows and columns of the
1727 // destination matrix.
1728 "1:\n"
1729
1730 // r1 is how many levels of depth we have already loaded
1731 // data for, r10 is the total depth.
1732 "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
1733 "cmp r1, r10\n"
1734 "beq 79f\n"
1735
1736 "2:\n"
1737
1738 // Mult, mult-acc in to q14, q15
1739 "vmull.s8 q14, d0, d8\n"
1740 "vmull.s8 q15, d2, d8\n"
1741 "vmlal.s8 q14, d1, d9\n"
1742 "vmlal.s8 q15, d3, d9\n"
1743
1744 // Then pairwise accumulate in to q6, q7
1745 "vpadal.s16 q6, q14\n"
1746 "vpadal.s16 q7, q15\n"
1747
1748 // Mult, mult-acc in to q14, q15
1749 "vmull.s8 q14, d4, d8\n"
1750 "vmull.s8 q15, d6, d8\n"
1751 "vmlal.s8 q14, d5, d9\n"
1752 "vmlal.s8 q15, d7, d9\n"
1753
1754 // Then pairwise accumulate in to q8, q9
1755 "vpadal.s16 q8, q14\n"
1756 "vpadal.s16 q9, q15\n"
1757
1758
1759 // Load the next 64 bytes of LHS and RHS data.
1760 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1761 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1762 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1763 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1764 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1765 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1766 // Skip the other column and advance the pointer.
1767 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1768 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1769
1770 // Each iteration of this loop advances by 16 levels of depth.
1771 "add r1, r1, #16\n"
1772
1773 // Loop termination condition
1774 "cmp r1, r10\n"
1775
1776 "blt 2b\n"
1777
1778 "79:\n"
1779
1780 // Mult, mult-acc in to q14, q15
1781 "vmull.s8 q14, d0, d8\n"
1782 "vmull.s8 q15, d2, d8\n"
1783 "vmlal.s8 q14, d1, d9\n"
1784 "vmlal.s8 q15, d3, d9\n"
1785
1786 // Then pairwise accumulate in to q6, q7
1787 "vpadal.s16 q6, q14\n"
1788 "vpadal.s16 q7, q15\n"
1789
1790 // Mult, mult-acc in to q14, q15
1791 "vmull.s8 q14, d4, d8\n"
1792 "vmull.s8 q15, d6, d8\n"
1793 "vmlal.s8 q14, d5, d9\n"
1794 "vmlal.s8 q15, d7, d9\n"
1795
1796 // Then pairwise accumulate in to q8, q9
1797 "vpadal.s16 q8, q14\n"
1798 "vpadal.s16 q9, q15\n"
1799
1800 // All accumulation over depth done. q6 - q9 contain the 4x32b
1801 // accumulators for the 4x1 final matrix.
1802 // We now have to compute the final 8-bit values from these int32
1803 // accumulators, and advance to the next 4x2 block. We intertwine
1804 // these two aspects whenever possible for optimal pipelining, both
1805 // at the data flow level (prefetch data for next block as early as
1806 // possible) and instruction pipelining level (some of the next-block
1807 // work can dual-issue with some of the final work on the current
1808 // block).
1809
1810 // q6-q9 now contain 4 x 32b
1811 "vpadd.i32 d0, d12, d13\n"
1812 "vpadd.i32 d1, d14, d15\n"
1813 "vpadd.i32 d2, d16, d17\n"
1814 "vpadd.i32 d3, d18, d19\n"
1815
1816 // d0-d4 each contain 2 x 32b accumulators.
1817 // Need to add pairwise to get 1 x 32b for each of the 4x1 entries
1818 // of destination, (Four 'd' registers total)
1819 "vpadd.i32 d28, d0, d1\n"
1820 "vpadd.i32 d29, d2, d3\n"
1821
1822 // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries.
1823
1824 // Logic to advance to the next block in preparation for the next
1825 // iteration of the main loop. For now, we only want to compute
1826 // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
1827 // not yet ready to update the values of row and col, as we still need
1828 // the current values for the rest of the work on the current block.
1829
1830 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
1831 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1832 "cmp r1, r3\n" // Have we finished the last row?
1833
1834 "bge 4f\n" // If finished last row, go to 4
1835 // Not finished last row: then advance to next row.
1836 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
1837 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1838 "add r4, r4, r1, lsl #2\n"
1839 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1840 "b 5f\n"
1841 "4:\n" // Finished last row...
1842 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
1843 // Go back to first row
1844 "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1845
1846 // Now we need to advance to the next column. If we already
1847 // finished the last column, then in principle we are done, however
1848 // we can't just return here, as we need to allow the end work of the
1849 // current block to complete. The good news is that at this point it
1850 // doesn't matter what data we load for the next column, since
1851 // we will exit from the main loop below before actually storing
1852 // anything computed from that data.
1853
1854 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
1855 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1856 "cmp r8, r4\n" // Have we finished the last column?
1857 "bge 5f\n" // If yes, just carry on without updating the column pointer.
1858 // Not finished last column: then advance to next column.
1859 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
1860 "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1861 "add r10, r10, r1, lsl #1\n"
1862 "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1863 "5:\n"
1864
1865 // Set the LHS and RHS data pointers to the start of the columns just
1866 // computed.
1867 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n"
1868 "mov %[lhs_ptr], r4\n"
1869 "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n"
1870 "mov %[rhs_ptr], r5\n"
1871
1872 // Now we load: bias data, LHS sums data, RHS sums data.
1873
1874 // First, load the base pointers from the params.
1875 "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1876 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
1877
1878 // Offset these base pointers as needed given the current row, col.
1879 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1880
1881 "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
1882 "beq 1000f\n"
1883 "add r1, r1, r8, lsl #2\n"
1884 "1000:\n"
1885
1886 // Load 4 bias values.
1887 "vld1.32 {d24, d25}, [r1]\n"
1888
1889 // Now that we know what LHS and RHS data the next iteration of the
1890 // main loop will need to load, we start loading the first 32 bytes of
1891 // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
1892 // in the rest of the work on the current block.
1893 "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n"
1894 "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n"
1895 "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n"
1896 "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n"
1897 RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n")
1898 "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n"
1899 // Skip the other column and advance the pointer.
1900 "add %[rhs_ptr], %[rhs_ptr], #16\n"
1901 RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n")
1902
1903 // Add to the bias values the product
1904 // (depth * lhs_zero_point * rhs_zero_point),
1905 // See the term NZ1Z2 in equation (7) in
1906 // https://arxiv.org/pdf/1712.05877.pdf
1907 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
1908 "vdup.32 q9, r3\n"
1909 "vadd.i32 q12, q12, q9\n"
1910
1911 // Perform the bias-addition (per the above, we have just folded into
1912 // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
1913 "vadd.i32 q14, q14, q12\n"
1914
1915 // LHS/RHS zero points
1916 // Has RHS sums
1917 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1918 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
1919 "beq 401f\n"
1920 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
1921 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
1922 // Offset by current col * number of bytes per value
1923 "add r3, r3, r4, lsl #2\n"
1924 "vld1.32 { d12 }, [r3]\n"
1925 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
1926 "vdup.32 q10, r5\n" // create lhs_zero_point_vec
1927 // Subtract rhs_sums * lhs_zero_point, per
1928 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1929 "vmls.i32 q14, q10, d12[0]\n"
1930 "401:\n"
1931
1932 // Has LHS sums
1933 "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
1934 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
1935 "beq 402f\n"
1936 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
1937 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1938 // Offset by current row * number of bytes per value
1939 "add r2, r2, r4, lsl #2\n"
1940 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
1941
1942 // Load 4 lhs_sums values.
1943 "vld1.32 {d22, d23}, [r2]\n"
1944 "vdup.32 d13, r5\n" // rhs_zero_point
1945
1946 // Compute lhs_sums * rhs_zero_point.
1947 "vmul.i32 q11, q11, d13[1]\n"
1948 // Subtract lhs_sums * rhs_zero_point, per
1949 // equation (7) in https://arxiv.org/pdf/1712.05877.pdf
1950 "vsub.s32 q14, q14, q11\n"
1951
1952 // If the destination is int32, it means the user asks for the raw
1953 // accumulators, no need for us to downquantize the value.
1954 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
1955 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
1956 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
1957
1958 "402:\n"
1959
1960 // At this point we have computed the final int32 values. Now we
1961 // start down-quantizing them to obtain the final 8bit values from them.
1962
1963 // As part of this down-quantization, our int32 values will be
1964 // multiplied by a multiplier that has a fixed-point component and an
1965 // exponent component.
1966
1967 //Load the exponent part of the multiplier.
1968 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
1969 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1970 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
1971 "beq 1001f\n"
1972 "add r1, r1, r4, lsl #2\n"
1973 "1001:\n"
1974
1975 "vld1.32 {q10}, [r1]\n"
1976
1977 "vmvn.i32 q8, #0\n"
1978 "vmin.s32 q13, q10, q8\n"
1979 "vsub.s32 q12, q10, q13\n"
1980
1981 // Apply the positive exponent part of the multiplier.
1982 "vshl.s32 q14, q14, q12\n"
1983
1984 // Load fixed point part of the multiplier
1985 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
1986 // r6 has flags, r4 has row
1987 "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
1988 "beq 1002f\n"
1989 "add r1, r1, r4, lsl #2\n"
1990 "1002:\n"
1991 "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
1992
1993 // Apply the fixed-point part of the multiplier.
1994 "vqdmulh.s32 q14, q14, q10\n"
1995
1996 // Apply the negative exponent part of the multiplier.
1997 "vrshl.s32 q14, q14, q13\n"
1998
1999 "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
2000 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
2001 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
2002 "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
2003 "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
2004
2005 // Store uint8 values:
2006 RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
2007
2008 // Cast-and-saturate from int32 to int16
2009 // After this, all values for output are in d28.
2010 "vqmovn.s32 d28, q14\n"
2011
2012 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2013 // current block, so we can start clearing these accumulators for the
2014 // next block (next iteration of the main loop).
2015 RUY_MAKE_ZERO(q6)
2016 RUY_MAKE_ZERO(q7)
2017 RUY_MAKE_ZERO(q8)
2018 RUY_MAKE_ZERO(q9)
2019 RUY_MAKE_ZERO(q10)
2020 RUY_MAKE_ZERO(q11)
2021 RUY_MAKE_ZERO(q12)
2022 RUY_MAKE_ZERO(q13)
2023 RUY_MAKE_ZERO(q15)
2024
2025 // Load the destination zero point into each of the 8 16-bit slots
2026 // in a q register.
2027 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2028 "vdup.16 q13, r4\n" // dst_zero_point
2029
2030 // Add the destination zero point
2031 "vqadd.s16 q14, q14, q13\n"
2032
2033 // Cast-and-saturate from int16 to uint8
2034 "vqmovun.s16 d30, q14\n"
2035 // At this point, we only need 4 8-bit values in the lower half
2036 // of d30.
2037
2038
2039 // Load the clamp_min, clamp_max bounds
2040 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2041 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2042 "vdup.8 d28, r2\n" // clamp_min
2043 "vdup.8 d29, r3\n" // clamp_max
2044
2045 // Apply the clamp_min bound
2046 "vmax.u8 d30, d30, d28\n"
2047 // Apply the clamp_max bound
2048 "vmin.u8 d30, d30, d29\n"
2049
2050 // Compute how much of the 4x1 block of destination 8bit values that
2051 // we have computed, fit in the destination matrix. Typically, all of
2052 // it fits, but when the destination matrix shape is not a multiple
2053 // of 4x1, there are some 4x1 blocks along the boundaries that do
2054 // not fit entirely.
2055
2056 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2057 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2058 "sub r1, r1, r8\n"
2059
2060 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2061 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2062 "sub r2, r2, r4\n"
2063 "mov r3, #4\n"
2064 "mov r5, #2\n"
2065 "cmp r1, #4\n"
2066 // Compute r1 = how many rows of the 4x1 block fit
2067 "it gt\n"
2068 "movgt r1, r3\n"
2069
2070 // Test if r1==4, i.e. if all of the 4x1 block fits.
2071 "cmp r1, r3\n"
2072
2073 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2074 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2075 // Yes, all of the 4x1 block fits, go to fast path.
2076 "beq 30f\n"
2077 // Not all of the 4x1 block fits.
2078 // Store to dst_tmp_buf
2079 // Set r3 address to write to dst_tmp_buf.
2080 "mov r3, %[dst_tmp_buf]\n"
2081 "vst1.8 {d30}, [r3]\n"
2082
2083 // Slow loop copying from dst_tmp_buf to dst.
2084 "50:\n"
2085 "mov r8, #0\n"
2086 "51:\n"
2087 "ldrb r10, [r3, r8]\n"
2088 "strb r10, [r4, r8]\n"
2089 "add r8, r8, #1\n"
2090 "cmp r8, r1\n"
2091 "blt 51b\n"
2092 "b 31f\n"
2093 "30:\n"
2094 // Yes, all of the 4x1 block fits.
2095 // r3 address, r5 stride
2096 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2097 "mov r4, r3\n"
2098 "mov r6, #1\n"
2099
2100 "vst1.8 {d30[0]}, [r3], r6\n"
2101 "vst1.8 {d30[1]}, [r3], r6\n"
2102 "vst1.8 {d30[2]}, [r3], r6\n"
2103 "vst1.8 {d30[3]}, [r3], r6\n"
2104 "31:\n"
2105
2106 // Load dst_ptr, increment, and write back.
2107 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2108 "add r4, r4, #4\n"
2109 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2110
2111 RUY_MAKE_ZERO(q13)
2112 RUY_MAKE_ZERO(q14)
2113 RUY_MAKE_ZERO(q15)
2114
2115 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2116
2117 // Store int8 values:
2118 RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
2119
2120 // Cast-and-saturate from int32 to int16
2121 // After this, all values for output are in d28.
2122 "vqmovn.s32 d28, q14\n"
2123
2124 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2125 // current block, so we can start clearing these accumulators for the
2126 // next block (next iteration of the main loop).
2127 RUY_MAKE_ZERO(q6)
2128 RUY_MAKE_ZERO(q7)
2129 RUY_MAKE_ZERO(q8)
2130 RUY_MAKE_ZERO(q9)
2131 RUY_MAKE_ZERO(q10)
2132 RUY_MAKE_ZERO(q11)
2133 RUY_MAKE_ZERO(q12)
2134 RUY_MAKE_ZERO(q13)
2135 RUY_MAKE_ZERO(q15)
2136
2137 // Load the destination zero point into each of the 8 16-bit slots
2138 // in a q register.
2139 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2140 "vdup.16 q13, r4\n" // dst_zero_point
2141
2142 // Add the destination zero point
2143 "vqadd.s16 q14, q14, q13\n"
2144
2145 // Cast-and-saturate from int16 to int8
2146 "vqmovn.s16 d30, q14\n"
2147 // At this point, we only need 4 8-bit values in the lower half
2148 // of d30.
2149
2150 // Load the clamp_min, clamp_max bounds
2151 "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2152 "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2153 "vdup.8 d28, r2\n" // clamp_min
2154 "vdup.8 d29, r3\n" // clamp_max
2155
2156 // Apply the clamp_min bound
2157 "vmax.s8 d30, d30, d28\n"
2158 // Apply the clamp_max bound
2159 "vmin.s8 d30, d30, d29\n"
2160
2161 // Compute how much of the 4x1 block of destination 8bit values that
2162 // we have computed, fit in the destination matrix. Typically, all of
2163 // it fits, but when the destination matrix shape is not a multiple
2164 // of 4x2, there are some 4x2 blocks along the boundaries that do
2165 // not fit entirely.
2166
2167 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2168 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2169 "sub r1, r1, r8\n"
2170
2171 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2172 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2173 "sub r2, r2, r4\n"
2174 "mov r3, #4\n"
2175 "mov r5, #2\n"
2176 "cmp r1, #4\n"
2177 // Compute r1 = how many rows of the 4x2 block fit
2178 "it gt\n"
2179 "movgt r1, r3\n"
2180
2181 // Test if r1==4 i.e. if all of the 4x1 block fits.
2182 "cmp r1, r3\n"
2183
2184 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2185 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2186 // Yes, all of the 4x2 block fits, go to fast path.
2187 "beq 30f\n"
2188 // Not all of the 4x2 block fits.
2189 // Store to dst_tmp_buf
2190 // Set r3 address to write to dst_tmp_buf.
2191 "mov r3, %[dst_tmp_buf]\n"
2192 "vst1.8 {d30}, [r3]\n"
2193
2194 // Slow loop copying from dst_tmp_buf to dst.
2195 "50:\n"
2196 "mov r8, #0\n"
2197 "51:\n"
2198 "ldrb r10, [r3, r8]\n"
2199 "strb r10, [r4, r8]\n"
2200 "add r8, r8, #1\n"
2201 "cmp r8, r1\n"
2202 "blt 51b\n"
2203 "b 31f\n"
2204 "30:\n"
2205 // Yes, all of the 4x1 block fits.
2206 // r3 address, r5 stride
2207 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2208 "mov r4, r3\n"
2209 "mov r6, #1\n"
2210
2211 "vst1.8 {d30[0]}, [r3], r6\n"
2212 "vst1.8 {d30[1]}, [r3], r6\n"
2213 "vst1.8 {d30[2]}, [r3], r6\n"
2214 "vst1.8 {d30[3]}, [r3], r6\n"
2215 "31:\n"
2216
2217 // Load dst_ptr, increment, and write back.
2218 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2219 "add r4, r4, #4\n"
2220 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2221
2222 RUY_MAKE_ZERO(q13)
2223 RUY_MAKE_ZERO(q14)
2224 RUY_MAKE_ZERO(q15)
2225
2226 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2227
2228 RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
2229
2230 // Load the destination zero point into each of the 4 32-bit slots
2231 // in a q register.
2232 "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
2233 "vdup.32 q13, r4\n" // dst_zero_point
2234 // Add the destination zero point
2235 "vadd.s32 q14, q14, q13\n"
2236 //"vadd.s32 q15, q15, q13\n"
2237
2238 // Cast-and-saturate from int32 to int16
2239 // After this, all values for output are in d28.
2240 "vqmovn.s32 d28, q14\n"
2241
2242 // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the
2243 // so we can start clearing these accumulators for the next block
2244 // (next iteration of the main loop).
2245 RUY_MAKE_ZERO(q6)
2246 RUY_MAKE_ZERO(q7)
2247 RUY_MAKE_ZERO(q8)
2248 RUY_MAKE_ZERO(q9)
2249 RUY_MAKE_ZERO(q10)
2250 RUY_MAKE_ZERO(q11)
2251 RUY_MAKE_ZERO(q15)
2252
2253 // Load the clamp_min, clamp_max bounds
2254 "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
2255 "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
2256 "vdup.16 d24, r2\n" // clamp_min
2257 "vdup.16 d26, r3\n" // clamp_max
2258
2259 // Apply the clamp_min bound
2260 "vmax.s16 d28, d28, d24\n"
2261 // Apply the clamp_max bound
2262 "vmin.s16 d28, d28, d26\n"
2263
2264 RUY_MAKE_ZERO(q12)
2265 RUY_MAKE_ZERO(q13)
2266
2267 // Compute how much of the 4x1 block of destination 16-bit values that
2268 // we have computed, fit in the destination matrix. Typically, all of
2269 // it fits, but when the destination matrix shape is not a multiple
2270 // of 4x1, there are some 4x1 blocks along the boundaries that do
2271 // not fit entirely.
2272
2273 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2274 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2275 "sub r1, r1, r8\n"
2276
2277 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2278 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2279 "sub r2, r2, r4\n"
2280 "mov r3, #4\n"
2281 "mov r5, #2\n"
2282 "cmp r1, #4\n"
2283 // Compute r1 = how many rows of the 4x1 block fit
2284 "it gt\n"
2285 "movgt r1, r3\n"
2286
2287 // Test if r1==4, i.e. if all of the 4x1 block fits.
2288 "cmp r1, r3\n"
2289
2290 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2291 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2292 // Yes, all of the 4x1 block fits, go to fast path.
2293 "beq 30f\n"
2294 // Not all of the 4x1 block fits.
2295 // Store to dst_tmp_buf
2296 // Set r3 address to write to dst_tmp_buf.
2297 "mov r3, %[dst_tmp_buf]\n"
2298 "vst1.16 {d28}, [r3]\n"
2299
2300 // Slow loop copying from dst_tmp_buf to dst.
2301 "50:\n"
2302 "mov r8, #0\n"
2303 "51:\n"
2304 // Shift of offset register for half-word loads not allowed in A32,
2305 // so we shift, load/store, then shift back r8.
2306 "lsl r8, r8, #1\n"
2307 "ldrh r10, [r3, r8]\n"
2308 "strh r10, [r4, r8]\n"
2309 "lsr r8, r8, #1\n"
2310 "add r8, r8, #1\n"
2311 "cmp r8, r1\n"
2312 "blt 51b\n"
2313 "b 31f\n"
2314 "30:\n"
2315 // Yes, all of the 4x1 block fits.
2316 // r3 address, r5 stride
2317 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2318 "mov r4, r3\n"
2319 "mov r6, #2\n"
2320
2321 "vst1.16 {d28[0]}, [r3], r6\n"
2322 "vst1.16 {d28[1]}, [r3], r6\n"
2323 "vst1.16 {d28[2]}, [r3], r6\n"
2324 "vst1.16 {d28[3]}, [r3], r6\n"
2325 "31:\n"
2326
2327 // Load dst_ptr, increment, and write back.
2328 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2329 "add r4, r4, #8\n"
2330 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2331
2332 RUY_MAKE_ZERO(q14)
2333
2334 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2335
2336 RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
2337
2338 // Since the store type is the same as the accum type, no need for
2339 // downcast. There's also no need for clamp by min/max.
2340
2341 // At this point, v20 -- v31 aren't used anymore for the current block,
2342 // so we can start clearing these accumulators for the next block
2343 // (next iteration of the main loop).
2344 // Clear accumulators.
2345 RUY_MAKE_ZERO(q6)
2346 RUY_MAKE_ZERO(q7)
2347 RUY_MAKE_ZERO(q8)
2348 RUY_MAKE_ZERO(q9)
2349 RUY_MAKE_ZERO(q10)
2350 RUY_MAKE_ZERO(q11)
2351 RUY_MAKE_ZERO(q12)
2352 RUY_MAKE_ZERO(q13)
2353
2354 // Compute how much of the 4x1 block of destination 32 bit values that
2355 // we have computed, fit in the destination matrix. Typically, all of
2356 // it fits, but when the destination matrix shape is not a multiple
2357 // of 4x2, there are some 4x4 blocks along the boundaries that do
2358 // not fit entirely.
2359
2360 "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n"
2361 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2362 "sub r1, r1, r8\n"
2363
2364 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n"
2365 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2366 "sub r2, r2, r4\n"
2367 "mov r3, #4\n"
2368 "mov r5, #2\n"
2369 "cmp r1, #4\n"
2370 // Compute r1 = how many rows of the 4x2 block fit
2371 "it gt\n"
2372 "movgt r1, r3\n"
2373
2374 // Test if r1==4, i.e. if all of the 4x1 block fits.
2375 "cmp r1, r3\n"
2376
2377 // Yes, all of the 4x1 block fits, go to fast path.
2378 "beq 30f\n"
2379 // Not all of the 4x1 block fits.
2380 // Set (r3 address, r4 stride) to write to dst_tmp_buf
2381 "mov r3, %[dst_tmp_buf]\n"
2382 "mov r4, #16\n"
2383 "b 31f\n"
2384
2385 "30:\n"
2386 // Yes, all of the 4x1 block fits.
2387 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2388 // r3 address, r4 stride
2389 "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2390 "mov r4, r5\n"
2391
2392 "31:\n"
2393
2394 "vst1.32 {d28, d29}, [r3]\n"
2395
2396 // If all of the 4x1 block fits, we just finished writing it to the
2397 // destination, so we skip the next part.
2398 "beq 41f\n"
2399 // Not all of the 4x1 block fits in the destination matrix. We just
2400 // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
2401 // it to copy into the destination matrix the part that fits.
2402 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2403 "mov r3, %[dst_tmp_buf]\n"
2404 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2405 "50:\n"
2406 "mov r5, #0\n"
2407 "51:\n"
2408 "ldr r10, [r3, r5, lsl #2]\n"
2409 "str r10, [r4, r5, lsl #2]\n"
2410 "add r5, r5, #1\n"
2411 "cmp r5, r1\n"
2412 "blt 51b\n"
2413
2414 "41:\n"
2415 // Load dst_ptr, increment, and write back.
2416 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2417 "add r4, r4, #16\n"
2418 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2419
2420 RUY_MAKE_ZERO(q10)
2421 RUY_MAKE_ZERO(q11)
2422
2423 "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
2424
2425 RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
2426
2427 // Reload some params --- we had used x5 -- x7 for a few other things
2428 // since the last time we had loaded them.
2429 "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
2430 "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
2431 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
2432
2433 // Move to the next block of the destination matrix, for the next iter
2434 // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
2435 // been updated earlier.
2436 // Have we reached the end row?
2437 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2438 "cmp r8, r3\n"
2439
2440 "beq 20f\n" // yes, end row.
2441 // Not end row. Move to the next row.
2442 "add r8, r8, #4\n"
2443 // Store new value of row
2444 "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2445
2446 "b 21f\n"
2447 "20:\n"
2448 // Was already at end row.
2449 // Move back to first row.
2450 "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
2451 // Move to the next column.
2452 "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2453 "add r4, r4, #2\n"
2454 "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2455
2456 "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
2457 "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2458 // Increment dst_col_ptr by dst_stride (i.e. 1 column)
2459 "add r1, r1, r8\n"
2460 // Store dst_col_ptr
2461 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n"
2462 // Store dst_ptr
2463 "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n"
2464 "21:\n"
2465
2466 // Main loop exit condition: have we hit the end column?
2467 "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
2468 "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n"
2469 "cmp r8, r4\n"
2470
2471 // w1 is the number of levels of depth that we have already loaded
2472 // LHS and RHS data for. Corresponding to the initial ld1 instructions
2473 // above, this is currently 16.
2474 "mov r1, #16\n"
2475
2476 "ble 1b\n"
2477
2478 // Restore stack pointer.
2479 "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n"
2480
2481 // clang-format on
2482
2483 : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr)
2484 : [ params ] "r"(&params), [dst_tmp_buf] "r"(params.dst_tmp_buf)
2485 : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc",
2486 // Clobber list must specify q registers (and not their constituent
2487 // d registers). There is a (currently unexplained) slowdown if
2488 // d registers are listed in the clobbers list.
2489 "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
2490 "q9", "q10", "q12", "q13", "q14", "q15");
2491}
2492
2493#undef RUY_OFFSET_BIAS
2494#undef RUY_OFFSET_LHS_SUMS
2495#undef RUY_OFFSET_RHS_SUMS
2496#undef RUY_OFFSET_LHS_BASE_PTR
2497#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT
2498#undef RUY_OFFSET_MULTIPLIER_EXPONENT
2499#undef RUY_OFFSET_RHS_BASE_PTR
2500#undef RUY_OFFSET_DST_BASE_PTR
2501#undef RUY_OFFSET_LHS_ZERO_POINT
2502#undef RUY_OFFSET_RHS_ZERO_POINT
2503#undef RUY_OFFSET_DST_ZERO_POINT
2504#undef RUY_OFFSET_PROD_ZP_DEPTH
2505#undef RUY_OFFSET_START_ROW
2506#undef RUY_OFFSET_START_COL
2507#undef RUY_OFFSET_LAST_ROW
2508#undef RUY_OFFSET_LAST_COL
2509#undef RUY_OFFSET_DST_ROWS
2510#undef RUY_OFFSET_DST_COLS
2511#undef RUY_OFFSET_LHS_STRIDE
2512#undef RUY_OFFSET_RHS_STRIDE
2513#undef RUY_OFFSET_DST_STRIDE
2514#undef RUY_OFFSET_DEPTH
2515#undef RUY_OFFSET_CLAMP_MIN
2516#undef RUY_OFFSET_CLAMP_MAX
2517#undef RUY_OFFSET_FLAGS
2518#undef RUY_OFFSET_DST_TYPE_ID
2519
2520#undef RUY_STACK_OFFSET_SIZE
2521#undef RUY_STACK_OFFSET_DST_COL_PTR
2522#undef RUY_STACK_OFFSET_DST_PTR
2523#undef RUY_STACK_OFFSET_ROW
2524#undef RUY_STACK_OFFSET_COL
2525#undef RUY_STACK_OFFSET_LHS_COL_PTR
2526#undef RUY_STACK_OFFSET_RHS_COL_PTR
2527
2528#endif // RUY_PLATFORM_NEON_32 && (RUY_OPT(ASM)
2529} // namespace ruy
2530