1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmEmbedding.h"
9
10#include <asmjit/asmjit.h>
11#include <cpuinfo.h>
12#include <cmath>
13#include <iostream>
14#include <mutex>
15#include <string>
16#include <tuple>
17#include "./CodeCache.h"
18#include "./MaskAvx2.h"
19#include "./RefImplementations.h"
20#include "fbgemm/SimdUtils.h"
21#include "fbgemm/Utils.h"
22
23namespace fbgemm {
24
25namespace {
26namespace x86 = asmjit::x86;
27
28template <typename indxType = std::int64_t>
29class ReturnFunctionSignature {
30 public:
31 using jit_sparse_adagrad_kernel = int (*)(
32 int num_rows, // number of rows reading
33 std::uint64_t param_size, // total number of parameters
34 float* w, // input/output parameters
35 const float* g, // input gradients
36 float* h, // input/output momentums
37 const indxType* indices, // indices of each row
38 float epsilon,
39 float lr,
40 const int* mask_avx2,
41 float weight_decay,
42 const double* counter,
43 std::int64_t counter_halflife);
44};
45
46template <
47 typename indxType = std::int64_t,
48 inst_set_t instSet = inst_set_t::avx2>
49class GenSparseAdagrad {
50 public:
51 GenSparseAdagrad() {}
52 void genSparseAdagrad(
53 x86::Emitter* a,
54 int unroll_factor,
55 int num_vec_regs_per_block,
56 int remainder,
57 int prefetch,
58 typename simd_info<instSet>::vec_reg_t epsilon_vreg,
59 typename simd_info<instSet>::vec_reg_t lr_vreg,
60 x86::Ymm mask_vreg,
61 typename simd_info<instSet>::vec_reg_t temp_vreg,
62 typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
63 bool has_weight_decay);
64
65 void genRowwiseSparseAdagrad(
66 x86::Emitter* a,
67 int block_size,
68 int unroll_factor,
69 int num_vec_regs_per_block,
70 int remainder,
71 int prefetch,
72 typename simd_info<instSet>::vec_reg_t epsilon_vreg,
73 typename simd_info<instSet>::vec_reg_t lr_vreg,
74 x86::Ymm mask_vreg,
75 typename simd_info<instSet>::vec_reg_t temp_vreg,
76 typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
77 bool has_weight_decay);
78
79 typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
80 getOrCreate(
81 int block_size,
82 int prefetch,
83 bool rowwise,
84 bool has_weight_decay);
85
86 private:
87 static asmjit::JitRuntime& runtime() {
88 static asmjit::JitRuntime rt; // JIT Runtime for asmjit
89 return rt;
90 }
91
92 static std::mutex rtMutex_; /// Controll access to runtime;
93
94 // The hash depends on embedding dimension (block size), prefetch distance,
95 // rowwise, and has_weight_decay
96 static CodeCache<
97 std::tuple<int, int, bool, bool>,
98 typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel>
99 codeCache_; ///< JIT Code Cache for reuse.
100
101 // These are register we share accross SparseAdagrad and RowwiseSparseAdagrad
102 x86::Gp w;
103 x86::Gp g;
104 x86::Gp h;
105 x86::Gp indices;
106 x86::Gp base_offset;
107 x86::Gp temp1_; // loop counter
108 x86::Gp temp2_; // prefetch offset
109 x86::Gp temp3_; // prefetch offset of moment in rowwise adagrad
110
111 x86::KReg reduce_mask_avx512_;
112}; // GenEmbeddingLookup
113
114template <typename indxType, inst_set_t instSet>
115std::mutex GenSparseAdagrad<indxType, instSet>::rtMutex_;
116
117template <typename indxType, inst_set_t instSet>
118CodeCache<
119 std::tuple<int, int, bool, bool>,
120 typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel>
121 GenSparseAdagrad<indxType, instSet>::codeCache_;
122
123template <typename indxType, inst_set_t instSet>
124void GenSparseAdagrad<indxType, instSet>::genSparseAdagrad(
125 x86::Emitter* a,
126 int unroll_factor,
127 int num_vec_regs_per_block,
128 int remainder,
129 int prefetch,
130 typename simd_info<instSet>::vec_reg_t epsilon_vreg,
131 typename simd_info<instSet>::vec_reg_t lr_vreg,
132 x86::Ymm mask_vreg,
133 typename simd_info<instSet>::vec_reg_t temp_vreg,
134 typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
135 bool has_weight_decay) {
136 // NOTE: temp_vreg is defined only when remainder is true and instSet == avx2
137 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
138 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
139 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
140 vec_idx += unroll_factor) {
141 int cur_unroll_factor =
142 std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
143
144 for (int v = 0; v < cur_unroll_factor; ++v) {
145 vec_reg_t out_vreg = vec_reg_t(v);
146 vec_reg_t g_vreg = vec_reg_t(v + cur_unroll_factor);
147
148 if (prefetch && ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
149 // Intel SDE (wrongly) thinks prefetchwt1 is not available in BDW
150 a->prefetchw(
151 x86::dword_ptr(h, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
152
153 a->prefetchw(
154 x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
155 }
156
157 auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
158 auto h_ptr = x86::dword_ptr(
159 h, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
160 auto w_ptr = x86::dword_ptr(
161 w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
162 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
163 if (instSet == inst_set_t::avx2) {
164 a->vmaskmovps(g_vreg.ymm(), mask_vreg, g_ptr);
165 if (has_weight_decay) {
166 // TODO(@taiqing) use a vreg for weights to avoid duplicate indexing
167 a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
168 a->vfmadd231ps(g_vreg, temp_vreg, weight_decay_vreg);
169 }
170 a->vmulps(out_vreg, g_vreg, g_vreg);
171 a->vmaskmovps(temp_vreg.ymm(), mask_vreg, h_ptr);
172 a->vaddps(out_vreg, out_vreg, temp_vreg);
173
174 a->vmaskmovps(h_ptr, mask_vreg, out_vreg.ymm());
175
176 a->vsqrtps(out_vreg, out_vreg);
177 a->vaddps(out_vreg, out_vreg, epsilon_vreg);
178
179 a->vmulps(g_vreg, lr_vreg, g_vreg);
180 a->vdivps(out_vreg, g_vreg, out_vreg);
181
182 a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
183 a->vaddps(out_vreg, out_vreg, temp_vreg);
184
185 a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
186 } else if (instSet == inst_set_t::avx512) {
187 a->k(x86::k(1)).vmovups(g_vreg, g_ptr);
188 if (has_weight_decay) {
189 a->k(x86::k(1)).vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr);
190 }
191 a->k(x86::k(1)).vmulps(out_vreg, g_vreg, g_vreg);
192 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, h_ptr);
193
194 a->k(x86::k(1)).vmovups(h_ptr, out_vreg);
195
196 a->k(x86::k(1)).vsqrtps(out_vreg, out_vreg);
197 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, epsilon_vreg);
198
199 a->k(x86::k(1)).vmulps(g_vreg, lr_vreg, g_vreg);
200 a->k(x86::k(1)).vdivps(out_vreg, g_vreg, out_vreg);
201
202 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
203
204 a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
205 }
206 } else {
207 a->vmovups(g_vreg, g_ptr);
208 if (has_weight_decay) {
209 a->vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr);
210 }
211 a->vmulps(out_vreg, g_vreg, g_vreg);
212 a->vaddps(out_vreg, out_vreg, h_ptr);
213
214 a->vmovups(h_ptr, out_vreg);
215
216 a->vsqrtps(out_vreg, out_vreg);
217 a->vaddps(out_vreg, out_vreg, epsilon_vreg);
218
219 a->vmulps(g_vreg, lr_vreg, g_vreg);
220 a->vdivps(out_vreg, g_vreg, out_vreg);
221
222 a->vaddps(out_vreg, out_vreg, w_ptr);
223
224 a->vmovups(w_ptr, out_vreg);
225 }
226 }
227 }
228}
229
230template <typename indxType, inst_set_t instSet>
231void GenSparseAdagrad<indxType, instSet>::genRowwiseSparseAdagrad(
232 x86::Emitter* a,
233 int block_size,
234 int unroll_factor,
235 int num_vec_regs_per_block,
236 int remainder,
237 int prefetch,
238 typename simd_info<instSet>::vec_reg_t epsilon_vreg,
239 typename simd_info<instSet>::vec_reg_t lr_vreg,
240 x86::Ymm mask_vreg,
241 typename simd_info<instSet>::vec_reg_t temp_vreg,
242 typename simd_info<instSet>::vec_reg_t weight_decay_vreg,
243 bool has_weight_decay) {
244 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
245 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
246
247 // Reduce the unroll factor by 1 for partial sum
248 --unroll_factor;
249 vec_reg_t partial_sum_vreg = vec_reg_t(unroll_factor);
250
251 if (prefetch) {
252 a->prefetchw(x86::dword_ptr(h, temp3_));
253 }
254
255 bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
256 auto indices_ptr = areIndices64b
257 ? x86::qword_ptr(
258 indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
259 : x86::dword_ptr(
260 indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
261 if (has_weight_decay) {
262 // set base_offset for fetching w in the calculation of gradient square sum
263 a->imul(
264 areIndices64b ? base_offset : base_offset.r32(),
265 indices_ptr,
266 static_cast<asmjit::Imm>(block_size * sizeof(float)));
267 }
268
269 // Even with avx512, we only need to use avx2 registers when computing
270 // partial_sum because some instructions we're using like vhaddps
271 // are only in avx2.
272 constexpr int vlen_avx2 = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
273 int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2;
274
275 // Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps
276 x86::Ymm partial_sum_vreg_avx2(0);
277 x86::Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id());
278
279 a->vxorps(
280 partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
281
282 // TODO: need to do a tree-reduction to fully take advantage of unrolling
283 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2;
284 vec_idx += unroll_factor - 1) {
285 int cur_unroll_factor =
286 std::min(unroll_factor - 1, num_vec_regs_per_block_avx2 - vec_idx);
287 for (int v = 0; v < cur_unroll_factor; ++v) {
288 x86::Ymm out_vreg = x86::Ymm(v + 1);
289 if (has_weight_decay && prefetch &&
290 ((vec_idx + v) % (64 / (vlen_avx2 * sizeof(float))) == 0)) {
291 a->prefetchw(x86::dword_ptr(
292 w, temp2_, 0, (vec_idx + v) * vlen_avx2 * sizeof(float)));
293 }
294
295 auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float));
296 auto w_ptr = x86::dword_ptr(
297 w, base_offset, 0, (vec_idx + v) * vlen_avx2 * sizeof(float));
298 if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS &&
299 vec_idx + v == num_vec_regs_per_block_avx2 - 1) {
300 if (instSet == inst_set_t::avx2) {
301 a->vmaskmovps(out_vreg, mask_vreg, g_ptr);
302 if (has_weight_decay) {
303 a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr);
304 a->vfmadd231ps(out_vreg, temp_vreg, weight_decay_vreg);
305 }
306 } else {
307 a->k(reduce_mask_avx512_).z().vmovups(out_vreg, g_ptr);
308 if (has_weight_decay) {
309 a->k(reduce_mask_avx512_)
310 .z()
311 .vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
312 }
313 }
314 } else {
315 a->vmovups(out_vreg, g_ptr);
316 if (has_weight_decay) {
317 a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
318 }
319 }
320 a->vmulps(out_vreg, out_vreg, out_vreg);
321 a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg);
322 }
323 }
324 // Reduce sum to 1 value
325 // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum);
326 // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2);
327 a->vhaddps(
328 partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
329 a->vhaddps(
330 partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2);
331
332 x86::Xmm partial_sum_xmm1(1);
333
334 //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3))
335 a->movss(partial_sum_xmm1, partial_sum_xmm0);
336 //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1))
337 a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1);
338
339 // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) +
340 // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1));
341 a->addss(partial_sum_xmm0, partial_sum_xmm1);
342
343 // This fragment moves block size (N) to stack and bcasts it to xmm reg
344 a->lea(
345 x86::rsp,
346 x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t))));
347 a->mov(x86::dword_ptr(x86::rsp), block_size);
348 a->vbroadcastss(
349 partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1
350 a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1);
351 a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t)));
352
353 if (has_weight_decay) {
354 // set base_offset for fetching h
355 a->imul(
356 areIndices64b ? base_offset : base_offset.r32(),
357 indices_ptr,
358 static_cast<asmjit::Imm>(sizeof(float)));
359 }
360
361 // final_sum /= N
362 a->divss(partial_sum_xmm0, partial_sum_xmm1);
363 // load h
364 a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset));
365 // *h + final_sum
366 a->addss(partial_sum_xmm0, partial_sum_xmm1);
367 // store h
368 a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0);
369 // sqrt(hi)
370 a->sqrtss(partial_sum_xmm0, partial_sum_xmm0);
371 // bcast partial to all of ymm/zmm reg
372 a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0);
373 // lr / sqrt(hi) + epsilon
374 a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg);
375 a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg);
376 // partial_sum_vreg now has float_step
377
378 // set base_offset for fetching w in updating weights
379 a->imul(
380 areIndices64b ? base_offset : base_offset.r32(),
381 indices_ptr,
382 static_cast<asmjit::Imm>(block_size * sizeof(float)));
383
384 for (int vec_idx = 0; vec_idx < num_vec_regs_per_block;
385 vec_idx += unroll_factor) {
386 int cur_unroll_factor =
387 std::min(unroll_factor, num_vec_regs_per_block - vec_idx);
388
389 for (int v = 0; v < cur_unroll_factor; ++v) {
390 vec_reg_t out_vreg = vec_reg_t(v);
391
392 if (!has_weight_decay && prefetch &&
393 ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) {
394 a->prefetchw(
395 x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float)));
396 }
397
398 auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float));
399 auto w_ptr = x86::dword_ptr(
400 w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float));
401 if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
402 if (instSet == inst_set_t::avx2) {
403 a->vmaskmovps(temp_vreg.ymm(), mask_vreg, g_ptr);
404 if (has_weight_decay) {
405 a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
406 // TODO(@taiqing): have vreg for weights
407 a->vfmadd231ps(temp_vreg, weight_decay_vreg, out_vreg);
408 }
409 a->vmulps(temp_vreg, partial_sum_vreg, temp_vreg);
410
411 a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr);
412 a->vaddps(out_vreg, temp_vreg, out_vreg);
413
414 a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm());
415 } else {
416 if (has_weight_decay) {
417 a->k(x86::k(1)).vmovups(out_vreg, g_ptr);
418 a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
419 a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg);
420 } else {
421 a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr);
422 }
423 a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr);
424 a->k(x86::k(1)).vmovups(w_ptr, out_vreg);
425 }
426 } else {
427 if (has_weight_decay) {
428 a->vmovups(out_vreg, g_ptr);
429 a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr);
430 a->vmulps(out_vreg, partial_sum_vreg, out_vreg);
431 } else {
432 a->vmulps(out_vreg, partial_sum_vreg, g_ptr);
433 }
434 a->vaddps(out_vreg, out_vreg, w_ptr);
435 a->vmovups(w_ptr, out_vreg);
436 }
437 }
438 }
439}
440
441template <typename indxType, inst_set_t instSet>
442typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
443GenSparseAdagrad<indxType, instSet>::getOrCreate(
444 int block_size,
445 int prefetch,
446 bool rowwise,
447 bool has_weight_decay) {
448 std::tuple<int, int, bool, bool> kernelSig =
449 std::make_tuple(block_size, prefetch, rowwise, has_weight_decay);
450
451 return codeCache_.getOrCreate(
452 kernelSig,
453 [&]() ->
454 typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel {
455 asmjit::CodeHolder code;
456 code.init(runtime().environment());
457 x86::Assembler assembler(&code);
458 x86::Emitter* a = assembler.as<x86::Emitter>();
459 bool areIndices64b = std::is_same<indxType, std::int64_t>::value;
460#if defined(FBGEMM_LOG_CODE)
461 std::string filename = "SparseAdagrad";
462 filename += "_emd_dim_" + std::to_string(block_size);
463 if (rowwise) {
464 filename += "_rowwise";
465 }
466 filename += areIndices64b ? "_64bit" : "_32bit";
467 filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2";
468 if (prefetch) {
469 filename += "_prefetch";
470 }
471 if (has_weight_decay) {
472 filename += "weight_decay";
473 }
474 filename += ".txt";
475 FILE* codeLogFile = fopen(filename.c_str(), "w");
476 asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
477 code.setLogger(codeLogger);
478#endif
479
480 x86::Gpd num_rows = a->zdi().r32();
481 x86::Gp param_size = a->zsi();
482 w = a->zdx();
483 g = a->zcx();
484 h = a->gpz(8);
485 indices = a->gpz(9);
486 x86::Xmm epsilon(0);
487 x86::Xmm lr(1);
488 x86::Gp mask_avx2 = a->gpz(10);
489 x86::Xmm weight_decay(2);
490 x86::Gp counter = a->gpz(11);
491 x86::Gp counter_halflife = a->gpz(12);
492
493 // reuse mask_avx2 because mask_avx2 is used only at the beginning
494 base_offset = a->gpz(10);
495 temp1_ = a->gpz(13);
496 temp2_ = a->gpz(14);
497 temp3_ = a->gpz(15);
498
499 asmjit::FuncDetail func;
500 func.init(
501 asmjit::FuncSignatureT<
502 int, // return type
503 int, // num rows
504 std::uint64_t, // param_size
505 float*, // w
506 const float*, // g
507 float*, // h
508 const indxType*, // indices
509 float, // epsilon
510 float, // lr
511 const int*, // mask_avx2
512 float, // weight_decay
513 const double*, // counter then counter_halflife
514 std::int64_t>(asmjit::CallConvId::kHost),
515 a->environment());
516
517 asmjit::FuncFrame frame;
518 frame.init(func);
519
520 if (instSet == inst_set_t::avx2) {
521 frame.setDirtyRegs(
522 asmjit::RegGroup::kVec,
523 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
524 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
525 } else {
526 frame.setDirtyRegs(
527 asmjit::RegGroup::kVec,
528 asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
529 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
530 asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
531 asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
532 }
533
534 frame.setDirtyRegs(
535 asmjit::RegGroup::kGp,
536 asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
537
538 asmjit::FuncArgsAssignment args(&func);
539 args.assignAll(
540 num_rows,
541 param_size,
542 w,
543 g,
544 h,
545 indices,
546 epsilon,
547 lr,
548 mask_avx2,
549 weight_decay,
550 counter,
551 counter_halflife);
552
553 args.updateFuncFrame(frame);
554 frame.finalize();
555 a->emitProlog(frame);
556 a->emitArgsAssignment(frame, args);
557
558 constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS;
559 constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS;
560 int unroll_factor = NUM_VEC_REG;
561
562 typedef typename simd_info<instSet>::vec_reg_t vec_reg_t;
563
564 int num_vec_regs_per_block = (block_size + vlen - 1) / vlen;
565 int remainder = block_size % vlen;
566
567 vec_reg_t epsilon_vreg;
568 vec_reg_t lr_vreg;
569 vec_reg_t weight_decay_vreg;
570 vec_reg_t adjusted_weight_decay_vreg;
571 x86::Ymm mask_vreg; // mask for avx2
572 vec_reg_t
573 temp_vreg; // temp vreg for avx2 to handle remainder computation
574
575 --unroll_factor;
576 epsilon_vreg = vec_reg_t(unroll_factor);
577 --unroll_factor;
578 lr_vreg = vec_reg_t(unroll_factor);
579 if (has_weight_decay) {
580 --unroll_factor;
581 weight_decay_vreg = vec_reg_t(unroll_factor);
582 --unroll_factor;
583 adjusted_weight_decay_vreg = vec_reg_t(unroll_factor);
584 }
585
586 if (remainder) {
587 if (instSet == inst_set_t::avx2) {
588 --unroll_factor;
589 temp_vreg = vec_reg_t(unroll_factor);
590 }
591
592 // Creating masks for non multiples of vlen iterations
593 if (instSet == inst_set_t::avx2) {
594 --unroll_factor;
595 mask_vreg = x86::Ymm(unroll_factor);
596 a->vmovups(mask_vreg, x86::dword_ptr(mask_avx2));
597 } else {
598 a->mov(temp1_, (1 << remainder) - 1);
599 a->kmovw(x86::k(1), temp1_);
600 }
601 }
602 // Need an extra mask for computing sum of gradients
603 int remainder_avx2 =
604 block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
605 if (remainder_avx2 && instSet == inst_set_t::avx512 && rowwise) {
606 reduce_mask_avx512_ = x86::k(2);
607 a->mov(temp1_, (1 << remainder_avx2) - 1);
608 a->kmovw(reduce_mask_avx512_, temp1_);
609 }
610
611 if (!rowwise) {
612 unroll_factor = unroll_factor / 2; // accont for g_vreg
613 }
614
615 asmjit::Label exit = a->newLabel();
616 asmjit::Label LoopRangeIndexBegin = a->newLabel();
617 asmjit::Label LoopRangeIndexEnd = a->newLabel();
618
619 a->vpbroadcastd(epsilon_vreg, epsilon);
620 a->vpbroadcastd(lr_vreg, lr);
621 if (has_weight_decay) {
622 a->vpbroadcastd(weight_decay_vreg, weight_decay);
623 }
624
625 a->xor_(temp1_, temp1_);
626
627 a->bind(LoopRangeIndexBegin);
628 a->cmp(temp1_.r32(), num_rows); // temp1_ is the loop trip counter
629 a->jge(LoopRangeIndexEnd);
630
631 auto indices_ptr = areIndices64b
632 ? x86::qword_ptr(
633 indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t)
634 : x86::dword_ptr(
635 indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t)
636 a->imul(
637 areIndices64b ? base_offset : base_offset.r32(),
638 indices_ptr,
639 static_cast<asmjit::Imm>(
640 (rowwise ? 1 : block_size) * sizeof(float)));
641
642 // Perform this check
643 // if (block_size + offsetIdx > param_size) {
644 // return i;
645 // }
646 if (areIndices64b) {
647 a->mov(temp2_, indices_ptr);
648 } else {
649 a->mov(temp2_.r32(), indices_ptr);
650 }
651
652 if (has_weight_decay) {
653 // Check counter != nullptr && counter[idx] > 0
654 a->vmovaps(adjusted_weight_decay_vreg, weight_decay_vreg);
655
656 asmjit::Label skip_adjust_freq = a->newLabel();
657
658 a->cmp(counter, 0);
659 a->je(skip_adjust_freq);
660
661 // temp3_ : counter[idx]
662 a->mov(temp3_, x86::qword_ptr(counter, temp2_, 3));
663 a->cmp(temp3_, 0);
664 a->jle(skip_adjust_freq);
665
666 // OK to use Xmm registers with small ids that are reserved for temp
667 // values in the inner most loop.
668 vec_reg_t counter_halflife_vreg(0);
669 x86::Xmm counter_vreg(1);
670 a->cvtsi2sd(counter_halflife_vreg.xmm(), counter_halflife);
671 a->movq(counter_vreg, temp3_);
672 a->divpd(counter_halflife_vreg.xmm(), counter_vreg);
673 a->vcvtpd2ps(
674 counter_halflife_vreg.xmm(), counter_halflife_vreg.ymm());
675 a->vbroadcastss(counter_halflife_vreg, counter_halflife_vreg.xmm());
676 a->vmulps(
677 adjusted_weight_decay_vreg,
678 adjusted_weight_decay_vreg,
679 counter_halflife_vreg);
680
681 a->bind(skip_adjust_freq);
682 }
683
684 a->inc(temp2_);
685 a->imul(
686 temp2_,
687 static_cast<asmjit::Imm>(block_size)); //(offsetIdx+1)*blocksize
688 a->cmp(temp2_, param_size);
689 a->jg(exit);
690
691 if (prefetch) {
692 asmjit::Label pref_dist_reset_start = a->newLabel();
693 asmjit::Label pref_dist_reset_end = a->newLabel();
694
695 a->mov(temp2_, temp1_);
696 a->add(temp2_, prefetch);
697 a->cmp(temp2_.r32(), num_rows);
698 a->jge(pref_dist_reset_start);
699
700 auto pref_indices_ptr = areIndices64b
701 ? x86::qword_ptr(indices, temp2_, 3)
702 : x86::dword_ptr(indices, temp2_, 2);
703 if (rowwise) {
704 a->imul(
705 areIndices64b ? temp3_ : temp3_.r32(),
706 pref_indices_ptr,
707 static_cast<asmjit::Imm>(sizeof(float)));
708 }
709 a->imul(
710 areIndices64b ? temp2_ : temp2_.r32(),
711 pref_indices_ptr,
712 static_cast<asmjit::Imm>(block_size * sizeof(float)));
713
714 a->jmp(pref_dist_reset_end);
715
716 a->bind(pref_dist_reset_start);
717 a->imul(
718 areIndices64b ? temp2_ : temp2_.r32(),
719 indices_ptr,
720 static_cast<asmjit::Imm>(block_size * sizeof(float)));
721 if (rowwise) {
722 a->imul(
723 areIndices64b ? temp3_ : temp3_.r32(),
724 indices_ptr,
725 static_cast<asmjit::Imm>(sizeof(float)));
726 }
727
728 a->bind(pref_dist_reset_end);
729 } // prefetch
730
731 if (rowwise) {
732 genRowwiseSparseAdagrad(
733 a,
734 block_size,
735 unroll_factor,
736 num_vec_regs_per_block,
737 remainder,
738 prefetch,
739 epsilon_vreg,
740 lr_vreg,
741 mask_vreg,
742 temp_vreg,
743 adjusted_weight_decay_vreg,
744 has_weight_decay);
745 } else {
746 genSparseAdagrad(
747 a,
748 unroll_factor,
749 num_vec_regs_per_block,
750 remainder,
751 prefetch,
752 epsilon_vreg,
753 lr_vreg,
754 mask_vreg,
755 temp_vreg,
756 adjusted_weight_decay_vreg,
757 has_weight_decay);
758 }
759
760 a->add(g, static_cast<asmjit::Imm>(block_size * sizeof(float)));
761 a->inc(temp1_);
762 a->jmp(LoopRangeIndexBegin);
763 a->bind(LoopRangeIndexEnd);
764
765 a->bind(exit);
766 a->mov(x86::eax, temp1_.r32());
767 a->emitEpilog(frame);
768
769 typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel
770 fn;
771 asmjit::Error err;
772 {
773 std::unique_lock<std::mutex> lock(rtMutex_);
774 err = runtime().add(&fn, &code);
775 }
776 if (err) {
777 std::cout << "Error: in fn add" << std::endl;
778 return nullptr;
779 }
780
781#if defined(FBGEMM_LOG_CODE)
782 fclose(codeLogFile);
783 delete codeLogger;
784#endif
785 return fn;
786 });
787} // getOrCreate
788
789// Specialization for block size 1 internally called by GenerateSparseAdaGrad
790template <typename IndexType>
791int SparseAdaGradBlockSize1_(
792 int num_rows, // number of rows reading
793 std::uint64_t param_size, // total number of parameters
794 float* w, // input/output parameters
795 const float* g, // input gradients
796 float* h, // input/output momentums
797 const IndexType* indices, // indices of each row
798 float epsilon,
799 float lr,
800 bool rowwise,
801 float weight_decay,
802 const double* counter,
803 std::int64_t counter_halflife) {
804 if (weight_decay != 0.0f) {
805 for (int i = 0; i < num_rows; ++i) {
806 IndexType idx = indices[i];
807 if (idx >= static_cast<int64_t>(param_size)) {
808 return i;
809 }
810
811 float freq = (counter && counter[idx] > 0)
812 ? counter_halflife / counter[idx]
813 : 1.0f;
814 float gi = std::fma(freq * weight_decay, w[idx], g[i]);
815 float hi = h[idx] = h[idx] + gi * gi;
816 if (rowwise) {
817 w[idx] += lr / (std::sqrt(hi) + epsilon) * gi;
818 } else {
819 w[idx] += lr * gi / (std::sqrt(hi) + epsilon);
820 }
821 }
822 } else {
823 for (int i = 0; i < num_rows; ++i) {
824 IndexType idx = indices[i];
825 if (idx >= static_cast<int64_t>(param_size)) {
826 return i;
827 }
828 float gi = g[i];
829 float hi = h[idx] = h[idx] + gi * gi;
830 if (rowwise) {
831 w[idx] += lr / (std::sqrt(hi) + epsilon) * gi;
832 } else {
833 w[idx] += lr * gi / (std::sqrt(hi) + epsilon);
834 }
835 }
836 }
837 return num_rows;
838}
839
840template int SparseAdaGradBlockSize1_(
841 int num_rows, // number of rows reading
842 std::uint64_t param_size, // total number of parameters
843 float* w, // input parameters
844 const float* g, // input gradients
845 float* h, // input momentums
846 const std::int64_t* indices, // indices of each row
847 float epsilon,
848 float lr,
849 bool rowwise,
850 float weight_decay,
851 const double* counter,
852 std::int64_t counter_halflife);
853
854template int SparseAdaGradBlockSize1_(
855 int num_rows, // number of rows reading
856 std::uint64_t param_size, // total number of parameters
857 float* w, // input parameters
858 const float* g, // input gradients
859 float* h, // input momentums
860 const std::int32_t* indices, // indices of each row
861 float epsilon,
862 float lr,
863 bool rowwise,
864 float weight_decay,
865 const double* counter,
866 std::int64_t counter_halflife);
867
868} // namespace
869
870template <typename IndexType>
871typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
872 int block_size,
873 bool rowwise,
874 int prefetch,
875 bool use_weight_decay) {
876 if (!cpuinfo_initialize()) {
877 throw std::runtime_error("Failed to initialize cpuinfo!");
878 }
879
880 if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
881 if (block_size == 1) {
882 return [=](int num_rows, // number of rows reading
883 std::uint64_t param_size, // total number of parameters
884 float* w, // input/output parameters
885 const float* g, // input gradients
886 float* h, // input/output momentums
887 const IndexType* indices, // indices of each row
888 float epsilon,
889 float lr,
890 float weight_decay,
891 const double* counter,
892 std::int64_t counter_halflife) {
893 return SparseAdaGradBlockSize1_(
894 num_rows,
895 param_size,
896 w,
897 g,
898 h,
899 indices,
900 epsilon,
901 lr,
902 rowwise,
903 weight_decay,
904 counter,
905 counter_halflife);
906 };
907 }
908 static GenSparseAdagrad<IndexType, inst_set_t::avx2> kernel_generator;
909 constexpr int VLEN = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
910 const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask
911 [(VLEN - (block_size % VLEN)) % VLEN];
912 const auto original_func = kernel_generator.getOrCreate(
913 block_size, prefetch, rowwise, use_weight_decay);
914 return [=](int num_rows, // number of rows reading
915 std::uint64_t param_size, // total number of parameters
916 float* w, // input/output parameters
917 const float* g, // input gradients
918 float* h, // input/output momentums
919 const IndexType* indices, // indices of each row
920 float epsilon,
921 float lr,
922 float weight_decay,
923 const double* counter,
924 std::int64_t counter_halflife) {
925 return original_func(
926 num_rows, // number of rows reading
927 param_size, // total number of parameters
928 w, // input/output parameters
929 g, // input gradients
930 h, // input/output momentums
931 indices, // indices of each row
932 epsilon,
933 lr,
934 mask_avx2,
935 weight_decay,
936 counter,
937 counter_halflife);
938 };
939 } else {
940#ifdef VLOG
941 VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
942#endif
943 return [=](int num_rows, // number of rows reading
944 std::uint64_t param_size, // total number of parameters
945 float* w, // input/output parameters
946 const float* g, // input gradients
947 float* h, // input/output momentums
948 const IndexType* indices, // indices of each row
949 float epsilon,
950 float lr,
951 float weight_decay,
952 const double* counter,
953 std::int64_t counter_halflife) {
954 if (rowwise) {
955 return rowwise_sparse_adagrad_ref(
956 num_rows, // number of rows reading
957 block_size, // number of parameters per rows
958 param_size, // total number of parameters
959 w, // input/output parameters
960 g, // input gradients
961 h, // input/output momentums
962 indices,
963 epsilon,
964 lr,
965 weight_decay,
966 counter,
967 counter_halflife);
968 } else {
969 return sparse_adagrad_ref(
970 num_rows, // number of rows reading
971 block_size, // number of parameters per rows
972 param_size, // total number of parameters
973 w, // input/output parameters
974 g, // input gradients
975 h, // input/output momentums
976 indices,
977 epsilon,
978 lr,
979 weight_decay,
980 counter,
981 counter_halflife);
982 }
983 };
984 }
985}
986
987template FBGEMM_API typename SparseAdaGradSignature<std::int64_t>::Type
988GenerateSparseAdaGrad<std::int64_t>(
989 int block_size, // number of parameters per rows
990 bool rowwise,
991 int prefetch,
992 bool use_weight_decay);
993
994template FBGEMM_API typename SparseAdaGradSignature<std::int32_t>::Type
995GenerateSparseAdaGrad<std::int32_t>(
996 int block_size, // number of parameters per rows
997 bool rowwise,
998 int prefetch,
999 bool use_weight_decay);
1000
1001} // namespace fbgemm
1002