1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23
24#include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34
35using namespace nstl;
36
37#define wht_blk_off(d, g, ...) \
38 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
39 : (d).blk_off(__VA_ARGS__))
40
41const float *jit_avx512_core_x8s8s32x_convolution_fwd_t::adjust_oscales(
42 const memory_tracking::grantor_t &scratchpad, const float *src_scales,
43 const float *wei_scales) const {
44 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
45 const float src_scale = src_scales[0];
46 const int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
47 float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
48 ? 1.f / pd()->jcp_.wei_adj_scale
49 : 1.f;
50 switch (wei_mask) {
51 case 0:
52 utils::array_set(loc_scales, src_scale * wei_scales[0] * factor,
53 pd()->jcp_.simd_w);
54 break;
55 default:
56 for (dim_t c = 0; c < pd()->OC(); c++)
57 loc_scales[c] = src_scale * wei_scales[c] * factor;
58 }
59 return loc_scales;
60}
61
62status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_1d(
63 const exec_ctx_t &ctx) const {
64 const auto &jcp = pd()->jcp_;
65
66 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
67 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
68 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
69 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
70 const auto post_ops_binary_rhs_arg_vec
71 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
72
73 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
74 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
75
76 const memory_desc_wrapper src_d(pd()->src_md());
77 const memory_desc_wrapper dst_d(pd()->dst_md());
78 const memory_desc_wrapper weights_d(pd()->weights_md(0));
79 const memory_desc_wrapper bias_d(pd()->weights_md(1));
80
81 const size_t bia_dt_size
82 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
83 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
84
85 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
86 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
87
88 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
89 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
90 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
91
92 const float *oscales = adjust_oscales(
93 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
94
95 size_t extra_data_offset
96 = weights_d.size() - weights_d.additional_buffer_size();
97 size_t ch_offset = jcp.is_depthwise ? jcp.nb_ch * jcp.ch_block
98 : jcp.ngroups * jcp.oc;
99 auto w = const_cast<char *>(weights);
100 int32_t *compensation = (jcp.signed_input)
101 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
102 : nullptr;
103 int32_t *zp_compensation = jcp.src_zero_point
104 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
105 + (jcp.signed_input ? ch_offset : 0)
106 : nullptr;
107
108 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
109 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
110 int group_block = jcp.ch_block;
111 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
112
113 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
114 int start {0}, end {0};
115 balance211(work_amount, nthr, ithr, start, end);
116
117 auto p = jit_conv_call_s();
118
119 int n {0}, gg {0}, occ {0}, owb {0};
120 switch (jcp.loop_order) {
121 case loop_cwgn:
122 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
123 nb_groups, n, jcp.mb);
124 break;
125 case loop_gncw:
126 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ,
127 oc_chunks, owb, jcp.nb_ow);
128 break;
129 case loop_ngcw:
130 nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ,
131 oc_chunks, owb, jcp.nb_ow);
132 break;
133 case loop_nwcg:
134 nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ,
135 oc_chunks, gg, nb_groups);
136 break;
137 default: assert(!"unsupported loop order");
138 }
139 while (start < end) {
140 int ocb = occ * jcp.nb_oc_blocking;
141 int gb = gg * jcp.nb_ch_blocking;
142 int g = gb * group_block;
143 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
144 int g_ic = g * jcp.nb_ic * jcp.ic_block;
145 int ow_s = owb * jcp.ow_block;
146 int iw_s = ow_s * jcp.stride_w;
147
148 p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
149 : nullptr;
150 p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
151 p.zp_compensation
152 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
153 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
154 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
155 p.dst_scale = dst_scales;
156 p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc, ow_s);
157 p.src = src + src_d.blk_off(n, g_ic, iw_s);
158 p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
159 p.scales = &oscales[jcp.is_oc_scale * g_oc];
160 p.oc_blocks = jcp.is_depthwise ? gb : ocb;
161 p.kh_padding = jcp.kh;
162 p.t_overflow = 0;
163 p.b_overflow = 0;
164 p.owb = owb;
165
166 p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block;
167 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
168 p.dst_orig = dst;
169 (*kernel_)(&p);
170
171 ++start;
172 switch (jcp.loop_order) {
173 case loop_cwgn:
174 nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg,
175 nb_groups, n, jcp.mb);
176 break;
177 case loop_gncw:
178 nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks,
179 owb, jcp.nb_ow);
180 break;
181 case loop_ngcw:
182 nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks,
183 owb, jcp.nb_ow);
184 break;
185 case loop_nwcg:
186 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
187 gg, nb_groups);
188 break;
189 default: assert(!"unsupported loop order");
190 }
191 }
192 });
193 return status::success;
194}
195
196status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d(
197 const exec_ctx_t &ctx) const {
198 const auto &jcp = pd()->jcp_;
199 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
200 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
201 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
202 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
203 const auto post_ops_binary_rhs_arg_vec
204 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
205
206 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
207 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
208
209 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
210 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
211 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
212
213 const memory_desc_wrapper src_d(pd()->src_md());
214 const memory_desc_wrapper dst_d(pd()->dst_md());
215 const memory_desc_wrapper weights_d(pd()->weights_md(0));
216 const memory_desc_wrapper bias_d(pd()->weights_md(1));
217
218 const size_t bia_dt_size
219 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
220 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
221
222 assert(jcp.ch_block == 1);
223 assert(jcp.nb_ch_blocking == 1);
224 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
225 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
226
227 const float *oscales = adjust_oscales(
228 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
229
230 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
231 auto w = const_cast<char *>(weights);
232 int32_t *compensation = (jcp.signed_input)
233 ? reinterpret_cast<int32_t *>(&w[offset])
234 : nullptr;
235 int32_t *zp_compensation = jcp.src_zero_point
236 ? reinterpret_cast<int32_t *>(&w[offset])
237 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
238 : nullptr;
239
240 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
241 int nb_groups = jcp.nb_ch;
242 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
243
244 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
245 int start {0}, end {0};
246 balance211(work_amount, nthr, ithr, start, end);
247
248 auto p = jit_conv_call_s();
249
250 size_t src_h_stride = src_d.blk_off(0, 0, 1);
251 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
252 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
253
254 int n {0}, g {0}, occ {0}, oh_s {0}, owb {0};
255 switch (jcp.loop_order) {
256 case loop_cwgn:
257 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
258 nb_groups, n, jcp.mb, oh_s, jcp.oh);
259 break;
260 case loop_ngcw:
261 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
262 owb, jcp.nb_ow, oh_s, jcp.oh);
263 break;
264 case loop_nhwcg:
265 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
266 occ, oc_chunks, g, nb_groups);
267 break;
268 default: assert(!"unsupported loop order");
269 }
270 while (start < end) {
271 for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
272 occ1 += jcp.nb_oc_blocking) {
273 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
274 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
275
276 int g_ic = g * jcp.nb_ic * jcp.ic_block;
277
278 int work_rem = end - start;
279 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
280 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
281 if (jcp.loop_order == loop_nhwcg)
282 oh_e = oh_s + 1; // step instead
283 int ow_s = owb * jcp.ow_block;
284 int iw_s = ow_s * jcp.stride_w;
285
286 auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
287 : nullptr;
288 int32_t *compensation_w
289 = (jcp.signed_input) ? compensation + g_oc : nullptr;
290
291 auto dst_w = dst
292 + dst_dt_size * dst_d.blk_off(n, g_oc, oh_s, ow_s);
293 auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
294 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
295
296 auto scales = &oscales[jcp.is_oc_scale * g_oc];
297
298 for (int oj = oh_s, ij = ih_s; oj < oh_e;
299 ++oj, ij += jcp.stride_h) {
300 int dilate_h = jcp.dilate_h + 1;
301 int i_t_overflow
302 = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h));
303 int i_b_overflow = nstl::min(jcp.kh,
304 div_up(max(0,
305 ij - jcp.ih + (jcp.kh - 1) * dilate_h
306 + 1),
307 dilate_h));
308 int kh_padding = nstl::max(
309 0, jcp.kh - i_t_overflow - i_b_overflow);
310
311 size_t wei_stride = (jcp.signed_input || jcp.src_zero_point)
312 ? 0
313 : i_t_overflow * wht_h_stride;
314 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
315 p.dst = dst_w;
316 p.filt = wht_w + wei_stride;
317 p.bias = bias_w;
318 p.compensation = compensation_w;
319 p.zp_compensation = jcp.src_zero_point
320 ? zp_compensation + g_oc
321 : nullptr;
322 p.src_zero_point
323 = jcp.src_zero_point ? src_zero_point : nullptr;
324 p.dst_zero_point
325 = jcp.dst_zero_point ? dst_zero_point : nullptr;
326 p.dst_scale = dst_scales;
327 p.oc_blocks = ocb;
328 p.kh_padding = kh_padding;
329 p.scales = scales;
330 p.t_overflow = i_t_overflow;
331 p.b_overflow = i_b_overflow;
332 p.owb = owb;
333
334 p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block;
335 p.post_ops_binary_rhs_arg_vec
336 = post_ops_binary_rhs_arg_vec.data();
337 p.dst_orig = dst;
338 (*kernel_)(&p);
339
340 src_w += src_h_stride * jcp.stride_h;
341 dst_w += dst_dt_size * dst_h_stride;
342 }
343 }
344 switch (jcp.loop_order) {
345 case loop_cwgn:
346 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
347 g, nb_groups, n, jcp.mb, oh_s, jcp.oh);
348 break;
349 case loop_ngcw:
350 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
351 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
352 break;
353 case loop_nhwcg:
354 ++start;
355 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
356 occ, oc_chunks, g, nb_groups);
357 break;
358 default: assert(!"unsupported loop order");
359 }
360 }
361 });
362 return status::success;
363}
364
365status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_2d_dw(
366 const exec_ctx_t &ctx) const {
367 const auto &jcp = pd()->jcp_;
368 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
369 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
370 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
371 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
372 const auto post_ops_binary_rhs_arg_vec
373 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
374
375 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
376 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
377
378 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
379 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
380 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
381
382 const memory_desc_wrapper src_d(pd()->src_md());
383 const memory_desc_wrapper dst_d(pd()->dst_md());
384 const memory_desc_wrapper weights_d(pd()->weights_md(0));
385 const memory_desc_wrapper bias_d(pd()->weights_md(1));
386
387 const size_t bia_dt_size
388 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
389 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
390
391 assert(jcp.ic_block == 1);
392 assert(jcp.oc_block == 1);
393 assert(jcp.nb_ic == 1);
394 assert(jcp.nb_oc == 1);
395 assert(jcp.nb_oc_blocking == 1);
396 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
397
398 const float *oscales = adjust_oscales(
399 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
400
401 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
402 auto w = const_cast<char *>(weights);
403 int32_t *compensation = (jcp.signed_input)
404 ? reinterpret_cast<int32_t *>(&w[offset])
405 : nullptr;
406 int32_t *zp_compensation = jcp.src_zero_point
407 ? reinterpret_cast<int32_t *>(&w[offset])
408 + (jcp.signed_input ? jcp.nb_ch * jcp.ch_block : 0)
409 : nullptr;
410 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
411 int group_block = jcp.ch_block;
412
413 parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups,
414 [&](dim_t n, dim_t oh_s, dim_t owb, dim_t gg) {
415 auto p = jit_conv_call_s();
416
417 size_t src_h_stride = src_d.blk_off(0, 0, 1);
418 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
419
420 int gb = gg * jcp.nb_ch_blocking;
421 int g = gb * group_block;
422
423 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
424 int ow_s = owb * jcp.ow_block;
425 int iw_s = ow_s * jcp.stride_w;
426
427 auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size)
428 : nullptr;
429 int32_t *compensation_w
430 = jcp.signed_input ? compensation + g : nullptr;
431
432 auto dst_w
433 = dst + dst_dt_size * dst_d.blk_off(n, g, oh_s, ow_s);
434 auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
435 auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
436
437 auto scales = &oscales[jcp.is_oc_scale * g];
438
439 int dilate_h = jcp.dilate_h + 1;
440 int i_t_overflow
441 = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
442 int i_b_overflow = nstl::min(jcp.kh,
443 div_up(max(0,
444 ih_s - jcp.ih + (jcp.kh - 1) * dilate_h
445 + 1),
446 dilate_h));
447 int kh_padding
448 = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
449
450 size_t wei_stride = (jcp.signed_input || jcp.src_zero_point)
451 ? 0
452 : i_t_overflow * wht_h_stride;
453 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
454 p.dst = dst_w;
455 p.filt = wht_w + wei_stride;
456 p.bias = bias_w;
457 p.compensation = compensation_w;
458 p.zp_compensation
459 = jcp.src_zero_point ? zp_compensation + g : nullptr;
460 p.src_zero_point
461 = jcp.src_zero_point ? src_zero_point : nullptr;
462 p.dst_zero_point
463 = jcp.dst_zero_point ? dst_zero_point : nullptr;
464 p.dst_scale = dst_scales;
465 p.oc_blocks = gb;
466 p.kh_padding = kh_padding;
467 p.scales = scales;
468 p.t_overflow = i_t_overflow;
469 p.b_overflow = i_b_overflow;
470 p.owb = owb;
471
472 p.oc_l_off = g * jcp.oc;
473 p.post_ops_binary_rhs_arg_vec
474 = post_ops_binary_rhs_arg_vec.data();
475 p.dst_orig = dst;
476
477 (*kernel_)(&p);
478 });
479 return status::success;
480}
481
482status_t jit_avx512_core_x8s8s32x_convolution_fwd_t::execute_forward_3d(
483 const exec_ctx_t &ctx) const {
484 const auto &jcp = pd()->jcp_;
485 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
486 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
487 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
488 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
489 const auto post_ops_binary_rhs_arg_vec
490 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
491
492 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
493 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
494
495 const memory_desc_wrapper src_d(pd()->src_md());
496 const memory_desc_wrapper dst_d(pd()->dst_md());
497 const memory_desc_wrapper weights_d(pd()->weights_md(0));
498 const memory_desc_wrapper bias_d(pd()->weights_md(1));
499
500 const size_t bia_dt_size
501 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
502 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
503
504 assert(jcp.ch_block == 1);
505 assert(jcp.nb_ch_blocking == 1);
506 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
507 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
508
509 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
510 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
511 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
512
513 const float *oscales = adjust_oscales(
514 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
515
516 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
517 auto w = const_cast<char *>(weights);
518 int32_t *compensation = (jcp.signed_input)
519 ? reinterpret_cast<int32_t *>(&w[offset])
520 : nullptr;
521 int32_t *zp_compensation = jcp.src_zero_point
522 ? reinterpret_cast<int32_t *>(&w[offset])
523 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
524 : nullptr;
525 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
526 int nb_groups = jcp.nb_ch;
527 int work_amount
528 = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow;
529
530 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
531 int start {0}, end {0};
532 balance211(work_amount, nthr, ithr, start, end);
533
534 auto p = jit_conv_call_s();
535
536 size_t src_d_stride = src_d.blk_off(0, 0, 1);
537 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
538 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
539 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
540 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
541
542 int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}, owb {0};
543 switch (jcp.loop_order) {
544 case loop_cwgn:
545 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
546 nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
547 break;
548 case loop_ngcw:
549 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
550 owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh);
551 break;
552 case loop_nhwcg:
553 nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh,
554 owb, jcp.nb_ow, occ, oc_chunks, g, nb_groups);
555 break;
556 default: assert(!"unsupported loop order");
557 }
558 while (start < end) {
559 for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
560 occ1 += jcp.nb_oc_blocking) {
561 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
562 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
563
564 int g_ic = g * jcp.nb_ic * jcp.ic_block;
565
566 int work_rem = end - start;
567 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
568 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
569 if (jcp.loop_order == loop_nhwcg)
570 oh_e = oh_s + 1; // step instead
571 int ow_s = owb * jcp.ow_block;
572 int iw_s = ow_s * jcp.stride_w;
573 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
574 int dilate_d = jcp.dilate_d + 1;
575 int d_f_overflow
576 = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d));
577 int d_back_overflow = nstl::min(jcp.kd,
578 div_up(max(0,
579 id_s - jcp.id + (jcp.kd - 1) * dilate_d
580 + 1),
581 dilate_d));
582
583 int kd_padding
584 = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow);
585
586 auto bias_w = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
587 : nullptr;
588 int32_t *compensation_w
589 = (jcp.signed_input) ? compensation + g_oc : nullptr;
590
591 auto dst_w = dst
592 + dst_dt_size
593 * dst_d.blk_off(n, g_oc, od_s, oh_s, ow_s);
594 auto src_w = src + src_d.blk_off(n, g_ic, id_s, ih_s, iw_s)
595 + d_f_overflow * dilate_d * src_d_stride;
596 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
597 + ((jcp.signed_input || jcp.src_zero_point)
598 ? 0
599 : d_f_overflow)
600 * wht_d_stride;
601
602 auto scales = &oscales[jcp.is_oc_scale * g_oc];
603
604 for (int oj = oh_s, ij = ih_s; oj < oh_e;
605 ++oj, ij += jcp.stride_h) {
606 int dilate_h = jcp.dilate_h + 1;
607 int i_t_overflow
608 = nstl::min(jcp.kh, div_up(max(0, -ij), dilate_h));
609 int i_b_overflow = nstl::min(jcp.kh,
610 div_up(max(0,
611 ij - jcp.ih + (jcp.kh - 1) * dilate_h
612 + 1),
613 dilate_h));
614 int kh_padding = nstl::max(
615 0, jcp.kh - i_t_overflow - i_b_overflow);
616
617 size_t wei_stride = (jcp.signed_input || jcp.src_zero_point)
618 ? 0
619 : wht_h_stride * i_t_overflow;
620 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
621 p.dst = dst_w;
622 p.filt = wht_w + wei_stride;
623 p.bias = bias_w;
624 p.compensation = compensation_w;
625 p.zp_compensation = jcp.src_zero_point
626 ? zp_compensation + g_oc
627 : nullptr;
628 p.src_zero_point
629 = jcp.src_zero_point ? src_zero_point : nullptr;
630 p.dst_zero_point
631 = jcp.dst_zero_point ? dst_zero_point : nullptr;
632 p.dst_scale = dst_scales;
633 p.oc_blocks = ocb;
634 p.kh_padding = kh_padding;
635 p.kd_padding = kd_padding;
636 p.scales = scales;
637 p.t_overflow = i_t_overflow;
638 p.b_overflow = i_b_overflow;
639 p.f_overflow = d_f_overflow;
640 p.back_overflow = d_back_overflow;
641 p.owb = owb;
642
643 p.oc_l_off = (g * jcp.nb_oc + ocb) * jcp.oc_block;
644 p.post_ops_binary_rhs_arg_vec
645 = post_ops_binary_rhs_arg_vec.data();
646 p.dst_orig = dst;
647 (*kernel_)(&p);
648
649 src_w += src_h_stride * jcp.stride_h;
650 dst_w += dst_dt_size * dst_h_stride;
651 }
652 }
653 switch (jcp.loop_order) {
654 case loop_cwgn:
655 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
656 g, nb_groups, n, jcp.mb, od_s, jcp.od, oh_s,
657 jcp.oh);
658 break;
659 case loop_ngcw:
660 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
661 oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s,
662 jcp.oh);
663 break;
664 case loop_nhwcg:
665 ++start;
666 nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb,
667 jcp.nb_ow, occ, oc_chunks, g, nb_groups);
668 break;
669 default: assert(!"unsupported loop order");
670 }
671 }
672 });
673 return status::success;
674}
675
676} // namespace x64
677} // namespace cpu
678} // namespace impl
679} // namespace dnnl
680
681// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
682