1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_avx512_common_convolution.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | |
29 | using namespace dnnl::impl::status; |
30 | using namespace dnnl::impl::memory_tracking::names; |
31 | using namespace dnnl::impl::utils; |
32 | |
33 | using namespace nstl; |
34 | |
35 | using jit_conv_ker_t = void (*)(jit_conv_call_s *); |
36 | |
37 | inline void jit_conv_ker_pipeline(const jit_conv_ker_t ker, jit_conv_call_s &p, |
38 | const void *src, const void *dst, const void *filt, const void *bias, |
39 | int channel, int kh_padding, int reduce_work, int load_work) { |
40 | p.src = src; |
41 | p.dst = dst; |
42 | p.filt = filt; |
43 | p.bias = bias; |
44 | p.channel = channel; |
45 | // non-positive value of kh_padding is allowed, in this case kernel must |
46 | // skip computation part and initialize output by zeroes |
47 | p.kh_padding = kh_padding; |
48 | p.reduce_work = reduce_work; |
49 | p.load_work = load_work; |
50 | |
51 | ker(&p); |
52 | } |
53 | // The special case for the driver with iw-parallelization (BWD) |
54 | inline void jit_conv_ker_pipeline_iw_thr(const jit_conv_ker_t ker, |
55 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
56 | const void *bias, int channel, int kh_padding, int iwb, int reduce_work, |
57 | int load_work) { |
58 | p.iwb = iwb; |
59 | |
60 | jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, |
61 | reduce_work, load_work); |
62 | } |
63 | |
64 | inline void jit_conv_3d_ker_pipeline(const jit_conv_ker_t ker, |
65 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
66 | const void *bias, int channel, int kh_padding, int kd_padding, |
67 | int reduce_work, int load_work) { |
68 | p.src = src; |
69 | p.dst = dst; |
70 | p.filt = filt; |
71 | p.bias = bias; |
72 | p.channel = channel; |
73 | // non-positive value of both kd_padding and kh_padding is allowed, in this |
74 | // case kernel must skip computation part and initialize output by zeroes |
75 | p.kh_padding = kh_padding; |
76 | p.kd_padding = kd_padding; |
77 | p.reduce_work = reduce_work; |
78 | p.load_work = load_work; |
79 | |
80 | ker(&p); |
81 | } |
82 | // The special case for the driver with ow-parallelization (FWD) |
83 | inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p, |
84 | const void *src, const void *dst, const void *filt, const void *bias, |
85 | int channel, int kh_padding, int owb, int reduce_work, int load_work, |
86 | const void *post_ops_binary_rhs_arg_vec, int oc_l_off, |
87 | const void *dst_orig, int flags) { |
88 | p.owb = owb; |
89 | p.flags = flags; |
90 | |
91 | p.oc_l_off = oc_l_off; |
92 | p.dst_orig = dst_orig; |
93 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
94 | |
95 | jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, |
96 | reduce_work, load_work); |
97 | } |
98 | |
99 | // The special case for the driver with ow-parallelization (FWD) |
100 | // TODO: implement it for BWD_D and BWD_W too |
101 | inline void jit_conv_3d_ker_pipeline_ow_thr(const jit_conv_ker_t ker, |
102 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
103 | const void *bias, int channel, int kh_padding, int kd_padding, int owb, |
104 | int reduce_work, int load_work, const void *post_ops_binary_rhs_arg_vec, |
105 | int oc_l_off, const void *dst_orig, int flags) { |
106 | |
107 | p.oc_l_off = oc_l_off; |
108 | p.dst_orig = dst_orig; |
109 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
110 | |
111 | p.owb = owb; |
112 | p.flags = flags; |
113 | |
114 | jit_conv_3d_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, |
115 | kd_padding, reduce_work, load_work); |
116 | } |
117 | |
118 | inline void jit_conv_ker_pipeline_bwd_w(const jit_conv_ker_t ker, |
119 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
120 | const void *bias, int channel, int kh_padding, size_t reduce_work, |
121 | size_t load_work) { |
122 | jit_conv_ker_pipeline(ker, p, src, dst, filt, bias, channel, kh_padding, |
123 | reduce_work, load_work); |
124 | } |
125 | |
126 | void jit_conv_2d_ker_bwd_w_pipeline(const jit_conv_ker_t ker, |
127 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
128 | const void *bias, int channel, int os_index_begin, int os_index_end, |
129 | int kh_padding /* kh_work_size */, size_t kh_offset, size_t reduce_work, |
130 | size_t load_work) { |
131 | p.src = src; |
132 | p.dst = dst; |
133 | p.filt = filt; |
134 | p.bias = bias; |
135 | p.channel = channel; |
136 | p.os_index_begin = os_index_begin; |
137 | p.os_index_end = os_index_end; |
138 | // non-positive value of kh_padding is allowed, in this case kernel must |
139 | // skip kw loop computation and initialize output by zeroes |
140 | p.kh_padding = kh_padding; |
141 | p.kh_offset = kh_offset; |
142 | p.reduce_work = reduce_work; |
143 | p.load_work = load_work; |
144 | |
145 | ker(&p); |
146 | } |
147 | |
148 | void jit_conv_3d_ker_bwd_w_pipeline(const jit_conv_ker_t ker, |
149 | jit_conv_call_s &p, const void *src, const void *dst, const void *filt, |
150 | const void *bias, int channel, int os_index_begin, int os_index_end, |
151 | int kd_padding /* kd_work_size */, size_t kd_offset, size_t reduce_work, |
152 | size_t load_work) { |
153 | p.src = src; |
154 | p.dst = dst; |
155 | p.filt = filt; |
156 | p.bias = bias; |
157 | p.channel = channel; |
158 | p.os_index_begin = os_index_begin; |
159 | p.os_index_end = os_index_end; |
160 | // non-positive value of kd_padding is allowed, in this case kernel must |
161 | // skip kh loop computation and initialize output by zeroes |
162 | p.kd_padding = kd_padding; |
163 | p.kd_offset = kd_offset; |
164 | p.reduce_work = reduce_work; |
165 | p.load_work = load_work; |
166 | |
167 | ker(&p); |
168 | } |
169 | #define wht_blk_off(d, g, ...) \ |
170 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
171 | : (d).blk_off(__VA_ARGS__)) |
172 | |
173 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
174 | void jit_avx512_common_convolution_fwd_t<src_type, wei_type, |
175 | dst_type>::prepare_padded_bias(const dst_data_t *&bias, |
176 | const memory_tracking::grantor_t &scratchpad) const { |
177 | if (!pd()->wants_padded_bias()) return; |
178 | |
179 | auto padded_bias |
180 | = scratchpad.template get<dst_data_t>(key_conv_padded_bias); |
181 | utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding); |
182 | utils::array_set(padded_bias + pd()->jcp_.oc_without_padding, (dst_data_t)0, |
183 | pd()->jcp_.oc - pd()->jcp_.oc_without_padding); |
184 | bias = padded_bias; |
185 | } |
186 | |
187 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
188 | void jit_avx512_common_convolution_fwd_t<src_type, wei_type, |
189 | dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const { |
190 | const auto &jcp = pd()->jcp_; |
191 | |
192 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
193 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
194 | auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); |
195 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
196 | const auto post_ops_binary_rhs_arg_vec |
197 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
198 | |
199 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
200 | |
201 | const memory_desc_wrapper src_d(pd()->src_md()); |
202 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
203 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
204 | |
205 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
206 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
207 | |
208 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
209 | int g_blocking = 1; |
210 | int nb_groups = jcp.ngroups / g_blocking; |
211 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; |
212 | int nthr = jcp.aligned_threads; |
213 | |
214 | parallel(nthr, [&](const int ithr, const int nthr) { |
215 | int start {0}, end {0}, start_copy; |
216 | balance211(work_amount, nthr, ithr, start, end); |
217 | start_copy = start; |
218 | |
219 | auto par_conv = jit_conv_call_s(); |
220 | size_t src_c_stride = src_d.blk_off(0, 1); |
221 | size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); |
222 | |
223 | for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { |
224 | start = start_copy; |
225 | int n {0}, gg {0}, occ {0}, owb {0}; |
226 | |
227 | if (jcp.loop_order == loop_cwgn) { |
228 | int dummy {0}; |
229 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
230 | nb_groups, n, jcp.mb, dummy, 1); |
231 | } else if (jcp.loop_order == loop_gncw) { |
232 | int dummy {0}; |
233 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, |
234 | oc_chunks, owb, jcp.nb_ow, dummy, 1); |
235 | } else if (jcp.loop_order == loop_nhwcg) { |
236 | nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, |
237 | oc_chunks, gg, nb_groups); |
238 | } else { |
239 | assert(!"unsupported loop order" ); |
240 | } |
241 | |
242 | while (start < end) { |
243 | int ocb = occ * jcp.nb_oc_blocking; |
244 | int g = gg * g_blocking; |
245 | int g_ocb = g * jcp.nb_oc + ocb; |
246 | int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; |
247 | |
248 | int ow_s = owb * jcp.ow_block; |
249 | int iw_s = ow_s * jcp.stride_w; |
250 | const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nwc; |
251 | const int oc_off_idx = is_dst_layout_nxc |
252 | ? g * jcp.oc + ocb * jcp.oc_block |
253 | : g_ocb; |
254 | auto dst_w = dst + dst_d.blk_off(n, oc_off_idx, ow_s); |
255 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::nwc; |
256 | const int ic_off_idx = is_src_layout_nxc |
257 | ? g * jcp.ic + icb_l2 * jcp.ic_block |
258 | : g_icb + icb_l2; |
259 | auto src_w = src + src_d.blk_off(n, ic_off_idx, iw_s); |
260 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2); |
261 | auto bias_w = bias ? bias |
262 | + oc_off_idx |
263 | * (is_dst_layout_nxc ? 1 : jcp.oc_block) |
264 | : nullptr; |
265 | |
266 | int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1; |
267 | int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); |
268 | const int oc_work = utils::this_block_size(ocb * jcp.oc_block, |
269 | jcp.oc_without_padding, |
270 | jcp.nb_oc_blocking * jcp.oc_block); |
271 | |
272 | int ic_work = icb_step * jcp.ic_block; |
273 | const int oc_l_off |
274 | = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
275 | |
276 | for (int icb = icb_l2; icb < icb_end; icb += icb_step) { |
277 | int curr_nb_ic = nstl::min(icb_step, icb_end - icb); |
278 | int flags = 0; |
279 | if (icb == 0) flags |= FLAG_IC_FIRST; |
280 | if (icb + curr_nb_ic >= jcp.nb_ic) { |
281 | flags |= FLAG_IC_LAST; |
282 | ic_work = utils::this_block_size(icb * jcp.ic_block, |
283 | jcp.ic, icb_step * jcp.ic_block); |
284 | } |
285 | jit_conv_ker_pipeline_ow_thr(jit_ker, par_conv, src_w, |
286 | dst_w, wht_w, bias_w, icb, 1, owb, ic_work, oc_work, |
287 | post_ops_binary_rhs_arg_vec.data(), oc_l_off, dst, |
288 | flags); |
289 | |
290 | src_w += src_c_stride; |
291 | wht_w += wht_ic_stride; |
292 | } |
293 | if (jcp.loop_order == loop_cwgn) { |
294 | int dummy {0}; |
295 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
296 | gg, nb_groups, n, jcp.mb, dummy, 1); |
297 | } else if (jcp.loop_order == loop_gncw) { |
298 | int dummy {0}; |
299 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
300 | oc_chunks, owb, jcp.nb_ow, dummy, 1); |
301 | } else if (jcp.loop_order == loop_nhwcg) { |
302 | ++start; |
303 | nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, |
304 | gg, nb_groups); |
305 | } else { |
306 | assert(!"unsupported loop order" ); |
307 | } |
308 | } |
309 | } |
310 | }); |
311 | } |
312 | |
313 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
314 | void jit_avx512_common_convolution_fwd_t<src_type, wei_type, |
315 | dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const { |
316 | |
317 | const auto &jcp = pd()->jcp_; |
318 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
319 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
320 | auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); |
321 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
322 | const auto post_ops_binary_rhs_arg_vec |
323 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
324 | |
325 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
326 | |
327 | const memory_desc_wrapper src_d(pd()->src_md()); |
328 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
329 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
330 | |
331 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
332 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
333 | |
334 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
335 | int g_blocking = 1; |
336 | int nb_groups = jcp.ngroups / g_blocking; |
337 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; |
338 | int nthr = jcp.aligned_threads; |
339 | |
340 | parallel(nthr, [&](const int ithr, const int nthr) { |
341 | int start {0}, end {0}, start_copy; |
342 | balance211(work_amount, nthr, ithr, start, end); |
343 | start_copy = start; |
344 | |
345 | auto par_conv = jit_conv_call_s(); |
346 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
347 | size_t src_c_stride = src_d.blk_off(0, 1); |
348 | size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
349 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
350 | size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); |
351 | |
352 | for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { |
353 | start = start_copy; |
354 | int n {0}, gg {0}, occ {0}, oh_s {0}, owb {0}; |
355 | |
356 | if (jcp.loop_order == loop_cwgn) |
357 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
358 | nb_groups, n, jcp.mb, oh_s, jcp.oh); |
359 | else if (jcp.loop_order == loop_gncw) |
360 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, |
361 | oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); |
362 | else if (jcp.loop_order == loop_nhwcg) |
363 | nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
364 | occ, oc_chunks, gg, nb_groups); |
365 | else |
366 | assert(!"unsupported loop order" ); |
367 | |
368 | while (start < end) { |
369 | int ocb = occ * jcp.nb_oc_blocking; |
370 | int g = gg * g_blocking; |
371 | int g_ocb = g * jcp.nb_oc + ocb; |
372 | int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; |
373 | |
374 | int work_rem = end - start; |
375 | |
376 | int ow_s = owb * jcp.ow_block; |
377 | int iw_s = ow_s * jcp.stride_w; |
378 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
379 | if (jcp.loop_order == loop_nhwcg) |
380 | oh_e = oh_s + 1; //step instead |
381 | |
382 | for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) { |
383 | int ih_b = -jcp.t_pad + oh_b * jcp.stride_h; |
384 | const bool is_dst_layout_nxc |
385 | = jcp.dst_tag == format_tag::nhwc; |
386 | const int oc_off_idx = is_dst_layout_nxc |
387 | ? g * jcp.oc + ocb * jcp.oc_block |
388 | : g_ocb; |
389 | auto dst_w = dst + dst_d.blk_off(n, oc_off_idx, oh_b, ow_s); |
390 | const bool is_src_layout_nxc |
391 | = jcp.src_tag == format_tag::nhwc; |
392 | const int ic_off_idx = is_src_layout_nxc |
393 | ? g * jcp.ic + icb_l2 * jcp.ic_block |
394 | : g_icb + icb_l2; |
395 | auto src_w = src + src_d.blk_off(n, ic_off_idx, ih_b, iw_s); |
396 | auto wht_w |
397 | = weights + wht_blk_off(weights_d, g, ocb, icb_l2); |
398 | |
399 | int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1; |
400 | int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); |
401 | auto bias_w = bias ? bias |
402 | + oc_off_idx |
403 | * (is_dst_layout_nxc ? 1 |
404 | : jcp.oc_block) |
405 | : nullptr; |
406 | const int oc_work = utils::this_block_size( |
407 | ocb * jcp.oc_block, jcp.oc_without_padding, |
408 | jcp.nb_oc_blocking * jcp.oc_block); |
409 | const int oc_l_off = oc_off_idx |
410 | * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
411 | int ic_work = icb_step * jcp.ic_block; |
412 | for (int icb = icb_l2; icb < icb_end; icb += icb_step) { |
413 | int curr_nb_ic = nstl::min(icb_step, icb_end - icb); |
414 | int flags = 0; |
415 | if (icb == 0) flags |= FLAG_IC_FIRST; |
416 | if (icb + curr_nb_ic >= jcp.nb_ic) { |
417 | flags |= FLAG_IC_LAST; |
418 | ic_work = utils::this_block_size(icb * jcp.ic_block, |
419 | jcp.ic, icb_step * jcp.ic_block); |
420 | } |
421 | auto src_c = src_w; |
422 | auto dst_c = dst_w; |
423 | for (int oj = oh_b, ij = ih_b; |
424 | oj < min(oh_e, oh_b + jcp.h_blocking); |
425 | ++oj, ij += jcp.stride_h) { |
426 | int dilate_h = jcp.dilate_h + 1; |
427 | int i_t_overflow = div_up(max(0, -ij), dilate_h); |
428 | int i_b_overflow = div_up( |
429 | max(0, |
430 | ij - jcp.ih |
431 | + (jcp.kh - 1) * dilate_h |
432 | + 1), |
433 | dilate_h); |
434 | int kh_padding = nstl::max( |
435 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
436 | |
437 | auto aux_src = src_c |
438 | + i_t_overflow * dilate_h * src_h_stride; |
439 | auto aux_wht = wht_w + i_t_overflow * wht_h_stride; |
440 | |
441 | jit_conv_ker_pipeline_ow_thr(jit_ker, par_conv, |
442 | aux_src, dst_c, aux_wht, bias_w, icb, |
443 | kh_padding, owb, ic_work, oc_work, |
444 | post_ops_binary_rhs_arg_vec.data(), |
445 | oc_l_off, dst, flags); |
446 | |
447 | src_c += src_h_stride * jcp.stride_h; |
448 | dst_c += dst_h_stride; |
449 | } |
450 | src_w += src_c_stride; |
451 | wht_w += wht_ic_stride; |
452 | } |
453 | } |
454 | |
455 | if (jcp.loop_order == loop_cwgn) |
456 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
457 | gg, nb_groups, n, jcp.mb, oh_s, jcp.oh); |
458 | else if (jcp.loop_order == loop_gncw) |
459 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
460 | oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); |
461 | else if (jcp.loop_order == loop_nhwcg) { |
462 | ++start; |
463 | nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
464 | occ, oc_chunks, gg, nb_groups); |
465 | } else |
466 | assert(!"unsupported loop order" ); |
467 | } |
468 | } |
469 | }); |
470 | } |
471 | |
472 | template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type> |
473 | void jit_avx512_common_convolution_fwd_t<src_type, wei_type, |
474 | dst_type>::execute_forward_3d(const exec_ctx_t &ctx) const { |
475 | const auto &jcp = pd()->jcp_; |
476 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
477 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
478 | auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); |
479 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
480 | const auto post_ops_binary_rhs_arg_vec |
481 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
482 | |
483 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
484 | |
485 | const memory_desc_wrapper src_d(pd()->src_md()); |
486 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
487 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
488 | |
489 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
490 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
491 | |
492 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
493 | int g_blocking = 1; |
494 | int nb_groups = jcp.ngroups / g_blocking; |
495 | int work_amount |
496 | = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; |
497 | int nthr = jcp.nthr; |
498 | |
499 | parallel(nthr, [&](const int ithr, const int nthr) { |
500 | int start {0}, end {0}, start_copy; |
501 | balance211(work_amount, nthr, ithr, start, end); |
502 | start_copy = start; |
503 | |
504 | auto par_conv = jit_conv_call_s(); |
505 | size_t src_d_stride = src_d.blk_off(0, 0, 1); |
506 | size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
507 | size_t src_c_stride = src_d.blk_off(0, 1); |
508 | size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
509 | size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
510 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
511 | size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1); |
512 | |
513 | for (int icb_l2 = 0; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) { |
514 | start = start_copy; |
515 | int n {0}, gg {0}, occ {0}, oh_s {0}, od_s {0}, owb {0}; |
516 | |
517 | if (jcp.loop_order == loop_cwgn) |
518 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
519 | nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
520 | else if (jcp.loop_order == loop_gncw) |
521 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, |
522 | oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); |
523 | else if (jcp.loop_order == loop_nhwcg) |
524 | nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, |
525 | owb, jcp.nb_ow, occ, oc_chunks, gg, nb_groups); |
526 | else |
527 | assert(!"unsupported loop order" ); |
528 | |
529 | while (start < end) { |
530 | int ocb = occ * jcp.nb_oc_blocking; |
531 | int g = gg * g_blocking; |
532 | int g_ocb = g * jcp.nb_oc + ocb; |
533 | int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off; |
534 | |
535 | int work_rem = end - start; |
536 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
537 | int ow_s = owb * jcp.ow_block; |
538 | int iw_s = ow_s * jcp.stride_w; |
539 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
540 | if (jcp.loop_order == loop_nhwcg) |
541 | oh_e = oh_s + 1; //step instead |
542 | |
543 | int id_s = -jcp.f_pad + od_s * jcp.stride_d; |
544 | |
545 | int dilate_d = jcp.dilate_d + 1; |
546 | int d_t_overflow = div_up(max(0, -id_s), dilate_d); |
547 | int d_b_overflow = div_up( |
548 | max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), |
549 | dilate_d); |
550 | int kd_padding |
551 | = nstl::max(0, jcp.kd - d_t_overflow - d_b_overflow); |
552 | const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::ndhwc; |
553 | const int oc_off_idx = is_dst_layout_nxc |
554 | ? g * jcp.oc + ocb * jcp.oc_block |
555 | : g_ocb; |
556 | auto dst_w |
557 | = dst + dst_d.blk_off(n, oc_off_idx, od_s, oh_s, ow_s); |
558 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc; |
559 | const int ic_off_idx = is_src_layout_nxc |
560 | ? g * jcp.ic + icb_l2 * jcp.ic_block |
561 | : g_icb + icb_l2; |
562 | auto src_w = src |
563 | + src_d.blk_off(n, ic_off_idx, id_s, ih_s, iw_s) |
564 | + d_t_overflow * dilate_d * src_d_stride; |
565 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2) |
566 | + d_t_overflow * wht_d_stride; |
567 | auto bias_w = bias ? bias |
568 | + oc_off_idx |
569 | * (is_dst_layout_nxc ? 1 : jcp.oc_block) |
570 | : nullptr; |
571 | |
572 | const int icb_step = is_src_layout_nxc ? jcp.nb_ic_L2 : 1; |
573 | int icb_end = min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); |
574 | const int oc_work = utils::this_block_size(ocb * jcp.oc_block, |
575 | jcp.oc_without_padding, |
576 | jcp.nb_oc_blocking * jcp.oc_block); |
577 | |
578 | const int oc_l_off |
579 | = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
580 | int ic_work = icb_step * jcp.ic_block; |
581 | for (int icb = icb_l2; icb < icb_end; icb += icb_step) { |
582 | int curr_nb_ic = nstl::min(icb_step, icb_end - icb); |
583 | int flags = 0; |
584 | if (icb == 0) flags |= FLAG_IC_FIRST; |
585 | if (icb + curr_nb_ic >= jcp.nb_ic) { |
586 | flags |= FLAG_IC_LAST; |
587 | ic_work = utils::this_block_size(icb * jcp.ic_block, |
588 | jcp.ic, icb_step * jcp.ic_block); |
589 | } |
590 | auto src_c = src_w; |
591 | auto dst_c = dst_w; |
592 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
593 | ++oj, ij += jcp.stride_h) { |
594 | int dilate_h = jcp.dilate_h + 1; |
595 | int i_t_overflow = div_up(max(0, -ij), dilate_h); |
596 | int i_b_overflow = div_up( |
597 | max(0, |
598 | ij - jcp.ih + (jcp.kh - 1) * dilate_h |
599 | + 1), |
600 | dilate_h); |
601 | int kh_padding = nstl::max( |
602 | 0, jcp.kh - i_t_overflow - i_b_overflow); |
603 | jit_conv_3d_ker_pipeline_ow_thr(jit_ker, par_conv, |
604 | src_c + i_t_overflow * dilate_h * src_h_stride, |
605 | dst_c, wht_w + i_t_overflow * wht_h_stride, |
606 | bias_w, icb, kh_padding, kd_padding, owb, |
607 | ic_work, oc_work, |
608 | post_ops_binary_rhs_arg_vec.data(), oc_l_off, |
609 | dst, flags); |
610 | |
611 | src_c += src_h_stride * jcp.stride_h; |
612 | dst_c += dst_h_stride; |
613 | } |
614 | src_w += src_c_stride; |
615 | wht_w += wht_ic_stride; |
616 | } |
617 | |
618 | if (jcp.loop_order == loop_cwgn) |
619 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, |
620 | gg, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, |
621 | jcp.oh); |
622 | else if (jcp.loop_order == loop_gncw) |
623 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
624 | oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, |
625 | jcp.oh); |
626 | else if (jcp.loop_order == loop_nhwcg) { |
627 | ++start; |
628 | nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, |
629 | jcp.nb_ow, occ, oc_chunks, gg, nb_groups); |
630 | } else |
631 | assert(!"unsupported loop order" ); |
632 | } |
633 | } |
634 | }); |
635 | } |
636 | |
637 | template struct jit_avx512_common_convolution_fwd_t<data_type::f32>; |
638 | |
639 | template <data_type_t diff_dst_type, data_type_t wei_type, |
640 | data_type_t diff_src_type> |
641 | void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, |
642 | diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const { |
643 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
644 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
645 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
646 | |
647 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
648 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
649 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
650 | |
651 | const auto &jcp = pd()->jcp_; |
652 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
653 | |
654 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
655 | int g_blocking = 1; |
656 | int nb_groups = jcp.ngroups / g_blocking; |
657 | int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.nb_iw; |
658 | int nthr = jcp.nthr; |
659 | |
660 | parallel(nthr, [&](const int ithr, const int nthr) { |
661 | int start {0}, end {0}, start_copy; |
662 | balance211(work_amount, nthr, ithr, start, end); |
663 | start_copy = start; |
664 | |
665 | auto par_conv = jit_conv_call_s(); |
666 | size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); |
667 | size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); |
668 | |
669 | for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { |
670 | start = start_copy; |
671 | int n {0}, gg {0}, icc {0}, iwb {0}; |
672 | if (jcp.loop_order == loop_cwgn) { |
673 | int dummy {0}; |
674 | nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg, |
675 | nb_groups, n, jcp.mb, dummy, 1); |
676 | } else if (jcp.loop_order == loop_gncw) { |
677 | int dummy {0}; |
678 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, |
679 | ic_chunks, iwb, jcp.nb_iw, dummy, 1); |
680 | } else if (jcp.loop_order == loop_nhwcg) { |
681 | nd_iterator_init(start, n, jcp.mb, iwb, jcp.nb_iw, icc, |
682 | ic_chunks, gg, nb_groups); |
683 | } else { |
684 | assert(!"unsupported loop order" ); |
685 | } |
686 | |
687 | while (start < end) { |
688 | int icb = icc * jcp.nb_ic_blocking; |
689 | int g = gg * g_blocking; |
690 | int g_icb = g * jcp.nb_ic + icb; |
691 | int g_ocb = g * jcp.nb_oc; |
692 | int iw_s = iwb * jcp.iw_block; |
693 | int ow_s = iw_s / jcp.stride_w; |
694 | |
695 | const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::nwc; |
696 | const int ic_off_idx = is_dsrc_layout_nxc |
697 | ? g * jcp.ic + icb * jcp.ic_block |
698 | : g_icb; |
699 | auto diff_src_w |
700 | = diff_src + diff_src_d.blk_off(n, ic_off_idx, iw_s); |
701 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nwc; |
702 | const int oc_off_idx = is_ddst_layout_nxc |
703 | ? g * jcp.oc + ocb_l2 * jcp.oc_block |
704 | : g_ocb + ocb_l2; |
705 | auto diff_dst_w |
706 | = diff_dst + diff_dst_d.blk_off(n, oc_off_idx, ow_s); |
707 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); |
708 | |
709 | int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1; |
710 | int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); |
711 | const int load_work = utils::this_block_size(icb * jcp.ic_block, |
712 | jcp.ic, jcp.nb_ic_blocking * jcp.ic_block); |
713 | int reduce_work = ocb_step * jcp.oc_block; |
714 | for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) { |
715 | int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb); |
716 | if (ocb + curr_nb_oc >= jcp.nb_oc) { |
717 | reduce_work = utils::this_block_size(ocb * jcp.oc_block, |
718 | jcp.oc, ocb_step * jcp.oc_block); |
719 | } |
720 | |
721 | jit_conv_ker_pipeline_iw_thr(jit_ker, par_conv, diff_src_w, |
722 | diff_dst_w, wht_w, nullptr, ocb, 1, iwb, |
723 | reduce_work, load_work); |
724 | diff_dst_w += diff_dst_c_stride; |
725 | wht_w += wht_oc_stride; |
726 | } |
727 | |
728 | if (jcp.loop_order == loop_cwgn) { |
729 | int dummy {0}; |
730 | nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw, |
731 | gg, nb_groups, n, jcp.mb, dummy, 1); |
732 | } else if (jcp.loop_order == loop_gncw) { |
733 | int dummy {0}; |
734 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, |
735 | ic_chunks, iwb, jcp.nb_iw, dummy, 1); |
736 | } else if (jcp.loop_order == loop_nhwcg) { |
737 | ++start; |
738 | nd_iterator_step(n, jcp.mb, iwb, jcp.nb_iw, icc, ic_chunks, |
739 | gg, nb_groups); |
740 | } else { |
741 | assert(!"unsupported loop order" ); |
742 | } |
743 | } |
744 | } |
745 | }); |
746 | } |
747 | |
748 | template <data_type_t diff_dst_type, data_type_t wei_type, |
749 | data_type_t diff_src_type> |
750 | void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, |
751 | diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const { |
752 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
753 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
754 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
755 | |
756 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
757 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
758 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
759 | |
760 | const auto &jcp = pd()->jcp_; |
761 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
762 | |
763 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
764 | int g_blocking = 1; |
765 | int nb_groups = jcp.ngroups / g_blocking; |
766 | int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw; |
767 | int nthr = jcp.nthr; |
768 | |
769 | parallel(nthr, [&](const int ithr, const int nthr) { |
770 | int start {0}, end {0}, start_copy; |
771 | balance211(work_amount, nthr, ithr, start, end); |
772 | start_copy = start; |
773 | |
774 | auto par_conv = jit_conv_call_s(); |
775 | size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); |
776 | size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); |
777 | size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); |
778 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
779 | size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); |
780 | |
781 | bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; |
782 | |
783 | for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { |
784 | start = start_copy; |
785 | int n {0}, gg {0}, icc {0}, ih_s {0}, iwb {0}; |
786 | |
787 | if (jcp.loop_order == loop_cwgn) { |
788 | nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg, |
789 | nb_groups, n, jcp.mb, ih_s, jcp.ih); |
790 | } else if (jcp.loop_order == loop_gncw) { |
791 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, |
792 | ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih); |
793 | } else if (jcp.loop_order == loop_nhwcg) { |
794 | nd_iterator_init(start, n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, |
795 | icc, ic_chunks, gg, nb_groups); |
796 | } else |
797 | assert(!"unsupported loop order" ); |
798 | |
799 | while (start < end) { |
800 | int icb = icc * jcp.nb_ic_blocking; |
801 | int g = gg * g_blocking; |
802 | int g_icb = g * jcp.nb_ic + icb; |
803 | int g_ocb = g * jcp.nb_oc; |
804 | |
805 | int work_rem = end - start; |
806 | int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; |
807 | if (jcp.loop_order == loop_nhwcg) |
808 | ih_e = ih_s + 1; //step instead |
809 | int iw_s = iwb * jcp.iw_block; |
810 | int ow_s = iw_s / jcp.stride_w; |
811 | const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::nhwc; |
812 | const int ic_off_idx = is_dsrc_layout_nxc |
813 | ? g * jcp.ic + icb * jcp.ic_block |
814 | : g_icb; |
815 | auto diff_src_w |
816 | = diff_src + diff_src_d.blk_off(n, ic_off_idx, 0, iw_s); |
817 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc; |
818 | const int oc_off_idx = is_ddst_layout_nxc |
819 | ? g * jcp.oc + ocb_l2 * jcp.oc_block |
820 | : g_ocb + ocb_l2; |
821 | auto diff_dst_w |
822 | = diff_dst + diff_dst_d.blk_off(n, oc_off_idx, 0, ow_s); |
823 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb); |
824 | |
825 | int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1; |
826 | int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); |
827 | const int load_work = utils::this_block_size(icb * jcp.ic_block, |
828 | jcp.ic, jcp.nb_ic_blocking * jcp.ic_block); |
829 | int reduce_work = ocb_step * jcp.oc_block; |
830 | for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) { |
831 | int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb); |
832 | if (ocb + curr_nb_oc >= jcp.nb_oc) { |
833 | reduce_work = utils::this_block_size(ocb * jcp.oc_block, |
834 | jcp.oc, ocb_step * jcp.oc_block); |
835 | } |
836 | for (int ij = ih_s; ij < ih_e; ++ij) { |
837 | int oj, k_len, k_lo; |
838 | if (is_fast_path) { // dilate == 0 && stride == 1 |
839 | int i_t_overflow |
840 | = max(0, jcp.kh - 1 - ij - jcp.t_pad); |
841 | int i_b_overflow |
842 | = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad); |
843 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
844 | k_lo = i_b_overflow; |
845 | oj = ij + jcp.t_pad - i_b_overflow; |
846 | } else if (jcp.dilate_h != 0) { // stride == 1 |
847 | int dilate_h = jcp.dilate_h + 1; |
848 | // Note: use div_up to account for "holes" in filter |
849 | int i_t_overflow |
850 | = div_up(max(0, |
851 | (jcp.kh - 1) * dilate_h |
852 | - ij - jcp.t_pad), |
853 | dilate_h); |
854 | int i_b_overflow = div_up( |
855 | max(0, |
856 | (jcp.kh - 1) * dilate_h + 1 - jcp.ih |
857 | + ij - jcp.b_pad), |
858 | dilate_h); |
859 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
860 | k_lo = i_b_overflow; |
861 | oj = ij + jcp.t_pad - i_b_overflow * dilate_h; |
862 | } else { // dilate == 0 |
863 | int i_t_overflow = max(0, |
864 | (jcp.kh - 1 - ij - jcp.t_pad) |
865 | / jcp.stride_h); |
866 | int i_b_overflow = max(0, |
867 | (jcp.kh - jcp.ih + ij - jcp.b_pad) |
868 | / jcp.stride_h); |
869 | int overflow_kh_hi = jcp.kh - 1 |
870 | - modulo(jcp.ih - 1 + jcp.b_pad - ij, |
871 | jcp.stride_h); |
872 | int overflow_kh_lo |
873 | = (ij + jcp.t_pad) % jcp.stride_h; |
874 | |
875 | k_len = (overflow_kh_hi - overflow_kh_lo) |
876 | / jcp.stride_h |
877 | + 1 - i_t_overflow - i_b_overflow; |
878 | k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; |
879 | oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; |
880 | } |
881 | |
882 | jit_conv_ker_pipeline_iw_thr(jit_ker, par_conv, |
883 | diff_src_w + ij * diff_src_h_stride, |
884 | diff_dst_w + oj * diff_dst_h_stride, |
885 | wht_w + k_lo * wht_h_stride, nullptr, ocb, |
886 | k_len, iwb, reduce_work, load_work); |
887 | } |
888 | diff_dst_w += diff_dst_c_stride; |
889 | wht_w += wht_oc_stride; |
890 | } |
891 | |
892 | if (jcp.loop_order == loop_cwgn) { |
893 | nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw, |
894 | gg, nb_groups, n, jcp.mb, ih_s, jcp.ih); |
895 | } else if (jcp.loop_order == loop_gncw) { |
896 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, |
897 | ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih); |
898 | } else if (jcp.loop_order == loop_nhwcg) { |
899 | ++start; |
900 | nd_iterator_step(n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, |
901 | icc, ic_chunks, gg, nb_groups); |
902 | } else |
903 | assert(!"unsupported loop order" ); |
904 | } |
905 | } |
906 | }); |
907 | } |
908 | |
909 | template <data_type_t diff_dst_type, data_type_t wei_type, |
910 | data_type_t diff_src_type> |
911 | void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type, |
912 | diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const { |
913 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
914 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
915 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
916 | |
917 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
918 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
919 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
920 | |
921 | const auto &jcp = pd()->jcp_; |
922 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
923 | |
924 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
925 | int g_blocking = 1; |
926 | int nb_groups = jcp.ngroups / g_blocking; |
927 | int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih; |
928 | int nthr = jcp.nthr; |
929 | |
930 | parallel(nthr, [&](const int ithr, const int nthr) { |
931 | int start {0}, end {0}, start_copy; |
932 | balance211(work_amount, nthr, ithr, start, end); |
933 | start_copy = start; |
934 | |
935 | auto par_conv = jit_conv_call_s(); |
936 | size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); |
937 | size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1); |
938 | size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); |
939 | size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1); |
940 | size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1); |
941 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
942 | size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
943 | size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1); |
944 | |
945 | bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; |
946 | bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; |
947 | |
948 | for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) { |
949 | start = start_copy; |
950 | int n {0}, gg {0}, icc {0}, ih_s {0}, id_s {0}; |
951 | // Input width threading is not currently implemented for 3d, so it |
952 | // is not included in the iterator. |
953 | if (jcp.loop_order == loop_cwgn) |
954 | nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n, |
955 | jcp.mb, id_s, jcp.id, ih_s, jcp.ih); |
956 | else if (jcp.loop_order == loop_gncw) |
957 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, |
958 | ic_chunks, id_s, jcp.id, ih_s, jcp.ih); |
959 | else if (jcp.loop_order == loop_nhwcg) |
960 | nd_iterator_init(start, n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, |
961 | icc, ic_chunks, gg, nb_groups); |
962 | else |
963 | assert(!"unsupported loop order" ); |
964 | |
965 | while (start < end) { |
966 | int icb = icc * jcp.nb_ic_blocking; |
967 | int g = gg * g_blocking; |
968 | int g_icb = g * jcp.nb_ic + icb; |
969 | int g_ocb = g * jcp.nb_oc; |
970 | |
971 | int work_rem = end - start; |
972 | int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; |
973 | if (jcp.loop_order == loop_nhwcg) |
974 | ih_e = ih_s + 1; //step instead |
975 | int d_len = 0, d_lo = 0, d_oj = 0; |
976 | if (is_fast_path_d) { // dilate == 0 && stride == 1 |
977 | int d_t_overflow = max(0, jcp.kd - 1 - id_s - jcp.f_pad); |
978 | int d_b_overflow |
979 | = max(0, jcp.kd - jcp.id + id_s - jcp.back_pad); |
980 | d_len = jcp.kd - d_t_overflow - d_b_overflow; |
981 | d_lo = d_b_overflow; |
982 | d_oj = id_s + jcp.f_pad - d_b_overflow; |
983 | } else if (jcp.dilate_d != 0) { // stride == 1 |
984 | int dilate_d = jcp.dilate_d + 1; |
985 | // Note: use div_up to account for "holes" in filter |
986 | int d_t_overflow = div_up( |
987 | max(0, (jcp.kd - 1) * dilate_d - id_s - jcp.f_pad), |
988 | dilate_d); |
989 | int d_b_overflow = div_up( |
990 | max(0, |
991 | (jcp.kd - 1) * dilate_d + 1 - jcp.id + id_s |
992 | - jcp.back_pad), |
993 | dilate_d); |
994 | d_len = jcp.kd - d_t_overflow - d_b_overflow; |
995 | d_lo = d_b_overflow; |
996 | d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d; |
997 | } else { // dilate == 0 |
998 | int d_t_overflow = max( |
999 | 0, (jcp.kd - 1 - id_s - jcp.f_pad) / jcp.stride_d); |
1000 | int d_b_overflow = max(0, |
1001 | (jcp.kd - jcp.id + id_s - jcp.back_pad) |
1002 | / jcp.stride_d); |
1003 | int overflow_kd_hi = jcp.kd - 1 |
1004 | - modulo(jcp.id - 1 + jcp.back_pad - id_s, |
1005 | jcp.stride_d); |
1006 | int overflow_kd_lo = (id_s + jcp.f_pad) % jcp.stride_d; |
1007 | |
1008 | d_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1 |
1009 | - d_t_overflow - d_b_overflow; |
1010 | d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; |
1011 | d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d; |
1012 | } |
1013 | |
1014 | const bool is_dsrc_layout_nxc |
1015 | = jcp.src_tag == format_tag::ndhwc; |
1016 | const int ic_off_idx = is_dsrc_layout_nxc |
1017 | ? g * jcp.ic + icb * jcp.ic_block |
1018 | : g_icb; |
1019 | auto diff_src_w = diff_src + diff_src_d.blk_off(n, ic_off_idx) |
1020 | + id_s * diff_src_d_stride; |
1021 | const bool is_ddst_layout_nxc |
1022 | = jcp.dst_tag == format_tag::ndhwc; |
1023 | const int oc_off_idx = is_ddst_layout_nxc |
1024 | ? g * jcp.oc + ocb_l2 * jcp.oc_block |
1025 | : g_ocb + ocb_l2; |
1026 | auto diff_dst_w = diff_dst + diff_dst_d.blk_off(n, oc_off_idx) |
1027 | + d_oj * diff_dst_d_stride; |
1028 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb) |
1029 | + d_lo * wht_d_stride; |
1030 | |
1031 | int ocb_step = is_ddst_layout_nxc ? jcp.nb_oc_L2 : 1; |
1032 | int ocb_end = min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); |
1033 | const int load_work = utils::this_block_size(icb * jcp.ic_block, |
1034 | jcp.ic, jcp.nb_ic_blocking * jcp.ic_block); |
1035 | int reduce_work = ocb_step * jcp.oc_block; |
1036 | for (int ocb = ocb_l2; ocb < ocb_end; ocb += ocb_step) { |
1037 | int curr_nb_oc = nstl::min(ocb_step, ocb_end - ocb); |
1038 | if (ocb + curr_nb_oc >= jcp.nb_oc) { |
1039 | reduce_work = utils::this_block_size(ocb * jcp.oc_block, |
1040 | jcp.oc, ocb_step * jcp.oc_block); |
1041 | } |
1042 | for (int ij = ih_s; ij < ih_e; ++ij) { |
1043 | int oj, k_len, k_lo; |
1044 | if (is_fast_path_h) { // dilate == 0 && stride == 1 |
1045 | int i_t_overflow |
1046 | = max(0, jcp.kh - 1 - ij - jcp.t_pad); |
1047 | int i_b_overflow |
1048 | = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad); |
1049 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
1050 | k_lo = i_b_overflow; |
1051 | oj = ij + jcp.t_pad - i_b_overflow; |
1052 | } else if (jcp.dilate_h != 0) { // stride == 1 |
1053 | int dilate_h = jcp.dilate_h + 1; |
1054 | // Note: use div_up to account for "holes" in filter |
1055 | int i_t_overflow |
1056 | = div_up(max(0, |
1057 | (jcp.kh - 1) * dilate_h |
1058 | - ij - jcp.t_pad), |
1059 | dilate_h); |
1060 | int i_b_overflow = div_up( |
1061 | max(0, |
1062 | (jcp.kh - 1) * dilate_h + 1 - jcp.ih |
1063 | + ij - jcp.b_pad), |
1064 | dilate_h); |
1065 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
1066 | k_lo = i_b_overflow; |
1067 | oj = ij + jcp.t_pad - i_b_overflow * dilate_h; |
1068 | } else { // dilate == 0 |
1069 | int i_t_overflow = max(0, |
1070 | (jcp.kh - 1 - ij - jcp.t_pad) |
1071 | / jcp.stride_h); |
1072 | int i_b_overflow = max(0, |
1073 | (jcp.kh - jcp.ih + ij - jcp.b_pad) |
1074 | / jcp.stride_h); |
1075 | int overflow_kh_hi = jcp.kh - 1 |
1076 | - modulo(jcp.ih - 1 + jcp.b_pad - ij, |
1077 | jcp.stride_h); |
1078 | int overflow_kh_lo |
1079 | = (ij + jcp.t_pad) % jcp.stride_h; |
1080 | |
1081 | k_len = (overflow_kh_hi - overflow_kh_lo) |
1082 | / jcp.stride_h |
1083 | + 1 - i_t_overflow - i_b_overflow; |
1084 | k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; |
1085 | oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; |
1086 | } |
1087 | assert(k_len >= 0); |
1088 | |
1089 | jit_conv_3d_ker_pipeline(jit_ker, par_conv, |
1090 | diff_src_w + ij * diff_src_h_stride, |
1091 | diff_dst_w + oj * diff_dst_h_stride, |
1092 | wht_w + k_lo * wht_h_stride, nullptr, ocb, |
1093 | k_len, d_len, reduce_work, load_work); |
1094 | } |
1095 | diff_dst_w += diff_dst_c_stride; |
1096 | wht_w += wht_oc_stride; |
1097 | } |
1098 | |
1099 | if (jcp.loop_order == loop_cwgn) |
1100 | nd_iterator_jump(start, end, icc, ic_chunks, gg, nb_groups, |
1101 | n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih); |
1102 | else if (jcp.loop_order == loop_gncw) |
1103 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, |
1104 | ic_chunks, id_s, jcp.id, ih_s, jcp.ih); |
1105 | else if (jcp.loop_order == loop_nhwcg) { |
1106 | ++start; |
1107 | nd_iterator_step(n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc, |
1108 | ic_chunks, gg, nb_groups); |
1109 | } else |
1110 | assert(!"unsupported loop order" ); |
1111 | } |
1112 | } |
1113 | }); |
1114 | } |
1115 | |
1116 | template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>; |
1117 | |
1118 | template <data_type_t src_type, data_type_t diff_dst_type, |
1119 | data_type_t diff_weights_type> |
1120 | status_t jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1121 | diff_weights_type>::init(engine_t *engine) { |
1122 | const auto &j = pd()->jcp_; |
1123 | |
1124 | nthr_ = j.nthr; |
1125 | nthr_mb_ = j.nthr_mb; |
1126 | nthr_g_ = j.nthr_g; |
1127 | nthr_oc_b_ = j.nthr_oc_b; |
1128 | nthr_ic_b_ = j.nthr_ic_b; |
1129 | |
1130 | CHECK(safe_ptr_assign( |
1131 | kernel_, new jit_avx512_common_conv_bwd_weights_kernel_f32(j))); |
1132 | CHECK(kernel_->create_kernel()); |
1133 | |
1134 | if (nthr_mb_ > 1) { |
1135 | CHECK(safe_ptr_assign( |
1136 | acc_ker_, new cpu_accumulator_1d_t<diff_weights_type>())); |
1137 | CHECK(acc_ker_->create_kernel()); |
1138 | } |
1139 | |
1140 | CHECK(safe_ptr_assign(reducer_bias_, |
1141 | new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_))); |
1142 | CHECK(reducer_bias_->create_kernel()); |
1143 | return status::success; |
1144 | } |
1145 | |
1146 | template <data_type_t src_type, data_type_t diff_dst_type, |
1147 | data_type_t diff_weights_type> |
1148 | struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1149 | diff_weights_type>::thread_info_t { |
1150 | const src_data_t *src; |
1151 | const diff_dst_data_t *diff_dst; |
1152 | const diff_weights_data_t *diff_weights; |
1153 | diff_weights_data_t *diff_bias; |
1154 | |
1155 | const memory_tracking::grantor_t scratchpad; |
1156 | |
1157 | src_data_t *tr_src; |
1158 | simple_barrier::ctx_t *tr_src_bctx; |
1159 | |
1160 | diff_dst_data_t *tr_diff_dst; |
1161 | simple_barrier::ctx_t *tr_diff_dst_bctx; |
1162 | |
1163 | diff_weights_data_t *wei_bia_reduction; |
1164 | simple_barrier::ctx_t *wei_bia_reduction_bctx; |
1165 | |
1166 | int ithr; |
1167 | int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb; |
1168 | int ithr_but_oc; |
1169 | int ithr_but_ic; |
1170 | |
1171 | int img_start = 0, img_end = 0, img_work; |
1172 | int g_start = 0, g_end = 0, g_work; |
1173 | int oc_b_start = 0, oc_b_end = 0, oc_b_work; |
1174 | int ic_b_start = 0, ic_b_end = 0, ic_b_work; |
1175 | |
1176 | thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self, |
1177 | const exec_ctx_t &ctx, int ithr) |
1178 | : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) { |
1179 | diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
1180 | src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
1181 | diff_weights |
1182 | = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
1183 | const auto &jcp = self->kernel_->jcp; |
1184 | const bool is_bias_padded = self->pd()->with_bias() |
1185 | && jcp.oc_without_padding % jcp.oc_block != 0; |
1186 | diff_bias = is_bias_padded |
1187 | ? scratchpad.template get<diff_weights_data_t>( |
1188 | key_conv_padded_bias) |
1189 | : CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_BIAS); |
1190 | |
1191 | tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
1192 | tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1193 | key_conv_tr_src_bctx); |
1194 | |
1195 | tr_diff_dst = scratchpad.template get<diff_dst_data_t>( |
1196 | key_conv_tr_diff_dst); |
1197 | tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1198 | key_conv_tr_diff_dst_bctx); |
1199 | |
1200 | wei_bia_reduction = scratchpad.template get<diff_weights_data_t>( |
1201 | key_conv_wei_bia_reduction); |
1202 | wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1203 | key_conv_wei_bia_reduction_bctx); |
1204 | |
1205 | ithr_ic_b = ithr % self->nthr_ic_b_; |
1206 | ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; |
1207 | ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; |
1208 | ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; |
1209 | |
1210 | ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ |
1211 | + ithr_ic_b; |
1212 | |
1213 | ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ |
1214 | + ithr_oc_b; |
1215 | |
1216 | /* reduction dimension */ |
1217 | int oh_reduce = jcp.harness == harness_2d_reduction ? jcp.oh : 1; |
1218 | balance211(jcp.mb * jcp.od * oh_reduce, self->nthr_mb_, ithr_mb, |
1219 | img_start, img_end); |
1220 | img_work = img_end - img_start; |
1221 | |
1222 | /* independent dimensions */ |
1223 | balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); |
1224 | g_work = g_end - g_start; |
1225 | |
1226 | balance211( |
1227 | jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end); |
1228 | oc_b_work = oc_b_end - oc_b_start; |
1229 | |
1230 | balance211( |
1231 | jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end); |
1232 | ic_b_work = ic_b_end - ic_b_start; |
1233 | } |
1234 | }; |
1235 | |
1236 | template <data_type_t src_type, data_type_t diff_dst_type, |
1237 | data_type_t diff_weights_type> |
1238 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1239 | diff_weights_type>::compute_diff_weights_nxc(const thread_info_t *ti) |
1240 | const { |
1241 | const auto &jcp = kernel_->jcp; |
1242 | |
1243 | const int wei_size |
1244 | = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd; |
1245 | diff_weights_data_t *diff_wei = ti->ithr_mb == 0 |
1246 | ? (diff_weights_data_t *)ti->diff_weights |
1247 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1248 | |
1249 | auto diff_weights_offset |
1250 | = [&](int g, int i_kd, int i_kh, int i_kw, int i_ic, int i_oc) { |
1251 | const int oc_block_size = 1; |
1252 | const int ic_block_size = jcp.oc_block * oc_block_size; |
1253 | const int kw_block_size = jcp.ic_block * ic_block_size; |
1254 | const int kh_block_size = jcp.kw * kw_block_size; |
1255 | const int kd_block_size = jcp.kh * kh_block_size; |
1256 | const int icb_block_size = jcp.kd * kd_block_size; |
1257 | const int ocb_block_size = jcp.nb_ic * icb_block_size; |
1258 | const int g_block_size = jcp.nb_oc * ocb_block_size; |
1259 | |
1260 | int icb = i_ic / jcp.ic_block; |
1261 | int ocb = i_oc / jcp.oc_block; |
1262 | i_ic = i_ic % jcp.ic_block; |
1263 | i_oc = i_oc % jcp.oc_block; |
1264 | |
1265 | return g * g_block_size + ocb * ocb_block_size |
1266 | + icb * icb_block_size + i_kd * kd_block_size |
1267 | + i_kh * kh_block_size + i_kw * kw_block_size |
1268 | + i_ic * ic_block_size + i_oc * oc_block_size; |
1269 | }; |
1270 | auto src_offset |
1271 | = [&](int g, int i_mb, int i_id, int i_ih, int i_ic, int i_iw) { |
1272 | const int ic_block_size = 1; |
1273 | const int g_block_size = jcp.ic * ic_block_size; |
1274 | const int iw_block_size = jcp.ngroups * g_block_size; |
1275 | const int ih_block_size = jcp.iw * iw_block_size; |
1276 | const int id_block_size = jcp.ih * ih_block_size; |
1277 | const int mb_block_size = jcp.id * id_block_size; |
1278 | |
1279 | return g * g_block_size + i_mb * mb_block_size |
1280 | + i_id * id_block_size + i_ih * ih_block_size |
1281 | + i_iw * iw_block_size + i_ic * ic_block_size; |
1282 | }; |
1283 | auto diff_dst_offset |
1284 | = [&](int g, int i_mb, int i_od, int i_oh, int i_ow, int i_oc) { |
1285 | const int oc_block_size = 1; |
1286 | const int g_block_size = jcp.oc * oc_block_size; |
1287 | const int ow_block_size = jcp.ngroups * g_block_size; |
1288 | const int oh_block_size = jcp.ow * ow_block_size; |
1289 | const int od_block_size = jcp.oh * oh_block_size; |
1290 | const int mb_block_size = jcp.od * od_block_size; |
1291 | |
1292 | return g * g_block_size + i_mb * mb_block_size |
1293 | + i_od * od_block_size + i_oh * oh_block_size |
1294 | + i_ow * ow_block_size + i_oc * oc_block_size; |
1295 | }; |
1296 | auto zero_diff_weights = [&]() { |
1297 | PRAGMA_OMP_SIMD() |
1298 | for (dim_t i = 0; i < wei_size; i++) |
1299 | diff_wei[i] = 0; |
1300 | }; |
1301 | |
1302 | int kd_step = jcp.dilate_d + 1; |
1303 | int kh_step = jcp.dilate_h + 1; |
1304 | int stride_d = jcp.stride_d; |
1305 | int stride_h = jcp.stride_h; |
1306 | int f_pad = jcp.f_pad; |
1307 | int t_pad = jcp.t_pad; |
1308 | |
1309 | dim_t work_amount = jcp.mb * jcp.od * jcp.oh * jcp.nb_ow; |
1310 | dim_t i_work {0}, i_work_end {0}; |
1311 | balance211(work_amount, jcp.nthr_mb, ti->ithr_mb, i_work, i_work_end); |
1312 | |
1313 | int i_mb {0}, i_od {0}, i_oh {0}, i_owb {0}; |
1314 | nd_iterator_init( |
1315 | i_work, i_mb, jcp.mb, i_od, jcp.od, i_oh, jcp.oh, i_owb, jcp.nb_ow); |
1316 | |
1317 | zero_diff_weights(); |
1318 | while (i_work < i_work_end) { |
1319 | int kd_start = nstl::max( |
1320 | 0, div_up(jcp.f_pad - jcp.stride_d * i_od, kd_step)); |
1321 | int kd_end = nstl::min( |
1322 | jcp.kd - 1, (jcp.id - 1 + f_pad - stride_d * i_od) / kd_step); |
1323 | int i_id_base = stride_d * i_od - f_pad; |
1324 | int kh_start = nstl::max( |
1325 | 0, div_up(jcp.t_pad - jcp.stride_h * i_oh, +kh_step)); |
1326 | int kh_end = nstl::min( |
1327 | jcp.kh - 1, (jcp.ih - 1 + t_pad - stride_h * i_oh) / kh_step); |
1328 | int i_ih_base = jcp.stride_h * i_oh + -jcp.t_pad; |
1329 | int i_ow_base = i_owb * jcp.ow_block; |
1330 | int i_ow_end = nstl::min(jcp.ow, i_ow_base + jcp.ow_block); |
1331 | |
1332 | // The kernel is small so these loops produce measurable overhead. Since |
1333 | // these are simple loops, the compiler will likely make the loops just |
1334 | // as well as we can with the jitted assembly, so there is not |
1335 | // necessarily a reason to move these loops into assembly. Avoid placing |
1336 | // computationally heavy operations within the loops. |
1337 | for_(int i_ow = i_ow_base; i_ow < i_ow_end; i_ow += jcp.ur_ow) |
1338 | for_(int i_oc = 0; i_oc < jcp.oc; i_oc += jcp.oc_block) |
1339 | for_(int g = 0; g < jcp.ngroups; g++) |
1340 | for_(int i_kd = kd_start; i_kd <= kd_end; i_kd++) |
1341 | for (int i_kh = kh_start; i_kh <= kh_end; i_kh++) { |
1342 | // Some Optimization Observations: It may be |
1343 | // worthwhile to move the kd and kh loops below the |
1344 | // icb loop in the kernel to further amortize the |
1345 | // ddst register loads. Alternatively, these |
1346 | // dimensions are independent on the weights kernel, |
1347 | // so can be used as a threading dimension that does |
1348 | // not require reduction. |
1349 | |
1350 | // The compiler seems to do a good job at optimizing these |
1351 | // computations. The offset functions likely need to be located |
1352 | // so that they will be inlined. |
1353 | int i_iw = i_ow * jcp.stride_w - jcp.l_pad; |
1354 | int i_id = i_id_base + i_kd * kd_step; |
1355 | int i_ih = i_ih_base + i_kh * kh_step; |
1356 | int ddst_offset = diff_dst_offset(g, i_mb, i_od, i_oh, i_ow, i_oc); |
1357 | int s_off_base = src_offset(g, i_mb, i_id, i_ih, 0, i_iw); |
1358 | int dwei_off_base = diff_weights_offset(g, i_kd, i_kh, 0, 0, i_oc); |
1359 | // ensure all parameters are 64bit, to comply with windows kernel |
1360 | // param access where the params from 5th are passed using stack. |
1361 | (*kernel_)(&diff_wei[dwei_off_base], &ti->src[s_off_base], |
1362 | &ti->diff_dst[ddst_offset], (dim_t)i_iw, (dim_t)i_ow); |
1363 | } |
1364 | nd_iterator_step( |
1365 | i_mb, jcp.mb, i_od, jcp.od, i_oh, jcp.oh, i_owb, jcp.nb_ow); |
1366 | i_work++; |
1367 | } |
1368 | } |
1369 | |
1370 | template <data_type_t src_type, data_type_t diff_dst_type, |
1371 | data_type_t diff_weights_type> |
1372 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1373 | diff_weights_type>::compute_diff_weights(const thread_info_t *ti) |
1374 | const { |
1375 | const memory_desc_wrapper src_d(pd()->src_md()); |
1376 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1377 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1378 | |
1379 | const auto &jcp = kernel_->jcp; |
1380 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
1381 | const int padded_oc = rnd_up(jcp.oc, jcp.oc_block); |
1382 | const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block) |
1383 | * jcp.kh * jcp.kw * jcp.kd; |
1384 | |
1385 | diff_weights_data_t *diff_wei = ti->ithr_mb == 0 |
1386 | ? (diff_weights_data_t *)ti->diff_weights |
1387 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1388 | |
1389 | const bool is_src_layout_nxc = utils::one_of( |
1390 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1391 | |
1392 | int ic_b_step = jcp.nb_ic_blocking_max; |
1393 | int icb_work = ti->ic_b_end - ti->ic_b_start; |
1394 | if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step) |
1395 | ic_b_step = utils::div_up(icb_work, 2); |
1396 | |
1397 | for (int img = ti->img_start; img < ti->img_end; ++img) { |
1398 | auto p = jit_conv_call_s(); |
1399 | |
1400 | const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc); |
1401 | const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic); |
1402 | const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag, |
1403 | format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1404 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1405 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1406 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1407 | ic_b += ic_b_step) { |
1408 | const int _oc = g * jcp.nb_oc + oc_b; |
1409 | const int _ic = g * jcp.nb_ic + ic_b; |
1410 | const int ic_to_compute = this_block_size( |
1411 | ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block); |
1412 | const int oc_to_compute = this_block_size( |
1413 | oc_b * jcp.oc_block, max_oc, jcp.oc_block); |
1414 | |
1415 | const int ic_off_idx = is_src_layout_nxc |
1416 | ? g * jcp.ic + ic_b * jcp.ic_block |
1417 | : _ic; |
1418 | const int oc_off_idx = is_ddst_layout_nxc |
1419 | ? g * jcp.oc + oc_b * jcp.oc_block |
1420 | : _oc; |
1421 | |
1422 | jit_conv_ker_pipeline_bwd_w(jit_ker, p, |
1423 | &ti->src[src_d.blk_off(img, ic_off_idx)], |
1424 | &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)], |
1425 | diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), |
1426 | nullptr, (img == ti->img_start), 0, ic_to_compute, |
1427 | oc_to_compute); |
1428 | } |
1429 | } |
1430 | } |
1431 | |
1432 | template <data_type_t src_type, data_type_t diff_dst_type, |
1433 | data_type_t diff_weights_type> |
1434 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1435 | diff_weights_type>::compute_diff_weights_2d(const thread_info_t *ti) |
1436 | const { |
1437 | const memory_desc_wrapper src_d(pd()->src_md()); |
1438 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1439 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1440 | |
1441 | const auto &jcp = kernel_->jcp; |
1442 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
1443 | const int padded_oc = rnd_up(jcp.oc, jcp.oc_block); |
1444 | const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block) |
1445 | * jcp.kh * jcp.kw; |
1446 | |
1447 | diff_weights_data_t *diff_wei = ti->ithr_mb == 0 |
1448 | ? (diff_weights_data_t *)ti->diff_weights |
1449 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1450 | diff_weights_data_t *diff_bia = ti->ithr_mb == 0 |
1451 | ? (diff_weights_data_t *)ti->diff_bias |
1452 | : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size |
1453 | + (ti->ithr_mb - 1) * jcp.ngroups * padded_oc; |
1454 | |
1455 | int img {0}, oh_s {0}; |
1456 | int img_start = ti->img_start, img_end = ti->img_end; |
1457 | nd_iterator_init(img_start, img, jcp.mb, oh_s, jcp.oh); |
1458 | const int img_first = img; |
1459 | |
1460 | int ic_b_step = jcp.nb_ic_blocking_max; |
1461 | int icb_work = ti->ic_b_end - ti->ic_b_start; |
1462 | if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step) |
1463 | ic_b_step = utils::div_up(icb_work, 2); |
1464 | while (img_start < img_end) { |
1465 | auto p = jit_conv_call_s(); |
1466 | |
1467 | int work_rem = img_end - img_start; |
1468 | const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1469 | const int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
1470 | const int kh_top_overflow = nstl::max(0, -ih_s); |
1471 | const int kh_bottom_overflow = nstl::max(0, ih_s - jcp.ih + jcp.kh); |
1472 | int kh_padding = jcp.kh - kh_top_overflow - kh_bottom_overflow; |
1473 | int kh_padding_offset = nstl::min(jcp.kh - 1, kh_top_overflow) * jcp.kw |
1474 | * jcp.ic_block * jcp.oc_block * jcp.typesize_out; |
1475 | auto src_h = ti->src + src_d.blk_off(img, 0, ih_s + kh_top_overflow); |
1476 | auto diff_dst_h = ti->diff_dst + diff_dst_d.blk_off(img, 0, oh_s); |
1477 | |
1478 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc; |
1479 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc; |
1480 | const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc); |
1481 | const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic); |
1482 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1483 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1484 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1485 | ic_b += ic_b_step) { |
1486 | const int _oc = g * jcp.nb_oc + oc_b; |
1487 | const int _ic = g * jcp.nb_ic + ic_b; |
1488 | const int ic_to_compute = this_block_size( |
1489 | ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block); |
1490 | const int oc_to_compute = this_block_size( |
1491 | oc_b * jcp.oc_block, max_oc, jcp.oc_block); |
1492 | const int ic_off_idx = is_src_layout_nxc |
1493 | ? g * jcp.ic + ic_b * jcp.ic_block |
1494 | : _ic; |
1495 | const int oc_off_idx = is_ddst_layout_nxc |
1496 | ? g * jcp.oc + oc_b * jcp.oc_block |
1497 | : _oc; |
1498 | auto src = src_h + src_d.blk_off(0, ic_off_idx); |
1499 | auto diff_dst = diff_dst_h + diff_dst_d.blk_off(0, oc_off_idx); |
1500 | p.flags = ic_b == 0 ? 0 : 1; |
1501 | jit_conv_2d_ker_bwd_w_pipeline(jit_ker, p, src, diff_dst, |
1502 | diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), |
1503 | diff_bia + _oc * jcp.oc_block, (img == img_first), oh_s, |
1504 | oh_e, kh_padding, kh_padding_offset, ic_to_compute, |
1505 | oc_to_compute); |
1506 | } |
1507 | nd_iterator_jump(img_start, img_end, img, jcp.mb, oh_s, jcp.oh); |
1508 | } |
1509 | } |
1510 | |
1511 | template <data_type_t src_type, data_type_t diff_dst_type, |
1512 | data_type_t diff_weights_type> |
1513 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1514 | diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) |
1515 | const { |
1516 | const memory_desc_wrapper src_d(pd()->src_md()); |
1517 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1518 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1519 | |
1520 | const auto &jcp = kernel_->jcp; |
1521 | const jit_conv_ker_t jit_ker = (decltype(jit_ker))kernel_->jit_ker(); |
1522 | const int padded_oc = rnd_up(jcp.oc, jcp.oc_block); |
1523 | const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block) |
1524 | * jcp.kh * jcp.kw * jcp.kd; |
1525 | |
1526 | diff_weights_data_t *diff_wei = ti->ithr_mb == 0 |
1527 | ? (diff_weights_data_t *)ti->diff_weights |
1528 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1529 | diff_weights_data_t *diff_bia = ti->ithr_mb == 0 |
1530 | ? (diff_weights_data_t *)ti->diff_bias |
1531 | : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size |
1532 | + (ti->ithr_mb - 1) * jcp.ngroups * padded_oc; |
1533 | |
1534 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc; |
1535 | const int inp_mult = is_src_layout_nxc |
1536 | ? jcp.ngroups * jcp.ic |
1537 | : (jcp.is_1stconv ? 1 : jcp.ic_block); |
1538 | const int input_step = jcp.ih * jcp.iw * inp_mult; |
1539 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc; |
1540 | const int output_step = jcp.ow * jcp.oh |
1541 | * (is_ddst_layout_nxc ? jcp.ngroups * jcp.oc : jcp.oc_block); |
1542 | int img {0}, od_s {0}; |
1543 | int img_start = ti->img_start, img_end = ti->img_end; |
1544 | nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od); |
1545 | const int img_first = img; |
1546 | |
1547 | int ic_b_step = jcp.nb_ic_blocking_max; |
1548 | int icb_work = ti->ic_b_end - ti->ic_b_start; |
1549 | if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step) |
1550 | ic_b_step = utils::div_up(icb_work, 2); |
1551 | |
1552 | while (img_start < img_end) { |
1553 | auto p = jit_conv_call_s(); |
1554 | |
1555 | int work_rem = img_end - img_start; |
1556 | const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; |
1557 | const int id_s = od_s * jcp.stride_d; |
1558 | const int ik_overlap = nstl::max(0, id_s - jcp.f_pad); |
1559 | const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s); |
1560 | const int kd_back_pad |
1561 | = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd); |
1562 | int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw |
1563 | * jcp.ic_block * jcp.oc_block * jcp.typesize_out; |
1564 | |
1565 | const int max_oc = nstl::min(ti->oc_b_end * jcp.oc_block, jcp.oc); |
1566 | const int max_ic = nstl::min(ti->ic_b_end * jcp.ic_block, jcp.ic); |
1567 | |
1568 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1569 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1570 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1571 | ic_b += ic_b_step) { |
1572 | const int _oc = g * jcp.nb_oc + oc_b; |
1573 | const int _ic = g * jcp.nb_ic + ic_b; |
1574 | |
1575 | const int ic_to_compute = this_block_size( |
1576 | ic_b * jcp.ic_block, max_ic, ic_b_step * jcp.ic_block); |
1577 | const int oc_to_compute = this_block_size( |
1578 | oc_b * jcp.oc_block, max_oc, jcp.oc_block); |
1579 | |
1580 | const int ic_off_idx = is_src_layout_nxc |
1581 | ? g * jcp.ic + ic_b * jcp.ic_block |
1582 | : _ic; |
1583 | const int oc_off_idx = is_ddst_layout_nxc |
1584 | ? g * jcp.oc + oc_b * jcp.oc_block |
1585 | : _oc; |
1586 | auto src = &ti->src[src_d.blk_off(img, ic_off_idx) |
1587 | + ik_overlap * input_step]; |
1588 | auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx) |
1589 | + od_s * output_step]; |
1590 | auto diff_bia_ptr = diff_bia ? diff_bia + _oc * 16 : nullptr; |
1591 | p.flags = ic_b == 0 ? 0 : 1; |
1592 | jit_conv_3d_ker_bwd_w_pipeline(jit_ker, p, src, dst, |
1593 | diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b), |
1594 | diff_bia_ptr, (img == img_first), od_s, od_e, |
1595 | jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off, |
1596 | ic_to_compute, oc_to_compute); |
1597 | } |
1598 | nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od); |
1599 | } |
1600 | } |
1601 | |
1602 | template <data_type_t src_type, data_type_t diff_dst_type, |
1603 | data_type_t diff_weights_type> |
1604 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1605 | diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const { |
1606 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1607 | |
1608 | const auto &jcp = kernel_->jcp; |
1609 | const int padded_oc = rnd_up(jcp.oc, jcp.oc_block); |
1610 | const int wei_size = jcp.ngroups * padded_oc * rnd_up(jcp.ic, jcp.ic_block) |
1611 | * jcp.kh * jcp.kw; |
1612 | |
1613 | /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ |
1614 | if (dnnl_thr_syncable()) |
1615 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); |
1616 | |
1617 | const int ic_b_kh_work = ti->ic_b_work * jcp.kh; |
1618 | const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; |
1619 | |
1620 | int start {0}, end {0}; |
1621 | balance211(work, nthr_mb_, ti->ithr_mb, start, end); |
1622 | if (start == end) return; |
1623 | |
1624 | for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { |
1625 | int w = start; |
1626 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0}; |
1627 | nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, |
1628 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1629 | while (w < end) { |
1630 | const int g = ti->g_start + sub_g_start; |
1631 | const int oc_b = ti->oc_b_start + sub_oc_b_start; |
1632 | const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh; |
1633 | const int kh = sub_ic_b_kh_start % jcp.kh; |
1634 | |
1635 | const int acc_size |
1636 | = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) |
1637 | * jcp.kw * jcp.ic_block * jcp.oc_block; |
1638 | |
1639 | const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh); |
1640 | |
1641 | diff_weights_data_t *d |
1642 | = (diff_weights_data_t *)ti->diff_weights + off; |
1643 | diff_weights_data_t *s |
1644 | = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; |
1645 | |
1646 | acc_ker_->accumulate(d, s, acc_size); |
1647 | |
1648 | nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, |
1649 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1650 | } |
1651 | } |
1652 | } |
1653 | |
1654 | template <data_type_t src_type, data_type_t diff_dst_type, |
1655 | data_type_t diff_weights_type> |
1656 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1657 | diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) |
1658 | const { |
1659 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1660 | |
1661 | const auto &jcp = kernel_->jcp; |
1662 | const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) |
1663 | * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw * jcp.kd; |
1664 | |
1665 | /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ |
1666 | if (dnnl_thr_syncable()) |
1667 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); |
1668 | |
1669 | const int ic_b_kh_work = ti->ic_b_work * jcp.kd; |
1670 | const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; |
1671 | |
1672 | int start {0}, end {0}; |
1673 | balance211(work, nthr_mb_, ti->ithr_mb, start, end); |
1674 | if (start == end) return; |
1675 | |
1676 | for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { |
1677 | int w = start; |
1678 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0}; |
1679 | nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, |
1680 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1681 | while (w < end) { |
1682 | const int g = ti->g_start + sub_g_start; |
1683 | const int oc_b = ti->oc_b_start + sub_oc_b_start; |
1684 | const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd; |
1685 | const int kd = sub_ic_b_kh_start % jcp.kd; |
1686 | |
1687 | const int acc_size |
1688 | = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start) |
1689 | * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh; |
1690 | |
1691 | const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd); |
1692 | diff_weights_data_t *d |
1693 | = (diff_weights_data_t *)ti->diff_weights + off; |
1694 | diff_weights_data_t *s |
1695 | = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off; |
1696 | acc_ker_->accumulate(d, s, acc_size); |
1697 | |
1698 | nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, |
1699 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1700 | } |
1701 | } |
1702 | } |
1703 | |
1704 | template <data_type_t src_type, data_type_t diff_dst_type, |
1705 | data_type_t diff_weights_type> |
1706 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1707 | diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const { |
1708 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1709 | |
1710 | auto rb = this->reducer_bias_.get(); |
1711 | assert(nthr_ == rb->balancer().nthr_); |
1712 | |
1713 | const auto reducer_bia_scratchpad |
1714 | = memory_tracking::grantor_t(ti->scratchpad, prefix_reducer_bia); |
1715 | |
1716 | const auto &jcp = kernel_->jcp; |
1717 | |
1718 | const int b_job_start = rb->balancer().ithr_job_off(ti->ithr); |
1719 | const int b_njobs = rb->balancer().ithr_njobs(ti->ithr); |
1720 | |
1721 | if (b_njobs == 0) return; |
1722 | |
1723 | /* reduction dimension */ |
1724 | int img_start {0}, img_end {0}; |
1725 | balance211(jcp.mb, rb->balancer().nthr_per_group_, |
1726 | rb->balancer().id_in_group(ti->ithr), img_start, img_end); |
1727 | |
1728 | /* jobs */ |
1729 | int g_start {0}, ocb_start {0}; |
1730 | nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc); |
1731 | for (int img = img_start; img < img_end; ++img) { |
1732 | int g = g_start, ocb = ocb_start; |
1733 | for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { |
1734 | const size_t _oc = g * jcp.nb_oc + ocb; |
1735 | const int max_oc |
1736 | = this_block_size(ocb * jcp.oc_block, jcp.oc, jcp.oc_block); |
1737 | |
1738 | const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag, |
1739 | format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1740 | const int oc_off_idx = is_ddst_layout_nxc |
1741 | ? g * jcp.oc + ocb * jcp.oc_block |
1742 | : _oc; |
1743 | const diff_dst_data_t *d_dst |
1744 | = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; |
1745 | diff_weights_data_t *d_bias |
1746 | = rb->get_local_ptr( |
1747 | ti->ithr, ti->diff_bias, reducer_bia_scratchpad) |
1748 | + b_job_loc * rb->balancer().job_size_; |
1749 | |
1750 | if (img == img_start) |
1751 | for (int o = 0; o < jcp.oc_block; ++o) |
1752 | d_bias[o] = 0; |
1753 | for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) { |
1754 | PRAGMA_OMP_SIMD() |
1755 | for (int o = 0; o < max_oc; ++o) |
1756 | d_bias[o] += d_dst[o]; |
1757 | d_dst += is_ddst_layout_nxc ? jcp.ngroups * jcp.oc |
1758 | : jcp.oc_block; |
1759 | } |
1760 | |
1761 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc); |
1762 | } |
1763 | } |
1764 | |
1765 | if (dnnl_thr_syncable()) |
1766 | rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad); |
1767 | } |
1768 | |
1769 | template <data_type_t src_type, data_type_t diff_dst_type, |
1770 | data_type_t diff_weights_type> |
1771 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1772 | diff_weights_type>::reduce_diff_bias(const thread_info_t *ti) const { |
1773 | const auto &jcp = kernel_->jcp; |
1774 | |
1775 | const size_t wei_size = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) |
1776 | * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw * jcp.kd; |
1777 | const int bia_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block); |
1778 | const diff_weights_data_t *diff_bias_ws |
1779 | = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size; |
1780 | |
1781 | if (dnnl_thr_syncable() && nthr_mb_ > 1) dnnl_thr_barrier(); |
1782 | |
1783 | if (ti->ithr == 0) { |
1784 | for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) { |
1785 | acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size); |
1786 | diff_bias_ws += bia_size; |
1787 | } |
1788 | } |
1789 | } |
1790 | |
1791 | template <data_type_t src_type, data_type_t diff_dst_type, |
1792 | data_type_t diff_weights_type> |
1793 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1794 | diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx) |
1795 | const { |
1796 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1797 | |
1798 | if (dnnl_thr_syncable() && nthr_mb_ > 1) { |
1799 | simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( |
1800 | key_conv_wei_bia_reduction_bctx)); |
1801 | } |
1802 | |
1803 | const auto reducer_bia_scratchpad |
1804 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); |
1805 | auto rb = this->reducer_bias_.get(); |
1806 | rb->init(reducer_bia_scratchpad); |
1807 | } |
1808 | |
1809 | template <data_type_t src_type, data_type_t diff_dst_type, |
1810 | data_type_t diff_weights_type> |
1811 | void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type, |
1812 | diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) |
1813 | const { |
1814 | prepare_scratchpad_data(ctx); |
1815 | |
1816 | #if DNNL_THR_SYNC == 1 |
1817 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1818 | assert(nthr_ == nthr); |
1819 | |
1820 | thread_info_t thread_info(this, ctx, ithr); |
1821 | |
1822 | switch (pd()->jcp_.harness) { |
1823 | case harness_2d_reduction: |
1824 | compute_diff_weights_2d(&thread_info); |
1825 | if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); |
1826 | if (pd()->with_bias()) reduce_diff_bias(&thread_info); |
1827 | break; |
1828 | case harness_3d_reduction: |
1829 | compute_diff_weights_3d(&thread_info); |
1830 | if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); |
1831 | if (pd()->with_bias()) reduce_diff_bias(&thread_info); |
1832 | break; |
1833 | case harness_mb_reduction: |
1834 | compute_diff_weights(&thread_info); |
1835 | if (nthr_mb_ > 1) reduce_diff_weights(&thread_info); |
1836 | if (pd()->with_bias()) compute_diff_bias(&thread_info); |
1837 | break; |
1838 | case harness_nxc: |
1839 | compute_diff_weights_nxc(&thread_info); |
1840 | if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info); |
1841 | if (pd()->with_bias()) compute_diff_bias(&thread_info); |
1842 | break; |
1843 | default: assert(!"Invalid harness type" ); |
1844 | } |
1845 | }); |
1846 | #else |
1847 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1848 | thread_info_t thread_info(this, ctx, ithr); |
1849 | switch (pd()->jcp_.harness) { |
1850 | case harness_nxc: |
1851 | compute_diff_weights_nxc(&thread_info); |
1852 | if (pd()->with_bias()) compute_diff_bias(&thread_info); |
1853 | break; |
1854 | case harness_2d_reduction: |
1855 | compute_diff_weights_2d(&thread_info); |
1856 | break; |
1857 | case harness_3d_reduction: |
1858 | compute_diff_weights_3d(&thread_info); |
1859 | break; |
1860 | case harness_mb_reduction: |
1861 | compute_diff_weights(&thread_info); |
1862 | if (pd()->with_bias()) compute_diff_bias(&thread_info); |
1863 | break; |
1864 | default: assert(!"Invalid harness type" ); |
1865 | } |
1866 | }); |
1867 | |
1868 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1869 | thread_info_t thread_info(this, ctx, ithr); |
1870 | if (nthr_mb_ > 1) { |
1871 | switch (pd()->jcp_.harness) { |
1872 | case harness_mb_reduction: |
1873 | case harness_2d_reduction: |
1874 | reduce_diff_weights(&thread_info); |
1875 | break; |
1876 | case harness_nxc: |
1877 | case harness_3d_reduction: |
1878 | reduce_diff_weights_3d(&thread_info); |
1879 | break; |
1880 | default: assert(!"Invalid harness type" ); |
1881 | } |
1882 | } |
1883 | if (pd()->with_bias()) { |
1884 | switch (pd()->jcp_.harness) { |
1885 | case harness_2d_reduction: |
1886 | case harness_3d_reduction: |
1887 | reduce_diff_bias(&thread_info); |
1888 | break; |
1889 | case harness_nxc: |
1890 | case harness_mb_reduction: { |
1891 | auto rb = this->reducer_bias_.get(); |
1892 | assert(nthr == rb->balancer().nthr_); |
1893 | if (rb->balancer().ithr_njobs(ithr) == 0) return; |
1894 | const auto reducer_bia_scratchpad |
1895 | = memory_tracking::grantor_t( |
1896 | thread_info.scratchpad, prefix_reducer_bia); |
1897 | rb->reduce_nolock(thread_info.ithr, thread_info.diff_bias, |
1898 | reducer_bia_scratchpad); |
1899 | } break; |
1900 | default: assert(!"Invalid harness type" ); |
1901 | } |
1902 | } |
1903 | }); |
1904 | #endif |
1905 | |
1906 | /* TODO: put that into compute_diff_bias() */ |
1907 | auto &jcp = pd()->jcp_; |
1908 | if (pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0) { |
1909 | auto diff_bias = ctx.get_scratchpad_grantor() |
1910 | .template get<const diff_weights_data_t>( |
1911 | key_conv_padded_bias); |
1912 | auto diff_bias_in |
1913 | = CTX_OUT_MEM(diff_weights_data_t *, DNNL_ARG_DIFF_BIAS); |
1914 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
1915 | const int stride = jcp.oc_without_padding; |
1916 | for (int g = 0; g < jcp.ngroups; ++g) { |
1917 | utils::array_copy(diff_bias_in + g * stride, |
1918 | diff_bias + g * padded_stride, stride); |
1919 | } |
1920 | } |
1921 | } |
1922 | |
1923 | template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>; |
1924 | |
1925 | } // namespace x64 |
1926 | } // namespace cpu |
1927 | } // namespace impl |
1928 | } // namespace dnnl |
1929 | |
1930 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1931 | |