1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "ruy/pack_arm.h" |
17 | |
18 | #include <cstdint> |
19 | |
20 | #include "ruy/asm_helpers.h" |
21 | #include "ruy/opt_set.h" |
22 | #include "ruy/pack_common.h" |
23 | #include "ruy/platform.h" |
24 | #include "ruy/profiler/instrumentation.h" |
25 | |
26 | #if RUY_PLATFORM_NEON |
27 | #include <arm_neon.h> |
28 | #endif |
29 | |
30 | namespace ruy { |
31 | |
32 | #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
33 | |
34 | void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1, |
35 | const void* src_ptr2, const void* src_ptr3, |
36 | int src_inc0, int src_inc1, int src_inc2, |
37 | int src_inc3, int src_rows, int src_zero_point, |
38 | std::int8_t* packed_ptr, std::int32_t* sums_ptr, |
39 | int input_xor) { |
40 | profiler::ScopeLabel label("Pack (kNeon)" ); |
41 | asm volatile( |
42 | // clang-format off |
43 | // v26 will be the vector to XOR input values with to perform |
44 | // any input data type conversion (e.g. uint8 to int8). |
45 | "dup v26.16b, %w[input_xor]\n" |
46 | // w1 will be the number of rows already loaded. |
47 | "mov w1, #0\n" |
48 | // v28--v32 will be used to accumulate the sums |
49 | "movi v28.4s, #0\n" |
50 | "movi v29.4s, #0\n" |
51 | "movi v30.4s, #0\n" |
52 | "movi v31.4s, #0\n" |
53 | // Let w2 be `rows` rounded down to multiple of 16. |
54 | "ands w2, %w[rows], #-16\n" |
55 | // If there are no full blocks of 16 rows to process, jump to the |
56 | // code handling the last < 16 rows. |
57 | "beq 3f\n" |
58 | // Load the first block of 16 rows. |
59 | "add w1, w1, #16\n" |
60 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
61 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
62 | // Check if these were the only full block of 16 rows to load. |
63 | "cmp w1, w2\n" |
64 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
65 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
66 | // In that case, jump to the code handling the last loaded block of |
67 | // 16 rows. |
68 | "beq 2f\n" |
69 | // Main loop processing blocks of 16 rows. |
70 | "1:\n" |
71 | // Load the next 16 rows, interleaved with the XOR input type |
72 | // conversion (e.g. uint8->int8) on the already loaded inputs. |
73 | "add w1, w1, #16\n" |
74 | "eor v4.16b, v0.16b, v26.16b\n" |
75 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
76 | "eor v5.16b, v1.16b, v26.16b\n" |
77 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
78 | "eor v6.16b, v2.16b, v26.16b\n" |
79 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
80 | "eor v7.16b, v3.16b, v26.16b\n" |
81 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
82 | // Compute the sums, interleaved with storing to the packed matrix. |
83 | "saddlp v16.8h, v4.16b\n" |
84 | "str q4, [%[packed_ptr], #0]\n" |
85 | "saddlp v17.8h, v5.16b\n" |
86 | "str q5, [%[packed_ptr], #16]\n" |
87 | "saddlp v18.8h, v6.16b\n" |
88 | "str q6, [%[packed_ptr], #32]\n" |
89 | "saddlp v19.8h, v7.16b\n" |
90 | "str q7, [%[packed_ptr], #48]\n" |
91 | "sadalp v28.4s, v16.8h\n" |
92 | // Was this the last block of 16 rows to load? |
93 | "cmp w1, w2\n" |
94 | "sadalp v29.4s, v17.8h\n" |
95 | "add %[packed_ptr], %[packed_ptr], #64\n" |
96 | "sadalp v30.4s, v18.8h\n" |
97 | "sadalp v31.4s, v19.8h\n" |
98 | // End of main loop on blocks of 16 rows. |
99 | "bne 1b\n" |
100 | |
101 | // Code handling the last already-loaded block of 16 rows. |
102 | "2:\n" |
103 | |
104 | // Process the last loaded full 16x4 block. |
105 | "eor v4.16b, v0.16b, v26.16b\n" |
106 | "eor v5.16b, v1.16b, v26.16b\n" |
107 | "eor v6.16b, v2.16b, v26.16b\n" |
108 | "eor v7.16b, v3.16b, v26.16b\n" |
109 | |
110 | "saddlp v16.8h, v4.16b\n" |
111 | "str q4, [%[packed_ptr], #0]\n" |
112 | "saddlp v17.8h, v5.16b\n" |
113 | "str q5, [%[packed_ptr], #16]\n" |
114 | "saddlp v18.8h, v6.16b\n" |
115 | "str q6, [%[packed_ptr], #32]\n" |
116 | "saddlp v19.8h, v7.16b\n" |
117 | "str q7, [%[packed_ptr], #48]\n" |
118 | "sadalp v28.4s, v16.8h\n" |
119 | "sadalp v29.4s, v17.8h\n" |
120 | "sadalp v30.4s, v18.8h\n" |
121 | "sadalp v31.4s, v19.8h\n" |
122 | |
123 | "add %[packed_ptr], %[packed_ptr], #64\n" |
124 | |
125 | // End of code handling full blocks of 16 rows. |
126 | // Now we handle any remaining rows. |
127 | "3:\n" |
128 | // Let w2 be the number of rows left to handle. |
129 | "ands w2, %w[rows], #15\n" |
130 | // If w2==0, there are no remaining rows, jump to the end. |
131 | "beq 4f\n" |
132 | // Zero out a 16x4 block in registers, which we'll partially overwrite |
133 | // with any remaining rows. |
134 | "dup v0.16b, %w[src_zero_point]\n" |
135 | "dup v1.16b, %w[src_zero_point]\n" |
136 | "dup v2.16b, %w[src_zero_point]\n" |
137 | "dup v3.16b, %w[src_zero_point]\n" |
138 | #define RUY_LOAD_ONE_ROW(R) \ |
139 | "cmp w2, #" #R "\n" \ |
140 | "beq 5f\n" \ |
141 | "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ |
142 | "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ |
143 | "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ |
144 | "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" |
145 | |
146 | RUY_LOAD_ONE_ROW(0) |
147 | RUY_LOAD_ONE_ROW(1) |
148 | RUY_LOAD_ONE_ROW(2) |
149 | RUY_LOAD_ONE_ROW(3) |
150 | RUY_LOAD_ONE_ROW(4) |
151 | RUY_LOAD_ONE_ROW(5) |
152 | RUY_LOAD_ONE_ROW(6) |
153 | RUY_LOAD_ONE_ROW(7) |
154 | RUY_LOAD_ONE_ROW(8) |
155 | RUY_LOAD_ONE_ROW(9) |
156 | RUY_LOAD_ONE_ROW(10) |
157 | RUY_LOAD_ONE_ROW(11) |
158 | RUY_LOAD_ONE_ROW(12) |
159 | RUY_LOAD_ONE_ROW(13) |
160 | RUY_LOAD_ONE_ROW(14) |
161 | // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. |
162 | #undef RUY_LOAD_ONE_ROW |
163 | "5:\n" |
164 | |
165 | // Process the last zero-padded 16x4 block. |
166 | "eor v4.16b, v0.16b, v26.16b\n" |
167 | "eor v5.16b, v1.16b, v26.16b\n" |
168 | "eor v6.16b, v2.16b, v26.16b\n" |
169 | "eor v7.16b, v3.16b, v26.16b\n" |
170 | |
171 | "saddlp v16.8h, v4.16b\n" |
172 | "saddlp v17.8h, v5.16b\n" |
173 | "saddlp v18.8h, v6.16b\n" |
174 | "saddlp v19.8h, v7.16b\n" |
175 | "sadalp v28.4s, v16.8h\n" |
176 | "sadalp v29.4s, v17.8h\n" |
177 | "sadalp v30.4s, v18.8h\n" |
178 | "sadalp v31.4s, v19.8h\n" |
179 | |
180 | "str q4, [%[packed_ptr], #0]\n" |
181 | "str q5, [%[packed_ptr], #16]\n" |
182 | "str q6, [%[packed_ptr], #32]\n" |
183 | "str q7, [%[packed_ptr], #48]\n" |
184 | "add %[packed_ptr], %[packed_ptr], #64\n" |
185 | |
186 | "4:\n" |
187 | |
188 | // Horizontal reduction of the registers used to accumulate sums. |
189 | "addp v28.4s, v28.4s, v29.4s\n" |
190 | "addp v30.4s, v30.4s, v31.4s\n" |
191 | "addp v28.4s, v28.4s, v30.4s\n" |
192 | |
193 | // Store the sums. |
194 | "cmp %[sums_ptr], #0\n" |
195 | "beq 6f\n" |
196 | "st1 {v28.4s}, [%[sums_ptr]], #16\n" |
197 | "6:\n" |
198 | // clang-format on |
199 | |
200 | : [src_ptr0] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), |
201 | [src_ptr2] "+r" (src_ptr2), [src_ptr3] "+r" (src_ptr3), |
202 | [packed_ptr] "+r" (packed_ptr), [sums_ptr] "+r" (sums_ptr) |
203 | : [src_inc0] "r" (static_cast<std::int64_t>(src_inc0)), |
204 | [src_inc1] "r" (static_cast<std::int64_t>(src_inc1)), |
205 | [src_inc2] "r" (static_cast<std::int64_t>(src_inc2)), |
206 | [src_inc3] "r" (static_cast<std::int64_t>(src_inc3)), |
207 | [rows] "r" (src_rows), [src_zero_point] "r" (src_zero_point), |
208 | [input_xor] "r" (input_xor) |
209 | : "cc" , "memory" , "x1" , "x2" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , |
210 | "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , "v16" , |
211 | "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , |
212 | "v27" , "v28" , "v29" , "v30" , "v31" ); |
213 | } |
214 | #endif |
215 | |
216 | #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) |
217 | |
218 | #define RUY_OFFSET_SRC_PTR0 0 |
219 | #define RUY_OFFSET_SRC_PTR1 4 |
220 | #define RUY_OFFSET_SRC_PTR2 8 |
221 | #define RUY_OFFSET_SRC_PTR3 12 |
222 | #define RUY_OFFSET_SUMS_PTR 16 |
223 | #define RUY_OFFSET_PACKED_PTR 20 |
224 | #define RUY_OFFSET_SRC_INC0 24 |
225 | #define RUY_OFFSET_SRC_INC1 28 |
226 | #define RUY_OFFSET_SRC_INC2 32 |
227 | #define RUY_OFFSET_SRC_INC3 36 |
228 | #define RUY_OFFSET_SRC_ROWS 40 |
229 | #define RUY_OFFSET_SRC_ZERO_POINT 44 |
230 | #define RUY_OFFSET_INPUT_XOR 48 |
231 | |
232 | template <typename Params> |
233 | void CheckOffsetsInPackParams8bit(const Params&) { |
234 | static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, "" ); |
235 | static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, "" ); |
236 | static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, "" ); |
237 | static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, "" ); |
238 | static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, "" ); |
239 | static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, "" ); |
240 | static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, "" ); |
241 | static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, "" ); |
242 | static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, "" ); |
243 | static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, "" ); |
244 | static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, "" ); |
245 | static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, |
246 | "" ); |
247 | static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, "" ); |
248 | } |
249 | |
250 | // No attempt made at making this code efficient on A55-ish cores yet. |
251 | void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params) { |
252 | CheckOffsetsInPackParams8bit(params); |
253 | profiler::ScopeLabel label("Pack (kNeon)" ); |
254 | const void* src_ptr0 = params.src_ptr0; |
255 | const void* src_ptr1 = params.src_ptr1; |
256 | const void* src_ptr2 = params.src_ptr2; |
257 | const void* src_ptr3 = params.src_ptr3; |
258 | const int src_inc0 = params.src_inc0; |
259 | const int src_inc1 = params.src_inc1; |
260 | const int src_inc2 = params.src_inc2; |
261 | const int src_inc3 = params.src_inc3; |
262 | const std::int8_t* packed_ptr = params.packed_ptr; |
263 | |
264 | asm volatile( |
265 | // clang-format off |
266 | |
267 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" |
268 | "vdup.8 q11, r2\n" |
269 | "mov r1, #0\n" |
270 | // Zero-out the accumulators |
271 | "vmov.i32 q12, #0\n" |
272 | "vmov.i32 q13, #0\n" |
273 | "vmov.i32 q14, #0\n" |
274 | "vmov.i32 q15, #0\n" |
275 | |
276 | // Round down src_rows to nearest multiple of 16. |
277 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" |
278 | "and r2, r3, #-16\n" |
279 | "cmp r1, r2\n" |
280 | "beq 3f\n" |
281 | |
282 | "1:\n" |
283 | "add r1, r1, #16\n" |
284 | /* Load q0 */ |
285 | "vld1.8 {d0, d1}, [%[src_ptr0]]\n" |
286 | "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" |
287 | RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n" ) |
288 | |
289 | /* Load q1 */ |
290 | "vld1.8 {d2, d3}, [%[src_ptr1]]\n" |
291 | "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" |
292 | RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n" ) |
293 | |
294 | "veor.8 q4, q0, q11\n" |
295 | "veor.8 q5, q1, q11\n" |
296 | |
297 | // Pairwise add in to 16b accumulators. |
298 | "vpaddl.s8 q8, q4\n" |
299 | "vpaddl.s8 q9, q5\n" |
300 | |
301 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
302 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
303 | |
304 | // Pairwise add accumulate into 32b accumulators. |
305 | // q12 and q13 contain 4x32b accumulators |
306 | "vpadal.s16 q12, q8\n" |
307 | "vpadal.s16 q13, q9\n" |
308 | |
309 | // Now do the same for src_ptr2 and src_ptr3. |
310 | "vld1.8 {d0, d1}, [%[src_ptr2]]\n" |
311 | "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" |
312 | RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n" ) |
313 | |
314 | "vld1.8 {d2, d3}, [%[src_ptr3]]\n" |
315 | "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" |
316 | RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n" ) |
317 | |
318 | "veor.8 q4, q0, q11\n" |
319 | "veor.8 q5, q1, q11\n" |
320 | |
321 | "vpaddl.s8 q8, q4\n" |
322 | "vpaddl.s8 q9, q5\n" |
323 | |
324 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
325 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
326 | |
327 | // Pairwise add accumulate into 32b accumulators. |
328 | // q14 and q15 contain 4x32b accumulators |
329 | "vpadal.s16 q14, q8\n" |
330 | "vpadal.s16 q15, q9\n" |
331 | |
332 | "cmp r1, r2\n" |
333 | "bne 1b\n" |
334 | |
335 | "3:\n" |
336 | |
337 | // Now pack the last (num_rows % 16) rows. |
338 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" |
339 | "ands r2, r3, #15\n" |
340 | "beq 4f\n" |
341 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" |
342 | "vdup.8 q0, r3\n" |
343 | "vdup.8 q1, r3\n" |
344 | |
345 | // First, read/accumulate/write for src_ptr0 and src_ptr1. |
346 | #define RUY_LOAD_ONE_ROW1(I, R) \ |
347 | "cmp r2, #" #I "\n" \ |
348 | "beq 5f\n" \ |
349 | "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ |
350 | "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ |
351 | |
352 | RUY_LOAD_ONE_ROW1(0, 0) |
353 | RUY_LOAD_ONE_ROW1(1, 1) |
354 | RUY_LOAD_ONE_ROW1(2, 2) |
355 | RUY_LOAD_ONE_ROW1(3, 3) |
356 | RUY_LOAD_ONE_ROW1(4, 4) |
357 | RUY_LOAD_ONE_ROW1(5, 5) |
358 | RUY_LOAD_ONE_ROW1(6, 6) |
359 | RUY_LOAD_ONE_ROW1(7, 7) |
360 | #undef RUY_LOAD_ONE_ROW1 |
361 | |
362 | #define RUY_LOAD_ONE_ROW2(I, R) \ |
363 | "cmp r2, #" #I "\n" \ |
364 | "beq 5f\n" \ |
365 | "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ |
366 | "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ |
367 | |
368 | RUY_LOAD_ONE_ROW2(8, 0) |
369 | RUY_LOAD_ONE_ROW2(9, 1) |
370 | RUY_LOAD_ONE_ROW2(10, 2) |
371 | RUY_LOAD_ONE_ROW2(11, 3) |
372 | RUY_LOAD_ONE_ROW2(12, 4) |
373 | RUY_LOAD_ONE_ROW2(13, 5) |
374 | RUY_LOAD_ONE_ROW2(14, 6) |
375 | RUY_LOAD_ONE_ROW2(15, 7) |
376 | #undef RUY_LOAD_ONE_ROW2 |
377 | |
378 | "5:\n" |
379 | |
380 | "veor.16 q4, q0, q11\n" |
381 | "veor.16 q5, q1, q11\n" |
382 | |
383 | "vpaddl.s8 q8, q4\n" |
384 | "vpaddl.s8 q9, q5\n" |
385 | |
386 | // Pairwise add accumulate to 4x32b accumulators. |
387 | "vpadal.s16 q12, q8\n" |
388 | "vpadal.s16 q13, q9\n" |
389 | |
390 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
391 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
392 | |
393 | // Reset to src_zero for src_ptr2 and src_ptr3. |
394 | "vdup.8 q0, r3\n" |
395 | "vdup.8 q1, r3\n" |
396 | |
397 | // Next, read/accumulate/write for src_ptr2 and src_ptr3. |
398 | #define RUY_LOAD_ONE_ROW1(I, R) \ |
399 | "cmp r2, #" #I "\n" \ |
400 | "beq 5f\n" \ |
401 | "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ |
402 | "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ |
403 | |
404 | RUY_LOAD_ONE_ROW1(0, 0) |
405 | RUY_LOAD_ONE_ROW1(1, 1) |
406 | RUY_LOAD_ONE_ROW1(2, 2) |
407 | RUY_LOAD_ONE_ROW1(3, 3) |
408 | RUY_LOAD_ONE_ROW1(4, 4) |
409 | RUY_LOAD_ONE_ROW1(5, 5) |
410 | RUY_LOAD_ONE_ROW1(6, 6) |
411 | RUY_LOAD_ONE_ROW1(7, 7) |
412 | #undef RUY_LOAD_ONE_ROW1 |
413 | |
414 | #define RUY_LOAD_ONE_ROW2(I, R) \ |
415 | "cmp r2, #" #I "\n" \ |
416 | "beq 5f\n" \ |
417 | "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ |
418 | "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ |
419 | |
420 | RUY_LOAD_ONE_ROW2(8, 0) |
421 | RUY_LOAD_ONE_ROW2(9, 1) |
422 | RUY_LOAD_ONE_ROW2(10, 2) |
423 | RUY_LOAD_ONE_ROW2(11, 3) |
424 | RUY_LOAD_ONE_ROW2(12, 4) |
425 | RUY_LOAD_ONE_ROW2(13, 5) |
426 | RUY_LOAD_ONE_ROW2(14, 6) |
427 | RUY_LOAD_ONE_ROW2(15, 7) |
428 | #undef RUY_LOAD_ONE_ROW2 |
429 | |
430 | "5:\n" |
431 | |
432 | "veor.16 q4, q0, q11\n" |
433 | "veor.16 q5, q1, q11\n" |
434 | |
435 | "vpaddl.s8 q8, q4\n" |
436 | "vpaddl.s8 q9, q5\n" |
437 | |
438 | // Pairwise add accumulate to 4x32b accumulators. |
439 | "vpadal.s16 q14, q8\n" |
440 | "vpadal.s16 q15, q9\n" |
441 | |
442 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
443 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
444 | |
445 | "4:\n" |
446 | // Pairwise add 32-bit accumulators |
447 | "vpadd.i32 d24, d24, d25\n" |
448 | "vpadd.i32 d26, d26, d27\n" |
449 | "vpadd.i32 d28, d28, d29\n" |
450 | "vpadd.i32 d30, d30, d31\n" |
451 | // Final 32-bit values per row |
452 | "vpadd.i32 d25, d24, d26\n" |
453 | "vpadd.i32 d27, d28, d30\n" |
454 | |
455 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" |
456 | "cmp r3, #0\n" |
457 | "beq 6f\n" |
458 | "vst1.32 {d25}, [r3]!\n" |
459 | "vst1.32 {d27}, [r3]!\n" |
460 | "6:\n" |
461 | // clang-format on |
462 | |
463 | : [ src_ptr0 ] "+r" (src_ptr0), [ src_ptr1 ] "+r" (src_ptr1), |
464 | [ src_ptr2 ] "+r" (src_ptr2), [ src_ptr3 ] "+r" (src_ptr3) |
465 | : [ src_inc0 ] "r" (src_inc0), [ src_inc1 ] "r" (src_inc1), |
466 | [ src_inc2 ] "r" (src_inc2), [ src_inc3 ] "r" (src_inc3), |
467 | [ packed_ptr ] "r" (packed_ptr), [ params ] "r" (¶ms) |
468 | : "cc" , "memory" , "r1" , "r2" , "r3" , "q0" , "q1" , "q2" , "q3" , |
469 | "q4" , "q5" , "q6" , "q7" , "q8" , "q9" , "q10" , "q11" , "q12" , "q13" ); |
470 | } |
471 | |
472 | // Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. |
473 | // No attempt made at making this code efficient on in-order cores yet. |
474 | // This version differs from the above in that we only handle two columns |
475 | // at a time. |
476 | void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params) { |
477 | CheckOffsetsInPackParams8bit(params); |
478 | profiler::ScopeLabel label("Pack (kNeon)" ); |
479 | const void* src_ptr0 = params.src_ptr0; |
480 | const void* src_ptr1 = params.src_ptr1; |
481 | const int src_inc0 = params.src_inc0; |
482 | const int src_inc1 = params.src_inc1; |
483 | const std::int8_t* packed_ptr = params.packed_ptr; |
484 | |
485 | asm volatile( |
486 | // clang-format off |
487 | |
488 | "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" |
489 | "vdup.8 q11, r2\n" |
490 | "mov r1, #0\n" |
491 | // Zero-out the accumulators |
492 | "vmov.i32 q12, #0\n" |
493 | "vmov.i32 q13, #0\n" |
494 | |
495 | // Round down src_rows to nearest multiple of 16. |
496 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" |
497 | "and r2, r3, #-16\n" |
498 | "cmp r1, r2\n" |
499 | "beq 3f\n" |
500 | |
501 | "1:\n" |
502 | "add r1, r1, #16\n" |
503 | /* Load q0 */ |
504 | "vld1.8 {d0, d1}, [%[src_ptr0]]\n" |
505 | "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" |
506 | |
507 | /* Load q1 */ |
508 | "vld1.8 {d2, d3}, [%[src_ptr1]]\n" |
509 | "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" |
510 | |
511 | "veor.8 q4, q0, q11\n" |
512 | "veor.8 q5, q1, q11\n" |
513 | |
514 | // Pairwise add in to 16b accumulators. |
515 | "vpaddl.s8 q8, q4\n" |
516 | "vpaddl.s8 q9, q5\n" |
517 | |
518 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
519 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
520 | |
521 | // Pairwise add accumulate into 32b accumulators. |
522 | // q12 and q13 contain 4x32b accumulators |
523 | "vpadal.s16 q12, q8\n" |
524 | "vpadal.s16 q13, q9\n" |
525 | |
526 | "cmp r1, r2\n" |
527 | |
528 | "bne 1b\n" |
529 | |
530 | "3:\n" |
531 | |
532 | // Now pack the last (num_rows % 16) rows. |
533 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" |
534 | "ands r2, r3, #15\n" |
535 | "beq 4f\n" |
536 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" |
537 | "vdup.8 q0, r3\n" |
538 | "vdup.8 q1, r3\n" |
539 | |
540 | // Read/accumulate/write for src_ptr0 and src_ptr1. |
541 | #define RUY_LOAD_ONE_ROW1(I, R) \ |
542 | "cmp r2, #" #I "\n" \ |
543 | "beq 5f\n" \ |
544 | "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ |
545 | "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ |
546 | |
547 | RUY_LOAD_ONE_ROW1(0, 0) |
548 | RUY_LOAD_ONE_ROW1(1, 1) |
549 | RUY_LOAD_ONE_ROW1(2, 2) |
550 | RUY_LOAD_ONE_ROW1(3, 3) |
551 | RUY_LOAD_ONE_ROW1(4, 4) |
552 | RUY_LOAD_ONE_ROW1(5, 5) |
553 | RUY_LOAD_ONE_ROW1(6, 6) |
554 | RUY_LOAD_ONE_ROW1(7, 7) |
555 | #undef RUY_LOAD_ONE_ROW1 |
556 | |
557 | #define RUY_LOAD_ONE_ROW2(I, R) \ |
558 | "cmp r2, #" #I "\n" \ |
559 | "beq 5f\n" \ |
560 | "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ |
561 | "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ |
562 | |
563 | RUY_LOAD_ONE_ROW2(8, 0) |
564 | RUY_LOAD_ONE_ROW2(9, 1) |
565 | RUY_LOAD_ONE_ROW2(10, 2) |
566 | RUY_LOAD_ONE_ROW2(11, 3) |
567 | RUY_LOAD_ONE_ROW2(12, 4) |
568 | RUY_LOAD_ONE_ROW2(13, 5) |
569 | RUY_LOAD_ONE_ROW2(14, 6) |
570 | RUY_LOAD_ONE_ROW2(15, 7) |
571 | #undef RUY_LOAD_ONE_ROW2 |
572 | |
573 | "5:\n" |
574 | |
575 | "veor.16 q4, q0, q11\n" |
576 | "veor.16 q5, q1, q11\n" |
577 | |
578 | "vpaddl.s8 q8, q4\n" |
579 | "vpaddl.s8 q9, q5\n" |
580 | |
581 | |
582 | // Pairwise add accumulate to 4x32b accumulators. |
583 | "vpadal.s16 q12, q8\n" |
584 | "vpadal.s16 q13, q9\n" |
585 | |
586 | "vst1.32 {q4}, [%[packed_ptr]]!\n" |
587 | "vst1.32 {q5}, [%[packed_ptr]]!\n" |
588 | |
589 | "4:\n" |
590 | |
591 | // Pairwise add 32-bit accumulators |
592 | "vpadd.i32 d24, d24, d25\n" |
593 | "vpadd.i32 d26, d26, d27\n" |
594 | // Final 32-bit values per row |
595 | "vpadd.i32 d25, d24, d26\n" |
596 | |
597 | "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" |
598 | "cmp r3, #0\n" |
599 | "beq 6f\n" |
600 | "vst1.32 {d25}, [r3]!\n" |
601 | "6:\n" |
602 | // clang-format on |
603 | |
604 | : [ src_ptr0 ] "+r" (src_ptr0), [ src_ptr1 ] "+r" (src_ptr1) |
605 | : [ src_inc0 ] "r" (src_inc0), [ src_inc1 ] "r" (src_inc1), |
606 | [ packed_ptr ] "r" (packed_ptr), [ params ] "r" (¶ms) |
607 | : "cc" , "memory" , "r1" , "r2" , "r3" , "q0" , "q1" , "q2" , "q3" , |
608 | "q4" , "q5" , "q6" , "q7" , "q8" , "q9" , "q10" , "q11" , "q12" , "q13" ); |
609 | } |
610 | |
611 | #undef RUY_OFFSET_SRC_PTR0 |
612 | #undef RUY_OFFSET_SRC_PTR1 |
613 | #undef RUY_OFFSET_SRC_PTR2 |
614 | #undef RUY_OFFSET_SRC_PTR32 |
615 | #undef RUY_OFFSET_SUMS_PTR |
616 | #undef RUY_OFFSET_PACKED_PTR0 |
617 | #undef RUY_OFFSET_SRC_INC0 |
618 | #undef RUY_OFFSET_SRC_INC1 |
619 | #undef RUY_OFFSET_SRC_INC2 |
620 | #undef RUY_OFFSET_SRC_INC3 |
621 | #undef RUY_OFFSET_SRC_ROWS |
622 | #undef RUY_OFFSET_SRC_ZERO_POINT |
623 | #undef RUY_OFFSET_INPUT_XOR |
624 | |
625 | #endif // RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) |
626 | |
627 | #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
628 | |
629 | void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1, |
630 | const void* src_ptr2, const void* src_ptr3, |
631 | int src_inc0, int src_inc1, int src_inc2, |
632 | int src_inc3, int src_rows, |
633 | int src_zero_point, std::int8_t* packed_ptr, |
634 | std::int32_t* sums_ptr, int input_xor) { |
635 | profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)" ); |
636 | asm volatile( |
637 | // clang-format off |
638 | // v26 will be the vector to XOR input values with to perform |
639 | // any input data type conversion (e.g. uint8 to int8). |
640 | "dup v26.16b, %w[input_xor]\n" |
641 | // w1 will be the number of rows already loaded. |
642 | "mov w1, #0\n" |
643 | // v28--v32 will be used to accumulate the sums |
644 | "movi v28.4s, #0\n" |
645 | "movi v29.4s, #0\n" |
646 | "movi v30.4s, #0\n" |
647 | "movi v31.4s, #0\n" |
648 | // Let w2 be `rows` rounded down to multiple of 16. |
649 | "ands w2, %w[rows], #-16\n" |
650 | // If there are no full blocks of 16 rows to process, jump to the |
651 | // code handling the last < 16 rows. |
652 | "beq 3f\n" |
653 | // Load the first block of 16 rows. |
654 | "add w1, w1, #16\n" |
655 | "ldr x10, [%[src_ptr0], #8]\n" |
656 | "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" |
657 | "ldr x11, [%[src_ptr1], #8]\n" |
658 | "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" |
659 | "ldr x12, [%[src_ptr2], #8]\n" |
660 | "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" |
661 | "ldr x13, [%[src_ptr3], #8]\n" |
662 | "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" |
663 | // Check if these were the only full block of 16 rows to load. |
664 | "cmp w1, w2\n" |
665 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n" ) |
666 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n" ) |
667 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n" ) |
668 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n" ) |
669 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n" ) |
670 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n" ) |
671 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n" ) |
672 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n" ) |
673 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n" ) |
674 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n" ) |
675 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n" ) |
676 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n" ) |
677 | // In that case, jump to the code handling the last loaded block of |
678 | // 16 rows. |
679 | "beq 2f\n" |
680 | // Main loop processing blocks of 16 rows. |
681 | "1:\n" |
682 | // Load the next 16 rows, interleaved with the XOR input type |
683 | // conversion (e.g. uint8->int8) on the already loaded inputs. |
684 | "add w1, w1, #16\n" |
685 | "ins v0.d[1], x10\n" |
686 | "ldr x10, [%[src_ptr0], #8]\n" |
687 | "ins v1.d[1], x11\n" |
688 | "ldr x11, [%[src_ptr1], #8]\n" |
689 | "ins v2.d[1], x12\n" |
690 | "ldr x12, [%[src_ptr2], #8]\n" |
691 | "ins v3.d[1], x13\n" |
692 | "ldr x13, [%[src_ptr3], #8]\n" |
693 | "eor v4.16b, v0.16b, v26.16b\n" |
694 | "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" |
695 | "eor v5.16b, v1.16b, v26.16b\n" |
696 | "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" |
697 | "eor v6.16b, v2.16b, v26.16b\n" |
698 | "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" |
699 | "eor v7.16b, v3.16b, v26.16b\n" |
700 | "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" |
701 | // Compute the sums, interleaved with storing to the packed matrix. |
702 | "saddlp v16.8h, v4.16b\n" |
703 | "str q4, [%[packed_ptr], #0]\n" |
704 | "saddlp v17.8h, v5.16b\n" |
705 | "str q5, [%[packed_ptr], #16]\n" |
706 | "saddlp v18.8h, v6.16b\n" |
707 | "str q6, [%[packed_ptr], #32]\n" |
708 | "saddlp v19.8h, v7.16b\n" |
709 | "str q7, [%[packed_ptr], #48]\n" |
710 | "sadalp v28.4s, v16.8h\n" |
711 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n" ) |
712 | // Was this the last block of 16 rows to load? |
713 | "cmp w1, w2\n" |
714 | "sadalp v29.4s, v17.8h\n" |
715 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n" ) |
716 | "add %[packed_ptr], %[packed_ptr], #64\n" |
717 | "sadalp v30.4s, v18.8h\n" |
718 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n" ) |
719 | "sadalp v31.4s, v19.8h\n" |
720 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n" ) |
721 | // End of main loop on blocks of 16 rows. |
722 | "bne 1b\n" |
723 | |
724 | // Code handling the last already-loaded block of 16 rows. |
725 | "2:\n" |
726 | // Process the last loaded full 16x4 block. |
727 | "ins v0.d[1], x10\n" |
728 | "ins v1.d[1], x11\n" |
729 | "ins v2.d[1], x12\n" |
730 | "ins v3.d[1], x13\n" |
731 | "eor v4.16b, v0.16b, v26.16b\n" |
732 | "eor v5.16b, v1.16b, v26.16b\n" |
733 | "eor v6.16b, v2.16b, v26.16b\n" |
734 | "eor v7.16b, v3.16b, v26.16b\n" |
735 | |
736 | "saddlp v16.8h, v4.16b\n" |
737 | "str q4, [%[packed_ptr], #0]\n" |
738 | "saddlp v17.8h, v5.16b\n" |
739 | "str q5, [%[packed_ptr], #16]\n" |
740 | "saddlp v18.8h, v6.16b\n" |
741 | "str q6, [%[packed_ptr], #32]\n" |
742 | "saddlp v19.8h, v7.16b\n" |
743 | "str q7, [%[packed_ptr], #48]\n" |
744 | "sadalp v28.4s, v16.8h\n" |
745 | "sadalp v29.4s, v17.8h\n" |
746 | "sadalp v30.4s, v18.8h\n" |
747 | "sadalp v31.4s, v19.8h\n" |
748 | |
749 | "add %[packed_ptr], %[packed_ptr], #64\n" |
750 | |
751 | // End of code handling full blocks of 16 rows. |
752 | // Now we handle any remaining rows. |
753 | "3:\n" |
754 | // Let w2 be the number of rows left to handle. |
755 | "ands w2, %w[rows], #15\n" |
756 | // If w2==0, there are no remaining rows, jump to the end. |
757 | "beq 4f\n" |
758 | // Zero out a 16x4 block in registers, which we'll partially overwrite |
759 | // with any remaining rows. |
760 | "dup v0.16b, %w[src_zero_point]\n" |
761 | "dup v1.16b, %w[src_zero_point]\n" |
762 | "dup v2.16b, %w[src_zero_point]\n" |
763 | "dup v3.16b, %w[src_zero_point]\n" |
764 | #define RUY_LOAD_ONE_ROW(R) \ |
765 | "cmp w2, #" #R "\n" \ |
766 | "beq 5f\n" \ |
767 | "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ |
768 | "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ |
769 | "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ |
770 | "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" |
771 | |
772 | RUY_LOAD_ONE_ROW(0) |
773 | RUY_LOAD_ONE_ROW(1) |
774 | RUY_LOAD_ONE_ROW(2) |
775 | RUY_LOAD_ONE_ROW(3) |
776 | RUY_LOAD_ONE_ROW(4) |
777 | RUY_LOAD_ONE_ROW(5) |
778 | RUY_LOAD_ONE_ROW(6) |
779 | RUY_LOAD_ONE_ROW(7) |
780 | RUY_LOAD_ONE_ROW(8) |
781 | RUY_LOAD_ONE_ROW(9) |
782 | RUY_LOAD_ONE_ROW(10) |
783 | RUY_LOAD_ONE_ROW(11) |
784 | RUY_LOAD_ONE_ROW(12) |
785 | RUY_LOAD_ONE_ROW(13) |
786 | RUY_LOAD_ONE_ROW(14) |
787 | // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. |
788 | #undef RUY_LOAD_ONE_ROW |
789 | "5:\n" |
790 | |
791 | // Process the last zero-padded 16x4 block. |
792 | "eor v4.16b, v0.16b, v26.16b\n" |
793 | "eor v5.16b, v1.16b, v26.16b\n" |
794 | "eor v6.16b, v2.16b, v26.16b\n" |
795 | "eor v7.16b, v3.16b, v26.16b\n" |
796 | |
797 | "saddlp v16.8h, v4.16b\n" |
798 | "saddlp v17.8h, v5.16b\n" |
799 | "saddlp v18.8h, v6.16b\n" |
800 | "saddlp v19.8h, v7.16b\n" |
801 | "sadalp v28.4s, v16.8h\n" |
802 | "sadalp v29.4s, v17.8h\n" |
803 | "sadalp v30.4s, v18.8h\n" |
804 | "sadalp v31.4s, v19.8h\n" |
805 | |
806 | "str q4, [%[packed_ptr], #0]\n" |
807 | "str q5, [%[packed_ptr], #16]\n" |
808 | "str q6, [%[packed_ptr], #32]\n" |
809 | "str q7, [%[packed_ptr], #48]\n" |
810 | "add %[packed_ptr], %[packed_ptr], #64\n" |
811 | |
812 | "4:\n" |
813 | |
814 | // Horizontal reduction of the registers used to accumulate sums. |
815 | "addp v28.4s, v28.4s, v29.4s\n" |
816 | "addp v30.4s, v30.4s, v31.4s\n" |
817 | "addp v28.4s, v28.4s, v30.4s\n" |
818 | |
819 | // Store the sums. |
820 | "cmp %[sums_ptr], #0\n" |
821 | "beq 6f\n" |
822 | "st1 {v28.4s}, [%[sums_ptr]], #16\n" |
823 | "6:\n" |
824 | // clang-format on |
825 | |
826 | : [ src_ptr0 ] "+r" (src_ptr0), [ src_ptr1 ] "+r" (src_ptr1), |
827 | [ src_ptr2 ] "+r" (src_ptr2), [ src_ptr3 ] "+r" (src_ptr3), |
828 | [ packed_ptr ] "+r" (packed_ptr), [ sums_ptr ] "+r" (sums_ptr) |
829 | : [ src_inc0 ] "r" (static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r" (static_cast<std::int64_t>(src_inc1)), |
830 | [ src_inc2 ] "r" (static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r" (static_cast<std::int64_t>(src_inc3)), |
831 | [ rows ] "r" (src_rows), |
832 | [ src_zero_point ] "r" (src_zero_point), |
833 | [input_xor] "r" (input_xor) |
834 | : "cc" , "memory" , "x1" , "x2" , "x10" , "x11" , "x12" , "x13" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , |
835 | "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , |
836 | "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , |
837 | "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
838 | } |
839 | |
840 | void Pack8bitColMajorForNeonDotprodA55ish( |
841 | const void* src_ptr0, const void* src_ptr1, const void* src_ptr2, |
842 | const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2, |
843 | int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr, |
844 | std::int32_t* sums_ptr, int input_xor) { |
845 | profiler::ScopeLabel label( |
846 | "Pack (kNeonDotprod, optimized for in-order cores)" ); |
847 | asm volatile( |
848 | // clang-format off |
849 | // v26 will be the vector to XOR input values with to perform |
850 | // any input data type conversion (e.g. uint8 to int8). |
851 | "dup v26.16b, %w[input_xor]\n" |
852 | // v27 will be filled with 1's. It will be used as an operand |
853 | // to SDOT to compute the sums. |
854 | "mov w1, #1\n" |
855 | "dup v27.16b, w1\n" |
856 | // w1 will be the number of rows already loaded. |
857 | "mov w1, #0\n" |
858 | // v28--v32 will be used to accumulate the sums |
859 | "movi v28.4s, #0\n" |
860 | "movi v29.4s, #0\n" |
861 | "movi v30.4s, #0\n" |
862 | "movi v31.4s, #0\n" |
863 | |
864 | // Let w2 be `rows` rounded down to multiple of 16. |
865 | "ands w2, %w[rows], #-16\n" |
866 | // If there are no full blocks of 16 rows to process, jump to the |
867 | // code handling the last < 16 rows. |
868 | "beq 3f\n" |
869 | // Load the first block of 16 rows. |
870 | "add w1, w1, #16\n" |
871 | "ldr x10, [%[src_ptr0], #8]\n" |
872 | "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" |
873 | "ldr x11, [%[src_ptr1], #8]\n" |
874 | "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" |
875 | "ldr x12, [%[src_ptr2], #8]\n" |
876 | "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" |
877 | "ldr x13, [%[src_ptr3], #8]\n" |
878 | "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" |
879 | // Check if these were the only full block of 16 rows to load. |
880 | "cmp w1, w2\n" |
881 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n" ) |
882 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n" ) |
883 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n" ) |
884 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n" ) |
885 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n" ) |
886 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n" ) |
887 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n" ) |
888 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n" ) |
889 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n" ) |
890 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n" ) |
891 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n" ) |
892 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n" ) |
893 | // In that case, jump to the code handling the last loaded block of |
894 | // 16 rows. |
895 | "beq 2f\n" |
896 | |
897 | // Main loop processing blocks of 16 rows. |
898 | "1:\n" |
899 | "add w1, w1, #16\n" |
900 | // Prepare the already-loaded 16 rows by inserting the parts |
901 | // loaded into general purpose registers x10--x13 into the |
902 | // NEON registers v0--v3 where the other parts had already been |
903 | // loaded. |
904 | "ins v0.d[1], x10\n" |
905 | "ldr x10, [%[src_ptr0], #8]\n" |
906 | "ins v1.d[1], x11\n" |
907 | "ldr x11, [%[src_ptr1], #8]\n" |
908 | "ins v2.d[1], x12\n" |
909 | "ldr x12, [%[src_ptr2], #8]\n" |
910 | "ins v3.d[1], x13\n" |
911 | "ldr x13, [%[src_ptr3], #8]\n" |
912 | |
913 | // Load the next 16 rows and, interleaved with that, |
914 | // perform the input type conversion (e.g. uint8->int8) on the |
915 | // current 16 rows. |
916 | "eor v4.16b, v0.16b, v26.16b\n" |
917 | "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" |
918 | "eor v5.16b, v1.16b, v26.16b\n" |
919 | "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" |
920 | "eor v6.16b, v2.16b, v26.16b\n" |
921 | "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" |
922 | "eor v7.16b, v3.16b, v26.16b\n" |
923 | "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" |
924 | |
925 | // Transposition of 4x4 blocks, part 1 |
926 | "trn1 v16.4s, v4.4s, v5.4s\n" |
927 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n" ) |
928 | "trn2 v17.4s, v4.4s, v5.4s\n" |
929 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n" ) |
930 | "trn1 v18.4s, v6.4s, v7.4s\n" |
931 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n" ) |
932 | "trn2 v19.4s, v6.4s, v7.4s\n" |
933 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n" ) |
934 | |
935 | // Transposition of 4x4 blocks, part 2 |
936 | "trn1 v20.2d, v16.2d, v18.2d\n" |
937 | "trn2 v22.2d, v16.2d, v18.2d\n" |
938 | "trn1 v21.2d, v17.2d, v19.2d\n" |
939 | "trn2 v23.2d, v17.2d, v19.2d\n" |
940 | "cmp w1, w2\n" |
941 | |
942 | // Store the block to the packed matrix and, interleaved with |
943 | // that, compute sums using sdot instructions. |
944 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
945 | "str q20, [%[packed_ptr], #0]\n" |
946 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
947 | "str q21, [%[packed_ptr], #32]\n" |
948 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
949 | "str q22, [%[packed_ptr], #64]\n" |
950 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
951 | "str q23, [%[packed_ptr], #96]\n" |
952 | "add %[packed_ptr], %[packed_ptr], #128\n" |
953 | // End of main loop on blocks of 16 rows. |
954 | "bne 1b\n" |
955 | |
956 | // Code handling the last already-loaded block of 16 rows. |
957 | "2:\n" |
958 | // Process the last loaded full 16x4 block. |
959 | "ins v0.d[1], x10\n" |
960 | "ins v1.d[1], x11\n" |
961 | "ins v2.d[1], x12\n" |
962 | "ins v3.d[1], x13\n" |
963 | "eor v0.16b, v0.16b, v26.16b\n" |
964 | "eor v1.16b, v1.16b, v26.16b\n" |
965 | "eor v2.16b, v2.16b, v26.16b\n" |
966 | "eor v3.16b, v3.16b, v26.16b\n" |
967 | |
968 | "trn1 v16.4s, v0.4s, v1.4s\n" |
969 | "trn2 v17.4s, v0.4s, v1.4s\n" |
970 | "trn1 v18.4s, v2.4s, v3.4s\n" |
971 | "trn2 v19.4s, v2.4s, v3.4s\n" |
972 | |
973 | "trn1 v20.2d, v16.2d, v18.2d\n" |
974 | "trn2 v22.2d, v16.2d, v18.2d\n" |
975 | "trn1 v21.2d, v17.2d, v19.2d\n" |
976 | "trn2 v23.2d, v17.2d, v19.2d\n" |
977 | |
978 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
979 | "str q20, [%[packed_ptr], #0]\n" |
980 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
981 | "str q21, [%[packed_ptr], #32]\n" |
982 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
983 | "str q22, [%[packed_ptr], #64]\n" |
984 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
985 | "str q23, [%[packed_ptr], #96]\n" |
986 | "add %[packed_ptr], %[packed_ptr], #128\n" |
987 | |
988 | // End of code handling full blocks of 16 rows. |
989 | // Now we handle any remaining rows. |
990 | "3:\n" |
991 | // Let w2 be the number of rows left to handle. |
992 | "ands w2, %w[rows], #15\n" |
993 | // If w2==0, there are no remaining rows, jump to the end. |
994 | "beq 4f\n" |
995 | // Zero out a 16x4 block in registers, which we'll partially overwrite |
996 | // with any remaining rows. |
997 | "dup v0.16b, %w[src_zero_point]\n" |
998 | "dup v1.16b, %w[src_zero_point]\n" |
999 | "dup v2.16b, %w[src_zero_point]\n" |
1000 | "dup v3.16b, %w[src_zero_point]\n" |
1001 | #define RUY_LOAD_ONE_ROW(R) \ |
1002 | "cmp w2, #" #R "\n" \ |
1003 | "beq 5f\n" \ |
1004 | "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ |
1005 | "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ |
1006 | "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ |
1007 | "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" |
1008 | |
1009 | RUY_LOAD_ONE_ROW(0) |
1010 | RUY_LOAD_ONE_ROW(1) |
1011 | RUY_LOAD_ONE_ROW(2) |
1012 | RUY_LOAD_ONE_ROW(3) |
1013 | RUY_LOAD_ONE_ROW(4) |
1014 | RUY_LOAD_ONE_ROW(5) |
1015 | RUY_LOAD_ONE_ROW(6) |
1016 | RUY_LOAD_ONE_ROW(7) |
1017 | RUY_LOAD_ONE_ROW(8) |
1018 | RUY_LOAD_ONE_ROW(9) |
1019 | RUY_LOAD_ONE_ROW(10) |
1020 | RUY_LOAD_ONE_ROW(11) |
1021 | RUY_LOAD_ONE_ROW(12) |
1022 | RUY_LOAD_ONE_ROW(13) |
1023 | RUY_LOAD_ONE_ROW(14) |
1024 | // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. |
1025 | #undef RUY_LOAD_ONE_ROW |
1026 | |
1027 | "5:\n" |
1028 | // Process the last zero-padded 16x4 block. |
1029 | "eor v0.16b, v0.16b, v26.16b\n" |
1030 | "eor v1.16b, v1.16b, v26.16b\n" |
1031 | "eor v2.16b, v2.16b, v26.16b\n" |
1032 | "eor v3.16b, v3.16b, v26.16b\n" |
1033 | |
1034 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1035 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1036 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1037 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1038 | |
1039 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1040 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1041 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1042 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1043 | |
1044 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1045 | "str q20, [%[packed_ptr], #0]\n" |
1046 | "cmp w2, #4\n" |
1047 | "ble 4f\n" |
1048 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1049 | "str q21, [%[packed_ptr], #32]\n" |
1050 | "cmp w2, #8\n" |
1051 | "ble 4f\n" |
1052 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1053 | "str q22, [%[packed_ptr], #64]\n" |
1054 | "cmp w2, #12\n" |
1055 | "ble 4f\n" |
1056 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1057 | "str q23, [%[packed_ptr], #96]\n" |
1058 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1059 | |
1060 | "4:\n" |
1061 | |
1062 | // Reduction of the registers used to accumulate sums. |
1063 | "add v28.4s, v28.4s, v29.4s\n" |
1064 | "add v30.4s, v30.4s, v31.4s\n" |
1065 | "add v28.4s, v28.4s, v30.4s\n" |
1066 | |
1067 | // Store the sums. |
1068 | "cmp %[sums_ptr], #0\n" |
1069 | "beq 6f\n" |
1070 | "st1 {v28.4s}, [%[sums_ptr]], #16\n" |
1071 | "6:\n" |
1072 | // clang-format on |
1073 | |
1074 | : [ src_ptr0 ] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), [src_ptr2] "+r" (src_ptr2), |
1075 | [src_ptr3] "+r" (src_ptr3), [packed_ptr] "+r" (packed_ptr), [sums_ptr] "+r" (sums_ptr) |
1076 | : [ src_inc0 ] "r" (static_cast<std::int64_t>(src_inc0)), [ src_inc1 ] "r" (static_cast<std::int64_t>(src_inc1)), |
1077 | [ src_inc2 ] "r" (static_cast<std::int64_t>(src_inc2)), [ src_inc3 ] "r" (static_cast<std::int64_t>(src_inc3)), |
1078 | [rows] "r" (src_rows), |
1079 | [src_zero_point] "r" (static_cast<int>(src_zero_point)), |
1080 | [input_xor] "r" (input_xor) |
1081 | : "cc" , "memory" , "x1" , "x2" , "x10" , "x11" , "x12" , "x13" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , |
1082 | "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , |
1083 | "v22" , "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
1084 | } |
1085 | |
1086 | void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, |
1087 | const void* src_ptr2, const void* src_ptr3, |
1088 | int src_inc0, int src_inc1, int src_inc2, |
1089 | int src_inc3, int src_rows, |
1090 | int src_zero_point, std::int8_t* packed_ptr, |
1091 | std::int32_t* sums_ptr, int input_xor) { |
1092 | profiler::ScopeLabel label("Pack (kNeonDotprod)" ); |
1093 | asm volatile( |
1094 | // clang-format off |
1095 | // v26 will be the vector to XOR input values with to perform |
1096 | // any input data type conversion (e.g. uint8 to int8). |
1097 | "dup v26.16b, %w[input_xor]\n" |
1098 | // v27 will be filled with 1's. It will be used as an operand |
1099 | // to SDOT to compute the sums. |
1100 | "mov w1, #1\n" |
1101 | "dup v27.16b, w1\n" |
1102 | // w1 will be the number of rows already loaded. |
1103 | "mov w1, #0\n" |
1104 | // v28--v32 will be used to accumulate the sums |
1105 | "movi v28.4s, #0\n" |
1106 | "movi v29.4s, #0\n" |
1107 | "movi v30.4s, #0\n" |
1108 | "movi v31.4s, #0\n" |
1109 | |
1110 | // 4x partially unrolled code processing blocks of 64 rows. |
1111 | // Read the original loop below first, it has more comments. |
1112 | #if RUY_OPT(MAX_STREAMING) |
1113 | // Let w2 be `rows` rounded down to multiple of 64. |
1114 | // Each iteration of this 4x partially unrolled loop handles |
1115 | // 64 rows. |
1116 | "ands w2, %w[rows], #-64\n" |
1117 | // If there are no full blocks of 64 rows to process, jump to |
1118 | // the main loop below handling 16 rows per iteration. |
1119 | "beq 9f\n" |
1120 | // Load the first block of 64 rows. |
1121 | "add w1, w1, #64\n" |
1122 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1123 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1124 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1125 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1126 | "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1127 | "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1128 | "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1129 | "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1130 | "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1131 | "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1132 | "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1133 | "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1134 | "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1135 | "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1136 | // Was that the last full block of 64 rows to load? |
1137 | "cmp w1, w2\n" |
1138 | "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1139 | "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1140 | // Then jump to the end of the 64-rows-at-a-time code. |
1141 | "beq 8f\n" |
1142 | |
1143 | // Start of the main 4x partially unrolled loop. |
1144 | "7:\n" |
1145 | // Process rows 0 -- 15 out of 64. |
1146 | "eor v0.16b, v0.16b, v26.16b\n" |
1147 | "eor v1.16b, v1.16b, v26.16b\n" |
1148 | "eor v2.16b, v2.16b, v26.16b\n" |
1149 | "eor v3.16b, v3.16b, v26.16b\n" |
1150 | |
1151 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1152 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1153 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1154 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1155 | |
1156 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1157 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1158 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1159 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1160 | "add w1, w1, #16\n" |
1161 | |
1162 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1163 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1164 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1165 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1166 | |
1167 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1168 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1169 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1170 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1171 | |
1172 | "str q20, [%[packed_ptr], #0]\n" |
1173 | "str q21, [%[packed_ptr], #32]\n" |
1174 | "str q22, [%[packed_ptr], #64]\n" |
1175 | "str q23, [%[packed_ptr], #96]\n" |
1176 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1177 | |
1178 | // Process rows 16 -- 31 out of 64. |
1179 | "eor v4.16b, v4.16b, v26.16b\n" |
1180 | "eor v5.16b, v5.16b, v26.16b\n" |
1181 | "eor v6.16b, v6.16b, v26.16b\n" |
1182 | "eor v7.16b, v7.16b, v26.16b\n" |
1183 | |
1184 | "trn1 v16.4s, v4.4s, v5.4s\n" |
1185 | "trn2 v17.4s, v4.4s, v5.4s\n" |
1186 | "trn1 v18.4s, v6.4s, v7.4s\n" |
1187 | "trn2 v19.4s, v6.4s, v7.4s\n" |
1188 | |
1189 | "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1190 | "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1191 | "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1192 | "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1193 | "add w1, w1, #16\n" |
1194 | |
1195 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1196 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1197 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1198 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1199 | |
1200 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1201 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1202 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1203 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1204 | |
1205 | "str q20, [%[packed_ptr], #0]\n" |
1206 | "str q21, [%[packed_ptr], #32]\n" |
1207 | "str q22, [%[packed_ptr], #64]\n" |
1208 | "str q23, [%[packed_ptr], #96]\n" |
1209 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1210 | |
1211 | // Process rows 32 -- 47 out of 64. |
1212 | "eor v8.16b, v8.16b, v26.16b\n" |
1213 | "eor v9.16b, v9.16b, v26.16b\n" |
1214 | "eor v10.16b, v10.16b, v26.16b\n" |
1215 | "eor v11.16b, v11.16b, v26.16b\n" |
1216 | |
1217 | "trn1 v16.4s, v8.4s, v9.4s\n" |
1218 | "trn2 v17.4s, v8.4s, v9.4s\n" |
1219 | "trn1 v18.4s, v10.4s, v11.4s\n" |
1220 | "trn2 v19.4s, v10.4s, v11.4s\n" |
1221 | |
1222 | "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1223 | "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1224 | "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1225 | "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1226 | "add w1, w1, #16\n" |
1227 | |
1228 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1229 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1230 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1231 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1232 | |
1233 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1234 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1235 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1236 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1237 | |
1238 | "str q20, [%[packed_ptr], #0]\n" |
1239 | "str q21, [%[packed_ptr], #32]\n" |
1240 | "str q22, [%[packed_ptr], #64]\n" |
1241 | "str q23, [%[packed_ptr], #96]\n" |
1242 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1243 | |
1244 | // Process rows 48 -- 63 out of 64. |
1245 | "eor v12.16b, v12.16b, v26.16b\n" |
1246 | "eor v13.16b, v13.16b, v26.16b\n" |
1247 | "eor v14.16b, v14.16b, v26.16b\n" |
1248 | "eor v15.16b, v15.16b, v26.16b\n" |
1249 | |
1250 | "trn1 v16.4s, v12.4s, v13.4s\n" |
1251 | "trn2 v17.4s, v12.4s, v13.4s\n" |
1252 | "trn1 v18.4s, v14.4s, v15.4s\n" |
1253 | "trn2 v19.4s, v14.4s, v15.4s\n" |
1254 | |
1255 | "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1256 | "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1257 | "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1258 | "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1259 | "add w1, w1, #16\n" |
1260 | |
1261 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1262 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1263 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1264 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1265 | |
1266 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1267 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1268 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1269 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1270 | |
1271 | "cmp w1, w2\n" |
1272 | "str q20, [%[packed_ptr], #0]\n" |
1273 | "str q21, [%[packed_ptr], #32]\n" |
1274 | "str q22, [%[packed_ptr], #64]\n" |
1275 | "str q23, [%[packed_ptr], #96]\n" |
1276 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1277 | |
1278 | // End of main 4x partially unrolled loop. |
1279 | "bne 7b\n" |
1280 | |
1281 | // Last part of the 4x partially unrolled code: |
1282 | // handle the last already-loaded 64 rows. |
1283 | "8:\n" |
1284 | |
1285 | // Process rows 0 -- 15 out of 64. |
1286 | "eor v0.16b, v0.16b, v26.16b\n" |
1287 | "eor v1.16b, v1.16b, v26.16b\n" |
1288 | "eor v2.16b, v2.16b, v26.16b\n" |
1289 | "eor v3.16b, v3.16b, v26.16b\n" |
1290 | |
1291 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1292 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1293 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1294 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1295 | |
1296 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1297 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1298 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1299 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1300 | |
1301 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1302 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1303 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1304 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1305 | |
1306 | "str q20, [%[packed_ptr], #0]\n" |
1307 | "str q21, [%[packed_ptr], #32]\n" |
1308 | "str q22, [%[packed_ptr], #64]\n" |
1309 | "str q23, [%[packed_ptr], #96]\n" |
1310 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1311 | |
1312 | // Process rows 16 -- 31 out of 64. |
1313 | "eor v4.16b, v4.16b, v26.16b\n" |
1314 | "eor v5.16b, v5.16b, v26.16b\n" |
1315 | "eor v6.16b, v6.16b, v26.16b\n" |
1316 | "eor v7.16b, v7.16b, v26.16b\n" |
1317 | |
1318 | "trn1 v16.4s, v4.4s, v5.4s\n" |
1319 | "trn2 v17.4s, v4.4s, v5.4s\n" |
1320 | "trn1 v18.4s, v6.4s, v7.4s\n" |
1321 | "trn2 v19.4s, v6.4s, v7.4s\n" |
1322 | |
1323 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1324 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1325 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1326 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1327 | |
1328 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1329 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1330 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1331 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1332 | |
1333 | "str q20, [%[packed_ptr], #0]\n" |
1334 | "str q21, [%[packed_ptr], #32]\n" |
1335 | "str q22, [%[packed_ptr], #64]\n" |
1336 | "str q23, [%[packed_ptr], #96]\n" |
1337 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1338 | |
1339 | // Process rows 32 -- 47 out of 64. |
1340 | "eor v8.16b, v8.16b, v26.16b\n" |
1341 | "eor v9.16b, v9.16b, v26.16b\n" |
1342 | "eor v10.16b, v10.16b, v26.16b\n" |
1343 | "eor v11.16b, v11.16b, v26.16b\n" |
1344 | |
1345 | "trn1 v16.4s, v8.4s, v9.4s\n" |
1346 | "trn2 v17.4s, v8.4s, v9.4s\n" |
1347 | "trn1 v18.4s, v10.4s, v11.4s\n" |
1348 | "trn2 v19.4s, v10.4s, v11.4s\n" |
1349 | |
1350 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1351 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1352 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1353 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1354 | |
1355 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1356 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1357 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1358 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1359 | |
1360 | "str q20, [%[packed_ptr], #0]\n" |
1361 | "str q21, [%[packed_ptr], #32]\n" |
1362 | "str q22, [%[packed_ptr], #64]\n" |
1363 | "str q23, [%[packed_ptr], #96]\n" |
1364 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1365 | |
1366 | // Process rows 48 -- 63 out of 64. |
1367 | "eor v12.16b, v12.16b, v26.16b\n" |
1368 | "eor v13.16b, v13.16b, v26.16b\n" |
1369 | "eor v14.16b, v14.16b, v26.16b\n" |
1370 | "eor v15.16b, v15.16b, v26.16b\n" |
1371 | |
1372 | "trn1 v16.4s, v12.4s, v13.4s\n" |
1373 | "trn2 v17.4s, v12.4s, v13.4s\n" |
1374 | "trn1 v18.4s, v14.4s, v15.4s\n" |
1375 | "trn2 v19.4s, v14.4s, v15.4s\n" |
1376 | |
1377 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1378 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1379 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1380 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1381 | |
1382 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1383 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1384 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1385 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1386 | |
1387 | "str q20, [%[packed_ptr], #0]\n" |
1388 | "str q21, [%[packed_ptr], #32]\n" |
1389 | "str q22, [%[packed_ptr], #64]\n" |
1390 | "str q23, [%[packed_ptr], #96]\n" |
1391 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1392 | |
1393 | "9:\n" |
1394 | #endif // #if RUY_OPT(MAX_STREAMING) |
1395 | // End of 4x partially unrolled code processing blocks of 64 rows. |
1396 | |
1397 | // Main part of the code, processing blocks of 16 rows. |
1398 | |
1399 | // Let w2 be `rows` rounded down to multiple of 16. |
1400 | "and w2, %w[rows], #-16\n" |
1401 | // If there are no full blocks of 16 rows to process, jump to the |
1402 | // code handling the last < 16 rows. |
1403 | "cmp w1, w2\n" |
1404 | "beq 3f\n" |
1405 | |
1406 | // Load the first block of 16 rows. |
1407 | "add w1, w1, #16\n" |
1408 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1409 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1410 | // Check if these were the only full block of 16 rows to load. |
1411 | "cmp w1, w2\n" |
1412 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1413 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1414 | // In that case, jump to the code handling the last loaded block of |
1415 | // 16 rows. |
1416 | "beq 2f\n" |
1417 | // Main loop processing blocks of 16 rows. |
1418 | "1:\n" |
1419 | // Input type conversion (e.g. uint8->int8). |
1420 | "eor v0.16b, v0.16b, v26.16b\n" |
1421 | "eor v1.16b, v1.16b, v26.16b\n" |
1422 | "eor v2.16b, v2.16b, v26.16b\n" |
1423 | "eor v3.16b, v3.16b, v26.16b\n" |
1424 | // Transposition of 4x4 blocks, part 1 |
1425 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1426 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1427 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1428 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1429 | // Load the next 16 rows |
1430 | "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" |
1431 | "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" |
1432 | "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" |
1433 | "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" |
1434 | "add w1, w1, #16\n" |
1435 | // Transposition of 4x4 blocks, part 2 |
1436 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1437 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1438 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1439 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1440 | // Compute sums using sdot instructions. |
1441 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1442 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1443 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1444 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1445 | // Store the block to the packed matrix. |
1446 | "str q20, [%[packed_ptr], #0]\n" |
1447 | "str q21, [%[packed_ptr], #32]\n" |
1448 | "cmp w1, w2\n" |
1449 | "str q22, [%[packed_ptr], #64]\n" |
1450 | "str q23, [%[packed_ptr], #96]\n" |
1451 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1452 | // End of main loop on blocks of 16 rows. |
1453 | "bne 1b\n" |
1454 | |
1455 | // Code handling the last already-loaded block of 16 rows. |
1456 | "2:\n" |
1457 | |
1458 | // Process the last loaded full 16x4 block. |
1459 | "eor v0.16b, v0.16b, v26.16b\n" |
1460 | "eor v1.16b, v1.16b, v26.16b\n" |
1461 | "eor v2.16b, v2.16b, v26.16b\n" |
1462 | "eor v3.16b, v3.16b, v26.16b\n" |
1463 | |
1464 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1465 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1466 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1467 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1468 | |
1469 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1470 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1471 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1472 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1473 | |
1474 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1475 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1476 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1477 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1478 | |
1479 | "str q20, [%[packed_ptr], #0]\n" |
1480 | "str q21, [%[packed_ptr], #32]\n" |
1481 | "str q22, [%[packed_ptr], #64]\n" |
1482 | "str q23, [%[packed_ptr], #96]\n" |
1483 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1484 | |
1485 | // End of code handling full blocks of 16 rows. |
1486 | // Now we handle any remaining rows. |
1487 | "3:\n" |
1488 | // Let w2 be the number of rows left to handle. |
1489 | "ands w2, %w[rows], #15\n" |
1490 | // If w2==0, there are no remaining rows, jump to the end. |
1491 | "beq 4f\n" |
1492 | // Zero out a 16x4 block in registers, which we'll partially overwrite |
1493 | // with any remaining rows. |
1494 | "dup v0.16b, %w[src_zero_point]\n" |
1495 | "dup v1.16b, %w[src_zero_point]\n" |
1496 | "dup v2.16b, %w[src_zero_point]\n" |
1497 | "dup v3.16b, %w[src_zero_point]\n" |
1498 | #define RUY_LOAD_ONE_ROW(R) \ |
1499 | "cmp w2, #" #R "\n" \ |
1500 | "beq 5f\n" \ |
1501 | "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ |
1502 | "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ |
1503 | "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ |
1504 | "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" |
1505 | |
1506 | RUY_LOAD_ONE_ROW(0) |
1507 | RUY_LOAD_ONE_ROW(1) |
1508 | RUY_LOAD_ONE_ROW(2) |
1509 | RUY_LOAD_ONE_ROW(3) |
1510 | RUY_LOAD_ONE_ROW(4) |
1511 | RUY_LOAD_ONE_ROW(5) |
1512 | RUY_LOAD_ONE_ROW(6) |
1513 | RUY_LOAD_ONE_ROW(7) |
1514 | RUY_LOAD_ONE_ROW(8) |
1515 | RUY_LOAD_ONE_ROW(9) |
1516 | RUY_LOAD_ONE_ROW(10) |
1517 | RUY_LOAD_ONE_ROW(11) |
1518 | RUY_LOAD_ONE_ROW(12) |
1519 | RUY_LOAD_ONE_ROW(13) |
1520 | RUY_LOAD_ONE_ROW(14) |
1521 | // Here we know that w2==15, so RUY_LOAD_ONE_ROW(15) would be a no-op. |
1522 | #undef RUY_LOAD_ONE_ROW |
1523 | |
1524 | "5:\n" |
1525 | // Process the last zero-padded 16x4 block. |
1526 | "eor v0.16b, v0.16b, v26.16b\n" |
1527 | "eor v1.16b, v1.16b, v26.16b\n" |
1528 | "eor v2.16b, v2.16b, v26.16b\n" |
1529 | "eor v3.16b, v3.16b, v26.16b\n" |
1530 | |
1531 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1532 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1533 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1534 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1535 | |
1536 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1537 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1538 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1539 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1540 | |
1541 | ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" |
1542 | "str q20, [%[packed_ptr], #0]\n" |
1543 | "cmp w2, #4\n" |
1544 | "ble 4f\n" |
1545 | ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" |
1546 | "str q21, [%[packed_ptr], #32]\n" |
1547 | "cmp w2, #8\n" |
1548 | "ble 4f\n" |
1549 | ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" |
1550 | "str q22, [%[packed_ptr], #64]\n" |
1551 | "cmp w2, #12\n" |
1552 | "ble 4f\n" |
1553 | ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" |
1554 | "str q23, [%[packed_ptr], #96]\n" |
1555 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1556 | |
1557 | "4:\n" |
1558 | |
1559 | // Reduction of the registers used to accumulate sums. |
1560 | "add v28.4s, v28.4s, v29.4s\n" |
1561 | "add v30.4s, v30.4s, v31.4s\n" |
1562 | "add v28.4s, v28.4s, v30.4s\n" |
1563 | |
1564 | // Store the sums. |
1565 | "cmp %[sums_ptr], #0\n" |
1566 | "beq 6f\n" |
1567 | "st1 {v28.4s}, [%[sums_ptr]], #16\n" |
1568 | "6:\n" |
1569 | // clang-format on |
1570 | |
1571 | : [src_ptr0] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), |
1572 | [src_ptr2] "+r" (src_ptr2), [src_ptr3] "+r" (src_ptr3), |
1573 | [packed_ptr] "+r" (packed_ptr), [sums_ptr] "+r" (sums_ptr) |
1574 | : [src_inc0] "r" (static_cast<std::int64_t>(src_inc0)), |
1575 | [src_inc1] "r" (static_cast<std::int64_t>(src_inc1)), |
1576 | [src_inc2] "r" (static_cast<std::int64_t>(src_inc2)), |
1577 | [src_inc3] "r" (static_cast<std::int64_t>(src_inc3)), |
1578 | [rows] "r" (src_rows), |
1579 | [src_zero_point] "r" (static_cast<int>(src_zero_point)), |
1580 | [input_xor] "r" (input_xor) |
1581 | : "cc" , "memory" , "x1" , "x2" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , |
1582 | "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , "v16" , |
1583 | "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , |
1584 | "v27" , "v28" , "v29" , "v30" , "v31" ); |
1585 | } |
1586 | |
1587 | void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1, |
1588 | const void* src_ptr2, const void* src_ptr3, |
1589 | int src_inc0, int src_inc1, int src_inc2, |
1590 | int src_inc3, int src_cols, |
1591 | int src_zero_point, std::int8_t* packed_ptr, |
1592 | int packed_stride, std::int32_t* sums_ptr, |
1593 | int input_xor) { |
1594 | profiler::ScopeLabel label("Pack (kNeonDotprod, from row-major)" ); |
1595 | asm volatile( |
1596 | // clang-format off |
1597 | // Prefetch data. This was tuned on Cortex-A55-rev1 cores. |
1598 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0]]\n" ) |
1599 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1]]\n" ) |
1600 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2]]\n" ) |
1601 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3]]\n" ) |
1602 | // Let w0 = (number of columns to compute) - 8. |
1603 | "subs w0, %w[src_cols], 8\n" |
1604 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 64]\n" ) |
1605 | // Let v26 duplicate the input_xor value in all lanes. |
1606 | "dup v26.16b, %w[input_xor]\n" |
1607 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 64]\n" ) |
1608 | // Let v27 be 1 in all lanes. Used with sdot to compute sums. |
1609 | "movi v27.16b, 1\n" |
1610 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 64]\n" ) |
1611 | // If there isn't a full block of 8 columns to load from, jump to the |
1612 | // code after the loop handling leftovers. |
1613 | "blt 2f\n" |
1614 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 64]\n" ) |
1615 | // Main loop, each iteration handles a full block of 8 cols. |
1616 | "1:\n" |
1617 | // Load the 4x8 block from the source matrix, or zero if we're |
1618 | // past the bottom of the source matrix. |
1619 | "ld1 {v0.8b}, [%[src_ptr0]]\n" |
1620 | "ld1 {v1.8b}, [%[src_ptr1]]\n" |
1621 | "ld1 {v2.8b}, [%[src_ptr2]]\n" |
1622 | "ld1 {v3.8b}, [%[src_ptr3]]\n" |
1623 | // Load values from the sums buffer, and start the reordering |
1624 | // of the loaded 4x8 block by interleaving 8bit values. |
1625 | "zip1 v0.16b, v0.16b, v1.16b\n" |
1626 | "ldr q8, [%[sums_ptr], 0]\n" |
1627 | "zip1 v1.16b, v2.16b, v3.16b\n" |
1628 | "ldr q9, [%[sums_ptr], 16]\n" |
1629 | // Finish the reordering of the 4x8 block, putting it into |
1630 | // column-major order. |
1631 | "zip1 v2.8h, v0.8h, v1.8h\n" |
1632 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], 128]\n" ) |
1633 | "zip2 v3.8h, v0.8h, v1.8h\n" |
1634 | // Apply input_xor, i.e. convert source values from uint8 to int8 |
1635 | // if needed. |
1636 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], 128]\n" ) |
1637 | "eor v2.16b, v2.16b, v26.16b\n" |
1638 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], 128]\n" ) |
1639 | "eor v3.16b, v3.16b, v26.16b\n" |
1640 | // Update the sums. |
1641 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], 128]\n" ) |
1642 | ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n" |
1643 | ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n" |
1644 | // Store the column-major 4x8 block to the packed matrix, and |
1645 | // increment some source pointers. |
1646 | "str q2, [%[packed_ptr], 0]\n" |
1647 | "add %[src_ptr0], %[src_ptr0], %w[src_inc0], sxtw\n" |
1648 | "str q3, [%[packed_ptr], 16]\n" |
1649 | "add %[src_ptr1], %[src_ptr1], %w[src_inc1], sxtw\n" |
1650 | // Store the updated sums, and increment the remaining pointers |
1651 | // and the block_col loop index. |
1652 | "st1 {v8.4s}, [%[sums_ptr]], 16\n" |
1653 | "add %[packed_ptr], %[packed_ptr], %[packed_stride], lsl 3\n" |
1654 | "st1 {v9.4s}, [%[sums_ptr]], 16\n" |
1655 | // Advance by 8 columns and set the condition code. |
1656 | "subs w0, w0, 8\n" |
1657 | "add %[src_ptr2], %[src_ptr2], %w[src_inc2], sxtw\n" |
1658 | "add %[src_ptr3], %[src_ptr3], %w[src_inc3], sxtw\n" |
1659 | // End of the main loop. |
1660 | "bge 1b\n" |
1661 | |
1662 | "2:\n" |
1663 | // We add back 8 to w0 so that w0 is the number of columns remaining |
1664 | // to handle. |
1665 | "adds w0, w0, 8\n" |
1666 | // Nothing left? Then jump to the end. |
1667 | "beq 3f\n" |
1668 | // Here w0 is between 1 and 7. We zero-initialize v0--v3 ... |
1669 | "dup v0.8b, %w[src_zero_point]\n" |
1670 | "dup v1.8b, %w[src_zero_point]\n" |
1671 | "dup v2.8b, %w[src_zero_point]\n" |
1672 | "dup v3.8b, %w[src_zero_point]\n" |
1673 | // ... and now we fill lanes one by one with leftover columns. |
1674 | #define RUY_LOAD_ONE_COL(C)\ |
1675 | "cmp w0, " #C "\n" \ |
1676 | "beq 4f\n" \ |
1677 | "ld1 { v0.b }[" #C "], [%[src_ptr0]], #1\n" \ |
1678 | "ld1 { v1.b }[" #C "], [%[src_ptr1]], #1\n" \ |
1679 | "ld1 { v2.b }[" #C "], [%[src_ptr2]], #1\n" \ |
1680 | "ld1 { v3.b }[" #C "], [%[src_ptr3]], #1\n" |
1681 | |
1682 | RUY_LOAD_ONE_COL(0) |
1683 | RUY_LOAD_ONE_COL(1) |
1684 | RUY_LOAD_ONE_COL(2) |
1685 | RUY_LOAD_ONE_COL(3) |
1686 | RUY_LOAD_ONE_COL(4) |
1687 | RUY_LOAD_ONE_COL(5) |
1688 | RUY_LOAD_ONE_COL(6) |
1689 | // Here we know that w0==7, so RUY_LOAD_ONE_COL(7) would be a no-op. |
1690 | #undef RUY_LOAD_ONE_COL |
1691 | |
1692 | "4:\n" |
1693 | // The leftovers source data is loaded, now we can perform the |
1694 | // computation as usual. |
1695 | // Load values from the sums buffer, and start the reordering |
1696 | // of the loaded 4x8 block by interleaving 8bit values. |
1697 | "zip1 v0.16b, v0.16b, v1.16b\n" |
1698 | "ldr q8, [%[sums_ptr], 0]\n" |
1699 | "zip1 v1.16b, v2.16b, v3.16b\n" |
1700 | "ldr q9, [%[sums_ptr], 16]\n" |
1701 | // Finish the reordering of the 4x8 block, putting it into |
1702 | // column-major order. |
1703 | "zip1 v2.8h, v0.8h, v1.8h\n" |
1704 | "zip2 v3.8h, v0.8h, v1.8h\n" |
1705 | // Apply input_xor, i.e. convert source values from uint8 to int8 |
1706 | // if needed. |
1707 | "eor v2.16b, v2.16b, v26.16b\n" |
1708 | "eor v3.16b, v3.16b, v26.16b\n" |
1709 | // Update the sums. |
1710 | ".word 0x4e9b9448 // sdot v8.4s, v2.16b, v27.16b\n" |
1711 | ".word 0x4e9b9469 // sdot v9.4s, v3.16b, v27.16b\n" |
1712 | // Store the column-major 4x8 block to the packed matrix, and |
1713 | // increment some source pointers. |
1714 | "str q2, [%[packed_ptr], 0]\n" |
1715 | "str q3, [%[packed_ptr], 16]\n" |
1716 | // Store the updated sums, and increment the remaining pointers |
1717 | // and the block_col loop index. |
1718 | "st1 {v8.4s}, [%[sums_ptr]], 16\n" |
1719 | "st1 {v9.4s}, [%[sums_ptr]], 16\n" |
1720 | |
1721 | // End label. |
1722 | "3:\n" |
1723 | // clang-format on |
1724 | : [packed_ptr] "+r" (packed_ptr), [sums_ptr] "+r" (sums_ptr), |
1725 | [src_ptr0] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), |
1726 | [src_ptr2] "+r" (src_ptr2), [src_ptr3] "+r" (src_ptr3) |
1727 | : [src_inc0] "r" (src_inc0), [src_inc1] "r" (src_inc1), |
1728 | [src_inc2] "r" (src_inc2), [src_inc3] "r" (src_inc3), |
1729 | [input_xor] "r" (input_xor), [src_zero_point] "r" (src_zero_point), |
1730 | [packed_stride] "r" (static_cast<std::int64_t>(packed_stride)), |
1731 | [src_cols] "r" (src_cols) |
1732 | : "cc" , "memory" , "x0" , "x1" , "x2" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , |
1733 | "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , "v16" , |
1734 | "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , |
1735 | "v27" , "v28" , "v29" , "v30" , "v31" ); |
1736 | } |
1737 | |
1738 | void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, |
1739 | const float* src_ptr2, const float* src_ptr3, |
1740 | int src_inc0, int src_inc1, int src_inc2, |
1741 | int src_inc3, int src_rows, float* packed_ptr) { |
1742 | profiler::ScopeLabel label("Pack (kNeon)" ); |
1743 | asm volatile( |
1744 | // clang-format off |
1745 | // w1 will be the number of rows already loaded. |
1746 | "mov w1, #0\n" |
1747 | // Let w2 be `rows` rounded down to multiple of 4. |
1748 | "ands w2, %w[rows], #-4\n" |
1749 | // If there are no full blocks of 4 rows to process, jump to the |
1750 | // code handling the last < 4 rows. |
1751 | "beq 3f\n" |
1752 | // Load the first block of 16 rows. |
1753 | "add w1, w1, #4\n" |
1754 | "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" |
1755 | "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" |
1756 | // Check if these were the only full block of 4 rows to load. |
1757 | "cmp w1, w2\n" |
1758 | "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" |
1759 | "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" |
1760 | // In that case, jump to the code handling the last loaded block of |
1761 | // 4 rows. |
1762 | "beq 2f\n" |
1763 | // Main loop processing blocks of 4 rows. |
1764 | "1:\n" |
1765 | // Advance by 4 rows. |
1766 | "add w1, w1, #4\n" |
1767 | // Transposition of the already-loaded 4x4 block, part 1. |
1768 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1769 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1770 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1771 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1772 | // Load the next 4x4 block. |
1773 | "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" |
1774 | "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" |
1775 | "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" |
1776 | "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" |
1777 | // Transposition of the already-loaded 4x4 block, part 2. |
1778 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1779 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1780 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1781 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1782 | // Was this the last full 4x4 block to load? |
1783 | "cmp w1, w2\n" |
1784 | // Store the transposed 4x4 block. |
1785 | "str q20, [%[packed_ptr], #0]\n" |
1786 | "str q21, [%[packed_ptr], #32]\n" |
1787 | "str q22, [%[packed_ptr], #64]\n" |
1788 | "str q23, [%[packed_ptr], #96]\n" |
1789 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1790 | // End of main loop on 4x4 blocks. |
1791 | "bne 1b\n" |
1792 | |
1793 | // Code handling the last already-loaded 4x4 block. |
1794 | "2:\n" |
1795 | |
1796 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1797 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1798 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1799 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1800 | |
1801 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1802 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1803 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1804 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1805 | |
1806 | "str q20, [%[packed_ptr], #0]\n" |
1807 | "str q21, [%[packed_ptr], #32]\n" |
1808 | "str q22, [%[packed_ptr], #64]\n" |
1809 | "str q23, [%[packed_ptr], #96]\n" |
1810 | "add %[packed_ptr], %[packed_ptr], #128\n" |
1811 | |
1812 | // End of code handling full 4x4 blocks. |
1813 | // Now we handle any remaining rows. |
1814 | "3:\n" |
1815 | // Let w2 be the number of rows left to handle. |
1816 | "ands w2, %w[rows], #3\n" |
1817 | // If w2==0, there are no remaining rows, jump to the end. |
1818 | "beq 4f\n" |
1819 | // Zero out a 4x4 block in registers, which we'll partially overwrite |
1820 | // with any remaining rows. |
1821 | "movi v0.16b, #0\n" |
1822 | "movi v1.16b, #0\n" |
1823 | "movi v2.16b, #0\n" |
1824 | "movi v3.16b, #0\n" |
1825 | #define RUY_LOAD_ONE_ROW(R) \ |
1826 | "cmp w2, #" #R "\n" \ |
1827 | "beq 5f\n" \ |
1828 | "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ |
1829 | "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ |
1830 | "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ |
1831 | "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" |
1832 | |
1833 | RUY_LOAD_ONE_ROW(0) |
1834 | RUY_LOAD_ONE_ROW(1) |
1835 | RUY_LOAD_ONE_ROW(2) |
1836 | // Here we know that w2==3, so RUY_LOAD_ONE_ROW(3) would be a no-op. |
1837 | #undef RUY_LOAD_ONE_ROW |
1838 | "5:\n" |
1839 | |
1840 | // Transpose that last zero-padded 4x4 block. |
1841 | "trn1 v16.4s, v0.4s, v1.4s\n" |
1842 | "trn2 v17.4s, v0.4s, v1.4s\n" |
1843 | "trn1 v18.4s, v2.4s, v3.4s\n" |
1844 | "trn2 v19.4s, v2.4s, v3.4s\n" |
1845 | |
1846 | "trn1 v20.2d, v16.2d, v18.2d\n" |
1847 | "trn2 v22.2d, v16.2d, v18.2d\n" |
1848 | "trn1 v21.2d, v17.2d, v19.2d\n" |
1849 | "trn2 v23.2d, v17.2d, v19.2d\n" |
1850 | |
1851 | // Store that last zero-padded block to the packed matrix. |
1852 | "mov x1, #32\n" |
1853 | #define RUY_STORE_ONE_ROW(ROW, REGISTER) \ |
1854 | "cmp w2, #" #ROW "\n" \ |
1855 | "beq 4f\n" \ |
1856 | "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" |
1857 | |
1858 | RUY_STORE_ONE_ROW(0, v20) |
1859 | RUY_STORE_ONE_ROW(1, v21) |
1860 | RUY_STORE_ONE_ROW(2, v22) |
1861 | RUY_STORE_ONE_ROW(3, v23) |
1862 | |
1863 | #undef RUY_STORE_ONE_ROW |
1864 | |
1865 | "4:\n" |
1866 | |
1867 | // clang-format on |
1868 | |
1869 | : [src_ptr0] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), |
1870 | [src_ptr2] "+r" (src_ptr2), [src_ptr3] "+r" (src_ptr3), |
1871 | [packed_ptr] "+r" (packed_ptr) |
1872 | : [src_inc0] "r" (static_cast<std::int64_t>(src_inc0)), |
1873 | [src_inc1] "r" (static_cast<std::int64_t>(src_inc1)), |
1874 | [src_inc2] "r" (static_cast<std::int64_t>(src_inc2)), |
1875 | [src_inc3] "r" (static_cast<std::int64_t>(src_inc3)), |
1876 | [rows] "r" (src_rows) |
1877 | : "cc" , "memory" , "x1" , "x2" , "x10" , "x11" , "x12" , "x13" , "v0" , "v1" , |
1878 | "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , "v10" , "v11" , "v12" , |
1879 | "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , |
1880 | "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
1881 | } |
1882 | #endif |
1883 | |
1884 | #if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM) |
1885 | void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1, |
1886 | const float* src_ptr2, const float* src_ptr3, |
1887 | int src_inc, int src_rows, float* packed_ptr, |
1888 | int output_stride) { |
1889 | profiler::ScopeLabel label("Pack (kNeon)" ); |
1890 | asm volatile( |
1891 | // clang-format off |
1892 | "mov r1, #0\n" |
1893 | "and r2, %[rows], #-4\n" |
1894 | "cmp r1, r2\n" |
1895 | "beq 3f\n" |
1896 | #define RUY_LOAD_FOUR_BY_FOUR() \ |
1897 | /* Load q0 */ \ |
1898 | "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ |
1899 | /* if src_inc0 != 0, add 16 to src_ptr0 */ \ |
1900 | "and r3, %[src_inc], #1\n" \ |
1901 | "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ |
1902 | /* Load q1 */ \ |
1903 | "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ |
1904 | /* if src_inc1 != 0, add 16 to src_ptr0 */ \ |
1905 | "and r3, %[src_inc], #2\n" \ |
1906 | "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ |
1907 | /* Load q2 */ \ |
1908 | "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ |
1909 | /* if src_inc2 != 0, add 16 to src_ptr0 */ \ |
1910 | "and r3, %[src_inc], #4\n" \ |
1911 | "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ |
1912 | /* Load q3 */ \ |
1913 | "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ |
1914 | /* if src_inc3 != 0, add 16 to src_ptr0 */ \ |
1915 | "and r3, %[src_inc], #8\n" \ |
1916 | "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ |
1917 | |
1918 | RUY_LOAD_FOUR_BY_FOUR() |
1919 | "add r1, r1, #4\n" |
1920 | "cmp r1, r2\n" |
1921 | |
1922 | "beq 2f\n" |
1923 | |
1924 | "1:\n" |
1925 | "add r1, r1, #4\n" |
1926 | |
1927 | // Transpose 4x4 matrix. |
1928 | "vzip.32 q0, q1\n" |
1929 | "vzip.32 q2, q3\n" |
1930 | |
1931 | "vtrn.32 q0, q2\n" |
1932 | "vtrn.32 q1, q3\n" |
1933 | |
1934 | "vzip.32 q0, q2\n" |
1935 | "vzip.32 q1, q3\n" |
1936 | |
1937 | "vmov q8, q0\n" |
1938 | "vmov q9, q1\n" |
1939 | "vmov q10, q2\n" |
1940 | "vmov q11, q3\n" |
1941 | |
1942 | RUY_LOAD_FOUR_BY_FOUR() |
1943 | #undef RUY_LOAD_FOUR_BY_FOUR |
1944 | |
1945 | #define RUY_STORE_FOUR_BY_FOUR() \ |
1946 | /* Store q8, q10, q9, q11 */ \ |
1947 | /* q8 = d16, d17 */ \ |
1948 | "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ |
1949 | /* q10 = d20, d21 */ \ |
1950 | "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ |
1951 | "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ |
1952 | /* q9 = d18, d19 */ \ |
1953 | "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ |
1954 | "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ |
1955 | /* q11 = d22, d23 */ \ |
1956 | "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ |
1957 | "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ |
1958 | "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ |
1959 | |
1960 | RUY_STORE_FOUR_BY_FOUR() |
1961 | "cmp r1, r2\n" |
1962 | |
1963 | "bne 1b\n" |
1964 | |
1965 | "2:\n" |
1966 | |
1967 | // Transpose 4x4 matrix. |
1968 | "vzip.32 q0, q1\n" |
1969 | "vzip.32 q2, q3\n" |
1970 | |
1971 | "vtrn.32 q0, q2\n" |
1972 | "vtrn.32 q1, q3\n" |
1973 | |
1974 | "vzip.32 q0, q2\n" |
1975 | "vzip.32 q1, q3\n" |
1976 | |
1977 | "vmov q8, q0\n" |
1978 | "vmov q9, q1\n" |
1979 | "vmov q10, q2\n" |
1980 | "vmov q11, q3\n" |
1981 | |
1982 | RUY_STORE_FOUR_BY_FOUR() |
1983 | #undef RUY_STORE_FOUR_BY_FOUR |
1984 | "3:\n" |
1985 | |
1986 | "ands r2, %[rows], #3\n" |
1987 | "beq 4f\n" |
1988 | "mov r0, #0\n" |
1989 | // Zero out q0 - q3 |
1990 | "vdup.32 q0, r0\n" |
1991 | "vdup.32 q1, r0\n" |
1992 | "vdup.32 q2, r0\n" |
1993 | "vdup.32 q3, r0\n" |
1994 | #define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ |
1995 | "cmp r2, #" #R "\n" \ |
1996 | "beq 5f\n" \ |
1997 | "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ |
1998 | "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ |
1999 | "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ |
2000 | "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" |
2001 | |
2002 | #define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ |
2003 | "cmp r2, #" #R "\n" \ |
2004 | "beq 5f\n" \ |
2005 | "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ |
2006 | "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ |
2007 | "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ |
2008 | "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" |
2009 | |
2010 | RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) |
2011 | RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) |
2012 | RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) |
2013 | RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) |
2014 | #undef RUY_LOAD_ONE_ROW_SECOND_HALF |
2015 | #undef RUY_LOAD_ONE_ROW_FIRST_HALF |
2016 | "5:\n" |
2017 | |
2018 | // Transpose 4x4 matrix. |
2019 | "vzip.32 q0, q1\n" |
2020 | "vzip.32 q2, q3\n" |
2021 | |
2022 | "vtrn.32 q0, q2\n" |
2023 | "vtrn.32 q1, q3\n" |
2024 | |
2025 | "vzip.32 q0, q2\n" |
2026 | "vzip.32 q1, q3\n" |
2027 | |
2028 | "vmov q8, q0\n" |
2029 | "vmov q9, q1\n" |
2030 | "vmov q10, q2\n" |
2031 | "vmov q11, q3\n" |
2032 | |
2033 | "mov r1, #32\n" |
2034 | |
2035 | #define RUY_STORE_ONE_ROW(ROW, REGISTER) \ |
2036 | "cmp r2, #" #ROW "\n" \ |
2037 | "beq 4f\n" \ |
2038 | "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ |
2039 | "add %[packed_ptr], %[packed_ptr], %[stride]\n" |
2040 | |
2041 | // Store q8 |
2042 | RUY_STORE_ONE_ROW(0, q8) |
2043 | // Store q10 |
2044 | RUY_STORE_ONE_ROW(1, q10) |
2045 | // Store q9 |
2046 | RUY_STORE_ONE_ROW(2, q9) |
2047 | // Store q11 |
2048 | RUY_STORE_ONE_ROW(3, q11) |
2049 | |
2050 | #undef RUY_STORE_ONE_ROW |
2051 | |
2052 | "4:\n" |
2053 | |
2054 | // clang-format on |
2055 | : [ src_ptr0 ] "+r" (src_ptr0), [ src_ptr1 ] "+r" (src_ptr1), |
2056 | [ src_ptr2 ] "+r" (src_ptr2), [ src_ptr3 ] "+r" (src_ptr3), |
2057 | [ packed_ptr ] "+r" (packed_ptr) |
2058 | : [ src_inc ] "r" (static_cast<std::int64_t>(src_inc)), |
2059 | [ rows ] "r" (src_rows), [ stride ] "r" (output_stride) |
2060 | : "cc" , "memory" , "r0" , "r1" , "r2" , "r3" , "q0" , "q1" , "q2" , "q3" , |
2061 | "q4" , "q5" , "q6" , "q7" , "q8" , "q9" , "q10" , "q11" ); |
2062 | } |
2063 | |
2064 | #endif // (RUY_PLATFORM_NEON_32 |
2065 | |
2066 | #if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
2067 | void PackFloatColMajorForNeonA55ish(const float* src_ptr0, |
2068 | const float* src_ptr1, |
2069 | const float* src_ptr2, |
2070 | const float* src_ptr3, int src_inc0, |
2071 | int src_inc1, int src_inc2, int src_inc3, |
2072 | int src_rows, float* packed_ptr) { |
2073 | profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)" ); |
2074 | |
2075 | asm volatile( |
2076 | // clang-format off |
2077 | "mov w1, #0\n" |
2078 | |
2079 | "and w2, %w[rows], #-4\n" |
2080 | "cmp w1, w2\n" |
2081 | "beq 3f\n" |
2082 | "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" |
2083 | "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" |
2084 | "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" |
2085 | "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" |
2086 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n" ) |
2087 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n" ) |
2088 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n" ) |
2089 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n" ) |
2090 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n" ) |
2091 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n" ) |
2092 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n" ) |
2093 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n" ) |
2094 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n" ) |
2095 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n" ) |
2096 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n" ) |
2097 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n" ) |
2098 | "add w1, w1, #4\n" |
2099 | "cmp w1, w2\n" |
2100 | |
2101 | "beq 2f\n" |
2102 | |
2103 | "1:\n" |
2104 | "add w1, w1, #4\n" |
2105 | |
2106 | "ldr x10, [%[src_ptr0], #8]\n" |
2107 | "trn1 v16.4s, v0.4s, v1.4s\n" |
2108 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n" ) |
2109 | "ldr x11, [%[src_ptr1], #8]\n" |
2110 | "trn2 v17.4s, v0.4s, v1.4s\n" |
2111 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n" ) |
2112 | "ldr x12, [%[src_ptr2], #8]\n" |
2113 | "trn1 v18.4s, v2.4s, v3.4s\n" |
2114 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n" ) |
2115 | "ldr x13, [%[src_ptr3], #8]\n" |
2116 | "trn2 v19.4s, v2.4s, v3.4s\n" |
2117 | RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n" ) |
2118 | |
2119 | "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" |
2120 | "trn1 v20.2d, v16.2d, v18.2d\n" |
2121 | "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" |
2122 | "trn2 v22.2d, v16.2d, v18.2d\n" |
2123 | "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" |
2124 | "trn1 v21.2d, v17.2d, v19.2d\n" |
2125 | "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" |
2126 | "trn2 v23.2d, v17.2d, v19.2d\n" |
2127 | "cmp w1, w2\n" |
2128 | |
2129 | "ins v0.d[1], x10\n" |
2130 | "str q20, [%[packed_ptr], #0]\n" |
2131 | "ins v1.d[1], x11\n" |
2132 | "str q21, [%[packed_ptr], #32]\n" |
2133 | "ins v2.d[1], x12\n" |
2134 | "str q22, [%[packed_ptr], #64]\n" |
2135 | "ins v3.d[1], x13\n" |
2136 | "str q23, [%[packed_ptr], #96]\n" |
2137 | |
2138 | "add %[packed_ptr], %[packed_ptr], #128\n" |
2139 | |
2140 | "bne 1b\n" |
2141 | |
2142 | "2:\n" |
2143 | |
2144 | "trn1 v16.4s, v0.4s, v1.4s\n" |
2145 | "trn2 v17.4s, v0.4s, v1.4s\n" |
2146 | "trn1 v18.4s, v2.4s, v3.4s\n" |
2147 | "trn2 v19.4s, v2.4s, v3.4s\n" |
2148 | |
2149 | "trn1 v20.2d, v16.2d, v18.2d\n" |
2150 | "trn2 v22.2d, v16.2d, v18.2d\n" |
2151 | "trn1 v21.2d, v17.2d, v19.2d\n" |
2152 | "trn2 v23.2d, v17.2d, v19.2d\n" |
2153 | |
2154 | "str q20, [%[packed_ptr], #0]\n" |
2155 | "str q21, [%[packed_ptr], #32]\n" |
2156 | "str q22, [%[packed_ptr], #64]\n" |
2157 | "str q23, [%[packed_ptr], #96]\n" |
2158 | "add %[packed_ptr], %[packed_ptr], #128\n" |
2159 | |
2160 | "3:\n" |
2161 | |
2162 | "ands w2, %w[rows], #3\n" |
2163 | "beq 4f\n" |
2164 | "movi v0.16b, #0\n" |
2165 | "movi v1.16b, #0\n" |
2166 | "movi v2.16b, #0\n" |
2167 | "movi v3.16b, #0\n" |
2168 | #define RUY_LOAD_ONE_ROW(R) \ |
2169 | "cmp w2, #" #R "\n" \ |
2170 | "beq 5f\n" \ |
2171 | "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ |
2172 | "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ |
2173 | "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ |
2174 | "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" |
2175 | |
2176 | RUY_LOAD_ONE_ROW(0) |
2177 | RUY_LOAD_ONE_ROW(1) |
2178 | RUY_LOAD_ONE_ROW(2) |
2179 | RUY_LOAD_ONE_ROW(3) |
2180 | #undef RUY_LOAD_ONE_ROW |
2181 | "5:\n" |
2182 | |
2183 | "trn1 v16.4s, v0.4s, v1.4s\n" |
2184 | "trn2 v17.4s, v0.4s, v1.4s\n" |
2185 | "trn1 v18.4s, v2.4s, v3.4s\n" |
2186 | "trn2 v19.4s, v2.4s, v3.4s\n" |
2187 | |
2188 | "trn1 v20.2d, v16.2d, v18.2d\n" |
2189 | "trn2 v22.2d, v16.2d, v18.2d\n" |
2190 | "trn1 v21.2d, v17.2d, v19.2d\n" |
2191 | "trn2 v23.2d, v17.2d, v19.2d\n" |
2192 | |
2193 | "mov x1, #32\n" |
2194 | |
2195 | #define RUY_STORE_ONE_ROW(ROW, REGISTER) \ |
2196 | "cmp w2, #" #ROW "\n" \ |
2197 | "beq 4f\n" \ |
2198 | "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" |
2199 | |
2200 | RUY_STORE_ONE_ROW(0, v20) |
2201 | RUY_STORE_ONE_ROW(1, v21) |
2202 | RUY_STORE_ONE_ROW(2, v22) |
2203 | RUY_STORE_ONE_ROW(3, v23) |
2204 | |
2205 | #undef RUY_STORE_ONE_ROW |
2206 | |
2207 | "4:\n" |
2208 | |
2209 | // clang-format on |
2210 | |
2211 | : [ src_ptr0 ] "+r" (src_ptr0), [src_ptr1] "+r" (src_ptr1), [src_ptr2] "+r" (src_ptr2), |
2212 | [src_ptr3] "+r" (src_ptr3), [packed_ptr] "+r" (packed_ptr) |
2213 | : [ src_inc0 ] "r" (static_cast<std::int64_t>(src_inc0)), [src_inc1] "r" (static_cast<std::int64_t>(src_inc1)), [src_inc2] "r" (static_cast<std::int64_t>(src_inc2)), |
2214 | [src_inc3] "r" (static_cast<std::int64_t>(src_inc3)), [rows] "r" (src_rows) |
2215 | : "cc" , "memory" , "x1" , "x2" , "x10" , "x11" , "x12" , "x13" , "v0" , "v1" , "v2" , "v3" , "v4" , "v5" , "v6" , "v7" , "v8" , "v9" , |
2216 | "v10" , "v11" , "v12" , "v13" , "v14" , "v15" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , |
2217 | "v22" , "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" ); |
2218 | } |
2219 | #endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM) |
2220 | |
2221 | #if RUY_PLATFORM_NEON |
2222 | |
2223 | namespace { |
2224 | // transpose_*bit_vals are wrappers around ARM TRN1 instructions, allowing |
2225 | // to use these instructions like we would in assembly --- this is one instance |
2226 | // where assembly is more idiomatic than intrinsics. |
2227 | // |
2228 | // The way that TRN1 is exposed by vtrn_* intrinsics makes its usage very |
2229 | // cumbersome. The issue is that transposing grouped of values has been exposed |
2230 | // only as transposing values of a wider type, so this requires many |
2231 | // vreinterpret's, and to make it worse, vtrn_* return NEON array types like |
2232 | // int8x8x2_t for which vreinterpret's are not defined! |
2233 | void transpose_8bit_vals(int8x8_t& a, int8x8_t& b) { |
2234 | int8x8x2_t t = vtrn_s8(a, b); |
2235 | a = t.val[0]; |
2236 | b = t.val[1]; |
2237 | } |
2238 | |
2239 | void transpose_16bit_vals(int8x8_t& a, int8x8_t& b) { |
2240 | int16x4x2_t t = vtrn_s16(vreinterpret_s16_s8(a), vreinterpret_s16_s8(b)); |
2241 | a = vreinterpret_s8_s16(t.val[0]); |
2242 | b = vreinterpret_s8_s16(t.val[1]); |
2243 | } |
2244 | |
2245 | void transpose_32bit_vals(int8x8_t& a, int8x8_t& b) { |
2246 | int32x2x2_t t = vtrn_s32(vreinterpret_s32_s8(a), vreinterpret_s32_s8(b)); |
2247 | a = vreinterpret_s8_s32(t.val[0]); |
2248 | b = vreinterpret_s8_s32(t.val[1]); |
2249 | } |
2250 | } // namespace |
2251 | |
2252 | void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride, |
2253 | int src_rows, int src_cols, int block_row, |
2254 | int start_col, int end_col, |
2255 | std::int8_t* packed_ptr, int packed_stride, |
2256 | int packed_zero_point, std::int32_t* sums, |
2257 | int input_xor, int kernel_cols) { |
2258 | profiler::ScopeLabel label("Pack (kNeon, from row-major)" ); |
2259 | |
2260 | int src_end_col = std::min(end_col, src_cols); |
2261 | int col = start_col; |
2262 | for (; col <= src_end_col - 8; col += 8) { |
2263 | // Each iteration of this loop handles 8 columns, and the kernel format |
2264 | // has 16 rows, so each iteration handles a 16x8 block. |
2265 | // |
2266 | // Since the source is row-major, handling 8 columns at a time means |
2267 | // loading only 8 bytes i.e. 64bit from each row. This may seem surprising |
2268 | // on 128bit SIMD like NEON. While we could handle 16 columns at a time, |
2269 | // we prefer to stick with 8 for the following reasons: |
2270 | // 1. The arithmetic (computing sums and transposing data) done on these |
2271 | // values is such that even though we initially start from 64bit vectors, |
2272 | // most of our NEON instructions are full 128bit instructions. For the |
2273 | // sums computation, that is because summing 8bit values requires |
2274 | // expansion to 16bit anyway. For the matrix transposition code, that is |
2275 | // because the ARM ZIP instructions take 64bit of data from two input |
2276 | // registers and zip it into a 128bit output. If we had 128bit of data |
2277 | // in each input registers, we would need 2x more ARM NEON instructions |
2278 | // to zip it. |
2279 | // 2. The main optimization target for this (ARM, 8bit, non-dotprod) |
2280 | // code path is in-order ARM cores such as the Cortex-A53, which prefer |
2281 | // 64bit loads anyway. |
2282 | // 3. Handling only 8 columns at a time limits the size of the final |
2283 | // leftover columns handled with slow scalar code. |
2284 | // |
2285 | // This code is not very optimized anyway, as evidenced from the facts that |
2286 | // (1) it's written in intrinsics, (2) it's not using separate versions |
2287 | // tuned for different types of CPU cores. At the level of optimization that |
2288 | // it's working at, this seems like a fair compromise. If one wanted to |
2289 | // maximize performance at the cost of more code complexity/size, one could |
2290 | // have code handling 16 columns at a time (maybe limited to |
2291 | // Tuning::kGeneric), then 8, then 4 to minimize the amount of slow |
2292 | // leftovers. |
2293 | // |
2294 | // Load 8 sums in sums0, sums1. |
2295 | int32x4_t sums0 = vld1q_s32(sums + col); |
2296 | int32x4_t sums1 = vld1q_s32(sums + col + 4); |
2297 | // Load the 8x16 block from the source matrix. |
2298 | // Each val* here is the data from one row. |
2299 | int8x8_t val0, val1, val2, val3, val4, val5, val6, val7, val8, val9, val10, |
2300 | val11, val12, val13, val14, val15; |
2301 | // Even though this function takes a uint8_t* src_ptr, that's only a |
2302 | // type-erased pointer (using uint8_t* so that pointer arithmetic is |
2303 | // allowed). The actual type may be either uint8_t or int8_t. The only |
2304 | // difference it makes is that if it's uint8_t then we need to flip the |
2305 | // sign bit. This is specified by the input_xor value (which is 0x80 if the |
2306 | // input data is uint8_t, and 0x0 otherwise). |
2307 | auto load_and_convert = [=](const std::uint8_t* from) { |
2308 | return vreinterpret_s8_u8(veor_u8(vdup_n_u8(input_xor), vld1_u8(from))); |
2309 | }; |
2310 | if (block_row <= src_rows - 16) { |
2311 | // Load data in the regular case: there are still 16 rows to be read from |
2312 | // the source matrix. |
2313 | val0 = load_and_convert(src_ptr + 0 * src_stride); |
2314 | val1 = load_and_convert(src_ptr + 1 * src_stride); |
2315 | val2 = load_and_convert(src_ptr + 2 * src_stride); |
2316 | val3 = load_and_convert(src_ptr + 3 * src_stride); |
2317 | val4 = load_and_convert(src_ptr + 4 * src_stride); |
2318 | val5 = load_and_convert(src_ptr + 5 * src_stride); |
2319 | val6 = load_and_convert(src_ptr + 6 * src_stride); |
2320 | val7 = load_and_convert(src_ptr + 7 * src_stride); |
2321 | val8 = load_and_convert(src_ptr + 8 * src_stride); |
2322 | val9 = load_and_convert(src_ptr + 9 * src_stride); |
2323 | val10 = load_and_convert(src_ptr + 10 * src_stride); |
2324 | val11 = load_and_convert(src_ptr + 11 * src_stride); |
2325 | val12 = load_and_convert(src_ptr + 12 * src_stride); |
2326 | val13 = load_and_convert(src_ptr + 13 * src_stride); |
2327 | val14 = load_and_convert(src_ptr + 14 * src_stride); |
2328 | val15 = load_and_convert(src_ptr + 15 * src_stride); |
2329 | } else { |
2330 | // Boundary case: there are fewer than 16 rows to be read from the source |
2331 | // matrix. We pad by the zero_point. |
2332 | val0 = vdup_n_s8(packed_zero_point); |
2333 | val1 = val0; |
2334 | val2 = val0; |
2335 | val3 = val0; |
2336 | val4 = val0; |
2337 | val5 = val0; |
2338 | val6 = val0; |
2339 | val7 = val0; |
2340 | val8 = val0; |
2341 | val9 = val0; |
2342 | val10 = val0; |
2343 | val11 = val0; |
2344 | val12 = val0; |
2345 | val13 = val0; |
2346 | val14 = val0; |
2347 | val15 = val0; |
2348 | if (block_row + 0 < src_rows) |
2349 | val0 = load_and_convert(src_ptr + 0 * src_stride); |
2350 | if (block_row + 1 < src_rows) |
2351 | val1 = load_and_convert(src_ptr + 1 * src_stride); |
2352 | if (block_row + 2 < src_rows) |
2353 | val2 = load_and_convert(src_ptr + 2 * src_stride); |
2354 | if (block_row + 3 < src_rows) |
2355 | val3 = load_and_convert(src_ptr + 3 * src_stride); |
2356 | if (block_row + 4 < src_rows) |
2357 | val4 = load_and_convert(src_ptr + 4 * src_stride); |
2358 | if (block_row + 5 < src_rows) |
2359 | val5 = load_and_convert(src_ptr + 5 * src_stride); |
2360 | if (block_row + 6 < src_rows) |
2361 | val6 = load_and_convert(src_ptr + 6 * src_stride); |
2362 | if (block_row + 7 < src_rows) |
2363 | val7 = load_and_convert(src_ptr + 7 * src_stride); |
2364 | if (block_row + 8 < src_rows) |
2365 | val8 = load_and_convert(src_ptr + 8 * src_stride); |
2366 | if (block_row + 9 < src_rows) |
2367 | val9 = load_and_convert(src_ptr + 9 * src_stride); |
2368 | if (block_row + 10 < src_rows) |
2369 | val10 = load_and_convert(src_ptr + 10 * src_stride); |
2370 | if (block_row + 11 < src_rows) |
2371 | val11 = load_and_convert(src_ptr + 11 * src_stride); |
2372 | if (block_row + 12 < src_rows) |
2373 | val12 = load_and_convert(src_ptr + 12 * src_stride); |
2374 | if (block_row + 13 < src_rows) |
2375 | val13 = load_and_convert(src_ptr + 13 * src_stride); |
2376 | if (block_row + 14 < src_rows) |
2377 | val14 = load_and_convert(src_ptr + 14 * src_stride); |
2378 | if (block_row + 15 < src_rows) |
2379 | val15 = load_and_convert(src_ptr + 15 * src_stride); |
2380 | } |
2381 | src_ptr += 8; |
2382 | // Compute sums. |
2383 | int16x8_t sums16_0 = vaddl_s8(val0, val1); |
2384 | int16x8_t sums16_1 = vaddl_s8(val2, val3); |
2385 | sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val4, val5)); |
2386 | sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val6, val7)); |
2387 | sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val8, val9)); |
2388 | sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val10, val11)); |
2389 | sums16_0 = vaddq_s16(sums16_0, vaddl_s8(val12, val13)); |
2390 | sums16_1 = vaddq_s16(sums16_1, vaddl_s8(val14, val15)); |
2391 | int16x8_t sums16 = vaddq_s16(sums16_0, sums16_1); |
2392 | sums0 = vaddw_s16(sums0, vget_low_s16(sums16)); |
2393 | sums1 = vaddw_s16(sums1, vget_high_s16(sums16)); |
2394 | // Store sums. |
2395 | vst1q_s32(sums + col, sums0); |
2396 | vst1q_s32(sums + col + 4, sums1); |
2397 | |
2398 | // Transpose the data, i.e. change the storage order of the |
2399 | // 16x8 block, to convert from the row-major source to the |
2400 | // column-major packed format. |
2401 | // |
2402 | // Before, for i in [0, 15], val<i> is the i-th row. |
2403 | // After, for i in [0, 7], { val<i> val<i+8> } is the i-th column. |
2404 | transpose_8bit_vals(val0, val1); |
2405 | transpose_8bit_vals(val2, val3); |
2406 | transpose_8bit_vals(val4, val5); |
2407 | transpose_8bit_vals(val6, val7); |
2408 | transpose_8bit_vals(val8, val9); |
2409 | transpose_8bit_vals(val10, val11); |
2410 | transpose_8bit_vals(val12, val13); |
2411 | transpose_8bit_vals(val14, val15); |
2412 | transpose_16bit_vals(val0, val2); |
2413 | transpose_16bit_vals(val1, val3); |
2414 | transpose_16bit_vals(val4, val6); |
2415 | transpose_16bit_vals(val5, val7); |
2416 | transpose_16bit_vals(val8, val10); |
2417 | transpose_16bit_vals(val9, val11); |
2418 | transpose_16bit_vals(val12, val14); |
2419 | transpose_16bit_vals(val13, val15); |
2420 | transpose_32bit_vals(val0, val4); |
2421 | transpose_32bit_vals(val1, val5); |
2422 | transpose_32bit_vals(val2, val6); |
2423 | transpose_32bit_vals(val3, val7); |
2424 | transpose_32bit_vals(val8, val12); |
2425 | transpose_32bit_vals(val9, val13); |
2426 | transpose_32bit_vals(val10, val14); |
2427 | transpose_32bit_vals(val11, val15); |
2428 | // Store to the packed_matrix. |
2429 | std::int8_t* dst_ptr = packed_ptr; |
2430 | vst1q_s8(dst_ptr, vcombine_s8(val0, val8)); |
2431 | vst1q_s8(dst_ptr + 16, vcombine_s8(val1, val9)); |
2432 | dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32; |
2433 | vst1q_s8(dst_ptr, vcombine_s8(val2, val10)); |
2434 | vst1q_s8(dst_ptr + 16, vcombine_s8(val3, val11)); |
2435 | packed_ptr += 4 * packed_stride; |
2436 | dst_ptr = packed_ptr; |
2437 | vst1q_s8(dst_ptr, vcombine_s8(val4, val12)); |
2438 | vst1q_s8(dst_ptr + 16, vcombine_s8(val5, val13)); |
2439 | dst_ptr += (kernel_cols == 2) ? 2 * packed_stride : 32; |
2440 | vst1q_s8(dst_ptr, vcombine_s8(val6, val14)); |
2441 | vst1q_s8(dst_ptr + 16, vcombine_s8(val7, val15)); |
2442 | packed_ptr += 4 * packed_stride; |
2443 | } |
2444 | // Handle remaining columns, not fitting in a full block of 8 columns, but |
2445 | // still true columns frome the source matrix (as opposed to the final columns |
2446 | // below). |
2447 | for (; col < src_end_col; col++) { |
2448 | std::int32_t accum = 0; |
2449 | std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16; |
2450 | for (int r = 0; r < 16; r++) { |
2451 | std::int8_t packed_val = (block_row + r < src_rows) |
2452 | ? (src_ptr[r * src_stride] ^ input_xor) |
2453 | : packed_zero_point; |
2454 | accum += packed_val; |
2455 | dst_ptr[r] = packed_val; |
2456 | } |
2457 | if (sums) { |
2458 | sums[col] += accum; |
2459 | } |
2460 | src_ptr++; |
2461 | if (((col + 1) & (kernel_cols - 1)) == 0) { |
2462 | packed_ptr += kernel_cols * packed_stride; |
2463 | } |
2464 | } |
2465 | // Handle the final columns of the packed matrix, beyond the last column of |
2466 | // the source matrix. The values here don't matter, we just want to avoid |
2467 | // leaving uninitialized data. Since the sums are already initialized above, |
2468 | // we don't need to do anything about them here. |
2469 | for (; col < end_col; col++) { |
2470 | std::int8_t* dst_ptr = packed_ptr + (col & (kernel_cols - 1)) * 16; |
2471 | std::memset(dst_ptr, 0, 16); |
2472 | if (((col + 1) & (kernel_cols - 1)) == 0) { |
2473 | packed_ptr += kernel_cols * packed_stride; |
2474 | } |
2475 | } |
2476 | } |
2477 | |
2478 | #endif |
2479 | |
2480 | } // namespace ruy |
2481 | |