1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmEmbedding.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | #include <type_traits> |
15 | |
16 | namespace fbgemm { |
17 | namespace internal { |
18 | |
19 | template <typename T> |
20 | struct reg_t; |
21 | |
22 | template <> |
23 | struct reg_t<int32_t> { |
24 | using w_reg_t = __m512; |
25 | using mask_reg_t = __mmask16; |
26 | }; |
27 | |
28 | template <> |
29 | struct reg_t<int64_t> { |
30 | using w_reg_t = __m256; |
31 | using mask_reg_t = __mmask8; |
32 | }; |
33 | |
34 | template < |
35 | typename T, |
36 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
37 | static constexpr int get_vlen() { |
38 | return 16; |
39 | } |
40 | |
41 | template < |
42 | typename T, |
43 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
44 | static constexpr int get_vlen() { |
45 | return 8; |
46 | } |
47 | |
48 | template < |
49 | typename T, |
50 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
51 | static inline __m512i load(void const* addr) { |
52 | return _mm512_loadu_si512(addr); |
53 | } |
54 | |
55 | template < |
56 | typename T, |
57 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
58 | static inline __m512i load(void const* addr) { |
59 | return _mm512_loadu_si512(addr); |
60 | } |
61 | |
62 | template < |
63 | typename T, |
64 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
65 | static inline __m512 load_weights(void const* addr) { |
66 | return _mm512_loadu_ps(addr); |
67 | } |
68 | |
69 | template < |
70 | typename T, |
71 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
72 | static inline __m256 load_weights(float const* addr) { |
73 | return _mm256_loadu_ps(addr); |
74 | } |
75 | |
76 | template < |
77 | typename T, |
78 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
79 | static inline __m512 |
80 | mask_load_weights(__m512i src, __mmask16 mask_rem_v, void const* addr) { |
81 | return _mm512_mask_loadu_ps(_mm512_castsi512_ps(src), mask_rem_v, addr); |
82 | } |
83 | |
84 | template < |
85 | typename T, |
86 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
87 | static inline __m256 |
88 | mask_load_weights(__m512i src, __mmask8 mask_rem_v, void const* addr) { |
89 | return _mm256_mask_loadu_ps( |
90 | _mm256_castsi256_ps(_mm512_castsi512_si256(src)), mask_rem_v, addr); |
91 | } |
92 | |
93 | template < |
94 | typename T, |
95 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
96 | static inline void mask_compress_and_store_weights( |
97 | void* addr, |
98 | __m512i zero_v, |
99 | __mmask16 compress_mask_v, |
100 | __mmask16 store_mask_v, |
101 | __m512 src) { |
102 | __m512 out_weights_v = _mm512_mask_compress_ps( |
103 | _mm512_castsi512_ps(zero_v), compress_mask_v, src); |
104 | _mm512_mask_storeu_ps(addr, store_mask_v, out_weights_v); |
105 | } |
106 | |
107 | template < |
108 | typename T, |
109 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
110 | static inline void mask_compress_and_store_weights( |
111 | void* addr, |
112 | __m512i zero_v, |
113 | __mmask8 compress_mask_v, |
114 | __mmask8 store_mask_v, |
115 | __m256 src) { |
116 | __m256 out_weights_v = _mm256_mask_compress_ps( |
117 | _mm256_castsi256_ps(_mm512_castsi512_si256(zero_v)), |
118 | compress_mask_v, |
119 | src); |
120 | _mm256_mask_storeu_ps(addr, store_mask_v, out_weights_v); |
121 | } |
122 | |
123 | template < |
124 | typename T, |
125 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
126 | static inline __mmask16 mask_from_rem(int rem) { |
127 | __mmask16 mask_rem_v = (((long long)1) << rem) - 1; |
128 | return mask_rem_v; |
129 | } |
130 | |
131 | template < |
132 | typename T, |
133 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
134 | static inline __mmask8 mask_from_rem(int rem) { |
135 | __mmask8 mask_rem_v = (((long long)1) << rem) - 1; |
136 | return mask_rem_v; |
137 | } |
138 | |
139 | template < |
140 | typename T, |
141 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
142 | static inline __m512i |
143 | mask_load(__m512i zero_v, __mmask16 mask_rem_v, void const* addr) { |
144 | return _mm512_mask_loadu_epi32(zero_v, mask_rem_v, addr); |
145 | } |
146 | |
147 | template < |
148 | typename T, |
149 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
150 | static inline __m512i |
151 | mask_load(__m512i zero_v, __mmask8 mask_rem_v, void const* addr) { |
152 | return _mm512_mask_loadu_epi64(zero_v, mask_rem_v, addr); |
153 | } |
154 | |
155 | template < |
156 | typename T, |
157 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
158 | static inline __m512i maskz_load(__mmask16 mask_rem_v, void const* addr) { |
159 | return _mm512_maskz_loadu_epi32(mask_rem_v, addr); |
160 | } |
161 | |
162 | template < |
163 | typename T, |
164 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
165 | static inline __m512i maskz_load(__mmask8 mask_rem_v, void const* addr) { |
166 | return _mm512_maskz_loadu_epi64(mask_rem_v, addr); |
167 | } |
168 | |
169 | template < |
170 | typename T, |
171 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
172 | static inline __m512i mask_mov(__m512i src, __mmask16 mask_rem_v, __m512i a) { |
173 | return _mm512_mask_mov_epi32(src, mask_rem_v, a); |
174 | } |
175 | |
176 | template < |
177 | typename T, |
178 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
179 | static inline __m512i mask_mov(__m512i src, __mmask8 mask_rem_v, __m512i a) { |
180 | return _mm512_mask_mov_epi64(src, mask_rem_v, a); |
181 | } |
182 | |
183 | template < |
184 | typename T, |
185 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
186 | static inline __m512i gather(__m512i indices, const int32_t* addr) { |
187 | return _mm512_i32gather_epi32(indices, addr, 4); |
188 | } |
189 | |
190 | template < |
191 | typename T, |
192 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
193 | static inline __m512i gather(__m512i indices, const int32_t* addr) { |
194 | // ToDo: Change this _mm512_i64gather_epi64 once mapping table is 64-bit |
195 | __m256i res_32 = _mm512_i64gather_epi32(indices, addr, 4); |
196 | return _mm512_cvtepi32_epi64(res_32); |
197 | } |
198 | |
199 | template < |
200 | typename T, |
201 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
202 | static inline __m512i mask_gather( |
203 | __m512i src, |
204 | __mmask16 mask_rem_v, |
205 | __m512i indices, |
206 | const int32_t* addr) { |
207 | return _mm512_mask_i32gather_epi32(src, mask_rem_v, indices, addr, 4); |
208 | } |
209 | |
210 | template < |
211 | typename T, |
212 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
213 | static inline __m512i mask_gather( |
214 | __m512i src, |
215 | __mmask8 mask_rem_v, |
216 | __m512i indices, |
217 | const int32_t* addr) { |
218 | // ToDo: Change this _mm512_mask_i64gather_epi64 once mapping table is 64-bit |
219 | __m256i res_32 = _mm512_mask_i64gather_epi32( |
220 | _mm512_castsi512_si256(src), mask_rem_v, indices, addr, 4); |
221 | return _mm512_cvtepi32_epi64(res_32); |
222 | } |
223 | |
224 | template < |
225 | typename T, |
226 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
227 | static inline __mmask16 gen_mask(__m512i indices, __m512i zero_v) { |
228 | return _mm512_cmpge_epi32_mask(indices, zero_v); |
229 | } |
230 | |
231 | template < |
232 | typename T, |
233 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
234 | static inline __mmask8 gen_mask(__m512i indices, __m512i zero_v) { |
235 | return _mm512_cmpge_epi64_mask(indices, zero_v); |
236 | } |
237 | |
238 | template < |
239 | typename T, |
240 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
241 | static inline void compress_store(void* addr, __mmask16 mask, __m512i src_v) { |
242 | _mm512_mask_compressstoreu_ps(addr, mask, _mm512_castsi512_ps(src_v)); |
243 | } |
244 | |
245 | template < |
246 | typename T, |
247 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
248 | static inline void compress_store(void* addr, __mmask8 mask, __m512i src_v) { |
249 | _mm512_mask_compressstoreu_pd(addr, mask, _mm512_castsi512_pd(src_v)); |
250 | } |
251 | |
252 | template < |
253 | typename T, |
254 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
255 | static inline void |
256 | compress_store_weights(void* addr, __mmask16 mask, __m512 src_v) { |
257 | _mm512_mask_compressstoreu_ps(addr, mask, src_v); |
258 | } |
259 | |
260 | template < |
261 | typename T, |
262 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
263 | static inline void |
264 | compress_store_weights(void* addr, __mmask8 mask, __m256 src_v) { |
265 | _mm256_mask_compressstoreu_ps(addr, mask, src_v); |
266 | } |
267 | |
268 | template < |
269 | typename T, |
270 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
271 | static inline __m512 compress(__m512i zero_v, __mmask16 mask, __m512i src_v) { |
272 | return _mm512_mask_compress_ps( |
273 | _mm512_castsi512_ps(zero_v), mask, _mm512_castsi512_ps(src_v)); |
274 | } |
275 | |
276 | template < |
277 | typename T, |
278 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
279 | static inline __m512d compress(__m512i zero_v, __mmask8 mask, __m512i src_v) { |
280 | return _mm512_mask_compress_pd( |
281 | _mm512_castsi512_pd(zero_v), mask, _mm512_castsi512_pd(src_v)); |
282 | } |
283 | |
284 | template < |
285 | typename T, |
286 | typename std::enable_if<std::is_same<T, int32_t>::value, int>::type = 0> |
287 | static inline void mask_store(void* addr, __mmask16 mask, __m512 src_v) { |
288 | _mm512_mask_storeu_epi32(addr, mask, _mm512_castps_si512(src_v)); |
289 | } |
290 | |
291 | template < |
292 | typename T, |
293 | typename std::enable_if<std::is_same<T, int64_t>::value, int>::type = 0> |
294 | static inline void mask_store(void* addr, __mmask8 mask, __m512d src_v) { |
295 | _mm512_mask_storeu_epi64(addr, mask, _mm512_castpd_si512(src_v)); |
296 | } |
297 | |
298 | // copy len bytes from src to dest |
299 | static inline void mymemcpy(char* src, char* dest, int len) { |
300 | constexpr int VLEN = 64; |
301 | int i = 0; |
302 | for (; i < len / VLEN * VLEN; i += VLEN) { |
303 | auto src_v = _mm512_loadu_si512(src + i); |
304 | _mm512_storeu_si512(dest + i, src_v); |
305 | } |
306 | int rem = len - i; |
307 | if (rem > 0) { |
308 | __mmask64 mask_rem_v = (((long long)1) << rem) - 1; |
309 | auto src_v = _mm512_maskz_loadu_epi8(mask_rem_v, src + i); |
310 | _mm512_mask_storeu_epi8(dest + i, mask_rem_v, src_v); |
311 | } |
312 | } |
313 | |
314 | template < |
315 | typename IndexType, |
316 | bool HAS_WEIGHTS, |
317 | int UNROLL = 8, |
318 | bool USE_MASK = false> |
319 | static inline void compressed_indices_remap_avx512_helper( |
320 | __m512i zero_v, |
321 | __m512i minus1_v, |
322 | const IndexType* offsets, |
323 | const IndexType* indices, |
324 | const int32_t* compressed_indices_mapping, |
325 | const float* weights, |
326 | IndexType* out_indices, |
327 | float* out_weights, |
328 | IndexType* count_indices, |
329 | const int32_t* rem, |
330 | const int32_t* ind_w_start_offsets) { |
331 | typename reg_t<IndexType>::mask_reg_t mask_rem_v[UNROLL]; |
332 | for (int i = 0; i < UNROLL; ++i) { |
333 | mask_rem_v[i] = mask_from_rem<IndexType>(rem[i]); |
334 | } |
335 | for (int i = 0; i < UNROLL; ++i) { |
336 | __m512i indices_v; |
337 | if (USE_MASK) { |
338 | indices_v = mask_load<IndexType>( |
339 | zero_v, |
340 | mask_rem_v[i], |
341 | reinterpret_cast<void const*>( |
342 | indices + offsets[i] + ind_w_start_offsets[i])); |
343 | } else { |
344 | indices_v = load<IndexType>(reinterpret_cast<void const*>( |
345 | indices + offsets[i] + ind_w_start_offsets[i])); |
346 | } |
347 | |
348 | // gather remapped indices from the mapping table |
349 | __m512i remapped_indices_v; |
350 | if (USE_MASK) { |
351 | remapped_indices_v = mask_gather<IndexType>( |
352 | zero_v, mask_rem_v[i], indices_v, compressed_indices_mapping); |
353 | // mov -1 to not used places in the vector |
354 | remapped_indices_v = |
355 | mask_mov<IndexType>(minus1_v, mask_rem_v[i], remapped_indices_v); |
356 | |
357 | } else { |
358 | remapped_indices_v = |
359 | gather<IndexType>(indices_v, compressed_indices_mapping); |
360 | } |
361 | |
362 | typename reg_t<IndexType>::w_reg_t weights_v; |
363 | if (HAS_WEIGHTS) { |
364 | if (USE_MASK) { |
365 | weights_v = mask_load_weights<IndexType>( |
366 | zero_v, |
367 | mask_rem_v[i], |
368 | reinterpret_cast<void const*>( |
369 | weights + offsets[i] + ind_w_start_offsets[i])); |
370 | } else { |
371 | weights_v = load_weights<IndexType>( |
372 | weights + offsets[i] + ind_w_start_offsets[i]); |
373 | } |
374 | } |
375 | |
376 | // Now remove -1 from the remapped indices |
377 | auto mask_indices_v = gen_mask<IndexType>(remapped_indices_v, zero_v); |
378 | |
379 | if (USE_MASK) { |
380 | auto out_indices_v = |
381 | compress<IndexType>(zero_v, mask_indices_v, remapped_indices_v); |
382 | |
383 | mask_store<IndexType>( |
384 | reinterpret_cast<void*>(out_indices + offsets[i] + count_indices[i]), |
385 | mask_rem_v[i], |
386 | out_indices_v); |
387 | } else { |
388 | compress_store<IndexType>( |
389 | reinterpret_cast<void*>(out_indices + offsets[i] + count_indices[i]), |
390 | mask_indices_v, |
391 | remapped_indices_v); |
392 | } |
393 | |
394 | if (HAS_WEIGHTS) { |
395 | if (USE_MASK) { |
396 | mask_compress_and_store_weights<IndexType>( |
397 | reinterpret_cast<void*>( |
398 | out_weights + offsets[i] + count_indices[i]), |
399 | zero_v, |
400 | mask_indices_v, |
401 | mask_rem_v[i], |
402 | weights_v); |
403 | } else { |
404 | compress_store_weights<IndexType>( |
405 | reinterpret_cast<void*>( |
406 | out_weights + offsets[i] + count_indices[i]), |
407 | mask_indices_v, |
408 | weights_v); |
409 | } |
410 | } |
411 | |
412 | count_indices[i] += _mm_popcnt_u32(mask_indices_v); |
413 | } |
414 | } |
415 | |
416 | template <typename IndexType, bool HAS_WEIGHTS> |
417 | void compressed_indices_remap_avx512( |
418 | std::int32_t offsets_len, |
419 | const IndexType* indices, |
420 | const int32_t* compressed_indices_mapping, |
421 | const IndexType* offsets, |
422 | const float* weights, // optional, can be null, |
423 | IndexType* out_indices, |
424 | IndexType* out_offsets, |
425 | float* out_weights) { |
426 | __m512i zero_v = _mm512_set1_epi32(0); |
427 | __m512i minus1_v = _mm512_set1_epi32(-1); |
428 | out_offsets[0] = offsets[0]; |
429 | constexpr int UNROLL = 8; |
430 | constexpr int VLEN = get_vlen<IndexType>(); |
431 | int k = 1; |
432 | for (; k < (offsets_len - 1) / UNROLL * UNROLL; k += UNROLL) { |
433 | int32_t len[UNROLL]; |
434 | int32_t rem[UNROLL]; |
435 | for (int l = 0; l < UNROLL; ++l) { |
436 | len[l] = offsets[k + l] - offsets[k + l - 1]; |
437 | } |
438 | // count of non-pruned indices |
439 | IndexType count_indices[UNROLL] = {0}; |
440 | // read indices/weights starting at these offsets |
441 | int32_t ind_w_start_offsets[UNROLL] = {0}; |
442 | __m256i vec_len_v = _mm256_set1_epi32(VLEN); |
443 | __m256i len_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len)); |
444 | __mmask8 cmp_res_v = _mm256_cmpge_epi32_mask(len_v, vec_len_v); |
445 | len_v = _mm256_mask_sub_epi32(len_v, cmp_res_v, len_v, vec_len_v); |
446 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(len), len_v); |
447 | __m256i rem_v = _mm256_maskz_mov_epi32(cmp_res_v, vec_len_v); |
448 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), rem_v); |
449 | int active_unrolls = _mm_popcnt_u32(cmp_res_v); |
450 | |
451 | // if we have any at least 1 full vector length work |
452 | // take vector path |
453 | while (active_unrolls > 0) { |
454 | compressed_indices_remap_avx512_helper< |
455 | IndexType, |
456 | HAS_WEIGHTS, |
457 | UNROLL, |
458 | true>( |
459 | zero_v, |
460 | minus1_v, |
461 | offsets + k - 1, |
462 | indices, |
463 | compressed_indices_mapping, |
464 | weights, |
465 | out_indices, |
466 | out_weights, |
467 | count_indices, |
468 | rem, |
469 | ind_w_start_offsets); |
470 | |
471 | __m256i start_offsets_v = _mm256_loadu_si256( |
472 | reinterpret_cast<const __m256i*>(ind_w_start_offsets)); |
473 | start_offsets_v = _mm256_mask_add_epi32( |
474 | start_offsets_v, cmp_res_v, start_offsets_v, vec_len_v); |
475 | _mm256_storeu_si256( |
476 | reinterpret_cast<__m256i*>(ind_w_start_offsets), start_offsets_v); |
477 | |
478 | len_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len)); |
479 | cmp_res_v = _mm256_cmpge_epi32_mask(len_v, vec_len_v); |
480 | len_v = _mm256_mask_sub_epi32(len_v, cmp_res_v, len_v, vec_len_v); |
481 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(len), len_v); |
482 | rem_v = _mm256_maskz_mov_epi32(cmp_res_v, vec_len_v); |
483 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), rem_v); |
484 | active_unrolls = _mm_popcnt_u32(cmp_res_v); |
485 | } |
486 | |
487 | // Now work on all the remainders |
488 | __m256i len_rem_v = |
489 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(len)); |
490 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(rem), len_rem_v); |
491 | compressed_indices_remap_avx512_helper< |
492 | IndexType, |
493 | HAS_WEIGHTS, |
494 | UNROLL, |
495 | true>( |
496 | zero_v, |
497 | minus1_v, |
498 | offsets + k - 1, |
499 | indices, |
500 | compressed_indices_mapping, |
501 | weights, |
502 | out_indices, |
503 | out_weights, |
504 | count_indices, |
505 | rem, |
506 | ind_w_start_offsets); |
507 | |
508 | // update output offsets |
509 | for (int l = 0; l < UNROLL; ++l) { |
510 | out_offsets[k + l] = out_offsets[k + l - 1] + count_indices[l]; |
511 | } |
512 | } |
513 | |
514 | // work on remaining offsets_len serially |
515 | constexpr int UNROLL_REM = 1; |
516 | for (; k < offsets_len; ++k) { |
517 | int32_t len[UNROLL_REM]; |
518 | int32_t rem[UNROLL_REM] = {0}; |
519 | for (int l = 0; l < UNROLL_REM; ++l) { |
520 | len[l] = offsets[k + l] - offsets[k + l - 1]; |
521 | } |
522 | IndexType count_indices[UNROLL_REM] = {0}; |
523 | int32_t ind_w_start_offsets[UNROLL_REM] = {0}; |
524 | int i = 0; |
525 | for (; i < len[0] / VLEN * VLEN; i += VLEN) { |
526 | compressed_indices_remap_avx512_helper< |
527 | IndexType, |
528 | HAS_WEIGHTS, |
529 | UNROLL_REM, |
530 | false>( |
531 | zero_v, |
532 | minus1_v, |
533 | offsets + k - 1, |
534 | indices, |
535 | compressed_indices_mapping, |
536 | weights, |
537 | out_indices, |
538 | out_weights, |
539 | count_indices, |
540 | rem, |
541 | ind_w_start_offsets); |
542 | ind_w_start_offsets[0] += VLEN; |
543 | } |
544 | // remainder |
545 | rem[0] = len[0] - i; |
546 | if (rem[0] > 0) { |
547 | compressed_indices_remap_avx512_helper< |
548 | IndexType, |
549 | HAS_WEIGHTS, |
550 | UNROLL_REM, |
551 | true>( |
552 | zero_v, |
553 | minus1_v, |
554 | offsets + k - 1, |
555 | indices, |
556 | compressed_indices_mapping, |
557 | weights, |
558 | out_indices, |
559 | out_weights, |
560 | count_indices, |
561 | rem, |
562 | ind_w_start_offsets); |
563 | } |
564 | |
565 | for (int l = 0; l < UNROLL_REM; ++l) { |
566 | out_offsets[k + l] = out_offsets[k + l - 1] + count_indices[l]; |
567 | } |
568 | } |
569 | |
570 | // Results are stored at input offsets in output variables |
571 | // copy results to right output locations |
572 | for (int i = 1; i < offsets_len; ++i) { |
573 | int out_len = out_offsets[i] - out_offsets[i - 1]; |
574 | mymemcpy( |
575 | reinterpret_cast<char*>(out_indices + offsets[i - 1]), |
576 | reinterpret_cast<char*>(out_indices + out_offsets[i - 1]), |
577 | out_len * sizeof(IndexType)); |
578 | if (HAS_WEIGHTS) { |
579 | mymemcpy( |
580 | reinterpret_cast<char*>(out_weights + offsets[i - 1]), |
581 | reinterpret_cast<char*>(out_weights + out_offsets[i - 1]), |
582 | out_len * sizeof(float)); |
583 | } |
584 | } |
585 | } |
586 | |
587 | #define INSTANTIATE_REMAP_BASE(INDEX_TYPE, HAS_WEIGHTS) \ |
588 | template void compressed_indices_remap_avx512<INDEX_TYPE, HAS_WEIGHTS>( \ |
589 | std::int32_t offsets_numel, \ |
590 | const INDEX_TYPE* indices, \ |
591 | const int32_t* compressed_indices_mapping, \ |
592 | const INDEX_TYPE* offsets, \ |
593 | const float* weights, \ |
594 | INDEX_TYPE* out_indices, \ |
595 | INDEX_TYPE* out_offsets, \ |
596 | float* out_weights); |
597 | |
598 | INSTANTIATE_REMAP_BASE(int32_t, true) |
599 | INSTANTIATE_REMAP_BASE(int32_t, false) |
600 | INSTANTIATE_REMAP_BASE(int64_t, true) |
601 | INSTANTIATE_REMAP_BASE(int64_t, false) |
602 | |
603 | #undef INSTANTIATE_REMAP_BASE |
604 | |
605 | } // namespace internal |
606 | } // namespace fbgemm |
607 | |