1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/utils.hpp" |
20 | #include "cpu/x64/jit_primitive_conf.hpp" |
21 | #include <type_traits> |
22 | |
23 | #include "jit_uni_deconv_zp_pad_str_kernel.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | namespace zp { |
30 | |
31 | jit_uni_deconv_zp_pad_str_kernel_base_t:: |
32 | jit_uni_deconv_zp_pad_str_kernel_base_t(const jit_conv_conf_t &jcp) |
33 | : jit_generator(jit_name()) |
34 | , jcp_(jcp) |
35 | , tail_size_(jcp.is_depthwise ? jcp.ngroups % jcp.ch_block |
36 | : jcp.oc_without_padding % jcp.oc_block) {} |
37 | |
38 | size_t jit_uni_deconv_zp_pad_str_kernel_base_t::reserve_vmm() { |
39 | return number_reserved_vmms_++; |
40 | } |
41 | |
42 | void jit_uni_deconv_zp_pad_str_kernel_base_t::generate() { |
43 | preamble(); |
44 | load_addresses(); |
45 | init(); |
46 | compute(); |
47 | apply_zero_point(); |
48 | store_result(); |
49 | postamble(); |
50 | } |
51 | |
52 | void jit_uni_deconv_zp_pad_str_kernel_base_t::compute() { |
53 | |
54 | const dim_t outer_icb_step = jcp_.kd * jcp_.kh * jcp_.kw * jcp_.ic_block |
55 | * jcp_.oc_block * jcp_.ch_block; |
56 | const dim_t inner_icb_step = jcp_.oc_block * jcp_.ch_block * 4; |
57 | const bool ic_tail_exists = jcp_.ic_without_padding % jcp_.ic_block; |
58 | |
59 | for (dim_t icb = 0; icb < jcp_.nb_ic; ++icb) { |
60 | const bool is_last_icb = icb == jcp_.nb_ic - 1; |
61 | |
62 | const int n_inner_ic_blk = jcp_.is_depthwise |
63 | ? 1 |
64 | : (is_last_icb && ic_tail_exists ? utils::div_up( |
65 | jcp_.ic_without_padding % jcp_.ic_block, 4) |
66 | : (jcp_.ic_block / 4)); |
67 | |
68 | const dim_t outer_wei_offset = icb * outer_icb_step; |
69 | |
70 | for (int inner_icb = 0; inner_icb < n_inner_ic_blk; inner_icb++) { |
71 | const dim_t inner_wei_offset |
72 | = outer_wei_offset + inner_icb * inner_icb_step; |
73 | |
74 | compute_step(inner_wei_offset); |
75 | } |
76 | } |
77 | } |
78 | |
79 | template <cpu_isa_t isa, typename Vmm> |
80 | jit_uni_deconv_zp_pad_str_kernel_t<isa, |
81 | Vmm>::jit_uni_deconv_zp_pad_str_kernel_t(const jit_conv_conf_t &jcp) |
82 | : jit_uni_deconv_zp_pad_str_kernel_base_t(jcp) |
83 | , result_acc_(reserve_vmm()) |
84 | , vmm_tmp_((jcp.has_vnni || jcp.is_depthwise) ? 0 : reserve_vmm()) |
85 | , vmm_one_bytes_(jcp.is_depthwise ? 0 : reserve_vmm()) |
86 | , vmm_one_words_((jcp.has_vnni || jcp.is_depthwise) ? 0 : reserve_vmm()) |
87 | , current_vmm_(number_reserved_vmms_) {} |
88 | |
89 | template <cpu_isa_t isa, typename Vmm> |
90 | void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::init() { |
91 | uni_vpxor(result_acc_, result_acc_, result_acc_); |
92 | |
93 | if (std::is_same<Vmm, Xbyak::Zmm>::value) { |
94 | const int mask = (1 << tail_size_) - 1; |
95 | Xbyak::Reg32 regw_tmp = reg_tmp_.cvt32(); |
96 | mov(regw_tmp, mask); |
97 | kmovw(ktail_mask_, regw_tmp); |
98 | } |
99 | |
100 | if (!jcp_.is_depthwise) { |
101 | |
102 | const auto reg32_scratch = reg_tmp_.cvt32(); |
103 | // fill register byte ones |
104 | const Xbyak::Xmm xmm_one {vmm_one_bytes_.getIdx()}; |
105 | |
106 | mov(reg32_scratch, 0x1010101); |
107 | if (isa == sse41) |
108 | movd(xmm_one, reg32_scratch); |
109 | else |
110 | vmovd(xmm_one, reg32_scratch); |
111 | uni_vbroadcastss(vmm_one_bytes_, xmm_one); |
112 | |
113 | if (!jcp_.has_vnni) { |
114 | const Xbyak::Xmm xmm_one_words |
115 | = Xbyak::Xmm(vmm_one_words_.getIdx()); |
116 | mov(reg_tmp_, 0x10001); |
117 | uni_vmovq(xmm_one_words, reg_tmp_); |
118 | uni_vpbroadcastd(vmm_one_words_, xmm_one_words); |
119 | } |
120 | } |
121 | } |
122 | |
123 | template <cpu_isa_t isa, typename Vmm> |
124 | Vmm jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::get_next_vmm() { |
125 | static constexpr int max_v_regs = cpu_isa_traits<isa>::n_vregs; |
126 | |
127 | const Vmm vmm {static_cast<int>(current_vmm_++)}; |
128 | |
129 | if (current_vmm_ == max_v_regs) current_vmm_ = number_reserved_vmms_; |
130 | |
131 | return vmm; |
132 | } |
133 | |
134 | template <cpu_isa_t isa, typename Vmm> |
135 | void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::compute_step( |
136 | const dim_t icb_offset) { |
137 | const auto wei_vmm = get_next_vmm(); |
138 | |
139 | if (jcp_.is_depthwise) |
140 | uni_vpmovsxbd(wei_vmm, ptr[reg_wei_ + icb_offset]); |
141 | else |
142 | uni_vmovups(wei_vmm, ptr[reg_wei_ + icb_offset]); |
143 | |
144 | if (jcp_.is_depthwise) |
145 | uni_vpaddd(result_acc_, result_acc_, wei_vmm); |
146 | else if (jcp_.has_vnni) |
147 | vpdpbusd(result_acc_, vmm_one_bytes_, wei_vmm, |
148 | is_superset(isa, avx512_core) ? Xbyak::EvexEncoding |
149 | : Xbyak::VexEncoding); |
150 | else { |
151 | uni_vpmaddubsw(vmm_tmp_, vmm_one_bytes_, wei_vmm); |
152 | uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_words_); |
153 | uni_vpaddd(result_acc_, result_acc_, vmm_tmp_); |
154 | } |
155 | } |
156 | |
157 | template <cpu_isa_t isa, typename Vmm, |
158 | typename T = std::integral_constant<bool, (isa < avx512_core)>> |
159 | struct helper_store_t { |
160 | static void store(jit_generator *gen, const Vmm &vmm, |
161 | const Xbyak::Reg64 ®_dst, const size_t size, |
162 | const Xbyak::Opmask &opmask) { |
163 | gen->store_bytes(vmm, reg_dst, 0, size); |
164 | } |
165 | }; |
166 | |
167 | using isa_at_least_avx512_core = std::false_type; |
168 | template <cpu_isa_t isa, typename Vmm> |
169 | struct helper_store_t<isa, Vmm, isa_at_least_avx512_core> { |
170 | static void store(jit_generator *gen, const Vmm &vmm, |
171 | const Xbyak::Reg64 ®_dst, const size_t size, |
172 | const Xbyak::Opmask &opmask) { |
173 | using namespace Xbyak::util; |
174 | gen->vmovups(gen->ptr[reg_dst], vmm | opmask); |
175 | } |
176 | }; |
177 | |
178 | template <cpu_isa_t isa, typename Vmm> |
179 | void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::store_result() { |
180 | |
181 | Xbyak::Label store_no_tail, end; |
182 | |
183 | if (tail_size_) { |
184 | cmp(reg_last_oc_block_, 0); |
185 | je(store_no_tail, T_NEAR); |
186 | helper_store_t<isa, Vmm>::store(this, result_acc_, reg_dst_, |
187 | tail_size_ * sizeof(int32_t), ktail_mask_); |
188 | jmp(end, T_NEAR); |
189 | } |
190 | |
191 | L(store_no_tail); |
192 | { uni_vmovups(ptr[reg_dst_], result_acc_); } |
193 | |
194 | L(end); |
195 | } |
196 | |
197 | template <cpu_isa_t isa, typename Vmm> |
198 | void jit_uni_deconv_zp_pad_str_kernel_t<isa, Vmm>::apply_zero_point() { |
199 | const auto zp_src_vmm = get_next_vmm(); |
200 | uni_vbroadcastss(zp_src_vmm, ptr[reg_src_zp_]); |
201 | uni_vpmulld(result_acc_, result_acc_, zp_src_vmm); |
202 | } |
203 | |
204 | #define PARAM_OFF(x) offsetof(jit_uni_deconv_zp_pad_str_call_params_t, x) |
205 | |
206 | void jit_uni_deconv_zp_pad_str_kernel_base_t::load_addresses() { |
207 | |
208 | mov(reg_src_zp_, ptr[abi_param1 + PARAM_OFF(src_zero_point)]); |
209 | mov(reg_wei_, ptr[abi_param1 + PARAM_OFF(wei)]); |
210 | mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst_scratchpad)]); |
211 | if (tail_size_) |
212 | mov(reg_last_oc_block_, ptr[abi_param1 + PARAM_OFF(last_oc_block)]); |
213 | } |
214 | |
215 | #undef PARAM_OFF |
216 | |
217 | template <cpu_isa_t isa, |
218 | typename T = std::integral_constant<bool, (isa < avx512_core)>> |
219 | struct helper_create_deconv_ker_t { |
220 | static jit_uni_deconv_zp_pad_str_kernel_base_t * |
221 | create_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp) { |
222 | |
223 | const int ch_block = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block; |
224 | switch (ch_block) { |
225 | case 8: |
226 | if (isa == avx2) { |
227 | return new jit_uni_deconv_zp_pad_str_kernel_t<avx2, |
228 | Xbyak::Ymm>(jcp); |
229 | } else |
230 | assert(!"invalid channel blocking for current ISA" ); |
231 | case 4: |
232 | return new jit_uni_deconv_zp_pad_str_kernel_t<isa, Xbyak::Xmm>( |
233 | jcp); |
234 | default: assert(!"invalid channel blocking" ); |
235 | } |
236 | |
237 | return nullptr; |
238 | } |
239 | }; |
240 | |
241 | template <cpu_isa_t isa> |
242 | struct helper_create_deconv_ker_t<isa, isa_at_least_avx512_core> { |
243 | static jit_uni_deconv_zp_pad_str_kernel_base_t * |
244 | create_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp) { |
245 | const int ch_block = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block; |
246 | switch (ch_block) { |
247 | case 16: |
248 | return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core, |
249 | Xbyak::Zmm>(jcp); |
250 | case 8: |
251 | return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core, |
252 | Xbyak::Ymm>(jcp); |
253 | case 4: |
254 | return new jit_uni_deconv_zp_pad_str_kernel_t<avx512_core, |
255 | Xbyak::Xmm>(jcp); |
256 | default: assert(!"invalid channel blocking" ); |
257 | } |
258 | |
259 | return nullptr; |
260 | } |
261 | }; |
262 | |
263 | template <cpu_isa_t isa> |
264 | jit_uni_deconv_zp_pad_str_kernel_base_t *create_deconv_zp_pad_str_comp_ker( |
265 | const jit_conv_conf_t &jcp) { |
266 | |
267 | return helper_create_deconv_ker_t<isa>::create_deconv_zp_pad_str_comp_ker( |
268 | jcp); |
269 | } |
270 | |
271 | #define wht_blk_off(d, g, ...) \ |
272 | (with_groups ? (d).blk_off((g), __VA_ARGS__) : (d).blk_off(__VA_ARGS__)) |
273 | |
274 | static dim_t wei_off(const memory_desc_wrapper &wei_d, const bool with_groups, |
275 | const dim_t ch_b, const dim_t oc_b, const dim_t d, const dim_t h, |
276 | const dim_t w) { |
277 | |
278 | const auto ndims = wei_d.ndims() - (with_groups ? 1 : 0); |
279 | |
280 | switch (ndims) { |
281 | case 5: return wht_blk_off(wei_d, ch_b, oc_b, 0, d, h, w); |
282 | case 4: return wht_blk_off(wei_d, ch_b, oc_b, 0, h, w); |
283 | case 3: return wht_blk_off(wei_d, ch_b, oc_b, 0, w); |
284 | default: assert("Unsupported ndims!" ); |
285 | } |
286 | |
287 | return 0; |
288 | } |
289 | |
290 | static dim_t dst_off(const jit_conv_conf_t &jcp, const dim_t ndims, |
291 | const dim_t g, const dim_t oc, const dim_t d, const dim_t h, |
292 | const dim_t w) { |
293 | |
294 | const auto &G = jcp.ngroups; |
295 | const auto &OC = jcp.oc_without_padding; |
296 | const auto &OW = jcp.kw; |
297 | const auto &OH = jcp.kh; |
298 | |
299 | dim_t offset = w; |
300 | |
301 | if (ndims == 5) |
302 | offset += d * OH * OW + h * OW; |
303 | else if (ndims == 4) |
304 | offset += h * OW; |
305 | |
306 | if (G == 1) return offset * OC + oc; |
307 | |
308 | return (offset * OC * G) + g * OC + oc; |
309 | } |
310 | |
311 | void compute_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp, |
312 | const bool with_groups, const memory_desc_wrapper &wei_d, |
313 | const int8_t *wei, const int32_t *src_zp, int32_t *dst, |
314 | jit_uni_deconv_zp_pad_str_kernel_base_t *ker) { |
315 | |
316 | using namespace dnnl::impl::utils; |
317 | const auto work_amount = jcp.nb_ch * jcp.nb_oc * jcp.kw * jcp.kd * jcp.kh; |
318 | /* |
319 | * Heuristics for parallel computation usage - cost of threads creation |
320 | * may exceed the computation time which leads to performance drop |
321 | */ |
322 | static constexpr int parallelization_ratio_thr = 5; |
323 | const int nthrs = (work_amount / jcp.nthr) > parallelization_ratio_thr |
324 | ? jcp.nthr |
325 | : 1; |
326 | |
327 | parallel(nthrs, [&](const int ithr, const int nthr) { |
328 | int start {0}, end {0}; |
329 | balance211(work_amount, nthr, ithr, start, end); |
330 | |
331 | int ch_b {0}, oc_b {0}, d {0}, h {0}, w {0}; |
332 | if (jcp.loop_order == loop_ngc) |
333 | nd_iterator_init(start, ch_b, jcp.nb_ch, oc_b, jcp.nb_oc, d, jcp.kd, |
334 | h, jcp.kh, w, jcp.kw); |
335 | else if (jcp.loop_order == loop_cgn) |
336 | nd_iterator_init(start, oc_b, jcp.nb_oc, ch_b, jcp.nb_ch, d, jcp.kd, |
337 | h, jcp.kh, w, jcp.kw); |
338 | |
339 | for (auto iwork = start; iwork < end; ++iwork) { |
340 | jit_uni_deconv_zp_pad_str_call_params_t params; |
341 | const auto oc = oc_b * jcp.oc_block; |
342 | const auto g = ch_b * jcp.ch_block; |
343 | params.wei = wei + wei_off(wei_d, with_groups, ch_b, oc_b, d, h, w); |
344 | params.src_zero_point = src_zp; |
345 | params.last_oc_block = jcp.is_depthwise ? ch_b == jcp.nb_ch - 1 |
346 | : oc_b == jcp.nb_oc - 1; |
347 | params.dst_scratchpad = dst |
348 | + dst_off(jcp, wei_d.ndims() - (with_groups ? 1 : 0), g, oc, |
349 | d, h, w); |
350 | |
351 | (*ker)(¶ms); |
352 | |
353 | if (jcp.loop_order == loop_ngc) |
354 | nd_iterator_step(ch_b, jcp.nb_ch, oc_b, jcp.nb_oc, d, jcp.kd, h, |
355 | jcp.kh, w, jcp.kw); |
356 | else if (jcp.loop_order == loop_cgn) |
357 | nd_iterator_step(oc_b, jcp.nb_oc, ch_b, jcp.nb_ch, d, jcp.kd, h, |
358 | jcp.kh, w, jcp.kw); |
359 | else |
360 | assert(!"unsupported loop order" ); |
361 | } |
362 | }); |
363 | } |
364 | |
365 | static bool stride_exists(const jit_conv_conf_t &jcp) noexcept { |
366 | return jcp.stride_d > 1 || jcp.stride_w > 1 || jcp.stride_h > 1; |
367 | } |
368 | |
369 | static bool padding_exists(const jit_conv_conf_t &jcp) noexcept { |
370 | const auto dd = jcp.dilate_d + 1; |
371 | const auto dh = jcp.dilate_h + 1; |
372 | const auto dw = jcp.dilate_w + 1; |
373 | return jcp.kw - jcp.l_pad / dw - 1 || jcp.kw - jcp.r_pad / dw - 1 |
374 | || jcp.kh - jcp.t_pad / dh - 1 || jcp.kh - jcp.b_pad / dh - 1 |
375 | || jcp.kd - jcp.f_pad / dd - 1 || jcp.kd - jcp.back_pad / dd - 1; |
376 | } |
377 | |
378 | bool should_calculate_deconv_zp_src_pad_str_comp( |
379 | const jit_conv_conf_t &jcp) noexcept { |
380 | return jcp.src_zero_point && (stride_exists(jcp) || padding_exists(jcp)); |
381 | } |
382 | |
383 | template jit_uni_deconv_zp_pad_str_kernel_base_t * |
384 | create_deconv_zp_pad_str_comp_ker<sse41>(const jit_conv_conf_t &jcp); |
385 | template jit_uni_deconv_zp_pad_str_kernel_base_t * |
386 | create_deconv_zp_pad_str_comp_ker<avx2>(const jit_conv_conf_t &jcp); |
387 | template jit_uni_deconv_zp_pad_str_kernel_base_t * |
388 | create_deconv_zp_pad_str_comp_ker<avx512_core>(const jit_conv_conf_t &jcp); |
389 | |
390 | } // namespace zp |
391 | } // namespace x64 |
392 | } // namespace cpu |
393 | } // namespace impl |
394 | } // namespace dnnl |
395 | |