1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef RUY_RUY_PACK_X86_H_ |
17 | #define RUY_RUY_PACK_X86_H_ |
18 | |
19 | #include <algorithm> |
20 | #include <cstdint> |
21 | #include <cstring> |
22 | #include <type_traits> |
23 | |
24 | #include "ruy/check_macros.h" |
25 | #include "ruy/mat.h" |
26 | #include "ruy/opt_set.h" |
27 | #include "ruy/pack_common.h" |
28 | #include "ruy/path.h" |
29 | #include "ruy/platform.h" |
30 | #include "ruy/profiler/instrumentation.h" |
31 | #include "ruy/tune.h" |
32 | |
33 | namespace ruy { |
34 | |
35 | #if RUY_PLATFORM_X86 |
36 | |
37 | RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx) |
38 | RUY_INHERIT_PACK(Path::kAvx, Path::kAvx2Fma) |
39 | RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512) |
40 | |
41 | RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx2Fma, 8) |
42 | RUY_USE_MEMCPY_ROWMAJOR_FLOAT_PACK(Path::kAvx512, 16) |
43 | |
44 | template <> |
45 | struct PackedTypeImpl<Path::kAvx, std::uint8_t> { |
46 | using Type = std::int8_t; |
47 | }; |
48 | |
49 | template <> |
50 | struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> { |
51 | using Type = std::int8_t; |
52 | }; |
53 | template <> |
54 | struct PackedTypeImpl<Path::kAvx512, std::uint8_t> { |
55 | using Type = std::int8_t; |
56 | }; |
57 | |
58 | // Note that source and zero buffers can be uint8 type, but in the packing |
59 | // function are reinterpreted as int8, and are XOR-ed with input_xor. |
60 | void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, |
61 | const std::int8_t* zerobuf, int src_stride, |
62 | int remaining_src_cols, int src_rows, |
63 | std::int8_t* packed_ptr, std::int32_t* sums_ptr); |
64 | |
65 | template <typename Scalar> |
66 | struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, |
67 | Scalar, std::int8_t, std::int32_t, Order::kColMajor> { |
68 | static_assert(std::is_same<Scalar, std::int8_t>::value || |
69 | std::is_same<Scalar, std::uint8_t>::value, |
70 | "" ); |
71 | using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
72 | static constexpr std::int8_t kInputXor = |
73 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
74 | |
75 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
76 | PMat<std::int8_t>* packed_matrix, int start_col, |
77 | int end_col) { |
78 | profiler::ScopeLabel label("Pack (AVX2 8-bit)" ); |
79 | |
80 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
81 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
82 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
83 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
84 | std::int32_t* sums = packed_matrix->sums; |
85 | Scalar zerobuf[Layout::kCols * Layout::kRows]; |
86 | memset(zerobuf, packed_matrix->zero_point ^ kInputXor, |
87 | Layout::kCols * Layout::kRows * sizeof(Scalar)); |
88 | for (int block_col = start_col; block_col < end_col; |
89 | block_col += Layout::kCols) { |
90 | std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; |
91 | int src_stride = src_matrix.layout.stride; |
92 | const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; |
93 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
94 | |
95 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
96 | std::int8_t* packed_ptr = |
97 | packed_matrix->data + |
98 | packed_matrix->layout.stride * (block_col & block_col_mask); |
99 | Pack8bitColMajorForAvx2( |
100 | reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, |
101 | reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, |
102 | remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); |
103 | } |
104 | } |
105 | }; |
106 | |
107 | void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor, |
108 | const std::int8_t* zerobuf, int src_stride, |
109 | int remaining_src_cols, int src_rows, |
110 | std::int8_t* packed_ptr, std::int32_t* sums_ptr); |
111 | |
112 | template <typename Scalar> |
113 | struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, |
114 | std::int8_t, std::int32_t, Order::kColMajor> { |
115 | static_assert(std::is_same<Scalar, std::int8_t>::value || |
116 | std::is_same<Scalar, std::uint8_t>::value, |
117 | "" ); |
118 | using Layout = FixedKernelLayout<Order::kColMajor, 4, 8>; |
119 | static constexpr std::int8_t kInputXor = |
120 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
121 | |
122 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
123 | PMat<std::int8_t>* packed_matrix, int start_col, |
124 | int end_col) { |
125 | profiler::ScopeLabel label("Pack (AVX 8-bit)" ); |
126 | |
127 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
128 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
129 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
130 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
131 | std::int32_t* sums = packed_matrix->sums; |
132 | Scalar zerobuf[Layout::kCols * Layout::kRows]; |
133 | memset(zerobuf, packed_matrix->zero_point ^ kInputXor, |
134 | Layout::kCols * Layout::kRows * sizeof(Scalar)); |
135 | for (int block_col = start_col; block_col < end_col; |
136 | block_col += Layout::kCols) { |
137 | std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; |
138 | int src_stride = src_matrix.layout.stride; |
139 | const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; |
140 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
141 | |
142 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
143 | std::int8_t* packed_ptr = |
144 | packed_matrix->data + |
145 | packed_matrix->layout.stride * (block_col & block_col_mask); |
146 | Pack8bitColMajorForAvx( |
147 | reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, |
148 | reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, |
149 | remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); |
150 | } |
151 | } |
152 | }; |
153 | |
154 | void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf, |
155 | int src_stride, int remaining_src_cols, |
156 | int src_rows, float* packed_ptr); |
157 | |
158 | template <> |
159 | struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, |
160 | float, float, Order::kColMajor> { |
161 | using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
162 | static void Run(Tuning, const Mat<float>& src_matrix, |
163 | PMat<float>* packed_matrix, int start_col, int end_col) { |
164 | profiler::ScopeLabel label("Pack (AVX float)" ); |
165 | |
166 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
167 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
168 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
169 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
170 | const float zerobuf[Layout::kCols] = { |
171 | 0.0f}; // Remainder default inits to 0.0f. |
172 | for (int block_col = start_col; block_col < end_col; |
173 | block_col += Layout::kCols) { |
174 | int src_stride = src_matrix.layout.stride; |
175 | const float* src_ptr = src_matrix.data.get() + src_stride * block_col; |
176 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
177 | |
178 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
179 | float* packed_ptr = |
180 | packed_matrix->data + |
181 | packed_matrix->layout.stride * (block_col & block_col_mask); |
182 | PackFloatColMajorForAvx(src_ptr, zerobuf, src_stride, remaining_src_cols, |
183 | src_matrix.layout.rows, packed_ptr); |
184 | } |
185 | } |
186 | }; |
187 | |
188 | void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf, |
189 | int src_stride, int remaining_src_cols, |
190 | int src_rows, float* packed_ptr); |
191 | |
192 | template <> |
193 | struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, |
194 | float, float, float, Order::kColMajor> { |
195 | using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; |
196 | static void Run(Tuning, const Mat<float>& src_matrix, |
197 | PMat<float>* packed_matrix, int start_col, int end_col) { |
198 | profiler::ScopeLabel label("Pack (AVX2 float)" ); |
199 | |
200 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
201 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
202 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
203 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
204 | const float zerobuf[Layout::kCols] = { |
205 | 0.0f}; // Remainder default inits to 0.0f. |
206 | for (int block_col = start_col; block_col < end_col; |
207 | block_col += Layout::kCols) { |
208 | int src_stride = src_matrix.layout.stride; |
209 | const float* src_ptr = src_matrix.data.get() + src_stride * block_col; |
210 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
211 | |
212 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
213 | float* packed_ptr = |
214 | packed_matrix->data + |
215 | packed_matrix->layout.stride * (block_col & block_col_mask); |
216 | PackFloatColMajorForAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, |
217 | src_matrix.layout.rows, packed_ptr); |
218 | } |
219 | } |
220 | }; |
221 | |
222 | // Note that source and zero buffers can be uint8 type, but in the packing |
223 | // function are reinterpreted as int8, and are XOR-ed with input_xor. |
224 | void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, |
225 | std::int8_t input_xor, |
226 | const std::int8_t* zerobuf, int src_stride, |
227 | int remaining_src_cols, int src_rows, |
228 | std::int8_t* packed_ptr, std::int32_t* sums_ptr); |
229 | |
230 | template <typename Scalar> |
231 | struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, |
232 | Scalar, std::int8_t, std::int32_t, Order::kColMajor> { |
233 | static_assert(std::is_same<Scalar, std::int8_t>::value || |
234 | std::is_same<Scalar, std::uint8_t>::value, |
235 | "" ); |
236 | using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
237 | static constexpr int kHalfLayoutCols = |
238 | 8; // Half the number of cols in a block. |
239 | static constexpr std::int8_t kInputXor = |
240 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
241 | |
242 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
243 | PMat<std::int8_t>* packed_matrix, int start_col, |
244 | int end_col) { |
245 | profiler::ScopeLabel label("Pack (AVX-512 8-bit)" ); |
246 | |
247 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
248 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
249 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
250 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
251 | RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); |
252 | std::int32_t* sums = packed_matrix->sums; |
253 | Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; |
254 | memset(zerobuf, packed_matrix->zero_point ^ kInputXor, |
255 | kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); |
256 | for (int block_col = start_col; block_col < end_col; |
257 | block_col += Layout::kCols) { |
258 | std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; |
259 | int src_stride = src_matrix.layout.stride; |
260 | const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; |
261 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
262 | |
263 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
264 | std::int8_t* packed_ptr = |
265 | packed_matrix->data + |
266 | packed_matrix->layout.stride * (block_col & block_col_mask); |
267 | Pack8bitColMajorForAvx512( |
268 | reinterpret_cast<const std::int8_t*>(src_ptr), kInputXor, |
269 | reinterpret_cast<const std::int8_t*>(zerobuf), src_stride, |
270 | remaining_src_cols, src_matrix.layout.rows, packed_ptr, sums_ptr); |
271 | } |
272 | } |
273 | }; |
274 | |
275 | void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, |
276 | const std::int16_t* zerobuf, int src_stride, |
277 | int remaining_src_cols, int src_rows, |
278 | std::int16_t* packed_ptr, |
279 | std::int32_t* sums_ptr); |
280 | |
281 | template <> |
282 | struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, |
283 | std::int16_t, std::int16_t, std::int32_t, Order::kColMajor> { |
284 | using Layout = FixedKernelLayout<Order::kColMajor, 4, 16>; |
285 | static constexpr int kHalfLayoutCols = |
286 | 8; // Half the number of cols in a block. |
287 | |
288 | static void Run(Tuning, const Mat<std::int16_t>& src_matrix, |
289 | PMat<std::int16_t>* packed_matrix, int start_col, |
290 | int end_col) { |
291 | profiler::ScopeLabel label("Pack (AVX-512 16-bit)" ); |
292 | |
293 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
294 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
295 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
296 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
297 | RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); |
298 | std::int32_t* sums = packed_matrix->sums; |
299 | std::int16_t zerobuf[kHalfLayoutCols * Layout::kRows]; |
300 | std::fill(zerobuf, zerobuf + kHalfLayoutCols * Layout::kRows, |
301 | static_cast<int16_t>(packed_matrix->zero_point)); |
302 | for (int block_col = start_col; block_col < end_col; |
303 | block_col += Layout::kCols) { |
304 | std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; |
305 | int src_stride = src_matrix.layout.stride; |
306 | const std::int16_t* src_ptr = |
307 | src_matrix.data.get() + src_stride * block_col; |
308 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
309 | |
310 | static constexpr int block_col_mask = ~(Layout::kCols - 1); |
311 | std::int16_t* packed_ptr = |
312 | packed_matrix->data + |
313 | packed_matrix->layout.stride * (block_col & block_col_mask); |
314 | Pack16bitColMajorForAvx512(src_ptr, zerobuf, src_stride, |
315 | remaining_src_cols, src_matrix.layout.rows, |
316 | packed_ptr, sums_ptr); |
317 | } |
318 | } |
319 | }; |
320 | |
321 | void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, |
322 | int src_stride, int remaining_src_cols, |
323 | int src_rows, float* packed_ptr); |
324 | |
325 | template <> |
326 | struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>, |
327 | float, float, float, Order::kColMajor> { |
328 | static void Run(Tuning, const Mat<float>& src_matrix, |
329 | PMat<float>* packed_matrix, int start_col, int end_col) { |
330 | profiler::ScopeLabel label("Pack (AVX-512 float)" ); |
331 | using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>; |
332 | RUY_DCHECK(IsColMajor(src_matrix.layout)); |
333 | RUY_DCHECK(IsColMajor(packed_matrix->layout)); |
334 | RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); |
335 | RUY_DCHECK_EQ(start_col % Layout::kCols, 0); |
336 | const float zerobuf[Layout::kCols] = { |
337 | 0.0f}; // Remainder default inits to 0.0f. |
338 | for (int block_col = start_col; block_col < end_col; |
339 | block_col += Layout::kCols) { |
340 | int src_stride = src_matrix.layout.stride; |
341 | const float* src_ptr = src_matrix.data.get() + src_stride * block_col; |
342 | int remaining_src_cols = src_matrix.layout.cols - block_col; |
343 | |
344 | static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. |
345 | float* packed_ptr = |
346 | packed_matrix->data + |
347 | packed_matrix->layout.stride * (block_col & block_col_mask); |
348 | PackFloatColMajorForAvx512(src_ptr, zerobuf, src_stride, |
349 | remaining_src_cols, src_matrix.layout.rows, |
350 | packed_ptr); |
351 | } |
352 | } |
353 | }; |
354 | |
355 | void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride, |
356 | int src_zero_point, std::int8_t* packed_ptr, |
357 | int packed_stride, int start_col, int end_col, |
358 | int src_cols, int block_row, int src_rows, |
359 | int input_xor, std::int32_t* sums); |
360 | |
361 | template <typename Scalar> |
362 | struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, |
363 | Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { |
364 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
365 | PMat<std::int8_t>* packed_matrix, int start_col, |
366 | int end_col) { |
367 | profiler::ScopeLabel label("Pack (kAvx2Fma 8bit row-major)" ); |
368 | RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); |
369 | static constexpr int kInputXor = |
370 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
371 | std::int32_t* sums = packed_matrix->sums; |
372 | std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); |
373 | int block_row = 0; |
374 | for (; block_row < packed_matrix->layout.rows; block_row += 4) { |
375 | int src_stride = src_matrix.layout.stride; |
376 | int packed_stride = packed_matrix->layout.stride; |
377 | const Scalar* src_ptr = |
378 | src_matrix.data.get() + block_row * src_stride + start_col; |
379 | std::int8_t* packed_ptr = |
380 | packed_matrix->data + start_col * packed_stride + block_row * 8; |
381 | Pack8bitRowMajorForAvx2(reinterpret_cast<const std::uint8_t*>(src_ptr), |
382 | src_stride, src_matrix.zero_point, packed_ptr, |
383 | packed_stride, start_col, end_col, |
384 | src_matrix.layout.cols, block_row, |
385 | src_matrix.layout.rows, kInputXor, sums); |
386 | } |
387 | } |
388 | }; |
389 | |
390 | void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride, |
391 | int src_zero_point, std::int8_t* packed_ptr, |
392 | int packed_stride, int start_col, int end_col, |
393 | int src_cols, int block_row, int src_rows, |
394 | int input_xor, std::int32_t* sums); |
395 | |
396 | template <typename Scalar> |
397 | struct PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, |
398 | std::int8_t, std::int32_t, Order::kRowMajor> { |
399 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
400 | PMat<std::int8_t>* packed_matrix, int start_col, |
401 | int end_col) { |
402 | profiler::ScopeLabel label("Pack (AVX 8bit row-major)" ); |
403 | RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); |
404 | static constexpr int kInputXor = |
405 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
406 | std::int32_t* sums = packed_matrix->sums; |
407 | std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); |
408 | int block_row = 0; |
409 | for (; block_row < packed_matrix->layout.rows; block_row += 4) { |
410 | int src_stride = src_matrix.layout.stride; |
411 | int packed_stride = packed_matrix->layout.stride; |
412 | const Scalar* src_ptr = |
413 | src_matrix.data.get() + block_row * src_stride + start_col; |
414 | std::int8_t* packed_ptr = |
415 | packed_matrix->data + start_col * packed_stride + block_row * 8; |
416 | Pack8bitRowMajorForAvx(reinterpret_cast<const std::uint8_t*>(src_ptr), |
417 | src_stride, src_matrix.zero_point, packed_ptr, |
418 | packed_stride, start_col, end_col, |
419 | src_matrix.layout.cols, block_row, |
420 | src_matrix.layout.rows, kInputXor, sums); |
421 | } |
422 | } |
423 | }; |
424 | |
425 | void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride, |
426 | int src_zero_point, std::int8_t* packed_ptr, |
427 | int packed_stride, int start_col, int end_col, |
428 | int src_cols, int block_row, int src_rows, |
429 | int input_xor, std::int32_t* sums); |
430 | |
431 | template <typename Scalar> |
432 | struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, |
433 | Scalar, std::int8_t, std::int32_t, Order::kRowMajor> { |
434 | static void Run(Tuning, const Mat<Scalar>& src_matrix, |
435 | PMat<std::int8_t>* packed_matrix, int start_col, |
436 | int end_col) { |
437 | profiler::ScopeLabel label("Pack (kAvx512 8bit row-major)" ); |
438 | RUY_DCHECK_EQ(src_matrix.layout.order, Order::kRowMajor); |
439 | static constexpr int kInputXor = |
440 | std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80; |
441 | std::int32_t* sums = packed_matrix->sums; |
442 | std::memset(sums + start_col, 0, sizeof(sums[0]) * (end_col - start_col)); |
443 | int block_row = 0; |
444 | for (; block_row < packed_matrix->layout.rows; block_row += 4) { |
445 | int src_stride = src_matrix.layout.stride; |
446 | int packed_stride = packed_matrix->layout.stride; |
447 | const Scalar* src_ptr = |
448 | src_matrix.data.get() + block_row * src_stride + start_col; |
449 | std::int8_t* packed_ptr = |
450 | packed_matrix->data + start_col * packed_stride + block_row * 16; |
451 | Pack8bitRowMajorForAvx512(reinterpret_cast<const std::uint8_t*>(src_ptr), |
452 | src_stride, src_matrix.zero_point, packed_ptr, |
453 | packed_stride, start_col, end_col, |
454 | src_matrix.layout.cols, block_row, |
455 | src_matrix.layout.rows, kInputXor, sums); |
456 | } |
457 | } |
458 | }; |
459 | #endif // RUY_PLATFORM_X86 |
460 | |
461 | } // namespace ruy |
462 | |
463 | #if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)) |
464 | |
465 | #include <immintrin.h> // IWYU pragma: keep |
466 | |
467 | namespace ruy { |
468 | namespace { |
469 | |
470 | template <Path path> |
471 | inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { |
472 | return _mm256_castpd_ps( |
473 | _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); |
474 | } |
475 | |
476 | template <Path path> |
477 | inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { |
478 | return _mm256_castpd_ps( |
479 | _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); |
480 | } |
481 | |
482 | template <Path path> |
483 | inline __m256i CompareGreaterThan(const __m256i&, const __m256i&) { |
484 | RUY_DCHECK(false); |
485 | return _mm256_set1_epi32(0); |
486 | } |
487 | |
488 | // Shared between AVX and AVX2+FMA. |
489 | template <Path path> |
490 | inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, |
491 | const std::int8_t* addr) { |
492 | RUY_DCHECK_LT(available_src_rows, 32); |
493 | __m256i padded_data; |
494 | |
495 | if (available_src_rows >= 16) { |
496 | __m128i load_hi = _mm_set1_epi8(zero_point); |
497 | __m128i load_lo = _mm_loadu_si128(reinterpret_cast<const __m128i*>(addr)); |
498 | memcpy(&load_hi, addr + 16, available_src_rows - 16); |
499 | padded_data = _mm256_set_m128i(load_hi, load_lo); |
500 | } else { |
501 | __m128i load_hi = _mm_set1_epi8(zero_point); |
502 | __m128i load_lo = load_hi; |
503 | memcpy(&load_lo, addr, available_src_rows); |
504 | padded_data = _mm256_set_m128i(load_hi, load_lo); |
505 | } |
506 | return padded_data; |
507 | } |
508 | |
509 | } // namespace. |
510 | |
511 | template <typename PackImpl, Path path> |
512 | inline void PackFloatColMajorForAvxCommonPacker(const float* src_ptr, |
513 | const float* zerobuf, |
514 | int src_stride, |
515 | int remaining_src_cols, |
516 | int src_rows, float* packed_ptr, |
517 | float* trailing_buf) { |
518 | RUY_DCHECK_EQ(PackImpl::Layout::kCols, 8); |
519 | RUY_DCHECK_EQ(PackImpl::Layout::kRows, 1); |
520 | |
521 | // This packing amounts to transposition of 8x8 blocks. |
522 | static constexpr int kPackCols = 8; // Source cols packed together. |
523 | static constexpr int kPackRows = 8; // Short input is padded. |
524 | |
525 | const float* src_ptr0 = src_ptr; |
526 | const float* src_ptr1 = src_ptr0 + src_stride; |
527 | const float* src_ptr2 = src_ptr1 + src_stride; |
528 | const float* src_ptr3 = src_ptr2 + src_stride; |
529 | const float* src_ptr4 = src_ptr3 + src_stride; |
530 | const float* src_ptr5 = src_ptr4 + src_stride; |
531 | const float* src_ptr6 = src_ptr5 + src_stride; |
532 | const float* src_ptr7 = src_ptr6 + src_stride; |
533 | std::int64_t src_inc0 = 8; |
534 | std::int64_t src_inc1 = 8; |
535 | std::int64_t src_inc2 = 8; |
536 | std::int64_t src_inc3 = 8; |
537 | std::int64_t src_inc4 = 8; |
538 | std::int64_t src_inc5 = 8; |
539 | std::int64_t src_inc6 = 8; |
540 | std::int64_t src_inc7 = 8; |
541 | // Handle cases where source does not have kPackDim (8) columns. |
542 | if (remaining_src_cols < kPackCols) { |
543 | if (remaining_src_cols <= 0) { |
544 | src_ptr0 = zerobuf; |
545 | src_inc0 = 0; |
546 | } |
547 | if (remaining_src_cols <= 1) { |
548 | src_ptr1 = zerobuf; |
549 | src_inc1 = 0; |
550 | } |
551 | if (remaining_src_cols <= 2) { |
552 | src_ptr2 = zerobuf; |
553 | src_inc2 = 0; |
554 | } |
555 | if (remaining_src_cols <= 3) { |
556 | src_ptr3 = zerobuf; |
557 | src_inc3 = 0; |
558 | } |
559 | if (remaining_src_cols <= 4) { |
560 | src_ptr4 = zerobuf; |
561 | src_inc4 = 0; |
562 | } |
563 | if (remaining_src_cols <= 5) { |
564 | src_ptr5 = zerobuf; |
565 | src_inc5 = 0; |
566 | } |
567 | if (remaining_src_cols <= 6) { |
568 | src_ptr6 = zerobuf; |
569 | src_inc6 = 0; |
570 | } |
571 | src_ptr7 = zerobuf; |
572 | src_inc7 = 0; |
573 | } |
574 | |
575 | for (int k = 0; k < src_rows; k += kPackRows) { |
576 | const int available_src_rows = src_rows - k; |
577 | // Effectively, |
578 | // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); |
579 | // but treat each case separately. |
580 | if (available_src_rows >= kPackRows) { |
581 | __m256 t0, t1, t2, t3, t4, t5, t6, t7; |
582 | __m256 r0, r1, r2, r3, r4, r5, r6, r7; |
583 | |
584 | t0 = _mm256_loadu_ps(src_ptr0); |
585 | t4 = _mm256_loadu_ps(src_ptr4); |
586 | t1 = _mm256_loadu_ps(src_ptr1); |
587 | t5 = _mm256_loadu_ps(src_ptr5); |
588 | t2 = _mm256_loadu_ps(src_ptr2); |
589 | t6 = _mm256_loadu_ps(src_ptr6); |
590 | t3 = _mm256_loadu_ps(src_ptr3); |
591 | t7 = _mm256_loadu_ps(src_ptr7); |
592 | |
593 | r0 = _mm256_unpacklo_ps(t0, t1); |
594 | r4 = _mm256_unpacklo_ps(t4, t5); |
595 | r2 = _mm256_unpackhi_ps(t0, t1); |
596 | r6 = _mm256_unpackhi_ps(t4, t5); |
597 | r1 = _mm256_unpacklo_ps(t2, t3); |
598 | r5 = _mm256_unpacklo_ps(t6, t7); |
599 | r3 = _mm256_unpackhi_ps(t2, t3); |
600 | r7 = _mm256_unpackhi_ps(t6, t7); |
601 | |
602 | t0 = Mm256UnpackloPsx2<path>(r0, r1); |
603 | t4 = Mm256UnpackloPsx2<path>(r4, r5); |
604 | t2 = Mm256UnpackhiPsx2<path>(r0, r1); |
605 | t6 = Mm256UnpackhiPsx2<path>(r4, r5); |
606 | t1 = Mm256UnpackloPsx2<path>(r2, r3); |
607 | t5 = Mm256UnpackloPsx2<path>(r6, r7); |
608 | t3 = Mm256UnpackhiPsx2<path>(r2, r3); |
609 | t7 = Mm256UnpackhiPsx2<path>(r6, r7); |
610 | |
611 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
612 | // and then by 8 bytes *within* lanes. The following set interleave by 16 |
613 | // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) |
614 | // are interleaved to create (r0, r1). This complexity follows from the |
615 | // way that AVX is centered around MM 128-bit lanes. |
616 | r0 = _mm256_permute2f128_ps(t0, t4, 0x20); |
617 | r4 = _mm256_permute2f128_ps(t1, t5, 0x20); |
618 | r1 = _mm256_permute2f128_ps(t0, t4, 0x31); |
619 | r5 = _mm256_permute2f128_ps(t1, t5, 0x31); |
620 | r2 = _mm256_permute2f128_ps(t2, t6, 0x20); |
621 | r6 = _mm256_permute2f128_ps(t3, t7, 0x20); |
622 | r3 = _mm256_permute2f128_ps(t2, t6, 0x31); |
623 | r7 = _mm256_permute2f128_ps(t3, t7, 0x31); |
624 | |
625 | _mm256_storeu_ps(packed_ptr + 0 * 8, r0); |
626 | _mm256_storeu_ps(packed_ptr + 2 * 8, r4); |
627 | _mm256_storeu_ps(packed_ptr + 4 * 8, r1); |
628 | _mm256_storeu_ps(packed_ptr + 6 * 8, r5); |
629 | _mm256_storeu_ps(packed_ptr + 1 * 8, r2); |
630 | _mm256_storeu_ps(packed_ptr + 3 * 8, r6); |
631 | _mm256_storeu_ps(packed_ptr + 5 * 8, r3); |
632 | _mm256_storeu_ps(packed_ptr + 7 * 8, r7); |
633 | } else if (available_src_rows > 0) { |
634 | const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); |
635 | const __m256i row_mask_v = CompareGreaterThan<path>( |
636 | _mm256_set1_epi32(available_src_rows), series); |
637 | |
638 | __m256 t0, t1, t2, t3, t4, t5, t6, t7; |
639 | __m256 r0, r1, r2, r3, r4, r5, r6, r7; |
640 | |
641 | t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); |
642 | t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); |
643 | t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); |
644 | t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); |
645 | t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); |
646 | t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); |
647 | t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); |
648 | t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); |
649 | |
650 | r0 = _mm256_unpacklo_ps(t0, t1); |
651 | r4 = _mm256_unpacklo_ps(t4, t5); |
652 | r2 = _mm256_unpackhi_ps(t0, t1); |
653 | r6 = _mm256_unpackhi_ps(t4, t5); |
654 | r1 = _mm256_unpacklo_ps(t2, t3); |
655 | r5 = _mm256_unpacklo_ps(t6, t7); |
656 | r3 = _mm256_unpackhi_ps(t2, t3); |
657 | r7 = _mm256_unpackhi_ps(t6, t7); |
658 | |
659 | t0 = Mm256UnpackloPsx2<path>(r0, r1); |
660 | t4 = Mm256UnpackloPsx2<path>(r4, r5); |
661 | t2 = Mm256UnpackhiPsx2<path>(r0, r1); |
662 | t6 = Mm256UnpackhiPsx2<path>(r4, r5); |
663 | t1 = Mm256UnpackloPsx2<path>(r2, r3); |
664 | t5 = Mm256UnpackloPsx2<path>(r6, r7); |
665 | t3 = Mm256UnpackhiPsx2<path>(r2, r3); |
666 | t7 = Mm256UnpackhiPsx2<path>(r6, r7); |
667 | |
668 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
669 | // and then by 8 bytes *within* lanes. The following set interleave by 16 |
670 | // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) |
671 | // are interleaved to create (r0, r1). This complexity follows from the |
672 | // way that AVX is centered around MM 128-bit lanes. |
673 | r0 = _mm256_permute2f128_ps(t0, t4, 0x20); |
674 | r4 = _mm256_permute2f128_ps(t1, t5, 0x20); |
675 | r1 = _mm256_permute2f128_ps(t0, t4, 0x31); |
676 | r5 = _mm256_permute2f128_ps(t1, t5, 0x31); |
677 | r2 = _mm256_permute2f128_ps(t2, t6, 0x20); |
678 | r6 = _mm256_permute2f128_ps(t3, t7, 0x20); |
679 | r3 = _mm256_permute2f128_ps(t2, t6, 0x31); |
680 | // r7 no longer needed. |
681 | |
682 | _mm256_storeu_ps(trailing_buf + 0 * 8, r0); |
683 | _mm256_storeu_ps(trailing_buf + 2 * 8, r4); |
684 | _mm256_storeu_ps(trailing_buf + 4 * 8, r1); |
685 | _mm256_storeu_ps(trailing_buf + 6 * 8, r5); |
686 | _mm256_storeu_ps(trailing_buf + 1 * 8, r2); |
687 | _mm256_storeu_ps(trailing_buf + 3 * 8, r6); |
688 | _mm256_storeu_ps(trailing_buf + 5 * 8, r3); |
689 | // No store to (trailing_buf + 7 * 8), space not allocated. |
690 | } |
691 | |
692 | packed_ptr += kPackRows * kPackCols; |
693 | src_ptr0 += src_inc0; |
694 | src_ptr1 += src_inc1; |
695 | src_ptr2 += src_inc2; |
696 | src_ptr3 += src_inc3; |
697 | src_ptr4 += src_inc4; |
698 | src_ptr5 += src_inc5; |
699 | src_ptr6 += src_inc6; |
700 | src_ptr7 += src_inc7; |
701 | } |
702 | } |
703 | } // namespace ruy |
704 | #endif // (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM) |
705 | |
706 | #endif // RUY_RUY_PACK_X86_H_ |
707 | |