1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_brgemm_conv_bwd_trans_kernel.hpp" |
18 | #include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | |
25 | using namespace dnnl::impl::utils; |
26 | using namespace nstl; |
27 | using namespace data_type; |
28 | |
29 | namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel { |
30 | |
31 | #define GET_OFF(field) offsetof(jit_brgemm_conv_bwd_trans_kernel_call_s, field) |
32 | |
33 | jit_avx512_core_brgemm_conv_bwd_trans_kernel_t:: |
34 | jit_avx512_core_brgemm_conv_bwd_trans_kernel_t( |
35 | const jit_brgemm_conv_conf_t &ajcp, const char *name) |
36 | : jit_generator(name), jcp(ajcp) { |
37 | inp_dsz = jcp.src_dsz; |
38 | oc_block_sz = inp_dsz * jcp.oc_block; |
39 | dst_w_block = jcp.ow_block; |
40 | dst_stride = jcp.owp; |
41 | dst_w_offset = oc_block_sz; |
42 | dst_h_offset = dst_stride * dst_w_offset; |
43 | ow_size = inp_dsz * jcp.ngroups * jcp.oc_without_padding; |
44 | VL = cpu_isa_traits<avx512_core>::vlen; |
45 | n_vec = jcp.oc_block / jcp.simd_w; |
46 | n_tail_vec = (jcp.oc_without_padding % jcp.oc_block) / jcp.simd_w; |
47 | } |
48 | |
49 | int jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::inp_w(int out_w) const { |
50 | const auto res = div_up(out_w + jcp.l_pad % jcp.stride_w, jcp.stride_w) |
51 | + (jcp.ext_kw - 1 - jcp.l_pad % jcp.stride_w) / jcp.stride_w; |
52 | return res; |
53 | } |
54 | |
55 | int jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::inp_w_start(int iwb) const { |
56 | const auto sw = jcp.l_pad % jcp.stride_w; |
57 | const auto kw = (jcp.kw - 1) % jcp.stride_w; |
58 | const auto kw_x = (jcp.kw - 1) - nstl::modulo(kw - sw, jcp.stride_w); |
59 | const auto ow = (iwb * jcp.iw_block + jcp.l_pad - kw_x * (jcp.dilate_w + 1)) |
60 | / jcp.stride_w; |
61 | return ow; |
62 | } |
63 | |
64 | // use different vmovdqu32/16/8 due to case when tail mask used |
65 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::load( |
66 | const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
67 | switch (jcp.src_dt) { |
68 | case f32: |
69 | case s32: vmovdqu32(x, addr); break; |
70 | case bf16: |
71 | case f16: vmovdqu16(x, addr); break; |
72 | case s8: |
73 | case u8: vmovdqu8(x, addr); break; |
74 | default: assert(!"Unknown type!" ); |
75 | } |
76 | } |
77 | |
78 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::store( |
79 | const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
80 | switch (jcp.src_dt) { |
81 | case f32: |
82 | case s32: vmovdqu32(addr, x); break; |
83 | case bf16: |
84 | case f16: vmovdqu16(addr, x); break; |
85 | case s8: |
86 | case u8: vmovdqu8(addr, x); break; |
87 | default: assert(!"Unknown type!" ); |
88 | } |
89 | } |
90 | |
91 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::zero_oc_block( |
92 | bool is_oc_tail, dim_t dst_off) { |
93 | bool has_block_tail = (jcp.oc_block % jcp.simd_w); |
94 | |
95 | // TODO: use Xmm or Ymm moves for better small oc efficiency |
96 | auto nvec = is_oc_tail ? n_tail_vec : n_vec; |
97 | for (int iv = 0; iv < nvec; iv++) |
98 | store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_zero); |
99 | const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL; |
100 | if (has_block_tail) |
101 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero); |
102 | else if (is_oc_tail) |
103 | store(ptr[last_dst_off], zmm_zero); |
104 | } |
105 | |
106 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_oc_block( |
107 | bool is_oc_tail, dim_t inp_off = 0, dim_t dst_off = 0, |
108 | bool do_load = true) { |
109 | bool has_block_tail = (jcp.oc_block % jcp.simd_w); |
110 | |
111 | // TODO: use Xmm or Ymm moves for better small oc efficiency |
112 | auto nvec = is_oc_tail ? n_tail_vec : n_vec; |
113 | for (int iv = 0; iv < nvec; iv++) { |
114 | if (do_load) load(zmm_tmp, ptr[aux_inp_ptr + inp_off + iv * VL]); |
115 | store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_tmp); |
116 | } |
117 | const auto last_inp_off = aux_inp_ptr + inp_off + nvec * VL; |
118 | const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL; |
119 | |
120 | if (is_oc_tail) { |
121 | auto zmm_tmp_mask = zmm_tmp | ktail_mask | T_z; |
122 | if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]); |
123 | if (has_block_tail) |
124 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp); |
125 | else |
126 | store(ptr[last_dst_off], zmm_tmp); |
127 | } else if (has_block_tail) { |
128 | auto zmm_tmp_mask = zmm_tmp | kblock_tail_mask | T_z; |
129 | if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]); |
130 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp); |
131 | } |
132 | } |
133 | |
134 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::generate() { |
135 | preamble(); |
136 | |
137 | mov(inp_ptr, ptr[param1 + GET_OFF(src)]); |
138 | mov(dst_ptr, ptr[param1 + GET_OFF(dst)]); |
139 | mov(reg_hc, ptr[param1 + GET_OFF(h_count)]); |
140 | mov(reg_t_pad, ptr[param1 + GET_OFF(t_pad)]); |
141 | mov(reg_b_pad, ptr[param1 + GET_OFF(b_pad)]); |
142 | mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]); |
143 | mov(reg_oc, ptr[param1 + GET_OFF(oc)]); |
144 | |
145 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
146 | |
147 | if (jcp.oc_without_padding % jcp.oc_block) { |
148 | int tail_size = (jcp.oc_without_padding % jcp.oc_block) % jcp.simd_w; |
149 | uint64_t mask = (UINT64_C(1) << tail_size) - 1; |
150 | mov(reg_tmp, mask); |
151 | kmovq(ktail_mask, reg_tmp); |
152 | } |
153 | |
154 | if (jcp.oc_block % jcp.simd_w) { |
155 | int block_tail_size = jcp.oc_block % jcp.simd_w; |
156 | uint64_t mask = (UINT64_C(1) << block_tail_size) - 1; |
157 | mov(reg_tmp, mask); |
158 | kmovq(kblock_tail_mask, reg_tmp); |
159 | } |
160 | |
161 | auto ocb_loop_body = [&](bool is_oc_tail) { |
162 | Xbyak::Label kh_label, no_kh_label; |
163 | Xbyak::Label kh_tover_label, kh_bover_label; |
164 | Xbyak::Label no_kh_tover_label, no_kh_bover_label; |
165 | |
166 | mov(aux_inp_ptr, inp_ptr); |
167 | mov(aux_dst_ptr, dst_ptr); |
168 | |
169 | cmp(reg_hc, 0); |
170 | jle(no_kh_bover_label, T_NEAR); // nothing to do |
171 | |
172 | cmp(reg_t_pad, 0); |
173 | jle(no_kh_tover_label, T_NEAR); |
174 | |
175 | mov(kh_over, reg_t_pad); |
176 | L(kh_tover_label); |
177 | { |
178 | // TODO: adjust step to improve zeroing efficiency for small oc |
179 | for_(dim_t ow = 0; ow < dst_w_block; ow++) |
180 | for (int kw = 0; kw < jcp.kw_sets; kw++) |
181 | zero_oc_block(is_oc_tail, ow * dst_w_offset + kw * oc_block_sz); |
182 | add(aux_dst_ptr, dst_h_offset); |
183 | |
184 | dec(kh_over); |
185 | jnz(kh_tover_label, T_NEAR); |
186 | } |
187 | sub(reg_hc, reg_t_pad); |
188 | L(no_kh_tover_label); |
189 | |
190 | cmp(reg_hc, reg_b_pad); |
191 | jle(no_kh_label, T_NEAR); |
192 | |
193 | L(kh_label); |
194 | { |
195 | copy_iw_block(is_oc_tail); |
196 | auto inp_h_offset = jcp.ow * ow_size; |
197 | |
198 | add(aux_inp_ptr, inp_h_offset); |
199 | add(aux_dst_ptr, dst_h_offset); |
200 | |
201 | dec(reg_hc); |
202 | cmp(reg_hc, reg_b_pad); |
203 | jg(kh_label, T_NEAR); |
204 | } |
205 | L(no_kh_label); |
206 | |
207 | cmp(reg_hc, 0); |
208 | jle(no_kh_bover_label, T_NEAR); |
209 | |
210 | L(kh_bover_label); |
211 | { |
212 | // TODO: adjust step to improve zeroing efficiency for small oc |
213 | for_(dim_t ow = 0; ow < dst_w_block; ow++) |
214 | for (int kw = 0; kw < jcp.kw_sets; kw++) |
215 | zero_oc_block(is_oc_tail, ow * dst_w_offset + kw * oc_block_sz); |
216 | add(aux_dst_ptr, dst_h_offset); |
217 | |
218 | dec(reg_hc); |
219 | jnz(kh_bover_label, T_NEAR); |
220 | } |
221 | L(no_kh_bover_label); |
222 | |
223 | // End IC Loop |
224 | auto inp_cb_offset = oc_block_sz; |
225 | auto dst_cb_offset = jcp.ohp * dst_h_offset; |
226 | |
227 | add(inp_ptr, inp_cb_offset); |
228 | add(dst_ptr, dst_cb_offset); |
229 | }; |
230 | |
231 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
232 | Xbyak::Label oc_tail_label, ocb_continue_label; |
233 | add(reg_oc, jcp.oc_block); |
234 | cmp(reg_oc, jcp.oc); |
235 | jg(oc_tail_label, T_NEAR); |
236 | |
237 | ocb_loop_body(false); |
238 | jmp(ocb_continue_label, T_NEAR); |
239 | |
240 | L(oc_tail_label); |
241 | ocb_loop_body(true); |
242 | |
243 | L(ocb_continue_label); |
244 | } |
245 | |
246 | postamble(); |
247 | } |
248 | |
249 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_iw_block( |
250 | bool is_oc_tail) { |
251 | if (jcp.l_ovf > 0) { |
252 | for (dim_t ind_w = 0; ind_w < jcp.l_ovf; ind_w++) |
253 | zero_oc_block(is_oc_tail, (ind_w + jcp.l_ovf) * dst_w_offset); |
254 | } |
255 | |
256 | Xbyak::Label copy_block_done_label; |
257 | |
258 | int start_first_zero_block = -1; |
259 | int end_first_zero_block = -1; |
260 | int start_first_partial_block = -1; |
261 | int end_first_partial_block = -1; |
262 | int start_full_block = -1; |
263 | int end_full_block = -1; |
264 | int start_last_partial_block = -1; |
265 | int end_last_partial_block = -1; |
266 | |
267 | int iw_block_tail = jcp.iw % jcp.iw_block; |
268 | |
269 | for (int iwb = 0; iwb < jcp.nb_iw; iwb++) { |
270 | const auto inp_block = inp_w(jcp.iw_block); |
271 | const auto inp_start = inp_w_start(iwb); |
272 | const auto inp_end = inp_start + inp_block; |
273 | if (inp_start + inp_block < 0) { |
274 | if (start_first_zero_block == -1) start_first_zero_block = iwb; |
275 | end_first_zero_block = iwb; |
276 | } else if (inp_start < 0) { |
277 | if (start_first_partial_block == -1) |
278 | start_first_partial_block = iwb; |
279 | end_first_partial_block = iwb; |
280 | } else if (inp_start < jcp.ow) { |
281 | if (inp_end <= jcp.ow) { |
282 | if (start_full_block == -1) start_full_block = iwb; |
283 | end_full_block = iwb; |
284 | } else { |
285 | if (start_last_partial_block == -1) |
286 | start_last_partial_block = iwb; |
287 | end_last_partial_block = iwb; |
288 | } |
289 | } |
290 | } |
291 | |
292 | if (start_first_zero_block != -1) { |
293 | Xbyak::Label skip_first_zero_blocks; |
294 | cmp(reg_iwb, end_first_zero_block); |
295 | jg(skip_first_zero_blocks, T_NEAR); |
296 | // zero block |
297 | copy_iw_block_body(0, jcp.iw_block, 0, is_oc_tail); |
298 | jmp(copy_block_done_label, T_NEAR); |
299 | |
300 | L(skip_first_zero_blocks); |
301 | } |
302 | if (start_first_partial_block != -1) { |
303 | for (int b = start_first_partial_block; b <= end_first_partial_block; |
304 | b++) { |
305 | int cur_iw_block = (b == jcp.nb_iw - 1 && iw_block_tail > 0) |
306 | ? iw_block_tail |
307 | : jcp.iw_block; |
308 | const auto inp_block = inp_w(cur_iw_block); |
309 | const auto inp_start = inp_w_start(b); |
310 | const auto inp_end = inp_start + inp_block; |
311 | const auto block_lpad = -inp_start; |
312 | const auto block_len = nstl::min(jcp.ow, inp_end); |
313 | Xbyak::Label skip_first_partial_block; |
314 | cmp(reg_iwb, b); |
315 | jne(skip_first_partial_block, T_NEAR); |
316 | copy_iw_block_body(block_lpad, jcp.iw_block, block_len, is_oc_tail); |
317 | jmp(copy_block_done_label, T_NEAR); |
318 | L(skip_first_partial_block); |
319 | } |
320 | } |
321 | if (start_full_block != -1) { |
322 | Xbyak::Label skip_full_blocks; |
323 | cmp(reg_iwb, end_full_block); |
324 | jg(skip_full_blocks, T_NEAR); |
325 | copy_iw_block_body(0, jcp.iw_block, inp_w(jcp.iw_block), is_oc_tail); |
326 | jmp(copy_block_done_label, T_NEAR); |
327 | |
328 | L(skip_full_blocks); |
329 | } |
330 | if (start_last_partial_block != -1) { |
331 | for (int b = start_last_partial_block; b <= end_last_partial_block; |
332 | b++) { |
333 | int cur_iw_block = (b == jcp.nb_iw - 1 && iw_block_tail > 0) |
334 | ? iw_block_tail |
335 | : jcp.iw_block; |
336 | const auto inp_block = inp_w(cur_iw_block); |
337 | const auto inp_start = inp_w_start(b); |
338 | const auto inp_end = inp_start + inp_block; |
339 | const auto block_lpad = 0; |
340 | const auto block_len = nstl::min(jcp.ow, inp_end) - inp_start; |
341 | Xbyak::Label skip_last_partial_block; |
342 | cmp(reg_iwb, b); |
343 | jne(skip_last_partial_block, T_NEAR); |
344 | copy_iw_block_body(block_lpad, cur_iw_block, block_len, is_oc_tail); |
345 | jmp(copy_block_done_label, T_NEAR); |
346 | |
347 | L(skip_last_partial_block); |
348 | } |
349 | } |
350 | |
351 | // if there are no executed cases above then owb is among last zero blocks |
352 | // check if this is needed and if it is partial |
353 | copy_iw_block_body(0, jcp.iw_block, 0, is_oc_tail); |
354 | |
355 | L(copy_block_done_label); |
356 | } |
357 | |
358 | void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_iw_block_body( |
359 | int lpad, int iw_len, int ow_len, bool is_oc_tail) { |
360 | const auto dst_width = inp_w(iw_len) + lpad; |
361 | for (dim_t ind_w = 0; ind_w < dst_width; ind_w++) { |
362 | auto ow_idx = ind_w - lpad; |
363 | auto dst_off = (ind_w + jcp.l_ovf) * dst_w_offset; |
364 | if (ow_idx < 0 || ow_idx >= ow_len) { |
365 | zero_oc_block(is_oc_tail, dst_off); |
366 | } else { |
367 | auto inp_off = ow_idx * ow_size; |
368 | copy_oc_block(is_oc_tail, inp_off, dst_off, true); |
369 | } |
370 | } |
371 | } |
372 | |
373 | } // namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel |
374 | } // namespace x64 |
375 | } // namespace cpu |
376 | } // namespace impl |
377 | } // namespace dnnl |
378 | |
379 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
380 | |