1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8
9#include "fbgemm/FbgemmEmbedding.h"
10
11#include <asmjit/asmjit.h>
12#include <cpuinfo.h>
13#include <cassert>
14#include <cmath>
15#include <iostream>
16#include <map>
17#include <mutex>
18#include <string>
19#include <tuple>
20#include "./CodeCache.h"
21#include "./MaskAvx2.h"
22#include "./RefImplementations.h"
23#include "fbgemm/FbgemmConvert.h"
24#include "fbgemm/SimdUtils.h"
25
26namespace fbgemm {
27
28namespace {
29
30namespace x86 = asmjit::x86;
31
32template <
33 typename inType,
34 typename indxType,
35 typename offsetType,
36 typename outType,
37 bool ROWWISE_SPARSE>
38class ReturnFunctionSignature {};
39
40template <
41 typename inType,
42 typename indxType,
43 typename offsetType,
44 typename outType>
45class ReturnFunctionSignature<inType, indxType, offsetType, outType, false> {
46 public:
47 using jit_embedding_kernel = bool (*)(
48 int64_t output_size,
49 int64_t index_size,
50 int64_t data_size,
51 const inType* input,
52 const indxType* indices,
53 const offsetType* offsets_or_lengths,
54 const float* weights,
55 outType* out,
56 const int* mask);
57};
58
59template <
60 typename inType,
61 typename indxType,
62 typename offsetType,
63 typename outType>
64class ReturnFunctionSignature<inType, indxType, offsetType, outType, true> {
65 public:
66 using jit_embedding_kernel = bool (*)(
67 int64_t output_size,
68 int64_t index_size,
69 int64_t uncompressed_data_size,
70 // int64_t compressed_data_size,
71 const inType* input,
72 const indxType* indices,
73 const offsetType* offsets_or_lengths,
74 const float* weights,
75 outType* out,
76 const int32_t* compressed_indices_table,
77 const int* mask);
78};
79
80template <
81 typename inType,
82 typename indxType,
83 typename offsetType,
84 typename outType,
85 inst_set_t instSet,
86 bool ROWWISE_SPARSE = false,
87 bool THREAD_LOCAL = false>
88class GenEmbeddingSpMDMLookup {
89 public:
90 GenEmbeddingSpMDMLookup() {}
91 typename ReturnFunctionSignature<
92 inType,
93 indxType,
94 offsetType,
95 outType,
96 ROWWISE_SPARSE>::jit_embedding_kernel
97 getOrCreate(
98 int block_size,
99 bool has_weight,
100 bool is_weight_positional,
101 bool normalize_by_lengths,
102 int prefetch,
103 bool use_offsets,
104 int output_stride,
105 int input_stride,
106 bool scale_bias_last,
107 bool isbf16);
108
109 private:
110 static asmjit::JitRuntime& runtime() {
111 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
112 // depents on other static
113 // variables. Required to prevent
114 // initialization order fiasco
115 return rt;
116 }
117
118 static std::mutex rtMutex_; ///< Controll access to runtime;
119
120 // The hash depends on embedding dimension (block size), weighted sls,
121 // positional weights, normalize by lenths, prefetch distance, use_offsets,
122 // output_stride, input_stride, and scale_bias_last
123 static CodeCache<
124 std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool>,
125 typename ReturnFunctionSignature<
126 inType,
127 indxType,
128 offsetType,
129 outType,
130 ROWWISE_SPARSE>::jit_embedding_kernel,
131 THREAD_LOCAL>
132 codeCache_; ///< JIT Code Cache for reuse.
133}; // GenEmbeddingSpmDMLookup
134
135template <
136 typename inType,
137 typename indxType,
138 typename offsetType,
139 typename outType,
140 inst_set_t instSet,
141 bool ROWWISE_SPARSE,
142 bool THREAD_LOCAL>
143std::mutex GenEmbeddingSpMDMLookup<
144 inType,
145 indxType,
146 offsetType,
147 outType,
148 instSet,
149 ROWWISE_SPARSE,
150 THREAD_LOCAL>::rtMutex_;
151
152template <
153 typename inType,
154 typename indxType,
155 typename offsetType,
156 typename outType,
157 inst_set_t instSet,
158 bool ROWWISE_SPARSE,
159 bool THREAD_LOCAL>
160CodeCache<
161 std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool>,
162 typename ReturnFunctionSignature<
163 inType,
164 indxType,
165 offsetType,
166 outType,
167 ROWWISE_SPARSE>::jit_embedding_kernel,
168 THREAD_LOCAL>
169 GenEmbeddingSpMDMLookup<
170 inType,
171 indxType,
172 offsetType,
173 outType,
174 instSet,
175 ROWWISE_SPARSE,
176 THREAD_LOCAL>::codeCache_;
177
178template <
179 typename inType,
180 typename indxType,
181 typename offsetType,
182 typename outType,
183 inst_set_t instSet,
184 bool ROWWISE_SPARSE,
185 bool THREAD_LOCAL>
186typename ReturnFunctionSignature<
187 inType,
188 indxType,
189 offsetType,
190 outType,
191 ROWWISE_SPARSE>::jit_embedding_kernel
192GenEmbeddingSpMDMLookup<
193 inType,
194 indxType,
195 offsetType,
196 outType,
197 instSet,
198 ROWWISE_SPARSE,
199 THREAD_LOCAL>::
200 getOrCreate(
201 int block_size,
202 bool has_weight,
203 bool is_weight_positional,
204 bool normalize_by_lengths,
205 int prefetch,
206 bool use_offsets,
207 int output_stride,
208 int input_stride,
209 bool scale_bias_last,
210 bool isbf16) {
211 std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool> kernelSig =
212 std::make_tuple(
213 block_size,
214 has_weight,
215 is_weight_positional,
216 normalize_by_lengths,
217 prefetch,
218 use_offsets,
219 output_stride,
220 input_stride,
221 scale_bias_last,
222 isbf16);
223
224 return codeCache_.getOrCreate(
225 kernelSig,
226 [&]() -> typename ReturnFunctionSignature<
227 inType,
228 indxType,
229 offsetType,
230 outType,
231 ROWWISE_SPARSE>::jit_embedding_kernel {
232 bool is8bit = std::is_same<inType, uint8_t>::value;
233 bool is16bit = std::is_same<inType, uint16_t>::value;
234 bool is16bitout = std::is_same<outType, uint16_t>::value;
235 bool isbf16out = isbf16;
236 bool isfp16 = is16bit && !isbf16;
237 bool isfp16out = is16bitout && !isbf16out;
238
239 // TODO: Make this tunable
240 int pref_dist = prefetch;
241 bool areIndices64b = std::is_same<indxType, int64_t>::value;
242
243 asmjit::CodeHolder code;
244 code.init(runtime().environment());
245 x86::Assembler assembler(&code);
246 x86::Emitter* a = assembler.as<x86::Emitter>();
247#if defined(FBGEMM_LOG_CODE)
248 std::string filename = "embeddinglookup";
249 if (is8bit) {
250 filename += "_8bit";
251 } else if (isfp16) {
252 filename += "_fp16";
253 } else if (isbf16) {
254 filename += "_bf16";
255 }
256 if (isbf16out) {
257 filename += "_bf16_out";
258 } else if (isfp16out) {
259 filename += "_fp16_out";
260 }
261 filename += "_emd_dim_" + std::to_string(block_size);
262 filename += areIndices64b ? "_64bit" : "_32bit";
263 filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2";
264 if (prefetch) {
265 filename += "_prefetch";
266 }
267 if (has_weight) {
268 filename += "_hasweight";
269 }
270 if (normalize_by_lengths) {
271 filename += "_normalize_by_lengths";
272 }
273 if (!use_offsets) {
274 filename += "_use_lengths";
275 }
276 if (ROWWISE_SPARSE) {
277 filename += "_rowwise_sparse";
278 }
279 filename += "_out_stride_" + std::to_string(output_stride);
280 if (!scale_bias_last) {
281 filename += "_scale_bias_first";
282 }
283 filename += ".txt";
284 FILE* codeLogFile = fopen(filename.c_str(), "w");
285 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
286 code.setLogger(codeLogger);
287#endif
288 // arguments to the function created
289 x86::Gp output_size = a->zdi();
290 // index_size will be overwritten to hold the end address of indices
291 x86::Gp index_size = a->zsi();
292 x86::Gp data_size = a->zdx();
293 x86::Gp input = a->zcx();
294 int reg_id = 8;
295 x86::Gp indices = a->gpz(reg_id); // 8
296 ++reg_id;
297 x86::Gp lengths = a->gpz(reg_id); // 9
298 ++reg_id;
299 x86::Gp weights = a->gpz(reg_id); // 10
300 ++reg_id;
301 x86::Gp out = a->gpz(reg_id); // 11
302
303 x86::Gp compressed_indices_table;
304 if (ROWWISE_SPARSE) {
305 ++reg_id;
306 compressed_indices_table = a->gpz(reg_id); // 12
307 }
308 ++reg_id;
309 x86::Gp scratchReg1_ = a->gpz(reg_id); // 12 or 13, also for mask
310
311 ++reg_id;
312 x86::Gpd lengths_R_ = a->gpz(reg_id).r32(); // 13 or 14
313 ++reg_id;
314 x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15
315
316 asmjit::FuncDetail func;
317
318 if (ROWWISE_SPARSE) {
319 func.init(
320 asmjit::FuncSignatureT<
321 bool,
322 int64_t, // output_size
323 int64_t, // index_size
324 int64_t, // uncompressed_data_size
325 const inType*, // input uint8_t or float
326 const indxType*, // indices
327 const offsetType*, // offsets or lengths
328 const float*, // weights
329 outType*, // out
330 const int32_t*, // compressed_indices_table and then mask
331 const int*>(asmjit::CallConvId::kHost),
332 a->environment());
333 } else {
334 func.init(
335 asmjit::FuncSignatureT<
336 bool,
337 int64_t, // output_size
338 int64_t, // index_size
339 int64_t, // data_size
340 const inType*, // input uint8_t or float
341 const indxType*, // indices
342 const offsetType*, // offsets or lengths
343 const float*, // weights
344 outType*, // out and then mask
345 const int*>(asmjit::CallConvId::kHost),
346 a->environment());
347 }
348
349 asmjit::FuncFrame frame;
350 frame.init(func);
351
352 if (instSet == inst_set_t::avx2) {
353 frame.setDirtyRegs(
354 asmjit::RegGroup::kVec,
355 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
356 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
357 } else {
358 frame.setDirtyRegs(
359 asmjit::RegGroup::kVec,
360 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
361 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
362 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
363 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
364 }
365
366 frame.setDirtyRegs(
367 asmjit::RegGroup::kGp,
368 reg_id == 15
369 ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
370 : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
371
372 asmjit::FuncArgsAssignment args(&func);
373 if (ROWWISE_SPARSE) {
374 args.assignAll(
375 output_size,
376 index_size,
377 data_size,
378 input,
379 indices,
380 lengths,
381 weights,
382 out,
383 compressed_indices_table,
384 scratchReg1_);
385 } else {
386 args.assignAll(
387 output_size,
388 index_size,
389 data_size,
390 input,
391 indices,
392 lengths,
393 weights,
394 out,
395 scratchReg1_);
396 }
397
398 args.updateFuncFrame(frame);
399 frame.finalize();
400
401 a->emitProlog(frame);
402 a->emitArgsAssignment(frame, args);
403
404 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
405 constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
406 int unroll_factor = NUM_VEC_REG;
407
408 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
409
410 int num_vec_regs_per_block = (block_size + vlen - 1) / vlen;
411 int remainder = block_size % vlen;
412
413 vec_reg_t scale_vreg; // holds scale
414 vec_reg_t bias_vreg; // holds bias
415 vec_reg_t w_vreg; // for weighted sls -- weights
416 vec_reg_t
417 vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i]
418 vec_reg_t src_vreg; // for holding embedding value temporarily
419 x86::Ymm mask_vreg; // mask for avx2
420 x86::Xmm mask_fp16_vreg; // mask for loading fp16 in avx2
421 vec_reg_t ones_vreg; // 2^15 for bf16_2_fp32_rn
422
423 if (is8bit) {
424 // We need 2 vec registers for 1. scale 2. bias
425 --unroll_factor;
426 scale_vreg = vec_reg_t(unroll_factor);
427 --unroll_factor;
428 bias_vreg = vec_reg_t(unroll_factor);
429 }
430
431 if (isbf16out) {
432 --unroll_factor;
433 ones_vreg = vec_reg_t(unroll_factor);
434 a->mov(scratchReg2_, 1 << 15);
435 a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0);
436 a->vpbroadcastd(ones_vreg, ones_vreg.xmm());
437 }
438
439 if (is8bit || is16bit || (remainder && instSet == inst_set_t::avx2)) {
440 --unroll_factor;
441 src_vreg = vec_reg_t(unroll_factor);
442 }
443
444 if (has_weight) {
445 --unroll_factor;
446 w_vreg = vec_reg_t(unroll_factor);
447 }
448
449 if (remainder && instSet == inst_set_t::avx2) {
450 // AVX512 doesn't need to use vector register for masking
451 --unroll_factor;
452 mask_vreg = x86::ymm(unroll_factor);
453 if (remainder > 1 && (is16bit || isbf16out || isfp16out)) {
454 --unroll_factor;
455 mask_fp16_vreg = x86::xmm(unroll_factor);
456 }
457 }
458
459 if (normalize_by_lengths) {
460 --unroll_factor;
461 vlen_inv_vreg = vec_reg_t(unroll_factor);
462 }
463
464 if (remainder) {
465 if (instSet == inst_set_t::avx2) {
466 a->vmovups(
467 mask_vreg,
468 x86::ymmword_ptr(
469 scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t)));
470 if (is16bit || isbf16out || isfp16out) {
471 if (remainder > 1) {
472 a->vmovups(
473 mask_fp16_vreg,
474 x86::xmmword_ptr(
475 scratchReg1_,
476 (vlen - remainder / 2) * sizeof(int32_t)));
477 }
478 // We need to keep using the stack during the main loop
479 a->lea(
480 x86::rsp,
481 x86::dword_ptr(
482 x86::rsp, static_cast<int32_t>(-vlen * sizeof(int32_t))));
483 }
484 } else {
485 a->mov(scratchReg1_, (1 << remainder) - 1);
486 a->kmovw(x86::k(1), scratchReg1_);
487 }
488 }
489
490 // Compute the end address of indices
491 a->lea(
492 index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2));
493
494 asmjit::Label exit = a->newLabel();
495 asmjit::Label error = a->newLabel();
496 asmjit::Label LoopRangeIndexBegin = a->newLabel();
497 asmjit::Label LoopRangeIndexEnd = a->newLabel();
498
499 // rangeIndex loop begins (iterate output_size times)
500 a->bind(LoopRangeIndexBegin);
501 a->dec(output_size);
502 a->jl(LoopRangeIndexEnd);
503
504 if (normalize_by_lengths) {
505 asmjit::Label IfLengthsBegin = a->newLabel();
506 asmjit::Label IfLengthsEnd = a->newLabel();
507 a->bind(IfLengthsBegin);
508 if (use_offsets) {
509 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
510 a->sub(lengths_R_, x86::dword_ptr(lengths));
511 } else {
512 a->mov(lengths_R_, x86::dword_ptr(lengths));
513 }
514 a->cmp(lengths_R_, 1);
515 // Initialize vlen_inv as 0 in case lengths is 0
516 a->vxorps(vlen_inv_vreg, vlen_inv_vreg, vlen_inv_vreg);
517 a->jl(IfLengthsEnd);
518
519 // OK to use vreg0 because it's for out_vreg used in the main loop
520 vec_reg_t temp_vreg(0);
521 if (instSet == inst_set_t::avx2) {
522 a->mov(scratchReg1_, 1);
523 a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_);
524 a->cvtsi2ss(temp_vreg.xmm(), lengths_R_);
525 a->divss(vlen_inv_vreg.xmm(), temp_vreg.xmm());
526 a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm());
527 } else { // avx512
528 a->mov(scratchReg1_, 1);
529 a->cvtsi2ss(temp_vreg.xmm(), scratchReg1_);
530 a->vpbroadcastd(vlen_inv_vreg, temp_vreg.xmm());
531 a->vpbroadcastd(temp_vreg, lengths_R_);
532 a->vcvtdq2ps(temp_vreg, temp_vreg);
533 a->vdivps(vlen_inv_vreg, vlen_inv_vreg, temp_vreg);
534 }
535 a->bind(IfLengthsEnd);
536 }
537
538 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
539 vec_idx += unroll_factor) {
540 int cur_unroll_factor =
541 std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
542
543 // Initialize output regs
544 for (int v = 0; v < cur_unroll_factor; ++v) {
545 vec_reg_t out_vreg = vec_reg_t(v);
546 a->vxorps(out_vreg, out_vreg, out_vreg);
547 }
548
549 if (use_offsets) {
550 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
551 a->sub(lengths_R_, x86::dword_ptr(lengths));
552 } else {
553 a->mov(lengths_R_, x86::dword_ptr(lengths));
554 }
555
556 // Array out of bound check
557 a->lea(
558 scratchReg1_,
559 x86::ptr(indices, lengths_R_, areIndices64b ? 3 : 2));
560 a->cmp(scratchReg1_, index_size);
561 a->jg(error);
562
563 asmjit::Label LoopDataIndexBegin = a->newLabel();
564 asmjit::Label LoopDataIndexEnd = a->newLabel();
565 asmjit::Label ValidIndexLabel = a->newLabel();
566
567 // dataIndex loop begins (iterate lengths_R_ times)
568 a->bind(LoopDataIndexBegin);
569 a->dec(lengths_R_);
570 a->jl(LoopDataIndexEnd);
571
572 // Array out of bound check
573 if (areIndices64b) {
574 a->mov(scratchReg1_, x86::qword_ptr(indices));
575 } else {
576 a->mov(scratchReg1_.r32(), x86::dword_ptr(indices));
577 }
578 if (!scale_bias_last) {
579 // When scale_bias_last == false, assume this is for table batched
580 // embedding (TBE) that can get -1 for pruned rows.
581 if (areIndices64b) {
582 a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1));
583 } else {
584 a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
585 }
586 a->jne(ValidIndexLabel);
587 a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
588 a->jmp(LoopDataIndexBegin);
589 a->bind(ValidIndexLabel);
590 }
591 // A trick to check x >= data_size or x < 0 in one shot by treating
592 // scratchReg1_ as if it has unsigned value
593 // (https://stackoverflow.com/a/34072155).
594 a->cmp(scratchReg1_, data_size);
595 a->jae(error);
596
597 if (ROWWISE_SPARSE) {
598 a->mov(
599 scratchReg1_.r32(),
600 x86::dword_ptr(
601 compressed_indices_table,
602 scratchReg1_,
603 2)); // use of 2 is to multiply by 4
604 }
605
606 int fused_block_size = input_stride * sizeof(inType);
607
608 if (pref_dist) {
609 asmjit::Label pref_dist_reset_start = a->newLabel();
610 asmjit::Label pref_dist_reset_end = a->newLabel();
611 // out of bound handling for prefetch
612 a->lea(
613 scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType)));
614 a->cmp(scratchReg2_, index_size);
615 a->jge(pref_dist_reset_start);
616
617 if (areIndices64b) {
618 a->mov(
619 scratchReg2_,
620 x86::qword_ptr(indices, pref_dist * sizeof(indxType)));
621 } else {
622 a->mov(
623 scratchReg2_.r32(),
624 x86::dword_ptr(indices, pref_dist * sizeof(indxType)));
625 }
626
627 a->jmp(pref_dist_reset_end);
628
629 a->bind(pref_dist_reset_start);
630 // things are not okay just get the current row
631 // this can be improved to getting the max dist row.
632 if (areIndices64b) {
633 a->mov(scratchReg2_, x86::qword_ptr(indices));
634 } else {
635 a->mov(scratchReg2_.r32(), x86::dword_ptr(indices));
636 }
637
638 a->bind(pref_dist_reset_end);
639 if (ROWWISE_SPARSE) {
640 asmjit::Label rowwise_sparse_pref_corner_case_begin =
641 a->newLabel();
642 asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel();
643 a->cmp(scratchReg2_, data_size);
644 a->jae(rowwise_sparse_pref_corner_case_begin);
645
646 a->mov(
647 scratchReg2_.r32(),
648 x86::dword_ptr(
649 compressed_indices_table,
650 scratchReg2_,
651 2)); // use of 2 is to multiply by 4
652 a->test(scratchReg2_.r32(), scratchReg2_.r32());
653 // Check negative
654 a->jns(rowwise_sparse_pref_corner_case_end);
655
656 a->bind(rowwise_sparse_pref_corner_case_begin);
657 // For corner case, just set prefetch row id to 0.
658 a->xor_(scratchReg2_.r32(), scratchReg2_.r32());
659 a->bind(rowwise_sparse_pref_corner_case_end);
660 }
661 a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size));
662 }
663
664 a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
665
666 if (has_weight) {
667 a->vbroadcastss(w_vreg, x86::dword_ptr(weights));
668 a->add(weights, static_cast<asmjit::Imm>(sizeof(float)));
669 }
670
671 if (ROWWISE_SPARSE) {
672 a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
673 a->je(LoopDataIndexBegin);
674 }
675
676 a->imul(scratchReg1_, static_cast<asmjit::Imm>(fused_block_size));
677
678 // broadcast the scale
679 x86::Mem scale_src, bias_src;
680 constexpr unsigned int CACHE_LINE_LEN = 64;
681 if (is8bit) {
682 if (scale_bias_last) {
683 scale_src = x86::dword_ptr(
684 input, scratchReg1_, 0, block_size * sizeof(uint8_t));
685 bias_src = x86::dword_ptr(
686 input,
687 scratchReg1_,
688 0,
689 block_size * sizeof(uint8_t) + sizeof(float));
690 a->vbroadcastss(scale_vreg, scale_src);
691 a->vbroadcastss(bias_vreg, bias_src);
692 } else {
693 scale_src = x86::word_ptr(input, scratchReg1_);
694 bias_src =
695 x86::word_ptr(input, scratchReg1_, 0, sizeof(uint16_t));
696 a->vpbroadcastw(scale_vreg.half(), scale_src);
697 a->vpbroadcastw(bias_vreg.half(), bias_src);
698 a->vcvtph2ps(scale_vreg, scale_vreg.half());
699 a->vcvtph2ps(bias_vreg, bias_vreg.half());
700 }
701
702 if (pref_dist && fused_block_size % CACHE_LINE_LEN > 0 &&
703 fused_block_size % CACHE_LINE_LEN <= 2 * sizeof(float)) {
704 a->prefetcht0(x86::dword_ptr(
705 input,
706 scratchReg2_,
707 0,
708 fused_block_size / CACHE_LINE_LEN * CACHE_LINE_LEN));
709 }
710 }
711
712 if (has_weight && is8bit) {
713 a->vmulps(scale_vreg, scale_vreg, w_vreg);
714 a->vmulps(bias_vreg, bias_vreg, w_vreg);
715 }
716
717 // The main computation
718 int src_addr_offset =
719 is8bit && !scale_bias_last ? 2 * sizeof(uint16_t) : 0;
720 for (int v = 0; v < cur_unroll_factor; ++v) {
721 constexpr int BYTES_PER_VLOAD = vlen * sizeof(inType);
722 auto src_addr = x86::dword_ptr(
723 input,
724 scratchReg1_,
725 0,
726 src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD);
727 vec_reg_t out_vreg = vec_reg_t(v);
728
729 // For 8bit SLS convert usigned 8-bit to 32bit int, then to float
730 // multiply with scale and then add with bias
731 if (is8bit) {
732 if (remainder && vec_idx + v == num_vec_regs_per_block - 1 &&
733 instSet == inst_set_t::avx512) {
734 a->k(x86::k(1)).z().vpmovzxbd(src_vreg, src_addr);
735 } else {
736 // We don't use a mask for AVX2 since we can use the extra
737 // "padding" of the 2 floats (= 8 chars) scale and bias
738 // this ensures we never access out of bound data
739 a->vpmovzxbd(src_vreg, src_addr);
740 }
741 a->vcvtdq2ps(src_vreg, src_vreg);
742 a->vaddps(out_vreg, out_vreg, bias_vreg);
743 a->vfmadd231ps(out_vreg, src_vreg, scale_vreg);
744 } else if (is16bit) {
745 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
746 if (instSet == inst_set_t::avx2) {
747 if (remainder % 2 == 0) {
748 a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr);
749 } else {
750 a->vpbroadcastw(
751 src_vreg.xmm(),
752 x86::word_ptr(
753 input,
754 scratchReg1_,
755 0,
756 src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD +
757 (remainder - 1) * sizeof(inType)));
758 if (remainder > 1) {
759 // AVX2 can't do masking for the last 16-bit so we store
760 // them to a stack and reload.
761 // First put broadcasted last 16-bit element
762 a->vmovups(x86::xmmword_ptr(x86::rsp), src_vreg.xmm());
763 // Mask store the remaining 16-bit elements
764 a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr);
765 a->vmaskmovps(
766 x86::xmmword_ptr(x86::rsp),
767 mask_fp16_vreg,
768 src_vreg.xmm());
769 // Load combined 16-bit elements
770 a->vmovups(src_vreg.xmm(), x86::xmmword_ptr(x86::rsp));
771 } // remainder > 1
772 } // remainder % 2
773 if (isfp16) {
774 a->vcvtph2ps(src_vreg.ymm(), src_vreg.xmm());
775 } else if (isbf16) {
776 // bf16
777 a->vpmovzxwd(src_vreg.ymm(), src_vreg.xmm());
778 a->vpslld(src_vreg.ymm(), src_vreg.ymm(), 16);
779 }
780 } else {
781 // avx512
782 if (isfp16) {
783 a->k(x86::k(1)).z().vcvtph2ps(src_vreg, src_addr);
784 } else if (isbf16) {
785 // bf16
786 a->k(x86::k(1)).z().vpmovzxwd(src_vreg, src_addr);
787 a->k(x86::k(1)).z().vpslld(src_vreg, src_vreg, 16);
788 }
789 }
790 } else {
791 // no remainder
792 if (isfp16) {
793 a->vcvtph2ps(src_vreg, src_addr);
794 } else if (isbf16) {
795 // bf16
796 a->vpmovzxwd(src_vreg, src_addr);
797 a->vpslld(src_vreg, src_vreg, 16);
798 }
799 }
800 if (has_weight) {
801 a->vfmadd231ps(out_vreg, w_vreg, src_vreg);
802 } else {
803 a->vaddps(out_vreg, out_vreg, src_vreg);
804 }
805 } else {
806 // This part for FP32 SLS
807 if (remainder && vec_idx + v == num_vec_regs_per_block - 1 &&
808 instSet == inst_set_t::avx2) {
809 a->vmaskmovps(src_vreg.ymm(), mask_vreg.ymm(), src_addr);
810 }
811 if (has_weight) {
812 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
813 if (instSet == inst_set_t::avx2) {
814 a->vfmadd231ps(out_vreg, w_vreg, src_vreg);
815 } else {
816 a->k(x86::k(1)).vfmadd231ps(out_vreg, w_vreg, src_addr);
817 }
818 } else {
819 a->vfmadd231ps(out_vreg, w_vreg, src_addr);
820 }
821 } else {
822 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
823 if (instSet == inst_set_t::avx2) {
824 a->vaddps(out_vreg, out_vreg, src_vreg);
825 } else {
826 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, src_addr);
827 }
828 } else {
829 a->vaddps(out_vreg, out_vreg, src_addr);
830 }
831 }
832 }
833
834 constexpr int VLOAD_PER_CACHE_LINE =
835 CACHE_LINE_LEN / BYTES_PER_VLOAD;
836 if (pref_dist && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) {
837 a->prefetcht0(x86::dword_ptr(
838 input, scratchReg2_, 0, (vec_idx + v) * BYTES_PER_VLOAD));
839 }
840 }
841
842 a->jmp(LoopDataIndexBegin);
843 a->bind(LoopDataIndexEnd);
844
845 // This loop is for writing back out_vreg (results)
846 // back to memory
847 for (int v = 0; v < cur_unroll_factor; ++v) {
848 auto dst_addr =
849 x86::dword_ptr(out, (vec_idx + v) * vlen * sizeof(outType));
850 vec_reg_t out_vreg = vec_reg_t(v);
851
852 if (normalize_by_lengths) {
853 a->vmulps(out_vreg, out_vreg, vlen_inv_vreg);
854 }
855
856 if (std::is_same<outType, float>::value) {
857 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
858 if (instSet == inst_set_t::avx2) {
859 a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm());
860 } else {
861 a->k(x86::k(1)).vmovups(dst_addr, out_vreg);
862 }
863 } else {
864 a->vmovups(dst_addr, out_vreg);
865 }
866 } else {
867 // fp16/bf16 output
868 if (instSet == inst_set_t::avx2) {
869 // round nearest with no exception
870 if (isfp16out) {
871 a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8);
872 } else if (isbf16out) {
873 a->vpaddd(out_vreg, out_vreg, ones_vreg);
874 a->vpsrld(out_vreg, out_vreg, 16);
875 a->vpackusdw(out_vreg, out_vreg, out_vreg);
876 a->vpermq(out_vreg, out_vreg, 0xd8);
877 }
878 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
879 if (remainder > 1) {
880 a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm());
881 }
882 if (remainder % 2 != 0) {
883 a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm());
884 a->mov(
885 scratchReg1_.r16(),
886 x86::word_ptr(
887 x86::rsp, (remainder - 1) * sizeof(outType)));
888 a->mov(
889 x86::word_ptr(
890 out,
891 ((vec_idx + v) * vlen + (remainder - 1)) *
892 sizeof(outType)),
893 scratchReg1_.r16());
894 }
895 } else {
896 a->vmovups(dst_addr, out_vreg.xmm());
897 }
898 } else {
899 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
900 if (isfp16out) {
901 a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8);
902 } else if (isbf16out) {
903 // bf16
904 a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg);
905 a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16);
906 a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg);
907 }
908 } else {
909 if (isfp16out) {
910 a->vcvtps2ph(dst_addr, out_vreg, 8);
911 } else if (isbf16out) {
912 // bf16
913 a->vpaddd(out_vreg, out_vreg, ones_vreg);
914 a->vpsrld(out_vreg, out_vreg, 16);
915 a->vpmovdw(dst_addr, out_vreg);
916 }
917 }
918 }
919 }
920 }
921
922 if (vec_idx + unroll_factor < num_vec_regs_per_block ||
923 (has_weight && is_weight_positional)) {
924 // Reset lengths_R_, indices, weights to run the dataIndex loop
925 // again
926 if (use_offsets) {
927 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
928 a->sub(lengths_R_, x86::dword_ptr(lengths));
929 } else {
930 a->mov(lengths_R_, x86::dword_ptr(lengths));
931 }
932
933 if (has_weight) {
934 a->imul(
935 scratchReg1_,
936 lengths_R_,
937 static_cast<asmjit::Imm>(sizeof(float)));
938 a->sub(weights, scratchReg1_);
939
940 if (vec_idx + unroll_factor < num_vec_regs_per_block) {
941 a->imul(
942 scratchReg1_,
943 static_cast<asmjit::Imm>(sizeof(indxType) / sizeof(float)));
944 a->sub(indices, scratchReg1_);
945 }
946 } else {
947 a->imul(
948 scratchReg1_,
949 lengths_R_,
950 static_cast<asmjit::Imm>(sizeof(indxType)));
951 a->sub(indices, scratchReg1_);
952 }
953 }
954 }
955
956 a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType)));
957 a->add(out, static_cast<asmjit::Imm>(output_stride * sizeof(outType)));
958
959 a->jmp(LoopRangeIndexBegin);
960 a->bind(LoopRangeIndexEnd);
961
962 a->cmp(indices, index_size);
963 a->jne(error);
964 a->mov(x86::eax, true);
965 a->jmp(exit);
966 a->bind(error);
967 a->mov(x86::eax, false);
968 a->bind(exit);
969
970 if (remainder && instSet == inst_set_t::avx2 &&
971 (is16bit || isbf16out || isfp16out)) {
972 a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t)));
973 }
974
975 a->emitEpilog(frame);
976
977 // jit_fused8bitembedding_kernel fn;
978 typename ReturnFunctionSignature<
979 inType,
980 indxType,
981 offsetType,
982 outType,
983 ROWWISE_SPARSE>::jit_embedding_kernel fn;
984 asmjit::Error err;
985 {
986 std::unique_lock<std::mutex> lock(rtMutex_);
987 err = runtime().add(&fn, &code);
988 }
989 if (err) {
990 std::cout << "Error: in fn add" << std::endl;
991 return nullptr;
992 }
993
994#if defined(FBGEMM_LOG_CODE)
995 fclose(codeLogFile);
996 delete codeLogger;
997#endif
998 return fn;
999 });
1000}
1001
1002} // namespace
1003
1004template <
1005 typename inType,
1006 typename indxType,
1007 typename offsetType,
1008 typename outType,
1009 bool THREAD_LOCAL>
1010typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
1011 Type
1012 GenerateEmbeddingSpMDMWithStrides(
1013 const int64_t block_size,
1014 bool has_weight,
1015 bool normalize_by_lengths,
1016 int prefetch,
1017 bool is_weight_positional,
1018 bool use_offsets,
1019 int64_t output_stride /*=-1*/,
1020 int64_t input_stride /*=-1*/,
1021 bool scale_bias_last /*=true*/,
1022 bool no_bag /*=false*/,
1023 bool isbf16 /*=false*/) {
1024 if (!cpuinfo_initialize()) {
1025 throw std::runtime_error("Failed to initialize cpuinfo!");
1026 }
1027 if (output_stride == -1) {
1028 output_stride = block_size;
1029 }
1030 if (input_stride == -1) {
1031 if (std::is_same<inType, uint8_t>::value) {
1032 const auto scale_bias_offset =
1033 2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t));
1034 input_stride = block_size + scale_bias_offset;
1035 } else {
1036 input_stride = block_size;
1037 }
1038 }
1039 const inst_set_t isa = fbgemmInstructionSet();
1040 if (no_bag == true) {
1041 return [=](int64_t output_size,
1042 int64_t index_size,
1043 int64_t data_size,
1044 const inType* input,
1045 const indxType* indices,
1046 const offsetType* offsets_or_lengths,
1047 const float* weights,
1048 outType* out) {
1049 return EmbeddingSpMDM_ref(
1050 block_size,
1051 output_size,
1052 index_size,
1053 data_size,
1054 input,
1055 indices,
1056 offsets_or_lengths,
1057 weights,
1058 normalize_by_lengths,
1059 out,
1060 is_weight_positional,
1061 use_offsets,
1062 output_stride,
1063 input_stride,
1064 scale_bias_last,
1065 no_bag,
1066 isbf16);
1067 };
1068 }
1069
1070 if ((std::is_same<inType, float>::value ||
1071 std::is_same<inType, uint16_t>::value) &&
1072 block_size == 1 && isYmm(isa) && output_stride == block_size &&
1073 input_stride == block_size && std::is_same<outType, float>::value) {
1074 return
1075 [=](int64_t output_size,
1076 int64_t index_size,
1077 int64_t data_size,
1078 const inType* input,
1079 const indxType* indices,
1080 const offsetType* offsets_or_lengths,
1081 const float* weights, // optional, can be null for non-weighted sum
1082 outType* out) {
1083 return internal::EmbeddingSpMDMBlockSize1_(
1084 output_size,
1085 index_size,
1086 data_size,
1087 input,
1088 indices,
1089 offsets_or_lengths,
1090 weights,
1091 normalize_by_lengths,
1092 reinterpret_cast<float*>(out),
1093 is_weight_positional,
1094 use_offsets,
1095 isbf16);
1096 };
1097 } else if (isZmm(isa)) {
1098 static GenEmbeddingSpMDMLookup<
1099 inType,
1100 indxType,
1101 offsetType,
1102 outType,
1103 inst_set_t::avx512,
1104 /*ROWWISE_SPARSE=*/false,
1105 THREAD_LOCAL>
1106 kernel_generator;
1107 const auto original_func = kernel_generator.getOrCreate(
1108 block_size,
1109 has_weight,
1110 is_weight_positional,
1111 normalize_by_lengths,
1112 prefetch,
1113 use_offsets,
1114 output_stride,
1115 input_stride,
1116 scale_bias_last,
1117 isbf16);
1118 return [=](int64_t output_size,
1119 int64_t index_size,
1120 int64_t data_size,
1121 const inType* input,
1122 const indxType* indices,
1123 const offsetType* offsets_or_lengths,
1124 const float* weights,
1125 outType* out) {
1126 return original_func(
1127 output_size,
1128 index_size,
1129 data_size,
1130 input,
1131 indices,
1132 offsets_or_lengths,
1133 weights,
1134 out,
1135 nullptr /* mask not used in avx512 */);
1136 };
1137 } else if (isYmm(isa)) {
1138 static GenEmbeddingSpMDMLookup<
1139 inType,
1140 indxType,
1141 offsetType,
1142 outType,
1143 inst_set_t::avx2,
1144 /*ROWWISE_SPARSE=*/false,
1145 THREAD_LOCAL>
1146 kernel_generator;
1147 const auto original_func = kernel_generator.getOrCreate(
1148 block_size,
1149 has_weight,
1150 is_weight_positional,
1151 normalize_by_lengths,
1152 prefetch,
1153 use_offsets,
1154 output_stride,
1155 input_stride,
1156 scale_bias_last,
1157 isbf16);
1158 return [=](int64_t output_size,
1159 int64_t index_size,
1160 int64_t data_size,
1161 const inType* input,
1162 const indxType* indices,
1163 const offsetType* offsets_or_lengths,
1164 const float* weights,
1165 outType* out) {
1166 return original_func(
1167 output_size,
1168 index_size,
1169 data_size,
1170 input,
1171 indices,
1172 offsets_or_lengths,
1173 weights,
1174 out,
1175 internal::avx2_ps_or_epi32_combined_mask);
1176 };
1177 } else {
1178#ifdef VLOG
1179 VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
1180#endif
1181 return [=](int64_t output_size,
1182 int64_t index_size,
1183 int64_t data_size,
1184 const inType* input,
1185 const indxType* indices,
1186 const offsetType* offsets_or_lengths,
1187 const float* weights,
1188 outType* out) {
1189 return EmbeddingSpMDM_ref(
1190 block_size,
1191 output_size,
1192 index_size,
1193 data_size,
1194 input,
1195 indices,
1196 offsets_or_lengths,
1197 weights,
1198 normalize_by_lengths,
1199 out,
1200 is_weight_positional,
1201 use_offsets,
1202 output_stride,
1203 input_stride,
1204 scale_bias_last,
1205 no_bag,
1206 isbf16);
1207 };
1208 }
1209}
1210
1211template <
1212 typename inType,
1213 typename indxType,
1214 typename offsetType,
1215 typename outType,
1216 bool THREAD_LOCAL>
1217typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
1218 Type
1219 GenerateEmbeddingSpMDM(
1220 const int64_t block_size,
1221 bool has_weight,
1222 bool normalize_by_lengths,
1223 int prefetch,
1224 bool is_weight_positional,
1225 bool use_offsets,
1226 bool isbf16) {
1227 return GenerateEmbeddingSpMDMWithStrides<
1228 inType,
1229 indxType,
1230 offsetType,
1231 outType,
1232 THREAD_LOCAL>(
1233 block_size,
1234 has_weight,
1235 normalize_by_lengths,
1236 prefetch,
1237 is_weight_positional,
1238 use_offsets,
1239 /*output_stride=*/-1,
1240 /*input_stride=*/-1,
1241 /*scale_bias_last=*/true,
1242 /*no_bag=*/false,
1243 isbf16);
1244}
1245
1246template <typename indxType, typename offsetType, typename outType>
1247typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
1248 Type
1249 GenerateEmbeddingSpMDMFP8WithStrides(
1250 const int64_t block_size,
1251 bool normalize_by_lengths,
1252 bool is_weight_positional,
1253 bool use_offsets,
1254 int64_t output_stride /*=-1*/,
1255 int64_t input_stride /*=-1*/,
1256 int exponent_bits,
1257 int exponent_bias) {
1258 if (output_stride == -1) {
1259 output_stride = block_size;
1260 }
1261 if (input_stride == -1) {
1262 input_stride = block_size;
1263 }
1264 // There is only the reference implementation for FP8 embedding
1265 return [=](int64_t output_size,
1266 int64_t index_size,
1267 int64_t data_size,
1268 const uint8_t* input,
1269 const indxType* indices,
1270 const offsetType* offsets_or_lengths,
1271 const float* weights,
1272 outType* out) {
1273 return EmbeddingSpMDMFP8_ref(
1274 block_size,
1275 output_size,
1276 index_size,
1277 data_size,
1278 input,
1279 indices,
1280 offsets_or_lengths,
1281 weights,
1282 normalize_by_lengths,
1283 out,
1284 is_weight_positional,
1285 use_offsets,
1286 output_stride,
1287 input_stride,
1288 exponent_bits,
1289 exponent_bias);
1290 };
1291}
1292
1293template <typename inType, typename indxType, typename offsetType>
1294typename EmbeddingSpMDMRowWiseSparseKernelSignature<
1295 inType,
1296 indxType,
1297 offsetType>::Type
1298GenerateEmbeddingSpMDMRowWiseSparse(
1299 const int64_t block_size,
1300 bool has_weight,
1301 bool normalize_by_lengths,
1302 int prefetch,
1303 bool is_weight_positional,
1304 bool use_offsets) {
1305 if (!cpuinfo_initialize()) {
1306 throw std::runtime_error("Failed to initialize cpuinfo!");
1307 }
1308 int64_t input_stride = block_size;
1309 if (std::is_same<inType, uint8_t>::value) {
1310 const auto scale_bias_offset = 2 * sizeof(float);
1311 input_stride = block_size + scale_bias_offset;
1312 }
1313 inst_set_t isa = fbgemmInstructionSet();
1314 if (isZmm(isa)) {
1315 static GenEmbeddingSpMDMLookup<
1316 inType,
1317 indxType,
1318 offsetType,
1319 /*outType=*/float,
1320 inst_set_t::avx512,
1321 /*rowwise_sparse=*/true>
1322 kernel_generator;
1323 const auto original_func = kernel_generator.getOrCreate(
1324 block_size,
1325 has_weight,
1326 is_weight_positional,
1327 normalize_by_lengths,
1328 prefetch,
1329 use_offsets,
1330 /*output_stride=*/block_size,
1331 input_stride,
1332 /*scale_bias_last=*/true,
1333 /*isbf16=*/false);
1334 return [=](int64_t output_size,
1335 int64_t index_size,
1336 int64_t uncompressed_data_size,
1337 const inType* input,
1338 const indxType* indices,
1339 const offsetType* offsets_or_lengths,
1340 const float* weights,
1341 float* out,
1342 const int32_t* compressed_indices_table) {
1343 return original_func(
1344 output_size,
1345 index_size,
1346 uncompressed_data_size,
1347 input,
1348 indices,
1349 offsets_or_lengths,
1350 weights,
1351 out,
1352 compressed_indices_table,
1353 nullptr /* mask not used in avx512 */);
1354 };
1355 } else if (isYmm(isa)) {
1356 static GenEmbeddingSpMDMLookup<
1357 inType,
1358 indxType,
1359 offsetType,
1360 /*outType=*/float,
1361 inst_set_t::avx2,
1362 /*rowwise_sparse=*/true>
1363 kernel_generator;
1364 const auto original_func = kernel_generator.getOrCreate(
1365 block_size,
1366 has_weight,
1367 is_weight_positional,
1368 normalize_by_lengths,
1369 prefetch,
1370 use_offsets,
1371 /*output_stride=*/block_size,
1372 input_stride,
1373 /*scale_bias_last=*/true,
1374 /*isbf16=*/false);
1375 return [=](int64_t output_size,
1376 int64_t index_size,
1377 int64_t uncompressed_data_size,
1378 const inType* input,
1379 const indxType* indices,
1380 const offsetType* offsets_or_lengths,
1381 const float* weights,
1382 float* out,
1383 const int32_t* compressed_indices_table) {
1384 return original_func(
1385 output_size,
1386 index_size,
1387 uncompressed_data_size,
1388 input,
1389 indices,
1390 offsets_or_lengths,
1391 weights,
1392 out,
1393 compressed_indices_table,
1394 internal::avx2_ps_or_epi32_combined_mask);
1395 };
1396 } else {
1397#ifdef VLOG
1398 VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
1399#endif
1400 return
1401 [=](int64_t output_size,
1402 int64_t index_size,
1403 int64_t uncompressed_data_size,
1404 const inType* input,
1405 const indxType* indices,
1406 const offsetType* offsets_or_lengths,
1407 const float* weights, // optional, can be null for non-weighted sum
1408 float* out,
1409 const int32_t* compressed_indices_table) {
1410 return EmbeddingSpMDMRowWiseSparse_ref(
1411 block_size,
1412 output_size,
1413 index_size,
1414 uncompressed_data_size,
1415 // compressed_data_size,
1416 input,
1417 indices,
1418 compressed_indices_table,
1419 offsets_or_lengths,
1420 weights,
1421 normalize_by_lengths,
1422 out,
1423 is_weight_positional,
1424 use_offsets);
1425 };
1426 }
1427}
1428
1429#define INSTANTIATE_SPMDM_BASE( \
1430 IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \
1431 template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \
1432 IN_TYPE, \
1433 INDEX_TYPE, \
1434 OFFSET_TYPE, \
1435 OUT_TYPE>::Type \
1436 GenerateEmbeddingSpMDMWithStrides< \
1437 IN_TYPE, \
1438 INDEX_TYPE, \
1439 OFFSET_TYPE, \
1440 OUT_TYPE, \
1441 THREAD_LOCAL>( \
1442 const int64_t block_size, \
1443 bool has_weight, \
1444 bool normalize_by_lengths, \
1445 int prefetch, \
1446 bool is_weight_positional, \
1447 bool use_offsets, \
1448 int64_t output_stride, \
1449 int64_t input_stride, \
1450 bool scale_bias_last, \
1451 bool no_bag, \
1452 bool isbf16);
1453
1454#define INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1455 template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \
1456 uint8_t, \
1457 INDEX_TYPE, \
1458 OFFSET_TYPE, \
1459 OUT_TYPE>::Type \
1460 GenerateEmbeddingSpMDMFP8WithStrides<INDEX_TYPE, OFFSET_TYPE, OUT_TYPE>( \
1461 const int64_t block_size, \
1462 bool normalize_by_lengths, \
1463 bool is_weight_positional, \
1464 bool use_offsets, \
1465 int64_t output_stride, \
1466 int64_t input_stride, \
1467 int exponent_bits, \
1468 int exponent_bias);
1469
1470#define INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1471 IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \
1472 template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \
1473 IN_TYPE, \
1474 INDEX_TYPE, \
1475 OFFSET_TYPE, \
1476 OUT_TYPE>::Type \
1477 GenerateEmbeddingSpMDM< \
1478 IN_TYPE, \
1479 INDEX_TYPE, \
1480 OFFSET_TYPE, \
1481 OUT_TYPE, \
1482 THREAD_LOCAL>( \
1483 const int64_t block_size, \
1484 bool has_weight, \
1485 bool normalize_by_lengths, \
1486 int prefetch, \
1487 bool is_weight_positional, \
1488 bool use_offsets, \
1489 bool isbf16);
1490
1491#define INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
1492 template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
1493 IN_TYPE, \
1494 INDEX_TYPE, \
1495 OFFSET_TYPE>::Type \
1496 GenerateEmbeddingSpMDMRowWiseSparse<IN_TYPE, INDEX_TYPE, OFFSET_TYPE>( \
1497 const int64_t block_size, \
1498 bool has_weight, \
1499 bool normalize_by_lengths, \
1500 int prefetch, \
1501 bool is_weight_positional, \
1502 bool use_offsets);
1503
1504#define INSTANTIATE_SPMDMFP8_BASE_uint8_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1505 INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1506#define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1507#define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1508
1509#define INSTANTIATE_SPMDM_THREAD_LOCAL( \
1510 IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1511 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1512 INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
1513 INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1514 IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1515 INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1516 IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
1517 INSTANTIATE_SPMDMFP8_BASE_##IN_TYPE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
1518
1519#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
1520 INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
1521 INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \
1522 INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE)
1523
1524#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
1525 INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int32_t) \
1526 INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int64_t)
1527
1528#define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \
1529 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int32_t) \
1530 INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int64_t)
1531
1532INSTANTIATE_SPMDM_INDEX_T(float)
1533INSTANTIATE_SPMDM_INDEX_T(uint16_t)
1534INSTANTIATE_SPMDM_INDEX_T(uint8_t)
1535
1536#undef INSTANTIATE_SPMDM_INDEX_T
1537#undef INSTANTIATE_SPMDM_OFFSET_T
1538#undef INSTANTIATE_SPMDM_OUT_T
1539#undef INSTANTIATE_SPMDM_THREAD_LOCAL
1540#undef INSTANTIATE_SPMDM_BASE
1541#undef INSTANTIATE_SPMDMFP8_BASE
1542#undef INSTANTIATE_SPMDM_NOSTRIDE_BASE
1543#undef INSTANTIATE_SPMDM_ROWWISE_BASE
1544
1545template <typename IndexType>
1546void compressed_indices_remap(
1547 std::int32_t offsets_len,
1548 const IndexType* indices,
1549 const int32_t* compressed_indices_mapping,
1550 const IndexType* offsets,
1551 const float* weights, // optional, can be null,
1552 IndexType* out_indices,
1553 IndexType* out_offsets,
1554 float* out_weights) {
1555 if (!cpuinfo_initialize()) {
1556 throw std::runtime_error("Failed to initialize cpuinfo!");
1557 }
1558
1559 const inst_set_t isa = fbgemmInstructionSet();
1560 if (isZmm(isa)) {
1561#ifndef __HIP_PLATFORM_HCC__
1562 if (weights == nullptr) {
1563 internal::compressed_indices_remap_avx512<IndexType, false>(
1564 offsets_len,
1565 indices,
1566 compressed_indices_mapping,
1567 offsets,
1568 weights,
1569 out_indices,
1570 out_offsets,
1571 out_weights);
1572 } else {
1573 internal::compressed_indices_remap_avx512<IndexType, true>(
1574 offsets_len,
1575 indices,
1576 compressed_indices_mapping,
1577 offsets,
1578 weights,
1579 out_indices,
1580 out_offsets,
1581 out_weights);
1582 }
1583#endif
1584 } else {
1585 compressed_indices_remap_ref<IndexType>(
1586 offsets_len,
1587 indices,
1588 compressed_indices_mapping,
1589 offsets,
1590 weights,
1591 out_indices,
1592 out_offsets,
1593 out_weights);
1594 }
1595}
1596
1597#define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \
1598 template FBGEMM_API void compressed_indices_remap( \
1599 std::int32_t offsets_numel, \
1600 const INDEX_TYPE* indices, \
1601 const int32_t* compressed_indices_mapping, \
1602 const INDEX_TYPE* offsets, \
1603 const float* weights, \
1604 INDEX_TYPE* out_indices, \
1605 INDEX_TYPE* out_offsets, \
1606 float* out_weights);
1607
1608INSTANTIATE_REMAP_BASE(int32_t)
1609INSTANTIATE_REMAP_BASE(int64_t)
1610
1611#undef INSTANTIATE_REMAP_BASE
1612
1613} // namespace fbgemm
1614