1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18#include "common/memory_desc_wrapper.hpp"
19#include "cpu/cpu_primitive.hpp"
20#include "cpu/zero_point_utils.hpp"
21
22#include "cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp"
23
24#define GET_OFF(field) offsetof(jit_deconv_call_s, field)
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34using namespace Xbyak;
35
36using namespace nstl;
37
38#define wht_blk_off(d, g, ...) \
39 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
41
42template <typename Vmm>
43jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::
44 jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
45 const primitive_attr_t &attr, const memory_desc_t &dst_md)
46 : jit_generator(jit_name())
47 , jcp(ajcp)
48 , attr_(attr)
49 , postops_injector_(nullptr) {
50
51 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
52 const std::size_t tail_size = jcp.is_depthwise
53 ? jcp.ngroups % jcp.ch_block
54 : jcp.oc_without_padding % jcp.oc_block;
55
56 static constexpr bool preserve_gpr = true;
57 static constexpr bool preserve_vmm = true;
58 static constexpr bool use_exact_tail_scalar_bcast = false;
59
60 const binary_injector::rhs_arg_static_params_t rhs_sp {
61 static_cast<size_t>(Xbyak::Xmm(31).getIdx()), this->r14,
62 this->r15, this->r13, preserve_gpr, preserve_vmm,
63 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
64 memory_desc_wrapper(dst_md), tail_size, ktail_mask,
65 use_exact_tail_scalar_bcast};
66 const binary_injector::static_params_t bsp {this->param1, rhs_sp};
67
68 postops_injector_ = utils::make_unique<
69 injector::jit_uni_postops_injector_t<avx512_core, Vmm>>(
70 this, jcp.post_ops, bsp);
71 }
72}
73
74template <typename Vmm>
75jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
76 Vmm>::~jit_avx512_core_x8s8s32x_deconv_fwd_kernel()
77 = default;
78
79status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
80 jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
81 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
82 const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr,
83 int nthreads) {
84 const memory_desc_wrapper src_d(&src_md);
85 const memory_desc_wrapper dst_d(&dst_md);
86 const memory_desc_wrapper weights_d(&weights_md);
87 const memory_desc_wrapper bias_d(&bias_md);
88
89 if (!(mayiuse(avx512_core)
90 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
91 && weights_d.data_type() == data_type::s8
92 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
93 data_type::s8, data_type::u8)))
94 return status::unimplemented;
95
96 jcp = zero<decltype(jcp)>();
97 jcp.nthr = nthreads;
98
99 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
100 jcp.signed_input = src_d.data_type() == data_type::s8;
101 const int ndims = jcp.ndims = dst_d.ndims();
102 const bool is_1d = ndims == 3;
103 const bool is_2d = ndims == 4;
104 const bool is_3d = ndims == 5;
105
106 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
107 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
108 jcp.ic = src_d.dims()[1] / jcp.ngroups;
109 jcp.id = is_3d ? src_d.dims()[2] : 1;
110 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
111 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
112 jcp.is_depthwise = true && with_groups
113 && utils::everyone_is(
114 1, jcp.ic_without_padding, jcp.oc_without_padding);
115
116 /* TODO: future work, on hold until depthwise specialized kernel is
117 * implemented. */
118 if (jcp.is_depthwise && (jcp.signed_input || is_3d))
119 return status::unimplemented;
120
121 if (!zero_points_valid(&attr)) return status::unimplemented;
122 jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC);
123 jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST);
124 jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC);
125
126 format_tag_t dat_tag = utils::pick(
127 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
128
129 if (src_d.format_kind() == format_kind::any) {
130 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
131 jcp.src_tag = dat_tag;
132 } else {
133 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
134 }
135 if (jcp.src_tag != dat_tag) return status::unimplemented;
136
137 if (dst_d.format_kind() == format_kind::any) {
138 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
139 jcp.dst_tag = dat_tag;
140 } else {
141 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
142 }
143 if (jcp.dst_tag != dat_tag) return status::unimplemented;
144
145 auto set_or_check_wei_format = [&]() {
146 using namespace format_tag;
147 format_tag_t wei_tag;
148 if (jcp.ic_block == 16 || jcp.ch_block == 16) {
149 if (is_3d) {
150 wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
151 } else if (is_1d) {
152 wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
153 : OIw4i16o4i;
154 } else {
155 assert(is_2d);
156 wei_tag = with_groups
157 ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
158 : OIhw4i16o4i;
159 }
160 } else if (jcp.ic_block == 8) {
161 assert(with_groups);
162 wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i;
163 } else {
164 assert(with_groups && jcp.ic_block == 4);
165 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
166 }
167
168 memory_desc_t want_wei_md = weights_md;
169 memory_desc_init_by_tag(want_wei_md, wei_tag);
170 if (jcp.signed_input && !jcp.is_depthwise) {
171 want_wei_md.extra.flags = 0
172 | memory_extra_flags::compensation_conv_s8s8
173 | memory_extra_flags::scale_adjust;
174 want_wei_md.extra.compensation_mask = (1 << 0)
175 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
176 want_wei_md.extra.scale_adjust
177 = mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
178 }
179 if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups);
180
181 if (weights_md.format_kind == format_kind::any) {
182 weights_md = want_wei_md;
183 return true;
184 }
185
186 return weights_md == want_wei_md;
187 };
188
189 jcp.with_bias = with_bias;
190 if (jcp.with_bias) {
191 if (bias_d.format_kind() == format_kind::any)
192 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
193 }
194
195 jcp.prop_kind = cd.prop_kind;
196 jcp.mb = src_d.dims()[0];
197 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
198 jcp.iw = src_d.dims()[ndims - 1];
199 jcp.od = is_3d ? dst_d.dims()[2] : 1;
200 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
201 jcp.ow = dst_d.dims()[ndims - 1];
202 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
203 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
204 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
205 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
206 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
207 jcp.l_pad = cd.padding[0][ndims - 3];
208 jcp.stride_d = is_3d ? cd.strides[0] : 1;
209 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
210 jcp.stride_w = cd.strides[ndims - 3];
211
212 if (jcp.is_depthwise) {
213 jcp.ch_block = 16;
214 jcp.oc_block = 1;
215 jcp.ic_block = 1;
216 } else {
217 jcp.ch_block = 1;
218 jcp.oc_block = 16;
219 jcp.ic_block = 16;
220
221 if (jcp.ngroups == 1) {
222 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
223 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
224 } else if (jcp.ngroups != 1
225 && ((jcp.ic % jcp.ic_block != 0)
226 || (jcp.oc % jcp.oc_block != 0))) {
227 /* For grouped deconvolutions, oneDNN doesn't support padding.
228 When channels per group is not multiple of 16:
229 - Use Ymm when channels per group is multiple of 8,
230 - Use Xmm when channels per group is multiple of 4,
231 - Otherwise return unimplemented. */
232 jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4;
233 jcp.oc_block = jcp.ic_block;
234 }
235 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
236 return status::unimplemented;
237 }
238
239 if (!set_or_check_wei_format()) return status::unimplemented;
240
241 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
242 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
243 jcp.dilate_w = cd.dilates[ndims - 3];
244
245 if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1)
246 || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
247 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
248 return status::unimplemented;
249
250 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
251 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
252 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
253 jcp.r_pad = calculate_end_padding(
254 jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw);
255 jcp.b_pad = calculate_end_padding(
256 jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh);
257 jcp.back_pad = calculate_end_padding(
258 jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd);
259 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
260 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
261 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
262 if (kernel_outside_src) return status::unimplemented;
263
264 CHECK(attr.set_default_formats(&dst_md));
265 if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented;
266
267 const auto &p = attr.post_ops_;
268 const int eltwise_ind = p.find(primitive_kind::eltwise);
269 jcp.with_eltwise = eltwise_ind != -1;
270 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
271 const int binary_ind = p.find(primitive_kind::binary);
272 jcp.with_binary = binary_ind != -1;
273
274 const int sum_ind = p.find(primitive_kind::sum);
275 jcp.with_sum = sum_ind != -1;
276
277 //save post_ops desc for further usage
278 jcp.post_ops = p;
279
280 jcp.has_vnni = mayiuse(avx512_core_vnni);
281
282 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
283 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
284 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
285 const int wei_scales_per_oc = 1 << (int)with_groups;
286 jcp.is_oc_scale = wei_scales.mask_ == wei_scales_per_oc;
287 jcp.dst_scale = !dst_scales.has_default_values();
288
289 // only common and per-oc-channel scales are supported
290 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_scales_per_oc)
291 && utils::everyone_is(src_scales.mask_, dst_scales.mask_, 0);
292 if (!scales_ok) return status::unimplemented;
293
294 jcp.dst_dt = dst_d.data_type();
295 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
296 jcp.typesize_bia
297 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
298 jcp.typesize_in = types::data_type_size(src_d.data_type());
299 jcp.typesize_out = types::data_type_size(dst_d.data_type());
300
301 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
302 jcp.nb_oc = jcp.oc / jcp.oc_block;
303 jcp.nb_ic = jcp.ic / jcp.ic_block;
304
305 /* kernel blocking params */
306 const int regs = jcp.has_vnni ? 30 : 28;
307 jcp.nb_ch_blocking = 1;
308 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
309 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
310 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
311 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
312 break;
313
314 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
315 int l_overflow = max(
316 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
317
318 if (jcp.ow < jcp.ur_w) {
319 jcp.ur_w = jcp.ow;
320 jcp.ur_w_tail = 0;
321 } else {
322 for (; jcp.ur_w >= 1; jcp.ur_w--) {
323 /* ur_w should be multiple of stride_w in order
324 to simplify logic for get_ow_start and get_ow_end */
325 bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
326
327 /* boundary conditions:
328 These conditions ensure all elements close to boundary
329 are computed in a single call of compute loop */
330 bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
331 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
332 int r_overflow_no_tail = max(0,
333 ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)
334 - jcp.ur_w_tail)
335 / jcp.stride_w);
336 bool right_boundary_covered
337 = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
338
339 if (is_multiple_of_stride && left_boundary_covered
340 && right_boundary_covered)
341 break;
342 else if (jcp.ur_w == 1)
343 /* The boundary conditions above are also important
344 to maintain simplicity of calls to icb_loop,
345 if those conditions are not satisfied,
346 then special cases will need to be added
347 to use correct l_overflow/r_overflow values
348 when different iterations of compute loop
349 work on the locations close to boundary.
350 So to keep code simple, return unimplemented
351 for extreme case when a good ur_w cannot be found.
352 */
353 return status::unimplemented;
354 }
355 }
356
357 jcp.wei_adj_scale
358 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
359 ? weights_d.extra().scale_adjust
360 : 1.f;
361
362 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
363 return status::success;
364}
365
366bool _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
367 jit_conv_conf_t &jcp, primitive_attr_t &attr,
368 const memory_desc_wrapper &dst_d) {
369
370 using namespace injector;
371 const auto &post_ops = attr.post_ops_;
372 static constexpr bool sum_at_pos_0_only = true;
373 static constexpr bool sum_requires_scale_one = false;
374
375 return injector::post_ops_ok({avx512_core, {eltwise, binary, sum}, post_ops,
376 &dst_d, sum_at_pos_0_only, sum_requires_scale_one});
377}
378
379void _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
380 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
381 const primitive_attr_t &attr) {
382 const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
383 const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups;
384 const dim_t count = nstl::max<dim_t>(scales_count, 16);
385 scratchpad.book<float>(key_conv_adjusted_scales, count);
386
387 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) {
388 const dim_t zp_pad_comp_size
389 = static_cast<size_t>(jcp.oc_without_padding) * jcp.ngroups
390 * jcp.kd * jcp.kh * jcp.kw;
391 scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size);
392 }
393}
394
395template <typename Vmm>
396void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::compute(
397 const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src) {
398
399 if (jcp.has_vnni) {
400 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
401 } else if (jcp.is_depthwise) {
402 uni_vmovups(vmm_tmp, vreg_src);
403 uni_vpmulld(vmm_tmp, vmm_tmp, vreg_wei);
404 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp);
405 } else {
406 uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
407 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
408 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp);
409 }
410}
411
412template <typename Vmm>
413std::function<Vmm()> jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
414 Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept {
415 const int start_vmm_idx = vmm_inp(0, jcp.nb_oc_blocking).getIdx();
416 const int end_vmm_idx = vmm_inp(ur_w - 1, jcp.nb_oc_blocking).getIdx() + 1;
417 int current_vmm_idx = start_vmm_idx;
418
419 return [=]() mutable {
420 const Vmm vmm {static_cast<int>(current_vmm_idx++)};
421
422 if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx;
423
424 return vmm;
425 };
426}
427
428template <typename Vmm>
429void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::apply_zp_src_pad_str_comp(
430 int ur_w, int l_overflow, int r_overflow, bool h_padded) {
431 Xbyak::Label end_zp_pad, no_tail;
432
433 // apply once per icb loop, zp src stride padding compensation calculated as
434 // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source
435 cmp(reg_icb, jcp.nb_ic);
436 jne(end_zp_pad, T_NEAR);
437
438 if (jcp.ngroups % jcp.ch_block || jcp.oc_without_padding % jcp.oc_block) {
439 if (jcp.is_depthwise)
440 cmp(reg_oc_blocks, jcp.nb_ch - 1);
441 else
442 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
443 jne(no_tail, T_NEAR);
444
445 static constexpr bool last_ocb = true;
446 append_zp_src_pad_str_comp(
447 ur_w, l_overflow, r_overflow, h_padded, last_ocb);
448 jmp(end_zp_pad, T_NEAR);
449 }
450
451 L(no_tail);
452 static constexpr bool last_ocb = false;
453
454 append_zp_src_pad_str_comp(
455 ur_w, l_overflow, r_overflow, h_padded, last_ocb);
456
457 L(end_zp_pad);
458}
459
460template <typename Vmm>
461void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
462 Vmm>::append_zp_src_pad_str_comp(int ur_w, int l_overflow,
463 int r_overflow, bool h_padded, bool last_oc_block) {
464
465 const auto &reg_zp_src_pad_comp = reg_scratch;
466 const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w);
467 bool base_comp_addr_loaded = false;
468
469 const auto load_base_zp_src_pad_comp_addr = [&]() {
470 if (!base_comp_addr_loaded) {
471 if (jcp.ndims == 5) mov(reg_scratch_preserved, reg_scratch);
472
473 if (jcp.ndims > 3)
474 mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr);
475 else
476 mov(reg_zp_src_pad_comp,
477 qword[param1 + GET_OFF(zp_src_pad_str_compensation)]);
478
479 base_comp_addr_loaded = true;
480 }
481 };
482
483 const auto load_zp_src_pad_comp = [&](const Vmm &zp_pad_comp_vmm,
484 const Xbyak::Address &comp_addr,
485 const int ocb) {
486 const bool is_last_ocb = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
487 const bool is_tail = is_last_ocb && get_tail_size() > 0;
488 if (is_tail)
489 vmovups(zp_pad_comp_vmm | ktail_mask | T_z, comp_addr);
490 else
491 vmovups(zp_pad_comp_vmm, comp_addr);
492 };
493
494 const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) {
495 const auto kw_offset = it_kw * jcp.oc_without_padding * jcp.ngroups;
496 const auto oc_offset = ocb * jcp.oc_block;
497
498 return (kw_offset + oc_offset) * sizeof(int32_t);
499 };
500
501 for (int it_kw = 0; it_kw < jcp.kw; ++it_kw) {
502 const int ow_start = get_ow_start(it_kw, l_overflow);
503 const int ow_end = get_ow_end(ur_w, it_kw, r_overflow);
504
505 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
506 Vmm zp_src_comp_pad_vmm; // will be assigned later
507 bool ocb_zp_loaded = false;
508
509 const auto zp_src_comp_pad_off
510 = get_zp_src_comp_pad_off(it_kw, ocb);
511
512 for (int it_ow = 0; it_ow < ur_w; ++it_ow) {
513
514 const bool inside_padded_area = h_padded
515 || !(it_ow >= ow_start && it_ow < ow_end
516 && ((it_ow + jcp.l_pad - it_kw) % jcp.stride_w
517 == 0));
518
519 if (inside_padded_area) {
520 load_base_zp_src_pad_comp_addr();
521
522 if (!ocb_zp_loaded) {
523 zp_src_comp_pad_vmm = get_next_comp_vmm();
524 const auto comp_addr = ptr[reg_zp_src_pad_comp
525 + zp_src_comp_pad_off];
526 load_zp_src_pad_comp(
527 zp_src_comp_pad_vmm, comp_addr, ocb);
528 ocb_zp_loaded = true;
529 }
530
531 const auto vmm_dst = vmm_out(it_ow, ocb);
532 uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm);
533 }
534 }
535 }
536 }
537
538 if (jcp.ndims > 3) {
539 if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr();
540
541 const auto kh_offset = jcp.kw * jcp.oc_without_padding * jcp.ngroups
542 * sizeof(int32_t);
543
544 add(reg_zp_src_pad_comp, kh_offset);
545 mov(zp_src_pad_comp_addr, reg_zp_src_pad_comp);
546 }
547
548 if (jcp.ndims == 5 && base_comp_addr_loaded)
549 mov(reg_scratch, reg_scratch_preserved);
550}
551
552template <typename Vmm>
553void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::compute_ker(int ur_w,
554 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
555 bool h_padded) {
556
557 const bool signed_input_or_src_zp
558 = (jcp.signed_input || jcp.src_zero_point);
559
560 const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
561 const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp.stride_w;
562
563 auto src_offset = [=](int oj, int icb, int ki) {
564 return jcp.typesize_in
565 * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
566 * jcp.ngroups * jcp.ic_without_padding
567 + icb * 4);
568 };
569
570 auto kernel_offset = [=](int ocb, int icb, int ki) {
571 return jcp.typesize_in
572 * ((ocb * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki)
573 * ch_block_all
574 + icb * jcp.oc_block * ic_sub_step);
575 };
576
577 for (int ki = 0; ki < jcp.kw; ki++) {
578 int jj_start = get_ow_start(ki, l_overflow);
579 int jj_end = get_ow_end(ur_w, ki, r_overflow);
580
581 int _start = (signed_input_or_src_zp) ? 0 : jj_start;
582 int _end = (signed_input_or_src_zp) ? ur_w : jj_end;
583
584 int tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
585 : jcp.ic_without_padding % 4;
586 int n_ic_blocks = jcp.is_depthwise
587 ? 1
588 : (last_ic_block_flag & ~no_last_block ? div_up(
589 jcp.ic_without_padding % jcp.ic_block, 4)
590 : jcp.ic_block / 4);
591
592 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
593 if (h_padded == true) {
594 if (jcp.signed_input) {
595 /* fill padded area with shifted values */
596 const Vmm inp = vmm_inp(0, jcp.nb_oc_blocking);
597 vpxord(inp, inp, inp);
598 vpsubb(inp, inp, vmm_shift);
599 }
600 } else {
601
602 for (int jj = _start; jj < _end; jj += ur_w_stride) {
603
604 int aux_src_off = src_offset(jj, icb1, ki);
605
606 if (jj >= jj_start && jj < jj_end
607 && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
608 if (jcp.is_depthwise) {
609 Vmm vmm_src = vmm_inp(jj, jcp.nb_oc_blocking);
610 if (tail_size != 0) {
611 assert(jcp.nb_oc_blocking == 1);
612 vmm_src = vmm_src | ktail_mask | T_z;
613 }
614 vpmovzxbd(vmm_src,
615 EVEX_compress_addr(
616 aux_reg_src, aux_src_off));
617 } else if ((last_ic_block_flag & last_sp_block)
618 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
619 const Xmm xmm_tmp = Xmm(
620 vmm_inp(jj, jcp.nb_oc_blocking).getIdx());
621 for (int r = 0; r < tail_size; ++r)
622 vpinsrb(xmm_tmp, xmm_tmp,
623 ptr[aux_reg_src + aux_src_off + r], r);
624 vpbroadcastd(
625 vmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
626 } else {
627 vpbroadcastd(vmm_inp(jj, jcp.nb_oc_blocking),
628 EVEX_compress_addr(
629 aux_reg_src, aux_src_off));
630 }
631 if (jcp.signed_input)
632 vpsubb(vmm_inp(jj, jcp.nb_oc_blocking),
633 vmm_inp(jj, jcp.nb_oc_blocking), vmm_shift);
634 } else {
635 /* fill padded area with shifted values */
636 if (jcp.signed_input) {
637 const Vmm inp = vmm_inp(jj, jcp.nb_oc_blocking);
638 vpxord(inp, inp, inp);
639 vpsubb(inp, inp, vmm_shift);
640 }
641 }
642 }
643 }
644 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
645 int aux_filt_off = kernel_offset(ocb, icb1, ki);
646
647 if (_end - _start > 0) {
648 if (jcp.is_depthwise)
649 vpmovsxbd(vmm_wei,
650 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
651 else
652 vmovups(vmm_wei,
653 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
654 }
655 for (int jj = _start; jj < _end; jj += ur_w_stride) {
656 const bool jj_between_start_end
657 = jj >= jj_start && jj < jj_end;
658 const bool ki_applies_to_stride
659 = (jj + jcp.l_pad - ki) % jcp.stride_w == 0;
660 const bool inside_padded_area = h_padded
661 || !(jj_between_start_end && ki_applies_to_stride);
662 const auto vmm_dst = vmm_out(jj, ocb);
663 if (jcp.signed_input || !inside_padded_area) {
664 const Vmm inp = vmm_inp(
665 h_padded ? 0 : jj, jcp.nb_oc_blocking);
666 compute(vmm_dst, vmm_wei, inp);
667 }
668 }
669 }
670 }
671 }
672
673 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
674 apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded);
675}
676
677template <typename Vmm>
678void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::kh_loop(int ur_w,
679 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
680
681 const bool signed_input_or_src_zp
682 = (jcp.signed_input || jcp.src_zero_point);
683
684 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
685 int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
686 * jcp.ngroups * jcp.ic_without_padding;
687 int shift_src_id = jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * jcp.iw
688 * jcp.ngroups * jcp.ic_without_padding;
689 const int stride_h = signed_input_or_src_zp ? 1 : jcp.stride_h;
690 int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
691 const int stride_d = signed_input_or_src_zp ? 1 : jcp.stride_d;
692 int shift_filt_kd
693 = jcp.typesize_in * jcp.kw * ch_block_all * jcp.kh * stride_d;
694
695 Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop;
696 Label t_overflow_label, no_t_overflow_label, b_overflow_label,
697 no_b_overflow_label;
698 Label back_overflow_label, no_back_overflow_label, d_h_overflow_label,
699 front_overflow_label, no_front_overflow_label, d_h_overflow_label2;
700 if (jcp.ndims == 5) {
701 mov(aux_reg_filt_d, reg_filt);
702 mov(aux_reg_src_d, reg_src);
703
704 if (signed_input_or_src_zp) {
705 mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]);
706 cmp(reg_ki, 0);
707 je(no_back_overflow_label, T_NEAR);
708 L(back_overflow_label);
709 {
710 mov(aux_reg_filt, aux_reg_filt_d);
711 mov(reg_kh, jcp.kh);
712 L(d_h_overflow_label);
713 {
714 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
715 add(aux_reg_filt, shift_filt_kh);
716 dec(reg_kh);
717 jnz(d_h_overflow_label);
718 }
719
720 add(aux_reg_filt_d, shift_filt_kd);
721 dec(reg_ki);
722 jnz(back_overflow_label);
723 }
724 L(no_back_overflow_label);
725 }
726
727 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
728
729 if ((signed_input_or_src_zp) || (jcp.dilate_d >= jcp.id)
730 || ((!signed_input_or_src_zp)
731 && ((min(jcp.f_pad, jcp.back_pad) < 0)
732 || ((jcp.kd - 1) * (jcp.dilate_d + 1)
733 < nstl::max(
734 jcp.f_pad, jcp.back_pad))))) {
735 cmp(reg_ki, 0);
736 je(skip_kd_loop, T_NEAR);
737 }
738
739 L(kd_loop_label);
740 mov(aux_reg_src, aux_reg_src_d);
741 mov(aux_reg_filt, aux_reg_filt_d);
742 } else {
743 mov(aux_reg_src, reg_src);
744 mov(aux_reg_filt, reg_filt);
745 }
746
747 if (signed_input_or_src_zp && jcp.ndims > 3) {
748 /* Weights are transposed, so first compute 'bottom' padding. */
749 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
750 cmp(reg_overflow, 0);
751 je(no_b_overflow_label, T_NEAR);
752 L(b_overflow_label);
753 {
754 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
755
756 add(aux_reg_filt, shift_filt_kh);
757 dec(reg_overflow);
758 cmp(reg_overflow, 0);
759 jg(b_overflow_label, T_NEAR);
760 }
761 L(no_b_overflow_label);
762 }
763
764 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
765
766 if ((signed_input_or_src_zp) || (jcp.dilate_h >= jcp.ih)
767 || ((!signed_input_or_src_zp)
768 && ((min(jcp.t_pad, jcp.b_pad) < 0)
769 || ((jcp.kh - 1) * (jcp.dilate_h + 1)
770 < nstl::max(jcp.t_pad, jcp.b_pad))))) {
771 cmp(reg_kh, 0);
772 je(skip_kh_loop, T_NEAR);
773 }
774
775 L(kh_loop_label);
776 {
777 compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
778 sub(aux_reg_src, shift_src_ih);
779 add(aux_reg_filt, shift_filt_kh);
780 dec(reg_kh);
781
782 /* Insert weight compensation in stride 'holes' */
783 if (signed_input_or_src_zp && jcp.stride_h > 1) {
784 Label kh_comp_loop;
785
786 cmp(reg_kh, 0);
787 je(skip_kh_loop, T_NEAR);
788 mov(reg_comp_strides, jcp.stride_h - 1);
789 L(kh_comp_loop);
790 {
791 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
792 add(aux_reg_filt, shift_filt_kh);
793 dec(reg_comp_strides);
794 cmp(reg_comp_strides, 0);
795 jg(kh_comp_loop, T_NEAR);
796 }
797 }
798 cmp(reg_kh, 0);
799 jg(kh_loop_label, T_NEAR);
800 }
801 L(skip_kh_loop);
802 if (signed_input_or_src_zp && jcp.ndims > 3) {
803 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
804 cmp(reg_overflow, 0);
805 je(no_t_overflow_label, T_NEAR);
806 L(t_overflow_label);
807 {
808 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
809
810 add(aux_reg_filt, shift_filt_kh);
811 dec(reg_overflow);
812 cmp(reg_overflow, 0);
813 jg(t_overflow_label, T_NEAR);
814 }
815 L(no_t_overflow_label);
816 }
817
818 if (jcp.ndims == 5) {
819 sub(aux_reg_src_d, shift_src_id);
820 add(aux_reg_filt_d, shift_filt_kd);
821 dec(reg_ki);
822
823 /* Insert weight compensation in stride 'holes' */
824 if (signed_input_or_src_zp && jcp.stride_d > 1) {
825 Label kd_comp_loop, kd_kh_comp_loop;
826 cmp(reg_ki, 0);
827 jz(skip_kd_loop, T_NEAR);
828 mov(reg_comp_strides, jcp.stride_d - 1);
829 L(kd_comp_loop);
830 mov(aux_reg_filt, aux_reg_filt_d);
831 mov(reg_kh, jcp.kh);
832 L(kd_kh_comp_loop);
833 {
834 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
835 add(aux_reg_filt, shift_filt_kh);
836 dec(reg_kh);
837 jnz(kd_kh_comp_loop, T_NEAR);
838 }
839 add(aux_reg_filt_d, shift_filt_kd);
840 dec(reg_comp_strides);
841 jnz(kd_comp_loop);
842 }
843
844 cmp(reg_ki, 0);
845 jg(kd_loop_label, T_NEAR);
846 L(skip_kd_loop);
847 if (signed_input_or_src_zp) {
848 mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]);
849 cmp(reg_ki, 0);
850 jz(no_front_overflow_label, T_NEAR);
851 L(front_overflow_label);
852 {
853 mov(aux_reg_filt, aux_reg_filt_d);
854 mov(reg_kh, jcp.kh);
855 L(d_h_overflow_label2);
856 {
857 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
858 add(aux_reg_filt, shift_filt_kh);
859 dec(reg_kh);
860 jnz(d_h_overflow_label2);
861 }
862 add(aux_reg_filt_d, shift_filt_kd);
863 dec(reg_ki);
864 jnz(front_overflow_label);
865 }
866 L(no_front_overflow_label);
867 }
868 }
869}
870template <typename Vmm>
871int jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_tail_size() const
872 noexcept {
873 return jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
874 : jcp.oc_without_padding % jcp.oc_block;
875}
876
877template <typename Vmm>
878int jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_blocking_size() const
879 noexcept {
880 return jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
881}
882
883template <typename Vmm>
884void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::prepare_output(int ur_w) {
885 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
886 for (int ur = 0; ur < ur_w; ur++) {
887 const Vmm vmm = vmm_out(ur, ocb);
888 vpxord(vmm, vmm, vmm);
889 }
890 }
891 if (jcp.signed_input) {
892 xor_(reg_scratch, reg_scratch);
893 Reg8 _t8 = reg_scratch.cvt8();
894 mov(_t8, (int8_t)-128);
895 vpbroadcastb(vmm_shift, _t8);
896 }
897}
898
899template <typename Vmm>
900void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::cvt2ps(
901 data_type_t type_in, Vmm vmm_in, const Operand &op, bool mask_flag) {
902 const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
903 switch (type_in) {
904 case data_type::f32:
905 case data_type::s32: vmovups(vmm, op); break;
906 case data_type::s8: vpmovsxbd(vmm, op); break;
907 case data_type::u8: vpmovzxbd(vmm, op); break;
908 default: assert(!"unsupported data type");
909 }
910 if (type_in != data_type::f32) vcvtdq2ps(vmm_in, vmm_in);
911}
912
913template <typename Vmm>
914void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::store_output(
915 int ur_w, bool last_oc_block) {
916 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
917 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
918
919 if (jcp.signed_input)
920 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
921
922 if (jcp.src_zero_point) {
923 mov(reg_zp_src_, ptr[param1 + GET_OFF(src_zero_point)]);
924 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
925 }
926
927 if (jcp.src_zero_point) {
928 const auto &vmm_src_zp = vmm_tmp;
929 const auto &vmm_zp_comp = vmm_wei;
930 uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]);
931
932 const bool is_tail = get_tail_size() > 0;
933 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
934 const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
935 const bool is_last_ocb
936 = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
937 const auto vmm = is_last_ocb && is_tail
938 ? vmm_zp_comp | ktail_mask | T_z
939 : vmm_zp_comp;
940 vmovups(vmm, ptr[reg_zp_compensation + zp_offset]);
941
942 uni_vpmulld(vmm_zp_comp, vmm, vmm_src_zp);
943
944 for (int ur = 0; ur < ur_w; ur++) {
945 const auto vmm_dst = vmm_out(ur, ocb);
946 uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp);
947 }
948 }
949 }
950
951 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
952 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
953 int scale_offset
954 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
955
956 const Vmm vmm_bias = vmm_tmp;
957 if (jcp.with_bias) {
958 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
959 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
960 cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
961 }
962 if (jcp.signed_input) {
963 int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
964 auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
965 cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
966 }
967
968 for (int ur = 0; ur < ur_w; ur++) {
969 const Vmm vmm = vmm_out(ur, ocb);
970 vcvtdq2ps(vmm, vmm);
971 if (jcp.signed_input) vaddps(vmm, vmm, vmm_comp);
972 const Vmm mask_vmm = mask_flag ? vmm | ktail_mask | T_z : vmm;
973 vmulps(mask_vmm, vmm,
974 EVEX_compress_addr(reg_ptr_scales, scale_offset));
975 if (jcp.with_bias) vaddps(vmm, vmm, vmm_bias);
976 }
977 }
978 /* Do post-ops */
979 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
980 const auto &p = attr_.post_ops_;
981 const int sum_idx = p.find(primitive_kind::sum);
982 const float *p_sum_scale
983 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
984 if (p_sum_scale && *p_sum_scale != 1.f)
985 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
986
987 const auto sum_injector = [&]() {
988 if (p_sum_scale) { // post_op: sum
989 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
990 const bool mask_flag
991 = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
992 for (int j = 0; j < ur_w; j++) {
993 int aux_output_offset = jcp.typesize_out
994 * (k * jcp.oc_block
995 + j * jcp.oc_without_padding
996 * jcp.ngroups);
997 auto addr = EVEX_compress_addr(
998 reg_dst, aux_output_offset);
999 const Vmm vmm = vmm_out(j, k);
1000 cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
1001 if (*p_sum_scale == 1.f)
1002 vaddps(vmm, vmm_prev_dst);
1003 else
1004 vfmadd231ps(vmm, vmm_prev_dst,
1005 zword_b[reg_ptr_sum_scale]);
1006 }
1007 }
1008 }
1009 };
1010
1011 if (p_sum_scale)
1012 postops_injector_->set_lambda_injector(
1013 primitive_kind::sum, sum_injector);
1014
1015 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1016 if (jcp.with_binary) {
1017 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1018 const bool mask_flag
1019 = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
1020 for (int ur = 0; ur < ur_w; ur++) {
1021 const int vmm_idx = vmm_out(ur, ocb).getIdx();
1022 const size_t aux_output_offset = jcp.typesize_out
1023 * (ocb * jcp.oc_block
1024 + ur * jcp.oc_without_padding
1025 * jcp.ngroups);
1026
1027 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst);
1028 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1029 vmm_idx, aux_output_offset);
1030 if (mask_flag)
1031 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1032 }
1033 }
1034 }
1035 const int nb_oc_block
1036 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
1037 postops_injector_->compute_vector_range(
1038 0, nb_oc_block * ur_w, rhs_arg_params);
1039 }
1040
1041 if (jcp.dst_scale) {
1042 mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]);
1043 uni_vmovups(vmm_dst_scale, ptr[reg_ptr_dst_scales]);
1044
1045 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1046 const bool mask_flag
1047 = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
1048 for (int ur = 0; ur < ur_w; ur++) {
1049 const auto vmm = vmm_out(ur, ocb);
1050 const Vmm mask_vmm = mask_flag ? vmm | ktail_mask | T_z : vmm;
1051 uni_vmulps(mask_vmm, vmm, vmm_dst_scale);
1052 }
1053 }
1054 }
1055 if (jcp.dst_zero_point) {
1056 mov(reg_zp_dst_, ptr[param1 + GET_OFF(dst_zero_point)]);
1057 const auto &vmm_zp_dst = vmm_tmp;
1058 uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]);
1059 vcvtdq2ps(vmm_zp_dst, vmm_zp_dst);
1060
1061 for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++)
1062 for (int ur = 0; ur < ur_w; ur++) {
1063 const auto vmm_dst = vmm_out(ur, ocb);
1064 uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst);
1065 }
1066 }
1067
1068 // Properly saturate the accumulators for integer datatypes
1069
1070 // No need to saturate on lower bound for signed integer types, as
1071 // the conversion to int would return INT_MIN, and then proper
1072 // saturation will happen when storing data
1073 if (jcp.dst_dt == data_type::u8) {
1074 vpxord(vmm_zero, vmm_zero, vmm_zero);
1075 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1076 for (int ur = 0; ur < ur_w; ur++) {
1077 const Vmm vmm = vmm_out(ur, ocb);
1078 vmaxps(vmm, vmm_zero, vmm);
1079 }
1080 }
1081 }
1082
1083 if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1084 float saturation_ubound = types::max_value<float>(jcp.dst_dt);
1085 Xmm xmm_saturation(vmm_saturation.getIdx());
1086 mov(reg_ptr_saturation_ubound, float2int(saturation_ubound));
1087 vmovq(xmm_saturation, reg_ptr_saturation_ubound);
1088 vbroadcastss(vmm_saturation, xmm_saturation);
1089
1090 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1091 for (int ur = 0; ur < ur_w; ur++) {
1092 const Vmm vmm = vmm_out(ur, ocb);
1093 vminps(vmm, vmm, vmm_saturation);
1094 }
1095 }
1096 }
1097
1098 if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1099 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1100 for (int ur = 0; ur < ur_w; ur++) {
1101 const Vmm vmm = vmm_out(ur, ocb);
1102 vcvtps2dq(vmm, vmm);
1103 }
1104 }
1105 }
1106
1107 /* write out register to output_addr */
1108 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
1109 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
1110 for (int ur = 0; ur < ur_w; ur++) {
1111 int aux_dst_off = jcp.typesize_out
1112 * (ur * jcp.ngroups * jcp.oc_without_padding
1113 + ocb * jcp.oc_block);
1114 auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
1115
1116 const Vmm vmm = vmm_out(ur, ocb);
1117 const Vmm r_vmm = mask_flag ? vmm | ktail_mask : vmm;
1118 switch (jcp.dst_dt) {
1119 case data_type::f32:
1120 case data_type::s32: vmovups(addr, r_vmm); break;
1121 case data_type::s8: vpmovsdb(addr, r_vmm); break;
1122 case data_type::u8: vpmovusdb(addr, r_vmm); break;
1123 default: assert(!"unknown dst_dt");
1124 }
1125 }
1126 }
1127}
1128
1129template <typename Vmm>
1130void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::icb_loop(
1131 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
1132
1133 int shift_src_icb = jcp.typesize_in * jcp.ic_block;
1134 const size_t shift_filt_icb = (size_t)jcp.typesize_in * jcp.kd * jcp.kh
1135 * jcp.kw * jcp.ic_block * jcp.oc_block;
1136
1137 prepare_output(ur_w);
1138
1139 Label skip_icb_loop, icb_loop_label;
1140
1141 mov(reg_icb, jcp.nb_ic);
1142
1143 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) {
1144 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
1145 if (jcp.ndims > 3) {
1146 mov(reg_scratch,
1147 qword[param1 + GET_OFF(zp_src_pad_str_compensation)]);
1148 mov(zp_src_pad_comp_addr, reg_scratch);
1149 }
1150 }
1151
1152 L(icb_loop_label);
1153 {
1154
1155 if (jcp.ic_without_padding != jcp.ic) {
1156 Label common_ker, end_ker;
1157 cmp(reg_icb, 1);
1158 jg(common_ker, T_NEAR);
1159
1160 kh_loop(ur_w, l_overflow, r_overflow,
1161 is_last_sp_block ? last_sp_block : last_ic_block);
1162 jmp(end_ker, T_NEAR);
1163
1164 L(common_ker);
1165 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1166
1167 L(end_ker);
1168 } else {
1169 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1170 }
1171
1172 add(reg_src, shift_src_icb);
1173 safe_add(reg_filt, shift_filt_icb, reg_ker_long_offt);
1174 dec(reg_icb);
1175 cmp(reg_icb, 0);
1176 jg(icb_loop_label, T_NEAR);
1177 }
1178
1179 /* come-back pointers */
1180 sub(reg_src, jcp.nb_ic * shift_src_icb);
1181 safe_sub(reg_filt, jcp.nb_ic * shift_filt_icb, reg_ker_long_offt);
1182 L(skip_icb_loop);
1183
1184 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
1185 Label common_store, end_store;
1186 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
1187 if (jcp.is_depthwise)
1188 cmp(reg_oc_blocks, jcp.nb_ch - 1);
1189 else
1190 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
1191 jne(common_store, T_NEAR);
1192
1193 store_output(ur_w, true);
1194 jmp(end_store, T_NEAR);
1195
1196 L(common_store);
1197 store_output(ur_w, false);
1198
1199 L(end_store);
1200
1201 } else {
1202 store_output(ur_w, false);
1203 }
1204}
1205
1206template <typename Vmm>
1207ur_w_blks_params_t
1208jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_ur_w_blks_params() {
1209 const int n_ur_blocks = jcp.ow / jcp.ur_w;
1210
1211 ur_w_blks_params_t ur_w_blks_params;
1212 int num_blks_to_process_sp_carefully = 0;
1213 int idx_last_non_zero_l_overflow_blk = -1;
1214 int idx_first_non_zero_r_overflow_blk = n_ur_blocks;
1215
1216 static constexpr int src_pixels_loaded_for_bcast = 4;
1217 const auto ic_mod = jcp.ic_without_padding % src_pixels_loaded_for_bcast;
1218 for (int blk_idx = 0; blk_idx < n_ur_blocks; blk_idx++) {
1219 const int first_blk_dst_elem = blk_idx * jcp.ur_w;
1220 const int last_dst_blk_elem = first_blk_dst_elem + jcp.ur_w - 1;
1221
1222 const int last_blk_src_idx = nstl::min(
1223 jcp.iw - 1, (last_dst_blk_elem + jcp.l_pad) / jcp.stride_w);
1224 const bool is_out_of_src_pixels_scope
1225 = ((jcp.iw - 1 - last_blk_src_idx) * jcp.ic_without_padding
1226 + ic_mod
1227 < src_pixels_loaded_for_bcast);
1228
1229 const bool process_sp_carefully
1230 = (ic_mod != 0) && is_out_of_src_pixels_scope;
1231 const int curr_l_overflow = nstl::max(0,
1232 ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad
1233 - first_blk_dst_elem)
1234 / jcp.stride_w);
1235 const int curr_r_overflow = nstl::max(0,
1236 (last_dst_blk_elem + jcp.l_pad) / jcp.stride_w - (jcp.iw - 1));
1237
1238 ur_w_blks_params.blks_params.emplace_back(
1239 curr_l_overflow, curr_r_overflow, process_sp_carefully);
1240
1241 num_blks_to_process_sp_carefully
1242 += static_cast<int>(process_sp_carefully);
1243 if (curr_l_overflow > 0) idx_last_non_zero_l_overflow_blk = blk_idx;
1244 if (curr_r_overflow > 0 && idx_first_non_zero_r_overflow_blk > blk_idx)
1245 idx_first_non_zero_r_overflow_blk = blk_idx;
1246 }
1247 idx_first_non_zero_r_overflow_blk
1248 = nstl::max(idx_first_non_zero_r_overflow_blk,
1249 idx_last_non_zero_l_overflow_blk + 1);
1250 // limit num_r_overflow_blks and num_blks_to_process_last_sp_carefully so that:
1251 // n_ur_blocks >= num_l_overflow_blks + max(num_r_overflow_blks, num_blks_to_process_last_sp_carefully)
1252 ur_w_blks_params.num_pre_blks
1253 = nstl::max(0, idx_last_non_zero_l_overflow_blk + 1);
1254 const int num_r_overflow_blks = idx_first_non_zero_r_overflow_blk
1255 <= idx_last_non_zero_l_overflow_blk
1256 ? n_ur_blocks - ur_w_blks_params.num_pre_blks
1257 : n_ur_blocks - idx_first_non_zero_r_overflow_blk;
1258 num_blks_to_process_sp_carefully
1259 = ur_w_blks_params.num_pre_blks + num_blks_to_process_sp_carefully
1260 < n_ur_blocks
1261 ? num_blks_to_process_sp_carefully
1262 : n_ur_blocks - ur_w_blks_params.num_pre_blks;
1263 ur_w_blks_params.num_post_blks
1264 = nstl::max(num_r_overflow_blks, num_blks_to_process_sp_carefully);
1265
1266 return ur_w_blks_params;
1267}
1268
1269template <typename Vmm>
1270void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::generate() {
1271 preamble();
1272
1273 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1274 sub(rsp, reserved_stack_size_);
1275
1276 xor_(reg_scratch, reg_scratch);
1277 Reg16 _t = reg_scratch.cvt16();
1278 mov(_t, 0x1);
1279 vpbroadcastw(vmm_one, _t);
1280
1281 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
1282 int tail_size = jcp.is_depthwise
1283 ? jcp.ngroups % jcp.ch_block
1284 : jcp.oc_without_padding % jcp.oc_block;
1285 int mask = (1 << tail_size) - 1;
1286 Reg32 regw_tmp = reg_nur_w.cvt32();
1287 Label skip_tail_mask;
1288 if (jcp.is_depthwise) {
1289 kxnorw(ktail_mask, ktail_mask, ktail_mask);
1290 cmp(dword[param1 + GET_OFF(oc_blocks)], jcp.nb_ch - 1);
1291 jne(skip_tail_mask, T_NEAR);
1292 }
1293 mov(regw_tmp, mask);
1294 kmovw(ktail_mask, regw_tmp);
1295 L(skip_tail_mask);
1296 }
1297
1298 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1299 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
1300 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
1301
1302 int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
1303 * jcp.oc_without_padding;
1304 int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
1305 * jcp.ic_without_padding;
1306
1307 const auto ur_w_blks_params = get_ur_w_blks_params();
1308 const int nur_w = jcp.ow / jcp.ur_w - ur_w_blks_params.num_pre_blks
1309 - ur_w_blks_params.num_post_blks;
1310
1311 const auto &blks_params = ur_w_blks_params.blks_params;
1312 const auto num_pre_blks = ur_w_blks_params.num_pre_blks;
1313 const auto num_post_blks = ur_w_blks_params.num_post_blks;
1314
1315 for (int i = 0; i < num_pre_blks; i++) {
1316 const bool blk_process_carefully = blks_params[i].process_sp_carefully;
1317 const int blk_l_overflow = blks_params[i].l_overflow;
1318 const int blk_r_overflow = blks_params[i].r_overflow;
1319
1320 icb_loop(jcp.ur_w, blk_l_overflow, blk_r_overflow,
1321 blk_process_carefully);
1322 add(reg_src, src_shift);
1323 add(reg_dst, dst_shift);
1324 }
1325
1326 if (nur_w > 0) {
1327 xor_(reg_nur_w, reg_nur_w);
1328 Label ow_loop_label;
1329 L(ow_loop_label);
1330 {
1331 icb_loop(jcp.ur_w, 0, 0, false);
1332 add(reg_src, src_shift);
1333 add(reg_dst, dst_shift);
1334 inc(reg_nur_w);
1335 cmp(reg_nur_w, nur_w);
1336 jl(ow_loop_label, T_NEAR);
1337 }
1338 }
1339
1340 if (num_post_blks > 0) {
1341 const auto blks_params_size = blks_params.size();
1342 const auto start_blk_idx = blks_params_size - num_post_blks;
1343 for (size_t i = start_blk_idx; i < blks_params_size; i++) {
1344 const bool blk_process_carefully
1345 = blks_params[i].process_sp_carefully;
1346 const int blk_l_overflow = blks_params[i].l_overflow;
1347 const int blk_r_overflow = blks_params[i].r_overflow;
1348
1349 icb_loop(jcp.ur_w, blk_l_overflow, blk_r_overflow,
1350 blk_process_carefully);
1351 add(reg_src, src_shift);
1352 add(reg_dst, dst_shift);
1353 }
1354 }
1355
1356 if (jcp.ur_w_tail != 0) {
1357 // l_overflow - no. of spatial elements of weights standing out of src spatial
1358 // when computing the left-most (in w dim) output pixel
1359 int l_overflow = 0;
1360 if (jcp.ur_w == jcp.ow)
1361 l_overflow = max(0,
1362 ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad)
1363 / jcp.stride_w);
1364 // r_overflow - no/ of spatial elements of weights standing out of src spatial
1365 // when computing the right-most (in w dim) output pixel
1366 const int r_overflow = max(0,
1367 ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
1368 / jcp.stride_w);
1369
1370 icb_loop(jcp.ur_w_tail, l_overflow, r_overflow, true);
1371 }
1372
1373 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1374 add(rsp, reserved_stack_size_);
1375
1376 postamble();
1377
1378 if (jcp.with_eltwise) postops_injector_->prepare_table();
1379}
1380
1381const float *jit_avx512_core_x8s8s32x_deconvolution_fwd_t::adjust_oscales(
1382 const memory_tracking::grantor_t &scratchpad, const float *src_scales,
1383 const float *wei_scales) const {
1384 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1385 int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
1386 float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
1387 ? 1.f / pd()->jcp_.wei_adj_scale
1388 : 1.0f;
1389 if (wei_mask == 0) {
1390 utils::array_set(
1391 loc_scales, src_scales[0] * wei_scales[0] * factor, 16);
1392 } else {
1393 for (dim_t c = 0; c < pd()->OC(); c++)
1394 loc_scales[c] = src_scales[0] * wei_scales[c] * factor;
1395 }
1396 return loc_scales;
1397}
1398
1399status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d(
1400 const exec_ctx_t &ctx) const {
1401 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1402 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1403 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1404 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1405 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1406 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1407
1408 const memory_desc_wrapper src_d(pd()->src_md());
1409 const memory_desc_wrapper dst_d(pd()->dst_md());
1410 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1411 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1412
1413 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1414 const auto &jcp = pd()->jcp_;
1415
1416 auto scratchpad = ctx.get_scratchpad_grantor();
1417 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1418
1419 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1420 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1421 weights_d, weights, zp_src, zp_src_comp_scratch,
1422 zp_src_pad_comp_kernel_.get());
1423
1424 const auto post_ops_binary_rhs_arg_vec
1425 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1426
1427 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1428 const int nb_groups = jcp.nb_ch;
1429
1430 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1431 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1432 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1433
1434 const float *oscales = adjust_oscales(
1435 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1436
1437 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1438 auto w = const_cast<int8_t *>(weights);
1439 int32_t *compensation = (jcp.signed_input)
1440 ? reinterpret_cast<int32_t *>(&w[offset])
1441 : nullptr;
1442 const int32_t *zp_compensation = jcp.src_zero_point
1443 ? get_src_zp_comp_from_wei(
1444 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1445 : nullptr;
1446
1447 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1448 int start {0}, end {0};
1449 int work_amount = jcp.mb * nb_groups * oc_chunks;
1450 balance211(work_amount, nthr, ithr, start, end);
1451
1452 auto p = jit_deconv_call_s();
1453
1454 int n {0}, g {0}, occ {0};
1455 if (jcp.loop_order == loop_ngc)
1456 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
1457 else if (jcp.loop_order == loop_cgn)
1458 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
1459 else
1460 assert(!"unsupported loop order");
1461 while (start < end) {
1462
1463 int ocb = occ * jcp.nb_oc_blocking;
1464 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1465 int g_ic = g * jcp.ch_block * jcp.ic;
1466
1467 p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1468 p.src = src + src_d.blk_off(n, g_ic);
1469 p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
1470 p.bias = jcp.with_bias
1471 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1472 : nullptr;
1473 p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
1474 p.scales = &oscales[jcp.is_oc_scale * g_oc];
1475 p.dst_scale = dst_scales;
1476 p.t_overflow = 0;
1477 p.b_overflow = 0;
1478 p.kh_padding = jcp.kh;
1479 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1480 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1481 p.oc_l_off = g_oc;
1482 p.zp_compensation
1483 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1484 p.zp_src_pad_str_compensation
1485 = jcp.src_zero_point ? zp_src_comp_scratch + g_oc : nullptr;
1486 p.src_zero_point = zp_src;
1487 p.dst_zero_point = zp_dst;
1488 p.dst_orig = dst;
1489 (*kernel_)(&p);
1490
1491 ++start;
1492 if (jcp.loop_order == loop_ngc)
1493 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
1494 else if (jcp.loop_order == loop_cgn)
1495 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
1496 else
1497 assert(!"unsupported loop order");
1498 }
1499 });
1500 return status::success;
1501}
1502
1503status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d(
1504 const exec_ctx_t &ctx) const {
1505 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1506 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1507 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1508 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1509 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1510 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1511
1512 const auto &jcp = pd()->jcp_;
1513
1514 const memory_desc_wrapper src_d(pd()->src_md());
1515 const memory_desc_wrapper dst_d(pd()->dst_md());
1516 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1517 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1518 const auto post_ops_binary_rhs_arg_vec
1519 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1520
1521 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1522
1523 auto scratchpad = ctx.get_scratchpad_grantor();
1524 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1525
1526 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1527 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1528 weights_d, weights, zp_src, zp_src_comp_scratch,
1529 zp_src_pad_comp_kernel_.get());
1530
1531 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1532 int nb_groups = jcp.nb_ch;
1533
1534 size_t src_h_stride = src_d.blk_off(0, 0, 1);
1535 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
1536 size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1537
1538 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1539 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1540 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1541
1542 const float *oscales = adjust_oscales(
1543 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1544
1545 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1546 auto w = const_cast<int8_t *>(weights);
1547 int32_t *compensation = (jcp.signed_input)
1548 ? reinterpret_cast<int32_t *>(&w[offset])
1549 : nullptr;
1550 const int32_t *zp_compensation = jcp.src_zero_point
1551 ? get_src_zp_comp_from_wei(
1552 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1553 : nullptr;
1554
1555 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1556 int start {0}, end {0};
1557 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
1558 balance211(work_amount, nthr, ithr, start, end);
1559
1560 auto p = jit_deconv_call_s();
1561
1562 /*loop order = cgn*/
1563 int n {0}, g {0}, occ {0}, oh_s {0};
1564 if (jcp.loop_order == loop_ngc)
1565 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1566 oh_s, jcp.oh);
1567 else if (jcp.loop_order == loop_cgn)
1568 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1569 oh_s, jcp.oh);
1570 else
1571 assert(!"unsupported loop order");
1572 while (start < end) {
1573
1574 int ocb = occ * jcp.nb_oc_blocking;
1575 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1576 int g_ic = g * jcp.ch_block * jcp.ic;
1577 int work_rem = end - start;
1578 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1579
1580 auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1581 auto src_w = src + src_d.blk_off(n, g_ic);
1582 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
1583 auto bias_w = jcp.with_bias
1584 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1585 : nullptr;
1586 int32_t *compensation_w
1587 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1588
1589 auto scales = &oscales[jcp.is_oc_scale * g_oc];
1590 for (int oj = oh_s; oj < oh_e; oj++) {
1591 int ih_max = 0, kh_lo = 0, kh_len = 0;
1592 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1593 /* dilation */
1594 int dilate_h = jcp.dilate_h + 1;
1595 // Note: use div_up to account for "holes" in filter
1596 int o_t_overflow = div_up(
1597 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1598 dilate_h);
1599 int o_b_overflow
1600 = div_up(max(0,
1601 (jcp.kh - 1) * dilate_h + 1
1602 - jcp.oh + oj - jcp.b_pad),
1603 dilate_h);
1604 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1605 kh_lo = o_b_overflow;
1606 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1607 } else {
1608 int o_t_overflow = max(
1609 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1610 int o_b_overflow = max(0,
1611 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1612 / jcp.stride_h);
1613 int overflow_kh_hi = jcp.kh - 1
1614 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1615 jcp.stride_h);
1616 int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1617
1618 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1619 + 1 - o_t_overflow - o_b_overflow;
1620 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1621 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1622 }
1623
1624 int wei_stride = (!jcp.signed_input && !jcp.src_zero_point)
1625 ? kh_lo * wht_kh_stride
1626 : 0;
1627 p.src = src_w + ih_max * src_h_stride;
1628 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1629 p.filt = wht_w + wei_stride;
1630 p.bias = bias_w;
1631 p.compensation = compensation_w;
1632 p.t_overflow = jcp.dilate_h > 0
1633 ? jcp.kh - kh_len - kh_lo
1634 : max(0,
1635 jcp.kh
1636 - (kh_lo
1637 + max(0, kh_len - 1)
1638 * jcp.stride_h
1639 + 1));
1640 p.b_overflow = kh_lo;
1641 p.kh_padding = kh_len;
1642 p.scales = scales;
1643 p.dst_scale = dst_scales;
1644 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1645 p.post_ops_binary_rhs_arg_vec
1646 = post_ops_binary_rhs_arg_vec.data();
1647 p.oc_l_off = g_oc;
1648 p.zp_compensation
1649 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1650 p.zp_src_pad_str_compensation = jcp.src_zero_point
1651 ? zp_src_comp_scratch + g_oc
1652 : nullptr;
1653 p.src_zero_point = zp_src;
1654 p.dst_zero_point = zp_dst;
1655 p.dst_orig = dst;
1656
1657 (*kernel_)(&p);
1658 }
1659 if (jcp.loop_order == loop_ngc)
1660 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1661 oc_chunks, oh_s, jcp.oh);
1662 else if (jcp.loop_order == loop_cgn)
1663 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1664 jcp.mb, oh_s, jcp.oh);
1665 else
1666 assert(!"unsupported loop order");
1667 }
1668 });
1669 return status::success;
1670}
1671
1672status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_3d(
1673 const exec_ctx_t &ctx) const {
1674 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1675 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1676 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1677 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1678 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1679 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1680
1681 const auto &jcp = pd()->jcp_;
1682
1683 const memory_desc_wrapper src_d(pd()->src_md());
1684 const memory_desc_wrapper dst_d(pd()->dst_md());
1685 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1686 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1687 const auto post_ops_binary_rhs_arg_vec
1688 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1689
1690 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1691
1692 auto scratchpad = ctx.get_scratchpad_grantor();
1693 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1694
1695 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1696 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1697 weights_d, weights, zp_src, zp_src_comp_scratch,
1698 zp_src_pad_comp_kernel_.get());
1699
1700 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1701 int nb_groups = jcp.nb_ch;
1702
1703 size_t src_d_stride = src_d.blk_off(0, 0, 1);
1704 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
1705 size_t dst_d_stride = dst_d.blk_off(0, 0, 1);
1706 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
1707 size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1708 size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
1709
1710 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1711 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1712 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1713
1714 const float *oscales = adjust_oscales(
1715 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1716
1717 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1718 auto w = const_cast<int8_t *>(weights);
1719 int32_t *compensation = (jcp.signed_input)
1720 ? reinterpret_cast<int32_t *>(&w[offset])
1721 : nullptr;
1722 const int32_t *zp_compensation = jcp.src_zero_point
1723 ? get_src_zp_comp_from_wei(
1724 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1725 : nullptr;
1726
1727 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1728 int start {0}, end {0};
1729 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh;
1730 balance211(work_amount, nthr, ithr, start, end);
1731
1732 auto p = jit_deconv_call_s();
1733
1734 /*loop order = cgn*/
1735 int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0};
1736 if (jcp.loop_order == loop_ngc)
1737 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1738 od_s, jcp.od, oh_s, jcp.oh);
1739 else if (jcp.loop_order == loop_cgn)
1740 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1741 od_s, jcp.od, oh_s, jcp.oh);
1742 else
1743 assert(!"unsupported loop order");
1744 while (start < end) {
1745
1746 int ocb = occ * jcp.nb_oc_blocking;
1747 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1748 int g_ic = g * jcp.ch_block * jcp.ic;
1749 int work_rem = end - start;
1750 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1751 int input_d_s = 0, kd_len = 0, kd_lo = 0;
1752 int d_t_overflow, d_back_overflow;
1753
1754 if (jcp.dilate_d != 0 && jcp.stride_d == 1) {
1755 /* dilation */
1756 int dilate_d = jcp.dilate_d + 1;
1757 // Note: use div_up to account for "holes" in filter
1758 d_t_overflow = div_up(
1759 max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad),
1760 dilate_d);
1761 d_back_overflow
1762 = div_up(max(0,
1763 (jcp.kd - 1) * dilate_d + 1 - jcp.od
1764 + od_s - jcp.back_pad),
1765 dilate_d);
1766 kd_len = jcp.kd - d_t_overflow - d_back_overflow;
1767 kd_lo = d_back_overflow;
1768 input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d;
1769 } else {
1770 int d_t_overflow = max(
1771 0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d);
1772 int d_back_overflow = max(0,
1773 ((od_s + jcp.kd) - (jcp.od + jcp.back_pad))
1774 / jcp.stride_d);
1775 int overflow_kd_hi = jcp.kd - 1
1776 - modulo(jcp.od + jcp.back_pad - (od_s + 1),
1777 jcp.stride_d);
1778 int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d;
1779
1780 kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
1781 - d_t_overflow - d_back_overflow;
1782 kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d;
1783 input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d;
1784 }
1785
1786 auto dst_w = dst
1787 + dst_dt_size
1788 * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride);
1789 auto src_w
1790 = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride;
1791 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
1792 + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo)
1793 * wht_kd_stride;
1794 auto bias_w = jcp.with_bias
1795 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1796 : nullptr;
1797 int32_t *compensation_w
1798 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1799
1800 auto scales = &oscales[jcp.is_oc_scale * g_oc];
1801
1802 for (int oj = oh_s; oj < oh_e; oj++) {
1803 int ih_max = 0, kh_lo = 0, kh_len = 0;
1804 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1805 /* dilation */
1806 int dilate_h = jcp.dilate_h + 1;
1807 // Note: use div_up to account for "holes" in filter
1808 int o_t_overflow = div_up(
1809 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1810 dilate_h);
1811 int o_b_overflow
1812 = div_up(max(0,
1813 (jcp.kh - 1) * dilate_h + 1
1814 - jcp.oh + oj - jcp.b_pad),
1815 dilate_h);
1816 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1817 kh_lo = o_b_overflow;
1818 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1819 } else {
1820 int o_t_overflow = max(
1821 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1822 int o_b_overflow = max(0,
1823 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1824 / jcp.stride_h);
1825 int overflow_kh_hi = jcp.kh - 1
1826 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1827 jcp.stride_h);
1828 int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1829
1830 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1831 + 1 - o_t_overflow - o_b_overflow;
1832 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1833 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1834 }
1835
1836 int wei_stride = (!jcp.signed_input && !jcp.src_zero_point)
1837 ? kh_lo * wht_kh_stride
1838 : 0;
1839 p.src = src_w + ih_max * src_h_stride;
1840 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1841 p.filt = wht_w + wei_stride;
1842 p.bias = bias_w;
1843 p.compensation = compensation_w;
1844 /* Note: Currently this kernel doesn't support dilations and
1845 strides together */
1846 p.t_overflow = jcp.dilate_h > 0
1847 ? jcp.kh - kh_len - kh_lo
1848 : max(0,
1849 jcp.kh
1850 - (kh_lo
1851 + max(0, kh_len - 1)
1852 * jcp.stride_h
1853 + 1));
1854 p.b_overflow = kh_lo;
1855 p.f_overflow = jcp.dilate_d > 0
1856 ? jcp.kd - kd_len - kd_lo
1857 : max(0,
1858 jcp.kd
1859 - (kd_lo
1860 + max(0, kd_len - 1)
1861 * jcp.stride_d
1862 + 1));
1863 p.back_overflow = kd_lo;
1864 p.kh_padding = kh_len;
1865 p.kd_padding = kd_len;
1866 p.scales = scales;
1867 p.dst_scale = dst_scales;
1868 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1869 p.post_ops_binary_rhs_arg_vec
1870 = post_ops_binary_rhs_arg_vec.data();
1871 p.oc_l_off = g_oc;
1872 p.zp_compensation
1873 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1874 p.zp_src_pad_str_compensation = jcp.src_zero_point
1875 ? zp_src_comp_scratch + g_oc
1876 : nullptr;
1877 p.src_zero_point = zp_src;
1878 p.dst_zero_point = zp_dst;
1879 p.dst_orig = dst;
1880 (*kernel_)(&p);
1881 }
1882 if (jcp.loop_order == loop_ngc)
1883 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1884 oc_chunks, od_s, jcp.od, oh_s, jcp.oh);
1885 else if (jcp.loop_order == loop_cgn)
1886 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1887 jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
1888 else
1889 assert(!"unsupported loop order");
1890 }
1891 });
1892 return status::success;
1893}
1894
1895template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Zmm>;
1896template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Ymm>;
1897template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Xmm>;
1898
1899} // namespace x64
1900} // namespace cpu
1901} // namespace impl
1902} // namespace dnnl
1903