1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18
19#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
20#include "cpu/x64/jit_uni_binary_kernel.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27#define PARAM_OFF(x) offsetof(jit_binary_call_s, x)
28
29static bcast_set_t get_supported_postops_bcast_strategies() {
30 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
31 broadcasting_strategy_t::per_oc_spatial,
32 broadcasting_strategy_t::no_broadcast};
33}
34
35binary_kernel_t::binary_kernel_t(const size_t vlen, const binary_pd_t *pd,
36 const jit_binary_conf_t conf, const char *name, bool tail_kernel)
37 : jit_generator(name)
38 , vlen_(vlen)
39 , simd_w_(vlen / sizeof(float))
40 , pd_(pd)
41 , conf_(conf)
42 , is_tail_kernel_(tail_kernel)
43 , is_src1_outer_dims_tail_(
44 conf_.is_src_different_layouts && conf_.outer_dims % simd_w_)
45 , tail_size_(get_tail_size())
46 , padding_tail_size_(
47 pd->src_md(0)->padded_dims[1] - pd->src_md(0)->dims[1]) {}
48
49size_t binary_kernel_t::get_tail_size() const {
50 memory_desc_wrapper src0_d(pd_->src_md(0));
51 const auto &dims = src0_d.dims();
52 const auto &ndims = src0_d.ndims();
53
54 dim_t nelems = 0;
55
56 if (ndims == 1)
57 nelems = dims[0];
58 else if (is_src1_outer_dims_tail_)
59 nelems = conf_.outer_dims;
60 else if (!conf_.is_i8 && conf_.op_type == op_t::c_blocked
61 && (is_tail_kernel_ || conf_.bcast_type == bcast_t::per_w))
62 nelems = dims[1];
63 else if (conf_.bcast_type == bcast_t::none
64 && !conf_.postops_per_oc_broadcast_exists)
65 nelems = src0_d.nelems(true);
66 else if (conf_.bcast_type == bcast_t::per_batch
67 && !conf_.postops_per_oc_broadcast_exists)
68 nelems = src0_d.nelems(true) / dims[0];
69 else {
70 if (conf_.op_type == op_t::n_spatial_c)
71 nelems = dims[1];
72 else if (conf_.op_type == op_t::n_c_spatial && ndims >= 3)
73 nelems = conf_.bcast_type == bcast_t::per_w
74 ? utils::array_product(
75 dims + (ndims - conf_.not_bcasted_sp_dims),
76 conf_.not_bcasted_sp_dims)
77 : utils::array_product(dims + 2, ndims - 2);
78 }
79 // it's float due to for bfloat16 we still load 16 elements, not 32.
80 return nelems % simd_w_;
81}
82
83template <cpu_isa_t isa, typename Vmm>
84jit_uni_binary_kernel_t<isa, Vmm>::jit_uni_binary_kernel_t(
85 const binary_pd_t *pd, const jit_binary_conf_t conf, bool tail_kernel)
86 : binary_kernel_t(vreg_traits<Vmm>::vlen, pd, conf, jit_name(), tail_kernel)
87 , offt_src0_(vlen_ / ((conf_.is_bf16 || conf_.is_f16) ? 2 : 1))
88 , offt_src1_(conf_.use_stride_src1 ? offt_src0_ : 0)
89 , io_(this, isa, {conf_.src0_type, conf_.src1_type, conf_.dst_type},
90 {false},
91 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
92 vmm_tail_vmask_.getIdx(), reg_tmp_},
93 io::io_emu_bf16_conf_t {vreg_bf16_emu_1_, vreg_bf16_emu_2_,
94 vreg_bf16_emu_3_, reg_tmp_, vreg_bf16_emu_4_},
95 create_saturation_vmm_map(),
96 io::io_gather_conf_t {simd_w_, full_mask_,
97 vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_,
98 vmm_tmp_gather_.getIdx()}) {
99 init();
100}
101
102template <cpu_isa_t isa, typename Vmm>
103std::map<data_type_t, io::io_saturation_conf_t>
104jit_uni_binary_kernel_t<isa, Vmm>::create_saturation_vmm_map() const {
105
106 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
107
108 if (conf_.is_i8)
109 saturation_map.emplace(conf_.dst_type,
110 io::io_saturation_conf_t {vreg_zero_.getIdx(),
111 vreg_saturation_ubound_.getIdx(), reg_tmp_});
112
113 return saturation_map;
114}
115
116template <cpu_isa_t isa, typename Vmm>
117void jit_uni_binary_kernel_t<isa, Vmm>::init() {
118 if (conf_.with_postops) init_post_ops_injector();
119}
120
121template <cpu_isa_t isa, typename Vmm>
122void jit_uni_binary_kernel_t<isa, Vmm>::init_post_ops_injector() {
123 const memory_desc_wrapper dst_d(pd_->dst_md(0));
124 const auto &po = pd_->attr()->post_ops_;
125
126 const eltwise_injector::static_params_t esp(true /*save_state*/,
127 reg_elt_inj_table_, elt_inj_opmask_, true /*is_fwd*/,
128 false /*use_dst*/);
129 const binary_injector::rhs_arg_static_params_t rhs_arg_bsp {10, reg_tmp_,
130 reg_elt_inj_table_, r13, true /*preserve gpr*/,
131 true /*preserve vmm*/, PARAM_OFF(post_ops_binary_rhs_arg_vec),
132 PARAM_OFF(dst_orig), dst_d, tail_size_, tail_opmask_,
133 false /*use_exact_tail_scalar_bcast*/};
134 const binary_injector::static_params_t bsp(this->param1,
135 get_supported_postops_bcast_strategies(), rhs_arg_bsp);
136
137 postops_injector_ = utils::make_unique<
138 injector::jit_uni_postops_injector_t<inject_isa, Vmm>>(
139 this, po, bsp, esp);
140}
141
142template <cpu_isa_t isa, typename Vmm>
143void jit_uni_binary_kernel_t<isa, Vmm>::apply_postops(int unroll, bool tail) {
144 const auto sum_injector = [&]() {
145 for (int i = 0; i < unroll; i++) {
146 const int offt = simd_w_ * i;
147 const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
148 const Vmm vreg_tmp = conf_.is_src_different_layouts
149 ? vmm_gathered_src_
150 : Vmm(unroll + i + vmm_start_idx_);
151 io_.at(conf_.dst_type)
152 ->load(dst_ptr(offt
153 * types::data_type_size(conf_.dst_type)),
154 vreg_tmp, tail);
155 uni_vfmadd231ps(vreg_tmp_src0, vreg_tmp, vreg_sum_scale_);
156 }
157 };
158
159 if (conf_.do_sum)
160 postops_injector_->set_lambda_injector(
161 primitive_kind::sum, sum_injector);
162
163 if (conf_.with_binary) {
164 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
165 const Reg64 &reg_offt_dst
166 = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_;
167
168 const injector_utils::register_preserve_guard_t register_guard {
169 this, {reg_tmp1_}};
170
171 mov(reg_tmp1_, reg_dst_);
172 add(reg_tmp1_, reg_offt_dst);
173
174 for (int vmm_idx = 1; vmm_idx < unroll + vmm_start_idx_; vmm_idx++) {
175 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_tmp1_);
176 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx,
177 (vmm_idx - vmm_start_idx_) * simd_w_
178 * types::data_type_size(conf_.dst_type));
179 if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
180 }
181 postops_injector_->compute_vector_range(
182 1, unroll + vmm_start_idx_, rhs_arg_params);
183 } else
184 postops_injector_->compute_vector_range(1, unroll + vmm_start_idx_);
185}
186
187template <cpu_isa_t isa, typename Vmm>
188void jit_uni_binary_kernel_t<isa, Vmm>::load_kernel_params() {
189 mov(reg_tmp_, float2int(conf_.sum_scale));
190 uni_vmovq(xreg_sum_scale_, reg_tmp_);
191 uni_vbroadcastss(vreg_sum_scale_, xreg_sum_scale_);
192 if (is_src1_outer_dims_tail_)
193 mov(reg_outer_dims_range_,
194 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
195 else
196 mov(reg_reverse_spat_offt_,
197 ptr[reg_param_ + PARAM_OFF(spat_offt_count)]);
198 mov(reg_src0_, ptr[reg_param_ + PARAM_OFF(src0)]);
199 mov(reg_src1_, ptr[reg_param_ + PARAM_OFF(src1)]);
200 mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]);
201 if (conf_.is_src_different_layouts) {
202 mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(indices)]);
203 uni_vmovdqu(vmm_indices_, ptr[reg_tmp_]);
204
205 mov(reg_src1_stride_range_,
206 ptr[reg_param_ + PARAM_OFF(src1_stride_range)]);
207 mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
208 }
209 if (conf_.do_scale_src0)
210 mov(reg_scales_src0_, ptr[reg_param_ + PARAM_OFF(scales_src0)]);
211 if (conf_.do_scale_src1)
212 mov(reg_scales_src1_, ptr[reg_param_ + PARAM_OFF(scales_src1)]);
213}
214
215template <cpu_isa_t isa, typename Vmm>
216Address jit_uni_binary_kernel_t<isa, Vmm>::src0_ptr(size_t offt) {
217 return vmmword[reg_src0_ + reg_offt_src0_ + offt];
218}
219
220template <cpu_isa_t isa, typename Vmm>
221Address jit_uni_binary_kernel_t<isa, Vmm>::src1_ptr(size_t offt) {
222 return vmmword[reg_src1_ + reg_offt_src1_ + offt];
223}
224
225template <cpu_isa_t isa, typename Vmm>
226Address jit_uni_binary_kernel_t<isa, Vmm>::dst_ptr(size_t offt) {
227 const Reg64 &reg_offt_dst = conf_.is_i8 ? reg_offt_dst_ : reg_offt_src0_;
228 return vmmword[reg_dst_ + reg_offt_dst + offt];
229}
230
231template <cpu_isa_t isa, typename Vmm>
232unsigned int jit_uni_binary_kernel_t<isa, Vmm>::cmp_predicate(alg_kind_t alg) {
233 using namespace alg_kind;
234 switch (alg) {
235 case binary_ge: return _cmp_nlt_us;
236 case binary_gt: return _cmp_nle_us;
237 case binary_le: return _cmp_le_os;
238 case binary_lt: return _cmp_lt_os;
239 case binary_eq: return _cmp_eq_oq;
240 case binary_ne: return _cmp_neq_uq;
241 default: assert(!"not supported operation!"); return -1;
242 }
243}
244
245template <cpu_isa_t isa, typename Vmm>
246void jit_uni_binary_kernel_t<isa, Vmm>::perform_op(
247 const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1) {
248 using namespace alg_kind;
249 const auto alg = pd_->desc()->alg_kind;
250 const bool cmp_op = utils::one_of(alg, alg_kind::binary_ge,
251 alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt,
252 alg_kind::binary_eq, alg_kind::binary_ne);
253 if (conf_.do_scale_src0) uni_vmulps(v0, v0, s_src0);
254 if (conf_.do_scale_src1 && offt_src1_ != 0 && !conf_.broadcast_src1_value)
255 uni_vmulps(v1, v1, s_src1);
256
257 if (alg == binary_add)
258 uni_vaddps(v0, v0, v1);
259 else if (alg == binary_mul)
260 uni_vmulps(v0, v0, v1);
261 else if (alg == binary_max)
262 uni_vmaxps(v0, v0, v1);
263 else if (alg == binary_min)
264 uni_vminps(v0, v0, v1);
265 else if (alg == binary_div)
266 uni_vdivps(v0, v0, v1);
267 else if (alg == binary_sub)
268 uni_vsubps(v0, v0, v1);
269 else if (cmp_op) {
270 const unsigned int predicate = cmp_predicate(alg);
271 if (is_avx512) {
272 vcmpps(cmp_mask, v0, v1, predicate);
273 vmovups(v0 | cmp_mask | T_z, vreg_one_);
274 } else {
275 uni_vcmpps(v0, v0, v1, predicate);
276 uni_vminps(v0, v0, vreg_one_);
277 }
278 } else
279 assert(!"not supported operation!");
280}
281
282template <cpu_isa_t isa, typename Vmm>
283void jit_uni_binary_kernel_t<isa, Vmm>::prepare_isa_kernel() {
284 if (conf_.is_bf16) io_.init_bf16();
285 if (tail_size_ > 0) io_.prepare_tail_mask();
286 if (conf_.is_src_different_layouts && is_superset(isa, avx2)) {
287 io_.init_full_mask();
288 io_.prepare_full_mask();
289 }
290}
291
292template <cpu_isa_t isa, typename Vmm>
293void jit_uni_binary_kernel_t<isa, Vmm>::compute_bcast(bool tail) {
294 if (conf_.broadcast_src1_value) {
295 if (conf_.is_i8)
296 uni_vpxor(xreg_bcast_src1_, xreg_bcast_src1_, xreg_bcast_src1_);
297 io_.at(conf_.src1_type)->broadcast(src1_ptr(), vreg_bcast_src1_);
298 } else if (!conf_.is_i8 && offt_src1_ == 0) {
299 io_.at(conf_.src1_type)->load(src1_ptr(), vreg_bcast_src1_, tail);
300 }
301}
302
303template <cpu_isa_t isa, typename Vmm>
304void jit_uni_binary_kernel_t<isa, Vmm>::load_src1(
305 const Vmm &vreg_src1, const int offt, bool tail) {
306 if (conf_.is_src_different_layouts) {
307 // if different layouts, gather data with strides
308 // after getting to stride range, offset is restored and
309 // increased
310 io_.at(conf_.src1_type)
311 ->gather(reg_src1_, vmm_indices_, vreg_src1, tail);
312 // gather is using register instead of operand to read address
313 // use reg_src1_ directly, without offset stored in second
314 // register
315 add(reg_src1_,
316 types::data_type_size(conf_.src1_type) * conf_.src1_stride
317 * simd_w_);
318 sub(reg_reverse_src1_stride_range_,
319 types::data_type_size(conf_.src1_type) * conf_.src1_stride
320 * simd_w_);
321
322 Label src1_stride_range_not_exceed, src1_C_tail_end;
323
324 cmp(reg_reverse_src1_stride_range_, 0);
325 jg(src1_stride_range_not_exceed, T_NEAR);
326 {
327 pop(reg_src1_);
328 add(reg_src1_, types::data_type_size(conf_.src1_type));
329 push(reg_src1_);
330 mov(reg_reverse_src1_stride_range_, reg_src1_stride_range_);
331 }
332 L(src1_stride_range_not_exceed);
333 } else
334 io_.at(conf_.src1_type)
335 ->load(src1_ptr(offt * types::data_type_size(conf_.src1_type)),
336 vreg_src1, tail);
337}
338
339template <cpu_isa_t isa, typename Vmm>
340void jit_uni_binary_kernel_t<isa, Vmm>::compute_dst(int unroll, bool tail) {
341 for (int i = 0; i < unroll; i++) {
342 const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
343 const Vmm vreg_tmp = conf_.is_src_different_layouts
344 ? vmm_gathered_src_
345 : Vmm(unroll + i + vmm_start_idx_);
346 const Vmm vreg_tmp_src1 = offt_src1_ ? vreg_tmp : vreg_bcast_src1_;
347 const int offt = simd_w_ * i;
348 io_.at(conf_.src0_type)
349 ->load(src0_ptr(offt * types::data_type_size(conf_.src0_type)),
350 vreg_tmp_src0, tail);
351 if (offt_src1_) load_src1(vreg_tmp_src1, offt, tail);
352
353 // avoid multiple multiplication on input scale for broadcasted vreg
354 // not needed for different layouts
355 if (!conf_.is_src_different_layouts)
356 uni_vmovups(vreg_tmp, vreg_tmp_src1);
357 perform_op(
358 vreg_tmp_src0, vreg_tmp, vreg_scales_src0_, vreg_scales_src1_);
359 }
360
361 if (postops_injector_) apply_postops(unroll, tail);
362
363 for (int i = 0; i < unroll; i++) {
364 const Vmm vreg_tmp_src0 = Vmm(i + vmm_start_idx_);
365 const int offt = simd_w_ * i;
366 const auto dt_size = types::data_type_size(conf_.dst_type);
367
368 if (is_tail_kernel_ && padding_tail_size_) {
369 // apply zero-padding
370 Label end;
371 auto off_base = 0;
372 auto zero_pad_left = padding_tail_size_;
373
374 // inplace data is assumed to be zero-padded
375 cmp(reg_src0_, reg_dst_);
376 je(end, T_NEAR);
377
378 if (zero_pad_left >= simd_w_ - tail_size_) {
379 vxorps(vreg_zero_, vreg_zero_, vreg_zero_);
380 if (is_avx512)
381 uni_vmovups(vreg_zero_ | tail_opmask_, vreg_tmp_src0);
382 else
383 uni_vblendvps(vreg_zero_, vreg_zero_, vreg_tmp_src0,
384 vmm_tail_vmask_);
385 io_.at(conf_.dst_type)
386 ->store(vreg_zero_, dst_ptr(offt * dt_size), false);
387 off_base = simd_w_ * dt_size;
388 zero_pad_left -= simd_w_ - tail_size_;
389 } else {
390 io_.at(conf_.dst_type)
391 ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), true);
392 off_base = tail_size_ * dt_size;
393 }
394
395 if (zero_pad_left) {
396 push(abi_param1);
397 const Reg32 &reg_zero = eax;
398 const Reg64 &reg_ptr = rdi;
399 const Reg64 &reg_counter = rcx;
400 const auto off_start = off_base;
401 const auto off_end = off_start + zero_pad_left * dt_size;
402 xor_(reg_zero, reg_zero);
403 lea(reg_ptr,
404 ptr[dst_ptr(offt * dt_size).getRegExp()
405 + RegExp(off_start)]);
406 mov(reg_counter, off_end - off_start);
407 rep();
408 stosb();
409 pop(abi_param1);
410 }
411 L(end);
412 } else
413 io_.at(conf_.dst_type)
414 ->store(vreg_tmp_src0, dst_ptr(offt * dt_size), tail);
415 }
416}
417
418template <cpu_isa_t isa, typename Vmm>
419void jit_uni_binary_kernel_t<isa, Vmm>::forward() {
420 Label unroll_loop, unroll_loop_tail, nelems_tail, end;
421
422 const auto src0_type_size = types::data_type_size(conf_.src0_type);
423 const auto src1_type_size = types::data_type_size(conf_.src1_type);
424 const auto dst_type_size = types::data_type_size(conf_.dst_type);
425
426 if (conf_.is_src_different_layouts) push(reg_src1_);
427
428 // if outer dims tail, do it outside outer dims loop
429 if (!is_src1_outer_dims_tail_) {
430 if (conf_.is_i8) {
431 uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
432 io_.init_saturate_f32({conf_.dst_type});
433 xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
434 }
435
436 xor_(reg_offt_src0_,
437 reg_offt_src0_); // offt_src0 to get addr of src0/dst
438 if (!conf_.is_src_different_layouts)
439 xor_(reg_offt_src1_,
440 reg_offt_src1_); // offt_src1 to get addr of src1
441 if (conf_.use_stride_rhs_postops && !conf_.is_i8)
442 xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
443 }
444 const auto alg = pd_->desc()->alg_kind;
445
446 if (utils::one_of(alg, alg_kind::binary_ge, alg_kind::binary_gt,
447 alg_kind::binary_le, alg_kind::binary_lt, alg_kind::binary_eq,
448 alg_kind::binary_ne)) {
449 Xmm xreg_one = Xmm(vreg_one_.getIdx());
450 mov(reg_tmp_, float2int(1));
451 uni_vmovq(xreg_one, reg_tmp_);
452 uni_vbroadcastss(vreg_one_, xreg_one);
453 }
454
455 compute_bcast(false); // bcast/load vreg just one time per a kernel call
456
457 // used in c_blocked strategy for last blocked if tail exists
458 const bool treat_each_compute_step_as_tail
459 = !conf_.is_i8 && is_tail_kernel_ && tail_size_;
460
461 if (conf_.do_scale_src0)
462 uni_vbroadcastss(vreg_scales_src0_, ptr[reg_scales_src0_]);
463 if (conf_.do_scale_src1) {
464 uni_vbroadcastss(vreg_scales_src1_, ptr[reg_scales_src1_]);
465 if (conf_.broadcast_src1_value || offt_src1_ == 0)
466 uni_vmulps(vreg_bcast_src1_, vreg_bcast_src1_, vreg_scales_src1_);
467 }
468
469 L(unroll_loop);
470 {
471 const size_t offt = unroll_regs_ * simd_w_;
472 cmp(reg_reverse_spat_offt_, offt * dst_type_size);
473 jl(unroll_loop_tail, T_NEAR);
474
475 compute_dst(unroll_regs_, treat_each_compute_step_as_tail);
476 sub(reg_reverse_spat_offt_, offt * dst_type_size);
477 add(reg_offt_src0_, offt * src0_type_size);
478 if (conf_.is_i8) {
479 if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
480 add(reg_offt_src1_, offt * src1_type_size);
481 add(reg_offt_dst_, offt);
482 } else {
483 if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
484 add(reg_offt_src1_, offt * src1_type_size);
485 if (conf_.use_stride_rhs_postops) add(reg_off_rhs_postops_, offt);
486 }
487 jmp(unroll_loop);
488 }
489
490 L(unroll_loop_tail);
491 {
492 cmp(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
493 jl(nelems_tail, T_NEAR);
494
495 compute_dst(1, treat_each_compute_step_as_tail);
496 sub(reg_reverse_spat_offt_, simd_w_ * dst_type_size);
497 add(reg_offt_src0_, simd_w_ * src0_type_size);
498 if (conf_.is_i8) {
499 if (!conf_.broadcast_src1_value && !conf_.is_src_different_layouts)
500 add(reg_offt_src1_, simd_w_ * src1_type_size);
501 add(reg_offt_dst_, simd_w_);
502 } else {
503 if (conf_.use_stride_src1 && !conf_.is_src_different_layouts)
504 add(reg_offt_src1_, simd_w_ * src1_type_size);
505 if (conf_.use_stride_rhs_postops)
506 add(reg_off_rhs_postops_, simd_w_);
507 }
508
509 jmp(unroll_loop_tail);
510 }
511
512 L(nelems_tail);
513 {
514 cmp(reg_reverse_spat_offt_, 1);
515 jl(end, T_NEAR);
516
517 compute_dst(1, true);
518 // need to increase if forward over outer dims
519 if (is_src1_outer_dims_tail_) {
520 add(reg_offt_src0_, tail_size_ * src0_type_size);
521 if (conf_.is_i8)
522 add(reg_offt_dst_, tail_size_);
523 else {
524 if (conf_.use_stride_rhs_postops)
525 add(reg_off_rhs_postops_, tail_size_);
526 }
527 }
528 }
529
530 L(end);
531 if (conf_.is_src_different_layouts) pop(reg_src1_);
532}
533
534template <cpu_isa_t isa, typename Vmm>
535void jit_uni_binary_kernel_t<isa, Vmm>::forward_over_outer_dims() {
536 const auto outer_dims_size
537 = conf_.outer_dims * types::data_type_size(conf_.dst_type);
538
539 if (conf_.is_i8) {
540 uni_vpxor(vreg_zero_, vreg_zero_, vreg_zero_);
541 io_.init_saturate_f32({conf_.dst_type});
542 xor_(reg_offt_dst_, reg_offt_dst_); // offt_dst to get addr of dst
543 }
544
545 xor_(reg_offt_src0_,
546 reg_offt_src0_); // offt_src0 to get addr of src0/dst
547 if (conf_.use_stride_rhs_postops && !conf_.is_i8)
548 xor_(reg_off_rhs_postops_, reg_off_rhs_postops_);
549
550 Label c_loop;
551 L(c_loop);
552 {
553 mov(reg_reverse_spat_offt_, outer_dims_size);
554 forward();
555 sub(reg_outer_dims_range_, outer_dims_size);
556 cmp(reg_outer_dims_range_, 0);
557 jg(c_loop);
558 }
559}
560
561template <cpu_isa_t isa, typename Vmm>
562void jit_uni_binary_kernel_t<isa, Vmm>::generate() {
563 preamble();
564 load_kernel_params();
565 prepare_isa_kernel();
566 // if outer dims is not aligned to simd_w, iterate over it to avoid
567 // modifying the gather indices
568 if (is_src1_outer_dims_tail_)
569 forward_over_outer_dims();
570 else
571 forward();
572 postamble();
573
574 if ((conf_.with_eltwise || conf_.is_i8) && postops_injector_)
575 postops_injector_->prepare_table();
576}
577
578#undef PARAM_OFF
579
580template struct jit_uni_binary_kernel_t<avx512_core_fp16, Zmm>;
581template struct jit_uni_binary_kernel_t<avx512_core_fp16, Ymm>;
582template struct jit_uni_binary_kernel_t<avx512_core_fp16, Xmm>;
583template struct jit_uni_binary_kernel_t<avx512_core_bf16, Zmm>;
584template struct jit_uni_binary_kernel_t<avx512_core_bf16, Ymm>;
585template struct jit_uni_binary_kernel_t<avx512_core_bf16, Xmm>;
586template struct jit_uni_binary_kernel_t<avx512_core, Zmm>;
587template struct jit_uni_binary_kernel_t<avx512_core, Ymm>;
588template struct jit_uni_binary_kernel_t<avx512_core, Xmm>;
589template struct jit_uni_binary_kernel_t<avx2, Ymm>;
590template struct jit_uni_binary_kernel_t<avx2, Xmm>;
591template struct jit_uni_binary_kernel_t<sse41, Xmm>;
592
593} // namespace x64
594} // namespace cpu
595} // namespace impl
596} // namespace dnnl
597