1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/memory_tracking.hpp"
20
21#include "common/bfloat16.hpp"
22
23#include "cpu/x64/jit_uni_dw_convolution.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::status;
31using namespace dnnl::impl::memory_tracking::names;
32using namespace dnnl::impl::utils;
33using namespace dnnl::impl::data_type;
34
35template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
36void jit_uni_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward(
37 const exec_ctx_t &ctx) const {
38 const auto &jcp = pd()->jcp_;
39 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
40 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
41 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
42 const auto post_ops_binary_rhs_arg_vec
43 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
44
45 const memory_desc_wrapper src_d(pd()->src_md());
46 const memory_desc_wrapper dst_d(pd()->dst_md());
47 const memory_desc_wrapper weights_d(pd()->weights_md(0));
48 const memory_desc_wrapper bias_d(pd()->weights_md(1));
49
50 f32_data_t *bias = nullptr;
51 if (pd()->desc()->bias_desc.data_type == bf16) {
52 auto bias_in = CTX_IN_MEM(const bf16_data_t *, DNNL_ARG_BIAS);
53 bias = ctx.get_scratchpad_grantor().template get<f32_data_t>(
54 key_conv_bias_bf16_convert_wsp);
55 cvt_bfloat16_to_float(bias, bias_in, jcp.oc_without_padding);
56 utils::array_set(bias + jcp.oc_without_padding, 0.f,
57 jcp.oc - jcp.oc_without_padding);
58 } else {
59 auto bias_in = CTX_IN_MEM(const f32_data_t *, DNNL_ARG_BIAS);
60 if (pd()->wants_padded_bias()) {
61 auto padded_bias
62 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
63 key_conv_padded_bias);
64 utils::array_copy(padded_bias, bias_in, jcp.oc_without_padding);
65 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
66 jcp.oc - jcp.oc_without_padding);
67 bias = padded_bias;
68 } else
69 bias = const_cast<float *>(bias_in);
70 }
71
72 const int dil_h = jcp.dilate_h + 1;
73 const int str_h = jcp.stride_h;
74 const int ch_step = jcp.nb_ch_blocking;
75 const int ow = 0;
76 const int iw = 0;
77 const int kw = 0;
78 const int chb_work = utils::div_up(jcp.nb_ch, ch_step);
79 const auto is_src_layout_nxc = jcp.src_tag == format_tag::nhwc;
80 const auto is_dst_layout_nxc = jcp.dst_tag == format_tag::nhwc;
81
82 const int work_amount = jcp.mb * chb_work * jcp.oh;
83 const auto nthr = jcp.nthr;
84
85 parallel(nthr, [&](const int ithr, const int nthr) {
86 int start {0}, end {0};
87 balance211(work_amount, nthr, ithr, start, end);
88
89 int n {0}, chb {0}, oh {0};
90 if (jcp.loop_order == loop_ngcw)
91 utils::nd_iterator_init(
92 start, n, jcp.mb, chb, chb_work, oh, jcp.oh);
93 else if (jcp.loop_order == loop_nhwcg)
94 utils::nd_iterator_init(
95 start, n, jcp.mb, oh, jcp.oh, chb, chb_work);
96 else
97 assert(!"unsupported loop order");
98
99 auto iwork = start;
100 while (iwork < end) {
101
102 int ch = chb * ch_step;
103
104 const int i_t_overflow
105 = nstl::max(0, (int)(jcp.t_pad - oh * str_h));
106 const int i_b_overflow
107 = nstl::max(jcp.ih,
108 (int)(oh * str_h + (jcp.kh - 1) * dil_h
109 - jcp.t_pad + 1))
110 - jcp.ih;
111
112 const int ih
113 = nstl::max((int)(oh * str_h - jcp.t_pad
114 + div_up(i_t_overflow, dil_h) * dil_h),
115 0);
116 const int kh = div_up(i_t_overflow, dil_h);
117 const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
118 - div_up(i_b_overflow, dil_h);
119
120 const auto ic_off_idx = is_src_layout_nxc ? ch * jcp.ch_block : ch;
121 const auto oc_off_idx = is_dst_layout_nxc ? ch * jcp.ch_block : ch;
122
123 auto par_conv = jit_conv_call_s();
124 par_conv.src = jcp.is_fused_conv
125 ? src
126 : &src[src_d.blk_off(n, ic_off_idx, ih, iw)];
127 par_conv.dst = &dst[dst_d.blk_off(n, oc_off_idx, oh, ow)];
128
129 par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
130 if (bias) par_conv.bias = &bias[bias_d.blk_off(ch * jcp.ch_block)];
131
132 par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
133
134 assert(IMPLICATION(
135 jcp.loop_order == loop_nhwcg, is_src_layout_nxc));
136 // For is_src_layout_nxc maximize jit work along contiguous dim.
137 const int work_rem = end - iwork;
138 par_conv.load_work = utils::this_block_size(ch * jcp.ch_block,
139 jcp.oc_without_padding,
140 (is_src_layout_nxc ? work_rem * ch_step : ch_step)
141 * jcp.ch_block);
142
143 par_conv.oc_l_off = ch * jcp.ch_block;
144 par_conv.post_ops_binary_rhs_arg_vec
145 = post_ops_binary_rhs_arg_vec.data();
146 par_conv.dst_orig = dst;
147 (*kernel_)(&par_conv);
148
149 if (jcp.loop_order == loop_ngcw) {
150 ++iwork;
151 utils::nd_iterator_step(n, jcp.mb, chb, chb_work, oh, jcp.oh);
152 } else if (jcp.loop_order == loop_nhwcg) {
153 utils::nd_iterator_jump(
154 iwork, end, n, jcp.mb, oh, jcp.oh, chb, chb_work);
155 } else
156 assert(!"unsupported loop order");
157 }
158 });
159
160 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
161}
162
163REG_AVX512_ISA(
164 template struct jit_uni_dw_convolution_fwd_t<avx512_core, bf16, f32>);
165REG_AVX512_ISA(template struct jit_uni_dw_convolution_fwd_t<avx512_core, bf16>);
166REG_AVX512_ISA(template struct jit_uni_dw_convolution_fwd_t<avx512_core, f32>);
167REG_AVX2_ISA(template struct jit_uni_dw_convolution_fwd_t<avx2, f32>);
168REG_SSE41_ISA(template struct jit_uni_dw_convolution_fwd_t<sse41, f32>);
169
170template <cpu_isa_t isa, data_type_t diff_dst_type, data_type_t diff_src_type>
171void jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type,
172 diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const {
173 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
174 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
175 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
176
177 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
178 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
179 const memory_desc_wrapper weights_d(pd()->weights_md(0));
180
181 const auto &jcp = pd()->jcp_;
182
183 auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih,
184 int i_t_overflow, int i_b_overflow,
185 int stride_off_h, int ch, int n,
186 int work_remaining) {
187 auto par_conv = jit_conv_call_s();
188 const bool is_dsrc_layout_nxc
189 = utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc);
190 const bool is_ddst_layout_nxc
191 = utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc);
192 const int nb_ch_blocking = jcp.nb_ch_blocking;
193
194 const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad));
195 const int i_r_overflow
196 = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) - jcp.r_pad));
197
198 int ow = iw + jcp.l_pad - i_r_overflow;
199 int stride_off_w = ow % jcp.stride_w;
200 ow /= jcp.stride_w;
201
202 const int ic_offset = is_dsrc_layout_nxc ? ch * jcp.ch_block : ch;
203 par_conv.src = &diff_src[diff_src_d.blk_off(n, ic_offset, ih, iw)];
204 const int oc_offset = is_ddst_layout_nxc ? ch * jcp.ch_block : ch;
205 par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, oc_offset, oh, ow)];
206 par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0,
207 i_b_overflow + stride_off_h, i_r_overflow + stride_off_w)];
208
209 par_conv.kh_padding = nstl::max(
210 0, jcp.kh - i_t_overflow - i_b_overflow - stride_off_h);
211 par_conv.kw_padding = nstl::max(
212 0, jcp.kw - i_l_overflow - i_r_overflow - stride_off_w);
213
214 par_conv.ur_str_w = ur_str_w;
215
216 const size_t ch_work = (is_ddst_layout_nxc ? work_remaining : 1)
217 * nb_ch_blocking * jcp.ch_block;
218 const size_t load_work
219 = utils::this_block_size(static_cast<size_t>(ch * jcp.ch_block),
220 static_cast<size_t>(jcp.oc), ch_work);
221 par_conv.ch_blocks = load_work;
222
223 return par_conv;
224 };
225
226 const int aux_w
227 = nstl::min(jcp.iw, jcp.iw - jcp.kw + jcp.r_pad + jcp.stride_w);
228 const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
229 const dim_t work_amount = jcp.mb * chb_work * jcp.ih;
230 const auto nthr = jcp.nthr;
231
232 parallel(nthr, [&](const int ithr, const int nthr) {
233 dim_t start {0}, end {0};
234 balance211(work_amount, nthr, ithr, start, end);
235 dim_t n {0}, chb {0}, ih {0};
236 if (jcp.loop_order == loop_ngcw)
237 utils::nd_iterator_init(
238 start, n, jcp.mb, chb, chb_work, ih, jcp.ih);
239 else if (jcp.loop_order == loop_nhwcg)
240 utils::nd_iterator_init(
241 start, n, jcp.mb, ih, jcp.ih, chb, chb_work);
242 else
243 assert(!"unsupported loop order");
244
245 auto iwork = start;
246 while (iwork < end) {
247 int ch = chb * jcp.nb_ch_blocking;
248
249 const int work_rem = end - iwork;
250 const dim_t i_t_overflow
251 = nstl::max(dim_t(0), jcp.kh - 1 - ih - jcp.t_pad);
252 const dim_t i_b_overflow = nstl::max(
253 dim_t(0), jcp.kh - 1 - (jcp.ih - 1 - ih) - jcp.b_pad);
254
255 int oh = ih + jcp.t_pad - i_b_overflow;
256 int stride_off_h = oh % jcp.stride_h;
257 oh /= jcp.stride_h;
258
259 for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) {
260 // left border
261 int iw = i_str_w;
262 int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw);
263 int ur_str_w = 1;
264 for (; iw < l_border; iw += jcp.stride_w) {
265 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
266 ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n,
267 work_rem);
268
269 (*kernel_)(&par_conv);
270 }
271
272 // main loop
273 ur_str_w = (aux_w - iw) / jcp.stride_w;
274 if (ur_str_w > 0) {
275 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
276 ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n,
277 work_rem);
278
279 (*kernel_)(&par_conv);
280
281 iw += ur_str_w * jcp.stride_w;
282 }
283
284 // right border
285 ur_str_w = 1;
286 for (; iw < jcp.iw; iw += jcp.stride_w) {
287 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
288 ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n,
289 work_rem);
290
291 (*kernel_)(&par_conv);
292 }
293 }
294 if (jcp.loop_order == loop_ngcw) {
295 ++iwork;
296 utils::nd_iterator_step(n, jcp.mb, chb, chb_work, ih, jcp.ih);
297 } else if (jcp.loop_order == loop_nhwcg) {
298 utils::nd_iterator_jump(
299 iwork, end, n, jcp.mb, ih, jcp.ih, chb, chb_work);
300 } else
301 assert(!"unsupported loop order");
302 }
303 });
304}
305
306REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_data_t<avx512_core,
307 bf16, f32>);
308REG_AVX512_ISA(
309 template struct jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16>);
310REG_AVX512_ISA(
311 template struct jit_uni_dw_convolution_bwd_data_t<avx512_core, f32>);
312REG_AVX2_ISA(template struct jit_uni_dw_convolution_bwd_data_t<avx2, f32>);
313REG_SSE41_ISA(template struct jit_uni_dw_convolution_bwd_data_t<sse41, f32>);
314
315template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
316jit_uni_dw_convolution_bwd_weights_t<isa, src_type, diff_weights_type>::
317 jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd)
318 : primitive_t(apd), acc_ker_(nullptr), kernel_(nullptr) {}
319
320template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
321void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
322 diff_weights_type>::execute_backward_weights_nxc(const exec_ctx_t &ctx)
323 const {
324 const auto &jcp = pd()->jcp_;
325
326 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
327 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
328 auto diff_weights
329 = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS);
330
331 auto diff_wei_reduction_buffer
332 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
333 key_conv_wei_reduction);
334 auto diff_bias_reduction_buffer
335 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
336 key_conv_bia_reduction);
337
338 auto diff_bias_f32_to_bf16_accum
339 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
340 key_conv_bias_bf16_convert_wsp);
341 float *diff_bias = jcp.bia_dt == bf16
342 ? diff_bias_f32_to_bf16_accum
343 : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
344
345 const int ch_block = jcp.ch_block;
346 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
347 auto conv_params = jit_dw_conv_call_s();
348 const int h_block_size = jcp.oh_blk_size;
349
350 const int ch_outer_blocks
351 = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
352 const int ithr_g = ithr % jcp.nthr_g;
353 int g_start {0}, g_end {0};
354 balance211(ch_outer_blocks, jcp.nthr_g, ithr_g, g_start, g_end);
355
356 const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
357 int mb_start {0}, mb_end {0};
358 balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
359
360 const int ithr_oh = (ithr / (jcp.nthr_mb * jcp.nthr_g)) % jcp.nthr_oh;
361 const int nb_oh = div_up(jcp.oh, jcp.oh_blk_size);
362 int nb_oh_start {0}, nb_oh_end {0};
363 balance211(nb_oh, jcp.nthr_oh, ithr_oh, nb_oh_start, nb_oh_end);
364
365 const size_t wei_size
366 = utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw;
367 const bool main_thread = ithr_mb == 0 && ithr_oh == 0;
368 const int offset_wei_buffer = diff_weights_type == f32 ? 1 : 0;
369 const int ithr_block = ithr_mb * jcp.nthr_oh + ithr_oh;
370 f32_data_t *ithr_diff_weights
371 = (main_thread && diff_weights_type == f32)
372 ? (f32_data_t *)diff_weights
373 : diff_wei_reduction_buffer
374 + static_cast<size_t>(
375 (ithr_block - offset_wei_buffer) * wei_size);
376
377 const size_t filter_g_step
378 = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block);
379 const size_t src_h_step = static_cast<size_t>(jcp.iw * jcp.ngroups);
380 const size_t ddst_h_step = static_cast<size_t>(jcp.ow * jcp.ngroups);
381 const size_t bias_size = static_cast<size_t>(jcp.ngroups);
382 auto ithr_diff_bias = main_thread
383 ? diff_bias
384 : diff_bias_reduction_buffer ? diff_bias_reduction_buffer
385 + (ithr_block - 1) * bias_size
386 : nullptr;
387 const int g_step = jcp.nb_ch_blocking;
388 for (int g_ = g_start; g_ < g_end; ++g_) {
389 const int g = g_ * jcp.nb_ch_blocking;
390 unsigned char last_g_flag
391 = (g + g_step) >= jcp.nb_ch ? FLAG_OC_LAST : 0;
392 unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
393 unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
394 for (int mb = mb_start; mb < mb_end; mb++) {
395 for (int nb_oh = nb_oh_start; nb_oh < nb_oh_end; ++nb_oh) {
396 const int oh_s = nb_oh * h_block_size;
397 const int h_work = nstl::min(h_block_size, jcp.oh - oh_s);
398 const int oh_e = oh_s + h_work;
399 const int ih = -jcp.t_pad + oh_s * jcp.stride_h;
400 const int kh_top_overflow = nstl::max(0, -ih);
401 const int kh_bottom_overflow
402 = nstl::max(0, ih - jcp.ih + jcp.kh);
403 const int kh_padding_offset
404 = nstl::min(jcp.kh - 1, kh_top_overflow);
405 conv_params.kh_count
406 = jcp.kh - kh_top_overflow - kh_bottom_overflow;
407 conv_params.filter_pad_off
408 = static_cast<size_t>(kh_padding_offset * jcp.kw
409 * ch_block * jcp.typesize_out);
410 const size_t filter_g_offset
411 = static_cast<size_t>(g) * filter_g_step;
412 conv_params.filter = &ithr_diff_weights[filter_g_offset];
413
414 const size_t g_offset
415 = static_cast<size_t>(g * jcp.ch_block);
416 const size_t src_offset = static_cast<size_t>(mb * jcp.ih
417 + ih + kh_top_overflow)
418 * src_h_step;
419 conv_params.input = &src[src_offset + g_offset];
420 const size_t diff_dst_off
421 = static_cast<size_t>(mb * jcp.oh + oh_s)
422 * ddst_h_step;
423 conv_params.output = &diff_dst[diff_dst_off + g_offset];
424 conv_params.oh_index = oh_s;
425 conv_params.oh_count = oh_e;
426 if (jcp.with_bias)
427 conv_params.bias = &ithr_diff_bias[g_offset];
428
429 conv_params.exec_flags
430 = zero_filter_flag | zero_bias_flag | last_g_flag;
431 (*kernel_)(&conv_params);
432
433 // flags are only needed during the first kernel call
434 zero_filter_flag &= ~FLAG_ZERO_FILTER;
435 zero_bias_flag &= ~FLAG_ZERO_BIAS;
436 }
437 }
438 }
439 });
440}
441
442template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
443void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
444 diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx)
445 const {
446 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
447 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
448 auto diff_weights
449 = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS);
450
451 auto diff_wei_reduction_buf
452 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
453 key_conv_wei_reduction);
454 auto diff_bia_reduction_buf
455 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
456 key_conv_bia_reduction);
457
458 const auto &jcp = pd()->jcp_;
459
460 float *diff_bias = nullptr;
461 if (jcp.bia_dt == bf16) {
462 diff_bias = ctx.get_scratchpad_grantor().template get<f32_data_t>(
463 key_conv_bias_bf16_convert_wsp);
464 } else {
465 diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
466 }
467
468 const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
469 const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
470
471 const int ch_block = jcp.ch_block;
472
473 auto set_kernel_params
474 = [&](jit_dw_conv_call_s *conv_params, const int batch,
475 const int group, const int oh_start, const int work_size,
476 const unsigned char exec_flag, const size_t kh_padding,
477 const size_t filter_off) {
478 const int tpad_underflow_off = jcp.t_pad - filter_off;
479
480 conv_params->exec_flags = exec_flag;
481 conv_params->kh_count = jcp.kh - kh_padding;
482
483 const int oh_s = oh_start;
484 const int oh_e = oh_start + work_size;
485 const int ih_s = oh_s * jcp.stride_h;
486
487 conv_params->filter_pad_off
488 = filter_off * jcp.kw * ch_block * jcp.typesize_out;
489 conv_params->oh_index = oh_s;
490 conv_params->oh_count = oh_e;
491
492 size_t diff_dst_off
493 = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
494 + oh_start)
495 * jcp.ow;
496
497 size_t src_off
498 = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
499 + ih_s - tpad_underflow_off)
500 * jcp.iw;
501
502 conv_params->output = &diff_dst[diff_dst_off * ch_block];
503 conv_params->input = &src[src_off * ch_block];
504 };
505
506 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
507 assert(nthr == jcp.nthr);
508
509 auto conv_params = jit_dw_conv_call_s();
510 const int h_block_size = jcp.oh_blk_size;
511 const int nb_ch = jcp.nb_ch;
512
513 /* assign iteration space to thread */
514 const int ithr_g = ithr % jcp.nthr_g;
515 const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
516
517 /* split dimensions */
518 int g_start {0}, g_end {0};
519 balance211(nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
520
521 int mb_start {0}, mb_end {0};
522 balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
523
524 auto i_mb = diff_weights_type == bf16 ? ithr_mb : ithr_mb - 1;
525 f32_data_t *diff_wei = (ithr_mb == 0 && diff_weights_type == f32)
526 ? (f32_data_t *)diff_weights
527 : diff_wei_reduction_buf + i_mb * wei_size;
528
529 auto diff_bia = ithr_mb == 0
530 ? diff_bias
531 : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
532
533 for (int g = g_start; g < g_end; ++g) {
534 unsigned char last_g_flag = g == nb_ch - 1 ? FLAG_OC_LAST : 0;
535 unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
536 unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
537
538 size_t diff_wei_off = g * jcp.kh * jcp.kw;
539 conv_params.filter = &diff_wei[diff_wei_off * ch_block];
540
541 if (jcp.with_bias) conv_params.bias = &diff_bia[g * ch_block];
542
543 for (int mb = mb_start; mb < mb_end; ++mb) {
544 int oh = 0;
545 while (oh < jcp.oh) {
546 const int h_work = nstl::min(h_block_size, jcp.oh - oh);
547 auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
548 auto kh_b_padding
549 = (oh * jcp.stride_h + jcp.kh > jcp.ih + jcp.t_pad)
550 ? nstl::max(jcp.b_pad - (h_work - 1), 0)
551 : 0;
552
553 set_kernel_params(&conv_params, mb, g, oh, h_work,
554 zero_filter_flag | zero_bias_flag | last_g_flag,
555 kh_t_padding + kh_b_padding, kh_t_padding);
556 (*kernel_)(&conv_params);
557
558 zero_bias_flag &= ~FLAG_ZERO_BIAS;
559 zero_filter_flag &= ~FLAG_ZERO_FILTER;
560 oh += h_work;
561 }
562 }
563 }
564 });
565}
566
567/* TODO: Performing a Parallel Reduction could potentially improve performance;
568 * this should be explored in the future if further optimizations are required.
569 */
570template <>
571void jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16>::execute_reduction(
572 const exec_ctx_t &ctx) const {
573
574 const auto &jcp = pd()->jcp_;
575 assert(jcp.dwei_dt == bf16);
576
577 auto diff_wei_reduction_buf
578 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
579 key_conv_wei_reduction);
580 auto diff_bia_reduction_buf
581 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
582 key_conv_bia_reduction);
583 auto diff_weights
584 = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS);
585 auto diff_bias_f32_to_bf16_accum
586 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
587 key_conv_bias_bf16_convert_wsp);
588 float *diff_bias = jcp.bia_dt == bf16
589 ? diff_bias_f32_to_bf16_accum
590 : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
591
592 const size_t wei_size
593 = utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw;
594 const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
595 const int ch_block = jcp.ch_block;
596
597 /* Apply single-threaded 'mb' reduction */
598 if (jcp.with_bias && jcp.nthr_mb > 1) {
599 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
600 size_t b_accum_offset = (thr_mb - 1) * bias_size;
601 const int bias_ch_tail = jcp.ch_tail;
602 const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch;
603
604 for (int g = 0; g < nb_ch; ++g) {
605 /* Reduction on Bias */
606 PRAGMA_OMP_SIMD()
607 for (int g_block = 0; g_block < ch_block; ++g_block) {
608 size_t bias_offset = g * ch_block + g_block;
609 diff_bias[bias_offset]
610 += diff_bia_reduction_buf[b_accum_offset
611 + bias_offset];
612 }
613 }
614 for (int g = 0; g < bias_ch_tail; ++g) {
615 size_t bias_offset = static_cast<size_t>(nb_ch * ch_block + g);
616 diff_bias[bias_offset]
617 += diff_bia_reduction_buf[b_accum_offset + bias_offset];
618 }
619 }
620 }
621 if (jcp.bia_dt == bf16) {
622 auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS);
623 cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.oc_without_padding);
624 }
625 /* Apply single-threaded 'mb' reduction */
626 if (jcp.nthr_mb > 1) {
627 for (int thr_mb = 2; thr_mb < jcp.nthr_mb; ++thr_mb) {
628 size_t mb_accum_offset = thr_mb * wei_size;
629 acc_ker_->accumulate(&diff_wei_reduction_buf[0],
630 &diff_wei_reduction_buf[mb_accum_offset], wei_size);
631 }
632 add_floats_and_cvt_to_bfloat16((bfloat16_t *)&(diff_weights[0]),
633 (float *)&diff_wei_reduction_buf[0],
634 (float *)&diff_wei_reduction_buf[wei_size], wei_size);
635 } else {
636 cvt_float_to_bfloat16((bfloat16_t *)&(diff_weights[0]),
637 (const float *)&(diff_wei_reduction_buf[0]), wei_size);
638 }
639}
640
641template <>
642void jit_uni_dw_convolution_bwd_weights_t<sse41, f32>::execute_reduction(
643 const exec_ctx_t &ctx) const {
644
645 auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS);
646 auto diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
647 auto diff_wei_reduction_buffer
648 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
649 key_conv_wei_reduction);
650 auto diff_bias_reduction_buffer
651 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
652 key_conv_bia_reduction);
653
654 const auto &jcp = pd()->jcp_;
655
656 /* Apply single-threaded 'mb' reduction */
657 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
658 const int ch_block = jcp.ch_block;
659 const size_t wei_size
660 = static_cast<size_t>(jcp.ngroups * jcp.kh * jcp.kw);
661 const size_t mb_accum_offset = (thr_mb - 1) * wei_size;
662 const size_t bias_size = jcp.ngroups;
663 const size_t b_accum_offset = (thr_mb - 1) * bias_size;
664
665 const int bias_ch_tail = jcp.ch_tail;
666 const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch;
667 for (int g = 0; g < nb_ch; ++g) {
668 if (jcp.with_bias) {
669 PRAGMA_OMP_SIMD()
670 for (int g_block = 0; g_block < ch_block; ++g_block) {
671 const size_t bias_offset
672 = static_cast<size_t>(g * ch_block + g_block);
673 diff_bias[bias_offset]
674 += diff_bias_reduction_buffer[b_accum_offset
675 + bias_offset];
676 }
677 }
678 for_(int kh = 0; kh < jcp.kh; ++kh)
679 for (int kw = 0; kw < jcp.kw; ++kw) {
680 const size_t wei_sp_offset = (g * jcp.kh + kh) * jcp.kw + kw;
681 PRAGMA_OMP_SIMD()
682 for (int g_block = 0; g_block < ch_block; ++g_block) {
683 const size_t wei_offset = static_cast<size_t>(
684 wei_sp_offset * ch_block + g_block);
685 diff_weights[wei_offset]
686 += diff_wei_reduction_buffer[mb_accum_offset
687 + wei_offset];
688 }
689 }
690 }
691 // handle reduction for channel tail
692 if (jcp.with_bias) {
693 for (int g = 0; g < bias_ch_tail; ++g) {
694 const size_t bias_offset
695 = static_cast<size_t>(nb_ch * ch_block + g);
696 diff_bias[bias_offset]
697 += diff_bias_reduction_buffer[b_accum_offset
698 + bias_offset];
699 }
700 }
701 if (bias_ch_tail > 0) {
702 for_(int kh = 0; kh < jcp.kh; ++kh)
703 for (int kw = 0; kw < jcp.kw; ++kw) {
704 const size_t wei_sp_offset = static_cast<size_t>(
705 ((nb_ch * jcp.kh + kh) * jcp.kw + kw) * ch_block);
706 for (int g = 0; g < bias_ch_tail; ++g) {
707 const size_t wei_offset = wei_sp_offset + g;
708 diff_weights[wei_offset]
709 += diff_wei_reduction_buffer[mb_accum_offset
710 + wei_offset];
711 }
712 }
713 }
714 }
715}
716
717template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
718void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
719 diff_weights_type>::execute_reduction(const exec_ctx_t &ctx) const {
720
721 const auto &jcp = pd()->jcp_;
722 assert(everyone_is(f32, diff_weights_type, jcp.dwei_dt));
723
724 auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS);
725 auto diff_wei_reduction_buffer
726 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
727 key_conv_wei_reduction);
728 auto diff_bias_reduction_buffer
729 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
730 key_conv_bia_reduction);
731 auto diff_bias_f32_to_bf16_accum
732 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
733 key_conv_bias_bf16_convert_wsp);
734 float *diff_bias = jcp.bia_dt == bf16
735 ? diff_bias_f32_to_bf16_accum
736 : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
737
738 /* Apply single-threaded 'mb' reduction */
739 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
740 const int ch_block = jcp.ch_block;
741 const size_t wei_size
742 = static_cast<size_t>(jcp.ngroups * jcp.kh * jcp.kw);
743 const size_t mb_accum_offset = (thr_mb - 1) * wei_size;
744 const size_t bias_size = jcp.ngroups;
745 const size_t b_accum_offset = (thr_mb - 1) * bias_size;
746
747 if (jcp.with_bias) { // Reduction on Bias:
748 const int bias_ch_tail = jcp.ch_tail;
749 const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch;
750 for (int g = 0; g < nb_ch; ++g) {
751 PRAGMA_OMP_SIMD()
752 for (int g_block = 0; g_block < ch_block; ++g_block) {
753 const size_t bias_offset
754 = static_cast<size_t>(g * ch_block + g_block);
755 diff_bias[bias_offset]
756 += diff_bias_reduction_buffer[b_accum_offset
757 + bias_offset];
758 }
759 }
760 // handle reduction for channel tail
761 for (int g = 0; g < bias_ch_tail; g++) {
762 const size_t bias_offset
763 = static_cast<size_t>(nb_ch * ch_block + g);
764 diff_bias[bias_offset]
765 += diff_bias_reduction_buffer[b_accum_offset
766 + bias_offset];
767 }
768 }
769 acc_ker_->accumulate(&diff_weights[0],
770 &diff_wei_reduction_buffer[mb_accum_offset], wei_size);
771 }
772
773 if (jcp.bia_dt == bf16) {
774 auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS);
775 cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.ngroups);
776 }
777}
778
779template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type>
780void jit_uni_dw_convolution_bwd_weights_t<isa, src_type,
781 diff_weights_type>::execute_reduction_nxc(const exec_ctx_t &ctx) const {
782
783 const auto &jcp = pd()->jcp_;
784 auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS);
785 auto diff_wei_reduction_buffer
786 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
787 key_conv_wei_reduction);
788 auto diff_bia_reduction_buffer
789 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
790 key_conv_bia_reduction);
791 auto diff_bias_f32_to_bf16_accum
792 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
793 key_conv_bias_bf16_convert_wsp);
794 float *diff_bias = jcp.bia_dt == bf16
795 ? diff_bias_f32_to_bf16_accum
796 : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
797
798 const size_t wei_size = static_cast<size_t>(
799 utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw);
800
801 // TODO: maybe add 'KH' as another parallel dimension to increase partition
802 // space
803 parallel_nd(jcp.nb_ch, [&](int NB_CH) {
804 const size_t nb_ch_step
805 = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block);
806 const size_t wei_offset = NB_CH * nb_ch_step;
807
808 f32_data_t *ithr_diff_weights = diff_weights_type == f32
809 ? (f32_data_t *)&diff_weights[wei_offset]
810 : &diff_wei_reduction_buffer[wei_offset];
811 auto ithr_dwei_reduction_buff = &diff_wei_reduction_buffer[wei_offset];
812
813 const int thr_work = jcp.nthr_mb * jcp.nthr_oh;
814 for (int ithr_reduction = 0; ithr_reduction < thr_work - 1;
815 ++ithr_reduction) {
816 const int mb_ithr = ithr_reduction % jcp.nthr_mb;
817 const int oh_ithr = (ithr_reduction / jcp.nthr_mb) % jcp.nthr_oh;
818 const size_t ithr_offset
819 = static_cast<size_t>(mb_ithr * jcp.nthr_oh + oh_ithr);
820 const int offset_wei_buffer = diff_weights_type == bf16 ? 1 : 0;
821 const size_t reduction_offset
822 = (ithr_offset + offset_wei_buffer) * wei_size;
823 const size_t reduction_size
824 = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block);
825 acc_ker_->accumulate(&ithr_diff_weights[0],
826 &ithr_dwei_reduction_buff[reduction_offset],
827 reduction_size);
828
829 const bool compute_bias = jcp.with_bias;
830 const int ch_block = jcp.ch_block;
831 const size_t bias_size = jcp.ngroups;
832 const size_t bias_accum_offset = ithr_offset * bias_size;
833 if (compute_bias) {
834 const size_t nb_ch_offset = NB_CH * ch_block;
835 const int bias_ch_tail = jcp.ch_tail;
836 const bool compute_ch_tail
837 = (NB_CH == jcp.nb_ch - 1) && bias_ch_tail > 0;
838 if (!compute_ch_tail) {
839 PRAGMA_OMP_SIMD()
840 for (int g_block = 0; g_block < ch_block; ++g_block) {
841 const size_t bias_offset
842 = static_cast<size_t>(nb_ch_offset + g_block);
843 diff_bias[bias_offset]
844 += diff_bia_reduction_buffer[bias_accum_offset
845 + bias_offset];
846 }
847 } else {
848 // handle reduction for channel tail
849 for (int g = 0; g < bias_ch_tail; g++) {
850 const size_t bias_offset
851 = static_cast<size_t>(nb_ch_offset + g);
852 diff_bias[bias_offset]
853 += diff_bia_reduction_buffer[bias_accum_offset
854 + bias_offset];
855 }
856 }
857 }
858 }
859 });
860
861 if (diff_weights_type == bf16) {
862 cvt_float_to_bfloat16((bfloat16_t *)&(diff_weights[0]),
863 (const float *)&(diff_wei_reduction_buffer[0]), wei_size);
864 }
865
866 if (jcp.bia_dt == bf16) {
867 auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS);
868 cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.oc_without_padding);
869 }
870}
871
872template <>
873void jit_uni_dw_convolution_bwd_weights_t<sse41, f32>::execute_reduction_nxc(
874 const exec_ctx_t &ctx) const {
875
876 auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS);
877 auto diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS);
878
879 auto diff_wei_reduction_buffer
880 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
881 key_conv_wei_reduction);
882 auto diff_bia_reduction_buffer
883 = ctx.get_scratchpad_grantor().template get<f32_data_t>(
884 key_conv_bia_reduction);
885
886 const auto &jcp = pd()->jcp_;
887
888 const int thr_work = jcp.nthr_mb * jcp.nthr_oh;
889 int ithr_reduction = 1;
890 while (ithr_reduction < thr_work) {
891 const int mb_ithr = (ithr_reduction - 1) % jcp.nthr_mb;
892 const int oh_ithr = ((ithr_reduction - 1) / jcp.nthr_mb) % jcp.nthr_oh;
893 const size_t ithr_offset
894 = static_cast<size_t>(mb_ithr * jcp.nthr_oh + oh_ithr);
895 const size_t wei_size = static_cast<size_t>(
896 utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw);
897 const size_t reduction_offset = ithr_offset * wei_size;
898
899 const int ch_block = jcp.ch_block;
900 const size_t bias_size = jcp.ngroups;
901 size_t b_accum_offset = ithr_offset * bias_size;
902
903 const bool compute_bias = jcp.with_bias;
904 const int bias_ch_tail = jcp.ch_tail;
905 const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch;
906 for (int g = 0; g < nb_ch; ++g) {
907 if (compute_bias) {
908 PRAGMA_OMP_SIMD()
909 for (int g_block = 0; g_block < ch_block; ++g_block) {
910 const size_t bias_offset
911 = static_cast<size_t>(g * ch_block + g_block);
912 diff_bias[bias_offset]
913 += diff_bia_reduction_buffer[b_accum_offset
914 + bias_offset];
915 }
916 }
917 for_(int kh = 0; kh < jcp.kh; ++kh)
918 for (int kw = 0; kw < jcp.kw; ++kw) {
919 const size_t wei_sp_offset
920 = static_cast<size_t>((g * jcp.kh + kh) * jcp.kw + kw);
921 PRAGMA_OMP_SIMD()
922 for (int g_block = 0; g_block < ch_block; ++g_block) {
923 const size_t wei_offset = static_cast<size_t>(
924 wei_sp_offset * ch_block + g_block);
925 diff_weights[wei_offset]
926 += diff_wei_reduction_buffer[reduction_offset
927 + wei_offset];
928 }
929 }
930 }
931 // handle reduction for channel tail
932 if (compute_bias) {
933 for (int g = 0; g < bias_ch_tail; ++g) {
934 const size_t bias_offset
935 = static_cast<size_t>(nb_ch * ch_block + g);
936 diff_bias[bias_offset]
937 += diff_bia_reduction_buffer[b_accum_offset
938 + bias_offset];
939 }
940 }
941 if (bias_ch_tail > 0) {
942 for_(int kh = 0; kh < jcp.kh; ++kh)
943 for (int kw = 0; kw < jcp.kw; ++kw) {
944 const size_t wei_sp_offset = static_cast<size_t>(
945 (nb_ch * jcp.kh + kh) * jcp.kw + kw);
946 for (int g = 0; g < bias_ch_tail; ++g) {
947 const size_t wei_offset
948 = static_cast<size_t>(wei_sp_offset * ch_block + g);
949 diff_weights[wei_offset]
950 += diff_wei_reduction_buffer[reduction_offset
951 + wei_offset];
952 }
953 }
954 }
955
956 ithr_reduction++;
957 }
958}
959
960REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core,
961 bf16>);
962REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core,
963 bf16, f32>);
964REG_AVX512_ISA(
965 template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core, f32>);
966REG_AVX2_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx2, f32>);
967REG_SSE41_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<sse41, f32>);
968} // namespace x64
969} // namespace cpu
970} // namespace impl
971} // namespace dnnl
972