1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | |
24 | #include "cpu/x64/jit_avx512_common_1x1_convolution.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | #define data_blk_off(f, n, c, d, h, w) \ |
36 | ((ndims == 3) ? (f).blk_off(n, c, w) \ |
37 | : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ |
38 | : (f).blk_off(n, c, d, h, w))) |
39 | /* convolution forward */ |
40 | |
41 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
42 | void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, |
43 | dst_type>::execute_forward(const exec_ctx_t &ctx) const { |
44 | const auto &jcp = kernel_->jcp; |
45 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
46 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
47 | auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); |
48 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
49 | auto weights_dw = CTX_IN_MEM( |
50 | const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
51 | auto bias_dw = CTX_IN_MEM( |
52 | const dst_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
53 | const auto post_ops_binary_rhs_arg_vec |
54 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
55 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_ |
56 | ? binary_injector::prepare_binary_args( |
57 | pd()->dw_conv_pd_->jcp_.post_ops, ctx, |
58 | pd()->jcp_.post_ops.entry_.size() + 1) |
59 | : std::vector<const void *> {}; |
60 | |
61 | auto scratchpad = ctx.get_scratchpad_grantor(); |
62 | |
63 | if (pd()->wants_padded_bias()) { |
64 | auto padded_bias |
65 | = scratchpad.template get<dst_data_t>(key_conv_padded_bias); |
66 | utils::array_copy(padded_bias, bias, jcp.oc_without_padding); |
67 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
68 | jcp.oc - jcp.oc_without_padding); |
69 | bias = padded_bias; |
70 | } |
71 | |
72 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
73 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
74 | dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), |
75 | post_ops_binary_rhs_arg_vec_dw.data()); |
76 | }); |
77 | |
78 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
79 | } |
80 | |
81 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
82 | void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, |
83 | dst_type>::execute_forward_thr(const int ithr, const int nthr, |
84 | const src_data_t *src, const wei_data_t *weights, |
85 | const dst_data_t *bias, const wei_data_t *weights_dw, |
86 | const dst_data_t *bias_dw, dst_data_t *dst, |
87 | const memory_tracking::grantor_t &scratchpad, |
88 | const void *post_ops_binary_rhs_arg_vec, |
89 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
90 | const memory_desc_wrapper src_d(pd()->src_md()); |
91 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
92 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
93 | const memory_desc_wrapper dw_weights_d( |
94 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
95 | const memory_desc_wrapper dw_bias_d( |
96 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); |
97 | |
98 | const auto &jcp = kernel_->jcp; |
99 | auto rtus_space = pd()->rtus_.reduce_src_ |
100 | ? scratchpad.get<src_data_t>(key_conv_rtus_space) |
101 | : nullptr; |
102 | |
103 | const int ndims = src_d.ndims(); |
104 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
105 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
106 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
107 | |
108 | auto step = [](int default_step, int remaining, int tail_step) { |
109 | assert(default_step <= tail_step); |
110 | return remaining < tail_step ? remaining : default_step; |
111 | }; |
112 | |
113 | auto p = jit_1x1_conv_call_s(); |
114 | |
115 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
116 | |
117 | const int nb_oc = jcp.nb_load; |
118 | const int nb_ic = jcp.nb_reduce; |
119 | const int nb_ic_blocking = jcp.nb_reduce_blocking; |
120 | |
121 | // override some constants for fused dw_conv |
122 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
123 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
124 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
125 | const int nb_bcast_blocking_max |
126 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
127 | const int nb_load_blocking = jcp.nb_load_blocking; |
128 | const int nb_load_blocking_max = jcp.with_dw_conv |
129 | ? jcp.nb_load_blocking |
130 | : jcp.nb_load_blocking_max; |
131 | const bool is_dst_layout_nxc = utils::one_of( |
132 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
133 | const bool is_src_layout_nxc = utils::one_of( |
134 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
135 | |
136 | // Begin: declare Variables needed for dw conv. |
137 | memory_tracking::grantor_t dw_scratchpad( |
138 | scratchpad, memory_tracking::names::prefix_fusion); |
139 | dst_data_t *pbuf; |
140 | size_t row_offset; |
141 | const int nb_buffer = jcp.nb_load_blocking; |
142 | std::vector<dst_data_t *> addrs; |
143 | // End |
144 | |
145 | auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, |
146 | int &bcast_step, int &od, int &oh, int &ow, |
147 | int &id, int &ih, int &iw) { |
148 | int osb {0}; |
149 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
150 | bcast_step = step( |
151 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
152 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
153 | |
154 | const int os = osb * os_block; |
155 | od = os / (jcp.oh * jcp.ow); |
156 | int os_2d = os % (jcp.oh * jcp.ow); |
157 | oh = os_2d / jcp.ow; |
158 | ow = os_2d % jcp.ow; |
159 | |
160 | id = od * stride_d; |
161 | ih = oh * stride_h; |
162 | iw = ow * stride_w; |
163 | rp.iw_start = iw; |
164 | |
165 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
166 | rp.os = p.bcast_dim; |
167 | }; |
168 | |
169 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
170 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
171 | const auto max_oc |
172 | = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding); |
173 | p.load_dim = this_block_size( |
174 | ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block); |
175 | }; |
176 | |
177 | auto init_reduce = [&](int icb) { |
178 | const int nb_ic_blocking_step |
179 | = nstl::min(icb + nb_ic_blocking, nb_ic) - icb; |
180 | p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0) |
181 | | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0); |
182 | |
183 | p.reduce_dim = this_block_size( |
184 | icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block); |
185 | rp.icb = p.reduce_dim; |
186 | }; |
187 | |
188 | auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od, |
189 | int oh, int ow, int id, int ih, int iw) { |
190 | const int oc_off_idx = is_dst_layout_nxc |
191 | ? g * jcp.oc + ocb * jcp.oc_block |
192 | : g * nb_oc + ocb; |
193 | const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow); |
194 | |
195 | p.output_data = jcp.with_dw_conv |
196 | ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset |
197 | : &dst[dst_off]; |
198 | p.bias_data = bias |
199 | ? &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)] |
200 | : nullptr; |
201 | |
202 | p.load_data |
203 | = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) |
204 | : weights_d.blk_off(ocb, icb)]; |
205 | const int ic_off_idx = is_src_layout_nxc |
206 | ? g * jcp.ic + icb * jcp.ic_block |
207 | : g * nb_ic + icb; |
208 | if (pd()->rtus_.reduce_src_) { |
209 | rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ |
210 | + (is_src_layout_nxc ? ic_off_idx |
211 | : jcp.is * ic_off_idx * jcp.ic_block); |
212 | if (ocb == ocb_start) { |
213 | rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
214 | (*rtus_driver_)(&rp); |
215 | } |
216 | p.bcast_data = rp.ws; |
217 | } else |
218 | p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
219 | |
220 | p.dst_l_off = dst_off; |
221 | p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
222 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
223 | p.dst_orig = dst; |
224 | |
225 | (*kernel_)(&p); |
226 | }; |
227 | |
228 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
229 | int ocb_end) { |
230 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
231 | |
232 | if (jcp.loop_order == loop_rlb) { |
233 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
234 | init_reduce(icb); |
235 | int ocb = ocb_start; |
236 | while (ocb < ocb_end) { |
237 | int load_step; |
238 | init_load(ocb, ocb_end, load_step); |
239 | int iwork = bcast_start; |
240 | while (iwork < bcast_end) { |
241 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, |
242 | ow {0}, id {0}, ih {0}, iw {0}; |
243 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, |
244 | ow, id, ih, iw); |
245 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
246 | iw); |
247 | iwork += bcast_step; |
248 | } |
249 | ocb += load_step; |
250 | } |
251 | } |
252 | } else if (jcp.loop_order == loop_lbr) { |
253 | int ocb = ocb_start; |
254 | while (ocb < ocb_end) { |
255 | int load_step; |
256 | init_load(ocb, ocb_end, load_step); |
257 | int iwork = bcast_start; |
258 | while (iwork < bcast_end) { |
259 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
260 | id {0}, ih {0}, iw {0}; |
261 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
262 | id, ih, iw); |
263 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
264 | init_reduce(icb); |
265 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
266 | iw); |
267 | } |
268 | iwork += bcast_step; |
269 | } |
270 | ocb += load_step; |
271 | } |
272 | } else if (jcp.loop_order == loop_rbl) { |
273 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
274 | init_reduce(icb); |
275 | int iwork = bcast_start; |
276 | while (iwork < bcast_end) { |
277 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
278 | id {0}, ih {0}, iw {0}; |
279 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
280 | id, ih, iw); |
281 | int ocb = ocb_start; |
282 | while (ocb < ocb_end) { |
283 | int load_step; |
284 | init_load(ocb, ocb_end, load_step); |
285 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
286 | iw); |
287 | ocb += load_step; |
288 | } |
289 | iwork += bcast_step; |
290 | } |
291 | } |
292 | } else if (jcp.loop_order == loop_blr) { |
293 | int iwork = bcast_start; |
294 | while (iwork < bcast_end) { |
295 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
296 | id {0}, ih {0}, iw {0}; |
297 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
298 | ih, iw); |
299 | int ocb = ocb_start; |
300 | while (ocb < ocb_end) { |
301 | int load_step; |
302 | init_load(ocb, ocb_end, load_step); |
303 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
304 | init_reduce(icb); |
305 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
306 | iw); |
307 | } |
308 | ocb += load_step; |
309 | } |
310 | iwork += bcast_step; |
311 | } |
312 | } else { |
313 | assert(!"unsupported loop order" ); |
314 | } |
315 | }; |
316 | |
317 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
318 | auto &jcp_dw = pd()->dw_conv_pd_->jcp_; |
319 | int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0); |
320 | |
321 | for (int i = 0; i < jcp_dw.kh; ++i) |
322 | addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset; |
323 | |
324 | const auto ocb_end = ocb_start + load_step; |
325 | const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw) |
326 | * jcp_dw.nb_ch_blocking * jcp_dw.ch_block; |
327 | const int dil_h = jcp_dw.dilate_h + 1; |
328 | const int str_h = jcp_dw.stride_h; |
329 | const int ch_num = jcp_dw.nb_ch_blocking; |
330 | const int ow = 0; |
331 | const int kw = 0; |
332 | |
333 | for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) { |
334 | |
335 | const int i_t_overflow |
336 | = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h)); |
337 | const int i_b_overflow |
338 | = nstl::max(jcp_dw.ih, |
339 | (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h |
340 | - jcp_dw.t_pad + 1)) |
341 | - jcp_dw.ih; |
342 | |
343 | const int kh = div_up(i_t_overflow, dil_h); |
344 | const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h) |
345 | - div_up(i_b_overflow, dil_h); |
346 | |
347 | jit_conv_call_s par_conv_dw; |
348 | |
349 | par_conv_dw.src = addrs.data(); |
350 | |
351 | const size_t ch_step = is_dst_layout_nxc |
352 | ? jcp_dw.ch_block |
353 | : dst_d.blk_off(0, 1, 0, 0); |
354 | par_conv_dw.dst |
355 | = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step]; |
356 | |
357 | par_conv_dw.filt |
358 | = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; |
359 | if (bias) |
360 | par_conv_dw.bias |
361 | = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)]; |
362 | |
363 | par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); |
364 | |
365 | par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch) |
366 | * jcp_dw.ch_block; |
367 | |
368 | par_conv_dw.oc_l_off = ch * jcp_dw.ch_block; |
369 | par_conv_dw.post_ops_binary_rhs_arg_vec |
370 | = post_ops_binary_rhs_arg_vec_dw; |
371 | par_conv_dw.dst_orig = dst; |
372 | |
373 | (*kernel_dw_)(&par_conv_dw); |
374 | |
375 | for (int i = 0; i < jcp_dw.kh; ++i) |
376 | addrs[i] += wch_stride; |
377 | } |
378 | }; |
379 | |
380 | auto conv_dw = [&]() { |
381 | // Set variables |
382 | auto dw_conv_buffer |
383 | = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer); |
384 | auto &jcp_dw = pd()->dw_conv_pd_->jcp_; |
385 | |
386 | const auto dw_conv_buffer_size_ |
387 | = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block; |
388 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
389 | row_offset = dw_conv_buffer_size_ / jcp_dw.kh; |
390 | addrs.resize(jcp_dw.kh); |
391 | |
392 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
393 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start, |
394 | bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); |
395 | |
396 | while (ocb_start < ocb_end) { |
397 | int load_step; |
398 | init_load(ocb_start, ocb_end, load_step); |
399 | |
400 | int oh_1x1 = 0; |
401 | auto bcast_iter = bcast_start; |
402 | while (bcast_iter < bcast_end) { |
403 | int n {0}, g {0}, oh_dw {0}; |
404 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
405 | jcp_dw.oh); |
406 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
407 | const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad; |
408 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
409 | const int oh_1x1_end |
410 | = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh); |
411 | oh_1x1 = nstl::max( |
412 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
413 | |
414 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh |
415 | const int bcast_start_1x1 |
416 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
417 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
418 | |
419 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
420 | ocb_start + load_step); |
421 | oh_1x1 = oh_1x1_end; |
422 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
423 | |
424 | bcast_iter += nb_bcast_blocking; |
425 | } |
426 | ocb_start += load_step; |
427 | } |
428 | }; |
429 | |
430 | if (jcp.with_dw_conv) { |
431 | conv_dw(); |
432 | } else { |
433 | |
434 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
435 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
436 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, |
437 | ocb_start, ocb_end, jcp.load_grp_count); |
438 | |
439 | conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); |
440 | } |
441 | } |
442 | |
443 | REG_AVX512_ISA(template struct jit_avx512_common_1x1_convolution_fwd_t< |
444 | data_type::f32>); |
445 | /* convolution backward wtr data */ |
446 | |
447 | template <data_type_t diff_dst_type, data_type_t wei_type, |
448 | data_type_t diff_src_type> |
449 | void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type, |
450 | diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const { |
451 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
452 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
453 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
454 | |
455 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
456 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
457 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
458 | |
459 | const auto &jcp = kernel_->jcp; |
460 | auto rtus_space = pd()->rtus_.reduce_src_ |
461 | ? ctx.get_scratchpad_grantor().template get<diff_src_data_t>( |
462 | key_conv_rtus_space) |
463 | : nullptr; |
464 | |
465 | const int ndims = diff_src_d.ndims(); |
466 | |
467 | assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1); |
468 | |
469 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
470 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
471 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
472 | |
473 | const int nb_ic = jcp.nb_load; |
474 | const int nb_oc = jcp.nb_reduce; |
475 | const int os_block = jcp.bcast_block; |
476 | const int nb_oc_blocking = jcp.nb_reduce_blocking; |
477 | |
478 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
479 | |
480 | auto step = [](int default_step, int remaining, int tail_step) { |
481 | assert(default_step <= tail_step); |
482 | return remaining < tail_step ? remaining : default_step; |
483 | }; |
484 | |
485 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
486 | auto p = jit_1x1_conv_call_s(); |
487 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
488 | |
489 | int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0}; |
490 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, |
491 | icb_start, icb_end, jcp.load_grp_count); |
492 | |
493 | bool reduce_outer |
494 | = (jcp.loop_order == loop_rbl || jcp.loop_order == loop_rlb); |
495 | int nboc_outer = reduce_outer ? nb_oc : 1; |
496 | int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1; |
497 | |
498 | int nboc_inner = reduce_outer ? 1 : nb_oc; |
499 | int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking; |
500 | const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic); |
501 | |
502 | for (int ocb_outer = 0; ocb_outer < nboc_outer; |
503 | ocb_outer += ocb_outer_step) { |
504 | size_t cur_ocb_outer |
505 | = nstl::min(ocb_outer + ocb_outer_step, nboc_outer) |
506 | - ocb_outer; |
507 | |
508 | int load_step = 0; |
509 | for (int icb = icb_start; icb < icb_end; icb += load_step) { |
510 | load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, |
511 | jcp.nb_load_blocking_max); |
512 | |
513 | p.load_dim = this_block_size( |
514 | icb * jcp.ic_block, max_ic, load_step * jcp.ic_block); |
515 | rp.icb = p.load_dim; |
516 | |
517 | int bcast_step; |
518 | for (int iwork = bcast_start; iwork < bcast_end; |
519 | iwork += bcast_step) { |
520 | int n {0}, g {0}, osb {0}; |
521 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, |
522 | jcp.nb_bcast); |
523 | |
524 | bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, |
525 | jcp.nb_bcast_blocking_max); |
526 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
527 | |
528 | const int os = osb * os_block; |
529 | p.bcast_dim = this_block_size( |
530 | os, jcp.os, bcast_step * os_block); |
531 | rp.os = p.bcast_dim; |
532 | |
533 | const int od = os / (jcp.oh * jcp.ow); |
534 | const int os_2d = os % (jcp.oh * jcp.ow); |
535 | const int oh = os_2d / jcp.ow; |
536 | const int ow = os_2d % jcp.ow; |
537 | const int id = od * stride_d; |
538 | const int ih = oh * stride_h; |
539 | const int iw = ow * stride_w; |
540 | rp.iw_start = iw; |
541 | const bool is_dsrc_layout_nxc |
542 | = utils::one_of(jcp.src_tag, format_tag::nwc, |
543 | format_tag::nhwc, format_tag::ndhwc); |
544 | const int ic_off_idx = is_dsrc_layout_nxc |
545 | ? g * jcp.ic + icb * jcp.ic_block |
546 | : g * nb_ic + icb; |
547 | rp.src = diff_src |
548 | + data_blk_off( |
549 | diff_src_d, n, ic_off_idx, id, ih, iw); |
550 | if (pd()->rtus_.reduce_src_) { |
551 | rp.ws = rtus_space |
552 | + ithr * pd()->rtus_.space_per_thread_; |
553 | p.output_data = rp.ws; |
554 | } else |
555 | p.output_data = rp.src; |
556 | |
557 | for (int ocb_inner = 0; ocb_inner < nboc_inner; |
558 | ocb_inner += ocb_inner_step) { |
559 | int cur_ocb_inner |
560 | = nstl::min(ocb_inner + ocb_inner_step, |
561 | nboc_inner) |
562 | - ocb_inner; |
563 | |
564 | int ocb = reduce_outer ? ocb_outer : ocb_inner; |
565 | int nb_oc_blocking_step |
566 | = reduce_outer ? cur_ocb_outer : cur_ocb_inner; |
567 | const bool is_ddst_layout_nxc |
568 | = utils::one_of(jcp.dst_tag, format_tag::nwc, |
569 | format_tag::nhwc, format_tag::ndhwc); |
570 | const int oc_off_idx = is_ddst_layout_nxc |
571 | ? g * jcp.oc + ocb * jcp.oc_block |
572 | : g * nb_oc + ocb; |
573 | size_t diff_dst_off = data_blk_off( |
574 | diff_dst_d, n, oc_off_idx, od, oh, ow); |
575 | p.bcast_data = &diff_dst[diff_dst_off]; |
576 | |
577 | p.load_data = &weights[pd()->with_groups() |
578 | ? weights_d.blk_off(g, ocb, icb) |
579 | : weights_d.blk_off(ocb, icb)]; |
580 | |
581 | p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; |
582 | |
583 | p.reduce_dim = this_block_size(ocb * jcp.oc_block, |
584 | jcp.oc, nb_oc_blocking_step * jcp.oc_block); |
585 | |
586 | (*kernel_)(&p); |
587 | } |
588 | if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); |
589 | } |
590 | } |
591 | } |
592 | }); |
593 | } |
594 | |
595 | REG_AVX512_ISA(template struct jit_avx512_common_1x1_convolution_bwd_data_t< |
596 | data_type::f32>); |
597 | |
598 | /* convolution backward wtr weights */ |
599 | |
600 | #define wht_blk_off(d, g, ...) \ |
601 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
602 | : (d).blk_off(__VA_ARGS__)) |
603 | |
604 | status_t jit_avx512_common_1x1_convolution_bwd_weights_t ::init( |
605 | engine_t *engine) { |
606 | CHECK(safe_ptr_assign(kernel_, |
607 | new jit_avx512_common_1x1_conv_kernel( |
608 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
609 | CHECK(safe_ptr_assign( |
610 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
611 | CHECK(safe_ptr_assign(reducer_bias_, |
612 | new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_))); |
613 | CHECK(kernel_->create_kernel()); |
614 | CHECK(acc_ker_->create_kernel()); |
615 | CHECK(reducer_bias_->create_kernel()); |
616 | |
617 | CHECK(init_rtus_driver<avx512_core>(this)); |
618 | return status::success; |
619 | } |
620 | |
621 | void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights( |
622 | const exec_ctx_t &ctx) const { |
623 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
624 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
625 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
626 | auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
627 | |
628 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
629 | const memory_desc_wrapper src_d(pd()->src_md()); |
630 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
631 | |
632 | const auto &jcp = kernel_->jcp; |
633 | |
634 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
635 | |
636 | auto rtus_space = pd()->rtus_.reduce_src_ |
637 | ? scratchpad.get<data_t>(key_conv_rtus_space) |
638 | : nullptr; |
639 | const bool is_bias_padded |
640 | = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0; |
641 | |
642 | data_t *diff_bias = is_bias_padded |
643 | ? scratchpad.get<data_t>(key_conv_padded_bias) |
644 | : diff_bias_in; |
645 | auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction); |
646 | |
647 | const int ndims = src_d.ndims(); |
648 | const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) |
649 | * rnd_up(jcp.ic, jcp.ic_block); |
650 | |
651 | simple_barrier::ctx_t reduction_barrier; |
652 | simple_barrier::ctx_init(&reduction_barrier); |
653 | |
654 | const auto reducer_bia_scratchpad |
655 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); |
656 | auto rb = this->reducer_bias_.get(); |
657 | rb->init(reducer_bia_scratchpad); |
658 | |
659 | // TODO (Roma): remove this restriction |
660 | assert(jcp.stride_w == 1 && jcp.stride_h == 1); |
661 | |
662 | const int nb_ic = jcp.nb_bcast; |
663 | const int nb_ic_blocking = jcp.nb_bcast_blocking; |
664 | |
665 | const int nb_oc = jcp.nb_load; |
666 | const int nb_oc_blocking = jcp.nb_load_blocking; |
667 | |
668 | const int sp_nb = jcp.nb_reduce; |
669 | const int mb_sp_work = jcp.mb * sp_nb; |
670 | |
671 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; |
672 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
673 | |
674 | auto step = [](int default_step, int remaining, int tail_step) { |
675 | assert(default_step <= tail_step); |
676 | return remaining < tail_step ? remaining : default_step; |
677 | }; |
678 | |
679 | const bool is_src_layout_nxc = utils::one_of( |
680 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
681 | const bool is_ddst_layout_nxc = utils::one_of( |
682 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
683 | |
684 | auto maybe_zero_icpad = [&](const int g_start, const int g_end, |
685 | const int ocb_start, const int ocb_end) { |
686 | // write zeros to IC padded region. |
687 | const int ic_tail = jcp.ic_without_padding % jcp.ic_block; |
688 | if (is_ddst_layout_nxc && ic_tail != 0) { |
689 | for_(int g = g_start; g < g_end; ++g) |
690 | for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) { |
691 | const int z_icb = nb_ic - 1; |
692 | const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb) |
693 | + ic_tail * jcp.oc_block; |
694 | data_t *z_wei = diff_weights + off; |
695 | const int zero_work |
696 | = (nb_ic * jcp.ic_block - jcp.ic_without_padding) |
697 | * jcp.oc_block; |
698 | PRAGMA_OMP_SIMD() |
699 | for (int o = 0; o < zero_work; ++o) { |
700 | z_wei[o] = 0; |
701 | } |
702 | } |
703 | } |
704 | }; |
705 | |
706 | auto ker = [&](const int ithr, const int nthr) { |
707 | assert(nthr == jcp.nthr); |
708 | |
709 | const int ithr_ic_b = ithr % jcp.nthr_ic_b; |
710 | const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; |
711 | const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; |
712 | const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; |
713 | |
714 | /* reduction dimension */ |
715 | int mb_sp_b_start {0}, mb_sp_b_end {0}; |
716 | balance211( |
717 | mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end); |
718 | |
719 | /* independent dimensions */ |
720 | int g_start {0}, oc_b_start {0}, ic_b_start {0}; |
721 | int g_end {0}, oc_b_end {0}, ic_b_end {0}; |
722 | |
723 | balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); |
724 | balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); |
725 | balance211( |
726 | jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); |
727 | |
728 | const int g_work = g_end - g_start; |
729 | const int oc_b_work = oc_b_end - oc_b_start; |
730 | const int ic_b_work = ic_b_end - ic_b_start; |
731 | const bool cache_aliasing |
732 | = (jcp.ic * jcp.ngroups * sizeof(float)) % 1024 == 0; |
733 | int reduce_step = jcp.nb_reduce_blocking; |
734 | int reduce_step_max = jcp.nb_reduce_blocking_max; |
735 | if (is_src_layout_nxc && cache_aliasing) { |
736 | // Experiments show 4 is a magic number with the tested shapes. |
737 | // TODO: maybe tune for shapes with sp_dim%4 != 0 |
738 | reduce_step = nstl::min(4, reduce_step); |
739 | reduce_step_max = reduce_step; |
740 | } |
741 | |
742 | data_t *diff_wei = ithr_mb == 0 |
743 | ? diff_weights |
744 | : wei_reduction + (ithr_mb - 1) * wei_size; |
745 | |
746 | int sp_b_step = 0; |
747 | for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; |
748 | mb_sp_b += sp_b_step) { |
749 | int img {0}, sp_b {0}; |
750 | nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); |
751 | sp_b_step = step(reduce_step, |
752 | nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), |
753 | reduce_step_max); |
754 | |
755 | for (int g = g_start; g < g_end; ++g) { |
756 | int load_step = 0; |
757 | int bcast_step = 0; |
758 | for (int ic_b = ic_b_start; ic_b < ic_b_end; |
759 | ic_b += bcast_step) { |
760 | if (is_src_layout_nxc && cache_aliasing) { |
761 | bcast_step = ic_b_work; |
762 | } else { |
763 | bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, |
764 | jcp.nb_bcast_blocking_max); |
765 | } |
766 | |
767 | for (int oc_b = oc_b_start; oc_b < oc_b_end; |
768 | oc_b += load_step) { |
769 | load_step = step(nb_oc_blocking, oc_b_end - oc_b, |
770 | jcp.nb_load_blocking_max); |
771 | const int _ic_b = g * nb_ic + ic_b; |
772 | const int oc_off_idx = is_ddst_layout_nxc |
773 | ? g * jcp.oc + oc_b * jcp.oc_block |
774 | : g * nb_oc + oc_b; |
775 | |
776 | data_t *store_to; |
777 | |
778 | const size_t off |
779 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
780 | store_to = diff_wei + off; |
781 | |
782 | const int ic_off_idx |
783 | = (is_src_layout_nxc ? jcp.ic_block : 1) |
784 | * _ic_b; |
785 | const data_t *diff_src |
786 | = &src[src_d.blk_off(img, ic_off_idx)]; |
787 | |
788 | int sp_b_end = sp_b + sp_b_step; |
789 | const data_t *pdiff_dst = &diff_dst[diff_dst_d.blk_off( |
790 | img, oc_off_idx)]; |
791 | const data_t *local_src = diff_src; |
792 | |
793 | auto p = jit_1x1_conv_call_s(); |
794 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
795 | |
796 | p.output_stride = utils::rnd_up(jcp.ic, jcp.ic_block) |
797 | * jcp.oc_block * jcp.typesize_out; |
798 | |
799 | p.load_dim = this_block_size(oc_b * jcp.oc_block, |
800 | jcp.oc, load_step * jcp.oc_block); |
801 | |
802 | p.bcast_dim = this_block_size(ic_b * jcp.ic_block, |
803 | jcp.ic, bcast_step * jcp.ic_block); |
804 | rp.icb = p.bcast_dim; |
805 | p.output_data = store_to; |
806 | |
807 | p.reduce_dim = sp_b_step * jcp.reduce_block; |
808 | rp.os = p.reduce_dim; |
809 | |
810 | p.first_last_flag = 0 |
811 | | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST |
812 | : 0) |
813 | | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0); |
814 | |
815 | int sp = sp_b * jcp.reduce_block; |
816 | int oc_mult |
817 | = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block; |
818 | p.load_data = pdiff_dst + sp * oc_mult; |
819 | |
820 | if (pd()->rtus_.reduce_src_) { |
821 | const int oh = sp / jcp.ow; |
822 | const int ow = sp % jcp.ow; |
823 | |
824 | const int ih = oh * stride_h; |
825 | const int iw = ow * stride_w; |
826 | rp.iw_start = iw; |
827 | |
828 | rp.ws = rtus_space |
829 | + ithr * pd()->rtus_.space_per_thread_ |
830 | + sp * jcp.ic_block; |
831 | |
832 | if (ndims == 3) |
833 | rp.src = local_src |
834 | + iw * src_d.blocking_desc().strides[2]; |
835 | else |
836 | rp.src = local_src |
837 | + ih * src_d.blocking_desc().strides[2] |
838 | + iw * src_d.blocking_desc().strides[3]; |
839 | (*rtus_driver_)(&rp); |
840 | |
841 | p.bcast_data = rp.ws; |
842 | } else { |
843 | int ic_mult |
844 | = is_src_layout_nxc ? jcp.ic : jcp.ic_block; |
845 | p.bcast_data = local_src + sp * ic_mult; |
846 | } |
847 | |
848 | (*kernel_)(&p); |
849 | } |
850 | } |
851 | } |
852 | } |
853 | |
854 | if (ithr_mb == 0 && ic_b_end >= jcp.nb_bcast) { |
855 | maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end); |
856 | } |
857 | |
858 | /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */ |
859 | if (dnnl_thr_syncable() && jcp.nthr_mb > 1) { |
860 | simple_barrier::barrier(&reduction_barrier, jcp.nthr); |
861 | const int work = g_work * oc_b_work * ic_b_work; |
862 | int start {0}, end {0}; |
863 | balance211(work, jcp.nthr_mb, ithr_mb, start, end); |
864 | if (start == end) return; |
865 | |
866 | for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { |
867 | int w = start; |
868 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; |
869 | nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, |
870 | oc_b_work, sub_ic_b_start, ic_b_work); |
871 | while (w < end) { |
872 | const int g = g_start + sub_g_start; |
873 | const int oc_b = oc_b_start + sub_oc_b_start; |
874 | const int ic_b = ic_b_start + sub_ic_b_start; |
875 | const int ic_to_accumulate |
876 | = nstl::min(end - w, ic_b_work - sub_ic_b_start) |
877 | * jcp.ic_block; |
878 | const int acc_size |
879 | = this_block_size(ic_b * jcp.ic_block, |
880 | jcp.ic_without_padding, ic_to_accumulate) |
881 | * jcp.oc_block; |
882 | |
883 | const size_t off |
884 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
885 | data_t *d = diff_weights + off; |
886 | data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off; |
887 | |
888 | acc_ker_->accumulate(d, s, acc_size); |
889 | |
890 | nd_iterator_jump(w, end, sub_g_start, g_work, |
891 | sub_oc_b_start, oc_b_work, sub_ic_b_start, |
892 | ic_b_work); |
893 | } |
894 | } |
895 | } |
896 | }; |
897 | |
898 | auto ker_bias = [&](int ithr, int nthr) { |
899 | assert(nthr == rb->balancer().nthr_); |
900 | |
901 | const int b_job_start = rb->balancer().ithr_job_off(ithr); |
902 | const int b_njobs = rb->balancer().ithr_njobs(ithr); |
903 | |
904 | if (b_njobs == 0) return; |
905 | |
906 | /* reduction dimension */ |
907 | int img_start {0}, img_end {0}; |
908 | |
909 | balance211(jcp.mb, rb->balancer().nthr_per_group_, |
910 | rb->balancer().id_in_group(ithr), img_start, img_end); |
911 | |
912 | /* jobs */ |
913 | int g_start {0}, ocb_start {0}; |
914 | nd_iterator_init( |
915 | b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load); |
916 | |
917 | for (int img = img_start; img < img_end; ++img) { |
918 | int g = g_start, ocb = ocb_start; |
919 | for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { |
920 | const int oc_off_idx = is_ddst_layout_nxc |
921 | ? g * jcp.oc + ocb * jcp.oc_block |
922 | : g * jcp.nb_load + ocb; |
923 | const data_t *d_dst |
924 | = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; |
925 | |
926 | data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, |
927 | reducer_bia_scratchpad) |
928 | + b_job_loc * rb->balancer().job_size_; |
929 | const int sp_shift = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc |
930 | : jcp.oc_block; |
931 | const auto max_oc = this_block_size( |
932 | ocb * jcp.oc_block, jcp.oc, jcp.oc_block); |
933 | if (img == img_start) |
934 | for (int o = 0; o < 16; ++o) |
935 | d_bias[o] = 0.; |
936 | |
937 | for (int os = 0; os < jcp.os; ++os) { |
938 | PRAGMA_OMP_SIMD() |
939 | for (int o = 0; o < max_oc; ++o) |
940 | d_bias[o] += d_dst[o]; |
941 | d_dst += sp_shift; |
942 | } |
943 | |
944 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load); |
945 | } |
946 | } |
947 | |
948 | if (dnnl_thr_syncable()) |
949 | rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); |
950 | }; |
951 | |
952 | if (dnnl_thr_syncable()) { |
953 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
954 | ker(ithr, jcp.nthr); |
955 | if (pd()->with_bias()) ker_bias(ithr, jcp.nthr); |
956 | }); |
957 | } else { |
958 | parallel(jcp.nthr, [&](int ithr, int nthr) { ker(ithr, nthr); }); |
959 | if (jcp.nthr_mb > 1) |
960 | parallel(jcp.nthr, [&](int ithr, int nthr) { |
961 | assert(nthr == jcp.nthr); |
962 | |
963 | const int ithr_ic_b = ithr % jcp.nthr_ic_b; |
964 | const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; |
965 | const int ithr_g |
966 | = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; |
967 | const int ithr_mb |
968 | = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; |
969 | |
970 | /* independent dimensions */ |
971 | int g_start {0}, oc_b_start {0}, ic_b_start {0}; |
972 | int g_end {0}, oc_b_end {0}, ic_b_end {0}; |
973 | |
974 | balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); |
975 | balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, |
976 | oc_b_end); |
977 | balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, |
978 | ic_b_end); |
979 | |
980 | const int g_work = g_end - g_start; |
981 | const int oc_b_work = oc_b_end - oc_b_start; |
982 | const int ic_b_work = ic_b_end - ic_b_start; |
983 | |
984 | const int work = g_work * oc_b_work * ic_b_work; |
985 | int start {0}, end {0}; |
986 | balance211(work, jcp.nthr_mb, ithr_mb, start, end); |
987 | if (start == end) return; |
988 | |
989 | for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) { |
990 | int w = start; |
991 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; |
992 | nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, |
993 | oc_b_work, sub_ic_b_start, ic_b_work); |
994 | while (w < end) { |
995 | const int g = g_start + sub_g_start; |
996 | const int oc_b = oc_b_start + sub_oc_b_start; |
997 | const int ic_b = ic_b_start + sub_ic_b_start; |
998 | const int ic_to_accumulate |
999 | = nstl::min(end - w, ic_b_work - sub_ic_b_start) |
1000 | * jcp.ic_block; |
1001 | const int acc_size |
1002 | = this_block_size(ic_b * jcp.ic_block, |
1003 | jcp.ic_without_padding, |
1004 | ic_to_accumulate) |
1005 | * jcp.oc_block; |
1006 | |
1007 | const size_t off |
1008 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1009 | data_t *d = diff_weights + off; |
1010 | data_t *s |
1011 | = wei_reduction + (thr_mb - 1) * wei_size + off; |
1012 | |
1013 | acc_ker_->accumulate(d, s, acc_size); |
1014 | |
1015 | nd_iterator_jump(w, end, sub_g_start, g_work, |
1016 | sub_oc_b_start, oc_b_work, sub_ic_b_start, |
1017 | ic_b_work); |
1018 | } |
1019 | } |
1020 | }); |
1021 | if (pd()->with_bias()) { |
1022 | parallel(jcp.nthr, |
1023 | [&](int ithr, int nthr) { ker_bias(ithr, nthr); }); |
1024 | parallel(jcp.nthr, [&](int ithr, int nthr) { |
1025 | assert(nthr == rb->balancer().nthr_); |
1026 | MAYBE_UNUSED(nthr); |
1027 | if (rb->balancer().ithr_njobs(ithr) == 0) return; |
1028 | rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad); |
1029 | }); |
1030 | } |
1031 | } |
1032 | |
1033 | /* TODO: put this in ker_bias */ |
1034 | if (is_bias_padded) { |
1035 | assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1)); |
1036 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
1037 | const int stride = jcp.oc_without_padding; |
1038 | for (int g = 0; g < jcp.ngroups; ++g) { |
1039 | utils::array_copy(diff_bias_in + g * stride, |
1040 | diff_bias + g * padded_stride, stride); |
1041 | } |
1042 | } |
1043 | } |
1044 | |
1045 | } // namespace x64 |
1046 | } // namespace cpu |
1047 | } // namespace impl |
1048 | } // namespace dnnl |
1049 | |