1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cstdint>
18
19#include "ruy/check_macros.h"
20#include "ruy/kernel_x86.h"
21#include "ruy/opt_set.h"
22#include "ruy/platform.h"
23#include "ruy/profiler/instrumentation.h"
24
25#if RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
26#include <immintrin.h> // IWYU pragma: keep
27#endif
28
29namespace ruy {
30
31#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
32
33void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) {
34 // CPU-ID-based checks should disable the path that would reach this point.
35 RUY_DCHECK(false);
36}
37
38void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) {
39 // CPU-ID-based checks should disable the path that would reach this point.
40 RUY_DCHECK(false);
41}
42
43void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) {
44 // CPU-ID-based checks should disable the path that would reach this point.
45 RUY_DCHECK(false);
46}
47
48void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
49 // CPU-ID-based checks should disable the path that would reach this point.
50 RUY_DCHECK(false);
51}
52
53#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
54
55void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
56 profiler::ScopeLabel label("Kernel kAvx512 8-bit");
57
58 std::int32_t dst_stride = 0;
59 if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) ||
60 (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) {
61 dst_stride = params.dst_stride;
62 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
63 dst_stride = params.dst_stride / sizeof(std::int16_t);
64 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
65 dst_stride = params.dst_stride / sizeof(std::int32_t);
66 } else {
67 RUY_DCHECK(false);
68 }
69
70 const void* rhs_col_ptr = params.rhs_base_ptr;
71 void* dst_col_ptr = params.dst_base_ptr;
72
73 for (int col = params.start_col; col <= params.last_col; col += 16) {
74 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
75 void* dst_ptr = dst_col_ptr;
76
77 const std::int32_t lhs_zero_point = params.lhs_zero_point;
78 const bool has_rhs_sums_offsets =
79 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
80 std::int32_t rhs_sums_offsets[16];
81 if (has_rhs_sums_offsets) {
82 const __m512i rhs_sums_offset_v =
83 _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
84 _mm512_loadu_si512(&params.rhs_sums[col]));
85 _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
86 rhs_sums_offset_v);
87 }
88
89 for (int row = params.start_row; row <= params.last_row; row += 16) {
90 int channel =
91 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row;
92 int multiplier_channel =
93 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0;
94
95 const int residual_rows = std::min(params.dst_rows - row, 16);
96 const int residual_cols = std::min(params.dst_cols - col, 16);
97
98 __m512i accum_data_v0;
99 __m512i accum_data_v1;
100 __m512i accum_data_v2;
101 __m512i accum_data_v3;
102 __m512i accum_data_v4;
103 __m512i accum_data_v5;
104 __m512i accum_data_v6;
105 __m512i accum_data_v7;
106 __m512i accum_data_v8;
107 __m512i accum_data_v9;
108 __m512i accum_data_va;
109 __m512i accum_data_vb;
110 __m512i accum_data_vc;
111 __m512i accum_data_vd;
112 __m512i accum_data_ve;
113 __m512i accum_data_vf;
114
115 const __mmask16 row_mask =
116 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
117
118 // initial_accum_data will be the initialize of each of the
119 // accum_data_* accumulator registers. We compute into it terms that are
120 // identical across columns.
121 __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth);
122
123 // In the channels-are-rows case, we can load bias here.
124 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
125 !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
126 initial_accum_data = _mm512_add_epi32(
127 initial_accum_data,
128 _mm512_loadu_si512(
129 reinterpret_cast<const __m512i*>(params.bias + row)));
130 }
131
132 const std::int32_t rhs_zero_point = params.rhs_zero_point;
133 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
134 const __m512i lhs_sums_offset =
135 _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
136 _mm512_loadu_si512(&params.lhs_sums[row]));
137 initial_accum_data =
138 _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
139 }
140
141 // Adjustments differing across columns.
142 if (has_rhs_sums_offsets) {
143 accum_data_v0 = _mm512_sub_epi32(
144 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0]));
145 accum_data_v1 = _mm512_sub_epi32(
146 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1]));
147 accum_data_v2 = _mm512_sub_epi32(
148 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2]));
149 accum_data_v3 = _mm512_sub_epi32(
150 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3]));
151 accum_data_v4 = _mm512_sub_epi32(
152 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4]));
153 accum_data_v5 = _mm512_sub_epi32(
154 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5]));
155 accum_data_v6 = _mm512_sub_epi32(
156 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6]));
157 accum_data_v7 = _mm512_sub_epi32(
158 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7]));
159 accum_data_v8 = _mm512_sub_epi32(
160 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8]));
161 accum_data_v9 = _mm512_sub_epi32(
162 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9]));
163 accum_data_va = _mm512_sub_epi32(
164 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10]));
165 accum_data_vb = _mm512_sub_epi32(
166 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11]));
167 accum_data_vc = _mm512_sub_epi32(
168 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12]));
169 accum_data_vd = _mm512_sub_epi32(
170 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13]));
171 accum_data_ve = _mm512_sub_epi32(
172 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14]));
173 accum_data_vf = _mm512_sub_epi32(
174 initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15]));
175 } else {
176 accum_data_v0 = initial_accum_data;
177 accum_data_v1 = initial_accum_data;
178 accum_data_v2 = initial_accum_data;
179 accum_data_v3 = initial_accum_data;
180 accum_data_v4 = initial_accum_data;
181 accum_data_v5 = initial_accum_data;
182 accum_data_v6 = initial_accum_data;
183 accum_data_v7 = initial_accum_data;
184 accum_data_v8 = initial_accum_data;
185 accum_data_v9 = initial_accum_data;
186 accum_data_va = initial_accum_data;
187 accum_data_vb = initial_accum_data;
188 accum_data_vc = initial_accum_data;
189 accum_data_vd = initial_accum_data;
190 accum_data_ve = initial_accum_data;
191 accum_data_vf = initial_accum_data;
192 }
193
194 // Finally, in the channels-are-columns case, load bias data here.
195 if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) &&
196 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) {
197 const __m512i bias_data = _mm512_loadu_si512(
198 reinterpret_cast<const __m512i*>(params.bias + col));
199 accum_data_v0 = _mm512_add_epi32(
200 accum_data_v0,
201 _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data));
202 accum_data_v1 = _mm512_add_epi32(
203 accum_data_v1,
204 _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data));
205 accum_data_v2 = _mm512_add_epi32(
206 accum_data_v2,
207 _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data));
208 accum_data_v3 = _mm512_add_epi32(
209 accum_data_v3,
210 _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data));
211 accum_data_v4 = _mm512_add_epi32(
212 accum_data_v4,
213 _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data));
214 accum_data_v5 = _mm512_add_epi32(
215 accum_data_v5,
216 _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data));
217 accum_data_v6 = _mm512_add_epi32(
218 accum_data_v6,
219 _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data));
220 accum_data_v7 = _mm512_add_epi32(
221 accum_data_v7,
222 _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data));
223 accum_data_v8 = _mm512_add_epi32(
224 accum_data_v8,
225 _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data));
226 accum_data_v9 = _mm512_add_epi32(
227 accum_data_v9,
228 _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data));
229 accum_data_va = _mm512_add_epi32(
230 accum_data_va,
231 _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data));
232 accum_data_vb = _mm512_add_epi32(
233 accum_data_vb,
234 _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data));
235 accum_data_vc = _mm512_add_epi32(
236 accum_data_vc,
237 _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data));
238 accum_data_vd = _mm512_add_epi32(
239 accum_data_vd,
240 _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data));
241 accum_data_ve = _mm512_add_epi32(
242 accum_data_ve,
243 _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data));
244 accum_data_vf = _mm512_add_epi32(
245 accum_data_vf,
246 _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data));
247 }
248
249 const std::int8_t* lhs_ptr = lhs_col_ptr;
250 const void* rhs_ptr = rhs_col_ptr;
251 for (int d = 0; d < params.depth; d += 4) {
252 const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
253 __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
254
255 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
256 std::int32_t rhs_data_buf[32];
257 const std::int32_t* rhs_data =
258 reinterpret_cast<const std::int32_t*>(rhs_ptr);
259 if (params.rhs_scalar_size == 1) {
260 rhs_data = rhs_data_buf;
261 const __m256i rhs_data_bottom_lane =
262 _mm512_castsi512_si256(rhs_data_8bit);
263 const __m256i rhs_data_top_lane =
264 _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
265 const __m512i rhs_16_bit_dup_low =
266 _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
267 const __m512i rhs_16_bit_dup_high =
268 _mm512_cvtepi8_epi16(rhs_data_top_lane);
269 // Now that we have cast the RHS data, we store it so that each value
270 // can be separately loaded in the accumulation loop.
271 _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf),
272 rhs_16_bit_dup_low);
273 _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16),
274 rhs_16_bit_dup_high);
275 } else {
276 RUY_DCHECK(params.rhs_scalar_size == 2);
277 }
278
279 // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
280 const __m512i lhs_16_bit_low =
281 _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
282 // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
283 const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
284 _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
285
286 auto process_column = [=](int col, __m512i& accum) {
287 const __m512i rhs_16_bit_dup_low =
288 _mm512_set1_epi32(rhs_data[2 * col]);
289 const __m512i rhs_16_bit_dup_high =
290 _mm512_set1_epi32(rhs_data[2 * col + 1]);
291
292 accum = _mm512_add_epi32(
293 accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
294 accum = _mm512_add_epi32(
295 accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
296 };
297 process_column(0, accum_data_v0);
298 process_column(1, accum_data_v1);
299 process_column(2, accum_data_v2);
300 process_column(3, accum_data_v3);
301 process_column(4, accum_data_v4);
302 process_column(5, accum_data_v5);
303 process_column(6, accum_data_v6);
304 process_column(7, accum_data_v7);
305 process_column(8, accum_data_v8);
306 process_column(9, accum_data_v9);
307 process_column(10, accum_data_va);
308 process_column(11, accum_data_vb);
309 process_column(12, accum_data_vc);
310 process_column(13, accum_data_vd);
311 process_column(14, accum_data_ve);
312 process_column(15, accum_data_vf);
313
314 lhs_ptr += 16 * 4;
315 rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
316 16 * 4 * params.rhs_scalar_size);
317 }
318
319 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
320 // The non-per-channel case could equivalently be handled in the per-row
321 // or per-column code path. The per-row code path is slightly more
322 // efficient so we handle it there.
323 const bool per_column_multiplier =
324 (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
325 (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
326
327 __m512i m_vector;
328 __m512i e_vector;
329 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
330 m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
331 params.multiplier_fixedpoint + multiplier_channel));
332 e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
333 params.multiplier_exponent + multiplier_channel));
334
335 const __m512i m_64bit_low =
336 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
337 const __m512i m_64bit_high =
338 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
339
340 const __m512i zero_vector = _mm512_setzero_epi32();
341 const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
342 const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
343 const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
344 const __m512i final_right_shift = _mm512_set1_epi32(31);
345 const __m512i right_shift_low =
346 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
347 const __m512i right_shift_high =
348 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
349 const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
350 _mm512_extracti32x8_epi32(final_right_shift, 0));
351 const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
352 _mm512_extracti32x8_epi32(final_right_shift, 1));
353
354 // A "half" added for rounding prior to truncation of 64-bit value.
355 const __m512i offset_vector =
356 _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
357
358 auto rounding_right_shift = [=](__m512i& results,
359 const __m512i& exponent) {
360 // Construct the "nudge" value for each lane if the exponent is
361 // greater than 0. Otherwise, the nudge is 0.
362 const __m512i zeros = _mm512_setzero_si512();
363 const auto mask_rightshift_gtz =
364 _mm512_cmpgt_epi64_mask(exponent, zeros);
365 const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
366 _mm512_set1_epi64(1),
367 _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
368 __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
369 one_shift_exp_minus1);
370 // Calculate the shifted sum (results + nudge) >> exp.
371 const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
372 const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
373
374 // Identify overflow in each lane and create mask.
375 const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
376 _mm512_set1_epi64(1),
377 _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
378 const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
379 results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
380 // Fill results with either (results + nudge) >> exponent or
381 // 1 << (31 - exp) in the case of overflow.
382 results = _mm512_mask_mov_epi64(
383 shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
384 };
385
386 if (per_column_multiplier) {
387 auto apply_multiplier = [=](__m512i& accum, int col) {
388 __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
389 // Apply the fixed-point part of the multiplier.
390 __m512i left_shift_val =
391 _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
392 __m512i m_64bit_val = _mm512_permutexvar_epi64(
393 perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
394 __m512i offset_vector_val =
395 _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
396 __m512i final_right_shift_val = _mm512_permutexvar_epi64(
397 perm_64bit_vals,
398 col < 8 ? final_right_shift_low : final_right_shift_high);
399 __m512i right_shift_val = _mm512_permutexvar_epi64(
400 perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
401
402 accum = _mm512_sllv_epi32(accum, left_shift_val);
403 __m512i scaled_v_low = _mm512_mul_epi32(
404 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
405 m_64bit_val);
406 __m512i scaled_v_high = _mm512_mul_epi32(
407 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
408 m_64bit_val);
409
410 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
411 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
412
413 scaled_v_low =
414 _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
415 scaled_v_high =
416 _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
417
418 rounding_right_shift(scaled_v_low, right_shift_val);
419 rounding_right_shift(scaled_v_high, right_shift_val);
420
421 accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
422 accum = _mm512_inserti32x8(accum,
423 _mm512_cvtepi64_epi32(scaled_v_high), 1);
424 };
425 apply_multiplier(accum_data_v0, 0);
426 apply_multiplier(accum_data_v1, 1);
427 apply_multiplier(accum_data_v2, 2);
428 apply_multiplier(accum_data_v3, 3);
429 apply_multiplier(accum_data_v4, 4);
430 apply_multiplier(accum_data_v5, 5);
431 apply_multiplier(accum_data_v6, 6);
432 apply_multiplier(accum_data_v7, 7);
433 apply_multiplier(accum_data_v8, 8);
434 apply_multiplier(accum_data_v9, 9);
435 apply_multiplier(accum_data_va, 10);
436 apply_multiplier(accum_data_vb, 11);
437 apply_multiplier(accum_data_vc, 12);
438 apply_multiplier(accum_data_vd, 13);
439 apply_multiplier(accum_data_ve, 14);
440 apply_multiplier(accum_data_vf, 15);
441 } else { // not per-column, so per-row
442 auto apply_multiplier = [=](__m512i& accum) {
443 accum = _mm512_sllv_epi32(accum, left_shift);
444 // Apply the fixed-point part of the multiplier.
445 __m512i scaled_v_low = _mm512_mul_epi32(
446 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
447 m_64bit_low);
448 __m512i scaled_v_high = _mm512_mul_epi32(
449 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
450 m_64bit_high);
451
452 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
453 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
454
455 scaled_v_low =
456 _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
457 scaled_v_high =
458 _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
459
460 rounding_right_shift(scaled_v_low, right_shift_low);
461 rounding_right_shift(scaled_v_high, right_shift_high);
462 accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
463 accum = _mm512_inserti32x8(accum,
464 _mm512_cvtepi64_epi32(scaled_v_high), 1);
465 };
466 apply_multiplier(accum_data_v0);
467 apply_multiplier(accum_data_v1);
468 apply_multiplier(accum_data_v2);
469 apply_multiplier(accum_data_v3);
470 apply_multiplier(accum_data_v4);
471 apply_multiplier(accum_data_v5);
472 apply_multiplier(accum_data_v6);
473 apply_multiplier(accum_data_v7);
474 apply_multiplier(accum_data_v8);
475 apply_multiplier(accum_data_v9);
476 apply_multiplier(accum_data_va);
477 apply_multiplier(accum_data_vb);
478 apply_multiplier(accum_data_vc);
479 apply_multiplier(accum_data_vd);
480 apply_multiplier(accum_data_ve);
481 apply_multiplier(accum_data_vf);
482 }
483
484 if (params.dst_zero_point != 0) {
485 __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
486 accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
487 accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point);
488 accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point);
489 accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point);
490 accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point);
491 accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point);
492 accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point);
493 accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point);
494 accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point);
495 accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point);
496 accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point);
497 accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point);
498 accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point);
499 accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point);
500 accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point);
501 accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point);
502 }
503 }
504
505 const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
506 const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
507
508 const bool store_full_block =
509 (residual_rows == 16) && (residual_cols == 16);
510
511 __m512i accum_data_v[16];
512
513 // In most cases we would make this conditional on (!store_full_block) and
514 // unwind the clamp-and-store loop, but the benefit appears small.
515 {
516 accum_data_v[0] = accum_data_v0;
517 accum_data_v[1] = accum_data_v1;
518 accum_data_v[2] = accum_data_v2;
519 accum_data_v[3] = accum_data_v3;
520 accum_data_v[4] = accum_data_v4;
521 accum_data_v[5] = accum_data_v5;
522 accum_data_v[6] = accum_data_v6;
523 accum_data_v[7] = accum_data_v7;
524 accum_data_v[8] = accum_data_v8;
525 accum_data_v[9] = accum_data_v9;
526 accum_data_v[10] = accum_data_va;
527 accum_data_v[11] = accum_data_vb;
528 accum_data_v[12] = accum_data_vc;
529 accum_data_v[13] = accum_data_vd;
530 accum_data_v[14] = accum_data_ve;
531 accum_data_v[15] = accum_data_vf;
532 }
533
534 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
535 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
536 const int block_col_offset = dst_stride;
537 if (store_full_block) {
538 for (int j = 0; j < 16; ++j) {
539 __m512i result = accum_data_v[j];
540 result = _mm512_min_epi32(result, clamp_max_v);
541 result = _mm512_max_epi32(result, clamp_min_v);
542 _mm_storeu_si128(
543 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
544 _mm512_cvtepi32_epi8(result));
545 }
546 } else {
547 for (int j = 0; j < residual_cols; ++j) {
548 __m512i result = accum_data_v[j];
549 result = _mm512_min_epi32(result, clamp_max_v);
550 result = _mm512_max_epi32(result, clamp_min_v);
551 _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
552 _mm512_cvtepi32_epi8(result));
553 }
554 }
555 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
556 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
557 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
558 const int block_col_offset = dst_stride;
559 if (store_full_block) {
560 for (int j = 0; j < residual_cols; ++j) {
561 __m512i result = accum_data_v[j];
562 result = _mm512_min_epi32(result, clamp_max_v);
563 result = _mm512_max_epi32(result, clamp_min_v);
564 _mm_storeu_si128(
565 reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset),
566 _mm512_cvtepi32_epi8(result));
567 }
568 } else {
569 for (int j = 0; j < residual_cols; ++j) {
570 __m512i result = accum_data_v[j];
571 result = _mm512_min_epi32(result, clamp_max_v);
572 result = _mm512_max_epi32(result, clamp_min_v);
573 _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask,
574 _mm512_cvtepi32_epi8(result));
575 }
576 }
577 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
578 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
579 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
580 const int block_col_offset = dst_stride;
581 if (store_full_block) {
582 for (int j = 0; j < 16; ++j) {
583 __m512i result = accum_data_v[j];
584 result = _mm512_min_epi32(result, clamp_max_v);
585 result = _mm512_max_epi32(result, clamp_min_v);
586 _mm256_storeu_si256(
587 reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset),
588 _mm512_cvtepi32_epi16(result));
589 }
590 } else {
591 for (int j = 0; j < residual_cols; ++j) {
592 __m512i result = accum_data_v[j];
593 result = _mm512_min_epi32(result, clamp_max_v);
594 result = _mm512_max_epi32(result, clamp_min_v);
595 _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask,
596 _mm512_cvtepi32_epi16(result));
597 }
598 }
599 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
600 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
601 if (store_full_block) {
602 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
603 for (int j = 0; j < 16; ++j) {
604 _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]);
605 }
606 } else {
607 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
608 for (int j = 0; j < residual_cols; ++j) {
609 _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask,
610 accum_data_v[j]);
611 }
612 }
613 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
614 } else {
615 RUY_DCHECK(false);
616 }
617
618 lhs_col_ptr += 16 * params.lhs_stride;
619 } // End row-block loop.
620
621 dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
622 16 * params.dst_stride);
623 rhs_col_ptr = static_cast<const void*>(
624 static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride);
625 } // End col-block loop.
626} // NOLINT(readability/fn_size)
627
628void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
629 profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV");
630
631 RUY_DCHECK_EQ(params.dst_cols, 1);
632 RUY_DCHECK_EQ(params.last_col, 0);
633 RUY_DCHECK_EQ(params.start_col, 0);
634
635 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
636
637 const void* rhs_col_ptr = params.rhs_base_ptr;
638 void* dst_col_ptr = params.dst_base_ptr;
639 const std::int32_t* bias_col_ptr = params.bias;
640 if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
641 bias_col_ptr += params.start_row;
642 }
643
644 const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
645 void* dst_ptr = dst_col_ptr;
646 const std::int32_t* bias_ptr = bias_col_ptr;
647
648 const std::int32_t lhs_zero_point = params.lhs_zero_point;
649 const bool has_rhs_sums_offsets =
650 (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point;
651 std::int32_t rhs_sums_offsets[16];
652 if (has_rhs_sums_offsets) {
653 const __m512i rhs_sums_offset_v =
654 _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point),
655 _mm512_loadu_si512(&params.rhs_sums[0]));
656 _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets),
657 rhs_sums_offset_v);
658 }
659
660 for (int row = params.start_row; row <= params.last_row; row += 16) {
661 const int residual_rows = std::min(params.dst_rows - row, 16);
662
663 __m512i accum_data_v0;
664
665 // Initialize with bias.
666 const __mmask16 row_mask =
667 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
668 __m512i initial_accum_data =
669 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr));
670 bias_ptr += bias_ptr_block_increment;
671
672 const std::int32_t rhs_zero_point = params.rhs_zero_point;
673 if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) {
674 const __m512i lhs_sums_offset =
675 _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point),
676 _mm512_loadu_si512(&params.lhs_sums[row]));
677 initial_accum_data =
678 _mm512_sub_epi32(initial_accum_data, lhs_sums_offset);
679 }
680
681 const std::int32_t prod_zp_depth = params.prod_zp_depth;
682 if (prod_zp_depth != 0) {
683 initial_accum_data = _mm512_add_epi32(initial_accum_data,
684 _mm512_set1_epi32(prod_zp_depth));
685 }
686
687 // Adjustments differing across columns.
688 if (has_rhs_sums_offsets) {
689 accum_data_v0 = _mm512_sub_epi32(initial_accum_data,
690 _mm512_set1_epi32(rhs_sums_offsets[0]));
691 } else {
692 accum_data_v0 = initial_accum_data;
693 }
694
695 const std::int8_t* lhs_ptr = lhs_col_ptr;
696 const void* rhs_ptr = rhs_col_ptr;
697 for (int d = 0; d < params.depth; d += 4) {
698 const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
699 const std::int32_t* rhs_data =
700 reinterpret_cast<const std::int32_t*>(rhs_ptr);
701
702 // Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
703 // For simplicity we load 4x the data that we need and process twice the
704 // data that we need and store only the data we need.
705 std::int32_t rhs_data_buf[2];
706 if (params.rhs_scalar_size == 1) {
707 rhs_data = rhs_data_buf;
708 const __m128i rhs_data_8bit =
709 _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
710 const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
711 // Now that we have cast the RHS data, we store it so that each value
712 // can be separately loaded in the accumulation loop.
713 _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
714 rhs_16_bit_dup);
715 } else {
716 RUY_DCHECK(params.rhs_scalar_size == 2);
717 }
718
719 // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
720 const __m512i lhs_16_bit_low =
721 _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data));
722 // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit.
723 const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16(
724 _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16)));
725
726 // Process column 0.
727 __m512i accum_v = accum_data_v0;
728 constexpr int index = 0;
729
730 const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]);
731 const __m512i rhs_16_bit_dup_high =
732 _mm512_set1_epi32(rhs_data[index + 1]);
733
734 accum_v = _mm512_add_epi32(
735 accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
736 accum_v = _mm512_add_epi32(
737 accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
738 accum_data_v0 = accum_v;
739
740 lhs_ptr += 16 * 4;
741 rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
742 16 * 4 * params.rhs_scalar_size);
743 }
744
745 if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
746 __m512i m_vector;
747 __m512i e_vector;
748 // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
749 int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0;
750 m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
751 params.multiplier_fixedpoint + channel));
752 e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
753 params.multiplier_exponent + channel));
754
755 const __m512i m_64bit_low =
756 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0));
757 const __m512i m_64bit_high =
758 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1));
759
760 const __m512i zero_vector = _mm512_setzero_epi32();
761 const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
762 const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
763 const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
764 const __m512i final_right_shift = _mm512_set1_epi32(31);
765 const __m512i right_shift_low =
766 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
767 const __m512i right_shift_high =
768 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
769 const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
770 _mm512_extracti32x8_epi32(final_right_shift, 0));
771 const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
772 _mm512_extracti32x8_epi32(final_right_shift, 1));
773
774 // A "half" added for rounding prior to truncation of 64-bit value.
775 const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
776
777 auto rounding_right_shift = [=](__m512i& results,
778 const __m512i& exponent) {
779 // Construct the "nudge" value for each lane if the exponent is
780 // greater than 0. Otherwise, the nudge is 0.
781 const __m512i zeros = _mm512_setzero_si512();
782 const auto mask_rightshift_gtz =
783 _mm512_cmpgt_epi64_mask(exponent, zeros);
784 const __m512i one_shift_exp_minus1 =
785 _mm512_sllv_epi64(_mm512_set1_epi64(1),
786 _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
787 __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
788 one_shift_exp_minus1);
789 // Calculate the shifted sum (results + nudge) >> exp.
790 const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
791 const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
792
793 // Identify overflow in each lane and create mask.
794 const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
795 _mm512_set1_epi64(1),
796 _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
797 const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
798 results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
799 // Fill results with either (results + nudge) >> exponent or
800 // 1 << (31 - exp) in the case of overflow.
801 results = _mm512_mask_mov_epi64(
802 shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
803 };
804
805 // Shift and round column 0.
806 accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
807 // Apply the fixed-point part of the multiplier.
808 __m512i scaled_v_low = _mm512_mul_epi32(
809 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)),
810 m_64bit_low);
811 __m512i scaled_v_high = _mm512_mul_epi32(
812 _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
813 m_64bit_high);
814
815 scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
816 scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
817
818 scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
819 scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
820
821 rounding_right_shift(scaled_v_low, right_shift_low);
822 rounding_right_shift(scaled_v_high, right_shift_high);
823
824 accum_data_v0 =
825 _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
826 accum_data_v0 = _mm512_inserti32x8(
827 accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1);
828
829 if (params.dst_zero_point != 0) {
830 __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point);
831 accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point);
832 }
833 }
834
835 const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max);
836 const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min);
837
838 if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
839 std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr);
840 __m512i result = accum_data_v0;
841 result = _mm512_min_epi32(result, clamp_max_v);
842 result = _mm512_max_epi32(result, clamp_min_v);
843 _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
844 dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16);
845 } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
846 std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr);
847 __m512i result = accum_data_v0;
848 result = _mm512_min_epi32(result, clamp_max_v);
849 result = _mm512_max_epi32(result, clamp_min_v);
850 _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result));
851 dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16);
852 } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
853 std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
854 __m512i result = accum_data_v0;
855 result = _mm512_min_epi32(result, clamp_max_v);
856 result = _mm512_max_epi32(result, clamp_min_v);
857 _mm256_mask_storeu_epi16(tmp_ptr, row_mask,
858 _mm512_cvtepi32_epi16(result));
859 dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16);
860 } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
861 std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
862 _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0);
863 dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16);
864 } else {
865 RUY_DCHECK(false);
866 }
867
868 lhs_col_ptr += 16 * params.lhs_stride;
869 } // End row-block loop.
870} // NOLINT(readability/fn_size)
871
872void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
873 profiler::ScopeLabel label("Kernel kAvx512 float");
874
875 // As parameters are defined, we need to scale by sizeof(float).
876 const std::int64_t lhs_stride = params.lhs_stride >> 2;
877 const std::int64_t dst_stride = params.dst_stride >> 2;
878 const std::int64_t rhs_stride = params.rhs_stride >> 2;
879
880 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
881 const int end_row = std::min(params.dst_rows, params.last_row + 16);
882 const int end_col = std::min(params.dst_cols, params.last_col + 16);
883
884 const float* adj_rhs_col_ptr =
885 params.rhs_base_ptr - params.start_col * rhs_stride;
886 float* adj_dst_col_ptr =
887 params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
888 const float* adj_lhs_col_ptr =
889 params.lhs_base_ptr - params.start_row * lhs_stride;
890 const float* bias_ptr = params.bias;
891
892 const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
893 const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
894 const bool channel_dimension_is_col =
895 params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
896
897 int col = params.start_col;
898 for (; col <= end_col - 16; col += 16) {
899 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
900 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
901
902 int row = params.start_row;
903 for (; row <= end_row - 16; row += 16) {
904 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
905 float* dst_ptr = dst_col_ptr + row;
906
907 // Process block in two halves, split by columns.
908#pragma unroll(1)
909 for (int mmm = 0; mmm < 2; ++mmm) {
910 __m512 accum_data_v0;
911 __m512 accum_data_v1;
912 __m512 accum_data_v2;
913 __m512 accum_data_v3;
914 __m512 accum_data_v4;
915 __m512 accum_data_v5;
916 __m512 accum_data_v6;
917 __m512 accum_data_v7;
918
919 // Initialize with bias.
920 if (channel_dimension_is_col) {
921 const float* bias_elem_ptr =
922 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
923 accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
924 accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
925 accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
926 accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
927 accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
928 accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
929 accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
930 accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
931 } else {
932 const __m512 initial_accum_data =
933 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
934
935 accum_data_v0 = initial_accum_data;
936 accum_data_v1 = initial_accum_data;
937 accum_data_v2 = initial_accum_data;
938 accum_data_v3 = initial_accum_data;
939 accum_data_v4 = initial_accum_data;
940 accum_data_v5 = initial_accum_data;
941 accum_data_v6 = initial_accum_data;
942 accum_data_v7 = initial_accum_data;
943 }
944
945 const float* lhs_ptr = lhs_col_ptr;
946 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
947 for (int d = 0; d < (params.depth - 1); ++d) {
948 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
949 const float* rhs_data = rhs_ptr;
950 lhs_ptr += 16;
951 rhs_ptr += 16;
952
953 // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
954 // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
955 // so if given an rvalue.
956 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
957 accum_data_v0);
958 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
959 accum_data_v1);
960 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
961 accum_data_v2);
962 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
963 accum_data_v3);
964 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
965 accum_data_v4);
966 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
967 accum_data_v5);
968 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
969 accum_data_v6);
970 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
971 accum_data_v7);
972 }
973 { // nested extra blocks lead to measurable speed gains
974 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
975 const float* rhs_data = rhs_ptr;
976 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
977 accum_data_v0);
978 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
979 accum_data_v1);
980 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
981 accum_data_v2);
982 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
983 accum_data_v3);
984 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
985 accum_data_v4);
986 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
987 accum_data_v5);
988 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
989 accum_data_v6);
990 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
991 accum_data_v7);
992 {
993 float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
994 accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
995 accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
996 _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
997 accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
998 accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
999 _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
1000 accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1001 accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1002 _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
1003 accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1004 accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1005 _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
1006 accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1007 accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1008 _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
1009 accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1010 accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1011 _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
1012 accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1013 accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1014 _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
1015 accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1016 accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1017 _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
1018 }
1019 }
1020 }
1021 } // End row-block loop.
1022
1023 // The unrolling within this conditional may be somewhat pointless. It
1024 // depends on the kinds of models.
1025 if (row < end_row) {
1026 const int residual_rows = end_row - row;
1027
1028 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1029 float* dst_ptr = dst_col_ptr + row;
1030
1031 const __mmask16 row_mask =
1032 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1033
1034 // Process block in two halves, split by columns.
1035 for (int mmm = 0; mmm < 2; ++mmm) {
1036 __m512 accum_data_v0;
1037 __m512 accum_data_v1;
1038 __m512 accum_data_v2;
1039 __m512 accum_data_v3;
1040 __m512 accum_data_v4;
1041 __m512 accum_data_v5;
1042 __m512 accum_data_v6;
1043 __m512 accum_data_v7;
1044
1045 // Initialize with bias.
1046 if (channel_dimension_is_col) {
1047 const float* bias_elem_ptr =
1048 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1049 accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
1050 accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
1051 accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
1052 accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
1053 accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
1054 accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
1055 accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
1056 accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
1057 } else {
1058 const __m512 initial_accum_data =
1059 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1060
1061 accum_data_v0 = initial_accum_data;
1062 accum_data_v1 = initial_accum_data;
1063 accum_data_v2 = initial_accum_data;
1064 accum_data_v3 = initial_accum_data;
1065 accum_data_v4 = initial_accum_data;
1066 accum_data_v5 = initial_accum_data;
1067 accum_data_v6 = initial_accum_data;
1068 accum_data_v7 = initial_accum_data;
1069 }
1070
1071 const float* lhs_ptr = lhs_col_ptr;
1072 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1073 for (int d = 0; d < (params.depth - 1); ++d) {
1074 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1075 const float* rhs_data = rhs_ptr;
1076 lhs_ptr += 16;
1077 rhs_ptr += 16;
1078 // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
1079 // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
1080 // so if given an rvalue.
1081 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1082 accum_data_v0);
1083 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1084 accum_data_v1);
1085 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1086 accum_data_v2);
1087 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1088 accum_data_v3);
1089 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1090 accum_data_v4);
1091 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1092 accum_data_v5);
1093 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1094 accum_data_v6);
1095 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1096 accum_data_v7);
1097 }
1098 {
1099 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1100 const float* rhs_data = rhs_ptr;
1101 accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
1102 accum_data_v0);
1103 accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
1104 accum_data_v1);
1105 accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
1106 accum_data_v2);
1107 accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
1108 accum_data_v3);
1109 accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
1110 accum_data_v4);
1111 accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
1112 accum_data_v5);
1113 accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
1114 accum_data_v6);
1115 accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
1116 accum_data_v7);
1117 {
1118 float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
1119 accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
1120 accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
1121 _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask,
1122 accum_data_v0);
1123 accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
1124 accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
1125 _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask,
1126 accum_data_v1);
1127 accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
1128 accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
1129 _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask,
1130 accum_data_v2);
1131 accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
1132 accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
1133 _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask,
1134 accum_data_v3);
1135 accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
1136 accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
1137 _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask,
1138 accum_data_v4);
1139 accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
1140 accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
1141 _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask,
1142 accum_data_v5);
1143 accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
1144 accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
1145 _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask,
1146 accum_data_v6);
1147 accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
1148 accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
1149 _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask,
1150 accum_data_v7);
1151 }
1152 }
1153 } // Inner half-block loop.
1154 } // Residual rows, main col-block loop.
1155 } // End col-block loop.
1156
1157 if (col < end_col) {
1158 RUY_DCHECK_GE(end_col - col, 0);
1159 RUY_DCHECK_LT(end_col - col, 16);
1160
1161 __m512 accum_data_v[8];
1162
1163 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
1164 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
1165
1166 for (int row = params.start_row; row < end_row; row += 16) {
1167 const int residual_rows = std::min(end_row - row, 16);
1168
1169 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1170 float* dst_ptr = dst_col_ptr + row;
1171
1172 const __mmask16 row_mask =
1173 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1174
1175 // Process block in two halves, split by columns.
1176 for (int mmm = 0; mmm < 2; ++mmm) {
1177 // Initialize with bias.
1178 if (channel_dimension_is_col) {
1179 const float* bias_elem_ptr =
1180 bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
1181 for (int j = 0; j < 8; ++j) {
1182 accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
1183 }
1184 } else {
1185 const __m512 initial_accum_data =
1186 _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
1187 for (int j = 0; j < 8; ++j) {
1188 accum_data_v[j] = initial_accum_data;
1189 }
1190 }
1191
1192 const float* lhs_ptr = lhs_col_ptr;
1193 const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
1194 for (int d = 0; d < params.depth; ++d) {
1195 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1196 const float* rhs_data = rhs_ptr;
1197
1198 for (int j = 0; j < 8; ++j) {
1199 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]);
1200 accum_data_v[j] =
1201 _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]);
1202 }
1203 lhs_ptr += 16;
1204 rhs_ptr += 16;
1205 }
1206
1207 const int residual_cols = std::min(end_col - col - 8 * mmm, 8);
1208
1209 if (residual_rows == 16) {
1210 if (residual_cols == 8) {
1211 for (int j = 0; j < 8; ++j) {
1212 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1213 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1214 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1215 _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1216 }
1217 } else {
1218 for (int j = 0; j < residual_cols; ++j) {
1219 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1220 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1221 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1222 _mm512_storeu_ps(block_ptr, accum_data_v[j]);
1223 }
1224 }
1225 } else {
1226 for (int j = 0; j < residual_cols; ++j) {
1227 float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride;
1228 accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v);
1229 accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v);
1230 _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]);
1231 }
1232 }
1233 } // Inner half-block loop.
1234 } // End row-block loop.
1235 } // Residual cols.
1236}
1237
1238void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) {
1239 profiler::ScopeLabel label("Kernel kAvx512 float GEMV");
1240
1241 RUY_DCHECK_EQ(params.dst_cols, 1);
1242 RUY_DCHECK_EQ(params.last_col, 0);
1243 RUY_DCHECK_EQ(params.start_col, 0);
1244
1245 // As parameters are defined, we need to scale by sizeof(float).
1246 const std::int64_t lhs_stride = params.lhs_stride >> 2;
1247
1248 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
1249 const int end_row = std::min(params.dst_rows, params.last_row + 16);
1250
1251 float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
1252 const float* adj_lhs_col_ptr =
1253 params.lhs_base_ptr - params.start_row * lhs_stride;
1254 const float* bias_col_ptr = params.bias;
1255
1256 const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
1257 const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
1258
1259 __m512 accum_data_v;
1260
1261 const float* rhs_col_ptr = params.rhs_base_ptr;
1262 float* dst_col_ptr = adj_dst_col_ptr;
1263
1264 int row = params.start_row;
1265 for (; row <= end_row - 16; row += 16) {
1266 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1267 float* dst_ptr = dst_col_ptr + row;
1268 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1269
1270 // Initialize with bias.
1271 accum_data_v = _mm512_loadu_ps(bias_ptr);
1272
1273 const float* lhs_ptr = lhs_col_ptr;
1274 const float* rhs_ptr = rhs_col_ptr;
1275 for (int d = 0; d < params.depth; ++d) {
1276 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1277 const float rhs_data = *rhs_ptr;
1278
1279 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1280 accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1281 lhs_ptr += 16;
1282 rhs_ptr += 16;
1283 }
1284
1285 accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1286 accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1287 _mm512_storeu_ps(dst_ptr, accum_data_v);
1288 } // End row-block loop.
1289
1290 if (row < end_row) {
1291 const int residual_rows = end_row - row;
1292 RUY_CHECK_GE(residual_rows, 1);
1293 RUY_CHECK_LT(residual_rows, 16);
1294
1295 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
1296 float* dst_ptr = dst_col_ptr + row;
1297 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
1298
1299 // Initialize with bias.
1300 const __mmask16 row_mask =
1301 (static_cast<std::uint32_t>(1) << residual_rows) - 1;
1302 accum_data_v = _mm512_loadu_ps(bias_ptr);
1303
1304 const float* lhs_ptr = lhs_col_ptr;
1305 const float* rhs_ptr = rhs_col_ptr;
1306 for (int d = 0; d < params.depth; ++d) {
1307 const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
1308 const float rhs_data = *rhs_ptr;
1309
1310 const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data);
1311 accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v);
1312 lhs_ptr += 16;
1313 rhs_ptr += 16;
1314 }
1315
1316 accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v);
1317 accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v);
1318 _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v);
1319 } // End handling of residual rows.
1320}
1321
1322#endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
1323
1324} // namespace ruy
1325