1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_ELTWISE_INJECTOR_HPP |
18 | #define CPU_X64_JIT_UNI_ELTWISE_INJECTOR_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <type_traits> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/primitive_attr.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/x64/injectors/injector_utils.hpp" |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | namespace eltwise_injector { |
37 | struct static_params_t { |
38 | |
39 | static_params_t(bool save_state = true, |
40 | Xbyak::Reg64 p_table = Xbyak::util::rax, |
41 | Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true, |
42 | bool use_dst = false, bool preserve_vmm = true, |
43 | bool preserve_p_table = true) |
44 | : save_state(save_state) |
45 | , p_table(p_table) |
46 | , k_mask(k_mask) |
47 | , is_fwd(is_fwd) |
48 | , use_dst(use_dst) |
49 | , preserve_vmm(preserve_vmm) |
50 | , preserve_p_table(preserve_p_table) {} |
51 | |
52 | bool save_state; |
53 | Xbyak::Reg64 p_table; |
54 | Xbyak::Opmask k_mask; |
55 | bool is_fwd; |
56 | bool use_dst; |
57 | bool preserve_vmm; |
58 | bool preserve_p_table; |
59 | }; |
60 | |
61 | /* |
62 | * Checks if isa is supported by eltwise injector. |
63 | */ |
64 | bool is_isa_supported(cpu_isa_t isa); |
65 | |
66 | /* |
67 | * Checks if eltwise algorithm is supported by eltwise injector. |
68 | */ |
69 | bool is_alg_supported(alg_kind_t alg); |
70 | |
71 | /* |
72 | * Checks if eltwise injection for given args is supported. |
73 | */ |
74 | bool is_supported(cpu_isa_t isa, alg_kind_t alg); |
75 | |
76 | } // namespace eltwise_injector |
77 | |
78 | template <cpu_isa_t isa, typename Wmm = typename cpu_isa_traits<isa>::Vmm> |
79 | struct jit_uni_eltwise_injector_f32 { |
80 | using Vmm = Wmm; |
81 | |
82 | // Arguments description: |
83 | // host - jit generator which is filled with instructions |
84 | // alg, alpha, beta, scale - user eltwise arguments |
85 | // save_state - when true, preserves on stack vmm_aux registers preventing |
86 | // results spoiling. Restores them when done in injector_postamble(). |
87 | // p_table - GPR where table label is stored to get access for pre-defined |
88 | // constants used in alg codes. |
89 | // k_mask - k_register to operate with masks in alg codes. |
90 | // is_fwd - when true, computes d = alg(s), otherwise, computes ds = alg'(s) |
91 | // - algorithm derivative. |
92 | // use_dst - defines whether source or destination point is passed to alg |
93 | // code. Depends on algorithm. See `_use_dst_for_bwd` algs definition. |
94 | jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, |
95 | float alpha, float beta, float scale, bool save_state = true, |
96 | Xbyak::Reg64 p_table = Xbyak::util::rax, |
97 | Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true, |
98 | bool use_dst = false, bool preserve_vmm = true, |
99 | bool preserve_p_table = true) |
100 | : alg_(alg) |
101 | , alpha_(alpha) |
102 | , beta_(beta) |
103 | , scale_(scale) |
104 | , h(host) |
105 | , save_state_(save_state) |
106 | , p_table(p_table) |
107 | , k_mask(k_mask) |
108 | , is_fwd_(is_fwd) |
109 | , use_dst_(use_dst) |
110 | , preserve_vmm_(preserve_vmm) |
111 | , preserve_p_table_(preserve_p_table) { |
112 | assert(eltwise_injector::is_supported(isa, alg_)); |
113 | |
114 | register_table_entries(); |
115 | } |
116 | |
117 | jit_uni_eltwise_injector_f32(jit_generator *host, |
118 | const post_ops_t::entry_t::eltwise_t &eltwise, |
119 | bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax, |
120 | Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true, |
121 | bool use_dst = false, bool preserve_vmm = true, |
122 | bool preserve_p_table = true) |
123 | : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha, |
124 | eltwise.beta, eltwise.scale, save_state, p_table, k_mask, |
125 | is_fwd, use_dst, preserve_vmm, preserve_p_table) {} |
126 | |
127 | void compute_vector_range(size_t start_idx, size_t end_idx); |
128 | void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs); |
129 | void compute_vector(size_t idx) { compute_vector_range({idx}); } |
130 | void prepare_table(bool gen_table = true); |
131 | void load_table_addr() { h->mov(p_table, l_table); } |
132 | |
133 | private: |
134 | const alg_kind_t alg_; |
135 | const float alpha_; |
136 | const float beta_; |
137 | const float scale_; |
138 | |
139 | jit_generator *const h; |
140 | |
141 | const bool save_state_; |
142 | const Xbyak::Reg64 p_table; |
143 | const Xbyak::Opmask k_mask; |
144 | const bool is_fwd_; |
145 | const bool use_dst_; |
146 | const bool preserve_vmm_; |
147 | const bool preserve_p_table_; |
148 | |
149 | Xbyak::Label l_table; |
150 | |
151 | // if only the injector was inherited from jit_generator... |
152 | enum { |
153 | _cmp_eq_oq = jit_generator::_cmp_eq_oq, |
154 | _cmp_neq_uq = jit_generator::_cmp_neq_uq, |
155 | _cmp_lt_os = jit_generator::_cmp_lt_os, |
156 | _cmp_le_os = jit_generator::_cmp_le_os, |
157 | _cmp_ge_os = jit_generator::_cmp_nlt_us, |
158 | _cmp_gt_os = jit_generator::_cmp_nle_us, |
159 | _op_floor = jit_generator::_op_floor, |
160 | _op_mxcsr = jit_generator::_op_mxcsr |
161 | }; |
162 | |
163 | const bool is_avx512 = is_superset(isa, avx512_core); |
164 | |
165 | static constexpr size_t vlen = vreg_traits<Vmm>::vlen; |
166 | static constexpr size_t preserved_vecs_max = 6; |
167 | static constexpr size_t preserved_gprs_max = 5; |
168 | static constexpr size_t vecs_count = cpu_isa_traits<isa>::n_vregs; |
169 | static constexpr int n_mantissa_bits = 23; |
170 | static constexpr int k_mask_size = 8; |
171 | |
172 | bool preserve_vec_for_avx = false; |
173 | |
174 | size_t vecs_to_preserve = 0; |
175 | size_t preserved_vecs_count = 0; |
176 | size_t preserved_vec_idxs[preserved_vecs_max] = {0}; |
177 | size_t preserved_gpr_idxs[preserved_gprs_max] = {0}; |
178 | injector_utils::vmm_index_set_iterator_t start_idx_tail; |
179 | |
180 | Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4, vmm_tmp; |
181 | Xbyak::Ymm ymm_tmp; |
182 | Xbyak::Xmm xmm_tmp; |
183 | |
184 | size_t aux_vecs_count(); |
185 | size_t aux_gprs_count(); |
186 | |
187 | void compute_body( |
188 | const injector_utils::vmm_index_set_iterator_t &start_idx_it, |
189 | const injector_utils::vmm_index_set_iterator_t &end_idx_it); |
190 | void injector_preamble(const injector_utils::vmm_index_set_t &vmm_idxs); |
191 | void injector_preamble_tail( |
192 | const injector_utils::vmm_index_set_iterator_t start_idx_it); |
193 | void injector_postamble(); |
194 | void assign_regs(); |
195 | void vec_shift(const Vmm &vmm_dst, const Vmm &vmm_src, bool shift_left, |
196 | const int imm); |
197 | void compute_cmp_mask(const Vmm &vmm_src, |
198 | const Xbyak::Operand &compare_operand, int cmp_predicate); |
199 | void blend_with_mask(const Vmm &vmm_dst, const Xbyak::Operand &src); |
200 | void test_mask(); |
201 | |
202 | void exp_compute_vector_fwd(const Vmm &vmm_src); |
203 | void relu_compute_vector_fwd(const Vmm &vmm_src); |
204 | void relu_zero_ns_compute_vector_fwd(const Vmm &vmm_src); |
205 | void elu_compute_vector_fwd(const Vmm &vmm_src); |
206 | void tanh_compute_vector_fwd(const Vmm &vmm_src); |
207 | void square_compute_vector_fwd(const Vmm &vmm_src); |
208 | void abs_compute_vector_fwd(const Vmm &vmm_src); |
209 | void sqrt_compute_vector_fwd(const Vmm &vmm_src); |
210 | void linear_compute_vector_fwd(const Vmm &vmm_src); |
211 | void soft_relu_compute_vector_fwd(const Vmm &vmm_src); |
212 | void mish_compute_vector_fwd(const Vmm &vmm_src); |
213 | void logistic_compute_vector_fwd(const Vmm &vmm_src); |
214 | void gelu_tanh_compute_vector_fwd(const Vmm &vmm_src); |
215 | void swish_compute_vector_fwd(const Vmm &vmm_src); |
216 | void log_compute_vector_fwd(const Vmm &vmm_src); |
217 | void clip_compute_vector_fwd(const Vmm &vmm_src); |
218 | void pow_compute_vector_fwd(const Vmm &vmm_src); |
219 | void gelu_erf_compute_vector_fwd(const Vmm &vmm_src); |
220 | void gelu_erf_minimax_approx_compute_vector_fwd(const Vmm &vmm_src); |
221 | void round_compute_vector_fwd(const Vmm &vmm_src); |
222 | void hardswish_compute_vector_fwd(const Vmm &vmm_src); |
223 | void hardsigmoid_compute_vector_fwd(const Vmm &vmm_src); |
224 | |
225 | void exp_compute_vector_bwd(const Vmm &vmm_src); |
226 | void relu_compute_vector_bwd(const Vmm &vmm_src); |
227 | void elu_compute_vector_bwd(const Vmm &vmm_src); |
228 | void tanh_compute_vector_bwd(const Vmm &vmm_src); |
229 | void square_compute_vector_bwd(const Vmm &vmm_src); |
230 | void abs_compute_vector_bwd(const Vmm &vmm_src); |
231 | void sqrt_compute_vector_bwd(const Vmm &vmm_src); |
232 | void linear_compute_vector_bwd(const Vmm &vmm_src); |
233 | void soft_relu_compute_vector_bwd(const Vmm &vmm_src); |
234 | void logistic_compute_vector_bwd(const Vmm &vmm_src); |
235 | void mish_compute_vector_bwd(const Vmm &vmm_src); |
236 | void gelu_tanh_compute_vector_bwd(const Vmm &vmm_src); |
237 | void swish_compute_vector_bwd(const Vmm &vmm_src); |
238 | void log_compute_vector_bwd(const Vmm &vmm_src); |
239 | void clip_compute_vector_bwd(const Vmm &vmm_src); |
240 | void pow_compute_vector_bwd(const Vmm &vmm_src); |
241 | void gelu_erf_compute_vector_bwd(const Vmm &vmm_src); |
242 | void hardswish_compute_vector_bwd(const Vmm &vmm_src); |
243 | void hardsigmoid_compute_vector_bwd(const Vmm &vmm_src); |
244 | |
245 | enum key_t { |
246 | scale = 0, // scale argument |
247 | alpha, // alpha argument |
248 | beta, // beta argument |
249 | zero, // 0.f |
250 | half, // 0.5f |
251 | one, // 1.f or mask for exponent bits |
252 | two, // 2.f |
253 | three, // 3.f |
254 | six, // 6.f |
255 | minus_one, // -1.f or changes sign to opposite |
256 | minus_two, // -2.f |
257 | minus_three, // -3.f |
258 | ln2f, // 0.69314718f |
259 | positive_mask, // changes sign to positive |
260 | sign_mask, // gets sign value |
261 | exponent_bias, // (127 = 2^7 - 1), gets exponent bits |
262 | exp_log2ef, // 1.44269502f - formula-based for approx |
263 | exp_ln_flt_max_f, // logf(FLT_MAX) - max normal value |
264 | exp_ln_flt_min_f, // logf(FLT_MIN) - min normal value |
265 | exp_pol, // see correspondent table for float values |
266 | // e^(2*x)+2*e^x+2 = FLT_MAX; x =~ 44.36141952603634 |
267 | fwd_mish_max_x_for_equation_f, |
268 | // e^x(e^3x+4e^2x+e^x*(6+4*x)+4*(1+x)) = FLT_MAX; x =~ 22.18070976278534 |
269 | bwd_mish_max_x_for_equation_f, |
270 | tanh_idx_bias, // bias applied during index computation |
271 | tanh_idx_mask, // mask applied to extract index |
272 | tanh_linear_ubound, // arg below which tanh(x) = x |
273 | tanh_saturation_lbound, // arg after which tanh(x) = 1.f |
274 | tanh_pol_table, // table of polynomial coefficients |
275 | soft_relu_one_twenty_six, // 126.f |
276 | soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign |
277 | soft_relu_pol, // see correspondent table for float values |
278 | gelu_tanh_fitting_const, // 0.044715f |
279 | gelu_tanh_fitting_const_times_three, // 0.134145f |
280 | gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f |
281 | // 0.3275911f - implementation based for approx |
282 | gelu_erf_Abramowitz_Stegun_approx_const, |
283 | gelu_erf_Abramowitz_Stegun_one_over_sqrt_two, // 1.f / sqrtf(2.f) |
284 | // 1.f / sqrtf(pi) = 0.564190f |
285 | gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi, |
286 | // see correspondent table for float values |
287 | gelu_erf_Abramowitz_Stegun_pol, |
288 | gelu_erf_minimax_neg_saturation_ubound, // x <= arg => gelu_erf = 0.0f |
289 | // when |x| <= arg => gelu_erf = 0.5f * x |
290 | gelu_erf_minimax_linear_ubound, |
291 | gelu_erf_minimax_saturation_lbound, // x >= arg => gelu_erf = x |
292 | gelu_erf_minimax_pol, // see correspondent table for float values |
293 | log_inf, // inf |
294 | log_minus_inf, // -inf |
295 | log_qnan, // qnan |
296 | log_mantissa_mask, // gets mantissa bits |
297 | log_full_k_reg_mask, // sets k_register with all bits of 1 |
298 | log_full_vector_reg_mask, // sets vector register will all bits of 1 |
299 | log_five_bit_offset, // 5 bits off (31 = 2^5 - 1) |
300 | log_pol, // see correspondent table for float values |
301 | log_predefined_vals, // see correspondent table for float values |
302 | undef_key, |
303 | }; |
304 | |
305 | size_t table_off(key_t key, size_t key_off_val_shift = 0) { |
306 | // assumption: all table entries sharing the same key also |
307 | // share their broadcast property |
308 | // TODO: enforce through data structure |
309 | const auto it = entry_map_.find(key); // search an entry for a key |
310 | assert(it != entry_map_.end()); |
311 | const auto &te = (*it).second; |
312 | const auto scale = te.bcast ? vlen : sizeof(table_entry_val_t); |
313 | return te.off + key_off_val_shift * scale; |
314 | } |
315 | Xbyak::Address table_val(key_t key, size_t key_off_val_shift = 0) { |
316 | auto off = table_off(key, key_off_val_shift); |
317 | return h->ptr[p_table + off]; |
318 | } |
319 | |
320 | // we accept only 32bit hexadecimal table values to avoid any rounding |
321 | using table_entry_val_t = uint32_t; |
322 | using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table |
323 | using table_entry_bcast_t = bool; // true => bcast value |
324 | |
325 | struct table_entry_t { |
326 | table_entry_val_t val; |
327 | table_entry_bcast_t bcast; |
328 | }; |
329 | struct mapped_table_entry_t { |
330 | table_entry_offset_t off; |
331 | table_entry_val_t val; |
332 | table_entry_bcast_t bcast; |
333 | }; |
334 | |
335 | using table_t = std::multimap<key_t, table_entry_t>; |
336 | using mapped_table_t = std::multimap<key_t, mapped_table_entry_t>; |
337 | |
338 | void register_table_entries(); |
339 | mapped_table_t entry_map_; |
340 | }; |
341 | |
342 | } // namespace x64 |
343 | } // namespace cpu |
344 | } // namespace impl |
345 | } // namespace dnnl |
346 | |
347 | #endif |
348 | |