1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP
19
20#include <functional>
21#include <vector>
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/memory.hpp"
26#include "common/nstl.hpp"
27#include "common/primitive.hpp"
28#include "common/type_helpers.hpp"
29#include "common/utils.hpp"
30
31#include "cpu/cpu_deconvolution_pd.hpp"
32#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
33#include "cpu/x64/jit_generator.hpp"
34#include "cpu/x64/jit_primitive_conf.hpp"
35#include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42typedef enum {
43 no_last_block = 0x1U,
44 last_ic_block = 0x2U,
45 last_sp_block = 0x4U,
46} ker_block_t;
47
48struct ur_w_blks_params_t {
49 struct single_ur_w_blk_params_t {
50 single_ur_w_blk_params_t(
51 int l_overflow, int r_overflow, bool process_sp_carefully)
52 : l_overflow(l_overflow)
53 , r_overflow(r_overflow)
54 , process_sp_carefully(process_sp_carefully) {}
55
56 // l_overflow - no. of spatial elements of weights standing out of
57 // src spatial when computing the 1st output pixel in the current blk
58 int l_overflow;
59 // r_overflow - no. of spatial elements of weights standing out of
60 // src spatial when computing the lst output pixel in the current blk
61 int r_overflow;
62 // process_sp_carefully - indicates if loading the last src sp
63 // for computation of the last dst sp of the block can't be done
64 // by fetching 4 src sp at once
65 bool process_sp_carefully;
66 };
67 std::vector<single_ur_w_blk_params_t> blks_params;
68 int num_pre_blks; // num of blocks with l_overflow>0
69 int num_post_blks; // num of blocks with r_overflow>0 or that need to be
70 // processed carefully
71};
72
73template <typename Vmm>
74struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator {
75 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t);
76
77 jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
78 const primitive_attr_t &attr, const memory_desc_t &dst_md);
79 ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel();
80
81 const jit_conv_conf_t &jcp;
82 const primitive_attr_t &attr_;
83
84private:
85 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>>
86 postops_injector_;
87
88 const int ic_sub_step = 4;
89
90 /* data regs */
91 const Xbyak::Reg64 reg_src = r8;
92 const Xbyak::Reg64 reg_filt = r9;
93 const Xbyak::Reg64 reg_dst = r10;
94 const Xbyak::Reg64 param1 = abi_param1;
95 const Xbyak::Reg64 reg_kh = abi_not_param1;
96 const Xbyak::Reg64 reg_ki = r14;
97
98 const Xbyak::Reg64 reg_nur_w = rbx;
99 const Xbyak::Reg64 reg_bias = rdx;
100 const Xbyak::Reg64 reg_icb = reg_bias;
101 const Xbyak::Reg64 reg_ptr_scales = rax;
102 const Xbyak::Reg64 reg_ptr_dst_scales = rax;
103 const Xbyak::Reg64 reg_ptr_saturation_ubound = rax;
104 const Xbyak::Reg64 reg_oc_blocks = rsi;
105
106 const Xbyak::Reg64 aux_reg_src = r11;
107 const Xbyak::Reg64 aux_reg_filt = r12;
108
109 const Xbyak::Reg64 aux_reg_src_d = r13;
110 const Xbyak::Reg64 aux_reg_filt_d = r15;
111
112 const Xbyak::Reg64 reg_compensation = r14;
113 const Xbyak::Reg64 reg_scratch = r14;
114 const Xbyak::Reg64 reg_ptr_sum_scale = r11;
115 const Xbyak::Reg64 reg_overflow = rax;
116 const Xbyak::Reg64 reg_comp_strides = reg_overflow;
117 const Xbyak::Reg64 reg_ker_long_offt = r15;
118 const Xbyak::Reg64 &reg_zp_dst_ = r15;
119 const Xbyak::Reg64 &reg_zp_src_ = r15;
120 const Xbyak::Reg64 &reg_zp_compensation = r11;
121 static constexpr int reserved_stack_size_ = 16;
122 const Xbyak::Address zp_src_pad_comp_addr = ptr[rsp];
123 const Xbyak::Address reg_scratch_preserved = ptr[rsp + 8];
124
125 Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
126 const Vmm vmm_tmp = Vmm(28);
127 const Vmm vmm_one = Vmm(29);
128 /* used during write-out section of store_output */
129 const Vmm vmm_zero = Vmm(31);
130 const Vmm vmm_saturation = Vmm(31);
131 const Vmm vmm_wei = Vmm(31);
132
133 /* signed input */
134 const Vmm vmm_shift = Vmm(30);
135 const Vmm vmm_comp = Vmm(30);
136 const Vmm vmm_bias = Vmm(31);
137 const Vmm vmm_dst_scale = Vmm(31);
138 const Vmm vmm_prev_dst = Vmm(31);
139
140 Vmm vmm_out(int i_ur, int i_oc) {
141 int idx = i_ur * jcp.nb_oc_blocking + i_oc;
142 assert(idx < 31);
143 return Vmm(idx);
144 }
145 Vmm vmm_inp(int i_ic, int nb_x_blocking) const {
146 int idx = i_ic + nb_x_blocking * jcp.ur_w;
147 assert(idx < 31);
148 return Vmm(idx);
149 }
150
151 int get_ow_start(int ki, int l_overflow) {
152 int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w
153 + l_overflow * jcp.stride_w
154 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
155 while (res < 0)
156 res += jcp.stride_w;
157 return res;
158 }
159
160 int get_ow_end(int ur_w, int ki, int r_overflow) {
161 if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail))
162 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
163 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
164 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
165 while (res < 0)
166 res += jcp.stride_w;
167 return ur_w - res;
168 }
169
170 int get_blocking_size() const noexcept;
171 int get_tail_size() const noexcept;
172 void prepare_output(int ur_w);
173 void store_output(int ur_w, bool last_oc_block);
174 void compute(const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src);
175 std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const
176 noexcept;
177 void apply_zp_src_pad_str_comp(
178 int ur_w, int l_overflow, int r_overflow, bool h_padded);
179 void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow,
180 bool h_padded, bool last_oc_block);
181 void compute_ker(int ur_w, int l_overflow, int r_overflow,
182 ker_block_t last_ic_block_flag, bool h_padded = false);
183 void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
184 void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
185
186 ur_w_blks_params_t get_ur_w_blks_params();
187
188 void generate() override;
189 void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op,
190 bool mask_flag);
191};
192
193struct _jit_avx512_core_x8s8s32x_deconv_fwd_kernel {
194 _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
195 const primitive_attr_t &attr, const memory_desc_t &dst_md)
196 : kernel_(nullptr) {
197
198 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
199 switch (ch_block) {
200 case 16:
201 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
202 Xbyak::Zmm>(ajcp, attr, dst_md);
203 return;
204 case 8:
205 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
206 Xbyak::Ymm>(ajcp, attr, dst_md);
207 return;
208 case 4:
209 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
210 Xbyak::Xmm>(ajcp, attr, dst_md);
211 return;
212 default: assert(!"invalid channel blocking");
213 }
214 }
215
216 status_t create_kernel() { return kernel_->create_kernel(); }
217
218 ~_jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { delete kernel_; }
219
220 void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); }
221
222 static bool post_ops_ok(jit_conv_conf_t &jcp, primitive_attr_t &attr,
223 const memory_desc_wrapper &dst_d);
224
225 static status_t init_conf(jit_conv_conf_t &jcp,
226 const deconvolution_desc_t &cd, memory_desc_t &src_md,
227 memory_desc_t &weights_md, memory_desc_t &dst_md,
228 const bool with_bias, memory_desc_t &bias_md,
229 primitive_attr_t &attr, int nthreads);
230
231 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
232 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
233
234private:
235 DNNL_DISALLOW_COPY_AND_ASSIGN(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel);
236 jit_generator *kernel_;
237};
238
239struct jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t {
240 struct pd_t : public cpu_deconvolution_fwd_pd_t {
241 using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
242
243 DECLARE_COMMON_PD_T(
244 JIT_IMPL_NAME_HELPER("jit_deconvolution:",
245 ((jcp_.has_vnni) ? avx512_core_vnni : avx512_core), ""),
246 jit_avx512_core_x8s8s32x_deconvolution_fwd_t);
247
248 status_t init(engine_t *engine) {
249 using namespace data_type;
250 using skip_mask_t = primitive_attr_t::skip_mask_t;
251 const bool ok = is_fwd()
252 && (desc()->alg_kind & alg_kind::deconvolution_direct)
253 && utils::one_of(src_md(0)->data_type, s8, u8)
254 && weights_md(0)->data_type == s8
255 && IMPLICATION(with_bias(),
256 utils::one_of(
257 weights_md(1)->data_type, f32, s32, s8, u8))
258 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
259 && desc()->accum_data_type == s32
260 && attr()->has_default_values(skip_mask_t::scales_runtime
261 | skip_mask_t::post_ops
262 | skip_mask_t::zero_points_runtime);
263 if (!ok) return status::unimplemented;
264
265 CHECK(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(jcp_,
266 *desc(), src_md_, weights_md_, dst_md_, with_bias(),
267 bias_md_, attr_, dnnl_get_max_threads()));
268
269 auto scratchpad = scratchpad_registry().registrar();
270 _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
271 scratchpad, jcp_, *attr());
272
273 return status::success;
274 }
275
276 jit_conv_conf_t jcp_;
277 };
278
279 jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd)
280 : primitive_t(apd) {}
281
282 status_t init(engine_t *engine) override {
283 CHECK(safe_ptr_assign(kernel_,
284 new _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(
285 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
286
287 if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) {
288 CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_,
289 zp::create_deconv_zp_pad_str_comp_ker<avx512_core>(
290 pd()->jcp_)));
291 const auto zp_kernel_status
292 = zp_src_pad_comp_kernel_->create_kernel();
293 if (zp_kernel_status != status::success) return zp_kernel_status;
294 }
295
296 return kernel_->create_kernel();
297 }
298
299 status_t execute(const exec_ctx_t &ctx) const override {
300 auto ndims = pd()->ndims();
301 if (ndims == 3)
302 return execute_forward_1d(ctx);
303 else if (ndims == 4)
304 return execute_forward_2d(ctx);
305 else if (ndims == 5)
306 return execute_forward_3d(ctx);
307 return status::runtime_error;
308 }
309
310private:
311 status_t execute_forward_1d(const exec_ctx_t &ctx) const;
312 status_t execute_forward_2d(const exec_ctx_t &ctx) const;
313 status_t execute_forward_3d(const exec_ctx_t &ctx) const;
314 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
315 std::unique_ptr<_jit_avx512_core_x8s8s32x_deconv_fwd_kernel> kernel_;
316 std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t>
317 zp_src_pad_comp_kernel_;
318 const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad,
319 const float *src_scales, const float *wei_scales) const;
320};
321
322} // namespace x64
323} // namespace cpu
324} // namespace impl
325} // namespace dnnl
326
327#endif
328
329// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
330