1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18#include "common/dnnl_thread.hpp"
19#include "common/utils.hpp"
20#include "cpu/x64/jit_primitive_conf.hpp"
21#include <type_traits>
22
23#include "jit_uni_deconv_zp_pad_str_kernel.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29namespace zp {
30
31jit_uni_deconv_zp_pad_str_kernel_base_t::
32 jit_uni_deconv_zp_pad_str_kernel_base_t(const jit_conv_conf_t &jcp)
33 : jit_generator(jit_name())
34 , jcp_(jcp)
35 , tail_size_(jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
36 : jcp.oc_without_padding % jcp.oc_block) {}
37
38size_t jit_uni_deconv_zp_pad_str_kernel_base_t::reserve_vmm() {
39 return number_reserved_vmms_++;
40}
41
42void jit_uni_deconv_zp_pad_str_kernel_base_t::generate() {
43 preamble();
44 load_addresses();
45 init();
46 compute();
47 apply_zero_point();
48 store_result();
49 postamble();
50}
51
52void jit_uni_deconv_zp_pad_str_kernel_base_t::compute() {
53
54 const dim_t outer_icb_step = jcp_.kd * jcp_.kh * jcp_.kw * jcp_.ic_block
55 * jcp_.oc_block * jcp_.ch_block;
56 const dim_t inner_icb_step = jcp_.oc_block * jcp_.ch_block * 4;
57 const bool ic_tail_exists = jcp_.ic_without_padding % jcp_.ic_block;
58
59 for (dim_t icb = 0; icb < jcp_.nb_ic; ++icb) {
60 const bool is_last_icb = icb == jcp_.nb_ic - 1;
61
62 const int n_inner_ic_blk = jcp_.is_depthwise
63 ? 1
64 : (is_last_icb && ic_tail_exists ? utils::div_up(
65 jcp_.ic_without_padding % jcp_.ic_block, 4)
66 : (jcp_.ic_block / 4));
67
68 const dim_t outer_wei_offset = icb * outer_icb_step;
69
70 for (int inner_icb = 0; inner_icb < n_inner_ic_blk; inner_icb++) {
71 const dim_t inner_wei_offset
72 = outer_wei_offset + inner_icb * inner_icb_step;
73
74 compute_step(inner_wei_offset);
75 }
76 }
77}
78
79template <cpu_isa_t isa, typename Vmm>
80jit_uni_deconv_zp_pad_str_kernel_t<isa,
81 Vmm>::jit_uni_deconv_zp_pad_str_kernel_t(const jit_conv_conf_t &jcp)
82 : jit_uni_deconv_zp_pad_str_kernel_base_t(jcp)
83 , result_acc_(reserve_vmm())
84 , vmm_tmp_((jcp.has_vnni || jcp.is_depthwise) ? 0 : reserve_vmm())
85 , vmm_one_bytes_(jcp.is_depthwise ? 0 : reserve_vmm())
86 , vmm_one_words_((jcp.has_vnni || jcp.is_depthwise) ? 0 : reserve_vmm())
87 , current_vmm_(number_reserved_vmms_) {}
88
89template <cpu_isa_t isa, typename Vmm>
90void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::init() {
91 uni_vpxor(result_acc_, result_acc_, result_acc_);
92
93 if (std::is_same<Vmm, Xbyak::Zmm>::value) {
94 const int mask = (1 << tail_size_) - 1;
95 Xbyak::Reg32 regw_tmp = reg_tmp_.cvt32();
96 mov(regw_tmp, mask);
97 kmovw(ktail_mask_, regw_tmp);
98 }
99
100 if (!jcp_.is_depthwise) {
101
102 const auto reg32_scratch = reg_tmp_.cvt32();
103 // fill register byte ones
104 const Xbyak::Xmm xmm_one {vmm_one_bytes_.getIdx()};
105
106 mov(reg32_scratch, 0x1010101);
107 if (isa == sse41)
108 movd(xmm_one, reg32_scratch);
109 else
110 vmovd(xmm_one, reg32_scratch);
111 uni_vbroadcastss(vmm_one_bytes_, xmm_one);
112
113 if (!jcp_.has_vnni) {
114 const Xbyak::Xmm xmm_one_words
115 = Xbyak::Xmm(vmm_one_words_.getIdx());
116 mov(reg_tmp_, 0x10001);
117 uni_vmovq(xmm_one_words, reg_tmp_);
118 uni_vpbroadcastd(vmm_one_words_, xmm_one_words);
119 }
120 }
121}
122
123template <cpu_isa_t isa, typename Vmm>
124Vmm jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::get_next_vmm() {
125 static constexpr int max_v_regs = cpu_isa_traits<isa>::n_vregs;
126
127 const Vmm vmm {static_cast<int>(current_vmm_++)};
128
129 if (current_vmm_ == max_v_regs) current_vmm_ = number_reserved_vmms_;
130
131 return vmm;
132}
133
134template <cpu_isa_t isa, typename Vmm>
135void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::compute_step(
136 const dim_t icb_offset) {
137 const auto wei_vmm = get_next_vmm();
138
139 if (jcp_.is_depthwise)
140 uni_vpmovsxbd(wei_vmm, ptr[reg_wei_ + icb_offset]);
141 else
142 uni_vmovups(wei_vmm, ptr[reg_wei_ + icb_offset]);
143
144 if (jcp_.is_depthwise)
145 uni_vpaddd(result_acc_, result_acc_, wei_vmm);
146 else if (jcp_.has_vnni)
147 vpdpbusd(result_acc_, vmm_one_bytes_, wei_vmm,
148 is_superset(isa, avx512_core) ? Xbyak::EvexEncoding
149 : Xbyak::VexEncoding);
150 else {
151 uni_vpmaddubsw(vmm_tmp_, vmm_one_bytes_, wei_vmm);
152 uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_words_);
153 uni_vpaddd(result_acc_, result_acc_, vmm_tmp_);
154 }
155}
156
157template <cpu_isa_t isa, typename Vmm,
158 typename T = std::integral_constant<bool, (isa < avx512_core)>>
159struct helper_store_t {
160 static void store(jit_generator *gen, const Vmm &vmm,
161 const Xbyak::Reg64 &reg_dst, const size_t size,
162 const Xbyak::Opmask &opmask) {
163 gen->store_bytes(vmm, reg_dst, 0, size);
164 }
165};
166
167using isa_at_least_avx512_core = std::false_type;
168template <cpu_isa_t isa, typename Vmm>
169struct helper_store_t<isa, Vmm, isa_at_least_avx512_core> {
170 static void store(jit_generator *gen, const Vmm &vmm,
171 const Xbyak::Reg64 &reg_dst, const size_t size,
172 const Xbyak::Opmask &opmask) {
173 using namespace Xbyak::util;
174 gen->vmovups(gen->ptr[reg_dst], vmm | opmask);
175 }
176};
177
178template <cpu_isa_t isa, typename Vmm>
179void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::store_result() {
180
181 Xbyak::Label store_no_tail, end;
182
183 if (tail_size_) {
184 cmp(reg_last_oc_block_, 0);
185 je(store_no_tail, T_NEAR);
186 helper_store_t<isa, Vmm>::store(this, result_acc_, reg_dst_,
187 tail_size_ * sizeof(int32_t), ktail_mask_);
188 jmp(end, T_NEAR);
189 }
190
191 L(store_no_tail);
192 { uni_vmovups(ptr[reg_dst_], result_acc_); }
193
194 L(end);
195}
196
197template <cpu_isa_t isa, typename Vmm>
198void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::apply_zero_point() {
199 const auto zp_src_vmm = get_next_vmm();
200 uni_vbroadcastss(zp_src_vmm, ptr[reg_src_zp_]);
201 uni_vpmulld(result_acc_, result_acc_, zp_src_vmm);
202}
203
204#define PARAM_OFF(x) offsetof(jit_uni_deconv_zp_pad_str_call_params_t, x)
205
206void jit_uni_deconv_zp_pad_str_kernel_base_t::load_addresses() {
207
208 mov(reg_src_zp_, ptr[abi_param1 + PARAM_OFF(src_zero_point)]);
209 mov(reg_wei_, ptr[abi_param1 + PARAM_OFF(wei)]);
210 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst_scratchpad)]);
211 if (tail_size_)
212 mov(reg_last_oc_block_, ptr[abi_param1 + PARAM_OFF(last_oc_block)]);
213}
214
215#undef PARAM_OFF
216
217template <cpu_isa_t isa,
218 typename T = std::integral_constant<bool, (isa < avx512_core)>>
219struct helper_create_deconv_ker_t {
220 static jit_uni_deconv_zp_pad_str_kernel_base_t *
221 create_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp) {
222
223 const int ch_block = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block;
224 switch (ch_block) {
225 case 8:
226 if (isa == avx2) {
227 return new jit_uni_deconv_zp_pad_str_kernel_t<avx2,
228 Xbyak::Ymm>(jcp);
229 } else
230 assert(!"invalid channel blocking for current ISA");
231 case 4:
232 return new jit_uni_deconv_zp_pad_str_kernel_t<isa, Xbyak::Xmm>(
233 jcp);
234 default: assert(!"invalid channel blocking");
235 }
236
237 return nullptr;
238 }
239};
240
241template <cpu_isa_t isa>
242struct helper_create_deconv_ker_t<isa, isa_at_least_avx512_core> {
243 static jit_uni_deconv_zp_pad_str_kernel_base_t *
244 create_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp) {
245 const int ch_block = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block;
246 switch (ch_block) {
247 case 16:
248 return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core,
249 Xbyak::Zmm>(jcp);
250 case 8:
251 return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core,
252 Xbyak::Ymm>(jcp);
253 case 4:
254 return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core,
255 Xbyak::Xmm>(jcp);
256 default: assert(!"invalid channel blocking");
257 }
258
259 return nullptr;
260 }
261};
262
263template <cpu_isa_t isa>
264jit_uni_deconv_zp_pad_str_kernel_base_t *create_deconv_zp_pad_str_comp_ker(
265 const jit_conv_conf_t &jcp) {
266
267 return helper_create_deconv_ker_t<isa>::create_deconv_zp_pad_str_comp_ker(
268 jcp);
269}
270
271#define wht_blk_off(d, g, ...) \
272 (with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__))
273
274static dim_t wei_off(const memory_desc_wrapper &wei_d, const bool with_groups,
275 const dim_t ch_b, const dim_t oc_b, const dim_t d, const dim_t h,
276 const dim_t w) {
277
278 const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0);
279
280 switch (ndims) {
281 case 5: return wht_blk_off(wei_d, ch_b, oc_b, 0, d, h, w);
282 case 4: return wht_blk_off(wei_d, ch_b, oc_b, 0, h, w);
283 case 3: return wht_blk_off(wei_d, ch_b, oc_b, 0, w);
284 default: assert("Unsupported ndims!");
285 }
286
287 return 0;
288}
289
290static dim_t dst_off(const jit_conv_conf_t &jcp, const dim_t ndims,
291 const dim_t g, const dim_t oc, const dim_t d, const dim_t h,
292 const dim_t w) {
293
294 const auto &G = jcp.ngroups;
295 const auto &OC = jcp.oc_without_padding;
296 const auto &OW = jcp.kw;
297 const auto &OH = jcp.kh;
298
299 dim_t offset = w;
300
301 if (ndims == 5)
302 offset += d * OH * OW + h * OW;
303 else if (ndims == 4)
304 offset += h * OW;
305
306 if (G == 1) return offset * OC + oc;
307
308 return (offset * OC * G) + g * OC + oc;
309}
310
311void compute_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp,
312 const bool with_groups, const memory_desc_wrapper &wei_d,
313 const int8_t *wei, const int32_t *src_zp, int32_t *dst,
314 jit_uni_deconv_zp_pad_str_kernel_base_t *ker) {
315
316 using namespace dnnl::impl::utils;
317 const auto work_amount = jcp.nb_ch * jcp.nb_oc * jcp.kw * jcp.kd * jcp.kh;
318 /*
319 * Heuristics for parallel computation usage - cost of threads creation
320 * may exceed the computation time which leads to performance drop
321 */
322 static constexpr int parallelization_ratio_thr = 5;
323 const int nthrs = (work_amount / jcp.nthr) > parallelization_ratio_thr
324 ? jcp.nthr
325 : 1;
326
327 parallel(nthrs, [&](const int ithr, const int nthr) {
328 int start {0}, end {0};
329 balance211(work_amount, nthr, ithr, start, end);
330
331 int ch_b {0}, oc_b {0}, d {0}, h {0}, w {0};
332 if (jcp.loop_order == loop_ngc)
333 nd_iterator_init(start, ch_b, jcp.nb_ch, oc_b, jcp.nb_oc, d, jcp.kd,
334 h, jcp.kh, w, jcp.kw);
335 else if (jcp.loop_order == loop_cgn)
336 nd_iterator_init(start, oc_b, jcp.nb_oc, ch_b, jcp.nb_ch, d, jcp.kd,
337 h, jcp.kh, w, jcp.kw);
338
339 for (auto iwork = start; iwork < end; ++iwork) {
340 jit_uni_deconv_zp_pad_str_call_params_t params;
341 const auto oc = oc_b * jcp.oc_block;
342 const auto g = ch_b * jcp.ch_block;
343 params.wei = wei + wei_off(wei_d, with_groups, ch_b, oc_b, d, h, w);
344 params.src_zero_point = src_zp;
345 params.last_oc_block = jcp.is_depthwise ? ch_b == jcp.nb_ch - 1
346 : oc_b == jcp.nb_oc - 1;
347 params.dst_scratchpad = dst
348 + dst_off(jcp, wei_d.ndims() - (with_groups ? 1 : 0), g, oc,
349 d, h, w);
350
351 (*ker)(&params);
352
353 if (jcp.loop_order == loop_ngc)
354 nd_iterator_step(ch_b, jcp.nb_ch, oc_b, jcp.nb_oc, d, jcp.kd, h,
355 jcp.kh, w, jcp.kw);
356 else if (jcp.loop_order == loop_cgn)
357 nd_iterator_step(oc_b, jcp.nb_oc, ch_b, jcp.nb_ch, d, jcp.kd, h,
358 jcp.kh, w, jcp.kw);
359 else
360 assert(!"unsupported loop order");
361 }
362 });
363}
364
365static bool stride_exists(const jit_conv_conf_t &jcp) noexcept {
366 return jcp.stride_d > 1 || jcp.stride_w > 1 || jcp.stride_h > 1;
367}
368
369static bool padding_exists(const jit_conv_conf_t &jcp) noexcept {
370 const auto dd = jcp.dilate_d + 1;
371 const auto dh = jcp.dilate_h + 1;
372 const auto dw = jcp.dilate_w + 1;
373 return jcp.kw - jcp.l_pad / dw - 1 || jcp.kw - jcp.r_pad / dw - 1
374 || jcp.kh - jcp.t_pad / dh - 1 || jcp.kh - jcp.b_pad / dh - 1
375 || jcp.kd - jcp.f_pad / dd - 1 || jcp.kd - jcp.back_pad / dd - 1;
376}
377
378bool should_calculate_deconv_zp_src_pad_str_comp(
379 const jit_conv_conf_t &jcp) noexcept {
380 return jcp.src_zero_point && (stride_exists(jcp) || padding_exists(jcp));
381}
382
383template jit_uni_deconv_zp_pad_str_kernel_base_t *
384create_deconv_zp_pad_str_comp_ker<sse41>(const jit_conv_conf_t &jcp);
385template jit_uni_deconv_zp_pad_str_kernel_base_t *
386create_deconv_zp_pad_str_comp_ker<avx2>(const jit_conv_conf_t &jcp);
387template jit_uni_deconv_zp_pad_str_kernel_base_t *
388create_deconv_zp_pad_str_comp_ker<avx512_core>(const jit_conv_conf_t &jcp);
389
390} // namespace zp
391} // namespace x64
392} // namespace cpu
393} // namespace impl
394} // namespace dnnl
395