1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_DW_CONV_KERNEL_F32_HPP
18#define CPU_X64_JIT_UNI_DW_CONV_KERNEL_F32_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
24#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
25#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
26#include "cpu/x64/jit_generator.hpp"
27#include "cpu/x64/jit_primitive_conf.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34template <cpu_isa_t isa>
35struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator {
36 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
37
38 jit_uni_dw_conv_fwd_kernel_f32(
39 const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md);
40
41 jit_conv_conf_t jcp;
42
43private:
44 using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm,
45 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
46 using reg64_t = const Xbyak::Reg64;
47 using mask_t = const Xbyak::Opmask;
48 const Xbyak::AddressFrame &vmmword
49 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
50 const int vlen = cpu_isa_traits<isa>::vlen;
51
52 // dw convolution
53 reg64_t reg_input = r8;
54 reg64_t aux_reg_input = r9;
55 reg64_t reg_kernel = r10;
56 reg64_t aux_reg_kernel = r11;
57 reg64_t reg_ch_blocks = r12;
58 reg64_t reg_output = r13;
59 reg64_t reg_bias = r14;
60 reg64_t reg_kh = r15;
61 reg64_t iter_kh = rax;
62 reg64_t reg_oi = rbx;
63 reg64_t aux_reg_ch_blocks = rsi;
64 // fused convolution
65 reg64_t reg_input_buffer_ptr = rdx;
66 reg64_t aux_reg_input_buffer_ptr = rbp;
67 reg64_t reg_iw_offset = reg_input; //Hack: clear reg_input early in kernel
68
69 reg64_t reg_tmp = reg_ch_blocks;
70 reg64_t reg_tail = rax;
71 mask_t k_oc_tail_mask = Xbyak::Opmask(2);
72
73 inline void load_src(int ur_ch_blocks, int ur_w, bool is_ch_tail);
74 inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r);
75 inline void ow_loop(int ur_ch_blocks);
76 inline void apply_filter_unrolled(
77 int ur_ch_blocks, int ur_w, int pad_l, int pad_r, bool is_ch_tail);
78 inline void apply_postops(
79 const int ur_ch_blocks, const int ur_w, const bool is_ch_tail);
80 inline void store_dst(int ur_ch_blocks, int ur_w, bool is_ch_tail);
81
82 int max_repeats() { return jcp.isa == sse41 ? 2 : 1; }
83
84 inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
85 inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
86 inline int get_acc_reg_idx(int idx) {
87 const int max_regs = jcp.isa == avx512_core ? 32 : 16;
88 return idx + (max_regs - jcp.ur_w * jcp.nb_ch_blocking * max_repeats());
89 }
90 inline Vmm get_acc_reg(int idx) { return Vmm(get_acc_reg_idx(idx)); }
91
92 void load_tail(
93 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int load_size);
94 void add_tail_from_mem(Vmm &vmm_acc, Vmm &vmm_tmp, const Xbyak::Reg64 &reg,
95 int64_t offset, int load_size);
96 void store_tail(
97 Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset, int store_size);
98
99 int get_ow_start(int ki, int pad_l) {
100 return nstl::max(0,
101 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
102 }
103
104 int get_ow_end(int ur_w, int ki, int pad_r) {
105 return ur_w
106 - nstl::max(0,
107 utils::div_up(
108 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
109 jcp.stride_w));
110 }
111
112 inline bool is_src_layout_nxc() {
113 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
114 format_tag::nwc);
115 }
116 inline bool is_dst_layout_nxc() {
117 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
118 format_tag::nwc);
119 }
120
121 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
122 postops_injector_;
123
124 void generate() override;
125};
126
127template <cpu_isa_t isa>
128struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator {
129 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
130
131 jit_uni_dw_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp)
132 : jit_generator(jit_name()), jcp(ajcp) {}
133 jit_conv_conf_t jcp;
134
135private:
136 using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm,
137 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
138 const int reg_repeats_ = (isa == sse41) ? 2 : 1;
139 const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
140 using reg64_t = const Xbyak::Reg64;
141
142 inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
143 inline Vmm get_ddst_reg(int idx) { return Vmm(idx + 1); }
144 inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
145
146 reg64_t reg_ddst = rax;
147 reg64_t aux_reg_ddst = r8;
148 reg64_t aux1_reg_ddst = abi_not_param1;
149 reg64_t reg_kernel = rdx;
150 reg64_t aux_reg_kernel = r10;
151 reg64_t aux1_reg_kernel = rbp;
152 reg64_t reg_dsrc = rsi;
153
154 reg64_t reg_ur_str_w = r9;
155 reg64_t reg_ch_blocks = rbx;
156
157 reg64_t iter_kh = r11;
158 reg64_t iter_kw = r12;
159 reg64_t reg_kh = r13;
160 reg64_t reg_kw = r14;
161
162 reg64_t aux_reg_ch_blocks = r15;
163 reg64_t reg_tmp = r15;
164 Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1);
165
166 void load_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail);
167 void store_vmm(Vmm &vmm, const Xbyak::Address &addr, bool tail);
168
169 inline void ch_loop_body(int ur_ch_blocks, int unroll_w);
170 inline void unroll_width_body(int ur_ch_blocks);
171 inline void load_ddst(int ur_ch_blocks, int ur_str_w);
172 inline void apply_filter(int ur_ch_blocks, int ur_str_w, bool is_last_ch);
173 inline void store_dsrc(int ur_ch_blocks, int ur_str_w, bool is_last_ch);
174
175 void generate() override;
176
177 inline bool tail_simd_overlap(int reg_repeat) {
178 return reg_repeat * simd_w_ >= jcp.ch_tail;
179 }
180
181 inline bool is_dsrc_layout_nxc() {
182 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
183 format_tag::nwc);
184 }
185 inline bool is_ddst_layout_nxc() {
186 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
187 format_tag::nwc);
188 }
189};
190
191template <cpu_isa_t isa>
192struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
193
194 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
195
196 jit_uni_dw_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp)
197 : jit_generator(jit_name()), jcp(ajcp) {}
198
199 jit_conv_conf_t jcp;
200
201private:
202 using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm,
203 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
204
205 const int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
206 const int reg_repeats_ = (isa == sse41) ? 2 : 1;
207 const int req_aux_vmm = isa == sse41 ? 1 : 0; // used for FMA operand
208
209 const int max_unroll_w_ = 30;
210 const int block_size_ = 15;
211
212 const Xbyak::AddressFrame &vmmword
213 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
214
215 /* Offset between input and accummulators is 3, therefore, assume 'kw'
216 * is no larger than 3*/
217 inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); }
218 inline Vmm get_output_reg(int idx) {
219 int vmm_idx = jcp.is_fast_depthwise
220 ? idx + 2 * jcp.kw * jcp.nb_ch_blocking
221 : idx + req_aux_vmm;
222 return Vmm(vmm_idx);
223 }
224 inline Vmm get_input_reg(int idx) {
225 int vmm_idx = jcp.is_fast_depthwise
226 ? idx + jcp.kw * jcp.nb_ch_blocking
227 : idx + 4 * reg_repeats_ + req_aux_vmm;
228 return Vmm(vmm_idx);
229 }
230 inline Vmm get_acc_reg(int idx) {
231 int vmm_idx = jcp.is_fast_depthwise
232 ? idx
233 : idx + 1 * reg_repeats_ + req_aux_vmm;
234 return Vmm(vmm_idx);
235 }
236 inline Vmm get_aux_reg() { return Vmm(0); }
237
238 using reg64_t = const Xbyak::Reg64;
239 reg64_t reg_tmp_input = r9;
240 reg64_t reg_tmp_output = r10;
241 reg64_t reg_tmp_filter = r13;
242 reg64_t reg_kh_offset = rax;
243
244 /* parameter passed by driver into kernel */
245 Xbyak::Reg8 reg_exec_flags = bl;
246
247 reg64_t reg_oh_worksize = r14;
248 reg64_t reg_oh = rax;
249
250 reg64_t reg_iter_ow_blk = r11;
251
252 reg64_t reg_kh_aux = rsi;
253 reg64_t reg_kh = rdx;
254
255 /* Base addresses for convolution parameters. */
256 reg64_t reg_input_baddr = r15;
257 reg64_t reg_output_baddr = r12;
258 reg64_t reg_filter_baddr = abi_not_param1;
259 reg64_t reg_bias_baddr = r13;
260
261 reg64_t reg_tmp = r8;
262
263 Xbyak::Opmask k_ch_tail_mask = Xbyak::Opmask(1);
264
265 void addps_xmm(Vmm &vmm_dst, Vmm &vmm_src, const Xbyak::Address &addr,
266 bool compute_tail);
267 void load_xmm(
268 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail = false);
269 void store_xmm(
270 Vmm &vmm, const Xbyak::Address &addr, bool compute_tail = false);
271
272 void dispatch_ow_step_unroll(int unroll_w, int l_pad, int pad_offset,
273 int ow_block, int nb_ch_blocking, bool is_last_ch);
274
275 /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
276 */
277 void compute_unroll_ow_step(int unroll_w, int l_pad, int pad_offset,
278 int ow_block, bool is_last_ch);
279
280 /* Micro-kernel JIT'ing, fusing 'kw', 'ow_block' and 'nb_ch_blocking' loops
281 * into unrolled FMAs. */
282 void compute_unroll_ow_step_nxc(int unroll_w, int l_pad, int pad_offset,
283 int ow_block, int nb_ch_blocking, bool is_last_ch);
284
285 /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
286 void compute_kh_step(int unroll_w, int l_pad, int pad_offset, int ow_block,
287 int nb_ch_blocking, bool is_last_ch);
288 /* Channel loop for 'nxc' format */
289 void compute_ch_loop(int unroll_w, int l_pad, int pad_offset, int ow_block);
290 void compute_h_loop(int unroll_w, int l_pad, int pad_offset, int ow_block);
291
292 /* Write 'width' micro-kernel JITs; depending on the padding and convolution
293 * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
294 * right ow-block.*/
295 void compute_ow_block_unroll();
296
297 void deploy_zero_filter();
298 void zero_filter_ch_loop();
299 void zero_filter_kh_loop(int nb_ch_blocking = 1);
300 void load_filter(int nb_ch_blocking, bool is_last_ch = false);
301 void zero_filter();
302 void load_bias(int nb_ch_blocking, bool is_last_ch);
303 void zero_bias();
304 void compute_bias_step_unroll(
305 const int unroll_w, int nb_ch_blocking, bool is_last_ch);
306 void compute_ch_loop_bias(bool do_load_bias);
307 void deploy_ch_loop_bias();
308 void compute_single_ch_block_bias();
309 void compute_spatial_loop_bias(int nb_ch_blocking, bool is_last_ch);
310 void store_filter(int nb_ch_blocking, bool is_last_ch = false);
311 void store_bias(int nb_ch_blocking, bool is_last_ch);
312 void compute_bias();
313 void calculate_w_unrolling(
314 int &unroll_trips, int &unroll_w, int &unroll_w_tail);
315
316 void generate() override;
317
318 inline bool is_layout_nxc() {
319 return utils::everyone_is(
320 true, is_src_layout_nxc(), is_ddst_layout_nxc());
321 }
322 inline bool is_src_layout_nxc() {
323 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
324 format_tag::nwc);
325 }
326 inline bool is_ddst_layout_nxc() {
327 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
328 format_tag::nwc);
329 }
330};
331
332} // namespace x64
333} // namespace cpu
334} // namespace impl
335} // namespace dnnl
336
337#endif
338