1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_AMX_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_AMX_CONV_KERNEL_HPP
19
20#include <queue>
21
22#include "common/c_types_map.hpp"
23#include "common/memory_tracking.hpp"
24
25#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
26#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
27#include "cpu/x64/jit_generator.hpp"
28#include "cpu/x64/jit_primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35/* This struct computes the compensation for src_zero_point related to
36 * padding */
37struct jit_avx512_core_amx_compute_zp_pbuff_t : public jit_generator {
38
39 using reg64_t = const Xbyak::Reg64;
40
41 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_compute_zp_pbuff_t)
42
43 jit_avx512_core_amx_compute_zp_pbuff_t(const jit_conv_conf_t &ajcp)
44 : jit_generator(
45 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
46 , jcp(ajcp) {}
47
48 static const int max_regs_ur = 30;
49
50private:
51 jit_conv_conf_t jcp;
52
53 typedef enum { no_last_block, last_ic_block } ic_block_t;
54 const int ic_inner_block = 4;
55
56 Xbyak::Label permb_idx_label;
57 Xbyak::Label ic_mask_label;
58
59 const reg64_t reg_zp_pbuff = r8;
60 const reg64_t reg_src_zero_point = r9;
61 const reg64_t reg_filt = r10;
62 const reg64_t aux_reg_filt = r11;
63 const reg64_t aux_reg_filt_d = r15;
64
65 const reg64_t reg_oc_blocks = r12;
66 const reg64_t reg_icb = r13;
67 const reg64_t reg_oi = r14;
68 const reg64_t reg_kj = rax;
69 const reg64_t reg_ki = rbx;
70 const reg64_t reg_overflow = reg_kj;
71 const reg64_t reg_scratch = rsi;
72
73 const Xbyak::Zmm zmm_one = Xbyak::Zmm(31);
74 const Xbyak::Zmm zmm_permb = Xbyak::Zmm(30);
75
76 const Xbyak::Opmask kmask_ic_block = Xbyak::Opmask(1);
77 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
78
79 void prepare_output(int ur_w);
80 void store_output(int ur_w, bool last_oc_block_flag);
81 void compute_ker(int ur_w, int pad_l, int pad_r,
82 ic_block_t last_ic_block_flag, bool padded);
83 void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag,
84 bool handle_h_pad);
85 void kd_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag,
86 bool handle_h_pad);
87 void icb_loop(int ur_w, int pad_l, int pad_r, bool handle_h_pad);
88 void unroll_width(const bool h_padding);
89
90 void generate() override;
91
92 Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
93 int idx = i_ur * jcp.nb_oc_blocking + i_oc;
94 assert(idx < max_regs_ur);
95 return Xbyak::Zmm(idx);
96 }
97 int get_ow_start(int ki, int pad_l) {
98 return nstl::max(0,
99 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
100 }
101 int get_ow_end(int ur_w, int ki, int pad_r) {
102 int filter_overlap = pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
103 return ur_w - nstl::max(0, utils::div_up(filter_overlap, jcp.stride_w));
104 }
105};
106
107struct jit_avx512_core_amx_copy_to_wbuffer_t : public jit_generator {
108 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_to_wbuffer_t)
109
110 using reg64_t = Xbyak::Reg64;
111
112 jit_avx512_core_amx_copy_to_wbuffer_t(const jit_conv_conf_t &ajcp)
113 : jit_generator(
114 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
115 , jcp(ajcp) {}
116
117private:
118 jit_conv_conf_t jcp;
119
120 const reg64_t reg_src = rax;
121 const reg64_t reg_dst = rbx;
122 const reg64_t reg_tmp = rdx;
123
124 const Xbyak::Opmask kmask_load = k2;
125
126 const Xbyak::Zmm zmm_src = zmm0;
127 const Xbyak::Zmm zmm_dst = zmm1;
128 const Xbyak::Zmm zmm_idx = zmm2;
129 const Xbyak::Zmm zmm_zero = zmm3;
130
131 void generate() override;
132};
133
134struct jit_avx512_core_amx_copy_to_pbuffer_t : public jit_generator {
135 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_to_pbuffer_t)
136
137 using reg64_t = Xbyak::Reg64;
138
139 jit_avx512_core_amx_copy_to_pbuffer_t(const jit_conv_conf_t &ajcp)
140 : jit_generator(
141 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
142 , jcp(ajcp) {}
143
144private:
145 jit_conv_conf_t jcp;
146
147 const reg64_t reg_inp_ptr = r15;
148 const reg64_t reg_out_ptr = r14;
149
150 const reg64_t reg_aux_inp_ptr = r13;
151 const reg64_t reg_aux_out_ptr = r12;
152
153 const reg64_t reg_khp = r10;
154
155 /* relow stuff */
156 const reg64_t reg_kht = r11;
157 const reg64_t reg_tov = r9;
158 const reg64_t reg_bov = r8;
159 const reg64_t reg_kwp = rax;
160 const reg64_t reg_lov = reg_aux_inp_ptr;
161 const reg64_t reg_rov = rbx;
162 const reg64_t reg_save_out_ptr = rdx;
163 const reg64_t reg_cnt = rbp;
164 /* relow stuff */
165
166 /* non-relow stuff */
167 const reg64_t reg_kdp = abi_not_param1;
168 const reg64_t reg_kdc = rbp;
169 const reg64_t reg_khc = r11;
170
171 const reg64_t reg_kh_over = r8;
172 const reg64_t reg_tover = rax;
173 const reg64_t reg_bover = rbx;
174
175 const reg64_t reg_owb = rdx;
176 /* non-relow stuff */
177
178 const reg64_t reg_tmp = rsi;
179
180 const Xbyak::Opmask &ktail_mask = k2;
181
182 const Xbyak::Ymm &ymm_tmp = ymm0;
183 const Xbyak::Zmm &zmm_tmp = zmm0;
184 const Xbyak::Zmm &zmm_zero = zmm1;
185
186 void generate() override;
187 void copy_row(int icb);
188 void copy_row_body(int lpad, int iw_len, int icb);
189 void copy_row_reduced_lowering();
190};
191
192struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator {
193 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_fwd_kernel_t)
194
195 jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t &ajcp,
196 const primitive_attr_t &attr, const memory_desc_t &dst_md);
197
198 status_t create_kernel() override;
199
200 static status_t init_conf(jit_conv_conf_t &jcp,
201 const convolution_desc_t &cd, memory_desc_t &src_pd,
202 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
203 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
204 static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
205 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
206
207 inline int accum_with_upper_bound(
208 int upper_bound, int lower_value, int upper_value) {
209 return nstl::min(upper_bound,
210 nstl::min(upper_bound, lower_value)
211 + nstl::max(0, upper_bound - upper_value));
212 }
213
214 /* Calculate and store the limits relevant to 'ow_block'. These limits
215 * allow the driver code to determine which 'ow_block' is currently being
216 * executed. There can be at most 5 different 'ow_block', each corresponding
217 * to:
218 * - l_pad block
219 * - middle block (no padding)
220 * - middle & r_pad shift block
221 * - r_pad (full) block
222 * - r_pad tail
223 * */
224 static void set_ow_blk_limits(jit_conv_conf_t &jcp);
225
226 /* Calculate and store the limits relevant to each 'oh_block'. Each
227 * 'oh_block' size is 'nb_oh_blocking * oh_per_tile'. These limits allow
228 * the driver code to determine which 'oh_block' is currently being
229 * executed, and what is the oh value required to advance the limits index.
230 *
231 * There can be at most 6 different 'oh_blk', depending on the sizes of
232 * 't_pad_output', 'b_pad_output' and their overlap with
233 * 'nb_oh_blocking * oh_per_tile'.
234 *
235 * For example, given the following input dimensions of {height_size = 12,
236 * oh_blk_size = 2, top_padding = 5 (t_pad), bottom_padding = 2 (b_pad)},
237 * the 4 output height blocks and limits are:
238 *
239 * H: _ H_blks:_ Limits:
240 * 0 | | 0|X|
241 * 1 | | |X|_4
242 * 2 | | t_pad
243 * 3 | | _
244 * 4 |_| 1|X|
245 * 5 | | |_|_5
246 * 6 | | 2| |
247 * 7 | | |_|_9
248 * 8 | |
249 * 9 |_| _
250 * 10| | b_pad 3|X|
251 * 11|_| |X|_11
252 *
253 * -where 'x' represents
254 * an 'h_blk' with output
255 * padding.
256 * */
257 static void set_oh_blk_limits(jit_conv_conf_t &jcp);
258
259 void tile_configure(char *tcfg_buff);
260
261 jit_conv_conf_t jcp;
262 const primitive_attr_t &attr_;
263
264 const jit_avx512_core_amx_copy_to_pbuffer_t &copy_to_pbuffer() const {
265 return *copy_to_pbuffer_;
266 }
267 const jit_avx512_core_amx_copy_to_wbuffer_t &copy_to_wbuffer() const {
268 return *copy_to_wbuffer_;
269 }
270 const jit_avx512_core_amx_compute_zp_pbuff_t &zp_pbuff_kernel() const {
271 return *zp_pbuff_kernel_;
272 }
273
274private:
275 constexpr static int isa_simd_width_
276 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
277 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
278 postops_injector_;
279 std::unique_ptr<jit_avx512_core_amx_copy_to_pbuffer_t> copy_to_pbuffer_;
280 std::unique_ptr<jit_avx512_core_amx_copy_to_wbuffer_t> copy_to_wbuffer_;
281 std::unique_ptr<jit_avx512_core_amx_compute_zp_pbuff_t> zp_pbuff_kernel_;
282
283 enum {
284 zmm_idx_limit_bf16 = 29,
285 zmm_idx_limit_int8 = 27,
286 };
287
288 int prv_width_ = 0;
289 int row_count_ = 0;
290 bool is_store_done_ = false;
291 bool is_buffer_empty_ = true;
292
293 struct w_pad_output {
294 int l_pad_output;
295 int r_pad_output;
296 w_pad_output(int l_, int r_) : l_pad_output(l_), r_pad_output(r_) {}
297 };
298 std::queue<w_pad_output> w_padding;
299
300 /* data regs */
301 const Xbyak::Reg64 reg_inp_ptr = r15;
302 const Xbyak::Reg64 reg_wei_ptr = r14;
303 const Xbyak::Reg64 reg_out_ptr = r13;
304 const Xbyak::Reg64 reg_wsp_ptr = r12;
305
306 const Xbyak::Reg64 reg_kd = r9;
307
308 const Xbyak::Reg64 reg_bias = r11;
309 const Xbyak::Reg64 reg_ptr_scales = r10;
310 const Xbyak::Reg64 reg_ptr_sum_scale = r9;
311 const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1;
312 const Xbyak::Reg64 reg_aux_saturation = reg_ptr_sum_scale;
313
314 const Xbyak::Reg64 reg_inp_stride = rbx;
315 const Xbyak::Reg64 reg_wei_stride = rdx;
316 // zero-point computation
317 const Xbyak::Reg64 reg_zp_compensation = rax;
318 const Xbyak::Reg64 reg_src_zero_point = r8;
319 const Xbyak::Reg64 reg_zero_point_pbuff = rsi;
320 const Xbyak::Reg64 reg_dst_zero_point = abi_not_param1;
321 const Xbyak::Reg64 reg_dst_scale = reg_dst_zero_point;
322
323 // rbp - reserved for EVEX compression
324 const Xbyak::Reg64 reg_last_h = abi_not_param1;
325 const Xbyak::Reg64 reg_jmp_blk = reg_last_h;
326
327 // temporary, used in generate() function only
328 const Xbyak::Reg64 reg_oc_blocks = rax;
329 const Xbyak::Reg64 reg_tmp = r8;
330
331 const Xbyak::Opmask &ktail_mask = k2;
332
333 const Xbyak::Zmm &zmm_bias = zmm31;
334 const Xbyak::Zmm &zmm_saturation = zmm_bias;
335 const Xbyak::Zmm &zmm_zero = zmm30;
336 const Xbyak::Zmm &zmm_prev_dst = zmm29;
337 const Xbyak::Zmm &zmm_sum_zp = zmm26;
338 /* zero-point */
339 const Xbyak::Zmm &zmm_zp = zmm29;
340 const Xbyak::Zmm &zmm_src_zp = zmm28;
341 const Xbyak::Zmm &zmm_dst_zp = zmm27;
342 /* dst scale */
343 const Xbyak::Zmm &zmm_dst_scale = zmm25;
344
345 const Xbyak::Reg64 bin_injector_helper_reg_1 = r14;
346 const Xbyak::Reg64 bin_injector_helper_reg_2 = r15;
347 const Xbyak::Reg64 bin_injector_helper_reg_3 = r11;
348
349 // AUX: Steps, shifts and offsets
350 size_t get_inp_icb_step() const;
351 size_t get_wei_icb_step() const;
352 size_t get_inp_d_step() const;
353 size_t get_inp_h_step() const;
354 size_t get_wei_d_step() const;
355 size_t get_wei_h_step() const;
356 size_t get_out_ocb_offset(int ohb, int ocb, size_t typesize) const;
357 size_t get_out_row_offset(int ohb, int ocb, int j, size_t typesize) const;
358 size_t get_out_shift(int width, size_t typesize) const;
359 size_t get_wsp_ocb_offset(int ohb, int ocb) const;
360 size_t get_wsp_row_offset(int ohb, int ocb, int j) const;
361 size_t get_wsp_shift() const;
362 size_t get_wei_offset(int ocb, int kw) const;
363 size_t get_inp_shift() const;
364 size_t get_inp_offset(int ohb, int kw) const;
365 size_t get_zp_comp_offset(int ocb, int zp_h, int zp_w) const;
366 int get_zp_index_offset(
367 int index, int mid, int s_pad_output, int e_pad_output);
368
369 int get_out_tensor(int h, int i, bool is_h_tail = false) const;
370 int get_inp_tensor(int h, bool is_h_tail = false) const;
371 int get_wei_tensor(int i) const;
372
373 void prepare_output(int tail);
374 void init_runtime_counters(bool start_with_last_tile_block);
375 size_t reduce_to_block(const int block_size, const int pad_output);
376 size_t reduce_to_blocked_dims(const int dim_size, const int block_size,
377 const int s_pad_output, const int e_pad_output);
378 void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in,
379 const Xbyak::Operand &op, bool mask_flag = false);
380 Xbyak::Zmm zmm_out(const int idx) {
381 const int upper_limit = jcp.src_dt == data_type::bf16
382 ? zmm_idx_limit_bf16
383 : zmm_idx_limit_int8;
384 assert(upper_limit > idx);
385 MAYBE_UNUSED(upper_limit);
386 return Xbyak::Zmm(idx);
387 }
388 Xbyak::Ymm ymm_mask(
389 const Xbyak::Ymm &zmm_in, bool mask_flag, bool store = false);
390 Xbyak::Zmm zmm_mask(
391 const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false);
392 void apply_sum(const Xbyak::Zmm &zmm_out, const float *p_sum_scale,
393 const int32_t *p_sum_zp, const Xbyak::Address &addr,
394 const bool mask_flag);
395 void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale,
396 const int32_t *p_sum_zp, const Xbyak::Address &addr,
397 const size_t off, const bool mask_flag);
398 inline void store_output_ymm_bf16(
399 const int idx, const Xbyak::Address &addr, const bool mask_flag);
400 void store_output_vector_bf16(
401 const Xbyak::Zmm &zmm_out, int ocb, int h, int w);
402 void store_output_vector_int8(const Xbyak::Zmm &zmm_out, int ocb, int h,
403 int w, const bool compute_zp, const int zp_h, const int zp_w);
404 void store_output_vector(const Xbyak::Zmm &zmm_out, int ocb, int h, int w,
405 const bool compute_zp = false, const int zp_h = 0,
406 const int zp_w = 0);
407 void store_output(int width, int tail, bool do_store,
408 const bool handle_h_block, const int t_pad_output,
409 const int b_pad_output, const int l_pad_output,
410 const int r_pad_output, const bool is_last_oh_block,
411 const bool zp_3d_pad = false);
412 void interleave_store(int width, int const t_pad_output,
413 int const b_pad_output, const bool zp_3d_pad = false);
414 void compute_icb_loop(int width, bool do_store, const bool handle_h_block,
415 const int t_pad_output, const int b_pad_output,
416 const int l_pad_output, const int r_pad_output,
417 const bool zp_3d_pad, const bool is_last_oh_block = false);
418 void dispatch_icb_loop(int width, bool do_store, const int l_pad_output,
419 const int r_pad_output, const bool zp_3d_pad);
420 void dispatch_zp_3d_compute(int width, bool do_store,
421 const int l_pad_output, const int r_pad_output);
422 void compute_ow_loop();
423
424 void generate() override;
425};
426
427struct jit_avx512_core_amx_bwd_data_copy_kernel_t : public jit_generator {
428 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_data_copy_kernel_t)
429
430 using reg64_t = Xbyak::Reg64;
431
432 jit_avx512_core_amx_bwd_data_copy_kernel_t(jit_conv_conf_t ajcp)
433 : jit_generator(
434 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
435 , jcp(ajcp) {}
436
437private:
438 jit_conv_conf_t jcp;
439
440 // pointers
441 const reg64_t reg_ptr_inp = r15;
442 const reg64_t reg_ptr_out = r14;
443
444 // auxiliary pointers
445 const reg64_t reg_ptr_aux_inp_h = r13;
446 const reg64_t reg_ptr_aux_inp_w = r12;
447 const reg64_t reg_ptr_aux_out = r11;
448
449 // variables
450 const reg64_t reg_khp = r10; // kh padding
451 const reg64_t reg_tov = r9; // top overflow
452 const reg64_t reg_bov = reg_tov; // bottom overflow
453 const reg64_t reg_kwp = rax; // kw padding
454 const reg64_t reg_lov = rbx; // left overflow
455 const reg64_t reg_rov = abi_not_param1; // right overflow
456 const reg64_t reg_kd = r8; // 3d filter
457
458 // counters
459 const reg64_t reg_cnt_khp = rdx;
460 const reg64_t reg_cnt_tmp = rbp;
461 const reg64_t reg_cnt_ocb = rsi;
462
463 const reg64_t reg_tmp = reg_cnt_tmp;
464
465 const Xbyak::Opmask ktail_mask = k2;
466
467 const Xbyak::Zmm zmm_tmp = zmm1;
468 const Xbyak::Zmm zmm_zero = zmm0;
469
470 void generate() override;
471 void copy_row(bool is_masked);
472 void kd_loop(bool is_masked);
473};
474
475struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator {
476 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_data_kernel_t)
477
478 jit_avx512_core_amx_bwd_data_kernel_t(
479 const jit_conv_conf_t ajcp, const primitive_attr_t &attr)
480 : jit_generator(
481 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
482 , jcp(ajcp)
483 , attr_(attr)
484 , eltwise_injector_(nullptr) {
485 if (jcp.with_eltwise)
486 eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_core>(
487 this, jcp.eltwise);
488 bwd_data_copy_kernel_
489 = new jit_avx512_core_amx_bwd_data_copy_kernel_t(jcp);
490 }
491 status_t create_kernel() override {
492 CHECK(jit_generator::create_kernel());
493 CHECK(bwd_data_copy_kernel_->create_kernel());
494 return status::success;
495 }
496 ~jit_avx512_core_amx_bwd_data_kernel_t() {
497 delete eltwise_injector_;
498 delete bwd_data_copy_kernel_;
499 }
500
501 static bool post_ops_ok(const jit_conv_conf_t &jcp, primitive_attr_t &attr);
502
503 static status_t init_conf(jit_conv_conf_t &jcp,
504 const convolution_desc_t &cd, memory_desc_t &diff_src_pd,
505 memory_desc_t &weights_pd, memory_desc_t &diff_dst_pd,
506 memory_desc_t *bias_pd, primitive_attr_t &attr, int nthreads);
507 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
508 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
509
510 void tile_configure(char *tcfg_buff);
511
512 jit_conv_conf_t jcp;
513 const primitive_attr_t &attr_;
514
515 const jit_avx512_core_amx_bwd_data_copy_kernel_t &
516 bwd_data_copy_kernel() const {
517 return *bwd_data_copy_kernel_;
518 }
519
520private:
521 jit_uni_eltwise_injector_f32<avx512_core> *eltwise_injector_;
522 jit_avx512_core_amx_bwd_data_copy_kernel_t *bwd_data_copy_kernel_;
523
524 int prv_width_ = 0;
525 int row_count_ = 0;
526 bool is_store_done_ = false;
527 bool is_buffer_empty_ = true;
528
529 bool is_store_done_save_ = false;
530 int prv_width_save_ = 0;
531
532 /* data regs */
533 const Xbyak::Reg64 reg_inp_ptr = r15;
534 const Xbyak::Reg64 reg_wei_ptr = r14;
535 const Xbyak::Reg64 reg_out_ptr = r13;
536 const Xbyak::Reg64 reg_wsp_ptr = r12;
537
538 const Xbyak::Reg64 reg_bias = r11;
539 const Xbyak::Reg64 reg_ptr_scales = r10;
540 const Xbyak::Reg64 reg_ptr_dst_scales = r10;
541 const Xbyak::Reg64 reg_ptr_sum_scale = r9;
542 const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1;
543 const Xbyak::Reg64 reg_aux_saturation = reg_ptr_sum_scale;
544
545 const Xbyak::Reg64 reg_aux_inp_ptr = r8;
546 const Xbyak::Reg64 reg_inp_stride = rbx;
547 const Xbyak::Reg64 reg_wei_stride = rdx;
548
549 // rbp - reserved for EVEX compression
550 const Xbyak::Reg64 reg_last_h = abi_not_param1;
551 const Xbyak::Reg64 reg_kd = rsi;
552
553 // temporary, used in generate() function only
554 const Xbyak::Reg64 reg_ic_blocks = rax;
555 const Xbyak::Reg64 reg_tmp = reg_aux_inp_ptr;
556
557 const Xbyak::Opmask ktail_mask = k2;
558
559 const Xbyak::Zmm zmm_bias = zmm31;
560 const Xbyak::Zmm zmm_saturation = zmm_bias;
561 const Xbyak::Zmm zmm_zero = zmm30;
562 const Xbyak::Zmm zmm_prev_dst = zmm29;
563 const Xbyak::Zmm zmm_sum_zp = zmm28;
564 /* dst scale */
565 const Xbyak::Zmm &zmm_dst_scale = zmm27;
566
567 // AUX: Steps, shifts and offsets
568 size_t get_inp_ocb_step() const;
569 size_t get_inp_offset(int ihb, int kh, int kw) const;
570 size_t get_inp_shift() const;
571 size_t get_inp_d_step() const;
572 size_t get_out_icb_offset(int ihb, int icb) const;
573 size_t get_out_row_offset(int ihb, int icb, int j) const;
574 size_t get_out_shift(int width) const;
575 size_t get_wei_kh_step() const;
576 size_t get_wei_ocb_step() const;
577 size_t get_wei_offset(int icb, int kh, int kw) const;
578 size_t get_wei_d_step() const;
579 size_t get_wsp_icb_offset(int ihb, int icb) const;
580 size_t get_wsp_row_offset(int ihb, int icb, int j) const;
581 size_t get_wsp_shift() const;
582
583 int get_out_tensor(int h, int i) const;
584 int get_inp_tensor(int h) const;
585 int get_wei_tensor(int i) const;
586
587 inline bool gaps_in_store() {
588 const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1;
589 return gen_kd < jcp.stride_d || jcp.dilate_d > 0;
590 }
591
592 void prepare_output();
593 void init_runtime_counters(bool start_with_last_tile_block);
594
595 bool maybe_eltwise(int position);
596 void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in,
597 const Xbyak::Operand &op, bool mask_flag = false);
598 Xbyak::Ymm ymm_mask(
599 const Xbyak::Ymm &zmm_in, bool mask_flag, bool store = false);
600 Xbyak::Zmm zmm_mask(
601 const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false);
602
603 void store_output_vector_bf16(
604 const Xbyak::Zmm &zmm_out, int icb, int ihb, int iw);
605 void store_output_vector_int8(
606 const Xbyak::Zmm &zmm_out, int icb, int ihb, int iw);
607 void store_output_vector(
608 const Xbyak::Zmm &zmm_out, int icb, int ih, int iw);
609 void store_output(int width, bool do_store);
610 void skipped_interleave_store();
611 void interleave_store(int width);
612 void compute_ocb_loop(int width, bool do_interleave_store);
613 void compute_kd_loop(int width, bool do_store, bool handle_skipped_stores);
614 void compute_iw_loop();
615
616 void generate() override;
617};
618
619struct jit_avx512_core_amx_bwd_weights_kernel_t : public jit_generator {
620
621 jit_avx512_core_amx_bwd_weights_kernel_t(const jit_conv_conf_t &ajcp)
622 : jit_generator(
623 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
624 , jcp(ajcp) {}
625
626 ~jit_avx512_core_amx_bwd_weights_kernel_t() {}
627
628 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_weights_kernel_t)
629
630 static status_t init_conf(jit_conv_conf_t &jcp,
631 const convolution_desc_t &cd, memory_desc_t &src_md,
632 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
633 memory_desc_t &diff_dst_md, int nthreads);
634 static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
635 const jit_conv_conf_t &jcp, memory_desc_t &src_md,
636 memory_desc_t &diff_weights_md, memory_desc_t &diff_dst_md);
637
638 void tile_configure(char *tcfg_buff);
639
640 const jit_conv_conf_t &jcp;
641
642private:
643 int get_wei_tensor(int ocb, int icb) const;
644 int get_src_tensor(int icb) const;
645 int get_ddst_tensor(int ocb) const;
646
647 using reg64_t = const Xbyak::Reg64;
648 static const int max_ur_w;
649
650 reg64_t param = abi_param1;
651 reg64_t reg_src = rax;
652 reg64_t reg_kernel = rdx;
653 reg64_t reg_ddst = rsi;
654 reg64_t b_ic = abi_not_param1;
655 reg64_t kj = r8;
656 reg64_t reg_kh = r9;
657 reg64_t reg_oj = r15;
658 reg64_t reg_tmp = r14;
659 reg64_t reg_ih_shift = reg_tmp;
660 reg64_t reg_long_offt = r14;
661 reg64_t reg_icb = rbx;
662
663 reg64_t ki = r11;
664 reg64_t reg_oj_setup = r11;
665 reg64_t reg_kd_count = r12;
666 reg64_t reg_oi = r12;
667 reg64_t reg_d_index = r13;
668 reg64_t reg_src_d = r15;
669 reg64_t reg_ddst_d = rbx;
670 reg64_t aux_reg_src = r12;
671 reg64_t aux_reg_kernel = r13;
672
673 reg64_t reg_b_stride = reg_icb;
674 reg64_t reg_a_stride = r10;
675
676 Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0);
677 Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1);
678 Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2);
679
680 enum {
681 full_spat_opt_working_set_size = 48 * 1024,
682 full_spat_max_working_set_size = 128 * 1024,
683 };
684
685 inline void maybe_zero_kernel(int nb_ic_blocking, int nb_oc_blocking);
686 inline void od_step_comeback_pointers();
687 inline void oh_step_comeback_pointers();
688 inline void compute_ic_loop(
689 int ic_block, int nb_ic_blocking, int nb_oc_blocking);
690 inline void compute_full_spat_loop(int nb_ic_blocking, int nb_oc_blocking);
691 inline void compute_oh_step_common(int nb_ic_blocking, int nb_oc_blocking);
692 inline void compute_loop(int nb_ic_blocking, int nb_oc_blocking);
693 inline void compute_oh_loop_common(
694 int nb_ic_blocking, int nb_oc_blocking, bool partial = false);
695 inline void compute_od_loop_common(
696 int nb_ic_blocking, int nb_oc_blocking, bool partial = false);
697 void compute_diff_bias_init(int ocb = 0);
698 void compute_diff_bias_row(bool is_partial, int ocb);
699 void maybe_compute_diff_bias(int nb_oc_blocking);
700 void may_be_set_oc_tail_mask();
701 void may_be_reset_oc_tail_mask();
702
703 void generate() override;
704
705 static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
706 int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
707
708 inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) {
709 return kw * (jcp.dilate_w + 1) + ow - pad_l;
710 }
711 inline dim_t filter_h_to_src(int kh) { return kh * (jcp.dilate_h + 1); }
712 inline dim_t filter_d_to_src(int kd) {
713 return kd * (jcp.dilate_d + 1) * jcp.ih;
714 }
715
716 inline dim_t get_src_offset(dim_t ic_idx, dim_t w_idx, dim_t hd_idx = 0) {
717 return jcp.typesize_in
718 * (hd_idx * jcp.tr_iw * jcp.ic_block + jcp.tr_iw * ic_idx
719 + w_idx);
720 }
721
722 inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) {
723 int ow_per_oc = 2;
724 dim_t w_off = w_idx / ow_per_oc * ow_per_oc * jcp.oc_block
725 + w_idx % ow_per_oc;
726 return jcp.typesize_in * (w_off + jcp.tr_ow * jcp.oc_block * hd_idx);
727 }
728
729 inline dim_t get_kernel_offset(int ic_idx, dim_t ksp_idx) {
730 return jcp.typesize_out * jcp.oc_block
731 * (ksp_idx * jcp.ic_block + ic_idx);
732 }
733 inline dim_t get_full_kernel_offset(int ocb, int icb, int kh, int kw) {
734 return jcp.typesize_out
735 * (ocb * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block
736 * jcp.oc_block
737 + icb * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block
738 * jcp.oc_block
739 + kh * jcp.kw * jcp.ic_block * jcp.oc_block
740 + kw * jcp.ic_block * jcp.oc_block);
741 };
742
743 inline void setup_stack_space();
744 int ic_block_step_stack_size = 0;
745 int stack_space_needed = 0;
746 int kd_count_offset = 0;
747 int src_d_offset = 0;
748 int ddst_d_offset = 0;
749 int d_index_offset = 0;
750 int ih_dilate_offset = 0;
751 int src_save_offset = 0;
752 int ddst_save_offset = 0;
753};
754
755struct jit_avx512_core_amx_bwd_bias_kernel_t : public jit_generator {
756
757 jit_avx512_core_amx_bwd_bias_kernel_t(const jit_conv_conf_t &ajcp)
758 : jit_generator(
759 jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx)
760 , jcp(ajcp) {}
761
762 ~jit_avx512_core_amx_bwd_bias_kernel_t() {}
763
764 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_bias_kernel_t)
765
766 const jit_conv_conf_t &jcp;
767
768private:
769 using reg64_t = const Xbyak::Reg64;
770
771 reg64_t param = abi_param1;
772 reg64_t reg_ddst = rsi;
773 reg64_t reg_oj = r15;
774 reg64_t reg_tmp = r14;
775 reg64_t reg_bias = r13;
776 reg64_t reg_initial = r12;
777 reg64_t reg_nrows = r11;
778
779 Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0);
780 Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1);
781 Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2);
782 Xbyak::Ymm yreg_bias_acc0 = Xbyak::Ymm(0);
783 Xbyak::Ymm yreg_bias_acc1 = Xbyak::Ymm(3);
784 Xbyak::Ymm yreg_bias_ddst0 = Xbyak::Ymm(2);
785 Xbyak::Ymm yreg_bias_ddst1 = Xbyak::Ymm(4);
786
787 void compute_diff_bias_row(int ocb);
788 void compute_diff_bias(int nb_oc_blocking);
789
790 void generate() override;
791
792 inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) {
793 int ow_per_oc = 2;
794 dim_t w_off = w_idx / ow_per_oc * ow_per_oc * jcp.oc_block
795 + w_idx % ow_per_oc;
796 return jcp.typesize_in * (w_off + jcp.tr_ow * jcp.oc_block * hd_idx);
797 }
798};
799
800} // namespace x64
801} // namespace cpu
802} // namespace impl
803} // namespace dnnl
804
805#endif
806