1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18#include "common/memory_tracking.hpp"
19#include "common/utils.hpp"
20
21#include "cpu/cpu_primitive.hpp"
22#include "cpu/scale_utils.hpp"
23
24#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25#include "cpu/x64/jit_brdgmm_dw_conv.hpp"
26#include <cpu/x64/cpu_isa_traits.hpp>
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::status;
35using namespace dnnl::impl::utils;
36
37using namespace nstl;
38using namespace data_type;
39
40inline status_t init_tag(memory_desc_t &md, const memory_desc_wrapper &mdw,
41 const format_tag_t tag_value, bool any_eligible) {
42
43 format_tag_t tag;
44 if (mdw.format_kind() == format_kind::any) {
45 if (any_eligible) {
46 CHECK(memory_desc_init_by_tag(md, tag_value));
47 tag = tag_value;
48 } else {
49 tag = format_tag::undef;
50 }
51 } else {
52 tag = mdw.matches_one_of_tag(tag_value);
53 }
54
55 if (tag != tag_value) return status::unimplemented;
56
57 return status::success;
58}
59
60bool post_ops_ok(jit_brdgmm_conv_conf_t &jcp, const primitive_attr_t &attr,
61 const memory_desc_wrapper &dst_d) {
62 using namespace injector;
63
64 const auto &post_ops = attr.post_ops_;
65
66 return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
67 {sum, eltwise, binary}, post_ops, &dst_d,
68 false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
69 false /*sum_requires_zp_zero*/,
70 {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar,
71 broadcasting_strategy_t::no_broadcast}));
72}
73
74cpu_isa_t get_supported_isa(
75 bool is_f32, bool is_int8, bool is_bf16, bool is_f16) {
76 std::vector<cpu_isa_t> isa_list;
77 if (is_f32) {
78 // Note: Temporarily disabling avx2 support until performance study.
79 isa_list = {avx512_core /*, avx2*/};
80 } else if (is_int8) {
81 isa_list = {avx512_core_vnni};
82 } else if (is_bf16) {
83 isa_list = {avx512_core_bf16, avx2_vnni_2};
84 } else if (is_f16) {
85 isa_list = {avx512_core_fp16, avx2_vnni_2};
86 }
87
88 for (auto isa : isa_list) {
89 if (mayiuse(isa)) return isa;
90 }
91 return isa_undef;
92}
93
94status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) {
95
96 using skip_mask_t = primitive_attr_t::skip_mask_t;
97
98 const auto &cd = *desc();
99 const auto src_type = cd.src_desc.data_type;
100 const auto wei_type = cd.weights_desc.data_type;
101 const auto bia_type = cd.bias_desc.data_type;
102 const auto dst_type = cd.dst_desc.data_type;
103
104 // TODO: support s8s8 conv
105 const bool is_f32 = everyone_is(f32, src_type, wei_type, dst_type);
106 const bool is_int8 = one_of(src_type, u8) && wei_type == s8
107 && one_of(dst_type, s32, f32, u8, s8, bf16);
108 const bool is_bf16 = everyone_is(bf16, src_type, wei_type)
109 && one_of(dst_type, bf16, f32);
110 const bool is_f16 = everyone_is(f16, src_type, wei_type)
111 && one_of(dst_type, f16, f32);
112 const cpu_isa_t isa = get_supported_isa(is_f32, is_int8, is_bf16, is_f16);
113
114 auto skip_mask = skip_mask_t::post_ops;
115 if (is_int8) skip_mask |= skip_mask_t::scales_runtime;
116
117 bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct)
118 && one_of(true, is_f32, is_int8, is_bf16, is_f16)
119 && (isa != isa_undef) && mayiuse(isa)
120 && IMPLICATION(is_int8,
121 one_of(bia_type, data_type::undef, f32, s32, s8, u8))
122 && IMPLICATION(!is_int8,
123 one_of(bia_type, data_type::undef, src_type, dst_type))
124 && attr()->has_default_values(skip_mask) && !has_zero_dim_memory();
125 if (!ok) return status::unimplemented;
126
127 auto &jcp = jcp_;
128
129 const memory_desc_wrapper src_d(&src_md_);
130 const memory_desc_wrapper weights_d(&weights_md_);
131 const memory_desc_wrapper dst_d(&dst_md_);
132 const memory_desc_wrapper bias_d(&bias_md_);
133
134 const int ndims = src_d.ndims();
135 // Currently this kernel only supports 2D convolutions.
136 if (ndims != 4) return status::unimplemented;
137 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
138 if (!with_groups) return status::unimplemented;
139 // dilations are not supported
140 if (cd.dilates[0] != 0 || cd.dilates[1] != 0) return status::unimplemented;
141
142 jcp = zero<decltype(jcp)>();
143 jcp.ngroups = weights_d.dims()[0];
144 jcp.mb = src_d.dims()[0];
145 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
146 jcp.ic = src_d.dims()[1] / jcp.ngroups;
147 jcp.ih = src_d.dims()[2];
148 jcp.iw = src_d.dims()[3];
149 jcp.oh = dst_d.dims()[2];
150 jcp.ow = dst_d.dims()[3];
151 jcp.kh = weights_d.dims()[3];
152 jcp.kw = weights_d.dims()[4];
153 jcp.t_pad = cd.padding[0][0];
154 jcp.l_pad = cd.padding[0][1];
155 jcp.stride_h = cd.strides[0];
156 jcp.stride_w = cd.strides[1];
157 jcp.b_pad = calculate_end_padding(
158 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.kh);
159 jcp.r_pad = calculate_end_padding(
160 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.kw);
161 jcp.src_dt = cd.src_desc.data_type;
162 jcp.dst_dt = cd.dst_desc.data_type;
163 jcp.wei_dt = cd.weights_desc.data_type;
164 jcp.with_bias = with_bias();
165 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
166
167 if (!(everyone_is(1, jcp.ic, jcp.oc))) return status::unimplemented;
168
169 const auto def_data_tag = format_tag::nhwc;
170 const bool any_eligible = (cd.prop_kind == prop_kind::forward_inference
171 || is_int8 || is_f16 || (isa == avx2_vnni_2 && is_bf16));
172 CHECK(init_tag(src_md_, src_d, def_data_tag, any_eligible));
173 CHECK(init_tag(dst_md_, dst_d, def_data_tag, any_eligible));
174
175 if (jcp.with_bias) {
176 if (bias_d.format_kind() == format_kind::any)
177 CHECK(memory_desc_init_by_tag(bias_md_, format_tag::x));
178 }
179
180 CHECK(attr_.set_default_formats(dst_md()));
181 if (!post_ops_ok(jcp, *attr(), dst_d)) return status::unimplemented;
182 jcp.with_post_ops = attr()->post_ops_.len() > 0;
183
184 jcp.isa = isa;
185 jcp.nthr = dnnl_get_max_threads();
186 jcp.src_dsz = types::data_type_size(jcp.src_dt);
187 jcp.wei_dsz = types::data_type_size(jcp.wei_dt);
188 jcp.bia_dsz
189 = jcp.with_bias ? types::data_type_size(cd.bias_desc.data_type) : 0;
190 jcp.dst_dsz = types::data_type_size(jcp.dst_dt);
191
192 const auto &src_scales = attr_.scales_.get(DNNL_ARG_SRC);
193 const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS);
194 const auto &dst_scales = attr_.scales_.get(DNNL_ARG_DST);
195 jcp.with_scale = !src_scales.has_default_values()
196 || !wei_scales.has_default_values();
197 const int wei_mask_per_oc = 1 << (int)with_groups;
198 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
199
200 // only common and per-oc-channel scales are supported
201 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
202 && src_scales.mask_ == 0 && dst_scales.has_default_values();
203 if (!scales_ok) return status::unimplemented;
204
205 // strd is only feasible for 1D (i.e., height dim is one)
206 // and if there are no tails (for calculating matrix_B strides).
207 // Since, we cannot always predict the blocking is 8 or 16.
208 if (jcp.kh == 1 && jcp.ngroups % 16 == 0) {
209 jcp.batch_kind = brgemm_strd;
210 } else if ((jcp.mb * jcp.oh) % jcp.nthr != 0) {
211 jcp.batch_kind = brgemm_offs;
212 } else {
213 jcp.batch_kind = brgemm_addr;
214 }
215
216 // to avoid cache concurrent access from different threads
217 size_t sc_size = sizeof(brgemm_batch_element_t);
218 jcp.adjusted_batch_size
219 = div_up(rnd_up(jcp.kh * jcp.kw * sc_size, 4096), sc_size);
220 CHECK(init_brdgmm_conf());
221 CHECK(init_scratchpad());
222 if (jcp.with_scale) {
223 auto scratchpad = scratchpad_registry().registrar();
224 book_precomputed_scales(scratchpad, attr_.scales_, OC());
225 }
226
227 return status::success;
228}
229
230status_t brdgmm_dw_convolution_fwd_t::pd_t::init_brdgmm_conf() {
231
232 auto &jcp = jcp_;
233
234 auto init_bcp = [&](int &idx, const int M, const int N) {
235 const float alpha = 1.f;
236 const float beta = 0.f;
237 const int LDA = jcp.ngroups * jcp.stride_w;
238 const int LDC = jcp.ngroups;
239 const int LDD = jcp.ngroups;
240
241 brgemm_attr_t brg_attr;
242 brg_attr.max_bs = jcp.kw * jcp.kh;
243 brg_attr.max_top_vpad = nstl::max(0, jcp.l_pad);
244 brg_attr.max_bottom_vpad = nstl::max(0, jcp.r_pad);
245
246 // only needed for strd batch_kind
247 const brgemm_strides_t strides
248 = {static_cast<dim_t>(jcp.src_dsz) * jcp.ngroups,
249 static_cast<dim_t>(jcp.wei_dsz) * jcp.ngroups};
250
251 auto &bcp = bcps_[idx];
252 CHECK(brdgmm_desc_init(&bcp, jcp.isa, jcp.batch_kind, jcp.src_dt,
253 jcp.wei_dt, false /*transA*/, brgemm_row_major, alpha, beta,
254 LDA, LDC, M, N, &strides));
255 CHECK(brgemm_desc_set_attr(&bcp, brg_attr));
256 CHECK(brgemm_desc_set_postops(&bcp, attr(), dst_md(), LDD, jcp.bia_dt));
257 ++idx;
258 return status::success;
259 };
260
261 bcps_.resize(1);
262 jcp.ow_block = jcp.ow;
263 jcp.nb_ow = 1;
264 jcp.nb_ch_blocking = jcp.ngroups;
265 jcp.chb_tail = 0;
266 int ker_idx = 0;
267 CHECK(init_bcp(ker_idx, jcp.ow, jcp.ngroups)); // default full row kernel.
268
269 const auto &bcp_0 = bcps_[0];
270 jcp.ch_block = bcp_0.ld_block;
271 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
272
273 const auto wei_tag
274 = jcp.ch_block == 16 ? format_tag::hwioG16g : format_tag::hwioG8g;
275 const memory_desc_wrapper weights_d(&weights_md_);
276 CHECK(init_tag(weights_md_, weights_d, wei_tag, true));
277
278 if ((jcp.mb * jcp.oh) % jcp.nthr != 0) {
279 // determine ow_block
280 {
281 const size_t work_amount = jcp.mb * jcp.oh * jcp.ow;
282 if (work_amount % jcp.nthr == 0) {
283 const size_t work_per_thr = div_up(work_amount, jcp.nthr);
284 const size_t ow_tail_block
285 = (work_per_thr / jcp.nb_ch) % jcp.ow;
286 if (ow_tail_block && (jcp.ow % ow_tail_block == 0))
287 jcp.ow_block = ow_tail_block;
288 else {
289 jcp.ow_block = jcp.ow;
290 }
291 } else {
292 const int max_ow_block = is_superset(jcp.isa, avx512_core)
293 ? 6
294 : bcp_0.bd_block2 /*TODO: Tune for avx2*/;
295 jcp.ow_block = nstl::min(max_ow_block, jcp.ow);
296 }
297 jcp.ow_tail = jcp.ow % jcp.ow_block;
298 }
299 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
300
301 // determine nb_ch_block
302 {
303 const size_t work_amount = jcp.mb * jcp.nb_ch * jcp.oh * jcp.nb_ow;
304 if (work_amount % jcp.nthr == 0) {
305 const size_t work_per_thr = div_up(work_amount, jcp.nthr);
306 const size_t ch_tail_block = work_per_thr % jcp.nb_ch;
307 if (ch_tail_block && (jcp.nb_ch % ch_tail_block == 0))
308 jcp.nb_ch_blocking = ch_tail_block * jcp.ch_block;
309 else
310 jcp.nb_ch_blocking = jcp.ngroups;
311 } else {
312 const int max_ch_block2 = is_superset(jcp.isa, avx512_core)
313 ? 4
314 : bcp_0.ld_block2 /*TODO: Tune for avx2*/;
315 jcp.nb_ch_blocking
316 = nstl::min(max_ch_block2 * jcp.ch_block, jcp.ngroups);
317 }
318 jcp.chb_tail = jcp.ngroups % jcp.nb_ch_blocking;
319 }
320
321 const int n_owb_kernels = std::ceil(log2(jcp.nb_ow));
322 const int num_kernels = 1 /*full ow*/ + n_owb_kernels
323 + (jcp.chb_tail != 0) + (jcp.nb_ch_blocking != jcp.ngroups)
324 + (jcp.ow_tail != 0);
325 bcps_.resize(num_kernels);
326
327 for (int i = 0; i < n_owb_kernels; ++i) {
328 CHECK(init_bcp(ker_idx, jcp.ow_block * (1 << i), jcp.ngroups));
329 }
330
331 if (jcp.chb_tail) {
332 jcp.chb_tail_idx = ker_idx;
333 CHECK(init_bcp(ker_idx, jcp.ow_block, jcp.chb_tail));
334 }
335
336 if (jcp.ow_tail) {
337 jcp.ow_tail_idx = ker_idx;
338 CHECK(init_bcp(ker_idx, jcp.ow_tail, jcp.ngroups));
339 }
340
341 if (jcp.nb_ch_blocking != jcp.ngroups) {
342 jcp.nb_ch_blocking_idx = ker_idx;
343 CHECK(init_bcp(ker_idx, jcp.ow_block, jcp.nb_ch_blocking));
344 }
345 assert(num_kernels == ker_idx);
346 }
347
348 return status::success;
349}
350
351status_t brdgmm_dw_convolution_fwd_t::pd_t::init_scratchpad() {
352 const auto &jcp = jcp_;
353 auto scratchpad = scratchpad_registry().registrar();
354
355 scratchpad.book(key_brgemm_primitive_batch,
356 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
357 sizeof(brgemm_batch_element_t), 64);
358 return status::success;
359}
360
361status_t brdgmm_dw_convolution_fwd_t::init(engine_t *engine) {
362 const auto &bcps = pd()->bcps_;
363 brdgmm_kernels_.resize(bcps.size());
364
365 for (size_t idx = 0; idx < bcps.size(); ++idx) {
366 const auto &bcp = bcps[idx];
367 if (bcp.bcast_dim * bcp.load_dim /* M*N */ == 0) continue;
368 brgemm_kernel_t *brg_kernel = nullptr;
369 CHECK(brgemm_kernel_create(&brg_kernel, pd()->bcps_[idx]));
370 CHECK(safe_ptr_assign(brdgmm_kernels_[idx], brg_kernel));
371 }
372
373 return status::success;
374}
375
376status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
377
378 const char *const __restrict src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
379 const char *const __restrict weights
380 = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
381 const char *const __restrict bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
382 char *const __restrict dst = CTX_OUT_MEM(const char *, DNNL_ARG_DST);
383 const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
384 brgemm_batch_element_t *const __restrict brg_batch_global
385 = scratchpad.template get<brgemm_batch_element_t>(
386 key_brgemm_primitive_batch);
387 const std::vector<const void *> post_ops_binary_rhs_arg_vec
388 = binary_injector::prepare_binary_args(
389 pd()->attr()->post_ops_, ctx);
390
391 const auto &jcp = pd()->jcp_;
392
393 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
394 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
395
396 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
397 src_scales, wei_scales, pd()->OC(), pd()->attr());
398
399 const int chb_step = jcp.nb_ch_blocking;
400 const int chb_work = div_up(jcp.ngroups, chb_step);
401 const int ow_step = jcp.ow_block;
402 const int work_amount = jcp.mb * jcp.oh * jcp.nb_ow * chb_work;
403
404 const size_t src_ch_stride = jcp.src_dsz;
405 const size_t src_w_stride = jcp.ngroups * jcp.src_dsz;
406 const size_t src_h_stride = jcp.ngroups * jcp.iw * jcp.src_dsz;
407 const size_t src_mb_stride = jcp.ngroups * jcp.iw * jcp.ih * jcp.src_dsz;
408 const size_t wei_ch_stride = jcp.wei_dsz;
409 const size_t wei_w_stride = rnd_up(jcp.ngroups, jcp.ch_block) * jcp.wei_dsz;
410 const size_t wei_h_stride = wei_w_stride * jcp.kw;
411 const size_t dst_ch_stride = jcp.dst_dsz;
412 const size_t dst_w_stride = jcp.ngroups * jcp.dst_dsz;
413 const size_t dst_h_stride = jcp.ngroups * jcp.ow * jcp.dst_dsz;
414 const size_t dst_mb_stride = jcp.ngroups * jcp.ow * jcp.oh * jcp.dst_dsz;
415
416 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
417 int start {0}, end {0};
418 balance211(work_amount, nthr, ithr, start, end);
419 int n {0}, chb {0}, oh {0}, owb {0};
420
421 auto iwork = start;
422 brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global
423 + static_cast<size_t>(ithr) * jcp.adjusted_batch_size;
424 const brgemm_kernel_t *kernel = nullptr;
425 const brgemm_kernel_t *kernel_chb_tail
426 = brdgmm_kernels_[jcp.chb_tail_idx].get();
427 brgemm_post_ops_data_t post_ops_data;
428 post_ops_data.binary_post_ops_rhs = post_ops_binary_rhs_arg_vec.data();
429 post_ops_data.data_C_ptr_ = dst;
430
431 while (iwork < end) {
432 nd_iterator_init(iwork, n, jcp.mb, oh, jcp.oh, owb, jcp.nb_ow, chb,
433 chb_work);
434 const bool is_m_tail = jcp.ow_tail != 0 && (owb + 1 == jcp.nb_ow);
435 const bool is_n_tail = jcp.chb_tail != 0 && (chb + 1 == chb_work);
436 if (is_m_tail && chb != 0) {
437 // the tail ow_block is not split btw threads to reduce the
438 // number of kernels.
439 utils::nd_iterator_jump(iwork, end, n, jcp.mb, oh, jcp.oh, owb,
440 jcp.nb_ow, chb, chb_work);
441 continue;
442 }
443
444 const auto rem_work = end - iwork;
445 const int rem_row_owb
446 = saturate(1, jcp.nb_ow - owb, rem_work / chb_work);
447 int cur_n_owb = 1;
448 int ker_idx = 0;
449 if (is_n_tail) {
450 ker_idx = jcp.chb_tail_idx;
451 } else if (is_m_tail) {
452 ker_idx = jcp.ow_tail_idx;
453 } else if (chb != 0 || rem_work < chb_work) {
454 ker_idx = jcp.nb_ch_blocking_idx;
455 } else if (rem_row_owb == jcp.nb_ow) {
456 ker_idx = 0;
457 cur_n_owb = jcp.nb_ow;
458 } else {
459 // The ow_tail kernel is processed alone, subtract if it exists.
460 const int log_rem_owb = log2(rem_row_owb
461 - (owb + rem_row_owb >= jcp.nb_ow)
462 * (jcp.ow_tail != 0));
463 cur_n_owb = (1 << log_rem_owb);
464 ker_idx = log_rem_owb + 1; // add 1 as 0th is full row.
465 }
466
467 kernel = brdgmm_kernels_[ker_idx].get();
468
469 int ch = chb * chb_step;
470 const int ow = owb * ow_step;
471 auto *ptr_A = src;
472 auto *ptr_B = weights;
473 int bs = 0;
474 for (int kh = 0; kh < jcp.kh; ++kh) {
475 for (int kw = 0; kw < jcp.kw; ++kw) {
476 const int ih = (oh * jcp.stride_h - jcp.t_pad) + kh;
477 if (ih < 0 || ih >= jcp.ih) continue;
478 const int iw_s = ow * jcp.stride_w - jcp.l_pad + kw;
479 const int ow_e
480 = nstl::min(jcp.ow, ow + cur_n_owb * jcp.ow_block)
481 - 1;
482 const int iw_e = ow_e * jcp.stride_w - jcp.l_pad + kw;
483 auto &batch = brg_batch[bs];
484 batch.vvpad.top = nstl::max(0, div_up(-iw_s, jcp.stride_w));
485 batch.vvpad.bottom = nstl::max<dim_t>(
486 0, div_up(iw_e - (jcp.iw - 1), jcp.stride_w));
487 const dim_t offs_A = n * src_mb_stride + ih * src_h_stride
488 + iw_s * src_w_stride + ch * src_ch_stride;
489 const dim_t offs_B = kh * wei_h_stride + kw * wei_w_stride
490 + ch * wei_ch_stride;
491 if (jcp.batch_kind == brgemm_offs) {
492 batch.offset.A = offs_A;
493 batch.offset.B = offs_B;
494 } else if (jcp.batch_kind == brgemm_addr) {
495 batch.ptr.A = src + offs_A;
496 batch.ptr.B = weights + offs_B;
497 } else {
498 assert(jcp.batch_kind == brgemm_strd);
499 if (bs == 0) {
500 ptr_A = src + offs_A;
501 ptr_B = weights + offs_B;
502 }
503 }
504 ++bs;
505 }
506 }
507 auto ptr_C = dst + n * dst_mb_stride + oh * dst_h_stride
508 + ow * dst_w_stride + ch * dst_ch_stride;
509 const int rem_chb_work = chb_work - chb;
510 int chb_loop_work = is_m_tail || (chb == 0 && rem_work >= chb_work)
511 ? 1 // Compute entire chb_work in single jit call
512 : nstl::min(rem_work, rem_chb_work);
513 iwork += cur_n_owb * nstl::min(rem_work, rem_chb_work);
514
515 while (chb_loop_work) {
516 // brgemm_offs and brgemm_strd mode enables us to run this loop,
517 // without changing brg_batch elements.
518 assert(IMPLICATION(chb != 0,
519 one_of(jcp.batch_kind, brgemm_offs, brgemm_strd)));
520 post_ops_data.bias = bias + ch * jcp.bia_dsz;
521 post_ops_data.scales = &oscales[jcp.is_oc_scale * ch];
522 post_ops_data.oc_logical_off = ch;
523 brgemm_kernel_execute_postops(kernel, bs, ptr_A, ptr_B,
524 brg_batch, ptr_C, ptr_C, post_ops_data,
525 nullptr /*scratch*/);
526 ++chb;
527 if (jcp.chb_tail != 0 && chb + 1 == chb_work)
528 kernel = kernel_chb_tail;
529 ch += chb_step;
530 ptr_A += chb_step * src_ch_stride;
531 ptr_B += chb_step * wei_ch_stride;
532 ptr_C += chb_step * dst_ch_stride;
533 --chb_loop_work;
534 }
535 }
536 });
537 return status::success;
538}
539} // namespace x64
540} // namespace cpu
541} // namespace impl
542} // namespace dnnl
543