1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cstdint> |
18 | #include <cstring> |
19 | |
20 | #include "ruy/check_macros.h" |
21 | #include "ruy/kernel_common.h" |
22 | #include "ruy/kernel_x86.h" |
23 | #include "ruy/opt_set.h" |
24 | #include "ruy/platform.h" |
25 | #include "ruy/profiler/instrumentation.h" |
26 | |
27 | #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) |
28 | #include <immintrin.h> // IWYU pragma: keep |
29 | #endif |
30 | |
31 | namespace ruy { |
32 | |
33 | #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) |
34 | |
35 | void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) { |
36 | // CPU-ID-based checks should disable the path that would reach this point. |
37 | RUY_DCHECK(false); |
38 | } |
39 | |
40 | void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>&) { |
41 | // CPU-ID-based checks should disable the path that would reach this point. |
42 | RUY_DCHECK(false); |
43 | } |
44 | |
45 | void KernelFloatAvx2(const KernelParamsFloat<8, 8>&) { |
46 | // CPU-ID-based checks should disable the path that would reach this point. |
47 | RUY_DCHECK(false); |
48 | } |
49 | |
50 | void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) { |
51 | // CPU-ID-based checks should disable the path that would reach this point. |
52 | RUY_DCHECK(false); |
53 | } |
54 | |
55 | #else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) |
56 | |
57 | static constexpr int kAvx8bitBlockSize = 8; |
58 | static constexpr int kAvx8bitInnerSize = 4; |
59 | |
60 | namespace { |
61 | namespace intrin_utils { |
62 | |
63 | template <> |
64 | inline __m256i mm256_shuffle_epi8<Path::kAvx2Fma>(const __m256i& a, |
65 | const __m256i& b) { |
66 | return _mm256_shuffle_epi8(a, b); |
67 | } |
68 | |
69 | // Make an inline function for FMA so we can share the float kernels |
70 | // with non-FMA code. |
71 | template <> |
72 | inline __m256 MulAdd<Path::kAvx2Fma>(const __m256& a, const __m256& b, |
73 | const __m256& c) { |
74 | return _mm256_fmadd_ps(a, b, c); |
75 | } |
76 | |
77 | template <> |
78 | inline __m128i <Path::kAvx2Fma>(const __m256i& a, |
79 | const int imm) { |
80 | switch (imm) { |
81 | case 0: |
82 | return _mm256_extracti128_si256(a, 0); |
83 | case 1: |
84 | return _mm256_extracti128_si256(a, 1); |
85 | default: |
86 | RUY_DCHECK_LT(imm, 2); |
87 | return _mm_setzero_si128(); |
88 | } |
89 | } |
90 | |
91 | __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, |
92 | const __m256i& mask) { |
93 | __m256 result = |
94 | _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), |
95 | _mm256_castsi256_ps(mask)); |
96 | return _mm256_castps_si256(result); |
97 | } |
98 | |
99 | } // namespace intrin_utils |
100 | } // namespace |
101 | |
102 | template <Path path> |
103 | void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { |
104 | profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit" ); |
105 | const std::int8_t splitter_idx_data[32] = { |
106 | 0, 1, 4, 5, 8, 9, 12, 13, // |
107 | 2, 3, 6, 7, 10, 11, 14, 15, // |
108 | 0, 1, 4, 5, 8, 9, 12, 13, // |
109 | 2, 3, 6, 7, 10, 11, 14, 15 // |
110 | }; |
111 | |
112 | std::int32_t dst_stride = 0; |
113 | if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || |
114 | (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { |
115 | dst_stride = params.dst_stride; |
116 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
117 | dst_stride = params.dst_stride / sizeof(std::int16_t); |
118 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
119 | dst_stride = params.dst_stride / sizeof(std::int32_t); |
120 | } else { |
121 | RUY_DCHECK(false); |
122 | } |
123 | |
124 | const void* rhs_col_ptr = params.rhs_base_ptr; |
125 | void* dst_col_ptr = params.dst_base_ptr; |
126 | |
127 | for (int col = params.start_col; col <= params.last_col; |
128 | col += kAvx8bitBlockSize) { |
129 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
130 | void* dst_ptr = dst_col_ptr; |
131 | |
132 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
133 | const bool has_rhs_sums_offsets = |
134 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
135 | std::int32_t rhs_sums_offsets[8]; |
136 | if (has_rhs_sums_offsets) { |
137 | const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( |
138 | _mm256_set1_epi32(lhs_zero_point), |
139 | _mm256_loadu_si256( |
140 | reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); |
141 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
142 | rhs_sums_offset_v); |
143 | } |
144 | |
145 | for (int row = params.start_row; row <= params.last_row; |
146 | row += kAvx8bitBlockSize) { |
147 | int channel = |
148 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; |
149 | int multiplier_channel = |
150 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; |
151 | const int residual_rows = |
152 | std::min(params.dst_rows - row, kAvx8bitBlockSize); |
153 | const int residual_cols = |
154 | std::min(params.dst_cols - col, kAvx8bitBlockSize); |
155 | |
156 | const __m256i splitter_idx = _mm256_loadu_si256( |
157 | reinterpret_cast<__m256i const*>(splitter_idx_data)); |
158 | |
159 | __m256i accum_data_v0; |
160 | __m256i accum_data_v1; |
161 | __m256i accum_data_v2; |
162 | __m256i accum_data_v3; |
163 | __m256i accum_data_v4; |
164 | __m256i accum_data_v5; |
165 | __m256i accum_data_v6; |
166 | __m256i accum_data_v7; |
167 | |
168 | // initial_accum_data will be the initialize of each of the |
169 | // accum_data_* accumulator registers. We compute into it terms that are |
170 | // identical across columns. |
171 | __m256i initial_accum_data = _mm256_set1_epi32(params.prod_zp_depth); |
172 | |
173 | // In the channels-are-rows case, we can load bias here. |
174 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
175 | !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
176 | initial_accum_data = _mm256_add_epi32( |
177 | initial_accum_data, |
178 | _mm256_loadu_si256( |
179 | reinterpret_cast<const __m256i*>(params.bias + row))); |
180 | } |
181 | |
182 | // Adjustments common across columns. |
183 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
184 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
185 | const __m256i lhs_sums_offset = _mm256_mullo_epi32( |
186 | _mm256_set1_epi32(rhs_zero_point), |
187 | _mm256_loadu_si256( |
188 | reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); |
189 | initial_accum_data = |
190 | _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); |
191 | } |
192 | |
193 | // Adjustments differing across columns. |
194 | if (has_rhs_sums_offsets) { |
195 | accum_data_v0 = _mm256_sub_epi32( |
196 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); |
197 | accum_data_v1 = _mm256_sub_epi32( |
198 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); |
199 | accum_data_v2 = _mm256_sub_epi32( |
200 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); |
201 | accum_data_v3 = _mm256_sub_epi32( |
202 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); |
203 | accum_data_v4 = _mm256_sub_epi32( |
204 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); |
205 | accum_data_v5 = _mm256_sub_epi32( |
206 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); |
207 | accum_data_v6 = _mm256_sub_epi32( |
208 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); |
209 | accum_data_v7 = _mm256_sub_epi32( |
210 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); |
211 | } else { |
212 | accum_data_v0 = initial_accum_data; |
213 | accum_data_v1 = initial_accum_data; |
214 | accum_data_v2 = initial_accum_data; |
215 | accum_data_v3 = initial_accum_data; |
216 | accum_data_v4 = initial_accum_data; |
217 | accum_data_v5 = initial_accum_data; |
218 | accum_data_v6 = initial_accum_data; |
219 | accum_data_v7 = initial_accum_data; |
220 | } |
221 | |
222 | // Finally, in the channels-are-columns case, load bias data here. |
223 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
224 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
225 | const __m256i bias_data = _mm256_loadu_si256( |
226 | reinterpret_cast<const __m256i*>(params.bias + col)); |
227 | accum_data_v0 = _mm256_add_epi32( |
228 | accum_data_v0, |
229 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(0))); |
230 | accum_data_v1 = _mm256_add_epi32( |
231 | accum_data_v1, |
232 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(1))); |
233 | accum_data_v2 = _mm256_add_epi32( |
234 | accum_data_v2, |
235 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(2))); |
236 | accum_data_v3 = _mm256_add_epi32( |
237 | accum_data_v3, |
238 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(3))); |
239 | accum_data_v4 = _mm256_add_epi32( |
240 | accum_data_v4, |
241 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(4))); |
242 | accum_data_v5 = _mm256_add_epi32( |
243 | accum_data_v5, |
244 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(5))); |
245 | accum_data_v6 = _mm256_add_epi32( |
246 | accum_data_v6, |
247 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(6))); |
248 | accum_data_v7 = _mm256_add_epi32( |
249 | accum_data_v7, |
250 | _mm256_permutevar8x32_epi32(bias_data, _mm256_set1_epi32(7))); |
251 | } |
252 | |
253 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
254 | const void* rhs_ptr = rhs_col_ptr; |
255 | for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
256 | const __m256i lhs_data = |
257 | _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
258 | const __m256i rhs_data_8bit = |
259 | _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); |
260 | |
261 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
262 | std::int32_t rhs_data_buf[16]; |
263 | const std::int32_t* rhs_data = |
264 | reinterpret_cast<const std::int32_t*>(rhs_ptr); |
265 | |
266 | if (params.rhs_scalar_size == 1) { |
267 | rhs_data = rhs_data_buf; |
268 | const __m128i rhs_data_bottom_lane = |
269 | _mm256_castsi256_si128(rhs_data_8bit); |
270 | const __m128i rhs_data_top_lane = |
271 | _mm256_extracti128_si256(rhs_data_8bit, 1); |
272 | const __m256i rhs_16_bit_dup_low = |
273 | _mm256_cvtepi8_epi16(rhs_data_bottom_lane); |
274 | const __m256i rhs_16_bit_dup_high = |
275 | _mm256_cvtepi8_epi16(rhs_data_top_lane); |
276 | // Now that we have cast the RHS data, we store it so that each value |
277 | // can be separately loaded in the accumulation loop. |
278 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf), |
279 | rhs_16_bit_dup_low); |
280 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8), |
281 | rhs_16_bit_dup_high); |
282 | } else { |
283 | RUY_DCHECK(params.rhs_scalar_size == 2); |
284 | } |
285 | |
286 | const __m256i lhs_data_split = |
287 | _mm256_shuffle_epi8(lhs_data, splitter_idx); |
288 | const __m256i lhs_data_split_expand_bottom = |
289 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); |
290 | const __m256i lhs_data_split_expand_top = |
291 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); |
292 | |
293 | // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
294 | const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( |
295 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
296 | // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
297 | const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( |
298 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
299 | |
300 | __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>( |
301 | rhs_data)); // Load [0 1 2 3 4 5 6 7] |
302 | __m256i rhs1 = _mm256_lddqu_si256( |
303 | reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15] |
304 | __m256i rhs0_3 = |
305 | _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3] |
306 | __m256i rhs4_7 = |
307 | _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7] |
308 | __m256i rhs8_11 = |
309 | _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11] |
310 | __m256i rhs12_15 = |
311 | _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15] |
312 | |
313 | auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi, |
314 | __m256i& accum) { |
315 | accum = _mm256_add_epi32( |
316 | accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_dup_lo)); |
317 | accum = _mm256_add_epi32( |
318 | accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_dup_hi)); |
319 | }; |
320 | __m256i tmp0, tmp1, tmp2, tmp3; |
321 | tmp0 = _mm256_shuffle_epi32(rhs0_3, 0); |
322 | tmp1 = _mm256_shuffle_epi32(rhs0_3, 0x55); |
323 | process_column(tmp0, tmp1, accum_data_v0); |
324 | tmp2 = _mm256_shuffle_epi32(rhs0_3, 0xaa); |
325 | tmp3 = _mm256_shuffle_epi32(rhs0_3, 0xff); |
326 | process_column(tmp2, tmp3, accum_data_v1); |
327 | |
328 | tmp0 = _mm256_shuffle_epi32(rhs4_7, 0); |
329 | tmp1 = _mm256_shuffle_epi32(rhs4_7, 0x55); |
330 | process_column(tmp0, tmp1, accum_data_v2); |
331 | tmp2 = _mm256_shuffle_epi32(rhs4_7, 0xaa); |
332 | tmp3 = _mm256_shuffle_epi32(rhs4_7, 0xff); |
333 | process_column(tmp2, tmp3, accum_data_v3); |
334 | |
335 | tmp0 = _mm256_shuffle_epi32(rhs8_11, 0); |
336 | tmp1 = _mm256_shuffle_epi32(rhs8_11, 0x55); |
337 | process_column(tmp0, tmp1, accum_data_v4); |
338 | tmp2 = _mm256_shuffle_epi32(rhs8_11, 0xaa); |
339 | tmp3 = _mm256_shuffle_epi32(rhs8_11, 0xff); |
340 | process_column(tmp2, tmp3, accum_data_v5); |
341 | |
342 | tmp0 = _mm256_shuffle_epi32(rhs12_15, 0); |
343 | tmp1 = _mm256_shuffle_epi32(rhs12_15, 0x55); |
344 | process_column(tmp0, tmp1, accum_data_v6); |
345 | tmp2 = _mm256_shuffle_epi32(rhs12_15, 0xaa); |
346 | tmp3 = _mm256_shuffle_epi32(rhs12_15, 0xff); |
347 | process_column(tmp2, tmp3, accum_data_v7); |
348 | |
349 | lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
350 | rhs_ptr = static_cast<const void*>( |
351 | static_cast<const char*>(rhs_ptr) + |
352 | kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size); |
353 | } |
354 | |
355 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
356 | __m256i m_vector; |
357 | __m256i e_vector; |
358 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
359 | m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
360 | params.multiplier_fixedpoint + multiplier_channel)); |
361 | e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
362 | params.multiplier_exponent + multiplier_channel)); |
363 | |
364 | const __m256i m_64bit_low = |
365 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); |
366 | const __m256i m_64bit_high = |
367 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); |
368 | |
369 | const __m256i zero_vector = _mm256_setzero_si256(); |
370 | const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); |
371 | const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); |
372 | const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); |
373 | const __m256i final_right_shift = _mm256_set1_epi32(31); |
374 | const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( |
375 | _mm256_extracti128_si256(final_right_shift, 0)); |
376 | const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( |
377 | _mm256_extracti128_si256(final_right_shift, 1)); |
378 | const __m256i convert_to_unsigned_64 = |
379 | _mm256_set1_epi64x(0x8000000000000000); |
380 | |
381 | __m256i post_scaling_offset = _mm256_setzero_si256(); |
382 | // A "half" added for rounding prior to truncation of 64-bit value. |
383 | const __m256i offset_vector = _mm256_add_epi64( |
384 | _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), |
385 | convert_to_unsigned_64); |
386 | |
387 | if (params.dst_zero_point) { |
388 | post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
389 | } |
390 | |
391 | const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); |
392 | |
393 | // We cannot do |
394 | // |
395 | // scaled_v_low = |
396 | // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); |
397 | // scaled_v_high = |
398 | // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); |
399 | // |
400 | // since this instruction is not in AVX2. Instead we use |
401 | // _mm256_srlv_epi64, but this is an unsigned shift, so we applied |
402 | // offsets before (convert_to_unsigned_64) and after |
403 | // (convert_to_signed_halved). |
404 | // |
405 | // The overall process is, for 64-bit scaled accumulator: |
406 | // unsigned_accum = signed_accum + 1 << 63; |
407 | // unsigned_accum = (unsigned_accum >> right_shift) >> 31; |
408 | // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; |
409 | |
410 | // There are various ways to repack the results, in the absence of |
411 | // _mm256_cvtepi64_epi32() or anything like it. |
412 | // A. |
413 | // accum_data_v[j] = |
414 | // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), |
415 | // _mm256_extract_epi32(scaled_v_high, 4), |
416 | // _mm256_extract_epi32(scaled_v_high, 2), |
417 | // _mm256_extract_epi32(scaled_v_high, 0), |
418 | // _mm256_extract_epi32(scaled_v_low, 6), |
419 | // _mm256_extract_epi32(scaled_v_low, 4), |
420 | // _mm256_extract_epi32(scaled_v_low, 2), |
421 | // _mm256_extract_epi32(scaled_v_low, 0)); |
422 | // B. |
423 | // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); |
424 | // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); |
425 | // accum_data_v[j] = |
426 | // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), |
427 | // _mm256_extract_epi64(scaled_v_high, 0), |
428 | // _mm256_extract_epi64(scaled_v_low, 2), |
429 | // _mm256_extract_epi64(scaled_v_low, 0)); |
430 | // C. |
431 | // scaled_v_low = |
432 | // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); |
433 | // scaled_v_high = |
434 | // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); |
435 | // accum_data_v[j] = |
436 | // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); |
437 | // |
438 | // However, we choose the following because it uses two lighter |
439 | // instructions. The permutation does have a longer latency, but this |
440 | // loop can be unrolled. |
441 | // D. |
442 | // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); |
443 | // __m256i results = |
444 | // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); |
445 | // results = _mm256_permutevar8x32_epi32(results, repack_perm); |
446 | // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset); |
447 | |
448 | // This multiplier code is complex and expensive enough on x86, that |
449 | // we prefer to implement the channels-are-columns case by transposing |
450 | // around it, rather than duplicate it (which would also require |
451 | // duplicating the above code computing the multiplier constants). |
452 | // This is one instance where channels-are-columns has lower performance |
453 | // than channels-are-rows. |
454 | const bool transpose_around_multiplier = |
455 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && |
456 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); |
457 | if (transpose_around_multiplier) { |
458 | // Transpose the 8x8 accumulators block. Will be un-transposed below |
459 | // after the multplier implementation. |
460 | intrin_utils::mm256_transpose8x8_epi32<path>( |
461 | &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
462 | &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
463 | } |
464 | |
465 | auto rounding_right_shift = [=](__m256i& results, |
466 | const __m256i& exponent) { |
467 | // Construct the "nudge" value for each lane if the exponent is |
468 | // greater than 0. Otherwise, the nudge is 0. |
469 | const __m256i zeros = _mm256_setzero_si256(); |
470 | const __m256i mask_rightshift_gtz = |
471 | _mm256_cmpgt_epi32(exponent, zeros); |
472 | const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( |
473 | _mm256_set1_epi32(1), |
474 | _mm256_sub_epi32(exponent, _mm256_set1_epi32(1))); |
475 | __m256i nudge = intrin_utils::mm256_blendv_epi32( |
476 | zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
477 | // Calculate the shifted sum (results + nudge) >> exp. |
478 | const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); |
479 | const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent); |
480 | |
481 | // Identify overflow in each lane and create mask. |
482 | const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( |
483 | _mm256_set1_epi32(1), |
484 | _mm256_sub_epi32(_mm256_set1_epi32(31), exponent)); |
485 | const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( |
486 | results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); |
487 | // Fill results with either (results + nudge) >> exponent or |
488 | // 1 << (31 - exp) in the case of overflow. |
489 | results = intrin_utils::mm256_blendv_epi32( |
490 | shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
491 | }; |
492 | |
493 | auto apply_multiplier = [=](__m256i& accum) { |
494 | __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift); |
495 | // Apply the fixed-point part of the multiplier. |
496 | __m256i scaled_v_low = _mm256_mul_epi32( |
497 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), |
498 | m_64bit_low); |
499 | __m256i scaled_v_high = _mm256_mul_epi32( |
500 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), |
501 | m_64bit_high); |
502 | |
503 | scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); |
504 | scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); |
505 | |
506 | scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); |
507 | scaled_v_high = |
508 | _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); |
509 | |
510 | scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); |
511 | __m256i results = |
512 | _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); |
513 | results = _mm256_permutevar8x32_epi32(results, repack_perm); |
514 | // Now do a Rounding Right Shift. |
515 | rounding_right_shift(results, right_shift); |
516 | accum = _mm256_add_epi32(results, post_scaling_offset); |
517 | }; |
518 | apply_multiplier(accum_data_v0); |
519 | apply_multiplier(accum_data_v1); |
520 | apply_multiplier(accum_data_v2); |
521 | apply_multiplier(accum_data_v3); |
522 | apply_multiplier(accum_data_v4); |
523 | apply_multiplier(accum_data_v5); |
524 | apply_multiplier(accum_data_v6); |
525 | apply_multiplier(accum_data_v7); |
526 | // See above comment: here we transpose again to undo the transposition |
527 | // of the 8x8 block of accumulators used to implement the |
528 | // channels-are-columns case. |
529 | if (transpose_around_multiplier) { |
530 | intrin_utils::mm256_transpose8x8_epi32<path>( |
531 | &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
532 | &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
533 | } |
534 | } |
535 | const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
536 | const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
537 | const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && |
538 | (residual_cols == kAvx8bitBlockSize); |
539 | |
540 | __m256i accum_data_v[kAvx8bitBlockSize]; |
541 | if (!store_full_block) { |
542 | accum_data_v[0] = accum_data_v0; |
543 | accum_data_v[1] = accum_data_v1; |
544 | accum_data_v[2] = accum_data_v2; |
545 | accum_data_v[3] = accum_data_v3; |
546 | accum_data_v[4] = accum_data_v4; |
547 | accum_data_v[5] = accum_data_v5; |
548 | accum_data_v[6] = accum_data_v6; |
549 | accum_data_v[7] = accum_data_v7; |
550 | } |
551 | |
552 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
553 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
554 | if (store_full_block) { |
555 | accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); |
556 | accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); |
557 | accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); |
558 | accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); |
559 | accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); |
560 | accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); |
561 | accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); |
562 | accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); |
563 | accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); |
564 | accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); |
565 | accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); |
566 | accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); |
567 | accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); |
568 | accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); |
569 | accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); |
570 | accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); |
571 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
572 | &tmp_ptr[0 * dst_stride], accum_data_v0); |
573 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
574 | &tmp_ptr[1 * dst_stride], accum_data_v1); |
575 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
576 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
577 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
578 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
579 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
580 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
581 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
582 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
583 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
584 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
585 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
586 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
587 | } else { |
588 | for (int j = 0; j < residual_cols; ++j) { |
589 | __m256i result = accum_data_v[j]; |
590 | result = _mm256_min_epi32(result, clamp_max_v); |
591 | result = _mm256_max_epi32(result, clamp_min_v); |
592 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
593 | tmp_ptr, residual_rows, result); |
594 | tmp_ptr += dst_stride; |
595 | } |
596 | } |
597 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
598 | kAvx8bitBlockSize); |
599 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
600 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
601 | if (store_full_block) { |
602 | accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); |
603 | accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); |
604 | accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); |
605 | accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); |
606 | accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); |
607 | accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); |
608 | accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); |
609 | accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); |
610 | accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); |
611 | accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); |
612 | accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); |
613 | accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); |
614 | accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); |
615 | accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); |
616 | accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); |
617 | accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); |
618 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], |
619 | accum_data_v0); |
620 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], |
621 | accum_data_v1); |
622 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
623 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
624 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
625 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
626 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
627 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
628 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
629 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
630 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
631 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
632 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
633 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
634 | } else { |
635 | for (int j = 0; j < residual_cols; ++j) { |
636 | __m256i result = accum_data_v[j]; |
637 | result = _mm256_min_epi32(result, clamp_max_v); |
638 | result = _mm256_max_epi32(result, clamp_min_v); |
639 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
640 | tmp_ptr, residual_rows, result); |
641 | tmp_ptr += dst_stride; |
642 | } |
643 | } |
644 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
645 | kAvx8bitBlockSize); |
646 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
647 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
648 | if (store_full_block) { |
649 | accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); |
650 | accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); |
651 | accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); |
652 | accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); |
653 | accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); |
654 | accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); |
655 | accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); |
656 | accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); |
657 | accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); |
658 | accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); |
659 | accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); |
660 | accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); |
661 | accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); |
662 | accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); |
663 | accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); |
664 | accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); |
665 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], |
666 | accum_data_v0); |
667 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], |
668 | accum_data_v1); |
669 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
670 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
671 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
672 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
673 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
674 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
675 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
676 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
677 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
678 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
679 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
680 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
681 | } else { |
682 | for (int j = 0; j < residual_cols; ++j) { |
683 | __m256i result = accum_data_v[j]; |
684 | result = _mm256_min_epi32(result, clamp_max_v); |
685 | result = _mm256_max_epi32(result, clamp_min_v); |
686 | intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( |
687 | tmp_ptr, residual_rows, result); |
688 | tmp_ptr += dst_stride; |
689 | } |
690 | } |
691 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
692 | kAvx8bitBlockSize); |
693 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
694 | if (store_full_block) { |
695 | std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
696 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); |
697 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], |
698 | accum_data_v1); |
699 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], |
700 | accum_data_v2); |
701 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], |
702 | accum_data_v3); |
703 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], |
704 | accum_data_v4); |
705 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], |
706 | accum_data_v5); |
707 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], |
708 | accum_data_v6); |
709 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], |
710 | accum_data_v7); |
711 | } else { |
712 | std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
713 | for (int j = 0; j < residual_cols; ++j) { |
714 | intrin_utils::mm256_n_storeu_epi32<path>( |
715 | dst_block_ptr, residual_rows, accum_data_v[j]); |
716 | dst_block_ptr += dst_stride; |
717 | } |
718 | } |
719 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
720 | kAvx8bitBlockSize); |
721 | } else { |
722 | RUY_DCHECK(false); |
723 | } |
724 | |
725 | lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
726 | } // End row-block loop. |
727 | |
728 | dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
729 | kAvx8bitBlockSize * params.dst_stride); |
730 | rhs_col_ptr = |
731 | static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) + |
732 | kAvx8bitBlockSize * params.rhs_stride); |
733 | } // End col-block loop. |
734 | } // NOLINT(readability/fn_size) |
735 | |
736 | void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { |
737 | Kernel8bitAvx2Impl<Path::kAvx2Fma>(params); |
738 | } |
739 | |
740 | template <Path path> |
741 | void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { |
742 | profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV" ); |
743 | |
744 | RUY_DCHECK_EQ(params.dst_cols, 1); |
745 | RUY_DCHECK_EQ(params.last_col, 0); |
746 | RUY_DCHECK_EQ(params.start_col, 0); |
747 | |
748 | const std::int8_t splitter_idx_data[32] = { |
749 | 0, 1, 4, 5, 8, 9, 12, 13, // |
750 | 2, 3, 6, 7, 10, 11, 14, 15, // |
751 | 0, 1, 4, 5, 8, 9, 12, 13, // |
752 | 2, 3, 6, 7, 10, 11, 14, 15 // |
753 | }; |
754 | |
755 | int bias_ptr_block_increment = |
756 | params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; |
757 | |
758 | const void* rhs_col_ptr = params.rhs_base_ptr; |
759 | void* dst_col_ptr = params.dst_base_ptr; |
760 | const std::int32_t* bias_col_ptr = params.bias; |
761 | if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
762 | bias_col_ptr += params.start_row; |
763 | } |
764 | |
765 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
766 | void* dst_ptr = dst_col_ptr; |
767 | const std::int32_t* bias_ptr = bias_col_ptr; |
768 | |
769 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
770 | const bool has_rhs_sums_offsets = |
771 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
772 | std::int32_t rhs_sums_offsets[8]; |
773 | if (has_rhs_sums_offsets) { |
774 | const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( |
775 | _mm256_set1_epi32(lhs_zero_point), |
776 | _mm256_loadu_si256( |
777 | reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); |
778 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
779 | rhs_sums_offset_v); |
780 | } |
781 | |
782 | for (int row = params.start_row; row <= params.last_row; |
783 | row += kAvx8bitBlockSize) { |
784 | const int residual_rows = |
785 | std::min(params.dst_rows - row, kAvx8bitBlockSize); |
786 | |
787 | const __m256i splitter_idx = |
788 | _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); |
789 | |
790 | __m256i accum_data_v0; |
791 | |
792 | // Initialize with bias. |
793 | __m256i initial_accum_data = |
794 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); |
795 | bias_ptr += bias_ptr_block_increment; |
796 | |
797 | // Adjustments common across columns. |
798 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
799 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
800 | const __m256i lhs_sums_offset = _mm256_mullo_epi32( |
801 | _mm256_set1_epi32(rhs_zero_point), |
802 | _mm256_loadu_si256( |
803 | reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); |
804 | initial_accum_data = |
805 | _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); |
806 | } |
807 | const std::int32_t prod_zp_depth = params.prod_zp_depth; |
808 | if (prod_zp_depth) { |
809 | initial_accum_data = _mm256_add_epi32(initial_accum_data, |
810 | _mm256_set1_epi32(prod_zp_depth)); |
811 | } |
812 | |
813 | // Adjustments differing across columns. |
814 | if (has_rhs_sums_offsets) { |
815 | accum_data_v0 = _mm256_sub_epi32(initial_accum_data, |
816 | _mm256_set1_epi32(rhs_sums_offsets[0])); |
817 | } else { |
818 | accum_data_v0 = initial_accum_data; |
819 | } |
820 | |
821 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
822 | const void* rhs_ptr = rhs_col_ptr; |
823 | for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
824 | const __m256i lhs_data = |
825 | _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
826 | const std::int32_t* rhs_data = |
827 | reinterpret_cast<const std::int32_t*>(rhs_ptr); |
828 | |
829 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
830 | // For simplicity we load 4x the data that we need and process twice the |
831 | // data that we need and store only the data we need. |
832 | std::int32_t rhs_data_buf[2]; |
833 | if (params.rhs_scalar_size == 1) { |
834 | rhs_data = rhs_data_buf; |
835 | const __m128i rhs_data_8bit = |
836 | intrin_utils::mm_loadu_si32<path>(rhs_ptr); |
837 | const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); |
838 | // Now that we have cast the RHS data, we store it so that each value |
839 | // can be separately loaded in the accumulation loop. |
840 | _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), |
841 | rhs_16_bit_dup); |
842 | } else { |
843 | RUY_DCHECK(params.rhs_scalar_size == 2); |
844 | } |
845 | |
846 | // NOTE: There may be opportunities for permuting the data in the packing |
847 | // code instead of here. |
848 | const __m256i lhs_data_split = |
849 | _mm256_shuffle_epi8(lhs_data, splitter_idx); |
850 | const __m256i lhs_data_split_expand_bottom = |
851 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); |
852 | const __m256i lhs_data_split_expand_top = |
853 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); |
854 | |
855 | // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
856 | const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( |
857 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
858 | // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
859 | const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( |
860 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
861 | // Accumulate for column 0. |
862 | const std::int32_t low_rhs_value = rhs_data[0]; |
863 | const std::int32_t high_rhs_value = rhs_data[1]; |
864 | |
865 | const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); |
866 | const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); |
867 | |
868 | accum_data_v0 = _mm256_add_epi32( |
869 | accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
870 | accum_data_v0 = _mm256_add_epi32( |
871 | accum_data_v0, |
872 | _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
873 | |
874 | lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
875 | rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + |
876 | kAvx8bitBlockSize * kAvx8bitInnerSize * |
877 | params.rhs_scalar_size); |
878 | } |
879 | |
880 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
881 | __m256i m_vector; |
882 | __m256i e_vector; |
883 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
884 | int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; |
885 | m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
886 | params.multiplier_fixedpoint + channel)); |
887 | e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
888 | params.multiplier_exponent + channel)); |
889 | |
890 | const __m256i m_64bit_low = |
891 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); |
892 | const __m256i m_64bit_high = |
893 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); |
894 | |
895 | const __m256i zero_vector = _mm256_setzero_si256(); |
896 | const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); |
897 | const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); |
898 | const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); |
899 | const __m256i final_right_shift = _mm256_set1_epi32(31); |
900 | const __m256i final_right_shift_low = |
901 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); |
902 | const __m256i final_right_shift_high = |
903 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); |
904 | const __m256i convert_to_unsigned_64 = |
905 | _mm256_set1_epi64x(0x8000000000000000); |
906 | |
907 | __m256i post_scaling_offset = _mm256_setzero_si256(); |
908 | // A "half" added for rounding prior to truncation of 64-bit value. |
909 | const __m256i offset_vector = _mm256_add_epi64( |
910 | _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), |
911 | convert_to_unsigned_64); |
912 | |
913 | if (params.dst_zero_point) { |
914 | post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
915 | } |
916 | |
917 | const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); |
918 | |
919 | // See GEMM version for details of this process. |
920 | { |
921 | __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); |
922 | // Apply the fixed-point part of the multiplier. |
923 | __m256i scaled_v_low = _mm256_mul_epi32( |
924 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), |
925 | m_64bit_low); |
926 | __m256i scaled_v_high = _mm256_mul_epi32( |
927 | _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), |
928 | m_64bit_high); |
929 | |
930 | scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); |
931 | scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); |
932 | |
933 | scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); |
934 | scaled_v_high = |
935 | _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); |
936 | |
937 | scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); |
938 | __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); |
939 | results = _mm256_permutevar8x32_epi32(results, repack_perm); |
940 | |
941 | // Now do a Rounding Right Shift. |
942 | // First, construct the nudge value for each lane. |
943 | const __m256i zeros = _mm256_setzero_si256(); |
944 | const __m256i mask_rightshift_gtz = |
945 | _mm256_cmpgt_epi32(right_shift, zeros); |
946 | const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( |
947 | _mm256_set1_epi32(1), |
948 | _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1))); |
949 | __m256i nudge = intrin_utils::mm256_blendv_epi32( |
950 | zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
951 | // Calculate the shifted sum (results + nudge) >> exp. |
952 | const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); |
953 | const __m256i shifted_sum = |
954 | _mm256_srav_epi32(r_plus_nudge, right_shift); |
955 | |
956 | // Identify overflow in each lane and create mask. |
957 | const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( |
958 | _mm256_set1_epi32(1), |
959 | _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift)); |
960 | const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( |
961 | results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); |
962 | // Fill results with either (results + nudge) >> exponent or |
963 | // 1 << (31 - exp) in the case of overflow. |
964 | results = intrin_utils::mm256_blendv_epi32( |
965 | shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
966 | |
967 | accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset); |
968 | } |
969 | } |
970 | const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
971 | const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
972 | |
973 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
974 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
975 | __m256i result = accum_data_v0; |
976 | result = _mm256_min_epi32(result, clamp_max_v); |
977 | result = _mm256_max_epi32(result, clamp_min_v); |
978 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
979 | result); |
980 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
981 | kAvx8bitBlockSize); |
982 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
983 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
984 | __m256i result = accum_data_v0; |
985 | result = _mm256_min_epi32(result, clamp_max_v); |
986 | result = _mm256_max_epi32(result, clamp_min_v); |
987 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
988 | result); |
989 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
990 | kAvx8bitBlockSize); |
991 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
992 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
993 | __m256i result = accum_data_v0; |
994 | result = _mm256_min_epi32(result, clamp_max_v); |
995 | result = _mm256_max_epi32(result, clamp_min_v); |
996 | intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, |
997 | result); |
998 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
999 | kAvx8bitBlockSize); |
1000 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
1001 | std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
1002 | intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, |
1003 | accum_data_v0); |
1004 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
1005 | kAvx8bitBlockSize); |
1006 | } else { |
1007 | RUY_DCHECK(false); |
1008 | } |
1009 | |
1010 | lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
1011 | } // End row-block loop. |
1012 | |
1013 | dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
1014 | kAvx8bitBlockSize * params.dst_stride); |
1015 | rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) + |
1016 | kAvx8bitBlockSize * params.rhs_stride); |
1017 | } // NOLINT(readability/fn_size) |
1018 | |
1019 | void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { |
1020 | Kernel8bitAvx2SingleColImpl<Path::kAvx2Fma>(params); |
1021 | } |
1022 | |
1023 | void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { |
1024 | profiler::ScopeLabel label("Kernel kAvx2Fma float" ); |
1025 | KernelFloatAvxCommon<Path::kAvx2Fma>(params); |
1026 | } |
1027 | |
1028 | void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { |
1029 | profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV" ); |
1030 | KernelFloatAvxCommonSingleCol<Path::kAvx2Fma>(params); |
1031 | } |
1032 | |
1033 | #endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) |
1034 | |
1035 | } // namespace ruy |
1036 | |