1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP |
18 | #define CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP |
19 | |
20 | #include <functional> |
21 | #include <memory> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/primitive.hpp" |
25 | |
26 | #include "cpu/cpu_deconvolution_pd.hpp" |
27 | |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_primitive_conf.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | namespace zp { |
37 | class jit_uni_deconv_zp_pad_str_kernel_base_t; |
38 | } // namespace zp |
39 | |
40 | namespace injector { |
41 | template <cpu_isa_t isa, typename Vmm> |
42 | class jit_uni_postops_injector_t; |
43 | } // namespace injector |
44 | |
45 | using namespace Xbyak; |
46 | |
47 | template <cpu_isa_t isa, typename Vmm> |
48 | struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator { |
49 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_deconv_fwd_kernel); |
50 | |
51 | _jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
52 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); |
53 | ~_jit_uni_x8s8s32x_deconv_fwd_kernel(); |
54 | |
55 | const jit_conv_conf_t jcp_; |
56 | |
57 | private: |
58 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>> |
59 | postops_injector_; |
60 | using reg64_t = const Xbyak::Reg64; |
61 | |
62 | static constexpr dim_t IC_SUB_STEP = 4; |
63 | const int ker_max_regs_ = -1; |
64 | |
65 | enum ker_block_t { |
66 | no_last_block = 0x1U, |
67 | last_ic_block = 0x2U, |
68 | last_sp_block = 0x4U, |
69 | }; |
70 | |
71 | /* data regs */ |
72 | const reg64_t reg_src_ = r8; |
73 | const reg64_t reg_filt_ = r9; |
74 | const reg64_t reg_dst_ = r10; |
75 | const reg64_t param1_ = abi_param1; |
76 | const reg64_t reg_kh_ = abi_not_param1; |
77 | const reg64_t reg_ki_ = r14; |
78 | |
79 | const reg64_t reg_nur_w_ = rbx; |
80 | const reg64_t reg_bias_ = rdx; |
81 | const reg64_t reg_icb_ = reg_bias_; |
82 | const reg64_t reg_ptr_scales_ = rax; |
83 | const reg64_t reg_ptr_dst_scales_ = abi_not_param1; |
84 | const reg64_t reg_ptr_saturation_ubound_ = rax; |
85 | const reg64_t reg_oc_blocks_ = rsi; |
86 | |
87 | const reg64_t aux_reg_src_ = r11; |
88 | const reg64_t aux_reg_filt_ = r12; |
89 | |
90 | const reg64_t aux_reg_src_d_ = r13; |
91 | const reg64_t aux_reg_filt_d_ = r15; |
92 | |
93 | const reg64_t reg_compensation_ = r14; |
94 | const reg64_t reg_scratch_ = r14; |
95 | const reg64_t reg_ptr_sum_scale_ = r11; |
96 | const reg64_t reg_ptr_sum_zp_ = r15; |
97 | const reg64_t reg_overflow_ = rax; |
98 | const reg64_t reg_comp_strides_ = reg_overflow_; |
99 | const reg64_t reg_ker_long_offt_ = r15; |
100 | const reg64_t reg_zp_dst_ = r15; |
101 | const reg64_t reg_zp_src_ = r15; |
102 | const reg64_t reg_zp_compensation_ = r11; |
103 | const Xbyak::Address zp_src_pad_comp_addr_ = ptr[rsp]; |
104 | const Xbyak::Address reg_scratch_preserved_ = ptr[rsp + 8]; |
105 | static constexpr int reserved_stack_size_ = 16; |
106 | |
107 | const Vmm vmm_tmp_ = Vmm(3); |
108 | const Vmm vmm_one_ = Vmm(2); |
109 | /* used during write-out section of store_output */ |
110 | const Vmm vmm_zero_ = Vmm(0); |
111 | const Vmm vmm_saturation_ = vmm_zero_; |
112 | const Vmm vmm_wei_ = vmm_zero_; |
113 | const Vmm vmm_scale_ = vmm_zero_; |
114 | const Vmm vmm_dst_scale_ = vmm_zero_; |
115 | /* signed input */ |
116 | const Vmm vmm_shift_ = Vmm(1); |
117 | const Vmm vmm_comp_ = Vmm(1); |
118 | const Vmm vmm_bias_ = vmm_zero_; |
119 | const Vmm vmm_prev_dst_ = vmm_zero_; |
120 | const Vmm vmm_sum_zp_ = vmm_tmp_; |
121 | |
122 | Vmm vmm_out(int i_ur, int i_oc) const; |
123 | Vmm vmm_inp(int i_ic, int nb_x_blocking) const; |
124 | |
125 | int get_ow_start(int ki, int l_overflow) const noexcept; |
126 | int get_ow_end(int ur_w, int ki, int r_overflow) const noexcept; |
127 | int get_blocking_size() const noexcept; |
128 | int get_tail_size() const noexcept; |
129 | |
130 | void prepare_output(int ur_w); |
131 | void apply_postops(int ur_w, bool last_oc_block, const float *p_sum_scale, |
132 | const int32_t *p_sum_zp); |
133 | void store_output(int ur_w, bool last_oc_block); |
134 | void compute_ker(int ur_w, int l_overflow, int r_overflow, |
135 | ker_block_t last_ic_block_flag, bool h_padded = false); |
136 | void compute(const Vmm vreg_acc, const Vmm vreg_wei, const Vmm vreg_src); |
137 | std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const |
138 | noexcept; |
139 | void apply_zp_src_pad_str_comp( |
140 | int ur_w, int l_overflow, int r_overflow, bool h_padded); |
141 | void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow, |
142 | bool h_padded, bool last_oc_block); |
143 | void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); |
144 | void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); |
145 | void generate() override; |
146 | void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Reg64 reg, |
147 | int offset, int load_size); |
148 | }; |
149 | |
150 | template <cpu_isa_t isa> |
151 | struct jit_uni_x8s8s32x_deconv_fwd_kernel { |
152 | |
153 | jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, |
154 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); |
155 | |
156 | status_t create_kernel() { return kernel_->create_kernel(); } |
157 | |
158 | ~jit_uni_x8s8s32x_deconv_fwd_kernel(); |
159 | |
160 | void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); } |
161 | |
162 | static bool post_ops_ok(jit_conv_conf_t &jcp, |
163 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); |
164 | |
165 | static status_t init_conf(jit_conv_conf_t &jcp, |
166 | const deconvolution_desc_t &cd, memory_desc_t &src_md, |
167 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
168 | const bool with_bias, memory_desc_t &bias_md, |
169 | primitive_attr_t &attr, int nthreads); |
170 | |
171 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
172 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
173 | |
174 | using _jit_avx2_x8s8s32x_deconv_fwd_kernel |
175 | = _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>; |
176 | |
177 | private: |
178 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_deconv_fwd_kernel); |
179 | std::unique_ptr<jit_generator> kernel_; |
180 | }; |
181 | |
182 | template <cpu_isa_t isa> |
183 | struct jit_uni_x8s8s32x_deconvolution_fwd_t : public primitive_t { |
184 | struct pd_t : public cpu_deconvolution_fwd_pd_t { |
185 | using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; |
186 | |
187 | DECLARE_COMMON_PD_T( |
188 | JIT_IMPL_NAME_HELPER("jit_uni_int8:" , |
189 | isa == avx2 && jcp_.has_vnni ? avx2_vnni : isa, "" ), |
190 | jit_uni_x8s8s32x_deconvolution_fwd_t); |
191 | |
192 | status_t init(engine_t *engine); |
193 | jit_conv_conf_t jcp_; |
194 | }; |
195 | |
196 | jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t *apd); |
197 | ~jit_uni_x8s8s32x_deconvolution_fwd_t(); |
198 | |
199 | status_t init(engine_t *engine) override; |
200 | status_t execute(const exec_ctx_t &ctx) const override; |
201 | |
202 | private: |
203 | status_t execute_forward_1d(const exec_ctx_t &ctx) const; |
204 | status_t execute_forward_2d(const exec_ctx_t &ctx) const; |
205 | status_t execute_forward_3d(const exec_ctx_t &ctx) const; |
206 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
207 | const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad, |
208 | const float *src_scales, const float *wei_scales) const; |
209 | std::unique_ptr<jit_uni_x8s8s32x_deconv_fwd_kernel<isa>> kernel_; |
210 | std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t> |
211 | zp_src_pad_comp_kernel_; |
212 | }; |
213 | |
214 | } // namespace x64 |
215 | } // namespace cpu |
216 | } // namespace impl |
217 | } // namespace dnnl |
218 | |
219 | #endif |
220 | |
221 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
222 | |