1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17#include <cstring>
18
19#include "ruy/check_macros.h"
20#include "ruy/opt_set.h"
21#include "ruy/pack_x86.h"
22#include "ruy/path.h"
23#include "ruy/platform.h"
24#include "ruy/profiler/instrumentation.h"
25
26#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
27#include <immintrin.h> // IWYU pragma: keep
28#endif
29
30namespace ruy {
31
32#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
33
34void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t,
35 const std::int8_t*, int, int, int, std::int8_t*,
36 std::int32_t*) {
37 // CPU-ID-based checks should disable the path that would reach this point.
38 RUY_DCHECK(false);
39}
40
41void PackFloatColMajorForAvx2(const float*, const float*, int, int, int,
42 float*) {
43 // CPU-ID-based checks should disable the path that would reach this point.
44 RUY_DCHECK(false);
45}
46
47void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
48 int, int, int, int, int, int, std::int32_t*) {
49 RUY_DCHECK(false);
50}
51
52#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
53
54// The first int8_t template parameter is arbitrary: this routine is common to
55// all 8-bit source matrix types.
56using PackImpl8bitAvx2 =
57 PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
58 std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
59
60using PackImplFloatAvx2 =
61 PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
62 float, float, Order::kColMajor>;
63
64namespace {
65
66inline void Pack8bitColMajorForAvx2Packer(
67 const std::int8_t* src_ptr, std::int8_t input_xor,
68 const std::int8_t* zerobuf, int src_stride, int remaining_src_cols,
69 int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr,
70 std::int8_t* trailing_buf) {
71 using Layout = PackImpl8bitAvx2::Layout;
72 RUY_DCHECK_EQ(Layout::kCols, 8);
73 RUY_DCHECK_EQ(Layout::kRows, 4);
74 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
75 // We process 8 of these chunks at a time, padding short input chunks.
76 constexpr int kNumRowChunks = 8;
77 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
78
79 const std::int8_t* src_ptr0 = src_ptr;
80 const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
81 const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
82 const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
83 const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
84 const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
85 const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
86 const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
87 std::int64_t src_inc0 = kNumChunkedSrcRows;
88 std::int64_t src_inc1 = kNumChunkedSrcRows;
89 std::int64_t src_inc2 = kNumChunkedSrcRows;
90 std::int64_t src_inc3 = kNumChunkedSrcRows;
91 std::int64_t src_inc4 = kNumChunkedSrcRows;
92 std::int64_t src_inc5 = kNumChunkedSrcRows;
93 std::int64_t src_inc6 = kNumChunkedSrcRows;
94 std::int64_t src_inc7 = kNumChunkedSrcRows;
95 // Handle cases where source does not have Layout::kCols (8) columns.
96 if (remaining_src_cols < 8) {
97 if (remaining_src_cols <= 0) {
98 src_ptr0 = zerobuf;
99 src_inc0 = 0;
100 }
101 if (remaining_src_cols <= 1) {
102 src_ptr1 = zerobuf;
103 src_inc1 = 0;
104 }
105 if (remaining_src_cols <= 2) {
106 src_ptr2 = zerobuf;
107 src_inc2 = 0;
108 }
109 if (remaining_src_cols <= 3) {
110 src_ptr3 = zerobuf;
111 src_inc3 = 0;
112 }
113 if (remaining_src_cols <= 4) {
114 src_ptr4 = zerobuf;
115 src_inc4 = 0;
116 }
117 if (remaining_src_cols <= 5) {
118 src_ptr5 = zerobuf;
119 src_inc5 = 0;
120 }
121 if (remaining_src_cols <= 6) {
122 src_ptr6 = zerobuf;
123 src_inc6 = 0;
124 }
125 src_ptr7 = zerobuf;
126 src_inc7 = 0;
127 }
128
129 const std::int8_t zero_point = zerobuf[0];
130
131 if (sums_ptr) {
132 // i: Layout::kCols.
133 for (int i = 0; i < 8; ++i) {
134 sums_ptr[i] = 0;
135 }
136 }
137 std::int32_t sums_adjustment = 0;
138 const __m256i ones_16bit = _mm256_set1_epi16(1);
139 __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
140 __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
141
142 // The overall packing effectively pads the source rows to
143 // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
144 // only pack for (src_rows + 31) & ~31. When there is an incomplete
145 // destination block, this is stored into trailing_buf instead of packed_ptr.
146 for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
147 // Available source rows.
148 // If this is less than 0 (for m=1), we skip, having filled trailing
149 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
150 // exactly to the end of the column in the packed buffer.
151 const int available_src_rows = src_rows - k;
152 // Effectively,
153 // available rows = std::max(0, std::min(8, src_rows - k));
154 // treat each case separately.
155 if (available_src_rows >= kNumChunkedSrcRows) {
156 if (sums_ptr) {
157 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
158 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
159 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
160
161 t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
162 t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
163 t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
164 t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
165 t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
166 t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
167 t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
168 t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
169
170 r0 = _mm256_unpacklo_epi32(t0, t1);
171 r4 = _mm256_unpacklo_epi32(t4, t5);
172 r2 = _mm256_unpackhi_epi32(t0, t1);
173 r6 = _mm256_unpackhi_epi32(t4, t5);
174 r1 = _mm256_unpacklo_epi32(t2, t3);
175 r5 = _mm256_unpacklo_epi32(t6, t7);
176 r3 = _mm256_unpackhi_epi32(t2, t3);
177 r7 = _mm256_unpackhi_epi32(t6, t7);
178
179 t0 = _mm256_unpacklo_epi64(r0, r1);
180 t4 = _mm256_unpacklo_epi64(r4, r5);
181 t2 = _mm256_unpackhi_epi64(r0, r1);
182 t6 = _mm256_unpackhi_epi64(r4, r5);
183 t1 = _mm256_unpacklo_epi64(r2, r3);
184 t5 = _mm256_unpacklo_epi64(r6, r7);
185 t3 = _mm256_unpackhi_epi64(r2, r3);
186 t7 = _mm256_unpackhi_epi64(r6, r7);
187
188 // The preceding sets of rearrangement operations interleaved by 4 bytes
189 // and then by 8 bytes *within* lanes. The following set interleave by
190 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
191 // t4) are interleaved to create (r0, r1). This complexity follows from
192 // the way that AVX is centered around MM 128-bit lanes.
193 r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
194 r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
195 r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
196 r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
197 r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
198 r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
199 r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
200 r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
201
202 r0 = _mm256_xor_si256(r0, input_xor_v);
203 r1 = _mm256_xor_si256(r1, input_xor_v);
204 r2 = _mm256_xor_si256(r2, input_xor_v);
205 r3 = _mm256_xor_si256(r3, input_xor_v);
206 r4 = _mm256_xor_si256(r4, input_xor_v);
207 r5 = _mm256_xor_si256(r5, input_xor_v);
208 r6 = _mm256_xor_si256(r6, input_xor_v);
209 r7 = _mm256_xor_si256(r7, input_xor_v);
210
211 __m256i sums_4x4_16bit_lo;
212 sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
213 sums_4x4_16bit_lo =
214 _mm256_add_epi16(sums_4x4_16bit_lo,
215 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
216 sums_4x4_16bit_lo =
217 _mm256_add_epi16(sums_4x4_16bit_lo,
218 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
219 sums_4x4_16bit_lo =
220 _mm256_add_epi16(sums_4x4_16bit_lo,
221 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
222 sums_4x4_16bit_lo =
223 _mm256_add_epi16(sums_4x4_16bit_lo,
224 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
225 sums_4x4_16bit_lo =
226 _mm256_add_epi16(sums_4x4_16bit_lo,
227 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
228 sums_4x4_16bit_lo =
229 _mm256_add_epi16(sums_4x4_16bit_lo,
230 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
231 sums_4x4_16bit_lo =
232 _mm256_add_epi16(sums_4x4_16bit_lo,
233 _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
234
235 // The sums have been performed across columns, and now we have 4x16-bit
236 // sums packed together. We use madd for pairwise 32-bit sums.
237 const __m256i sums_4x2_32bit_lo_new =
238 _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
239 sums_4x2_32bit_lo =
240 _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
241
242 __m256i sums_4x4_16bit_hi;
243 sums_4x4_16bit_hi =
244 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
245 sums_4x4_16bit_hi = _mm256_add_epi16(
246 sums_4x4_16bit_hi,
247 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
248 sums_4x4_16bit_hi = _mm256_add_epi16(
249 sums_4x4_16bit_hi,
250 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
251 sums_4x4_16bit_hi = _mm256_add_epi16(
252 sums_4x4_16bit_hi,
253 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
254 sums_4x4_16bit_hi = _mm256_add_epi16(
255 sums_4x4_16bit_hi,
256 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
257 sums_4x4_16bit_hi = _mm256_add_epi16(
258 sums_4x4_16bit_hi,
259 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
260 sums_4x4_16bit_hi = _mm256_add_epi16(
261 sums_4x4_16bit_hi,
262 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
263 sums_4x4_16bit_hi = _mm256_add_epi16(
264 sums_4x4_16bit_hi,
265 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
266
267 const __m256i sums_4x2_32bit_hi_new =
268 _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
269 sums_4x2_32bit_hi =
270 _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
271
272 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
273 r0);
274 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
275 r4);
276 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
277 r1);
278 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
279 r5);
280 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
281 r2);
282 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
283 r6);
284 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
285 r3);
286 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
287 r7);
288 } else {
289 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
290 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
291 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
292
293 t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
294 t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
295 t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
296 t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
297 t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
298 t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
299 t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
300 t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
301
302 r0 = _mm256_unpacklo_epi32(t0, t1);
303 r4 = _mm256_unpacklo_epi32(t4, t5);
304 r2 = _mm256_unpackhi_epi32(t0, t1);
305 r6 = _mm256_unpackhi_epi32(t4, t5);
306 r1 = _mm256_unpacklo_epi32(t2, t3);
307 r5 = _mm256_unpacklo_epi32(t6, t7);
308 r3 = _mm256_unpackhi_epi32(t2, t3);
309 r7 = _mm256_unpackhi_epi32(t6, t7);
310
311 t0 = _mm256_unpacklo_epi64(r0, r1);
312 t4 = _mm256_unpacklo_epi64(r4, r5);
313 t2 = _mm256_unpackhi_epi64(r0, r1);
314 t6 = _mm256_unpackhi_epi64(r4, r5);
315 t1 = _mm256_unpacklo_epi64(r2, r3);
316 t5 = _mm256_unpacklo_epi64(r6, r7);
317 t3 = _mm256_unpackhi_epi64(r2, r3);
318 t7 = _mm256_unpackhi_epi64(r6, r7);
319
320 // The preceding sets of rearrangement operations interleaved by 4 bytes
321 // and then by 8 bytes *within* lanes. The following set interleave by
322 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
323 // t4) are interleaved to create (r0, r1). This complexity follows from
324 // the way that AVX is centered around MM 128-bit lanes.
325 r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
326 r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
327 r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
328 r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
329 r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
330 r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
331 r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
332 r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
333
334 r0 = _mm256_xor_si256(r0, input_xor_v);
335 r1 = _mm256_xor_si256(r1, input_xor_v);
336 r2 = _mm256_xor_si256(r2, input_xor_v);
337 r3 = _mm256_xor_si256(r3, input_xor_v);
338 r4 = _mm256_xor_si256(r4, input_xor_v);
339 r5 = _mm256_xor_si256(r5, input_xor_v);
340 r6 = _mm256_xor_si256(r6, input_xor_v);
341 r7 = _mm256_xor_si256(r7, input_xor_v);
342
343 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
344 r0);
345 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
346 r4);
347 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
348 r1);
349 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
350 r5);
351 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
352 r2);
353 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
354 r6);
355 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
356 r3);
357 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
358 r7);
359 }
360 } else if (available_src_rows > 0) {
361 RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
362 // We do not care what goes into the trailing buffer, but we want
363 // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
364 //
365 // We compensate for padding-with-zero_point by initializing the
366 // summations with the compensating offset, effectively
367 // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
368 // 4 * (8 - ((available_src_rows + 3) >> 2)).
369 //
370 // Note that (zero_point ^ input_xor) is performed in 8-bits and then
371 // cast.
372 sums_adjustment +=
373 -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
374
375 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
376 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
377 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
378
379 t0 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr0);
380 t4 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr4);
381 t1 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr1);
382 t5 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr5);
383 t2 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr2);
384 t6 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr6);
385 t3 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr3);
386 t7 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr7);
387
388 r0 = _mm256_unpacklo_epi32(t0, t1);
389 r4 = _mm256_unpacklo_epi32(t4, t5);
390 r2 = _mm256_unpackhi_epi32(t0, t1);
391 r6 = _mm256_unpackhi_epi32(t4, t5);
392 r1 = _mm256_unpacklo_epi32(t2, t3);
393 r5 = _mm256_unpacklo_epi32(t6, t7);
394 r3 = _mm256_unpackhi_epi32(t2, t3);
395 r7 = _mm256_unpackhi_epi32(t6, t7);
396
397 t0 = _mm256_unpacklo_epi64(r0, r1);
398 t4 = _mm256_unpacklo_epi64(r4, r5);
399 t2 = _mm256_unpackhi_epi64(r0, r1);
400 t6 = _mm256_unpackhi_epi64(r4, r5);
401 t1 = _mm256_unpacklo_epi64(r2, r3);
402 t5 = _mm256_unpacklo_epi64(r6, r7);
403 t3 = _mm256_unpackhi_epi64(r2, r3);
404 t7 = _mm256_unpackhi_epi64(r6, r7);
405
406 // The preceding sets of rearrangement operations interleaved by 4 bytes
407 // and then by 8 bytes *within* lanes. The following set interleave by
408 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
409 // t4) are interleaved to create (r0, r1). This complexity follows from
410 // the way that AVX is centered around MM 128-bit lanes.
411 r0 = _mm256_permute2x128_si256(t0, t4, 0x20);
412 r4 = _mm256_permute2x128_si256(t1, t5, 0x20);
413 r1 = _mm256_permute2x128_si256(t0, t4, 0x31);
414 r5 = _mm256_permute2x128_si256(t1, t5, 0x31);
415 r2 = _mm256_permute2x128_si256(t2, t6, 0x20);
416 r6 = _mm256_permute2x128_si256(t3, t7, 0x20);
417 r3 = _mm256_permute2x128_si256(t2, t6, 0x31);
418 r7 = _mm256_permute2x128_si256(t3, t7, 0x31);
419
420 r0 = _mm256_xor_si256(r0, input_xor_v);
421 r1 = _mm256_xor_si256(r1, input_xor_v);
422 r2 = _mm256_xor_si256(r2, input_xor_v);
423 r3 = _mm256_xor_si256(r3, input_xor_v);
424 r4 = _mm256_xor_si256(r4, input_xor_v);
425 r5 = _mm256_xor_si256(r5, input_xor_v);
426 r6 = _mm256_xor_si256(r6, input_xor_v);
427 r7 = _mm256_xor_si256(r7, input_xor_v);
428
429 __m256i sums_4x4_16bit_lo;
430 sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
431 sums_4x4_16bit_lo = _mm256_add_epi16(
432 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
433 sums_4x4_16bit_lo = _mm256_add_epi16(
434 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
435 sums_4x4_16bit_lo = _mm256_add_epi16(
436 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
437 sums_4x4_16bit_lo = _mm256_add_epi16(
438 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
439 sums_4x4_16bit_lo = _mm256_add_epi16(
440 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
441 sums_4x4_16bit_lo = _mm256_add_epi16(
442 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
443 sums_4x4_16bit_lo = _mm256_add_epi16(
444 sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
445
446 // The sums have been performed across columns, and now we have 4x16-bit
447 // sums packed together. We use madd for pairwise 32-bit sums.
448 const __m256i sums_4x2_32bit_lo_new =
449 _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
450 sums_4x2_32bit_lo =
451 _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
452
453 __m256i sums_4x4_16bit_hi;
454 sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1));
455 sums_4x4_16bit_hi = _mm256_add_epi16(
456 sums_4x4_16bit_hi,
457 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1)));
458 sums_4x4_16bit_hi = _mm256_add_epi16(
459 sums_4x4_16bit_hi,
460 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1)));
461 sums_4x4_16bit_hi = _mm256_add_epi16(
462 sums_4x4_16bit_hi,
463 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1)));
464 sums_4x4_16bit_hi = _mm256_add_epi16(
465 sums_4x4_16bit_hi,
466 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1)));
467 sums_4x4_16bit_hi = _mm256_add_epi16(
468 sums_4x4_16bit_hi,
469 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1)));
470 sums_4x4_16bit_hi = _mm256_add_epi16(
471 sums_4x4_16bit_hi,
472 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1)));
473 sums_4x4_16bit_hi = _mm256_add_epi16(
474 sums_4x4_16bit_hi,
475 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1)));
476
477 const __m256i sums_4x2_32bit_hi_new =
478 _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
479 sums_4x2_32bit_hi =
480 _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
481
482 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
483 r0);
484 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
485 r4);
486 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
487 r1);
488 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
489 r5);
490 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
491 r2);
492 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
493 r6);
494 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
495 r3);
496 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
497 r7);
498 }
499
500 packed_ptr += 8 * kNumChunkedSrcRows;
501 src_ptr0 += src_inc0;
502 src_ptr1 += src_inc1;
503 src_ptr2 += src_inc2;
504 src_ptr3 += src_inc3;
505 src_ptr4 += src_inc4;
506 src_ptr5 += src_inc5;
507 src_ptr6 += src_inc6;
508 src_ptr7 += src_inc7;
509 }
510
511 if (sums_ptr) {
512 const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
513
514 __m256i sums =
515 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
516 const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
517
518 // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
519 // neighbours, finshing up by adding them to the stored accumulated sums.
520 const __m256i sums_2x4_32bit_lo =
521 _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx);
522 const __m256i sums_2x4_32bit_hi =
523 _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx);
524 const __m256i sums_2x4_32bit_a =
525 _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
526 const __m256i sums_2x4_32bit_b =
527 _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
528 sums = _mm256_add_epi32(sums, sums_adjustment_v);
529 sums = _mm256_add_epi32(sums, sums_2x4_32bit_a);
530 sums = _mm256_add_epi32(sums, sums_2x4_32bit_b);
531
532 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
533 }
534}
535
536// Use AVX2 specific intrinsic for greater than comparison.
537template <>
538inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a,
539 const __m256i& b) {
540 return _mm256_cmpgt_epi32(a, b);
541}
542
543} // namespace.
544
545void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
546 const std::int8_t* zerobuf, int src_stride,
547 int remaining_src_cols, int src_rows,
548 std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
549 profiler::ScopeLabel label("Pack kAvx2Fma 8bit");
550
551 using Layout = PackImpl8bitAvx2::Layout;
552 RUY_DCHECK_EQ(Layout::kCols, 8);
553 RUY_DCHECK_EQ(Layout::kRows, 4);
554
555 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
556 // We process 8 of these chunks at a time, padding short input chunks.
557 static constexpr int kNumRowChunks = 8; // Short input is padded.
558
559 // Each packed block is 4*8, and there are normally 8. The trailing block is
560 // only slightly shorter.
561 constexpr int kTrailingBufSize =
562 kNumRowChunks * Layout::kCols * Layout::kRows;
563 std::int8_t trailing_buf[kTrailingBufSize];
564 memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
565
566 Pack8bitColMajorForAvx2Packer(src_ptr, input_xor, zerobuf, src_stride,
567 remaining_src_cols, src_rows, packed_ptr,
568 sums_ptr, trailing_buf);
569
570 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
571 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
572 // If the number of source rows is not a multiple of kChunkedRowMask, there
573 // will be data in the trailing buffer,
574 if (trailing_data) {
575 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
576 // Destination "rows" are padded to next highest multiple of Layout::kRows.
577 const int dst_rows = (src_rows + 3) & ~3;
578 const int trailing_rows = dst_rows - non_trailing_rows;
579 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
580 Layout::kCols * trailing_rows * sizeof(std::int8_t));
581 }
582}
583
584void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
585 int src_stride, int remaining_src_cols,
586 int src_rows, float* packed_ptr) {
587 profiler::ScopeLabel label("Pack kAvx2Fma float");
588 static constexpr int kPackCols = 8; // Source cols packed together.
589 static constexpr int kPackRows = 8; // Short input is padded.
590 float trailing_buf[(kPackRows - 1) * kPackCols];
591 if (remaining_src_cols < 8) {
592 memset(trailing_buf, 0, sizeof(trailing_buf));
593 }
594 PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx2, Path::kAvx2Fma>(
595 src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
596 trailing_buf);
597
598 const int trailing_rows = src_rows & (kPackRows - 1);
599 if (trailing_rows > 0) {
600 const int non_trailing_rows = src_rows & ~(kPackRows - 1);
601 memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
602 kPackCols * trailing_rows * sizeof(float));
603 }
604}
605
606void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
607 int src_zero_point, std::int8_t* packed_ptr,
608 int packed_stride, int start_col, int end_col,
609 int src_cols, int block_row, int src_rows,
610 int input_xor, std::int32_t* sums) {
611 int col = start_col;
612 int src_end_col = std::min(end_col, src_cols);
613
614 for (; col <= src_end_col - 8; col += 8) {
615 std::int8_t* dst_ptr = packed_ptr;
616 __m128i val0, val1, val2, val3;
617 __m128i input_xor_dup = _mm_set1_epi8(input_xor);
618 // Load a 4x8 block.
619 if (block_row + 4 <= src_rows) {
620 val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
621 val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
622 val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
623 val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
624 } else {
625 val0 = _mm_set1_epi8(src_zero_point);
626 val1 = val0;
627 val2 = val0;
628 val3 = val0;
629 if (block_row + 0 < src_rows)
630 val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
631 if (block_row + 1 < src_rows)
632 val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
633 if (block_row + 2 < src_rows)
634 val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
635 if (block_row + 3 < src_rows)
636 val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
637 }
638 // Maybe xor the sign bit to convert from uint8 to int8.
639 val0 = _mm_xor_si128(val0, input_xor_dup);
640 val1 = _mm_xor_si128(val1, input_xor_dup);
641 val2 = _mm_xor_si128(val2, input_xor_dup);
642 val3 = _mm_xor_si128(val3, input_xor_dup);
643 // Update the sums.
644 __m128i val16_0 = _mm_cvtepi8_epi16(val0);
645 __m128i val16_1 = _mm_cvtepi8_epi16(val1);
646 __m128i val16_2 = _mm_cvtepi8_epi16(val2);
647 __m128i val16_3 = _mm_cvtepi8_epi16(val3);
648 __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
649 _mm_add_epi16(val16_2, val16_3));
650 __m256i sum =
651 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
652 sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16));
653 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
654 // Perform the transposition of 4x4 blocks
655 __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
656 __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
657 __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
658 __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
659 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
660 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
661 src_ptr += 8;
662 packed_ptr += packed_stride * 8;
663 }
664 for (; col < src_end_col; col++) {
665 std::int32_t accum = 0;
666 for (int r = 0; r < 4; r++) {
667 std::int8_t packed_val;
668 if (block_row + r < src_rows) {
669 packed_val = input_xor ^ src_ptr[r * src_stride];
670 } else {
671 packed_val = input_xor ^ src_zero_point;
672 }
673 accum += packed_val;
674 *packed_ptr++ = packed_val;
675 }
676 if (sums) {
677 sums[col] += accum;
678 }
679 src_ptr++;
680 }
681 for (; col < end_col; col++) {
682 std::memset(packed_ptr, 0, 4);
683 packed_ptr += 4;
684 }
685}
686
687#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
688
689} // namespace ruy
690