1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | |
24 | #include "cpu/x64/jit_avx2_1x1_convolution.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | #define data_blk_off(f, n, c, d, h, w) \ |
36 | ((ndims == 3) ? (f).blk_off(n, c, w) \ |
37 | : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ |
38 | : (f).blk_off(n, c, d, h, w))) |
39 | /* convolution forward */ |
40 | |
41 | void jit_avx2_1x1_convolution_fwd_t::execute_forward( |
42 | const exec_ctx_t &ctx) const { |
43 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
44 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
45 | auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
46 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
47 | auto weights_dw = CTX_IN_MEM( |
48 | const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
49 | auto bias_dw = CTX_IN_MEM( |
50 | const data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
51 | const auto post_ops_binary_rhs_arg_vec |
52 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
53 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ |
54 | ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx, |
55 | pd()->jcp_.post_ops.entry_.size() + 1) |
56 | : std::vector<const void *> {}; |
57 | |
58 | auto scratchpad = ctx.get_scratchpad_grantor(); |
59 | |
60 | const auto &jcp = kernel_->jcp; |
61 | // TODO (Roma): remove this restriction |
62 | assert(jcp.stride_w == 1 && jcp.stride_h == 1); |
63 | |
64 | if (pd()->wants_padded_bias()) { |
65 | auto padded_bias = scratchpad.get<data_t>(key_conv_padded_bias); |
66 | utils::array_copy(padded_bias, bias, jcp.oc_without_padding); |
67 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
68 | jcp.oc - jcp.oc_without_padding); |
69 | bias = padded_bias; |
70 | } |
71 | |
72 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
73 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
74 | dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), |
75 | post_ops_binary_rhs_arg_vec_dw.data()); |
76 | }); |
77 | |
78 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
79 | } |
80 | |
81 | void jit_avx2_1x1_convolution_fwd_t::execute_forward_thr(const int ithr, |
82 | const int nthr, const data_t *src, const data_t *weights, |
83 | const data_t *bias, const data_t *weights_dw, const data_t *bias_dw, |
84 | data_t *dst, const memory_tracking::grantor_t &scratchpad, |
85 | const void *post_ops_binary_rhs_arg_vec, |
86 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
87 | |
88 | const memory_desc_wrapper src_d(pd()->src_md()); |
89 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
90 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
91 | const memory_desc_wrapper dw_weights_d( |
92 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
93 | const memory_desc_wrapper dw_bias_d( |
94 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); |
95 | |
96 | const auto &jcp = kernel_->jcp; |
97 | auto rtus_space = pd()->rtus_.reduce_src_ |
98 | ? scratchpad.get<data_t>(key_conv_rtus_space) |
99 | : nullptr; |
100 | |
101 | const int ndims = dst_d.ndims(); |
102 | |
103 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
104 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
105 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
106 | |
107 | const int nb_oc = jcp.nb_load; |
108 | const int nb_ic = jcp.nb_reduce; |
109 | const int nb_ic_blocking = jcp.nb_reduce_blocking; |
110 | |
111 | auto p = jit_1x1_conv_call_s(); |
112 | auto rp = rtus_driver_t<avx2>::call_params_t(); |
113 | |
114 | // override some constants for fused dw_conv |
115 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
116 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
117 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
118 | const int nb_bcast_blocking_max |
119 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
120 | const int nb_load_blocking = jcp.nb_load_blocking; |
121 | const int nb_load_blocking_max = jcp.with_dw_conv |
122 | ? jcp.nb_load_blocking |
123 | : jcp.nb_load_blocking_max; |
124 | |
125 | // Begin: declare Variables needed for dw conv. |
126 | data_t *pbuf; |
127 | size_t row_offset; |
128 | const int nb_buffer = jcp.nb_load_blocking; |
129 | auto jcp_dw = pd()->jcp_dw_; |
130 | std::vector<data_t *> addrs; |
131 | jit_generator *dw_jit_ker = nullptr; |
132 | |
133 | const bool is_src_layout_nxc = utils::one_of( |
134 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
135 | const bool is_dst_layout_nxc = utils::one_of( |
136 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
137 | |
138 | auto step = [](int default_step, int remaining, int tail_step) { |
139 | assert(default_step <= tail_step); |
140 | return remaining < tail_step ? remaining : default_step; |
141 | }; |
142 | |
143 | auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, |
144 | int &bcast_step, int &od, int &oh, int &ow, |
145 | int &id, int &ih, int &iw) { |
146 | int osb {0}; |
147 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
148 | |
149 | bcast_step = step( |
150 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
151 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
152 | |
153 | const int os = osb * os_block; |
154 | const int os_2d = os % (jcp.oh * jcp.ow); |
155 | od = os / (jcp.oh * jcp.ow); |
156 | oh = os_2d / jcp.ow; |
157 | ow = os_2d % jcp.ow; |
158 | id = od * stride_d; |
159 | ih = oh * stride_h; |
160 | iw = ow * stride_w; |
161 | rp.iw_start = iw; |
162 | |
163 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
164 | rp.os = p.bcast_dim; |
165 | }; |
166 | |
167 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
168 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
169 | // binary postop injector may override zero-padded areas, so proper |
170 | // output masking needs to be performed base on exact number of channels |
171 | const auto oc = jcp.with_binary ? jcp.oc_without_padding : jcp.oc; |
172 | p.load_dim = this_block_size( |
173 | ocb * jcp.oc_block, oc, load_step * jcp.oc_block); |
174 | }; |
175 | |
176 | auto ker_1x1 = [&](int ocb, int icb, int ocb_start, int n, int g, int od, |
177 | int oh, int ow, int id, int ih, int iw) { |
178 | const int oc_off_idx = is_dst_layout_nxc |
179 | ? g * jcp.oc + ocb * jcp.oc_block |
180 | : g * nb_oc + ocb; |
181 | |
182 | p.output_data = jcp.with_dw_conv |
183 | ? pbuf + (oh % jcp_dw->kh) * row_offset |
184 | : &dst[data_blk_off(dst_d, n, oc_off_idx, od, oh, ow)]; |
185 | p.bias_data |
186 | = &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)]; |
187 | |
188 | p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0) |
189 | | (icb + nb_ic_blocking >= nb_ic ? FLAG_REDUCE_LAST : 0); |
190 | |
191 | p.reduce_dim = this_block_size( |
192 | icb * jcp.ic_block, jcp.ic, nb_ic_blocking * jcp.ic_block); |
193 | rp.icb = p.reduce_dim; |
194 | |
195 | p.load_data |
196 | = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) |
197 | : weights_d.blk_off(ocb, icb)]; |
198 | |
199 | const int ic_off_idx = is_src_layout_nxc |
200 | ? g * jcp.ic + icb * jcp.ic_block |
201 | : g * nb_ic + icb; |
202 | |
203 | if (pd()->rtus_.reduce_src_) { |
204 | rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ |
205 | + (is_src_layout_nxc ? ic_off_idx |
206 | : jcp.is * ic_off_idx * jcp.ic_block); |
207 | |
208 | if (ocb == ocb_start) { |
209 | rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
210 | (*rtus_driver_)(&rp); |
211 | } |
212 | |
213 | p.bcast_data = rp.ws; |
214 | } else |
215 | p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
216 | |
217 | p.oc_l_off = ocb * jcp.oc_block; |
218 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
219 | p.dst_orig = dst; |
220 | |
221 | (*kernel_)(&p); |
222 | }; |
223 | |
224 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
225 | int ocb_end) { |
226 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
227 | int iwork = bcast_start; |
228 | while (iwork < bcast_end) { |
229 | int n {0}, g {0}, bcast_step, od, oh, ow, id, ih, iw; |
230 | init_bcast( |
231 | iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, ih, iw); |
232 | int ocb = ocb_start; |
233 | while (ocb < ocb_end) { |
234 | int load_step; |
235 | init_load(ocb, ocb_end, load_step); |
236 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
237 | ker_1x1(ocb, icb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
238 | } |
239 | ocb += load_step; |
240 | } |
241 | iwork += bcast_step; |
242 | } |
243 | }; |
244 | |
245 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
246 | int oh_1x1 = nstl::max(dw_oh * jcp_dw->stride_h - jcp_dw->t_pad, 0); |
247 | |
248 | for (int i = 0; i < jcp_dw->kh; ++i) |
249 | addrs[i] = pbuf + ((oh_1x1++) % jcp_dw->kh) * row_offset; |
250 | |
251 | const ptrdiff_t wch_stride = (is_src_layout_nxc ? 1 : jcp_dw->iw) |
252 | * jcp_dw->nb_ch_blocking * jcp_dw->ch_block; |
253 | const auto ocb_end = ocb_start + load_step; |
254 | const int dil_h = jcp_dw->dilate_h + 1; |
255 | const int str_h = jcp_dw->stride_h; |
256 | const int ch_num = jcp_dw->nb_ch_blocking; |
257 | const int ow = 0; |
258 | const int kw = 0; |
259 | |
260 | for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw->nb_ch_blocking) { |
261 | |
262 | const int i_t_overflow |
263 | = nstl::max(0, (int)(jcp_dw->t_pad - dw_oh * str_h)); |
264 | const int i_b_overflow |
265 | = nstl::max(jcp_dw->ih, |
266 | (int)(dw_oh * str_h + (jcp_dw->kh - 1) * dil_h |
267 | - jcp_dw->t_pad + 1)) |
268 | - jcp_dw->ih; |
269 | |
270 | const int kh = div_up(i_t_overflow, dil_h); |
271 | const int kh_padding = jcp_dw->kh - div_up(i_t_overflow, dil_h) |
272 | - div_up(i_b_overflow, dil_h); |
273 | |
274 | jit_conv_call_s par_conv_dw; |
275 | |
276 | par_conv_dw.src = addrs.data(); |
277 | |
278 | const size_t ch_step = is_dst_layout_nxc |
279 | ? jcp_dw->ch_block |
280 | : dst_d.blk_off(0, 1, 0, 0); |
281 | par_conv_dw.dst |
282 | = &dst[dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step]; |
283 | |
284 | par_conv_dw.filt |
285 | = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; |
286 | if (bias) |
287 | par_conv_dw.bias |
288 | = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw->ch_block)]; |
289 | |
290 | par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); |
291 | |
292 | par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw->nb_ch) - ch) |
293 | * jcp_dw->ch_block; |
294 | |
295 | par_conv_dw.oc_l_off = ch * jcp_dw->ch_block; |
296 | par_conv_dw.post_ops_binary_rhs_arg_vec |
297 | = post_ops_binary_rhs_arg_vec_dw; |
298 | par_conv_dw.dst_orig = dst; |
299 | |
300 | (*dw_jit_ker)(&par_conv_dw); |
301 | |
302 | for (int i = 0; i < jcp_dw->kh; ++i) |
303 | addrs[i] += wch_stride; |
304 | } |
305 | }; |
306 | |
307 | auto conv_dw = [&]() { |
308 | // Set variables |
309 | memory_tracking::grantor_t dw_scratchpad( |
310 | scratchpad, memory_tracking::names::prefix_fusion); |
311 | auto dw_conv_buffer |
312 | = dw_scratchpad.get<data_t>(key_fusion_inout_buffer); |
313 | dw_jit_ker = kernel_dw_avx2 ? kernel_dw_avx2->ker() |
314 | : kernel_dw_sse41->ker(); |
315 | |
316 | const auto dw_conv_buffer_size_ |
317 | = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block; |
318 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
319 | row_offset = dw_conv_buffer_size_ / jcp_dw->kh; |
320 | addrs.resize(jcp_dw->kh); |
321 | |
322 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
323 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, |
324 | bcast_end, nb_oc, ocb_start, ocb_end, 1); |
325 | |
326 | while (ocb_start < ocb_end) { |
327 | int load_step; |
328 | init_load(ocb_start, ocb_end, load_step); |
329 | |
330 | int oh_1x1 = 0; |
331 | auto bcast_iter = bcast_start; |
332 | while (bcast_iter < bcast_end) { |
333 | int n, g, oh_dw; |
334 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
335 | jcp_dw->oh); |
336 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
337 | const int oh_1x1_range |
338 | = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad; |
339 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
340 | const int oh_1x1_end |
341 | = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh); |
342 | oh_1x1 = nstl::max( |
343 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
344 | |
345 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh |
346 | const int bcast_start_1x1 |
347 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
348 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
349 | |
350 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
351 | ocb_start + load_step); |
352 | oh_1x1 = oh_1x1_end; |
353 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
354 | |
355 | bcast_iter += nb_bcast_blocking; |
356 | } |
357 | ocb_start += load_step; |
358 | } |
359 | }; |
360 | |
361 | if (jcp.with_dw_conv) { |
362 | conv_dw(); |
363 | } else { |
364 | int start {0}, end {0}; |
365 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
366 | balance211(work_amount, nthr, ithr, start, end); |
367 | conv_1x1(start, end, 0, jcp.nb_load); |
368 | } |
369 | } |
370 | |
371 | /* convolution backward wtr data */ |
372 | |
373 | void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data( |
374 | const exec_ctx_t &ctx) const { |
375 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
376 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
377 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
378 | |
379 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
380 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
381 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
382 | |
383 | const auto &jcp = kernel_->jcp; |
384 | auto rtus_space = pd()->rtus_.reduce_src_ |
385 | ? ctx.get_scratchpad_grantor().get<data_t>(key_conv_rtus_space) |
386 | : nullptr; |
387 | |
388 | // TODO (Roma): remove this restriction |
389 | assert(jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1); |
390 | const int ndims = diff_dst_d.ndims(); |
391 | |
392 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
393 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
394 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
395 | |
396 | const int nb_ic = jcp.nb_load; |
397 | const int nb_oc = jcp.nb_reduce; |
398 | const int os_block = jcp.bcast_block; |
399 | const int nb_oc_blocking = jcp.nb_reduce_blocking; |
400 | |
401 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
402 | |
403 | auto step = [](int default_step, int remaining, int tail_step) { |
404 | assert(default_step <= tail_step); |
405 | return remaining < tail_step ? remaining : default_step; |
406 | }; |
407 | |
408 | auto ker = [&](const int ithr, const int nthr) { |
409 | auto p = jit_1x1_conv_call_s(); |
410 | auto rp = rtus_driver_t<avx2>::call_params_t(); |
411 | |
412 | int start {0}, end {0}; |
413 | balance211(work_amount, nthr, ithr, start, end); |
414 | |
415 | int load_step = 0; |
416 | for (int icb = 0; icb < jcp.nb_load; icb += load_step) { |
417 | load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb, |
418 | jcp.nb_load_blocking_max); |
419 | |
420 | p.load_dim = this_block_size( |
421 | icb * jcp.ic_block, jcp.ic, load_step * jcp.ic_block); |
422 | rp.icb = p.load_dim; |
423 | |
424 | int bcast_step; |
425 | for (int iwork = start; iwork < end; iwork += bcast_step) { |
426 | int n {0}, g {0}, osb {0}; |
427 | nd_iterator_init( |
428 | iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast); |
429 | |
430 | bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, |
431 | jcp.nb_bcast_blocking_max); |
432 | bcast_step = nstl::min(bcast_step, end - iwork); |
433 | |
434 | const int os = osb * os_block; |
435 | p.bcast_dim |
436 | = this_block_size(os, jcp.os, bcast_step * os_block); |
437 | rp.os = p.bcast_dim; |
438 | |
439 | const int od = os / (jcp.oh * jcp.ow); |
440 | const int os_2d = os % (jcp.oh * jcp.ow); |
441 | const int oh = os_2d / jcp.ow; |
442 | const int ow = os_2d % jcp.ow; |
443 | const int id = od * stride_d; |
444 | const int ih = oh * stride_h; |
445 | const int iw = ow * stride_w; |
446 | rp.iw_start = iw; |
447 | |
448 | const bool is_dsrc_layout_nxc = utils::one_of(jcp.src_tag, |
449 | format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
450 | const int ic_off_idx = is_dsrc_layout_nxc |
451 | ? g * jcp.ic + icb * jcp.ic_block |
452 | : g * nb_ic + icb; |
453 | rp.src = diff_src |
454 | + data_blk_off(diff_src_d, n, ic_off_idx, id, ih, iw); |
455 | if (pd()->rtus_.reduce_src_) { |
456 | rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_; |
457 | p.output_data = rp.ws; |
458 | } else |
459 | p.output_data = rp.src; |
460 | |
461 | for (int ocb = 0; ocb < jcp.nb_reduce; |
462 | ocb += jcp.nb_reduce_blocking) { |
463 | const bool is_ddst_layout_nxc |
464 | = utils::one_of(jcp.dst_tag, format_tag::nwc, |
465 | format_tag::nhwc, format_tag::ndhwc); |
466 | const int oc_off_idx = is_ddst_layout_nxc |
467 | ? g * jcp.oc + ocb * jcp.oc_block |
468 | : g * nb_oc + ocb; |
469 | size_t diff_dst_off = data_blk_off( |
470 | diff_dst_d, n, oc_off_idx, od, oh, ow); |
471 | p.bcast_data = &diff_dst[diff_dst_off]; |
472 | |
473 | p.load_data = &weights[pd()->with_groups() |
474 | ? weights_d.blk_off(g, ocb, icb) |
475 | : weights_d.blk_off(ocb, icb)]; |
476 | |
477 | p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0; |
478 | |
479 | p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, |
480 | nb_oc_blocking * jcp.oc_block); |
481 | |
482 | (*kernel_)(&p); |
483 | } |
484 | |
485 | if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); |
486 | } |
487 | } |
488 | }; |
489 | |
490 | parallel(jcp.nthr, ker); |
491 | } |
492 | |
493 | /* convolution backward wtr weights */ |
494 | |
495 | status_t jit_avx2_1x1_convolution_bwd_weights_t::init(engine_t *engine) { |
496 | CHECK(safe_ptr_assign(kernel_, |
497 | new jit_avx2_1x1_conv_kernel_f32( |
498 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
499 | CHECK(kernel_->create_kernel()); |
500 | |
501 | CHECK(safe_ptr_assign(reducer_weights_, |
502 | new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_))); |
503 | CHECK(reducer_weights_->create_kernel()); |
504 | |
505 | CHECK(safe_ptr_assign(reducer_bias_, |
506 | new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_))); |
507 | if (pd()->with_bias()) { |
508 | assert(reducer_weights_->balancer().nthr_ |
509 | == reducer_bias_->balancer().nthr_); |
510 | CHECK(reducer_bias_->create_kernel()); |
511 | } |
512 | |
513 | CHECK(init_rtus_driver<avx2>(this)); |
514 | return status::success; |
515 | } |
516 | |
517 | void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights( |
518 | const exec_ctx_t &ctx) const { |
519 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
520 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
521 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
522 | auto diff_bias_in = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
523 | |
524 | auto scratchpad = ctx.get_scratchpad_grantor(); |
525 | |
526 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
527 | const memory_desc_wrapper src_d(pd()->src_md()); |
528 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
529 | const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); |
530 | |
531 | const auto &jcp = kernel_->jcp; |
532 | auto rtus_space = pd()->rtus_.reduce_src_ |
533 | ? scratchpad.get<data_t>(key_conv_rtus_space) |
534 | : nullptr; |
535 | |
536 | const bool is_bias_padded |
537 | = pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0); |
538 | |
539 | data_t *diff_bias = is_bias_padded |
540 | ? scratchpad.get<data_t>(key_conv_padded_bias) |
541 | : diff_bias_in; |
542 | |
543 | auto reducer_bia_scratchpad |
544 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_bia); |
545 | auto rb = this->reducer_bias_.get(); |
546 | rb->init(reducer_bia_scratchpad); |
547 | |
548 | auto reducer_wei_scratchpad |
549 | = memory_tracking::grantor_t(scratchpad, prefix_reducer_wei); |
550 | auto rw = this->reducer_weights_.get(); |
551 | rw->init(reducer_wei_scratchpad); |
552 | |
553 | const int ndims = diff_dst_d.ndims(); |
554 | // TODO (Roma): remove this restriction |
555 | assert(jcp.stride_w == 1 && jcp.stride_h == 1); |
556 | |
557 | const int nb_ic = jcp.nb_bcast; |
558 | const int nb_ic_blocking = jcp.nb_bcast_blocking; |
559 | const int bcast_work = div_up(nb_ic, nb_ic_blocking); |
560 | |
561 | const int nb_oc = jcp.nb_load; |
562 | const int nb_oc_blocking = jcp.nb_load_blocking; |
563 | const int load_work = div_up(nb_oc, nb_oc_blocking); |
564 | |
565 | const int sp_dim = jcp.reduce_dim; |
566 | const int mb_sp_work = jcp.mb * sp_dim; |
567 | |
568 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
569 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
570 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
571 | |
572 | const bool is_src_layout_nxc = utils::one_of( |
573 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
574 | const bool is_ddst_layout_nxc = utils::one_of( |
575 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
576 | |
577 | auto step = [](int default_step, int remaining, int tail_step) { |
578 | assert(default_step <= tail_step); |
579 | return remaining < tail_step ? remaining : default_step; |
580 | }; |
581 | |
582 | auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image, |
583 | data_t *store_to, size_t store_to_ld, |
584 | const data_t *diff_dst, const data_t *src, |
585 | int ithr) { |
586 | auto p = jit_1x1_conv_call_s(); |
587 | auto rp = rtus_driver_t<avx2>::call_params_t(); |
588 | |
589 | p.output_stride = store_to_ld * sizeof(float); |
590 | |
591 | int oc_b_step = 0; |
592 | for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) { |
593 | oc_b_step = step(nb_oc_blocking, nb_oc_blocking - oc_b, |
594 | jcp.nb_load_blocking_max); |
595 | p.load_dim = this_block_size( |
596 | oc_b * jcp.oc_block, jcp.oc, oc_b_step * jcp.oc_block); |
597 | |
598 | int ic_b_step = 0; |
599 | for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) { |
600 | ic_b_step = step(nb_ic_blocking, nb_ic_blocking - ic_b, |
601 | jcp.nb_bcast_blocking_max); |
602 | p.bcast_dim = this_block_size( |
603 | ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block); |
604 | rp.icb = p.bcast_dim; |
605 | |
606 | p.output_data = store_to + oc_b * store_to_ld |
607 | + ic_b * jcp.ic_block * jcp.oc_block; |
608 | |
609 | /* spatial reduction */ |
610 | int sp_step = 0; |
611 | for (int sp = sp_start; sp < sp_end; sp += sp_step) { |
612 | sp_step = step(jcp.nb_reduce_blocking, sp_end - sp, |
613 | jcp.nb_reduce_blocking_max); |
614 | p.reduce_dim = sp_step * jcp.reduce_block; |
615 | rp.os = p.reduce_dim; |
616 | |
617 | p.first_last_flag = sp == sp_start && first_image |
618 | ? FLAG_REDUCE_FIRST |
619 | : 0; |
620 | |
621 | p.load_data = diff_dst |
622 | + (oc_b * jcp.reduce_dim + sp) |
623 | * (is_ddst_layout_nxc ? jcp.oc |
624 | : jcp.oc_block); |
625 | |
626 | if (pd()->rtus_.reduce_src_) { |
627 | const int od = sp / (jcp.oh * jcp.ow); |
628 | const int sp_2d = sp % (jcp.oh * jcp.ow); |
629 | const int oh = sp_2d / jcp.ow; |
630 | const int ow = sp_2d % jcp.ow; |
631 | |
632 | const int id = od * stride_d; |
633 | const int ih = oh * stride_h; |
634 | const int iw = ow * stride_w; |
635 | rp.iw_start = iw; |
636 | |
637 | rp.ws = rtus_space |
638 | + ithr * pd()->rtus_.space_per_thread_ |
639 | + (ic_b * jcp.is + sp) * jcp.ic_block; |
640 | size_t src_offset |
641 | = iw * src_d.blocking_desc().strides[ndims - 1]; |
642 | if (ndims > 3) |
643 | src_offset += ih |
644 | * src_d.blocking_desc().strides[ndims - 2]; |
645 | if (ndims == 5) |
646 | src_offset += id |
647 | * src_d.blocking_desc().strides[ndims - 3]; |
648 | |
649 | rp.src = src + src_offset; |
650 | if (oc_b == 0) (*rtus_driver_)(&rp); |
651 | |
652 | p.bcast_data = rp.ws; |
653 | } else |
654 | p.bcast_data = src |
655 | + (ic_b * jcp.reduce_dim + sp) |
656 | * (is_src_layout_nxc ? jcp.ic |
657 | : jcp.ic_block); |
658 | |
659 | (*kernel_)(&p); |
660 | } |
661 | } |
662 | } |
663 | }; |
664 | |
665 | auto maybe_zero_icpad = [&](const int g_start, const int g_end, |
666 | const int ocb_start, const int ocb_end) { |
667 | // write zeros to IC padded region. |
668 | const int ic_tail = jcp.ic_without_padding % jcp.ic_block; |
669 | if (is_ddst_layout_nxc && ic_tail != 0) { |
670 | for_(int g = g_start; g < g_end; ++g) |
671 | for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) { |
672 | const int z_icb = nb_ic - 1; |
673 | const size_t off = pd()->with_groups() |
674 | ? diff_weights_d.blk_off(g, z_ocb, z_icb) |
675 | : diff_weights_d.blk_off(z_ocb, z_icb); |
676 | data_t *z_wei = diff_weights + off + ic_tail * jcp.oc_block; |
677 | const int zero_work |
678 | = (nb_ic * jcp.ic_block - jcp.ic_without_padding) |
679 | * jcp.oc_block; |
680 | PRAGMA_OMP_SIMD() |
681 | for (int o = 0; o < zero_work; ++o) { |
682 | z_wei[o] = 0; |
683 | } |
684 | } |
685 | } |
686 | }; |
687 | |
688 | auto ker = [&](const int ithr, const int nthr) { |
689 | assert(nthr == rw->balancer().nthr_); |
690 | |
691 | const int w_njobs = rw->balancer().ithr_njobs(ithr); |
692 | if (w_njobs == 0) return; |
693 | |
694 | /* setup: independent work (oc, ic) */ |
695 | const int w_job_start = rw->balancer().ithr_job_off(ithr); |
696 | int g {0}, load_i {0}, bcast_i {0}; |
697 | nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work, |
698 | bcast_i, bcast_work); |
699 | |
700 | /* setup: reduction work (mb, sp) */ |
701 | int mb_sp_start {0}, mb_sp_end {0}; |
702 | balance211(mb_sp_work, rw->balancer().nthr_per_group_, |
703 | rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end); |
704 | int img_start {0}, sp_start {0}; |
705 | nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim); |
706 | |
707 | /* independent work */ |
708 | for (int iwork = 0; iwork < w_njobs; ++iwork) { |
709 | const int oc_b = nb_oc_blocking * load_i; |
710 | const int ic_b = nb_ic_blocking * bcast_i; |
711 | |
712 | const int oc_off_idx = is_ddst_layout_nxc |
713 | ? g * jcp.oc + oc_b * jcp.oc_block |
714 | : g * nb_oc + oc_b; |
715 | const int ic_off_idx = is_src_layout_nxc |
716 | ? g * jcp.ic + ic_b * jcp.ic_block |
717 | : g * nb_ic + ic_b; |
718 | |
719 | data_t *store_to; |
720 | size_t store_to_ld; |
721 | |
722 | if (rw->balancer().nthr_per_group_ == 1) { |
723 | const size_t off = pd()->with_groups() |
724 | ? diff_weights_d.blk_off(g, oc_b, ic_b) |
725 | : diff_weights_d.blk_off(oc_b, ic_b); |
726 | store_to = &diff_weights[off]; |
727 | store_to_ld = rnd_up(jcp.ic, jcp.ic_block) * jcp.oc_block; |
728 | } else { |
729 | const size_t off = (size_t)iwork * rw->balancer().job_size_; |
730 | store_to |
731 | = rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off; |
732 | store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block; |
733 | } |
734 | |
735 | /* reduction work */ |
736 | int img = img_start; |
737 | int sp = sp_start; |
738 | int sp_step = 0; |
739 | for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step) { |
740 | sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp); |
741 | |
742 | const bool first_image = img == img_start; |
743 | if (is_ddst_layout_nxc && first_image |
744 | && rw->balancer().nthr_per_group_ > 1) { |
745 | // Zero-pad the scratchpad when nthr > 1 (since most threads |
746 | // write to scratchpad) so that zero-padding is maintained |
747 | // for the final output after reduction |
748 | array_set(rw->get_local_ptr(ithr, reducer_wei_scratchpad) |
749 | + iwork * rw->balancer().job_size_, |
750 | 0, rw->balancer().job_size_); |
751 | } |
752 | oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to, |
753 | store_to_ld, |
754 | &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)], |
755 | &src[src_d.blk_off(img, ic_off_idx)], ithr); |
756 | |
757 | sp = 0; |
758 | img += 1; |
759 | } |
760 | |
761 | if (rw->balancer().nthr_per_group_ == 1 |
762 | && bcast_i + 1 >= bcast_work) |
763 | maybe_zero_icpad(g, g + 1, oc_b, |
764 | nstl::min(nb_oc, oc_b + nb_oc_blocking)); |
765 | |
766 | nd_iterator_step( |
767 | g, jcp.ngroups, load_i, load_work, bcast_i, bcast_work); |
768 | } |
769 | |
770 | if (dnnl_thr_syncable()) |
771 | rw->reduce(ithr, diff_weights, reducer_wei_scratchpad); |
772 | }; |
773 | |
774 | auto ker_bias = [&](int ithr, int nthr) { |
775 | assert(nthr == rb->balancer().nthr_); |
776 | |
777 | const int b_job_start = rb->balancer().ithr_job_off(ithr); |
778 | const int b_njobs = rb->balancer().ithr_njobs(ithr); |
779 | |
780 | if (b_njobs == 0) return; |
781 | |
782 | /* reduction dimension */ |
783 | int img_start {0}, img_end {0}; |
784 | balance211(jcp.mb, rb->balancer().nthr_per_group_, |
785 | rb->balancer().id_in_group(ithr), img_start, img_end); |
786 | |
787 | /* jobs */ |
788 | int g_start {0}, ocb_start {0}; |
789 | nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc); |
790 | |
791 | for (int img = img_start; img < img_end; ++img) { |
792 | int g = g_start, ocb = ocb_start; |
793 | for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) { |
794 | const int oc_off_idx = is_ddst_layout_nxc |
795 | ? g * jcp.oc + ocb * jcp.oc_block |
796 | : g * nb_oc + ocb; |
797 | |
798 | const data_t *d_dst |
799 | = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; |
800 | data_t *d_bias = rb->get_local_ptr(ithr, diff_bias, |
801 | reducer_bia_scratchpad) |
802 | + b_job_loc * rb->balancer().job_size_; |
803 | |
804 | if (img == img_start) |
805 | for (int o = 0; o < 8; ++o) |
806 | d_bias[o] = 0.; |
807 | |
808 | const int spatial_shift |
809 | = is_ddst_layout_nxc ? jcp.oc : jcp.oc_block; |
810 | const int max_oc = this_block_size( |
811 | ocb * jcp.oc_block, jcp.oc, jcp.oc_block); |
812 | for (int hw = 0; hw < jcp.os; ++hw) { |
813 | PRAGMA_OMP_SIMD() |
814 | for (int o = 0; o < max_oc; ++o) |
815 | d_bias[o] += d_dst[o]; |
816 | d_dst += spatial_shift; |
817 | } |
818 | |
819 | nd_iterator_step(g, jcp.ngroups, ocb, nb_oc); |
820 | } |
821 | } |
822 | |
823 | if (dnnl_thr_syncable()) |
824 | rb->reduce(ithr, diff_bias, reducer_bia_scratchpad); |
825 | }; |
826 | |
827 | if (dnnl_thr_syncable()) { |
828 | assert(IMPLICATION(pd()->with_bias(), |
829 | rw->balancer().nthr_ == rb->balancer().nthr_)); |
830 | parallel(rw->balancer().nthr_, [&](const int ithr, const int nthr) { |
831 | ker(ithr, nthr); |
832 | if (pd()->with_bias()) ker_bias(ithr, nthr); |
833 | }); |
834 | } else { |
835 | parallel(rw->balancer().nthr_, |
836 | [&](int ithr, int nthr) { ker(ithr, nthr); }); |
837 | parallel(rw->balancer().nthr_, [&](int ithr, int nthr) { |
838 | assert(nthr == rw->balancer().nthr_); |
839 | MAYBE_UNUSED(nthr); |
840 | if (rw->balancer().ithr_njobs(ithr) == 0) return; |
841 | rw->reduce_nolock(ithr, diff_weights, reducer_wei_scratchpad); |
842 | }); |
843 | if (pd()->with_bias()) { |
844 | parallel(rb->balancer().nthr_, |
845 | [&](int ithr, int nthr) { ker_bias(ithr, nthr); }); |
846 | parallel(rb->balancer().nthr_, [&](int ithr, int nthr) { |
847 | assert(nthr == rb->balancer().nthr_); |
848 | MAYBE_UNUSED(nthr); |
849 | if (rb->balancer().ithr_njobs(ithr) == 0) return; |
850 | rb->reduce_nolock(ithr, diff_bias, reducer_bia_scratchpad); |
851 | }); |
852 | } |
853 | } |
854 | |
855 | /* TODO: put this in ker_bias */ |
856 | if (is_bias_padded) { |
857 | assert(IMPLICATION(!is_ddst_layout_nxc, jcp.ngroups == 1)); |
858 | const int padded_stride = utils::rnd_up(jcp.oc, jcp.oc_block); |
859 | const int stride = jcp.oc_without_padding; |
860 | for (int g = 0; g < jcp.ngroups; ++g) { |
861 | utils::array_copy(diff_bias_in + g * stride, |
862 | diff_bias + g * padded_stride, stride); |
863 | } |
864 | } |
865 | } |
866 | |
867 | } // namespace x64 |
868 | } // namespace cpu |
869 | } // namespace impl |
870 | } // namespace dnnl |
871 | |