1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23#include "cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp"
24
25#include "cpu/x64/jit_generator.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35using namespace dnnl::impl::prop_kind;
36
37#define data_blk_off(f, n, c, d, h, w) \
38 ((ndims == 3) ? (f).blk_off(n, c, w) \
39 : ((ndims == 4) ? (f).blk_off(n, c, h, w) \
40 : (f).blk_off(n, c, d, h, w)))
41
42namespace {
43/*TODO: investigate why common balance2D defined in common/dnnl_thread.hpp
44 * not used here ?*/
45template <typename T, typename U>
46void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, T nx, T &nx_start,
47 T &nx_end, T nx_divider) {
48 const T grp_size = utils::div_up(nthr, nx_divider);
49 const T grp_count = utils::div_up(nthr, grp_size);
50
51 T grp = ithr / grp_size;
52 T grp_ithr = ithr % grp_size;
53 T grp_nthr = grp_size;
54 T first_grps = nthr % grp_count;
55 if (first_grps > 0 && grp >= first_grps) {
56 ithr -= first_grps * grp_size;
57 grp_nthr--;
58 grp = ithr / grp_nthr + first_grps;
59 grp_ithr = ithr % grp_nthr;
60 }
61 balance211(nx, grp_count, grp, nx_start, nx_end);
62 balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
63}
64} // namespace
65
66/* convolution forward */
67template <data_type_t dst_type>
68void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward(
69 const exec_ctx_t &ctx) const {
70 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
71 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
72 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
73 auto dst = CTX_OUT_MEM(const char *, DNNL_ARG_DST);
74 auto weights_dw = CTX_IN_MEM(
75 const dw_wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
76 const auto post_ops_binary_rhs_arg_vec
77 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
78 const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ != nullptr
79 ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
80 pd()->jcp_.post_ops.entry_.size() + 1)
81 : std::vector<const void *> {};
82
83 auto scratchpad = ctx.get_scratchpad_grantor();
84
85 const auto &jcp = kernel_->jcp;
86 if (pd()->wants_padded_bias()) {
87 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
88 auto padded_bias = scratchpad.template get<char>(key_conv_padded_bias);
89 utils::array_copy(
90 padded_bias, bias, bia_dt_size * jcp.oc_without_padding);
91 utils::array_set(padded_bias + bia_dt_size * jcp.oc_without_padding,
92 0.f, bia_dt_size * (jcp.oc - jcp.oc_without_padding));
93 bias = padded_bias;
94 }
95
96 float *bias_dw = nullptr;
97 if (pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)->data_type
98 == data_type::bf16) {
99 auto jcp_dw = pd()->jcp_dw_;
100 memory_tracking::grantor_t dw_scratchpad(
101 scratchpad, memory_tracking::names::prefix_fusion);
102 auto bias_in = CTX_IN_MEM(
103 const src_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
104 bias_dw = dw_scratchpad.template get<float>(
105 key_conv_bias_bf16_convert_wsp);
106 cvt_bfloat16_to_float(bias_dw, bias_in, jcp_dw->oc_without_padding);
107 utils::array_set(bias_dw + jcp_dw->oc_without_padding, 0.f,
108 jcp_dw->oc - jcp_dw->oc_without_padding);
109 } else {
110 auto bias_in = CTX_IN_MEM(
111 const float *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
112 bias_dw = const_cast<float *>(bias_in);
113 }
114
115 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
116 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
117 dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
118 post_ops_binary_rhs_arg_vec_dw.data());
119 });
120
121 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
122}
123
124template <data_type_t dst_type>
125void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward_thr(
126 const int ithr, const int nthr, const src_data_t *src,
127 const wei_data_t *weights, const char *bias,
128 const dw_wei_data_t *weights_dw, const float *bias_dw, const char *dst,
129 const memory_tracking::grantor_t &scratchpad,
130 const void *post_ops_binary_rhs_arg_vec,
131 const void *post_ops_binary_rhs_arg_vec_dw) const {
132 const memory_desc_wrapper src_d(pd()->src_md());
133 const memory_desc_wrapper dst_d(pd()->dst_md());
134 const memory_desc_wrapper weights_d(pd()->weights_md(0));
135 const memory_desc_wrapper dw_weights_d(
136 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
137 const memory_desc_wrapper dw_bias_d(
138 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS));
139
140 const auto &jcp = kernel_->jcp;
141 auto rtus_space = pd()->rtus_.reduce_src_
142 ? scratchpad.get<src_data_t>(key_conv_rtus_space)
143 : nullptr;
144 float *store_buffer = scratchpad.template get<float>(key_conv_store_wsp);
145
146 const int ndims = src_d.ndims();
147 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
148 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
149 const int stride_w = pd()->desc()->strides[ndims - 3];
150
151 auto p = jit_1x1_conv_call_s();
152
153 auto rp = rtus_driver_t<avx512_core>::call_params_t();
154
155 const int nb_oc = jcp.nb_load;
156 const int nb_ic = jcp.nb_reduce;
157 const int nb_ic_blocking = jcp.nb_reduce_blocking;
158
159 // override some constants for fused dw_conv
160 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
161 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
162 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
163 const int nb_bcast_blocking_max
164 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
165 const int nb_load_blocking = jcp.nb_load_blocking;
166 const int nb_load_blocking_max = jcp.with_dw_conv
167 ? jcp.nb_load_blocking
168 : jcp.nb_load_blocking_max;
169
170 // Begin: declare Variables needed for dw conv.
171 dst_data_t *pbuf; //bf16->bf16 fusion
172 size_t row_offset;
173 const auto jcp_dw = pd()->jcp_dw_;
174 const int nb_buffer = jcp.nb_load_blocking;
175 std::vector<decltype(pbuf)> addrs;
176 const bool is_dst_layout_nxc = utils::one_of(
177 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
178 const bool is_src_layout_nxc = utils::one_of(
179 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
180
181 auto step = [](int default_step, int remaining, int tail_step) {
182 assert(default_step <= tail_step);
183 return remaining < tail_step ? remaining : default_step;
184 };
185
186 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
187 int &bcast_step, int &od, int &oh, int &ow,
188 int &id, int &ih, int &iw) {
189 int osb {0};
190 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
191 bcast_step = step(
192 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
193 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
194
195 const int os = osb * os_block;
196 od = os / (jcp.oh * jcp.ow);
197 int os_2d = os % (jcp.oh * jcp.ow);
198 oh = os_2d / jcp.ow;
199 ow = os_2d % jcp.ow;
200
201 id = od * stride_d;
202 ih = oh * stride_h;
203 iw = ow * stride_w;
204 rp.iw_start = iw;
205
206 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
207 rp.os = p.bcast_dim;
208 };
209
210 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
211 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
212 const auto max_oc
213 = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding);
214 p.load_dim = this_block_size(
215 ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block);
216 };
217
218 auto init_reduce = [&](int icb) {
219 const int nb_ic_blocking_step
220 = nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
221 p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
222 | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0);
223
224 p.reduce_dim = this_block_size(
225 icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block);
226 rp.icb = p.reduce_dim;
227 };
228
229 auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od,
230 int oh, int ow, int id, int ih, int iw) {
231 const int oc_off_idx = is_dst_layout_nxc
232 ? g * jcp.oc + ocb * jcp.oc_block
233 : g * nb_oc + ocb;
234 const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow);
235
236 void *output_data = jcp.with_dw_conv
237 ? (void *)(pbuf + (oh % jcp_dw->kh) * row_offset)
238 : (void *)(&dst[dst_off * dst_d.data_type_size()]);
239 p.output_data = output_data;
240
241 p.bias_data = &bias[jcp.typesize_bia * oc_off_idx
242 * (is_dst_layout_nxc ? 1 : jcp.oc_block)];
243 p.load_data
244 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
245 : weights_d.blk_off(ocb, icb)];
246
247 const int ic_off_idx = is_src_layout_nxc
248 ? g * jcp.ic + icb * jcp.ic_block
249 : g * nb_ic + icb;
250 if (pd()->rtus_.reduce_src_) {
251 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
252 + (is_src_layout_nxc ? ic_off_idx
253 : jcp.is * ic_off_idx * jcp.ic_block);
254 if (ocb == ocb_start) {
255 rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
256 (*rtus_driver_)(&rp);
257 }
258 p.bcast_data = rp.ws;
259 } else
260 p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
261
262 const size_t grp_count = utils::div_up(
263 jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count));
264 const size_t max_load_per_thread = is_dst_layout_nxc
265 ? jcp.load_dim
266 : rnd_up((jcp.load_dim / grp_count), jcp.load_block);
267 const size_t str_size = jcp.bcast_dim * max_load_per_thread;
268 p.store_buffer = store_buffer + ithr * str_size
269 + data_blk_off(dst_d, 0, 0, od, oh, ow);
270
271 p.dst_l_off = dst_off;
272 p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
273 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
274 p.dst_orig = dst;
275
276 (*kernel_)(&p);
277 };
278
279 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
280 int ocb_end) {
281 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
282 if (jcp.loop_order == loop_lbr) {
283 int ocb = ocb_start;
284 while (ocb < ocb_end) {
285 int load_step;
286 init_load(ocb, ocb_end, load_step);
287 int iwork = bcast_start;
288 while (iwork < bcast_end) {
289 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
290 id {0}, ih {0}, iw {0};
291 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
292 id, ih, iw);
293 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
294 init_reduce(icb);
295 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
296 iw);
297 }
298 iwork += bcast_step;
299 }
300 ocb += load_step;
301 }
302 } else if (jcp.loop_order == loop_blr) {
303 int iwork = bcast_start;
304 while (iwork < bcast_end) {
305 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
306 id {0}, ih {0}, iw {0};
307 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
308 ih, iw);
309 int ocb = ocb_start;
310 while (ocb < ocb_end) {
311 int load_step;
312 init_load(ocb, ocb_end, load_step);
313 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
314 init_reduce(icb);
315 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
316 iw);
317 }
318 ocb += load_step;
319 }
320 iwork += bcast_step;
321 }
322 } else {
323 assert(!"unsupported loop order");
324 }
325 };
326
327 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
328 int oh_1x1 = nstl::max(dw_oh * jcp_dw->stride_h - jcp_dw->t_pad, 0);
329
330 for (int i = 0; i < jcp_dw->kh; ++i)
331 addrs[i] = pbuf + ((oh_1x1++) % jcp_dw->kh) * row_offset;
332
333 const auto ocb_end = ocb_start + load_step;
334 const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw->iw)
335 * jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
336
337 const int dil_h = jcp_dw->dilate_h + 1;
338 const int str_h = jcp_dw->stride_h;
339 const int ch_num = jcp_dw->nb_ch_blocking;
340
341 for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw->nb_ch_blocking) {
342
343 const int i_t_overflow
344 = nstl::max(0, (int)(jcp_dw->t_pad - dw_oh * str_h));
345 const int i_b_overflow
346 = nstl::max(jcp_dw->ih,
347 (int)(dw_oh * str_h + (jcp_dw->kh - 1) * dil_h
348 - jcp_dw->t_pad + 1))
349 - jcp_dw->ih;
350
351 const int kh = div_up(i_t_overflow, dil_h);
352 const int kh_padding = jcp_dw->kh - div_up(i_t_overflow, dil_h)
353 - div_up(i_b_overflow, dil_h);
354
355 const int ow = 0;
356 const int kw = 0;
357 jit_conv_call_s par_conv_dw;
358
359 par_conv_dw.src = addrs.data();
360
361 const size_t ch_step = is_dst_layout_nxc
362 ? jcp_dw->ch_block
363 : dst_d.blk_off(0, 1, 0, 0);
364 par_conv_dw.dst
365 = &dst[(dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step)
366 * dst_d.data_type_size()];
367
368 par_conv_dw.filt
369 = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)];
370 if (bias)
371 par_conv_dw.bias
372 = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw->ch_block)];
373
374 par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding);
375
376 par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw->nb_ch) - ch)
377 * jcp_dw->ch_block;
378
379 par_conv_dw.oc_l_off = ch * jcp_dw->ch_block;
380 par_conv_dw.post_ops_binary_rhs_arg_vec
381 = post_ops_binary_rhs_arg_vec_dw;
382 par_conv_dw.dst_orig = dst;
383
384 (*kernel_dw_)(&par_conv_dw);
385
386 for (int i = 0; i < jcp_dw->kh; ++i)
387 addrs[i] += wch_stride;
388 }
389 };
390
391 auto conv_dw = [&]() {
392 // Set variables
393 memory_tracking::grantor_t dw_scratchpad(
394 scratchpad, memory_tracking::names::prefix_fusion);
395 const auto dw_conv_buffer
396 = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer);
397
398 const auto dw_conv_buffer_size_
399 = jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
400 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
401 row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
402 addrs.resize(jcp_dw->kh);
403
404 int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end;
405 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
406 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count);
407
408 while (ocb_start < ocb_end) {
409 int load_step;
410 init_load(ocb_start, ocb_end, load_step);
411
412 int oh_1x1 = 0;
413 auto bcast_iter = bcast_start;
414 while (bcast_iter < bcast_end) {
415 int n {0}, g {0}, oh_dw {0};
416 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
417 jcp_dw->oh);
418 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
419 const int oh_1x1_range
420 = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
421 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
422 const int oh_1x1_end
423 = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
424 oh_1x1 = nstl::max(
425 oh_1x1_begin, oh_1x1); // Skip rows computed previously
426
427 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh
428 const int bcast_start_1x1
429 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
430 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
431
432 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
433 ocb_start + load_step);
434 oh_1x1 = oh_1x1_end;
435 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
436
437 bcast_iter += nb_bcast_blocking;
438 }
439 ocb_start += load_step;
440 }
441 };
442
443 if (jcp.with_dw_conv) {
444 conv_dw();
445 } else {
446 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
447 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
448 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load,
449 ocb_start, ocb_end, jcp.load_grp_count);
450
451 conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end);
452 }
453}
454
455REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_fwd_t<
456 data_type::f32>);
457REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_fwd_t<
458 data_type::bf16>);
459
460template <data_type_t diff_src_type>
461void jit_avx512_core_bf16_1x1_convolution_bwd_data_t<
462 diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const {
463 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
464 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
465 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
466 auto scratchpad = ctx.get_scratchpad_grantor();
467 const auto &jcp = kernel_->jcp;
468 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
469 assert(nthr == jcp.nthr);
470 execute_backward_data_thr(
471 ithr, nthr, diff_dst, weights, diff_src, scratchpad);
472 });
473}
474
475template <data_type_t diff_src_type>
476void jit_avx512_core_bf16_1x1_convolution_bwd_data_t<
477 diff_src_type>::execute_backward_data_thr(const int ithr,
478 const int nthr, const diff_dst_data_t *diff_dst,
479 const wei_data_t *weights, diff_src_data_t *diff_src,
480 const memory_tracking::grantor_t &scratchpad) const {
481
482 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
483 const memory_desc_wrapper weights_d(pd()->weights_md(0));
484 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
485
486 const auto &jcp = kernel_->jcp;
487
488 auto rtus_space = pd()->rtus_.reduce_src_
489 ? scratchpad.template get<diff_src_data_t>(key_conv_rtus_space)
490 : nullptr;
491 float *store_buffer = scratchpad.template get<float>(key_conv_store_wsp);
492 const int ndims = diff_src_d.ndims();
493 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
494 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
495 const int stride_w = pd()->desc()->strides[ndims - 3];
496
497 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
498
499 auto step = [](int default_step, int remaining, int tail_step) {
500 assert(default_step <= tail_step);
501 return remaining < tail_step ? remaining : default_step;
502 };
503
504 auto p = jit_1x1_conv_call_s();
505
506 auto rp = rtus_driver_t<avx512_core>::call_params_t();
507 const int nb_ic = jcp.nb_load;
508 const int nb_oc = jcp.nb_reduce;
509 const int os_block = jcp.bcast_block;
510 const int nb_oc_blocking = jcp.nb_reduce_blocking;
511
512 int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0};
513 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load,
514 icb_start, icb_end, jcp.load_grp_count);
515
516 auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, int &od,
517 int &oh, int &ow, int &id, int &ih, int &iw) {
518 int osb {0};
519 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast);
520 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
521 jcp.nb_bcast_blocking_max);
522 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
523
524 const int os = osb * os_block;
525 od = os / (jcp.oh * jcp.ow);
526 const int os_2d = os % (jcp.oh * jcp.ow);
527 oh = os_2d / jcp.ow;
528 ow = os_2d % jcp.ow;
529 id = od * stride_d;
530 ih = oh * stride_h;
531 iw = ow * stride_w;
532 rp.iw_start = iw;
533
534 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
535 rp.os = p.bcast_dim;
536 };
537
538 auto init_load = [&](int icb, int &load_step) {
539 load_step = step(
540 jcp.nb_load_blocking, icb_end - icb, jcp.nb_load_blocking_max);
541 const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic);
542 p.load_dim = this_block_size(
543 icb * jcp.ic_block, max_ic, load_step * jcp.ic_block);
544 rp.icb = p.load_dim;
545 };
546
547 auto init_reduce = [&](int ocb) {
548 const int nb_oc_blocking_step
549 = nstl::min(ocb + nb_oc_blocking, nb_oc) - ocb;
550 p.first_last_flag = 0 | (ocb == 0 ? FLAG_REDUCE_FIRST : 0)
551 | (ocb + nb_oc_blocking_step >= nb_oc ? FLAG_REDUCE_LAST : 0);
552
553 p.reduce_dim = this_block_size(
554 ocb * jcp.oc_block, jcp.oc, nb_oc_blocking_step * jcp.oc_block);
555 };
556
557 auto inner_ker = [&](int icb, int ocb, int n, int g, int od, int oh, int ow,
558 int id, int ih, int iw) {
559 const bool is_dsrc_layout_nxc = utils::one_of(jcp.src_tag,
560 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
561 const int ic_off_idx = is_dsrc_layout_nxc
562 ? g * jcp.ic + icb * jcp.ic_block
563 : g * nb_ic + icb;
564 const size_t diff_src_off
565 = data_blk_off(diff_src_d, n, ic_off_idx, id, ih, iw);
566
567 rp.src = diff_src + diff_src_off;
568 if (pd()->rtus_.reduce_src_) {
569 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_;
570 p.output_data = rp.ws;
571 } else
572 p.output_data = rp.src;
573 p.load_data
574 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
575 : weights_d.blk_off(ocb, icb)];
576
577 const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag,
578 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
579 const int oc_off_idx = is_ddst_layout_nxc
580 ? g * jcp.oc + ocb * jcp.oc_block
581 : g * nb_oc + ocb;
582 p.bcast_data = diff_dst
583 + data_blk_off(diff_dst_d, n, oc_off_idx, od, oh, ow);
584
585 const size_t grp_count = utils::div_up(
586 jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count));
587 const size_t max_load_per_thread = is_dsrc_layout_nxc
588 ? jcp.load_dim
589 : rnd_up((jcp.load_dim / grp_count), jcp.load_block);
590 const size_t str_size = jcp.bcast_dim * max_load_per_thread;
591 p.store_buffer = store_buffer + ithr * str_size
592 + data_blk_off(diff_src_d, 0, 0, id, ih, iw);
593 (*kernel_)(&p);
594 if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp);
595 };
596
597 if (jcp.loop_order == loop_lbr) {
598 int icb = icb_start;
599 while (icb < icb_end) {
600 int load_step;
601 init_load(icb, load_step);
602 int iwork = bcast_start;
603 while (iwork < bcast_end) {
604 int n, g, bcast_step, od, oh, ow, id, ih, iw;
605 init_bcast(iwork, n, g, bcast_step, od, oh, ow, id, ih, iw);
606 for (int ocb = 0; ocb < nb_oc; ocb += nb_oc_blocking) {
607 init_reduce(ocb);
608 inner_ker(icb, ocb, n, g, od, oh, ow, id, ih, iw);
609 }
610 iwork += bcast_step;
611 }
612 icb += load_step;
613 }
614 } else {
615 assert(!"unsupported loop order");
616 }
617}
618
619REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t<
620 data_type::f32>);
621REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t<
622 data_type::bf16>);
623
624/* convolution backward wtr weights */
625
626#define wht_blk_off(d, g, ...) \
627 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
628 : (d).blk_off(__VA_ARGS__))
629
630template <data_type_t diff_weights_type>
631status_t
632jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<diff_weights_type>::init(
633 engine_t *engine) {
634 CHECK(safe_ptr_assign(kernel_,
635 new jit_avx512_core_bf16_1x1_conv_kernel(
636 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
637
638 CHECK(safe_ptr_assign(
639 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
640 CHECK(kernel_->create_kernel());
641 CHECK(acc_ker_->create_kernel());
642
643 if (!pd()->jcp_.uses_permw_transposition) {
644 const bool is_src_layout_nxc = utils::one_of(pd()->jcp_.src_tag,
645 format_tag::ndhwc, format_tag::nhwc, format_tag::nwc);
646 const bool is_ddst_layout_nxc = utils::one_of(pd()->jcp_.dst_tag,
647 format_tag::ndhwc, format_tag::nhwc, format_tag::nwc);
648 if (!is_src_layout_nxc || !is_ddst_layout_nxc) {
649 CHECK(safe_ptr_assign(tr_reorder_,
650 new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t()));
651 CHECK(tr_reorder_->create_kernel());
652 }
653 if (is_src_layout_nxc) {
654 int ic = pd()->jcp_.ic * pd()->jcp_.ngroups;
655 CHECK(safe_ptr_assign(tr_reorder_nhwc_src_,
656 new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(ic)));
657 CHECK(tr_reorder_nhwc_src_->create_kernel());
658 }
659 if (is_ddst_layout_nxc) {
660 int oc = pd()->jcp_.oc * pd()->jcp_.ngroups;
661 CHECK(safe_ptr_assign(tr_reorder_nhwc_ddst_,
662 new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(oc)));
663 CHECK(tr_reorder_nhwc_ddst_->create_kernel());
664 }
665 }
666
667 CHECK(init_rtus_driver<avx512_core>(this));
668 return status::success;
669}
670
671template <data_type_t diff_weights_type>
672void jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<diff_weights_type>::
673 execute_backward_weights(const exec_ctx_t &ctx) const {
674 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
675 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
676 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
677
678 auto scratchpad = ctx.get_scratchpad_grantor();
679 const auto &jcp = pd()->jcp_;
680
681 float *diff_bias = nullptr;
682 if (jcp.with_bias && pd()->jcp_.bia_dt == data_type::f32) {
683 diff_bias = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block
684 ? scratchpad.template get<float>(key_conv_padded_bias)
685 : CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
686 }
687 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
688 const memory_desc_wrapper src_d(pd()->src_md());
689 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
690
691 auto rtus_space = scratchpad.template get<src_data_t>(key_conv_rtus_space);
692 auto wei_reduction = scratchpad.template get<float>(key_conv_wei_reduction);
693
694 auto tr_src_buffer = !jcp.uses_permw_transposition
695 ? scratchpad.template get<src_data_t>(key_conv_tr_src)
696 : nullptr;
697 auto tr_diff_buffer = !jcp.uses_permw_transposition
698 ? scratchpad.template get<diff_dst_data_t>(key_conv_tr_diff_dst)
699 : nullptr;
700
701 const int ndims = src_d.ndims();
702 const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block)
703 * rnd_up(jcp.ic, jcp.ic_block);
704 const int n_wei_buffers
705 = jcp.dst_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
706 auto bia_reduction = wei_reduction + n_wei_buffers * wei_size;
707
708 simple_barrier::ctx_t reduction_barrier;
709 if (dnnl_thr_syncable()) simple_barrier::ctx_init(&reduction_barrier);
710
711 // TODO (Roma): remove this restriction
712 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
713
714 const int nb_ic_blocking = jcp.nb_bcast_blocking;
715
716 const int nb_oc = jcp.nb_load;
717 const int nb_oc_blocking = jcp.nb_load_blocking;
718
719 const int sp_nb = jcp.nb_reduce;
720 const int mb_sp_work = jcp.mb * sp_nb;
721
722 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
723 const int stride_w = pd()->desc()->strides[ndims - 3];
724
725 auto step = [](int default_step, int remaining, int tail_step) {
726 assert(default_step <= tail_step);
727 return remaining < tail_step ? remaining : default_step;
728 };
729
730 const bool is_ddst_layout_nxc = utils::one_of(
731 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
732
733 auto maybe_zero_icpad = [&](const int g_start, const int g_end,
734 const int ocb_start, const int ocb_end) {
735 // write zeros to IC padded region.
736 const int ic_tail = jcp.ic_without_padding % jcp.ic_block;
737 if (ic_tail != 0) {
738 for_(int g = g_start; g < g_end; ++g)
739 for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) {
740 const int z_icb = jcp.nb_bcast - 1;
741 const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb)
742 + ic_tail * jcp.oc_block;
743 diff_wei_data_t *z_wei = diff_weights + off;
744 const int zero_work
745 = (jcp.nb_bcast * jcp.ic_block - jcp.ic_without_padding)
746 * jcp.oc_block;
747 PRAGMA_OMP_SIMD()
748 for (int o = 0; o < zero_work; ++o) {
749 z_wei[o] = 0;
750 }
751 }
752 }
753 };
754
755 auto ker = [&](const int ithr, const int nthr) {
756 assert(nthr == jcp.nthr);
757
758 const int ithr_ic_b = ithr % jcp.nthr_ic_b;
759 const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
760 const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
761 const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g;
762
763 /* reduction dimension */
764 int mb_sp_b_start {0}, mb_sp_b_end {0};
765 balance211(
766 mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end);
767
768 /* independent dimensions */
769 int g_start {0}, oc_b_start {0}, ic_b_start {0};
770 int g_end {0}, oc_b_end {0}, ic_b_end {0};
771
772 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
773 balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end);
774 balance211(
775 jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end);
776
777 float *diff_wei;
778 if (diff_weights_type == data_type::bf16) {
779 diff_wei = wei_reduction + (ithr_mb)*wei_size;
780 } else {
781 diff_wei = ithr_mb == 0
782 ? (float *)diff_weights
783 : (float *)wei_reduction + (ithr_mb - 1) * wei_size;
784 }
785
786 float *diff_bia = nullptr;
787 if (jcp.with_bias) {
788 const int bias_size = jcp.ngroups * jcp.nb_load * jcp.oc_block;
789 if (jcp.bia_dt == data_type::bf16) {
790 diff_bia = bia_reduction + (ithr_mb)*bias_size;
791 } else {
792 diff_bia = ithr_mb == 0
793 ? (float *)diff_bias
794 : (float *)bia_reduction + (ithr_mb - 1) * bias_size;
795 }
796 }
797
798 int sp_b_step = 0;
799 for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
800 mb_sp_b += sp_b_step) {
801 int img {0}, sp_b {0};
802 nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
803 sp_b_step = step(jcp.nb_reduce_blocking,
804 nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
805 jcp.nb_reduce_blocking_max);
806
807 for (int g = g_start; g < g_end; ++g) {
808 int load_step = 0;
809 int bcast_step = 0;
810 for (int ic_b = ic_b_start; ic_b < ic_b_end;
811 ic_b += bcast_step) {
812 bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
813 jcp.nb_bcast_blocking_max);
814 for (int oc_b = oc_b_start; oc_b < oc_b_end;
815 oc_b += load_step) {
816 load_step = step(nb_oc_blocking, oc_b_end - oc_b,
817 jcp.nb_load_blocking_max);
818
819 float *store_to;
820
821 const size_t off
822 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
823 store_to = diff_wei + off;
824
825 const bool is_src_layout_nxc
826 = utils::one_of(jcp.src_tag, format_tag::nwc,
827 format_tag::nhwc, format_tag::ndhwc);
828 const int ic_off_idx = is_src_layout_nxc
829 ? g * jcp.ic + ic_b * jcp.ic_block
830 : g * nb_oc + ic_b;
831 const src_data_t *diff_src
832 = &src[src_d.blk_off(img, ic_off_idx)];
833 const int oc_off_idx = is_ddst_layout_nxc
834 ? g * jcp.oc + oc_b * jcp.oc_block
835 : g * nb_oc + oc_b;
836 const diff_dst_data_t *pdiff_dst
837 = &diff_dst[diff_dst_d.blk_off(
838 img, oc_off_idx)];
839 const src_data_t *local_src = diff_src;
840
841 auto p = jit_1x1_conv_call_s();
842 auto rp = rtus_driver_t<avx512_core>::call_params_t();
843
844 p.output_stride = utils::rnd_up(jcp.ic, jcp.oc_block)
845 * jcp.oc_block * jcp.typesize_out;
846
847 p.load_dim = this_block_size(oc_b * jcp.oc_block,
848 jcp.oc, load_step * jcp.oc_block);
849
850 p.bcast_dim = this_block_size(ic_b * jcp.ic_block,
851 jcp.ic, bcast_step * jcp.ic_block);
852 rp.icb = p.bcast_dim;
853 p.output_data = store_to;
854
855 p.reduce_dim = sp_b_step * jcp.reduce_block;
856 if (!jcp.uses_permw_transposition)
857 p.reduce_dim = nstl::min(p.reduce_dim,
858 (size_t)jcp.reduce_dim
859 - sp_b * jcp.reduce_block);
860
861 rp.os = p.reduce_dim;
862
863 p.first_last_flag = 0
864 | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST
865 : 0)
866 | (ic_b == 0 ? FLAG_COMPUTE_BIAS : 0);
867
868 int sp = sp_b * jcp.reduce_block;
869 int oc_mult = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc
870 : jcp.oc_block;
871 p.load_data = pdiff_dst + sp * oc_mult;
872
873 if (pd()->rtus_.reduce_src_) {
874 const int oh = sp / jcp.ow;
875 const int ow = sp % jcp.ow;
876
877 const int ih = oh * stride_h;
878 const int iw = ow * stride_w;
879 rp.iw_start = iw;
880
881 rp.ws = rtus_space
882 + ithr * pd()->rtus_.space_per_thread_
883 + sp * jcp.ic_block;
884
885 if (ndims == 3)
886 rp.src = local_src
887 + iw * src_d.blocking_desc().strides[2];
888 else
889 rp.src = local_src
890 + ih * src_d.blocking_desc().strides[2]
891 + iw * src_d.blocking_desc().strides[3];
892 (*rtus_driver_)(&rp);
893
894 p.bcast_data = rp.ws;
895 } else {
896 int ic_mult = is_src_layout_nxc
897 ? jcp.ngroups * jcp.ic
898 : jcp.ic_block;
899 p.bcast_data = local_src + sp * ic_mult;
900 }
901 if (!jcp.uses_permw_transposition) {
902 bf16_support::jit_call_t ptr;
903 ptr.nelems = p.reduce_dim;
904 int thr_src_block_size = rnd_up(jcp.reduce_dim, 2)
905 * jcp.ic_block * jcp.nb_bcast_blocking_max;
906 src_data_t *tr_src
907 = &tr_src_buffer[ithr * thr_src_block_size];
908 for (int bs = 0; bs < bcast_step; bs++) {
909 size_t src_off = bs * jcp.ic_block
910 * (is_src_layout_nxc ? 1
911 : jcp.reduce_dim);
912 size_t src_tr_off = bs
913 * rnd_up(jcp.reduce_dim, 2)
914 * jcp.ic_block;
915 src_data_t *curr_inp = &(
916 (src_data_t *)p.bcast_data)[src_off];
917 src_data_t *curr_out = &tr_src[src_tr_off];
918 int ch_work = nstl::min<int>(
919 p.bcast_dim - bs * jcp.bcast_block, 16);
920 assert(ch_work <= 16);
921 ptr.mask = (1 << ch_work) - 1;
922 ptr.inp = (void *)curr_inp;
923 ptr.out = (void *)curr_out;
924 if (is_src_layout_nxc)
925 (*tr_reorder_nhwc_src_)(&ptr);
926 else
927 (*tr_reorder_)(&ptr);
928 }
929
930 p.bcast_data = (void *)tr_src;
931 int thr_dst_block_size = rnd_up(jcp.reduce_dim, 2)
932 * jcp.oc_block * jcp.nb_load_blocking_max;
933 diff_dst_data_t *tr_diff_dst = &tr_diff_buffer[ithr
934 * thr_dst_block_size];
935 for (int ls = 0; ls < load_step; ls++) {
936 size_t ddst_off = ls * jcp.oc_block
937 * (is_ddst_layout_nxc ? 1 : jcp.os);
938 size_t ddst_tr_off = ls
939 * rnd_up(jcp.reduce_dim, 2)
940 * jcp.oc_block;
941 diff_dst_data_t *curr_inp
942 = &((diff_dst_data_t *)
943 p.load_data)[ddst_off];
944 diff_dst_data_t *curr_out
945 = &tr_diff_dst[ddst_tr_off];
946 int ch_work = nstl::min<int>(
947 p.load_dim - ls * jcp.load_block, 16);
948 ptr.mask = (1 << ch_work) - 1;
949 ptr.inp = (void *)curr_inp;
950 ptr.out = (void *)curr_out;
951 if (is_ddst_layout_nxc)
952 (*tr_reorder_nhwc_ddst_)(&ptr);
953 else
954 (*tr_reorder_)(&ptr);
955 }
956 p.load_data = (void *)tr_diff_dst;
957 } //if (!jcp.uses_permw_transposition)
958
959 p.bias_data = diff_bia
960 ? &diff_bia[oc_off_idx
961 * (is_ddst_layout_nxc ? 1
962 : jcp.oc_block)]
963 : nullptr;
964 (*kernel_)(&p);
965 }
966 }
967 }
968 }
969 };
970
971 auto ker_reduce_and_convert_diff_wei_bia = [&](const int ithr,
972 const int nthr) {
973 assert(nthr == jcp.nthr);
974
975 const int ithr_ic_b = ithr % jcp.nthr_ic_b;
976 const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
977 const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
978 const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g;
979
980 /* independent dimensions */
981 int g_start {0}, oc_b_start {0}, ic_b_start {0};
982 int g_end {0}, oc_b_end {0}, ic_b_end {0};
983
984 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
985 balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end);
986 balance211(
987 jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end);
988
989 const int g_work = g_end - g_start;
990 const int oc_b_work = oc_b_end - oc_b_start;
991 const int ic_b_work = ic_b_end - ic_b_start;
992
993 const int _start_nthr_mb = 1;
994 const bool is_bf16_out = diff_weights_type == data_type::bf16;
995 const bool is_bf16_bias
996 = jcp.with_bias && jcp.bia_dt == data_type::bf16;
997 /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */
998 if (jcp.nthr_mb > _start_nthr_mb) {
999 if (dnnl_thr_syncable())
1000 simple_barrier::barrier(&reduction_barrier, jcp.nthr);
1001 const int work = g_work * oc_b_work * ic_b_work;
1002 int start {0}, end {0};
1003 balance211(work, jcp.nthr_mb, ithr_mb, start, end);
1004 if (start == end) return;
1005
1006 for (int thr_mb = _start_nthr_mb; thr_mb < jcp.nthr_mb; ++thr_mb) {
1007 int w = start;
1008 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0};
1009 nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
1010 oc_b_work, sub_ic_b_start, ic_b_work);
1011 while (w < end) {
1012 const int g = g_start + sub_g_start;
1013 const int oc_b = oc_b_start + sub_oc_b_start;
1014 const int ic_b = ic_b_start + sub_ic_b_start;
1015 const int ic_to_accumulate
1016 = nstl::min(end - w, ic_b_work - sub_ic_b_start)
1017 * jcp.ic_block;
1018 const int acc_size
1019 = this_block_size(ic_b * jcp.ic_block,
1020 jcp.ic_without_padding, ic_to_accumulate)
1021 * jcp.oc_block;
1022
1023 const size_t off
1024 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1025 float *wei_reduced = is_bf16_out
1026 ? wei_reduction + off
1027 : (float *)diff_weights + off;
1028
1029 int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1;
1030 float *wei_to_reduce = wei_reduction
1031 + thr_mb_buffer_idx * wei_size + off;
1032 if (is_bf16_out && thr_mb == jcp.nthr_mb - 1)
1033 // the last iteration for bfloat16 requires conversion
1034 // and store to diff_weights array
1035 add_floats_and_cvt_to_bfloat16(
1036 (bfloat16_t *)(diff_weights + off), wei_reduced,
1037 wei_to_reduce, acc_size);
1038 else
1039 acc_ker_->accumulate(
1040 wei_reduced, wei_to_reduce, acc_size);
1041
1042 nd_iterator_jump(w, end, sub_g_start, g_work,
1043 sub_oc_b_start, oc_b_work, sub_ic_b_start,
1044 ic_b_work);
1045 }
1046
1047 if (jcp.with_bias && ithr_ic_b == 0 && ic_b_work > 0
1048 && ithr_mb == 0) {
1049 for (int g = g_start; g < g_end; g++) {
1050 float *bias_reduced
1051 = is_bf16_bias ? bia_reduction : diff_bias;
1052 int thr_mb_buffer_idx
1053 = is_bf16_bias ? thr_mb : thr_mb - 1;
1054 int bias_buf_size
1055 = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block);
1056 float *bias_to_reduce = bia_reduction
1057 + thr_mb_buffer_idx * bias_buf_size;
1058 const size_t acc_size
1059 = this_block_size(oc_b_start * jcp.oc_block,
1060 jcp.oc_without_padding,
1061 (oc_b_end - oc_b_start) * jcp.oc_block);
1062 int idx = g * rnd_up(jcp.oc, jcp.oc_block)
1063 + oc_b_start * jcp.oc_block;
1064 if (is_bf16_bias && thr_mb == jcp.nthr_mb - 1) {
1065 // the last iteration for bfloat16 requires conversion and
1066 // store to diff_weights array
1067 int diff_bias_idx = g * jcp.oc_without_padding
1068 + oc_b_start * jcp.oc_block;
1069 bfloat16_t *diff_bias_result
1070 = CTX_OUT_MEM(
1071 bfloat16_t *, DNNL_ARG_DIFF_BIAS)
1072 + diff_bias_idx;
1073 add_floats_and_cvt_to_bfloat16(diff_bias_result,
1074 &bias_reduced[idx], &bias_to_reduce[idx],
1075 acc_size);
1076 } else {
1077 acc_ker_->accumulate(&bias_reduced[idx],
1078 &bias_to_reduce[idx], acc_size);
1079 }
1080 }
1081 }
1082 }
1083 } else {
1084 if (is_bf16_out) {
1085 const auto ic_work = nstl::min(jcp.ic, ic_b_end * jcp.ic_block)
1086 - ic_b_start * jcp.ic_block;
1087 for_(int g = g_start; g < g_end; g++)
1088 for (int oc_b = oc_b_start; oc_b < oc_b_end; oc_b++) {
1089 const size_t acc_size = (size_t)ic_work * jcp.oc_block;
1090 const size_t off
1091 = wht_blk_off(diff_weights_d, g, oc_b, ic_b_start);
1092
1093 cvt_float_to_bfloat16((bfloat16_t *)(diff_weights + off),
1094 (const float *)(wei_reduction + off), acc_size);
1095 }
1096 }
1097
1098 if (is_bf16_bias && ithr_ic_b == 0 && ic_b_work > 0) {
1099 for (int g = g_start; g < g_end; g++) {
1100 int result_start_idx = g * jcp.oc_without_padding
1101 + oc_b_start * jcp.oc_block;
1102 int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block)
1103 + oc_b_start * jcp.oc_block;
1104 const size_t acc_size = nstl::min(jcp.oc_without_padding,
1105 oc_b_end * jcp.oc_block)
1106 - oc_b_start * jcp.oc_block;
1107 bfloat16_t *diff_bias_result
1108 = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS)
1109 + result_start_idx;
1110 float *buffer = bia_reduction + buffer_start_idx;
1111 cvt_float_to_bfloat16(diff_bias_result, buffer, acc_size);
1112 }
1113 }
1114 }
1115 if (ic_b_end >= jcp.nb_bcast) {
1116 maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end);
1117 }
1118 };
1119
1120 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1121 assert(nthr == jcp.nthr);
1122 ker(ithr, jcp.nthr);
1123 if (dnnl_thr_syncable())
1124 ker_reduce_and_convert_diff_wei_bia(ithr, jcp.nthr);
1125 });
1126
1127 if (!dnnl_thr_syncable()) {
1128 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1129 assert(nthr == jcp.nthr);
1130 ker_reduce_and_convert_diff_wei_bia(ithr, jcp.nthr);
1131 });
1132 }
1133
1134 if (pd()->jcp_.bia_dt == data_type::f32
1135 && jcp.oc_without_padding % jcp.oc_block) {
1136 auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1137 utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
1138 }
1139}
1140
1141REG_AVX512_ISA(
1142 template struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<
1143 data_type::f32>);
1144REG_AVX512_ISA(
1145 template struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<
1146 data_type::bf16>);
1147
1148} // namespace x64
1149} // namespace cpu
1150} // namespace impl
1151} // namespace dnnl
1152