1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
25 | #include "cpu/x64/jit_generator.hpp" |
26 | #include "cpu/x64/jit_primitive_conf.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | template <typename Vmm> |
34 | struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator { |
35 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t) |
36 | |
37 | enum { STATE_FIRST_DST_LOAD = 0x1U }; |
38 | |
39 | _jit_avx512_core_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp, |
40 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
41 | |
42 | jit_conv_conf_t jcp; |
43 | const primitive_attr_t &attr_; |
44 | |
45 | private: |
46 | constexpr static int isa_simd_width_ |
47 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
48 | using Vmm_down_t = |
49 | typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value, |
50 | Xbyak::Ymm, Xbyak::Xmm>::type; |
51 | const int ic_sub_step = 4; |
52 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>> |
53 | postops_injector_; |
54 | |
55 | enum { |
56 | typesize = sizeof(float), |
57 | ker_reg_base_idx = 28, |
58 | ker_dw_reg_base_idx = 30, |
59 | ker_zp_reg_base_idx = 26, |
60 | }; |
61 | typedef enum { |
62 | no_last_block, |
63 | last_ic_block, |
64 | last_sp_block, |
65 | } ic_block_t; |
66 | |
67 | /* data regs */ |
68 | const Xbyak::Reg64 reg_ptr_scales = rax; |
69 | const Xbyak::Reg64 aux_reg_saturation = rax; |
70 | const Xbyak::Reg64 reg_inp = r8; |
71 | const Xbyak::Reg64 reg_ker = r9; |
72 | const Xbyak::Reg64 reg_out = r10; |
73 | const Xbyak::Reg64 aux_reg_inp = r11; |
74 | const Xbyak::Reg64 reg_ptr_sum_scale = r11; |
75 | const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1; |
76 | const Xbyak::Reg64 aux_reg_ker = r12; |
77 | const Xbyak::Reg64 reg_compensation = r14; |
78 | const Xbyak::Reg64 aux_reg_inp_d = r13; |
79 | const Xbyak::Reg64 aux_reg_ker_d = r15; |
80 | const Xbyak::Reg64 reg_ker_long_offt = r13; |
81 | // Using 3d regs as depthwise_3d is not yet supported |
82 | const Xbyak::Reg64 reg_inp_buffer_ptr = aux_reg_inp_d; |
83 | const Xbyak::Reg64 aux_reg_inp_buffer_ptr = aux_reg_ker_d; |
84 | // zero-point computation |
85 | const Xbyak::Reg64 reg_zp_compensation = aux_reg_inp; |
86 | const Xbyak::Reg64 reg_src_zero_point = aux_reg_ker_d; |
87 | const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point; |
88 | |
89 | // dst scale |
90 | const Xbyak::Reg64 reg_dst_scale = reg_src_zero_point; |
91 | |
92 | /* counter regs */ |
93 | const Xbyak::Reg64 reg_oi = rbx; |
94 | const Xbyak::Reg64 reg_bias = rdx; |
95 | const Xbyak::Reg64 reg_oc_blocks = rsi; |
96 | const Xbyak::Reg64 reg_owb = aux_reg_ker; |
97 | const Xbyak::Reg64 reg_scratch = reg_compensation; |
98 | const Xbyak::Reg64 reg_kj = reg_ptr_scales; |
99 | const Xbyak::Reg64 reg_ki = reg_compensation; |
100 | const Xbyak::Reg64 reg_overflow = reg_ptr_scales; |
101 | const Xbyak::Reg64 reg_icb = reg_bias; |
102 | const Xbyak::Reg64 reg_jmp_tbl_base = reg_kj; |
103 | |
104 | /* binary post-op operand */ |
105 | const Xbyak::Reg64 temp_offset_reg = r12; |
106 | |
107 | const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); |
108 | const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3); |
109 | const Xbyak::Opmask postops_mask = Xbyak::Opmask(4); |
110 | const Xbyak::Opmask ktail_mask_extended = Xbyak::Opmask(5); |
111 | |
112 | const Vmm vmm_wei = Vmm(31); |
113 | /* used during bias section of store_output */ |
114 | const Vmm vmm_comp = Vmm(30); // only for signed input |
115 | const Vmm vmm_bias = Vmm(31); |
116 | /* used during post_op sum section of store_output */ |
117 | const Vmm vmm_prev_dst = Vmm(31); |
118 | /* used during write-out section of store_output */ |
119 | const Vmm vmm_saturation = Vmm(30); |
120 | const Vmm vmm_sum_zp = Vmm(30); |
121 | const Vmm vmm_zero = Vmm(31); |
122 | |
123 | /* used in compute_ker (but set during prepare_output) */ |
124 | const Vmm vmm_shift = vmm_comp; // only for signed input |
125 | /* used in compute_ker (but only for pre-VNNI machines) */ |
126 | const Vmm vmm_tmp = Vmm(28); // not used for depthwise |
127 | const Vmm vmm_one |
128 | = Vmm(29); // set at start of kernel, not used for depthwise. |
129 | /* zero-point */ |
130 | const Vmm vmm_zp = Vmm(25); |
131 | const Vmm vmm_zp_one = Vmm(26); |
132 | const Vmm vmm_zp_tmp = vmm_zp; |
133 | |
134 | const Vmm vmm_dst_scale = Vmm(26); |
135 | |
136 | /* bf16 emulation */ |
137 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
138 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
139 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
140 | Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(30); |
141 | // bf16_emu_reserv_5 not required when only computing vcvtneps2bf16() |
142 | const Xbyak::Reg64 bf16_emu_scratch = aux_reg_ker; |
143 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
144 | |
145 | /* registers use only for depthwise |
146 | groups are always blocked by 16(padded if needed), |
147 | hence use only Zmm registers */ |
148 | const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31); |
149 | Xbyak::Zmm zmm_tmp; |
150 | Xbyak::Zmm zmm_src; |
151 | Xbyak::Zmm zmm_shifted_zero; |
152 | Xbyak::Zmm zmm_permute; |
153 | |
154 | int vmm_out_idx(int i_ur, int i_oc) { |
155 | const int nb_x_blocking |
156 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
157 | const int idx = i_ur * nb_x_blocking + i_oc; |
158 | assert(idx < (jcp.is_depthwise |
159 | ? ker_dw_reg_base_idx |
160 | : jcp.src_zero_point ? ker_zp_reg_base_idx |
161 | : ker_reg_base_idx)); |
162 | return idx; |
163 | } |
164 | |
165 | Vmm vmm_out(int i_ur, int i_oc) { return Vmm(vmm_out_idx(i_ur, i_oc)); } |
166 | Xbyak::Zmm zmm_out(int i_ur, int i_oc) { |
167 | int idx = vmm_out(i_ur, i_oc).getIdx(); |
168 | assert(idx |
169 | < (jcp.is_depthwise ? ker_dw_reg_base_idx : ker_reg_base_idx)); |
170 | return Xbyak::Zmm(idx); |
171 | } |
172 | Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
173 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
174 | assert(idx < 31); |
175 | return Vmm(idx); |
176 | } |
177 | Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) { |
178 | const int idx = i_ic + nb_x_blocking * jcp.ur_w; |
179 | const int max_idx = jcp.src_zero_point ? ker_zp_reg_base_idx |
180 | : ker_dw_reg_base_idx; |
181 | assert(idx < max_idx); |
182 | MAYBE_UNUSED(max_idx); |
183 | |
184 | return Xbyak::Zmm(idx); |
185 | } |
186 | int get_ow_start(int ki, int pad_l) { |
187 | return nstl::max(0, |
188 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
189 | } |
190 | int get_ow_end(int ur_w, int ki, int pad_r) { |
191 | return ur_w |
192 | - nstl::max(0, |
193 | utils::div_up( |
194 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
195 | jcp.stride_w)); |
196 | } |
197 | |
198 | // bf16 utils |
199 | int get_src_down_idx(int nb_x_blocking) { |
200 | int idx = nb_x_blocking * jcp.ur_w; |
201 | assert(idx < 31); |
202 | return idx; |
203 | } |
204 | inline Vmm maybe_mask_vmm(Vmm vmm, bool mask_flag) { |
205 | return mask_flag ? vmm | ktail_mask_extended : vmm; |
206 | } |
207 | inline Vmm_down_t maybe_mask_vmm_down(Vmm_down_t vmm, bool mask_flag) { |
208 | return (mask_flag) ? vmm | ktail_mask : vmm; |
209 | } |
210 | inline void store_bf16(Xbyak::Address addr, int vmm_dst_idx, |
211 | int vmm_down_idx, bool mask_flag) { |
212 | auto vmm_down = Vmm_down_t(vmm_down_idx); |
213 | bf16_emu_->vcvtneps2bf16( |
214 | Xbyak::Ymm(vmm_down_idx), Xbyak::Zmm(vmm_dst_idx)); |
215 | |
216 | // for xmm, upper half is zero after conversion to |
217 | // bf16, so mask always & mask for tails |
218 | vmovdqu16(addr, |
219 | maybe_mask_vmm_down(vmm_down, mask_flag || jcp.simd_w == 4)); |
220 | } |
221 | |
222 | void prepare_output(int ur_w); |
223 | void apply_sum(int ur_w, bool last_oc_block_flag, const int nb_oc_block, |
224 | const int oc_block, const float *p_sum_scale, |
225 | const int32_t *p_sum_zp); |
226 | void apply_postops(int ur_w, bool last_oc_block_flag, const int nb_oc_block, |
227 | const int oc_block, const float *p_sum_scale, |
228 | const int32_t *p_sum_zp); |
229 | void store_output(int ur_w, bool last_oc_block_flag); |
230 | void compute_ker_dw(int ur_w, int pad_l, int pad_r, |
231 | ic_block_t last_ic_block_flag, bool h_padded); |
232 | void compute_ker(int ur_w, int pad_l, int pad_r, |
233 | ic_block_t last_ic_block_flag, bool h_padded = false); |
234 | void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag); |
235 | void icb_loop(int ur_w, int pad_l, int pad_r, bool is_last_spatial_block); |
236 | void generate() override; |
237 | void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, |
238 | bool mask_flag); |
239 | Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false); |
240 | }; |
241 | |
242 | struct jit_avx512_core_x8s8s32x_fwd_kernel { |
243 | |
244 | jit_avx512_core_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp, |
245 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
246 | : kernel_(nullptr) { |
247 | int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; |
248 | switch (ch_block) { |
249 | case 16: |
250 | kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>( |
251 | ajcp, attr, dst_md); |
252 | return; |
253 | case 8: |
254 | kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>( |
255 | ajcp, attr, dst_md); |
256 | return; |
257 | case 4: |
258 | kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>( |
259 | ajcp, attr, dst_md); |
260 | return; |
261 | default: assert(!"invalid channel blocking" ); |
262 | } |
263 | } |
264 | |
265 | status_t create_kernel() { return kernel_->create_kernel(); } |
266 | |
267 | ~jit_avx512_core_x8s8s32x_fwd_kernel() { delete kernel_; } |
268 | |
269 | static status_t init_conf(jit_conv_conf_t &jcp, |
270 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
271 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
272 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
273 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
274 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
275 | void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } |
276 | const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } |
277 | |
278 | private: |
279 | jit_generator *kernel_; |
280 | }; |
281 | |
282 | } // namespace x64 |
283 | } // namespace cpu |
284 | } // namespace impl |
285 | } // namespace dnnl |
286 | |
287 | #endif |
288 | |