1/* Copyright 2020 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cstdint>
18#include <cstring>
19
20#include "ruy/check_macros.h"
21#include "ruy/kernel_common.h"
22#include "ruy/kernel_x86.h"
23#include "ruy/opt_set.h"
24#include "ruy/platform.h"
25#include "ruy/profiler/instrumentation.h"
26
27#if RUY_PLATFORM_AVX && RUY_OPT(ASM)
28#include <immintrin.h> // IWYU pragma: keep
29#endif
30
31namespace ruy {
32
33#if !(RUY_PLATFORM_AVX && RUY_OPT(ASM))
34
35void Kernel8bitAvx(const KernelParams8bit<8, 8>&) {
36 // CPU-ID-based checks should disable the path that would reach this point.
37 RUY_DCHECK(false);
38}
39
40void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) {
41 // CPU-ID-based checks should disable the path that would reach this point.
42 RUY_DCHECK(false);
43}
44
45void KernelFloatAvx(const KernelParamsFloat<8, 8>&) {
46 // CPU-ID-based checks should disable the path that would reach this point.
47 RUY_DCHECK(false);
48}
49
50void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) {
51 // CPU-ID-based checks should disable the path that would reach this point.
52 RUY_DCHECK(false);
53}
54
55#else // RUY_PLATFORM_AVX && RUY_OPT(ASM)
56
57static constexpr int kAvx8bitBlockSize = 8;
58static constexpr int kAvx8bitInnerSize = 4;
59
60namespace {
61namespace intrin_utils {
62
63template <>
64inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a,
65 const __m256i& b) {
66 __m128i a_lo = _mm256_extractf128_si256(a, 0);
67 __m128i a_hi = _mm256_extractf128_si256(a, 1);
68 __m128i b_lo = _mm256_extractf128_si256(b, 0);
69 __m128i b_hi = _mm256_extractf128_si256(b, 1);
70 __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo);
71 __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi);
72 return _mm256_set_m128i(dst_hi, dst_lo);
73}
74
75template <>
76inline __m128i mm256_extracti128_si256<Path::kAvx>(const __m256i& a,
77 const int imm) {
78 switch (imm) {
79 case 0:
80 return _mm256_extractf128_si256(a, 0);
81 case 1:
82 return _mm256_extractf128_si256(a, 1);
83 default:
84 RUY_DCHECK_LT(imm, 2);
85 return _mm_setzero_si128();
86 }
87}
88
89template <Path path>
90inline __m256i mm256_cvtepi8_epi16(const __m128i& a) {
91 // Take the upper 64 bits of a and put in the first 64 bits of 'hi'
92 __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128());
93 return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a));
94}
95
96template <Path path>
97inline __m256i mm256_cvtepi32_epi64(const __m128i& a) {
98 // sign extend the 32-bit values in the lower 64 bits of a.
99 __m128i lo = _mm_cvtepi32_epi64(a);
100 __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128()));
101 return _mm256_set_m128i(hi, lo);
102}
103
104inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b,
105 const int imm) {
106 __m128i tmp = _mm_setzero_si128();
107 if (!(imm & 8)) {
108 switch (imm & 3) {
109 case 0:
110 return _mm256_extractf128_si256(a, 0);
111 case 1:
112 return _mm256_extractf128_si256(a, 1);
113 case 2:
114 return _mm256_extractf128_si256(b, 0);
115 case 3:
116 return _mm256_extractf128_si256(b, 1);
117 }
118 }
119 return tmp;
120}
121
122template <Path path>
123inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b,
124 const int imm) {
125 const int lo_imm = imm & 15;
126 __m128i lo = mm_permute_helper(a, b, lo_imm);
127 const int hi_imm = (imm >> 4) & 15;
128 __m128i hi = mm_permute_helper(a, b, hi_imm);
129 return _mm256_set_m128i(hi, lo);
130}
131
132template <Path path>
133inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) {
134 __m128i a_lo = _mm256_extractf128_si256(a, 0);
135 __m128i a_hi = _mm256_extractf128_si256(a, 1);
136 __m128i b_lo = _mm256_extractf128_si256(b, 0);
137 __m128i b_hi = _mm256_extractf128_si256(b, 1);
138 __m128i lo = _mm_max_epi32(a_lo, b_lo);
139 __m128i hi = _mm_max_epi32(a_hi, b_hi);
140 return _mm256_set_m128i(hi, lo);
141}
142
143template <Path path>
144inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) {
145 __m128i a_lo = _mm256_extractf128_si256(a, 0);
146 __m128i a_hi = _mm256_extractf128_si256(a, 1);
147 __m128i b_lo = _mm256_extractf128_si256(b, 0);
148 __m128i b_hi = _mm256_extractf128_si256(b, 1);
149 __m128i lo = _mm_min_epi32(a_lo, b_lo);
150 __m128i hi = _mm_min_epi32(a_hi, b_hi);
151 return _mm256_set_m128i(hi, lo);
152}
153
154template <Path path>
155inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) {
156 __m128i a_lo = _mm256_extractf128_si256(a, 0);
157 __m128i a_hi = _mm256_extractf128_si256(a, 1);
158 __m128i b_lo = _mm256_extractf128_si256(b, 0);
159 __m128i b_hi = _mm256_extractf128_si256(b, 1);
160 __m128i lo = _mm_add_epi32(a_lo, b_lo);
161 __m128i hi = _mm_add_epi32(a_hi, b_hi);
162 return _mm256_set_m128i(hi, lo);
163}
164
165template <Path path>
166inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) {
167 __m128i a_lo = _mm256_extractf128_si256(a, 0);
168 __m128i a_hi = _mm256_extractf128_si256(a, 1);
169 __m128i b_lo = _mm256_extractf128_si256(b, 0);
170 __m128i b_hi = _mm256_extractf128_si256(b, 1);
171 __m128i lo = _mm_add_epi64(a_lo, b_lo);
172 __m128i hi = _mm_add_epi64(a_hi, b_hi);
173 return _mm256_set_m128i(hi, lo);
174}
175
176template <Path path>
177inline __m256i mm256_slli_epi64(const __m256i& a, int imm) {
178 __m128i a_lo = _mm256_extractf128_si256(a, 0);
179 __m128i a_hi = _mm256_extractf128_si256(a, 1);
180 __m128i lo = _mm_slli_epi64(a_lo, imm);
181 __m128i hi = _mm_slli_epi64(a_hi, imm);
182 return _mm256_set_m128i(hi, lo);
183}
184
185template <Path path>
186inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
187 __m128i a_lo = _mm256_extractf128_si256(a, 0);
188 __m128i a_hi = _mm256_extractf128_si256(a, 1);
189 __m128i b_lo = _mm256_extractf128_si256(b, 0);
190 __m128i b_hi = _mm256_extractf128_si256(b, 1);
191 __m128i lo = _mm_mullo_epi32(a_lo, b_lo);
192 __m128i hi = _mm_mullo_epi32(a_hi, b_hi);
193 return _mm256_set_m128i(hi, lo);
194}
195
196// Defined as a macro since `imm` must be an immediate.
197#define BlendM128_epi32(a, b, imm) \
198 _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
199
200// Defined as a macro since `imm` must be an immediate.
201#define BlendM128_epi64(a, b, imm) \
202 _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
203
204// Defined as a macro since `imm` must be an immediate.
205#define mm256_blend_epi32(ans, a, b, imm) \
206 __m128i a_lo = _mm256_extractf128_si256(a, 0); \
207 __m128i a_hi = _mm256_extractf128_si256(a, 1); \
208 __m128i b_lo = _mm256_extractf128_si256(b, 0); \
209 __m128i b_hi = _mm256_extractf128_si256(b, 1); \
210 __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \
211 __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \
212 ans = _mm256_set_m128i(hi, lo);
213
214#define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \
215 a_lo = _mm256_extractf128_si256(a, 0); \
216 a_hi = _mm256_extractf128_si256(a, 1); \
217 ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \
218 _mm_shuffle_epi32(a_lo, imm));
219
220template <Path path>
221inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) {
222 __m128i a_lo = _mm256_extractf128_si256(a, 0);
223 __m128i a_hi = _mm256_extractf128_si256(a, 1);
224 __m128i b_lo = _mm256_extractf128_si256(b, 0);
225 __m128i b_hi = _mm256_extractf128_si256(b, 1);
226 __m128i lo = _mm_madd_epi16(a_lo, b_lo);
227 __m128i hi = _mm_madd_epi16(a_hi, b_hi);
228 return _mm256_set_m128i(hi, lo);
229}
230
231inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) {
232 // shift both elements of a by lower 64bits of b.
233 __m128i res_lo = _mm_srl_epi64(a, b);
234 // shift both elements of a by upper 64bits of b.
235 __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
236 __m128i res_hi = _mm_srl_epi64(a, hi_count);
237 // Take the lower 64 bits of res_lo and upper 64 bits of res hi
238 // 1. Swap the upper and lower 64 bits of res_hi
239 __m128i tmp_hi =
240 _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
241 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
242 return _mm_unpacklo_epi64(res_lo, tmp_hi);
243}
244
245template <Path path>
246inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) {
247 __m128i a_lo = _mm256_extractf128_si256(a, 0);
248 __m128i a_hi = _mm256_extractf128_si256(a, 1);
249 __m128i b_lo = _mm256_extractf128_si256(b, 0);
250 __m128i b_hi = _mm256_extractf128_si256(b, 1);
251 __m128i lo = mm_srlv_epi64(a_lo, b_lo);
252 __m128i hi = mm_srlv_epi64(a_hi, b_hi);
253 return _mm256_set_m128i(hi, lo);
254}
255
256template <Path path>
257inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) {
258 // shift both elements of a by lower 64bits of b.
259 __m128i res_lo = _mm_sll_epi64(a, b);
260 // shift both elements of a by upper 64bits of b.
261 __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128());
262 __m128i res_hi = _mm_sll_epi64(a, hi_count);
263 // Take the lower 64 bits of res_lo and upper 64 bits of res hi
264 // 1. Swap the upper and lower 64 bits of res_hi
265 __m128i tmp_hi =
266 _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1));
267 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
268 return _mm_unpacklo_epi64(res_lo, tmp_hi);
269}
270
271template <Path path>
272inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
273 __m128i a_lo = _mm256_extractf128_si256(a, 0);
274 __m128i a_hi = _mm256_extractf128_si256(a, 1);
275 __m128i b_lo = _mm256_extractf128_si256(b, 0);
276 __m128i b_hi = _mm256_extractf128_si256(b, 1);
277 __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo);
278 __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi);
279 return _mm256_set_m128i(hi, lo);
280}
281
282#define PermuteM128_epi32(a, imm) \
283 _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
284
285inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
286 // shift all elements of a by first 32bits of b.
287 __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
288
289 // put bits 32-63 of b in the first slot.
290 __m128i tmp1 = PermuteM128_epi32(b, 1);
291 // put bits 32-63 of a in the first slot.
292 __m128i a1 = PermuteM128_epi32(a, 1);
293 // shift all elements of a by second 32bits of b.
294 __m128i res1 =
295 _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
296
297 // put bits 64-95 of b in the first slot.
298 __m128i tmp2 = PermuteM128_epi32(b, 2);
299 // shift all elements of a by third 32bits of b.
300 __m128i res2 =
301 _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
302
303 // put bits 96-127 of b in the first slot.
304 __m128i tmp3 = PermuteM128_epi32(b, 3);
305 // put bits 96-127 of a in the third slot.
306 __m128i a3 = PermuteM128_epi32(a, 48);
307 // shift all elements of a3 by fourth 32bits of b.
308 __m128i res3 =
309 _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
310
311 // Take bits 0-31 of res0, bits 0-31 of res1,
312 // bits 64-95 of res2, and bits 64-95 of res3.
313 // res0 _ _ _ 0
314 // res1 _ _ _ 1
315 // res2 _ 2 _ _
316 // res3 _ 3 _ _
317 // f_01 _ _ 1 0
318 // f_23 _ _ 3 2
319 __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
320 __m128i f_23 = _mm_unpackhi_epi32(res2, res3);
321 // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
322 return _mm_unpacklo_epi64(f_01, f_23);
323}
324
325template <Path path>
326inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) {
327 __m128i a_lo = _mm256_extractf128_si256(a, 0);
328 __m128i a_hi = _mm256_extractf128_si256(a, 1);
329 __m128i b_lo = _mm256_extractf128_si256(b, 0);
330 __m128i b_hi = _mm256_extractf128_si256(b, 1);
331 __m128i lo = mm_sllv_epi32(a_lo, b_lo);
332 __m128i hi = mm_sllv_epi32(a_hi, b_hi);
333 return _mm256_set_m128i(hi, lo);
334}
335
336template <Path path>
337inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) {
338 __m128i a_lo = _mm256_extractf128_si256(a, 0);
339 __m128i a_hi = _mm256_extractf128_si256(a, 1);
340 __m128i b_lo = _mm256_extractf128_si256(b, 0);
341 __m128i b_hi = _mm256_extractf128_si256(b, 1);
342 __m128i lo = _mm_sub_epi32(a_lo, b_lo);
343 __m128i hi = _mm_sub_epi32(a_hi, b_hi);
344 return _mm256_set_m128i(hi, lo);
345}
346
347template <Path path>
348inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) {
349 __m128i a_lo = _mm256_extractf128_si256(a, 0);
350 __m128i a_hi = _mm256_extractf128_si256(a, 1);
351 __m128i b_lo = _mm256_extractf128_si256(b, 0);
352 __m128i b_hi = _mm256_extractf128_si256(b, 1);
353 __m128i lo = _mm_mul_epi32(a_lo, b_lo);
354 __m128i hi = _mm_mul_epi32(a_hi, b_hi);
355 return _mm256_set_m128i(hi, lo);
356}
357
358// Perform the equivalent of mm256_permutevar8x32 with
359// a second argument of {7, 5, 3, 1, 6, 4, 2, 0}
360template <Path path>
361inline __m256i PermuteEpi32EvenOdds(const __m256i& a) {
362 // a_lo = 3 2 1 0
363 __m128i a_lo = _mm256_extractf128_si256(a, 0);
364 // a_hi = 7 6 5 4
365 __m128i a_hi = _mm256_extractf128_si256(a, 1);
366 // shuffle a_lo to get 3 1 2 0
367 __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8);
368 // shuffle a_hi to get 7 5 6 4
369 __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8);
370 // unpack lo 64 of res_lo and res hi to get 6 4 2 0
371 __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi);
372 // unpack hi 64 of res_lo and res hi to get 7 5 1 3
373 __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi);
374 return _mm256_set_m128i(res_hi, res_lo);
375}
376
377template <Path path>
378inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
379 const __m256i bias0 = _mm256_set1_epi32(*(bias + offset));
380 return mm256_add_epi32<path>(a, bias0);
381}
382
383__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
384 const __m256i& mask) {
385 __m256 result =
386 _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
387 _mm256_castsi256_ps(mask));
388 return _mm256_castps_si256(result);
389}
390
391template <Path path>
392inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
393 __m128i a_lo = _mm256_extractf128_si256(a, 0);
394 __m128i a_hi = _mm256_extractf128_si256(a, 1);
395 __m128i b_lo = _mm256_extractf128_si256(b, 0);
396 __m128i b_hi = _mm256_extractf128_si256(b, 1);
397 __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
398 __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
399 return _mm256_set_m128i(hi, lo);
400}
401
402template <Path path>
403inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
404 __m128i a_lo = _mm256_extractf128_si256(a, 0);
405 __m128i a_hi = _mm256_extractf128_si256(a, 1);
406
407 __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
408 __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
409 __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
410 __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
411 __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
412 __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
413 __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
414 __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
415
416 // get element 0 from r0, element 1 from r1
417 __m128i r01 = BlendM128_epi32(r0, r1, 2);
418 // get element 2 from r2, element 3 from r3
419 __m128i r23 = BlendM128_epi32(r2, r3, 8);
420 // get element 0 from r4, element 1 from r5
421 __m128i r45 = BlendM128_epi32(r4, r5, 2);
422 // get element 2 from r6, element 3 from r7
423 __m128i r67 = BlendM128_epi32(r6, r7, 8);
424 // get lower 64 bits of r01, upper 64 bits of r23
425 __m128i r0123 = BlendM128_epi64(r01, r23, 2);
426 // get lower 64 bits of r45, upper 64 bits of r67
427 __m128i r4567 = BlendM128_epi64(r45, r67, 2);
428 return _mm256_set_m128i(r4567, r0123);
429}
430
431// AVX doesn't have fused multiply-add so we define an inline function to be
432// used in the common code following.
433template <>
434inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b,
435 const __m256& c) {
436 const __m256 prod = _mm256_mul_ps(a, b);
437 return _mm256_add_ps(prod, c);
438}
439
440} // namespace intrin_utils
441} // namespace
442
443template <Path path>
444void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
445 profiler::ScopeLabel label("Kernel kAvx 8-bit");
446 const std::int8_t splitter_idx_data[32] = {
447 0, 1, 4, 5, 8, 9, 12, 13, //
448 2, 3, 6, 7, 10, 11, 14, 15, //
449 0, 1, 4, 5, 8, 9, 12, 13, //
450 2, 3, 6, 7, 10, 11, 14, 15 //
451 };
452
453 std::int32_t dst_stride = 0;
454 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
455 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
456 dst_stride = params.dst_stride;
457 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
458 dst_stride = params.dst_stride / sizeof(std::int16_t);
459 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
460 dst_stride = params.dst_stride / sizeof(std::int32_t);
461 } else {
462 RUY_DCHECK(false);
463 }
464
465 const std::int8_t* rhs_col_ptr =
466 static_cast<const int8_t*>(params.rhs_base_ptr);
467 void* dst_col_ptr = params.dst_base_ptr;
468
469 for (int col = params.start_col; col <= params.last_col;
470 col += kAvx8bitBlockSize) {
471 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
472 void* dst_ptr = dst_col_ptr;
473
474 const std::int32_t lhs_zero_point = params.lhs_zero_point;
475 const bool has_rhs_sums_offsets =
476 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
477 std::int32_t rhs_sums_offsets[8];
478 if (has_rhs_sums_offsets) {
479 const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
480 _mm256_set1_epi32(lhs_zero_point),
481 _mm256_loadu_si256(
482 reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
483 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
484 rhs_sums_offset_v);
485 }
486
487 for (int row = params.start_row; row <= params.last_row;
488 row += kAvx8bitBlockSize) {
489 int channel =
490 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
491 int multiplier_channel =
492 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
493 const int residual_rows =
494 std::min(params.dst_rows - row, kAvx8bitBlockSize);
495 const int residual_cols =
496 std::min(params.dst_cols - col, kAvx8bitBlockSize);
497
498 const __m256i splitter_idx = _mm256_loadu_si256(
499 reinterpret_cast<__m256i const*>(splitter_idx_data));
500
501 __m256i accum_data_v0;
502 __m256i accum_data_v1;
503 __m256i accum_data_v2;
504 __m256i accum_data_v3;
505 __m256i accum_data_v4;
506 __m256i accum_data_v5;
507 __m256i accum_data_v6;
508 __m256i accum_data_v7;
509
510 // initial_accum_data will be the initialize of each of the
511 // accum_data_* accumulator registers. We compute into it terms that are
512 // identical across columns.
513 __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth);
514 __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth);
515
516 // In the channels-are-rows case, we can load bias here.
517 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
518 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
519 initial_accum_data_lo = _mm_add_epi32(
520 initial_accum_data_lo,
521 _mm_loadu_si128(
522 reinterpret_cast<const __m128i*>(params.bias + row)));
523 initial_accum_data_hi = _mm_add_epi32(
524 initial_accum_data_hi,
525 _mm_loadu_si128(
526 reinterpret_cast<const __m128i*>(params.bias + row + 4)));
527 }
528
529 // Adjustments common across columns.
530 const std::int32_t rhs_zero_point = params.rhs_zero_point;
531 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
532 const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point);
533 const __m128i lhs_sums_offset_lo = _mm_mullo_epi32(
534 rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
535 &params.lhs_sums[row])));
536 const __m128i lhs_sums_offset_hi = _mm_mullo_epi32(
537 rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>(
538 &params.lhs_sums[row + 4])));
539
540 initial_accum_data_lo =
541 _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo);
542 initial_accum_data_hi =
543 _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi);
544 }
545
546 // Adjustments differing across columns.
547 if (has_rhs_sums_offsets) {
548 __m256i initial_accum_data =
549 _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
550
551 accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
552 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
553 accum_data_v1 = intrin_utils::mm256_sub_epi32<path>(
554 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
555 accum_data_v2 = intrin_utils::mm256_sub_epi32<path>(
556 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
557 accum_data_v3 = intrin_utils::mm256_sub_epi32<path>(
558 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
559 accum_data_v4 = intrin_utils::mm256_sub_epi32<path>(
560 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
561 accum_data_v5 = intrin_utils::mm256_sub_epi32<path>(
562 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
563 accum_data_v6 = intrin_utils::mm256_sub_epi32<path>(
564 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
565 accum_data_v7 = intrin_utils::mm256_sub_epi32<path>(
566 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
567 } else {
568 __m256i initial_accum_data =
569 _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo);
570 accum_data_v0 = initial_accum_data;
571 accum_data_v1 = initial_accum_data;
572 accum_data_v2 = initial_accum_data;
573 accum_data_v3 = initial_accum_data;
574 accum_data_v4 = initial_accum_data;
575 accum_data_v5 = initial_accum_data;
576 accum_data_v6 = initial_accum_data;
577 accum_data_v7 = initial_accum_data;
578 }
579
580 // Finally, in the channels-are-columns case, load bias data here.
581 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
582 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
583 accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0,
584 params.bias + col, 0);
585 accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1,
586 params.bias + col, 1);
587 accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2,
588 params.bias + col, 2);
589 accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3,
590 params.bias + col, 3);
591 accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4,
592 params.bias + col, 4);
593 accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5,
594 params.bias + col, 5);
595 accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6,
596 params.bias + col, 6);
597 accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7,
598 params.bias + col, 7);
599 }
600
601 const std::int8_t* lhs_ptr = lhs_col_ptr;
602 const std::int8_t* rhs_ptr = rhs_col_ptr;
603 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
604 const __m256i lhs_data =
605 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
606 const __m256i rhs_data_8bit =
607 _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
608
609 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
610 std::int32_t rhs_data[16];
611 const __m128i rhs_data_bottom_lane =
612 _mm256_castsi256_si128(rhs_data_8bit);
613 const __m128i rhs_data_top_lane =
614 _mm256_extractf128_si256(rhs_data_8bit, 1);
615 const __m256i rhs_16_bit_dup_low =
616 intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane);
617 const __m256i rhs_16_bit_dup_high =
618 intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane);
619 // Now that we have cast the RHS data, we store it so that each value
620 // can be separately loaded in the accumulation loop.
621 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
622 rhs_16_bit_dup_low);
623 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
624 rhs_16_bit_dup_high);
625
626 // NOTE: There may be opportunities for permuting the data in the
627 // packing code instead of here.
628 const __m256i lhs_data_split =
629 intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
630 const __m256i lhs_data_split_expand_bottom =
631 intrin_utils::mm256_cvtepi8_epi16<path>(
632 _mm256_extractf128_si256(lhs_data_split, 0));
633 const __m256i lhs_data_split_expand_top =
634 intrin_utils::mm256_cvtepi8_epi16<path>(
635 _mm256_extractf128_si256(lhs_data_split, 1));
636
637 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
638 const __m256i lhs_16_bit_low =
639 intrin_utils::mm256_permute2x128_si256<path>(
640 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
641 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
642 const __m256i lhs_16_bit_high =
643 intrin_utils::mm256_permute2x128_si256<path>(
644 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
645
646 __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
647 rhs_data)); // Load [0 1 2 3 4 5 6 7]
648 __m256i rhs1 = _mm256_lddqu_si256(
649 reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
650 __m256i rhs0_3 =
651 _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
652 __m256i rhs4_7 =
653 _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
654 __m256i rhs8_11 =
655 _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
656 __m256i rhs12_15 =
657 _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
658
659 auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
660 __m256i& accum) {
661 // Perform mul-adds on low and high components of accum separately.
662 __m128i accum_lo = _mm256_extractf128_si256(accum, 0);
663 __m128i accum_hi = _mm256_extractf128_si256(accum, 1);
664
665 __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0);
666 __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1);
667 __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0);
668 __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1);
669 __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0);
670 __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1);
671
672 accum_lo = _mm_add_epi32(accum_lo, lo_0);
673 accum_hi = _mm_add_epi32(accum_hi, lo_1);
674
675 __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0);
676 __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1);
677 __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0);
678 __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1);
679 __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0);
680 __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1);
681
682 accum_lo = _mm_add_epi32(accum_lo, hi_0);
683 accum_hi = _mm_add_epi32(accum_hi, hi_1);
684 accum = _mm256_set_m128i(accum_hi, accum_lo);
685 };
686 __m256i tmp0, tmp1, tmp2, tmp3;
687 __m128i lo0, lo1, hi0, hi1;
688 mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0);
689 mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55);
690 process_column(tmp0, tmp1, accum_data_v0);
691 mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa);
692 mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff);
693 process_column(tmp2, tmp3, accum_data_v1);
694
695 mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0);
696 mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55);
697 process_column(tmp0, tmp1, accum_data_v2);
698 mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa);
699 mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff);
700 process_column(tmp2, tmp3, accum_data_v3);
701
702 mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0);
703 mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55);
704 process_column(tmp0, tmp1, accum_data_v4);
705 mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa);
706 mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff);
707 process_column(tmp2, tmp3, accum_data_v5);
708
709 mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0);
710 mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55);
711 process_column(tmp0, tmp1, accum_data_v6);
712 mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa);
713 mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff);
714 process_column(tmp2, tmp3, accum_data_v7);
715
716 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
717 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
718 }
719
720 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
721 __m256i m_vector;
722 __m256i e_vector;
723 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
724 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
725 params.multiplier_fixedpoint + multiplier_channel));
726 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
727 params.multiplier_exponent + multiplier_channel));
728
729 const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
730 _mm256_extractf128_si256(m_vector, 0));
731 const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
732 _mm256_extractf128_si256(m_vector, 1));
733
734 const __m256i zero_vector = _mm256_setzero_si256();
735 const __m256i left_shift =
736 intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
737 const __m256i neg_e_vector =
738 intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
739 const __m256i right_shift =
740 intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
741 const __m256i final_right_shift = _mm256_set1_epi32(31);
742 const __m256i final_right_shift_low =
743 intrin_utils::mm256_cvtepi32_epi64<path>(
744 _mm256_extractf128_si256(final_right_shift, 0));
745 const __m256i final_right_shift_high =
746 intrin_utils::mm256_cvtepi32_epi64<path>(
747 _mm256_extractf128_si256(final_right_shift, 1));
748 const __m256i convert_to_unsigned_64 =
749 _mm256_set1_epi64x(0x8000000000000000);
750
751 __m256i post_scaling_offset = _mm256_setzero_si256();
752
753 // A "half" added for rounding prior to truncation of 64-bit value.
754 const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
755 intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
756 convert_to_unsigned_64);
757
758 if (params.dst_zero_point) {
759 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
760 }
761
762 // We cannot do
763 //
764 // scaled_v_low =
765 // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
766 // scaled_v_high =
767 // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
768 //
769 // since this instruction is not in AVX2. Instead we use
770 // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
771 // offsets before (convert_to_unsigned_64) and after
772 // (convert_to_signed_halved).
773 //
774 // The overall process is, for 64-bit scaled accumulator:
775 // unsigned_accum = signed_accum + 1 << 63;
776 // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
777 // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
778
779 // There are various ways to repack the results, in the absence of
780 // _mm256_cvtepi64_epi32() or anything like it.
781 // A.
782 // accum_data_v[j] =
783 // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
784 // _mm256_extract_epi32(scaled_v_high, 4),
785 // _mm256_extract_epi32(scaled_v_high, 2),
786 // _mm256_extract_epi32(scaled_v_high, 0),
787 // _mm256_extract_epi32(scaled_v_low, 6),
788 // _mm256_extract_epi32(scaled_v_low, 4),
789 // _mm256_extract_epi32(scaled_v_low, 2),
790 // _mm256_extract_epi32(scaled_v_low, 0));
791 // B.
792 // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
793 // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
794 // accum_data_v[j] =
795 // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
796 // _mm256_extract_epi64(scaled_v_high, 0),
797 // _mm256_extract_epi64(scaled_v_low, 2),
798 // _mm256_extract_epi64(scaled_v_low, 0));
799 // C.
800 // scaled_v_low =
801 // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
802 // scaled_v_high =
803 // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
804 // accum_data_v[j] =
805 // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
806 //
807 // However, we choose the following because it uses two lighter
808 // instructions. The permutation does have a longer latency, but this
809 // loop can be unrolled.
810 // D.
811 // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
812 // __m256i results =
813 // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
814 // results = _mm256_permutevar8x32_epi32(results, repack_perm);
815 // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
816 // post_scaling_offset);
817
818 // This multiplier code is complex and expensive enough on x86, that
819 // we prefer to implement the channels-are-columns case by transposing
820 // around it, rather than duplicate it (which would also require
821 // duplicating the above code computing the multiplier constants).
822 // This is one instance where channels-are-columns has lower performance
823 // than channels-are-rows.
824 const bool transpose_around_multiplier =
825 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
826 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
827 if (transpose_around_multiplier) {
828 // Transpose the 8x8 accumulators block. Will be un-transposed below
829 // after the multplier implementation.
830 intrin_utils::mm256_transpose8x8_epi32<path>(
831 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
832 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
833 }
834
835 auto rounding_right_shift = [=](__m256i& results,
836 const __m256i& exponent) {
837 // Construct the "nudge" value for each lane if the exponent is
838 // greater than 0. Otherwise, the nudge is 0.
839 const __m256i zeros = _mm256_setzero_si256();
840 const __m256i mask_rightshift_gtz =
841 intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
842 const __m256i one_shift_exp_minus1 =
843 intrin_utils::mm256_sllv_epi32<path>(
844 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
845 exponent, _mm256_set1_epi32(1)));
846 __m256i nudge = intrin_utils::mm256_blendv_epi32(
847 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
848 // Calculate the shifted sum (results + nudge) >> exp.
849 const __m256i r_plus_nudge =
850 intrin_utils::mm256_add_epi32<path>(results, nudge);
851 const __m256i shifted_sum =
852 intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
853
854 // Identify overflow in each lane and create mask.
855 const __m256i one_shift_31minus_exp =
856 intrin_utils::mm256_sllv_epi32<path>(
857 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
858 _mm256_set1_epi32(31), exponent));
859 const __m256i mask_num_plus_nudge_overflow =
860 intrin_utils::mm256_cmpgt_epi32<path>(
861 results, intrin_utils::mm256_sub_epi32<path>(
862 _mm256_set1_epi32(0x7fffffff), nudge));
863 // Fill results with either (results + nudge) >> exponent or
864 // 1 << (31 - exp) in the case of overflow.
865 results = intrin_utils::mm256_blendv_epi32(
866 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
867 };
868
869 auto apply_multiplier = [=](__m256i& accum) {
870 __m256i shifted_accum =
871 intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
872 // Apply the fixed-point part of the multiplier.
873 __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
874 intrin_utils::mm256_cvtepi32_epi64<path>(
875 _mm256_extractf128_si256(shifted_accum, 0)),
876 m_64bit_low);
877 __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
878 intrin_utils::mm256_cvtepi32_epi64<path>(
879 _mm256_extractf128_si256(shifted_accum, 1)),
880 m_64bit_high);
881 scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
882 offset_vector);
883 scaled_v_high = intrin_utils::mm256_add_epi64<path>(
884 scaled_v_high, offset_vector);
885
886 scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
887 scaled_v_low, final_right_shift_low);
888 scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
889 scaled_v_high, final_right_shift_high);
890 scaled_v_high =
891 intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
892 __m256i results;
893 mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
894 // Permute results to this ordering of int32 elements
895 // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
896 results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
897
898 rounding_right_shift(results, right_shift);
899 accum =
900 intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
901 };
902 apply_multiplier(accum_data_v0);
903 apply_multiplier(accum_data_v1);
904 apply_multiplier(accum_data_v2);
905 apply_multiplier(accum_data_v3);
906 apply_multiplier(accum_data_v4);
907 apply_multiplier(accum_data_v5);
908 apply_multiplier(accum_data_v6);
909 apply_multiplier(accum_data_v7);
910 // See above comment: here we transpose again to undo the transposition
911 // of the 8x8 block of accumulators used to implement the
912 // channels-are-columns case.
913 if (transpose_around_multiplier) {
914 intrin_utils::mm256_transpose8x8_epi32<path>(
915 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
916 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
917 }
918 }
919 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
920 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
921 const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
922 (residual_cols == kAvx8bitBlockSize);
923
924 __m256i accum_data_v[kAvx8bitBlockSize];
925 if (!store_full_block) {
926 accum_data_v[0] = accum_data_v0;
927 accum_data_v[1] = accum_data_v1;
928 accum_data_v[2] = accum_data_v2;
929 accum_data_v[3] = accum_data_v3;
930 accum_data_v[4] = accum_data_v4;
931 accum_data_v[5] = accum_data_v5;
932 accum_data_v[6] = accum_data_v6;
933 accum_data_v[7] = accum_data_v7;
934 }
935
936 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
937 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
938 if (store_full_block) {
939 accum_data_v0 =
940 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
941 accum_data_v0 =
942 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
943 accum_data_v1 =
944 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
945 accum_data_v1 =
946 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
947 accum_data_v2 =
948 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
949 accum_data_v2 =
950 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
951 accum_data_v3 =
952 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
953 accum_data_v3 =
954 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
955 accum_data_v4 =
956 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
957 accum_data_v4 =
958 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
959 accum_data_v5 =
960 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
961 accum_data_v5 =
962 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
963 accum_data_v6 =
964 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
965 accum_data_v6 =
966 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
967 accum_data_v7 =
968 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
969 accum_data_v7 =
970 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
971 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
972 &tmp_ptr[0 * dst_stride], accum_data_v0);
973 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
974 &tmp_ptr[1 * dst_stride], accum_data_v1);
975 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
976 &tmp_ptr[2 * dst_stride], accum_data_v2);
977 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
978 &tmp_ptr[3 * dst_stride], accum_data_v3);
979 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
980 &tmp_ptr[4 * dst_stride], accum_data_v4);
981 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
982 &tmp_ptr[5 * dst_stride], accum_data_v5);
983 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
984 &tmp_ptr[6 * dst_stride], accum_data_v6);
985 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
986 &tmp_ptr[7 * dst_stride], accum_data_v7);
987 } else {
988 for (int j = 0; j < residual_cols; ++j) {
989 __m256i result = accum_data_v[j];
990 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
991 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
992 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
993 tmp_ptr, residual_rows, result);
994 tmp_ptr += dst_stride;
995 }
996 }
997 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
998 kAvx8bitBlockSize);
999 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1000 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1001 if (store_full_block) {
1002 accum_data_v0 =
1003 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1004 accum_data_v0 =
1005 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1006 accum_data_v1 =
1007 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1008 accum_data_v1 =
1009 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1010 accum_data_v2 =
1011 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1012 accum_data_v2 =
1013 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1014 accum_data_v3 =
1015 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1016 accum_data_v3 =
1017 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1018 accum_data_v4 =
1019 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1020 accum_data_v4 =
1021 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1022 accum_data_v5 =
1023 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1024 accum_data_v5 =
1025 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1026 accum_data_v6 =
1027 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1028 accum_data_v6 =
1029 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1030 accum_data_v7 =
1031 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1032 accum_data_v7 =
1033 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1034 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
1035 accum_data_v0);
1036 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
1037 accum_data_v1);
1038 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1039 &tmp_ptr[2 * dst_stride], accum_data_v2);
1040 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1041 &tmp_ptr[3 * dst_stride], accum_data_v3);
1042 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1043 &tmp_ptr[4 * dst_stride], accum_data_v4);
1044 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1045 &tmp_ptr[5 * dst_stride], accum_data_v5);
1046 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1047 &tmp_ptr[6 * dst_stride], accum_data_v6);
1048 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
1049 &tmp_ptr[7 * dst_stride], accum_data_v7);
1050 } else {
1051 for (int j = 0; j < residual_cols; ++j) {
1052 __m256i result = accum_data_v[j];
1053 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1054 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1055 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
1056 tmp_ptr, residual_rows, result);
1057 tmp_ptr += dst_stride;
1058 }
1059 }
1060 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1061 kAvx8bitBlockSize);
1062 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1063 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1064 if (store_full_block) {
1065 accum_data_v0 =
1066 intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v);
1067 accum_data_v0 =
1068 intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v);
1069 accum_data_v1 =
1070 intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v);
1071 accum_data_v1 =
1072 intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v);
1073 accum_data_v2 =
1074 intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v);
1075 accum_data_v2 =
1076 intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v);
1077 accum_data_v3 =
1078 intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v);
1079 accum_data_v3 =
1080 intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v);
1081 accum_data_v4 =
1082 intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v);
1083 accum_data_v4 =
1084 intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v);
1085 accum_data_v5 =
1086 intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v);
1087 accum_data_v5 =
1088 intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v);
1089 accum_data_v6 =
1090 intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v);
1091 accum_data_v6 =
1092 intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v);
1093 accum_data_v7 =
1094 intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v);
1095 accum_data_v7 =
1096 intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v);
1097 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
1098 accum_data_v0);
1099 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
1100 accum_data_v1);
1101 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1102 &tmp_ptr[2 * dst_stride], accum_data_v2);
1103 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1104 &tmp_ptr[3 * dst_stride], accum_data_v3);
1105 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1106 &tmp_ptr[4 * dst_stride], accum_data_v4);
1107 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1108 &tmp_ptr[5 * dst_stride], accum_data_v5);
1109 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1110 &tmp_ptr[6 * dst_stride], accum_data_v6);
1111 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
1112 &tmp_ptr[7 * dst_stride], accum_data_v7);
1113 } else {
1114 for (int j = 0; j < residual_cols; ++j) {
1115 __m256i result = accum_data_v[j];
1116 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1117 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1118 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
1119 tmp_ptr, residual_rows, result);
1120 tmp_ptr += dst_stride;
1121 }
1122 }
1123 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1124 kAvx8bitBlockSize);
1125 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1126 if (store_full_block) {
1127 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
1128 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
1129 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
1130 accum_data_v1);
1131 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
1132 accum_data_v2);
1133 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
1134 accum_data_v3);
1135 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
1136 accum_data_v4);
1137 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
1138 accum_data_v5);
1139 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
1140 accum_data_v6);
1141 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
1142 accum_data_v7);
1143 } else {
1144 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1145 for (int j = 0; j < residual_cols; ++j) {
1146 intrin_utils::mm256_n_storeu_epi32<path>(
1147 dst_block_ptr, residual_rows, accum_data_v[j]);
1148 dst_block_ptr += dst_stride;
1149 }
1150 }
1151 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1152 kAvx8bitBlockSize);
1153 } else {
1154 RUY_DCHECK(false);
1155 }
1156
1157 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1158 } // End row-block loop.
1159
1160 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1161 kAvx8bitBlockSize * params.dst_stride);
1162 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1163 } // End col-block loop.
1164} // NOLINT(readability/fn_size)
1165
1166void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) {
1167 Kernel8bitAvxImpl<Path::kAvx>(params);
1168}
1169
1170template <Path path>
1171void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
1172 profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
1173
1174 RUY_DCHECK_EQ(params.dst_cols, 1);
1175 RUY_DCHECK_EQ(params.last_col, 0);
1176 RUY_DCHECK_EQ(params.start_col, 0);
1177
1178 const std::int8_t splitter_idx_data[32] = {
1179 0, 1, 4, 5, 8, 9, 12, 13, //
1180 2, 3, 6, 7, 10, 11, 14, 15, //
1181 0, 1, 4, 5, 8, 9, 12, 13, //
1182 2, 3, 6, 7, 10, 11, 14, 15 //
1183 };
1184
1185 int bias_ptr_block_increment =
1186 params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
1187
1188 const std::int8_t* rhs_col_ptr =
1189 static_cast<const int8_t*>(params.rhs_base_ptr);
1190 void* dst_col_ptr = params.dst_base_ptr;
1191 const std::int32_t* bias_col_ptr = params.bias;
1192 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
1193 bias_col_ptr += params.start_row;
1194 }
1195
1196 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
1197 void* dst_ptr = dst_col_ptr;
1198 const std::int32_t* bias_ptr = bias_col_ptr;
1199
1200 const std::int32_t lhs_zero_point = params.lhs_zero_point;
1201 const bool has_rhs_sums_offsets =
1202 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
1203 std::int32_t rhs_sums_offsets[8];
1204 if (has_rhs_sums_offsets) {
1205 const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>(
1206 _mm256_set1_epi32(lhs_zero_point),
1207 _mm256_loadu_si256(
1208 reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
1209 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
1210 rhs_sums_offset_v);
1211 }
1212
1213 for (int row = params.start_row; row <= params.last_row;
1214 row += kAvx8bitBlockSize) {
1215 const int residual_rows =
1216 std::min(params.dst_rows - row, kAvx8bitBlockSize);
1217
1218 const __m256i splitter_idx =
1219 _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
1220
1221 __m256i accum_data_v0;
1222
1223 // Initialize with bias.
1224 __m256i initial_accum_data =
1225 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
1226 bias_ptr += bias_ptr_block_increment;
1227
1228 // Adjustments common across columns.
1229 const std::int32_t rhs_zero_point = params.rhs_zero_point;
1230 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
1231 const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>(
1232 _mm256_set1_epi32(rhs_zero_point),
1233 _mm256_loadu_si256(
1234 reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
1235 initial_accum_data = intrin_utils::mm256_sub_epi32<path>(
1236 initial_accum_data, lhs_sums_offset);
1237 }
1238 const std::int32_t prod_zp_depth = params.prod_zp_depth;
1239 if (prod_zp_depth) {
1240 initial_accum_data = intrin_utils::mm256_add_epi32<path>(
1241 initial_accum_data, _mm256_set1_epi32(prod_zp_depth));
1242 }
1243
1244 // Adjustments differing across columns.
1245 if (has_rhs_sums_offsets) {
1246 accum_data_v0 = intrin_utils::mm256_sub_epi32<path>(
1247 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
1248 } else {
1249 accum_data_v0 = initial_accum_data;
1250 }
1251
1252 const std::int8_t* lhs_ptr = lhs_col_ptr;
1253 const std::int8_t* rhs_ptr = rhs_col_ptr;
1254 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
1255 const __m256i lhs_data =
1256 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
1257 const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
1258
1259 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
1260 // For simplicity we load 4x the data that we need and process twice the
1261 // data that we need and store only the data we need.
1262 std::int32_t rhs_data[2];
1263 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
1264 // Now that we have cast the RHS data, we store it so that each value
1265 // can be separately loaded in the accumulation loop.
1266 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
1267
1268 // NOTE: There may be opportunities for permuting the data in the packing
1269 // code instead of here.
1270 const __m256i lhs_data_split =
1271 intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx);
1272 const __m256i lhs_data_split_expand_bottom =
1273 intrin_utils::mm256_cvtepi8_epi16<path>(
1274 _mm256_extractf128_si256(lhs_data_split, 0));
1275 const __m256i lhs_data_split_expand_top =
1276 intrin_utils::mm256_cvtepi8_epi16<path>(
1277 _mm256_extractf128_si256(lhs_data_split, 1));
1278
1279 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
1280 const __m256i lhs_16_bit_low =
1281 intrin_utils::mm256_permute2x128_si256<path>(
1282 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
1283 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
1284 const __m256i lhs_16_bit_high =
1285 intrin_utils::mm256_permute2x128_si256<path>(
1286 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
1287 // Accumulate for column 0.
1288 const std::int32_t low_rhs_value = rhs_data[0];
1289 const std::int32_t high_rhs_value = rhs_data[1];
1290
1291 const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
1292 const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
1293
1294 accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1295 accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1296 lhs_16_bit_low, rhs_16_bit_dup_low));
1297 accum_data_v0 = intrin_utils::mm256_add_epi32<path>(
1298 accum_data_v0, intrin_utils::mm256_madd_epi16<path>(
1299 lhs_16_bit_high, rhs_16_bit_dup_high));
1300
1301 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1302 rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
1303 }
1304
1305 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
1306 __m256i m_vector;
1307 __m256i e_vector;
1308 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
1309 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
1310 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1311 params.multiplier_fixedpoint + channel));
1312 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1313 params.multiplier_exponent + channel));
1314
1315 const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>(
1316 _mm256_extractf128_si256(m_vector, 0));
1317 const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>(
1318 _mm256_extractf128_si256(m_vector, 1));
1319
1320 const __m256i zero_vector = _mm256_setzero_si256();
1321 const __m256i left_shift =
1322 intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector);
1323 const __m256i neg_e_vector =
1324 intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
1325 const __m256i right_shift =
1326 intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
1327 const __m256i final_right_shift = _mm256_set1_epi32(31);
1328 const __m256i final_right_shift_low =
1329 intrin_utils::mm256_cvtepi32_epi64<path>(
1330 _mm256_extractf128_si256(final_right_shift, 0));
1331 const __m256i final_right_shift_high =
1332 intrin_utils::mm256_cvtepi32_epi64<path>(
1333 _mm256_extractf128_si256(final_right_shift, 1));
1334 const __m256i convert_to_unsigned_64 =
1335 _mm256_set1_epi64x(0x8000000000000000);
1336
1337 __m256i post_scaling_offset = _mm256_setzero_si256();
1338
1339 // A "half" added for rounding prior to truncation of 64-bit value.
1340 const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
1341 intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
1342 convert_to_unsigned_64);
1343
1344 if (params.dst_zero_point) {
1345 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
1346 }
1347
1348 // See GEMM version for details of this process.
1349 {
1350 __m256i shifted_accum =
1351 intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift);
1352 // Apply the fixed-point part of the multiplier.
1353 __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>(
1354 intrin_utils::mm256_cvtepi32_epi64<path>(
1355 _mm256_extractf128_si256(shifted_accum, 0)),
1356 m_64bit_low);
1357 __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>(
1358 intrin_utils::mm256_cvtepi32_epi64<path>(
1359 _mm256_extractf128_si256(shifted_accum, 1)),
1360 m_64bit_high);
1361
1362 scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
1363 offset_vector);
1364 scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
1365 offset_vector);
1366
1367 scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
1368 scaled_v_low, final_right_shift_low);
1369 scaled_v_high = intrin_utils::mm256_srlv_epi64<path>(
1370 scaled_v_high, final_right_shift_high);
1371
1372 scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32);
1373 __m256i results;
1374 mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa);
1375 // Permute results to this ordering of int32 elements
1376 // lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
1377 results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
1378
1379 // Now perform the Rounding Right Shift.
1380 // First, construct the "nudge" value for each lane if the exponent is
1381 // greater than 0. Otherwise, the nudge is 0.
1382 const __m256i zeros = _mm256_setzero_si256();
1383 const __m256i mask_rightshift_gtz =
1384 intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
1385 const __m256i one_shift_exp_minus1 =
1386 intrin_utils::mm256_sllv_epi32<path>(
1387 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1388 right_shift, _mm256_set1_epi32(1)));
1389 __m256i nudge = intrin_utils::mm256_blendv_epi32(
1390 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
1391 // Calculate the shifted sum (results + nudge) >> exp.
1392 const __m256i r_plus_nudge =
1393 intrin_utils::mm256_add_epi32<path>(results, nudge);
1394 const __m256i shifted_sum =
1395 intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
1396
1397 // Identify overflow in each lane and create mask.
1398 const __m256i one_shift_31minus_exp =
1399 intrin_utils::mm256_sllv_epi32<path>(
1400 _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
1401 _mm256_set1_epi32(31), right_shift));
1402 const __m256i mask_num_plus_nudge_overflow =
1403 intrin_utils::mm256_cmpgt_epi32<path>(
1404 results, intrin_utils::mm256_sub_epi32<path>(
1405 _mm256_set1_epi32(0x7fffffff), nudge));
1406 // Fill results with either (results + nudge) >> exponent or
1407 // 1 << (31 - exp) in the case of overflow.
1408 results = intrin_utils::mm256_blendv_epi32(
1409 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
1410 accum_data_v0 =
1411 intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
1412 }
1413 }
1414 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
1415 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
1416
1417 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
1418 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
1419 __m256i result = accum_data_v0;
1420 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1421 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1422 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1423 result);
1424 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
1425 kAvx8bitBlockSize);
1426 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
1427 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
1428 __m256i result = accum_data_v0;
1429 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1430 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1431 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
1432 result);
1433 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
1434 kAvx8bitBlockSize);
1435 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
1436 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
1437 __m256i result = accum_data_v0;
1438 result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v);
1439 result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v);
1440 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
1441 result);
1442 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
1443 kAvx8bitBlockSize);
1444 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1445 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1446 intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1447 accum_data_v0);
1448 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1449 kAvx8bitBlockSize);
1450 } else {
1451 RUY_DCHECK(false);
1452 }
1453
1454 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1455 } // End row-block loop.
1456
1457 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1458 kAvx8bitBlockSize * params.dst_stride);
1459 rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
1460} // NOLINT(readability/fn_size)
1461
1462void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) {
1463 Kernel8bitAvxSingleColImpl<Path::kAvx>(params);
1464}
1465
1466void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) {
1467 profiler::ScopeLabel label("Kernel kAvx float");
1468 KernelFloatAvxCommon<Path::kAvx>(params);
1469}
1470
1471void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) {
1472 profiler::ScopeLabel label("Kernel kAvx float GEMV");
1473 KernelFloatAvxCommonSingleCol<Path::kAvx>(params);
1474}
1475
1476#endif // RUY_PLATFORM_AVX && RUY_OPT(ASM)
1477
1478} // namespace ruy
1479