1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16_DW_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16_DW_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator { |
35 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_fwd_kernel_bf16) |
36 | |
37 | jit_avx512_dw_conv_fwd_kernel_bf16( |
38 | const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); |
39 | |
40 | jit_conv_conf_t jcp; |
41 | |
42 | private: |
43 | using reg64_t = const Xbyak::Reg64; |
44 | using mask_t = const Xbyak::Opmask; |
45 | const Xbyak::AddressFrame &vmmword = zword; |
46 | |
47 | const int acc_idx_start = 2; |
48 | inline int get_max_regs() const { return isa_has_bf16(jcp.isa) ? 30 : 25; }; |
49 | |
50 | // dw convolution |
51 | reg64_t reg_input = r8; |
52 | reg64_t aux_reg_input = r9; |
53 | reg64_t reg_kernel = r10; |
54 | reg64_t aux_reg_kernel = r11; |
55 | reg64_t reg_ch_blocks = r12; |
56 | reg64_t reg_output = r13; |
57 | reg64_t reg_bias = r14; |
58 | reg64_t reg_kh = r15; |
59 | reg64_t iter_kh = rax; |
60 | reg64_t reg_oi = rbx; |
61 | |
62 | reg64_t reg_tmp = reg_ch_blocks; |
63 | |
64 | // fused convolution |
65 | reg64_t reg_input_buffer_ptr = rdx; |
66 | reg64_t aux_reg_input_buffer_ptr = rsi; |
67 | reg64_t reg_iw_offset = reg_input; //Hack: clear reg_input early in kernel |
68 | reg64_t reg_tail = rax; |
69 | mask_t k_oc_tail_mask = Xbyak::Opmask(2); |
70 | mask_t ktail_mask = k_oc_tail_mask; |
71 | mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3); |
72 | |
73 | Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); |
74 | Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1); |
75 | Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31); |
76 | |
77 | /* Registers used for bfloat16 emulation */ |
78 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
79 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
80 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
81 | reg64_t bf16_emu_reserv_4 = abi_not_param1; |
82 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29); |
83 | Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30); |
84 | |
85 | int get_acc_reg_idx(int idx) const; |
86 | |
87 | Xbyak::Zmm get_acc_reg(int idx); |
88 | |
89 | int get_ow_start(int ki, int pad_l) { |
90 | return nstl::max(0, |
91 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
92 | } |
93 | |
94 | int get_ow_end(int ur_w, int ki, int pad_r) { |
95 | return ur_w |
96 | - nstl::max(0, |
97 | utils::div_up( |
98 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
99 | jcp.stride_w)); |
100 | } |
101 | |
102 | inline bool is_src_layout_nxc() { |
103 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
104 | format_tag::nwc); |
105 | } |
106 | |
107 | inline bool is_dst_layout_nxc() { |
108 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
109 | format_tag::nwc); |
110 | } |
111 | |
112 | inline void load_src(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); |
113 | inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r); |
114 | inline void loop_ow(int ur_ch_blocks); |
115 | inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w, int pad_l, |
116 | int pad_r, bool last_ch_block_flag); |
117 | inline void apply_postops( |
118 | int ur_ch_blocks, int ur_w, bool last_ch_block_flag); |
119 | inline void store_dst(int ur_ch_blocks, int ur_w, bool last_ch_block_flag); |
120 | |
121 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
122 | postops_injector_; |
123 | |
124 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
125 | |
126 | void generate() override; |
127 | }; |
128 | |
129 | struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator { |
130 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_data_kernel_bf16) |
131 | |
132 | jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp) |
133 | : jit_generator(jit_name()), jcp(ajcp), bf16_emu_(nullptr) { |
134 | |
135 | if (!isa_has_bf16(jcp.isa)) |
136 | bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, |
137 | bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, |
138 | bf16_emu_reserv_5, bf16_emu_reserv_6); |
139 | } |
140 | |
141 | ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { delete bf16_emu_; } |
142 | |
143 | jit_conv_conf_t jcp; |
144 | |
145 | private: |
146 | using reg64_t = const Xbyak::Reg64; |
147 | |
148 | const int acc_idx_start = 2; |
149 | inline int get_max_regs() { return isa_has_bf16(jcp.isa) ? 30 : 25; }; |
150 | |
151 | Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0); |
152 | Xbyak::Zmm zmm_dst_reg = Xbyak::Zmm(1); |
153 | |
154 | inline Xbyak::Zmm get_acc_reg(int idx) { |
155 | assert(idx + acc_idx_start <= get_max_regs()); |
156 | return Xbyak::Zmm(idx + acc_idx_start); |
157 | } |
158 | |
159 | reg64_t reg_ddst = rax; |
160 | reg64_t aux_reg_ddst = r8; |
161 | reg64_t aux1_reg_ddst = abi_not_param1; |
162 | reg64_t reg_kernel = rdx; |
163 | reg64_t aux_reg_kernel = r10; |
164 | reg64_t aux1_reg_kernel = rbp; |
165 | reg64_t reg_dsrc = rsi; |
166 | |
167 | reg64_t reg_ur_str_w = r9; |
168 | reg64_t reg_ch_blocks = rbx; |
169 | |
170 | reg64_t iter_kh = r11; |
171 | reg64_t iter_kw = r12; |
172 | reg64_t reg_kh = r13; |
173 | reg64_t reg_kw = r14; |
174 | |
175 | reg64_t aux_reg_ch_blocks = r15; |
176 | reg64_t reg_tmp = r15; |
177 | Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1); |
178 | |
179 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
180 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
181 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
182 | reg64_t bf16_emu_reserv_4 = iter_kw; |
183 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29); |
184 | Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30); |
185 | |
186 | bf16_emulation_t *bf16_emu_; |
187 | |
188 | inline void ch_loop_body(int ur_ch_blocks, int unroll_w); |
189 | inline void unroll_width_body(int ur_ch_blocks); |
190 | inline void load_ddst(int ur_ch_blocks, int ur_str_w); |
191 | inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch); |
192 | inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch); |
193 | |
194 | void generate() override; |
195 | inline bool is_dsrc_layout_nxc() { |
196 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
197 | format_tag::nwc); |
198 | } |
199 | inline bool is_ddst_layout_nxc() { |
200 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
201 | format_tag::nwc); |
202 | } |
203 | }; |
204 | |
205 | struct jit_avx512_dw_conv_bwd_weights_kernel_bf16 : public jit_generator { |
206 | |
207 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_weights_kernel_bf16) |
208 | |
209 | jit_avx512_dw_conv_bwd_weights_kernel_bf16(const jit_conv_conf_t &ajcp) |
210 | : jit_generator(jit_name()), jcp(ajcp), bf16_emu_(nullptr) { |
211 | |
212 | if (!isa_has_bf16(jcp.isa)) |
213 | bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1, |
214 | bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, |
215 | bf16_emu_reserv_5, bf16_emu_reserv_6); |
216 | } |
217 | |
218 | ~jit_avx512_dw_conv_bwd_weights_kernel_bf16() { delete bf16_emu_; } |
219 | |
220 | jit_conv_conf_t jcp; |
221 | |
222 | private: |
223 | using reg64_t = const Xbyak::Reg64; |
224 | const Xbyak::AddressFrame &vmmword = zword; |
225 | |
226 | const int max_unroll_w_ = 30; |
227 | const int block_size_ = 15; |
228 | |
229 | const int idx_start = 2; |
230 | inline int get_max_regs() const { return isa_has_bf16(jcp.isa) ? 30 : 25; }; |
231 | |
232 | /* Offset between input and accummulators is 3, therefore, assume 'kw' |
233 | * is no larger than 3*/ |
234 | Xbyak::Zmm zmm_bias_reg = Xbyak::Zmm(0); |
235 | Xbyak::Zmm zmm_out_reg = Xbyak::Zmm(1); |
236 | |
237 | inline Xbyak::Zmm get_acc_reg(int idx) { |
238 | assert(idx + idx_start <= get_max_regs()); |
239 | return Xbyak::Zmm(idx + idx_start); |
240 | } |
241 | inline Xbyak::Zmm get_input_reg(int idx) { |
242 | const int i_idx = idx_start + jcp.kw + idx % jcp.kw; |
243 | assert(i_idx <= get_max_regs()); |
244 | return Xbyak::Zmm(i_idx); |
245 | } |
246 | |
247 | reg64_t reg_tmp_input = r9; |
248 | reg64_t reg_tmp_output = r10; |
249 | reg64_t reg_tmp_filter = r13; |
250 | reg64_t reg_kh_offset = rax; |
251 | |
252 | /* parameter passed by driver into kernel */ |
253 | Xbyak::Reg8 reg_exec_flags = bl; |
254 | reg64_t reg_oh_worksize = r14; |
255 | reg64_t reg_oh = rax; |
256 | reg64_t reg_iter_ow_blk = r11; |
257 | reg64_t reg_kh_aux = rsi; |
258 | reg64_t reg_kh = rdx; |
259 | |
260 | /* Base addresses for convolution parameters. */ |
261 | reg64_t reg_input_baddr = r15; |
262 | reg64_t reg_output_baddr = r12; |
263 | reg64_t reg_filter_baddr = abi_not_param1; |
264 | reg64_t reg_bias_baddr = r13; |
265 | |
266 | reg64_t reg_tmp = r8; |
267 | |
268 | Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1); |
269 | |
270 | /* Registers used for bfloat16 emulation */ |
271 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
272 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
273 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
274 | reg64_t bf16_emu_reserv_4 = r8; |
275 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29); |
276 | Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30); |
277 | |
278 | bf16_emulation_t *bf16_emu_; |
279 | |
280 | /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs |
281 | */ |
282 | void compute_ow_step_unroll(int unroll_w, int l_pad, int pad_offset, |
283 | int ow_block, bool is_last_ch); |
284 | |
285 | /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ |
286 | void compute_kh_step(int unroll_w, int l_pad, int pad_offset, int ow_block, |
287 | bool is_last_ch); |
288 | /* Channel loop for 'nxc' format */ |
289 | void compute_ch_loop(int unroll_w, int l_pad, int pad_offset, int ow_block); |
290 | void compute_h_loop(int unroll_w, int l_pad, int pad_offset, int ow_block); |
291 | |
292 | /* Write 'width' micro-kernel JITs; depending on the padding and convolution |
293 | * size, write a micro-kernel for the left ow-block, middle ow-block(s), and |
294 | * right ow-block.*/ |
295 | void compute_ow_block_unroll(); |
296 | void deploy_zero_filter(); |
297 | void zero_filter_kh_loop(); |
298 | void load_filter(bool is_last_ch = false); |
299 | void zero_filter(); |
300 | void load_bias(bool is_last_ch); |
301 | void zero_bias(); |
302 | void compute_bias_step_unroll(const int unroll_w, bool is_last_ch); |
303 | void compute_ch_loop_bias(bool do_load_bias); |
304 | void deploy_ch_loop_bias(); |
305 | void compute_single_ch_block_bias(); |
306 | void compute_spatial_loop_bias(bool is_last_ch); |
307 | void store_filter(bool is_last_ch = false); |
308 | void store_bias(bool is_last_ch); |
309 | void compute_bias(); |
310 | void calculate_w_unrolling( |
311 | int &unroll_trips, int &unroll_w, int &unroll_w_tail); |
312 | |
313 | void generate() override; |
314 | |
315 | inline bool is_layout_nxc() { |
316 | return utils::everyone_is( |
317 | true, is_src_layout_nxc(), is_ddst_layout_nxc()); |
318 | } |
319 | inline bool is_src_layout_nxc() { |
320 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
321 | format_tag::nwc); |
322 | } |
323 | inline bool is_ddst_layout_nxc() { |
324 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
325 | format_tag::nwc); |
326 | } |
327 | }; |
328 | |
329 | } // namespace x64 |
330 | } // namespace cpu |
331 | } // namespace impl |
332 | } // namespace dnnl |
333 | |
334 | #endif /* JIT_UNI_DW_CONV_KERNEL_BF16_HPP */ |
335 | |