1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/jit_generator.hpp"
23
24#include "cpu/x64/jit_avx2_1x1_convolution.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34
35#define data_blk_off(f, n, c, d, h, w) \
36 ((ndims == 3) ? (f).blk_off(n, c, w) \
37 : ((ndims == 4) ? (f).blk_off(n, c, h, w) \
38 : (f).blk_off(n, c, d, h, w)))
39/* convolution forward */
40
41void jit_avx2_1x1_convolution_fwd_t::execute_forward(
42 const exec_ctx_t &ctx) const {
43 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
44 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
45 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
46 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
47 auto weights_dw = CTX_IN_MEM(
48 const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
49 auto bias_dw = CTX_IN_MEM(
50 const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
51 const auto post_ops_binary_rhs_arg_vec
52 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
53 const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_
54 ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
55 pd()->jcp_.post_ops.entry_.size() + 1)
56 : std::vector<const void *> {};
57
58 auto scratchpad = ctx.get_scratchpad_grantor();
59
60 const auto &jcp = kernel_->jcp;
61 // TODO (Roma): remove this restriction
62 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
63
64 if (pd()->wants_padded_bias()) {
65 auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias);
66 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
67 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
68 jcp.oc - jcp.oc_without_padding);
69 bias = padded_bias;
70 }
71
72 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
73 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
74 dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
75 post_ops_binary_rhs_arg_vec_dw.data());
76 });
77
78 if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST);
79}
80
81void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr,
82 const int nthr, const data_t *src, const data_t *weights,
83 const data_t *bias, const data_t *weights_dw, const data_t *bias_dw,
84 data_t *dst, const memory_tracking::grantor_t &scratchpad,
85 const void *post_ops_binary_rhs_arg_vec,
86 const void *post_ops_binary_rhs_arg_vec_dw) const {
87
88 const memory_desc_wrapper src_d(pd()->src_md());
89 const memory_desc_wrapper dst_d(pd()->dst_md());
90 const memory_desc_wrapper weights_d(pd()->weights_md(0));
91 const memory_desc_wrapper dw_weights_d(
92 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
93 const memory_desc_wrapper dw_bias_d(
94 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS));
95
96 const auto &jcp = kernel_->jcp;
97 auto rtus_space = pd()->rtus_.reduce_src_
98 ? scratchpad.get<data_t>(key_conv_rtus_space)
99 : nullptr;
100
101 const int ndims = dst_d.ndims();
102
103 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
104 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
105 const int stride_w = pd()->desc()->strides[ndims - 3];
106
107 const int nb_oc = jcp.nb_load;
108 const int nb_ic = jcp.nb_reduce;
109 const int nb_ic_blocking = jcp.nb_reduce_blocking;
110
111 auto p = jit_1x1_conv_call_s();
112 auto rp = rtus_driver_t<avx2>::call_params_t();
113
114 // override some constants for fused dw_conv
115 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
116 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
117 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
118 const int nb_bcast_blocking_max
119 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
120 const int nb_load_blocking = jcp.nb_load_blocking;
121 const int nb_load_blocking_max = jcp.with_dw_conv
122 ? jcp.nb_load_blocking
123 : jcp.nb_load_blocking_max;
124
125 // Begin: declare Variables needed for dw conv.
126 data_t *pbuf;
127 size_t row_offset;
128 const int nb_buffer = jcp.nb_load_blocking;
129 auto jcp_dw = pd()->jcp_dw_;
130 std::vector<data_t *> addrs;
131 jit_generator *dw_jit_ker = nullptr;
132
133 const bool is_src_layout_nxc = utils::one_of(
134 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
135 const bool is_dst_layout_nxc = utils::one_of(
136 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
137
138 auto step = [](int default_step, int remaining, int tail_step) {
139 assert(default_step <= tail_step);
140 return remaining < tail_step ? remaining : default_step;
141 };
142
143 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
144 int &bcast_step, int &od, int &oh, int &ow,
145 int &id, int &ih, int &iw) {
146 int osb {0};
147 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
148
149 bcast_step = step(
150 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
151 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
152
153 const int os = osb * os_block;
154 const int os_2d = os % (jcp.oh * jcp.ow);
155 od = os / (jcp.oh * jcp.ow);
156 oh = os_2d / jcp.ow;
157 ow = os_2d % jcp.ow;
158 id = od * stride_d;
159 ih = oh * stride_h;
160 iw = ow * stride_w;
161 rp.iw_start = iw;
162
163 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
164 rp.os = p.bcast_dim;
165 };
166
167 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
168 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
169 // binary postop injector may override zero-padded areas, so proper
170 // output masking needs to be performed base on exact number of channels
171 const auto oc = jcp.with_binary ? jcp.oc_without_padding : jcp.oc;
172 p.load_dim = this_block_size(
173 ocb * jcp.oc_block, oc, load_step * jcp.oc_block);
174 };
175
176 auto ker_1x1 = [&](int ocb, int icb, int ocb_start, int n, int g, int od,
177 int oh, int ow, int id, int ih, int iw) {
178 const int oc_off_idx = is_dst_layout_nxc
179 ? g * jcp.oc + ocb * jcp.oc_block
180 : g * nb_oc + ocb;
181
182 p.output_data = jcp.with_dw_conv
183 ? pbuf + (oh % jcp_dw->kh) * row_offset
184 : &dst[data_blk_off(dst_d, n, oc_off_idx, od, oh, ow)];
185 p.bias_data
186 = &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)];
187
188 p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
189 | (icb + nb_ic_blocking >= nb_ic ? FLAG_REDUCE_LAST : 0);
190
191 p.reduce_dim = this_block_size(
192 icb * jcp.ic_block, jcp.ic, nb_ic_blocking * jcp.ic_block);
193 rp.icb = p.reduce_dim;
194
195 p.load_data
196 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
197 : weights_d.blk_off(ocb, icb)];
198
199 const int ic_off_idx = is_src_layout_nxc
200 ? g * jcp.ic + icb * jcp.ic_block
201 : g * nb_ic + icb;
202
203 if (pd()->rtus_.reduce_src_) {
204 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
205 + (is_src_layout_nxc ? ic_off_idx
206 : jcp.is * ic_off_idx * jcp.ic_block);
207
208 if (ocb == ocb_start) {
209 rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
210 (*rtus_driver_)(&rp);
211 }
212
213 p.bcast_data = rp.ws;
214 } else
215 p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw);
216
217 p.oc_l_off = ocb * jcp.oc_block;
218 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
219 p.dst_orig = dst;
220
221 (*kernel_)(&p);
222 };
223
224 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
225 int ocb_end) {
226 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
227 int iwork = bcast_start;
228 while (iwork < bcast_end) {
229 int n {0}, g {0}, bcast_step, od, oh, ow, id, ih, iw;
230 init_bcast(
231 iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, ih, iw);
232 int ocb = ocb_start;
233 while (ocb < ocb_end) {
234 int load_step;
235 init_load(ocb, ocb_end, load_step);
236 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
237 ker_1x1(ocb, icb, ocb_start, n, g, od, oh, ow, id, ih, iw);
238 }
239 ocb += load_step;
240 }
241 iwork += bcast_step;
242 }
243 };
244
245 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
246 int oh_1x1 = nstl::max(dw_oh * jcp_dw->stride_h - jcp_dw->t_pad, 0);
247
248 for (int i = 0; i < jcp_dw->kh; ++i)
249 addrs[i] = pbuf + ((oh_1x1++) % jcp_dw->kh) * row_offset;
250
251 const ptrdiff_t wch_stride = (is_src_layout_nxc ? 1 : jcp_dw->iw)
252 * jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
253 const auto ocb_end = ocb_start + load_step;
254 const int dil_h = jcp_dw->dilate_h + 1;
255 const int str_h = jcp_dw->stride_h;
256 const int ch_num = jcp_dw->nb_ch_blocking;
257 const int ow = 0;
258 const int kw = 0;
259
260 for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw->nb_ch_blocking) {
261
262 const int i_t_overflow
263 = nstl::max(0, (int)(jcp_dw->t_pad - dw_oh * str_h));
264 const int i_b_overflow
265 = nstl::max(jcp_dw->ih,
266 (int)(dw_oh * str_h + (jcp_dw->kh - 1) * dil_h
267 - jcp_dw->t_pad + 1))
268 - jcp_dw->ih;
269
270 const int kh = div_up(i_t_overflow, dil_h);
271 const int kh_padding = jcp_dw->kh - div_up(i_t_overflow, dil_h)
272 - div_up(i_b_overflow, dil_h);
273
274 jit_conv_call_s par_conv_dw;
275
276 par_conv_dw.src = addrs.data();
277
278 const size_t ch_step = is_dst_layout_nxc
279 ? jcp_dw->ch_block
280 : dst_d.blk_off(0, 1, 0, 0);
281 par_conv_dw.dst
282 = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step];
283
284 par_conv_dw.filt
285 = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)];
286 if (bias)
287 par_conv_dw.bias
288 = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw->ch_block)];
289
290 par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding);
291
292 par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw->nb_ch) - ch)
293 * jcp_dw->ch_block;
294
295 par_conv_dw.oc_l_off = ch * jcp_dw->ch_block;
296 par_conv_dw.post_ops_binary_rhs_arg_vec
297 = post_ops_binary_rhs_arg_vec_dw;
298 par_conv_dw.dst_orig = dst;
299
300 (*dw_jit_ker)(&par_conv_dw);
301
302 for (int i = 0; i < jcp_dw->kh; ++i)
303 addrs[i] += wch_stride;
304 }
305 };
306
307 auto conv_dw = [&]() {
308 // Set variables
309 memory_tracking::grantor_t dw_scratchpad(
310 scratchpad, memory_tracking::names::prefix_fusion);
311 auto dw_conv_buffer
312 = dw_scratchpad.get<data_t>(key_fusion_inout_buffer);
313 dw_jit_ker = kernel_dw_avx2 ? kernel_dw_avx2->ker()
314 : kernel_dw_sse41->ker();
315
316 const auto dw_conv_buffer_size_
317 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
318 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
319 row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
320 addrs.resize(jcp_dw->kh);
321
322 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
323 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
324 bcast_end, nb_oc, ocb_start, ocb_end, 1);
325
326 while (ocb_start < ocb_end) {
327 int load_step;
328 init_load(ocb_start, ocb_end, load_step);
329
330 int oh_1x1 = 0;
331 auto bcast_iter = bcast_start;
332 while (bcast_iter < bcast_end) {
333 int n, g, oh_dw;
334 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
335 jcp_dw->oh);
336 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
337 const int oh_1x1_range
338 = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
339 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
340 const int oh_1x1_end
341 = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
342 oh_1x1 = nstl::max(
343 oh_1x1_begin, oh_1x1); // Skip rows computed previously
344
345 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh
346 const int bcast_start_1x1
347 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
348 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
349
350 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
351 ocb_start + load_step);
352 oh_1x1 = oh_1x1_end;
353 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
354
355 bcast_iter += nb_bcast_blocking;
356 }
357 ocb_start += load_step;
358 }
359 };
360
361 if (jcp.with_dw_conv) {
362 conv_dw();
363 } else {
364 int start {0}, end {0};
365 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
366 balance211(work_amount, nthr, ithr, start, end);
367 conv_1x1(start, end, 0, jcp.nb_load);
368 }
369}
370
371/* convolution backward wtr data */
372
373void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data(
374 const exec_ctx_t &ctx) const {
375 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
376 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
377 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
378
379 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
380 const memory_desc_wrapper weights_d(pd()->weights_md(0));
381 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
382
383 const auto &jcp = kernel_->jcp;
384 auto rtus_space = pd()->rtus_.reduce_src_
385 ? ctx.get_scratchpad_grantor().get<data_t>(key_conv_rtus_space)
386 : nullptr;
387
388 // TODO (Roma): remove this restriction
389 assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1);
390 const int ndims = diff_dst_d.ndims();
391
392 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
393 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
394 const int stride_w = pd()->desc()->strides[ndims - 3];
395
396 const int nb_ic = jcp.nb_load;
397 const int nb_oc = jcp.nb_reduce;
398 const int os_block = jcp.bcast_block;
399 const int nb_oc_blocking = jcp.nb_reduce_blocking;
400
401 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
402
403 auto step = [](int default_step, int remaining, int tail_step) {
404 assert(default_step <= tail_step);
405 return remaining < tail_step ? remaining : default_step;
406 };
407
408 auto ker = [&](const int ithr, const int nthr) {
409 auto p = jit_1x1_conv_call_s();
410 auto rp = rtus_driver_t<avx2>::call_params_t();
411
412 int start {0}, end {0};
413 balance211(work_amount, nthr, ithr, start, end);
414
415 int load_step = 0;
416 for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
417 load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
418 jcp.nb_load_blocking_max);
419
420 p.load_dim = this_block_size(
421 icb * jcp.ic_block, jcp.ic, load_step * jcp.ic_block);
422 rp.icb = p.load_dim;
423
424 int bcast_step;
425 for (int iwork = start; iwork < end; iwork += bcast_step) {
426 int n {0}, g {0}, osb {0};
427 nd_iterator_init(
428 iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast);
429
430 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
431 jcp.nb_bcast_blocking_max);
432 bcast_step = nstl::min(bcast_step, end - iwork);
433
434 const int os = osb * os_block;
435 p.bcast_dim
436 = this_block_size(os, jcp.os, bcast_step * os_block);
437 rp.os = p.bcast_dim;
438
439 const int od = os / (jcp.oh * jcp.ow);
440 const int os_2d = os % (jcp.oh * jcp.ow);
441 const int oh = os_2d / jcp.ow;
442 const int ow = os_2d % jcp.ow;
443 const int id = od * stride_d;
444 const int ih = oh * stride_h;
445 const int iw = ow * stride_w;
446 rp.iw_start = iw;
447
448 const bool is_dsrc_layout_nxc = utils::one_of(jcp.src_tag,
449 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
450 const int ic_off_idx = is_dsrc_layout_nxc
451 ? g * jcp.ic + icb * jcp.ic_block
452 : g * nb_ic + icb;
453 rp.src = diff_src
454 + data_blk_off(diff_src_d, n, ic_off_idx, id, ih, iw);
455 if (pd()->rtus_.reduce_src_) {
456 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_;
457 p.output_data = rp.ws;
458 } else
459 p.output_data = rp.src;
460
461 for (int ocb = 0; ocb < jcp.nb_reduce;
462 ocb += jcp.nb_reduce_blocking) {
463 const bool is_ddst_layout_nxc
464 = utils::one_of(jcp.dst_tag, format_tag::nwc,
465 format_tag::nhwc, format_tag::ndhwc);
466 const int oc_off_idx = is_ddst_layout_nxc
467 ? g * jcp.oc + ocb * jcp.oc_block
468 : g * nb_oc + ocb;
469 size_t diff_dst_off = data_blk_off(
470 diff_dst_d, n, oc_off_idx, od, oh, ow);
471 p.bcast_data = &diff_dst[diff_dst_off];
472
473 p.load_data = &weights[pd()->with_groups()
474 ? weights_d.blk_off(g, ocb, icb)
475 : weights_d.blk_off(ocb, icb)];
476
477 p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
478
479 p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
480 nb_oc_blocking * jcp.oc_block);
481
482 (*kernel_)(&p);
483 }
484
485 if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp);
486 }
487 }
488 };
489
490 parallel(jcp.nthr, ker);
491}
492
493/* convolution backward wtr weights */
494
495status_t jit_avx2_1x1_convolution_bwd_weights_t::init(engine_t *engine) {
496 CHECK(safe_ptr_assign(kernel_,
497 new jit_avx2_1x1_conv_kernel_f32(
498 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
499 CHECK(kernel_->create_kernel());
500
501 CHECK(safe_ptr_assign(reducer_weights_,
502 new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_)));
503 CHECK(reducer_weights_->create_kernel());
504
505 CHECK(safe_ptr_assign(reducer_bias_,
506 new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_)));
507 if (pd()->with_bias()) {
508 assert(reducer_weights_->balancer().nthr_
509 == reducer_bias_->balancer().nthr_);
510 CHECK(reducer_bias_->create_kernel());
511 }
512
513 CHECK(init_rtus_driver<avx2>(this));
514 return status::success;
515}
516
517void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights(
518 const exec_ctx_t &ctx) const {
519 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
520 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
521 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
522 auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
523
524 auto scratchpad = ctx.get_scratchpad_grantor();
525
526 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
527 const memory_desc_wrapper src_d(pd()->src_md());
528 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
529 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
530
531 const auto &jcp = kernel_->jcp;
532 auto rtus_space = pd()->rtus_.reduce_src_
533 ? scratchpad.get<data_t>(key_conv_rtus_space)
534 : nullptr;
535
536 const bool is_bias_padded
537 = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0);
538
539 data_t *diff_bias = is_bias_padded
540 ? scratchpad.get<data_t>(key_conv_padded_bias)
541 : diff_bias_in;
542
543 auto reducer_bia_scratchpad
544 = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia);
545 auto rb = this->reducer_bias_.get();
546 rb->init(reducer_bia_scratchpad);
547
548 auto reducer_wei_scratchpad
549 = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei);
550 auto rw = this->reducer_weights_.get();
551 rw->init(reducer_wei_scratchpad);
552
553 const int ndims = diff_dst_d.ndims();
554 // TODO (Roma): remove this restriction
555 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
556
557 const int nb_ic = jcp.nb_bcast;
558 const int nb_ic_blocking = jcp.nb_bcast_blocking;
559 const int bcast_work = div_up(nb_ic, nb_ic_blocking);
560
561 const int nb_oc = jcp.nb_load;
562 const int nb_oc_blocking = jcp.nb_load_blocking;
563 const int load_work = div_up(nb_oc, nb_oc_blocking);
564
565 const int sp_dim = jcp.reduce_dim;
566 const int mb_sp_work = jcp.mb * sp_dim;
567
568 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
569 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
570 const int stride_w = pd()->desc()->strides[ndims - 3];
571
572 const bool is_src_layout_nxc = utils::one_of(
573 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
574 const bool is_ddst_layout_nxc = utils::one_of(
575 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
576
577 auto step = [](int default_step, int remaining, int tail_step) {
578 assert(default_step <= tail_step);
579 return remaining < tail_step ? remaining : default_step;
580 };
581
582 auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
583 data_t *store_to, size_t store_to_ld,
584 const data_t *diff_dst, const data_t *src,
585 int ithr) {
586 auto p = jit_1x1_conv_call_s();
587 auto rp = rtus_driver_t<avx2>::call_params_t();
588
589 p.output_stride = store_to_ld * sizeof(float);
590
591 int oc_b_step = 0;
592 for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
593 oc_b_step = step(nb_oc_blocking, nb_oc_blocking - oc_b,
594 jcp.nb_load_blocking_max);
595 p.load_dim = this_block_size(
596 oc_b * jcp.oc_block, jcp.oc, oc_b_step * jcp.oc_block);
597
598 int ic_b_step = 0;
599 for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
600 ic_b_step = step(nb_ic_blocking, nb_ic_blocking - ic_b,
601 jcp.nb_bcast_blocking_max);
602 p.bcast_dim = this_block_size(
603 ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block);
604 rp.icb = p.bcast_dim;
605
606 p.output_data = store_to + oc_b * store_to_ld
607 + ic_b * jcp.ic_block * jcp.oc_block;
608
609 /* spatial reduction */
610 int sp_step = 0;
611 for (int sp = sp_start; sp < sp_end; sp += sp_step) {
612 sp_step = step(jcp.nb_reduce_blocking, sp_end - sp,
613 jcp.nb_reduce_blocking_max);
614 p.reduce_dim = sp_step * jcp.reduce_block;
615 rp.os = p.reduce_dim;
616
617 p.first_last_flag = sp == sp_start && first_image
618 ? FLAG_REDUCE_FIRST
619 : 0;
620
621 p.load_data = diff_dst
622 + (oc_b * jcp.reduce_dim + sp)
623 * (is_ddst_layout_nxc ? jcp.oc
624 : jcp.oc_block);
625
626 if (pd()->rtus_.reduce_src_) {
627 const int od = sp / (jcp.oh * jcp.ow);
628 const int sp_2d = sp % (jcp.oh * jcp.ow);
629 const int oh = sp_2d / jcp.ow;
630 const int ow = sp_2d % jcp.ow;
631
632 const int id = od * stride_d;
633 const int ih = oh * stride_h;
634 const int iw = ow * stride_w;
635 rp.iw_start = iw;
636
637 rp.ws = rtus_space
638 + ithr * pd()->rtus_.space_per_thread_
639 + (ic_b * jcp.is + sp) * jcp.ic_block;
640 size_t src_offset
641 = iw * src_d.blocking_desc().strides[ndims - 1];
642 if (ndims > 3)
643 src_offset += ih
644 * src_d.blocking_desc().strides[ndims - 2];
645 if (ndims == 5)
646 src_offset += id
647 * src_d.blocking_desc().strides[ndims - 3];
648
649 rp.src = src + src_offset;
650 if (oc_b == 0) (*rtus_driver_)(&rp);
651
652 p.bcast_data = rp.ws;
653 } else
654 p.bcast_data = src
655 + (ic_b * jcp.reduce_dim + sp)
656 * (is_src_layout_nxc ? jcp.ic
657 : jcp.ic_block);
658
659 (*kernel_)(&p);
660 }
661 }
662 }
663 };
664
665 auto maybe_zero_icpad = [&](const int g_start, const int g_end,
666 const int ocb_start, const int ocb_end) {
667 // write zeros to IC padded region.
668 const int ic_tail = jcp.ic_without_padding % jcp.ic_block;
669 if (is_ddst_layout_nxc && ic_tail != 0) {
670 for_(int g = g_start; g < g_end; ++g)
671 for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) {
672 const int z_icb = nb_ic - 1;
673 const size_t off = pd()->with_groups()
674 ? diff_weights_d.blk_off(g, z_ocb, z_icb)
675 : diff_weights_d.blk_off(z_ocb, z_icb);
676 data_t *z_wei = diff_weights + off + ic_tail * jcp.oc_block;
677 const int zero_work
678 = (nb_ic * jcp.ic_block - jcp.ic_without_padding)
679 * jcp.oc_block;
680 PRAGMA_OMP_SIMD()
681 for (int o = 0; o < zero_work; ++o) {
682 z_wei[o] = 0;
683 }
684 }
685 }
686 };
687
688 auto ker = [&](const int ithr, const int nthr) {
689 assert(nthr == rw->balancer().nthr_);
690
691 const int w_njobs = rw->balancer().ithr_njobs(ithr);
692 if (w_njobs == 0) return;
693
694 /* setup: independent work (oc, ic) */
695 const int w_job_start = rw->balancer().ithr_job_off(ithr);
696 int g {0}, load_i {0}, bcast_i {0};
697 nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
698 bcast_i, bcast_work);
699
700 /* setup: reduction work (mb, sp) */
701 int mb_sp_start {0}, mb_sp_end {0};
702 balance211(mb_sp_work, rw->balancer().nthr_per_group_,
703 rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
704 int img_start {0}, sp_start {0};
705 nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
706
707 /* independent work */
708 for (int iwork = 0; iwork < w_njobs; ++iwork) {
709 const int oc_b = nb_oc_blocking * load_i;
710 const int ic_b = nb_ic_blocking * bcast_i;
711
712 const int oc_off_idx = is_ddst_layout_nxc
713 ? g * jcp.oc + oc_b * jcp.oc_block
714 : g * nb_oc + oc_b;
715 const int ic_off_idx = is_src_layout_nxc
716 ? g * jcp.ic + ic_b * jcp.ic_block
717 : g * nb_ic + ic_b;
718
719 data_t *store_to;
720 size_t store_to_ld;
721
722 if (rw->balancer().nthr_per_group_ == 1) {
723 const size_t off = pd()->with_groups()
724 ? diff_weights_d.blk_off(g, oc_b, ic_b)
725 : diff_weights_d.blk_off(oc_b, ic_b);
726 store_to = &diff_weights[off];
727 store_to_ld = rnd_up(jcp.ic, jcp.ic_block) * jcp.oc_block;
728 } else {
729 const size_t off = (size_t)iwork * rw->balancer().job_size_;
730 store_to
731 = rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
732 store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
733 }
734
735 /* reduction work */
736 int img = img_start;
737 int sp = sp_start;
738 int sp_step = 0;
739 for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) {
740 sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
741
742 const bool first_image = img == img_start;
743 if (is_ddst_layout_nxc && first_image
744 && rw->balancer().nthr_per_group_ > 1) {
745 // Zero-pad the scratchpad when nthr > 1 (since most threads
746 // write to scratchpad) so that zero-padding is maintained
747 // for the final output after reduction
748 array_set(rw->get_local_ptr(ithr, reducer_wei_scratchpad)
749 + iwork * rw->balancer().job_size_,
750 0, rw->balancer().job_size_);
751 }
752 oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
753 store_to_ld,
754 &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)],
755 &src[src_d.blk_off(img, ic_off_idx)], ithr);
756
757 sp = 0;
758 img += 1;
759 }
760
761 if (rw->balancer().nthr_per_group_ == 1
762 && bcast_i + 1 >= bcast_work)
763 maybe_zero_icpad(g, g + 1, oc_b,
764 nstl::min(nb_oc, oc_b + nb_oc_blocking));
765
766 nd_iterator_step(
767 g, jcp.ngroups, load_i, load_work, bcast_i, bcast_work);
768 }
769
770 if (dnnl_thr_syncable())
771 rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
772 };
773
774 auto ker_bias = [&](int ithr, int nthr) {
775 assert(nthr == rb->balancer().nthr_);
776
777 const int b_job_start = rb->balancer().ithr_job_off(ithr);
778 const int b_njobs = rb->balancer().ithr_njobs(ithr);
779
780 if (b_njobs == 0) return;
781
782 /* reduction dimension */
783 int img_start {0}, img_end {0};
784 balance211(jcp.mb, rb->balancer().nthr_per_group_,
785 rb->balancer().id_in_group(ithr), img_start, img_end);
786
787 /* jobs */
788 int g_start {0}, ocb_start {0};
789 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
790
791 for (int img = img_start; img < img_end; ++img) {
792 int g = g_start, ocb = ocb_start;
793 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
794 const int oc_off_idx = is_ddst_layout_nxc
795 ? g * jcp.oc + ocb * jcp.oc_block
796 : g * nb_oc + ocb;
797
798 const data_t *d_dst
799 = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)];
800 data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
801 reducer_bia_scratchpad)
802 + b_job_loc * rb->balancer().job_size_;
803
804 if (img == img_start)
805 for (int o = 0; o < 8; ++o)
806 d_bias[o] = 0.;
807
808 const int spatial_shift
809 = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block;
810 const int max_oc = this_block_size(
811 ocb * jcp.oc_block, jcp.oc, jcp.oc_block);
812 for (int hw = 0; hw < jcp.os; ++hw) {
813 PRAGMA_OMP_SIMD()
814 for (int o = 0; o < max_oc; ++o)
815 d_bias[o] += d_dst[o];
816 d_dst += spatial_shift;
817 }
818
819 nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
820 }
821 }
822
823 if (dnnl_thr_syncable())
824 rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
825 };
826
827 if (dnnl_thr_syncable()) {
828 assert(IMPLICATION(pd()->with_bias(),
829 rw->balancer().nthr_ == rb->balancer().nthr_));
830 parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) {
831 ker(ithr, nthr);
832 if (pd()->with_bias()) ker_bias(ithr, nthr);
833 });
834 } else {
835 parallel(rw->balancer().nthr_,
836 [&](int ithr, int nthr) { ker(ithr, nthr); });
837 parallel(rw->balancer().nthr_, [&](int ithr, int nthr) {
838 assert(nthr == rw->balancer().nthr_);
839 MAYBE_UNUSED(nthr);
840 if (rw->balancer().ithr_njobs(ithr) == 0) return;
841 rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad);
842 });
843 if (pd()->with_bias()) {
844 parallel(rb->balancer().nthr_,
845 [&](int ithr, int nthr) { ker_bias(ithr, nthr); });
846 parallel(rb->balancer().nthr_, [&](int ithr, int nthr) {
847 assert(nthr == rb->balancer().nthr_);
848 MAYBE_UNUSED(nthr);
849 if (rb->balancer().ithr_njobs(ithr) == 0) return;
850 rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad);
851 });
852 }
853 }
854
855 /* TODO: put this in ker_bias */
856 if (is_bias_padded) {
857 assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1));
858 const int padded_stride = utils::rnd_up(jcp.oc, jcp.oc_block);
859 const int stride = jcp.oc_without_padding;
860 for (int g = 0; g < jcp.ngroups; ++g) {
861 utils::array_copy(diff_bias_in + g * stride,
862 diff_bias + g * padded_stride, stride);
863 }
864 }
865}
866
867} // namespace x64
868} // namespace cpu
869} // namespace impl
870} // namespace dnnl
871