1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | |
9 | #include "fbgemm/FbgemmEmbedding.h" |
10 | |
11 | #include <asmjit/asmjit.h> |
12 | #include <cpuinfo.h> |
13 | #include <cassert> |
14 | #include <cmath> |
15 | #include <iostream> |
16 | #include <map> |
17 | #include <mutex> |
18 | #include <string> |
19 | #include <tuple> |
20 | #include "./CodeCache.h" |
21 | #include "./MaskAvx2.h" |
22 | #include "./RefImplementations.h" |
23 | #include "fbgemm/SimdUtils.h" |
24 | #include "fbgemm/Types.h" |
25 | |
26 | using namespace std; |
27 | |
28 | namespace fbgemm { |
29 | |
30 | namespace { |
31 | |
32 | template <typename T> |
33 | T ceil_div(T a, T b) { |
34 | return (a + b - 1) / b; |
35 | } |
36 | |
37 | namespace x86 = asmjit::x86; |
38 | |
39 | template < |
40 | typename indxType, |
41 | typename offsetType, |
42 | typename outType, |
43 | bool ROWWISE_SPARSE> |
44 | class ReturnFunctionSignature {}; |
45 | |
46 | template <typename indxType, typename offsetType, typename outType> |
47 | class ReturnFunctionSignature<indxType, offsetType, outType, false> { |
48 | public: |
49 | using jit_embedding_kernel = bool (*)( |
50 | int64_t output_size, |
51 | int64_t index_size, |
52 | int64_t data_size, |
53 | const uint8_t* input, |
54 | const indxType* indices, |
55 | const offsetType* offsets_or_lengths, |
56 | const float* weights, |
57 | outType* out, |
58 | const int* mask); |
59 | }; |
60 | |
61 | template <typename indxType, typename offsetType, typename outType> |
62 | class ReturnFunctionSignature<indxType, offsetType, outType, true> { |
63 | public: |
64 | using jit_embedding_kernel = bool (*)( |
65 | int64_t output_size, |
66 | int64_t index_size, |
67 | int64_t uncompressed_data_size, |
68 | // int64_t compressed_data_size, |
69 | const uint8_t* input, |
70 | const indxType* indices, |
71 | const offsetType* offsets_or_lengths, |
72 | const float* weights, |
73 | outType* out, |
74 | const int32_t* compressed_indices_table, |
75 | const int* mask); |
76 | }; |
77 | |
78 | template < |
79 | typename indxType, |
80 | typename offsetType, |
81 | typename outType, |
82 | inst_set_t instSet, |
83 | bool ROWWISE_SPARSE = false, |
84 | bool THREAD_LOCAL = false> |
85 | class GenEmbeddingSpMDMNBitLookup { |
86 | public: |
87 | GenEmbeddingSpMDMNBitLookup() {} |
88 | typename ReturnFunctionSignature< |
89 | indxType, |
90 | offsetType, |
91 | outType, |
92 | ROWWISE_SPARSE>::jit_embedding_kernel |
93 | getOrCreate( |
94 | int bit_rate, |
95 | int block_size, |
96 | bool has_weight, |
97 | bool is_weight_positional, |
98 | bool normalize_by_lengths, |
99 | int prefetch, |
100 | bool use_offsets, |
101 | int output_stride, |
102 | int input_stride, |
103 | bool scale_bias_last); |
104 | |
105 | private: |
106 | static asmjit::JitRuntime& runtime() { |
107 | static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, |
108 | // depents on other static |
109 | // variables. Required to prevent |
110 | // initialization order fiasco |
111 | return rt; |
112 | } |
113 | |
114 | static mutex rtMutex_; ///< Controll access to runtime; |
115 | |
116 | // The hash depends on bit_rate, embedding dimension (block size), weighted |
117 | // sls, positional weights, normalize by lenths, prefetch distance, |
118 | // use_offsets, output_stride, input_stride, and scale_bias_last |
119 | static CodeCache< |
120 | tuple<int, int, bool, bool, bool, int, bool, int, int, bool>, |
121 | typename ReturnFunctionSignature< |
122 | indxType, |
123 | offsetType, |
124 | outType, |
125 | ROWWISE_SPARSE>::jit_embedding_kernel, |
126 | THREAD_LOCAL> |
127 | codeCache_; ///< JIT Code Cache for reuse. |
128 | }; // GenEmbeddingSpmDMLookup |
129 | |
130 | template < |
131 | typename indxType, |
132 | typename offsetType, |
133 | typename outType, |
134 | inst_set_t instSet, |
135 | bool ROWWISE_SPARSE, |
136 | bool THREAD_LOCAL> |
137 | mutex GenEmbeddingSpMDMNBitLookup< |
138 | indxType, |
139 | offsetType, |
140 | outType, |
141 | instSet, |
142 | ROWWISE_SPARSE, |
143 | THREAD_LOCAL>::rtMutex_; |
144 | |
145 | template < |
146 | typename indxType, |
147 | typename offsetType, |
148 | typename outType, |
149 | inst_set_t instSet, |
150 | bool ROWWISE_SPARSE, |
151 | bool THREAD_LOCAL> |
152 | CodeCache< |
153 | tuple<int, int, bool, bool, bool, int, bool, int, int, bool>, |
154 | typename ReturnFunctionSignature< |
155 | indxType, |
156 | offsetType, |
157 | outType, |
158 | ROWWISE_SPARSE>::jit_embedding_kernel, |
159 | THREAD_LOCAL> |
160 | GenEmbeddingSpMDMNBitLookup< |
161 | indxType, |
162 | offsetType, |
163 | outType, |
164 | instSet, |
165 | ROWWISE_SPARSE, |
166 | THREAD_LOCAL>::codeCache_; |
167 | |
168 | template < |
169 | typename indxType, |
170 | typename offsetType, |
171 | typename outType, |
172 | inst_set_t instSet, |
173 | bool ROWWISE_SPARSE, |
174 | bool THREAD_LOCAL> |
175 | typename ReturnFunctionSignature< |
176 | indxType, |
177 | offsetType, |
178 | outType, |
179 | ROWWISE_SPARSE>::jit_embedding_kernel |
180 | GenEmbeddingSpMDMNBitLookup< |
181 | indxType, |
182 | offsetType, |
183 | outType, |
184 | instSet, |
185 | ROWWISE_SPARSE, |
186 | THREAD_LOCAL>:: |
187 | getOrCreate( |
188 | int bit_rate, |
189 | int block_size, |
190 | bool has_weight, |
191 | bool is_weight_positional, |
192 | bool normalize_by_lengths, |
193 | int prefetch, |
194 | bool use_offsets, |
195 | int output_stride, |
196 | int input_stride, |
197 | bool scale_bias_last) { |
198 | tuple<int, int, bool, bool, bool, int, bool, int, int, bool> kernelSig = |
199 | make_tuple( |
200 | bit_rate, |
201 | block_size, |
202 | has_weight, |
203 | is_weight_positional, |
204 | normalize_by_lengths, |
205 | prefetch, |
206 | use_offsets, |
207 | output_stride, |
208 | input_stride, |
209 | scale_bias_last); |
210 | |
211 | return codeCache_.getOrCreate( |
212 | kernelSig, |
213 | [&]() -> typename ReturnFunctionSignature< |
214 | indxType, |
215 | offsetType, |
216 | outType, |
217 | ROWWISE_SPARSE>::jit_embedding_kernel { |
218 | // TODO: Make this tunable |
219 | int pref_dist = prefetch; |
220 | bool areIndices64b = is_same<indxType, int64_t>::value; |
221 | |
222 | asmjit::CodeHolder code; |
223 | code.init(runtime().environment()); |
224 | x86::Assembler assembler(&code); |
225 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
226 | #if defined(FBGEMM_LOG_CODE) |
227 | string filename = "embeddinglookup_" + to_string(bit_rate) + "bit" ; |
228 | filename += "_emd_dim_" + to_string(block_size); |
229 | filename += areIndices64b ? "_64bit" : "_32bit" ; |
230 | filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2" ; |
231 | if (prefetch) { |
232 | filename += "_prefetch" ; |
233 | } |
234 | if (has_weight) { |
235 | filename += "_hasweight" ; |
236 | } |
237 | if (normalize_by_lengths) { |
238 | filename += "_normalize_by_lengths" ; |
239 | } |
240 | if (!use_offsets) { |
241 | filename += "_use_lengths" ; |
242 | } |
243 | if (ROWWISE_SPARSE) { |
244 | filename += "_rowwise_sparse" ; |
245 | } |
246 | if (!scale_bias_last) { |
247 | filename += "_scale_bias_first" |
248 | } |
249 | filename += ".txt" ; |
250 | FILE* codeLogFile = fopen(filename.c_str(), "w" ); |
251 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); |
252 | code.setLogger(codeLogger); |
253 | #endif |
254 | // arguments to the function created |
255 | x86::Gp output_size = a->zdi(); |
256 | // index_size will be overwritten to hold the end address of indices |
257 | x86::Gp index_size = a->zsi(); |
258 | x86::Gp data_size = a->zdx(); |
259 | x86::Gp input = a->zcx(); |
260 | int reg_id = 8; |
261 | x86::Gp indices = a->gpz(reg_id); // 8 |
262 | ++reg_id; |
263 | x86::Gp lengths = a->gpz(reg_id); // 9 |
264 | ++reg_id; |
265 | x86::Gp weights = a->gpz(reg_id); // 10 |
266 | ++reg_id; |
267 | x86::Gp out = a->gpz(reg_id); // 11 |
268 | |
269 | x86::Gp compressed_indices_table; |
270 | if (ROWWISE_SPARSE) { |
271 | ++reg_id; |
272 | compressed_indices_table = a->gpz(reg_id); // 12 |
273 | } |
274 | |
275 | ++reg_id; |
276 | x86::Gp scratchReg1_ = a->gpz(reg_id); // 12 or 13 |
277 | |
278 | ++reg_id; |
279 | x86::Gpd lengths_R_ = a->gpz(reg_id).r32(); // 13 or 14 |
280 | |
281 | ++reg_id; |
282 | x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15 |
283 | x86::Gp scratchReg3_; |
284 | if (instSet == inst_set_t::avx2) { |
285 | scratchReg3_ = a->zax(); |
286 | } |
287 | |
288 | asmjit::FuncDetail func; |
289 | |
290 | if (ROWWISE_SPARSE) { |
291 | func.init( |
292 | asmjit::FuncSignatureT< |
293 | bool, |
294 | int64_t, // output_size |
295 | int64_t, // index_size |
296 | int64_t, // uncompressed_data_size |
297 | const uint8_t*, // input uint8_t or float |
298 | const indxType*, // indices |
299 | const offsetType*, // offsets or lengths |
300 | const float*, // weights |
301 | float*, // out |
302 | const int32_t* /* compressed_indices_table */, |
303 | const int* /* mask */>(asmjit::CallConvId::kHost), |
304 | a->environment()); |
305 | } else { |
306 | func.init( |
307 | asmjit::FuncSignatureT< |
308 | bool, |
309 | int64_t, // output_size |
310 | int64_t, // index_size |
311 | int64_t, // data_size |
312 | const uint8_t*, // input uint8_t or float |
313 | const indxType*, // indices |
314 | const offsetType*, // offsets or lengths |
315 | const float*, // weights |
316 | float*, // out |
317 | const int* /* mask */>(asmjit::CallConvId::kHost), |
318 | a->environment()); |
319 | } |
320 | |
321 | asmjit::FuncFrame frame; |
322 | frame.init(func); |
323 | |
324 | frame.setDirtyRegs( |
325 | asmjit::RegGroup::kVec, |
326 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
327 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
328 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
329 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
330 | |
331 | frame.setDirtyRegs( |
332 | asmjit::RegGroup::kGp, |
333 | reg_id == 15 |
334 | ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
335 | : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); |
336 | |
337 | asmjit::FuncArgsAssignment args(&func); |
338 | if (ROWWISE_SPARSE) { |
339 | args.assignAll( |
340 | output_size, |
341 | index_size, |
342 | data_size, |
343 | input, |
344 | indices, |
345 | lengths, |
346 | weights, |
347 | out, |
348 | compressed_indices_table, |
349 | scratchReg1_); |
350 | } else { |
351 | args.assignAll( |
352 | output_size, |
353 | index_size, |
354 | data_size, |
355 | input, |
356 | indices, |
357 | lengths, |
358 | weights, |
359 | out, |
360 | scratchReg1_); |
361 | } |
362 | |
363 | args.updateFuncFrame(frame); |
364 | frame.finalize(); |
365 | |
366 | a->emitProlog(frame); |
367 | a->emitArgsAssignment(frame, args); |
368 | |
369 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
370 | constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS; |
371 | int unroll_factor = NUM_VEC_REG; |
372 | |
373 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
374 | |
375 | int num_vec_regs_per_block = ceil_div(block_size, vlen); |
376 | const int remainder = block_size % vlen; |
377 | |
378 | // Compute a remainder for vector load |
379 | // Since every row is followed by 2 fp16 (scale and bias), luckily |
380 | // we don't need mask at bit-rate granularity but just at 32-bit |
381 | // granularity. |
382 | int num_elem_per_32bit = 32 / bit_rate; |
383 | // multiply by 4 because we're handling 4 vlen per iteration |
384 | int num_of_32bit_per_vload = vlen * 4 / num_elem_per_32bit; |
385 | int remainder_32bit_granularity = |
386 | ceil_div(block_size, num_elem_per_32bit) % num_of_32bit_per_vload; |
387 | |
388 | vec_reg_t scale_vreg; // holds scale |
389 | vec_reg_t bias_vreg; // holds bias |
390 | vec_reg_t w_vreg; // for weighted sls -- weights |
391 | vec_reg_t |
392 | vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i] |
393 | vec_reg_t src_vreg; // for holding embedding value temporarily |
394 | x86::Ymm mask_vreg; // mask for avx2 |
395 | x86::Xmm mask2_vreg; |
396 | x86::Xmm mask_fp16_vreg; |
397 | |
398 | // We need 2 vec registers for 1. scale 2. bias |
399 | --unroll_factor; |
400 | scale_vreg = vec_reg_t(unroll_factor); |
401 | --unroll_factor; |
402 | bias_vreg = vec_reg_t(unroll_factor); |
403 | |
404 | --unroll_factor; |
405 | src_vreg = vec_reg_t(unroll_factor); |
406 | // temporary register for bit manipulation instructions |
407 | --unroll_factor; |
408 | vec_reg_t temp_vreg = vec_reg_t(unroll_factor); |
409 | vec_reg_t temp2_vreg; |
410 | |
411 | --unroll_factor; |
412 | temp2_vreg = vec_reg_t(unroll_factor); |
413 | |
414 | // Create a mask that extracts lower bit_rate bits from each 8-bit block |
415 | --unroll_factor; |
416 | vec_reg_t = vec_reg_t(unroll_factor); |
417 | a->lea( |
418 | x86::rsp, |
419 | x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t)))); |
420 | if (bit_rate == 4) { |
421 | a->mov(x86::word_ptr(x86::rsp), 0x0f0f); |
422 | a->vpbroadcastw(extract_mask_vreg, x86::word_ptr(x86::rsp)); |
423 | } else { |
424 | a->mov(x86::dword_ptr(x86::rsp), 0x03030303); |
425 | a->vpbroadcastd(extract_mask_vreg, x86::dword_ptr(x86::rsp)); |
426 | } |
427 | a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t))); |
428 | |
429 | if (has_weight) { |
430 | --unroll_factor; |
431 | w_vreg = vec_reg_t(unroll_factor); |
432 | } |
433 | |
434 | if (remainder && instSet == inst_set_t::avx2) { |
435 | // AVX512 doesn't need to use vector register for masking |
436 | --unroll_factor; |
437 | mask_vreg = x86::ymm(unroll_factor); |
438 | if (remainder > 1 && std::is_same<outType, float16>::value) { |
439 | --unroll_factor; |
440 | mask_fp16_vreg = x86::xmm(unroll_factor); |
441 | } |
442 | } |
443 | |
444 | // Creating a mask for vector load |
445 | if (remainder_32bit_granularity && instSet == inst_set_t::avx2) { |
446 | // AVX512 doesn't need to use vector register for masking |
447 | --unroll_factor; |
448 | mask2_vreg = x86::xmm(unroll_factor); |
449 | } |
450 | |
451 | if (normalize_by_lengths) { |
452 | --unroll_factor; |
453 | vlen_inv_vreg = vec_reg_t(unroll_factor); |
454 | } |
455 | |
456 | // Make unroll_factor a multiple of 4 |
457 | unroll_factor = unroll_factor / 4 * 4; |
458 | |
459 | if (remainder) { |
460 | if (instSet == inst_set_t::avx2) { |
461 | a->vmovups( |
462 | mask_vreg, |
463 | x86::ymmword_ptr( |
464 | scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); |
465 | if (std::is_same<outType, float16>::value) { |
466 | if (remainder > 1) { |
467 | a->vmovups( |
468 | mask_fp16_vreg, |
469 | x86::xmmword_ptr( |
470 | scratchReg1_, |
471 | (vlen - remainder / 2) * sizeof(int32_t))); |
472 | } |
473 | // We need to keep using the stack during the main loop |
474 | a->lea( |
475 | x86::rsp, |
476 | x86::dword_ptr( |
477 | x86::rsp, static_cast<int32_t>(-vlen * sizeof(int32_t)))); |
478 | } |
479 | } else { |
480 | a->mov(scratchReg1_, (1 << remainder) - 1); |
481 | a->kmovw(x86::k(1), scratchReg1_); |
482 | } |
483 | } |
484 | |
485 | if (remainder_32bit_granularity) { |
486 | if (instSet == inst_set_t::avx2) { |
487 | a->lea( |
488 | x86::rsp, |
489 | x86::dword_ptr( |
490 | x86::rsp, (int32_t)(-(vlen / 2) * sizeof(int32_t)))); |
491 | for (int i = 0; i < remainder_32bit_granularity; i++) { |
492 | a->mov(x86::dword_ptr(x86::rsp, i * sizeof(int32_t)), -1); |
493 | } |
494 | for (int i = remainder_32bit_granularity; i < vlen / 2; i++) { |
495 | a->mov(x86::dword_ptr(x86::rsp, i * sizeof(int32_t)), 0); |
496 | } |
497 | a->vmovups(mask2_vreg, x86::dword_ptr(x86::rsp)); |
498 | a->lea( |
499 | x86::rsp, |
500 | x86::dword_ptr( |
501 | x86::rsp, (int32_t)((vlen / 2) * sizeof(int32_t)))); |
502 | } else { |
503 | a->mov(scratchReg1_, (1 << remainder_32bit_granularity) - 1); |
504 | a->kmovw(x86::k(2), scratchReg1_); |
505 | } |
506 | } |
507 | |
508 | // Compute the end address of indices |
509 | a->lea( |
510 | index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2)); |
511 | |
512 | asmjit::Label exit = a->newLabel(); |
513 | asmjit::Label error = a->newLabel(); |
514 | asmjit::Label LoopRangeIndexBegin = a->newLabel(); |
515 | asmjit::Label LoopRangeIndexEnd = a->newLabel(); |
516 | |
517 | // rangeIndex loop begins (iterate output_size times) |
518 | a->bind(LoopRangeIndexBegin); |
519 | a->dec(output_size); |
520 | a->jl(LoopRangeIndexEnd); |
521 | |
522 | if (normalize_by_lengths) { |
523 | asmjit::Label IfLengthsBegin = a->newLabel(); |
524 | asmjit::Label IfLengthsEnd = a->newLabel(); |
525 | a->bind(IfLengthsBegin); |
526 | if (use_offsets) { |
527 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
528 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
529 | } else { |
530 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
531 | } |
532 | a->cmp(lengths_R_, 1); |
533 | // Initialize vlen_inv as 0 in case lengths is 0 |
534 | a->vxorps(vlen_inv_vreg, vlen_inv_vreg, vlen_inv_vreg); |
535 | a->jl(IfLengthsEnd); |
536 | |
537 | vec_reg_t temp_vreg0(0); |
538 | if (instSet == inst_set_t::avx2) { |
539 | a->mov(scratchReg1_, 1); |
540 | a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_); |
541 | a->cvtsi2ss(temp_vreg0.xmm(), lengths_R_); |
542 | a->divss(vlen_inv_vreg.xmm(), temp_vreg0.xmm()); |
543 | a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm()); |
544 | } else { |
545 | a->mov(scratchReg1_, 1); |
546 | a->cvtsi2ss(temp_vreg0.xmm(), scratchReg1_); |
547 | a->vpbroadcastd(vlen_inv_vreg, temp_vreg0.xmm()); |
548 | a->vpbroadcastd(temp_vreg0, lengths_R_); |
549 | a->vcvtdq2ps(temp_vreg0, temp_vreg0); |
550 | a->vdivps(vlen_inv_vreg, vlen_inv_vreg, temp_vreg0); |
551 | } |
552 | a->bind(IfLengthsEnd); |
553 | } |
554 | |
555 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; |
556 | vec_idx += unroll_factor) { |
557 | int cur_unroll_factor = |
558 | std::min(unroll_factor, num_vec_regs_per_block - vec_idx); |
559 | |
560 | // Initialize output regs |
561 | for (int v = 0; v < cur_unroll_factor; ++v) { |
562 | vec_reg_t out_vreg = vec_reg_t(v); |
563 | a->vxorps(out_vreg, out_vreg, out_vreg); |
564 | } |
565 | |
566 | if (use_offsets) { |
567 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
568 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
569 | } else { |
570 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
571 | } |
572 | |
573 | // Array out of bound check |
574 | a->lea( |
575 | scratchReg1_, |
576 | x86::ptr(indices, lengths_R_, areIndices64b ? 3 : 2)); |
577 | a->cmp(scratchReg1_, index_size); |
578 | a->jg(error); |
579 | |
580 | asmjit::Label LoopDataIndexBegin = a->newLabel(); |
581 | asmjit::Label LoopDataIndexEnd = a->newLabel(); |
582 | asmjit::Label ValidIndexLabel = a->newLabel(); |
583 | |
584 | // dataIndex loop begins (iterate lengths_R_ times) |
585 | a->bind(LoopDataIndexBegin); |
586 | a->dec(lengths_R_); |
587 | a->jl(LoopDataIndexEnd); |
588 | |
589 | // Array out of bound check |
590 | if (areIndices64b) { |
591 | a->mov(scratchReg1_, x86::qword_ptr(indices)); |
592 | } else { |
593 | a->mov(scratchReg1_.r32(), x86::dword_ptr(indices)); |
594 | } |
595 | if (!scale_bias_last) { |
596 | // When scale_bias_last == false, assume this is for table batched |
597 | // embedding (TBE) that can get -1 for pruned rows. |
598 | if (areIndices64b) { |
599 | a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1)); |
600 | } else { |
601 | a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1)); |
602 | } |
603 | a->jne(ValidIndexLabel); |
604 | a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType))); |
605 | if (has_weight) { |
606 | a->add(weights, static_cast<asmjit::Imm>(sizeof(float))); |
607 | } |
608 | a->jmp(LoopDataIndexBegin); |
609 | a->bind(ValidIndexLabel); |
610 | } |
611 | // A trick to check x >= data_size or x < 0 in one shot by treating |
612 | // scratchReg1_ as if it has unsigned value |
613 | // (https://stackoverflow.com/a/34072155). |
614 | a->cmp(scratchReg1_, data_size); |
615 | a->jae(error); |
616 | |
617 | if (ROWWISE_SPARSE) { |
618 | a->mov( |
619 | scratchReg1_.r32(), |
620 | x86::dword_ptr( |
621 | compressed_indices_table, |
622 | scratchReg1_, |
623 | 2)); // use of 2 is to multiply by 4 |
624 | } |
625 | |
626 | int num_elem_per_byte = 8 / bit_rate; |
627 | int fused_block_size = input_stride; |
628 | if (pref_dist) { |
629 | asmjit::Label pref_dist_reset_start = a->newLabel(); |
630 | asmjit::Label pref_dist_reset_end = a->newLabel(); |
631 | // out of bound handling for prefetch |
632 | a->lea( |
633 | scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType))); |
634 | a->cmp(scratchReg2_, index_size); |
635 | a->jge(pref_dist_reset_start); |
636 | |
637 | if (areIndices64b) { |
638 | a->mov( |
639 | scratchReg2_, |
640 | x86::qword_ptr(indices, pref_dist * sizeof(indxType))); |
641 | } else { |
642 | a->mov( |
643 | scratchReg2_.r32(), |
644 | x86::dword_ptr(indices, pref_dist * sizeof(indxType))); |
645 | } |
646 | |
647 | a->jmp(pref_dist_reset_end); |
648 | |
649 | a->bind(pref_dist_reset_start); |
650 | // things are not okay just get the current row |
651 | // this can be improved to getting the max dist row. |
652 | if (areIndices64b) { |
653 | a->mov(scratchReg2_, x86::qword_ptr(indices)); |
654 | } else { |
655 | a->mov(scratchReg2_.r32(), x86::dword_ptr(indices)); |
656 | } |
657 | |
658 | a->bind(pref_dist_reset_end); |
659 | if (ROWWISE_SPARSE) { |
660 | asmjit::Label rowwise_sparse_pref_corner_case_begin = |
661 | a->newLabel(); |
662 | asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel(); |
663 | a->cmp(scratchReg2_, data_size); |
664 | a->jae(rowwise_sparse_pref_corner_case_begin); |
665 | |
666 | a->mov( |
667 | scratchReg2_.r32(), |
668 | x86::dword_ptr( |
669 | compressed_indices_table, |
670 | scratchReg2_, |
671 | 2)); // use of 2 is to multiply by 4 |
672 | a->test(scratchReg2_.r32(), scratchReg2_.r32()); |
673 | // Check negative |
674 | a->jns(rowwise_sparse_pref_corner_case_end); |
675 | |
676 | a->bind(rowwise_sparse_pref_corner_case_begin); |
677 | // For corner case, just set prefetch row id to 0. |
678 | a->xor_(scratchReg2_.r32(), scratchReg2_.r32()); |
679 | a->bind(rowwise_sparse_pref_corner_case_end); |
680 | } |
681 | // This has to be fused_block_size |
682 | a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size)); |
683 | } |
684 | |
685 | a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType))); |
686 | |
687 | if (has_weight) { |
688 | a->vbroadcastss(w_vreg, x86::dword_ptr(weights)); |
689 | a->add(weights, static_cast<asmjit::Imm>(sizeof(float))); |
690 | } |
691 | |
692 | if (ROWWISE_SPARSE) { |
693 | a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1)); |
694 | a->je(LoopDataIndexBegin); |
695 | } |
696 | |
697 | a->imul(scratchReg1_, static_cast<asmjit::Imm>(fused_block_size)); |
698 | |
699 | // broadcast the scale |
700 | x86::Mem scale_src, bias_src; |
701 | int scale_offset = |
702 | scale_bias_last ? ceil_div(block_size, num_elem_per_byte) : 0; |
703 | scale_src = x86::word_ptr(input, scratchReg1_, 0, scale_offset); |
704 | bias_src = x86::word_ptr( |
705 | input, scratchReg1_, 0, scale_offset + sizeof(float16)); |
706 | a->vpbroadcastw(scale_vreg.half(), scale_src); |
707 | a->vpbroadcastw(bias_vreg.half(), bias_src); |
708 | a->vcvtph2ps(scale_vreg, scale_vreg.half()); |
709 | a->vcvtph2ps(bias_vreg, bias_vreg.half()); |
710 | constexpr unsigned int CACHE_LINE_LEN = 64; |
711 | if (pref_dist && fused_block_size % CACHE_LINE_LEN > 0 && |
712 | fused_block_size % CACHE_LINE_LEN <= 2 * sizeof(float16)) { |
713 | a->prefetcht0(x86::dword_ptr( |
714 | input, |
715 | scratchReg2_, |
716 | 0, |
717 | fused_block_size / CACHE_LINE_LEN * CACHE_LINE_LEN)); |
718 | } |
719 | |
720 | if (has_weight) { |
721 | a->vmulps(scale_vreg, scale_vreg, w_vreg); |
722 | a->vmulps(bias_vreg, bias_vreg, w_vreg); |
723 | } |
724 | |
725 | // The main computation |
726 | // Handling 4 vector registers per iteration because |
727 | // 1) when bit_rate == 4, we get zmm from ymm load via vpmovzxbw |
728 | // (epu8->epi16), and then get 4 zmms from each 128-bit portion of |
729 | // zmm via vpmovsxbd (epi8->epi32). |
730 | // 2) when bit_rate == 2, we get zmm from xmm load via vpmovzxbd |
731 | // (epu8->epi32), and then get 4 zmms from each 128-bit portion of |
732 | // zmm via vpmovsxbd (epi8->epi32). |
733 | int src_addr_offset = scale_bias_last ? 0 : 2 * sizeof(float16); |
734 | for (int v = 0; v < cur_unroll_factor; v += 4) { |
735 | int bytes_per_vload = (vlen / num_elem_per_byte) * sizeof(uint8_t); |
736 | auto src_addr = x86::dword_ptr( |
737 | input, |
738 | scratchReg1_, |
739 | 0, |
740 | src_addr_offset + (vec_idx + v) * bytes_per_vload); |
741 | |
742 | if (bit_rate == 4) { |
743 | if (num_vec_regs_per_block - (vec_idx + v) < 4 && |
744 | remainder_32bit_granularity) { |
745 | if (instSet == inst_set_t::avx512) { |
746 | a->k(x86::k(2)).vmovups(src_vreg.ymm(), src_addr); |
747 | } else { |
748 | a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr); |
749 | } |
750 | a->vpmovzxbw(src_vreg, src_vreg.half()); |
751 | } else { |
752 | a->vpmovzxbw(src_vreg, src_addr); |
753 | } |
754 | a->vpslld(temp_vreg, src_vreg, asmjit::Imm(4)); |
755 | if (instSet == inst_set_t::avx512) { |
756 | a->vpord(src_vreg, src_vreg, temp_vreg); |
757 | a->vpandd(src_vreg, src_vreg, extract_mask_vreg); |
758 | } else { |
759 | a->vpor(src_vreg.ymm(), src_vreg.ymm(), temp_vreg.ymm()); |
760 | a->vpand( |
761 | src_vreg.ymm(), src_vreg.ymm(), extract_mask_vreg.ymm()); |
762 | } |
763 | } else { |
764 | if (num_vec_regs_per_block - (vec_idx + v) < 4 && |
765 | remainder_32bit_granularity) { |
766 | if (instSet == inst_set_t::avx512) { |
767 | a->k(x86::k(2)).vmovups(src_vreg.xmm(), src_addr); |
768 | a->vpmovzxbd(src_vreg, src_vreg.xmm()); |
769 | } else { |
770 | a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr); |
771 | a->vpmovzxbd(src_vreg, src_vreg.xmm()); |
772 | } |
773 | } else { |
774 | a->vpmovzxbd(src_vreg, src_addr); |
775 | } |
776 | a->vpslld(temp_vreg, src_vreg, 2 * 8 + 2); |
777 | a->vpslld(temp2_vreg, src_vreg, 8 + 4); |
778 | if (instSet == inst_set_t::avx512) { |
779 | a->vpord(temp_vreg, temp_vreg, temp2_vreg); |
780 | } else { |
781 | a->vpor(temp_vreg.ymm(), temp_vreg.ymm(), temp2_vreg.ymm()); |
782 | } |
783 | a->vpslld(temp2_vreg, src_vreg, 6); |
784 | if (instSet == inst_set_t::avx512) { |
785 | a->vpord(temp_vreg, temp_vreg, temp2_vreg); |
786 | a->vpord(src_vreg, temp_vreg, src_vreg); |
787 | a->vpandd(src_vreg, src_vreg, extract_mask_vreg); |
788 | } else { |
789 | a->vpor(temp_vreg.ymm(), temp_vreg.ymm(), temp2_vreg.ymm()); |
790 | a->vpor(src_vreg.ymm(), temp_vreg.ymm(), src_vreg.ymm()); |
791 | a->vpand( |
792 | src_vreg.ymm(), src_vreg.ymm(), extract_mask_vreg.ymm()); |
793 | } |
794 | } |
795 | |
796 | // AVX2: For the following loop, operations on src_vreg impact the |
797 | // next iteration. For i = 0, we make a copy. i = 1 just right |
798 | // shifts and uses it. i = 2 we extract upper 128 bits from the copy |
799 | // to src_vreg and use it. i = 3 just right shifts it and uses it. |
800 | for (int i = 0; |
801 | i < std::min(4, num_vec_regs_per_block - (vec_idx + v)); |
802 | ++i) { |
803 | vec_reg_t out_vreg = vec_reg_t(v + i); |
804 | if (i == 0) { |
805 | a->vpmovsxbd(temp_vreg, src_vreg.xmm()); |
806 | // this is only needed for avx2 |
807 | if (instSet == inst_set_t::avx2) { |
808 | a->vmovups(temp2_vreg, src_vreg); |
809 | } |
810 | } else { |
811 | if (instSet == inst_set_t::avx512) { |
812 | // We could've used avx512_ymm for clock frequency advantage, |
813 | // if there's an instruction to extract a 64-bit portion from |
814 | // a YMM as an XMM register. |
815 | a->vextracti32x4(temp_vreg.xmm(), src_vreg, asmjit::Imm(i)); |
816 | a->vpmovsxbd(temp_vreg, temp_vreg.xmm()); |
817 | } else { |
818 | if (i == 1) { |
819 | a->vpsrldq(src_vreg, src_vreg, asmjit::Imm(8)); |
820 | } else if (i == 2) { |
821 | a->vextractf128( |
822 | src_vreg.xmm(), temp2_vreg.ymm(), asmjit::Imm(i >> 1)); |
823 | } else { |
824 | a->vpsrldq(src_vreg, src_vreg, asmjit::Imm(8)); |
825 | } |
826 | a->vpmovsxbd(temp_vreg, src_vreg.xmm()); |
827 | } // avx2 |
828 | } // i > 0 |
829 | a->vcvtdq2ps(temp_vreg, temp_vreg); |
830 | a->vaddps(out_vreg, out_vreg, bias_vreg); |
831 | a->vfmadd231ps(out_vreg, temp_vreg, scale_vreg); |
832 | } // for each i |
833 | |
834 | int vload_per_cache_line = CACHE_LINE_LEN / bytes_per_vload; |
835 | int v_aligned = ceil_div(vec_idx + v, 4) * 4; |
836 | if (pref_dist && v_aligned % vload_per_cache_line == 0) { |
837 | a->prefetcht0(x86::dword_ptr( |
838 | input, scratchReg2_, 0, v_aligned * bytes_per_vload)); |
839 | } |
840 | } |
841 | |
842 | a->jmp(LoopDataIndexBegin); |
843 | a->bind(LoopDataIndexEnd); |
844 | |
845 | // This loop is for writing back out_vreg (results) |
846 | // back to memory |
847 | for (int v = 0; v < cur_unroll_factor; ++v) { |
848 | auto dst_addr = |
849 | x86::dword_ptr(out, (vec_idx + v) * vlen * sizeof(outType)); |
850 | vec_reg_t out_vreg = vec_reg_t(v); |
851 | |
852 | if (normalize_by_lengths) { |
853 | a->vmulps(out_vreg, out_vreg, vlen_inv_vreg); |
854 | } |
855 | |
856 | if (std::is_same<outType, float>::value) { |
857 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
858 | if (instSet == inst_set_t::avx512) { |
859 | a->k(x86::k(1)).vmovups(dst_addr, out_vreg); |
860 | } else { |
861 | a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm()); |
862 | } |
863 | } else { |
864 | a->vmovups(dst_addr, out_vreg); |
865 | } |
866 | } else { |
867 | // fp16 output |
868 | if (instSet == inst_set_t::avx2) { |
869 | // round nearest with no exception |
870 | a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8); |
871 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
872 | if (remainder > 1) { |
873 | a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); |
874 | } |
875 | if (remainder % 2 != 0) { |
876 | a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm()); |
877 | a->mov( |
878 | scratchReg1_.r16(), |
879 | x86::word_ptr( |
880 | x86::rsp, (remainder - 1) * sizeof(outType))); |
881 | a->mov( |
882 | x86::word_ptr( |
883 | out, |
884 | ((vec_idx + v) * vlen + (remainder - 1)) * |
885 | sizeof(outType)), |
886 | scratchReg1_.r16()); |
887 | } |
888 | } else { |
889 | a->vmovups(dst_addr, out_vreg.xmm()); |
890 | } |
891 | } else { |
892 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
893 | a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); |
894 | } else { |
895 | a->vcvtps2ph(dst_addr, out_vreg, 8); |
896 | } |
897 | } |
898 | } |
899 | } |
900 | |
901 | if (vec_idx + unroll_factor < num_vec_regs_per_block || |
902 | (has_weight && is_weight_positional)) { |
903 | // Reset lengths_R_, indices, weights to run the dataIndex loop |
904 | // again |
905 | if (use_offsets) { |
906 | a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); |
907 | a->sub(lengths_R_, x86::dword_ptr(lengths)); |
908 | } else { |
909 | a->mov(lengths_R_, x86::dword_ptr(lengths)); |
910 | } |
911 | |
912 | if (has_weight) { |
913 | a->imul( |
914 | scratchReg1_, |
915 | lengths_R_, |
916 | static_cast<asmjit::Imm>(sizeof(float))); |
917 | a->sub(weights, scratchReg1_); |
918 | |
919 | if (vec_idx + unroll_factor < num_vec_regs_per_block) { |
920 | a->imul( |
921 | scratchReg1_, |
922 | static_cast<asmjit::Imm>(sizeof(indxType) / sizeof(float))); |
923 | a->sub(indices, scratchReg1_); |
924 | } |
925 | } else { |
926 | a->imul( |
927 | scratchReg1_, |
928 | lengths_R_, |
929 | static_cast<asmjit::Imm>(sizeof(indxType))); |
930 | a->sub(indices, scratchReg1_); |
931 | } |
932 | } |
933 | } |
934 | |
935 | a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType))); |
936 | a->add(out, static_cast<asmjit::Imm>(output_stride * sizeof(outType))); |
937 | |
938 | a->jmp(LoopRangeIndexBegin); |
939 | a->bind(LoopRangeIndexEnd); |
940 | |
941 | a->cmp(indices, index_size); |
942 | a->jne(error); |
943 | a->mov(x86::eax, true); |
944 | a->jmp(exit); |
945 | a->bind(error); |
946 | a->mov(x86::eax, false); |
947 | a->bind(exit); |
948 | |
949 | if (remainder && instSet == inst_set_t::avx2 && |
950 | std::is_same<outType, float16>::value) { |
951 | a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); |
952 | } |
953 | |
954 | a->emitEpilog(frame); |
955 | |
956 | // jit_fused8bitembedding_kernel fn; |
957 | typename ReturnFunctionSignature< |
958 | indxType, |
959 | offsetType, |
960 | outType, |
961 | ROWWISE_SPARSE>::jit_embedding_kernel fn; |
962 | asmjit::Error err; |
963 | { |
964 | unique_lock<mutex> lock(rtMutex_); |
965 | err = runtime().add(&fn, &code); |
966 | } |
967 | if (err) { |
968 | cout << "Error: in fn add" << endl; |
969 | return nullptr; |
970 | } |
971 | |
972 | #if defined(FBGEMM_LOG_CODE) |
973 | fclose(codeLogFile); |
974 | delete codeLogger; |
975 | #endif |
976 | return fn; |
977 | }); |
978 | } |
979 | |
980 | } // namespace |
981 | |
982 | template < |
983 | typename indxType, |
984 | typename offsetType, |
985 | typename outType, |
986 | bool THREAD_LOCAL> |
987 | typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>:: |
988 | Type |
989 | GenerateEmbeddingSpMDMNBitWithStrides( |
990 | int bit_rate, |
991 | const int64_t block_size, |
992 | bool has_weight, |
993 | bool normalize_by_lengths, |
994 | int prefetch, |
995 | bool is_weight_positional, |
996 | bool use_offsets, |
997 | int64_t output_stride /*=-1*/, |
998 | int64_t input_stride /*=-1*/, |
999 | bool scale_bias_last /*=true*/) { |
1000 | assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4" ); |
1001 | |
1002 | if (!cpuinfo_initialize()) { |
1003 | throw runtime_error("Failed to initialize cpuinfo!" ); |
1004 | } |
1005 | if (output_stride == -1) { |
1006 | output_stride = block_size; |
1007 | } |
1008 | if (input_stride == -1) { |
1009 | int64_t num_elem_per_byte = 8 / bit_rate; |
1010 | input_stride = |
1011 | ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(float16); |
1012 | } |
1013 | if (fbgemmHasAvx512Support()) { |
1014 | static GenEmbeddingSpMDMNBitLookup< |
1015 | indxType, |
1016 | offsetType, |
1017 | outType, |
1018 | inst_set_t::avx512, |
1019 | /*ROWWISE_SPARSE=*/false, |
1020 | THREAD_LOCAL> |
1021 | kernel_generator; |
1022 | const auto original_func = kernel_generator.getOrCreate( |
1023 | bit_rate, |
1024 | block_size, |
1025 | has_weight, |
1026 | is_weight_positional, |
1027 | normalize_by_lengths, |
1028 | prefetch, |
1029 | use_offsets, |
1030 | output_stride, |
1031 | input_stride, |
1032 | scale_bias_last); |
1033 | return [=](int64_t output_size, |
1034 | int64_t index_size, |
1035 | int64_t data_size, |
1036 | const uint8_t* input, |
1037 | const indxType* indices, |
1038 | const offsetType* offsets_or_lengths, |
1039 | const float* weights, |
1040 | outType* out) { |
1041 | return original_func( |
1042 | output_size, |
1043 | index_size, |
1044 | data_size, |
1045 | input, |
1046 | indices, |
1047 | offsets_or_lengths, |
1048 | weights, |
1049 | out, |
1050 | nullptr /* mask not used in avx512 */); |
1051 | }; |
1052 | } else if (fbgemmHasAvx2Support()) { |
1053 | static GenEmbeddingSpMDMNBitLookup< |
1054 | indxType, |
1055 | offsetType, |
1056 | outType, |
1057 | inst_set_t::avx2, |
1058 | /*ROWWISE_SPARSE=*/false, |
1059 | THREAD_LOCAL> |
1060 | kernel_generator; |
1061 | const auto original_func = kernel_generator.getOrCreate( |
1062 | bit_rate, |
1063 | block_size, |
1064 | has_weight, |
1065 | is_weight_positional, |
1066 | normalize_by_lengths, |
1067 | prefetch, |
1068 | use_offsets, |
1069 | output_stride, |
1070 | input_stride, |
1071 | scale_bias_last); |
1072 | return [=](int64_t output_size, |
1073 | int64_t index_size, |
1074 | int64_t data_size, |
1075 | const uint8_t* input, |
1076 | const indxType* indices, |
1077 | const offsetType* offsets_or_lengths, |
1078 | const float* weights, |
1079 | outType* out) { |
1080 | return original_func( |
1081 | output_size, |
1082 | index_size, |
1083 | data_size, |
1084 | input, |
1085 | indices, |
1086 | offsets_or_lengths, |
1087 | weights, |
1088 | out, |
1089 | internal::avx2_ps_or_epi32_combined_mask); |
1090 | }; |
1091 | } else { |
1092 | #ifdef VLOG |
1093 | VLOG(0) << "AVX2 or AVX512 not found, taking the slow path" ; |
1094 | #endif |
1095 | return [=](int64_t output_size, |
1096 | int64_t index_size, |
1097 | int64_t data_size, |
1098 | const uint8_t* input, |
1099 | const indxType* indices, |
1100 | const offsetType* offsets_or_lengths, |
1101 | const float* weights, |
1102 | outType* out) { |
1103 | return EmbeddingSpMDMNBit_ref( |
1104 | bit_rate, |
1105 | block_size, |
1106 | output_size, |
1107 | index_size, |
1108 | data_size, |
1109 | input, |
1110 | indices, |
1111 | offsets_or_lengths, |
1112 | weights, |
1113 | normalize_by_lengths, |
1114 | out, |
1115 | is_weight_positional, |
1116 | use_offsets, |
1117 | output_stride, |
1118 | input_stride, |
1119 | scale_bias_last); |
1120 | }; |
1121 | } |
1122 | } |
1123 | |
1124 | template <typename IndexType, typename OffsetType, typename OutType> |
1125 | FBGEMM_API typename EmbeddingSpMDMKernelSignature< |
1126 | std::uint8_t, |
1127 | IndexType, |
1128 | OffsetType, |
1129 | OutType>::Type |
1130 | GenerateEmbeddingSpMDMNBit( |
1131 | int bit_rate, |
1132 | const std::int64_t block_size, |
1133 | bool has_weight, |
1134 | bool normalize_by_lengths, |
1135 | int prefetch, |
1136 | bool is_weight_positional, |
1137 | bool use_offsets) { |
1138 | return GenerateEmbeddingSpMDMNBitWithStrides<IndexType, OffsetType, OutType>( |
1139 | bit_rate, |
1140 | block_size, |
1141 | has_weight, |
1142 | normalize_by_lengths, |
1143 | prefetch, |
1144 | is_weight_positional, |
1145 | use_offsets); |
1146 | } |
1147 | |
1148 | template <typename indxType, typename offsetType> |
1149 | typename EmbeddingSpMDMRowWiseSparseKernelSignature< |
1150 | uint8_t, |
1151 | indxType, |
1152 | offsetType>::Type |
1153 | GenerateEmbeddingSpMDMNBitRowWiseSparse( |
1154 | int bit_rate, |
1155 | const int64_t block_size, |
1156 | bool has_weight, |
1157 | bool normalize_by_lengths, |
1158 | int prefetch, |
1159 | bool is_weight_positional, |
1160 | bool use_offsets) { |
1161 | assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4" ); |
1162 | |
1163 | if (!cpuinfo_initialize()) { |
1164 | throw runtime_error("Failed to initialize cpuinfo!" ); |
1165 | } |
1166 | int64_t num_elem_per_byte = 8 / bit_rate; |
1167 | int64_t input_stride = |
1168 | ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(float16); |
1169 | if (fbgemmHasAvx512Support()) { |
1170 | static GenEmbeddingSpMDMNBitLookup< |
1171 | indxType, |
1172 | offsetType, |
1173 | /*outType=*/float, |
1174 | inst_set_t::avx512, |
1175 | /*rowwise_sparse=*/true> |
1176 | kernel_generator; |
1177 | const auto original_func = kernel_generator.getOrCreate( |
1178 | bit_rate, |
1179 | block_size, |
1180 | has_weight, |
1181 | is_weight_positional, |
1182 | normalize_by_lengths, |
1183 | prefetch, |
1184 | use_offsets, |
1185 | /*output_stride=*/block_size, |
1186 | input_stride, |
1187 | /*scale_bias_last=*/true); |
1188 | return [=](int64_t output_size, |
1189 | int64_t index_size, |
1190 | int64_t uncompressed_data_size, |
1191 | const uint8_t* input, |
1192 | const indxType* indices, |
1193 | const offsetType* offsets_or_lengths, |
1194 | const float* weights, |
1195 | float* out, |
1196 | const int32_t* compressed_indices_table) { |
1197 | return original_func( |
1198 | output_size, |
1199 | index_size, |
1200 | uncompressed_data_size, |
1201 | input, |
1202 | indices, |
1203 | offsets_or_lengths, |
1204 | weights, |
1205 | out, |
1206 | compressed_indices_table, |
1207 | nullptr /* mask not used in avx512 */); |
1208 | }; |
1209 | } else if (fbgemmHasAvx2Support()) { |
1210 | static GenEmbeddingSpMDMNBitLookup< |
1211 | indxType, |
1212 | offsetType, |
1213 | /*outType=*/float, |
1214 | inst_set_t::avx2, |
1215 | /*rowwise_sparse=*/true> |
1216 | kernel_generator; |
1217 | const auto original_func = kernel_generator.getOrCreate( |
1218 | bit_rate, |
1219 | block_size, |
1220 | has_weight, |
1221 | is_weight_positional, |
1222 | normalize_by_lengths, |
1223 | prefetch, |
1224 | use_offsets, |
1225 | /*output_stride=*/block_size, |
1226 | input_stride, |
1227 | /*scale_bias_last=*/true); |
1228 | return [=](int64_t output_size, |
1229 | int64_t index_size, |
1230 | int64_t uncompressed_data_size, |
1231 | const uint8_t* input, |
1232 | const indxType* indices, |
1233 | const offsetType* offsets_or_lengths, |
1234 | const float* weights, |
1235 | float* out, |
1236 | const int32_t* compressed_indices_table) { |
1237 | return original_func( |
1238 | output_size, |
1239 | index_size, |
1240 | uncompressed_data_size, |
1241 | input, |
1242 | indices, |
1243 | offsets_or_lengths, |
1244 | weights, |
1245 | out, |
1246 | compressed_indices_table, |
1247 | internal::avx2_ps_or_epi32_combined_mask); |
1248 | }; |
1249 | } else { |
1250 | #ifdef VLOG |
1251 | VLOG(0) << "AVX2 or AVX512 not found, taking the slow path" ; |
1252 | #endif |
1253 | return [=](int64_t output_size, |
1254 | int64_t index_size, |
1255 | int64_t uncompressed_data_size, |
1256 | const uint8_t* input, |
1257 | const indxType* indices, |
1258 | const offsetType* offsets_or_lengths, |
1259 | const float* weights, |
1260 | float* out, |
1261 | const int32_t* compressed_indices_table) { |
1262 | return EmbeddingSpMDMNBitRowWiseSparse_ref( |
1263 | bit_rate, |
1264 | block_size, |
1265 | output_size, |
1266 | index_size, |
1267 | uncompressed_data_size, |
1268 | // compressed_data_size, |
1269 | input, |
1270 | indices, |
1271 | compressed_indices_table, |
1272 | offsets_or_lengths, |
1273 | weights, |
1274 | normalize_by_lengths, |
1275 | out, |
1276 | is_weight_positional, |
1277 | use_offsets); |
1278 | }; |
1279 | } |
1280 | } |
1281 | |
1282 | #define INSTANTIATE_SPMDM_BASE( \ |
1283 | INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ |
1284 | template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ |
1285 | uint8_t, \ |
1286 | INDEX_TYPE, \ |
1287 | OFFSET_TYPE, \ |
1288 | OUT_TYPE>::Type \ |
1289 | GenerateEmbeddingSpMDMNBitWithStrides< \ |
1290 | INDEX_TYPE, \ |
1291 | OFFSET_TYPE, \ |
1292 | OUT_TYPE, \ |
1293 | THREAD_LOCAL>( \ |
1294 | int bit_rate, \ |
1295 | const int64_t block_size, \ |
1296 | bool has_weight, \ |
1297 | bool normalize_by_lengths, \ |
1298 | int prefetch, \ |
1299 | bool is_weight_positional, \ |
1300 | bool use_offsets, \ |
1301 | int64_t output_stride, \ |
1302 | int64_t input_stride, \ |
1303 | bool scale_bias_last); |
1304 | |
1305 | #define INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ |
1306 | INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ |
1307 | INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \ |
1308 | template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ |
1309 | uint8_t, \ |
1310 | INDEX_TYPE, \ |
1311 | OFFSET_TYPE, \ |
1312 | OUT_TYPE>::Type \ |
1313 | GenerateEmbeddingSpMDMNBit<INDEX_TYPE, OFFSET_TYPE, OUT_TYPE>( \ |
1314 | int bit_rate, \ |
1315 | const int64_t block_size, \ |
1316 | bool has_weight, \ |
1317 | bool normalize_by_lengths, \ |
1318 | int prefetch, \ |
1319 | bool is_weight_positional, \ |
1320 | bool use_offsets); |
1321 | |
1322 | #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ |
1323 | INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \ |
1324 | INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float16) \ |
1325 | template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ |
1326 | uint8_t, \ |
1327 | INDEX_TYPE, \ |
1328 | OFFSET_TYPE>::Type \ |
1329 | GenerateEmbeddingSpMDMNBitRowWiseSparse<INDEX_TYPE, OFFSET_TYPE>( \ |
1330 | int bit_rate, \ |
1331 | const int64_t block_size, \ |
1332 | bool has_weight, \ |
1333 | bool normalize_by_lengths, \ |
1334 | int prefetch, \ |
1335 | bool is_weight_positional, \ |
1336 | bool use_offsets); |
1337 | |
1338 | #define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ |
1339 | INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \ |
1340 | INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t) |
1341 | |
1342 | INSTANTIATE_SPMDM_OFFSET_T(int32_t) |
1343 | INSTANTIATE_SPMDM_OFFSET_T(int64_t) |
1344 | |
1345 | #undef INSTANTIATE_SPMDM_OFFSET_T |
1346 | #undef INSTANTIATE_SPMDM_OUT_T |
1347 | #undef INSTANTIATE_SPMDM_THREAD_LOCAL |
1348 | #undef INSTANTIATE_SPMDM_BASE |
1349 | |
1350 | } // namespace fbgemm |
1351 | |