1/* Copyright 2020 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17#include <cstring>
18
19#include "ruy/check_macros.h"
20#include "ruy/opt_set.h"
21#include "ruy/pack_x86.h"
22#include "ruy/path.h"
23#include "ruy/platform.h"
24#include "ruy/profiler/instrumentation.h"
25
26#if RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
27#include <immintrin.h> // IWYU pragma: keep
28#endif
29
30namespace ruy {
31
32#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
33
34void Pack8bitColMajorForAvx(const std::int8_t*, std::int8_t, const std::int8_t*,
35 int, int, int, std::int8_t*, std::int32_t*) {
36 // CPU-ID-based checks should disable the path that would reach this point.
37 RUY_DCHECK(false);
38}
39
40void PackFloatColMajorForAvx(const float*, const float*, int, int, int,
41 float*) {
42 // CPU-ID-based checks should disable the path that would reach this point.
43 RUY_DCHECK(false);
44}
45
46void Pack8bitRowMajorForAvx(const std::uint8_t*, int, int, std::int8_t*, int,
47 int, int, int, int, int, int, std::int32_t*) {
48 RUY_DCHECK(false);
49}
50
51#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
52
53// The first int8_t template parameter is arbitrary: this routine is common to
54// all 8-bit source matrix types.
55using PackImpl8bitAvx =
56 PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t,
57 std::int8_t, std::int32_t, Order::kColMajor>;
58
59using PackImplFloatAvx =
60 PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
61 float, float, Order::kColMajor>;
62
63namespace {
64
65// Perform the equivalent of mm256_permutevar8x32 with
66// a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
67inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
68 // a_lo = 3 2 1 0
69 __m128i a_lo = _mm256_extractf128_si256(a, 0);
70 // a_hi = 7 6 5 4
71 __m128i a_hi = _mm256_extractf128_si256(a, 1);
72 // shuffle a_lo to get 3 1 2 0
73 __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
74 // shuffle a_hi to get 7 5 6 4
75 __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
76 // unpack lo 64 of res_lo and res hi to get 6 4 2 0
77 __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
78 // unpack hi 64 of res_lo and res hi to get 7 5 3 1
79 __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
80 return _mm256_set_m128i(res_hi, res_lo);
81}
82
83inline __m128i mm256_extracti128_si256(const __m256i& a, const int imm) {
84 switch (imm) {
85 case 0:
86 return _mm256_extractf128_si256(a, 0);
87 case 1:
88 return _mm256_extractf128_si256(a, 1);
89 default:
90 RUY_DCHECK_LT(imm, 2);
91 return _mm_setzero_si128();
92 }
93}
94
95inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
96 // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
97 __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
98 return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
99}
100
101inline __m256i mm256_cvtepi16_epi32(const __m128i& a) {
102 // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
103 __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
104 return _mm256_set_m128i(_mm_cvtepi16_epi32(hi), _mm_cvtepi16_epi32(a));
105}
106
107inline __m256i mm256_xor_si256(const __m256i& a, const __m256i& b) {
108 __m128i a_lo = _mm256_extractf128_si256(a, 0);
109 __m128i a_hi = _mm256_extractf128_si256(a, 1);
110 __m128i b_lo = _mm256_extractf128_si256(b, 0);
111 __m128i b_hi = _mm256_extractf128_si256(b, 1);
112 __m128i lo = _mm_xor_si128(a_lo, b_lo);
113 __m128i hi = _mm_xor_si128(a_hi, b_hi);
114 return _mm256_set_m128i(hi, lo);
115}
116
117inline __m256i mm256_unpacklo_epi32(const __m256i& a, const __m256i& b) {
118 __m128i a_lo = _mm256_extractf128_si256(a, 0);
119 __m128i a_hi = _mm256_extractf128_si256(a, 1);
120 __m128i b_lo = _mm256_extractf128_si256(b, 0);
121 __m128i b_hi = _mm256_extractf128_si256(b, 1);
122 __m128i lo = _mm_unpacklo_epi32(a_lo, b_lo);
123 __m128i hi = _mm_unpacklo_epi32(a_hi, b_hi);
124 return _mm256_set_m128i(hi, lo);
125}
126
127inline __m256i mm256_unpacklo_epi64(const __m256i& a, const __m256i& b) {
128 __m128i a_lo = _mm256_extractf128_si256(a, 0);
129 __m128i a_hi = _mm256_extractf128_si256(a, 1);
130 __m128i b_lo = _mm256_extractf128_si256(b, 0);
131 __m128i b_hi = _mm256_extractf128_si256(b, 1);
132 __m128i lo = _mm_unpacklo_epi64(a_lo, b_lo);
133 __m128i hi = _mm_unpacklo_epi64(a_hi, b_hi);
134 return _mm256_set_m128i(hi, lo);
135}
136
137inline __m256i mm256_unpackhi_epi32(const __m256i& a, const __m256i& b) {
138 __m128i a_lo = _mm256_extractf128_si256(a, 0);
139 __m128i a_hi = _mm256_extractf128_si256(a, 1);
140 __m128i b_lo = _mm256_extractf128_si256(b, 0);
141 __m128i b_hi = _mm256_extractf128_si256(b, 1);
142 __m128i lo = _mm_unpackhi_epi32(a_lo, b_lo);
143 __m128i hi = _mm_unpackhi_epi32(a_hi, b_hi);
144 return _mm256_set_m128i(hi, lo);
145}
146
147inline __m256i mm256_unpackhi_epi64(const __m256i& a, const __m256i& b) {
148 __m128i a_lo = _mm256_extractf128_si256(a, 0);
149 __m128i a_hi = _mm256_extractf128_si256(a, 1);
150 __m128i b_lo = _mm256_extractf128_si256(b, 0);
151 __m128i b_hi = _mm256_extractf128_si256(b, 1);
152 __m128i lo = _mm_unpackhi_epi64(a_lo, b_lo);
153 __m128i hi = _mm_unpackhi_epi64(a_hi, b_hi);
154 return _mm256_set_m128i(hi, lo);
155}
156
157inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
158 __m128i a_lo = _mm256_extractf128_si256(a, 0);
159 __m128i a_hi = _mm256_extractf128_si256(a, 1);
160 __m128i b_lo = _mm256_extractf128_si256(b, 0);
161 __m128i b_hi = _mm256_extractf128_si256(b, 1);
162 __m128i lo = _mm_add_epi32(a_lo, b_lo);
163 __m128i hi = _mm_add_epi32(a_hi, b_hi);
164 return _mm256_set_m128i(hi, lo);
165}
166
167inline __m256i mm256_add_epi16(const __m256i& a, const __m256i& b) {
168 __m128i a_lo = _mm256_extractf128_si256(a, 0);
169 __m128i a_hi = _mm256_extractf128_si256(a, 1);
170 __m128i b_lo = _mm256_extractf128_si256(b, 0);
171 __m128i b_hi = _mm256_extractf128_si256(b, 1);
172 __m128i lo = _mm_add_epi16(a_lo, b_lo);
173 __m128i hi = _mm_add_epi16(a_hi, b_hi);
174 return _mm256_set_m128i(hi, lo);
175}
176
177inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
178 __m128i a_lo = _mm256_extractf128_si256(a, 0);
179 __m128i a_hi = _mm256_extractf128_si256(a, 1);
180 __m128i b_lo = _mm256_extractf128_si256(b, 0);
181 __m128i b_hi = _mm256_extractf128_si256(b, 1);
182 __m128i lo = _mm_madd_epi16(a_lo, b_lo);
183 __m128i hi = _mm_madd_epi16(a_hi, b_hi);
184 return _mm256_set_m128i(hi, lo);
185}
186
187inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
188 const int imm) {
189 __m128i tmp = _mm_setzero_si128();
190 if (!(imm & 8)) {
191 switch (imm & 3) {
192 case 0:
193 return _mm256_extractf128_si256(a, 0);
194 case 1:
195 return _mm256_extractf128_si256(a, 1);
196 case 2:
197 return _mm256_extractf128_si256(b, 0);
198 case 3:
199 return _mm256_extractf128_si256(b, 1);
200 }
201 }
202 return tmp;
203}
204
205inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
206 const int imm) {
207 const int lo_imm = imm & 15;
208 __m128i lo = mm_permute_helper(a, b, lo_imm);
209 const int hi_imm = (imm >> 4) & 15;
210 __m128i hi = mm_permute_helper(a, b, hi_imm);
211 return _mm256_set_m128i(hi, lo);
212}
213
214inline void Pack8bitColMajorForAvxPacker(const std::int8_t* src_ptr,
215 std::int8_t input_xor,
216 const std::int8_t* zerobuf,
217 int src_stride, int remaining_src_cols,
218 int src_rows, std::int8_t* packed_ptr,
219 std::int32_t* sums_ptr,
220 std::int8_t* trailing_buf) {
221 using Layout = PackImpl8bitAvx::Layout;
222 RUY_DCHECK_EQ(Layout::kCols, 8);
223 RUY_DCHECK_EQ(Layout::kRows, 4);
224 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
225 // We process 8 of these chunks at a time, padding short input chunks.
226 constexpr int kNumRowChunks = 8;
227 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
228
229 const std::int8_t* src_ptr0 = src_ptr;
230 const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
231 const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
232 const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
233 const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
234 const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
235 const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
236 const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
237 std::int64_t src_inc0 = kNumChunkedSrcRows;
238 std::int64_t src_inc1 = kNumChunkedSrcRows;
239 std::int64_t src_inc2 = kNumChunkedSrcRows;
240 std::int64_t src_inc3 = kNumChunkedSrcRows;
241 std::int64_t src_inc4 = kNumChunkedSrcRows;
242 std::int64_t src_inc5 = kNumChunkedSrcRows;
243 std::int64_t src_inc6 = kNumChunkedSrcRows;
244 std::int64_t src_inc7 = kNumChunkedSrcRows;
245 // Handle cases where source does not have Layout::kCols (8) columns.
246 if (remaining_src_cols < 8) {
247 if (remaining_src_cols <= 0) {
248 src_ptr0 = zerobuf;
249 src_inc0 = 0;
250 }
251 if (remaining_src_cols <= 1) {
252 src_ptr1 = zerobuf;
253 src_inc1 = 0;
254 }
255 if (remaining_src_cols <= 2) {
256 src_ptr2 = zerobuf;
257 src_inc2 = 0;
258 }
259 if (remaining_src_cols <= 3) {
260 src_ptr3 = zerobuf;
261 src_inc3 = 0;
262 }
263 if (remaining_src_cols <= 4) {
264 src_ptr4 = zerobuf;
265 src_inc4 = 0;
266 }
267 if (remaining_src_cols <= 5) {
268 src_ptr5 = zerobuf;
269 src_inc5 = 0;
270 }
271 if (remaining_src_cols <= 6) {
272 src_ptr6 = zerobuf;
273 src_inc6 = 0;
274 }
275 src_ptr7 = zerobuf;
276 src_inc7 = 0;
277 }
278
279 const std::int8_t zero_point = zerobuf[0];
280
281 if (sums_ptr) {
282 // i: Layout::kCols.
283 for (int i = 0; i < 8; ++i) {
284 sums_ptr[i] = 0;
285 }
286 }
287 std::int32_t sums_adjustment = 0;
288 const __m256i ones_16bit = _mm256_set1_epi16(1);
289 __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
290 __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
291
292 // The overall packing effectively pads the source rows to
293 // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
294 // only pack for (src_rows + 31) & ~31. When there is an incomplete
295 // destination block, this is stored into trailing_buf instead of packed_ptr.
296 for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
297 // Available source rows.
298 // If this is less than 0 (for m=1), we skip, having filled trailing
299 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
300 // exactly to the end of the column in the packed buffer.
301 const int available_src_rows = src_rows - k;
302 // Effectively,
303 // available rows = std::max(0, std::min(8, src_rows - k));
304 // treat each case separately.
305 if (available_src_rows >= kNumChunkedSrcRows) {
306 if (sums_ptr) {
307 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
308 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
309 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
310
311 t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
312 t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
313 t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
314 t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
315 t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
316 t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
317 t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
318 t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
319
320 r0 = mm256_unpacklo_epi32(t0, t1);
321 r4 = mm256_unpacklo_epi32(t4, t5);
322 r2 = mm256_unpackhi_epi32(t0, t1);
323 r6 = mm256_unpackhi_epi32(t4, t5);
324 r1 = mm256_unpacklo_epi32(t2, t3);
325 r5 = mm256_unpacklo_epi32(t6, t7);
326 r3 = mm256_unpackhi_epi32(t2, t3);
327 r7 = mm256_unpackhi_epi32(t6, t7);
328
329 t0 = mm256_unpacklo_epi64(r0, r1);
330 t4 = mm256_unpacklo_epi64(r4, r5);
331 t2 = mm256_unpackhi_epi64(r0, r1);
332 t6 = mm256_unpackhi_epi64(r4, r5);
333 t1 = mm256_unpacklo_epi64(r2, r3);
334 t5 = mm256_unpacklo_epi64(r6, r7);
335 t3 = mm256_unpackhi_epi64(r2, r3);
336 t7 = mm256_unpackhi_epi64(r6, r7);
337
338 // The preceding sets of rearrangement operations interleaved by 4 bytes
339 // and then by 8 bytes *within* lanes. The following set interleave by
340 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
341 // t4) are interleaved to create (r0, r1). This complexity follows from
342 // the way that AVX is centered around MM 128-bit lanes.
343 r0 = mm256_permute2x128_si256(t0, t4, 0x20);
344 r4 = mm256_permute2x128_si256(t1, t5, 0x20);
345 r1 = mm256_permute2x128_si256(t0, t4, 0x31);
346 r5 = mm256_permute2x128_si256(t1, t5, 0x31);
347 r2 = mm256_permute2x128_si256(t2, t6, 0x20);
348 r6 = mm256_permute2x128_si256(t3, t7, 0x20);
349 r3 = mm256_permute2x128_si256(t2, t6, 0x31);
350 r7 = mm256_permute2x128_si256(t3, t7, 0x31);
351
352 r0 = mm256_xor_si256(r0, input_xor_v);
353 r1 = mm256_xor_si256(r1, input_xor_v);
354 r2 = mm256_xor_si256(r2, input_xor_v);
355 r3 = mm256_xor_si256(r3, input_xor_v);
356 r4 = mm256_xor_si256(r4, input_xor_v);
357 r5 = mm256_xor_si256(r5, input_xor_v);
358 r6 = mm256_xor_si256(r6, input_xor_v);
359 r7 = mm256_xor_si256(r7, input_xor_v);
360
361 __m256i sums_4x4_16bit_lo;
362 sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
363 sums_4x4_16bit_lo = mm256_add_epi16(
364 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
365 sums_4x4_16bit_lo = mm256_add_epi16(
366 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
367 sums_4x4_16bit_lo = mm256_add_epi16(
368 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
369 sums_4x4_16bit_lo = mm256_add_epi16(
370 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
371 sums_4x4_16bit_lo = mm256_add_epi16(
372 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
373 sums_4x4_16bit_lo = mm256_add_epi16(
374 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
375 sums_4x4_16bit_lo = mm256_add_epi16(
376 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
377
378 // The sums have been performed across columns, and now we have 4x16-bit
379 // sums packed together. We use madd for pairwise 32-bit sums.
380 const __m256i sums_4x2_32bit_lo_new =
381 mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
382 sums_4x2_32bit_lo =
383 mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
384
385 __m256i sums_4x4_16bit_hi;
386 sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
387 sums_4x4_16bit_hi = mm256_add_epi16(
388 sums_4x4_16bit_hi,
389 mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
390 sums_4x4_16bit_hi = mm256_add_epi16(
391 sums_4x4_16bit_hi,
392 mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
393 sums_4x4_16bit_hi = mm256_add_epi16(
394 sums_4x4_16bit_hi,
395 mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
396 sums_4x4_16bit_hi = mm256_add_epi16(
397 sums_4x4_16bit_hi,
398 mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
399 sums_4x4_16bit_hi = mm256_add_epi16(
400 sums_4x4_16bit_hi,
401 mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
402 sums_4x4_16bit_hi = mm256_add_epi16(
403 sums_4x4_16bit_hi,
404 mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
405 sums_4x4_16bit_hi = mm256_add_epi16(
406 sums_4x4_16bit_hi,
407 mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
408
409 const __m256i sums_4x2_32bit_hi_new =
410 mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
411 sums_4x2_32bit_hi =
412 mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
413
414 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
415 r0);
416 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
417 r4);
418 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
419 r1);
420 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
421 r5);
422 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
423 r2);
424 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
425 r6);
426 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
427 r3);
428 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
429 r7);
430 } else {
431 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
432 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
433 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
434
435 t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
436 t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
437 t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
438 t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5));
439 t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2));
440 t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6));
441 t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3));
442 t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7));
443
444 r0 = mm256_unpacklo_epi32(t0, t1);
445 r4 = mm256_unpacklo_epi32(t4, t5);
446 r2 = mm256_unpackhi_epi32(t0, t1);
447 r6 = mm256_unpackhi_epi32(t4, t5);
448 r1 = mm256_unpacklo_epi32(t2, t3);
449 r5 = mm256_unpacklo_epi32(t6, t7);
450 r3 = mm256_unpackhi_epi32(t2, t3);
451 r7 = mm256_unpackhi_epi32(t6, t7);
452
453 t0 = mm256_unpacklo_epi64(r0, r1);
454 t4 = mm256_unpacklo_epi64(r4, r5);
455 t2 = mm256_unpackhi_epi64(r0, r1);
456 t6 = mm256_unpackhi_epi64(r4, r5);
457 t1 = mm256_unpacklo_epi64(r2, r3);
458 t5 = mm256_unpacklo_epi64(r6, r7);
459 t3 = mm256_unpackhi_epi64(r2, r3);
460 t7 = mm256_unpackhi_epi64(r6, r7);
461
462 // The preceding sets of rearrangement operations interleaved by 4 bytes
463 // and then by 8 bytes *within* lanes. The following set interleave by
464 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
465 // t4) are interleaved to create (r0, r1). This complexity follows from
466 // the way that AVX is centered around MM 128-bit lanes.
467 r0 = mm256_permute2x128_si256(t0, t4, 0x20);
468 r4 = mm256_permute2x128_si256(t1, t5, 0x20);
469 r1 = mm256_permute2x128_si256(t0, t4, 0x31);
470 r5 = mm256_permute2x128_si256(t1, t5, 0x31);
471 r2 = mm256_permute2x128_si256(t2, t6, 0x20);
472 r6 = mm256_permute2x128_si256(t3, t7, 0x20);
473 r3 = mm256_permute2x128_si256(t2, t6, 0x31);
474 r7 = mm256_permute2x128_si256(t3, t7, 0x31);
475
476 r0 = mm256_xor_si256(r0, input_xor_v);
477 r1 = mm256_xor_si256(r1, input_xor_v);
478 r2 = mm256_xor_si256(r2, input_xor_v);
479 r3 = mm256_xor_si256(r3, input_xor_v);
480 r4 = mm256_xor_si256(r4, input_xor_v);
481 r5 = mm256_xor_si256(r5, input_xor_v);
482 r6 = mm256_xor_si256(r6, input_xor_v);
483 r7 = mm256_xor_si256(r7, input_xor_v);
484
485 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4),
486 r0);
487 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4),
488 r4);
489 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4),
490 r1);
491 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4),
492 r5);
493 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4),
494 r2);
495 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4),
496 r6);
497 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4),
498 r3);
499 _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4),
500 r7);
501 }
502 } else if (available_src_rows > 0) {
503 RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
504 // We do not care what goes into the trailing buffer, but we want
505 // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
506 //
507 // We compensate for padding-with-zero_point by initializing the
508 // summations with the compensating offset, effectively
509 // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) *
510 // 4 * (8 - ((available_src_rows + 3) >> 2)).
511 //
512 // Note that (zero_point ^ input_xor) is performed in 8-bits and then
513 // cast.
514 sums_adjustment +=
515 -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2));
516
517 __m256i t0, t1, t2, t3, t4, t5, t6, t7;
518 __m256i r0, r1, r2, r3, r4, r5, r6, r7;
519 const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
520
521 t0 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr0);
522 t4 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr4);
523 t1 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr1);
524 t5 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr5);
525 t2 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr2);
526 t6 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr6);
527 t3 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr3);
528 t7 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr7);
529
530 r0 = mm256_unpacklo_epi32(t0, t1);
531 r4 = mm256_unpacklo_epi32(t4, t5);
532 r2 = mm256_unpackhi_epi32(t0, t1);
533 r6 = mm256_unpackhi_epi32(t4, t5);
534 r1 = mm256_unpacklo_epi32(t2, t3);
535 r5 = mm256_unpacklo_epi32(t6, t7);
536 r3 = mm256_unpackhi_epi32(t2, t3);
537 r7 = mm256_unpackhi_epi32(t6, t7);
538
539 t0 = mm256_unpacklo_epi64(r0, r1);
540 t4 = mm256_unpacklo_epi64(r4, r5);
541 t2 = mm256_unpackhi_epi64(r0, r1);
542 t6 = mm256_unpackhi_epi64(r4, r5);
543 t1 = mm256_unpacklo_epi64(r2, r3);
544 t5 = mm256_unpacklo_epi64(r6, r7);
545 t3 = mm256_unpackhi_epi64(r2, r3);
546 t7 = mm256_unpackhi_epi64(r6, r7);
547
548 // The preceding sets of rearrangement operations interleaved by 4 bytes
549 // and then by 8 bytes *within* lanes. The following set interleave by
550 // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0,
551 // t4) are interleaved to create (r0, r1). This complexity follows from
552 // the way that AVX is centered around MM 128-bit lanes.
553 r0 = mm256_permute2x128_si256(t0, t4, 0x20);
554 r4 = mm256_permute2x128_si256(t1, t5, 0x20);
555 r1 = mm256_permute2x128_si256(t0, t4, 0x31);
556 r5 = mm256_permute2x128_si256(t1, t5, 0x31);
557 r2 = mm256_permute2x128_si256(t2, t6, 0x20);
558 r6 = mm256_permute2x128_si256(t3, t7, 0x20);
559 r3 = mm256_permute2x128_si256(t2, t6, 0x31);
560 r7 = mm256_permute2x128_si256(t3, t7, 0x31);
561
562 r0 = mm256_xor_si256(r0, input_xor_v);
563 r1 = mm256_xor_si256(r1, input_xor_v);
564 r2 = mm256_xor_si256(r2, input_xor_v);
565 r3 = mm256_xor_si256(r3, input_xor_v);
566 r4 = mm256_xor_si256(r4, input_xor_v);
567 r5 = mm256_xor_si256(r5, input_xor_v);
568 r6 = mm256_xor_si256(r6, input_xor_v);
569 r7 = mm256_xor_si256(r7, input_xor_v);
570
571 __m256i sums_4x4_16bit_lo;
572 sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0));
573 sums_4x4_16bit_lo = mm256_add_epi16(
574 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1)));
575 sums_4x4_16bit_lo = mm256_add_epi16(
576 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2)));
577 sums_4x4_16bit_lo = mm256_add_epi16(
578 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3)));
579 sums_4x4_16bit_lo = mm256_add_epi16(
580 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4)));
581 sums_4x4_16bit_lo = mm256_add_epi16(
582 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5)));
583 sums_4x4_16bit_lo = mm256_add_epi16(
584 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6)));
585 sums_4x4_16bit_lo = mm256_add_epi16(
586 sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7)));
587
588 // The sums have been performed across columns, and now we have 4x16-bit
589 // sums packed together. We use madd for pairwise 32-bit sums.
590 const __m256i sums_4x2_32bit_lo_new =
591 mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit);
592 sums_4x2_32bit_lo =
593 mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new);
594
595 __m256i sums_4x4_16bit_hi;
596 sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1));
597 sums_4x4_16bit_hi =
598 mm256_add_epi16(sums_4x4_16bit_hi,
599 mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1)));
600 sums_4x4_16bit_hi =
601 mm256_add_epi16(sums_4x4_16bit_hi,
602 mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1)));
603 sums_4x4_16bit_hi =
604 mm256_add_epi16(sums_4x4_16bit_hi,
605 mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1)));
606 sums_4x4_16bit_hi =
607 mm256_add_epi16(sums_4x4_16bit_hi,
608 mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1)));
609 sums_4x4_16bit_hi =
610 mm256_add_epi16(sums_4x4_16bit_hi,
611 mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1)));
612 sums_4x4_16bit_hi =
613 mm256_add_epi16(sums_4x4_16bit_hi,
614 mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1)));
615 sums_4x4_16bit_hi =
616 mm256_add_epi16(sums_4x4_16bit_hi,
617 mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1)));
618
619 const __m256i sums_4x2_32bit_hi_new =
620 mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit);
621 sums_4x2_32bit_hi =
622 mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new);
623
624 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4),
625 r0);
626 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4),
627 r4);
628 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4),
629 r1);
630 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4),
631 r5);
632 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4),
633 r2);
634 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4),
635 r6);
636 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4),
637 r3);
638 _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4),
639 r7);
640 }
641
642 packed_ptr += 8 * kNumChunkedSrcRows;
643 src_ptr0 += src_inc0;
644 src_ptr1 += src_inc1;
645 src_ptr2 += src_inc2;
646 src_ptr3 += src_inc3;
647 src_ptr4 += src_inc4;
648 src_ptr5 += src_inc5;
649 src_ptr6 += src_inc6;
650 src_ptr7 += src_inc7;
651 }
652
653 if (sums_ptr) {
654 const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
655
656 __m256i sums =
657 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
658
659 // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the
660 // neighbours, finshing up by adding them to the stored accumulated sums.
661 const __m256i sums_2x4_32bit_lo = PermuteEpi32EvenOdds(sums_4x2_32bit_lo);
662 const __m256i sums_2x4_32bit_hi = PermuteEpi32EvenOdds(sums_4x2_32bit_hi);
663 const __m256i sums_2x4_32bit_a =
664 mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20);
665 const __m256i sums_2x4_32bit_b =
666 mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31);
667 sums = mm256_add_epi32(sums, sums_adjustment_v);
668 sums = mm256_add_epi32(sums, sums_2x4_32bit_a);
669 sums = mm256_add_epi32(sums, sums_2x4_32bit_b);
670
671 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
672 }
673}
674
675// Use a generic AVX intrinsic for greater-than comparison.
676template <>
677inline __m256i CompareGreaterThan<Path::kAvx>(const __m256i& a,
678 const __m256i& b) {
679 constexpr int kGreaterThanSignalling = 14;
680 const __m256 v = _mm256_cmp_ps(_mm256_cvtepi32_ps(a), _mm256_cvtepi32_ps(b),
681 kGreaterThanSignalling);
682 return _mm256_cvtps_epi32(v);
683}
684
685} // namespace.
686
687void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
688 const std::int8_t* zerobuf, int src_stride,
689 int remaining_src_cols, int src_rows,
690 std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
691 profiler::ScopeLabel label("Pack kAvx 8bit");
692
693 using Layout = PackImpl8bitAvx::Layout;
694 RUY_DCHECK_EQ(Layout::kCols, 8);
695 RUY_DCHECK_EQ(Layout::kRows, 4);
696
697 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
698 // We process 8 of these chunks at a time, padding short input chunks.
699 static constexpr int kNumRowChunks = 8; // Short input is padded.
700
701 // Each packed block is 4*8, and there are normally 8. The trailing block is
702 // only slightly shorter.
703 constexpr int kTrailingBufSize =
704 kNumRowChunks * Layout::kCols * Layout::kRows;
705 std::int8_t trailing_buf[kTrailingBufSize];
706 memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
707
708 Pack8bitColMajorForAvxPacker(src_ptr, input_xor, zerobuf, src_stride,
709 remaining_src_cols, src_rows, packed_ptr,
710 sums_ptr, trailing_buf);
711
712 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
713 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
714 // If the number of source rows is not a multiple of kChunkedRowMask, there
715 // will be data in the trailing buffer,
716 if (trailing_data) {
717 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
718 // Destination "rows" are padded to next highest multiple of Layout::kRows.
719 const int dst_rows = (src_rows + 3) & ~3;
720 const int trailing_rows = dst_rows - non_trailing_rows;
721 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
722 Layout::kCols * trailing_rows * sizeof(std::int8_t));
723 }
724}
725
726void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
727 int src_stride, int remaining_src_cols,
728 int src_rows, float* packed_ptr) {
729 profiler::ScopeLabel label("Pack kAvx float");
730 static constexpr int kPackCols = 8; // Source cols packed together.
731 static constexpr int kPackRows = 8; // Short input is padded.
732 float trailing_buf[(kPackRows - 1) * kPackCols];
733 if (remaining_src_cols < 8) {
734 memset(trailing_buf, 0, sizeof(trailing_buf));
735 }
736 PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx, Path::kAvx>(
737 src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr,
738 trailing_buf);
739
740 const int trailing_rows = src_rows & (kPackRows - 1);
741 if (trailing_rows > 0) {
742 const int non_trailing_rows = src_rows & ~(kPackRows - 1);
743 memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
744 kPackCols * trailing_rows * sizeof(float));
745 }
746}
747
748void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
749 int src_zero_point, std::int8_t* packed_ptr,
750 int packed_stride, int start_col, int end_col,
751 int src_cols, int block_row, int src_rows,
752 int input_xor, std::int32_t* sums) {
753 int col = start_col;
754 int src_end_col = std::min(end_col, src_cols);
755
756 for (; col <= src_end_col - 8; col += 8) {
757 std::int8_t* dst_ptr = packed_ptr;
758 __m128i val0, val1, val2, val3;
759 __m128i input_xor_dup = _mm_set1_epi8(input_xor);
760 // Load a 4x8 block.
761 if (block_row + 4 <= src_rows) {
762 val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
763 val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
764 val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
765 val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
766 } else {
767 val0 = _mm_set1_epi8(src_zero_point);
768 val1 = val0;
769 val2 = val0;
770 val3 = val0;
771 if (block_row + 0 < src_rows)
772 val0 = _mm_loadu_si64(src_ptr + 0 * src_stride);
773 if (block_row + 1 < src_rows)
774 val1 = _mm_loadu_si64(src_ptr + 1 * src_stride);
775 if (block_row + 2 < src_rows)
776 val2 = _mm_loadu_si64(src_ptr + 2 * src_stride);
777 if (block_row + 3 < src_rows)
778 val3 = _mm_loadu_si64(src_ptr + 3 * src_stride);
779 }
780 // Maybe xor the sign bit to convert from uint8 to int8.
781 val0 = _mm_xor_si128(val0, input_xor_dup);
782 val1 = _mm_xor_si128(val1, input_xor_dup);
783 val2 = _mm_xor_si128(val2, input_xor_dup);
784 val3 = _mm_xor_si128(val3, input_xor_dup);
785 // Update the sums.
786 __m128i val16_0 = _mm_cvtepi8_epi16(val0);
787 __m128i val16_1 = _mm_cvtepi8_epi16(val1);
788 __m128i val16_2 = _mm_cvtepi8_epi16(val2);
789 __m128i val16_3 = _mm_cvtepi8_epi16(val3);
790 __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1),
791 _mm_add_epi16(val16_2, val16_3));
792 __m256i sum =
793 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col));
794 sum = mm256_add_epi32(sum, mm256_cvtepi16_epi32(new_sum16));
795 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum);
796 // Perform the transposition of 4x4 blocks
797 __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1);
798 __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3);
799 __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1);
800 __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1);
801 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0);
802 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1);
803 src_ptr += 8;
804 packed_ptr += packed_stride * 8;
805 }
806 for (; col < src_end_col; col++) {
807 std::int32_t accum = 0;
808 for (int r = 0; r < 4; r++) {
809 std::int8_t packed_val;
810 if (block_row + r < src_rows) {
811 packed_val = input_xor ^ src_ptr[r * src_stride];
812 } else {
813 packed_val = input_xor ^ src_zero_point;
814 }
815 accum += packed_val;
816 *packed_ptr++ = packed_val;
817 }
818 if (sums) {
819 sums[col] += accum;
820 }
821 src_ptr++;
822 }
823 for (; col < end_col; col++) {
824 std::memset(packed_ptr, 0, 4);
825 packed_ptr += 4;
826 }
827}
828
829#endif // RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS)
830
831} // namespace ruy
832