1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | |
8 | #define FBGEMM_EXPORTS |
9 | #include "./OptimizedKernelsAvx2.h" |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | |
15 | namespace fbgemm { |
16 | |
17 | int32_t reduceAvx2(const uint8_t* A, int len) { |
18 | int32_t row_sum = 0; |
19 | #if defined(__AVX2__) |
20 | __m256i sum_v = _mm256_setzero_si256(); |
21 | __m256i one_epi16_v = _mm256_set1_epi16(1); |
22 | __m256i one_epi8_v = _mm256_set1_epi8(1); |
23 | |
24 | int i; |
25 | // vectorized |
26 | for (i = 0; i < len / 32 * 32; i += 32) { |
27 | __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); |
28 | sum_v = _mm256_add_epi32( |
29 | sum_v, |
30 | _mm256_madd_epi16( |
31 | _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v)); |
32 | } |
33 | |
34 | alignas(64) int32_t temp[8]; |
35 | _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v); |
36 | for (int k = 0; k < 8; ++k) { |
37 | row_sum += temp[k]; |
38 | } |
39 | |
40 | // scalar |
41 | for (; i < len; ++i) { |
42 | row_sum += A[i]; |
43 | } |
44 | |
45 | #else |
46 | for (int i = 0; i < len; ++i) { |
47 | row_sum += A[i]; |
48 | } |
49 | #endif |
50 | return row_sum; |
51 | } |
52 | |
53 | void transpose_8rows( |
54 | int N, |
55 | const uint8_t* src, |
56 | int ld_src, |
57 | uint8_t* dst, |
58 | int ld_dst) { |
59 | constexpr int M = 8; |
60 | int j; |
61 | // vectorized loop |
62 | for (j = 0; j < N / 32 * 32; j += 32) { |
63 | // a : a0 a1 ... a31 |
64 | // b : b0 b1 ... b31 |
65 | // c : c0 c1 ... c31 |
66 | // d : d0 d1 ... d31 |
67 | __m256i a = _mm256_lddqu_si256( |
68 | reinterpret_cast<const __m256i*>(src + j + 0 * ld_src)); |
69 | __m256i b = _mm256_lddqu_si256( |
70 | reinterpret_cast<const __m256i*>(src + j + 1 * ld_src)); |
71 | __m256i c = _mm256_lddqu_si256( |
72 | reinterpret_cast<const __m256i*>(src + j + 2 * ld_src)); |
73 | __m256i d = _mm256_lddqu_si256( |
74 | reinterpret_cast<const __m256i*>(src + j + 3 * ld_src)); |
75 | __m256i e = _mm256_lddqu_si256( |
76 | reinterpret_cast<const __m256i*>(src + j + 4 * ld_src)); |
77 | __m256i f = _mm256_lddqu_si256( |
78 | reinterpret_cast<const __m256i*>(src + j + 5 * ld_src)); |
79 | __m256i g = _mm256_lddqu_si256( |
80 | reinterpret_cast<const __m256i*>(src + j + 6 * ld_src)); |
81 | __m256i h = _mm256_lddqu_si256( |
82 | reinterpret_cast<const __m256i*>(src + j + 7 * ld_src)); |
83 | |
84 | // even-odd interleaving |
85 | // ab_lo : a0 b0 a1 b1 ... a7 b7 | a16 b16 ... a23 b23 |
86 | // ab_hi : a8 b8 a9 b9 ... a15 b15 | a24 b24 ... a31 b31 |
87 | // cd_lo : c0 d0 c1 d1 ... c7 d7 | c16 d16 ... c23 d23 |
88 | // cd_hi : c8 d8 c9 d9 ... c15 d15 | c24 d24 ... c31 d31 |
89 | __m256i ab_lo = _mm256_unpacklo_epi8(a, b); |
90 | __m256i ab_hi = _mm256_unpackhi_epi8(a, b); |
91 | __m256i cd_lo = _mm256_unpacklo_epi8(c, d); |
92 | __m256i cd_hi = _mm256_unpackhi_epi8(c, d); |
93 | __m256i ef_lo = _mm256_unpacklo_epi8(e, f); |
94 | __m256i ef_hi = _mm256_unpackhi_epi8(e, f); |
95 | __m256i gh_lo = _mm256_unpacklo_epi8(g, h); |
96 | __m256i gh_hi = _mm256_unpackhi_epi8(g, h); |
97 | |
98 | // 4-row interleaving but permuted at 128-bit granularity |
99 | // abcd0 : a0 b0 c0 d0 ... a-d3 | a-d16 ... a-d19 |
100 | // abcd1 : a4 b4 c4 d4 ... a-d7 | a-d20 ... a-d23 |
101 | // abcd2 : a8 b8 c8 d8 ... a-d11 | a-d24 ... a-d27 |
102 | // abcd3 : a12 b12 c12 d12 ... a-d15 | a-d28 ... a-d31 |
103 | __m256i abcd0 = _mm256_unpacklo_epi16(ab_lo, cd_lo); |
104 | __m256i abcd1 = _mm256_unpackhi_epi16(ab_lo, cd_lo); |
105 | __m256i abcd2 = _mm256_unpacklo_epi16(ab_hi, cd_hi); |
106 | __m256i abcd3 = _mm256_unpackhi_epi16(ab_hi, cd_hi); |
107 | __m256i efgh0 = _mm256_unpacklo_epi16(ef_lo, gh_lo); |
108 | __m256i efgh1 = _mm256_unpackhi_epi16(ef_lo, gh_lo); |
109 | __m256i efgh2 = _mm256_unpacklo_epi16(ef_hi, gh_hi); |
110 | __m256i efgh3 = _mm256_unpackhi_epi16(ef_hi, gh_hi); |
111 | |
112 | // 8-row interleaving |
113 | __m256i y0 = _mm256_unpacklo_epi32(abcd0, efgh0); |
114 | __m256i y1 = _mm256_unpackhi_epi32(abcd0, efgh0); |
115 | __m256i y2 = _mm256_unpacklo_epi32(abcd1, efgh1); |
116 | __m256i y3 = _mm256_unpackhi_epi32(abcd1, efgh1); |
117 | __m256i y4 = _mm256_unpacklo_epi32(abcd2, efgh2); |
118 | __m256i y5 = _mm256_unpackhi_epi32(abcd2, efgh2); |
119 | __m256i y6 = _mm256_unpacklo_epi32(abcd3, efgh3); |
120 | __m256i y7 = _mm256_unpackhi_epi32(abcd3, efgh3); |
121 | |
122 | // Storing with 128-bit lanes are permuted so that everything is in order |
123 | _mm_storel_epi64( |
124 | reinterpret_cast<__m128i*>(dst + (j + 0) * ld_dst), |
125 | _mm256_castsi256_si128(y0)); |
126 | *reinterpret_cast<int64_t*>(dst + (j + 1) * ld_dst) = |
127 | _mm256_extract_epi64(y0, 1); |
128 | _mm_storel_epi64( |
129 | reinterpret_cast<__m128i*>(dst + (j + 2) * ld_dst), |
130 | _mm256_castsi256_si128(y1)); |
131 | *reinterpret_cast<int64_t*>(dst + (j + 3) * ld_dst) = |
132 | _mm256_extract_epi64(y1, 1); |
133 | _mm_storel_epi64( |
134 | reinterpret_cast<__m128i*>(dst + (j + 4) * ld_dst), |
135 | _mm256_castsi256_si128(y2)); |
136 | *reinterpret_cast<int64_t*>(dst + (j + 5) * ld_dst) = |
137 | _mm256_extract_epi64(y2, 1); |
138 | _mm_storel_epi64( |
139 | reinterpret_cast<__m128i*>(dst + (j + 6) * ld_dst), |
140 | _mm256_castsi256_si128(y3)); |
141 | *reinterpret_cast<int64_t*>(dst + (j + 7) * ld_dst) = |
142 | _mm256_extract_epi64(y3, 1); |
143 | _mm_storel_epi64( |
144 | reinterpret_cast<__m128i*>(dst + (j + 8) * ld_dst), |
145 | _mm256_castsi256_si128(y4)); |
146 | *reinterpret_cast<int64_t*>(dst + (j + 9) * ld_dst) = |
147 | _mm256_extract_epi64(y4, 1); |
148 | _mm_storel_epi64( |
149 | reinterpret_cast<__m128i*>(dst + (j + 10) * ld_dst), |
150 | _mm256_castsi256_si128(y5)); |
151 | *reinterpret_cast<int64_t*>(dst + (j + 11) * ld_dst) = |
152 | _mm256_extract_epi64(y5, 1); |
153 | _mm_storel_epi64( |
154 | reinterpret_cast<__m128i*>(dst + (j + 12) * ld_dst), |
155 | _mm256_castsi256_si128(y6)); |
156 | *reinterpret_cast<int64_t*>(dst + (j + 13) * ld_dst) = |
157 | _mm256_extract_epi64(y6, 1); |
158 | _mm_storel_epi64( |
159 | reinterpret_cast<__m128i*>(dst + (j + 14) * ld_dst), |
160 | _mm256_castsi256_si128(y7)); |
161 | *reinterpret_cast<int64_t*>(dst + (j + 15) * ld_dst) = |
162 | _mm256_extract_epi64(y7, 1); |
163 | *reinterpret_cast<int64_t*>(dst + (j + 16) * ld_dst) = |
164 | _mm256_extract_epi64(y0, 2); |
165 | *reinterpret_cast<int64_t*>(dst + (j + 17) * ld_dst) = |
166 | _mm256_extract_epi64(y0, 3); |
167 | *reinterpret_cast<int64_t*>(dst + (j + 18) * ld_dst) = |
168 | _mm256_extract_epi64(y1, 2); |
169 | *reinterpret_cast<int64_t*>(dst + (j + 19) * ld_dst) = |
170 | _mm256_extract_epi64(y1, 3); |
171 | *reinterpret_cast<int64_t*>(dst + (j + 20) * ld_dst) = |
172 | _mm256_extract_epi64(y2, 2); |
173 | *reinterpret_cast<int64_t*>(dst + (j + 21) * ld_dst) = |
174 | _mm256_extract_epi64(y2, 3); |
175 | *reinterpret_cast<int64_t*>(dst + (j + 22) * ld_dst) = |
176 | _mm256_extract_epi64(y3, 2); |
177 | *reinterpret_cast<int64_t*>(dst + (j + 23) * ld_dst) = |
178 | _mm256_extract_epi64(y3, 3); |
179 | *reinterpret_cast<int64_t*>(dst + (j + 24) * ld_dst) = |
180 | _mm256_extract_epi64(y4, 2); |
181 | *reinterpret_cast<int64_t*>(dst + (j + 25) * ld_dst) = |
182 | _mm256_extract_epi64(y4, 3); |
183 | *reinterpret_cast<int64_t*>(dst + (j + 26) * ld_dst) = |
184 | _mm256_extract_epi64(y5, 2); |
185 | *reinterpret_cast<int64_t*>(dst + (j + 27) * ld_dst) = |
186 | _mm256_extract_epi64(y5, 3); |
187 | *reinterpret_cast<int64_t*>(dst + (j + 28) * ld_dst) = |
188 | _mm256_extract_epi64(y6, 2); |
189 | *reinterpret_cast<int64_t*>(dst + (j + 29) * ld_dst) = |
190 | _mm256_extract_epi64(y6, 3); |
191 | *reinterpret_cast<int64_t*>(dst + (j + 30) * ld_dst) = |
192 | _mm256_extract_epi64(y7, 2); |
193 | *reinterpret_cast<int64_t*>(dst + (j + 31) * ld_dst) = |
194 | _mm256_extract_epi64(y7, 3); |
195 | } |
196 | |
197 | // scalar loop for remainder |
198 | for (; j < N; ++j) { |
199 | for (int i = 0; i < M; ++i) { |
200 | dst[j * ld_dst + i] = src[j + i * ld_src]; |
201 | } |
202 | } |
203 | } |
204 | |
205 | void spmdmKernelAvx2( |
206 | int N, |
207 | const uint8_t* A_buffer, |
208 | const int32_t* colptr, |
209 | const int8_t* values, |
210 | const int16_t* rowidx, |
211 | int32_t* C_buffer) { |
212 | for (int j = 0; j < N; ++j) { |
213 | int k = colptr[j]; |
214 | int k_end_aligned = colptr[j] + (colptr[j + 1] - colptr[j]) / 4 * 4; |
215 | |
216 | for (; k < k_end_aligned; k += 4) { |
217 | __m256i w = |
218 | _mm256_set1_epi32(*(reinterpret_cast<const int32_t*>(&values[k]))); |
219 | __m256i a[4]; |
220 | a[0] = _mm256_load_si256( |
221 | reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 0] * 32])); |
222 | a[1] = _mm256_load_si256( |
223 | reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 1] * 32])); |
224 | a[2] = _mm256_load_si256( |
225 | reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 2] * 32])); |
226 | a[3] = _mm256_load_si256( |
227 | reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 3] * 32])); |
228 | |
229 | __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); |
230 | __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); |
231 | __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); |
232 | __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); |
233 | |
234 | a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); |
235 | a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); |
236 | a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); |
237 | a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); |
238 | |
239 | __m256i ab[4]; |
240 | ab[0] = _mm256_maddubs_epi16(a[0], w); |
241 | ab[1] = _mm256_maddubs_epi16(a[1], w); |
242 | ab[2] = _mm256_maddubs_epi16(a[2], w); |
243 | ab[3] = _mm256_maddubs_epi16(a[3], w); |
244 | |
245 | __m256i one = _mm256_set1_epi16(1); |
246 | ab[0] = _mm256_madd_epi16(ab[0], one); |
247 | ab[1] = _mm256_madd_epi16(ab[1], one); |
248 | ab[2] = _mm256_madd_epi16(ab[2], one); |
249 | ab[3] = _mm256_madd_epi16(ab[3], one); |
250 | |
251 | __m256i t[4]; |
252 | t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); |
253 | t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); |
254 | t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); |
255 | t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); |
256 | |
257 | _mm256_store_si256( |
258 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), |
259 | _mm256_add_epi32( |
260 | _mm256_load_si256( |
261 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 0 * 8])), |
262 | t[0])); |
263 | _mm256_store_si256( |
264 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), |
265 | _mm256_add_epi32( |
266 | _mm256_load_si256( |
267 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 1 * 8])), |
268 | t[1])); |
269 | _mm256_store_si256( |
270 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), |
271 | _mm256_add_epi32( |
272 | _mm256_load_si256( |
273 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 2 * 8])), |
274 | t[2])); |
275 | _mm256_store_si256( |
276 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), |
277 | _mm256_add_epi32( |
278 | _mm256_load_si256( |
279 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 3 * 8])), |
280 | t[3])); |
281 | } |
282 | |
283 | int remainder = colptr[j + 1] - k; |
284 | if (remainder > 0) { |
285 | int32_t temp_w = 0; |
286 | for (int r = 0; r < remainder; ++r) { |
287 | (reinterpret_cast<int8_t*>(&temp_w))[r] = values[k + r]; |
288 | } |
289 | __m256i w = _mm256_set1_epi32(temp_w); |
290 | __m256i a[4]; |
291 | a[0] = _mm256_load_si256( |
292 | reinterpret_cast<const __m256i*>(&A_buffer[rowidx[k + 0] * 32])); |
293 | a[1] = remainder > 1 ? _mm256_load_si256(reinterpret_cast<const __m256i*>( |
294 | &A_buffer[rowidx[k + 1] * 32])) |
295 | : _mm256_setzero_si256(); |
296 | a[2] = remainder > 2 ? _mm256_load_si256(reinterpret_cast<const __m256i*>( |
297 | &A_buffer[rowidx[k + 2] * 32])) |
298 | : _mm256_setzero_si256(); |
299 | a[3] = _mm256_setzero_si256(); |
300 | |
301 | __m256i a01_lo = _mm256_unpacklo_epi8(a[0], a[1]); |
302 | __m256i a01_hi = _mm256_unpackhi_epi8(a[0], a[1]); |
303 | __m256i a23_lo = _mm256_unpacklo_epi8(a[2], a[3]); |
304 | __m256i a23_hi = _mm256_unpackhi_epi8(a[2], a[3]); |
305 | |
306 | a[0] = _mm256_unpacklo_epi16(a01_lo, a23_lo); |
307 | a[1] = _mm256_unpackhi_epi16(a01_lo, a23_lo); |
308 | a[2] = _mm256_unpacklo_epi16(a01_hi, a23_hi); |
309 | a[3] = _mm256_unpackhi_epi16(a01_hi, a23_hi); |
310 | |
311 | __m256i ab[4]; |
312 | ab[0] = _mm256_maddubs_epi16(a[0], w); |
313 | ab[1] = _mm256_maddubs_epi16(a[1], w); |
314 | ab[2] = _mm256_maddubs_epi16(a[2], w); |
315 | ab[3] = _mm256_maddubs_epi16(a[3], w); |
316 | |
317 | __m256i one = _mm256_set1_epi16(1); |
318 | ab[0] = _mm256_madd_epi16(ab[0], one); |
319 | ab[1] = _mm256_madd_epi16(ab[1], one); |
320 | ab[2] = _mm256_madd_epi16(ab[2], one); |
321 | ab[3] = _mm256_madd_epi16(ab[3], one); |
322 | |
323 | __m256i t[4]; |
324 | t[0] = _mm256_permute2f128_si256(ab[0], ab[1], 0x20); |
325 | t[1] = _mm256_permute2f128_si256(ab[2], ab[3], 0x20); |
326 | t[2] = _mm256_permute2f128_si256(ab[0], ab[1], 0x31); |
327 | t[3] = _mm256_permute2f128_si256(ab[2], ab[3], 0x31); |
328 | |
329 | _mm256_store_si256( |
330 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 0 * 8]), |
331 | _mm256_add_epi32( |
332 | _mm256_load_si256( |
333 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 0 * 8])), |
334 | t[0])); |
335 | _mm256_store_si256( |
336 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 1 * 8]), |
337 | _mm256_add_epi32( |
338 | _mm256_load_si256( |
339 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 1 * 8])), |
340 | t[1])); |
341 | _mm256_store_si256( |
342 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 2 * 8]), |
343 | _mm256_add_epi32( |
344 | _mm256_load_si256( |
345 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 2 * 8])), |
346 | t[2])); |
347 | _mm256_store_si256( |
348 | reinterpret_cast<__m256i*>(&C_buffer[j * 32 + 3 * 8]), |
349 | _mm256_add_epi32( |
350 | _mm256_load_si256( |
351 | reinterpret_cast<const __m256i*>(&C_buffer[j * 32 + 3 * 8])), |
352 | t[3])); |
353 | } |
354 | } // for each column of B |
355 | } |
356 | |
357 | } // namespace fbgemm |
358 | |