1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | #include <cstring> |
18 | |
19 | #include "ruy/check_macros.h" |
20 | #include "ruy/opt_set.h" |
21 | #include "ruy/pack_x86.h" |
22 | #include "ruy/path.h" |
23 | #include "ruy/platform.h" |
24 | #include "ruy/profiler/instrumentation.h" |
25 | |
26 | #if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) |
27 | #include <immintrin.h> // IWYU pragma: keep |
28 | #endif |
29 | |
30 | namespace ruy { |
31 | |
32 | #if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) |
33 | |
34 | void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t, |
35 | const std::int8_t*, int, int, int, std::int8_t*, |
36 | std::int32_t*) { |
37 | // CPU-ID-based checks should disable the path that would reach this point. |
38 | RUY_DCHECK(false); |
39 | } |
40 | |
41 | void PackFloatColMajorForAvx2(const float*, const float*, int, int, int, |
42 | float*) { |
43 | // CPU-ID-based checks should disable the path that would reach this point. |
44 | RUY_DCHECK(false); |
45 | } |
46 | |
47 | void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int, |
48 | int, int, int, int, int, int, std::int32_t*) { |
49 | RUY_DCHECK(false); |
50 | } |
51 | |
52 | #else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) |
53 | |
54 | // The first int8_t template parameter is arbitrary: this routine is common to |
55 | // all 8-bit source matrix types. |
56 | using PackImpl8bitAvx2 = |
57 | PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, |
58 | std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>; |
59 | |
60 | using PackImplFloatAvx2 = |
61 | PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, |
62 | float, float, Order::kColMajor>; |
63 | |
64 | namespace { |
65 | |
66 | inline void Pack8bitColMajorForAvx2Packer( |
67 | const std::int8_t* src_ptr, std::int8_t input_xor, |
68 | const std::int8_t* zerobuf, int src_stride, int remaining_src_cols, |
69 | int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr, |
70 | std::int8_t* trailing_buf) { |
71 | using Layout = PackImpl8bitAvx2::Layout; |
72 | RUY_DCHECK_EQ(Layout::kCols, 8); |
73 | RUY_DCHECK_EQ(Layout::kRows, 4); |
74 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
75 | // We process 8 of these chunks at a time, padding short input chunks. |
76 | constexpr int kNumRowChunks = 8; |
77 | constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; |
78 | |
79 | const std::int8_t* src_ptr0 = src_ptr; |
80 | const std::int8_t* src_ptr1 = src_ptr0 + src_stride; |
81 | const std::int8_t* src_ptr2 = src_ptr1 + src_stride; |
82 | const std::int8_t* src_ptr3 = src_ptr2 + src_stride; |
83 | const std::int8_t* src_ptr4 = src_ptr3 + src_stride; |
84 | const std::int8_t* src_ptr5 = src_ptr4 + src_stride; |
85 | const std::int8_t* src_ptr6 = src_ptr5 + src_stride; |
86 | const std::int8_t* src_ptr7 = src_ptr6 + src_stride; |
87 | std::int64_t src_inc0 = kNumChunkedSrcRows; |
88 | std::int64_t src_inc1 = kNumChunkedSrcRows; |
89 | std::int64_t src_inc2 = kNumChunkedSrcRows; |
90 | std::int64_t src_inc3 = kNumChunkedSrcRows; |
91 | std::int64_t src_inc4 = kNumChunkedSrcRows; |
92 | std::int64_t src_inc5 = kNumChunkedSrcRows; |
93 | std::int64_t src_inc6 = kNumChunkedSrcRows; |
94 | std::int64_t src_inc7 = kNumChunkedSrcRows; |
95 | // Handle cases where source does not have Layout::kCols (8) columns. |
96 | if (remaining_src_cols < 8) { |
97 | if (remaining_src_cols <= 0) { |
98 | src_ptr0 = zerobuf; |
99 | src_inc0 = 0; |
100 | } |
101 | if (remaining_src_cols <= 1) { |
102 | src_ptr1 = zerobuf; |
103 | src_inc1 = 0; |
104 | } |
105 | if (remaining_src_cols <= 2) { |
106 | src_ptr2 = zerobuf; |
107 | src_inc2 = 0; |
108 | } |
109 | if (remaining_src_cols <= 3) { |
110 | src_ptr3 = zerobuf; |
111 | src_inc3 = 0; |
112 | } |
113 | if (remaining_src_cols <= 4) { |
114 | src_ptr4 = zerobuf; |
115 | src_inc4 = 0; |
116 | } |
117 | if (remaining_src_cols <= 5) { |
118 | src_ptr5 = zerobuf; |
119 | src_inc5 = 0; |
120 | } |
121 | if (remaining_src_cols <= 6) { |
122 | src_ptr6 = zerobuf; |
123 | src_inc6 = 0; |
124 | } |
125 | src_ptr7 = zerobuf; |
126 | src_inc7 = 0; |
127 | } |
128 | |
129 | const std::int8_t zero_point = zerobuf[0]; |
130 | |
131 | if (sums_ptr) { |
132 | // i: Layout::kCols. |
133 | for (int i = 0; i < 8; ++i) { |
134 | sums_ptr[i] = 0; |
135 | } |
136 | } |
137 | std::int32_t sums_adjustment = 0; |
138 | const __m256i ones_16bit = _mm256_set1_epi16(1); |
139 | __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); |
140 | __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); |
141 | |
142 | // The overall packing effectively pads the source rows to |
143 | // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we |
144 | // only pack for (src_rows + 31) & ~31. When there is an incomplete |
145 | // destination block, this is stored into trailing_buf instead of packed_ptr. |
146 | for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { |
147 | // Available source rows. |
148 | // If this is less than 0 (for m=1), we skip, having filled trailing |
149 | // buffer for m=0. Also, if source rows is zero on m=1, then we filled |
150 | // exactly to the end of the column in the packed buffer. |
151 | const int available_src_rows = src_rows - k; |
152 | // Effectively, |
153 | // available rows = std::max(0, std::min(8, src_rows - k)); |
154 | // treat each case separately. |
155 | if (available_src_rows >= kNumChunkedSrcRows) { |
156 | if (sums_ptr) { |
157 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
158 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
159 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
160 | |
161 | t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); |
162 | t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); |
163 | t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); |
164 | t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); |
165 | t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); |
166 | t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); |
167 | t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); |
168 | t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); |
169 | |
170 | r0 = _mm256_unpacklo_epi32(t0, t1); |
171 | r4 = _mm256_unpacklo_epi32(t4, t5); |
172 | r2 = _mm256_unpackhi_epi32(t0, t1); |
173 | r6 = _mm256_unpackhi_epi32(t4, t5); |
174 | r1 = _mm256_unpacklo_epi32(t2, t3); |
175 | r5 = _mm256_unpacklo_epi32(t6, t7); |
176 | r3 = _mm256_unpackhi_epi32(t2, t3); |
177 | r7 = _mm256_unpackhi_epi32(t6, t7); |
178 | |
179 | t0 = _mm256_unpacklo_epi64(r0, r1); |
180 | t4 = _mm256_unpacklo_epi64(r4, r5); |
181 | t2 = _mm256_unpackhi_epi64(r0, r1); |
182 | t6 = _mm256_unpackhi_epi64(r4, r5); |
183 | t1 = _mm256_unpacklo_epi64(r2, r3); |
184 | t5 = _mm256_unpacklo_epi64(r6, r7); |
185 | t3 = _mm256_unpackhi_epi64(r2, r3); |
186 | t7 = _mm256_unpackhi_epi64(r6, r7); |
187 | |
188 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
189 | // and then by 8 bytes *within* lanes. The following set interleave by |
190 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
191 | // t4) are interleaved to create (r0, r1). This complexity follows from |
192 | // the way that AVX is centered around MM 128-bit lanes. |
193 | r0 = _mm256_permute2x128_si256(t0, t4, 0x20); |
194 | r4 = _mm256_permute2x128_si256(t1, t5, 0x20); |
195 | r1 = _mm256_permute2x128_si256(t0, t4, 0x31); |
196 | r5 = _mm256_permute2x128_si256(t1, t5, 0x31); |
197 | r2 = _mm256_permute2x128_si256(t2, t6, 0x20); |
198 | r6 = _mm256_permute2x128_si256(t3, t7, 0x20); |
199 | r3 = _mm256_permute2x128_si256(t2, t6, 0x31); |
200 | r7 = _mm256_permute2x128_si256(t3, t7, 0x31); |
201 | |
202 | r0 = _mm256_xor_si256(r0, input_xor_v); |
203 | r1 = _mm256_xor_si256(r1, input_xor_v); |
204 | r2 = _mm256_xor_si256(r2, input_xor_v); |
205 | r3 = _mm256_xor_si256(r3, input_xor_v); |
206 | r4 = _mm256_xor_si256(r4, input_xor_v); |
207 | r5 = _mm256_xor_si256(r5, input_xor_v); |
208 | r6 = _mm256_xor_si256(r6, input_xor_v); |
209 | r7 = _mm256_xor_si256(r7, input_xor_v); |
210 | |
211 | __m256i sums_4x4_16bit_lo; |
212 | sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); |
213 | sums_4x4_16bit_lo = |
214 | _mm256_add_epi16(sums_4x4_16bit_lo, |
215 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); |
216 | sums_4x4_16bit_lo = |
217 | _mm256_add_epi16(sums_4x4_16bit_lo, |
218 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); |
219 | sums_4x4_16bit_lo = |
220 | _mm256_add_epi16(sums_4x4_16bit_lo, |
221 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); |
222 | sums_4x4_16bit_lo = |
223 | _mm256_add_epi16(sums_4x4_16bit_lo, |
224 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); |
225 | sums_4x4_16bit_lo = |
226 | _mm256_add_epi16(sums_4x4_16bit_lo, |
227 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); |
228 | sums_4x4_16bit_lo = |
229 | _mm256_add_epi16(sums_4x4_16bit_lo, |
230 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); |
231 | sums_4x4_16bit_lo = |
232 | _mm256_add_epi16(sums_4x4_16bit_lo, |
233 | _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); |
234 | |
235 | // The sums have been performed across columns, and now we have 4x16-bit |
236 | // sums packed together. We use madd for pairwise 32-bit sums. |
237 | const __m256i sums_4x2_32bit_lo_new = |
238 | _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); |
239 | sums_4x2_32bit_lo = |
240 | _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); |
241 | |
242 | __m256i sums_4x4_16bit_hi; |
243 | sums_4x4_16bit_hi = |
244 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); |
245 | sums_4x4_16bit_hi = _mm256_add_epi16( |
246 | sums_4x4_16bit_hi, |
247 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); |
248 | sums_4x4_16bit_hi = _mm256_add_epi16( |
249 | sums_4x4_16bit_hi, |
250 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); |
251 | sums_4x4_16bit_hi = _mm256_add_epi16( |
252 | sums_4x4_16bit_hi, |
253 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); |
254 | sums_4x4_16bit_hi = _mm256_add_epi16( |
255 | sums_4x4_16bit_hi, |
256 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); |
257 | sums_4x4_16bit_hi = _mm256_add_epi16( |
258 | sums_4x4_16bit_hi, |
259 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); |
260 | sums_4x4_16bit_hi = _mm256_add_epi16( |
261 | sums_4x4_16bit_hi, |
262 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); |
263 | sums_4x4_16bit_hi = _mm256_add_epi16( |
264 | sums_4x4_16bit_hi, |
265 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); |
266 | |
267 | const __m256i sums_4x2_32bit_hi_new = |
268 | _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); |
269 | sums_4x2_32bit_hi = |
270 | _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); |
271 | |
272 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), |
273 | r0); |
274 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), |
275 | r4); |
276 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), |
277 | r1); |
278 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), |
279 | r5); |
280 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), |
281 | r2); |
282 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), |
283 | r6); |
284 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), |
285 | r3); |
286 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), |
287 | r7); |
288 | } else { |
289 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
290 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
291 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
292 | |
293 | t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); |
294 | t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); |
295 | t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); |
296 | t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); |
297 | t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); |
298 | t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); |
299 | t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); |
300 | t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); |
301 | |
302 | r0 = _mm256_unpacklo_epi32(t0, t1); |
303 | r4 = _mm256_unpacklo_epi32(t4, t5); |
304 | r2 = _mm256_unpackhi_epi32(t0, t1); |
305 | r6 = _mm256_unpackhi_epi32(t4, t5); |
306 | r1 = _mm256_unpacklo_epi32(t2, t3); |
307 | r5 = _mm256_unpacklo_epi32(t6, t7); |
308 | r3 = _mm256_unpackhi_epi32(t2, t3); |
309 | r7 = _mm256_unpackhi_epi32(t6, t7); |
310 | |
311 | t0 = _mm256_unpacklo_epi64(r0, r1); |
312 | t4 = _mm256_unpacklo_epi64(r4, r5); |
313 | t2 = _mm256_unpackhi_epi64(r0, r1); |
314 | t6 = _mm256_unpackhi_epi64(r4, r5); |
315 | t1 = _mm256_unpacklo_epi64(r2, r3); |
316 | t5 = _mm256_unpacklo_epi64(r6, r7); |
317 | t3 = _mm256_unpackhi_epi64(r2, r3); |
318 | t7 = _mm256_unpackhi_epi64(r6, r7); |
319 | |
320 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
321 | // and then by 8 bytes *within* lanes. The following set interleave by |
322 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
323 | // t4) are interleaved to create (r0, r1). This complexity follows from |
324 | // the way that AVX is centered around MM 128-bit lanes. |
325 | r0 = _mm256_permute2x128_si256(t0, t4, 0x20); |
326 | r4 = _mm256_permute2x128_si256(t1, t5, 0x20); |
327 | r1 = _mm256_permute2x128_si256(t0, t4, 0x31); |
328 | r5 = _mm256_permute2x128_si256(t1, t5, 0x31); |
329 | r2 = _mm256_permute2x128_si256(t2, t6, 0x20); |
330 | r6 = _mm256_permute2x128_si256(t3, t7, 0x20); |
331 | r3 = _mm256_permute2x128_si256(t2, t6, 0x31); |
332 | r7 = _mm256_permute2x128_si256(t3, t7, 0x31); |
333 | |
334 | r0 = _mm256_xor_si256(r0, input_xor_v); |
335 | r1 = _mm256_xor_si256(r1, input_xor_v); |
336 | r2 = _mm256_xor_si256(r2, input_xor_v); |
337 | r3 = _mm256_xor_si256(r3, input_xor_v); |
338 | r4 = _mm256_xor_si256(r4, input_xor_v); |
339 | r5 = _mm256_xor_si256(r5, input_xor_v); |
340 | r6 = _mm256_xor_si256(r6, input_xor_v); |
341 | r7 = _mm256_xor_si256(r7, input_xor_v); |
342 | |
343 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), |
344 | r0); |
345 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), |
346 | r4); |
347 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), |
348 | r1); |
349 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), |
350 | r5); |
351 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), |
352 | r2); |
353 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), |
354 | r6); |
355 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), |
356 | r3); |
357 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), |
358 | r7); |
359 | } |
360 | } else if (available_src_rows > 0) { |
361 | RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); |
362 | // We do not care what goes into the trailing buffer, but we want |
363 | // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. |
364 | // |
365 | // We compensate for padding-with-zero_point by initializing the |
366 | // summations with the compensating offset, effectively |
367 | // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * |
368 | // 4 * (8 - ((available_src_rows + 3) >> 2)). |
369 | // |
370 | // Note that (zero_point ^ input_xor) is performed in 8-bits and then |
371 | // cast. |
372 | sums_adjustment += |
373 | -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); |
374 | |
375 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
376 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
377 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
378 | |
379 | t0 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr0); |
380 | t4 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr4); |
381 | t1 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr1); |
382 | t5 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr5); |
383 | t2 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr2); |
384 | t6 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr6); |
385 | t3 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr3); |
386 | t7 = MaskLoadu<Path::kAvx2Fma>(available_src_rows, zero_point, src_ptr7); |
387 | |
388 | r0 = _mm256_unpacklo_epi32(t0, t1); |
389 | r4 = _mm256_unpacklo_epi32(t4, t5); |
390 | r2 = _mm256_unpackhi_epi32(t0, t1); |
391 | r6 = _mm256_unpackhi_epi32(t4, t5); |
392 | r1 = _mm256_unpacklo_epi32(t2, t3); |
393 | r5 = _mm256_unpacklo_epi32(t6, t7); |
394 | r3 = _mm256_unpackhi_epi32(t2, t3); |
395 | r7 = _mm256_unpackhi_epi32(t6, t7); |
396 | |
397 | t0 = _mm256_unpacklo_epi64(r0, r1); |
398 | t4 = _mm256_unpacklo_epi64(r4, r5); |
399 | t2 = _mm256_unpackhi_epi64(r0, r1); |
400 | t6 = _mm256_unpackhi_epi64(r4, r5); |
401 | t1 = _mm256_unpacklo_epi64(r2, r3); |
402 | t5 = _mm256_unpacklo_epi64(r6, r7); |
403 | t3 = _mm256_unpackhi_epi64(r2, r3); |
404 | t7 = _mm256_unpackhi_epi64(r6, r7); |
405 | |
406 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
407 | // and then by 8 bytes *within* lanes. The following set interleave by |
408 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
409 | // t4) are interleaved to create (r0, r1). This complexity follows from |
410 | // the way that AVX is centered around MM 128-bit lanes. |
411 | r0 = _mm256_permute2x128_si256(t0, t4, 0x20); |
412 | r4 = _mm256_permute2x128_si256(t1, t5, 0x20); |
413 | r1 = _mm256_permute2x128_si256(t0, t4, 0x31); |
414 | r5 = _mm256_permute2x128_si256(t1, t5, 0x31); |
415 | r2 = _mm256_permute2x128_si256(t2, t6, 0x20); |
416 | r6 = _mm256_permute2x128_si256(t3, t7, 0x20); |
417 | r3 = _mm256_permute2x128_si256(t2, t6, 0x31); |
418 | r7 = _mm256_permute2x128_si256(t3, t7, 0x31); |
419 | |
420 | r0 = _mm256_xor_si256(r0, input_xor_v); |
421 | r1 = _mm256_xor_si256(r1, input_xor_v); |
422 | r2 = _mm256_xor_si256(r2, input_xor_v); |
423 | r3 = _mm256_xor_si256(r3, input_xor_v); |
424 | r4 = _mm256_xor_si256(r4, input_xor_v); |
425 | r5 = _mm256_xor_si256(r5, input_xor_v); |
426 | r6 = _mm256_xor_si256(r6, input_xor_v); |
427 | r7 = _mm256_xor_si256(r7, input_xor_v); |
428 | |
429 | __m256i sums_4x4_16bit_lo; |
430 | sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); |
431 | sums_4x4_16bit_lo = _mm256_add_epi16( |
432 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); |
433 | sums_4x4_16bit_lo = _mm256_add_epi16( |
434 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); |
435 | sums_4x4_16bit_lo = _mm256_add_epi16( |
436 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); |
437 | sums_4x4_16bit_lo = _mm256_add_epi16( |
438 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); |
439 | sums_4x4_16bit_lo = _mm256_add_epi16( |
440 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); |
441 | sums_4x4_16bit_lo = _mm256_add_epi16( |
442 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); |
443 | sums_4x4_16bit_lo = _mm256_add_epi16( |
444 | sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); |
445 | |
446 | // The sums have been performed across columns, and now we have 4x16-bit |
447 | // sums packed together. We use madd for pairwise 32-bit sums. |
448 | const __m256i sums_4x2_32bit_lo_new = |
449 | _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); |
450 | sums_4x2_32bit_lo = |
451 | _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); |
452 | |
453 | __m256i sums_4x4_16bit_hi; |
454 | sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); |
455 | sums_4x4_16bit_hi = _mm256_add_epi16( |
456 | sums_4x4_16bit_hi, |
457 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); |
458 | sums_4x4_16bit_hi = _mm256_add_epi16( |
459 | sums_4x4_16bit_hi, |
460 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); |
461 | sums_4x4_16bit_hi = _mm256_add_epi16( |
462 | sums_4x4_16bit_hi, |
463 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); |
464 | sums_4x4_16bit_hi = _mm256_add_epi16( |
465 | sums_4x4_16bit_hi, |
466 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); |
467 | sums_4x4_16bit_hi = _mm256_add_epi16( |
468 | sums_4x4_16bit_hi, |
469 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); |
470 | sums_4x4_16bit_hi = _mm256_add_epi16( |
471 | sums_4x4_16bit_hi, |
472 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); |
473 | sums_4x4_16bit_hi = _mm256_add_epi16( |
474 | sums_4x4_16bit_hi, |
475 | _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); |
476 | |
477 | const __m256i sums_4x2_32bit_hi_new = |
478 | _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); |
479 | sums_4x2_32bit_hi = |
480 | _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); |
481 | |
482 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), |
483 | r0); |
484 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), |
485 | r4); |
486 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), |
487 | r1); |
488 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), |
489 | r5); |
490 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), |
491 | r2); |
492 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), |
493 | r6); |
494 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), |
495 | r3); |
496 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), |
497 | r7); |
498 | } |
499 | |
500 | packed_ptr += 8 * kNumChunkedSrcRows; |
501 | src_ptr0 += src_inc0; |
502 | src_ptr1 += src_inc1; |
503 | src_ptr2 += src_inc2; |
504 | src_ptr3 += src_inc3; |
505 | src_ptr4 += src_inc4; |
506 | src_ptr5 += src_inc5; |
507 | src_ptr6 += src_inc6; |
508 | src_ptr7 += src_inc7; |
509 | } |
510 | |
511 | if (sums_ptr) { |
512 | const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); |
513 | |
514 | __m256i sums = |
515 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); |
516 | const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); |
517 | |
518 | // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the |
519 | // neighbours, finshing up by adding them to the stored accumulated sums. |
520 | const __m256i sums_2x4_32bit_lo = |
521 | _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); |
522 | const __m256i sums_2x4_32bit_hi = |
523 | _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); |
524 | const __m256i sums_2x4_32bit_a = |
525 | _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); |
526 | const __m256i sums_2x4_32bit_b = |
527 | _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); |
528 | sums = _mm256_add_epi32(sums, sums_adjustment_v); |
529 | sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); |
530 | sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); |
531 | |
532 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); |
533 | } |
534 | } |
535 | |
536 | // Use AVX2 specific intrinsic for greater than comparison. |
537 | template <> |
538 | inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a, |
539 | const __m256i& b) { |
540 | return _mm256_cmpgt_epi32(a, b); |
541 | } |
542 | |
543 | } // namespace. |
544 | |
545 | void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, |
546 | const std::int8_t* zerobuf, int src_stride, |
547 | int remaining_src_cols, int src_rows, |
548 | std::int8_t* packed_ptr, std::int32_t* sums_ptr) { |
549 | profiler::ScopeLabel label("Pack kAvx2Fma 8bit" ); |
550 | |
551 | using Layout = PackImpl8bitAvx2::Layout; |
552 | RUY_DCHECK_EQ(Layout::kCols, 8); |
553 | RUY_DCHECK_EQ(Layout::kRows, 4); |
554 | |
555 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
556 | // We process 8 of these chunks at a time, padding short input chunks. |
557 | static constexpr int kNumRowChunks = 8; // Short input is padded. |
558 | |
559 | // Each packed block is 4*8, and there are normally 8. The trailing block is |
560 | // only slightly shorter. |
561 | constexpr int kTrailingBufSize = |
562 | kNumRowChunks * Layout::kCols * Layout::kRows; |
563 | std::int8_t trailing_buf[kTrailingBufSize]; |
564 | memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); |
565 | |
566 | Pack8bitColMajorForAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, |
567 | remaining_src_cols, src_rows, packed_ptr, |
568 | sums_ptr, trailing_buf); |
569 | |
570 | constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; |
571 | const bool trailing_data = (src_rows & kChunkedRowMask) > 0; |
572 | // If the number of source rows is not a multiple of kChunkedRowMask, there |
573 | // will be data in the trailing buffer, |
574 | if (trailing_data) { |
575 | const int non_trailing_rows = src_rows & ~kChunkedRowMask; |
576 | // Destination "rows" are padded to next highest multiple of Layout::kRows. |
577 | const int dst_rows = (src_rows + 3) & ~3; |
578 | const int trailing_rows = dst_rows - non_trailing_rows; |
579 | memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, |
580 | Layout::kCols * trailing_rows * sizeof(std::int8_t)); |
581 | } |
582 | } |
583 | |
584 | void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf, |
585 | int src_stride, int remaining_src_cols, |
586 | int src_rows, float* packed_ptr) { |
587 | profiler::ScopeLabel label("Pack kAvx2Fma float" ); |
588 | static constexpr int kPackCols = 8; // Source cols packed together. |
589 | static constexpr int kPackRows = 8; // Short input is padded. |
590 | float trailing_buf[(kPackRows - 1) * kPackCols]; |
591 | if (remaining_src_cols < 8) { |
592 | memset(trailing_buf, 0, sizeof(trailing_buf)); |
593 | } |
594 | PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx2, Path::kAvx2Fma>( |
595 | src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, |
596 | trailing_buf); |
597 | |
598 | const int trailing_rows = src_rows & (kPackRows - 1); |
599 | if (trailing_rows > 0) { |
600 | const int non_trailing_rows = src_rows & ~(kPackRows - 1); |
601 | memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, |
602 | kPackCols * trailing_rows * sizeof(float)); |
603 | } |
604 | } |
605 | |
606 | void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride, |
607 | int src_zero_point, std::int8_t* packed_ptr, |
608 | int packed_stride, int start_col, int end_col, |
609 | int src_cols, int block_row, int src_rows, |
610 | int input_xor, std::int32_t* sums) { |
611 | int col = start_col; |
612 | int src_end_col = std::min(end_col, src_cols); |
613 | |
614 | for (; col <= src_end_col - 8; col += 8) { |
615 | std::int8_t* dst_ptr = packed_ptr; |
616 | __m128i val0, val1, val2, val3; |
617 | __m128i input_xor_dup = _mm_set1_epi8(input_xor); |
618 | // Load a 4x8 block. |
619 | if (block_row + 4 <= src_rows) { |
620 | val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); |
621 | val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); |
622 | val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); |
623 | val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); |
624 | } else { |
625 | val0 = _mm_set1_epi8(src_zero_point); |
626 | val1 = val0; |
627 | val2 = val0; |
628 | val3 = val0; |
629 | if (block_row + 0 < src_rows) |
630 | val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); |
631 | if (block_row + 1 < src_rows) |
632 | val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); |
633 | if (block_row + 2 < src_rows) |
634 | val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); |
635 | if (block_row + 3 < src_rows) |
636 | val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); |
637 | } |
638 | // Maybe xor the sign bit to convert from uint8 to int8. |
639 | val0 = _mm_xor_si128(val0, input_xor_dup); |
640 | val1 = _mm_xor_si128(val1, input_xor_dup); |
641 | val2 = _mm_xor_si128(val2, input_xor_dup); |
642 | val3 = _mm_xor_si128(val3, input_xor_dup); |
643 | // Update the sums. |
644 | __m128i val16_0 = _mm_cvtepi8_epi16(val0); |
645 | __m128i val16_1 = _mm_cvtepi8_epi16(val1); |
646 | __m128i val16_2 = _mm_cvtepi8_epi16(val2); |
647 | __m128i val16_3 = _mm_cvtepi8_epi16(val3); |
648 | __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1), |
649 | _mm_add_epi16(val16_2, val16_3)); |
650 | __m256i sum = |
651 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col)); |
652 | sum = _mm256_add_epi32(sum, _mm256_cvtepi16_epi32(new_sum16)); |
653 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum); |
654 | // Perform the transposition of 4x4 blocks |
655 | __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1); |
656 | __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3); |
657 | __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1); |
658 | __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1); |
659 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0); |
660 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1); |
661 | src_ptr += 8; |
662 | packed_ptr += packed_stride * 8; |
663 | } |
664 | for (; col < src_end_col; col++) { |
665 | std::int32_t accum = 0; |
666 | for (int r = 0; r < 4; r++) { |
667 | std::int8_t packed_val; |
668 | if (block_row + r < src_rows) { |
669 | packed_val = input_xor ^ src_ptr[r * src_stride]; |
670 | } else { |
671 | packed_val = input_xor ^ src_zero_point; |
672 | } |
673 | accum += packed_val; |
674 | *packed_ptr++ = packed_val; |
675 | } |
676 | if (sums) { |
677 | sums[col] += accum; |
678 | } |
679 | src_ptr++; |
680 | } |
681 | for (; col < end_col; col++) { |
682 | std::memset(packed_ptr, 0, 4); |
683 | packed_ptr += 4; |
684 | } |
685 | } |
686 | |
687 | #endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) |
688 | |
689 | } // namespace ruy |
690 | |