1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "ruy/pack_arm.h"
17
18#include <cstdint>
19
20#include "ruy/asm_helpers.h"
21#include "ruy/opt_set.h"
22#include "ruy/pack_common.h"
23#include "ruy/platform.h"
24#include "ruy/profiler/instrumentation.h"
25
26#if RUY_PLATFORM_NEON
27#include <arm_neon.h>
28#endif
29
30namespace ruy {
31
32#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
33
34void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
35 const void* src_ptr2, const void* src_ptr3,
36 int src_inc0, int src_inc1, int src_inc2,
37 int src_inc3, int src_rows, int src_zero_point,
38 std::int8_t* packed_ptr, std::int32_t* sums_ptr,
39 int input_xor) {
40 profiler::ScopeLabel label("Pack (kNeon)");
41 asm volatile(
42 // clang-format off
43 // v26 will be the vector to XOR input values with to perform
44 // any input data type conversion (e.g. uint8 to int8).
45 "dup v26.16b, %w[input_xor]\n"
46 // w1 will be the number of rows already loaded.
47 "mov w1, #0\n"
48 // v28--v32 will be used to accumulate the sums
49 "movi v28.4s, #0\n"
50 "movi v29.4s, #0\n"
51 "movi v30.4s, #0\n"
52 "movi v31.4s, #0\n"
53 // Let w2 be `rows` rounded down to multiple of 16.
54 "ands w2, %w[rows], #-16\n"
55 // If there are no full blocks of 16 rows to process, jump to the
56 // code handling the last < 16 rows.
57 "beq 3f\n"
58 // Load the first block of 16 rows.
59 "add w1, w1, #16\n"
60 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
61 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
62 // Check if these were the only full block of 16 rows to load.
63 "cmp w1, w2\n"
64 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
65 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
66 // In that case, jump to the code handling the last loaded block of
67 // 16 rows.
68 "beq 2f\n"
69 // Main loop processing blocks of 16 rows.
70 "1:\n"
71 // Load the next 16 rows, interleaved with the XOR input type
72 // conversion (e.g. uint8->int8) on the already loaded inputs.
73 "add w1, w1, #16\n"
74 "eor v4.16b, v0.16b, v26.16b\n"
75 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
76 "eor v5.16b, v1.16b, v26.16b\n"
77 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
78 "eor v6.16b, v2.16b, v26.16b\n"
79 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
80 "eor v7.16b, v3.16b, v26.16b\n"
81 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
82 // Compute the sums, interleaved with storing to the packed matrix.
83 "saddlp v16.8h, v4.16b\n"
84 "str q4, [%[packed_ptr], #0]\n"
85 "saddlp v17.8h, v5.16b\n"
86 "str q5, [%[packed_ptr], #16]\n"
87 "saddlp v18.8h, v6.16b\n"
88 "str q6, [%[packed_ptr], #32]\n"
89 "saddlp v19.8h, v7.16b\n"
90 "str q7, [%[packed_ptr], #48]\n"
91 "sadalp v28.4s, v16.8h\n"
92 // Was this the last block of 16 rows to load?
93 "cmp w1, w2\n"
94 "sadalp v29.4s, v17.8h\n"
95 "add %[packed_ptr], %[packed_ptr], #64\n"
96 "sadalp v30.4s, v18.8h\n"
97 "sadalp v31.4s, v19.8h\n"
98 // End of main loop on blocks of 16 rows.
99 "bne 1b\n"
100
101 // Code handling the last already-loaded block of 16 rows.
102 "2:\n"
103
104 // Process the last loaded full 16x4 block.
105 "eor v4.16b, v0.16b, v26.16b\n"
106 "eor v5.16b, v1.16b, v26.16b\n"
107 "eor v6.16b, v2.16b, v26.16b\n"
108 "eor v7.16b, v3.16b, v26.16b\n"
109
110 "saddlp v16.8h, v4.16b\n"
111 "str q4, [%[packed_ptr], #0]\n"
112 "saddlp v17.8h, v5.16b\n"
113 "str q5, [%[packed_ptr], #16]\n"
114 "saddlp v18.8h, v6.16b\n"
115 "str q6, [%[packed_ptr], #32]\n"
116 "saddlp v19.8h, v7.16b\n"
117 "str q7, [%[packed_ptr], #48]\n"
118 "sadalp v28.4s, v16.8h\n"
119 "sadalp v29.4s, v17.8h\n"
120 "sadalp v30.4s, v18.8h\n"
121 "sadalp v31.4s, v19.8h\n"
122
123 "add %[packed_ptr], %[packed_ptr], #64\n"
124
125 // End of code handling full blocks of 16 rows.
126 // Now we handle any remaining rows.
127 "3:\n"
128 // Let w2 be the number of rows left to handle.
129 "ands w2, %w[rows], #15\n"
130 // If w2==0, there are no remaining rows, jump to the end.
131 "beq 4f\n"
132 // Zero out a 16x4 block in registers, which we'll partially overwrite
133 // with any remaining rows.
134 "dup v0.16b, %w[src_zero_point]\n"
135 "dup v1.16b, %w[src_zero_point]\n"
136 "dup v2.16b, %w[src_zero_point]\n"
137 "dup v3.16b, %w[src_zero_point]\n"
138#define RUY_LOAD_ONE_ROW(R) \
139 "cmp w2, #" #R "\n" \
140 "beq 5f\n" \
141 "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
142 "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
143 "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
144 "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
145
146 RUY_LOAD_ONE_ROW(0)
147 RUY_LOAD_ONE_ROW(1)
148 RUY_LOAD_ONE_ROW(2)
149 RUY_LOAD_ONE_ROW(3)
150 RUY_LOAD_ONE_ROW(4)
151 RUY_LOAD_ONE_ROW(5)
152 RUY_LOAD_ONE_ROW(6)
153 RUY_LOAD_ONE_ROW(7)
154 RUY_LOAD_ONE_ROW(8)
155 RUY_LOAD_ONE_ROW(9)
156 RUY_LOAD_ONE_ROW(10)
157 RUY_LOAD_ONE_ROW(11)
158 RUY_LOAD_ONE_ROW(12)
159 RUY_LOAD_ONE_ROW(13)
160 RUY_LOAD_ONE_ROW(14)
161 // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
162#undef RUY_LOAD_ONE_ROW
163 "5:\n"
164
165 // Process the last zero-padded 16x4 block.
166 "eor v4.16b, v0.16b, v26.16b\n"
167 "eor v5.16b, v1.16b, v26.16b\n"
168 "eor v6.16b, v2.16b, v26.16b\n"
169 "eor v7.16b, v3.16b, v26.16b\n"
170
171 "saddlp v16.8h, v4.16b\n"
172 "saddlp v17.8h, v5.16b\n"
173 "saddlp v18.8h, v6.16b\n"
174 "saddlp v19.8h, v7.16b\n"
175 "sadalp v28.4s, v16.8h\n"
176 "sadalp v29.4s, v17.8h\n"
177 "sadalp v30.4s, v18.8h\n"
178 "sadalp v31.4s, v19.8h\n"
179
180 "str q4, [%[packed_ptr], #0]\n"
181 "str q5, [%[packed_ptr], #16]\n"
182 "str q6, [%[packed_ptr], #32]\n"
183 "str q7, [%[packed_ptr], #48]\n"
184 "add %[packed_ptr], %[packed_ptr], #64\n"
185
186 "4:\n"
187
188 // Horizontal reduction of the registers used to accumulate sums.
189 "addp v28.4s, v28.4s, v29.4s\n"
190 "addp v30.4s, v30.4s, v31.4s\n"
191 "addp v28.4s, v28.4s, v30.4s\n"
192
193 // Store the sums.
194 "cmp %[sums_ptr], #0\n"
195 "beq 6f\n"
196 "st1 {v28.4s}, [%[sums_ptr]], #16\n"
197 "6:\n"
198 // clang-format on
199
200 : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
201 [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
202 [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
203 : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
204 [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
205 [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
206 [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
207 [rows] "r"(src_rows), [src_zero_point] "r"(src_zero_point),
208 [input_xor] "r"(input_xor)
209 : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
210 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
211 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
212 "v27", "v28", "v29", "v30", "v31");
213}
214#endif
215
216#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
217
218#define RUY_OFFSET_SRC_PTR0 0
219#define RUY_OFFSET_SRC_PTR1 4
220#define RUY_OFFSET_SRC_PTR2 8
221#define RUY_OFFSET_SRC_PTR3 12
222#define RUY_OFFSET_SUMS_PTR 16
223#define RUY_OFFSET_PACKED_PTR 20
224#define RUY_OFFSET_SRC_INC0 24
225#define RUY_OFFSET_SRC_INC1 28
226#define RUY_OFFSET_SRC_INC2 32
227#define RUY_OFFSET_SRC_INC3 36
228#define RUY_OFFSET_SRC_ROWS 40
229#define RUY_OFFSET_SRC_ZERO_POINT 44
230#define RUY_OFFSET_INPUT_XOR 48
231
232template <typename Params>
233void CheckOffsetsInPackParams8bit(const Params&) {
234 static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, "");
235 static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, "");
236 static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, "");
237 static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, "");
238 static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, "");
239 static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, "");
240 static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, "");
241 static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, "");
242 static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, "");
243 static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, "");
244 static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, "");
245 static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT,
246 "");
247 static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, "");
248}
249
250// No attempt made at making this code efficient on A55-ish cores yet.
251void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params) {
252 CheckOffsetsInPackParams8bit(params);
253 profiler::ScopeLabel label("Pack (kNeon)");
254 const void* src_ptr0 = params.src_ptr0;
255 const void* src_ptr1 = params.src_ptr1;
256 const void* src_ptr2 = params.src_ptr2;
257 const void* src_ptr3 = params.src_ptr3;
258 const int src_inc0 = params.src_inc0;
259 const int src_inc1 = params.src_inc1;
260 const int src_inc2 = params.src_inc2;
261 const int src_inc3 = params.src_inc3;
262 const std::int8_t* packed_ptr = params.packed_ptr;
263
264 asm volatile(
265 // clang-format off
266
267 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
268 "vdup.8 q11, r2\n"
269 "mov r1, #0\n"
270 // Zero-out the accumulators
271 "vmov.i32 q12, #0\n"
272 "vmov.i32 q13, #0\n"
273 "vmov.i32 q14, #0\n"
274 "vmov.i32 q15, #0\n"
275
276 // Round down src_rows to nearest multiple of 16.
277 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
278 "and r2, r3, #-16\n"
279 "cmp r1, r2\n"
280 "beq 3f\n"
281
282 "1:\n"
283 "add r1, r1, #16\n"
284 /* Load q0 */
285 "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
286 "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
287 RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n")
288
289 /* Load q1 */
290 "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
291 "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
292 RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n")
293
294 "veor.8 q4, q0, q11\n"
295 "veor.8 q5, q1, q11\n"
296
297 // Pairwise add in to 16b accumulators.
298 "vpaddl.s8 q8, q4\n"
299 "vpaddl.s8 q9, q5\n"
300
301 "vst1.32 {q4}, [%[packed_ptr]]!\n"
302 "vst1.32 {q5}, [%[packed_ptr]]!\n"
303
304 // Pairwise add accumulate into 32b accumulators.
305 // q12 and q13 contain 4x32b accumulators
306 "vpadal.s16 q12, q8\n"
307 "vpadal.s16 q13, q9\n"
308
309 // Now do the same for src_ptr2 and src_ptr3.
310 "vld1.8 {d0, d1}, [%[src_ptr2]]\n"
311 "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n"
312 RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n")
313
314 "vld1.8 {d2, d3}, [%[src_ptr3]]\n"
315 "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n"
316 RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n")
317
318 "veor.8 q4, q0, q11\n"
319 "veor.8 q5, q1, q11\n"
320
321 "vpaddl.s8 q8, q4\n"
322 "vpaddl.s8 q9, q5\n"
323
324 "vst1.32 {q4}, [%[packed_ptr]]!\n"
325 "vst1.32 {q5}, [%[packed_ptr]]!\n"
326
327 // Pairwise add accumulate into 32b accumulators.
328 // q14 and q15 contain 4x32b accumulators
329 "vpadal.s16 q14, q8\n"
330 "vpadal.s16 q15, q9\n"
331
332 "cmp r1, r2\n"
333 "bne 1b\n"
334
335 "3:\n"
336
337 // Now pack the last (num_rows % 16) rows.
338 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
339 "ands r2, r3, #15\n"
340 "beq 4f\n"
341 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
342 "vdup.8 q0, r3\n"
343 "vdup.8 q1, r3\n"
344
345// First, read/accumulate/write for src_ptr0 and src_ptr1.
346#define RUY_LOAD_ONE_ROW1(I, R) \
347 "cmp r2, #" #I "\n" \
348 "beq 5f\n" \
349 "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
350 "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
351
352 RUY_LOAD_ONE_ROW1(0, 0)
353 RUY_LOAD_ONE_ROW1(1, 1)
354 RUY_LOAD_ONE_ROW1(2, 2)
355 RUY_LOAD_ONE_ROW1(3, 3)
356 RUY_LOAD_ONE_ROW1(4, 4)
357 RUY_LOAD_ONE_ROW1(5, 5)
358 RUY_LOAD_ONE_ROW1(6, 6)
359 RUY_LOAD_ONE_ROW1(7, 7)
360#undef RUY_LOAD_ONE_ROW1
361
362#define RUY_LOAD_ONE_ROW2(I, R) \
363 "cmp r2, #" #I "\n" \
364 "beq 5f\n" \
365 "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
366 "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
367
368 RUY_LOAD_ONE_ROW2(8, 0)
369 RUY_LOAD_ONE_ROW2(9, 1)
370 RUY_LOAD_ONE_ROW2(10, 2)
371 RUY_LOAD_ONE_ROW2(11, 3)
372 RUY_LOAD_ONE_ROW2(12, 4)
373 RUY_LOAD_ONE_ROW2(13, 5)
374 RUY_LOAD_ONE_ROW2(14, 6)
375 RUY_LOAD_ONE_ROW2(15, 7)
376#undef RUY_LOAD_ONE_ROW2
377
378 "5:\n"
379
380 "veor.16 q4, q0, q11\n"
381 "veor.16 q5, q1, q11\n"
382
383 "vpaddl.s8 q8, q4\n"
384 "vpaddl.s8 q9, q5\n"
385
386 // Pairwise add accumulate to 4x32b accumulators.
387 "vpadal.s16 q12, q8\n"
388 "vpadal.s16 q13, q9\n"
389
390 "vst1.32 {q4}, [%[packed_ptr]]!\n"
391 "vst1.32 {q5}, [%[packed_ptr]]!\n"
392
393 // Reset to src_zero for src_ptr2 and src_ptr3.
394 "vdup.8 q0, r3\n"
395 "vdup.8 q1, r3\n"
396
397// Next, read/accumulate/write for src_ptr2 and src_ptr3.
398#define RUY_LOAD_ONE_ROW1(I, R) \
399 "cmp r2, #" #I "\n" \
400 "beq 5f\n" \
401 "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \
402 "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \
403
404 RUY_LOAD_ONE_ROW1(0, 0)
405 RUY_LOAD_ONE_ROW1(1, 1)
406 RUY_LOAD_ONE_ROW1(2, 2)
407 RUY_LOAD_ONE_ROW1(3, 3)
408 RUY_LOAD_ONE_ROW1(4, 4)
409 RUY_LOAD_ONE_ROW1(5, 5)
410 RUY_LOAD_ONE_ROW1(6, 6)
411 RUY_LOAD_ONE_ROW1(7, 7)
412#undef RUY_LOAD_ONE_ROW1
413
414#define RUY_LOAD_ONE_ROW2(I, R) \
415 "cmp r2, #" #I "\n" \
416 "beq 5f\n" \
417 "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \
418 "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \
419
420 RUY_LOAD_ONE_ROW2(8, 0)
421 RUY_LOAD_ONE_ROW2(9, 1)
422 RUY_LOAD_ONE_ROW2(10, 2)
423 RUY_LOAD_ONE_ROW2(11, 3)
424 RUY_LOAD_ONE_ROW2(12, 4)
425 RUY_LOAD_ONE_ROW2(13, 5)
426 RUY_LOAD_ONE_ROW2(14, 6)
427 RUY_LOAD_ONE_ROW2(15, 7)
428#undef RUY_LOAD_ONE_ROW2
429
430 "5:\n"
431
432 "veor.16 q4, q0, q11\n"
433 "veor.16 q5, q1, q11\n"
434
435 "vpaddl.s8 q8, q4\n"
436 "vpaddl.s8 q9, q5\n"
437
438 // Pairwise add accumulate to 4x32b accumulators.
439 "vpadal.s16 q14, q8\n"
440 "vpadal.s16 q15, q9\n"
441
442 "vst1.32 {q4}, [%[packed_ptr]]!\n"
443 "vst1.32 {q5}, [%[packed_ptr]]!\n"
444
445 "4:\n"
446 // Pairwise add 32-bit accumulators
447 "vpadd.i32 d24, d24, d25\n"
448 "vpadd.i32 d26, d26, d27\n"
449 "vpadd.i32 d28, d28, d29\n"
450 "vpadd.i32 d30, d30, d31\n"
451 // Final 32-bit values per row
452 "vpadd.i32 d25, d24, d26\n"
453 "vpadd.i32 d27, d28, d30\n"
454
455 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
456 "cmp r3, #0\n"
457 "beq 6f\n"
458 "vst1.32 {d25}, [r3]!\n"
459 "vst1.32 {d27}, [r3]!\n"
460 "6:\n"
461 // clang-format on
462
463 : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
464 [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3)
465 : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
466 [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3),
467 [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
468 : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
469 "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
470}
471
472// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9.
473// No attempt made at making this code efficient on in-order cores yet.
474// This version differs from the above in that we only handle two columns
475// at a time.
476void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params) {
477 CheckOffsetsInPackParams8bit(params);
478 profiler::ScopeLabel label("Pack (kNeon)");
479 const void* src_ptr0 = params.src_ptr0;
480 const void* src_ptr1 = params.src_ptr1;
481 const int src_inc0 = params.src_inc0;
482 const int src_inc1 = params.src_inc1;
483 const std::int8_t* packed_ptr = params.packed_ptr;
484
485 asm volatile(
486 // clang-format off
487
488 "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n"
489 "vdup.8 q11, r2\n"
490 "mov r1, #0\n"
491 // Zero-out the accumulators
492 "vmov.i32 q12, #0\n"
493 "vmov.i32 q13, #0\n"
494
495 // Round down src_rows to nearest multiple of 16.
496 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
497 "and r2, r3, #-16\n"
498 "cmp r1, r2\n"
499 "beq 3f\n"
500
501 "1:\n"
502 "add r1, r1, #16\n"
503 /* Load q0 */
504 "vld1.8 {d0, d1}, [%[src_ptr0]]\n"
505 "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n"
506
507 /* Load q1 */
508 "vld1.8 {d2, d3}, [%[src_ptr1]]\n"
509 "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n"
510
511 "veor.8 q4, q0, q11\n"
512 "veor.8 q5, q1, q11\n"
513
514 // Pairwise add in to 16b accumulators.
515 "vpaddl.s8 q8, q4\n"
516 "vpaddl.s8 q9, q5\n"
517
518 "vst1.32 {q4}, [%[packed_ptr]]!\n"
519 "vst1.32 {q5}, [%[packed_ptr]]!\n"
520
521 // Pairwise add accumulate into 32b accumulators.
522 // q12 and q13 contain 4x32b accumulators
523 "vpadal.s16 q12, q8\n"
524 "vpadal.s16 q13, q9\n"
525
526 "cmp r1, r2\n"
527
528 "bne 1b\n"
529
530 "3:\n"
531
532 // Now pack the last (num_rows % 16) rows.
533 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n"
534 "ands r2, r3, #15\n"
535 "beq 4f\n"
536 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n"
537 "vdup.8 q0, r3\n"
538 "vdup.8 q1, r3\n"
539
540// Read/accumulate/write for src_ptr0 and src_ptr1.
541#define RUY_LOAD_ONE_ROW1(I, R) \
542 "cmp r2, #" #I "\n" \
543 "beq 5f\n" \
544 "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \
545 "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \
546
547 RUY_LOAD_ONE_ROW1(0, 0)
548 RUY_LOAD_ONE_ROW1(1, 1)
549 RUY_LOAD_ONE_ROW1(2, 2)
550 RUY_LOAD_ONE_ROW1(3, 3)
551 RUY_LOAD_ONE_ROW1(4, 4)
552 RUY_LOAD_ONE_ROW1(5, 5)
553 RUY_LOAD_ONE_ROW1(6, 6)
554 RUY_LOAD_ONE_ROW1(7, 7)
555#undef RUY_LOAD_ONE_ROW1
556
557#define RUY_LOAD_ONE_ROW2(I, R) \
558 "cmp r2, #" #I "\n" \
559 "beq 5f\n" \
560 "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \
561 "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \
562
563 RUY_LOAD_ONE_ROW2(8, 0)
564 RUY_LOAD_ONE_ROW2(9, 1)
565 RUY_LOAD_ONE_ROW2(10, 2)
566 RUY_LOAD_ONE_ROW2(11, 3)
567 RUY_LOAD_ONE_ROW2(12, 4)
568 RUY_LOAD_ONE_ROW2(13, 5)
569 RUY_LOAD_ONE_ROW2(14, 6)
570 RUY_LOAD_ONE_ROW2(15, 7)
571#undef RUY_LOAD_ONE_ROW2
572
573 "5:\n"
574
575 "veor.16 q4, q0, q11\n"
576 "veor.16 q5, q1, q11\n"
577
578 "vpaddl.s8 q8, q4\n"
579 "vpaddl.s8 q9, q5\n"
580
581
582 // Pairwise add accumulate to 4x32b accumulators.
583 "vpadal.s16 q12, q8\n"
584 "vpadal.s16 q13, q9\n"
585
586 "vst1.32 {q4}, [%[packed_ptr]]!\n"
587 "vst1.32 {q5}, [%[packed_ptr]]!\n"
588
589 "4:\n"
590
591 // Pairwise add 32-bit accumulators
592 "vpadd.i32 d24, d24, d25\n"
593 "vpadd.i32 d26, d26, d27\n"
594 // Final 32-bit values per row
595 "vpadd.i32 d25, d24, d26\n"
596
597 "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n"
598 "cmp r3, #0\n"
599 "beq 6f\n"
600 "vst1.32 {d25}, [r3]!\n"
601 "6:\n"
602 // clang-format on
603
604 : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1)
605 : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1),
606 [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(&params)
607 : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
608 "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13");
609}
610
611#undef RUY_OFFSET_SRC_PTR0
612#undef RUY_OFFSET_SRC_PTR1
613#undef RUY_OFFSET_SRC_PTR2
614#undef RUY_OFFSET_SRC_PTR32
615#undef RUY_OFFSET_SUMS_PTR
616#undef RUY_OFFSET_PACKED_PTR0
617#undef RUY_OFFSET_SRC_INC0
618#undef RUY_OFFSET_SRC_INC1
619#undef RUY_OFFSET_SRC_INC2
620#undef RUY_OFFSET_SRC_INC3
621#undef RUY_OFFSET_SRC_ROWS
622#undef RUY_OFFSET_SRC_ZERO_POINT
623#undef RUY_OFFSET_INPUT_XOR
624
625#endif // RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
626
627#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
628
629void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
630 const void* src_ptr2, const void* src_ptr3,
631 int src_inc0, int src_inc1, int src_inc2,
632 int src_inc3, int src_rows,
633 int src_zero_point, std::int8_t* packed_ptr,
634 std::int32_t* sums_ptr, int input_xor) {
635 profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
636 asm volatile(
637 // clang-format off
638 // v26 will be the vector to XOR input values with to perform
639 // any input data type conversion (e.g. uint8 to int8).
640 "dup v26.16b, %w[input_xor]\n"
641 // w1 will be the number of rows already loaded.
642 "mov w1, #0\n"
643 // v28--v32 will be used to accumulate the sums
644 "movi v28.4s, #0\n"
645 "movi v29.4s, #0\n"
646 "movi v30.4s, #0\n"
647 "movi v31.4s, #0\n"
648 // Let w2 be `rows` rounded down to multiple of 16.
649 "ands w2, %w[rows], #-16\n"
650 // If there are no full blocks of 16 rows to process, jump to the
651 // code handling the last < 16 rows.
652 "beq 3f\n"
653 // Load the first block of 16 rows.
654 "add w1, w1, #16\n"
655 "ldr x10, [%[src_ptr0], #8]\n"
656 "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
657 "ldr x11, [%[src_ptr1], #8]\n"
658 "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
659 "ldr x12, [%[src_ptr2], #8]\n"
660 "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
661 "ldr x13, [%[src_ptr3], #8]\n"
662 "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
663 // Check if these were the only full block of 16 rows to load.
664 "cmp w1, w2\n"
665 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
666 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
667 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
668 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
669 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
670 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
671 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
672 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
673 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
674 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
675 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
676 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
677 // In that case, jump to the code handling the last loaded block of
678 // 16 rows.
679 "beq 2f\n"
680 // Main loop processing blocks of 16 rows.
681 "1:\n"
682 // Load the next 16 rows, interleaved with the XOR input type
683 // conversion (e.g. uint8->int8) on the already loaded inputs.
684 "add w1, w1, #16\n"
685 "ins v0.d[1], x10\n"
686 "ldr x10, [%[src_ptr0], #8]\n"
687 "ins v1.d[1], x11\n"
688 "ldr x11, [%[src_ptr1], #8]\n"
689 "ins v2.d[1], x12\n"
690 "ldr x12, [%[src_ptr2], #8]\n"
691 "ins v3.d[1], x13\n"
692 "ldr x13, [%[src_ptr3], #8]\n"
693 "eor v4.16b, v0.16b, v26.16b\n"
694 "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
695 "eor v5.16b, v1.16b, v26.16b\n"
696 "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
697 "eor v6.16b, v2.16b, v26.16b\n"
698 "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
699 "eor v7.16b, v3.16b, v26.16b\n"
700 "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
701 // Compute the sums, interleaved with storing to the packed matrix.
702 "saddlp v16.8h, v4.16b\n"
703 "str q4, [%[packed_ptr], #0]\n"
704 "saddlp v17.8h, v5.16b\n"
705 "str q5, [%[packed_ptr], #16]\n"
706 "saddlp v18.8h, v6.16b\n"
707 "str q6, [%[packed_ptr], #32]\n"
708 "saddlp v19.8h, v7.16b\n"
709 "str q7, [%[packed_ptr], #48]\n"
710 "sadalp v28.4s, v16.8h\n"
711 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
712 // Was this the last block of 16 rows to load?
713 "cmp w1, w2\n"
714 "sadalp v29.4s, v17.8h\n"
715 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
716 "add %[packed_ptr], %[packed_ptr], #64\n"
717 "sadalp v30.4s, v18.8h\n"
718 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
719 "sadalp v31.4s, v19.8h\n"
720 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
721 // End of main loop on blocks of 16 rows.
722 "bne 1b\n"
723
724 // Code handling the last already-loaded block of 16 rows.
725 "2:\n"
726 // Process the last loaded full 16x4 block.
727 "ins v0.d[1], x10\n"
728 "ins v1.d[1], x11\n"
729 "ins v2.d[1], x12\n"
730 "ins v3.d[1], x13\n"
731 "eor v4.16b, v0.16b, v26.16b\n"
732 "eor v5.16b, v1.16b, v26.16b\n"
733 "eor v6.16b, v2.16b, v26.16b\n"
734 "eor v7.16b, v3.16b, v26.16b\n"
735
736 "saddlp v16.8h, v4.16b\n"
737 "str q4, [%[packed_ptr], #0]\n"
738 "saddlp v17.8h, v5.16b\n"
739 "str q5, [%[packed_ptr], #16]\n"
740 "saddlp v18.8h, v6.16b\n"
741 "str q6, [%[packed_ptr], #32]\n"
742 "saddlp v19.8h, v7.16b\n"
743 "str q7, [%[packed_ptr], #48]\n"
744 "sadalp v28.4s, v16.8h\n"
745 "sadalp v29.4s, v17.8h\n"
746 "sadalp v30.4s, v18.8h\n"
747 "sadalp v31.4s, v19.8h\n"
748
749 "add %[packed_ptr], %[packed_ptr], #64\n"
750
751 // End of code handling full blocks of 16 rows.
752 // Now we handle any remaining rows.
753 "3:\n"
754 // Let w2 be the number of rows left to handle.
755 "ands w2, %w[rows], #15\n"
756 // If w2==0, there are no remaining rows, jump to the end.
757 "beq 4f\n"
758 // Zero out a 16x4 block in registers, which we'll partially overwrite
759 // with any remaining rows.
760 "dup v0.16b, %w[src_zero_point]\n"
761 "dup v1.16b, %w[src_zero_point]\n"
762 "dup v2.16b, %w[src_zero_point]\n"
763 "dup v3.16b, %w[src_zero_point]\n"
764#define RUY_LOAD_ONE_ROW(R) \
765 "cmp w2, #" #R "\n" \
766 "beq 5f\n" \
767 "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
768 "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
769 "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
770 "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
771
772 RUY_LOAD_ONE_ROW(0)
773 RUY_LOAD_ONE_ROW(1)
774 RUY_LOAD_ONE_ROW(2)
775 RUY_LOAD_ONE_ROW(3)
776 RUY_LOAD_ONE_ROW(4)
777 RUY_LOAD_ONE_ROW(5)
778 RUY_LOAD_ONE_ROW(6)
779 RUY_LOAD_ONE_ROW(7)
780 RUY_LOAD_ONE_ROW(8)
781 RUY_LOAD_ONE_ROW(9)
782 RUY_LOAD_ONE_ROW(10)
783 RUY_LOAD_ONE_ROW(11)
784 RUY_LOAD_ONE_ROW(12)
785 RUY_LOAD_ONE_ROW(13)
786 RUY_LOAD_ONE_ROW(14)
787 // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
788#undef RUY_LOAD_ONE_ROW
789 "5:\n"
790
791 // Process the last zero-padded 16x4 block.
792 "eor v4.16b, v0.16b, v26.16b\n"
793 "eor v5.16b, v1.16b, v26.16b\n"
794 "eor v6.16b, v2.16b, v26.16b\n"
795 "eor v7.16b, v3.16b, v26.16b\n"
796
797 "saddlp v16.8h, v4.16b\n"
798 "saddlp v17.8h, v5.16b\n"
799 "saddlp v18.8h, v6.16b\n"
800 "saddlp v19.8h, v7.16b\n"
801 "sadalp v28.4s, v16.8h\n"
802 "sadalp v29.4s, v17.8h\n"
803 "sadalp v30.4s, v18.8h\n"
804 "sadalp v31.4s, v19.8h\n"
805
806 "str q4, [%[packed_ptr], #0]\n"
807 "str q5, [%[packed_ptr], #16]\n"
808 "str q6, [%[packed_ptr], #32]\n"
809 "str q7, [%[packed_ptr], #48]\n"
810 "add %[packed_ptr], %[packed_ptr], #64\n"
811
812 "4:\n"
813
814 // Horizontal reduction of the registers used to accumulate sums.
815 "addp v28.4s, v28.4s, v29.4s\n"
816 "addp v30.4s, v30.4s, v31.4s\n"
817 "addp v28.4s, v28.4s, v30.4s\n"
818
819 // Store the sums.
820 "cmp %[sums_ptr], #0\n"
821 "beq 6f\n"
822 "st1 {v28.4s}, [%[sums_ptr]], #16\n"
823 "6:\n"
824 // clang-format on
825
826 : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
827 [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
828 [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr)
829 : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
830 [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
831 [ rows ] "r"(src_rows),
832 [ src_zero_point ] "r"(src_zero_point),
833 [input_xor] "r"(input_xor)
834 : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5",
835 "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15",
836 "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24",
837 "v25", "v26", "v27", "v28", "v29", "v30", "v31");
838}
839
840void Pack8bitColMajorForNeonDotprodA55ish(
841 const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
842 const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
843 int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
844 std::int32_t* sums_ptr, int input_xor) {
845 profiler::ScopeLabel label(
846 "Pack (kNeonDotprod, optimized for in-order cores)");
847 asm volatile(
848 // clang-format off
849 // v26 will be the vector to XOR input values with to perform
850 // any input data type conversion (e.g. uint8 to int8).
851 "dup v26.16b, %w[input_xor]\n"
852 // v27 will be filled with 1's. It will be used as an operand
853 // to SDOT to compute the sums.
854 "mov w1, #1\n"
855 "dup v27.16b, w1\n"
856 // w1 will be the number of rows already loaded.
857 "mov w1, #0\n"
858 // v28--v32 will be used to accumulate the sums
859 "movi v28.4s, #0\n"
860 "movi v29.4s, #0\n"
861 "movi v30.4s, #0\n"
862 "movi v31.4s, #0\n"
863
864 // Let w2 be `rows` rounded down to multiple of 16.
865 "ands w2, %w[rows], #-16\n"
866 // If there are no full blocks of 16 rows to process, jump to the
867 // code handling the last < 16 rows.
868 "beq 3f\n"
869 // Load the first block of 16 rows.
870 "add w1, w1, #16\n"
871 "ldr x10, [%[src_ptr0], #8]\n"
872 "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
873 "ldr x11, [%[src_ptr1], #8]\n"
874 "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
875 "ldr x12, [%[src_ptr2], #8]\n"
876 "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
877 "ldr x13, [%[src_ptr3], #8]\n"
878 "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
879 // Check if these were the only full block of 16 rows to load.
880 "cmp w1, w2\n"
881 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
882 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
883 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
884 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
885 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
886 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
887 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
888 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
889 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
890 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
891 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
892 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
893 // In that case, jump to the code handling the last loaded block of
894 // 16 rows.
895 "beq 2f\n"
896
897 // Main loop processing blocks of 16 rows.
898 "1:\n"
899 "add w1, w1, #16\n"
900 // Prepare the already-loaded 16 rows by inserting the parts
901 // loaded into general purpose registers x10--x13 into the
902 // NEON registers v0--v3 where the other parts had already been
903 // loaded.
904 "ins v0.d[1], x10\n"
905 "ldr x10, [%[src_ptr0], #8]\n"
906 "ins v1.d[1], x11\n"
907 "ldr x11, [%[src_ptr1], #8]\n"
908 "ins v2.d[1], x12\n"
909 "ldr x12, [%[src_ptr2], #8]\n"
910 "ins v3.d[1], x13\n"
911 "ldr x13, [%[src_ptr3], #8]\n"
912
913 // Load the next 16 rows and, interleaved with that,
914 // perform the input type conversion (e.g. uint8->int8) on the
915 // current 16 rows.
916 "eor v4.16b, v0.16b, v26.16b\n"
917 "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n"
918 "eor v5.16b, v1.16b, v26.16b\n"
919 "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n"
920 "eor v6.16b, v2.16b, v26.16b\n"
921 "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n"
922 "eor v7.16b, v3.16b, v26.16b\n"
923 "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n"
924
925 // Transposition of 4x4 blocks, part 1
926 "trn1 v16.4s, v4.4s, v5.4s\n"
927 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
928 "trn2 v17.4s, v4.4s, v5.4s\n"
929 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
930 "trn1 v18.4s, v6.4s, v7.4s\n"
931 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
932 "trn2 v19.4s, v6.4s, v7.4s\n"
933 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
934
935 // Transposition of 4x4 blocks, part 2
936 "trn1 v20.2d, v16.2d, v18.2d\n"
937 "trn2 v22.2d, v16.2d, v18.2d\n"
938 "trn1 v21.2d, v17.2d, v19.2d\n"
939 "trn2 v23.2d, v17.2d, v19.2d\n"
940 "cmp w1, w2\n"
941
942 // Store the block to the packed matrix and, interleaved with
943 // that, compute sums using sdot instructions.
944 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
945 "str q20, [%[packed_ptr], #0]\n"
946 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
947 "str q21, [%[packed_ptr], #32]\n"
948 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
949 "str q22, [%[packed_ptr], #64]\n"
950 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
951 "str q23, [%[packed_ptr], #96]\n"
952 "add %[packed_ptr], %[packed_ptr], #128\n"
953 // End of main loop on blocks of 16 rows.
954 "bne 1b\n"
955
956 // Code handling the last already-loaded block of 16 rows.
957 "2:\n"
958 // Process the last loaded full 16x4 block.
959 "ins v0.d[1], x10\n"
960 "ins v1.d[1], x11\n"
961 "ins v2.d[1], x12\n"
962 "ins v3.d[1], x13\n"
963 "eor v0.16b, v0.16b, v26.16b\n"
964 "eor v1.16b, v1.16b, v26.16b\n"
965 "eor v2.16b, v2.16b, v26.16b\n"
966 "eor v3.16b, v3.16b, v26.16b\n"
967
968 "trn1 v16.4s, v0.4s, v1.4s\n"
969 "trn2 v17.4s, v0.4s, v1.4s\n"
970 "trn1 v18.4s, v2.4s, v3.4s\n"
971 "trn2 v19.4s, v2.4s, v3.4s\n"
972
973 "trn1 v20.2d, v16.2d, v18.2d\n"
974 "trn2 v22.2d, v16.2d, v18.2d\n"
975 "trn1 v21.2d, v17.2d, v19.2d\n"
976 "trn2 v23.2d, v17.2d, v19.2d\n"
977
978 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
979 "str q20, [%[packed_ptr], #0]\n"
980 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
981 "str q21, [%[packed_ptr], #32]\n"
982 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
983 "str q22, [%[packed_ptr], #64]\n"
984 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
985 "str q23, [%[packed_ptr], #96]\n"
986 "add %[packed_ptr], %[packed_ptr], #128\n"
987
988 // End of code handling full blocks of 16 rows.
989 // Now we handle any remaining rows.
990 "3:\n"
991 // Let w2 be the number of rows left to handle.
992 "ands w2, %w[rows], #15\n"
993 // If w2==0, there are no remaining rows, jump to the end.
994 "beq 4f\n"
995 // Zero out a 16x4 block in registers, which we'll partially overwrite
996 // with any remaining rows.
997 "dup v0.16b, %w[src_zero_point]\n"
998 "dup v1.16b, %w[src_zero_point]\n"
999 "dup v2.16b, %w[src_zero_point]\n"
1000 "dup v3.16b, %w[src_zero_point]\n"
1001#define RUY_LOAD_ONE_ROW(R) \
1002 "cmp w2, #" #R "\n" \
1003 "beq 5f\n" \
1004 "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
1005 "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
1006 "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
1007 "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
1008
1009 RUY_LOAD_ONE_ROW(0)
1010 RUY_LOAD_ONE_ROW(1)
1011 RUY_LOAD_ONE_ROW(2)
1012 RUY_LOAD_ONE_ROW(3)
1013 RUY_LOAD_ONE_ROW(4)
1014 RUY_LOAD_ONE_ROW(5)
1015 RUY_LOAD_ONE_ROW(6)
1016 RUY_LOAD_ONE_ROW(7)
1017 RUY_LOAD_ONE_ROW(8)
1018 RUY_LOAD_ONE_ROW(9)
1019 RUY_LOAD_ONE_ROW(10)
1020 RUY_LOAD_ONE_ROW(11)
1021 RUY_LOAD_ONE_ROW(12)
1022 RUY_LOAD_ONE_ROW(13)
1023 RUY_LOAD_ONE_ROW(14)
1024 // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
1025#undef RUY_LOAD_ONE_ROW
1026
1027 "5:\n"
1028 // Process the last zero-padded 16x4 block.
1029 "eor v0.16b, v0.16b, v26.16b\n"
1030 "eor v1.16b, v1.16b, v26.16b\n"
1031 "eor v2.16b, v2.16b, v26.16b\n"
1032 "eor v3.16b, v3.16b, v26.16b\n"
1033
1034 "trn1 v16.4s, v0.4s, v1.4s\n"
1035 "trn2 v17.4s, v0.4s, v1.4s\n"
1036 "trn1 v18.4s, v2.4s, v3.4s\n"
1037 "trn2 v19.4s, v2.4s, v3.4s\n"
1038
1039 "trn1 v20.2d, v16.2d, v18.2d\n"
1040 "trn2 v22.2d, v16.2d, v18.2d\n"
1041 "trn1 v21.2d, v17.2d, v19.2d\n"
1042 "trn2 v23.2d, v17.2d, v19.2d\n"
1043
1044 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1045 "str q20, [%[packed_ptr], #0]\n"
1046 "cmp w2, #4\n"
1047 "ble 4f\n"
1048 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1049 "str q21, [%[packed_ptr], #32]\n"
1050 "cmp w2, #8\n"
1051 "ble 4f\n"
1052 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1053 "str q22, [%[packed_ptr], #64]\n"
1054 "cmp w2, #12\n"
1055 "ble 4f\n"
1056 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1057 "str q23, [%[packed_ptr], #96]\n"
1058 "add %[packed_ptr], %[packed_ptr], #128\n"
1059
1060 "4:\n"
1061
1062 // Reduction of the registers used to accumulate sums.
1063 "add v28.4s, v28.4s, v29.4s\n"
1064 "add v30.4s, v30.4s, v31.4s\n"
1065 "add v28.4s, v28.4s, v30.4s\n"
1066
1067 // Store the sums.
1068 "cmp %[sums_ptr], #0\n"
1069 "beq 6f\n"
1070 "st1 {v28.4s}, [%[sums_ptr]], #16\n"
1071 "6:\n"
1072 // clang-format on
1073
1074 : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
1075 [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
1076 : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r"(static_cast<std::int64_t>(src_inc1)),
1077 [ src_inc2 ] "r"(static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r"(static_cast<std::int64_t>(src_inc3)),
1078 [rows] "r"(src_rows),
1079 [src_zero_point] "r"(static_cast<int>(src_zero_point)),
1080 [input_xor] "r"(input_xor)
1081 : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
1082 "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
1083 "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1084}
1085
1086void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
1087 const void* src_ptr2, const void* src_ptr3,
1088 int src_inc0, int src_inc1, int src_inc2,
1089 int src_inc3, int src_rows,
1090 int src_zero_point, std::int8_t* packed_ptr,
1091 std::int32_t* sums_ptr, int input_xor) {
1092 profiler::ScopeLabel label("Pack (kNeonDotprod)");
1093 asm volatile(
1094 // clang-format off
1095 // v26 will be the vector to XOR input values with to perform
1096 // any input data type conversion (e.g. uint8 to int8).
1097 "dup v26.16b, %w[input_xor]\n"
1098 // v27 will be filled with 1's. It will be used as an operand
1099 // to SDOT to compute the sums.
1100 "mov w1, #1\n"
1101 "dup v27.16b, w1\n"
1102 // w1 will be the number of rows already loaded.
1103 "mov w1, #0\n"
1104 // v28--v32 will be used to accumulate the sums
1105 "movi v28.4s, #0\n"
1106 "movi v29.4s, #0\n"
1107 "movi v30.4s, #0\n"
1108 "movi v31.4s, #0\n"
1109
1110 // 4x partially unrolled code processing blocks of 64 rows.
1111 // Read the original loop below first, it has more comments.
1112#if RUY_OPT(MAX_STREAMING)
1113 // Let w2 be `rows` rounded down to multiple of 64.
1114 // Each iteration of this 4x partially unrolled loop handles
1115 // 64 rows.
1116 "ands w2, %w[rows], #-64\n"
1117 // If there are no full blocks of 64 rows to process, jump to
1118 // the main loop below handling 16 rows per iteration.
1119 "beq 9f\n"
1120 // Load the first block of 64 rows.
1121 "add w1, w1, #64\n"
1122 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1123 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1124 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1125 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1126 "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
1127 "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
1128 "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
1129 "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
1130 "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
1131 "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
1132 "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
1133 "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
1134 "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
1135 "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
1136 // Was that the last full block of 64 rows to load?
1137 "cmp w1, w2\n"
1138 "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
1139 "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
1140 // Then jump to the end of the 64-rows-at-a-time code.
1141 "beq 8f\n"
1142
1143 // Start of the main 4x partially unrolled loop.
1144 "7:\n"
1145 // Process rows 0 -- 15 out of 64.
1146 "eor v0.16b, v0.16b, v26.16b\n"
1147 "eor v1.16b, v1.16b, v26.16b\n"
1148 "eor v2.16b, v2.16b, v26.16b\n"
1149 "eor v3.16b, v3.16b, v26.16b\n"
1150
1151 "trn1 v16.4s, v0.4s, v1.4s\n"
1152 "trn2 v17.4s, v0.4s, v1.4s\n"
1153 "trn1 v18.4s, v2.4s, v3.4s\n"
1154 "trn2 v19.4s, v2.4s, v3.4s\n"
1155
1156 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1157 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1158 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1159 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1160 "add w1, w1, #16\n"
1161
1162 "trn1 v20.2d, v16.2d, v18.2d\n"
1163 "trn2 v22.2d, v16.2d, v18.2d\n"
1164 "trn1 v21.2d, v17.2d, v19.2d\n"
1165 "trn2 v23.2d, v17.2d, v19.2d\n"
1166
1167 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1168 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1169 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1170 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1171
1172 "str q20, [%[packed_ptr], #0]\n"
1173 "str q21, [%[packed_ptr], #32]\n"
1174 "str q22, [%[packed_ptr], #64]\n"
1175 "str q23, [%[packed_ptr], #96]\n"
1176 "add %[packed_ptr], %[packed_ptr], #128\n"
1177
1178 // Process rows 16 -- 31 out of 64.
1179 "eor v4.16b, v4.16b, v26.16b\n"
1180 "eor v5.16b, v5.16b, v26.16b\n"
1181 "eor v6.16b, v6.16b, v26.16b\n"
1182 "eor v7.16b, v7.16b, v26.16b\n"
1183
1184 "trn1 v16.4s, v4.4s, v5.4s\n"
1185 "trn2 v17.4s, v4.4s, v5.4s\n"
1186 "trn1 v18.4s, v6.4s, v7.4s\n"
1187 "trn2 v19.4s, v6.4s, v7.4s\n"
1188
1189 "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n"
1190 "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n"
1191 "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n"
1192 "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n"
1193 "add w1, w1, #16\n"
1194
1195 "trn1 v20.2d, v16.2d, v18.2d\n"
1196 "trn2 v22.2d, v16.2d, v18.2d\n"
1197 "trn1 v21.2d, v17.2d, v19.2d\n"
1198 "trn2 v23.2d, v17.2d, v19.2d\n"
1199
1200 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1201 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1202 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1203 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1204
1205 "str q20, [%[packed_ptr], #0]\n"
1206 "str q21, [%[packed_ptr], #32]\n"
1207 "str q22, [%[packed_ptr], #64]\n"
1208 "str q23, [%[packed_ptr], #96]\n"
1209 "add %[packed_ptr], %[packed_ptr], #128\n"
1210
1211 // Process rows 32 -- 47 out of 64.
1212 "eor v8.16b, v8.16b, v26.16b\n"
1213 "eor v9.16b, v9.16b, v26.16b\n"
1214 "eor v10.16b, v10.16b, v26.16b\n"
1215 "eor v11.16b, v11.16b, v26.16b\n"
1216
1217 "trn1 v16.4s, v8.4s, v9.4s\n"
1218 "trn2 v17.4s, v8.4s, v9.4s\n"
1219 "trn1 v18.4s, v10.4s, v11.4s\n"
1220 "trn2 v19.4s, v10.4s, v11.4s\n"
1221
1222 "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n"
1223 "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n"
1224 "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n"
1225 "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n"
1226 "add w1, w1, #16\n"
1227
1228 "trn1 v20.2d, v16.2d, v18.2d\n"
1229 "trn2 v22.2d, v16.2d, v18.2d\n"
1230 "trn1 v21.2d, v17.2d, v19.2d\n"
1231 "trn2 v23.2d, v17.2d, v19.2d\n"
1232
1233 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1234 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1235 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1236 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1237
1238 "str q20, [%[packed_ptr], #0]\n"
1239 "str q21, [%[packed_ptr], #32]\n"
1240 "str q22, [%[packed_ptr], #64]\n"
1241 "str q23, [%[packed_ptr], #96]\n"
1242 "add %[packed_ptr], %[packed_ptr], #128\n"
1243
1244 // Process rows 48 -- 63 out of 64.
1245 "eor v12.16b, v12.16b, v26.16b\n"
1246 "eor v13.16b, v13.16b, v26.16b\n"
1247 "eor v14.16b, v14.16b, v26.16b\n"
1248 "eor v15.16b, v15.16b, v26.16b\n"
1249
1250 "trn1 v16.4s, v12.4s, v13.4s\n"
1251 "trn2 v17.4s, v12.4s, v13.4s\n"
1252 "trn1 v18.4s, v14.4s, v15.4s\n"
1253 "trn2 v19.4s, v14.4s, v15.4s\n"
1254
1255 "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n"
1256 "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n"
1257 "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n"
1258 "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n"
1259 "add w1, w1, #16\n"
1260
1261 "trn1 v20.2d, v16.2d, v18.2d\n"
1262 "trn2 v22.2d, v16.2d, v18.2d\n"
1263 "trn1 v21.2d, v17.2d, v19.2d\n"
1264 "trn2 v23.2d, v17.2d, v19.2d\n"
1265
1266 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1267 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1268 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1269 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1270
1271 "cmp w1, w2\n"
1272 "str q20, [%[packed_ptr], #0]\n"
1273 "str q21, [%[packed_ptr], #32]\n"
1274 "str q22, [%[packed_ptr], #64]\n"
1275 "str q23, [%[packed_ptr], #96]\n"
1276 "add %[packed_ptr], %[packed_ptr], #128\n"
1277
1278 // End of main 4x partially unrolled loop.
1279 "bne 7b\n"
1280
1281 // Last part of the 4x partially unrolled code:
1282 // handle the last already-loaded 64 rows.
1283 "8:\n"
1284
1285 // Process rows 0 -- 15 out of 64.
1286 "eor v0.16b, v0.16b, v26.16b\n"
1287 "eor v1.16b, v1.16b, v26.16b\n"
1288 "eor v2.16b, v2.16b, v26.16b\n"
1289 "eor v3.16b, v3.16b, v26.16b\n"
1290
1291 "trn1 v16.4s, v0.4s, v1.4s\n"
1292 "trn2 v17.4s, v0.4s, v1.4s\n"
1293 "trn1 v18.4s, v2.4s, v3.4s\n"
1294 "trn2 v19.4s, v2.4s, v3.4s\n"
1295
1296 "trn1 v20.2d, v16.2d, v18.2d\n"
1297 "trn2 v22.2d, v16.2d, v18.2d\n"
1298 "trn1 v21.2d, v17.2d, v19.2d\n"
1299 "trn2 v23.2d, v17.2d, v19.2d\n"
1300
1301 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1302 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1303 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1304 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1305
1306 "str q20, [%[packed_ptr], #0]\n"
1307 "str q21, [%[packed_ptr], #32]\n"
1308 "str q22, [%[packed_ptr], #64]\n"
1309 "str q23, [%[packed_ptr], #96]\n"
1310 "add %[packed_ptr], %[packed_ptr], #128\n"
1311
1312 // Process rows 16 -- 31 out of 64.
1313 "eor v4.16b, v4.16b, v26.16b\n"
1314 "eor v5.16b, v5.16b, v26.16b\n"
1315 "eor v6.16b, v6.16b, v26.16b\n"
1316 "eor v7.16b, v7.16b, v26.16b\n"
1317
1318 "trn1 v16.4s, v4.4s, v5.4s\n"
1319 "trn2 v17.4s, v4.4s, v5.4s\n"
1320 "trn1 v18.4s, v6.4s, v7.4s\n"
1321 "trn2 v19.4s, v6.4s, v7.4s\n"
1322
1323 "trn1 v20.2d, v16.2d, v18.2d\n"
1324 "trn2 v22.2d, v16.2d, v18.2d\n"
1325 "trn1 v21.2d, v17.2d, v19.2d\n"
1326 "trn2 v23.2d, v17.2d, v19.2d\n"
1327
1328 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1329 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1330 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1331 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1332
1333 "str q20, [%[packed_ptr], #0]\n"
1334 "str q21, [%[packed_ptr], #32]\n"
1335 "str q22, [%[packed_ptr], #64]\n"
1336 "str q23, [%[packed_ptr], #96]\n"
1337 "add %[packed_ptr], %[packed_ptr], #128\n"
1338
1339 // Process rows 32 -- 47 out of 64.
1340 "eor v8.16b, v8.16b, v26.16b\n"
1341 "eor v9.16b, v9.16b, v26.16b\n"
1342 "eor v10.16b, v10.16b, v26.16b\n"
1343 "eor v11.16b, v11.16b, v26.16b\n"
1344
1345 "trn1 v16.4s, v8.4s, v9.4s\n"
1346 "trn2 v17.4s, v8.4s, v9.4s\n"
1347 "trn1 v18.4s, v10.4s, v11.4s\n"
1348 "trn2 v19.4s, v10.4s, v11.4s\n"
1349
1350 "trn1 v20.2d, v16.2d, v18.2d\n"
1351 "trn2 v22.2d, v16.2d, v18.2d\n"
1352 "trn1 v21.2d, v17.2d, v19.2d\n"
1353 "trn2 v23.2d, v17.2d, v19.2d\n"
1354
1355 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1356 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1357 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1358 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1359
1360 "str q20, [%[packed_ptr], #0]\n"
1361 "str q21, [%[packed_ptr], #32]\n"
1362 "str q22, [%[packed_ptr], #64]\n"
1363 "str q23, [%[packed_ptr], #96]\n"
1364 "add %[packed_ptr], %[packed_ptr], #128\n"
1365
1366 // Process rows 48 -- 63 out of 64.
1367 "eor v12.16b, v12.16b, v26.16b\n"
1368 "eor v13.16b, v13.16b, v26.16b\n"
1369 "eor v14.16b, v14.16b, v26.16b\n"
1370 "eor v15.16b, v15.16b, v26.16b\n"
1371
1372 "trn1 v16.4s, v12.4s, v13.4s\n"
1373 "trn2 v17.4s, v12.4s, v13.4s\n"
1374 "trn1 v18.4s, v14.4s, v15.4s\n"
1375 "trn2 v19.4s, v14.4s, v15.4s\n"
1376
1377 "trn1 v20.2d, v16.2d, v18.2d\n"
1378 "trn2 v22.2d, v16.2d, v18.2d\n"
1379 "trn1 v21.2d, v17.2d, v19.2d\n"
1380 "trn2 v23.2d, v17.2d, v19.2d\n"
1381
1382 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1383 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1384 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1385 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1386
1387 "str q20, [%[packed_ptr], #0]\n"
1388 "str q21, [%[packed_ptr], #32]\n"
1389 "str q22, [%[packed_ptr], #64]\n"
1390 "str q23, [%[packed_ptr], #96]\n"
1391 "add %[packed_ptr], %[packed_ptr], #128\n"
1392
1393 "9:\n"
1394#endif // #if RUY_OPT(MAX_STREAMING)
1395 // End of 4x partially unrolled code processing blocks of 64 rows.
1396
1397 // Main part of the code, processing blocks of 16 rows.
1398
1399 // Let w2 be `rows` rounded down to multiple of 16.
1400 "and w2, %w[rows], #-16\n"
1401 // If there are no full blocks of 16 rows to process, jump to the
1402 // code handling the last < 16 rows.
1403 "cmp w1, w2\n"
1404 "beq 3f\n"
1405
1406 // Load the first block of 16 rows.
1407 "add w1, w1, #16\n"
1408 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1409 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1410 // Check if these were the only full block of 16 rows to load.
1411 "cmp w1, w2\n"
1412 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1413 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1414 // In that case, jump to the code handling the last loaded block of
1415 // 16 rows.
1416 "beq 2f\n"
1417 // Main loop processing blocks of 16 rows.
1418 "1:\n"
1419 // Input type conversion (e.g. uint8->int8).
1420 "eor v0.16b, v0.16b, v26.16b\n"
1421 "eor v1.16b, v1.16b, v26.16b\n"
1422 "eor v2.16b, v2.16b, v26.16b\n"
1423 "eor v3.16b, v3.16b, v26.16b\n"
1424 // Transposition of 4x4 blocks, part 1
1425 "trn1 v16.4s, v0.4s, v1.4s\n"
1426 "trn2 v17.4s, v0.4s, v1.4s\n"
1427 "trn1 v18.4s, v2.4s, v3.4s\n"
1428 "trn2 v19.4s, v2.4s, v3.4s\n"
1429 // Load the next 16 rows
1430 "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n"
1431 "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n"
1432 "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n"
1433 "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n"
1434 "add w1, w1, #16\n"
1435 // Transposition of 4x4 blocks, part 2
1436 "trn1 v20.2d, v16.2d, v18.2d\n"
1437 "trn2 v22.2d, v16.2d, v18.2d\n"
1438 "trn1 v21.2d, v17.2d, v19.2d\n"
1439 "trn2 v23.2d, v17.2d, v19.2d\n"
1440 // Compute sums using sdot instructions.
1441 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1442 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1443 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1444 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1445 // Store the block to the packed matrix.
1446 "str q20, [%[packed_ptr], #0]\n"
1447 "str q21, [%[packed_ptr], #32]\n"
1448 "cmp w1, w2\n"
1449 "str q22, [%[packed_ptr], #64]\n"
1450 "str q23, [%[packed_ptr], #96]\n"
1451 "add %[packed_ptr], %[packed_ptr], #128\n"
1452 // End of main loop on blocks of 16 rows.
1453 "bne 1b\n"
1454
1455 // Code handling the last already-loaded block of 16 rows.
1456 "2:\n"
1457
1458 // Process the last loaded full 16x4 block.
1459 "eor v0.16b, v0.16b, v26.16b\n"
1460 "eor v1.16b, v1.16b, v26.16b\n"
1461 "eor v2.16b, v2.16b, v26.16b\n"
1462 "eor v3.16b, v3.16b, v26.16b\n"
1463
1464 "trn1 v16.4s, v0.4s, v1.4s\n"
1465 "trn2 v17.4s, v0.4s, v1.4s\n"
1466 "trn1 v18.4s, v2.4s, v3.4s\n"
1467 "trn2 v19.4s, v2.4s, v3.4s\n"
1468
1469 "trn1 v20.2d, v16.2d, v18.2d\n"
1470 "trn2 v22.2d, v16.2d, v18.2d\n"
1471 "trn1 v21.2d, v17.2d, v19.2d\n"
1472 "trn2 v23.2d, v17.2d, v19.2d\n"
1473
1474 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1475 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1476 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1477 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1478
1479 "str q20, [%[packed_ptr], #0]\n"
1480 "str q21, [%[packed_ptr], #32]\n"
1481 "str q22, [%[packed_ptr], #64]\n"
1482 "str q23, [%[packed_ptr], #96]\n"
1483 "add %[packed_ptr], %[packed_ptr], #128\n"
1484
1485 // End of code handling full blocks of 16 rows.
1486 // Now we handle any remaining rows.
1487 "3:\n"
1488 // Let w2 be the number of rows left to handle.
1489 "ands w2, %w[rows], #15\n"
1490 // If w2==0, there are no remaining rows, jump to the end.
1491 "beq 4f\n"
1492 // Zero out a 16x4 block in registers, which we'll partially overwrite
1493 // with any remaining rows.
1494 "dup v0.16b, %w[src_zero_point]\n"
1495 "dup v1.16b, %w[src_zero_point]\n"
1496 "dup v2.16b, %w[src_zero_point]\n"
1497 "dup v3.16b, %w[src_zero_point]\n"
1498#define RUY_LOAD_ONE_ROW(R) \
1499 "cmp w2, #" #R "\n" \
1500 "beq 5f\n" \
1501 "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \
1502 "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \
1503 "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \
1504 "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n"
1505
1506 RUY_LOAD_ONE_ROW(0)
1507 RUY_LOAD_ONE_ROW(1)
1508 RUY_LOAD_ONE_ROW(2)
1509 RUY_LOAD_ONE_ROW(3)
1510 RUY_LOAD_ONE_ROW(4)
1511 RUY_LOAD_ONE_ROW(5)
1512 RUY_LOAD_ONE_ROW(6)
1513 RUY_LOAD_ONE_ROW(7)
1514 RUY_LOAD_ONE_ROW(8)
1515 RUY_LOAD_ONE_ROW(9)
1516 RUY_LOAD_ONE_ROW(10)
1517 RUY_LOAD_ONE_ROW(11)
1518 RUY_LOAD_ONE_ROW(12)
1519 RUY_LOAD_ONE_ROW(13)
1520 RUY_LOAD_ONE_ROW(14)
1521 // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op.
1522#undef RUY_LOAD_ONE_ROW
1523
1524 "5:\n"
1525 // Process the last zero-padded 16x4 block.
1526 "eor v0.16b, v0.16b, v26.16b\n"
1527 "eor v1.16b, v1.16b, v26.16b\n"
1528 "eor v2.16b, v2.16b, v26.16b\n"
1529 "eor v3.16b, v3.16b, v26.16b\n"
1530
1531 "trn1 v16.4s, v0.4s, v1.4s\n"
1532 "trn2 v17.4s, v0.4s, v1.4s\n"
1533 "trn1 v18.4s, v2.4s, v3.4s\n"
1534 "trn2 v19.4s, v2.4s, v3.4s\n"
1535
1536 "trn1 v20.2d, v16.2d, v18.2d\n"
1537 "trn2 v22.2d, v16.2d, v18.2d\n"
1538 "trn1 v21.2d, v17.2d, v19.2d\n"
1539 "trn2 v23.2d, v17.2d, v19.2d\n"
1540
1541 ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n"
1542 "str q20, [%[packed_ptr], #0]\n"
1543 "cmp w2, #4\n"
1544 "ble 4f\n"
1545 ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n"
1546 "str q21, [%[packed_ptr], #32]\n"
1547 "cmp w2, #8\n"
1548 "ble 4f\n"
1549 ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n"
1550 "str q22, [%[packed_ptr], #64]\n"
1551 "cmp w2, #12\n"
1552 "ble 4f\n"
1553 ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n"
1554 "str q23, [%[packed_ptr], #96]\n"
1555 "add %[packed_ptr], %[packed_ptr], #128\n"
1556
1557 "4:\n"
1558
1559 // Reduction of the registers used to accumulate sums.
1560 "add v28.4s, v28.4s, v29.4s\n"
1561 "add v30.4s, v30.4s, v31.4s\n"
1562 "add v28.4s, v28.4s, v30.4s\n"
1563
1564 // Store the sums.
1565 "cmp %[sums_ptr], #0\n"
1566 "beq 6f\n"
1567 "st1 {v28.4s}, [%[sums_ptr]], #16\n"
1568 "6:\n"
1569 // clang-format on
1570
1571 : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1572 [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
1573 [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr)
1574 : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
1575 [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
1576 [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
1577 [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
1578 [rows] "r"(src_rows),
1579 [src_zero_point] "r"(static_cast<int>(src_zero_point)),
1580 [input_xor] "r"(input_xor)
1581 : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
1582 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
1583 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
1584 "v27", "v28", "v29", "v30", "v31");
1585}
1586
1587void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
1588 const void* src_ptr2, const void* src_ptr3,
1589 int src_inc0, int src_inc1, int src_inc2,
1590 int src_inc3, int src_cols,
1591 int src_zero_point, std::int8_t* packed_ptr,
1592 int packed_stride, std::int32_t* sums_ptr,
1593 int input_xor) {
1594 profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)");
1595 asm volatile(
1596 // clang-format off
1597 // Prefetch data. This was tuned on Cortex-A55-rev1 cores.
1598 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n")
1599 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1]]\n")
1600 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2]]\n")
1601 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3]]\n")
1602 // Let w0 = (number of columns to compute) - 8.
1603 "subs w0, %w[src_cols], 8\n"
1604 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 64]\n")
1605 // Let v26 duplicate the input_xor value in all lanes.
1606 "dup v26.16b, %w[input_xor]\n"
1607 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 64]\n")
1608 // Let v27 be 1 in all lanes. Used with sdot to compute sums.
1609 "movi v27.16b, 1\n"
1610 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 64]\n")
1611 // If there isn't a full block of 8 columns to load from, jump to the
1612 // code after the loop handling leftovers.
1613 "blt 2f\n"
1614 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 64]\n")
1615 // Main loop, each iteration handles a full block of 8 cols.
1616 "1:\n"
1617 // Load the 4x8 block from the source matrix, or zero if we're
1618 // past the bottom of the source matrix.
1619 "ld1 {v0.8b}, [%[src_ptr0]]\n"
1620 "ld1 {v1.8b}, [%[src_ptr1]]\n"
1621 "ld1 {v2.8b}, [%[src_ptr2]]\n"
1622 "ld1 {v3.8b}, [%[src_ptr3]]\n"
1623 // Load values from the sums buffer, and start the reordering
1624 // of the loaded 4x8 block by interleaving 8bit values.
1625 "zip1 v0.16b, v0.16b, v1.16b\n"
1626 "ldr q8, [%[sums_ptr], 0]\n"
1627 "zip1 v1.16b, v2.16b, v3.16b\n"
1628 "ldr q9, [%[sums_ptr], 16]\n"
1629 // Finish the reordering of the 4x8 block, putting it into
1630 // column-major order.
1631 "zip1 v2.8h, v0.8h, v1.8h\n"
1632 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 128]\n")
1633 "zip2 v3.8h, v0.8h, v1.8h\n"
1634 // Apply input_xor, i.e. convert source values from uint8 to int8
1635 // if needed.
1636 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 128]\n")
1637 "eor v2.16b, v2.16b, v26.16b\n"
1638 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 128]\n")
1639 "eor v3.16b, v3.16b, v26.16b\n"
1640 // Update the sums.
1641 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 128]\n")
1642 ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n"
1643 ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n"
1644 // Store the column-major 4x8 block to the packed matrix, and
1645 // increment some source pointers.
1646 "str q2, [%[packed_ptr], 0]\n"
1647 "add %[src_ptr0], %[src_ptr0], %w[src_inc0], sxtw\n"
1648 "str q3, [%[packed_ptr], 16]\n"
1649 "add %[src_ptr1], %[src_ptr1], %w[src_inc1], sxtw\n"
1650 // Store the updated sums, and increment the remaining pointers
1651 // and the block_col loop index.
1652 "st1 {v8.4s}, [%[sums_ptr]], 16\n"
1653 "add %[packed_ptr], %[packed_ptr], %[packed_stride], lsl 3\n"
1654 "st1 {v9.4s}, [%[sums_ptr]], 16\n"
1655 // Advance by 8 columns and set the condition code.
1656 "subs w0, w0, 8\n"
1657 "add %[src_ptr2], %[src_ptr2], %w[src_inc2], sxtw\n"
1658 "add %[src_ptr3], %[src_ptr3], %w[src_inc3], sxtw\n"
1659 // End of the main loop.
1660 "bge 1b\n"
1661
1662 "2:\n"
1663 // We add back 8 to w0 so that w0 is the number of columns remaining
1664 // to handle.
1665 "adds w0, w0, 8\n"
1666 // Nothing left? Then jump to the end.
1667 "beq 3f\n"
1668 // Here w0 is between 1 and 7. We zero-initialize v0--v3 ...
1669 "dup v0.8b, %w[src_zero_point]\n"
1670 "dup v1.8b, %w[src_zero_point]\n"
1671 "dup v2.8b, %w[src_zero_point]\n"
1672 "dup v3.8b, %w[src_zero_point]\n"
1673 // ... and now we fill lanes one by one with leftover columns.
1674#define RUY_LOAD_ONE_COL(C)\
1675 "cmp w0, " #C "\n" \
1676 "beq 4f\n" \
1677 "ld1 { v0.b }[" #C "], [%[src_ptr0]], #1\n" \
1678 "ld1 { v1.b }[" #C "], [%[src_ptr1]], #1\n" \
1679 "ld1 { v2.b }[" #C "], [%[src_ptr2]], #1\n" \
1680 "ld1 { v3.b }[" #C "], [%[src_ptr3]], #1\n"
1681
1682 RUY_LOAD_ONE_COL(0)
1683 RUY_LOAD_ONE_COL(1)
1684 RUY_LOAD_ONE_COL(2)
1685 RUY_LOAD_ONE_COL(3)
1686 RUY_LOAD_ONE_COL(4)
1687 RUY_LOAD_ONE_COL(5)
1688 RUY_LOAD_ONE_COL(6)
1689 // Here we know that w0==7, so RUY_LOAD_ONE_COL(7) would be a no-op.
1690#undef RUY_LOAD_ONE_COL
1691
1692 "4:\n"
1693 // The leftovers source data is loaded, now we can perform the
1694 // computation as usual.
1695 // Load values from the sums buffer, and start the reordering
1696 // of the loaded 4x8 block by interleaving 8bit values.
1697 "zip1 v0.16b, v0.16b, v1.16b\n"
1698 "ldr q8, [%[sums_ptr], 0]\n"
1699 "zip1 v1.16b, v2.16b, v3.16b\n"
1700 "ldr q9, [%[sums_ptr], 16]\n"
1701 // Finish the reordering of the 4x8 block, putting it into
1702 // column-major order.
1703 "zip1 v2.8h, v0.8h, v1.8h\n"
1704 "zip2 v3.8h, v0.8h, v1.8h\n"
1705 // Apply input_xor, i.e. convert source values from uint8 to int8
1706 // if needed.
1707 "eor v2.16b, v2.16b, v26.16b\n"
1708 "eor v3.16b, v3.16b, v26.16b\n"
1709 // Update the sums.
1710 ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n"
1711 ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n"
1712 // Store the column-major 4x8 block to the packed matrix, and
1713 // increment some source pointers.
1714 "str q2, [%[packed_ptr], 0]\n"
1715 "str q3, [%[packed_ptr], 16]\n"
1716 // Store the updated sums, and increment the remaining pointers
1717 // and the block_col loop index.
1718 "st1 {v8.4s}, [%[sums_ptr]], 16\n"
1719 "st1 {v9.4s}, [%[sums_ptr]], 16\n"
1720
1721 // End label.
1722 "3:\n"
1723 // clang-format on
1724 : [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr),
1725 [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1726 [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3)
1727 : [src_inc0] "r"(src_inc0), [src_inc1] "r"(src_inc1),
1728 [src_inc2] "r"(src_inc2), [src_inc3] "r"(src_inc3),
1729 [input_xor] "r"(input_xor), [src_zero_point] "r"(src_zero_point),
1730 [packed_stride] "r"(static_cast<std::int64_t>(packed_stride)),
1731 [src_cols] "r"(src_cols)
1732 : "cc", "memory", "x0", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5",
1733 "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
1734 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
1735 "v27", "v28", "v29", "v30", "v31");
1736}
1737
1738void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
1739 const float* src_ptr2, const float* src_ptr3,
1740 int src_inc0, int src_inc1, int src_inc2,
1741 int src_inc3, int src_rows, float* packed_ptr) {
1742 profiler::ScopeLabel label("Pack (kNeon)");
1743 asm volatile(
1744 // clang-format off
1745 // w1 will be the number of rows already loaded.
1746 "mov w1, #0\n"
1747 // Let w2 be `rows` rounded down to multiple of 4.
1748 "ands w2, %w[rows], #-4\n"
1749 // If there are no full blocks of 4 rows to process, jump to the
1750 // code handling the last < 4 rows.
1751 "beq 3f\n"
1752 // Load the first block of 16 rows.
1753 "add w1, w1, #4\n"
1754 "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
1755 "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
1756 // Check if these were the only full block of 4 rows to load.
1757 "cmp w1, w2\n"
1758 "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
1759 "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
1760 // In that case, jump to the code handling the last loaded block of
1761 // 4 rows.
1762 "beq 2f\n"
1763 // Main loop processing blocks of 4 rows.
1764 "1:\n"
1765 // Advance by 4 rows.
1766 "add w1, w1, #4\n"
1767 // Transposition of the already-loaded 4x4 block, part 1.
1768 "trn1 v16.4s, v0.4s, v1.4s\n"
1769 "trn2 v17.4s, v0.4s, v1.4s\n"
1770 "trn1 v18.4s, v2.4s, v3.4s\n"
1771 "trn2 v19.4s, v2.4s, v3.4s\n"
1772 // Load the next 4x4 block.
1773 "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
1774 "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
1775 "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
1776 "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
1777 // Transposition of the already-loaded 4x4 block, part 2.
1778 "trn1 v20.2d, v16.2d, v18.2d\n"
1779 "trn2 v22.2d, v16.2d, v18.2d\n"
1780 "trn1 v21.2d, v17.2d, v19.2d\n"
1781 "trn2 v23.2d, v17.2d, v19.2d\n"
1782 // Was this the last full 4x4 block to load?
1783 "cmp w1, w2\n"
1784 // Store the transposed 4x4 block.
1785 "str q20, [%[packed_ptr], #0]\n"
1786 "str q21, [%[packed_ptr], #32]\n"
1787 "str q22, [%[packed_ptr], #64]\n"
1788 "str q23, [%[packed_ptr], #96]\n"
1789 "add %[packed_ptr], %[packed_ptr], #128\n"
1790 // End of main loop on 4x4 blocks.
1791 "bne 1b\n"
1792
1793 // Code handling the last already-loaded 4x4 block.
1794 "2:\n"
1795
1796 "trn1 v16.4s, v0.4s, v1.4s\n"
1797 "trn2 v17.4s, v0.4s, v1.4s\n"
1798 "trn1 v18.4s, v2.4s, v3.4s\n"
1799 "trn2 v19.4s, v2.4s, v3.4s\n"
1800
1801 "trn1 v20.2d, v16.2d, v18.2d\n"
1802 "trn2 v22.2d, v16.2d, v18.2d\n"
1803 "trn1 v21.2d, v17.2d, v19.2d\n"
1804 "trn2 v23.2d, v17.2d, v19.2d\n"
1805
1806 "str q20, [%[packed_ptr], #0]\n"
1807 "str q21, [%[packed_ptr], #32]\n"
1808 "str q22, [%[packed_ptr], #64]\n"
1809 "str q23, [%[packed_ptr], #96]\n"
1810 "add %[packed_ptr], %[packed_ptr], #128\n"
1811
1812 // End of code handling full 4x4 blocks.
1813 // Now we handle any remaining rows.
1814 "3:\n"
1815 // Let w2 be the number of rows left to handle.
1816 "ands w2, %w[rows], #3\n"
1817 // If w2==0, there are no remaining rows, jump to the end.
1818 "beq 4f\n"
1819 // Zero out a 4x4 block in registers, which we'll partially overwrite
1820 // with any remaining rows.
1821 "movi v0.16b, #0\n"
1822 "movi v1.16b, #0\n"
1823 "movi v2.16b, #0\n"
1824 "movi v3.16b, #0\n"
1825#define RUY_LOAD_ONE_ROW(R) \
1826 "cmp w2, #" #R "\n" \
1827 "beq 5f\n" \
1828 "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
1829 "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
1830 "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
1831 "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
1832
1833 RUY_LOAD_ONE_ROW(0)
1834 RUY_LOAD_ONE_ROW(1)
1835 RUY_LOAD_ONE_ROW(2)
1836 // Here we know that w2==3, so RUY_LOAD_ONE_ROW(3) would be a no-op.
1837#undef RUY_LOAD_ONE_ROW
1838 "5:\n"
1839
1840 // Transpose that last zero-padded 4x4 block.
1841 "trn1 v16.4s, v0.4s, v1.4s\n"
1842 "trn2 v17.4s, v0.4s, v1.4s\n"
1843 "trn1 v18.4s, v2.4s, v3.4s\n"
1844 "trn2 v19.4s, v2.4s, v3.4s\n"
1845
1846 "trn1 v20.2d, v16.2d, v18.2d\n"
1847 "trn2 v22.2d, v16.2d, v18.2d\n"
1848 "trn1 v21.2d, v17.2d, v19.2d\n"
1849 "trn2 v23.2d, v17.2d, v19.2d\n"
1850
1851 // Store that last zero-padded block to the packed matrix.
1852 "mov x1, #32\n"
1853#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
1854 "cmp w2, #" #ROW "\n" \
1855 "beq 4f\n" \
1856 "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
1857
1858 RUY_STORE_ONE_ROW(0, v20)
1859 RUY_STORE_ONE_ROW(1, v21)
1860 RUY_STORE_ONE_ROW(2, v22)
1861 RUY_STORE_ONE_ROW(3, v23)
1862
1863#undef RUY_STORE_ONE_ROW
1864
1865 "4:\n"
1866
1867 // clang-format on
1868
1869 : [src_ptr0] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1),
1870 [src_ptr2] "+r"(src_ptr2), [src_ptr3] "+r"(src_ptr3),
1871 [packed_ptr] "+r"(packed_ptr)
1872 : [src_inc0] "r"(static_cast<std::int64_t>(src_inc0)),
1873 [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)),
1874 [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
1875 [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)),
1876 [rows] "r"(src_rows)
1877 : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1",
1878 "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
1879 "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
1880 "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
1881}
1882#endif
1883
1884#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
1885void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
1886 const float* src_ptr2, const float* src_ptr3,
1887 int src_inc, int src_rows, float* packed_ptr,
1888 int output_stride) {
1889 profiler::ScopeLabel label("Pack (kNeon)");
1890 asm volatile(
1891 // clang-format off
1892 "mov r1, #0\n"
1893 "and r2, %[rows], #-4\n"
1894 "cmp r1, r2\n"
1895 "beq 3f\n"
1896#define RUY_LOAD_FOUR_BY_FOUR() \
1897 /* Load q0 */ \
1898 "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \
1899 /* if src_inc0 != 0, add 16 to src_ptr0 */ \
1900 "and r3, %[src_inc], #1\n" \
1901 "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\
1902 /* Load q1 */ \
1903 "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \
1904 /* if src_inc1 != 0, add 16 to src_ptr0 */ \
1905 "and r3, %[src_inc], #2\n" \
1906 "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\
1907 /* Load q2 */ \
1908 "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \
1909 /* if src_inc2 != 0, add 16 to src_ptr0 */ \
1910 "and r3, %[src_inc], #4\n" \
1911 "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\
1912 /* Load q3 */ \
1913 "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \
1914 /* if src_inc3 != 0, add 16 to src_ptr0 */ \
1915 "and r3, %[src_inc], #8\n" \
1916 "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\
1917
1918 RUY_LOAD_FOUR_BY_FOUR()
1919 "add r1, r1, #4\n"
1920 "cmp r1, r2\n"
1921
1922 "beq 2f\n"
1923
1924 "1:\n"
1925 "add r1, r1, #4\n"
1926
1927 // Transpose 4x4 matrix.
1928 "vzip.32 q0, q1\n"
1929 "vzip.32 q2, q3\n"
1930
1931 "vtrn.32 q0, q2\n"
1932 "vtrn.32 q1, q3\n"
1933
1934 "vzip.32 q0, q2\n"
1935 "vzip.32 q1, q3\n"
1936
1937 "vmov q8, q0\n"
1938 "vmov q9, q1\n"
1939 "vmov q10, q2\n"
1940 "vmov q11, q3\n"
1941
1942 RUY_LOAD_FOUR_BY_FOUR()
1943#undef RUY_LOAD_FOUR_BY_FOUR
1944
1945#define RUY_STORE_FOUR_BY_FOUR() \
1946 /* Store q8, q10, q9, q11 */ \
1947 /* q8 = d16, d17 */ \
1948 "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \
1949 /* q10 = d20, d21 */ \
1950 "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1951 "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \
1952 /* q9 = d18, d19 */ \
1953 "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1954 "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \
1955 /* q11 = d22, d23 */ \
1956 "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1957 "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \
1958 "add %[packed_ptr], %[packed_ptr], %[stride]\n" \
1959
1960 RUY_STORE_FOUR_BY_FOUR()
1961 "cmp r1, r2\n"
1962
1963 "bne 1b\n"
1964
1965 "2:\n"
1966
1967 // Transpose 4x4 matrix.
1968 "vzip.32 q0, q1\n"
1969 "vzip.32 q2, q3\n"
1970
1971 "vtrn.32 q0, q2\n"
1972 "vtrn.32 q1, q3\n"
1973
1974 "vzip.32 q0, q2\n"
1975 "vzip.32 q1, q3\n"
1976
1977 "vmov q8, q0\n"
1978 "vmov q9, q1\n"
1979 "vmov q10, q2\n"
1980 "vmov q11, q3\n"
1981
1982 RUY_STORE_FOUR_BY_FOUR()
1983#undef RUY_STORE_FOUR_BY_FOUR
1984 "3:\n"
1985
1986 "ands r2, %[rows], #3\n"
1987 "beq 4f\n"
1988 "mov r0, #0\n"
1989 // Zero out q0 - q3
1990 "vdup.32 q0, r0\n"
1991 "vdup.32 q1, r0\n"
1992 "vdup.32 q2, r0\n"
1993 "vdup.32 q3, r0\n"
1994#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \
1995 "cmp r2, #" #R "\n" \
1996 "beq 5f\n" \
1997 "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \
1998 "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \
1999 "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \
2000 "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n"
2001
2002#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \
2003 "cmp r2, #" #R "\n" \
2004 "beq 5f\n" \
2005 "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \
2006 "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \
2007 "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \
2008 "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n"
2009
2010 RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0)
2011 RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1)
2012 RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0)
2013 RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1)
2014#undef RUY_LOAD_ONE_ROW_SECOND_HALF
2015#undef RUY_LOAD_ONE_ROW_FIRST_HALF
2016 "5:\n"
2017
2018 // Transpose 4x4 matrix.
2019 "vzip.32 q0, q1\n"
2020 "vzip.32 q2, q3\n"
2021
2022 "vtrn.32 q0, q2\n"
2023 "vtrn.32 q1, q3\n"
2024
2025 "vzip.32 q0, q2\n"
2026 "vzip.32 q1, q3\n"
2027
2028 "vmov q8, q0\n"
2029 "vmov q9, q1\n"
2030 "vmov q10, q2\n"
2031 "vmov q11, q3\n"
2032
2033 "mov r1, #32\n"
2034
2035#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
2036 "cmp r2, #" #ROW "\n" \
2037 "beq 4f\n" \
2038 "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \
2039 "add %[packed_ptr], %[packed_ptr], %[stride]\n"
2040
2041 // Store q8
2042 RUY_STORE_ONE_ROW(0, q8)
2043 // Store q10
2044 RUY_STORE_ONE_ROW(1, q10)
2045 // Store q9
2046 RUY_STORE_ONE_ROW(2, q9)
2047 // Store q11
2048 RUY_STORE_ONE_ROW(3, q11)
2049
2050#undef RUY_STORE_ONE_ROW
2051
2052 "4:\n"
2053
2054 // clang-format on
2055 : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1),
2056 [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3),
2057 [ packed_ptr ] "+r"(packed_ptr)
2058 : [ src_inc ] "r"(static_cast<std::int64_t>(src_inc)),
2059 [ rows ] "r"(src_rows), [ stride ] "r"(output_stride)
2060 : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3",
2061 "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11");
2062}
2063
2064#endif // (RUY_PLATFORM_NEON_32
2065
2066#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
2067void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
2068 const float* src_ptr1,
2069 const float* src_ptr2,
2070 const float* src_ptr3, int src_inc0,
2071 int src_inc1, int src_inc2, int src_inc3,
2072 int src_rows, float* packed_ptr) {
2073 profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)");
2074
2075 asm volatile(
2076 // clang-format off
2077 "mov w1, #0\n"
2078
2079 "and w2, %w[rows], #-4\n"
2080 "cmp w1, w2\n"
2081 "beq 3f\n"
2082 "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n"
2083 "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n"
2084 "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n"
2085 "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n"
2086 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n")
2087 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n")
2088 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n")
2089 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n")
2090 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n")
2091 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n")
2092 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n")
2093 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n")
2094 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n")
2095 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n")
2096 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n")
2097 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n")
2098 "add w1, w1, #4\n"
2099 "cmp w1, w2\n"
2100
2101 "beq 2f\n"
2102
2103 "1:\n"
2104 "add w1, w1, #4\n"
2105
2106 "ldr x10, [%[src_ptr0], #8]\n"
2107 "trn1 v16.4s, v0.4s, v1.4s\n"
2108 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n")
2109 "ldr x11, [%[src_ptr1], #8]\n"
2110 "trn2 v17.4s, v0.4s, v1.4s\n"
2111 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n")
2112 "ldr x12, [%[src_ptr2], #8]\n"
2113 "trn1 v18.4s, v2.4s, v3.4s\n"
2114 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n")
2115 "ldr x13, [%[src_ptr3], #8]\n"
2116 "trn2 v19.4s, v2.4s, v3.4s\n"
2117 RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n")
2118
2119 "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n"
2120 "trn1 v20.2d, v16.2d, v18.2d\n"
2121 "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n"
2122 "trn2 v22.2d, v16.2d, v18.2d\n"
2123 "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n"
2124 "trn1 v21.2d, v17.2d, v19.2d\n"
2125 "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n"
2126 "trn2 v23.2d, v17.2d, v19.2d\n"
2127 "cmp w1, w2\n"
2128
2129 "ins v0.d[1], x10\n"
2130 "str q20, [%[packed_ptr], #0]\n"
2131 "ins v1.d[1], x11\n"
2132 "str q21, [%[packed_ptr], #32]\n"
2133 "ins v2.d[1], x12\n"
2134 "str q22, [%[packed_ptr], #64]\n"
2135 "ins v3.d[1], x13\n"
2136 "str q23, [%[packed_ptr], #96]\n"
2137
2138 "add %[packed_ptr], %[packed_ptr], #128\n"
2139
2140 "bne 1b\n"
2141
2142 "2:\n"
2143
2144 "trn1 v16.4s, v0.4s, v1.4s\n"
2145 "trn2 v17.4s, v0.4s, v1.4s\n"
2146 "trn1 v18.4s, v2.4s, v3.4s\n"
2147 "trn2 v19.4s, v2.4s, v3.4s\n"
2148
2149 "trn1 v20.2d, v16.2d, v18.2d\n"
2150 "trn2 v22.2d, v16.2d, v18.2d\n"
2151 "trn1 v21.2d, v17.2d, v19.2d\n"
2152 "trn2 v23.2d, v17.2d, v19.2d\n"
2153
2154 "str q20, [%[packed_ptr], #0]\n"
2155 "str q21, [%[packed_ptr], #32]\n"
2156 "str q22, [%[packed_ptr], #64]\n"
2157 "str q23, [%[packed_ptr], #96]\n"
2158 "add %[packed_ptr], %[packed_ptr], #128\n"
2159
2160 "3:\n"
2161
2162 "ands w2, %w[rows], #3\n"
2163 "beq 4f\n"
2164 "movi v0.16b, #0\n"
2165 "movi v1.16b, #0\n"
2166 "movi v2.16b, #0\n"
2167 "movi v3.16b, #0\n"
2168#define RUY_LOAD_ONE_ROW(R) \
2169 "cmp w2, #" #R "\n" \
2170 "beq 5f\n" \
2171 "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \
2172 "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \
2173 "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \
2174 "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n"
2175
2176 RUY_LOAD_ONE_ROW(0)
2177 RUY_LOAD_ONE_ROW(1)
2178 RUY_LOAD_ONE_ROW(2)
2179 RUY_LOAD_ONE_ROW(3)
2180#undef RUY_LOAD_ONE_ROW
2181 "5:\n"
2182
2183 "trn1 v16.4s, v0.4s, v1.4s\n"
2184 "trn2 v17.4s, v0.4s, v1.4s\n"
2185 "trn1 v18.4s, v2.4s, v3.4s\n"
2186 "trn2 v19.4s, v2.4s, v3.4s\n"
2187
2188 "trn1 v20.2d, v16.2d, v18.2d\n"
2189 "trn2 v22.2d, v16.2d, v18.2d\n"
2190 "trn1 v21.2d, v17.2d, v19.2d\n"
2191 "trn2 v23.2d, v17.2d, v19.2d\n"
2192
2193 "mov x1, #32\n"
2194
2195#define RUY_STORE_ONE_ROW(ROW, REGISTER) \
2196 "cmp w2, #" #ROW "\n" \
2197 "beq 4f\n" \
2198 "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n"
2199
2200 RUY_STORE_ONE_ROW(0, v20)
2201 RUY_STORE_ONE_ROW(1, v21)
2202 RUY_STORE_ONE_ROW(2, v22)
2203 RUY_STORE_ONE_ROW(3, v23)
2204
2205#undef RUY_STORE_ONE_ROW
2206
2207 "4:\n"
2208
2209 // clang-format on
2210
2211 : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2),
2212 [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr)
2213 : [ src_inc0 ] "r"(static_cast<std::int64_t>(src_inc0)), [src_inc1] "r"(static_cast<std::int64_t>(src_inc1)), [src_inc2] "r"(static_cast<std::int64_t>(src_inc2)),
2214 [src_inc3] "r"(static_cast<std::int64_t>(src_inc3)), [rows] "r"(src_rows)
2215 : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
2216 "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
2217 "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
2218}
2219#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
2220
2221#if RUY_PLATFORM_NEON
2222
2223namespace {
2224// transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing
2225// to use these instructions like we would in assembly --- this is one instance
2226// where assembly is more idiomatic than intrinsics.
2227//
2228// The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very
2229// cumbersome. The issue is that transposing grouped of values has been exposed
2230// only as transposing values of a wider type, so this requires many
2231// vreinterpret's, and to make it worse, vtrn_* return NEON array types like
2232// int8x8x2_t for which vreinterpret's are not defined!
2233void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) {
2234 int8x8x2_t t = vtrn_s8(a, b);
2235 a = t.val[0];
2236 b = t.val[1];
2237}
2238
2239void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) {
2240 int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b));
2241 a = vreinterpret_s8_s16(t.val[0]);
2242 b = vreinterpret_s8_s16(t.val[1]);
2243}
2244
2245void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) {
2246 int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b));
2247 a = vreinterpret_s8_s32(t.val[0]);
2248 b = vreinterpret_s8_s32(t.val[1]);
2249}
2250} // namespace
2251
2252void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
2253 int src_rows, int src_cols, int block_row,
2254 int start_col, int end_col,
2255 std::int8_t* packed_ptr, int packed_stride,
2256 int packed_zero_point, std::int32_t* sums,
2257 int input_xor, int kernel_cols) {
2258 profiler::ScopeLabel label("Pack (kNeon, from row-major)");
2259
2260 int src_end_col = std::min(end_col, src_cols);
2261 int col = start_col;
2262 for (; col <= src_end_col - 8; col += 8) {
2263 // Each iteration of this loop handles 8 columns, and the kernel format
2264 // has 16 rows, so each iteration handles a 16x8 block.
2265 //
2266 // Since the source is row-major, handling 8 columns at a time means
2267 // loading only 8 bytes i.e. 64bit from each row. This may seem surprising
2268 // on 128bit SIMD like NEON. While we could handle 16 columns at a time,
2269 // we prefer to stick with 8 for the following reasons:
2270 // 1. The arithmetic (computing sums and transposing data) done on these
2271 // values is such that even though we initially start from 64bit vectors,
2272 // most of our NEON instructions are full 128bit instructions. For the
2273 // sums computation, that is because summing 8bit values requires
2274 // expansion to 16bit anyway. For the matrix transposition code, that is
2275 // because the ARM ZIP instructions take 64bit of data from two input
2276 // registers and zip it into a 128bit output. If we had 128bit of data
2277 // in each input registers, we would need 2x more ARM NEON instructions
2278 // to zip it.
2279 // 2. The main optimization target for this (ARM, 8bit, non-dotprod)
2280 // code path is in-order ARM cores such as the Cortex-A53, which prefer
2281 // 64bit loads anyway.
2282 // 3. Handling only 8 columns at a time limits the size of the final
2283 // leftover columns handled with slow scalar code.
2284 //
2285 // This code is not very optimized anyway, as evidenced from the facts that
2286 // (1) it's written in intrinsics, (2) it's not using separate versions
2287 // tuned for different types of CPU cores. At the level of optimization that
2288 // it's working at, this seems like a fair compromise. If one wanted to
2289 // maximize performance at the cost of more code complexity/size, one could
2290 // have code handling 16 columns at a time (maybe limited to
2291 // Tuning::kGeneric), then 8, then 4 to minimize the amount of slow
2292 // leftovers.
2293 //
2294 // Load 8 sums in sums0, sums1.
2295 int32x4_t sums0 = vld1q_s32(sums + col);
2296 int32x4_t sums1 = vld1q_s32(sums + col + 4);
2297 // Load the 8x16 block from the source matrix.
2298 // Each val* here is the data from one row.
2299 int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10,
2300 val11, val12, val13, val14, val15;
2301 // Even though this function takes a uint8_t* src_ptr, that's only a
2302 // type-erased pointer (using uint8_t* so that pointer arithmetic is
2303 // allowed). The actual type may be either uint8_t or int8_t. The only
2304 // difference it makes is that if it's uint8_t then we need to flip the
2305 // sign bit. This is specified by the input_xor value (which is 0x80 if the
2306 // input data is uint8_t, and 0x0 otherwise).
2307 auto load_and_convert = [=](const std::uint8_t* from) {
2308 return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from)));
2309 };
2310 if (block_row <= src_rows - 16) {
2311 // Load data in the regular case: there are still 16 rows to be read from
2312 // the source matrix.
2313 val0 = load_and_convert(src_ptr + 0 * src_stride);
2314 val1 = load_and_convert(src_ptr + 1 * src_stride);
2315 val2 = load_and_convert(src_ptr + 2 * src_stride);
2316 val3 = load_and_convert(src_ptr + 3 * src_stride);
2317 val4 = load_and_convert(src_ptr + 4 * src_stride);
2318 val5 = load_and_convert(src_ptr + 5 * src_stride);
2319 val6 = load_and_convert(src_ptr + 6 * src_stride);
2320 val7 = load_and_convert(src_ptr + 7 * src_stride);
2321 val8 = load_and_convert(src_ptr + 8 * src_stride);
2322 val9 = load_and_convert(src_ptr + 9 * src_stride);
2323 val10 = load_and_convert(src_ptr + 10 * src_stride);
2324 val11 = load_and_convert(src_ptr + 11 * src_stride);
2325 val12 = load_and_convert(src_ptr + 12 * src_stride);
2326 val13 = load_and_convert(src_ptr + 13 * src_stride);
2327 val14 = load_and_convert(src_ptr + 14 * src_stride);
2328 val15 = load_and_convert(src_ptr + 15 * src_stride);
2329 } else {
2330 // Boundary case: there are fewer than 16 rows to be read from the source
2331 // matrix. We pad by the zero_point.
2332 val0 = vdup_n_s8(packed_zero_point);
2333 val1 = val0;
2334 val2 = val0;
2335 val3 = val0;
2336 val4 = val0;
2337 val5 = val0;
2338 val6 = val0;
2339 val7 = val0;
2340 val8 = val0;
2341 val9 = val0;
2342 val10 = val0;
2343 val11 = val0;
2344 val12 = val0;
2345 val13 = val0;
2346 val14 = val0;
2347 val15 = val0;
2348 if (block_row + 0 < src_rows)
2349 val0 = load_and_convert(src_ptr + 0 * src_stride);
2350 if (block_row + 1 < src_rows)
2351 val1 = load_and_convert(src_ptr + 1 * src_stride);
2352 if (block_row + 2 < src_rows)
2353 val2 = load_and_convert(src_ptr + 2 * src_stride);
2354 if (block_row + 3 < src_rows)
2355 val3 = load_and_convert(src_ptr + 3 * src_stride);
2356 if (block_row + 4 < src_rows)
2357 val4 = load_and_convert(src_ptr + 4 * src_stride);
2358 if (block_row + 5 < src_rows)
2359 val5 = load_and_convert(src_ptr + 5 * src_stride);
2360 if (block_row + 6 < src_rows)
2361 val6 = load_and_convert(src_ptr + 6 * src_stride);
2362 if (block_row + 7 < src_rows)
2363 val7 = load_and_convert(src_ptr + 7 * src_stride);
2364 if (block_row + 8 < src_rows)
2365 val8 = load_and_convert(src_ptr + 8 * src_stride);
2366 if (block_row + 9 < src_rows)
2367 val9 = load_and_convert(src_ptr + 9 * src_stride);
2368 if (block_row + 10 < src_rows)
2369 val10 = load_and_convert(src_ptr + 10 * src_stride);
2370 if (block_row + 11 < src_rows)
2371 val11 = load_and_convert(src_ptr + 11 * src_stride);
2372 if (block_row + 12 < src_rows)
2373 val12 = load_and_convert(src_ptr + 12 * src_stride);
2374 if (block_row + 13 < src_rows)
2375 val13 = load_and_convert(src_ptr + 13 * src_stride);
2376 if (block_row + 14 < src_rows)
2377 val14 = load_and_convert(src_ptr + 14 * src_stride);
2378 if (block_row + 15 < src_rows)
2379 val15 = load_and_convert(src_ptr + 15 * src_stride);
2380 }
2381 src_ptr += 8;
2382 // Compute sums.
2383 int16x8_t sums16_0 = vaddl_s8(val0, val1);
2384 int16x8_t sums16_1 = vaddl_s8(val2, val3);
2385 sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5));
2386 sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7));
2387 sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9));
2388 sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11));
2389 sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13));
2390 sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15));
2391 int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1);
2392 sums0 = vaddw_s16(sums0, vget_low_s16(sums16));
2393 sums1 = vaddw_s16(sums1, vget_high_s16(sums16));
2394 // Store sums.
2395 vst1q_s32(sums + col, sums0);
2396 vst1q_s32(sums + col + 4, sums1);
2397
2398 // Transpose the data, i.e. change the storage order of the
2399 // 16x8 block, to convert from the row-major source to the
2400 // column-major packed format.
2401 //
2402 // Before, for i in [0, 15], val<i> is the i-th row.
2403 // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column.
2404 transpose_8bit_vals(val0, val1);
2405 transpose_8bit_vals(val2, val3);
2406 transpose_8bit_vals(val4, val5);
2407 transpose_8bit_vals(val6, val7);
2408 transpose_8bit_vals(val8, val9);
2409 transpose_8bit_vals(val10, val11);
2410 transpose_8bit_vals(val12, val13);
2411 transpose_8bit_vals(val14, val15);
2412 transpose_16bit_vals(val0, val2);
2413 transpose_16bit_vals(val1, val3);
2414 transpose_16bit_vals(val4, val6);
2415 transpose_16bit_vals(val5, val7);
2416 transpose_16bit_vals(val8, val10);
2417 transpose_16bit_vals(val9, val11);
2418 transpose_16bit_vals(val12, val14);
2419 transpose_16bit_vals(val13, val15);
2420 transpose_32bit_vals(val0, val4);
2421 transpose_32bit_vals(val1, val5);
2422 transpose_32bit_vals(val2, val6);
2423 transpose_32bit_vals(val3, val7);
2424 transpose_32bit_vals(val8, val12);
2425 transpose_32bit_vals(val9, val13);
2426 transpose_32bit_vals(val10, val14);
2427 transpose_32bit_vals(val11, val15);
2428 // Store to the packed_matrix.
2429 std::int8_t* dst_ptr = packed_ptr;
2430 vst1q_s8(dst_ptr, vcombine_s8(val0, val8));
2431 vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9));
2432 dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
2433 vst1q_s8(dst_ptr, vcombine_s8(val2, val10));
2434 vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11));
2435 packed_ptr += 4 * packed_stride;
2436 dst_ptr = packed_ptr;
2437 vst1q_s8(dst_ptr, vcombine_s8(val4, val12));
2438 vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13));
2439 dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32;
2440 vst1q_s8(dst_ptr, vcombine_s8(val6, val14));
2441 vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15));
2442 packed_ptr += 4 * packed_stride;
2443 }
2444 // Handle remaining columns, not fitting in a full block of 8 columns, but
2445 // still true columns frome the source matrix (as opposed to the final columns
2446 // below).
2447 for (; col < src_end_col; col++) {
2448 std::int32_t accum = 0;
2449 std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
2450 for (int r = 0; r < 16; r++) {
2451 std::int8_t packed_val = (block_row + r < src_rows)
2452 ? (src_ptr[r * src_stride] ^ input_xor)
2453 : packed_zero_point;
2454 accum += packed_val;
2455 dst_ptr[r] = packed_val;
2456 }
2457 if (sums) {
2458 sums[col] += accum;
2459 }
2460 src_ptr++;
2461 if (((col + 1) & (kernel_cols - 1)) == 0) {
2462 packed_ptr += kernel_cols * packed_stride;
2463 }
2464 }
2465 // Handle the final columns of the packed matrix, beyond the last column of
2466 // the source matrix. The values here don't matter, we just want to avoid
2467 // leaving uninitialized data. Since the sums are already initialized above,
2468 // we don't need to do anything about them here.
2469 for (; col < end_col; col++) {
2470 std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16;
2471 std::memset(dst_ptr, 0, 16);
2472 if (((col + 1) & (kernel_cols - 1)) == 0) {
2473 packed_ptr += kernel_cols * packed_stride;
2474 }
2475 }
2476}
2477
2478#endif
2479
2480} // namespace ruy
2481