1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_PACK_ARM_H_
17#define RUY_RUY_PACK_ARM_H_
18
19#include <algorithm>
20#include <cstdint>
21#include <type_traits>
22
23#include "ruy/asm_helpers.h"
24#include "ruy/check_macros.h"
25#include "ruy/mat.h"
26#include "ruy/opt_set.h"
27#include "ruy/pack_common.h"
28#include "ruy/path.h"
29#include "ruy/platform.h"
30#include "ruy/profiler/instrumentation.h"
31#include "ruy/tune.h"
32
33namespace ruy {
34
35#if RUY_PLATFORM_NEON
36RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon)
37RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod)
38
39RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 8)
40#if RUY_PLATFORM_NEON_32
41RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kNeon, 4)
42#endif
43
44template <>
45struct PackedTypeImpl<Path::kNeon, std::uint8_t> {
46 using Type = std::int8_t;
47};
48template <>
49struct PackedTypeImpl<Path::kNeonDotprod, std::uint8_t> {
50 using Type = std::int8_t;
51};
52#endif
53
54#if RUY_PLATFORM_NEON
55void Pack8bitRowMajorForNeon(const std::uint8_t* src_ptr, int src_stride,
56 int src_rows, int src_cols, int block_row,
57 int start_col, int end_col,
58 std::int8_t* packed_ptr, int packed_stride,
59 int packed_zero_point, std::int32_t* sums_ptr,
60 int input_xor, int kernel_cols);
61#endif
62
63#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
64
65void Pack8bitColMajorForNeon(const void* src_ptr0, const void* src_ptr1,
66 const void* src_ptr2, const void* src_ptr3,
67 int src_inc0, int src_inc1, int src_inc2,
68 int src_inc3, int src_rows, int src_zero_point,
69 std::int8_t* packed_ptr, std::int32_t* sums_ptr,
70 int input_xor);
71void Pack8bitColMajorForNeonA55ish(const void* src_ptr0, const void* src_ptr1,
72 const void* src_ptr2, const void* src_ptr3,
73 int src_inc0, int src_inc1, int src_inc2,
74 int src_inc3, int src_rows,
75 int src_zero_point, std::int8_t* packed_ptr,
76 std::int32_t* sums_ptr, int input_xor);
77void Pack8bitColMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
78 const void* src_ptr2, const void* src_ptr3,
79 int src_inc0, int src_inc1, int src_inc2,
80 int src_inc3, int src_rows,
81 int src_zero_point, std::int8_t* packed_ptr,
82 std::int32_t* sums_ptr, int input_xor);
83void Pack8bitColMajorForNeonDotprodA55ish(
84 const void* src_ptr0, const void* src_ptr1, const void* src_ptr2,
85 const void* src_ptr3, int src_inc0, int src_inc1, int src_inc2,
86 int src_inc3, int src_rows, int src_zero_point, std::int8_t* packed_ptr,
87 std::int32_t* sums_ptr, int input_xor);
88void Pack8bitRowMajorForNeonDotprod(const void* src_ptr0, const void* src_ptr1,
89 const void* src_ptr2, const void* src_ptr3,
90 int src_inc0, int src_inc1, int src_inc2,
91 int src_inc3, int src_cols,
92 int src_zero_point, std::int8_t* packed_ptr,
93 int packed_stride, std::int32_t* sums_ptr,
94 int input_xor);
95#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
96
97struct PackParams8bit {
98 const void* src_ptr0;
99 const void* src_ptr1;
100 const void* src_ptr2;
101 const void* src_ptr3;
102 const std::int32_t* sums_ptr;
103 const std::int8_t* packed_ptr;
104 int src_inc0;
105 int src_inc1;
106 int src_inc2;
107 int src_inc3;
108 int src_rows;
109 int src_zero_point;
110 int input_xor;
111};
112
113inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1,
114 const void* src_ptr2, const void* src_ptr3,
115 const std::int32_t* sums_ptr,
116 const std::int8_t* packed_ptr, int src_inc0,
117 int src_inc1, int src_inc2, int src_inc3,
118 int src_rows, int src_zero_point, int input_xor,
119 PackParams8bit* params) {
120 params->src_ptr0 = src_ptr0;
121 params->src_ptr1 = src_ptr1;
122 params->src_ptr2 = src_ptr2;
123 params->src_ptr3 = src_ptr3;
124 params->sums_ptr = sums_ptr;
125 params->packed_ptr = packed_ptr;
126 params->src_inc0 = src_inc0;
127 params->src_inc1 = src_inc1;
128 params->src_inc2 = src_inc2;
129 params->src_inc3 = src_inc3;
130 params->src_rows = src_rows;
131 params->src_zero_point = src_zero_point;
132 params->input_xor = input_xor;
133}
134
135void Pack8bitColMajorForNeon4Cols(const PackParams8bit& params);
136void Pack8bitColMajorForNeon2Cols(const PackParams8bit& params);
137
138#endif // (RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
139
140#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
141
142template <typename Scalar>
143struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
144 std::int8_t, std::int32_t, Order::kColMajor> {
145 static_assert(std::is_same<Scalar, std::int8_t>::value ||
146 std::is_same<Scalar, std::uint8_t>::value,
147 "");
148 static constexpr int kInputXor =
149 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
150
151 static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
152 PMat<std::int8_t>* packed_matrix, int start_col,
153 int end_col) {
154 RUY_DCHECK(IsColMajor(src_matrix.layout));
155 RUY_DCHECK(IsColMajor(packed_matrix->layout));
156 RUY_DCHECK_EQ(start_col % 4, 0);
157 std::int32_t* sums = packed_matrix->sums;
158 Scalar zerobuf[16];
159 memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
160 for (int block_col = start_col; block_col < end_col; block_col += 4) {
161 int src_stride = src_matrix.layout.stride;
162 const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
163 const Scalar* src_ptr1 = src_ptr0 + src_stride;
164 const Scalar* src_ptr2 = src_ptr1 + src_stride;
165 const Scalar* src_ptr3 = src_ptr2 + src_stride;
166 int src_inc0 = 16;
167 int src_inc1 = 16;
168 int src_inc2 = 16;
169 int src_inc3 = 16;
170 if (block_col >= src_matrix.layout.cols - 3) {
171 if (block_col >= src_matrix.layout.cols - 0) {
172 src_ptr0 = zerobuf;
173 src_inc0 = 0;
174 }
175 if (block_col >= src_matrix.layout.cols - 1) {
176 src_ptr1 = zerobuf;
177 src_inc1 = 0;
178 }
179 if (block_col >= src_matrix.layout.cols - 2) {
180 src_ptr2 = zerobuf;
181 src_inc2 = 0;
182 }
183 if (block_col >= src_matrix.layout.cols - 3) {
184 src_ptr3 = zerobuf;
185 src_inc3 = 0;
186 }
187 }
188 std::int8_t* packed_ptr =
189 packed_matrix->data + packed_matrix->layout.stride * block_col;
190 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
191#if RUY_PLATFORM_NEON_64
192 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
193 Pack8bitColMajorForNeonA55ish(
194 src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
195 src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
196 packed_ptr, sums_ptr, kInputXor);
197 } else {
198 Pack8bitColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
199 src_inc0, src_inc1, src_inc2, src_inc3,
200 src_matrix.layout.rows, src_matrix.zero_point,
201 packed_ptr, sums_ptr, kInputXor);
202 }
203#else
204 (void)tuning;
205 // We have a more limited set of general purpose registers in ARMv7, so
206 // we use the "params" struct technique from the kernel code to save
207 // registers.
208 PackParams8bit params;
209 MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr,
210 packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3,
211 src_matrix.layout.rows, src_matrix.zero_point,
212 kInputXor, &params);
213 Pack8bitColMajorForNeon4Cols(params);
214#endif // RUY_PLATFORM_NEON_64
215 }
216 }
217};
218
219#endif // (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) &&
220 // RUY_OPT(ASM)
221
222#if RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
223// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional
224// partial specialization for the RHS, which has a FixedKernelLayout with 2
225// columns.
226template <typename Scalar>
227struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
228 std::int8_t, std::int32_t, Order::kColMajor> {
229 static_assert(std::is_same<Scalar, std::int8_t>::value ||
230 std::is_same<Scalar, std::uint8_t>::value,
231 "");
232 static constexpr int kInputXor =
233 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
234 static void Run(Tuning, const Mat<Scalar>& src_matrix,
235 PMat<std::int8_t>* packed_matrix, int start_col,
236 int end_col) {
237 RUY_DCHECK(IsColMajor(src_matrix.layout));
238 RUY_DCHECK(IsColMajor(packed_matrix->layout));
239 RUY_DCHECK_EQ(start_col % 2, 0);
240 std::int32_t* sums = packed_matrix->sums;
241 Scalar zerobuf[16];
242 memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
243 for (int block_col = start_col; block_col < end_col; block_col += 2) {
244 int src_stride = src_matrix.layout.stride;
245 const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
246 const Scalar* src_ptr1 = src_ptr0 + src_stride;
247 int src_inc0 = 16;
248 int src_inc1 = 16;
249 if (block_col >= src_matrix.layout.cols - 2) {
250 if (block_col >= src_matrix.layout.cols - 0) {
251 src_ptr0 = zerobuf;
252 src_inc0 = 0;
253 }
254 if (block_col >= src_matrix.layout.cols - 1) {
255 src_ptr1 = zerobuf;
256 src_inc1 = 0;
257 }
258 }
259 std::int8_t* packed_ptr =
260 packed_matrix->data + packed_matrix->layout.stride * block_col;
261 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
262 PackParams8bit params;
263 MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr,
264 packed_ptr, src_inc0, src_inc1, -1, -1,
265 src_matrix.layout.rows, src_matrix.zero_point,
266 kInputXor, &params);
267 Pack8bitColMajorForNeon2Cols(params);
268 }
269 }
270};
271#endif // (RUY_PLATFORM_NEON_32) && RUY_OPT(ASM)
272
273#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
274template <typename Scalar>
275struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
276 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
277 static_assert(std::is_same<Scalar, std::int8_t>::value ||
278 std::is_same<Scalar, std::uint8_t>::value,
279 "");
280 static constexpr int kInputXor =
281 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
282
283 static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
284 PMat<std::int8_t>* packed_matrix, int start_col,
285 int end_col) {
286 RUY_DCHECK(IsColMajor(src_matrix.layout));
287 RUY_DCHECK(IsColMajor(packed_matrix->layout));
288 RUY_DCHECK_EQ(start_col % 8, 0);
289 std::int32_t* sums = packed_matrix->sums;
290 Scalar zerobuf[16];
291 memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
292 for (int block_col = start_col; block_col < end_col; block_col += 4) {
293 int src_stride = src_matrix.layout.stride;
294 const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
295 const Scalar* src_ptr1 = src_ptr0 + src_stride;
296 const Scalar* src_ptr2 = src_ptr1 + src_stride;
297 const Scalar* src_ptr3 = src_ptr2 + src_stride;
298 std::int64_t src_inc0 = 16;
299 std::int64_t src_inc1 = 16;
300 std::int64_t src_inc2 = 16;
301 std::int64_t src_inc3 = 16;
302 if (block_col >= src_matrix.layout.cols - 3) {
303 if (block_col >= src_matrix.layout.cols - 0) {
304 src_ptr0 = zerobuf;
305 src_inc0 = 0;
306 }
307 if (block_col >= src_matrix.layout.cols - 1) {
308 src_ptr1 = zerobuf;
309 src_inc1 = 0;
310 }
311 if (block_col >= src_matrix.layout.cols - 2) {
312 src_ptr2 = zerobuf;
313 src_inc2 = 0;
314 }
315 if (block_col >= src_matrix.layout.cols - 3) {
316 src_ptr3 = zerobuf;
317 src_inc3 = 0;
318 }
319 }
320 std::int8_t* packed_ptr =
321 packed_matrix->data +
322 packed_matrix->layout.stride * (block_col & ~7) +
323 ((block_col & 4) * 4);
324 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
325 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
326 Pack8bitColMajorForNeonDotprodA55ish(
327 src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
328 src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
329 packed_ptr, sums_ptr, kInputXor);
330 } else {
331 Pack8bitColMajorForNeonDotprod(
332 src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1,
333 src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point,
334 packed_ptr, sums_ptr, kInputXor);
335 }
336 }
337 }
338};
339#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
340
341#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
342void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
343 const float* src_ptr2, const float* src_ptr3,
344 int src_inc0, int src_inc1, int src_inc2,
345 int src_inc3, int src_rows, float* packed_ptr);
346void PackFloatColMajorForNeonA55ish(const float* src_ptr0,
347 const float* src_ptr1,
348 const float* src_ptr2,
349 const float* src_ptr3, int src_inc0,
350 int src_inc1, int src_inc2, int src_inc3,
351 int src_rows, float* packed_ptr);
352
353#elif RUY_PLATFORM_NEON_32 && RUY_OPT(ASM)
354void PackFloatColMajorForNeon(const float* src_ptr0, const float* src_ptr1,
355 const float* src_ptr2, const float* src_ptr3,
356 int src_inc, int src_rows, float* packed_ptr,
357 int stride);
358#endif // (RUY_PLATFORM_NEON_64&& RUY_OPT(ASM)
359
360#if (RUY_PLATFORM_NEON_32 || RUY_PLATFORM_NEON_64) && RUY_OPT(ASM)
361
362template <>
363struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
364 float, float, Order::kColMajor> {
365 static void Run(Tuning tuning, const Mat<float>& src_matrix,
366 PMat<float>* packed_matrix, int start_col, int end_col) {
367 RUY_DCHECK(IsColMajor(src_matrix.layout));
368 RUY_DCHECK(IsColMajor(packed_matrix->layout));
369 RUY_DCHECK_EQ(start_col % 8, 0);
370 const float zerobuf[4] = {0};
371 for (int block_col = start_col; block_col < end_col; block_col += 4) {
372 int src_stride = src_matrix.layout.stride;
373 const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
374 const float* src_ptr1 = src_ptr0 + src_stride;
375 const float* src_ptr2 = src_ptr1 + src_stride;
376 const float* src_ptr3 = src_ptr2 + src_stride;
377 std::int64_t src_inc0 = 16;
378 std::int64_t src_inc1 = 16;
379 std::int64_t src_inc2 = 16;
380 std::int64_t src_inc3 = 16;
381 if (block_col >= src_matrix.layout.cols - 3) {
382 if (block_col >= src_matrix.layout.cols - 0) {
383 src_ptr0 = zerobuf;
384 src_inc0 = 0;
385 }
386 if (block_col >= src_matrix.layout.cols - 1) {
387 src_ptr1 = zerobuf;
388 src_inc1 = 0;
389 }
390 if (block_col >= src_matrix.layout.cols - 2) {
391 src_ptr2 = zerobuf;
392 src_inc2 = 0;
393 }
394 if (block_col >= src_matrix.layout.cols - 3) {
395 src_ptr3 = zerobuf;
396 src_inc3 = 0;
397 }
398 }
399 float* packed_ptr = packed_matrix->data +
400 packed_matrix->layout.stride * (block_col & ~7) +
401 ((block_col & 4));
402#if RUY_PLATFORM_NEON_64
403 if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
404 PackFloatColMajorForNeonA55ish(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
405 src_inc0, src_inc1, src_inc2, src_inc3,
406 src_matrix.layout.rows, packed_ptr);
407 } else {
408 PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3,
409 src_inc0, src_inc1, src_inc2, src_inc3,
410 src_matrix.layout.rows, packed_ptr);
411 }
412#else
413 (void)tuning;
414 // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc
415 // to save on registers (we have fewer general purpose registers in
416 // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four
417 // values that are each either 16 or 0 and use them directly. For the
418 // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should
419 // use the value 16 (bit is set) or 0 (bit is not set) for the
420 // respective increment value.
421 std::int64_t src_inc = 0;
422 src_inc += src_inc0 == 16 ? 1 : 0;
423 src_inc += src_inc1 == 16 ? 2 : 0;
424 src_inc += src_inc2 == 16 ? 4 : 0;
425 src_inc += src_inc3 == 16 ? 8 : 0;
426 const int kOutputStride = 32;
427 PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
428 src_matrix.layout.rows, packed_ptr,
429 kOutputStride);
430#endif // RUY_PLATFORM_NEON_64
431 }
432 }
433};
434
435#if RUY_PLATFORM_NEON_32
436// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional
437// specialization for a FixedKernelLayout with 4 columns.
438template <>
439struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
440 float, float, Order::kColMajor> {
441 static void Run(Tuning, const Mat<float>& src_matrix,
442 PMat<float>* packed_matrix, int start_col, int end_col) {
443 RUY_DCHECK(IsColMajor(src_matrix.layout));
444 RUY_DCHECK(IsColMajor(packed_matrix->layout));
445 RUY_DCHECK_EQ(start_col % 4, 0);
446 const float zerobuf[4] = {0};
447 for (int block_col = start_col; block_col < end_col; block_col += 4) {
448 int src_stride = src_matrix.layout.stride;
449 const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col;
450 const float* src_ptr1 = src_ptr0 + src_stride;
451 const float* src_ptr2 = src_ptr1 + src_stride;
452 const float* src_ptr3 = src_ptr2 + src_stride;
453 std::int64_t src_inc0 = 16;
454 std::int64_t src_inc1 = 16;
455 std::int64_t src_inc2 = 16;
456 std::int64_t src_inc3 = 16;
457 if (block_col >= src_matrix.layout.cols - 3) {
458 if (block_col >= src_matrix.layout.cols - 0) {
459 src_ptr0 = zerobuf;
460 src_inc0 = 0;
461 }
462 if (block_col >= src_matrix.layout.cols - 1) {
463 src_ptr1 = zerobuf;
464 src_inc1 = 0;
465 }
466 if (block_col >= src_matrix.layout.cols - 2) {
467 src_ptr2 = zerobuf;
468 src_inc2 = 0;
469 }
470 if (block_col >= src_matrix.layout.cols - 3) {
471 src_ptr3 = zerobuf;
472 src_inc3 = 0;
473 }
474 }
475 float* packed_ptr =
476 packed_matrix->data + packed_matrix->layout.stride * (block_col);
477 // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc
478 // to save registers.
479 std::int64_t src_inc = 0;
480 src_inc += src_inc0 == 16 ? 1 : 0;
481 src_inc += src_inc1 == 16 ? 2 : 0;
482 src_inc += src_inc2 == 16 ? 4 : 0;
483 src_inc += src_inc3 == 16 ? 8 : 0;
484 const int kOutputStride = 16;
485 PackFloatColMajorForNeon(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc,
486 src_matrix.layout.rows, packed_ptr,
487 kOutputStride);
488 }
489 }
490};
491#endif // (RUY_PLATFORM_NEON_32)
492#endif // (RUY_PLATFORM_NEON_64 || RUY_PLATFORM_NEON_32) && \
493 // RUY_OPT(ASM)
494
495#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
496
497template <typename Scalar>
498struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
499 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
500 static_assert(std::is_same<Scalar, std::int8_t>::value ||
501 std::is_same<Scalar, std::uint8_t>::value,
502 "");
503 static constexpr int kInputXor =
504 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
505
506 static void Run(Tuning, const Mat<Scalar>& src_matrix,
507 PMat<std::int8_t>* packed_matrix, int start_col,
508 int end_col) {
509 RUY_DCHECK(IsRowMajor(src_matrix.layout));
510 RUY_DCHECK(IsColMajor(packed_matrix->layout));
511 RUY_DCHECK_EQ(start_col % 8, 0);
512 std::int32_t* sums = packed_matrix->sums;
513 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
514 Scalar zerobuf[8];
515 memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf));
516 int src_stride = src_matrix.layout.stride;
517 // As the source matrix is row-major and the destination packed matrix is
518 // column-major, there is no traversal order that will be optimal for both
519 // so we choose to favor the source matrix with a row-major traversal order.
520 // Loop over groups of 4 rows.
521 for (int block_row = 0; block_row < packed_matrix->layout.rows;
522 block_row += 4) {
523 // src_ptr[0-3] shall point to the positions in the 4 rows of the source
524 // matrix that we are loading from, and will be incremented by
525 // src_inc[0-3] after each 4x8 block is loaded.
526 // First we compute these src_ptr and src_inc values for the case where
527 // there are 4 rows left to be loaded from in the source matrix ...
528 const Scalar* src_ptr0 =
529 src_matrix.data.get() + src_stride * block_row + start_col;
530 const Scalar* src_ptr1 = src_ptr0 + src_stride;
531 const Scalar* src_ptr2 = src_ptr1 + src_stride;
532 const Scalar* src_ptr3 = src_ptr2 + src_stride;
533 std::int64_t src_inc0 = 8;
534 std::int64_t src_inc1 = 8;
535 std::int64_t src_inc2 = 8;
536 std::int64_t src_inc3 = 8;
537 // ... and now we adjust these values in case there are fewer than 4 rows
538 // left to load from in the source matrix. In that case, we set the
539 // corresponding src_ptr pointer to load from `zerobuf` and set src_inc
540 // to 0 to avoid overrunning that small buffer.
541 if (block_row >= src_matrix.layout.rows - 3) {
542 if (block_row >= src_matrix.layout.rows - 0) {
543 src_ptr0 = zerobuf;
544 src_inc0 = 0;
545 }
546 if (block_row >= src_matrix.layout.rows - 1) {
547 src_ptr1 = zerobuf;
548 src_inc1 = 0;
549 }
550 if (block_row >= src_matrix.layout.rows - 2) {
551 src_ptr2 = zerobuf;
552 src_inc2 = 0;
553 }
554 if (block_row >= src_matrix.layout.rows - 3) {
555 src_ptr3 = zerobuf;
556 src_inc3 = 0;
557 }
558 }
559 // Let src_cols be the number of source matrix columns to handle.
560 int src_cols = std::min(end_col, src_matrix.layout.cols) - start_col;
561 std::int8_t* packed_ptr = packed_matrix->data +
562 packed_matrix->layout.stride * start_col +
563 8 * block_row;
564 std::int32_t* sums_ptr = sums + start_col;
565 Pack8bitRowMajorForNeonDotprod(
566 src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, src_inc2,
567 src_inc3, src_cols, src_matrix.zero_point, packed_ptr,
568 packed_matrix->layout.stride, sums_ptr, kInputXor);
569 }
570 }
571};
572
573#endif // RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
574
575#if RUY_PLATFORM_NEON
576
577template <typename Scalar, int KernelCols>
578struct PackImpl<Path::kNeon,
579 FixedKernelLayout<Order::kColMajor, 16, KernelCols>, Scalar,
580 std::int8_t, std::int32_t, Order::kRowMajor> {
581 static void Run(Tuning, const Mat<Scalar>& src_matrix,
582 PMat<std::int8_t>* packed_matrix, int start_col,
583 int end_col) {
584 profiler::ScopeLabel label("Pack (KNeon, from row-major source)");
585 static constexpr int kInputXor =
586 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
587 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
588 RUY_DCHECK_EQ((end_col - start_col) % KernelCols, 0);
589 std::int32_t* sums = packed_matrix->sums;
590 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
591 int block_row = 0;
592 for (; block_row < packed_matrix->layout.rows; block_row += 16) {
593 int src_stride = src_matrix.layout.stride;
594 int packed_stride = packed_matrix->layout.stride;
595 const Scalar* src_ptr =
596 src_matrix.data.get() + block_row * src_stride + start_col;
597 std::int8_t* packed_ptr = packed_matrix->data +
598 start_col * packed_stride +
599 block_row * KernelCols;
600
601 Pack8bitRowMajorForNeon(
602 reinterpret_cast<const std::uint8_t*>(src_ptr), src_stride,
603 src_matrix.layout.rows, src_matrix.layout.cols, block_row, start_col,
604 end_col, packed_ptr, packed_stride, packed_matrix->zero_point, sums,
605 kInputXor, KernelCols);
606 }
607 }
608};
609#endif
610
611} // namespace ruy
612
613#endif // RUY_RUY_PACK_ARM_H_
614