1 | /******************************************************************************* |
2 | * Copyright 2016-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/jit_avx2_convolution.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | |
29 | using namespace dnnl::impl::status; |
30 | using namespace dnnl::impl::memory_tracking::names; |
31 | using namespace dnnl::impl::utils; |
32 | using namespace nstl; |
33 | |
34 | #define src_blk_off(f, n, c, d, h, w) \ |
35 | (pd()->ndims() == 3) ? (f).blk_off(n, c, w) \ |
36 | : (pd()->ndims() == 4) ? (f).blk_off(n, c, h, w) \ |
37 | : (f).blk_off(n, c, d, h, w) |
38 | |
39 | #define wht_blk_off_(f, g, ...) \ |
40 | pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__) |
41 | #define wht_blk_off(f, g, oc, ic, kd, kh, kw) \ |
42 | (pd()->ndims() == 3) \ |
43 | ? wht_blk_off_(f, g, oc, ic, kw) \ |
44 | : (pd()->ndims() == 4) ? wht_blk_off_(f, g, oc, ic, kh, kw) \ |
45 | : wht_blk_off_(f, g, oc, ic, kd, kh, kw) |
46 | |
47 | void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
48 | const auto &jcp = kernel_->jcp; |
49 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
50 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
51 | auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
52 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
53 | const auto post_ops_binary_rhs_arg_vec |
54 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
55 | |
56 | const memory_desc_wrapper src_d(pd()->src_md()); |
57 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
58 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
59 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
60 | |
61 | const size_t ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking); |
62 | const size_t work_amount |
63 | = jcp.mb * jcp.ngroups * ocb_work * jcp.od * jcp.oh; |
64 | |
65 | auto ker = [&](const int ithr, const int nthr) { |
66 | size_t start {0}, end {0}; |
67 | balance211(work_amount, nthr, ithr, start, end); |
68 | |
69 | bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c, |
70 | format_tag::nChw8c, format_tag::nCdhw8c); |
71 | int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic; |
72 | int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block; |
73 | |
74 | bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c, |
75 | format_tag::nChw8c, format_tag::nCdhw8c); |
76 | int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc; |
77 | int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block; |
78 | int oc_bias_scale = is_oc_physically_blocked ? jcp.oc_block : 1; |
79 | |
80 | int icbb = 0; |
81 | while (icbb < jcp.nb_ic) { |
82 | int icb_step = jcp.nb_ic_blocking; |
83 | int icb_step_rem = jcp.nb_ic - icbb; |
84 | if (icb_step_rem < jcp.nb_ic_blocking_max) icb_step = icb_step_rem; |
85 | |
86 | size_t n {0}, g {0}, ocbb {0}, oh {0}, od {0}; |
87 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, |
88 | od, jcp.od, oh, jcp.oh); |
89 | for (size_t iwork = start; iwork < end; ++iwork) { |
90 | int ocb = ocbb * jcp.nb_oc_blocking; |
91 | int ocb_num = jcp.nb_oc_blocking; |
92 | |
93 | for (int icb = icbb; icb < icbb + icb_step; ++icb) { |
94 | auto par_conv = jit_conv_call_s(); |
95 | |
96 | const int ij = oh * jcp.stride_h; |
97 | const int i_t_overflow = nstl::max(0, jcp.t_pad - ij); |
98 | const int i_b_overflow |
99 | = nstl::max(jcp.ih, |
100 | ij + (jcp.kh - 1) * (jcp.dilate_h + 1) |
101 | - jcp.t_pad + 1) |
102 | - jcp.ih; |
103 | |
104 | const int dj = od * jcp.stride_d; |
105 | const int d_t_overflow = nstl::max(0, jcp.f_pad - dj); |
106 | const int d_b_overflow |
107 | = nstl::max(jcp.id, |
108 | dj + (jcp.kd - 1) * (jcp.dilate_d + 1) |
109 | - jcp.f_pad + 1) |
110 | - jcp.id; |
111 | |
112 | const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale; |
113 | const size_t _ic = g * g_ic_offset + icb * icb_ic_scale; |
114 | |
115 | const int ih = nstl::max(ij - jcp.t_pad |
116 | + div_up(i_t_overflow, (jcp.dilate_h + 1)) |
117 | * (jcp.dilate_h + 1), |
118 | 0); |
119 | |
120 | const int id = nstl::max(dj - jcp.f_pad |
121 | + div_up(d_t_overflow, (jcp.dilate_d + 1)) |
122 | * (jcp.dilate_d + 1), |
123 | 0); |
124 | |
125 | par_conv.src = &src[src_blk_off(src_d, n, _ic, id, ih, 0)]; |
126 | |
127 | par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)]; |
128 | |
129 | const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1)); |
130 | const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1)); |
131 | par_conv.filt = &weights[wht_blk_off( |
132 | weights_d, g, ocb, icb, wd, wh, 0)]; |
133 | |
134 | if (icb == 0) { |
135 | if (bias) |
136 | par_conv.bias = &bias[bias_d.blk_off( |
137 | _oc * oc_bias_scale)]; |
138 | |
139 | par_conv.flags |= FLAG_IC_FIRST; |
140 | } |
141 | |
142 | if ((jcp.with_eltwise || jcp.with_binary) |
143 | && icb + 1 == jcp.nb_ic) |
144 | par_conv.flags |= FLAG_IC_LAST; |
145 | |
146 | par_conv.reduce_work = this_block_size( |
147 | icb * jcp.ic_block, jcp.ic, jcp.ic_block); |
148 | |
149 | par_conv.oc_blocks |
150 | = nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb; |
151 | |
152 | if (ocbb == ocb_work - 1) par_conv.oc_flag |= FLAG_OC_LAST; |
153 | |
154 | par_conv.kw_padding = 0; |
155 | const int kh_padding = jcp.kh |
156 | - div_up(i_t_overflow, (jcp.dilate_h + 1)) |
157 | - div_up(i_b_overflow, (jcp.dilate_h + 1)); |
158 | par_conv.kh_padding = nstl::max(0, kh_padding); |
159 | |
160 | const int kd_padding = jcp.kd |
161 | - div_up(d_t_overflow, (jcp.dilate_d + 1)) |
162 | - div_up(d_b_overflow, (jcp.dilate_d + 1)); |
163 | par_conv.kd_padding = nstl::max(0, kd_padding); |
164 | |
165 | par_conv.oc_l_off = _oc * oc_bias_scale; |
166 | par_conv.post_ops_binary_rhs_arg_vec |
167 | = post_ops_binary_rhs_arg_vec.data(); |
168 | par_conv.dst_orig = dst; |
169 | |
170 | (*kernel_)(&par_conv); |
171 | } |
172 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, od, |
173 | jcp.od, oh, jcp.oh); |
174 | } |
175 | icbb += icb_step; |
176 | } |
177 | }; |
178 | |
179 | if (pd()->wants_padded_bias()) { |
180 | auto padded_bias = ctx.get_scratchpad_grantor().get<data_t>( |
181 | key_conv_padded_bias); |
182 | utils::array_copy(padded_bias, bias, jcp.oc_without_padding); |
183 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
184 | jcp.oc - jcp.oc_without_padding); |
185 | bias = padded_bias; |
186 | } |
187 | |
188 | parallel(jcp.nthr, ker); |
189 | |
190 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
191 | } |
192 | |
193 | void jit_avx2_convolution_bwd_data_t::execute_backward_data( |
194 | const exec_ctx_t &ctx) const { |
195 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
196 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
197 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
198 | |
199 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
200 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
201 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
202 | |
203 | const auto &jcp = kernel_->jcp; |
204 | |
205 | int icb_work = jcp.nb_ic / jcp.nb_ic_blocking; |
206 | int ih_block_size = jcp.ih; |
207 | int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); |
208 | size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks; |
209 | |
210 | const auto data_size = sizeof(data_t); |
211 | const auto L2 = platform::get_per_core_cache_size(2) / data_size; |
212 | // input + output + weights per iteration by nb_oc_blocking |
213 | auto ic_chunk = jcp.nb_ic_blocking * jcp.ic_block; |
214 | auto oc_chunk = jcp.nb_oc_blocking * jcp.oc_block; |
215 | auto iter_data_amount = (size_t)jcp.id * jcp.ih * jcp.iw * ic_chunk |
216 | + (size_t)jcp.od * jcp.oh * jcp.ow * oc_chunk |
217 | + (size_t)jcp.kd * jcp.kh * jcp.kw * ic_chunk * oc_chunk; |
218 | |
219 | if (work_amount < (size_t)2 * jcp.nthr || iter_data_amount > L2) { |
220 | ih_block_size = 1; |
221 | num_ih_blocks = utils::div_up(jcp.ih, ih_block_size); |
222 | work_amount *= num_ih_blocks; |
223 | } |
224 | |
225 | const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
226 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
227 | |
228 | bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c, |
229 | format_tag::nChw8c, format_tag::nCdhw8c); |
230 | int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic; |
231 | int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block; |
232 | |
233 | bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c, |
234 | format_tag::nChw8c, format_tag::nCdhw8c); |
235 | int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc; |
236 | int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block; |
237 | |
238 | const bool is_ddst_layout_nxc = one_of( |
239 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
240 | const int oc_step = is_ddst_layout_nxc ? jcp.nb_oc_blocking : 1; |
241 | |
242 | auto ker = [&](const int ithr, const int nthr) { |
243 | size_t start {0}, end {0}; |
244 | balance211(work_amount, nthr, ithr, start, end); |
245 | |
246 | size_t n {0}, g {0}, icbb {0}, ihb {0}; |
247 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, |
248 | num_ih_blocks); |
249 | for (size_t iwork = start; iwork < end; ++iwork) { |
250 | for_(int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking) |
251 | for (int id = 0; id < jcp.id; ++id) { |
252 | int cur_nb_oc = nstl::min(jcp.nb_oc - oc, jcp.nb_oc_blocking); |
253 | |
254 | auto par_conv = jit_conv_call_s(); |
255 | |
256 | int d_t_overflow, d_b_overflow, od; |
257 | if (jcp.dilate_d != 0) { // stride == 1 |
258 | const int dilate_d = jcp.dilate_d + 1; |
259 | d_t_overflow |
260 | = div_up(nstl::max(0, ext_kd - 1 - id - jcp.f_pad), |
261 | dilate_d); |
262 | d_b_overflow = div_up( |
263 | nstl::max(0, ext_kd - jcp.id + id - jcp.back_pad), |
264 | dilate_d); |
265 | od = id + jcp.f_pad - d_b_overflow * dilate_d; |
266 | } else { |
267 | d_t_overflow = nstl::max(0, jcp.kd - 1 - id - jcp.f_pad); |
268 | d_b_overflow = nstl::max( |
269 | 0, jcp.kd - 1 - (jcp.id - 1 - id) - jcp.back_pad); |
270 | od = id + jcp.f_pad - d_b_overflow; |
271 | } |
272 | par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow; |
273 | |
274 | int ih_start = ihb * ih_block_size; |
275 | int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size); |
276 | for (int ih = ih_start; ih < ih_end; ++ih) { |
277 | |
278 | int k_lo, oh; |
279 | if (jcp.dilate_h != 0) { // stride == 1 |
280 | const int dilate_h = jcp.dilate_h + 1; |
281 | int i_t_overflow = div_up( |
282 | nstl::max(0, ext_kh - 1 - ih - jcp.t_pad), |
283 | dilate_h); |
284 | int i_b_overflow = div_up( |
285 | nstl::max(0, ext_kh - jcp.ih + ih - jcp.b_pad), |
286 | dilate_h); |
287 | par_conv.kh_padding |
288 | = jcp.kh - i_t_overflow - i_b_overflow; |
289 | k_lo = i_b_overflow; |
290 | oh = ih + jcp.t_pad - k_lo * dilate_h; |
291 | } else { |
292 | int i_t_overflow = nstl::max(0, |
293 | (jcp.kh - 1 - ih - jcp.t_pad) / jcp.stride_h); |
294 | int i_b_overflow = nstl::max(0, |
295 | (jcp.kh - jcp.ih + ih - jcp.b_pad) |
296 | / jcp.stride_h); |
297 | int overflow_kh_hi = jcp.kh - 1 |
298 | - modulo(jcp.ih - 1 + jcp.b_pad - ih, |
299 | jcp.stride_h); |
300 | int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h; |
301 | |
302 | par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo) |
303 | / jcp.stride_h |
304 | + 1 - i_t_overflow - i_b_overflow; |
305 | |
306 | k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; |
307 | oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h; |
308 | } |
309 | par_conv.kw_padding = 0; |
310 | |
311 | par_conv.src = &diff_src[src_blk_off(diff_src_d, n, |
312 | g * g_ic_offset |
313 | + jcp.nb_ic_blocking * icbb * icb_ic_scale, |
314 | id, ih, 0)]; |
315 | par_conv.dst = &diff_dst[src_blk_off(diff_dst_d, n, |
316 | g * g_oc_offset + ocb_oc_scale * oc, od, oh, 0)]; |
317 | par_conv.filt = &weights[wht_blk_off(weights_d, g, oc, |
318 | jcp.nb_ic_blocking * icbb, d_b_overflow, k_lo, 0)]; |
319 | |
320 | par_conv.src_prf = nullptr; |
321 | par_conv.dst_prf = nullptr; |
322 | par_conv.filt_prf = nullptr; |
323 | par_conv.channel = oc; |
324 | par_conv.ch_blocks = cur_nb_oc; |
325 | |
326 | if (is_ddst_layout_nxc) { |
327 | par_conv.load_work = this_block_size( |
328 | icbb * jcp.nb_ic_blocking * jcp.ic_block, |
329 | (size_t)jcp.ic, |
330 | jcp.nb_ic_blocking * jcp.ic_block); |
331 | par_conv.reduce_work |
332 | = this_block_size(oc * jcp.oc_block, jcp.oc, |
333 | oc_step * jcp.oc_block); |
334 | |
335 | if (par_conv.load_work % jcp.ic_block > 0) |
336 | par_conv.flags |= FLAG_IC_LAST; |
337 | } |
338 | |
339 | (*kernel_)(&par_conv); |
340 | } |
341 | } |
342 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb, |
343 | num_ih_blocks); |
344 | } |
345 | }; |
346 | |
347 | parallel(jcp.nthr, ker); |
348 | } |
349 | |
350 | void jit_avx2_convolution_bwd_weights_t::execute_backward_weights( |
351 | const exec_ctx_t &ctx) const { |
352 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
353 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
354 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
355 | auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
356 | |
357 | auto scratchpad = ctx.get_scratchpad_grantor(); |
358 | |
359 | const auto &jcp = kernel_->jcp; |
360 | |
361 | const bool is_bias_padded |
362 | = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0); |
363 | |
364 | data_t *diff_bias = is_bias_padded |
365 | ? scratchpad.get<data_t>(key_conv_padded_bias) |
366 | : diff_bias_in; |
367 | |
368 | const memory_desc_wrapper src_d(pd()->src_md()); |
369 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
370 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
371 | |
372 | auto reducer_bia_scratchpad |
373 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); |
374 | auto rb = this->reducer_bias_.get(); |
375 | rb->init(reducer_bia_scratchpad); |
376 | |
377 | auto reducer_wei_scratchpad |
378 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei); |
379 | auto rw = this->reducer_weights_.get(); |
380 | rw->init(reducer_wei_scratchpad); |
381 | |
382 | bool is_ic_physically_blocked = one_of(jcp.src_tag, format_tag::nCw8c, |
383 | format_tag::nChw8c, format_tag::nCdhw8c); |
384 | int g_ic_offset = is_ic_physically_blocked ? jcp.nb_ic : jcp.ic; |
385 | int icb_ic_scale = is_ic_physically_blocked ? 1 : jcp.ic_block; |
386 | |
387 | bool is_oc_physically_blocked = one_of(jcp.dst_tag, format_tag::nCw8c, |
388 | format_tag::nChw8c, format_tag::nCdhw8c); |
389 | bool is_ddst_layout_nxc = !is_oc_physically_blocked; |
390 | int g_oc_offset = is_oc_physically_blocked ? jcp.nb_oc : jcp.oc; |
391 | int ocb_oc_scale = is_oc_physically_blocked ? 1 : jcp.oc_block; |
392 | |
393 | auto ker = [&](int ithr, int nthr) { |
394 | assert(nthr == rw->balancer().nthr_); |
395 | |
396 | const int w_job_start = rw->balancer().ithr_job_off(ithr); |
397 | const int w_njobs = rw->balancer().ithr_njobs(ithr); |
398 | |
399 | if (w_njobs == 0) return; |
400 | |
401 | /* reduction dimension */ |
402 | int img_od_start {0}, img_od_end {0}, img {0}, od_s {0}; |
403 | balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_, |
404 | rw->balancer().id_in_group(ithr), img_od_start, img_od_end); |
405 | |
406 | int img_start = img_od_start, img_end = img_od_end; |
407 | nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); |
408 | const int img_first = img; |
409 | |
410 | /* jobs */ |
411 | int g_start {0}, ocb_start {0}, icb_start {0}; |
412 | nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start, |
413 | jcp.nb_oc, icb_start, jcp.nb_ic); |
414 | |
415 | while (img_start < img_end) { |
416 | int g = g_start, ocb = ocb_start, icb = icb_start; |
417 | |
418 | const int work_rem = img_end - img_start; |
419 | const int od_e |
420 | = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; |
421 | const int id_s = od_s * jcp.stride_d; |
422 | const int idp = jcp.id + jcp.f_pad + jcp.back_pad; |
423 | |
424 | if (id_s < idp - jcp.back_pad - jcp.kd + 1) |
425 | for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) { |
426 | const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale; |
427 | const size_t _ic = g * g_ic_offset + icb * icb_ic_scale; |
428 | |
429 | /* TODO: put dw <-- 0 in kernel */ |
430 | if (img == img_first) |
431 | array_set(rw->get_local_ptr(ithr, diff_weights, |
432 | reducer_wei_scratchpad) |
433 | + w_job_loc * rw->balancer().job_size_, |
434 | 0, rw->balancer().job_size_); |
435 | |
436 | for (int od = od_s; od < od_e; ++od) { |
437 | const int id = od * jcp.stride_d; |
438 | if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break; |
439 | |
440 | auto par_conv = jit_conv_call_s(); |
441 | par_conv.src |
442 | = &src[src_blk_off(src_d, img, _ic, id, 0, 0)]; |
443 | par_conv.dst = &diff_dst[src_blk_off( |
444 | diff_dst_d, img, _oc, od, 0, 0)]; |
445 | par_conv.filt = rw->get_local_ptr(ithr, diff_weights, |
446 | reducer_wei_scratchpad) |
447 | + w_job_loc * rw->balancer().job_size_; |
448 | |
449 | if (ocb == jcp.nb_oc - 1) |
450 | par_conv.flags |= FLAG_OC_LAST; |
451 | |
452 | par_conv.channel = this_block_size( |
453 | icb * jcp.ic_block, jcp.ic, jcp.ic_block); |
454 | |
455 | (*kernel_)(&par_conv); |
456 | } |
457 | nd_iterator_step( |
458 | g, jcp.ngroups, ocb, jcp.nb_oc, icb, jcp.nb_ic); |
459 | } |
460 | nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); |
461 | } |
462 | |
463 | if (dnnl_thr_syncable()) |
464 | rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); |
465 | }; |
466 | |
467 | auto ker_bias = [&](int ithr, int nthr) { |
468 | assert(nthr == rb->balancer().nthr_); |
469 | |
470 | const int b_job_start = rb->balancer().ithr_job_off(ithr); |
471 | const int b_njobs = rb->balancer().ithr_njobs(ithr); |
472 | |
473 | if (b_njobs == 0) return; |
474 | |
475 | /* reduction dimension */ |
476 | int img_start {0}, img_end {0}; |
477 | balance211(jcp.mb, rb->balancer().nthr_per_group_, |
478 | rb->balancer().id_in_group(ithr), img_start, img_end); |
479 | |
480 | /* jobs */ |
481 | int g_start {0}, ocb_start {0}; |
482 | nd_iterator_init( |
483 | b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); |
484 | |
485 | for (int img = img_start; img < img_end; ++img) { |
486 | int g = g_start, ocb = ocb_start; |
487 | for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { |
488 | const size_t _oc = g * g_oc_offset + ocb * ocb_oc_scale; |
489 | |
490 | const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)]; |
491 | data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, |
492 | reducer_bia_scratchpad) |
493 | + b_job_loc * rb->balancer().job_size_; |
494 | |
495 | if (img == img_start) |
496 | for (int o = 0; o < jcp.oc_block; ++o) |
497 | d_bias[o] = 0.; |
498 | |
499 | const int max_oc = this_block_size( |
500 | ocb * jcp.oc_block, jcp.oc, jcp.oc_block); |
501 | |
502 | for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) { |
503 | PRAGMA_OMP_SIMD() |
504 | for (int o = 0; o < max_oc; ++o) |
505 | d_bias[o] += d_dst[o]; |
506 | d_dst += is_ddst_layout_nxc ? jcp.ngroups * jcp.oc |
507 | : jcp.oc_block; |
508 | } |
509 | |
510 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); |
511 | } |
512 | } |
513 | |
514 | if (dnnl_thr_syncable()) |
515 | rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); |
516 | }; |
517 | |
518 | if (dnnl_thr_syncable()) { |
519 | assert(IMPLICATION(pd()->with_bias(), |
520 | rw->balancer().nthr_ == rb->balancer().nthr_)); |
521 | parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) { |
522 | ker(ithr, nthr); |
523 | if (pd()->with_bias()) ker_bias(ithr, nthr); |
524 | }); |
525 | } else { |
526 | parallel(rw->balancer().nthr_, |
527 | [&](int ithr, int nthr) { ker(ithr, nthr); }); |
528 | parallel(rw->balancer().nthr_, [&](int ithr, int nthr) { |
529 | assert(nthr == rw->balancer().nthr_); |
530 | MAYBE_UNUSED(nthr); |
531 | if (rw->balancer().ithr_njobs(ithr) == 0) return; |
532 | rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad); |
533 | }); |
534 | if (pd()->with_bias()) { |
535 | parallel(rb->balancer().nthr_, |
536 | [&](int ithr, int nthr) { ker_bias(ithr, nthr); }); |
537 | parallel(rb->balancer().nthr_, [&](int ithr, int nthr) { |
538 | assert(nthr == rb->balancer().nthr_); |
539 | MAYBE_UNUSED(nthr); |
540 | if (rb->balancer().ithr_njobs(ithr) == 0) return; |
541 | rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad); |
542 | }); |
543 | } |
544 | } |
545 | |
546 | /* TODO: put this in ker_bias */ |
547 | if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0)) { |
548 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
549 | const int stride = jcp.oc_without_padding; |
550 | for (int g = 0; g < jcp.ngroups; ++g) |
551 | utils::array_copy(diff_bias_in + g * stride, |
552 | diff_bias + g * padded_stride, stride); |
553 | } |
554 | } |
555 | |
556 | } // namespace x64 |
557 | } // namespace cpu |
558 | } // namespace impl |
559 | } // namespace dnnl |
560 | |
561 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
562 | |