1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_brgemm_conv_bwd_trans_kernel.hpp"
18#include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24
25using namespace dnnl::impl::utils;
26using namespace nstl;
27using namespace data_type;
28
29namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel {
30
31#define GET_OFF(field) offsetof(jit_brgemm_conv_bwd_trans_kernel_call_s, field)
32
33jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::
34 jit_avx512_core_brgemm_conv_bwd_trans_kernel_t(
35 const jit_brgemm_conv_conf_t &ajcp, const char *name)
36 : jit_generator(name), jcp(ajcp) {
37 inp_dsz = jcp.src_dsz;
38 oc_block_sz = inp_dsz * jcp.oc_block;
39 dst_w_block = jcp.ow_block;
40 dst_stride = jcp.owp;
41 dst_w_offset = oc_block_sz;
42 dst_h_offset = dst_stride * dst_w_offset;
43 ow_size = inp_dsz * jcp.ngroups * jcp.oc_without_padding;
44 VL = cpu_isa_traits<avx512_core>::vlen;
45 n_vec = jcp.oc_block / jcp.simd_w;
46 n_tail_vec = (jcp.oc_without_padding % jcp.oc_block) / jcp.simd_w;
47}
48
49int jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::inp_w(int out_w) const {
50 const auto res = div_up(out_w + jcp.l_pad % jcp.stride_w, jcp.stride_w)
51 + (jcp.ext_kw - 1 - jcp.l_pad % jcp.stride_w) / jcp.stride_w;
52 return res;
53}
54
55int jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::inp_w_start(int iwb) const {
56 const auto sw = jcp.l_pad % jcp.stride_w;
57 const auto kw = (jcp.kw - 1) % jcp.stride_w;
58 const auto kw_x = (jcp.kw - 1) - nstl::modulo(kw - sw, jcp.stride_w);
59 const auto ow = (iwb * jcp.iw_block + jcp.l_pad - kw_x * (jcp.dilate_w + 1))
60 / jcp.stride_w;
61 return ow;
62}
63
64// use different vmovdqu32/16/8 due to case when tail mask used
65void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::load(
66 const Xbyak::Xmm &x, const Xbyak::Address &addr) {
67 switch (jcp.src_dt) {
68 case f32:
69 case s32: vmovdqu32(x, addr); break;
70 case bf16:
71 case f16: vmovdqu16(x, addr); break;
72 case s8:
73 case u8: vmovdqu8(x, addr); break;
74 default: assert(!"Unknown type!");
75 }
76}
77
78void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::store(
79 const Xbyak::Address &addr, const Xbyak::Xmm &x) {
80 switch (jcp.src_dt) {
81 case f32:
82 case s32: vmovdqu32(addr, x); break;
83 case bf16:
84 case f16: vmovdqu16(addr, x); break;
85 case s8:
86 case u8: vmovdqu8(addr, x); break;
87 default: assert(!"Unknown type!");
88 }
89}
90
91void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::zero_oc_block(
92 bool is_oc_tail, dim_t dst_off) {
93 bool has_block_tail = (jcp.oc_block % jcp.simd_w);
94
95 // TODO: use Xmm or Ymm moves for better small oc efficiency
96 auto nvec = is_oc_tail ? n_tail_vec : n_vec;
97 for (int iv = 0; iv < nvec; iv++)
98 store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_zero);
99 const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL;
100 if (has_block_tail)
101 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero);
102 else if (is_oc_tail)
103 store(ptr[last_dst_off], zmm_zero);
104}
105
106void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_oc_block(
107 bool is_oc_tail, dim_t inp_off = 0, dim_t dst_off = 0,
108 bool do_load = true) {
109 bool has_block_tail = (jcp.oc_block % jcp.simd_w);
110
111 // TODO: use Xmm or Ymm moves for better small oc efficiency
112 auto nvec = is_oc_tail ? n_tail_vec : n_vec;
113 for (int iv = 0; iv < nvec; iv++) {
114 if (do_load) load(zmm_tmp, ptr[aux_inp_ptr + inp_off + iv * VL]);
115 store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_tmp);
116 }
117 const auto last_inp_off = aux_inp_ptr + inp_off + nvec * VL;
118 const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL;
119
120 if (is_oc_tail) {
121 auto zmm_tmp_mask = zmm_tmp | ktail_mask | T_z;
122 if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]);
123 if (has_block_tail)
124 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp);
125 else
126 store(ptr[last_dst_off], zmm_tmp);
127 } else if (has_block_tail) {
128 auto zmm_tmp_mask = zmm_tmp | kblock_tail_mask | T_z;
129 if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]);
130 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp);
131 }
132}
133
134void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::generate() {
135 preamble();
136
137 mov(inp_ptr, ptr[param1 + GET_OFF(src)]);
138 mov(dst_ptr, ptr[param1 + GET_OFF(dst)]);
139 mov(reg_hc, ptr[param1 + GET_OFF(h_count)]);
140 mov(reg_t_pad, ptr[param1 + GET_OFF(t_pad)]);
141 mov(reg_b_pad, ptr[param1 + GET_OFF(b_pad)]);
142 mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
143 mov(reg_oc, ptr[param1 + GET_OFF(oc)]);
144
145 vpxord(zmm_zero, zmm_zero, zmm_zero);
146
147 if (jcp.oc_without_padding % jcp.oc_block) {
148 int tail_size = (jcp.oc_without_padding % jcp.oc_block) % jcp.simd_w;
149 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
150 mov(reg_tmp, mask);
151 kmovq(ktail_mask, reg_tmp);
152 }
153
154 if (jcp.oc_block % jcp.simd_w) {
155 int block_tail_size = jcp.oc_block % jcp.simd_w;
156 uint64_t mask = (UINT64_C(1) << block_tail_size) - 1;
157 mov(reg_tmp, mask);
158 kmovq(kblock_tail_mask, reg_tmp);
159 }
160
161 auto ocb_loop_body = [&](bool is_oc_tail) {
162 Xbyak::Label kh_label, no_kh_label;
163 Xbyak::Label kh_tover_label, kh_bover_label;
164 Xbyak::Label no_kh_tover_label, no_kh_bover_label;
165
166 mov(aux_inp_ptr, inp_ptr);
167 mov(aux_dst_ptr, dst_ptr);
168
169 cmp(reg_hc, 0);
170 jle(no_kh_bover_label, T_NEAR); // nothing to do
171
172 cmp(reg_t_pad, 0);
173 jle(no_kh_tover_label, T_NEAR);
174
175 mov(kh_over, reg_t_pad);
176 L(kh_tover_label);
177 {
178 // TODO: adjust step to improve zeroing efficiency for small oc
179 for_(dim_t ow = 0; ow < dst_w_block; ow++)
180 for (int kw = 0; kw < jcp.kw_sets; kw++)
181 zero_oc_block(is_oc_tail, ow * dst_w_offset + kw * oc_block_sz);
182 add(aux_dst_ptr, dst_h_offset);
183
184 dec(kh_over);
185 jnz(kh_tover_label, T_NEAR);
186 }
187 sub(reg_hc, reg_t_pad);
188 L(no_kh_tover_label);
189
190 cmp(reg_hc, reg_b_pad);
191 jle(no_kh_label, T_NEAR);
192
193 L(kh_label);
194 {
195 copy_iw_block(is_oc_tail);
196 auto inp_h_offset = jcp.ow * ow_size;
197
198 add(aux_inp_ptr, inp_h_offset);
199 add(aux_dst_ptr, dst_h_offset);
200
201 dec(reg_hc);
202 cmp(reg_hc, reg_b_pad);
203 jg(kh_label, T_NEAR);
204 }
205 L(no_kh_label);
206
207 cmp(reg_hc, 0);
208 jle(no_kh_bover_label, T_NEAR);
209
210 L(kh_bover_label);
211 {
212 // TODO: adjust step to improve zeroing efficiency for small oc
213 for_(dim_t ow = 0; ow < dst_w_block; ow++)
214 for (int kw = 0; kw < jcp.kw_sets; kw++)
215 zero_oc_block(is_oc_tail, ow * dst_w_offset + kw * oc_block_sz);
216 add(aux_dst_ptr, dst_h_offset);
217
218 dec(reg_hc);
219 jnz(kh_bover_label, T_NEAR);
220 }
221 L(no_kh_bover_label);
222
223 // End IC Loop
224 auto inp_cb_offset = oc_block_sz;
225 auto dst_cb_offset = jcp.ohp * dst_h_offset;
226
227 add(inp_ptr, inp_cb_offset);
228 add(dst_ptr, dst_cb_offset);
229 };
230
231 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
232 Xbyak::Label oc_tail_label, ocb_continue_label;
233 add(reg_oc, jcp.oc_block);
234 cmp(reg_oc, jcp.oc);
235 jg(oc_tail_label, T_NEAR);
236
237 ocb_loop_body(false);
238 jmp(ocb_continue_label, T_NEAR);
239
240 L(oc_tail_label);
241 ocb_loop_body(true);
242
243 L(ocb_continue_label);
244 }
245
246 postamble();
247}
248
249void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_iw_block(
250 bool is_oc_tail) {
251 if (jcp.l_ovf > 0) {
252 for (dim_t ind_w = 0; ind_w < jcp.l_ovf; ind_w++)
253 zero_oc_block(is_oc_tail, (ind_w + jcp.l_ovf) * dst_w_offset);
254 }
255
256 Xbyak::Label copy_block_done_label;
257
258 int start_first_zero_block = -1;
259 int end_first_zero_block = -1;
260 int start_first_partial_block = -1;
261 int end_first_partial_block = -1;
262 int start_full_block = -1;
263 int end_full_block = -1;
264 int start_last_partial_block = -1;
265 int end_last_partial_block = -1;
266
267 int iw_block_tail = jcp.iw % jcp.iw_block;
268
269 for (int iwb = 0; iwb < jcp.nb_iw; iwb++) {
270 const auto inp_block = inp_w(jcp.iw_block);
271 const auto inp_start = inp_w_start(iwb);
272 const auto inp_end = inp_start + inp_block;
273 if (inp_start + inp_block < 0) {
274 if (start_first_zero_block == -1) start_first_zero_block = iwb;
275 end_first_zero_block = iwb;
276 } else if (inp_start < 0) {
277 if (start_first_partial_block == -1)
278 start_first_partial_block = iwb;
279 end_first_partial_block = iwb;
280 } else if (inp_start < jcp.ow) {
281 if (inp_end <= jcp.ow) {
282 if (start_full_block == -1) start_full_block = iwb;
283 end_full_block = iwb;
284 } else {
285 if (start_last_partial_block == -1)
286 start_last_partial_block = iwb;
287 end_last_partial_block = iwb;
288 }
289 }
290 }
291
292 if (start_first_zero_block != -1) {
293 Xbyak::Label skip_first_zero_blocks;
294 cmp(reg_iwb, end_first_zero_block);
295 jg(skip_first_zero_blocks, T_NEAR);
296 // zero block
297 copy_iw_block_body(0, jcp.iw_block, 0, is_oc_tail);
298 jmp(copy_block_done_label, T_NEAR);
299
300 L(skip_first_zero_blocks);
301 }
302 if (start_first_partial_block != -1) {
303 for (int b = start_first_partial_block; b <= end_first_partial_block;
304 b++) {
305 int cur_iw_block = (b == jcp.nb_iw - 1 && iw_block_tail > 0)
306 ? iw_block_tail
307 : jcp.iw_block;
308 const auto inp_block = inp_w(cur_iw_block);
309 const auto inp_start = inp_w_start(b);
310 const auto inp_end = inp_start + inp_block;
311 const auto block_lpad = -inp_start;
312 const auto block_len = nstl::min(jcp.ow, inp_end);
313 Xbyak::Label skip_first_partial_block;
314 cmp(reg_iwb, b);
315 jne(skip_first_partial_block, T_NEAR);
316 copy_iw_block_body(block_lpad, jcp.iw_block, block_len, is_oc_tail);
317 jmp(copy_block_done_label, T_NEAR);
318 L(skip_first_partial_block);
319 }
320 }
321 if (start_full_block != -1) {
322 Xbyak::Label skip_full_blocks;
323 cmp(reg_iwb, end_full_block);
324 jg(skip_full_blocks, T_NEAR);
325 copy_iw_block_body(0, jcp.iw_block, inp_w(jcp.iw_block), is_oc_tail);
326 jmp(copy_block_done_label, T_NEAR);
327
328 L(skip_full_blocks);
329 }
330 if (start_last_partial_block != -1) {
331 for (int b = start_last_partial_block; b <= end_last_partial_block;
332 b++) {
333 int cur_iw_block = (b == jcp.nb_iw - 1 && iw_block_tail > 0)
334 ? iw_block_tail
335 : jcp.iw_block;
336 const auto inp_block = inp_w(cur_iw_block);
337 const auto inp_start = inp_w_start(b);
338 const auto inp_end = inp_start + inp_block;
339 const auto block_lpad = 0;
340 const auto block_len = nstl::min(jcp.ow, inp_end) - inp_start;
341 Xbyak::Label skip_last_partial_block;
342 cmp(reg_iwb, b);
343 jne(skip_last_partial_block, T_NEAR);
344 copy_iw_block_body(block_lpad, cur_iw_block, block_len, is_oc_tail);
345 jmp(copy_block_done_label, T_NEAR);
346
347 L(skip_last_partial_block);
348 }
349 }
350
351 // if there are no executed cases above then owb is among last zero blocks
352 // check if this is needed and if it is partial
353 copy_iw_block_body(0, jcp.iw_block, 0, is_oc_tail);
354
355 L(copy_block_done_label);
356}
357
358void jit_avx512_core_brgemm_conv_bwd_trans_kernel_t::copy_iw_block_body(
359 int lpad, int iw_len, int ow_len, bool is_oc_tail) {
360 const auto dst_width = inp_w(iw_len) + lpad;
361 for (dim_t ind_w = 0; ind_w < dst_width; ind_w++) {
362 auto ow_idx = ind_w - lpad;
363 auto dst_off = (ind_w + jcp.l_ovf) * dst_w_offset;
364 if (ow_idx < 0 || ow_idx >= ow_len) {
365 zero_oc_block(is_oc_tail, dst_off);
366 } else {
367 auto inp_off = ow_idx * ow_size;
368 copy_oc_block(is_oc_tail, inp_off, dst_off, true);
369 }
370 }
371}
372
373} // namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel
374} // namespace x64
375} // namespace cpu
376} // namespace impl
377} // namespace dnnl
378
379// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
380