1 | /* Copyright 2020 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | #include <cstring> |
18 | |
19 | #include "ruy/check_macros.h" |
20 | #include "ruy/opt_set.h" |
21 | #include "ruy/pack_x86.h" |
22 | #include "ruy/path.h" |
23 | #include "ruy/platform.h" |
24 | #include "ruy/profiler/instrumentation.h" |
25 | |
26 | #if RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS) |
27 | #include <immintrin.h> // IWYU pragma: keep |
28 | #endif |
29 | |
30 | namespace ruy { |
31 | |
32 | #if !(RUY_PLATFORM_AVX && RUY_OPT(ASM)) |
33 | |
34 | void Pack8bitColMajorForAvx(const std::int8_t*, std::int8_t, const std::int8_t*, |
35 | int, int, int, std::int8_t*, std::int32_t*) { |
36 | // CPU-ID-based checks should disable the path that would reach this point. |
37 | RUY_DCHECK(false); |
38 | } |
39 | |
40 | void PackFloatColMajorForAvx(const float*, const float*, int, int, int, |
41 | float*) { |
42 | // CPU-ID-based checks should disable the path that would reach this point. |
43 | RUY_DCHECK(false); |
44 | } |
45 | |
46 | void Pack8bitRowMajorForAvx(const std::uint8_t*, int, int, std::int8_t*, int, |
47 | int, int, int, int, int, int, std::int32_t*) { |
48 | RUY_DCHECK(false); |
49 | } |
50 | |
51 | #else // RUY_PLATFORM_AVX && RUY_OPT(ASM) |
52 | |
53 | // The first int8_t template parameter is arbitrary: this routine is common to |
54 | // all 8-bit source matrix types. |
55 | using PackImpl8bitAvx = |
56 | PackImpl<Path::kAvx, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t, |
57 | std::int8_t, std::int32_t, Order::kColMajor>; |
58 | |
59 | using PackImplFloatAvx = |
60 | PackImpl<Path::kAvx, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, |
61 | float, float, Order::kColMajor>; |
62 | |
63 | namespace { |
64 | |
65 | // Perform the equivalent of mm256_permutevar8x32 with |
66 | // a second argument of {7, 5, 3, 1, 6, 4, 2, 0} |
67 | inline __m256i PermuteEpi32EvenOdds(const __m256i& a) { |
68 | // a_lo = 3 2 1 0 |
69 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
70 | // a_hi = 7 6 5 4 |
71 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
72 | // shuffle a_lo to get 3 1 2 0 |
73 | __m128i tmp_lo = _mm_shuffle_epi32(a_lo, 0xd8); |
74 | // shuffle a_hi to get 7 5 6 4 |
75 | __m128i tmp_hi = _mm_shuffle_epi32(a_hi, 0xd8); |
76 | // unpack lo 64 of res_lo and res hi to get 6 4 2 0 |
77 | __m128i res_lo = _mm_unpacklo_epi64(tmp_lo, tmp_hi); |
78 | // unpack hi 64 of res_lo and res hi to get 7 5 3 1 |
79 | __m128i res_hi = _mm_unpackhi_epi64(tmp_lo, tmp_hi); |
80 | return _mm256_set_m128i(res_hi, res_lo); |
81 | } |
82 | |
83 | inline __m128i (const __m256i& a, const int imm) { |
84 | switch (imm) { |
85 | case 0: |
86 | return _mm256_extractf128_si256(a, 0); |
87 | case 1: |
88 | return _mm256_extractf128_si256(a, 1); |
89 | default: |
90 | RUY_DCHECK_LT(imm, 2); |
91 | return _mm_setzero_si128(); |
92 | } |
93 | } |
94 | |
95 | inline __m256i mm256_cvtepi8_epi16(const __m128i& a) { |
96 | // Take the upper 64 bits of a and put in the first 64 bits of 'hi' |
97 | __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); |
98 | return _mm256_set_m128i(_mm_cvtepi8_epi16(hi), _mm_cvtepi8_epi16(a)); |
99 | } |
100 | |
101 | inline __m256i mm256_cvtepi16_epi32(const __m128i& a) { |
102 | // Take the upper 64 bits of a and put in the first 64 bits of 'hi' |
103 | __m128i hi = _mm_unpackhi_epi64(a, _mm_setzero_si128()); |
104 | return _mm256_set_m128i(_mm_cvtepi16_epi32(hi), _mm_cvtepi16_epi32(a)); |
105 | } |
106 | |
107 | inline __m256i mm256_xor_si256(const __m256i& a, const __m256i& b) { |
108 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
109 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
110 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
111 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
112 | __m128i lo = _mm_xor_si128(a_lo, b_lo); |
113 | __m128i hi = _mm_xor_si128(a_hi, b_hi); |
114 | return _mm256_set_m128i(hi, lo); |
115 | } |
116 | |
117 | inline __m256i mm256_unpacklo_epi32(const __m256i& a, const __m256i& b) { |
118 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
119 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
120 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
121 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
122 | __m128i lo = _mm_unpacklo_epi32(a_lo, b_lo); |
123 | __m128i hi = _mm_unpacklo_epi32(a_hi, b_hi); |
124 | return _mm256_set_m128i(hi, lo); |
125 | } |
126 | |
127 | inline __m256i mm256_unpacklo_epi64(const __m256i& a, const __m256i& b) { |
128 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
129 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
130 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
131 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
132 | __m128i lo = _mm_unpacklo_epi64(a_lo, b_lo); |
133 | __m128i hi = _mm_unpacklo_epi64(a_hi, b_hi); |
134 | return _mm256_set_m128i(hi, lo); |
135 | } |
136 | |
137 | inline __m256i mm256_unpackhi_epi32(const __m256i& a, const __m256i& b) { |
138 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
139 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
140 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
141 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
142 | __m128i lo = _mm_unpackhi_epi32(a_lo, b_lo); |
143 | __m128i hi = _mm_unpackhi_epi32(a_hi, b_hi); |
144 | return _mm256_set_m128i(hi, lo); |
145 | } |
146 | |
147 | inline __m256i mm256_unpackhi_epi64(const __m256i& a, const __m256i& b) { |
148 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
149 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
150 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
151 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
152 | __m128i lo = _mm_unpackhi_epi64(a_lo, b_lo); |
153 | __m128i hi = _mm_unpackhi_epi64(a_hi, b_hi); |
154 | return _mm256_set_m128i(hi, lo); |
155 | } |
156 | |
157 | inline __m256i mm256_add_epi32(const __m256i& a, const __m256i& b) { |
158 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
159 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
160 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
161 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
162 | __m128i lo = _mm_add_epi32(a_lo, b_lo); |
163 | __m128i hi = _mm_add_epi32(a_hi, b_hi); |
164 | return _mm256_set_m128i(hi, lo); |
165 | } |
166 | |
167 | inline __m256i mm256_add_epi16(const __m256i& a, const __m256i& b) { |
168 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
169 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
170 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
171 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
172 | __m128i lo = _mm_add_epi16(a_lo, b_lo); |
173 | __m128i hi = _mm_add_epi16(a_hi, b_hi); |
174 | return _mm256_set_m128i(hi, lo); |
175 | } |
176 | |
177 | inline __m256i mm256_madd_epi16(const __m256i& a, const __m256i& b) { |
178 | __m128i a_lo = _mm256_extractf128_si256(a, 0); |
179 | __m128i a_hi = _mm256_extractf128_si256(a, 1); |
180 | __m128i b_lo = _mm256_extractf128_si256(b, 0); |
181 | __m128i b_hi = _mm256_extractf128_si256(b, 1); |
182 | __m128i lo = _mm_madd_epi16(a_lo, b_lo); |
183 | __m128i hi = _mm_madd_epi16(a_hi, b_hi); |
184 | return _mm256_set_m128i(hi, lo); |
185 | } |
186 | |
187 | inline __m128i mm_permute_helper(const __m256i& a, const __m256i& b, |
188 | const int imm) { |
189 | __m128i tmp = _mm_setzero_si128(); |
190 | if (!(imm & 8)) { |
191 | switch (imm & 3) { |
192 | case 0: |
193 | return _mm256_extractf128_si256(a, 0); |
194 | case 1: |
195 | return _mm256_extractf128_si256(a, 1); |
196 | case 2: |
197 | return _mm256_extractf128_si256(b, 0); |
198 | case 3: |
199 | return _mm256_extractf128_si256(b, 1); |
200 | } |
201 | } |
202 | return tmp; |
203 | } |
204 | |
205 | inline __m256i mm256_permute2x128_si256(const __m256i& a, const __m256i& b, |
206 | const int imm) { |
207 | const int lo_imm = imm & 15; |
208 | __m128i lo = mm_permute_helper(a, b, lo_imm); |
209 | const int hi_imm = (imm >> 4) & 15; |
210 | __m128i hi = mm_permute_helper(a, b, hi_imm); |
211 | return _mm256_set_m128i(hi, lo); |
212 | } |
213 | |
214 | inline void Pack8bitColMajorForAvxPacker(const std::int8_t* src_ptr, |
215 | std::int8_t input_xor, |
216 | const std::int8_t* zerobuf, |
217 | int src_stride, int remaining_src_cols, |
218 | int src_rows, std::int8_t* packed_ptr, |
219 | std::int32_t* sums_ptr, |
220 | std::int8_t* trailing_buf) { |
221 | using Layout = PackImpl8bitAvx::Layout; |
222 | RUY_DCHECK_EQ(Layout::kCols, 8); |
223 | RUY_DCHECK_EQ(Layout::kRows, 4); |
224 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
225 | // We process 8 of these chunks at a time, padding short input chunks. |
226 | constexpr int kNumRowChunks = 8; |
227 | constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; |
228 | |
229 | const std::int8_t* src_ptr0 = src_ptr; |
230 | const std::int8_t* src_ptr1 = src_ptr0 + src_stride; |
231 | const std::int8_t* src_ptr2 = src_ptr1 + src_stride; |
232 | const std::int8_t* src_ptr3 = src_ptr2 + src_stride; |
233 | const std::int8_t* src_ptr4 = src_ptr3 + src_stride; |
234 | const std::int8_t* src_ptr5 = src_ptr4 + src_stride; |
235 | const std::int8_t* src_ptr6 = src_ptr5 + src_stride; |
236 | const std::int8_t* src_ptr7 = src_ptr6 + src_stride; |
237 | std::int64_t src_inc0 = kNumChunkedSrcRows; |
238 | std::int64_t src_inc1 = kNumChunkedSrcRows; |
239 | std::int64_t src_inc2 = kNumChunkedSrcRows; |
240 | std::int64_t src_inc3 = kNumChunkedSrcRows; |
241 | std::int64_t src_inc4 = kNumChunkedSrcRows; |
242 | std::int64_t src_inc5 = kNumChunkedSrcRows; |
243 | std::int64_t src_inc6 = kNumChunkedSrcRows; |
244 | std::int64_t src_inc7 = kNumChunkedSrcRows; |
245 | // Handle cases where source does not have Layout::kCols (8) columns. |
246 | if (remaining_src_cols < 8) { |
247 | if (remaining_src_cols <= 0) { |
248 | src_ptr0 = zerobuf; |
249 | src_inc0 = 0; |
250 | } |
251 | if (remaining_src_cols <= 1) { |
252 | src_ptr1 = zerobuf; |
253 | src_inc1 = 0; |
254 | } |
255 | if (remaining_src_cols <= 2) { |
256 | src_ptr2 = zerobuf; |
257 | src_inc2 = 0; |
258 | } |
259 | if (remaining_src_cols <= 3) { |
260 | src_ptr3 = zerobuf; |
261 | src_inc3 = 0; |
262 | } |
263 | if (remaining_src_cols <= 4) { |
264 | src_ptr4 = zerobuf; |
265 | src_inc4 = 0; |
266 | } |
267 | if (remaining_src_cols <= 5) { |
268 | src_ptr5 = zerobuf; |
269 | src_inc5 = 0; |
270 | } |
271 | if (remaining_src_cols <= 6) { |
272 | src_ptr6 = zerobuf; |
273 | src_inc6 = 0; |
274 | } |
275 | src_ptr7 = zerobuf; |
276 | src_inc7 = 0; |
277 | } |
278 | |
279 | const std::int8_t zero_point = zerobuf[0]; |
280 | |
281 | if (sums_ptr) { |
282 | // i: Layout::kCols. |
283 | for (int i = 0; i < 8; ++i) { |
284 | sums_ptr[i] = 0; |
285 | } |
286 | } |
287 | std::int32_t sums_adjustment = 0; |
288 | const __m256i ones_16bit = _mm256_set1_epi16(1); |
289 | __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); |
290 | __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); |
291 | |
292 | // The overall packing effectively pads the source rows to |
293 | // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we |
294 | // only pack for (src_rows + 31) & ~31. When there is an incomplete |
295 | // destination block, this is stored into trailing_buf instead of packed_ptr. |
296 | for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { |
297 | // Available source rows. |
298 | // If this is less than 0 (for m=1), we skip, having filled trailing |
299 | // buffer for m=0. Also, if source rows is zero on m=1, then we filled |
300 | // exactly to the end of the column in the packed buffer. |
301 | const int available_src_rows = src_rows - k; |
302 | // Effectively, |
303 | // available rows = std::max(0, std::min(8, src_rows - k)); |
304 | // treat each case separately. |
305 | if (available_src_rows >= kNumChunkedSrcRows) { |
306 | if (sums_ptr) { |
307 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
308 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
309 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
310 | |
311 | t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); |
312 | t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); |
313 | t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); |
314 | t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); |
315 | t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); |
316 | t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); |
317 | t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); |
318 | t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); |
319 | |
320 | r0 = mm256_unpacklo_epi32(t0, t1); |
321 | r4 = mm256_unpacklo_epi32(t4, t5); |
322 | r2 = mm256_unpackhi_epi32(t0, t1); |
323 | r6 = mm256_unpackhi_epi32(t4, t5); |
324 | r1 = mm256_unpacklo_epi32(t2, t3); |
325 | r5 = mm256_unpacklo_epi32(t6, t7); |
326 | r3 = mm256_unpackhi_epi32(t2, t3); |
327 | r7 = mm256_unpackhi_epi32(t6, t7); |
328 | |
329 | t0 = mm256_unpacklo_epi64(r0, r1); |
330 | t4 = mm256_unpacklo_epi64(r4, r5); |
331 | t2 = mm256_unpackhi_epi64(r0, r1); |
332 | t6 = mm256_unpackhi_epi64(r4, r5); |
333 | t1 = mm256_unpacklo_epi64(r2, r3); |
334 | t5 = mm256_unpacklo_epi64(r6, r7); |
335 | t3 = mm256_unpackhi_epi64(r2, r3); |
336 | t7 = mm256_unpackhi_epi64(r6, r7); |
337 | |
338 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
339 | // and then by 8 bytes *within* lanes. The following set interleave by |
340 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
341 | // t4) are interleaved to create (r0, r1). This complexity follows from |
342 | // the way that AVX is centered around MM 128-bit lanes. |
343 | r0 = mm256_permute2x128_si256(t0, t4, 0x20); |
344 | r4 = mm256_permute2x128_si256(t1, t5, 0x20); |
345 | r1 = mm256_permute2x128_si256(t0, t4, 0x31); |
346 | r5 = mm256_permute2x128_si256(t1, t5, 0x31); |
347 | r2 = mm256_permute2x128_si256(t2, t6, 0x20); |
348 | r6 = mm256_permute2x128_si256(t3, t7, 0x20); |
349 | r3 = mm256_permute2x128_si256(t2, t6, 0x31); |
350 | r7 = mm256_permute2x128_si256(t3, t7, 0x31); |
351 | |
352 | r0 = mm256_xor_si256(r0, input_xor_v); |
353 | r1 = mm256_xor_si256(r1, input_xor_v); |
354 | r2 = mm256_xor_si256(r2, input_xor_v); |
355 | r3 = mm256_xor_si256(r3, input_xor_v); |
356 | r4 = mm256_xor_si256(r4, input_xor_v); |
357 | r5 = mm256_xor_si256(r5, input_xor_v); |
358 | r6 = mm256_xor_si256(r6, input_xor_v); |
359 | r7 = mm256_xor_si256(r7, input_xor_v); |
360 | |
361 | __m256i sums_4x4_16bit_lo; |
362 | sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); |
363 | sums_4x4_16bit_lo = mm256_add_epi16( |
364 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); |
365 | sums_4x4_16bit_lo = mm256_add_epi16( |
366 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); |
367 | sums_4x4_16bit_lo = mm256_add_epi16( |
368 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); |
369 | sums_4x4_16bit_lo = mm256_add_epi16( |
370 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); |
371 | sums_4x4_16bit_lo = mm256_add_epi16( |
372 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); |
373 | sums_4x4_16bit_lo = mm256_add_epi16( |
374 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); |
375 | sums_4x4_16bit_lo = mm256_add_epi16( |
376 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); |
377 | |
378 | // The sums have been performed across columns, and now we have 4x16-bit |
379 | // sums packed together. We use madd for pairwise 32-bit sums. |
380 | const __m256i sums_4x2_32bit_lo_new = |
381 | mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); |
382 | sums_4x2_32bit_lo = |
383 | mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); |
384 | |
385 | __m256i sums_4x4_16bit_hi; |
386 | sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1)); |
387 | sums_4x4_16bit_hi = mm256_add_epi16( |
388 | sums_4x4_16bit_hi, |
389 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1))); |
390 | sums_4x4_16bit_hi = mm256_add_epi16( |
391 | sums_4x4_16bit_hi, |
392 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1))); |
393 | sums_4x4_16bit_hi = mm256_add_epi16( |
394 | sums_4x4_16bit_hi, |
395 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1))); |
396 | sums_4x4_16bit_hi = mm256_add_epi16( |
397 | sums_4x4_16bit_hi, |
398 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1))); |
399 | sums_4x4_16bit_hi = mm256_add_epi16( |
400 | sums_4x4_16bit_hi, |
401 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1))); |
402 | sums_4x4_16bit_hi = mm256_add_epi16( |
403 | sums_4x4_16bit_hi, |
404 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1))); |
405 | sums_4x4_16bit_hi = mm256_add_epi16( |
406 | sums_4x4_16bit_hi, |
407 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1))); |
408 | |
409 | const __m256i sums_4x2_32bit_hi_new = |
410 | mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); |
411 | sums_4x2_32bit_hi = |
412 | mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); |
413 | |
414 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), |
415 | r0); |
416 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), |
417 | r4); |
418 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), |
419 | r1); |
420 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), |
421 | r5); |
422 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), |
423 | r2); |
424 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), |
425 | r6); |
426 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), |
427 | r3); |
428 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), |
429 | r7); |
430 | } else { |
431 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
432 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
433 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
434 | |
435 | t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0)); |
436 | t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4)); |
437 | t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1)); |
438 | t5 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr5)); |
439 | t2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr2)); |
440 | t6 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr6)); |
441 | t3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr3)); |
442 | t7 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr7)); |
443 | |
444 | r0 = mm256_unpacklo_epi32(t0, t1); |
445 | r4 = mm256_unpacklo_epi32(t4, t5); |
446 | r2 = mm256_unpackhi_epi32(t0, t1); |
447 | r6 = mm256_unpackhi_epi32(t4, t5); |
448 | r1 = mm256_unpacklo_epi32(t2, t3); |
449 | r5 = mm256_unpacklo_epi32(t6, t7); |
450 | r3 = mm256_unpackhi_epi32(t2, t3); |
451 | r7 = mm256_unpackhi_epi32(t6, t7); |
452 | |
453 | t0 = mm256_unpacklo_epi64(r0, r1); |
454 | t4 = mm256_unpacklo_epi64(r4, r5); |
455 | t2 = mm256_unpackhi_epi64(r0, r1); |
456 | t6 = mm256_unpackhi_epi64(r4, r5); |
457 | t1 = mm256_unpacklo_epi64(r2, r3); |
458 | t5 = mm256_unpacklo_epi64(r6, r7); |
459 | t3 = mm256_unpackhi_epi64(r2, r3); |
460 | t7 = mm256_unpackhi_epi64(r6, r7); |
461 | |
462 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
463 | // and then by 8 bytes *within* lanes. The following set interleave by |
464 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
465 | // t4) are interleaved to create (r0, r1). This complexity follows from |
466 | // the way that AVX is centered around MM 128-bit lanes. |
467 | r0 = mm256_permute2x128_si256(t0, t4, 0x20); |
468 | r4 = mm256_permute2x128_si256(t1, t5, 0x20); |
469 | r1 = mm256_permute2x128_si256(t0, t4, 0x31); |
470 | r5 = mm256_permute2x128_si256(t1, t5, 0x31); |
471 | r2 = mm256_permute2x128_si256(t2, t6, 0x20); |
472 | r6 = mm256_permute2x128_si256(t3, t7, 0x20); |
473 | r3 = mm256_permute2x128_si256(t2, t6, 0x31); |
474 | r7 = mm256_permute2x128_si256(t3, t7, 0x31); |
475 | |
476 | r0 = mm256_xor_si256(r0, input_xor_v); |
477 | r1 = mm256_xor_si256(r1, input_xor_v); |
478 | r2 = mm256_xor_si256(r2, input_xor_v); |
479 | r3 = mm256_xor_si256(r3, input_xor_v); |
480 | r4 = mm256_xor_si256(r4, input_xor_v); |
481 | r5 = mm256_xor_si256(r5, input_xor_v); |
482 | r6 = mm256_xor_si256(r6, input_xor_v); |
483 | r7 = mm256_xor_si256(r7, input_xor_v); |
484 | |
485 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), |
486 | r0); |
487 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), |
488 | r4); |
489 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), |
490 | r1); |
491 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), |
492 | r5); |
493 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), |
494 | r2); |
495 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), |
496 | r6); |
497 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), |
498 | r3); |
499 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), |
500 | r7); |
501 | } |
502 | } else if (available_src_rows > 0) { |
503 | RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); |
504 | // We do not care what goes into the trailing buffer, but we want |
505 | // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. |
506 | // |
507 | // We compensate for padding-with-zero_point by initializing the |
508 | // summations with the compensating offset, effectively |
509 | // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * |
510 | // 4 * (8 - ((available_src_rows + 3) >> 2)). |
511 | // |
512 | // Note that (zero_point ^ input_xor) is performed in 8-bits and then |
513 | // cast. |
514 | sums_adjustment += |
515 | -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); |
516 | |
517 | __m256i t0, t1, t2, t3, t4, t5, t6, t7; |
518 | __m256i r0, r1, r2, r3, r4, r5, r6, r7; |
519 | const __m256i input_xor_v = _mm256_set1_epi8(input_xor); |
520 | |
521 | t0 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr0); |
522 | t4 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr4); |
523 | t1 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr1); |
524 | t5 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr5); |
525 | t2 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr2); |
526 | t6 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr6); |
527 | t3 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr3); |
528 | t7 = MaskLoadu<Path::kAvx>(available_src_rows, zero_point, src_ptr7); |
529 | |
530 | r0 = mm256_unpacklo_epi32(t0, t1); |
531 | r4 = mm256_unpacklo_epi32(t4, t5); |
532 | r2 = mm256_unpackhi_epi32(t0, t1); |
533 | r6 = mm256_unpackhi_epi32(t4, t5); |
534 | r1 = mm256_unpacklo_epi32(t2, t3); |
535 | r5 = mm256_unpacklo_epi32(t6, t7); |
536 | r3 = mm256_unpackhi_epi32(t2, t3); |
537 | r7 = mm256_unpackhi_epi32(t6, t7); |
538 | |
539 | t0 = mm256_unpacklo_epi64(r0, r1); |
540 | t4 = mm256_unpacklo_epi64(r4, r5); |
541 | t2 = mm256_unpackhi_epi64(r0, r1); |
542 | t6 = mm256_unpackhi_epi64(r4, r5); |
543 | t1 = mm256_unpacklo_epi64(r2, r3); |
544 | t5 = mm256_unpacklo_epi64(r6, r7); |
545 | t3 = mm256_unpackhi_epi64(r2, r3); |
546 | t7 = mm256_unpackhi_epi64(r6, r7); |
547 | |
548 | // The preceding sets of rearrangement operations interleaved by 4 bytes |
549 | // and then by 8 bytes *within* lanes. The following set interleave by |
550 | // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, |
551 | // t4) are interleaved to create (r0, r1). This complexity follows from |
552 | // the way that AVX is centered around MM 128-bit lanes. |
553 | r0 = mm256_permute2x128_si256(t0, t4, 0x20); |
554 | r4 = mm256_permute2x128_si256(t1, t5, 0x20); |
555 | r1 = mm256_permute2x128_si256(t0, t4, 0x31); |
556 | r5 = mm256_permute2x128_si256(t1, t5, 0x31); |
557 | r2 = mm256_permute2x128_si256(t2, t6, 0x20); |
558 | r6 = mm256_permute2x128_si256(t3, t7, 0x20); |
559 | r3 = mm256_permute2x128_si256(t2, t6, 0x31); |
560 | r7 = mm256_permute2x128_si256(t3, t7, 0x31); |
561 | |
562 | r0 = mm256_xor_si256(r0, input_xor_v); |
563 | r1 = mm256_xor_si256(r1, input_xor_v); |
564 | r2 = mm256_xor_si256(r2, input_xor_v); |
565 | r3 = mm256_xor_si256(r3, input_xor_v); |
566 | r4 = mm256_xor_si256(r4, input_xor_v); |
567 | r5 = mm256_xor_si256(r5, input_xor_v); |
568 | r6 = mm256_xor_si256(r6, input_xor_v); |
569 | r7 = mm256_xor_si256(r7, input_xor_v); |
570 | |
571 | __m256i sums_4x4_16bit_lo; |
572 | sums_4x4_16bit_lo = mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); |
573 | sums_4x4_16bit_lo = mm256_add_epi16( |
574 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); |
575 | sums_4x4_16bit_lo = mm256_add_epi16( |
576 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); |
577 | sums_4x4_16bit_lo = mm256_add_epi16( |
578 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); |
579 | sums_4x4_16bit_lo = mm256_add_epi16( |
580 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); |
581 | sums_4x4_16bit_lo = mm256_add_epi16( |
582 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); |
583 | sums_4x4_16bit_lo = mm256_add_epi16( |
584 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); |
585 | sums_4x4_16bit_lo = mm256_add_epi16( |
586 | sums_4x4_16bit_lo, mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); |
587 | |
588 | // The sums have been performed across columns, and now we have 4x16-bit |
589 | // sums packed together. We use madd for pairwise 32-bit sums. |
590 | const __m256i sums_4x2_32bit_lo_new = |
591 | mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); |
592 | sums_4x2_32bit_lo = |
593 | mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); |
594 | |
595 | __m256i sums_4x4_16bit_hi; |
596 | sums_4x4_16bit_hi = mm256_cvtepi8_epi16(mm256_extracti128_si256(r0, 1)); |
597 | sums_4x4_16bit_hi = |
598 | mm256_add_epi16(sums_4x4_16bit_hi, |
599 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r1, 1))); |
600 | sums_4x4_16bit_hi = |
601 | mm256_add_epi16(sums_4x4_16bit_hi, |
602 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r2, 1))); |
603 | sums_4x4_16bit_hi = |
604 | mm256_add_epi16(sums_4x4_16bit_hi, |
605 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r3, 1))); |
606 | sums_4x4_16bit_hi = |
607 | mm256_add_epi16(sums_4x4_16bit_hi, |
608 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r4, 1))); |
609 | sums_4x4_16bit_hi = |
610 | mm256_add_epi16(sums_4x4_16bit_hi, |
611 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r5, 1))); |
612 | sums_4x4_16bit_hi = |
613 | mm256_add_epi16(sums_4x4_16bit_hi, |
614 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r6, 1))); |
615 | sums_4x4_16bit_hi = |
616 | mm256_add_epi16(sums_4x4_16bit_hi, |
617 | mm256_cvtepi8_epi16(mm256_extracti128_si256(r7, 1))); |
618 | |
619 | const __m256i sums_4x2_32bit_hi_new = |
620 | mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); |
621 | sums_4x2_32bit_hi = |
622 | mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); |
623 | |
624 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), |
625 | r0); |
626 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), |
627 | r4); |
628 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), |
629 | r1); |
630 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), |
631 | r5); |
632 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), |
633 | r2); |
634 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), |
635 | r6); |
636 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), |
637 | r3); |
638 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), |
639 | r7); |
640 | } |
641 | |
642 | packed_ptr += 8 * kNumChunkedSrcRows; |
643 | src_ptr0 += src_inc0; |
644 | src_ptr1 += src_inc1; |
645 | src_ptr2 += src_inc2; |
646 | src_ptr3 += src_inc3; |
647 | src_ptr4 += src_inc4; |
648 | src_ptr5 += src_inc5; |
649 | src_ptr6 += src_inc6; |
650 | src_ptr7 += src_inc7; |
651 | } |
652 | |
653 | if (sums_ptr) { |
654 | const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); |
655 | |
656 | __m256i sums = |
657 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr)); |
658 | |
659 | // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the |
660 | // neighbours, finshing up by adding them to the stored accumulated sums. |
661 | const __m256i sums_2x4_32bit_lo = PermuteEpi32EvenOdds(sums_4x2_32bit_lo); |
662 | const __m256i sums_2x4_32bit_hi = PermuteEpi32EvenOdds(sums_4x2_32bit_hi); |
663 | const __m256i sums_2x4_32bit_a = |
664 | mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); |
665 | const __m256i sums_2x4_32bit_b = |
666 | mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); |
667 | sums = mm256_add_epi32(sums, sums_adjustment_v); |
668 | sums = mm256_add_epi32(sums, sums_2x4_32bit_a); |
669 | sums = mm256_add_epi32(sums, sums_2x4_32bit_b); |
670 | |
671 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); |
672 | } |
673 | } |
674 | |
675 | // Use a generic AVX intrinsic for greater-than comparison. |
676 | template <> |
677 | inline __m256i CompareGreaterThan<Path::kAvx>(const __m256i& a, |
678 | const __m256i& b) { |
679 | constexpr int kGreaterThanSignalling = 14; |
680 | const __m256 v = _mm256_cmp_ps(_mm256_cvtepi32_ps(a), _mm256_cvtepi32_ps(b), |
681 | kGreaterThanSignalling); |
682 | return _mm256_cvtps_epi32(v); |
683 | } |
684 | |
685 | } // namespace. |
686 | |
687 | void Pack8bitColMajorForAvx(const std::int8_t* src_ptr, std::int8_t input_xor, |
688 | const std::int8_t* zerobuf, int src_stride, |
689 | int remaining_src_cols, int src_rows, |
690 | std::int8_t* packed_ptr, std::int32_t* sums_ptr) { |
691 | profiler::ScopeLabel label("Pack kAvx 8bit" ); |
692 | |
693 | using Layout = PackImpl8bitAvx::Layout; |
694 | RUY_DCHECK_EQ(Layout::kCols, 8); |
695 | RUY_DCHECK_EQ(Layout::kRows, 4); |
696 | |
697 | // Each Layout::Rows is 4 contiguous input, contiguous packed elements. |
698 | // We process 8 of these chunks at a time, padding short input chunks. |
699 | static constexpr int kNumRowChunks = 8; // Short input is padded. |
700 | |
701 | // Each packed block is 4*8, and there are normally 8. The trailing block is |
702 | // only slightly shorter. |
703 | constexpr int kTrailingBufSize = |
704 | kNumRowChunks * Layout::kCols * Layout::kRows; |
705 | std::int8_t trailing_buf[kTrailingBufSize]; |
706 | memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); |
707 | |
708 | Pack8bitColMajorForAvxPacker(src_ptr, input_xor, zerobuf, src_stride, |
709 | remaining_src_cols, src_rows, packed_ptr, |
710 | sums_ptr, trailing_buf); |
711 | |
712 | constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; |
713 | const bool trailing_data = (src_rows & kChunkedRowMask) > 0; |
714 | // If the number of source rows is not a multiple of kChunkedRowMask, there |
715 | // will be data in the trailing buffer, |
716 | if (trailing_data) { |
717 | const int non_trailing_rows = src_rows & ~kChunkedRowMask; |
718 | // Destination "rows" are padded to next highest multiple of Layout::kRows. |
719 | const int dst_rows = (src_rows + 3) & ~3; |
720 | const int trailing_rows = dst_rows - non_trailing_rows; |
721 | memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, |
722 | Layout::kCols * trailing_rows * sizeof(std::int8_t)); |
723 | } |
724 | } |
725 | |
726 | void PackFloatColMajorForAvx(const float* src_ptr, const float* zerobuf, |
727 | int src_stride, int remaining_src_cols, |
728 | int src_rows, float* packed_ptr) { |
729 | profiler::ScopeLabel label("Pack kAvx float" ); |
730 | static constexpr int kPackCols = 8; // Source cols packed together. |
731 | static constexpr int kPackRows = 8; // Short input is padded. |
732 | float trailing_buf[(kPackRows - 1) * kPackCols]; |
733 | if (remaining_src_cols < 8) { |
734 | memset(trailing_buf, 0, sizeof(trailing_buf)); |
735 | } |
736 | PackFloatColMajorForAvxCommonPacker<PackImplFloatAvx, Path::kAvx>( |
737 | src_ptr, zerobuf, src_stride, remaining_src_cols, src_rows, packed_ptr, |
738 | trailing_buf); |
739 | |
740 | const int trailing_rows = src_rows & (kPackRows - 1); |
741 | if (trailing_rows > 0) { |
742 | const int non_trailing_rows = src_rows & ~(kPackRows - 1); |
743 | memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, |
744 | kPackCols * trailing_rows * sizeof(float)); |
745 | } |
746 | } |
747 | |
748 | void Pack8bitRowMajorForAvx(const std::uint8_t* src_ptr, int src_stride, |
749 | int src_zero_point, std::int8_t* packed_ptr, |
750 | int packed_stride, int start_col, int end_col, |
751 | int src_cols, int block_row, int src_rows, |
752 | int input_xor, std::int32_t* sums) { |
753 | int col = start_col; |
754 | int src_end_col = std::min(end_col, src_cols); |
755 | |
756 | for (; col <= src_end_col - 8; col += 8) { |
757 | std::int8_t* dst_ptr = packed_ptr; |
758 | __m128i val0, val1, val2, val3; |
759 | __m128i input_xor_dup = _mm_set1_epi8(input_xor); |
760 | // Load a 4x8 block. |
761 | if (block_row + 4 <= src_rows) { |
762 | val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); |
763 | val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); |
764 | val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); |
765 | val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); |
766 | } else { |
767 | val0 = _mm_set1_epi8(src_zero_point); |
768 | val1 = val0; |
769 | val2 = val0; |
770 | val3 = val0; |
771 | if (block_row + 0 < src_rows) |
772 | val0 = _mm_loadu_si64(src_ptr + 0 * src_stride); |
773 | if (block_row + 1 < src_rows) |
774 | val1 = _mm_loadu_si64(src_ptr + 1 * src_stride); |
775 | if (block_row + 2 < src_rows) |
776 | val2 = _mm_loadu_si64(src_ptr + 2 * src_stride); |
777 | if (block_row + 3 < src_rows) |
778 | val3 = _mm_loadu_si64(src_ptr + 3 * src_stride); |
779 | } |
780 | // Maybe xor the sign bit to convert from uint8 to int8. |
781 | val0 = _mm_xor_si128(val0, input_xor_dup); |
782 | val1 = _mm_xor_si128(val1, input_xor_dup); |
783 | val2 = _mm_xor_si128(val2, input_xor_dup); |
784 | val3 = _mm_xor_si128(val3, input_xor_dup); |
785 | // Update the sums. |
786 | __m128i val16_0 = _mm_cvtepi8_epi16(val0); |
787 | __m128i val16_1 = _mm_cvtepi8_epi16(val1); |
788 | __m128i val16_2 = _mm_cvtepi8_epi16(val2); |
789 | __m128i val16_3 = _mm_cvtepi8_epi16(val3); |
790 | __m128i new_sum16 = _mm_add_epi16(_mm_add_epi16(val16_0, val16_1), |
791 | _mm_add_epi16(val16_2, val16_3)); |
792 | __m256i sum = |
793 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums + col)); |
794 | sum = mm256_add_epi32(sum, mm256_cvtepi16_epi32(new_sum16)); |
795 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums + col), sum); |
796 | // Perform the transposition of 4x4 blocks |
797 | __m128i t2_val0 = _mm_unpacklo_epi8(val0, val1); |
798 | __m128i t2_val1 = _mm_unpacklo_epi8(val2, val3); |
799 | __m128i t4_val0 = _mm_unpacklo_epi16(t2_val0, t2_val1); |
800 | __m128i t4_val1 = _mm_unpackhi_epi16(t2_val0, t2_val1); |
801 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), t4_val0); |
802 | _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16), t4_val1); |
803 | src_ptr += 8; |
804 | packed_ptr += packed_stride * 8; |
805 | } |
806 | for (; col < src_end_col; col++) { |
807 | std::int32_t accum = 0; |
808 | for (int r = 0; r < 4; r++) { |
809 | std::int8_t packed_val; |
810 | if (block_row + r < src_rows) { |
811 | packed_val = input_xor ^ src_ptr[r * src_stride]; |
812 | } else { |
813 | packed_val = input_xor ^ src_zero_point; |
814 | } |
815 | accum += packed_val; |
816 | *packed_ptr++ = packed_val; |
817 | } |
818 | if (sums) { |
819 | sums[col] += accum; |
820 | } |
821 | src_ptr++; |
822 | } |
823 | for (; col < end_col; col++) { |
824 | std::memset(packed_ptr, 0, 4); |
825 | packed_ptr += 4; |
826 | } |
827 | } |
828 | |
829 | #endif // RUY_PLATFORM_AVX && RUY_OPT(INTRINSICS) |
830 | |
831 | } // namespace ruy |
832 | |