1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | using namespace dnnl::impl::status; |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | #define data_blk_off(f, n, c, d, h, w) \ |
37 | ((ndims == 3) ? (f).blk_off(n, c, w) \ |
38 | : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ |
39 | : (f).blk_off(n, c, d, h, w))) |
40 | |
41 | /* convolution forward */ |
42 | template <cpu_isa_t isa> |
43 | status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::execute_forward( |
44 | const exec_ctx_t &ctx) const { |
45 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
46 | const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
47 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
48 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
49 | auto weights_dw = CTX_IN_MEM( |
50 | const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
51 | auto bias_dw = CTX_IN_MEM( |
52 | const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
53 | const auto post_ops_binary_rhs_arg_vec |
54 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
55 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ |
56 | ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx, |
57 | pd()->jcp_.post_ops.entry_.size() + 1) |
58 | : std::vector<const void *> {}; |
59 | |
60 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
61 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
62 | |
63 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
64 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
65 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
66 | |
67 | DEFINE_ARG_SCALES_BUFFER( |
68 | dw_wei_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
69 | DEFINE_ARG_SCALES_BUFFER( |
70 | dw_dst_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST); |
71 | |
72 | auto scratchpad = ctx.get_scratchpad_grantor(); |
73 | |
74 | auto local_scales |
75 | = scratchpad.template get<float>(key_conv_adjusted_scales); |
76 | const float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
77 | ? 1.f / pd()->jcp_.wei_adj_scale |
78 | : 1.0f; |
79 | int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
80 | if (wei_mask == 0) { |
81 | utils::array_set( |
82 | local_scales, src_scales[0] * wei_scales[0] * factor, 8); |
83 | } else { |
84 | for (dim_t c = 0; c < pd()->OC(); c++) |
85 | local_scales[c] = src_scales[0] * wei_scales[c] * factor; |
86 | } |
87 | |
88 | const float *dw_oscales = nullptr; |
89 | if (pd()->jcp_.with_dw_conv) { |
90 | auto jcp_dw = pd()->jcp_dw_; |
91 | memory_tracking::grantor_t dw_scratchpad( |
92 | scratchpad, memory_tracking::names::prefix_fusion); |
93 | auto attr_dw = pd()->dw_conv_pd_->attr(); |
94 | |
95 | auto dw_local_scales |
96 | = dw_scratchpad.template get<float>(key_conv_adjusted_scales); |
97 | int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
98 | float factor = 1.f / jcp_dw->wei_adj_scale; |
99 | if (wei_mask == 0) { |
100 | utils::array_set(dw_local_scales, |
101 | dw_wei_scales[0] / dst_scales[0] * factor, |
102 | pd()->jcp_.ic_block); |
103 | } else { |
104 | for (dim_t c = 0; c < pd()->dw_conv_pd_->OC(); c++) |
105 | dw_local_scales[c] = dw_wei_scales[c] / dst_scales[0] * factor; |
106 | } |
107 | dw_oscales = dw_local_scales; |
108 | } |
109 | parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { |
110 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
111 | dst, local_scales, dst_scales, dw_oscales, dw_dst_scales, |
112 | src_zero_point, dst_zero_point, scratchpad, |
113 | post_ops_binary_rhs_arg_vec.data(), |
114 | post_ops_binary_rhs_arg_vec_dw.data()); |
115 | }); |
116 | return status::success; |
117 | } |
118 | |
119 | template <cpu_isa_t isa> |
120 | void jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::execute_forward_thr( |
121 | const int ithr, const int nthr, const char *src, const char *weights, |
122 | const char *bias, const char *weights_dw, const char *bias_dw, |
123 | char *dst, const float *oscales, const float *dst_scales, |
124 | const float *dw_oscales, const float *dw_dst_scales, |
125 | const int32_t *src_zero_point, const int32_t *dst_zero_point, |
126 | const memory_tracking::grantor_t &scratchpad, |
127 | const void *post_ops_binary_rhs_arg_vec, |
128 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
129 | const memory_desc_wrapper src_d(pd()->src_md()); |
130 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
131 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
132 | const memory_desc_wrapper dw_weights_d( |
133 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
134 | |
135 | const auto &jcp = pd()->jcp_; |
136 | |
137 | const size_t src_dt_size = types::data_type_size(src_d.data_type()); |
138 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
139 | const size_t bia_dt_size = pd()->with_bias() |
140 | ? types::data_type_size(pd()->desc()->bias_desc.data_type) |
141 | : 0; |
142 | |
143 | auto rtus_space = pd()->rtus_.reduce_src_ |
144 | ? scratchpad.get<char>(key_conv_rtus_space) |
145 | : nullptr; |
146 | |
147 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
148 | |
149 | const int ndims = dst_d.ndims(); |
150 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
151 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
152 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
153 | |
154 | auto offset = weights_d.size() - weights_d.additional_buffer_size(); |
155 | char *w = const_cast<char *>(weights); |
156 | const int32_t *compensation = (jcp.signed_input) |
157 | ? reinterpret_cast<int32_t *>(w + offset) |
158 | : nullptr; |
159 | const int32_t *zp_compensation = jcp.src_zero_point |
160 | ? reinterpret_cast<int32_t *>(&w[offset]) |
161 | + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) |
162 | : nullptr; |
163 | |
164 | auto p = jit_1x1_conv_call_s(); |
165 | |
166 | auto rp = typename rtus_driver_t<isa>::call_params_t(); |
167 | const int nb_oc = jcp.nb_load; |
168 | // override some constants for fused dw_conv |
169 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
170 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
171 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
172 | const int nb_bcast_blocking_max |
173 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
174 | const int nb_load_blocking = jcp.nb_load_blocking; |
175 | const int nb_load_blocking_max = jcp.with_dw_conv |
176 | ? jcp.nb_load_blocking |
177 | : jcp.nb_load_blocking_max; |
178 | |
179 | // Begin: declare Variables needed for dw conv. |
180 | const auto jcp_dw = pd()->jcp_dw_; |
181 | const auto &dw_pd = pd()->dw_conv_pd_; |
182 | memory_tracking::grantor_t dw_scratchpad( |
183 | scratchpad, memory_tracking::names::prefix_fusion); |
184 | |
185 | const size_t dw_bia_dt_size = jcp_dw && jcp_dw->with_bias |
186 | ? types::data_type_size(dw_pd->desc()->bias_desc.data_type) |
187 | : 0; |
188 | |
189 | int32_t *compensation_dw {nullptr}; |
190 | |
191 | if (jcp.with_dw_conv) { |
192 | offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size(); |
193 | w = const_cast<char *>(weights_dw); |
194 | compensation_dw = (jcp_dw->signed_input) |
195 | ? reinterpret_cast<int32_t *>(w + offset) |
196 | : nullptr; |
197 | } |
198 | |
199 | char *pbuf {nullptr}; |
200 | size_t row_offset {}; |
201 | const int nb_buffer = jcp.nb_load_blocking; |
202 | std::vector<char *> addrs; |
203 | // End |
204 | |
205 | auto step = [](int default_step, int remaining, int tail_step) { |
206 | assert(default_step <= tail_step); |
207 | return remaining < tail_step ? remaining : default_step; |
208 | }; |
209 | |
210 | auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, |
211 | int &bcast_step, int &od, int &oh, int &ow, |
212 | int &id, int &ih, int &iw) { |
213 | int osb {0}; |
214 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
215 | bcast_step = step( |
216 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
217 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
218 | |
219 | const int os = osb * os_block; |
220 | od = os / (jcp.oh * jcp.ow); |
221 | const int os_2d = os % (jcp.oh * jcp.ow); |
222 | oh = os_2d / jcp.ow; |
223 | ow = os_2d % jcp.ow; |
224 | |
225 | id = od * stride_d; |
226 | ih = oh * stride_h; |
227 | iw = ow * stride_w; |
228 | rp.iw_start = iw; |
229 | |
230 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
231 | rp.os = p.bcast_dim; |
232 | }; |
233 | |
234 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
235 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
236 | p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block, |
237 | load_step * jcp.oc_block); |
238 | |
239 | if (ocb + load_step >= nb_oc) |
240 | p.first_last_flag |= FLAG_OC_LAST; |
241 | else |
242 | p.first_last_flag &= ~FLAG_OC_LAST; |
243 | }; |
244 | |
245 | auto init_reduce = [&]() { |
246 | p.reduce_dim = this_block_size( |
247 | 0, jcp.ic_without_padding, jcp.ic_without_padding); |
248 | rp.icb = p.reduce_dim; |
249 | }; |
250 | |
251 | auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh, |
252 | int ow, int id, int ih, int iw) { |
253 | const int icb = 0; // Start from the first IC block |
254 | const int _ocb = g * nb_oc + ocb; |
255 | const int _icb = g; |
256 | |
257 | const auto src_offset |
258 | = data_blk_off(src_d, n, _icb * jcp.ic_block, id, ih, iw); |
259 | const auto dst_offset |
260 | = data_blk_off(dst_d, n, _ocb * jcp.oc_block, od, oh, ow); |
261 | p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset |
262 | : dst + dst_dt_size * dst_offset; |
263 | |
264 | const auto wei_offset = pd()->with_groups() |
265 | ? weights_d.blk_off(g, ocb, icb) |
266 | : weights_d.blk_off(ocb, icb); |
267 | p.load_data = weights + wei_offset; |
268 | p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; |
269 | p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block] |
270 | : nullptr; |
271 | p.zp_compensation = jcp.src_zero_point |
272 | ? zp_compensation + _ocb * jcp.oc_block |
273 | : nullptr; |
274 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
275 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
276 | p.scales = &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block]; |
277 | p.dst_scale = dst_scales; |
278 | if (pd()->rtus_.reduce_src_) { |
279 | rp.ws = rtus_space |
280 | + src_dt_size |
281 | * (ithr * pd()->rtus_.space_per_thread_ |
282 | + _icb * jcp.is * jcp.ic_block); |
283 | if (ocb == ocb_start) { |
284 | rp.src = src + src_dt_size * src_offset; |
285 | (*rtus_driver_)(&rp); |
286 | } |
287 | p.bcast_data = rp.ws; |
288 | } else |
289 | p.bcast_data = src + src_dt_size * src_offset; |
290 | |
291 | p.oc_l_off = g * nb_oc + ocb * jcp.oc_block; |
292 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
293 | p.dst_orig = jcp.with_dw_conv ? pbuf : dst; |
294 | |
295 | (*kernel_)(&p); |
296 | }; |
297 | |
298 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
299 | int ocb_end) { |
300 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
301 | if (jcp.loop_order == loop_rlb) { |
302 | init_reduce(); |
303 | int ocb = ocb_start; |
304 | while (ocb < ocb_end) { |
305 | int load_step; |
306 | init_load(ocb, ocb_end, load_step); |
307 | int iwork = bcast_start; |
308 | while (iwork < bcast_end) { |
309 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
310 | id {0}, ih {0}, iw {0}; |
311 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
312 | id, ih, iw); |
313 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
314 | iwork += bcast_step; |
315 | } |
316 | ocb += load_step; |
317 | } |
318 | } else if (jcp.loop_order == loop_lbr) { |
319 | int ocb = ocb_start; |
320 | while (ocb < ocb_end) { |
321 | int load_step; |
322 | init_load(ocb, ocb_end, load_step); |
323 | int iwork = bcast_start; |
324 | while (iwork < bcast_end) { |
325 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
326 | id {0}, ih {0}, iw {0}; |
327 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
328 | id, ih, iw); |
329 | init_reduce(); |
330 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
331 | iwork += bcast_step; |
332 | } |
333 | ocb += load_step; |
334 | } |
335 | } else if (jcp.loop_order == loop_rbl) { |
336 | init_reduce(); |
337 | int iwork = bcast_start; |
338 | while (iwork < bcast_end) { |
339 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
340 | id {0}, ih {0}, iw {0}; |
341 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
342 | ih, iw); |
343 | int ocb = ocb_start; |
344 | while (ocb < ocb_end) { |
345 | int load_step; |
346 | init_load(ocb, ocb_end, load_step); |
347 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
348 | ocb += load_step; |
349 | } |
350 | iwork += bcast_step; |
351 | } |
352 | } else if (jcp.loop_order == loop_blr) { |
353 | int iwork = bcast_start; |
354 | while (iwork < bcast_end) { |
355 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
356 | id {0}, ih {0}, iw {0}; |
357 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
358 | ih, iw); |
359 | int ocb = ocb_start; |
360 | while (ocb < ocb_end) { |
361 | int load_step; |
362 | init_load(ocb, ocb_end, load_step); |
363 | init_reduce(); |
364 | ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); |
365 | ocb += load_step; |
366 | } |
367 | iwork += bcast_step; |
368 | } |
369 | } else { |
370 | assert(!"unsupported loop order" ); |
371 | } |
372 | }; |
373 | |
374 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
375 | int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad; |
376 | int oh_1x1_begin = nstl::max(oh_1x1, 0); |
377 | |
378 | for (int i = 0; i < jcp_dw->kh; ++i) |
379 | addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset; |
380 | |
381 | const auto ocb_end = ocb_start + load_step; |
382 | const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block; |
383 | auto par_conv_dw = jit_conv_call_s(); |
384 | |
385 | par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1)); |
386 | par_conv_dw.b_overflow = nstl::min( |
387 | jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh)); |
388 | par_conv_dw.kh_padding = nstl::max<int>(0, |
389 | jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow); |
390 | |
391 | const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow |
392 | + dw_oh * jcp_dw->ow * jcp_dw->ngroups; |
393 | |
394 | const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1); |
395 | const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow |
396 | * wht_h_stride; |
397 | for (int ocb = ocb_start; ocb < ocb_end; |
398 | ocb += jcp_dw->nb_ch_blocking) { |
399 | |
400 | par_conv_dw.src = addrs.data(); |
401 | par_conv_dw.dst = dst |
402 | + (dst_offset + jcp_dw->ch_block * ocb) |
403 | * jcp_dw->typesize_out; |
404 | |
405 | par_conv_dw.filt |
406 | = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride; |
407 | par_conv_dw.bias |
408 | = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size]; |
409 | par_conv_dw.ur_w = (size_t)(jcp_dw->ow); |
410 | par_conv_dw.owb = jcp_dw->ow; |
411 | par_conv_dw.oc_blocks = ocb; |
412 | par_conv_dw.compensation = compensation_dw |
413 | ? &compensation_dw[ocb * jcp_dw->ch_block] |
414 | : nullptr; |
415 | par_conv_dw.scales = dw_oscales |
416 | ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block] |
417 | : nullptr; |
418 | par_conv_dw.dst_scale = dw_dst_scales; |
419 | |
420 | par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block; |
421 | par_conv_dw.post_ops_binary_rhs_arg_vec |
422 | = post_ops_binary_rhs_arg_vec_dw; |
423 | par_conv_dw.dst_orig = dst; |
424 | |
425 | (*kernel_dw_)(&par_conv_dw); |
426 | |
427 | for (int i = 0; i < jcp_dw->kh; ++i) |
428 | addrs[i] += src_ch_stride; |
429 | } |
430 | }; |
431 | |
432 | auto conv_dw = [&]() { |
433 | auto &jcp_dw = pd()->jcp_dw_; |
434 | auto dw_conv_buffer = dw_scratchpad.get<char>(key_fusion_inout_buffer); |
435 | |
436 | const auto dw_conv_buffer_size_ |
437 | = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block; |
438 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
439 | row_offset = dw_conv_buffer_size_ / jcp_dw->kh; |
440 | addrs.resize(jcp_dw->kh); |
441 | |
442 | int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; |
443 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, |
444 | bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); |
445 | |
446 | while (ocb_start < ocb_end) { |
447 | int load_step; |
448 | init_load(ocb_start, ocb_end, load_step); |
449 | |
450 | int oh_1x1 = 0; |
451 | auto bcast_iter = bcast_start; |
452 | while (bcast_iter < bcast_end) { |
453 | int n {0}, g {0}, oh_dw {0}; |
454 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
455 | jcp_dw->oh); |
456 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
457 | const int oh_1x1_range |
458 | = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad; |
459 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
460 | const int oh_1x1_end |
461 | = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh); |
462 | oh_1x1 = nstl::max( |
463 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
464 | |
465 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh |
466 | const int bcast_start_1x1 |
467 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
468 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
469 | |
470 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
471 | ocb_start + load_step); |
472 | oh_1x1 = oh_1x1_end; |
473 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
474 | |
475 | bcast_iter += nb_bcast_blocking; |
476 | } |
477 | ocb_start += load_step; |
478 | } |
479 | }; |
480 | |
481 | if (jcp.with_dw_conv) { |
482 | conv_dw(); |
483 | } else { |
484 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
485 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, |
486 | jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, |
487 | jcp.load_grp_count); |
488 | if (jcp.nb_load_chunk > 1) { |
489 | ocb_start *= jcp.nb_load_chunk; |
490 | ocb_end *= jcp.nb_load_chunk; |
491 | } |
492 | conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); |
493 | } |
494 | } |
495 | |
496 | template struct jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>; |
497 | template struct jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>; |
498 | |
499 | } // namespace x64 |
500 | } // namespace cpu |
501 | } // namespace impl |
502 | } // namespace dnnl |
503 | |