1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_COMMON_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_COMMON_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32template <typename Vmm>
33struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
34
35 _jit_avx512_common_conv_fwd_kernel(const jit_conv_conf_t &ajcp,
36 const primitive_attr_t &attr, const memory_desc_t &dst_md);
37
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
39
40 jit_conv_conf_t jcp;
41 const primitive_attr_t &attr_;
42
43private:
44 constexpr static int isa_simd_width_
45 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
46 using reg64_t = const Xbyak::Reg64;
47 enum {
48 typesize = sizeof(float),
49 ker_reg_base_idx = 28,
50 };
51
52 reg64_t param = abi_param1;
53 reg64_t reg_inp = r8;
54 reg64_t reg_ker = r9;
55 reg64_t reg_out = r10;
56
57 reg64_t reg_owb = r12;
58
59 reg64_t aux_reg_inp = r14;
60 reg64_t aux_reg_ker = r15;
61
62 reg64_t reg_channel = rsi;
63 reg64_t reg_bias = rdx;
64
65 reg64_t aux_reg_ker_d = r9;
66 reg64_t aux_reg_inp_d = rbx;
67 reg64_t reg_ki = r10;
68
69 reg64_t reg_kj = rax;
70 reg64_t reg_relu_ns = rax;
71 reg64_t reg_oi = rbx;
72 reg64_t reg_kh = abi_not_param1;
73
74 reg64_t reg_tmp = rbp;
75
76 reg64_t reg_long_offt = r11;
77 reg64_t reg_out_long_offt = r14;
78 reg64_t reg_ker_long_offt = r11;
79 reg64_t reg_tail = aux_reg_ker;
80 reg64_t reg_load_work = reg_tail;
81
82 /* binary post-ops operand */
83 reg64_t temp_offset_reg = r12;
84
85 Xbyak::Opmask k_oc_tail_mask = Xbyak::Opmask(2);
86 const Xbyak::Opmask postops_mask = Xbyak::Opmask(3);
87
88 inline Vmm vmm_ker(int i_ic) {
89 assert(i_ic < 4);
90 return Vmm(ker_reg_base_idx + i_ic);
91 }
92
93 inline int vmm_out_idx(int i_ur, int i_oc) {
94 const int idx = i_ur * jcp.nb_oc_blocking + i_oc;
95 assert(idx < ker_reg_base_idx);
96 return idx;
97 }
98
99 inline Vmm vmm_out(int i_ur, int i_oc) {
100 return Vmm(vmm_out_idx(i_ur, i_oc));
101 }
102
103 inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
104 int idx = i_ic + nb_x_blocking * jcp.ur_w;
105 assert(idx < 31);
106 return Vmm(idx);
107 }
108
109 Xbyak::Reg64 imm_addr64 = r15;
110 Vmm vmm_wei = Vmm(31);
111
112 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
113 postops_injector_;
114
115 inline void prepare_output(int ur_w);
116 inline void apply_postops(int ur_w);
117 inline void store_output(int ur_w);
118 inline void compute_loop_fma(int ur_w, int pad_l, int pad_r);
119 inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r);
120 inline void compute_loop(int ur_w, int pad_l, int pad_r);
121
122 void generate() override;
123
124 inline size_t get_output_offset(int oi, int n_oc_block) {
125 const bool is_nxc_layout = is_dst_layout_nxc();
126 size_t ow_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block;
127 size_t ocb_str = is_nxc_layout
128 ? jcp.oc_block
129 : (size_t)jcp.od * jcp.oh * jcp.ow * jcp.oc_block;
130
131 return jcp.typesize_out * (n_oc_block * ocb_str + oi * ow_str);
132 }
133
134 inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) {
135 const bool is_nxc_layout = is_src_layout_nxc();
136 size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic
137 : (!jcp.is_1stconv ? jcp.ic_block : 1);
138 size_t ic_str = !jcp.is_1stconv || is_nxc_layout
139 ? 1
140 : (size_t)jcp.iw * jcp.ih * jcp.id;
141 size_t iw_idx = ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l;
142
143 return jcp.typesize_in * (iw_idx * iw_str + ic * ic_str);
144 }
145
146 inline int get_kernel_offset(
147 int ki, int ic, int n_oc_block, int ker_number) {
148 return jcp.typesize_in * jcp.oc_block
149 * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw
150 * jcp.kd
151 + (ic + ker_number) + ki * jcp.ic_block);
152 }
153
154 inline int get_ow_start(int ki, int pad_l) {
155 return nstl::max(0,
156 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
157 }
158
159 inline int get_ow_end(int ur_w, int ki, int pad_r) {
160 return ur_w
161 - nstl::max(0,
162 utils::div_up(
163 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
164 jcp.stride_w));
165 }
166 inline bool is_src_layout_nxc() {
167 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
168 format_tag::nwc);
169 }
170 inline bool is_dst_layout_nxc() {
171 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
172 format_tag::nwc);
173 }
174};
175
176struct jit_avx512_common_conv_fwd_kernel {
177
178 jit_avx512_common_conv_fwd_kernel(const jit_conv_conf_t &ajcp,
179 const primitive_attr_t &attr, const memory_desc_t &dst_md)
180 : kernel_(nullptr) {
181 switch (ajcp.oc_block) {
182 case 16:
183 kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
184 ajcp, attr, dst_md);
185 return;
186 case 8:
187 kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Ymm>(
188 ajcp, attr, dst_md);
189 return;
190 case 4:
191 kernel_ = new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
192 ajcp, attr, dst_md);
193 return;
194 default: assert(!"invalid channel blocking");
195 }
196 }
197
198 status_t create_kernel() { return kernel_->create_kernel(); }
199
200 ~jit_avx512_common_conv_fwd_kernel() { delete kernel_; }
201
202 enum { typesize = sizeof(float) };
203
204 static status_t init_conf(jit_conv_conf_t &jcp,
205 const convolution_desc_t &cd, memory_desc_t &src_pd,
206 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
207 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
208 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
209 const jit_conv_conf_t &jcp);
210
211 void operator()(jit_conv_call_s *p) const { (*kernel_)(p); }
212
213 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
214
215private:
216 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_common_conv_fwd_kernel);
217 jit_generator *kernel_;
218};
219
220template <typename Vmm>
221struct _jit_avx512_common_conv_bwd_data_kernel_f32 : public jit_generator {
222
223 _jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp)
224 : jit_generator(jit_name()), jcp(ajcp) {}
225
226 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_bwd_data_kernel_f32)
227 jit_conv_conf_t jcp;
228
229private:
230 using reg64_t = const Xbyak::Reg64;
231 enum {
232 typesize = sizeof(float),
233 ker_reg_base_idx = 28,
234 };
235
236 reg64_t param = abi_param1;
237 reg64_t reg_dst = r8;
238 reg64_t reg_ker = r9;
239 reg64_t reg_src = r10;
240
241 reg64_t reg_iwb = r14;
242
243 reg64_t aux_reg_dst = r14;
244 reg64_t aux_reg_ker = r15;
245
246 reg64_t aux_reg_dst_d = rbx;
247 reg64_t aux_reg_ker_d = r9;
248 reg64_t reg_ki = r10;
249
250 reg64_t reg_kj = rax;
251 reg64_t reg_oi = rbx;
252 reg64_t reg_kh = abi_not_param1;
253
254 reg64_t reg_channel = rsi;
255
256 reg64_t reg_tmp = rbp;
257 reg64_t reg_long_offt = r14;
258
259 reg64_t reg_tail = aux_reg_ker;
260 reg64_t reg_load_work = reg_tail;
261
262 Xbyak::Opmask k_ic_tail_mask = Xbyak::Opmask(1);
263
264 inline Vmm vmm_ker(int i_ic) {
265 assert(i_ic < 4);
266 return Vmm(ker_reg_base_idx + i_ic);
267 }
268 inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
269 int idx = i_ic + nb_x_blocking * jcp.ur_w;
270 assert(idx < 31);
271 return Vmm(idx);
272 }
273 inline Vmm vmm_out(int i_ur, int i_oc) {
274 int idx = i_ur + i_oc * jcp.ur_w;
275 assert(idx < ker_reg_base_idx);
276 return Vmm(idx);
277 }
278
279 Vmm vmm_wei = Vmm(31);
280
281 inline void prepare_output(int ur_w);
282 inline void store_output(int ur_w);
283 inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow);
284 inline void compute_loop_fma_core(
285 int ur_w, int l_overflow, int r_overflow, int k_offset);
286 inline void compute_loop(
287 int ur_w, int l_overflow, int r_overflow, int k_offset = 0);
288 void generate() override;
289
290 inline int get_iw_start(int ki, int l_overflow) {
291 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
292 + l_overflow * jcp.stride_w
293 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
294 while (res < 0)
295 res += jcp.stride_w;
296
297 return res;
298 }
299
300 inline int get_iw_end(int ur_w, int ki, int r_overflow) {
301 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
302 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
303 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
304 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
305 while (res < 0)
306 res += jcp.stride_w;
307
308 return ur_w - res;
309 }
310
311 inline size_t get_diff_src_offset(int iw, int icb) {
312 const bool is_nxc_layout = is_dsrc_layout_nxc();
313 size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block;
314 size_t icb_str = is_nxc_layout
315 ? jcp.ic_block
316 : (size_t)jcp.id * jcp.ih * jcp.iw * jcp.ic_block;
317
318 return typesize * (icb * icb_str + iw * iw_str);
319 }
320
321 inline ptrdiff_t get_dst_offset(int iw, int oc, int kw) {
322 ptrdiff_t ow
323 = (iw + jcp.l_pad - kw * (jcp.dilate_w + 1)) / jcp.stride_w;
324 ptrdiff_t ow_str
325 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
326
327 return typesize * (ow * ow_str + oc);
328 };
329
330 inline bool is_dsrc_layout_nxc() {
331 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
332 format_tag::nwc);
333 }
334 inline bool is_ddst_layout_nxc() {
335 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
336 format_tag::nwc);
337 }
338};
339
340struct jit_avx512_common_conv_bwd_data_kernel_f32 {
341
342 jit_avx512_common_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp)
343 : kernel_(nullptr) {
344 switch (ajcp.ic_block) {
345 case 16:
346 kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32<
347 Xbyak::Zmm>(ajcp);
348 return;
349 case 8:
350 kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32<
351 Xbyak::Ymm>(ajcp);
352 return;
353 case 4:
354 kernel_ = new _jit_avx512_common_conv_bwd_data_kernel_f32<
355 Xbyak::Xmm>(ajcp);
356 return;
357 default: assert(!"invalid channel blocking");
358 }
359 }
360
361 status_t create_kernel() { return kernel_->create_kernel(); }
362
363 ~jit_avx512_common_conv_bwd_data_kernel_f32() { delete kernel_; }
364
365 enum { typesize = sizeof(float) };
366
367 static status_t init_conf(jit_conv_conf_t &jcp,
368 const convolution_desc_t &cd, memory_desc_t &diff_src_d,
369 memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads);
370 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
371 const jit_conv_conf_t &jcp);
372
373 void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
374 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
375
376private:
377 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_common_conv_bwd_data_kernel_f32);
378 jit_generator *kernel_;
379};
380
381struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator {
382
383 jit_avx512_common_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp)
384 : jit_generator(jit_name()), jcp(ajcp) {}
385
386 void generate() override {
387 if (jcp.harness != harness_nxc)
388 generate_kernel();
389 else
390 generate_microkernel();
391 }
392
393 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32)
394
395 static status_t init_conf(jit_conv_conf_t &jcp,
396 const convolution_desc_t &cd, memory_desc_t &src_md,
397 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
398 memory_desc_t &diff_dst_md, int nthreads);
399 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
400 const jit_conv_conf_t &jcp);
401
402 jit_conv_conf_t jcp;
403
404private:
405 using reg64_t = const Xbyak::Reg64;
406 enum { typesize = sizeof(float) };
407 static const int max_ur_w;
408 static const int min_oh_reduce;
409
410 reg64_t param = abi_param1;
411 reg64_t reg_input = rax;
412 reg64_t reg_kernel = rdx;
413 reg64_t reg_output = rsi;
414 reg64_t b_ic = abi_not_param1;
415 reg64_t kj = r8;
416 reg64_t reg_kh = r9;
417 reg64_t reg_ur_w_trips = r10;
418 reg64_t reg_oj = r15;
419 reg64_t reg_tmp = r14;
420 reg64_t reg_long_offt = r14;
421 reg64_t reg_icb = rbx;
422
423 reg64_t ki = r11;
424 reg64_t reg_kd_count = r12;
425 reg64_t reg_oi = r12;
426 reg64_t reg_d_index = r13;
427 reg64_t reg_input_d = r15;
428 reg64_t reg_output_d = rbx;
429 reg64_t aux_reg_input = r12;
430 reg64_t aux_reg_kernel = r13;
431 reg64_t reg_bias = rbx;
432 reg64_t reg_oc_tail = r14;
433
434 Xbyak::Opmask k_oc_mask = Xbyak::Opmask(2);
435
436 inline void bias_kernel_2d();
437 inline void bias_kernel_3d();
438 inline void maybe_zero_kernel();
439 inline void compute_oh_step_unroll_ow_icblock(
440 int ic_block_step, int max_ur_w);
441 inline void od_step_comeback_pointers();
442 inline void oh_step_comeback_pointers();
443 inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
444 inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
445 int ic_block_step, int input_offset, int kernel_offset,
446 int output_offset, bool input_wraparound = false);
447 inline void compute_ic_block_step_fma(int ur_w, int pad_l, int pad_r,
448 int ic_block_step, int input_offset, int kernel_offset,
449 int output_offset, bool input_wraparound);
450 inline void compute_ic_block_step_fma_expl(int ur_w, int pad_l, int pad_r,
451 int ic_block_step, int input_offset, int kernel_offset,
452 int output_offset, bool input_wraparound);
453 inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
454 inline void compute_oh_step_disp();
455 inline void compute_oh_loop_common();
456 inline void compute_oh_loop_partial();
457 inline void compute_od_loop_partial();
458
459 inline void compute_loop();
460 inline bool is_src_layout_nxc() {
461 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
462 format_tag::nwc);
463 }
464 inline bool is_ddst_layout_nxc() {
465 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
466 format_tag::nwc);
467 }
468
469 inline ptrdiff_t get_full_src_offset(
470 int i_iw, int i_ic, ptrdiff_t input_offset) {
471 const bool is_nxc_layout = is_src_layout_nxc();
472 const size_t w_shift_st = (jcp.is_hw_transp ? jcp.iw : 1)
473 * (jcp.is_1stconv ? 1 : jcp.ic_block);
474 ptrdiff_t w_shift = is_nxc_layout ? jcp.ngroups * jcp.ic : w_shift_st;
475 ptrdiff_t ic_shift = jcp.is_1stconv && !is_nxc_layout
476 ? (ptrdiff_t)jcp.ih * jcp.iw * jcp.id
477 : 1;
478
479 ptrdiff_t local_input_offset = i_iw * w_shift + i_ic * ic_shift;
480 return input_offset + typesize * local_input_offset;
481 };
482
483 inline int get_iw_idx(int ow, int kw, int l_pad) {
484 return ow * jcp.stride_w + kw * (jcp.dilate_w + 1) - l_pad;
485 }
486
487 void generate_kernel();
488 void generate_microkernel();
489
490 static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
491 int &nthr_g, int &nthr_oc_b, int &nthr_ic_b, int nthreads);
492};
493
494} // namespace x64
495} // namespace cpu
496} // namespace impl
497} // namespace dnnl
498
499#endif
500