1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX2_CONV_KERNEL_F32_HPP |
18 | #define CPU_X64_JIT_AVX2_CONV_KERNEL_F32_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | |
24 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
25 | #include "cpu/x64/jit_generator.hpp" |
26 | #include "cpu/x64/jit_primitive_conf.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | struct jit_avx2_conv_fwd_kernel_f32 : public jit_generator { |
34 | jit_avx2_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp, |
35 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
36 | |
37 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32) |
38 | |
39 | static status_t init_conf(jit_conv_conf_t &jcp, |
40 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
41 | const memory_desc_wrapper &weights_d, |
42 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); |
43 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
44 | const jit_conv_conf_t &jcp); |
45 | |
46 | jit_conv_conf_t jcp; |
47 | const primitive_attr_t &attr_; |
48 | |
49 | private: |
50 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx2>> |
51 | postops_injector_; |
52 | |
53 | constexpr static int isa_simd_width_ |
54 | = cpu_isa_traits<avx2>::vlen / sizeof(float); |
55 | using reg64_t = const Xbyak::Reg64; |
56 | reg64_t reg_input = rax; |
57 | reg64_t aux_reg_input = r8; |
58 | reg64_t reg_kernel = rdx; |
59 | reg64_t aux_reg_kernel = r9; |
60 | reg64_t reg_output = rsi; |
61 | reg64_t reg_bias = rbx; |
62 | |
63 | reg64_t aux_reg_inp_d = r11; |
64 | reg64_t aux_reg_ker_d = abi_not_param1; |
65 | |
66 | reg64_t reg_ki = rsi; |
67 | reg64_t kj = r10; |
68 | reg64_t oi_iter = r11; |
69 | reg64_t ki_iter = r12; |
70 | reg64_t reg_channel = ki_iter; |
71 | reg64_t reg_kh = abi_not_param1; |
72 | reg64_t reg_oc_blocks = r14; |
73 | reg64_t imm_addr64 = r15; |
74 | reg64_t reg_long_offt = r15; |
75 | Xbyak::Reg32 reg_ci_flag = r13d; |
76 | Xbyak::Reg32 reg_oc_flag = r14d; |
77 | |
78 | /* binary post-ops operand */ |
79 | reg64_t temp_offset_reg = r12; |
80 | |
81 | Xbyak::Ymm ytmp = Xbyak::Ymm(14); |
82 | |
83 | inline void oh_step_unroll_kw( |
84 | int ur_w, int pad_l, int pad_r, int oc_blocks); |
85 | inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); |
86 | void apply_postops(const int oc_blocks, const int ur_w, const int oc_tail); |
87 | inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); |
88 | inline void solve_common(int oc_blocks); |
89 | |
90 | inline dim_t filter_w_to_input(int ki, int oi = 0, int pad_l = 0) { |
91 | return ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; |
92 | }; |
93 | inline dim_t filter_h_to_input(int ki) { |
94 | return ki * (jcp.dilate_h + 1) * jcp.iw; |
95 | }; |
96 | inline dim_t filter_d_to_input(int ki) { |
97 | return ki * (jcp.dilate_d + 1) * jcp.iw * jcp.ih; |
98 | }; |
99 | |
100 | inline dim_t get_input_offset(int i_ic, int i_iw) { |
101 | dim_t offset; |
102 | if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw, |
103 | format_tag::ncdhw)) { |
104 | offset = static_cast<dim_t>(i_ic) * jcp.id * jcp.ih * jcp.iw + i_iw; |
105 | } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc, |
106 | format_tag::ndhwc)) { |
107 | offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups + i_ic; |
108 | } else { |
109 | offset = static_cast<dim_t>(i_iw) * jcp.ic_block + i_ic; |
110 | } |
111 | return sizeof(float) * offset; |
112 | } |
113 | |
114 | inline dim_t get_output_offset(int i_oc_block, int i_ow) { |
115 | dim_t offset; |
116 | if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc, |
117 | format_tag::ndhwc)) { |
118 | offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups |
119 | + i_oc_block * jcp.oc_block; |
120 | } else { |
121 | offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow |
122 | * jcp.oc_block |
123 | + i_ow * jcp.oc_block; |
124 | } |
125 | return sizeof(float) * offset; |
126 | } |
127 | |
128 | inline dim_t get_kernel_offset(int i_oc_block, int ki, int i_ic) { |
129 | dim_t block_step_size = jcp.ic_block * jcp.oc_block; |
130 | dim_t ic_block_step_size = static_cast<dim_t>(jcp.kd) * jcp.kh * jcp.kw |
131 | * block_step_size; |
132 | dim_t oc_block_step_size |
133 | = static_cast<dim_t>(jcp.nb_ic) * ic_block_step_size; |
134 | dim_t offset = static_cast<dim_t>(i_oc_block) * oc_block_step_size |
135 | + ki * block_step_size + i_ic * jcp.oc_block; |
136 | return sizeof(float) * offset; |
137 | } |
138 | |
139 | inline bool is_src_layout_nxc() { |
140 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
141 | format_tag::nwc); |
142 | } |
143 | |
144 | void generate() override; |
145 | }; |
146 | |
147 | struct jit_avx2_conv_bwd_data_kernel_f32 : public jit_generator { |
148 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32) |
149 | |
150 | jit_avx2_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) |
151 | : jit_generator(jit_name()), jcp(ajcp) {} |
152 | |
153 | static status_t init_conf(jit_conv_conf_t &jcp, |
154 | const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, |
155 | const memory_desc_wrapper &weights_d, |
156 | const memory_desc_wrapper &diff_dst_d); |
157 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
158 | const jit_conv_conf_t &jcp); |
159 | |
160 | jit_conv_conf_t jcp; |
161 | |
162 | private: |
163 | using reg64_t = const Xbyak::Reg64; |
164 | |
165 | reg64_t reg_ddst = rax; |
166 | reg64_t aux_reg_ddst = r8; |
167 | reg64_t reg_kernel = rdx; |
168 | reg64_t aux_reg_kernel = r10; |
169 | reg64_t reg_dsrc = rsi; |
170 | reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only |
171 | reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5 |
172 | case only */ |
173 | |
174 | reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only |
175 | reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only |
176 | |
177 | reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only |
178 | reg64_t kj = r11; |
179 | reg64_t oi_iter = r12; |
180 | reg64_t reg_kh = r14; |
181 | reg64_t reg_channel = r13; // used in ndims < 5 case only |
182 | reg64_t reg_channel_work = r9; // used in ndims < 5 case only |
183 | reg64_t reg_long_offt = r15; |
184 | reg64_t reg_reduce_work = reg_long_offt; |
185 | Xbyak::Reg32 reg_ci_flag = r13d; // used for nxc tails |
186 | |
187 | inline void compute_loop(int ur_w, int l_overflow, int r_overflow); |
188 | |
189 | void generate() override; |
190 | |
191 | inline int get_iw_start(int ki, int l_overflow) { |
192 | int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w |
193 | + l_overflow * jcp.stride_w |
194 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
195 | while (res < 0) |
196 | res += jcp.stride_w; |
197 | |
198 | return res; |
199 | } |
200 | |
201 | inline int get_iw_end(int ur_w, int ki, int r_overflow) { |
202 | if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) |
203 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
204 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
205 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
206 | while (res < 0) |
207 | res += jcp.stride_w; |
208 | |
209 | return ur_w - res; |
210 | } |
211 | |
212 | inline dim_t filter_w_to_ddst(int ki, int oi = 0, int pad_l = 0) { |
213 | return (oi + pad_l - ki * (jcp.dilate_w + 1)) / jcp.stride_w; |
214 | } |
215 | |
216 | inline dim_t get_ddst_offset(int i_oc_block, int i_ow, int i_oc) { |
217 | dim_t offset; |
218 | if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc, |
219 | format_tag::ndhwc)) { |
220 | offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups |
221 | + i_oc_block * jcp.oc_block + i_oc; |
222 | } else { |
223 | offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow |
224 | * jcp.oc_block |
225 | + i_ow * jcp.oc_block + i_oc; |
226 | } |
227 | return sizeof(float) * offset; |
228 | } |
229 | |
230 | inline dim_t get_dsrc_offset(int i_ic_block, int i_iw) { |
231 | dim_t offset; |
232 | if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc, |
233 | format_tag::ndhwc)) { |
234 | offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups |
235 | + i_ic_block * jcp.ic_block; |
236 | } else { |
237 | offset = static_cast<dim_t>(i_ic_block) * jcp.id * jcp.ih * jcp.iw |
238 | * jcp.ic_block |
239 | + i_iw * jcp.ic_block; |
240 | } |
241 | return sizeof(float) * offset; |
242 | } |
243 | |
244 | inline dim_t get_kernel_offset( |
245 | int i_oc_block, int i_ic_block, int ki, int i_oc) { |
246 | dim_t block_step_size = jcp.ic_block * jcp.oc_block; |
247 | dim_t ic_block_step_size = static_cast<dim_t>(jcp.kd) * jcp.kh * jcp.kw |
248 | * block_step_size; |
249 | dim_t oc_block_step_size |
250 | = static_cast<dim_t>(jcp.nb_ic) * ic_block_step_size; |
251 | dim_t offset = static_cast<dim_t>(i_oc_block) * oc_block_step_size |
252 | + i_ic_block * ic_block_step_size + ki * block_step_size |
253 | + i_oc * jcp.ic_block; |
254 | return sizeof(float) * offset; |
255 | } |
256 | }; |
257 | |
258 | struct jit_avx2_conv_bwd_weights_kernel_f32 : public jit_generator { |
259 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32) |
260 | |
261 | jit_avx2_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp) |
262 | : jit_generator(jit_name()), jcp(ajcp) {} |
263 | |
264 | static status_t init_conf(jit_conv_conf_t &jcp, |
265 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
266 | const memory_desc_wrapper &diff_weights_d, |
267 | const memory_desc_wrapper &diff_dst_d); |
268 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
269 | const jit_conv_conf_t &jcp); |
270 | |
271 | jit_conv_conf_t jcp; |
272 | |
273 | private: |
274 | using reg64_t = const Xbyak::Reg64; |
275 | reg64_t reg_input = rax; |
276 | reg64_t reg_kernel = rdx; |
277 | reg64_t reg_output = rsi; |
278 | reg64_t b_ic = abi_not_param1; |
279 | reg64_t kj = r8; |
280 | reg64_t reg_kh = r9; |
281 | reg64_t reg_ur_w_trips = r10; |
282 | reg64_t reg_tmp = r11; |
283 | reg64_t reg_oj = r15; |
284 | reg64_t reg_ih_count = rbx; |
285 | reg64_t aux_reg_input = r12; |
286 | reg64_t aux_reg_kernel = r13; |
287 | reg64_t ki = r14; |
288 | reg64_t reg_long_offt = r11; |
289 | reg64_t reg_channel = reg_ih_count; // used for nxc tails |
290 | Xbyak::Reg32 reg_ci_flag = r9d; // used for nxc tails |
291 | |
292 | inline void od_step_comeback_pointers(); |
293 | inline void oh_step_comeback_pointers(); |
294 | inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, |
295 | int ic_block_step, int input_offset, int kernel_offset, |
296 | int output_offset); |
297 | inline void compute_oh_step_disp(); |
298 | inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); |
299 | inline void compute_oh_step_common(int ic_block_step, int max_ur_w); |
300 | inline void compute_oh_loop_common(); |
301 | |
302 | inline dim_t get_input_offset(int i_ic, int i_iw) { |
303 | dim_t offset; |
304 | if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw, |
305 | format_tag::ncdhw)) { |
306 | offset = static_cast<dim_t>(i_ic) * jcp.id * jcp.ih * jcp.iw + i_iw; |
307 | } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc, |
308 | format_tag::ndhwc)) { |
309 | offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups + i_ic; |
310 | } else { |
311 | offset = static_cast<dim_t>(i_iw) * jcp.ic_block + i_ic; |
312 | } |
313 | return sizeof(float) * offset; |
314 | } |
315 | |
316 | inline dim_t get_output_offset(int i_oc_block, int i_ow) { |
317 | dim_t offset; |
318 | if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc, |
319 | format_tag::ndhwc)) { |
320 | offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups |
321 | + i_oc_block * jcp.oc_block; |
322 | } else { |
323 | offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow |
324 | * jcp.oc_block |
325 | + i_ow * jcp.oc_block; |
326 | } |
327 | return sizeof(float) * offset; |
328 | } |
329 | |
330 | inline dim_t get_kernel_offset(int ki, int i_ic) { |
331 | dim_t block_step_size = jcp.ic_block * jcp.oc_block; |
332 | dim_t offset = static_cast<dim_t>(ki) * block_step_size |
333 | + i_ic * jcp.oc_block; |
334 | return sizeof(float) * offset; |
335 | } |
336 | void generate() override; |
337 | }; |
338 | |
339 | } // namespace x64 |
340 | } // namespace cpu |
341 | } // namespace impl |
342 | } // namespace dnnl |
343 | |
344 | #endif |
345 | |