1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cstdint>
18#include <cstring>
19
20#include "ruy/check_macros.h"
21#include "ruy/kernel_common.h"
22#include "ruy/kernel_x86.h"
23#include "ruy/opt_set.h"
24#include "ruy/platform.h"
25#include "ruy/profiler/instrumentation.h"
26
27#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
28#include <immintrin.h> // IWYU pragma: keep
29#endif
30
31namespace ruy {
32
33#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
34
35void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
36 // CPU-ID-based checks should disable the path that would reach this point.
37 RUY_DCHECK(false);
38}
39
40void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) {
41 // CPU-ID-based checks should disable the path that would reach this point.
42 RUY_DCHECK(false);
43}
44
45void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) {
46 // CPU-ID-based checks should disable the path that would reach this point.
47 RUY_DCHECK(false);
48}
49
50void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
51 // CPU-ID-based checks should disable the path that would reach this point.
52 RUY_DCHECK(false);
53}
54
55#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
56
57static constexpr int kAvx8bitBlockSize = 8;
58static constexpr int kAvx8bitInnerSize = 4;
59
60namespace {
61namespace intrin_utils {
62
63template <>
64inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a,
65 const __m256i& b) {
66 return _mm256_shuffle_epi8(a, b);
67}
68
69// Make an inline function for FMA so we can share the float kernels
70// with non-FMA code.
71template <>
72inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b,
73 const __m256& c) {
74 return _mm256_fmadd_ps(a, b, c);
75}
76
77template <>
78inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
79 const int imm) {
80 switch (imm) {
81 case 0:
82 return _mm256_extracti128_si256(a, 0);
83 case 1:
84 return _mm256_extracti128_si256(a, 1);
85 default:
86 RUY_DCHECK_LT(imm, 2);
87 return _mm_setzero_si128();
88 }
89}
90
91__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
92 const __m256i& mask) {
93 __m256 result =
94 _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
95 _mm256_castsi256_ps(mask));
96 return _mm256_castps_si256(result);
97}
98
99} // namespace intrin_utils
100} // namespace
101
102template <Path path>
103void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
104 profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
105 const std::int8_t splitter_idx_data[32] = {
106 0, 1, 4, 5, 8, 9, 12, 13, //
107 2, 3, 6, 7, 10, 11, 14, 15, //
108 0, 1, 4, 5, 8, 9, 12, 13, //
109 2, 3, 6, 7, 10, 11, 14, 15 //
110 };
111
112 std::int32_t dst_stride = 0;
113 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
114 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
115 dst_stride = params.dst_stride;
116 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
117 dst_stride = params.dst_stride / sizeof(std::int16_t);
118 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
119 dst_stride = params.dst_stride / sizeof(std::int32_t);
120 } else {
121 RUY_DCHECK(false);
122 }
123
124 const void* rhs_col_ptr = params.rhs_base_ptr;
125 void* dst_col_ptr = params.dst_base_ptr;
126
127 for (int col = params.start_col; col <= params.last_col;
128 col += kAvx8bitBlockSize) {
129 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
130 void* dst_ptr = dst_col_ptr;
131
132 const std::int32_t lhs_zero_point = params.lhs_zero_point;
133 const bool has_rhs_sums_offsets =
134 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
135 std::int32_t rhs_sums_offsets[8];
136 if (has_rhs_sums_offsets) {
137 const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
138 _mm256_set1_epi32(lhs_zero_point),
139 _mm256_loadu_si256(
140 reinterpret_cast<__m256i const*>(&params.rhs_sums[col])));
141 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
142 rhs_sums_offset_v);
143 }
144
145 for (int row = params.start_row; row <= params.last_row;
146 row += kAvx8bitBlockSize) {
147 int channel =
148 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
149 int multiplier_channel =
150 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
151 const int residual_rows =
152 std::min(params.dst_rows - row, kAvx8bitBlockSize);
153 const int residual_cols =
154 std::min(params.dst_cols - col, kAvx8bitBlockSize);
155
156 const __m256i splitter_idx = _mm256_loadu_si256(
157 reinterpret_cast<__m256i const*>(splitter_idx_data));
158
159 __m256i accum_data_v0;
160 __m256i accum_data_v1;
161 __m256i accum_data_v2;
162 __m256i accum_data_v3;
163 __m256i accum_data_v4;
164 __m256i accum_data_v5;
165 __m256i accum_data_v6;
166 __m256i accum_data_v7;
167
168 // initial_accum_data will be the initialize of each of the
169 // accum_data_* accumulator registers. We compute into it terms that are
170 // identical across columns.
171 __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth);
172
173 // In the channels-are-rows case, we can load bias here.
174 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
175 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
176 initial_accum_data = _mm256_add_epi32(
177 initial_accum_data,
178 _mm256_loadu_si256(
179 reinterpret_cast<const __m256i*>(params.bias + row)));
180 }
181
182 // Adjustments common across columns.
183 const std::int32_t rhs_zero_point = params.rhs_zero_point;
184 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
185 const __m256i lhs_sums_offset = _mm256_mullo_epi32(
186 _mm256_set1_epi32(rhs_zero_point),
187 _mm256_loadu_si256(
188 reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
189 initial_accum_data =
190 _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
191 }
192
193 // Adjustments differing across columns.
194 if (has_rhs_sums_offsets) {
195 accum_data_v0 = _mm256_sub_epi32(
196 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0]));
197 accum_data_v1 = _mm256_sub_epi32(
198 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1]));
199 accum_data_v2 = _mm256_sub_epi32(
200 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2]));
201 accum_data_v3 = _mm256_sub_epi32(
202 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3]));
203 accum_data_v4 = _mm256_sub_epi32(
204 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4]));
205 accum_data_v5 = _mm256_sub_epi32(
206 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5]));
207 accum_data_v6 = _mm256_sub_epi32(
208 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6]));
209 accum_data_v7 = _mm256_sub_epi32(
210 initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7]));
211 } else {
212 accum_data_v0 = initial_accum_data;
213 accum_data_v1 = initial_accum_data;
214 accum_data_v2 = initial_accum_data;
215 accum_data_v3 = initial_accum_data;
216 accum_data_v4 = initial_accum_data;
217 accum_data_v5 = initial_accum_data;
218 accum_data_v6 = initial_accum_data;
219 accum_data_v7 = initial_accum_data;
220 }
221
222 // Finally, in the channels-are-columns case, load bias data here.
223 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
224 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
225 const __m256i bias_data = _mm256_loadu_si256(
226 reinterpret_cast<const __m256i*>(params.bias + col));
227 accum_data_v0 = _mm256_add_epi32(
228 accum_data_v0,
229 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0)));
230 accum_data_v1 = _mm256_add_epi32(
231 accum_data_v1,
232 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1)));
233 accum_data_v2 = _mm256_add_epi32(
234 accum_data_v2,
235 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2)));
236 accum_data_v3 = _mm256_add_epi32(
237 accum_data_v3,
238 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3)));
239 accum_data_v4 = _mm256_add_epi32(
240 accum_data_v4,
241 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4)));
242 accum_data_v5 = _mm256_add_epi32(
243 accum_data_v5,
244 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5)));
245 accum_data_v6 = _mm256_add_epi32(
246 accum_data_v6,
247 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6)));
248 accum_data_v7 = _mm256_add_epi32(
249 accum_data_v7,
250 _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7)));
251 }
252
253 const std::int8_t* lhs_ptr = lhs_col_ptr;
254 const void* rhs_ptr = rhs_col_ptr;
255 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
256 const __m256i lhs_data =
257 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
258 const __m256i rhs_data_8bit =
259 _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
260
261 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
262 std::int32_t rhs_data_buf[16];
263 const std::int32_t* rhs_data =
264 reinterpret_cast<const std::int32_t*>(rhs_ptr);
265
266 if (params.rhs_scalar_size == 1) {
267 rhs_data = rhs_data_buf;
268 const __m128i rhs_data_bottom_lane =
269 _mm256_castsi256_si128(rhs_data_8bit);
270 const __m128i rhs_data_top_lane =
271 _mm256_extracti128_si256(rhs_data_8bit, 1);
272 const __m256i rhs_16_bit_dup_low =
273 _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
274 const __m256i rhs_16_bit_dup_high =
275 _mm256_cvtepi8_epi16(rhs_data_top_lane);
276 // Now that we have cast the RHS data, we store it so that each value
277 // can be separately loaded in the accumulation loop.
278 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf),
279 rhs_16_bit_dup_low);
280 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8),
281 rhs_16_bit_dup_high);
282 } else {
283 RUY_DCHECK(params.rhs_scalar_size == 2);
284 }
285
286 const __m256i lhs_data_split =
287 _mm256_shuffle_epi8(lhs_data, splitter_idx);
288 const __m256i lhs_data_split_expand_bottom =
289 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
290 const __m256i lhs_data_split_expand_top =
291 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
292
293 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
294 const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
295 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
296 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
297 const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
298 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
299
300 __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
301 rhs_data)); // Load [0 1 2 3 4 5 6 7]
302 __m256i rhs1 = _mm256_lddqu_si256(
303 reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15]
304 __m256i rhs0_3 =
305 _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3]
306 __m256i rhs4_7 =
307 _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7]
308 __m256i rhs8_11 =
309 _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11]
310 __m256i rhs12_15 =
311 _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15]
312
313 auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi,
314 __m256i& accum) {
315 accum = _mm256_add_epi32(
316 accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo));
317 accum = _mm256_add_epi32(
318 accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi));
319 };
320 __m256i tmp0, tmp1, tmp2, tmp3;
321 tmp0 = _mm256_shuffle_epi32(rhs0_3, 0);
322 tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55);
323 process_column(tmp0, tmp1, accum_data_v0);
324 tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa);
325 tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff);
326 process_column(tmp2, tmp3, accum_data_v1);
327
328 tmp0 = _mm256_shuffle_epi32(rhs4_7, 0);
329 tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55);
330 process_column(tmp0, tmp1, accum_data_v2);
331 tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa);
332 tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff);
333 process_column(tmp2, tmp3, accum_data_v3);
334
335 tmp0 = _mm256_shuffle_epi32(rhs8_11, 0);
336 tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55);
337 process_column(tmp0, tmp1, accum_data_v4);
338 tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa);
339 tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff);
340 process_column(tmp2, tmp3, accum_data_v5);
341
342 tmp0 = _mm256_shuffle_epi32(rhs12_15, 0);
343 tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55);
344 process_column(tmp0, tmp1, accum_data_v6);
345 tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa);
346 tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff);
347 process_column(tmp2, tmp3, accum_data_v7);
348
349 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
350 rhs_ptr = static_cast<const void*>(
351 static_cast<const char*>(rhs_ptr) +
352 kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size);
353 }
354
355 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
356 __m256i m_vector;
357 __m256i e_vector;
358 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
359 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
360 params.multiplier_fixedpoint + multiplier_channel));
361 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
362 params.multiplier_exponent + multiplier_channel));
363
364 const __m256i m_64bit_low =
365 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
366 const __m256i m_64bit_high =
367 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
368
369 const __m256i zero_vector = _mm256_setzero_si256();
370 const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
371 const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
372 const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
373 const __m256i final_right_shift = _mm256_set1_epi32(31);
374 const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
375 _mm256_extracti128_si256(final_right_shift, 0));
376 const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
377 _mm256_extracti128_si256(final_right_shift, 1));
378 const __m256i convert_to_unsigned_64 =
379 _mm256_set1_epi64x(0x8000000000000000);
380
381 __m256i post_scaling_offset = _mm256_setzero_si256();
382 // A "half" added for rounding prior to truncation of 64-bit value.
383 const __m256i offset_vector = _mm256_add_epi64(
384 _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
385 convert_to_unsigned_64);
386
387 if (params.dst_zero_point) {
388 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
389 }
390
391 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
392
393 // We cannot do
394 //
395 // scaled_v_low =
396 // _mm256_srav_epi64(scaled_v_low, final_right_shift_low);
397 // scaled_v_high =
398 // _mm256_srav_epi64(scaled_v_high, final_right_shift_high);
399 //
400 // since this instruction is not in AVX2. Instead we use
401 // _mm256_srlv_epi64, but this is an unsigned shift, so we applied
402 // offsets before (convert_to_unsigned_64) and after
403 // (convert_to_signed_halved).
404 //
405 // The overall process is, for 64-bit scaled accumulator:
406 // unsigned_accum = signed_accum + 1 << 63;
407 // unsigned_accum = (unsigned_accum >> right_shift) >> 31;
408 // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2;
409
410 // There are various ways to repack the results, in the absence of
411 // _mm256_cvtepi64_epi32() or anything like it.
412 // A.
413 // accum_data_v[j] =
414 // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6),
415 // _mm256_extract_epi32(scaled_v_high, 4),
416 // _mm256_extract_epi32(scaled_v_high, 2),
417 // _mm256_extract_epi32(scaled_v_high, 0),
418 // _mm256_extract_epi32(scaled_v_low, 6),
419 // _mm256_extract_epi32(scaled_v_low, 4),
420 // _mm256_extract_epi32(scaled_v_low, 2),
421 // _mm256_extract_epi32(scaled_v_low, 0));
422 // B.
423 // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8);
424 // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8);
425 // accum_data_v[j] =
426 // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2),
427 // _mm256_extract_epi64(scaled_v_high, 0),
428 // _mm256_extract_epi64(scaled_v_low, 2),
429 // _mm256_extract_epi64(scaled_v_low, 0));
430 // C.
431 // scaled_v_low =
432 // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm);
433 // scaled_v_high =
434 // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm);
435 // accum_data_v[j] =
436 // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20);
437 //
438 // However, we choose the following because it uses two lighter
439 // instructions. The permutation does have a longer latency, but this
440 // loop can be unrolled.
441 // D.
442 // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
443 // __m256i results =
444 // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
445 // results = _mm256_permutevar8x32_epi32(results, repack_perm);
446 // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
447
448 // This multiplier code is complex and expensive enough on x86, that
449 // we prefer to implement the channels-are-columns case by transposing
450 // around it, rather than duplicate it (which would also require
451 // duplicating the above code computing the multiplier constants).
452 // This is one instance where channels-are-columns has lower performance
453 // than channels-are-rows.
454 const bool transpose_around_multiplier =
455 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
456 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
457 if (transpose_around_multiplier) {
458 // Transpose the 8x8 accumulators block. Will be un-transposed below
459 // after the multplier implementation.
460 intrin_utils::mm256_transpose8x8_epi32<path>(
461 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
462 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
463 }
464
465 auto rounding_right_shift = [=](__m256i& results,
466 const __m256i& exponent) {
467 // Construct the "nudge" value for each lane if the exponent is
468 // greater than 0. Otherwise, the nudge is 0.
469 const __m256i zeros = _mm256_setzero_si256();
470 const __m256i mask_rightshift_gtz =
471 _mm256_cmpgt_epi32(exponent, zeros);
472 const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
473 _mm256_set1_epi32(1),
474 _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
475 __m256i nudge = intrin_utils::mm256_blendv_epi32(
476 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
477 // Calculate the shifted sum (results + nudge) >> exp.
478 const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
479 const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
480
481 // Identify overflow in each lane and create mask.
482 const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
483 _mm256_set1_epi32(1),
484 _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
485 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
486 results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
487 // Fill results with either (results + nudge) >> exponent or
488 // 1 << (31 - exp) in the case of overflow.
489 results = intrin_utils::mm256_blendv_epi32(
490 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
491 };
492
493 auto apply_multiplier = [=](__m256i& accum) {
494 __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
495 // Apply the fixed-point part of the multiplier.
496 __m256i scaled_v_low = _mm256_mul_epi32(
497 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
498 m_64bit_low);
499 __m256i scaled_v_high = _mm256_mul_epi32(
500 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
501 m_64bit_high);
502
503 scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
504 scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
505
506 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
507 scaled_v_high =
508 _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
509
510 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
511 __m256i results =
512 _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
513 results = _mm256_permutevar8x32_epi32(results, repack_perm);
514 // Now do a Rounding Right Shift.
515 rounding_right_shift(results, right_shift);
516 accum = _mm256_add_epi32(results, post_scaling_offset);
517 };
518 apply_multiplier(accum_data_v0);
519 apply_multiplier(accum_data_v1);
520 apply_multiplier(accum_data_v2);
521 apply_multiplier(accum_data_v3);
522 apply_multiplier(accum_data_v4);
523 apply_multiplier(accum_data_v5);
524 apply_multiplier(accum_data_v6);
525 apply_multiplier(accum_data_v7);
526 // See above comment: here we transpose again to undo the transposition
527 // of the 8x8 block of accumulators used to implement the
528 // channels-are-columns case.
529 if (transpose_around_multiplier) {
530 intrin_utils::mm256_transpose8x8_epi32<path>(
531 &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
532 &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
533 }
534 }
535 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
536 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
537 const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
538 (residual_cols == kAvx8bitBlockSize);
539
540 __m256i accum_data_v[kAvx8bitBlockSize];
541 if (!store_full_block) {
542 accum_data_v[0] = accum_data_v0;
543 accum_data_v[1] = accum_data_v1;
544 accum_data_v[2] = accum_data_v2;
545 accum_data_v[3] = accum_data_v3;
546 accum_data_v[4] = accum_data_v4;
547 accum_data_v[5] = accum_data_v5;
548 accum_data_v[6] = accum_data_v6;
549 accum_data_v[7] = accum_data_v7;
550 }
551
552 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
553 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
554 if (store_full_block) {
555 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
556 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
557 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
558 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
559 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
560 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
561 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
562 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
563 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
564 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
565 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
566 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
567 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
568 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
569 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
570 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
571 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
572 &tmp_ptr[0 * dst_stride], accum_data_v0);
573 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
574 &tmp_ptr[1 * dst_stride], accum_data_v1);
575 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
576 &tmp_ptr[2 * dst_stride], accum_data_v2);
577 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
578 &tmp_ptr[3 * dst_stride], accum_data_v3);
579 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
580 &tmp_ptr[4 * dst_stride], accum_data_v4);
581 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
582 &tmp_ptr[5 * dst_stride], accum_data_v5);
583 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
584 &tmp_ptr[6 * dst_stride], accum_data_v6);
585 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
586 &tmp_ptr[7 * dst_stride], accum_data_v7);
587 } else {
588 for (int j = 0; j < residual_cols; ++j) {
589 __m256i result = accum_data_v[j];
590 result = _mm256_min_epi32(result, clamp_max_v);
591 result = _mm256_max_epi32(result, clamp_min_v);
592 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
593 tmp_ptr, residual_rows, result);
594 tmp_ptr += dst_stride;
595 }
596 }
597 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
598 kAvx8bitBlockSize);
599 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
600 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
601 if (store_full_block) {
602 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
603 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
604 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
605 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
606 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
607 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
608 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
609 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
610 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
611 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
612 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
613 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
614 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
615 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
616 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
617 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
618 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0],
619 accum_data_v0);
620 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride],
621 accum_data_v1);
622 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
623 &tmp_ptr[2 * dst_stride], accum_data_v2);
624 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
625 &tmp_ptr[3 * dst_stride], accum_data_v3);
626 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
627 &tmp_ptr[4 * dst_stride], accum_data_v4);
628 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
629 &tmp_ptr[5 * dst_stride], accum_data_v5);
630 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
631 &tmp_ptr[6 * dst_stride], accum_data_v6);
632 intrin_utils::mm256_storeu_cvtepi32_epi8<path>(
633 &tmp_ptr[7 * dst_stride], accum_data_v7);
634 } else {
635 for (int j = 0; j < residual_cols; ++j) {
636 __m256i result = accum_data_v[j];
637 result = _mm256_min_epi32(result, clamp_max_v);
638 result = _mm256_max_epi32(result, clamp_min_v);
639 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
640 tmp_ptr, residual_rows, result);
641 tmp_ptr += dst_stride;
642 }
643 }
644 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
645 kAvx8bitBlockSize);
646 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
647 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
648 if (store_full_block) {
649 accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v);
650 accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v);
651 accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v);
652 accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v);
653 accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v);
654 accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v);
655 accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v);
656 accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v);
657 accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v);
658 accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v);
659 accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v);
660 accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v);
661 accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v);
662 accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v);
663 accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v);
664 accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v);
665 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0],
666 accum_data_v0);
667 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride],
668 accum_data_v1);
669 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
670 &tmp_ptr[2 * dst_stride], accum_data_v2);
671 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
672 &tmp_ptr[3 * dst_stride], accum_data_v3);
673 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
674 &tmp_ptr[4 * dst_stride], accum_data_v4);
675 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
676 &tmp_ptr[5 * dst_stride], accum_data_v5);
677 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
678 &tmp_ptr[6 * dst_stride], accum_data_v6);
679 intrin_utils::mm256_storeu_cvtepi32_epi16<path>(
680 &tmp_ptr[7 * dst_stride], accum_data_v7);
681 } else {
682 for (int j = 0; j < residual_cols; ++j) {
683 __m256i result = accum_data_v[j];
684 result = _mm256_min_epi32(result, clamp_max_v);
685 result = _mm256_max_epi32(result, clamp_min_v);
686 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(
687 tmp_ptr, residual_rows, result);
688 tmp_ptr += dst_stride;
689 }
690 }
691 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
692 kAvx8bitBlockSize);
693 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
694 if (store_full_block) {
695 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
696 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0);
697 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride],
698 accum_data_v1);
699 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride],
700 accum_data_v2);
701 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride],
702 accum_data_v3);
703 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride],
704 accum_data_v4);
705 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride],
706 accum_data_v5);
707 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride],
708 accum_data_v6);
709 intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride],
710 accum_data_v7);
711 } else {
712 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
713 for (int j = 0; j < residual_cols; ++j) {
714 intrin_utils::mm256_n_storeu_epi32<path>(
715 dst_block_ptr, residual_rows, accum_data_v[j]);
716 dst_block_ptr += dst_stride;
717 }
718 }
719 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
720 kAvx8bitBlockSize);
721 } else {
722 RUY_DCHECK(false);
723 }
724
725 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
726 } // End row-block loop.
727
728 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
729 kAvx8bitBlockSize * params.dst_stride);
730 rhs_col_ptr =
731 static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
732 kAvx8bitBlockSize * params.rhs_stride);
733 } // End col-block loop.
734} // NOLINT(readability/fn_size)
735
736void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
737 Kernel8bitAvx2Impl<Path::kAvx2Fma>(params);
738}
739
740template <Path path>
741void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
742 profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
743
744 RUY_DCHECK_EQ(params.dst_cols, 1);
745 RUY_DCHECK_EQ(params.last_col, 0);
746 RUY_DCHECK_EQ(params.start_col, 0);
747
748 const std::int8_t splitter_idx_data[32] = {
749 0, 1, 4, 5, 8, 9, 12, 13, //
750 2, 3, 6, 7, 10, 11, 14, 15, //
751 0, 1, 4, 5, 8, 9, 12, 13, //
752 2, 3, 6, 7, 10, 11, 14, 15 //
753 };
754
755 int bias_ptr_block_increment =
756 params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
757
758 const void* rhs_col_ptr = params.rhs_base_ptr;
759 void* dst_col_ptr = params.dst_base_ptr;
760 const std::int32_t* bias_col_ptr = params.bias;
761 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
762 bias_col_ptr += params.start_row;
763 }
764
765 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
766 void* dst_ptr = dst_col_ptr;
767 const std::int32_t* bias_ptr = bias_col_ptr;
768
769 const std::int32_t lhs_zero_point = params.lhs_zero_point;
770 const bool has_rhs_sums_offsets =
771 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
772 std::int32_t rhs_sums_offsets[8];
773 if (has_rhs_sums_offsets) {
774 const __m256i rhs_sums_offset_v = _mm256_mullo_epi32(
775 _mm256_set1_epi32(lhs_zero_point),
776 _mm256_loadu_si256(
777 reinterpret_cast<__m256i const*>(&params.rhs_sums[0])));
778 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets),
779 rhs_sums_offset_v);
780 }
781
782 for (int row = params.start_row; row <= params.last_row;
783 row += kAvx8bitBlockSize) {
784 const int residual_rows =
785 std::min(params.dst_rows - row, kAvx8bitBlockSize);
786
787 const __m256i splitter_idx =
788 _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data));
789
790 __m256i accum_data_v0;
791
792 // Initialize with bias.
793 __m256i initial_accum_data =
794 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr));
795 bias_ptr += bias_ptr_block_increment;
796
797 // Adjustments common across columns.
798 const std::int32_t rhs_zero_point = params.rhs_zero_point;
799 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
800 const __m256i lhs_sums_offset = _mm256_mullo_epi32(
801 _mm256_set1_epi32(rhs_zero_point),
802 _mm256_loadu_si256(
803 reinterpret_cast<__m256i const*>(&params.lhs_sums[row])));
804 initial_accum_data =
805 _mm256_sub_epi32(initial_accum_data, lhs_sums_offset);
806 }
807 const std::int32_t prod_zp_depth = params.prod_zp_depth;
808 if (prod_zp_depth) {
809 initial_accum_data = _mm256_add_epi32(initial_accum_data,
810 _mm256_set1_epi32(prod_zp_depth));
811 }
812
813 // Adjustments differing across columns.
814 if (has_rhs_sums_offsets) {
815 accum_data_v0 = _mm256_sub_epi32(initial_accum_data,
816 _mm256_set1_epi32(rhs_sums_offsets[0]));
817 } else {
818 accum_data_v0 = initial_accum_data;
819 }
820
821 const std::int8_t* lhs_ptr = lhs_col_ptr;
822 const void* rhs_ptr = rhs_col_ptr;
823 for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
824 const __m256i lhs_data =
825 _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
826 const std::int32_t* rhs_data =
827 reinterpret_cast<const std::int32_t*>(rhs_ptr);
828
829 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
830 // For simplicity we load 4x the data that we need and process twice the
831 // data that we need and store only the data we need.
832 std::int32_t rhs_data_buf[2];
833 if (params.rhs_scalar_size == 1) {
834 rhs_data = rhs_data_buf;
835 const __m128i rhs_data_8bit =
836 intrin_utils::mm_loadu_si32<path>(rhs_ptr);
837 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
838 // Now that we have cast the RHS data, we store it so that each value
839 // can be separately loaded in the accumulation loop.
840 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
841 rhs_16_bit_dup);
842 } else {
843 RUY_DCHECK(params.rhs_scalar_size == 2);
844 }
845
846 // NOTE: There may be opportunities for permuting the data in the packing
847 // code instead of here.
848 const __m256i lhs_data_split =
849 _mm256_shuffle_epi8(lhs_data, splitter_idx);
850 const __m256i lhs_data_split_expand_bottom =
851 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0));
852 const __m256i lhs_data_split_expand_top =
853 _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1));
854
855 // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit.
856 const __m256i lhs_16_bit_low = _mm256_permute2x128_si256(
857 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20);
858 // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
859 const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
860 lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
861 // Accumulate for column 0.
862 const std::int32_t low_rhs_value = rhs_data[0];
863 const std::int32_t high_rhs_value = rhs_data[1];
864
865 const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
866 const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
867
868 accum_data_v0 = _mm256_add_epi32(
869 accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
870 accum_data_v0 = _mm256_add_epi32(
871 accum_data_v0,
872 _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
873
874 lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
875 rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
876 kAvx8bitBlockSize * kAvx8bitInnerSize *
877 params.rhs_scalar_size);
878 }
879
880 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
881 __m256i m_vector;
882 __m256i e_vector;
883 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
884 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
885 m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
886 params.multiplier_fixedpoint + channel));
887 e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
888 params.multiplier_exponent + channel));
889
890 const __m256i m_64bit_low =
891 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0));
892 const __m256i m_64bit_high =
893 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1));
894
895 const __m256i zero_vector = _mm256_setzero_si256();
896 const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
897 const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
898 const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
899 const __m256i final_right_shift = _mm256_set1_epi32(31);
900 const __m256i final_right_shift_low =
901 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
902 const __m256i final_right_shift_high =
903 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
904 const __m256i convert_to_unsigned_64 =
905 _mm256_set1_epi64x(0x8000000000000000);
906
907 __m256i post_scaling_offset = _mm256_setzero_si256();
908 // A "half" added for rounding prior to truncation of 64-bit value.
909 const __m256i offset_vector = _mm256_add_epi64(
910 _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
911 convert_to_unsigned_64);
912
913 if (params.dst_zero_point) {
914 post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
915 }
916
917 const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
918
919 // See GEMM version for details of this process.
920 {
921 __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
922 // Apply the fixed-point part of the multiplier.
923 __m256i scaled_v_low = _mm256_mul_epi32(
924 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
925 m_64bit_low);
926 __m256i scaled_v_high = _mm256_mul_epi32(
927 _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
928 m_64bit_high);
929
930 scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
931 scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
932
933 scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
934 scaled_v_high =
935 _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
936
937 scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
938 __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
939 results = _mm256_permutevar8x32_epi32(results, repack_perm);
940
941 // Now do a Rounding Right Shift.
942 // First, construct the nudge value for each lane.
943 const __m256i zeros = _mm256_setzero_si256();
944 const __m256i mask_rightshift_gtz =
945 _mm256_cmpgt_epi32(right_shift, zeros);
946 const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
947 _mm256_set1_epi32(1),
948 _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
949 __m256i nudge = intrin_utils::mm256_blendv_epi32(
950 zeros, one_shift_exp_minus1, mask_rightshift_gtz);
951 // Calculate the shifted sum (results + nudge) >> exp.
952 const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
953 const __m256i shifted_sum =
954 _mm256_srav_epi32(r_plus_nudge, right_shift);
955
956 // Identify overflow in each lane and create mask.
957 const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
958 _mm256_set1_epi32(1),
959 _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
960 const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
961 results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
962 // Fill results with either (results + nudge) >> exponent or
963 // 1 << (31 - exp) in the case of overflow.
964 results = intrin_utils::mm256_blendv_epi32(
965 shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
966
967 accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
968 }
969 }
970 const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
971 const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min);
972
973 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
974 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
975 __m256i result = accum_data_v0;
976 result = _mm256_min_epi32(result, clamp_max_v);
977 result = _mm256_max_epi32(result, clamp_min_v);
978 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
979 result);
980 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
981 kAvx8bitBlockSize);
982 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
983 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
984 __m256i result = accum_data_v0;
985 result = _mm256_min_epi32(result, clamp_max_v);
986 result = _mm256_max_epi32(result, clamp_min_v);
987 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows,
988 result);
989 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
990 kAvx8bitBlockSize);
991 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
992 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
993 __m256i result = accum_data_v0;
994 result = _mm256_min_epi32(result, clamp_max_v);
995 result = _mm256_max_epi32(result, clamp_min_v);
996 intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows,
997 result);
998 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
999 kAvx8bitBlockSize);
1000 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
1001 std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
1002 intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows,
1003 accum_data_v0);
1004 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
1005 kAvx8bitBlockSize);
1006 } else {
1007 RUY_DCHECK(false);
1008 }
1009
1010 lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
1011 } // End row-block loop.
1012
1013 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
1014 kAvx8bitBlockSize * params.dst_stride);
1015 rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
1016 kAvx8bitBlockSize * params.rhs_stride);
1017} // NOLINT(readability/fn_size)
1018
1019void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
1020 Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params);
1021}
1022
1023void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
1024 profiler::ScopeLabel label("Kernel kAvx2Fma float");
1025 KernelFloatAvxCommon<Path::kAvx2Fma>(params);
1026}
1027
1028void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
1029 profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
1030 KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params);
1031}
1032
1033#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
1034
1035} // namespace ruy
1036