1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_DW_CONV_KERNEL_F32_HPP |
18 | #define CPU_X64_JIT_UNI_DW_CONV_KERNEL_F32_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
24 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
25 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
26 | #include "cpu/x64/jit_generator.hpp" |
27 | #include "cpu/x64/jit_primitive_conf.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | template <cpu_isa_t isa> |
35 | struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator { |
36 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) |
37 | |
38 | jit_uni_dw_conv_fwd_kernel_f32( |
39 | const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md); |
40 | |
41 | jit_conv_conf_t jcp; |
42 | |
43 | private: |
44 | using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm, |
45 | isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; |
46 | using reg64_t = const Xbyak::Reg64; |
47 | using mask_t = const Xbyak::Opmask; |
48 | const Xbyak::AddressFrame &vmmword |
49 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
50 | const int vlen = cpu_isa_traits<isa>::vlen; |
51 | |
52 | // dw convolution |
53 | reg64_t reg_input = r8; |
54 | reg64_t aux_reg_input = r9; |
55 | reg64_t reg_kernel = r10; |
56 | reg64_t aux_reg_kernel = r11; |
57 | reg64_t reg_ch_blocks = r12; |
58 | reg64_t reg_output = r13; |
59 | reg64_t reg_bias = r14; |
60 | reg64_t reg_kh = r15; |
61 | reg64_t iter_kh = rax; |
62 | reg64_t reg_oi = rbx; |
63 | reg64_t aux_reg_ch_blocks = rsi; |
64 | // fused convolution |
65 | reg64_t reg_input_buffer_ptr = rdx; |
66 | reg64_t aux_reg_input_buffer_ptr = rbp; |
67 | reg64_t reg_iw_offset = reg_input; //Hack: clear reg_input early in kernel |
68 | |
69 | reg64_t reg_tmp = reg_ch_blocks; |
70 | reg64_t reg_tail = rax; |
71 | mask_t k_oc_tail_mask = Xbyak::Opmask(2); |
72 | |
73 | inline void load_src(int ur_ch_blocks, int ur_w, bool is_ch_tail); |
74 | inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r); |
75 | inline void ow_loop(int ur_ch_blocks); |
76 | inline void apply_filter_unrolled( |
77 | int ur_ch_blocks, int ur_w, int pad_l, int pad_r, bool is_ch_tail); |
78 | inline void apply_postops( |
79 | const int ur_ch_blocks, const int ur_w, const bool is_ch_tail); |
80 | inline void store_dst(int ur_ch_blocks, int ur_w, bool is_ch_tail); |
81 | |
82 | int max_repeats() { return jcp.isa == sse41 ? 2 : 1; } |
83 | |
84 | inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } |
85 | inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); } |
86 | inline int get_acc_reg_idx(int idx) { |
87 | const int max_regs = jcp.isa == avx512_core ? 32 : 16; |
88 | return idx + (max_regs - jcp.ur_w * jcp.nb_ch_blocking * max_repeats()); |
89 | } |
90 | inline Vmm get_acc_reg(int idx) { return Vmm(get_acc_reg_idx(idx)); } |
91 | |
92 | void load_tail( |
93 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int load_size); |
94 | void add_tail_from_mem(Vmm &vmm_acc, Vmm &vmm_tmp, const Xbyak::Reg64 ®, |
95 | int64_t offset, int load_size); |
96 | void store_tail( |
97 | Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, int store_size); |
98 | |
99 | int get_ow_start(int ki, int pad_l) { |
100 | return nstl::max(0, |
101 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
102 | } |
103 | |
104 | int get_ow_end(int ur_w, int ki, int pad_r) { |
105 | return ur_w |
106 | - nstl::max(0, |
107 | utils::div_up( |
108 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
109 | jcp.stride_w)); |
110 | } |
111 | |
112 | inline bool is_src_layout_nxc() { |
113 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
114 | format_tag::nwc); |
115 | } |
116 | inline bool is_dst_layout_nxc() { |
117 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
118 | format_tag::nwc); |
119 | } |
120 | |
121 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa>> |
122 | postops_injector_; |
123 | |
124 | void generate() override; |
125 | }; |
126 | |
127 | template <cpu_isa_t isa> |
128 | struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator { |
129 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) |
130 | |
131 | jit_uni_dw_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) |
132 | : jit_generator(jit_name()), jcp(ajcp) {} |
133 | jit_conv_conf_t jcp; |
134 | |
135 | private: |
136 | using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm, |
137 | isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; |
138 | const int reg_repeats_ = (isa == sse41) ? 2 : 1; |
139 | const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float); |
140 | using reg64_t = const Xbyak::Reg64; |
141 | |
142 | inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); } |
143 | inline Vmm get_ddst_reg(int idx) { return Vmm(idx + 1); } |
144 | inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); } |
145 | |
146 | reg64_t reg_ddst = rax; |
147 | reg64_t aux_reg_ddst = r8; |
148 | reg64_t aux1_reg_ddst = abi_not_param1; |
149 | reg64_t reg_kernel = rdx; |
150 | reg64_t aux_reg_kernel = r10; |
151 | reg64_t aux1_reg_kernel = rbp; |
152 | reg64_t reg_dsrc = rsi; |
153 | |
154 | reg64_t reg_ur_str_w = r9; |
155 | reg64_t reg_ch_blocks = rbx; |
156 | |
157 | reg64_t iter_kh = r11; |
158 | reg64_t iter_kw = r12; |
159 | reg64_t reg_kh = r13; |
160 | reg64_t reg_kw = r14; |
161 | |
162 | reg64_t aux_reg_ch_blocks = r15; |
163 | reg64_t reg_tmp = r15; |
164 | Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1); |
165 | |
166 | void load_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail); |
167 | void store_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail); |
168 | |
169 | inline void ch_loop_body(int ur_ch_blocks, int unroll_w); |
170 | inline void unroll_width_body(int ur_ch_blocks); |
171 | inline void load_ddst(int ur_ch_blocks, int ur_str_w); |
172 | inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch); |
173 | inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch); |
174 | |
175 | void generate() override; |
176 | |
177 | inline bool tail_simd_overlap(int reg_repeat) { |
178 | return reg_repeat * simd_w_ >= jcp.ch_tail; |
179 | } |
180 | |
181 | inline bool is_dsrc_layout_nxc() { |
182 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
183 | format_tag::nwc); |
184 | } |
185 | inline bool is_ddst_layout_nxc() { |
186 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
187 | format_tag::nwc); |
188 | } |
189 | }; |
190 | |
191 | template <cpu_isa_t isa> |
192 | struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { |
193 | |
194 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) |
195 | |
196 | jit_uni_dw_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp) |
197 | : jit_generator(jit_name()), jcp(ajcp) {} |
198 | |
199 | jit_conv_conf_t jcp; |
200 | |
201 | private: |
202 | using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm, |
203 | isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; |
204 | |
205 | const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float); |
206 | const int reg_repeats_ = (isa == sse41) ? 2 : 1; |
207 | const int req_aux_vmm = isa == sse41 ? 1 : 0; // used for FMA operand |
208 | |
209 | const int max_unroll_w_ = 30; |
210 | const int block_size_ = 15; |
211 | |
212 | const Xbyak::AddressFrame &vmmword |
213 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
214 | |
215 | /* Offset between input and accummulators is 3, therefore, assume 'kw' |
216 | * is no larger than 3*/ |
217 | inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); } |
218 | inline Vmm get_output_reg(int idx) { |
219 | int vmm_idx = jcp.is_fast_depthwise |
220 | ? idx + 2 * jcp.kw * jcp.nb_ch_blocking |
221 | : idx + req_aux_vmm; |
222 | return Vmm(vmm_idx); |
223 | } |
224 | inline Vmm get_input_reg(int idx) { |
225 | int vmm_idx = jcp.is_fast_depthwise |
226 | ? idx + jcp.kw * jcp.nb_ch_blocking |
227 | : idx + 4 * reg_repeats_ + req_aux_vmm; |
228 | return Vmm(vmm_idx); |
229 | } |
230 | inline Vmm get_acc_reg(int idx) { |
231 | int vmm_idx = jcp.is_fast_depthwise |
232 | ? idx |
233 | : idx + 1 * reg_repeats_ + req_aux_vmm; |
234 | return Vmm(vmm_idx); |
235 | } |
236 | inline Vmm get_aux_reg() { return Vmm(0); } |
237 | |
238 | using reg64_t = const Xbyak::Reg64; |
239 | reg64_t reg_tmp_input = r9; |
240 | reg64_t reg_tmp_output = r10; |
241 | reg64_t reg_tmp_filter = r13; |
242 | reg64_t reg_kh_offset = rax; |
243 | |
244 | /* parameter passed by driver into kernel */ |
245 | Xbyak::Reg8 reg_exec_flags = bl; |
246 | |
247 | reg64_t reg_oh_worksize = r14; |
248 | reg64_t reg_oh = rax; |
249 | |
250 | reg64_t reg_iter_ow_blk = r11; |
251 | |
252 | reg64_t reg_kh_aux = rsi; |
253 | reg64_t reg_kh = rdx; |
254 | |
255 | /* Base addresses for convolution parameters. */ |
256 | reg64_t reg_input_baddr = r15; |
257 | reg64_t reg_output_baddr = r12; |
258 | reg64_t reg_filter_baddr = abi_not_param1; |
259 | reg64_t reg_bias_baddr = r13; |
260 | |
261 | reg64_t reg_tmp = r8; |
262 | |
263 | Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1); |
264 | |
265 | void addps_xmm(Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr, |
266 | bool compute_tail); |
267 | void load_xmm( |
268 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail = false); |
269 | void store_xmm( |
270 | Vmm &vmm, const Xbyak::Address &addr, bool compute_tail = false); |
271 | |
272 | void dispatch_ow_step_unroll(int unroll_w, int l_pad, int pad_offset, |
273 | int ow_block, int nb_ch_blocking, bool is_last_ch); |
274 | |
275 | /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs |
276 | */ |
277 | void compute_unroll_ow_step(int unroll_w, int l_pad, int pad_offset, |
278 | int ow_block, bool is_last_ch); |
279 | |
280 | /* Micro-kernel JIT'ing, fusing 'kw', 'ow_block' and 'nb_ch_blocking' loops |
281 | * into unrolled FMAs. */ |
282 | void compute_unroll_ow_step_nxc(int unroll_w, int l_pad, int pad_offset, |
283 | int ow_block, int nb_ch_blocking, bool is_last_ch); |
284 | |
285 | /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ |
286 | void compute_kh_step(int unroll_w, int l_pad, int pad_offset, int ow_block, |
287 | int nb_ch_blocking, bool is_last_ch); |
288 | /* Channel loop for 'nxc' format */ |
289 | void compute_ch_loop(int unroll_w, int l_pad, int pad_offset, int ow_block); |
290 | void compute_h_loop(int unroll_w, int l_pad, int pad_offset, int ow_block); |
291 | |
292 | /* Write 'width' micro-kernel JITs; depending on the padding and convolution |
293 | * size, write a micro-kernel for the left ow-block, middle ow-block(s), and |
294 | * right ow-block.*/ |
295 | void compute_ow_block_unroll(); |
296 | |
297 | void deploy_zero_filter(); |
298 | void zero_filter_ch_loop(); |
299 | void zero_filter_kh_loop(int nb_ch_blocking = 1); |
300 | void load_filter(int nb_ch_blocking, bool is_last_ch = false); |
301 | void zero_filter(); |
302 | void load_bias(int nb_ch_blocking, bool is_last_ch); |
303 | void zero_bias(); |
304 | void compute_bias_step_unroll( |
305 | const int unroll_w, int nb_ch_blocking, bool is_last_ch); |
306 | void compute_ch_loop_bias(bool do_load_bias); |
307 | void deploy_ch_loop_bias(); |
308 | void compute_single_ch_block_bias(); |
309 | void compute_spatial_loop_bias(int nb_ch_blocking, bool is_last_ch); |
310 | void store_filter(int nb_ch_blocking, bool is_last_ch = false); |
311 | void store_bias(int nb_ch_blocking, bool is_last_ch); |
312 | void compute_bias(); |
313 | void calculate_w_unrolling( |
314 | int &unroll_trips, int &unroll_w, int &unroll_w_tail); |
315 | |
316 | void generate() override; |
317 | |
318 | inline bool is_layout_nxc() { |
319 | return utils::everyone_is( |
320 | true, is_src_layout_nxc(), is_ddst_layout_nxc()); |
321 | } |
322 | inline bool is_src_layout_nxc() { |
323 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
324 | format_tag::nwc); |
325 | } |
326 | inline bool is_ddst_layout_nxc() { |
327 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
328 | format_tag::nwc); |
329 | } |
330 | }; |
331 | |
332 | } // namespace x64 |
333 | } // namespace cpu |
334 | } // namespace impl |
335 | } // namespace dnnl |
336 | |
337 | #endif |
338 | |