1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <functional> |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "cpu/cpu_primitive.hpp" |
21 | #include "cpu/x64/jit_uni_binary.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | static bcast_set_t get_supported_postops_bcast_strategies() { |
29 | return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
30 | broadcasting_strategy_t::per_oc_spatial, |
31 | broadcasting_strategy_t::no_broadcast}; |
32 | } |
33 | |
34 | static bool compare_layouts(const memory_desc_wrapper &src0_md, |
35 | const memory_desc_wrapper &src1_md) { |
36 | const strides_t &strides0 = src0_md.blocking_desc().strides; |
37 | const strides_t &strides1 = src1_md.blocking_desc().strides; |
38 | const dims_t &dims0 = src0_md.dims(); |
39 | const dims_t &dims1 = src1_md.dims(); |
40 | const int ndims = src0_md.ndims(); |
41 | |
42 | bool is_bcast = false; |
43 | for (int d = 1; d < ndims; d++) |
44 | is_bcast = is_bcast || dims0[d] != dims1[d]; |
45 | if (is_bcast) return true; |
46 | |
47 | bool same_layouts = true; |
48 | for (int d = 0; d < ndims; ++d) |
49 | same_layouts = same_layouts && strides0[d] == strides1[d]; |
50 | return same_layouts; |
51 | } |
52 | |
53 | static dim_t get_different_layout_stride( |
54 | const strides_t &strides0, const strides_t &strides1, const int ndims) { |
55 | for (int d = 0; d < ndims; d++) |
56 | if (strides0[d] == 1) return strides1[d]; |
57 | return strides1[ndims - 1]; |
58 | } |
59 | |
60 | static dim_t get_outer_dims_product( |
61 | const strides_t &strides0, const dims_t &dims, const int ndims) { |
62 | // nchw:nhwc->nchw |
63 | if (strides0[1] == 1) return dims[1]; |
64 | // nhwc:nchw->nhwc |
65 | else if (strides0[ndims - 1] == 1) |
66 | return utils::array_product(dims + 2, ndims - 2); |
67 | else |
68 | return dims[ndims - 1]; |
69 | } |
70 | |
71 | using namespace data_type; |
72 | |
73 | static bool data_type_supported(const data_type_t dtype) { |
74 | return utils::one_of(dtype, f32, bf16, f16, s8, u8); |
75 | } |
76 | |
77 | static cpu_isa_t get_supported_isa() { |
78 | if (mayiuse(avx512_core_fp16)) return avx512_core_fp16; |
79 | if (mayiuse(avx512_core_bf16)) return avx512_core_bf16; |
80 | if (mayiuse(avx512_core)) return avx512_core; |
81 | if (mayiuse(avx2)) return avx2; |
82 | if (mayiuse(sse41)) return sse41; |
83 | |
84 | return isa_undef; |
85 | } |
86 | |
87 | static bool data_format_supported( |
88 | const memory_desc_wrapper &mdw, const cpu_isa_t isa) { |
89 | if (mdw.is_plain()) return true; |
90 | const auto blk_size = mdw.blocking_desc().inner_blks[0]; |
91 | return (is_superset(isa, avx512_core) && utils::one_of(blk_size, 16, 8, 4)) |
92 | || (is_superset(isa, avx2) && utils::one_of(blk_size, 8, 4)) |
93 | || (is_superset(isa, sse41) && blk_size == 4); |
94 | } |
95 | |
96 | status_t jit_uni_binary_t::pd_t::init(engine_t *engine) { |
97 | using sm = primitive_attr_t::skip_mask_t; |
98 | |
99 | conf_.dst_type = dst_md()->data_type; |
100 | conf_.src0_type = src_md(0)->data_type; |
101 | conf_.src1_type = src_md(1)->data_type; |
102 | |
103 | memory_desc_wrapper dst_md_(dst_md()); |
104 | memory_desc_wrapper src0_md_(src_md(0)); |
105 | memory_desc_wrapper src1_md_(src_md(1)); |
106 | |
107 | const auto &po = attr()->post_ops_; |
108 | const int elt_idx = po.find(primitive_kind::eltwise); |
109 | conf_.is_i8 = utils::one_of(conf_.dst_type, s8, u8); |
110 | |
111 | conf_.isa = get_supported_isa(); |
112 | |
113 | bool ok = data_type_supported(conf_.dst_type) |
114 | && data_type_supported(conf_.src0_type) |
115 | && data_type_supported(conf_.src1_type) |
116 | && data_format_supported(src0_md_, conf_.isa) |
117 | && IMPLICATION(conf_.src0_type == bf16, mayiuse(avx512_core)) |
118 | && IMPLICATION(utils::one_of(f16, conf_.src0_type, conf_.src1_type, |
119 | conf_.dst_type), |
120 | mayiuse(avx512_core_fp16)) |
121 | && set_default_params() == status::success && !has_zero_dim_memory() |
122 | && IMPLICATION(!conf_.is_i8, src0_md_ == dst_md_) && is_applicable() |
123 | && attr()->has_default_values(sm::post_ops | sm::scales_runtime) |
124 | && attr_.set_default_formats(dst_md(0)) == status::success; |
125 | if (!ok) return status::unimplemented; |
126 | |
127 | // All operations over blocking descriptors should have md initialized. |
128 | conf_.is_src_different_layouts = !compare_layouts(src0_md_, src1_md_); |
129 | ok = post_ops_ok( |
130 | attr(), src_md(0), dst_md(), conf_.is_src_different_layouts) |
131 | && (conf_.is_i8 || elt_idx == -1 |
132 | || IMPLICATION(!dst_md_.is_dense(), |
133 | cpu_eltwise_fwd_pd_t::eltwise_preserves_zero( |
134 | po.entry_[elt_idx].eltwise))) |
135 | && IMPLICATION((!attr()->scales_.has_default_values()), |
136 | check_scales_mask()) |
137 | && (conf_.is_i8 |
138 | || IMPLICATION(!mayiuse(avx2), |
139 | src0_md_.consistent_with(src1_md_) |
140 | || src0_md_.is_plain())); |
141 | |
142 | if (!ok) return status::unimplemented; |
143 | |
144 | conf_.postops_per_oc_broadcast_exists |
145 | = binary_injector::any_binary_postop_rhs_per_oc_broadcast( |
146 | po, src0_md_, get_supported_postops_bcast_strategies()); |
147 | conf_.is_bf16 = conf_.dst_type == bf16; |
148 | conf_.is_f16 = conf_.dst_type == f16; |
149 | conf_.op_type = get_op_type(src0_md_); |
150 | assert(conf_.op_type != op_t::none); |
151 | conf_.do_scale_src0 = !attr()->scales_.get(DNNL_ARG_SRC_0).defined() |
152 | || !attr()->scales_.get(DNNL_ARG_SRC_0).has_default_values(); |
153 | conf_.do_scale_src1 = !attr()->scales_.get(DNNL_ARG_SRC_1).defined() |
154 | || !attr()->scales_.get(DNNL_ARG_SRC_1).has_default_values(); |
155 | const auto sum_idx = po.find(primitive_kind::sum); |
156 | conf_.do_sum = sum_idx != -1 && po.entry_[sum_idx].sum.scale != 0.f; |
157 | conf_.with_eltwise = po.find(primitive_kind::eltwise) != -1; |
158 | conf_.with_binary = po.find(primitive_kind::binary) != -1; |
159 | conf_.with_postops |
160 | = conf_.with_binary || conf_.with_eltwise || conf_.do_sum; |
161 | conf_.sum_scale = conf_.do_sum ? po.entry_[sum_idx].sum.scale : 0.f; |
162 | const auto &bcast_dims = broadcast_dims(); |
163 | conf_.bcast_type = is_tensor_op() ? bcast_t::none |
164 | : get_bcast_type(src1_md_, bcast_dims); |
165 | conf_.broadcast_src1_value = (conf_.op_type == op_t::n_c_spatial |
166 | && conf_.bcast_type == bcast_t::per_c) |
167 | || (utils::one_of(conf_.op_type, op_t::n_spatial_c, op_t::c_blocked) |
168 | && conf_.bcast_type == bcast_t::per_w) |
169 | || conf_.bcast_type == bcast_t::scalar; |
170 | conf_.use_stride_src1 = !conf_.broadcast_src1_value |
171 | && (utils::one_of( |
172 | conf_.bcast_type, bcast_t::none, bcast_t::per_batch) |
173 | || (conf_.op_type == op_t::n_spatial_c |
174 | && conf_.bcast_type == bcast_t::per_c) |
175 | || (conf_.op_type == op_t::n_c_spatial |
176 | && conf_.bcast_type == bcast_t::per_w)); |
177 | conf_.use_stride_rhs_postops = conf_.postops_per_oc_broadcast_exists |
178 | && conf_.op_type == op_t::n_spatial_c; |
179 | |
180 | const auto ndims = src0_md_.ndims(); |
181 | if (conf_.is_src_different_layouts) { |
182 | const auto &strides0 = src0_md_.blocking_desc().strides; |
183 | const auto &strides1 = src1_md_.blocking_desc().strides; |
184 | conf_.src1_stride |
185 | = get_different_layout_stride(strides0, strides1, ndims); |
186 | conf_.outer_dims |
187 | = get_outer_dims_product(strides0, src0_md_.dims(), ndims); |
188 | } |
189 | if (conf_.bcast_type == bcast_t::per_w) { |
190 | for (int d = 2; d < ndims; ++d) |
191 | conf_.not_bcasted_sp_dims += !bcast_dims[d]; |
192 | } |
193 | |
194 | return status::success; |
195 | } |
196 | |
197 | op_t jit_uni_binary_t::pd_t::get_op_type(const memory_desc_wrapper &src0_d) { |
198 | const auto &strides = src0_d.blocking_desc().strides; |
199 | const auto ndims = src0_d.ndims(); |
200 | |
201 | if (!src0_d.is_plain() && src0_d.blocking_desc().inner_idxs[0] == 1) |
202 | return op_t::c_blocked; |
203 | else if (strides[1] == 1) |
204 | return op_t::n_spatial_c; |
205 | else if (strides[0] >= strides[1] |
206 | && IMPLICATION(ndims >= 3, strides[1] >= strides[2])) |
207 | return op_t::n_c_spatial; |
208 | return op_t::none; |
209 | } |
210 | |
211 | bool jit_uni_binary_t::pd_t::is_only_dim0_bcasted( |
212 | const dims_t &bcast_dims, const int ndims) { |
213 | bool only_dim0_bcasted = true; |
214 | for (int d = 1; d < ndims; d++) |
215 | only_dim0_bcasted = only_dim0_bcasted && bcast_dims[d] == 0; |
216 | return only_dim0_bcasted; |
217 | } |
218 | |
219 | // non-blocked: nxc || ncx |
220 | bool jit_uni_binary_t::pd_t::is_format_non_blocked( |
221 | const memory_desc_wrapper &mdw) const { |
222 | const auto &dims = mdw.dims(); |
223 | const auto &strides = mdw.blocking_desc().strides; |
224 | const auto &ndims = mdw.ndims(); |
225 | |
226 | const bool is_ncx |
227 | = IMPLICATION(strides[0] != 0, |
228 | strides[0] >= utils::array_product(dims + 1, ndims - 1)) |
229 | && IMPLICATION(ndims >= 3 && strides[1] != 0, |
230 | strides[1] >= utils::array_product(dims + 2, ndims - 2)) |
231 | && IMPLICATION(ndims >= 4 && strides[2] != 0, |
232 | strides[2] >= utils::array_product(dims + 3, ndims - 3)) |
233 | && IMPLICATION(ndims >= 5 && strides[3] != 0, |
234 | strides[3] >= utils::array_product(dims + 4, ndims - 4)) |
235 | && IMPLICATION(strides[ndims - 1] != 0, strides[ndims - 1] == 1); |
236 | const bool is_nxc |
237 | = IMPLICATION(strides[0] != 0, |
238 | strides[0] >= utils::array_product(dims + 1, ndims - 1)) |
239 | && IMPLICATION(ndims >= 3 && strides[2] != 0, |
240 | strides[2] >= dims[1] |
241 | * utils::array_product(dims + 3, ndims - 3)) |
242 | && IMPLICATION(ndims >= 4 && strides[3] != 0, |
243 | strides[3] >= dims[1] |
244 | * utils::array_product(dims + 4, ndims - 4)) |
245 | && IMPLICATION(ndims >= 5 && strides[4] != 0, |
246 | strides[4] >= dims[1] |
247 | * utils::array_product(dims + 5, ndims - 5)) |
248 | && IMPLICATION(strides[1] != 0, strides[1] == 1); |
249 | return is_nxc || is_ncx; |
250 | } |
251 | |
252 | bcast_t jit_uni_binary_t::pd_t::get_bcast_type( |
253 | const memory_desc_wrapper &src1_d, const dims_t &bcast_dims) { |
254 | if (src1_d.nelems() == 1) |
255 | return bcast_t::scalar; |
256 | else if (bcast_dims[1] == 1) |
257 | return bcast_t::per_w; |
258 | else if (is_only_dim0_bcasted(bcast_dims, src1_d.ndims())) |
259 | return bcast_t::per_batch; |
260 | else |
261 | return bcast_t::per_c; |
262 | } |
263 | |
264 | bool jit_uni_binary_t::pd_t::alg_preserves_zero() const { |
265 | using namespace utils; |
266 | using namespace alg_kind; |
267 | return utils::one_of(desc()->alg_kind, binary_add, binary_max, binary_min, |
268 | binary_mul, binary_sub, binary_ge, binary_gt, binary_le, binary_lt, |
269 | binary_eq, binary_ne); |
270 | } |
271 | |
272 | bool jit_uni_binary_t::pd_t::check_scales_mask() const { |
273 | for (const auto &s : attr()->scales_.scales_) { |
274 | if (s.second.mask_ != 0) return false; |
275 | } |
276 | return true; |
277 | } |
278 | |
279 | bool jit_uni_binary_t::pd_t::is_bcast_pattern(const dims_t &bcast_dims, |
280 | const dim_t ndims, const dim_t N_bcast, const dim_t C_bcast, |
281 | const dim_t W_bcast) const { |
282 | return bcast_dims[0] == N_bcast && bcast_dims[1] == C_bcast |
283 | && bcast_dims[ndims - 1] == W_bcast; |
284 | } |
285 | |
286 | bool jit_uni_binary_t::pd_t::is_bcast_allowed(const int ndims) const { |
287 | // supported cases: NxCxDxHxW:{NxCx1x1x1,1xCx1x1x1,Nx1xDxHxW,Nx1x1xHxW, |
288 | // Nx1x1x1xW,1xCxDxHxW,1x1xDxHxW,1x1x1xHxW, |
289 | // 1x1x1x1xW,1x1x1x1x1} |
290 | const auto &bcast_dims = broadcast_dims(); |
291 | // check if there is continuous broadcast between non-broadcast dims |
292 | // if next_bcast_expected == 1, not broadcast dim not met |
293 | int next_bcast_expected = 1; |
294 | bool sp_not_bcasted = true; |
295 | bool ok = true; |
296 | for (int d = 2; d < ndims; ++d) { |
297 | if (bcast_dims[d] == 0) |
298 | next_bcast_expected = 0; |
299 | else |
300 | sp_not_bcasted = false; |
301 | ok = ok && bcast_dims[d] == next_bcast_expected; |
302 | } |
303 | |
304 | #define BCAST_PATTERN(N, C, W, condition) \ |
305 | (is_bcast_pattern(bcast_dims, ndims, N, C, W) && (condition)) |
306 | if (ndims > 2) |
307 | ok = ok |
308 | && (BCAST_PATTERN(0, 1, 0, true) || BCAST_PATTERN(1, 1, 0, true) |
309 | || BCAST_PATTERN(1, 0, 0, sp_not_bcasted) |
310 | || BCAST_PATTERN(0, 0, 1, !!next_bcast_expected) |
311 | || BCAST_PATTERN(1, 0, 1, !!next_bcast_expected) |
312 | || BCAST_PATTERN(1, 1, 1, !!next_bcast_expected)); |
313 | #undef BCAST_PATTERN |
314 | return ok; |
315 | } |
316 | |
317 | // check for different src formats with same dims |
318 | // broadcast can be accepted if src_dim == src1_dims (1 == 1) |
319 | bool jit_uni_binary_t::pd_t::is_different_layouts_allowed( |
320 | const memory_desc_wrapper &src0_d, |
321 | const memory_desc_wrapper &src1_d) const { |
322 | const dims_t &src0_dims = src0_d.dims(); |
323 | const dims_t &src1_dims = src1_d.dims(); |
324 | const int ndims = src0_d.ndims(); |
325 | |
326 | bool without_bcast = true; |
327 | for (int d = 0; d < ndims; d++) |
328 | without_bcast = without_bcast && src0_dims[d] == src1_dims[d]; |
329 | if (!without_bcast) return false; |
330 | |
331 | // allow nchw:nhwc and nhwc:nchw and disable for blocked layouts |
332 | return src0_d.is_plain() && src1_d.is_plain() |
333 | && is_format_non_blocked(src0_d) && is_format_non_blocked(src1_d); |
334 | } |
335 | |
336 | bool jit_uni_binary_t::pd_t::is_applicable() { |
337 | const memory_desc_wrapper src0_d(src_md(0)); |
338 | const memory_desc_wrapper src1_d(src_md(1)); |
339 | const memory_desc_wrapper dst_d(dst_md()); |
340 | const auto ndims = src0_d.ndims(); |
341 | |
342 | // check density first to avoid same non-dense src0 and src1 to pass |
343 | // the next check |
344 | bool ok = src0_d.is_dense(true) && src1_d.is_dense(true) |
345 | && dst_d.is_dense(true); |
346 | if (!ok) return false; |
347 | |
348 | // TODO: fix implementation for tensor with paddings to work with any block |
349 | // size. For now return unimplemented if more than single blocking |
350 | // or `block size > 16`. |
351 | const auto &blk_d = dst_d.blocking_desc(); |
352 | if (!dst_d.is_dense() |
353 | && (blk_d.inner_nblks > 1 || blk_d.inner_blks[0] > 16)) |
354 | return false; |
355 | |
356 | const bool is_src_different_layouts = !compare_layouts(src0_d, src1_d); |
357 | const bool different_layouts_allowed |
358 | = is_different_layouts_allowed(src0_d, src1_d); |
359 | if (!conf_.is_i8) { |
360 | const bool has_padding = utils::one_of(true, |
361 | src0_d.nelems(true) != src0_d.nelems(false), |
362 | src1_d.nelems(true) != src1_d.nelems(false), |
363 | dst_d.nelems(true) != dst_d.nelems(false)); |
364 | ok = IMPLICATION(has_padding, alg_preserves_zero()); |
365 | if (!ok) return false; |
366 | |
367 | // full tensor operation |
368 | bool same_dims = true; |
369 | const auto &src0_dims = src0_d.dims(); |
370 | const auto &src1_dims = src1_d.dims(); |
371 | for (int d = 0; d < ndims; d++) |
372 | same_dims = same_dims && src0_dims[d] == src1_dims[d]; |
373 | if (same_dims |
374 | && IMPLICATION( |
375 | is_src_different_layouts, different_layouts_allowed)) |
376 | return true; |
377 | } else { |
378 | const dim_t C = ndims >= 2 ? src0_d.dims()[1] : 1; |
379 | const bool has_oc_tail = C != src0_d.padded_dims()[1]; |
380 | const bool has_outer_dims_tail = is_src_different_layouts |
381 | && get_outer_dims_product(src0_d.blocking_desc().strides, |
382 | src0_d.dims(), src0_d.ndims()); |
383 | |
384 | // Disable compare operations when blocked tag with tail. |
385 | // Tail processing is not supported and the vcmps instruction |
386 | // overwrites the output vector. |
387 | if (utils::one_of(desc()->alg_kind, alg_kind::binary_ge, |
388 | alg_kind::binary_gt, alg_kind::binary_le, |
389 | alg_kind::binary_lt, alg_kind::binary_eq, |
390 | alg_kind::binary_ne) |
391 | && (has_oc_tail || has_outer_dims_tail)) |
392 | return false; |
393 | |
394 | // full tensor operation |
395 | if (src0_d.similar_to(src1_d, true, false, 0) |
396 | || different_layouts_allowed) |
397 | return true; |
398 | // source0 broadcast not supported |
399 | if (!src0_d.similar_to(dst_d, true, false, 0)) return false; |
400 | } |
401 | // broadcast or different layouts operation |
402 | if (!(is_bcast_allowed(ndims) |
403 | && IMPLICATION( |
404 | is_src_different_layouts, different_layouts_allowed))) |
405 | return false; |
406 | |
407 | // only nspc and ncsp formats are supported for bcast |
408 | if (src0_d.is_plain() && src1_d.is_plain()) |
409 | return is_format_non_blocked(src0_d) && is_format_non_blocked(src1_d); |
410 | |
411 | // blocked formats |
412 | if (!conf_.is_i8) { |
413 | // check blocking_desc consistency |
414 | const auto valid_bd = [&](const memory_desc_wrapper &mdw) { |
415 | int blksize = 8; |
416 | if (mayiuse(avx512_core)) blksize = 16; |
417 | const auto &bd = mdw.blocking_desc(); |
418 | |
419 | return bd.inner_nblks == 1 && bd.inner_blks[0] == blksize |
420 | && bd.inner_idxs[0] == 1; |
421 | }; |
422 | |
423 | return valid_bd(src0_d) && valid_bd(src1_d); |
424 | } else { |
425 | const auto &bd0 = src0_d.blocking_desc(); |
426 | const auto &bd1 = src1_d.blocking_desc(); |
427 | const auto &bcast_dims = broadcast_dims(); |
428 | // disable blocked tag for source1 when W is not broadcast |
429 | return bd0.strides[1] == 1 && bd0.inner_nblks == 0 |
430 | && IMPLICATION( |
431 | bcast_dims[ndims - 1] == 0, bd1.inner_nblks == 0); |
432 | } |
433 | } |
434 | |
435 | bool jit_uni_binary_t::post_ops_ok(const primitive_attr_t *attr, |
436 | const memory_desc_wrapper &src0_d, const memory_desc_wrapper &dst_d, |
437 | const bool is_src_different_layouts) { |
438 | using namespace primitive_kind; |
439 | |
440 | const auto &p = attr->post_ops_; |
441 | const auto is_eltwise = [&](int idx) { |
442 | if (p.entry_[idx].is_eltwise()) { |
443 | const auto alg = p.entry_[idx].eltwise.alg; |
444 | return eltwise_injector::is_alg_supported(alg); |
445 | } |
446 | return false; |
447 | }; |
448 | const auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); }; |
449 | const bool is_avx512_core = mayiuse(avx512_core); |
450 | const bool is_avx512_core_fp16 = mayiuse(avx512_core_fp16); |
451 | const bool is_i8 = utils::one_of(dst_d.data_type(), s8, u8); |
452 | |
453 | const auto supported_strategies = get_supported_postops_bcast_strategies(); |
454 | for (int i = 0; i < p.len(); i++) { |
455 | if (p.contain(primitive_kind::sum, i)) { |
456 | if (p.entry_[i].sum.zero_point != 0) return false; |
457 | if (src0_d.data_type() != dst_d.data_type()) return false; |
458 | } else if (is_binary(i)) { |
459 | const auto &post_ops_mem = p.entry_[i].binary.src1_desc; |
460 | const bool is_src1_bf16 = post_ops_mem.data_type == data_type::bf16; |
461 | const bool is_src1_f16 = post_ops_mem.data_type == data_type::f16; |
462 | if (is_i8 && (is_src1_bf16 || is_src1_f16)) return false; |
463 | if (!IMPLICATION(is_src1_bf16, is_avx512_core)) return false; |
464 | if (!IMPLICATION(is_src1_f16, is_avx512_core_fp16)) return false; |
465 | if (get_rhs_arg_broadcasting_strategy( |
466 | post_ops_mem, dst_d, supported_strategies) |
467 | == broadcasting_strategy_t::no_broadcast) { |
468 | const memory_desc_wrapper post_op_mem_d(post_ops_mem); |
469 | if (!post_op_mem_d.similar_to(dst_d, true, false)) return false; |
470 | } |
471 | } else if (!is_eltwise(i)) { |
472 | return false; |
473 | } |
474 | } |
475 | |
476 | const int vlen = is_avx512_core ? cpu_isa_traits<avx512_core>::vlen |
477 | : cpu_isa_traits<avx2>::vlen; |
478 | const bool postops_per_oc_broadcast_exists |
479 | = binary_injector::any_binary_postop_rhs_per_oc_broadcast( |
480 | p, src0_d, supported_strategies); |
481 | if (postops_per_oc_broadcast_exists && is_src_different_layouts) |
482 | return false; |
483 | const int blksize = vlen / sizeof(float); |
484 | |
485 | const bool blocked_format = !src0_d.is_plain() && src0_d.is_blocking_desc(); |
486 | |
487 | if (postops_per_oc_broadcast_exists && blocked_format) { |
488 | /* |
489 | * check blocking_desc consistency, currently when among postops exists |
490 | * per_oc broadcast, binary kernel doesn't support situations when blocked |
491 | * format size is smaller then vlen. example: sse41 vlen size is 4 and format |
492 | * is nChw8c - not supported, avx2 vlen size is 8 and format is |
493 | * nChw8c - supported. |
494 | */ |
495 | const auto blocking_desc = src0_d.blocking_desc(); |
496 | if (blocking_desc.inner_nblks != 1 |
497 | || blocking_desc.inner_blks[0] != blksize |
498 | || blocking_desc.inner_idxs[0] != 1) |
499 | return false; |
500 | } |
501 | |
502 | const dim_t n_dims = src0_d.ndims(); |
503 | const dim_t &oc = n_dims >= 2 ? src0_d.dims()[1] : 1; |
504 | |
505 | /* |
506 | * TODO: Remove limitation supporting tail with blocked format for i8i8 |
507 | */ |
508 | const bool blocked_tail = p.len() && blocked_format && oc % blksize; |
509 | |
510 | return binary_injector::binary_args_broadcast_supported( |
511 | p, src0_d, get_supported_postops_bcast_strategies()) |
512 | && IMPLICATION( |
513 | utils::one_of(src0_d.data_type(), s8, u8), !blocked_tail) |
514 | && IMPLICATION(postops_per_oc_broadcast_exists, |
515 | binary_injector::all_binary_postop_rhs_per_oc_broadcast(p, |
516 | src0_d, supported_strategies, |
517 | [&src0_d](const memory_desc_wrapper &rhs_arg_md) { |
518 | return IMPLICATION(!mayiuse(avx2), |
519 | src0_d.consistent_with(rhs_arg_md) |
520 | || src0_d.is_plain()); |
521 | })); |
522 | } |
523 | |
524 | binary_kernel_t *create_binary_kernel( |
525 | const jit_uni_binary_t::pd_t *pd, bool tail_kernel) { |
526 | const auto &conf = pd->get_conf(); |
527 | const memory_desc_wrapper src0_d(pd->src_md(0)); |
528 | // No support for different blocked memory layouts |
529 | const auto blk_size = src0_d.blocking_desc().inner_blks[0]; |
530 | const auto is_plain_layout = src0_d.is_plain(); |
531 | switch (conf.isa) { |
532 | case avx512_core_fp16: { |
533 | if (blk_size == 16 || is_plain_layout) { |
534 | using kernel_t |
535 | = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Zmm>; |
536 | return new kernel_t(pd, conf, tail_kernel); |
537 | } else if (blk_size == 8) { |
538 | using kernel_t |
539 | = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Ymm>; |
540 | return new kernel_t(pd, conf, tail_kernel); |
541 | } else if (blk_size == 4) { |
542 | using kernel_t |
543 | = jit_uni_binary_kernel_t<avx512_core_fp16, Xbyak::Xmm>; |
544 | return new kernel_t(pd, conf, tail_kernel); |
545 | } |
546 | break; |
547 | } |
548 | case avx512_core_bf16: { |
549 | if (blk_size == 16 || is_plain_layout) { |
550 | if (conf.is_i8) { |
551 | using kernel_t |
552 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>; |
553 | return new kernel_t(pd, conf, false); |
554 | } else { |
555 | using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16, |
556 | Xbyak::Zmm>; |
557 | return new kernel_t(pd, conf, tail_kernel); |
558 | } |
559 | } else if (blk_size == 8) { |
560 | if (conf.is_i8) { |
561 | using kernel_t |
562 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>; |
563 | return new kernel_t(pd, conf, false); |
564 | } else { |
565 | using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16, |
566 | Xbyak::Ymm>; |
567 | return new kernel_t(pd, conf, tail_kernel); |
568 | } |
569 | } else if (blk_size == 4) { |
570 | if (conf.is_i8) { |
571 | using kernel_t |
572 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>; |
573 | return new kernel_t(pd, conf, false); |
574 | } else { |
575 | using kernel_t = jit_uni_binary_kernel_t<avx512_core_bf16, |
576 | Xbyak::Xmm>; |
577 | return new kernel_t(pd, conf, tail_kernel); |
578 | } |
579 | } |
580 | break; |
581 | } |
582 | case avx512_core: { |
583 | if (blk_size == 16 || is_plain_layout) { |
584 | if (conf.is_i8) { |
585 | using kernel_t |
586 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>; |
587 | return new kernel_t(pd, conf, false); |
588 | } else { |
589 | using kernel_t |
590 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Zmm>; |
591 | return new kernel_t(pd, conf, tail_kernel); |
592 | } |
593 | } else if (blk_size == 8) { |
594 | if (conf.is_i8) { |
595 | using kernel_t |
596 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>; |
597 | return new kernel_t(pd, conf, false); |
598 | } else { |
599 | using kernel_t |
600 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Ymm>; |
601 | return new kernel_t(pd, conf, tail_kernel); |
602 | } |
603 | } else if (blk_size == 4) { |
604 | if (conf.is_i8) { |
605 | using kernel_t |
606 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>; |
607 | return new kernel_t(pd, conf, false); |
608 | } else { |
609 | using kernel_t |
610 | = jit_uni_binary_kernel_t<avx512_core, Xbyak::Xmm>; |
611 | return new kernel_t(pd, conf, tail_kernel); |
612 | } |
613 | } |
614 | break; |
615 | } |
616 | case avx2: { |
617 | if (blk_size == 8 || is_plain_layout) { |
618 | using kernel_t = jit_uni_binary_kernel_t<avx2, Xbyak::Ymm>; |
619 | return new kernel_t(pd, conf, tail_kernel && !conf.is_i8); |
620 | } else if (blk_size == 4) { |
621 | using kernel_t = jit_uni_binary_kernel_t<avx2, Xbyak::Xmm>; |
622 | return new kernel_t(pd, conf, tail_kernel && !conf.is_i8); |
623 | } |
624 | break; |
625 | } |
626 | case sse41: { |
627 | if (blk_size == 4 || is_plain_layout) { |
628 | using kernel_t = jit_uni_binary_kernel_t<sse41, Xbyak::Xmm>; |
629 | return new kernel_t(pd, conf, tail_kernel && !conf.is_i8); |
630 | } |
631 | break; |
632 | } |
633 | default: assert(!"Not supported isa" ); |
634 | } |
635 | assert(!"Could not create binary kernel" ); |
636 | return nullptr; |
637 | } |
638 | |
639 | jit_uni_binary_t::jit_uni_binary_t(const pd_t *apd) : primitive_t(apd) {} |
640 | |
641 | status_t jit_uni_binary_t::init(engine_t *engine) { |
642 | CHECK(safe_ptr_assign( |
643 | kernel_, create_binary_kernel(pd(), false /*tail_kernel*/))); |
644 | |
645 | if (utils::one_of(pd()->dst_md(0)->data_type, f32, bf16, f16)) { |
646 | const memory_desc_wrapper src0_d(pd_->src_md(0)); |
647 | const auto &simd_w = kernel_->simd_w(); |
648 | const auto oc = src0_d.ndims() >= 2 ? src0_d.dims()[1] : 1; |
649 | |
650 | if (op_t::c_blocked == pd()->get_conf().op_type && oc % simd_w) { |
651 | CHECK(safe_ptr_assign(kernel_tail_, |
652 | create_binary_kernel(pd(), true /*tail_kernel*/))); |
653 | CHECK(kernel_tail_->create_kernel()); |
654 | } |
655 | } |
656 | |
657 | return kernel_->create_kernel(); |
658 | } |
659 | |
660 | void jit_uni_binary_t::execute_no_bcast_strategy(const data_t *src0, |
661 | const data_t *src1, data_t *dst, const float *scale0, |
662 | const float *scale1, |
663 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec, |
664 | const bcast_t bcast_type) const { |
665 | const auto kernel = kernel_.get(); |
666 | const auto &simd_w = kernel_->simd_w(); |
667 | |
668 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
669 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
670 | const memory_desc_wrapper dst_d(pd()->dst_md(0)); |
671 | const int src0_type_size = types::data_type_size(src0_d.data_type()); |
672 | const int src1_type_size = types::data_type_size(src1_d.data_type()); |
673 | const int dst_type_size = types::data_type_size(dst_d.data_type()); |
674 | |
675 | const auto &conf = pd()->get_conf(); |
676 | const bool is_src_different_layouts = conf.is_src_different_layouts; |
677 | |
678 | if (is_src_different_layouts) { |
679 | std::vector<unsigned> indices; |
680 | |
681 | const dim_t src1_different_layout_stride = conf.src1_stride; |
682 | for (size_t i = 0; i < simd_w; i++) |
683 | indices.push_back( |
684 | i * src1_different_layout_stride * src1_type_size); |
685 | |
686 | const dim_t batch = src0_d.dims()[0]; |
687 | const dim_t batch_stride = src1_d.blocking_desc().strides[0]; |
688 | const dim_t outer_dims = conf.outer_dims; |
689 | const size_t src1_stride_range |
690 | = outer_dims * src1_different_layout_stride * src1_type_size; |
691 | |
692 | const dim_t nelems_per_aligned_dims |
693 | = src0_d.nelems(true) / (batch * outer_dims); |
694 | const dim_t nelems0_simd = nelems_per_aligned_dims / simd_w; |
695 | const dim_t nelems0_tail = nelems_per_aligned_dims % simd_w; |
696 | const bool has_tail = nelems0_tail > 0; |
697 | |
698 | const int nthr = dnnl_get_current_num_threads(); |
699 | const dim_t thr_per_nelems_group = nstl::min( |
700 | nstl::max(nthr / batch, (dim_t)1), nelems0_simd + has_tail); |
701 | |
702 | // Compute strategy: |
703 | // Iterate over batch and over outer dims. |
704 | // Divide number of threads by batch size and limiting it by a number |
705 | // of outer_dims nelems to parallel over it when needed. |
706 | parallel_nd( |
707 | batch, thr_per_nelems_group, [&](dim_t b, dim_t nelems_group) { |
708 | dim_t start = 0, end = 0; |
709 | balance211(nelems0_simd + has_tail, thr_per_nelems_group, |
710 | nelems_group, start, end); |
711 | if (start >= end) return; |
712 | |
713 | const bool ithr_does_tail = has_tail |
714 | && utils::one_of(nelems0_simd + has_tail, end, 0); |
715 | const dim_t n_simd_to_do |
716 | = (end - start - ithr_does_tail) * simd_w; |
717 | const dim_t tail_to_do = ithr_does_tail * nelems0_tail; |
718 | const size_t batch_off = batch_stride * b; |
719 | |
720 | if (nelems0_simd != 0) { |
721 | start *= outer_dims; |
722 | end *= outer_dims; |
723 | } |
724 | |
725 | start *= simd_w; |
726 | jit_binary_call_s p; |
727 | p.spat_offt_count = (n_simd_to_do + tail_to_do) * outer_dims |
728 | * dst_type_size; |
729 | p.src0 = src0 + (start + batch_off) * src0_type_size; |
730 | p.src1 = src1 |
731 | + (start / outer_dims + batch_off) * src1_type_size; |
732 | p.dst = dst + (start + batch_off) * dst_type_size; |
733 | p.indices = &indices[0]; |
734 | p.src1_stride_range = src1_stride_range; |
735 | p.scales_src0 = scale0; |
736 | p.scales_src1 = scale1; |
737 | p.post_ops_binary_rhs_arg_vec |
738 | = post_ops_binary_rhs_arg_vec.data(); |
739 | p.dst_orig = dst; |
740 | (*kernel)(&p); |
741 | }); |
742 | } else { |
743 | const dim_t nelems0 = src0_d.nelems(true); |
744 | const dim_t nelems0_simd = nelems0 / simd_w; |
745 | const dim_t nelems0_tail = nelems0 % simd_w; |
746 | const bool has_tail = nelems0_tail > 0; |
747 | |
748 | const bool point_broadcast = bcast_type == bcast_t::scalar; |
749 | |
750 | // Compute strategy: |
751 | // Compute number of vectors, divide it equally between all threads. |
752 | // Last one will also handle a tail if present. |
753 | parallel(0, [&](const int ithr, const int nthr) { |
754 | dim_t start = 0, end = 0; |
755 | balance211(nelems0_simd + has_tail, nthr, ithr, start, end); |
756 | if (start >= end) return; |
757 | |
758 | const bool ithr_does_tail |
759 | = has_tail && end == nelems0_simd + has_tail; |
760 | const dim_t n_simd_to_do = (end - start - ithr_does_tail) * simd_w; |
761 | const dim_t tail_to_do = ithr_does_tail * nelems0_tail; |
762 | |
763 | jit_binary_call_s p; |
764 | p.spat_offt_count = (n_simd_to_do + tail_to_do) * dst_type_size; |
765 | p.src0 = src0 + start * simd_w * src0_type_size; |
766 | p.src1 = src1 |
767 | + (point_broadcast ? 0 : (start * simd_w * src1_type_size)); |
768 | p.dst = dst + start * simd_w * dst_type_size; |
769 | p.scales_src0 = scale0; |
770 | p.scales_src1 = scale1; |
771 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
772 | p.dst_orig = dst; |
773 | (*kernel)(&p); |
774 | }); |
775 | } |
776 | } |
777 | |
778 | void jit_uni_binary_t::execute_bcast_per_batch_strategy(const data_t *src0, |
779 | const data_t *src1, data_t *dst, const float *scale0, |
780 | const float *scale1, |
781 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec) const { |
782 | |
783 | const auto kernel = kernel_.get(); |
784 | const auto &simd_w = kernel_->simd_w(); |
785 | |
786 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
787 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
788 | const memory_desc_wrapper dst_d(pd()->dst_md(0)); |
789 | const int src0_type_size = types::data_type_size(src0_d.data_type()); |
790 | const int src1_type_size = types::data_type_size(src1_d.data_type()); |
791 | const int dst_type_size = types::data_type_size(dst_d.data_type()); |
792 | |
793 | const dim_t MB = src0_d.dims()[0]; |
794 | const dim_t nelems0_per_b = src0_d.nelems(true) / MB; |
795 | const dim_t nelems0_simd = nelems0_per_b / simd_w; |
796 | const dim_t nelems0_tail = nelems0_per_b % simd_w; |
797 | const bool has_tail = nelems0_tail > 0; |
798 | |
799 | // Compute strategy: |
800 | // Compute number of vectors per batch, divide it equally between all |
801 | // threads. Last one will also handle a tail if present. |
802 | const dim_t nthr = nstl::min( |
803 | nelems0_simd + has_tail, (dim_t)dnnl_get_current_num_threads()); |
804 | parallel_nd(MB, nthr, [&](dim_t b, dim_t ithr) { |
805 | dim_t start = 0, end = 0; |
806 | balance211(nelems0_simd + has_tail, nthr, ithr, start, end); |
807 | if (start >= end) return; |
808 | |
809 | const bool ithr_does_tail = has_tail && end == nelems0_simd + has_tail; |
810 | const dim_t n_simd_to_do = (end - start - ithr_does_tail) * simd_w; |
811 | const dim_t tail_to_do = ithr_does_tail * nelems0_tail; |
812 | |
813 | jit_binary_call_s p; |
814 | p.spat_offt_count = (n_simd_to_do + tail_to_do) * dst_type_size; |
815 | const dim_t off = start * simd_w; |
816 | p.src0 = src0 + (off + b * nelems0_per_b) * src0_type_size; |
817 | p.src1 = src1 + off * src1_type_size; |
818 | p.dst = dst + (off + b * nelems0_per_b) * dst_type_size; |
819 | p.scales_src0 = scale0; |
820 | p.scales_src1 = scale1; |
821 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
822 | p.dst_orig = dst; |
823 | (*kernel)(&p); |
824 | }); |
825 | } |
826 | |
827 | void jit_uni_binary_t::execute_bcast_per_c_strategy(const data_t *src0, |
828 | const data_t *src1, data_t *dst, const float *scale0, |
829 | const float *scale1, |
830 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec, |
831 | const op_t op_type, const bcast_t bcast_type, |
832 | const bool blocked_oc_tail) const { |
833 | const auto kernel = kernel_.get(); |
834 | const auto kernel_tail = kernel_tail_.get(); |
835 | const auto &simd_w = kernel_->simd_w(); |
836 | |
837 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
838 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
839 | const memory_desc_wrapper dst_d(pd()->dst_md(0)); |
840 | const int src0_type_size = types::data_type_size(src0_d.data_type()); |
841 | const int src1_type_size = types::data_type_size(src1_d.data_type()); |
842 | const int dst_type_size = types::data_type_size(dst_d.data_type()); |
843 | const auto ndims = src0_d.ndims(); |
844 | const auto &dims = src0_d.dims(); |
845 | const dim_t MB = dims[0]; |
846 | const dim_t C = ndims >= 2 ? dims[1] : 1; |
847 | const dim_t SP = ndims >= 3 ? utils::array_product(dims + 2, ndims - 2) : 1; |
848 | |
849 | const auto &bcast_dims = pd()->broadcast_dims(); |
850 | |
851 | const dim_t nelems_slice_src0 |
852 | = utils::array_product(src0_d.padded_dims() + 1, ndims - 1); |
853 | const dim_t nelems_slice_src1 = bcast_type == bcast_t::none |
854 | ? nelems_slice_src0 |
855 | : ((bcast_dims[0] == 0) ? utils::array_product( |
856 | src1_d.padded_dims() + 1, ndims - 1) |
857 | : 0); |
858 | |
859 | if (op_type == op_t::c_blocked) { |
860 | const dim_t C_blocks = std::ceil( |
861 | static_cast<float>(src0_d.padded_dims()[1]) / simd_w); |
862 | // Compute strategy: |
863 | // Each block is individual - parallel over MB and C_blocks safely. |
864 | |
865 | const std::function<void(jit_binary_call_s *, dim_t)> |
866 | kernel_blocked_no_tail |
867 | = [&](jit_binary_call_s *p, dim_t C_blk) { (*kernel)(p); }; |
868 | const std::function<void(jit_binary_call_s *, dim_t)> |
869 | kernel_blocked_tail = [&](jit_binary_call_s *p, dim_t C_blk) { |
870 | if (C_blk == (C_blocks - 1)) |
871 | (*kernel_tail)(p); |
872 | else |
873 | (*kernel)(p); |
874 | }; |
875 | const auto &kernel_blocked = blocked_oc_tail ? kernel_blocked_tail |
876 | : kernel_blocked_no_tail; |
877 | const auto src1_off = [&](dim_t mb, dim_t C_blk, dim_t off) -> dim_t { |
878 | switch (bcast_type) { |
879 | case bcast_t::scalar: return mb * nelems_slice_src1; |
880 | case bcast_t::per_batch: return C_blk * SP * simd_w; |
881 | case bcast_t::none: return off; |
882 | default: return mb * nelems_slice_src1 + C_blk * simd_w; |
883 | } |
884 | }; |
885 | |
886 | parallel_nd(MB, C_blocks, [&](dim_t mb, dim_t C_blk) { |
887 | jit_binary_call_s p; |
888 | p.spat_offt_count = SP * simd_w * dst_type_size; |
889 | const dim_t off = mb * nelems_slice_src0 + C_blk * SP * simd_w; |
890 | p.dst = dst + off * dst_type_size; |
891 | p.src0 = src0 + off * src0_type_size; |
892 | p.src1 = src1 + src1_off(mb, C_blk, off) * src1_type_size; |
893 | p.scales_src0 = scale0; |
894 | p.scales_src1 = scale1; |
895 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
896 | p.dst_orig = dst; |
897 | kernel_blocked(&p, C_blk); |
898 | }); |
899 | } else if (op_type == op_t::n_spatial_c) { |
900 | const auto src1_off = [&](dim_t mb, dim_t sp, dim_t off) -> dim_t { |
901 | switch (bcast_type) { |
902 | case bcast_t::per_batch: return sp * C; |
903 | case bcast_t::none: return off; |
904 | default: return mb * nelems_slice_src1; |
905 | } |
906 | }; |
907 | |
908 | // Compute strategy: |
909 | // Each line of channels is individual, parallel over MB and spatial. |
910 | parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) { |
911 | jit_binary_call_s p; |
912 | p.spat_offt_count = C * dst_type_size; |
913 | const auto off = mb * nelems_slice_src0 + sp * C; |
914 | p.dst = dst + off * dst_type_size; |
915 | p.src0 = src0 + off * src0_type_size; |
916 | p.src1 = src1 + src1_off(mb, sp, off) * src1_type_size; |
917 | p.scales_src0 = scale0; |
918 | p.scales_src1 = scale1; |
919 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
920 | p.dst_orig = dst; |
921 | (*kernel)(&p); |
922 | }); |
923 | } else if (op_type == op_t::n_c_spatial) { |
924 | const auto src1_off = [&](dim_t mb, dim_t c, dim_t off) -> dim_t { |
925 | switch (bcast_type) { |
926 | case bcast_t::scalar: return mb * nelems_slice_src1; |
927 | case bcast_t::per_batch: return c * SP; |
928 | case bcast_t::none: return off; |
929 | default: return mb * nelems_slice_src1 + c; |
930 | } |
931 | }; |
932 | |
933 | // Compute strategy: |
934 | // Each line of spatial is individual, parallel over MB and C. |
935 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
936 | jit_binary_call_s p; |
937 | p.spat_offt_count = SP * dst_type_size; |
938 | const auto off = mb * nelems_slice_src0 + c * SP; |
939 | p.dst = dst + off * dst_type_size; |
940 | p.src0 = src0 + off * src0_type_size; |
941 | p.src1 = src1 + src1_off(mb, c, off) * src1_type_size; |
942 | p.scales_src0 = scale0; |
943 | p.scales_src1 = scale1; |
944 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
945 | p.dst_orig = dst; |
946 | (*kernel)(&p); |
947 | }); |
948 | } |
949 | } |
950 | |
951 | void jit_uni_binary_t::execute_bcast_per_w_strategy(const data_t *src0, |
952 | const data_t *src1, data_t *dst, const float *scale0, |
953 | const float *scale1, |
954 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec, |
955 | const op_t op_type, const bool blocked_oc_tail) const { |
956 | const auto kernel = kernel_.get(); |
957 | const auto kernel_tail = kernel_tail_.get(); |
958 | const auto &simd_w = kernel_->simd_w(); |
959 | |
960 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
961 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
962 | const memory_desc_wrapper dst_d(pd()->dst_md(0)); |
963 | const int src0_type_size = types::data_type_size(src0_d.data_type()); |
964 | const int src1_type_size = types::data_type_size(src1_d.data_type()); |
965 | const int dst_type_size = types::data_type_size(dst_d.data_type()); |
966 | const auto ndims = src0_d.ndims(); |
967 | const auto &dims = src0_d.dims(); |
968 | const auto &bcast_dims = pd()->broadcast_dims(); |
969 | |
970 | const int not_bcasted_sp_dims = pd()->get_conf().not_bcasted_sp_dims; |
971 | const dim_t MB = dims[0]; |
972 | // array product of outer dimensions that are not broadcast |
973 | const dim_t SP_no_bcast = ndims >= 3 |
974 | ? utils::array_product( |
975 | dims + (ndims - not_bcasted_sp_dims), not_bcasted_sp_dims) |
976 | : 1; |
977 | const dim_t C = ndims >= 2 ? dims[1] : 1; |
978 | const dim_t SP = ndims >= 3 ? utils::array_product(dims + 2, ndims - 2) : 1; |
979 | // spatial without dimensions that are not broadcasted by src1 |
980 | const dim_t N = SP / SP_no_bcast; |
981 | |
982 | const dim_t nelems_slice_src0 |
983 | = utils::array_product(src0_d.padded_dims() + 1, ndims - 1); |
984 | |
985 | if (op_type == op_t::c_blocked) { |
986 | const dim_t C_blocks = std::ceil( |
987 | static_cast<float>(src0_d.padded_dims()[1]) / simd_w); |
988 | // Compute strategy: |
989 | // Each line of channels is individual, parallel over MB, C_blocks |
990 | // and spatial (broadcasted and not broadcasted spatial dims |
991 | // separately). |
992 | |
993 | const std::function<void(jit_binary_call_s *, dim_t)> |
994 | kernel_blocked_no_tail |
995 | = [&](jit_binary_call_s *p, dim_t C_blk) { (*kernel)(p); }; |
996 | const std::function<void(jit_binary_call_s *, dim_t)> |
997 | kernel_blocked_tail = [&](jit_binary_call_s *p, dim_t C_blk) { |
998 | if (C_blk == (C_blocks - 1)) |
999 | (*kernel_tail)(p); |
1000 | else |
1001 | (*kernel)(p); |
1002 | }; |
1003 | const auto &kernel_blocked = blocked_oc_tail ? kernel_blocked_tail |
1004 | : kernel_blocked_no_tail; |
1005 | |
1006 | parallel_nd(MB, C_blocks, N, SP_no_bcast, |
1007 | [&](dim_t mb, dim_t C_blk, dim_t n, dim_t sp) { |
1008 | jit_binary_call_s p; |
1009 | p.spat_offt_count = simd_w * dst_type_size; |
1010 | const auto off = mb * nelems_slice_src0 |
1011 | + simd_w * (C_blk * SP + n * SP_no_bcast + sp); |
1012 | p.dst = dst + off * dst_type_size; |
1013 | p.src0 = src0 + off * src0_type_size; |
1014 | // check if mb is broadcast |
1015 | const dim_t src1_off = bcast_dims[0] == 1 |
1016 | ? sp * simd_w |
1017 | : (mb * SP_no_bcast + sp) * simd_w; |
1018 | p.src1 = src1 + src1_off * src1_type_size; |
1019 | p.scales_src0 = scale0; |
1020 | p.scales_src1 = scale1; |
1021 | p.post_ops_binary_rhs_arg_vec |
1022 | = post_ops_binary_rhs_arg_vec.data(); |
1023 | p.dst_orig = dst; |
1024 | kernel_blocked(&p, C_blk); |
1025 | }); |
1026 | } else if (op_type == op_t::n_spatial_c) { |
1027 | // Compute strategy: |
1028 | // Each line of channels is individual, parallel over MB and spatial |
1029 | // (broadcasted and not broadcasted spatial dims separately). |
1030 | |
1031 | parallel_nd(MB, N, SP_no_bcast, [&](dim_t mb, dim_t n, dim_t sp) { |
1032 | jit_binary_call_s p; |
1033 | p.spat_offt_count = C * dst_type_size; |
1034 | const auto off |
1035 | = mb * nelems_slice_src0 + n * SP_no_bcast * C + sp * C; |
1036 | p.dst = dst + off * dst_type_size; |
1037 | p.src0 = src0 + off * src0_type_size; |
1038 | const dim_t src1_off |
1039 | = bcast_dims[0] == 1 ? sp : mb * SP_no_bcast + sp; |
1040 | p.src1 = src1 + src1_off * src1_type_size; |
1041 | p.scales_src0 = scale0; |
1042 | p.scales_src1 = scale1; |
1043 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
1044 | p.dst_orig = dst; |
1045 | (*kernel)(&p); |
1046 | }); |
1047 | } else if (op_type == op_t::n_c_spatial) { |
1048 | // Compute strategy: |
1049 | // Each line of width is individual, parallel over MB, C and spatial |
1050 | // without not broadcasted dims. Use a kernel which broadcasts c_i |
1051 | // value into a vector register. |
1052 | |
1053 | parallel_nd(MB, C, N, [&](dim_t mb, dim_t c, dim_t n) { |
1054 | jit_binary_call_s p; |
1055 | p.spat_offt_count = SP_no_bcast * dst_type_size; |
1056 | const auto off = mb * nelems_slice_src0 + c * N * SP_no_bcast |
1057 | + n * SP_no_bcast; |
1058 | p.dst = dst + off * dst_type_size; |
1059 | p.src0 = src0 + off * src0_type_size; |
1060 | const dim_t src1_off = bcast_dims[0] == 1 ? 0 : mb * SP_no_bcast; |
1061 | p.src1 = src1 + src1_off * src1_type_size; |
1062 | p.scales_src0 = scale0; |
1063 | p.scales_src1 = scale1; |
1064 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
1065 | p.dst_orig = dst; |
1066 | (*kernel)(&p); |
1067 | }); |
1068 | } |
1069 | } |
1070 | |
1071 | status_t jit_uni_binary_t::execute(const exec_ctx_t &ctx) const { |
1072 | const auto src0 = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC_0); |
1073 | const auto src1 = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC_1); |
1074 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
1075 | const auto &post_ops = pd()->attr()->post_ops_; |
1076 | const auto &post_ops_binary_rhs_arg_vec |
1077 | = binary_injector::prepare_binary_args(post_ops, ctx); |
1078 | const float *scales[2]; |
1079 | ASSIGN_ARG_SCALE_VALUE(scales[0], DNNL_ARG_SRC_0); |
1080 | ASSIGN_ARG_SCALE_VALUE(scales[1], DNNL_ARG_SRC_1); |
1081 | |
1082 | const memory_desc_wrapper src0_d(pd()->src_md(0)); |
1083 | const memory_desc_wrapper src1_d(pd()->src_md(1)); |
1084 | const auto ndims = src0_d.ndims(); |
1085 | const auto &dims = src0_d.dims(); |
1086 | const dim_t C = ndims >= 2 ? dims[1] : 0; |
1087 | |
1088 | const bool postops_per_oc_broadcast_exists |
1089 | = binary_injector::any_binary_postop_rhs_per_oc_broadcast( |
1090 | post_ops, src0_d, get_supported_postops_bcast_strategies()); |
1091 | const auto &bcast_type = pd()->get_conf().bcast_type; |
1092 | const bool point_broadcast = bcast_type == bcast_t::scalar; |
1093 | const auto &op_type = pd()->get_conf().op_type; |
1094 | const bool with_postops = !post_ops.entry_.empty(); |
1095 | const auto &simd_w = kernel_->simd_w(); |
1096 | const bool has_oc_tail = C % simd_w; |
1097 | const bool point_broadcast_no_oc_tail = point_broadcast && !has_oc_tail; |
1098 | const auto alg = pd()->desc()->alg_kind; |
1099 | // Use strategy with kernel_tail for GreaterEqual op with oc_tail and |
1100 | // blocked format due to overwriting the vector tail by vcmpps. |
1101 | const bool vector_overwrite = utils::one_of(alg, alg_kind::binary_ge, |
1102 | alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt, |
1103 | alg_kind::binary_eq, alg_kind::binary_ne); |
1104 | const bool blocked_oc_tail = op_type == op_t::c_blocked && has_oc_tail |
1105 | && (with_postops || point_broadcast || bcast_type == bcast_t::per_w |
1106 | || vector_overwrite); |
1107 | |
1108 | if ((bcast_type == bcast_t::none || point_broadcast_no_oc_tail) |
1109 | && !postops_per_oc_broadcast_exists && !blocked_oc_tail) |
1110 | execute_no_bcast_strategy(src0, src1, dst, scales[0], scales[1], |
1111 | post_ops_binary_rhs_arg_vec, bcast_type); |
1112 | else if (bcast_type == bcast_t::per_batch |
1113 | && !postops_per_oc_broadcast_exists && !blocked_oc_tail) |
1114 | execute_bcast_per_batch_strategy(src0, src1, dst, scales[0], scales[1], |
1115 | post_ops_binary_rhs_arg_vec); |
1116 | else if (bcast_type == bcast_t::per_w) |
1117 | execute_bcast_per_w_strategy(src0, src1, dst, scales[0], scales[1], |
1118 | post_ops_binary_rhs_arg_vec, op_type, blocked_oc_tail); |
1119 | else |
1120 | execute_bcast_per_c_strategy(src0, src1, dst, scales[0], scales[1], |
1121 | post_ops_binary_rhs_arg_vec, op_type, bcast_type, |
1122 | blocked_oc_tail); |
1123 | |
1124 | return status::success; |
1125 | } |
1126 | |
1127 | } // namespace x64 |
1128 | } // namespace cpu |
1129 | } // namespace impl |
1130 | } // namespace dnnl |
1131 | |