1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp"
18#include "cpu/x64/jit_brgemm_conv_utils.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24
25using namespace dnnl::impl::utils;
26using namespace nstl;
27using namespace data_type;
28
29namespace jit_avx512_core_brgemm_conv_trans_kernel {
30
31#define GET_OFF(field) offsetof(jit_brgemm_conv_trans_kernel_call_s, field)
32
33jit_avx512_core_brgemm_conv_trans_kernel_t::
34 jit_avx512_core_brgemm_conv_trans_kernel_t(
35 const jit_brgemm_conv_conf_t &ajcp, const char *name)
36 : jit_generator(name), jcp(ajcp) {
37 inp_dsz = jcp.src_dsz;
38 ic_block_sz = inp_dsz * jcp.ic_block;
39 dst_w_block = dst_w(jcp.ow_block);
40 dst_stride = jcp.copy_block_only ? dst_w_block : jcp.iwp;
41 dst_w_offset = jcp.kh_sets * jcp.kw_sets * ic_block_sz;
42 dst_h_offset = dst_stride * dst_w_offset;
43 iw_size = inp_dsz * jcp.ngroups * jcp.ic_without_padding;
44 VL = cpu_isa_traits<avx512_core>::vlen;
45 n_vec = jcp.ic_block / jcp.simd_w;
46 n_tail_vec = (jcp.ic_without_padding % jcp.ic_block) / jcp.simd_w;
47}
48
49int get_inp_size(int dst_size, int ext_k, int stride, int dilate) {
50 const auto res = calculate_end_padding(0, dst_size, 0, stride, ext_k);
51 return res;
52}
53
54int get_inp_start(int b, int b_size, int stride, int pad) {
55 return b * b_size * stride - pad;
56}
57
58int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w(int out_w, int kw) const {
59 return get_inp_size(out_w, kw, jcp.stride_w, jcp.dilate_w);
60}
61
62int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w(int out_w) const {
63 return inp_w(out_w, jcp.ext_kw);
64}
65
66int jit_avx512_core_brgemm_conv_trans_kernel_t::dst_w(int out_w) const {
67 int res = 0;
68 if (jcp.kw_sets > 1)
69 res = get_inp_size(out_w, 1, 1, jcp.dilate_w);
70 else
71 res = get_inp_size(out_w, jcp.ext_kw, jcp.stride_w, jcp.dilate_w);
72 if (jcp.is_os_blocking) res = rnd_up(res, jcp.stride_w);
73 return res;
74}
75
76int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w_start(int owb) const {
77 return get_inp_start(owb, jcp.ow_block, jcp.stride_w, jcp.l_pad);
78}
79
80// use different vmovdqu32/16/8 due to case when tail mask used
81void jit_avx512_core_brgemm_conv_trans_kernel_t::load(
82 const Xbyak::Xmm &x, const Xbyak::Address &addr) {
83 if (one_of(jcp.src_dt, f32, s32))
84 vmovdqu32(x, addr);
85 else if (one_of(jcp.src_dt, bf16, f16))
86 vmovdqu16(x, addr);
87 else if (one_of(jcp.src_dt, s8, u8))
88 vmovdqu8(x, addr);
89 else
90 assert(!"Unknown type!");
91}
92
93void jit_avx512_core_brgemm_conv_trans_kernel_t::store(
94 const Xbyak::Address &addr, const Xbyak::Xmm &x) {
95 if (one_of(jcp.src_dt, f32, s32))
96 vmovdqu32(addr, x);
97 else if (one_of(jcp.src_dt, bf16, f16))
98 vmovdqu16(addr, x);
99 else if (one_of(jcp.src_dt, s8, u8))
100 vmovdqu8(addr, x);
101 else
102 assert(!"Unknown type!");
103}
104
105void jit_avx512_core_brgemm_conv_trans_kernel_t::zero_ic_block(
106 bool is_ic_tail, dim_t dst_off) {
107 bool has_block_tail = (jcp.ic_block % jcp.simd_w);
108
109 // TODO: use Xmm or Ymm moves for better small ic efficiency
110 auto nvec = is_ic_tail ? n_tail_vec : n_vec;
111 for (int iv = 0; iv < nvec; iv++)
112 store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_zero);
113 const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL;
114 if (is_ic_tail) {
115 if (has_block_tail)
116 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero);
117 else
118 store(ptr[last_dst_off], zmm_zero);
119 } else if (has_block_tail)
120 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero);
121}
122
123void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ic_block(bool is_ic_tail,
124 dim_t inp_off = 0, dim_t dst_off = 0, bool do_load = true) {
125 bool has_block_tail = (jcp.ic_block % jcp.simd_w);
126
127 // TODO: use Xmm or Ymm moves for better small ic efficiency
128 auto nvec = is_ic_tail ? n_tail_vec : n_vec;
129 for (int iv = 0; iv < nvec; iv++) {
130 if (do_load) load(zmm_tmp, ptr[aux_inp_ptr + inp_off + iv * VL]);
131 store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_tmp);
132 }
133 const auto last_inp_off = aux_inp_ptr + inp_off + nvec * VL;
134 const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL;
135
136 if (is_ic_tail) {
137 auto zmm_tmp_mask = zmm_tmp | ktail_mask | T_z;
138 if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]);
139 if (has_block_tail)
140 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp);
141 else
142 store(ptr[last_dst_off], zmm_tmp);
143 } else if (has_block_tail) {
144 auto zmm_tmp_mask = zmm_tmp | kblock_tail_mask | T_z;
145 if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]);
146 store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp);
147 }
148}
149
150void jit_avx512_core_brgemm_conv_trans_kernel_t::generate() {
151 preamble();
152
153 mov(inp_ptr, ptr[param1 + GET_OFF(src)]);
154 mov(dst_ptr, ptr[param1 + GET_OFF(dst)]);
155 mov(reg_hc, ptr[param1 + GET_OFF(h_count)]);
156 mov(reg_t_pad, ptr[param1 + GET_OFF(t_pad)]);
157 mov(reg_b_pad, ptr[param1 + GET_OFF(b_pad)]);
158 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
159 mov(reg_ic, ptr[param1 + GET_OFF(ic)]);
160
161 vpxord(zmm_zero, zmm_zero, zmm_zero);
162
163 if (jcp.ic_without_padding % jcp.ic_block) {
164 int tail_size = (jcp.ic_without_padding % jcp.ic_block) % jcp.simd_w;
165 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
166 mov(reg_tmp, mask);
167 kmovq(ktail_mask, reg_tmp);
168 }
169
170 if (jcp.ic_block % jcp.simd_w) {
171 int block_tail_size = jcp.ic_block % jcp.simd_w;
172 uint64_t mask = (UINT64_C(1) << block_tail_size) - 1;
173 mov(reg_tmp, mask);
174 kmovq(kblock_tail_mask, reg_tmp);
175 }
176
177 auto icb_loop_body = [&](bool is_ic_tail) {
178 Xbyak::Label kh_label, no_kh_label;
179 Xbyak::Label kh_tover_label, kh_bover_label;
180 Xbyak::Label no_kh_tover_label, no_kh_bover_label;
181
182 mov(aux_inp_ptr, inp_ptr);
183 mov(aux_dst_ptr, dst_ptr);
184
185 cmp(reg_hc, 0);
186 jle(no_kh_bover_label, T_NEAR); // nothing to do
187
188 cmp(reg_t_pad, 0);
189 jle(no_kh_tover_label, T_NEAR);
190
191 mov(kh_over, reg_t_pad);
192 L(kh_tover_label);
193 {
194 // TODO: adjust step to improve zeroing efficiency for small ic
195 for_(dim_t iw = 0; iw < dst_w_block; iw++)
196 for (int kw = 0; kw < jcp.kw_sets; kw++)
197 zero_ic_block(is_ic_tail, iw * dst_w_offset + kw * ic_block_sz);
198 add(aux_dst_ptr, dst_h_offset);
199
200 dec(kh_over);
201 jnz(kh_tover_label, T_NEAR);
202 }
203 sub(reg_hc, reg_t_pad);
204 L(no_kh_tover_label);
205
206 cmp(reg_hc, reg_b_pad);
207 jle(no_kh_label, T_NEAR);
208
209 L(kh_label);
210 {
211 copy_ow_block(is_ic_tail);
212 auto inp_h_offset = jcp.iw * iw_size;
213
214 add(aux_inp_ptr, inp_h_offset);
215 add(aux_dst_ptr, dst_h_offset);
216
217 dec(reg_hc);
218 cmp(reg_hc, reg_b_pad);
219 jg(kh_label, T_NEAR);
220 }
221 L(no_kh_label);
222
223 cmp(reg_hc, 0);
224 jle(no_kh_bover_label, T_NEAR);
225
226 L(kh_bover_label);
227 {
228 // TODO: adjust step to improve zeroing efficiency for small ic
229 for_(dim_t iw = 0; iw < dst_w_block; iw++)
230 for (int kw = 0; kw < jcp.kw_sets; kw++)
231 zero_ic_block(is_ic_tail, iw * dst_w_offset + kw * ic_block_sz);
232 add(aux_dst_ptr, dst_h_offset);
233
234 dec(reg_hc);
235 jnz(kh_bover_label, T_NEAR);
236 }
237 L(no_kh_bover_label);
238
239 // End IC Loop
240 auto inp_cb_offset = ic_block_sz;
241 auto dst_cb_offset = jcp.ihp * dst_h_offset;
242
243 add(inp_ptr, inp_cb_offset);
244 add(dst_ptr, dst_cb_offset);
245 };
246
247 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
248 Xbyak::Label ic_tail_label, icb_continue_label;
249 add(reg_ic, jcp.ic_block);
250 cmp(reg_ic, jcp.ic);
251 jg(ic_tail_label, T_NEAR);
252
253 icb_loop_body(false);
254 jmp(icb_continue_label, T_NEAR);
255
256 L(ic_tail_label);
257 icb_loop_body(true);
258
259 L(icb_continue_label);
260 }
261
262 postamble();
263}
264
265void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ow_block(
266 bool is_ic_tail) {
267 if (jcp.nb_ow == 1) {
268 copy_ow_block_body(jcp.l_pad, jcp.ow_block, jcp.iw, is_ic_tail);
269 return;
270 }
271
272 Xbyak::Label copy_block_done_label;
273
274 int start_first_zero_block = -1;
275 int end_first_zero_block = -1;
276 int start_first_partial_block = -1;
277 int end_first_partial_block = -1;
278 int start_full_block = -1;
279 int end_full_block = -1;
280 int start_last_partial_block = -1;
281 int end_last_partial_block = -1;
282
283 const auto adj_iw = nstl::min(jcp.iw, jcp.iwp - jcp.l_pad);
284
285 int ow_block_tail = jcp.ow % jcp.ow_block;
286
287 for (int owb = 0; owb < jcp.nb_ow; owb++) {
288 const auto inp_block = inp_w(jcp.ow_block);
289 const auto inp_start = inp_w_start(owb);
290 const auto inp_end = inp_start + inp_block;
291 if (inp_start + inp_block < 0) {
292 if (start_first_zero_block == -1) start_first_zero_block = owb;
293 end_first_zero_block = owb;
294 } else if (inp_start < 0) {
295 if (start_first_partial_block == -1)
296 start_first_partial_block = owb;
297 end_first_partial_block = owb;
298 } else if (inp_start < adj_iw) {
299 if (inp_end <= adj_iw) {
300 if (start_full_block == -1) start_full_block = owb;
301 end_full_block = owb;
302 } else {
303 if (start_last_partial_block == -1)
304 start_last_partial_block = owb;
305 end_last_partial_block = owb;
306 }
307 }
308 }
309
310 if (start_first_zero_block != -1) {
311 Xbyak::Label skip_first_zero_blocks;
312 cmp(reg_owb, end_first_zero_block);
313 jg(skip_first_zero_blocks, T_NEAR);
314 // zero block
315 copy_ow_block_body(0, jcp.ow_block, 0, is_ic_tail);
316 jmp(copy_block_done_label, T_NEAR);
317
318 L(skip_first_zero_blocks);
319 }
320 if (start_first_partial_block != -1) {
321 for (int b = start_first_partial_block; b <= end_first_partial_block;
322 b++) {
323 int cur_ow_block = (b == jcp.nb_ow - 1 && ow_block_tail > 0)
324 ? ow_block_tail
325 : jcp.ow_block;
326 const auto inp_block = inp_w(cur_ow_block);
327 const auto inp_start = inp_w_start(b);
328 const auto inp_end = inp_start + inp_block;
329 const auto block_lpad = -inp_start;
330 const auto block_len = nstl::min(adj_iw, inp_end);
331 Xbyak::Label skip_first_partial_block;
332 cmp(reg_owb, b);
333 jne(skip_first_partial_block, T_NEAR);
334 copy_ow_block_body(block_lpad, jcp.ow_block, block_len, is_ic_tail);
335 jmp(copy_block_done_label, T_NEAR);
336 L(skip_first_partial_block);
337 }
338 }
339 if (start_full_block != -1) {
340 Xbyak::Label skip_full_blocks;
341 cmp(reg_owb, end_full_block);
342 jg(skip_full_blocks, T_NEAR);
343 copy_ow_block_body(0, jcp.ow_block, inp_w(jcp.ow_block), is_ic_tail);
344 jmp(copy_block_done_label, T_NEAR);
345
346 L(skip_full_blocks);
347 }
348 if (start_last_partial_block != -1) {
349 for (int b = start_last_partial_block; b <= end_last_partial_block;
350 b++) {
351 int cur_ow_block = (b == jcp.nb_ow - 1 && ow_block_tail > 0)
352 ? ow_block_tail
353 : jcp.ow_block;
354 const auto inp_block = inp_w(cur_ow_block);
355 const auto inp_start = inp_w_start(b);
356 const auto inp_end = inp_start + inp_block;
357 const auto block_lpad = 0;
358 const auto block_len = nstl::min(adj_iw, inp_end) - inp_start;
359 Xbyak::Label skip_last_partial_block;
360 cmp(reg_owb, b);
361 jne(skip_last_partial_block, T_NEAR);
362 copy_ow_block_body(block_lpad, cur_ow_block, block_len, is_ic_tail);
363 jmp(copy_block_done_label, T_NEAR);
364
365 L(skip_last_partial_block);
366 }
367 }
368
369 // if not any above case then owb is among last zero blocks
370 // check is this needed and check may it be partial
371 copy_ow_block_body(0, jcp.ow_block, 0, is_ic_tail);
372
373 L(copy_block_done_label);
374}
375
376void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ow_block_body(
377 int lpad, int ow_len, int iw_len, bool is_ic_tail) {
378 const auto dst_width = dst_w(ow_len);
379 const auto iw_stride = jcp.kw_sets > 1 ? jcp.stride_w : 1;
380 for_(int kw = 0; kw < jcp.kw_sets; kw++)
381 for (dim_t ind_w = 0; ind_w < dst_width; ind_w++) {
382 auto iw_idx = ind_w * iw_stride - lpad + kw * (jcp.dilate_w + 1);
383 auto dst_off = ind_w * dst_w_offset + kw * ic_block_sz;
384 if (iw_idx < 0 || iw_idx >= iw_len) {
385 // left or right padding
386 zero_ic_block(is_ic_tail, dst_off);
387 } else {
388 auto inp_off = iw_idx * iw_size;
389 copy_ic_block(is_ic_tail, inp_off, dst_off, true);
390 }
391 }
392}
393
394jit_avx512_core_brgemm_conv_rtus_kernel_t::
395 jit_avx512_core_brgemm_conv_rtus_kernel_t(
396 const jit_brgemm_conv_conf_t &ajcp)
397 : jit_avx512_core_brgemm_conv_trans_kernel_t(ajcp, jit_name()) {
398 ic_block_sz = inp_dsz * jcp.LDA; // output may or may not be zero padded
399 dst_h_offset = jcp.iwp * ic_block_sz;
400}
401
402void jit_avx512_core_brgemm_conv_rtus_kernel_t::generate() {
403 preamble();
404
405 const Xbyak::Reg64 &reg_khp = reg_hc;
406 const Xbyak::Reg64 &reg_kwp = reg_owb;
407
408 mov(inp_ptr, ptr[param1 + GET_OFF(src)]);
409 mov(dst_ptr, ptr[param1 + GET_OFF(dst)]);
410 mov(reg_khp, ptr[param1 + GET_OFF(h_count)]);
411 mov(reg_kwp, ptr[param1 + GET_OFF(owb)]);
412
413 if (jcp.ic_without_padding % jcp.ic_block) {
414 int tail_size = (jcp.ic_without_padding % jcp.ic_block) % jcp.simd_w;
415 uint64_t mask = (UINT64_C(1) << tail_size) - 1;
416 mov(reg_tmp, mask);
417 kmovq(ktail_mask, reg_tmp);
418 }
419
420 if (jcp.ic_block % jcp.simd_w) {
421 int block_tail_size = jcp.ic_block % jcp.simd_w;
422 uint64_t mask = (UINT64_C(1) << block_tail_size) - 1;
423 mov(reg_tmp, mask);
424 kmovq(kblock_tail_mask, reg_tmp);
425 }
426
427 assert(jcp.nb_ic_blocking == 1 && "TODO: support multi-batch case");
428
429 for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) {
430 const bool is_ic_tail
431 = (icb + 1) * jcp.ic_block > jcp.ic_without_padding;
432 mov(aux_inp_ptr, inp_ptr);
433 mov(aux_dst_ptr, dst_ptr);
434
435 // Section 1: copy nw spatial elements in a row
436 Xbyak::Label label_kwp_begin, label_kwp_end;
437 cmp(reg_kwp, 0);
438 jle(label_kwp_end, T_NEAR);
439 L(label_kwp_begin);
440 {
441 copy_ic_block(is_ic_tail);
442
443 auto inp_w_step = jcp.stride_w * iw_size;
444 auto out_w_step = ic_block_sz;
445 add(aux_inp_ptr, inp_w_step);
446 add(aux_dst_ptr, out_w_step);
447
448 dec(reg_kwp);
449 jnz(label_kwp_begin, T_NEAR);
450 }
451 L(label_kwp_end);
452
453 // Section 2: copy nh whole rows of OW spatial elements
454 Xbyak::Label label_khp_begin, label_khp_end;
455 cmp(reg_khp, 0);
456 jle(label_khp_end, T_NEAR);
457 L(label_khp_begin);
458 {
459 for (int ow = 0; ow < jcp.ow; ow++) {
460 auto inp_w_off = ow * jcp.stride_w * iw_size;
461 auto out_w_off = ow * ic_block_sz;
462 copy_ic_block(is_ic_tail, inp_w_off, out_w_off);
463 }
464
465 auto inp_h_step = jcp.stride_h * jcp.iw * iw_size;
466 auto out_h_step = jcp.ow * ic_block_sz;
467 add(aux_inp_ptr, inp_h_step);
468 add(aux_dst_ptr, out_h_step);
469
470 dec(reg_khp);
471 jnz(label_khp_begin, T_NEAR);
472 }
473 L(label_khp_end);
474 }
475
476 postamble();
477}
478
479} // namespace jit_avx512_core_brgemm_conv_trans_kernel
480
481} // namespace x64
482} // namespace cpu
483} // namespace impl
484} // namespace dnnl
485
486// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
487