1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_COMMON_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_COMMON_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | template <typename Vmm> |
33 | struct _jit_avx512_common_conv_fwd_kernel : public jit_generator { |
34 | |
35 | _jit_avx512_common_conv_fwd_kernel(const jit_conv_conf_t &ajcp, |
36 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
37 | |
38 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel) |
39 | |
40 | jit_conv_conf_t jcp; |
41 | const primitive_attr_t &attr_; |
42 | |
43 | private: |
44 | constexpr static int isa_simd_width_ |
45 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
46 | using reg64_t = const Xbyak::Reg64; |
47 | enum { |
48 | typesize = sizeof(float), |
49 | ker_reg_base_idx = 28, |
50 | }; |
51 | |
52 | reg64_t param = abi_param1; |
53 | reg64_t reg_inp = r8; |
54 | reg64_t reg_ker = r9; |
55 | reg64_t reg_out = r10; |
56 | |
57 | reg64_t reg_owb = r12; |
58 | |
59 | reg64_t aux_reg_inp = r14; |
60 | reg64_t aux_reg_ker = r15; |
61 | |
62 | reg64_t reg_channel = rsi; |
63 | reg64_t reg_bias = rdx; |
64 | |
65 | reg64_t aux_reg_ker_d = r9; |
66 | reg64_t aux_reg_inp_d = rbx; |
67 | reg64_t reg_ki = r10; |
68 | |
69 | reg64_t reg_kj = rax; |
70 | reg64_t reg_relu_ns = rax; |
71 | reg64_t reg_oi = rbx; |
72 | reg64_t reg_kh = abi_not_param1; |
73 | |
74 | reg64_t reg_tmp = rbp; |
75 | |
76 | reg64_t reg_long_offt = r11; |
77 | reg64_t reg_out_long_offt = r14; |
78 | reg64_t reg_ker_long_offt = r11; |
79 | reg64_t reg_tail = aux_reg_ker; |
80 | reg64_t reg_load_work = reg_tail; |
81 | |
82 | /* binary post-ops operand */ |
83 | reg64_t temp_offset_reg = r12; |
84 | |
85 | Xbyak::Opmask k_oc_tail_mask = Xbyak::Opmask(2); |
86 | const Xbyak::Opmask postops_mask = Xbyak::Opmask(3); |
87 | |
88 | inline Vmm vmm_ker(int i_ic) { |
89 | assert(i_ic < 4); |
90 | return Vmm(ker_reg_base_idx + i_ic); |
91 | } |
92 | |
93 | inline int vmm_out_idx(int i_ur, int i_oc) { |
94 | const int idx = i_ur * jcp.nb_oc_blocking + i_oc; |
95 | assert(idx < ker_reg_base_idx); |
96 | return idx; |
97 | } |
98 | |
99 | inline Vmm vmm_out(int i_ur, int i_oc) { |
100 | return Vmm(vmm_out_idx(i_ur, i_oc)); |
101 | } |
102 | |
103 | inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
104 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
105 | assert(idx < 31); |
106 | return Vmm(idx); |
107 | } |
108 | |
109 | Xbyak::Reg64 imm_addr64 = r15; |
110 | Vmm vmm_wei = Vmm(31); |
111 | |
112 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
113 | postops_injector_; |
114 | |
115 | inline void prepare_output(int ur_w); |
116 | inline void apply_postops(int ur_w); |
117 | inline void store_output(int ur_w); |
118 | inline void compute_loop_fma(int ur_w, int pad_l, int pad_r); |
119 | inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); |
120 | inline void compute_loop(int ur_w, int pad_l, int pad_r); |
121 | |
122 | void generate() override; |
123 | |
124 | inline size_t get_output_offset(int oi, int n_oc_block) { |
125 | const bool is_nxc_layout = is_dst_layout_nxc(); |
126 | size_t ow_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block; |
127 | size_t ocb_str = is_nxc_layout |
128 | ? jcp.oc_block |
129 | : (size_t)jcp.od * jcp.oh * jcp.ow * jcp.oc_block; |
130 | |
131 | return jcp.typesize_out * (n_oc_block * ocb_str + oi * ow_str); |
132 | } |
133 | |
134 | inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { |
135 | const bool is_nxc_layout = is_src_layout_nxc(); |
136 | size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic |
137 | : (!jcp.is_1stconv ? jcp.ic_block : 1); |
138 | size_t ic_str = !jcp.is_1stconv || is_nxc_layout |
139 | ? 1 |
140 | : (size_t)jcp.iw * jcp.ih * jcp.id; |
141 | size_t iw_idx = ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; |
142 | |
143 | return jcp.typesize_in * (iw_idx * iw_str + ic * ic_str); |
144 | } |
145 | |
146 | inline int get_kernel_offset( |
147 | int ki, int ic, int n_oc_block, int ker_number) { |
148 | return jcp.typesize_in * jcp.oc_block |
149 | * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw |
150 | * jcp.kd |
151 | + (ic + ker_number) + ki * jcp.ic_block); |
152 | } |
153 | |
154 | inline int get_ow_start(int ki, int pad_l) { |
155 | return nstl::max(0, |
156 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
157 | } |
158 | |
159 | inline int get_ow_end(int ur_w, int ki, int pad_r) { |
160 | return ur_w |
161 | - nstl::max(0, |
162 | utils::div_up( |
163 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
164 | jcp.stride_w)); |
165 | } |
166 | inline bool is_src_layout_nxc() { |
167 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
168 | format_tag::nwc); |
169 | } |
170 | inline bool is_dst_layout_nxc() { |
171 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
172 | format_tag::nwc); |
173 | } |
174 | }; |
175 | |
176 | struct jit_avx512_common_conv_fwd_kernel { |
177 | |
178 | jit_avx512_common_conv_fwd_kernel(const jit_conv_conf_t &ajcp, |
179 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
180 | : kernel_(nullptr) { |
181 | switch (ajcp.oc_block) { |
182 | case 16: |
183 | kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>( |
184 | ajcp, attr, dst_md); |
185 | return; |
186 | case 8: |
187 | kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Ymm>( |
188 | ajcp, attr, dst_md); |
189 | return; |
190 | case 4: |
191 | kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>( |
192 | ajcp, attr, dst_md); |
193 | return; |
194 | default: assert(!"invalid channel blocking" ); |
195 | } |
196 | } |
197 | |
198 | status_t create_kernel() { return kernel_->create_kernel(); } |
199 | |
200 | ~jit_avx512_common_conv_fwd_kernel() { delete kernel_; } |
201 | |
202 | enum { typesize = sizeof(float) }; |
203 | |
204 | static status_t init_conf(jit_conv_conf_t &jcp, |
205 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
206 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
207 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
208 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
209 | const jit_conv_conf_t &jcp); |
210 | |
211 | void operator()(jit_conv_call_s *p) const { (*kernel_)(p); } |
212 | |
213 | const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } |
214 | |
215 | private: |
216 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_common_conv_fwd_kernel); |
217 | jit_generator *kernel_; |
218 | }; |
219 | |
220 | template <typename Vmm> |
221 | struct _jit_avx512_common_conv_bwd_data_kernel_f32 : public jit_generator { |
222 | |
223 | _jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) |
224 | : jit_generator(jit_name()), jcp(ajcp) {} |
225 | |
226 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_bwd_data_kernel_f32) |
227 | jit_conv_conf_t jcp; |
228 | |
229 | private: |
230 | using reg64_t = const Xbyak::Reg64; |
231 | enum { |
232 | typesize = sizeof(float), |
233 | ker_reg_base_idx = 28, |
234 | }; |
235 | |
236 | reg64_t param = abi_param1; |
237 | reg64_t reg_dst = r8; |
238 | reg64_t reg_ker = r9; |
239 | reg64_t reg_src = r10; |
240 | |
241 | reg64_t reg_iwb = r14; |
242 | |
243 | reg64_t aux_reg_dst = r14; |
244 | reg64_t aux_reg_ker = r15; |
245 | |
246 | reg64_t aux_reg_dst_d = rbx; |
247 | reg64_t aux_reg_ker_d = r9; |
248 | reg64_t reg_ki = r10; |
249 | |
250 | reg64_t reg_kj = rax; |
251 | reg64_t reg_oi = rbx; |
252 | reg64_t reg_kh = abi_not_param1; |
253 | |
254 | reg64_t reg_channel = rsi; |
255 | |
256 | reg64_t reg_tmp = rbp; |
257 | reg64_t reg_long_offt = r14; |
258 | |
259 | reg64_t reg_tail = aux_reg_ker; |
260 | reg64_t reg_load_work = reg_tail; |
261 | |
262 | Xbyak::Opmask k_ic_tail_mask = Xbyak::Opmask(1); |
263 | |
264 | inline Vmm vmm_ker(int i_ic) { |
265 | assert(i_ic < 4); |
266 | return Vmm(ker_reg_base_idx + i_ic); |
267 | } |
268 | inline Vmm vmm_inp(int i_ic, int nb_x_blocking) { |
269 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
270 | assert(idx < 31); |
271 | return Vmm(idx); |
272 | } |
273 | inline Vmm vmm_out(int i_ur, int i_oc) { |
274 | int idx = i_ur + i_oc * jcp.ur_w; |
275 | assert(idx < ker_reg_base_idx); |
276 | return Vmm(idx); |
277 | } |
278 | |
279 | Vmm vmm_wei = Vmm(31); |
280 | |
281 | inline void prepare_output(int ur_w); |
282 | inline void store_output(int ur_w); |
283 | inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); |
284 | inline void compute_loop_fma_core( |
285 | int ur_w, int l_overflow, int r_overflow, int k_offset); |
286 | inline void compute_loop( |
287 | int ur_w, int l_overflow, int r_overflow, int k_offset = 0); |
288 | void generate() override; |
289 | |
290 | inline int get_iw_start(int ki, int l_overflow) { |
291 | int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w |
292 | + l_overflow * jcp.stride_w |
293 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
294 | while (res < 0) |
295 | res += jcp.stride_w; |
296 | |
297 | return res; |
298 | } |
299 | |
300 | inline int get_iw_end(int ur_w, int ki, int r_overflow) { |
301 | if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) |
302 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
303 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
304 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
305 | while (res < 0) |
306 | res += jcp.stride_w; |
307 | |
308 | return ur_w - res; |
309 | } |
310 | |
311 | inline size_t get_diff_src_offset(int iw, int icb) { |
312 | const bool is_nxc_layout = is_dsrc_layout_nxc(); |
313 | size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block; |
314 | size_t icb_str = is_nxc_layout |
315 | ? jcp.ic_block |
316 | : (size_t)jcp.id * jcp.ih * jcp.iw * jcp.ic_block; |
317 | |
318 | return typesize * (icb * icb_str + iw * iw_str); |
319 | } |
320 | |
321 | inline ptrdiff_t get_dst_offset(int iw, int oc, int kw) { |
322 | ptrdiff_t ow |
323 | = (iw + jcp.l_pad - kw * (jcp.dilate_w + 1)) / jcp.stride_w; |
324 | ptrdiff_t ow_str |
325 | = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block; |
326 | |
327 | return typesize * (ow * ow_str + oc); |
328 | }; |
329 | |
330 | inline bool is_dsrc_layout_nxc() { |
331 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
332 | format_tag::nwc); |
333 | } |
334 | inline bool is_ddst_layout_nxc() { |
335 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
336 | format_tag::nwc); |
337 | } |
338 | }; |
339 | |
340 | struct jit_avx512_common_conv_bwd_data_kernel_f32 { |
341 | |
342 | jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) |
343 | : kernel_(nullptr) { |
344 | switch (ajcp.ic_block) { |
345 | case 16: |
346 | kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< |
347 | Xbyak::Zmm>(ajcp); |
348 | return; |
349 | case 8: |
350 | kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< |
351 | Xbyak::Ymm>(ajcp); |
352 | return; |
353 | case 4: |
354 | kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32< |
355 | Xbyak::Xmm>(ajcp); |
356 | return; |
357 | default: assert(!"invalid channel blocking" ); |
358 | } |
359 | } |
360 | |
361 | status_t create_kernel() { return kernel_->create_kernel(); } |
362 | |
363 | ~jit_avx512_common_conv_bwd_data_kernel_f32() { delete kernel_; } |
364 | |
365 | enum { typesize = sizeof(float) }; |
366 | |
367 | static status_t init_conf(jit_conv_conf_t &jcp, |
368 | const convolution_desc_t &cd, memory_desc_t &diff_src_d, |
369 | memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads); |
370 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
371 | const jit_conv_conf_t &jcp); |
372 | |
373 | void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } |
374 | const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } |
375 | |
376 | private: |
377 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_common_conv_bwd_data_kernel_f32); |
378 | jit_generator *kernel_; |
379 | }; |
380 | |
381 | struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator { |
382 | |
383 | jit_avx512_common_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp) |
384 | : jit_generator(jit_name()), jcp(ajcp) {} |
385 | |
386 | void generate() override { |
387 | if (jcp.harness != harness_nxc) |
388 | generate_kernel(); |
389 | else |
390 | generate_microkernel(); |
391 | } |
392 | |
393 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32) |
394 | |
395 | static status_t init_conf(jit_conv_conf_t &jcp, |
396 | const convolution_desc_t &cd, memory_desc_t &src_md, |
397 | memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, |
398 | memory_desc_t &diff_dst_md, int nthreads); |
399 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
400 | const jit_conv_conf_t &jcp); |
401 | |
402 | jit_conv_conf_t jcp; |
403 | |
404 | private: |
405 | using reg64_t = const Xbyak::Reg64; |
406 | enum { typesize = sizeof(float) }; |
407 | static const int max_ur_w; |
408 | static const int min_oh_reduce; |
409 | |
410 | reg64_t param = abi_param1; |
411 | reg64_t reg_input = rax; |
412 | reg64_t reg_kernel = rdx; |
413 | reg64_t reg_output = rsi; |
414 | reg64_t b_ic = abi_not_param1; |
415 | reg64_t kj = r8; |
416 | reg64_t reg_kh = r9; |
417 | reg64_t reg_ur_w_trips = r10; |
418 | reg64_t reg_oj = r15; |
419 | reg64_t reg_tmp = r14; |
420 | reg64_t reg_long_offt = r14; |
421 | reg64_t reg_icb = rbx; |
422 | |
423 | reg64_t ki = r11; |
424 | reg64_t reg_kd_count = r12; |
425 | reg64_t reg_oi = r12; |
426 | reg64_t reg_d_index = r13; |
427 | reg64_t reg_input_d = r15; |
428 | reg64_t reg_output_d = rbx; |
429 | reg64_t aux_reg_input = r12; |
430 | reg64_t aux_reg_kernel = r13; |
431 | reg64_t reg_bias = rbx; |
432 | reg64_t reg_oc_tail = r14; |
433 | |
434 | Xbyak::Opmask k_oc_mask = Xbyak::Opmask(2); |
435 | |
436 | inline void bias_kernel_2d(); |
437 | inline void bias_kernel_3d(); |
438 | inline void maybe_zero_kernel(); |
439 | inline void compute_oh_step_unroll_ow_icblock( |
440 | int ic_block_step, int max_ur_w); |
441 | inline void od_step_comeback_pointers(); |
442 | inline void oh_step_comeback_pointers(); |
443 | inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); |
444 | inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, |
445 | int ic_block_step, int input_offset, int kernel_offset, |
446 | int output_offset, bool input_wraparound = false); |
447 | inline void compute_ic_block_step_fma(int ur_w, int pad_l, int pad_r, |
448 | int ic_block_step, int input_offset, int kernel_offset, |
449 | int output_offset, bool input_wraparound); |
450 | inline void compute_ic_block_step_fma_expl(int ur_w, int pad_l, int pad_r, |
451 | int ic_block_step, int input_offset, int kernel_offset, |
452 | int output_offset, bool input_wraparound); |
453 | inline void compute_oh_step_common(int ic_block_step, int max_ur_w); |
454 | inline void compute_oh_step_disp(); |
455 | inline void compute_oh_loop_common(); |
456 | inline void compute_oh_loop_partial(); |
457 | inline void compute_od_loop_partial(); |
458 | |
459 | inline void compute_loop(); |
460 | inline bool is_src_layout_nxc() { |
461 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
462 | format_tag::nwc); |
463 | } |
464 | inline bool is_ddst_layout_nxc() { |
465 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
466 | format_tag::nwc); |
467 | } |
468 | |
469 | inline ptrdiff_t get_full_src_offset( |
470 | int i_iw, int i_ic, ptrdiff_t input_offset) { |
471 | const bool is_nxc_layout = is_src_layout_nxc(); |
472 | const size_t w_shift_st = (jcp.is_hw_transp ? jcp.iw : 1) |
473 | * (jcp.is_1stconv ? 1 : jcp.ic_block); |
474 | ptrdiff_t w_shift = is_nxc_layout ? jcp.ngroups * jcp.ic : w_shift_st; |
475 | ptrdiff_t ic_shift = jcp.is_1stconv && !is_nxc_layout |
476 | ? (ptrdiff_t)jcp.ih * jcp.iw * jcp.id |
477 | : 1; |
478 | |
479 | ptrdiff_t local_input_offset = i_iw * w_shift + i_ic * ic_shift; |
480 | return input_offset + typesize * local_input_offset; |
481 | }; |
482 | |
483 | inline int get_iw_idx(int ow, int kw, int l_pad) { |
484 | return ow * jcp.stride_w + kw * (jcp.dilate_w + 1) - l_pad; |
485 | } |
486 | |
487 | void generate_kernel(); |
488 | void generate_microkernel(); |
489 | |
490 | static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, |
491 | int &nthr_g, int &nthr_oc_b, int &nthr_ic_b, int nthreads); |
492 | }; |
493 | |
494 | } // namespace x64 |
495 | } // namespace cpu |
496 | } // namespace impl |
497 | } // namespace dnnl |
498 | |
499 | #endif |
500 | |