1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | |
9 | #include "fbgemm/FbgemmEmbedding.h" |
10 | |
11 | #include <asmjit/asmjit.h> |
12 | #include <cpuinfo.h> |
13 | #include <cassert> |
14 | #include <cmath> |
15 | #include <iostream> |
16 | #include <map> |
17 | #include <mutex> |
18 | #include <string> |
19 | #include <tuple> |
20 | #include "./CodeCache.h" |
21 | #include "./MaskAvx2.h" |
22 | #include "./RefImplementations.h" |
23 | #include "fbgemm/FbgemmConvert.h" |
24 | #include "fbgemm/SimdUtils.h" |
25 | |
26 | namespace fbgemm { |
27 | |
28 | namespace { |
29 | |
30 | namespace x86 = asmjit::x86; |
31 | |
32 | template < |
33 | typename inType, |
34 | typename indxType, |
35 | typename offsetType, |
36 | typename outType, |
37 | bool ROWWISE_SPARSE> |
38 | class ReturnFunctionSignature {}; |
39 | |
40 | template < |
41 | typename inType, |
42 | typename indxType, |
43 | typename offsetType, |
44 | typename outType> |
45 | class ReturnFunctionSignature<inType, indxType, offsetType, outType, false> { |
46 | public: |
47 | using jit_embedding_kernel = bool (*)( |
48 | int64_t output_size, |
49 | int64_t index_size, |
50 | int64_t data_size, |
51 | const inType* input, |
52 | const indxType* indices, |
53 | const offsetType* offsets_or_lengths, |
54 | const float* weights, |
55 | outType* out, |
56 | const int* mask); |
57 | }; |
58 | |
59 | template < |
60 | typename inType, |
61 | typename indxType, |
62 | typename offsetType, |
63 | typename outType> |
64 | class ReturnFunctionSignature<inType, indxType, offsetType, outType, true> { |
65 | public: |
66 | using jit_embedding_kernel = bool (*)( |
67 | int64_t output_size, |
68 | int64_t index_size, |
69 | int64_t uncompressed_data_size, |
70 | // int64_t compressed_data_size, |
71 | const inType* input, |
72 | const indxType* indices, |
73 | const offsetType* offsets_or_lengths, |
74 | const float* weights, |
75 | outType* out, |
76 | const int32_t* compressed_indices_table, |
77 | const int* mask); |
78 | }; |
79 | |
80 | template < |
81 | typename inType, |
82 | typename indxType, |
83 | typename offsetType, |
84 | typename outType, |
85 | inst_set_t instSet, |
86 | bool ROWWISE_SPARSE = false, |
87 | bool THREAD_LOCAL = false> |
88 | class GenEmbeddingSpMDMLookup { |
89 | public: |
90 | GenEmbeddingSpMDMLookup() {} |
91 | typename ReturnFunctionSignature< |
92 | inType, |
93 | indxType, |
94 | offsetType, |
95 | outType, |
96 | ROWWISE_SPARSE>::jit_embedding_kernel |
97 | getOrCreate( |
98 | int block_size, |
99 | bool has_weight, |
100 | bool is_weight_positional, |
101 | bool normalize_by_lengths, |
102 | int prefetch, |
103 | bool use_offsets, |
104 | int output_stride, |
105 | int input_stride, |
106 | bool scale_bias_last, |
107 | bool isbf16); |
108 | |
109 | private: |
110 | static asmjit::JitRuntime& runtime() { |
111 | static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, |
112 | // depents on other static |
113 | // variables. Required to prevent |
114 | // initialization order fiasco |
115 | return rt; |
116 | } |
117 | |
118 | static std::mutex rtMutex_; ///< Controll access to runtime; |
119 | |
120 | // The hash depends on embedding dimension (block size), weighted sls, |
121 | // positional weights, normalize by lenths, prefetch distance, use_offsets, |
122 | // output_stride, input_stride, and scale_bias_last |
123 | static CodeCache< |
124 | std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool>, |
125 | typename ReturnFunctionSignature< |
126 | inType, |
127 | indxType, |
128 | offsetType, |
129 | outType, |
130 | ROWWISE_SPARSE>::jit_embedding_kernel, |
131 | THREAD_LOCAL> |
132 | codeCache_; ///< JIT Code Cache for reuse. |
133 | }; // GenEmbeddingSpmDMLookup |
134 | |
135 | template < |
136 | typename inType, |
137 | typename indxType, |
138 | typename offsetType, |
139 | typename outType, |
140 | inst_set_t instSet, |
141 | bool ROWWISE_SPARSE, |
142 | bool THREAD_LOCAL> |
143 | std::mutex GenEmbeddingSpMDMLookup< |
144 | inType, |
145 | indxType, |
146 | offsetType, |
147 | outType, |
148 | instSet, |
149 | ROWWISE_SPARSE, |
150 | THREAD_LOCAL>::rtMutex_; |
151 | |
152 | template < |
153 | typename inType, |
154 | typename indxType, |
155 | typename offsetType, |
156 | typename outType, |
157 | inst_set_t instSet, |
158 | bool ROWWISE_SPARSE, |
159 | bool THREAD_LOCAL> |
160 | CodeCache< |
161 | std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool>, |
162 | typename ReturnFunctionSignature< |
163 | inType, |
164 | indxType, |
165 | offsetType, |
166 | outType, |
167 | ROWWISE_SPARSE>::jit_embedding_kernel, |
168 | THREAD_LOCAL> |
169 | GenEmbeddingSpMDMLookup< |
170 | inType, |
171 | indxType, |
172 | offsetType, |
173 | outType, |
174 | instSet, |
175 | ROWWISE_SPARSE, |
176 | THREAD_LOCAL>::codeCache_; |
177 | |
178 | template < |
179 | typename inType, |
180 | typename indxType, |
181 | typename offsetType, |
182 | typename outType, |
183 | inst_set_t instSet, |
184 | bool ROWWISE_SPARSE, |
185 | bool THREAD_LOCAL> |
186 | typename ReturnFunctionSignature< |
187 | inType, |
188 | indxType, |
189 | offsetType, |
190 | outType, |
191 | ROWWISE_SPARSE>::jit_embedding_kernel |
192 | GenEmbeddingSpMDMLookup< |
193 | inType, |
194 | indxType, |
195 | offsetType, |
196 | outType, |
197 | instSet, |
198 | ROWWISE_SPARSE, |
199 | THREAD_LOCAL>:: |
200 | getOrCreate( |
201 | int block_size, |
202 | bool has_weight, |
203 | bool is_weight_positional, |
204 | bool normalize_by_lengths, |
205 | int prefetch, |
206 | bool use_offsets, |
207 | int output_stride, |
208 | int input_stride, |
209 | bool scale_bias_last, |
210 | bool isbf16) { |
211 | std::tuple<int, bool, bool, bool, int, bool, int, int, bool, bool> kernelSig = |
212 | std::make_tuple( |
213 | block_size, |
214 | has_weight, |
215 | is_weight_positional, |
216 | normalize_by_lengths, |
217 | prefetch, |
218 | use_offsets, |
219 | output_stride, |
220 | input_stride, |
221 | scale_bias_last, |
222 | isbf16); |
223 | |
224 | return codeCache_.getOrCreate( |
225 | kernelSig, |
226 | [&]() -> typename ReturnFunctionSignature< |
227 | inType, |
228 | indxType, |
229 | offsetType, |
230 | outType, |
231 | ROWWISE_SPARSE>::jit_embedding_kernel { |
232 | bool is8bit = std::is_same<inType, uint8_t>::value; |
233 | bool is16bit = std::is_same<inType, uint16_t>::value; |
234 | bool is16bitout = std::is_same<outType, uint16_t>::value; |
235 | bool isbf16out = isbf16; |
236 | bool isfp16 = is16bit && !isbf16; |
237 | bool isfp16out = is16bitout && !isbf16out; |
238 | |
239 | // TODO: Make this tunable |
240 | int pref_dist = prefetch; |
241 | bool areIndices64b = std::is_same<indxType, int64_t>::value; |
242 | |
243 | asmjit::CodeHolder code; |
244 | code.init(runtime().environment()); |
245 | x86::Assembler assembler(&code); |
246 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
247 | #if defined(FBGEMM_LOG_CODE) |
248 | std::string filename = "embeddinglookup" ; |
249 | if (is8bit) { |
250 | filename += "_8bit" ; |
251 | } else if (isfp16) { |
252 | filename += "_fp16" ; |
253 | } else if (isbf16) { |
254 | filename += "_bf16" ; |
255 | } |
256 | if (isbf16out) { |
257 | filename += "_bf16_out" ; |
258 | } else if (isfp16out) { |
259 | filename += "_fp16_out" ; |
260 | } |
261 | filename += "_emd_dim_" + std::to_string(block_size); |
262 | filename += areIndices64b ? "_64bit" : "_32bit" ; |
263 | filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2" ; |
264 | if (prefetch) { |
265 | filename += "_prefetch" ; |
266 | } |
267 | if (has_weight) { |
268 | filename += "_hasweight" ; |
269 | } |
270 | if (normalize_by_lengths) { |
271 | filename += "_normalize_by_lengths" ; |
272 | } |
273 | if (!use_offsets) { |
274 | filename += "_use_lengths" ; |
275 | } |
276 | if (ROWWISE_SPARSE) { |
277 | filename += "_rowwise_sparse" ; |
278 | } |
279 | filename += "_out_stride_" + std::to_string(output_stride); |
280 | if (!scale_bias_last) { |
281 | filename += "_scale_bias_first" ; |
282 | } |
283 | filename += ".txt" ; |
284 | FILE* codeLogFile = fopen(filename.c_str(), "w" ); |
285 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); |
286 | code.setLogger(codeLogger); |
287 | #endif |
288 | // arguments to the function created |
289 | x86::Gp output_size = a->zdi(); |
290 | // index_size will be overwritten to hold the end address of indices |
291 | x86::Gp index_size = a->zsi(); |
292 | x86::Gp data_size = a->zdx(); |
293 | x86::Gp input = a->zcx(); |
294 | int reg_id = 8; |
295 | x86::Gp indices = a->gpz(reg_id); // 8 |
296 | ++reg_id; |
297 | x86::Gp lengths = a->gpz(reg_id); // 9 |
298 | ++reg_id; |
299 | x86::Gp weights = a->gpz(reg_id); // 10 |
300 | ++reg_id; |
301 | x86::Gp out = a->gpz(reg_id); // 11 |
302 | |
303 | x86::Gp compressed_indices_table; |
304 | if (ROWWISE_SPARSE) { |
305 | ++reg_id; |
306 | compressed_indices_table = a->gpz(reg_id); // 12 |
307 | } |
308 | ++reg_id; |
309 | x86::Gp scratchReg1_ = a->gpz(reg_id); // 12 or 13, also for mask |
310 | |
311 | ++reg_id; |
312 | x86::Gpd lengths_R_ = a->gpz(reg_id).r32(); // 13 or 14 |
313 | ++reg_id; |
314 | x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15 |
315 | |
316 | asmjit::FuncDetail func; |
317 | |
318 | if (ROWWISE_SPARSE) { |
319 | func.init( |
320 | asmjit::FuncSignatureT< |
321 | bool, |
322 | int64_t, // output_size |
323 | int64_t, // index_size |
324 | int64_t, // uncompressed_data_size |
325 | const inType*, // input uint8_t or float |
326 | const indxType*, // indices |
327 | const offsetType*, // offsets or lengths |
328 | const float*, // weights |
329 | outType*, // out |
330 | const int32_t*, // compressed_indices_table and then mask |
331 | const int*>(asmjit::CallConvId::kHost), |
332 | a->environment()); |
333 | } else { |
334 | func.init( |
335 | asmjit::FuncSignatureT< |
336 | bool, |
337 | int64_t, // output_size |
338 | int64_t, // index_size |
339 | int64_t, // data_size |
340 | const inType*, // input uint8_t or float |
341 | const indxType*, // indices |
342 | const offsetType*, // offsets or lengths |
343 | const float*, // weights |
344 | outType*, // out and then mask |
345 | const int*>(asmjit::CallConvId::kHost), |
346 | a->environment()); |
347 | } |
348 | |
349 | asmjit::FuncFrame frame; |
350 | frame.init(func); |
351 | |
352 | if (instSet == inst_set_t::avx2) { |
353 | frame.setDirtyRegs( |
354 | asmjit::RegGroup::kVec, |
355 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
356 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
357 | } else { |
358 | frame.setDirtyRegs( |
359 | asmjit::RegGroup::kVec, |
360 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
361 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
362 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
363 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
364 | } |
365 | |
366 | frame.setDirtyRegs( |
367 | asmjit::RegGroup::kGp, |
368 | reg_id == 15 |
369 | ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
370 | : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); |
371 | |
372 | asmjit::FuncArgsAssignment args(&func); |
373 | if (ROWWISE_SPARSE) { |
374 | args.assignAll( |
375 | output_size, |
376 | index_size, |
377 | data_size, |
378 | input, |
379 | indices, |
380 | lengths, |
381 | weights, |
382 | out, |
383 | compressed_indices_table, |
384 | scratchReg1_); |
385 | } else { |
386 | args.assignAll( |
387 | output_size, |
388 | index_size, |
389 | data_size, |
390 | input, |
391 | indices, |
392 | lengths, |
393 | weights, |
394 | out, |
395 | scratchReg1_); |
396 | } |
397 | |
398 | args.updateFuncFrame(frame); |
399 | frame.finalize(); |
400 | |
401 | a->emitProlog(frame); |
402 | a->emitArgsAssignment(frame, args); |
403 | |
404 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
405 | constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS; |
406 | int unroll_factor = NUM_VEC_REG; |
407 | |
408 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
409 | |
410 | int num_vec_regs_per_block = (block_size + vlen - 1) / vlen; |
411 | int remainder = block_size % vlen; |
412 | |
413 | vec_reg_t scale_vreg; // holds scale |
414 | vec_reg_t bias_vreg; // holds bias |
415 | vec_reg_t w_vreg; // for weighted sls -- weights |
416 | vec_reg_t |
417 | vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i] |
418 | vec_reg_t src_vreg; // for holding embedding value temporarily |
419 | x86::Ymm mask_vreg; // mask for avx2 |
420 | x86::Xmm mask_fp16_vreg; // mask for loading fp16 in avx2 |
421 | vec_reg_t ones_vreg; // 2^15 for bf16_2_fp32_rn |
422 | |
423 | if (is8bit) { |
424 | // We need 2 vec registers for 1. scale 2. bias |
425 | --unroll_factor; |
426 | scale_vreg = vec_reg_t(unroll_factor); |
427 | --unroll_factor; |
428 | bias_vreg = vec_reg_t(unroll_factor); |
429 | } |
430 | |
431 | if (isbf16out) { |
432 | --unroll_factor; |
433 | ones_vreg = vec_reg_t(unroll_factor); |
434 | a->mov(scratchReg2_, 1 << 15); |
435 | a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0); |
436 | a->vpbroadcastd(ones_vreg, ones_vreg.xmm()); |
437 | } |
438 | |
439 | if (is8bit || is16bit || (remainder && instSet == inst_set_t::avx2)) { |
440 | --unroll_factor; |
441 | src_vreg = vec_reg_t(unroll_factor); |
442 | } |
443 | |
444 | if (has_weight) { |
445 | --unroll_factor; |
446 | w_vreg = vec_reg_t(unroll_factor); |
447 | } |
448 | |
449 | if (remainder && instSet == inst_set_t::avx2) { |
450 | // AVX512 doesn't need to use vector register for masking |
451 | --unroll_factor; |
452 | mask_vreg = x86::ymm(unroll_factor); |
453 | if (remainder > 1 && (is16bit || isbf16out || isfp16out)) { |
454 | --unroll_factor; |
455 | mask_fp16_vreg = x86::xmm(unroll_factor); |
456 | } |
457 | } |
458 | |
459 | if (normalize_by_lengths) { |
460 | --unroll_factor; |
461 | vlen_inv_vreg = vec_reg_t(unroll_factor); |
462 | } |
463 | |
464 | if (remainder) { |
465 | if (instSet == inst_set_t::avx2) { |
466 | a->vmovups( |
467 | mask_vreg, |
468 | x86::ymmword_ptr( |
469 | scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); |
470 | if (is16bit || isbf16out || isfp16out) { |
471 | if (remainder > 1) { |
472 | a->vmovups( |
473 | mask_fp16_vreg, |
474 | x86::xmmword_ptr( |
475 | scratchReg1_, |
476 | (vlen - remainder / 2) * sizeof(int32_t))); |
477 | } |
478 | // We need to keep using the stack during the main loop |
479 | a->lea( |
480 | x86::rsp, |
481 | x86::dword_ptr( |
482 | x86::rsp, static_cast<int32_t>(-vlen * sizeof(int32_t)))); |
483 | } |
484 | } else { |
485 | a->mov(scratchReg1_, (1 << remainder) - 1); |
486 | a->kmovw(x86::k(1), scratchReg1_); |
487 | } |
488 | } |
489 | |
490 | // Compute the end address of indices |
491 | a->lea( |
492 | index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2)); |
493 | |
494 | asmjit::Label exit = a->newLabel(); |
495 | asmjit::Label error = a->newLabel(); |
496 | asmjit::Label LoopRangeIndexBegin = a->newLabel(); |
497 | asmjit::Label LoopRangeIndexEnd = a->newLabel(); |
498 | |
499 | // rangeIndex loop begins (iterate output_size times) |
500 | a->bind(LoopRangeIndexBegin); |
501 | a->dec(output_size); |
502 | a->jl(LoopRangeIndexEnd); |
503 | |
504 | if (normalize_by_lengths) { |
505 | asmjit::Label IfLengthsBegin = a->newLabel(); |
506 | asmjit::Label IfLengthsEnd = a->newLabel(); |
507 | a->bind(IfLengthsBegin); |
508 | if (use_offsets) { |
509 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
510 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
511 | } else { |
512 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
513 | } |
514 | a->cmp(lengths_R_, 1); |
515 | // Initialize vlen_inv as 0 in case lengths is 0 |
516 | a->vxorps(vlen_inv_vreg, vlen_inv_vreg, vlen_inv_vreg); |
517 | a->jl(IfLengthsEnd); |
518 | |
519 | // OK to use vreg0 because it's for out_vreg used in the main loop |
520 | vec_reg_t temp_vreg(0); |
521 | if (instSet == inst_set_t::avx2) { |
522 | a->mov(scratchReg1_, 1); |
523 | a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_); |
524 | a->cvtsi2ss(temp_vreg.xmm(), lengths_R_); |
525 | a->divss(vlen_inv_vreg.xmm(), temp_vreg.xmm()); |
526 | a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm()); |
527 | } else { // avx512 |
528 | a->mov(scratchReg1_, 1); |
529 | a->cvtsi2ss(temp_vreg.xmm(), scratchReg1_); |
530 | a->vpbroadcastd(vlen_inv_vreg, temp_vreg.xmm()); |
531 | a->vpbroadcastd(temp_vreg, lengths_R_); |
532 | a->vcvtdq2ps(temp_vreg, temp_vreg); |
533 | a->vdivps(vlen_inv_vreg, vlen_inv_vreg, temp_vreg); |
534 | } |
535 | a->bind(IfLengthsEnd); |
536 | } |
537 | |
538 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; |
539 | vec_idx += unroll_factor) { |
540 | int cur_unroll_factor = |
541 | std::min(unroll_factor, num_vec_regs_per_block - vec_idx); |
542 | |
543 | // Initialize output regs |
544 | for (int v = 0; v < cur_unroll_factor; ++v) { |
545 | vec_reg_t out_vreg = vec_reg_t(v); |
546 | a->vxorps(out_vreg, out_vreg, out_vreg); |
547 | } |
548 | |
549 | if (use_offsets) { |
550 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
551 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
552 | } else { |
553 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
554 | } |
555 | |
556 | // Array out of bound check |
557 | a->lea( |
558 | scratchReg1_, |
559 | x86::ptr(indices, lengths_R_, areIndices64b ? 3 : 2)); |
560 | a->cmp(scratchReg1_, index_size); |
561 | a->jg(error); |
562 | |
563 | asmjit::Label LoopDataIndexBegin = a->newLabel(); |
564 | asmjit::Label LoopDataIndexEnd = a->newLabel(); |
565 | asmjit::Label ValidIndexLabel = a->newLabel(); |
566 | |
567 | // dataIndex loop begins (iterate lengths_R_ times) |
568 | a->bind(LoopDataIndexBegin); |
569 | a->dec(lengths_R_); |
570 | a->jl(LoopDataIndexEnd); |
571 | |
572 | // Array out of bound check |
573 | if (areIndices64b) { |
574 | a->mov(scratchReg1_, x86::qword_ptr(indices)); |
575 | } else { |
576 | a->mov(scratchReg1_.r32(), x86::dword_ptr(indices)); |
577 | } |
578 | if (!scale_bias_last) { |
579 | // When scale_bias_last == false, assume this is for table batched |
580 | // embedding (TBE) that can get -1 for pruned rows. |
581 | if (areIndices64b) { |
582 | a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1)); |
583 | } else { |
584 | a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1)); |
585 | } |
586 | a->jne(ValidIndexLabel); |
587 | a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType))); |
588 | a->jmp(LoopDataIndexBegin); |
589 | a->bind(ValidIndexLabel); |
590 | } |
591 | // A trick to check x >= data_size or x < 0 in one shot by treating |
592 | // scratchReg1_ as if it has unsigned value |
593 | // (https://stackoverflow.com/a/34072155). |
594 | a->cmp(scratchReg1_, data_size); |
595 | a->jae(error); |
596 | |
597 | if (ROWWISE_SPARSE) { |
598 | a->mov( |
599 | scratchReg1_.r32(), |
600 | x86::dword_ptr( |
601 | compressed_indices_table, |
602 | scratchReg1_, |
603 | 2)); // use of 2 is to multiply by 4 |
604 | } |
605 | |
606 | int fused_block_size = input_stride * sizeof(inType); |
607 | |
608 | if (pref_dist) { |
609 | asmjit::Label pref_dist_reset_start = a->newLabel(); |
610 | asmjit::Label pref_dist_reset_end = a->newLabel(); |
611 | // out of bound handling for prefetch |
612 | a->lea( |
613 | scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType))); |
614 | a->cmp(scratchReg2_, index_size); |
615 | a->jge(pref_dist_reset_start); |
616 | |
617 | if (areIndices64b) { |
618 | a->mov( |
619 | scratchReg2_, |
620 | x86::qword_ptr(indices, pref_dist * sizeof(indxType))); |
621 | } else { |
622 | a->mov( |
623 | scratchReg2_.r32(), |
624 | x86::dword_ptr(indices, pref_dist * sizeof(indxType))); |
625 | } |
626 | |
627 | a->jmp(pref_dist_reset_end); |
628 | |
629 | a->bind(pref_dist_reset_start); |
630 | // things are not okay just get the current row |
631 | // this can be improved to getting the max dist row. |
632 | if (areIndices64b) { |
633 | a->mov(scratchReg2_, x86::qword_ptr(indices)); |
634 | } else { |
635 | a->mov(scratchReg2_.r32(), x86::dword_ptr(indices)); |
636 | } |
637 | |
638 | a->bind(pref_dist_reset_end); |
639 | if (ROWWISE_SPARSE) { |
640 | asmjit::Label rowwise_sparse_pref_corner_case_begin = |
641 | a->newLabel(); |
642 | asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel(); |
643 | a->cmp(scratchReg2_, data_size); |
644 | a->jae(rowwise_sparse_pref_corner_case_begin); |
645 | |
646 | a->mov( |
647 | scratchReg2_.r32(), |
648 | x86::dword_ptr( |
649 | compressed_indices_table, |
650 | scratchReg2_, |
651 | 2)); // use of 2 is to multiply by 4 |
652 | a->test(scratchReg2_.r32(), scratchReg2_.r32()); |
653 | // Check negative |
654 | a->jns(rowwise_sparse_pref_corner_case_end); |
655 | |
656 | a->bind(rowwise_sparse_pref_corner_case_begin); |
657 | // For corner case, just set prefetch row id to 0. |
658 | a->xor_(scratchReg2_.r32(), scratchReg2_.r32()); |
659 | a->bind(rowwise_sparse_pref_corner_case_end); |
660 | } |
661 | a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size)); |
662 | } |
663 | |
664 | a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType))); |
665 | |
666 | if (has_weight) { |
667 | a->vbroadcastss(w_vreg, x86::dword_ptr(weights)); |
668 | a->add(weights, static_cast<asmjit::Imm>(sizeof(float))); |
669 | } |
670 | |
671 | if (ROWWISE_SPARSE) { |
672 | a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1)); |
673 | a->je(LoopDataIndexBegin); |
674 | } |
675 | |
676 | a->imul(scratchReg1_, static_cast<asmjit::Imm>(fused_block_size)); |
677 | |
678 | // broadcast the scale |
679 | x86::Mem scale_src, bias_src; |
680 | constexpr unsigned int CACHE_LINE_LEN = 64; |
681 | if (is8bit) { |
682 | if (scale_bias_last) { |
683 | scale_src = x86::dword_ptr( |
684 | input, scratchReg1_, 0, block_size * sizeof(uint8_t)); |
685 | bias_src = x86::dword_ptr( |
686 | input, |
687 | scratchReg1_, |
688 | 0, |
689 | block_size * sizeof(uint8_t) + sizeof(float)); |
690 | a->vbroadcastss(scale_vreg, scale_src); |
691 | a->vbroadcastss(bias_vreg, bias_src); |
692 | } else { |
693 | scale_src = x86::word_ptr(input, scratchReg1_); |
694 | bias_src = |
695 | x86::word_ptr(input, scratchReg1_, 0, sizeof(uint16_t)); |
696 | a->vpbroadcastw(scale_vreg.half(), scale_src); |
697 | a->vpbroadcastw(bias_vreg.half(), bias_src); |
698 | a->vcvtph2ps(scale_vreg, scale_vreg.half()); |
699 | a->vcvtph2ps(bias_vreg, bias_vreg.half()); |
700 | } |
701 | |
702 | if (pref_dist && fused_block_size % CACHE_LINE_LEN > 0 && |
703 | fused_block_size % CACHE_LINE_LEN <= 2 * sizeof(float)) { |
704 | a->prefetcht0(x86::dword_ptr( |
705 | input, |
706 | scratchReg2_, |
707 | 0, |
708 | fused_block_size / CACHE_LINE_LEN * CACHE_LINE_LEN)); |
709 | } |
710 | } |
711 | |
712 | if (has_weight && is8bit) { |
713 | a->vmulps(scale_vreg, scale_vreg, w_vreg); |
714 | a->vmulps(bias_vreg, bias_vreg, w_vreg); |
715 | } |
716 | |
717 | // The main computation |
718 | int src_addr_offset = |
719 | is8bit && !scale_bias_last ? 2 * sizeof(uint16_t) : 0; |
720 | for (int v = 0; v < cur_unroll_factor; ++v) { |
721 | constexpr int BYTES_PER_VLOAD = vlen * sizeof(inType); |
722 | auto src_addr = x86::dword_ptr( |
723 | input, |
724 | scratchReg1_, |
725 | 0, |
726 | src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD); |
727 | vec_reg_t out_vreg = vec_reg_t(v); |
728 | |
729 | // For 8bit SLS convert usigned 8-bit to 32bit int, then to float |
730 | // multiply with scale and then add with bias |
731 | if (is8bit) { |
732 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && |
733 | instSet == inst_set_t::avx512) { |
734 | a->k(x86::k(1)).z().vpmovzxbd(src_vreg, src_addr); |
735 | } else { |
736 | // We don't use a mask for AVX2 since we can use the extra |
737 | // "padding" of the 2 floats (= 8 chars) scale and bias |
738 | // this ensures we never access out of bound data |
739 | a->vpmovzxbd(src_vreg, src_addr); |
740 | } |
741 | a->vcvtdq2ps(src_vreg, src_vreg); |
742 | a->vaddps(out_vreg, out_vreg, bias_vreg); |
743 | a->vfmadd231ps(out_vreg, src_vreg, scale_vreg); |
744 | } else if (is16bit) { |
745 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
746 | if (instSet == inst_set_t::avx2) { |
747 | if (remainder % 2 == 0) { |
748 | a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); |
749 | } else { |
750 | a->vpbroadcastw( |
751 | src_vreg.xmm(), |
752 | x86::word_ptr( |
753 | input, |
754 | scratchReg1_, |
755 | 0, |
756 | src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD + |
757 | (remainder - 1) * sizeof(inType))); |
758 | if (remainder > 1) { |
759 | // AVX2 can't do masking for the last 16-bit so we store |
760 | // them to a stack and reload. |
761 | // First put broadcasted last 16-bit element |
762 | a->vmovups(x86::xmmword_ptr(x86::rsp), src_vreg.xmm()); |
763 | // Mask store the remaining 16-bit elements |
764 | a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); |
765 | a->vmaskmovps( |
766 | x86::xmmword_ptr(x86::rsp), |
767 | mask_fp16_vreg, |
768 | src_vreg.xmm()); |
769 | // Load combined 16-bit elements |
770 | a->vmovups(src_vreg.xmm(), x86::xmmword_ptr(x86::rsp)); |
771 | } // remainder > 1 |
772 | } // remainder % 2 |
773 | if (isfp16) { |
774 | a->vcvtph2ps(src_vreg.ymm(), src_vreg.xmm()); |
775 | } else if (isbf16) { |
776 | // bf16 |
777 | a->vpmovzxwd(src_vreg.ymm(), src_vreg.xmm()); |
778 | a->vpslld(src_vreg.ymm(), src_vreg.ymm(), 16); |
779 | } |
780 | } else { |
781 | // avx512 |
782 | if (isfp16) { |
783 | a->k(x86::k(1)).z().vcvtph2ps(src_vreg, src_addr); |
784 | } else if (isbf16) { |
785 | // bf16 |
786 | a->k(x86::k(1)).z().vpmovzxwd(src_vreg, src_addr); |
787 | a->k(x86::k(1)).z().vpslld(src_vreg, src_vreg, 16); |
788 | } |
789 | } |
790 | } else { |
791 | // no remainder |
792 | if (isfp16) { |
793 | a->vcvtph2ps(src_vreg, src_addr); |
794 | } else if (isbf16) { |
795 | // bf16 |
796 | a->vpmovzxwd(src_vreg, src_addr); |
797 | a->vpslld(src_vreg, src_vreg, 16); |
798 | } |
799 | } |
800 | if (has_weight) { |
801 | a->vfmadd231ps(out_vreg, w_vreg, src_vreg); |
802 | } else { |
803 | a->vaddps(out_vreg, out_vreg, src_vreg); |
804 | } |
805 | } else { |
806 | // This part for FP32 SLS |
807 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && |
808 | instSet == inst_set_t::avx2) { |
809 | a->vmaskmovps(src_vreg.ymm(), mask_vreg.ymm(), src_addr); |
810 | } |
811 | if (has_weight) { |
812 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
813 | if (instSet == inst_set_t::avx2) { |
814 | a->vfmadd231ps(out_vreg, w_vreg, src_vreg); |
815 | } else { |
816 | a->k(x86::k(1)).vfmadd231ps(out_vreg, w_vreg, src_addr); |
817 | } |
818 | } else { |
819 | a->vfmadd231ps(out_vreg, w_vreg, src_addr); |
820 | } |
821 | } else { |
822 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
823 | if (instSet == inst_set_t::avx2) { |
824 | a->vaddps(out_vreg, out_vreg, src_vreg); |
825 | } else { |
826 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, src_addr); |
827 | } |
828 | } else { |
829 | a->vaddps(out_vreg, out_vreg, src_addr); |
830 | } |
831 | } |
832 | } |
833 | |
834 | constexpr int VLOAD_PER_CACHE_LINE = |
835 | CACHE_LINE_LEN / BYTES_PER_VLOAD; |
836 | if (pref_dist && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) { |
837 | a->prefetcht0(x86::dword_ptr( |
838 | input, scratchReg2_, 0, (vec_idx + v) * BYTES_PER_VLOAD)); |
839 | } |
840 | } |
841 | |
842 | a->jmp(LoopDataIndexBegin); |
843 | a->bind(LoopDataIndexEnd); |
844 | |
845 | // This loop is for writing back out_vreg (results) |
846 | // back to memory |
847 | for (int v = 0; v < cur_unroll_factor; ++v) { |
848 | auto dst_addr = |
849 | x86::dword_ptr(out, (vec_idx + v) * vlen * sizeof(outType)); |
850 | vec_reg_t out_vreg = vec_reg_t(v); |
851 | |
852 | if (normalize_by_lengths) { |
853 | a->vmulps(out_vreg, out_vreg, vlen_inv_vreg); |
854 | } |
855 | |
856 | if (std::is_same<outType, float>::value) { |
857 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
858 | if (instSet == inst_set_t::avx2) { |
859 | a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm()); |
860 | } else { |
861 | a->k(x86::k(1)).vmovups(dst_addr, out_vreg); |
862 | } |
863 | } else { |
864 | a->vmovups(dst_addr, out_vreg); |
865 | } |
866 | } else { |
867 | // fp16/bf16 output |
868 | if (instSet == inst_set_t::avx2) { |
869 | // round nearest with no exception |
870 | if (isfp16out) { |
871 | a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8); |
872 | } else if (isbf16out) { |
873 | a->vpaddd(out_vreg, out_vreg, ones_vreg); |
874 | a->vpsrld(out_vreg, out_vreg, 16); |
875 | a->vpackusdw(out_vreg, out_vreg, out_vreg); |
876 | a->vpermq(out_vreg, out_vreg, 0xd8); |
877 | } |
878 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
879 | if (remainder > 1) { |
880 | a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); |
881 | } |
882 | if (remainder % 2 != 0) { |
883 | a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm()); |
884 | a->mov( |
885 | scratchReg1_.r16(), |
886 | x86::word_ptr( |
887 | x86::rsp, (remainder - 1) * sizeof(outType))); |
888 | a->mov( |
889 | x86::word_ptr( |
890 | out, |
891 | ((vec_idx + v) * vlen + (remainder - 1)) * |
892 | sizeof(outType)), |
893 | scratchReg1_.r16()); |
894 | } |
895 | } else { |
896 | a->vmovups(dst_addr, out_vreg.xmm()); |
897 | } |
898 | } else { |
899 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
900 | if (isfp16out) { |
901 | a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); |
902 | } else if (isbf16out) { |
903 | // bf16 |
904 | a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); |
905 | a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); |
906 | a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); |
907 | } |
908 | } else { |
909 | if (isfp16out) { |
910 | a->vcvtps2ph(dst_addr, out_vreg, 8); |
911 | } else if (isbf16out) { |
912 | // bf16 |
913 | a->vpaddd(out_vreg, out_vreg, ones_vreg); |
914 | a->vpsrld(out_vreg, out_vreg, 16); |
915 | a->vpmovdw(dst_addr, out_vreg); |
916 | } |
917 | } |
918 | } |
919 | } |
920 | } |
921 | |
922 | if (vec_idx + unroll_factor < num_vec_regs_per_block || |
923 | (has_weight && is_weight_positional)) { |
924 | // Reset lengths_R_, indices, weights to run the dataIndex loop |
925 | // again |
926 | if (use_offsets) { |
927 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
928 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
929 | } else { |
930 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
931 | } |
932 | |
933 | if (has_weight) { |
934 | a->imul( |
935 | scratchReg1_, |
936 | lengths_R_, |
937 | static_cast<asmjit::Imm>(sizeof(float))); |
938 | a->sub(weights, scratchReg1_); |
939 | |
940 | if (vec_idx + unroll_factor < num_vec_regs_per_block) { |
941 | a->imul( |
942 | scratchReg1_, |
943 | static_cast<asmjit::Imm>(sizeof(indxType) / sizeof(float))); |
944 | a->sub(indices, scratchReg1_); |
945 | } |
946 | } else { |
947 | a->imul( |
948 | scratchReg1_, |
949 | lengths_R_, |
950 | static_cast<asmjit::Imm>(sizeof(indxType))); |
951 | a->sub(indices, scratchReg1_); |
952 | } |
953 | } |
954 | } |
955 | |
956 | a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType))); |
957 | a->add(out, static_cast<asmjit::Imm>(output_stride * sizeof(outType))); |
958 | |
959 | a->jmp(LoopRangeIndexBegin); |
960 | a->bind(LoopRangeIndexEnd); |
961 | |
962 | a->cmp(indices, index_size); |
963 | a->jne(error); |
964 | a->mov(x86::eax, true); |
965 | a->jmp(exit); |
966 | a->bind(error); |
967 | a->mov(x86::eax, false); |
968 | a->bind(exit); |
969 | |
970 | if (remainder && instSet == inst_set_t::avx2 && |
971 | (is16bit || isbf16out || isfp16out)) { |
972 | a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); |
973 | } |
974 | |
975 | a->emitEpilog(frame); |
976 | |
977 | // jit_fused8bitembedding_kernel fn; |
978 | typename ReturnFunctionSignature< |
979 | inType, |
980 | indxType, |
981 | offsetType, |
982 | outType, |
983 | ROWWISE_SPARSE>::jit_embedding_kernel fn; |
984 | asmjit::Error err; |
985 | { |
986 | std::unique_lock<std::mutex> lock(rtMutex_); |
987 | err = runtime().add(&fn, &code); |
988 | } |
989 | if (err) { |
990 | std::cout << "Error: in fn add" << std::endl; |
991 | return nullptr; |
992 | } |
993 | |
994 | #if defined(FBGEMM_LOG_CODE) |
995 | fclose(codeLogFile); |
996 | delete codeLogger; |
997 | #endif |
998 | return fn; |
999 | }); |
1000 | } |
1001 | |
1002 | } // namespace |
1003 | |
1004 | template < |
1005 | typename inType, |
1006 | typename indxType, |
1007 | typename offsetType, |
1008 | typename outType, |
1009 | bool THREAD_LOCAL> |
1010 | typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>:: |
1011 | Type |
1012 | GenerateEmbeddingSpMDMWithStrides( |
1013 | const int64_t block_size, |
1014 | bool has_weight, |
1015 | bool normalize_by_lengths, |
1016 | int prefetch, |
1017 | bool is_weight_positional, |
1018 | bool use_offsets, |
1019 | int64_t output_stride /*=-1*/, |
1020 | int64_t input_stride /*=-1*/, |
1021 | bool scale_bias_last /*=true*/, |
1022 | bool no_bag /*=false*/, |
1023 | bool isbf16 /*=false*/) { |
1024 | if (!cpuinfo_initialize()) { |
1025 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
1026 | } |
1027 | if (output_stride == -1) { |
1028 | output_stride = block_size; |
1029 | } |
1030 | if (input_stride == -1) { |
1031 | if (std::is_same<inType, uint8_t>::value) { |
1032 | const auto scale_bias_offset = |
1033 | 2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t)); |
1034 | input_stride = block_size + scale_bias_offset; |
1035 | } else { |
1036 | input_stride = block_size; |
1037 | } |
1038 | } |
1039 | const inst_set_t isa = fbgemmInstructionSet(); |
1040 | if (no_bag == true) { |
1041 | return [=](int64_t output_size, |
1042 | int64_t index_size, |
1043 | int64_t data_size, |
1044 | const inType* input, |
1045 | const indxType* indices, |
1046 | const offsetType* offsets_or_lengths, |
1047 | const float* weights, |
1048 | outType* out) { |
1049 | return EmbeddingSpMDM_ref( |
1050 | block_size, |
1051 | output_size, |
1052 | index_size, |
1053 | data_size, |
1054 | input, |
1055 | indices, |
1056 | offsets_or_lengths, |
1057 | weights, |
1058 | normalize_by_lengths, |
1059 | out, |
1060 | is_weight_positional, |
1061 | use_offsets, |
1062 | output_stride, |
1063 | input_stride, |
1064 | scale_bias_last, |
1065 | no_bag, |
1066 | isbf16); |
1067 | }; |
1068 | } |
1069 | |
1070 | if ((std::is_same<inType, float>::value || |
1071 | std::is_same<inType, uint16_t>::value) && |
1072 | block_size == 1 && isYmm(isa) && output_stride == block_size && |
1073 | input_stride == block_size && std::is_same<outType, float>::value) { |
1074 | return |
1075 | [=](int64_t output_size, |
1076 | int64_t index_size, |
1077 | int64_t data_size, |
1078 | const inType* input, |
1079 | const indxType* indices, |
1080 | const offsetType* offsets_or_lengths, |
1081 | const float* weights, // optional, can be null for non-weighted sum |
1082 | outType* out) { |
1083 | return internal::EmbeddingSpMDMBlockSize1_( |
1084 | output_size, |
1085 | index_size, |
1086 | data_size, |
1087 | input, |
1088 | indices, |
1089 | offsets_or_lengths, |
1090 | weights, |
1091 | normalize_by_lengths, |
1092 | reinterpret_cast<float*>(out), |
1093 | is_weight_positional, |
1094 | use_offsets, |
1095 | isbf16); |
1096 | }; |
1097 | } else if (isZmm(isa)) { |
1098 | static GenEmbeddingSpMDMLookup< |
1099 | inType, |
1100 | indxType, |
1101 | offsetType, |
1102 | outType, |
1103 | inst_set_t::avx512, |
1104 | /*ROWWISE_SPARSE=*/false, |
1105 | THREAD_LOCAL> |
1106 | kernel_generator; |
1107 | const auto original_func = kernel_generator.getOrCreate( |
1108 | block_size, |
1109 | has_weight, |
1110 | is_weight_positional, |
1111 | normalize_by_lengths, |
1112 | prefetch, |
1113 | use_offsets, |
1114 | output_stride, |
1115 | input_stride, |
1116 | scale_bias_last, |
1117 | isbf16); |
1118 | return [=](int64_t output_size, |
1119 | int64_t index_size, |
1120 | int64_t data_size, |
1121 | const inType* input, |
1122 | const indxType* indices, |
1123 | const offsetType* offsets_or_lengths, |
1124 | const float* weights, |
1125 | outType* out) { |
1126 | return original_func( |
1127 | output_size, |
1128 | index_size, |
1129 | data_size, |
1130 | input, |
1131 | indices, |
1132 | offsets_or_lengths, |
1133 | weights, |
1134 | out, |
1135 | nullptr /* mask not used in avx512 */); |
1136 | }; |
1137 | } else if (isYmm(isa)) { |
1138 | static GenEmbeddingSpMDMLookup< |
1139 | inType, |
1140 | indxType, |
1141 | offsetType, |
1142 | outType, |
1143 | inst_set_t::avx2, |
1144 | /*ROWWISE_SPARSE=*/false, |
1145 | THREAD_LOCAL> |
1146 | kernel_generator; |
1147 | const auto original_func = kernel_generator.getOrCreate( |
1148 | block_size, |
1149 | has_weight, |
1150 | is_weight_positional, |
1151 | normalize_by_lengths, |
1152 | prefetch, |
1153 | use_offsets, |
1154 | output_stride, |
1155 | input_stride, |
1156 | scale_bias_last, |
1157 | isbf16); |
1158 | return [=](int64_t output_size, |
1159 | int64_t index_size, |
1160 | int64_t data_size, |
1161 | const inType* input, |
1162 | const indxType* indices, |
1163 | const offsetType* offsets_or_lengths, |
1164 | const float* weights, |
1165 | outType* out) { |
1166 | return original_func( |
1167 | output_size, |
1168 | index_size, |
1169 | data_size, |
1170 | input, |
1171 | indices, |
1172 | offsets_or_lengths, |
1173 | weights, |
1174 | out, |
1175 | internal::avx2_ps_or_epi32_combined_mask); |
1176 | }; |
1177 | } else { |
1178 | #ifdef VLOG |
1179 | VLOG(0) << "AVX2 or AVX512 not found, taking the slow path" ; |
1180 | #endif |
1181 | return [=](int64_t output_size, |
1182 | int64_t index_size, |
1183 | int64_t data_size, |
1184 | const inType* input, |
1185 | const indxType* indices, |
1186 | const offsetType* offsets_or_lengths, |
1187 | const float* weights, |
1188 | outType* out) { |
1189 | return EmbeddingSpMDM_ref( |
1190 | block_size, |
1191 | output_size, |
1192 | index_size, |
1193 | data_size, |
1194 | input, |
1195 | indices, |
1196 | offsets_or_lengths, |
1197 | weights, |
1198 | normalize_by_lengths, |
1199 | out, |
1200 | is_weight_positional, |
1201 | use_offsets, |
1202 | output_stride, |
1203 | input_stride, |
1204 | scale_bias_last, |
1205 | no_bag, |
1206 | isbf16); |
1207 | }; |
1208 | } |
1209 | } |
1210 | |
1211 | template < |
1212 | typename inType, |
1213 | typename indxType, |
1214 | typename offsetType, |
1215 | typename outType, |
1216 | bool THREAD_LOCAL> |
1217 | typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>:: |
1218 | Type |
1219 | GenerateEmbeddingSpMDM( |
1220 | const int64_t block_size, |
1221 | bool has_weight, |
1222 | bool normalize_by_lengths, |
1223 | int prefetch, |
1224 | bool is_weight_positional, |
1225 | bool use_offsets, |
1226 | bool isbf16) { |
1227 | return GenerateEmbeddingSpMDMWithStrides< |
1228 | inType, |
1229 | indxType, |
1230 | offsetType, |
1231 | outType, |
1232 | THREAD_LOCAL>( |
1233 | block_size, |
1234 | has_weight, |
1235 | normalize_by_lengths, |
1236 | prefetch, |
1237 | is_weight_positional, |
1238 | use_offsets, |
1239 | /*output_stride=*/-1, |
1240 | /*input_stride=*/-1, |
1241 | /*scale_bias_last=*/true, |
1242 | /*no_bag=*/false, |
1243 | isbf16); |
1244 | } |
1245 | |
1246 | template <typename indxType, typename offsetType, typename outType> |
1247 | typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>:: |
1248 | Type |
1249 | GenerateEmbeddingSpMDMFP8WithStrides( |
1250 | const int64_t block_size, |
1251 | bool normalize_by_lengths, |
1252 | bool is_weight_positional, |
1253 | bool use_offsets, |
1254 | int64_t output_stride /*=-1*/, |
1255 | int64_t input_stride /*=-1*/, |
1256 | int exponent_bits, |
1257 | int exponent_bias) { |
1258 | if (output_stride == -1) { |
1259 | output_stride = block_size; |
1260 | } |
1261 | if (input_stride == -1) { |
1262 | input_stride = block_size; |
1263 | } |
1264 | // There is only the reference implementation for FP8 embedding |
1265 | return [=](int64_t output_size, |
1266 | int64_t index_size, |
1267 | int64_t data_size, |
1268 | const uint8_t* input, |
1269 | const indxType* indices, |
1270 | const offsetType* offsets_or_lengths, |
1271 | const float* weights, |
1272 | outType* out) { |
1273 | return EmbeddingSpMDMFP8_ref( |
1274 | block_size, |
1275 | output_size, |
1276 | index_size, |
1277 | data_size, |
1278 | input, |
1279 | indices, |
1280 | offsets_or_lengths, |
1281 | weights, |
1282 | normalize_by_lengths, |
1283 | out, |
1284 | is_weight_positional, |
1285 | use_offsets, |
1286 | output_stride, |
1287 | input_stride, |
1288 | exponent_bits, |
1289 | exponent_bias); |
1290 | }; |
1291 | } |
1292 | |
1293 | template <typename inType, typename indxType, typename offsetType> |
1294 | typename EmbeddingSpMDMRowWiseSparseKernelSignature< |
1295 | inType, |
1296 | indxType, |
1297 | offsetType>::Type |
1298 | GenerateEmbeddingSpMDMRowWiseSparse( |
1299 | const int64_t block_size, |
1300 | bool has_weight, |
1301 | bool normalize_by_lengths, |
1302 | int prefetch, |
1303 | bool is_weight_positional, |
1304 | bool use_offsets) { |
1305 | if (!cpuinfo_initialize()) { |
1306 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
1307 | } |
1308 | int64_t input_stride = block_size; |
1309 | if (std::is_same<inType, uint8_t>::value) { |
1310 | const auto scale_bias_offset = 2 * sizeof(float); |
1311 | input_stride = block_size + scale_bias_offset; |
1312 | } |
1313 | inst_set_t isa = fbgemmInstructionSet(); |
1314 | if (isZmm(isa)) { |
1315 | static GenEmbeddingSpMDMLookup< |
1316 | inType, |
1317 | indxType, |
1318 | offsetType, |
1319 | /*outType=*/float, |
1320 | inst_set_t::avx512, |
1321 | /*rowwise_sparse=*/true> |
1322 | kernel_generator; |
1323 | const auto original_func = kernel_generator.getOrCreate( |
1324 | block_size, |
1325 | has_weight, |
1326 | is_weight_positional, |
1327 | normalize_by_lengths, |
1328 | prefetch, |
1329 | use_offsets, |
1330 | /*output_stride=*/block_size, |
1331 | input_stride, |
1332 | /*scale_bias_last=*/true, |
1333 | /*isbf16=*/false); |
1334 | return [=](int64_t output_size, |
1335 | int64_t index_size, |
1336 | int64_t uncompressed_data_size, |
1337 | const inType* input, |
1338 | const indxType* indices, |
1339 | const offsetType* offsets_or_lengths, |
1340 | const float* weights, |
1341 | float* out, |
1342 | const int32_t* compressed_indices_table) { |
1343 | return original_func( |
1344 | output_size, |
1345 | index_size, |
1346 | uncompressed_data_size, |
1347 | input, |
1348 | indices, |
1349 | offsets_or_lengths, |
1350 | weights, |
1351 | out, |
1352 | compressed_indices_table, |
1353 | nullptr /* mask not used in avx512 */); |
1354 | }; |
1355 | } else if (isYmm(isa)) { |
1356 | static GenEmbeddingSpMDMLookup< |
1357 | inType, |
1358 | indxType, |
1359 | offsetType, |
1360 | /*outType=*/float, |
1361 | inst_set_t::avx2, |
1362 | /*rowwise_sparse=*/true> |
1363 | kernel_generator; |
1364 | const auto original_func = kernel_generator.getOrCreate( |
1365 | block_size, |
1366 | has_weight, |
1367 | is_weight_positional, |
1368 | normalize_by_lengths, |
1369 | prefetch, |
1370 | use_offsets, |
1371 | /*output_stride=*/block_size, |
1372 | input_stride, |
1373 | /*scale_bias_last=*/true, |
1374 | /*isbf16=*/false); |
1375 | return [=](int64_t output_size, |
1376 | int64_t index_size, |
1377 | int64_t uncompressed_data_size, |
1378 | const inType* input, |
1379 | const indxType* indices, |
1380 | const offsetType* offsets_or_lengths, |
1381 | const float* weights, |
1382 | float* out, |
1383 | const int32_t* compressed_indices_table) { |
1384 | return original_func( |
1385 | output_size, |
1386 | index_size, |
1387 | uncompressed_data_size, |
1388 | input, |
1389 | indices, |
1390 | offsets_or_lengths, |
1391 | weights, |
1392 | out, |
1393 | compressed_indices_table, |
1394 | internal::avx2_ps_or_epi32_combined_mask); |
1395 | }; |
1396 | } else { |
1397 | #ifdef VLOG |
1398 | VLOG(0) << "AVX2 or AVX512 not found, taking the slow path" ; |
1399 | #endif |
1400 | return |
1401 | [=](int64_t output_size, |
1402 | int64_t index_size, |
1403 | int64_t uncompressed_data_size, |
1404 | const inType* input, |
1405 | const indxType* indices, |
1406 | const offsetType* offsets_or_lengths, |
1407 | const float* weights, // optional, can be null for non-weighted sum |
1408 | float* out, |
1409 | const int32_t* compressed_indices_table) { |
1410 | return EmbeddingSpMDMRowWiseSparse_ref( |
1411 | block_size, |
1412 | output_size, |
1413 | index_size, |
1414 | uncompressed_data_size, |
1415 | // compressed_data_size, |
1416 | input, |
1417 | indices, |
1418 | compressed_indices_table, |
1419 | offsets_or_lengths, |
1420 | weights, |
1421 | normalize_by_lengths, |
1422 | out, |
1423 | is_weight_positional, |
1424 | use_offsets); |
1425 | }; |
1426 | } |
1427 | } |
1428 | |
1429 | #define INSTANTIATE_SPMDM_BASE( \ |
1430 | IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ |
1431 | template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ |
1432 | IN_TYPE, \ |
1433 | INDEX_TYPE, \ |
1434 | OFFSET_TYPE, \ |
1435 | OUT_TYPE>::Type \ |
1436 | GenerateEmbeddingSpMDMWithStrides< \ |
1437 | IN_TYPE, \ |
1438 | INDEX_TYPE, \ |
1439 | OFFSET_TYPE, \ |
1440 | OUT_TYPE, \ |
1441 | THREAD_LOCAL>( \ |
1442 | const int64_t block_size, \ |
1443 | bool has_weight, \ |
1444 | bool normalize_by_lengths, \ |
1445 | int prefetch, \ |
1446 | bool is_weight_positional, \ |
1447 | bool use_offsets, \ |
1448 | int64_t output_stride, \ |
1449 | int64_t input_stride, \ |
1450 | bool scale_bias_last, \ |
1451 | bool no_bag, \ |
1452 | bool isbf16); |
1453 | |
1454 | #define INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
1455 | template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ |
1456 | uint8_t, \ |
1457 | INDEX_TYPE, \ |
1458 | OFFSET_TYPE, \ |
1459 | OUT_TYPE>::Type \ |
1460 | GenerateEmbeddingSpMDMFP8WithStrides<INDEX_TYPE, OFFSET_TYPE, OUT_TYPE>( \ |
1461 | const int64_t block_size, \ |
1462 | bool normalize_by_lengths, \ |
1463 | bool is_weight_positional, \ |
1464 | bool use_offsets, \ |
1465 | int64_t output_stride, \ |
1466 | int64_t input_stride, \ |
1467 | int exponent_bits, \ |
1468 | int exponent_bias); |
1469 | |
1470 | #define INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ |
1471 | IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ |
1472 | template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ |
1473 | IN_TYPE, \ |
1474 | INDEX_TYPE, \ |
1475 | OFFSET_TYPE, \ |
1476 | OUT_TYPE>::Type \ |
1477 | GenerateEmbeddingSpMDM< \ |
1478 | IN_TYPE, \ |
1479 | INDEX_TYPE, \ |
1480 | OFFSET_TYPE, \ |
1481 | OUT_TYPE, \ |
1482 | THREAD_LOCAL>( \ |
1483 | const int64_t block_size, \ |
1484 | bool has_weight, \ |
1485 | bool normalize_by_lengths, \ |
1486 | int prefetch, \ |
1487 | bool is_weight_positional, \ |
1488 | bool use_offsets, \ |
1489 | bool isbf16); |
1490 | |
1491 | #define INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ |
1492 | template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ |
1493 | IN_TYPE, \ |
1494 | INDEX_TYPE, \ |
1495 | OFFSET_TYPE>::Type \ |
1496 | GenerateEmbeddingSpMDMRowWiseSparse<IN_TYPE, INDEX_TYPE, OFFSET_TYPE>( \ |
1497 | const int64_t block_size, \ |
1498 | bool has_weight, \ |
1499 | bool normalize_by_lengths, \ |
1500 | int prefetch, \ |
1501 | bool is_weight_positional, \ |
1502 | bool use_offsets); |
1503 | |
1504 | #define INSTANTIATE_SPMDMFP8_BASE_uint8_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
1505 | INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) |
1506 | #define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) |
1507 | #define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) |
1508 | |
1509 | #define INSTANTIATE_SPMDM_THREAD_LOCAL( \ |
1510 | IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
1511 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \ |
1512 | INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ |
1513 | INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ |
1514 | IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \ |
1515 | INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ |
1516 | IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ |
1517 | INSTANTIATE_SPMDMFP8_BASE_##IN_TYPE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) |
1518 | |
1519 | #define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ |
1520 | INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ |
1521 | INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \ |
1522 | INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) |
1523 | |
1524 | #define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \ |
1525 | INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int32_t) \ |
1526 | INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int64_t) |
1527 | |
1528 | #define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \ |
1529 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int32_t) \ |
1530 | INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int64_t) |
1531 | |
1532 | INSTANTIATE_SPMDM_INDEX_T(float) |
1533 | INSTANTIATE_SPMDM_INDEX_T(uint16_t) |
1534 | INSTANTIATE_SPMDM_INDEX_T(uint8_t) |
1535 | |
1536 | #undef INSTANTIATE_SPMDM_INDEX_T |
1537 | #undef INSTANTIATE_SPMDM_OFFSET_T |
1538 | #undef INSTANTIATE_SPMDM_OUT_T |
1539 | #undef INSTANTIATE_SPMDM_THREAD_LOCAL |
1540 | #undef INSTANTIATE_SPMDM_BASE |
1541 | #undef INSTANTIATE_SPMDMFP8_BASE |
1542 | #undef INSTANTIATE_SPMDM_NOSTRIDE_BASE |
1543 | #undef INSTANTIATE_SPMDM_ROWWISE_BASE |
1544 | |
1545 | template <typename IndexType> |
1546 | void compressed_indices_remap( |
1547 | std::int32_t offsets_len, |
1548 | const IndexType* indices, |
1549 | const int32_t* compressed_indices_mapping, |
1550 | const IndexType* offsets, |
1551 | const float* weights, // optional, can be null, |
1552 | IndexType* out_indices, |
1553 | IndexType* out_offsets, |
1554 | float* out_weights) { |
1555 | if (!cpuinfo_initialize()) { |
1556 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
1557 | } |
1558 | |
1559 | const inst_set_t isa = fbgemmInstructionSet(); |
1560 | if (isZmm(isa)) { |
1561 | #ifndef __HIP_PLATFORM_HCC__ |
1562 | if (weights == nullptr) { |
1563 | internal::compressed_indices_remap_avx512<IndexType, false>( |
1564 | offsets_len, |
1565 | indices, |
1566 | compressed_indices_mapping, |
1567 | offsets, |
1568 | weights, |
1569 | out_indices, |
1570 | out_offsets, |
1571 | out_weights); |
1572 | } else { |
1573 | internal::compressed_indices_remap_avx512<IndexType, true>( |
1574 | offsets_len, |
1575 | indices, |
1576 | compressed_indices_mapping, |
1577 | offsets, |
1578 | weights, |
1579 | out_indices, |
1580 | out_offsets, |
1581 | out_weights); |
1582 | } |
1583 | #endif |
1584 | } else { |
1585 | compressed_indices_remap_ref<IndexType>( |
1586 | offsets_len, |
1587 | indices, |
1588 | compressed_indices_mapping, |
1589 | offsets, |
1590 | weights, |
1591 | out_indices, |
1592 | out_offsets, |
1593 | out_weights); |
1594 | } |
1595 | } |
1596 | |
1597 | #define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \ |
1598 | template FBGEMM_API void compressed_indices_remap( \ |
1599 | std::int32_t offsets_numel, \ |
1600 | const INDEX_TYPE* indices, \ |
1601 | const int32_t* compressed_indices_mapping, \ |
1602 | const INDEX_TYPE* offsets, \ |
1603 | const float* weights, \ |
1604 | INDEX_TYPE* out_indices, \ |
1605 | INDEX_TYPE* out_offsets, \ |
1606 | float* out_weights); |
1607 | |
1608 | INSTANTIATE_REMAP_BASE(int32_t) |
1609 | INSTANTIATE_REMAP_BASE(int64_t) |
1610 | |
1611 | #undef INSTANTIATE_REMAP_BASE |
1612 | |
1613 | } // namespace fbgemm |
1614 | |