1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | #include <cstring> |
18 | |
19 | #include "ruy/check_macros.h" |
20 | #include "ruy/opt_set.h" |
21 | #include "ruy/pack_x86.h" |
22 | #include "ruy/path.h" |
23 | #include "ruy/platform.h" |
24 | #include "ruy/profiler/instrumentation.h" |
25 | |
26 | #if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS) |
27 | #include <immintrin.h> // IWYU pragma: keep |
28 | #endif |
29 | |
30 | namespace ruy { |
31 | |
32 | #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) |
33 | |
34 | void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t, |
35 | const std::int8_t*, int, int, int, std::int8_t*, |
36 | std::int32_t*) { |
37 | // CPU-ID-based checks should disable the path that would reach this point. |
38 | RUY_DCHECK(false); |
39 | } |
40 | |
41 | void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int, |
42 | int, int, std::int16_t*, std::int32_t*) { |
43 | // CPU-ID-based checks should disable the path that would reach this point. |
44 | RUY_DCHECK(false); |
45 | } |
46 | |
47 | void PackFloatColMajorForAvx512(const float*, const float*, int, int, int, |
48 | float*) { |
49 | // CPU-ID-based checks should disable the path that would reach this point. |
50 | RUY_DCHECK(false); |
51 | } |
52 | |
53 | void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int, |
54 | int, int, int, int, int, int, std::int32_t*) { |
55 | RUY_DCHECK(false); |
56 | } |
57 | |
58 | #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
59 | |
60 | // The first int8_t template parameter is arbitrary: this routine is common to |
61 | // all 8-bit source matrix types. |
62 | using PackImpl8bitAvx512 = |
63 | PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, |
64 | std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>; |
65 | using PackImpl16bitAvx512 = |
66 | PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>, |
67 | std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>; |
68 | |
69 | namespace { |
70 | |
71 | template <typename PackImplAvx512, typename Scalar> |
72 | inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point, |
73 | Scalar* packed_ptr, int chunked_row_mask) { |
74 | using Layout = typename PackImplAvx512::Layout; |
75 | static constexpr int kHalfLayoutCols = |
76 | PackImplAvx512::kHalfLayoutCols; // Half the number of cols in a |
77 | // block. |
78 | RUY_DCHECK_EQ(kHalfLayoutCols, 8); |
79 | RUY_DCHECK_EQ(Layout::kCols, 16); |
80 | RUY_DCHECK_EQ(Layout::kRows, 4); |
81 | |
82 | const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2; |
83 | // This routine fills half blocks, and typically fills the second halves. |
84 | // Thus packed_ptr is already offset by 8 * 4. |
85 | for (int k = 0; k < non_trailing_blocks; ++k) { |
86 | for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { |
87 | packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; |
88 | } |
89 | } |
90 | } |
91 | |
92 | template <typename Scalar> |
93 | inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) { |
94 | __m512i lower_filled = _mm512_castsi256_si512( |
95 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo))); |
96 | return _mm512_inserti32x8( |
97 | lower_filled, |
98 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1); |
99 | } |
100 | |
101 | inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, |
102 | const std::int8_t* addr_lo, |
103 | const std::int8_t* addr_hi) { |
104 | const __m512i lower_filled = _mm512_castsi256_si512( |
105 | _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); |
106 | return _mm512_inserti32x8( |
107 | lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), |
108 | 1); |
109 | } |
110 | |
111 | inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, |
112 | const std::int16_t* addr_lo, |
113 | const std::int16_t* addr_hi) { |
114 | const __m512i lower_filled = _mm512_castsi256_si512( |
115 | _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo)); |
116 | return _mm512_inserti32x8( |
117 | lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi), |
118 | 1); |
119 | } |
120 | |
121 | inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, |
122 | std::int8_t input_xor, |
123 | const std::int8_t* zerobuf, int src_stride, |
124 | int remaining_src_cols, int src_rows, |
125 | std::int8_t* packed_ptr, std::int32_t* sums_ptr, |
126 | std::int8_t* trailing_buf) { |
127 | using Layout = PackImpl8bitAvx512::Layout; |
128 | RUY_DCHECK_EQ(Layout::kCols, 16); |
129 | RUY_DCHECK_EQ(Layout::kRows, 4); |
130 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
131 | // We process 8 of these chunks at a time, padding short input chunks. |
132 | constexpr int kNumRowChunks = 8; |
133 | constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; |
134 | |
135 | const std::int8_t* src_ptr0 = src_ptr; |
136 | const std::int8_t* src_ptr1 = src_ptr0 + src_stride; |
137 | const std::int8_t* src_ptr2 = src_ptr1 + src_stride; |
138 | const std::int8_t* src_ptr3 = src_ptr2 + src_stride; |
139 | const std::int8_t* src_ptr4 = src_ptr3 + src_stride; |
140 | const std::int8_t* src_ptr5 = src_ptr4 + src_stride; |
141 | const std::int8_t* src_ptr6 = src_ptr5 + src_stride; |
142 | const std::int8_t* src_ptr7 = src_ptr6 + src_stride; |
143 | std::int64_t src_inc0 = kNumChunkedSrcRows; |
144 | std::int64_t src_inc1 = kNumChunkedSrcRows; |
145 | std::int64_t src_inc2 = kNumChunkedSrcRows; |
146 | std::int64_t src_inc3 = kNumChunkedSrcRows; |
147 | std::int64_t src_inc4 = kNumChunkedSrcRows; |
148 | std::int64_t src_inc5 = kNumChunkedSrcRows; |
149 | std::int64_t src_inc6 = kNumChunkedSrcRows; |
150 | std::int64_t src_inc7 = kNumChunkedSrcRows; |
151 | // Handle cases where source does not have kHalfLayoutCols (8) columns. |
152 | if (remaining_src_cols < 8) { |
153 | if (remaining_src_cols <= 0) { |
154 | src_ptr0 = zerobuf; |
155 | src_inc0 = 0; |
156 | } |
157 | if (remaining_src_cols <= 1) { |
158 | src_ptr1 = zerobuf; |
159 | src_inc1 = 0; |
160 | } |
161 | if (remaining_src_cols <= 2) { |
162 | src_ptr2 = zerobuf; |
163 | src_inc2 = 0; |
164 | } |
165 | if (remaining_src_cols <= 3) { |
166 | src_ptr3 = zerobuf; |
167 | src_inc3 = 0; |
168 | } |
169 | if (remaining_src_cols <= 4) { |
170 | src_ptr4 = zerobuf; |
171 | src_inc4 = 0; |
172 | } |
173 | if (remaining_src_cols <= 5) { |
174 | src_ptr5 = zerobuf; |
175 | src_inc5 = 0; |
176 | } |
177 | if (remaining_src_cols <= 6) { |
178 | src_ptr6 = zerobuf; |
179 | src_inc6 = 0; |
180 | } |
181 | src_ptr7 = zerobuf; |
182 | src_inc7 = 0; |
183 | } |
184 | |
185 | const std::int8_t zero_point = zerobuf[0]; |
186 | |
187 | if (sums_ptr) { |
188 | // i: kHalfLayoutCols. |
189 | for (int i = 0; i < 8; ++i) { |
190 | sums_ptr[i] = 0; |
191 | } |
192 | } |
193 | std::int32_t sums_adjustment = 0; |
194 | const __m512i ones_16bit = _mm512_set1_epi16(1); |
195 | __m512i sums_8x2_32bit = _mm512_set1_epi32(0); |
196 | |
197 | // The overall packing effectively pads the source rows to |
198 | // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we |
199 | // only pack for (src_rows + 31) & ~31. When there is an incomplete |
200 | // destination block, this is stored into trailing_buf instead of packed_ptr. |
201 | for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { |
202 | // m: {0, 1} for 2 chunks of rows. |
203 | for (int m = 0; m < 2; ++m) { |
204 | // Available source rows. |
205 | // If this is less than 0 (for m=1), we skip, having filled trailing |
206 | // buffer for m=0. Also, if source rows is zero on m=1, then we filled |
207 | // exactly to the end of the column in the packed buffer. |
208 | const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; |
209 | // Effectively, |
210 | // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); |
211 | // treat each case separately. |
212 | if (available_src_rows >= kNumChunkedSrcRows) { |
213 | // i: chunks, s: Layout::Rows. |
214 | if (sums_ptr) { |
215 | __m512i t0, t1, t2, t3; |
216 | __m512i r0, r1, r2, r3; |
217 | const __m512i input_xor_v = _mm512_set1_epi8(input_xor); |
218 | |
219 | t0 = LoaduTwo(src_ptr0, src_ptr4); |
220 | t1 = LoaduTwo(src_ptr1, src_ptr5); |
221 | t2 = LoaduTwo(src_ptr2, src_ptr6); |
222 | t3 = LoaduTwo(src_ptr3, src_ptr7); |
223 | |
224 | r0 = _mm512_unpacklo_epi32(t0, t1); |
225 | r2 = _mm512_unpackhi_epi32(t0, t1); |
226 | r1 = _mm512_unpacklo_epi32(t2, t3); |
227 | r3 = _mm512_unpackhi_epi32(t2, t3); |
228 | |
229 | t0 = _mm512_unpacklo_epi64(r0, r1); |
230 | t2 = _mm512_unpackhi_epi64(r0, r1); |
231 | t1 = _mm512_unpacklo_epi64(r2, r3); |
232 | t3 = _mm512_unpackhi_epi64(r2, r3); |
233 | |
234 | r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); |
235 | r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); |
236 | r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); |
237 | r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); |
238 | |
239 | r0 = _mm512_xor_si512(r0, input_xor_v); |
240 | r1 = _mm512_xor_si512(r1, input_xor_v); |
241 | r2 = _mm512_xor_si512(r2, input_xor_v); |
242 | r3 = _mm512_xor_si512(r3, input_xor_v); |
243 | |
244 | const __m256i r0_0 = _mm512_castsi512_si256(r0); |
245 | const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); |
246 | const __m256i r1_0 = _mm512_castsi512_si256(r1); |
247 | const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); |
248 | const __m256i r2_0 = _mm512_castsi512_si256(r2); |
249 | const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); |
250 | const __m256i r3_0 = _mm512_castsi512_si256(r3); |
251 | const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); |
252 | |
253 | __m512i sums_8x4_16bit; |
254 | sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); |
255 | sums_8x4_16bit = |
256 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); |
257 | sums_8x4_16bit = |
258 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); |
259 | sums_8x4_16bit = |
260 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); |
261 | sums_8x4_16bit = |
262 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); |
263 | sums_8x4_16bit = |
264 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); |
265 | sums_8x4_16bit = |
266 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); |
267 | sums_8x4_16bit = |
268 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); |
269 | // The sums have been performed across columns, and now we have |
270 | // 4x16-bit sums packed together. We use madd for pairwise 32-bit |
271 | // sums. |
272 | const __m512i sums_8x2_32bit_new = |
273 | _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); |
274 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); |
275 | |
276 | _mm256_storeu_si256( |
277 | reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0); |
278 | _mm256_storeu_si256( |
279 | reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1); |
280 | _mm256_storeu_si256( |
281 | reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0); |
282 | _mm256_storeu_si256( |
283 | reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1); |
284 | _mm256_storeu_si256( |
285 | reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0); |
286 | _mm256_storeu_si256( |
287 | reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1); |
288 | _mm256_storeu_si256( |
289 | reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0); |
290 | _mm256_storeu_si256( |
291 | reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1); |
292 | } else { |
293 | __m512i t0, t1, t2, t3; |
294 | __m512i r0, r1, r2, r3; |
295 | const __m512i input_xor_v = _mm512_set1_epi8(input_xor); |
296 | |
297 | t0 = LoaduTwo(src_ptr0, src_ptr4); |
298 | t1 = LoaduTwo(src_ptr1, src_ptr5); |
299 | t2 = LoaduTwo(src_ptr2, src_ptr6); |
300 | t3 = LoaduTwo(src_ptr3, src_ptr7); |
301 | |
302 | r0 = _mm512_unpacklo_epi32(t0, t1); |
303 | r2 = _mm512_unpackhi_epi32(t0, t1); |
304 | r1 = _mm512_unpacklo_epi32(t2, t3); |
305 | r3 = _mm512_unpackhi_epi32(t2, t3); |
306 | |
307 | t0 = _mm512_unpacklo_epi64(r0, r1); |
308 | t2 = _mm512_unpackhi_epi64(r0, r1); |
309 | t1 = _mm512_unpacklo_epi64(r2, r3); |
310 | t3 = _mm512_unpackhi_epi64(r2, r3); |
311 | |
312 | r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); |
313 | r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); |
314 | r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); |
315 | r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); |
316 | |
317 | r0 = _mm512_xor_si512(r0, input_xor_v); |
318 | r1 = _mm512_xor_si512(r1, input_xor_v); |
319 | r2 = _mm512_xor_si512(r2, input_xor_v); |
320 | r3 = _mm512_xor_si512(r3, input_xor_v); |
321 | |
322 | const __m256i r0_0 = _mm512_castsi512_si256(r0); |
323 | const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); |
324 | const __m256i r1_0 = _mm512_castsi512_si256(r1); |
325 | const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); |
326 | const __m256i r2_0 = _mm512_castsi512_si256(r2); |
327 | const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); |
328 | const __m256i r3_0 = _mm512_castsi512_si256(r3); |
329 | const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); |
330 | _mm256_storeu_si256( |
331 | reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0); |
332 | _mm256_storeu_si256( |
333 | reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1); |
334 | _mm256_storeu_si256( |
335 | reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0); |
336 | _mm256_storeu_si256( |
337 | reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1); |
338 | _mm256_storeu_si256( |
339 | reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0); |
340 | _mm256_storeu_si256( |
341 | reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1); |
342 | _mm256_storeu_si256( |
343 | reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0); |
344 | _mm256_storeu_si256( |
345 | reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1); |
346 | } |
347 | } else if (available_src_rows > 0) { |
348 | RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); |
349 | const __mmask32 row_mask = |
350 | (static_cast<std::uint64_t>(1) << available_src_rows) - 1; |
351 | |
352 | // We do not care what goes into the trailing buffer, but we want |
353 | // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. |
354 | // |
355 | // We compensate for padding-with-zero_point by initializing the |
356 | // summations with the compensating offset, effectively |
357 | // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * |
358 | // 4 * (8 - ((available_src_rows + 3) >> 2)). |
359 | // |
360 | // Note that (zero_point ^ input_xor) is performed in 8-bits and then |
361 | // cast. |
362 | sums_adjustment += -(zero_point ^ input_xor) * 4 * |
363 | (8 - ((available_src_rows + 3) >> 2)); |
364 | |
365 | __m512i t0, t1, t2, t3; |
366 | __m512i r0, r1, r2, r3; |
367 | const __m512i input_xor_v = _mm512_set1_epi8(input_xor); |
368 | const __m256i zero_point_v = _mm256_set1_epi8(zero_point); |
369 | |
370 | t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); |
371 | t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); |
372 | t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); |
373 | t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); |
374 | |
375 | r0 = _mm512_unpacklo_epi32(t0, t1); |
376 | r2 = _mm512_unpackhi_epi32(t0, t1); |
377 | r1 = _mm512_unpacklo_epi32(t2, t3); |
378 | r3 = _mm512_unpackhi_epi32(t2, t3); |
379 | |
380 | t0 = _mm512_unpacklo_epi64(r0, r1); |
381 | t2 = _mm512_unpackhi_epi64(r0, r1); |
382 | t1 = _mm512_unpacklo_epi64(r2, r3); |
383 | t3 = _mm512_unpackhi_epi64(r2, r3); |
384 | |
385 | r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); |
386 | r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); |
387 | r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); |
388 | r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); |
389 | |
390 | r0 = _mm512_xor_si512(r0, input_xor_v); |
391 | r1 = _mm512_xor_si512(r1, input_xor_v); |
392 | r2 = _mm512_xor_si512(r2, input_xor_v); |
393 | r3 = _mm512_xor_si512(r3, input_xor_v); |
394 | |
395 | const __m256i r0_0 = _mm512_castsi512_si256(r0); |
396 | const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); |
397 | const __m256i r1_0 = _mm512_castsi512_si256(r1); |
398 | const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); |
399 | const __m256i r2_0 = _mm512_castsi512_si256(r2); |
400 | const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); |
401 | const __m256i r3_0 = _mm512_castsi512_si256(r3); |
402 | const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); |
403 | |
404 | __m512i sums_8x4_16bit; |
405 | sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); |
406 | sums_8x4_16bit = |
407 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); |
408 | sums_8x4_16bit = |
409 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); |
410 | sums_8x4_16bit = |
411 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); |
412 | sums_8x4_16bit = |
413 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); |
414 | sums_8x4_16bit = |
415 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); |
416 | sums_8x4_16bit = |
417 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); |
418 | sums_8x4_16bit = |
419 | _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); |
420 | // The sums have been performed across columns, and now we have |
421 | // 4x16-bit sums packed together. We use madd for pairwise 32-bit |
422 | // sums. |
423 | const __m512i sums_8x2_32bit_new = |
424 | _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); |
425 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); |
426 | |
427 | _mm256_storeu_si256( |
428 | reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0); |
429 | _mm256_storeu_si256( |
430 | reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1); |
431 | _mm256_storeu_si256( |
432 | reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0); |
433 | _mm256_storeu_si256( |
434 | reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1); |
435 | _mm256_storeu_si256( |
436 | reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0); |
437 | _mm256_storeu_si256( |
438 | reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1); |
439 | _mm256_storeu_si256( |
440 | reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0); |
441 | _mm256_storeu_si256( |
442 | reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1); |
443 | } |
444 | |
445 | packed_ptr += 16 * kNumChunkedSrcRows; |
446 | src_ptr0 += src_inc0; |
447 | src_ptr1 += src_inc1; |
448 | src_ptr2 += src_inc2; |
449 | src_ptr3 += src_inc3; |
450 | src_ptr4 += src_inc4; |
451 | src_ptr5 += src_inc5; |
452 | src_ptr6 += src_inc6; |
453 | src_ptr7 += src_inc7; |
454 | } |
455 | } |
456 | |
457 | if (sums_ptr) { |
458 | const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); |
459 | |
460 | __m256i sums = |
461 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); |
462 | const __m512i idx = |
463 | _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); |
464 | |
465 | // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the |
466 | // neighbours, finshing up by adding them to the stored accumulated sums. |
467 | const __m512i sums_2x8_32bit = |
468 | _mm512_permutexvar_epi32(idx, sums_8x2_32bit); |
469 | sums = _mm256_add_epi32(sums, sums_adjustment_v); |
470 | sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); |
471 | sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); |
472 | |
473 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); |
474 | } |
475 | } |
476 | |
477 | inline void HalfPack16bitAvx512(const std::int16_t* src_ptr, |
478 | const std::int16_t* zerobuf, int src_stride, |
479 | int remaining_src_cols, int src_rows, |
480 | std::int16_t* packed_ptr, |
481 | std::int32_t* sums_ptr, |
482 | std::int16_t* trailing_buf) { |
483 | using Layout = PackImpl16bitAvx512::Layout; |
484 | RUY_DCHECK_EQ(Layout::kCols, 16); |
485 | RUY_DCHECK_EQ(Layout::kRows, 4); |
486 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
487 | // We process 4 of these chunks at a time, padding std::int16_t input chunks. |
488 | constexpr int kNumRowChunks = 4; |
489 | constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; |
490 | |
491 | const std::int16_t* src_ptr0 = src_ptr; |
492 | const std::int16_t* src_ptr1 = src_ptr0 + src_stride; |
493 | const std::int16_t* src_ptr2 = src_ptr1 + src_stride; |
494 | const std::int16_t* src_ptr3 = src_ptr2 + src_stride; |
495 | const std::int16_t* src_ptr4 = src_ptr3 + src_stride; |
496 | const std::int16_t* src_ptr5 = src_ptr4 + src_stride; |
497 | const std::int16_t* src_ptr6 = src_ptr5 + src_stride; |
498 | const std::int16_t* src_ptr7 = src_ptr6 + src_stride; |
499 | std::int64_t src_inc0 = kNumChunkedSrcRows; |
500 | std::int64_t src_inc1 = kNumChunkedSrcRows; |
501 | std::int64_t src_inc2 = kNumChunkedSrcRows; |
502 | std::int64_t src_inc3 = kNumChunkedSrcRows; |
503 | std::int64_t src_inc4 = kNumChunkedSrcRows; |
504 | std::int64_t src_inc5 = kNumChunkedSrcRows; |
505 | std::int64_t src_inc6 = kNumChunkedSrcRows; |
506 | std::int64_t src_inc7 = kNumChunkedSrcRows; |
507 | // Handle cases where source does not have kHalfLayoutCols (8) columns. |
508 | if (remaining_src_cols < 8) { |
509 | if (remaining_src_cols <= 0) { |
510 | src_ptr0 = zerobuf; |
511 | src_inc0 = 0; |
512 | } |
513 | if (remaining_src_cols <= 1) { |
514 | src_ptr1 = zerobuf; |
515 | src_inc1 = 0; |
516 | } |
517 | if (remaining_src_cols <= 2) { |
518 | src_ptr2 = zerobuf; |
519 | src_inc2 = 0; |
520 | } |
521 | if (remaining_src_cols <= 3) { |
522 | src_ptr3 = zerobuf; |
523 | src_inc3 = 0; |
524 | } |
525 | if (remaining_src_cols <= 4) { |
526 | src_ptr4 = zerobuf; |
527 | src_inc4 = 0; |
528 | } |
529 | if (remaining_src_cols <= 5) { |
530 | src_ptr5 = zerobuf; |
531 | src_inc5 = 0; |
532 | } |
533 | if (remaining_src_cols <= 6) { |
534 | src_ptr6 = zerobuf; |
535 | src_inc6 = 0; |
536 | } |
537 | src_ptr7 = zerobuf; |
538 | src_inc7 = 0; |
539 | } |
540 | |
541 | const std::int16_t zero_point = zerobuf[0]; |
542 | |
543 | if (sums_ptr) { |
544 | // i: kHalfLayoutCols. |
545 | for (int i = 0; i < 8; ++i) { |
546 | sums_ptr[i] = 0; |
547 | } |
548 | } |
549 | std::int32_t sums_adjustment = 0; |
550 | const __m512i ones_16bit = _mm512_set1_epi16(1); |
551 | __m512i sums_8x2_32bit = _mm512_set1_epi32(0); |
552 | |
553 | // The overall packing effectively pads the source rows to |
554 | // (src_rows + 31) & ~31. The iteration over k may skip when m=1, and then we |
555 | // only pack for (src_rows + 15) & ~15. When there is an incomplete |
556 | // destination block, this is stored into trailing_buf instead of packed_ptr. |
557 | for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { |
558 | // m: {0, 1} for 2 chunks of rows. |
559 | for (int m = 0; m < 2; ++m) { |
560 | const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; |
561 | |
562 | // Available source rows. |
563 | // If this is less than 0 (for m=1), we skip, having filled trailing |
564 | // buffer for m=0. Also, if source rows is zero on m=1, then we filled |
565 | // exactly to the end of the column in the packed buffer. |
566 | if (available_src_rows > 0) { |
567 | __m512i t0, t1, t2, t3; |
568 | __m512i r0, r1, r2, r3; |
569 | std::int16_t* dst_ptr = packed_ptr; |
570 | |
571 | if (available_src_rows >= kNumChunkedSrcRows) { |
572 | t0 = LoaduTwo(src_ptr0, src_ptr4); |
573 | t1 = LoaduTwo(src_ptr1, src_ptr5); |
574 | t2 = LoaduTwo(src_ptr2, src_ptr6); |
575 | t3 = LoaduTwo(src_ptr3, src_ptr7); |
576 | } else { |
577 | RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); |
578 | // We do not care what goes into the trailing buffer, but we want |
579 | // in_data[...] == zero_point for irrelevant values in the summation. |
580 | // |
581 | // We compensate for padding-with-zero_point by initializing the |
582 | // summations with the compensating offset. |
583 | sums_adjustment += |
584 | -(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2)); |
585 | |
586 | const __m256i zero_point_v = _mm256_set1_epi16(zero_point); |
587 | const __mmask32 row_mask = |
588 | (static_cast<std::uint64_t>(1) << available_src_rows) - 1; |
589 | |
590 | t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); |
591 | t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); |
592 | t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); |
593 | t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); |
594 | dst_ptr = trailing_buf; |
595 | } |
596 | |
597 | r0 = _mm512_unpacklo_epi64(t0, t1); |
598 | r2 = _mm512_unpackhi_epi64(t0, t1); |
599 | r1 = _mm512_unpacklo_epi64(t2, t3); |
600 | r3 = _mm512_unpackhi_epi64(t2, t3); |
601 | |
602 | r1 = _mm512_permutex_epi64(r1, 0x4e); |
603 | r3 = _mm512_permutex_epi64(r3, 0x4e); |
604 | |
605 | t0 = _mm512_mask_blend_epi64(0xcc, r0, r1); |
606 | t1 = _mm512_mask_blend_epi64(0x33, r0, r1); |
607 | t2 = _mm512_mask_blend_epi64(0xcc, r2, r3); |
608 | t3 = _mm512_mask_blend_epi64(0x33, r2, r3); |
609 | |
610 | t1 = _mm512_permutex_epi64(t1, 0x4e); |
611 | t3 = _mm512_permutex_epi64(t3, 0x4e); |
612 | |
613 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4), |
614 | t0); |
615 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4), |
616 | t1); |
617 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4), |
618 | t2); |
619 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4), |
620 | t3); |
621 | |
622 | if (sums_ptr) { |
623 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, |
624 | _mm512_madd_epi16(t0, ones_16bit)); |
625 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, |
626 | _mm512_madd_epi16(t1, ones_16bit)); |
627 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, |
628 | _mm512_madd_epi16(t2, ones_16bit)); |
629 | sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, |
630 | _mm512_madd_epi16(t3, ones_16bit)); |
631 | } |
632 | } |
633 | |
634 | packed_ptr += 16 * kNumChunkedSrcRows; |
635 | src_ptr0 += src_inc0; |
636 | src_ptr1 += src_inc1; |
637 | src_ptr2 += src_inc2; |
638 | src_ptr3 += src_inc3; |
639 | src_ptr4 += src_inc4; |
640 | src_ptr5 += src_inc5; |
641 | src_ptr6 += src_inc6; |
642 | src_ptr7 += src_inc7; |
643 | } |
644 | } |
645 | |
646 | if (sums_ptr) { |
647 | const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); |
648 | |
649 | __m256i sums = |
650 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); |
651 | const __m512i idx = |
652 | _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); |
653 | |
654 | const __m512i sums_2x8_32bit = |
655 | _mm512_permutexvar_epi32(idx, sums_8x2_32bit); |
656 | sums = _mm256_add_epi32(sums, sums_adjustment_v); |
657 | sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); |
658 | sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); |
659 | |
660 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); |
661 | } |
662 | } |
663 | |
664 | inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { |
665 | const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); |
666 | return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); |
667 | } |
668 | |
669 | inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, |
670 | const float* addr_hi) { |
671 | const __m512 lower_filled = |
672 | _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); |
673 | return _mm512_insertf32x8(lower_filled, |
674 | _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); |
675 | } |
676 | |
677 | inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { |
678 | return _mm512_castpd_ps( |
679 | _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); |
680 | } |
681 | |
682 | inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { |
683 | return _mm512_castpd_ps( |
684 | _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); |
685 | } |
686 | |
687 | inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, |
688 | int src_stride, int remaining_src_cols, |
689 | int src_rows, float* packed_ptr, |
690 | float* trailing_buf) { |
691 | const float* src_ptr0 = src_ptr; |
692 | const float* src_ptr1 = src_ptr0 + src_stride; |
693 | const float* src_ptr2 = src_ptr1 + src_stride; |
694 | const float* src_ptr3 = src_ptr2 + src_stride; |
695 | const float* src_ptr4 = src_ptr3 + src_stride; |
696 | const float* src_ptr5 = src_ptr4 + src_stride; |
697 | const float* src_ptr6 = src_ptr5 + src_stride; |
698 | const float* src_ptr7 = src_ptr6 + src_stride; |
699 | std::int64_t src_inc0 = 8; |
700 | std::int64_t src_inc1 = 8; |
701 | std::int64_t src_inc2 = 8; |
702 | std::int64_t src_inc3 = 8; |
703 | std::int64_t src_inc4 = 8; |
704 | std::int64_t src_inc5 = 8; |
705 | std::int64_t src_inc6 = 8; |
706 | std::int64_t src_inc7 = 8; |
707 | if (remaining_src_cols < 8) { |
708 | if (remaining_src_cols <= 0) { |
709 | src_ptr0 = zerobuf; |
710 | src_inc0 = 0; |
711 | } |
712 | if (remaining_src_cols <= 1) { |
713 | src_ptr1 = zerobuf; |
714 | src_inc1 = 0; |
715 | } |
716 | if (remaining_src_cols <= 2) { |
717 | src_ptr2 = zerobuf; |
718 | src_inc2 = 0; |
719 | } |
720 | if (remaining_src_cols <= 3) { |
721 | src_ptr3 = zerobuf; |
722 | src_inc3 = 0; |
723 | } |
724 | if (remaining_src_cols <= 4) { |
725 | src_ptr4 = zerobuf; |
726 | src_inc4 = 0; |
727 | } |
728 | if (remaining_src_cols <= 5) { |
729 | src_ptr5 = zerobuf; |
730 | src_inc5 = 0; |
731 | } |
732 | if (remaining_src_cols <= 6) { |
733 | src_ptr6 = zerobuf; |
734 | src_inc6 = 0; |
735 | } |
736 | src_ptr7 = zerobuf; |
737 | src_inc7 = 0; |
738 | } |
739 | |
740 | for (int k = 0; k < src_rows; k += 16) { |
741 | for (int m = 0; m < 2; ++m) { |
742 | const int available_src_rows = src_rows - k - 8 * m; |
743 | // Effectively, |
744 | // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); |
745 | // but treat each case separately. |
746 | if (available_src_rows > 7) { |
747 | __m512 t0, t1, t2, t3; |
748 | __m512 r0, r1, r2, r3; |
749 | |
750 | t0 = LoaduTwo(src_ptr0, src_ptr4); |
751 | t1 = LoaduTwo(src_ptr1, src_ptr5); |
752 | t2 = LoaduTwo(src_ptr2, src_ptr6); |
753 | t3 = LoaduTwo(src_ptr3, src_ptr7); |
754 | |
755 | r0 = _mm512_unpacklo_ps(t0, t1); |
756 | r2 = _mm512_unpackhi_ps(t0, t1); |
757 | r1 = _mm512_unpacklo_ps(t2, t3); |
758 | r3 = _mm512_unpackhi_ps(t2, t3); |
759 | |
760 | t0 = Mm512UnpackloPsx2(r0, r1); |
761 | t2 = Mm512UnpackhiPsx2(r0, r1); |
762 | t1 = Mm512UnpackloPsx2(r2, r3); |
763 | t3 = Mm512UnpackhiPsx2(r2, r3); |
764 | |
765 | r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); |
766 | r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); |
767 | r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); |
768 | r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); |
769 | |
770 | _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); |
771 | _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); |
772 | _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); |
773 | _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); |
774 | _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); |
775 | _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); |
776 | _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); |
777 | _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); |
778 | } else if (available_src_rows > 0) { |
779 | const __mmask8 row_mask = |
780 | (static_cast<std::uint32_t>(1) << available_src_rows) - 1; |
781 | |
782 | __m512 t0, t1, t2, t3; |
783 | __m512 r0, r1, r2, r3; |
784 | |
785 | t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); |
786 | t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); |
787 | t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); |
788 | t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); |
789 | |
790 | r0 = _mm512_unpacklo_ps(t0, t1); |
791 | r2 = _mm512_unpackhi_ps(t0, t1); |
792 | r1 = _mm512_unpacklo_ps(t2, t3); |
793 | r3 = _mm512_unpackhi_ps(t2, t3); |
794 | |
795 | t0 = Mm512UnpackloPsx2(r0, r1); |
796 | t2 = Mm512UnpackhiPsx2(r0, r1); |
797 | t1 = Mm512UnpackloPsx2(r2, r3); |
798 | t3 = Mm512UnpackhiPsx2(r2, r3); |
799 | |
800 | r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); |
801 | r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); |
802 | r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); |
803 | r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); |
804 | |
805 | _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); |
806 | _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); |
807 | _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); |
808 | _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); |
809 | _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); |
810 | _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); |
811 | _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); |
812 | // Do not store _mm512_extractf32x8_ps(r3, 1). |
813 | } |
814 | |
815 | packed_ptr += 16 * 8; |
816 | src_ptr0 += src_inc0; |
817 | src_ptr1 += src_inc1; |
818 | src_ptr2 += src_inc2; |
819 | src_ptr3 += src_inc3; |
820 | src_ptr4 += src_inc4; |
821 | src_ptr5 += src_inc5; |
822 | src_ptr6 += src_inc6; |
823 | src_ptr7 += src_inc7; |
824 | } |
825 | } |
826 | } |
827 | |
828 | inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { |
829 | const int non_trailing_rows = src_rows & ~7; |
830 | for (int k = 0; k < non_trailing_rows; ++k) { |
831 | for (int j = 0; j < 8; ++j) { |
832 | packed_ptr[j] = 0.0f; |
833 | } |
834 | packed_ptr += 16; |
835 | } |
836 | } |
837 | |
838 | } // namespace. |
839 | |
840 | void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr, |
841 | std::int8_t input_xor, |
842 | const std::int8_t* zerobuf, int src_stride, |
843 | int remaining_src_cols, int src_rows, |
844 | std::int8_t* packed_ptr, |
845 | std::int32_t* sums_ptr) { |
846 | profiler::ScopeLabel label("Pack kAvx512 8bit" ); |
847 | |
848 | using Layout = PackImpl8bitAvx512::Layout; |
849 | constexpr int kHalfBlockOffset = 32; |
850 | RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); |
851 | static constexpr int kHalfLayoutCols = |
852 | PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a |
853 | // block. |
854 | RUY_DCHECK_EQ(kHalfLayoutCols, 8); |
855 | RUY_DCHECK_EQ(Layout::kCols, 16); |
856 | RUY_DCHECK_EQ(Layout::kRows, 4); |
857 | |
858 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
859 | // We process 8 of these chunks at a time, padding short input chunks. |
860 | constexpr int kNumRowChunks = 8; |
861 | |
862 | // Each packed block is 4*16, and there are normally 8. The trailing block is |
863 | // only slightly shorter. |
864 | constexpr int kTrailingBufSize = |
865 | kNumRowChunks * Layout::kCols * Layout::kRows; |
866 | std::int8_t trailing_buf[kTrailingBufSize]; |
867 | memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); |
868 | constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; |
869 | |
870 | std::int32_t* second_sums_ptr = |
871 | sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; |
872 | if (remaining_src_cols > kHalfLayoutCols) { |
873 | HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, |
874 | remaining_src_cols, src_rows, packed_ptr, sums_ptr, |
875 | trailing_buf); |
876 | HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, |
877 | zerobuf, src_stride, |
878 | remaining_src_cols - kHalfLayoutCols, src_rows, |
879 | packed_ptr + kHalfBlockOffset, second_sums_ptr, |
880 | trailing_buf + kHalfBlockOffset); |
881 | } else { |
882 | HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, |
883 | remaining_src_cols, src_rows, packed_ptr, sums_ptr, |
884 | trailing_buf); |
885 | ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>( |
886 | src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset, |
887 | kChunkedRowMask); |
888 | // The kernel may not need the second half-blocks sums to be set. |
889 | if (second_sums_ptr) { |
890 | for (int i = 0; i < kHalfLayoutCols; ++i) { |
891 | second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); |
892 | } |
893 | } |
894 | } |
895 | const bool trailing_data = (src_rows & kChunkedRowMask) > 0; |
896 | // If the number of source rows is not a multiple of kChunkedRowMask, there |
897 | // will be data in the trailing buffer, |
898 | if (trailing_data) { |
899 | const int non_trailing_rows = src_rows & ~kChunkedRowMask; |
900 | // Destination "rows" are padded to next highest multiple of Layout::kRows. |
901 | const int dst_rows = (src_rows + 3) & ~3; |
902 | const int trailing_rows = dst_rows - non_trailing_rows; |
903 | memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, |
904 | Layout::kCols * trailing_rows * sizeof(std::int8_t)); |
905 | } |
906 | } |
907 | |
908 | void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr, |
909 | const std::int16_t* zerobuf, int src_stride, |
910 | int remaining_src_cols, int src_rows, |
911 | std::int16_t* packed_ptr, |
912 | std::int32_t* sums_ptr) { |
913 | profiler::ScopeLabel label("Pack kAvx512 16bit" ); |
914 | |
915 | using Layout = PackImpl16bitAvx512::Layout; |
916 | constexpr int kHalfBlockOffset = 32; |
917 | RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); |
918 | static constexpr int kHalfLayoutCols = |
919 | PackImpl16bitAvx512::kHalfLayoutCols; // Half the number of cols in a |
920 | // block. |
921 | RUY_DCHECK_EQ(kHalfLayoutCols, 8); |
922 | RUY_DCHECK_EQ(Layout::kCols, 16); |
923 | RUY_DCHECK_EQ(Layout::kRows, 4); |
924 | |
925 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
926 | // We process 8 of these chunks at a time, padding short input chunks. |
927 | constexpr int kNumRowChunks = 4; |
928 | |
929 | // Each packed block is 4*16, and there are normally 8. The trailing block is |
930 | // only slightly shorter. |
931 | constexpr int kTrailingBufSize = |
932 | kNumRowChunks * Layout::kCols * Layout::kRows; |
933 | std::int16_t trailing_buf[kTrailingBufSize] = {0}; |
934 | constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; |
935 | |
936 | std::int32_t* second_sums_ptr = |
937 | sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; |
938 | if (remaining_src_cols > kHalfLayoutCols) { |
939 | HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, |
940 | src_rows, packed_ptr, sums_ptr, trailing_buf); |
941 | HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf, |
942 | src_stride, remaining_src_cols - kHalfLayoutCols, |
943 | src_rows, packed_ptr + kHalfBlockOffset, |
944 | second_sums_ptr, trailing_buf + kHalfBlockOffset); |
945 | } else { |
946 | HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, |
947 | src_rows, packed_ptr, sums_ptr, trailing_buf); |
948 | ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>( |
949 | src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask); |
950 | // The kernel may not need the second half-blocks sums to be set. |
951 | if (second_sums_ptr) { |
952 | for (int i = 0; i < kHalfLayoutCols; ++i) { |
953 | second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3); |
954 | } |
955 | } |
956 | } |
957 | const bool trailing_data = (src_rows & kChunkedRowMask) > 0; |
958 | // If the number of source rows is not a multiple of kChunkedRowMask, there |
959 | // will be data in the trailing buffer, |
960 | if (trailing_data) { |
961 | const int non_trailing_rows = src_rows & ~kChunkedRowMask; |
962 | // Destination "rows" are padded to next highest multiple of Layout::kRows. |
963 | const int dst_rows = (src_rows + 3) & ~3; |
964 | const int trailing_rows = dst_rows - non_trailing_rows; |
965 | memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, |
966 | Layout::kCols * trailing_rows * sizeof(std::int16_t)); |
967 | } |
968 | } |
969 | |
970 | void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf, |
971 | int src_stride, int remaining_src_cols, |
972 | int src_rows, float* packed_ptr) { |
973 | profiler::ScopeLabel label("Pack kAvx512 float" ); |
974 | float trailing_buf[7 * 16]; |
975 | if (remaining_src_cols > 8) { |
976 | HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, |
977 | src_rows, packed_ptr, trailing_buf); |
978 | HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, |
979 | remaining_src_cols - 8, src_rows, packed_ptr + 8, |
980 | trailing_buf + 8); |
981 | } else { |
982 | memset(trailing_buf, 0, sizeof(trailing_buf)); |
983 | HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, |
984 | src_rows, packed_ptr, trailing_buf); |
985 | ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); |
986 | } |
987 | const int trailing_rows = src_rows & 7; |
988 | if (trailing_rows > 0) { |
989 | const int non_trailing_rows = src_rows & ~7; |
990 | memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, |
991 | 16 * trailing_rows * sizeof(float)); |
992 | } |
993 | } |
994 | |
995 | void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride, |
996 | int src_zero_point, std::int8_t* packed_ptr, |
997 | int packed_stride, int start_col, int end_col, |
998 | int src_cols, int block_row, int src_rows, |
999 | int input_xor, std::int32_t* sums) { |
1000 | int col = start_col; |
1001 | int src_end_col = std::min(end_col, src_cols); |
1002 | |
1003 | for (; col <= src_end_col - 16; col += 16) { |
1004 | std::int8_t* dst_ptr = packed_ptr; |
1005 | __m128i val0, val1, val2, val3; |
1006 | __m128i input_xor_dup = _mm_set1_epi8(input_xor); |
1007 | // Load a 4x16 block. |
1008 | if (block_row + 4 <= src_rows) { |
1009 | val0 = _mm_loadu_si128( |
1010 | reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride)); |
1011 | val1 = _mm_loadu_si128( |
1012 | reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride)); |
1013 | val2 = _mm_loadu_si128( |
1014 | reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride)); |
1015 | val3 = _mm_loadu_si128( |
1016 | reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride)); |
1017 | } else { |
1018 | val0 = _mm_set1_epi8(src_zero_point); |
1019 | val1 = val0; |
1020 | val2 = val0; |
1021 | val3 = val0; |
1022 | if (block_row + 0 < src_rows) |
1023 | val0 = _mm_loadu_si128( |
1024 | reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride)); |
1025 | if (block_row + 1 < src_rows) |
1026 | val1 = _mm_loadu_si128( |
1027 | reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride)); |
1028 | if (block_row + 2 < src_rows) |
1029 | val2 = _mm_loadu_si128( |
1030 | reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride)); |
1031 | if (block_row + 3 < src_rows) |
1032 | val3 = _mm_loadu_si128( |
1033 | reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride)); |
1034 | } |
1035 | // Maybe xor the sign bit to convert from uint8 to int8. |
1036 | val0 = _mm_xor_si128(val0, input_xor_dup); |
1037 | val1 = _mm_xor_si128(val1, input_xor_dup); |
1038 | val2 = _mm_xor_si128(val2, input_xor_dup); |
1039 | val3 = _mm_xor_si128(val3, input_xor_dup); |
1040 | // Update the sums. |
1041 | __m256i val16_0 = _mm256_cvtepi8_epi16(val0); |
1042 | __m256i val16_1 = _mm256_cvtepi8_epi16(val1); |
1043 | __m256i val16_2 = _mm256_cvtepi8_epi16(val2); |
1044 | __m256i val16_3 = _mm256_cvtepi8_epi16(val3); |
1045 | __m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1), |
1046 | _mm256_add_epi16(val16_2, val16_3)); |
1047 | __m512i sum = |
1048 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col)); |
1049 | sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16)); |
1050 | _mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum); |
1051 | auto zip = [](__m128i x, __m128i y) { |
1052 | auto perm_64_0_64_0 = [](__m128i x) { |
1053 | return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3), |
1054 | _mm256_castsi128_si256(x)); |
1055 | }; |
1056 | return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y)); |
1057 | }; |
1058 | __m256i t2_val0 = zip(val0, val1); |
1059 | __m256i t2_val1 = zip(val2, val3); |
1060 | __m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1); |
1061 | __m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1); |
1062 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), |
1063 | _mm256_extractf128_si256(t4_val0, 0)); |
1064 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), |
1065 | _mm256_extractf128_si256(t4_val1, 0)); |
1066 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32), |
1067 | _mm256_extractf128_si256(t4_val0, 1)); |
1068 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48), |
1069 | _mm256_extractf128_si256(t4_val1, 1)); |
1070 | src_ptr += 16; |
1071 | packed_ptr += packed_stride * 16; |
1072 | } |
1073 | for (; col < src_end_col; col++) { |
1074 | std::int32_t accum = 0; |
1075 | for (int r = 0; r < 4; r++) { |
1076 | std::int8_t packed_val; |
1077 | if (block_row + r < src_rows) { |
1078 | packed_val = input_xor ^ src_ptr[r * src_stride]; |
1079 | } else { |
1080 | packed_val = input_xor ^ src_zero_point; |
1081 | } |
1082 | accum += packed_val; |
1083 | *packed_ptr++ = packed_val; |
1084 | } |
1085 | if (sums) { |
1086 | sums[col] += accum; |
1087 | } |
1088 | src_ptr++; |
1089 | } |
1090 | for (; col < end_col; col++) { |
1091 | std::memset(packed_ptr, 0, 4); |
1092 | packed_ptr += 4; |
1093 | } |
1094 | } |
1095 | |
1096 | #endif // RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS) |
1097 | |
1098 | } // namespace ruy |
1099 | |