1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23
24#include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34using namespace data_type;
35
36using namespace nstl;
37
38#define wht_blk_off(d, g, ...) \
39 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
41
42template <cpu_isa_t isa>
43const float *jit_uni_x8s8s32x_convolution_fwd_t<isa>::adjust_oscales(
44 const memory_tracking::grantor_t &scratchpad, const float *src_scales,
45 const float *wei_scales) const {
46 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
47 int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
48 float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
49 ? 1.f / pd()->jcp_.wei_adj_scale
50 : 1.0f;
51 if (wei_mask == 0) {
52 utils::array_set(loc_scales, src_scales[0] * wei_scales[0] * factor, 8);
53 } else {
54 for (dim_t c = 0; c < pd()->OC(); c++)
55 loc_scales[c] = src_scales[0] * wei_scales[c] * factor;
56 }
57 return loc_scales;
58}
59
60template <cpu_isa_t isa>
61status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_2d(
62 const exec_ctx_t &ctx) const {
63 const auto &jcp = pd()->jcp_;
64 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
65 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
66 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
67 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
68 const auto post_ops_binary_rhs_arg_vec
69 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
70
71 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
72 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
73
74 const memory_desc_wrapper src_d(pd()->src_md());
75 const memory_desc_wrapper dst_d(pd()->dst_md());
76 const memory_desc_wrapper weights_d(pd()->weights_md(0));
77 const memory_desc_wrapper bias_d(pd()->weights_md(1));
78
79 const size_t bia_dt_size
80 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
81 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
82
83 assert(jcp.ch_block == 1);
84 assert(jcp.nb_ch_blocking == 1);
85 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
86 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
87
88 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
89 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
90 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
91
92 const float *oscales = adjust_oscales(
93 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
94
95 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
96 auto w = const_cast<char *>(weights);
97 const int32_t *compensation = (jcp.signed_input)
98 ? reinterpret_cast<int32_t *>(&w[offset])
99 : nullptr;
100 const int32_t *zp_compensation = jcp.src_zero_point
101 ? reinterpret_cast<int32_t *>(&w[offset])
102 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
103 : nullptr;
104 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
105 int nb_groups = jcp.nb_ch;
106 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
107
108 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
109 int start {0}, end {0};
110 balance211(work_amount, nthr, ithr, start, end);
111
112 auto p = jit_conv_call_s();
113
114 size_t src_h_stride = src_d.blk_off(0, 0, 1);
115 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
116 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
117
118 int n {0}, g {0}, occ {0}, oh_s {0}, owb {0};
119 switch (jcp.loop_order) {
120 case loop_cwgn:
121 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
122 nb_groups, n, jcp.mb, oh_s, jcp.oh);
123 break;
124 case loop_ngcw:
125 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
126 owb, jcp.nb_ow, oh_s, jcp.oh);
127 break;
128 case loop_nhwcg:
129 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
130 occ, oc_chunks, g, nb_groups);
131 break;
132 default: assert(!"unsupported loop order");
133 }
134 while (start < end) {
135 for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
136 occ1 += jcp.nb_oc_blocking) {
137 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
138 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
139
140 int g_ic = g * jcp.nb_ic * jcp.ic_block;
141
142 int work_rem = end - start;
143 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
144 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
145 if (jcp.loop_order == loop_nhwcg)
146 oh_e = oh_s + 1; // step instead
147 int ow_s = owb * jcp.ow_block;
148 int iw_s = ow_s * jcp.stride_w;
149
150 auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
151 : nullptr;
152 const int32_t *compensation_w
153 = (jcp.signed_input) ? compensation + g_oc : nullptr;
154
155 auto dst_w = dst
156 + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s);
157 auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
158 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
159
160 auto scales = &oscales[jcp.is_oc_scale * g_oc];
161
162 for (int oj = oh_s, ij = ih_s; oj < oh_e;
163 ++oj, ij += jcp.stride_h) {
164 int dilate_h = jcp.dilate_h + 1;
165 int i_t_overflow
166 = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h));
167 int i_b_overflow = nstl::min(jcp.kh,
168 div_up(max(0,
169 ij - jcp.ih + (jcp.kh - 1) * dilate_h
170 + 1),
171 dilate_h));
172 int kh_padding = nstl::max(
173 0, jcp.kh - i_t_overflow - i_b_overflow);
174
175 const size_t wei_stride
176 = (jcp.signed_input || jcp.src_zero_point)
177 ? 0
178 : i_t_overflow * wht_h_stride;
179 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
180 p.dst = dst_w;
181 p.filt = wht_w + wei_stride;
182 p.bias = bias_w;
183 p.compensation = compensation_w;
184 p.zp_compensation = jcp.src_zero_point
185 ? zp_compensation + g_oc
186 : nullptr;
187 p.src_zero_point
188 = jcp.src_zero_point ? src_zero_point : nullptr;
189 p.dst_zero_point
190 = jcp.dst_zero_point ? dst_zero_point : nullptr;
191 p.oc_blocks = ocb;
192 p.kh_padding = kh_padding;
193 p.scales = scales;
194 p.dst_scale = dst_scales;
195 p.t_overflow = i_t_overflow;
196 p.b_overflow = i_b_overflow;
197 p.owb = owb;
198
199 p.oc_l_off = g_oc;
200 p.post_ops_binary_rhs_arg_vec
201 = post_ops_binary_rhs_arg_vec.data();
202 p.dst_orig = dst;
203
204 (*kernel_)(&p);
205 src_w += src_h_stride * jcp.stride_h;
206 dst_w += dst_dt_size * dst_h_stride;
207 }
208 }
209 switch (jcp.loop_order) {
210 case loop_cwgn:
211 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
212 g, nb_groups, n, jcp.mb, oh_s, jcp.oh);
213 break;
214 case loop_ngcw:
215 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
216 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
217 break;
218 case loop_nhwcg:
219 ++start;
220 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
221 occ, oc_chunks, g, nb_groups);
222 break;
223 default: assert(!"unsupported loop order");
224 }
225 }
226 });
227 return status::success;
228}
229
230template <cpu_isa_t isa>
231status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_1d(
232 const exec_ctx_t &ctx) const {
233 const auto &jcp = pd()->jcp_;
234 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
235 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
236 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
237 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
238 const auto post_ops_binary_rhs_arg_vec
239 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
240
241 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
242 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
243
244 const memory_desc_wrapper src_d(pd()->src_md());
245 const memory_desc_wrapper dst_d(pd()->dst_md());
246 const memory_desc_wrapper weights_d(pd()->weights_md(0));
247 const memory_desc_wrapper bias_d(pd()->weights_md(1));
248
249 const size_t bia_dt_size
250 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
251 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
252
253 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
254 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
255
256 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
257 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
258 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
259
260 const float *oscales = adjust_oscales(
261 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
262
263 size_t extra_data_offset
264 = weights_d.size() - weights_d.additional_buffer_size();
265 size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block
266 : jcp.ngroups * jcp.oc;
267 auto w = const_cast<char *>(weights);
268 const int32_t *compensation = (jcp.signed_input)
269 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
270 : nullptr;
271 const int32_t *zp_compensation = jcp.src_zero_point
272 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
273 + (jcp.signed_input ? ch_offset : 0)
274 : nullptr;
275
276 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
277 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
278 int group_block = jcp.ch_block;
279 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
280 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
281 int start {0}, end {0};
282 balance211(work_amount, nthr, ithr, start, end);
283
284 auto p = jit_conv_call_s();
285
286 int n {0}, gg {0}, occ {0}, owb {0};
287 switch (jcp.loop_order) {
288 case loop_cwgn:
289 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
290 nb_groups, n, jcp.mb);
291 break;
292 case loop_gncw:
293 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ,
294 oc_chunks, owb, jcp.nb_ow);
295 break;
296 case loop_ngcw:
297 nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ,
298 oc_chunks, owb, jcp.nb_ow);
299 break;
300 case loop_nwcg:
301 nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ,
302 oc_chunks, gg, nb_groups);
303 break;
304 default: assert(!"unsupported loop order");
305 }
306 while (start < end) {
307 int ocb = occ * jcp.nb_oc_blocking;
308 int gb = gg * jcp.nb_ch_blocking;
309 int g = gb * group_block;
310 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
311 int g_ic = g * jcp.nb_ic * jcp.ic_block;
312 int ow_s = owb * jcp.ow_block;
313 int iw_s = ow_s * jcp.stride_w;
314
315 p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
316 : nullptr;
317 p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
318 p.zp_compensation
319 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
320 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
321 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
322 p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc, ow_s);
323 p.src = src + src_d.blk_off(n, g_ic, iw_s);
324 p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
325 p.scales = &oscales[jcp.is_oc_scale * g_oc];
326 p.dst_scale = dst_scales;
327 p.oc_blocks = jcp.is_depthwise ? gb : ocb;
328 p.kh_padding = jcp.kh;
329 p.t_overflow = 0;
330 p.b_overflow = 0;
331 p.owb = owb;
332
333 p.oc_l_off = g_oc;
334 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
335 p.dst_orig = dst;
336
337 (*kernel_)(&p);
338
339 ++start;
340 switch (jcp.loop_order) {
341 case loop_cwgn:
342 nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg,
343 nb_groups, n, jcp.mb);
344 break;
345 case loop_gncw:
346 nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks,
347 owb, jcp.nb_ow);
348 break;
349 case loop_ngcw:
350 nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks,
351 owb, jcp.nb_ow);
352 break;
353 case loop_nwcg:
354 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
355 gg, nb_groups);
356 break;
357 default: assert(!"unsupported loop order");
358 }
359 }
360 });
361 return status::success;
362}
363
364template <cpu_isa_t isa>
365status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_2d_dw(
366 const exec_ctx_t &ctx) const {
367 const auto &jcp = pd()->jcp_;
368 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
369 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
370 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
371 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
372 const auto post_ops_binary_rhs_arg_vec
373 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
374
375 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
376 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
377
378 const memory_desc_wrapper src_d(pd()->src_md());
379 const memory_desc_wrapper dst_d(pd()->dst_md());
380 const memory_desc_wrapper weights_d(pd()->weights_md(0));
381 const memory_desc_wrapper bias_d(pd()->weights_md(1));
382
383 const size_t bia_dt_size
384 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
385 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
386
387 assert(jcp.ic_block == 1);
388 assert(jcp.oc_block == 1);
389 assert(jcp.nb_ic == 1);
390 assert(jcp.nb_oc == 1);
391 assert(jcp.nb_oc_blocking == 1);
392 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
393
394 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
395 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
396 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
397
398 const float *oscales = adjust_oscales(
399 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
400
401 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
402 auto w = const_cast<char *>(weights);
403 const int32_t *compensation = (jcp.signed_input)
404 ? reinterpret_cast<int32_t *>(&w[offset])
405 : nullptr;
406 const int32_t *zp_compensation = jcp.src_zero_point
407 ? reinterpret_cast<int32_t *>(&w[offset])
408 + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0)
409 : nullptr;
410 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
411 int group_block = jcp.ch_block;
412
413 parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups,
414 [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) {
415 auto p = jit_conv_call_s();
416
417 size_t src_h_stride = src_d.blk_off(0, 0, 1);
418 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
419
420 int gb = gg * jcp.nb_ch_blocking;
421 int g = gb * group_block;
422
423 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
424 int ow_s = owb * jcp.ow_block;
425 int iw_s = ow_s * jcp.stride_w;
426
427 auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size)
428 : nullptr;
429 const int32_t *compensation_w
430 = jcp.signed_input ? compensation + g : nullptr;
431
432 auto dst_w
433 = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s);
434 auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
435 auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
436
437 auto scales = &oscales[jcp.is_oc_scale * g];
438
439 int dilate_h = jcp.dilate_h + 1;
440 int i_t_overflow
441 = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
442 int i_b_overflow = nstl::min(jcp.kh,
443 div_up(max(0,
444 ih_s - jcp.ih + (jcp.kh - 1) * dilate_h
445 + 1),
446 dilate_h));
447 int kh_padding
448 = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
449
450 size_t wei_stride = (jcp.signed_input || jcp.src_zero_point)
451 ? 0
452 : i_t_overflow * wht_h_stride;
453 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
454 p.dst = dst_w;
455 p.filt = wht_w + wei_stride;
456 p.bias = bias_w;
457 p.compensation = compensation_w;
458 p.zp_compensation
459 = jcp.src_zero_point ? zp_compensation + g : nullptr;
460 p.src_zero_point
461 = jcp.src_zero_point ? src_zero_point : nullptr;
462 p.dst_zero_point
463 = jcp.dst_zero_point ? dst_zero_point : nullptr;
464 p.oc_blocks = gb;
465 p.kh_padding = kh_padding;
466 p.scales = scales;
467 p.dst_scale = dst_scales;
468 p.t_overflow = i_t_overflow;
469 p.b_overflow = i_b_overflow;
470 p.owb = owb;
471
472 p.oc_l_off = g;
473 p.post_ops_binary_rhs_arg_vec
474 = post_ops_binary_rhs_arg_vec.data();
475 p.dst_orig = dst;
476
477 (*kernel_)(&p);
478 });
479 return status::success;
480}
481
482template <cpu_isa_t isa>
483status_t jit_uni_x8s8s32x_convolution_fwd_t<isa>::execute_forward_3d(
484 const exec_ctx_t &ctx) const {
485 const auto &jcp = pd()->jcp_;
486 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
487 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
488 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
489 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
490 const auto post_ops_binary_rhs_arg_vec
491 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
492
493 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
494 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
495
496 const memory_desc_wrapper src_d(pd()->src_md());
497 const memory_desc_wrapper dst_d(pd()->dst_md());
498 const memory_desc_wrapper weights_d(pd()->weights_md(0));
499 const memory_desc_wrapper bias_d(pd()->weights_md(1));
500
501 const size_t bia_dt_size
502 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
503 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
504
505 assert(jcp.ch_block == 1);
506 assert(jcp.nb_ch_blocking == 1);
507 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
508 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
509
510 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
511 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
512 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
513
514 const float *oscales = adjust_oscales(
515 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
516
517 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
518 auto w = const_cast<char *>(weights);
519 const int32_t *compensation = (jcp.signed_input)
520 ? reinterpret_cast<int32_t *>(&w[offset])
521 : nullptr;
522 const int32_t *zp_compensation = jcp.src_zero_point
523 ? reinterpret_cast<int32_t *>(&w[offset])
524 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
525 : nullptr;
526 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
527 int nb_groups = jcp.nb_ch;
528 int work_amount
529 = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow;
530
531 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
532 int start {0}, end {0};
533 balance211(work_amount, nthr, ithr, start, end);
534
535 auto p = jit_conv_call_s();
536
537 size_t src_d_stride = src_d.blk_off(0, 0, 1);
538 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
539 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
540 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
541 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
542
543 int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}, owb {0};
544 switch (jcp.loop_order) {
545 case loop_cwgn:
546 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
547 nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
548 break;
549 case loop_ngcw:
550 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
551 owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh);
552 break;
553 case loop_nhwcg:
554 nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh,
555 owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups);
556 break;
557 default: assert(!"unsupported loop order");
558 }
559 while (start < end) {
560 for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
561 occ1 += jcp.nb_oc_blocking) {
562 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
563 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
564
565 int g_ic = g * jcp.nb_ic * jcp.ic_block;
566
567 int work_rem = end - start;
568 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
569 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
570 if (jcp.loop_order == loop_nhwcg)
571 oh_e = oh_s + 1; // step instead
572 int ow_s = owb * jcp.ow_block;
573 int iw_s = ow_s * jcp.stride_w;
574 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
575 int dilate_d = jcp.dilate_d + 1;
576 int d_f_overflow
577 = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d));
578 int d_back_overflow = nstl::min(jcp.kd,
579 div_up(max(0,
580 id_s - jcp.id + (jcp.kd - 1) * dilate_d
581 + 1),
582 dilate_d));
583
584 int kd_padding
585 = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow);
586
587 auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
588 : nullptr;
589 const int32_t *compensation_w
590 = (jcp.signed_input) ? compensation + g_oc : nullptr;
591 p.zp_compensation
592 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
593 p.src_zero_point
594 = jcp.src_zero_point ? src_zero_point : nullptr;
595 p.dst_zero_point
596 = jcp.dst_zero_point ? dst_zero_point : nullptr;
597
598 auto dst_w = dst
599 + dst_dt_size
600 * dst_d.blk_off(n, g_oc, od_s, oh_s, ow_s);
601 auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s)
602 + d_f_overflow * dilate_d * src_d_stride;
603 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
604 + ((jcp.signed_input || jcp.src_zero_point)
605 ? 0
606 : d_f_overflow)
607 * wht_d_stride;
608
609 auto scales = &oscales[jcp.is_oc_scale * g_oc];
610
611 for (int oj = oh_s, ij = ih_s; oj < oh_e;
612 ++oj, ij += jcp.stride_h) {
613 int dilate_h = jcp.dilate_h + 1;
614 int i_t_overflow
615 = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h));
616 int i_b_overflow = nstl::min(jcp.kh,
617 div_up(max(0,
618 ij - jcp.ih + (jcp.kh - 1) * dilate_h
619 + 1),
620 dilate_h));
621 int kh_padding = nstl::max(
622 0, jcp.kh - i_t_overflow - i_b_overflow);
623
624 size_t wei_stride = (jcp.signed_input || jcp.src_zero_point)
625 ? 0
626 : wht_h_stride * i_t_overflow;
627 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
628 p.dst = dst_w;
629 p.filt = wht_w + wei_stride;
630 p.bias = bias_w;
631 p.compensation = compensation_w;
632 p.oc_blocks = ocb;
633 p.kh_padding = kh_padding;
634 p.kd_padding = kd_padding;
635 p.scales = scales;
636 p.dst_scale = dst_scales;
637 p.t_overflow = i_t_overflow;
638 p.b_overflow = i_b_overflow;
639 p.f_overflow = d_f_overflow;
640 p.back_overflow = d_back_overflow;
641 p.owb = owb;
642
643 p.oc_l_off = g_oc;
644 p.post_ops_binary_rhs_arg_vec
645 = post_ops_binary_rhs_arg_vec.data();
646 p.dst_orig = dst;
647
648 (*kernel_)(&p);
649 src_w += src_h_stride * jcp.stride_h;
650 dst_w += dst_dt_size * dst_h_stride;
651 }
652 }
653 switch (jcp.loop_order) {
654 case loop_cwgn:
655 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
656 g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s,
657 jcp.oh);
658 break;
659 case loop_ngcw:
660 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
661 oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s,
662 jcp.oh);
663 break;
664 case loop_nhwcg:
665 ++start;
666 nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb,
667 jcp.nb_ow, occ, oc_chunks, g, nb_groups);
668 break;
669 default: assert(!"unsupported loop order");
670 }
671 }
672 });
673 return status::success;
674}
675
676template struct jit_uni_x8s8s32x_convolution_fwd_t<sse41>;
677template struct jit_uni_x8s8s32x_convolution_fwd_t<avx2>;
678
679} // namespace x64
680} // namespace cpu
681} // namespace impl
682} // namespace dnnl
683