1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmEmbedding.h" |
9 | |
10 | #include <asmjit/asmjit.h> |
11 | #include <cpuinfo.h> |
12 | #include <cmath> |
13 | #include <iostream> |
14 | #include <mutex> |
15 | #include <string> |
16 | #include <tuple> |
17 | #include "./CodeCache.h" |
18 | #include "./MaskAvx2.h" |
19 | #include "./RefImplementations.h" |
20 | #include "fbgemm/SimdUtils.h" |
21 | #include "fbgemm/Utils.h" |
22 | |
23 | namespace fbgemm { |
24 | |
25 | namespace { |
26 | namespace x86 = asmjit::x86; |
27 | |
28 | template <typename indxType = std::int64_t> |
29 | class ReturnFunctionSignature { |
30 | public: |
31 | using jit_sparse_adagrad_kernel = int (*)( |
32 | int num_rows, // number of rows reading |
33 | std::uint64_t param_size, // total number of parameters |
34 | float* w, // input/output parameters |
35 | const float* g, // input gradients |
36 | float* h, // input/output momentums |
37 | const indxType* indices, // indices of each row |
38 | float epsilon, |
39 | float lr, |
40 | const int* mask_avx2, |
41 | float weight_decay, |
42 | const double* counter, |
43 | std::int64_t counter_halflife); |
44 | }; |
45 | |
46 | template < |
47 | typename indxType = std::int64_t, |
48 | inst_set_t instSet = inst_set_t::avx2> |
49 | class GenSparseAdagrad { |
50 | public: |
51 | GenSparseAdagrad() {} |
52 | void genSparseAdagrad( |
53 | x86::Emitter* a, |
54 | int unroll_factor, |
55 | int num_vec_regs_per_block, |
56 | int remainder, |
57 | int prefetch, |
58 | typename simd_info<instSet>::vec_reg_t epsilon_vreg, |
59 | typename simd_info<instSet>::vec_reg_t lr_vreg, |
60 | x86::Ymm mask_vreg, |
61 | typename simd_info<instSet>::vec_reg_t temp_vreg, |
62 | typename simd_info<instSet>::vec_reg_t weight_decay_vreg, |
63 | bool has_weight_decay); |
64 | |
65 | void genRowwiseSparseAdagrad( |
66 | x86::Emitter* a, |
67 | int block_size, |
68 | int unroll_factor, |
69 | int num_vec_regs_per_block, |
70 | int remainder, |
71 | int prefetch, |
72 | typename simd_info<instSet>::vec_reg_t epsilon_vreg, |
73 | typename simd_info<instSet>::vec_reg_t lr_vreg, |
74 | x86::Ymm mask_vreg, |
75 | typename simd_info<instSet>::vec_reg_t temp_vreg, |
76 | typename simd_info<instSet>::vec_reg_t weight_decay_vreg, |
77 | bool has_weight_decay); |
78 | |
79 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel |
80 | getOrCreate( |
81 | int block_size, |
82 | int prefetch, |
83 | bool rowwise, |
84 | bool has_weight_decay); |
85 | |
86 | private: |
87 | static asmjit::JitRuntime& runtime() { |
88 | static asmjit::JitRuntime rt; // JIT Runtime for asmjit |
89 | return rt; |
90 | } |
91 | |
92 | static std::mutex rtMutex_; /// Controll access to runtime; |
93 | |
94 | // The hash depends on embedding dimension (block size), prefetch distance, |
95 | // rowwise, and has_weight_decay |
96 | static CodeCache< |
97 | std::tuple<int, int, bool, bool>, |
98 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel> |
99 | codeCache_; ///< JIT Code Cache for reuse. |
100 | |
101 | // These are register we share accross SparseAdagrad and RowwiseSparseAdagrad |
102 | x86::Gp w; |
103 | x86::Gp g; |
104 | x86::Gp h; |
105 | x86::Gp indices; |
106 | x86::Gp base_offset; |
107 | x86::Gp temp1_; // loop counter |
108 | x86::Gp temp2_; // prefetch offset |
109 | x86::Gp temp3_; // prefetch offset of moment in rowwise adagrad |
110 | |
111 | x86::KReg reduce_mask_avx512_; |
112 | }; // GenEmbeddingLookup |
113 | |
114 | template <typename indxType, inst_set_t instSet> |
115 | std::mutex GenSparseAdagrad<indxType, instSet>::rtMutex_; |
116 | |
117 | template <typename indxType, inst_set_t instSet> |
118 | CodeCache< |
119 | std::tuple<int, int, bool, bool>, |
120 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel> |
121 | GenSparseAdagrad<indxType, instSet>::codeCache_; |
122 | |
123 | template <typename indxType, inst_set_t instSet> |
124 | void GenSparseAdagrad<indxType, instSet>::genSparseAdagrad( |
125 | x86::Emitter* a, |
126 | int unroll_factor, |
127 | int num_vec_regs_per_block, |
128 | int remainder, |
129 | int prefetch, |
130 | typename simd_info<instSet>::vec_reg_t epsilon_vreg, |
131 | typename simd_info<instSet>::vec_reg_t lr_vreg, |
132 | x86::Ymm mask_vreg, |
133 | typename simd_info<instSet>::vec_reg_t temp_vreg, |
134 | typename simd_info<instSet>::vec_reg_t weight_decay_vreg, |
135 | bool has_weight_decay) { |
136 | // NOTE: temp_vreg is defined only when remainder is true and instSet == avx2 |
137 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
138 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
139 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; |
140 | vec_idx += unroll_factor) { |
141 | int cur_unroll_factor = |
142 | std::min(unroll_factor, num_vec_regs_per_block - vec_idx); |
143 | |
144 | for (int v = 0; v < cur_unroll_factor; ++v) { |
145 | vec_reg_t out_vreg = vec_reg_t(v); |
146 | vec_reg_t g_vreg = vec_reg_t(v + cur_unroll_factor); |
147 | |
148 | if (prefetch && ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) { |
149 | // Intel SDE (wrongly) thinks prefetchwt1 is not available in BDW |
150 | a->prefetchw( |
151 | x86::dword_ptr(h, temp2_, 0, (vec_idx + v) * vlen * sizeof(float))); |
152 | |
153 | a->prefetchw( |
154 | x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float))); |
155 | } |
156 | |
157 | auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float)); |
158 | auto h_ptr = x86::dword_ptr( |
159 | h, base_offset, 0, (vec_idx + v) * vlen * sizeof(float)); |
160 | auto w_ptr = x86::dword_ptr( |
161 | w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float)); |
162 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
163 | if (instSet == inst_set_t::avx2) { |
164 | a->vmaskmovps(g_vreg.ymm(), mask_vreg, g_ptr); |
165 | if (has_weight_decay) { |
166 | // TODO(@taiqing) use a vreg for weights to avoid duplicate indexing |
167 | a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr); |
168 | a->vfmadd231ps(g_vreg, temp_vreg, weight_decay_vreg); |
169 | } |
170 | a->vmulps(out_vreg, g_vreg, g_vreg); |
171 | a->vmaskmovps(temp_vreg.ymm(), mask_vreg, h_ptr); |
172 | a->vaddps(out_vreg, out_vreg, temp_vreg); |
173 | |
174 | a->vmaskmovps(h_ptr, mask_vreg, out_vreg.ymm()); |
175 | |
176 | a->vsqrtps(out_vreg, out_vreg); |
177 | a->vaddps(out_vreg, out_vreg, epsilon_vreg); |
178 | |
179 | a->vmulps(g_vreg, lr_vreg, g_vreg); |
180 | a->vdivps(out_vreg, g_vreg, out_vreg); |
181 | |
182 | a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr); |
183 | a->vaddps(out_vreg, out_vreg, temp_vreg); |
184 | |
185 | a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); |
186 | } else if (instSet == inst_set_t::avx512) { |
187 | a->k(x86::k(1)).vmovups(g_vreg, g_ptr); |
188 | if (has_weight_decay) { |
189 | a->k(x86::k(1)).vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr); |
190 | } |
191 | a->k(x86::k(1)).vmulps(out_vreg, g_vreg, g_vreg); |
192 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, h_ptr); |
193 | |
194 | a->k(x86::k(1)).vmovups(h_ptr, out_vreg); |
195 | |
196 | a->k(x86::k(1)).vsqrtps(out_vreg, out_vreg); |
197 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, epsilon_vreg); |
198 | |
199 | a->k(x86::k(1)).vmulps(g_vreg, lr_vreg, g_vreg); |
200 | a->k(x86::k(1)).vdivps(out_vreg, g_vreg, out_vreg); |
201 | |
202 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); |
203 | |
204 | a->k(x86::k(1)).vmovups(w_ptr, out_vreg); |
205 | } |
206 | } else { |
207 | a->vmovups(g_vreg, g_ptr); |
208 | if (has_weight_decay) { |
209 | a->vfmadd231ps(g_vreg, weight_decay_vreg, w_ptr); |
210 | } |
211 | a->vmulps(out_vreg, g_vreg, g_vreg); |
212 | a->vaddps(out_vreg, out_vreg, h_ptr); |
213 | |
214 | a->vmovups(h_ptr, out_vreg); |
215 | |
216 | a->vsqrtps(out_vreg, out_vreg); |
217 | a->vaddps(out_vreg, out_vreg, epsilon_vreg); |
218 | |
219 | a->vmulps(g_vreg, lr_vreg, g_vreg); |
220 | a->vdivps(out_vreg, g_vreg, out_vreg); |
221 | |
222 | a->vaddps(out_vreg, out_vreg, w_ptr); |
223 | |
224 | a->vmovups(w_ptr, out_vreg); |
225 | } |
226 | } |
227 | } |
228 | } |
229 | |
230 | template <typename indxType, inst_set_t instSet> |
231 | void GenSparseAdagrad<indxType, instSet>::genRowwiseSparseAdagrad( |
232 | x86::Emitter* a, |
233 | int block_size, |
234 | int unroll_factor, |
235 | int num_vec_regs_per_block, |
236 | int remainder, |
237 | int prefetch, |
238 | typename simd_info<instSet>::vec_reg_t epsilon_vreg, |
239 | typename simd_info<instSet>::vec_reg_t lr_vreg, |
240 | x86::Ymm mask_vreg, |
241 | typename simd_info<instSet>::vec_reg_t temp_vreg, |
242 | typename simd_info<instSet>::vec_reg_t weight_decay_vreg, |
243 | bool has_weight_decay) { |
244 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
245 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
246 | |
247 | // Reduce the unroll factor by 1 for partial sum |
248 | --unroll_factor; |
249 | vec_reg_t partial_sum_vreg = vec_reg_t(unroll_factor); |
250 | |
251 | if (prefetch) { |
252 | a->prefetchw(x86::dword_ptr(h, temp3_)); |
253 | } |
254 | |
255 | bool areIndices64b = std::is_same<indxType, std::int64_t>::value; |
256 | auto indices_ptr = areIndices64b |
257 | ? x86::qword_ptr( |
258 | indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t) |
259 | : x86::dword_ptr( |
260 | indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t) |
261 | if (has_weight_decay) { |
262 | // set base_offset for fetching w in the calculation of gradient square sum |
263 | a->imul( |
264 | areIndices64b ? base_offset : base_offset.r32(), |
265 | indices_ptr, |
266 | static_cast<asmjit::Imm>(block_size * sizeof(float))); |
267 | } |
268 | |
269 | // Even with avx512, we only need to use avx2 registers when computing |
270 | // partial_sum because some instructions we're using like vhaddps |
271 | // are only in avx2. |
272 | constexpr int vlen_avx2 = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; |
273 | int num_vec_regs_per_block_avx2 = (block_size + vlen_avx2 - 1) / vlen_avx2; |
274 | |
275 | // Use YMM/XMMs with smaller ids for AVX2 specific instructions like vhaddps |
276 | x86::Ymm partial_sum_vreg_avx2(0); |
277 | x86::Xmm partial_sum_xmm0(partial_sum_vreg_avx2.id()); |
278 | |
279 | a->vxorps( |
280 | partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); |
281 | |
282 | // TODO: need to do a tree-reduction to fully take advantage of unrolling |
283 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block_avx2; |
284 | vec_idx += unroll_factor - 1) { |
285 | int cur_unroll_factor = |
286 | std::min(unroll_factor - 1, num_vec_regs_per_block_avx2 - vec_idx); |
287 | for (int v = 0; v < cur_unroll_factor; ++v) { |
288 | x86::Ymm out_vreg = x86::Ymm(v + 1); |
289 | if (has_weight_decay && prefetch && |
290 | ((vec_idx + v) % (64 / (vlen_avx2 * sizeof(float))) == 0)) { |
291 | a->prefetchw(x86::dword_ptr( |
292 | w, temp2_, 0, (vec_idx + v) * vlen_avx2 * sizeof(float))); |
293 | } |
294 | |
295 | auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen_avx2 * sizeof(float)); |
296 | auto w_ptr = x86::dword_ptr( |
297 | w, base_offset, 0, (vec_idx + v) * vlen_avx2 * sizeof(float)); |
298 | if (block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS && |
299 | vec_idx + v == num_vec_regs_per_block_avx2 - 1) { |
300 | if (instSet == inst_set_t::avx2) { |
301 | a->vmaskmovps(out_vreg, mask_vreg, g_ptr); |
302 | if (has_weight_decay) { |
303 | a->vmaskmovps(temp_vreg.ymm(), mask_vreg, w_ptr); |
304 | a->vfmadd231ps(out_vreg, temp_vreg, weight_decay_vreg); |
305 | } |
306 | } else { |
307 | a->k(reduce_mask_avx512_).z().vmovups(out_vreg, g_ptr); |
308 | if (has_weight_decay) { |
309 | a->k(reduce_mask_avx512_) |
310 | .z() |
311 | .vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); |
312 | } |
313 | } |
314 | } else { |
315 | a->vmovups(out_vreg, g_ptr); |
316 | if (has_weight_decay) { |
317 | a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); |
318 | } |
319 | } |
320 | a->vmulps(out_vreg, out_vreg, out_vreg); |
321 | a->vaddps(partial_sum_vreg_avx2, partial_sum_vreg_avx2, out_vreg); |
322 | } |
323 | } |
324 | // Reduce sum to 1 value |
325 | // __m256 partial_sum_2 = _mm256_hadd_ps(partial_sum, partial_sum); |
326 | // __m256 partial_sum_3 = _mm256_hadd_ps(partial_sum_2, partial_sum_2); |
327 | a->vhaddps( |
328 | partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); |
329 | a->vhaddps( |
330 | partial_sum_vreg_avx2, partial_sum_vreg_avx2, partial_sum_vreg_avx2); |
331 | |
332 | x86::Xmm partial_sum_xmm1(1); |
333 | |
334 | //_mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) |
335 | a->movss(partial_sum_xmm1, partial_sum_xmm0); |
336 | //_mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)) |
337 | a->vextractf128(partial_sum_xmm0, partial_sum_vreg_avx2, 1); |
338 | |
339 | // final_sum = _mm_cvtss_f32(_mm256_castps256_ps128(partial_sum_3)) + |
340 | // _mm_cvtss_f32(_mm256_extractf128_ps(partial_sum_3, 1)); |
341 | a->addss(partial_sum_xmm0, partial_sum_xmm1); |
342 | |
343 | // This fragment moves block size (N) to stack and bcasts it to xmm reg |
344 | a->lea( |
345 | x86::rsp, |
346 | x86::dword_ptr(x86::rsp, -1 * static_cast<int>(sizeof(int32_t)))); |
347 | a->mov(x86::dword_ptr(x86::rsp), block_size); |
348 | a->vbroadcastss( |
349 | partial_sum_xmm1, x86::dword_ptr(x86::rsp)); // N is partial_sum_xmm1 |
350 | a->vcvtdq2ps(partial_sum_xmm1, partial_sum_xmm1); |
351 | a->lea(x86::rsp, x86::dword_ptr(x86::rsp, sizeof(int32_t))); |
352 | |
353 | if (has_weight_decay) { |
354 | // set base_offset for fetching h |
355 | a->imul( |
356 | areIndices64b ? base_offset : base_offset.r32(), |
357 | indices_ptr, |
358 | static_cast<asmjit::Imm>(sizeof(float))); |
359 | } |
360 | |
361 | // final_sum /= N |
362 | a->divss(partial_sum_xmm0, partial_sum_xmm1); |
363 | // load h |
364 | a->movss(partial_sum_xmm1, x86::dword_ptr(h, base_offset)); |
365 | // *h + final_sum |
366 | a->addss(partial_sum_xmm0, partial_sum_xmm1); |
367 | // store h |
368 | a->movss(x86::dword_ptr(h, base_offset), partial_sum_xmm0); |
369 | // sqrt(hi) |
370 | a->sqrtss(partial_sum_xmm0, partial_sum_xmm0); |
371 | // bcast partial to all of ymm/zmm reg |
372 | a->vpbroadcastd(partial_sum_vreg, partial_sum_xmm0); |
373 | // lr / sqrt(hi) + epsilon |
374 | a->vaddps(partial_sum_vreg, partial_sum_vreg, epsilon_vreg); |
375 | a->vdivps(partial_sum_vreg, lr_vreg, partial_sum_vreg); |
376 | // partial_sum_vreg now has float_step |
377 | |
378 | // set base_offset for fetching w in updating weights |
379 | a->imul( |
380 | areIndices64b ? base_offset : base_offset.r32(), |
381 | indices_ptr, |
382 | static_cast<asmjit::Imm>(block_size * sizeof(float))); |
383 | |
384 | for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; |
385 | vec_idx += unroll_factor) { |
386 | int cur_unroll_factor = |
387 | std::min(unroll_factor, num_vec_regs_per_block - vec_idx); |
388 | |
389 | for (int v = 0; v < cur_unroll_factor; ++v) { |
390 | vec_reg_t out_vreg = vec_reg_t(v); |
391 | |
392 | if (!has_weight_decay && prefetch && |
393 | ((vec_idx + v) % (64 / (vlen * sizeof(float))) == 0)) { |
394 | a->prefetchw( |
395 | x86::dword_ptr(w, temp2_, 0, (vec_idx + v) * vlen * sizeof(float))); |
396 | } |
397 | |
398 | auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float)); |
399 | auto w_ptr = x86::dword_ptr( |
400 | w, base_offset, 0, (vec_idx + v) * vlen * sizeof(float)); |
401 | if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { |
402 | if (instSet == inst_set_t::avx2) { |
403 | a->vmaskmovps(temp_vreg.ymm(), mask_vreg, g_ptr); |
404 | if (has_weight_decay) { |
405 | a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr); |
406 | // TODO(@taiqing): have vreg for weights |
407 | a->vfmadd231ps(temp_vreg, weight_decay_vreg, out_vreg); |
408 | } |
409 | a->vmulps(temp_vreg, partial_sum_vreg, temp_vreg); |
410 | |
411 | a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr); |
412 | a->vaddps(out_vreg, temp_vreg, out_vreg); |
413 | |
414 | a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); |
415 | } else { |
416 | if (has_weight_decay) { |
417 | a->k(x86::k(1)).vmovups(out_vreg, g_ptr); |
418 | a->k(x86::k(1)).vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); |
419 | a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, out_vreg); |
420 | } else { |
421 | a->k(x86::k(1)).vmulps(out_vreg, partial_sum_vreg, g_ptr); |
422 | } |
423 | a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); |
424 | a->k(x86::k(1)).vmovups(w_ptr, out_vreg); |
425 | } |
426 | } else { |
427 | if (has_weight_decay) { |
428 | a->vmovups(out_vreg, g_ptr); |
429 | a->vfmadd231ps(out_vreg, weight_decay_vreg, w_ptr); |
430 | a->vmulps(out_vreg, partial_sum_vreg, out_vreg); |
431 | } else { |
432 | a->vmulps(out_vreg, partial_sum_vreg, g_ptr); |
433 | } |
434 | a->vaddps(out_vreg, out_vreg, w_ptr); |
435 | a->vmovups(w_ptr, out_vreg); |
436 | } |
437 | } |
438 | } |
439 | } |
440 | |
441 | template <typename indxType, inst_set_t instSet> |
442 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel |
443 | GenSparseAdagrad<indxType, instSet>::getOrCreate( |
444 | int block_size, |
445 | int prefetch, |
446 | bool rowwise, |
447 | bool has_weight_decay) { |
448 | std::tuple<int, int, bool, bool> kernelSig = |
449 | std::make_tuple(block_size, prefetch, rowwise, has_weight_decay); |
450 | |
451 | return codeCache_.getOrCreate( |
452 | kernelSig, |
453 | [&]() -> |
454 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel { |
455 | asmjit::CodeHolder code; |
456 | code.init(runtime().environment()); |
457 | x86::Assembler assembler(&code); |
458 | x86::Emitter* a = assembler.as<x86::Emitter>(); |
459 | bool areIndices64b = std::is_same<indxType, std::int64_t>::value; |
460 | #if defined(FBGEMM_LOG_CODE) |
461 | std::string filename = "SparseAdagrad" ; |
462 | filename += "_emd_dim_" + std::to_string(block_size); |
463 | if (rowwise) { |
464 | filename += "_rowwise" ; |
465 | } |
466 | filename += areIndices64b ? "_64bit" : "_32bit" ; |
467 | filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2" ; |
468 | if (prefetch) { |
469 | filename += "_prefetch" ; |
470 | } |
471 | if (has_weight_decay) { |
472 | filename += "weight_decay" ; |
473 | } |
474 | filename += ".txt" ; |
475 | FILE* codeLogFile = fopen(filename.c_str(), "w" ); |
476 | asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); |
477 | code.setLogger(codeLogger); |
478 | #endif |
479 | |
480 | x86::Gpd num_rows = a->zdi().r32(); |
481 | x86::Gp param_size = a->zsi(); |
482 | w = a->zdx(); |
483 | g = a->zcx(); |
484 | h = a->gpz(8); |
485 | indices = a->gpz(9); |
486 | x86::Xmm epsilon(0); |
487 | x86::Xmm lr(1); |
488 | x86::Gp mask_avx2 = a->gpz(10); |
489 | x86::Xmm weight_decay(2); |
490 | x86::Gp counter = a->gpz(11); |
491 | x86::Gp counter_halflife = a->gpz(12); |
492 | |
493 | // reuse mask_avx2 because mask_avx2 is used only at the beginning |
494 | base_offset = a->gpz(10); |
495 | temp1_ = a->gpz(13); |
496 | temp2_ = a->gpz(14); |
497 | temp3_ = a->gpz(15); |
498 | |
499 | asmjit::FuncDetail func; |
500 | func.init( |
501 | asmjit::FuncSignatureT< |
502 | int, // return type |
503 | int, // num rows |
504 | std::uint64_t, // param_size |
505 | float*, // w |
506 | const float*, // g |
507 | float*, // h |
508 | const indxType*, // indices |
509 | float, // epsilon |
510 | float, // lr |
511 | const int*, // mask_avx2 |
512 | float, // weight_decay |
513 | const double*, // counter then counter_halflife |
514 | std::int64_t>(asmjit::CallConvId::kHost), |
515 | a->environment()); |
516 | |
517 | asmjit::FuncFrame frame; |
518 | frame.init(func); |
519 | |
520 | if (instSet == inst_set_t::avx2) { |
521 | frame.setDirtyRegs( |
522 | asmjit::RegGroup::kVec, |
523 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
524 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
525 | } else { |
526 | frame.setDirtyRegs( |
527 | asmjit::RegGroup::kVec, |
528 | asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | |
529 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | |
530 | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | |
531 | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); |
532 | } |
533 | |
534 | frame.setDirtyRegs( |
535 | asmjit::RegGroup::kGp, |
536 | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); |
537 | |
538 | asmjit::FuncArgsAssignment args(&func); |
539 | args.assignAll( |
540 | num_rows, |
541 | param_size, |
542 | w, |
543 | g, |
544 | h, |
545 | indices, |
546 | epsilon, |
547 | lr, |
548 | mask_avx2, |
549 | weight_decay, |
550 | counter, |
551 | counter_halflife); |
552 | |
553 | args.updateFuncFrame(frame); |
554 | frame.finalize(); |
555 | a->emitProlog(frame); |
556 | a->emitArgsAssignment(frame, args); |
557 | |
558 | constexpr int vlen = simd_info<instSet>::WIDTH_32BIT_ELEMS; |
559 | constexpr int NUM_VEC_REG = simd_info<instSet>::NUM_VEC_REGS; |
560 | int unroll_factor = NUM_VEC_REG; |
561 | |
562 | typedef typename simd_info<instSet>::vec_reg_t vec_reg_t; |
563 | |
564 | int num_vec_regs_per_block = (block_size + vlen - 1) / vlen; |
565 | int remainder = block_size % vlen; |
566 | |
567 | vec_reg_t epsilon_vreg; |
568 | vec_reg_t lr_vreg; |
569 | vec_reg_t weight_decay_vreg; |
570 | vec_reg_t adjusted_weight_decay_vreg; |
571 | x86::Ymm mask_vreg; // mask for avx2 |
572 | vec_reg_t |
573 | temp_vreg; // temp vreg for avx2 to handle remainder computation |
574 | |
575 | --unroll_factor; |
576 | epsilon_vreg = vec_reg_t(unroll_factor); |
577 | --unroll_factor; |
578 | lr_vreg = vec_reg_t(unroll_factor); |
579 | if (has_weight_decay) { |
580 | --unroll_factor; |
581 | weight_decay_vreg = vec_reg_t(unroll_factor); |
582 | --unroll_factor; |
583 | adjusted_weight_decay_vreg = vec_reg_t(unroll_factor); |
584 | } |
585 | |
586 | if (remainder) { |
587 | if (instSet == inst_set_t::avx2) { |
588 | --unroll_factor; |
589 | temp_vreg = vec_reg_t(unroll_factor); |
590 | } |
591 | |
592 | // Creating masks for non multiples of vlen iterations |
593 | if (instSet == inst_set_t::avx2) { |
594 | --unroll_factor; |
595 | mask_vreg = x86::Ymm(unroll_factor); |
596 | a->vmovups(mask_vreg, x86::dword_ptr(mask_avx2)); |
597 | } else { |
598 | a->mov(temp1_, (1 << remainder) - 1); |
599 | a->kmovw(x86::k(1), temp1_); |
600 | } |
601 | } |
602 | // Need an extra mask for computing sum of gradients |
603 | int remainder_avx2 = |
604 | block_size % simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; |
605 | if (remainder_avx2 && instSet == inst_set_t::avx512 && rowwise) { |
606 | reduce_mask_avx512_ = x86::k(2); |
607 | a->mov(temp1_, (1 << remainder_avx2) - 1); |
608 | a->kmovw(reduce_mask_avx512_, temp1_); |
609 | } |
610 | |
611 | if (!rowwise) { |
612 | unroll_factor = unroll_factor / 2; // accont for g_vreg |
613 | } |
614 | |
615 | asmjit::Label exit = a->newLabel(); |
616 | asmjit::Label LoopRangeIndexBegin = a->newLabel(); |
617 | asmjit::Label LoopRangeIndexEnd = a->newLabel(); |
618 | |
619 | a->vpbroadcastd(epsilon_vreg, epsilon); |
620 | a->vpbroadcastd(lr_vreg, lr); |
621 | if (has_weight_decay) { |
622 | a->vpbroadcastd(weight_decay_vreg, weight_decay); |
623 | } |
624 | |
625 | a->xor_(temp1_, temp1_); |
626 | |
627 | a->bind(LoopRangeIndexBegin); |
628 | a->cmp(temp1_.r32(), num_rows); // temp1_ is the loop trip counter |
629 | a->jge(LoopRangeIndexEnd); |
630 | |
631 | auto indices_ptr = areIndices64b |
632 | ? x86::qword_ptr( |
633 | indices, temp1_, 3) // use of 3 is to muliply by 8 (int64_t) |
634 | : x86::dword_ptr( |
635 | indices, temp1_, 2); // use of 2 is to muliply by 4 (int32_t) |
636 | a->imul( |
637 | areIndices64b ? base_offset : base_offset.r32(), |
638 | indices_ptr, |
639 | static_cast<asmjit::Imm>( |
640 | (rowwise ? 1 : block_size) * sizeof(float))); |
641 | |
642 | // Perform this check |
643 | // if (block_size + offsetIdx > param_size) { |
644 | // return i; |
645 | // } |
646 | if (areIndices64b) { |
647 | a->mov(temp2_, indices_ptr); |
648 | } else { |
649 | a->mov(temp2_.r32(), indices_ptr); |
650 | } |
651 | |
652 | if (has_weight_decay) { |
653 | // Check counter != nullptr && counter[idx] > 0 |
654 | a->vmovaps(adjusted_weight_decay_vreg, weight_decay_vreg); |
655 | |
656 | asmjit::Label skip_adjust_freq = a->newLabel(); |
657 | |
658 | a->cmp(counter, 0); |
659 | a->je(skip_adjust_freq); |
660 | |
661 | // temp3_ : counter[idx] |
662 | a->mov(temp3_, x86::qword_ptr(counter, temp2_, 3)); |
663 | a->cmp(temp3_, 0); |
664 | a->jle(skip_adjust_freq); |
665 | |
666 | // OK to use Xmm registers with small ids that are reserved for temp |
667 | // values in the inner most loop. |
668 | vec_reg_t counter_halflife_vreg(0); |
669 | x86::Xmm counter_vreg(1); |
670 | a->cvtsi2sd(counter_halflife_vreg.xmm(), counter_halflife); |
671 | a->movq(counter_vreg, temp3_); |
672 | a->divpd(counter_halflife_vreg.xmm(), counter_vreg); |
673 | a->vcvtpd2ps( |
674 | counter_halflife_vreg.xmm(), counter_halflife_vreg.ymm()); |
675 | a->vbroadcastss(counter_halflife_vreg, counter_halflife_vreg.xmm()); |
676 | a->vmulps( |
677 | adjusted_weight_decay_vreg, |
678 | adjusted_weight_decay_vreg, |
679 | counter_halflife_vreg); |
680 | |
681 | a->bind(skip_adjust_freq); |
682 | } |
683 | |
684 | a->inc(temp2_); |
685 | a->imul( |
686 | temp2_, |
687 | static_cast<asmjit::Imm>(block_size)); //(offsetIdx+1)*blocksize |
688 | a->cmp(temp2_, param_size); |
689 | a->jg(exit); |
690 | |
691 | if (prefetch) { |
692 | asmjit::Label pref_dist_reset_start = a->newLabel(); |
693 | asmjit::Label pref_dist_reset_end = a->newLabel(); |
694 | |
695 | a->mov(temp2_, temp1_); |
696 | a->add(temp2_, prefetch); |
697 | a->cmp(temp2_.r32(), num_rows); |
698 | a->jge(pref_dist_reset_start); |
699 | |
700 | auto pref_indices_ptr = areIndices64b |
701 | ? x86::qword_ptr(indices, temp2_, 3) |
702 | : x86::dword_ptr(indices, temp2_, 2); |
703 | if (rowwise) { |
704 | a->imul( |
705 | areIndices64b ? temp3_ : temp3_.r32(), |
706 | pref_indices_ptr, |
707 | static_cast<asmjit::Imm>(sizeof(float))); |
708 | } |
709 | a->imul( |
710 | areIndices64b ? temp2_ : temp2_.r32(), |
711 | pref_indices_ptr, |
712 | static_cast<asmjit::Imm>(block_size * sizeof(float))); |
713 | |
714 | a->jmp(pref_dist_reset_end); |
715 | |
716 | a->bind(pref_dist_reset_start); |
717 | a->imul( |
718 | areIndices64b ? temp2_ : temp2_.r32(), |
719 | indices_ptr, |
720 | static_cast<asmjit::Imm>(block_size * sizeof(float))); |
721 | if (rowwise) { |
722 | a->imul( |
723 | areIndices64b ? temp3_ : temp3_.r32(), |
724 | indices_ptr, |
725 | static_cast<asmjit::Imm>(sizeof(float))); |
726 | } |
727 | |
728 | a->bind(pref_dist_reset_end); |
729 | } // prefetch |
730 | |
731 | if (rowwise) { |
732 | genRowwiseSparseAdagrad( |
733 | a, |
734 | block_size, |
735 | unroll_factor, |
736 | num_vec_regs_per_block, |
737 | remainder, |
738 | prefetch, |
739 | epsilon_vreg, |
740 | lr_vreg, |
741 | mask_vreg, |
742 | temp_vreg, |
743 | adjusted_weight_decay_vreg, |
744 | has_weight_decay); |
745 | } else { |
746 | genSparseAdagrad( |
747 | a, |
748 | unroll_factor, |
749 | num_vec_regs_per_block, |
750 | remainder, |
751 | prefetch, |
752 | epsilon_vreg, |
753 | lr_vreg, |
754 | mask_vreg, |
755 | temp_vreg, |
756 | adjusted_weight_decay_vreg, |
757 | has_weight_decay); |
758 | } |
759 | |
760 | a->add(g, static_cast<asmjit::Imm>(block_size * sizeof(float))); |
761 | a->inc(temp1_); |
762 | a->jmp(LoopRangeIndexBegin); |
763 | a->bind(LoopRangeIndexEnd); |
764 | |
765 | a->bind(exit); |
766 | a->mov(x86::eax, temp1_.r32()); |
767 | a->emitEpilog(frame); |
768 | |
769 | typename ReturnFunctionSignature<indxType>::jit_sparse_adagrad_kernel |
770 | fn; |
771 | asmjit::Error err; |
772 | { |
773 | std::unique_lock<std::mutex> lock(rtMutex_); |
774 | err = runtime().add(&fn, &code); |
775 | } |
776 | if (err) { |
777 | std::cout << "Error: in fn add" << std::endl; |
778 | return nullptr; |
779 | } |
780 | |
781 | #if defined(FBGEMM_LOG_CODE) |
782 | fclose(codeLogFile); |
783 | delete codeLogger; |
784 | #endif |
785 | return fn; |
786 | }); |
787 | } // getOrCreate |
788 | |
789 | // Specialization for block size 1 internally called by GenerateSparseAdaGrad |
790 | template <typename IndexType> |
791 | int SparseAdaGradBlockSize1_( |
792 | int num_rows, // number of rows reading |
793 | std::uint64_t param_size, // total number of parameters |
794 | float* w, // input/output parameters |
795 | const float* g, // input gradients |
796 | float* h, // input/output momentums |
797 | const IndexType* indices, // indices of each row |
798 | float epsilon, |
799 | float lr, |
800 | bool rowwise, |
801 | float weight_decay, |
802 | const double* counter, |
803 | std::int64_t counter_halflife) { |
804 | if (weight_decay != 0.0f) { |
805 | for (int i = 0; i < num_rows; ++i) { |
806 | IndexType idx = indices[i]; |
807 | if (idx >= static_cast<int64_t>(param_size)) { |
808 | return i; |
809 | } |
810 | |
811 | float freq = (counter && counter[idx] > 0) |
812 | ? counter_halflife / counter[idx] |
813 | : 1.0f; |
814 | float gi = std::fma(freq * weight_decay, w[idx], g[i]); |
815 | float hi = h[idx] = h[idx] + gi * gi; |
816 | if (rowwise) { |
817 | w[idx] += lr / (std::sqrt(hi) + epsilon) * gi; |
818 | } else { |
819 | w[idx] += lr * gi / (std::sqrt(hi) + epsilon); |
820 | } |
821 | } |
822 | } else { |
823 | for (int i = 0; i < num_rows; ++i) { |
824 | IndexType idx = indices[i]; |
825 | if (idx >= static_cast<int64_t>(param_size)) { |
826 | return i; |
827 | } |
828 | float gi = g[i]; |
829 | float hi = h[idx] = h[idx] + gi * gi; |
830 | if (rowwise) { |
831 | w[idx] += lr / (std::sqrt(hi) + epsilon) * gi; |
832 | } else { |
833 | w[idx] += lr * gi / (std::sqrt(hi) + epsilon); |
834 | } |
835 | } |
836 | } |
837 | return num_rows; |
838 | } |
839 | |
840 | template int SparseAdaGradBlockSize1_( |
841 | int num_rows, // number of rows reading |
842 | std::uint64_t param_size, // total number of parameters |
843 | float* w, // input parameters |
844 | const float* g, // input gradients |
845 | float* h, // input momentums |
846 | const std::int64_t* indices, // indices of each row |
847 | float epsilon, |
848 | float lr, |
849 | bool rowwise, |
850 | float weight_decay, |
851 | const double* counter, |
852 | std::int64_t counter_halflife); |
853 | |
854 | template int SparseAdaGradBlockSize1_( |
855 | int num_rows, // number of rows reading |
856 | std::uint64_t param_size, // total number of parameters |
857 | float* w, // input parameters |
858 | const float* g, // input gradients |
859 | float* h, // input momentums |
860 | const std::int32_t* indices, // indices of each row |
861 | float epsilon, |
862 | float lr, |
863 | bool rowwise, |
864 | float weight_decay, |
865 | const double* counter, |
866 | std::int64_t counter_halflife); |
867 | |
868 | } // namespace |
869 | |
870 | template <typename IndexType> |
871 | typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad( |
872 | int block_size, |
873 | bool rowwise, |
874 | int prefetch, |
875 | bool use_weight_decay) { |
876 | if (!cpuinfo_initialize()) { |
877 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
878 | } |
879 | |
880 | if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { |
881 | if (block_size == 1) { |
882 | return [=](int num_rows, // number of rows reading |
883 | std::uint64_t param_size, // total number of parameters |
884 | float* w, // input/output parameters |
885 | const float* g, // input gradients |
886 | float* h, // input/output momentums |
887 | const IndexType* indices, // indices of each row |
888 | float epsilon, |
889 | float lr, |
890 | float weight_decay, |
891 | const double* counter, |
892 | std::int64_t counter_halflife) { |
893 | return SparseAdaGradBlockSize1_( |
894 | num_rows, |
895 | param_size, |
896 | w, |
897 | g, |
898 | h, |
899 | indices, |
900 | epsilon, |
901 | lr, |
902 | rowwise, |
903 | weight_decay, |
904 | counter, |
905 | counter_halflife); |
906 | }; |
907 | } |
908 | static GenSparseAdagrad<IndexType, inst_set_t::avx2> kernel_generator; |
909 | constexpr int VLEN = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; |
910 | const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask |
911 | [(VLEN - (block_size % VLEN)) % VLEN]; |
912 | const auto original_func = kernel_generator.getOrCreate( |
913 | block_size, prefetch, rowwise, use_weight_decay); |
914 | return [=](int num_rows, // number of rows reading |
915 | std::uint64_t param_size, // total number of parameters |
916 | float* w, // input/output parameters |
917 | const float* g, // input gradients |
918 | float* h, // input/output momentums |
919 | const IndexType* indices, // indices of each row |
920 | float epsilon, |
921 | float lr, |
922 | float weight_decay, |
923 | const double* counter, |
924 | std::int64_t counter_halflife) { |
925 | return original_func( |
926 | num_rows, // number of rows reading |
927 | param_size, // total number of parameters |
928 | w, // input/output parameters |
929 | g, // input gradients |
930 | h, // input/output momentums |
931 | indices, // indices of each row |
932 | epsilon, |
933 | lr, |
934 | mask_avx2, |
935 | weight_decay, |
936 | counter, |
937 | counter_halflife); |
938 | }; |
939 | } else { |
940 | #ifdef VLOG |
941 | VLOG(0) << "AVX2 or AVX512 not found, taking the slow path" ; |
942 | #endif |
943 | return [=](int num_rows, // number of rows reading |
944 | std::uint64_t param_size, // total number of parameters |
945 | float* w, // input/output parameters |
946 | const float* g, // input gradients |
947 | float* h, // input/output momentums |
948 | const IndexType* indices, // indices of each row |
949 | float epsilon, |
950 | float lr, |
951 | float weight_decay, |
952 | const double* counter, |
953 | std::int64_t counter_halflife) { |
954 | if (rowwise) { |
955 | return rowwise_sparse_adagrad_ref( |
956 | num_rows, // number of rows reading |
957 | block_size, // number of parameters per rows |
958 | param_size, // total number of parameters |
959 | w, // input/output parameters |
960 | g, // input gradients |
961 | h, // input/output momentums |
962 | indices, |
963 | epsilon, |
964 | lr, |
965 | weight_decay, |
966 | counter, |
967 | counter_halflife); |
968 | } else { |
969 | return sparse_adagrad_ref( |
970 | num_rows, // number of rows reading |
971 | block_size, // number of parameters per rows |
972 | param_size, // total number of parameters |
973 | w, // input/output parameters |
974 | g, // input gradients |
975 | h, // input/output momentums |
976 | indices, |
977 | epsilon, |
978 | lr, |
979 | weight_decay, |
980 | counter, |
981 | counter_halflife); |
982 | } |
983 | }; |
984 | } |
985 | } |
986 | |
987 | template FBGEMM_API typename SparseAdaGradSignature<std::int64_t>::Type |
988 | GenerateSparseAdaGrad<std::int64_t>( |
989 | int block_size, // number of parameters per rows |
990 | bool rowwise, |
991 | int prefetch, |
992 | bool use_weight_decay); |
993 | |
994 | template FBGEMM_API typename SparseAdaGradSignature<std::int32_t>::Type |
995 | GenerateSparseAdaGrad<std::int32_t>( |
996 | int block_size, // number of parameters per rows |
997 | bool rowwise, |
998 | int prefetch, |
999 | bool use_weight_decay); |
1000 | |
1001 | } // namespace fbgemm |
1002 | |