1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/memory_tracking.hpp" |
20 | |
21 | #include "common/bfloat16.hpp" |
22 | |
23 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::status; |
31 | using namespace dnnl::impl::memory_tracking::names; |
32 | using namespace dnnl::impl::utils; |
33 | using namespace dnnl::impl::data_type; |
34 | |
35 | template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type> |
36 | void jit_uni_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward( |
37 | const exec_ctx_t &ctx) const { |
38 | const auto &jcp = pd()->jcp_; |
39 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
40 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
41 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
42 | const auto post_ops_binary_rhs_arg_vec |
43 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
44 | |
45 | const memory_desc_wrapper src_d(pd()->src_md()); |
46 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
47 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
48 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
49 | |
50 | f32_data_t *bias = nullptr; |
51 | if (pd()->desc()->bias_desc.data_type == bf16) { |
52 | auto bias_in = CTX_IN_MEM(const bf16_data_t *, DNNL_ARG_BIAS); |
53 | bias = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
54 | key_conv_bias_bf16_convert_wsp); |
55 | cvt_bfloat16_to_float(bias, bias_in, jcp.oc_without_padding); |
56 | utils::array_set(bias + jcp.oc_without_padding, 0.f, |
57 | jcp.oc - jcp.oc_without_padding); |
58 | } else { |
59 | auto bias_in = CTX_IN_MEM(const f32_data_t *, DNNL_ARG_BIAS); |
60 | if (pd()->wants_padded_bias()) { |
61 | auto padded_bias |
62 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
63 | key_conv_padded_bias); |
64 | utils::array_copy(padded_bias, bias_in, jcp.oc_without_padding); |
65 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
66 | jcp.oc - jcp.oc_without_padding); |
67 | bias = padded_bias; |
68 | } else |
69 | bias = const_cast<float *>(bias_in); |
70 | } |
71 | |
72 | const int dil_h = jcp.dilate_h + 1; |
73 | const int str_h = jcp.stride_h; |
74 | const int ch_step = jcp.nb_ch_blocking; |
75 | const int ow = 0; |
76 | const int iw = 0; |
77 | const int kw = 0; |
78 | const int chb_work = utils::div_up(jcp.nb_ch, ch_step); |
79 | const auto is_src_layout_nxc = jcp.src_tag == format_tag::nhwc; |
80 | const auto is_dst_layout_nxc = jcp.dst_tag == format_tag::nhwc; |
81 | |
82 | const int work_amount = jcp.mb * chb_work * jcp.oh; |
83 | const auto nthr = jcp.nthr; |
84 | |
85 | parallel(nthr, [&](const int ithr, const int nthr) { |
86 | int start {0}, end {0}; |
87 | balance211(work_amount, nthr, ithr, start, end); |
88 | |
89 | int n {0}, chb {0}, oh {0}; |
90 | if (jcp.loop_order == loop_ngcw) |
91 | utils::nd_iterator_init( |
92 | start, n, jcp.mb, chb, chb_work, oh, jcp.oh); |
93 | else if (jcp.loop_order == loop_nhwcg) |
94 | utils::nd_iterator_init( |
95 | start, n, jcp.mb, oh, jcp.oh, chb, chb_work); |
96 | else |
97 | assert(!"unsupported loop order" ); |
98 | |
99 | auto iwork = start; |
100 | while (iwork < end) { |
101 | |
102 | int ch = chb * ch_step; |
103 | |
104 | const int i_t_overflow |
105 | = nstl::max(0, (int)(jcp.t_pad - oh * str_h)); |
106 | const int i_b_overflow |
107 | = nstl::max(jcp.ih, |
108 | (int)(oh * str_h + (jcp.kh - 1) * dil_h |
109 | - jcp.t_pad + 1)) |
110 | - jcp.ih; |
111 | |
112 | const int ih |
113 | = nstl::max((int)(oh * str_h - jcp.t_pad |
114 | + div_up(i_t_overflow, dil_h) * dil_h), |
115 | 0); |
116 | const int kh = div_up(i_t_overflow, dil_h); |
117 | const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h) |
118 | - div_up(i_b_overflow, dil_h); |
119 | |
120 | const auto ic_off_idx = is_src_layout_nxc ? ch * jcp.ch_block : ch; |
121 | const auto oc_off_idx = is_dst_layout_nxc ? ch * jcp.ch_block : ch; |
122 | |
123 | auto par_conv = jit_conv_call_s(); |
124 | par_conv.src = jcp.is_fused_conv |
125 | ? src |
126 | : &src[src_d.blk_off(n, ic_off_idx, ih, iw)]; |
127 | par_conv.dst = &dst[dst_d.blk_off(n, oc_off_idx, oh, ow)]; |
128 | |
129 | par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)]; |
130 | if (bias) par_conv.bias = &bias[bias_d.blk_off(ch * jcp.ch_block)]; |
131 | |
132 | par_conv.kh_padding = (size_t)nstl::max(0, kh_padding); |
133 | |
134 | assert(IMPLICATION( |
135 | jcp.loop_order == loop_nhwcg, is_src_layout_nxc)); |
136 | // For is_src_layout_nxc maximize jit work along contiguous dim. |
137 | const int work_rem = end - iwork; |
138 | par_conv.load_work = utils::this_block_size(ch * jcp.ch_block, |
139 | jcp.oc_without_padding, |
140 | (is_src_layout_nxc ? work_rem * ch_step : ch_step) |
141 | * jcp.ch_block); |
142 | |
143 | par_conv.oc_l_off = ch * jcp.ch_block; |
144 | par_conv.post_ops_binary_rhs_arg_vec |
145 | = post_ops_binary_rhs_arg_vec.data(); |
146 | par_conv.dst_orig = dst; |
147 | (*kernel_)(&par_conv); |
148 | |
149 | if (jcp.loop_order == loop_ngcw) { |
150 | ++iwork; |
151 | utils::nd_iterator_step(n, jcp.mb, chb, chb_work, oh, jcp.oh); |
152 | } else if (jcp.loop_order == loop_nhwcg) { |
153 | utils::nd_iterator_jump( |
154 | iwork, end, n, jcp.mb, oh, jcp.oh, chb, chb_work); |
155 | } else |
156 | assert(!"unsupported loop order" ); |
157 | } |
158 | }); |
159 | |
160 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
161 | } |
162 | |
163 | REG_AVX512_ISA( |
164 | template struct jit_uni_dw_convolution_fwd_t<avx512_core, bf16, f32>); |
165 | REG_AVX512_ISA(template struct jit_uni_dw_convolution_fwd_t<avx512_core, bf16>); |
166 | REG_AVX512_ISA(template struct jit_uni_dw_convolution_fwd_t<avx512_core, f32>); |
167 | REG_AVX2_ISA(template struct jit_uni_dw_convolution_fwd_t<avx2, f32>); |
168 | REG_SSE41_ISA(template struct jit_uni_dw_convolution_fwd_t<sse41, f32>); |
169 | |
170 | template <cpu_isa_t isa, data_type_t diff_dst_type, data_type_t diff_src_type> |
171 | void jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type, |
172 | diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const { |
173 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
174 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
175 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
176 | |
177 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
178 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
179 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
180 | |
181 | const auto &jcp = pd()->jcp_; |
182 | |
183 | auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih, |
184 | int i_t_overflow, int i_b_overflow, |
185 | int stride_off_h, int ch, int n, |
186 | int work_remaining) { |
187 | auto par_conv = jit_conv_call_s(); |
188 | const bool is_dsrc_layout_nxc |
189 | = utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc); |
190 | const bool is_ddst_layout_nxc |
191 | = utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc); |
192 | const int nb_ch_blocking = jcp.nb_ch_blocking; |
193 | |
194 | const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad)); |
195 | const int i_r_overflow |
196 | = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw) - jcp.r_pad)); |
197 | |
198 | int ow = iw + jcp.l_pad - i_r_overflow; |
199 | int stride_off_w = ow % jcp.stride_w; |
200 | ow /= jcp.stride_w; |
201 | |
202 | const int ic_offset = is_dsrc_layout_nxc ? ch * jcp.ch_block : ch; |
203 | par_conv.src = &diff_src[diff_src_d.blk_off(n, ic_offset, ih, iw)]; |
204 | const int oc_offset = is_ddst_layout_nxc ? ch * jcp.ch_block : ch; |
205 | par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, oc_offset, oh, ow)]; |
206 | par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, |
207 | i_b_overflow + stride_off_h, i_r_overflow + stride_off_w)]; |
208 | |
209 | par_conv.kh_padding = nstl::max( |
210 | 0, jcp.kh - i_t_overflow - i_b_overflow - stride_off_h); |
211 | par_conv.kw_padding = nstl::max( |
212 | 0, jcp.kw - i_l_overflow - i_r_overflow - stride_off_w); |
213 | |
214 | par_conv.ur_str_w = ur_str_w; |
215 | |
216 | const size_t ch_work = (is_ddst_layout_nxc ? work_remaining : 1) |
217 | * nb_ch_blocking * jcp.ch_block; |
218 | const size_t load_work |
219 | = utils::this_block_size(static_cast<size_t>(ch * jcp.ch_block), |
220 | static_cast<size_t>(jcp.oc), ch_work); |
221 | par_conv.ch_blocks = load_work; |
222 | |
223 | return par_conv; |
224 | }; |
225 | |
226 | const int aux_w |
227 | = nstl::min(jcp.iw, jcp.iw - jcp.kw + jcp.r_pad + jcp.stride_w); |
228 | const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); |
229 | const dim_t work_amount = jcp.mb * chb_work * jcp.ih; |
230 | const auto nthr = jcp.nthr; |
231 | |
232 | parallel(nthr, [&](const int ithr, const int nthr) { |
233 | dim_t start {0}, end {0}; |
234 | balance211(work_amount, nthr, ithr, start, end); |
235 | dim_t n {0}, chb {0}, ih {0}; |
236 | if (jcp.loop_order == loop_ngcw) |
237 | utils::nd_iterator_init( |
238 | start, n, jcp.mb, chb, chb_work, ih, jcp.ih); |
239 | else if (jcp.loop_order == loop_nhwcg) |
240 | utils::nd_iterator_init( |
241 | start, n, jcp.mb, ih, jcp.ih, chb, chb_work); |
242 | else |
243 | assert(!"unsupported loop order" ); |
244 | |
245 | auto iwork = start; |
246 | while (iwork < end) { |
247 | int ch = chb * jcp.nb_ch_blocking; |
248 | |
249 | const int work_rem = end - iwork; |
250 | const dim_t i_t_overflow |
251 | = nstl::max(dim_t(0), jcp.kh - 1 - ih - jcp.t_pad); |
252 | const dim_t i_b_overflow = nstl::max( |
253 | dim_t(0), jcp.kh - 1 - (jcp.ih - 1 - ih) - jcp.b_pad); |
254 | |
255 | int oh = ih + jcp.t_pad - i_b_overflow; |
256 | int stride_off_h = oh % jcp.stride_h; |
257 | oh /= jcp.stride_h; |
258 | |
259 | for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) { |
260 | // left border |
261 | int iw = i_str_w; |
262 | int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw); |
263 | int ur_str_w = 1; |
264 | for (; iw < l_border; iw += jcp.stride_w) { |
265 | jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, |
266 | ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n, |
267 | work_rem); |
268 | |
269 | (*kernel_)(&par_conv); |
270 | } |
271 | |
272 | // main loop |
273 | ur_str_w = (aux_w - iw) / jcp.stride_w; |
274 | if (ur_str_w > 0) { |
275 | jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, |
276 | ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n, |
277 | work_rem); |
278 | |
279 | (*kernel_)(&par_conv); |
280 | |
281 | iw += ur_str_w * jcp.stride_w; |
282 | } |
283 | |
284 | // right border |
285 | ur_str_w = 1; |
286 | for (; iw < jcp.iw; iw += jcp.stride_w) { |
287 | jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh, |
288 | ih, i_t_overflow, i_b_overflow, stride_off_h, ch, n, |
289 | work_rem); |
290 | |
291 | (*kernel_)(&par_conv); |
292 | } |
293 | } |
294 | if (jcp.loop_order == loop_ngcw) { |
295 | ++iwork; |
296 | utils::nd_iterator_step(n, jcp.mb, chb, chb_work, ih, jcp.ih); |
297 | } else if (jcp.loop_order == loop_nhwcg) { |
298 | utils::nd_iterator_jump( |
299 | iwork, end, n, jcp.mb, ih, jcp.ih, chb, chb_work); |
300 | } else |
301 | assert(!"unsupported loop order" ); |
302 | } |
303 | }); |
304 | } |
305 | |
306 | REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_data_t<avx512_core, |
307 | bf16, f32>); |
308 | REG_AVX512_ISA( |
309 | template struct jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16>); |
310 | REG_AVX512_ISA( |
311 | template struct jit_uni_dw_convolution_bwd_data_t<avx512_core, f32>); |
312 | REG_AVX2_ISA(template struct jit_uni_dw_convolution_bwd_data_t<avx2, f32>); |
313 | REG_SSE41_ISA(template struct jit_uni_dw_convolution_bwd_data_t<sse41, f32>); |
314 | |
315 | template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type> |
316 | jit_uni_dw_convolution_bwd_weights_t<isa, src_type, diff_weights_type>:: |
317 | jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd) |
318 | : primitive_t(apd), acc_ker_(nullptr), kernel_(nullptr) {} |
319 | |
320 | template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type> |
321 | void jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
322 | diff_weights_type>::execute_backward_weights_nxc(const exec_ctx_t &ctx) |
323 | const { |
324 | const auto &jcp = pd()->jcp_; |
325 | |
326 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
327 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
328 | auto diff_weights |
329 | = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
330 | |
331 | auto diff_wei_reduction_buffer |
332 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
333 | key_conv_wei_reduction); |
334 | auto diff_bias_reduction_buffer |
335 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
336 | key_conv_bia_reduction); |
337 | |
338 | auto diff_bias_f32_to_bf16_accum |
339 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
340 | key_conv_bias_bf16_convert_wsp); |
341 | float *diff_bias = jcp.bia_dt == bf16 |
342 | ? diff_bias_f32_to_bf16_accum |
343 | : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
344 | |
345 | const int ch_block = jcp.ch_block; |
346 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
347 | auto conv_params = jit_dw_conv_call_s(); |
348 | const int h_block_size = jcp.oh_blk_size; |
349 | |
350 | const int ch_outer_blocks |
351 | = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking); |
352 | const int ithr_g = ithr % jcp.nthr_g; |
353 | int g_start {0}, g_end {0}; |
354 | balance211(ch_outer_blocks, jcp.nthr_g, ithr_g, g_start, g_end); |
355 | |
356 | const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; |
357 | int mb_start {0}, mb_end {0}; |
358 | balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); |
359 | |
360 | const int ithr_oh = (ithr / (jcp.nthr_mb * jcp.nthr_g)) % jcp.nthr_oh; |
361 | const int nb_oh = div_up(jcp.oh, jcp.oh_blk_size); |
362 | int nb_oh_start {0}, nb_oh_end {0}; |
363 | balance211(nb_oh, jcp.nthr_oh, ithr_oh, nb_oh_start, nb_oh_end); |
364 | |
365 | const size_t wei_size |
366 | = utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw; |
367 | const bool main_thread = ithr_mb == 0 && ithr_oh == 0; |
368 | const int offset_wei_buffer = diff_weights_type == f32 ? 1 : 0; |
369 | const int ithr_block = ithr_mb * jcp.nthr_oh + ithr_oh; |
370 | f32_data_t *ithr_diff_weights |
371 | = (main_thread && diff_weights_type == f32) |
372 | ? (f32_data_t *)diff_weights |
373 | : diff_wei_reduction_buffer |
374 | + static_cast<size_t>( |
375 | (ithr_block - offset_wei_buffer) * wei_size); |
376 | |
377 | const size_t filter_g_step |
378 | = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block); |
379 | const size_t src_h_step = static_cast<size_t>(jcp.iw * jcp.ngroups); |
380 | const size_t ddst_h_step = static_cast<size_t>(jcp.ow * jcp.ngroups); |
381 | const size_t bias_size = static_cast<size_t>(jcp.ngroups); |
382 | auto ithr_diff_bias = main_thread |
383 | ? diff_bias |
384 | : diff_bias_reduction_buffer ? diff_bias_reduction_buffer |
385 | + (ithr_block - 1) * bias_size |
386 | : nullptr; |
387 | const int g_step = jcp.nb_ch_blocking; |
388 | for (int g_ = g_start; g_ < g_end; ++g_) { |
389 | const int g = g_ * jcp.nb_ch_blocking; |
390 | unsigned char last_g_flag |
391 | = (g + g_step) >= jcp.nb_ch ? FLAG_OC_LAST : 0; |
392 | unsigned char zero_filter_flag = FLAG_ZERO_FILTER; |
393 | unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; |
394 | for (int mb = mb_start; mb < mb_end; mb++) { |
395 | for (int nb_oh = nb_oh_start; nb_oh < nb_oh_end; ++nb_oh) { |
396 | const int oh_s = nb_oh * h_block_size; |
397 | const int h_work = nstl::min(h_block_size, jcp.oh - oh_s); |
398 | const int oh_e = oh_s + h_work; |
399 | const int ih = -jcp.t_pad + oh_s * jcp.stride_h; |
400 | const int kh_top_overflow = nstl::max(0, -ih); |
401 | const int kh_bottom_overflow |
402 | = nstl::max(0, ih - jcp.ih + jcp.kh); |
403 | const int kh_padding_offset |
404 | = nstl::min(jcp.kh - 1, kh_top_overflow); |
405 | conv_params.kh_count |
406 | = jcp.kh - kh_top_overflow - kh_bottom_overflow; |
407 | conv_params.filter_pad_off |
408 | = static_cast<size_t>(kh_padding_offset * jcp.kw |
409 | * ch_block * jcp.typesize_out); |
410 | const size_t filter_g_offset |
411 | = static_cast<size_t>(g) * filter_g_step; |
412 | conv_params.filter = &ithr_diff_weights[filter_g_offset]; |
413 | |
414 | const size_t g_offset |
415 | = static_cast<size_t>(g * jcp.ch_block); |
416 | const size_t src_offset = static_cast<size_t>(mb * jcp.ih |
417 | + ih + kh_top_overflow) |
418 | * src_h_step; |
419 | conv_params.input = &src[src_offset + g_offset]; |
420 | const size_t diff_dst_off |
421 | = static_cast<size_t>(mb * jcp.oh + oh_s) |
422 | * ddst_h_step; |
423 | conv_params.output = &diff_dst[diff_dst_off + g_offset]; |
424 | conv_params.oh_index = oh_s; |
425 | conv_params.oh_count = oh_e; |
426 | if (jcp.with_bias) |
427 | conv_params.bias = &ithr_diff_bias[g_offset]; |
428 | |
429 | conv_params.exec_flags |
430 | = zero_filter_flag | zero_bias_flag | last_g_flag; |
431 | (*kernel_)(&conv_params); |
432 | |
433 | // flags are only needed during the first kernel call |
434 | zero_filter_flag &= ~FLAG_ZERO_FILTER; |
435 | zero_bias_flag &= ~FLAG_ZERO_BIAS; |
436 | } |
437 | } |
438 | } |
439 | }); |
440 | } |
441 | |
442 | template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type> |
443 | void jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
444 | diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) |
445 | const { |
446 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
447 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
448 | auto diff_weights |
449 | = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
450 | |
451 | auto diff_wei_reduction_buf |
452 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
453 | key_conv_wei_reduction); |
454 | auto diff_bia_reduction_buf |
455 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
456 | key_conv_bia_reduction); |
457 | |
458 | const auto &jcp = pd()->jcp_; |
459 | |
460 | float *diff_bias = nullptr; |
461 | if (jcp.bia_dt == bf16) { |
462 | diff_bias = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
463 | key_conv_bias_bf16_convert_wsp); |
464 | } else { |
465 | diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
466 | } |
467 | |
468 | const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw; |
469 | const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; |
470 | |
471 | const int ch_block = jcp.ch_block; |
472 | |
473 | auto set_kernel_params |
474 | = [&](jit_dw_conv_call_s *conv_params, const int batch, |
475 | const int group, const int oh_start, const int work_size, |
476 | const unsigned char exec_flag, const size_t kh_padding, |
477 | const size_t filter_off) { |
478 | const int tpad_underflow_off = jcp.t_pad - filter_off; |
479 | |
480 | conv_params->exec_flags = exec_flag; |
481 | conv_params->kh_count = jcp.kh - kh_padding; |
482 | |
483 | const int oh_s = oh_start; |
484 | const int oh_e = oh_start + work_size; |
485 | const int ih_s = oh_s * jcp.stride_h; |
486 | |
487 | conv_params->filter_pad_off |
488 | = filter_off * jcp.kw * ch_block * jcp.typesize_out; |
489 | conv_params->oh_index = oh_s; |
490 | conv_params->oh_count = oh_e; |
491 | |
492 | size_t diff_dst_off |
493 | = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh |
494 | + oh_start) |
495 | * jcp.ow; |
496 | |
497 | size_t src_off |
498 | = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih |
499 | + ih_s - tpad_underflow_off) |
500 | * jcp.iw; |
501 | |
502 | conv_params->output = &diff_dst[diff_dst_off * ch_block]; |
503 | conv_params->input = &src[src_off * ch_block]; |
504 | }; |
505 | |
506 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
507 | assert(nthr == jcp.nthr); |
508 | |
509 | auto conv_params = jit_dw_conv_call_s(); |
510 | const int h_block_size = jcp.oh_blk_size; |
511 | const int nb_ch = jcp.nb_ch; |
512 | |
513 | /* assign iteration space to thread */ |
514 | const int ithr_g = ithr % jcp.nthr_g; |
515 | const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb; |
516 | |
517 | /* split dimensions */ |
518 | int g_start {0}, g_end {0}; |
519 | balance211(nb_ch, jcp.nthr_g, ithr_g, g_start, g_end); |
520 | |
521 | int mb_start {0}, mb_end {0}; |
522 | balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end); |
523 | |
524 | auto i_mb = diff_weights_type == bf16 ? ithr_mb : ithr_mb - 1; |
525 | f32_data_t *diff_wei = (ithr_mb == 0 && diff_weights_type == f32) |
526 | ? (f32_data_t *)diff_weights |
527 | : diff_wei_reduction_buf + i_mb * wei_size; |
528 | |
529 | auto diff_bia = ithr_mb == 0 |
530 | ? diff_bias |
531 | : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size; |
532 | |
533 | for (int g = g_start; g < g_end; ++g) { |
534 | unsigned char last_g_flag = g == nb_ch - 1 ? FLAG_OC_LAST : 0; |
535 | unsigned char zero_filter_flag = FLAG_ZERO_FILTER; |
536 | unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0; |
537 | |
538 | size_t diff_wei_off = g * jcp.kh * jcp.kw; |
539 | conv_params.filter = &diff_wei[diff_wei_off * ch_block]; |
540 | |
541 | if (jcp.with_bias) conv_params.bias = &diff_bia[g * ch_block]; |
542 | |
543 | for (int mb = mb_start; mb < mb_end; ++mb) { |
544 | int oh = 0; |
545 | while (oh < jcp.oh) { |
546 | const int h_work = nstl::min(h_block_size, jcp.oh - oh); |
547 | auto kh_t_padding = nstl::max(0, jcp.t_pad - oh); |
548 | auto kh_b_padding |
549 | = (oh * jcp.stride_h + jcp.kh > jcp.ih + jcp.t_pad) |
550 | ? nstl::max(jcp.b_pad - (h_work - 1), 0) |
551 | : 0; |
552 | |
553 | set_kernel_params(&conv_params, mb, g, oh, h_work, |
554 | zero_filter_flag | zero_bias_flag | last_g_flag, |
555 | kh_t_padding + kh_b_padding, kh_t_padding); |
556 | (*kernel_)(&conv_params); |
557 | |
558 | zero_bias_flag &= ~FLAG_ZERO_BIAS; |
559 | zero_filter_flag &= ~FLAG_ZERO_FILTER; |
560 | oh += h_work; |
561 | } |
562 | } |
563 | } |
564 | }); |
565 | } |
566 | |
567 | /* TODO: Performing a Parallel Reduction could potentially improve performance; |
568 | * this should be explored in the future if further optimizations are required. |
569 | */ |
570 | template <> |
571 | void jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16>::execute_reduction( |
572 | const exec_ctx_t &ctx) const { |
573 | |
574 | const auto &jcp = pd()->jcp_; |
575 | assert(jcp.dwei_dt == bf16); |
576 | |
577 | auto diff_wei_reduction_buf |
578 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
579 | key_conv_wei_reduction); |
580 | auto diff_bia_reduction_buf |
581 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
582 | key_conv_bia_reduction); |
583 | auto diff_weights |
584 | = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
585 | auto diff_bias_f32_to_bf16_accum |
586 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
587 | key_conv_bias_bf16_convert_wsp); |
588 | float *diff_bias = jcp.bia_dt == bf16 |
589 | ? diff_bias_f32_to_bf16_accum |
590 | : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
591 | |
592 | const size_t wei_size |
593 | = utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw; |
594 | const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0; |
595 | const int ch_block = jcp.ch_block; |
596 | |
597 | /* Apply single-threaded 'mb' reduction */ |
598 | if (jcp.with_bias && jcp.nthr_mb > 1) { |
599 | for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { |
600 | size_t b_accum_offset = (thr_mb - 1) * bias_size; |
601 | const int bias_ch_tail = jcp.ch_tail; |
602 | const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch; |
603 | |
604 | for (int g = 0; g < nb_ch; ++g) { |
605 | /* Reduction on Bias */ |
606 | PRAGMA_OMP_SIMD() |
607 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
608 | size_t bias_offset = g * ch_block + g_block; |
609 | diff_bias[bias_offset] |
610 | += diff_bia_reduction_buf[b_accum_offset |
611 | + bias_offset]; |
612 | } |
613 | } |
614 | for (int g = 0; g < bias_ch_tail; ++g) { |
615 | size_t bias_offset = static_cast<size_t>(nb_ch * ch_block + g); |
616 | diff_bias[bias_offset] |
617 | += diff_bia_reduction_buf[b_accum_offset + bias_offset]; |
618 | } |
619 | } |
620 | } |
621 | if (jcp.bia_dt == bf16) { |
622 | auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS); |
623 | cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.oc_without_padding); |
624 | } |
625 | /* Apply single-threaded 'mb' reduction */ |
626 | if (jcp.nthr_mb > 1) { |
627 | for (int thr_mb = 2; thr_mb < jcp.nthr_mb; ++thr_mb) { |
628 | size_t mb_accum_offset = thr_mb * wei_size; |
629 | acc_ker_->accumulate(&diff_wei_reduction_buf[0], |
630 | &diff_wei_reduction_buf[mb_accum_offset], wei_size); |
631 | } |
632 | add_floats_and_cvt_to_bfloat16((bfloat16_t *)&(diff_weights[0]), |
633 | (float *)&diff_wei_reduction_buf[0], |
634 | (float *)&diff_wei_reduction_buf[wei_size], wei_size); |
635 | } else { |
636 | cvt_float_to_bfloat16((bfloat16_t *)&(diff_weights[0]), |
637 | (const float *)&(diff_wei_reduction_buf[0]), wei_size); |
638 | } |
639 | } |
640 | |
641 | template <> |
642 | void jit_uni_dw_convolution_bwd_weights_t<sse41, f32>::execute_reduction( |
643 | const exec_ctx_t &ctx) const { |
644 | |
645 | auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
646 | auto diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
647 | auto diff_wei_reduction_buffer |
648 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
649 | key_conv_wei_reduction); |
650 | auto diff_bias_reduction_buffer |
651 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
652 | key_conv_bia_reduction); |
653 | |
654 | const auto &jcp = pd()->jcp_; |
655 | |
656 | /* Apply single-threaded 'mb' reduction */ |
657 | for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { |
658 | const int ch_block = jcp.ch_block; |
659 | const size_t wei_size |
660 | = static_cast<size_t>(jcp.ngroups * jcp.kh * jcp.kw); |
661 | const size_t mb_accum_offset = (thr_mb - 1) * wei_size; |
662 | const size_t bias_size = jcp.ngroups; |
663 | const size_t b_accum_offset = (thr_mb - 1) * bias_size; |
664 | |
665 | const int bias_ch_tail = jcp.ch_tail; |
666 | const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch; |
667 | for (int g = 0; g < nb_ch; ++g) { |
668 | if (jcp.with_bias) { |
669 | PRAGMA_OMP_SIMD() |
670 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
671 | const size_t bias_offset |
672 | = static_cast<size_t>(g * ch_block + g_block); |
673 | diff_bias[bias_offset] |
674 | += diff_bias_reduction_buffer[b_accum_offset |
675 | + bias_offset]; |
676 | } |
677 | } |
678 | for_(int kh = 0; kh < jcp.kh; ++kh) |
679 | for (int kw = 0; kw < jcp.kw; ++kw) { |
680 | const size_t wei_sp_offset = (g * jcp.kh + kh) * jcp.kw + kw; |
681 | PRAGMA_OMP_SIMD() |
682 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
683 | const size_t wei_offset = static_cast<size_t>( |
684 | wei_sp_offset * ch_block + g_block); |
685 | diff_weights[wei_offset] |
686 | += diff_wei_reduction_buffer[mb_accum_offset |
687 | + wei_offset]; |
688 | } |
689 | } |
690 | } |
691 | // handle reduction for channel tail |
692 | if (jcp.with_bias) { |
693 | for (int g = 0; g < bias_ch_tail; ++g) { |
694 | const size_t bias_offset |
695 | = static_cast<size_t>(nb_ch * ch_block + g); |
696 | diff_bias[bias_offset] |
697 | += diff_bias_reduction_buffer[b_accum_offset |
698 | + bias_offset]; |
699 | } |
700 | } |
701 | if (bias_ch_tail > 0) { |
702 | for_(int kh = 0; kh < jcp.kh; ++kh) |
703 | for (int kw = 0; kw < jcp.kw; ++kw) { |
704 | const size_t wei_sp_offset = static_cast<size_t>( |
705 | ((nb_ch * jcp.kh + kh) * jcp.kw + kw) * ch_block); |
706 | for (int g = 0; g < bias_ch_tail; ++g) { |
707 | const size_t wei_offset = wei_sp_offset + g; |
708 | diff_weights[wei_offset] |
709 | += diff_wei_reduction_buffer[mb_accum_offset |
710 | + wei_offset]; |
711 | } |
712 | } |
713 | } |
714 | } |
715 | } |
716 | |
717 | template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type> |
718 | void jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
719 | diff_weights_type>::execute_reduction(const exec_ctx_t &ctx) const { |
720 | |
721 | const auto &jcp = pd()->jcp_; |
722 | assert(everyone_is(f32, diff_weights_type, jcp.dwei_dt)); |
723 | |
724 | auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
725 | auto diff_wei_reduction_buffer |
726 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
727 | key_conv_wei_reduction); |
728 | auto diff_bias_reduction_buffer |
729 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
730 | key_conv_bia_reduction); |
731 | auto diff_bias_f32_to_bf16_accum |
732 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
733 | key_conv_bias_bf16_convert_wsp); |
734 | float *diff_bias = jcp.bia_dt == bf16 |
735 | ? diff_bias_f32_to_bf16_accum |
736 | : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
737 | |
738 | /* Apply single-threaded 'mb' reduction */ |
739 | for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { |
740 | const int ch_block = jcp.ch_block; |
741 | const size_t wei_size |
742 | = static_cast<size_t>(jcp.ngroups * jcp.kh * jcp.kw); |
743 | const size_t mb_accum_offset = (thr_mb - 1) * wei_size; |
744 | const size_t bias_size = jcp.ngroups; |
745 | const size_t b_accum_offset = (thr_mb - 1) * bias_size; |
746 | |
747 | if (jcp.with_bias) { // Reduction on Bias: |
748 | const int bias_ch_tail = jcp.ch_tail; |
749 | const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch; |
750 | for (int g = 0; g < nb_ch; ++g) { |
751 | PRAGMA_OMP_SIMD() |
752 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
753 | const size_t bias_offset |
754 | = static_cast<size_t>(g * ch_block + g_block); |
755 | diff_bias[bias_offset] |
756 | += diff_bias_reduction_buffer[b_accum_offset |
757 | + bias_offset]; |
758 | } |
759 | } |
760 | // handle reduction for channel tail |
761 | for (int g = 0; g < bias_ch_tail; g++) { |
762 | const size_t bias_offset |
763 | = static_cast<size_t>(nb_ch * ch_block + g); |
764 | diff_bias[bias_offset] |
765 | += diff_bias_reduction_buffer[b_accum_offset |
766 | + bias_offset]; |
767 | } |
768 | } |
769 | acc_ker_->accumulate(&diff_weights[0], |
770 | &diff_wei_reduction_buffer[mb_accum_offset], wei_size); |
771 | } |
772 | |
773 | if (jcp.bia_dt == bf16) { |
774 | auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS); |
775 | cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.ngroups); |
776 | } |
777 | } |
778 | |
779 | template <cpu_isa_t isa, data_type_t src_type, data_type_t diff_weights_type> |
780 | void jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
781 | diff_weights_type>::execute_reduction_nxc(const exec_ctx_t &ctx) const { |
782 | |
783 | const auto &jcp = pd()->jcp_; |
784 | auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
785 | auto diff_wei_reduction_buffer |
786 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
787 | key_conv_wei_reduction); |
788 | auto diff_bia_reduction_buffer |
789 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
790 | key_conv_bia_reduction); |
791 | auto diff_bias_f32_to_bf16_accum |
792 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
793 | key_conv_bias_bf16_convert_wsp); |
794 | float *diff_bias = jcp.bia_dt == bf16 |
795 | ? diff_bias_f32_to_bf16_accum |
796 | : CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
797 | |
798 | const size_t wei_size = static_cast<size_t>( |
799 | utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw); |
800 | |
801 | // TODO: maybe add 'KH' as another parallel dimension to increase partition |
802 | // space |
803 | parallel_nd(jcp.nb_ch, [&](int NB_CH) { |
804 | const size_t nb_ch_step |
805 | = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block); |
806 | const size_t wei_offset = NB_CH * nb_ch_step; |
807 | |
808 | f32_data_t *ithr_diff_weights = diff_weights_type == f32 |
809 | ? (f32_data_t *)&diff_weights[wei_offset] |
810 | : &diff_wei_reduction_buffer[wei_offset]; |
811 | auto ithr_dwei_reduction_buff = &diff_wei_reduction_buffer[wei_offset]; |
812 | |
813 | const int thr_work = jcp.nthr_mb * jcp.nthr_oh; |
814 | for (int ithr_reduction = 0; ithr_reduction < thr_work - 1; |
815 | ++ithr_reduction) { |
816 | const int mb_ithr = ithr_reduction % jcp.nthr_mb; |
817 | const int oh_ithr = (ithr_reduction / jcp.nthr_mb) % jcp.nthr_oh; |
818 | const size_t ithr_offset |
819 | = static_cast<size_t>(mb_ithr * jcp.nthr_oh + oh_ithr); |
820 | const int offset_wei_buffer = diff_weights_type == bf16 ? 1 : 0; |
821 | const size_t reduction_offset |
822 | = (ithr_offset + offset_wei_buffer) * wei_size; |
823 | const size_t reduction_size |
824 | = static_cast<size_t>(jcp.kh * jcp.kw * jcp.ch_block); |
825 | acc_ker_->accumulate(&ithr_diff_weights[0], |
826 | &ithr_dwei_reduction_buff[reduction_offset], |
827 | reduction_size); |
828 | |
829 | const bool compute_bias = jcp.with_bias; |
830 | const int ch_block = jcp.ch_block; |
831 | const size_t bias_size = jcp.ngroups; |
832 | const size_t bias_accum_offset = ithr_offset * bias_size; |
833 | if (compute_bias) { |
834 | const size_t nb_ch_offset = NB_CH * ch_block; |
835 | const int bias_ch_tail = jcp.ch_tail; |
836 | const bool compute_ch_tail |
837 | = (NB_CH == jcp.nb_ch - 1) && bias_ch_tail > 0; |
838 | if (!compute_ch_tail) { |
839 | PRAGMA_OMP_SIMD() |
840 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
841 | const size_t bias_offset |
842 | = static_cast<size_t>(nb_ch_offset + g_block); |
843 | diff_bias[bias_offset] |
844 | += diff_bia_reduction_buffer[bias_accum_offset |
845 | + bias_offset]; |
846 | } |
847 | } else { |
848 | // handle reduction for channel tail |
849 | for (int g = 0; g < bias_ch_tail; g++) { |
850 | const size_t bias_offset |
851 | = static_cast<size_t>(nb_ch_offset + g); |
852 | diff_bias[bias_offset] |
853 | += diff_bia_reduction_buffer[bias_accum_offset |
854 | + bias_offset]; |
855 | } |
856 | } |
857 | } |
858 | } |
859 | }); |
860 | |
861 | if (diff_weights_type == bf16) { |
862 | cvt_float_to_bfloat16((bfloat16_t *)&(diff_weights[0]), |
863 | (const float *)&(diff_wei_reduction_buffer[0]), wei_size); |
864 | } |
865 | |
866 | if (jcp.bia_dt == bf16) { |
867 | auto diff_bias_in = CTX_OUT_MEM(bf16_data_t *, DNNL_ARG_DIFF_BIAS); |
868 | cvt_float_to_bfloat16(diff_bias_in, diff_bias, jcp.oc_without_padding); |
869 | } |
870 | } |
871 | |
872 | template <> |
873 | void jit_uni_dw_convolution_bwd_weights_t<sse41, f32>::execute_reduction_nxc( |
874 | const exec_ctx_t &ctx) const { |
875 | |
876 | auto diff_weights = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
877 | auto diff_bias = CTX_OUT_MEM(f32_data_t *, DNNL_ARG_DIFF_BIAS); |
878 | |
879 | auto diff_wei_reduction_buffer |
880 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
881 | key_conv_wei_reduction); |
882 | auto diff_bia_reduction_buffer |
883 | = ctx.get_scratchpad_grantor().template get<f32_data_t>( |
884 | key_conv_bia_reduction); |
885 | |
886 | const auto &jcp = pd()->jcp_; |
887 | |
888 | const int thr_work = jcp.nthr_mb * jcp.nthr_oh; |
889 | int ithr_reduction = 1; |
890 | while (ithr_reduction < thr_work) { |
891 | const int mb_ithr = (ithr_reduction - 1) % jcp.nthr_mb; |
892 | const int oh_ithr = ((ithr_reduction - 1) / jcp.nthr_mb) % jcp.nthr_oh; |
893 | const size_t ithr_offset |
894 | = static_cast<size_t>(mb_ithr * jcp.nthr_oh + oh_ithr); |
895 | const size_t wei_size = static_cast<size_t>( |
896 | utils::rnd_up(jcp.ngroups, jcp.ch_block) * jcp.kh * jcp.kw); |
897 | const size_t reduction_offset = ithr_offset * wei_size; |
898 | |
899 | const int ch_block = jcp.ch_block; |
900 | const size_t bias_size = jcp.ngroups; |
901 | size_t b_accum_offset = ithr_offset * bias_size; |
902 | |
903 | const bool compute_bias = jcp.with_bias; |
904 | const int bias_ch_tail = jcp.ch_tail; |
905 | const int nb_ch = bias_ch_tail > 0 ? jcp.nb_ch - 1 : jcp.nb_ch; |
906 | for (int g = 0; g < nb_ch; ++g) { |
907 | if (compute_bias) { |
908 | PRAGMA_OMP_SIMD() |
909 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
910 | const size_t bias_offset |
911 | = static_cast<size_t>(g * ch_block + g_block); |
912 | diff_bias[bias_offset] |
913 | += diff_bia_reduction_buffer[b_accum_offset |
914 | + bias_offset]; |
915 | } |
916 | } |
917 | for_(int kh = 0; kh < jcp.kh; ++kh) |
918 | for (int kw = 0; kw < jcp.kw; ++kw) { |
919 | const size_t wei_sp_offset |
920 | = static_cast<size_t>((g * jcp.kh + kh) * jcp.kw + kw); |
921 | PRAGMA_OMP_SIMD() |
922 | for (int g_block = 0; g_block < ch_block; ++g_block) { |
923 | const size_t wei_offset = static_cast<size_t>( |
924 | wei_sp_offset * ch_block + g_block); |
925 | diff_weights[wei_offset] |
926 | += diff_wei_reduction_buffer[reduction_offset |
927 | + wei_offset]; |
928 | } |
929 | } |
930 | } |
931 | // handle reduction for channel tail |
932 | if (compute_bias) { |
933 | for (int g = 0; g < bias_ch_tail; ++g) { |
934 | const size_t bias_offset |
935 | = static_cast<size_t>(nb_ch * ch_block + g); |
936 | diff_bias[bias_offset] |
937 | += diff_bia_reduction_buffer[b_accum_offset |
938 | + bias_offset]; |
939 | } |
940 | } |
941 | if (bias_ch_tail > 0) { |
942 | for_(int kh = 0; kh < jcp.kh; ++kh) |
943 | for (int kw = 0; kw < jcp.kw; ++kw) { |
944 | const size_t wei_sp_offset = static_cast<size_t>( |
945 | (nb_ch * jcp.kh + kh) * jcp.kw + kw); |
946 | for (int g = 0; g < bias_ch_tail; ++g) { |
947 | const size_t wei_offset |
948 | = static_cast<size_t>(wei_sp_offset * ch_block + g); |
949 | diff_weights[wei_offset] |
950 | += diff_wei_reduction_buffer[reduction_offset |
951 | + wei_offset]; |
952 | } |
953 | } |
954 | } |
955 | |
956 | ithr_reduction++; |
957 | } |
958 | } |
959 | |
960 | REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core, |
961 | bf16>); |
962 | REG_AVX512_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core, |
963 | bf16, f32>); |
964 | REG_AVX512_ISA( |
965 | template struct jit_uni_dw_convolution_bwd_weights_t<avx512_core, f32>); |
966 | REG_AVX2_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<avx2, f32>); |
967 | REG_SSE41_ISA(template struct jit_uni_dw_convolution_bwd_weights_t<sse41, f32>); |
968 | } // namespace x64 |
969 | } // namespace cpu |
970 | } // namespace impl |
971 | } // namespace dnnl |
972 | |