1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_KERNEL_X86_H_ |
17 | #define RUY_RUY_KERNEL_X86_H_ |
18 | |
19 | #include <cstdint> |
20 | #include <cstring> |
21 | |
22 | #include "ruy/kernel_common.h" |
23 | #include "ruy/mat.h" |
24 | #include "ruy/mul_params.h" |
25 | #include "ruy/opt_set.h" |
26 | #include "ruy/path.h" |
27 | #include "ruy/platform.h" |
28 | #include "ruy/tune.h" |
29 | |
30 | namespace ruy { |
31 | |
32 | #if RUY_PLATFORM_X86 |
33 | |
34 | RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx) |
35 | RUY_INHERIT_KERNEL(Path::kAvx, Path::kAvx2Fma) |
36 | RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512) |
37 | |
38 | void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); |
39 | void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); |
40 | |
41 | template <typename DstScalar> |
42 | struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> { |
43 | static constexpr Path kPath = Path::kAvx512; |
44 | Tuning tuning = Tuning::kAuto; |
45 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
46 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
47 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
48 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
49 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
50 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
51 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
52 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
53 | end_col, dst, ¶ms); |
54 | if (dst->layout.cols == 1 && |
55 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
56 | Kernel8bitAvx512SingleCol(params); |
57 | } else { |
58 | Kernel8bitAvx512(params); |
59 | } |
60 | } |
61 | }; |
62 | |
63 | template <typename DstScalar> |
64 | struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t, |
65 | DstScalar> { |
66 | static constexpr Path kPath = Path::kAvx512; |
67 | Tuning tuning = Tuning::kAuto; |
68 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
69 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
70 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
71 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs, |
72 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
73 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
74 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
75 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
76 | end_col, dst, ¶ms); |
77 | if (dst->layout.cols == 1 && |
78 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
79 | Kernel8bitAvx512SingleCol(params); |
80 | } else { |
81 | Kernel8bitAvx512(params); |
82 | } |
83 | } |
84 | }; |
85 | |
86 | void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); |
87 | void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); |
88 | |
89 | template <> |
90 | struct Kernel<Path::kAvx512, float, float, float, float> { |
91 | static constexpr Path kPath = Path::kAvx512; |
92 | Tuning tuning = Tuning::kAuto; |
93 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; |
94 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>; |
95 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
96 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
97 | const MulParams<float, float>& mul_params, int start_row, |
98 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
99 | KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; |
100 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
101 | end_col, dst, ¶ms); |
102 | if (dst->layout.cols == 1 && |
103 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
104 | KernelFloatAvx512SingleCol(params); |
105 | } else { |
106 | KernelFloatAvx512(params); |
107 | } |
108 | } |
109 | }; |
110 | |
111 | void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); |
112 | void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); |
113 | |
114 | template <typename DstScalar> |
115 | struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, |
116 | DstScalar> { |
117 | static constexpr Path kPath = Path::kAvx2Fma; |
118 | Tuning tuning = Tuning::kAuto; |
119 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
120 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
121 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
122 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
123 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
124 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
125 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
126 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
127 | end_col, dst, ¶ms); |
128 | if (dst->layout.cols == 1 && |
129 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
130 | Kernel8bitAvx2SingleCol(params); |
131 | } else { |
132 | Kernel8bitAvx2(params); |
133 | } |
134 | } |
135 | }; |
136 | |
137 | template <typename DstScalar> |
138 | struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, |
139 | DstScalar> { |
140 | static constexpr Path kPath = Path::kAvx2Fma; |
141 | Tuning tuning = Tuning::kAuto; |
142 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
143 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
144 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
145 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs, |
146 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
147 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
148 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
149 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
150 | end_col, dst, ¶ms); |
151 | if (dst->layout.cols == 1 && |
152 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
153 | Kernel8bitAvx2SingleCol(params); |
154 | } else { |
155 | Kernel8bitAvx2(params); |
156 | } |
157 | } |
158 | }; |
159 | |
160 | void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); |
161 | void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); |
162 | |
163 | template <> |
164 | struct Kernel<Path::kAvx2Fma, float, float, float, float> { |
165 | static constexpr Path kPath = Path::kAvx2Fma; |
166 | Tuning tuning = Tuning::kAuto; |
167 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
168 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
169 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
170 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
171 | const MulParams<float, float>& mul_params, int start_row, |
172 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
173 | KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; |
174 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
175 | end_col, dst, ¶ms); |
176 | if (dst->layout.cols == 1 && |
177 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
178 | KernelFloatAvx2SingleCol(params); |
179 | } else { |
180 | KernelFloatAvx2(params); |
181 | } |
182 | } |
183 | }; |
184 | |
185 | void KernelFloatAvx(const KernelParamsFloat<8, 8>& params); |
186 | void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params); |
187 | |
188 | template <> |
189 | struct Kernel<Path::kAvx, float, float, float, float> { |
190 | static constexpr Path kPath = Path::kAvx; |
191 | Tuning tuning = Tuning::kAuto; |
192 | using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
193 | using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
194 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
195 | void Run(const PMat<float>& lhs, const PMat<float>& rhs, |
196 | const MulParams<float, float>& mul_params, int start_row, |
197 | int start_col, int end_row, int end_col, Mat<float>* dst) const { |
198 | KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params; |
199 | MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row, |
200 | end_col, dst, ¶ms); |
201 | if (dst->layout.cols == 1 && |
202 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
203 | KernelFloatAvxSingleCol(params); |
204 | } else { |
205 | KernelFloatAvx(params); |
206 | } |
207 | } |
208 | }; |
209 | |
210 | void Kernel8bitAvx(const KernelParams8bit<8, 8>& params); |
211 | void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params); |
212 | |
213 | template <typename DstScalar> |
214 | struct Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar> { |
215 | static constexpr Path kPath = Path::kAvx2Fma; |
216 | Tuning tuning = Tuning::kAuto; |
217 | using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
218 | using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
219 | explicit Kernel(Tuning tuning_) : tuning(tuning_) {} |
220 | void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs, |
221 | const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, |
222 | int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { |
223 | KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; |
224 | MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, |
225 | end_col, dst, ¶ms); |
226 | if (dst->layout.cols == 1 && |
227 | mul_params.channel_dimension() == ChannelDimension::kRow) { |
228 | Kernel8bitAvxSingleCol(params); |
229 | } else { |
230 | Kernel8bitAvx(params); |
231 | } |
232 | } |
233 | }; |
234 | |
235 | #endif // RUY_PLATFORM_X86 |
236 | } // namespace ruy |
237 | |
238 | #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) |
239 | |
240 | #include <immintrin.h> // IWYU pragma: keep |
241 | |
242 | namespace ruy { |
243 | namespace { |
244 | namespace intrin_utils { |
245 | |
246 | // Defined as a template so clang won't detect it as an uneeded |
247 | // definition. |
248 | template <Path path> |
249 | inline float mm256_get1_ps(const __m256 a, int i) { |
250 | __m256i ai = _mm256_castps_si256(a); |
251 | int float_val_as_int; |
252 | switch (i) { |
253 | case 0: |
254 | float_val_as_int = _mm256_extract_epi32(ai, 0); |
255 | break; |
256 | case 1: |
257 | float_val_as_int = _mm256_extract_epi32(ai, 1); |
258 | break; |
259 | case 2: |
260 | float_val_as_int = _mm256_extract_epi32(ai, 2); |
261 | break; |
262 | case 3: |
263 | float_val_as_int = _mm256_extract_epi32(ai, 3); |
264 | break; |
265 | case 4: |
266 | float_val_as_int = _mm256_extract_epi32(ai, 4); |
267 | break; |
268 | case 5: |
269 | float_val_as_int = _mm256_extract_epi32(ai, 5); |
270 | break; |
271 | case 6: |
272 | float_val_as_int = _mm256_extract_epi32(ai, 6); |
273 | break; |
274 | case 7: |
275 | float_val_as_int = _mm256_extract_epi32(ai, 7); |
276 | break; |
277 | default: |
278 | RUY_DCHECK_LT(i, 8); |
279 | return .0f; |
280 | } |
281 | float float_val; |
282 | std::memcpy(&float_val, &float_val_as_int, sizeof(float_val)); |
283 | return float_val; |
284 | } |
285 | |
286 | // Defined as a template so clang won't detect it as an uneeded |
287 | // definition. |
288 | template <Path path> |
289 | inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { |
290 | for (int i = 0; i < residual_rows; ++i) { |
291 | dst[i] = intrin_utils::mm256_get1_ps<path>(v, i); |
292 | } |
293 | } |
294 | |
295 | template <Path path> |
296 | inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { |
297 | // Specializations added for AVX and AVX2FMA paths in their respective kernel |
298 | // files. |
299 | RUY_DCHECK(false); |
300 | return _mm256_set1_ps(0); |
301 | } |
302 | |
303 | template <Path path> |
304 | inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) { |
305 | // Specializations added for AVX and AVX2FMA paths in their respective kernel |
306 | // files. |
307 | RUY_DCHECK(false); |
308 | return _mm256_set1_epi32(0); |
309 | } |
310 | |
311 | // Polyfill for _mm_storeu_si16(dst, v). |
312 | template <Path path> |
313 | inline void mm_storeu_si16(void* dst, __m128i v) { |
314 | #if (defined __clang__) || (defined _MSC_VER) |
315 | _mm_storeu_si16(dst, v); |
316 | #else |
317 | // GCC 9 lacks support for __mm_storeu_si16. |
318 | *static_cast<std::int16_t*>(dst) = _mm_extract_epi16(v, 0); |
319 | #endif |
320 | } |
321 | |
322 | // Polyfill for _mm_storeu_si32(dst, v). |
323 | template <Path path> |
324 | inline void mm_storeu_si32(void* dst, __m128i v) { |
325 | #if (defined __clang__) || (defined _MSC_VER) |
326 | _mm_storeu_si32(dst, v); |
327 | #else |
328 | // GCC 9 lacks support for __mm_storeu_si32. |
329 | *static_cast<std::int32_t*>(dst) = _mm_extract_epi32(v, 0); |
330 | #endif |
331 | } |
332 | |
333 | // Polyfill for _mm_loadu_si32(src). |
334 | template <Path path> |
335 | inline __m128i mm_loadu_si32(const void* src) { |
336 | #if (defined __clang__) || (defined _MSC_VER) |
337 | return _mm_loadu_si32(src); |
338 | #else |
339 | // GCC 9 lacks support for _mm_loadu_si32. |
340 | __m128i res; |
341 | asm("movss %[src], %[res]" |
342 | : [res] "=x" (res) |
343 | : [src] "m" (*static_cast<const int*>(src))); |
344 | return res; |
345 | #endif |
346 | } |
347 | |
348 | template <Path path> |
349 | inline __m128i mm256_extracti128_si256(const __m256i&, const int) { |
350 | RUY_DCHECK(false); |
351 | return _mm_setzero_si128(); |
352 | } |
353 | |
354 | template <Path path> |
355 | inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, |
356 | const __m256i v) { |
357 | // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. |
358 | const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); |
359 | __m256i shuffled_v; |
360 | if (residual_rows > 1) { |
361 | // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 |
362 | // in each 128-bit lane. |
363 | shuffled_v = intrin_utils::mm256_shuffle_epi8<path>(v, repack_perm); |
364 | } |
365 | switch (residual_rows) { |
366 | case 0: |
367 | break; |
368 | case 1: |
369 | dst[0] = _mm256_extract_epi8(v, 0); |
370 | break; |
371 | case 2: |
372 | mm_storeu_si16<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
373 | break; |
374 | case 3: { |
375 | __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 0); |
376 | mm_storeu_si16<path>(dst, trailing_packed); |
377 | dst[2] = _mm_extract_epi8(trailing_packed, 2); |
378 | break; |
379 | } |
380 | case 4: |
381 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
382 | break; |
383 | case 5: |
384 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
385 | dst[4] = _mm256_extract_epi8(shuffled_v, 16); |
386 | break; |
387 | case 6: |
388 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
389 | mm_storeu_si16<path>(dst + 4, |
390 | mm256_extracti128_si256<path>(shuffled_v, 1)); |
391 | break; |
392 | case 7: { |
393 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
394 | __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); |
395 | mm_storeu_si16<path>(dst + 4, trailing_packed); |
396 | dst[6] = _mm_extract_epi8(trailing_packed, 2); |
397 | break; |
398 | } |
399 | case 8: |
400 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
401 | mm_storeu_si32<path>(dst + 4, |
402 | mm256_extracti128_si256<path>(shuffled_v, 1)); |
403 | break; |
404 | default: |
405 | RUY_DCHECK_LE(residual_rows, 8); |
406 | break; |
407 | } |
408 | } |
409 | |
410 | template <Path path> |
411 | inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { |
412 | // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. |
413 | const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); |
414 | const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); |
415 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
416 | mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); |
417 | } |
418 | |
419 | template <Path path> |
420 | inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, |
421 | const __m256i v) { |
422 | intrin_utils::mm256_n_storeu_cvtepi32_epi8<path>( |
423 | reinterpret_cast<std::uint8_t*>(dst), residual_rows, v); |
424 | } |
425 | |
426 | template <Path path> |
427 | inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { |
428 | // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. |
429 | const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); |
430 | const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); |
431 | mm_storeu_si32<path>(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
432 | mm_storeu_si32<path>(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); |
433 | } |
434 | |
435 | template <Path path> |
436 | inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, |
437 | const __m256i v) { |
438 | // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively |
439 | // truncating each 16-bit integer. |
440 | const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); |
441 | __m256i shuffled_v; |
442 | __m128i shuffled_v_low; |
443 | if (residual_rows > 1) { |
444 | shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); |
445 | shuffled_v_low = mm256_extracti128_si256<path>(shuffled_v, 0); |
446 | } else { |
447 | shuffled_v_low = mm256_extracti128_si256<path>(v, 0); |
448 | } |
449 | switch (residual_rows) { |
450 | case 0: |
451 | break; |
452 | case 1: |
453 | mm_storeu_si16<path>(dst, shuffled_v_low); |
454 | break; |
455 | case 2: |
456 | mm_storeu_si32<path>(dst, shuffled_v_low); |
457 | break; |
458 | case 3: { |
459 | mm_storeu_si32<path>(dst, shuffled_v_low); |
460 | dst[2] = _mm_extract_epi16(shuffled_v_low, 2); |
461 | break; |
462 | } |
463 | case 4: |
464 | _mm_storeu_si64(dst, shuffled_v_low); |
465 | break; |
466 | case 5: |
467 | _mm_storeu_si64(dst, shuffled_v_low); |
468 | dst[4] = _mm256_extract_epi16(shuffled_v, 8); |
469 | break; |
470 | case 6: |
471 | _mm_storeu_si64(dst, shuffled_v_low); |
472 | mm_storeu_si32<path>(dst + 4, |
473 | mm256_extracti128_si256<path>(shuffled_v, 1)); |
474 | break; |
475 | case 7: { |
476 | _mm_storeu_si64(dst, shuffled_v_low); |
477 | __m128i trailing_packed = mm256_extracti128_si256<path>(shuffled_v, 1); |
478 | mm_storeu_si32<path>(dst + 4, trailing_packed); |
479 | dst[6] = _mm_extract_epi16(trailing_packed, 2); |
480 | break; |
481 | } |
482 | case 8: |
483 | _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
484 | _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); |
485 | break; |
486 | default: |
487 | RUY_DCHECK_LE(residual_rows, 8); |
488 | break; |
489 | } |
490 | } |
491 | |
492 | template <Path path> |
493 | inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { |
494 | // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively |
495 | // truncating each 16-bit integer. |
496 | const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); |
497 | const __m256i shuffled_v = mm256_shuffle_epi8<path>(v, repack_perm); |
498 | _mm_storeu_si64(dst, mm256_extracti128_si256<path>(shuffled_v, 0)); |
499 | _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(shuffled_v, 1)); |
500 | } |
501 | |
502 | template <Path path> |
503 | inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, |
504 | const __m256i v) { |
505 | const __m128i v_low = mm256_extracti128_si256<path>(v, 0); |
506 | switch (residual_rows) { |
507 | case 0: |
508 | break; |
509 | case 1: |
510 | mm_storeu_si32<path>(dst, v_low); |
511 | break; |
512 | case 2: |
513 | _mm_storeu_si64(dst, v_low); |
514 | break; |
515 | case 3: { |
516 | __m128i trailing_packed = v_low; |
517 | _mm_storeu_si64(dst, trailing_packed); |
518 | dst[2] = _mm_extract_epi32(trailing_packed, 2); |
519 | break; |
520 | } |
521 | case 4: |
522 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); |
523 | break; |
524 | case 5: |
525 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); |
526 | dst[4] = _mm256_extract_epi32(v, 4); |
527 | break; |
528 | case 6: |
529 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); |
530 | _mm_storeu_si64(dst + 4, mm256_extracti128_si256<path>(v, 1)); |
531 | break; |
532 | case 7: { |
533 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); |
534 | __m128i trailing_packed = mm256_extracti128_si256<path>(v, 1); |
535 | _mm_storeu_si64(dst + 4, trailing_packed); |
536 | dst[6] = _mm_extract_epi32(trailing_packed, 2); |
537 | break; |
538 | } |
539 | case 8: |
540 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); |
541 | break; |
542 | default: |
543 | RUY_DCHECK_LE(residual_rows, 8); |
544 | break; |
545 | } |
546 | } |
547 | |
548 | template <Path path> |
549 | inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { |
550 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); |
551 | } |
552 | |
553 | // Transpose a 8x8 matrix of floats. |
554 | template <Path path> |
555 | void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3, |
556 | __m256* v4, __m256* v5, __m256* v6, __m256* v7) { |
557 | __m256 t2x2_0 = _mm256_unpacklo_ps(*v0, *v1); |
558 | __m256 t2x2_1 = _mm256_unpackhi_ps(*v0, *v1); |
559 | __m256 t2x2_2 = _mm256_unpacklo_ps(*v2, *v3); |
560 | __m256 t2x2_3 = _mm256_unpackhi_ps(*v2, *v3); |
561 | __m256 t2x2_4 = _mm256_unpacklo_ps(*v4, *v5); |
562 | __m256 t2x2_5 = _mm256_unpackhi_ps(*v4, *v5); |
563 | __m256 t2x2_6 = _mm256_unpacklo_ps(*v6, *v7); |
564 | __m256 t2x2_7 = _mm256_unpackhi_ps(*v6, *v7); |
565 | __m256 t4x4_0 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(1, 0, 1, 0)); |
566 | __m256 t4x4_1 = _mm256_shuffle_ps(t2x2_0, t2x2_2, _MM_SHUFFLE(3, 2, 3, 2)); |
567 | __m256 t4x4_2 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(1, 0, 1, 0)); |
568 | __m256 t4x4_3 = _mm256_shuffle_ps(t2x2_1, t2x2_3, _MM_SHUFFLE(3, 2, 3, 2)); |
569 | __m256 t4x4_4 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(1, 0, 1, 0)); |
570 | __m256 t4x4_5 = _mm256_shuffle_ps(t2x2_4, t2x2_6, _MM_SHUFFLE(3, 2, 3, 2)); |
571 | __m256 t4x4_6 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(1, 0, 1, 0)); |
572 | __m256 t4x4_7 = _mm256_shuffle_ps(t2x2_5, t2x2_7, _MM_SHUFFLE(3, 2, 3, 2)); |
573 | *v0 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x20); |
574 | *v1 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x20); |
575 | *v2 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x20); |
576 | *v3 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x20); |
577 | *v4 = _mm256_permute2f128_ps(t4x4_0, t4x4_4, 0x31); |
578 | *v5 = _mm256_permute2f128_ps(t4x4_1, t4x4_5, 0x31); |
579 | *v6 = _mm256_permute2f128_ps(t4x4_2, t4x4_6, 0x31); |
580 | *v7 = _mm256_permute2f128_ps(t4x4_3, t4x4_7, 0x31); |
581 | } |
582 | |
583 | // Transpose a 8x8 matrix of int32's. |
584 | template <Path path> |
585 | void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2, |
586 | __m256i* v3, __m256i* v4, __m256i* v5, |
587 | __m256i* v6, __m256i* v7) { |
588 | mm256_transpose8x8_ps<path>( |
589 | reinterpret_cast<__m256*>(v0), reinterpret_cast<__m256*>(v1), |
590 | reinterpret_cast<__m256*>(v2), reinterpret_cast<__m256*>(v3), |
591 | reinterpret_cast<__m256*>(v4), reinterpret_cast<__m256*>(v5), |
592 | reinterpret_cast<__m256*>(v6), reinterpret_cast<__m256*>(v7)); |
593 | } |
594 | |
595 | } // namespace intrin_utils |
596 | } // namespace |
597 | |
598 | template <Path path> |
599 | inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { |
600 | // As parameters are defined, we need to scale by sizeof(float). |
601 | const std::int64_t lhs_stride = params.lhs_stride >> 2; |
602 | const std::int64_t dst_stride = params.dst_stride >> 2; |
603 | const std::int64_t rhs_stride = params.rhs_stride >> 2; |
604 | // |
605 | int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
606 | // AVX2 float block size = 8. |
607 | const int end_row = std::min(params.dst_rows, params.last_row + 8); |
608 | const int end_col = std::min(params.dst_cols, params.last_col + 8); |
609 | // |
610 | const float* adj_rhs_col_ptr = |
611 | params.rhs_base_ptr - params.start_col * rhs_stride; |
612 | float* adj_dst_col_ptr = |
613 | params.dst_base_ptr - params.start_col * dst_stride - params.start_row; |
614 | const float* adj_lhs_col_ptr = |
615 | params.lhs_base_ptr - params.start_row * lhs_stride; |
616 | const float* bias_ptr = params.bias; |
617 | |
618 | const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); |
619 | const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); |
620 | const bool channel_dimension_is_col = |
621 | params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; |
622 | |
623 | int col = params.start_col; |
624 | // Loop through cols by float block size, leaving incomplete remainder |
625 | for (; col <= end_col - 8; col += 8) { |
626 | __m256 accum_data_v[8]; |
627 | |
628 | const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
629 | float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
630 | |
631 | for (int row = params.start_row; row < end_row; row += 8) { |
632 | const int residual_rows = std::min(end_row - row, 8); |
633 | |
634 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
635 | float* dst_ptr = dst_col_ptr + row; |
636 | |
637 | // Initialize with bias. |
638 | if (channel_dimension_is_col) { |
639 | const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; |
640 | for (int j = 0; j < 8; ++j) { |
641 | accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); |
642 | } |
643 | } else { |
644 | const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; |
645 | const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); |
646 | |
647 | for (int j = 0; j < 8; ++j) { |
648 | accum_data_v[j] = initial_accum_data; |
649 | } |
650 | } |
651 | |
652 | const float* lhs_ptr = lhs_col_ptr; |
653 | const float* rhs_ptr = rhs_col_ptr; |
654 | for (int d = 0; d < params.depth; ++d) { |
655 | const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); |
656 | // Load 8 RHS values, then use permute instructions to broadcast each |
657 | // value to a register. _mm256_permute2f128_ps is slow on AMD. |
658 | __m256 rhs0_3 = |
659 | _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); |
660 | __m256 rhs4_7 = |
661 | _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); |
662 | |
663 | const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); |
664 | accum_data_v[0] = intrin_utils::MulAdd<path>( |
665 | lhs_data, dup_rhs_element_0, accum_data_v[0]); |
666 | |
667 | const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); |
668 | accum_data_v[1] = intrin_utils::MulAdd<path>( |
669 | lhs_data, dup_rhs_element_1, accum_data_v[1]); |
670 | |
671 | const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); |
672 | accum_data_v[2] = intrin_utils::MulAdd<path>( |
673 | lhs_data, dup_rhs_element_2, accum_data_v[2]); |
674 | |
675 | const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); |
676 | accum_data_v[3] = intrin_utils::MulAdd<path>( |
677 | lhs_data, dup_rhs_element_3, accum_data_v[3]); |
678 | |
679 | const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); |
680 | accum_data_v[4] = intrin_utils::MulAdd<path>( |
681 | lhs_data, dup_rhs_element_4, accum_data_v[4]); |
682 | |
683 | const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); |
684 | accum_data_v[5] = intrin_utils::MulAdd<path>( |
685 | lhs_data, dup_rhs_element_5, accum_data_v[5]); |
686 | |
687 | const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); |
688 | accum_data_v[6] = intrin_utils::MulAdd<path>( |
689 | lhs_data, dup_rhs_element_6, accum_data_v[6]); |
690 | |
691 | const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); |
692 | accum_data_v[7] = intrin_utils::MulAdd<path>( |
693 | lhs_data, dup_rhs_element_7, accum_data_v[7]); |
694 | |
695 | lhs_ptr += 8; |
696 | rhs_ptr += 8; |
697 | } |
698 | |
699 | if (residual_rows == 8) { |
700 | for (int j = 0; j < 8; ++j) { |
701 | float* block_ptr = dst_ptr + j * dst_stride; |
702 | accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); |
703 | accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); |
704 | _mm256_storeu_ps(block_ptr, accum_data_v[j]); |
705 | } |
706 | } else { |
707 | for (int j = 0; j < 8; ++j) { |
708 | float* block_ptr = dst_ptr + j * dst_stride; |
709 | accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); |
710 | accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); |
711 | intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, |
712 | accum_data_v[j]); |
713 | } |
714 | } |
715 | } // End row-block loop. |
716 | } // End col-block loop. |
717 | |
718 | if (col < end_col) { |
719 | // Remaining cols in [0, float block size). |
720 | RUY_DCHECK_GE(end_col - col, 0); |
721 | RUY_DCHECK_LT(end_col - col, 8); |
722 | |
723 | __m256 accum_data_v[8]; |
724 | |
725 | const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
726 | float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
727 | const int residual_cols = std::min(end_col - col, 8); |
728 | |
729 | for (int row = params.start_row; row < end_row; row += 8) { |
730 | const int residual_rows = std::min(end_row - row, 8); |
731 | |
732 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
733 | float* dst_ptr = dst_col_ptr + row; |
734 | |
735 | // Initialize with bias. |
736 | if (channel_dimension_is_col) { |
737 | const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment; |
738 | for (int j = 0; j < 8; ++j) { |
739 | accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j); |
740 | } |
741 | } else { |
742 | const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; |
743 | const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); |
744 | |
745 | for (int j = 0; j < 8; ++j) { |
746 | accum_data_v[j] = initial_accum_data; |
747 | } |
748 | } |
749 | |
750 | const float* lhs_ptr = lhs_col_ptr; |
751 | const float* rhs_ptr = rhs_col_ptr; |
752 | for (int d = 0; d < params.depth; ++d) { |
753 | const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); |
754 | |
755 | __m256 rhs0_3 = |
756 | _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); |
757 | __m256 rhs4_7 = |
758 | _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4)); |
759 | |
760 | const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); |
761 | accum_data_v[0] = intrin_utils::MulAdd<path>( |
762 | lhs_data, dup_rhs_element_0, accum_data_v[0]); |
763 | |
764 | const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); |
765 | accum_data_v[1] = intrin_utils::MulAdd<path>( |
766 | lhs_data, dup_rhs_element_1, accum_data_v[1]); |
767 | |
768 | const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); |
769 | accum_data_v[2] = intrin_utils::MulAdd<path>( |
770 | lhs_data, dup_rhs_element_2, accum_data_v[2]); |
771 | |
772 | const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); |
773 | accum_data_v[3] = intrin_utils::MulAdd<path>( |
774 | lhs_data, dup_rhs_element_3, accum_data_v[3]); |
775 | |
776 | const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); |
777 | accum_data_v[4] = intrin_utils::MulAdd<path>( |
778 | lhs_data, dup_rhs_element_4, accum_data_v[4]); |
779 | |
780 | const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); |
781 | accum_data_v[5] = intrin_utils::MulAdd<path>( |
782 | lhs_data, dup_rhs_element_5, accum_data_v[5]); |
783 | |
784 | const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); |
785 | accum_data_v[6] = intrin_utils::MulAdd<path>( |
786 | lhs_data, dup_rhs_element_6, accum_data_v[6]); |
787 | |
788 | const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); |
789 | accum_data_v[7] = intrin_utils::MulAdd<path>( |
790 | lhs_data, dup_rhs_element_7, accum_data_v[7]); |
791 | |
792 | lhs_ptr += 8; |
793 | rhs_ptr += 8; |
794 | } |
795 | |
796 | for (int j = 0; j < residual_cols; ++j) { |
797 | float* block_ptr = dst_ptr + j * dst_stride; |
798 | accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); |
799 | accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); |
800 | intrin_utils::mm256_n_storeu_ps<path>(block_ptr, residual_rows, |
801 | accum_data_v[j]); |
802 | } |
803 | } // End row-block loop. |
804 | } // End col-block terminal conditional. |
805 | } |
806 | |
807 | template <Path path> |
808 | inline void KernelFloatAvxCommonSingleCol( |
809 | const KernelParamsFloat<8, 8>& params) { |
810 | RUY_DCHECK_EQ(params.dst_cols, 1); |
811 | RUY_DCHECK_EQ(params.last_col, 0); |
812 | RUY_DCHECK_EQ(params.start_col, 0); |
813 | |
814 | // As parameters are defined, we need to scale by sizeof(float). |
815 | const std::int64_t lhs_stride = params.lhs_stride >> 2; |
816 | // |
817 | int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
818 | // AVX2 float block size = 8. |
819 | const int end_row = std::min(params.dst_rows, params.last_row + 8); |
820 | |
821 | float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; |
822 | const float* adj_lhs_col_ptr = |
823 | params.lhs_base_ptr - params.start_row * lhs_stride; |
824 | const float* bias_col_ptr = params.bias; |
825 | |
826 | const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); |
827 | const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); |
828 | |
829 | __m256 accum_data_v; |
830 | |
831 | const float* rhs_col_ptr = params.rhs_base_ptr; |
832 | float* dst_col_ptr = adj_dst_col_ptr; |
833 | |
834 | int row = params.start_row; |
835 | for (; row <= end_row - 8; row += 8) { |
836 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
837 | float* dst_ptr = dst_col_ptr + row; |
838 | const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
839 | |
840 | // Initialize with bias. |
841 | accum_data_v = _mm256_loadu_ps(bias_ptr); |
842 | |
843 | const float* lhs_ptr = lhs_col_ptr; |
844 | const float* rhs_ptr = rhs_col_ptr; |
845 | int d = 0; |
846 | for (; d <= params.depth - 4; d += 4) { |
847 | const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); |
848 | const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); |
849 | accum_data_v = intrin_utils::MulAdd<path>(lhs_data_0, dup_rhs_element_0, |
850 | accum_data_v); |
851 | const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); |
852 | const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); |
853 | accum_data_v = intrin_utils::MulAdd<path>(lhs_data_1, dup_rhs_element_1, |
854 | accum_data_v); |
855 | |
856 | const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); |
857 | const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); |
858 | accum_data_v = intrin_utils::MulAdd<path>(lhs_data_2, dup_rhs_element_2, |
859 | accum_data_v); |
860 | const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); |
861 | const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); |
862 | accum_data_v = intrin_utils::MulAdd<path>(lhs_data_3, dup_rhs_element_3, |
863 | accum_data_v); |
864 | lhs_ptr += 32; // Loaded 8 * 4 floats. |
865 | rhs_ptr += 32; |
866 | } |
867 | for (; d < params.depth; ++d) { |
868 | const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); |
869 | const float* rhs_data = rhs_ptr; |
870 | |
871 | const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); |
872 | accum_data_v = |
873 | intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); |
874 | lhs_ptr += 8; |
875 | rhs_ptr += 8; |
876 | } |
877 | |
878 | accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); |
879 | accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); |
880 | _mm256_storeu_ps(dst_ptr, accum_data_v); |
881 | } // End row-block loop. |
882 | |
883 | if (row < end_row) { |
884 | const int residual_rows = end_row - row; |
885 | RUY_CHECK_GE(residual_rows, 1); |
886 | RUY_CHECK_LT(residual_rows, 8); |
887 | |
888 | const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
889 | float* dst_ptr = dst_col_ptr + row; |
890 | const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
891 | |
892 | // Initialize with bias. |
893 | accum_data_v = _mm256_loadu_ps(bias_ptr); |
894 | |
895 | const float* lhs_ptr = lhs_col_ptr; |
896 | const float* rhs_ptr = rhs_col_ptr; |
897 | for (int d = 0; d < params.depth; ++d) { |
898 | const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); |
899 | const float* rhs_data = rhs_ptr; |
900 | |
901 | const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); |
902 | accum_data_v = |
903 | intrin_utils::MulAdd<path>(lhs_data, dup_rhs_element_j, accum_data_v); |
904 | lhs_ptr += 8; |
905 | rhs_ptr += 8; |
906 | } |
907 | |
908 | accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); |
909 | accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); |
910 | intrin_utils::mm256_n_storeu_ps<path>(dst_ptr, residual_rows, accum_data_v); |
911 | } // End handling of residual rows. |
912 | } |
913 | } // namespace ruy |
914 | #endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) |
915 | |
916 | #endif // RUY_RUY_KERNEL_X86_H_ |
917 | |