1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_avx512_common_convolution.hpp"
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::memory_tracking::names;
31using namespace dnnl::impl::utils;
32
33using namespace nstl;
34
35using jit_conv_ker_t = void (*)(jit_conv_call_s *);
36
37inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p,
38 const void *src, const void *dst, const void *filt, const void *bias,
39 int channel, int kh_padding, int reduce_work, int load_work) {
40 p.src = src;
41 p.dst = dst;
42 p.filt = filt;
43 p.bias = bias;
44 p.channel = channel;
45 // non-positive value of kh_padding is allowed, in this case kernel must
46 // skip computation part and initialize output by zeroes
47 p.kh_padding = kh_padding;
48 p.reduce_work = reduce_work;
49 p.load_work = load_work;
50
51 ker(&p);
52}
53// The special case for the driver with iw-parallelization (BWD)
54inline void jit_conv_ker_pipeline_iw_thr(const jit_conv_ker_t ker,
55 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
56 const void *bias, int channel, int kh_padding, int iwb, int reduce_work,
57 int load_work) {
58 p.iwb = iwb;
59
60 jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding,
61 reduce_work, load_work);
62}
63
64inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker,
65 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
66 const void *bias, int channel, int kh_padding, int kd_padding,
67 int reduce_work, int load_work) {
68 p.src = src;
69 p.dst = dst;
70 p.filt = filt;
71 p.bias = bias;
72 p.channel = channel;
73 // non-positive value of both kd_padding and kh_padding is allowed, in this
74 // case kernel must skip computation part and initialize output by zeroes
75 p.kh_padding = kh_padding;
76 p.kd_padding = kd_padding;
77 p.reduce_work = reduce_work;
78 p.load_work = load_work;
79
80 ker(&p);
81}
82// The special case for the driver with ow-parallelization (FWD)
83inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
84 const void *src, const void *dst, const void *filt, const void *bias,
85 int channel, int kh_padding, int owb, int reduce_work, int load_work,
86 const void *post_ops_binary_rhs_arg_vec, int oc_l_off,
87 const void *dst_orig, int flags) {
88 p.owb = owb;
89 p.flags = flags;
90
91 p.oc_l_off = oc_l_off;
92 p.dst_orig = dst_orig;
93 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
94
95 jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding,
96 reduce_work, load_work);
97}
98
99// The special case for the driver with ow-parallelization (FWD)
100// TODO: implement it for BWD_D and BWD_W too
101inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker,
102 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
103 const void *bias, int channel, int kh_padding, int kd_padding, int owb,
104 int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec,
105 int oc_l_off, const void *dst_orig, int flags) {
106
107 p.oc_l_off = oc_l_off;
108 p.dst_orig = dst_orig;
109 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
110
111 p.owb = owb;
112 p.flags = flags;
113
114 jit_conv_3d_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding,
115 kd_padding, reduce_work, load_work);
116}
117
118inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker,
119 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
120 const void *bias, int channel, int kh_padding, size_t reduce_work,
121 size_t load_work) {
122 jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding,
123 reduce_work, load_work);
124}
125
126void jit_conv_2d_ker_bwd_w_pipeline(const jit_conv_ker_t ker,
127 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
128 const void *bias, int channel, int os_index_begin, int os_index_end,
129 int kh_padding /* kh_work_size */, size_t kh_offset, size_t reduce_work,
130 size_t load_work) {
131 p.src = src;
132 p.dst = dst;
133 p.filt = filt;
134 p.bias = bias;
135 p.channel = channel;
136 p.os_index_begin = os_index_begin;
137 p.os_index_end = os_index_end;
138 // non-positive value of kh_padding is allowed, in this case kernel must
139 // skip kw loop computation and initialize output by zeroes
140 p.kh_padding = kh_padding;
141 p.kh_offset = kh_offset;
142 p.reduce_work = reduce_work;
143 p.load_work = load_work;
144
145 ker(&p);
146}
147
148void jit_conv_3d_ker_bwd_w_pipeline(const jit_conv_ker_t ker,
149 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
150 const void *bias, int channel, int os_index_begin, int os_index_end,
151 int kd_padding /* kd_work_size */, size_t kd_offset, size_t reduce_work,
152 size_t load_work) {
153 p.src = src;
154 p.dst = dst;
155 p.filt = filt;
156 p.bias = bias;
157 p.channel = channel;
158 p.os_index_begin = os_index_begin;
159 p.os_index_end = os_index_end;
160 // non-positive value of kd_padding is allowed, in this case kernel must
161 // skip kh loop computation and initialize output by zeroes
162 p.kd_padding = kd_padding;
163 p.kd_offset = kd_offset;
164 p.reduce_work = reduce_work;
165 p.load_work = load_work;
166
167 ker(&p);
168}
169#define wht_blk_off(d, g, ...) \
170 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
171 : (d).blk_off(__VA_ARGS__))
172
173template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
174void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
175 dst_type>::prepare_padded_bias(const dst_data_t *&bias,
176 const memory_tracking::grantor_t &scratchpad) const {
177 if (!pd()->wants_padded_bias()) return;
178
179 auto padded_bias
180 = scratchpad.template get<dst_data_t>(key_conv_padded_bias);
181 utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
182 utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, (dst_data_t)0,
183 pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
184 bias = padded_bias;
185}
186
187template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
188void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
189 dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const {
190 const auto &jcp = pd()->jcp_;
191
192 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
193 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
194 auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS);
195 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
196 const auto post_ops_binary_rhs_arg_vec
197 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
198
199 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
200
201 const memory_desc_wrapper src_d(pd()->src_md());
202 const memory_desc_wrapper dst_d(pd()->dst_md());
203 const memory_desc_wrapper weights_d(pd()->weights_md(0));
204
205 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
206 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
207
208 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
209 int g_blocking = 1;
210 int nb_groups = jcp.ngroups / g_blocking;
211 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
212 int nthr = jcp.aligned_threads;
213
214 parallel(nthr, [&](const int ithr, const int nthr) {
215 int start {0}, end {0}, start_copy;
216 balance211(work_amount, nthr, ithr, start, end);
217 start_copy = start;
218
219 auto par_conv = jit_conv_call_s();
220 size_t src_c_stride = src_d.blk_off(0, 1);
221 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
222
223 for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
224 start = start_copy;
225 int n {0}, gg {0}, occ {0}, owb {0};
226
227 if (jcp.loop_order == loop_cwgn) {
228 int dummy {0};
229 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
230 nb_groups, n, jcp.mb, dummy, 1);
231 } else if (jcp.loop_order == loop_gncw) {
232 int dummy {0};
233 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ,
234 oc_chunks, owb, jcp.nb_ow, dummy, 1);
235 } else if (jcp.loop_order == loop_nhwcg) {
236 nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ,
237 oc_chunks, gg, nb_groups);
238 } else {
239 assert(!"unsupported loop order");
240 }
241
242 while (start < end) {
243 int ocb = occ * jcp.nb_oc_blocking;
244 int g = gg * g_blocking;
245 int g_ocb = g * jcp.nb_oc + ocb;
246 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
247
248 int ow_s = owb * jcp.ow_block;
249 int iw_s = ow_s * jcp.stride_w;
250 const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nwc;
251 const int oc_off_idx = is_dst_layout_nxc
252 ? g * jcp.oc + ocb * jcp.oc_block
253 : g_ocb;
254 auto dst_w = dst + dst_d.blk_off(n, oc_off_idx, ow_s);
255 const bool is_src_layout_nxc = jcp.src_tag == format_tag::nwc;
256 const int ic_off_idx = is_src_layout_nxc
257 ? g * jcp.ic + icb_l2 * jcp.ic_block
258 : g_icb + icb_l2;
259 auto src_w = src + src_d.blk_off(n, ic_off_idx, iw_s);
260 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
261 auto bias_w = bias ? bias
262 + oc_off_idx
263 * (is_dst_layout_nxc ? 1 : jcp.oc_block)
264 : nullptr;
265
266 int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1;
267 int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
268 const int oc_work = utils::this_block_size(ocb * jcp.oc_block,
269 jcp.oc_without_padding,
270 jcp.nb_oc_blocking * jcp.oc_block);
271
272 int ic_work = icb_step * jcp.ic_block;
273 const int oc_l_off
274 = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
275
276 for (int icb = icb_l2; icb < icb_end; icb += icb_step) {
277 int curr_nb_ic = nstl::min(icb_step, icb_end - icb);
278 int flags = 0;
279 if (icb == 0) flags |= FLAG_IC_FIRST;
280 if (icb + curr_nb_ic >= jcp.nb_ic) {
281 flags |= FLAG_IC_LAST;
282 ic_work = utils::this_block_size(icb * jcp.ic_block,
283 jcp.ic, icb_step * jcp.ic_block);
284 }
285 jit_conv_ker_pipeline_ow_thr(jit_ker, par_conv, src_w,
286 dst_w, wht_w, bias_w, icb, 1, owb, ic_work, oc_work,
287 post_ops_binary_rhs_arg_vec.data(), oc_l_off, dst,
288 flags);
289
290 src_w += src_c_stride;
291 wht_w += wht_ic_stride;
292 }
293 if (jcp.loop_order == loop_cwgn) {
294 int dummy {0};
295 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
296 gg, nb_groups, n, jcp.mb, dummy, 1);
297 } else if (jcp.loop_order == loop_gncw) {
298 int dummy {0};
299 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
300 oc_chunks, owb, jcp.nb_ow, dummy, 1);
301 } else if (jcp.loop_order == loop_nhwcg) {
302 ++start;
303 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
304 gg, nb_groups);
305 } else {
306 assert(!"unsupported loop order");
307 }
308 }
309 }
310 });
311}
312
313template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
314void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
315 dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const {
316
317 const auto &jcp = pd()->jcp_;
318 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
319 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
320 auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS);
321 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
322 const auto post_ops_binary_rhs_arg_vec
323 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
324
325 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
326
327 const memory_desc_wrapper src_d(pd()->src_md());
328 const memory_desc_wrapper dst_d(pd()->dst_md());
329 const memory_desc_wrapper weights_d(pd()->weights_md(0));
330
331 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
332 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
333
334 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
335 int g_blocking = 1;
336 int nb_groups = jcp.ngroups / g_blocking;
337 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
338 int nthr = jcp.aligned_threads;
339
340 parallel(nthr, [&](const int ithr, const int nthr) {
341 int start {0}, end {0}, start_copy;
342 balance211(work_amount, nthr, ithr, start, end);
343 start_copy = start;
344
345 auto par_conv = jit_conv_call_s();
346 size_t src_h_stride = src_d.blk_off(0, 0, 1);
347 size_t src_c_stride = src_d.blk_off(0, 1);
348 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
349 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
350 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
351
352 for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
353 start = start_copy;
354 int n {0}, gg {0}, occ {0}, oh_s {0}, owb {0};
355
356 if (jcp.loop_order == loop_cwgn)
357 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
358 nb_groups, n, jcp.mb, oh_s, jcp.oh);
359 else if (jcp.loop_order == loop_gncw)
360 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ,
361 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
362 else if (jcp.loop_order == loop_nhwcg)
363 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
364 occ, oc_chunks, gg, nb_groups);
365 else
366 assert(!"unsupported loop order");
367
368 while (start < end) {
369 int ocb = occ * jcp.nb_oc_blocking;
370 int g = gg * g_blocking;
371 int g_ocb = g * jcp.nb_oc + ocb;
372 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
373
374 int work_rem = end - start;
375
376 int ow_s = owb * jcp.ow_block;
377 int iw_s = ow_s * jcp.stride_w;
378 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
379 if (jcp.loop_order == loop_nhwcg)
380 oh_e = oh_s + 1; //step instead
381
382 for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
383 int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
384 const bool is_dst_layout_nxc
385 = jcp.dst_tag == format_tag::nhwc;
386 const int oc_off_idx = is_dst_layout_nxc
387 ? g * jcp.oc + ocb * jcp.oc_block
388 : g_ocb;
389 auto dst_w = dst + dst_d.blk_off(n, oc_off_idx, oh_b, ow_s);
390 const bool is_src_layout_nxc
391 = jcp.src_tag == format_tag::nhwc;
392 const int ic_off_idx = is_src_layout_nxc
393 ? g * jcp.ic + icb_l2 * jcp.ic_block
394 : g_icb + icb_l2;
395 auto src_w = src + src_d.blk_off(n, ic_off_idx, ih_b, iw_s);
396 auto wht_w
397 = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
398
399 int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1;
400 int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
401 auto bias_w = bias ? bias
402 + oc_off_idx
403 * (is_dst_layout_nxc ? 1
404 : jcp.oc_block)
405 : nullptr;
406 const int oc_work = utils::this_block_size(
407 ocb * jcp.oc_block, jcp.oc_without_padding,
408 jcp.nb_oc_blocking * jcp.oc_block);
409 const int oc_l_off = oc_off_idx
410 * (is_dst_layout_nxc ? 1 : jcp.oc_block);
411 int ic_work = icb_step * jcp.ic_block;
412 for (int icb = icb_l2; icb < icb_end; icb += icb_step) {
413 int curr_nb_ic = nstl::min(icb_step, icb_end - icb);
414 int flags = 0;
415 if (icb == 0) flags |= FLAG_IC_FIRST;
416 if (icb + curr_nb_ic >= jcp.nb_ic) {
417 flags |= FLAG_IC_LAST;
418 ic_work = utils::this_block_size(icb * jcp.ic_block,
419 jcp.ic, icb_step * jcp.ic_block);
420 }
421 auto src_c = src_w;
422 auto dst_c = dst_w;
423 for (int oj = oh_b, ij = ih_b;
424 oj < min(oh_e, oh_b + jcp.h_blocking);
425 ++oj, ij += jcp.stride_h) {
426 int dilate_h = jcp.dilate_h + 1;
427 int i_t_overflow = div_up(max(0, -ij), dilate_h);
428 int i_b_overflow = div_up(
429 max(0,
430 ij - jcp.ih
431 + (jcp.kh - 1) * dilate_h
432 + 1),
433 dilate_h);
434 int kh_padding = nstl::max(
435 0, jcp.kh - i_t_overflow - i_b_overflow);
436
437 auto aux_src = src_c
438 + i_t_overflow * dilate_h * src_h_stride;
439 auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
440
441 jit_conv_ker_pipeline_ow_thr(jit_ker, par_conv,
442 aux_src, dst_c, aux_wht, bias_w, icb,
443 kh_padding, owb, ic_work, oc_work,
444 post_ops_binary_rhs_arg_vec.data(),
445 oc_l_off, dst, flags);
446
447 src_c += src_h_stride * jcp.stride_h;
448 dst_c += dst_h_stride;
449 }
450 src_w += src_c_stride;
451 wht_w += wht_ic_stride;
452 }
453 }
454
455 if (jcp.loop_order == loop_cwgn)
456 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
457 gg, nb_groups, n, jcp.mb, oh_s, jcp.oh);
458 else if (jcp.loop_order == loop_gncw)
459 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
460 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
461 else if (jcp.loop_order == loop_nhwcg) {
462 ++start;
463 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
464 occ, oc_chunks, gg, nb_groups);
465 } else
466 assert(!"unsupported loop order");
467 }
468 }
469 });
470}
471
472template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
473void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
474 dst_type>::execute_forward_3d(const exec_ctx_t &ctx) const {
475 const auto &jcp = pd()->jcp_;
476 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
477 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
478 auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS);
479 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
480 const auto post_ops_binary_rhs_arg_vec
481 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
482
483 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
484
485 const memory_desc_wrapper src_d(pd()->src_md());
486 const memory_desc_wrapper dst_d(pd()->dst_md());
487 const memory_desc_wrapper weights_d(pd()->weights_md(0));
488
489 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
490 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
491
492 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
493 int g_blocking = 1;
494 int nb_groups = jcp.ngroups / g_blocking;
495 int work_amount
496 = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow;
497 int nthr = jcp.nthr;
498
499 parallel(nthr, [&](const int ithr, const int nthr) {
500 int start {0}, end {0}, start_copy;
501 balance211(work_amount, nthr, ithr, start, end);
502 start_copy = start;
503
504 auto par_conv = jit_conv_call_s();
505 size_t src_d_stride = src_d.blk_off(0, 0, 1);
506 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
507 size_t src_c_stride = src_d.blk_off(0, 1);
508 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
509 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
510 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
511 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
512
513 for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
514 start = start_copy;
515 int n {0}, gg {0}, occ {0}, oh_s {0}, od_s {0}, owb {0};
516
517 if (jcp.loop_order == loop_cwgn)
518 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
519 nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
520 else if (jcp.loop_order == loop_gncw)
521 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ,
522 oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh);
523 else if (jcp.loop_order == loop_nhwcg)
524 nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh,
525 owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups);
526 else
527 assert(!"unsupported loop order");
528
529 while (start < end) {
530 int ocb = occ * jcp.nb_oc_blocking;
531 int g = gg * g_blocking;
532 int g_ocb = g * jcp.nb_oc + ocb;
533 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
534
535 int work_rem = end - start;
536 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
537 int ow_s = owb * jcp.ow_block;
538 int iw_s = ow_s * jcp.stride_w;
539 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
540 if (jcp.loop_order == loop_nhwcg)
541 oh_e = oh_s + 1; //step instead
542
543 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
544
545 int dilate_d = jcp.dilate_d + 1;
546 int d_t_overflow = div_up(max(0, -id_s), dilate_d);
547 int d_b_overflow = div_up(
548 max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
549 dilate_d);
550 int kd_padding
551 = nstl::max(0, jcp.kd - d_t_overflow - d_b_overflow);
552 const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::ndhwc;
553 const int oc_off_idx = is_dst_layout_nxc
554 ? g * jcp.oc + ocb * jcp.oc_block
555 : g_ocb;
556 auto dst_w
557 = dst + dst_d.blk_off(n, oc_off_idx, od_s, oh_s, ow_s);
558 const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc;
559 const int ic_off_idx = is_src_layout_nxc
560 ? g * jcp.ic + icb_l2 * jcp.ic_block
561 : g_icb + icb_l2;
562 auto src_w = src
563 + src_d.blk_off(n, ic_off_idx, id_s, ih_s, iw_s)
564 + d_t_overflow * dilate_d * src_d_stride;
565 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
566 + d_t_overflow * wht_d_stride;
567 auto bias_w = bias ? bias
568 + oc_off_idx
569 * (is_dst_layout_nxc ? 1 : jcp.oc_block)
570 : nullptr;
571
572 const int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1;
573 int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
574 const int oc_work = utils::this_block_size(ocb * jcp.oc_block,
575 jcp.oc_without_padding,
576 jcp.nb_oc_blocking * jcp.oc_block);
577
578 const int oc_l_off
579 = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
580 int ic_work = icb_step * jcp.ic_block;
581 for (int icb = icb_l2; icb < icb_end; icb += icb_step) {
582 int curr_nb_ic = nstl::min(icb_step, icb_end - icb);
583 int flags = 0;
584 if (icb == 0) flags |= FLAG_IC_FIRST;
585 if (icb + curr_nb_ic >= jcp.nb_ic) {
586 flags |= FLAG_IC_LAST;
587 ic_work = utils::this_block_size(icb * jcp.ic_block,
588 jcp.ic, icb_step * jcp.ic_block);
589 }
590 auto src_c = src_w;
591 auto dst_c = dst_w;
592 for (int oj = oh_s, ij = ih_s; oj < oh_e;
593 ++oj, ij += jcp.stride_h) {
594 int dilate_h = jcp.dilate_h + 1;
595 int i_t_overflow = div_up(max(0, -ij), dilate_h);
596 int i_b_overflow = div_up(
597 max(0,
598 ij - jcp.ih + (jcp.kh - 1) * dilate_h
599 + 1),
600 dilate_h);
601 int kh_padding = nstl::max(
602 0, jcp.kh - i_t_overflow - i_b_overflow);
603 jit_conv_3d_ker_pipeline_ow_thr(jit_ker, par_conv,
604 src_c + i_t_overflow * dilate_h * src_h_stride,
605 dst_c, wht_w + i_t_overflow * wht_h_stride,
606 bias_w, icb, kh_padding, kd_padding, owb,
607 ic_work, oc_work,
608 post_ops_binary_rhs_arg_vec.data(), oc_l_off,
609 dst, flags);
610
611 src_c += src_h_stride * jcp.stride_h;
612 dst_c += dst_h_stride;
613 }
614 src_w += src_c_stride;
615 wht_w += wht_ic_stride;
616 }
617
618 if (jcp.loop_order == loop_cwgn)
619 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
620 gg, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s,
621 jcp.oh);
622 else if (jcp.loop_order == loop_gncw)
623 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
624 oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s,
625 jcp.oh);
626 else if (jcp.loop_order == loop_nhwcg) {
627 ++start;
628 nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb,
629 jcp.nb_ow, occ, oc_chunks, gg, nb_groups);
630 } else
631 assert(!"unsupported loop order");
632 }
633 }
634 });
635}
636
637template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
638
639template <data_type_t diff_dst_type, data_type_t wei_type,
640 data_type_t diff_src_type>
641void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
642 diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const {
643 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
644 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
645 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
646
647 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
648 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
649 const memory_desc_wrapper weights_d(pd()->weights_md(0));
650
651 const auto &jcp = pd()->jcp_;
652 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
653
654 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
655 int g_blocking = 1;
656 int nb_groups = jcp.ngroups / g_blocking;
657 int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.nb_iw;
658 int nthr = jcp.nthr;
659
660 parallel(nthr, [&](const int ithr, const int nthr) {
661 int start {0}, end {0}, start_copy;
662 balance211(work_amount, nthr, ithr, start, end);
663 start_copy = start;
664
665 auto par_conv = jit_conv_call_s();
666 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
667 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
668
669 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
670 start = start_copy;
671 int n {0}, gg {0}, icc {0}, iwb {0};
672 if (jcp.loop_order == loop_cwgn) {
673 int dummy {0};
674 nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg,
675 nb_groups, n, jcp.mb, dummy, 1);
676 } else if (jcp.loop_order == loop_gncw) {
677 int dummy {0};
678 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc,
679 ic_chunks, iwb, jcp.nb_iw, dummy, 1);
680 } else if (jcp.loop_order == loop_nhwcg) {
681 nd_iterator_init(start, n, jcp.mb, iwb, jcp.nb_iw, icc,
682 ic_chunks, gg, nb_groups);
683 } else {
684 assert(!"unsupported loop order");
685 }
686
687 while (start < end) {
688 int icb = icc * jcp.nb_ic_blocking;
689 int g = gg * g_blocking;
690 int g_icb = g * jcp.nb_ic + icb;
691 int g_ocb = g * jcp.nb_oc;
692 int iw_s = iwb * jcp.iw_block;
693 int ow_s = iw_s / jcp.stride_w;
694
695 const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::nwc;
696 const int ic_off_idx = is_dsrc_layout_nxc
697 ? g * jcp.ic + icb * jcp.ic_block
698 : g_icb;
699 auto diff_src_w
700 = diff_src + diff_src_d.blk_off(n, ic_off_idx, iw_s);
701 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nwc;
702 const int oc_off_idx = is_ddst_layout_nxc
703 ? g * jcp.oc + ocb_l2 * jcp.oc_block
704 : g_ocb + ocb_l2;
705 auto diff_dst_w
706 = diff_dst + diff_dst_d.blk_off(n, oc_off_idx, ow_s);
707 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
708
709 int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1;
710 int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2);
711 const int load_work = utils::this_block_size(icb * jcp.ic_block,
712 jcp.ic, jcp.nb_ic_blocking * jcp.ic_block);
713 int reduce_work = ocb_step * jcp.oc_block;
714 for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) {
715 int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb);
716 if (ocb + curr_nb_oc >= jcp.nb_oc) {
717 reduce_work = utils::this_block_size(ocb * jcp.oc_block,
718 jcp.oc, ocb_step * jcp.oc_block);
719 }
720
721 jit_conv_ker_pipeline_iw_thr(jit_ker, par_conv, diff_src_w,
722 diff_dst_w, wht_w, nullptr, ocb, 1, iwb,
723 reduce_work, load_work);
724 diff_dst_w += diff_dst_c_stride;
725 wht_w += wht_oc_stride;
726 }
727
728 if (jcp.loop_order == loop_cwgn) {
729 int dummy {0};
730 nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw,
731 gg, nb_groups, n, jcp.mb, dummy, 1);
732 } else if (jcp.loop_order == loop_gncw) {
733 int dummy {0};
734 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc,
735 ic_chunks, iwb, jcp.nb_iw, dummy, 1);
736 } else if (jcp.loop_order == loop_nhwcg) {
737 ++start;
738 nd_iterator_step(n, jcp.mb, iwb, jcp.nb_iw, icc, ic_chunks,
739 gg, nb_groups);
740 } else {
741 assert(!"unsupported loop order");
742 }
743 }
744 }
745 });
746}
747
748template <data_type_t diff_dst_type, data_type_t wei_type,
749 data_type_t diff_src_type>
750void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
751 diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const {
752 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
753 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
754 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
755
756 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
757 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
758 const memory_desc_wrapper weights_d(pd()->weights_md(0));
759
760 const auto &jcp = pd()->jcp_;
761 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
762
763 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
764 int g_blocking = 1;
765 int nb_groups = jcp.ngroups / g_blocking;
766 int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw;
767 int nthr = jcp.nthr;
768
769 parallel(nthr, [&](const int ithr, const int nthr) {
770 int start {0}, end {0}, start_copy;
771 balance211(work_amount, nthr, ithr, start, end);
772 start_copy = start;
773
774 auto par_conv = jit_conv_call_s();
775 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
776 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
777 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
778 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
779 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
780
781 bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
782
783 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
784 start = start_copy;
785 int n {0}, gg {0}, icc {0}, ih_s {0}, iwb {0};
786
787 if (jcp.loop_order == loop_cwgn) {
788 nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg,
789 nb_groups, n, jcp.mb, ih_s, jcp.ih);
790 } else if (jcp.loop_order == loop_gncw) {
791 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc,
792 ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih);
793 } else if (jcp.loop_order == loop_nhwcg) {
794 nd_iterator_init(start, n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw,
795 icc, ic_chunks, gg, nb_groups);
796 } else
797 assert(!"unsupported loop order");
798
799 while (start < end) {
800 int icb = icc * jcp.nb_ic_blocking;
801 int g = gg * g_blocking;
802 int g_icb = g * jcp.nb_ic + icb;
803 int g_ocb = g * jcp.nb_oc;
804
805 int work_rem = end - start;
806 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
807 if (jcp.loop_order == loop_nhwcg)
808 ih_e = ih_s + 1; //step instead
809 int iw_s = iwb * jcp.iw_block;
810 int ow_s = iw_s / jcp.stride_w;
811 const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::nhwc;
812 const int ic_off_idx = is_dsrc_layout_nxc
813 ? g * jcp.ic + icb * jcp.ic_block
814 : g_icb;
815 auto diff_src_w
816 = diff_src + diff_src_d.blk_off(n, ic_off_idx, 0, iw_s);
817 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc;
818 const int oc_off_idx = is_ddst_layout_nxc
819 ? g * jcp.oc + ocb_l2 * jcp.oc_block
820 : g_ocb + ocb_l2;
821 auto diff_dst_w
822 = diff_dst + diff_dst_d.blk_off(n, oc_off_idx, 0, ow_s);
823 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
824
825 int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1;
826 int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2);
827 const int load_work = utils::this_block_size(icb * jcp.ic_block,
828 jcp.ic, jcp.nb_ic_blocking * jcp.ic_block);
829 int reduce_work = ocb_step * jcp.oc_block;
830 for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) {
831 int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb);
832 if (ocb + curr_nb_oc >= jcp.nb_oc) {
833 reduce_work = utils::this_block_size(ocb * jcp.oc_block,
834 jcp.oc, ocb_step * jcp.oc_block);
835 }
836 for (int ij = ih_s; ij < ih_e; ++ij) {
837 int oj, k_len, k_lo;
838 if (is_fast_path) { // dilate == 0 && stride == 1
839 int i_t_overflow
840 = max(0, jcp.kh - 1 - ij - jcp.t_pad);
841 int i_b_overflow
842 = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad);
843 k_len = jcp.kh - i_t_overflow - i_b_overflow;
844 k_lo = i_b_overflow;
845 oj = ij + jcp.t_pad - i_b_overflow;
846 } else if (jcp.dilate_h != 0) { // stride == 1
847 int dilate_h = jcp.dilate_h + 1;
848 // Note: use div_up to account for "holes" in filter
849 int i_t_overflow
850 = div_up(max(0,
851 (jcp.kh - 1) * dilate_h
852 - ij - jcp.t_pad),
853 dilate_h);
854 int i_b_overflow = div_up(
855 max(0,
856 (jcp.kh - 1) * dilate_h + 1 - jcp.ih
857 + ij - jcp.b_pad),
858 dilate_h);
859 k_len = jcp.kh - i_t_overflow - i_b_overflow;
860 k_lo = i_b_overflow;
861 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
862 } else { // dilate == 0
863 int i_t_overflow = max(0,
864 (jcp.kh - 1 - ij - jcp.t_pad)
865 / jcp.stride_h);
866 int i_b_overflow = max(0,
867 (jcp.kh - jcp.ih + ij - jcp.b_pad)
868 / jcp.stride_h);
869 int overflow_kh_hi = jcp.kh - 1
870 - modulo(jcp.ih - 1 + jcp.b_pad - ij,
871 jcp.stride_h);
872 int overflow_kh_lo
873 = (ij + jcp.t_pad) % jcp.stride_h;
874
875 k_len = (overflow_kh_hi - overflow_kh_lo)
876 / jcp.stride_h
877 + 1 - i_t_overflow - i_b_overflow;
878 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
879 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
880 }
881
882 jit_conv_ker_pipeline_iw_thr(jit_ker, par_conv,
883 diff_src_w + ij * diff_src_h_stride,
884 diff_dst_w + oj * diff_dst_h_stride,
885 wht_w + k_lo * wht_h_stride, nullptr, ocb,
886 k_len, iwb, reduce_work, load_work);
887 }
888 diff_dst_w += diff_dst_c_stride;
889 wht_w += wht_oc_stride;
890 }
891
892 if (jcp.loop_order == loop_cwgn) {
893 nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw,
894 gg, nb_groups, n, jcp.mb, ih_s, jcp.ih);
895 } else if (jcp.loop_order == loop_gncw) {
896 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc,
897 ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih);
898 } else if (jcp.loop_order == loop_nhwcg) {
899 ++start;
900 nd_iterator_step(n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw,
901 icc, ic_chunks, gg, nb_groups);
902 } else
903 assert(!"unsupported loop order");
904 }
905 }
906 });
907}
908
909template <data_type_t diff_dst_type, data_type_t wei_type,
910 data_type_t diff_src_type>
911void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
912 diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const {
913 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
914 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
915 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
916
917 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
918 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
919 const memory_desc_wrapper weights_d(pd()->weights_md(0));
920
921 const auto &jcp = pd()->jcp_;
922 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
923
924 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
925 int g_blocking = 1;
926 int nb_groups = jcp.ngroups / g_blocking;
927 int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih;
928 int nthr = jcp.nthr;
929
930 parallel(nthr, [&](const int ithr, const int nthr) {
931 int start {0}, end {0}, start_copy;
932 balance211(work_amount, nthr, ithr, start, end);
933 start_copy = start;
934
935 auto par_conv = jit_conv_call_s();
936 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
937 size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
938 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
939 size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
940 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
941 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
942 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
943 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
944
945 bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
946 bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
947
948 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
949 start = start_copy;
950 int n {0}, gg {0}, icc {0}, ih_s {0}, id_s {0};
951 // Input width threading is not currently implemented for 3d, so it
952 // is not included in the iterator.
953 if (jcp.loop_order == loop_cwgn)
954 nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n,
955 jcp.mb, id_s, jcp.id, ih_s, jcp.ih);
956 else if (jcp.loop_order == loop_gncw)
957 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc,
958 ic_chunks, id_s, jcp.id, ih_s, jcp.ih);
959 else if (jcp.loop_order == loop_nhwcg)
960 nd_iterator_init(start, n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih,
961 icc, ic_chunks, gg, nb_groups);
962 else
963 assert(!"unsupported loop order");
964
965 while (start < end) {
966 int icb = icc * jcp.nb_ic_blocking;
967 int g = gg * g_blocking;
968 int g_icb = g * jcp.nb_ic + icb;
969 int g_ocb = g * jcp.nb_oc;
970
971 int work_rem = end - start;
972 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
973 if (jcp.loop_order == loop_nhwcg)
974 ih_e = ih_s + 1; //step instead
975 int d_len = 0, d_lo = 0, d_oj = 0;
976 if (is_fast_path_d) { // dilate == 0 && stride == 1
977 int d_t_overflow = max(0, jcp.kd - 1 - id_s - jcp.f_pad);
978 int d_b_overflow
979 = max(0, jcp.kd - jcp.id + id_s - jcp.back_pad);
980 d_len = jcp.kd - d_t_overflow - d_b_overflow;
981 d_lo = d_b_overflow;
982 d_oj = id_s + jcp.f_pad - d_b_overflow;
983 } else if (jcp.dilate_d != 0) { // stride == 1
984 int dilate_d = jcp.dilate_d + 1;
985 // Note: use div_up to account for "holes" in filter
986 int d_t_overflow = div_up(
987 max(0, (jcp.kd - 1) * dilate_d - id_s - jcp.f_pad),
988 dilate_d);
989 int d_b_overflow = div_up(
990 max(0,
991 (jcp.kd - 1) * dilate_d + 1 - jcp.id + id_s
992 - jcp.back_pad),
993 dilate_d);
994 d_len = jcp.kd - d_t_overflow - d_b_overflow;
995 d_lo = d_b_overflow;
996 d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
997 } else { // dilate == 0
998 int d_t_overflow = max(
999 0, (jcp.kd - 1 - id_s - jcp.f_pad) / jcp.stride_d);
1000 int d_b_overflow = max(0,
1001 (jcp.kd - jcp.id + id_s - jcp.back_pad)
1002 / jcp.stride_d);
1003 int overflow_kd_hi = jcp.kd - 1
1004 - modulo(jcp.id - 1 + jcp.back_pad - id_s,
1005 jcp.stride_d);
1006 int overflow_kd_lo = (id_s + jcp.f_pad) % jcp.stride_d;
1007
1008 d_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
1009 - d_t_overflow - d_b_overflow;
1010 d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
1011 d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
1012 }
1013
1014 const bool is_dsrc_layout_nxc
1015 = jcp.src_tag == format_tag::ndhwc;
1016 const int ic_off_idx = is_dsrc_layout_nxc
1017 ? g * jcp.ic + icb * jcp.ic_block
1018 : g_icb;
1019 auto diff_src_w = diff_src + diff_src_d.blk_off(n, ic_off_idx)
1020 + id_s * diff_src_d_stride;
1021 const bool is_ddst_layout_nxc
1022 = jcp.dst_tag == format_tag::ndhwc;
1023 const int oc_off_idx = is_ddst_layout_nxc
1024 ? g * jcp.oc + ocb_l2 * jcp.oc_block
1025 : g_ocb + ocb_l2;
1026 auto diff_dst_w = diff_dst + diff_dst_d.blk_off(n, oc_off_idx)
1027 + d_oj * diff_dst_d_stride;
1028 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
1029 + d_lo * wht_d_stride;
1030
1031 int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1;
1032 int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2);
1033 const int load_work = utils::this_block_size(icb * jcp.ic_block,
1034 jcp.ic, jcp.nb_ic_blocking * jcp.ic_block);
1035 int reduce_work = ocb_step * jcp.oc_block;
1036 for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) {
1037 int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb);
1038 if (ocb + curr_nb_oc >= jcp.nb_oc) {
1039 reduce_work = utils::this_block_size(ocb * jcp.oc_block,
1040 jcp.oc, ocb_step * jcp.oc_block);
1041 }
1042 for (int ij = ih_s; ij < ih_e; ++ij) {
1043 int oj, k_len, k_lo;
1044 if (is_fast_path_h) { // dilate == 0 && stride == 1
1045 int i_t_overflow
1046 = max(0, jcp.kh - 1 - ij - jcp.t_pad);
1047 int i_b_overflow
1048 = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad);
1049 k_len = jcp.kh - i_t_overflow - i_b_overflow;
1050 k_lo = i_b_overflow;
1051 oj = ij + jcp.t_pad - i_b_overflow;
1052 } else if (jcp.dilate_h != 0) { // stride == 1
1053 int dilate_h = jcp.dilate_h + 1;
1054 // Note: use div_up to account for "holes" in filter
1055 int i_t_overflow
1056 = div_up(max(0,
1057 (jcp.kh - 1) * dilate_h
1058 - ij - jcp.t_pad),
1059 dilate_h);
1060 int i_b_overflow = div_up(
1061 max(0,
1062 (jcp.kh - 1) * dilate_h + 1 - jcp.ih
1063 + ij - jcp.b_pad),
1064 dilate_h);
1065 k_len = jcp.kh - i_t_overflow - i_b_overflow;
1066 k_lo = i_b_overflow;
1067 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
1068 } else { // dilate == 0
1069 int i_t_overflow = max(0,
1070 (jcp.kh - 1 - ij - jcp.t_pad)
1071 / jcp.stride_h);
1072 int i_b_overflow = max(0,
1073 (jcp.kh - jcp.ih + ij - jcp.b_pad)
1074 / jcp.stride_h);
1075 int overflow_kh_hi = jcp.kh - 1
1076 - modulo(jcp.ih - 1 + jcp.b_pad - ij,
1077 jcp.stride_h);
1078 int overflow_kh_lo
1079 = (ij + jcp.t_pad) % jcp.stride_h;
1080
1081 k_len = (overflow_kh_hi - overflow_kh_lo)
1082 / jcp.stride_h
1083 + 1 - i_t_overflow - i_b_overflow;
1084 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
1085 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
1086 }
1087 assert(k_len >= 0);
1088
1089 jit_conv_3d_ker_pipeline(jit_ker, par_conv,
1090 diff_src_w + ij * diff_src_h_stride,
1091 diff_dst_w + oj * diff_dst_h_stride,
1092 wht_w + k_lo * wht_h_stride, nullptr, ocb,
1093 k_len, d_len, reduce_work, load_work);
1094 }
1095 diff_dst_w += diff_dst_c_stride;
1096 wht_w += wht_oc_stride;
1097 }
1098
1099 if (jcp.loop_order == loop_cwgn)
1100 nd_iterator_jump(start, end, icc, ic_chunks, gg, nb_groups,
1101 n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih);
1102 else if (jcp.loop_order == loop_gncw)
1103 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc,
1104 ic_chunks, id_s, jcp.id, ih_s, jcp.ih);
1105 else if (jcp.loop_order == loop_nhwcg) {
1106 ++start;
1107 nd_iterator_step(n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc,
1108 ic_chunks, gg, nb_groups);
1109 } else
1110 assert(!"unsupported loop order");
1111 }
1112 }
1113 });
1114}
1115
1116template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
1117
1118template <data_type_t src_type, data_type_t diff_dst_type,
1119 data_type_t diff_weights_type>
1120status_t jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1121 diff_weights_type>::init(engine_t *engine) {
1122 const auto &j = pd()->jcp_;
1123
1124 nthr_ = j.nthr;
1125 nthr_mb_ = j.nthr_mb;
1126 nthr_g_ = j.nthr_g;
1127 nthr_oc_b_ = j.nthr_oc_b;
1128 nthr_ic_b_ = j.nthr_ic_b;
1129
1130 CHECK(safe_ptr_assign(
1131 kernel_, new jit_avx512_common_conv_bwd_weights_kernel_f32(j)));
1132 CHECK(kernel_->create_kernel());
1133
1134 if (nthr_mb_ > 1) {
1135 CHECK(safe_ptr_assign(
1136 acc_ker_, new cpu_accumulator_1d_t<diff_weights_type>()));
1137 CHECK(acc_ker_->create_kernel());
1138 }
1139
1140 CHECK(safe_ptr_assign(reducer_bias_,
1141 new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_)));
1142 CHECK(reducer_bias_->create_kernel());
1143 return status::success;
1144}
1145
1146template <data_type_t src_type, data_type_t diff_dst_type,
1147 data_type_t diff_weights_type>
1148struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1149 diff_weights_type>::thread_info_t {
1150 const src_data_t *src;
1151 const diff_dst_data_t *diff_dst;
1152 const diff_weights_data_t *diff_weights;
1153 diff_weights_data_t *diff_bias;
1154
1155 const memory_tracking::grantor_t scratchpad;
1156
1157 src_data_t *tr_src;
1158 simple_barrier::ctx_t *tr_src_bctx;
1159
1160 diff_dst_data_t *tr_diff_dst;
1161 simple_barrier::ctx_t *tr_diff_dst_bctx;
1162
1163 diff_weights_data_t *wei_bia_reduction;
1164 simple_barrier::ctx_t *wei_bia_reduction_bctx;
1165
1166 int ithr;
1167 int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
1168 int ithr_but_oc;
1169 int ithr_but_ic;
1170
1171 int img_start = 0, img_end = 0, img_work;
1172 int g_start = 0, g_end = 0, g_work;
1173 int oc_b_start = 0, oc_b_end = 0, oc_b_work;
1174 int ic_b_start = 0, ic_b_end = 0, ic_b_work;
1175
1176 thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
1177 const exec_ctx_t &ctx, int ithr)
1178 : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) {
1179 diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
1180 src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1181 diff_weights
1182 = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS);
1183 const auto &jcp = self->kernel_->jcp;
1184 const bool is_bias_padded = self->pd()->with_bias()
1185 && jcp.oc_without_padding % jcp.oc_block != 0;
1186 diff_bias = is_bias_padded
1187 ? scratchpad.template get<diff_weights_data_t>(
1188 key_conv_padded_bias)
1189 : CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_BIAS);
1190
1191 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1192 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1193 key_conv_tr_src_bctx);
1194
1195 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
1196 key_conv_tr_diff_dst);
1197 tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1198 key_conv_tr_diff_dst_bctx);
1199
1200 wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
1201 key_conv_wei_bia_reduction);
1202 wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1203 key_conv_wei_bia_reduction_bctx);
1204
1205 ithr_ic_b = ithr % self->nthr_ic_b_;
1206 ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
1207 ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
1208 ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
1209
1210 ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
1211 + ithr_ic_b;
1212
1213 ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
1214 + ithr_oc_b;
1215
1216 /* reduction dimension */
1217 int oh_reduce = jcp.harness == harness_2d_reduction ? jcp.oh : 1;
1218 balance211(jcp.mb * jcp.od * oh_reduce, self->nthr_mb_, ithr_mb,
1219 img_start, img_end);
1220 img_work = img_end - img_start;
1221
1222 /* independent dimensions */
1223 balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
1224 g_work = g_end - g_start;
1225
1226 balance211(
1227 jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end);
1228 oc_b_work = oc_b_end - oc_b_start;
1229
1230 balance211(
1231 jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end);
1232 ic_b_work = ic_b_end - ic_b_start;
1233 }
1234};
1235
1236template <data_type_t src_type, data_type_t diff_dst_type,
1237 data_type_t diff_weights_type>
1238void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1239 diff_weights_type>::compute_diff_weights_nxc(const thread_info_t *ti)
1240 const {
1241 const auto &jcp = kernel_->jcp;
1242
1243 const int wei_size
1244 = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
1245 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1246 ? (diff_weights_data_t *)ti->diff_weights
1247 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1248
1249 auto diff_weights_offset
1250 = [&](int g, int i_kd, int i_kh, int i_kw, int i_ic, int i_oc) {
1251 const int oc_block_size = 1;
1252 const int ic_block_size = jcp.oc_block * oc_block_size;
1253 const int kw_block_size = jcp.ic_block * ic_block_size;
1254 const int kh_block_size = jcp.kw * kw_block_size;
1255 const int kd_block_size = jcp.kh * kh_block_size;
1256 const int icb_block_size = jcp.kd * kd_block_size;
1257 const int ocb_block_size = jcp.nb_ic * icb_block_size;
1258 const int g_block_size = jcp.nb_oc * ocb_block_size;
1259
1260 int icb = i_ic / jcp.ic_block;
1261 int ocb = i_oc / jcp.oc_block;
1262 i_ic = i_ic % jcp.ic_block;
1263 i_oc = i_oc % jcp.oc_block;
1264
1265 return g * g_block_size + ocb * ocb_block_size
1266 + icb * icb_block_size + i_kd * kd_block_size
1267 + i_kh * kh_block_size + i_kw * kw_block_size
1268 + i_ic * ic_block_size + i_oc * oc_block_size;
1269 };
1270 auto src_offset
1271 = [&](int g, int i_mb, int i_id, int i_ih, int i_ic, int i_iw) {
1272 const int ic_block_size = 1;
1273 const int g_block_size = jcp.ic * ic_block_size;
1274 const int iw_block_size = jcp.ngroups * g_block_size;
1275 const int ih_block_size = jcp.iw * iw_block_size;
1276 const int id_block_size = jcp.ih * ih_block_size;
1277 const int mb_block_size = jcp.id * id_block_size;
1278
1279 return g * g_block_size + i_mb * mb_block_size
1280 + i_id * id_block_size + i_ih * ih_block_size
1281 + i_iw * iw_block_size + i_ic * ic_block_size;
1282 };
1283 auto diff_dst_offset
1284 = [&](int g, int i_mb, int i_od, int i_oh, int i_ow, int i_oc) {
1285 const int oc_block_size = 1;
1286 const int g_block_size = jcp.oc * oc_block_size;
1287 const int ow_block_size = jcp.ngroups * g_block_size;
1288 const int oh_block_size = jcp.ow * ow_block_size;
1289 const int od_block_size = jcp.oh * oh_block_size;
1290 const int mb_block_size = jcp.od * od_block_size;
1291
1292 return g * g_block_size + i_mb * mb_block_size
1293 + i_od * od_block_size + i_oh * oh_block_size
1294 + i_ow * ow_block_size + i_oc * oc_block_size;
1295 };
1296 auto zero_diff_weights = [&]() {
1297 PRAGMA_OMP_SIMD()
1298 for (dim_t i = 0; i < wei_size; i++)
1299 diff_wei[i] = 0;
1300 };
1301
1302 int kd_step = jcp.dilate_d + 1;
1303 int kh_step = jcp.dilate_h + 1;
1304 int stride_d = jcp.stride_d;
1305 int stride_h = jcp.stride_h;
1306 int f_pad = jcp.f_pad;
1307 int t_pad = jcp.t_pad;
1308
1309 dim_t work_amount = jcp.mb * jcp.od * jcp.oh * jcp.nb_ow;
1310 dim_t i_work {0}, i_work_end {0};
1311 balance211(work_amount, jcp.nthr_mb, ti->ithr_mb, i_work, i_work_end);
1312
1313 int i_mb {0}, i_od {0}, i_oh {0}, i_owb {0};
1314 nd_iterator_init(
1315 i_work, i_mb, jcp.mb, i_od, jcp.od, i_oh, jcp.oh, i_owb, jcp.nb_ow);
1316
1317 zero_diff_weights();
1318 while (i_work < i_work_end) {
1319 int kd_start = nstl::max(
1320 0, div_up(jcp.f_pad - jcp.stride_d * i_od, kd_step));
1321 int kd_end = nstl::min(
1322 jcp.kd - 1, (jcp.id - 1 + f_pad - stride_d * i_od) / kd_step);
1323 int i_id_base = stride_d * i_od - f_pad;
1324 int kh_start = nstl::max(
1325 0, div_up(jcp.t_pad - jcp.stride_h * i_oh, +kh_step));
1326 int kh_end = nstl::min(
1327 jcp.kh - 1, (jcp.ih - 1 + t_pad - stride_h * i_oh) / kh_step);
1328 int i_ih_base = jcp.stride_h * i_oh + -jcp.t_pad;
1329 int i_ow_base = i_owb * jcp.ow_block;
1330 int i_ow_end = nstl::min(jcp.ow, i_ow_base + jcp.ow_block);
1331
1332 // The kernel is small so these loops produce measurable overhead. Since
1333 // these are simple loops, the compiler will likely make the loops just
1334 // as well as we can with the jitted assembly, so there is not
1335 // necessarily a reason to move these loops into assembly. Avoid placing
1336 // computationally heavy operations within the loops.
1337 for_(int i_ow = i_ow_base; i_ow < i_ow_end; i_ow += jcp.ur_ow)
1338 for_(int i_oc = 0; i_oc < jcp.oc; i_oc += jcp.oc_block)
1339 for_(int g = 0; g < jcp.ngroups; g++)
1340 for_(int i_kd = kd_start; i_kd <= kd_end; i_kd++)
1341 for (int i_kh = kh_start; i_kh <= kh_end; i_kh++) {
1342 // Some Optimization Observations: It may be
1343 // worthwhile to move the kd and kh loops below the
1344 // icb loop in the kernel to further amortize the
1345 // ddst register loads. Alternatively, these
1346 // dimensions are independent on the weights kernel,
1347 // so can be used as a threading dimension that does
1348 // not require reduction.
1349
1350 // The compiler seems to do a good job at optimizing these
1351 // computations. The offset functions likely need to be located
1352 // so that they will be inlined.
1353 int i_iw = i_ow * jcp.stride_w - jcp.l_pad;
1354 int i_id = i_id_base + i_kd * kd_step;
1355 int i_ih = i_ih_base + i_kh * kh_step;
1356 int ddst_offset = diff_dst_offset(g, i_mb, i_od, i_oh, i_ow, i_oc);
1357 int s_off_base = src_offset(g, i_mb, i_id, i_ih, 0, i_iw);
1358 int dwei_off_base = diff_weights_offset(g, i_kd, i_kh, 0, 0, i_oc);
1359 // ensure all parameters are 64bit, to comply with windows kernel
1360 // param access where the params from 5th are passed using stack.
1361 (*kernel_)(&diff_wei[dwei_off_base], &ti->src[s_off_base],
1362 &ti->diff_dst[ddst_offset], (dim_t)i_iw, (dim_t)i_ow);
1363 }
1364 nd_iterator_step(
1365 i_mb, jcp.mb, i_od, jcp.od, i_oh, jcp.oh, i_owb, jcp.nb_ow);
1366 i_work++;
1367 }
1368}
1369
1370template <data_type_t src_type, data_type_t diff_dst_type,
1371 data_type_t diff_weights_type>
1372void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1373 diff_weights_type>::compute_diff_weights(const thread_info_t *ti)
1374 const {
1375 const memory_desc_wrapper src_d(pd()->src_md());
1376 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1377 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1378
1379 const auto &jcp = kernel_->jcp;
1380 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
1381 const int padded_oc = rnd_up(jcp.oc, jcp.oc_block);
1382 const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block)
1383 * jcp.kh * jcp.kw * jcp.kd;
1384
1385 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1386 ? (diff_weights_data_t *)ti->diff_weights
1387 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1388
1389 const bool is_src_layout_nxc = utils::one_of(
1390 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1391
1392 int ic_b_step = jcp.nb_ic_blocking_max;
1393 int icb_work = ti->ic_b_end - ti->ic_b_start;
1394 if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step)
1395 ic_b_step = utils::div_up(icb_work, 2);
1396
1397 for (int img = ti->img_start; img < ti->img_end; ++img) {
1398 auto p = jit_conv_call_s();
1399
1400 const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc);
1401 const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic);
1402 const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag,
1403 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1404 for_(int g = ti->g_start; g < ti->g_end; ++g)
1405 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1406 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1407 ic_b += ic_b_step) {
1408 const int _oc = g * jcp.nb_oc + oc_b;
1409 const int _ic = g * jcp.nb_ic + ic_b;
1410 const int ic_to_compute = this_block_size(
1411 ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block);
1412 const int oc_to_compute = this_block_size(
1413 oc_b * jcp.oc_block, max_oc, jcp.oc_block);
1414
1415 const int ic_off_idx = is_src_layout_nxc
1416 ? g * jcp.ic + ic_b * jcp.ic_block
1417 : _ic;
1418 const int oc_off_idx = is_ddst_layout_nxc
1419 ? g * jcp.oc + oc_b * jcp.oc_block
1420 : _oc;
1421
1422 jit_conv_ker_pipeline_bwd_w(jit_ker, p,
1423 &ti->src[src_d.blk_off(img, ic_off_idx)],
1424 &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)],
1425 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1426 nullptr, (img == ti->img_start), 0, ic_to_compute,
1427 oc_to_compute);
1428 }
1429 }
1430}
1431
1432template <data_type_t src_type, data_type_t diff_dst_type,
1433 data_type_t diff_weights_type>
1434void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1435 diff_weights_type>::compute_diff_weights_2d(const thread_info_t *ti)
1436 const {
1437 const memory_desc_wrapper src_d(pd()->src_md());
1438 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1439 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1440
1441 const auto &jcp = kernel_->jcp;
1442 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
1443 const int padded_oc = rnd_up(jcp.oc, jcp.oc_block);
1444 const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block)
1445 * jcp.kh * jcp.kw;
1446
1447 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1448 ? (diff_weights_data_t *)ti->diff_weights
1449 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1450 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1451 ? (diff_weights_data_t *)ti->diff_bias
1452 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1453 + (ti->ithr_mb - 1) * jcp.ngroups * padded_oc;
1454
1455 int img {0}, oh_s {0};
1456 int img_start = ti->img_start, img_end = ti->img_end;
1457 nd_iterator_init(img_start, img, jcp.mb, oh_s, jcp.oh);
1458 const int img_first = img;
1459
1460 int ic_b_step = jcp.nb_ic_blocking_max;
1461 int icb_work = ti->ic_b_end - ti->ic_b_start;
1462 if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step)
1463 ic_b_step = utils::div_up(icb_work, 2);
1464 while (img_start < img_end) {
1465 auto p = jit_conv_call_s();
1466
1467 int work_rem = img_end - img_start;
1468 const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1469 const int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
1470 const int kh_top_overflow = nstl::max(0, -ih_s);
1471 const int kh_bottom_overflow = nstl::max(0, ih_s - jcp.ih + jcp.kh);
1472 int kh_padding = jcp.kh - kh_top_overflow - kh_bottom_overflow;
1473 int kh_padding_offset = nstl::min(jcp.kh - 1, kh_top_overflow) * jcp.kw
1474 * jcp.ic_block * jcp.oc_block * jcp.typesize_out;
1475 auto src_h = ti->src + src_d.blk_off(img, 0, ih_s + kh_top_overflow);
1476 auto diff_dst_h = ti->diff_dst + diff_dst_d.blk_off(img, 0, oh_s);
1477
1478 const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc;
1479 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc;
1480 const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc);
1481 const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic);
1482 for_(int g = ti->g_start; g < ti->g_end; ++g)
1483 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1484 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1485 ic_b += ic_b_step) {
1486 const int _oc = g * jcp.nb_oc + oc_b;
1487 const int _ic = g * jcp.nb_ic + ic_b;
1488 const int ic_to_compute = this_block_size(
1489 ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block);
1490 const int oc_to_compute = this_block_size(
1491 oc_b * jcp.oc_block, max_oc, jcp.oc_block);
1492 const int ic_off_idx = is_src_layout_nxc
1493 ? g * jcp.ic + ic_b * jcp.ic_block
1494 : _ic;
1495 const int oc_off_idx = is_ddst_layout_nxc
1496 ? g * jcp.oc + oc_b * jcp.oc_block
1497 : _oc;
1498 auto src = src_h + src_d.blk_off(0, ic_off_idx);
1499 auto diff_dst = diff_dst_h + diff_dst_d.blk_off(0, oc_off_idx);
1500 p.flags = ic_b == 0 ? 0 : 1;
1501 jit_conv_2d_ker_bwd_w_pipeline(jit_ker, p, src, diff_dst,
1502 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1503 diff_bia + _oc * jcp.oc_block, (img == img_first), oh_s,
1504 oh_e, kh_padding, kh_padding_offset, ic_to_compute,
1505 oc_to_compute);
1506 }
1507 nd_iterator_jump(img_start, img_end, img, jcp.mb, oh_s, jcp.oh);
1508 }
1509}
1510
1511template <data_type_t src_type, data_type_t diff_dst_type,
1512 data_type_t diff_weights_type>
1513void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1514 diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti)
1515 const {
1516 const memory_desc_wrapper src_d(pd()->src_md());
1517 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1518 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1519
1520 const auto &jcp = kernel_->jcp;
1521 const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker();
1522 const int padded_oc = rnd_up(jcp.oc, jcp.oc_block);
1523 const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block)
1524 * jcp.kh * jcp.kw * jcp.kd;
1525
1526 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1527 ? (diff_weights_data_t *)ti->diff_weights
1528 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1529 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1530 ? (diff_weights_data_t *)ti->diff_bias
1531 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1532 + (ti->ithr_mb - 1) * jcp.ngroups * padded_oc;
1533
1534 const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc;
1535 const int inp_mult = is_src_layout_nxc
1536 ? jcp.ngroups * jcp.ic
1537 : (jcp.is_1stconv ? 1 : jcp.ic_block);
1538 const int input_step = jcp.ih * jcp.iw * inp_mult;
1539 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc;
1540 const int output_step = jcp.ow * jcp.oh
1541 * (is_ddst_layout_nxc ? jcp.ngroups * jcp.oc : jcp.oc_block);
1542 int img {0}, od_s {0};
1543 int img_start = ti->img_start, img_end = ti->img_end;
1544 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
1545 const int img_first = img;
1546
1547 int ic_b_step = jcp.nb_ic_blocking_max;
1548 int icb_work = ti->ic_b_end - ti->ic_b_start;
1549 if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step)
1550 ic_b_step = utils::div_up(icb_work, 2);
1551
1552 while (img_start < img_end) {
1553 auto p = jit_conv_call_s();
1554
1555 int work_rem = img_end - img_start;
1556 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1557 const int id_s = od_s * jcp.stride_d;
1558 const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
1559 const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
1560 const int kd_back_pad
1561 = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd);
1562 int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw
1563 * jcp.ic_block * jcp.oc_block * jcp.typesize_out;
1564
1565 const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc);
1566 const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic);
1567
1568 for_(int g = ti->g_start; g < ti->g_end; ++g)
1569 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1570 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1571 ic_b += ic_b_step) {
1572 const int _oc = g * jcp.nb_oc + oc_b;
1573 const int _ic = g * jcp.nb_ic + ic_b;
1574
1575 const int ic_to_compute = this_block_size(
1576 ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block);
1577 const int oc_to_compute = this_block_size(
1578 oc_b * jcp.oc_block, max_oc, jcp.oc_block);
1579
1580 const int ic_off_idx = is_src_layout_nxc
1581 ? g * jcp.ic + ic_b * jcp.ic_block
1582 : _ic;
1583 const int oc_off_idx = is_ddst_layout_nxc
1584 ? g * jcp.oc + oc_b * jcp.oc_block
1585 : _oc;
1586 auto src = &ti->src[src_d.blk_off(img, ic_off_idx)
1587 + ik_overlap * input_step];
1588 auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)
1589 + od_s * output_step];
1590 auto diff_bia_ptr = diff_bia ? diff_bia + _oc * 16 : nullptr;
1591 p.flags = ic_b == 0 ? 0 : 1;
1592 jit_conv_3d_ker_bwd_w_pipeline(jit_ker, p, src, dst,
1593 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1594 diff_bia_ptr, (img == img_first), od_s, od_e,
1595 jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off,
1596 ic_to_compute, oc_to_compute);
1597 }
1598 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
1599 }
1600}
1601
1602template <data_type_t src_type, data_type_t diff_dst_type,
1603 data_type_t diff_weights_type>
1604void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1605 diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
1606 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1607
1608 const auto &jcp = kernel_->jcp;
1609 const int padded_oc = rnd_up(jcp.oc, jcp.oc_block);
1610 const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block)
1611 * jcp.kh * jcp.kw;
1612
1613 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1614 if (dnnl_thr_syncable())
1615 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1616
1617 const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
1618 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1619
1620 int start {0}, end {0};
1621 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1622 if (start == end) return;
1623
1624 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1625 int w = start;
1626 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0};
1627 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1628 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1629 while (w < end) {
1630 const int g = ti->g_start + sub_g_start;
1631 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1632 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
1633 const int kh = sub_ic_b_kh_start % jcp.kh;
1634
1635 const int acc_size
1636 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1637 * jcp.kw * jcp.ic_block * jcp.oc_block;
1638
1639 const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
1640
1641 diff_weights_data_t *d
1642 = (diff_weights_data_t *)ti->diff_weights + off;
1643 diff_weights_data_t *s
1644 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1645
1646 acc_ker_->accumulate(d, s, acc_size);
1647
1648 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1649 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1650 }
1651 }
1652}
1653
1654template <data_type_t src_type, data_type_t diff_dst_type,
1655 data_type_t diff_weights_type>
1656void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1657 diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti)
1658 const {
1659 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1660
1661 const auto &jcp = kernel_->jcp;
1662 const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block)
1663 * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw * jcp.kd;
1664
1665 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1666 if (dnnl_thr_syncable())
1667 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1668
1669 const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
1670 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1671
1672 int start {0}, end {0};
1673 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1674 if (start == end) return;
1675
1676 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1677 int w = start;
1678 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0};
1679 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1680 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1681 while (w < end) {
1682 const int g = ti->g_start + sub_g_start;
1683 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1684 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
1685 const int kd = sub_ic_b_kh_start % jcp.kd;
1686
1687 const int acc_size
1688 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1689 * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
1690
1691 const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
1692 diff_weights_data_t *d
1693 = (diff_weights_data_t *)ti->diff_weights + off;
1694 diff_weights_data_t *s
1695 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1696 acc_ker_->accumulate(d, s, acc_size);
1697
1698 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1699 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1700 }
1701 }
1702}
1703
1704template <data_type_t src_type, data_type_t diff_dst_type,
1705 data_type_t diff_weights_type>
1706void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1707 diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
1708 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1709
1710 auto rb = this->reducer_bias_.get();
1711 assert(nthr_ == rb->balancer().nthr_);
1712
1713 const auto reducer_bia_scratchpad
1714 = memory_tracking::grantor_t(ti->scratchpad, prefix_reducer_bia);
1715
1716 const auto &jcp = kernel_->jcp;
1717
1718 const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
1719 const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
1720
1721 if (b_njobs == 0) return;
1722
1723 /* reduction dimension */
1724 int img_start {0}, img_end {0};
1725 balance211(jcp.mb, rb->balancer().nthr_per_group_,
1726 rb->balancer().id_in_group(ti->ithr), img_start, img_end);
1727
1728 /* jobs */
1729 int g_start {0}, ocb_start {0};
1730 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
1731 for (int img = img_start; img < img_end; ++img) {
1732 int g = g_start, ocb = ocb_start;
1733 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
1734 const size_t _oc = g * jcp.nb_oc + ocb;
1735 const int max_oc
1736 = this_block_size(ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
1737
1738 const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag,
1739 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1740 const int oc_off_idx = is_ddst_layout_nxc
1741 ? g * jcp.oc + ocb * jcp.oc_block
1742 : _oc;
1743 const diff_dst_data_t *d_dst
1744 = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)];
1745 diff_weights_data_t *d_bias
1746 = rb->get_local_ptr(
1747 ti->ithr, ti->diff_bias, reducer_bia_scratchpad)
1748 + b_job_loc * rb->balancer().job_size_;
1749
1750 if (img == img_start)
1751 for (int o = 0; o < jcp.oc_block; ++o)
1752 d_bias[o] = 0;
1753 for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
1754 PRAGMA_OMP_SIMD()
1755 for (int o = 0; o < max_oc; ++o)
1756 d_bias[o] += d_dst[o];
1757 d_dst += is_ddst_layout_nxc ? jcp.ngroups * jcp.oc
1758 : jcp.oc_block;
1759 }
1760
1761 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
1762 }
1763 }
1764
1765 if (dnnl_thr_syncable())
1766 rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
1767}
1768
1769template <data_type_t src_type, data_type_t diff_dst_type,
1770 data_type_t diff_weights_type>
1771void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1772 diff_weights_type>::reduce_diff_bias(const thread_info_t *ti) const {
1773 const auto &jcp = kernel_->jcp;
1774
1775 const size_t wei_size = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block)
1776 * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw * jcp.kd;
1777 const int bia_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block);
1778 const diff_weights_data_t *diff_bias_ws
1779 = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
1780
1781 if (dnnl_thr_syncable() && nthr_mb_ > 1) dnnl_thr_barrier();
1782
1783 if (ti->ithr == 0) {
1784 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1785 acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
1786 diff_bias_ws += bia_size;
1787 }
1788 }
1789}
1790
1791template <data_type_t src_type, data_type_t diff_dst_type,
1792 data_type_t diff_weights_type>
1793void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1794 diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx)
1795 const {
1796 auto scratchpad = ctx.get_scratchpad_grantor();
1797
1798 if (dnnl_thr_syncable() && nthr_mb_ > 1) {
1799 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1800 key_conv_wei_bia_reduction_bctx));
1801 }
1802
1803 const auto reducer_bia_scratchpad
1804 = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
1805 auto rb = this->reducer_bias_.get();
1806 rb->init(reducer_bia_scratchpad);
1807}
1808
1809template <data_type_t src_type, data_type_t diff_dst_type,
1810 data_type_t diff_weights_type>
1811void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1812 diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx)
1813 const {
1814 prepare_scratchpad_data(ctx);
1815
1816#if DNNL_THR_SYNC == 1
1817 parallel(nthr_, [&](const int ithr, const int nthr) {
1818 assert(nthr_ == nthr);
1819
1820 thread_info_t thread_info(this, ctx, ithr);
1821
1822 switch (pd()->jcp_.harness) {
1823 case harness_2d_reduction:
1824 compute_diff_weights_2d(&thread_info);
1825 if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
1826 if (pd()->with_bias()) reduce_diff_bias(&thread_info);
1827 break;
1828 case harness_3d_reduction:
1829 compute_diff_weights_3d(&thread_info);
1830 if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
1831 if (pd()->with_bias()) reduce_diff_bias(&thread_info);
1832 break;
1833 case harness_mb_reduction:
1834 compute_diff_weights(&thread_info);
1835 if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
1836 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1837 break;
1838 case harness_nxc:
1839 compute_diff_weights_nxc(&thread_info);
1840 if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
1841 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1842 break;
1843 default: assert(!"Invalid harness type");
1844 }
1845 });
1846#else
1847 parallel(nthr_, [&](const int ithr, const int nthr) {
1848 thread_info_t thread_info(this, ctx, ithr);
1849 switch (pd()->jcp_.harness) {
1850 case harness_nxc:
1851 compute_diff_weights_nxc(&thread_info);
1852 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1853 break;
1854 case harness_2d_reduction:
1855 compute_diff_weights_2d(&thread_info);
1856 break;
1857 case harness_3d_reduction:
1858 compute_diff_weights_3d(&thread_info);
1859 break;
1860 case harness_mb_reduction:
1861 compute_diff_weights(&thread_info);
1862 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1863 break;
1864 default: assert(!"Invalid harness type");
1865 }
1866 });
1867
1868 parallel(nthr_, [&](const int ithr, const int nthr) {
1869 thread_info_t thread_info(this, ctx, ithr);
1870 if (nthr_mb_ > 1) {
1871 switch (pd()->jcp_.harness) {
1872 case harness_mb_reduction:
1873 case harness_2d_reduction:
1874 reduce_diff_weights(&thread_info);
1875 break;
1876 case harness_nxc:
1877 case harness_3d_reduction:
1878 reduce_diff_weights_3d(&thread_info);
1879 break;
1880 default: assert(!"Invalid harness type");
1881 }
1882 }
1883 if (pd()->with_bias()) {
1884 switch (pd()->jcp_.harness) {
1885 case harness_2d_reduction:
1886 case harness_3d_reduction:
1887 reduce_diff_bias(&thread_info);
1888 break;
1889 case harness_nxc:
1890 case harness_mb_reduction: {
1891 auto rb = this->reducer_bias_.get();
1892 assert(nthr == rb->balancer().nthr_);
1893 if (rb->balancer().ithr_njobs(ithr) == 0) return;
1894 const auto reducer_bia_scratchpad
1895 = memory_tracking::grantor_t(
1896 thread_info.scratchpad, prefix_reducer_bia);
1897 rb->reduce_nolock(thread_info.ithr, thread_info.diff_bias,
1898 reducer_bia_scratchpad);
1899 } break;
1900 default: assert(!"Invalid harness type");
1901 }
1902 }
1903 });
1904#endif
1905
1906 /* TODO: put that into compute_diff_bias() */
1907 auto &jcp = pd()->jcp_;
1908 if (pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0) {
1909 auto diff_bias = ctx.get_scratchpad_grantor()
1910 .template get<const diff_weights_data_t>(
1911 key_conv_padded_bias);
1912 auto diff_bias_in
1913 = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_BIAS);
1914 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
1915 const int stride = jcp.oc_without_padding;
1916 for (int g = 0; g < jcp.ngroups; ++g) {
1917 utils::array_copy(diff_bias_in + g * stride,
1918 diff_bias + g * padded_stride, stride);
1919 }
1920 }
1921}
1922
1923template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
1924
1925} // namespace x64
1926} // namespace cpu
1927} // namespace impl
1928} // namespace dnnl
1929
1930// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1931