1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | #include "cpu/x64/jit_primitive_conf.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | //alpha determines the output tile_size |
31 | constexpr int alpha = 6; |
32 | constexpr int tile_size = 4; |
33 | //simd length used for vectorization |
34 | constexpr int simd_w = 16; |
35 | |
36 | struct _jit_avx512_core_f32_wino_conv_4x3_data_kernel : public jit_generator { |
37 | _jit_avx512_core_f32_wino_conv_4x3_data_kernel( |
38 | const jit_conv_winograd_conf_t &ajcp) |
39 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, false, avx512_core) |
40 | , jcp(ajcp) {} |
41 | |
42 | void generate() override { |
43 | { |
44 | const Xbyak::uint8 *addr = getCurr(); |
45 | this->weights_transform_data_ker_generate(); |
46 | weights_transform_data_ker |
47 | = (decltype(weights_transform_data_ker))addr; |
48 | register_jit_code(addr, getCurr() - addr); |
49 | } |
50 | { |
51 | align(); |
52 | const Xbyak::uint8 *addr = getCurr(); |
53 | this->input_transform_data_ker_generate(); |
54 | input_transform_data_ker = (decltype(input_transform_data_ker))addr; |
55 | register_jit_code(addr, getCurr() - addr); |
56 | } |
57 | { |
58 | align(); |
59 | const Xbyak::uint8 *addr = getCurr(); |
60 | this->output_transform_data_ker_generate(); |
61 | output_transform_data_ker |
62 | = (decltype(output_transform_data_ker))addr; |
63 | register_jit_code(addr, getCurr() - addr); |
64 | } |
65 | { |
66 | align(); |
67 | const Xbyak::uint8 *addr = getCurr(); |
68 | this->gemm_loop_generate(); |
69 | gemm_loop_ker = (decltype(gemm_loop_ker))addr; |
70 | register_jit_code(addr, getCurr() - addr); |
71 | } |
72 | } |
73 | |
74 | DECLARE_CPU_JIT_AUX_FUNCTIONS( |
75 | _jit_avx512_core_f32_wino_conv_4x3_data_kernel) |
76 | |
77 | static status_t init_conf_common(jit_conv_winograd_conf_t &jcp, |
78 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
79 | const memory_desc_wrapper &weights_d, |
80 | const memory_desc_wrapper &dst_d); |
81 | |
82 | static status_t init_conf_kernel( |
83 | jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK); |
84 | |
85 | jit_conv_winograd_conf_t jcp; |
86 | void (*gemm_loop_ker)(float *, const float *, const float *, const int); |
87 | void (*input_transform_data_ker)(jit_wino_transform_call_s *); |
88 | void (*output_transform_data_ker)(jit_wino_transform_call_s *); |
89 | void (*weights_transform_data_ker)(jit_wino_transform_call_s *); |
90 | |
91 | protected: |
92 | using reg64_t = const Xbyak::Reg64; |
93 | using reg32_t = const Xbyak::Reg32; |
94 | enum { typesize = sizeof(float) }; |
95 | |
96 | void gemm_loop_generate(); |
97 | void input_transform_data_ker_generate(); |
98 | void output_transform_data_ker_generate(); |
99 | void weights_transform_data_ker_generate(); |
100 | |
101 | /* registers used for GEMM */ |
102 | reg64_t reg_dstC = abi_param1; |
103 | reg64_t reg_srcA = abi_param2; |
104 | reg64_t reg_srcB = abi_param3; |
105 | reg64_t reg_is_beta_zero = abi_param4; |
106 | |
107 | reg64_t reg_dimM_block_loop_cnt = r10; |
108 | reg64_t reg_dimK_block_loop_cnt = r11; |
109 | |
110 | /* registers used for transforms*/ |
111 | reg64_t param = abi_param1; |
112 | |
113 | /* registers used for output_transform_data_ker */ |
114 | reg64_t oreg_temp = abi_not_param1; |
115 | reg64_t oreg_Ow = r9; |
116 | reg64_t oreg_src = r11; |
117 | reg64_t oreg_tile_block = r12; |
118 | reg64_t oreg_tile_block_ur = r13; |
119 | reg64_t oreg_nb_tile_block_ur = r14; |
120 | reg64_t oreg_O = r8; |
121 | reg64_t oreg_T = r10; |
122 | reg64_t oreg_dst = r11; |
123 | reg64_t oreg_ydim = r14; |
124 | reg64_t oreg_xdim = r15; |
125 | reg64_t oreg_out_j = r12; |
126 | reg64_t oreg_bias = rbx; |
127 | reg64_t imm_addr64 = rax; |
128 | |
129 | /* registers used for input_transform_data_ker */ |
130 | reg64_t ireg_temp = abi_not_param1; |
131 | reg64_t ireg_jtiles = rax; |
132 | reg64_t ireg_itiles = rbx; |
133 | reg64_t ireg_I = r8; |
134 | reg64_t ireg_src = r13; |
135 | reg64_t ireg_ydim = r14; |
136 | reg64_t ireg_xdim = r15; |
137 | reg64_t ireg_inp_j = r12; |
138 | reg64_t ireg_inp_i = rdx; |
139 | reg64_t ireg_mask_j = r11; |
140 | reg64_t ireg_mask = rsi; |
141 | reg32_t ireg_mask_32 = esi; |
142 | reg64_t ireg_zero = r9; |
143 | reg64_t ireg_Iw = r9; |
144 | reg64_t ireg_T = r10; |
145 | reg64_t ireg_tile_block = r12; |
146 | reg64_t ireg_tile_block_ur = r13; |
147 | reg64_t ireg_nb_tile_block_ur = r14; |
148 | reg64_t ireg_output = r15; |
149 | |
150 | /* registers used for wei transform */ |
151 | reg64_t wreg_temp = abi_not_param1; |
152 | reg64_t wreg_F = r8; |
153 | reg64_t wreg_src = r9; |
154 | reg64_t wreg_MT = r15; |
155 | reg64_t wreg_M = r14; |
156 | reg64_t wreg_dst = r10; |
157 | reg64_t wreg_dst_aux = r9; |
158 | reg64_t wreg_dst_idx = r8; |
159 | reg64_t wreg_Fw = r11; |
160 | reg64_t wreg_T = r12; |
161 | reg64_t wreg_cnt_j = rdx; |
162 | reg64_t wreg_F_aux = r14; |
163 | reg64_t wreg_Fw_aux = r15; |
164 | }; |
165 | |
166 | struct jit_avx512_core_f32_wino_conv_4x3_fwd_kernel |
167 | : _jit_avx512_core_f32_wino_conv_4x3_data_kernel { |
168 | using _jit_avx512_core_f32_wino_conv_4x3_data_kernel:: |
169 | _jit_avx512_core_f32_wino_conv_4x3_data_kernel; |
170 | |
171 | static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
172 | |
173 | static status_t init_conf(jit_conv_winograd_conf_t &jcp, |
174 | const convolution_desc_t &cd, const memory_desc_t &src_md, |
175 | memory_desc_t &weights_md, const memory_desc_t &dst_md, |
176 | const primitive_attr_t &attr); |
177 | }; |
178 | |
179 | struct jit_avx512_core_f32_wino_conv_4x3_bwd_data_kernel |
180 | : public _jit_avx512_core_f32_wino_conv_4x3_data_kernel { |
181 | using _jit_avx512_core_f32_wino_conv_4x3_data_kernel:: |
182 | _jit_avx512_core_f32_wino_conv_4x3_data_kernel; |
183 | |
184 | static status_t init_conf(jit_conv_winograd_conf_t &jcp, |
185 | const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d, |
186 | const memory_desc_wrapper &weights_d, |
187 | const memory_desc_wrapper &diff_dst_d); |
188 | }; |
189 | |
190 | struct jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel |
191 | : public jit_generator { |
192 | DECLARE_CPU_JIT_AUX_FUNCTIONS( |
193 | _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32) |
194 | |
195 | jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel( |
196 | const jit_conv_winograd_conf_t &ajcp) |
197 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, false, avx512_core) |
198 | , jcp(ajcp) {} |
199 | |
200 | void generate() override { |
201 | //******************* First iter kernel ********************// |
202 | { |
203 | const Xbyak::uint8 *addr = getCurr(); |
204 | this->gemm_loop_generate(true); |
205 | gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr; |
206 | register_jit_code(addr, getCurr() - addr); |
207 | } |
208 | { |
209 | align(); |
210 | const Xbyak::uint8 *addr = getCurr(); |
211 | this->src_transform_generate(); |
212 | src_transform = (decltype(src_transform))addr; |
213 | register_jit_code(addr, getCurr() - addr); |
214 | } |
215 | if (jcp.with_bias) { |
216 | align(); |
217 | const Xbyak::uint8 *addr = getCurr(); |
218 | this->diff_dst_transform_generate(true); |
219 | diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr; |
220 | register_jit_code(addr, getCurr() - addr); |
221 | } |
222 | { |
223 | align(); |
224 | const Xbyak::uint8 *addr = getCurr(); |
225 | this->diff_dst_transform_generate(false); |
226 | diff_dst_transform = (decltype(diff_dst_transform))addr; |
227 | register_jit_code(addr, getCurr() - addr); |
228 | } |
229 | if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) { |
230 | align(); |
231 | const Xbyak::uint8 *addr = getCurr(); |
232 | this->gemm_loop_generate(false); |
233 | gemm_loop_ker = (decltype(gemm_loop_ker))addr; |
234 | register_jit_code(addr, getCurr() - addr); |
235 | } |
236 | { |
237 | align(); |
238 | const Xbyak::uint8 *addr = getCurr(); |
239 | this->diff_weights_transform_generate(true); |
240 | diff_weights_transform = (decltype(diff_weights_transform))addr; |
241 | register_jit_code(addr, getCurr() - addr); |
242 | } |
243 | if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { |
244 | align(); |
245 | const Xbyak::uint8 *addr = getCurr(); |
246 | this->diff_weights_transform_generate(false); |
247 | diff_weights_transform_accum |
248 | = (decltype(diff_weights_transform_accum))addr; |
249 | register_jit_code(addr, getCurr() - addr); |
250 | } |
251 | } |
252 | |
253 | static status_t init_conf(jit_conv_winograd_conf_t &jcp, |
254 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
255 | const memory_desc_wrapper &diff_dst_d, |
256 | const memory_desc_wrapper &diff_weights_d); |
257 | |
258 | jit_conv_winograd_conf_t jcp; |
259 | void (*gemm_loop_ker)(float *, const float *, const float *); |
260 | void (*gemm_loop_ker_first_iter)(float *, const float *, const float *); |
261 | void (*src_transform)(jit_wino_transform_call_s *); |
262 | void (*diff_dst_transform)(jit_wino_transform_call_s *); |
263 | void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *); |
264 | void (*diff_weights_transform)(jit_wino_transform_call_s *); |
265 | void (*diff_weights_transform_accum)(jit_wino_transform_call_s *); |
266 | |
267 | private: |
268 | using reg64_t = const Xbyak::Reg64; |
269 | using reg32_t = const Xbyak::Reg32; |
270 | enum { typesize = sizeof(float) }; |
271 | |
272 | void src_transform_generate(); |
273 | void diff_dst_transform_generate(bool with_bias); |
274 | void diff_weights_transform_generate(bool first_tile); |
275 | |
276 | /*registers common to transforms*/ |
277 | reg64_t reg_transp = abi_param1; |
278 | reg64_t reg_ti = rbx; |
279 | reg64_t reg_tj = abi_not_param1; |
280 | reg64_t reg_src = r8; |
281 | reg64_t reg_dst = r9; |
282 | reg64_t reg_G = rsi; /*TODO: check if this is ok*/ |
283 | reg64_t reg_temp = rsi; |
284 | |
285 | /*registers common to src/diff_dst transform*/ |
286 | reg64_t reg_I = r10; |
287 | reg64_t reg_ydim = r11; |
288 | reg64_t reg_xdim = r12; |
289 | reg64_t reg_src_offset = r13; |
290 | reg64_t reg_zero = r14; |
291 | reg64_t reg_tile_count = r15; |
292 | reg64_t reg_maski = rsi; |
293 | reg32_t reg_maski_32 = esi; |
294 | reg64_t reg_maskj = rdx; |
295 | |
296 | reg64_t reg_T = rax; |
297 | reg64_t reg_oc_ur = rax; |
298 | reg64_t reg_ic_simd = r14; |
299 | reg64_t reg_bias = r10; |
300 | |
301 | void gemm_loop_generate(bool is_first_tile); |
302 | |
303 | reg64_t reg_dstC = abi_param1; |
304 | reg64_t reg_srcA = abi_param2; |
305 | reg64_t reg_srcB = abi_param3; |
306 | |
307 | reg64_t reg_dimM_block_loop_cnt = r9; |
308 | reg64_t reg_dimN_block_loop_cnt = r10; |
309 | reg64_t reg_nb_dimN_bcast_ur = r11; |
310 | reg64_t reg_dimK_block_loop_cnt = r12; |
311 | }; |
312 | } // namespace x64 |
313 | } // namespace cpu |
314 | } // namespace impl |
315 | } // namespace dnnl |
316 | |
317 | #endif |
318 | |