1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21
22#include "cpu/x64/jit_generator.hpp"
23#include "cpu/x64/jit_primitive_conf.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30//alpha determines the output tile_size
31constexpr int alpha = 6;
32constexpr int tile_size = 4;
33//simd length used for vectorization
34constexpr int simd_w = 16;
35
36struct _jit_avx512_core_f32_wino_conv_4x3_data_kernel : public jit_generator {
37 _jit_avx512_core_f32_wino_conv_4x3_data_kernel(
38 const jit_conv_winograd_conf_t &ajcp)
39 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, false, avx512_core)
40 , jcp(ajcp) {}
41
42 void generate() override {
43 {
44 const Xbyak::uint8 *addr = getCurr();
45 this->weights_transform_data_ker_generate();
46 weights_transform_data_ker
47 = (decltype(weights_transform_data_ker))addr;
48 register_jit_code(addr, getCurr() - addr);
49 }
50 {
51 align();
52 const Xbyak::uint8 *addr = getCurr();
53 this->input_transform_data_ker_generate();
54 input_transform_data_ker = (decltype(input_transform_data_ker))addr;
55 register_jit_code(addr, getCurr() - addr);
56 }
57 {
58 align();
59 const Xbyak::uint8 *addr = getCurr();
60 this->output_transform_data_ker_generate();
61 output_transform_data_ker
62 = (decltype(output_transform_data_ker))addr;
63 register_jit_code(addr, getCurr() - addr);
64 }
65 {
66 align();
67 const Xbyak::uint8 *addr = getCurr();
68 this->gemm_loop_generate();
69 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
70 register_jit_code(addr, getCurr() - addr);
71 }
72 }
73
74 DECLARE_CPU_JIT_AUX_FUNCTIONS(
75 _jit_avx512_core_f32_wino_conv_4x3_data_kernel)
76
77 static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
78 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
79 const memory_desc_wrapper &weights_d,
80 const memory_desc_wrapper &dst_d);
81
82 static status_t init_conf_kernel(
83 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
84
85 jit_conv_winograd_conf_t jcp;
86 void (*gemm_loop_ker)(float *, const float *, const float *, const int);
87 void (*input_transform_data_ker)(jit_wino_transform_call_s *);
88 void (*output_transform_data_ker)(jit_wino_transform_call_s *);
89 void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
90
91protected:
92 using reg64_t = const Xbyak::Reg64;
93 using reg32_t = const Xbyak::Reg32;
94 enum { typesize = sizeof(float) };
95
96 void gemm_loop_generate();
97 void input_transform_data_ker_generate();
98 void output_transform_data_ker_generate();
99 void weights_transform_data_ker_generate();
100
101 /* registers used for GEMM */
102 reg64_t reg_dstC = abi_param1;
103 reg64_t reg_srcA = abi_param2;
104 reg64_t reg_srcB = abi_param3;
105 reg64_t reg_is_beta_zero = abi_param4;
106
107 reg64_t reg_dimM_block_loop_cnt = r10;
108 reg64_t reg_dimK_block_loop_cnt = r11;
109
110 /* registers used for transforms*/
111 reg64_t param = abi_param1;
112
113 /* registers used for output_transform_data_ker */
114 reg64_t oreg_temp = abi_not_param1;
115 reg64_t oreg_Ow = r9;
116 reg64_t oreg_src = r11;
117 reg64_t oreg_tile_block = r12;
118 reg64_t oreg_tile_block_ur = r13;
119 reg64_t oreg_nb_tile_block_ur = r14;
120 reg64_t oreg_O = r8;
121 reg64_t oreg_T = r10;
122 reg64_t oreg_dst = r11;
123 reg64_t oreg_ydim = r14;
124 reg64_t oreg_xdim = r15;
125 reg64_t oreg_out_j = r12;
126 reg64_t oreg_bias = rbx;
127 reg64_t imm_addr64 = rax;
128
129 /* registers used for input_transform_data_ker */
130 reg64_t ireg_temp = abi_not_param1;
131 reg64_t ireg_jtiles = rax;
132 reg64_t ireg_itiles = rbx;
133 reg64_t ireg_I = r8;
134 reg64_t ireg_src = r13;
135 reg64_t ireg_ydim = r14;
136 reg64_t ireg_xdim = r15;
137 reg64_t ireg_inp_j = r12;
138 reg64_t ireg_inp_i = rdx;
139 reg64_t ireg_mask_j = r11;
140 reg64_t ireg_mask = rsi;
141 reg32_t ireg_mask_32 = esi;
142 reg64_t ireg_zero = r9;
143 reg64_t ireg_Iw = r9;
144 reg64_t ireg_T = r10;
145 reg64_t ireg_tile_block = r12;
146 reg64_t ireg_tile_block_ur = r13;
147 reg64_t ireg_nb_tile_block_ur = r14;
148 reg64_t ireg_output = r15;
149
150 /* registers used for wei transform */
151 reg64_t wreg_temp = abi_not_param1;
152 reg64_t wreg_F = r8;
153 reg64_t wreg_src = r9;
154 reg64_t wreg_MT = r15;
155 reg64_t wreg_M = r14;
156 reg64_t wreg_dst = r10;
157 reg64_t wreg_dst_aux = r9;
158 reg64_t wreg_dst_idx = r8;
159 reg64_t wreg_Fw = r11;
160 reg64_t wreg_T = r12;
161 reg64_t wreg_cnt_j = rdx;
162 reg64_t wreg_F_aux = r14;
163 reg64_t wreg_Fw_aux = r15;
164};
165
166struct jit_avx512_core_f32_wino_conv_4x3_fwd_kernel
167 : _jit_avx512_core_f32_wino_conv_4x3_data_kernel {
168 using _jit_avx512_core_f32_wino_conv_4x3_data_kernel::
169 _jit_avx512_core_f32_wino_conv_4x3_data_kernel;
170
171 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
172
173 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
174 const convolution_desc_t &cd, const memory_desc_t &src_md,
175 memory_desc_t &weights_md, const memory_desc_t &dst_md,
176 const primitive_attr_t &attr);
177};
178
179struct jit_avx512_core_f32_wino_conv_4x3_bwd_data_kernel
180 : public _jit_avx512_core_f32_wino_conv_4x3_data_kernel {
181 using _jit_avx512_core_f32_wino_conv_4x3_data_kernel::
182 _jit_avx512_core_f32_wino_conv_4x3_data_kernel;
183
184 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
185 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
186 const memory_desc_wrapper &weights_d,
187 const memory_desc_wrapper &diff_dst_d);
188};
189
190struct jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel
191 : public jit_generator {
192 DECLARE_CPU_JIT_AUX_FUNCTIONS(
193 _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
194
195 jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel(
196 const jit_conv_winograd_conf_t &ajcp)
197 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, false, avx512_core)
198 , jcp(ajcp) {}
199
200 void generate() override {
201 //******************* First iter kernel ********************//
202 {
203 const Xbyak::uint8 *addr = getCurr();
204 this->gemm_loop_generate(true);
205 gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
206 register_jit_code(addr, getCurr() - addr);
207 }
208 {
209 align();
210 const Xbyak::uint8 *addr = getCurr();
211 this->src_transform_generate();
212 src_transform = (decltype(src_transform))addr;
213 register_jit_code(addr, getCurr() - addr);
214 }
215 if (jcp.with_bias) {
216 align();
217 const Xbyak::uint8 *addr = getCurr();
218 this->diff_dst_transform_generate(true);
219 diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
220 register_jit_code(addr, getCurr() - addr);
221 }
222 {
223 align();
224 const Xbyak::uint8 *addr = getCurr();
225 this->diff_dst_transform_generate(false);
226 diff_dst_transform = (decltype(diff_dst_transform))addr;
227 register_jit_code(addr, getCurr() - addr);
228 }
229 if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
230 align();
231 const Xbyak::uint8 *addr = getCurr();
232 this->gemm_loop_generate(false);
233 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
234 register_jit_code(addr, getCurr() - addr);
235 }
236 {
237 align();
238 const Xbyak::uint8 *addr = getCurr();
239 this->diff_weights_transform_generate(true);
240 diff_weights_transform = (decltype(diff_weights_transform))addr;
241 register_jit_code(addr, getCurr() - addr);
242 }
243 if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
244 align();
245 const Xbyak::uint8 *addr = getCurr();
246 this->diff_weights_transform_generate(false);
247 diff_weights_transform_accum
248 = (decltype(diff_weights_transform_accum))addr;
249 register_jit_code(addr, getCurr() - addr);
250 }
251 }
252
253 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
254 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
255 const memory_desc_wrapper &diff_dst_d,
256 const memory_desc_wrapper &diff_weights_d);
257
258 jit_conv_winograd_conf_t jcp;
259 void (*gemm_loop_ker)(float *, const float *, const float *);
260 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
261 void (*src_transform)(jit_wino_transform_call_s *);
262 void (*diff_dst_transform)(jit_wino_transform_call_s *);
263 void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
264 void (*diff_weights_transform)(jit_wino_transform_call_s *);
265 void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
266
267private:
268 using reg64_t = const Xbyak::Reg64;
269 using reg32_t = const Xbyak::Reg32;
270 enum { typesize = sizeof(float) };
271
272 void src_transform_generate();
273 void diff_dst_transform_generate(bool with_bias);
274 void diff_weights_transform_generate(bool first_tile);
275
276 /*registers common to transforms*/
277 reg64_t reg_transp = abi_param1;
278 reg64_t reg_ti = rbx;
279 reg64_t reg_tj = abi_not_param1;
280 reg64_t reg_src = r8;
281 reg64_t reg_dst = r9;
282 reg64_t reg_G = rsi; /*TODO: check if this is ok*/
283 reg64_t reg_temp = rsi;
284
285 /*registers common to src/diff_dst transform*/
286 reg64_t reg_I = r10;
287 reg64_t reg_ydim = r11;
288 reg64_t reg_xdim = r12;
289 reg64_t reg_src_offset = r13;
290 reg64_t reg_zero = r14;
291 reg64_t reg_tile_count = r15;
292 reg64_t reg_maski = rsi;
293 reg32_t reg_maski_32 = esi;
294 reg64_t reg_maskj = rdx;
295
296 reg64_t reg_T = rax;
297 reg64_t reg_oc_ur = rax;
298 reg64_t reg_ic_simd = r14;
299 reg64_t reg_bias = r10;
300
301 void gemm_loop_generate(bool is_first_tile);
302
303 reg64_t reg_dstC = abi_param1;
304 reg64_t reg_srcA = abi_param2;
305 reg64_t reg_srcB = abi_param3;
306
307 reg64_t reg_dimM_block_loop_cnt = r9;
308 reg64_t reg_dimN_block_loop_cnt = r10;
309 reg64_t reg_nb_dimN_bcast_ur = r11;
310 reg64_t reg_dimK_block_loop_cnt = r12;
311};
312} // namespace x64
313} // namespace cpu
314} // namespace impl
315} // namespace dnnl
316
317#endif
318