1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/utils.hpp" |
20 | |
21 | #include "cpu/cpu_primitive.hpp" |
22 | #include "cpu/zero_point_utils.hpp" |
23 | |
24 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
25 | #include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp" |
26 | #include "cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp" |
27 | |
28 | #define GET_OFF(field) offsetof(jit_deconv_call_s, field) |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::status; |
36 | using namespace dnnl::impl::memory_tracking::names; |
37 | using namespace dnnl::impl::utils; |
38 | using namespace Xbyak; |
39 | |
40 | using namespace nstl; |
41 | |
42 | #define wht_blk_off(d, g, ...) \ |
43 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
44 | : (d).blk_off(__VA_ARGS__)) |
45 | |
46 | template <cpu_isa_t isa> |
47 | status_t jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf( |
48 | jit_conv_conf_t &jcp, const deconvolution_desc_t &cd, |
49 | memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, |
50 | const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr, |
51 | int nthreads) { |
52 | const memory_desc_wrapper src_d(&src_md); |
53 | const memory_desc_wrapper dst_d(&dst_md); |
54 | const memory_desc_wrapper weights_d(&weights_md); |
55 | const memory_desc_wrapper bias_d(&bias_md); |
56 | |
57 | if (!(mayiuse(isa) |
58 | && one_of(src_d.data_type(), data_type::u8, data_type::s8) |
59 | && weights_d.data_type() == data_type::s8 |
60 | && one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
61 | data_type::s8, data_type::u8))) |
62 | return status::unimplemented; |
63 | |
64 | jcp = zero<decltype(jcp)>(); |
65 | jcp.nthr = nthreads; |
66 | |
67 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
68 | jcp.signed_input = src_d.data_type() == data_type::s8; |
69 | const int ndims = jcp.ndims = dst_d.ndims(); |
70 | const bool is_1d = ndims == 3; |
71 | const bool is_2d = ndims == 4; |
72 | const bool is_3d = ndims == 5; |
73 | const bool is_avx2 = isa == avx2; |
74 | |
75 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
76 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
77 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
78 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
79 | jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; |
80 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
81 | jcp.is_depthwise = true && with_groups |
82 | && utils::everyone_is( |
83 | 1, jcp.ic_without_padding, jcp.oc_without_padding); |
84 | jcp.has_vnni = mayiuse(avx2_vnni); |
85 | |
86 | /* TODO: future work, on hold until depthwise specialized kernel is |
87 | * implemented. */ |
88 | if (jcp.is_depthwise && (jcp.signed_input || is_3d)) |
89 | return status::unimplemented; |
90 | |
91 | if (!zero_points_valid(&attr)) return status::unimplemented; |
92 | jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC); |
93 | jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST); |
94 | jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC); |
95 | |
96 | format_tag_t dat_tag = utils::pick( |
97 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
98 | |
99 | if (src_d.format_kind() == format_kind::any) { |
100 | CHECK(memory_desc_init_by_tag(src_md, dat_tag)); |
101 | jcp.src_tag = dat_tag; |
102 | } else { |
103 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
104 | } |
105 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
106 | |
107 | if (dst_d.format_kind() == format_kind::any) { |
108 | CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); |
109 | jcp.dst_tag = dat_tag; |
110 | } else { |
111 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
112 | } |
113 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
114 | |
115 | auto set_or_check_wei_format = [&]() { |
116 | using namespace format_tag; |
117 | format_tag_t wei_tag; |
118 | if (jcp.ic_block == 8 || jcp.ch_block == 8) { |
119 | if (is_1d) { |
120 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i |
121 | : OIw2i8o4i; |
122 | } else if (is_2d) { |
123 | wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i |
124 | : OIhw2i8o4i; |
125 | } else { |
126 | wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i; |
127 | } |
128 | } else { |
129 | if (is_avx2) { |
130 | assert(with_groups && jcp.ic_block == 4); |
131 | wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i; |
132 | } else { |
133 | if (is_1d) { |
134 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i |
135 | : OIw4o4i; |
136 | } else if (is_2d) { |
137 | wei_tag = with_groups |
138 | ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i |
139 | : OIhw4o4i; |
140 | } else { |
141 | wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i; |
142 | } |
143 | } |
144 | } |
145 | |
146 | memory_desc_t want_wei_md = weights_md; |
147 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
148 | if (jcp.signed_input && !jcp.is_depthwise) { |
149 | want_wei_md.extra.flags = 0 |
150 | | memory_extra_flags::compensation_conv_s8s8 |
151 | | memory_extra_flags::scale_adjust; |
152 | want_wei_md.extra.compensation_mask = (1 << 0) |
153 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
154 | want_wei_md.extra.scale_adjust = jcp.has_vnni ? 1.f : 0.5f; |
155 | } |
156 | if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups); |
157 | |
158 | if (weights_md.format_kind == format_kind::any) { |
159 | weights_md = want_wei_md; |
160 | return true; |
161 | } |
162 | |
163 | return weights_md == want_wei_md; |
164 | }; |
165 | |
166 | jcp.with_bias = with_bias; |
167 | if (jcp.with_bias) { |
168 | if (bias_d.format_kind() == format_kind::any) |
169 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
170 | } |
171 | |
172 | jcp.prop_kind = cd.prop_kind; |
173 | jcp.mb = src_d.dims()[0]; |
174 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
175 | jcp.iw = src_d.dims()[ndims - 1]; |
176 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
177 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
178 | jcp.ow = dst_d.dims()[ndims - 1]; |
179 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
180 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
181 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
182 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
183 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
184 | jcp.l_pad = cd.padding[0][ndims - 3]; |
185 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
186 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
187 | jcp.stride_w = cd.strides[ndims - 3]; |
188 | |
189 | if (jcp.is_depthwise) { |
190 | jcp.ch_block = is_avx2 ? 8 : 4; |
191 | jcp.oc_block = 1; |
192 | jcp.ic_block = 1; |
193 | } else { |
194 | jcp.ch_block = 1; |
195 | jcp.oc_block = is_avx2 ? 8 : 4; |
196 | jcp.ic_block = is_avx2 ? 8 : 4; |
197 | |
198 | if (jcp.ngroups == 1) { |
199 | jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block); |
200 | jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block); |
201 | } else if (jcp.ngroups != 1 |
202 | && ((jcp.ic % jcp.ic_block != 0) |
203 | || (jcp.oc % jcp.oc_block != 0))) { |
204 | /* For grouped convolution, oneDNN doesn't support padding. |
205 | * When channel per group is not multiple of 8 in avx2: |
206 | * - Use Xmm when channels per groups is multiple of 4. |
207 | * - Otherwise return unimplemented */ |
208 | jcp.oc_block = jcp.ic_block = 4; |
209 | } |
210 | if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0) |
211 | return status::unimplemented; |
212 | } |
213 | |
214 | if (!set_or_check_wei_format()) return status::unimplemented; |
215 | |
216 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
217 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
218 | jcp.dilate_w = cd.dilates[ndims - 3]; |
219 | |
220 | if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1) |
221 | || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1) |
222 | || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1)) |
223 | return status::unimplemented; |
224 | |
225 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
226 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
227 | const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
228 | jcp.r_pad = calculate_end_padding( |
229 | jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw); |
230 | jcp.b_pad = calculate_end_padding( |
231 | jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh); |
232 | jcp.back_pad = calculate_end_padding( |
233 | jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd); |
234 | const bool kernel_outside_src = false || ext_kw <= jcp.l_pad |
235 | || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad |
236 | || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; |
237 | if (kernel_outside_src) return status::unimplemented; |
238 | |
239 | CHECK(attr.set_default_formats(&dst_md)); |
240 | if (!post_ops_ok(jcp, dst_d, attr)) return status::unimplemented; |
241 | |
242 | const auto &p = attr.post_ops_; |
243 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
244 | jcp.with_eltwise = eltwise_ind != -1; |
245 | if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
246 | |
247 | const int binary_ind = p.find(primitive_kind::binary); |
248 | jcp.with_binary = binary_ind != -1; |
249 | |
250 | const int sum_ind = p.find(primitive_kind::sum); |
251 | jcp.with_sum = sum_ind != -1; |
252 | |
253 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
254 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
255 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
256 | const int wei_scales_per_oc = 1 << (int)with_groups; |
257 | jcp.is_oc_scale = wei_scales.mask_ == wei_scales_per_oc; |
258 | jcp.dst_scale = !dst_scales.has_default_values(); |
259 | |
260 | jcp.post_ops = p; |
261 | |
262 | // only common and per-oc-channel scales are supported |
263 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_scales_per_oc) |
264 | && utils::everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
265 | if (!scales_ok) return status::unimplemented; |
266 | |
267 | jcp.dst_dt = dst_d.data_type(); |
268 | jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef; |
269 | jcp.typesize_bia |
270 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
271 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
272 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
273 | |
274 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
275 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
276 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
277 | |
278 | /* kernel blocking params */ |
279 | const int regs = jcp.has_vnni ? 14 : 12; |
280 | |
281 | jcp.nb_ch_blocking = 1; |
282 | jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc); |
283 | for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) |
284 | if (jcp.nb_oc % jcp.nb_oc_blocking == 0 |
285 | && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1)) |
286 | break; |
287 | |
288 | jcp.ur_w = regs / (jcp.nb_oc_blocking + 1); |
289 | const int l_overflow = max( |
290 | 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w); |
291 | |
292 | if (jcp.ow < jcp.ur_w) { |
293 | jcp.ur_w = jcp.ow; |
294 | jcp.ur_w_tail = 0; |
295 | } else { |
296 | for (; jcp.ur_w >= 1; jcp.ur_w--) { |
297 | /* ur_w should be multiple of stride_w in order |
298 | to simplify logic for get_ow_start and get_ow_end */ |
299 | const bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0; |
300 | |
301 | /* boundary conditions: |
302 | These conditions ensure all elements close to boundary |
303 | are computed in a single call of compute loop */ |
304 | const bool left_boundary_covered |
305 | = jcp.ur_w >= l_overflow * jcp.stride_w; |
306 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
307 | const int r_overflow_no_tail = max(0, |
308 | ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad) |
309 | - jcp.ur_w_tail) |
310 | / jcp.stride_w); |
311 | const bool right_boundary_covered |
312 | = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w; |
313 | |
314 | if (is_multiple_of_stride && left_boundary_covered |
315 | && right_boundary_covered) |
316 | break; |
317 | else if (jcp.ur_w == 1) |
318 | /* The boundary conditions above are also important |
319 | to maintain simplicity of calls to icb_loop, |
320 | if those conditions are not satisfied, |
321 | then special cases will need to be added |
322 | to use correct l_overflow/r_overflow values |
323 | when different iterations of compute loop |
324 | work on the locations close to boundary. |
325 | So to keep code simple, return unimplemented |
326 | for extreme case when a good ur_w cannot be found. |
327 | */ |
328 | return status::unimplemented; |
329 | } |
330 | } |
331 | |
332 | jcp.wei_adj_scale |
333 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
334 | ? weights_d.extra().scale_adjust |
335 | : 1.f; |
336 | |
337 | jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn; |
338 | return status::success; |
339 | } |
340 | |
341 | template <cpu_isa_t isa> |
342 | jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::jit_uni_x8s8s32x_deconv_fwd_kernel( |
343 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
344 | const memory_desc_wrapper &dst_d) |
345 | : kernel_(nullptr) { |
346 | |
347 | const int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; |
348 | switch (ch_block) { |
349 | case 8: |
350 | if (isa == avx2) { |
351 | kernel_ = utils::make_unique< |
352 | _jit_avx2_x8s8s32x_deconv_fwd_kernel>( |
353 | ajcp, attr, dst_d); |
354 | return; |
355 | } else |
356 | assert(!"invalid channel blocking for current ISA" ); |
357 | case 4: |
358 | kernel_ = utils::make_unique< |
359 | _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Xbyak::Xmm>>( |
360 | ajcp, attr, dst_d); |
361 | return; |
362 | default: assert(!"invalid channel blocking" ); |
363 | } |
364 | } |
365 | |
366 | template <cpu_isa_t isa> |
367 | jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::~jit_uni_x8s8s32x_deconv_fwd_kernel() |
368 | = default; |
369 | |
370 | template <cpu_isa_t isa> |
371 | void jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad( |
372 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
373 | const primitive_attr_t &attr) { |
374 | const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
375 | const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
376 | dim_t count = nstl::max<dim_t>(scales_count, 8); |
377 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
378 | |
379 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) { |
380 | const auto zp_pad_comp_size |
381 | = static_cast<size_t>(jcp.oc_without_padding) * jcp.ngroups |
382 | * jcp.kd * jcp.kh * jcp.kw; |
383 | scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size); |
384 | } |
385 | } |
386 | |
387 | template <cpu_isa_t isa> |
388 | bool jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::post_ops_ok(jit_conv_conf_t &jcp, |
389 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) { |
390 | using namespace injector; |
391 | |
392 | return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary}, |
393 | attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/, |
394 | false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/, |
395 | {broadcasting_strategy_t::per_oc, |
396 | broadcasting_strategy_t::scalar})); |
397 | } |
398 | |
399 | template <cpu_isa_t isa, typename Vmm> |
400 | _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, |
401 | Vmm>::_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
402 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) |
403 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
404 | , jcp_(ajcp) |
405 | , postops_injector_(nullptr) |
406 | , ker_max_regs_(jcp_.has_vnni ? 14 : 12) { |
407 | |
408 | if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) { |
409 | const std::size_t tail_size = get_tail_size(); |
410 | |
411 | static constexpr bool preserve_gpr = true; |
412 | static constexpr bool preserve_vmm = true; |
413 | static constexpr bool use_exact_tail_scalar_bcast = false; |
414 | static constexpr size_t vmm_helper_idx = 15; |
415 | |
416 | const binary_injector::rhs_arg_static_params_t rhs_sp {vmm_helper_idx, |
417 | this->r14, this->r15, this->r13, preserve_gpr, preserve_vmm, |
418 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d, |
419 | tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast}; |
420 | const binary_injector::static_params_t bsp {this->param1_, rhs_sp}; |
421 | |
422 | postops_injector_ = utils::make_unique< |
423 | injector::jit_uni_postops_injector_t<isa, Vmm>>( |
424 | this, jcp_.post_ops, bsp); |
425 | } |
426 | } |
427 | |
428 | template <cpu_isa_t isa, typename Vmm> |
429 | _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, |
430 | Vmm>::~_jit_uni_x8s8s32x_deconv_fwd_kernel() |
431 | = default; |
432 | |
433 | template <cpu_isa_t isa, typename Vmm> |
434 | Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_out( |
435 | int i_ur, int i_oc) const { |
436 | const int idx = i_ur * jcp_.nb_oc_blocking + i_oc; |
437 | assert(idx < ker_max_regs_); |
438 | /* remap the reg indices to avoid using xmm0 in eltwise injector */ |
439 | return Vmm(15 - idx); |
440 | } |
441 | |
442 | template <cpu_isa_t isa, typename Vmm> |
443 | Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_inp( |
444 | int i_ic, int nb_x_blocking) const { |
445 | const int idx = i_ic + nb_x_blocking * jcp_.ur_w; |
446 | assert(idx < ker_max_regs_); |
447 | return Vmm(15 - idx); |
448 | } |
449 | |
450 | template <cpu_isa_t isa, typename Vmm> |
451 | int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_start( |
452 | int ki, int l_overflow) const noexcept { |
453 | int res = (jcp_.ow - 1 + jcp_.r_pad) % jcp_.stride_w |
454 | + l_overflow * jcp_.stride_w |
455 | - (jcp_.kw - 1 - ki) * (jcp_.dilate_w + 1); |
456 | while (res < 0) |
457 | res += jcp_.stride_w; |
458 | return res; |
459 | } |
460 | |
461 | template <cpu_isa_t isa, typename Vmm> |
462 | int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_end( |
463 | int ur_w, int ki, int r_overflow) const noexcept { |
464 | if (utils::one_of(ur_w, jcp_.ow, jcp_.ur_w_tail)) |
465 | ur_w += nstl::min(0, jcp_.r_pad); // remove negative padding |
466 | int res = (ur_w - 1 + jcp_.l_pad) % jcp_.stride_w |
467 | + r_overflow * jcp_.stride_w - ki * (jcp_.dilate_w + 1); |
468 | while (res < 0) |
469 | res += jcp_.stride_w; |
470 | return ur_w - res; |
471 | } |
472 | |
473 | template <cpu_isa_t isa, typename Vmm> |
474 | int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_blocking_size() const |
475 | noexcept { |
476 | return jcp_.is_depthwise ? jcp_.ch_block : jcp_.oc_block; |
477 | } |
478 | |
479 | template <cpu_isa_t isa, typename Vmm> |
480 | int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_tail_size() const |
481 | noexcept { |
482 | return jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block |
483 | : jcp_.oc_without_padding % jcp_.oc_block; |
484 | } |
485 | |
486 | template <cpu_isa_t isa, typename Vmm> |
487 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute( |
488 | const Vmm vreg_acc, const Vmm vreg_wei, const Vmm vreg_src) { |
489 | |
490 | if (jcp_.has_vnni) { |
491 | vpdpbusd(vreg_acc, vreg_src, vreg_wei, Xbyak::VexEncoding); |
492 | } else if (jcp_.is_depthwise) { |
493 | uni_vmovups(vmm_tmp_, vreg_src); |
494 | uni_vpmulld(vmm_tmp_, vmm_tmp_, vreg_wei); |
495 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_); |
496 | } else { |
497 | uni_vpmaddubsw(vmm_tmp_, vreg_src, vreg_wei); |
498 | uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_); |
499 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_); |
500 | } |
501 | } |
502 | |
503 | template <cpu_isa_t isa, typename Vmm> |
504 | std::function<Vmm()> _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, |
505 | Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept { |
506 | |
507 | const int start_vmm_idx = vmm_inp(ur_w - 1, jcp_.nb_oc_blocking).getIdx(); |
508 | const int end_vmm_idx = vmm_inp(0, jcp_.nb_oc_blocking).getIdx() + 1; |
509 | int current_vmm_idx = start_vmm_idx; |
510 | |
511 | return [=]() mutable { |
512 | const Vmm vmm {static_cast<int>(current_vmm_idx++)}; |
513 | |
514 | if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx; |
515 | |
516 | return vmm; |
517 | }; |
518 | } |
519 | |
520 | template <cpu_isa_t isa, typename Vmm> |
521 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_zp_src_pad_str_comp( |
522 | int ur_w, int l_overflow, int r_overflow, bool h_padded) { |
523 | Xbyak::Label end_zp_pad, no_tail; |
524 | |
525 | // apply once per icb loop, zp src stride paddding compensation calculate as |
526 | // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source |
527 | cmp(reg_icb_, jcp_.nb_ic); |
528 | jne(end_zp_pad, T_NEAR); |
529 | |
530 | if (jcp_.ngroups % jcp_.ch_block |
531 | || jcp_.oc_without_padding % jcp_.oc_block) { |
532 | if (jcp_.is_depthwise) |
533 | cmp(reg_oc_blocks_, jcp_.nb_ch - 1); |
534 | else |
535 | cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking); |
536 | jne(no_tail, T_NEAR); |
537 | |
538 | append_zp_src_pad_str_comp( |
539 | ur_w, l_overflow, r_overflow, h_padded, true /*last_oc_block*/); |
540 | jmp(end_zp_pad, T_NEAR); |
541 | } |
542 | |
543 | L(no_tail); |
544 | append_zp_src_pad_str_comp( |
545 | ur_w, l_overflow, r_overflow, h_padded, false /*last_oc_block*/); |
546 | |
547 | L(end_zp_pad); |
548 | } |
549 | |
550 | template <cpu_isa_t isa, typename Vmm> |
551 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::append_zp_src_pad_str_comp( |
552 | int ur_w, int l_overflow, int r_overflow, bool h_padded, |
553 | bool last_oc_block) { |
554 | |
555 | const auto ®_zp_src_pad_comp = reg_scratch_; |
556 | const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w); |
557 | bool base_comp_addr_loaded = false; |
558 | |
559 | const auto load_base_zp_src_pad_comp_addr = [&]() { |
560 | if (!base_comp_addr_loaded) { |
561 | if (jcp_.ndims == 5) mov(reg_scratch_preserved_, reg_scratch_); |
562 | |
563 | if (jcp_.ndims > 3) |
564 | mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr_); |
565 | else |
566 | mov(reg_zp_src_pad_comp, |
567 | qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]); |
568 | |
569 | base_comp_addr_loaded = true; |
570 | } |
571 | }; |
572 | |
573 | const auto load_zp_src_pad_comp = [&](const Vmm zp_pad_comp_vmm, |
574 | const Xbyak::Address &comp_addr, |
575 | const int ocb) { |
576 | const bool is_tail = last_oc_block && ocb == jcp_.nb_oc_blocking - 1; |
577 | |
578 | if (is_tail) |
579 | load_data(data_type::s32, zp_pad_comp_vmm, comp_addr, |
580 | get_tail_size()); |
581 | else |
582 | uni_vmovups(zp_pad_comp_vmm, comp_addr); |
583 | }; |
584 | |
585 | const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) { |
586 | const auto kw_offset = it_kw * jcp_.oc_without_padding * jcp_.ngroups; |
587 | const auto oc_offset = ocb * jcp_.oc_block; |
588 | |
589 | return (kw_offset + oc_offset) * sizeof(int32_t); |
590 | }; |
591 | |
592 | for (int it_kw = 0; it_kw < jcp_.kw; ++it_kw) { |
593 | const int ow_start = get_ow_start(it_kw, l_overflow); |
594 | const int ow_end = get_ow_end(ur_w, it_kw, r_overflow); |
595 | |
596 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
597 | Vmm zp_src_comp_pad_vmm; // will be assigned later |
598 | bool ocb_zp_loaded = false; |
599 | |
600 | const auto zp_src_comp_pad_off |
601 | = get_zp_src_comp_pad_off(it_kw, ocb); |
602 | |
603 | for (int it_ow = 0; it_ow < ur_w; ++it_ow) { |
604 | |
605 | const bool inside_padded_area = h_padded |
606 | || !(it_ow >= ow_start && it_ow < ow_end |
607 | && ((it_ow + jcp_.l_pad - it_kw) % jcp_.stride_w |
608 | == 0)); |
609 | |
610 | if (inside_padded_area) { |
611 | load_base_zp_src_pad_comp_addr(); |
612 | |
613 | if (!ocb_zp_loaded) { |
614 | zp_src_comp_pad_vmm = get_next_comp_vmm(); |
615 | const auto comp_addr = ptr[reg_zp_src_pad_comp |
616 | + zp_src_comp_pad_off]; |
617 | load_zp_src_pad_comp( |
618 | zp_src_comp_pad_vmm, comp_addr, ocb); |
619 | ocb_zp_loaded = true; |
620 | } |
621 | |
622 | const auto vmm_dst = vmm_out(it_ow, ocb); |
623 | uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm); |
624 | } |
625 | } |
626 | } |
627 | } |
628 | |
629 | if (jcp_.ndims > 3) { |
630 | if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr(); |
631 | |
632 | const auto kh_offset = jcp_.kw * jcp_.oc_without_padding * jcp_.ngroups |
633 | * sizeof(int32_t); |
634 | |
635 | add(reg_zp_src_pad_comp, kh_offset); |
636 | mov(zp_src_pad_comp_addr_, reg_zp_src_pad_comp); |
637 | } |
638 | |
639 | if (jcp_.ndims == 5 && base_comp_addr_loaded) |
640 | mov(reg_scratch_, reg_scratch_preserved_); |
641 | } |
642 | |
643 | template <cpu_isa_t isa, typename Vmm> |
644 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute_ker(int ur_w, |
645 | int l_overflow, int r_overflow, ker_block_t last_ic_block_flag, |
646 | bool h_padded) { |
647 | |
648 | const bool signed_input_or_src_zp |
649 | = (jcp_.signed_input || jcp_.src_zero_point); |
650 | |
651 | const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block; |
652 | const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp_.stride_w; |
653 | |
654 | const auto src_offset = [=](int oj, int icb, int ki) { |
655 | return jcp_.typesize_in |
656 | * (((oj + jcp_.l_pad - ki * (jcp_.dilate_w + 1)) |
657 | / jcp_.stride_w) |
658 | * jcp_.ngroups * jcp_.ic_without_padding |
659 | + icb * 4); |
660 | }; |
661 | |
662 | const auto kernel_offset = [=](int ocb, int icb, int ki) { |
663 | return jcp_.typesize_in |
664 | * ((ocb * jcp_.nb_ic * jcp_.kd * jcp_.kh * jcp_.kw + ki) |
665 | * ch_block_all |
666 | + icb * jcp_.oc_block * IC_SUB_STEP); |
667 | }; |
668 | |
669 | for (int ki = 0; ki < jcp_.kw; ki++) { |
670 | |
671 | const int jj_start = get_ow_start(ki, l_overflow); |
672 | const int jj_end = get_ow_end(ur_w, ki, r_overflow); |
673 | |
674 | const int _start = (signed_input_or_src_zp) ? 0 : jj_start; |
675 | const int _end = (signed_input_or_src_zp) ? ur_w : jj_end; |
676 | |
677 | const int tail_size = jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block |
678 | : jcp_.ic_without_padding % 4; |
679 | const int n_ic_blocks = jcp_.is_depthwise |
680 | ? 1 |
681 | : (last_ic_block_flag != no_last_block ? div_up( |
682 | jcp_.ic_without_padding % jcp_.ic_block, 4) |
683 | : jcp_.ic_block / 4); |
684 | |
685 | for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) { |
686 | if (h_padded) { |
687 | /* fill padded area with shifted values */ |
688 | if (jcp_.signed_input) { |
689 | const Vmm inp = vmm_inp(0, jcp_.nb_oc_blocking); |
690 | uni_vpxor(inp, inp, inp); |
691 | uni_vpsubb(inp, inp, vmm_shift_); |
692 | } |
693 | } else { |
694 | |
695 | for (int jj = _start; jj < _end; jj += ur_w_stride) { |
696 | |
697 | const int aux_src_off = src_offset(jj, icb1, ki); |
698 | const auto vmm_src = vmm_inp(jj, jcp_.nb_oc_blocking); |
699 | |
700 | if (jj >= jj_start && jj < jj_end |
701 | && ((jj + jcp_.l_pad - ki) % jcp_.stride_w == 0)) { |
702 | if (jcp_.is_depthwise) { |
703 | if (tail_size != 0) |
704 | assert(jcp_.nb_oc_blocking == 1); |
705 | uni_vpxor(vmm_src, vmm_src, vmm_src); |
706 | const bool mask_flag |
707 | = last_ic_block_flag != no_last_block |
708 | && tail_size; |
709 | load_data(data_type::u8, vmm_src, aux_reg_src_, |
710 | aux_src_off, |
711 | mask_flag ? tail_size : jcp_.ch_block); |
712 | } else if ((last_ic_block_flag == last_sp_block) |
713 | && tail_size != 0 && icb1 == n_ic_blocks - 1) { |
714 | const auto vmm_inp_tmp = Xmm(vmm_src.getIdx()); |
715 | load_bytes(vmm_inp_tmp, aux_reg_src_, aux_src_off, |
716 | tail_size); |
717 | uni_vpbroadcastd(vmm_src, vmm_inp_tmp); |
718 | } else { |
719 | uni_vpbroadcastd( |
720 | vmm_src, ptr[aux_reg_src_ + aux_src_off]); |
721 | } |
722 | if (jcp_.signed_input) |
723 | uni_vpsubb(vmm_src, vmm_src, vmm_shift_); |
724 | } else { |
725 | /* fill padded area with shifted values */ |
726 | if (jcp_.signed_input) { |
727 | uni_vpxor(vmm_src, vmm_src, vmm_src); |
728 | uni_vpsubb(vmm_src, vmm_src, vmm_shift_); |
729 | } |
730 | } |
731 | } |
732 | } |
733 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
734 | const int aux_filt_off = kernel_offset(ocb, icb1, ki); |
735 | |
736 | if (_end - _start > 0) { |
737 | if (jcp_.is_depthwise) { |
738 | uni_vpmovsxbd( |
739 | vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]); |
740 | } else |
741 | uni_vmovups( |
742 | vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]); |
743 | } |
744 | |
745 | for (int jj = _start; jj < _end; jj += ur_w_stride) { |
746 | |
747 | const bool inside_padded_area = h_padded |
748 | || !(jj >= jj_start && jj < jj_end |
749 | && ((jj + jcp_.l_pad - ki) % jcp_.stride_w |
750 | == 0)); |
751 | const auto vmm_dst = vmm_out(jj, ocb); |
752 | if (jcp_.signed_input || !inside_padded_area) { |
753 | const Vmm inp = vmm_inp( |
754 | h_padded ? 0 : jj, jcp_.nb_oc_blocking); |
755 | compute(vmm_dst, vmm_wei_, inp); |
756 | } |
757 | } |
758 | } |
759 | } |
760 | } |
761 | |
762 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) |
763 | apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded); |
764 | } |
765 | |
766 | template <cpu_isa_t isa, typename Vmm> |
767 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::kh_loop(int ur_w, |
768 | int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) { |
769 | |
770 | const bool signed_input_or_src_zp |
771 | = (jcp_.signed_input || jcp_.src_zero_point); |
772 | |
773 | const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block; |
774 | const int shift_src_ih = jcp_.typesize_in * (jcp_.dilate_h + 1) * jcp_.iw |
775 | * jcp_.ngroups * jcp_.ic_without_padding; |
776 | const int shift_src_id = jcp_.typesize_in * (jcp_.dilate_d + 1) * jcp_.ih |
777 | * jcp_.iw * jcp_.ngroups * jcp_.ic_without_padding; |
778 | const int stride_h = signed_input_or_src_zp ? 1 : jcp_.stride_h; |
779 | const int shift_filt_kh |
780 | = jcp_.typesize_in * jcp_.kw * ch_block_all * stride_h; |
781 | const int stride_d = signed_input_or_src_zp ? 1 : jcp_.stride_d; |
782 | const int shift_filt_kd |
783 | = jcp_.typesize_in * jcp_.kw * ch_block_all * jcp_.kh * stride_d; |
784 | |
785 | Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop; |
786 | Label t_overflow_label, no_t_overflow_label, b_overflow_label, |
787 | no_b_overflow_label; |
788 | Label back_overflow_label, no_back_overflow_label, d_h_overflow_label, |
789 | front_overflow_label, no_front_overflow_label, d_h_overflow_label2; |
790 | if (jcp_.ndims == 5) { |
791 | mov(aux_reg_filt_d_, reg_filt_); |
792 | mov(aux_reg_src_d_, reg_src_); |
793 | |
794 | if (signed_input_or_src_zp) { |
795 | mov(reg_ki_, ptr[param1_ + GET_OFF(back_overflow)]); |
796 | cmp(reg_ki_, 0); |
797 | je(no_back_overflow_label, T_NEAR); |
798 | |
799 | L(back_overflow_label); |
800 | { |
801 | mov(aux_reg_filt_, aux_reg_filt_d_); |
802 | mov(reg_kh_, jcp_.kh); |
803 | L(d_h_overflow_label); |
804 | { |
805 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
806 | add(aux_reg_filt_, shift_filt_kh); |
807 | dec(reg_kh_); |
808 | jnz(d_h_overflow_label); |
809 | } |
810 | |
811 | add(aux_reg_filt_d_, shift_filt_kd); |
812 | dec(reg_ki_); |
813 | jnz(back_overflow_label); |
814 | } |
815 | L(no_back_overflow_label); |
816 | } |
817 | |
818 | mov(reg_ki_, ptr[param1_ + GET_OFF(kd_padding)]); |
819 | |
820 | if ((signed_input_or_src_zp) || (jcp_.dilate_d >= jcp_.id) |
821 | || ((!signed_input_or_src_zp) |
822 | && ((min(jcp_.f_pad, jcp_.back_pad) < 0) |
823 | || ((jcp_.kd - 1) * (jcp_.dilate_d + 1) |
824 | < nstl::max( |
825 | jcp_.f_pad, jcp_.back_pad))))) { |
826 | cmp(reg_ki_, 0); |
827 | je(skip_kd_loop, T_NEAR); |
828 | } |
829 | |
830 | L(kd_loop_label); |
831 | mov(aux_reg_src_, aux_reg_src_d_); |
832 | mov(aux_reg_filt_, aux_reg_filt_d_); |
833 | } else { |
834 | mov(aux_reg_src_, reg_src_); |
835 | mov(aux_reg_filt_, reg_filt_); |
836 | } |
837 | |
838 | if (signed_input_or_src_zp && jcp_.ndims > 3) { |
839 | /* Weights are transposed, so first compute 'bottom' padding. */ |
840 | mov(reg_overflow_, ptr[param1_ + GET_OFF(b_overflow)]); |
841 | cmp(reg_overflow_, 0); |
842 | je(no_b_overflow_label, T_NEAR); |
843 | L(b_overflow_label); |
844 | { |
845 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
846 | |
847 | add(aux_reg_filt_, shift_filt_kh); |
848 | dec(reg_overflow_); |
849 | cmp(reg_overflow_, 0); |
850 | jg(b_overflow_label, T_NEAR); |
851 | } |
852 | L(no_b_overflow_label); |
853 | } |
854 | |
855 | mov(reg_kh_, ptr[param1_ + GET_OFF(kh_padding)]); |
856 | |
857 | if ((signed_input_or_src_zp) || (jcp_.dilate_h >= jcp_.ih) |
858 | || ((!signed_input_or_src_zp) |
859 | && ((min(jcp_.t_pad, jcp_.b_pad) < 0) |
860 | || ((jcp_.kh - 1) * (jcp_.dilate_h + 1) |
861 | < nstl::max(jcp_.t_pad, jcp_.b_pad))))) { |
862 | cmp(reg_kh_, 0); |
863 | je(skip_kh_loop, T_NEAR); |
864 | } |
865 | |
866 | L(kh_loop_label); |
867 | { |
868 | compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false); |
869 | sub(aux_reg_src_, shift_src_ih); |
870 | add(aux_reg_filt_, shift_filt_kh); |
871 | dec(reg_kh_); |
872 | |
873 | /* Insert weight compensation in stride 'holes' */ |
874 | if (signed_input_or_src_zp && jcp_.stride_h > 1) { |
875 | Label kh_comp_loop; |
876 | |
877 | cmp(reg_kh_, 0); |
878 | je(skip_kh_loop, T_NEAR); |
879 | mov(reg_comp_strides_, jcp_.stride_h - 1); |
880 | L(kh_comp_loop); |
881 | { |
882 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
883 | add(aux_reg_filt_, shift_filt_kh); |
884 | dec(reg_comp_strides_); |
885 | cmp(reg_comp_strides_, 0); |
886 | jg(kh_comp_loop, T_NEAR); |
887 | } |
888 | } |
889 | cmp(reg_kh_, 0); |
890 | jg(kh_loop_label, T_NEAR); |
891 | } |
892 | L(skip_kh_loop); |
893 | if (signed_input_or_src_zp && jcp_.ndims > 3) { |
894 | mov(reg_overflow_, ptr[param1_ + GET_OFF(t_overflow)]); |
895 | cmp(reg_overflow_, 0); |
896 | je(no_t_overflow_label, T_NEAR); |
897 | L(t_overflow_label); |
898 | { |
899 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
900 | |
901 | add(aux_reg_filt_, shift_filt_kh); |
902 | dec(reg_overflow_); |
903 | cmp(reg_overflow_, 0); |
904 | jg(t_overflow_label, T_NEAR); |
905 | } |
906 | L(no_t_overflow_label); |
907 | } |
908 | |
909 | if (jcp_.ndims == 5) { |
910 | sub(aux_reg_src_d_, shift_src_id); |
911 | add(aux_reg_filt_d_, shift_filt_kd); |
912 | dec(reg_ki_); |
913 | |
914 | /* Insert weight compensation in stride 'holes' */ |
915 | if (signed_input_or_src_zp && jcp_.stride_d > 1) { |
916 | Label kd_comp_loop, kd_kh_comp_loop; |
917 | cmp(reg_ki_, 0); |
918 | jz(skip_kd_loop, T_NEAR); |
919 | mov(reg_comp_strides_, jcp_.stride_d - 1); |
920 | L(kd_comp_loop); |
921 | mov(aux_reg_filt_, aux_reg_filt_d_); |
922 | mov(reg_kh_, jcp_.kh); |
923 | L(kd_kh_comp_loop); |
924 | { |
925 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
926 | add(aux_reg_filt_, shift_filt_kh); |
927 | dec(reg_kh_); |
928 | jnz(kd_kh_comp_loop, T_NEAR); |
929 | } |
930 | add(aux_reg_filt_d_, shift_filt_kd); |
931 | dec(reg_comp_strides_); |
932 | jnz(kd_comp_loop); |
933 | } |
934 | |
935 | cmp(reg_ki_, 0); |
936 | jg(kd_loop_label, T_NEAR); |
937 | L(skip_kd_loop); |
938 | if (signed_input_or_src_zp) { |
939 | mov(reg_ki_, ptr[param1_ + GET_OFF(f_overflow)]); |
940 | cmp(reg_ki_, 0); |
941 | jz(no_front_overflow_label, T_NEAR); |
942 | L(front_overflow_label); |
943 | { |
944 | mov(aux_reg_filt_, aux_reg_filt_d_); |
945 | mov(reg_kh_, jcp_.kh); |
946 | L(d_h_overflow_label2); |
947 | { |
948 | compute_ker(ur_w, 0, 0, last_ic_block_flag, true); |
949 | add(aux_reg_filt_, shift_filt_kh); |
950 | dec(reg_kh_); |
951 | jnz(d_h_overflow_label2); |
952 | } |
953 | add(aux_reg_filt_d_, shift_filt_kd); |
954 | dec(reg_ki_); |
955 | jnz(front_overflow_label); |
956 | } |
957 | L(no_front_overflow_label); |
958 | } |
959 | } |
960 | } |
961 | |
962 | template <cpu_isa_t isa, typename Vmm> |
963 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) { |
964 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
965 | for (int ur = 0; ur < ur_w; ur++) { |
966 | const Vmm vmm = vmm_out(ur, ocb); |
967 | uni_vpxor(vmm, vmm, vmm); |
968 | } |
969 | } |
970 | if (jcp_.signed_input) { |
971 | const auto xmm_shift = Xbyak::Xmm(vmm_shift_.getIdx()); |
972 | mov(reg_scratch_, 0x80808080); |
973 | uni_vmovq(xmm_shift, reg_scratch_); |
974 | uni_vpbroadcastd(vmm_shift_, xmm_shift); |
975 | } |
976 | } |
977 | |
978 | template <cpu_isa_t isa, typename Vmm> |
979 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in, |
980 | const Vmm vmm_in, const Reg64 reg, int offset, int load_size) { |
981 | |
982 | load_data(type_in, vmm_in, reg, offset, load_size); |
983 | if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in); |
984 | } |
985 | |
986 | template <cpu_isa_t isa, typename Vmm> |
987 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_postops(int ur_w, |
988 | bool last_oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { |
989 | const auto sum_injector = [=]() { |
990 | if (p_sum_scale) { // post_op: sum |
991 | for (int k = 0; k < jcp_.nb_oc_blocking; k++) { |
992 | const bool mask_flag |
993 | = last_oc_block == 1 && k == jcp_.nb_oc_blocking - 1; |
994 | for (int j = 0; j < ur_w; j++) { |
995 | const int aux_output_offset = jcp_.typesize_out |
996 | * (k * jcp_.oc_block |
997 | + j * jcp_.oc_without_padding |
998 | * jcp_.ngroups); |
999 | cvt2ps(jcp_.dst_dt, vmm_prev_dst_, reg_dst_, |
1000 | aux_output_offset, |
1001 | mask_flag ? get_tail_size() : get_blocking_size()); |
1002 | if (*p_sum_zp != 0) { |
1003 | uni_vbroadcastss(vmm_sum_zp_, ptr[reg_ptr_sum_zp_]); |
1004 | uni_vcvtdq2ps(vmm_sum_zp_, vmm_sum_zp_); |
1005 | uni_vsubps(vmm_prev_dst_, vmm_prev_dst_, vmm_sum_zp_); |
1006 | } |
1007 | const Vmm vmm = vmm_out(j, k); |
1008 | if (*p_sum_scale == 1.f) |
1009 | uni_vaddps(vmm, vmm, vmm_prev_dst_); |
1010 | else { |
1011 | uni_vbroadcastss(vmm_tmp_, ptr[reg_ptr_sum_scale_]); |
1012 | uni_vfmadd231ps(vmm, vmm_prev_dst_, vmm_tmp_); |
1013 | } |
1014 | } |
1015 | } |
1016 | } |
1017 | }; |
1018 | |
1019 | if (p_sum_scale) |
1020 | postops_injector_->set_lambda_injector( |
1021 | primitive_kind::sum, sum_injector); |
1022 | |
1023 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
1024 | if (jcp_.with_binary) { |
1025 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1026 | const bool mask_flag |
1027 | = last_oc_block && ocb == jcp_.nb_oc_blocking - 1; |
1028 | for (int ur = 0; ur < ur_w; ur++) { |
1029 | const int vmm_idx = vmm_out(ur, ocb).getIdx(); |
1030 | const size_t aux_output_offset = jcp_.typesize_out |
1031 | * (ocb * jcp_.oc_block |
1032 | + ur * jcp_.oc_without_padding * jcp_.ngroups); |
1033 | |
1034 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst_); |
1035 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
1036 | vmm_idx, aux_output_offset); |
1037 | if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
1038 | } |
1039 | } |
1040 | } |
1041 | const int nb_oc_block |
1042 | = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking; |
1043 | postops_injector_->compute_vector_range( |
1044 | 16 - nb_oc_block * ur_w, 16, rhs_arg_params); |
1045 | } |
1046 | |
1047 | template <cpu_isa_t isa, typename Vmm> |
1048 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::store_output( |
1049 | int ur_w, bool last_oc_block) { |
1050 | mov(reg_bias_, ptr[param1_ + GET_OFF(bias)]); |
1051 | mov(reg_ptr_scales_, ptr[param1_ + GET_OFF(scales)]); |
1052 | |
1053 | if (jcp_.signed_input) |
1054 | mov(reg_compensation_, ptr[param1_ + GET_OFF(compensation)]); |
1055 | |
1056 | if (jcp_.src_zero_point) { |
1057 | mov(reg_zp_src_, ptr[param1_ + GET_OFF(src_zero_point)]); |
1058 | mov(reg_zp_compensation_, ptr[param1_ + GET_OFF(zp_compensation)]); |
1059 | } |
1060 | |
1061 | const auto &p = jcp_.post_ops; |
1062 | const int sum_idx = p.find(primitive_kind::sum); |
1063 | const float *p_sum_scale |
1064 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; |
1065 | const int32_t *p_sum_zp |
1066 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr; |
1067 | |
1068 | if (jcp_.src_zero_point) { |
1069 | const auto &vmm_src_zp = vmm_tmp_; |
1070 | const auto &vmm_zp_comp = vmm_scale_; |
1071 | uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]); |
1072 | |
1073 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1074 | const int zp_offset = sizeof(int32_t) * ocb * jcp_.oc_block; |
1075 | const bool mask_flag |
1076 | = last_oc_block && ocb == jcp_.nb_oc_blocking - 1; |
1077 | const int load_size |
1078 | = mask_flag ? get_tail_size() : get_blocking_size(); |
1079 | load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation_, |
1080 | zp_offset, load_size); |
1081 | uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_src_zp); |
1082 | |
1083 | for (int ur = 0; ur < ur_w; ur++) { |
1084 | const auto vmm_dst = vmm_out(ur, ocb); |
1085 | uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp); |
1086 | } |
1087 | } |
1088 | } |
1089 | |
1090 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1091 | const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1; |
1092 | const int scale_offset |
1093 | = jcp_.is_oc_scale * (sizeof(float) * ocb * jcp_.oc_block); |
1094 | |
1095 | const auto vmm_bias_ = vmm_tmp_; |
1096 | if (jcp_.with_bias) { |
1097 | int bias_offset = jcp_.typesize_bia * ocb * jcp_.oc_block; |
1098 | cvt2ps(jcp_.bia_dt, vmm_bias_, reg_bias_, bias_offset, |
1099 | mask_flag ? get_tail_size() : get_blocking_size()); |
1100 | } |
1101 | if (jcp_.signed_input) { |
1102 | const int comp_offset = sizeof(int32_t) * ocb * jcp_.oc_block; |
1103 | cvt2ps(data_type::s32, vmm_comp_, reg_compensation_, comp_offset, |
1104 | mask_flag ? get_tail_size() : get_blocking_size()); |
1105 | } |
1106 | |
1107 | /* add to ymm_accum: compensation, bias */ |
1108 | if (mask_flag) { |
1109 | load_data(data_type::f32, vmm_scale_, reg_ptr_scales_, scale_offset, |
1110 | get_tail_size()); |
1111 | } else { |
1112 | uni_vmovups(vmm_scale_, ptr[reg_ptr_scales_ + scale_offset]); |
1113 | } |
1114 | for (int ur = 0; ur < ur_w; ur++) { |
1115 | const Vmm vmm = vmm_out(ur, ocb); |
1116 | uni_vcvtdq2ps(vmm, vmm); |
1117 | if (jcp_.signed_input) uni_vaddps(vmm, vmm, vmm_comp_); |
1118 | uni_vmulps(vmm, vmm, vmm_scale_); |
1119 | if (jcp_.with_bias) uni_vaddps(vmm, vmm, vmm_bias_); |
1120 | } |
1121 | } |
1122 | |
1123 | if (p_sum_scale && *p_sum_scale != 1.f) |
1124 | mov(reg_ptr_sum_scale_, reinterpret_cast<size_t>(p_sum_scale)); |
1125 | if (p_sum_zp && *p_sum_zp != 0) { |
1126 | mov(reg_ptr_sum_zp_, reinterpret_cast<size_t>(p_sum_zp)); |
1127 | } |
1128 | if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) |
1129 | apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp); |
1130 | if (jcp_.dst_scale) { |
1131 | mov(reg_ptr_dst_scales_, ptr[param1_ + GET_OFF(dst_scale)]); |
1132 | uni_vmovups(vmm_dst_scale_, ptr[reg_ptr_dst_scales_]); |
1133 | |
1134 | for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) |
1135 | for (int ur = 0; ur < ur_w; ur++) { |
1136 | const auto vmm = vmm_out(ur, ocb); |
1137 | uni_vmulps(vmm, vmm, vmm_dst_scale_); |
1138 | } |
1139 | } |
1140 | if (jcp_.dst_zero_point) { |
1141 | mov(reg_zp_dst_, ptr[param1_ + GET_OFF(dst_zero_point)]); |
1142 | const auto &vmm_zp_dst = vmm_tmp_; |
1143 | uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]); |
1144 | uni_vcvtdq2ps(vmm_zp_dst, vmm_zp_dst); |
1145 | |
1146 | for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) |
1147 | for (int ur = 0; ur < ur_w; ur++) { |
1148 | const auto vmm_dst = vmm_out(ur, ocb); |
1149 | uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst); |
1150 | } |
1151 | } |
1152 | |
1153 | // Properly saturate the accumulators for integer datatypes |
1154 | |
1155 | // No need to saturate on lower bound for signed integer types, as |
1156 | // the conversion to int would return INT_MIN, and then proper |
1157 | // saturation will happen when storing data |
1158 | if (jcp_.dst_dt == data_type::u8) { |
1159 | uni_vpxor(vmm_zero_, vmm_zero_, vmm_zero_); |
1160 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1161 | for (int ur = 0; ur < ur_w; ur++) { |
1162 | const Vmm vmm = vmm_out(ur, ocb); |
1163 | uni_vmaxps(vmm, vmm, vmm_zero_); |
1164 | } |
1165 | } |
1166 | } |
1167 | |
1168 | if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) { |
1169 | float saturation_ubound = types::max_value<float>(jcp_.dst_dt); |
1170 | const Xmm xmm_saturation(vmm_saturation_.getIdx()); |
1171 | mov(reg_ptr_saturation_ubound_, float2int(saturation_ubound)); |
1172 | uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound_); |
1173 | uni_vbroadcastss(vmm_saturation_, xmm_saturation); |
1174 | |
1175 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1176 | for (int ur = 0; ur < ur_w; ur++) { |
1177 | const Vmm vmm = vmm_out(ur, ocb); |
1178 | uni_vminps(vmm, vmm, vmm_saturation_); |
1179 | } |
1180 | } |
1181 | } |
1182 | |
1183 | if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) { |
1184 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1185 | for (int ur = 0; ur < ur_w; ur++) { |
1186 | const Vmm vmm = vmm_out(ur, ocb); |
1187 | uni_vcvtps2dq(vmm, vmm); |
1188 | } |
1189 | } |
1190 | } |
1191 | |
1192 | /* write out register to output_addr */ |
1193 | for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) { |
1194 | const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1; |
1195 | for (int ur = 0; ur < ur_w; ur++) { |
1196 | const int aux_dst_off = jcp_.typesize_out |
1197 | * (ur * jcp_.ngroups * jcp_.oc_without_padding |
1198 | + ocb * jcp_.oc_block); |
1199 | const Vmm r_vmm = vmm_out(ur, ocb); |
1200 | store_data(jcp_.dst_dt, r_vmm, reg_dst_, aux_dst_off, |
1201 | mask_flag ? get_tail_size() : get_blocking_size()); |
1202 | } |
1203 | } |
1204 | } |
1205 | |
1206 | template <cpu_isa_t isa, typename Vmm> |
1207 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::icb_loop( |
1208 | int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) { |
1209 | |
1210 | const int shift_src_icb = jcp_.typesize_in * jcp_.ic_block; |
1211 | const size_t shift_filt_icb = (size_t)jcp_.typesize_in * jcp_.kd * jcp_.kh |
1212 | * jcp_.kw * jcp_.ic_block * jcp_.oc_block; |
1213 | |
1214 | prepare_output(ur_w); |
1215 | |
1216 | Label skip_icb_loop, icb_loop_label; |
1217 | |
1218 | mov(reg_icb_, jcp_.nb_ic); |
1219 | mov(reg_oc_blocks_, ptr[param1_ + GET_OFF(oc_blocks)]); |
1220 | |
1221 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_) |
1222 | && jcp_.ndims > 3) { |
1223 | mov(reg_scratch_, |
1224 | qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]); |
1225 | mov(zp_src_pad_comp_addr_, reg_scratch_); |
1226 | } |
1227 | |
1228 | L(icb_loop_label); |
1229 | { |
1230 | if (jcp_.ngroups % jcp_.ch_block != 0 |
1231 | || jcp_.ic_without_padding != jcp_.ic) { |
1232 | Label common_ker, end_ker; |
1233 | if (jcp_.is_depthwise) { |
1234 | cmp(reg_oc_blocks_, jcp_.nb_ch - 1); |
1235 | jne(common_ker, T_NEAR); |
1236 | } else { |
1237 | cmp(reg_icb_, 1); |
1238 | jg(common_ker, T_NEAR); |
1239 | } |
1240 | |
1241 | kh_loop(ur_w, l_overflow, r_overflow, last_sp_block); |
1242 | jmp(end_ker, T_NEAR); |
1243 | |
1244 | L(common_ker); |
1245 | kh_loop(ur_w, l_overflow, r_overflow, no_last_block); |
1246 | |
1247 | L(end_ker); |
1248 | } else { |
1249 | kh_loop(ur_w, l_overflow, r_overflow, no_last_block); |
1250 | } |
1251 | |
1252 | add(reg_src_, shift_src_icb); |
1253 | safe_add(reg_filt_, shift_filt_icb, reg_ker_long_offt_); |
1254 | dec(reg_icb_); |
1255 | cmp(reg_icb_, 0); |
1256 | jg(icb_loop_label, T_NEAR); |
1257 | } |
1258 | |
1259 | /* come-back pointers */ |
1260 | sub(reg_src_, jcp_.nb_ic * shift_src_icb); |
1261 | safe_sub(reg_filt_, jcp_.nb_ic * shift_filt_icb, reg_ker_long_offt_); |
1262 | L(skip_icb_loop); |
1263 | |
1264 | if (jcp_.ngroups % jcp_.ch_block != 0 |
1265 | || jcp_.oc_without_padding != jcp_.oc) { |
1266 | Label common_store, end_store; |
1267 | if (jcp_.is_depthwise) |
1268 | cmp(reg_oc_blocks_, jcp_.nb_ch - 1); |
1269 | else |
1270 | cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking); |
1271 | jne(common_store, T_NEAR); |
1272 | |
1273 | store_output(ur_w, true); |
1274 | jmp(end_store, T_NEAR); |
1275 | |
1276 | L(common_store); |
1277 | store_output(ur_w, false); |
1278 | |
1279 | L(end_store); |
1280 | |
1281 | } else { |
1282 | store_output(ur_w, false); |
1283 | } |
1284 | } |
1285 | |
1286 | template <cpu_isa_t isa, typename Vmm> |
1287 | void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::generate() { |
1288 | preamble(); |
1289 | |
1290 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) |
1291 | sub(rsp, reserved_stack_size_); |
1292 | |
1293 | const auto vmm_one_128 = Xbyak::Xmm(vmm_one_.getIdx()); |
1294 | mov(reg_scratch_, 0x10001); |
1295 | uni_vmovq(vmm_one_128, reg_scratch_); |
1296 | uni_vpbroadcastd(vmm_one_, vmm_one_128); |
1297 | |
1298 | mov(reg_src_, ptr[param1_ + GET_OFF(src)]); |
1299 | mov(reg_filt_, ptr[param1_ + GET_OFF(filt)]); |
1300 | mov(reg_dst_, ptr[param1_ + GET_OFF(dst)]); |
1301 | |
1302 | const int dst_shift = jcp_.typesize_out * jcp_.ur_w * jcp_.ngroups |
1303 | * jcp_.oc_without_padding; |
1304 | const int src_shift = jcp_.typesize_in * (jcp_.ur_w / jcp_.stride_w) |
1305 | * jcp_.ngroups * jcp_.ic_without_padding; |
1306 | |
1307 | const int l_overflow = max(0, |
1308 | ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - jcp_.l_pad) / jcp_.stride_w); |
1309 | const int r_overflow = max(0, |
1310 | ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - max(0, jcp_.r_pad)) |
1311 | / jcp_.stride_w); |
1312 | |
1313 | const int r_overflow1 = nstl::max(0, |
1314 | ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - nstl::max(0, jcp_.r_pad) |
1315 | - jcp_.ur_w_tail) |
1316 | / jcp_.stride_w); |
1317 | int nur_w = jcp_.ow / jcp_.ur_w; |
1318 | if (r_overflow1 > 0) nur_w--; |
1319 | |
1320 | if (jcp_.ur_w == jcp_.ow) { |
1321 | icb_loop(jcp_.ur_w, l_overflow, r_overflow, true); |
1322 | } else if (nur_w == 0) { |
1323 | icb_loop(jcp_.ur_w, l_overflow, r_overflow1, jcp_.ur_w_tail == 0); |
1324 | add(reg_src_, src_shift); |
1325 | add(reg_dst_, dst_shift); |
1326 | if (jcp_.ur_w_tail != 0) icb_loop(jcp_.ur_w_tail, 0, r_overflow, true); |
1327 | } else { |
1328 | xor_(reg_nur_w_, reg_nur_w_); |
1329 | if (l_overflow > 0) { |
1330 | icb_loop(jcp_.ur_w, l_overflow, 0, false); |
1331 | add(reg_src_, src_shift); |
1332 | add(reg_dst_, dst_shift); |
1333 | inc(reg_nur_w_); |
1334 | } |
1335 | if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) { |
1336 | Label ow_loop_label; |
1337 | L(ow_loop_label); |
1338 | { |
1339 | icb_loop(jcp_.ur_w, 0, 0, false); |
1340 | add(reg_src_, src_shift); |
1341 | add(reg_dst_, dst_shift); |
1342 | inc(reg_nur_w_); |
1343 | cmp(reg_nur_w_, nur_w); |
1344 | jl(ow_loop_label, T_NEAR); |
1345 | } |
1346 | } |
1347 | if (r_overflow1 > 0) { |
1348 | icb_loop(jcp_.ur_w, 0, r_overflow1, jcp_.ur_w_tail == 0); |
1349 | add(reg_src_, src_shift); |
1350 | add(reg_dst_, dst_shift); |
1351 | } |
1352 | if (jcp_.ur_w_tail != 0) { |
1353 | icb_loop(jcp_.ur_w_tail, 0, r_overflow, true); |
1354 | } |
1355 | } |
1356 | |
1357 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)) |
1358 | add(rsp, reserved_stack_size_); |
1359 | |
1360 | postamble(); |
1361 | |
1362 | if (jcp_.with_eltwise) postops_injector_->prepare_table(); |
1363 | } |
1364 | |
1365 | template <cpu_isa_t isa> |
1366 | jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::jit_uni_x8s8s32x_deconvolution_fwd_t( |
1367 | const pd_t *apd) |
1368 | : primitive_t(apd) {} |
1369 | |
1370 | template <cpu_isa_t isa> |
1371 | jit_uni_x8s8s32x_deconvolution_fwd_t< |
1372 | isa>::~jit_uni_x8s8s32x_deconvolution_fwd_t() |
1373 | = default; |
1374 | |
1375 | template <cpu_isa_t isa> |
1376 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::pd_t::init( |
1377 | engine_t *engine) { |
1378 | using namespace data_type; |
1379 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
1380 | const bool ok = true && is_fwd() |
1381 | && (desc()->alg_kind & alg_kind::deconvolution_direct) |
1382 | && utils::one_of(src_md(0)->data_type, s8, u8) |
1383 | && weights_md(0)->data_type == s8 |
1384 | && IMPLICATION(with_bias(), |
1385 | utils::one_of(weights_md(1)->data_type, f32, s32, s8, u8)) |
1386 | && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8) |
1387 | && desc()->accum_data_type == s32 |
1388 | && attr()->has_default_values(skip_mask_t::scales_runtime |
1389 | | skip_mask_t::post_ops | skip_mask_t::zero_points_runtime); |
1390 | if (!ok) return status::unimplemented; |
1391 | |
1392 | CHECK(jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(jcp_, *desc(), |
1393 | src_md_, weights_md_, dst_md_, with_bias(), bias_md_, attr_, |
1394 | dnnl_get_max_threads())); |
1395 | |
1396 | auto scratchpad = scratchpad_registry().registrar(); |
1397 | jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad( |
1398 | scratchpad, jcp_, *attr()); |
1399 | |
1400 | return status::success; |
1401 | } |
1402 | |
1403 | template <cpu_isa_t isa> |
1404 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::init(engine_t *engine) { |
1405 | CHECK(safe_ptr_assign(kernel_, |
1406 | new jit_uni_x8s8s32x_deconv_fwd_kernel<isa>(pd()->jcp_, |
1407 | *pd()->attr(), memory_desc_wrapper(pd()->dst_md())))); |
1408 | |
1409 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) { |
1410 | CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_, |
1411 | zp::create_deconv_zp_pad_str_comp_ker<isa>(pd()->jcp_))); |
1412 | const auto zp_kernel_status = zp_src_pad_comp_kernel_->create_kernel(); |
1413 | if (zp_kernel_status != status::success) return zp_kernel_status; |
1414 | } |
1415 | |
1416 | return kernel_->create_kernel(); |
1417 | } |
1418 | |
1419 | template <cpu_isa_t isa> |
1420 | const float *jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::adjust_oscales( |
1421 | const memory_tracking::grantor_t &scratchpad, const float *src_scales, |
1422 | const float *wei_scales) const { |
1423 | auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales); |
1424 | int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
1425 | float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni)) |
1426 | ? 1.f / pd()->jcp_.wei_adj_scale |
1427 | : 1.0f; |
1428 | if (wei_mask == 0) { |
1429 | utils::array_set(loc_scales, src_scales[0] * wei_scales[0] * factor, 8); |
1430 | } else { |
1431 | for (dim_t c = 0; c < pd()->OC(); c++) |
1432 | loc_scales[c] = src_scales[0] * wei_scales[c] * factor; |
1433 | } |
1434 | return loc_scales; |
1435 | } |
1436 | |
1437 | template <cpu_isa_t isa> |
1438 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute( |
1439 | const exec_ctx_t &ctx) const { |
1440 | const auto &_pd = pd(); |
1441 | const auto &ndims = _pd->ndims(); |
1442 | if (ndims == 3) |
1443 | return execute_forward_1d(ctx); |
1444 | else if (ndims == 4) |
1445 | return execute_forward_2d(ctx); |
1446 | else if (ndims == 5) |
1447 | return execute_forward_3d(ctx); |
1448 | else |
1449 | return status::unimplemented; |
1450 | return status::success; |
1451 | } |
1452 | |
1453 | template <cpu_isa_t isa> |
1454 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_1d( |
1455 | const exec_ctx_t &ctx) const { |
1456 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1457 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1458 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1459 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1460 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1461 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1462 | |
1463 | const auto &jcp = pd()->jcp_; |
1464 | |
1465 | const memory_desc_wrapper src_d(pd()->src_md()); |
1466 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1467 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1468 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1469 | |
1470 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1471 | |
1472 | const auto post_ops_binary_rhs_arg_vec |
1473 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1474 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1475 | int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1476 | |
1477 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1478 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1479 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1480 | zp_src_pad_comp_kernel_.get()); |
1481 | |
1482 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1483 | const int nb_groups = jcp.nb_ch; |
1484 | |
1485 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1486 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1487 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1488 | |
1489 | const float *oscales = adjust_oscales( |
1490 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1491 | |
1492 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1493 | auto w = const_cast<int8_t *>(weights); |
1494 | int32_t *compensation = (jcp.signed_input) |
1495 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1496 | : nullptr; |
1497 | const int32_t *zp_compensation = jcp.src_zero_point |
1498 | ? get_src_zp_comp_from_wei( |
1499 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1500 | : nullptr; |
1501 | |
1502 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1503 | int start {0}, end {0}; |
1504 | const int work_amount = jcp.mb * nb_groups * oc_chunks; |
1505 | balance211(work_amount, nthr, ithr, start, end); |
1506 | |
1507 | auto p = jit_deconv_call_s(); |
1508 | |
1509 | int n {0}, g {0}, occ {0}; |
1510 | if (jcp.loop_order == loop_ngc) |
1511 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks); |
1512 | else if (jcp.loop_order == loop_cgn) |
1513 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb); |
1514 | else |
1515 | assert(!"unsupported loop order" ); |
1516 | while (start < end) { |
1517 | |
1518 | const int ocb = occ * jcp.nb_oc_blocking; |
1519 | const int g_oc |
1520 | = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1521 | const int g_ic = g * jcp.ch_block * jcp.ic; |
1522 | |
1523 | p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc); |
1524 | p.src = src + src_d.blk_off(n, g_ic); |
1525 | p.filt = weights + wht_blk_off(weights_d, g, ocb, 0); |
1526 | p.bias = jcp.with_bias |
1527 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1528 | : nullptr; |
1529 | p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1530 | p.scales = &oscales[jcp.is_oc_scale * g_oc]; |
1531 | p.dst_scale = dst_scales; |
1532 | p.t_overflow = 0; |
1533 | p.b_overflow = 0; |
1534 | p.kh_padding = jcp.kh; |
1535 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1536 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
1537 | p.oc_l_off = g_oc; |
1538 | p.zp_compensation |
1539 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1540 | p.zp_src_pad_str_compensation = zp_src_comp_scratch |
1541 | ? zp_src_comp_scratch + g_oc |
1542 | : nullptr; |
1543 | p.src_zero_point = zp_src; |
1544 | p.dst_zero_point = zp_dst; |
1545 | p.dst_orig = dst; |
1546 | |
1547 | (*kernel_)(&p); |
1548 | |
1549 | ++start; |
1550 | if (jcp.loop_order == loop_ngc) |
1551 | nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks); |
1552 | else if (jcp.loop_order == loop_cgn) |
1553 | nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb); |
1554 | else |
1555 | assert(!"unsupported loop order" ); |
1556 | } |
1557 | }); |
1558 | return status::success; |
1559 | } |
1560 | |
1561 | template <cpu_isa_t isa> |
1562 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_2d( |
1563 | const exec_ctx_t &ctx) const { |
1564 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1565 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1566 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1567 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1568 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1569 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1570 | |
1571 | const memory_desc_wrapper src_d(pd()->src_md()); |
1572 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1573 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1574 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1575 | |
1576 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1577 | |
1578 | const auto &jcp = pd()->jcp_; |
1579 | const auto post_ops_binary_rhs_arg_vec |
1580 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1581 | |
1582 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1583 | int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1584 | |
1585 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1586 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1587 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1588 | zp_src_pad_comp_kernel_.get()); |
1589 | |
1590 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1591 | const int nb_groups = jcp.nb_ch; |
1592 | |
1593 | const size_t src_h_stride = src_d.blk_off(0, 0, 1); |
1594 | const size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
1595 | const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
1596 | |
1597 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1598 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1599 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1600 | |
1601 | const float *oscales = adjust_oscales( |
1602 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1603 | |
1604 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1605 | auto w = const_cast<int8_t *>(weights); |
1606 | int32_t *compensation = (jcp.signed_input) |
1607 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1608 | : nullptr; |
1609 | const int32_t *zp_compensation = jcp.src_zero_point |
1610 | ? get_src_zp_comp_from_wei( |
1611 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1612 | : nullptr; |
1613 | |
1614 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1615 | int start {0}, end {0}; |
1616 | const int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh; |
1617 | balance211(work_amount, nthr, ithr, start, end); |
1618 | |
1619 | auto p = jit_deconv_call_s(); |
1620 | |
1621 | /*loop order = cgn*/ |
1622 | int n {0}, g {0}, occ {0}, oh_s {0}; |
1623 | if (jcp.loop_order == loop_ngc) |
1624 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
1625 | oh_s, jcp.oh); |
1626 | else if (jcp.loop_order == loop_cgn) |
1627 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, |
1628 | oh_s, jcp.oh); |
1629 | else |
1630 | assert(!"unsupported loop order" ); |
1631 | while (start < end) { |
1632 | |
1633 | const int ocb = occ * jcp.nb_oc_blocking; |
1634 | const int g_oc |
1635 | = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1636 | const int g_ic = g * jcp.ch_block * jcp.ic; |
1637 | const int work_rem = end - start; |
1638 | const int oh_e |
1639 | = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1640 | |
1641 | const auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc); |
1642 | const auto src_w = src + src_d.blk_off(n, g_ic); |
1643 | const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); |
1644 | const auto bias_w = jcp.with_bias |
1645 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1646 | : nullptr; |
1647 | const int32_t *compensation_w |
1648 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1649 | |
1650 | const auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
1651 | for (int oj = oh_s; oj < oh_e; oj++) { |
1652 | int ih_max = 0, kh_lo = 0, kh_len = 0; |
1653 | if (jcp.dilate_h != 0 && jcp.stride_h == 1) { |
1654 | /* dilation */ |
1655 | const int dilate_h = jcp.dilate_h + 1; |
1656 | // Note: use div_up to account for "holes" in filter |
1657 | const int o_t_overflow = div_up( |
1658 | max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), |
1659 | dilate_h); |
1660 | const int o_b_overflow |
1661 | = div_up(max(0, |
1662 | (jcp.kh - 1) * dilate_h + 1 |
1663 | - jcp.oh + oj - jcp.b_pad), |
1664 | dilate_h); |
1665 | kh_len = jcp.kh - o_t_overflow - o_b_overflow; |
1666 | kh_lo = o_b_overflow; |
1667 | ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; |
1668 | } else { |
1669 | const int o_t_overflow = max( |
1670 | 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); |
1671 | const int o_b_overflow = max(0, |
1672 | ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) |
1673 | / jcp.stride_h); |
1674 | const int overflow_kh_hi = jcp.kh - 1 |
1675 | - modulo(jcp.oh + jcp.b_pad - (oj + 1), |
1676 | jcp.stride_h); |
1677 | const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; |
1678 | |
1679 | kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h |
1680 | + 1 - o_t_overflow - o_b_overflow; |
1681 | kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; |
1682 | ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; |
1683 | } |
1684 | |
1685 | const int wei_stride |
1686 | = (!jcp.signed_input && !jcp.src_zero_point) |
1687 | ? kh_lo * wht_kh_stride |
1688 | : 0; |
1689 | p.src = src_w + ih_max * src_h_stride; |
1690 | p.dst = dst_w + dst_dt_size * oj * dst_h_stride; |
1691 | p.filt = wht_w + wei_stride; |
1692 | p.bias = bias_w; |
1693 | p.compensation = compensation_w; |
1694 | p.t_overflow = jcp.dilate_h > 0 |
1695 | ? jcp.kh - kh_len - kh_lo |
1696 | : max(0, |
1697 | jcp.kh |
1698 | - (kh_lo |
1699 | + max(0, kh_len - 1) |
1700 | * jcp.stride_h |
1701 | + 1)); |
1702 | p.b_overflow = kh_lo; |
1703 | p.kh_padding = kh_len; |
1704 | p.scales = scales; |
1705 | p.dst_scale = dst_scales; |
1706 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1707 | p.post_ops_binary_rhs_arg_vec |
1708 | = post_ops_binary_rhs_arg_vec.data(); |
1709 | p.oc_l_off = g_oc; |
1710 | p.zp_compensation |
1711 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1712 | p.zp_src_pad_str_compensation = jcp.src_zero_point |
1713 | ? zp_src_comp_scratch + g_oc |
1714 | : nullptr; |
1715 | p.src_zero_point = zp_src; |
1716 | p.dst_zero_point = zp_dst; |
1717 | p.dst_orig = dst; |
1718 | |
1719 | (*kernel_)(&p); |
1720 | } |
1721 | if (jcp.loop_order == loop_ngc) |
1722 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
1723 | oc_chunks, oh_s, jcp.oh); |
1724 | else if (jcp.loop_order == loop_cgn) |
1725 | nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, |
1726 | jcp.mb, oh_s, jcp.oh); |
1727 | else |
1728 | assert(!"unsupported loop order" ); |
1729 | } |
1730 | }); |
1731 | return status::success; |
1732 | } |
1733 | |
1734 | template <cpu_isa_t isa> |
1735 | status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_3d( |
1736 | const exec_ctx_t &ctx) const { |
1737 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
1738 | const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
1739 | const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
1740 | const auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
1741 | DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC); |
1742 | DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST); |
1743 | |
1744 | const memory_desc_wrapper src_d(pd()->src_md()); |
1745 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
1746 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
1747 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
1748 | |
1749 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
1750 | |
1751 | const auto &jcp = pd()->jcp_; |
1752 | const auto post_ops_binary_rhs_arg_vec |
1753 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
1754 | |
1755 | const auto scratchpad = ctx.get_scratchpad_grantor(); |
1756 | int32_t *const zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp); |
1757 | |
1758 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) |
1759 | zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(), |
1760 | weights_d, weights, zp_src, zp_src_comp_scratch, |
1761 | zp_src_pad_comp_kernel_.get()); |
1762 | |
1763 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
1764 | const int &nb_groups = jcp.nb_ch; |
1765 | |
1766 | const size_t src_d_stride = src_d.blk_off(0, 0, 1); |
1767 | const size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
1768 | const size_t dst_d_stride = dst_d.blk_off(0, 0, 1); |
1769 | const size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
1770 | const size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
1771 | const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
1772 | |
1773 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
1774 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
1775 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
1776 | |
1777 | const float *oscales = adjust_oscales( |
1778 | ctx.get_scratchpad_grantor(), src_scales, wei_scales); |
1779 | |
1780 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
1781 | auto w = const_cast<int8_t *>(weights); |
1782 | int32_t *compensation = (jcp.signed_input) |
1783 | ? reinterpret_cast<int32_t *>(&w[offset]) |
1784 | : nullptr; |
1785 | const int32_t *zp_compensation = jcp.src_zero_point |
1786 | ? get_src_zp_comp_from_wei( |
1787 | weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc) |
1788 | : nullptr; |
1789 | |
1790 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1791 | int start {0}, end {0}; |
1792 | int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh; |
1793 | balance211(work_amount, nthr, ithr, start, end); |
1794 | |
1795 | auto p = jit_deconv_call_s(); |
1796 | |
1797 | /*loop order = cgn*/ |
1798 | int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0}; |
1799 | if (jcp.loop_order == loop_ngc) |
1800 | nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks, |
1801 | od_s, jcp.od, oh_s, jcp.oh); |
1802 | else if (jcp.loop_order == loop_cgn) |
1803 | nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb, |
1804 | od_s, jcp.od, oh_s, jcp.oh); |
1805 | else |
1806 | assert(!"unsupported loop order" ); |
1807 | while (start < end) { |
1808 | |
1809 | const int ocb = occ * jcp.nb_oc_blocking; |
1810 | const int g_oc |
1811 | = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block; |
1812 | const int g_ic = g * jcp.ch_block * jcp.ic; |
1813 | const int work_rem = end - start; |
1814 | const int oh_e |
1815 | = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1816 | int input_d_s = 0, kd_len = 0, kd_lo = 0; |
1817 | int d_t_overflow, d_back_overflow; |
1818 | |
1819 | if (jcp.dilate_d != 0 && jcp.stride_d == 1) { |
1820 | /* dilation */ |
1821 | int dilate_d = jcp.dilate_d + 1; |
1822 | // Note: use div_up to account for "holes" in filter |
1823 | d_t_overflow = div_up( |
1824 | max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad), |
1825 | dilate_d); |
1826 | d_back_overflow |
1827 | = div_up(max(0, |
1828 | (jcp.kd - 1) * dilate_d + 1 - jcp.od |
1829 | + od_s - jcp.back_pad), |
1830 | dilate_d); |
1831 | kd_len = jcp.kd - d_t_overflow - d_back_overflow; |
1832 | kd_lo = d_back_overflow; |
1833 | input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d; |
1834 | } else { |
1835 | const int d_t_overflow = max( |
1836 | 0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d); |
1837 | const int d_back_overflow = max(0, |
1838 | ((od_s + jcp.kd) - (jcp.od + jcp.back_pad)) |
1839 | / jcp.stride_d); |
1840 | const int overflow_kd_hi = jcp.kd - 1 |
1841 | - modulo(jcp.od + jcp.back_pad - (od_s + 1), |
1842 | jcp.stride_d); |
1843 | const int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d; |
1844 | |
1845 | kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1 |
1846 | - d_t_overflow - d_back_overflow; |
1847 | kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d; |
1848 | input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d; |
1849 | } |
1850 | |
1851 | auto dst_w = dst |
1852 | + dst_dt_size |
1853 | * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride); |
1854 | const auto src_w |
1855 | = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride; |
1856 | const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) |
1857 | + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo) |
1858 | * wht_kd_stride; |
1859 | const auto bias_w = jcp.with_bias |
1860 | ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) |
1861 | : nullptr; |
1862 | const int32_t *compensation_w |
1863 | = (jcp.signed_input) ? compensation + g_oc : nullptr; |
1864 | |
1865 | const auto scales = &oscales[jcp.is_oc_scale * g_oc]; |
1866 | |
1867 | for (int oj = oh_s; oj < oh_e; oj++) { |
1868 | int ih_max = 0, kh_lo = 0, kh_len = 0; |
1869 | if (jcp.dilate_h != 0 && jcp.stride_h == 1) { |
1870 | /* dilation */ |
1871 | const int dilate_h = jcp.dilate_h + 1; |
1872 | // Note: use div_up to account for "holes" in filter |
1873 | const int o_t_overflow = div_up( |
1874 | max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad), |
1875 | dilate_h); |
1876 | const int o_b_overflow |
1877 | = div_up(max(0, |
1878 | (jcp.kh - 1) * dilate_h + 1 |
1879 | - jcp.oh + oj - jcp.b_pad), |
1880 | dilate_h); |
1881 | kh_len = jcp.kh - o_t_overflow - o_b_overflow; |
1882 | kh_lo = o_b_overflow; |
1883 | ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h; |
1884 | } else { |
1885 | const int o_t_overflow = max( |
1886 | 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); |
1887 | const int o_b_overflow = max(0, |
1888 | ((oj + jcp.kh) - (jcp.oh + jcp.b_pad)) |
1889 | / jcp.stride_h); |
1890 | const int overflow_kh_hi = jcp.kh - 1 |
1891 | - modulo(jcp.oh + jcp.b_pad - (oj + 1), |
1892 | jcp.stride_h); |
1893 | const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h; |
1894 | |
1895 | kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h |
1896 | + 1 - o_t_overflow - o_b_overflow; |
1897 | kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h; |
1898 | ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h; |
1899 | } |
1900 | |
1901 | const int wei_stride |
1902 | = (!jcp.signed_input && !jcp.src_zero_point) |
1903 | ? kh_lo * wht_kh_stride |
1904 | : 0; |
1905 | p.src = src_w + ih_max * src_h_stride; |
1906 | p.dst = dst_w + dst_dt_size * oj * dst_h_stride; |
1907 | p.filt = wht_w + wei_stride; |
1908 | p.bias = bias_w; |
1909 | p.compensation = compensation_w; |
1910 | /* Note: Currently this kernel doesn't support dilations and |
1911 | strides together */ |
1912 | p.t_overflow = jcp.dilate_h > 0 |
1913 | ? jcp.kh - kh_len - kh_lo |
1914 | : max(0, |
1915 | jcp.kh |
1916 | - (kh_lo |
1917 | + max(0, kh_len - 1) |
1918 | * jcp.stride_h |
1919 | + 1)); |
1920 | p.b_overflow = kh_lo; |
1921 | p.f_overflow = jcp.dilate_d > 0 |
1922 | ? jcp.kd - kd_len - kd_lo |
1923 | : max(0, |
1924 | jcp.kd |
1925 | - (kd_lo |
1926 | + max(0, kd_len - 1) |
1927 | * jcp.stride_d |
1928 | + 1)); |
1929 | p.back_overflow = kd_lo; |
1930 | p.kh_padding = kh_len; |
1931 | p.kd_padding = kd_len; |
1932 | p.scales = scales; |
1933 | p.dst_scale = dst_scales; |
1934 | p.oc_blocks = jcp.is_depthwise ? g : ocb; |
1935 | p.post_ops_binary_rhs_arg_vec |
1936 | = post_ops_binary_rhs_arg_vec.data(); |
1937 | p.oc_l_off = g_oc; |
1938 | p.zp_compensation |
1939 | = jcp.src_zero_point ? zp_compensation + g_oc : nullptr; |
1940 | p.zp_src_pad_str_compensation = jcp.src_zero_point |
1941 | ? zp_src_comp_scratch + g_oc |
1942 | : nullptr; |
1943 | p.src_zero_point = zp_src; |
1944 | p.dst_zero_point = zp_dst; |
1945 | p.dst_orig = dst; |
1946 | |
1947 | (*kernel_)(&p); |
1948 | } |
1949 | |
1950 | if (jcp.loop_order == loop_ngc) |
1951 | nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ, |
1952 | oc_chunks, od_s, jcp.od, oh_s, jcp.oh); |
1953 | else if (jcp.loop_order == loop_cgn) |
1954 | nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n, |
1955 | jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
1956 | else |
1957 | assert(!"unsupported loop order" ); |
1958 | } |
1959 | }); |
1960 | |
1961 | return status::success; |
1962 | } |
1963 | |
1964 | using namespace data_type; |
1965 | template struct jit_uni_x8s8s32x_deconvolution_fwd_t<avx2>; |
1966 | template struct jit_uni_x8s8s32x_deconvolution_fwd_t<sse41>; |
1967 | template struct jit_uni_x8s8s32x_deconv_fwd_kernel<avx2>; |
1968 | template struct jit_uni_x8s8s32x_deconv_fwd_kernel<sse41>; |
1969 | template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>; |
1970 | template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Xmm>; |
1971 | template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<sse41, Xbyak::Xmm>; |
1972 | } // namespace x64 |
1973 | } // namespace cpu |
1974 | } // namespace impl |
1975 | } // namespace dnnl |
1976 | |