1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
25#include "cpu/x64/jit_generator.hpp"
26#include "cpu/x64/jit_primitive_conf.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33template <typename Vmm>
34struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
35 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
36
37 enum { STATE_FIRST_DST_LOAD = 0x1U };
38
39 _jit_avx512_core_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp,
40 const primitive_attr_t &attr, const memory_desc_t &dst_md);
41
42 jit_conv_conf_t jcp;
43 const primitive_attr_t &attr_;
44
45private:
46 constexpr static int isa_simd_width_
47 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
48 using Vmm_down_t =
49 typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value,
50 Xbyak::Ymm, Xbyak::Xmm>::type;
51 const int ic_sub_step = 4;
52 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>>
53 postops_injector_;
54
55 enum {
56 typesize = sizeof(float),
57 ker_reg_base_idx = 28,
58 ker_dw_reg_base_idx = 30,
59 ker_zp_reg_base_idx = 26,
60 };
61 typedef enum {
62 no_last_block,
63 last_ic_block,
64 last_sp_block,
65 } ic_block_t;
66
67 /* data regs */
68 const Xbyak::Reg64 reg_ptr_scales = rax;
69 const Xbyak::Reg64 aux_reg_saturation = rax;
70 const Xbyak::Reg64 reg_inp = r8;
71 const Xbyak::Reg64 reg_ker = r9;
72 const Xbyak::Reg64 reg_out = r10;
73 const Xbyak::Reg64 aux_reg_inp = r11;
74 const Xbyak::Reg64 reg_ptr_sum_scale = r11;
75 const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1;
76 const Xbyak::Reg64 aux_reg_ker = r12;
77 const Xbyak::Reg64 reg_compensation = r14;
78 const Xbyak::Reg64 aux_reg_inp_d = r13;
79 const Xbyak::Reg64 aux_reg_ker_d = r15;
80 const Xbyak::Reg64 reg_ker_long_offt = r13;
81 // Using 3d regs as depthwise_3d is not yet supported
82 const Xbyak::Reg64 reg_inp_buffer_ptr = aux_reg_inp_d;
83 const Xbyak::Reg64 aux_reg_inp_buffer_ptr = aux_reg_ker_d;
84 // zero-point computation
85 const Xbyak::Reg64 reg_zp_compensation = aux_reg_inp;
86 const Xbyak::Reg64 reg_src_zero_point = aux_reg_ker_d;
87 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point;
88
89 // dst scale
90 const Xbyak::Reg64 reg_dst_scale = reg_src_zero_point;
91
92 /* counter regs */
93 const Xbyak::Reg64 reg_oi = rbx;
94 const Xbyak::Reg64 reg_bias = rdx;
95 const Xbyak::Reg64 reg_oc_blocks = rsi;
96 const Xbyak::Reg64 reg_owb = aux_reg_ker;
97 const Xbyak::Reg64 reg_scratch = reg_compensation;
98 const Xbyak::Reg64 reg_kj = reg_ptr_scales;
99 const Xbyak::Reg64 reg_ki = reg_compensation;
100 const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
101 const Xbyak::Reg64 reg_icb = reg_bias;
102 const Xbyak::Reg64 reg_jmp_tbl_base = reg_kj;
103
104 /* binary post-op operand */
105 const Xbyak::Reg64 temp_offset_reg = r12;
106
107 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
108 const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
109 const Xbyak::Opmask postops_mask = Xbyak::Opmask(4);
110 const Xbyak::Opmask ktail_mask_extended = Xbyak::Opmask(5);
111
112 const Vmm vmm_wei = Vmm(31);
113 /* used during bias section of store_output */
114 const Vmm vmm_comp = Vmm(30); // only for signed input
115 const Vmm vmm_bias = Vmm(31);
116 /* used during post_op sum section of store_output */
117 const Vmm vmm_prev_dst = Vmm(31);
118 /* used during write-out section of store_output */
119 const Vmm vmm_saturation = Vmm(30);
120 const Vmm vmm_sum_zp = Vmm(30);
121 const Vmm vmm_zero = Vmm(31);
122
123 /* used in compute_ker (but set during prepare_output) */
124 const Vmm vmm_shift = vmm_comp; // only for signed input
125 /* used in compute_ker (but only for pre-VNNI machines) */
126 const Vmm vmm_tmp = Vmm(28); // not used for depthwise
127 const Vmm vmm_one
128 = Vmm(29); // set at start of kernel, not used for depthwise.
129 /* zero-point */
130 const Vmm vmm_zp = Vmm(25);
131 const Vmm vmm_zp_one = Vmm(26);
132 const Vmm vmm_zp_tmp = vmm_zp;
133
134 const Vmm vmm_dst_scale = Vmm(26);
135
136 /* bf16 emulation */
137 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
138 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
139 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
140 Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(30);
141 // bf16_emu_reserv_5 not required when only computing vcvtneps2bf16()
142 const Xbyak::Reg64 bf16_emu_scratch = aux_reg_ker;
143 std::unique_ptr<bf16_emulation_t> bf16_emu_;
144
145 /* registers use only for depthwise
146 groups are always blocked by 16(padded if needed),
147 hence use only Zmm registers */
148 const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
149 Xbyak::Zmm zmm_tmp;
150 Xbyak::Zmm zmm_src;
151 Xbyak::Zmm zmm_shifted_zero;
152 Xbyak::Zmm zmm_permute;
153
154 int vmm_out_idx(int i_ur, int i_oc) {
155 const int nb_x_blocking
156 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
157 const int idx = i_ur * nb_x_blocking + i_oc;
158 assert(idx < (jcp.is_depthwise
159 ? ker_dw_reg_base_idx
160 : jcp.src_zero_point ? ker_zp_reg_base_idx
161 : ker_reg_base_idx));
162 return idx;
163 }
164
165 Vmm vmm_out(int i_ur, int i_oc) { return Vmm(vmm_out_idx(i_ur, i_oc)); }
166 Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
167 int idx = vmm_out(i_ur, i_oc).getIdx();
168 assert(idx
169 < (jcp.is_depthwise ? ker_dw_reg_base_idx : ker_reg_base_idx));
170 return Xbyak::Zmm(idx);
171 }
172 Vmm vmm_inp(int i_ic, int nb_x_blocking) {
173 int idx = i_ic + nb_x_blocking * jcp.ur_w;
174 assert(idx < 31);
175 return Vmm(idx);
176 }
177 Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
178 const int idx = i_ic + nb_x_blocking * jcp.ur_w;
179 const int max_idx = jcp.src_zero_point ? ker_zp_reg_base_idx
180 : ker_dw_reg_base_idx;
181 assert(idx < max_idx);
182 MAYBE_UNUSED(max_idx);
183
184 return Xbyak::Zmm(idx);
185 }
186 int get_ow_start(int ki, int pad_l) {
187 return nstl::max(0,
188 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
189 }
190 int get_ow_end(int ur_w, int ki, int pad_r) {
191 return ur_w
192 - nstl::max(0,
193 utils::div_up(
194 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
195 jcp.stride_w));
196 }
197
198 // bf16 utils
199 int get_src_down_idx(int nb_x_blocking) {
200 int idx = nb_x_blocking * jcp.ur_w;
201 assert(idx < 31);
202 return idx;
203 }
204 inline Vmm maybe_mask_vmm(Vmm vmm, bool mask_flag) {
205 return mask_flag ? vmm | ktail_mask_extended : vmm;
206 }
207 inline Vmm_down_t maybe_mask_vmm_down(Vmm_down_t vmm, bool mask_flag) {
208 return (mask_flag) ? vmm | ktail_mask : vmm;
209 }
210 inline void store_bf16(Xbyak::Address addr, int vmm_dst_idx,
211 int vmm_down_idx, bool mask_flag) {
212 auto vmm_down = Vmm_down_t(vmm_down_idx);
213 bf16_emu_->vcvtneps2bf16(
214 Xbyak::Ymm(vmm_down_idx), Xbyak::Zmm(vmm_dst_idx));
215
216 // for xmm, upper half is zero after conversion to
217 // bf16, so mask always & mask for tails
218 vmovdqu16(addr,
219 maybe_mask_vmm_down(vmm_down, mask_flag || jcp.simd_w == 4));
220 }
221
222 void prepare_output(int ur_w);
223 void apply_sum(int ur_w, bool last_oc_block_flag, const int nb_oc_block,
224 const int oc_block, const float *p_sum_scale,
225 const int32_t *p_sum_zp);
226 void apply_postops(int ur_w, bool last_oc_block_flag, const int nb_oc_block,
227 const int oc_block, const float *p_sum_scale,
228 const int32_t *p_sum_zp);
229 void store_output(int ur_w, bool last_oc_block_flag);
230 void compute_ker_dw(int ur_w, int pad_l, int pad_r,
231 ic_block_t last_ic_block_flag, bool h_padded);
232 void compute_ker(int ur_w, int pad_l, int pad_r,
233 ic_block_t last_ic_block_flag, bool h_padded = false);
234 void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
235 void icb_loop(int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
236 void generate() override;
237 void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
238 bool mask_flag);
239 Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
240};
241
242struct jit_avx512_core_x8s8s32x_fwd_kernel {
243
244 jit_avx512_core_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp,
245 const primitive_attr_t &attr, const memory_desc_t &dst_md)
246 : kernel_(nullptr) {
247 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
248 switch (ch_block) {
249 case 16:
250 kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
251 ajcp, attr, dst_md);
252 return;
253 case 8:
254 kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
255 ajcp, attr, dst_md);
256 return;
257 case 4:
258 kernel_ = new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
259 ajcp, attr, dst_md);
260 return;
261 default: assert(!"invalid channel blocking");
262 }
263 }
264
265 status_t create_kernel() { return kernel_->create_kernel(); }
266
267 ~jit_avx512_core_x8s8s32x_fwd_kernel() { delete kernel_; }
268
269 static status_t init_conf(jit_conv_conf_t &jcp,
270 const convolution_desc_t &cd, memory_desc_t &src_pd,
271 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
272 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
273 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
274 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
275 void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
276 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
277
278private:
279 jit_generator *kernel_;
280};
281
282} // namespace x64
283} // namespace cpu
284} // namespace impl
285} // namespace dnnl
286
287#endif
288