1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmEmbedding.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14#include <type_traits>
15
16namespace fbgemm {
17namespace internal {
18
19template <typename T>
20struct reg_t;
21
22template <>
23struct reg_t<int32_t> {
24 using w_reg_t = __m512;
25 using mask_reg_t = __mmask16;
26};
27
28template <>
29struct reg_t<int64_t> {
30 using w_reg_t = __m256;
31 using mask_reg_t = __mmask8;
32};
33
34template <
35 typename T,
36 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
37static constexpr int get_vlen() {
38 return 16;
39}
40
41template <
42 typename T,
43 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
44static constexpr int get_vlen() {
45 return 8;
46}
47
48template <
49 typename T,
50 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
51static inline __m512i load(void const* addr) {
52 return _mm512_loadu_si512(addr);
53}
54
55template <
56 typename T,
57 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
58static inline __m512i load(void const* addr) {
59 return _mm512_loadu_si512(addr);
60}
61
62template <
63 typename T,
64 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
65static inline __m512 load_weights(void const* addr) {
66 return _mm512_loadu_ps(addr);
67}
68
69template <
70 typename T,
71 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
72static inline __m256 load_weights(float const* addr) {
73 return _mm256_loadu_ps(addr);
74}
75
76template <
77 typename T,
78 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
79static inline __m512
80mask_load_weights(__m512i src, __mmask16 mask_rem_v, void const* addr) {
81 return _mm512_mask_loadu_ps(_mm512_castsi512_ps(src), mask_rem_v, addr);
82}
83
84template <
85 typename T,
86 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
87static inline __m256
88mask_load_weights(__m512i src, __mmask8 mask_rem_v, void const* addr) {
89 return _mm256_mask_loadu_ps(
90 _mm256_castsi256_ps(_mm512_castsi512_si256(src)), mask_rem_v, addr);
91}
92
93template <
94 typename T,
95 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
96static inline void mask_compress_and_store_weights(
97 void* addr,
98 __m512i zero_v,
99 __mmask16 compress_mask_v,
100 __mmask16 store_mask_v,
101 __m512 src) {
102 __m512 out_weights_v = _mm512_mask_compress_ps(
103 _mm512_castsi512_ps(zero_v), compress_mask_v, src);
104 _mm512_mask_storeu_ps(addr, store_mask_v, out_weights_v);
105}
106
107template <
108 typename T,
109 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
110static inline void mask_compress_and_store_weights(
111 void* addr,
112 __m512i zero_v,
113 __mmask8 compress_mask_v,
114 __mmask8 store_mask_v,
115 __m256 src) {
116 __m256 out_weights_v = _mm256_mask_compress_ps(
117 _mm256_castsi256_ps(_mm512_castsi512_si256(zero_v)),
118 compress_mask_v,
119 src);
120 _mm256_mask_storeu_ps(addr, store_mask_v, out_weights_v);
121}
122
123template <
124 typename T,
125 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
126static inline __mmask16 mask_from_rem(int rem) {
127 __mmask16 mask_rem_v = (((long long)1) << rem) - 1;
128 return mask_rem_v;
129}
130
131template <
132 typename T,
133 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
134static inline __mmask8 mask_from_rem(int rem) {
135 __mmask8 mask_rem_v = (((long long)1) << rem) - 1;
136 return mask_rem_v;
137}
138
139template <
140 typename T,
141 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
142static inline __m512i
143mask_load(__m512i zero_v, __mmask16 mask_rem_v, void const* addr) {
144 return _mm512_mask_loadu_epi32(zero_v, mask_rem_v, addr);
145}
146
147template <
148 typename T,
149 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
150static inline __m512i
151mask_load(__m512i zero_v, __mmask8 mask_rem_v, void const* addr) {
152 return _mm512_mask_loadu_epi64(zero_v, mask_rem_v, addr);
153}
154
155template <
156 typename T,
157 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
158static inline __m512i maskz_load(__mmask16 mask_rem_v, void const* addr) {
159 return _mm512_maskz_loadu_epi32(mask_rem_v, addr);
160}
161
162template <
163 typename T,
164 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
165static inline __m512i maskz_load(__mmask8 mask_rem_v, void const* addr) {
166 return _mm512_maskz_loadu_epi64(mask_rem_v, addr);
167}
168
169template <
170 typename T,
171 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
172static inline __m512i mask_mov(__m512i src, __mmask16 mask_rem_v, __m512i a) {
173 return _mm512_mask_mov_epi32(src, mask_rem_v, a);
174}
175
176template <
177 typename T,
178 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
179static inline __m512i mask_mov(__m512i src, __mmask8 mask_rem_v, __m512i a) {
180 return _mm512_mask_mov_epi64(src, mask_rem_v, a);
181}
182
183template <
184 typename T,
185 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
186static inline __m512i gather(__m512i indices, const int32_t* addr) {
187 return _mm512_i32gather_epi32(indices, addr, 4);
188}
189
190template <
191 typename T,
192 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
193static inline __m512i gather(__m512i indices, const int32_t* addr) {
194 // ToDo: Change this _mm512_i64gather_epi64 once mapping table is 64-bit
195 __m256i res_32 = _mm512_i64gather_epi32(indices, addr, 4);
196 return _mm512_cvtepi32_epi64(res_32);
197}
198
199template <
200 typename T,
201 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
202static inline __m512i mask_gather(
203 __m512i src,
204 __mmask16 mask_rem_v,
205 __m512i indices,
206 const int32_t* addr) {
207 return _mm512_mask_i32gather_epi32(src, mask_rem_v, indices, addr, 4);
208}
209
210template <
211 typename T,
212 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
213static inline __m512i mask_gather(
214 __m512i src,
215 __mmask8 mask_rem_v,
216 __m512i indices,
217 const int32_t* addr) {
218 // ToDo: Change this _mm512_mask_i64gather_epi64 once mapping table is 64-bit
219 __m256i res_32 = _mm512_mask_i64gather_epi32(
220 _mm512_castsi512_si256(src), mask_rem_v, indices, addr, 4);
221 return _mm512_cvtepi32_epi64(res_32);
222}
223
224template <
225 typename T,
226 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
227static inline __mmask16 gen_mask(__m512i indices, __m512i zero_v) {
228 return _mm512_cmpge_epi32_mask(indices, zero_v);
229}
230
231template <
232 typename T,
233 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
234static inline __mmask8 gen_mask(__m512i indices, __m512i zero_v) {
235 return _mm512_cmpge_epi64_mask(indices, zero_v);
236}
237
238template <
239 typename T,
240 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
241static inline void compress_store(void* addr, __mmask16 mask, __m512i src_v) {
242 _mm512_mask_compressstoreu_ps(addr, mask, _mm512_castsi512_ps(src_v));
243}
244
245template <
246 typename T,
247 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
248static inline void compress_store(void* addr, __mmask8 mask, __m512i src_v) {
249 _mm512_mask_compressstoreu_pd(addr, mask, _mm512_castsi512_pd(src_v));
250}
251
252template <
253 typename T,
254 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
255static inline void
256compress_store_weights(void* addr, __mmask16 mask, __m512 src_v) {
257 _mm512_mask_compressstoreu_ps(addr, mask, src_v);
258}
259
260template <
261 typename T,
262 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
263static inline void
264compress_store_weights(void* addr, __mmask8 mask, __m256 src_v) {
265 _mm256_mask_compressstoreu_ps(addr, mask, src_v);
266}
267
268template <
269 typename T,
270 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
271static inline __m512 compress(__m512i zero_v, __mmask16 mask, __m512i src_v) {
272 return _mm512_mask_compress_ps(
273 _mm512_castsi512_ps(zero_v), mask, _mm512_castsi512_ps(src_v));
274}
275
276template <
277 typename T,
278 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
279static inline __m512d compress(__m512i zero_v, __mmask8 mask, __m512i src_v) {
280 return _mm512_mask_compress_pd(
281 _mm512_castsi512_pd(zero_v), mask, _mm512_castsi512_pd(src_v));
282}
283
284template <
285 typename T,
286 typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0>
287static inline void mask_store(void* addr, __mmask16 mask, __m512 src_v) {
288 _mm512_mask_storeu_epi32(addr, mask, _mm512_castps_si512(src_v));
289}
290
291template <
292 typename T,
293 typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0>
294static inline void mask_store(void* addr, __mmask8 mask, __m512d src_v) {
295 _mm512_mask_storeu_epi64(addr, mask, _mm512_castpd_si512(src_v));
296}
297
298// copy len bytes from src to dest
299static inline void mymemcpy(char* src, char* dest, int len) {
300 constexpr int VLEN = 64;
301 int i = 0;
302 for (; i < len / VLEN * VLEN; i += VLEN) {
303 auto src_v = _mm512_loadu_si512(src + i);
304 _mm512_storeu_si512(dest + i, src_v);
305 }
306 int rem = len - i;
307 if (rem > 0) {
308 __mmask64 mask_rem_v = (((long long)1) << rem) - 1;
309 auto src_v = _mm512_maskz_loadu_epi8(mask_rem_v, src + i);
310 _mm512_mask_storeu_epi8(dest + i, mask_rem_v, src_v);
311 }
312}
313
314template <
315 typename IndexType,
316 bool HAS_WEIGHTS,
317 int UNROLL = 8,
318 bool USE_MASK = false>
319static inline void compressed_indices_remap_avx512_helper(
320 __m512i zero_v,
321 __m512i minus1_v,
322 const IndexType* offsets,
323 const IndexType* indices,
324 const int32_t* compressed_indices_mapping,
325 const float* weights,
326 IndexType* out_indices,
327 float* out_weights,
328 IndexType* count_indices,
329 const int32_t* rem,
330 const int32_t* ind_w_start_offsets) {
331 typename reg_t<IndexType>::mask_reg_t mask_rem_v[UNROLL];
332 for (int i = 0; i < UNROLL; ++i) {
333 mask_rem_v[i] = mask_from_rem<IndexType>(rem[i]);
334 }
335 for (int i = 0; i < UNROLL; ++i) {
336 __m512i indices_v;
337 if (USE_MASK) {
338 indices_v = mask_load<IndexType>(
339 zero_v,
340 mask_rem_v[i],
341 reinterpret_cast<void const*>(
342 indices + offsets[i] + ind_w_start_offsets[i]));
343 } else {
344 indices_v = load<IndexType>(reinterpret_cast<void const*>(
345 indices + offsets[i] + ind_w_start_offsets[i]));
346 }
347
348 // gather remapped indices from the mapping table
349 __m512i remapped_indices_v;
350 if (USE_MASK) {
351 remapped_indices_v = mask_gather<IndexType>(
352 zero_v, mask_rem_v[i], indices_v, compressed_indices_mapping);
353 // mov -1 to not used places in the vector
354 remapped_indices_v =
355 mask_mov<IndexType>(minus1_v, mask_rem_v[i], remapped_indices_v);
356
357 } else {
358 remapped_indices_v =
359 gather<IndexType>(indices_v, compressed_indices_mapping);
360 }
361
362 typename reg_t<IndexType>::w_reg_t weights_v;
363 if (HAS_WEIGHTS) {
364 if (USE_MASK) {
365 weights_v = mask_load_weights<IndexType>(
366 zero_v,
367 mask_rem_v[i],
368 reinterpret_cast<void const*>(
369 weights + offsets[i] + ind_w_start_offsets[i]));
370 } else {
371 weights_v = load_weights<IndexType>(
372 weights + offsets[i] + ind_w_start_offsets[i]);
373 }
374 }
375
376 // Now remove -1 from the remapped indices
377 auto mask_indices_v = gen_mask<IndexType>(remapped_indices_v, zero_v);
378
379 if (USE_MASK) {
380 auto out_indices_v =
381 compress<IndexType>(zero_v, mask_indices_v, remapped_indices_v);
382
383 mask_store<IndexType>(
384 reinterpret_cast<void*>(out_indices + offsets[i] + count_indices[i]),
385 mask_rem_v[i],
386 out_indices_v);
387 } else {
388 compress_store<IndexType>(
389 reinterpret_cast<void*>(out_indices + offsets[i] + count_indices[i]),
390 mask_indices_v,
391 remapped_indices_v);
392 }
393
394 if (HAS_WEIGHTS) {
395 if (USE_MASK) {
396 mask_compress_and_store_weights<IndexType>(
397 reinterpret_cast<void*>(
398 out_weights + offsets[i] + count_indices[i]),
399 zero_v,
400 mask_indices_v,
401 mask_rem_v[i],
402 weights_v);
403 } else {
404 compress_store_weights<IndexType>(
405 reinterpret_cast<void*>(
406 out_weights + offsets[i] + count_indices[i]),
407 mask_indices_v,
408 weights_v);
409 }
410 }
411
412 count_indices[i] += _mm_popcnt_u32(mask_indices_v);
413 }
414}
415
416template <typename IndexType, bool HAS_WEIGHTS>
417void compressed_indices_remap_avx512(
418 std::int32_t offsets_len,
419 const IndexType* indices,
420 const int32_t* compressed_indices_mapping,
421 const IndexType* offsets,
422 const float* weights, // optional, can be null,
423 IndexType* out_indices,
424 IndexType* out_offsets,
425 float* out_weights) {
426 __m512i zero_v = _mm512_set1_epi32(0);
427 __m512i minus1_v = _mm512_set1_epi32(-1);
428 out_offsets[0] = offsets[0];
429 constexpr int UNROLL = 8;
430 constexpr int VLEN = get_vlen<IndexType>();
431 int k = 1;
432 for (; k < (offsets_len - 1) / UNROLL * UNROLL; k += UNROLL) {
433 int32_t len[UNROLL];
434 int32_t rem[UNROLL];
435 for (int l = 0; l < UNROLL; ++l) {
436 len[l] = offsets[k + l] - offsets[k + l - 1];
437 }
438 // count of non-pruned indices
439 IndexType count_indices[UNROLL] = {0};
440 // read indices/weights starting at these offsets
441 int32_t ind_w_start_offsets[UNROLL] = {0};
442 __m256i vec_len_v = _mm256_set1_epi32(VLEN);
443 __m256i len_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len));
444 __mmask8 cmp_res_v = _mm256_cmpge_epi32_mask(len_v, vec_len_v);
445 len_v = _mm256_mask_sub_epi32(len_v, cmp_res_v, len_v, vec_len_v);
446 _mm256_storeu_si256(reinterpret_cast<__m256i*>(len), len_v);
447 __m256i rem_v = _mm256_maskz_mov_epi32(cmp_res_v, vec_len_v);
448 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), rem_v);
449 int active_unrolls = _mm_popcnt_u32(cmp_res_v);
450
451 // if we have any at least 1 full vector length work
452 // take vector path
453 while (active_unrolls > 0) {
454 compressed_indices_remap_avx512_helper<
455 IndexType,
456 HAS_WEIGHTS,
457 UNROLL,
458 true>(
459 zero_v,
460 minus1_v,
461 offsets + k - 1,
462 indices,
463 compressed_indices_mapping,
464 weights,
465 out_indices,
466 out_weights,
467 count_indices,
468 rem,
469 ind_w_start_offsets);
470
471 __m256i start_offsets_v = _mm256_loadu_si256(
472 reinterpret_cast<const __m256i*>(ind_w_start_offsets));
473 start_offsets_v = _mm256_mask_add_epi32(
474 start_offsets_v, cmp_res_v, start_offsets_v, vec_len_v);
475 _mm256_storeu_si256(
476 reinterpret_cast<__m256i*>(ind_w_start_offsets), start_offsets_v);
477
478 len_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len));
479 cmp_res_v = _mm256_cmpge_epi32_mask(len_v, vec_len_v);
480 len_v = _mm256_mask_sub_epi32(len_v, cmp_res_v, len_v, vec_len_v);
481 _mm256_storeu_si256(reinterpret_cast<__m256i*>(len), len_v);
482 rem_v = _mm256_maskz_mov_epi32(cmp_res_v, vec_len_v);
483 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), rem_v);
484 active_unrolls = _mm_popcnt_u32(cmp_res_v);
485 }
486
487 // Now work on all the remainders
488 __m256i len_rem_v =
489 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len));
490 _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), len_rem_v);
491 compressed_indices_remap_avx512_helper<
492 IndexType,
493 HAS_WEIGHTS,
494 UNROLL,
495 true>(
496 zero_v,
497 minus1_v,
498 offsets + k - 1,
499 indices,
500 compressed_indices_mapping,
501 weights,
502 out_indices,
503 out_weights,
504 count_indices,
505 rem,
506 ind_w_start_offsets);
507
508 // update output offsets
509 for (int l = 0; l < UNROLL; ++l) {
510 out_offsets[k + l] = out_offsets[k + l - 1] + count_indices[l];
511 }
512 }
513
514 // work on remaining offsets_len serially
515 constexpr int UNROLL_REM = 1;
516 for (; k < offsets_len; ++k) {
517 int32_t len[UNROLL_REM];
518 int32_t rem[UNROLL_REM] = {0};
519 for (int l = 0; l < UNROLL_REM; ++l) {
520 len[l] = offsets[k + l] - offsets[k + l - 1];
521 }
522 IndexType count_indices[UNROLL_REM] = {0};
523 int32_t ind_w_start_offsets[UNROLL_REM] = {0};
524 int i = 0;
525 for (; i < len[0] / VLEN * VLEN; i += VLEN) {
526 compressed_indices_remap_avx512_helper<
527 IndexType,
528 HAS_WEIGHTS,
529 UNROLL_REM,
530 false>(
531 zero_v,
532 minus1_v,
533 offsets + k - 1,
534 indices,
535 compressed_indices_mapping,
536 weights,
537 out_indices,
538 out_weights,
539 count_indices,
540 rem,
541 ind_w_start_offsets);
542 ind_w_start_offsets[0] += VLEN;
543 }
544 // remainder
545 rem[0] = len[0] - i;
546 if (rem[0] > 0) {
547 compressed_indices_remap_avx512_helper<
548 IndexType,
549 HAS_WEIGHTS,
550 UNROLL_REM,
551 true>(
552 zero_v,
553 minus1_v,
554 offsets + k - 1,
555 indices,
556 compressed_indices_mapping,
557 weights,
558 out_indices,
559 out_weights,
560 count_indices,
561 rem,
562 ind_w_start_offsets);
563 }
564
565 for (int l = 0; l < UNROLL_REM; ++l) {
566 out_offsets[k + l] = out_offsets[k + l - 1] + count_indices[l];
567 }
568 }
569
570 // Results are stored at input offsets in output variables
571 // copy results to right output locations
572 for (int i = 1; i < offsets_len; ++i) {
573 int out_len = out_offsets[i] - out_offsets[i - 1];
574 mymemcpy(
575 reinterpret_cast<char*>(out_indices + offsets[i - 1]),
576 reinterpret_cast<char*>(out_indices + out_offsets[i - 1]),
577 out_len * sizeof(IndexType));
578 if (HAS_WEIGHTS) {
579 mymemcpy(
580 reinterpret_cast<char*>(out_weights + offsets[i - 1]),
581 reinterpret_cast<char*>(out_weights + out_offsets[i - 1]),
582 out_len * sizeof(float));
583 }
584 }
585}
586
587#define INSTANTIATE_REMAP_BASE(INDEX_TYPE, HAS_WEIGHTS) \
588 template void compressed_indices_remap_avx512<INDEX_TYPE, HAS_WEIGHTS>( \
589 std::int32_t offsets_numel, \
590 const INDEX_TYPE* indices, \
591 const int32_t* compressed_indices_mapping, \
592 const INDEX_TYPE* offsets, \
593 const float* weights, \
594 INDEX_TYPE* out_indices, \
595 INDEX_TYPE* out_offsets, \
596 float* out_weights);
597
598INSTANTIATE_REMAP_BASE(int32_t, true)
599INSTANTIATE_REMAP_BASE(int32_t, false)
600INSTANTIATE_REMAP_BASE(int64_t, true)
601INSTANTIATE_REMAP_BASE(int64_t, false)
602
603#undef INSTANTIATE_REMAP_BASE
604
605} // namespace internal
606} // namespace fbgemm
607