1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_X8S8S32X_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_UNI_X8S8S32X_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | template <cpu_isa_t isa, typename Vmm> |
33 | struct _jit_uni_x8s8s32x_fwd_kernel : public jit_generator { |
34 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_conv_fwd_ker_t_) |
35 | |
36 | _jit_uni_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp, |
37 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
38 | |
39 | jit_conv_conf_t jcp; |
40 | const primitive_attr_t &attr_; |
41 | |
42 | private: |
43 | constexpr static int isa_simd_width_ |
44 | = cpu_isa_traits<isa>::vlen / sizeof(float); |
45 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa>> |
46 | postops_injector_; |
47 | enum { |
48 | typesize = sizeof(float), |
49 | ic_sub_step = 4, |
50 | ker_zp_reg_base_idx = 9, |
51 | ker_reg_base_idx = 12, |
52 | ker_dw_reg_base_idx = 14, |
53 | ker_max_reg = 15, |
54 | }; |
55 | enum ic_block_t { |
56 | no_last_block, |
57 | last_ic_block, |
58 | last_sp_block, |
59 | }; |
60 | |
61 | /* data registers */ |
62 | const Xbyak::Reg64 reg_ptr_scales = rax; |
63 | const Xbyak::Reg64 reg_ptr_saturation_ubound = rax; |
64 | const Xbyak::Reg64 reg_inp = r8; |
65 | const Xbyak::Reg64 reg_ker = r9; |
66 | const Xbyak::Reg64 reg_out = r10; |
67 | const Xbyak::Reg64 aux_reg_inp = r11; |
68 | const Xbyak::Reg64 reg_ptr_sum_scale = r11; |
69 | const Xbyak::Reg64 reg_ptr_sum_zp = rdx; |
70 | const Xbyak::Reg64 aux_reg_ker = r12; |
71 | const Xbyak::Reg64 aux_reg_inp_d = r13; |
72 | const Xbyak::Reg64 reg_compensation = r14; |
73 | const Xbyak::Reg64 aux_reg_ker_d = r15; |
74 | const Xbyak::Reg64 reg_ker_long_offt = r13; |
75 | |
76 | /* counter regs */ |
77 | const Xbyak::Reg64 reg_oi = rbx; |
78 | const Xbyak::Reg64 reg_bias = rdx; |
79 | const Xbyak::Reg64 reg_oc_blocks = rsi; |
80 | const Xbyak::Reg64 reg_owb = aux_reg_ker; |
81 | const Xbyak::Reg64 reg_scratch = reg_compensation; |
82 | const Xbyak::Reg64 reg_ki = reg_compensation; |
83 | const Xbyak::Reg64 reg_kj = reg_ptr_scales; |
84 | const Xbyak::Reg64 reg_overflow = reg_ptr_scales; |
85 | const Xbyak::Reg64 reg_icb = reg_bias; |
86 | // Using 3d regs as depthwise3d is not yet supported |
87 | const Xbyak::Reg64 reg_inp_buffer_ptr = aux_reg_inp_d; |
88 | const Xbyak::Reg64 aux_reg_inp_buffer_ptr = aux_reg_ker_d; |
89 | const Xbyak::Reg64 reg_jmp_tbl_base = reg_kj; |
90 | // zero-point computation |
91 | const Xbyak::Reg64 reg_zp_compensation = aux_reg_inp; |
92 | const Xbyak::Reg64 reg_src_zero_point = aux_reg_ker_d; |
93 | const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point; |
94 | // dst scale |
95 | const Xbyak::Reg64 reg_dst_scale = reg_dst_zero_point; |
96 | |
97 | /* binary post-ops operand */ |
98 | const Xbyak::Reg64 temp_offset_reg = r12; |
99 | |
100 | const Vmm vmm_wei = Vmm(0); |
101 | /* used during bias/comp/scale section of store_output */ |
102 | const Vmm vmm_bias = Vmm(0); |
103 | const Vmm vmm_comp = Vmm(2); // only for signed input |
104 | const Vmm vmm_scale = Vmm(1); |
105 | /* used during post_op sum section of store_output */ |
106 | const Vmm vmm_prev_dst = Vmm(0); |
107 | /* used during write-out section of store_output */ |
108 | const Vmm vmm_zero = Vmm(0); |
109 | const Vmm vmm_saturation = Vmm(0); |
110 | /* used for zero-point */ |
111 | const Vmm vmm_zp = Vmm(6); |
112 | const Vmm vmm_zp_one = Vmm(5); |
113 | const Vmm vmm_zp_comp = vmm_zp_one; |
114 | const Vmm vmm_zp_dw_tmp = vmm_zp_one; |
115 | /* dst scale */ |
116 | const Vmm vmm_dst_scale = Vmm(5); |
117 | |
118 | /* used in compute_ker (but set during prepare_output) */ |
119 | const Vmm vmm_shift = Vmm(1); // only for signed input |
120 | /* used in compute_ker */ |
121 | const Vmm vmm_tmp = Vmm(2); // not used for depthwise |
122 | const Vmm vmm_one |
123 | = Vmm(3); // set at start of kernel, not used for depthwise. |
124 | /* used only for depthwise */ |
125 | Vmm vmm_dw_tmp; |
126 | Vmm vmm_dw_src; |
127 | |
128 | int vmm_out_idx(int i_ur, int i_oc) { |
129 | const int idx_limit = jcp.src_zero_point |
130 | ? ker_zp_reg_base_idx |
131 | : jcp.is_depthwise ? ker_dw_reg_base_idx - jcp.signed_input |
132 | : ker_reg_base_idx; |
133 | const int nb_x_blocking |
134 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
135 | const int idx = i_ur * nb_x_blocking + i_oc; |
136 | assert(idx < idx_limit); |
137 | MAYBE_UNUSED(idx_limit); |
138 | /* remap register indices from 4 to 15 |
139 | * to avoid passing xmm0 to comp*/ |
140 | return ker_max_reg - idx; |
141 | } |
142 | Vmm vmm_out(int i_ur, int i_oc) { |
143 | const int idx = vmm_out_idx(i_ur, i_oc); |
144 | return Vmm(idx); |
145 | } |
146 | Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
147 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
148 | assert(idx < ker_max_reg); |
149 | return Vmm(ker_max_reg - idx); |
150 | } |
151 | int get_ow_start(int ki, int pad_l) { |
152 | return nstl::max(0, |
153 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
154 | } |
155 | int get_ow_end(int ur_w, int ki, int pad_r) { |
156 | return ur_w |
157 | - nstl::max(0, |
158 | utils::div_up( |
159 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
160 | jcp.stride_w)); |
161 | } |
162 | int get_blocking_size() { |
163 | return jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; |
164 | } |
165 | int get_tail_size() { |
166 | return jcp.is_depthwise ? jcp.ngroups % jcp.ch_block |
167 | : jcp.oc_without_padding % jcp.oc_block; |
168 | } |
169 | |
170 | void prepare_output(int ur_w); |
171 | void store_output(int ur_w, bool last_oc_block_flag); |
172 | void compute_ker_dw(int ur_w, int pad_l, int pad_r, |
173 | ic_block_t last_ic_block_flag, bool h_padded); |
174 | void compute_ker(int ur_w, int pad_l, int pad_r, |
175 | ic_block_t last_ic_block_flag, bool h_padded = false); |
176 | void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); |
177 | void icb_loop(int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); |
178 | void generate() override; |
179 | |
180 | void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Xbyak::Reg64 ®, |
181 | int offset, int load_size); |
182 | void apply_sum(const int nb_oc_block, const int ur_w, |
183 | const bool last_oc_block_flag, const int oc_block, |
184 | const float *p_sum_scale, const int32_t *p_sum_zp); |
185 | void apply_postops(const int nb_oc_block, const int ur_w, |
186 | const bool last_oc_block_flag, const int oc_block, |
187 | const float *p_sum_scale, const int32_t *p_sum_zp); |
188 | }; |
189 | |
190 | template <cpu_isa_t isa> |
191 | struct jit_uni_x8s8s32x_fwd_kernel { |
192 | |
193 | jit_uni_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp, |
194 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
195 | : kernel_(nullptr) { |
196 | int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; |
197 | switch (ch_block) { |
198 | case 8: |
199 | if (utils::one_of(isa, avx2)) { |
200 | kernel_ = new _jit_uni_x8s8s32x_fwd_kernel<isa, Xbyak::Ymm>( |
201 | ajcp, attr, dst_md); |
202 | } else |
203 | assert(!"invalid channel blocking for current ISA" ); |
204 | return; |
205 | case 4: |
206 | kernel_ = new _jit_uni_x8s8s32x_fwd_kernel<isa, Xbyak::Xmm>( |
207 | ajcp, attr, dst_md); |
208 | return; |
209 | default: assert(!"invalid channel blocking" ); |
210 | } |
211 | } |
212 | |
213 | status_t create_kernel() { return kernel_->create_kernel(); } |
214 | |
215 | ~jit_uni_x8s8s32x_fwd_kernel() { delete kernel_; } |
216 | |
217 | void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } |
218 | |
219 | static status_t init_conf(jit_conv_conf_t &jcp, |
220 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
221 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
222 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
223 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
224 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
225 | |
226 | void (*jit_ker)(jit_conv_call_s *); |
227 | |
228 | private: |
229 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_fwd_kernel); |
230 | jit_generator *kernel_; |
231 | }; |
232 | |
233 | } // namespace x64 |
234 | } // namespace cpu |
235 | } // namespace impl |
236 | } // namespace dnnl |
237 | |
238 | #endif |
239 | |