1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <functional>
18
19#include "common/dnnl_thread.hpp"
20#include "cpu/cpu_primitive.hpp"
21#include "cpu/x64/jit_uni_binary.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28static bcast_set_t get_supported_postops_bcast_strategies() {
29 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
30 broadcasting_strategy_t::per_oc_spatial,
31 broadcasting_strategy_t::no_broadcast};
32}
33
34static bool compare_layouts(const memory_desc_wrapper &src0_md,
35 const memory_desc_wrapper &src1_md) {
36 const strides_t &strides0 = src0_md.blocking_desc().strides;
37 const strides_t &strides1 = src1_md.blocking_desc().strides;
38 const dims_t &dims0 = src0_md.dims();
39 const dims_t &dims1 = src1_md.dims();
40 const int ndims = src0_md.ndims();
41
42 bool is_bcast = false;
43 for (int d = 1; d < ndims; d++)
44 is_bcast = is_bcast || dims0[d] != dims1[d];
45 if (is_bcast) return true;
46
47 bool same_layouts = true;
48 for (int d = 0; d < ndims; ++d)
49 same_layouts = same_layouts && strides0[d] == strides1[d];
50 return same_layouts;
51}
52
53static dim_t get_different_layout_stride(
54 const strides_t &strides0, const strides_t &strides1, const int ndims) {
55 for (int d = 0; d < ndims; d++)
56 if (strides0[d] == 1) return strides1[d];
57 return strides1[ndims - 1];
58}
59
60static dim_t get_outer_dims_product(
61 const strides_t &strides0, const dims_t &dims, const int ndims) {
62 // nchw:nhwc->nchw
63 if (strides0[1] == 1) return dims[1];
64 // nhwc:nchw->nhwc
65 else if (strides0[ndims - 1] == 1)
66 return utils::array_product(dims + 2, ndims - 2);
67 else
68 return dims[ndims - 1];
69}
70
71using namespace data_type;
72
73static bool data_type_supported(const data_type_t dtype) {
74 return utils::one_of(dtype, f32, bf16, f16, s8, u8);
75}
76
77static cpu_isa_t get_supported_isa() {
78 if (mayiuse(avx512_core_fp16)) return avx512_core_fp16;
79 if (mayiuse(avx512_core_bf16)) return avx512_core_bf16;
80 if (mayiuse(avx512_core)) return avx512_core;
81 if (mayiuse(avx2)) return avx2;
82 if (mayiuse(sse41)) return sse41;
83
84 return isa_undef;
85}
86
87static bool data_format_supported(
88 const memory_desc_wrapper &mdw, const cpu_isa_t isa) {
89 if (mdw.is_plain()) return true;
90 const auto blk_size = mdw.blocking_desc().inner_blks[0];
91 return (is_superset(isa, avx512_core) && utils::one_of(blk_size, 16, 8, 4))
92 || (is_superset(isa, avx2) && utils::one_of(blk_size, 8, 4))
93 || (is_superset(isa, sse41) && blk_size == 4);
94}
95
96status_t jit_uni_binary_t::pd_t::init(engine_t *engine) {
97 using sm = primitive_attr_t::skip_mask_t;
98
99 conf_.dst_type = dst_md()->data_type;
100 conf_.src0_type = src_md(0)->data_type;
101 conf_.src1_type = src_md(1)->data_type;
102
103 memory_desc_wrapper dst_md_(dst_md());
104 memory_desc_wrapper src0_md_(src_md(0));
105 memory_desc_wrapper src1_md_(src_md(1));
106
107 const auto &po = attr()->post_ops_;
108 const int elt_idx = po.find(primitive_kind::eltwise);
109 conf_.is_i8 = utils::one_of(conf_.dst_type, s8, u8);
110
111 conf_.isa = get_supported_isa();
112
113 bool ok = data_type_supported(conf_.dst_type)
114 && data_type_supported(conf_.src0_type)
115 && data_type_supported(conf_.src1_type)
116 && data_format_supported(src0_md_, conf_.isa)
117 && IMPLICATION(conf_.src0_type == bf16, mayiuse(avx512_core))
118 && IMPLICATION(utils::one_of(f16, conf_.src0_type, conf_.src1_type,
119 conf_.dst_type),
120 mayiuse(avx512_core_fp16))
121 && set_default_params() == status::success && !has_zero_dim_memory()
122 && IMPLICATION(!conf_.is_i8, src0_md_ == dst_md_) && is_applicable()
123 && attr()->has_default_values(sm::post_ops | sm::scales_runtime)
124 && attr_.set_default_formats(dst_md(0)) == status::success;
125 if (!ok) return status::unimplemented;
126
127 // All operations over blocking descriptors should have md initialized.
128 conf_.is_src_different_layouts = !compare_layouts(src0_md_, src1_md_);
129 ok = post_ops_ok(
130 attr(), src_md(0), dst_md(), conf_.is_src_different_layouts)
131 && (conf_.is_i8 || elt_idx == -1
132 || IMPLICATION(!dst_md_.is_dense(),
133 cpu_eltwise_fwd_pd_t::eltwise_preserves_zero(
134 po.entry_[elt_idx].eltwise)))
135 && IMPLICATION((!attr()->scales_.has_default_values()),
136 check_scales_mask())
137 && (conf_.is_i8
138 || IMPLICATION(!mayiuse(avx2),
139 src0_md_.consistent_with(src1_md_)
140 || src0_md_.is_plain()));
141
142 if (!ok) return status::unimplemented;
143
144 conf_.postops_per_oc_broadcast_exists
145 = binary_injector::any_binary_postop_rhs_per_oc_broadcast(
146 po, src0_md_, get_supported_postops_bcast_strategies());
147 conf_.is_bf16 = conf_.dst_type == bf16;
148 conf_.is_f16 = conf_.dst_type == f16;
149 conf_.op_type = get_op_type(src0_md_);
150 assert(conf_.op_type != op_t::none);
151 conf_.do_scale_src0 = !attr()->scales_.get(DNNL_ARG_SRC_0).defined()
152 || !attr()->scales_.get(DNNL_ARG_SRC_0).has_default_values();
153 conf_.do_scale_src1 = !attr()->scales_.get(DNNL_ARG_SRC_1).defined()
154 || !attr()->scales_.get(DNNL_ARG_SRC_1).has_default_values();
155 const auto sum_idx = po.find(primitive_kind::sum);
156 conf_.do_sum = sum_idx != -1 && po.entry_[sum_idx].sum.scale != 0.f;
157 conf_.with_eltwise = po.find(primitive_kind::eltwise) != -1;
158 conf_.with_binary = po.find(primitive_kind::binary) != -1;
159 conf_.with_postops
160 = conf_.with_binary || conf_.with_eltwise || conf_.do_sum;
161 conf_.sum_scale = conf_.do_sum ? po.entry_[sum_idx].sum.scale : 0.f;
162 const auto &bcast_dims = broadcast_dims();
163 conf_.bcast_type = is_tensor_op() ? bcast_t::none
164 : get_bcast_type(src1_md_, bcast_dims);
165 conf_.broadcast_src1_value = (conf_.op_type == op_t::n_c_spatial
166 && conf_.bcast_type == bcast_t::per_c)
167 || (utils::one_of(conf_.op_type, op_t::n_spatial_c, op_t::c_blocked)
168 && conf_.bcast_type == bcast_t::per_w)
169 || conf_.bcast_type == bcast_t::scalar;
170 conf_.use_stride_src1 = !conf_.broadcast_src1_value
171 && (utils::one_of(
172 conf_.bcast_type, bcast_t::none, bcast_t::per_batch)
173 || (conf_.op_type == op_t::n_spatial_c
174 && conf_.bcast_type == bcast_t::per_c)
175 || (conf_.op_type == op_t::n_c_spatial
176 && conf_.bcast_type == bcast_t::per_w));
177 conf_.use_stride_rhs_postops = conf_.postops_per_oc_broadcast_exists
178 && conf_.op_type == op_t::n_spatial_c;
179
180 const auto ndims = src0_md_.ndims();
181 if (conf_.is_src_different_layouts) {
182 const auto &strides0 = src0_md_.blocking_desc().strides;
183 const auto &strides1 = src1_md_.blocking_desc().strides;
184 conf_.src1_stride
185 = get_different_layout_stride(strides0, strides1, ndims);
186 conf_.outer_dims
187 = get_outer_dims_product(strides0, src0_md_.dims(), ndims);
188 }
189 if (conf_.bcast_type == bcast_t::per_w) {
190 for (int d = 2; d < ndims; ++d)
191 conf_.not_bcasted_sp_dims += !bcast_dims[d];
192 }
193
194 return status::success;
195}
196
197op_t jit_uni_binary_t::pd_t::get_op_type(const memory_desc_wrapper &src0_d) {
198 const auto &strides = src0_d.blocking_desc().strides;
199 const auto ndims = src0_d.ndims();
200
201 if (!src0_d.is_plain() && src0_d.blocking_desc().inner_idxs[0] == 1)
202 return op_t::c_blocked;
203 else if (strides[1] == 1)
204 return op_t::n_spatial_c;
205 else if (strides[0] >= strides[1]
206 && IMPLICATION(ndims >= 3, strides[1] >= strides[2]))
207 return op_t::n_c_spatial;
208 return op_t::none;
209}
210
211bool jit_uni_binary_t::pd_t::is_only_dim0_bcasted(
212 const dims_t &bcast_dims, const int ndims) {
213 bool only_dim0_bcasted = true;
214 for (int d = 1; d < ndims; d++)
215 only_dim0_bcasted = only_dim0_bcasted && bcast_dims[d] == 0;
216 return only_dim0_bcasted;
217}
218
219// non-blocked: nxc || ncx
220bool jit_uni_binary_t::pd_t::is_format_non_blocked(
221 const memory_desc_wrapper &mdw) const {
222 const auto &dims = mdw.dims();
223 const auto &strides = mdw.blocking_desc().strides;
224 const auto &ndims = mdw.ndims();
225
226 const bool is_ncx
227 = IMPLICATION(strides[0] != 0,
228 strides[0] >= utils::array_product(dims + 1, ndims - 1))
229 && IMPLICATION(ndims >= 3 && strides[1] != 0,
230 strides[1] >= utils::array_product(dims + 2, ndims - 2))
231 && IMPLICATION(ndims >= 4 && strides[2] != 0,
232 strides[2] >= utils::array_product(dims + 3, ndims - 3))
233 && IMPLICATION(ndims >= 5 && strides[3] != 0,
234 strides[3] >= utils::array_product(dims + 4, ndims - 4))
235 && IMPLICATION(strides[ndims - 1] != 0, strides[ndims - 1] == 1);
236 const bool is_nxc
237 = IMPLICATION(strides[0] != 0,
238 strides[0] >= utils::array_product(dims + 1, ndims - 1))
239 && IMPLICATION(ndims >= 3 && strides[2] != 0,
240 strides[2] >= dims[1]
241 * utils::array_product(dims + 3, ndims - 3))
242 && IMPLICATION(ndims >= 4 && strides[3] != 0,
243 strides[3] >= dims[1]
244 * utils::array_product(dims + 4, ndims - 4))
245 && IMPLICATION(ndims >= 5 && strides[4] != 0,
246 strides[4] >= dims[1]
247 * utils::array_product(dims + 5, ndims - 5))
248 && IMPLICATION(strides[1] != 0, strides[1] == 1);
249 return is_nxc || is_ncx;
250}
251
252bcast_t jit_uni_binary_t::pd_t::get_bcast_type(
253 const memory_desc_wrapper &src1_d, const dims_t &bcast_dims) {
254 if (src1_d.nelems() == 1)
255 return bcast_t::scalar;
256 else if (bcast_dims[1] == 1)
257 return bcast_t::per_w;
258 else if (is_only_dim0_bcasted(bcast_dims, src1_d.ndims()))
259 return bcast_t::per_batch;
260 else
261 return bcast_t::per_c;
262}
263
264bool jit_uni_binary_t::pd_t::alg_preserves_zero() const {
265 using namespace utils;
266 using namespace alg_kind;
267 return utils::one_of(desc()->alg_kind, binary_add, binary_max, binary_min,
268 binary_mul, binary_sub, binary_ge, binary_gt, binary_le, binary_lt,
269 binary_eq, binary_ne);
270}
271
272bool jit_uni_binary_t::pd_t::check_scales_mask() const {
273 for (const auto &s : attr()->scales_.scales_) {
274 if (s.second.mask_ != 0) return false;
275 }
276 return true;
277}
278
279bool jit_uni_binary_t::pd_t::is_bcast_pattern(const dims_t &bcast_dims,
280 const dim_t ndims, const dim_t N_bcast, const dim_t C_bcast,
281 const dim_t W_bcast) const {
282 return bcast_dims[0] == N_bcast && bcast_dims[1] == C_bcast
283 && bcast_dims[ndims - 1] == W_bcast;
284}
285
286bool jit_uni_binary_t::pd_t::is_bcast_allowed(const int ndims) const {
287 // supported cases: NxCxDxHxW:{NxCx1x1x1,1xCx1x1x1,Nx1xDxHxW,Nx1x1xHxW,
288 // Nx1x1x1xW,1xCxDxHxW,1x1xDxHxW,1x1x1xHxW,
289 // 1x1x1x1xW,1x1x1x1x1}
290 const auto &bcast_dims = broadcast_dims();
291 // check if there is continuous broadcast between non-broadcast dims
292 // if next_bcast_expected == 1, not broadcast dim not met
293 int next_bcast_expected = 1;
294 bool sp_not_bcasted = true;
295 bool ok = true;
296 for (int d = 2; d < ndims; ++d) {
297 if (bcast_dims[d] == 0)
298 next_bcast_expected = 0;
299 else
300 sp_not_bcasted = false;
301 ok = ok && bcast_dims[d] == next_bcast_expected;
302 }
303
304#define BCAST_PATTERN(N, C, W, condition) \
305 (is_bcast_pattern(bcast_dims, ndims, N, C, W) && (condition))
306 if (ndims > 2)
307 ok = ok
308 && (BCAST_PATTERN(0, 1, 0, true) || BCAST_PATTERN(1, 1, 0, true)
309 || BCAST_PATTERN(1, 0, 0, sp_not_bcasted)
310 || BCAST_PATTERN(0, 0, 1, !!next_bcast_expected)
311 || BCAST_PATTERN(1, 0, 1, !!next_bcast_expected)
312 || BCAST_PATTERN(1, 1, 1, !!next_bcast_expected));
313#undef BCAST_PATTERN
314 return ok;
315}
316
317// check for different src formats with same dims
318// broadcast can be accepted if src_dim == src1_dims (1 == 1)
319bool jit_uni_binary_t::pd_t::is_different_layouts_allowed(
320 const memory_desc_wrapper &src0_d,
321 const memory_desc_wrapper &src1_d) const {
322 const dims_t &src0_dims = src0_d.dims();
323 const dims_t &src1_dims = src1_d.dims();
324 const int ndims = src0_d.ndims();
325
326 bool without_bcast = true;
327 for (int d = 0; d < ndims; d++)
328 without_bcast = without_bcast && src0_dims[d] == src1_dims[d];
329 if (!without_bcast) return false;
330
331 // allow nchw:nhwc and nhwc:nchw and disable for blocked layouts
332 return src0_d.is_plain() && src1_d.is_plain()
333 && is_format_non_blocked(src0_d) && is_format_non_blocked(src1_d);
334}
335
336bool jit_uni_binary_t::pd_t::is_applicable() {
337 const memory_desc_wrapper src0_d(src_md(0));
338 const memory_desc_wrapper src1_d(src_md(1));
339 const memory_desc_wrapper dst_d(dst_md());
340 const auto ndims = src0_d.ndims();
341
342 // check density first to avoid same non-dense src0 and src1 to pass
343 // the next check
344 bool ok = src0_d.is_dense(true) && src1_d.is_dense(true)
345 && dst_d.is_dense(true);
346 if (!ok) return false;
347
348 // TODO: fix implementation for tensor with paddings to work with any block
349 // size. For now return unimplemented if more than single blocking
350 // or `block size > 16`.
351 const auto &blk_d = dst_d.blocking_desc();
352 if (!dst_d.is_dense()
353 && (blk_d.inner_nblks > 1 || blk_d.inner_blks[0] > 16))
354 return false;
355
356 const bool is_src_different_layouts = !compare_layouts(src0_d, src1_d);
357 const bool different_layouts_allowed
358 = is_different_layouts_allowed(src0_d, src1_d);
359 if (!conf_.is_i8) {
360 const bool has_padding = utils::one_of(true,
361 src0_d.nelems(true) != src0_d.nelems(false),
362 src1_d.nelems(true) != src1_d.nelems(false),
363 dst_d.nelems(true) != dst_d.nelems(false));
364 ok = IMPLICATION(has_padding, alg_preserves_zero());
365 if (!ok) return false;
366
367 // full tensor operation
368 bool same_dims = true;
369 const auto &src0_dims = src0_d.dims();
370 const auto &src1_dims = src1_d.dims();
371 for (int d = 0; d < ndims; d++)
372 same_dims = same_dims && src0_dims[d] == src1_dims[d];
373 if (same_dims
374 && IMPLICATION(
375 is_src_different_layouts, different_layouts_allowed))
376 return true;
377 } else {
378 const dim_t C = ndims >= 2 ? src0_d.dims()[1] : 1;
379 const bool has_oc_tail = C != src0_d.padded_dims()[1];
380 const bool has_outer_dims_tail = is_src_different_layouts
381 && get_outer_dims_product(src0_d.blocking_desc().strides,
382 src0_d.dims(), src0_d.ndims());
383
384 // Disable compare operations when blocked tag with tail.
385 // Tail processing is not supported and the vcmps instruction
386 // overwrites the output vector.
387 if (utils::one_of(desc()->alg_kind, alg_kind::binary_ge,
388 alg_kind::binary_gt, alg_kind::binary_le,
389 alg_kind::binary_lt, alg_kind::binary_eq,
390 alg_kind::binary_ne)
391 && (has_oc_tail || has_outer_dims_tail))
392 return false;
393
394 // full tensor operation
395 if (src0_d.similar_to(src1_d, true, false, 0)
396 || different_layouts_allowed)
397 return true;
398 // source0 broadcast not supported
399 if (!src0_d.similar_to(dst_d, true, false, 0)) return false;
400 }
401 // broadcast or different layouts operation
402 if (!(is_bcast_allowed(ndims)
403 && IMPLICATION(
404 is_src_different_layouts, different_layouts_allowed)))
405 return false;
406
407 // only nspc and ncsp formats are supported for bcast
408 if (src0_d.is_plain() && src1_d.is_plain())
409 return is_format_non_blocked(src0_d) && is_format_non_blocked(src1_d);
410
411 // blocked formats
412 if (!conf_.is_i8) {
413 // check blocking_desc consistency
414 const auto valid_bd = [&](const memory_desc_wrapper &mdw) {
415 int blksize = 8;
416 if (mayiuse(avx512_core)) blksize = 16;
417 const auto &bd = mdw.blocking_desc();
418
419 return bd.inner_nblks == 1 && bd.inner_blks[0] == blksize
420 && bd.inner_idxs[0] == 1;
421 };
422
423 return valid_bd(src0_d) && valid_bd(src1_d);
424 } else {
425 const auto &bd0 = src0_d.blocking_desc();
426 const auto &bd1 = src1_d.blocking_desc();
427 const auto &bcast_dims = broadcast_dims();
428 // disable blocked tag for source1 when W is not broadcast
429 return bd0.strides[1] == 1 && bd0.inner_nblks == 0
430 && IMPLICATION(
431 bcast_dims[ndims - 1] == 0, bd1.inner_nblks == 0);
432 }
433}
434
435bool jit_uni_binary_t::post_ops_ok(const primitive_attr_t *attr,
436 const memory_desc_wrapper &src0_d, const memory_desc_wrapper &dst_d,
437 const bool is_src_different_layouts) {
438 using namespace primitive_kind;
439
440 const auto &p = attr->post_ops_;
441 const auto is_eltwise = [&](int idx) {
442 if (p.entry_[idx].is_eltwise()) {
443 const auto alg = p.entry_[idx].eltwise.alg;
444 return eltwise_injector::is_alg_supported(alg);
445 }
446 return false;
447 };
448 const auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); };
449 const bool is_avx512_core = mayiuse(avx512_core);
450 const bool is_avx512_core_fp16 = mayiuse(avx512_core_fp16);
451 const bool is_i8 = utils::one_of(dst_d.data_type(), s8, u8);
452
453 const auto supported_strategies = get_supported_postops_bcast_strategies();
454 for (int i = 0; i < p.len(); i++) {
455 if (p.contain(primitive_kind::sum, i)) {
456 if (p.entry_[i].sum.zero_point != 0) return false;
457 if (src0_d.data_type() != dst_d.data_type()) return false;
458 } else if (is_binary(i)) {
459 const auto &post_ops_mem = p.entry_[i].binary.src1_desc;
460 const bool is_src1_bf16 = post_ops_mem.data_type == data_type::bf16;
461 const bool is_src1_f16 = post_ops_mem.data_type == data_type::f16;
462 if (is_i8 && (is_src1_bf16 || is_src1_f16)) return false;
463 if (!IMPLICATION(is_src1_bf16, is_avx512_core)) return false;
464 if (!IMPLICATION(is_src1_f16, is_avx512_core_fp16)) return false;
465 if (get_rhs_arg_broadcasting_strategy(
466 post_ops_mem, dst_d, supported_strategies)
467 == broadcasting_strategy_t::no_broadcast) {
468 const memory_desc_wrapper post_op_mem_d(post_ops_mem);
469 if (!post_op_mem_d.similar_to(dst_d, true, false)) return false;
470 }
471 } else if (!is_eltwise(i)) {
472 return false;
473 }
474 }
475
476 const int vlen = is_avx512_core ? cpu_isa_traits<avx512_core>::vlen
477 : cpu_isa_traits<avx2>::vlen;
478 const bool postops_per_oc_broadcast_exists
479 = binary_injector::any_binary_postop_rhs_per_oc_broadcast(
480 p, src0_d, supported_strategies);
481 if (postops_per_oc_broadcast_exists && is_src_different_layouts)
482 return false;
483 const int blksize = vlen / sizeof(float);
484
485 const bool blocked_format = !src0_d.is_plain() && src0_d.is_blocking_desc();
486
487 if (postops_per_oc_broadcast_exists && blocked_format) {
488 /*
489 * check blocking_desc consistency, currently when among postops exists
490 * per_oc broadcast, binary kernel doesn't support situations when blocked
491 * format size is smaller then vlen. example: sse41 vlen size is 4 and format
492 * is nChw8c - not supported, avx2 vlen size is 8 and format is
493 * nChw8c - supported.
494 */
495 const auto blocking_desc = src0_d.blocking_desc();
496 if (blocking_desc.inner_nblks != 1
497 || blocking_desc.inner_blks[0] != blksize
498 || blocking_desc.inner_idxs[0] != 1)
499 return false;
500 }
501
502 const dim_t n_dims = src0_d.ndims();
503 const dim_t &oc = n_dims >= 2 ? src0_d.dims()[1] : 1;
504
505 /*
506 * TODO: Remove limitation supporting tail with blocked format for i8i8
507 */
508 const bool blocked_tail = p.len() && blocked_format && oc % blksize;
509
510 return binary_injector::binary_args_broadcast_supported(
511 p, src0_d, get_supported_postops_bcast_strategies())
512 && IMPLICATION(
513 utils::one_of(src0_d.data_type(), s8, u8), !blocked_tail)
514 && IMPLICATION(postops_per_oc_broadcast_exists,
515 binary_injector::all_binary_postop_rhs_per_oc_broadcast(p,
516 src0_d, supported_strategies,
517 [&src0_d](const memory_desc_wrapper &rhs_arg_md) {
518 return IMPLICATION(!mayiuse(avx2),
519 src0_d.consistent_with(rhs_arg_md)
520 || src0_d.is_plain());
521 }));
522}
523
524binary_kernel_t *create_binary_kernel(
525 const jit_uni_binary_t::pd_t *pd, bool tail_kernel) {
526 const auto &conf = pd->get_conf();
527 const memory_desc_wrapper src0_d(pd->src_md(0));
528 // No support for different blocked memory layouts
529 const auto blk_size = src0_d.blocking_desc().inner_blks[0];
530 const auto is_plain_layout = src0_d.is_plain();
531 switch (conf.isa) {
532 case avx512_core_fp16: {
533 if (blk_size == 16 || is_plain_layout) {
534 using kernel_t
535 = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Zmm>;
536 return new kernel_t(pd, conf, tail_kernel);
537 } else if (blk_size == 8) {
538 using kernel_t
539 = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Ymm>;
540 return new kernel_t(pd, conf, tail_kernel);
541 } else if (blk_size == 4) {
542 using kernel_t
543 = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Xmm>;
544 return new kernel_t(pd, conf, tail_kernel);
545 }
546 break;
547 }
548 case avx512_core_bf16: {
549 if (blk_size == 16 || is_plain_layout) {
550 if (conf.is_i8) {
551 using kernel_t
552 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>;
553 return new kernel_t(pd, conf, false);
554 } else {
555 using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16,
556 Xbyak::Zmm>;
557 return new kernel_t(pd, conf, tail_kernel);
558 }
559 } else if (blk_size == 8) {
560 if (conf.is_i8) {
561 using kernel_t
562 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>;
563 return new kernel_t(pd, conf, false);
564 } else {
565 using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16,
566 Xbyak::Ymm>;
567 return new kernel_t(pd, conf, tail_kernel);
568 }
569 } else if (blk_size == 4) {
570 if (conf.is_i8) {
571 using kernel_t
572 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>;
573 return new kernel_t(pd, conf, false);
574 } else {
575 using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16,
576 Xbyak::Xmm>;
577 return new kernel_t(pd, conf, tail_kernel);
578 }
579 }
580 break;
581 }
582 case avx512_core: {
583 if (blk_size == 16 || is_plain_layout) {
584 if (conf.is_i8) {
585 using kernel_t
586 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>;
587 return new kernel_t(pd, conf, false);
588 } else {
589 using kernel_t
590 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>;
591 return new kernel_t(pd, conf, tail_kernel);
592 }
593 } else if (blk_size == 8) {
594 if (conf.is_i8) {
595 using kernel_t
596 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>;
597 return new kernel_t(pd, conf, false);
598 } else {
599 using kernel_t
600 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>;
601 return new kernel_t(pd, conf, tail_kernel);
602 }
603 } else if (blk_size == 4) {
604 if (conf.is_i8) {
605 using kernel_t
606 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>;
607 return new kernel_t(pd, conf, false);
608 } else {
609 using kernel_t
610 = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>;
611 return new kernel_t(pd, conf, tail_kernel);
612 }
613 }
614 break;
615 }
616 case avx2: {
617 if (blk_size == 8 || is_plain_layout) {
618 using kernel_t = jit_uni_binary_kernel_t<avx2, Xbyak::Ymm>;
619 return new kernel_t(pd, conf, tail_kernel && !conf.is_i8);
620 } else if (blk_size == 4) {
621 using kernel_t = jit_uni_binary_kernel_t<avx2, Xbyak::Xmm>;
622 return new kernel_t(pd, conf, tail_kernel && !conf.is_i8);
623 }
624 break;
625 }
626 case sse41: {
627 if (blk_size == 4 || is_plain_layout) {
628 using kernel_t = jit_uni_binary_kernel_t<sse41, Xbyak::Xmm>;
629 return new kernel_t(pd, conf, tail_kernel && !conf.is_i8);
630 }
631 break;
632 }
633 default: assert(!"Not supported isa");
634 }
635 assert(!"Could not create binary kernel");
636 return nullptr;
637}
638
639jit_uni_binary_t::jit_uni_binary_t(const pd_t *apd) : primitive_t(apd) {}
640
641status_t jit_uni_binary_t::init(engine_t *engine) {
642 CHECK(safe_ptr_assign(
643 kernel_, create_binary_kernel(pd(), false /*tail_kernel*/)));
644
645 if (utils::one_of(pd()->dst_md(0)->data_type, f32, bf16, f16)) {
646 const memory_desc_wrapper src0_d(pd_->src_md(0));
647 const auto &simd_w = kernel_->simd_w();
648 const auto oc = src0_d.ndims() >= 2 ? src0_d.dims()[1] : 1;
649
650 if (op_t::c_blocked == pd()->get_conf().op_type && oc % simd_w) {
651 CHECK(safe_ptr_assign(kernel_tail_,
652 create_binary_kernel(pd(), true /*tail_kernel*/)));
653 CHECK(kernel_tail_->create_kernel());
654 }
655 }
656
657 return kernel_->create_kernel();
658}
659
660void jit_uni_binary_t::execute_no_bcast_strategy(const data_t *src0,
661 const data_t *src1, data_t *dst, const float *scale0,
662 const float *scale1,
663 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
664 const bcast_t bcast_type) const {
665 const auto kernel = kernel_.get();
666 const auto &simd_w = kernel_->simd_w();
667
668 const memory_desc_wrapper src0_d(pd()->src_md(0));
669 const memory_desc_wrapper src1_d(pd()->src_md(1));
670 const memory_desc_wrapper dst_d(pd()->dst_md(0));
671 const int src0_type_size = types::data_type_size(src0_d.data_type());
672 const int src1_type_size = types::data_type_size(src1_d.data_type());
673 const int dst_type_size = types::data_type_size(dst_d.data_type());
674
675 const auto &conf = pd()->get_conf();
676 const bool is_src_different_layouts = conf.is_src_different_layouts;
677
678 if (is_src_different_layouts) {
679 std::vector<unsigned> indices;
680
681 const dim_t src1_different_layout_stride = conf.src1_stride;
682 for (size_t i = 0; i < simd_w; i++)
683 indices.push_back(
684 i * src1_different_layout_stride * src1_type_size);
685
686 const dim_t batch = src0_d.dims()[0];
687 const dim_t batch_stride = src1_d.blocking_desc().strides[0];
688 const dim_t outer_dims = conf.outer_dims;
689 const size_t src1_stride_range
690 = outer_dims * src1_different_layout_stride * src1_type_size;
691
692 const dim_t nelems_per_aligned_dims
693 = src0_d.nelems(true) / (batch * outer_dims);
694 const dim_t nelems0_simd = nelems_per_aligned_dims / simd_w;
695 const dim_t nelems0_tail = nelems_per_aligned_dims % simd_w;
696 const bool has_tail = nelems0_tail > 0;
697
698 const int nthr = dnnl_get_current_num_threads();
699 const dim_t thr_per_nelems_group = nstl::min(
700 nstl::max(nthr / batch, (dim_t)1), nelems0_simd + has_tail);
701
702 // Compute strategy:
703 // Iterate over batch and over outer dims.
704 // Divide number of threads by batch size and limiting it by a number
705 // of outer_dims nelems to parallel over it when needed.
706 parallel_nd(
707 batch, thr_per_nelems_group, [&](dim_t b, dim_t nelems_group) {
708 dim_t start = 0, end = 0;
709 balance211(nelems0_simd + has_tail, thr_per_nelems_group,
710 nelems_group, start, end);
711 if (start >= end) return;
712
713 const bool ithr_does_tail = has_tail
714 && utils::one_of(nelems0_simd + has_tail, end, 0);
715 const dim_t n_simd_to_do
716 = (end - start - ithr_does_tail) * simd_w;
717 const dim_t tail_to_do = ithr_does_tail * nelems0_tail;
718 const size_t batch_off = batch_stride * b;
719
720 if (nelems0_simd != 0) {
721 start *= outer_dims;
722 end *= outer_dims;
723 }
724
725 start *= simd_w;
726 jit_binary_call_s p;
727 p.spat_offt_count = (n_simd_to_do + tail_to_do) * outer_dims
728 * dst_type_size;
729 p.src0 = src0 + (start + batch_off) * src0_type_size;
730 p.src1 = src1
731 + (start / outer_dims + batch_off) * src1_type_size;
732 p.dst = dst + (start + batch_off) * dst_type_size;
733 p.indices = &indices[0];
734 p.src1_stride_range = src1_stride_range;
735 p.scales_src0 = scale0;
736 p.scales_src1 = scale1;
737 p.post_ops_binary_rhs_arg_vec
738 = post_ops_binary_rhs_arg_vec.data();
739 p.dst_orig = dst;
740 (*kernel)(&p);
741 });
742 } else {
743 const dim_t nelems0 = src0_d.nelems(true);
744 const dim_t nelems0_simd = nelems0 / simd_w;
745 const dim_t nelems0_tail = nelems0 % simd_w;
746 const bool has_tail = nelems0_tail > 0;
747
748 const bool point_broadcast = bcast_type == bcast_t::scalar;
749
750 // Compute strategy:
751 // Compute number of vectors, divide it equally between all threads.
752 // Last one will also handle a tail if present.
753 parallel(0, [&](const int ithr, const int nthr) {
754 dim_t start = 0, end = 0;
755 balance211(nelems0_simd + has_tail, nthr, ithr, start, end);
756 if (start >= end) return;
757
758 const bool ithr_does_tail
759 = has_tail && end == nelems0_simd + has_tail;
760 const dim_t n_simd_to_do = (end - start - ithr_does_tail) * simd_w;
761 const dim_t tail_to_do = ithr_does_tail * nelems0_tail;
762
763 jit_binary_call_s p;
764 p.spat_offt_count = (n_simd_to_do + tail_to_do) * dst_type_size;
765 p.src0 = src0 + start * simd_w * src0_type_size;
766 p.src1 = src1
767 + (point_broadcast ? 0 : (start * simd_w * src1_type_size));
768 p.dst = dst + start * simd_w * dst_type_size;
769 p.scales_src0 = scale0;
770 p.scales_src1 = scale1;
771 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
772 p.dst_orig = dst;
773 (*kernel)(&p);
774 });
775 }
776}
777
778void jit_uni_binary_t::execute_bcast_per_batch_strategy(const data_t *src0,
779 const data_t *src1, data_t *dst, const float *scale0,
780 const float *scale1,
781 const std::vector<const void *> &post_ops_binary_rhs_arg_vec) const {
782
783 const auto kernel = kernel_.get();
784 const auto &simd_w = kernel_->simd_w();
785
786 const memory_desc_wrapper src0_d(pd()->src_md(0));
787 const memory_desc_wrapper src1_d(pd()->src_md(1));
788 const memory_desc_wrapper dst_d(pd()->dst_md(0));
789 const int src0_type_size = types::data_type_size(src0_d.data_type());
790 const int src1_type_size = types::data_type_size(src1_d.data_type());
791 const int dst_type_size = types::data_type_size(dst_d.data_type());
792
793 const dim_t MB = src0_d.dims()[0];
794 const dim_t nelems0_per_b = src0_d.nelems(true) / MB;
795 const dim_t nelems0_simd = nelems0_per_b / simd_w;
796 const dim_t nelems0_tail = nelems0_per_b % simd_w;
797 const bool has_tail = nelems0_tail > 0;
798
799 // Compute strategy:
800 // Compute number of vectors per batch, divide it equally between all
801 // threads. Last one will also handle a tail if present.
802 const dim_t nthr = nstl::min(
803 nelems0_simd + has_tail, (dim_t)dnnl_get_current_num_threads());
804 parallel_nd(MB, nthr, [&](dim_t b, dim_t ithr) {
805 dim_t start = 0, end = 0;
806 balance211(nelems0_simd + has_tail, nthr, ithr, start, end);
807 if (start >= end) return;
808
809 const bool ithr_does_tail = has_tail && end == nelems0_simd + has_tail;
810 const dim_t n_simd_to_do = (end - start - ithr_does_tail) * simd_w;
811 const dim_t tail_to_do = ithr_does_tail * nelems0_tail;
812
813 jit_binary_call_s p;
814 p.spat_offt_count = (n_simd_to_do + tail_to_do) * dst_type_size;
815 const dim_t off = start * simd_w;
816 p.src0 = src0 + (off + b * nelems0_per_b) * src0_type_size;
817 p.src1 = src1 + off * src1_type_size;
818 p.dst = dst + (off + b * nelems0_per_b) * dst_type_size;
819 p.scales_src0 = scale0;
820 p.scales_src1 = scale1;
821 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
822 p.dst_orig = dst;
823 (*kernel)(&p);
824 });
825}
826
827void jit_uni_binary_t::execute_bcast_per_c_strategy(const data_t *src0,
828 const data_t *src1, data_t *dst, const float *scale0,
829 const float *scale1,
830 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
831 const op_t op_type, const bcast_t bcast_type,
832 const bool blocked_oc_tail) const {
833 const auto kernel = kernel_.get();
834 const auto kernel_tail = kernel_tail_.get();
835 const auto &simd_w = kernel_->simd_w();
836
837 const memory_desc_wrapper src0_d(pd()->src_md(0));
838 const memory_desc_wrapper src1_d(pd()->src_md(1));
839 const memory_desc_wrapper dst_d(pd()->dst_md(0));
840 const int src0_type_size = types::data_type_size(src0_d.data_type());
841 const int src1_type_size = types::data_type_size(src1_d.data_type());
842 const int dst_type_size = types::data_type_size(dst_d.data_type());
843 const auto ndims = src0_d.ndims();
844 const auto &dims = src0_d.dims();
845 const dim_t MB = dims[0];
846 const dim_t C = ndims >= 2 ? dims[1] : 1;
847 const dim_t SP = ndims >= 3 ? utils::array_product(dims + 2, ndims - 2) : 1;
848
849 const auto &bcast_dims = pd()->broadcast_dims();
850
851 const dim_t nelems_slice_src0
852 = utils::array_product(src0_d.padded_dims() + 1, ndims - 1);
853 const dim_t nelems_slice_src1 = bcast_type == bcast_t::none
854 ? nelems_slice_src0
855 : ((bcast_dims[0] == 0) ? utils::array_product(
856 src1_d.padded_dims() + 1, ndims - 1)
857 : 0);
858
859 if (op_type == op_t::c_blocked) {
860 const dim_t C_blocks = std::ceil(
861 static_cast<float>(src0_d.padded_dims()[1]) / simd_w);
862 // Compute strategy:
863 // Each block is individual - parallel over MB and C_blocks safely.
864
865 const std::function<void(jit_binary_call_s *, dim_t)>
866 kernel_blocked_no_tail
867 = [&](jit_binary_call_s *p, dim_t C_blk) { (*kernel)(p); };
868 const std::function<void(jit_binary_call_s *, dim_t)>
869 kernel_blocked_tail = [&](jit_binary_call_s *p, dim_t C_blk) {
870 if (C_blk == (C_blocks - 1))
871 (*kernel_tail)(p);
872 else
873 (*kernel)(p);
874 };
875 const auto &kernel_blocked = blocked_oc_tail ? kernel_blocked_tail
876 : kernel_blocked_no_tail;
877 const auto src1_off = [&](dim_t mb, dim_t C_blk, dim_t off) -> dim_t {
878 switch (bcast_type) {
879 case bcast_t::scalar: return mb * nelems_slice_src1;
880 case bcast_t::per_batch: return C_blk * SP * simd_w;
881 case bcast_t::none: return off;
882 default: return mb * nelems_slice_src1 + C_blk * simd_w;
883 }
884 };
885
886 parallel_nd(MB, C_blocks, [&](dim_t mb, dim_t C_blk) {
887 jit_binary_call_s p;
888 p.spat_offt_count = SP * simd_w * dst_type_size;
889 const dim_t off = mb * nelems_slice_src0 + C_blk * SP * simd_w;
890 p.dst = dst + off * dst_type_size;
891 p.src0 = src0 + off * src0_type_size;
892 p.src1 = src1 + src1_off(mb, C_blk, off) * src1_type_size;
893 p.scales_src0 = scale0;
894 p.scales_src1 = scale1;
895 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
896 p.dst_orig = dst;
897 kernel_blocked(&p, C_blk);
898 });
899 } else if (op_type == op_t::n_spatial_c) {
900 const auto src1_off = [&](dim_t mb, dim_t sp, dim_t off) -> dim_t {
901 switch (bcast_type) {
902 case bcast_t::per_batch: return sp * C;
903 case bcast_t::none: return off;
904 default: return mb * nelems_slice_src1;
905 }
906 };
907
908 // Compute strategy:
909 // Each line of channels is individual, parallel over MB and spatial.
910 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
911 jit_binary_call_s p;
912 p.spat_offt_count = C * dst_type_size;
913 const auto off = mb * nelems_slice_src0 + sp * C;
914 p.dst = dst + off * dst_type_size;
915 p.src0 = src0 + off * src0_type_size;
916 p.src1 = src1 + src1_off(mb, sp, off) * src1_type_size;
917 p.scales_src0 = scale0;
918 p.scales_src1 = scale1;
919 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
920 p.dst_orig = dst;
921 (*kernel)(&p);
922 });
923 } else if (op_type == op_t::n_c_spatial) {
924 const auto src1_off = [&](dim_t mb, dim_t c, dim_t off) -> dim_t {
925 switch (bcast_type) {
926 case bcast_t::scalar: return mb * nelems_slice_src1;
927 case bcast_t::per_batch: return c * SP;
928 case bcast_t::none: return off;
929 default: return mb * nelems_slice_src1 + c;
930 }
931 };
932
933 // Compute strategy:
934 // Each line of spatial is individual, parallel over MB and C.
935 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
936 jit_binary_call_s p;
937 p.spat_offt_count = SP * dst_type_size;
938 const auto off = mb * nelems_slice_src0 + c * SP;
939 p.dst = dst + off * dst_type_size;
940 p.src0 = src0 + off * src0_type_size;
941 p.src1 = src1 + src1_off(mb, c, off) * src1_type_size;
942 p.scales_src0 = scale0;
943 p.scales_src1 = scale1;
944 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
945 p.dst_orig = dst;
946 (*kernel)(&p);
947 });
948 }
949}
950
951void jit_uni_binary_t::execute_bcast_per_w_strategy(const data_t *src0,
952 const data_t *src1, data_t *dst, const float *scale0,
953 const float *scale1,
954 const std::vector<const void *> &post_ops_binary_rhs_arg_vec,
955 const op_t op_type, const bool blocked_oc_tail) const {
956 const auto kernel = kernel_.get();
957 const auto kernel_tail = kernel_tail_.get();
958 const auto &simd_w = kernel_->simd_w();
959
960 const memory_desc_wrapper src0_d(pd()->src_md(0));
961 const memory_desc_wrapper src1_d(pd()->src_md(1));
962 const memory_desc_wrapper dst_d(pd()->dst_md(0));
963 const int src0_type_size = types::data_type_size(src0_d.data_type());
964 const int src1_type_size = types::data_type_size(src1_d.data_type());
965 const int dst_type_size = types::data_type_size(dst_d.data_type());
966 const auto ndims = src0_d.ndims();
967 const auto &dims = src0_d.dims();
968 const auto &bcast_dims = pd()->broadcast_dims();
969
970 const int not_bcasted_sp_dims = pd()->get_conf().not_bcasted_sp_dims;
971 const dim_t MB = dims[0];
972 // array product of outer dimensions that are not broadcast
973 const dim_t SP_no_bcast = ndims >= 3
974 ? utils::array_product(
975 dims + (ndims - not_bcasted_sp_dims), not_bcasted_sp_dims)
976 : 1;
977 const dim_t C = ndims >= 2 ? dims[1] : 1;
978 const dim_t SP = ndims >= 3 ? utils::array_product(dims + 2, ndims - 2) : 1;
979 // spatial without dimensions that are not broadcasted by src1
980 const dim_t N = SP / SP_no_bcast;
981
982 const dim_t nelems_slice_src0
983 = utils::array_product(src0_d.padded_dims() + 1, ndims - 1);
984
985 if (op_type == op_t::c_blocked) {
986 const dim_t C_blocks = std::ceil(
987 static_cast<float>(src0_d.padded_dims()[1]) / simd_w);
988 // Compute strategy:
989 // Each line of channels is individual, parallel over MB, C_blocks
990 // and spatial (broadcasted and not broadcasted spatial dims
991 // separately).
992
993 const std::function<void(jit_binary_call_s *, dim_t)>
994 kernel_blocked_no_tail
995 = [&](jit_binary_call_s *p, dim_t C_blk) { (*kernel)(p); };
996 const std::function<void(jit_binary_call_s *, dim_t)>
997 kernel_blocked_tail = [&](jit_binary_call_s *p, dim_t C_blk) {
998 if (C_blk == (C_blocks - 1))
999 (*kernel_tail)(p);
1000 else
1001 (*kernel)(p);
1002 };
1003 const auto &kernel_blocked = blocked_oc_tail ? kernel_blocked_tail
1004 : kernel_blocked_no_tail;
1005
1006 parallel_nd(MB, C_blocks, N, SP_no_bcast,
1007 [&](dim_t mb, dim_t C_blk, dim_t n, dim_t sp) {
1008 jit_binary_call_s p;
1009 p.spat_offt_count = simd_w * dst_type_size;
1010 const auto off = mb * nelems_slice_src0
1011 + simd_w * (C_blk * SP + n * SP_no_bcast + sp);
1012 p.dst = dst + off * dst_type_size;
1013 p.src0 = src0 + off * src0_type_size;
1014 // check if mb is broadcast
1015 const dim_t src1_off = bcast_dims[0] == 1
1016 ? sp * simd_w
1017 : (mb * SP_no_bcast + sp) * simd_w;
1018 p.src1 = src1 + src1_off * src1_type_size;
1019 p.scales_src0 = scale0;
1020 p.scales_src1 = scale1;
1021 p.post_ops_binary_rhs_arg_vec
1022 = post_ops_binary_rhs_arg_vec.data();
1023 p.dst_orig = dst;
1024 kernel_blocked(&p, C_blk);
1025 });
1026 } else if (op_type == op_t::n_spatial_c) {
1027 // Compute strategy:
1028 // Each line of channels is individual, parallel over MB and spatial
1029 // (broadcasted and not broadcasted spatial dims separately).
1030
1031 parallel_nd(MB, N, SP_no_bcast, [&](dim_t mb, dim_t n, dim_t sp) {
1032 jit_binary_call_s p;
1033 p.spat_offt_count = C * dst_type_size;
1034 const auto off
1035 = mb * nelems_slice_src0 + n * SP_no_bcast * C + sp * C;
1036 p.dst = dst + off * dst_type_size;
1037 p.src0 = src0 + off * src0_type_size;
1038 const dim_t src1_off
1039 = bcast_dims[0] == 1 ? sp : mb * SP_no_bcast + sp;
1040 p.src1 = src1 + src1_off * src1_type_size;
1041 p.scales_src0 = scale0;
1042 p.scales_src1 = scale1;
1043 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1044 p.dst_orig = dst;
1045 (*kernel)(&p);
1046 });
1047 } else if (op_type == op_t::n_c_spatial) {
1048 // Compute strategy:
1049 // Each line of width is individual, parallel over MB, C and spatial
1050 // without not broadcasted dims. Use a kernel which broadcasts c_i
1051 // value into a vector register.
1052
1053 parallel_nd(MB, C, N, [&](dim_t mb, dim_t c, dim_t n) {
1054 jit_binary_call_s p;
1055 p.spat_offt_count = SP_no_bcast * dst_type_size;
1056 const auto off = mb * nelems_slice_src0 + c * N * SP_no_bcast
1057 + n * SP_no_bcast;
1058 p.dst = dst + off * dst_type_size;
1059 p.src0 = src0 + off * src0_type_size;
1060 const dim_t src1_off = bcast_dims[0] == 1 ? 0 : mb * SP_no_bcast;
1061 p.src1 = src1 + src1_off * src1_type_size;
1062 p.scales_src0 = scale0;
1063 p.scales_src1 = scale1;
1064 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1065 p.dst_orig = dst;
1066 (*kernel)(&p);
1067 });
1068 }
1069}
1070
1071status_t jit_uni_binary_t::execute(const exec_ctx_t &ctx) const {
1072 const auto src0 = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC_0);
1073 const auto src1 = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC_1);
1074 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
1075 const auto &post_ops = pd()->attr()->post_ops_;
1076 const auto &post_ops_binary_rhs_arg_vec
1077 = binary_injector::prepare_binary_args(post_ops, ctx);
1078 const float *scales[2];
1079 ASSIGN_ARG_SCALE_VALUE(scales[0], DNNL_ARG_SRC_0);
1080 ASSIGN_ARG_SCALE_VALUE(scales[1], DNNL_ARG_SRC_1);
1081
1082 const memory_desc_wrapper src0_d(pd()->src_md(0));
1083 const memory_desc_wrapper src1_d(pd()->src_md(1));
1084 const auto ndims = src0_d.ndims();
1085 const auto &dims = src0_d.dims();
1086 const dim_t C = ndims >= 2 ? dims[1] : 0;
1087
1088 const bool postops_per_oc_broadcast_exists
1089 = binary_injector::any_binary_postop_rhs_per_oc_broadcast(
1090 post_ops, src0_d, get_supported_postops_bcast_strategies());
1091 const auto &bcast_type = pd()->get_conf().bcast_type;
1092 const bool point_broadcast = bcast_type == bcast_t::scalar;
1093 const auto &op_type = pd()->get_conf().op_type;
1094 const bool with_postops = !post_ops.entry_.empty();
1095 const auto &simd_w = kernel_->simd_w();
1096 const bool has_oc_tail = C % simd_w;
1097 const bool point_broadcast_no_oc_tail = point_broadcast && !has_oc_tail;
1098 const auto alg = pd()->desc()->alg_kind;
1099 // Use strategy with kernel_tail for GreaterEqual op with oc_tail and
1100 // blocked format due to overwriting the vector tail by vcmpps.
1101 const bool vector_overwrite = utils::one_of(alg, alg_kind::binary_ge,
1102 alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt,
1103 alg_kind::binary_eq, alg_kind::binary_ne);
1104 const bool blocked_oc_tail = op_type == op_t::c_blocked && has_oc_tail
1105 && (with_postops || point_broadcast || bcast_type == bcast_t::per_w
1106 || vector_overwrite);
1107
1108 if ((bcast_type == bcast_t::none || point_broadcast_no_oc_tail)
1109 && !postops_per_oc_broadcast_exists && !blocked_oc_tail)
1110 execute_no_bcast_strategy(src0, src1, dst, scales[0], scales[1],
1111 post_ops_binary_rhs_arg_vec, bcast_type);
1112 else if (bcast_type == bcast_t::per_batch
1113 && !postops_per_oc_broadcast_exists && !blocked_oc_tail)
1114 execute_bcast_per_batch_strategy(src0, src1, dst, scales[0], scales[1],
1115 post_ops_binary_rhs_arg_vec);
1116 else if (bcast_type == bcast_t::per_w)
1117 execute_bcast_per_w_strategy(src0, src1, dst, scales[0], scales[1],
1118 post_ops_binary_rhs_arg_vec, op_type, blocked_oc_tail);
1119 else
1120 execute_bcast_per_c_strategy(src0, src1, dst, scales[0], scales[1],
1121 post_ops_binary_rhs_arg_vec, op_type, bcast_type,
1122 blocked_oc_tail);
1123
1124 return status::success;
1125}
1126
1127} // namespace x64
1128} // namespace cpu
1129} // namespace impl
1130} // namespace dnnl
1131