1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | #include "common/memory_desc_wrapper.hpp" |
19 | #include "cpu/cpu_primitive.hpp" |
20 | #include "cpu/zero_point_utils.hpp" |
21 | |
22 | #include "cpu/x64/jit_avx512_core_x8s8s32x_deconvolution.hpp" |
23 | |
24 | #define GET_OFF(field) offsetof(jit_deconv_call_s, field) |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | using namespace Xbyak; |
35 | |
36 | using namespace nstl; |
37 | |
38 | #define wht_blk_off(d, g, ...) \ |
39 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
40 | : (d).blk_off(__VA_ARGS__)) |
41 | |
42 | template <typename Vmm> |
43 | jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>:: |
44 | jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
45 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
46 | : jit_generator(jit_name()) |
47 | , jcp(ajcp) |
48 | , attr_(attr) |
49 | , postops_injector_(nullptr) { |
50 | |
51 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
52 | const std::size_t tail_size = jcp.is_depthwise |
53 | ? jcp.ngroups % jcp.ch_block |
54 | : jcp.oc_without_padding % jcp.oc_block; |
55 | |
56 | static constexpr bool preserve_gpr = true; |
57 | static constexpr bool preserve_vmm = true; |
58 | static constexpr bool use_exact_tail_scalar_bcast = false; |
59 | |
60 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
61 | static_cast<size_t>(Xbyak::Xmm(31).getIdx()), this->r14, |
62 | this->r15, this->r13, preserve_gpr, preserve_vmm, |
63 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
64 | memory_desc_wrapper(dst_md), tail_size, ktail_mask, |
65 | use_exact_tail_scalar_bcast}; |
66 | const binary_injector::static_params_t bsp {this->param1, rhs_sp}; |
67 | |
68 | postops_injector_ = utils::make_unique< |
69 | injector::jit_uni_postops_injector_t<avx512_core, Vmm>>( |
70 | this, jcp.post_ops, bsp); |
71 | } |
72 | } |
73 | |
74 | template <typename Vmm> |
75 | jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
76 | Vmm>::~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() |
77 | = default; |
78 | |
79 | status_t _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( |
80 | jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, |
81 | memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, |
82 | const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr, |
83 | int nthreads) { |
84 | const memory_desc_wrapper src_d(&src_md); |
85 | const memory_desc_wrapper dst_d(&dst_md); |
86 | const memory_desc_wrapper weights_d(&weights_md); |
87 | const memory_desc_wrapper bias_d(&bias_md); |
88 | |
89 | if (!(mayiuse(avx512_core) |
90 | && one_of(src_d.data_type(), data_type::u8, data_type::s8) |
91 | && weights_d.data_type() == data_type::s8 |
92 | && one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
93 | data_type::s8, data_type::u8))) |
94 | return status::unimplemented; |
95 | |
96 | jcp = zero<decltype(jcp)>(); |
97 | jcp.nthr = nthreads; |
98 | |
99 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
100 | jcp.signed_input = src_d.data_type() == data_type::s8; |
101 | const int ndims = jcp.ndims = dst_d.ndims(); |
102 | const bool is_1d = ndims == 3; |
103 | const bool is_2d = ndims == 4; |
104 | const bool is_3d = ndims == 5; |
105 | |
106 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
107 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
108 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
109 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
110 | jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; |
111 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
112 | jcp.is_depthwise = true && with_groups |
113 | && utils::everyone_is( |
114 | 1, jcp.ic_without_padding, jcp.oc_without_padding); |
115 | |
116 | /* TODO: future work, on hold until depthwise specialized kernel is |
117 | * implemented. */ |
118 | if (jcp.is_depthwise && (jcp.signed_input || is_3d)) |
119 | return status::unimplemented; |
120 | |
121 | if (!zero_points_valid(&attr)) return status::unimplemented; |
122 | jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC); |
123 | jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST); |
124 | jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC); |
125 | |
126 | format_tag_t dat_tag = utils::pick( |
127 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
128 | |
129 | if (src_d.format_kind() == format_kind::any) { |
130 | CHECK(memory_desc_init_by_tag(src_md, dat_tag)); |
131 | jcp.src_tag = dat_tag; |
132 | } else { |
133 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
134 | } |
135 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
136 | |
137 | if (dst_d.format_kind() == format_kind::any) { |
138 | CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); |
139 | jcp.dst_tag = dat_tag; |
140 | } else { |
141 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
142 | } |
143 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
144 | |
145 | auto set_or_check_wei_format = [&]() { |
146 | using namespace format_tag; |
147 | format_tag_t wei_tag; |
148 | if (jcp.ic_block == 16 || jcp.ch_block == 16) { |
149 | if (is_3d) { |
150 | wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i; |
151 | } else if (is_1d) { |
152 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i |
153 | : OIw4i16o4i; |
154 | } else { |
155 | assert(is_2d); |
156 | wei_tag = with_groups |
157 | ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i |
158 | : OIhw4i16o4i; |
159 | } |
160 | } else if (jcp.ic_block == 8) { |
161 | assert(with_groups); |
162 | wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i; |
163 | } else { |
164 | assert(with_groups && jcp.ic_block == 4); |
165 | wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i; |
166 | } |
167 | |
168 | memory_desc_t want_wei_md = weights_md; |
169 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
170 | if (jcp.signed_input && !jcp.is_depthwise) { |
171 | want_wei_md.extra.flags = 0 |
172 | | memory_extra_flags::compensation_conv_s8s8 |
173 | | memory_extra_flags::scale_adjust; |
174 | want_wei_md.extra.compensation_mask = (1 << 0) |
175 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
176 | want_wei_md.extra.scale_adjust |
177 | = mayiuse(avx512_core_vnni) ? 1.f : 0.5f; |
178 | } |
179 | if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups); |
180 | |
181 | if (weights_md.format_kind == format_kind::any) { |
182 | weights_md = want_wei_md; |
183 | return true; |
184 | } |
185 | |
186 | return weights_md == want_wei_md; |
187 | }; |
188 | |
189 | jcp.with_bias = with_bias; |
190 | if (jcp.with_bias) { |
191 | if (bias_d.format_kind() == format_kind::any) |
192 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
193 | } |
194 | |
195 | jcp.prop_kind = cd.prop_kind; |
196 | jcp.mb = src_d.dims()[0]; |
197 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
198 | jcp.iw = src_d.dims()[ndims - 1]; |
199 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
200 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
201 | jcp.ow = dst_d.dims()[ndims - 1]; |
202 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
203 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
204 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
205 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
206 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
207 | jcp.l_pad = cd.padding[0][ndims - 3]; |
208 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
209 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
210 | jcp.stride_w = cd.strides[ndims - 3]; |
211 | |
212 | if (jcp.is_depthwise) { |
213 | jcp.ch_block = 16; |
214 | jcp.oc_block = 1; |
215 | jcp.ic_block = 1; |
216 | } else { |
217 | jcp.ch_block = 1; |
218 | jcp.oc_block = 16; |
219 | jcp.ic_block = 16; |
220 | |
221 | if (jcp.ngroups == 1) { |
222 | jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); |
223 | jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); |
224 | } else if (jcp.ngroups != 1 |
225 | && ((jcp.ic % jcp.ic_block != 0) |
226 | || (jcp.oc % jcp.oc_block != 0))) { |
227 | /* For grouped deconvolutions, oneDNN doesn't support padding. |
228 | When channels per group is not multiple of 16: |
229 | - Use Ymm when channels per group is multiple of 8, |
230 | - Use Xmm when channels per group is multiple of 4, |
231 | - Otherwise return unimplemented. */ |
232 | jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4; |
233 | jcp.oc_block = jcp.ic_block; |
234 | } |
235 | if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0) |
236 | return status::unimplemented; |
237 | } |
238 | |
239 | if (!set_or_check_wei_format()) return status::unimplemented; |
240 | |
241 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
242 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
243 | jcp.dilate_w = cd.dilates[ndims - 3]; |
244 | |
245 | if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1) |
246 | || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) |
247 | || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) |
248 | return status::unimplemented; |
249 | |
250 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
251 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
252 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
253 | jcp.r_pad = calculate_end_padding( |
254 | jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw); |
255 | jcp.b_pad = calculate_end_padding( |
256 | jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh); |
257 | jcp.back_pad = calculate_end_padding( |
258 | jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd); |
259 | bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
260 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
261 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
262 | if (kernel_outside_src) return status::unimplemented; |
263 | |
264 | CHECK(attr.set_default_formats(&dst_md)); |
265 | if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented; |
266 | |
267 | const auto &p = attr.post_ops_; |
268 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
269 | jcp.with_eltwise = eltwise_ind != -1; |
270 | if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
271 | const int binary_ind = p.find(primitive_kind::binary); |
272 | jcp.with_binary = binary_ind != -1; |
273 | |
274 | const int sum_ind = p.find(primitive_kind::sum); |
275 | jcp.with_sum = sum_ind != -1; |
276 | |
277 | //save post_ops desc for further usage |
278 | jcp.post_ops = p; |
279 | |
280 | jcp.has_vnni = mayiuse(avx512_core_vnni); |
281 | |
282 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
283 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
284 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
285 | const int wei_scales_per_oc = 1 << (int)with_groups; |
286 | jcp.is_oc_scale = wei_scales.mask_ == wei_scales_per_oc; |
287 | jcp.dst_scale = !dst_scales.has_default_values(); |
288 | |
289 | // only common and per-oc-channel scales are supported |
290 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_scales_per_oc) |
291 | && utils::everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
292 | if (!scales_ok) return status::unimplemented; |
293 | |
294 | jcp.dst_dt = dst_d.data_type(); |
295 | jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; |
296 | jcp.typesize_bia |
297 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
298 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
299 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
300 | |
301 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
302 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
303 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
304 | |
305 | /* kernel blocking params */ |
306 | const int regs = jcp.has_vnni ? 30 : 28; |
307 | jcp.nb_ch_blocking = 1; |
308 | jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); |
309 | for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) |
310 | if (jcp.nb_oc % jcp.nb_oc_blocking == 0 |
311 | && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) |
312 | break; |
313 | |
314 | jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); |
315 | int l_overflow = max( |
316 | 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); |
317 | |
318 | if (jcp.ow < jcp.ur_w) { |
319 | jcp.ur_w = jcp.ow; |
320 | jcp.ur_w_tail = 0; |
321 | } else { |
322 | for (; jcp.ur_w >= 1; jcp.ur_w--) { |
323 | /* ur_w should be multiple of stride_w in order |
324 | to simplify logic for get_ow_start and get_ow_end */ |
325 | bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; |
326 | |
327 | /* boundary conditions: |
328 | These conditions ensure all elements close to boundary |
329 | are computed in a single call of compute loop */ |
330 | bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w; |
331 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
332 | int r_overflow_no_tail = max(0, |
333 | ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad) |
334 | - jcp.ur_w_tail) |
335 | / jcp.stride_w); |
336 | bool right_boundary_covered |
337 | = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; |
338 | |
339 | if (is_multiple_of_stride && left_boundary_covered |
340 | && right_boundary_covered) |
341 | break; |
342 | else if (jcp.ur_w == 1) |
343 | /* The boundary conditions above are also important |
344 | to maintain simplicity of calls to icb_loop, |
345 | if those conditions are not satisfied, |
346 | then special cases will need to be added |
347 | to use correct l_overflow/r_overflow values |
348 | when different iterations of compute loop |
349 | work on the locations close to boundary. |
350 | So to keep code simple, return unimplemented |
351 | for extreme case when a good ur_w cannot be found. |
352 | */ |
353 | return status::unimplemented; |
354 | } |
355 | } |
356 | |
357 | jcp.wei_adj_scale |
358 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
359 | ? weights_d.extra().scale_adjust |
360 | : 1.f; |
361 | |
362 | jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; |
363 | return status::success; |
364 | } |
365 | |
366 | bool _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok( |
367 | jit_conv_conf_t &jcp, primitive_attr_t &attr, |
368 | const memory_desc_wrapper &dst_d) { |
369 | |
370 | using namespace injector; |
371 | const auto &post_ops = attr.post_ops_; |
372 | static constexpr bool sum_at_pos_0_only = true; |
373 | static constexpr bool sum_requires_scale_one = false; |
374 | |
375 | return injector::post_ops_ok({avx512_core, {eltwise, binary, sum}, post_ops, |
376 | &dst_d, sum_at_pos_0_only, sum_requires_scale_one}); |
377 | } |
378 | |
379 | void _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( |
380 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
381 | const primitive_attr_t &attr) { |
382 | const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
383 | const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
384 | const dim_t count = nstl::max<dim_t>(scales_count, 16); |
385 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
386 | |
387 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) { |
388 | const dim_t zp_pad_comp_size |
389 | = static_cast<size_t>(jcp.oc_without_padding) * jcp.ngroups |
390 | * jcp.kd * jcp.kh * jcp.kw; |
391 | scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size); |
392 | } |
393 | } |
394 | |
395 | template <typename Vmm> |
396 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::compute( |
397 | const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src) { |
398 | |
399 | if (jcp.has_vnni) { |
400 | vpdpbusd(vreg_acc, vreg_src, vreg_wei); |
401 | } else if (jcp.is_depthwise) { |
402 | uni_vmovups(vmm_tmp, vreg_src); |
403 | uni_vpmulld(vmm_tmp, vmm_tmp, vreg_wei); |
404 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
405 | } else { |
406 | uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); |
407 | uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
408 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
409 | } |
410 | } |
411 | |
412 | template <typename Vmm> |
413 | std::function<Vmm()> jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
414 | Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept { |
415 | const int start_vmm_idx = vmm_inp(0, jcp.nb_oc_blocking).getIdx(); |
416 | const int end_vmm_idx = vmm_inp(ur_w - 1, jcp.nb_oc_blocking).getIdx() + 1; |
417 | int current_vmm_idx = start_vmm_idx; |
418 | |
419 | return [=]() mutable { |
420 | const Vmm vmm {static_cast<int>(current_vmm_idx++)}; |
421 | |
422 | if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx; |
423 | |
424 | return vmm; |
425 | }; |
426 | } |
427 | |
428 | template <typename Vmm> |
429 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::apply_zp_src_pad_str_comp( |
430 | int ur_w, int l_overflow, int r_overflow, bool h_padded) { |
431 | Xbyak::Label end_zp_pad, no_tail; |
432 | |
433 | // apply once per icb loop, zp src stride padding compensation calculated as |
434 | // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source |
435 | cmp(reg_icb, jcp.nb_ic); |
436 | jne(end_zp_pad, T_NEAR); |
437 | |
438 | if (jcp.ngroups % jcp.ch_block || jcp.oc_without_padding % jcp.oc_block) { |
439 | if (jcp.is_depthwise) |
440 | cmp(reg_oc_blocks, jcp.nb_ch - 1); |
441 | else |
442 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
443 | jne(no_tail, T_NEAR); |
444 | |
445 | static constexpr bool last_ocb = true; |
446 | append_zp_src_pad_str_comp( |
447 | ur_w, l_overflow, r_overflow, h_padded, last_ocb); |
448 | jmp(end_zp_pad, T_NEAR); |
449 | } |
450 | |
451 | L(no_tail); |
452 | static constexpr bool last_ocb = false; |
453 | |
454 | append_zp_src_pad_str_comp( |
455 | ur_w, l_overflow, r_overflow, h_padded, last_ocb); |
456 | |
457 | L(end_zp_pad); |
458 | } |
459 | |
460 | template <typename Vmm> |
461 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
462 | Vmm>::append_zp_src_pad_str_comp(int ur_w, int l_overflow, |
463 | int r_overflow, bool h_padded, bool last_oc_block) { |
464 | |
465 | const auto ®_zp_src_pad_comp = reg_scratch; |
466 | const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w); |
467 | bool base_comp_addr_loaded = false; |
468 | |
469 | const auto load_base_zp_src_pad_comp_addr = [&]() { |
470 | if (!base_comp_addr_loaded) { |
471 | if (jcp.ndims == 5) mov(reg_scratch_preserved, reg_scratch); |
472 | |
473 | if (jcp.ndims > 3) |
474 | mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr); |
475 | else |
476 | mov(reg_zp_src_pad_comp, |
477 | qword[param1 + GET_OFF(zp_src_pad_str_compensation)]); |
478 | |
479 | base_comp_addr_loaded = true; |
480 | } |
481 | }; |
482 | |
483 | const auto load_zp_src_pad_comp = [&](const Vmm &zp_pad_comp_vmm, |
484 | const Xbyak::Address &comp_addr, |
485 | const int ocb) { |
486 | const bool is_last_ocb = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
487 | const bool is_tail = is_last_ocb && get_tail_size() > 0; |
488 | if (is_tail) |
489 | vmovups(zp_pad_comp_vmm | ktail_mask | T_z, comp_addr); |
490 | else |
491 | vmovups(zp_pad_comp_vmm, comp_addr); |
492 | }; |
493 | |
494 | const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) { |
495 | const auto kw_offset = it_kw * jcp.oc_without_padding * jcp.ngroups; |
496 | const auto oc_offset = ocb * jcp.oc_block; |
497 | |
498 | return (kw_offset + oc_offset) * sizeof(int32_t); |
499 | }; |
500 | |
501 | for (int it_kw = 0; it_kw < jcp.kw; ++it_kw) { |
502 | const int ow_start = get_ow_start(it_kw, l_overflow); |
503 | const int ow_end = get_ow_end(ur_w, it_kw, r_overflow); |
504 | |
505 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
506 | Vmm zp_src_comp_pad_vmm; // will be assigned later |
507 | bool ocb_zp_loaded = false; |
508 | |
509 | const auto zp_src_comp_pad_off |
510 | = get_zp_src_comp_pad_off(it_kw, ocb); |
511 | |
512 | for (int it_ow = 0; it_ow < ur_w; ++it_ow) { |
513 | |
514 | const bool inside_padded_area = h_padded |
515 | || !(it_ow >= ow_start && it_ow < ow_end |
516 | && ((it_ow + jcp.l_pad - it_kw) % jcp.stride_w |
517 | == 0)); |
518 | |
519 | if (inside_padded_area) { |
520 | load_base_zp_src_pad_comp_addr(); |
521 | |
522 | if (!ocb_zp_loaded) { |
523 | zp_src_comp_pad_vmm = get_next_comp_vmm(); |
524 | const auto comp_addr = ptr[reg_zp_src_pad_comp |
525 | + zp_src_comp_pad_off]; |
526 | load_zp_src_pad_comp( |
527 | zp_src_comp_pad_vmm, comp_addr, ocb); |
528 | ocb_zp_loaded = true; |
529 | } |
530 | |
531 | const auto vmm_dst = vmm_out(it_ow, ocb); |
532 | uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm); |
533 | } |
534 | } |
535 | } |
536 | } |
537 | |
538 | if (jcp.ndims > 3) { |
539 | if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr(); |
540 | |
541 | const auto kh_offset = jcp.kw * jcp.oc_without_padding * jcp.ngroups |
542 | * sizeof(int32_t); |
543 | |
544 | add(reg_zp_src_pad_comp, kh_offset); |
545 | mov(zp_src_pad_comp_addr, reg_zp_src_pad_comp); |
546 | } |
547 | |
548 | if (jcp.ndims == 5 && base_comp_addr_loaded) |
549 | mov(reg_scratch, reg_scratch_preserved); |
550 | } |
551 | |
552 | template <typename Vmm> |
553 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::compute_ker(int ur_w, |
554 | int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, |
555 | bool h_padded) { |
556 | |
557 | const bool signed_input_or_src_zp |
558 | = (jcp.signed_input || jcp.src_zero_point); |
559 | |
560 | const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; |
561 | const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp.stride_w; |
562 | |
563 | auto src_offset = [=](int oj, int icb, int ki) { |
564 | return jcp.typesize_in |
565 | * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) |
566 | * jcp.ngroups * jcp.ic_without_padding |
567 | + icb * 4); |
568 | }; |
569 | |
570 | auto kernel_offset = [=](int ocb, int icb, int ki) { |
571 | return jcp.typesize_in |
572 | * ((ocb * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki) |
573 | * ch_block_all |
574 | + icb * jcp.oc_block * ic_sub_step); |
575 | }; |
576 | |
577 | for (int ki = 0; ki < jcp.kw; ki++) { |
578 | int jj_start = get_ow_start(ki, l_overflow); |
579 | int jj_end = get_ow_end(ur_w, ki, r_overflow); |
580 | |
581 | int _start = (signed_input_or_src_zp) ? 0 : jj_start; |
582 | int _end = (signed_input_or_src_zp) ? ur_w : jj_end; |
583 | |
584 | int tail_size = jcp.is_depthwise ? jcp.ngroups % jcp.ch_block |
585 | : jcp.ic_without_padding % 4; |
586 | int n_ic_blocks = jcp.is_depthwise |
587 | ? 1 |
588 | : (last_ic_block_flag & ~no_last_block ? div_up( |
589 | jcp.ic_without_padding % jcp.ic_block, 4) |
590 | : jcp.ic_block / 4); |
591 | |
592 | for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { |
593 | if (h_padded == true) { |
594 | if (jcp.signed_input) { |
595 | /* fill padded area with shifted values */ |
596 | const Vmm inp = vmm_inp(0, jcp.nb_oc_blocking); |
597 | vpxord(inp, inp, inp); |
598 | vpsubb(inp, inp, vmm_shift); |
599 | } |
600 | } else { |
601 | |
602 | for (int jj = _start; jj < _end; jj += ur_w_stride) { |
603 | |
604 | int aux_src_off = src_offset(jj, icb1, ki); |
605 | |
606 | if (jj >= jj_start && jj < jj_end |
607 | && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) { |
608 | if (jcp.is_depthwise) { |
609 | Vmm vmm_src = vmm_inp(jj, jcp.nb_oc_blocking); |
610 | if (tail_size != 0) { |
611 | assert(jcp.nb_oc_blocking == 1); |
612 | vmm_src = vmm_src | ktail_mask | T_z; |
613 | } |
614 | vpmovzxbd(vmm_src, |
615 | EVEX_compress_addr( |
616 | aux_reg_src, aux_src_off)); |
617 | } else if ((last_ic_block_flag & last_sp_block) |
618 | && tail_size != 0 && icb1 == n_ic_blocks - 1) { |
619 | const Xmm xmm_tmp = Xmm( |
620 | vmm_inp(jj, jcp.nb_oc_blocking).getIdx()); |
621 | for (int r = 0; r < tail_size; ++r) |
622 | vpinsrb(xmm_tmp, xmm_tmp, |
623 | ptr[aux_reg_src + aux_src_off + r], r); |
624 | vpbroadcastd( |
625 | vmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp); |
626 | } else { |
627 | vpbroadcastd(vmm_inp(jj, jcp.nb_oc_blocking), |
628 | EVEX_compress_addr( |
629 | aux_reg_src, aux_src_off)); |
630 | } |
631 | if (jcp.signed_input) |
632 | vpsubb(vmm_inp(jj, jcp.nb_oc_blocking), |
633 | vmm_inp(jj, jcp.nb_oc_blocking), vmm_shift); |
634 | } else { |
635 | /* fill padded area with shifted values */ |
636 | if (jcp.signed_input) { |
637 | const Vmm inp = vmm_inp(jj, jcp.nb_oc_blocking); |
638 | vpxord(inp, inp, inp); |
639 | vpsubb(inp, inp, vmm_shift); |
640 | } |
641 | } |
642 | } |
643 | } |
644 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
645 | int aux_filt_off = kernel_offset(ocb, icb1, ki); |
646 | |
647 | if (_end - _start > 0) { |
648 | if (jcp.is_depthwise) |
649 | vpmovsxbd(vmm_wei, |
650 | EVEX_compress_addr(aux_reg_filt, aux_filt_off)); |
651 | else |
652 | vmovups(vmm_wei, |
653 | EVEX_compress_addr(aux_reg_filt, aux_filt_off)); |
654 | } |
655 | for (int jj = _start; jj < _end; jj += ur_w_stride) { |
656 | const bool jj_between_start_end |
657 | = jj >= jj_start && jj < jj_end; |
658 | const bool ki_applies_to_stride |
659 | = (jj + jcp.l_pad - ki) % jcp.stride_w == 0; |
660 | const bool inside_padded_area = h_padded |
661 | || !(jj_between_start_end && ki_applies_to_stride); |
662 | const auto vmm_dst = vmm_out(jj, ocb); |
663 | if (jcp.signed_input || !inside_padded_area) { |
664 | const Vmm inp = vmm_inp( |
665 | h_padded ? 0 : jj, jcp.nb_oc_blocking); |
666 | compute(vmm_dst, vmm_wei, inp); |
667 | } |
668 | } |
669 | } |
670 | } |
671 | } |
672 | |
673 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
674 | apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded); |
675 | } |
676 | |
677 | template <typename Vmm> |
678 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::kh_loop(int ur_w, |
679 | int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { |
680 | |
681 | const bool signed_input_or_src_zp |
682 | = (jcp.signed_input || jcp.src_zero_point); |
683 | |
684 | int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; |
685 | int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw |
686 | * jcp.ngroups * jcp.ic_without_padding; |
687 | int shift_src_id = jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * jcp.iw |
688 | * jcp.ngroups * jcp.ic_without_padding; |
689 | const int stride_h = signed_input_or_src_zp ? 1 : jcp.stride_h; |
690 | int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h; |
691 | const int stride_d = signed_input_or_src_zp ? 1 : jcp.stride_d; |
692 | int shift_filt_kd |
693 | = jcp.typesize_in * jcp.kw * ch_block_all * jcp.kh * stride_d; |
694 | |
695 | Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop; |
696 | Label t_overflow_label, no_t_overflow_label, b_overflow_label, |
697 | no_b_overflow_label; |
698 | Label back_overflow_label, no_back_overflow_label, d_h_overflow_label, |
699 | front_overflow_label, no_front_overflow_label, d_h_overflow_label2; |
700 | if (jcp.ndims == 5) { |
701 | mov(aux_reg_filt_d, reg_filt); |
702 | mov(aux_reg_src_d, reg_src); |
703 | |
704 | if (signed_input_or_src_zp) { |
705 | mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]); |
706 | cmp(reg_ki, 0); |
707 | je(no_back_overflow_label, T_NEAR); |
708 | L(back_overflow_label); |
709 | { |
710 | mov(aux_reg_filt, aux_reg_filt_d); |
711 | mov(reg_kh, jcp.kh); |
712 | L(d_h_overflow_label); |
713 | { |
714 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
715 | add(aux_reg_filt, shift_filt_kh); |
716 | dec(reg_kh); |
717 | jnz(d_h_overflow_label); |
718 | } |
719 | |
720 | add(aux_reg_filt_d, shift_filt_kd); |
721 | dec(reg_ki); |
722 | jnz(back_overflow_label); |
723 | } |
724 | L(no_back_overflow_label); |
725 | } |
726 | |
727 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
728 | |
729 | if ((signed_input_or_src_zp) || (jcp.dilate_d >= jcp.id) |
730 | || ((!signed_input_or_src_zp) |
731 | && ((min(jcp.f_pad, jcp.back_pad) < 0) |
732 | || ((jcp.kd - 1) * (jcp.dilate_d + 1) |
733 | < nstl::max( |
734 | jcp.f_pad, jcp.back_pad))))) { |
735 | cmp(reg_ki, 0); |
736 | je(skip_kd_loop, T_NEAR); |
737 | } |
738 | |
739 | L(kd_loop_label); |
740 | mov(aux_reg_src, aux_reg_src_d); |
741 | mov(aux_reg_filt, aux_reg_filt_d); |
742 | } else { |
743 | mov(aux_reg_src, reg_src); |
744 | mov(aux_reg_filt, reg_filt); |
745 | } |
746 | |
747 | if (signed_input_or_src_zp && jcp.ndims > 3) { |
748 | /* Weights are transposed, so first compute 'bottom' padding. */ |
749 | mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); |
750 | cmp(reg_overflow, 0); |
751 | je(no_b_overflow_label, T_NEAR); |
752 | L(b_overflow_label); |
753 | { |
754 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
755 | |
756 | add(aux_reg_filt, shift_filt_kh); |
757 | dec(reg_overflow); |
758 | cmp(reg_overflow, 0); |
759 | jg(b_overflow_label, T_NEAR); |
760 | } |
761 | L(no_b_overflow_label); |
762 | } |
763 | |
764 | mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]); |
765 | |
766 | if ((signed_input_or_src_zp) || (jcp.dilate_h >= jcp.ih) |
767 | || ((!signed_input_or_src_zp) |
768 | && ((min(jcp.t_pad, jcp.b_pad) < 0) |
769 | || ((jcp.kh - 1) * (jcp.dilate_h + 1) |
770 | < nstl::max(jcp.t_pad, jcp.b_pad))))) { |
771 | cmp(reg_kh, 0); |
772 | je(skip_kh_loop, T_NEAR); |
773 | } |
774 | |
775 | L(kh_loop_label); |
776 | { |
777 | compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); |
778 | sub(aux_reg_src, shift_src_ih); |
779 | add(aux_reg_filt, shift_filt_kh); |
780 | dec(reg_kh); |
781 | |
782 | /* Insert weight compensation in stride 'holes' */ |
783 | if (signed_input_or_src_zp && jcp.stride_h > 1) { |
784 | Label kh_comp_loop; |
785 | |
786 | cmp(reg_kh, 0); |
787 | je(skip_kh_loop, T_NEAR); |
788 | mov(reg_comp_strides, jcp.stride_h - 1); |
789 | L(kh_comp_loop); |
790 | { |
791 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
792 | add(aux_reg_filt, shift_filt_kh); |
793 | dec(reg_comp_strides); |
794 | cmp(reg_comp_strides, 0); |
795 | jg(kh_comp_loop, T_NEAR); |
796 | } |
797 | } |
798 | cmp(reg_kh, 0); |
799 | jg(kh_loop_label, T_NEAR); |
800 | } |
801 | L(skip_kh_loop); |
802 | if (signed_input_or_src_zp && jcp.ndims > 3) { |
803 | mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); |
804 | cmp(reg_overflow, 0); |
805 | je(no_t_overflow_label, T_NEAR); |
806 | L(t_overflow_label); |
807 | { |
808 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
809 | |
810 | add(aux_reg_filt, shift_filt_kh); |
811 | dec(reg_overflow); |
812 | cmp(reg_overflow, 0); |
813 | jg(t_overflow_label, T_NEAR); |
814 | } |
815 | L(no_t_overflow_label); |
816 | } |
817 | |
818 | if (jcp.ndims == 5) { |
819 | sub(aux_reg_src_d, shift_src_id); |
820 | add(aux_reg_filt_d, shift_filt_kd); |
821 | dec(reg_ki); |
822 | |
823 | /* Insert weight compensation in stride 'holes' */ |
824 | if (signed_input_or_src_zp && jcp.stride_d > 1) { |
825 | Label kd_comp_loop, kd_kh_comp_loop; |
826 | cmp(reg_ki, 0); |
827 | jz(skip_kd_loop, T_NEAR); |
828 | mov(reg_comp_strides, jcp.stride_d - 1); |
829 | L(kd_comp_loop); |
830 | mov(aux_reg_filt, aux_reg_filt_d); |
831 | mov(reg_kh, jcp.kh); |
832 | L(kd_kh_comp_loop); |
833 | { |
834 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
835 | add(aux_reg_filt, shift_filt_kh); |
836 | dec(reg_kh); |
837 | jnz(kd_kh_comp_loop, T_NEAR); |
838 | } |
839 | add(aux_reg_filt_d, shift_filt_kd); |
840 | dec(reg_comp_strides); |
841 | jnz(kd_comp_loop); |
842 | } |
843 | |
844 | cmp(reg_ki, 0); |
845 | jg(kd_loop_label, T_NEAR); |
846 | L(skip_kd_loop); |
847 | if (signed_input_or_src_zp) { |
848 | mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]); |
849 | cmp(reg_ki, 0); |
850 | jz(no_front_overflow_label, T_NEAR); |
851 | L(front_overflow_label); |
852 | { |
853 | mov(aux_reg_filt, aux_reg_filt_d); |
854 | mov(reg_kh, jcp.kh); |
855 | L(d_h_overflow_label2); |
856 | { |
857 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
858 | add(aux_reg_filt, shift_filt_kh); |
859 | dec(reg_kh); |
860 | jnz(d_h_overflow_label2); |
861 | } |
862 | add(aux_reg_filt_d, shift_filt_kd); |
863 | dec(reg_ki); |
864 | jnz(front_overflow_label); |
865 | } |
866 | L(no_front_overflow_label); |
867 | } |
868 | } |
869 | } |
870 | template <typename Vmm> |
871 | int jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_tail_size() const |
872 | noexcept { |
873 | return jcp.is_depthwise ? jcp.ngroups % jcp.ch_block |
874 | : jcp.oc_without_padding % jcp.oc_block; |
875 | } |
876 | |
877 | template <typename Vmm> |
878 | int jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_blocking_size() const |
879 | noexcept { |
880 | return jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; |
881 | } |
882 | |
883 | template <typename Vmm> |
884 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::prepare_output(int ur_w) { |
885 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
886 | for (int ur = 0; ur < ur_w; ur++) { |
887 | const Vmm vmm = vmm_out(ur, ocb); |
888 | vpxord(vmm, vmm, vmm); |
889 | } |
890 | } |
891 | if (jcp.signed_input) { |
892 | xor_(reg_scratch, reg_scratch); |
893 | Reg8 _t8 = reg_scratch.cvt8(); |
894 | mov(_t8, (int8_t)-128); |
895 | vpbroadcastb(vmm_shift, _t8); |
896 | } |
897 | } |
898 | |
899 | template <typename Vmm> |
900 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::cvt2ps( |
901 | data_type_t type_in, Vmm vmm_in, const Operand &op, bool mask_flag) { |
902 | const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in; |
903 | switch (type_in) { |
904 | case data_type::f32: |
905 | case data_type::s32: vmovups(vmm, op); break; |
906 | case data_type::s8: vpmovsxbd(vmm, op); break; |
907 | case data_type::u8: vpmovzxbd(vmm, op); break; |
908 | default: assert(!"unsupported data type" ); |
909 | } |
910 | if (type_in != data_type::f32) vcvtdq2ps(vmm_in, vmm_in); |
911 | } |
912 | |
913 | template <typename Vmm> |
914 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::store_output( |
915 | int ur_w, bool last_oc_block) { |
916 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
917 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
918 | |
919 | if (jcp.signed_input) |
920 | mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); |
921 | |
922 | if (jcp.src_zero_point) { |
923 | mov(reg_zp_src_, ptr[param1 + GET_OFF(src_zero_point)]); |
924 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
925 | } |
926 | |
927 | if (jcp.src_zero_point) { |
928 | const auto &vmm_src_zp = vmm_tmp; |
929 | const auto &vmm_zp_comp = vmm_wei; |
930 | uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]); |
931 | |
932 | const bool is_tail = get_tail_size() > 0; |
933 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
934 | const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block; |
935 | const bool is_last_ocb |
936 | = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
937 | const auto vmm = is_last_ocb && is_tail |
938 | ? vmm_zp_comp | ktail_mask | T_z |
939 | : vmm_zp_comp; |
940 | vmovups(vmm, ptr[reg_zp_compensation + zp_offset]); |
941 | |
942 | uni_vpmulld(vmm_zp_comp, vmm, vmm_src_zp); |
943 | |
944 | for (int ur = 0; ur < ur_w; ur++) { |
945 | const auto vmm_dst = vmm_out(ur, ocb); |
946 | uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp); |
947 | } |
948 | } |
949 | } |
950 | |
951 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
952 | const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
953 | int scale_offset |
954 | = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); |
955 | |
956 | const Vmm vmm_bias = vmm_tmp; |
957 | if (jcp.with_bias) { |
958 | int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
959 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
960 | cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); |
961 | } |
962 | if (jcp.signed_input) { |
963 | int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block; |
964 | auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset); |
965 | cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag); |
966 | } |
967 | |
968 | for (int ur = 0; ur < ur_w; ur++) { |
969 | const Vmm vmm = vmm_out(ur, ocb); |
970 | vcvtdq2ps(vmm, vmm); |
971 | if (jcp.signed_input) vaddps(vmm, vmm, vmm_comp); |
972 | const Vmm mask_vmm = mask_flag ? vmm | ktail_mask | T_z : vmm; |
973 | vmulps(mask_vmm, vmm, |
974 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
975 | if (jcp.with_bias) vaddps(vmm, vmm, vmm_bias); |
976 | } |
977 | } |
978 | /* Do post-ops */ |
979 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
980 | const auto &p = attr_.post_ops_; |
981 | const int sum_idx = p.find(primitive_kind::sum); |
982 | const float *p_sum_scale |
983 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; |
984 | if (p_sum_scale && *p_sum_scale != 1.f) |
985 | mov(reg_ptr_sum_scale, (size_t)p_sum_scale); |
986 | |
987 | const auto sum_injector = [&]() { |
988 | if (p_sum_scale) { // post_op: sum |
989 | for (int k = 0; k < jcp.nb_oc_blocking; k++) { |
990 | const bool mask_flag |
991 | = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1; |
992 | for (int j = 0; j < ur_w; j++) { |
993 | int aux_output_offset = jcp.typesize_out |
994 | * (k * jcp.oc_block |
995 | + j * jcp.oc_without_padding |
996 | * jcp.ngroups); |
997 | auto addr = EVEX_compress_addr( |
998 | reg_dst, aux_output_offset); |
999 | const Vmm vmm = vmm_out(j, k); |
1000 | cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag); |
1001 | if (*p_sum_scale == 1.f) |
1002 | vaddps(vmm, vmm_prev_dst); |
1003 | else |
1004 | vfmadd231ps(vmm, vmm_prev_dst, |
1005 | zword_b[reg_ptr_sum_scale]); |
1006 | } |
1007 | } |
1008 | } |
1009 | }; |
1010 | |
1011 | if (p_sum_scale) |
1012 | postops_injector_->set_lambda_injector( |
1013 | primitive_kind::sum, sum_injector); |
1014 | |
1015 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
1016 | if (jcp.with_binary) { |
1017 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1018 | const bool mask_flag |
1019 | = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
1020 | for (int ur = 0; ur < ur_w; ur++) { |
1021 | const int vmm_idx = vmm_out(ur, ocb).getIdx(); |
1022 | const size_t aux_output_offset = jcp.typesize_out |
1023 | * (ocb * jcp.oc_block |
1024 | + ur * jcp.oc_without_padding |
1025 | * jcp.ngroups); |
1026 | |
1027 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); |
1028 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
1029 | vmm_idx, aux_output_offset); |
1030 | if (mask_flag) |
1031 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
1032 | } |
1033 | } |
1034 | } |
1035 | const int nb_oc_block |
1036 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
1037 | postops_injector_->compute_vector_range( |
1038 | 0, nb_oc_block * ur_w, rhs_arg_params); |
1039 | } |
1040 | |
1041 | if (jcp.dst_scale) { |
1042 | mov(reg_ptr_dst_scales, ptr[param1 + GET_OFF(dst_scale)]); |
1043 | uni_vmovups(vmm_dst_scale, ptr[reg_ptr_dst_scales]); |
1044 | |
1045 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1046 | const bool mask_flag |
1047 | = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
1048 | for (int ur = 0; ur < ur_w; ur++) { |
1049 | const auto vmm = vmm_out(ur, ocb); |
1050 | const Vmm mask_vmm = mask_flag ? vmm | ktail_mask | T_z : vmm; |
1051 | uni_vmulps(mask_vmm, vmm, vmm_dst_scale); |
1052 | } |
1053 | } |
1054 | } |
1055 | if (jcp.dst_zero_point) { |
1056 | mov(reg_zp_dst_, ptr[param1 + GET_OFF(dst_zero_point)]); |
1057 | const auto &vmm_zp_dst = vmm_tmp; |
1058 | uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]); |
1059 | vcvtdq2ps(vmm_zp_dst, vmm_zp_dst); |
1060 | |
1061 | for_(int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
1062 | for (int ur = 0; ur < ur_w; ur++) { |
1063 | const auto vmm_dst = vmm_out(ur, ocb); |
1064 | uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst); |
1065 | } |
1066 | } |
1067 | |
1068 | // Properly saturate the accumulators for integer datatypes |
1069 | |
1070 | // No need to saturate on lower bound for signed integer types, as |
1071 | // the conversion to int would return INT_MIN, and then proper |
1072 | // saturation will happen when storing data |
1073 | if (jcp.dst_dt == data_type::u8) { |
1074 | vpxord(vmm_zero, vmm_zero, vmm_zero); |
1075 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1076 | for (int ur = 0; ur < ur_w; ur++) { |
1077 | const Vmm vmm = vmm_out(ur, ocb); |
1078 | vmaxps(vmm, vmm_zero, vmm); |
1079 | } |
1080 | } |
1081 | } |
1082 | |
1083 | if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) { |
1084 | float saturation_ubound = types::max_value<float>(jcp.dst_dt); |
1085 | Xmm xmm_saturation(vmm_saturation.getIdx()); |
1086 | mov(reg_ptr_saturation_ubound, float2int(saturation_ubound)); |
1087 | vmovq(xmm_saturation, reg_ptr_saturation_ubound); |
1088 | vbroadcastss(vmm_saturation, xmm_saturation); |
1089 | |
1090 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1091 | for (int ur = 0; ur < ur_w; ur++) { |
1092 | const Vmm vmm = vmm_out(ur, ocb); |
1093 | vminps(vmm, vmm, vmm_saturation); |
1094 | } |
1095 | } |
1096 | } |
1097 | |
1098 | if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) { |
1099 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1100 | for (int ur = 0; ur < ur_w; ur++) { |
1101 | const Vmm vmm = vmm_out(ur, ocb); |
1102 | vcvtps2dq(vmm, vmm); |
1103 | } |
1104 | } |
1105 | } |
1106 | |
1107 | /* write out register to output_addr */ |
1108 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
1109 | const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1; |
1110 | for (int ur = 0; ur < ur_w; ur++) { |
1111 | int aux_dst_off = jcp.typesize_out |
1112 | * (ur * jcp.ngroups * jcp.oc_without_padding |
1113 | + ocb * jcp.oc_block); |
1114 | auto addr = EVEX_compress_addr(reg_dst, aux_dst_off); |
1115 | |
1116 | const Vmm vmm = vmm_out(ur, ocb); |
1117 | const Vmm r_vmm = mask_flag ? vmm | ktail_mask : vmm; |
1118 | switch (jcp.dst_dt) { |
1119 | case data_type::f32: |
1120 | case data_type::s32: vmovups(addr, r_vmm); break; |
1121 | case data_type::s8: vpmovsdb(addr, r_vmm); break; |
1122 | case data_type::u8: vpmovusdb(addr, r_vmm); break; |
1123 | default: assert(!"unknown dst_dt" ); |
1124 | } |
1125 | } |
1126 | } |
1127 | } |
1128 | |
1129 | template <typename Vmm> |
1130 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::icb_loop( |
1131 | int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { |
1132 | |
1133 | int shift_src_icb = jcp.typesize_in * jcp.ic_block; |
1134 | const size_t shift_filt_icb = (size_t)jcp.typesize_in * jcp.kd * jcp.kh |
1135 | * jcp.kw * jcp.ic_block * jcp.oc_block; |
1136 | |
1137 | prepare_output(ur_w); |
1138 | |
1139 | Label skip_icb_loop, icb_loop_label; |
1140 | |
1141 | mov(reg_icb, jcp.nb_ic); |
1142 | |
1143 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) { |
1144 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
1145 | if (jcp.ndims > 3) { |
1146 | mov(reg_scratch, |
1147 | qword[param1 + GET_OFF(zp_src_pad_str_compensation)]); |
1148 | mov(zp_src_pad_comp_addr, reg_scratch); |
1149 | } |
1150 | } |
1151 | |
1152 | L(icb_loop_label); |
1153 | { |
1154 | |
1155 | if (jcp.ic_without_padding != jcp.ic) { |
1156 | Label common_ker, end_ker; |
1157 | cmp(reg_icb, 1); |
1158 | jg(common_ker, T_NEAR); |
1159 | |
1160 | kh_loop(ur_w, l_overflow, r_overflow, |
1161 | is_last_sp_block ? last_sp_block : last_ic_block); |
1162 | jmp(end_ker, T_NEAR); |
1163 | |
1164 | L(common_ker); |
1165 | kh_loop(ur_w, l_overflow, r_overflow, no_last_block); |
1166 | |
1167 | L(end_ker); |
1168 | } else { |
1169 | kh_loop(ur_w, l_overflow, r_overflow, no_last_block); |
1170 | } |
1171 | |
1172 | add(reg_src, shift_src_icb); |
1173 | safe_add(reg_filt, shift_filt_icb, reg_ker_long_offt); |
1174 | dec(reg_icb); |
1175 | cmp(reg_icb, 0); |
1176 | jg(icb_loop_label, T_NEAR); |
1177 | } |
1178 | |
1179 | /* come-back pointers */ |
1180 | sub(reg_src, jcp.nb_ic * shift_src_icb); |
1181 | safe_sub(reg_filt, jcp.nb_ic * shift_filt_icb, reg_ker_long_offt); |
1182 | L(skip_icb_loop); |
1183 | |
1184 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
1185 | Label common_store, end_store; |
1186 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
1187 | if (jcp.is_depthwise) |
1188 | cmp(reg_oc_blocks, jcp.nb_ch - 1); |
1189 | else |
1190 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
1191 | jne(common_store, T_NEAR); |
1192 | |
1193 | store_output(ur_w, true); |
1194 | jmp(end_store, T_NEAR); |
1195 | |
1196 | L(common_store); |
1197 | store_output(ur_w, false); |
1198 | |
1199 | L(end_store); |
1200 | |
1201 | } else { |
1202 | store_output(ur_w, false); |
1203 | } |
1204 | } |
1205 | |
1206 | template <typename Vmm> |
1207 | ur_w_blks_params_t |
1208 | jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::get_ur_w_blks_params() { |
1209 | const int n_ur_blocks = jcp.ow / jcp.ur_w; |
1210 | |
1211 | ur_w_blks_params_t ur_w_blks_params; |
1212 | int num_blks_to_process_sp_carefully = 0; |
1213 | int idx_last_non_zero_l_overflow_blk = -1; |
1214 | int idx_first_non_zero_r_overflow_blk = n_ur_blocks; |
1215 | |
1216 | static constexpr int src_pixels_loaded_for_bcast = 4; |
1217 | const auto ic_mod = jcp.ic_without_padding % src_pixels_loaded_for_bcast; |
1218 | for (int blk_idx = 0; blk_idx < n_ur_blocks; blk_idx++) { |
1219 | const int first_blk_dst_elem = blk_idx * jcp.ur_w; |
1220 | const int last_dst_blk_elem = first_blk_dst_elem + jcp.ur_w - 1; |
1221 | |
1222 | const int last_blk_src_idx = nstl::min( |
1223 | jcp.iw - 1, (last_dst_blk_elem + jcp.l_pad) / jcp.stride_w); |
1224 | const bool is_out_of_src_pixels_scope |
1225 | = ((jcp.iw - 1 - last_blk_src_idx) * jcp.ic_without_padding |
1226 | + ic_mod |
1227 | < src_pixels_loaded_for_bcast); |
1228 | |
1229 | const bool process_sp_carefully |
1230 | = (ic_mod != 0) && is_out_of_src_pixels_scope; |
1231 | const int curr_l_overflow = nstl::max(0, |
1232 | ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad |
1233 | - first_blk_dst_elem) |
1234 | / jcp.stride_w); |
1235 | const int curr_r_overflow = nstl::max(0, |
1236 | (last_dst_blk_elem + jcp.l_pad) / jcp.stride_w - (jcp.iw - 1)); |
1237 | |
1238 | ur_w_blks_params.blks_params.emplace_back( |
1239 | curr_l_overflow, curr_r_overflow, process_sp_carefully); |
1240 | |
1241 | num_blks_to_process_sp_carefully |
1242 | += static_cast<int>(process_sp_carefully); |
1243 | if (curr_l_overflow > 0) idx_last_non_zero_l_overflow_blk = blk_idx; |
1244 | if (curr_r_overflow > 0 && idx_first_non_zero_r_overflow_blk > blk_idx) |
1245 | idx_first_non_zero_r_overflow_blk = blk_idx; |
1246 | } |
1247 | idx_first_non_zero_r_overflow_blk |
1248 | = nstl::max(idx_first_non_zero_r_overflow_blk, |
1249 | idx_last_non_zero_l_overflow_blk + 1); |
1250 | // limit num_r_overflow_blks and num_blks_to_process_last_sp_carefully so that: |
1251 | // n_ur_blocks >= num_l_overflow_blks + max(num_r_overflow_blks, num_blks_to_process_last_sp_carefully) |
1252 | ur_w_blks_params.num_pre_blks |
1253 | = nstl::max(0, idx_last_non_zero_l_overflow_blk + 1); |
1254 | const int num_r_overflow_blks = idx_first_non_zero_r_overflow_blk |
1255 | <= idx_last_non_zero_l_overflow_blk |
1256 | ? n_ur_blocks - ur_w_blks_params.num_pre_blks |
1257 | : n_ur_blocks - idx_first_non_zero_r_overflow_blk; |
1258 | num_blks_to_process_sp_carefully |
1259 | = ur_w_blks_params.num_pre_blks + num_blks_to_process_sp_carefully |
1260 | < n_ur_blocks |
1261 | ? num_blks_to_process_sp_carefully |
1262 | : n_ur_blocks - ur_w_blks_params.num_pre_blks; |
1263 | ur_w_blks_params.num_post_blks |
1264 | = nstl::max(num_r_overflow_blks, num_blks_to_process_sp_carefully); |
1265 | |
1266 | return ur_w_blks_params; |
1267 | } |
1268 | |
1269 | template <typename Vmm> |
1270 | void jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Vmm>::generate() { |
1271 | preamble(); |
1272 | |
1273 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1274 | sub(rsp, reserved_stack_size_); |
1275 | |
1276 | xor_(reg_scratch, reg_scratch); |
1277 | Reg16 _t = reg_scratch.cvt16(); |
1278 | mov(_t, 0x1); |
1279 | vpbroadcastw(vmm_one, _t); |
1280 | |
1281 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
1282 | int tail_size = jcp.is_depthwise |
1283 | ? jcp.ngroups % jcp.ch_block |
1284 | : jcp.oc_without_padding % jcp.oc_block; |
1285 | int mask = (1 << tail_size) - 1; |
1286 | Reg32 regw_tmp = reg_nur_w.cvt32(); |
1287 | Label skip_tail_mask; |
1288 | if (jcp.is_depthwise) { |
1289 | kxnorw(ktail_mask, ktail_mask, ktail_mask); |
1290 | cmp(dword[param1 + GET_OFF(oc_blocks)], jcp.nb_ch - 1); |
1291 | jne(skip_tail_mask, T_NEAR); |
1292 | } |
1293 | mov(regw_tmp, mask); |
1294 | kmovw(ktail_mask, regw_tmp); |
1295 | L(skip_tail_mask); |
1296 | } |
1297 | |
1298 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1299 | mov(reg_filt, ptr[param1 + GET_OFF(filt)]); |
1300 | mov(reg_dst, ptr[param1 + GET_OFF(dst)]); |
1301 | |
1302 | int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups |
1303 | * jcp.oc_without_padding; |
1304 | int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups |
1305 | * jcp.ic_without_padding; |
1306 | |
1307 | const auto ur_w_blks_params = get_ur_w_blks_params(); |
1308 | const int nur_w = jcp.ow / jcp.ur_w - ur_w_blks_params.num_pre_blks |
1309 | - ur_w_blks_params.num_post_blks; |
1310 | |
1311 | const auto &blks_params = ur_w_blks_params.blks_params; |
1312 | const auto num_pre_blks = ur_w_blks_params.num_pre_blks; |
1313 | const auto num_post_blks = ur_w_blks_params.num_post_blks; |
1314 | |
1315 | for (int i = 0; i < num_pre_blks; i++) { |
1316 | const bool blk_process_carefully = blks_params[i].process_sp_carefully; |
1317 | const int blk_l_overflow = blks_params[i].l_overflow; |
1318 | const int blk_r_overflow = blks_params[i].r_overflow; |
1319 | |
1320 | icb_loop(jcp.ur_w, blk_l_overflow, blk_r_overflow, |
1321 | blk_process_carefully); |
1322 | add(reg_src, src_shift); |
1323 | add(reg_dst, dst_shift); |
1324 | } |
1325 | |
1326 | if (nur_w > 0) { |
1327 | xor_(reg_nur_w, reg_nur_w); |
1328 | Label ow_loop_label; |
1329 | L(ow_loop_label); |
1330 | { |
1331 | icb_loop(jcp.ur_w, 0, 0, false); |
1332 | add(reg_src, src_shift); |
1333 | add(reg_dst, dst_shift); |
1334 | inc(reg_nur_w); |
1335 | cmp(reg_nur_w, nur_w); |
1336 | jl(ow_loop_label, T_NEAR); |
1337 | } |
1338 | } |
1339 | |
1340 | if (num_post_blks > 0) { |
1341 | const auto blks_params_size = blks_params.size(); |
1342 | const auto start_blk_idx = blks_params_size - num_post_blks; |
1343 | for (size_t i = start_blk_idx; i < blks_params_size; i++) { |
1344 | const bool blk_process_carefully |
1345 | = blks_params[i].process_sp_carefully; |
1346 | const int blk_l_overflow = blks_params[i].l_overflow; |
1347 | const int blk_r_overflow = blks_params[i].r_overflow; |
1348 | |
1349 | icb_loop(jcp.ur_w, blk_l_overflow, blk_r_overflow, |
1350 | blk_process_carefully); |
1351 | add(reg_src, src_shift); |
1352 | add(reg_dst, dst_shift); |
1353 | } |
1354 | } |
1355 | |
1356 | if (jcp.ur_w_tail != 0) { |
1357 | // l_overflow - no. of spatial elements of weights standing out of src spatial |
1358 | // when computing the left-most (in w dim) output pixel |
1359 | int l_overflow = 0; |
1360 | if (jcp.ur_w == jcp.ow) |
1361 | l_overflow = max(0, |
1362 | ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) |
1363 | / jcp.stride_w); |
1364 | // r_overflow - no/ of spatial elements of weights standing out of src spatial |
1365 | // when computing the right-most (in w dim) output pixel |
1366 | const int r_overflow = max(0, |
1367 | ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)) |
1368 | / jcp.stride_w); |
1369 | |
1370 | icb_loop(jcp.ur_w_tail, l_overflow, r_overflow, true); |
1371 | } |
1372 | |
1373 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1374 | add(rsp, reserved_stack_size_); |
1375 | |
1376 | postamble(); |
1377 | |
1378 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
1379 | } |
1380 | |
1381 | const float *jit_avx512_core_x8s8s32x_deconvolution_fwd_t::adjust_oscales( |
1382 | const memory_tracking::grantor_t &scratchpad, const float *src_scales, |
1383 | const float *wei_scales) const { |
1384 | auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales); |
1385 | int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
1386 | float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
1387 | ? 1.f / pd()->jcp_.wei_adj_scale |
1388 | : 1.0f; |
1389 | if (wei_mask == 0) { |
1390 | utils::array_set( |
1391 | loc_scales, src_scales[0] * wei_scales[0] * factor, 16); |
1392 | } else { |
1393 | for (dim_t c = 0; c < pd()->OC(); c++) |
1394 | loc_scales[c] = src_scales[0] * wei_scales[c] * factor; |
1395 | } |
1396 | return loc_scales; |
1397 | } |
1398 | |
1399 | status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_1d( |
1400 | const exec_ctx_t &ctx) const { |
1401 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1402 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1403 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1404 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1405 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1406 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1407 | |
1408 | const memory_desc_wrapper src_d(pd()->src_md()); |
1409 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1410 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1411 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1412 | |
1413 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1414 | const auto &jcp = pd()->jcp_; |
1415 | |
1416 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1417 | int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1418 | |
1419 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1420 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1421 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1422 | zp_src_pad_comp_kernel_.get()); |
1423 | |
1424 | const auto post_ops_binary_rhs_arg_vec |
1425 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1426 | |
1427 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1428 | const int nb_groups = jcp.nb_ch; |
1429 | |
1430 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1431 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1432 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1433 | |
1434 | const float *oscales = adjust_oscales( |
1435 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1436 | |
1437 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1438 | auto w = const_cast<int8_t *>(weights); |
1439 | int32_t *compensation = (jcp.signed_input) |
1440 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1441 | : nullptr; |
1442 | const int32_t *zp_compensation = jcp.src_zero_point |
1443 | ? get_src_zp_comp_from_wei( |
1444 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1445 | : nullptr; |
1446 | |
1447 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1448 | int start {0}, end {0}; |
1449 | int work_amount = jcp.mb * nb_groups * oc_chunks; |
1450 | balance211(work_amount, nthr, ithr, start, end); |
1451 | |
1452 | auto p = jit_deconv_call_s(); |
1453 | |
1454 | int n {0}, g {0}, occ {0}; |
1455 | if (jcp.loop_order == loop_ngc) |
1456 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); |
1457 | else if (jcp.loop_order == loop_cgn) |
1458 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); |
1459 | else |
1460 | assert(!"unsupported loop order" ); |
1461 | while (start < end) { |
1462 | |
1463 | int ocb = occ * jcp.nb_oc_blocking; |
1464 | int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1465 | int g_ic = g * jcp.ch_block * jcp.ic; |
1466 | |
1467 | p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc); |
1468 | p.src = src + src_d.blk_off(n, g_ic); |
1469 | p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); |
1470 | p.bias = jcp.with_bias |
1471 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1472 | : nullptr; |
1473 | p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1474 | p.scales = &oscales[jcp.is_oc_scale * g_oc]; |
1475 | p.dst_scale = dst_scales; |
1476 | p.t_overflow = 0; |
1477 | p.b_overflow = 0; |
1478 | p.kh_padding = jcp.kh; |
1479 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1480 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
1481 | p.oc_l_off = g_oc; |
1482 | p.zp_compensation |
1483 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1484 | p.zp_src_pad_str_compensation |
1485 | = jcp.src_zero_point ? zp_src_comp_scratch + g_oc : nullptr; |
1486 | p.src_zero_point = zp_src; |
1487 | p.dst_zero_point = zp_dst; |
1488 | p.dst_orig = dst; |
1489 | (*kernel_)(&p); |
1490 | |
1491 | ++start; |
1492 | if (jcp.loop_order == loop_ngc) |
1493 | nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); |
1494 | else if (jcp.loop_order == loop_cgn) |
1495 | nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); |
1496 | else |
1497 | assert(!"unsupported loop order" ); |
1498 | } |
1499 | }); |
1500 | return status::success; |
1501 | } |
1502 | |
1503 | status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_2d( |
1504 | const exec_ctx_t &ctx) const { |
1505 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1506 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1507 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1508 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1509 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1510 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1511 | |
1512 | const auto &jcp = pd()->jcp_; |
1513 | |
1514 | const memory_desc_wrapper src_d(pd()->src_md()); |
1515 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1516 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1517 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1518 | const auto post_ops_binary_rhs_arg_vec |
1519 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1520 | |
1521 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1522 | |
1523 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1524 | int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1525 | |
1526 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1527 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1528 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1529 | zp_src_pad_comp_kernel_.get()); |
1530 | |
1531 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1532 | int nb_groups = jcp.nb_ch; |
1533 | |
1534 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
1535 | size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
1536 | size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
1537 | |
1538 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1539 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1540 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1541 | |
1542 | const float *oscales = adjust_oscales( |
1543 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1544 | |
1545 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1546 | auto w = const_cast<int8_t *>(weights); |
1547 | int32_t *compensation = (jcp.signed_input) |
1548 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1549 | : nullptr; |
1550 | const int32_t *zp_compensation = jcp.src_zero_point |
1551 | ? get_src_zp_comp_from_wei( |
1552 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1553 | : nullptr; |
1554 | |
1555 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1556 | int start {0}, end {0}; |
1557 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; |
1558 | balance211(work_amount, nthr, ithr, start, end); |
1559 | |
1560 | auto p = jit_deconv_call_s(); |
1561 | |
1562 | /*loop order = cgn*/ |
1563 | int n {0}, g {0}, occ {0}, oh_s {0}; |
1564 | if (jcp.loop_order == loop_ngc) |
1565 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
1566 | oh_s, jcp.oh); |
1567 | else if (jcp.loop_order == loop_cgn) |
1568 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, |
1569 | oh_s, jcp.oh); |
1570 | else |
1571 | assert(!"unsupported loop order" ); |
1572 | while (start < end) { |
1573 | |
1574 | int ocb = occ * jcp.nb_oc_blocking; |
1575 | int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1576 | int g_ic = g * jcp.ch_block * jcp.ic; |
1577 | int work_rem = end - start; |
1578 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1579 | |
1580 | auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc); |
1581 | auto src_w = src + src_d.blk_off(n, g_ic); |
1582 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); |
1583 | auto bias_w = jcp.with_bias |
1584 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1585 | : nullptr; |
1586 | int32_t *compensation_w |
1587 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1588 | |
1589 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
1590 | for (int oj = oh_s; oj < oh_e; oj++) { |
1591 | int ih_max = 0, kh_lo = 0, kh_len = 0; |
1592 | if (jcp.dilate_h != 0 && jcp.stride_h == 1) { |
1593 | /* dilation */ |
1594 | int dilate_h = jcp.dilate_h + 1; |
1595 | // Note: use div_up to account for "holes" in filter |
1596 | int o_t_overflow = div_up( |
1597 | max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), |
1598 | dilate_h); |
1599 | int o_b_overflow |
1600 | = div_up(max(0, |
1601 | (jcp.kh - 1) * dilate_h + 1 |
1602 | - jcp.oh + oj - jcp.b_pad), |
1603 | dilate_h); |
1604 | kh_len = jcp.kh - o_t_overflow - o_b_overflow; |
1605 | kh_lo = o_b_overflow; |
1606 | ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; |
1607 | } else { |
1608 | int o_t_overflow = max( |
1609 | 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); |
1610 | int o_b_overflow = max(0, |
1611 | ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) |
1612 | / jcp.stride_h); |
1613 | int overflow_kh_hi = jcp.kh - 1 |
1614 | - modulo(jcp.oh + jcp.b_pad - (oj + 1), |
1615 | jcp.stride_h); |
1616 | int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; |
1617 | |
1618 | kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h |
1619 | + 1 - o_t_overflow - o_b_overflow; |
1620 | kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; |
1621 | ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; |
1622 | } |
1623 | |
1624 | int wei_stride = (!jcp.signed_input && !jcp.src_zero_point) |
1625 | ? kh_lo * wht_kh_stride |
1626 | : 0; |
1627 | p.src = src_w + ih_max * src_h_stride; |
1628 | p.dst = dst_w + dst_dt_size * oj * dst_h_stride; |
1629 | p.filt = wht_w + wei_stride; |
1630 | p.bias = bias_w; |
1631 | p.compensation = compensation_w; |
1632 | p.t_overflow = jcp.dilate_h > 0 |
1633 | ? jcp.kh - kh_len - kh_lo |
1634 | : max(0, |
1635 | jcp.kh |
1636 | - (kh_lo |
1637 | + max(0, kh_len - 1) |
1638 | * jcp.stride_h |
1639 | + 1)); |
1640 | p.b_overflow = kh_lo; |
1641 | p.kh_padding = kh_len; |
1642 | p.scales = scales; |
1643 | p.dst_scale = dst_scales; |
1644 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1645 | p.post_ops_binary_rhs_arg_vec |
1646 | = post_ops_binary_rhs_arg_vec.data(); |
1647 | p.oc_l_off = g_oc; |
1648 | p.zp_compensation |
1649 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1650 | p.zp_src_pad_str_compensation = jcp.src_zero_point |
1651 | ? zp_src_comp_scratch + g_oc |
1652 | : nullptr; |
1653 | p.src_zero_point = zp_src; |
1654 | p.dst_zero_point = zp_dst; |
1655 | p.dst_orig = dst; |
1656 | |
1657 | (*kernel_)(&p); |
1658 | } |
1659 | if (jcp.loop_order == loop_ngc) |
1660 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
1661 | oc_chunks, oh_s, jcp.oh); |
1662 | else if (jcp.loop_order == loop_cgn) |
1663 | nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, |
1664 | jcp.mb, oh_s, jcp.oh); |
1665 | else |
1666 | assert(!"unsupported loop order" ); |
1667 | } |
1668 | }); |
1669 | return status::success; |
1670 | } |
1671 | |
1672 | status_t jit_avx512_core_x8s8s32x_deconvolution_fwd_t::execute_forward_3d( |
1673 | const exec_ctx_t &ctx) const { |
1674 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1675 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1676 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1677 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1678 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1679 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1680 | |
1681 | const auto &jcp = pd()->jcp_; |
1682 | |
1683 | const memory_desc_wrapper src_d(pd()->src_md()); |
1684 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1685 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1686 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1687 | const auto post_ops_binary_rhs_arg_vec |
1688 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1689 | |
1690 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1691 | |
1692 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1693 | int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1694 | |
1695 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1696 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1697 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1698 | zp_src_pad_comp_kernel_.get()); |
1699 | |
1700 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1701 | int nb_groups = jcp.nb_ch; |
1702 | |
1703 | size_t src_d_stride = src_d.blk_off(0, 0, 1); |
1704 | size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
1705 | size_t dst_d_stride = dst_d.blk_off(0, 0, 1); |
1706 | size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
1707 | size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
1708 | size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
1709 | |
1710 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1711 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1712 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1713 | |
1714 | const float *oscales = adjust_oscales( |
1715 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1716 | |
1717 | size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1718 | auto w = const_cast<int8_t *>(weights); |
1719 | int32_t *compensation = (jcp.signed_input) |
1720 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1721 | : nullptr; |
1722 | const int32_t *zp_compensation = jcp.src_zero_point |
1723 | ? get_src_zp_comp_from_wei( |
1724 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1725 | : nullptr; |
1726 | |
1727 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1728 | int start {0}, end {0}; |
1729 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh; |
1730 | balance211(work_amount, nthr, ithr, start, end); |
1731 | |
1732 | auto p = jit_deconv_call_s(); |
1733 | |
1734 | /*loop order = cgn*/ |
1735 | int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}; |
1736 | if (jcp.loop_order == loop_ngc) |
1737 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
1738 | od_s, jcp.od, oh_s, jcp.oh); |
1739 | else if (jcp.loop_order == loop_cgn) |
1740 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, |
1741 | od_s, jcp.od, oh_s, jcp.oh); |
1742 | else |
1743 | assert(!"unsupported loop order" ); |
1744 | while (start < end) { |
1745 | |
1746 | int ocb = occ * jcp.nb_oc_blocking; |
1747 | int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1748 | int g_ic = g * jcp.ch_block * jcp.ic; |
1749 | int work_rem = end - start; |
1750 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1751 | int input_d_s = 0, kd_len = 0, kd_lo = 0; |
1752 | int d_t_overflow, d_back_overflow; |
1753 | |
1754 | if (jcp.dilate_d != 0 && jcp.stride_d == 1) { |
1755 | /* dilation */ |
1756 | int dilate_d = jcp.dilate_d + 1; |
1757 | // Note: use div_up to account for "holes" in filter |
1758 | d_t_overflow = div_up( |
1759 | max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad), |
1760 | dilate_d); |
1761 | d_back_overflow |
1762 | = div_up(max(0, |
1763 | (jcp.kd - 1) * dilate_d + 1 - jcp.od |
1764 | + od_s - jcp.back_pad), |
1765 | dilate_d); |
1766 | kd_len = jcp.kd - d_t_overflow - d_back_overflow; |
1767 | kd_lo = d_back_overflow; |
1768 | input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d; |
1769 | } else { |
1770 | int d_t_overflow = max( |
1771 | 0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d); |
1772 | int d_back_overflow = max(0, |
1773 | ((od_s + jcp.kd) - (jcp.od + jcp.back_pad)) |
1774 | / jcp.stride_d); |
1775 | int overflow_kd_hi = jcp.kd - 1 |
1776 | - modulo(jcp.od + jcp.back_pad - (od_s + 1), |
1777 | jcp.stride_d); |
1778 | int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d; |
1779 | |
1780 | kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1 |
1781 | - d_t_overflow - d_back_overflow; |
1782 | kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d; |
1783 | input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d; |
1784 | } |
1785 | |
1786 | auto dst_w = dst |
1787 | + dst_dt_size |
1788 | * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride); |
1789 | auto src_w |
1790 | = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride; |
1791 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) |
1792 | + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo) |
1793 | * wht_kd_stride; |
1794 | auto bias_w = jcp.with_bias |
1795 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1796 | : nullptr; |
1797 | int32_t *compensation_w |
1798 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1799 | |
1800 | auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
1801 | |
1802 | for (int oj = oh_s; oj < oh_e; oj++) { |
1803 | int ih_max = 0, kh_lo = 0, kh_len = 0; |
1804 | if (jcp.dilate_h != 0 && jcp.stride_h == 1) { |
1805 | /* dilation */ |
1806 | int dilate_h = jcp.dilate_h + 1; |
1807 | // Note: use div_up to account for "holes" in filter |
1808 | int o_t_overflow = div_up( |
1809 | max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), |
1810 | dilate_h); |
1811 | int o_b_overflow |
1812 | = div_up(max(0, |
1813 | (jcp.kh - 1) * dilate_h + 1 |
1814 | - jcp.oh + oj - jcp.b_pad), |
1815 | dilate_h); |
1816 | kh_len = jcp.kh - o_t_overflow - o_b_overflow; |
1817 | kh_lo = o_b_overflow; |
1818 | ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; |
1819 | } else { |
1820 | int o_t_overflow = max( |
1821 | 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); |
1822 | int o_b_overflow = max(0, |
1823 | ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) |
1824 | / jcp.stride_h); |
1825 | int overflow_kh_hi = jcp.kh - 1 |
1826 | - modulo(jcp.oh + jcp.b_pad - (oj + 1), |
1827 | jcp.stride_h); |
1828 | int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; |
1829 | |
1830 | kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h |
1831 | + 1 - o_t_overflow - o_b_overflow; |
1832 | kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; |
1833 | ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; |
1834 | } |
1835 | |
1836 | int wei_stride = (!jcp.signed_input && !jcp.src_zero_point) |
1837 | ? kh_lo * wht_kh_stride |
1838 | : 0; |
1839 | p.src = src_w + ih_max * src_h_stride; |
1840 | p.dst = dst_w + dst_dt_size * oj * dst_h_stride; |
1841 | p.filt = wht_w + wei_stride; |
1842 | p.bias = bias_w; |
1843 | p.compensation = compensation_w; |
1844 | /* Note: Currently this kernel doesn't support dilations and |
1845 | strides together */ |
1846 | p.t_overflow = jcp.dilate_h > 0 |
1847 | ? jcp.kh - kh_len - kh_lo |
1848 | : max(0, |
1849 | jcp.kh |
1850 | - (kh_lo |
1851 | + max(0, kh_len - 1) |
1852 | * jcp.stride_h |
1853 | + 1)); |
1854 | p.b_overflow = kh_lo; |
1855 | p.f_overflow = jcp.dilate_d > 0 |
1856 | ? jcp.kd - kd_len - kd_lo |
1857 | : max(0, |
1858 | jcp.kd |
1859 | - (kd_lo |
1860 | + max(0, kd_len - 1) |
1861 | * jcp.stride_d |
1862 | + 1)); |
1863 | p.back_overflow = kd_lo; |
1864 | p.kh_padding = kh_len; |
1865 | p.kd_padding = kd_len; |
1866 | p.scales = scales; |
1867 | p.dst_scale = dst_scales; |
1868 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1869 | p.post_ops_binary_rhs_arg_vec |
1870 | = post_ops_binary_rhs_arg_vec.data(); |
1871 | p.oc_l_off = g_oc; |
1872 | p.zp_compensation |
1873 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1874 | p.zp_src_pad_str_compensation = jcp.src_zero_point |
1875 | ? zp_src_comp_scratch + g_oc |
1876 | : nullptr; |
1877 | p.src_zero_point = zp_src; |
1878 | p.dst_zero_point = zp_dst; |
1879 | p.dst_orig = dst; |
1880 | (*kernel_)(&p); |
1881 | } |
1882 | if (jcp.loop_order == loop_ngc) |
1883 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
1884 | oc_chunks, od_s, jcp.od, oh_s, jcp.oh); |
1885 | else if (jcp.loop_order == loop_cgn) |
1886 | nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, |
1887 | jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
1888 | else |
1889 | assert(!"unsupported loop order" ); |
1890 | } |
1891 | }); |
1892 | return status::success; |
1893 | } |
1894 | |
1895 | template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Zmm>; |
1896 | template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Ymm>; |
1897 | template struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel<Xbyak::Xmm>; |
1898 | |
1899 | } // namespace x64 |
1900 | } // namespace cpu |
1901 | } // namespace impl |
1902 | } // namespace dnnl |
1903 | |