1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16_DW_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16_DW_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34struct jit_avx512_dw_conv_fwd_kernel_bf16 : public jit_generator {
35 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_fwd_kernel_bf16)
36
37 jit_avx512_dw_conv_fwd_kernel_bf16(
38 const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md);
39
40 jit_conv_conf_t jcp;
41
42private:
43 using reg64_t = const Xbyak::Reg64;
44 using mask_t = const Xbyak::Opmask;
45 const Xbyak::AddressFrame &vmmword = zword;
46
47 const int acc_idx_start = 2;
48 inline int get_max_regs() const { return isa_has_bf16(jcp.isa) ? 30 : 25; };
49
50 // dw convolution
51 reg64_t reg_input = r8;
52 reg64_t aux_reg_input = r9;
53 reg64_t reg_kernel = r10;
54 reg64_t aux_reg_kernel = r11;
55 reg64_t reg_ch_blocks = r12;
56 reg64_t reg_output = r13;
57 reg64_t reg_bias = r14;
58 reg64_t reg_kh = r15;
59 reg64_t iter_kh = rax;
60 reg64_t reg_oi = rbx;
61
62 reg64_t reg_tmp = reg_ch_blocks;
63
64 // fused convolution
65 reg64_t reg_input_buffer_ptr = rdx;
66 reg64_t aux_reg_input_buffer_ptr = rsi;
67 reg64_t reg_iw_offset = reg_input; //Hack: clear reg_input early in kernel
68 reg64_t reg_tail = rax;
69 mask_t k_oc_tail_mask = Xbyak::Opmask(2);
70 mask_t ktail_mask = k_oc_tail_mask;
71 mask_t k_ch_tail_mask_extended = Xbyak::Opmask(3);
72
73 Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0);
74 Xbyak::Zmm zmm_src_reg = Xbyak::Zmm(1);
75 Xbyak::Zmm zmm_prev_dst = Xbyak::Zmm(31);
76
77 /* Registers used for bfloat16 emulation */
78 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
79 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
80 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
81 reg64_t bf16_emu_reserv_4 = abi_not_param1;
82 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29);
83 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30);
84
85 int get_acc_reg_idx(int idx) const;
86
87 Xbyak::Zmm get_acc_reg(int idx);
88
89 int get_ow_start(int ki, int pad_l) {
90 return nstl::max(0,
91 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
92 }
93
94 int get_ow_end(int ur_w, int ki, int pad_r) {
95 return ur_w
96 - nstl::max(0,
97 utils::div_up(
98 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
99 jcp.stride_w));
100 }
101
102 inline bool is_src_layout_nxc() {
103 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
104 format_tag::nwc);
105 }
106
107 inline bool is_dst_layout_nxc() {
108 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
109 format_tag::nwc);
110 }
111
112 inline void load_src(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
113 inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r);
114 inline void loop_ow(int ur_ch_blocks);
115 inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w, int pad_l,
116 int pad_r, bool last_ch_block_flag);
117 inline void apply_postops(
118 int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
119 inline void store_dst(int ur_ch_blocks, int ur_w, bool last_ch_block_flag);
120
121 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
122 postops_injector_;
123
124 std::unique_ptr<bf16_emulation_t> bf16_emu_;
125
126 void generate() override;
127};
128
129struct jit_avx512_dw_conv_bwd_data_kernel_bf16 : public jit_generator {
130 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_data_kernel_bf16)
131
132 jit_avx512_dw_conv_bwd_data_kernel_bf16(const jit_conv_conf_t &ajcp)
133 : jit_generator(jit_name()), jcp(ajcp), bf16_emu_(nullptr) {
134
135 if (!isa_has_bf16(jcp.isa))
136 bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1,
137 bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
138 bf16_emu_reserv_5, bf16_emu_reserv_6);
139 }
140
141 ~jit_avx512_dw_conv_bwd_data_kernel_bf16() { delete bf16_emu_; }
142
143 jit_conv_conf_t jcp;
144
145private:
146 using reg64_t = const Xbyak::Reg64;
147
148 const int acc_idx_start = 2;
149 inline int get_max_regs() { return isa_has_bf16(jcp.isa) ? 30 : 25; };
150
151 Xbyak::Zmm zmm_ker_reg = Xbyak::Zmm(0);
152 Xbyak::Zmm zmm_dst_reg = Xbyak::Zmm(1);
153
154 inline Xbyak::Zmm get_acc_reg(int idx) {
155 assert(idx + acc_idx_start <= get_max_regs());
156 return Xbyak::Zmm(idx + acc_idx_start);
157 }
158
159 reg64_t reg_ddst = rax;
160 reg64_t aux_reg_ddst = r8;
161 reg64_t aux1_reg_ddst = abi_not_param1;
162 reg64_t reg_kernel = rdx;
163 reg64_t aux_reg_kernel = r10;
164 reg64_t aux1_reg_kernel = rbp;
165 reg64_t reg_dsrc = rsi;
166
167 reg64_t reg_ur_str_w = r9;
168 reg64_t reg_ch_blocks = rbx;
169
170 reg64_t iter_kh = r11;
171 reg64_t iter_kw = r12;
172 reg64_t reg_kh = r13;
173 reg64_t reg_kw = r14;
174
175 reg64_t aux_reg_ch_blocks = r15;
176 reg64_t reg_tmp = r15;
177 Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1);
178
179 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
180 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
181 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
182 reg64_t bf16_emu_reserv_4 = iter_kw;
183 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29);
184 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30);
185
186 bf16_emulation_t *bf16_emu_;
187
188 inline void ch_loop_body(int ur_ch_blocks, int unroll_w);
189 inline void unroll_width_body(int ur_ch_blocks);
190 inline void load_ddst(int ur_ch_blocks, int ur_str_w);
191 inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch);
192 inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch);
193
194 void generate() override;
195 inline bool is_dsrc_layout_nxc() {
196 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
197 format_tag::nwc);
198 }
199 inline bool is_ddst_layout_nxc() {
200 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
201 format_tag::nwc);
202 }
203};
204
205struct jit_avx512_dw_conv_bwd_weights_kernel_bf16 : public jit_generator {
206
207 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_dw_conv_bwd_weights_kernel_bf16)
208
209 jit_avx512_dw_conv_bwd_weights_kernel_bf16(const jit_conv_conf_t &ajcp)
210 : jit_generator(jit_name()), jcp(ajcp), bf16_emu_(nullptr) {
211
212 if (!isa_has_bf16(jcp.isa))
213 bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserv_1,
214 bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
215 bf16_emu_reserv_5, bf16_emu_reserv_6);
216 }
217
218 ~jit_avx512_dw_conv_bwd_weights_kernel_bf16() { delete bf16_emu_; }
219
220 jit_conv_conf_t jcp;
221
222private:
223 using reg64_t = const Xbyak::Reg64;
224 const Xbyak::AddressFrame &vmmword = zword;
225
226 const int max_unroll_w_ = 30;
227 const int block_size_ = 15;
228
229 const int idx_start = 2;
230 inline int get_max_regs() const { return isa_has_bf16(jcp.isa) ? 30 : 25; };
231
232 /* Offset between input and accummulators is 3, therefore, assume 'kw'
233 * is no larger than 3*/
234 Xbyak::Zmm zmm_bias_reg = Xbyak::Zmm(0);
235 Xbyak::Zmm zmm_out_reg = Xbyak::Zmm(1);
236
237 inline Xbyak::Zmm get_acc_reg(int idx) {
238 assert(idx + idx_start <= get_max_regs());
239 return Xbyak::Zmm(idx + idx_start);
240 }
241 inline Xbyak::Zmm get_input_reg(int idx) {
242 const int i_idx = idx_start + jcp.kw + idx % jcp.kw;
243 assert(i_idx <= get_max_regs());
244 return Xbyak::Zmm(i_idx);
245 }
246
247 reg64_t reg_tmp_input = r9;
248 reg64_t reg_tmp_output = r10;
249 reg64_t reg_tmp_filter = r13;
250 reg64_t reg_kh_offset = rax;
251
252 /* parameter passed by driver into kernel */
253 Xbyak::Reg8 reg_exec_flags = bl;
254 reg64_t reg_oh_worksize = r14;
255 reg64_t reg_oh = rax;
256 reg64_t reg_iter_ow_blk = r11;
257 reg64_t reg_kh_aux = rsi;
258 reg64_t reg_kh = rdx;
259
260 /* Base addresses for convolution parameters. */
261 reg64_t reg_input_baddr = r15;
262 reg64_t reg_output_baddr = r12;
263 reg64_t reg_filter_baddr = abi_not_param1;
264 reg64_t reg_bias_baddr = r13;
265
266 reg64_t reg_tmp = r8;
267
268 Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1);
269
270 /* Registers used for bfloat16 emulation */
271 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
272 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
273 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
274 reg64_t bf16_emu_reserv_4 = r8;
275 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(29);
276 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(30);
277
278 bf16_emulation_t *bf16_emu_;
279
280 /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
281 */
282 void compute_ow_step_unroll(int unroll_w, int l_pad, int pad_offset,
283 int ow_block, bool is_last_ch);
284
285 /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
286 void compute_kh_step(int unroll_w, int l_pad, int pad_offset, int ow_block,
287 bool is_last_ch);
288 /* Channel loop for 'nxc' format */
289 void compute_ch_loop(int unroll_w, int l_pad, int pad_offset, int ow_block);
290 void compute_h_loop(int unroll_w, int l_pad, int pad_offset, int ow_block);
291
292 /* Write 'width' micro-kernel JITs; depending on the padding and convolution
293 * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
294 * right ow-block.*/
295 void compute_ow_block_unroll();
296 void deploy_zero_filter();
297 void zero_filter_kh_loop();
298 void load_filter(bool is_last_ch = false);
299 void zero_filter();
300 void load_bias(bool is_last_ch);
301 void zero_bias();
302 void compute_bias_step_unroll(const int unroll_w, bool is_last_ch);
303 void compute_ch_loop_bias(bool do_load_bias);
304 void deploy_ch_loop_bias();
305 void compute_single_ch_block_bias();
306 void compute_spatial_loop_bias(bool is_last_ch);
307 void store_filter(bool is_last_ch = false);
308 void store_bias(bool is_last_ch);
309 void compute_bias();
310 void calculate_w_unrolling(
311 int &unroll_trips, int &unroll_w, int &unroll_w_tail);
312
313 void generate() override;
314
315 inline bool is_layout_nxc() {
316 return utils::everyone_is(
317 true, is_src_layout_nxc(), is_ddst_layout_nxc());
318 }
319 inline bool is_src_layout_nxc() {
320 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
321 format_tag::nwc);
322 }
323 inline bool is_ddst_layout_nxc() {
324 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
325 format_tag::nwc);
326 }
327};
328
329} // namespace x64
330} // namespace cpu
331} // namespace impl
332} // namespace dnnl
333
334#endif /* JIT_UNI_DW_CONV_KERNEL_BF16_HPP */
335