1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_ELTWISE_INJECTOR_HPP
18#define CPU_X64_JIT_UNI_ELTWISE_INJECTOR_HPP
19
20#include <assert.h>
21#include <type_traits>
22
23#include "common/c_types_map.hpp"
24#include "common/primitive_attr.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/x64/injectors/injector_utils.hpp"
29#include "cpu/x64/jit_generator.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36namespace eltwise_injector {
37struct static_params_t {
38
39 static_params_t(bool save_state = true,
40 Xbyak::Reg64 p_table = Xbyak::util::rax,
41 Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true,
42 bool use_dst = false, bool preserve_vmm = true,
43 bool preserve_p_table = true)
44 : save_state(save_state)
45 , p_table(p_table)
46 , k_mask(k_mask)
47 , is_fwd(is_fwd)
48 , use_dst(use_dst)
49 , preserve_vmm(preserve_vmm)
50 , preserve_p_table(preserve_p_table) {}
51
52 bool save_state;
53 Xbyak::Reg64 p_table;
54 Xbyak::Opmask k_mask;
55 bool is_fwd;
56 bool use_dst;
57 bool preserve_vmm;
58 bool preserve_p_table;
59};
60
61/*
62 * Checks if isa is supported by eltwise injector.
63 */
64bool is_isa_supported(cpu_isa_t isa);
65
66/*
67 * Checks if eltwise algorithm is supported by eltwise injector.
68 */
69bool is_alg_supported(alg_kind_t alg);
70
71/*
72 * Checks if eltwise injection for given args is supported.
73 */
74bool is_supported(cpu_isa_t isa, alg_kind_t alg);
75
76} // namespace eltwise_injector
77
78template <cpu_isa_t isa, typename Wmm = typename cpu_isa_traits<isa>::Vmm>
79struct jit_uni_eltwise_injector_f32 {
80 using Vmm = Wmm;
81
82 // Arguments description:
83 // host - jit generator which is filled with instructions
84 // alg, alpha, beta, scale - user eltwise arguments
85 // save_state - when true, preserves on stack vmm_aux registers preventing
86 // results spoiling. Restores them when done in injector_postamble().
87 // p_table - GPR where table label is stored to get access for pre-defined
88 // constants used in alg codes.
89 // k_mask - k_register to operate with masks in alg codes.
90 // is_fwd - when true, computes d = alg(s), otherwise, computes ds = alg'(s)
91 // - algorithm derivative.
92 // use_dst - defines whether source or destination point is passed to alg
93 // code. Depends on algorithm. See `_use_dst_for_bwd` algs definition.
94 jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
95 float alpha, float beta, float scale, bool save_state = true,
96 Xbyak::Reg64 p_table = Xbyak::util::rax,
97 Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true,
98 bool use_dst = false, bool preserve_vmm = true,
99 bool preserve_p_table = true)
100 : alg_(alg)
101 , alpha_(alpha)
102 , beta_(beta)
103 , scale_(scale)
104 , h(host)
105 , save_state_(save_state)
106 , p_table(p_table)
107 , k_mask(k_mask)
108 , is_fwd_(is_fwd)
109 , use_dst_(use_dst)
110 , preserve_vmm_(preserve_vmm)
111 , preserve_p_table_(preserve_p_table) {
112 assert(eltwise_injector::is_supported(isa, alg_));
113
114 register_table_entries();
115 }
116
117 jit_uni_eltwise_injector_f32(jit_generator *host,
118 const post_ops_t::entry_t::eltwise_t &eltwise,
119 bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
120 Xbyak::Opmask k_mask = Xbyak::Opmask(1), bool is_fwd = true,
121 bool use_dst = false, bool preserve_vmm = true,
122 bool preserve_p_table = true)
123 : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
124 eltwise.beta, eltwise.scale, save_state, p_table, k_mask,
125 is_fwd, use_dst, preserve_vmm, preserve_p_table) {}
126
127 void compute_vector_range(size_t start_idx, size_t end_idx);
128 void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs);
129 void compute_vector(size_t idx) { compute_vector_range({idx}); }
130 void prepare_table(bool gen_table = true);
131 void load_table_addr() { h->mov(p_table, l_table); }
132
133private:
134 const alg_kind_t alg_;
135 const float alpha_;
136 const float beta_;
137 const float scale_;
138
139 jit_generator *const h;
140
141 const bool save_state_;
142 const Xbyak::Reg64 p_table;
143 const Xbyak::Opmask k_mask;
144 const bool is_fwd_;
145 const bool use_dst_;
146 const bool preserve_vmm_;
147 const bool preserve_p_table_;
148
149 Xbyak::Label l_table;
150
151 // if only the injector was inherited from jit_generator...
152 enum {
153 _cmp_eq_oq = jit_generator::_cmp_eq_oq,
154 _cmp_neq_uq = jit_generator::_cmp_neq_uq,
155 _cmp_lt_os = jit_generator::_cmp_lt_os,
156 _cmp_le_os = jit_generator::_cmp_le_os,
157 _cmp_ge_os = jit_generator::_cmp_nlt_us,
158 _cmp_gt_os = jit_generator::_cmp_nle_us,
159 _op_floor = jit_generator::_op_floor,
160 _op_mxcsr = jit_generator::_op_mxcsr
161 };
162
163 const bool is_avx512 = is_superset(isa, avx512_core);
164
165 static constexpr size_t vlen = vreg_traits<Vmm>::vlen;
166 static constexpr size_t preserved_vecs_max = 6;
167 static constexpr size_t preserved_gprs_max = 5;
168 static constexpr size_t vecs_count = cpu_isa_traits<isa>::n_vregs;
169 static constexpr int n_mantissa_bits = 23;
170 static constexpr int k_mask_size = 8;
171
172 bool preserve_vec_for_avx = false;
173
174 size_t vecs_to_preserve = 0;
175 size_t preserved_vecs_count = 0;
176 size_t preserved_vec_idxs[preserved_vecs_max] = {0};
177 size_t preserved_gpr_idxs[preserved_gprs_max] = {0};
178 injector_utils::vmm_index_set_iterator_t start_idx_tail;
179
180 Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4, vmm_tmp;
181 Xbyak::Ymm ymm_tmp;
182 Xbyak::Xmm xmm_tmp;
183
184 size_t aux_vecs_count();
185 size_t aux_gprs_count();
186
187 void compute_body(
188 const injector_utils::vmm_index_set_iterator_t &start_idx_it,
189 const injector_utils::vmm_index_set_iterator_t &end_idx_it);
190 void injector_preamble(const injector_utils::vmm_index_set_t &vmm_idxs);
191 void injector_preamble_tail(
192 const injector_utils::vmm_index_set_iterator_t start_idx_it);
193 void injector_postamble();
194 void assign_regs();
195 void vec_shift(const Vmm &vmm_dst, const Vmm &vmm_src, bool shift_left,
196 const int imm);
197 void compute_cmp_mask(const Vmm &vmm_src,
198 const Xbyak::Operand &compare_operand, int cmp_predicate);
199 void blend_with_mask(const Vmm &vmm_dst, const Xbyak::Operand &src);
200 void test_mask();
201
202 void exp_compute_vector_fwd(const Vmm &vmm_src);
203 void relu_compute_vector_fwd(const Vmm &vmm_src);
204 void relu_zero_ns_compute_vector_fwd(const Vmm &vmm_src);
205 void elu_compute_vector_fwd(const Vmm &vmm_src);
206 void tanh_compute_vector_fwd(const Vmm &vmm_src);
207 void square_compute_vector_fwd(const Vmm &vmm_src);
208 void abs_compute_vector_fwd(const Vmm &vmm_src);
209 void sqrt_compute_vector_fwd(const Vmm &vmm_src);
210 void linear_compute_vector_fwd(const Vmm &vmm_src);
211 void soft_relu_compute_vector_fwd(const Vmm &vmm_src);
212 void mish_compute_vector_fwd(const Vmm &vmm_src);
213 void logistic_compute_vector_fwd(const Vmm &vmm_src);
214 void gelu_tanh_compute_vector_fwd(const Vmm &vmm_src);
215 void swish_compute_vector_fwd(const Vmm &vmm_src);
216 void log_compute_vector_fwd(const Vmm &vmm_src);
217 void clip_compute_vector_fwd(const Vmm &vmm_src);
218 void pow_compute_vector_fwd(const Vmm &vmm_src);
219 void gelu_erf_compute_vector_fwd(const Vmm &vmm_src);
220 void gelu_erf_minimax_approx_compute_vector_fwd(const Vmm &vmm_src);
221 void round_compute_vector_fwd(const Vmm &vmm_src);
222 void hardswish_compute_vector_fwd(const Vmm &vmm_src);
223 void hardsigmoid_compute_vector_fwd(const Vmm &vmm_src);
224
225 void exp_compute_vector_bwd(const Vmm &vmm_src);
226 void relu_compute_vector_bwd(const Vmm &vmm_src);
227 void elu_compute_vector_bwd(const Vmm &vmm_src);
228 void tanh_compute_vector_bwd(const Vmm &vmm_src);
229 void square_compute_vector_bwd(const Vmm &vmm_src);
230 void abs_compute_vector_bwd(const Vmm &vmm_src);
231 void sqrt_compute_vector_bwd(const Vmm &vmm_src);
232 void linear_compute_vector_bwd(const Vmm &vmm_src);
233 void soft_relu_compute_vector_bwd(const Vmm &vmm_src);
234 void logistic_compute_vector_bwd(const Vmm &vmm_src);
235 void mish_compute_vector_bwd(const Vmm &vmm_src);
236 void gelu_tanh_compute_vector_bwd(const Vmm &vmm_src);
237 void swish_compute_vector_bwd(const Vmm &vmm_src);
238 void log_compute_vector_bwd(const Vmm &vmm_src);
239 void clip_compute_vector_bwd(const Vmm &vmm_src);
240 void pow_compute_vector_bwd(const Vmm &vmm_src);
241 void gelu_erf_compute_vector_bwd(const Vmm &vmm_src);
242 void hardswish_compute_vector_bwd(const Vmm &vmm_src);
243 void hardsigmoid_compute_vector_bwd(const Vmm &vmm_src);
244
245 enum key_t {
246 scale = 0, // scale argument
247 alpha, // alpha argument
248 beta, // beta argument
249 zero, // 0.f
250 half, // 0.5f
251 one, // 1.f or mask for exponent bits
252 two, // 2.f
253 three, // 3.f
254 six, // 6.f
255 minus_one, // -1.f or changes sign to opposite
256 minus_two, // -2.f
257 minus_three, // -3.f
258 ln2f, // 0.69314718f
259 positive_mask, // changes sign to positive
260 sign_mask, // gets sign value
261 exponent_bias, // (127 = 2^7 - 1), gets exponent bits
262 exp_log2ef, // 1.44269502f - formula-based for approx
263 exp_ln_flt_max_f, // logf(FLT_MAX) - max normal value
264 exp_ln_flt_min_f, // logf(FLT_MIN) - min normal value
265 exp_pol, // see correspondent table for float values
266 // e^(2*x)+2*e^x+2 = FLT_MAX; x =~ 44.36141952603634
267 fwd_mish_max_x_for_equation_f,
268 // e^x(e^3x+4e^2x+e^x*(6+4*x)+4*(1+x)) = FLT_MAX; x =~ 22.18070976278534
269 bwd_mish_max_x_for_equation_f,
270 tanh_idx_bias, // bias applied during index computation
271 tanh_idx_mask, // mask applied to extract index
272 tanh_linear_ubound, // arg below which tanh(x) = x
273 tanh_saturation_lbound, // arg after which tanh(x) = 1.f
274 tanh_pol_table, // table of polynomial coefficients
275 soft_relu_one_twenty_six, // 126.f
276 soft_relu_mantissa_sign_mask, // mask for mantissa bits and sign
277 soft_relu_pol, // see correspondent table for float values
278 gelu_tanh_fitting_const, // 0.044715f
279 gelu_tanh_fitting_const_times_three, // 0.134145f
280 gelu_tanh_sqrt_two_over_pi, // sqrtf(2.f/pi) = 0.797884f
281 // 0.3275911f - implementation based for approx
282 gelu_erf_Abramowitz_Stegun_approx_const,
283 gelu_erf_Abramowitz_Stegun_one_over_sqrt_two, // 1.f / sqrtf(2.f)
284 // 1.f / sqrtf(pi) = 0.564190f
285 gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi,
286 // see correspondent table for float values
287 gelu_erf_Abramowitz_Stegun_pol,
288 gelu_erf_minimax_neg_saturation_ubound, // x <= arg => gelu_erf = 0.0f
289 // when |x| <= arg => gelu_erf = 0.5f * x
290 gelu_erf_minimax_linear_ubound,
291 gelu_erf_minimax_saturation_lbound, // x >= arg => gelu_erf = x
292 gelu_erf_minimax_pol, // see correspondent table for float values
293 log_inf, // inf
294 log_minus_inf, // -inf
295 log_qnan, // qnan
296 log_mantissa_mask, // gets mantissa bits
297 log_full_k_reg_mask, // sets k_register with all bits of 1
298 log_full_vector_reg_mask, // sets vector register will all bits of 1
299 log_five_bit_offset, // 5 bits off (31 = 2^5 - 1)
300 log_pol, // see correspondent table for float values
301 log_predefined_vals, // see correspondent table for float values
302 undef_key,
303 };
304
305 size_t table_off(key_t key, size_t key_off_val_shift = 0) {
306 // assumption: all table entries sharing the same key also
307 // share their broadcast property
308 // TODO: enforce through data structure
309 const auto it = entry_map_.find(key); // search an entry for a key
310 assert(it != entry_map_.end());
311 const auto &te = (*it).second;
312 const auto scale = te.bcast ? vlen : sizeof(table_entry_val_t);
313 return te.off + key_off_val_shift * scale;
314 }
315 Xbyak::Address table_val(key_t key, size_t key_off_val_shift = 0) {
316 auto off = table_off(key, key_off_val_shift);
317 return h->ptr[p_table + off];
318 }
319
320 // we accept only 32bit hexadecimal table values to avoid any rounding
321 using table_entry_val_t = uint32_t;
322 using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table
323 using table_entry_bcast_t = bool; // true => bcast value
324
325 struct table_entry_t {
326 table_entry_val_t val;
327 table_entry_bcast_t bcast;
328 };
329 struct mapped_table_entry_t {
330 table_entry_offset_t off;
331 table_entry_val_t val;
332 table_entry_bcast_t bcast;
333 };
334
335 using table_t = std::multimap<key_t, table_entry_t>;
336 using mapped_table_t = std::multimap<key_t, mapped_table_entry_t>;
337
338 void register_table_entries();
339 mapped_table_t entry_map_;
340};
341
342} // namespace x64
343} // namespace cpu
344} // namespace impl
345} // namespace dnnl
346
347#endif
348