1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef RUY_RUY_PACK_X86_H_
17#define RUY_RUY_PACK_X86_H_
18
19#include <algorithm>
20#include <cstdint>
21#include <cstring>
22#include <type_traits>
23
24#include "ruy/check_macros.h"
25#include "ruy/mat.h"
26#include "ruy/opt_set.h"
27#include "ruy/pack_common.h"
28#include "ruy/path.h"
29#include "ruy/platform.h"
30#include "ruy/profiler/instrumentation.h"
31#include "ruy/tune.h"
32
33namespace ruy {
34
35#if RUY_PLATFORM_X86
36
37RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx)
38RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma)
39RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512)
40
41RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8)
42RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16)
43
44template <>
45struct PackedTypeImpl<Path::kAvx, std::uint8_t> {
46 using Type = std::int8_t;
47};
48
49template <>
50struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> {
51 using Type = std::int8_t;
52};
53template <>
54struct PackedTypeImpl<Path::kAvx512, std::uint8_t> {
55 using Type = std::int8_t;
56};
57
58// Note that source and zero buffers can be uint8 type, but in the packing
59// function are reinterpreted as int8, and are XOR-ed with input_xor.
60void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
61 const std::int8_t* zerobuf, int src_stride,
62 int remaining_src_cols, int src_rows,
63 std::int8_t* packed_ptr, std::int32_t* sums_ptr);
64
65template <typename Scalar>
66struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
67 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
68 static_assert(std::is_same<Scalar, std::int8_t>::value ||
69 std::is_same<Scalar, std::uint8_t>::value,
70 "");
71 using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
72 static constexpr std::int8_t kInputXor =
73 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
74
75 static void Run(Tuning, const Mat<Scalar>& src_matrix,
76 PMat<std::int8_t>* packed_matrix, int start_col,
77 int end_col) {
78 profiler::ScopeLabel label("Pack (AVX2 8-bit)");
79
80 RUY_DCHECK(IsColMajor(src_matrix.layout));
81 RUY_DCHECK(IsColMajor(packed_matrix->layout));
82 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
83 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
84 std::int32_t* sums = packed_matrix->sums;
85 Scalar zerobuf[Layout::kCols * Layout::kRows];
86 memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
87 Layout::kCols * Layout::kRows * sizeof(Scalar));
88 for (int block_col = start_col; block_col < end_col;
89 block_col += Layout::kCols) {
90 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
91 int src_stride = src_matrix.layout.stride;
92 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
93 int remaining_src_cols = src_matrix.layout.cols - block_col;
94
95 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
96 std::int8_t* packed_ptr =
97 packed_matrix->data +
98 packed_matrix->layout.stride * (block_col & block_col_mask);
99 Pack8bitColMajorForAvx2(
100 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
101 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
102 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
103 }
104 }
105};
106
107void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor,
108 const std::int8_t* zerobuf, int src_stride,
109 int remaining_src_cols, int src_rows,
110 std::int8_t* packed_ptr, std::int32_t* sums_ptr);
111
112template <typename Scalar>
113struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
114 std::int8_t, std::int32_t, Order::kColMajor> {
115 static_assert(std::is_same<Scalar, std::int8_t>::value ||
116 std::is_same<Scalar, std::uint8_t>::value,
117 "");
118 using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>;
119 static constexpr std::int8_t kInputXor =
120 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
121
122 static void Run(Tuning, const Mat<Scalar>& src_matrix,
123 PMat<std::int8_t>* packed_matrix, int start_col,
124 int end_col) {
125 profiler::ScopeLabel label("Pack (AVX 8-bit)");
126
127 RUY_DCHECK(IsColMajor(src_matrix.layout));
128 RUY_DCHECK(IsColMajor(packed_matrix->layout));
129 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
130 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
131 std::int32_t* sums = packed_matrix->sums;
132 Scalar zerobuf[Layout::kCols * Layout::kRows];
133 memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
134 Layout::kCols * Layout::kRows * sizeof(Scalar));
135 for (int block_col = start_col; block_col < end_col;
136 block_col += Layout::kCols) {
137 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
138 int src_stride = src_matrix.layout.stride;
139 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
140 int remaining_src_cols = src_matrix.layout.cols - block_col;
141
142 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
143 std::int8_t* packed_ptr =
144 packed_matrix->data +
145 packed_matrix->layout.stride * (block_col & block_col_mask);
146 Pack8bitColMajorForAvx(
147 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
148 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
149 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
150 }
151 }
152};
153
154void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf,
155 int src_stride, int remaining_src_cols,
156 int src_rows, float* packed_ptr);
157
158template <>
159struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
160 float, float, Order::kColMajor> {
161 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
162 static void Run(Tuning, const Mat<float>& src_matrix,
163 PMat<float>* packed_matrix, int start_col, int end_col) {
164 profiler::ScopeLabel label("Pack (AVX float)");
165
166 RUY_DCHECK(IsColMajor(src_matrix.layout));
167 RUY_DCHECK(IsColMajor(packed_matrix->layout));
168 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
169 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
170 const float zerobuf[Layout::kCols] = {
171 0.0f}; // Remainder default inits to 0.0f.
172 for (int block_col = start_col; block_col < end_col;
173 block_col += Layout::kCols) {
174 int src_stride = src_matrix.layout.stride;
175 const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
176 int remaining_src_cols = src_matrix.layout.cols - block_col;
177
178 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
179 float* packed_ptr =
180 packed_matrix->data +
181 packed_matrix->layout.stride * (block_col & block_col_mask);
182 PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols,
183 src_matrix.layout.rows, packed_ptr);
184 }
185 }
186};
187
188void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
189 int src_stride, int remaining_src_cols,
190 int src_rows, float* packed_ptr);
191
192template <>
193struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>,
194 float, float, float, Order::kColMajor> {
195 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
196 static void Run(Tuning, const Mat<float>& src_matrix,
197 PMat<float>* packed_matrix, int start_col, int end_col) {
198 profiler::ScopeLabel label("Pack (AVX2 float)");
199
200 RUY_DCHECK(IsColMajor(src_matrix.layout));
201 RUY_DCHECK(IsColMajor(packed_matrix->layout));
202 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
203 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
204 const float zerobuf[Layout::kCols] = {
205 0.0f}; // Remainder default inits to 0.0f.
206 for (int block_col = start_col; block_col < end_col;
207 block_col += Layout::kCols) {
208 int src_stride = src_matrix.layout.stride;
209 const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
210 int remaining_src_cols = src_matrix.layout.cols - block_col;
211
212 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
213 float* packed_ptr =
214 packed_matrix->data +
215 packed_matrix->layout.stride * (block_col & block_col_mask);
216 PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols,
217 src_matrix.layout.rows, packed_ptr);
218 }
219 }
220};
221
222// Note that source and zero buffers can be uint8 type, but in the packing
223// function are reinterpreted as int8, and are XOR-ed with input_xor.
224void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
225 std::int8_t input_xor,
226 const std::int8_t* zerobuf, int src_stride,
227 int remaining_src_cols, int src_rows,
228 std::int8_t* packed_ptr, std::int32_t* sums_ptr);
229
230template <typename Scalar>
231struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
232 Scalar, std::int8_t, std::int32_t, Order::kColMajor> {
233 static_assert(std::is_same<Scalar, std::int8_t>::value ||
234 std::is_same<Scalar, std::uint8_t>::value,
235 "");
236 using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
237 static constexpr int kHalfLayoutCols =
238 8; // Half the number of cols in a block.
239 static constexpr std::int8_t kInputXor =
240 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
241
242 static void Run(Tuning, const Mat<Scalar>& src_matrix,
243 PMat<std::int8_t>* packed_matrix, int start_col,
244 int end_col) {
245 profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
246
247 RUY_DCHECK(IsColMajor(src_matrix.layout));
248 RUY_DCHECK(IsColMajor(packed_matrix->layout));
249 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
250 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
251 RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
252 std::int32_t* sums = packed_matrix->sums;
253 Scalar zerobuf[kHalfLayoutCols * Layout::kRows];
254 memset(zerobuf, packed_matrix->zero_point ^ kInputXor,
255 kHalfLayoutCols * Layout::kRows * sizeof(Scalar));
256 for (int block_col = start_col; block_col < end_col;
257 block_col += Layout::kCols) {
258 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
259 int src_stride = src_matrix.layout.stride;
260 const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col;
261 int remaining_src_cols = src_matrix.layout.cols - block_col;
262
263 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
264 std::int8_t* packed_ptr =
265 packed_matrix->data +
266 packed_matrix->layout.stride * (block_col & block_col_mask);
267 Pack8bitColMajorForAvx512(
268 reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor,
269 reinterpret_cast<const std::int8_t*>(zerobuf), src_stride,
270 remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr);
271 }
272 }
273};
274
275void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
276 const std::int16_t* zerobuf, int src_stride,
277 int remaining_src_cols, int src_rows,
278 std::int16_t* packed_ptr,
279 std::int32_t* sums_ptr);
280
281template <>
282struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
283 std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> {
284 using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>;
285 static constexpr int kHalfLayoutCols =
286 8; // Half the number of cols in a block.
287
288 static void Run(Tuning, const Mat<std::int16_t>& src_matrix,
289 PMat<std::int16_t>* packed_matrix, int start_col,
290 int end_col) {
291 profiler::ScopeLabel label("Pack (AVX-512 16-bit)");
292
293 RUY_DCHECK(IsColMajor(src_matrix.layout));
294 RUY_DCHECK(IsColMajor(packed_matrix->layout));
295 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
296 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
297 RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols);
298 std::int32_t* sums = packed_matrix->sums;
299 std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows];
300 std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows,
301 static_cast<int16_t>(packed_matrix->zero_point));
302 for (int block_col = start_col; block_col < end_col;
303 block_col += Layout::kCols) {
304 std::int32_t* sums_ptr = sums ? sums + block_col : nullptr;
305 int src_stride = src_matrix.layout.stride;
306 const std::int16_t* src_ptr =
307 src_matrix.data.get() + src_stride * block_col;
308 int remaining_src_cols = src_matrix.layout.cols - block_col;
309
310 static constexpr int block_col_mask = ~(Layout::kCols - 1);
311 std::int16_t* packed_ptr =
312 packed_matrix->data +
313 packed_matrix->layout.stride * (block_col & block_col_mask);
314 Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride,
315 remaining_src_cols, src_matrix.layout.rows,
316 packed_ptr, sums_ptr);
317 }
318 }
319};
320
321void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
322 int src_stride, int remaining_src_cols,
323 int src_rows, float* packed_ptr);
324
325template <>
326struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
327 float, float, float, Order::kColMajor> {
328 static void Run(Tuning, const Mat<float>& src_matrix,
329 PMat<float>* packed_matrix, int start_col, int end_col) {
330 profiler::ScopeLabel label("Pack (AVX-512 float)");
331 using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
332 RUY_DCHECK(IsColMajor(src_matrix.layout));
333 RUY_DCHECK(IsColMajor(packed_matrix->layout));
334 RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0);
335 RUY_DCHECK_EQ(start_col % Layout::kCols, 0);
336 const float zerobuf[Layout::kCols] = {
337 0.0f}; // Remainder default inits to 0.0f.
338 for (int block_col = start_col; block_col < end_col;
339 block_col += Layout::kCols) {
340 int src_stride = src_matrix.layout.stride;
341 const float* src_ptr = src_matrix.data.get() + src_stride * block_col;
342 int remaining_src_cols = src_matrix.layout.cols - block_col;
343
344 static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits.
345 float* packed_ptr =
346 packed_matrix->data +
347 packed_matrix->layout.stride * (block_col & block_col_mask);
348 PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride,
349 remaining_src_cols, src_matrix.layout.rows,
350 packed_ptr);
351 }
352 }
353};
354
355void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
356 int src_zero_point, std::int8_t* packed_ptr,
357 int packed_stride, int start_col, int end_col,
358 int src_cols, int block_row, int src_rows,
359 int input_xor, std::int32_t* sums);
360
361template <typename Scalar>
362struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
363 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
364 static void Run(Tuning, const Mat<Scalar>& src_matrix,
365 PMat<std::int8_t>* packed_matrix, int start_col,
366 int end_col) {
367 profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)");
368 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
369 static constexpr int kInputXor =
370 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
371 std::int32_t* sums = packed_matrix->sums;
372 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
373 int block_row = 0;
374 for (; block_row < packed_matrix->layout.rows; block_row += 4) {
375 int src_stride = src_matrix.layout.stride;
376 int packed_stride = packed_matrix->layout.stride;
377 const Scalar* src_ptr =
378 src_matrix.data.get() + block_row * src_stride + start_col;
379 std::int8_t* packed_ptr =
380 packed_matrix->data + start_col * packed_stride + block_row * 8;
381 Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr),
382 src_stride, src_matrix.zero_point, packed_ptr,
383 packed_stride, start_col, end_col,
384 src_matrix.layout.cols, block_row,
385 src_matrix.layout.rows, kInputXor, sums);
386 }
387 }
388};
389
390void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride,
391 int src_zero_point, std::int8_t* packed_ptr,
392 int packed_stride, int start_col, int end_col,
393 int src_cols, int block_row, int src_rows,
394 int input_xor, std::int32_t* sums);
395
396template <typename Scalar>
397struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
398 std::int8_t, std::int32_t, Order::kRowMajor> {
399 static void Run(Tuning, const Mat<Scalar>& src_matrix,
400 PMat<std::int8_t>* packed_matrix, int start_col,
401 int end_col) {
402 profiler::ScopeLabel label("Pack (AVX 8bit row-major)");
403 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
404 static constexpr int kInputXor =
405 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
406 std::int32_t* sums = packed_matrix->sums;
407 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
408 int block_row = 0;
409 for (; block_row < packed_matrix->layout.rows; block_row += 4) {
410 int src_stride = src_matrix.layout.stride;
411 int packed_stride = packed_matrix->layout.stride;
412 const Scalar* src_ptr =
413 src_matrix.data.get() + block_row * src_stride + start_col;
414 std::int8_t* packed_ptr =
415 packed_matrix->data + start_col * packed_stride + block_row * 8;
416 Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr),
417 src_stride, src_matrix.zero_point, packed_ptr,
418 packed_stride, start_col, end_col,
419 src_matrix.layout.cols, block_row,
420 src_matrix.layout.rows, kInputXor, sums);
421 }
422 }
423};
424
425void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
426 int src_zero_point, std::int8_t* packed_ptr,
427 int packed_stride, int start_col, int end_col,
428 int src_cols, int block_row, int src_rows,
429 int input_xor, std::int32_t* sums);
430
431template <typename Scalar>
432struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
433 Scalar, std::int8_t, std::int32_t, Order::kRowMajor> {
434 static void Run(Tuning, const Mat<Scalar>& src_matrix,
435 PMat<std::int8_t>* packed_matrix, int start_col,
436 int end_col) {
437 profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)");
438 RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor);
439 static constexpr int kInputXor =
440 std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
441 std::int32_t* sums = packed_matrix->sums;
442 std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col));
443 int block_row = 0;
444 for (; block_row < packed_matrix->layout.rows; block_row += 4) {
445 int src_stride = src_matrix.layout.stride;
446 int packed_stride = packed_matrix->layout.stride;
447 const Scalar* src_ptr =
448 src_matrix.data.get() + block_row * src_stride + start_col;
449 std::int8_t* packed_ptr =
450 packed_matrix->data + start_col * packed_stride + block_row * 16;
451 Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr),
452 src_stride, src_matrix.zero_point, packed_ptr,
453 packed_stride, start_col, end_col,
454 src_matrix.layout.cols, block_row,
455 src_matrix.layout.rows, kInputXor, sums);
456 }
457 }
458};
459#endif // RUY_PLATFORM_X86
460
461} // namespace ruy
462
463#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
464
465#include <immintrin.h> // IWYU pragma: keep
466
467namespace ruy {
468namespace {
469
470template <Path path>
471inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) {
472 return _mm256_castpd_ps(
473 _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
474}
475
476template <Path path>
477inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) {
478 return _mm256_castpd_ps(
479 _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b)));
480}
481
482template <Path path>
483inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) {
484 RUY_DCHECK(false);
485 return _mm256_set1_epi32(0);
486}
487
488// Shared between AVX and AVX2+FMA.
489template <Path path>
490inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point,
491 const std::int8_t* addr) {
492 RUY_DCHECK_LT(available_src_rows, 32);
493 __m256i padded_data;
494
495 if (available_src_rows >= 16) {
496 __m128i load_hi = _mm_set1_epi8(zero_point);
497 __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr));
498 memcpy(&load_hi, addr + 16, available_src_rows - 16);
499 padded_data = _mm256_set_m128i(load_hi, load_lo);
500 } else {
501 __m128i load_hi = _mm_set1_epi8(zero_point);
502 __m128i load_lo = load_hi;
503 memcpy(&load_lo, addr, available_src_rows);
504 padded_data = _mm256_set_m128i(load_hi, load_lo);
505 }
506 return padded_data;
507}
508
509} // namespace.
510
511template <typename PackImpl, Path path>
512inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr,
513 const float* zerobuf,
514 int src_stride,
515 int remaining_src_cols,
516 int src_rows, float* packed_ptr,
517 float* trailing_buf) {
518 RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8);
519 RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1);
520
521 // This packing amounts to transposition of 8x8 blocks.
522 static constexpr int kPackCols = 8; // Source cols packed together.
523 static constexpr int kPackRows = 8; // Short input is padded.
524
525 const float* src_ptr0 = src_ptr;
526 const float* src_ptr1 = src_ptr0 + src_stride;
527 const float* src_ptr2 = src_ptr1 + src_stride;
528 const float* src_ptr3 = src_ptr2 + src_stride;
529 const float* src_ptr4 = src_ptr3 + src_stride;
530 const float* src_ptr5 = src_ptr4 + src_stride;
531 const float* src_ptr6 = src_ptr5 + src_stride;
532 const float* src_ptr7 = src_ptr6 + src_stride;
533 std::int64_t src_inc0 = 8;
534 std::int64_t src_inc1 = 8;
535 std::int64_t src_inc2 = 8;
536 std::int64_t src_inc3 = 8;
537 std::int64_t src_inc4 = 8;
538 std::int64_t src_inc5 = 8;
539 std::int64_t src_inc6 = 8;
540 std::int64_t src_inc7 = 8;
541 // Handle cases where source does not have kPackDim (8) columns.
542 if (remaining_src_cols < kPackCols) {
543 if (remaining_src_cols <= 0) {
544 src_ptr0 = zerobuf;
545 src_inc0 = 0;
546 }
547 if (remaining_src_cols <= 1) {
548 src_ptr1 = zerobuf;
549 src_inc1 = 0;
550 }
551 if (remaining_src_cols <= 2) {
552 src_ptr2 = zerobuf;
553 src_inc2 = 0;
554 }
555 if (remaining_src_cols <= 3) {
556 src_ptr3 = zerobuf;
557 src_inc3 = 0;
558 }
559 if (remaining_src_cols <= 4) {
560 src_ptr4 = zerobuf;
561 src_inc4 = 0;
562 }
563 if (remaining_src_cols <= 5) {
564 src_ptr5 = zerobuf;
565 src_inc5 = 0;
566 }
567 if (remaining_src_cols <= 6) {
568 src_ptr6 = zerobuf;
569 src_inc6 = 0;
570 }
571 src_ptr7 = zerobuf;
572 src_inc7 = 0;
573 }
574
575 for (int k = 0; k < src_rows; k += kPackRows) {
576 const int available_src_rows = src_rows - k;
577 // Effectively,
578 // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
579 // but treat each case separately.
580 if (available_src_rows >= kPackRows) {
581 __m256 t0, t1, t2, t3, t4, t5, t6, t7;
582 __m256 r0, r1, r2, r3, r4, r5, r6, r7;
583
584 t0 = _mm256_loadu_ps(src_ptr0);
585 t4 = _mm256_loadu_ps(src_ptr4);
586 t1 = _mm256_loadu_ps(src_ptr1);
587 t5 = _mm256_loadu_ps(src_ptr5);
588 t2 = _mm256_loadu_ps(src_ptr2);
589 t6 = _mm256_loadu_ps(src_ptr6);
590 t3 = _mm256_loadu_ps(src_ptr3);
591 t7 = _mm256_loadu_ps(src_ptr7);
592
593 r0 = _mm256_unpacklo_ps(t0, t1);
594 r4 = _mm256_unpacklo_ps(t4, t5);
595 r2 = _mm256_unpackhi_ps(t0, t1);
596 r6 = _mm256_unpackhi_ps(t4, t5);
597 r1 = _mm256_unpacklo_ps(t2, t3);
598 r5 = _mm256_unpacklo_ps(t6, t7);
599 r3 = _mm256_unpackhi_ps(t2, t3);
600 r7 = _mm256_unpackhi_ps(t6, t7);
601
602 t0 = Mm256UnpackloPsx2<path>(r0, r1);
603 t4 = Mm256UnpackloPsx2<path>(r4, r5);
604 t2 = Mm256UnpackhiPsx2<path>(r0, r1);
605 t6 = Mm256UnpackhiPsx2<path>(r4, r5);
606 t1 = Mm256UnpackloPsx2<path>(r2, r3);
607 t5 = Mm256UnpackloPsx2<path>(r6, r7);
608 t3 = Mm256UnpackhiPsx2<path>(r2, r3);
609 t7 = Mm256UnpackhiPsx2<path>(r6, r7);
610
611 // The preceding sets of rearrangement operations interleaved by 4 bytes
612 // and then by 8 bytes *within* lanes. The following set interleave by 16
613 // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
614 // are interleaved to create (r0, r1). This complexity follows from the
615 // way that AVX is centered around MM 128-bit lanes.
616 r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
617 r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
618 r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
619 r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
620 r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
621 r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
622 r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
623 r7 = _mm256_permute2f128_ps(t3, t7, 0x31);
624
625 _mm256_storeu_ps(packed_ptr + 0 * 8, r0);
626 _mm256_storeu_ps(packed_ptr + 2 * 8, r4);
627 _mm256_storeu_ps(packed_ptr + 4 * 8, r1);
628 _mm256_storeu_ps(packed_ptr + 6 * 8, r5);
629 _mm256_storeu_ps(packed_ptr + 1 * 8, r2);
630 _mm256_storeu_ps(packed_ptr + 3 * 8, r6);
631 _mm256_storeu_ps(packed_ptr + 5 * 8, r3);
632 _mm256_storeu_ps(packed_ptr + 7 * 8, r7);
633 } else if (available_src_rows > 0) {
634 const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
635 const __m256i row_mask_v = CompareGreaterThan<path>(
636 _mm256_set1_epi32(available_src_rows), series);
637
638 __m256 t0, t1, t2, t3, t4, t5, t6, t7;
639 __m256 r0, r1, r2, r3, r4, r5, r6, r7;
640
641 t0 = _mm256_maskload_ps(src_ptr0, row_mask_v);
642 t4 = _mm256_maskload_ps(src_ptr4, row_mask_v);
643 t1 = _mm256_maskload_ps(src_ptr1, row_mask_v);
644 t5 = _mm256_maskload_ps(src_ptr5, row_mask_v);
645 t2 = _mm256_maskload_ps(src_ptr2, row_mask_v);
646 t6 = _mm256_maskload_ps(src_ptr6, row_mask_v);
647 t3 = _mm256_maskload_ps(src_ptr3, row_mask_v);
648 t7 = _mm256_maskload_ps(src_ptr7, row_mask_v);
649
650 r0 = _mm256_unpacklo_ps(t0, t1);
651 r4 = _mm256_unpacklo_ps(t4, t5);
652 r2 = _mm256_unpackhi_ps(t0, t1);
653 r6 = _mm256_unpackhi_ps(t4, t5);
654 r1 = _mm256_unpacklo_ps(t2, t3);
655 r5 = _mm256_unpacklo_ps(t6, t7);
656 r3 = _mm256_unpackhi_ps(t2, t3);
657 r7 = _mm256_unpackhi_ps(t6, t7);
658
659 t0 = Mm256UnpackloPsx2<path>(r0, r1);
660 t4 = Mm256UnpackloPsx2<path>(r4, r5);
661 t2 = Mm256UnpackhiPsx2<path>(r0, r1);
662 t6 = Mm256UnpackhiPsx2<path>(r4, r5);
663 t1 = Mm256UnpackloPsx2<path>(r2, r3);
664 t5 = Mm256UnpackloPsx2<path>(r6, r7);
665 t3 = Mm256UnpackhiPsx2<path>(r2, r3);
666 t7 = Mm256UnpackhiPsx2<path>(r6, r7);
667
668 // The preceding sets of rearrangement operations interleaved by 4 bytes
669 // and then by 8 bytes *within* lanes. The following set interleave by 16
670 // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4)
671 // are interleaved to create (r0, r1). This complexity follows from the
672 // way that AVX is centered around MM 128-bit lanes.
673 r0 = _mm256_permute2f128_ps(t0, t4, 0x20);
674 r4 = _mm256_permute2f128_ps(t1, t5, 0x20);
675 r1 = _mm256_permute2f128_ps(t0, t4, 0x31);
676 r5 = _mm256_permute2f128_ps(t1, t5, 0x31);
677 r2 = _mm256_permute2f128_ps(t2, t6, 0x20);
678 r6 = _mm256_permute2f128_ps(t3, t7, 0x20);
679 r3 = _mm256_permute2f128_ps(t2, t6, 0x31);
680 // r7 no longer needed.
681
682 _mm256_storeu_ps(trailing_buf + 0 * 8, r0);
683 _mm256_storeu_ps(trailing_buf + 2 * 8, r4);
684 _mm256_storeu_ps(trailing_buf + 4 * 8, r1);
685 _mm256_storeu_ps(trailing_buf + 6 * 8, r5);
686 _mm256_storeu_ps(trailing_buf + 1 * 8, r2);
687 _mm256_storeu_ps(trailing_buf + 3 * 8, r6);
688 _mm256_storeu_ps(trailing_buf + 5 * 8, r3);
689 // No store to (trailing_buf + 7 * 8), space not allocated.
690 }
691
692 packed_ptr += kPackRows * kPackCols;
693 src_ptr0 += src_inc0;
694 src_ptr1 += src_inc1;
695 src_ptr2 += src_inc2;
696 src_ptr3 += src_inc3;
697 src_ptr4 += src_inc4;
698 src_ptr5 += src_inc5;
699 src_ptr6 += src_inc6;
700 src_ptr7 += src_inc7;
701 }
702}
703} // namespace ruy
704#endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)
705
706#endif // RUY_RUY_PACK_X86_H_
707