1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | |
19 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
20 | #include "cpu/x64/jit_uni_binary_kernel.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | #define PARAM_OFF(x) offsetof(jit_binary_call_s, x) |
28 | |
29 | static bcast_set_t get_supported_postops_bcast_strategies() { |
30 | return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
31 | broadcasting_strategy_t::per_oc_spatial, |
32 | broadcasting_strategy_t::no_broadcast}; |
33 | } |
34 | |
35 | binary_kernel_t::binary_kernel_t(const size_t vlen, const binary_pd_t *pd, |
36 | const jit_binary_conf_t conf, const char *name, bool tail_kernel) |
37 | : jit_generator(name) |
38 | , vlen_(vlen) |
39 | , simd_w_(vlen / sizeof(float)) |
40 | , pd_(pd) |
41 | , conf_(conf) |
42 | , is_tail_kernel_(tail_kernel) |
43 | , is_src1_outer_dims_tail_( |
44 | conf_.is_src_different_layouts && conf_.outer_dims % simd_w_) |
45 | , tail_size_(get_tail_size()) |
46 | , padding_tail_size_( |
47 | pd->src_md(0)->padded_dims[1] - pd->src_md(0)->dims[1]) {} |
48 | |
49 | size_t binary_kernel_t::get_tail_size() const { |
50 | memory_desc_wrapper src0_d(pd_->src_md(0)); |
51 | const auto &dims = src0_d.dims(); |
52 | const auto &ndims = src0_d.ndims(); |
53 | |
54 | dim_t nelems = 0; |
55 | |
56 | if (ndims == 1) |
57 | nelems = dims[0]; |
58 | else if (is_src1_outer_dims_tail_) |
59 | nelems = conf_.outer_dims; |
60 | else if (!conf_.is_i8 && conf_.op_type == op_t::c_blocked |
61 | && (is_tail_kernel_ || conf_.bcast_type == bcast_t::per_w)) |
62 | nelems = dims[1]; |
63 | else if (conf_.bcast_type == bcast_t::none |
64 | && !conf_.postops_per_oc_broadcast_exists) |
65 | nelems = src0_d.nelems(true); |
66 | else if (conf_.bcast_type == bcast_t::per_batch |
67 | && !conf_.postops_per_oc_broadcast_exists) |
68 | nelems = src0_d.nelems(true) / dims[0]; |
69 | else { |
70 | if (conf_.op_type == op_t::n_spatial_c) |
71 | nelems = dims[1]; |
72 | else if (conf_.op_type == op_t::n_c_spatial && ndims >= 3) |
73 | nelems = conf_.bcast_type == bcast_t::per_w |
74 | ? utils::array_product( |
75 | dims + (ndims - conf_.not_bcasted_sp_dims), |
76 | conf_.not_bcasted_sp_dims) |
77 | : utils::array_product(dims + 2, ndims - 2); |
78 | } |
79 | // it's float due to for bfloat16 we still load 16 elements, not 32. |
80 | return nelems % simd_w_; |
81 | } |
82 | |
83 | template <cpu_isa_t isa, typename Vmm> |
84 | jit_uni_binary_kernel_t<isa, Vmm>::jit_uni_binary_kernel_t( |
85 | const binary_pd_t *pd, const jit_binary_conf_t conf, bool tail_kernel) |
86 | : binary_kernel_t(vreg_traits<Vmm>::vlen, pd, conf, jit_name(), tail_kernel) |
87 | , offt_src0_(vlen_ / ((conf_.is_bf16 || conf_.is_f16) ? 2 : 1)) |
88 | , offt_src1_(conf_.use_stride_src1 ? offt_src0_ : 0) |
89 | , io_(this, isa, {conf_.src0_type, conf_.src1_type, conf_.dst_type}, |
90 | {false}, |
91 | io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_, |
92 | vmm_tail_vmask_.getIdx(), reg_tmp_}, |
93 | io::io_emu_bf16_conf_t {vreg_bf16_emu_1_, vreg_bf16_emu_2_, |
94 | vreg_bf16_emu_3_, reg_tmp_, vreg_bf16_emu_4_}, |
95 | create_saturation_vmm_map(), |
96 | io::io_gather_conf_t {simd_w_, full_mask_, |
97 | vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_, |
98 | vmm_tmp_gather_.getIdx()}) { |
99 | init(); |
100 | } |
101 | |
102 | template <cpu_isa_t isa, typename Vmm> |
103 | std::map<data_type_t, io::io_saturation_conf_t> |
104 | jit_uni_binary_kernel_t<isa, Vmm>::create_saturation_vmm_map() const { |
105 | |
106 | std::map<data_type_t, io::io_saturation_conf_t> saturation_map {}; |
107 | |
108 | if (conf_.is_i8) |
109 | saturation_map.emplace(conf_.dst_type, |
110 | io::io_saturation_conf_t {vreg_zero_.getIdx(), |
111 | vreg_saturation_ubound_.getIdx(), reg_tmp_}); |
112 | |
113 | return saturation_map; |
114 | } |
115 | |
116 | template <cpu_isa_t isa, typename Vmm> |
117 | void jit_uni_binary_kernel_t<isa, Vmm>::init() { |
118 | if (conf_.with_postops) init_post_ops_injector(); |
119 | } |
120 | |
121 | template <cpu_isa_t isa, typename Vmm> |
122 | void jit_uni_binary_kernel_t<isa, Vmm>::init_post_ops_injector() { |
123 | const memory_desc_wrapper dst_d(pd_->dst_md(0)); |
124 | const auto &po = pd_->attr()->post_ops_; |
125 | |
126 | const eltwise_injector::static_params_t esp(true /*save_state*/, |
127 | reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/, |
128 | false /*use_dst*/); |
129 | const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_, |
130 | reg_elt_inj_table_, r13, true /*preserve gpr*/, |
131 | true /*preserve vmm*/, PARAM_OFF(post_ops_binary_rhs_arg_vec), |
132 | PARAM_OFF(dst_orig), dst_d, tail_size_, tail_opmask_, |
133 | false /*use_exact_tail_scalar_bcast*/}; |
134 | const binary_injector::static_params_t bsp(this->param1, |
135 | get_supported_postops_bcast_strategies(), rhs_arg_bsp); |
136 | |
137 | postops_injector_ = utils::make_unique< |
138 | injector::jit_uni_postops_injector_t<inject_isa, Vmm>>( |
139 | this, po, bsp, esp); |
140 | } |
141 | |
142 | template <cpu_isa_t isa, typename Vmm> |
143 | void jit_uni_binary_kernel_t<isa, Vmm>::apply_postops(int unroll, bool tail) { |
144 | const auto sum_injector = [&]() { |
145 | for (int i = 0; i < unroll; i++) { |
146 | const int offt = simd_w_ * i; |
147 | const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_); |
148 | const Vmm vreg_tmp = conf_.is_src_different_layouts |
149 | ? vmm_gathered_src_ |
150 | : Vmm(unroll + i + vmm_start_idx_); |
151 | io_.at(conf_.dst_type) |
152 | ->load(dst_ptr(offt |
153 | * types::data_type_size(conf_.dst_type)), |
154 | vreg_tmp, tail); |
155 | uni_vfmadd231ps(vreg_tmp_src0, vreg_tmp, vreg_sum_scale_); |
156 | } |
157 | }; |
158 | |
159 | if (conf_.do_sum) |
160 | postops_injector_->set_lambda_injector( |
161 | primitive_kind::sum, sum_injector); |
162 | |
163 | if (conf_.with_binary) { |
164 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
165 | const Reg64 ®_offt_dst |
166 | = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_; |
167 | |
168 | const injector_utils::register_preserve_guard_t register_guard { |
169 | this, {reg_tmp1_}}; |
170 | |
171 | mov(reg_tmp1_, reg_dst_); |
172 | add(reg_tmp1_, reg_offt_dst); |
173 | |
174 | for (int vmm_idx = 1; vmm_idx < unroll + vmm_start_idx_; vmm_idx++) { |
175 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_tmp1_); |
176 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, |
177 | (vmm_idx - vmm_start_idx_) * simd_w_ |
178 | * types::data_type_size(conf_.dst_type)); |
179 | if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
180 | } |
181 | postops_injector_->compute_vector_range( |
182 | 1, unroll + vmm_start_idx_, rhs_arg_params); |
183 | } else |
184 | postops_injector_->compute_vector_range(1, unroll + vmm_start_idx_); |
185 | } |
186 | |
187 | template <cpu_isa_t isa, typename Vmm> |
188 | void jit_uni_binary_kernel_t<isa, Vmm>::load_kernel_params() { |
189 | mov(reg_tmp_, float2int(conf_.sum_scale)); |
190 | uni_vmovq(xreg_sum_scale_, reg_tmp_); |
191 | uni_vbroadcastss(vreg_sum_scale_, xreg_sum_scale_); |
192 | if (is_src1_outer_dims_tail_) |
193 | mov(reg_outer_dims_range_, |
194 | ptr[reg_param_ + PARAM_OFF(spat_offt_count)]); |
195 | else |
196 | mov(reg_reverse_spat_offt_, |
197 | ptr[reg_param_ + PARAM_OFF(spat_offt_count)]); |
198 | mov(reg_src0_, ptr[reg_param_ + PARAM_OFF(src0)]); |
199 | mov(reg_src1_, ptr[reg_param_ + PARAM_OFF(src1)]); |
200 | mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]); |
201 | if (conf_.is_src_different_layouts) { |
202 | mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(indices)]); |
203 | uni_vmovdqu(vmm_indices_, ptr[reg_tmp_]); |
204 | |
205 | mov(reg_src1_stride_range_, |
206 | ptr[reg_param_ + PARAM_OFF(src1_stride_range)]); |
207 | mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_); |
208 | } |
209 | if (conf_.do_scale_src0) |
210 | mov(reg_scales_src0_, ptr[reg_param_ + PARAM_OFF(scales_src0)]); |
211 | if (conf_.do_scale_src1) |
212 | mov(reg_scales_src1_, ptr[reg_param_ + PARAM_OFF(scales_src1)]); |
213 | } |
214 | |
215 | template <cpu_isa_t isa, typename Vmm> |
216 | Address jit_uni_binary_kernel_t<isa, Vmm>::src0_ptr(size_t offt) { |
217 | return vmmword[reg_src0_ + reg_offt_src0_ + offt]; |
218 | } |
219 | |
220 | template <cpu_isa_t isa, typename Vmm> |
221 | Address jit_uni_binary_kernel_t<isa, Vmm>::src1_ptr(size_t offt) { |
222 | return vmmword[reg_src1_ + reg_offt_src1_ + offt]; |
223 | } |
224 | |
225 | template <cpu_isa_t isa, typename Vmm> |
226 | Address jit_uni_binary_kernel_t<isa, Vmm>::dst_ptr(size_t offt) { |
227 | const Reg64 ®_offt_dst = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_; |
228 | return vmmword[reg_dst_ + reg_offt_dst + offt]; |
229 | } |
230 | |
231 | template <cpu_isa_t isa, typename Vmm> |
232 | unsigned int jit_uni_binary_kernel_t<isa, Vmm>::cmp_predicate(alg_kind_t alg) { |
233 | using namespace alg_kind; |
234 | switch (alg) { |
235 | case binary_ge: return _cmp_nlt_us; |
236 | case binary_gt: return _cmp_nle_us; |
237 | case binary_le: return _cmp_le_os; |
238 | case binary_lt: return _cmp_lt_os; |
239 | case binary_eq: return _cmp_eq_oq; |
240 | case binary_ne: return _cmp_neq_uq; |
241 | default: assert(!"not supported operation!" ); return -1; |
242 | } |
243 | } |
244 | |
245 | template <cpu_isa_t isa, typename Vmm> |
246 | void jit_uni_binary_kernel_t<isa, Vmm>::perform_op( |
247 | const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1) { |
248 | using namespace alg_kind; |
249 | const auto alg = pd_->desc()->alg_kind; |
250 | const bool cmp_op = utils::one_of(alg, alg_kind::binary_ge, |
251 | alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, |
252 | alg_kind::binary_eq, alg_kind::binary_ne); |
253 | if (conf_.do_scale_src0) uni_vmulps(v0, v0, s_src0); |
254 | if (conf_.do_scale_src1 && offt_src1_ != 0 && !conf_.broadcast_src1_value) |
255 | uni_vmulps(v1, v1, s_src1); |
256 | |
257 | if (alg == binary_add) |
258 | uni_vaddps(v0, v0, v1); |
259 | else if (alg == binary_mul) |
260 | uni_vmulps(v0, v0, v1); |
261 | else if (alg == binary_max) |
262 | uni_vmaxps(v0, v0, v1); |
263 | else if (alg == binary_min) |
264 | uni_vminps(v0, v0, v1); |
265 | else if (alg == binary_div) |
266 | uni_vdivps(v0, v0, v1); |
267 | else if (alg == binary_sub) |
268 | uni_vsubps(v0, v0, v1); |
269 | else if (cmp_op) { |
270 | const unsigned int predicate = cmp_predicate(alg); |
271 | if (is_avx512) { |
272 | vcmpps(cmp_mask, v0, v1, predicate); |
273 | vmovups(v0 | cmp_mask | T_z, vreg_one_); |
274 | } else { |
275 | uni_vcmpps(v0, v0, v1, predicate); |
276 | uni_vminps(v0, v0, vreg_one_); |
277 | } |
278 | } else |
279 | assert(!"not supported operation!" ); |
280 | } |
281 | |
282 | template <cpu_isa_t isa, typename Vmm> |
283 | void jit_uni_binary_kernel_t<isa, Vmm>::prepare_isa_kernel() { |
284 | if (conf_.is_bf16) io_.init_bf16(); |
285 | if (tail_size_ > 0) io_.prepare_tail_mask(); |
286 | if (conf_.is_src_different_layouts && is_superset(isa, avx2)) { |
287 | io_.init_full_mask(); |
288 | io_.prepare_full_mask(); |
289 | } |
290 | } |
291 | |
292 | template <cpu_isa_t isa, typename Vmm> |
293 | void jit_uni_binary_kernel_t<isa, Vmm>::compute_bcast(bool tail) { |
294 | if (conf_.broadcast_src1_value) { |
295 | if (conf_.is_i8) |
296 | uni_vpxor(xreg_bcast_src1_, xreg_bcast_src1_, xreg_bcast_src1_); |
297 | io_.at(conf_.src1_type)->broadcast(src1_ptr(), vreg_bcast_src1_); |
298 | } else if (!conf_.is_i8 && offt_src1_ == 0) { |
299 | io_.at(conf_.src1_type)->load(src1_ptr(), vreg_bcast_src1_, tail); |
300 | } |
301 | } |
302 | |
303 | template <cpu_isa_t isa, typename Vmm> |
304 | void jit_uni_binary_kernel_t<isa, Vmm>::load_src1( |
305 | const Vmm &vreg_src1, const int offt, bool tail) { |
306 | if (conf_.is_src_different_layouts) { |
307 | // if different layouts, gather data with strides |
308 | // after getting to stride range, offset is restored and |
309 | // increased |
310 | io_.at(conf_.src1_type) |
311 | ->gather(reg_src1_, vmm_indices_, vreg_src1, tail); |
312 | // gather is using register instead of operand to read address |
313 | // use reg_src1_ directly, without offset stored in second |
314 | // register |
315 | add(reg_src1_, |
316 | types::data_type_size(conf_.src1_type) * conf_.src1_stride |
317 | * simd_w_); |
318 | sub(reg_reverse_src1_stride_range_, |
319 | types::data_type_size(conf_.src1_type) * conf_.src1_stride |
320 | * simd_w_); |
321 | |
322 | Label src1_stride_range_not_exceed, src1_C_tail_end; |
323 | |
324 | cmp(reg_reverse_src1_stride_range_, 0); |
325 | jg(src1_stride_range_not_exceed, T_NEAR); |
326 | { |
327 | pop(reg_src1_); |
328 | add(reg_src1_, types::data_type_size(conf_.src1_type)); |
329 | push(reg_src1_); |
330 | mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_); |
331 | } |
332 | L(src1_stride_range_not_exceed); |
333 | } else |
334 | io_.at(conf_.src1_type) |
335 | ->load(src1_ptr(offt * types::data_type_size(conf_.src1_type)), |
336 | vreg_src1, tail); |
337 | } |
338 | |
339 | template <cpu_isa_t isa, typename Vmm> |
340 | void jit_uni_binary_kernel_t<isa, Vmm>::compute_dst(int unroll, bool tail) { |
341 | for (int i = 0; i < unroll; i++) { |
342 | const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_); |
343 | const Vmm vreg_tmp = conf_.is_src_different_layouts |
344 | ? vmm_gathered_src_ |
345 | : Vmm(unroll + i + vmm_start_idx_); |
346 | const Vmm vreg_tmp_src1 = offt_src1_ ? vreg_tmp : vreg_bcast_src1_; |
347 | const int offt = simd_w_ * i; |
348 | io_.at(conf_.src0_type) |
349 | ->load(src0_ptr(offt * types::data_type_size(conf_.src0_type)), |
350 | vreg_tmp_src0, tail); |
351 | if (offt_src1_) load_src1(vreg_tmp_src1, offt, tail); |
352 | |
353 | // avoid multiple multiplication on input scale for broadcasted vreg |
354 | // not needed for different layouts |
355 | if (!conf_.is_src_different_layouts) |
356 | uni_vmovups(vreg_tmp, vreg_tmp_src1); |
357 | perform_op( |
358 | vreg_tmp_src0, vreg_tmp, vreg_scales_src0_, vreg_scales_src1_); |
359 | } |
360 | |
361 | if (postops_injector_) apply_postops(unroll, tail); |
362 | |
363 | for (int i = 0; i < unroll; i++) { |
364 | const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_); |
365 | const int offt = simd_w_ * i; |
366 | const auto dt_size = types::data_type_size(conf_.dst_type); |
367 | |
368 | if (is_tail_kernel_ && padding_tail_size_) { |
369 | // apply zero-padding |
370 | Label end; |
371 | auto off_base = 0; |
372 | auto zero_pad_left = padding_tail_size_; |
373 | |
374 | // inplace data is assumed to be zero-padded |
375 | cmp(reg_src0_, reg_dst_); |
376 | je(end, T_NEAR); |
377 | |
378 | if (zero_pad_left >= simd_w_ - tail_size_) { |
379 | vxorps(vreg_zero_, vreg_zero_, vreg_zero_); |
380 | if (is_avx512) |
381 | uni_vmovups(vreg_zero_ | tail_opmask_, vreg_tmp_src0); |
382 | else |
383 | uni_vblendvps(vreg_zero_, vreg_zero_, vreg_tmp_src0, |
384 | vmm_tail_vmask_); |
385 | io_.at(conf_.dst_type) |
386 | ->store(vreg_zero_, dst_ptr(offt * dt_size), false); |
387 | off_base = simd_w_ * dt_size; |
388 | zero_pad_left -= simd_w_ - tail_size_; |
389 | } else { |
390 | io_.at(conf_.dst_type) |
391 | ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), true); |
392 | off_base = tail_size_ * dt_size; |
393 | } |
394 | |
395 | if (zero_pad_left) { |
396 | push(abi_param1); |
397 | const Reg32 ®_zero = eax; |
398 | const Reg64 ®_ptr = rdi; |
399 | const Reg64 ®_counter = rcx; |
400 | const auto off_start = off_base; |
401 | const auto off_end = off_start + zero_pad_left * dt_size; |
402 | xor_(reg_zero, reg_zero); |
403 | lea(reg_ptr, |
404 | ptr[dst_ptr(offt * dt_size).getRegExp() |
405 | + RegExp(off_start)]); |
406 | mov(reg_counter, off_end - off_start); |
407 | rep(); |
408 | stosb(); |
409 | pop(abi_param1); |
410 | } |
411 | L(end); |
412 | } else |
413 | io_.at(conf_.dst_type) |
414 | ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), tail); |
415 | } |
416 | } |
417 | |
418 | template <cpu_isa_t isa, typename Vmm> |
419 | void jit_uni_binary_kernel_t<isa, Vmm>::forward() { |
420 | Label unroll_loop, unroll_loop_tail, nelems_tail, end; |
421 | |
422 | const auto src0_type_size = types::data_type_size(conf_.src0_type); |
423 | const auto src1_type_size = types::data_type_size(conf_.src1_type); |
424 | const auto dst_type_size = types::data_type_size(conf_.dst_type); |
425 | |
426 | if (conf_.is_src_different_layouts) push(reg_src1_); |
427 | |
428 | // if outer dims tail, do it outside outer dims loop |
429 | if (!is_src1_outer_dims_tail_) { |
430 | if (conf_.is_i8) { |
431 | uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_); |
432 | io_.init_saturate_f32({conf_.dst_type}); |
433 | xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst |
434 | } |
435 | |
436 | xor_(reg_offt_src0_, |
437 | reg_offt_src0_); // offt_src0 to get addr of src0/dst |
438 | if (!conf_.is_src_different_layouts) |
439 | xor_(reg_offt_src1_, |
440 | reg_offt_src1_); // offt_src1 to get addr of src1 |
441 | if (conf_.use_stride_rhs_postops && !conf_.is_i8) |
442 | xor_(reg_off_rhs_postops_, reg_off_rhs_postops_); |
443 | } |
444 | const auto alg = pd_->desc()->alg_kind; |
445 | |
446 | if (utils::one_of(alg, alg_kind::binary_ge, alg_kind::binary_gt, |
447 | alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq, |
448 | alg_kind::binary_ne)) { |
449 | Xmm xreg_one = Xmm(vreg_one_.getIdx()); |
450 | mov(reg_tmp_, float2int(1)); |
451 | uni_vmovq(xreg_one, reg_tmp_); |
452 | uni_vbroadcastss(vreg_one_, xreg_one); |
453 | } |
454 | |
455 | compute_bcast(false); // bcast/load vreg just one time per a kernel call |
456 | |
457 | // used in c_blocked strategy for last blocked if tail exists |
458 | const bool treat_each_compute_step_as_tail |
459 | = !conf_.is_i8 && is_tail_kernel_ && tail_size_; |
460 | |
461 | if (conf_.do_scale_src0) |
462 | uni_vbroadcastss(vreg_scales_src0_, ptr[reg_scales_src0_]); |
463 | if (conf_.do_scale_src1) { |
464 | uni_vbroadcastss(vreg_scales_src1_, ptr[reg_scales_src1_]); |
465 | if (conf_.broadcast_src1_value || offt_src1_ == 0) |
466 | uni_vmulps(vreg_bcast_src1_, vreg_bcast_src1_, vreg_scales_src1_); |
467 | } |
468 | |
469 | L(unroll_loop); |
470 | { |
471 | const size_t offt = unroll_regs_ * simd_w_; |
472 | cmp(reg_reverse_spat_offt_, offt * dst_type_size); |
473 | jl(unroll_loop_tail, T_NEAR); |
474 | |
475 | compute_dst(unroll_regs_, treat_each_compute_step_as_tail); |
476 | sub(reg_reverse_spat_offt_, offt * dst_type_size); |
477 | add(reg_offt_src0_, offt * src0_type_size); |
478 | if (conf_.is_i8) { |
479 | if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts) |
480 | add(reg_offt_src1_, offt * src1_type_size); |
481 | add(reg_offt_dst_, offt); |
482 | } else { |
483 | if (conf_.use_stride_src1 && !conf_.is_src_different_layouts) |
484 | add(reg_offt_src1_, offt * src1_type_size); |
485 | if (conf_.use_stride_rhs_postops) add(reg_off_rhs_postops_, offt); |
486 | } |
487 | jmp(unroll_loop); |
488 | } |
489 | |
490 | L(unroll_loop_tail); |
491 | { |
492 | cmp(reg_reverse_spat_offt_, simd_w_ * dst_type_size); |
493 | jl(nelems_tail, T_NEAR); |
494 | |
495 | compute_dst(1, treat_each_compute_step_as_tail); |
496 | sub(reg_reverse_spat_offt_, simd_w_ * dst_type_size); |
497 | add(reg_offt_src0_, simd_w_ * src0_type_size); |
498 | if (conf_.is_i8) { |
499 | if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts) |
500 | add(reg_offt_src1_, simd_w_ * src1_type_size); |
501 | add(reg_offt_dst_, simd_w_); |
502 | } else { |
503 | if (conf_.use_stride_src1 && !conf_.is_src_different_layouts) |
504 | add(reg_offt_src1_, simd_w_ * src1_type_size); |
505 | if (conf_.use_stride_rhs_postops) |
506 | add(reg_off_rhs_postops_, simd_w_); |
507 | } |
508 | |
509 | jmp(unroll_loop_tail); |
510 | } |
511 | |
512 | L(nelems_tail); |
513 | { |
514 | cmp(reg_reverse_spat_offt_, 1); |
515 | jl(end, T_NEAR); |
516 | |
517 | compute_dst(1, true); |
518 | // need to increase if forward over outer dims |
519 | if (is_src1_outer_dims_tail_) { |
520 | add(reg_offt_src0_, tail_size_ * src0_type_size); |
521 | if (conf_.is_i8) |
522 | add(reg_offt_dst_, tail_size_); |
523 | else { |
524 | if (conf_.use_stride_rhs_postops) |
525 | add(reg_off_rhs_postops_, tail_size_); |
526 | } |
527 | } |
528 | } |
529 | |
530 | L(end); |
531 | if (conf_.is_src_different_layouts) pop(reg_src1_); |
532 | } |
533 | |
534 | template <cpu_isa_t isa, typename Vmm> |
535 | void jit_uni_binary_kernel_t<isa, Vmm>::forward_over_outer_dims() { |
536 | const auto outer_dims_size |
537 | = conf_.outer_dims * types::data_type_size(conf_.dst_type); |
538 | |
539 | if (conf_.is_i8) { |
540 | uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_); |
541 | io_.init_saturate_f32({conf_.dst_type}); |
542 | xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst |
543 | } |
544 | |
545 | xor_(reg_offt_src0_, |
546 | reg_offt_src0_); // offt_src0 to get addr of src0/dst |
547 | if (conf_.use_stride_rhs_postops && !conf_.is_i8) |
548 | xor_(reg_off_rhs_postops_, reg_off_rhs_postops_); |
549 | |
550 | Label c_loop; |
551 | L(c_loop); |
552 | { |
553 | mov(reg_reverse_spat_offt_, outer_dims_size); |
554 | forward(); |
555 | sub(reg_outer_dims_range_, outer_dims_size); |
556 | cmp(reg_outer_dims_range_, 0); |
557 | jg(c_loop); |
558 | } |
559 | } |
560 | |
561 | template <cpu_isa_t isa, typename Vmm> |
562 | void jit_uni_binary_kernel_t<isa, Vmm>::generate() { |
563 | preamble(); |
564 | load_kernel_params(); |
565 | prepare_isa_kernel(); |
566 | // if outer dims is not aligned to simd_w, iterate over it to avoid |
567 | // modifying the gather indices |
568 | if (is_src1_outer_dims_tail_) |
569 | forward_over_outer_dims(); |
570 | else |
571 | forward(); |
572 | postamble(); |
573 | |
574 | if ((conf_.with_eltwise || conf_.is_i8) && postops_injector_) |
575 | postops_injector_->prepare_table(); |
576 | } |
577 | |
578 | #undef PARAM_OFF |
579 | |
580 | template struct jit_uni_binary_kernel_t<avx512_core_fp16, Zmm>; |
581 | template struct jit_uni_binary_kernel_t<avx512_core_fp16, Ymm>; |
582 | template struct jit_uni_binary_kernel_t<avx512_core_fp16, Xmm>; |
583 | template struct jit_uni_binary_kernel_t<avx512_core_bf16, Zmm>; |
584 | template struct jit_uni_binary_kernel_t<avx512_core_bf16, Ymm>; |
585 | template struct jit_uni_binary_kernel_t<avx512_core_bf16, Xmm>; |
586 | template struct jit_uni_binary_kernel_t<avx512_core, Zmm>; |
587 | template struct jit_uni_binary_kernel_t<avx512_core, Ymm>; |
588 | template struct jit_uni_binary_kernel_t<avx512_core, Xmm>; |
589 | template struct jit_uni_binary_kernel_t<avx2, Ymm>; |
590 | template struct jit_uni_binary_kernel_t<avx2, Xmm>; |
591 | template struct jit_uni_binary_kernel_t<sse41, Xmm>; |
592 | |
593 | } // namespace x64 |
594 | } // namespace cpu |
595 | } // namespace impl |
596 | } // namespace dnnl |
597 | |