1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/jit_generator.hpp"
23
24#include "cpu/x64/jit_avx512_common_1x1_convolution.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34
35#define data_blk_off(f, n, c, d, h, w) \
36 ((ndims == 3) ? (f).blk_off(n, c, w) \
37 : ((ndims == 4) ? (f).blk_off(n, c, h, w) \
38 : (f).blk_off(n, c, d, h, w)))
39/* convolution forward */
40
41template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
42void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type,
43 dst_type>::execute_forward(const exec_ctx_t &ctx) const {
44 const auto &jcp = kernel_->jcp;
45 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
46 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
47 auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS);
48 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
49 auto weights_dw = CTX_IN_MEM(
50 const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
51 auto bias_dw = CTX_IN_MEM(
52 const dst_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
53 const auto post_ops_binary_rhs_arg_vec
54 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
55 const auto post_ops_binary_rhs_arg_vec_dw = pd()->dw_conv_pd_
56 ? binary_injector::prepare_binary_args(
57 pd()->dw_conv_pd_->jcp_.post_ops, ctx,
58 pd()->jcp_.post_ops.entry_.size() + 1)
59 : std::vector<const void *> {};
60
61 auto scratchpad = ctx.get_scratchpad_grantor();
62
63 if (pd()->wants_padded_bias()) {
64 auto padded_bias
65 = scratchpad.template get<dst_data_t>(key_conv_padded_bias);
66 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
67 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
68 jcp.oc - jcp.oc_without_padding);
69 bias = padded_bias;
70 }
71
72 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
73 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
74 dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
75 post_ops_binary_rhs_arg_vec_dw.data());
76 });
77
78 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
79}
80
81template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
82void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type,
83 dst_type>::execute_forward_thr(const int ithr, const int nthr,
84 const src_data_t *src, const wei_data_t *weights,
85 const dst_data_t *bias, const wei_data_t *weights_dw,
86 const dst_data_t *bias_dw, dst_data_t *dst,
87 const memory_tracking::grantor_t &scratchpad,
88 const void *post_ops_binary_rhs_arg_vec,
89 const void *post_ops_binary_rhs_arg_vec_dw) const {
90 const memory_desc_wrapper src_d(pd()->src_md());
91 const memory_desc_wrapper dst_d(pd()->dst_md());
92 const memory_desc_wrapper weights_d(pd()->weights_md(0));
93 const memory_desc_wrapper dw_weights_d(
94 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
95 const memory_desc_wrapper dw_bias_d(
96 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS));
97
98 const auto &jcp = kernel_->jcp;
99 auto rtus_space = pd()->rtus_.reduce_src_
100 ? scratchpad.get<src_data_t>(key_conv_rtus_space)
101 : nullptr;
102
103 const int ndims = src_d.ndims();
104 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
105 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
106 const int stride_w = pd()->desc()->strides[ndims - 3];
107
108 auto step = [](int default_step, int remaining, int tail_step) {
109 assert(default_step <= tail_step);
110 return remaining < tail_step ? remaining : default_step;
111 };
112
113 auto p = jit_1x1_conv_call_s();
114
115 auto rp = rtus_driver_t<avx512_core>::call_params_t();
116
117 const int nb_oc = jcp.nb_load;
118 const int nb_ic = jcp.nb_reduce;
119 const int nb_ic_blocking = jcp.nb_reduce_blocking;
120
121 // override some constants for fused dw_conv
122 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
123 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
124 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
125 const int nb_bcast_blocking_max
126 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
127 const int nb_load_blocking = jcp.nb_load_blocking;
128 const int nb_load_blocking_max = jcp.with_dw_conv
129 ? jcp.nb_load_blocking
130 : jcp.nb_load_blocking_max;
131 const bool is_dst_layout_nxc = utils::one_of(
132 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
133 const bool is_src_layout_nxc = utils::one_of(
134 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
135
136 // Begin: declare Variables needed for dw conv.
137 memory_tracking::grantor_t dw_scratchpad(
138 scratchpad, memory_tracking::names::prefix_fusion);
139 dst_data_t *pbuf;
140 size_t row_offset;
141 const int nb_buffer = jcp.nb_load_blocking;
142 std::vector<dst_data_t *> addrs;
143 // End
144
145 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
146 int &bcast_step, int &od, int &oh, int &ow,
147 int &id, int &ih, int &iw) {
148 int osb {0};
149 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
150 bcast_step = step(
151 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
152 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
153
154 const int os = osb * os_block;
155 od = os / (jcp.oh * jcp.ow);
156 int os_2d = os % (jcp.oh * jcp.ow);
157 oh = os_2d / jcp.ow;
158 ow = os_2d % jcp.ow;
159
160 id = od * stride_d;
161 ih = oh * stride_h;
162 iw = ow * stride_w;
163 rp.iw_start = iw;
164
165 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
166 rp.os = p.bcast_dim;
167 };
168
169 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
170 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
171 const auto max_oc
172 = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding);
173 p.load_dim = this_block_size(
174 ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block);
175 };
176
177 auto init_reduce = [&](int icb) {
178 const int nb_ic_blocking_step
179 = nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
180 p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
181 | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0);
182
183 p.reduce_dim = this_block_size(
184 icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block);
185 rp.icb = p.reduce_dim;
186 };
187
188 auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od,
189 int oh, int ow, int id, int ih, int iw) {
190 const int oc_off_idx = is_dst_layout_nxc
191 ? g * jcp.oc + ocb * jcp.oc_block
192 : g * nb_oc + ocb;
193 const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow);
194
195 p.output_data = jcp.with_dw_conv
196 ? pbuf + (oh % pd()->dw_conv_pd_->jcp_.kh) * row_offset
197 : &dst[dst_off];
198 p.bias_data = bias
199 ? &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)]
200 : nullptr;
201
202 p.load_data
203 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
204 : weights_d.blk_off(ocb, icb)];
205 const int ic_off_idx = is_src_layout_nxc
206 ? g * jcp.ic + icb * jcp.ic_block
207 : g * nb_ic + icb;
208 if (pd()->rtus_.reduce_src_) {
209 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
210 + (is_src_layout_nxc ? ic_off_idx
211 : jcp.is * ic_off_idx * jcp.ic_block);
212 if (ocb == ocb_start) {
213 rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
214 (*rtus_driver_)(&rp);
215 }
216 p.bcast_data = rp.ws;
217 } else
218 p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
219
220 p.dst_l_off = dst_off;
221 p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
222 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
223 p.dst_orig = dst;
224
225 (*kernel_)(&p);
226 };
227
228 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
229 int ocb_end) {
230 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
231
232 if (jcp.loop_order == loop_rlb) {
233 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
234 init_reduce(icb);
235 int ocb = ocb_start;
236 while (ocb < ocb_end) {
237 int load_step;
238 init_load(ocb, ocb_end, load_step);
239 int iwork = bcast_start;
240 while (iwork < bcast_end) {
241 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0},
242 ow {0}, id {0}, ih {0}, iw {0};
243 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh,
244 ow, id, ih, iw);
245 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
246 iw);
247 iwork += bcast_step;
248 }
249 ocb += load_step;
250 }
251 }
252 } else if (jcp.loop_order == loop_lbr) {
253 int ocb = ocb_start;
254 while (ocb < ocb_end) {
255 int load_step;
256 init_load(ocb, ocb_end, load_step);
257 int iwork = bcast_start;
258 while (iwork < bcast_end) {
259 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
260 id {0}, ih {0}, iw {0};
261 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
262 id, ih, iw);
263 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
264 init_reduce(icb);
265 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
266 iw);
267 }
268 iwork += bcast_step;
269 }
270 ocb += load_step;
271 }
272 } else if (jcp.loop_order == loop_rbl) {
273 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
274 init_reduce(icb);
275 int iwork = bcast_start;
276 while (iwork < bcast_end) {
277 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
278 id {0}, ih {0}, iw {0};
279 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
280 id, ih, iw);
281 int ocb = ocb_start;
282 while (ocb < ocb_end) {
283 int load_step;
284 init_load(ocb, ocb_end, load_step);
285 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
286 iw);
287 ocb += load_step;
288 }
289 iwork += bcast_step;
290 }
291 }
292 } else if (jcp.loop_order == loop_blr) {
293 int iwork = bcast_start;
294 while (iwork < bcast_end) {
295 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
296 id {0}, ih {0}, iw {0};
297 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
298 ih, iw);
299 int ocb = ocb_start;
300 while (ocb < ocb_end) {
301 int load_step;
302 init_load(ocb, ocb_end, load_step);
303 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
304 init_reduce(icb);
305 ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih,
306 iw);
307 }
308 ocb += load_step;
309 }
310 iwork += bcast_step;
311 }
312 } else {
313 assert(!"unsupported loop order");
314 }
315 };
316
317 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
318 auto &jcp_dw = pd()->dw_conv_pd_->jcp_;
319 int oh_1x1 = nstl::max(dw_oh * jcp_dw.stride_h - jcp_dw.t_pad, 0);
320
321 for (int i = 0; i < jcp_dw.kh; ++i)
322 addrs[i] = pbuf + ((oh_1x1++) % jcp_dw.kh) * row_offset;
323
324 const auto ocb_end = ocb_start + load_step;
325 const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw.iw)
326 * jcp_dw.nb_ch_blocking * jcp_dw.ch_block;
327 const int dil_h = jcp_dw.dilate_h + 1;
328 const int str_h = jcp_dw.stride_h;
329 const int ch_num = jcp_dw.nb_ch_blocking;
330 const int ow = 0;
331 const int kw = 0;
332
333 for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw.nb_ch_blocking) {
334
335 const int i_t_overflow
336 = nstl::max(0, (int)(jcp_dw.t_pad - dw_oh * str_h));
337 const int i_b_overflow
338 = nstl::max(jcp_dw.ih,
339 (int)(dw_oh * str_h + (jcp_dw.kh - 1) * dil_h
340 - jcp_dw.t_pad + 1))
341 - jcp_dw.ih;
342
343 const int kh = div_up(i_t_overflow, dil_h);
344 const int kh_padding = jcp_dw.kh - div_up(i_t_overflow, dil_h)
345 - div_up(i_b_overflow, dil_h);
346
347 jit_conv_call_s par_conv_dw;
348
349 par_conv_dw.src = addrs.data();
350
351 const size_t ch_step = is_dst_layout_nxc
352 ? jcp_dw.ch_block
353 : dst_d.blk_off(0, 1, 0, 0);
354 par_conv_dw.dst
355 = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step];
356
357 par_conv_dw.filt
358 = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)];
359 if (bias)
360 par_conv_dw.bias
361 = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw.ch_block)];
362
363 par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding);
364
365 par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw.nb_ch) - ch)
366 * jcp_dw.ch_block;
367
368 par_conv_dw.oc_l_off = ch * jcp_dw.ch_block;
369 par_conv_dw.post_ops_binary_rhs_arg_vec
370 = post_ops_binary_rhs_arg_vec_dw;
371 par_conv_dw.dst_orig = dst;
372
373 (*kernel_dw_)(&par_conv_dw);
374
375 for (int i = 0; i < jcp_dw.kh; ++i)
376 addrs[i] += wch_stride;
377 }
378 };
379
380 auto conv_dw = [&]() {
381 // Set variables
382 auto dw_conv_buffer
383 = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer);
384 auto &jcp_dw = pd()->dw_conv_pd_->jcp_;
385
386 const auto dw_conv_buffer_size_
387 = (size_t)jcp_dw.kh * jcp.ow * nb_buffer * jcp.oc_block;
388 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
389 row_offset = dw_conv_buffer_size_ / jcp_dw.kh;
390 addrs.resize(jcp_dw.kh);
391
392 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
393 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw.oh, bcast_start,
394 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count);
395
396 while (ocb_start < ocb_end) {
397 int load_step;
398 init_load(ocb_start, ocb_end, load_step);
399
400 int oh_1x1 = 0;
401 auto bcast_iter = bcast_start;
402 while (bcast_iter < bcast_end) {
403 int n {0}, g {0}, oh_dw {0};
404 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
405 jcp_dw.oh);
406 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
407 const int oh_1x1_range = oh_dw * jcp_dw.stride_h - jcp_dw.t_pad;
408 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
409 const int oh_1x1_end
410 = nstl::min(oh_1x1_range + jcp_dw.kh, jcp.oh);
411 oh_1x1 = nstl::max(
412 oh_1x1_begin, oh_1x1); // Skip rows computed previously
413
414 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh
415 const int bcast_start_1x1
416 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
417 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
418
419 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
420 ocb_start + load_step);
421 oh_1x1 = oh_1x1_end;
422 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
423
424 bcast_iter += nb_bcast_blocking;
425 }
426 ocb_start += load_step;
427 }
428 };
429
430 if (jcp.with_dw_conv) {
431 conv_dw();
432 } else {
433
434 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
435 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
436 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load,
437 ocb_start, ocb_end, jcp.load_grp_count);
438
439 conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end);
440 }
441}
442
443REG_AVX512_ISA(template struct jit_avx512_common_1x1_convolution_fwd_t<
444 data_type::f32>);
445/* convolution backward wtr data */
446
447template <data_type_t diff_dst_type, data_type_t wei_type,
448 data_type_t diff_src_type>
449void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
450 diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const {
451 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
452 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
453 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
454
455 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
456 const memory_desc_wrapper weights_d(pd()->weights_md(0));
457 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
458
459 const auto &jcp = kernel_->jcp;
460 auto rtus_space = pd()->rtus_.reduce_src_
461 ? ctx.get_scratchpad_grantor().template get<diff_src_data_t>(
462 key_conv_rtus_space)
463 : nullptr;
464
465 const int ndims = diff_src_d.ndims();
466
467 assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1);
468
469 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
470 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
471 const int stride_w = pd()->desc()->strides[ndims - 3];
472
473 const int nb_ic = jcp.nb_load;
474 const int nb_oc = jcp.nb_reduce;
475 const int os_block = jcp.bcast_block;
476 const int nb_oc_blocking = jcp.nb_reduce_blocking;
477
478 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
479
480 auto step = [](int default_step, int remaining, int tail_step) {
481 assert(default_step <= tail_step);
482 return remaining < tail_step ? remaining : default_step;
483 };
484
485 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
486 auto p = jit_1x1_conv_call_s();
487 auto rp = rtus_driver_t<avx512_core>::call_params_t();
488
489 int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0};
490 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load,
491 icb_start, icb_end, jcp.load_grp_count);
492
493 bool reduce_outer
494 = (jcp.loop_order == loop_rbl || jcp.loop_order == loop_rlb);
495 int nboc_outer = reduce_outer ? nb_oc : 1;
496 int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1;
497
498 int nboc_inner = reduce_outer ? 1 : nb_oc;
499 int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking;
500 const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic);
501
502 for (int ocb_outer = 0; ocb_outer < nboc_outer;
503 ocb_outer += ocb_outer_step) {
504 size_t cur_ocb_outer
505 = nstl::min(ocb_outer + ocb_outer_step, nboc_outer)
506 - ocb_outer;
507
508 int load_step = 0;
509 for (int icb = icb_start; icb < icb_end; icb += load_step) {
510 load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
511 jcp.nb_load_blocking_max);
512
513 p.load_dim = this_block_size(
514 icb * jcp.ic_block, max_ic, load_step * jcp.ic_block);
515 rp.icb = p.load_dim;
516
517 int bcast_step;
518 for (int iwork = bcast_start; iwork < bcast_end;
519 iwork += bcast_step) {
520 int n {0}, g {0}, osb {0};
521 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
522 jcp.nb_bcast);
523
524 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
525 jcp.nb_bcast_blocking_max);
526 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
527
528 const int os = osb * os_block;
529 p.bcast_dim = this_block_size(
530 os, jcp.os, bcast_step * os_block);
531 rp.os = p.bcast_dim;
532
533 const int od = os / (jcp.oh * jcp.ow);
534 const int os_2d = os % (jcp.oh * jcp.ow);
535 const int oh = os_2d / jcp.ow;
536 const int ow = os_2d % jcp.ow;
537 const int id = od * stride_d;
538 const int ih = oh * stride_h;
539 const int iw = ow * stride_w;
540 rp.iw_start = iw;
541 const bool is_dsrc_layout_nxc
542 = utils::one_of(jcp.src_tag, format_tag::nwc,
543 format_tag::nhwc, format_tag::ndhwc);
544 const int ic_off_idx = is_dsrc_layout_nxc
545 ? g * jcp.ic + icb * jcp.ic_block
546 : g * nb_ic + icb;
547 rp.src = diff_src
548 + data_blk_off(
549 diff_src_d, n, ic_off_idx, id, ih, iw);
550 if (pd()->rtus_.reduce_src_) {
551 rp.ws = rtus_space
552 + ithr * pd()->rtus_.space_per_thread_;
553 p.output_data = rp.ws;
554 } else
555 p.output_data = rp.src;
556
557 for (int ocb_inner = 0; ocb_inner < nboc_inner;
558 ocb_inner += ocb_inner_step) {
559 int cur_ocb_inner
560 = nstl::min(ocb_inner + ocb_inner_step,
561 nboc_inner)
562 - ocb_inner;
563
564 int ocb = reduce_outer ? ocb_outer : ocb_inner;
565 int nb_oc_blocking_step
566 = reduce_outer ? cur_ocb_outer : cur_ocb_inner;
567 const bool is_ddst_layout_nxc
568 = utils::one_of(jcp.dst_tag, format_tag::nwc,
569 format_tag::nhwc, format_tag::ndhwc);
570 const int oc_off_idx = is_ddst_layout_nxc
571 ? g * jcp.oc + ocb * jcp.oc_block
572 : g * nb_oc + ocb;
573 size_t diff_dst_off = data_blk_off(
574 diff_dst_d, n, oc_off_idx, od, oh, ow);
575 p.bcast_data = &diff_dst[diff_dst_off];
576
577 p.load_data = &weights[pd()->with_groups()
578 ? weights_d.blk_off(g, ocb, icb)
579 : weights_d.blk_off(ocb, icb)];
580
581 p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
582
583 p.reduce_dim = this_block_size(ocb * jcp.oc_block,
584 jcp.oc, nb_oc_blocking_step * jcp.oc_block);
585
586 (*kernel_)(&p);
587 }
588 if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp);
589 }
590 }
591 }
592 });
593}
594
595REG_AVX512_ISA(template struct jit_avx512_common_1x1_convolution_bwd_data_t<
596 data_type::f32>);
597
598/* convolution backward wtr weights */
599
600#define wht_blk_off(d, g, ...) \
601 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
602 : (d).blk_off(__VA_ARGS__))
603
604status_t jit_avx512_common_1x1_convolution_bwd_weights_t ::init(
605 engine_t *engine) {
606 CHECK(safe_ptr_assign(kernel_,
607 new jit_avx512_common_1x1_conv_kernel(
608 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
609 CHECK(safe_ptr_assign(
610 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
611 CHECK(safe_ptr_assign(reducer_bias_,
612 new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_)));
613 CHECK(kernel_->create_kernel());
614 CHECK(acc_ker_->create_kernel());
615 CHECK(reducer_bias_->create_kernel());
616
617 CHECK(init_rtus_driver<avx512_core>(this));
618 return status::success;
619}
620
621void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights(
622 const exec_ctx_t &ctx) const {
623 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
624 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
625 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
626 auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
627
628 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
629 const memory_desc_wrapper src_d(pd()->src_md());
630 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
631
632 const auto &jcp = kernel_->jcp;
633
634 const auto scratchpad = ctx.get_scratchpad_grantor();
635
636 auto rtus_space = pd()->rtus_.reduce_src_
637 ? scratchpad.get<data_t>(key_conv_rtus_space)
638 : nullptr;
639 const bool is_bias_padded
640 = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block != 0;
641
642 data_t *diff_bias = is_bias_padded
643 ? scratchpad.get<data_t>(key_conv_padded_bias)
644 : diff_bias_in;
645 auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
646
647 const int ndims = src_d.ndims();
648 const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block)
649 * rnd_up(jcp.ic, jcp.ic_block);
650
651 simple_barrier::ctx_t reduction_barrier;
652 simple_barrier::ctx_init(&reduction_barrier);
653
654 const auto reducer_bia_scratchpad
655 = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
656 auto rb = this->reducer_bias_.get();
657 rb->init(reducer_bia_scratchpad);
658
659 // TODO (Roma): remove this restriction
660 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
661
662 const int nb_ic = jcp.nb_bcast;
663 const int nb_ic_blocking = jcp.nb_bcast_blocking;
664
665 const int nb_oc = jcp.nb_load;
666 const int nb_oc_blocking = jcp.nb_load_blocking;
667
668 const int sp_nb = jcp.nb_reduce;
669 const int mb_sp_work = jcp.mb * sp_nb;
670
671 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
672 const int stride_w = pd()->desc()->strides[ndims - 3];
673
674 auto step = [](int default_step, int remaining, int tail_step) {
675 assert(default_step <= tail_step);
676 return remaining < tail_step ? remaining : default_step;
677 };
678
679 const bool is_src_layout_nxc = utils::one_of(
680 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
681 const bool is_ddst_layout_nxc = utils::one_of(
682 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
683
684 auto maybe_zero_icpad = [&](const int g_start, const int g_end,
685 const int ocb_start, const int ocb_end) {
686 // write zeros to IC padded region.
687 const int ic_tail = jcp.ic_without_padding % jcp.ic_block;
688 if (is_ddst_layout_nxc && ic_tail != 0) {
689 for_(int g = g_start; g < g_end; ++g)
690 for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) {
691 const int z_icb = nb_ic - 1;
692 const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb)
693 + ic_tail * jcp.oc_block;
694 data_t *z_wei = diff_weights + off;
695 const int zero_work
696 = (nb_ic * jcp.ic_block - jcp.ic_without_padding)
697 * jcp.oc_block;
698 PRAGMA_OMP_SIMD()
699 for (int o = 0; o < zero_work; ++o) {
700 z_wei[o] = 0;
701 }
702 }
703 }
704 };
705
706 auto ker = [&](const int ithr, const int nthr) {
707 assert(nthr == jcp.nthr);
708
709 const int ithr_ic_b = ithr % jcp.nthr_ic_b;
710 const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
711 const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
712 const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g;
713
714 /* reduction dimension */
715 int mb_sp_b_start {0}, mb_sp_b_end {0};
716 balance211(
717 mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end);
718
719 /* independent dimensions */
720 int g_start {0}, oc_b_start {0}, ic_b_start {0};
721 int g_end {0}, oc_b_end {0}, ic_b_end {0};
722
723 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
724 balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end);
725 balance211(
726 jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end);
727
728 const int g_work = g_end - g_start;
729 const int oc_b_work = oc_b_end - oc_b_start;
730 const int ic_b_work = ic_b_end - ic_b_start;
731 const bool cache_aliasing
732 = (jcp.ic * jcp.ngroups * sizeof(float)) % 1024 == 0;
733 int reduce_step = jcp.nb_reduce_blocking;
734 int reduce_step_max = jcp.nb_reduce_blocking_max;
735 if (is_src_layout_nxc && cache_aliasing) {
736 // Experiments show 4 is a magic number with the tested shapes.
737 // TODO: maybe tune for shapes with sp_dim%4 != 0
738 reduce_step = nstl::min(4, reduce_step);
739 reduce_step_max = reduce_step;
740 }
741
742 data_t *diff_wei = ithr_mb == 0
743 ? diff_weights
744 : wei_reduction + (ithr_mb - 1) * wei_size;
745
746 int sp_b_step = 0;
747 for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
748 mb_sp_b += sp_b_step) {
749 int img {0}, sp_b {0};
750 nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
751 sp_b_step = step(reduce_step,
752 nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
753 reduce_step_max);
754
755 for (int g = g_start; g < g_end; ++g) {
756 int load_step = 0;
757 int bcast_step = 0;
758 for (int ic_b = ic_b_start; ic_b < ic_b_end;
759 ic_b += bcast_step) {
760 if (is_src_layout_nxc && cache_aliasing) {
761 bcast_step = ic_b_work;
762 } else {
763 bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
764 jcp.nb_bcast_blocking_max);
765 }
766
767 for (int oc_b = oc_b_start; oc_b < oc_b_end;
768 oc_b += load_step) {
769 load_step = step(nb_oc_blocking, oc_b_end - oc_b,
770 jcp.nb_load_blocking_max);
771 const int _ic_b = g * nb_ic + ic_b;
772 const int oc_off_idx = is_ddst_layout_nxc
773 ? g * jcp.oc + oc_b * jcp.oc_block
774 : g * nb_oc + oc_b;
775
776 data_t *store_to;
777
778 const size_t off
779 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
780 store_to = diff_wei + off;
781
782 const int ic_off_idx
783 = (is_src_layout_nxc ? jcp.ic_block : 1)
784 * _ic_b;
785 const data_t *diff_src
786 = &src[src_d.blk_off(img, ic_off_idx)];
787
788 int sp_b_end = sp_b + sp_b_step;
789 const data_t *pdiff_dst = &diff_dst[diff_dst_d.blk_off(
790 img, oc_off_idx)];
791 const data_t *local_src = diff_src;
792
793 auto p = jit_1x1_conv_call_s();
794 auto rp = rtus_driver_t<avx512_core>::call_params_t();
795
796 p.output_stride = utils::rnd_up(jcp.ic, jcp.ic_block)
797 * jcp.oc_block * jcp.typesize_out;
798
799 p.load_dim = this_block_size(oc_b * jcp.oc_block,
800 jcp.oc, load_step * jcp.oc_block);
801
802 p.bcast_dim = this_block_size(ic_b * jcp.ic_block,
803 jcp.ic, bcast_step * jcp.ic_block);
804 rp.icb = p.bcast_dim;
805 p.output_data = store_to;
806
807 p.reduce_dim = sp_b_step * jcp.reduce_block;
808 rp.os = p.reduce_dim;
809
810 p.first_last_flag = 0
811 | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST
812 : 0)
813 | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
814
815 int sp = sp_b * jcp.reduce_block;
816 int oc_mult
817 = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block;
818 p.load_data = pdiff_dst + sp * oc_mult;
819
820 if (pd()->rtus_.reduce_src_) {
821 const int oh = sp / jcp.ow;
822 const int ow = sp % jcp.ow;
823
824 const int ih = oh * stride_h;
825 const int iw = ow * stride_w;
826 rp.iw_start = iw;
827
828 rp.ws = rtus_space
829 + ithr * pd()->rtus_.space_per_thread_
830 + sp * jcp.ic_block;
831
832 if (ndims == 3)
833 rp.src = local_src
834 + iw * src_d.blocking_desc().strides[2];
835 else
836 rp.src = local_src
837 + ih * src_d.blocking_desc().strides[2]
838 + iw * src_d.blocking_desc().strides[3];
839 (*rtus_driver_)(&rp);
840
841 p.bcast_data = rp.ws;
842 } else {
843 int ic_mult
844 = is_src_layout_nxc ? jcp.ic : jcp.ic_block;
845 p.bcast_data = local_src + sp * ic_mult;
846 }
847
848 (*kernel_)(&p);
849 }
850 }
851 }
852 }
853
854 if (ithr_mb == 0 && ic_b_end >= jcp.nb_bcast) {
855 maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end);
856 }
857
858 /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
859 if (dnnl_thr_syncable() && jcp.nthr_mb > 1) {
860 simple_barrier::barrier(&reduction_barrier, jcp.nthr);
861 const int work = g_work * oc_b_work * ic_b_work;
862 int start {0}, end {0};
863 balance211(work, jcp.nthr_mb, ithr_mb, start, end);
864 if (start == end) return;
865
866 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
867 int w = start;
868 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0};
869 nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
870 oc_b_work, sub_ic_b_start, ic_b_work);
871 while (w < end) {
872 const int g = g_start + sub_g_start;
873 const int oc_b = oc_b_start + sub_oc_b_start;
874 const int ic_b = ic_b_start + sub_ic_b_start;
875 const int ic_to_accumulate
876 = nstl::min(end - w, ic_b_work - sub_ic_b_start)
877 * jcp.ic_block;
878 const int acc_size
879 = this_block_size(ic_b * jcp.ic_block,
880 jcp.ic_without_padding, ic_to_accumulate)
881 * jcp.oc_block;
882
883 const size_t off
884 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
885 data_t *d = diff_weights + off;
886 data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
887
888 acc_ker_->accumulate(d, s, acc_size);
889
890 nd_iterator_jump(w, end, sub_g_start, g_work,
891 sub_oc_b_start, oc_b_work, sub_ic_b_start,
892 ic_b_work);
893 }
894 }
895 }
896 };
897
898 auto ker_bias = [&](int ithr, int nthr) {
899 assert(nthr == rb->balancer().nthr_);
900
901 const int b_job_start = rb->balancer().ithr_job_off(ithr);
902 const int b_njobs = rb->balancer().ithr_njobs(ithr);
903
904 if (b_njobs == 0) return;
905
906 /* reduction dimension */
907 int img_start {0}, img_end {0};
908
909 balance211(jcp.mb, rb->balancer().nthr_per_group_,
910 rb->balancer().id_in_group(ithr), img_start, img_end);
911
912 /* jobs */
913 int g_start {0}, ocb_start {0};
914 nd_iterator_init(
915 b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load);
916
917 for (int img = img_start; img < img_end; ++img) {
918 int g = g_start, ocb = ocb_start;
919 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
920 const int oc_off_idx = is_ddst_layout_nxc
921 ? g * jcp.oc + ocb * jcp.oc_block
922 : g * jcp.nb_load + ocb;
923 const data_t *d_dst
924 = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)];
925
926 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
927 reducer_bia_scratchpad)
928 + b_job_loc * rb->balancer().job_size_;
929 const int sp_shift = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc
930 : jcp.oc_block;
931 const auto max_oc = this_block_size(
932 ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
933 if (img == img_start)
934 for (int o = 0; o < 16; ++o)
935 d_bias[o] = 0.;
936
937 for (int os = 0; os < jcp.os; ++os) {
938 PRAGMA_OMP_SIMD()
939 for (int o = 0; o < max_oc; ++o)
940 d_bias[o] += d_dst[o];
941 d_dst += sp_shift;
942 }
943
944 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
945 }
946 }
947
948 if (dnnl_thr_syncable())
949 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
950 };
951
952 if (dnnl_thr_syncable()) {
953 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
954 ker(ithr, jcp.nthr);
955 if (pd()->with_bias()) ker_bias(ithr, jcp.nthr);
956 });
957 } else {
958 parallel(jcp.nthr, [&](int ithr, int nthr) { ker(ithr, nthr); });
959 if (jcp.nthr_mb > 1)
960 parallel(jcp.nthr, [&](int ithr, int nthr) {
961 assert(nthr == jcp.nthr);
962
963 const int ithr_ic_b = ithr % jcp.nthr_ic_b;
964 const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
965 const int ithr_g
966 = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
967 const int ithr_mb
968 = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g;
969
970 /* independent dimensions */
971 int g_start {0}, oc_b_start {0}, ic_b_start {0};
972 int g_end {0}, oc_b_end {0}, ic_b_end {0};
973
974 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
975 balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start,
976 oc_b_end);
977 balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start,
978 ic_b_end);
979
980 const int g_work = g_end - g_start;
981 const int oc_b_work = oc_b_end - oc_b_start;
982 const int ic_b_work = ic_b_end - ic_b_start;
983
984 const int work = g_work * oc_b_work * ic_b_work;
985 int start {0}, end {0};
986 balance211(work, jcp.nthr_mb, ithr_mb, start, end);
987 if (start == end) return;
988
989 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
990 int w = start;
991 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0};
992 nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
993 oc_b_work, sub_ic_b_start, ic_b_work);
994 while (w < end) {
995 const int g = g_start + sub_g_start;
996 const int oc_b = oc_b_start + sub_oc_b_start;
997 const int ic_b = ic_b_start + sub_ic_b_start;
998 const int ic_to_accumulate
999 = nstl::min(end - w, ic_b_work - sub_ic_b_start)
1000 * jcp.ic_block;
1001 const int acc_size
1002 = this_block_size(ic_b * jcp.ic_block,
1003 jcp.ic_without_padding,
1004 ic_to_accumulate)
1005 * jcp.oc_block;
1006
1007 const size_t off
1008 = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1009 data_t *d = diff_weights + off;
1010 data_t *s
1011 = wei_reduction + (thr_mb - 1) * wei_size + off;
1012
1013 acc_ker_->accumulate(d, s, acc_size);
1014
1015 nd_iterator_jump(w, end, sub_g_start, g_work,
1016 sub_oc_b_start, oc_b_work, sub_ic_b_start,
1017 ic_b_work);
1018 }
1019 }
1020 });
1021 if (pd()->with_bias()) {
1022 parallel(jcp.nthr,
1023 [&](int ithr, int nthr) { ker_bias(ithr, nthr); });
1024 parallel(jcp.nthr, [&](int ithr, int nthr) {
1025 assert(nthr == rb->balancer().nthr_);
1026 MAYBE_UNUSED(nthr);
1027 if (rb->balancer().ithr_njobs(ithr) == 0) return;
1028 rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad);
1029 });
1030 }
1031 }
1032
1033 /* TODO: put this in ker_bias */
1034 if (is_bias_padded) {
1035 assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1));
1036 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
1037 const int stride = jcp.oc_without_padding;
1038 for (int g = 0; g < jcp.ngroups; ++g) {
1039 utils::array_copy(diff_bias_in + g * stride,
1040 diff_bias + g * padded_stride, stride);
1041 }
1042 }
1043}
1044
1045} // namespace x64
1046} // namespace cpu
1047} // namespace impl
1048} // namespace dnnl
1049