1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8
9#include "fbgemm/FbgemmEmbedding.h"
10
11#include <asmjit/asmjit.h>
12#include <cpuinfo.h>
13#include <cassert>
14#include <cmath>
15#include <iostream>
16#include <map>
17#include <mutex>
18#include <string>
19#include <tuple>
20#include "./CodeCache.h"
21#include "./MaskAvx2.h"
22#include "./RefImplementations.h"
23#include "fbgemm/SimdUtils.h"
24#include "fbgemm/Types.h"
25
26using namespace std;
27
28namespace fbgemm {
29
30namespace {
31
32template <typename T>
33T ceil_div(T a, T b) {
34 return (a + b - 1) / b;
35}
36
37namespace x86 = asmjit::x86;
38
39template <
40 typename indxType,
41 typename offsetType,
42 typename outType,
43 bool ROWWISE_SPARSE>
44class ReturnFunctionSignature {};
45
46template <typename indxType, typename offsetType, typename outType>
47class ReturnFunctionSignature<indxType, offsetType, outType, false> {
48 public:
49 using jit_embedding_kernel = bool (*)(
50 int64_t output_size,
51 int64_t index_size,
52 int64_t data_size,
53 const uint8_t* input,
54 const indxType* indices,
55 const offsetType* offsets_or_lengths,
56 const float* weights,
57 outType* out,
58 const int* mask);
59};
60
61template <typename indxType, typename offsetType, typename outType>
62class ReturnFunctionSignature<indxType, offsetType, outType, true> {
63 public:
64 using jit_embedding_kernel = bool (*)(
65 int64_t output_size,
66 int64_t index_size,
67 int64_t uncompressed_data_size,
68 // int64_t compressed_data_size,
69 const uint8_t* input,
70 const indxType* indices,
71 const offsetType* offsets_or_lengths,
72 const float* weights,
73 outType* out,
74 const int32_t* compressed_indices_table,
75 const int* mask);
76};
77
78template <
79 typename indxType,
80 typename offsetType,
81 typename outType,
82 inst_set_t instSet,
83 bool ROWWISE_SPARSE = false,
84 bool THREAD_LOCAL = false>
85class GenEmbeddingSpMDMNBitLookup {
86 public:
87 GenEmbeddingSpMDMNBitLookup() {}
88 typename ReturnFunctionSignature<
89 indxType,
90 offsetType,
91 outType,
92 ROWWISE_SPARSE>::jit_embedding_kernel
93 getOrCreate(
94 int bit_rate,
95 int block_size,
96 bool has_weight,
97 bool is_weight_positional,
98 bool normalize_by_lengths,
99 int prefetch,
100 bool use_offsets,
101 int output_stride,
102 int input_stride,
103 bool scale_bias_last);
104
105 private:
106 static asmjit::JitRuntime& runtime() {
107 static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
108 // depents on other static
109 // variables. Required to prevent
110 // initialization order fiasco
111 return rt;
112 }
113
114 static mutex rtMutex_; ///< Controll access to runtime;
115
116 // The hash depends on bit_rate, embedding dimension (block size), weighted
117 // sls, positional weights, normalize by lenths, prefetch distance,
118 // use_offsets, output_stride, input_stride, and scale_bias_last
119 static CodeCache<
120 tuple<int, int, bool, bool, bool, int, bool, int, int, bool>,
121 typename ReturnFunctionSignature<
122 indxType,
123 offsetType,
124 outType,
125 ROWWISE_SPARSE>::jit_embedding_kernel,
126 THREAD_LOCAL>
127 codeCache_; ///< JIT Code Cache for reuse.
128}; // GenEmbeddingSpmDMLookup
129
130template <
131 typename indxType,
132 typename offsetType,
133 typename outType,
134 inst_set_t instSet,
135 bool ROWWISE_SPARSE,
136 bool THREAD_LOCAL>
137mutex GenEmbeddingSpMDMNBitLookup<
138 indxType,
139 offsetType,
140 outType,
141 instSet,
142 ROWWISE_SPARSE,
143 THREAD_LOCAL>::rtMutex_;
144
145template <
146 typename indxType,
147 typename offsetType,
148 typename outType,
149 inst_set_t instSet,
150 bool ROWWISE_SPARSE,
151 bool THREAD_LOCAL>
152CodeCache<
153 tuple<int, int, bool, bool, bool, int, bool, int, int, bool>,
154 typename ReturnFunctionSignature<
155 indxType,
156 offsetType,
157 outType,
158 ROWWISE_SPARSE>::jit_embedding_kernel,
159 THREAD_LOCAL>
160 GenEmbeddingSpMDMNBitLookup<
161 indxType,
162 offsetType,
163 outType,
164 instSet,
165 ROWWISE_SPARSE,
166 THREAD_LOCAL>::codeCache_;
167
168template <
169 typename indxType,
170 typename offsetType,
171 typename outType,
172 inst_set_t instSet,
173 bool ROWWISE_SPARSE,
174 bool THREAD_LOCAL>
175typename ReturnFunctionSignature<
176 indxType,
177 offsetType,
178 outType,
179 ROWWISE_SPARSE>::jit_embedding_kernel
180GenEmbeddingSpMDMNBitLookup<
181 indxType,
182 offsetType,
183 outType,
184 instSet,
185 ROWWISE_SPARSE,
186 THREAD_LOCAL>::
187 getOrCreate(
188 int bit_rate,
189 int block_size,
190 bool has_weight,
191 bool is_weight_positional,
192 bool normalize_by_lengths,
193 int prefetch,
194 bool use_offsets,
195 int output_stride,
196 int input_stride,
197 bool scale_bias_last) {
198 tuple<int, int, bool, bool, bool, int, bool, int, int, bool> kernelSig =
199 make_tuple(
200 bit_rate,
201 block_size,
202 has_weight,
203 is_weight_positional,
204 normalize_by_lengths,
205 prefetch,
206 use_offsets,
207 output_stride,
208 input_stride,
209 scale_bias_last);
210
211 return codeCache_.getOrCreate(
212 kernelSig,
213 [&]() -> typename ReturnFunctionSignature<
214 indxType,
215 offsetType,
216 outType,
217 ROWWISE_SPARSE>::jit_embedding_kernel {
218 // TODO: Make this tunable
219 int pref_dist = prefetch;
220 bool areIndices64b = is_same<indxType, int64_t>::value;
221
222 asmjit::CodeHolder code;
223 code.init(runtime().environment());
224 x86::Assembler assembler(&code);
225 x86::Emitter* a = assembler.as<x86::Emitter>();
226#if defined(FBGEMM_LOG_CODE)
227 string filename = "embeddinglookup_" + to_string(bit_rate) + "bit";
228 filename += "_emd_dim_" + to_string(block_size);
229 filename += areIndices64b ? "_64bit" : "_32bit";
230 filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2";
231 if (prefetch) {
232 filename += "_prefetch";
233 }
234 if (has_weight) {
235 filename += "_hasweight";
236 }
237 if (normalize_by_lengths) {
238 filename += "_normalize_by_lengths";
239 }
240 if (!use_offsets) {
241 filename += "_use_lengths";
242 }
243 if (ROWWISE_SPARSE) {
244 filename += "_rowwise_sparse";
245 }
246 if (!scale_bias_last) {
247 filename += "_scale_bias_first"
248 }
249 filename += ".txt";
250 FILE* codeLogFile = fopen(filename.c_str(), "w");
251 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
252 code.setLogger(codeLogger);
253#endif
254 // arguments to the function created
255 x86::Gp output_size = a->zdi();
256 // index_size will be overwritten to hold the end address of indices
257 x86::Gp index_size = a->zsi();
258 x86::Gp data_size = a->zdx();
259 x86::Gp input = a->zcx();
260 int reg_id = 8;
261 x86::Gp indices = a->gpz(reg_id); // 8
262 ++reg_id;
263 x86::Gp lengths = a->gpz(reg_id); // 9
264 ++reg_id;
265 x86::Gp weights = a->gpz(reg_id); // 10
266 ++reg_id;
267 x86::Gp out = a->gpz(reg_id); // 11
268
269 x86::Gp compressed_indices_table;
270 if (ROWWISE_SPARSE) {
271 ++reg_id;
272 compressed_indices_table = a->gpz(reg_id); // 12
273 }
274
275 ++reg_id;
276 x86::Gp scratchReg1_ = a->gpz(reg_id); // 12 or 13
277
278 ++reg_id;
279 x86::Gpd lengths_R_ = a->gpz(reg_id).r32(); // 13 or 14
280
281 ++reg_id;
282 x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15
283 x86::Gp scratchReg3_;
284 if (instSet == inst_set_t::avx2) {
285 scratchReg3_ = a->zax();
286 }
287
288 asmjit::FuncDetail func;
289
290 if (ROWWISE_SPARSE) {
291 func.init(
292 asmjit::FuncSignatureT<
293 bool,
294 int64_t, // output_size
295 int64_t, // index_size
296 int64_t, // uncompressed_data_size
297 const uint8_t*, // input uint8_t or float
298 const indxType*, // indices
299 const offsetType*, // offsets or lengths
300 const float*, // weights
301 float*, // out
302 const int32_t* /* compressed_indices_table */,
303 const int* /* mask */>(asmjit::CallConvId::kHost),
304 a->environment());
305 } else {
306 func.init(
307 asmjit::FuncSignatureT<
308 bool,
309 int64_t, // output_size
310 int64_t, // index_size
311 int64_t, // data_size
312 const uint8_t*, // input uint8_t or float
313 const indxType*, // indices
314 const offsetType*, // offsets or lengths
315 const float*, // weights
316 float*, // out
317 const int* /* mask */>(asmjit::CallConvId::kHost),
318 a->environment());
319 }
320
321 asmjit::FuncFrame frame;
322 frame.init(func);
323
324 frame.setDirtyRegs(
325 asmjit::RegGroup::kVec,
326 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
327 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
328 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
329 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
330
331 frame.setDirtyRegs(
332 asmjit::RegGroup::kGp,
333 reg_id == 15
334 ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
335 : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
336
337 asmjit::FuncArgsAssignment args(&func);
338 if (ROWWISE_SPARSE) {
339 args.assignAll(
340 output_size,
341 index_size,
342 data_size,
343 input,
344 indices,
345 lengths,
346 weights,
347 out,
348 compressed_indices_table,
349 scratchReg1_);
350 } else {
351 args.assignAll(
352 output_size,
353 index_size,
354 data_size,
355 input,
356 indices,
357 lengths,
358 weights,
359 out,
360 scratchReg1_);
361 }
362
363 args.updateFuncFrame(frame);
364 frame.finalize();
365
366 a->emitProlog(frame);
367 a->emitArgsAssignment(frame, args);
368
369 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
370 constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
371 int unroll_factor = NUM_VEC_REG;
372
373 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
374
375 int num_vec_regs_per_block = ceil_div(block_size, vlen);
376 const int remainder = block_size % vlen;
377
378 // Compute a remainder for vector load
379 // Since every row is followed by 2 fp16 (scale and bias), luckily
380 // we don't need mask at bit-rate granularity but just at 32-bit
381 // granularity.
382 int num_elem_per_32bit = 32 / bit_rate;
383 // multiply by 4 because we're handling 4 vlen per iteration
384 int num_of_32bit_per_vload = vlen * 4 / num_elem_per_32bit;
385 int remainder_32bit_granularity =
386 ceil_div(block_size, num_elem_per_32bit) % num_of_32bit_per_vload;
387
388 vec_reg_t scale_vreg; // holds scale
389 vec_reg_t bias_vreg; // holds bias
390 vec_reg_t w_vreg; // for weighted sls -- weights
391 vec_reg_t
392 vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i]
393 vec_reg_t src_vreg; // for holding embedding value temporarily
394 x86::Ymm mask_vreg; // mask for avx2
395 x86::Xmm mask2_vreg;
396 x86::Xmm mask_fp16_vreg;
397
398 // We need 2 vec registers for 1. scale 2. bias
399 --unroll_factor;
400 scale_vreg = vec_reg_t(unroll_factor);
401 --unroll_factor;
402 bias_vreg = vec_reg_t(unroll_factor);
403
404 --unroll_factor;
405 src_vreg = vec_reg_t(unroll_factor);
406 // temporary register for bit manipulation instructions
407 --unroll_factor;
408 vec_reg_t temp_vreg = vec_reg_t(unroll_factor);
409 vec_reg_t temp2_vreg;
410
411 --unroll_factor;
412 temp2_vreg = vec_reg_t(unroll_factor);
413
414 // Create a mask that extracts lower bit_rate bits from each 8-bit block
415 --unroll_factor;
416 vec_reg_t extract_mask_vreg = vec_reg_t(unroll_factor);
417 a->lea(
418 x86::rsp,
419 x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
420 if (bit_rate == 4) {
421 a->mov(x86::word_ptr(x86::rsp), 0x0f0f);
422 a->vpbroadcastw(extract_mask_vreg, x86::word_ptr(x86::rsp));
423 } else {
424 a->mov(x86::dword_ptr(x86::rsp), 0x03030303);
425 a->vpbroadcastd(extract_mask_vreg, x86::dword_ptr(x86::rsp));
426 }
427 a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));
428
429 if (has_weight) {
430 --unroll_factor;
431 w_vreg = vec_reg_t(unroll_factor);
432 }
433
434 if (remainder && instSet == inst_set_t::avx2) {
435 // AVX512 doesn't need to use vector register for masking
436 --unroll_factor;
437 mask_vreg = x86::ymm(unroll_factor);
438 if (remainder > 1 && std::is_same<outType, float16>::value) {
439 --unroll_factor;
440 mask_fp16_vreg = x86::xmm(unroll_factor);
441 }
442 }
443
444 // Creating a mask for vector load
445 if (remainder_32bit_granularity && instSet == inst_set_t::avx2) {
446 // AVX512 doesn't need to use vector register for masking
447 --unroll_factor;
448 mask2_vreg = x86::xmm(unroll_factor);
449 }
450
451 if (normalize_by_lengths) {
452 --unroll_factor;
453 vlen_inv_vreg = vec_reg_t(unroll_factor);
454 }
455
456 // Make unroll_factor a multiple of 4
457 unroll_factor = unroll_factor / 4 * 4;
458
459 if (remainder) {
460 if (instSet == inst_set_t::avx2) {
461 a->vmovups(
462 mask_vreg,
463 x86::ymmword_ptr(
464 scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t)));
465 if (std::is_same<outType, float16>::value) {
466 if (remainder > 1) {
467 a->vmovups(
468 mask_fp16_vreg,
469 x86::xmmword_ptr(
470 scratchReg1_,
471 (vlen - remainder / 2) * sizeof(int32_t)));
472 }
473 // We need to keep using the stack during the main loop
474 a->lea(
475 x86::rsp,
476 x86::dword_ptr(
477 x86::rsp, static_cast<int32_t>(-vlen * sizeof(int32_t))));
478 }
479 } else {
480 a->mov(scratchReg1_, (1 << remainder) - 1);
481 a->kmovw(x86::k(1), scratchReg1_);
482 }
483 }
484
485 if (remainder_32bit_granularity) {
486 if (instSet == inst_set_t::avx2) {
487 a->lea(
488 x86::rsp,
489 x86::dword_ptr(
490 x86::rsp, (int32_t)(-(vlen / 2) * sizeof(int32_t))));
491 for (int i = 0; i < remainder_32bit_granularity; i++) {
492 a->mov(x86::dword_ptr(x86::rsp, i * sizeof(int32_t)), -1);
493 }
494 for (int i = remainder_32bit_granularity; i < vlen / 2; i++) {
495 a->mov(x86::dword_ptr(x86::rsp, i * sizeof(int32_t)), 0);
496 }
497 a->vmovups(mask2_vreg, x86::dword_ptr(x86::rsp));
498 a->lea(
499 x86::rsp,
500 x86::dword_ptr(
501 x86::rsp, (int32_t)((vlen / 2) * sizeof(int32_t))));
502 } else {
503 a->mov(scratchReg1_, (1 << remainder_32bit_granularity) - 1);
504 a->kmovw(x86::k(2), scratchReg1_);
505 }
506 }
507
508 // Compute the end address of indices
509 a->lea(
510 index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2));
511
512 asmjit::Label exit = a->newLabel();
513 asmjit::Label error = a->newLabel();
514 asmjit::Label LoopRangeIndexBegin = a->newLabel();
515 asmjit::Label LoopRangeIndexEnd = a->newLabel();
516
517 // rangeIndex loop begins (iterate output_size times)
518 a->bind(LoopRangeIndexBegin);
519 a->dec(output_size);
520 a->jl(LoopRangeIndexEnd);
521
522 if (normalize_by_lengths) {
523 asmjit::Label IfLengthsBegin = a->newLabel();
524 asmjit::Label IfLengthsEnd = a->newLabel();
525 a->bind(IfLengthsBegin);
526 if (use_offsets) {
527 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
528 a->sub(lengths_R_, x86::dword_ptr(lengths));
529 } else {
530 a->mov(lengths_R_, x86::dword_ptr(lengths));
531 }
532 a->cmp(lengths_R_, 1);
533 // Initialize vlen_inv as 0 in case lengths is 0
534 a->vxorps(vlen_inv_vreg, vlen_inv_vreg, vlen_inv_vreg);
535 a->jl(IfLengthsEnd);
536
537 vec_reg_t temp_vreg0(0);
538 if (instSet == inst_set_t::avx2) {
539 a->mov(scratchReg1_, 1);
540 a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_);
541 a->cvtsi2ss(temp_vreg0.xmm(), lengths_R_);
542 a->divss(vlen_inv_vreg.xmm(), temp_vreg0.xmm());
543 a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm());
544 } else {
545 a->mov(scratchReg1_, 1);
546 a->cvtsi2ss(temp_vreg0.xmm(), scratchReg1_);
547 a->vpbroadcastd(vlen_inv_vreg, temp_vreg0.xmm());
548 a->vpbroadcastd(temp_vreg0, lengths_R_);
549 a->vcvtdq2ps(temp_vreg0, temp_vreg0);
550 a->vdivps(vlen_inv_vreg, vlen_inv_vreg, temp_vreg0);
551 }
552 a->bind(IfLengthsEnd);
553 }
554
555 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
556 vec_idx += unroll_factor) {
557 int cur_unroll_factor =
558 std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
559
560 // Initialize output regs
561 for (int v = 0; v < cur_unroll_factor; ++v) {
562 vec_reg_t out_vreg = vec_reg_t(v);
563 a->vxorps(out_vreg, out_vreg, out_vreg);
564 }
565
566 if (use_offsets) {
567 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
568 a->sub(lengths_R_, x86::dword_ptr(lengths));
569 } else {
570 a->mov(lengths_R_, x86::dword_ptr(lengths));
571 }
572
573 // Array out of bound check
574 a->lea(
575 scratchReg1_,
576 x86::ptr(indices, lengths_R_, areIndices64b ? 3 : 2));
577 a->cmp(scratchReg1_, index_size);
578 a->jg(error);
579
580 asmjit::Label LoopDataIndexBegin = a->newLabel();
581 asmjit::Label LoopDataIndexEnd = a->newLabel();
582 asmjit::Label ValidIndexLabel = a->newLabel();
583
584 // dataIndex loop begins (iterate lengths_R_ times)
585 a->bind(LoopDataIndexBegin);
586 a->dec(lengths_R_);
587 a->jl(LoopDataIndexEnd);
588
589 // Array out of bound check
590 if (areIndices64b) {
591 a->mov(scratchReg1_, x86::qword_ptr(indices));
592 } else {
593 a->mov(scratchReg1_.r32(), x86::dword_ptr(indices));
594 }
595 if (!scale_bias_last) {
596 // When scale_bias_last == false, assume this is for table batched
597 // embedding (TBE) that can get -1 for pruned rows.
598 if (areIndices64b) {
599 a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1));
600 } else {
601 a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
602 }
603 a->jne(ValidIndexLabel);
604 a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
605 if (has_weight) {
606 a->add(weights, static_cast<asmjit::Imm>(sizeof(float)));
607 }
608 a->jmp(LoopDataIndexBegin);
609 a->bind(ValidIndexLabel);
610 }
611 // A trick to check x >= data_size or x < 0 in one shot by treating
612 // scratchReg1_ as if it has unsigned value
613 // (https://stackoverflow.com/a/34072155).
614 a->cmp(scratchReg1_, data_size);
615 a->jae(error);
616
617 if (ROWWISE_SPARSE) {
618 a->mov(
619 scratchReg1_.r32(),
620 x86::dword_ptr(
621 compressed_indices_table,
622 scratchReg1_,
623 2)); // use of 2 is to multiply by 4
624 }
625
626 int num_elem_per_byte = 8 / bit_rate;
627 int fused_block_size = input_stride;
628 if (pref_dist) {
629 asmjit::Label pref_dist_reset_start = a->newLabel();
630 asmjit::Label pref_dist_reset_end = a->newLabel();
631 // out of bound handling for prefetch
632 a->lea(
633 scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType)));
634 a->cmp(scratchReg2_, index_size);
635 a->jge(pref_dist_reset_start);
636
637 if (areIndices64b) {
638 a->mov(
639 scratchReg2_,
640 x86::qword_ptr(indices, pref_dist * sizeof(indxType)));
641 } else {
642 a->mov(
643 scratchReg2_.r32(),
644 x86::dword_ptr(indices, pref_dist * sizeof(indxType)));
645 }
646
647 a->jmp(pref_dist_reset_end);
648
649 a->bind(pref_dist_reset_start);
650 // things are not okay just get the current row
651 // this can be improved to getting the max dist row.
652 if (areIndices64b) {
653 a->mov(scratchReg2_, x86::qword_ptr(indices));
654 } else {
655 a->mov(scratchReg2_.r32(), x86::dword_ptr(indices));
656 }
657
658 a->bind(pref_dist_reset_end);
659 if (ROWWISE_SPARSE) {
660 asmjit::Label rowwise_sparse_pref_corner_case_begin =
661 a->newLabel();
662 asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel();
663 a->cmp(scratchReg2_, data_size);
664 a->jae(rowwise_sparse_pref_corner_case_begin);
665
666 a->mov(
667 scratchReg2_.r32(),
668 x86::dword_ptr(
669 compressed_indices_table,
670 scratchReg2_,
671 2)); // use of 2 is to multiply by 4
672 a->test(scratchReg2_.r32(), scratchReg2_.r32());
673 // Check negative
674 a->jns(rowwise_sparse_pref_corner_case_end);
675
676 a->bind(rowwise_sparse_pref_corner_case_begin);
677 // For corner case, just set prefetch row id to 0.
678 a->xor_(scratchReg2_.r32(), scratchReg2_.r32());
679 a->bind(rowwise_sparse_pref_corner_case_end);
680 }
681 // This has to be fused_block_size
682 a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size));
683 }
684
685 a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
686
687 if (has_weight) {
688 a->vbroadcastss(w_vreg, x86::dword_ptr(weights));
689 a->add(weights, static_cast<asmjit::Imm>(sizeof(float)));
690 }
691
692 if (ROWWISE_SPARSE) {
693 a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
694 a->je(LoopDataIndexBegin);
695 }
696
697 a->imul(scratchReg1_, static_cast<asmjit::Imm>(fused_block_size));
698
699 // broadcast the scale
700 x86::Mem scale_src, bias_src;
701 int scale_offset =
702 scale_bias_last ? ceil_div(block_size, num_elem_per_byte) : 0;
703 scale_src = x86::word_ptr(input, scratchReg1_, 0, scale_offset);
704 bias_src = x86::word_ptr(
705 input, scratchReg1_, 0, scale_offset + sizeof(float16));
706 a->vpbroadcastw(scale_vreg.half(), scale_src);
707 a->vpbroadcastw(bias_vreg.half(), bias_src);
708 a->vcvtph2ps(scale_vreg, scale_vreg.half());
709 a->vcvtph2ps(bias_vreg, bias_vreg.half());
710 constexpr unsigned int CACHE_LINE_LEN = 64;
711 if (pref_dist && fused_block_size % CACHE_LINE_LEN > 0 &&
712 fused_block_size % CACHE_LINE_LEN <= 2 * sizeof(float16)) {
713 a->prefetcht0(x86::dword_ptr(
714 input,
715 scratchReg2_,
716 0,
717 fused_block_size / CACHE_LINE_LEN * CACHE_LINE_LEN));
718 }
719
720 if (has_weight) {
721 a->vmulps(scale_vreg, scale_vreg, w_vreg);
722 a->vmulps(bias_vreg, bias_vreg, w_vreg);
723 }
724
725 // The main computation
726 // Handling 4 vector registers per iteration because
727 // 1) when bit_rate == 4, we get zmm from ymm load via vpmovzxbw
728 // (epu8->epi16), and then get 4 zmms from each 128-bit portion of
729 // zmm via vpmovsxbd (epi8->epi32).
730 // 2) when bit_rate == 2, we get zmm from xmm load via vpmovzxbd
731 // (epu8->epi32), and then get 4 zmms from each 128-bit portion of
732 // zmm via vpmovsxbd (epi8->epi32).
733 int src_addr_offset = scale_bias_last ? 0 : 2 * sizeof(float16);
734 for (int v = 0; v < cur_unroll_factor; v += 4) {
735 int bytes_per_vload = (vlen / num_elem_per_byte) * sizeof(uint8_t);
736 auto src_addr = x86::dword_ptr(
737 input,
738 scratchReg1_,
739 0,
740 src_addr_offset + (vec_idx + v) * bytes_per_vload);
741
742 if (bit_rate == 4) {
743 if (num_vec_regs_per_block - (vec_idx + v) < 4 &&
744 remainder_32bit_granularity) {
745 if (instSet == inst_set_t::avx512) {
746 a->k(x86::k(2)).vmovups(src_vreg.ymm(), src_addr);
747 } else {
748 a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr);
749 }
750 a->vpmovzxbw(src_vreg, src_vreg.half());
751 } else {
752 a->vpmovzxbw(src_vreg, src_addr);
753 }
754 a->vpslld(temp_vreg, src_vreg, asmjit::Imm(4));
755 if (instSet == inst_set_t::avx512) {
756 a->vpord(src_vreg, src_vreg, temp_vreg);
757 a->vpandd(src_vreg, src_vreg, extract_mask_vreg);
758 } else {
759 a->vpor(src_vreg.ymm(), src_vreg.ymm(), temp_vreg.ymm());
760 a->vpand(
761 src_vreg.ymm(), src_vreg.ymm(), extract_mask_vreg.ymm());
762 }
763 } else {
764 if (num_vec_regs_per_block - (vec_idx + v) < 4 &&
765 remainder_32bit_granularity) {
766 if (instSet == inst_set_t::avx512) {
767 a->k(x86::k(2)).vmovups(src_vreg.xmm(), src_addr);
768 a->vpmovzxbd(src_vreg, src_vreg.xmm());
769 } else {
770 a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr);
771 a->vpmovzxbd(src_vreg, src_vreg.xmm());
772 }
773 } else {
774 a->vpmovzxbd(src_vreg, src_addr);
775 }
776 a->vpslld(temp_vreg, src_vreg, 2 * 8 + 2);
777 a->vpslld(temp2_vreg, src_vreg, 8 + 4);
778 if (instSet == inst_set_t::avx512) {
779 a->vpord(temp_vreg, temp_vreg, temp2_vreg);
780 } else {
781 a->vpor(temp_vreg.ymm(), temp_vreg.ymm(), temp2_vreg.ymm());
782 }
783 a->vpslld(temp2_vreg, src_vreg, 6);
784 if (instSet == inst_set_t::avx512) {
785 a->vpord(temp_vreg, temp_vreg, temp2_vreg);
786 a->vpord(src_vreg, temp_vreg, src_vreg);
787 a->vpandd(src_vreg, src_vreg, extract_mask_vreg);
788 } else {
789 a->vpor(temp_vreg.ymm(), temp_vreg.ymm(), temp2_vreg.ymm());
790 a->vpor(src_vreg.ymm(), temp_vreg.ymm(), src_vreg.ymm());
791 a->vpand(
792 src_vreg.ymm(), src_vreg.ymm(), extract_mask_vreg.ymm());
793 }
794 }
795
796 // AVX2: For the following loop, operations on src_vreg impact the
797 // next iteration. For i = 0, we make a copy. i = 1 just right
798 // shifts and uses it. i = 2 we extract upper 128 bits from the copy
799 // to src_vreg and use it. i = 3 just right shifts it and uses it.
800 for (int i = 0;
801 i < std::min(4, num_vec_regs_per_block - (vec_idx + v));
802 ++i) {
803 vec_reg_t out_vreg = vec_reg_t(v + i);
804 if (i == 0) {
805 a->vpmovsxbd(temp_vreg, src_vreg.xmm());
806 // this is only needed for avx2
807 if (instSet == inst_set_t::avx2) {
808 a->vmovups(temp2_vreg, src_vreg);
809 }
810 } else {
811 if (instSet == inst_set_t::avx512) {
812 // We could've used avx512_ymm for clock frequency advantage,
813 // if there's an instruction to extract a 64-bit portion from
814 // a YMM as an XMM register.
815 a->vextracti32x4(temp_vreg.xmm(), src_vreg, asmjit::Imm(i));
816 a->vpmovsxbd(temp_vreg, temp_vreg.xmm());
817 } else {
818 if (i == 1) {
819 a->vpsrldq(src_vreg, src_vreg, asmjit::Imm(8));
820 } else if (i == 2) {
821 a->vextractf128(
822 src_vreg.xmm(), temp2_vreg.ymm(), asmjit::Imm(i >> 1));
823 } else {
824 a->vpsrldq(src_vreg, src_vreg, asmjit::Imm(8));
825 }
826 a->vpmovsxbd(temp_vreg, src_vreg.xmm());
827 } // avx2
828 } // i > 0
829 a->vcvtdq2ps(temp_vreg, temp_vreg);
830 a->vaddps(out_vreg, out_vreg, bias_vreg);
831 a->vfmadd231ps(out_vreg, temp_vreg, scale_vreg);
832 } // for each i
833
834 int vload_per_cache_line = CACHE_LINE_LEN / bytes_per_vload;
835 int v_aligned = ceil_div(vec_idx + v, 4) * 4;
836 if (pref_dist && v_aligned % vload_per_cache_line == 0) {
837 a->prefetcht0(x86::dword_ptr(
838 input, scratchReg2_, 0, v_aligned * bytes_per_vload));
839 }
840 }
841
842 a->jmp(LoopDataIndexBegin);
843 a->bind(LoopDataIndexEnd);
844
845 // This loop is for writing back out_vreg (results)
846 // back to memory
847 for (int v = 0; v < cur_unroll_factor; ++v) {
848 auto dst_addr =
849 x86::dword_ptr(out, (vec_idx + v) * vlen * sizeof(outType));
850 vec_reg_t out_vreg = vec_reg_t(v);
851
852 if (normalize_by_lengths) {
853 a->vmulps(out_vreg, out_vreg, vlen_inv_vreg);
854 }
855
856 if (std::is_same<outType, float>::value) {
857 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
858 if (instSet == inst_set_t::avx512) {
859 a->k(x86::k(1)).vmovups(dst_addr, out_vreg);
860 } else {
861 a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm());
862 }
863 } else {
864 a->vmovups(dst_addr, out_vreg);
865 }
866 } else {
867 // fp16 output
868 if (instSet == inst_set_t::avx2) {
869 // round nearest with no exception
870 a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8);
871 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
872 if (remainder > 1) {
873 a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm());
874 }
875 if (remainder % 2 != 0) {
876 a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm());
877 a->mov(
878 scratchReg1_.r16(),
879 x86::word_ptr(
880 x86::rsp, (remainder - 1) * sizeof(outType)));
881 a->mov(
882 x86::word_ptr(
883 out,
884 ((vec_idx + v) * vlen + (remainder - 1)) *
885 sizeof(outType)),
886 scratchReg1_.r16());
887 }
888 } else {
889 a->vmovups(dst_addr, out_vreg.xmm());
890 }
891 } else {
892 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
893 a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8);
894 } else {
895 a->vcvtps2ph(dst_addr, out_vreg, 8);
896 }
897 }
898 }
899 }
900
901 if (vec_idx + unroll_factor < num_vec_regs_per_block ||
902 (has_weight && is_weight_positional)) {
903 // Reset lengths_R_, indices, weights to run the dataIndex loop
904 // again
905 if (use_offsets) {
906 a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType)));
907 a->sub(lengths_R_, x86::dword_ptr(lengths));
908 } else {
909 a->mov(lengths_R_, x86::dword_ptr(lengths));
910 }
911
912 if (has_weight) {
913 a->imul(
914 scratchReg1_,
915 lengths_R_,
916 static_cast<asmjit::Imm>(sizeof(float)));
917 a->sub(weights, scratchReg1_);
918
919 if (vec_idx + unroll_factor < num_vec_regs_per_block) {
920 a->imul(
921 scratchReg1_,
922 static_cast<asmjit::Imm>(sizeof(indxType) / sizeof(float)));
923 a->sub(indices, scratchReg1_);
924 }
925 } else {
926 a->imul(
927 scratchReg1_,
928 lengths_R_,
929 static_cast<asmjit::Imm>(sizeof(indxType)));
930 a->sub(indices, scratchReg1_);
931 }
932 }
933 }
934
935 a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType)));
936 a->add(out, static_cast<asmjit::Imm>(output_stride * sizeof(outType)));
937
938 a->jmp(LoopRangeIndexBegin);
939 a->bind(LoopRangeIndexEnd);
940
941 a->cmp(indices, index_size);
942 a->jne(error);
943 a->mov(x86::eax, true);
944 a->jmp(exit);
945 a->bind(error);
946 a->mov(x86::eax, false);
947 a->bind(exit);
948
949 if (remainder && instSet == inst_set_t::avx2 &&
950 std::is_same<outType, float16>::value) {
951 a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t)));
952 }
953
954 a->emitEpilog(frame);
955
956 // jit_fused8bitembedding_kernel fn;
957 typename ReturnFunctionSignature<
958 indxType,
959 offsetType,
960 outType,
961 ROWWISE_SPARSE>::jit_embedding_kernel fn;
962 asmjit::Error err;
963 {
964 unique_lock<mutex> lock(rtMutex_);
965 err = runtime().add(&fn, &code);
966 }
967 if (err) {
968 cout << "Error: in fn add" << endl;
969 return nullptr;
970 }
971
972#if defined(FBGEMM_LOG_CODE)
973 fclose(codeLogFile);
974 delete codeLogger;
975#endif
976 return fn;
977 });
978}
979
980} // namespace
981
982template <
983 typename indxType,
984 typename offsetType,
985 typename outType,
986 bool THREAD_LOCAL>
987typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
988 Type
989 GenerateEmbeddingSpMDMNBitWithStrides(
990 int bit_rate,
991 const int64_t block_size,
992 bool has_weight,
993 bool normalize_by_lengths,
994 int prefetch,
995 bool is_weight_positional,
996 bool use_offsets,
997 int64_t output_stride /*=-1*/,
998 int64_t input_stride /*=-1*/,
999 bool scale_bias_last /*=true*/) {
1000 assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
1001
1002 if (!cpuinfo_initialize()) {
1003 throw runtime_error("Failed to initialize cpuinfo!");
1004 }
1005 if (output_stride == -1) {
1006 output_stride = block_size;
1007 }
1008 if (input_stride == -1) {
1009 int64_t num_elem_per_byte = 8 / bit_rate;
1010 input_stride =
1011 ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(float16);
1012 }
1013 if (fbgemmHasAvx512Support()) {
1014 static GenEmbeddingSpMDMNBitLookup<
1015 indxType,
1016 offsetType,
1017 outType,
1018 inst_set_t::avx512,
1019 /*ROWWISE_SPARSE=*/false,
1020 THREAD_LOCAL>
1021 kernel_generator;
1022 const auto original_func = kernel_generator.getOrCreate(
1023 bit_rate,
1024 block_size,
1025 has_weight,
1026 is_weight_positional,
1027 normalize_by_lengths,
1028 prefetch,
1029 use_offsets,
1030 output_stride,
1031 input_stride,
1032 scale_bias_last);
1033 return [=](int64_t output_size,
1034 int64_t index_size,
1035 int64_t data_size,
1036 const uint8_t* input,
1037 const indxType* indices,
1038 const offsetType* offsets_or_lengths,
1039 const float* weights,
1040 outType* out) {
1041 return original_func(
1042 output_size,
1043 index_size,
1044 data_size,
1045 input,
1046 indices,
1047 offsets_or_lengths,
1048 weights,
1049 out,
1050 nullptr /* mask not used in avx512 */);
1051 };
1052 } else if (fbgemmHasAvx2Support()) {
1053 static GenEmbeddingSpMDMNBitLookup<
1054 indxType,
1055 offsetType,
1056 outType,
1057 inst_set_t::avx2,
1058 /*ROWWISE_SPARSE=*/false,
1059 THREAD_LOCAL>
1060 kernel_generator;
1061 const auto original_func = kernel_generator.getOrCreate(
1062 bit_rate,
1063 block_size,
1064 has_weight,
1065 is_weight_positional,
1066 normalize_by_lengths,
1067 prefetch,
1068 use_offsets,
1069 output_stride,
1070 input_stride,
1071 scale_bias_last);
1072 return [=](int64_t output_size,
1073 int64_t index_size,
1074 int64_t data_size,
1075 const uint8_t* input,
1076 const indxType* indices,
1077 const offsetType* offsets_or_lengths,
1078 const float* weights,
1079 outType* out) {
1080 return original_func(
1081 output_size,
1082 index_size,
1083 data_size,
1084 input,
1085 indices,
1086 offsets_or_lengths,
1087 weights,
1088 out,
1089 internal::avx2_ps_or_epi32_combined_mask);
1090 };
1091 } else {
1092#ifdef VLOG
1093 VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
1094#endif
1095 return [=](int64_t output_size,
1096 int64_t index_size,
1097 int64_t data_size,
1098 const uint8_t* input,
1099 const indxType* indices,
1100 const offsetType* offsets_or_lengths,
1101 const float* weights,
1102 outType* out) {
1103 return EmbeddingSpMDMNBit_ref(
1104 bit_rate,
1105 block_size,
1106 output_size,
1107 index_size,
1108 data_size,
1109 input,
1110 indices,
1111 offsets_or_lengths,
1112 weights,
1113 normalize_by_lengths,
1114 out,
1115 is_weight_positional,
1116 use_offsets,
1117 output_stride,
1118 input_stride,
1119 scale_bias_last);
1120 };
1121 }
1122}
1123
1124template <typename IndexType, typename OffsetType, typename OutType>
1125FBGEMM_API typename EmbeddingSpMDMKernelSignature<
1126 std::uint8_t,
1127 IndexType,
1128 OffsetType,
1129 OutType>::Type
1130GenerateEmbeddingSpMDMNBit(
1131 int bit_rate,
1132 const std::int64_t block_size,
1133 bool has_weight,
1134 bool normalize_by_lengths,
1135 int prefetch,
1136 bool is_weight_positional,
1137 bool use_offsets) {
1138 return GenerateEmbeddingSpMDMNBitWithStrides<IndexType, OffsetType, OutType>(
1139 bit_rate,
1140 block_size,
1141 has_weight,
1142 normalize_by_lengths,
1143 prefetch,
1144 is_weight_positional,
1145 use_offsets);
1146}
1147
1148template <typename indxType, typename offsetType>
1149typename EmbeddingSpMDMRowWiseSparseKernelSignature<
1150 uint8_t,
1151 indxType,
1152 offsetType>::Type
1153GenerateEmbeddingSpMDMNBitRowWiseSparse(
1154 int bit_rate,
1155 const int64_t block_size,
1156 bool has_weight,
1157 bool normalize_by_lengths,
1158 int prefetch,
1159 bool is_weight_positional,
1160 bool use_offsets) {
1161 assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
1162
1163 if (!cpuinfo_initialize()) {
1164 throw runtime_error("Failed to initialize cpuinfo!");
1165 }
1166 int64_t num_elem_per_byte = 8 / bit_rate;
1167 int64_t input_stride =
1168 ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(float16);
1169 if (fbgemmHasAvx512Support()) {
1170 static GenEmbeddingSpMDMNBitLookup<
1171 indxType,
1172 offsetType,
1173 /*outType=*/float,
1174 inst_set_t::avx512,
1175 /*rowwise_sparse=*/true>
1176 kernel_generator;
1177 const auto original_func = kernel_generator.getOrCreate(
1178 bit_rate,
1179 block_size,
1180 has_weight,
1181 is_weight_positional,
1182 normalize_by_lengths,
1183 prefetch,
1184 use_offsets,
1185 /*output_stride=*/block_size,
1186 input_stride,
1187 /*scale_bias_last=*/true);
1188 return [=](int64_t output_size,
1189 int64_t index_size,
1190 int64_t uncompressed_data_size,
1191 const uint8_t* input,
1192 const indxType* indices,
1193 const offsetType* offsets_or_lengths,
1194 const float* weights,
1195 float* out,
1196 const int32_t* compressed_indices_table) {
1197 return original_func(
1198 output_size,
1199 index_size,
1200 uncompressed_data_size,
1201 input,
1202 indices,
1203 offsets_or_lengths,
1204 weights,
1205 out,
1206 compressed_indices_table,
1207 nullptr /* mask not used in avx512 */);
1208 };
1209 } else if (fbgemmHasAvx2Support()) {
1210 static GenEmbeddingSpMDMNBitLookup<
1211 indxType,
1212 offsetType,
1213 /*outType=*/float,
1214 inst_set_t::avx2,
1215 /*rowwise_sparse=*/true>
1216 kernel_generator;
1217 const auto original_func = kernel_generator.getOrCreate(
1218 bit_rate,
1219 block_size,
1220 has_weight,
1221 is_weight_positional,
1222 normalize_by_lengths,
1223 prefetch,
1224 use_offsets,
1225 /*output_stride=*/block_size,
1226 input_stride,
1227 /*scale_bias_last=*/true);
1228 return [=](int64_t output_size,
1229 int64_t index_size,
1230 int64_t uncompressed_data_size,
1231 const uint8_t* input,
1232 const indxType* indices,
1233 const offsetType* offsets_or_lengths,
1234 const float* weights,
1235 float* out,
1236 const int32_t* compressed_indices_table) {
1237 return original_func(
1238 output_size,
1239 index_size,
1240 uncompressed_data_size,
1241 input,
1242 indices,
1243 offsets_or_lengths,
1244 weights,
1245 out,
1246 compressed_indices_table,
1247 internal::avx2_ps_or_epi32_combined_mask);
1248 };
1249 } else {
1250#ifdef VLOG
1251 VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
1252#endif
1253 return [=](int64_t output_size,
1254 int64_t index_size,
1255 int64_t uncompressed_data_size,
1256 const uint8_t* input,
1257 const indxType* indices,
1258 const offsetType* offsets_or_lengths,
1259 const float* weights,
1260 float* out,
1261 const int32_t* compressed_indices_table) {
1262 return EmbeddingSpMDMNBitRowWiseSparse_ref(
1263 bit_rate,
1264 block_size,
1265 output_size,
1266 index_size,
1267 uncompressed_data_size,
1268 // compressed_data_size,
1269 input,
1270 indices,
1271 compressed_indices_table,
1272 offsets_or_lengths,
1273 weights,
1274 normalize_by_lengths,
1275 out,
1276 is_weight_positional,
1277 use_offsets);
1278 };
1279 }
1280}
1281
1282#define INSTANTIATE_SPMDM_BASE( \
1283 INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \
1284 template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \
1285 uint8_t, \
1286 INDEX_TYPE, \
1287 OFFSET_TYPE, \
1288 OUT_TYPE>::Type \
1289 GenerateEmbeddingSpMDMNBitWithStrides< \
1290 INDEX_TYPE, \
1291 OFFSET_TYPE, \
1292 OUT_TYPE, \
1293 THREAD_LOCAL>( \
1294 int bit_rate, \
1295 const int64_t block_size, \
1296 bool has_weight, \
1297 bool normalize_by_lengths, \
1298 int prefetch, \
1299 bool is_weight_positional, \
1300 bool use_offsets, \
1301 int64_t output_stride, \
1302 int64_t input_stride, \
1303 bool scale_bias_last);
1304
1305#define INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1306 INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
1307 INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1308 template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \
1309 uint8_t, \
1310 INDEX_TYPE, \
1311 OFFSET_TYPE, \
1312 OUT_TYPE>::Type \
1313 GenerateEmbeddingSpMDMNBit<INDEX_TYPE, OFFSET_TYPE, OUT_TYPE>( \
1314 int bit_rate, \
1315 const int64_t block_size, \
1316 bool has_weight, \
1317 bool normalize_by_lengths, \
1318 int prefetch, \
1319 bool is_weight_positional, \
1320 bool use_offsets);
1321
1322#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
1323 INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \
1324 INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float16) \
1325 template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
1326 uint8_t, \
1327 INDEX_TYPE, \
1328 OFFSET_TYPE>::Type \
1329 GenerateEmbeddingSpMDMNBitRowWiseSparse<INDEX_TYPE, OFFSET_TYPE>( \
1330 int bit_rate, \
1331 const int64_t block_size, \
1332 bool has_weight, \
1333 bool normalize_by_lengths, \
1334 int prefetch, \
1335 bool is_weight_positional, \
1336 bool use_offsets);
1337
1338#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
1339 INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \
1340 INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t)
1341
1342INSTANTIATE_SPMDM_OFFSET_T(int32_t)
1343INSTANTIATE_SPMDM_OFFSET_T(int64_t)
1344
1345#undef INSTANTIATE_SPMDM_OFFSET_T
1346#undef INSTANTIATE_SPMDM_OUT_T
1347#undef INSTANTIATE_SPMDM_THREAD_LOCAL
1348#undef INSTANTIATE_SPMDM_BASE
1349
1350} // namespace fbgemm
1351