1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cstdint> |
18 | |
19 | #include "ruy/check_macros.h" |
20 | #include "ruy/kernel_x86.h" |
21 | #include "ruy/opt_set.h" |
22 | #include "ruy/platform.h" |
23 | #include "ruy/profiler/instrumentation.h" |
24 | |
25 | #if RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
26 | #include <immintrin.h> // IWYU pragma: keep |
27 | #endif |
28 | |
29 | namespace ruy { |
30 | |
31 | #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) |
32 | |
33 | void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) { |
34 | // CPU-ID-based checks should disable the path that would reach this point. |
35 | RUY_DCHECK(false); |
36 | } |
37 | |
38 | void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) { |
39 | // CPU-ID-based checks should disable the path that would reach this point. |
40 | RUY_DCHECK(false); |
41 | } |
42 | |
43 | void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) { |
44 | // CPU-ID-based checks should disable the path that would reach this point. |
45 | RUY_DCHECK(false); |
46 | } |
47 | |
48 | void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { |
49 | // CPU-ID-based checks should disable the path that would reach this point. |
50 | RUY_DCHECK(false); |
51 | } |
52 | |
53 | #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
54 | |
55 | void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { |
56 | profiler::ScopeLabel label("Kernel kAvx512 8-bit" ); |
57 | |
58 | std::int32_t dst_stride = 0; |
59 | if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || |
60 | (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { |
61 | dst_stride = params.dst_stride; |
62 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
63 | dst_stride = params.dst_stride / sizeof(std::int16_t); |
64 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
65 | dst_stride = params.dst_stride / sizeof(std::int32_t); |
66 | } else { |
67 | RUY_DCHECK(false); |
68 | } |
69 | |
70 | const void* rhs_col_ptr = params.rhs_base_ptr; |
71 | void* dst_col_ptr = params.dst_base_ptr; |
72 | |
73 | for (int col = params.start_col; col <= params.last_col; col += 16) { |
74 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
75 | void* dst_ptr = dst_col_ptr; |
76 | |
77 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
78 | const bool has_rhs_sums_offsets = |
79 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
80 | std::int32_t rhs_sums_offsets[16]; |
81 | if (has_rhs_sums_offsets) { |
82 | const __m512i rhs_sums_offset_v = |
83 | _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), |
84 | _mm512_loadu_si512(¶ms.rhs_sums[col])); |
85 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), |
86 | rhs_sums_offset_v); |
87 | } |
88 | |
89 | for (int row = params.start_row; row <= params.last_row; row += 16) { |
90 | int channel = |
91 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; |
92 | int multiplier_channel = |
93 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; |
94 | |
95 | const int residual_rows = std::min(params.dst_rows - row, 16); |
96 | const int residual_cols = std::min(params.dst_cols - col, 16); |
97 | |
98 | __m512i accum_data_v0; |
99 | __m512i accum_data_v1; |
100 | __m512i accum_data_v2; |
101 | __m512i accum_data_v3; |
102 | __m512i accum_data_v4; |
103 | __m512i accum_data_v5; |
104 | __m512i accum_data_v6; |
105 | __m512i accum_data_v7; |
106 | __m512i accum_data_v8; |
107 | __m512i accum_data_v9; |
108 | __m512i accum_data_va; |
109 | __m512i accum_data_vb; |
110 | __m512i accum_data_vc; |
111 | __m512i accum_data_vd; |
112 | __m512i accum_data_ve; |
113 | __m512i accum_data_vf; |
114 | |
115 | const __mmask16 row_mask = |
116 | (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
117 | |
118 | // initial_accum_data will be the initialize of each of the |
119 | // accum_data_* accumulator registers. We compute into it terms that are |
120 | // identical across columns. |
121 | __m512i initial_accum_data = _mm512_set1_epi32(params.prod_zp_depth); |
122 | |
123 | // In the channels-are-rows case, we can load bias here. |
124 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
125 | !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
126 | initial_accum_data = _mm512_add_epi32( |
127 | initial_accum_data, |
128 | _mm512_loadu_si512( |
129 | reinterpret_cast<const __m512i*>(params.bias + row))); |
130 | } |
131 | |
132 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
133 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
134 | const __m512i lhs_sums_offset = |
135 | _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), |
136 | _mm512_loadu_si512(¶ms.lhs_sums[row])); |
137 | initial_accum_data = |
138 | _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); |
139 | } |
140 | |
141 | // Adjustments differing across columns. |
142 | if (has_rhs_sums_offsets) { |
143 | accum_data_v0 = _mm512_sub_epi32( |
144 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); |
145 | accum_data_v1 = _mm512_sub_epi32( |
146 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); |
147 | accum_data_v2 = _mm512_sub_epi32( |
148 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); |
149 | accum_data_v3 = _mm512_sub_epi32( |
150 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); |
151 | accum_data_v4 = _mm512_sub_epi32( |
152 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); |
153 | accum_data_v5 = _mm512_sub_epi32( |
154 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); |
155 | accum_data_v6 = _mm512_sub_epi32( |
156 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); |
157 | accum_data_v7 = _mm512_sub_epi32( |
158 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); |
159 | accum_data_v8 = _mm512_sub_epi32( |
160 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); |
161 | accum_data_v9 = _mm512_sub_epi32( |
162 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); |
163 | accum_data_va = _mm512_sub_epi32( |
164 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); |
165 | accum_data_vb = _mm512_sub_epi32( |
166 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); |
167 | accum_data_vc = _mm512_sub_epi32( |
168 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); |
169 | accum_data_vd = _mm512_sub_epi32( |
170 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); |
171 | accum_data_ve = _mm512_sub_epi32( |
172 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); |
173 | accum_data_vf = _mm512_sub_epi32( |
174 | initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); |
175 | } else { |
176 | accum_data_v0 = initial_accum_data; |
177 | accum_data_v1 = initial_accum_data; |
178 | accum_data_v2 = initial_accum_data; |
179 | accum_data_v3 = initial_accum_data; |
180 | accum_data_v4 = initial_accum_data; |
181 | accum_data_v5 = initial_accum_data; |
182 | accum_data_v6 = initial_accum_data; |
183 | accum_data_v7 = initial_accum_data; |
184 | accum_data_v8 = initial_accum_data; |
185 | accum_data_v9 = initial_accum_data; |
186 | accum_data_va = initial_accum_data; |
187 | accum_data_vb = initial_accum_data; |
188 | accum_data_vc = initial_accum_data; |
189 | accum_data_vd = initial_accum_data; |
190 | accum_data_ve = initial_accum_data; |
191 | accum_data_vf = initial_accum_data; |
192 | } |
193 | |
194 | // Finally, in the channels-are-columns case, load bias data here. |
195 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
196 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
197 | const __m512i bias_data = _mm512_loadu_si512( |
198 | reinterpret_cast<const __m512i*>(params.bias + col)); |
199 | accum_data_v0 = _mm512_add_epi32( |
200 | accum_data_v0, |
201 | _mm512_permutexvar_epi32(_mm512_set1_epi32(0), bias_data)); |
202 | accum_data_v1 = _mm512_add_epi32( |
203 | accum_data_v1, |
204 | _mm512_permutexvar_epi32(_mm512_set1_epi32(1), bias_data)); |
205 | accum_data_v2 = _mm512_add_epi32( |
206 | accum_data_v2, |
207 | _mm512_permutexvar_epi32(_mm512_set1_epi32(2), bias_data)); |
208 | accum_data_v3 = _mm512_add_epi32( |
209 | accum_data_v3, |
210 | _mm512_permutexvar_epi32(_mm512_set1_epi32(3), bias_data)); |
211 | accum_data_v4 = _mm512_add_epi32( |
212 | accum_data_v4, |
213 | _mm512_permutexvar_epi32(_mm512_set1_epi32(4), bias_data)); |
214 | accum_data_v5 = _mm512_add_epi32( |
215 | accum_data_v5, |
216 | _mm512_permutexvar_epi32(_mm512_set1_epi32(5), bias_data)); |
217 | accum_data_v6 = _mm512_add_epi32( |
218 | accum_data_v6, |
219 | _mm512_permutexvar_epi32(_mm512_set1_epi32(6), bias_data)); |
220 | accum_data_v7 = _mm512_add_epi32( |
221 | accum_data_v7, |
222 | _mm512_permutexvar_epi32(_mm512_set1_epi32(7), bias_data)); |
223 | accum_data_v8 = _mm512_add_epi32( |
224 | accum_data_v8, |
225 | _mm512_permutexvar_epi32(_mm512_set1_epi32(8), bias_data)); |
226 | accum_data_v9 = _mm512_add_epi32( |
227 | accum_data_v9, |
228 | _mm512_permutexvar_epi32(_mm512_set1_epi32(9), bias_data)); |
229 | accum_data_va = _mm512_add_epi32( |
230 | accum_data_va, |
231 | _mm512_permutexvar_epi32(_mm512_set1_epi32(10), bias_data)); |
232 | accum_data_vb = _mm512_add_epi32( |
233 | accum_data_vb, |
234 | _mm512_permutexvar_epi32(_mm512_set1_epi32(11), bias_data)); |
235 | accum_data_vc = _mm512_add_epi32( |
236 | accum_data_vc, |
237 | _mm512_permutexvar_epi32(_mm512_set1_epi32(12), bias_data)); |
238 | accum_data_vd = _mm512_add_epi32( |
239 | accum_data_vd, |
240 | _mm512_permutexvar_epi32(_mm512_set1_epi32(13), bias_data)); |
241 | accum_data_ve = _mm512_add_epi32( |
242 | accum_data_ve, |
243 | _mm512_permutexvar_epi32(_mm512_set1_epi32(14), bias_data)); |
244 | accum_data_vf = _mm512_add_epi32( |
245 | accum_data_vf, |
246 | _mm512_permutexvar_epi32(_mm512_set1_epi32(15), bias_data)); |
247 | } |
248 | |
249 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
250 | const void* rhs_ptr = rhs_col_ptr; |
251 | for (int d = 0; d < params.depth; d += 4) { |
252 | const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); |
253 | __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr); |
254 | |
255 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
256 | std::int32_t rhs_data_buf[32]; |
257 | const std::int32_t* rhs_data = |
258 | reinterpret_cast<const std::int32_t*>(rhs_ptr); |
259 | if (params.rhs_scalar_size == 1) { |
260 | rhs_data = rhs_data_buf; |
261 | const __m256i rhs_data_bottom_lane = |
262 | _mm512_castsi512_si256(rhs_data_8bit); |
263 | const __m256i rhs_data_top_lane = |
264 | _mm512_extracti32x8_epi32(rhs_data_8bit, 1); |
265 | const __m512i rhs_16_bit_dup_low = |
266 | _mm512_cvtepi8_epi16(rhs_data_bottom_lane); |
267 | const __m512i rhs_16_bit_dup_high = |
268 | _mm512_cvtepi8_epi16(rhs_data_top_lane); |
269 | // Now that we have cast the RHS data, we store it so that each value |
270 | // can be separately loaded in the accumulation loop. |
271 | _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf), |
272 | rhs_16_bit_dup_low); |
273 | _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16), |
274 | rhs_16_bit_dup_high); |
275 | } else { |
276 | RUY_DCHECK(params.rhs_scalar_size == 2); |
277 | } |
278 | |
279 | // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. |
280 | const __m512i lhs_16_bit_low = |
281 | _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); |
282 | // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. |
283 | const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( |
284 | _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); |
285 | |
286 | auto process_column = [=](int col, __m512i& accum) { |
287 | const __m512i rhs_16_bit_dup_low = |
288 | _mm512_set1_epi32(rhs_data[2 * col]); |
289 | const __m512i rhs_16_bit_dup_high = |
290 | _mm512_set1_epi32(rhs_data[2 * col + 1]); |
291 | |
292 | accum = _mm512_add_epi32( |
293 | accum, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
294 | accum = _mm512_add_epi32( |
295 | accum, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
296 | }; |
297 | process_column(0, accum_data_v0); |
298 | process_column(1, accum_data_v1); |
299 | process_column(2, accum_data_v2); |
300 | process_column(3, accum_data_v3); |
301 | process_column(4, accum_data_v4); |
302 | process_column(5, accum_data_v5); |
303 | process_column(6, accum_data_v6); |
304 | process_column(7, accum_data_v7); |
305 | process_column(8, accum_data_v8); |
306 | process_column(9, accum_data_v9); |
307 | process_column(10, accum_data_va); |
308 | process_column(11, accum_data_vb); |
309 | process_column(12, accum_data_vc); |
310 | process_column(13, accum_data_vd); |
311 | process_column(14, accum_data_ve); |
312 | process_column(15, accum_data_vf); |
313 | |
314 | lhs_ptr += 16 * 4; |
315 | rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + |
316 | 16 * 4 * params.rhs_scalar_size); |
317 | } |
318 | |
319 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
320 | // The non-per-channel case could equivalently be handled in the per-row |
321 | // or per-column code path. The per-row code path is slightly more |
322 | // efficient so we handle it there. |
323 | const bool per_column_multiplier = |
324 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && |
325 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); |
326 | |
327 | __m512i m_vector; |
328 | __m512i e_vector; |
329 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
330 | m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( |
331 | params.multiplier_fixedpoint + multiplier_channel)); |
332 | e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( |
333 | params.multiplier_exponent + multiplier_channel)); |
334 | |
335 | const __m512i m_64bit_low = |
336 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); |
337 | const __m512i m_64bit_high = |
338 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); |
339 | |
340 | const __m512i zero_vector = _mm512_setzero_epi32(); |
341 | const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); |
342 | const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); |
343 | const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); |
344 | const __m512i final_right_shift = _mm512_set1_epi32(31); |
345 | const __m512i right_shift_low = |
346 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); |
347 | const __m512i right_shift_high = |
348 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); |
349 | const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( |
350 | _mm512_extracti32x8_epi32(final_right_shift, 0)); |
351 | const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( |
352 | _mm512_extracti32x8_epi32(final_right_shift, 1)); |
353 | |
354 | // A "half" added for rounding prior to truncation of 64-bit value. |
355 | const __m512i offset_vector = |
356 | _mm512_slli_epi64(_mm512_set1_epi64(1), 30); |
357 | |
358 | auto rounding_right_shift = [=](__m512i& results, |
359 | const __m512i& exponent) { |
360 | // Construct the "nudge" value for each lane if the exponent is |
361 | // greater than 0. Otherwise, the nudge is 0. |
362 | const __m512i zeros = _mm512_setzero_si512(); |
363 | const auto mask_rightshift_gtz = |
364 | _mm512_cmpgt_epi64_mask(exponent, zeros); |
365 | const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64( |
366 | _mm512_set1_epi64(1), |
367 | _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); |
368 | __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, |
369 | one_shift_exp_minus1); |
370 | // Calculate the shifted sum (results + nudge) >> exp. |
371 | const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); |
372 | const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); |
373 | |
374 | // Identify overflow in each lane and create mask. |
375 | const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( |
376 | _mm512_set1_epi64(1), |
377 | _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); |
378 | const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( |
379 | results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); |
380 | // Fill results with either (results + nudge) >> exponent or |
381 | // 1 << (31 - exp) in the case of overflow. |
382 | results = _mm512_mask_mov_epi64( |
383 | shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); |
384 | }; |
385 | |
386 | if (per_column_multiplier) { |
387 | auto apply_multiplier = [=](__m512i& accum, int col) { |
388 | __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8); |
389 | // Apply the fixed-point part of the multiplier. |
390 | __m512i left_shift_val = |
391 | _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift); |
392 | __m512i m_64bit_val = _mm512_permutexvar_epi64( |
393 | perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); |
394 | __m512i offset_vector_val = |
395 | _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector); |
396 | __m512i final_right_shift_val = _mm512_permutexvar_epi64( |
397 | perm_64bit_vals, |
398 | col < 8 ? final_right_shift_low : final_right_shift_high); |
399 | __m512i right_shift_val = _mm512_permutexvar_epi64( |
400 | perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high); |
401 | |
402 | accum = _mm512_sllv_epi32(accum, left_shift_val); |
403 | __m512i scaled_v_low = _mm512_mul_epi32( |
404 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), |
405 | m_64bit_val); |
406 | __m512i scaled_v_high = _mm512_mul_epi32( |
407 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), |
408 | m_64bit_val); |
409 | |
410 | scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val); |
411 | scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val); |
412 | |
413 | scaled_v_low = |
414 | _mm512_srav_epi64(scaled_v_low, final_right_shift_val); |
415 | scaled_v_high = |
416 | _mm512_srav_epi64(scaled_v_high, final_right_shift_val); |
417 | |
418 | rounding_right_shift(scaled_v_low, right_shift_val); |
419 | rounding_right_shift(scaled_v_high, right_shift_val); |
420 | |
421 | accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
422 | accum = _mm512_inserti32x8(accum, |
423 | _mm512_cvtepi64_epi32(scaled_v_high), 1); |
424 | }; |
425 | apply_multiplier(accum_data_v0, 0); |
426 | apply_multiplier(accum_data_v1, 1); |
427 | apply_multiplier(accum_data_v2, 2); |
428 | apply_multiplier(accum_data_v3, 3); |
429 | apply_multiplier(accum_data_v4, 4); |
430 | apply_multiplier(accum_data_v5, 5); |
431 | apply_multiplier(accum_data_v6, 6); |
432 | apply_multiplier(accum_data_v7, 7); |
433 | apply_multiplier(accum_data_v8, 8); |
434 | apply_multiplier(accum_data_v9, 9); |
435 | apply_multiplier(accum_data_va, 10); |
436 | apply_multiplier(accum_data_vb, 11); |
437 | apply_multiplier(accum_data_vc, 12); |
438 | apply_multiplier(accum_data_vd, 13); |
439 | apply_multiplier(accum_data_ve, 14); |
440 | apply_multiplier(accum_data_vf, 15); |
441 | } else { // not per-column, so per-row |
442 | auto apply_multiplier = [=](__m512i& accum) { |
443 | accum = _mm512_sllv_epi32(accum, left_shift); |
444 | // Apply the fixed-point part of the multiplier. |
445 | __m512i scaled_v_low = _mm512_mul_epi32( |
446 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), |
447 | m_64bit_low); |
448 | __m512i scaled_v_high = _mm512_mul_epi32( |
449 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), |
450 | m_64bit_high); |
451 | |
452 | scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); |
453 | scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); |
454 | |
455 | scaled_v_low = |
456 | _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
457 | scaled_v_high = |
458 | _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
459 | |
460 | rounding_right_shift(scaled_v_low, right_shift_low); |
461 | rounding_right_shift(scaled_v_high, right_shift_high); |
462 | accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
463 | accum = _mm512_inserti32x8(accum, |
464 | _mm512_cvtepi64_epi32(scaled_v_high), 1); |
465 | }; |
466 | apply_multiplier(accum_data_v0); |
467 | apply_multiplier(accum_data_v1); |
468 | apply_multiplier(accum_data_v2); |
469 | apply_multiplier(accum_data_v3); |
470 | apply_multiplier(accum_data_v4); |
471 | apply_multiplier(accum_data_v5); |
472 | apply_multiplier(accum_data_v6); |
473 | apply_multiplier(accum_data_v7); |
474 | apply_multiplier(accum_data_v8); |
475 | apply_multiplier(accum_data_v9); |
476 | apply_multiplier(accum_data_va); |
477 | apply_multiplier(accum_data_vb); |
478 | apply_multiplier(accum_data_vc); |
479 | apply_multiplier(accum_data_vd); |
480 | apply_multiplier(accum_data_ve); |
481 | apply_multiplier(accum_data_vf); |
482 | } |
483 | |
484 | if (params.dst_zero_point != 0) { |
485 | __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); |
486 | accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); |
487 | accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); |
488 | accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); |
489 | accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); |
490 | accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); |
491 | accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); |
492 | accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); |
493 | accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); |
494 | accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); |
495 | accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); |
496 | accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); |
497 | accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); |
498 | accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); |
499 | accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); |
500 | accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); |
501 | accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); |
502 | } |
503 | } |
504 | |
505 | const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); |
506 | const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); |
507 | |
508 | const bool store_full_block = |
509 | (residual_rows == 16) && (residual_cols == 16); |
510 | |
511 | __m512i accum_data_v[16]; |
512 | |
513 | // In most cases we would make this conditional on (!store_full_block) and |
514 | // unwind the clamp-and-store loop, but the benefit appears small. |
515 | { |
516 | accum_data_v[0] = accum_data_v0; |
517 | accum_data_v[1] = accum_data_v1; |
518 | accum_data_v[2] = accum_data_v2; |
519 | accum_data_v[3] = accum_data_v3; |
520 | accum_data_v[4] = accum_data_v4; |
521 | accum_data_v[5] = accum_data_v5; |
522 | accum_data_v[6] = accum_data_v6; |
523 | accum_data_v[7] = accum_data_v7; |
524 | accum_data_v[8] = accum_data_v8; |
525 | accum_data_v[9] = accum_data_v9; |
526 | accum_data_v[10] = accum_data_va; |
527 | accum_data_v[11] = accum_data_vb; |
528 | accum_data_v[12] = accum_data_vc; |
529 | accum_data_v[13] = accum_data_vd; |
530 | accum_data_v[14] = accum_data_ve; |
531 | accum_data_v[15] = accum_data_vf; |
532 | } |
533 | |
534 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
535 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
536 | const int block_col_offset = dst_stride; |
537 | if (store_full_block) { |
538 | for (int j = 0; j < 16; ++j) { |
539 | __m512i result = accum_data_v[j]; |
540 | result = _mm512_min_epi32(result, clamp_max_v); |
541 | result = _mm512_max_epi32(result, clamp_min_v); |
542 | _mm_storeu_si128( |
543 | reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), |
544 | _mm512_cvtepi32_epi8(result)); |
545 | } |
546 | } else { |
547 | for (int j = 0; j < residual_cols; ++j) { |
548 | __m512i result = accum_data_v[j]; |
549 | result = _mm512_min_epi32(result, clamp_max_v); |
550 | result = _mm512_max_epi32(result, clamp_min_v); |
551 | _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, |
552 | _mm512_cvtepi32_epi8(result)); |
553 | } |
554 | } |
555 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); |
556 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
557 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
558 | const int block_col_offset = dst_stride; |
559 | if (store_full_block) { |
560 | for (int j = 0; j < residual_cols; ++j) { |
561 | __m512i result = accum_data_v[j]; |
562 | result = _mm512_min_epi32(result, clamp_max_v); |
563 | result = _mm512_max_epi32(result, clamp_min_v); |
564 | _mm_storeu_si128( |
565 | reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), |
566 | _mm512_cvtepi32_epi8(result)); |
567 | } |
568 | } else { |
569 | for (int j = 0; j < residual_cols; ++j) { |
570 | __m512i result = accum_data_v[j]; |
571 | result = _mm512_min_epi32(result, clamp_max_v); |
572 | result = _mm512_max_epi32(result, clamp_min_v); |
573 | _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, |
574 | _mm512_cvtepi32_epi8(result)); |
575 | } |
576 | } |
577 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); |
578 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
579 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
580 | const int block_col_offset = dst_stride; |
581 | if (store_full_block) { |
582 | for (int j = 0; j < 16; ++j) { |
583 | __m512i result = accum_data_v[j]; |
584 | result = _mm512_min_epi32(result, clamp_max_v); |
585 | result = _mm512_max_epi32(result, clamp_min_v); |
586 | _mm256_storeu_si256( |
587 | reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset), |
588 | _mm512_cvtepi32_epi16(result)); |
589 | } |
590 | } else { |
591 | for (int j = 0; j < residual_cols; ++j) { |
592 | __m512i result = accum_data_v[j]; |
593 | result = _mm512_min_epi32(result, clamp_max_v); |
594 | result = _mm512_max_epi32(result, clamp_min_v); |
595 | _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, |
596 | _mm512_cvtepi32_epi16(result)); |
597 | } |
598 | } |
599 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); |
600 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
601 | if (store_full_block) { |
602 | std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
603 | for (int j = 0; j < 16; ++j) { |
604 | _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]); |
605 | } |
606 | } else { |
607 | std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
608 | for (int j = 0; j < residual_cols; ++j) { |
609 | _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, |
610 | accum_data_v[j]); |
611 | } |
612 | } |
613 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); |
614 | } else { |
615 | RUY_DCHECK(false); |
616 | } |
617 | |
618 | lhs_col_ptr += 16 * params.lhs_stride; |
619 | } // End row-block loop. |
620 | |
621 | dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
622 | 16 * params.dst_stride); |
623 | rhs_col_ptr = static_cast<const void*>( |
624 | static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride); |
625 | } // End col-block loop. |
626 | } // NOLINT(readability/fn_size) |
627 | |
628 | void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { |
629 | profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV" ); |
630 | |
631 | RUY_DCHECK_EQ(params.dst_cols, 1); |
632 | RUY_DCHECK_EQ(params.last_col, 0); |
633 | RUY_DCHECK_EQ(params.start_col, 0); |
634 | |
635 | int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; |
636 | |
637 | const void* rhs_col_ptr = params.rhs_base_ptr; |
638 | void* dst_col_ptr = params.dst_base_ptr; |
639 | const std::int32_t* bias_col_ptr = params.bias; |
640 | if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
641 | bias_col_ptr += params.start_row; |
642 | } |
643 | |
644 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
645 | void* dst_ptr = dst_col_ptr; |
646 | const std::int32_t* bias_ptr = bias_col_ptr; |
647 | |
648 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
649 | const bool has_rhs_sums_offsets = |
650 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
651 | std::int32_t rhs_sums_offsets[16]; |
652 | if (has_rhs_sums_offsets) { |
653 | const __m512i rhs_sums_offset_v = |
654 | _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), |
655 | _mm512_loadu_si512(¶ms.rhs_sums[0])); |
656 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), |
657 | rhs_sums_offset_v); |
658 | } |
659 | |
660 | for (int row = params.start_row; row <= params.last_row; row += 16) { |
661 | const int residual_rows = std::min(params.dst_rows - row, 16); |
662 | |
663 | __m512i accum_data_v0; |
664 | |
665 | // Initialize with bias. |
666 | const __mmask16 row_mask = |
667 | (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
668 | __m512i initial_accum_data = |
669 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(bias_ptr)); |
670 | bias_ptr += bias_ptr_block_increment; |
671 | |
672 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
673 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
674 | const __m512i lhs_sums_offset = |
675 | _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), |
676 | _mm512_loadu_si512(¶ms.lhs_sums[row])); |
677 | initial_accum_data = |
678 | _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); |
679 | } |
680 | |
681 | const std::int32_t prod_zp_depth = params.prod_zp_depth; |
682 | if (prod_zp_depth != 0) { |
683 | initial_accum_data = _mm512_add_epi32(initial_accum_data, |
684 | _mm512_set1_epi32(prod_zp_depth)); |
685 | } |
686 | |
687 | // Adjustments differing across columns. |
688 | if (has_rhs_sums_offsets) { |
689 | accum_data_v0 = _mm512_sub_epi32(initial_accum_data, |
690 | _mm512_set1_epi32(rhs_sums_offsets[0])); |
691 | } else { |
692 | accum_data_v0 = initial_accum_data; |
693 | } |
694 | |
695 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
696 | const void* rhs_ptr = rhs_col_ptr; |
697 | for (int d = 0; d < params.depth; d += 4) { |
698 | const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); |
699 | const std::int32_t* rhs_data = |
700 | reinterpret_cast<const std::int32_t*>(rhs_ptr); |
701 | |
702 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
703 | // For simplicity we load 4x the data that we need and process twice the |
704 | // data that we need and store only the data we need. |
705 | std::int32_t rhs_data_buf[2]; |
706 | if (params.rhs_scalar_size == 1) { |
707 | rhs_data = rhs_data_buf; |
708 | const __m128i rhs_data_8bit = |
709 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); |
710 | const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); |
711 | // Now that we have cast the RHS data, we store it so that each value |
712 | // can be separately loaded in the accumulation loop. |
713 | _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), |
714 | rhs_16_bit_dup); |
715 | } else { |
716 | RUY_DCHECK(params.rhs_scalar_size == 2); |
717 | } |
718 | |
719 | // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. |
720 | const __m512i lhs_16_bit_low = |
721 | _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); |
722 | // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. |
723 | const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( |
724 | _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); |
725 | |
726 | // Process column 0. |
727 | __m512i accum_v = accum_data_v0; |
728 | constexpr int index = 0; |
729 | |
730 | const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
731 | const __m512i rhs_16_bit_dup_high = |
732 | _mm512_set1_epi32(rhs_data[index + 1]); |
733 | |
734 | accum_v = _mm512_add_epi32( |
735 | accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
736 | accum_v = _mm512_add_epi32( |
737 | accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
738 | accum_data_v0 = accum_v; |
739 | |
740 | lhs_ptr += 16 * 4; |
741 | rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + |
742 | 16 * 4 * params.rhs_scalar_size); |
743 | } |
744 | |
745 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
746 | __m512i m_vector; |
747 | __m512i e_vector; |
748 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
749 | int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; |
750 | m_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( |
751 | params.multiplier_fixedpoint + channel)); |
752 | e_vector = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( |
753 | params.multiplier_exponent + channel)); |
754 | |
755 | const __m512i m_64bit_low = |
756 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); |
757 | const __m512i m_64bit_high = |
758 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); |
759 | |
760 | const __m512i zero_vector = _mm512_setzero_epi32(); |
761 | const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); |
762 | const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); |
763 | const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); |
764 | const __m512i final_right_shift = _mm512_set1_epi32(31); |
765 | const __m512i right_shift_low = |
766 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); |
767 | const __m512i right_shift_high = |
768 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); |
769 | const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( |
770 | _mm512_extracti32x8_epi32(final_right_shift, 0)); |
771 | const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( |
772 | _mm512_extracti32x8_epi32(final_right_shift, 1)); |
773 | |
774 | // A "half" added for rounding prior to truncation of 64-bit value. |
775 | const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); |
776 | |
777 | auto rounding_right_shift = [=](__m512i& results, |
778 | const __m512i& exponent) { |
779 | // Construct the "nudge" value for each lane if the exponent is |
780 | // greater than 0. Otherwise, the nudge is 0. |
781 | const __m512i zeros = _mm512_setzero_si512(); |
782 | const auto mask_rightshift_gtz = |
783 | _mm512_cmpgt_epi64_mask(exponent, zeros); |
784 | const __m512i one_shift_exp_minus1 = |
785 | _mm512_sllv_epi64(_mm512_set1_epi64(1), |
786 | _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); |
787 | __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, |
788 | one_shift_exp_minus1); |
789 | // Calculate the shifted sum (results + nudge) >> exp. |
790 | const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); |
791 | const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); |
792 | |
793 | // Identify overflow in each lane and create mask. |
794 | const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( |
795 | _mm512_set1_epi64(1), |
796 | _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); |
797 | const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( |
798 | results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); |
799 | // Fill results with either (results + nudge) >> exponent or |
800 | // 1 << (31 - exp) in the case of overflow. |
801 | results = _mm512_mask_mov_epi64( |
802 | shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); |
803 | }; |
804 | |
805 | // Shift and round column 0. |
806 | accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); |
807 | // Apply the fixed-point part of the multiplier. |
808 | __m512i scaled_v_low = _mm512_mul_epi32( |
809 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), |
810 | m_64bit_low); |
811 | __m512i scaled_v_high = _mm512_mul_epi32( |
812 | _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), |
813 | m_64bit_high); |
814 | |
815 | scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); |
816 | scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); |
817 | |
818 | scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
819 | scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
820 | |
821 | rounding_right_shift(scaled_v_low, right_shift_low); |
822 | rounding_right_shift(scaled_v_high, right_shift_high); |
823 | |
824 | accum_data_v0 = |
825 | _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
826 | accum_data_v0 = _mm512_inserti32x8( |
827 | accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
828 | |
829 | if (params.dst_zero_point != 0) { |
830 | __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); |
831 | accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); |
832 | } |
833 | } |
834 | |
835 | const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); |
836 | const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); |
837 | |
838 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
839 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
840 | __m512i result = accum_data_v0; |
841 | result = _mm512_min_epi32(result, clamp_max_v); |
842 | result = _mm512_max_epi32(result, clamp_min_v); |
843 | _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); |
844 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); |
845 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
846 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
847 | __m512i result = accum_data_v0; |
848 | result = _mm512_min_epi32(result, clamp_max_v); |
849 | result = _mm512_max_epi32(result, clamp_min_v); |
850 | _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); |
851 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); |
852 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
853 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
854 | __m512i result = accum_data_v0; |
855 | result = _mm512_min_epi32(result, clamp_max_v); |
856 | result = _mm512_max_epi32(result, clamp_min_v); |
857 | _mm256_mask_storeu_epi16(tmp_ptr, row_mask, |
858 | _mm512_cvtepi32_epi16(result)); |
859 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); |
860 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
861 | std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
862 | _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); |
863 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); |
864 | } else { |
865 | RUY_DCHECK(false); |
866 | } |
867 | |
868 | lhs_col_ptr += 16 * params.lhs_stride; |
869 | } // End row-block loop. |
870 | } // NOLINT(readability/fn_size) |
871 | |
872 | void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { |
873 | profiler::ScopeLabel label("Kernel kAvx512 float" ); |
874 | |
875 | // As parameters are defined, we need to scale by sizeof(float). |
876 | const std::int64_t lhs_stride = params.lhs_stride >> 2; |
877 | const std::int64_t dst_stride = params.dst_stride >> 2; |
878 | const std::int64_t rhs_stride = params.rhs_stride >> 2; |
879 | |
880 | int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
881 | const int end_row = std::min(params.dst_rows, params.last_row + 16); |
882 | const int end_col = std::min(params.dst_cols, params.last_col + 16); |
883 | |
884 | const float* adj_rhs_col_ptr = |
885 | params.rhs_base_ptr - params.start_col * rhs_stride; |
886 | float* adj_dst_col_ptr = |
887 | params.dst_base_ptr - params.start_col * dst_stride - params.start_row; |
888 | const float* adj_lhs_col_ptr = |
889 | params.lhs_base_ptr - params.start_row * lhs_stride; |
890 | const float* bias_ptr = params.bias; |
891 | |
892 | const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); |
893 | const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); |
894 | const bool channel_dimension_is_col = |
895 | params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; |
896 | |
897 | int col = params.start_col; |
898 | for (; col <= end_col - 16; col += 16) { |
899 | const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
900 | float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
901 | |
902 | int row = params.start_row; |
903 | for (; row <= end_row - 16; row += 16) { |
904 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
905 | float* dst_ptr = dst_col_ptr + row; |
906 | |
907 | // Process block in two halves, split by columns. |
908 | #pragma unroll(1) |
909 | for (int mmm = 0; mmm < 2; ++mmm) { |
910 | __m512 accum_data_v0; |
911 | __m512 accum_data_v1; |
912 | __m512 accum_data_v2; |
913 | __m512 accum_data_v3; |
914 | __m512 accum_data_v4; |
915 | __m512 accum_data_v5; |
916 | __m512 accum_data_v6; |
917 | __m512 accum_data_v7; |
918 | |
919 | // Initialize with bias. |
920 | if (channel_dimension_is_col) { |
921 | const float* bias_elem_ptr = |
922 | bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; |
923 | accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); |
924 | accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); |
925 | accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); |
926 | accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); |
927 | accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); |
928 | accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); |
929 | accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); |
930 | accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); |
931 | } else { |
932 | const __m512 initial_accum_data = |
933 | _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); |
934 | |
935 | accum_data_v0 = initial_accum_data; |
936 | accum_data_v1 = initial_accum_data; |
937 | accum_data_v2 = initial_accum_data; |
938 | accum_data_v3 = initial_accum_data; |
939 | accum_data_v4 = initial_accum_data; |
940 | accum_data_v5 = initial_accum_data; |
941 | accum_data_v6 = initial_accum_data; |
942 | accum_data_v7 = initial_accum_data; |
943 | } |
944 | |
945 | const float* lhs_ptr = lhs_col_ptr; |
946 | const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
947 | for (int d = 0; d < (params.depth - 1); ++d) { |
948 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
949 | const float* rhs_data = rhs_ptr; |
950 | lhs_ptr += 16; |
951 | rhs_ptr += 16; |
952 | |
953 | // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: |
954 | // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do |
955 | // so if given an rvalue. |
956 | accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), |
957 | accum_data_v0); |
958 | accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), |
959 | accum_data_v1); |
960 | accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), |
961 | accum_data_v2); |
962 | accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), |
963 | accum_data_v3); |
964 | accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), |
965 | accum_data_v4); |
966 | accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), |
967 | accum_data_v5); |
968 | accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), |
969 | accum_data_v6); |
970 | accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), |
971 | accum_data_v7); |
972 | } |
973 | { // nested extra blocks lead to measurable speed gains |
974 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
975 | const float* rhs_data = rhs_ptr; |
976 | accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), |
977 | accum_data_v0); |
978 | accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), |
979 | accum_data_v1); |
980 | accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), |
981 | accum_data_v2); |
982 | accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), |
983 | accum_data_v3); |
984 | accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), |
985 | accum_data_v4); |
986 | accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), |
987 | accum_data_v5); |
988 | accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), |
989 | accum_data_v6); |
990 | accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), |
991 | accum_data_v7); |
992 | { |
993 | float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; |
994 | accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); |
995 | accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); |
996 | _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); |
997 | accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); |
998 | accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); |
999 | _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); |
1000 | accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); |
1001 | accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); |
1002 | _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); |
1003 | accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); |
1004 | accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); |
1005 | _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); |
1006 | accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); |
1007 | accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); |
1008 | _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); |
1009 | accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); |
1010 | accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); |
1011 | _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); |
1012 | accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); |
1013 | accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); |
1014 | _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); |
1015 | accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); |
1016 | accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); |
1017 | _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); |
1018 | } |
1019 | } |
1020 | } |
1021 | } // End row-block loop. |
1022 | |
1023 | // The unrolling within this conditional may be somewhat pointless. It |
1024 | // depends on the kinds of models. |
1025 | if (row < end_row) { |
1026 | const int residual_rows = end_row - row; |
1027 | |
1028 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
1029 | float* dst_ptr = dst_col_ptr + row; |
1030 | |
1031 | const __mmask16 row_mask = |
1032 | (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
1033 | |
1034 | // Process block in two halves, split by columns. |
1035 | for (int mmm = 0; mmm < 2; ++mmm) { |
1036 | __m512 accum_data_v0; |
1037 | __m512 accum_data_v1; |
1038 | __m512 accum_data_v2; |
1039 | __m512 accum_data_v3; |
1040 | __m512 accum_data_v4; |
1041 | __m512 accum_data_v5; |
1042 | __m512 accum_data_v6; |
1043 | __m512 accum_data_v7; |
1044 | |
1045 | // Initialize with bias. |
1046 | if (channel_dimension_is_col) { |
1047 | const float* bias_elem_ptr = |
1048 | bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; |
1049 | accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); |
1050 | accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); |
1051 | accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); |
1052 | accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); |
1053 | accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); |
1054 | accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); |
1055 | accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); |
1056 | accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); |
1057 | } else { |
1058 | const __m512 initial_accum_data = |
1059 | _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); |
1060 | |
1061 | accum_data_v0 = initial_accum_data; |
1062 | accum_data_v1 = initial_accum_data; |
1063 | accum_data_v2 = initial_accum_data; |
1064 | accum_data_v3 = initial_accum_data; |
1065 | accum_data_v4 = initial_accum_data; |
1066 | accum_data_v5 = initial_accum_data; |
1067 | accum_data_v6 = initial_accum_data; |
1068 | accum_data_v7 = initial_accum_data; |
1069 | } |
1070 | |
1071 | const float* lhs_ptr = lhs_col_ptr; |
1072 | const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
1073 | for (int d = 0; d < (params.depth - 1); ++d) { |
1074 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
1075 | const float* rhs_data = rhs_ptr; |
1076 | lhs_ptr += 16; |
1077 | rhs_ptr += 16; |
1078 | // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: |
1079 | // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do |
1080 | // so if given an rvalue. |
1081 | accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), |
1082 | accum_data_v0); |
1083 | accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), |
1084 | accum_data_v1); |
1085 | accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), |
1086 | accum_data_v2); |
1087 | accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), |
1088 | accum_data_v3); |
1089 | accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), |
1090 | accum_data_v4); |
1091 | accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), |
1092 | accum_data_v5); |
1093 | accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), |
1094 | accum_data_v6); |
1095 | accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), |
1096 | accum_data_v7); |
1097 | } |
1098 | { |
1099 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
1100 | const float* rhs_data = rhs_ptr; |
1101 | accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), |
1102 | accum_data_v0); |
1103 | accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), |
1104 | accum_data_v1); |
1105 | accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), |
1106 | accum_data_v2); |
1107 | accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), |
1108 | accum_data_v3); |
1109 | accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), |
1110 | accum_data_v4); |
1111 | accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), |
1112 | accum_data_v5); |
1113 | accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), |
1114 | accum_data_v6); |
1115 | accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), |
1116 | accum_data_v7); |
1117 | { |
1118 | float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; |
1119 | accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); |
1120 | accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); |
1121 | _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, |
1122 | accum_data_v0); |
1123 | accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); |
1124 | accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); |
1125 | _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, |
1126 | accum_data_v1); |
1127 | accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); |
1128 | accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); |
1129 | _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, |
1130 | accum_data_v2); |
1131 | accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); |
1132 | accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); |
1133 | _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, |
1134 | accum_data_v3); |
1135 | accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); |
1136 | accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); |
1137 | _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, |
1138 | accum_data_v4); |
1139 | accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); |
1140 | accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); |
1141 | _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, |
1142 | accum_data_v5); |
1143 | accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); |
1144 | accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); |
1145 | _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, |
1146 | accum_data_v6); |
1147 | accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); |
1148 | accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); |
1149 | _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, |
1150 | accum_data_v7); |
1151 | } |
1152 | } |
1153 | } // Inner half-block loop. |
1154 | } // Residual rows, main col-block loop. |
1155 | } // End col-block loop. |
1156 | |
1157 | if (col < end_col) { |
1158 | RUY_DCHECK_GE(end_col - col, 0); |
1159 | RUY_DCHECK_LT(end_col - col, 16); |
1160 | |
1161 | __m512 accum_data_v[8]; |
1162 | |
1163 | const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
1164 | float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
1165 | |
1166 | for (int row = params.start_row; row < end_row; row += 16) { |
1167 | const int residual_rows = std::min(end_row - row, 16); |
1168 | |
1169 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
1170 | float* dst_ptr = dst_col_ptr + row; |
1171 | |
1172 | const __mmask16 row_mask = |
1173 | (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
1174 | |
1175 | // Process block in two halves, split by columns. |
1176 | for (int mmm = 0; mmm < 2; ++mmm) { |
1177 | // Initialize with bias. |
1178 | if (channel_dimension_is_col) { |
1179 | const float* bias_elem_ptr = |
1180 | bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; |
1181 | for (int j = 0; j < 8; ++j) { |
1182 | accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]); |
1183 | } |
1184 | } else { |
1185 | const __m512 initial_accum_data = |
1186 | _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); |
1187 | for (int j = 0; j < 8; ++j) { |
1188 | accum_data_v[j] = initial_accum_data; |
1189 | } |
1190 | } |
1191 | |
1192 | const float* lhs_ptr = lhs_col_ptr; |
1193 | const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
1194 | for (int d = 0; d < params.depth; ++d) { |
1195 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
1196 | const float* rhs_data = rhs_ptr; |
1197 | |
1198 | for (int j = 0; j < 8; ++j) { |
1199 | const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); |
1200 | accum_data_v[j] = |
1201 | _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); |
1202 | } |
1203 | lhs_ptr += 16; |
1204 | rhs_ptr += 16; |
1205 | } |
1206 | |
1207 | const int residual_cols = std::min(end_col - col - 8 * mmm, 8); |
1208 | |
1209 | if (residual_rows == 16) { |
1210 | if (residual_cols == 8) { |
1211 | for (int j = 0; j < 8; ++j) { |
1212 | float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
1213 | accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
1214 | accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
1215 | _mm512_storeu_ps(block_ptr, accum_data_v[j]); |
1216 | } |
1217 | } else { |
1218 | for (int j = 0; j < residual_cols; ++j) { |
1219 | float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
1220 | accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
1221 | accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
1222 | _mm512_storeu_ps(block_ptr, accum_data_v[j]); |
1223 | } |
1224 | } |
1225 | } else { |
1226 | for (int j = 0; j < residual_cols; ++j) { |
1227 | float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
1228 | accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
1229 | accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
1230 | _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); |
1231 | } |
1232 | } |
1233 | } // Inner half-block loop. |
1234 | } // End row-block loop. |
1235 | } // Residual cols. |
1236 | } |
1237 | |
1238 | void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { |
1239 | profiler::ScopeLabel label("Kernel kAvx512 float GEMV" ); |
1240 | |
1241 | RUY_DCHECK_EQ(params.dst_cols, 1); |
1242 | RUY_DCHECK_EQ(params.last_col, 0); |
1243 | RUY_DCHECK_EQ(params.start_col, 0); |
1244 | |
1245 | // As parameters are defined, we need to scale by sizeof(float). |
1246 | const std::int64_t lhs_stride = params.lhs_stride >> 2; |
1247 | |
1248 | int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
1249 | const int end_row = std::min(params.dst_rows, params.last_row + 16); |
1250 | |
1251 | float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; |
1252 | const float* adj_lhs_col_ptr = |
1253 | params.lhs_base_ptr - params.start_row * lhs_stride; |
1254 | const float* bias_col_ptr = params.bias; |
1255 | |
1256 | const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); |
1257 | const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); |
1258 | |
1259 | __m512 accum_data_v; |
1260 | |
1261 | const float* rhs_col_ptr = params.rhs_base_ptr; |
1262 | float* dst_col_ptr = adj_dst_col_ptr; |
1263 | |
1264 | int row = params.start_row; |
1265 | for (; row <= end_row - 16; row += 16) { |
1266 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
1267 | float* dst_ptr = dst_col_ptr + row; |
1268 | const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
1269 | |
1270 | // Initialize with bias. |
1271 | accum_data_v = _mm512_loadu_ps(bias_ptr); |
1272 | |
1273 | const float* lhs_ptr = lhs_col_ptr; |
1274 | const float* rhs_ptr = rhs_col_ptr; |
1275 | for (int d = 0; d < params.depth; ++d) { |
1276 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
1277 | const float rhs_data = *rhs_ptr; |
1278 | |
1279 | const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); |
1280 | accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); |
1281 | lhs_ptr += 16; |
1282 | rhs_ptr += 16; |
1283 | } |
1284 | |
1285 | accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); |
1286 | accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); |
1287 | _mm512_storeu_ps(dst_ptr, accum_data_v); |
1288 | } // End row-block loop. |
1289 | |
1290 | if (row < end_row) { |
1291 | const int residual_rows = end_row - row; |
1292 | RUY_CHECK_GE(residual_rows, 1); |
1293 | RUY_CHECK_LT(residual_rows, 16); |
1294 | |
1295 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
1296 | float* dst_ptr = dst_col_ptr + row; |
1297 | const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
1298 | |
1299 | // Initialize with bias. |
1300 | const __mmask16 row_mask = |
1301 | (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
1302 | accum_data_v = _mm512_loadu_ps(bias_ptr); |
1303 | |
1304 | const float* lhs_ptr = lhs_col_ptr; |
1305 | const float* rhs_ptr = rhs_col_ptr; |
1306 | for (int d = 0; d < params.depth; ++d) { |
1307 | const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
1308 | const float rhs_data = *rhs_ptr; |
1309 | |
1310 | const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); |
1311 | accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); |
1312 | lhs_ptr += 16; |
1313 | rhs_ptr += 16; |
1314 | } |
1315 | |
1316 | accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); |
1317 | accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); |
1318 | _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); |
1319 | } // End handling of residual rows. |
1320 | } |
1321 | |
1322 | #endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
1323 | |
1324 | } // namespace ruy |
1325 | |