1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_AMX_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_AMX_CONV_KERNEL_HPP |
19 | |
20 | #include <queue> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | |
25 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
26 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | #include "cpu/x64/jit_primitive_conf.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | /* This struct computes the compensation for src_zero_point related to |
36 | * padding */ |
37 | struct jit_avx512_core_amx_compute_zp_pbuff_t : public jit_generator { |
38 | |
39 | using reg64_t = const Xbyak::Reg64; |
40 | |
41 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_compute_zp_pbuff_t) |
42 | |
43 | jit_avx512_core_amx_compute_zp_pbuff_t(const jit_conv_conf_t &ajcp) |
44 | : jit_generator( |
45 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
46 | , jcp(ajcp) {} |
47 | |
48 | static const int max_regs_ur = 30; |
49 | |
50 | private: |
51 | jit_conv_conf_t jcp; |
52 | |
53 | typedef enum { no_last_block, last_ic_block } ic_block_t; |
54 | const int ic_inner_block = 4; |
55 | |
56 | Xbyak::Label permb_idx_label; |
57 | Xbyak::Label ic_mask_label; |
58 | |
59 | const reg64_t reg_zp_pbuff = r8; |
60 | const reg64_t reg_src_zero_point = r9; |
61 | const reg64_t reg_filt = r10; |
62 | const reg64_t aux_reg_filt = r11; |
63 | const reg64_t aux_reg_filt_d = r15; |
64 | |
65 | const reg64_t reg_oc_blocks = r12; |
66 | const reg64_t reg_icb = r13; |
67 | const reg64_t reg_oi = r14; |
68 | const reg64_t reg_kj = rax; |
69 | const reg64_t reg_ki = rbx; |
70 | const reg64_t reg_overflow = reg_kj; |
71 | const reg64_t reg_scratch = rsi; |
72 | |
73 | const Xbyak::Zmm zmm_one = Xbyak::Zmm(31); |
74 | const Xbyak::Zmm zmm_permb = Xbyak::Zmm(30); |
75 | |
76 | const Xbyak::Opmask kmask_ic_block = Xbyak::Opmask(1); |
77 | const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); |
78 | |
79 | void prepare_output(int ur_w); |
80 | void store_output(int ur_w, bool last_oc_block_flag); |
81 | void compute_ker(int ur_w, int pad_l, int pad_r, |
82 | ic_block_t last_ic_block_flag, bool padded); |
83 | void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, |
84 | bool handle_h_pad); |
85 | void kd_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, |
86 | bool handle_h_pad); |
87 | void icb_loop(int ur_w, int pad_l, int pad_r, bool handle_h_pad); |
88 | void unroll_width(const bool h_padding); |
89 | |
90 | void generate() override; |
91 | |
92 | Xbyak::Zmm zmm_out(int i_ur, int i_oc) { |
93 | int idx = i_ur * jcp.nb_oc_blocking + i_oc; |
94 | assert(idx < max_regs_ur); |
95 | return Xbyak::Zmm(idx); |
96 | } |
97 | int get_ow_start(int ki, int pad_l) { |
98 | return nstl::max(0, |
99 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
100 | } |
101 | int get_ow_end(int ur_w, int ki, int pad_r) { |
102 | int filter_overlap = pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
103 | return ur_w - nstl::max(0, utils::div_up(filter_overlap, jcp.stride_w)); |
104 | } |
105 | }; |
106 | |
107 | struct jit_avx512_core_amx_copy_to_wbuffer_t : public jit_generator { |
108 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_to_wbuffer_t) |
109 | |
110 | using reg64_t = Xbyak::Reg64; |
111 | |
112 | jit_avx512_core_amx_copy_to_wbuffer_t(const jit_conv_conf_t &ajcp) |
113 | : jit_generator( |
114 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
115 | , jcp(ajcp) {} |
116 | |
117 | private: |
118 | jit_conv_conf_t jcp; |
119 | |
120 | const reg64_t reg_src = rax; |
121 | const reg64_t reg_dst = rbx; |
122 | const reg64_t reg_tmp = rdx; |
123 | |
124 | const Xbyak::Opmask kmask_load = k2; |
125 | |
126 | const Xbyak::Zmm zmm_src = zmm0; |
127 | const Xbyak::Zmm zmm_dst = zmm1; |
128 | const Xbyak::Zmm zmm_idx = zmm2; |
129 | const Xbyak::Zmm zmm_zero = zmm3; |
130 | |
131 | void generate() override; |
132 | }; |
133 | |
134 | struct jit_avx512_core_amx_copy_to_pbuffer_t : public jit_generator { |
135 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_copy_to_pbuffer_t) |
136 | |
137 | using reg64_t = Xbyak::Reg64; |
138 | |
139 | jit_avx512_core_amx_copy_to_pbuffer_t(const jit_conv_conf_t &ajcp) |
140 | : jit_generator( |
141 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
142 | , jcp(ajcp) {} |
143 | |
144 | private: |
145 | jit_conv_conf_t jcp; |
146 | |
147 | const reg64_t reg_inp_ptr = r15; |
148 | const reg64_t reg_out_ptr = r14; |
149 | |
150 | const reg64_t reg_aux_inp_ptr = r13; |
151 | const reg64_t reg_aux_out_ptr = r12; |
152 | |
153 | const reg64_t reg_khp = r10; |
154 | |
155 | /* relow stuff */ |
156 | const reg64_t reg_kht = r11; |
157 | const reg64_t reg_tov = r9; |
158 | const reg64_t reg_bov = r8; |
159 | const reg64_t reg_kwp = rax; |
160 | const reg64_t reg_lov = reg_aux_inp_ptr; |
161 | const reg64_t reg_rov = rbx; |
162 | const reg64_t reg_save_out_ptr = rdx; |
163 | const reg64_t reg_cnt = rbp; |
164 | /* relow stuff */ |
165 | |
166 | /* non-relow stuff */ |
167 | const reg64_t reg_kdp = abi_not_param1; |
168 | const reg64_t reg_kdc = rbp; |
169 | const reg64_t reg_khc = r11; |
170 | |
171 | const reg64_t reg_kh_over = r8; |
172 | const reg64_t reg_tover = rax; |
173 | const reg64_t reg_bover = rbx; |
174 | |
175 | const reg64_t reg_owb = rdx; |
176 | /* non-relow stuff */ |
177 | |
178 | const reg64_t reg_tmp = rsi; |
179 | |
180 | const Xbyak::Opmask &ktail_mask = k2; |
181 | |
182 | const Xbyak::Ymm &ymm_tmp = ymm0; |
183 | const Xbyak::Zmm &zmm_tmp = zmm0; |
184 | const Xbyak::Zmm &zmm_zero = zmm1; |
185 | |
186 | void generate() override; |
187 | void copy_row(int icb); |
188 | void copy_row_body(int lpad, int iw_len, int icb); |
189 | void copy_row_reduced_lowering(); |
190 | }; |
191 | |
192 | struct jit_avx512_core_amx_fwd_kernel_t : public jit_generator { |
193 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_fwd_kernel_t) |
194 | |
195 | jit_avx512_core_amx_fwd_kernel_t(const jit_conv_conf_t &ajcp, |
196 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
197 | |
198 | status_t create_kernel() override; |
199 | |
200 | static status_t init_conf(jit_conv_conf_t &jcp, |
201 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
202 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
203 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
204 | static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, |
205 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
206 | |
207 | inline int accum_with_upper_bound( |
208 | int upper_bound, int lower_value, int upper_value) { |
209 | return nstl::min(upper_bound, |
210 | nstl::min(upper_bound, lower_value) |
211 | + nstl::max(0, upper_bound - upper_value)); |
212 | } |
213 | |
214 | /* Calculate and store the limits relevant to 'ow_block'. These limits |
215 | * allow the driver code to determine which 'ow_block' is currently being |
216 | * executed. There can be at most 5 different 'ow_block', each corresponding |
217 | * to: |
218 | * - l_pad block |
219 | * - middle block (no padding) |
220 | * - middle & r_pad shift block |
221 | * - r_pad (full) block |
222 | * - r_pad tail |
223 | * */ |
224 | static void set_ow_blk_limits(jit_conv_conf_t &jcp); |
225 | |
226 | /* Calculate and store the limits relevant to each 'oh_block'. Each |
227 | * 'oh_block' size is 'nb_oh_blocking * oh_per_tile'. These limits allow |
228 | * the driver code to determine which 'oh_block' is currently being |
229 | * executed, and what is the oh value required to advance the limits index. |
230 | * |
231 | * There can be at most 6 different 'oh_blk', depending on the sizes of |
232 | * 't_pad_output', 'b_pad_output' and their overlap with |
233 | * 'nb_oh_blocking * oh_per_tile'. |
234 | * |
235 | * For example, given the following input dimensions of {height_size = 12, |
236 | * oh_blk_size = 2, top_padding = 5 (t_pad), bottom_padding = 2 (b_pad)}, |
237 | * the 4 output height blocks and limits are: |
238 | * |
239 | * H: _ H_blks:_ Limits: |
240 | * 0 | | 0|X| |
241 | * 1 | | |X|_4 |
242 | * 2 | | t_pad |
243 | * 3 | | _ |
244 | * 4 |_| 1|X| |
245 | * 5 | | |_|_5 |
246 | * 6 | | 2| | |
247 | * 7 | | |_|_9 |
248 | * 8 | | |
249 | * 9 |_| _ |
250 | * 10| | b_pad 3|X| |
251 | * 11|_| |X|_11 |
252 | * |
253 | * -where 'x' represents |
254 | * an 'h_blk' with output |
255 | * padding. |
256 | * */ |
257 | static void set_oh_blk_limits(jit_conv_conf_t &jcp); |
258 | |
259 | void tile_configure(char *tcfg_buff); |
260 | |
261 | jit_conv_conf_t jcp; |
262 | const primitive_attr_t &attr_; |
263 | |
264 | const jit_avx512_core_amx_copy_to_pbuffer_t ©_to_pbuffer() const { |
265 | return *copy_to_pbuffer_; |
266 | } |
267 | const jit_avx512_core_amx_copy_to_wbuffer_t ©_to_wbuffer() const { |
268 | return *copy_to_wbuffer_; |
269 | } |
270 | const jit_avx512_core_amx_compute_zp_pbuff_t &zp_pbuff_kernel() const { |
271 | return *zp_pbuff_kernel_; |
272 | } |
273 | |
274 | private: |
275 | constexpr static int isa_simd_width_ |
276 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
277 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
278 | postops_injector_; |
279 | std::unique_ptr<jit_avx512_core_amx_copy_to_pbuffer_t> copy_to_pbuffer_; |
280 | std::unique_ptr<jit_avx512_core_amx_copy_to_wbuffer_t> copy_to_wbuffer_; |
281 | std::unique_ptr<jit_avx512_core_amx_compute_zp_pbuff_t> zp_pbuff_kernel_; |
282 | |
283 | enum { |
284 | zmm_idx_limit_bf16 = 29, |
285 | zmm_idx_limit_int8 = 27, |
286 | }; |
287 | |
288 | int prv_width_ = 0; |
289 | int row_count_ = 0; |
290 | bool is_store_done_ = false; |
291 | bool is_buffer_empty_ = true; |
292 | |
293 | struct w_pad_output { |
294 | int l_pad_output; |
295 | int r_pad_output; |
296 | w_pad_output(int l_, int r_) : l_pad_output(l_), r_pad_output(r_) {} |
297 | }; |
298 | std::queue<w_pad_output> w_padding; |
299 | |
300 | /* data regs */ |
301 | const Xbyak::Reg64 reg_inp_ptr = r15; |
302 | const Xbyak::Reg64 reg_wei_ptr = r14; |
303 | const Xbyak::Reg64 reg_out_ptr = r13; |
304 | const Xbyak::Reg64 reg_wsp_ptr = r12; |
305 | |
306 | const Xbyak::Reg64 reg_kd = r9; |
307 | |
308 | const Xbyak::Reg64 reg_bias = r11; |
309 | const Xbyak::Reg64 reg_ptr_scales = r10; |
310 | const Xbyak::Reg64 reg_ptr_sum_scale = r9; |
311 | const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1; |
312 | const Xbyak::Reg64 reg_aux_saturation = reg_ptr_sum_scale; |
313 | |
314 | const Xbyak::Reg64 reg_inp_stride = rbx; |
315 | const Xbyak::Reg64 reg_wei_stride = rdx; |
316 | // zero-point computation |
317 | const Xbyak::Reg64 reg_zp_compensation = rax; |
318 | const Xbyak::Reg64 reg_src_zero_point = r8; |
319 | const Xbyak::Reg64 reg_zero_point_pbuff = rsi; |
320 | const Xbyak::Reg64 reg_dst_zero_point = abi_not_param1; |
321 | const Xbyak::Reg64 reg_dst_scale = reg_dst_zero_point; |
322 | |
323 | // rbp - reserved for EVEX compression |
324 | const Xbyak::Reg64 reg_last_h = abi_not_param1; |
325 | const Xbyak::Reg64 reg_jmp_blk = reg_last_h; |
326 | |
327 | // temporary, used in generate() function only |
328 | const Xbyak::Reg64 reg_oc_blocks = rax; |
329 | const Xbyak::Reg64 reg_tmp = r8; |
330 | |
331 | const Xbyak::Opmask &ktail_mask = k2; |
332 | |
333 | const Xbyak::Zmm &zmm_bias = zmm31; |
334 | const Xbyak::Zmm &zmm_saturation = zmm_bias; |
335 | const Xbyak::Zmm &zmm_zero = zmm30; |
336 | const Xbyak::Zmm &zmm_prev_dst = zmm29; |
337 | const Xbyak::Zmm &zmm_sum_zp = zmm26; |
338 | /* zero-point */ |
339 | const Xbyak::Zmm &zmm_zp = zmm29; |
340 | const Xbyak::Zmm &zmm_src_zp = zmm28; |
341 | const Xbyak::Zmm &zmm_dst_zp = zmm27; |
342 | /* dst scale */ |
343 | const Xbyak::Zmm &zmm_dst_scale = zmm25; |
344 | |
345 | const Xbyak::Reg64 bin_injector_helper_reg_1 = r14; |
346 | const Xbyak::Reg64 bin_injector_helper_reg_2 = r15; |
347 | const Xbyak::Reg64 bin_injector_helper_reg_3 = r11; |
348 | |
349 | // AUX: Steps, shifts and offsets |
350 | size_t get_inp_icb_step() const; |
351 | size_t get_wei_icb_step() const; |
352 | size_t get_inp_d_step() const; |
353 | size_t get_inp_h_step() const; |
354 | size_t get_wei_d_step() const; |
355 | size_t get_wei_h_step() const; |
356 | size_t get_out_ocb_offset(int ohb, int ocb, size_t typesize) const; |
357 | size_t get_out_row_offset(int ohb, int ocb, int j, size_t typesize) const; |
358 | size_t get_out_shift(int width, size_t typesize) const; |
359 | size_t get_wsp_ocb_offset(int ohb, int ocb) const; |
360 | size_t get_wsp_row_offset(int ohb, int ocb, int j) const; |
361 | size_t get_wsp_shift() const; |
362 | size_t get_wei_offset(int ocb, int kw) const; |
363 | size_t get_inp_shift() const; |
364 | size_t get_inp_offset(int ohb, int kw) const; |
365 | size_t get_zp_comp_offset(int ocb, int zp_h, int zp_w) const; |
366 | int get_zp_index_offset( |
367 | int index, int mid, int s_pad_output, int e_pad_output); |
368 | |
369 | int get_out_tensor(int h, int i, bool is_h_tail = false) const; |
370 | int get_inp_tensor(int h, bool is_h_tail = false) const; |
371 | int get_wei_tensor(int i) const; |
372 | |
373 | void prepare_output(int tail); |
374 | void init_runtime_counters(bool start_with_last_tile_block); |
375 | size_t reduce_to_block(const int block_size, const int pad_output); |
376 | size_t reduce_to_blocked_dims(const int dim_size, const int block_size, |
377 | const int s_pad_output, const int e_pad_output); |
378 | void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in, |
379 | const Xbyak::Operand &op, bool mask_flag = false); |
380 | Xbyak::Zmm zmm_out(const int idx) { |
381 | const int upper_limit = jcp.src_dt == data_type::bf16 |
382 | ? zmm_idx_limit_bf16 |
383 | : zmm_idx_limit_int8; |
384 | assert(upper_limit > idx); |
385 | MAYBE_UNUSED(upper_limit); |
386 | return Xbyak::Zmm(idx); |
387 | } |
388 | Xbyak::Ymm ymm_mask( |
389 | const Xbyak::Ymm &zmm_in, bool mask_flag, bool store = false); |
390 | Xbyak::Zmm zmm_mask( |
391 | const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false); |
392 | void apply_sum(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, |
393 | const int32_t *p_sum_zp, const Xbyak::Address &addr, |
394 | const bool mask_flag); |
395 | void apply_postops(const Xbyak::Zmm &zmm_out, const float *p_sum_scale, |
396 | const int32_t *p_sum_zp, const Xbyak::Address &addr, |
397 | const size_t off, const bool mask_flag); |
398 | inline void store_output_ymm_bf16( |
399 | const int idx, const Xbyak::Address &addr, const bool mask_flag); |
400 | void store_output_vector_bf16( |
401 | const Xbyak::Zmm &zmm_out, int ocb, int h, int w); |
402 | void store_output_vector_int8(const Xbyak::Zmm &zmm_out, int ocb, int h, |
403 | int w, const bool compute_zp, const int zp_h, const int zp_w); |
404 | void store_output_vector(const Xbyak::Zmm &zmm_out, int ocb, int h, int w, |
405 | const bool compute_zp = false, const int zp_h = 0, |
406 | const int zp_w = 0); |
407 | void store_output(int width, int tail, bool do_store, |
408 | const bool handle_h_block, const int t_pad_output, |
409 | const int b_pad_output, const int l_pad_output, |
410 | const int r_pad_output, const bool is_last_oh_block, |
411 | const bool zp_3d_pad = false); |
412 | void interleave_store(int width, int const t_pad_output, |
413 | int const b_pad_output, const bool zp_3d_pad = false); |
414 | void compute_icb_loop(int width, bool do_store, const bool handle_h_block, |
415 | const int t_pad_output, const int b_pad_output, |
416 | const int l_pad_output, const int r_pad_output, |
417 | const bool zp_3d_pad, const bool is_last_oh_block = false); |
418 | void dispatch_icb_loop(int width, bool do_store, const int l_pad_output, |
419 | const int r_pad_output, const bool zp_3d_pad); |
420 | void dispatch_zp_3d_compute(int width, bool do_store, |
421 | const int l_pad_output, const int r_pad_output); |
422 | void compute_ow_loop(); |
423 | |
424 | void generate() override; |
425 | }; |
426 | |
427 | struct jit_avx512_core_amx_bwd_data_copy_kernel_t : public jit_generator { |
428 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_data_copy_kernel_t) |
429 | |
430 | using reg64_t = Xbyak::Reg64; |
431 | |
432 | jit_avx512_core_amx_bwd_data_copy_kernel_t(jit_conv_conf_t ajcp) |
433 | : jit_generator( |
434 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
435 | , jcp(ajcp) {} |
436 | |
437 | private: |
438 | jit_conv_conf_t jcp; |
439 | |
440 | // pointers |
441 | const reg64_t reg_ptr_inp = r15; |
442 | const reg64_t reg_ptr_out = r14; |
443 | |
444 | // auxiliary pointers |
445 | const reg64_t reg_ptr_aux_inp_h = r13; |
446 | const reg64_t reg_ptr_aux_inp_w = r12; |
447 | const reg64_t reg_ptr_aux_out = r11; |
448 | |
449 | // variables |
450 | const reg64_t reg_khp = r10; // kh padding |
451 | const reg64_t reg_tov = r9; // top overflow |
452 | const reg64_t reg_bov = reg_tov; // bottom overflow |
453 | const reg64_t reg_kwp = rax; // kw padding |
454 | const reg64_t reg_lov = rbx; // left overflow |
455 | const reg64_t reg_rov = abi_not_param1; // right overflow |
456 | const reg64_t reg_kd = r8; // 3d filter |
457 | |
458 | // counters |
459 | const reg64_t reg_cnt_khp = rdx; |
460 | const reg64_t reg_cnt_tmp = rbp; |
461 | const reg64_t reg_cnt_ocb = rsi; |
462 | |
463 | const reg64_t reg_tmp = reg_cnt_tmp; |
464 | |
465 | const Xbyak::Opmask ktail_mask = k2; |
466 | |
467 | const Xbyak::Zmm zmm_tmp = zmm1; |
468 | const Xbyak::Zmm zmm_zero = zmm0; |
469 | |
470 | void generate() override; |
471 | void copy_row(bool is_masked); |
472 | void kd_loop(bool is_masked); |
473 | }; |
474 | |
475 | struct jit_avx512_core_amx_bwd_data_kernel_t : public jit_generator { |
476 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_data_kernel_t) |
477 | |
478 | jit_avx512_core_amx_bwd_data_kernel_t( |
479 | const jit_conv_conf_t ajcp, const primitive_attr_t &attr) |
480 | : jit_generator( |
481 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
482 | , jcp(ajcp) |
483 | , attr_(attr) |
484 | , eltwise_injector_(nullptr) { |
485 | if (jcp.with_eltwise) |
486 | eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_core>( |
487 | this, jcp.eltwise); |
488 | bwd_data_copy_kernel_ |
489 | = new jit_avx512_core_amx_bwd_data_copy_kernel_t(jcp); |
490 | } |
491 | status_t create_kernel() override { |
492 | CHECK(jit_generator::create_kernel()); |
493 | CHECK(bwd_data_copy_kernel_->create_kernel()); |
494 | return status::success; |
495 | } |
496 | ~jit_avx512_core_amx_bwd_data_kernel_t() { |
497 | delete eltwise_injector_; |
498 | delete bwd_data_copy_kernel_; |
499 | } |
500 | |
501 | static bool post_ops_ok(const jit_conv_conf_t &jcp, primitive_attr_t &attr); |
502 | |
503 | static status_t init_conf(jit_conv_conf_t &jcp, |
504 | const convolution_desc_t &cd, memory_desc_t &diff_src_pd, |
505 | memory_desc_t &weights_pd, memory_desc_t &diff_dst_pd, |
506 | memory_desc_t *bias_pd, primitive_attr_t &attr, int nthreads); |
507 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
508 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
509 | |
510 | void tile_configure(char *tcfg_buff); |
511 | |
512 | jit_conv_conf_t jcp; |
513 | const primitive_attr_t &attr_; |
514 | |
515 | const jit_avx512_core_amx_bwd_data_copy_kernel_t & |
516 | bwd_data_copy_kernel() const { |
517 | return *bwd_data_copy_kernel_; |
518 | } |
519 | |
520 | private: |
521 | jit_uni_eltwise_injector_f32<avx512_core> *eltwise_injector_; |
522 | jit_avx512_core_amx_bwd_data_copy_kernel_t *bwd_data_copy_kernel_; |
523 | |
524 | int prv_width_ = 0; |
525 | int row_count_ = 0; |
526 | bool is_store_done_ = false; |
527 | bool is_buffer_empty_ = true; |
528 | |
529 | bool is_store_done_save_ = false; |
530 | int prv_width_save_ = 0; |
531 | |
532 | /* data regs */ |
533 | const Xbyak::Reg64 reg_inp_ptr = r15; |
534 | const Xbyak::Reg64 reg_wei_ptr = r14; |
535 | const Xbyak::Reg64 reg_out_ptr = r13; |
536 | const Xbyak::Reg64 reg_wsp_ptr = r12; |
537 | |
538 | const Xbyak::Reg64 reg_bias = r11; |
539 | const Xbyak::Reg64 reg_ptr_scales = r10; |
540 | const Xbyak::Reg64 reg_ptr_dst_scales = r10; |
541 | const Xbyak::Reg64 reg_ptr_sum_scale = r9; |
542 | const Xbyak::Reg64 reg_ptr_sum_zp = abi_not_param1; |
543 | const Xbyak::Reg64 reg_aux_saturation = reg_ptr_sum_scale; |
544 | |
545 | const Xbyak::Reg64 reg_aux_inp_ptr = r8; |
546 | const Xbyak::Reg64 reg_inp_stride = rbx; |
547 | const Xbyak::Reg64 reg_wei_stride = rdx; |
548 | |
549 | // rbp - reserved for EVEX compression |
550 | const Xbyak::Reg64 reg_last_h = abi_not_param1; |
551 | const Xbyak::Reg64 reg_kd = rsi; |
552 | |
553 | // temporary, used in generate() function only |
554 | const Xbyak::Reg64 reg_ic_blocks = rax; |
555 | const Xbyak::Reg64 reg_tmp = reg_aux_inp_ptr; |
556 | |
557 | const Xbyak::Opmask ktail_mask = k2; |
558 | |
559 | const Xbyak::Zmm zmm_bias = zmm31; |
560 | const Xbyak::Zmm zmm_saturation = zmm_bias; |
561 | const Xbyak::Zmm zmm_zero = zmm30; |
562 | const Xbyak::Zmm zmm_prev_dst = zmm29; |
563 | const Xbyak::Zmm zmm_sum_zp = zmm28; |
564 | /* dst scale */ |
565 | const Xbyak::Zmm &zmm_dst_scale = zmm27; |
566 | |
567 | // AUX: Steps, shifts and offsets |
568 | size_t get_inp_ocb_step() const; |
569 | size_t get_inp_offset(int ihb, int kh, int kw) const; |
570 | size_t get_inp_shift() const; |
571 | size_t get_inp_d_step() const; |
572 | size_t get_out_icb_offset(int ihb, int icb) const; |
573 | size_t get_out_row_offset(int ihb, int icb, int j) const; |
574 | size_t get_out_shift(int width) const; |
575 | size_t get_wei_kh_step() const; |
576 | size_t get_wei_ocb_step() const; |
577 | size_t get_wei_offset(int icb, int kh, int kw) const; |
578 | size_t get_wei_d_step() const; |
579 | size_t get_wsp_icb_offset(int ihb, int icb) const; |
580 | size_t get_wsp_row_offset(int ihb, int icb, int j) const; |
581 | size_t get_wsp_shift() const; |
582 | |
583 | int get_out_tensor(int h, int i) const; |
584 | int get_inp_tensor(int h) const; |
585 | int get_wei_tensor(int i) const; |
586 | |
587 | inline bool gaps_in_store() { |
588 | const int gen_kd = (jcp.kd - 1) * (jcp.dilate_d + 1) + 1; |
589 | return gen_kd < jcp.stride_d || jcp.dilate_d > 0; |
590 | } |
591 | |
592 | void prepare_output(); |
593 | void init_runtime_counters(bool start_with_last_tile_block); |
594 | |
595 | bool maybe_eltwise(int position); |
596 | void cvt2ps(data_type_t type_in, const Xbyak::Zmm &ymm_in, |
597 | const Xbyak::Operand &op, bool mask_flag = false); |
598 | Xbyak::Ymm ymm_mask( |
599 | const Xbyak::Ymm &zmm_in, bool mask_flag, bool store = false); |
600 | Xbyak::Zmm zmm_mask( |
601 | const Xbyak::Zmm &zmm_in, bool mask_flag, bool store = false); |
602 | |
603 | void store_output_vector_bf16( |
604 | const Xbyak::Zmm &zmm_out, int icb, int ihb, int iw); |
605 | void store_output_vector_int8( |
606 | const Xbyak::Zmm &zmm_out, int icb, int ihb, int iw); |
607 | void store_output_vector( |
608 | const Xbyak::Zmm &zmm_out, int icb, int ih, int iw); |
609 | void store_output(int width, bool do_store); |
610 | void skipped_interleave_store(); |
611 | void interleave_store(int width); |
612 | void compute_ocb_loop(int width, bool do_interleave_store); |
613 | void compute_kd_loop(int width, bool do_store, bool handle_skipped_stores); |
614 | void compute_iw_loop(); |
615 | |
616 | void generate() override; |
617 | }; |
618 | |
619 | struct jit_avx512_core_amx_bwd_weights_kernel_t : public jit_generator { |
620 | |
621 | jit_avx512_core_amx_bwd_weights_kernel_t(const jit_conv_conf_t &ajcp) |
622 | : jit_generator( |
623 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
624 | , jcp(ajcp) {} |
625 | |
626 | ~jit_avx512_core_amx_bwd_weights_kernel_t() {} |
627 | |
628 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_weights_kernel_t) |
629 | |
630 | static status_t init_conf(jit_conv_conf_t &jcp, |
631 | const convolution_desc_t &cd, memory_desc_t &src_md, |
632 | memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, |
633 | memory_desc_t &diff_dst_md, int nthreads); |
634 | static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, |
635 | const jit_conv_conf_t &jcp, memory_desc_t &src_md, |
636 | memory_desc_t &diff_weights_md, memory_desc_t &diff_dst_md); |
637 | |
638 | void tile_configure(char *tcfg_buff); |
639 | |
640 | const jit_conv_conf_t &jcp; |
641 | |
642 | private: |
643 | int get_wei_tensor(int ocb, int icb) const; |
644 | int get_src_tensor(int icb) const; |
645 | int get_ddst_tensor(int ocb) const; |
646 | |
647 | using reg64_t = const Xbyak::Reg64; |
648 | static const int max_ur_w; |
649 | |
650 | reg64_t param = abi_param1; |
651 | reg64_t reg_src = rax; |
652 | reg64_t reg_kernel = rdx; |
653 | reg64_t reg_ddst = rsi; |
654 | reg64_t b_ic = abi_not_param1; |
655 | reg64_t kj = r8; |
656 | reg64_t reg_kh = r9; |
657 | reg64_t reg_oj = r15; |
658 | reg64_t reg_tmp = r14; |
659 | reg64_t reg_ih_shift = reg_tmp; |
660 | reg64_t reg_long_offt = r14; |
661 | reg64_t reg_icb = rbx; |
662 | |
663 | reg64_t ki = r11; |
664 | reg64_t reg_oj_setup = r11; |
665 | reg64_t reg_kd_count = r12; |
666 | reg64_t reg_oi = r12; |
667 | reg64_t reg_d_index = r13; |
668 | reg64_t reg_src_d = r15; |
669 | reg64_t reg_ddst_d = rbx; |
670 | reg64_t aux_reg_src = r12; |
671 | reg64_t aux_reg_kernel = r13; |
672 | |
673 | reg64_t reg_b_stride = reg_icb; |
674 | reg64_t reg_a_stride = r10; |
675 | |
676 | Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0); |
677 | Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1); |
678 | Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2); |
679 | |
680 | enum { |
681 | full_spat_opt_working_set_size = 48 * 1024, |
682 | full_spat_max_working_set_size = 128 * 1024, |
683 | }; |
684 | |
685 | inline void maybe_zero_kernel(int nb_ic_blocking, int nb_oc_blocking); |
686 | inline void od_step_comeback_pointers(); |
687 | inline void oh_step_comeback_pointers(); |
688 | inline void compute_ic_loop( |
689 | int ic_block, int nb_ic_blocking, int nb_oc_blocking); |
690 | inline void compute_full_spat_loop(int nb_ic_blocking, int nb_oc_blocking); |
691 | inline void compute_oh_step_common(int nb_ic_blocking, int nb_oc_blocking); |
692 | inline void compute_loop(int nb_ic_blocking, int nb_oc_blocking); |
693 | inline void compute_oh_loop_common( |
694 | int nb_ic_blocking, int nb_oc_blocking, bool partial = false); |
695 | inline void compute_od_loop_common( |
696 | int nb_ic_blocking, int nb_oc_blocking, bool partial = false); |
697 | void compute_diff_bias_init(int ocb = 0); |
698 | void compute_diff_bias_row(bool is_partial, int ocb); |
699 | void maybe_compute_diff_bias(int nb_oc_blocking); |
700 | void may_be_set_oc_tail_mask(); |
701 | void may_be_reset_oc_tail_mask(); |
702 | |
703 | void generate() override; |
704 | |
705 | static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, |
706 | int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); |
707 | |
708 | inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) { |
709 | return kw * (jcp.dilate_w + 1) + ow - pad_l; |
710 | } |
711 | inline dim_t filter_h_to_src(int kh) { return kh * (jcp.dilate_h + 1); } |
712 | inline dim_t filter_d_to_src(int kd) { |
713 | return kd * (jcp.dilate_d + 1) * jcp.ih; |
714 | } |
715 | |
716 | inline dim_t get_src_offset(dim_t ic_idx, dim_t w_idx, dim_t hd_idx = 0) { |
717 | return jcp.typesize_in |
718 | * (hd_idx * jcp.tr_iw * jcp.ic_block + jcp.tr_iw * ic_idx |
719 | + w_idx); |
720 | } |
721 | |
722 | inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) { |
723 | int ow_per_oc = 2; |
724 | dim_t w_off = w_idx / ow_per_oc * ow_per_oc * jcp.oc_block |
725 | + w_idx % ow_per_oc; |
726 | return jcp.typesize_in * (w_off + jcp.tr_ow * jcp.oc_block * hd_idx); |
727 | } |
728 | |
729 | inline dim_t get_kernel_offset(int ic_idx, dim_t ksp_idx) { |
730 | return jcp.typesize_out * jcp.oc_block |
731 | * (ksp_idx * jcp.ic_block + ic_idx); |
732 | } |
733 | inline dim_t get_full_kernel_offset(int ocb, int icb, int kh, int kw) { |
734 | return jcp.typesize_out |
735 | * (ocb * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block |
736 | * jcp.oc_block |
737 | + icb * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block |
738 | * jcp.oc_block |
739 | + kh * jcp.kw * jcp.ic_block * jcp.oc_block |
740 | + kw * jcp.ic_block * jcp.oc_block); |
741 | }; |
742 | |
743 | inline void setup_stack_space(); |
744 | int ic_block_step_stack_size = 0; |
745 | int stack_space_needed = 0; |
746 | int kd_count_offset = 0; |
747 | int src_d_offset = 0; |
748 | int ddst_d_offset = 0; |
749 | int d_index_offset = 0; |
750 | int ih_dilate_offset = 0; |
751 | int src_save_offset = 0; |
752 | int ddst_save_offset = 0; |
753 | }; |
754 | |
755 | struct jit_avx512_core_amx_bwd_bias_kernel_t : public jit_generator { |
756 | |
757 | jit_avx512_core_amx_bwd_bias_kernel_t(const jit_conv_conf_t &ajcp) |
758 | : jit_generator( |
759 | jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
760 | , jcp(ajcp) {} |
761 | |
762 | ~jit_avx512_core_amx_bwd_bias_kernel_t() {} |
763 | |
764 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_bwd_bias_kernel_t) |
765 | |
766 | const jit_conv_conf_t &jcp; |
767 | |
768 | private: |
769 | using reg64_t = const Xbyak::Reg64; |
770 | |
771 | reg64_t param = abi_param1; |
772 | reg64_t reg_ddst = rsi; |
773 | reg64_t reg_oj = r15; |
774 | reg64_t reg_tmp = r14; |
775 | reg64_t reg_bias = r13; |
776 | reg64_t reg_initial = r12; |
777 | reg64_t reg_nrows = r11; |
778 | |
779 | Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0); |
780 | Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1); |
781 | Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2); |
782 | Xbyak::Ymm yreg_bias_acc0 = Xbyak::Ymm(0); |
783 | Xbyak::Ymm yreg_bias_acc1 = Xbyak::Ymm(3); |
784 | Xbyak::Ymm yreg_bias_ddst0 = Xbyak::Ymm(2); |
785 | Xbyak::Ymm yreg_bias_ddst1 = Xbyak::Ymm(4); |
786 | |
787 | void compute_diff_bias_row(int ocb); |
788 | void compute_diff_bias(int nb_oc_blocking); |
789 | |
790 | void generate() override; |
791 | |
792 | inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) { |
793 | int ow_per_oc = 2; |
794 | dim_t w_off = w_idx / ow_per_oc * ow_per_oc * jcp.oc_block |
795 | + w_idx % ow_per_oc; |
796 | return jcp.typesize_in * (w_off + jcp.tr_ow * jcp.oc_block * hd_idx); |
797 | } |
798 | }; |
799 | |
800 | } // namespace x64 |
801 | } // namespace cpu |
802 | } // namespace impl |
803 | } // namespace dnnl |
804 | |
805 | #endif |
806 | |