1 | /* Copyright 2020 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cstdint> |
18 | #include <cstring> |
19 | |
20 | #include "ruy/check_macros.h" |
21 | #include "ruy/kernel_common.h" |
22 | #include "ruy/kernel_x86.h" |
23 | #include "ruy/opt_set.h" |
24 | #include "ruy/platform.h" |
25 | #include "ruy/profiler/instrumentation.h" |
26 | |
27 | #if RUY_PLATFORM_AVX && RUY_OPT(ASM) |
28 | #include <immintrin.h> // IWYU pragma: keep |
29 | #endif |
30 | |
31 | namespace ruy { |
32 | |
33 | #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) |
34 | |
35 | void Kernel8bitAvx(const KernelParams8bit<8, 8>&) { |
36 | // CPU-ID-based checks should disable the path that would reach this point. |
37 | RUY_DCHECK(false); |
38 | } |
39 | |
40 | void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>&) { |
41 | // CPU-ID-based checks should disable the path that would reach this point. |
42 | RUY_DCHECK(false); |
43 | } |
44 | |
45 | void KernelFloatAvx(const KernelParamsFloat<8, 8>&) { |
46 | // CPU-ID-based checks should disable the path that would reach this point. |
47 | RUY_DCHECK(false); |
48 | } |
49 | |
50 | void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>&) { |
51 | // CPU-ID-based checks should disable the path that would reach this point. |
52 | RUY_DCHECK(false); |
53 | } |
54 | |
55 | #else // RUY_PLATFORM_AVX && RUY_OPT(ASM) |
56 | |
57 | static constexpr int kAvx8bitBlockSize = 8; |
58 | static constexpr int kAvx8bitInnerSize = 4; |
59 | |
60 | namespace { |
61 | namespace intrin_utils { |
62 | |
63 | template <> |
64 | inline __m256i mm256_shuffle_epi8<Path::kAvx>(const __m256i& a, |
65 | const __m256i& b) { |
66 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
67 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
68 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
69 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
70 | __m128i dst_lo = _mm_shuffle_epi8(a_lo, b_lo); |
71 | __m128i dst_hi = _mm_shuffle_epi8(a_hi, b_hi); |
72 | return _mm256_set_m128i(dst_hi, dst_lo); |
73 | } |
74 | |
75 | template <> |
76 | inline __m128i <Path::kAvx>(const __m256i& a, |
77 | const int imm) { |
78 | switch (imm) { |
79 | case 0: |
80 | return _mm256_extractf128_si256(a, 0); |
81 | case 1: |
82 | return _mm256_extractf128_si256(a, 1); |
83 | default: |
84 | RUY_DCHECK_LT(imm, 2); |
85 | return _mm_setzero_si128(); |
86 | } |
87 | } |
88 | |
89 | template <Path path> |
90 | inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { |
91 | // Take the upper 64 bits of a and put in the first 64 bits of 'hi' |
92 | __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); |
93 | return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); |
94 | } |
95 | |
96 | template <Path path> |
97 | inline __m256i mm256_cvtepi32_epi64(const __m128i& a) { |
98 | // sign extend the 32-bit values in the lower 64 bits of a. |
99 | __m128i lo = _mm_cvtepi32_epi64(a); |
100 | __m128i hi = _mm_cvtepi32_epi64(_mm_unpackhi_epi64(a, _mm_setzero_si128())); |
101 | return _mm256_set_m128i(hi, lo); |
102 | } |
103 | |
104 | inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, |
105 | const int imm) { |
106 | __m128i tmp = _mm_setzero_si128(); |
107 | if (!(imm & 8)) { |
108 | switch (imm & 3) { |
109 | case 0: |
110 | return _mm256_extractf128_si256(a, 0); |
111 | case 1: |
112 | return _mm256_extractf128_si256(a, 1); |
113 | case 2: |
114 | return _mm256_extractf128_si256(b, 0); |
115 | case 3: |
116 | return _mm256_extractf128_si256(b, 1); |
117 | } |
118 | } |
119 | return tmp; |
120 | } |
121 | |
122 | template <Path path> |
123 | inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, |
124 | const int imm) { |
125 | const int lo_imm = imm & 15; |
126 | __m128i lo = mm_permute_helper(a, b, lo_imm); |
127 | const int hi_imm = (imm >> 4) & 15; |
128 | __m128i hi = mm_permute_helper(a, b, hi_imm); |
129 | return _mm256_set_m128i(hi, lo); |
130 | } |
131 | |
132 | template <Path path> |
133 | inline __m256i mm256_max_epi32(const __m256i& a, const __m256i& b) { |
134 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
135 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
136 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
137 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
138 | __m128i lo = _mm_max_epi32(a_lo, b_lo); |
139 | __m128i hi = _mm_max_epi32(a_hi, b_hi); |
140 | return _mm256_set_m128i(hi, lo); |
141 | } |
142 | |
143 | template <Path path> |
144 | inline __m256i mm256_min_epi32(const __m256i& a, const __m256i& b) { |
145 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
146 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
147 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
148 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
149 | __m128i lo = _mm_min_epi32(a_lo, b_lo); |
150 | __m128i hi = _mm_min_epi32(a_hi, b_hi); |
151 | return _mm256_set_m128i(hi, lo); |
152 | } |
153 | |
154 | template <Path path> |
155 | inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { |
156 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
157 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
158 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
159 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
160 | __m128i lo = _mm_add_epi32(a_lo, b_lo); |
161 | __m128i hi = _mm_add_epi32(a_hi, b_hi); |
162 | return _mm256_set_m128i(hi, lo); |
163 | } |
164 | |
165 | template <Path path> |
166 | inline __m256i mm256_add_epi64(const __m256i& a, const __m256i& b) { |
167 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
168 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
169 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
170 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
171 | __m128i lo = _mm_add_epi64(a_lo, b_lo); |
172 | __m128i hi = _mm_add_epi64(a_hi, b_hi); |
173 | return _mm256_set_m128i(hi, lo); |
174 | } |
175 | |
176 | template <Path path> |
177 | inline __m256i mm256_slli_epi64(const __m256i& a, int imm) { |
178 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
179 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
180 | __m128i lo = _mm_slli_epi64(a_lo, imm); |
181 | __m128i hi = _mm_slli_epi64(a_hi, imm); |
182 | return _mm256_set_m128i(hi, lo); |
183 | } |
184 | |
185 | template <Path path> |
186 | inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) { |
187 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
188 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
189 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
190 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
191 | __m128i lo = _mm_mullo_epi32(a_lo, b_lo); |
192 | __m128i hi = _mm_mullo_epi32(a_hi, b_hi); |
193 | return _mm256_set_m128i(hi, lo); |
194 | } |
195 | |
196 | // Defined as a macro since `imm` must be an immediate. |
197 | #define BlendM128_epi32(a, b, imm) \ |
198 | _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm)) |
199 | |
200 | // Defined as a macro since `imm` must be an immediate. |
201 | #define BlendM128_epi64(a, b, imm) \ |
202 | _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm)) |
203 | |
204 | // Defined as a macro since `imm` must be an immediate. |
205 | #define mm256_blend_epi32(ans, a, b, imm) \ |
206 | __m128i a_lo = _mm256_extractf128_si256(a, 0); \ |
207 | __m128i a_hi = _mm256_extractf128_si256(a, 1); \ |
208 | __m128i b_lo = _mm256_extractf128_si256(b, 0); \ |
209 | __m128i b_hi = _mm256_extractf128_si256(b, 1); \ |
210 | __m128i lo = BlendM128_epi32(a_lo, b_lo, imm & 0xe); \ |
211 | __m128i hi = BlendM128_epi32(a_hi, b_hi, imm >> 4); \ |
212 | ans = _mm256_set_m128i(hi, lo); |
213 | |
214 | #define mm256_shuffle_epi32(ans, a, a_lo, a_hi, imm) \ |
215 | a_lo = _mm256_extractf128_si256(a, 0); \ |
216 | a_hi = _mm256_extractf128_si256(a, 1); \ |
217 | ans = _mm256_set_m128i(_mm_shuffle_epi32(a_hi, imm), \ |
218 | _mm_shuffle_epi32(a_lo, imm)); |
219 | |
220 | template <Path path> |
221 | inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { |
222 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
223 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
224 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
225 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
226 | __m128i lo = _mm_madd_epi16(a_lo, b_lo); |
227 | __m128i hi = _mm_madd_epi16(a_hi, b_hi); |
228 | return _mm256_set_m128i(hi, lo); |
229 | } |
230 | |
231 | inline __m128i mm_srlv_epi64(const __m128i& a, const __m128i& b) { |
232 | // shift both elements of a by lower 64bits of b. |
233 | __m128i res_lo = _mm_srl_epi64(a, b); |
234 | // shift both elements of a by upper 64bits of b. |
235 | __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); |
236 | __m128i res_hi = _mm_srl_epi64(a, hi_count); |
237 | // Take the lower 64 bits of res_lo and upper 64 bits of res hi |
238 | // 1. Swap the upper and lower 64 bits of res_hi |
239 | __m128i tmp_hi = |
240 | _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); |
241 | // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
242 | return _mm_unpacklo_epi64(res_lo, tmp_hi); |
243 | } |
244 | |
245 | template <Path path> |
246 | inline __m256i mm256_srlv_epi64(const __m256i& a, const __m256i& b) { |
247 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
248 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
249 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
250 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
251 | __m128i lo = mm_srlv_epi64(a_lo, b_lo); |
252 | __m128i hi = mm_srlv_epi64(a_hi, b_hi); |
253 | return _mm256_set_m128i(hi, lo); |
254 | } |
255 | |
256 | template <Path path> |
257 | inline __m128i mm_sllv_epi64(const __m128i& a, const __m128i& b) { |
258 | // shift both elements of a by lower 64bits of b. |
259 | __m128i res_lo = _mm_sll_epi64(a, b); |
260 | // shift both elements of a by upper 64bits of b. |
261 | __m128i hi_count = _mm_unpackhi_epi64(b, _mm_setzero_si128()); |
262 | __m128i res_hi = _mm_sll_epi64(a, hi_count); |
263 | // Take the lower 64 bits of res_lo and upper 64 bits of res hi |
264 | // 1. Swap the upper and lower 64 bits of res_hi |
265 | __m128i tmp_hi = |
266 | _mm_castpd_si128(_mm_permute_pd(_mm_castsi128_pd(res_hi), 1)); |
267 | // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
268 | return _mm_unpacklo_epi64(res_lo, tmp_hi); |
269 | } |
270 | |
271 | template <Path path> |
272 | inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) { |
273 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
274 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
275 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
276 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
277 | __m128i lo = mm_sllv_epi64<path>(a_lo, b_lo); |
278 | __m128i hi = mm_sllv_epi64<path>(a_hi, b_hi); |
279 | return _mm256_set_m128i(hi, lo); |
280 | } |
281 | |
282 | #define PermuteM128_epi32(a, imm) \ |
283 | _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm)); |
284 | |
285 | inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) { |
286 | // shift all elements of a by first 32bits of b. |
287 | __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); |
288 | |
289 | // put bits 32-63 of b in the first slot. |
290 | __m128i tmp1 = PermuteM128_epi32(b, 1); |
291 | // put bits 32-63 of a in the first slot. |
292 | __m128i a1 = PermuteM128_epi32(a, 1); |
293 | // shift all elements of a by second 32bits of b. |
294 | __m128i res1 = |
295 | _mm_sll_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); |
296 | |
297 | // put bits 64-95 of b in the first slot. |
298 | __m128i tmp2 = PermuteM128_epi32(b, 2); |
299 | // shift all elements of a by third 32bits of b. |
300 | __m128i res2 = |
301 | _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); |
302 | |
303 | // put bits 96-127 of b in the first slot. |
304 | __m128i tmp3 = PermuteM128_epi32(b, 3); |
305 | // put bits 96-127 of a in the third slot. |
306 | __m128i a3 = PermuteM128_epi32(a, 48); |
307 | // shift all elements of a3 by fourth 32bits of b. |
308 | __m128i res3 = |
309 | _mm_sll_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); |
310 | |
311 | // Take bits 0-31 of res0, bits 0-31 of res1, |
312 | // bits 64-95 of res2, and bits 64-95 of res3. |
313 | // res0 _ _ _ 0 |
314 | // res1 _ _ _ 1 |
315 | // res2 _ 2 _ _ |
316 | // res3 _ 3 _ _ |
317 | // f_01 _ _ 1 0 |
318 | // f_23 _ _ 3 2 |
319 | __m128i f_01 = _mm_unpacklo_epi32(res0, res1); |
320 | __m128i f_23 = _mm_unpackhi_epi32(res2, res3); |
321 | // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. |
322 | return _mm_unpacklo_epi64(f_01, f_23); |
323 | } |
324 | |
325 | template <Path path> |
326 | inline __m256i mm256_sllv_epi32(const __m256i& a, const __m256i& b) { |
327 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
328 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
329 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
330 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
331 | __m128i lo = mm_sllv_epi32(a_lo, b_lo); |
332 | __m128i hi = mm_sllv_epi32(a_hi, b_hi); |
333 | return _mm256_set_m128i(hi, lo); |
334 | } |
335 | |
336 | template <Path path> |
337 | inline __m256i mm256_sub_epi32(const __m256i& a, const __m256i& b) { |
338 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
339 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
340 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
341 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
342 | __m128i lo = _mm_sub_epi32(a_lo, b_lo); |
343 | __m128i hi = _mm_sub_epi32(a_hi, b_hi); |
344 | return _mm256_set_m128i(hi, lo); |
345 | } |
346 | |
347 | template <Path path> |
348 | inline __m256i mm256_mul_epi32(const __m256i& a, const __m256i& b) { |
349 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
350 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
351 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
352 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
353 | __m128i lo = _mm_mul_epi32(a_lo, b_lo); |
354 | __m128i hi = _mm_mul_epi32(a_hi, b_hi); |
355 | return _mm256_set_m128i(hi, lo); |
356 | } |
357 | |
358 | // Perform the equivalent of mm256_permutevar8x32 with |
359 | // a second argument of {7, 5, 3, 1, 6, 4, 2, 0} |
360 | template <Path path> |
361 | inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { |
362 | // a_lo = 3 2 1 0 |
363 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
364 | // a_hi = 7 6 5 4 |
365 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
366 | // shuffle a_lo to get 3 1 2 0 |
367 | __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); |
368 | // shuffle a_hi to get 7 5 6 4 |
369 | __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); |
370 | // unpack lo 64 of res_lo and res hi to get 6 4 2 0 |
371 | __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); |
372 | // unpack hi 64 of res_lo and res hi to get 7 5 1 3 |
373 | __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); |
374 | return _mm256_set_m128i(res_hi, res_lo); |
375 | } |
376 | |
377 | template <Path path> |
378 | inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) { |
379 | const __m256i bias0 = _mm256_set1_epi32(*(bias + offset)); |
380 | return mm256_add_epi32<path>(a, bias0); |
381 | } |
382 | |
383 | __m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, |
384 | const __m256i& mask) { |
385 | __m256 result = |
386 | _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), |
387 | _mm256_castsi256_ps(mask)); |
388 | return _mm256_castps_si256(result); |
389 | } |
390 | |
391 | template <Path path> |
392 | inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) { |
393 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
394 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
395 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
396 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
397 | __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo); |
398 | __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi); |
399 | return _mm256_set_m128i(hi, lo); |
400 | } |
401 | |
402 | template <Path path> |
403 | inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) { |
404 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
405 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
406 | |
407 | __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0)); |
408 | __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1)); |
409 | __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2)); |
410 | __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3)); |
411 | __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4)); |
412 | __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5)); |
413 | __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6)); |
414 | __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7)); |
415 | |
416 | // get element 0 from r0, element 1 from r1 |
417 | __m128i r01 = BlendM128_epi32(r0, r1, 2); |
418 | // get element 2 from r2, element 3 from r3 |
419 | __m128i r23 = BlendM128_epi32(r2, r3, 8); |
420 | // get element 0 from r4, element 1 from r5 |
421 | __m128i r45 = BlendM128_epi32(r4, r5, 2); |
422 | // get element 2 from r6, element 3 from r7 |
423 | __m128i r67 = BlendM128_epi32(r6, r7, 8); |
424 | // get lower 64 bits of r01, upper 64 bits of r23 |
425 | __m128i r0123 = BlendM128_epi64(r01, r23, 2); |
426 | // get lower 64 bits of r45, upper 64 bits of r67 |
427 | __m128i r4567 = BlendM128_epi64(r45, r67, 2); |
428 | return _mm256_set_m128i(r4567, r0123); |
429 | } |
430 | |
431 | // AVX doesn't have fused multiply-add so we define an inline function to be |
432 | // used in the common code following. |
433 | template <> |
434 | inline __m256 MulAdd<Path::kAvx>(const __m256& a, const __m256& b, |
435 | const __m256& c) { |
436 | const __m256 prod = _mm256_mul_ps(a, b); |
437 | return _mm256_add_ps(prod, c); |
438 | } |
439 | |
440 | } // namespace intrin_utils |
441 | } // namespace |
442 | |
443 | template <Path path> |
444 | void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { |
445 | profiler::ScopeLabel label("Kernel kAvx 8-bit" ); |
446 | const std::int8_t splitter_idx_data[32] = { |
447 | 0, 1, 4, 5, 8, 9, 12, 13, // |
448 | 2, 3, 6, 7, 10, 11, 14, 15, // |
449 | 0, 1, 4, 5, 8, 9, 12, 13, // |
450 | 2, 3, 6, 7, 10, 11, 14, 15 // |
451 | }; |
452 | |
453 | std::int32_t dst_stride = 0; |
454 | if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || |
455 | (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { |
456 | dst_stride = params.dst_stride; |
457 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
458 | dst_stride = params.dst_stride / sizeof(std::int16_t); |
459 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
460 | dst_stride = params.dst_stride / sizeof(std::int32_t); |
461 | } else { |
462 | RUY_DCHECK(false); |
463 | } |
464 | |
465 | const std::int8_t* rhs_col_ptr = |
466 | static_cast<const int8_t*>(params.rhs_base_ptr); |
467 | void* dst_col_ptr = params.dst_base_ptr; |
468 | |
469 | for (int col = params.start_col; col <= params.last_col; |
470 | col += kAvx8bitBlockSize) { |
471 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
472 | void* dst_ptr = dst_col_ptr; |
473 | |
474 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
475 | const bool has_rhs_sums_offsets = |
476 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
477 | std::int32_t rhs_sums_offsets[8]; |
478 | if (has_rhs_sums_offsets) { |
479 | const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( |
480 | _mm256_set1_epi32(lhs_zero_point), |
481 | _mm256_loadu_si256( |
482 | reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); |
483 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
484 | rhs_sums_offset_v); |
485 | } |
486 | |
487 | for (int row = params.start_row; row <= params.last_row; |
488 | row += kAvx8bitBlockSize) { |
489 | int channel = |
490 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) ? col : row; |
491 | int multiplier_channel = |
492 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? channel : 0; |
493 | const int residual_rows = |
494 | std::min(params.dst_rows - row, kAvx8bitBlockSize); |
495 | const int residual_cols = |
496 | std::min(params.dst_cols - col, kAvx8bitBlockSize); |
497 | |
498 | const __m256i splitter_idx = _mm256_loadu_si256( |
499 | reinterpret_cast<__m256i const*>(splitter_idx_data)); |
500 | |
501 | __m256i accum_data_v0; |
502 | __m256i accum_data_v1; |
503 | __m256i accum_data_v2; |
504 | __m256i accum_data_v3; |
505 | __m256i accum_data_v4; |
506 | __m256i accum_data_v5; |
507 | __m256i accum_data_v6; |
508 | __m256i accum_data_v7; |
509 | |
510 | // initial_accum_data will be the initialize of each of the |
511 | // accum_data_* accumulator registers. We compute into it terms that are |
512 | // identical across columns. |
513 | __m128i initial_accum_data_lo = _mm_set1_epi32(params.prod_zp_depth); |
514 | __m128i initial_accum_data_hi = _mm_set1_epi32(params.prod_zp_depth); |
515 | |
516 | // In the channels-are-rows case, we can load bias here. |
517 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
518 | !(params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
519 | initial_accum_data_lo = _mm_add_epi32( |
520 | initial_accum_data_lo, |
521 | _mm_loadu_si128( |
522 | reinterpret_cast<const __m128i*>(params.bias + row))); |
523 | initial_accum_data_hi = _mm_add_epi32( |
524 | initial_accum_data_hi, |
525 | _mm_loadu_si128( |
526 | reinterpret_cast<const __m128i*>(params.bias + row + 4))); |
527 | } |
528 | |
529 | // Adjustments common across columns. |
530 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
531 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
532 | const __m128i rhs_zp = _mm_set1_epi32(rhs_zero_point); |
533 | const __m128i lhs_sums_offset_lo = _mm_mullo_epi32( |
534 | rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( |
535 | ¶ms.lhs_sums[row]))); |
536 | const __m128i lhs_sums_offset_hi = _mm_mullo_epi32( |
537 | rhs_zp, _mm_loadu_si128(reinterpret_cast<__m128i const*>( |
538 | ¶ms.lhs_sums[row + 4]))); |
539 | |
540 | initial_accum_data_lo = |
541 | _mm_sub_epi32(initial_accum_data_lo, lhs_sums_offset_lo); |
542 | initial_accum_data_hi = |
543 | _mm_sub_epi32(initial_accum_data_hi, lhs_sums_offset_hi); |
544 | } |
545 | |
546 | // Adjustments differing across columns. |
547 | if (has_rhs_sums_offsets) { |
548 | __m256i initial_accum_data = |
549 | _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); |
550 | |
551 | accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( |
552 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); |
553 | accum_data_v1 = intrin_utils::mm256_sub_epi32<path>( |
554 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); |
555 | accum_data_v2 = intrin_utils::mm256_sub_epi32<path>( |
556 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); |
557 | accum_data_v3 = intrin_utils::mm256_sub_epi32<path>( |
558 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); |
559 | accum_data_v4 = intrin_utils::mm256_sub_epi32<path>( |
560 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); |
561 | accum_data_v5 = intrin_utils::mm256_sub_epi32<path>( |
562 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); |
563 | accum_data_v6 = intrin_utils::mm256_sub_epi32<path>( |
564 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); |
565 | accum_data_v7 = intrin_utils::mm256_sub_epi32<path>( |
566 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); |
567 | } else { |
568 | __m256i initial_accum_data = |
569 | _mm256_set_m128i(initial_accum_data_hi, initial_accum_data_lo); |
570 | accum_data_v0 = initial_accum_data; |
571 | accum_data_v1 = initial_accum_data; |
572 | accum_data_v2 = initial_accum_data; |
573 | accum_data_v3 = initial_accum_data; |
574 | accum_data_v4 = initial_accum_data; |
575 | accum_data_v5 = initial_accum_data; |
576 | accum_data_v6 = initial_accum_data; |
577 | accum_data_v7 = initial_accum_data; |
578 | } |
579 | |
580 | // Finally, in the channels-are-columns case, load bias data here. |
581 | if ((params.flags & RUY_ASM_FLAG_HAS_BIAS) && |
582 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL)) { |
583 | accum_data_v0 = intrin_utils::AddBiasEpi32<path>(accum_data_v0, |
584 | params.bias + col, 0); |
585 | accum_data_v1 = intrin_utils::AddBiasEpi32<path>(accum_data_v1, |
586 | params.bias + col, 1); |
587 | accum_data_v2 = intrin_utils::AddBiasEpi32<path>(accum_data_v2, |
588 | params.bias + col, 2); |
589 | accum_data_v3 = intrin_utils::AddBiasEpi32<path>(accum_data_v3, |
590 | params.bias + col, 3); |
591 | accum_data_v4 = intrin_utils::AddBiasEpi32<path>(accum_data_v4, |
592 | params.bias + col, 4); |
593 | accum_data_v5 = intrin_utils::AddBiasEpi32<path>(accum_data_v5, |
594 | params.bias + col, 5); |
595 | accum_data_v6 = intrin_utils::AddBiasEpi32<path>(accum_data_v6, |
596 | params.bias + col, 6); |
597 | accum_data_v7 = intrin_utils::AddBiasEpi32<path>(accum_data_v7, |
598 | params.bias + col, 7); |
599 | } |
600 | |
601 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
602 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
603 | for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
604 | const __m256i lhs_data = |
605 | _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
606 | const __m256i rhs_data_8bit = |
607 | _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); |
608 | |
609 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
610 | std::int32_t rhs_data[16]; |
611 | const __m128i rhs_data_bottom_lane = |
612 | _mm256_castsi256_si128(rhs_data_8bit); |
613 | const __m128i rhs_data_top_lane = |
614 | _mm256_extractf128_si256(rhs_data_8bit, 1); |
615 | const __m256i rhs_16_bit_dup_low = |
616 | intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_bottom_lane); |
617 | const __m256i rhs_16_bit_dup_high = |
618 | intrin_utils::mm256_cvtepi8_epi16<path>(rhs_data_top_lane); |
619 | // Now that we have cast the RHS data, we store it so that each value |
620 | // can be separately loaded in the accumulation loop. |
621 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), |
622 | rhs_16_bit_dup_low); |
623 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), |
624 | rhs_16_bit_dup_high); |
625 | |
626 | // NOTE: There may be opportunities for permuting the data in the |
627 | // packing code instead of here. |
628 | const __m256i lhs_data_split = |
629 | intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); |
630 | const __m256i lhs_data_split_expand_bottom = |
631 | intrin_utils::mm256_cvtepi8_epi16<path>( |
632 | _mm256_extractf128_si256(lhs_data_split, 0)); |
633 | const __m256i lhs_data_split_expand_top = |
634 | intrin_utils::mm256_cvtepi8_epi16<path>( |
635 | _mm256_extractf128_si256(lhs_data_split, 1)); |
636 | |
637 | // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
638 | const __m256i lhs_16_bit_low = |
639 | intrin_utils::mm256_permute2x128_si256<path>( |
640 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
641 | // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
642 | const __m256i lhs_16_bit_high = |
643 | intrin_utils::mm256_permute2x128_si256<path>( |
644 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
645 | |
646 | __m256i rhs0 = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>( |
647 | rhs_data)); // Load [0 1 2 3 4 5 6 7] |
648 | __m256i rhs1 = _mm256_lddqu_si256( |
649 | reinterpret_cast<const __m256i*>(rhs_data + 8)); // Load [8 - 15] |
650 | __m256i rhs0_3 = |
651 | _mm256_permute2f128_si256(rhs0, rhs0, 0); // [0 1 2 3 0 1 2 3] |
652 | __m256i rhs4_7 = |
653 | _mm256_permute2f128_si256(rhs0, rhs0, 0x11); // [4 5 6 7 4 5 6 7] |
654 | __m256i rhs8_11 = |
655 | _mm256_permute2f128_si256(rhs1, rhs1, 0); // [8 9 10 11 8 9 10 11] |
656 | __m256i rhs12_15 = |
657 | _mm256_permute2f128_si256(rhs1, rhs1, 17); // [12 - 15, 12 - 15] |
658 | |
659 | auto process_column = [=](__m256i& rhs_dup_lo, __m256i& rhs_dup_hi, |
660 | __m256i& accum) { |
661 | // Perform mul-adds on low and high components of accum separately. |
662 | __m128i accum_lo = _mm256_extractf128_si256(accum, 0); |
663 | __m128i accum_hi = _mm256_extractf128_si256(accum, 1); |
664 | |
665 | __m128i lhs_lo_0 = _mm256_extractf128_si256(lhs_16_bit_low, 0); |
666 | __m128i lhs_lo_1 = _mm256_extractf128_si256(lhs_16_bit_low, 1); |
667 | __m128i rhs_dup_lo_0 = _mm256_extractf128_si256(rhs_dup_lo, 0); |
668 | __m128i rhs_dup_lo_1 = _mm256_extractf128_si256(rhs_dup_lo, 1); |
669 | __m128i lo_0 = _mm_madd_epi16(lhs_lo_0, rhs_dup_lo_0); |
670 | __m128i lo_1 = _mm_madd_epi16(lhs_lo_1, rhs_dup_lo_1); |
671 | |
672 | accum_lo = _mm_add_epi32(accum_lo, lo_0); |
673 | accum_hi = _mm_add_epi32(accum_hi, lo_1); |
674 | |
675 | __m128i lhs_hi_0 = _mm256_extractf128_si256(lhs_16_bit_high, 0); |
676 | __m128i lhs_hi_1 = _mm256_extractf128_si256(lhs_16_bit_high, 1); |
677 | __m128i rhs_dup_hi_0 = _mm256_extractf128_si256(rhs_dup_hi, 0); |
678 | __m128i rhs_dup_hi_1 = _mm256_extractf128_si256(rhs_dup_hi, 1); |
679 | __m128i hi_0 = _mm_madd_epi16(lhs_hi_0, rhs_dup_hi_0); |
680 | __m128i hi_1 = _mm_madd_epi16(lhs_hi_1, rhs_dup_hi_1); |
681 | |
682 | accum_lo = _mm_add_epi32(accum_lo, hi_0); |
683 | accum_hi = _mm_add_epi32(accum_hi, hi_1); |
684 | accum = _mm256_set_m128i(accum_hi, accum_lo); |
685 | }; |
686 | __m256i tmp0, tmp1, tmp2, tmp3; |
687 | __m128i lo0, lo1, hi0, hi1; |
688 | mm256_shuffle_epi32(tmp0, rhs0_3, lo0, hi0, 0); |
689 | mm256_shuffle_epi32(tmp1, rhs0_3, lo1, hi1, 0x55); |
690 | process_column(tmp0, tmp1, accum_data_v0); |
691 | mm256_shuffle_epi32(tmp2, rhs0_3, lo0, hi0, 0xaa); |
692 | mm256_shuffle_epi32(tmp3, rhs0_3, lo1, hi1, 0xff); |
693 | process_column(tmp2, tmp3, accum_data_v1); |
694 | |
695 | mm256_shuffle_epi32(tmp0, rhs4_7, lo0, hi0, 0); |
696 | mm256_shuffle_epi32(tmp1, rhs4_7, lo1, hi1, 0x55); |
697 | process_column(tmp0, tmp1, accum_data_v2); |
698 | mm256_shuffle_epi32(tmp2, rhs4_7, lo0, hi0, 0xaa); |
699 | mm256_shuffle_epi32(tmp3, rhs4_7, lo1, hi1, 0xff); |
700 | process_column(tmp2, tmp3, accum_data_v3); |
701 | |
702 | mm256_shuffle_epi32(tmp0, rhs8_11, lo0, hi0, 0); |
703 | mm256_shuffle_epi32(tmp1, rhs8_11, lo1, hi1, 0x55); |
704 | process_column(tmp0, tmp1, accum_data_v4); |
705 | mm256_shuffle_epi32(tmp2, rhs8_11, lo0, hi0, 0xaa); |
706 | mm256_shuffle_epi32(tmp3, rhs8_11, lo1, hi1, 0xff); |
707 | process_column(tmp2, tmp3, accum_data_v5); |
708 | |
709 | mm256_shuffle_epi32(tmp0, rhs12_15, lo0, hi0, 0); |
710 | mm256_shuffle_epi32(tmp1, rhs12_15, lo1, hi1, 0x55); |
711 | process_column(tmp0, tmp1, accum_data_v6); |
712 | mm256_shuffle_epi32(tmp2, rhs12_15, lo0, hi0, 0xaa); |
713 | mm256_shuffle_epi32(tmp3, rhs12_15, lo1, hi1, 0xff); |
714 | process_column(tmp2, tmp3, accum_data_v7); |
715 | |
716 | lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
717 | rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
718 | } |
719 | |
720 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
721 | __m256i m_vector; |
722 | __m256i e_vector; |
723 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
724 | m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
725 | params.multiplier_fixedpoint + multiplier_channel)); |
726 | e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
727 | params.multiplier_exponent + multiplier_channel)); |
728 | |
729 | const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( |
730 | _mm256_extractf128_si256(m_vector, 0)); |
731 | const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( |
732 | _mm256_extractf128_si256(m_vector, 1)); |
733 | |
734 | const __m256i zero_vector = _mm256_setzero_si256(); |
735 | const __m256i left_shift = |
736 | intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); |
737 | const __m256i neg_e_vector = |
738 | intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); |
739 | const __m256i right_shift = |
740 | intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); |
741 | const __m256i final_right_shift = _mm256_set1_epi32(31); |
742 | const __m256i final_right_shift_low = |
743 | intrin_utils::mm256_cvtepi32_epi64<path>( |
744 | _mm256_extractf128_si256(final_right_shift, 0)); |
745 | const __m256i final_right_shift_high = |
746 | intrin_utils::mm256_cvtepi32_epi64<path>( |
747 | _mm256_extractf128_si256(final_right_shift, 1)); |
748 | const __m256i convert_to_unsigned_64 = |
749 | _mm256_set1_epi64x(0x8000000000000000); |
750 | |
751 | __m256i post_scaling_offset = _mm256_setzero_si256(); |
752 | |
753 | // A "half" added for rounding prior to truncation of 64-bit value. |
754 | const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( |
755 | intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), |
756 | convert_to_unsigned_64); |
757 | |
758 | if (params.dst_zero_point) { |
759 | post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
760 | } |
761 | |
762 | // We cannot do |
763 | // |
764 | // scaled_v_low = |
765 | // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); |
766 | // scaled_v_high = |
767 | // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); |
768 | // |
769 | // since this instruction is not in AVX2. Instead we use |
770 | // _mm256_srlv_epi64, but this is an unsigned shift, so we applied |
771 | // offsets before (convert_to_unsigned_64) and after |
772 | // (convert_to_signed_halved). |
773 | // |
774 | // The overall process is, for 64-bit scaled accumulator: |
775 | // unsigned_accum = signed_accum + 1 << 63; |
776 | // unsigned_accum = (unsigned_accum >> right_shift) >> 31; |
777 | // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; |
778 | |
779 | // There are various ways to repack the results, in the absence of |
780 | // _mm256_cvtepi64_epi32() or anything like it. |
781 | // A. |
782 | // accum_data_v[j] = |
783 | // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), |
784 | // _mm256_extract_epi32(scaled_v_high, 4), |
785 | // _mm256_extract_epi32(scaled_v_high, 2), |
786 | // _mm256_extract_epi32(scaled_v_high, 0), |
787 | // _mm256_extract_epi32(scaled_v_low, 6), |
788 | // _mm256_extract_epi32(scaled_v_low, 4), |
789 | // _mm256_extract_epi32(scaled_v_low, 2), |
790 | // _mm256_extract_epi32(scaled_v_low, 0)); |
791 | // B. |
792 | // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); |
793 | // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); |
794 | // accum_data_v[j] = |
795 | // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), |
796 | // _mm256_extract_epi64(scaled_v_high, 0), |
797 | // _mm256_extract_epi64(scaled_v_low, 2), |
798 | // _mm256_extract_epi64(scaled_v_low, 0)); |
799 | // C. |
800 | // scaled_v_low = |
801 | // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); |
802 | // scaled_v_high = |
803 | // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); |
804 | // accum_data_v[j] = |
805 | // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); |
806 | // |
807 | // However, we choose the following because it uses two lighter |
808 | // instructions. The permutation does have a longer latency, but this |
809 | // loop can be unrolled. |
810 | // D. |
811 | // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); |
812 | // __m256i results = |
813 | // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); |
814 | // results = _mm256_permutevar8x32_epi32(results, repack_perm); |
815 | // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results, |
816 | // post_scaling_offset); |
817 | |
818 | // This multiplier code is complex and expensive enough on x86, that |
819 | // we prefer to implement the channels-are-columns case by transposing |
820 | // around it, rather than duplicate it (which would also require |
821 | // duplicating the above code computing the multiplier constants). |
822 | // This is one instance where channels-are-columns has lower performance |
823 | // than channels-are-rows. |
824 | const bool transpose_around_multiplier = |
825 | (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && |
826 | (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); |
827 | if (transpose_around_multiplier) { |
828 | // Transpose the 8x8 accumulators block. Will be un-transposed below |
829 | // after the multplier implementation. |
830 | intrin_utils::mm256_transpose8x8_epi32<path>( |
831 | &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
832 | &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
833 | } |
834 | |
835 | auto rounding_right_shift = [=](__m256i& results, |
836 | const __m256i& exponent) { |
837 | // Construct the "nudge" value for each lane if the exponent is |
838 | // greater than 0. Otherwise, the nudge is 0. |
839 | const __m256i zeros = _mm256_setzero_si256(); |
840 | const __m256i mask_rightshift_gtz = |
841 | intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros); |
842 | const __m256i one_shift_exp_minus1 = |
843 | intrin_utils::mm256_sllv_epi32<path>( |
844 | _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
845 | exponent, _mm256_set1_epi32(1))); |
846 | __m256i nudge = intrin_utils::mm256_blendv_epi32( |
847 | zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
848 | // Calculate the shifted sum (results + nudge) >> exp. |
849 | const __m256i r_plus_nudge = |
850 | intrin_utils::mm256_add_epi32<path>(results, nudge); |
851 | const __m256i shifted_sum = |
852 | intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent); |
853 | |
854 | // Identify overflow in each lane and create mask. |
855 | const __m256i one_shift_31minus_exp = |
856 | intrin_utils::mm256_sllv_epi32<path>( |
857 | _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
858 | _mm256_set1_epi32(31), exponent)); |
859 | const __m256i mask_num_plus_nudge_overflow = |
860 | intrin_utils::mm256_cmpgt_epi32<path>( |
861 | results, intrin_utils::mm256_sub_epi32<path>( |
862 | _mm256_set1_epi32(0x7fffffff), nudge)); |
863 | // Fill results with either (results + nudge) >> exponent or |
864 | // 1 << (31 - exp) in the case of overflow. |
865 | results = intrin_utils::mm256_blendv_epi32( |
866 | shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
867 | }; |
868 | |
869 | auto apply_multiplier = [=](__m256i& accum) { |
870 | __m256i shifted_accum = |
871 | intrin_utils::mm256_sllv_epi32<path>(accum, left_shift); |
872 | // Apply the fixed-point part of the multiplier. |
873 | __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( |
874 | intrin_utils::mm256_cvtepi32_epi64<path>( |
875 | _mm256_extractf128_si256(shifted_accum, 0)), |
876 | m_64bit_low); |
877 | __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( |
878 | intrin_utils::mm256_cvtepi32_epi64<path>( |
879 | _mm256_extractf128_si256(shifted_accum, 1)), |
880 | m_64bit_high); |
881 | scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, |
882 | offset_vector); |
883 | scaled_v_high = intrin_utils::mm256_add_epi64<path>( |
884 | scaled_v_high, offset_vector); |
885 | |
886 | scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( |
887 | scaled_v_low, final_right_shift_low); |
888 | scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( |
889 | scaled_v_high, final_right_shift_high); |
890 | scaled_v_high = |
891 | intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); |
892 | __m256i results; |
893 | mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); |
894 | // Permute results to this ordering of int32 elements |
895 | // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); |
896 | results = intrin_utils::PermuteEpi32EvenOdds<path>(results); |
897 | |
898 | rounding_right_shift(results, right_shift); |
899 | accum = |
900 | intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); |
901 | }; |
902 | apply_multiplier(accum_data_v0); |
903 | apply_multiplier(accum_data_v1); |
904 | apply_multiplier(accum_data_v2); |
905 | apply_multiplier(accum_data_v3); |
906 | apply_multiplier(accum_data_v4); |
907 | apply_multiplier(accum_data_v5); |
908 | apply_multiplier(accum_data_v6); |
909 | apply_multiplier(accum_data_v7); |
910 | // See above comment: here we transpose again to undo the transposition |
911 | // of the 8x8 block of accumulators used to implement the |
912 | // channels-are-columns case. |
913 | if (transpose_around_multiplier) { |
914 | intrin_utils::mm256_transpose8x8_epi32<path>( |
915 | &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, |
916 | &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); |
917 | } |
918 | } |
919 | const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
920 | const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
921 | const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && |
922 | (residual_cols == kAvx8bitBlockSize); |
923 | |
924 | __m256i accum_data_v[kAvx8bitBlockSize]; |
925 | if (!store_full_block) { |
926 | accum_data_v[0] = accum_data_v0; |
927 | accum_data_v[1] = accum_data_v1; |
928 | accum_data_v[2] = accum_data_v2; |
929 | accum_data_v[3] = accum_data_v3; |
930 | accum_data_v[4] = accum_data_v4; |
931 | accum_data_v[5] = accum_data_v5; |
932 | accum_data_v[6] = accum_data_v6; |
933 | accum_data_v[7] = accum_data_v7; |
934 | } |
935 | |
936 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
937 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
938 | if (store_full_block) { |
939 | accum_data_v0 = |
940 | intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
941 | accum_data_v0 = |
942 | intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
943 | accum_data_v1 = |
944 | intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
945 | accum_data_v1 = |
946 | intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
947 | accum_data_v2 = |
948 | intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
949 | accum_data_v2 = |
950 | intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
951 | accum_data_v3 = |
952 | intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
953 | accum_data_v3 = |
954 | intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
955 | accum_data_v4 = |
956 | intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
957 | accum_data_v4 = |
958 | intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
959 | accum_data_v5 = |
960 | intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
961 | accum_data_v5 = |
962 | intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
963 | accum_data_v6 = |
964 | intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
965 | accum_data_v6 = |
966 | intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
967 | accum_data_v7 = |
968 | intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
969 | accum_data_v7 = |
970 | intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
971 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
972 | &tmp_ptr[0 * dst_stride], accum_data_v0); |
973 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
974 | &tmp_ptr[1 * dst_stride], accum_data_v1); |
975 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
976 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
977 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
978 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
979 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
980 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
981 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
982 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
983 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
984 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
985 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
986 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
987 | } else { |
988 | for (int j = 0; j < residual_cols; ++j) { |
989 | __m256i result = accum_data_v[j]; |
990 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
991 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
992 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
993 | tmp_ptr, residual_rows, result); |
994 | tmp_ptr += dst_stride; |
995 | } |
996 | } |
997 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
998 | kAvx8bitBlockSize); |
999 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
1000 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
1001 | if (store_full_block) { |
1002 | accum_data_v0 = |
1003 | intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
1004 | accum_data_v0 = |
1005 | intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
1006 | accum_data_v1 = |
1007 | intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
1008 | accum_data_v1 = |
1009 | intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
1010 | accum_data_v2 = |
1011 | intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
1012 | accum_data_v2 = |
1013 | intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
1014 | accum_data_v3 = |
1015 | intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
1016 | accum_data_v3 = |
1017 | intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
1018 | accum_data_v4 = |
1019 | intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
1020 | accum_data_v4 = |
1021 | intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
1022 | accum_data_v5 = |
1023 | intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
1024 | accum_data_v5 = |
1025 | intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
1026 | accum_data_v6 = |
1027 | intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
1028 | accum_data_v6 = |
1029 | intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
1030 | accum_data_v7 = |
1031 | intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
1032 | accum_data_v7 = |
1033 | intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
1034 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[0], |
1035 | accum_data_v0); |
1036 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>(&tmp_ptr[dst_stride], |
1037 | accum_data_v1); |
1038 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1039 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
1040 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1041 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
1042 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1043 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
1044 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1045 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
1046 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1047 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
1048 | intrin_utils::mm256_storeu_cvtepi32_epi8<path>( |
1049 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
1050 | } else { |
1051 | for (int j = 0; j < residual_cols; ++j) { |
1052 | __m256i result = accum_data_v[j]; |
1053 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
1054 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
1055 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
1056 | tmp_ptr, residual_rows, result); |
1057 | tmp_ptr += dst_stride; |
1058 | } |
1059 | } |
1060 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
1061 | kAvx8bitBlockSize); |
1062 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
1063 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
1064 | if (store_full_block) { |
1065 | accum_data_v0 = |
1066 | intrin_utils::mm256_min_epi32<path>(accum_data_v0, clamp_max_v); |
1067 | accum_data_v0 = |
1068 | intrin_utils::mm256_max_epi32<path>(accum_data_v0, clamp_min_v); |
1069 | accum_data_v1 = |
1070 | intrin_utils::mm256_min_epi32<path>(accum_data_v1, clamp_max_v); |
1071 | accum_data_v1 = |
1072 | intrin_utils::mm256_max_epi32<path>(accum_data_v1, clamp_min_v); |
1073 | accum_data_v2 = |
1074 | intrin_utils::mm256_min_epi32<path>(accum_data_v2, clamp_max_v); |
1075 | accum_data_v2 = |
1076 | intrin_utils::mm256_max_epi32<path>(accum_data_v2, clamp_min_v); |
1077 | accum_data_v3 = |
1078 | intrin_utils::mm256_min_epi32<path>(accum_data_v3, clamp_max_v); |
1079 | accum_data_v3 = |
1080 | intrin_utils::mm256_max_epi32<path>(accum_data_v3, clamp_min_v); |
1081 | accum_data_v4 = |
1082 | intrin_utils::mm256_min_epi32<path>(accum_data_v4, clamp_max_v); |
1083 | accum_data_v4 = |
1084 | intrin_utils::mm256_max_epi32<path>(accum_data_v4, clamp_min_v); |
1085 | accum_data_v5 = |
1086 | intrin_utils::mm256_min_epi32<path>(accum_data_v5, clamp_max_v); |
1087 | accum_data_v5 = |
1088 | intrin_utils::mm256_max_epi32<path>(accum_data_v5, clamp_min_v); |
1089 | accum_data_v6 = |
1090 | intrin_utils::mm256_min_epi32<path>(accum_data_v6, clamp_max_v); |
1091 | accum_data_v6 = |
1092 | intrin_utils::mm256_max_epi32<path>(accum_data_v6, clamp_min_v); |
1093 | accum_data_v7 = |
1094 | intrin_utils::mm256_min_epi32<path>(accum_data_v7, clamp_max_v); |
1095 | accum_data_v7 = |
1096 | intrin_utils::mm256_max_epi32<path>(accum_data_v7, clamp_min_v); |
1097 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[0], |
1098 | accum_data_v0); |
1099 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>(&tmp_ptr[dst_stride], |
1100 | accum_data_v1); |
1101 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1102 | &tmp_ptr[2 * dst_stride], accum_data_v2); |
1103 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1104 | &tmp_ptr[3 * dst_stride], accum_data_v3); |
1105 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1106 | &tmp_ptr[4 * dst_stride], accum_data_v4); |
1107 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1108 | &tmp_ptr[5 * dst_stride], accum_data_v5); |
1109 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1110 | &tmp_ptr[6 * dst_stride], accum_data_v6); |
1111 | intrin_utils::mm256_storeu_cvtepi32_epi16<path>( |
1112 | &tmp_ptr[7 * dst_stride], accum_data_v7); |
1113 | } else { |
1114 | for (int j = 0; j < residual_cols; ++j) { |
1115 | __m256i result = accum_data_v[j]; |
1116 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
1117 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
1118 | intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>( |
1119 | tmp_ptr, residual_rows, result); |
1120 | tmp_ptr += dst_stride; |
1121 | } |
1122 | } |
1123 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
1124 | kAvx8bitBlockSize); |
1125 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
1126 | if (store_full_block) { |
1127 | std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
1128 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[0], accum_data_v0); |
1129 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[dst_stride], |
1130 | accum_data_v1); |
1131 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[2 * dst_stride], |
1132 | accum_data_v2); |
1133 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[3 * dst_stride], |
1134 | accum_data_v3); |
1135 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[4 * dst_stride], |
1136 | accum_data_v4); |
1137 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[5 * dst_stride], |
1138 | accum_data_v5); |
1139 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[6 * dst_stride], |
1140 | accum_data_v6); |
1141 | intrin_utils::mm256_storeu_epi32<path>(&tmp_ptr[7 * dst_stride], |
1142 | accum_data_v7); |
1143 | } else { |
1144 | std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
1145 | for (int j = 0; j < residual_cols; ++j) { |
1146 | intrin_utils::mm256_n_storeu_epi32<path>( |
1147 | dst_block_ptr, residual_rows, accum_data_v[j]); |
1148 | dst_block_ptr += dst_stride; |
1149 | } |
1150 | } |
1151 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
1152 | kAvx8bitBlockSize); |
1153 | } else { |
1154 | RUY_DCHECK(false); |
1155 | } |
1156 | |
1157 | lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
1158 | } // End row-block loop. |
1159 | |
1160 | dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
1161 | kAvx8bitBlockSize * params.dst_stride); |
1162 | rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; |
1163 | } // End col-block loop. |
1164 | } // NOLINT(readability/fn_size) |
1165 | |
1166 | void Kernel8bitAvx(const KernelParams8bit<8, 8>& params) { |
1167 | Kernel8bitAvxImpl<Path::kAvx>(params); |
1168 | } |
1169 | |
1170 | template <Path path> |
1171 | void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { |
1172 | profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV" ); |
1173 | |
1174 | RUY_DCHECK_EQ(params.dst_cols, 1); |
1175 | RUY_DCHECK_EQ(params.last_col, 0); |
1176 | RUY_DCHECK_EQ(params.start_col, 0); |
1177 | |
1178 | const std::int8_t splitter_idx_data[32] = { |
1179 | 0, 1, 4, 5, 8, 9, 12, 13, // |
1180 | 2, 3, 6, 7, 10, 11, 14, 15, // |
1181 | 0, 1, 4, 5, 8, 9, 12, 13, // |
1182 | 2, 3, 6, 7, 10, 11, 14, 15 // |
1183 | }; |
1184 | |
1185 | int bias_ptr_block_increment = |
1186 | params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; |
1187 | |
1188 | const std::int8_t* rhs_col_ptr = |
1189 | static_cast<const int8_t*>(params.rhs_base_ptr); |
1190 | void* dst_col_ptr = params.dst_base_ptr; |
1191 | const std::int32_t* bias_col_ptr = params.bias; |
1192 | if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
1193 | bias_col_ptr += params.start_row; |
1194 | } |
1195 | |
1196 | const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
1197 | void* dst_ptr = dst_col_ptr; |
1198 | const std::int32_t* bias_ptr = bias_col_ptr; |
1199 | |
1200 | const std::int32_t lhs_zero_point = params.lhs_zero_point; |
1201 | const bool has_rhs_sums_offsets = |
1202 | (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
1203 | std::int32_t rhs_sums_offsets[8]; |
1204 | if (has_rhs_sums_offsets) { |
1205 | const __m256i rhs_sums_offset_v = intrin_utils::mm256_mullo_epi32<path>( |
1206 | _mm256_set1_epi32(lhs_zero_point), |
1207 | _mm256_loadu_si256( |
1208 | reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); |
1209 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), |
1210 | rhs_sums_offset_v); |
1211 | } |
1212 | |
1213 | for (int row = params.start_row; row <= params.last_row; |
1214 | row += kAvx8bitBlockSize) { |
1215 | const int residual_rows = |
1216 | std::min(params.dst_rows - row, kAvx8bitBlockSize); |
1217 | |
1218 | const __m256i splitter_idx = |
1219 | _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); |
1220 | |
1221 | __m256i accum_data_v0; |
1222 | |
1223 | // Initialize with bias. |
1224 | __m256i initial_accum_data = |
1225 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias_ptr)); |
1226 | bias_ptr += bias_ptr_block_increment; |
1227 | |
1228 | // Adjustments common across columns. |
1229 | const std::int32_t rhs_zero_point = params.rhs_zero_point; |
1230 | if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
1231 | const __m256i lhs_sums_offset = intrin_utils::mm256_mullo_epi32<path>( |
1232 | _mm256_set1_epi32(rhs_zero_point), |
1233 | _mm256_loadu_si256( |
1234 | reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); |
1235 | initial_accum_data = intrin_utils::mm256_sub_epi32<path>( |
1236 | initial_accum_data, lhs_sums_offset); |
1237 | } |
1238 | const std::int32_t prod_zp_depth = params.prod_zp_depth; |
1239 | if (prod_zp_depth) { |
1240 | initial_accum_data = intrin_utils::mm256_add_epi32<path>( |
1241 | initial_accum_data, _mm256_set1_epi32(prod_zp_depth)); |
1242 | } |
1243 | |
1244 | // Adjustments differing across columns. |
1245 | if (has_rhs_sums_offsets) { |
1246 | accum_data_v0 = intrin_utils::mm256_sub_epi32<path>( |
1247 | initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); |
1248 | } else { |
1249 | accum_data_v0 = initial_accum_data; |
1250 | } |
1251 | |
1252 | const std::int8_t* lhs_ptr = lhs_col_ptr; |
1253 | const std::int8_t* rhs_ptr = rhs_col_ptr; |
1254 | for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { |
1255 | const __m256i lhs_data = |
1256 | _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); |
1257 | const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); |
1258 | |
1259 | // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
1260 | // For simplicity we load 4x the data that we need and process twice the |
1261 | // data that we need and store only the data we need. |
1262 | std::int32_t rhs_data[2]; |
1263 | const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); |
1264 | // Now that we have cast the RHS data, we store it so that each value |
1265 | // can be separately loaded in the accumulation loop. |
1266 | _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); |
1267 | |
1268 | // NOTE: There may be opportunities for permuting the data in the packing |
1269 | // code instead of here. |
1270 | const __m256i lhs_data_split = |
1271 | intrin_utils::mm256_shuffle_epi8<path>(lhs_data, splitter_idx); |
1272 | const __m256i lhs_data_split_expand_bottom = |
1273 | intrin_utils::mm256_cvtepi8_epi16<path>( |
1274 | _mm256_extractf128_si256(lhs_data_split, 0)); |
1275 | const __m256i lhs_data_split_expand_top = |
1276 | intrin_utils::mm256_cvtepi8_epi16<path>( |
1277 | _mm256_extractf128_si256(lhs_data_split, 1)); |
1278 | |
1279 | // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. |
1280 | const __m256i lhs_16_bit_low = |
1281 | intrin_utils::mm256_permute2x128_si256<path>( |
1282 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); |
1283 | // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. |
1284 | const __m256i lhs_16_bit_high = |
1285 | intrin_utils::mm256_permute2x128_si256<path>( |
1286 | lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); |
1287 | // Accumulate for column 0. |
1288 | const std::int32_t low_rhs_value = rhs_data[0]; |
1289 | const std::int32_t high_rhs_value = rhs_data[1]; |
1290 | |
1291 | const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); |
1292 | const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); |
1293 | |
1294 | accum_data_v0 = intrin_utils::mm256_add_epi32<path>( |
1295 | accum_data_v0, intrin_utils::mm256_madd_epi16<path>( |
1296 | lhs_16_bit_low, rhs_16_bit_dup_low)); |
1297 | accum_data_v0 = intrin_utils::mm256_add_epi32<path>( |
1298 | accum_data_v0, intrin_utils::mm256_madd_epi16<path>( |
1299 | lhs_16_bit_high, rhs_16_bit_dup_high)); |
1300 | |
1301 | lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
1302 | rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; |
1303 | } |
1304 | |
1305 | if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
1306 | __m256i m_vector; |
1307 | __m256i e_vector; |
1308 | // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
1309 | int channel = (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) ? row : 0; |
1310 | m_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1311 | params.multiplier_fixedpoint + channel)); |
1312 | e_vector = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1313 | params.multiplier_exponent + channel)); |
1314 | |
1315 | const __m256i m_64bit_low = intrin_utils::mm256_cvtepi32_epi64<path>( |
1316 | _mm256_extractf128_si256(m_vector, 0)); |
1317 | const __m256i m_64bit_high = intrin_utils::mm256_cvtepi32_epi64<path>( |
1318 | _mm256_extractf128_si256(m_vector, 1)); |
1319 | |
1320 | const __m256i zero_vector = _mm256_setzero_si256(); |
1321 | const __m256i left_shift = |
1322 | intrin_utils::mm256_max_epi32<path>(e_vector, zero_vector); |
1323 | const __m256i neg_e_vector = |
1324 | intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); |
1325 | const __m256i right_shift = |
1326 | intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); |
1327 | const __m256i final_right_shift = _mm256_set1_epi32(31); |
1328 | const __m256i final_right_shift_low = |
1329 | intrin_utils::mm256_cvtepi32_epi64<path>( |
1330 | _mm256_extractf128_si256(final_right_shift, 0)); |
1331 | const __m256i final_right_shift_high = |
1332 | intrin_utils::mm256_cvtepi32_epi64<path>( |
1333 | _mm256_extractf128_si256(final_right_shift, 1)); |
1334 | const __m256i convert_to_unsigned_64 = |
1335 | _mm256_set1_epi64x(0x8000000000000000); |
1336 | |
1337 | __m256i post_scaling_offset = _mm256_setzero_si256(); |
1338 | |
1339 | // A "half" added for rounding prior to truncation of 64-bit value. |
1340 | const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( |
1341 | intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), |
1342 | convert_to_unsigned_64); |
1343 | |
1344 | if (params.dst_zero_point) { |
1345 | post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); |
1346 | } |
1347 | |
1348 | // See GEMM version for details of this process. |
1349 | { |
1350 | __m256i shifted_accum = |
1351 | intrin_utils::mm256_sllv_epi32<path>(accum_data_v0, left_shift); |
1352 | // Apply the fixed-point part of the multiplier. |
1353 | __m256i scaled_v_low = intrin_utils::mm256_mul_epi32<path>( |
1354 | intrin_utils::mm256_cvtepi32_epi64<path>( |
1355 | _mm256_extractf128_si256(shifted_accum, 0)), |
1356 | m_64bit_low); |
1357 | __m256i scaled_v_high = intrin_utils::mm256_mul_epi32<path>( |
1358 | intrin_utils::mm256_cvtepi32_epi64<path>( |
1359 | _mm256_extractf128_si256(shifted_accum, 1)), |
1360 | m_64bit_high); |
1361 | |
1362 | scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, |
1363 | offset_vector); |
1364 | scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high, |
1365 | offset_vector); |
1366 | |
1367 | scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( |
1368 | scaled_v_low, final_right_shift_low); |
1369 | scaled_v_high = intrin_utils::mm256_srlv_epi64<path>( |
1370 | scaled_v_high, final_right_shift_high); |
1371 | |
1372 | scaled_v_high = intrin_utils::mm256_slli_epi64<path>(scaled_v_high, 32); |
1373 | __m256i results; |
1374 | mm256_blend_epi32(results, scaled_v_low, scaled_v_high, 0xaa); |
1375 | // Permute results to this ordering of int32 elements |
1376 | // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); |
1377 | results = intrin_utils::PermuteEpi32EvenOdds<path>(results); |
1378 | |
1379 | // Now perform the Rounding Right Shift. |
1380 | // First, construct the "nudge" value for each lane if the exponent is |
1381 | // greater than 0. Otherwise, the nudge is 0. |
1382 | const __m256i zeros = _mm256_setzero_si256(); |
1383 | const __m256i mask_rightshift_gtz = |
1384 | intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros); |
1385 | const __m256i one_shift_exp_minus1 = |
1386 | intrin_utils::mm256_sllv_epi32<path>( |
1387 | _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
1388 | right_shift, _mm256_set1_epi32(1))); |
1389 | __m256i nudge = intrin_utils::mm256_blendv_epi32( |
1390 | zeros, one_shift_exp_minus1, mask_rightshift_gtz); |
1391 | // Calculate the shifted sum (results + nudge) >> exp. |
1392 | const __m256i r_plus_nudge = |
1393 | intrin_utils::mm256_add_epi32<path>(results, nudge); |
1394 | const __m256i shifted_sum = |
1395 | intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift); |
1396 | |
1397 | // Identify overflow in each lane and create mask. |
1398 | const __m256i one_shift_31minus_exp = |
1399 | intrin_utils::mm256_sllv_epi32<path>( |
1400 | _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( |
1401 | _mm256_set1_epi32(31), right_shift)); |
1402 | const __m256i mask_num_plus_nudge_overflow = |
1403 | intrin_utils::mm256_cmpgt_epi32<path>( |
1404 | results, intrin_utils::mm256_sub_epi32<path>( |
1405 | _mm256_set1_epi32(0x7fffffff), nudge)); |
1406 | // Fill results with either (results + nudge) >> exponent or |
1407 | // 1 << (31 - exp) in the case of overflow. |
1408 | results = intrin_utils::mm256_blendv_epi32( |
1409 | shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); |
1410 | accum_data_v0 = |
1411 | intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); |
1412 | } |
1413 | } |
1414 | const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); |
1415 | const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); |
1416 | |
1417 | if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
1418 | std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
1419 | __m256i result = accum_data_v0; |
1420 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
1421 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
1422 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
1423 | result); |
1424 | dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + |
1425 | kAvx8bitBlockSize); |
1426 | } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
1427 | std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
1428 | __m256i result = accum_data_v0; |
1429 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
1430 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
1431 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(tmp_ptr, residual_rows, |
1432 | result); |
1433 | dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + |
1434 | kAvx8bitBlockSize); |
1435 | } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
1436 | std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
1437 | __m256i result = accum_data_v0; |
1438 | result = intrin_utils::mm256_min_epi32<path>(result, clamp_max_v); |
1439 | result = intrin_utils::mm256_max_epi32<path>(result, clamp_min_v); |
1440 | intrin_utils::mm256_n_storeu_cvtepi32_epi16<path>(tmp_ptr, residual_rows, |
1441 | result); |
1442 | dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + |
1443 | kAvx8bitBlockSize); |
1444 | } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
1445 | std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr); |
1446 | intrin_utils::mm256_n_storeu_epi32<path>(dst_block_ptr, residual_rows, |
1447 | accum_data_v0); |
1448 | dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + |
1449 | kAvx8bitBlockSize); |
1450 | } else { |
1451 | RUY_DCHECK(false); |
1452 | } |
1453 | |
1454 | lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; |
1455 | } // End row-block loop. |
1456 | |
1457 | dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
1458 | kAvx8bitBlockSize * params.dst_stride); |
1459 | rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; |
1460 | } // NOLINT(readability/fn_size) |
1461 | |
1462 | void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params) { |
1463 | Kernel8bitAvxSingleColImpl<Path::kAvx>(params); |
1464 | } |
1465 | |
1466 | void KernelFloatAvx(const KernelParamsFloat<8, 8>& params) { |
1467 | profiler::ScopeLabel label("Kernel kAvx float" ); |
1468 | KernelFloatAvxCommon<Path::kAvx>(params); |
1469 | } |
1470 | |
1471 | void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params) { |
1472 | profiler::ScopeLabel label("Kernel kAvx float GEMV" ); |
1473 | KernelFloatAvxCommonSingleCol<Path::kAvx>(params); |
1474 | } |
1475 | |
1476 | #endif // RUY_PLATFORM_AVX && RUY_OPT(ASM) |
1477 | |
1478 | } // namespace ruy |
1479 | |