1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_KERNEL_X86_H_
17#define RUY_RUY_KERNEL_X86_H_
18
19#include <cstdint>
20#include <cstring>
21
22#include "ruy/kernel_common.h"
23#include "ruy/mat.h"
24#include "ruy/mul_params.h"
25#include "ruy/opt_set.h"
26#include "ruy/path.h"
27#include "ruy/platform.h"
28#include "ruy/tune.h"
29
30namespace ruy {
31
32#if RUY_PLATFORM_X86
33
34RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx)
35RUY_INHERIT_KERNEL(Path::kAvx, Path::kAvx2Fma)
36RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
37
38void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
39void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
40
41template <typename DstScalar>
42struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
43 static constexpr Path kPath = Path::kAvx512;
44 Tuning tuning = Tuning::kAuto;
45 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
46 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
47 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
48 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
49 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
50 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
51 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
52 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
53 end_col, dst, &params);
54 if (dst->layout.cols == 1 &&
55 mul_params.channel_dimension() == ChannelDimension::kRow) {
56 Kernel8bitAvx512SingleCol(params);
57 } else {
58 Kernel8bitAvx512(params);
59 }
60 }
61};
62
63template <typename DstScalar>
64struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t,
65 DstScalar> {
66 static constexpr Path kPath = Path::kAvx512;
67 Tuning tuning = Tuning::kAuto;
68 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
69 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
70 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
71 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
72 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
73 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
74 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
75 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
76 end_col, dst, &params);
77 if (dst->layout.cols == 1 &&
78 mul_params.channel_dimension() == ChannelDimension::kRow) {
79 Kernel8bitAvx512SingleCol(params);
80 } else {
81 Kernel8bitAvx512(params);
82 }
83 }
84};
85
86void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
87void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
88
89template <>
90struct Kernel<Path::kAvx512, float, float, float, float> {
91 static constexpr Path kPath = Path::kAvx512;
92 Tuning tuning = Tuning::kAuto;
93 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
94 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
95 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
96 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
97 const MulParams<float, float>& mul_params, int start_row,
98 int start_col, int end_row, int end_col, Mat<float>* dst) const {
99 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
100 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
101 end_col, dst, &params);
102 if (dst->layout.cols == 1 &&
103 mul_params.channel_dimension() == ChannelDimension::kRow) {
104 KernelFloatAvx512SingleCol(params);
105 } else {
106 KernelFloatAvx512(params);
107 }
108 }
109};
110
111void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
112void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
113
114template <typename DstScalar>
115struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
116 DstScalar> {
117 static constexpr Path kPath = Path::kAvx2Fma;
118 Tuning tuning = Tuning::kAuto;
119 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
120 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
121 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
122 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
123 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
124 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
125 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
126 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
127 end_col, dst, &params);
128 if (dst->layout.cols == 1 &&
129 mul_params.channel_dimension() == ChannelDimension::kRow) {
130 Kernel8bitAvx2SingleCol(params);
131 } else {
132 Kernel8bitAvx2(params);
133 }
134 }
135};
136
137template <typename DstScalar>
138struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t,
139 DstScalar> {
140 static constexpr Path kPath = Path::kAvx2Fma;
141 Tuning tuning = Tuning::kAuto;
142 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
143 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
144 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
145 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
146 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
147 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
148 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
149 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
150 end_col, dst, &params);
151 if (dst->layout.cols == 1 &&
152 mul_params.channel_dimension() == ChannelDimension::kRow) {
153 Kernel8bitAvx2SingleCol(params);
154 } else {
155 Kernel8bitAvx2(params);
156 }
157 }
158};
159
160void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
161void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
162
163template <>
164struct Kernel<Path::kAvx2Fma, float, float, float, float> {
165 static constexpr Path kPath = Path::kAvx2Fma;
166 Tuning tuning = Tuning::kAuto;
167 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
168 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
169 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
170 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
171 const MulParams<float, float>& mul_params, int start_row,
172 int start_col, int end_row, int end_col, Mat<float>* dst) const {
173 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
174 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
175 end_col, dst, &params);
176 if (dst->layout.cols == 1 &&
177 mul_params.channel_dimension() == ChannelDimension::kRow) {
178 KernelFloatAvx2SingleCol(params);
179 } else {
180 KernelFloatAvx2(params);
181 }
182 }
183};
184
185void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
186void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
187
188template <>
189struct Kernel<Path::kAvx, float, float, float, float> {
190 static constexpr Path kPath = Path::kAvx;
191 Tuning tuning = Tuning::kAuto;
192 using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
193 using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
194 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
195 void Run(const PMat<float>& lhs, const PMat<float>& rhs,
196 const MulParams<float, float>& mul_params, int start_row,
197 int start_col, int end_row, int end_col, Mat<float>* dst) const {
198 KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
199 MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
200 end_col, dst, &params);
201 if (dst->layout.cols == 1 &&
202 mul_params.channel_dimension() == ChannelDimension::kRow) {
203 KernelFloatAvxSingleCol(params);
204 } else {
205 KernelFloatAvx(params);
206 }
207 }
208};
209
210void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
211void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);
212
213template <typename DstScalar>
214struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
215 static constexpr Path kPath = Path::kAvx2Fma;
216 Tuning tuning = Tuning::kAuto;
217 using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
218 using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
219 explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
220 void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
221 const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
222 int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
223 KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
224 MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
225 end_col, dst, &params);
226 if (dst->layout.cols == 1 &&
227 mul_params.channel_dimension() == ChannelDimension::kRow) {
228 Kernel8bitAvxSingleCol(params);
229 } else {
230 Kernel8bitAvx(params);
231 }
232 }
233};
234
235#endif // RUY_PLATFORM_X86
236} // namespace ruy
237
238#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
239
240#include <immintrin.h> // IWYU pragma: keep
241
242namespace ruy {
243namespace {
244namespace intrin_utils {
245
246// Defined as a template so clang won't detect it as an uneeded
247// definition.
248template <Path path>
249inline float mm256_get1_ps(const __m256 a, int i) {
250 __m256i ai = _mm256_castps_si256(a);
251 int float_val_as_int;
252 switch (i) {
253 case 0:
254 float_val_as_int = _mm256_extract_epi32(ai, 0);
255 break;
256 case 1:
257 float_val_as_int = _mm256_extract_epi32(ai, 1);
258 break;
259 case 2:
260 float_val_as_int = _mm256_extract_epi32(ai, 2);
261 break;
262 case 3:
263 float_val_as_int = _mm256_extract_epi32(ai, 3);
264 break;
265 case 4:
266 float_val_as_int = _mm256_extract_epi32(ai, 4);
267 break;
268 case 5:
269 float_val_as_int = _mm256_extract_epi32(ai, 5);
270 break;
271 case 6:
272 float_val_as_int = _mm256_extract_epi32(ai, 6);
273 break;
274 case 7:
275 float_val_as_int = _mm256_extract_epi32(ai, 7);
276 break;
277 default:
278 RUY_DCHECK_LT(i, 8);
279 return .0f;
280 }
281 float float_val;
282 std::memcpy(&float_val, &float_val_as_int, sizeof(float_val));
283 return float_val;
284}
285
286// Defined as a template so clang won't detect it as an uneeded
287// definition.
288template <Path path>
289inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
290 for (int i = 0; i < residual_rows; ++i) {
291 dst[i] = intrin_utils::mm256_get1_ps<path>(v, i);
292 }
293}
294
295template <Path path>
296inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {
297 // Specializations added for AVX and AVX2FMA paths in their respective kernel
298 // files.
299 RUY_DCHECK(false);
300 return _mm256_set1_ps(0);
301}
302
303template <Path path>
304inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) {
305 // Specializations added for AVX and AVX2FMA paths in their respective kernel
306 // files.
307 RUY_DCHECK(false);
308 return _mm256_set1_epi32(0);
309}
310
311// Polyfill for _mm_storeu_si16(dst, v).
312template <Path path>
313inline void mm_storeu_si16(void* dst, __m128i v) {
314#if (defined __clang__) || (defined _MSC_VER)
315 _mm_storeu_si16(dst, v);
316#else
317 // GCC 9 lacks support for __mm_storeu_si16.
318 *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0);
319#endif
320}
321
322// Polyfill for _mm_storeu_si32(dst, v).
323template <Path path>
324inline void mm_storeu_si32(void* dst, __m128i v) {
325#if (defined __clang__) || (defined _MSC_VER)
326 _mm_storeu_si32(dst, v);
327#else
328 // GCC 9 lacks support for __mm_storeu_si32.
329 *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0);
330#endif
331}
332
333// Polyfill for _mm_loadu_si32(src).
334template <Path path>
335inline __m128i mm_loadu_si32(const void* src) {
336#if (defined __clang__) || (defined _MSC_VER)
337 return _mm_loadu_si32(src);
338#else
339 // GCC 9 lacks support for _mm_loadu_si32.
340 __m128i res;
341 asm("movss %[src], %[res]"
342 : [res] "=x"(res)
343 : [src] "m"(*static_cast<const int*>(src)));
344 return res;
345#endif
346}
347
348template <Path path>
349inline __m128i mm256_extracti128_si256(const __m256i&, const int) {
350 RUY_DCHECK(false);
351 return _mm_setzero_si128();
352}
353
354template <Path path>
355inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
356 const __m256i v) {
357 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
358 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
359 __m256i shuffled_v;
360 if (residual_rows > 1) {
361 // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4
362 // in each 128-bit lane.
363 shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm);
364 }
365 switch (residual_rows) {
366 case 0:
367 break;
368 case 1:
369 dst[0] = _mm256_extract_epi8(v, 0);
370 break;
371 case 2:
372 mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
373 break;
374 case 3: {
375 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0);
376 mm_storeu_si16<path>(dst, trailing_packed);
377 dst[2] = _mm_extract_epi8(trailing_packed, 2);
378 break;
379 }
380 case 4:
381 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
382 break;
383 case 5:
384 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
385 dst[4] = _mm256_extract_epi8(shuffled_v, 16);
386 break;
387 case 6:
388 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
389 mm_storeu_si16<path>(dst + 4,
390 mm256_extracti128_si256<path>(shuffled_v, 1));
391 break;
392 case 7: {
393 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
394 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
395 mm_storeu_si16<path>(dst + 4, trailing_packed);
396 dst[6] = _mm_extract_epi8(trailing_packed, 2);
397 break;
398 }
399 case 8:
400 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
401 mm_storeu_si32<path>(dst + 4,
402 mm256_extracti128_si256<path>(shuffled_v, 1));
403 break;
404 default:
405 RUY_DCHECK_LE(residual_rows, 8);
406 break;
407 }
408}
409
410template <Path path>
411inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) {
412 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
413 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
414 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
415 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
416 mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
417}
418
419template <Path path>
420inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
421 const __m256i v) {
422 intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>(
423 reinterpret_cast<std::uint8_t*>(dst), residual_rows, v);
424}
425
426template <Path path>
427inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) {
428 // Select bytes 0, 4, 8, 12 within each lane, effectively truncating.
429 const __m256i repack_perm = _mm256_set1_epi32(0x0c080400);
430 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
431 mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
432 mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
433}
434
435template <Path path>
436inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
437 const __m256i v) {
438 // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
439 // truncating each 16-bit integer.
440 const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
441 __m256i shuffled_v;
442 __m128i shuffled_v_low;
443 if (residual_rows > 1) {
444 shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
445 shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0);
446 } else {
447 shuffled_v_low = mm256_extracti128_si256<path>(v, 0);
448 }
449 switch (residual_rows) {
450 case 0:
451 break;
452 case 1:
453 mm_storeu_si16<path>(dst, shuffled_v_low);
454 break;
455 case 2:
456 mm_storeu_si32<path>(dst, shuffled_v_low);
457 break;
458 case 3: {
459 mm_storeu_si32<path>(dst, shuffled_v_low);
460 dst[2] = _mm_extract_epi16(shuffled_v_low, 2);
461 break;
462 }
463 case 4:
464 _mm_storeu_si64(dst, shuffled_v_low);
465 break;
466 case 5:
467 _mm_storeu_si64(dst, shuffled_v_low);
468 dst[4] = _mm256_extract_epi16(shuffled_v, 8);
469 break;
470 case 6:
471 _mm_storeu_si64(dst, shuffled_v_low);
472 mm_storeu_si32<path>(dst + 4,
473 mm256_extracti128_si256<path>(shuffled_v, 1));
474 break;
475 case 7: {
476 _mm_storeu_si64(dst, shuffled_v_low);
477 __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1);
478 mm_storeu_si32<path>(dst + 4, trailing_packed);
479 dst[6] = _mm_extract_epi16(trailing_packed, 2);
480 break;
481 }
482 case 8:
483 _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
484 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
485 break;
486 default:
487 RUY_DCHECK_LE(residual_rows, 8);
488 break;
489 }
490}
491
492template <Path path>
493inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) {
494 // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively
495 // truncating each 16-bit integer.
496 const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100);
497 const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm);
498 _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0));
499 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1));
500}
501
502template <Path path>
503inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
504 const __m256i v) {
505 const __m128i v_low = mm256_extracti128_si256<path>(v, 0);
506 switch (residual_rows) {
507 case 0:
508 break;
509 case 1:
510 mm_storeu_si32<path>(dst, v_low);
511 break;
512 case 2:
513 _mm_storeu_si64(dst, v_low);
514 break;
515 case 3: {
516 __m128i trailing_packed = v_low;
517 _mm_storeu_si64(dst, trailing_packed);
518 dst[2] = _mm_extract_epi32(trailing_packed, 2);
519 break;
520 }
521 case 4:
522 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
523 break;
524 case 5:
525 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
526 dst[4] = _mm256_extract_epi32(v, 4);
527 break;
528 case 6:
529 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
530 _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1));
531 break;
532 case 7: {
533 _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low);
534 __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1);
535 _mm_storeu_si64(dst + 4, trailing_packed);
536 dst[6] = _mm_extract_epi32(trailing_packed, 2);
537 break;
538 }
539 case 8:
540 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
541 break;
542 default:
543 RUY_DCHECK_LE(residual_rows, 8);
544 break;
545 }
546}
547
548template <Path path>
549inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {
550 _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v);
551}
552
553// Transpose a 8x8 matrix of floats.
554template <Path path>
555void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
556 __m256* v4, __m256* v5, __m256* v6, __m256* v7) {
557 __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1);
558 __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1);
559 __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3);
560 __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3);
561 __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5);
562 __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5);
563 __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7);
564 __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7);
565 __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0));
566 __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2));
567 __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0));
568 __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2));
569 __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0));
570 __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2));
571 __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0));
572 __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2));
573 *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20);
574 *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20);
575 *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20);
576 *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20);
577 *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31);
578 *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31);
579 *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31);
580 *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31);
581}
582
583// Transpose a 8x8 matrix of int32's.
584template <Path path>
585void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
586 __m256i* v3, __m256i* v4, __m256i* v5,
587 __m256i* v6, __m256i* v7) {
588 mm256_transpose8x8_ps<path>(
589 reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1),
590 reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3),
591 reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5),
592 reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7));
593}
594
595} // namespace intrin_utils
596} // namespace
597
598template <Path path>
599inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
600 // As parameters are defined, we need to scale by sizeof(float).
601 const std::int64_t lhs_stride = params.lhs_stride >> 2;
602 const std::int64_t dst_stride = params.dst_stride >> 2;
603 const std::int64_t rhs_stride = params.rhs_stride >> 2;
604 //
605 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
606 // AVX2 float block size = 8.
607 const int end_row = std::min(params.dst_rows, params.last_row + 8);
608 const int end_col = std::min(params.dst_cols, params.last_col + 8);
609 //
610 const float* adj_rhs_col_ptr =
611 params.rhs_base_ptr - params.start_col * rhs_stride;
612 float* adj_dst_col_ptr =
613 params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
614 const float* adj_lhs_col_ptr =
615 params.lhs_base_ptr - params.start_row * lhs_stride;
616 const float* bias_ptr = params.bias;
617
618 const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
619 const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
620 const bool channel_dimension_is_col =
621 params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
622
623 int col = params.start_col;
624 // Loop through cols by float block size, leaving incomplete remainder
625 for (; col <= end_col - 8; col += 8) {
626 __m256 accum_data_v[8];
627
628 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
629 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
630
631 for (int row = params.start_row; row < end_row; row += 8) {
632 const int residual_rows = std::min(end_row - row, 8);
633
634 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
635 float* dst_ptr = dst_col_ptr + row;
636
637 // Initialize with bias.
638 if (channel_dimension_is_col) {
639 const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
640 for (int j = 0; j < 8; ++j) {
641 accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
642 }
643 } else {
644 const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
645 const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
646
647 for (int j = 0; j < 8; ++j) {
648 accum_data_v[j] = initial_accum_data;
649 }
650 }
651
652 const float* lhs_ptr = lhs_col_ptr;
653 const float* rhs_ptr = rhs_col_ptr;
654 for (int d = 0; d < params.depth; ++d) {
655 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
656 // Load 8 RHS values, then use permute instructions to broadcast each
657 // value to a register. _mm256_permute2f128_ps is slow on AMD.
658 __m256 rhs0_3 =
659 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
660 __m256 rhs4_7 =
661 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
662
663 const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
664 accum_data_v[0] = intrin_utils::MulAdd<path>(
665 lhs_data, dup_rhs_element_0, accum_data_v[0]);
666
667 const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
668 accum_data_v[1] = intrin_utils::MulAdd<path>(
669 lhs_data, dup_rhs_element_1, accum_data_v[1]);
670
671 const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
672 accum_data_v[2] = intrin_utils::MulAdd<path>(
673 lhs_data, dup_rhs_element_2, accum_data_v[2]);
674
675 const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
676 accum_data_v[3] = intrin_utils::MulAdd<path>(
677 lhs_data, dup_rhs_element_3, accum_data_v[3]);
678
679 const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
680 accum_data_v[4] = intrin_utils::MulAdd<path>(
681 lhs_data, dup_rhs_element_4, accum_data_v[4]);
682
683 const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
684 accum_data_v[5] = intrin_utils::MulAdd<path>(
685 lhs_data, dup_rhs_element_5, accum_data_v[5]);
686
687 const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
688 accum_data_v[6] = intrin_utils::MulAdd<path>(
689 lhs_data, dup_rhs_element_6, accum_data_v[6]);
690
691 const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
692 accum_data_v[7] = intrin_utils::MulAdd<path>(
693 lhs_data, dup_rhs_element_7, accum_data_v[7]);
694
695 lhs_ptr += 8;
696 rhs_ptr += 8;
697 }
698
699 if (residual_rows == 8) {
700 for (int j = 0; j < 8; ++j) {
701 float* block_ptr = dst_ptr + j * dst_stride;
702 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
703 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
704 _mm256_storeu_ps(block_ptr, accum_data_v[j]);
705 }
706 } else {
707 for (int j = 0; j < 8; ++j) {
708 float* block_ptr = dst_ptr + j * dst_stride;
709 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
710 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
711 intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
712 accum_data_v[j]);
713 }
714 }
715 } // End row-block loop.
716 } // End col-block loop.
717
718 if (col < end_col) {
719 // Remaining cols in [0, float block size).
720 RUY_DCHECK_GE(end_col - col, 0);
721 RUY_DCHECK_LT(end_col - col, 8);
722
723 __m256 accum_data_v[8];
724
725 const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride;
726 float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride;
727 const int residual_cols = std::min(end_col - col, 8);
728
729 for (int row = params.start_row; row < end_row; row += 8) {
730 const int residual_rows = std::min(end_row - row, 8);
731
732 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
733 float* dst_ptr = dst_col_ptr + row;
734
735 // Initialize with bias.
736 if (channel_dimension_is_col) {
737 const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
738 for (int j = 0; j < 8; ++j) {
739 accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
740 }
741 } else {
742 const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
743 const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
744
745 for (int j = 0; j < 8; ++j) {
746 accum_data_v[j] = initial_accum_data;
747 }
748 }
749
750 const float* lhs_ptr = lhs_col_ptr;
751 const float* rhs_ptr = rhs_col_ptr;
752 for (int d = 0; d < params.depth; ++d) {
753 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
754
755 __m256 rhs0_3 =
756 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
757 __m256 rhs4_7 =
758 _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
759
760 const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
761 accum_data_v[0] = intrin_utils::MulAdd<path>(
762 lhs_data, dup_rhs_element_0, accum_data_v[0]);
763
764 const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
765 accum_data_v[1] = intrin_utils::MulAdd<path>(
766 lhs_data, dup_rhs_element_1, accum_data_v[1]);
767
768 const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
769 accum_data_v[2] = intrin_utils::MulAdd<path>(
770 lhs_data, dup_rhs_element_2, accum_data_v[2]);
771
772 const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
773 accum_data_v[3] = intrin_utils::MulAdd<path>(
774 lhs_data, dup_rhs_element_3, accum_data_v[3]);
775
776 const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
777 accum_data_v[4] = intrin_utils::MulAdd<path>(
778 lhs_data, dup_rhs_element_4, accum_data_v[4]);
779
780 const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
781 accum_data_v[5] = intrin_utils::MulAdd<path>(
782 lhs_data, dup_rhs_element_5, accum_data_v[5]);
783
784 const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
785 accum_data_v[6] = intrin_utils::MulAdd<path>(
786 lhs_data, dup_rhs_element_6, accum_data_v[6]);
787
788 const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
789 accum_data_v[7] = intrin_utils::MulAdd<path>(
790 lhs_data, dup_rhs_element_7, accum_data_v[7]);
791
792 lhs_ptr += 8;
793 rhs_ptr += 8;
794 }
795
796 for (int j = 0; j < residual_cols; ++j) {
797 float* block_ptr = dst_ptr + j * dst_stride;
798 accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v);
799 accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v);
800 intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows,
801 accum_data_v[j]);
802 }
803 } // End row-block loop.
804 } // End col-block terminal conditional.
805}
806
807template <Path path>
808inline void KernelFloatAvxCommonSingleCol(
809 const KernelParamsFloat<8, 8>& params) {
810 RUY_DCHECK_EQ(params.dst_cols, 1);
811 RUY_DCHECK_EQ(params.last_col, 0);
812 RUY_DCHECK_EQ(params.start_col, 0);
813
814 // As parameters are defined, we need to scale by sizeof(float).
815 const std::int64_t lhs_stride = params.lhs_stride >> 2;
816 //
817 int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0;
818 // AVX2 float block size = 8.
819 const int end_row = std::min(params.dst_rows, params.last_row + 8);
820
821 float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row;
822 const float* adj_lhs_col_ptr =
823 params.lhs_base_ptr - params.start_row * lhs_stride;
824 const float* bias_col_ptr = params.bias;
825
826 const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
827 const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
828
829 __m256 accum_data_v;
830
831 const float* rhs_col_ptr = params.rhs_base_ptr;
832 float* dst_col_ptr = adj_dst_col_ptr;
833
834 int row = params.start_row;
835 for (; row <= end_row - 8; row += 8) {
836 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
837 float* dst_ptr = dst_col_ptr + row;
838 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
839
840 // Initialize with bias.
841 accum_data_v = _mm256_loadu_ps(bias_ptr);
842
843 const float* lhs_ptr = lhs_col_ptr;
844 const float* rhs_ptr = rhs_col_ptr;
845 int d = 0;
846 for (; d <= params.depth - 4; d += 4) {
847 const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr);
848 const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]);
849 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0,
850 accum_data_v);
851 const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]);
852 const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8);
853 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1,
854 accum_data_v);
855
856 const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16);
857 const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]);
858 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2,
859 accum_data_v);
860 const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]);
861 const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24);
862 accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3,
863 accum_data_v);
864 lhs_ptr += 32; // Loaded 8 * 4 floats.
865 rhs_ptr += 32;
866 }
867 for (; d < params.depth; ++d) {
868 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
869 const float* rhs_data = rhs_ptr;
870
871 const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
872 accum_data_v =
873 intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
874 lhs_ptr += 8;
875 rhs_ptr += 8;
876 }
877
878 accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
879 accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
880 _mm256_storeu_ps(dst_ptr, accum_data_v);
881 } // End row-block loop.
882
883 if (row < end_row) {
884 const int residual_rows = end_row - row;
885 RUY_CHECK_GE(residual_rows, 1);
886 RUY_CHECK_LT(residual_rows, 8);
887
888 const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
889 float* dst_ptr = dst_col_ptr + row;
890 const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
891
892 // Initialize with bias.
893 accum_data_v = _mm256_loadu_ps(bias_ptr);
894
895 const float* lhs_ptr = lhs_col_ptr;
896 const float* rhs_ptr = rhs_col_ptr;
897 for (int d = 0; d < params.depth; ++d) {
898 const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
899 const float* rhs_data = rhs_ptr;
900
901 const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]);
902 accum_data_v =
903 intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v);
904 lhs_ptr += 8;
905 rhs_ptr += 8;
906 }
907
908 accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v);
909 accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v);
910 intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v);
911 } // End handling of residual rows.
912}
913} // namespace ruy
914#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
915
916#endif // RUY_RUY_KERNEL_X86_H_
917