1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX2_CONV_KERNEL_F32_HPP
18#define CPU_X64_JIT_AVX2_CONV_KERNEL_F32_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory.hpp"
22#include "common/memory_tracking.hpp"
23
24#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25#include "cpu/x64/jit_generator.hpp"
26#include "cpu/x64/jit_primitive_conf.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33struct jit_avx2_conv_fwd_kernel_f32 : public jit_generator {
34 jit_avx2_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp,
35 const primitive_attr_t &attr, const memory_desc_t &dst_md);
36
37 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32)
38
39 static status_t init_conf(jit_conv_conf_t &jcp,
40 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
41 const memory_desc_wrapper &weights_d,
42 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
43 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
44 const jit_conv_conf_t &jcp);
45
46 jit_conv_conf_t jcp;
47 const primitive_attr_t &attr_;
48
49private:
50 std::unique_ptr<injector::jit_uni_postops_injector_t<avx2>>
51 postops_injector_;
52
53 constexpr static int isa_simd_width_
54 = cpu_isa_traits<avx2>::vlen / sizeof(float);
55 using reg64_t = const Xbyak::Reg64;
56 reg64_t reg_input = rax;
57 reg64_t aux_reg_input = r8;
58 reg64_t reg_kernel = rdx;
59 reg64_t aux_reg_kernel = r9;
60 reg64_t reg_output = rsi;
61 reg64_t reg_bias = rbx;
62
63 reg64_t aux_reg_inp_d = r11;
64 reg64_t aux_reg_ker_d = abi_not_param1;
65
66 reg64_t reg_ki = rsi;
67 reg64_t kj = r10;
68 reg64_t oi_iter = r11;
69 reg64_t ki_iter = r12;
70 reg64_t reg_channel = ki_iter;
71 reg64_t reg_kh = abi_not_param1;
72 reg64_t reg_oc_blocks = r14;
73 reg64_t imm_addr64 = r15;
74 reg64_t reg_long_offt = r15;
75 Xbyak::Reg32 reg_ci_flag = r13d;
76 Xbyak::Reg32 reg_oc_flag = r14d;
77
78 /* binary post-ops operand */
79 reg64_t temp_offset_reg = r12;
80
81 Xbyak::Ymm ytmp = Xbyak::Ymm(14);
82
83 inline void oh_step_unroll_kw(
84 int ur_w, int pad_l, int pad_r, int oc_blocks);
85 inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
86 void apply_postops(const int oc_blocks, const int ur_w, const int oc_tail);
87 inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
88 inline void solve_common(int oc_blocks);
89
90 inline dim_t filter_w_to_input(int ki, int oi = 0, int pad_l = 0) {
91 return ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l;
92 };
93 inline dim_t filter_h_to_input(int ki) {
94 return ki * (jcp.dilate_h + 1) * jcp.iw;
95 };
96 inline dim_t filter_d_to_input(int ki) {
97 return ki * (jcp.dilate_d + 1) * jcp.iw * jcp.ih;
98 };
99
100 inline dim_t get_input_offset(int i_ic, int i_iw) {
101 dim_t offset;
102 if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw,
103 format_tag::ncdhw)) {
104 offset = static_cast<dim_t>(i_ic) * jcp.id * jcp.ih * jcp.iw + i_iw;
105 } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc,
106 format_tag::ndhwc)) {
107 offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups + i_ic;
108 } else {
109 offset = static_cast<dim_t>(i_iw) * jcp.ic_block + i_ic;
110 }
111 return sizeof(float) * offset;
112 }
113
114 inline dim_t get_output_offset(int i_oc_block, int i_ow) {
115 dim_t offset;
116 if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc,
117 format_tag::ndhwc)) {
118 offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups
119 + i_oc_block * jcp.oc_block;
120 } else {
121 offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow
122 * jcp.oc_block
123 + i_ow * jcp.oc_block;
124 }
125 return sizeof(float) * offset;
126 }
127
128 inline dim_t get_kernel_offset(int i_oc_block, int ki, int i_ic) {
129 dim_t block_step_size = jcp.ic_block * jcp.oc_block;
130 dim_t ic_block_step_size = static_cast<dim_t>(jcp.kd) * jcp.kh * jcp.kw
131 * block_step_size;
132 dim_t oc_block_step_size
133 = static_cast<dim_t>(jcp.nb_ic) * ic_block_step_size;
134 dim_t offset = static_cast<dim_t>(i_oc_block) * oc_block_step_size
135 + ki * block_step_size + i_ic * jcp.oc_block;
136 return sizeof(float) * offset;
137 }
138
139 inline bool is_src_layout_nxc() {
140 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
141 format_tag::nwc);
142 }
143
144 void generate() override;
145};
146
147struct jit_avx2_conv_bwd_data_kernel_f32 : public jit_generator {
148 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32)
149
150 jit_avx2_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp)
151 : jit_generator(jit_name()), jcp(ajcp) {}
152
153 static status_t init_conf(jit_conv_conf_t &jcp,
154 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
155 const memory_desc_wrapper &weights_d,
156 const memory_desc_wrapper &diff_dst_d);
157 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
158 const jit_conv_conf_t &jcp);
159
160 jit_conv_conf_t jcp;
161
162private:
163 using reg64_t = const Xbyak::Reg64;
164
165 reg64_t reg_ddst = rax;
166 reg64_t aux_reg_ddst = r8;
167 reg64_t reg_kernel = rdx;
168 reg64_t aux_reg_kernel = r10;
169 reg64_t reg_dsrc = rsi;
170 reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only
171 reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5
172 case only */
173
174 reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only
175 reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only
176
177 reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only
178 reg64_t kj = r11;
179 reg64_t oi_iter = r12;
180 reg64_t reg_kh = r14;
181 reg64_t reg_channel = r13; // used in ndims < 5 case only
182 reg64_t reg_channel_work = r9; // used in ndims < 5 case only
183 reg64_t reg_long_offt = r15;
184 reg64_t reg_reduce_work = reg_long_offt;
185 Xbyak::Reg32 reg_ci_flag = r13d; // used for nxc tails
186
187 inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
188
189 void generate() override;
190
191 inline int get_iw_start(int ki, int l_overflow) {
192 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
193 + l_overflow * jcp.stride_w
194 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
195 while (res < 0)
196 res += jcp.stride_w;
197
198 return res;
199 }
200
201 inline int get_iw_end(int ur_w, int ki, int r_overflow) {
202 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
203 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
204 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
205 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
206 while (res < 0)
207 res += jcp.stride_w;
208
209 return ur_w - res;
210 }
211
212 inline dim_t filter_w_to_ddst(int ki, int oi = 0, int pad_l = 0) {
213 return (oi + pad_l - ki * (jcp.dilate_w + 1)) / jcp.stride_w;
214 }
215
216 inline dim_t get_ddst_offset(int i_oc_block, int i_ow, int i_oc) {
217 dim_t offset;
218 if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc,
219 format_tag::ndhwc)) {
220 offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups
221 + i_oc_block * jcp.oc_block + i_oc;
222 } else {
223 offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow
224 * jcp.oc_block
225 + i_ow * jcp.oc_block + i_oc;
226 }
227 return sizeof(float) * offset;
228 }
229
230 inline dim_t get_dsrc_offset(int i_ic_block, int i_iw) {
231 dim_t offset;
232 if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc,
233 format_tag::ndhwc)) {
234 offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups
235 + i_ic_block * jcp.ic_block;
236 } else {
237 offset = static_cast<dim_t>(i_ic_block) * jcp.id * jcp.ih * jcp.iw
238 * jcp.ic_block
239 + i_iw * jcp.ic_block;
240 }
241 return sizeof(float) * offset;
242 }
243
244 inline dim_t get_kernel_offset(
245 int i_oc_block, int i_ic_block, int ki, int i_oc) {
246 dim_t block_step_size = jcp.ic_block * jcp.oc_block;
247 dim_t ic_block_step_size = static_cast<dim_t>(jcp.kd) * jcp.kh * jcp.kw
248 * block_step_size;
249 dim_t oc_block_step_size
250 = static_cast<dim_t>(jcp.nb_ic) * ic_block_step_size;
251 dim_t offset = static_cast<dim_t>(i_oc_block) * oc_block_step_size
252 + i_ic_block * ic_block_step_size + ki * block_step_size
253 + i_oc * jcp.ic_block;
254 return sizeof(float) * offset;
255 }
256};
257
258struct jit_avx2_conv_bwd_weights_kernel_f32 : public jit_generator {
259 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32)
260
261 jit_avx2_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp)
262 : jit_generator(jit_name()), jcp(ajcp) {}
263
264 static status_t init_conf(jit_conv_conf_t &jcp,
265 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
266 const memory_desc_wrapper &diff_weights_d,
267 const memory_desc_wrapper &diff_dst_d);
268 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
269 const jit_conv_conf_t &jcp);
270
271 jit_conv_conf_t jcp;
272
273private:
274 using reg64_t = const Xbyak::Reg64;
275 reg64_t reg_input = rax;
276 reg64_t reg_kernel = rdx;
277 reg64_t reg_output = rsi;
278 reg64_t b_ic = abi_not_param1;
279 reg64_t kj = r8;
280 reg64_t reg_kh = r9;
281 reg64_t reg_ur_w_trips = r10;
282 reg64_t reg_tmp = r11;
283 reg64_t reg_oj = r15;
284 reg64_t reg_ih_count = rbx;
285 reg64_t aux_reg_input = r12;
286 reg64_t aux_reg_kernel = r13;
287 reg64_t ki = r14;
288 reg64_t reg_long_offt = r11;
289 reg64_t reg_channel = reg_ih_count; // used for nxc tails
290 Xbyak::Reg32 reg_ci_flag = r9d; // used for nxc tails
291
292 inline void od_step_comeback_pointers();
293 inline void oh_step_comeback_pointers();
294 inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
295 int ic_block_step, int input_offset, int kernel_offset,
296 int output_offset);
297 inline void compute_oh_step_disp();
298 inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
299 inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
300 inline void compute_oh_loop_common();
301
302 inline dim_t get_input_offset(int i_ic, int i_iw) {
303 dim_t offset;
304 if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw,
305 format_tag::ncdhw)) {
306 offset = static_cast<dim_t>(i_ic) * jcp.id * jcp.ih * jcp.iw + i_iw;
307 } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc,
308 format_tag::ndhwc)) {
309 offset = static_cast<dim_t>(i_iw) * jcp.ic * jcp.ngroups + i_ic;
310 } else {
311 offset = static_cast<dim_t>(i_iw) * jcp.ic_block + i_ic;
312 }
313 return sizeof(float) * offset;
314 }
315
316 inline dim_t get_output_offset(int i_oc_block, int i_ow) {
317 dim_t offset;
318 if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc,
319 format_tag::ndhwc)) {
320 offset = static_cast<dim_t>(i_ow) * jcp.oc * jcp.ngroups
321 + i_oc_block * jcp.oc_block;
322 } else {
323 offset = static_cast<dim_t>(i_oc_block) * jcp.od * jcp.oh * jcp.ow
324 * jcp.oc_block
325 + i_ow * jcp.oc_block;
326 }
327 return sizeof(float) * offset;
328 }
329
330 inline dim_t get_kernel_offset(int ki, int i_ic) {
331 dim_t block_step_size = jcp.ic_block * jcp.oc_block;
332 dim_t offset = static_cast<dim_t>(ki) * block_step_size
333 + i_ic * jcp.oc_block;
334 return sizeof(float) * offset;
335 }
336 void generate() override;
337};
338
339} // namespace x64
340} // namespace cpu
341} // namespace impl
342} // namespace dnnl
343
344#endif
345