1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17#include <cstring>
18
19#include "ruy/check_macros.h"
20#include "ruy/opt_set.h"
21#include "ruy/pack_x86.h"
22#include "ruy/path.h"
23#include "ruy/platform.h"
24#include "ruy/profiler/instrumentation.h"
25
26#if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
27#include <immintrin.h> // IWYU pragma: keep
28#endif
29
30namespace ruy {
31
32#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
33
34void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
35 const std::int8_t*, int, int, int, std::int8_t*,
36 std::int32_t*) {
37 // CPU-ID-based checks should disable the path that would reach this point.
38 RUY_DCHECK(false);
39}
40
41void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int,
42 int, int, std::int16_t*, std::int32_t*) {
43 // CPU-ID-based checks should disable the path that would reach this point.
44 RUY_DCHECK(false);
45}
46
47void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
48 float*) {
49 // CPU-ID-based checks should disable the path that would reach this point.
50 RUY_DCHECK(false);
51}
52
53void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
54 int, int, int, int, int, int, std::int32_t*) {
55 RUY_DCHECK(false);
56}
57
58#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
59
60// The first int8_t template parameter is arbitrary: this routine is common to
61// all 8-bit source matrix types.
62using PackImpl8bitAvx512 =
63 PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
64 std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
65using PackImpl16bitAvx512 =
66 PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
67 std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>;
68
69namespace {
70
71template <typename PackImplAvx512, typename Scalar>
72inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
73 Scalar* packed_ptr, int chunked_row_mask) {
74 using Layout = typename PackImplAvx512::Layout;
75 static constexpr int kHalfLayoutCols =
76 PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a
77 // block.
78 RUY_DCHECK_EQ(kHalfLayoutCols, 8);
79 RUY_DCHECK_EQ(Layout::kCols, 16);
80 RUY_DCHECK_EQ(Layout::kRows, 4);
81
82 const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2;
83 // This routine fills half blocks, and typically fills the second halves.
84 // Thus packed_ptr is already offset by 8 * 4.
85 for (int k = 0; k < non_trailing_blocks; ++k) {
86 for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
87 packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
88 }
89 }
90}
91
92template <typename Scalar>
93inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) {
94 __m512i lower_filled = _mm512_castsi256_si512(
95 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
96 return _mm512_inserti32x8(
97 lower_filled,
98 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
99}
100
101inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
102 const std::int8_t* addr_lo,
103 const std::int8_t* addr_hi) {
104 const __m512i lower_filled = _mm512_castsi256_si512(
105 _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
106 return _mm512_inserti32x8(
107 lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
108 1);
109}
110
111inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
112 const std::int16_t* addr_lo,
113 const std::int16_t* addr_hi) {
114 const __m512i lower_filled = _mm512_castsi256_si512(
115 _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo));
116 return _mm512_inserti32x8(
117 lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi),
118 1);
119}
120
121inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
122 std::int8_t input_xor,
123 const std::int8_t* zerobuf, int src_stride,
124 int remaining_src_cols, int src_rows,
125 std::int8_t* packed_ptr, std::int32_t* sums_ptr,
126 std::int8_t* trailing_buf) {
127 using Layout = PackImpl8bitAvx512::Layout;
128 RUY_DCHECK_EQ(Layout::kCols, 16);
129 RUY_DCHECK_EQ(Layout::kRows, 4);
130 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
131 // We process 8 of these chunks at a time, padding short input chunks.
132 constexpr int kNumRowChunks = 8;
133 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
134
135 const std::int8_t* src_ptr0 = src_ptr;
136 const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
137 const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
138 const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
139 const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
140 const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
141 const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
142 const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
143 std::int64_t src_inc0 = kNumChunkedSrcRows;
144 std::int64_t src_inc1 = kNumChunkedSrcRows;
145 std::int64_t src_inc2 = kNumChunkedSrcRows;
146 std::int64_t src_inc3 = kNumChunkedSrcRows;
147 std::int64_t src_inc4 = kNumChunkedSrcRows;
148 std::int64_t src_inc5 = kNumChunkedSrcRows;
149 std::int64_t src_inc6 = kNumChunkedSrcRows;
150 std::int64_t src_inc7 = kNumChunkedSrcRows;
151 // Handle cases where source does not have kHalfLayoutCols (8) columns.
152 if (remaining_src_cols < 8) {
153 if (remaining_src_cols <= 0) {
154 src_ptr0 = zerobuf;
155 src_inc0 = 0;
156 }
157 if (remaining_src_cols <= 1) {
158 src_ptr1 = zerobuf;
159 src_inc1 = 0;
160 }
161 if (remaining_src_cols <= 2) {
162 src_ptr2 = zerobuf;
163 src_inc2 = 0;
164 }
165 if (remaining_src_cols <= 3) {
166 src_ptr3 = zerobuf;
167 src_inc3 = 0;
168 }
169 if (remaining_src_cols <= 4) {
170 src_ptr4 = zerobuf;
171 src_inc4 = 0;
172 }
173 if (remaining_src_cols <= 5) {
174 src_ptr5 = zerobuf;
175 src_inc5 = 0;
176 }
177 if (remaining_src_cols <= 6) {
178 src_ptr6 = zerobuf;
179 src_inc6 = 0;
180 }
181 src_ptr7 = zerobuf;
182 src_inc7 = 0;
183 }
184
185 const std::int8_t zero_point = zerobuf[0];
186
187 if (sums_ptr) {
188 // i: kHalfLayoutCols.
189 for (int i = 0; i < 8; ++i) {
190 sums_ptr[i] = 0;
191 }
192 }
193 std::int32_t sums_adjustment = 0;
194 const __m512i ones_16bit = _mm512_set1_epi16(1);
195 __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
196
197 // The overall packing effectively pads the source rows to
198 // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
199 // only pack for (src_rows + 31) & ~31. When there is an incomplete
200 // destination block, this is stored into trailing_buf instead of packed_ptr.
201 for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
202 // m: {0, 1} for 2 chunks of rows.
203 for (int m = 0; m < 2; ++m) {
204 // Available source rows.
205 // If this is less than 0 (for m=1), we skip, having filled trailing
206 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
207 // exactly to the end of the column in the packed buffer.
208 const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
209 // Effectively,
210 // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m));
211 // treat each case separately.
212 if (available_src_rows >= kNumChunkedSrcRows) {
213 // i: chunks, s: Layout::Rows.
214 if (sums_ptr) {
215 __m512i t0, t1, t2, t3;
216 __m512i r0, r1, r2, r3;
217 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
218
219 t0 = LoaduTwo(src_ptr0, src_ptr4);
220 t1 = LoaduTwo(src_ptr1, src_ptr5);
221 t2 = LoaduTwo(src_ptr2, src_ptr6);
222 t3 = LoaduTwo(src_ptr3, src_ptr7);
223
224 r0 = _mm512_unpacklo_epi32(t0, t1);
225 r2 = _mm512_unpackhi_epi32(t0, t1);
226 r1 = _mm512_unpacklo_epi32(t2, t3);
227 r3 = _mm512_unpackhi_epi32(t2, t3);
228
229 t0 = _mm512_unpacklo_epi64(r0, r1);
230 t2 = _mm512_unpackhi_epi64(r0, r1);
231 t1 = _mm512_unpacklo_epi64(r2, r3);
232 t3 = _mm512_unpackhi_epi64(r2, r3);
233
234 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
235 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
236 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
237 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
238
239 r0 = _mm512_xor_si512(r0, input_xor_v);
240 r1 = _mm512_xor_si512(r1, input_xor_v);
241 r2 = _mm512_xor_si512(r2, input_xor_v);
242 r3 = _mm512_xor_si512(r3, input_xor_v);
243
244 const __m256i r0_0 = _mm512_castsi512_si256(r0);
245 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
246 const __m256i r1_0 = _mm512_castsi512_si256(r1);
247 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
248 const __m256i r2_0 = _mm512_castsi512_si256(r2);
249 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
250 const __m256i r3_0 = _mm512_castsi512_si256(r3);
251 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
252
253 __m512i sums_8x4_16bit;
254 sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
255 sums_8x4_16bit =
256 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
257 sums_8x4_16bit =
258 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
259 sums_8x4_16bit =
260 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
261 sums_8x4_16bit =
262 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
263 sums_8x4_16bit =
264 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
265 sums_8x4_16bit =
266 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
267 sums_8x4_16bit =
268 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
269 // The sums have been performed across columns, and now we have
270 // 4x16-bit sums packed together. We use madd for pairwise 32-bit
271 // sums.
272 const __m512i sums_8x2_32bit_new =
273 _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
274 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
275
276 _mm256_storeu_si256(
277 reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
278 _mm256_storeu_si256(
279 reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
280 _mm256_storeu_si256(
281 reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
282 _mm256_storeu_si256(
283 reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
284 _mm256_storeu_si256(
285 reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
286 _mm256_storeu_si256(
287 reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
288 _mm256_storeu_si256(
289 reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
290 _mm256_storeu_si256(
291 reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
292 } else {
293 __m512i t0, t1, t2, t3;
294 __m512i r0, r1, r2, r3;
295 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
296
297 t0 = LoaduTwo(src_ptr0, src_ptr4);
298 t1 = LoaduTwo(src_ptr1, src_ptr5);
299 t2 = LoaduTwo(src_ptr2, src_ptr6);
300 t3 = LoaduTwo(src_ptr3, src_ptr7);
301
302 r0 = _mm512_unpacklo_epi32(t0, t1);
303 r2 = _mm512_unpackhi_epi32(t0, t1);
304 r1 = _mm512_unpacklo_epi32(t2, t3);
305 r3 = _mm512_unpackhi_epi32(t2, t3);
306
307 t0 = _mm512_unpacklo_epi64(r0, r1);
308 t2 = _mm512_unpackhi_epi64(r0, r1);
309 t1 = _mm512_unpacklo_epi64(r2, r3);
310 t3 = _mm512_unpackhi_epi64(r2, r3);
311
312 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
313 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
314 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
315 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
316
317 r0 = _mm512_xor_si512(r0, input_xor_v);
318 r1 = _mm512_xor_si512(r1, input_xor_v);
319 r2 = _mm512_xor_si512(r2, input_xor_v);
320 r3 = _mm512_xor_si512(r3, input_xor_v);
321
322 const __m256i r0_0 = _mm512_castsi512_si256(r0);
323 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
324 const __m256i r1_0 = _mm512_castsi512_si256(r1);
325 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
326 const __m256i r2_0 = _mm512_castsi512_si256(r2);
327 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
328 const __m256i r3_0 = _mm512_castsi512_si256(r3);
329 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
330 _mm256_storeu_si256(
331 reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
332 _mm256_storeu_si256(
333 reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
334 _mm256_storeu_si256(
335 reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
336 _mm256_storeu_si256(
337 reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
338 _mm256_storeu_si256(
339 reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
340 _mm256_storeu_si256(
341 reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
342 _mm256_storeu_si256(
343 reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
344 _mm256_storeu_si256(
345 reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
346 }
347 } else if (available_src_rows > 0) {
348 RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
349 const __mmask32 row_mask =
350 (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
351
352 // We do not care what goes into the trailing buffer, but we want
353 // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
354 //
355 // We compensate for padding-with-zero_point by initializing the
356 // summations with the compensating offset, effectively
357 // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
358 // 4 * (8 - ((available_src_rows + 3) >> 2)).
359 //
360 // Note that (zero_point ^ input_xor) is performed in 8-bits and then
361 // cast.
362 sums_adjustment += -(zero_point ^ input_xor) * 4 *
363 (8 - ((available_src_rows + 3) >> 2));
364
365 __m512i t0, t1, t2, t3;
366 __m512i r0, r1, r2, r3;
367 const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
368 const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
369
370 t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
371 t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
372 t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
373 t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
374
375 r0 = _mm512_unpacklo_epi32(t0, t1);
376 r2 = _mm512_unpackhi_epi32(t0, t1);
377 r1 = _mm512_unpacklo_epi32(t2, t3);
378 r3 = _mm512_unpackhi_epi32(t2, t3);
379
380 t0 = _mm512_unpacklo_epi64(r0, r1);
381 t2 = _mm512_unpackhi_epi64(r0, r1);
382 t1 = _mm512_unpacklo_epi64(r2, r3);
383 t3 = _mm512_unpackhi_epi64(r2, r3);
384
385 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
386 r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
387 r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
388 r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
389
390 r0 = _mm512_xor_si512(r0, input_xor_v);
391 r1 = _mm512_xor_si512(r1, input_xor_v);
392 r2 = _mm512_xor_si512(r2, input_xor_v);
393 r3 = _mm512_xor_si512(r3, input_xor_v);
394
395 const __m256i r0_0 = _mm512_castsi512_si256(r0);
396 const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
397 const __m256i r1_0 = _mm512_castsi512_si256(r1);
398 const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
399 const __m256i r2_0 = _mm512_castsi512_si256(r2);
400 const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
401 const __m256i r3_0 = _mm512_castsi512_si256(r3);
402 const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
403
404 __m512i sums_8x4_16bit;
405 sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
406 sums_8x4_16bit =
407 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
408 sums_8x4_16bit =
409 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
410 sums_8x4_16bit =
411 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
412 sums_8x4_16bit =
413 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
414 sums_8x4_16bit =
415 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
416 sums_8x4_16bit =
417 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
418 sums_8x4_16bit =
419 _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
420 // The sums have been performed across columns, and now we have
421 // 4x16-bit sums packed together. We use madd for pairwise 32-bit
422 // sums.
423 const __m512i sums_8x2_32bit_new =
424 _mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
425 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
426
427 _mm256_storeu_si256(
428 reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
429 _mm256_storeu_si256(
430 reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
431 _mm256_storeu_si256(
432 reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
433 _mm256_storeu_si256(
434 reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
435 _mm256_storeu_si256(
436 reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
437 _mm256_storeu_si256(
438 reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
439 _mm256_storeu_si256(
440 reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
441 _mm256_storeu_si256(
442 reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
443 }
444
445 packed_ptr += 16 * kNumChunkedSrcRows;
446 src_ptr0 += src_inc0;
447 src_ptr1 += src_inc1;
448 src_ptr2 += src_inc2;
449 src_ptr3 += src_inc3;
450 src_ptr4 += src_inc4;
451 src_ptr5 += src_inc5;
452 src_ptr6 += src_inc6;
453 src_ptr7 += src_inc7;
454 }
455 }
456
457 if (sums_ptr) {
458 const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
459
460 __m256i sums =
461 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
462 const __m512i idx =
463 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
464
465 // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
466 // neighbours, finshing up by adding them to the stored accumulated sums.
467 const __m512i sums_2x8_32bit =
468 _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
469 sums = _mm256_add_epi32(sums, sums_adjustment_v);
470 sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
471 sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
472
473 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
474 }
475}
476
477inline void HalfPack16bitAvx512(const std::int16_t* src_ptr,
478 const std::int16_t* zerobuf, int src_stride,
479 int remaining_src_cols, int src_rows,
480 std::int16_t* packed_ptr,
481 std::int32_t* sums_ptr,
482 std::int16_t* trailing_buf) {
483 using Layout = PackImpl16bitAvx512::Layout;
484 RUY_DCHECK_EQ(Layout::kCols, 16);
485 RUY_DCHECK_EQ(Layout::kRows, 4);
486 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
487 // We process 4 of these chunks at a time, padding std::int16_t input chunks.
488 constexpr int kNumRowChunks = 4;
489 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
490
491 const std::int16_t* src_ptr0 = src_ptr;
492 const std::int16_t* src_ptr1 = src_ptr0 + src_stride;
493 const std::int16_t* src_ptr2 = src_ptr1 + src_stride;
494 const std::int16_t* src_ptr3 = src_ptr2 + src_stride;
495 const std::int16_t* src_ptr4 = src_ptr3 + src_stride;
496 const std::int16_t* src_ptr5 = src_ptr4 + src_stride;
497 const std::int16_t* src_ptr6 = src_ptr5 + src_stride;
498 const std::int16_t* src_ptr7 = src_ptr6 + src_stride;
499 std::int64_t src_inc0 = kNumChunkedSrcRows;
500 std::int64_t src_inc1 = kNumChunkedSrcRows;
501 std::int64_t src_inc2 = kNumChunkedSrcRows;
502 std::int64_t src_inc3 = kNumChunkedSrcRows;
503 std::int64_t src_inc4 = kNumChunkedSrcRows;
504 std::int64_t src_inc5 = kNumChunkedSrcRows;
505 std::int64_t src_inc6 = kNumChunkedSrcRows;
506 std::int64_t src_inc7 = kNumChunkedSrcRows;
507 // Handle cases where source does not have kHalfLayoutCols (8) columns.
508 if (remaining_src_cols < 8) {
509 if (remaining_src_cols <= 0) {
510 src_ptr0 = zerobuf;
511 src_inc0 = 0;
512 }
513 if (remaining_src_cols <= 1) {
514 src_ptr1 = zerobuf;
515 src_inc1 = 0;
516 }
517 if (remaining_src_cols <= 2) {
518 src_ptr2 = zerobuf;
519 src_inc2 = 0;
520 }
521 if (remaining_src_cols <= 3) {
522 src_ptr3 = zerobuf;
523 src_inc3 = 0;
524 }
525 if (remaining_src_cols <= 4) {
526 src_ptr4 = zerobuf;
527 src_inc4 = 0;
528 }
529 if (remaining_src_cols <= 5) {
530 src_ptr5 = zerobuf;
531 src_inc5 = 0;
532 }
533 if (remaining_src_cols <= 6) {
534 src_ptr6 = zerobuf;
535 src_inc6 = 0;
536 }
537 src_ptr7 = zerobuf;
538 src_inc7 = 0;
539 }
540
541 const std::int16_t zero_point = zerobuf[0];
542
543 if (sums_ptr) {
544 // i: kHalfLayoutCols.
545 for (int i = 0; i < 8; ++i) {
546 sums_ptr[i] = 0;
547 }
548 }
549 std::int32_t sums_adjustment = 0;
550 const __m512i ones_16bit = _mm512_set1_epi16(1);
551 __m512i sums_8x2_32bit = _mm512_set1_epi32(0);
552
553 // The overall packing effectively pads the source rows to
554 // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we
555 // only pack for (src_rows + 15) & ~15. When there is an incomplete
556 // destination block, this is stored into trailing_buf instead of packed_ptr.
557 for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
558 // m: {0, 1} for 2 chunks of rows.
559 for (int m = 0; m < 2; ++m) {
560 const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
561
562 // Available source rows.
563 // If this is less than 0 (for m=1), we skip, having filled trailing
564 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
565 // exactly to the end of the column in the packed buffer.
566 if (available_src_rows > 0) {
567 __m512i t0, t1, t2, t3;
568 __m512i r0, r1, r2, r3;
569 std::int16_t* dst_ptr = packed_ptr;
570
571 if (available_src_rows >= kNumChunkedSrcRows) {
572 t0 = LoaduTwo(src_ptr0, src_ptr4);
573 t1 = LoaduTwo(src_ptr1, src_ptr5);
574 t2 = LoaduTwo(src_ptr2, src_ptr6);
575 t3 = LoaduTwo(src_ptr3, src_ptr7);
576 } else {
577 RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
578 // We do not care what goes into the trailing buffer, but we want
579 // in_data[...] == zero_point for irrelevant values in the summation.
580 //
581 // We compensate for padding-with-zero_point by initializing the
582 // summations with the compensating offset.
583 sums_adjustment +=
584 -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2));
585
586 const __m256i zero_point_v = _mm256_set1_epi16(zero_point);
587 const __mmask32 row_mask =
588 (static_cast<std::uint64_t>(1) << available_src_rows) - 1;
589
590 t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
591 t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
592 t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
593 t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
594 dst_ptr = trailing_buf;
595 }
596
597 r0 = _mm512_unpacklo_epi64(t0, t1);
598 r2 = _mm512_unpackhi_epi64(t0, t1);
599 r1 = _mm512_unpacklo_epi64(t2, t3);
600 r3 = _mm512_unpackhi_epi64(t2, t3);
601
602 r1 = _mm512_permutex_epi64(r1, 0x4e);
603 r3 = _mm512_permutex_epi64(r3, 0x4e);
604
605 t0 = _mm512_mask_blend_epi64(0xcc, r0, r1);
606 t1 = _mm512_mask_blend_epi64(0x33, r0, r1);
607 t2 = _mm512_mask_blend_epi64(0xcc, r2, r3);
608 t3 = _mm512_mask_blend_epi64(0x33, r2, r3);
609
610 t1 = _mm512_permutex_epi64(t1, 0x4e);
611 t3 = _mm512_permutex_epi64(t3, 0x4e);
612
613 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4),
614 t0);
615 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4),
616 t1);
617 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4),
618 t2);
619 _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4),
620 t3);
621
622 if (sums_ptr) {
623 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
624 _mm512_madd_epi16(t0, ones_16bit));
625 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
626 _mm512_madd_epi16(t1, ones_16bit));
627 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
628 _mm512_madd_epi16(t2, ones_16bit));
629 sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
630 _mm512_madd_epi16(t3, ones_16bit));
631 }
632 }
633
634 packed_ptr += 16 * kNumChunkedSrcRows;
635 src_ptr0 += src_inc0;
636 src_ptr1 += src_inc1;
637 src_ptr2 += src_inc2;
638 src_ptr3 += src_inc3;
639 src_ptr4 += src_inc4;
640 src_ptr5 += src_inc5;
641 src_ptr6 += src_inc6;
642 src_ptr7 += src_inc7;
643 }
644 }
645
646 if (sums_ptr) {
647 const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
648
649 __m256i sums =
650 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
651 const __m512i idx =
652 _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
653
654 const __m512i sums_2x8_32bit =
655 _mm512_permutexvar_epi32(idx, sums_8x2_32bit);
656 sums = _mm256_add_epi32(sums, sums_adjustment_v);
657 sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
658 sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
659
660 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
661 }
662}
663
664inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
665 const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
666 return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
667}
668
669inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
670 const float* addr_hi) {
671 const __m512 lower_filled =
672 _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
673 return _mm512_insertf32x8(lower_filled,
674 _mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
675}
676
677inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
678 return _mm512_castpd_ps(
679 _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
680}
681
682inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
683 return _mm512_castpd_ps(
684 _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
685}
686
687inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
688 int src_stride, int remaining_src_cols,
689 int src_rows, float* packed_ptr,
690 float* trailing_buf) {
691 const float* src_ptr0 = src_ptr;
692 const float* src_ptr1 = src_ptr0 + src_stride;
693 const float* src_ptr2 = src_ptr1 + src_stride;
694 const float* src_ptr3 = src_ptr2 + src_stride;
695 const float* src_ptr4 = src_ptr3 + src_stride;
696 const float* src_ptr5 = src_ptr4 + src_stride;
697 const float* src_ptr6 = src_ptr5 + src_stride;
698 const float* src_ptr7 = src_ptr6 + src_stride;
699 std::int64_t src_inc0 = 8;
700 std::int64_t src_inc1 = 8;
701 std::int64_t src_inc2 = 8;
702 std::int64_t src_inc3 = 8;
703 std::int64_t src_inc4 = 8;
704 std::int64_t src_inc5 = 8;
705 std::int64_t src_inc6 = 8;
706 std::int64_t src_inc7 = 8;
707 if (remaining_src_cols < 8) {
708 if (remaining_src_cols <= 0) {
709 src_ptr0 = zerobuf;
710 src_inc0 = 0;
711 }
712 if (remaining_src_cols <= 1) {
713 src_ptr1 = zerobuf;
714 src_inc1 = 0;
715 }
716 if (remaining_src_cols <= 2) {
717 src_ptr2 = zerobuf;
718 src_inc2 = 0;
719 }
720 if (remaining_src_cols <= 3) {
721 src_ptr3 = zerobuf;
722 src_inc3 = 0;
723 }
724 if (remaining_src_cols <= 4) {
725 src_ptr4 = zerobuf;
726 src_inc4 = 0;
727 }
728 if (remaining_src_cols <= 5) {
729 src_ptr5 = zerobuf;
730 src_inc5 = 0;
731 }
732 if (remaining_src_cols <= 6) {
733 src_ptr6 = zerobuf;
734 src_inc6 = 0;
735 }
736 src_ptr7 = zerobuf;
737 src_inc7 = 0;
738 }
739
740 for (int k = 0; k < src_rows; k += 16) {
741 for (int m = 0; m < 2; ++m) {
742 const int available_src_rows = src_rows - k - 8 * m;
743 // Effectively,
744 // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m));
745 // but treat each case separately.
746 if (available_src_rows > 7) {
747 __m512 t0, t1, t2, t3;
748 __m512 r0, r1, r2, r3;
749
750 t0 = LoaduTwo(src_ptr0, src_ptr4);
751 t1 = LoaduTwo(src_ptr1, src_ptr5);
752 t2 = LoaduTwo(src_ptr2, src_ptr6);
753 t3 = LoaduTwo(src_ptr3, src_ptr7);
754
755 r0 = _mm512_unpacklo_ps(t0, t1);
756 r2 = _mm512_unpackhi_ps(t0, t1);
757 r1 = _mm512_unpacklo_ps(t2, t3);
758 r3 = _mm512_unpackhi_ps(t2, t3);
759
760 t0 = Mm512UnpackloPsx2(r0, r1);
761 t2 = Mm512UnpackhiPsx2(r0, r1);
762 t1 = Mm512UnpackloPsx2(r2, r3);
763 t3 = Mm512UnpackhiPsx2(r2, r3);
764
765 r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
766 r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
767 r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
768 r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
769
770 _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
771 _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
772 _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
773 _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
774 _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
775 _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
776 _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
777 _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
778 } else if (available_src_rows > 0) {
779 const __mmask8 row_mask =
780 (static_cast<std::uint32_t>(1) << available_src_rows) - 1;
781
782 __m512 t0, t1, t2, t3;
783 __m512 r0, r1, r2, r3;
784
785 t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
786 t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
787 t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
788 t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
789
790 r0 = _mm512_unpacklo_ps(t0, t1);
791 r2 = _mm512_unpackhi_ps(t0, t1);
792 r1 = _mm512_unpacklo_ps(t2, t3);
793 r3 = _mm512_unpackhi_ps(t2, t3);
794
795 t0 = Mm512UnpackloPsx2(r0, r1);
796 t2 = Mm512UnpackhiPsx2(r0, r1);
797 t1 = Mm512UnpackloPsx2(r2, r3);
798 t3 = Mm512UnpackhiPsx2(r2, r3);
799
800 r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
801 r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
802 r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
803 r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
804
805 _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
806 _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
807 _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
808 _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
809 _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
810 _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
811 _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
812 // Do not store _mm512_extractf32x8_ps(r3, 1).
813 }
814
815 packed_ptr += 16 * 8;
816 src_ptr0 += src_inc0;
817 src_ptr1 += src_inc1;
818 src_ptr2 += src_inc2;
819 src_ptr3 += src_inc3;
820 src_ptr4 += src_inc4;
821 src_ptr5 += src_inc5;
822 src_ptr6 += src_inc6;
823 src_ptr7 += src_inc7;
824 }
825 }
826}
827
828inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
829 const int non_trailing_rows = src_rows & ~7;
830 for (int k = 0; k < non_trailing_rows; ++k) {
831 for (int j = 0; j < 8; ++j) {
832 packed_ptr[j] = 0.0f;
833 }
834 packed_ptr += 16;
835 }
836}
837
838} // namespace.
839
840void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
841 std::int8_t input_xor,
842 const std::int8_t* zerobuf, int src_stride,
843 int remaining_src_cols, int src_rows,
844 std::int8_t* packed_ptr,
845 std::int32_t* sums_ptr) {
846 profiler::ScopeLabel label("Pack kAvx512 8bit");
847
848 using Layout = PackImpl8bitAvx512::Layout;
849 constexpr int kHalfBlockOffset = 32;
850 RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
851 static constexpr int kHalfLayoutCols =
852 PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a
853 // block.
854 RUY_DCHECK_EQ(kHalfLayoutCols, 8);
855 RUY_DCHECK_EQ(Layout::kCols, 16);
856 RUY_DCHECK_EQ(Layout::kRows, 4);
857
858 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
859 // We process 8 of these chunks at a time, padding short input chunks.
860 constexpr int kNumRowChunks = 8;
861
862 // Each packed block is 4*16, and there are normally 8. The trailing block is
863 // only slightly shorter.
864 constexpr int kTrailingBufSize =
865 kNumRowChunks * Layout::kCols * Layout::kRows;
866 std::int8_t trailing_buf[kTrailingBufSize];
867 memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
868 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
869
870 std::int32_t* second_sums_ptr =
871 sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
872 if (remaining_src_cols > kHalfLayoutCols) {
873 HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
874 remaining_src_cols, src_rows, packed_ptr, sums_ptr,
875 trailing_buf);
876 HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
877 zerobuf, src_stride,
878 remaining_src_cols - kHalfLayoutCols, src_rows,
879 packed_ptr + kHalfBlockOffset, second_sums_ptr,
880 trailing_buf + kHalfBlockOffset);
881 } else {
882 HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
883 remaining_src_cols, src_rows, packed_ptr, sums_ptr,
884 trailing_buf);
885 ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
886 src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset,
887 kChunkedRowMask);
888 // The kernel may not need the second half-blocks sums to be set.
889 if (second_sums_ptr) {
890 for (int i = 0; i < kHalfLayoutCols; ++i) {
891 second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
892 }
893 }
894 }
895 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
896 // If the number of source rows is not a multiple of kChunkedRowMask, there
897 // will be data in the trailing buffer,
898 if (trailing_data) {
899 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
900 // Destination "rows" are padded to next highest multiple of Layout::kRows.
901 const int dst_rows = (src_rows + 3) & ~3;
902 const int trailing_rows = dst_rows - non_trailing_rows;
903 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
904 Layout::kCols * trailing_rows * sizeof(std::int8_t));
905 }
906}
907
908void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
909 const std::int16_t* zerobuf, int src_stride,
910 int remaining_src_cols, int src_rows,
911 std::int16_t* packed_ptr,
912 std::int32_t* sums_ptr) {
913 profiler::ScopeLabel label("Pack kAvx512 16bit");
914
915 using Layout = PackImpl16bitAvx512::Layout;
916 constexpr int kHalfBlockOffset = 32;
917 RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
918 static constexpr int kHalfLayoutCols =
919 PackImpl16bitAvx512::kHalfLayoutCols; // Half the number of cols in a
920 // block.
921 RUY_DCHECK_EQ(kHalfLayoutCols, 8);
922 RUY_DCHECK_EQ(Layout::kCols, 16);
923 RUY_DCHECK_EQ(Layout::kRows, 4);
924
925 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
926 // We process 8 of these chunks at a time, padding short input chunks.
927 constexpr int kNumRowChunks = 4;
928
929 // Each packed block is 4*16, and there are normally 8. The trailing block is
930 // only slightly shorter.
931 constexpr int kTrailingBufSize =
932 kNumRowChunks * Layout::kCols * Layout::kRows;
933 std::int16_t trailing_buf[kTrailingBufSize] = {0};
934 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
935
936 std::int32_t* second_sums_ptr =
937 sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
938 if (remaining_src_cols > kHalfLayoutCols) {
939 HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
940 src_rows, packed_ptr, sums_ptr, trailing_buf);
941 HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf,
942 src_stride, remaining_src_cols - kHalfLayoutCols,
943 src_rows, packed_ptr + kHalfBlockOffset,
944 second_sums_ptr, trailing_buf + kHalfBlockOffset);
945 } else {
946 HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
947 src_rows, packed_ptr, sums_ptr, trailing_buf);
948 ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
949 src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask);
950 // The kernel may not need the second half-blocks sums to be set.
951 if (second_sums_ptr) {
952 for (int i = 0; i < kHalfLayoutCols; ++i) {
953 second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3);
954 }
955 }
956 }
957 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
958 // If the number of source rows is not a multiple of kChunkedRowMask, there
959 // will be data in the trailing buffer,
960 if (trailing_data) {
961 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
962 // Destination "rows" are padded to next highest multiple of Layout::kRows.
963 const int dst_rows = (src_rows + 3) & ~3;
964 const int trailing_rows = dst_rows - non_trailing_rows;
965 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
966 Layout::kCols * trailing_rows * sizeof(std::int16_t));
967 }
968}
969
970void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
971 int src_stride, int remaining_src_cols,
972 int src_rows, float* packed_ptr) {
973 profiler::ScopeLabel label("Pack kAvx512 float");
974 float trailing_buf[7 * 16];
975 if (remaining_src_cols > 8) {
976 HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
977 src_rows, packed_ptr, trailing_buf);
978 HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
979 remaining_src_cols - 8, src_rows, packed_ptr + 8,
980 trailing_buf + 8);
981 } else {
982 memset(trailing_buf, 0, sizeof(trailing_buf));
983 HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
984 src_rows, packed_ptr, trailing_buf);
985 ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
986 }
987 const int trailing_rows = src_rows & 7;
988 if (trailing_rows > 0) {
989 const int non_trailing_rows = src_rows & ~7;
990 memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
991 16 * trailing_rows * sizeof(float));
992 }
993}
994
995void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
996 int src_zero_point, std::int8_t* packed_ptr,
997 int packed_stride, int start_col, int end_col,
998 int src_cols, int block_row, int src_rows,
999 int input_xor, std::int32_t* sums) {
1000 int col = start_col;
1001 int src_end_col = std::min(end_col, src_cols);
1002
1003 for (; col <= src_end_col - 16; col += 16) {
1004 std::int8_t* dst_ptr = packed_ptr;
1005 __m128i val0, val1, val2, val3;
1006 __m128i input_xor_dup = _mm_set1_epi8(input_xor);
1007 // Load a 4x16 block.
1008 if (block_row + 4 <= src_rows) {
1009 val0 = _mm_loadu_si128(
1010 reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
1011 val1 = _mm_loadu_si128(
1012 reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
1013 val2 = _mm_loadu_si128(
1014 reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
1015 val3 = _mm_loadu_si128(
1016 reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
1017 } else {
1018 val0 = _mm_set1_epi8(src_zero_point);
1019 val1 = val0;
1020 val2 = val0;
1021 val3 = val0;
1022 if (block_row + 0 < src_rows)
1023 val0 = _mm_loadu_si128(
1024 reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
1025 if (block_row + 1 < src_rows)
1026 val1 = _mm_loadu_si128(
1027 reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
1028 if (block_row + 2 < src_rows)
1029 val2 = _mm_loadu_si128(
1030 reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
1031 if (block_row + 3 < src_rows)
1032 val3 = _mm_loadu_si128(
1033 reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
1034 }
1035 // Maybe xor the sign bit to convert from uint8 to int8.
1036 val0 = _mm_xor_si128(val0, input_xor_dup);
1037 val1 = _mm_xor_si128(val1, input_xor_dup);
1038 val2 = _mm_xor_si128(val2, input_xor_dup);
1039 val3 = _mm_xor_si128(val3, input_xor_dup);
1040 // Update the sums.
1041 __m256i val16_0 = _mm256_cvtepi8_epi16(val0);
1042 __m256i val16_1 = _mm256_cvtepi8_epi16(val1);
1043 __m256i val16_2 = _mm256_cvtepi8_epi16(val2);
1044 __m256i val16_3 = _mm256_cvtepi8_epi16(val3);
1045 __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
1046 _mm256_add_epi16(val16_2, val16_3));
1047 __m512i sum =
1048 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
1049 sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
1050 _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
1051 auto zip = [](__m128i x, __m128i y) {
1052 auto perm_64_0_64_0 = [](__m128i x) {
1053 return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
1054 _mm256_castsi128_si256(x));
1055 };
1056 return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
1057 };
1058 __m256i t2_val0 = zip(val0, val1);
1059 __m256i t2_val1 = zip(val2, val3);
1060 __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
1061 __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
1062 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
1063 _mm256_extractf128_si256(t4_val0, 0));
1064 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
1065 _mm256_extractf128_si256(t4_val1, 0));
1066 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
1067 _mm256_extractf128_si256(t4_val0, 1));
1068 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
1069 _mm256_extractf128_si256(t4_val1, 1));
1070 src_ptr += 16;
1071 packed_ptr += packed_stride * 16;
1072 }
1073 for (; col < src_end_col; col++) {
1074 std::int32_t accum = 0;
1075 for (int r = 0; r < 4; r++) {
1076 std::int8_t packed_val;
1077 if (block_row + r < src_rows) {
1078 packed_val = input_xor ^ src_ptr[r * src_stride];
1079 } else {
1080 packed_val = input_xor ^ src_zero_point;
1081 }
1082 accum += packed_val;
1083 *packed_ptr++ = packed_val;
1084 }
1085 if (sums) {
1086 sums[col] += accum;
1087 }
1088 src_ptr++;
1089 }
1090 for (; col < end_col; col++) {
1091 std::memset(packed_ptr, 0, 4);
1092 packed_ptr += 4;
1093 }
1094}
1095
1096#endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
1097
1098} // namespace ruy
1099