1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | |
26 | #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | using namespace dnnl::impl::status; |
34 | using namespace dnnl::impl::memory_tracking::names; |
35 | using namespace dnnl::impl::utils; |
36 | |
37 | /* convolution forward */ |
38 | status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward( |
39 | const exec_ctx_t &ctx) const { |
40 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
41 | const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
42 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
43 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
44 | auto weights_dw = CTX_IN_MEM( |
45 | const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
46 | auto bias_dw = CTX_IN_MEM( |
47 | const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
48 | const auto post_ops_binary_rhs_arg_vec |
49 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
50 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ |
51 | ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx, |
52 | pd()->jcp_.post_ops.entry_.size() + 1) |
53 | : std::vector<const void *> {}; |
54 | |
55 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
56 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
57 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
58 | |
59 | DEFINE_ARG_SCALES_BUFFER( |
60 | dw_wei_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
61 | DEFINE_ARG_SCALES_BUFFER( |
62 | dw_dst_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST); |
63 | |
64 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
65 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
66 | |
67 | auto scratchpad = ctx.get_scratchpad_grantor(); |
68 | |
69 | auto local_scales |
70 | = scratchpad.template get<float>(key_conv_adjusted_scales); |
71 | // Src scale is always a single value |
72 | float src_scale = src_scales[0]; |
73 | int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
74 | float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
75 | ? 1.f / pd()->jcp_.wei_adj_scale |
76 | : 1.f; |
77 | switch (wei_mask) { |
78 | case 0: |
79 | utils::array_set(local_scales, src_scale * wei_scales[0] * factor, |
80 | pd()->jcp_.ic_block); |
81 | break; |
82 | default: |
83 | for (dim_t c = 0; c < pd()->OC(); c++) |
84 | local_scales[c] = src_scale * wei_scales[c] * factor; |
85 | } |
86 | |
87 | const float *dw_oscales = nullptr; |
88 | if (pd()->jcp_.with_dw_conv) { |
89 | auto jcp_dw = pd()->jcp_dw_; |
90 | memory_tracking::grantor_t dw_scratchpad( |
91 | scratchpad, memory_tracking::names::prefix_fusion); |
92 | auto dw_local_scales |
93 | = dw_scratchpad.template get<float>(key_conv_adjusted_scales); |
94 | auto attr_dw = pd()->dw_conv_pd_->attr(); |
95 | int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
96 | dim_t count = wei_mask == 0 ? 1 : pd()->dw_conv_pd_->OC(); |
97 | float factor = 1.f / jcp_dw->wei_adj_scale; |
98 | if (count == 1) { |
99 | utils::array_set(dw_local_scales, |
100 | dw_wei_scales[0] / dst_scales[0] * factor, |
101 | pd()->jcp_.ic_block); |
102 | } else { |
103 | for (dim_t c = 0; c < count; c++) |
104 | dw_local_scales[c] = dw_wei_scales[c] / dst_scales[0] * factor; |
105 | } |
106 | dw_oscales = dw_local_scales; |
107 | } |
108 | parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { |
109 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
110 | dst, local_scales, dst_scales, dw_oscales, dw_dst_scales, |
111 | src_zero_point, dst_zero_point, scratchpad, |
112 | post_ops_binary_rhs_arg_vec.data(), |
113 | post_ops_binary_rhs_arg_vec_dw.data()); |
114 | }); |
115 | return status::success; |
116 | } |
117 | |
118 | void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr( |
119 | const int ithr, const int nthr, const char *src, const char *weights, |
120 | const char *bias, const char *weights_dw, const char *bias_dw, |
121 | char *dst, const float *oscales, const float *dst_scales, |
122 | const float *dw_oscales, const float *dw_dst_scales, |
123 | const int32_t *src_zero_point, const int32_t *dst_zero_point, |
124 | const memory_tracking::grantor_t &scratchpad, |
125 | const void *post_ops_binary_rhs_arg_vec, |
126 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
127 | const memory_desc_wrapper src_d(pd()->src_md()); |
128 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
129 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
130 | const memory_desc_wrapper dw_weights_d( |
131 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
132 | |
133 | const auto &jcp = pd()->jcp_; |
134 | |
135 | const size_t src_dt_size = types::data_type_size(src_d.data_type()); |
136 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
137 | const size_t bia_dt_size = pd()->with_bias() |
138 | ? types::data_type_size(pd()->desc()->bias_desc.data_type) |
139 | : 0; |
140 | |
141 | auto rtus_space = pd()->rtus_.reduce_src_ |
142 | ? scratchpad.get<char>(key_conv_rtus_space) |
143 | : nullptr; |
144 | |
145 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
146 | |
147 | const bool is_2d = pd()->ndims() == 4; |
148 | const bool is_3d = pd()->ndims() == 5; |
149 | |
150 | const int stride_d = pd()->KSD(); |
151 | const int stride_h = pd()->KSH(); |
152 | const int stride_w = pd()->KSW(); |
153 | |
154 | auto offset = weights_d.size() - weights_d.additional_buffer_size(); |
155 | char *w = const_cast<char *>(weights); |
156 | const int32_t *compensation = (jcp.signed_input) |
157 | ? reinterpret_cast<int32_t *>(w + offset) |
158 | : nullptr; |
159 | const int32_t *zp_compensation = jcp.src_zero_point |
160 | ? reinterpret_cast<int32_t *>(w + offset) |
161 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
162 | : nullptr; |
163 | |
164 | auto p = jit_1x1_conv_call_s(); |
165 | |
166 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
167 | const int nb_oc = jcp.nb_load; |
168 | const int nb_ic = jcp.nb_reduce; |
169 | // override some constants for fused dw_conv |
170 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
171 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
172 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
173 | const int nb_bcast_blocking_max |
174 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
175 | const int nb_load_blocking = jcp.nb_load_blocking; |
176 | const int nb_load_blocking_max = jcp.with_dw_conv |
177 | ? jcp.nb_load_blocking |
178 | : jcp.nb_load_blocking_max; |
179 | |
180 | // Begin: declare Variables needed for dw conv. |
181 | const auto jcp_dw = pd()->jcp_dw_; |
182 | const auto &dw_pd = pd()->dw_conv_pd_; |
183 | memory_tracking::grantor_t dw_scratchpad( |
184 | scratchpad, memory_tracking::names::prefix_fusion); |
185 | |
186 | size_t dw_bia_dt_size = 0; |
187 | int32_t *compensation_dw {nullptr}; |
188 | if (jcp.with_dw_conv && jcp_dw) { |
189 | if (jcp_dw->with_bias) |
190 | dw_bia_dt_size |
191 | = types::data_type_size(dw_pd->desc()->bias_desc.data_type); |
192 | |
193 | offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size(); |
194 | w = const_cast<char *>(weights_dw); |
195 | compensation_dw = (jcp_dw->signed_input) |
196 | ? reinterpret_cast<int32_t *>(w + offset) |
197 | : nullptr; |
198 | dw_oscales = dw_scratchpad.get<float>(key_conv_adjusted_scales); |
199 | } |
200 | |
201 | char *pbuf {nullptr}; |
202 | size_t row_offset {}; |
203 | const int nb_buffer = jcp.nb_load_blocking; |
204 | std::vector<char *> addrs; |
205 | // End |
206 | |
207 | auto step = [](int default_step, int remaining, int tail_step) { |
208 | assert(default_step <= tail_step); |
209 | return remaining < tail_step ? remaining : default_step; |
210 | }; |
211 | |
212 | auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, |
213 | int &bcast_step, int &od, int &oh, int &ow, |
214 | int &id, int &ih, int &iw) { |
215 | int osb {0}; |
216 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
217 | bcast_step = step( |
218 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
219 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
220 | |
221 | const int os = osb * os_block; |
222 | const int depth_orthogonal_area = jcp.ow * jcp.oh; |
223 | od = os / depth_orthogonal_area; |
224 | oh = (os % depth_orthogonal_area) / jcp.ow; |
225 | ow = (os % depth_orthogonal_area) % jcp.ow; |
226 | |
227 | id = od * stride_d; |
228 | ih = oh * stride_h; |
229 | iw = ow * stride_w; |
230 | rp.iw_start = iw; |
231 | |
232 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
233 | rp.os = p.bcast_dim; |
234 | }; |
235 | |
236 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
237 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
238 | p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block, |
239 | load_step * jcp.oc_block); |
240 | |
241 | if (ocb + load_step >= nb_oc) |
242 | p.first_last_flag |= FLAG_OC_LAST; |
243 | else |
244 | p.first_last_flag &= ~FLAG_OC_LAST; |
245 | }; |
246 | |
247 | auto init_reduce = [&]() { |
248 | p.reduce_dim = this_block_size( |
249 | 0, jcp.ic_without_padding, jcp.ic_without_padding); |
250 | rp.icb = p.reduce_dim; |
251 | }; |
252 | |
253 | auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh, |
254 | int ow, int id, int ih, int iw) { |
255 | const int icb = 0; // Start from the first IC block |
256 | const int _ocb = g * nb_oc + ocb; |
257 | const int _icb = g * nb_ic + icb; |
258 | |
259 | const size_t dst_off = is_3d |
260 | ? dst_d.blk_off(n, _ocb * jcp.oc_block, od, oh, ow) |
261 | : is_2d ? dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow) |
262 | : dst_d.blk_off(n, _ocb * jcp.oc_block, ow); |
263 | |
264 | p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset |
265 | : dst + dst_dt_size * dst_off; |
266 | const auto wei_offset = pd()->with_groups() |
267 | ? weights_d.blk_off(g, ocb, icb) |
268 | : weights_d.blk_off(ocb, icb); |
269 | p.load_data = weights + wei_offset; |
270 | p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; |
271 | p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block] |
272 | : nullptr; |
273 | p.zp_compensation = jcp.src_zero_point |
274 | ? zp_compensation + _ocb * jcp.oc_block |
275 | : nullptr; |
276 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
277 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
278 | p.scales = &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block]; |
279 | p.dst_scale = dst_scales; |
280 | const size_t src_off = is_3d |
281 | ? src_d.blk_off(n, _icb * jcp.ic_block, id, ih, iw) |
282 | : is_2d ? src_d.blk_off(n, _icb * jcp.ic_block, ih, iw) |
283 | : src_d.blk_off(n, _icb * jcp.ic_block, iw); |
284 | if (pd()->rtus_.reduce_src_) { |
285 | rp.ws = rtus_space |
286 | + src_dt_size |
287 | * (ithr * pd()->rtus_.space_per_thread_ |
288 | + _icb * jcp.is * jcp.ic_block); |
289 | if (ocb == ocb_start) { |
290 | rp.src = src + src_dt_size * src_off; |
291 | (*rtus_driver_)(&rp); |
292 | } |
293 | p.bcast_data = rp.ws; |
294 | } else |
295 | p.bcast_data = src + src_dt_size * src_off; |
296 | |
297 | p.dst_l_off = dst_off; |
298 | p.oc_l_off = _ocb * jcp.oc_block; |
299 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
300 | p.dst_orig = dst; |
301 | |
302 | (*kernel_)(&p); |
303 | }; |
304 | |
305 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
306 | int ocb_end) { |
307 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
308 | if (jcp.loop_order == loop_rlb) { |
309 | init_reduce(); |
310 | int ocb = ocb_start; |
311 | while (ocb < ocb_end) { |
312 | int load_step; |
313 | init_load(ocb, ocb_end, load_step); |
314 | int iwork = bcast_start; |
315 | while (iwork < bcast_end) { |
316 | int n, g, bcast_step, od, oh, ow, id, ih, iw; |
317 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
318 | id, ih, iw); |
319 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
320 | iwork += bcast_step; |
321 | } |
322 | ocb += load_step; |
323 | } |
324 | } else if (jcp.loop_order == loop_lbr) { |
325 | int ocb = ocb_start; |
326 | while (ocb < ocb_end) { |
327 | int load_step; |
328 | init_load(ocb, ocb_end, load_step); |
329 | int iwork = bcast_start; |
330 | while (iwork < bcast_end) { |
331 | int n, g, bcast_step, od, oh, ow, id, ih, iw; |
332 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
333 | id, ih, iw); |
334 | init_reduce(); |
335 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
336 | iwork += bcast_step; |
337 | } |
338 | ocb += load_step; |
339 | } |
340 | } else if (jcp.loop_order == loop_rbl) { |
341 | init_reduce(); |
342 | int iwork = bcast_start; |
343 | while (iwork < bcast_end) { |
344 | int n, g, bcast_step, od, oh, ow, id, ih, iw; |
345 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
346 | ih, iw); |
347 | int ocb = ocb_start; |
348 | while (ocb < ocb_end) { |
349 | int load_step; |
350 | init_load(ocb, ocb_end, load_step); |
351 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
352 | ocb += load_step; |
353 | } |
354 | iwork += bcast_step; |
355 | } |
356 | } else if (jcp.loop_order == loop_blr) { |
357 | int iwork = bcast_start; |
358 | while (iwork < bcast_end) { |
359 | int n, g, bcast_step, od, oh, ow, id, ih, iw; |
360 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
361 | ih, iw); |
362 | int ocb = ocb_start; |
363 | while (ocb < ocb_end) { |
364 | int load_step; |
365 | init_load(ocb, ocb_end, load_step); |
366 | init_reduce(); |
367 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
368 | ocb += load_step; |
369 | } |
370 | iwork += bcast_step; |
371 | } |
372 | } else { |
373 | assert(!"unsupported loop order" ); |
374 | } |
375 | }; |
376 | |
377 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
378 | int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad; |
379 | int oh_1x1_begin = nstl::max(oh_1x1, 0); |
380 | |
381 | for (int i = 0; i < jcp_dw->kh; ++i) |
382 | addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset; |
383 | |
384 | const auto ocb_end = ocb_start + load_step; |
385 | const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block; |
386 | auto par_conv_dw = jit_conv_call_s(); |
387 | |
388 | par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1)); |
389 | par_conv_dw.b_overflow = nstl::min( |
390 | jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh)); |
391 | par_conv_dw.kh_padding = nstl::max<int>(0, |
392 | jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow); |
393 | |
394 | const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow |
395 | + dw_oh * jcp_dw->ow * jcp_dw->ngroups; |
396 | |
397 | const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1); |
398 | const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow |
399 | * wht_h_stride; |
400 | for (int ocb = ocb_start; ocb < ocb_end; |
401 | ocb += jcp_dw->nb_ch_blocking) { |
402 | |
403 | par_conv_dw.src = addrs.data(); |
404 | par_conv_dw.dst = dst |
405 | + (dst_offset + jcp_dw->ch_block * ocb) |
406 | * jcp_dw->typesize_out; |
407 | |
408 | par_conv_dw.filt |
409 | = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride; |
410 | par_conv_dw.bias |
411 | = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size]; |
412 | par_conv_dw.ur_w = (size_t)(jcp_dw->ow); |
413 | par_conv_dw.owb = jcp_dw->ow; |
414 | par_conv_dw.oc_blocks = ocb; |
415 | par_conv_dw.compensation = compensation_dw |
416 | ? &compensation_dw[ocb * jcp_dw->ch_block] |
417 | : nullptr; |
418 | par_conv_dw.scales = dw_oscales |
419 | ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block] |
420 | : nullptr; |
421 | par_conv_dw.dst_scale = dw_dst_scales; |
422 | |
423 | par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block; |
424 | par_conv_dw.post_ops_binary_rhs_arg_vec |
425 | = post_ops_binary_rhs_arg_vec_dw; |
426 | par_conv_dw.dst_orig = dst; |
427 | |
428 | (*kernel_dw_)(&par_conv_dw); |
429 | |
430 | for (int i = 0; i < jcp_dw->kh; ++i) |
431 | addrs[i] += src_ch_stride; |
432 | } |
433 | }; |
434 | |
435 | auto conv_dw = [&]() { |
436 | auto jcp_dw = pd()->jcp_dw_; |
437 | auto dw_conv_buffer = dw_scratchpad.get<char>(key_fusion_inout_buffer); |
438 | |
439 | const auto dw_conv_buffer_size_ |
440 | = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block; |
441 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
442 | row_offset = dw_conv_buffer_size_ / jcp_dw->kh; |
443 | addrs.resize(jcp_dw->kh); |
444 | |
445 | int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; |
446 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, |
447 | bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); |
448 | |
449 | while (ocb_start < ocb_end) { |
450 | int load_step; |
451 | init_load(ocb_start, ocb_end, load_step); |
452 | |
453 | int oh_1x1 = 0; |
454 | auto bcast_iter = bcast_start; |
455 | while (bcast_iter < bcast_end) { |
456 | int n, g, oh_dw; |
457 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
458 | jcp_dw->oh); |
459 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
460 | const int oh_1x1_range |
461 | = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad; |
462 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
463 | const int oh_1x1_end |
464 | = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh); |
465 | oh_1x1 = nstl::max( |
466 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
467 | |
468 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh |
469 | const int bcast_start_1x1 |
470 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
471 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
472 | |
473 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
474 | ocb_start + load_step); |
475 | oh_1x1 = oh_1x1_end; |
476 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
477 | |
478 | bcast_iter += nb_bcast_blocking; |
479 | } |
480 | ocb_start += load_step; |
481 | } |
482 | }; |
483 | |
484 | if (jcp.with_dw_conv) { |
485 | conv_dw(); |
486 | } else { |
487 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
488 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, |
489 | jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, |
490 | jcp.load_grp_count); |
491 | if (jcp.nb_load_chunk > 1) { |
492 | ocb_start *= jcp.nb_load_chunk; |
493 | ocb_end *= jcp.nb_load_chunk; |
494 | } |
495 | conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); |
496 | } |
497 | } |
498 | |
499 | } // namespace x64 |
500 | } // namespace cpu |
501 | } // namespace impl |
502 | } // namespace dnnl |
503 | |