1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/jit_avx2_convolution.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::memory_tracking::names;
31using namespace dnnl::impl::utils;
32using namespace nstl;
33
34#define src_blk_off(f, n, c, d, h, w) \
35 (pd()->ndims() == 3) ? (f).blk_off(n, c, w) \
36 : (pd()->ndims() == 4) ? (f).blk_off(n, c, h, w) \
37 : (f).blk_off(n, c, d, h, w)
38
39#define wht_blk_off_(f, g, ...) \
40 pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
41#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
42 (pd()->ndims() == 3) \
43 ? wht_blk_off_(f, g, oc, ic, kw) \
44 : (pd()->ndims() == 4) ? wht_blk_off_(f, g, oc, ic, kh, kw) \
45 : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
46
47void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
48 const auto &jcp = kernel_->jcp;
49 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
50 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
51 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
52 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
53 const auto post_ops_binary_rhs_arg_vec
54 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
55
56 const memory_desc_wrapper src_d(pd()->src_md());
57 const memory_desc_wrapper dst_d(pd()->dst_md());
58 const memory_desc_wrapper weights_d(pd()->weights_md(0));
59 const memory_desc_wrapper bias_d(pd()->weights_md(1));
60
61 const size_t ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
62 const size_t work_amount
63 = jcp.mb * jcp.ngroups * ocb_work * jcp.od * jcp.oh;
64
65 auto ker = [&](const int ithr, const int nthr) {
66 size_t start {0}, end {0};
67 balance211(work_amount, nthr, ithr, start, end);
68
69 bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
70 format_tag::nChw8c, format_tag::nCdhw8c);
71 int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
72 int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
73
74 bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
75 format_tag::nChw8c, format_tag::nCdhw8c);
76 int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
77 int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
78 int oc_bias_scale = is_oc_physically_blocked ? jcp.oc_block : 1;
79
80 int icbb = 0;
81 while (icbb < jcp.nb_ic) {
82 int icb_step = jcp.nb_ic_blocking;
83 int icb_step_rem = jcp.nb_ic - icbb;
84 if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem;
85
86 size_t n {0}, g {0}, ocbb {0}, oh {0}, od {0};
87 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
88 od, jcp.od, oh, jcp.oh);
89 for (size_t iwork = start; iwork < end; ++iwork) {
90 int ocb = ocbb * jcp.nb_oc_blocking;
91 int ocb_num = jcp.nb_oc_blocking;
92
93 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
94 auto par_conv = jit_conv_call_s();
95
96 const int ij = oh * jcp.stride_h;
97 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
98 const int i_b_overflow
99 = nstl::max(jcp.ih,
100 ij + (jcp.kh - 1) * (jcp.dilate_h + 1)
101 - jcp.t_pad + 1)
102 - jcp.ih;
103
104 const int dj = od * jcp.stride_d;
105 const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
106 const int d_b_overflow
107 = nstl::max(jcp.id,
108 dj + (jcp.kd - 1) * (jcp.dilate_d + 1)
109 - jcp.f_pad + 1)
110 - jcp.id;
111
112 const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
113 const size_t _ic = g * g_ic_offset + icb * icb_ic_scale;
114
115 const int ih = nstl::max(ij - jcp.t_pad
116 + div_up(i_t_overflow, (jcp.dilate_h + 1))
117 * (jcp.dilate_h + 1),
118 0);
119
120 const int id = nstl::max(dj - jcp.f_pad
121 + div_up(d_t_overflow, (jcp.dilate_d + 1))
122 * (jcp.dilate_d + 1),
123 0);
124
125 par_conv.src = &src[src_blk_off(src_d, n, _ic, id, ih, 0)];
126
127 par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
128
129 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
130 const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
131 par_conv.filt = &weights[wht_blk_off(
132 weights_d, g, ocb, icb, wd, wh, 0)];
133
134 if (icb == 0) {
135 if (bias)
136 par_conv.bias = &bias[bias_d.blk_off(
137 _oc * oc_bias_scale)];
138
139 par_conv.flags |= FLAG_IC_FIRST;
140 }
141
142 if ((jcp.with_eltwise || jcp.with_binary)
143 && icb + 1 == jcp.nb_ic)
144 par_conv.flags |= FLAG_IC_LAST;
145
146 par_conv.reduce_work = this_block_size(
147 icb * jcp.ic_block, jcp.ic, jcp.ic_block);
148
149 par_conv.oc_blocks
150 = nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
151
152 if (ocbb == ocb_work - 1) par_conv.oc_flag |= FLAG_OC_LAST;
153
154 par_conv.kw_padding = 0;
155 const int kh_padding = jcp.kh
156 - div_up(i_t_overflow, (jcp.dilate_h + 1))
157 - div_up(i_b_overflow, (jcp.dilate_h + 1));
158 par_conv.kh_padding = nstl::max(0, kh_padding);
159
160 const int kd_padding = jcp.kd
161 - div_up(d_t_overflow, (jcp.dilate_d + 1))
162 - div_up(d_b_overflow, (jcp.dilate_d + 1));
163 par_conv.kd_padding = nstl::max(0, kd_padding);
164
165 par_conv.oc_l_off = _oc * oc_bias_scale;
166 par_conv.post_ops_binary_rhs_arg_vec
167 = post_ops_binary_rhs_arg_vec.data();
168 par_conv.dst_orig = dst;
169
170 (*kernel_)(&par_conv);
171 }
172 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, od,
173 jcp.od, oh, jcp.oh);
174 }
175 icbb += icb_step;
176 }
177 };
178
179 if (pd()->wants_padded_bias()) {
180 auto padded_bias = ctx.get_scratchpad_grantor().get<data_t>(
181 key_conv_padded_bias);
182 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
183 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
184 jcp.oc - jcp.oc_without_padding);
185 bias = padded_bias;
186 }
187
188 parallel(jcp.nthr, ker);
189
190 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
191}
192
193void jit_avx2_convolution_bwd_data_t::execute_backward_data(
194 const exec_ctx_t &ctx) const {
195 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
196 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
197 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
198
199 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
200 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
201 const memory_desc_wrapper weights_d(pd()->weights_md(0));
202
203 const auto &jcp = kernel_->jcp;
204
205 int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
206 int ih_block_size = jcp.ih;
207 int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
208 size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks;
209
210 const auto data_size = sizeof(data_t);
211 const auto L2 = platform::get_per_core_cache_size(2) / data_size;
212 // input + output + weights per iteration by nb_oc_blocking
213 auto ic_chunk = jcp.nb_ic_blocking * jcp.ic_block;
214 auto oc_chunk = jcp.nb_oc_blocking * jcp.oc_block;
215 auto iter_data_amount = (size_t)jcp.id * jcp.ih * jcp.iw * ic_chunk
216 + (size_t)jcp.od * jcp.oh * jcp.ow * oc_chunk
217 + (size_t)jcp.kd * jcp.kh * jcp.kw * ic_chunk * oc_chunk;
218
219 if (work_amount < (size_t)2 * jcp.nthr || iter_data_amount > L2) {
220 ih_block_size = 1;
221 num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
222 work_amount *= num_ih_blocks;
223 }
224
225 const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
226 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
227
228 bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
229 format_tag::nChw8c, format_tag::nCdhw8c);
230 int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
231 int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
232
233 bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
234 format_tag::nChw8c, format_tag::nCdhw8c);
235 int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
236 int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
237
238 const bool is_ddst_layout_nxc = one_of(
239 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
240 const int oc_step = is_ddst_layout_nxc ? jcp.nb_oc_blocking : 1;
241
242 auto ker = [&](const int ithr, const int nthr) {
243 size_t start {0}, end {0};
244 balance211(work_amount, nthr, ithr, start, end);
245
246 size_t n {0}, g {0}, icbb {0}, ihb {0};
247 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
248 num_ih_blocks);
249 for (size_t iwork = start; iwork < end; ++iwork) {
250 for_(int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
251 for (int id = 0; id < jcp.id; ++id) {
252 int cur_nb_oc = nstl::min(jcp.nb_oc - oc, jcp.nb_oc_blocking);
253
254 auto par_conv = jit_conv_call_s();
255
256 int d_t_overflow, d_b_overflow, od;
257 if (jcp.dilate_d != 0) { // stride == 1
258 const int dilate_d = jcp.dilate_d + 1;
259 d_t_overflow
260 = div_up(nstl::max(0, ext_kd - 1 - id - jcp.f_pad),
261 dilate_d);
262 d_b_overflow = div_up(
263 nstl::max(0, ext_kd - jcp.id + id - jcp.back_pad),
264 dilate_d);
265 od = id + jcp.f_pad - d_b_overflow * dilate_d;
266 } else {
267 d_t_overflow = nstl::max(0, jcp.kd - 1 - id - jcp.f_pad);
268 d_b_overflow = nstl::max(
269 0, jcp.kd - 1 - (jcp.id - 1 - id) - jcp.back_pad);
270 od = id + jcp.f_pad - d_b_overflow;
271 }
272 par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
273
274 int ih_start = ihb * ih_block_size;
275 int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
276 for (int ih = ih_start; ih < ih_end; ++ih) {
277
278 int k_lo, oh;
279 if (jcp.dilate_h != 0) { // stride == 1
280 const int dilate_h = jcp.dilate_h + 1;
281 int i_t_overflow = div_up(
282 nstl::max(0, ext_kh - 1 - ih - jcp.t_pad),
283 dilate_h);
284 int i_b_overflow = div_up(
285 nstl::max(0, ext_kh - jcp.ih + ih - jcp.b_pad),
286 dilate_h);
287 par_conv.kh_padding
288 = jcp.kh - i_t_overflow - i_b_overflow;
289 k_lo = i_b_overflow;
290 oh = ih + jcp.t_pad - k_lo * dilate_h;
291 } else {
292 int i_t_overflow = nstl::max(0,
293 (jcp.kh - 1 - ih - jcp.t_pad) / jcp.stride_h);
294 int i_b_overflow = nstl::max(0,
295 (jcp.kh - jcp.ih + ih - jcp.b_pad)
296 / jcp.stride_h);
297 int overflow_kh_hi = jcp.kh - 1
298 - modulo(jcp.ih - 1 + jcp.b_pad - ih,
299 jcp.stride_h);
300 int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
301
302 par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
303 / jcp.stride_h
304 + 1 - i_t_overflow - i_b_overflow;
305
306 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
307 oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
308 }
309 par_conv.kw_padding = 0;
310
311 par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
312 g * g_ic_offset
313 + jcp.nb_ic_blocking * icbb * icb_ic_scale,
314 id, ih, 0)];
315 par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, n,
316 g * g_oc_offset + ocb_oc_scale * oc, od, oh, 0)];
317 par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
318 jcp.nb_ic_blocking * icbb, d_b_overflow, k_lo, 0)];
319
320 par_conv.src_prf = nullptr;
321 par_conv.dst_prf = nullptr;
322 par_conv.filt_prf = nullptr;
323 par_conv.channel = oc;
324 par_conv.ch_blocks = cur_nb_oc;
325
326 if (is_ddst_layout_nxc) {
327 par_conv.load_work = this_block_size(
328 icbb * jcp.nb_ic_blocking * jcp.ic_block,
329 (size_t)jcp.ic,
330 jcp.nb_ic_blocking * jcp.ic_block);
331 par_conv.reduce_work
332 = this_block_size(oc * jcp.oc_block, jcp.oc,
333 oc_step * jcp.oc_block);
334
335 if (par_conv.load_work % jcp.ic_block > 0)
336 par_conv.flags |= FLAG_IC_LAST;
337 }
338
339 (*kernel_)(&par_conv);
340 }
341 }
342 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
343 num_ih_blocks);
344 }
345 };
346
347 parallel(jcp.nthr, ker);
348}
349
350void jit_avx2_convolution_bwd_weights_t::execute_backward_weights(
351 const exec_ctx_t &ctx) const {
352 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
353 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
354 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
355 auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
356
357 auto scratchpad = ctx.get_scratchpad_grantor();
358
359 const auto &jcp = kernel_->jcp;
360
361 const bool is_bias_padded
362 = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0);
363
364 data_t *diff_bias = is_bias_padded
365 ? scratchpad.get<data_t>(key_conv_padded_bias)
366 : diff_bias_in;
367
368 const memory_desc_wrapper src_d(pd()->src_md());
369 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
370 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
371
372 auto reducer_bia_scratchpad
373 = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
374 auto rb = this->reducer_bias_.get();
375 rb->init(reducer_bia_scratchpad);
376
377 auto reducer_wei_scratchpad
378 = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei);
379 auto rw = this->reducer_weights_.get();
380 rw->init(reducer_wei_scratchpad);
381
382 bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c,
383 format_tag::nChw8c, format_tag::nCdhw8c);
384 int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic;
385 int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block;
386
387 bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c,
388 format_tag::nChw8c, format_tag::nCdhw8c);
389 bool is_ddst_layout_nxc = !is_oc_physically_blocked;
390 int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc;
391 int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block;
392
393 auto ker = [&](int ithr, int nthr) {
394 assert(nthr == rw->balancer().nthr_);
395
396 const int w_job_start = rw->balancer().ithr_job_off(ithr);
397 const int w_njobs = rw->balancer().ithr_njobs(ithr);
398
399 if (w_njobs == 0) return;
400
401 /* reduction dimension */
402 int img_od_start {0}, img_od_end {0}, img {0}, od_s {0};
403 balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
404 rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
405
406 int img_start = img_od_start, img_end = img_od_end;
407 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
408 const int img_first = img;
409
410 /* jobs */
411 int g_start {0}, ocb_start {0}, icb_start {0};
412 nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
413 jcp.nb_oc, icb_start, jcp.nb_ic);
414
415 while (img_start < img_end) {
416 int g = g_start, ocb = ocb_start, icb = icb_start;
417
418 const int work_rem = img_end - img_start;
419 const int od_e
420 = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
421 const int id_s = od_s * jcp.stride_d;
422 const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
423
424 if (id_s < idp - jcp.back_pad - jcp.kd + 1)
425 for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
426 const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
427 const size_t _ic = g * g_ic_offset + icb * icb_ic_scale;
428
429 /* TODO: put dw <-- 0 in kernel */
430 if (img == img_first)
431 array_set(rw->get_local_ptr(ithr, diff_weights,
432 reducer_wei_scratchpad)
433 + w_job_loc * rw->balancer().job_size_,
434 0, rw->balancer().job_size_);
435
436 for (int od = od_s; od < od_e; ++od) {
437 const int id = od * jcp.stride_d;
438 if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
439
440 auto par_conv = jit_conv_call_s();
441 par_conv.src
442 = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
443 par_conv.dst = &diff_dst[src_blk_off(
444 diff_dst_d, img, _oc, od, 0, 0)];
445 par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
446 reducer_wei_scratchpad)
447 + w_job_loc * rw->balancer().job_size_;
448
449 if (ocb == jcp.nb_oc - 1)
450 par_conv.flags |= FLAG_OC_LAST;
451
452 par_conv.channel = this_block_size(
453 icb * jcp.ic_block, jcp.ic, jcp.ic_block);
454
455 (*kernel_)(&par_conv);
456 }
457 nd_iterator_step(
458 g, jcp.ngroups, ocb, jcp.nb_oc, icb, jcp.nb_ic);
459 }
460 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
461 }
462
463 if (dnnl_thr_syncable())
464 rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
465 };
466
467 auto ker_bias = [&](int ithr, int nthr) {
468 assert(nthr == rb->balancer().nthr_);
469
470 const int b_job_start = rb->balancer().ithr_job_off(ithr);
471 const int b_njobs = rb->balancer().ithr_njobs(ithr);
472
473 if (b_njobs == 0) return;
474
475 /* reduction dimension */
476 int img_start {0}, img_end {0};
477 balance211(jcp.mb, rb->balancer().nthr_per_group_,
478 rb->balancer().id_in_group(ithr), img_start, img_end);
479
480 /* jobs */
481 int g_start {0}, ocb_start {0};
482 nd_iterator_init(
483 b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
484
485 for (int img = img_start; img < img_end; ++img) {
486 int g = g_start, ocb = ocb_start;
487 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
488 const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale;
489
490 const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
491 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
492 reducer_bia_scratchpad)
493 + b_job_loc * rb->balancer().job_size_;
494
495 if (img == img_start)
496 for (int o = 0; o < jcp.oc_block; ++o)
497 d_bias[o] = 0.;
498
499 const int max_oc = this_block_size(
500 ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
501
502 for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
503 PRAGMA_OMP_SIMD()
504 for (int o = 0; o < max_oc; ++o)
505 d_bias[o] += d_dst[o];
506 d_dst += is_ddst_layout_nxc ? jcp.ngroups * jcp.oc
507 : jcp.oc_block;
508 }
509
510 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
511 }
512 }
513
514 if (dnnl_thr_syncable())
515 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
516 };
517
518 if (dnnl_thr_syncable()) {
519 assert(IMPLICATION(pd()->with_bias(),
520 rw->balancer().nthr_ == rb->balancer().nthr_));
521 parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) {
522 ker(ithr, nthr);
523 if (pd()->with_bias()) ker_bias(ithr, nthr);
524 });
525 } else {
526 parallel(rw->balancer().nthr_,
527 [&](int ithr, int nthr) { ker(ithr, nthr); });
528 parallel(rw->balancer().nthr_, [&](int ithr, int nthr) {
529 assert(nthr == rw->balancer().nthr_);
530 MAYBE_UNUSED(nthr);
531 if (rw->balancer().ithr_njobs(ithr) == 0) return;
532 rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad);
533 });
534 if (pd()->with_bias()) {
535 parallel(rb->balancer().nthr_,
536 [&](int ithr, int nthr) { ker_bias(ithr, nthr); });
537 parallel(rb->balancer().nthr_, [&](int ithr, int nthr) {
538 assert(nthr == rb->balancer().nthr_);
539 MAYBE_UNUSED(nthr);
540 if (rb->balancer().ithr_njobs(ithr) == 0) return;
541 rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad);
542 });
543 }
544 }
545
546 /* TODO: put this in ker_bias */
547 if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0)) {
548 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
549 const int stride = jcp.oc_without_padding;
550 for (int g = 0; g < jcp.ngroups; ++g)
551 utils::array_copy(diff_bias_in + g * stride,
552 diff_bias + g * padded_stride, stride);
553 }
554}
555
556} // namespace x64
557} // namespace cpu
558} // namespace impl
559} // namespace dnnl
560
561// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
562