1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18#include "common/nstl.hpp"
19#include "common/utils.hpp"
20
21#include "cpu/cpu_primitive.hpp"
22#include "cpu/zero_point_utils.hpp"
23
24#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25#include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp"
26#include "cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp"
27
28#define GET_OFF(field) offsetof(jit_deconv_call_s, field)
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::status;
36using namespace dnnl::impl::memory_tracking::names;
37using namespace dnnl::impl::utils;
38using namespace Xbyak;
39
40using namespace nstl;
41
42#define wht_blk_off(d, g, ...) \
43 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
44 : (d).blk_off(__VA_ARGS__))
45
46template <cpu_isa_t isa>
47status_t jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(
48 jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
49 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
50 const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr,
51 int nthreads) {
52 const memory_desc_wrapper src_d(&src_md);
53 const memory_desc_wrapper dst_d(&dst_md);
54 const memory_desc_wrapper weights_d(&weights_md);
55 const memory_desc_wrapper bias_d(&bias_md);
56
57 if (!(mayiuse(isa)
58 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
59 && weights_d.data_type() == data_type::s8
60 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
61 data_type::s8, data_type::u8)))
62 return status::unimplemented;
63
64 jcp = zero<decltype(jcp)>();
65 jcp.nthr = nthreads;
66
67 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
68 jcp.signed_input = src_d.data_type() == data_type::s8;
69 const int ndims = jcp.ndims = dst_d.ndims();
70 const bool is_1d = ndims == 3;
71 const bool is_2d = ndims == 4;
72 const bool is_3d = ndims == 5;
73 const bool is_avx2 = isa == avx2;
74
75 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
76 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
77 jcp.ic = src_d.dims()[1] / jcp.ngroups;
78 jcp.id = is_3d ? src_d.dims()[2] : 1;
79 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
80 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
81 jcp.is_depthwise = true && with_groups
82 && utils::everyone_is(
83 1, jcp.ic_without_padding, jcp.oc_without_padding);
84 jcp.has_vnni = mayiuse(avx2_vnni);
85
86 /* TODO: future work, on hold until depthwise specialized kernel is
87 * implemented. */
88 if (jcp.is_depthwise && (jcp.signed_input || is_3d))
89 return status::unimplemented;
90
91 if (!zero_points_valid(&attr)) return status::unimplemented;
92 jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC);
93 jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST);
94 jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC);
95
96 format_tag_t dat_tag = utils::pick(
97 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
98
99 if (src_d.format_kind() == format_kind::any) {
100 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
101 jcp.src_tag = dat_tag;
102 } else {
103 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
104 }
105 if (jcp.src_tag != dat_tag) return status::unimplemented;
106
107 if (dst_d.format_kind() == format_kind::any) {
108 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
109 jcp.dst_tag = dat_tag;
110 } else {
111 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
112 }
113 if (jcp.dst_tag != dat_tag) return status::unimplemented;
114
115 auto set_or_check_wei_format = [&]() {
116 using namespace format_tag;
117 format_tag_t wei_tag;
118 if (jcp.ic_block == 8 || jcp.ch_block == 8) {
119 if (is_1d) {
120 wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i
121 : OIw2i8o4i;
122 } else if (is_2d) {
123 wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
124 : OIhw2i8o4i;
125 } else {
126 wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
127 }
128 } else {
129 if (is_avx2) {
130 assert(with_groups && jcp.ic_block == 4);
131 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
132 } else {
133 if (is_1d) {
134 wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i
135 : OIw4o4i;
136 } else if (is_2d) {
137 wei_tag = with_groups
138 ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
139 : OIhw4o4i;
140 } else {
141 wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
142 }
143 }
144 }
145
146 memory_desc_t want_wei_md = weights_md;
147 memory_desc_init_by_tag(want_wei_md, wei_tag);
148 if (jcp.signed_input && !jcp.is_depthwise) {
149 want_wei_md.extra.flags = 0
150 | memory_extra_flags::compensation_conv_s8s8
151 | memory_extra_flags::scale_adjust;
152 want_wei_md.extra.compensation_mask = (1 << 0)
153 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
154 want_wei_md.extra.scale_adjust = jcp.has_vnni ? 1.f : 0.5f;
155 }
156 if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups);
157
158 if (weights_md.format_kind == format_kind::any) {
159 weights_md = want_wei_md;
160 return true;
161 }
162
163 return weights_md == want_wei_md;
164 };
165
166 jcp.with_bias = with_bias;
167 if (jcp.with_bias) {
168 if (bias_d.format_kind() == format_kind::any)
169 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
170 }
171
172 jcp.prop_kind = cd.prop_kind;
173 jcp.mb = src_d.dims()[0];
174 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
175 jcp.iw = src_d.dims()[ndims - 1];
176 jcp.od = is_3d ? dst_d.dims()[2] : 1;
177 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
178 jcp.ow = dst_d.dims()[ndims - 1];
179 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
180 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
181 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
182 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
183 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
184 jcp.l_pad = cd.padding[0][ndims - 3];
185 jcp.stride_d = is_3d ? cd.strides[0] : 1;
186 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
187 jcp.stride_w = cd.strides[ndims - 3];
188
189 if (jcp.is_depthwise) {
190 jcp.ch_block = is_avx2 ? 8 : 4;
191 jcp.oc_block = 1;
192 jcp.ic_block = 1;
193 } else {
194 jcp.ch_block = 1;
195 jcp.oc_block = is_avx2 ? 8 : 4;
196 jcp.ic_block = is_avx2 ? 8 : 4;
197
198 if (jcp.ngroups == 1) {
199 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
200 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
201 } else if (jcp.ngroups != 1
202 && ((jcp.ic % jcp.ic_block != 0)
203 || (jcp.oc % jcp.oc_block != 0))) {
204 /* For grouped convolution, oneDNN doesn't support padding.
205 * When channel per group is not multiple of 8 in avx2:
206 * - Use Xmm when channels per groups is multiple of 4.
207 * - Otherwise return unimplemented */
208 jcp.oc_block = jcp.ic_block = 4;
209 }
210 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
211 return status::unimplemented;
212 }
213
214 if (!set_or_check_wei_format()) return status::unimplemented;
215
216 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
217 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
218 jcp.dilate_w = cd.dilates[ndims - 3];
219
220 if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1)
221 || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
222 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
223 return status::unimplemented;
224
225 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
226 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
227 const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
228 jcp.r_pad = calculate_end_padding(
229 jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw);
230 jcp.b_pad = calculate_end_padding(
231 jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh);
232 jcp.back_pad = calculate_end_padding(
233 jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd);
234 const bool kernel_outside_src = false || ext_kw <= jcp.l_pad
235 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
236 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
237 if (kernel_outside_src) return status::unimplemented;
238
239 CHECK(attr.set_default_formats(&dst_md));
240 if (!post_ops_ok(jcp, dst_d, attr)) return status::unimplemented;
241
242 const auto &p = attr.post_ops_;
243 const int eltwise_ind = p.find(primitive_kind::eltwise);
244 jcp.with_eltwise = eltwise_ind != -1;
245 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
246
247 const int binary_ind = p.find(primitive_kind::binary);
248 jcp.with_binary = binary_ind != -1;
249
250 const int sum_ind = p.find(primitive_kind::sum);
251 jcp.with_sum = sum_ind != -1;
252
253 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
254 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
255 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
256 const int wei_scales_per_oc = 1 << (int)with_groups;
257 jcp.is_oc_scale = wei_scales.mask_ == wei_scales_per_oc;
258 jcp.dst_scale = !dst_scales.has_default_values();
259
260 jcp.post_ops = p;
261
262 // only common and per-oc-channel scales are supported
263 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_scales_per_oc)
264 && utils::everyone_is(src_scales.mask_, dst_scales.mask_, 0);
265 if (!scales_ok) return status::unimplemented;
266
267 jcp.dst_dt = dst_d.data_type();
268 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
269 jcp.typesize_bia
270 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
271 jcp.typesize_in = types::data_type_size(src_d.data_type());
272 jcp.typesize_out = types::data_type_size(dst_d.data_type());
273
274 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
275 jcp.nb_oc = jcp.oc / jcp.oc_block;
276 jcp.nb_ic = jcp.ic / jcp.ic_block;
277
278 /* kernel blocking params */
279 const int regs = jcp.has_vnni ? 14 : 12;
280
281 jcp.nb_ch_blocking = 1;
282 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
283 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
284 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
285 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
286 break;
287
288 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
289 const int l_overflow = max(
290 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
291
292 if (jcp.ow < jcp.ur_w) {
293 jcp.ur_w = jcp.ow;
294 jcp.ur_w_tail = 0;
295 } else {
296 for (; jcp.ur_w >= 1; jcp.ur_w--) {
297 /* ur_w should be multiple of stride_w in order
298 to simplify logic for get_ow_start and get_ow_end */
299 const bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
300
301 /* boundary conditions:
302 These conditions ensure all elements close to boundary
303 are computed in a single call of compute loop */
304 const bool left_boundary_covered
305 = jcp.ur_w >= l_overflow * jcp.stride_w;
306 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
307 const int r_overflow_no_tail = max(0,
308 ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)
309 - jcp.ur_w_tail)
310 / jcp.stride_w);
311 const bool right_boundary_covered
312 = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
313
314 if (is_multiple_of_stride && left_boundary_covered
315 && right_boundary_covered)
316 break;
317 else if (jcp.ur_w == 1)
318 /* The boundary conditions above are also important
319 to maintain simplicity of calls to icb_loop,
320 if those conditions are not satisfied,
321 then special cases will need to be added
322 to use correct l_overflow/r_overflow values
323 when different iterations of compute loop
324 work on the locations close to boundary.
325 So to keep code simple, return unimplemented
326 for extreme case when a good ur_w cannot be found.
327 */
328 return status::unimplemented;
329 }
330 }
331
332 jcp.wei_adj_scale
333 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
334 ? weights_d.extra().scale_adjust
335 : 1.f;
336
337 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
338 return status::success;
339}
340
341template <cpu_isa_t isa>
342jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::jit_uni_x8s8s32x_deconv_fwd_kernel(
343 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
344 const memory_desc_wrapper &dst_d)
345 : kernel_(nullptr) {
346
347 const int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
348 switch (ch_block) {
349 case 8:
350 if (isa == avx2) {
351 kernel_ = utils::make_unique<
352 _jit_avx2_x8s8s32x_deconv_fwd_kernel>(
353 ajcp, attr, dst_d);
354 return;
355 } else
356 assert(!"invalid channel blocking for current ISA");
357 case 4:
358 kernel_ = utils::make_unique<
359 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Xbyak::Xmm>>(
360 ajcp, attr, dst_d);
361 return;
362 default: assert(!"invalid channel blocking");
363 }
364}
365
366template <cpu_isa_t isa>
367jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::~jit_uni_x8s8s32x_deconv_fwd_kernel()
368 = default;
369
370template <cpu_isa_t isa>
371void jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
372 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
373 const primitive_attr_t &attr) {
374 const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
375 const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups;
376 dim_t count = nstl::max<dim_t>(scales_count, 8);
377 scratchpad.book<float>(key_conv_adjusted_scales, count);
378
379 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) {
380 const auto zp_pad_comp_size
381 = static_cast<size_t>(jcp.oc_without_padding) * jcp.ngroups
382 * jcp.kd * jcp.kh * jcp.kw;
383 scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size);
384 }
385}
386
387template <cpu_isa_t isa>
388bool jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::post_ops_ok(jit_conv_conf_t &jcp,
389 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
390 using namespace injector;
391
392 return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary},
393 attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/,
394 false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/,
395 {broadcasting_strategy_t::per_oc,
396 broadcasting_strategy_t::scalar}));
397}
398
399template <cpu_isa_t isa, typename Vmm>
400_jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
401 Vmm>::_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
402 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d)
403 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
404 , jcp_(ajcp)
405 , postops_injector_(nullptr)
406 , ker_max_regs_(jcp_.has_vnni ? 14 : 12) {
407
408 if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) {
409 const std::size_t tail_size = get_tail_size();
410
411 static constexpr bool preserve_gpr = true;
412 static constexpr bool preserve_vmm = true;
413 static constexpr bool use_exact_tail_scalar_bcast = false;
414 static constexpr size_t vmm_helper_idx = 15;
415
416 const binary_injector::rhs_arg_static_params_t rhs_sp {vmm_helper_idx,
417 this->r14, this->r15, this->r13, preserve_gpr, preserve_vmm,
418 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d,
419 tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast};
420 const binary_injector::static_params_t bsp {this->param1_, rhs_sp};
421
422 postops_injector_ = utils::make_unique<
423 injector::jit_uni_postops_injector_t<isa, Vmm>>(
424 this, jcp_.post_ops, bsp);
425 }
426}
427
428template <cpu_isa_t isa, typename Vmm>
429_jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
430 Vmm>::~_jit_uni_x8s8s32x_deconv_fwd_kernel()
431 = default;
432
433template <cpu_isa_t isa, typename Vmm>
434Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_out(
435 int i_ur, int i_oc) const {
436 const int idx = i_ur * jcp_.nb_oc_blocking + i_oc;
437 assert(idx < ker_max_regs_);
438 /* remap the reg indices to avoid using xmm0 in eltwise injector */
439 return Vmm(15 - idx);
440}
441
442template <cpu_isa_t isa, typename Vmm>
443Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_inp(
444 int i_ic, int nb_x_blocking) const {
445 const int idx = i_ic + nb_x_blocking * jcp_.ur_w;
446 assert(idx < ker_max_regs_);
447 return Vmm(15 - idx);
448}
449
450template <cpu_isa_t isa, typename Vmm>
451int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_start(
452 int ki, int l_overflow) const noexcept {
453 int res = (jcp_.ow - 1 + jcp_.r_pad) % jcp_.stride_w
454 + l_overflow * jcp_.stride_w
455 - (jcp_.kw - 1 - ki) * (jcp_.dilate_w + 1);
456 while (res < 0)
457 res += jcp_.stride_w;
458 return res;
459}
460
461template <cpu_isa_t isa, typename Vmm>
462int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_end(
463 int ur_w, int ki, int r_overflow) const noexcept {
464 if (utils::one_of(ur_w, jcp_.ow, jcp_.ur_w_tail))
465 ur_w += nstl::min(0, jcp_.r_pad); // remove negative padding
466 int res = (ur_w - 1 + jcp_.l_pad) % jcp_.stride_w
467 + r_overflow * jcp_.stride_w - ki * (jcp_.dilate_w + 1);
468 while (res < 0)
469 res += jcp_.stride_w;
470 return ur_w - res;
471}
472
473template <cpu_isa_t isa, typename Vmm>
474int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_blocking_size() const
475 noexcept {
476 return jcp_.is_depthwise ? jcp_.ch_block : jcp_.oc_block;
477}
478
479template <cpu_isa_t isa, typename Vmm>
480int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_tail_size() const
481 noexcept {
482 return jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
483 : jcp_.oc_without_padding % jcp_.oc_block;
484}
485
486template <cpu_isa_t isa, typename Vmm>
487void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute(
488 const Vmm vreg_acc, const Vmm vreg_wei, const Vmm vreg_src) {
489
490 if (jcp_.has_vnni) {
491 vpdpbusd(vreg_acc, vreg_src, vreg_wei, Xbyak::VexEncoding);
492 } else if (jcp_.is_depthwise) {
493 uni_vmovups(vmm_tmp_, vreg_src);
494 uni_vpmulld(vmm_tmp_, vmm_tmp_, vreg_wei);
495 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
496 } else {
497 uni_vpmaddubsw(vmm_tmp_, vreg_src, vreg_wei);
498 uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_);
499 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
500 }
501}
502
503template <cpu_isa_t isa, typename Vmm>
504std::function<Vmm()> _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
505 Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept {
506
507 const int start_vmm_idx = vmm_inp(ur_w - 1, jcp_.nb_oc_blocking).getIdx();
508 const int end_vmm_idx = vmm_inp(0, jcp_.nb_oc_blocking).getIdx() + 1;
509 int current_vmm_idx = start_vmm_idx;
510
511 return [=]() mutable {
512 const Vmm vmm {static_cast<int>(current_vmm_idx++)};
513
514 if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx;
515
516 return vmm;
517 };
518}
519
520template <cpu_isa_t isa, typename Vmm>
521void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_zp_src_pad_str_comp(
522 int ur_w, int l_overflow, int r_overflow, bool h_padded) {
523 Xbyak::Label end_zp_pad, no_tail;
524
525 // apply once per icb loop, zp src stride paddding compensation calculate as
526 // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source
527 cmp(reg_icb_, jcp_.nb_ic);
528 jne(end_zp_pad, T_NEAR);
529
530 if (jcp_.ngroups % jcp_.ch_block
531 || jcp_.oc_without_padding % jcp_.oc_block) {
532 if (jcp_.is_depthwise)
533 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
534 else
535 cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
536 jne(no_tail, T_NEAR);
537
538 append_zp_src_pad_str_comp(
539 ur_w, l_overflow, r_overflow, h_padded, true /*last_oc_block*/);
540 jmp(end_zp_pad, T_NEAR);
541 }
542
543 L(no_tail);
544 append_zp_src_pad_str_comp(
545 ur_w, l_overflow, r_overflow, h_padded, false /*last_oc_block*/);
546
547 L(end_zp_pad);
548}
549
550template <cpu_isa_t isa, typename Vmm>
551void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::append_zp_src_pad_str_comp(
552 int ur_w, int l_overflow, int r_overflow, bool h_padded,
553 bool last_oc_block) {
554
555 const auto &reg_zp_src_pad_comp = reg_scratch_;
556 const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w);
557 bool base_comp_addr_loaded = false;
558
559 const auto load_base_zp_src_pad_comp_addr = [&]() {
560 if (!base_comp_addr_loaded) {
561 if (jcp_.ndims == 5) mov(reg_scratch_preserved_, reg_scratch_);
562
563 if (jcp_.ndims > 3)
564 mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr_);
565 else
566 mov(reg_zp_src_pad_comp,
567 qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
568
569 base_comp_addr_loaded = true;
570 }
571 };
572
573 const auto load_zp_src_pad_comp = [&](const Vmm zp_pad_comp_vmm,
574 const Xbyak::Address &comp_addr,
575 const int ocb) {
576 const bool is_tail = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
577
578 if (is_tail)
579 load_data(data_type::s32, zp_pad_comp_vmm, comp_addr,
580 get_tail_size());
581 else
582 uni_vmovups(zp_pad_comp_vmm, comp_addr);
583 };
584
585 const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) {
586 const auto kw_offset = it_kw * jcp_.oc_without_padding * jcp_.ngroups;
587 const auto oc_offset = ocb * jcp_.oc_block;
588
589 return (kw_offset + oc_offset) * sizeof(int32_t);
590 };
591
592 for (int it_kw = 0; it_kw < jcp_.kw; ++it_kw) {
593 const int ow_start = get_ow_start(it_kw, l_overflow);
594 const int ow_end = get_ow_end(ur_w, it_kw, r_overflow);
595
596 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
597 Vmm zp_src_comp_pad_vmm; // will be assigned later
598 bool ocb_zp_loaded = false;
599
600 const auto zp_src_comp_pad_off
601 = get_zp_src_comp_pad_off(it_kw, ocb);
602
603 for (int it_ow = 0; it_ow < ur_w; ++it_ow) {
604
605 const bool inside_padded_area = h_padded
606 || !(it_ow >= ow_start && it_ow < ow_end
607 && ((it_ow + jcp_.l_pad - it_kw) % jcp_.stride_w
608 == 0));
609
610 if (inside_padded_area) {
611 load_base_zp_src_pad_comp_addr();
612
613 if (!ocb_zp_loaded) {
614 zp_src_comp_pad_vmm = get_next_comp_vmm();
615 const auto comp_addr = ptr[reg_zp_src_pad_comp
616 + zp_src_comp_pad_off];
617 load_zp_src_pad_comp(
618 zp_src_comp_pad_vmm, comp_addr, ocb);
619 ocb_zp_loaded = true;
620 }
621
622 const auto vmm_dst = vmm_out(it_ow, ocb);
623 uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm);
624 }
625 }
626 }
627 }
628
629 if (jcp_.ndims > 3) {
630 if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr();
631
632 const auto kh_offset = jcp_.kw * jcp_.oc_without_padding * jcp_.ngroups
633 * sizeof(int32_t);
634
635 add(reg_zp_src_pad_comp, kh_offset);
636 mov(zp_src_pad_comp_addr_, reg_zp_src_pad_comp);
637 }
638
639 if (jcp_.ndims == 5 && base_comp_addr_loaded)
640 mov(reg_scratch_, reg_scratch_preserved_);
641}
642
643template <cpu_isa_t isa, typename Vmm>
644void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute_ker(int ur_w,
645 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
646 bool h_padded) {
647
648 const bool signed_input_or_src_zp
649 = (jcp_.signed_input || jcp_.src_zero_point);
650
651 const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
652 const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp_.stride_w;
653
654 const auto src_offset = [=](int oj, int icb, int ki) {
655 return jcp_.typesize_in
656 * (((oj + jcp_.l_pad - ki * (jcp_.dilate_w + 1))
657 / jcp_.stride_w)
658 * jcp_.ngroups * jcp_.ic_without_padding
659 + icb * 4);
660 };
661
662 const auto kernel_offset = [=](int ocb, int icb, int ki) {
663 return jcp_.typesize_in
664 * ((ocb * jcp_.nb_ic * jcp_.kd * jcp_.kh * jcp_.kw + ki)
665 * ch_block_all
666 + icb * jcp_.oc_block * IC_SUB_STEP);
667 };
668
669 for (int ki = 0; ki < jcp_.kw; ki++) {
670
671 const int jj_start = get_ow_start(ki, l_overflow);
672 const int jj_end = get_ow_end(ur_w, ki, r_overflow);
673
674 const int _start = (signed_input_or_src_zp) ? 0 : jj_start;
675 const int _end = (signed_input_or_src_zp) ? ur_w : jj_end;
676
677 const int tail_size = jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
678 : jcp_.ic_without_padding % 4;
679 const int n_ic_blocks = jcp_.is_depthwise
680 ? 1
681 : (last_ic_block_flag != no_last_block ? div_up(
682 jcp_.ic_without_padding % jcp_.ic_block, 4)
683 : jcp_.ic_block / 4);
684
685 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
686 if (h_padded) {
687 /* fill padded area with shifted values */
688 if (jcp_.signed_input) {
689 const Vmm inp = vmm_inp(0, jcp_.nb_oc_blocking);
690 uni_vpxor(inp, inp, inp);
691 uni_vpsubb(inp, inp, vmm_shift_);
692 }
693 } else {
694
695 for (int jj = _start; jj < _end; jj += ur_w_stride) {
696
697 const int aux_src_off = src_offset(jj, icb1, ki);
698 const auto vmm_src = vmm_inp(jj, jcp_.nb_oc_blocking);
699
700 if (jj >= jj_start && jj < jj_end
701 && ((jj + jcp_.l_pad - ki) % jcp_.stride_w == 0)) {
702 if (jcp_.is_depthwise) {
703 if (tail_size != 0)
704 assert(jcp_.nb_oc_blocking == 1);
705 uni_vpxor(vmm_src, vmm_src, vmm_src);
706 const bool mask_flag
707 = last_ic_block_flag != no_last_block
708 && tail_size;
709 load_data(data_type::u8, vmm_src, aux_reg_src_,
710 aux_src_off,
711 mask_flag ? tail_size : jcp_.ch_block);
712 } else if ((last_ic_block_flag == last_sp_block)
713 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
714 const auto vmm_inp_tmp = Xmm(vmm_src.getIdx());
715 load_bytes(vmm_inp_tmp, aux_reg_src_, aux_src_off,
716 tail_size);
717 uni_vpbroadcastd(vmm_src, vmm_inp_tmp);
718 } else {
719 uni_vpbroadcastd(
720 vmm_src, ptr[aux_reg_src_ + aux_src_off]);
721 }
722 if (jcp_.signed_input)
723 uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
724 } else {
725 /* fill padded area with shifted values */
726 if (jcp_.signed_input) {
727 uni_vpxor(vmm_src, vmm_src, vmm_src);
728 uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
729 }
730 }
731 }
732 }
733 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
734 const int aux_filt_off = kernel_offset(ocb, icb1, ki);
735
736 if (_end - _start > 0) {
737 if (jcp_.is_depthwise) {
738 uni_vpmovsxbd(
739 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
740 } else
741 uni_vmovups(
742 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
743 }
744
745 for (int jj = _start; jj < _end; jj += ur_w_stride) {
746
747 const bool inside_padded_area = h_padded
748 || !(jj >= jj_start && jj < jj_end
749 && ((jj + jcp_.l_pad - ki) % jcp_.stride_w
750 == 0));
751 const auto vmm_dst = vmm_out(jj, ocb);
752 if (jcp_.signed_input || !inside_padded_area) {
753 const Vmm inp = vmm_inp(
754 h_padded ? 0 : jj, jcp_.nb_oc_blocking);
755 compute(vmm_dst, vmm_wei_, inp);
756 }
757 }
758 }
759 }
760 }
761
762 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
763 apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded);
764}
765
766template <cpu_isa_t isa, typename Vmm>
767void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::kh_loop(int ur_w,
768 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
769
770 const bool signed_input_or_src_zp
771 = (jcp_.signed_input || jcp_.src_zero_point);
772
773 const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
774 const int shift_src_ih = jcp_.typesize_in * (jcp_.dilate_h + 1) * jcp_.iw
775 * jcp_.ngroups * jcp_.ic_without_padding;
776 const int shift_src_id = jcp_.typesize_in * (jcp_.dilate_d + 1) * jcp_.ih
777 * jcp_.iw * jcp_.ngroups * jcp_.ic_without_padding;
778 const int stride_h = signed_input_or_src_zp ? 1 : jcp_.stride_h;
779 const int shift_filt_kh
780 = jcp_.typesize_in * jcp_.kw * ch_block_all * stride_h;
781 const int stride_d = signed_input_or_src_zp ? 1 : jcp_.stride_d;
782 const int shift_filt_kd
783 = jcp_.typesize_in * jcp_.kw * ch_block_all * jcp_.kh * stride_d;
784
785 Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop;
786 Label t_overflow_label, no_t_overflow_label, b_overflow_label,
787 no_b_overflow_label;
788 Label back_overflow_label, no_back_overflow_label, d_h_overflow_label,
789 front_overflow_label, no_front_overflow_label, d_h_overflow_label2;
790 if (jcp_.ndims == 5) {
791 mov(aux_reg_filt_d_, reg_filt_);
792 mov(aux_reg_src_d_, reg_src_);
793
794 if (signed_input_or_src_zp) {
795 mov(reg_ki_, ptr[param1_ + GET_OFF(back_overflow)]);
796 cmp(reg_ki_, 0);
797 je(no_back_overflow_label, T_NEAR);
798
799 L(back_overflow_label);
800 {
801 mov(aux_reg_filt_, aux_reg_filt_d_);
802 mov(reg_kh_, jcp_.kh);
803 L(d_h_overflow_label);
804 {
805 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
806 add(aux_reg_filt_, shift_filt_kh);
807 dec(reg_kh_);
808 jnz(d_h_overflow_label);
809 }
810
811 add(aux_reg_filt_d_, shift_filt_kd);
812 dec(reg_ki_);
813 jnz(back_overflow_label);
814 }
815 L(no_back_overflow_label);
816 }
817
818 mov(reg_ki_, ptr[param1_ + GET_OFF(kd_padding)]);
819
820 if ((signed_input_or_src_zp) || (jcp_.dilate_d >= jcp_.id)
821 || ((!signed_input_or_src_zp)
822 && ((min(jcp_.f_pad, jcp_.back_pad) < 0)
823 || ((jcp_.kd - 1) * (jcp_.dilate_d + 1)
824 < nstl::max(
825 jcp_.f_pad, jcp_.back_pad))))) {
826 cmp(reg_ki_, 0);
827 je(skip_kd_loop, T_NEAR);
828 }
829
830 L(kd_loop_label);
831 mov(aux_reg_src_, aux_reg_src_d_);
832 mov(aux_reg_filt_, aux_reg_filt_d_);
833 } else {
834 mov(aux_reg_src_, reg_src_);
835 mov(aux_reg_filt_, reg_filt_);
836 }
837
838 if (signed_input_or_src_zp && jcp_.ndims > 3) {
839 /* Weights are transposed, so first compute 'bottom' padding. */
840 mov(reg_overflow_, ptr[param1_ + GET_OFF(b_overflow)]);
841 cmp(reg_overflow_, 0);
842 je(no_b_overflow_label, T_NEAR);
843 L(b_overflow_label);
844 {
845 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
846
847 add(aux_reg_filt_, shift_filt_kh);
848 dec(reg_overflow_);
849 cmp(reg_overflow_, 0);
850 jg(b_overflow_label, T_NEAR);
851 }
852 L(no_b_overflow_label);
853 }
854
855 mov(reg_kh_, ptr[param1_ + GET_OFF(kh_padding)]);
856
857 if ((signed_input_or_src_zp) || (jcp_.dilate_h >= jcp_.ih)
858 || ((!signed_input_or_src_zp)
859 && ((min(jcp_.t_pad, jcp_.b_pad) < 0)
860 || ((jcp_.kh - 1) * (jcp_.dilate_h + 1)
861 < nstl::max(jcp_.t_pad, jcp_.b_pad))))) {
862 cmp(reg_kh_, 0);
863 je(skip_kh_loop, T_NEAR);
864 }
865
866 L(kh_loop_label);
867 {
868 compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
869 sub(aux_reg_src_, shift_src_ih);
870 add(aux_reg_filt_, shift_filt_kh);
871 dec(reg_kh_);
872
873 /* Insert weight compensation in stride 'holes' */
874 if (signed_input_or_src_zp && jcp_.stride_h > 1) {
875 Label kh_comp_loop;
876
877 cmp(reg_kh_, 0);
878 je(skip_kh_loop, T_NEAR);
879 mov(reg_comp_strides_, jcp_.stride_h - 1);
880 L(kh_comp_loop);
881 {
882 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
883 add(aux_reg_filt_, shift_filt_kh);
884 dec(reg_comp_strides_);
885 cmp(reg_comp_strides_, 0);
886 jg(kh_comp_loop, T_NEAR);
887 }
888 }
889 cmp(reg_kh_, 0);
890 jg(kh_loop_label, T_NEAR);
891 }
892 L(skip_kh_loop);
893 if (signed_input_or_src_zp && jcp_.ndims > 3) {
894 mov(reg_overflow_, ptr[param1_ + GET_OFF(t_overflow)]);
895 cmp(reg_overflow_, 0);
896 je(no_t_overflow_label, T_NEAR);
897 L(t_overflow_label);
898 {
899 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
900
901 add(aux_reg_filt_, shift_filt_kh);
902 dec(reg_overflow_);
903 cmp(reg_overflow_, 0);
904 jg(t_overflow_label, T_NEAR);
905 }
906 L(no_t_overflow_label);
907 }
908
909 if (jcp_.ndims == 5) {
910 sub(aux_reg_src_d_, shift_src_id);
911 add(aux_reg_filt_d_, shift_filt_kd);
912 dec(reg_ki_);
913
914 /* Insert weight compensation in stride 'holes' */
915 if (signed_input_or_src_zp && jcp_.stride_d > 1) {
916 Label kd_comp_loop, kd_kh_comp_loop;
917 cmp(reg_ki_, 0);
918 jz(skip_kd_loop, T_NEAR);
919 mov(reg_comp_strides_, jcp_.stride_d - 1);
920 L(kd_comp_loop);
921 mov(aux_reg_filt_, aux_reg_filt_d_);
922 mov(reg_kh_, jcp_.kh);
923 L(kd_kh_comp_loop);
924 {
925 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
926 add(aux_reg_filt_, shift_filt_kh);
927 dec(reg_kh_);
928 jnz(kd_kh_comp_loop, T_NEAR);
929 }
930 add(aux_reg_filt_d_, shift_filt_kd);
931 dec(reg_comp_strides_);
932 jnz(kd_comp_loop);
933 }
934
935 cmp(reg_ki_, 0);
936 jg(kd_loop_label, T_NEAR);
937 L(skip_kd_loop);
938 if (signed_input_or_src_zp) {
939 mov(reg_ki_, ptr[param1_ + GET_OFF(f_overflow)]);
940 cmp(reg_ki_, 0);
941 jz(no_front_overflow_label, T_NEAR);
942 L(front_overflow_label);
943 {
944 mov(aux_reg_filt_, aux_reg_filt_d_);
945 mov(reg_kh_, jcp_.kh);
946 L(d_h_overflow_label2);
947 {
948 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
949 add(aux_reg_filt_, shift_filt_kh);
950 dec(reg_kh_);
951 jnz(d_h_overflow_label2);
952 }
953 add(aux_reg_filt_d_, shift_filt_kd);
954 dec(reg_ki_);
955 jnz(front_overflow_label);
956 }
957 L(no_front_overflow_label);
958 }
959 }
960}
961
962template <cpu_isa_t isa, typename Vmm>
963void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) {
964 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
965 for (int ur = 0; ur < ur_w; ur++) {
966 const Vmm vmm = vmm_out(ur, ocb);
967 uni_vpxor(vmm, vmm, vmm);
968 }
969 }
970 if (jcp_.signed_input) {
971 const auto xmm_shift = Xbyak::Xmm(vmm_shift_.getIdx());
972 mov(reg_scratch_, 0x80808080);
973 uni_vmovq(xmm_shift, reg_scratch_);
974 uni_vpbroadcastd(vmm_shift_, xmm_shift);
975 }
976}
977
978template <cpu_isa_t isa, typename Vmm>
979void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in,
980 const Vmm vmm_in, const Reg64 reg, int offset, int load_size) {
981
982 load_data(type_in, vmm_in, reg, offset, load_size);
983 if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in);
984}
985
986template <cpu_isa_t isa, typename Vmm>
987void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_postops(int ur_w,
988 bool last_oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) {
989 const auto sum_injector = [=]() {
990 if (p_sum_scale) { // post_op: sum
991 for (int k = 0; k < jcp_.nb_oc_blocking; k++) {
992 const bool mask_flag
993 = last_oc_block == 1 && k == jcp_.nb_oc_blocking - 1;
994 for (int j = 0; j < ur_w; j++) {
995 const int aux_output_offset = jcp_.typesize_out
996 * (k * jcp_.oc_block
997 + j * jcp_.oc_without_padding
998 * jcp_.ngroups);
999 cvt2ps(jcp_.dst_dt, vmm_prev_dst_, reg_dst_,
1000 aux_output_offset,
1001 mask_flag ? get_tail_size() : get_blocking_size());
1002 if (*p_sum_zp != 0) {
1003 uni_vbroadcastss(vmm_sum_zp_, ptr[reg_ptr_sum_zp_]);
1004 uni_vcvtdq2ps(vmm_sum_zp_, vmm_sum_zp_);
1005 uni_vsubps(vmm_prev_dst_, vmm_prev_dst_, vmm_sum_zp_);
1006 }
1007 const Vmm vmm = vmm_out(j, k);
1008 if (*p_sum_scale == 1.f)
1009 uni_vaddps(vmm, vmm, vmm_prev_dst_);
1010 else {
1011 uni_vbroadcastss(vmm_tmp_, ptr[reg_ptr_sum_scale_]);
1012 uni_vfmadd231ps(vmm, vmm_prev_dst_, vmm_tmp_);
1013 }
1014 }
1015 }
1016 }
1017 };
1018
1019 if (p_sum_scale)
1020 postops_injector_->set_lambda_injector(
1021 primitive_kind::sum, sum_injector);
1022
1023 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1024 if (jcp_.with_binary) {
1025 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1026 const bool mask_flag
1027 = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1028 for (int ur = 0; ur < ur_w; ur++) {
1029 const int vmm_idx = vmm_out(ur, ocb).getIdx();
1030 const size_t aux_output_offset = jcp_.typesize_out
1031 * (ocb * jcp_.oc_block
1032 + ur * jcp_.oc_without_padding * jcp_.ngroups);
1033
1034 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst_);
1035 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1036 vmm_idx, aux_output_offset);
1037 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1038 }
1039 }
1040 }
1041 const int nb_oc_block
1042 = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking;
1043 postops_injector_->compute_vector_range(
1044 16 - nb_oc_block * ur_w, 16, rhs_arg_params);
1045}
1046
1047template <cpu_isa_t isa, typename Vmm>
1048void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::store_output(
1049 int ur_w, bool last_oc_block) {
1050 mov(reg_bias_, ptr[param1_ + GET_OFF(bias)]);
1051 mov(reg_ptr_scales_, ptr[param1_ + GET_OFF(scales)]);
1052
1053 if (jcp_.signed_input)
1054 mov(reg_compensation_, ptr[param1_ + GET_OFF(compensation)]);
1055
1056 if (jcp_.src_zero_point) {
1057 mov(reg_zp_src_, ptr[param1_ + GET_OFF(src_zero_point)]);
1058 mov(reg_zp_compensation_, ptr[param1_ + GET_OFF(zp_compensation)]);
1059 }
1060
1061 const auto &p = jcp_.post_ops;
1062 const int sum_idx = p.find(primitive_kind::sum);
1063 const float *p_sum_scale
1064 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
1065 const int32_t *p_sum_zp
1066 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
1067
1068 if (jcp_.src_zero_point) {
1069 const auto &vmm_src_zp = vmm_tmp_;
1070 const auto &vmm_zp_comp = vmm_scale_;
1071 uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]);
1072
1073 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1074 const int zp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1075 const bool mask_flag
1076 = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1077 const int load_size
1078 = mask_flag ? get_tail_size() : get_blocking_size();
1079 load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation_,
1080 zp_offset, load_size);
1081 uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_src_zp);
1082
1083 for (int ur = 0; ur < ur_w; ur++) {
1084 const auto vmm_dst = vmm_out(ur, ocb);
1085 uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp);
1086 }
1087 }
1088 }
1089
1090 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1091 const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1092 const int scale_offset
1093 = jcp_.is_oc_scale * (sizeof(float) * ocb * jcp_.oc_block);
1094
1095 const auto vmm_bias_ = vmm_tmp_;
1096 if (jcp_.with_bias) {
1097 int bias_offset = jcp_.typesize_bia * ocb * jcp_.oc_block;
1098 cvt2ps(jcp_.bia_dt, vmm_bias_, reg_bias_, bias_offset,
1099 mask_flag ? get_tail_size() : get_blocking_size());
1100 }
1101 if (jcp_.signed_input) {
1102 const int comp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1103 cvt2ps(data_type::s32, vmm_comp_, reg_compensation_, comp_offset,
1104 mask_flag ? get_tail_size() : get_blocking_size());
1105 }
1106
1107 /* add to ymm_accum: compensation, bias */
1108 if (mask_flag) {
1109 load_data(data_type::f32, vmm_scale_, reg_ptr_scales_, scale_offset,
1110 get_tail_size());
1111 } else {
1112 uni_vmovups(vmm_scale_, ptr[reg_ptr_scales_ + scale_offset]);
1113 }
1114 for (int ur = 0; ur < ur_w; ur++) {
1115 const Vmm vmm = vmm_out(ur, ocb);
1116 uni_vcvtdq2ps(vmm, vmm);
1117 if (jcp_.signed_input) uni_vaddps(vmm, vmm, vmm_comp_);
1118 uni_vmulps(vmm, vmm, vmm_scale_);
1119 if (jcp_.with_bias) uni_vaddps(vmm, vmm, vmm_bias_);
1120 }
1121 }
1122
1123 if (p_sum_scale && *p_sum_scale != 1.f)
1124 mov(reg_ptr_sum_scale_, reinterpret_cast<size_t>(p_sum_scale));
1125 if (p_sum_zp && *p_sum_zp != 0) {
1126 mov(reg_ptr_sum_zp_, reinterpret_cast<size_t>(p_sum_zp));
1127 }
1128 if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum)
1129 apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp);
1130 if (jcp_.dst_scale) {
1131 mov(reg_ptr_dst_scales_, ptr[param1_ + GET_OFF(dst_scale)]);
1132 uni_vmovups(vmm_dst_scale_, ptr[reg_ptr_dst_scales_]);
1133
1134 for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++)
1135 for (int ur = 0; ur < ur_w; ur++) {
1136 const auto vmm = vmm_out(ur, ocb);
1137 uni_vmulps(vmm, vmm, vmm_dst_scale_);
1138 }
1139 }
1140 if (jcp_.dst_zero_point) {
1141 mov(reg_zp_dst_, ptr[param1_ + GET_OFF(dst_zero_point)]);
1142 const auto &vmm_zp_dst = vmm_tmp_;
1143 uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]);
1144 uni_vcvtdq2ps(vmm_zp_dst, vmm_zp_dst);
1145
1146 for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++)
1147 for (int ur = 0; ur < ur_w; ur++) {
1148 const auto vmm_dst = vmm_out(ur, ocb);
1149 uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst);
1150 }
1151 }
1152
1153 // Properly saturate the accumulators for integer datatypes
1154
1155 // No need to saturate on lower bound for signed integer types, as
1156 // the conversion to int would return INT_MIN, and then proper
1157 // saturation will happen when storing data
1158 if (jcp_.dst_dt == data_type::u8) {
1159 uni_vpxor(vmm_zero_, vmm_zero_, vmm_zero_);
1160 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1161 for (int ur = 0; ur < ur_w; ur++) {
1162 const Vmm vmm = vmm_out(ur, ocb);
1163 uni_vmaxps(vmm, vmm, vmm_zero_);
1164 }
1165 }
1166 }
1167
1168 if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1169 float saturation_ubound = types::max_value<float>(jcp_.dst_dt);
1170 const Xmm xmm_saturation(vmm_saturation_.getIdx());
1171 mov(reg_ptr_saturation_ubound_, float2int(saturation_ubound));
1172 uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound_);
1173 uni_vbroadcastss(vmm_saturation_, xmm_saturation);
1174
1175 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1176 for (int ur = 0; ur < ur_w; ur++) {
1177 const Vmm vmm = vmm_out(ur, ocb);
1178 uni_vminps(vmm, vmm, vmm_saturation_);
1179 }
1180 }
1181 }
1182
1183 if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1184 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1185 for (int ur = 0; ur < ur_w; ur++) {
1186 const Vmm vmm = vmm_out(ur, ocb);
1187 uni_vcvtps2dq(vmm, vmm);
1188 }
1189 }
1190 }
1191
1192 /* write out register to output_addr */
1193 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1194 const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1195 for (int ur = 0; ur < ur_w; ur++) {
1196 const int aux_dst_off = jcp_.typesize_out
1197 * (ur * jcp_.ngroups * jcp_.oc_without_padding
1198 + ocb * jcp_.oc_block);
1199 const Vmm r_vmm = vmm_out(ur, ocb);
1200 store_data(jcp_.dst_dt, r_vmm, reg_dst_, aux_dst_off,
1201 mask_flag ? get_tail_size() : get_blocking_size());
1202 }
1203 }
1204}
1205
1206template <cpu_isa_t isa, typename Vmm>
1207void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::icb_loop(
1208 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
1209
1210 const int shift_src_icb = jcp_.typesize_in * jcp_.ic_block;
1211 const size_t shift_filt_icb = (size_t)jcp_.typesize_in * jcp_.kd * jcp_.kh
1212 * jcp_.kw * jcp_.ic_block * jcp_.oc_block;
1213
1214 prepare_output(ur_w);
1215
1216 Label skip_icb_loop, icb_loop_label;
1217
1218 mov(reg_icb_, jcp_.nb_ic);
1219 mov(reg_oc_blocks_, ptr[param1_ + GET_OFF(oc_blocks)]);
1220
1221 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)
1222 && jcp_.ndims > 3) {
1223 mov(reg_scratch_,
1224 qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
1225 mov(zp_src_pad_comp_addr_, reg_scratch_);
1226 }
1227
1228 L(icb_loop_label);
1229 {
1230 if (jcp_.ngroups % jcp_.ch_block != 0
1231 || jcp_.ic_without_padding != jcp_.ic) {
1232 Label common_ker, end_ker;
1233 if (jcp_.is_depthwise) {
1234 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1235 jne(common_ker, T_NEAR);
1236 } else {
1237 cmp(reg_icb_, 1);
1238 jg(common_ker, T_NEAR);
1239 }
1240
1241 kh_loop(ur_w, l_overflow, r_overflow, last_sp_block);
1242 jmp(end_ker, T_NEAR);
1243
1244 L(common_ker);
1245 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1246
1247 L(end_ker);
1248 } else {
1249 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1250 }
1251
1252 add(reg_src_, shift_src_icb);
1253 safe_add(reg_filt_, shift_filt_icb, reg_ker_long_offt_);
1254 dec(reg_icb_);
1255 cmp(reg_icb_, 0);
1256 jg(icb_loop_label, T_NEAR);
1257 }
1258
1259 /* come-back pointers */
1260 sub(reg_src_, jcp_.nb_ic * shift_src_icb);
1261 safe_sub(reg_filt_, jcp_.nb_ic * shift_filt_icb, reg_ker_long_offt_);
1262 L(skip_icb_loop);
1263
1264 if (jcp_.ngroups % jcp_.ch_block != 0
1265 || jcp_.oc_without_padding != jcp_.oc) {
1266 Label common_store, end_store;
1267 if (jcp_.is_depthwise)
1268 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1269 else
1270 cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
1271 jne(common_store, T_NEAR);
1272
1273 store_output(ur_w, true);
1274 jmp(end_store, T_NEAR);
1275
1276 L(common_store);
1277 store_output(ur_w, false);
1278
1279 L(end_store);
1280
1281 } else {
1282 store_output(ur_w, false);
1283 }
1284}
1285
1286template <cpu_isa_t isa, typename Vmm>
1287void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::generate() {
1288 preamble();
1289
1290 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1291 sub(rsp, reserved_stack_size_);
1292
1293 const auto vmm_one_128 = Xbyak::Xmm(vmm_one_.getIdx());
1294 mov(reg_scratch_, 0x10001);
1295 uni_vmovq(vmm_one_128, reg_scratch_);
1296 uni_vpbroadcastd(vmm_one_, vmm_one_128);
1297
1298 mov(reg_src_, ptr[param1_ + GET_OFF(src)]);
1299 mov(reg_filt_, ptr[param1_ + GET_OFF(filt)]);
1300 mov(reg_dst_, ptr[param1_ + GET_OFF(dst)]);
1301
1302 const int dst_shift = jcp_.typesize_out * jcp_.ur_w * jcp_.ngroups
1303 * jcp_.oc_without_padding;
1304 const int src_shift = jcp_.typesize_in * (jcp_.ur_w / jcp_.stride_w)
1305 * jcp_.ngroups * jcp_.ic_without_padding;
1306
1307 const int l_overflow = max(0,
1308 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - jcp_.l_pad) / jcp_.stride_w);
1309 const int r_overflow = max(0,
1310 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - max(0, jcp_.r_pad))
1311 / jcp_.stride_w);
1312
1313 const int r_overflow1 = nstl::max(0,
1314 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - nstl::max(0, jcp_.r_pad)
1315 - jcp_.ur_w_tail)
1316 / jcp_.stride_w);
1317 int nur_w = jcp_.ow / jcp_.ur_w;
1318 if (r_overflow1 > 0) nur_w--;
1319
1320 if (jcp_.ur_w == jcp_.ow) {
1321 icb_loop(jcp_.ur_w, l_overflow, r_overflow, true);
1322 } else if (nur_w == 0) {
1323 icb_loop(jcp_.ur_w, l_overflow, r_overflow1, jcp_.ur_w_tail == 0);
1324 add(reg_src_, src_shift);
1325 add(reg_dst_, dst_shift);
1326 if (jcp_.ur_w_tail != 0) icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1327 } else {
1328 xor_(reg_nur_w_, reg_nur_w_);
1329 if (l_overflow > 0) {
1330 icb_loop(jcp_.ur_w, l_overflow, 0, false);
1331 add(reg_src_, src_shift);
1332 add(reg_dst_, dst_shift);
1333 inc(reg_nur_w_);
1334 }
1335 if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
1336 Label ow_loop_label;
1337 L(ow_loop_label);
1338 {
1339 icb_loop(jcp_.ur_w, 0, 0, false);
1340 add(reg_src_, src_shift);
1341 add(reg_dst_, dst_shift);
1342 inc(reg_nur_w_);
1343 cmp(reg_nur_w_, nur_w);
1344 jl(ow_loop_label, T_NEAR);
1345 }
1346 }
1347 if (r_overflow1 > 0) {
1348 icb_loop(jcp_.ur_w, 0, r_overflow1, jcp_.ur_w_tail == 0);
1349 add(reg_src_, src_shift);
1350 add(reg_dst_, dst_shift);
1351 }
1352 if (jcp_.ur_w_tail != 0) {
1353 icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1354 }
1355 }
1356
1357 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1358 add(rsp, reserved_stack_size_);
1359
1360 postamble();
1361
1362 if (jcp_.with_eltwise) postops_injector_->prepare_table();
1363}
1364
1365template <cpu_isa_t isa>
1366jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::jit_uni_x8s8s32x_deconvolution_fwd_t(
1367 const pd_t *apd)
1368 : primitive_t(apd) {}
1369
1370template <cpu_isa_t isa>
1371jit_uni_x8s8s32x_deconvolution_fwd_t<
1372 isa>::~jit_uni_x8s8s32x_deconvolution_fwd_t()
1373 = default;
1374
1375template <cpu_isa_t isa>
1376status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::pd_t::init(
1377 engine_t *engine) {
1378 using namespace data_type;
1379 using skip_mask_t = primitive_attr_t::skip_mask_t;
1380 const bool ok = true && is_fwd()
1381 && (desc()->alg_kind & alg_kind::deconvolution_direct)
1382 && utils::one_of(src_md(0)->data_type, s8, u8)
1383 && weights_md(0)->data_type == s8
1384 && IMPLICATION(with_bias(),
1385 utils::one_of(weights_md(1)->data_type, f32, s32, s8, u8))
1386 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
1387 && desc()->accum_data_type == s32
1388 && attr()->has_default_values(skip_mask_t::scales_runtime
1389 | skip_mask_t::post_ops | skip_mask_t::zero_points_runtime);
1390 if (!ok) return status::unimplemented;
1391
1392 CHECK(jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(jcp_, *desc(),
1393 src_md_, weights_md_, dst_md_, with_bias(), bias_md_, attr_,
1394 dnnl_get_max_threads()));
1395
1396 auto scratchpad = scratchpad_registry().registrar();
1397 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
1398 scratchpad, jcp_, *attr());
1399
1400 return status::success;
1401}
1402
1403template <cpu_isa_t isa>
1404status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::init(engine_t *engine) {
1405 CHECK(safe_ptr_assign(kernel_,
1406 new jit_uni_x8s8s32x_deconv_fwd_kernel<isa>(pd()->jcp_,
1407 *pd()->attr(), memory_desc_wrapper(pd()->dst_md()))));
1408
1409 if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) {
1410 CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_,
1411 zp::create_deconv_zp_pad_str_comp_ker<isa>(pd()->jcp_)));
1412 const auto zp_kernel_status = zp_src_pad_comp_kernel_->create_kernel();
1413 if (zp_kernel_status != status::success) return zp_kernel_status;
1414 }
1415
1416 return kernel_->create_kernel();
1417}
1418
1419template <cpu_isa_t isa>
1420const float *jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::adjust_oscales(
1421 const memory_tracking::grantor_t &scratchpad, const float *src_scales,
1422 const float *wei_scales) const {
1423 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1424 int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
1425 float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
1426 ? 1.f / pd()->jcp_.wei_adj_scale
1427 : 1.0f;
1428 if (wei_mask == 0) {
1429 utils::array_set(loc_scales, src_scales[0] * wei_scales[0] * factor, 8);
1430 } else {
1431 for (dim_t c = 0; c < pd()->OC(); c++)
1432 loc_scales[c] = src_scales[0] * wei_scales[c] * factor;
1433 }
1434 return loc_scales;
1435}
1436
1437template <cpu_isa_t isa>
1438status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute(
1439 const exec_ctx_t &ctx) const {
1440 const auto &_pd = pd();
1441 const auto &ndims = _pd->ndims();
1442 if (ndims == 3)
1443 return execute_forward_1d(ctx);
1444 else if (ndims == 4)
1445 return execute_forward_2d(ctx);
1446 else if (ndims == 5)
1447 return execute_forward_3d(ctx);
1448 else
1449 return status::unimplemented;
1450 return status::success;
1451}
1452
1453template <cpu_isa_t isa>
1454status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_1d(
1455 const exec_ctx_t &ctx) const {
1456 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1457 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1458 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1459 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1460 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1461 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1462
1463 const auto &jcp = pd()->jcp_;
1464
1465 const memory_desc_wrapper src_d(pd()->src_md());
1466 const memory_desc_wrapper dst_d(pd()->dst_md());
1467 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1468 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1469
1470 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1471
1472 const auto post_ops_binary_rhs_arg_vec
1473 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1474 auto scratchpad = ctx.get_scratchpad_grantor();
1475 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1476
1477 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1478 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1479 weights_d, weights, zp_src, zp_src_comp_scratch,
1480 zp_src_pad_comp_kernel_.get());
1481
1482 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1483 const int nb_groups = jcp.nb_ch;
1484
1485 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1486 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1487 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1488
1489 const float *oscales = adjust_oscales(
1490 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1491
1492 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1493 auto w = const_cast<int8_t *>(weights);
1494 int32_t *compensation = (jcp.signed_input)
1495 ? reinterpret_cast<int32_t *>(&w[offset])
1496 : nullptr;
1497 const int32_t *zp_compensation = jcp.src_zero_point
1498 ? get_src_zp_comp_from_wei(
1499 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1500 : nullptr;
1501
1502 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1503 int start {0}, end {0};
1504 const int work_amount = jcp.mb * nb_groups * oc_chunks;
1505 balance211(work_amount, nthr, ithr, start, end);
1506
1507 auto p = jit_deconv_call_s();
1508
1509 int n {0}, g {0}, occ {0};
1510 if (jcp.loop_order == loop_ngc)
1511 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
1512 else if (jcp.loop_order == loop_cgn)
1513 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
1514 else
1515 assert(!"unsupported loop order");
1516 while (start < end) {
1517
1518 const int ocb = occ * jcp.nb_oc_blocking;
1519 const int g_oc
1520 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1521 const int g_ic = g * jcp.ch_block * jcp.ic;
1522
1523 p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1524 p.src = src + src_d.blk_off(n, g_ic);
1525 p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
1526 p.bias = jcp.with_bias
1527 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1528 : nullptr;
1529 p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
1530 p.scales = &oscales[jcp.is_oc_scale * g_oc];
1531 p.dst_scale = dst_scales;
1532 p.t_overflow = 0;
1533 p.b_overflow = 0;
1534 p.kh_padding = jcp.kh;
1535 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1536 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1537 p.oc_l_off = g_oc;
1538 p.zp_compensation
1539 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1540 p.zp_src_pad_str_compensation = zp_src_comp_scratch
1541 ? zp_src_comp_scratch + g_oc
1542 : nullptr;
1543 p.src_zero_point = zp_src;
1544 p.dst_zero_point = zp_dst;
1545 p.dst_orig = dst;
1546
1547 (*kernel_)(&p);
1548
1549 ++start;
1550 if (jcp.loop_order == loop_ngc)
1551 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
1552 else if (jcp.loop_order == loop_cgn)
1553 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
1554 else
1555 assert(!"unsupported loop order");
1556 }
1557 });
1558 return status::success;
1559}
1560
1561template <cpu_isa_t isa>
1562status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_2d(
1563 const exec_ctx_t &ctx) const {
1564 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1565 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1566 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1567 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1568 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1569 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1570
1571 const memory_desc_wrapper src_d(pd()->src_md());
1572 const memory_desc_wrapper dst_d(pd()->dst_md());
1573 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1574 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1575
1576 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1577
1578 const auto &jcp = pd()->jcp_;
1579 const auto post_ops_binary_rhs_arg_vec
1580 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1581
1582 auto scratchpad = ctx.get_scratchpad_grantor();
1583 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1584
1585 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1586 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1587 weights_d, weights, zp_src, zp_src_comp_scratch,
1588 zp_src_pad_comp_kernel_.get());
1589
1590 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1591 const int nb_groups = jcp.nb_ch;
1592
1593 const size_t src_h_stride = src_d.blk_off(0, 0, 1);
1594 const size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
1595 const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1596
1597 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1598 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1599 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1600
1601 const float *oscales = adjust_oscales(
1602 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1603
1604 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1605 auto w = const_cast<int8_t *>(weights);
1606 int32_t *compensation = (jcp.signed_input)
1607 ? reinterpret_cast<int32_t *>(&w[offset])
1608 : nullptr;
1609 const int32_t *zp_compensation = jcp.src_zero_point
1610 ? get_src_zp_comp_from_wei(
1611 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1612 : nullptr;
1613
1614 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1615 int start {0}, end {0};
1616 const int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
1617 balance211(work_amount, nthr, ithr, start, end);
1618
1619 auto p = jit_deconv_call_s();
1620
1621 /*loop order = cgn*/
1622 int n {0}, g {0}, occ {0}, oh_s {0};
1623 if (jcp.loop_order == loop_ngc)
1624 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1625 oh_s, jcp.oh);
1626 else if (jcp.loop_order == loop_cgn)
1627 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1628 oh_s, jcp.oh);
1629 else
1630 assert(!"unsupported loop order");
1631 while (start < end) {
1632
1633 const int ocb = occ * jcp.nb_oc_blocking;
1634 const int g_oc
1635 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1636 const int g_ic = g * jcp.ch_block * jcp.ic;
1637 const int work_rem = end - start;
1638 const int oh_e
1639 = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1640
1641 const auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1642 const auto src_w = src + src_d.blk_off(n, g_ic);
1643 const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
1644 const auto bias_w = jcp.with_bias
1645 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1646 : nullptr;
1647 const int32_t *compensation_w
1648 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1649
1650 const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1651 for (int oj = oh_s; oj < oh_e; oj++) {
1652 int ih_max = 0, kh_lo = 0, kh_len = 0;
1653 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1654 /* dilation */
1655 const int dilate_h = jcp.dilate_h + 1;
1656 // Note: use div_up to account for "holes" in filter
1657 const int o_t_overflow = div_up(
1658 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1659 dilate_h);
1660 const int o_b_overflow
1661 = div_up(max(0,
1662 (jcp.kh - 1) * dilate_h + 1
1663 - jcp.oh + oj - jcp.b_pad),
1664 dilate_h);
1665 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1666 kh_lo = o_b_overflow;
1667 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1668 } else {
1669 const int o_t_overflow = max(
1670 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1671 const int o_b_overflow = max(0,
1672 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1673 / jcp.stride_h);
1674 const int overflow_kh_hi = jcp.kh - 1
1675 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1676 jcp.stride_h);
1677 const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1678
1679 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1680 + 1 - o_t_overflow - o_b_overflow;
1681 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1682 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1683 }
1684
1685 const int wei_stride
1686 = (!jcp.signed_input && !jcp.src_zero_point)
1687 ? kh_lo * wht_kh_stride
1688 : 0;
1689 p.src = src_w + ih_max * src_h_stride;
1690 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1691 p.filt = wht_w + wei_stride;
1692 p.bias = bias_w;
1693 p.compensation = compensation_w;
1694 p.t_overflow = jcp.dilate_h > 0
1695 ? jcp.kh - kh_len - kh_lo
1696 : max(0,
1697 jcp.kh
1698 - (kh_lo
1699 + max(0, kh_len - 1)
1700 * jcp.stride_h
1701 + 1));
1702 p.b_overflow = kh_lo;
1703 p.kh_padding = kh_len;
1704 p.scales = scales;
1705 p.dst_scale = dst_scales;
1706 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1707 p.post_ops_binary_rhs_arg_vec
1708 = post_ops_binary_rhs_arg_vec.data();
1709 p.oc_l_off = g_oc;
1710 p.zp_compensation
1711 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1712 p.zp_src_pad_str_compensation = jcp.src_zero_point
1713 ? zp_src_comp_scratch + g_oc
1714 : nullptr;
1715 p.src_zero_point = zp_src;
1716 p.dst_zero_point = zp_dst;
1717 p.dst_orig = dst;
1718
1719 (*kernel_)(&p);
1720 }
1721 if (jcp.loop_order == loop_ngc)
1722 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1723 oc_chunks, oh_s, jcp.oh);
1724 else if (jcp.loop_order == loop_cgn)
1725 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1726 jcp.mb, oh_s, jcp.oh);
1727 else
1728 assert(!"unsupported loop order");
1729 }
1730 });
1731 return status::success;
1732}
1733
1734template <cpu_isa_t isa>
1735status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_3d(
1736 const exec_ctx_t &ctx) const {
1737 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1738 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1739 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1740 const auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1741 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1742 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1743
1744 const memory_desc_wrapper src_d(pd()->src_md());
1745 const memory_desc_wrapper dst_d(pd()->dst_md());
1746 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1747 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1748
1749 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1750
1751 const auto &jcp = pd()->jcp_;
1752 const auto post_ops_binary_rhs_arg_vec
1753 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1754
1755 const auto scratchpad = ctx.get_scratchpad_grantor();
1756 int32_t *const zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1757
1758 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1759 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1760 weights_d, weights, zp_src, zp_src_comp_scratch,
1761 zp_src_pad_comp_kernel_.get());
1762
1763 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1764 const int &nb_groups = jcp.nb_ch;
1765
1766 const size_t src_d_stride = src_d.blk_off(0, 0, 1);
1767 const size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
1768 const size_t dst_d_stride = dst_d.blk_off(0, 0, 1);
1769 const size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
1770 const size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1771 const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
1772
1773 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
1774 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
1775 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
1776
1777 const float *oscales = adjust_oscales(
1778 ctx.get_scratchpad_grantor(), src_scales, wei_scales);
1779
1780 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1781 auto w = const_cast<int8_t *>(weights);
1782 int32_t *compensation = (jcp.signed_input)
1783 ? reinterpret_cast<int32_t *>(&w[offset])
1784 : nullptr;
1785 const int32_t *zp_compensation = jcp.src_zero_point
1786 ? get_src_zp_comp_from_wei(
1787 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1788 : nullptr;
1789
1790 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1791 int start {0}, end {0};
1792 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh;
1793 balance211(work_amount, nthr, ithr, start, end);
1794
1795 auto p = jit_deconv_call_s();
1796
1797 /*loop order = cgn*/
1798 int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0};
1799 if (jcp.loop_order == loop_ngc)
1800 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1801 od_s, jcp.od, oh_s, jcp.oh);
1802 else if (jcp.loop_order == loop_cgn)
1803 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1804 od_s, jcp.od, oh_s, jcp.oh);
1805 else
1806 assert(!"unsupported loop order");
1807 while (start < end) {
1808
1809 const int ocb = occ * jcp.nb_oc_blocking;
1810 const int g_oc
1811 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1812 const int g_ic = g * jcp.ch_block * jcp.ic;
1813 const int work_rem = end - start;
1814 const int oh_e
1815 = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1816 int input_d_s = 0, kd_len = 0, kd_lo = 0;
1817 int d_t_overflow, d_back_overflow;
1818
1819 if (jcp.dilate_d != 0 && jcp.stride_d == 1) {
1820 /* dilation */
1821 int dilate_d = jcp.dilate_d + 1;
1822 // Note: use div_up to account for "holes" in filter
1823 d_t_overflow = div_up(
1824 max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad),
1825 dilate_d);
1826 d_back_overflow
1827 = div_up(max(0,
1828 (jcp.kd - 1) * dilate_d + 1 - jcp.od
1829 + od_s - jcp.back_pad),
1830 dilate_d);
1831 kd_len = jcp.kd - d_t_overflow - d_back_overflow;
1832 kd_lo = d_back_overflow;
1833 input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d;
1834 } else {
1835 const int d_t_overflow = max(
1836 0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d);
1837 const int d_back_overflow = max(0,
1838 ((od_s + jcp.kd) - (jcp.od + jcp.back_pad))
1839 / jcp.stride_d);
1840 const int overflow_kd_hi = jcp.kd - 1
1841 - modulo(jcp.od + jcp.back_pad - (od_s + 1),
1842 jcp.stride_d);
1843 const int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d;
1844
1845 kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
1846 - d_t_overflow - d_back_overflow;
1847 kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d;
1848 input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d;
1849 }
1850
1851 auto dst_w = dst
1852 + dst_dt_size
1853 * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride);
1854 const auto src_w
1855 = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride;
1856 const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
1857 + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo)
1858 * wht_kd_stride;
1859 const auto bias_w = jcp.with_bias
1860 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1861 : nullptr;
1862 const int32_t *compensation_w
1863 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1864
1865 const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1866
1867 for (int oj = oh_s; oj < oh_e; oj++) {
1868 int ih_max = 0, kh_lo = 0, kh_len = 0;
1869 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1870 /* dilation */
1871 const int dilate_h = jcp.dilate_h + 1;
1872 // Note: use div_up to account for "holes" in filter
1873 const int o_t_overflow = div_up(
1874 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1875 dilate_h);
1876 const int o_b_overflow
1877 = div_up(max(0,
1878 (jcp.kh - 1) * dilate_h + 1
1879 - jcp.oh + oj - jcp.b_pad),
1880 dilate_h);
1881 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1882 kh_lo = o_b_overflow;
1883 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1884 } else {
1885 const int o_t_overflow = max(
1886 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1887 const int o_b_overflow = max(0,
1888 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1889 / jcp.stride_h);
1890 const int overflow_kh_hi = jcp.kh - 1
1891 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1892 jcp.stride_h);
1893 const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1894
1895 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1896 + 1 - o_t_overflow - o_b_overflow;
1897 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1898 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1899 }
1900
1901 const int wei_stride
1902 = (!jcp.signed_input && !jcp.src_zero_point)
1903 ? kh_lo * wht_kh_stride
1904 : 0;
1905 p.src = src_w + ih_max * src_h_stride;
1906 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1907 p.filt = wht_w + wei_stride;
1908 p.bias = bias_w;
1909 p.compensation = compensation_w;
1910 /* Note: Currently this kernel doesn't support dilations and
1911 strides together */
1912 p.t_overflow = jcp.dilate_h > 0
1913 ? jcp.kh - kh_len - kh_lo
1914 : max(0,
1915 jcp.kh
1916 - (kh_lo
1917 + max(0, kh_len - 1)
1918 * jcp.stride_h
1919 + 1));
1920 p.b_overflow = kh_lo;
1921 p.f_overflow = jcp.dilate_d > 0
1922 ? jcp.kd - kd_len - kd_lo
1923 : max(0,
1924 jcp.kd
1925 - (kd_lo
1926 + max(0, kd_len - 1)
1927 * jcp.stride_d
1928 + 1));
1929 p.back_overflow = kd_lo;
1930 p.kh_padding = kh_len;
1931 p.kd_padding = kd_len;
1932 p.scales = scales;
1933 p.dst_scale = dst_scales;
1934 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1935 p.post_ops_binary_rhs_arg_vec
1936 = post_ops_binary_rhs_arg_vec.data();
1937 p.oc_l_off = g_oc;
1938 p.zp_compensation
1939 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1940 p.zp_src_pad_str_compensation = jcp.src_zero_point
1941 ? zp_src_comp_scratch + g_oc
1942 : nullptr;
1943 p.src_zero_point = zp_src;
1944 p.dst_zero_point = zp_dst;
1945 p.dst_orig = dst;
1946
1947 (*kernel_)(&p);
1948 }
1949
1950 if (jcp.loop_order == loop_ngc)
1951 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1952 oc_chunks, od_s, jcp.od, oh_s, jcp.oh);
1953 else if (jcp.loop_order == loop_cgn)
1954 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1955 jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
1956 else
1957 assert(!"unsupported loop order");
1958 }
1959 });
1960
1961 return status::success;
1962}
1963
1964using namespace data_type;
1965template struct jit_uni_x8s8s32x_deconvolution_fwd_t<avx2>;
1966template struct jit_uni_x8s8s32x_deconvolution_fwd_t<sse41>;
1967template struct jit_uni_x8s8s32x_deconv_fwd_kernel<avx2>;
1968template struct jit_uni_x8s8s32x_deconv_fwd_kernel<sse41>;
1969template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>;
1970template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Xmm>;
1971template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<sse41, Xbyak::Xmm>;
1972} // namespace x64
1973} // namespace cpu
1974} // namespace impl
1975} // namespace dnnl
1976