1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | |
29 | namespace eltwise_injector { |
30 | |
31 | bool is_isa_supported(cpu_isa_t isa) { |
32 | return is_superset(isa, sse41); |
33 | } |
34 | |
35 | bool is_alg_supported(alg_kind_t alg) { |
36 | using namespace alg_kind; |
37 | return utils::one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu, |
38 | eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear, |
39 | eltwise_soft_relu, eltwise_logistic, eltwise_mish, eltwise_exp, |
40 | eltwise_gelu_tanh, eltwise_hardsigmoid, eltwise_hardswish, |
41 | eltwise_swish, eltwise_log, eltwise_clip, eltwise_clip_v2, |
42 | eltwise_pow, eltwise_gelu_erf, eltwise_round, |
43 | eltwise_relu_use_dst_for_bwd, eltwise_tanh_use_dst_for_bwd, |
44 | eltwise_elu_use_dst_for_bwd, eltwise_sqrt_use_dst_for_bwd, |
45 | eltwise_logistic_use_dst_for_bwd, eltwise_exp_use_dst_for_bwd, |
46 | eltwise_clip_v2_use_dst_for_bwd); |
47 | } |
48 | |
49 | bool is_supported(cpu_isa_t isa, alg_kind_t alg) { |
50 | return is_isa_supported(isa) && is_alg_supported(alg); |
51 | } |
52 | |
53 | } // namespace eltwise_injector |
54 | |
55 | using namespace Xbyak; |
56 | |
57 | template <cpu_isa_t isa, typename Wmm> |
58 | void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble( |
59 | const injector_utils::vmm_index_set_t &vmm_idxs) { |
60 | using namespace alg_kind; |
61 | using namespace Xbyak::util; |
62 | preserved_vecs_count = 0; |
63 | vecs_to_preserve = aux_vecs_count(); |
64 | const auto start_idx = *(vmm_idxs.begin()); |
65 | const auto end_idx = *(vmm_idxs.rbegin()) + 1; |
66 | start_idx_tail = vmm_idxs.begin(); |
67 | |
68 | // For avx we need a register to save the upper part of Ymm |
69 | preserve_vec_for_avx = isa == avx |
70 | && utils::one_of(alg_, eltwise_tanh, eltwise_elu, eltwise_abs, |
71 | eltwise_soft_relu, eltwise_mish, eltwise_logistic, |
72 | eltwise_exp, eltwise_gelu_tanh, eltwise_swish, |
73 | eltwise_gelu_erf, eltwise_tanh_use_dst_for_bwd, |
74 | eltwise_elu_use_dst_for_bwd, |
75 | eltwise_logistic_use_dst_for_bwd, |
76 | eltwise_exp_use_dst_for_bwd); |
77 | if (preserve_vec_for_avx) vecs_to_preserve++; |
78 | |
79 | // For sse41 mask register has to be Xmm(0) |
80 | if (isa == sse41 && vecs_to_preserve > 0) { |
81 | size_t idx = 0; |
82 | assert(idx < start_idx); |
83 | preserved_vec_idxs[preserved_vecs_count++] = idx; |
84 | } |
85 | |
86 | for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) { |
87 | if (preserved_vecs_count >= vecs_to_preserve) break; |
88 | if (start_idx <= idx && idx < end_idx) continue; |
89 | |
90 | preserved_vec_idxs[preserved_vecs_count++] = idx; |
91 | } |
92 | |
93 | size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count; |
94 | for (size_t i = 0; i < preserved_vecs_count_tail; i++) { |
95 | preserved_vec_idxs[preserved_vecs_count++] = *start_idx_tail; |
96 | ++start_idx_tail; |
97 | } |
98 | |
99 | assert(preserved_vecs_count == vecs_to_preserve); |
100 | |
101 | // Same logic but to allocate gprs |
102 | size_t preserved_gprs_count = 0; |
103 | for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) { |
104 | int _idx = Operand::R15 - gpr_idx; // we allocate from the end |
105 | if (preserved_gprs_count < aux_gprs_count() |
106 | && !utils::one_of(_idx, p_table.getIdx(), Operand::RSP)) |
107 | preserved_gpr_idxs[preserved_gprs_count++] = _idx; |
108 | } |
109 | assert(preserved_gprs_count == aux_gprs_count()); |
110 | |
111 | if (save_state_) { |
112 | if (preserve_p_table_) h->push(p_table); |
113 | for (size_t i = 0; i < preserved_gprs_count; ++i) |
114 | h->push(Reg64(preserved_gpr_idxs[i])); |
115 | |
116 | if (preserve_vmm_) { |
117 | if (preserved_vecs_count) |
118 | h->sub(h->rsp, preserved_vecs_count * vlen); |
119 | |
120 | for (size_t i = 0; i < preserved_vecs_count; ++i) |
121 | h->uni_vmovups( |
122 | h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i])); |
123 | } |
124 | load_table_addr(); |
125 | } |
126 | |
127 | assign_regs(); |
128 | } |
129 | |
130 | template <cpu_isa_t isa, typename Wmm> |
131 | void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble_tail( |
132 | const injector_utils::vmm_index_set_iterator_t start_idx_it) { |
133 | size_t tail_vecs_to_preserve = std::distance(start_idx_it, start_idx_tail); |
134 | if (tail_vecs_to_preserve == 0) return; |
135 | |
136 | const int idx_off = vecs_to_preserve - tail_vecs_to_preserve; |
137 | |
138 | if (save_state_) { |
139 | if (idx_off) h->add(h->rsp, idx_off * vlen); |
140 | |
141 | for (size_t i = 0; i < tail_vecs_to_preserve; ++i) |
142 | h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]), |
143 | h->ptr[h->rsp + i * vlen]); |
144 | } |
145 | |
146 | for (size_t i = 0; i < tail_vecs_to_preserve; ++i) |
147 | preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve; |
148 | |
149 | if (save_state_ && preserve_vmm_) { |
150 | for (size_t i = 0; i < tail_vecs_to_preserve; ++i) |
151 | h->uni_vmovups(h->ptr[h->rsp + i * vlen], |
152 | Vmm(preserved_vec_idxs[idx_off + i])); |
153 | |
154 | if (idx_off) h->sub(h->rsp, idx_off * vlen); |
155 | } |
156 | |
157 | assign_regs(); |
158 | } |
159 | |
160 | template <cpu_isa_t isa, typename Wmm> |
161 | void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_postamble() { |
162 | using namespace Xbyak::util; |
163 | if (!save_state_) return; |
164 | |
165 | if (preserve_vmm_) { |
166 | for (size_t i = 0; i < preserved_vecs_count; ++i) |
167 | h->uni_vmovups( |
168 | Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]); |
169 | |
170 | if (preserved_vecs_count) h->add(h->rsp, preserved_vecs_count * vlen); |
171 | } |
172 | |
173 | for (int i = aux_gprs_count() - 1; i >= 0; --i) |
174 | h->pop(Reg64(preserved_gpr_idxs[i])); |
175 | if (preserve_p_table_) h->pop(p_table); |
176 | } |
177 | |
178 | template <cpu_isa_t isa, typename Wmm> |
179 | void jit_uni_eltwise_injector_f32<isa, Wmm>::assign_regs() { |
180 | vmm_mask = Vmm(preserved_vec_idxs[0]); |
181 | vmm_aux0 = Vmm(preserved_vec_idxs[0]); |
182 | vmm_aux1 = Vmm(preserved_vec_idxs[1]); |
183 | vmm_aux2 = Vmm(preserved_vec_idxs[2]); |
184 | vmm_aux3 = Vmm(preserved_vec_idxs[3]); |
185 | vmm_aux4 = Vmm(preserved_vec_idxs[4]); |
186 | if (preserve_vec_for_avx) { |
187 | vmm_tmp = Vmm(preserved_vec_idxs[vecs_to_preserve - 1]); |
188 | ymm_tmp = Ymm(preserved_vec_idxs[vecs_to_preserve - 1]); |
189 | xmm_tmp = Xmm(preserved_vec_idxs[vecs_to_preserve - 1]); |
190 | } |
191 | } |
192 | |
193 | template <cpu_isa_t isa, typename Wmm> |
194 | void jit_uni_eltwise_injector_f32<isa, Wmm>::vec_shift(const Vmm &vmm_dst, |
195 | const Vmm &vmm_src, bool shift_left, const int imm) { |
196 | if (isa != avx) { |
197 | if (shift_left) |
198 | h->uni_vpslld(vmm_dst, vmm_src, imm); |
199 | else |
200 | h->uni_vpsrld(vmm_dst, vmm_src, imm); |
201 | } else { |
202 | // Declare appropriate vectors to use non-uni instructions |
203 | Xmm xmm_dst = Xmm(vmm_dst.getIdx()); |
204 | Ymm ymm_dst = Ymm(vmm_dst.getIdx()); |
205 | Ymm ymm_src = Ymm(vmm_src.getIdx()); |
206 | if (vmm_dst.getIdx() != vmm_src.getIdx()) h->vmovups(ymm_dst, ymm_src); |
207 | h->vextractf128(xmm_tmp, ymm_dst, 1); |
208 | if (shift_left) { |
209 | h->vpslld(xmm_dst, xmm_dst, imm); |
210 | h->vpslld(xmm_tmp, xmm_tmp, imm); |
211 | } else { |
212 | h->vpsrld(xmm_dst, xmm_dst, imm); |
213 | h->vpsrld(xmm_tmp, xmm_tmp, imm); |
214 | } |
215 | h->vinsertf128(ymm_dst, ymm_dst, xmm_tmp, 1); |
216 | } |
217 | } |
218 | |
219 | // Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2). |
220 | // Stores a mask by applying cmpps on two inputs w/ a given predicate. |
221 | template <cpu_isa_t isa, typename Wmm> |
222 | void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_cmp_mask( |
223 | const Vmm &vmm_src, const Xbyak::Operand &compare_operand, |
224 | int cmp_predicate) { |
225 | if (is_avx512) { |
226 | h->vcmpps(k_mask, vmm_src, compare_operand, cmp_predicate); |
227 | } else { |
228 | h->uni_vcmpps(vmm_mask, vmm_src, compare_operand, cmp_predicate); |
229 | } |
230 | } |
231 | |
232 | // Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2). |
233 | // Blends a result of second input into a first input w/ a stored mask. |
234 | template <cpu_isa_t isa, typename Wmm> |
235 | void jit_uni_eltwise_injector_f32<isa, Wmm>::blend_with_mask( |
236 | const Vmm &vmm_dst, const Xbyak::Operand &src) { |
237 | if (is_avx512) { |
238 | h->vblendmps(vmm_dst | k_mask, vmm_dst, src); |
239 | } else { |
240 | h->uni_vblendvps(vmm_dst, vmm_dst, src, vmm_mask); |
241 | } |
242 | } |
243 | |
244 | // Uses injector masks objects: k_mask (>= avx512_core) or vmm_mask (<= avx2). |
245 | // Tests a mask for all zeros. If all zeroes occur, set ZF = 1. |
246 | // Nicely combines with jump_if_zero (jz). |
247 | template <cpu_isa_t isa, typename Wmm> |
248 | void jit_uni_eltwise_injector_f32<isa, Wmm>::test_mask() { |
249 | if (is_avx512) { |
250 | h->kortestw(k_mask, k_mask); |
251 | } else { |
252 | h->uni_vtestps(vmm_mask, vmm_mask); |
253 | } |
254 | } |
255 | |
256 | template <cpu_isa_t isa, typename Wmm> |
257 | void jit_uni_eltwise_injector_f32<isa, Wmm>::exp_compute_vector_fwd( |
258 | const Vmm &vmm_src) { |
259 | // exp(x) = |
260 | // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem |
261 | // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression |
262 | |
263 | // get mask of values lower than log(FLT_MIN) to zero them in the output |
264 | compute_cmp_mask(vmm_src, table_val(exp_ln_flt_min_f), _cmp_lt_os); |
265 | |
266 | h->uni_vminps(vmm_src, vmm_src, table_val(exp_ln_flt_max_f)); |
267 | h->uni_vmaxps(vmm_src, vmm_src, table_val(exp_ln_flt_min_f)); |
268 | h->uni_vmovups(vmm_aux1, vmm_src); |
269 | |
270 | // calculate exp(x) |
271 | // fx = x * log2ef + 0.5 |
272 | h->uni_vmulps(vmm_src, vmm_src, table_val(exp_log2ef)); |
273 | h->uni_vaddps(vmm_src, vmm_src, table_val(half)); |
274 | |
275 | // tmp = floorf(fx) |
276 | h->uni_vroundps(vmm_aux2, vmm_src, _op_floor); |
277 | |
278 | // keep vmm_src = fx for further computations |
279 | h->uni_vmovups(vmm_src, vmm_aux2); |
280 | |
281 | // x = x - fx * ln2 |
282 | h->uni_vfnmadd231ps(vmm_aux1, vmm_aux2, table_val(ln2f)); |
283 | |
284 | // We do not count 2^n here, because n can reach 128 and 2^128 is not |
285 | // representable by fp32, so to get around this problem, instead of computing |
286 | // 2^n * exp(r) will be counted 2*2^(n-1)*exp(r), because 2^127 |
287 | // and 2 are numbers representable in fp32. |
288 | |
289 | // compute 2^(n-1) |
290 | h->uni_vsubps(vmm_src, vmm_src, table_val(one)); |
291 | h->uni_vcvtps2dq(vmm_aux2, vmm_src); |
292 | if (isa != avx) |
293 | h->uni_vpaddd(vmm_aux2, vmm_aux2, table_val(exponent_bias)); |
294 | else { |
295 | Ymm ymm_aux2 = Ymm(vmm_aux2.getIdx()); |
296 | Xmm xmm_aux2 = Xmm(vmm_aux2.getIdx()); |
297 | h->vextractf128(xmm_tmp, ymm_aux2, 1); |
298 | h->vpaddd(xmm_tmp, xmm_tmp, table_val(exponent_bias)); |
299 | h->vpaddd(xmm_aux2, xmm_aux2, table_val(exponent_bias)); |
300 | h->vinsertf128(ymm_aux2, ymm_aux2, xmm_tmp, 1); |
301 | } |
302 | vec_shift(vmm_aux2, vmm_aux2, true /*shift_left*/, n_mantissa_bits); |
303 | // use vmm_src as tmp vmm_zero when applying mask |
304 | h->uni_vxorps(vmm_src, vmm_src, vmm_src); |
305 | // set zeroes at those points which were < log(FLT_MIN) |
306 | blend_with_mask(vmm_aux2, vmm_src); |
307 | |
308 | // compute polynomial |
309 | h->uni_vmovups(vmm_src, table_val(exp_pol, 4)); |
310 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 3)); |
311 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 2)); |
312 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 1)); |
313 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(exp_pol, 0)); |
314 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one)); |
315 | // y = y * 2^n |
316 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux2); |
317 | h->uni_vmulps(vmm_src, vmm_src, table_val(two)); |
318 | } |
319 | |
320 | template <cpu_isa_t isa, typename Wmm> |
321 | void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_compute_vector_fwd( |
322 | const Vmm &vmm_src) { |
323 | h->uni_vmovups(vmm_aux1, vmm_src); |
324 | compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os); |
325 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
326 | blend_with_mask(vmm_src, vmm_aux1); |
327 | } |
328 | |
329 | template <cpu_isa_t isa, typename Wmm> |
330 | void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_zero_ns_compute_vector_fwd( |
331 | const Vmm &vmm_src) { |
332 | h->uni_vmaxps(vmm_src, vmm_src, table_val(zero)); |
333 | } |
334 | |
335 | template <cpu_isa_t isa, typename Wmm> |
336 | void jit_uni_eltwise_injector_f32<isa, Wmm>::elu_compute_vector_fwd( |
337 | const Vmm &vmm_src) { |
338 | // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it. |
339 | h->uni_vmovups(vmm_aux3, vmm_src); |
340 | // compute exponent |
341 | exp_compute_vector_fwd(vmm_src); |
342 | |
343 | // alpha * (exp(x) - 1) |
344 | h->uni_vsubps(vmm_src, vmm_src, table_val(one)); |
345 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
346 | |
347 | // combine with mask |
348 | compute_cmp_mask(vmm_aux3, table_val(zero), _cmp_gt_os); |
349 | blend_with_mask(vmm_src, vmm_aux3); |
350 | } |
351 | |
352 | template <cpu_isa_t isa, typename Wmm> |
353 | void jit_uni_eltwise_injector_f32<isa, Wmm>::tanh_compute_vector_fwd( |
354 | const Vmm &vmm_src) { |
355 | // we add a check as the avx2 code cannot be used for avx |
356 | assert(IMPLICATION(isa == avx2, mayiuse(avx2))); |
357 | |
358 | using namespace Xbyak::util; |
359 | const int XMM_float_lanes_count = 4; |
360 | const int tanh_n_polynomials = 32; |
361 | |
362 | // register mapping |
363 | // TODO: put sign on stack and alias zmm_table2 with vmm_sign to save a reg ? |
364 | Vmm vmm_dst = vmm_aux1, vmm_src_shift = vmm_aux1, vmm_coeff = vmm_aux1, |
365 | vmm_pol = vmm_aux2, vmm_indices = vmm_aux3, vmm_src_original = vmm_aux4, |
366 | vmm_sign = vmm_aux4; |
367 | Reg64 gpr_idx[XMM_float_lanes_count]; |
368 | |
369 | if (isa == sse41 || isa == avx) { |
370 | assert(aux_gprs_count() >= XMM_float_lanes_count); |
371 | for (int i = 0; i < XMM_float_lanes_count; i++) |
372 | gpr_idx[i] = Reg64(preserved_gpr_idxs[i]); |
373 | } |
374 | |
375 | // We split the positive domain in 33 intervals: |
376 | // a) [0; linear_ubound]: in this interval tanh(x) = x |
377 | // b) [linear_ubound; 0x1.8p-12]: This interval spans part of a |
378 | // half binade |
379 | // c) [0x1.8p-12; 0x1.0p-11], ..., [0x1.8p2; 0x1.0p3]: |
380 | // one interval for each half binade, there are 29 of those |
381 | // d) [0x1.0p3; saturation_ubound]: |
382 | // This interval spans part of a half binade |
383 | // e) [0x1.205966p3; saturation_ubound]: in this interval, tanh(x) = 1 |
384 | // For b-d, we need 31 polynomials and will do a table lookup for those. |
385 | // To simplify the logic, we will also put a) in the table. |
386 | |
387 | // The polynomials are of degree 6, so we need to gather 7 coefficients. |
388 | // - sse4.1: we do it the naive way using vextract/vinsert. |
389 | // Here we will extract the indices in gpr only once and |
390 | // reuse them as there are only 4 of them. |
391 | // - avx: we do the same as for sse4.1 but use half of the 64-bits |
392 | // registers to store the idx of second half of YMM and half for |
393 | // responding XMM. Halfway through the copy we exchange Xmm and |
394 | // higher half of Ymm and we get the expected result. |
395 | // - avx2: we use vpermps and blend for each coefficient. |
396 | // This needs an extra vmm to store the mask |
397 | // - avx512: because the table fits in 2 registers, we can use vpermi2d. |
398 | auto coeffs_off = [&](int coeff_off, int off = 0) { |
399 | return table_off(tanh_pol_table, coeff_off * tanh_n_polynomials + off); |
400 | }; |
401 | auto coeffs_address = [&](int coeff_off, int off = 0) { |
402 | return table_val(tanh_pol_table, coeff_off * tanh_n_polynomials + off); |
403 | }; |
404 | auto gather_coefficient_init = [&](Vmm vmm_pol_idx, int nelems) { |
405 | switch (isa) { |
406 | case sse41: |
407 | for (int i = 0; i < XMM_float_lanes_count; ++i) |
408 | h->pextrd(gpr_idx[i].cvt32(), vmm_pol_idx, i); |
409 | break; |
410 | case avx: { |
411 | Xmm xmm_pol_idx = Xmm(vmm_pol_idx.getIdx()); |
412 | for (int i = 0; i < XMM_float_lanes_count; ++i) |
413 | h->vpextrd(gpr_idx[i].cvt32(), xmm_pol_idx, i); |
414 | } break; |
415 | case avx2_vnni_2: |
416 | case avx2: |
417 | // needed for gather instruction |
418 | h->uni_vxorps(vmm_mask, vmm_mask, vmm_mask); |
419 | break; |
420 | case avx512_core_fp16: |
421 | case avx512_core_bf16: |
422 | case avx512_core: break; |
423 | default: assert(!"unimplemented" ); |
424 | } |
425 | }; |
426 | auto gather_coefficient = [&](Vmm vmm_coeff, int coeff_idx, |
427 | Vmm vmm_pol_idx) { |
428 | switch (isa) { |
429 | case sse41: |
430 | for (int idx = 0; idx < 4; ++idx) { |
431 | Xbyak::Address coeff_addr |
432 | = ptr[p_table + coeffs_off(coeff_idx) |
433 | + gpr_idx[idx] * sizeof(float)]; |
434 | h->pinsrd(vmm_coeff, coeff_addr, idx); |
435 | } |
436 | break; |
437 | case avx: { |
438 | Xmm xmm_coeff = Xmm(vmm_coeff.getIdx()); |
439 | for (int idx = 0; idx < 4; ++idx) { |
440 | Xbyak::Address coeff_addr |
441 | = ptr[p_table + coeffs_off(coeff_idx) |
442 | + gpr_idx[idx] * sizeof(float)]; |
443 | h->vpinsrd(xmm_coeff, xmm_coeff, coeff_addr, idx); |
444 | } |
445 | } break; |
446 | case avx2_vnni_2: |
447 | case avx2: { |
448 | Xbyak::Address idx_addr = ptr[p_table + coeffs_off(coeff_idx) |
449 | + vmm_pol_idx * sizeof(float)]; |
450 | // we set the mask to all ones to gather full |
451 | // register. needs to be done after each gather since |
452 | // since the gather instructions zeros the mask if |
453 | // successful |
454 | h->uni_vcmpps(vmm_mask, vmm_mask, vmm_mask, _cmp_eq_oq); |
455 | h->vgatherdps(vmm_coeff, idx_addr, vmm_mask); |
456 | break; |
457 | } |
458 | // use gather instruction |
459 | case avx512_core_fp16: |
460 | case avx512_core_bf16: |
461 | case avx512_core: |
462 | // we use vpermt2ps to not override the indices |
463 | // this also enables to save a register for table loading |
464 | { |
465 | Zmm zmm_coeff(vmm_coeff.getIdx()); |
466 | Zmm zmm_pol_idx(vmm_pol_idx.getIdx()); |
467 | h->uni_vmovups(zmm_coeff, coeffs_address(coeff_idx, 0)); |
468 | h->vpermt2ps(zmm_coeff, zmm_pol_idx, |
469 | coeffs_address(coeff_idx, 16)); |
470 | break; |
471 | } |
472 | default: assert(!"unimplemented" ); |
473 | } |
474 | }; |
475 | |
476 | // because tanh(x) = -tanh(-x), we extract sign to make x postive |
477 | // and reapply sign at the end |
478 | h->uni_vmovups(vmm_src_original, vmm_src); |
479 | h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask)); |
480 | |
481 | // We compute the indices for the table lookup |
482 | h->uni_vmovups(vmm_indices, vmm_src); |
483 | if (isa != avx) |
484 | h->uni_vpsubd(vmm_indices, vmm_indices, table_val(tanh_idx_bias)); |
485 | else { |
486 | Ymm ymm_indices = Ymm(vmm_indices.getIdx()); |
487 | Xmm xmm_indices = Xmm(vmm_indices.getIdx()); |
488 | h->vextractf128(xmm_tmp, ymm_indices, 1); |
489 | h->vpsubd(xmm_tmp, xmm_tmp, table_val(tanh_idx_bias)); |
490 | h->vpsubd(xmm_indices, xmm_indices, table_val(tanh_idx_bias)); |
491 | h->vinsertf128(ymm_indices, ymm_indices, xmm_tmp, 1); |
492 | } |
493 | h->uni_vandps(vmm_indices, vmm_indices, table_val(tanh_idx_mask)); |
494 | vec_shift(vmm_indices, vmm_indices, false, 22); |
495 | |
496 | // we do the argument reduction |
497 | h->uni_vmovups(vmm_src_shift, vmm_src); |
498 | h->uni_vandps(vmm_src_shift, vmm_src_shift, table_val(tanh_idx_mask)); |
499 | h->uni_vsubps(vmm_src, vmm_src, vmm_src_shift); |
500 | |
501 | // we gather and evaluate the polynonials |
502 | gather_coefficient_init(vmm_indices, vlen / sizeof(float)); |
503 | gather_coefficient(vmm_pol, 6, vmm_indices); |
504 | for (int deg = 5; deg >= 0; --deg) { |
505 | gather_coefficient(vmm_coeff, deg, vmm_indices); |
506 | h->uni_vfmadd213ps(vmm_pol, vmm_src, vmm_coeff); |
507 | } |
508 | |
509 | if (isa == avx) { |
510 | Ymm ymm_indices = Ymm(vmm_indices.getIdx()); |
511 | Ymm ymm_pol = Ymm(vmm_pol.getIdx()); |
512 | Ymm ymm_src = Ymm(vmm_src.getIdx()); |
513 | Xmm xmm_src = Xmm(vmm_src.getIdx()); |
514 | Xmm xmm_coeff = Xmm(vmm_coeff.getIdx()); |
515 | |
516 | h->vperm2f128(ymm_src, ymm_src, ymm_src, 1); |
517 | h->vperm2f128(ymm_indices, ymm_indices, ymm_indices, 1); |
518 | gather_coefficient_init(vmm_indices, vlen / sizeof(float)); |
519 | gather_coefficient(vmm_tmp, 6, vmm_indices); |
520 | for (int deg = 5; deg >= 0; --deg) { |
521 | gather_coefficient(vmm_coeff, deg, vmm_indices); |
522 | h->vmulps(xmm_tmp, xmm_tmp, xmm_src); |
523 | h->vaddps(xmm_tmp, xmm_tmp, xmm_coeff); |
524 | } |
525 | h->vinsertf128(ymm_pol, ymm_pol, xmm_tmp, 1); |
526 | } |
527 | |
528 | // we restore src with cleared sign, and keep sign |
529 | assert(vmm_sign.getIdx() == vmm_src_original.getIdx()); |
530 | h->uni_vmovups(vmm_src, vmm_src_original); |
531 | h->uni_vandps(vmm_sign, vmm_sign, table_val(sign_mask)); |
532 | h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask)); |
533 | |
534 | // Now we blend the results |
535 | // [saturation_ubound; +inf[ : we return +/- 1 |
536 | h->uni_vmovups(vmm_dst, table_val(one)); |
537 | // [linear_ubound; saturation_lbound] : we return +/- P(x) |
538 | h->uni_vmovups(vmm_mask, table_val(tanh_saturation_lbound)); |
539 | compute_cmp_mask(vmm_mask, vmm_src, _cmp_gt_os); |
540 | blend_with_mask(vmm_dst, vmm_pol); |
541 | // [0; linear_ubound] : we return x |
542 | h->uni_vmovups(vmm_mask, table_val(tanh_linear_ubound)); |
543 | compute_cmp_mask(vmm_mask, vmm_src, _cmp_gt_os); |
544 | blend_with_mask(vmm_dst, vmm_src); |
545 | |
546 | // We reapply the sign and return |
547 | h->uni_vxorps(vmm_dst, vmm_dst, vmm_sign); |
548 | h->uni_vmovups(vmm_src, vmm_dst); |
549 | } |
550 | |
551 | template <cpu_isa_t isa, typename Wmm> |
552 | void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_fwd( |
553 | const Vmm &vmm_src) { |
554 | h->uni_vmovups(vmm_aux0, vmm_src); |
555 | |
556 | // compute G(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x * x) |
557 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
558 | h->uni_vmovups(vmm_aux1, table_val(gelu_tanh_fitting_const)); |
559 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one)); |
560 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
561 | h->uni_vmulps(vmm_src, vmm_src, table_val(gelu_tanh_sqrt_two_over_pi)); |
562 | |
563 | // save x on stack as tanh uses vmm_aux0 |
564 | h->sub(h->rsp, vlen); |
565 | h->uni_vmovups(h->ptr[h->rsp], vmm_aux0); |
566 | |
567 | // compute tanh(G(x)) |
568 | tanh_compute_vector_fwd(vmm_src); |
569 | |
570 | h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]); |
571 | h->add(h->rsp, vlen); |
572 | |
573 | // compute 0.5 * x * (1 + tanh(G(x))) |
574 | h->uni_vaddps(vmm_src, vmm_src, table_val(one)); |
575 | h->uni_vmulps(vmm_src, vmm_src, table_val(half)); |
576 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
577 | } |
578 | |
579 | template <cpu_isa_t isa, typename Wmm> |
580 | void jit_uni_eltwise_injector_f32<isa, Wmm>::square_compute_vector_fwd( |
581 | const Vmm &vmm_src) { |
582 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
583 | } |
584 | |
585 | template <cpu_isa_t isa, typename Wmm> |
586 | void jit_uni_eltwise_injector_f32<isa, Wmm>::abs_compute_vector_fwd( |
587 | const Vmm &vmm_src) { |
588 | // compute abs(x) = _mm_and_ps(x, 01111..111)); |
589 | h->uni_vandps(vmm_src, vmm_src, table_val(positive_mask)); |
590 | } |
591 | |
592 | template <cpu_isa_t isa, typename Wmm> |
593 | void jit_uni_eltwise_injector_f32<isa, Wmm>::sqrt_compute_vector_fwd( |
594 | const Vmm &vmm_src) { |
595 | h->uni_vsqrtps(vmm_src, vmm_src); |
596 | } |
597 | |
598 | template <cpu_isa_t isa, typename Wmm> |
599 | void jit_uni_eltwise_injector_f32<isa, Wmm>::linear_compute_vector_fwd( |
600 | const Vmm &vmm_src) { |
601 | // compute x = alpha * x + beta; |
602 | h->uni_vmovups(vmm_aux0, table_val(alpha)); |
603 | h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(beta)); |
604 | } |
605 | |
606 | template <cpu_isa_t isa, typename Wmm> |
607 | void jit_uni_eltwise_injector_f32<isa, Wmm>::clip_compute_vector_fwd( |
608 | const Vmm &vmm_src) { |
609 | h->uni_vmaxps(vmm_src, vmm_src, table_val(alpha)); |
610 | h->uni_vminps(vmm_src, vmm_src, table_val(beta)); |
611 | } |
612 | |
613 | template <cpu_isa_t isa, typename Wmm> |
614 | void jit_uni_eltwise_injector_f32<isa, Wmm>::mish_compute_vector_fwd( |
615 | const Vmm &vmm_src) { |
616 | // An equation other than mish(x) = x*tanh(srelu(x)) was used |
617 | // to calculate mish, but it should be remembered that it is equivalent |
618 | // equation, it uses the following rule: |
619 | // tanh(x) = (e^x - e^-x) / (e^x + e^-x), |
620 | // hence the equation for mish can take the form: |
621 | // mish(x) = x * ((e^x + 1)^2 - 1)/((e^x + 1)^2 + 1). |
622 | // This option was chosen because computing tanh requires more registers |
623 | // than exp, and also requires more constants to be stored in memory, |
624 | // making the algorithm slower. |
625 | |
626 | // IMPORTANT: we use vmm_aux3 to save src as exp does not use it. |
627 | h->uni_vmovups(vmm_aux3, vmm_src); // vmm_aux3 = x |
628 | |
629 | h->uni_vminps(vmm_src, vmm_src, table_val(fwd_mish_max_x_for_equation_f)); |
630 | exp_compute_vector_fwd(vmm_src); |
631 | |
632 | // (e^x+1)^2 |
633 | h->uni_vaddps(vmm_src, vmm_src, table_val(one)); |
634 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
635 | |
636 | // save (e^x+1)^2 as it appears in both the denominator and the numerator |
637 | h->uni_vmovups(vmm_aux1, vmm_src); |
638 | |
639 | // x * ((e^x + 1)^2 - 1) / ((e^x + 1)^2 + 1) |
640 | h->uni_vsubps(vmm_src, vmm_src, table_val(one)); |
641 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one)); |
642 | h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); |
643 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux3); |
644 | } |
645 | |
646 | template <cpu_isa_t isa, typename Wmm> |
647 | void jit_uni_eltwise_injector_f32<isa, Wmm>::hardswish_compute_vector_fwd( |
648 | const Vmm &vmm_src) { |
649 | // result = x * hardsigmoid(x) |
650 | h->uni_vmovups(vmm_aux0, vmm_src); |
651 | hardsigmoid_compute_vector_fwd(vmm_src); |
652 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
653 | } |
654 | |
655 | template <cpu_isa_t isa, typename Wmm> |
656 | void jit_uni_eltwise_injector_f32<isa, Wmm>::hardsigmoid_compute_vector_fwd( |
657 | const Vmm &vmm_src) { |
658 | // result = max(0, min(1, alpha * x + beta)) |
659 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
660 | h->uni_vaddps(vmm_src, vmm_src, table_val(beta)); |
661 | h->uni_vminps(vmm_src, vmm_src, table_val(one)); |
662 | h->uni_vmaxps(vmm_src, vmm_src, table_val(zero)); |
663 | } |
664 | |
665 | template <cpu_isa_t isa, typename Wmm> |
666 | void jit_uni_eltwise_injector_f32<isa, Wmm>::soft_relu_compute_vector_fwd( |
667 | const Vmm &vmm_src) { |
668 | // alpha scaling |
669 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
670 | |
671 | // ln(1 + exp(x)) = |
672 | // = ln(1 + exp(n * ln(2) + r)) // divide x by ln(2) and get quot and rem |
673 | // = ln(1 + 2^n * exp(r)) // simplify the exp(n*ln(2)) expression |
674 | // = ln(2 ^ 0 + 2^n * exp(r)) // note 1 = 2^0 |
675 | // = ln(2 ^ (n - n) + 2^n * exp(r)) // 2^0 = 2^(n-n) |
676 | // = ln(2 ^ n * (2^-n + exp(r))) // factorize with 2^n |
677 | // = n * ln(2) + ln(2^-n + exp(r)) // take the 2^n factor out of the ln |
678 | |
679 | // keep src for further computations |
680 | h->uni_vmovups(vmm_aux2, vmm_src); |
681 | |
682 | h->uni_vminps(vmm_src, vmm_src, table_val(exp_ln_flt_max_f)); |
683 | h->uni_vmaxps(vmm_src, vmm_src, table_val(exp_ln_flt_min_f)); |
684 | h->uni_vmovups(vmm_aux1, vmm_src); |
685 | |
686 | // calculate exp(x) |
687 | // fx = x * log2ef + 0.5 |
688 | h->uni_vmulps(vmm_src, vmm_src, table_val(exp_log2ef)); |
689 | h->uni_vaddps(vmm_src, vmm_src, table_val(half)); |
690 | |
691 | // tmp = floorf(fx) |
692 | h->uni_vroundps(vmm_aux0, vmm_src, _op_floor); |
693 | |
694 | // keep vmm_src = fx for further computations |
695 | h->uni_vmovups(vmm_src, vmm_aux0); |
696 | |
697 | // x = x - fx * ln2 |
698 | h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(ln2f)); |
699 | h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0); |
700 | // compute exponent polynomial |
701 | h->uni_vmovups(vmm_aux3, table_val(exp_pol, 4)); |
702 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 3)); |
703 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 2)); |
704 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 1)); |
705 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(exp_pol, 0)); |
706 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(one)); |
707 | |
708 | // We do not count 2^-n here, because n can reach 128 and 2^(-128) is not |
709 | // representable by fp32, so to get around this problem, instead of computing |
710 | // 2^-n + exp(r) will be counted (2^-(n-1) + 2*exp(r))/2, because 2^(-127) |
711 | // and 2 are numbers representable in fp32. |
712 | |
713 | // compute 2^-(n-1) |
714 | // vmm_src now represents n-1 |
715 | h->uni_vsubps(vmm_src, vmm_src, table_val(one)); |
716 | if (is_avx512) { |
717 | h->vmulps(vmm_aux1, vmm_src, table_val(minus_one)); |
718 | h->vcvtps2dq(vmm_aux1, vmm_aux1); |
719 | } else if (isa == avx) { |
720 | h->uni_vxorps(vmm_aux1, vmm_src, table_val(sign_mask)); |
721 | h->uni_vcvtps2dq(vmm_aux1, vmm_aux1); |
722 | } else { |
723 | h->uni_vcvtps2dq(vmm_aux1, vmm_src); |
724 | h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(minus_one)); |
725 | } |
726 | // restore vmm_src to n |
727 | h->uni_vaddps(vmm_src, vmm_src, table_val(one)); |
728 | |
729 | if (isa != avx) |
730 | h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(exponent_bias)); |
731 | else { |
732 | Ymm ymm_aux1 = Ymm(vmm_aux1.getIdx()); |
733 | Xmm xmm_aux1 = Xmm(vmm_aux1.getIdx()); |
734 | h->vextractf128(xmm_tmp, ymm_aux1, 1); |
735 | h->vpaddd(xmm_tmp, xmm_tmp, table_val(exponent_bias)); |
736 | h->vpaddd(xmm_aux1, xmm_aux1, table_val(exponent_bias)); |
737 | h->vinsertf128(ymm_aux1, ymm_aux1, xmm_tmp, 1); |
738 | } |
739 | vec_shift(vmm_aux1, vmm_aux1, true /*shift_left*/, n_mantissa_bits); |
740 | // calculate ln(1 + y) |
741 | h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(two)); // 2*exp(r) |
742 | h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1); // 2^-(n-1) + 2*exp(r) |
743 | h->uni_vdivps( |
744 | vmm_aux3, vmm_aux3, table_val(two)); // (2^-(n-1) + 2*exp(r))/2 |
745 | // frexp() |
746 | vec_shift(vmm_src, vmm_aux3, false /*shift_left*/, n_mantissa_bits); |
747 | h->uni_vcvtdq2ps(vmm_src, vmm_src); |
748 | // got n. where n is x = 2^n * y. y = 0.5 .. 1 |
749 | h->uni_vsubps(vmm_src, vmm_src, table_val(soft_relu_one_twenty_six)); |
750 | |
751 | // and with mask (to get 0.5 * mantissa) |
752 | h->uni_vandps(vmm_aux3, vmm_aux3, table_val(soft_relu_mantissa_sign_mask)); |
753 | // got y. (mantisa) 0.5 < y < 1 (or with (to get 0.5 * mantissa)) |
754 | h->uni_vorps(vmm_aux3, vmm_aux3, table_val(half)); |
755 | // y = y - 1 |
756 | h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(one)); |
757 | |
758 | // compute log1p polynomial |
759 | h->uni_vmovups(vmm_aux1, table_val(soft_relu_pol, 8)); |
760 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 7)); |
761 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 6)); |
762 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 5)); |
763 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 4)); |
764 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 3)); |
765 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 2)); |
766 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 1)); |
767 | h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(soft_relu_pol, 0)); |
768 | //calculate ln(2) * n |
769 | h->uni_vmulps(vmm_src, vmm_src, table_val(ln2f)); |
770 | h->uni_vaddps(vmm_src, vmm_src, vmm_aux1); |
771 | h->uni_vaddps(vmm_src, vmm_src, vmm_aux0); |
772 | |
773 | // get vmm_mask = src > max logf |
774 | // y = (x < max log f) ? soft_relu(x) : x |
775 | compute_cmp_mask(vmm_aux2, table_val(exp_ln_flt_max_f), _cmp_gt_os); |
776 | blend_with_mask(vmm_src, vmm_aux2); |
777 | if (alpha_ == 1.f) { // standard soft_relu case |
778 | // Skip an instruction. |
779 | } else if (alpha_ == -1) { // logsigmoid case |
780 | h->uni_vmulps(vmm_src, vmm_src, table_val(minus_one)); |
781 | } else { // General case. |
782 | h->uni_vdivps(vmm_src, vmm_src, table_val(alpha)); |
783 | } |
784 | } |
785 | |
786 | template <cpu_isa_t isa, typename Wmm> |
787 | void jit_uni_eltwise_injector_f32<isa, Wmm>::logistic_compute_vector_fwd( |
788 | const Vmm &vmm_src) { |
789 | // To avoid exp(x) overflow happened at x > logf(FLT_MAX), negate positive, |
790 | // compute exp(x), where x <= 0 to get 0 <= exp(x) <= 1 and restore value |
791 | // sign at the end. This is possible due to logistic is symmetric function. |
792 | |
793 | // IMPORTANT: we use vmm_aux3 for the mask as exp_compute does not use it. |
794 | h->uni_vmovups(vmm_aux3, vmm_src); |
795 | // we store the original sign and make x negative |
796 | h->uni_vandps(vmm_aux3, vmm_aux3, table_val(sign_mask)); |
797 | h->uni_vorps(vmm_src, vmm_src, table_val(sign_mask)); |
798 | |
799 | exp_compute_vector_fwd(vmm_src); |
800 | // dup exp(x) |
801 | h->uni_vmovups(vmm_aux1, vmm_src); |
802 | // (exp(x) + 1) |
803 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one)); |
804 | // y = exp(x) / (exp(x) + 1) |
805 | h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); |
806 | |
807 | // Now we have to apply the "symmetry" based on original sign |
808 | h->uni_vmovups(vmm_aux2, table_val(one)); |
809 | h->uni_vsubps(vmm_aux2, vmm_aux2, vmm_src); |
810 | if (is_avx512) { |
811 | h->vptestmd(k_mask, vmm_aux3, vmm_aux3); |
812 | } else { |
813 | h->uni_vmovups(vmm_mask, vmm_aux3); |
814 | } |
815 | blend_with_mask(vmm_aux2, vmm_src); |
816 | h->uni_vmovups(vmm_src, vmm_aux2); |
817 | } |
818 | |
819 | template <cpu_isa_t isa, typename Wmm> |
820 | void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_fwd( |
821 | const Vmm &vmm_src) { |
822 | // Save src data on stack for later usage |
823 | h->sub(h->rsp, vlen); |
824 | h->uni_vmovups(h->ptr[h->rsp], vmm_src); |
825 | // x*alpha |
826 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
827 | // sigmoid(x*alpha) |
828 | logistic_compute_vector_fwd(vmm_src); |
829 | // x*sigmoid(alpha*x) |
830 | h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]); |
831 | h->add(h->rsp, vlen); |
832 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
833 | } |
834 | |
835 | template <cpu_isa_t isa, typename Wmm> |
836 | void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_fwd( |
837 | const Vmm &vmm_src) { |
838 | // From J.-M. Muller and others, Handbook of Floating-Point Arithmetic, 2010 |
839 | // Here is a brief mathematics to approximate log(x): |
840 | // log(x) = E * log(2) + log(y), where -log(2)/2 <= log(y) <= log(2)/2; |
841 | // log(y) = log(1 + z) - log(r_i), where z = y * r_i - 1, r_i approximates |
842 | // 1 / y, i is index of one of precomputed values; |
843 | // log(1 + z) ~~ polynomial(z), => |
844 | // if (x is normal) |
845 | // log(x) ~~ E * log(2) + polynomial(z) - log(r_i), |
846 | // where log(r_i) is table value. |
847 | // |
848 | // If (x == 0) result = -inf; |
849 | // If (x < 0) result = qnan; (qnan value taken from table_val) |
850 | // If (x == inf) result = inf; |
851 | // If (x == qnan) result = qnan; (qnan value taken from src) |
852 | // If (x == 1) result = 0; |
853 | |
854 | // set unused register as tmp for avx |
855 | if (isa == avx) { |
856 | ymm_tmp = Ymm(vmm_aux0.getIdx()); |
857 | xmm_tmp = Xmm(vmm_aux0.getIdx()); |
858 | } |
859 | |
860 | // save source on stack to check neg and zero values at the end |
861 | h->sub(h->rsp, vlen); |
862 | h->uni_vmovups(h->ptr[h->rsp], vmm_src); |
863 | |
864 | // compute i |
865 | const int approx_order = 5; |
866 | vec_shift(vmm_aux1, vmm_src, false, n_mantissa_bits - approx_order); |
867 | h->uni_vandps(vmm_aux1, vmm_aux1, table_val(log_five_bit_offset)); |
868 | vec_shift(vmm_aux1, vmm_aux1, true, 1); // multiply i by 2 |
869 | |
870 | // compute anticancellation i |
871 | vec_shift(vmm_aux2, vmm_aux1, false, approx_order); |
872 | |
873 | // get E, don't care about sign as only positive numbers are considered |
874 | vec_shift(vmm_aux3, vmm_src, false, n_mantissa_bits); |
875 | if (isa != avx) |
876 | h->uni_vpaddd(vmm_aux3, vmm_aux3, vmm_aux2); |
877 | else { |
878 | Ymm ymm_aux2 = Ymm(vmm_aux2.getIdx()); |
879 | Ymm ymm_aux3 = Ymm(vmm_aux3.getIdx()); |
880 | Xmm xmm_aux2 = Xmm(vmm_aux2.getIdx()); |
881 | Xmm xmm_aux3 = Xmm(vmm_aux3.getIdx()); |
882 | h->vextractf128(xmm_tmp, ymm_aux3, 1); |
883 | h->vpaddd(xmm_aux3, xmm_aux3, xmm_aux2); |
884 | h->vperm2f128(ymm_aux2, ymm_aux2, ymm_aux2, 1); |
885 | h->vpaddd(xmm_tmp, xmm_tmp, xmm_aux2); |
886 | h->vperm2f128(ymm_aux2, ymm_aux2, ymm_aux2, 1); |
887 | h->vinsertf128(ymm_aux3, ymm_aux3, xmm_tmp, 1); |
888 | } |
889 | h->uni_vcvtdq2ps(vmm_aux3, vmm_aux3); |
890 | |
891 | // get m (mantissa) |
892 | h->uni_vxorps(vmm_aux2, vmm_aux2, table_val(exponent_bias)); |
893 | vec_shift(vmm_aux2, vmm_aux2, true, n_mantissa_bits); |
894 | h->uni_vandps(vmm_src, vmm_src, table_val(log_mantissa_mask)); |
895 | h->uni_vorps(vmm_src, vmm_src, vmm_aux2); |
896 | |
897 | // At first, adjust indices for table structure which broadcasts elements |
898 | // by multiplying by simd_w |
899 | const int simd_w = math::ilog2q( |
900 | vlen / sizeof(float)); // equal to 2/3/4 for xmm/ymm/zmm |
901 | vec_shift(vmm_aux1, vmm_aux1, true, simd_w); |
902 | |
903 | const auto it = entry_map_.find(log_predefined_vals); |
904 | assert(it != entry_map_.end()); |
905 | const auto table_start_idx = (*it).second.off; |
906 | |
907 | auto gather_table_values = [&](const Vmm &vmm_dst, const Vmm &vmm_idxs, |
908 | size_t offt = 0) { |
909 | Xbyak::Address table_idx = h->ptr[p_table + table_start_idx + offt |
910 | + vmm_idxs * sizeof(float)]; |
911 | if (is_avx512) { |
912 | h->kmovw(k_mask, table_val(log_full_k_reg_mask)); |
913 | h->vgatherdps(vmm_dst | k_mask, table_idx); |
914 | } else if (utils::one_of(isa, avx2, avx2_vnni_2)) { |
915 | h->uni_vmovups(vmm_mask, table_val(sign_mask)); |
916 | h->vgatherdps(vmm_dst, table_idx, vmm_mask); |
917 | } else if (isa == avx || isa == sse41) { |
918 | Xbyak::Reg64 reg_tmp |
919 | = p_table.getIdx() != h->r9.getIdx() ? h->r9 : h->r10; |
920 | |
921 | const int gpr_size = 8; |
922 | // save reg_tmp state as we are not allowed to spoil it. |
923 | h->sub(h->rsp, gpr_size); |
924 | h->mov(h->ptr[h->rsp], reg_tmp); |
925 | |
926 | // rest of code puts indices on stack, fetching a table number based |
927 | // on an index, replaces index with the value, and, finally, moves |
928 | // fetched values into vector register. |
929 | h->sub(h->rsp, vlen); |
930 | h->uni_vmovups(h->ptr[h->rsp], vmm_idxs); |
931 | |
932 | for (size_t i = 0; i < vlen / sizeof(float); ++i) { |
933 | h->mov(reg_tmp.cvt32(), h->ptr[h->rsp + i * sizeof(float)]); |
934 | h->shl(reg_tmp.cvt32(), 2); // multiply by simd_w |
935 | table_idx = h->ptr[p_table + table_start_idx + offt + reg_tmp]; |
936 | h->mov(reg_tmp.cvt32(), table_idx); |
937 | h->mov(h->ptr[h->rsp + i * sizeof(float)], reg_tmp.cvt32()); |
938 | } |
939 | |
940 | h->uni_vmovups(vmm_dst, h->ptr[h->rsp]); |
941 | h->add(h->rsp, vlen); |
942 | // restore GPR state |
943 | h->mov(reg_tmp, h->ptr[h->rsp]); |
944 | h->add(h->rsp, gpr_size); |
945 | } |
946 | }; |
947 | |
948 | // get r_i, same as table(i) |
949 | gather_table_values(vmm_aux2, vmm_aux1, 0); |
950 | |
951 | // compute relative error (rel_err = m * r_i - 1) |
952 | h->uni_vfmsub213ps(vmm_aux2, vmm_src, table_val(one)); |
953 | |
954 | // compute polynomial(rel_err) |
955 | h->uni_vmovups(vmm_src, table_val(log_pol, 3)); |
956 | h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 2)); |
957 | h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 1)); |
958 | h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(log_pol, 0)); |
959 | h->uni_vfmadd213ps(vmm_src, vmm_aux2, table_val(one)); |
960 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux2); |
961 | |
962 | // get log(r_i) = table(i+1) |
963 | gather_table_values(vmm_aux2, vmm_aux1, vlen); |
964 | |
965 | // compute partial result (pres = E * ln(2) - log(r_i)) |
966 | h->uni_vfmadd231ps(vmm_aux2, vmm_aux3, table_val(ln2f)); |
967 | |
968 | // compute (result = polynomial + pres) w/ TwoSum algorithm |
969 | // TODO: restore this instead of version below when asserts are gone |
970 | // h->uni_vaddps(vmm_aux1, vmm_src, vmm_aux2); // res_hi = pol + pres |
971 | // h->uni_vsubps(vmm_aux3, vmm_aux1, vmm_aux2); // res_lo = res_hi - pres |
972 | // h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); // res_lo = res_lo - pol |
973 | // h->uni_vaddps(vmm_src, vmm_aux1, vmm_aux3); // res_hi = pol + pres |
974 | |
975 | h->uni_vmovups(vmm_aux1, vmm_src); |
976 | h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux2); // res_hi = pol + pres |
977 | h->uni_vmovups(vmm_aux3, vmm_aux1); |
978 | h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_aux2); // res_lo = res_hi - pres |
979 | h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); // res_lo = res_lo - pol |
980 | h->uni_vmovups(vmm_src, vmm_aux1); |
981 | h->uni_vaddps(vmm_src, vmm_src, vmm_aux3); // res_hi = pol + pres |
982 | |
983 | // Check original source for zero and neg values. skip blend w/ extreme |
984 | // values if all src values were positive. |
985 | h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]); |
986 | h->add(h->rsp, vlen); |
987 | |
988 | Xbyak::Label end_log_zero_label; |
989 | compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os); |
990 | test_mask(); |
991 | h->jz(end_log_zero_label); |
992 | |
993 | // Blend extreme values into src if reach here. |
994 | // First zero for -inf values... |
995 | compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_eq_oq); |
996 | blend_with_mask(vmm_src, table_val(log_minus_inf)); |
997 | |
998 | // ...then negative for qnan values. |
999 | compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_lt_os); |
1000 | blend_with_mask(vmm_src, table_val(log_qnan)); |
1001 | |
1002 | h->L(end_log_zero_label); |
1003 | |
1004 | // Leave inf values same as in src. |
1005 | compute_cmp_mask(vmm_aux1, table_val(log_inf), _cmp_eq_oq); |
1006 | Xbyak::Label end_log_inf_label; |
1007 | test_mask(); |
1008 | h->jz(end_log_inf_label); |
1009 | blend_with_mask(vmm_src, table_val(log_inf)); |
1010 | h->L(end_log_inf_label); |
1011 | |
1012 | // Detect qnans if src != src and blend with qnans. |
1013 | compute_cmp_mask(vmm_aux1, vmm_aux1, _cmp_neq_uq); |
1014 | Xbyak::Label end_log_nan_label; |
1015 | test_mask(); |
1016 | h->jz(end_log_nan_label); |
1017 | blend_with_mask(vmm_src, vmm_aux1); |
1018 | h->L(end_log_nan_label); |
1019 | |
1020 | // Detect ones and blend with zeros. |
1021 | compute_cmp_mask(vmm_aux1, table_val(one), _cmp_eq_oq); |
1022 | Xbyak::Label end_log_one_label; |
1023 | test_mask(); |
1024 | h->jz(end_log_one_label); |
1025 | blend_with_mask(vmm_src, table_val(zero)); |
1026 | h->L(end_log_one_label); |
1027 | } |
1028 | |
1029 | template <cpu_isa_t isa, typename Wmm> |
1030 | void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_fwd( |
1031 | const Vmm &vmm_src) { |
1032 | // dispatch between special cases. |
1033 | if (beta_ == -1) { // alpha / x |
1034 | h->uni_vmovups(vmm_aux0, table_val(alpha)); |
1035 | h->uni_vdivps(vmm_src, vmm_aux0, vmm_src, vmm_aux0); |
1036 | } else if (beta_ == 0) { // alpha |
1037 | h->uni_vmovups(vmm_src, table_val(alpha)); |
1038 | } else if (beta_ == 0.5) { // alpha * sqrt(x) |
1039 | sqrt_compute_vector_fwd(vmm_src); |
1040 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1041 | } else if (beta_ == 1) { // alpha * x |
1042 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1043 | } else if (beta_ == 2) { // alpha * x^2 |
1044 | square_compute_vector_fwd(vmm_src); |
1045 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1046 | } else { // general path |
1047 | // caller obligation to save gprs as callee may use them |
1048 | size_t gpr_size = 8; |
1049 | Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, |
1050 | h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; |
1051 | size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); |
1052 | |
1053 | h->sub(h->rsp, n_gprs_to_save * gpr_size); |
1054 | for (size_t i = 0; i < n_gprs_to_save; ++i) |
1055 | h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); |
1056 | |
1057 | // caller obligation to save k-regs as callee may use them |
1058 | size_t n_k_regs_to_save = 8; |
1059 | if (is_avx512) { |
1060 | h->sub(h->rsp, n_k_regs_to_save * k_mask_size); |
1061 | for (size_t i = 0; i < n_k_regs_to_save; ++i) { |
1062 | if (mayiuse(avx512_core)) |
1063 | h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(i)); |
1064 | else |
1065 | h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(i)); |
1066 | } |
1067 | } |
1068 | |
1069 | // 1. Caller obligation to save vector registers as callee may use them. |
1070 | // 2. Additionally save space for vmm_src, to put the answer in-place on |
1071 | // this space and space for beta. |
1072 | // 3. There is an implicit assumption that the host code uses the same |
1073 | // `isa` as the injector. Once the assumption is wrong, `vecs_count` and |
1074 | // `vlen` should be replaced with `host_isa::vlen` and |
1075 | // `host_isa::vecs_count`. |
1076 | h->sub(h->rsp, (vecs_count + 2) * vlen); |
1077 | for (size_t i = 2; i < vecs_count + 2; ++i) |
1078 | h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(i - 2)); |
1079 | h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_src); // src |
1080 | h->uni_vmovups(vmm_src, table_val(beta)); |
1081 | h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_src); // beta |
1082 | |
1083 | // save function address in gpr to pass in in call instruction |
1084 | h->mov(h->rbp, reinterpret_cast<uintptr_t>(powf)); |
1085 | |
1086 | // The 64-bit Windows ABI requires the caller to allocate 32 bytes of |
1087 | // a so called "shadow space" for the callee. It also requires that |
1088 | // the stack be 16 byte aligned before the call instruction is issued. |
1089 | // In order to allocate the shadow space and ensure the 16-byte alignment |
1090 | // of the stack we may actually need to allocate 40 bytes (32 bytes for |
1091 | // the "shadow space" + 8 bytes to align the stack) if the stack |
1092 | // pointer is not currently 16 byte aligned. |
1093 | |
1094 | // align stack on 16-byte as ABI requires |
1095 | h->mov(h->rbx, h->rsp); |
1096 | // Get alignment offset. |
1097 | h->and_(h->rbx, 0xf); |
1098 | h->add(h->rbx, 0x20); |
1099 | h->sub(h->rsp, h->rbx); |
1100 | |
1101 | // Take src, apply powf on it and replace value on a stack with dst. |
1102 | Xmm xmm0 = Xmm(0), xmm1 = Xmm(1); |
1103 | for (size_t i = 0; i < vlen / sizeof(float); ++i) { |
1104 | const Address &source = h->ptr[h->rsp + h->rbx + i * sizeof(float)]; |
1105 | h->uni_vmovss(xmm0, source); |
1106 | h->uni_vmovss(xmm1, h->ptr[h->rsp + h->rbx + vlen]); // beta |
1107 | h->uni_vzeroupper(); // eliminate performance penalties on avx |
1108 | h->call(h->rbp); |
1109 | // eliminate performance penalties on sse isa |
1110 | if (isa == sse41) h->uni_vzeroupper(); |
1111 | h->uni_vmovss(source, xmm0); |
1112 | } |
1113 | |
1114 | h->add(h->rsp, h->rbx); |
1115 | |
1116 | // restore vector registers |
1117 | for (size_t i = vecs_count + 1; i >= 2; --i) |
1118 | h->uni_vmovups(Vmm(i - 2), h->ptr[h->rsp + i * vlen]); |
1119 | h->uni_vmovups(vmm_src, h->ptr[h->rsp + 0 * vlen]); |
1120 | h->add(h->rsp, (vecs_count + 2) * vlen); |
1121 | |
1122 | // restore k registers |
1123 | if (is_avx512) { |
1124 | for (int i = n_k_regs_to_save - 1; i >= 0; --i) { |
1125 | if (mayiuse(avx512_core)) |
1126 | h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); |
1127 | else |
1128 | h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); |
1129 | } |
1130 | h->add(h->rsp, n_k_regs_to_save * k_mask_size); |
1131 | } |
1132 | |
1133 | // restore gpr registers |
1134 | for (int i = n_gprs_to_save - 1; i >= 0; --i) |
1135 | h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); |
1136 | h->add(h->rsp, n_gprs_to_save * gpr_size); |
1137 | |
1138 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1139 | } |
1140 | } |
1141 | |
1142 | template <cpu_isa_t isa, typename Wmm> |
1143 | void jit_uni_eltwise_injector_f32<isa, |
1144 | Wmm>::gelu_erf_minimax_approx_compute_vector_fwd(const Vmm &vmm_src) { |
1145 | using namespace Xbyak::util; |
1146 | |
1147 | // TODO: consider enabling for lower ISA |
1148 | if (!is_superset(isa, avx512_core)) return; |
1149 | |
1150 | // register mapping |
1151 | Vmm vmm_pol = vmm_aux1, vmm_src_square = vmm_aux2, vmm_src_half = vmm_aux3, |
1152 | vmm_src_positive = vmm_aux4; |
1153 | |
1154 | h->uni_vmulps(vmm_src_square, vmm_src, vmm_src); |
1155 | h->uni_vmovups(vmm_src_positive, vmm_src); |
1156 | h->uni_vandps(vmm_src_positive, vmm_src_positive, table_val(positive_mask)); |
1157 | |
1158 | h->uni_vmulps(vmm_src_half, vmm_src, table_val(half)); |
1159 | // compute P(x^2) |
1160 | h->uni_vmovups(vmm_pol, table_val(gelu_erf_minimax_pol, 14)); |
1161 | // TODO: consider reducing latency by spitting into parital sums, for |
1162 | // example by using x^4 polynomial |
1163 | for (int deg = 13; deg >= 0; --deg) { |
1164 | h->uni_vfmadd213ps( |
1165 | vmm_pol, vmm_src_square, table_val(gelu_erf_minimax_pol, deg)); |
1166 | } |
1167 | |
1168 | // 1.0f + erf(x * inv_sqrt2) = 1.0f + x * P(x^2) |
1169 | h->uni_vfmadd213ps(vmm_pol, vmm_src, table_val(one)); |
1170 | // move instead first blend_with_mask? |
1171 | h->uni_vmulps(vmm_pol, vmm_pol, vmm_src_half); |
1172 | // Now we blend the results |
1173 | // [saturation_ubound; +inf] : we return x |
1174 | // [-inf; neg_saturation_ubound] : we return 0.0f |
1175 | h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_neg_saturation_ubound)); |
1176 | compute_cmp_mask(vmm_mask, vmm_src, _cmp_ge_os); |
1177 | blend_with_mask(vmm_src, table_val(zero)); |
1178 | // [neg_saturation_ubound; -linear_ubound] or |
1179 | // [linear_ubound; saturation_lbound] : we return P(x) |
1180 | h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_saturation_lbound)); |
1181 | compute_cmp_mask(vmm_mask, vmm_src_positive, _cmp_gt_os); |
1182 | blend_with_mask(vmm_src, vmm_pol); |
1183 | // [-linear_ubound; linear_ubound] : we return 0.5f * x |
1184 | h->uni_vmovups(vmm_mask, table_val(gelu_erf_minimax_linear_ubound)); |
1185 | compute_cmp_mask(vmm_mask, vmm_src_positive, _cmp_gt_os); |
1186 | blend_with_mask(vmm_src, vmm_src_half); |
1187 | } |
1188 | |
1189 | template <cpu_isa_t isa, typename Wmm> |
1190 | void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_fwd( |
1191 | const Vmm &vmm_src) { |
1192 | if (is_superset(isa, avx512_core)) { |
1193 | gelu_erf_minimax_approx_compute_vector_fwd(vmm_src); |
1194 | return; |
1195 | } |
1196 | |
1197 | // Here we approximate erf(x) using the expression by |
1198 | // Abramowitz and Stegun from ``Handbook of Mathematical |
1199 | // Functions'' |
1200 | // NOTE: The performance of this kernel can be further improved |
1201 | // with a minimax polynomialial expansion, thereby avoiding division |
1202 | // and exp. However, so far, this has costed larger accuracy |
1203 | // differences with respect to glibc erf based GELU, in particular |
1204 | // ~1.0e-5 -- 1.0e-3 absolute error at s = -5. |
1205 | |
1206 | // use vmm_aux3 to store original src. |
1207 | h->uni_vmovups(vmm_aux3, vmm_src); |
1208 | |
1209 | // x = s / sqrt(2) |
1210 | h->uni_vmulps(vmm_src, vmm_src, |
1211 | table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_two)); |
1212 | |
1213 | // abs(x) |
1214 | h->uni_vmovups(vmm_aux4, vmm_src); |
1215 | abs_compute_vector_fwd(vmm_aux4); |
1216 | |
1217 | // t = 1 / (p*x + 1) |
1218 | h->uni_vmovups( |
1219 | vmm_aux2, table_val(gelu_erf_Abramowitz_Stegun_approx_const)); |
1220 | h->uni_vfmadd213ps(vmm_aux2, vmm_aux4, table_val(one)); |
1221 | h->uni_vmovups(vmm_aux4, table_val(one)); |
1222 | h->uni_vdivps(vmm_aux4, vmm_aux4, vmm_aux2); |
1223 | |
1224 | // -exp(-x*x) |
1225 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
1226 | h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask)); |
1227 | exp_compute_vector_fwd(vmm_src); // pollutes aux1, aux2 |
1228 | h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask)); |
1229 | |
1230 | // get sign |
1231 | h->uni_vmovups(vmm_aux0, vmm_aux3); |
1232 | h->uni_vandps(vmm_aux0, vmm_aux0, table_val(sign_mask)); |
1233 | |
1234 | // -exp(-x*x)*t |
1235 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux4); |
1236 | |
1237 | // compute polynomialial r |
1238 | h->uni_vmovups(vmm_aux1, table_val(gelu_erf_Abramowitz_Stegun_pol, 4)); |
1239 | h->uni_vfmadd213ps( |
1240 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 3)); |
1241 | h->uni_vfmadd213ps( |
1242 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 2)); |
1243 | h->uni_vfmadd213ps( |
1244 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 1)); |
1245 | h->uni_vfmadd213ps( |
1246 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 0)); |
1247 | |
1248 | // erf = sign * (1 - r * t * exp(-x*x)) |
1249 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one)); |
1250 | h->uni_vxorps(vmm_src, vmm_src, vmm_aux0); |
1251 | |
1252 | // S = 0.5 * s |
1253 | h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(half)); |
1254 | // GELU = 0.5 * s * (1 + erf) = S + S * erf |
1255 | h->uni_vfmadd213ps(vmm_src, vmm_aux3, vmm_aux3); |
1256 | } |
1257 | |
1258 | template <cpu_isa_t isa, typename Wmm> |
1259 | void jit_uni_eltwise_injector_f32<isa, Wmm>::relu_compute_vector_bwd( |
1260 | const Vmm &vmm_src) { |
1261 | // invariant to whether `s` or `d` is passed. |
1262 | // get mask of `s` > 0 |
1263 | compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os); |
1264 | // fill with alpha, then blend with 1.f |
1265 | h->uni_vmovups(vmm_src, table_val(alpha)); |
1266 | blend_with_mask(vmm_src, table_val(one)); |
1267 | } |
1268 | |
1269 | template <cpu_isa_t isa, typename Wmm> |
1270 | void jit_uni_eltwise_injector_f32<isa, Wmm>::elu_compute_vector_bwd( |
1271 | const Vmm &vmm_src) { |
1272 | if (!use_dst_) { |
1273 | // R = exp(s) |
1274 | exp_compute_vector_fwd(vmm_src); |
1275 | // after exponentiation, get mask by comparing with exp(0)=1.f, not 0.f |
1276 | compute_cmp_mask(vmm_src, table_val(one), _cmp_gt_os); |
1277 | // R * alpha, then blend with 1.f |
1278 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1279 | } else { |
1280 | // get mask of `d` > 0 |
1281 | compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os); |
1282 | // R = `d` + alpha, then blend with 1.f |
1283 | h->uni_vaddps(vmm_src, vmm_src, table_val(alpha)); |
1284 | } |
1285 | blend_with_mask(vmm_src, table_val(one)); |
1286 | } |
1287 | |
1288 | template <cpu_isa_t isa, typename Wmm> |
1289 | void jit_uni_eltwise_injector_f32<isa, Wmm>::tanh_compute_vector_bwd( |
1290 | const Vmm &vmm_src) { |
1291 | // res = 1 - d^2 = 1 - tanh^2(s) |
1292 | if (!use_dst_) tanh_compute_vector_fwd(vmm_src); |
1293 | h->uni_vmovups(vmm_aux0, table_val(one)); |
1294 | h->uni_vfnmadd231ps(vmm_aux0, vmm_src, vmm_src); |
1295 | h->uni_vmovups(vmm_src, vmm_aux0); |
1296 | } |
1297 | |
1298 | template <cpu_isa_t isa, typename Wmm> |
1299 | void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_bwd( |
1300 | const Vmm &vmm_src) { |
1301 | h->uni_vmovups(vmm_aux0, vmm_src); |
1302 | |
1303 | // compute G1(x) = sqrt_root_two_over_pi * x * (1 + fitting_const * x^2) |
1304 | // compute G2(x) = sqrt_root_two_over_pi * x * (1 + 3 * fitting_const * x^2) |
1305 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
1306 | |
1307 | // keep G2 in a separate register |
1308 | h->uni_vmovups(vmm_aux2, table_val(gelu_tanh_fitting_const_times_three)); |
1309 | h->uni_vfmadd213ps(vmm_aux2, vmm_src, table_val(one)); |
1310 | |
1311 | h->uni_vmovups(vmm_aux1, table_val(gelu_tanh_fitting_const)); |
1312 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one)); |
1313 | h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(gelu_tanh_sqrt_two_over_pi)); |
1314 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
1315 | h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux0); |
1316 | |
1317 | // save G2 on stack as tanh uses all available registers |
1318 | h->sub(h->rsp, vlen); |
1319 | h->uni_vmovups(h->ptr[h->rsp], vmm_aux2); |
1320 | |
1321 | // T = tanh(G1(x)) |
1322 | tanh_compute_vector_fwd(vmm_src); |
1323 | |
1324 | h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]); |
1325 | h->add(h->rsp, vlen); |
1326 | |
1327 | // compute 0.5 * (1 + T) * (1 + G2 * (1 - T)) |
1328 | if (isa == sse41 || isa == avx) { |
1329 | h->uni_vmovups(vmm_aux3, table_val(one)); |
1330 | h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src); |
1331 | h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux3); |
1332 | h->uni_vaddps(vmm_src, vmm_src, table_val(one)); |
1333 | h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_src); |
1334 | h->uni_vaddps(vmm_src, vmm_src, vmm_aux2); |
1335 | } else { |
1336 | // 1) R = G2 * (1 - T) = G2 - G2 * T |
1337 | h->uni_vfnmadd231ps(vmm_aux2, vmm_aux2, vmm_src); |
1338 | // 2) Q = 1 + T |
1339 | h->uni_vaddps(vmm_src, vmm_src, table_val(one)); |
1340 | // 3) res = Q * (1 + R) = Q + Q * R |
1341 | h->uni_vfmadd231ps(vmm_src, vmm_src, vmm_aux2); |
1342 | } |
1343 | h->uni_vmulps(vmm_src, vmm_src, table_val(half)); |
1344 | } |
1345 | |
1346 | template <cpu_isa_t isa, typename Wmm> |
1347 | void jit_uni_eltwise_injector_f32<isa, Wmm>::square_compute_vector_bwd( |
1348 | const Vmm &vmm_src) { |
1349 | // res = 2 * s |
1350 | h->uni_vmulps(vmm_src, vmm_src, table_val(two)); |
1351 | } |
1352 | |
1353 | template <cpu_isa_t isa, typename Wmm> |
1354 | void jit_uni_eltwise_injector_f32<isa, Wmm>::abs_compute_vector_bwd( |
1355 | const Vmm &vmm_src) { |
1356 | // replace positive values with 1.f |
1357 | compute_cmp_mask(vmm_src, table_val(zero), _cmp_gt_os); |
1358 | blend_with_mask(vmm_src, table_val(one)); |
1359 | // replace negative values with -1.f |
1360 | compute_cmp_mask(vmm_src, table_val(zero), _cmp_lt_os); |
1361 | blend_with_mask(vmm_src, table_val(minus_one)); |
1362 | } |
1363 | |
1364 | template <cpu_isa_t isa, typename Wmm> |
1365 | void jit_uni_eltwise_injector_f32<isa, Wmm>::sqrt_compute_vector_bwd( |
1366 | const Vmm &vmm_src) { |
1367 | // res = 0.5 / d = 0.5 / sqrt(s) |
1368 | if (!use_dst_) sqrt_compute_vector_fwd(vmm_src); |
1369 | h->uni_vmovups(vmm_aux0, table_val(half)); |
1370 | // h->uni_vdivps(vmm_src, vmm_aux0, vmm_src); // bless sse41 |
1371 | h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src); |
1372 | h->uni_vmovups(vmm_src, vmm_aux0); |
1373 | } |
1374 | |
1375 | template <cpu_isa_t isa, typename Wmm> |
1376 | void jit_uni_eltwise_injector_f32<isa, Wmm>::linear_compute_vector_bwd( |
1377 | const Vmm &vmm_src) { |
1378 | h->uni_vmovups(vmm_src, table_val(alpha)); |
1379 | } |
1380 | |
1381 | template <cpu_isa_t isa, typename Wmm> |
1382 | void jit_uni_eltwise_injector_f32<isa, Wmm>::soft_relu_compute_vector_bwd( |
1383 | const Vmm &vmm_src) { |
1384 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1385 | logistic_compute_vector_fwd(vmm_src); |
1386 | } |
1387 | |
1388 | template <cpu_isa_t isa, typename Wmm> |
1389 | void jit_uni_eltwise_injector_f32<isa, Wmm>::mish_compute_vector_bwd( |
1390 | const Vmm &vmm_src) { |
1391 | // IMPORTANT: we use vmm_aux3 to save src as exp does not use it. |
1392 | h->uni_vmovups(vmm_aux3, vmm_src); // vmm_aux3 = x |
1393 | |
1394 | h->uni_vminps(vmm_src, vmm_src, table_val(bwd_mish_max_x_for_equation_f)); |
1395 | exp_compute_vector_fwd(vmm_src); |
1396 | h->uni_vmovups(vmm_aux2, vmm_src); // vmm_aux2 = e^x |
1397 | |
1398 | // e^3x + 4*e^2x |
1399 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); // e^2x |
1400 | h->uni_vmovups(vmm_aux1, vmm_src); |
1401 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two)); |
1402 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two)); // 4*e^2x |
1403 | h->uni_vfmadd213ps(vmm_src, vmm_aux2, vmm_aux1); |
1404 | |
1405 | // e^3x + 4*e^2x + 4*e^x*(x+1.5) |
1406 | h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(one)); // vmm_aux3 = x + 1 |
1407 | h->uni_vmovups(vmm_aux1, vmm_aux3); |
1408 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(half)); |
1409 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two)); |
1410 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(two)); |
1411 | h->uni_vfmadd231ps(vmm_src, vmm_aux1, vmm_aux2); |
1412 | |
1413 | // omega = e^3x + 4*e^2x + 4*e^x*(x+1.5) + 4*(x+1) |
1414 | h->uni_vmulps(vmm_aux3, vmm_aux3, table_val(two)); |
1415 | h->uni_vfmadd231ps(vmm_src, vmm_aux3, table_val(two)); |
1416 | |
1417 | // delta = (e^x+1)^2 + 1 |
1418 | h->uni_vmovups(vmm_aux1, vmm_aux2); |
1419 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one)); |
1420 | h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux1); |
1421 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one)); |
1422 | h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux1); |
1423 | |
1424 | // e^x * omega / delta^2 |
1425 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux2); |
1426 | h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); |
1427 | } |
1428 | |
1429 | template <cpu_isa_t isa, typename Wmm> |
1430 | void jit_uni_eltwise_injector_f32<isa, Wmm>::logistic_compute_vector_bwd( |
1431 | const Vmm &vmm_src) { |
1432 | // res = d * (1 - d) = d - d * d; d = logistic(s) |
1433 | if (!use_dst_) logistic_compute_vector_fwd(vmm_src); |
1434 | // h->uni_vfnmadd231ps(vmm_src, vmm_src, vmm_src); // bless sse41 |
1435 | h->uni_vmovups(vmm_aux0, table_val(one)); |
1436 | h->uni_vsubps(vmm_aux0, vmm_aux0, vmm_src); |
1437 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux0); |
1438 | } |
1439 | |
1440 | template <cpu_isa_t isa, typename Wmm> |
1441 | void jit_uni_eltwise_injector_f32<isa, Wmm>::exp_compute_vector_bwd( |
1442 | const Vmm &vmm_src) { |
1443 | if (!use_dst_) exp_compute_vector_fwd(vmm_src); |
1444 | } |
1445 | |
1446 | template <cpu_isa_t isa, typename Wmm> |
1447 | void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_bwd( |
1448 | const Vmm &vmm_src) { |
1449 | // R = alpha * s |
1450 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1451 | // Save R on stack for later usage |
1452 | h->sub(h->rsp, vlen); |
1453 | h->uni_vmovups(h->ptr[h->rsp], vmm_src); |
1454 | // Q = sigmoid(alpha * s) |
1455 | logistic_compute_vector_fwd(vmm_src); |
1456 | h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]); |
1457 | h->add(h->rsp, vlen); |
1458 | // compute Q * (1 + R * (1 - Q)) |
1459 | if (utils::one_of(isa, sse41, avx)) { |
1460 | h->uni_vmovups(vmm_aux1, table_val(one)); |
1461 | h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_src); |
1462 | h->uni_vmulps(vmm_aux1, vmm_aux1, vmm_aux0); |
1463 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(one)); |
1464 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux1); |
1465 | } else { |
1466 | // T = R * (1 - Q) = R - R * Q |
1467 | h->uni_vfnmadd231ps(vmm_aux0, vmm_aux0, vmm_src); |
1468 | // Q * (1 + T) = Q + Q * T |
1469 | h->uni_vfmadd231ps(vmm_src, vmm_src, vmm_aux0); |
1470 | } |
1471 | } |
1472 | |
1473 | template <cpu_isa_t isa, typename Wmm> |
1474 | void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_bwd( |
1475 | const Vmm &vmm_src) { |
1476 | // res = 1 / s |
1477 | h->uni_vmovups(vmm_aux0, table_val(one)); |
1478 | // h->uni_vdivps(vmm_src, vmm_aux0, vmm_src); // bless sse41 |
1479 | h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src); |
1480 | h->uni_vmovups(vmm_src, vmm_aux0); |
1481 | } |
1482 | |
1483 | template <cpu_isa_t isa, typename Wmm> |
1484 | void jit_uni_eltwise_injector_f32<isa, Wmm>::clip_compute_vector_bwd( |
1485 | const Vmm &vmm_src) { |
1486 | using namespace alg_kind; |
1487 | |
1488 | // set result with 1.f |
1489 | h->uni_vmovups(vmm_aux1, table_val(one)); |
1490 | const auto cmp_flag = alg_ == eltwise_clip ? _cmp_gt_os : _cmp_ge_os; |
1491 | // get mask of values > beta (or >= beta) and blend with 0.f |
1492 | compute_cmp_mask(vmm_src, table_val(beta), cmp_flag); |
1493 | blend_with_mask(vmm_aux1, table_val(zero)); |
1494 | // get mask of values <= alpha and blend with 0.f |
1495 | compute_cmp_mask(vmm_src, table_val(alpha), _cmp_le_os); |
1496 | blend_with_mask(vmm_aux1, table_val(zero)); |
1497 | h->uni_vmovups(vmm_src, vmm_aux1); |
1498 | } |
1499 | |
1500 | template <cpu_isa_t isa, typename Wmm> |
1501 | void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_bwd( |
1502 | const Vmm &vmm_src) { |
1503 | // dispatch some special cases. |
1504 | if (beta_ == 0) { // zero |
1505 | h->uni_vmovups(vmm_src, table_val(zero)); |
1506 | } else if (beta_ == 0.5) { // 0.5 * alpha / sqrt(s) |
1507 | sqrt_compute_vector_bwd(vmm_src); |
1508 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1509 | } else if (beta_ == 1) { // alpha |
1510 | h->uni_vmovups(vmm_src, table_val(alpha)); |
1511 | } else { |
1512 | // Save `s` on stack for later usage |
1513 | h->sub(h->rsp, vlen); |
1514 | h->uni_vmovups(h->ptr[h->rsp], vmm_src); |
1515 | // R = alpha * pow(s, beta) |
1516 | pow_compute_vector_fwd(vmm_src); |
1517 | // Restore `s` from stack |
1518 | h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]); |
1519 | h->add(h->rsp, vlen); |
1520 | // Save mask of zero elements to convert them into zeros at the end |
1521 | if (beta_ >= 1) compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_eq_oq); |
1522 | // res = alpha * beta * pow(s, beta - 1) = beta * R / s; |
1523 | h->uni_vdivps(vmm_src, vmm_src, vmm_aux1); |
1524 | h->uni_vmulps(vmm_src, vmm_src, table_val(beta)); |
1525 | |
1526 | // beta < 1 leads to NaN as `s` appears in denominator, but beta >= 1 |
1527 | // should lead to zero, when `s` is zero. |
1528 | if (beta_ >= 1) blend_with_mask(vmm_src, table_val(zero)); |
1529 | } |
1530 | } |
1531 | |
1532 | template <cpu_isa_t isa, typename Wmm> |
1533 | void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_bwd( |
1534 | const Vmm &vmm_src) { |
1535 | // R = s / sqrt(2) |
1536 | h->uni_vmulps(vmm_src, vmm_src, |
1537 | table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_two)); |
1538 | |
1539 | // Save R on stack for later usage |
1540 | h->sub(h->rsp, vlen); |
1541 | h->uni_vmovups(h->ptr[h->rsp], vmm_src); |
1542 | |
1543 | // Q = exp(-R*R) |
1544 | h->uni_vmulps(vmm_src, vmm_src, vmm_src); |
1545 | h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask)); |
1546 | exp_compute_vector_fwd(vmm_src); |
1547 | |
1548 | // T = R / sqrt(pi) * Q |
1549 | h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]); |
1550 | h->uni_vmulps(vmm_aux2, vmm_aux2, |
1551 | table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi)); |
1552 | h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_src); |
1553 | |
1554 | // -Q |
1555 | h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask)); |
1556 | |
1557 | // get sign |
1558 | h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]); |
1559 | h->uni_vandps(vmm_aux0, vmm_aux0, table_val(sign_mask)); |
1560 | |
1561 | // abs(x) |
1562 | h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]); |
1563 | h->add(h->rsp, vlen); |
1564 | abs_compute_vector_fwd(vmm_aux1); |
1565 | |
1566 | // W = 1 / (p * s + 1) |
1567 | h->uni_vmovups( |
1568 | vmm_aux3, table_val(gelu_erf_Abramowitz_Stegun_approx_const)); |
1569 | h->uni_vmovups(vmm_aux4, table_val(one)); |
1570 | h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, vmm_aux4); |
1571 | h->uni_vdivps(vmm_aux4, vmm_aux4, vmm_aux3); |
1572 | |
1573 | // Q * W |
1574 | h->uni_vmulps(vmm_src, vmm_src, vmm_aux4); |
1575 | |
1576 | // compute polynomial r |
1577 | h->uni_vmovups(vmm_aux1, table_val(gelu_erf_Abramowitz_Stegun_pol, 4)); |
1578 | h->uni_vfmadd213ps( |
1579 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 3)); |
1580 | h->uni_vfmadd213ps( |
1581 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 2)); |
1582 | h->uni_vfmadd213ps( |
1583 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 1)); |
1584 | h->uni_vfmadd213ps( |
1585 | vmm_aux1, vmm_aux4, table_val(gelu_erf_Abramowitz_Stegun_pol, 0)); |
1586 | |
1587 | // erf = sign * (1 - r * t * exp(-x*x)) |
1588 | h->uni_vfmadd213ps(vmm_src, vmm_aux1, table_val(one)); |
1589 | h->uni_vxorps(vmm_src, vmm_src, vmm_aux0); |
1590 | |
1591 | // P = T + 0.5 |
1592 | h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(half)); |
1593 | // res = P + 0.5 * erf |
1594 | h->uni_vfmadd231ps(vmm_aux2, vmm_src, table_val(half)); |
1595 | h->uni_vmovups(vmm_src, vmm_aux2); |
1596 | } |
1597 | |
1598 | template <cpu_isa_t isa, typename Wmm> |
1599 | void jit_uni_eltwise_injector_f32<isa, Wmm>::hardswish_compute_vector_bwd( |
1600 | const Vmm &vmm_src) { |
1601 | // Get mask for 0 < alpha * x + beta < 1 |
1602 | h->uni_vmovups(vmm_aux1, vmm_src); |
1603 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(alpha)); |
1604 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(beta)); |
1605 | // Form a derivative value |
1606 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1607 | h->uni_vaddps(vmm_src, vmm_src, vmm_aux1); |
1608 | |
1609 | compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os); |
1610 | blend_with_mask(vmm_src, table_val(zero)); |
1611 | compute_cmp_mask(vmm_aux1, table_val(one), _cmp_ge_os); |
1612 | blend_with_mask(vmm_src, table_val(one)); |
1613 | } |
1614 | |
1615 | template <cpu_isa_t isa, typename Wmm> |
1616 | void jit_uni_eltwise_injector_f32<isa, Wmm>::hardsigmoid_compute_vector_bwd( |
1617 | const Vmm &vmm_src) { |
1618 | // Get mask for 0 < alpha * x + beta < 1 |
1619 | // Zero rest values. |
1620 | h->uni_vmovups(vmm_aux1, vmm_src); |
1621 | h->uni_vmulps(vmm_aux1, vmm_aux1, table_val(alpha)); |
1622 | h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(beta)); |
1623 | |
1624 | h->uni_vmovups(vmm_src, table_val(one)); |
1625 | compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os); |
1626 | blend_with_mask(vmm_src, table_val(zero)); |
1627 | compute_cmp_mask(vmm_aux1, table_val(one), _cmp_ge_os); |
1628 | blend_with_mask(vmm_src, table_val(zero)); |
1629 | h->uni_vmulps(vmm_src, vmm_src, table_val(alpha)); |
1630 | } |
1631 | |
1632 | template <cpu_isa_t isa, typename Wmm> |
1633 | size_t jit_uni_eltwise_injector_f32<isa, Wmm>::aux_gprs_count() { |
1634 | using namespace alg_kind; |
1635 | switch (alg_) { |
1636 | case eltwise_tanh_use_dst_for_bwd: |
1637 | case eltwise_tanh: |
1638 | case eltwise_gelu_tanh: return isa == sse41 || isa == avx ? 4 : 0; |
1639 | default: return 0; |
1640 | } |
1641 | return 0; |
1642 | }; |
1643 | |
1644 | template <cpu_isa_t isa, typename Wmm> |
1645 | void jit_uni_eltwise_injector_f32<isa, Wmm>::round_compute_vector_fwd( |
1646 | const Vmm &vmm_src) { |
1647 | h->uni_vroundps(vmm_src, vmm_src, _op_mxcsr); |
1648 | } |
1649 | |
1650 | template <cpu_isa_t isa, typename Wmm> |
1651 | size_t jit_uni_eltwise_injector_f32<isa, Wmm>::aux_vecs_count() { |
1652 | using namespace alg_kind; |
1653 | if (is_fwd_) { |
1654 | switch (alg_) { |
1655 | case eltwise_relu_use_dst_for_bwd: |
1656 | case eltwise_relu: return (alpha_ == 0.f) ? 0 : 2; |
1657 | case eltwise_elu_use_dst_for_bwd: |
1658 | case eltwise_elu: return 4; |
1659 | case eltwise_tanh_use_dst_for_bwd: |
1660 | case eltwise_tanh: return 5; |
1661 | case eltwise_square: return 0; |
1662 | case eltwise_abs: return 0; |
1663 | case eltwise_sqrt_use_dst_for_bwd: |
1664 | case eltwise_sqrt: return 0; |
1665 | case eltwise_linear: return 1; |
1666 | case eltwise_soft_relu: return 4; |
1667 | case eltwise_mish: return 4; |
1668 | case eltwise_logistic_use_dst_for_bwd: |
1669 | case eltwise_logistic: return 4; |
1670 | case eltwise_exp_use_dst_for_bwd: |
1671 | case eltwise_exp: return 3; |
1672 | case eltwise_gelu_tanh: return 5; |
1673 | case eltwise_swish: return 4; |
1674 | case eltwise_log: return 5; |
1675 | case eltwise_clip: |
1676 | case eltwise_clip_v2_use_dst_for_bwd: |
1677 | case eltwise_clip_v2: return 0; |
1678 | case eltwise_pow: return 2; |
1679 | case eltwise_gelu_erf: return 5; |
1680 | case eltwise_round: return 0; |
1681 | case eltwise_hardswish: return 1; |
1682 | case eltwise_hardsigmoid: return 0; |
1683 | default: assert(!"unsupported eltwise algorithm" ); |
1684 | } |
1685 | } else { |
1686 | switch (alg_) { |
1687 | case eltwise_relu_use_dst_for_bwd: |
1688 | case eltwise_relu: return 1; |
1689 | case eltwise_elu_use_dst_for_bwd: return 1; |
1690 | case eltwise_elu: return 3; |
1691 | case eltwise_tanh_use_dst_for_bwd: return 1; |
1692 | case eltwise_tanh: return 5; |
1693 | case eltwise_square: return 0; |
1694 | case eltwise_abs: return 0; |
1695 | case eltwise_sqrt_use_dst_for_bwd: |
1696 | case eltwise_sqrt: return 1; |
1697 | case eltwise_linear: return 0; |
1698 | case eltwise_soft_relu: return 4; |
1699 | case eltwise_mish: return 4; |
1700 | case eltwise_logistic_use_dst_for_bwd: return 1; |
1701 | case eltwise_logistic: return 4; |
1702 | case eltwise_exp_use_dst_for_bwd: return 0; |
1703 | case eltwise_exp: return 3; |
1704 | case eltwise_gelu_tanh: return 5; |
1705 | case eltwise_swish: return 4; |
1706 | case eltwise_log: return 1; |
1707 | case eltwise_clip: |
1708 | case eltwise_clip_v2_use_dst_for_bwd: |
1709 | case eltwise_clip_v2: return 2; |
1710 | case eltwise_pow: return 2; |
1711 | case eltwise_gelu_erf: return 5; |
1712 | case eltwise_hardswish: return 2; |
1713 | case eltwise_hardsigmoid: return 2; |
1714 | default: assert(!"unsupported eltwise algorithm" ); |
1715 | } |
1716 | } |
1717 | |
1718 | return 0; |
1719 | } |
1720 | |
1721 | template <cpu_isa_t isa, typename Wmm> |
1722 | void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_body( |
1723 | const injector_utils::vmm_index_set_iterator_t &start_idx_it, |
1724 | const injector_utils::vmm_index_set_iterator_t &end_idx_it) { |
1725 | using namespace alg_kind; |
1726 | std::for_each(start_idx_it, end_idx_it, [&](size_t idx) { |
1727 | if (is_fwd_) { |
1728 | switch (alg_) { |
1729 | case eltwise_relu_use_dst_for_bwd: |
1730 | case eltwise_relu: |
1731 | if (alpha_ == 0.f) |
1732 | relu_zero_ns_compute_vector_fwd(Vmm(idx)); |
1733 | else |
1734 | relu_compute_vector_fwd(Vmm(idx)); |
1735 | break; |
1736 | case eltwise_elu_use_dst_for_bwd: |
1737 | case eltwise_elu: elu_compute_vector_fwd(Vmm(idx)); break; |
1738 | case eltwise_tanh_use_dst_for_bwd: |
1739 | case eltwise_tanh: tanh_compute_vector_fwd(Vmm(idx)); break; |
1740 | case eltwise_square: square_compute_vector_fwd(Vmm(idx)); break; |
1741 | case eltwise_abs: abs_compute_vector_fwd(Vmm(idx)); break; |
1742 | case eltwise_sqrt_use_dst_for_bwd: |
1743 | case eltwise_sqrt: sqrt_compute_vector_fwd(Vmm(idx)); break; |
1744 | case eltwise_swish: swish_compute_vector_fwd(Vmm(idx)); break; |
1745 | case eltwise_linear: linear_compute_vector_fwd(Vmm(idx)); break; |
1746 | case eltwise_soft_relu: |
1747 | soft_relu_compute_vector_fwd(Vmm(idx)); |
1748 | break; |
1749 | case eltwise_mish: mish_compute_vector_fwd(Vmm(idx)); break; |
1750 | case eltwise_logistic_use_dst_for_bwd: |
1751 | case eltwise_logistic: |
1752 | logistic_compute_vector_fwd(Vmm(idx)); |
1753 | break; |
1754 | case eltwise_exp_use_dst_for_bwd: |
1755 | case eltwise_exp: exp_compute_vector_fwd(Vmm(idx)); break; |
1756 | case eltwise_gelu_tanh: |
1757 | gelu_tanh_compute_vector_fwd(Vmm(idx)); |
1758 | break; |
1759 | case eltwise_log: log_compute_vector_fwd(Vmm(idx)); break; |
1760 | case eltwise_clip: |
1761 | case eltwise_clip_v2_use_dst_for_bwd: |
1762 | case eltwise_clip_v2: clip_compute_vector_fwd(Vmm(idx)); break; |
1763 | case eltwise_pow: pow_compute_vector_fwd(Vmm(idx)); break; |
1764 | case eltwise_gelu_erf: |
1765 | gelu_erf_compute_vector_fwd(Vmm(idx)); |
1766 | break; |
1767 | case eltwise_round: round_compute_vector_fwd(Vmm(idx)); break; |
1768 | case eltwise_hardswish: |
1769 | hardswish_compute_vector_fwd(Vmm(idx)); |
1770 | break; |
1771 | case eltwise_hardsigmoid: |
1772 | hardsigmoid_compute_vector_fwd(Vmm(idx)); |
1773 | break; |
1774 | default: assert(!"unsupported eltwise algorithm" ); |
1775 | } |
1776 | } else { |
1777 | switch (alg_) { |
1778 | case eltwise_relu_use_dst_for_bwd: |
1779 | case eltwise_relu: relu_compute_vector_bwd(Vmm(idx)); break; |
1780 | case eltwise_elu_use_dst_for_bwd: |
1781 | case eltwise_elu: elu_compute_vector_bwd(Vmm(idx)); break; |
1782 | case eltwise_tanh_use_dst_for_bwd: |
1783 | case eltwise_tanh: tanh_compute_vector_bwd(Vmm(idx)); break; |
1784 | case eltwise_square: square_compute_vector_bwd(Vmm(idx)); break; |
1785 | case eltwise_abs: abs_compute_vector_bwd(Vmm(idx)); break; |
1786 | case eltwise_sqrt_use_dst_for_bwd: |
1787 | case eltwise_sqrt: sqrt_compute_vector_bwd(Vmm(idx)); break; |
1788 | case eltwise_linear: linear_compute_vector_bwd(Vmm(idx)); break; |
1789 | case eltwise_soft_relu: |
1790 | soft_relu_compute_vector_bwd(Vmm(idx)); |
1791 | break; |
1792 | case eltwise_mish: mish_compute_vector_bwd(Vmm(idx)); break; |
1793 | case eltwise_logistic_use_dst_for_bwd: |
1794 | case eltwise_logistic: |
1795 | logistic_compute_vector_bwd(Vmm(idx)); |
1796 | break; |
1797 | case eltwise_exp_use_dst_for_bwd: |
1798 | case eltwise_exp: exp_compute_vector_bwd(Vmm(idx)); break; |
1799 | case eltwise_gelu_tanh: |
1800 | gelu_tanh_compute_vector_bwd(Vmm(idx)); |
1801 | break; |
1802 | case eltwise_swish: swish_compute_vector_bwd(Vmm(idx)); break; |
1803 | case eltwise_log: log_compute_vector_bwd(Vmm(idx)); break; |
1804 | case eltwise_clip: |
1805 | case eltwise_clip_v2_use_dst_for_bwd: |
1806 | case eltwise_clip_v2: clip_compute_vector_bwd(Vmm(idx)); break; |
1807 | case eltwise_pow: pow_compute_vector_bwd(Vmm(idx)); break; |
1808 | case eltwise_gelu_erf: |
1809 | gelu_erf_compute_vector_bwd(Vmm(idx)); |
1810 | break; |
1811 | case eltwise_hardswish: |
1812 | hardswish_compute_vector_bwd(Vmm(idx)); |
1813 | break; |
1814 | case eltwise_hardsigmoid: |
1815 | hardsigmoid_compute_vector_bwd(Vmm(idx)); |
1816 | break; |
1817 | default: assert(!"unsupported eltwise algorithm" ); |
1818 | } |
1819 | } |
1820 | if (scale_ != 1.f) { |
1821 | h->uni_vmulps(Vmm(idx), Vmm(idx), table_val(scale)); |
1822 | } |
1823 | }); |
1824 | } |
1825 | |
1826 | template <cpu_isa_t isa, typename Wmm> |
1827 | void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_vector_range( |
1828 | size_t start_idx, size_t end_idx) { |
1829 | injector_utils::vmm_index_set_t vmm_idxs; |
1830 | for (size_t i = start_idx; i < end_idx; i++) |
1831 | vmm_idxs.emplace(i); |
1832 | compute_vector_range(vmm_idxs); |
1833 | } |
1834 | |
1835 | template <cpu_isa_t isa, typename Wmm> |
1836 | void jit_uni_eltwise_injector_f32<isa, Wmm>::compute_vector_range( |
1837 | const injector_utils::vmm_index_set_t &vmm_idxs) { |
1838 | const auto &start_idx_it = vmm_idxs.begin(); |
1839 | const auto &end_idx_it = vmm_idxs.end(); |
1840 | assert(*start_idx_it < *vmm_idxs.rbegin() + 1 |
1841 | && *vmm_idxs.rbegin() <= vecs_count); |
1842 | |
1843 | injector_preamble(vmm_idxs); |
1844 | compute_body(start_idx_tail, end_idx_it); |
1845 | injector_preamble_tail(start_idx_it); |
1846 | compute_body(start_idx_it, start_idx_tail); |
1847 | injector_postamble(); |
1848 | } |
1849 | |
1850 | template <cpu_isa_t isa, typename Wmm> |
1851 | void jit_uni_eltwise_injector_f32<isa, Wmm>::prepare_table(bool gen_table) { |
1852 | if (!gen_table) return; |
1853 | |
1854 | h->align(64); |
1855 | h->L(l_table); |
1856 | |
1857 | // Assumption: entries can be inserted with dd, so they should be 4 bytes. |
1858 | assert(sizeof(table_entry_val_t) == 4); |
1859 | |
1860 | // Assumption: iterating on entry_map_ here has the same order as |
1861 | // when we set the offsets. We verify that in asserts. |
1862 | // table_entry_val_t is assumed to be 32 bits |
1863 | #ifndef NDEBUG |
1864 | size_t off = 0; |
1865 | key_t curr_key = undef_key; |
1866 | int key_occurences = 0; |
1867 | #endif |
1868 | |
1869 | // Run through the map and insert values stored there |
1870 | for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { |
1871 | const auto &te = (*it).second; // get map entry for a given key |
1872 | const auto len = te.bcast ? vlen : sizeof(table_entry_val_t); |
1873 | for (size_t d = 0; d < len; d += sizeof(table_entry_val_t)) |
1874 | h->dd(te.val); |
1875 | |
1876 | #ifndef NDEBUG |
1877 | // we check that the precomputed offsets match the registered ones |
1878 | const auto &key = (*it).first; // get map entry key |
1879 | if (key != curr_key) { |
1880 | curr_key = key; |
1881 | key_occurences = 0; |
1882 | } |
1883 | key_occurences++; |
1884 | auto expected_off = table_off(key, key_occurences - 1); |
1885 | assert(off == expected_off); |
1886 | MAYBE_UNUSED(expected_off); |
1887 | off += len; |
1888 | #endif |
1889 | } |
1890 | } |
1891 | |
1892 | template <cpu_isa_t isa, typename Wmm> |
1893 | void jit_uni_eltwise_injector_f32<isa, Wmm>::register_table_entries() { |
1894 | // This function is responsible to pick all necessary constants |
1895 | // for a given algorithm, compute right offset for them to be used |
1896 | // in table_val() and save the hexadecimal value of them, which |
1897 | // will be finally used in prepare_table(). We rely on fact that |
1898 | // the map iterator order is deterministic for a fixed map. |
1899 | |
1900 | // common values used in several algorithms |
1901 | static const table_t common_values {{zero, {0x00000000, true}}, |
1902 | {half, {0x3f000000, true}}, {one, {0x3f800000, true}}, |
1903 | {two, {0x40000000, true}}, {minus_one, {0xbf800000, true}}, |
1904 | {minus_two, {0xc0000000, true}}, {ln2f, {0x3f317218, true}}, |
1905 | {positive_mask, {0x7fffffff, true}}, |
1906 | {sign_mask, {0x80000000, true}}, |
1907 | {exponent_bias, {0x0000007f, true}}}; |
1908 | |
1909 | // exp(x) constants |
1910 | static const table_t exp_consts {{exp_log2ef, {0x3fb8aa3b, true}}, |
1911 | {exp_ln_flt_max_f, {0x42b17218, true}}, |
1912 | {exp_ln_flt_min_f, {0xc2aeac50, true}}}; |
1913 | |
1914 | // exp(x) polynomial approximation |
1915 | static const table_t exp_polynomial { |
1916 | // p0 = 1.0f |
1917 | {exp_pol, {0x3f7ffffb, true}}, // p1 = 0.999999701f |
1918 | {exp_pol, {0x3efffee3, true}}, // p2 = 0.499991506f |
1919 | {exp_pol, {0x3e2aad40, true}}, // p3 = 0.166676521f |
1920 | {exp_pol, {0x3d2b9d0d, true}}, // p4 = 0.0418978221f |
1921 | {exp_pol, {0x3c07cfce, true}} // p5 = 0.00828929059f |
1922 | }; |
1923 | |
1924 | // mish(x) constants |
1925 | static const table_t mish_consts { |
1926 | {fwd_mish_max_x_for_equation_f, {0x42317217, true}}, |
1927 | {bwd_mish_max_x_for_equation_f, {0x41b17217, true}}}; |
1928 | |
1929 | // tanh(x) constants for four interval approximation |
1930 | static const table_t tanh_consts {{tanh_idx_bias, {0x39800000, true}}, |
1931 | {tanh_idx_mask, {0xffc00000, true}}, |
1932 | {tanh_linear_ubound, {0x39ddb3d7, true}}, |
1933 | {tanh_saturation_lbound, {0x41102cb3, true}}}; |
1934 | |
1935 | // tanh(x) polynomial approximation |
1936 | // For each coefficient, there is 32 entries |
1937 | static const table_t tanh_polynomial_table { |
1938 | // coefficients of degree 0 |
1939 | {tanh_pol_table, {0x00000000, false}}, |
1940 | {tanh_pol_table, {0x39bfffff, false}}, |
1941 | {tanh_pol_table, {0x39ffffff, false}}, |
1942 | {tanh_pol_table, {0x3a3ffffe, false}}, |
1943 | {tanh_pol_table, {0x3a7ffffb, false}}, |
1944 | {tanh_pol_table, {0x3abffff7, false}}, |
1945 | {tanh_pol_table, {0x3affffeb, false}}, |
1946 | {tanh_pol_table, {0x3b3fffdc, false}}, |
1947 | {tanh_pol_table, {0x3b7fffab, false}}, |
1948 | {tanh_pol_table, {0x3bbfff70, false}}, |
1949 | {tanh_pol_table, {0x3bfffeab, false}}, |
1950 | {tanh_pol_table, {0x3c3ffdc0, false}}, |
1951 | {tanh_pol_table, {0x3c7ffaab, false}}, |
1952 | {tanh_pol_table, {0x3cbff701, false}}, |
1953 | {tanh_pol_table, {0x3cffeaad, false}}, |
1954 | {tanh_pol_table, {0x3d3fdc08, false}}, |
1955 | {tanh_pol_table, {0x3d7faacd, false}}, |
1956 | {tanh_pol_table, {0x3dbf7081, false}}, |
1957 | {tanh_pol_table, {0x3dfeacc9, false}}, |
1958 | {tanh_pol_table, {0x3e3dc7fd, false}}, |
1959 | {tanh_pol_table, {0x3e7acbf5, false}}, |
1960 | {tanh_pol_table, {0x3eb77a9f, false}}, |
1961 | {tanh_pol_table, {0x3eec9a9f, false}}, |
1962 | {tanh_pol_table, {0x3f22991f, false}}, |
1963 | {tanh_pol_table, {0x3f42f7d6, false}}, |
1964 | {tanh_pol_table, {0x3f67b7cc, false}}, |
1965 | {tanh_pol_table, {0x3f76ca83, false}}, |
1966 | {tanh_pol_table, {0x3f7ebbe9, false}}, |
1967 | {tanh_pol_table, {0x3f7fd40c, false}}, |
1968 | {tanh_pol_table, {0x3f7fff32, false}}, |
1969 | {tanh_pol_table, {0x3f7ffffc, false}}, |
1970 | {tanh_pol_table, {0x3f800000, false}}, |
1971 | // coefficients of degree 1 |
1972 | {tanh_pol_table, {0x3f800000, false}}, |
1973 | {tanh_pol_table, {0x3f800018, false}}, |
1974 | {tanh_pol_table, {0x3f7fffe8, false}}, |
1975 | {tanh_pol_table, {0x3f7fffda, false}}, |
1976 | {tanh_pol_table, {0x3f7fffdc, false}}, |
1977 | {tanh_pol_table, {0x3f7fffdc, false}}, |
1978 | {tanh_pol_table, {0x3f7fffac, false}}, |
1979 | {tanh_pol_table, {0x3f7fff70, false}}, |
1980 | {tanh_pol_table, {0x3f7ffeec, false}}, |
1981 | {tanh_pol_table, {0x3f7ffdc0, false}}, |
1982 | {tanh_pol_table, {0x3f7ffbed, false}}, |
1983 | {tanh_pol_table, {0x3f7ff704, false}}, |
1984 | {tanh_pol_table, {0x3f7feff5, false}}, |
1985 | {tanh_pol_table, {0x3f7fdbca, false}}, |
1986 | {tanh_pol_table, {0x3f7fbfff, false}}, |
1987 | {tanh_pol_table, {0x3f7f7041, false}}, |
1988 | {tanh_pol_table, {0x3f7f009b, false}}, |
1989 | {tanh_pol_table, {0x3f7dc36c, false}}, |
1990 | {tanh_pol_table, {0x3f7c0aa8, false}}, |
1991 | {tanh_pol_table, {0x3f7734b8, false}}, |
1992 | {tanh_pol_table, {0x3f70a4de, false}}, |
1993 | {tanh_pol_table, {0x3f5f1fd8, false}}, |
1994 | {tanh_pol_table, {0x3f495493, false}}, |
1995 | {tanh_pol_table, {0x3f18b9ec, false}}, |
1996 | {tanh_pol_table, {0x3ed706cb, false}}, |
1997 | {tanh_pol_table, {0x3e390b06, false}}, |
1998 | {tanh_pol_table, {0x3d90b11f, false}}, |
1999 | {tanh_pol_table, {0x3c21a053, false}}, |
2000 | {tanh_pol_table, {0x3aaf7fdb, false}}, |
2001 | {tanh_pol_table, {0x37ccc1a3, false}}, |
2002 | {tanh_pol_table, {0x355c6733, false}}, |
2003 | {tanh_pol_table, {0x00000000, false}}, |
2004 | // coefficients of degree 2 |
2005 | {tanh_pol_table, {0x00000000, false}}, |
2006 | {tanh_pol_table, {0xbe4e0ff1, false}}, |
2007 | {tanh_pol_table, {0x3d25b1b1, false}}, |
2008 | {tanh_pol_table, {0x3d6b6dab, false}}, |
2009 | {tanh_pol_table, {0x3c9fb1d5, false}}, |
2010 | {tanh_pol_table, {0xbabff06f, false}}, |
2011 | {tanh_pol_table, {0x3c07b3f6, false}}, |
2012 | {tanh_pol_table, {0xbb3fc1bc, false}}, |
2013 | {tanh_pol_table, {0x3a9f5921, false}}, |
2014 | {tanh_pol_table, {0xbbbf06f2, false}}, |
2015 | {tanh_pol_table, {0xbbb0f402, false}}, |
2016 | {tanh_pol_table, {0xbc47db9e, false}}, |
2017 | {tanh_pol_table, {0xbc73d5e7, false}}, |
2018 | {tanh_pol_table, {0xbca25bda, false}}, |
2019 | {tanh_pol_table, {0xbcfca780, false}}, |
2020 | {tanh_pol_table, {0xbd40e07c, false}}, |
2021 | {tanh_pol_table, {0xbd7dab03, false}}, |
2022 | {tanh_pol_table, {0xbdbe4a0f, false}}, |
2023 | {tanh_pol_table, {0xbdfb14a5, false}}, |
2024 | {tanh_pol_table, {0xbe36cc8d, false}}, |
2025 | {tanh_pol_table, {0xbe6bd102, false}}, |
2026 | {tanh_pol_table, {0xbe9fe7c5, false}}, |
2027 | {tanh_pol_table, {0xbeba0f10, false}}, |
2028 | {tanh_pol_table, {0xbec206a8, false}}, |
2029 | {tanh_pol_table, {0xbea3c388, false}}, |
2030 | {tanh_pol_table, {0xbe277d62, false}}, |
2031 | {tanh_pol_table, {0xbd8b7960, false}}, |
2032 | {tanh_pol_table, {0xbc209f49, false}}, |
2033 | {tanh_pol_table, {0xbaad44ca, false}}, |
2034 | {tanh_pol_table, {0xb7c6eeac, false}}, |
2035 | {tanh_pol_table, {0xb663aa41, false}}, |
2036 | {tanh_pol_table, {0x00000000, false}}, |
2037 | // coefficients of degree 3 |
2038 | {tanh_pol_table, {0x00000000, false}}, |
2039 | {tanh_pol_table, {0x45b3ae96, false}}, |
2040 | {tanh_pol_table, {0xc414eb20, false}}, |
2041 | {tanh_pol_table, {0xc450e02e, false}}, |
2042 | {tanh_pol_table, {0xc3152b4e, false}}, |
2043 | {tanh_pol_table, {0xbead2f56, false}}, |
2044 | {tanh_pol_table, {0xc2162e02, false}}, |
2045 | {tanh_pol_table, {0xbeb4bd5a, false}}, |
2046 | {tanh_pol_table, {0xc11a59a4, false}}, |
2047 | {tanh_pol_table, {0xbed2f507, false}}, |
2048 | {tanh_pol_table, {0xc020d32c, false}}, |
2049 | {tanh_pol_table, {0x3dd0f506, false}}, |
2050 | {tanh_pol_table, {0xbf2a75e2, false}}, |
2051 | {tanh_pol_table, {0xbff950e3, false}}, |
2052 | {tanh_pol_table, {0xbed47334, false}}, |
2053 | {tanh_pol_table, {0xbe809b8c, false}}, |
2054 | {tanh_pol_table, {0xbeb64532, false}}, |
2055 | {tanh_pol_table, {0xbe961a5b, false}}, |
2056 | {tanh_pol_table, {0xbe9b63ac, false}}, |
2057 | {tanh_pol_table, {0xbea0d4b2, false}}, |
2058 | {tanh_pol_table, {0xbe828a77, false}}, |
2059 | {tanh_pol_table, {0xbe378612, false}}, |
2060 | {tanh_pol_table, {0xbdc20908, false}}, |
2061 | {tanh_pol_table, {0x3d2d3957, false}}, |
2062 | {tanh_pol_table, {0x3dd46e89, false}}, |
2063 | {tanh_pol_table, {0x3db3f629, false}}, |
2064 | {tanh_pol_table, {0x3d2c5e7b, false}}, |
2065 | {tanh_pol_table, {0x3bd20403, false}}, |
2066 | {tanh_pol_table, {0x3a59dfae, false}}, |
2067 | {tanh_pol_table, {0x3770af45, false}}, |
2068 | {tanh_pol_table, {0x372cc014, false}}, |
2069 | {tanh_pol_table, {0x00000000, false}}, |
2070 | // coefficients of degree 4 |
2071 | {tanh_pol_table, {0x00000000, false}}, |
2072 | {tanh_pol_table, {0xcc981a1b, false}}, |
2073 | {tanh_pol_table, {0x4a7edd3d, false}}, |
2074 | {tanh_pol_table, {0x4ab1007c, false}}, |
2075 | {tanh_pol_table, {0x48fedd9c, false}}, |
2076 | {tanh_pol_table, {0x41a557b5, false}}, |
2077 | {tanh_pol_table, {0x477ee32a, false}}, |
2078 | {tanh_pol_table, {0x422557f5, false}}, |
2079 | {tanh_pol_table, {0x45ff3ce4, false}}, |
2080 | {tanh_pol_table, {0x42a55641, false}}, |
2081 | {tanh_pol_table, {0x446e0867, false}}, |
2082 | {tanh_pol_table, {0xc33dc19a, false}}, |
2083 | {tanh_pol_table, {0x42915214, false}}, |
2084 | {tanh_pol_table, {0x43af4fad, false}}, |
2085 | {tanh_pol_table, {0x4110fe88, false}}, |
2086 | {tanh_pol_table, {0xc1099b75, false}}, |
2087 | {tanh_pol_table, {0x3fc8a8dc, false}}, |
2088 | {tanh_pol_table, {0xbfbeaef5, false}}, |
2089 | {tanh_pol_table, {0xbe365aad, false}}, |
2090 | {tanh_pol_table, {0x3f4d9652, false}}, |
2091 | {tanh_pol_table, {0x3ddfa08f, false}}, |
2092 | {tanh_pol_table, {0x3e34e9b8, false}}, |
2093 | {tanh_pol_table, {0x3e2d07a6, false}}, |
2094 | {tanh_pol_table, {0x3dc63567, false}}, |
2095 | {tanh_pol_table, {0x3cdaeb78, false}}, |
2096 | {tanh_pol_table, {0xbcd17537, false}}, |
2097 | {tanh_pol_table, {0xbc92829c, false}}, |
2098 | {tanh_pol_table, {0xbb43ab99, false}}, |
2099 | {tanh_pol_table, {0xb9b471dd, false}}, |
2100 | {tanh_pol_table, {0xb6baad5a, false}}, |
2101 | {tanh_pol_table, {0xb78bafc7, false}}, |
2102 | {tanh_pol_table, {0x00000000, false}}, |
2103 | // coefficients of degree 5 |
2104 | {tanh_pol_table, {0x00000000, false}}, |
2105 | {tanh_pol_table, {0x52f688d5, false}}, |
2106 | {tanh_pol_table, {0xd0505c72, false}}, |
2107 | {tanh_pol_table, {0xd08f98e3, false}}, |
2108 | {tanh_pol_table, {0xce505cc9, false}}, |
2109 | {tanh_pol_table, {0xc7162b8a, false}}, |
2110 | {tanh_pol_table, {0xcc5061d6, false}}, |
2111 | {tanh_pol_table, {0xc7162bdf, false}}, |
2112 | {tanh_pol_table, {0xca50b37f, false}}, |
2113 | {tanh_pol_table, {0xc7162a3a, false}}, |
2114 | {tanh_pol_table, {0xc8422086, false}}, |
2115 | {tanh_pol_table, {0x471a714e, false}}, |
2116 | {tanh_pol_table, {0xc5ece1f1, false}}, |
2117 | {tanh_pol_table, {0xc70e3d90, false}}, |
2118 | {tanh_pol_table, {0xc3eba94a, false}}, |
2119 | {tanh_pol_table, {0x43e0c424, false}}, |
2120 | {tanh_pol_table, {0xc21f4552, false}}, |
2121 | {tanh_pol_table, {0x42217cc8, false}}, |
2122 | {tanh_pol_table, {0x405e7dc4, false}}, |
2123 | {tanh_pol_table, {0xc10dd401, false}}, |
2124 | {tanh_pol_table, {0x3e96b602, false}}, |
2125 | {tanh_pol_table, {0xbd1a6d2f, false}}, |
2126 | {tanh_pol_table, {0xbd393883, false}}, |
2127 | {tanh_pol_table, {0xbd674682, false}}, |
2128 | {tanh_pol_table, {0xbd310016, false}}, |
2129 | {tanh_pol_table, {0xb961e269, false}}, |
2130 | {tanh_pol_table, {0x3ba32495, false}}, |
2131 | {tanh_pol_table, {0x3a7680d5, false}}, |
2132 | {tanh_pol_table, {0x38b3173c, false}}, |
2133 | {tanh_pol_table, {0x35a9deea, false}}, |
2134 | {tanh_pol_table, {0x375c3f2a, false}}, |
2135 | {tanh_pol_table, {0x00000000, false}}, |
2136 | // coefficients of degree 6 |
2137 | {tanh_pol_table, {0x00000000, false}}, |
2138 | {tanh_pol_table, {0xd8995ed1, false}}, |
2139 | {tanh_pol_table, {0x558285ea, false}}, |
2140 | {tanh_pol_table, {0x55b2cd69, false}}, |
2141 | {tanh_pol_table, {0x53028625, false}}, |
2142 | {tanh_pol_table, {0x4bc9991f, false}}, |
2143 | {tanh_pol_table, {0x5082898a, false}}, |
2144 | {tanh_pol_table, {0x4b4999b3, false}}, |
2145 | {tanh_pol_table, {0x4e02c07c, false}}, |
2146 | {tanh_pol_table, {0x4ac99764, false}}, |
2147 | {tanh_pol_table, {0x4b72c822, false}}, |
2148 | {tanh_pol_table, {0xca40c0e1, false}}, |
2149 | {tanh_pol_table, {0x489413e4, false}}, |
2150 | {tanh_pol_table, {0x49b12224, false}}, |
2151 | {tanh_pol_table, {0x46134c4e, false}}, |
2152 | {tanh_pol_table, {0xc60c2d57, false}}, |
2153 | {tanh_pol_table, {0x43c83910, false}}, |
2154 | {tanh_pol_table, {0xc3c872d1, false}}, |
2155 | {tanh_pol_table, {0xc186bc9e, false}}, |
2156 | {tanh_pol_table, {0x42325bc3, false}}, |
2157 | {tanh_pol_table, {0xbf2ffa4a, false}}, |
2158 | {tanh_pol_table, {0x3d9a203c, false}}, |
2159 | {tanh_pol_table, {0xbc545a43, false}}, |
2160 | {tanh_pol_table, {0xbae08fee, false}}, |
2161 | {tanh_pol_table, {0x3c80225d, false}}, |
2162 | {tanh_pol_table, {0x3b1fd1df, false}}, |
2163 | {tanh_pol_table, {0xba36b9d1, false}}, |
2164 | {tanh_pol_table, {0xb91de544, false}}, |
2165 | {tanh_pol_table, {0xb71f100f, false}}, |
2166 | {tanh_pol_table, {0xb408e2ed, false}}, |
2167 | {tanh_pol_table, {0xb685fec8, false}}, |
2168 | {tanh_pol_table, {0x00000000, false}}, |
2169 | }; |
2170 | |
2171 | // soft_relu(x) constants |
2172 | static const table_t soft_relu_consts { |
2173 | {soft_relu_one_twenty_six, {0x42fc0000, true}}, |
2174 | {soft_relu_mantissa_sign_mask, {0x807fffff, true}}, |
2175 | }; |
2176 | |
2177 | // soft_relu ln(1 + x) polynomial approximation |
2178 | static const table_t soft_relu_polynomial { |
2179 | {soft_relu_pol, {0xb2b4637d, true}}, // p0 = 0.0000000244f |
2180 | {soft_relu_pol, {0x3f7fff8e, true}}, // p1 = 0.9999976971f |
2181 | {soft_relu_pol, {0xbf001759, true}}, // p2 = -0.5002478215f |
2182 | {soft_relu_pol, {0x3ea70608, true}}, // p3 = 0.3272714505f |
2183 | {soft_relu_pol, {0xbea3d7bf, true}}, // p4 = -0.3153830071f |
2184 | {soft_relu_pol, {0xbe361d04, true}}, // p5 = -0.1701777461f |
2185 | {soft_relu_pol, {0xbfa8f1e6, true}}, // p6 = -1.3254635147f |
2186 | {soft_relu_pol, {0xbfe1e812, true}}, // p7 = -1.7971917960f |
2187 | {soft_relu_pol, {0xbfc4d30e, true}}, // p8 = -1.5652673123f |
2188 | }; |
2189 | |
2190 | // gelu_tanh(x) constants (formula defined) |
2191 | static const table_t gelu_tanh_consts { |
2192 | {gelu_tanh_fitting_const, {0x3d372713, true}}, |
2193 | {gelu_tanh_fitting_const_times_three, {0x3e095d4f, true}}, |
2194 | {gelu_tanh_sqrt_two_over_pi, {0x3f4c422a, true}}, |
2195 | }; |
2196 | |
2197 | // gelu_erf(x) constants for approximation based on Abramowitz and Stegun |
2198 | // algorithm (formula defined) |
2199 | static const table_t gelu_erf_Abramowitz_Stegun_consts { |
2200 | {gelu_erf_Abramowitz_Stegun_approx_const, {0x3ea7ba05, true}}, |
2201 | {gelu_erf_Abramowitz_Stegun_one_over_sqrt_two, {0x3f3504f3, true}}, |
2202 | {gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi, {0x3f106eba, true}}, |
2203 | }; |
2204 | |
2205 | // gelu_erf(x) polynomial approximation based on Abramowitz and Stegun |
2206 | // algorithm |
2207 | static const table_t gelu_erf_Abramowitz_Stegun_polynomial { |
2208 | // p1 = 0.254829592f |
2209 | {gelu_erf_Abramowitz_Stegun_pol, {0x3e827906, true}}, |
2210 | // p2 = -0.284496736f |
2211 | {gelu_erf_Abramowitz_Stegun_pol, {0xbe91a98e, true}}, |
2212 | // p3 = 1.421413741f |
2213 | {gelu_erf_Abramowitz_Stegun_pol, {0x3fb5f0e3, true}}, |
2214 | // p4 = -1.453152027f |
2215 | {gelu_erf_Abramowitz_Stegun_pol, {0xbfba00e3, true}}, |
2216 | // p5 = 1.061405429f |
2217 | {gelu_erf_Abramowitz_Stegun_pol, {0x3f87dc22, true}}, |
2218 | }; |
2219 | |
2220 | // gelu_erf(x) constants for direct erf approximation (formula defined) |
2221 | static const table_t gelu_erf_minimax_consts { |
2222 | // x <= -0x1.4p+2 -> return 0.0f |
2223 | {gelu_erf_minimax_neg_saturation_ubound, {0xc0a00000, true}}, |
2224 | // |x| <= 0x1.0p-24 -> return 0.5f * x |
2225 | {gelu_erf_minimax_linear_ubound, {0x33800000, true}}, |
2226 | // x >= 0x1.4p+2 -> return x |
2227 | {gelu_erf_minimax_saturation_lbound, {0x40a00000, true}}}; |
2228 | |
2229 | // gelu_erf(x) polynomial for direct erf approximation (formula defined) |
2230 | static const table_t gelu_erf_minimax_polynomial { |
2231 | {gelu_erf_minimax_pol, {0x3f4c4228, true}}, // p0 = 0x1.98845p-1 |
2232 | {gelu_erf_minimax_pol, {0xbe082bc7, true}}, // p1 = -0x1.10578ep-3 |
2233 | {gelu_erf_minimax_pol, {0x3ca3621f, true}}, // p2 = 0x1.46c43ep-6 |
2234 | {gelu_erf_minimax_pol, {0xbb1b7399, true}}, // p3 = -0x1.36e732p-9 |
2235 | {gelu_erf_minimax_pol, {0x3970b255, true}}, // p4 = 0x1.e164aap-13 |
2236 | {gelu_erf_minimax_pol, {0xb79b0914, true}}, // p5 = -0x1.361228p-16 |
2237 | {gelu_erf_minimax_pol, {0x35a776e9, true}}, // p6 = 0x1.4eedd2p-20 |
2238 | {gelu_erf_minimax_pol, {0xb3969b11, true}}, // p7 = -0x1.2d3622p-24 |
2239 | {gelu_erf_minimax_pol, {0x315d4a4f, true}}, // p8 = 0x1.ba949ep-29 |
2240 | {gelu_erf_minimax_pol, {0xaf013b2c, true}}, // p9 = -0x1.027658p-33 |
2241 | {gelu_erf_minimax_pol, {0x2c67ddb2, true}}, // p10 = 0x1.cfbb64p-39 |
2242 | {gelu_erf_minimax_pol, {0xa998c963, true}}, // p11 = -0x1.3192c6p-44 |
2243 | {gelu_erf_minimax_pol, {0x268a7927, true}}, // p12 = 0x1.14f24ep-50 |
2244 | {gelu_erf_minimax_pol, {0xa3198977, true}}, // p13 = -0x1.3312eep-57 |
2245 | {gelu_erf_minimax_pol, {0x1f1c83fd, true}}, // p14 = 0x1.3907fap-65 |
2246 | }; |
2247 | |
2248 | // log(x) constants |
2249 | static const table_t log_consts { |
2250 | {log_inf, {0x7f800000, true}}, |
2251 | {log_minus_inf, {0xff800000, true}}, |
2252 | {log_qnan, {0x7fc00000, true}}, |
2253 | {log_mantissa_mask, {0x007fffff, true}}, |
2254 | {log_full_k_reg_mask, {0x0000ffff, true}}, |
2255 | {log_five_bit_offset, {0x0000001f, true}}, |
2256 | }; |
2257 | |
2258 | // log(x) polynomial approximation |
2259 | static const table_t log_polynomial { |
2260 | {log_pol, {0xbf000000, true}}, // p1 = -0.5f |
2261 | {log_pol, {0x3eaaaaab, true}}, // p2 = 0.333333343f |
2262 | {log_pol, {0xbe8004ab, true}}, // p3 = -0.250035613f |
2263 | {log_pol, {0x3e4cc8a3, true}}, // p4 = 0.199984118f |
2264 | }; |
2265 | |
2266 | // log(x) pre-defined values. First goes index}, then val[index]. |
2267 | static const table_t log_predefined_values { |
2268 | {log_predefined_vals, {0x3f800000, true}}, // 0: 1 |
2269 | {log_predefined_vals, |
2270 | {0xc2b00f34, true}}, // 1: -88.029693603515625 |
2271 | {log_predefined_vals, {0x3f780000, true}}, // 2: 0.96875 |
2272 | {log_predefined_vals, |
2273 | {0xc2affef2, true}}, // 3: -87.9979400634765625 |
2274 | {log_predefined_vals, {0x3f700000, true}}, // 4: 0.9375 |
2275 | {log_predefined_vals, |
2276 | {0xc2afee29, true}}, // 5: -87.9651565551757812 |
2277 | {log_predefined_vals, {0x3f680000, true}}, // 6: 0.90625 |
2278 | {log_predefined_vals, |
2279 | {0xc2afdccd, true}}, // 7: -87.9312515258789062 |
2280 | {log_predefined_vals, {0x3f600000, true}}, // 8: 0.875 |
2281 | {log_predefined_vals, |
2282 | {0xc2afcad6, true}}, // 9: -87.8961639404296875 |
2283 | {log_predefined_vals, {0x3f580000, true}}, // 10: 0.84375 |
2284 | {log_predefined_vals, |
2285 | {0xc2afb837, true}}, // 11: -87.859794616699218 |
2286 | {log_predefined_vals, {0x3f580000, true}}, // 12: 0.84375 |
2287 | {log_predefined_vals, |
2288 | {0xc2afb837, true}}, // 13: -87.859794616699218 |
2289 | {log_predefined_vals, {0x3f500000, true}}, // 14: 0.8125 |
2290 | {log_predefined_vals, |
2291 | {0xc2afa4e4, true}}, // 15: -87.822052001953125 |
2292 | {log_predefined_vals, {0x3f480000, true}}, // 16: 0.78125 |
2293 | {log_predefined_vals, |
2294 | {0xc2af90cf, true}}, // 17: -87.782829284667968 |
2295 | {log_predefined_vals, {0x3f480000, true}}, // 18: 0.78125 |
2296 | {log_predefined_vals, |
2297 | {0xc2af90cf, true}}, // 19: -87.782829284667968 |
2298 | {log_predefined_vals, {0x3f400000, true}}, // 20: 0.75 |
2299 | {log_predefined_vals, |
2300 | {0xc2af7be9, true}}, // 21: -87.742012023925781 |
2301 | {log_predefined_vals, {0x3f400000, true}}, // 22: 0.75 |
2302 | {log_predefined_vals, |
2303 | {0xc2af7be9, true}}, // 23: -87.742012023925781 |
2304 | {log_predefined_vals, {0x3f380000, true}}, // 24: 0.71875 |
2305 | {log_predefined_vals, |
2306 | {0xc2af661e, true}}, // 25: -87.699447631835937 |
2307 | {log_predefined_vals, {0x3f380000, true}}, // 26: 0.71875 |
2308 | {log_predefined_vals, |
2309 | {0xc2af661e, true}}, // 27: -87.699447631835937 |
2310 | {log_predefined_vals, {0x3f300000, true}}, // 28: 0.6875 |
2311 | {log_predefined_vals, |
2312 | {0xc2af4f5c, true}}, // 29: -87.654998779296875 |
2313 | {log_predefined_vals, {0x3f300000, true}}, // 30: 0.6875 |
2314 | {log_predefined_vals, |
2315 | {0xc2af4f5c, true}}, // 31: -87.654998779296875 |
2316 | {log_predefined_vals, {0x3fa80000, true}}, // 32: 1.3125 |
2317 | {log_predefined_vals, |
2318 | {0xc2b09a6f, true}}, // 33: -88.301628112792968 |
2319 | {log_predefined_vals, {0x3fa80000, true}}, // 34: 1.3125 |
2320 | {log_predefined_vals, |
2321 | {0xc2b09a6f, true}}, // 35: -88.301628112792968 |
2322 | {log_predefined_vals, {0x3fa00000, true}}, // 36: 1.25 |
2323 | {log_predefined_vals, |
2324 | {0xc2b08174, true}}, // 37: -88.252838134765625 |
2325 | {log_predefined_vals, {0x3fa00000, true}}, // 38: 1.25 |
2326 | {log_predefined_vals, |
2327 | {0xc2b08174, true}}, // 39: -88.252838134765625 |
2328 | {log_predefined_vals, {0x3fa00000, true}}, // 40: 1.25 |
2329 | {log_predefined_vals, |
2330 | {0xc2b08174, true}}, // 41: -88.252838134765625 |
2331 | {log_predefined_vals, {0x3f980000, true}}, // 42: 1.1875 |
2332 | {log_predefined_vals, |
2333 | {0xc2b06731, true}}, // 43: -88.201545715332031 |
2334 | {log_predefined_vals, {0x3f980000, true}}, // 44: 1.1875 |
2335 | {log_predefined_vals, |
2336 | {0xc2b06731, true}}, // 45: -88.201545715332031 |
2337 | {log_predefined_vals, {0x3f900000, true}}, // 46: 1.125 |
2338 | {log_predefined_vals, |
2339 | {0xc2b04b82, true}}, // 47: -88.147476196289062 |
2340 | {log_predefined_vals, {0x3f900000, true}}, // 48: 1.125 |
2341 | {log_predefined_vals, |
2342 | {0xc2b04b82, true}}, // 49: -88.147476196289062 |
2343 | {log_predefined_vals, {0x3f900000, true}}, // 50: 1.125 |
2344 | {log_predefined_vals, |
2345 | {0xc2b04b82, true}}, // 51: -88.147476196289062 |
2346 | {log_predefined_vals, {0x3f900000, true}}, // 52: 1.125 |
2347 | {log_predefined_vals, |
2348 | {0xc2b04b82, true}}, // 53: -88.147476196289062 |
2349 | {log_predefined_vals, {0x3f880000, true}}, // 54: 1.0625 |
2350 | {log_predefined_vals, |
2351 | {0xc2b02e3e, true}}, // 55: -88.090316772460937 |
2352 | {log_predefined_vals, {0x3f880000, true}}, // 56: 1.0625 |
2353 | {log_predefined_vals, |
2354 | {0xc2b02e3e, true}}, // 57: -88.090316772460937 |
2355 | {log_predefined_vals, {0x3f880000, true}}, // 58: 1.0625 |
2356 | {log_predefined_vals, |
2357 | {0xc2b02e3e, true}}, // 59: -88.090316772460937 |
2358 | {log_predefined_vals, {0x3f800000, true}}, // 60: 1 |
2359 | {log_predefined_vals, |
2360 | {0xc2b00f34, true}}, // 61: -88.029693603515625 |
2361 | {log_predefined_vals, {0x3f800000, true}}, // 62: 1 |
2362 | {log_predefined_vals, |
2363 | {0xc2b00f34, true}}, // 63: -88.029693603515625 |
2364 | }; |
2365 | |
2366 | // This object takes care about which constants and polynomials to include. |
2367 | struct need_t { |
2368 | need_t(alg_kind_t alg) { |
2369 | using namespace alg_kind; |
2370 | switch (alg) { |
2371 | case eltwise_elu_use_dst_for_bwd: |
2372 | case eltwise_elu: |
2373 | case eltwise_exp_use_dst_for_bwd: |
2374 | case eltwise_exp: |
2375 | case eltwise_logistic_use_dst_for_bwd: |
2376 | case eltwise_logistic: |
2377 | case eltwise_swish: exp_ = true; break; |
2378 | case eltwise_gelu_erf: gelu_erf_ = true; break; |
2379 | case eltwise_gelu_tanh: gelu_tanh_ = true; break; |
2380 | case eltwise_log: log_ = true; break; |
2381 | case eltwise_soft_relu: soft_relu_ = true; break; |
2382 | case eltwise_mish: mish_ = true; break; |
2383 | case eltwise_tanh_use_dst_for_bwd: |
2384 | case eltwise_tanh: tanh_ = true; break; |
2385 | default: break; |
2386 | } |
2387 | } |
2388 | |
2389 | bool exp_ = false; |
2390 | bool mish_ = false; |
2391 | bool tanh_ = false; |
2392 | bool soft_relu_ = false; |
2393 | bool gelu_tanh_ = false; |
2394 | bool gelu_erf_ = false; |
2395 | bool log_ = false; |
2396 | |
2397 | bool exp() const { return exp_ || soft_relu_ || gelu_erf_ || mish_; } |
2398 | bool mish() const { return mish_; } |
2399 | bool tanh() const { return tanh_ || gelu_tanh_; } |
2400 | bool soft_relu() const { return soft_relu_; } |
2401 | bool gelu_tanh() const { return gelu_tanh_; } |
2402 | bool gelu_erf() const { return gelu_erf_; } |
2403 | bool log() const { return log_; } |
2404 | }; |
2405 | |
2406 | need_t need(alg_); |
2407 | |
2408 | auto push_arg_entry_of = [&](const key_t key, const table_entry_val_t val, |
2409 | const bool broadcast) { |
2410 | mapped_table_entry_t te {0, val, broadcast}; |
2411 | entry_map_.insert(std::make_pair(key, te)); |
2412 | }; |
2413 | |
2414 | auto push_entries_of = [&](const table_t &t) { |
2415 | for (auto it = t.begin(); it != t.end(); it++) { |
2416 | auto key = (*it).first; |
2417 | auto te = (*it).second; // copy values from table |
2418 | push_arg_entry_of(key, te.val, te.bcast); |
2419 | } |
2420 | }; |
2421 | |
2422 | push_arg_entry_of(scale, float2int(scale_), true); |
2423 | push_arg_entry_of(alpha, float2int(alpha_), true); |
2424 | push_arg_entry_of(beta, float2int(beta_), true); |
2425 | push_entries_of(common_values); |
2426 | if (need.exp()) push_entries_of(exp_consts); |
2427 | if (need.exp()) push_entries_of(exp_polynomial); |
2428 | if (need.mish()) push_entries_of(mish_consts); |
2429 | if (need.tanh()) push_entries_of(tanh_consts); |
2430 | if (need.tanh()) push_entries_of(tanh_polynomial_table); |
2431 | if (need.soft_relu()) push_entries_of(soft_relu_consts); |
2432 | if (need.soft_relu()) push_entries_of(soft_relu_polynomial); |
2433 | if (need.gelu_tanh()) push_entries_of(gelu_tanh_consts); |
2434 | if (need.gelu_erf()) push_entries_of(gelu_erf_Abramowitz_Stegun_consts); |
2435 | if (need.gelu_erf()) push_entries_of(gelu_erf_Abramowitz_Stegun_polynomial); |
2436 | if (need.gelu_erf() && is_superset(isa, avx512_core)) |
2437 | push_entries_of(gelu_erf_minimax_consts); |
2438 | if (need.gelu_erf() && is_superset(isa, avx512_core)) |
2439 | push_entries_of(gelu_erf_minimax_polynomial); |
2440 | |
2441 | if (need.log()) push_entries_of(log_consts); |
2442 | if (need.log()) push_entries_of(log_polynomial); |
2443 | if (need.log()) push_entries_of(log_predefined_values); |
2444 | |
2445 | // Now that we registered the entries, we set the offsets. No |
2446 | // entries should be registered after this point. This allows to |
2447 | // expect the same order when injecting the table entries in |
2448 | // prepare_table. |
2449 | size_t off = 0; |
2450 | for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { |
2451 | auto &te = (*it).second; |
2452 | te.off = off; |
2453 | off += te.bcast ? vlen : sizeof(table_entry_val_t); |
2454 | } |
2455 | } |
2456 | |
2457 | template struct jit_uni_eltwise_injector_f32<avx512_core_fp16>; |
2458 | template struct jit_uni_eltwise_injector_f32<avx512_core_fp16, Xbyak::Ymm>; |
2459 | template struct jit_uni_eltwise_injector_f32<avx512_core_fp16, Xbyak::Xmm>; |
2460 | template struct jit_uni_eltwise_injector_f32<avx512_core_bf16>; |
2461 | template struct jit_uni_eltwise_injector_f32<avx512_core>; |
2462 | template struct jit_uni_eltwise_injector_f32<avx512_core, Ymm>; |
2463 | template struct jit_uni_eltwise_injector_f32<avx512_core, Xmm>; |
2464 | template struct jit_uni_eltwise_injector_f32<avx2_vnni_2>; |
2465 | template struct jit_uni_eltwise_injector_f32<avx2>; |
2466 | template struct jit_uni_eltwise_injector_f32<avx2, Xmm>; |
2467 | template struct jit_uni_eltwise_injector_f32<avx>; |
2468 | template struct jit_uni_eltwise_injector_f32<avx, Xmm>; |
2469 | template struct jit_uni_eltwise_injector_f32<sse41>; |
2470 | |
2471 | } // namespace x64 |
2472 | } // namespace cpu |
2473 | } // namespace impl |
2474 | } // namespace dnnl |
2475 | |