1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP |
19 | |
20 | #include <functional> |
21 | #include <vector> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/memory.hpp" |
26 | #include "common/nstl.hpp" |
27 | #include "common/primitive.hpp" |
28 | #include "common/type_helpers.hpp" |
29 | #include "common/utils.hpp" |
30 | |
31 | #include "cpu/cpu_deconvolution_pd.hpp" |
32 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
33 | #include "cpu/x64/jit_generator.hpp" |
34 | #include "cpu/x64/jit_primitive_conf.hpp" |
35 | #include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | typedef enum { |
43 | no_last_block = 0x1U, |
44 | last_ic_block = 0x2U, |
45 | last_sp_block = 0x4U, |
46 | } ker_block_t; |
47 | |
48 | struct ur_w_blks_params_t { |
49 | struct single_ur_w_blk_params_t { |
50 | single_ur_w_blk_params_t( |
51 | int l_overflow, int r_overflow, bool process_sp_carefully) |
52 | : l_overflow(l_overflow) |
53 | , r_overflow(r_overflow) |
54 | , process_sp_carefully(process_sp_carefully) {} |
55 | |
56 | // l_overflow - no. of spatial elements of weights standing out of |
57 | // src spatial when computing the 1st output pixel in the current blk |
58 | int l_overflow; |
59 | // r_overflow - no. of spatial elements of weights standing out of |
60 | // src spatial when computing the lst output pixel in the current blk |
61 | int r_overflow; |
62 | // process_sp_carefully - indicates if loading the last src sp |
63 | // for computation of the last dst sp of the block can't be done |
64 | // by fetching 4 src sp at once |
65 | bool process_sp_carefully; |
66 | }; |
67 | std::vector<single_ur_w_blk_params_t> blks_params; |
68 | int num_pre_blks; // num of blocks with l_overflow>0 |
69 | int num_post_blks; // num of blocks with r_overflow>0 or that need to be |
70 | // processed carefully |
71 | }; |
72 | |
73 | template <typename Vmm> |
74 | struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { |
75 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); |
76 | |
77 | jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
78 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
79 | ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel(); |
80 | |
81 | const jit_conv_conf_t &jcp; |
82 | const primitive_attr_t &attr_; |
83 | |
84 | private: |
85 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>> |
86 | postops_injector_; |
87 | |
88 | const int ic_sub_step = 4; |
89 | |
90 | /* data regs */ |
91 | const Xbyak::Reg64 reg_src = r8; |
92 | const Xbyak::Reg64 reg_filt = r9; |
93 | const Xbyak::Reg64 reg_dst = r10; |
94 | const Xbyak::Reg64 param1 = abi_param1; |
95 | const Xbyak::Reg64 reg_kh = abi_not_param1; |
96 | const Xbyak::Reg64 reg_ki = r14; |
97 | |
98 | const Xbyak::Reg64 reg_nur_w = rbx; |
99 | const Xbyak::Reg64 reg_bias = rdx; |
100 | const Xbyak::Reg64 reg_icb = reg_bias; |
101 | const Xbyak::Reg64 reg_ptr_scales = rax; |
102 | const Xbyak::Reg64 reg_ptr_dst_scales = rax; |
103 | const Xbyak::Reg64 reg_ptr_saturation_ubound = rax; |
104 | const Xbyak::Reg64 reg_oc_blocks = rsi; |
105 | |
106 | const Xbyak::Reg64 aux_reg_src = r11; |
107 | const Xbyak::Reg64 aux_reg_filt = r12; |
108 | |
109 | const Xbyak::Reg64 aux_reg_src_d = r13; |
110 | const Xbyak::Reg64 aux_reg_filt_d = r15; |
111 | |
112 | const Xbyak::Reg64 reg_compensation = r14; |
113 | const Xbyak::Reg64 reg_scratch = r14; |
114 | const Xbyak::Reg64 reg_ptr_sum_scale = r11; |
115 | const Xbyak::Reg64 reg_overflow = rax; |
116 | const Xbyak::Reg64 reg_comp_strides = reg_overflow; |
117 | const Xbyak::Reg64 reg_ker_long_offt = r15; |
118 | const Xbyak::Reg64 ®_zp_dst_ = r15; |
119 | const Xbyak::Reg64 ®_zp_src_ = r15; |
120 | const Xbyak::Reg64 ®_zp_compensation = r11; |
121 | static constexpr int reserved_stack_size_ = 16; |
122 | const Xbyak::Address zp_src_pad_comp_addr = ptr[rsp]; |
123 | const Xbyak::Address reg_scratch_preserved = ptr[rsp + 8]; |
124 | |
125 | Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); |
126 | const Vmm vmm_tmp = Vmm(28); |
127 | const Vmm vmm_one = Vmm(29); |
128 | /* used during write-out section of store_output */ |
129 | const Vmm vmm_zero = Vmm(31); |
130 | const Vmm vmm_saturation = Vmm(31); |
131 | const Vmm vmm_wei = Vmm(31); |
132 | |
133 | /* signed input */ |
134 | const Vmm vmm_shift = Vmm(30); |
135 | const Vmm vmm_comp = Vmm(30); |
136 | const Vmm vmm_bias = Vmm(31); |
137 | const Vmm vmm_dst_scale = Vmm(31); |
138 | const Vmm vmm_prev_dst = Vmm(31); |
139 | |
140 | Vmm vmm_out(int i_ur, int i_oc) { |
141 | int idx = i_ur * jcp.nb_oc_blocking + i_oc; |
142 | assert(idx < 31); |
143 | return Vmm(idx); |
144 | } |
145 | Vmm vmm_inp(int i_ic, int nb_x_blocking) const { |
146 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
147 | assert(idx < 31); |
148 | return Vmm(idx); |
149 | } |
150 | |
151 | int get_ow_start(int ki, int l_overflow) { |
152 | int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w |
153 | + l_overflow * jcp.stride_w |
154 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
155 | while (res < 0) |
156 | res += jcp.stride_w; |
157 | return res; |
158 | } |
159 | |
160 | int get_ow_end(int ur_w, int ki, int r_overflow) { |
161 | if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) |
162 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
163 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
164 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
165 | while (res < 0) |
166 | res += jcp.stride_w; |
167 | return ur_w - res; |
168 | } |
169 | |
170 | int get_blocking_size() const noexcept; |
171 | int get_tail_size() const noexcept; |
172 | void prepare_output(int ur_w); |
173 | void store_output(int ur_w, bool last_oc_block); |
174 | void compute(const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src); |
175 | std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const |
176 | noexcept; |
177 | void apply_zp_src_pad_str_comp( |
178 | int ur_w, int l_overflow, int r_overflow, bool h_padded); |
179 | void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow, |
180 | bool h_padded, bool last_oc_block); |
181 | void compute_ker(int ur_w, int l_overflow, int r_overflow, |
182 | ker_block_t last_ic_block_flag, bool h_padded = false); |
183 | void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); |
184 | void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); |
185 | |
186 | ur_w_blks_params_t get_ur_w_blks_params(); |
187 | |
188 | void generate() override; |
189 | void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, |
190 | bool mask_flag); |
191 | }; |
192 | |
193 | struct _jit_avx512_core_x8s8s32x_deconv_fwd_kernel { |
194 | _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
195 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
196 | : kernel_(nullptr) { |
197 | |
198 | int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; |
199 | switch (ch_block) { |
200 | case 16: |
201 | kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
202 | Xbyak::Zmm>(ajcp, attr, dst_md); |
203 | return; |
204 | case 8: |
205 | kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
206 | Xbyak::Ymm>(ajcp, attr, dst_md); |
207 | return; |
208 | case 4: |
209 | kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< |
210 | Xbyak::Xmm>(ajcp, attr, dst_md); |
211 | return; |
212 | default: assert(!"invalid channel blocking" ); |
213 | } |
214 | } |
215 | |
216 | status_t create_kernel() { return kernel_->create_kernel(); } |
217 | |
218 | ~_jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { delete kernel_; } |
219 | |
220 | void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); } |
221 | |
222 | static bool post_ops_ok(jit_conv_conf_t &jcp, primitive_attr_t &attr, |
223 | const memory_desc_wrapper &dst_d); |
224 | |
225 | static status_t init_conf(jit_conv_conf_t &jcp, |
226 | const deconvolution_desc_t &cd, memory_desc_t &src_md, |
227 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
228 | const bool with_bias, memory_desc_t &bias_md, |
229 | primitive_attr_t &attr, int nthreads); |
230 | |
231 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
232 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
233 | |
234 | private: |
235 | DNNL_DISALLOW_COPY_AND_ASSIGN(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel); |
236 | jit_generator *kernel_; |
237 | }; |
238 | |
239 | struct jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t { |
240 | struct pd_t : public cpu_deconvolution_fwd_pd_t { |
241 | using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; |
242 | |
243 | DECLARE_COMMON_PD_T( |
244 | JIT_IMPL_NAME_HELPER("jit_deconvolution:" , |
245 | ((jcp_.has_vnni) ? avx512_core_vnni : avx512_core), "" ), |
246 | jit_avx512_core_x8s8s32x_deconvolution_fwd_t); |
247 | |
248 | status_t init(engine_t *engine) { |
249 | using namespace data_type; |
250 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
251 | const bool ok = is_fwd() |
252 | && (desc()->alg_kind & alg_kind::deconvolution_direct) |
253 | && utils::one_of(src_md(0)->data_type, s8, u8) |
254 | && weights_md(0)->data_type == s8 |
255 | && IMPLICATION(with_bias(), |
256 | utils::one_of( |
257 | weights_md(1)->data_type, f32, s32, s8, u8)) |
258 | && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8) |
259 | && desc()->accum_data_type == s32 |
260 | && attr()->has_default_values(skip_mask_t::scales_runtime |
261 | | skip_mask_t::post_ops |
262 | | skip_mask_t::zero_points_runtime); |
263 | if (!ok) return status::unimplemented; |
264 | |
265 | CHECK(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(jcp_, |
266 | *desc(), src_md_, weights_md_, dst_md_, with_bias(), |
267 | bias_md_, attr_, dnnl_get_max_threads())); |
268 | |
269 | auto scratchpad = scratchpad_registry().registrar(); |
270 | _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( |
271 | scratchpad, jcp_, *attr()); |
272 | |
273 | return status::success; |
274 | } |
275 | |
276 | jit_conv_conf_t jcp_; |
277 | }; |
278 | |
279 | jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) |
280 | : primitive_t(apd) {} |
281 | |
282 | status_t init(engine_t *engine) override { |
283 | CHECK(safe_ptr_assign(kernel_, |
284 | new _jit_avx512_core_x8s8s32x_deconv_fwd_kernel( |
285 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
286 | |
287 | if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) { |
288 | CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_, |
289 | zp::create_deconv_zp_pad_str_comp_ker<avx512_core>( |
290 | pd()->jcp_))); |
291 | const auto zp_kernel_status |
292 | = zp_src_pad_comp_kernel_->create_kernel(); |
293 | if (zp_kernel_status != status::success) return zp_kernel_status; |
294 | } |
295 | |
296 | return kernel_->create_kernel(); |
297 | } |
298 | |
299 | status_t execute(const exec_ctx_t &ctx) const override { |
300 | auto ndims = pd()->ndims(); |
301 | if (ndims == 3) |
302 | return execute_forward_1d(ctx); |
303 | else if (ndims == 4) |
304 | return execute_forward_2d(ctx); |
305 | else if (ndims == 5) |
306 | return execute_forward_3d(ctx); |
307 | return status::runtime_error; |
308 | } |
309 | |
310 | private: |
311 | status_t execute_forward_1d(const exec_ctx_t &ctx) const; |
312 | status_t execute_forward_2d(const exec_ctx_t &ctx) const; |
313 | status_t execute_forward_3d(const exec_ctx_t &ctx) const; |
314 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
315 | std::unique_ptr<_jit_avx512_core_x8s8s32x_deconv_fwd_kernel> kernel_; |
316 | std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t> |
317 | zp_src_pad_comp_kernel_; |
318 | const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad, |
319 | const float *src_scales, const float *wei_scales) const; |
320 | }; |
321 | |
322 | } // namespace x64 |
323 | } // namespace cpu |
324 | } // namespace impl |
325 | } // namespace dnnl |
326 | |
327 | #endif |
328 | |
329 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
330 | |