1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28
29namespace eltwise_injector {
30
31bool is_isa_supported(cpu_isa_t isa) {
32 return is_superset(isa, sse41);
33}
34
35bool is_alg_supported(alg_kind_t alg) {
36 using namespace alg_kind;
37 return utils::one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
38 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
39 eltwise_soft_relu, eltwise_logistic, eltwise_mish, eltwise_exp,
40 eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_hardswish,
41 eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2,
42 eltwise_pow, eltwise_gelu_erf, eltwise_round,
43 eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd,
44 eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd,
45 eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd,
46 eltwise_clip_v2_use_dst_for_bwd);
47}
48
49bool is_supported(cpu_isa_t isa, alg_kind_t alg) {
50 return is_isa_supported(isa) && is_alg_supported(alg);
51}
52
53} // namespace eltwise_injector
54
55using namespace Xbyak;
56
57template <cpu_isa_t isa, typename Wmm>
58void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble(
59 const injector_utils::vmm_index_set_t &vmm_idxs) {
60 using namespace alg_kind;
61 using namespace Xbyak::util;
62 preserved_vecs_count = 0;
63 vecs_to_preserve = aux_vecs_count();
64 const auto start_idx = *(vmm_idxs.begin());
65 const auto end_idx = *(vmm_idxs.rbegin()) + 1;
66 start_idx_tail = vmm_idxs.begin();
67
68 // For avx we need a register to save the upper part of Ymm
69 preserve_vec_for_avx = isa == avx
70 && utils::one_of(alg_, eltwise_tanh, eltwise_elu, eltwise_abs,
71 eltwise_soft_relu, eltwise_mish, eltwise_logistic,
72 eltwise_exp, eltwise_gelu_tanh, eltwise_swish,
73 eltwise_gelu_erf, eltwise_tanh_use_dst_for_bwd,
74 eltwise_elu_use_dst_for_bwd,
75 eltwise_logistic_use_dst_for_bwd,
76 eltwise_exp_use_dst_for_bwd);
77 if (preserve_vec_for_avx) vecs_to_preserve++;
78
79 // For sse41 mask register has to be Xmm(0)
80 if (isa == sse41 && vecs_to_preserve > 0) {
81 size_t idx = 0;
82 assert(idx < start_idx);
83 preserved_vec_idxs[preserved_vecs_count++] = idx;
84 }
85
86 for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
87 if (preserved_vecs_count >= vecs_to_preserve) break;
88 if (start_idx <= idx && idx < end_idx) continue;
89
90 preserved_vec_idxs[preserved_vecs_count++] = idx;
91 }
92
93 size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
94 for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
95 preserved_vec_idxs[preserved_vecs_count++] = *start_idx_tail;
96 ++start_idx_tail;
97 }
98
99 assert(preserved_vecs_count == vecs_to_preserve);
100
101 // Same logic but to allocate gprs
102 size_t preserved_gprs_count = 0;
103 for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) {
104 int _idx = Operand::R15 - gpr_idx; // we allocate from the end
105 if (preserved_gprs_count < aux_gprs_count()
106 && !utils::one_of(_idx, p_table.getIdx(), Operand::RSP))
107 preserved_gpr_idxs[preserved_gprs_count++] = _idx;
108 }
109 assert(preserved_gprs_count == aux_gprs_count());
110
111 if (save_state_) {
112 if (preserve_p_table_) h->push(p_table);
113 for (size_t i = 0; i < preserved_gprs_count; ++i)
114 h->push(Reg64(preserved_gpr_idxs[i]));
115
116 if (preserve_vmm_) {
117 if (preserved_vecs_count)
118 h->sub(h->rsp, preserved_vecs_count * vlen);
119
120 for (size_t i = 0; i < preserved_vecs_count; ++i)
121 h->uni_vmovups(
122 h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
123 }
124 load_table_addr();
125 }
126
127 assign_regs();
128}
129
130template <cpu_isa_t isa, typename Wmm>
131void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble_tail(
132 const injector_utils::vmm_index_set_iterator_t start_idx_it) {
133 size_t tail_vecs_to_preserve = std::distance(start_idx_it, start_idx_tail);
134 if (tail_vecs_to_preserve == 0) return;
135
136 const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
137
138 if (save_state_) {
139 if (idx_off) h->add(h->rsp, idx_off * vlen);
140
141 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
142 h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
143 h->ptr[h->rsp + i * vlen]);
144 }
145
146 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
147 preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
148
149 if (save_state_ && preserve_vmm_) {
150 for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
151 h->uni_vmovups(h->ptr[h->rsp + i * vlen],
152 Vmm(preserved_vec_idxs[idx_off + i]));
153
154 if (idx_off) h->sub(h->rsp, idx_off * vlen);
155 }
156
157 assign_regs();
158}
159
160template <cpu_isa_t isa, typename Wmm>
161void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_postamble() {
162 using namespace Xbyak::util;
163 if (!save_state_) return;
164
165 if (preserve_vmm_) {
166 for (size_t i = 0; i < preserved_vecs_count; ++i)
167 h->uni_vmovups(
168 Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);
169
170 if (preserved_vecs_count) h->add(h->rsp, preserved_vecs_count * vlen);
171 }
172
173 for (int i = aux_gprs_count() - 1; i >= 0; --i)
174 h->pop(Reg64(preserved_gpr_idxs[i]));
175 if (preserve_p_table_) h->pop(p_table);
176}
177
178template <cpu_isa_t isa, typename Wmm>
179void jit_uni_eltwise_injector_f32<isa, Wmm>::assign_regs() {
180 vmm_mask = Vmm(preserved_vec_idxs[0]);
181 vmm_aux0 = Vmm(preserved_vec_idxs[0]);
182 vmm_aux1 = Vmm(preserved_vec_idxs[1]);
183 vmm_aux2 = Vmm(preserved_vec_idxs[2]);
184 vmm_aux3 = Vmm(preserved_vec_idxs[3]);
185 vmm_aux4 = Vmm(preserved_vec_idxs[4]);
186 if (preserve_vec_for_avx) {
187 vmm_tmp = Vmm(preserved_vec_idxs[vecs_to_preserve - 1]);
188 ymm_tmp = Ymm(preserved_vec_idxs[vecs_to_preserve - 1]);
189 xmm_tmp = Xmm(preserved_vec_idxs[vecs_to_preserve - 1]);
190 }
191}
192
193template <cpu_isa_t isa, typename Wmm>
194void jit_uni_eltwise_injector_f32<isa, Wmm>::vec_shift(const Vmm &vmm_dst,
195 const Vmm &vmm_src, bool shift_left, const int imm) {
196 if (isa != avx) {
197 if (shift_left)
198 h->uni_vpslld(vmm_dst, vmm_src, imm);
199 else
200 h->uni_vpsrld(vmm_dst, vmm_src, imm);
201 } else {
202 // Declare appropriate vectors to use non-uni instructions
203 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
204 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
205 Ymm ymm_src = Ymm(vmm_src.getIdx());
206 if (vmm_dst.getIdx() != vmm_src.getIdx()) h->vmovups(ymm_dst, ymm_src);
207 h->vextractf128(xmm_tmp, ymm_dst, 1);
208 if (shift_left) {
209 h->vpslld(xmm_dst, xmm_dst, imm);
210 h->vpslld(xmm_tmp, xmm_tmp, imm);
211 } else {
212 h->vpsrld(xmm_dst, xmm_dst, imm);
213 h->vpsrld(xmm_tmp, xmm_tmp, imm);
214 }
215 h->vinsertf128(ymm_dst, ymm_dst, xmm_tmp, 1);
216 }
217}
218
219// Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2).
220// Stores a mask by applying cmpps on two inputs w/ a given predicate.
221template <cpu_isa_t isa, typename Wmm>
222void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_cmp_mask(
223 const Vmm &vmm_src, const Xbyak::Operand &compare_operand,
224 int cmp_predicate) {
225 if (is_avx512) {
226 h->vcmpps(k_mask, vmm_src, compare_operand, cmp_predicate);
227 } else {
228 h->uni_vcmpps(vmm_mask, vmm_src, compare_operand, cmp_predicate);
229 }
230}
231
232// Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2).
233// Blends a result of second input into a first input w/ a stored mask.
234template <cpu_isa_t isa, typename Wmm>
235void jit_uni_eltwise_injector_f32<isa, Wmm>::blend_with_mask(
236 const Vmm &vmm_dst, const Xbyak::Operand &src) {
237 if (is_avx512) {
238 h->vblendmps(vmm_dst | k_mask, vmm_dst, src);
239 } else {
240 h->uni_vblendvps(vmm_dst, vmm_dst, src, vmm_mask);
241 }
242}
243
244// Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2).
245// Tests a mask for all zeros. If all zeroes occur, set ZF = 1.
246// Nicely combines with jump_if_zero (jz).
247template <cpu_isa_t isa, typename Wmm>
248void jit_uni_eltwise_injector_f32<isa, Wmm>::test_mask() {
249 if (is_avx512) {
250 h->kortestw(k_mask, k_mask);
251 } else {
252 h->uni_vtestps(vmm_mask, vmm_mask);
253 }
254}
255
256template <cpu_isa_t isa, typename Wmm>
257void jit_uni_eltwise_injector_f32<isa, Wmm>::exp_compute_vector_fwd(
258 const Vmm &vmm_src) {
259 // exp(x) =
260 // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
261 // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression
262
263 // get mask of values lower than log(FLT_MIN) to zero them in the output
264 compute_cmp_mask(vmm_src, table_val(exp_ln_flt_min_f), _cmp_lt_os);
265
266 h->uni_vminps(vmm_src, vmm_src, table_val(exp_ln_flt_max_f));
267 h->uni_vmaxps(vmm_src, vmm_src, table_val(exp_ln_flt_min_f));
268 h->uni_vmovups(vmm_aux1, vmm_src);
269
270 // calculate exp(x)
271 // fx = x * log2ef + 0.5
272 h->uni_vmulps(vmm_src, vmm_src, table_val(exp_log2ef));
273 h->uni_vaddps(vmm_src, vmm_src, table_val(half));
274
275 // tmp = floorf(fx)
276 h->uni_vroundps(vmm_aux2, vmm_src, _op_floor);
277
278 // keep vmm_src = fx for further computations
279 h->uni_vmovups(vmm_src, vmm_aux2);
280
281 // x = x - fx * ln2
282 h->uni_vfnmadd231ps(vmm_aux1, vmm_aux2, table_val(ln2f));
283
284 // We do not count 2^n here, because n can reach 128 and 2^128 is not
285 // representable by fp32, so to get around this problem, instead of computing
286 // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127
287 // and 2 are numbers representable in fp32.
288
289 // compute 2^(n-1)
290 h->uni_vsubps(vmm_src, vmm_src, table_val(one));
291 h->uni_vcvtps2dq(vmm_aux2, vmm_src);
292 if (isa != avx)
293 h->uni_vpaddd(vmm_aux2, vmm_aux2, table_val(exponent_bias));
294 else {
295 Ymm ymm_aux2 = Ymm(vmm_aux2.getIdx());
296 Xmm xmm_aux2 = Xmm(vmm_aux2.getIdx());
297 h->vextractf128(xmm_tmp, ymm_aux2, 1);
298 h->vpaddd(xmm_tmp, xmm_tmp, table_val(exponent_bias));
299 h->vpaddd(xmm_aux2, xmm_aux2, table_val(exponent_bias));
300 h->vinsertf128(ymm_aux2, ymm_aux2, xmm_tmp, 1);
301 }
302 vec_shift(vmm_aux2, vmm_aux2, true /*shift_left*/, n_mantissa_bits);
303 // use vmm_src as tmp vmm_zero when applying mask
304 h->uni_vxorps(vmm_src, vmm_src, vmm_src);
305 // set zeroes at those points which were < log(FLT_MIN)
306 blend_with_mask(vmm_aux2, vmm_src);
307
308 // compute polynomial
309 h->uni_vmovups(vmm_src, table_val(exp_pol, 4));
310 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 3));
311 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 2));
312 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 1));
313 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 0));
314 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one));
315 // y = y * 2^n
316 h->uni_vmulps(vmm_src, vmm_src, vmm_aux2);
317 h->uni_vmulps(vmm_src, vmm_src, table_val(two));
318}
319
320template <cpu_isa_t isa, typename Wmm>
321void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_compute_vector_fwd(
322 const Vmm &vmm_src) {
323 h->uni_vmovups(vmm_aux1, vmm_src);
324 compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os);
325 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
326 blend_with_mask(vmm_src, vmm_aux1);
327}
328
329template <cpu_isa_t isa, typename Wmm>
330void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_zero_ns_compute_vector_fwd(
331 const Vmm &vmm_src) {
332 h->uni_vmaxps(vmm_src, vmm_src, table_val(zero));
333}
334
335template <cpu_isa_t isa, typename Wmm>
336void jit_uni_eltwise_injector_f32<isa, Wmm>::elu_compute_vector_fwd(
337 const Vmm &vmm_src) {
338 // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it.
339 h->uni_vmovups(vmm_aux3, vmm_src);
340 // compute exponent
341 exp_compute_vector_fwd(vmm_src);
342
343 // alpha * (exp(x) - 1)
344 h->uni_vsubps(vmm_src, vmm_src, table_val(one));
345 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
346
347 // combine with mask
348 compute_cmp_mask(vmm_aux3, table_val(zero), _cmp_gt_os);
349 blend_with_mask(vmm_src, vmm_aux3);
350}
351
352template <cpu_isa_t isa, typename Wmm>
353void jit_uni_eltwise_injector_f32<isa, Wmm>::tanh_compute_vector_fwd(
354 const Vmm &vmm_src) {
355 // we add a check as the avx2 code cannot be used for avx
356 assert(IMPLICATION(isa == avx2, mayiuse(avx2)));
357
358 using namespace Xbyak::util;
359 const int XMM_float_lanes_count = 4;
360 const int tanh_n_polynomials = 32;
361
362 // register mapping
363 // TODO: put sign on stack and alias zmm_table2 with vmm_sign to save a reg ?
364 Vmm vmm_dst = vmm_aux1, vmm_src_shift = vmm_aux1, vmm_coeff = vmm_aux1,
365 vmm_pol = vmm_aux2, vmm_indices = vmm_aux3, vmm_src_original = vmm_aux4,
366 vmm_sign = vmm_aux4;
367 Reg64 gpr_idx[XMM_float_lanes_count];
368
369 if (isa == sse41 || isa == avx) {
370 assert(aux_gprs_count() >= XMM_float_lanes_count);
371 for (int i = 0; i < XMM_float_lanes_count; i++)
372 gpr_idx[i] = Reg64(preserved_gpr_idxs[i]);
373 }
374
375 // We split the positive domain in 33 intervals:
376 // a) [0; linear_ubound]: in this interval tanh(x) = x
377 // b) [linear_ubound; 0x1.8p-12]: This interval spans part of a
378 // half binade
379 // c) [0x1.8p-12; 0x1.0p-11], ..., [0x1.8p2; 0x1.0p3]:
380 // one interval for each half binade, there are 29 of those
381 // d) [0x1.0p3; saturation_ubound]:
382 // This interval spans part of a half binade
383 // e) [0x1.205966p3; saturation_ubound]: in this interval, tanh(x) = 1
384 // For b-d, we need 31 polynomials and will do a table lookup for those.
385 // To simplify the logic, we will also put a) in the table.
386
387 // The polynomials are of degree 6, so we need to gather 7 coefficients.
388 // - sse4.1: we do it the naive way using vextract/vinsert.
389 // Here we will extract the indices in gpr only once and
390 // reuse them as there are only 4 of them.
391 // - avx: we do the same as for sse4.1 but use half of the 64-bits
392 // registers to store the idx of second half of YMM and half for
393 // responding XMM. Halfway through the copy we exchange Xmm and
394 // higher half of Ymm and we get the expected result.
395 // - avx2: we use vpermps and blend for each coefficient.
396 // This needs an extra vmm to store the mask
397 // - avx512: because the table fits in 2 registers, we can use vpermi2d.
398 auto coeffs_off = [&](int coeff_off, int off = 0) {
399 return table_off(tanh_pol_table, coeff_off * tanh_n_polynomials + off);
400 };
401 auto coeffs_address = [&](int coeff_off, int off = 0) {
402 return table_val(tanh_pol_table, coeff_off * tanh_n_polynomials + off);
403 };
404 auto gather_coefficient_init = [&](Vmm vmm_pol_idx, int nelems) {
405 switch (isa) {
406 case sse41:
407 for (int i = 0; i < XMM_float_lanes_count; ++i)
408 h->pextrd(gpr_idx[i].cvt32(), vmm_pol_idx, i);
409 break;
410 case avx: {
411 Xmm xmm_pol_idx = Xmm(vmm_pol_idx.getIdx());
412 for (int i = 0; i < XMM_float_lanes_count; ++i)
413 h->vpextrd(gpr_idx[i].cvt32(), xmm_pol_idx, i);
414 } break;
415 case avx2_vnni_2:
416 case avx2:
417 // needed for gather instruction
418 h->uni_vxorps(vmm_mask, vmm_mask, vmm_mask);
419 break;
420 case avx512_core_fp16:
421 case avx512_core_bf16:
422 case avx512_core: break;
423 default: assert(!"unimplemented");
424 }
425 };
426 auto gather_coefficient = [&](Vmm vmm_coeff, int coeff_idx,
427 Vmm vmm_pol_idx) {
428 switch (isa) {
429 case sse41:
430 for (int idx = 0; idx < 4; ++idx) {
431 Xbyak::Address coeff_addr
432 = ptr[p_table + coeffs_off(coeff_idx)
433 + gpr_idx[idx] * sizeof(float)];
434 h->pinsrd(vmm_coeff, coeff_addr, idx);
435 }
436 break;
437 case avx: {
438 Xmm xmm_coeff = Xmm(vmm_coeff.getIdx());
439 for (int idx = 0; idx < 4; ++idx) {
440 Xbyak::Address coeff_addr
441 = ptr[p_table + coeffs_off(coeff_idx)
442 + gpr_idx[idx] * sizeof(float)];
443 h->vpinsrd(xmm_coeff, xmm_coeff, coeff_addr, idx);
444 }
445 } break;
446 case avx2_vnni_2:
447 case avx2: {
448 Xbyak::Address idx_addr = ptr[p_table + coeffs_off(coeff_idx)
449 + vmm_pol_idx * sizeof(float)];
450 // we set the mask to all ones to gather full
451 // register. needs to be done after each gather since
452 // since the gather instructions zeros the mask if
453 // successful
454 h->uni_vcmpps(vmm_mask, vmm_mask, vmm_mask, _cmp_eq_oq);
455 h->vgatherdps(vmm_coeff, idx_addr, vmm_mask);
456 break;
457 }
458 // use gather instruction
459 case avx512_core_fp16:
460 case avx512_core_bf16:
461 case avx512_core:
462 // we use vpermt2ps to not override the indices
463 // this also enables to save a register for table loading
464 {
465 Zmm zmm_coeff(vmm_coeff.getIdx());
466 Zmm zmm_pol_idx(vmm_pol_idx.getIdx());
467 h->uni_vmovups(zmm_coeff, coeffs_address(coeff_idx, 0));
468 h->vpermt2ps(zmm_coeff, zmm_pol_idx,
469 coeffs_address(coeff_idx, 16));
470 break;
471 }
472 default: assert(!"unimplemented");
473 }
474 };
475
476 // because tanh(x) = -tanh(-x), we extract sign to make x postive
477 // and reapply sign at the end
478 h->uni_vmovups(vmm_src_original, vmm_src);
479 h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask));
480
481 // We compute the indices for the table lookup
482 h->uni_vmovups(vmm_indices, vmm_src);
483 if (isa != avx)
484 h->uni_vpsubd(vmm_indices, vmm_indices, table_val(tanh_idx_bias));
485 else {
486 Ymm ymm_indices = Ymm(vmm_indices.getIdx());
487 Xmm xmm_indices = Xmm(vmm_indices.getIdx());
488 h->vextractf128(xmm_tmp, ymm_indices, 1);
489 h->vpsubd(xmm_tmp, xmm_tmp, table_val(tanh_idx_bias));
490 h->vpsubd(xmm_indices, xmm_indices, table_val(tanh_idx_bias));
491 h->vinsertf128(ymm_indices, ymm_indices, xmm_tmp, 1);
492 }
493 h->uni_vandps(vmm_indices, vmm_indices, table_val(tanh_idx_mask));
494 vec_shift(vmm_indices, vmm_indices, false, 22);
495
496 // we do the argument reduction
497 h->uni_vmovups(vmm_src_shift, vmm_src);
498 h->uni_vandps(vmm_src_shift, vmm_src_shift, table_val(tanh_idx_mask));
499 h->uni_vsubps(vmm_src, vmm_src, vmm_src_shift);
500
501 // we gather and evaluate the polynonials
502 gather_coefficient_init(vmm_indices, vlen / sizeof(float));
503 gather_coefficient(vmm_pol, 6, vmm_indices);
504 for (int deg = 5; deg >= 0; --deg) {
505 gather_coefficient(vmm_coeff, deg, vmm_indices);
506 h->uni_vfmadd213ps(vmm_pol, vmm_src, vmm_coeff);
507 }
508
509 if (isa == avx) {
510 Ymm ymm_indices = Ymm(vmm_indices.getIdx());
511 Ymm ymm_pol = Ymm(vmm_pol.getIdx());
512 Ymm ymm_src = Ymm(vmm_src.getIdx());
513 Xmm xmm_src = Xmm(vmm_src.getIdx());
514 Xmm xmm_coeff = Xmm(vmm_coeff.getIdx());
515
516 h->vperm2f128(ymm_src, ymm_src, ymm_src, 1);
517 h->vperm2f128(ymm_indices, ymm_indices, ymm_indices, 1);
518 gather_coefficient_init(vmm_indices, vlen / sizeof(float));
519 gather_coefficient(vmm_tmp, 6, vmm_indices);
520 for (int deg = 5; deg >= 0; --deg) {
521 gather_coefficient(vmm_coeff, deg, vmm_indices);
522 h->vmulps(xmm_tmp, xmm_tmp, xmm_src);
523 h->vaddps(xmm_tmp, xmm_tmp, xmm_coeff);
524 }
525 h->vinsertf128(ymm_pol, ymm_pol, xmm_tmp, 1);
526 }
527
528 // we restore src with cleared sign, and keep sign
529 assert(vmm_sign.getIdx() == vmm_src_original.getIdx());
530 h->uni_vmovups(vmm_src, vmm_src_original);
531 h->uni_vandps(vmm_sign, vmm_sign, table_val(sign_mask));
532 h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask));
533
534 // Now we blend the results
535 // [saturation_ubound; +inf[ : we return +/- 1
536 h->uni_vmovups(vmm_dst, table_val(one));
537 // [linear_ubound; saturation_lbound] : we return +/- P(x)
538 h->uni_vmovups(vmm_mask, table_val(tanh_saturation_lbound));
539 compute_cmp_mask(vmm_mask, vmm_src, _cmp_gt_os);
540 blend_with_mask(vmm_dst, vmm_pol);
541 // [0; linear_ubound] : we return x
542 h->uni_vmovups(vmm_mask, table_val(tanh_linear_ubound));
543 compute_cmp_mask(vmm_mask, vmm_src, _cmp_gt_os);
544 blend_with_mask(vmm_dst, vmm_src);
545
546 // We reapply the sign and return
547 h->uni_vxorps(vmm_dst, vmm_dst, vmm_sign);
548 h->uni_vmovups(vmm_src, vmm_dst);
549}
550
551template <cpu_isa_t isa, typename Wmm>
552void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_fwd(
553 const Vmm &vmm_src) {
554 h->uni_vmovups(vmm_aux0, vmm_src);
555
556 // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x)
557 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
558 h->uni_vmovups(vmm_aux1, table_val(gelu_tanh_fitting_const));
559 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one));
560 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
561 h->uni_vmulps(vmm_src, vmm_src, table_val(gelu_tanh_sqrt_two_over_pi));
562
563 // save x on stack as tanh uses vmm_aux0
564 h->sub(h->rsp, vlen);
565 h->uni_vmovups(h->ptr[h->rsp], vmm_aux0);
566
567 // compute tanh(G(x))
568 tanh_compute_vector_fwd(vmm_src);
569
570 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
571 h->add(h->rsp, vlen);
572
573 // compute 0.5 * x * (1 + tanh(G(x)))
574 h->uni_vaddps(vmm_src, vmm_src, table_val(one));
575 h->uni_vmulps(vmm_src, vmm_src, table_val(half));
576 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
577}
578
579template <cpu_isa_t isa, typename Wmm>
580void jit_uni_eltwise_injector_f32<isa, Wmm>::square_compute_vector_fwd(
581 const Vmm &vmm_src) {
582 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
583}
584
585template <cpu_isa_t isa, typename Wmm>
586void jit_uni_eltwise_injector_f32<isa, Wmm>::abs_compute_vector_fwd(
587 const Vmm &vmm_src) {
588 // compute abs(x) = _mm_and_ps(x, 01111..111));
589 h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask));
590}
591
592template <cpu_isa_t isa, typename Wmm>
593void jit_uni_eltwise_injector_f32<isa, Wmm>::sqrt_compute_vector_fwd(
594 const Vmm &vmm_src) {
595 h->uni_vsqrtps(vmm_src, vmm_src);
596}
597
598template <cpu_isa_t isa, typename Wmm>
599void jit_uni_eltwise_injector_f32<isa, Wmm>::linear_compute_vector_fwd(
600 const Vmm &vmm_src) {
601 // compute x = alpha * x + beta;
602 h->uni_vmovups(vmm_aux0, table_val(alpha));
603 h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(beta));
604}
605
606template <cpu_isa_t isa, typename Wmm>
607void jit_uni_eltwise_injector_f32<isa, Wmm>::clip_compute_vector_fwd(
608 const Vmm &vmm_src) {
609 h->uni_vmaxps(vmm_src, vmm_src, table_val(alpha));
610 h->uni_vminps(vmm_src, vmm_src, table_val(beta));
611}
612
613template <cpu_isa_t isa, typename Wmm>
614void jit_uni_eltwise_injector_f32<isa, Wmm>::mish_compute_vector_fwd(
615 const Vmm &vmm_src) {
616 // An equation other than mish(x) = x*tanh(srelu(x)) was used
617 // to calculate mish, but it should be remembered that it is equivalent
618 // equation, it uses the following rule:
619 // tanh(x) = (e^x - e^-x) / (e^x + e^-x),
620 // hence the equation for mish can take the form:
621 // mish(x) = x * ((e^x + 1)^2 - 1)/((e^x + 1)^2 + 1).
622 // This option was chosen because computing tanh requires more registers
623 // than exp, and also requires more constants to be stored in memory,
624 // making the algorithm slower.
625
626 // IMPORTANT: we use vmm_aux3 to save src as exp does not use it.
627 h->uni_vmovups(vmm_aux3, vmm_src); // vmm_aux3 = x
628
629 h->uni_vminps(vmm_src, vmm_src, table_val(fwd_mish_max_x_for_equation_f));
630 exp_compute_vector_fwd(vmm_src);
631
632 // (e^x+1)^2
633 h->uni_vaddps(vmm_src, vmm_src, table_val(one));
634 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
635
636 // save (e^x+1)^2 as it appears in both the denominator and the numerator
637 h->uni_vmovups(vmm_aux1, vmm_src);
638
639 // x * ((e^x + 1)^2 - 1) / ((e^x + 1)^2 + 1)
640 h->uni_vsubps(vmm_src, vmm_src, table_val(one));
641 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one));
642 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
643 h->uni_vmulps(vmm_src, vmm_src, vmm_aux3);
644}
645
646template <cpu_isa_t isa, typename Wmm>
647void jit_uni_eltwise_injector_f32<isa, Wmm>::hardswish_compute_vector_fwd(
648 const Vmm &vmm_src) {
649 // result = x * hardsigmoid(x)
650 h->uni_vmovups(vmm_aux0, vmm_src);
651 hardsigmoid_compute_vector_fwd(vmm_src);
652 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
653}
654
655template <cpu_isa_t isa, typename Wmm>
656void jit_uni_eltwise_injector_f32<isa, Wmm>::hardsigmoid_compute_vector_fwd(
657 const Vmm &vmm_src) {
658 // result = max(0, min(1, alpha * x + beta))
659 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
660 h->uni_vaddps(vmm_src, vmm_src, table_val(beta));
661 h->uni_vminps(vmm_src, vmm_src, table_val(one));
662 h->uni_vmaxps(vmm_src, vmm_src, table_val(zero));
663}
664
665template <cpu_isa_t isa, typename Wmm>
666void jit_uni_eltwise_injector_f32<isa, Wmm>::soft_relu_compute_vector_fwd(
667 const Vmm &vmm_src) {
668 // alpha scaling
669 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
670
671 // ln(1 + exp(x)) =
672 // = ln(1 + exp(n * ln(2) + r)) // divide x by ln(2) and get quot and rem
673 // = ln(1 + 2^n * exp(r)) // simplify the exp(n*ln(2)) expression
674 // = ln(2 ^ 0 + 2^n * exp(r)) // note 1 = 2^0
675 // = ln(2 ^ (n - n) + 2^n * exp(r)) // 2^0 = 2^(n-n)
676 // = ln(2 ^ n * (2^-n + exp(r))) // factorize with 2^n
677 // = n * ln(2) + ln(2^-n + exp(r)) // take the 2^n factor out of the ln
678
679 // keep src for further computations
680 h->uni_vmovups(vmm_aux2, vmm_src);
681
682 h->uni_vminps(vmm_src, vmm_src, table_val(exp_ln_flt_max_f));
683 h->uni_vmaxps(vmm_src, vmm_src, table_val(exp_ln_flt_min_f));
684 h->uni_vmovups(vmm_aux1, vmm_src);
685
686 // calculate exp(x)
687 // fx = x * log2ef + 0.5
688 h->uni_vmulps(vmm_src, vmm_src, table_val(exp_log2ef));
689 h->uni_vaddps(vmm_src, vmm_src, table_val(half));
690
691 // tmp = floorf(fx)
692 h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
693
694 // keep vmm_src = fx for further computations
695 h->uni_vmovups(vmm_src, vmm_aux0);
696
697 // x = x - fx * ln2
698 h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(ln2f));
699 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
700 // compute exponent polynomial
701 h->uni_vmovups(vmm_aux3, table_val(exp_pol, 4));
702 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 3));
703 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 2));
704 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 1));
705 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 0));
706 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(one));
707
708 // We do not count 2^-n here, because n can reach 128 and 2^(-128) is not
709 // representable by fp32, so to get around this problem, instead of computing
710 // 2^-n + exp(r) will be counted (2^-(n-1) + 2*exp(r))/2, because 2^(-127)
711 // and 2 are numbers representable in fp32.
712
713 // compute 2^-(n-1)
714 // vmm_src now represents n-1
715 h->uni_vsubps(vmm_src, vmm_src, table_val(one));
716 if (is_avx512) {
717 h->vmulps(vmm_aux1, vmm_src, table_val(minus_one));
718 h->vcvtps2dq(vmm_aux1, vmm_aux1);
719 } else if (isa == avx) {
720 h->uni_vxorps(vmm_aux1, vmm_src, table_val(sign_mask));
721 h->uni_vcvtps2dq(vmm_aux1, vmm_aux1);
722 } else {
723 h->uni_vcvtps2dq(vmm_aux1, vmm_src);
724 h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(minus_one));
725 }
726 // restore vmm_src to n
727 h->uni_vaddps(vmm_src, vmm_src, table_val(one));
728
729 if (isa != avx)
730 h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(exponent_bias));
731 else {
732 Ymm ymm_aux1 = Ymm(vmm_aux1.getIdx());
733 Xmm xmm_aux1 = Xmm(vmm_aux1.getIdx());
734 h->vextractf128(xmm_tmp, ymm_aux1, 1);
735 h->vpaddd(xmm_tmp, xmm_tmp, table_val(exponent_bias));
736 h->vpaddd(xmm_aux1, xmm_aux1, table_val(exponent_bias));
737 h->vinsertf128(ymm_aux1, ymm_aux1, xmm_tmp, 1);
738 }
739 vec_shift(vmm_aux1, vmm_aux1, true /*shift_left*/, n_mantissa_bits);
740 // calculate ln(1 + y)
741 h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(two)); // 2*exp(r)
742 h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); // 2^-(n-1) + 2*exp(r)
743 h->uni_vdivps(
744 vmm_aux3, vmm_aux3, table_val(two)); // (2^-(n-1) + 2*exp(r))/2
745 // frexp()
746 vec_shift(vmm_src, vmm_aux3, false /*shift_left*/, n_mantissa_bits);
747 h->uni_vcvtdq2ps(vmm_src, vmm_src);
748 // got n. where n is x = 2^n * y. y = 0.5 .. 1
749 h->uni_vsubps(vmm_src, vmm_src, table_val(soft_relu_one_twenty_six));
750
751 // and with mask (to get 0.5 * mantissa)
752 h->uni_vandps(vmm_aux3, vmm_aux3, table_val(soft_relu_mantissa_sign_mask));
753 // got y. (mantisa) 0.5 < y < 1 (or with (to get 0.5 * mantissa))
754 h->uni_vorps(vmm_aux3, vmm_aux3, table_val(half));
755 // y = y - 1
756 h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(one));
757
758 // compute log1p polynomial
759 h->uni_vmovups(vmm_aux1, table_val(soft_relu_pol, 8));
760 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 7));
761 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 6));
762 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 5));
763 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 4));
764 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 3));
765 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 2));
766 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 1));
767 h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 0));
768 //calculate ln(2) * n
769 h->uni_vmulps(vmm_src, vmm_src, table_val(ln2f));
770 h->uni_vaddps(vmm_src, vmm_src, vmm_aux1);
771 h->uni_vaddps(vmm_src, vmm_src, vmm_aux0);
772
773 // get vmm_mask = src > max logf
774 // y = (x < max log f) ? soft_relu(x) : x
775 compute_cmp_mask(vmm_aux2, table_val(exp_ln_flt_max_f), _cmp_gt_os);
776 blend_with_mask(vmm_src, vmm_aux2);
777 if (alpha_ == 1.f) { // standard soft_relu case
778 // Skip an instruction.
779 } else if (alpha_ == -1) { // logsigmoid case
780 h->uni_vmulps(vmm_src, vmm_src, table_val(minus_one));
781 } else { // General case.
782 h->uni_vdivps(vmm_src, vmm_src, table_val(alpha));
783 }
784}
785
786template <cpu_isa_t isa, typename Wmm>
787void jit_uni_eltwise_injector_f32<isa, Wmm>::logistic_compute_vector_fwd(
788 const Vmm &vmm_src) {
789 // To avoid exp(x) overflow happened at x > logf(FLT_MAX), negate positive,
790 // compute exp(x), where x <= 0 to get 0 <= exp(x) <= 1 and restore value
791 // sign at the end. This is possible due to logistic is symmetric function.
792
793 // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it.
794 h->uni_vmovups(vmm_aux3, vmm_src);
795 // we store the original sign and make x negative
796 h->uni_vandps(vmm_aux3, vmm_aux3, table_val(sign_mask));
797 h->uni_vorps(vmm_src, vmm_src, table_val(sign_mask));
798
799 exp_compute_vector_fwd(vmm_src);
800 // dup exp(x)
801 h->uni_vmovups(vmm_aux1, vmm_src);
802 // (exp(x) + 1)
803 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one));
804 // y = exp(x) / (exp(x) + 1)
805 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
806
807 // Now we have to apply the "symmetry" based on original sign
808 h->uni_vmovups(vmm_aux2, table_val(one));
809 h->uni_vsubps(vmm_aux2, vmm_aux2, vmm_src);
810 if (is_avx512) {
811 h->vptestmd(k_mask, vmm_aux3, vmm_aux3);
812 } else {
813 h->uni_vmovups(vmm_mask, vmm_aux3);
814 }
815 blend_with_mask(vmm_aux2, vmm_src);
816 h->uni_vmovups(vmm_src, vmm_aux2);
817}
818
819template <cpu_isa_t isa, typename Wmm>
820void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_fwd(
821 const Vmm &vmm_src) {
822 // Save src data on stack for later usage
823 h->sub(h->rsp, vlen);
824 h->uni_vmovups(h->ptr[h->rsp], vmm_src);
825 // x*alpha
826 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
827 // sigmoid(x*alpha)
828 logistic_compute_vector_fwd(vmm_src);
829 // x*sigmoid(alpha*x)
830 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
831 h->add(h->rsp, vlen);
832 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
833}
834
835template <cpu_isa_t isa, typename Wmm>
836void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_fwd(
837 const Vmm &vmm_src) {
838 // From J.-M. Muller and others, Handbook of Floating-Point Arithmetic, 2010
839 // Here is a brief mathematics to approximate log(x):
840 // log(x) = E * log(2) + log(y), where -log(2)/2 <= log(y) <= log(2)/2;
841 // log(y) = log(1 + z) - log(r_i), where z = y * r_i - 1, r_i approximates
842 // 1 / y, i is index of one of precomputed values;
843 // log(1 + z) ~~ polynomial(z), =>
844 // if (x is normal)
845 // log(x) ~~ E * log(2) + polynomial(z) - log(r_i),
846 // where log(r_i) is table value.
847 //
848 // If (x == 0) result = -inf;
849 // If (x < 0) result = qnan; (qnan value taken from table_val)
850 // If (x == inf) result = inf;
851 // If (x == qnan) result = qnan; (qnan value taken from src)
852 // If (x == 1) result = 0;
853
854 // set unused register as tmp for avx
855 if (isa == avx) {
856 ymm_tmp = Ymm(vmm_aux0.getIdx());
857 xmm_tmp = Xmm(vmm_aux0.getIdx());
858 }
859
860 // save source on stack to check neg and zero values at the end
861 h->sub(h->rsp, vlen);
862 h->uni_vmovups(h->ptr[h->rsp], vmm_src);
863
864 // compute i
865 const int approx_order = 5;
866 vec_shift(vmm_aux1, vmm_src, false, n_mantissa_bits - approx_order);
867 h->uni_vandps(vmm_aux1, vmm_aux1, table_val(log_five_bit_offset));
868 vec_shift(vmm_aux1, vmm_aux1, true, 1); // multiply i by 2
869
870 // compute anticancellation i
871 vec_shift(vmm_aux2, vmm_aux1, false, approx_order);
872
873 // get E, don't care about sign as only positive numbers are considered
874 vec_shift(vmm_aux3, vmm_src, false, n_mantissa_bits);
875 if (isa != avx)
876 h->uni_vpaddd(vmm_aux3, vmm_aux3, vmm_aux2);
877 else {
878 Ymm ymm_aux2 = Ymm(vmm_aux2.getIdx());
879 Ymm ymm_aux3 = Ymm(vmm_aux3.getIdx());
880 Xmm xmm_aux2 = Xmm(vmm_aux2.getIdx());
881 Xmm xmm_aux3 = Xmm(vmm_aux3.getIdx());
882 h->vextractf128(xmm_tmp, ymm_aux3, 1);
883 h->vpaddd(xmm_aux3, xmm_aux3, xmm_aux2);
884 h->vperm2f128(ymm_aux2, ymm_aux2, ymm_aux2, 1);
885 h->vpaddd(xmm_tmp, xmm_tmp, xmm_aux2);
886 h->vperm2f128(ymm_aux2, ymm_aux2, ymm_aux2, 1);
887 h->vinsertf128(ymm_aux3, ymm_aux3, xmm_tmp, 1);
888 }
889 h->uni_vcvtdq2ps(vmm_aux3, vmm_aux3);
890
891 // get m (mantissa)
892 h->uni_vxorps(vmm_aux2, vmm_aux2, table_val(exponent_bias));
893 vec_shift(vmm_aux2, vmm_aux2, true, n_mantissa_bits);
894 h->uni_vandps(vmm_src, vmm_src, table_val(log_mantissa_mask));
895 h->uni_vorps(vmm_src, vmm_src, vmm_aux2);
896
897 // At first, adjust indices for table structure which broadcasts elements
898 // by multiplying by simd_w
899 const int simd_w = math::ilog2q(
900 vlen / sizeof(float)); // equal to 2/3/4 for xmm/ymm/zmm
901 vec_shift(vmm_aux1, vmm_aux1, true, simd_w);
902
903 const auto it = entry_map_.find(log_predefined_vals);
904 assert(it != entry_map_.end());
905 const auto table_start_idx = (*it).second.off;
906
907 auto gather_table_values = [&](const Vmm &vmm_dst, const Vmm &vmm_idxs,
908 size_t offt = 0) {
909 Xbyak::Address table_idx = h->ptr[p_table + table_start_idx + offt
910 + vmm_idxs * sizeof(float)];
911 if (is_avx512) {
912 h->kmovw(k_mask, table_val(log_full_k_reg_mask));
913 h->vgatherdps(vmm_dst | k_mask, table_idx);
914 } else if (utils::one_of(isa, avx2, avx2_vnni_2)) {
915 h->uni_vmovups(vmm_mask, table_val(sign_mask));
916 h->vgatherdps(vmm_dst, table_idx, vmm_mask);
917 } else if (isa == avx || isa == sse41) {
918 Xbyak::Reg64 reg_tmp
919 = p_table.getIdx() != h->r9.getIdx() ? h->r9 : h->r10;
920
921 const int gpr_size = 8;
922 // save reg_tmp state as we are not allowed to spoil it.
923 h->sub(h->rsp, gpr_size);
924 h->mov(h->ptr[h->rsp], reg_tmp);
925
926 // rest of code puts indices on stack, fetching a table number based
927 // on an index, replaces index with the value, and, finally, moves
928 // fetched values into vector register.
929 h->sub(h->rsp, vlen);
930 h->uni_vmovups(h->ptr[h->rsp], vmm_idxs);
931
932 for (size_t i = 0; i < vlen / sizeof(float); ++i) {
933 h->mov(reg_tmp.cvt32(), h->ptr[h->rsp + i * sizeof(float)]);
934 h->shl(reg_tmp.cvt32(), 2); // multiply by simd_w
935 table_idx = h->ptr[p_table + table_start_idx + offt + reg_tmp];
936 h->mov(reg_tmp.cvt32(), table_idx);
937 h->mov(h->ptr[h->rsp + i * sizeof(float)], reg_tmp.cvt32());
938 }
939
940 h->uni_vmovups(vmm_dst, h->ptr[h->rsp]);
941 h->add(h->rsp, vlen);
942 // restore GPR state
943 h->mov(reg_tmp, h->ptr[h->rsp]);
944 h->add(h->rsp, gpr_size);
945 }
946 };
947
948 // get r_i, same as table(i)
949 gather_table_values(vmm_aux2, vmm_aux1, 0);
950
951 // compute relative error (rel_err = m * r_i - 1)
952 h->uni_vfmsub213ps(vmm_aux2, vmm_src, table_val(one));
953
954 // compute polynomial(rel_err)
955 h->uni_vmovups(vmm_src, table_val(log_pol, 3));
956 h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 2));
957 h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 1));
958 h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 0));
959 h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(one));
960 h->uni_vmulps(vmm_src, vmm_src, vmm_aux2);
961
962 // get log(r_i) = table(i+1)
963 gather_table_values(vmm_aux2, vmm_aux1, vlen);
964
965 // compute partial result (pres = E * ln(2) - log(r_i))
966 h->uni_vfmadd231ps(vmm_aux2, vmm_aux3, table_val(ln2f));
967
968 // compute (result = polynomial + pres) w/ TwoSum algorithm
969 // TODO: restore this instead of version below when asserts are gone
970 // h->uni_vaddps(vmm_aux1, vmm_src, vmm_aux2); // res_hi = pol + pres
971 // h->uni_vsubps(vmm_aux3, vmm_aux1, vmm_aux2); // res_lo = res_hi - pres
972 // h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); // res_lo = res_lo - pol
973 // h->uni_vaddps(vmm_src, vmm_aux1, vmm_aux3); // res_hi = pol + pres
974
975 h->uni_vmovups(vmm_aux1, vmm_src);
976 h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux2); // res_hi = pol + pres
977 h->uni_vmovups(vmm_aux3, vmm_aux1);
978 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_aux2); // res_lo = res_hi - pres
979 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); // res_lo = res_lo - pol
980 h->uni_vmovups(vmm_src, vmm_aux1);
981 h->uni_vaddps(vmm_src, vmm_src, vmm_aux3); // res_hi = pol + pres
982
983 // Check original source for zero and neg values. skip blend w/ extreme
984 // values if all src values were positive.
985 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
986 h->add(h->rsp, vlen);
987
988 Xbyak::Label end_log_zero_label;
989 compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os);
990 test_mask();
991 h->jz(end_log_zero_label);
992
993 // Blend extreme values into src if reach here.
994 // First zero for -inf values...
995 compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_eq_oq);
996 blend_with_mask(vmm_src, table_val(log_minus_inf));
997
998 // ...then negative for qnan values.
999 compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_lt_os);
1000 blend_with_mask(vmm_src, table_val(log_qnan));
1001
1002 h->L(end_log_zero_label);
1003
1004 // Leave inf values same as in src.
1005 compute_cmp_mask(vmm_aux1, table_val(log_inf), _cmp_eq_oq);
1006 Xbyak::Label end_log_inf_label;
1007 test_mask();
1008 h->jz(end_log_inf_label);
1009 blend_with_mask(vmm_src, table_val(log_inf));
1010 h->L(end_log_inf_label);
1011
1012 // Detect qnans if src != src and blend with qnans.
1013 compute_cmp_mask(vmm_aux1, vmm_aux1, _cmp_neq_uq);
1014 Xbyak::Label end_log_nan_label;
1015 test_mask();
1016 h->jz(end_log_nan_label);
1017 blend_with_mask(vmm_src, vmm_aux1);
1018 h->L(end_log_nan_label);
1019
1020 // Detect ones and blend with zeros.
1021 compute_cmp_mask(vmm_aux1, table_val(one), _cmp_eq_oq);
1022 Xbyak::Label end_log_one_label;
1023 test_mask();
1024 h->jz(end_log_one_label);
1025 blend_with_mask(vmm_src, table_val(zero));
1026 h->L(end_log_one_label);
1027}
1028
1029template <cpu_isa_t isa, typename Wmm>
1030void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_fwd(
1031 const Vmm &vmm_src) {
1032 // dispatch between special cases.
1033 if (beta_ == -1) { // alpha / x
1034 h->uni_vmovups(vmm_aux0, table_val(alpha));
1035 h->uni_vdivps(vmm_src, vmm_aux0, vmm_src, vmm_aux0);
1036 } else if (beta_ == 0) { // alpha
1037 h->uni_vmovups(vmm_src, table_val(alpha));
1038 } else if (beta_ == 0.5) { // alpha * sqrt(x)
1039 sqrt_compute_vector_fwd(vmm_src);
1040 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1041 } else if (beta_ == 1) { // alpha * x
1042 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1043 } else if (beta_ == 2) { // alpha * x^2
1044 square_compute_vector_fwd(vmm_src);
1045 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1046 } else { // general path
1047 // caller obligation to save gprs as callee may use them
1048 size_t gpr_size = 8;
1049 Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax,
1050 h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx};
1051 size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
1052
1053 h->sub(h->rsp, n_gprs_to_save * gpr_size);
1054 for (size_t i = 0; i < n_gprs_to_save; ++i)
1055 h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);
1056
1057 // caller obligation to save k-regs as callee may use them
1058 size_t n_k_regs_to_save = 8;
1059 if (is_avx512) {
1060 h->sub(h->rsp, n_k_regs_to_save * k_mask_size);
1061 for (size_t i = 0; i < n_k_regs_to_save; ++i) {
1062 if (mayiuse(avx512_core))
1063 h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(i));
1064 else
1065 h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(i));
1066 }
1067 }
1068
1069 // 1. Caller obligation to save vector registers as callee may use them.
1070 // 2. Additionally save space for vmm_src, to put the answer in-place on
1071 // this space and space for beta.
1072 // 3. There is an implicit assumption that the host code uses the same
1073 // `isa` as the injector. Once the assumption is wrong, `vecs_count` and
1074 // `vlen` should be replaced with `host_isa::vlen` and
1075 // `host_isa::vecs_count`.
1076 h->sub(h->rsp, (vecs_count + 2) * vlen);
1077 for (size_t i = 2; i < vecs_count + 2; ++i)
1078 h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(i - 2));
1079 h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_src); // src
1080 h->uni_vmovups(vmm_src, table_val(beta));
1081 h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_src); // beta
1082
1083 // save function address in gpr to pass in in call instruction
1084 h->mov(h->rbp, reinterpret_cast<uintptr_t>(powf));
1085
1086 // The 64-bit Windows ABI requires the caller to allocate 32 bytes of
1087 // a so called "shadow space" for the callee. It also requires that
1088 // the stack be 16 byte aligned before the call instruction is issued.
1089 // In order to allocate the shadow space and ensure the 16-byte alignment
1090 // of the stack we may actually need to allocate 40 bytes (32 bytes for
1091 // the "shadow space" + 8 bytes to align the stack) if the stack
1092 // pointer is not currently 16 byte aligned.
1093
1094 // align stack on 16-byte as ABI requires
1095 h->mov(h->rbx, h->rsp);
1096 // Get alignment offset.
1097 h->and_(h->rbx, 0xf);
1098 h->add(h->rbx, 0x20);
1099 h->sub(h->rsp, h->rbx);
1100
1101 // Take src, apply powf on it and replace value on a stack with dst.
1102 Xmm xmm0 = Xmm(0), xmm1 = Xmm(1);
1103 for (size_t i = 0; i < vlen / sizeof(float); ++i) {
1104 const Address &source = h->ptr[h->rsp + h->rbx + i * sizeof(float)];
1105 h->uni_vmovss(xmm0, source);
1106 h->uni_vmovss(xmm1, h->ptr[h->rsp + h->rbx + vlen]); // beta
1107 h->uni_vzeroupper(); // eliminate performance penalties on avx
1108 h->call(h->rbp);
1109 // eliminate performance penalties on sse isa
1110 if (isa == sse41) h->uni_vzeroupper();
1111 h->uni_vmovss(source, xmm0);
1112 }
1113
1114 h->add(h->rsp, h->rbx);
1115
1116 // restore vector registers
1117 for (size_t i = vecs_count + 1; i >= 2; --i)
1118 h->uni_vmovups(Vmm(i - 2), h->ptr[h->rsp + i * vlen]);
1119 h->uni_vmovups(vmm_src, h->ptr[h->rsp + 0 * vlen]);
1120 h->add(h->rsp, (vecs_count + 2) * vlen);
1121
1122 // restore k registers
1123 if (is_avx512) {
1124 for (int i = n_k_regs_to_save - 1; i >= 0; --i) {
1125 if (mayiuse(avx512_core))
1126 h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
1127 else
1128 h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
1129 }
1130 h->add(h->rsp, n_k_regs_to_save * k_mask_size);
1131 }
1132
1133 // restore gpr registers
1134 for (int i = n_gprs_to_save - 1; i >= 0; --i)
1135 h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
1136 h->add(h->rsp, n_gprs_to_save * gpr_size);
1137
1138 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1139 }
1140}
1141
1142template <cpu_isa_t isa, typename Wmm>
1143void jit_uni_eltwise_injector_f32<isa,
1144 Wmm>::gelu_erf_minimax_approx_compute_vector_fwd(const Vmm &vmm_src) {
1145 using namespace Xbyak::util;
1146
1147 // TODO: consider enabling for lower ISA
1148 if (!is_superset(isa, avx512_core)) return;
1149
1150 // register mapping
1151 Vmm vmm_pol = vmm_aux1, vmm_src_square = vmm_aux2, vmm_src_half = vmm_aux3,
1152 vmm_src_positive = vmm_aux4;
1153
1154 h->uni_vmulps(vmm_src_square, vmm_src, vmm_src);
1155 h->uni_vmovups(vmm_src_positive, vmm_src);
1156 h->uni_vandps(vmm_src_positive, vmm_src_positive, table_val(positive_mask));
1157
1158 h->uni_vmulps(vmm_src_half, vmm_src, table_val(half));
1159 // compute P(x^2)
1160 h->uni_vmovups(vmm_pol, table_val(gelu_erf_minimax_pol, 14));
1161 // TODO: consider reducing latency by spitting into parital sums, for
1162 // example by using x^4 polynomial
1163 for (int deg = 13; deg >= 0; --deg) {
1164 h->uni_vfmadd213ps(
1165 vmm_pol, vmm_src_square, table_val(gelu_erf_minimax_pol, deg));
1166 }
1167
1168 // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2)
1169 h->uni_vfmadd213ps(vmm_pol, vmm_src, table_val(one));
1170 // move instead first blend_with_mask?
1171 h->uni_vmulps(vmm_pol, vmm_pol, vmm_src_half);
1172 // Now we blend the results
1173 // [saturation_ubound; +inf] : we return x
1174 // [-inf; neg_saturation_ubound] : we return 0.0f
1175 h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_neg_saturation_ubound));
1176 compute_cmp_mask(vmm_mask, vmm_src, _cmp_ge_os);
1177 blend_with_mask(vmm_src, table_val(zero));
1178 // [neg_saturation_ubound; -linear_ubound] or
1179 // [linear_ubound; saturation_lbound] : we return P(x)
1180 h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_saturation_lbound));
1181 compute_cmp_mask(vmm_mask, vmm_src_positive, _cmp_gt_os);
1182 blend_with_mask(vmm_src, vmm_pol);
1183 // [-linear_ubound; linear_ubound] : we return 0.5f * x
1184 h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_linear_ubound));
1185 compute_cmp_mask(vmm_mask, vmm_src_positive, _cmp_gt_os);
1186 blend_with_mask(vmm_src, vmm_src_half);
1187}
1188
1189template <cpu_isa_t isa, typename Wmm>
1190void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_fwd(
1191 const Vmm &vmm_src) {
1192 if (is_superset(isa, avx512_core)) {
1193 gelu_erf_minimax_approx_compute_vector_fwd(vmm_src);
1194 return;
1195 }
1196
1197 // Here we approximate erf(x) using the expression by
1198 // Abramowitz and Stegun from ``Handbook of Mathematical
1199 // Functions''
1200 // NOTE: The performance of this kernel can be further improved
1201 // with a minimax polynomialial expansion, thereby avoiding division
1202 // and exp. However, so far, this has costed larger accuracy
1203 // differences with respect to glibc erf based GELU, in particular
1204 // ~1.0e-5 -- 1.0e-3 absolute error at s = -5.
1205
1206 // use vmm_aux3 to store original src.
1207 h->uni_vmovups(vmm_aux3, vmm_src);
1208
1209 // x = s / sqrt(2)
1210 h->uni_vmulps(vmm_src, vmm_src,
1211 table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_two));
1212
1213 // abs(x)
1214 h->uni_vmovups(vmm_aux4, vmm_src);
1215 abs_compute_vector_fwd(vmm_aux4);
1216
1217 // t = 1 / (p*x + 1)
1218 h->uni_vmovups(
1219 vmm_aux2, table_val(gelu_erf_Abramowitz_Stegun_approx_const));
1220 h->uni_vfmadd213ps(vmm_aux2, vmm_aux4, table_val(one));
1221 h->uni_vmovups(vmm_aux4, table_val(one));
1222 h->uni_vdivps(vmm_aux4, vmm_aux4, vmm_aux2);
1223
1224 // -exp(-x*x)
1225 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
1226 h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));
1227 exp_compute_vector_fwd(vmm_src); // pollutes aux1, aux2
1228 h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));
1229
1230 // get sign
1231 h->uni_vmovups(vmm_aux0, vmm_aux3);
1232 h->uni_vandps(vmm_aux0, vmm_aux0, table_val(sign_mask));
1233
1234 // -exp(-x*x)*t
1235 h->uni_vmulps(vmm_src, vmm_src, vmm_aux4);
1236
1237 // compute polynomialial r
1238 h->uni_vmovups(vmm_aux1, table_val(gelu_erf_Abramowitz_Stegun_pol, 4));
1239 h->uni_vfmadd213ps(
1240 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 3));
1241 h->uni_vfmadd213ps(
1242 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 2));
1243 h->uni_vfmadd213ps(
1244 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 1));
1245 h->uni_vfmadd213ps(
1246 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 0));
1247
1248 // erf = sign * (1 - r * t * exp(-x*x))
1249 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one));
1250 h->uni_vxorps(vmm_src, vmm_src, vmm_aux0);
1251
1252 // S = 0.5 * s
1253 h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(half));
1254 // GELU = 0.5 * s * (1 + erf) = S + S * erf
1255 h->uni_vfmadd213ps(vmm_src, vmm_aux3, vmm_aux3);
1256}
1257
1258template <cpu_isa_t isa, typename Wmm>
1259void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_compute_vector_bwd(
1260 const Vmm &vmm_src) {
1261 // invariant to whether `s` or `d` is passed.
1262 // get mask of `s` > 0
1263 compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os);
1264 // fill with alpha, then blend with 1.f
1265 h->uni_vmovups(vmm_src, table_val(alpha));
1266 blend_with_mask(vmm_src, table_val(one));
1267}
1268
1269template <cpu_isa_t isa, typename Wmm>
1270void jit_uni_eltwise_injector_f32<isa, Wmm>::elu_compute_vector_bwd(
1271 const Vmm &vmm_src) {
1272 if (!use_dst_) {
1273 // R = exp(s)
1274 exp_compute_vector_fwd(vmm_src);
1275 // after exponentiation, get mask by comparing with exp(0)=1.f, not 0.f
1276 compute_cmp_mask(vmm_src, table_val(one), _cmp_gt_os);
1277 // R * alpha, then blend with 1.f
1278 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1279 } else {
1280 // get mask of `d` > 0
1281 compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os);
1282 // R = `d` + alpha, then blend with 1.f
1283 h->uni_vaddps(vmm_src, vmm_src, table_val(alpha));
1284 }
1285 blend_with_mask(vmm_src, table_val(one));
1286}
1287
1288template <cpu_isa_t isa, typename Wmm>
1289void jit_uni_eltwise_injector_f32<isa, Wmm>::tanh_compute_vector_bwd(
1290 const Vmm &vmm_src) {
1291 // res = 1 - d^2 = 1 - tanh^2(s)
1292 if (!use_dst_) tanh_compute_vector_fwd(vmm_src);
1293 h->uni_vmovups(vmm_aux0, table_val(one));
1294 h->uni_vfnmadd231ps(vmm_aux0, vmm_src, vmm_src);
1295 h->uni_vmovups(vmm_src, vmm_aux0);
1296}
1297
1298template <cpu_isa_t isa, typename Wmm>
1299void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_bwd(
1300 const Vmm &vmm_src) {
1301 h->uni_vmovups(vmm_aux0, vmm_src);
1302
1303 // compute G1(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x^2)
1304 // compute G2(x) = sqrt_root_two_over_pi * x * (1 + 3 * fitting_const * x^2)
1305 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
1306
1307 // keep G2 in a separate register
1308 h->uni_vmovups(vmm_aux2, table_val(gelu_tanh_fitting_const_times_three));
1309 h->uni_vfmadd213ps(vmm_aux2, vmm_src, table_val(one));
1310
1311 h->uni_vmovups(vmm_aux1, table_val(gelu_tanh_fitting_const));
1312 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one));
1313 h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(gelu_tanh_sqrt_two_over_pi));
1314 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
1315 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux0);
1316
1317 // save G2 on stack as tanh uses all available registers
1318 h->sub(h->rsp, vlen);
1319 h->uni_vmovups(h->ptr[h->rsp], vmm_aux2);
1320
1321 // T = tanh(G1(x))
1322 tanh_compute_vector_fwd(vmm_src);
1323
1324 h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]);
1325 h->add(h->rsp, vlen);
1326
1327 // compute 0.5 * (1 + T) * (1 + G2 * (1 - T))
1328 if (isa == sse41 || isa == avx) {
1329 h->uni_vmovups(vmm_aux3, table_val(one));
1330 h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
1331 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux3);
1332 h->uni_vaddps(vmm_src, vmm_src, table_val(one));
1333 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_src);
1334 h->uni_vaddps(vmm_src, vmm_src, vmm_aux2);
1335 } else {
1336 // 1) R = G2 * (1 - T) = G2 - G2 * T
1337 h->uni_vfnmadd231ps(vmm_aux2, vmm_aux2, vmm_src);
1338 // 2) Q = 1 + T
1339 h->uni_vaddps(vmm_src, vmm_src, table_val(one));
1340 // 3) res = Q * (1 + R) = Q + Q * R
1341 h->uni_vfmadd231ps(vmm_src, vmm_src, vmm_aux2);
1342 }
1343 h->uni_vmulps(vmm_src, vmm_src, table_val(half));
1344}
1345
1346template <cpu_isa_t isa, typename Wmm>
1347void jit_uni_eltwise_injector_f32<isa, Wmm>::square_compute_vector_bwd(
1348 const Vmm &vmm_src) {
1349 // res = 2 * s
1350 h->uni_vmulps(vmm_src, vmm_src, table_val(two));
1351}
1352
1353template <cpu_isa_t isa, typename Wmm>
1354void jit_uni_eltwise_injector_f32<isa, Wmm>::abs_compute_vector_bwd(
1355 const Vmm &vmm_src) {
1356 // replace positive values with 1.f
1357 compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os);
1358 blend_with_mask(vmm_src, table_val(one));
1359 // replace negative values with -1.f
1360 compute_cmp_mask(vmm_src, table_val(zero), _cmp_lt_os);
1361 blend_with_mask(vmm_src, table_val(minus_one));
1362}
1363
1364template <cpu_isa_t isa, typename Wmm>
1365void jit_uni_eltwise_injector_f32<isa, Wmm>::sqrt_compute_vector_bwd(
1366 const Vmm &vmm_src) {
1367 // res = 0.5 / d = 0.5 / sqrt(s)
1368 if (!use_dst_) sqrt_compute_vector_fwd(vmm_src);
1369 h->uni_vmovups(vmm_aux0, table_val(half));
1370 // h->uni_vdivps(vmm_src, vmm_aux0, vmm_src); // bless sse41
1371 h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src);
1372 h->uni_vmovups(vmm_src, vmm_aux0);
1373}
1374
1375template <cpu_isa_t isa, typename Wmm>
1376void jit_uni_eltwise_injector_f32<isa, Wmm>::linear_compute_vector_bwd(
1377 const Vmm &vmm_src) {
1378 h->uni_vmovups(vmm_src, table_val(alpha));
1379}
1380
1381template <cpu_isa_t isa, typename Wmm>
1382void jit_uni_eltwise_injector_f32<isa, Wmm>::soft_relu_compute_vector_bwd(
1383 const Vmm &vmm_src) {
1384 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1385 logistic_compute_vector_fwd(vmm_src);
1386}
1387
1388template <cpu_isa_t isa, typename Wmm>
1389void jit_uni_eltwise_injector_f32<isa, Wmm>::mish_compute_vector_bwd(
1390 const Vmm &vmm_src) {
1391 // IMPORTANT: we use vmm_aux3 to save src as exp does not use it.
1392 h->uni_vmovups(vmm_aux3, vmm_src); // vmm_aux3 = x
1393
1394 h->uni_vminps(vmm_src, vmm_src, table_val(bwd_mish_max_x_for_equation_f));
1395 exp_compute_vector_fwd(vmm_src);
1396 h->uni_vmovups(vmm_aux2, vmm_src); // vmm_aux2 = e^x
1397
1398 // e^3x + 4*e^2x
1399 h->uni_vmulps(vmm_src, vmm_src, vmm_src); // e^2x
1400 h->uni_vmovups(vmm_aux1, vmm_src);
1401 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two));
1402 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two)); // 4*e^2x
1403 h->uni_vfmadd213ps(vmm_src, vmm_aux2, vmm_aux1);
1404
1405 // e^3x + 4*e^2x + 4*e^x*(x+1.5)
1406 h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(one)); // vmm_aux3 = x + 1
1407 h->uni_vmovups(vmm_aux1, vmm_aux3);
1408 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(half));
1409 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two));
1410 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two));
1411 h->uni_vfmadd231ps(vmm_src, vmm_aux1, vmm_aux2);
1412
1413 // omega = e^3x + 4*e^2x + 4*e^x*(x+1.5) + 4*(x+1)
1414 h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(two));
1415 h->uni_vfmadd231ps(vmm_src, vmm_aux3, table_val(two));
1416
1417 // delta = (e^x+1)^2 + 1
1418 h->uni_vmovups(vmm_aux1, vmm_aux2);
1419 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one));
1420 h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux1);
1421 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one));
1422 h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux1);
1423
1424 // e^x * omega / delta^2
1425 h->uni_vmulps(vmm_src, vmm_src, vmm_aux2);
1426 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
1427}
1428
1429template <cpu_isa_t isa, typename Wmm>
1430void jit_uni_eltwise_injector_f32<isa, Wmm>::logistic_compute_vector_bwd(
1431 const Vmm &vmm_src) {
1432 // res = d * (1 - d) = d - d * d; d = logistic(s)
1433 if (!use_dst_) logistic_compute_vector_fwd(vmm_src);
1434 // h->uni_vfnmadd231ps(vmm_src, vmm_src, vmm_src); // bless sse41
1435 h->uni_vmovups(vmm_aux0, table_val(one));
1436 h->uni_vsubps(vmm_aux0, vmm_aux0, vmm_src);
1437 h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
1438}
1439
1440template <cpu_isa_t isa, typename Wmm>
1441void jit_uni_eltwise_injector_f32<isa, Wmm>::exp_compute_vector_bwd(
1442 const Vmm &vmm_src) {
1443 if (!use_dst_) exp_compute_vector_fwd(vmm_src);
1444}
1445
1446template <cpu_isa_t isa, typename Wmm>
1447void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_bwd(
1448 const Vmm &vmm_src) {
1449 // R = alpha * s
1450 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1451 // Save R on stack for later usage
1452 h->sub(h->rsp, vlen);
1453 h->uni_vmovups(h->ptr[h->rsp], vmm_src);
1454 // Q = sigmoid(alpha * s)
1455 logistic_compute_vector_fwd(vmm_src);
1456 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
1457 h->add(h->rsp, vlen);
1458 // compute Q * (1 + R * (1 - Q))
1459 if (utils::one_of(isa, sse41, avx)) {
1460 h->uni_vmovups(vmm_aux1, table_val(one));
1461 h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_src);
1462 h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux0);
1463 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one));
1464 h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
1465 } else {
1466 // T = R * (1 - Q) = R - R * Q
1467 h->uni_vfnmadd231ps(vmm_aux0, vmm_aux0, vmm_src);
1468 // Q * (1 + T) = Q + Q * T
1469 h->uni_vfmadd231ps(vmm_src, vmm_src, vmm_aux0);
1470 }
1471}
1472
1473template <cpu_isa_t isa, typename Wmm>
1474void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_bwd(
1475 const Vmm &vmm_src) {
1476 // res = 1 / s
1477 h->uni_vmovups(vmm_aux0, table_val(one));
1478 // h->uni_vdivps(vmm_src, vmm_aux0, vmm_src); // bless sse41
1479 h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src);
1480 h->uni_vmovups(vmm_src, vmm_aux0);
1481}
1482
1483template <cpu_isa_t isa, typename Wmm>
1484void jit_uni_eltwise_injector_f32<isa, Wmm>::clip_compute_vector_bwd(
1485 const Vmm &vmm_src) {
1486 using namespace alg_kind;
1487
1488 // set result with 1.f
1489 h->uni_vmovups(vmm_aux1, table_val(one));
1490 const auto cmp_flag = alg_ == eltwise_clip ? _cmp_gt_os : _cmp_ge_os;
1491 // get mask of values > beta (or >= beta) and blend with 0.f
1492 compute_cmp_mask(vmm_src, table_val(beta), cmp_flag);
1493 blend_with_mask(vmm_aux1, table_val(zero));
1494 // get mask of values <= alpha and blend with 0.f
1495 compute_cmp_mask(vmm_src, table_val(alpha), _cmp_le_os);
1496 blend_with_mask(vmm_aux1, table_val(zero));
1497 h->uni_vmovups(vmm_src, vmm_aux1);
1498}
1499
1500template <cpu_isa_t isa, typename Wmm>
1501void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_bwd(
1502 const Vmm &vmm_src) {
1503 // dispatch some special cases.
1504 if (beta_ == 0) { // zero
1505 h->uni_vmovups(vmm_src, table_val(zero));
1506 } else if (beta_ == 0.5) { // 0.5 * alpha / sqrt(s)
1507 sqrt_compute_vector_bwd(vmm_src);
1508 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1509 } else if (beta_ == 1) { // alpha
1510 h->uni_vmovups(vmm_src, table_val(alpha));
1511 } else {
1512 // Save `s` on stack for later usage
1513 h->sub(h->rsp, vlen);
1514 h->uni_vmovups(h->ptr[h->rsp], vmm_src);
1515 // R = alpha * pow(s, beta)
1516 pow_compute_vector_fwd(vmm_src);
1517 // Restore `s` from stack
1518 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
1519 h->add(h->rsp, vlen);
1520 // Save mask of zero elements to convert them into zeros at the end
1521 if (beta_ >= 1) compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_eq_oq);
1522 // res = alpha * beta * pow(s, beta - 1) = beta * R / s;
1523 h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
1524 h->uni_vmulps(vmm_src, vmm_src, table_val(beta));
1525
1526 // beta < 1 leads to NaN as `s` appears in denominator, but beta >= 1
1527 // should lead to zero, when `s` is zero.
1528 if (beta_ >= 1) blend_with_mask(vmm_src, table_val(zero));
1529 }
1530}
1531
1532template <cpu_isa_t isa, typename Wmm>
1533void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_bwd(
1534 const Vmm &vmm_src) {
1535 // R = s / sqrt(2)
1536 h->uni_vmulps(vmm_src, vmm_src,
1537 table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_two));
1538
1539 // Save R on stack for later usage
1540 h->sub(h->rsp, vlen);
1541 h->uni_vmovups(h->ptr[h->rsp], vmm_src);
1542
1543 // Q = exp(-R*R)
1544 h->uni_vmulps(vmm_src, vmm_src, vmm_src);
1545 h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));
1546 exp_compute_vector_fwd(vmm_src);
1547
1548 // T = R / sqrt(pi) * Q
1549 h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]);
1550 h->uni_vmulps(vmm_aux2, vmm_aux2,
1551 table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi));
1552 h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_src);
1553
1554 // -Q
1555 h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));
1556
1557 // get sign
1558 h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
1559 h->uni_vandps(vmm_aux0, vmm_aux0, table_val(sign_mask));
1560
1561 // abs(x)
1562 h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
1563 h->add(h->rsp, vlen);
1564 abs_compute_vector_fwd(vmm_aux1);
1565
1566 // W = 1 / (p * s + 1)
1567 h->uni_vmovups(
1568 vmm_aux3, table_val(gelu_erf_Abramowitz_Stegun_approx_const));
1569 h->uni_vmovups(vmm_aux4, table_val(one));
1570 h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, vmm_aux4);
1571 h->uni_vdivps(vmm_aux4, vmm_aux4, vmm_aux3);
1572
1573 // Q * W
1574 h->uni_vmulps(vmm_src, vmm_src, vmm_aux4);
1575
1576 // compute polynomial r
1577 h->uni_vmovups(vmm_aux1, table_val(gelu_erf_Abramowitz_Stegun_pol, 4));
1578 h->uni_vfmadd213ps(
1579 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 3));
1580 h->uni_vfmadd213ps(
1581 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 2));
1582 h->uni_vfmadd213ps(
1583 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 1));
1584 h->uni_vfmadd213ps(
1585 vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 0));
1586
1587 // erf = sign * (1 - r * t * exp(-x*x))
1588 h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one));
1589 h->uni_vxorps(vmm_src, vmm_src, vmm_aux0);
1590
1591 // P = T + 0.5
1592 h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(half));
1593 // res = P + 0.5 * erf
1594 h->uni_vfmadd231ps(vmm_aux2, vmm_src, table_val(half));
1595 h->uni_vmovups(vmm_src, vmm_aux2);
1596}
1597
1598template <cpu_isa_t isa, typename Wmm>
1599void jit_uni_eltwise_injector_f32<isa, Wmm>::hardswish_compute_vector_bwd(
1600 const Vmm &vmm_src) {
1601 // Get mask for 0 < alpha * x + beta < 1
1602 h->uni_vmovups(vmm_aux1, vmm_src);
1603 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(alpha));
1604 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(beta));
1605 // Form a derivative value
1606 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1607 h->uni_vaddps(vmm_src, vmm_src, vmm_aux1);
1608
1609 compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os);
1610 blend_with_mask(vmm_src, table_val(zero));
1611 compute_cmp_mask(vmm_aux1, table_val(one), _cmp_ge_os);
1612 blend_with_mask(vmm_src, table_val(one));
1613}
1614
1615template <cpu_isa_t isa, typename Wmm>
1616void jit_uni_eltwise_injector_f32<isa, Wmm>::hardsigmoid_compute_vector_bwd(
1617 const Vmm &vmm_src) {
1618 // Get mask for 0 < alpha * x + beta < 1
1619 // Zero rest values.
1620 h->uni_vmovups(vmm_aux1, vmm_src);
1621 h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(alpha));
1622 h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(beta));
1623
1624 h->uni_vmovups(vmm_src, table_val(one));
1625 compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os);
1626 blend_with_mask(vmm_src, table_val(zero));
1627 compute_cmp_mask(vmm_aux1, table_val(one), _cmp_ge_os);
1628 blend_with_mask(vmm_src, table_val(zero));
1629 h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
1630}
1631
1632template <cpu_isa_t isa, typename Wmm>
1633size_t jit_uni_eltwise_injector_f32<isa, Wmm>::aux_gprs_count() {
1634 using namespace alg_kind;
1635 switch (alg_) {
1636 case eltwise_tanh_use_dst_for_bwd:
1637 case eltwise_tanh:
1638 case eltwise_gelu_tanh: return isa == sse41 || isa == avx ? 4 : 0;
1639 default: return 0;
1640 }
1641 return 0;
1642};
1643
1644template <cpu_isa_t isa, typename Wmm>
1645void jit_uni_eltwise_injector_f32<isa, Wmm>::round_compute_vector_fwd(
1646 const Vmm &vmm_src) {
1647 h->uni_vroundps(vmm_src, vmm_src, _op_mxcsr);
1648}
1649
1650template <cpu_isa_t isa, typename Wmm>
1651size_t jit_uni_eltwise_injector_f32<isa, Wmm>::aux_vecs_count() {
1652 using namespace alg_kind;
1653 if (is_fwd_) {
1654 switch (alg_) {
1655 case eltwise_relu_use_dst_for_bwd:
1656 case eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
1657 case eltwise_elu_use_dst_for_bwd:
1658 case eltwise_elu: return 4;
1659 case eltwise_tanh_use_dst_for_bwd:
1660 case eltwise_tanh: return 5;
1661 case eltwise_square: return 0;
1662 case eltwise_abs: return 0;
1663 case eltwise_sqrt_use_dst_for_bwd:
1664 case eltwise_sqrt: return 0;
1665 case eltwise_linear: return 1;
1666 case eltwise_soft_relu: return 4;
1667 case eltwise_mish: return 4;
1668 case eltwise_logistic_use_dst_for_bwd:
1669 case eltwise_logistic: return 4;
1670 case eltwise_exp_use_dst_for_bwd:
1671 case eltwise_exp: return 3;
1672 case eltwise_gelu_tanh: return 5;
1673 case eltwise_swish: return 4;
1674 case eltwise_log: return 5;
1675 case eltwise_clip:
1676 case eltwise_clip_v2_use_dst_for_bwd:
1677 case eltwise_clip_v2: return 0;
1678 case eltwise_pow: return 2;
1679 case eltwise_gelu_erf: return 5;
1680 case eltwise_round: return 0;
1681 case eltwise_hardswish: return 1;
1682 case eltwise_hardsigmoid: return 0;
1683 default: assert(!"unsupported eltwise algorithm");
1684 }
1685 } else {
1686 switch (alg_) {
1687 case eltwise_relu_use_dst_for_bwd:
1688 case eltwise_relu: return 1;
1689 case eltwise_elu_use_dst_for_bwd: return 1;
1690 case eltwise_elu: return 3;
1691 case eltwise_tanh_use_dst_for_bwd: return 1;
1692 case eltwise_tanh: return 5;
1693 case eltwise_square: return 0;
1694 case eltwise_abs: return 0;
1695 case eltwise_sqrt_use_dst_for_bwd:
1696 case eltwise_sqrt: return 1;
1697 case eltwise_linear: return 0;
1698 case eltwise_soft_relu: return 4;
1699 case eltwise_mish: return 4;
1700 case eltwise_logistic_use_dst_for_bwd: return 1;
1701 case eltwise_logistic: return 4;
1702 case eltwise_exp_use_dst_for_bwd: return 0;
1703 case eltwise_exp: return 3;
1704 case eltwise_gelu_tanh: return 5;
1705 case eltwise_swish: return 4;
1706 case eltwise_log: return 1;
1707 case eltwise_clip:
1708 case eltwise_clip_v2_use_dst_for_bwd:
1709 case eltwise_clip_v2: return 2;
1710 case eltwise_pow: return 2;
1711 case eltwise_gelu_erf: return 5;
1712 case eltwise_hardswish: return 2;
1713 case eltwise_hardsigmoid: return 2;
1714 default: assert(!"unsupported eltwise algorithm");
1715 }
1716 }
1717
1718 return 0;
1719}
1720
1721template <cpu_isa_t isa, typename Wmm>
1722void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_body(
1723 const injector_utils::vmm_index_set_iterator_t &start_idx_it,
1724 const injector_utils::vmm_index_set_iterator_t &end_idx_it) {
1725 using namespace alg_kind;
1726 std::for_each(start_idx_it, end_idx_it, [&](size_t idx) {
1727 if (is_fwd_) {
1728 switch (alg_) {
1729 case eltwise_relu_use_dst_for_bwd:
1730 case eltwise_relu:
1731 if (alpha_ == 0.f)
1732 relu_zero_ns_compute_vector_fwd(Vmm(idx));
1733 else
1734 relu_compute_vector_fwd(Vmm(idx));
1735 break;
1736 case eltwise_elu_use_dst_for_bwd:
1737 case eltwise_elu: elu_compute_vector_fwd(Vmm(idx)); break;
1738 case eltwise_tanh_use_dst_for_bwd:
1739 case eltwise_tanh: tanh_compute_vector_fwd(Vmm(idx)); break;
1740 case eltwise_square: square_compute_vector_fwd(Vmm(idx)); break;
1741 case eltwise_abs: abs_compute_vector_fwd(Vmm(idx)); break;
1742 case eltwise_sqrt_use_dst_for_bwd:
1743 case eltwise_sqrt: sqrt_compute_vector_fwd(Vmm(idx)); break;
1744 case eltwise_swish: swish_compute_vector_fwd(Vmm(idx)); break;
1745 case eltwise_linear: linear_compute_vector_fwd(Vmm(idx)); break;
1746 case eltwise_soft_relu:
1747 soft_relu_compute_vector_fwd(Vmm(idx));
1748 break;
1749 case eltwise_mish: mish_compute_vector_fwd(Vmm(idx)); break;
1750 case eltwise_logistic_use_dst_for_bwd:
1751 case eltwise_logistic:
1752 logistic_compute_vector_fwd(Vmm(idx));
1753 break;
1754 case eltwise_exp_use_dst_for_bwd:
1755 case eltwise_exp: exp_compute_vector_fwd(Vmm(idx)); break;
1756 case eltwise_gelu_tanh:
1757 gelu_tanh_compute_vector_fwd(Vmm(idx));
1758 break;
1759 case eltwise_log: log_compute_vector_fwd(Vmm(idx)); break;
1760 case eltwise_clip:
1761 case eltwise_clip_v2_use_dst_for_bwd:
1762 case eltwise_clip_v2: clip_compute_vector_fwd(Vmm(idx)); break;
1763 case eltwise_pow: pow_compute_vector_fwd(Vmm(idx)); break;
1764 case eltwise_gelu_erf:
1765 gelu_erf_compute_vector_fwd(Vmm(idx));
1766 break;
1767 case eltwise_round: round_compute_vector_fwd(Vmm(idx)); break;
1768 case eltwise_hardswish:
1769 hardswish_compute_vector_fwd(Vmm(idx));
1770 break;
1771 case eltwise_hardsigmoid:
1772 hardsigmoid_compute_vector_fwd(Vmm(idx));
1773 break;
1774 default: assert(!"unsupported eltwise algorithm");
1775 }
1776 } else {
1777 switch (alg_) {
1778 case eltwise_relu_use_dst_for_bwd:
1779 case eltwise_relu: relu_compute_vector_bwd(Vmm(idx)); break;
1780 case eltwise_elu_use_dst_for_bwd:
1781 case eltwise_elu: elu_compute_vector_bwd(Vmm(idx)); break;
1782 case eltwise_tanh_use_dst_for_bwd:
1783 case eltwise_tanh: tanh_compute_vector_bwd(Vmm(idx)); break;
1784 case eltwise_square: square_compute_vector_bwd(Vmm(idx)); break;
1785 case eltwise_abs: abs_compute_vector_bwd(Vmm(idx)); break;
1786 case eltwise_sqrt_use_dst_for_bwd:
1787 case eltwise_sqrt: sqrt_compute_vector_bwd(Vmm(idx)); break;
1788 case eltwise_linear: linear_compute_vector_bwd(Vmm(idx)); break;
1789 case eltwise_soft_relu:
1790 soft_relu_compute_vector_bwd(Vmm(idx));
1791 break;
1792 case eltwise_mish: mish_compute_vector_bwd(Vmm(idx)); break;
1793 case eltwise_logistic_use_dst_for_bwd:
1794 case eltwise_logistic:
1795 logistic_compute_vector_bwd(Vmm(idx));
1796 break;
1797 case eltwise_exp_use_dst_for_bwd:
1798 case eltwise_exp: exp_compute_vector_bwd(Vmm(idx)); break;
1799 case eltwise_gelu_tanh:
1800 gelu_tanh_compute_vector_bwd(Vmm(idx));
1801 break;
1802 case eltwise_swish: swish_compute_vector_bwd(Vmm(idx)); break;
1803 case eltwise_log: log_compute_vector_bwd(Vmm(idx)); break;
1804 case eltwise_clip:
1805 case eltwise_clip_v2_use_dst_for_bwd:
1806 case eltwise_clip_v2: clip_compute_vector_bwd(Vmm(idx)); break;
1807 case eltwise_pow: pow_compute_vector_bwd(Vmm(idx)); break;
1808 case eltwise_gelu_erf:
1809 gelu_erf_compute_vector_bwd(Vmm(idx));
1810 break;
1811 case eltwise_hardswish:
1812 hardswish_compute_vector_bwd(Vmm(idx));
1813 break;
1814 case eltwise_hardsigmoid:
1815 hardsigmoid_compute_vector_bwd(Vmm(idx));
1816 break;
1817 default: assert(!"unsupported eltwise algorithm");
1818 }
1819 }
1820 if (scale_ != 1.f) {
1821 h->uni_vmulps(Vmm(idx), Vmm(idx), table_val(scale));
1822 }
1823 });
1824}
1825
1826template <cpu_isa_t isa, typename Wmm>
1827void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_vector_range(
1828 size_t start_idx, size_t end_idx) {
1829 injector_utils::vmm_index_set_t vmm_idxs;
1830 for (size_t i = start_idx; i < end_idx; i++)
1831 vmm_idxs.emplace(i);
1832 compute_vector_range(vmm_idxs);
1833}
1834
1835template <cpu_isa_t isa, typename Wmm>
1836void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_vector_range(
1837 const injector_utils::vmm_index_set_t &vmm_idxs) {
1838 const auto &start_idx_it = vmm_idxs.begin();
1839 const auto &end_idx_it = vmm_idxs.end();
1840 assert(*start_idx_it < *vmm_idxs.rbegin() + 1
1841 && *vmm_idxs.rbegin() <= vecs_count);
1842
1843 injector_preamble(vmm_idxs);
1844 compute_body(start_idx_tail, end_idx_it);
1845 injector_preamble_tail(start_idx_it);
1846 compute_body(start_idx_it, start_idx_tail);
1847 injector_postamble();
1848}
1849
1850template <cpu_isa_t isa, typename Wmm>
1851void jit_uni_eltwise_injector_f32<isa, Wmm>::prepare_table(bool gen_table) {
1852 if (!gen_table) return;
1853
1854 h->align(64);
1855 h->L(l_table);
1856
1857 // Assumption: entries can be inserted with dd, so they should be 4 bytes.
1858 assert(sizeof(table_entry_val_t) == 4);
1859
1860 // Assumption: iterating on entry_map_ here has the same order as
1861 // when we set the offsets. We verify that in asserts.
1862 // table_entry_val_t is assumed to be 32 bits
1863#ifndef NDEBUG
1864 size_t off = 0;
1865 key_t curr_key = undef_key;
1866 int key_occurences = 0;
1867#endif
1868
1869 // Run through the map and insert values stored there
1870 for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) {
1871 const auto &te = (*it).second; // get map entry for a given key
1872 const auto len = te.bcast ? vlen : sizeof(table_entry_val_t);
1873 for (size_t d = 0; d < len; d += sizeof(table_entry_val_t))
1874 h->dd(te.val);
1875
1876#ifndef NDEBUG
1877 // we check that the precomputed offsets match the registered ones
1878 const auto &key = (*it).first; // get map entry key
1879 if (key != curr_key) {
1880 curr_key = key;
1881 key_occurences = 0;
1882 }
1883 key_occurences++;
1884 auto expected_off = table_off(key, key_occurences - 1);
1885 assert(off == expected_off);
1886 MAYBE_UNUSED(expected_off);
1887 off += len;
1888#endif
1889 }
1890}
1891
1892template <cpu_isa_t isa, typename Wmm>
1893void jit_uni_eltwise_injector_f32<isa, Wmm>::register_table_entries() {
1894 // This function is responsible to pick all necessary constants
1895 // for a given algorithm, compute right offset for them to be used
1896 // in table_val() and save the hexadecimal value of them, which
1897 // will be finally used in prepare_table(). We rely on fact that
1898 // the map iterator order is deterministic for a fixed map.
1899
1900 // common values used in several algorithms
1901 static const table_t common_values {{zero, {0x00000000, true}},
1902 {half, {0x3f000000, true}}, {one, {0x3f800000, true}},
1903 {two, {0x40000000, true}}, {minus_one, {0xbf800000, true}},
1904 {minus_two, {0xc0000000, true}}, {ln2f, {0x3f317218, true}},
1905 {positive_mask, {0x7fffffff, true}},
1906 {sign_mask, {0x80000000, true}},
1907 {exponent_bias, {0x0000007f, true}}};
1908
1909 // exp(x) constants
1910 static const table_t exp_consts {{exp_log2ef, {0x3fb8aa3b, true}},
1911 {exp_ln_flt_max_f, {0x42b17218, true}},
1912 {exp_ln_flt_min_f, {0xc2aeac50, true}}};
1913
1914 // exp(x) polynomial approximation
1915 static const table_t exp_polynomial {
1916 // p0 = 1.0f
1917 {exp_pol, {0x3f7ffffb, true}}, // p1 = 0.999999701f
1918 {exp_pol, {0x3efffee3, true}}, // p2 = 0.499991506f
1919 {exp_pol, {0x3e2aad40, true}}, // p3 = 0.166676521f
1920 {exp_pol, {0x3d2b9d0d, true}}, // p4 = 0.0418978221f
1921 {exp_pol, {0x3c07cfce, true}} // p5 = 0.00828929059f
1922 };
1923
1924 // mish(x) constants
1925 static const table_t mish_consts {
1926 {fwd_mish_max_x_for_equation_f, {0x42317217, true}},
1927 {bwd_mish_max_x_for_equation_f, {0x41b17217, true}}};
1928
1929 // tanh(x) constants for four interval approximation
1930 static const table_t tanh_consts {{tanh_idx_bias, {0x39800000, true}},
1931 {tanh_idx_mask, {0xffc00000, true}},
1932 {tanh_linear_ubound, {0x39ddb3d7, true}},
1933 {tanh_saturation_lbound, {0x41102cb3, true}}};
1934
1935 // tanh(x) polynomial approximation
1936 // For each coefficient, there is 32 entries
1937 static const table_t tanh_polynomial_table {
1938 // coefficients of degree 0
1939 {tanh_pol_table, {0x00000000, false}},
1940 {tanh_pol_table, {0x39bfffff, false}},
1941 {tanh_pol_table, {0x39ffffff, false}},
1942 {tanh_pol_table, {0x3a3ffffe, false}},
1943 {tanh_pol_table, {0x3a7ffffb, false}},
1944 {tanh_pol_table, {0x3abffff7, false}},
1945 {tanh_pol_table, {0x3affffeb, false}},
1946 {tanh_pol_table, {0x3b3fffdc, false}},
1947 {tanh_pol_table, {0x3b7fffab, false}},
1948 {tanh_pol_table, {0x3bbfff70, false}},
1949 {tanh_pol_table, {0x3bfffeab, false}},
1950 {tanh_pol_table, {0x3c3ffdc0, false}},
1951 {tanh_pol_table, {0x3c7ffaab, false}},
1952 {tanh_pol_table, {0x3cbff701, false}},
1953 {tanh_pol_table, {0x3cffeaad, false}},
1954 {tanh_pol_table, {0x3d3fdc08, false}},
1955 {tanh_pol_table, {0x3d7faacd, false}},
1956 {tanh_pol_table, {0x3dbf7081, false}},
1957 {tanh_pol_table, {0x3dfeacc9, false}},
1958 {tanh_pol_table, {0x3e3dc7fd, false}},
1959 {tanh_pol_table, {0x3e7acbf5, false}},
1960 {tanh_pol_table, {0x3eb77a9f, false}},
1961 {tanh_pol_table, {0x3eec9a9f, false}},
1962 {tanh_pol_table, {0x3f22991f, false}},
1963 {tanh_pol_table, {0x3f42f7d6, false}},
1964 {tanh_pol_table, {0x3f67b7cc, false}},
1965 {tanh_pol_table, {0x3f76ca83, false}},
1966 {tanh_pol_table, {0x3f7ebbe9, false}},
1967 {tanh_pol_table, {0x3f7fd40c, false}},
1968 {tanh_pol_table, {0x3f7fff32, false}},
1969 {tanh_pol_table, {0x3f7ffffc, false}},
1970 {tanh_pol_table, {0x3f800000, false}},
1971 // coefficients of degree 1
1972 {tanh_pol_table, {0x3f800000, false}},
1973 {tanh_pol_table, {0x3f800018, false}},
1974 {tanh_pol_table, {0x3f7fffe8, false}},
1975 {tanh_pol_table, {0x3f7fffda, false}},
1976 {tanh_pol_table, {0x3f7fffdc, false}},
1977 {tanh_pol_table, {0x3f7fffdc, false}},
1978 {tanh_pol_table, {0x3f7fffac, false}},
1979 {tanh_pol_table, {0x3f7fff70, false}},
1980 {tanh_pol_table, {0x3f7ffeec, false}},
1981 {tanh_pol_table, {0x3f7ffdc0, false}},
1982 {tanh_pol_table, {0x3f7ffbed, false}},
1983 {tanh_pol_table, {0x3f7ff704, false}},
1984 {tanh_pol_table, {0x3f7feff5, false}},
1985 {tanh_pol_table, {0x3f7fdbca, false}},
1986 {tanh_pol_table, {0x3f7fbfff, false}},
1987 {tanh_pol_table, {0x3f7f7041, false}},
1988 {tanh_pol_table, {0x3f7f009b, false}},
1989 {tanh_pol_table, {0x3f7dc36c, false}},
1990 {tanh_pol_table, {0x3f7c0aa8, false}},
1991 {tanh_pol_table, {0x3f7734b8, false}},
1992 {tanh_pol_table, {0x3f70a4de, false}},
1993 {tanh_pol_table, {0x3f5f1fd8, false}},
1994 {tanh_pol_table, {0x3f495493, false}},
1995 {tanh_pol_table, {0x3f18b9ec, false}},
1996 {tanh_pol_table, {0x3ed706cb, false}},
1997 {tanh_pol_table, {0x3e390b06, false}},
1998 {tanh_pol_table, {0x3d90b11f, false}},
1999 {tanh_pol_table, {0x3c21a053, false}},
2000 {tanh_pol_table, {0x3aaf7fdb, false}},
2001 {tanh_pol_table, {0x37ccc1a3, false}},
2002 {tanh_pol_table, {0x355c6733, false}},
2003 {tanh_pol_table, {0x00000000, false}},
2004 // coefficients of degree 2
2005 {tanh_pol_table, {0x00000000, false}},
2006 {tanh_pol_table, {0xbe4e0ff1, false}},
2007 {tanh_pol_table, {0x3d25b1b1, false}},
2008 {tanh_pol_table, {0x3d6b6dab, false}},
2009 {tanh_pol_table, {0x3c9fb1d5, false}},
2010 {tanh_pol_table, {0xbabff06f, false}},
2011 {tanh_pol_table, {0x3c07b3f6, false}},
2012 {tanh_pol_table, {0xbb3fc1bc, false}},
2013 {tanh_pol_table, {0x3a9f5921, false}},
2014 {tanh_pol_table, {0xbbbf06f2, false}},
2015 {tanh_pol_table, {0xbbb0f402, false}},
2016 {tanh_pol_table, {0xbc47db9e, false}},
2017 {tanh_pol_table, {0xbc73d5e7, false}},
2018 {tanh_pol_table, {0xbca25bda, false}},
2019 {tanh_pol_table, {0xbcfca780, false}},
2020 {tanh_pol_table, {0xbd40e07c, false}},
2021 {tanh_pol_table, {0xbd7dab03, false}},
2022 {tanh_pol_table, {0xbdbe4a0f, false}},
2023 {tanh_pol_table, {0xbdfb14a5, false}},
2024 {tanh_pol_table, {0xbe36cc8d, false}},
2025 {tanh_pol_table, {0xbe6bd102, false}},
2026 {tanh_pol_table, {0xbe9fe7c5, false}},
2027 {tanh_pol_table, {0xbeba0f10, false}},
2028 {tanh_pol_table, {0xbec206a8, false}},
2029 {tanh_pol_table, {0xbea3c388, false}},
2030 {tanh_pol_table, {0xbe277d62, false}},
2031 {tanh_pol_table, {0xbd8b7960, false}},
2032 {tanh_pol_table, {0xbc209f49, false}},
2033 {tanh_pol_table, {0xbaad44ca, false}},
2034 {tanh_pol_table, {0xb7c6eeac, false}},
2035 {tanh_pol_table, {0xb663aa41, false}},
2036 {tanh_pol_table, {0x00000000, false}},
2037 // coefficients of degree 3
2038 {tanh_pol_table, {0x00000000, false}},
2039 {tanh_pol_table, {0x45b3ae96, false}},
2040 {tanh_pol_table, {0xc414eb20, false}},
2041 {tanh_pol_table, {0xc450e02e, false}},
2042 {tanh_pol_table, {0xc3152b4e, false}},
2043 {tanh_pol_table, {0xbead2f56, false}},
2044 {tanh_pol_table, {0xc2162e02, false}},
2045 {tanh_pol_table, {0xbeb4bd5a, false}},
2046 {tanh_pol_table, {0xc11a59a4, false}},
2047 {tanh_pol_table, {0xbed2f507, false}},
2048 {tanh_pol_table, {0xc020d32c, false}},
2049 {tanh_pol_table, {0x3dd0f506, false}},
2050 {tanh_pol_table, {0xbf2a75e2, false}},
2051 {tanh_pol_table, {0xbff950e3, false}},
2052 {tanh_pol_table, {0xbed47334, false}},
2053 {tanh_pol_table, {0xbe809b8c, false}},
2054 {tanh_pol_table, {0xbeb64532, false}},
2055 {tanh_pol_table, {0xbe961a5b, false}},
2056 {tanh_pol_table, {0xbe9b63ac, false}},
2057 {tanh_pol_table, {0xbea0d4b2, false}},
2058 {tanh_pol_table, {0xbe828a77, false}},
2059 {tanh_pol_table, {0xbe378612, false}},
2060 {tanh_pol_table, {0xbdc20908, false}},
2061 {tanh_pol_table, {0x3d2d3957, false}},
2062 {tanh_pol_table, {0x3dd46e89, false}},
2063 {tanh_pol_table, {0x3db3f629, false}},
2064 {tanh_pol_table, {0x3d2c5e7b, false}},
2065 {tanh_pol_table, {0x3bd20403, false}},
2066 {tanh_pol_table, {0x3a59dfae, false}},
2067 {tanh_pol_table, {0x3770af45, false}},
2068 {tanh_pol_table, {0x372cc014, false}},
2069 {tanh_pol_table, {0x00000000, false}},
2070 // coefficients of degree 4
2071 {tanh_pol_table, {0x00000000, false}},
2072 {tanh_pol_table, {0xcc981a1b, false}},
2073 {tanh_pol_table, {0x4a7edd3d, false}},
2074 {tanh_pol_table, {0x4ab1007c, false}},
2075 {tanh_pol_table, {0x48fedd9c, false}},
2076 {tanh_pol_table, {0x41a557b5, false}},
2077 {tanh_pol_table, {0x477ee32a, false}},
2078 {tanh_pol_table, {0x422557f5, false}},
2079 {tanh_pol_table, {0x45ff3ce4, false}},
2080 {tanh_pol_table, {0x42a55641, false}},
2081 {tanh_pol_table, {0x446e0867, false}},
2082 {tanh_pol_table, {0xc33dc19a, false}},
2083 {tanh_pol_table, {0x42915214, false}},
2084 {tanh_pol_table, {0x43af4fad, false}},
2085 {tanh_pol_table, {0x4110fe88, false}},
2086 {tanh_pol_table, {0xc1099b75, false}},
2087 {tanh_pol_table, {0x3fc8a8dc, false}},
2088 {tanh_pol_table, {0xbfbeaef5, false}},
2089 {tanh_pol_table, {0xbe365aad, false}},
2090 {tanh_pol_table, {0x3f4d9652, false}},
2091 {tanh_pol_table, {0x3ddfa08f, false}},
2092 {tanh_pol_table, {0x3e34e9b8, false}},
2093 {tanh_pol_table, {0x3e2d07a6, false}},
2094 {tanh_pol_table, {0x3dc63567, false}},
2095 {tanh_pol_table, {0x3cdaeb78, false}},
2096 {tanh_pol_table, {0xbcd17537, false}},
2097 {tanh_pol_table, {0xbc92829c, false}},
2098 {tanh_pol_table, {0xbb43ab99, false}},
2099 {tanh_pol_table, {0xb9b471dd, false}},
2100 {tanh_pol_table, {0xb6baad5a, false}},
2101 {tanh_pol_table, {0xb78bafc7, false}},
2102 {tanh_pol_table, {0x00000000, false}},
2103 // coefficients of degree 5
2104 {tanh_pol_table, {0x00000000, false}},
2105 {tanh_pol_table, {0x52f688d5, false}},
2106 {tanh_pol_table, {0xd0505c72, false}},
2107 {tanh_pol_table, {0xd08f98e3, false}},
2108 {tanh_pol_table, {0xce505cc9, false}},
2109 {tanh_pol_table, {0xc7162b8a, false}},
2110 {tanh_pol_table, {0xcc5061d6, false}},
2111 {tanh_pol_table, {0xc7162bdf, false}},
2112 {tanh_pol_table, {0xca50b37f, false}},
2113 {tanh_pol_table, {0xc7162a3a, false}},
2114 {tanh_pol_table, {0xc8422086, false}},
2115 {tanh_pol_table, {0x471a714e, false}},
2116 {tanh_pol_table, {0xc5ece1f1, false}},
2117 {tanh_pol_table, {0xc70e3d90, false}},
2118 {tanh_pol_table, {0xc3eba94a, false}},
2119 {tanh_pol_table, {0x43e0c424, false}},
2120 {tanh_pol_table, {0xc21f4552, false}},
2121 {tanh_pol_table, {0x42217cc8, false}},
2122 {tanh_pol_table, {0x405e7dc4, false}},
2123 {tanh_pol_table, {0xc10dd401, false}},
2124 {tanh_pol_table, {0x3e96b602, false}},
2125 {tanh_pol_table, {0xbd1a6d2f, false}},
2126 {tanh_pol_table, {0xbd393883, false}},
2127 {tanh_pol_table, {0xbd674682, false}},
2128 {tanh_pol_table, {0xbd310016, false}},
2129 {tanh_pol_table, {0xb961e269, false}},
2130 {tanh_pol_table, {0x3ba32495, false}},
2131 {tanh_pol_table, {0x3a7680d5, false}},
2132 {tanh_pol_table, {0x38b3173c, false}},
2133 {tanh_pol_table, {0x35a9deea, false}},
2134 {tanh_pol_table, {0x375c3f2a, false}},
2135 {tanh_pol_table, {0x00000000, false}},
2136 // coefficients of degree 6
2137 {tanh_pol_table, {0x00000000, false}},
2138 {tanh_pol_table, {0xd8995ed1, false}},
2139 {tanh_pol_table, {0x558285ea, false}},
2140 {tanh_pol_table, {0x55b2cd69, false}},
2141 {tanh_pol_table, {0x53028625, false}},
2142 {tanh_pol_table, {0x4bc9991f, false}},
2143 {tanh_pol_table, {0x5082898a, false}},
2144 {tanh_pol_table, {0x4b4999b3, false}},
2145 {tanh_pol_table, {0x4e02c07c, false}},
2146 {tanh_pol_table, {0x4ac99764, false}},
2147 {tanh_pol_table, {0x4b72c822, false}},
2148 {tanh_pol_table, {0xca40c0e1, false}},
2149 {tanh_pol_table, {0x489413e4, false}},
2150 {tanh_pol_table, {0x49b12224, false}},
2151 {tanh_pol_table, {0x46134c4e, false}},
2152 {tanh_pol_table, {0xc60c2d57, false}},
2153 {tanh_pol_table, {0x43c83910, false}},
2154 {tanh_pol_table, {0xc3c872d1, false}},
2155 {tanh_pol_table, {0xc186bc9e, false}},
2156 {tanh_pol_table, {0x42325bc3, false}},
2157 {tanh_pol_table, {0xbf2ffa4a, false}},
2158 {tanh_pol_table, {0x3d9a203c, false}},
2159 {tanh_pol_table, {0xbc545a43, false}},
2160 {tanh_pol_table, {0xbae08fee, false}},
2161 {tanh_pol_table, {0x3c80225d, false}},
2162 {tanh_pol_table, {0x3b1fd1df, false}},
2163 {tanh_pol_table, {0xba36b9d1, false}},
2164 {tanh_pol_table, {0xb91de544, false}},
2165 {tanh_pol_table, {0xb71f100f, false}},
2166 {tanh_pol_table, {0xb408e2ed, false}},
2167 {tanh_pol_table, {0xb685fec8, false}},
2168 {tanh_pol_table, {0x00000000, false}},
2169 };
2170
2171 // soft_relu(x) constants
2172 static const table_t soft_relu_consts {
2173 {soft_relu_one_twenty_six, {0x42fc0000, true}},
2174 {soft_relu_mantissa_sign_mask, {0x807fffff, true}},
2175 };
2176
2177 // soft_relu ln(1 + x) polynomial approximation
2178 static const table_t soft_relu_polynomial {
2179 {soft_relu_pol, {0xb2b4637d, true}}, // p0 = 0.0000000244f
2180 {soft_relu_pol, {0x3f7fff8e, true}}, // p1 = 0.9999976971f
2181 {soft_relu_pol, {0xbf001759, true}}, // p2 = -0.5002478215f
2182 {soft_relu_pol, {0x3ea70608, true}}, // p3 = 0.3272714505f
2183 {soft_relu_pol, {0xbea3d7bf, true}}, // p4 = -0.3153830071f
2184 {soft_relu_pol, {0xbe361d04, true}}, // p5 = -0.1701777461f
2185 {soft_relu_pol, {0xbfa8f1e6, true}}, // p6 = -1.3254635147f
2186 {soft_relu_pol, {0xbfe1e812, true}}, // p7 = -1.7971917960f
2187 {soft_relu_pol, {0xbfc4d30e, true}}, // p8 = -1.5652673123f
2188 };
2189
2190 // gelu_tanh(x) constants (formula defined)
2191 static const table_t gelu_tanh_consts {
2192 {gelu_tanh_fitting_const, {0x3d372713, true}},
2193 {gelu_tanh_fitting_const_times_three, {0x3e095d4f, true}},
2194 {gelu_tanh_sqrt_two_over_pi, {0x3f4c422a, true}},
2195 };
2196
2197 // gelu_erf(x) constants for approximation based on Abramowitz and Stegun
2198 // algorithm (formula defined)
2199 static const table_t gelu_erf_Abramowitz_Stegun_consts {
2200 {gelu_erf_Abramowitz_Stegun_approx_const, {0x3ea7ba05, true}},
2201 {gelu_erf_Abramowitz_Stegun_one_over_sqrt_two, {0x3f3504f3, true}},
2202 {gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi, {0x3f106eba, true}},
2203 };
2204
2205 // gelu_erf(x) polynomial approximation based on Abramowitz and Stegun
2206 // algorithm
2207 static const table_t gelu_erf_Abramowitz_Stegun_polynomial {
2208 // p1 = 0.254829592f
2209 {gelu_erf_Abramowitz_Stegun_pol, {0x3e827906, true}},
2210 // p2 = -0.284496736f
2211 {gelu_erf_Abramowitz_Stegun_pol, {0xbe91a98e, true}},
2212 // p3 = 1.421413741f
2213 {gelu_erf_Abramowitz_Stegun_pol, {0x3fb5f0e3, true}},
2214 // p4 = -1.453152027f
2215 {gelu_erf_Abramowitz_Stegun_pol, {0xbfba00e3, true}},
2216 // p5 = 1.061405429f
2217 {gelu_erf_Abramowitz_Stegun_pol, {0x3f87dc22, true}},
2218 };
2219
2220 // gelu_erf(x) constants for direct erf approximation (formula defined)
2221 static const table_t gelu_erf_minimax_consts {
2222 // x <= -0x1.4p+2 -> return 0.0f
2223 {gelu_erf_minimax_neg_saturation_ubound, {0xc0a00000, true}},
2224 // |x| <= 0x1.0p-24 -> return 0.5f * x
2225 {gelu_erf_minimax_linear_ubound, {0x33800000, true}},
2226 // x >= 0x1.4p+2 -> return x
2227 {gelu_erf_minimax_saturation_lbound, {0x40a00000, true}}};
2228
2229 // gelu_erf(x) polynomial for direct erf approximation (formula defined)
2230 static const table_t gelu_erf_minimax_polynomial {
2231 {gelu_erf_minimax_pol, {0x3f4c4228, true}}, // p0 = 0x1.98845p-1
2232 {gelu_erf_minimax_pol, {0xbe082bc7, true}}, // p1 = -0x1.10578ep-3
2233 {gelu_erf_minimax_pol, {0x3ca3621f, true}}, // p2 = 0x1.46c43ep-6
2234 {gelu_erf_minimax_pol, {0xbb1b7399, true}}, // p3 = -0x1.36e732p-9
2235 {gelu_erf_minimax_pol, {0x3970b255, true}}, // p4 = 0x1.e164aap-13
2236 {gelu_erf_minimax_pol, {0xb79b0914, true}}, // p5 = -0x1.361228p-16
2237 {gelu_erf_minimax_pol, {0x35a776e9, true}}, // p6 = 0x1.4eedd2p-20
2238 {gelu_erf_minimax_pol, {0xb3969b11, true}}, // p7 = -0x1.2d3622p-24
2239 {gelu_erf_minimax_pol, {0x315d4a4f, true}}, // p8 = 0x1.ba949ep-29
2240 {gelu_erf_minimax_pol, {0xaf013b2c, true}}, // p9 = -0x1.027658p-33
2241 {gelu_erf_minimax_pol, {0x2c67ddb2, true}}, // p10 = 0x1.cfbb64p-39
2242 {gelu_erf_minimax_pol, {0xa998c963, true}}, // p11 = -0x1.3192c6p-44
2243 {gelu_erf_minimax_pol, {0x268a7927, true}}, // p12 = 0x1.14f24ep-50
2244 {gelu_erf_minimax_pol, {0xa3198977, true}}, // p13 = -0x1.3312eep-57
2245 {gelu_erf_minimax_pol, {0x1f1c83fd, true}}, // p14 = 0x1.3907fap-65
2246 };
2247
2248 // log(x) constants
2249 static const table_t log_consts {
2250 {log_inf, {0x7f800000, true}},
2251 {log_minus_inf, {0xff800000, true}},
2252 {log_qnan, {0x7fc00000, true}},
2253 {log_mantissa_mask, {0x007fffff, true}},
2254 {log_full_k_reg_mask, {0x0000ffff, true}},
2255 {log_five_bit_offset, {0x0000001f, true}},
2256 };
2257
2258 // log(x) polynomial approximation
2259 static const table_t log_polynomial {
2260 {log_pol, {0xbf000000, true}}, // p1 = -0.5f
2261 {log_pol, {0x3eaaaaab, true}}, // p2 = 0.333333343f
2262 {log_pol, {0xbe8004ab, true}}, // p3 = -0.250035613f
2263 {log_pol, {0x3e4cc8a3, true}}, // p4 = 0.199984118f
2264 };
2265
2266 // log(x) pre-defined values. First goes index}, then val[index].
2267 static const table_t log_predefined_values {
2268 {log_predefined_vals, {0x3f800000, true}}, // 0: 1
2269 {log_predefined_vals,
2270 {0xc2b00f34, true}}, // 1: -88.029693603515625
2271 {log_predefined_vals, {0x3f780000, true}}, // 2: 0.96875
2272 {log_predefined_vals,
2273 {0xc2affef2, true}}, // 3: -87.9979400634765625
2274 {log_predefined_vals, {0x3f700000, true}}, // 4: 0.9375
2275 {log_predefined_vals,
2276 {0xc2afee29, true}}, // 5: -87.9651565551757812
2277 {log_predefined_vals, {0x3f680000, true}}, // 6: 0.90625
2278 {log_predefined_vals,
2279 {0xc2afdccd, true}}, // 7: -87.9312515258789062
2280 {log_predefined_vals, {0x3f600000, true}}, // 8: 0.875
2281 {log_predefined_vals,
2282 {0xc2afcad6, true}}, // 9: -87.8961639404296875
2283 {log_predefined_vals, {0x3f580000, true}}, // 10: 0.84375
2284 {log_predefined_vals,
2285 {0xc2afb837, true}}, // 11: -87.859794616699218
2286 {log_predefined_vals, {0x3f580000, true}}, // 12: 0.84375
2287 {log_predefined_vals,
2288 {0xc2afb837, true}}, // 13: -87.859794616699218
2289 {log_predefined_vals, {0x3f500000, true}}, // 14: 0.8125
2290 {log_predefined_vals,
2291 {0xc2afa4e4, true}}, // 15: -87.822052001953125
2292 {log_predefined_vals, {0x3f480000, true}}, // 16: 0.78125
2293 {log_predefined_vals,
2294 {0xc2af90cf, true}}, // 17: -87.782829284667968
2295 {log_predefined_vals, {0x3f480000, true}}, // 18: 0.78125
2296 {log_predefined_vals,
2297 {0xc2af90cf, true}}, // 19: -87.782829284667968
2298 {log_predefined_vals, {0x3f400000, true}}, // 20: 0.75
2299 {log_predefined_vals,
2300 {0xc2af7be9, true}}, // 21: -87.742012023925781
2301 {log_predefined_vals, {0x3f400000, true}}, // 22: 0.75
2302 {log_predefined_vals,
2303 {0xc2af7be9, true}}, // 23: -87.742012023925781
2304 {log_predefined_vals, {0x3f380000, true}}, // 24: 0.71875
2305 {log_predefined_vals,
2306 {0xc2af661e, true}}, // 25: -87.699447631835937
2307 {log_predefined_vals, {0x3f380000, true}}, // 26: 0.71875
2308 {log_predefined_vals,
2309 {0xc2af661e, true}}, // 27: -87.699447631835937
2310 {log_predefined_vals, {0x3f300000, true}}, // 28: 0.6875
2311 {log_predefined_vals,
2312 {0xc2af4f5c, true}}, // 29: -87.654998779296875
2313 {log_predefined_vals, {0x3f300000, true}}, // 30: 0.6875
2314 {log_predefined_vals,
2315 {0xc2af4f5c, true}}, // 31: -87.654998779296875
2316 {log_predefined_vals, {0x3fa80000, true}}, // 32: 1.3125
2317 {log_predefined_vals,
2318 {0xc2b09a6f, true}}, // 33: -88.301628112792968
2319 {log_predefined_vals, {0x3fa80000, true}}, // 34: 1.3125
2320 {log_predefined_vals,
2321 {0xc2b09a6f, true}}, // 35: -88.301628112792968
2322 {log_predefined_vals, {0x3fa00000, true}}, // 36: 1.25
2323 {log_predefined_vals,
2324 {0xc2b08174, true}}, // 37: -88.252838134765625
2325 {log_predefined_vals, {0x3fa00000, true}}, // 38: 1.25
2326 {log_predefined_vals,
2327 {0xc2b08174, true}}, // 39: -88.252838134765625
2328 {log_predefined_vals, {0x3fa00000, true}}, // 40: 1.25
2329 {log_predefined_vals,
2330 {0xc2b08174, true}}, // 41: -88.252838134765625
2331 {log_predefined_vals, {0x3f980000, true}}, // 42: 1.1875
2332 {log_predefined_vals,
2333 {0xc2b06731, true}}, // 43: -88.201545715332031
2334 {log_predefined_vals, {0x3f980000, true}}, // 44: 1.1875
2335 {log_predefined_vals,
2336 {0xc2b06731, true}}, // 45: -88.201545715332031
2337 {log_predefined_vals, {0x3f900000, true}}, // 46: 1.125
2338 {log_predefined_vals,
2339 {0xc2b04b82, true}}, // 47: -88.147476196289062
2340 {log_predefined_vals, {0x3f900000, true}}, // 48: 1.125
2341 {log_predefined_vals,
2342 {0xc2b04b82, true}}, // 49: -88.147476196289062
2343 {log_predefined_vals, {0x3f900000, true}}, // 50: 1.125
2344 {log_predefined_vals,
2345 {0xc2b04b82, true}}, // 51: -88.147476196289062
2346 {log_predefined_vals, {0x3f900000, true}}, // 52: 1.125
2347 {log_predefined_vals,
2348 {0xc2b04b82, true}}, // 53: -88.147476196289062
2349 {log_predefined_vals, {0x3f880000, true}}, // 54: 1.0625
2350 {log_predefined_vals,
2351 {0xc2b02e3e, true}}, // 55: -88.090316772460937
2352 {log_predefined_vals, {0x3f880000, true}}, // 56: 1.0625
2353 {log_predefined_vals,
2354 {0xc2b02e3e, true}}, // 57: -88.090316772460937
2355 {log_predefined_vals, {0x3f880000, true}}, // 58: 1.0625
2356 {log_predefined_vals,
2357 {0xc2b02e3e, true}}, // 59: -88.090316772460937
2358 {log_predefined_vals, {0x3f800000, true}}, // 60: 1
2359 {log_predefined_vals,
2360 {0xc2b00f34, true}}, // 61: -88.029693603515625
2361 {log_predefined_vals, {0x3f800000, true}}, // 62: 1
2362 {log_predefined_vals,
2363 {0xc2b00f34, true}}, // 63: -88.029693603515625
2364 };
2365
2366 // This object takes care about which constants and polynomials to include.
2367 struct need_t {
2368 need_t(alg_kind_t alg) {
2369 using namespace alg_kind;
2370 switch (alg) {
2371 case eltwise_elu_use_dst_for_bwd:
2372 case eltwise_elu:
2373 case eltwise_exp_use_dst_for_bwd:
2374 case eltwise_exp:
2375 case eltwise_logistic_use_dst_for_bwd:
2376 case eltwise_logistic:
2377 case eltwise_swish: exp_ = true; break;
2378 case eltwise_gelu_erf: gelu_erf_ = true; break;
2379 case eltwise_gelu_tanh: gelu_tanh_ = true; break;
2380 case eltwise_log: log_ = true; break;
2381 case eltwise_soft_relu: soft_relu_ = true; break;
2382 case eltwise_mish: mish_ = true; break;
2383 case eltwise_tanh_use_dst_for_bwd:
2384 case eltwise_tanh: tanh_ = true; break;
2385 default: break;
2386 }
2387 }
2388
2389 bool exp_ = false;
2390 bool mish_ = false;
2391 bool tanh_ = false;
2392 bool soft_relu_ = false;
2393 bool gelu_tanh_ = false;
2394 bool gelu_erf_ = false;
2395 bool log_ = false;
2396
2397 bool exp() const { return exp_ || soft_relu_ || gelu_erf_ || mish_; }
2398 bool mish() const { return mish_; }
2399 bool tanh() const { return tanh_ || gelu_tanh_; }
2400 bool soft_relu() const { return soft_relu_; }
2401 bool gelu_tanh() const { return gelu_tanh_; }
2402 bool gelu_erf() const { return gelu_erf_; }
2403 bool log() const { return log_; }
2404 };
2405
2406 need_t need(alg_);
2407
2408 auto push_arg_entry_of = [&](const key_t key, const table_entry_val_t val,
2409 const bool broadcast) {
2410 mapped_table_entry_t te {0, val, broadcast};
2411 entry_map_.insert(std::make_pair(key, te));
2412 };
2413
2414 auto push_entries_of = [&](const table_t &t) {
2415 for (auto it = t.begin(); it != t.end(); it++) {
2416 auto key = (*it).first;
2417 auto te = (*it).second; // copy values from table
2418 push_arg_entry_of(key, te.val, te.bcast);
2419 }
2420 };
2421
2422 push_arg_entry_of(scale, float2int(scale_), true);
2423 push_arg_entry_of(alpha, float2int(alpha_), true);
2424 push_arg_entry_of(beta, float2int(beta_), true);
2425 push_entries_of(common_values);
2426 if (need.exp()) push_entries_of(exp_consts);
2427 if (need.exp()) push_entries_of(exp_polynomial);
2428 if (need.mish()) push_entries_of(mish_consts);
2429 if (need.tanh()) push_entries_of(tanh_consts);
2430 if (need.tanh()) push_entries_of(tanh_polynomial_table);
2431 if (need.soft_relu()) push_entries_of(soft_relu_consts);
2432 if (need.soft_relu()) push_entries_of(soft_relu_polynomial);
2433 if (need.gelu_tanh()) push_entries_of(gelu_tanh_consts);
2434 if (need.gelu_erf()) push_entries_of(gelu_erf_Abramowitz_Stegun_consts);
2435 if (need.gelu_erf()) push_entries_of(gelu_erf_Abramowitz_Stegun_polynomial);
2436 if (need.gelu_erf() && is_superset(isa, avx512_core))
2437 push_entries_of(gelu_erf_minimax_consts);
2438 if (need.gelu_erf() && is_superset(isa, avx512_core))
2439 push_entries_of(gelu_erf_minimax_polynomial);
2440
2441 if (need.log()) push_entries_of(log_consts);
2442 if (need.log()) push_entries_of(log_polynomial);
2443 if (need.log()) push_entries_of(log_predefined_values);
2444
2445 // Now that we registered the entries, we set the offsets. No
2446 // entries should be registered after this point. This allows to
2447 // expect the same order when injecting the table entries in
2448 // prepare_table.
2449 size_t off = 0;
2450 for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) {
2451 auto &te = (*it).second;
2452 te.off = off;
2453 off += te.bcast ? vlen : sizeof(table_entry_val_t);
2454 }
2455}
2456
2457template struct jit_uni_eltwise_injector_f32<avx512_core_fp16>;
2458template struct jit_uni_eltwise_injector_f32<avx512_core_fp16, Xbyak::Ymm>;
2459template struct jit_uni_eltwise_injector_f32<avx512_core_fp16, Xbyak::Xmm>;
2460template struct jit_uni_eltwise_injector_f32<avx512_core_bf16>;
2461template struct jit_uni_eltwise_injector_f32<avx512_core>;
2462template struct jit_uni_eltwise_injector_f32<avx512_core, Ymm>;
2463template struct jit_uni_eltwise_injector_f32<avx512_core, Xmm>;
2464template struct jit_uni_eltwise_injector_f32<avx2_vnni_2>;
2465template struct jit_uni_eltwise_injector_f32<avx2>;
2466template struct jit_uni_eltwise_injector_f32<avx2, Xmm>;
2467template struct jit_uni_eltwise_injector_f32<avx>;
2468template struct jit_uni_eltwise_injector_f32<avx, Xmm>;
2469template struct jit_uni_eltwise_injector_f32<sse41>;
2470
2471} // namespace x64
2472} // namespace cpu
2473} // namespace impl
2474} // namespace dnnl
2475