1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp" |
18 | #include "cpu/x64/jit_brgemm_conv_utils.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | |
25 | using namespace dnnl::impl::utils; |
26 | using namespace nstl; |
27 | using namespace data_type; |
28 | |
29 | namespace jit_avx512_core_brgemm_conv_trans_kernel { |
30 | |
31 | #define GET_OFF(field) offsetof(jit_brgemm_conv_trans_kernel_call_s, field) |
32 | |
33 | jit_avx512_core_brgemm_conv_trans_kernel_t:: |
34 | jit_avx512_core_brgemm_conv_trans_kernel_t( |
35 | const jit_brgemm_conv_conf_t &ajcp, const char *name) |
36 | : jit_generator(name), jcp(ajcp) { |
37 | inp_dsz = jcp.src_dsz; |
38 | ic_block_sz = inp_dsz * jcp.ic_block; |
39 | dst_w_block = dst_w(jcp.ow_block); |
40 | dst_stride = jcp.copy_block_only ? dst_w_block : jcp.iwp; |
41 | dst_w_offset = jcp.kh_sets * jcp.kw_sets * ic_block_sz; |
42 | dst_h_offset = dst_stride * dst_w_offset; |
43 | iw_size = inp_dsz * jcp.ngroups * jcp.ic_without_padding; |
44 | VL = cpu_isa_traits<avx512_core>::vlen; |
45 | n_vec = jcp.ic_block / jcp.simd_w; |
46 | n_tail_vec = (jcp.ic_without_padding % jcp.ic_block) / jcp.simd_w; |
47 | } |
48 | |
49 | int get_inp_size(int dst_size, int ext_k, int stride, int dilate) { |
50 | const auto res = calculate_end_padding(0, dst_size, 0, stride, ext_k); |
51 | return res; |
52 | } |
53 | |
54 | int get_inp_start(int b, int b_size, int stride, int pad) { |
55 | return b * b_size * stride - pad; |
56 | } |
57 | |
58 | int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w(int out_w, int kw) const { |
59 | return get_inp_size(out_w, kw, jcp.stride_w, jcp.dilate_w); |
60 | } |
61 | |
62 | int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w(int out_w) const { |
63 | return inp_w(out_w, jcp.ext_kw); |
64 | } |
65 | |
66 | int jit_avx512_core_brgemm_conv_trans_kernel_t::dst_w(int out_w) const { |
67 | int res = 0; |
68 | if (jcp.kw_sets > 1) |
69 | res = get_inp_size(out_w, 1, 1, jcp.dilate_w); |
70 | else |
71 | res = get_inp_size(out_w, jcp.ext_kw, jcp.stride_w, jcp.dilate_w); |
72 | if (jcp.is_os_blocking) res = rnd_up(res, jcp.stride_w); |
73 | return res; |
74 | } |
75 | |
76 | int jit_avx512_core_brgemm_conv_trans_kernel_t::inp_w_start(int owb) const { |
77 | return get_inp_start(owb, jcp.ow_block, jcp.stride_w, jcp.l_pad); |
78 | } |
79 | |
80 | // use different vmovdqu32/16/8 due to case when tail mask used |
81 | void jit_avx512_core_brgemm_conv_trans_kernel_t::load( |
82 | const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
83 | if (one_of(jcp.src_dt, f32, s32)) |
84 | vmovdqu32(x, addr); |
85 | else if (one_of(jcp.src_dt, bf16, f16)) |
86 | vmovdqu16(x, addr); |
87 | else if (one_of(jcp.src_dt, s8, u8)) |
88 | vmovdqu8(x, addr); |
89 | else |
90 | assert(!"Unknown type!" ); |
91 | } |
92 | |
93 | void jit_avx512_core_brgemm_conv_trans_kernel_t::store( |
94 | const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
95 | if (one_of(jcp.src_dt, f32, s32)) |
96 | vmovdqu32(addr, x); |
97 | else if (one_of(jcp.src_dt, bf16, f16)) |
98 | vmovdqu16(addr, x); |
99 | else if (one_of(jcp.src_dt, s8, u8)) |
100 | vmovdqu8(addr, x); |
101 | else |
102 | assert(!"Unknown type!" ); |
103 | } |
104 | |
105 | void jit_avx512_core_brgemm_conv_trans_kernel_t::zero_ic_block( |
106 | bool is_ic_tail, dim_t dst_off) { |
107 | bool has_block_tail = (jcp.ic_block % jcp.simd_w); |
108 | |
109 | // TODO: use Xmm or Ymm moves for better small ic efficiency |
110 | auto nvec = is_ic_tail ? n_tail_vec : n_vec; |
111 | for (int iv = 0; iv < nvec; iv++) |
112 | store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_zero); |
113 | const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL; |
114 | if (is_ic_tail) { |
115 | if (has_block_tail) |
116 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero); |
117 | else |
118 | store(ptr[last_dst_off], zmm_zero); |
119 | } else if (has_block_tail) |
120 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_zero); |
121 | } |
122 | |
123 | void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ic_block(bool is_ic_tail, |
124 | dim_t inp_off = 0, dim_t dst_off = 0, bool do_load = true) { |
125 | bool has_block_tail = (jcp.ic_block % jcp.simd_w); |
126 | |
127 | // TODO: use Xmm or Ymm moves for better small ic efficiency |
128 | auto nvec = is_ic_tail ? n_tail_vec : n_vec; |
129 | for (int iv = 0; iv < nvec; iv++) { |
130 | if (do_load) load(zmm_tmp, ptr[aux_inp_ptr + inp_off + iv * VL]); |
131 | store(ptr[aux_dst_ptr + dst_off + iv * VL], zmm_tmp); |
132 | } |
133 | const auto last_inp_off = aux_inp_ptr + inp_off + nvec * VL; |
134 | const auto last_dst_off = aux_dst_ptr + dst_off + nvec * VL; |
135 | |
136 | if (is_ic_tail) { |
137 | auto zmm_tmp_mask = zmm_tmp | ktail_mask | T_z; |
138 | if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]); |
139 | if (has_block_tail) |
140 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp); |
141 | else |
142 | store(ptr[last_dst_off], zmm_tmp); |
143 | } else if (has_block_tail) { |
144 | auto zmm_tmp_mask = zmm_tmp | kblock_tail_mask | T_z; |
145 | if (do_load) load(zmm_tmp_mask, ptr[last_inp_off]); |
146 | store(ptr[last_dst_off] | kblock_tail_mask | T_z, zmm_tmp); |
147 | } |
148 | } |
149 | |
150 | void jit_avx512_core_brgemm_conv_trans_kernel_t::generate() { |
151 | preamble(); |
152 | |
153 | mov(inp_ptr, ptr[param1 + GET_OFF(src)]); |
154 | mov(dst_ptr, ptr[param1 + GET_OFF(dst)]); |
155 | mov(reg_hc, ptr[param1 + GET_OFF(h_count)]); |
156 | mov(reg_t_pad, ptr[param1 + GET_OFF(t_pad)]); |
157 | mov(reg_b_pad, ptr[param1 + GET_OFF(b_pad)]); |
158 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
159 | mov(reg_ic, ptr[param1 + GET_OFF(ic)]); |
160 | |
161 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
162 | |
163 | if (jcp.ic_without_padding % jcp.ic_block) { |
164 | int tail_size = (jcp.ic_without_padding % jcp.ic_block) % jcp.simd_w; |
165 | uint64_t mask = (UINT64_C(1) << tail_size) - 1; |
166 | mov(reg_tmp, mask); |
167 | kmovq(ktail_mask, reg_tmp); |
168 | } |
169 | |
170 | if (jcp.ic_block % jcp.simd_w) { |
171 | int block_tail_size = jcp.ic_block % jcp.simd_w; |
172 | uint64_t mask = (UINT64_C(1) << block_tail_size) - 1; |
173 | mov(reg_tmp, mask); |
174 | kmovq(kblock_tail_mask, reg_tmp); |
175 | } |
176 | |
177 | auto icb_loop_body = [&](bool is_ic_tail) { |
178 | Xbyak::Label kh_label, no_kh_label; |
179 | Xbyak::Label kh_tover_label, kh_bover_label; |
180 | Xbyak::Label no_kh_tover_label, no_kh_bover_label; |
181 | |
182 | mov(aux_inp_ptr, inp_ptr); |
183 | mov(aux_dst_ptr, dst_ptr); |
184 | |
185 | cmp(reg_hc, 0); |
186 | jle(no_kh_bover_label, T_NEAR); // nothing to do |
187 | |
188 | cmp(reg_t_pad, 0); |
189 | jle(no_kh_tover_label, T_NEAR); |
190 | |
191 | mov(kh_over, reg_t_pad); |
192 | L(kh_tover_label); |
193 | { |
194 | // TODO: adjust step to improve zeroing efficiency for small ic |
195 | for_(dim_t iw = 0; iw < dst_w_block; iw++) |
196 | for (int kw = 0; kw < jcp.kw_sets; kw++) |
197 | zero_ic_block(is_ic_tail, iw * dst_w_offset + kw * ic_block_sz); |
198 | add(aux_dst_ptr, dst_h_offset); |
199 | |
200 | dec(kh_over); |
201 | jnz(kh_tover_label, T_NEAR); |
202 | } |
203 | sub(reg_hc, reg_t_pad); |
204 | L(no_kh_tover_label); |
205 | |
206 | cmp(reg_hc, reg_b_pad); |
207 | jle(no_kh_label, T_NEAR); |
208 | |
209 | L(kh_label); |
210 | { |
211 | copy_ow_block(is_ic_tail); |
212 | auto inp_h_offset = jcp.iw * iw_size; |
213 | |
214 | add(aux_inp_ptr, inp_h_offset); |
215 | add(aux_dst_ptr, dst_h_offset); |
216 | |
217 | dec(reg_hc); |
218 | cmp(reg_hc, reg_b_pad); |
219 | jg(kh_label, T_NEAR); |
220 | } |
221 | L(no_kh_label); |
222 | |
223 | cmp(reg_hc, 0); |
224 | jle(no_kh_bover_label, T_NEAR); |
225 | |
226 | L(kh_bover_label); |
227 | { |
228 | // TODO: adjust step to improve zeroing efficiency for small ic |
229 | for_(dim_t iw = 0; iw < dst_w_block; iw++) |
230 | for (int kw = 0; kw < jcp.kw_sets; kw++) |
231 | zero_ic_block(is_ic_tail, iw * dst_w_offset + kw * ic_block_sz); |
232 | add(aux_dst_ptr, dst_h_offset); |
233 | |
234 | dec(reg_hc); |
235 | jnz(kh_bover_label, T_NEAR); |
236 | } |
237 | L(no_kh_bover_label); |
238 | |
239 | // End IC Loop |
240 | auto inp_cb_offset = ic_block_sz; |
241 | auto dst_cb_offset = jcp.ihp * dst_h_offset; |
242 | |
243 | add(inp_ptr, inp_cb_offset); |
244 | add(dst_ptr, dst_cb_offset); |
245 | }; |
246 | |
247 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) { |
248 | Xbyak::Label ic_tail_label, icb_continue_label; |
249 | add(reg_ic, jcp.ic_block); |
250 | cmp(reg_ic, jcp.ic); |
251 | jg(ic_tail_label, T_NEAR); |
252 | |
253 | icb_loop_body(false); |
254 | jmp(icb_continue_label, T_NEAR); |
255 | |
256 | L(ic_tail_label); |
257 | icb_loop_body(true); |
258 | |
259 | L(icb_continue_label); |
260 | } |
261 | |
262 | postamble(); |
263 | } |
264 | |
265 | void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ow_block( |
266 | bool is_ic_tail) { |
267 | if (jcp.nb_ow == 1) { |
268 | copy_ow_block_body(jcp.l_pad, jcp.ow_block, jcp.iw, is_ic_tail); |
269 | return; |
270 | } |
271 | |
272 | Xbyak::Label copy_block_done_label; |
273 | |
274 | int start_first_zero_block = -1; |
275 | int end_first_zero_block = -1; |
276 | int start_first_partial_block = -1; |
277 | int end_first_partial_block = -1; |
278 | int start_full_block = -1; |
279 | int end_full_block = -1; |
280 | int start_last_partial_block = -1; |
281 | int end_last_partial_block = -1; |
282 | |
283 | const auto adj_iw = nstl::min(jcp.iw, jcp.iwp - jcp.l_pad); |
284 | |
285 | int ow_block_tail = jcp.ow % jcp.ow_block; |
286 | |
287 | for (int owb = 0; owb < jcp.nb_ow; owb++) { |
288 | const auto inp_block = inp_w(jcp.ow_block); |
289 | const auto inp_start = inp_w_start(owb); |
290 | const auto inp_end = inp_start + inp_block; |
291 | if (inp_start + inp_block < 0) { |
292 | if (start_first_zero_block == -1) start_first_zero_block = owb; |
293 | end_first_zero_block = owb; |
294 | } else if (inp_start < 0) { |
295 | if (start_first_partial_block == -1) |
296 | start_first_partial_block = owb; |
297 | end_first_partial_block = owb; |
298 | } else if (inp_start < adj_iw) { |
299 | if (inp_end <= adj_iw) { |
300 | if (start_full_block == -1) start_full_block = owb; |
301 | end_full_block = owb; |
302 | } else { |
303 | if (start_last_partial_block == -1) |
304 | start_last_partial_block = owb; |
305 | end_last_partial_block = owb; |
306 | } |
307 | } |
308 | } |
309 | |
310 | if (start_first_zero_block != -1) { |
311 | Xbyak::Label skip_first_zero_blocks; |
312 | cmp(reg_owb, end_first_zero_block); |
313 | jg(skip_first_zero_blocks, T_NEAR); |
314 | // zero block |
315 | copy_ow_block_body(0, jcp.ow_block, 0, is_ic_tail); |
316 | jmp(copy_block_done_label, T_NEAR); |
317 | |
318 | L(skip_first_zero_blocks); |
319 | } |
320 | if (start_first_partial_block != -1) { |
321 | for (int b = start_first_partial_block; b <= end_first_partial_block; |
322 | b++) { |
323 | int cur_ow_block = (b == jcp.nb_ow - 1 && ow_block_tail > 0) |
324 | ? ow_block_tail |
325 | : jcp.ow_block; |
326 | const auto inp_block = inp_w(cur_ow_block); |
327 | const auto inp_start = inp_w_start(b); |
328 | const auto inp_end = inp_start + inp_block; |
329 | const auto block_lpad = -inp_start; |
330 | const auto block_len = nstl::min(adj_iw, inp_end); |
331 | Xbyak::Label skip_first_partial_block; |
332 | cmp(reg_owb, b); |
333 | jne(skip_first_partial_block, T_NEAR); |
334 | copy_ow_block_body(block_lpad, jcp.ow_block, block_len, is_ic_tail); |
335 | jmp(copy_block_done_label, T_NEAR); |
336 | L(skip_first_partial_block); |
337 | } |
338 | } |
339 | if (start_full_block != -1) { |
340 | Xbyak::Label skip_full_blocks; |
341 | cmp(reg_owb, end_full_block); |
342 | jg(skip_full_blocks, T_NEAR); |
343 | copy_ow_block_body(0, jcp.ow_block, inp_w(jcp.ow_block), is_ic_tail); |
344 | jmp(copy_block_done_label, T_NEAR); |
345 | |
346 | L(skip_full_blocks); |
347 | } |
348 | if (start_last_partial_block != -1) { |
349 | for (int b = start_last_partial_block; b <= end_last_partial_block; |
350 | b++) { |
351 | int cur_ow_block = (b == jcp.nb_ow - 1 && ow_block_tail > 0) |
352 | ? ow_block_tail |
353 | : jcp.ow_block; |
354 | const auto inp_block = inp_w(cur_ow_block); |
355 | const auto inp_start = inp_w_start(b); |
356 | const auto inp_end = inp_start + inp_block; |
357 | const auto block_lpad = 0; |
358 | const auto block_len = nstl::min(adj_iw, inp_end) - inp_start; |
359 | Xbyak::Label skip_last_partial_block; |
360 | cmp(reg_owb, b); |
361 | jne(skip_last_partial_block, T_NEAR); |
362 | copy_ow_block_body(block_lpad, cur_ow_block, block_len, is_ic_tail); |
363 | jmp(copy_block_done_label, T_NEAR); |
364 | |
365 | L(skip_last_partial_block); |
366 | } |
367 | } |
368 | |
369 | // if not any above case then owb is among last zero blocks |
370 | // check is this needed and check may it be partial |
371 | copy_ow_block_body(0, jcp.ow_block, 0, is_ic_tail); |
372 | |
373 | L(copy_block_done_label); |
374 | } |
375 | |
376 | void jit_avx512_core_brgemm_conv_trans_kernel_t::copy_ow_block_body( |
377 | int lpad, int ow_len, int iw_len, bool is_ic_tail) { |
378 | const auto dst_width = dst_w(ow_len); |
379 | const auto iw_stride = jcp.kw_sets > 1 ? jcp.stride_w : 1; |
380 | for_(int kw = 0; kw < jcp.kw_sets; kw++) |
381 | for (dim_t ind_w = 0; ind_w < dst_width; ind_w++) { |
382 | auto iw_idx = ind_w * iw_stride - lpad + kw * (jcp.dilate_w + 1); |
383 | auto dst_off = ind_w * dst_w_offset + kw * ic_block_sz; |
384 | if (iw_idx < 0 || iw_idx >= iw_len) { |
385 | // left or right padding |
386 | zero_ic_block(is_ic_tail, dst_off); |
387 | } else { |
388 | auto inp_off = iw_idx * iw_size; |
389 | copy_ic_block(is_ic_tail, inp_off, dst_off, true); |
390 | } |
391 | } |
392 | } |
393 | |
394 | jit_avx512_core_brgemm_conv_rtus_kernel_t:: |
395 | jit_avx512_core_brgemm_conv_rtus_kernel_t( |
396 | const jit_brgemm_conv_conf_t &ajcp) |
397 | : jit_avx512_core_brgemm_conv_trans_kernel_t(ajcp, jit_name()) { |
398 | ic_block_sz = inp_dsz * jcp.LDA; // output may or may not be zero padded |
399 | dst_h_offset = jcp.iwp * ic_block_sz; |
400 | } |
401 | |
402 | void jit_avx512_core_brgemm_conv_rtus_kernel_t::generate() { |
403 | preamble(); |
404 | |
405 | const Xbyak::Reg64 ®_khp = reg_hc; |
406 | const Xbyak::Reg64 ®_kwp = reg_owb; |
407 | |
408 | mov(inp_ptr, ptr[param1 + GET_OFF(src)]); |
409 | mov(dst_ptr, ptr[param1 + GET_OFF(dst)]); |
410 | mov(reg_khp, ptr[param1 + GET_OFF(h_count)]); |
411 | mov(reg_kwp, ptr[param1 + GET_OFF(owb)]); |
412 | |
413 | if (jcp.ic_without_padding % jcp.ic_block) { |
414 | int tail_size = (jcp.ic_without_padding % jcp.ic_block) % jcp.simd_w; |
415 | uint64_t mask = (UINT64_C(1) << tail_size) - 1; |
416 | mov(reg_tmp, mask); |
417 | kmovq(ktail_mask, reg_tmp); |
418 | } |
419 | |
420 | if (jcp.ic_block % jcp.simd_w) { |
421 | int block_tail_size = jcp.ic_block % jcp.simd_w; |
422 | uint64_t mask = (UINT64_C(1) << block_tail_size) - 1; |
423 | mov(reg_tmp, mask); |
424 | kmovq(kblock_tail_mask, reg_tmp); |
425 | } |
426 | |
427 | assert(jcp.nb_ic_blocking == 1 && "TODO: support multi-batch case" ); |
428 | |
429 | for (int icb = 0; icb < jcp.nb_ic_blocking; icb++) { |
430 | const bool is_ic_tail |
431 | = (icb + 1) * jcp.ic_block > jcp.ic_without_padding; |
432 | mov(aux_inp_ptr, inp_ptr); |
433 | mov(aux_dst_ptr, dst_ptr); |
434 | |
435 | // Section 1: copy nw spatial elements in a row |
436 | Xbyak::Label label_kwp_begin, label_kwp_end; |
437 | cmp(reg_kwp, 0); |
438 | jle(label_kwp_end, T_NEAR); |
439 | L(label_kwp_begin); |
440 | { |
441 | copy_ic_block(is_ic_tail); |
442 | |
443 | auto inp_w_step = jcp.stride_w * iw_size; |
444 | auto out_w_step = ic_block_sz; |
445 | add(aux_inp_ptr, inp_w_step); |
446 | add(aux_dst_ptr, out_w_step); |
447 | |
448 | dec(reg_kwp); |
449 | jnz(label_kwp_begin, T_NEAR); |
450 | } |
451 | L(label_kwp_end); |
452 | |
453 | // Section 2: copy nh whole rows of OW spatial elements |
454 | Xbyak::Label label_khp_begin, label_khp_end; |
455 | cmp(reg_khp, 0); |
456 | jle(label_khp_end, T_NEAR); |
457 | L(label_khp_begin); |
458 | { |
459 | for (int ow = 0; ow < jcp.ow; ow++) { |
460 | auto inp_w_off = ow * jcp.stride_w * iw_size; |
461 | auto out_w_off = ow * ic_block_sz; |
462 | copy_ic_block(is_ic_tail, inp_w_off, out_w_off); |
463 | } |
464 | |
465 | auto inp_h_step = jcp.stride_h * jcp.iw * iw_size; |
466 | auto out_h_step = jcp.ow * ic_block_sz; |
467 | add(aux_inp_ptr, inp_h_step); |
468 | add(aux_dst_ptr, out_h_step); |
469 | |
470 | dec(reg_khp); |
471 | jnz(label_khp_begin, T_NEAR); |
472 | } |
473 | L(label_khp_end); |
474 | } |
475 | |
476 | postamble(); |
477 | } |
478 | |
479 | } // namespace jit_avx512_core_brgemm_conv_trans_kernel |
480 | |
481 | } // namespace x64 |
482 | } // namespace cpu |
483 | } // namespace impl |
484 | } // namespace dnnl |
485 | |
486 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
487 | |