1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP
19
20#include <functional>
21#include <memory>
22
23#include "common/c_types_map.hpp"
24#include "common/primitive.hpp"
25
26#include "cpu/cpu_deconvolution_pd.hpp"
27
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36namespace zp {
37class jit_uni_deconv_zp_pad_str_kernel_base_t;
38} // namespace zp
39
40namespace injector {
41template <cpu_isa_t isa, typename Vmm>
42class jit_uni_postops_injector_t;
43} // namespace injector
44
45using namespace Xbyak;
46
47template <cpu_isa_t isa, typename Vmm>
48struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator {
49 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_deconv_fwd_kernel);
50
51 _jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
52 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
53 ~_jit_uni_x8s8s32x_deconv_fwd_kernel();
54
55 const jit_conv_conf_t jcp_;
56
57private:
58 std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>>
59 postops_injector_;
60 using reg64_t = const Xbyak::Reg64;
61
62 static constexpr dim_t IC_SUB_STEP = 4;
63 const int ker_max_regs_ = -1;
64
65 enum ker_block_t {
66 no_last_block = 0x1U,
67 last_ic_block = 0x2U,
68 last_sp_block = 0x4U,
69 };
70
71 /* data regs */
72 const reg64_t reg_src_ = r8;
73 const reg64_t reg_filt_ = r9;
74 const reg64_t reg_dst_ = r10;
75 const reg64_t param1_ = abi_param1;
76 const reg64_t reg_kh_ = abi_not_param1;
77 const reg64_t reg_ki_ = r14;
78
79 const reg64_t reg_nur_w_ = rbx;
80 const reg64_t reg_bias_ = rdx;
81 const reg64_t reg_icb_ = reg_bias_;
82 const reg64_t reg_ptr_scales_ = rax;
83 const reg64_t reg_ptr_dst_scales_ = abi_not_param1;
84 const reg64_t reg_ptr_saturation_ubound_ = rax;
85 const reg64_t reg_oc_blocks_ = rsi;
86
87 const reg64_t aux_reg_src_ = r11;
88 const reg64_t aux_reg_filt_ = r12;
89
90 const reg64_t aux_reg_src_d_ = r13;
91 const reg64_t aux_reg_filt_d_ = r15;
92
93 const reg64_t reg_compensation_ = r14;
94 const reg64_t reg_scratch_ = r14;
95 const reg64_t reg_ptr_sum_scale_ = r11;
96 const reg64_t reg_ptr_sum_zp_ = r15;
97 const reg64_t reg_overflow_ = rax;
98 const reg64_t reg_comp_strides_ = reg_overflow_;
99 const reg64_t reg_ker_long_offt_ = r15;
100 const reg64_t reg_zp_dst_ = r15;
101 const reg64_t reg_zp_src_ = r15;
102 const reg64_t reg_zp_compensation_ = r11;
103 const Xbyak::Address zp_src_pad_comp_addr_ = ptr[rsp];
104 const Xbyak::Address reg_scratch_preserved_ = ptr[rsp + 8];
105 static constexpr int reserved_stack_size_ = 16;
106
107 const Vmm vmm_tmp_ = Vmm(3);
108 const Vmm vmm_one_ = Vmm(2);
109 /* used during write-out section of store_output */
110 const Vmm vmm_zero_ = Vmm(0);
111 const Vmm vmm_saturation_ = vmm_zero_;
112 const Vmm vmm_wei_ = vmm_zero_;
113 const Vmm vmm_scale_ = vmm_zero_;
114 const Vmm vmm_dst_scale_ = vmm_zero_;
115 /* signed input */
116 const Vmm vmm_shift_ = Vmm(1);
117 const Vmm vmm_comp_ = Vmm(1);
118 const Vmm vmm_bias_ = vmm_zero_;
119 const Vmm vmm_prev_dst_ = vmm_zero_;
120 const Vmm vmm_sum_zp_ = vmm_tmp_;
121
122 Vmm vmm_out(int i_ur, int i_oc) const;
123 Vmm vmm_inp(int i_ic, int nb_x_blocking) const;
124
125 int get_ow_start(int ki, int l_overflow) const noexcept;
126 int get_ow_end(int ur_w, int ki, int r_overflow) const noexcept;
127 int get_blocking_size() const noexcept;
128 int get_tail_size() const noexcept;
129
130 void prepare_output(int ur_w);
131 void apply_postops(int ur_w, bool last_oc_block, const float *p_sum_scale,
132 const int32_t *p_sum_zp);
133 void store_output(int ur_w, bool last_oc_block);
134 void compute_ker(int ur_w, int l_overflow, int r_overflow,
135 ker_block_t last_ic_block_flag, bool h_padded = false);
136 void compute(const Vmm vreg_acc, const Vmm vreg_wei, const Vmm vreg_src);
137 std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const
138 noexcept;
139 void apply_zp_src_pad_str_comp(
140 int ur_w, int l_overflow, int r_overflow, bool h_padded);
141 void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow,
142 bool h_padded, bool last_oc_block);
143 void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
144 void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
145 void generate() override;
146 void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Reg64 reg,
147 int offset, int load_size);
148};
149
150template <cpu_isa_t isa>
151struct jit_uni_x8s8s32x_deconv_fwd_kernel {
152
153 jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
154 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
155
156 status_t create_kernel() { return kernel_->create_kernel(); }
157
158 ~jit_uni_x8s8s32x_deconv_fwd_kernel();
159
160 void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); }
161
162 static bool post_ops_ok(jit_conv_conf_t &jcp,
163 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
164
165 static status_t init_conf(jit_conv_conf_t &jcp,
166 const deconvolution_desc_t &cd, memory_desc_t &src_md,
167 memory_desc_t &weights_md, memory_desc_t &dst_md,
168 const bool with_bias, memory_desc_t &bias_md,
169 primitive_attr_t &attr, int nthreads);
170
171 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
172 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
173
174 using _jit_avx2_x8s8s32x_deconv_fwd_kernel
175 = _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>;
176
177private:
178 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_deconv_fwd_kernel);
179 std::unique_ptr<jit_generator> kernel_;
180};
181
182template <cpu_isa_t isa>
183struct jit_uni_x8s8s32x_deconvolution_fwd_t : public primitive_t {
184 struct pd_t : public cpu_deconvolution_fwd_pd_t {
185 using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
186
187 DECLARE_COMMON_PD_T(
188 JIT_IMPL_NAME_HELPER("jit_uni_int8:",
189 isa == avx2 && jcp_.has_vnni ? avx2_vnni : isa, ""),
190 jit_uni_x8s8s32x_deconvolution_fwd_t);
191
192 status_t init(engine_t *engine);
193 jit_conv_conf_t jcp_;
194 };
195
196 jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t *apd);
197 ~jit_uni_x8s8s32x_deconvolution_fwd_t();
198
199 status_t init(engine_t *engine) override;
200 status_t execute(const exec_ctx_t &ctx) const override;
201
202private:
203 status_t execute_forward_1d(const exec_ctx_t &ctx) const;
204 status_t execute_forward_2d(const exec_ctx_t &ctx) const;
205 status_t execute_forward_3d(const exec_ctx_t &ctx) const;
206 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
207 const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad,
208 const float *src_scales, const float *wei_scales) const;
209 std::unique_ptr<jit_uni_x8s8s32x_deconv_fwd_kernel<isa>> kernel_;
210 std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t>
211 zp_src_pad_comp_kernel_;
212};
213
214} // namespace x64
215} // namespace cpu
216} // namespace impl
217} // namespace dnnl
218
219#endif
220
221// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
222