1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmEmbedding.h"
9
10#include <asmjit/asmjit.h>
11#include <cpuinfo.h>
12#include <cassert>
13#include <iostream>
14#include <mutex>
15#include "./CodeCache.h"
16#include "./MaskAvx2.h"
17#include "./RefImplementations.h"
18#include "fbgemm/SimdUtils.h"
19#include "fbgemm/Utils.h"
20
21using namespace std;
22
23namespace fbgemm {
24namespace {
25namespace x86 = asmjit::x86;
26
27template <typename indxType, typename offsetType, typename dataType>
28class ReturnFunctionSignature {
29 public:
30 using jit_sparse_adagrad_kernel = bool (*)(
31 int64_t output_size,
32 int64_t index_size,
33 int64_t data_size, // number of rows in w
34 dataType* w, // input/output parameters
35 const float* g, // input gradients
36 float* h, // input/output momentums
37 const indxType* indices, // indices of each row
38 const offsetType* offsets_or_lengths,
39 float epsilon,
40 float lr,
41 uint32_t* rand_buffer);
42};
43
44template <
45 typename indxType,
46 typename offsetType,
47 typename dataType,
48 inst_set_t instSet = inst_set_t::avx2>
49class GenRowWiseSparseAdagradFused {
50 public:
51 GenRowWiseSparseAdagradFused() {}
52
53 typename ReturnFunctionSignature<indxType, offsetType, dataType>::
54 jit_sparse_adagrad_kernel
55 getOrCreate(
56 const int* mask_avx2,
57 int block_size,
58 int prefetch,
59 bool use_offsets,
60 bool use_stochastic_rounding,
61 int grad_stride);
62
63 private:
64 static asmjit::JitRuntime& runtime() {
65 static asmjit::JitRuntime rt; // JIT Runtime for asmjit
66 return rt;
67 }
68
69 static mutex rtMutex_; /// Controll access to runtime;
70
71 // The hash depends on:
72 // avx2 mask array, embedding dimension (block size), prefetch distance,
73 // use_offsets and use_stochastic_rouding switch
74 static CodeCache<
75 tuple<const int*, int, int, bool, bool, int>,
76 typename ReturnFunctionSignature<indxType, offsetType, dataType>::
77 jit_sparse_adagrad_kernel>
78 codeCache_; ///< JIT Code Cache for reuse.
79}; // class GenRowWiseSparseAdagradFused
80
81template <
82 typename indxType,
83 typename offsetType,
84 typename dataType,
85 inst_set_t instSet>
86mutex GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>::
87 rtMutex_;
88
89template <
90 typename indxType,
91 typename offsetType,
92 typename dataType,
93 inst_set_t instSet>
94CodeCache<
95 tuple<const int*, int, int, bool, bool, int>,
96 typename ReturnFunctionSignature<indxType, offsetType, dataType>::
97 jit_sparse_adagrad_kernel>
98 GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>::
99 codeCache_;
100
101template <
102 typename indxType,
103 typename offsetType,
104 typename dataType,
105 inst_set_t instSet>
106typename ReturnFunctionSignature<indxType, offsetType, dataType>::
107 jit_sparse_adagrad_kernel
108 GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>::
109 getOrCreate(
110 const int* mask_avx2, // runtime constant
111 int block_size,
112 int prefetch,
113 bool use_offsets,
114 bool use_stochastic_rounding,
115 int grad_stride) {
116 tuple<const int*, int, int, bool, bool, int> kernelSig = make_tuple(
117 mask_avx2,
118 block_size,
119 prefetch,
120 use_offsets,
121 use_stochastic_rounding,
122 grad_stride);
123
124 return codeCache_.getOrCreate(
125 kernelSig,
126 [&]() -> typename ReturnFunctionSignature<
127 indxType,
128 offsetType,
129 dataType>::jit_sparse_adagrad_kernel {
130 asmjit::CodeHolder code;
131 code.init(runtime().environment());
132 x86::Assembler assembler(&code);
133 x86::Emitter* a = assembler.as<x86::Emitter>();
134 bool areIndices64b = is_same<indxType, int64_t>::value;
135 bool areWeightsFp16 = is_same<dataType, float16>::value;
136#if defined(FBGEMM_LOG_CODE)
137 string filename = "RowWiseSparseAdagradFused";
138 filename += "_emd_dim_" + to_string(block_size);
139 filename += "_wei_float";
140 filename += areWeightsFp16 ? "16" : "32";
141 filename += areIndices64b ? "_64bit" : "_32bit";
142 filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2";
143 if (prefetch) {
144 filename += "_prefetch";
145 }
146 filename += ".txt";
147 FILE* codeLogFile = fopen(filename.c_str(), "w");
148 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
149 code.setLogger(codeLogger);
150#endif
151
152 x86::Gp rand_buffer = a->zax();
153 x86::Gp output_size = a->zdi();
154 x86::Gp index_size = a->zsi();
155 x86::Gp data_size = a->zdx();
156 x86::Gp w = a->zcx();
157 x86::Gp g = a->gpz(8);
158 x86::Gp h = a->gpz(9);
159 x86::Gp indices = a->gpz(10);
160 x86::Gp lengths = a->gpz(11);
161 x86::Xmm epsilon(0);
162 x86::Xmm lr(1);
163 x86::Gpd lengths_R = a->gpz(12).r32();
164 x86::Gp scratchReg1 = a->gpz(13);
165 x86::Gp scratchReg2 = a->gpz(14); // for prefetching
166
167 asmjit::FuncDetail func;
168 func.init(
169 asmjit::FuncSignatureT<
170 bool, // return type
171 int64_t, // output_size
172 int64_t, // index_size
173 int64_t, // data_size
174 dataType*, // w
175 const float*, // g
176 float*, // h
177 const indxType*, // indices
178 const int*, // lengths
179 float, // epsilon
180 float, // lr then rand_buffer
181 uint32_t*>(asmjit::CallConvId::kHost),
182 a->environment());
183
184 asmjit::FuncFrame frame;
185 frame.init(func);
186
187 if (instSet == inst_set_t::avx2) {
188 frame.setDirtyRegs(
189 asmjit::RegGroup::kVec,
190 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
191 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
192 } else {
193 frame.setDirtyRegs(
194 asmjit::RegGroup::kVec,
195 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
196 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
197 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
198 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
199 }
200
201 frame.setDirtyRegs(
202 asmjit::RegGroup::kGp,
203 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
204
205 asmjit::FuncArgsAssignment args(&func);
206 args.assignAll(
207 output_size,
208 index_size,
209 data_size,
210 w,
211 g,
212 h,
213 indices,
214 lengths,
215 epsilon,
216 lr,
217 rand_buffer);
218
219 args.updateFuncFrame(frame);
220 frame.finalize();
221 a->emitProlog(frame);
222 a->emitArgsAssignment(frame, args);
223
224 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
225 constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
226
227 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
228
229 int num_vec_regs_per_block = (block_size + vlen - 1) / vlen;
230 int remainder = block_size % vlen;
231
232 vec_reg_t src_vreg; // for holding embedding value temporarily
233 x86::Ymm mask_vreg;
234
235 // Reserve registers with small ids first because some of them need to
236 // be used with an instruction not supported in avx512 for which a big
237 // register id won't work.
238 int first_available_vec_reg_id = 0;
239 x86::Ymm partial_sum_vreg = x86::Ymm(first_available_vec_reg_id);
240 ++first_available_vec_reg_id;
241 vec_reg_t float_step_vreg = vec_reg_t(first_available_vec_reg_id);
242 ++first_available_vec_reg_id;
243 vec_reg_t epsilon_vreg = vec_reg_t(first_available_vec_reg_id);
244 ++first_available_vec_reg_id;
245 vec_reg_t lr_vreg = vec_reg_t(first_available_vec_reg_id);
246 ++first_available_vec_reg_id;
247
248 a->vpbroadcastd(epsilon_vreg, epsilon);
249 a->vpbroadcastd(lr_vreg, lr);
250
251 // Reserve vector registers for random buffer generating
252 // S0...S3: global random buffer state
253 // R: generated random number in uint32_t
254 // r0: extracted random byte (uint8_t) shifted to bits[5...13]
255 // r1: temp
256 vec_reg_t R_vreg, S0_vreg, S1_vreg, S2_vreg, S3_vreg, r0_vreg, r1_vreg;
257 if (areWeightsFp16 && use_stochastic_rounding) {
258 R_vreg = vec_reg_t(first_available_vec_reg_id);
259 first_available_vec_reg_id++;
260 S0_vreg = vec_reg_t(first_available_vec_reg_id);
261 first_available_vec_reg_id++;
262 S1_vreg = vec_reg_t(first_available_vec_reg_id);
263 first_available_vec_reg_id++;
264 S2_vreg = vec_reg_t(first_available_vec_reg_id);
265 first_available_vec_reg_id++;
266 S3_vreg = vec_reg_t(first_available_vec_reg_id);
267 first_available_vec_reg_id++;
268 r0_vreg = vec_reg_t(first_available_vec_reg_id);
269 first_available_vec_reg_id++;
270 r1_vreg = vec_reg_t(first_available_vec_reg_id);
271 first_available_vec_reg_id++;
272
273 // Load random buffer for FP16 stochastic rounding
274 if (instSet == inst_set_t::avx2) {
275 a->vmovdqa(S0_vreg.ymm(), x86::dword_ptr(rand_buffer));
276 a->vmovdqa(
277 S1_vreg.ymm(),
278 x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)));
279 a->vmovdqa(
280 S2_vreg.ymm(),
281 x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)));
282 a->vmovdqa(
283 S3_vreg.ymm(),
284 x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)));
285 } else { // AVX512
286 a->vmovdqa32(S0_vreg, x86::dword_ptr(rand_buffer));
287 a->vmovdqa32(
288 S1_vreg,
289 x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)));
290 a->vmovdqa32(
291 S2_vreg,
292 x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)));
293 a->vmovdqa32(
294 S3_vreg,
295 x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)));
296 }
297 }
298
299 if (remainder) {
300 if (instSet == inst_set_t::avx2) {
301 src_vreg = vec_reg_t(first_available_vec_reg_id);
302 ++first_available_vec_reg_id;
303
304 mask_vreg = x86::Ymm(first_available_vec_reg_id);
305 ++first_available_vec_reg_id;
306 // Use scratchReg1 as temp
307 a->mov(scratchReg1, asmjit::imm(mask_avx2));
308 a->vmovups(
309 mask_vreg,
310 x86::ymmword_ptr(
311 scratchReg1, (vlen - remainder) % vlen * sizeof(int32_t)));
312 } else {
313 a->mov(scratchReg1, (1 << remainder) - 1);
314 a->kmovw(x86::k(1), scratchReg1);
315 }
316 }
317 // Need an extra mask for computing sum of gradients
318 int remainder_avx2 =
319 block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
320 x86::KReg reduce_mask_avx512;
321 if (remainder_avx2 && instSet == inst_set_t::avx512) {
322 reduce_mask_avx512 = x86::k(2);
323 a->mov(scratchReg1, (1 << remainder_avx2) - 1);
324 a->kmovw(reduce_mask_avx512, scratchReg1);
325 }
326
327 int unroll_factor = NUM_VEC_REG - first_available_vec_reg_id;
328
329 // Compute the end address of indices
330 a->imul(
331 scratchReg1,
332 index_size,
333 static_cast<asmjit::Imm>(sizeof(indxType)));
334 a->add(scratchReg1, indices);
335 a->mov(index_size, scratchReg1);
336
337 asmjit::Label exit = a->newLabel();
338 asmjit::Label error = a->newLabel();
339 asmjit::Label LoopRangeIndexBegin = a->newLabel();
340 asmjit::Label LoopRangeIndexEnd = a->newLabel();
341
342 // rangeIndex loop begin (iterate output_size times)
343 a->bind(LoopRangeIndexBegin);
344 a->dec(output_size);
345 a->jl(LoopRangeIndexEnd);
346
347 // Compute sq avg of gradients
348 // Even with avx512, we only need to use avx2 registers when computing
349 // partial_sum because some instructions we're using like vhaddps
350 // are only in avx2.
351 constexpr int vlen_avx2 =
352 simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
353 int num_vec_regs_per_block_avx2 =
354 (block_size + vlen_avx2 - 1) / vlen_avx2;
355
356 a->vxorps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg);
357
358 // TODO: need to do a tree-reduction to fully take advantage of
359 // unrolling
360 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2;
361 vec_idx += unroll_factor) {
362 int cur_unroll_factor =
363 std::min(unroll_factor, num_vec_regs_per_block_avx2 - vec_idx);
364 for (int v = 0; v < cur_unroll_factor; ++v) {
365 x86::Ymm out_vreg = x86::Ymm(v + first_available_vec_reg_id);
366
367 auto g_ptr =
368 x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float));
369 if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS &&
370 vec_idx + v == num_vec_regs_per_block_avx2 - 1) {
371 if (instSet == inst_set_t::avx2) {
372 a->vmaskmovps(out_vreg, mask_vreg, g_ptr);
373 } else {
374 a->k(reduce_mask_avx512).z().vmovups(out_vreg, g_ptr);
375 }
376 } else {
377 a->vmovups(out_vreg, g_ptr);
378 }
379 a->vmulps(out_vreg, out_vreg, out_vreg);
380 a->vaddps(partial_sum_vreg, partial_sum_vreg, out_vreg);
381 }
382 }
383 // Reduce sum to 1 value
384 // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
385 // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
386 // Use YMM/XMMs with smaller ids for AVX2 specific instructions like
387 // vhaddps
388 x86::Xmm partial_sum_xmm(partial_sum_vreg.id());
389 x86::Xmm float_step_xmm(float_step_vreg.id());
390 // a->vmovups(partial_sum_temp0_ymm, partial_sum_vreg);
391 a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg);
392 a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg);
393
394 //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3))
395 a->movss(float_step_xmm, partial_sum_xmm);
396 //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1))
397 a->vextractf128(partial_sum_xmm, partial_sum_vreg, 1);
398
399 // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
400 // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
401 a->addss(partial_sum_xmm, float_step_xmm);
402
403 // This fragment moves block size (N) to stack and bcasts it to xmm reg
404 a->lea(
405 x86::rsp,
406 x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
407 a->mov(x86::dword_ptr(x86::rsp), block_size);
408 a->vbroadcastss(
409 float_step_xmm,
410 x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1
411 a->vcvtdq2ps(float_step_xmm, float_step_xmm);
412 a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));
413
414 // final_sum /= N
415 a->divss(partial_sum_xmm, float_step_xmm);
416
417 if (use_offsets) {
418 a->mov(lengths_R, x86::dword_ptr(lengths, sizeof(offsetType)));
419 a->sub(lengths_R, x86::dword_ptr(lengths));
420 } else {
421 a->mov(lengths_R, x86::dword_ptr(lengths));
422 }
423
424 // Array out of bound check
425 a->imul(
426 scratchReg1, lengths_R, static_cast<asmjit::Imm>(sizeof(indxType)));
427
428 a->add(scratchReg1, indices);
429 a->cmp(scratchReg1, index_size);
430 a->jg(error);
431
432 asmjit::Label LoopDataIndexBegin = a->newLabel();
433 asmjit::Label LoopDataIndexEnd = a->newLabel();
434
435 // dataIndex loop begins (iterate lengths_R_ times)
436 a->bind(LoopDataIndexBegin);
437 a->dec(lengths_R);
438 a->jl(LoopDataIndexEnd);
439
440 // Array out of bound check
441 if (areIndices64b) {
442 a->mov(scratchReg1, x86::qword_ptr(indices));
443 } else {
444 a->mov(scratchReg1.r32(), x86::dword_ptr(indices));
445 }
446 // A trick to check x >= data_size or x < 0 in one shot by treating
447 // scratchReg1_ as if it has unsigned value
448 // (https://stackoverflow.com/a/34072155).
449 a->cmp(scratchReg1, data_size);
450 a->jae(error);
451
452 if (prefetch) {
453 asmjit::Label pref_dist_reset_start = a->newLabel();
454 asmjit::Label pref_dist_reset_end = a->newLabel();
455 // out of bound handling for prefetch
456 a->mov(scratchReg2, indices);
457 a->add(
458 scratchReg2,
459 static_cast<asmjit::Imm>(prefetch * sizeof(indxType)));
460 a->cmp(scratchReg2, index_size);
461 a->jge(pref_dist_reset_start);
462
463 if (areIndices64b) {
464 a->mov(
465 scratchReg2,
466 x86::qword_ptr(indices, prefetch * sizeof(indxType)));
467 } else {
468 a->mov(
469 scratchReg2.r32(),
470 x86::dword_ptr(indices, prefetch * sizeof(indxType)));
471 }
472
473 a->jmp(pref_dist_reset_end);
474
475 a->bind(pref_dist_reset_start);
476 // things are not okay just get the current row
477 // this can be improved to getting the max dist row.
478 if (areIndices64b) {
479 a->mov(scratchReg2, x86::qword_ptr(indices));
480 } else {
481 a->mov(scratchReg2.r32(), x86::dword_ptr(indices));
482 }
483
484 a->bind(pref_dist_reset_end);
485 }
486
487 a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType)));
488
489 if (prefetch) {
490 a->prefetchw(x86::dword_ptr(h, scratchReg2, 2));
491 }
492 // load h
493 a->movss(float_step_xmm, x86::dword_ptr(h, scratchReg1, 2));
494 // *h + final_sum
495 a->addss(float_step_xmm, partial_sum_xmm);
496 // store h
497 a->movss(x86::dword_ptr(h, scratchReg1, 2), float_step_xmm);
498 // sqrt(hi)
499 a->sqrtss(float_step_xmm, float_step_xmm);
500 // bcast partial to all of ymm/zmm reg
501 a->vpbroadcastd(float_step_vreg, float_step_xmm);
502 // lr / sqrt(hi) + epsilon
503 a->vaddps(float_step_vreg, float_step_vreg, epsilon_vreg);
504 a->vdivps(float_step_vreg, lr_vreg, float_step_vreg);
505
506 a->imul(scratchReg1, static_cast<asmjit::Imm>(block_size));
507 if (prefetch) {
508 a->imul(scratchReg2, static_cast<asmjit::Imm>(block_size));
509 }
510
511 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
512 vec_idx += unroll_factor) {
513 int cur_unroll_factor =
514 std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
515
516 // The main computation
517 for (int v = 0; v < cur_unroll_factor; ++v) {
518 vec_reg_t out_vreg = vec_reg_t(v + first_available_vec_reg_id);
519
520 auto g_ptr =
521 x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
522 if (!areWeightsFp16) { // float weights
523 auto w_ptr = x86::dword_ptr(
524 w, scratchReg1, 2, (vec_idx + v) * vlen * sizeof(dataType));
525 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
526 if (instSet == inst_set_t::avx2) {
527 a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr);
528 a->vmulps(src_vreg, float_step_vreg, src_vreg);
529
530 a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
531 a->vaddps(out_vreg, src_vreg, out_vreg);
532
533 a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
534 } else {
535 a->k(x86::k(1)).vmulps(out_vreg, float_step_vreg, g_ptr);
536 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
537 a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
538 }
539 } else {
540 a->vmulps(out_vreg, float_step_vreg, g_ptr);
541 a->vaddps(out_vreg, out_vreg, w_ptr);
542 a->vmovups(w_ptr, out_vreg);
543 }
544 } else { // float16 weights
545 auto w_ptr = x86::word_ptr(
546 w, scratchReg1, 1, (vec_idx + v) * vlen * sizeof(dataType));
547
548 if (use_stochastic_rounding) {
549 // Index [0..3] for extracted bytes
550 // Each int32 has 4 8-bit rand byte
551 int sr_idx = (vec_idx + v) % 4;
552
553 if (sr_idx == 0) {
554 // Generate R buffer every 4 steps of num_vec_regs_per_block
555 // loop. Each 8-bit in R (uint32_t) will be used once. It is
556 // shifted to the bits [5-13] then added to FP32 weights
557 // before FP16 conversion.
558 //
559 // The shifted 8 bit region
560 // +-------+--------+--------+--------+
561 // | | | xxxxx|xxx |
562 // 31 23 15 7 0
563 //
564 // Half float has 10 bits of mantissa, and float has 23, we
565 // are shifting the bits to cover the region where half
566 // floats can't represent data. This is bits[13..23] of the
567 // mantissa of FP32. This will be effectively adding a random
568 // variable of [0,1]
569
570 // Random generator using xoshiro128++
571 // Ref: http://prng.di.unimi.it/xoshiro128plusplus.c
572 a->vpaddd(r0_vreg, S0_vreg, S3_vreg);
573 a->vpslld(r1_vreg, r0_vreg, 7);
574 a->vpsrld(r0_vreg, r0_vreg, 25);
575 if (instSet == inst_set_t::avx2) {
576 a->vpor(R_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm());
577 } else {
578 a->vpord(R_vreg, r0_vreg, r1_vreg);
579 }
580 a->vpaddd(R_vreg, R_vreg, S0_vreg);
581
582 a->vpslld(r0_vreg, S1_vreg, 9);
583
584 if (instSet == inst_set_t::avx2) {
585 a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), S0_vreg.ymm());
586 a->vpxor(S3_vreg.ymm(), S3_vreg.ymm(), S1_vreg.ymm());
587 a->vpxor(S1_vreg.ymm(), S1_vreg.ymm(), S2_vreg.ymm());
588 a->vpxor(S0_vreg.ymm(), S0_vreg.ymm(), S3_vreg.ymm());
589
590 a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), r0_vreg.ymm());
591 } else {
592 a->vpxord(S2_vreg, S2_vreg, S0_vreg);
593 a->vpxord(S3_vreg, S3_vreg, S1_vreg);
594 a->vpxord(S1_vreg, S1_vreg, S2_vreg);
595 a->vpxord(S0_vreg, S0_vreg, S3_vreg);
596
597 a->vpxord(S2_vreg, S2_vreg, r0_vreg);
598 }
599 a->vpslld(r0_vreg, S3_vreg, 11);
600 a->vpsrld(r1_vreg, S3_vreg, 21);
601 if (instSet == inst_set_t::avx2) {
602 a->vpor(S3_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm());
603 } else {
604 a->vpord(S3_vreg, r0_vreg, r1_vreg);
605 }
606
607 // Extract byte 0 and shift to bits[5..13]
608 a->vpslld(r0_vreg, R_vreg, 24);
609 a->vpsrld(r0_vreg, r0_vreg, 19);
610 } else if (sr_idx == 1) {
611 // Extract byte 1 and shift to bits[[5..13]
612 a->vpsrld(r0_vreg, R_vreg, 8);
613 a->vpslld(r0_vreg, r0_vreg, 24);
614 a->vpsrld(r0_vreg, r0_vreg, 19);
615 } else if (sr_idx == 2) {
616 // Extract byte 2 and shift to bits[5..13]
617 a->vpslld(r0_vreg, R_vreg, 8);
618 a->vpsrld(r0_vreg, r0_vreg, 24);
619 a->vpslld(r0_vreg, r0_vreg, 5);
620 } else { // sr_idx == 3
621 // Extract byte 3 and shift to bits[5..13]
622 a->vpsrld(r0_vreg, R_vreg, 24);
623 a->vpslld(r0_vreg, r0_vreg, 5);
624 }
625 }
626
627 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
628 if (instSet == inst_set_t::avx2) {
629 a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr);
630 // No AVX2 mask load/store for 16bit
631 // Copy input to stack using loop instead and reuse GPR for h
632 a->lea(x86::rsp, x86::ptr(x86::rsp, -8));
633 a->mov(x86::ptr(x86::rsp), h);
634 a->lea(
635 x86::rsp,
636 x86::ptr(
637 x86::rsp, static_cast<int>(-vlen * sizeof(float16))));
638 for (int r = 0; r < remainder; ++r) {
639 a->mov(
640 h.r16(),
641 x86::word_ptr(
642 w,
643 scratchReg1,
644 1,
645 ((vec_idx + v) * vlen + r) * sizeof(dataType)));
646 a->mov(x86::ptr(x86::rsp, sizeof(dataType) * r), h.r16());
647 }
648 a->vcvtph2ps(out_vreg, x86::word_ptr(x86::rsp));
649 a->vfmadd231ps(out_vreg, float_step_vreg, src_vreg);
650 if (use_stochastic_rounding) {
651 a->vpaddd(out_vreg, r0_vreg, out_vreg);
652 }
653 // Truncate rounding to 'counterwork' the random added part
654 a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 11);
655 // Copy results back
656 for (int r = 0; r < remainder; ++r) {
657 a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r));
658 a->mov(
659 x86::word_ptr(
660 w,
661 scratchReg1,
662 1,
663 ((vec_idx + v) * vlen + r) * sizeof(dataType)),
664 h.r16());
665 }
666 a->lea(
667 x86::rsp,
668 x86::ptr(
669 x86::rsp, static_cast<int>(vlen * sizeof(float16))));
670 a->mov(h, x86::ptr(x86::rsp));
671 a->lea(x86::rsp, x86::ptr(x86::rsp, 8));
672 } else {
673 a->k(x86::k(1)).vcvtph2ps(out_vreg, w_ptr);
674 a->k(x86::k(1)).vfmadd231ps(out_vreg, float_step_vreg, g_ptr);
675 if (use_stochastic_rounding) {
676 a->vpaddd(out_vreg, r0_vreg, out_vreg);
677 }
678 // Truncate rounding
679 a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 11);
680 }
681 } else {
682 a->vcvtph2ps(out_vreg, w_ptr);
683 a->vfmadd231ps(out_vreg, float_step_vreg, g_ptr);
684 if (use_stochastic_rounding) {
685 a->vpaddd(out_vreg, r0_vreg, out_vreg);
686 }
687 // Truncate rounding
688 a->vcvtps2ph(w_ptr, out_vreg, 11);
689 }
690 }
691
692 constexpr int CACHE_LINE_LEN = 64;
693 constexpr int BYTES_PER_VLOAD = vlen * sizeof(dataType);
694 constexpr int VLOAD_PER_CACHE_LINE =
695 CACHE_LINE_LEN / BYTES_PER_VLOAD;
696 if (prefetch && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) {
697 a->prefetchw(x86::dword_ptr(
698 w,
699 scratchReg2,
700 areWeightsFp16 ? 1 : 2,
701 (vec_idx + v) * BYTES_PER_VLOAD));
702 }
703 }
704 }
705
706 a->jmp(LoopDataIndexBegin);
707 a->bind(LoopDataIndexEnd);
708
709 a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType)));
710 a->add(g, static_cast<asmjit::Imm>(grad_stride * sizeof(float)));
711
712 a->jmp(LoopRangeIndexBegin);
713 a->bind(LoopRangeIndexEnd);
714
715 a->cmp(indices, index_size);
716 a->jne(error);
717 a->mov(scratchReg1.r32(), 1);
718 a->jmp(exit);
719 a->bind(error);
720 a->mov(scratchReg1.r32(), 0);
721 a->bind(exit);
722
723 if (areWeightsFp16 && use_stochastic_rounding) {
724 if (instSet == inst_set_t::avx2) {
725 a->vmovdqa(x86::dword_ptr(rand_buffer), S0_vreg.ymm());
726 a->vmovdqa(
727 x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)),
728 S1_vreg.ymm());
729 a->vmovdqa(
730 x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)),
731 S2_vreg.ymm());
732 a->vmovdqa(
733 x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)),
734 S3_vreg.ymm());
735 } else {
736 a->vmovdqa32(x86::dword_ptr(rand_buffer), S0_vreg);
737 a->vmovdqa32(
738 x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)),
739 S1_vreg);
740 a->vmovdqa32(
741 x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)),
742 S2_vreg);
743 a->vmovdqa32(
744 x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)),
745 S3_vreg);
746 }
747 }
748
749 a->mov(x86::eax, scratchReg1.r32());
750 a->emitEpilog(frame);
751
752 // jit_fused8bitembedding_kernel fn;
753 typename ReturnFunctionSignature<indxType, offsetType, dataType>::
754 jit_sparse_adagrad_kernel fn;
755 asmjit::Error err;
756 {
757 unique_lock<mutex> lock(rtMutex_);
758 err = runtime().add(&fn, &code);
759 }
760 if (err) {
761 cout << "Error: in fn add" << endl;
762 return nullptr;
763 }
764
765#if defined(FBGEMM_LOG_CODE)
766 fclose(codeLogFile);
767 delete codeLogger;
768#endif
769 return fn;
770 });
771} // getOrCreate
772
773// Per-thread global buffer for random number generating, with max vector size
774constexpr size_t VLEN_MAX = simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS;
775alignas(64) static thread_local uint32_t g_rnd128v_buffer[4 * VLEN_MAX];
776static thread_local bool g_rnd128v_initialized = false;
777
778void rand_initialize() {
779 // Splitmix64: http://prng.di.unimi.it/splitmix64.c
780 auto rnd128_init_next = [](uint64_t& x) {
781 uint64_t z = (x += 0x9e3779b97f4a7c15);
782 z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
783 z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
784 return z ^ (z >> 31);
785 };
786
787 if (!g_rnd128v_initialized) {
788 uint64_t h0 = std::hash<std::thread::id>{}(std::this_thread::get_id());
789 for (auto i = 0; i < 4; ++i) {
790 g_rnd128v_buffer[i * VLEN_MAX] = rnd128_init_next(h0);
791 uint64_t h1 = g_rnd128v_buffer[i * VLEN_MAX];
792 for (size_t v = 1; v < VLEN_MAX; ++v) {
793 g_rnd128v_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1);
794 }
795 }
796 g_rnd128v_initialized = true;
797 }
798}
799
800} // namespace
801
802template <typename IndexType, typename OffsetType, typename DataType>
803FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<
804 IndexType,
805 OffsetType,
806 DataType>::Type
807GenerateRowWiseSparseAdaGradFused(
808 int block_size, // number of parameters per row
809 int prefetch,
810 bool use_offsets,
811 bool use_stochastic_rounding,
812 int grad_stride) {
813 if (!cpuinfo_initialize()) {
814 throw std::runtime_error("Failed to initialize cpuinfo!");
815 }
816 if (grad_stride == -1) {
817 grad_stride = block_size;
818 }
819
820 // Use avx512 only for fp16 + stochastic rounding
821 if (fbgemmHasAvx512Support() && std::is_same<DataType, float16>::value &&
822 use_stochastic_rounding) {
823 static GenRowWiseSparseAdagradFused<
824 IndexType,
825 OffsetType,
826 DataType,
827 inst_set_t::avx512>
828 kernel_generator;
829 const auto original_func = kernel_generator.getOrCreate(
830 nullptr,
831 block_size,
832 prefetch,
833 use_offsets,
834 use_stochastic_rounding,
835 grad_stride);
836 const auto lambda_func = [=](int64_t output_size,
837 int64_t index_size,
838 int64_t data_size,
839 DataType* w,
840 const float* g,
841 float* h,
842 const IndexType* indices,
843 const OffsetType* offsets_or_lengths,
844 float epsilon,
845 float lr) {
846 // Initialize random buffer in the first execution
847 // TODO: JIT
848 if (std::is_same<DataType, float16>::value && use_stochastic_rounding) {
849 rand_initialize();
850 }
851
852 return original_func(
853 output_size,
854 index_size,
855 data_size,
856 w, // input/output parameters
857 g, // input gradients
858 h, // input/output momentums
859 indices, // indices of each row
860 offsets_or_lengths,
861 epsilon,
862 lr,
863 g_rnd128v_buffer);
864 };
865 return lambda_func;
866 } else if (fbgemmHasAvx2Support()) {
867 static GenRowWiseSparseAdagradFused<
868 IndexType,
869 OffsetType,
870 DataType,
871 inst_set_t::avx2>
872 kernel_generator;
873 const auto original_func = kernel_generator.getOrCreate(
874 internal::avx2_ps_or_epi32_combined_mask,
875 block_size,
876 prefetch,
877 use_offsets,
878 use_stochastic_rounding,
879 grad_stride);
880 const auto lambda_func = [=](int64_t output_size,
881 int64_t index_size,
882 int64_t data_size,
883 DataType* w,
884 const float* g,
885 float* h,
886 const IndexType* indices,
887 const OffsetType* offsets_or_lengths,
888 float epsilon,
889 float lr) {
890 // Initialize random buffer in the first execution
891 // TODO: JIT
892 if (std::is_same<DataType, float16>::value && use_stochastic_rounding) {
893 rand_initialize();
894 }
895
896 return original_func(
897 output_size,
898 index_size,
899 data_size,
900 w, // input/output parameters
901 g, // input gradients
902 h, // input/output momentums
903 indices, // indices of each row
904 offsets_or_lengths,
905 epsilon,
906 lr,
907 g_rnd128v_buffer);
908 };
909 return lambda_func;
910 } else {
911 return [=](int64_t output_size,
912 int64_t index_size,
913 int64_t data_size,
914 DataType* w,
915 const float* g,
916 float* h,
917 const IndexType* indices,
918 const OffsetType* offsets_or_lengths,
919 float epsilon,
920 float lr) {
921 return rowwise_sparse_adagrad_fused_ref(
922 block_size,
923 output_size,
924 index_size,
925 data_size,
926 w,
927 g,
928 h,
929 indices,
930 offsets_or_lengths,
931 epsilon,
932 lr,
933 use_offsets,
934 use_stochastic_rounding,
935 /*emu_vector_size=*/8,
936 grad_stride);
937 };
938 }
939}
940
941template FBGEMM_API
942 typename RowWiseSparseAdaGradFusedSignature<int64_t, int32_t, float>::Type
943 GenerateRowWiseSparseAdaGradFused<int64_t, int32_t, float>(
944 int block_size, // number of parameters per row
945 int prefetch,
946 bool use_offsets,
947 bool use_stochastic_rounding,
948 int grad_stride);
949
950template FBGEMM_API
951 typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float>::Type
952 GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
953 int block_size, // number of parameters per row
954 int prefetch,
955 bool use_offsets,
956 bool use_stochastic_rounding,
957 int grad_stride);
958
959template FBGEMM_API
960 typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float>::Type
961 GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float>(
962 int block_size, // number of parameters per row
963 int prefetch,
964 bool use_offsets,
965 bool use_stochastic_rounding,
966 int grad_stride);
967
968template FBGEMM_API
969 typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float>::Type
970 GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float>(
971 int block_size, // number of parameters per row
972 int prefetch,
973 bool use_offsets,
974 bool use_stochastic_rounding,
975 int grad_stride);
976
977template FBGEMM_API
978 typename RowWiseSparseAdaGradFusedSignature<int64_t, int32_t, float16>::Type
979 GenerateRowWiseSparseAdaGradFused<int64_t, int32_t, float16>(
980 int block_size, // number of parameters per row
981 int prefetch,
982 bool use_offsets,
983 bool use_stochastic_rounding,
984 int grad_stride);
985
986template FBGEMM_API
987 typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float16>::Type
988 GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float16>(
989 int block_size, // number of parameters per row
990 int prefetch,
991 bool use_offsets,
992 bool use_stochastic_rounding,
993 int grad_stride);
994
995template FBGEMM_API
996 typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float16>::Type
997 GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float16>(
998 int block_size, // number of parameters per row
999 int prefetch,
1000 bool use_offsets,
1001 bool use_stochastic_rounding,
1002 int grad_stride);
1003
1004template FBGEMM_API
1005 typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float16>::Type
1006 GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float16>(
1007 int block_size, // number of parameters per row
1008 int prefetch,
1009 bool use_offsets,
1010 bool use_stochastic_rounding,
1011 int grad_stride);
1012
1013} // namespace fbgemm
1014