1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmEmbedding.h" |
9 | |
10 | #include <asmjit/asmjit.h> |
11 | #include <cpuinfo.h> |
12 | #include <cassert> |
13 | #include <iostream> |
14 | #include <mutex> |
15 | #include "./CodeCache.h" |
16 | #include "./MaskAvx2.h" |
17 | #include "./RefImplementations.h" |
18 | #include "fbgemm/SimdUtils.h" |
19 | #include "fbgemm/Utils.h" |
20 | |
21 | using namespace std; |
22 | |
23 | namespace fbgemm { |
24 | namespace { |
25 | namespace x86 = asmjit::x86; |
26 | |
27 | template <typename indxType, typename offsetType, typename dataType> |
28 | class ReturnFunctionSignature { |
29 | public: |
30 | using jit_sparse_adagrad_kernel = bool (*)( |
31 | int64_t output_size, |
32 | int64_t index_size, |
33 | int64_t data_size, // number of rows in w |
34 | dataType* w, // input/output parameters |
35 | const float* g, // input gradients |
36 | float* h, // input/output momentums |
37 | const indxType* indices, // indices of each row |
38 | const offsetType* offsets_or_lengths, |
39 | float epsilon, |
40 | float lr, |
41 | uint32_t* rand_buffer); |
42 | }; |
43 | |
44 | template < |
45 | typename indxType, |
46 | typename offsetType, |
47 | typename dataType, |
48 | inst_set_t instSet = inst_set_t::avx2> |
49 | class GenRowWiseSparseAdagradFused { |
50 | public: |
51 | GenRowWiseSparseAdagradFused() {} |
52 | |
53 | typename ReturnFunctionSignature<indxType, offsetType, dataType>:: |
54 | jit_sparse_adagrad_kernel |
55 | getOrCreate( |
56 | const int* mask_avx2, |
57 | int block_size, |
58 | int prefetch, |
59 | bool use_offsets, |
60 | bool use_stochastic_rounding, |
61 | int grad_stride); |
62 | |
63 | private: |
64 | static asmjit::JitRuntime& runtime() { |
65 | static asmjit::JitRuntime rt; // JIT Runtime for asmjit |
66 | return rt; |
67 | } |
68 | |
69 | static mutex rtMutex_; /// Controll access to runtime; |
70 | |
71 | // The hash depends on: |
72 | // avx2 mask array, embedding dimension (block size), prefetch distance, |
73 | // use_offsets and use_stochastic_rouding switch |
74 | static CodeCache< |
75 | tuple<const int*, int, int, bool, bool, int>, |
76 | typename ReturnFunctionSignature<indxType, offsetType, dataType>:: |
77 | jit_sparse_adagrad_kernel> |
78 | codeCache_; ///< JIT Code Cache for reuse. |
79 | }; // class GenRowWiseSparseAdagradFused |
80 | |
81 | template < |
82 | typename indxType, |
83 | typename offsetType, |
84 | typename dataType, |
85 | inst_set_t instSet> |
86 | mutex GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>:: |
87 | rtMutex_; |
88 | |
89 | template < |
90 | typename indxType, |
91 | typename offsetType, |
92 | typename dataType, |
93 | inst_set_t instSet> |
94 | CodeCache< |
95 | tuple<const int*, int, int, bool, bool, int>, |
96 | typename ReturnFunctionSignature<indxType, offsetType, dataType>:: |
97 | jit_sparse_adagrad_kernel> |
98 | GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>:: |
99 | codeCache_; |
100 | |
101 | template < |
102 | typename indxType, |
103 | typename offsetType, |
104 | typename dataType, |
105 | inst_set_t instSet> |
106 | typename ReturnFunctionSignature<indxType, offsetType, dataType>:: |
107 | jit_sparse_adagrad_kernel |
108 | GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>:: |
109 | getOrCreate( |
110 | const int* mask_avx2, // runtime constant |
111 | int block_size, |
112 | int prefetch, |
113 | bool use_offsets, |
114 | bool use_stochastic_rounding, |
115 | int grad_stride) { |
116 | tuple<const int*, int, int, bool, bool, int> kernelSig = make_tuple( |
117 | mask_avx2, |
118 | block_size, |
119 | prefetch, |
120 | use_offsets, |
121 | use_stochastic_rounding, |
122 | grad_stride); |
123 | |
124 | return codeCache_.getOrCreate( |
125 | kernelSig, |
126 | [&]() -> typename ReturnFunctionSignature< |
127 | indxType, |
128 | offsetType, |
129 | dataType>::jit_sparse_adagrad_kernel { |
130 | asmjit::CodeHolder code; |
131 | code.init(runtime().environment()); |
132 | x86::Assembler assembler(&code); |
133 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
134 | bool areIndices64b = is_same<indxType, int64_t>::value; |
135 | bool areWeightsFp16 = is_same<dataType, float16>::value; |
136 | #if defined(FBGEMM_LOG_CODE) |
137 | string filename = "RowWiseSparseAdagradFused" ; |
138 | filename += "_emd_dim_" + to_string(block_size); |
139 | filename += "_wei_float" ; |
140 | filename += areWeightsFp16 ? "16" : "32" ; |
141 | filename += areIndices64b ? "_64bit" : "_32bit" ; |
142 | filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2" ; |
143 | if (prefetch) { |
144 | filename += "_prefetch" ; |
145 | } |
146 | filename += ".txt" ; |
147 | FILE* codeLogFile = fopen(filename.c_str(), "w" ); |
148 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); |
149 | code.setLogger(codeLogger); |
150 | #endif |
151 | |
152 | x86::Gp rand_buffer = a->zax(); |
153 | x86::Gp output_size = a->zdi(); |
154 | x86::Gp index_size = a->zsi(); |
155 | x86::Gp data_size = a->zdx(); |
156 | x86::Gp w = a->zcx(); |
157 | x86::Gp g = a->gpz(8); |
158 | x86::Gp h = a->gpz(9); |
159 | x86::Gp indices = a->gpz(10); |
160 | x86::Gp lengths = a->gpz(11); |
161 | x86::Xmm epsilon(0); |
162 | x86::Xmm lr(1); |
163 | x86::Gpd lengths_R = a->gpz(12).r32(); |
164 | x86::Gp scratchReg1 = a->gpz(13); |
165 | x86::Gp scratchReg2 = a->gpz(14); // for prefetching |
166 | |
167 | asmjit::FuncDetail func; |
168 | func.init( |
169 | asmjit::FuncSignatureT< |
170 | bool, // return type |
171 | int64_t, // output_size |
172 | int64_t, // index_size |
173 | int64_t, // data_size |
174 | dataType*, // w |
175 | const float*, // g |
176 | float*, // h |
177 | const indxType*, // indices |
178 | const int*, // lengths |
179 | float, // epsilon |
180 | float, // lr then rand_buffer |
181 | uint32_t*>(asmjit::CallConvId::kHost), |
182 | a->environment()); |
183 | |
184 | asmjit::FuncFrame frame; |
185 | frame.init(func); |
186 | |
187 | if (instSet == inst_set_t::avx2) { |
188 | frame.setDirtyRegs( |
189 | asmjit::RegGroup::kVec, |
190 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
191 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
192 | } else { |
193 | frame.setDirtyRegs( |
194 | asmjit::RegGroup::kVec, |
195 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
196 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
197 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
198 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
199 | } |
200 | |
201 | frame.setDirtyRegs( |
202 | asmjit::RegGroup::kGp, |
203 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); |
204 | |
205 | asmjit::FuncArgsAssignment args(&func); |
206 | args.assignAll( |
207 | output_size, |
208 | index_size, |
209 | data_size, |
210 | w, |
211 | g, |
212 | h, |
213 | indices, |
214 | lengths, |
215 | epsilon, |
216 | lr, |
217 | rand_buffer); |
218 | |
219 | args.updateFuncFrame(frame); |
220 | frame.finalize(); |
221 | a->emitProlog(frame); |
222 | a->emitArgsAssignment(frame, args); |
223 | |
224 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
225 | constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS; |
226 | |
227 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
228 | |
229 | int num_vec_regs_per_block = (block_size + vlen - 1) / vlen; |
230 | int remainder = block_size % vlen; |
231 | |
232 | vec_reg_t src_vreg; // for holding embedding value temporarily |
233 | x86::Ymm mask_vreg; |
234 | |
235 | // Reserve registers with small ids first because some of them need to |
236 | // be used with an instruction not supported in avx512 for which a big |
237 | // register id won't work. |
238 | int first_available_vec_reg_id = 0; |
239 | x86::Ymm partial_sum_vreg = x86::Ymm(first_available_vec_reg_id); |
240 | ++first_available_vec_reg_id; |
241 | vec_reg_t float_step_vreg = vec_reg_t(first_available_vec_reg_id); |
242 | ++first_available_vec_reg_id; |
243 | vec_reg_t epsilon_vreg = vec_reg_t(first_available_vec_reg_id); |
244 | ++first_available_vec_reg_id; |
245 | vec_reg_t lr_vreg = vec_reg_t(first_available_vec_reg_id); |
246 | ++first_available_vec_reg_id; |
247 | |
248 | a->vpbroadcastd(epsilon_vreg, epsilon); |
249 | a->vpbroadcastd(lr_vreg, lr); |
250 | |
251 | // Reserve vector registers for random buffer generating |
252 | // S0...S3: global random buffer state |
253 | // R: generated random number in uint32_t |
254 | // r0: extracted random byte (uint8_t) shifted to bits[5...13] |
255 | // r1: temp |
256 | vec_reg_t R_vreg, S0_vreg, S1_vreg, S2_vreg, S3_vreg, r0_vreg, r1_vreg; |
257 | if (areWeightsFp16 && use_stochastic_rounding) { |
258 | R_vreg = vec_reg_t(first_available_vec_reg_id); |
259 | first_available_vec_reg_id++; |
260 | S0_vreg = vec_reg_t(first_available_vec_reg_id); |
261 | first_available_vec_reg_id++; |
262 | S1_vreg = vec_reg_t(first_available_vec_reg_id); |
263 | first_available_vec_reg_id++; |
264 | S2_vreg = vec_reg_t(first_available_vec_reg_id); |
265 | first_available_vec_reg_id++; |
266 | S3_vreg = vec_reg_t(first_available_vec_reg_id); |
267 | first_available_vec_reg_id++; |
268 | r0_vreg = vec_reg_t(first_available_vec_reg_id); |
269 | first_available_vec_reg_id++; |
270 | r1_vreg = vec_reg_t(first_available_vec_reg_id); |
271 | first_available_vec_reg_id++; |
272 | |
273 | // Load random buffer for FP16 stochastic rounding |
274 | if (instSet == inst_set_t::avx2) { |
275 | a->vmovdqa(S0_vreg.ymm(), x86::dword_ptr(rand_buffer)); |
276 | a->vmovdqa( |
277 | S1_vreg.ymm(), |
278 | x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t))); |
279 | a->vmovdqa( |
280 | S2_vreg.ymm(), |
281 | x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t))); |
282 | a->vmovdqa( |
283 | S3_vreg.ymm(), |
284 | x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t))); |
285 | } else { // AVX512 |
286 | a->vmovdqa32(S0_vreg, x86::dword_ptr(rand_buffer)); |
287 | a->vmovdqa32( |
288 | S1_vreg, |
289 | x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t))); |
290 | a->vmovdqa32( |
291 | S2_vreg, |
292 | x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t))); |
293 | a->vmovdqa32( |
294 | S3_vreg, |
295 | x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t))); |
296 | } |
297 | } |
298 | |
299 | if (remainder) { |
300 | if (instSet == inst_set_t::avx2) { |
301 | src_vreg = vec_reg_t(first_available_vec_reg_id); |
302 | ++first_available_vec_reg_id; |
303 | |
304 | mask_vreg = x86::Ymm(first_available_vec_reg_id); |
305 | ++first_available_vec_reg_id; |
306 | // Use scratchReg1 as temp |
307 | a->mov(scratchReg1, asmjit::imm(mask_avx2)); |
308 | a->vmovups( |
309 | mask_vreg, |
310 | x86::ymmword_ptr( |
311 | scratchReg1, (vlen - remainder) % vlen * sizeof(int32_t))); |
312 | } else { |
313 | a->mov(scratchReg1, (1 << remainder) - 1); |
314 | a->kmovw(x86::k(1), scratchReg1); |
315 | } |
316 | } |
317 | // Need an extra mask for computing sum of gradients |
318 | int remainder_avx2 = |
319 | block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; |
320 | x86::KReg reduce_mask_avx512; |
321 | if (remainder_avx2 && instSet == inst_set_t::avx512) { |
322 | reduce_mask_avx512 = x86::k(2); |
323 | a->mov(scratchReg1, (1 << remainder_avx2) - 1); |
324 | a->kmovw(reduce_mask_avx512, scratchReg1); |
325 | } |
326 | |
327 | int unroll_factor = NUM_VEC_REG - first_available_vec_reg_id; |
328 | |
329 | // Compute the end address of indices |
330 | a->imul( |
331 | scratchReg1, |
332 | index_size, |
333 | static_cast<asmjit::Imm>(sizeof(indxType))); |
334 | a->add(scratchReg1, indices); |
335 | a->mov(index_size, scratchReg1); |
336 | |
337 | asmjit::Label exit = a->newLabel(); |
338 | asmjit::Label error = a->newLabel(); |
339 | asmjit::Label LoopRangeIndexBegin = a->newLabel(); |
340 | asmjit::Label LoopRangeIndexEnd = a->newLabel(); |
341 | |
342 | // rangeIndex loop begin (iterate output_size times) |
343 | a->bind(LoopRangeIndexBegin); |
344 | a->dec(output_size); |
345 | a->jl(LoopRangeIndexEnd); |
346 | |
347 | // Compute sq avg of gradients |
348 | // Even with avx512, we only need to use avx2 registers when computing |
349 | // partial_sum because some instructions we're using like vhaddps |
350 | // are only in avx2. |
351 | constexpr int vlen_avx2 = |
352 | simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; |
353 | int num_vec_regs_per_block_avx2 = |
354 | (block_size + vlen_avx2 - 1) / vlen_avx2; |
355 | |
356 | a->vxorps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg); |
357 | |
358 | // TODO: need to do a tree-reduction to fully take advantage of |
359 | // unrolling |
360 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2; |
361 | vec_idx += unroll_factor) { |
362 | int cur_unroll_factor = |
363 | std::min(unroll_factor, num_vec_regs_per_block_avx2 - vec_idx); |
364 | for (int v = 0; v < cur_unroll_factor; ++v) { |
365 | x86::Ymm out_vreg = x86::Ymm(v + first_available_vec_reg_id); |
366 | |
367 | auto g_ptr = |
368 | x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float)); |
369 | if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS && |
370 | vec_idx + v == num_vec_regs_per_block_avx2 - 1) { |
371 | if (instSet == inst_set_t::avx2) { |
372 | a->vmaskmovps(out_vreg, mask_vreg, g_ptr); |
373 | } else { |
374 | a->k(reduce_mask_avx512).z().vmovups(out_vreg, g_ptr); |
375 | } |
376 | } else { |
377 | a->vmovups(out_vreg, g_ptr); |
378 | } |
379 | a->vmulps(out_vreg, out_vreg, out_vreg); |
380 | a->vaddps(partial_sum_vreg, partial_sum_vreg, out_vreg); |
381 | } |
382 | } |
383 | // Reduce sum to 1 value |
384 | // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum); |
385 | // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2); |
386 | // Use YMM/XMMs with smaller ids for AVX2 specific instructions like |
387 | // vhaddps |
388 | x86::Xmm partial_sum_xmm(partial_sum_vreg.id()); |
389 | x86::Xmm float_step_xmm(float_step_vreg.id()); |
390 | // a->vmovups(partial_sum_temp0_ymm, partial_sum_vreg); |
391 | a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg); |
392 | a->vhaddps(partial_sum_vreg, partial_sum_vreg, partial_sum_vreg); |
393 | |
394 | //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) |
395 | a->movss(float_step_xmm, partial_sum_xmm); |
396 | //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)) |
397 | a->vextractf128(partial_sum_xmm, partial_sum_vreg, 1); |
398 | |
399 | // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) + |
400 | // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)); |
401 | a->addss(partial_sum_xmm, float_step_xmm); |
402 | |
403 | // This fragment moves block size (N) to stack and bcasts it to xmm reg |
404 | a->lea( |
405 | x86::rsp, |
406 | x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t)))); |
407 | a->mov(x86::dword_ptr(x86::rsp), block_size); |
408 | a->vbroadcastss( |
409 | float_step_xmm, |
410 | x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1 |
411 | a->vcvtdq2ps(float_step_xmm, float_step_xmm); |
412 | a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t))); |
413 | |
414 | // final_sum /= N |
415 | a->divss(partial_sum_xmm, float_step_xmm); |
416 | |
417 | if (use_offsets) { |
418 | a->mov(lengths_R, x86::dword_ptr(lengths, sizeof(offsetType))); |
419 | a->sub(lengths_R, x86::dword_ptr(lengths)); |
420 | } else { |
421 | a->mov(lengths_R, x86::dword_ptr(lengths)); |
422 | } |
423 | |
424 | // Array out of bound check |
425 | a->imul( |
426 | scratchReg1, lengths_R, static_cast<asmjit::Imm>(sizeof(indxType))); |
427 | |
428 | a->add(scratchReg1, indices); |
429 | a->cmp(scratchReg1, index_size); |
430 | a->jg(error); |
431 | |
432 | asmjit::Label LoopDataIndexBegin = a->newLabel(); |
433 | asmjit::Label LoopDataIndexEnd = a->newLabel(); |
434 | |
435 | // dataIndex loop begins (iterate lengths_R_ times) |
436 | a->bind(LoopDataIndexBegin); |
437 | a->dec(lengths_R); |
438 | a->jl(LoopDataIndexEnd); |
439 | |
440 | // Array out of bound check |
441 | if (areIndices64b) { |
442 | a->mov(scratchReg1, x86::qword_ptr(indices)); |
443 | } else { |
444 | a->mov(scratchReg1.r32(), x86::dword_ptr(indices)); |
445 | } |
446 | // A trick to check x >= data_size or x < 0 in one shot by treating |
447 | // scratchReg1_ as if it has unsigned value |
448 | // (https://stackoverflow.com/a/34072155). |
449 | a->cmp(scratchReg1, data_size); |
450 | a->jae(error); |
451 | |
452 | if (prefetch) { |
453 | asmjit::Label pref_dist_reset_start = a->newLabel(); |
454 | asmjit::Label pref_dist_reset_end = a->newLabel(); |
455 | // out of bound handling for prefetch |
456 | a->mov(scratchReg2, indices); |
457 | a->add( |
458 | scratchReg2, |
459 | static_cast<asmjit::Imm>(prefetch * sizeof(indxType))); |
460 | a->cmp(scratchReg2, index_size); |
461 | a->jge(pref_dist_reset_start); |
462 | |
463 | if (areIndices64b) { |
464 | a->mov( |
465 | scratchReg2, |
466 | x86::qword_ptr(indices, prefetch * sizeof(indxType))); |
467 | } else { |
468 | a->mov( |
469 | scratchReg2.r32(), |
470 | x86::dword_ptr(indices, prefetch * sizeof(indxType))); |
471 | } |
472 | |
473 | a->jmp(pref_dist_reset_end); |
474 | |
475 | a->bind(pref_dist_reset_start); |
476 | // things are not okay just get the current row |
477 | // this can be improved to getting the max dist row. |
478 | if (areIndices64b) { |
479 | a->mov(scratchReg2, x86::qword_ptr(indices)); |
480 | } else { |
481 | a->mov(scratchReg2.r32(), x86::dword_ptr(indices)); |
482 | } |
483 | |
484 | a->bind(pref_dist_reset_end); |
485 | } |
486 | |
487 | a->add(indices, static_cast<asmjit::Imm>(sizeof(indxType))); |
488 | |
489 | if (prefetch) { |
490 | a->prefetchw(x86::dword_ptr(h, scratchReg2, 2)); |
491 | } |
492 | // load h |
493 | a->movss(float_step_xmm, x86::dword_ptr(h, scratchReg1, 2)); |
494 | // *h + final_sum |
495 | a->addss(float_step_xmm, partial_sum_xmm); |
496 | // store h |
497 | a->movss(x86::dword_ptr(h, scratchReg1, 2), float_step_xmm); |
498 | // sqrt(hi) |
499 | a->sqrtss(float_step_xmm, float_step_xmm); |
500 | // bcast partial to all of ymm/zmm reg |
501 | a->vpbroadcastd(float_step_vreg, float_step_xmm); |
502 | // lr / sqrt(hi) + epsilon |
503 | a->vaddps(float_step_vreg, float_step_vreg, epsilon_vreg); |
504 | a->vdivps(float_step_vreg, lr_vreg, float_step_vreg); |
505 | |
506 | a->imul(scratchReg1, static_cast<asmjit::Imm>(block_size)); |
507 | if (prefetch) { |
508 | a->imul(scratchReg2, static_cast<asmjit::Imm>(block_size)); |
509 | } |
510 | |
511 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; |
512 | vec_idx += unroll_factor) { |
513 | int cur_unroll_factor = |
514 | std::min(unroll_factor, num_vec_regs_per_block - vec_idx); |
515 | |
516 | // The main computation |
517 | for (int v = 0; v < cur_unroll_factor; ++v) { |
518 | vec_reg_t out_vreg = vec_reg_t(v + first_available_vec_reg_id); |
519 | |
520 | auto g_ptr = |
521 | x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float)); |
522 | if (!areWeightsFp16) { // float weights |
523 | auto w_ptr = x86::dword_ptr( |
524 | w, scratchReg1, 2, (vec_idx + v) * vlen * sizeof(dataType)); |
525 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
526 | if (instSet == inst_set_t::avx2) { |
527 | a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr); |
528 | a->vmulps(src_vreg, float_step_vreg, src_vreg); |
529 | |
530 | a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr); |
531 | a->vaddps(out_vreg, src_vreg, out_vreg); |
532 | |
533 | a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); |
534 | } else { |
535 | a->k(x86::k(1)).vmulps(out_vreg, float_step_vreg, g_ptr); |
536 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); |
537 | a->k(x86::k(1)).vmovups(w_ptr, out_vreg); |
538 | } |
539 | } else { |
540 | a->vmulps(out_vreg, float_step_vreg, g_ptr); |
541 | a->vaddps(out_vreg, out_vreg, w_ptr); |
542 | a->vmovups(w_ptr, out_vreg); |
543 | } |
544 | } else { // float16 weights |
545 | auto w_ptr = x86::word_ptr( |
546 | w, scratchReg1, 1, (vec_idx + v) * vlen * sizeof(dataType)); |
547 | |
548 | if (use_stochastic_rounding) { |
549 | // Index [0..3] for extracted bytes |
550 | // Each int32 has 4 8-bit rand byte |
551 | int sr_idx = (vec_idx + v) % 4; |
552 | |
553 | if (sr_idx == 0) { |
554 | // Generate R buffer every 4 steps of num_vec_regs_per_block |
555 | // loop. Each 8-bit in R (uint32_t) will be used once. It is |
556 | // shifted to the bits [5-13] then added to FP32 weights |
557 | // before FP16 conversion. |
558 | // |
559 | // The shifted 8 bit region |
560 | // +-------+--------+--------+--------+ |
561 | // | | | xxxxx|xxx | |
562 | // 31 23 15 7 0 |
563 | // |
564 | // Half float has 10 bits of mantissa, and float has 23, we |
565 | // are shifting the bits to cover the region where half |
566 | // floats can't represent data. This is bits[13..23] of the |
567 | // mantissa of FP32. This will be effectively adding a random |
568 | // variable of [0,1] |
569 | |
570 | // Random generator using xoshiro128++ |
571 | // Ref: http://prng.di.unimi.it/xoshiro128plusplus.c |
572 | a->vpaddd(r0_vreg, S0_vreg, S3_vreg); |
573 | a->vpslld(r1_vreg, r0_vreg, 7); |
574 | a->vpsrld(r0_vreg, r0_vreg, 25); |
575 | if (instSet == inst_set_t::avx2) { |
576 | a->vpor(R_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); |
577 | } else { |
578 | a->vpord(R_vreg, r0_vreg, r1_vreg); |
579 | } |
580 | a->vpaddd(R_vreg, R_vreg, S0_vreg); |
581 | |
582 | a->vpslld(r0_vreg, S1_vreg, 9); |
583 | |
584 | if (instSet == inst_set_t::avx2) { |
585 | a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), S0_vreg.ymm()); |
586 | a->vpxor(S3_vreg.ymm(), S3_vreg.ymm(), S1_vreg.ymm()); |
587 | a->vpxor(S1_vreg.ymm(), S1_vreg.ymm(), S2_vreg.ymm()); |
588 | a->vpxor(S0_vreg.ymm(), S0_vreg.ymm(), S3_vreg.ymm()); |
589 | |
590 | a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), r0_vreg.ymm()); |
591 | } else { |
592 | a->vpxord(S2_vreg, S2_vreg, S0_vreg); |
593 | a->vpxord(S3_vreg, S3_vreg, S1_vreg); |
594 | a->vpxord(S1_vreg, S1_vreg, S2_vreg); |
595 | a->vpxord(S0_vreg, S0_vreg, S3_vreg); |
596 | |
597 | a->vpxord(S2_vreg, S2_vreg, r0_vreg); |
598 | } |
599 | a->vpslld(r0_vreg, S3_vreg, 11); |
600 | a->vpsrld(r1_vreg, S3_vreg, 21); |
601 | if (instSet == inst_set_t::avx2) { |
602 | a->vpor(S3_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); |
603 | } else { |
604 | a->vpord(S3_vreg, r0_vreg, r1_vreg); |
605 | } |
606 | |
607 | // Extract byte 0 and shift to bits[5..13] |
608 | a->vpslld(r0_vreg, R_vreg, 24); |
609 | a->vpsrld(r0_vreg, r0_vreg, 19); |
610 | } else if (sr_idx == 1) { |
611 | // Extract byte 1 and shift to bits[[5..13] |
612 | a->vpsrld(r0_vreg, R_vreg, 8); |
613 | a->vpslld(r0_vreg, r0_vreg, 24); |
614 | a->vpsrld(r0_vreg, r0_vreg, 19); |
615 | } else if (sr_idx == 2) { |
616 | // Extract byte 2 and shift to bits[5..13] |
617 | a->vpslld(r0_vreg, R_vreg, 8); |
618 | a->vpsrld(r0_vreg, r0_vreg, 24); |
619 | a->vpslld(r0_vreg, r0_vreg, 5); |
620 | } else { // sr_idx == 3 |
621 | // Extract byte 3 and shift to bits[5..13] |
622 | a->vpsrld(r0_vreg, R_vreg, 24); |
623 | a->vpslld(r0_vreg, r0_vreg, 5); |
624 | } |
625 | } |
626 | |
627 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
628 | if (instSet == inst_set_t::avx2) { |
629 | a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr); |
630 | // No AVX2 mask load/store for 16bit |
631 | // Copy input to stack using loop instead and reuse GPR for h |
632 | a->lea(x86::rsp, x86::ptr(x86::rsp, -8)); |
633 | a->mov(x86::ptr(x86::rsp), h); |
634 | a->lea( |
635 | x86::rsp, |
636 | x86::ptr( |
637 | x86::rsp, static_cast<int>(-vlen * sizeof(float16)))); |
638 | for (int r = 0; r < remainder; ++r) { |
639 | a->mov( |
640 | h.r16(), |
641 | x86::word_ptr( |
642 | w, |
643 | scratchReg1, |
644 | 1, |
645 | ((vec_idx + v) * vlen + r) * sizeof(dataType))); |
646 | a->mov(x86::ptr(x86::rsp, sizeof(dataType) * r), h.r16()); |
647 | } |
648 | a->vcvtph2ps(out_vreg, x86::word_ptr(x86::rsp)); |
649 | a->vfmadd231ps(out_vreg, float_step_vreg, src_vreg); |
650 | if (use_stochastic_rounding) { |
651 | a->vpaddd(out_vreg, r0_vreg, out_vreg); |
652 | } |
653 | // Truncate rounding to 'counterwork' the random added part |
654 | a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 11); |
655 | // Copy results back |
656 | for (int r = 0; r < remainder; ++r) { |
657 | a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r)); |
658 | a->mov( |
659 | x86::word_ptr( |
660 | w, |
661 | scratchReg1, |
662 | 1, |
663 | ((vec_idx + v) * vlen + r) * sizeof(dataType)), |
664 | h.r16()); |
665 | } |
666 | a->lea( |
667 | x86::rsp, |
668 | x86::ptr( |
669 | x86::rsp, static_cast<int>(vlen * sizeof(float16)))); |
670 | a->mov(h, x86::ptr(x86::rsp)); |
671 | a->lea(x86::rsp, x86::ptr(x86::rsp, 8)); |
672 | } else { |
673 | a->k(x86::k(1)).vcvtph2ps(out_vreg, w_ptr); |
674 | a->k(x86::k(1)).vfmadd231ps(out_vreg, float_step_vreg, g_ptr); |
675 | if (use_stochastic_rounding) { |
676 | a->vpaddd(out_vreg, r0_vreg, out_vreg); |
677 | } |
678 | // Truncate rounding |
679 | a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 11); |
680 | } |
681 | } else { |
682 | a->vcvtph2ps(out_vreg, w_ptr); |
683 | a->vfmadd231ps(out_vreg, float_step_vreg, g_ptr); |
684 | if (use_stochastic_rounding) { |
685 | a->vpaddd(out_vreg, r0_vreg, out_vreg); |
686 | } |
687 | // Truncate rounding |
688 | a->vcvtps2ph(w_ptr, out_vreg, 11); |
689 | } |
690 | } |
691 | |
692 | constexpr int CACHE_LINE_LEN = 64; |
693 | constexpr int BYTES_PER_VLOAD = vlen * sizeof(dataType); |
694 | constexpr int VLOAD_PER_CACHE_LINE = |
695 | CACHE_LINE_LEN / BYTES_PER_VLOAD; |
696 | if (prefetch && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) { |
697 | a->prefetchw(x86::dword_ptr( |
698 | w, |
699 | scratchReg2, |
700 | areWeightsFp16 ? 1 : 2, |
701 | (vec_idx + v) * BYTES_PER_VLOAD)); |
702 | } |
703 | } |
704 | } |
705 | |
706 | a->jmp(LoopDataIndexBegin); |
707 | a->bind(LoopDataIndexEnd); |
708 | |
709 | a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType))); |
710 | a->add(g, static_cast<asmjit::Imm>(grad_stride * sizeof(float))); |
711 | |
712 | a->jmp(LoopRangeIndexBegin); |
713 | a->bind(LoopRangeIndexEnd); |
714 | |
715 | a->cmp(indices, index_size); |
716 | a->jne(error); |
717 | a->mov(scratchReg1.r32(), 1); |
718 | a->jmp(exit); |
719 | a->bind(error); |
720 | a->mov(scratchReg1.r32(), 0); |
721 | a->bind(exit); |
722 | |
723 | if (areWeightsFp16 && use_stochastic_rounding) { |
724 | if (instSet == inst_set_t::avx2) { |
725 | a->vmovdqa(x86::dword_ptr(rand_buffer), S0_vreg.ymm()); |
726 | a->vmovdqa( |
727 | x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)), |
728 | S1_vreg.ymm()); |
729 | a->vmovdqa( |
730 | x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)), |
731 | S2_vreg.ymm()); |
732 | a->vmovdqa( |
733 | x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)), |
734 | S3_vreg.ymm()); |
735 | } else { |
736 | a->vmovdqa32(x86::dword_ptr(rand_buffer), S0_vreg); |
737 | a->vmovdqa32( |
738 | x86::dword_ptr(rand_buffer, 1 * vlen * sizeof(uint32_t)), |
739 | S1_vreg); |
740 | a->vmovdqa32( |
741 | x86::dword_ptr(rand_buffer, 2 * vlen * sizeof(uint32_t)), |
742 | S2_vreg); |
743 | a->vmovdqa32( |
744 | x86::dword_ptr(rand_buffer, 3 * vlen * sizeof(uint32_t)), |
745 | S3_vreg); |
746 | } |
747 | } |
748 | |
749 | a->mov(x86::eax, scratchReg1.r32()); |
750 | a->emitEpilog(frame); |
751 | |
752 | // jit_fused8bitembedding_kernel fn; |
753 | typename ReturnFunctionSignature<indxType, offsetType, dataType>:: |
754 | jit_sparse_adagrad_kernel fn; |
755 | asmjit::Error err; |
756 | { |
757 | unique_lock<mutex> lock(rtMutex_); |
758 | err = runtime().add(&fn, &code); |
759 | } |
760 | if (err) { |
761 | cout << "Error: in fn add" << endl; |
762 | return nullptr; |
763 | } |
764 | |
765 | #if defined(FBGEMM_LOG_CODE) |
766 | fclose(codeLogFile); |
767 | delete codeLogger; |
768 | #endif |
769 | return fn; |
770 | }); |
771 | } // getOrCreate |
772 | |
773 | // Per-thread global buffer for random number generating, with max vector size |
774 | constexpr size_t VLEN_MAX = simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS; |
775 | alignas(64) static thread_local uint32_t g_rnd128v_buffer[4 * VLEN_MAX]; |
776 | static thread_local bool g_rnd128v_initialized = false; |
777 | |
778 | void rand_initialize() { |
779 | // Splitmix64: http://prng.di.unimi.it/splitmix64.c |
780 | auto rnd128_init_next = [](uint64_t& x) { |
781 | uint64_t z = (x += 0x9e3779b97f4a7c15); |
782 | z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; |
783 | z = (z ^ (z >> 27)) * 0x94d049bb133111eb; |
784 | return z ^ (z >> 31); |
785 | }; |
786 | |
787 | if (!g_rnd128v_initialized) { |
788 | uint64_t h0 = std::hash<std::thread::id>{}(std::this_thread::get_id()); |
789 | for (auto i = 0; i < 4; ++i) { |
790 | g_rnd128v_buffer[i * VLEN_MAX] = rnd128_init_next(h0); |
791 | uint64_t h1 = g_rnd128v_buffer[i * VLEN_MAX]; |
792 | for (size_t v = 1; v < VLEN_MAX; ++v) { |
793 | g_rnd128v_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1); |
794 | } |
795 | } |
796 | g_rnd128v_initialized = true; |
797 | } |
798 | } |
799 | |
800 | } // namespace |
801 | |
802 | template <typename IndexType, typename OffsetType, typename DataType> |
803 | FBGEMM_API typename RowWiseSparseAdaGradFusedSignature< |
804 | IndexType, |
805 | OffsetType, |
806 | DataType>::Type |
807 | GenerateRowWiseSparseAdaGradFused( |
808 | int block_size, // number of parameters per row |
809 | int prefetch, |
810 | bool use_offsets, |
811 | bool use_stochastic_rounding, |
812 | int grad_stride) { |
813 | if (!cpuinfo_initialize()) { |
814 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
815 | } |
816 | if (grad_stride == -1) { |
817 | grad_stride = block_size; |
818 | } |
819 | |
820 | // Use avx512 only for fp16 + stochastic rounding |
821 | if (fbgemmHasAvx512Support() && std::is_same<DataType, float16>::value && |
822 | use_stochastic_rounding) { |
823 | static GenRowWiseSparseAdagradFused< |
824 | IndexType, |
825 | OffsetType, |
826 | DataType, |
827 | inst_set_t::avx512> |
828 | kernel_generator; |
829 | const auto original_func = kernel_generator.getOrCreate( |
830 | nullptr, |
831 | block_size, |
832 | prefetch, |
833 | use_offsets, |
834 | use_stochastic_rounding, |
835 | grad_stride); |
836 | const auto lambda_func = [=](int64_t output_size, |
837 | int64_t index_size, |
838 | int64_t data_size, |
839 | DataType* w, |
840 | const float* g, |
841 | float* h, |
842 | const IndexType* indices, |
843 | const OffsetType* offsets_or_lengths, |
844 | float epsilon, |
845 | float lr) { |
846 | // Initialize random buffer in the first execution |
847 | // TODO: JIT |
848 | if (std::is_same<DataType, float16>::value && use_stochastic_rounding) { |
849 | rand_initialize(); |
850 | } |
851 | |
852 | return original_func( |
853 | output_size, |
854 | index_size, |
855 | data_size, |
856 | w, // input/output parameters |
857 | g, // input gradients |
858 | h, // input/output momentums |
859 | indices, // indices of each row |
860 | offsets_or_lengths, |
861 | epsilon, |
862 | lr, |
863 | g_rnd128v_buffer); |
864 | }; |
865 | return lambda_func; |
866 | } else if (fbgemmHasAvx2Support()) { |
867 | static GenRowWiseSparseAdagradFused< |
868 | IndexType, |
869 | OffsetType, |
870 | DataType, |
871 | inst_set_t::avx2> |
872 | kernel_generator; |
873 | const auto original_func = kernel_generator.getOrCreate( |
874 | internal::avx2_ps_or_epi32_combined_mask, |
875 | block_size, |
876 | prefetch, |
877 | use_offsets, |
878 | use_stochastic_rounding, |
879 | grad_stride); |
880 | const auto lambda_func = [=](int64_t output_size, |
881 | int64_t index_size, |
882 | int64_t data_size, |
883 | DataType* w, |
884 | const float* g, |
885 | float* h, |
886 | const IndexType* indices, |
887 | const OffsetType* offsets_or_lengths, |
888 | float epsilon, |
889 | float lr) { |
890 | // Initialize random buffer in the first execution |
891 | // TODO: JIT |
892 | if (std::is_same<DataType, float16>::value && use_stochastic_rounding) { |
893 | rand_initialize(); |
894 | } |
895 | |
896 | return original_func( |
897 | output_size, |
898 | index_size, |
899 | data_size, |
900 | w, // input/output parameters |
901 | g, // input gradients |
902 | h, // input/output momentums |
903 | indices, // indices of each row |
904 | offsets_or_lengths, |
905 | epsilon, |
906 | lr, |
907 | g_rnd128v_buffer); |
908 | }; |
909 | return lambda_func; |
910 | } else { |
911 | return [=](int64_t output_size, |
912 | int64_t index_size, |
913 | int64_t data_size, |
914 | DataType* w, |
915 | const float* g, |
916 | float* h, |
917 | const IndexType* indices, |
918 | const OffsetType* offsets_or_lengths, |
919 | float epsilon, |
920 | float lr) { |
921 | return rowwise_sparse_adagrad_fused_ref( |
922 | block_size, |
923 | output_size, |
924 | index_size, |
925 | data_size, |
926 | w, |
927 | g, |
928 | h, |
929 | indices, |
930 | offsets_or_lengths, |
931 | epsilon, |
932 | lr, |
933 | use_offsets, |
934 | use_stochastic_rounding, |
935 | /*emu_vector_size=*/8, |
936 | grad_stride); |
937 | }; |
938 | } |
939 | } |
940 | |
941 | template FBGEMM_API |
942 | typename RowWiseSparseAdaGradFusedSignature<int64_t, int32_t, float>::Type |
943 | GenerateRowWiseSparseAdaGradFused<int64_t, int32_t, float>( |
944 | int block_size, // number of parameters per row |
945 | int prefetch, |
946 | bool use_offsets, |
947 | bool use_stochastic_rounding, |
948 | int grad_stride); |
949 | |
950 | template FBGEMM_API |
951 | typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float>::Type |
952 | GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>( |
953 | int block_size, // number of parameters per row |
954 | int prefetch, |
955 | bool use_offsets, |
956 | bool use_stochastic_rounding, |
957 | int grad_stride); |
958 | |
959 | template FBGEMM_API |
960 | typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float>::Type |
961 | GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float>( |
962 | int block_size, // number of parameters per row |
963 | int prefetch, |
964 | bool use_offsets, |
965 | bool use_stochastic_rounding, |
966 | int grad_stride); |
967 | |
968 | template FBGEMM_API |
969 | typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float>::Type |
970 | GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float>( |
971 | int block_size, // number of parameters per row |
972 | int prefetch, |
973 | bool use_offsets, |
974 | bool use_stochastic_rounding, |
975 | int grad_stride); |
976 | |
977 | template FBGEMM_API |
978 | typename RowWiseSparseAdaGradFusedSignature<int64_t, int32_t, float16>::Type |
979 | GenerateRowWiseSparseAdaGradFused<int64_t, int32_t, float16>( |
980 | int block_size, // number of parameters per row |
981 | int prefetch, |
982 | bool use_offsets, |
983 | bool use_stochastic_rounding, |
984 | int grad_stride); |
985 | |
986 | template FBGEMM_API |
987 | typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float16>::Type |
988 | GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float16>( |
989 | int block_size, // number of parameters per row |
990 | int prefetch, |
991 | bool use_offsets, |
992 | bool use_stochastic_rounding, |
993 | int grad_stride); |
994 | |
995 | template FBGEMM_API |
996 | typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float16>::Type |
997 | GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float16>( |
998 | int block_size, // number of parameters per row |
999 | int prefetch, |
1000 | bool use_offsets, |
1001 | bool use_stochastic_rounding, |
1002 | int grad_stride); |
1003 | |
1004 | template FBGEMM_API |
1005 | typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float16>::Type |
1006 | GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float16>( |
1007 | int block_size, // number of parameters per row |
1008 | int prefetch, |
1009 | bool use_offsets, |
1010 | bool use_stochastic_rounding, |
1011 | int grad_stride); |
1012 | |
1013 | } // namespace fbgemm |
1014 | |