1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
23#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
24#include "cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp"
25
26#define GET_OFF(field) offsetof(jit_conv_call_s, field)
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace Xbyak;
34using namespace dnnl::impl::utils;
35
36jit_avx512_dw_conv_fwd_kernel_bf16::jit_avx512_dw_conv_fwd_kernel_bf16(
37 const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md)
38 : jit_generator(jit_name()), jcp(ajcp) {
39 if (jcp.with_eltwise || jcp.with_binary) {
40 using namespace binary_injector;
41 static constexpr bool preserve_gpr = true;
42 static constexpr bool preserve_vmm = false;
43 static constexpr size_t helper_vmm_idx = 31;
44 static constexpr bool use_exact_tail_scalar_bcast = true;
45 const size_t tail_size = jcp.oc_without_padding
46 % (cpu_isa_traits<avx512_core>::vlen / sizeof(float));
47
48 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
49 r14, r15, r12, preserve_gpr, preserve_vmm,
50 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
51 memory_desc_wrapper(dst_md), tail_size, k_oc_tail_mask,
52 use_exact_tail_scalar_bcast};
53 const static_params_t static_params {
54 this->param1, rhs_arg_static_params};
55
56 postops_injector_ = utils::make_unique<
57 injector::jit_uni_postops_injector_t<avx512_core>>(
58 this, jcp.post_ops, static_params);
59 }
60 if (!isa_has_bf16(jcp.isa))
61 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
62 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
63 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6);
64}
65
66int jit_avx512_dw_conv_fwd_kernel_bf16::get_acc_reg_idx(int idx) const {
67 assert(idx + acc_idx_start <= get_max_regs());
68 return idx + acc_idx_start;
69}
70
71Xbyak::Zmm jit_avx512_dw_conv_fwd_kernel_bf16::get_acc_reg(int idx) {
72 return Xbyak::Zmm(get_acc_reg_idx(idx));
73}
74
75void jit_avx512_dw_conv_fwd_kernel_bf16::load_src(
76 int ur_ch_blocks, int ur_w, bool last_ch_block_flag) {
77
78 const auto dst_layout_nxc = is_dst_layout_nxc();
79 const auto ch_blk = jcp.ch_block;
80 const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
81 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
82
83 for (int ch = 0; ch < ur_ch_blocks; ch++) {
84 const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1;
85 for (int ow = 0; ow < ur_w; ow++) {
86 Zmm zmm_acc = get_acc_reg(ch * ur_w + ow);
87 const Zmm zmm_acc_msk
88 = mask_flag ? zmm_acc | ktail_mask | T_z : zmm_acc;
89
90 if (this->jcp.with_bias) {
91 int b_off = ch * ch_blk;
92 uni_vmovups(
93 zmm_acc_msk, vmmword[reg_bias + b_off * sizeof(float)]);
94 } else {
95 uni_vpxor(zmm_acc, zmm_acc, zmm_acc);
96 }
97 if (this->jcp.with_sum) {
98 int o_off = ch * ocb_stride + ow * ow_stride;
99 if (jcp.dst_dt == data_type::bf16) {
100 const Zmm zmm_prev_dst_msk = mask_flag
101 ? zmm_prev_dst | ktail_mask | T_z
102 : zmm_prev_dst;
103 vpmovzxwd(zmm_prev_dst_msk,
104 vmmword[reg_output + o_off * jcp.typesize_out]);
105 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
106 vaddps(zmm_acc, zmm_prev_dst);
107 } else {
108 uni_vaddps(zmm_acc_msk, zmm_acc_msk,
109 vmmword[reg_output + o_off * jcp.typesize_out]);
110 }
111 }
112 }
113 }
114}
115
116void jit_avx512_dw_conv_fwd_kernel_bf16::apply_filter_unrolled(int ur_ch_blocks,
117 int ur_w, int pad_l, int pad_r, bool last_ch_block_flag) {
118 int ch_blk = jcp.ch_block;
119 int dilate_h = jcp.dilate_h + 1;
120 int dilate_w = jcp.dilate_w + 1;
121 int stride_w = jcp.stride_w;
122
123 const auto src_layout_nxc = is_src_layout_nxc();
124 const auto iw_stride = src_layout_nxc ? jcp.ngroups : ch_blk;
125 const auto ih_stride = jcp.iw * iw_stride;
126 const auto icb_stride = src_layout_nxc
127 ? ch_blk
128 : (jcp.is_fused_conv ? 1 : jcp.ih) * jcp.iw * ch_blk;
129
130 Label iter_exit_label;
131
132 cmp(reg_kh, 0);
133 je(iter_exit_label, T_NEAR);
134
135 mov(iter_kh, reg_kh);
136 Label kh_label;
137 L(kh_label);
138 {
139 if (jcp.is_fused_conv) {
140 mov(aux_reg_input, ptr[aux_reg_input_buffer_ptr]);
141 add(aux_reg_input, reg_iw_offset);
142 }
143 for (int ch = 0; ch < ur_ch_blocks; ch++) {
144 const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1;
145 for (int kw = 0; kw < jcp.kw; kw++) {
146 int ker_off = ch * jcp.kh * jcp.kw * ch_blk + kw * ch_blk;
147 const Zmm zmm_ker_reg_msk = mask_flag
148 ? zmm_ker_reg | ktail_mask | T_z
149 : zmm_ker_reg;
150 vpmovzxwd(zmm_ker_reg_msk,
151 ptr[aux_reg_kernel + ker_off * jcp.typesize_in]);
152 int ow_start = get_ow_start(kw, pad_l);
153 int ow_end = get_ow_end(ur_w, kw, pad_r);
154 for (int ow = ow_start; ow < ow_end; ow++) {
155 const Zmm zmm_src_reg_msk = mask_flag
156 ? zmm_src_reg | ktail_mask | T_z
157 : zmm_src_reg;
158 Zmm zmm_acc = get_acc_reg(ch * ur_w + ow);
159 int inp_off = ch * icb_stride
160 + (ow * stride_w - pad_l) * iw_stride
161 + kw * dilate_w * iw_stride;
162 /* zero-extend bf16 to packed 32-bit int */
163 vpmovzxwd(zmm_src_reg_msk,
164 ptr[aux_reg_input + inp_off * jcp.typesize_in]);
165 if (isa_has_bf16(jcp.isa))
166 vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg);
167 else
168 bf16_emu_->vdpbf16ps(zmm_acc, zmm_ker_reg, zmm_src_reg);
169 }
170 }
171 }
172
173 add(aux_reg_kernel, jcp.kw * ch_blk * jcp.typesize_in);
174 if (jcp.is_fused_conv) {
175 // Move to next row pointer in the buffer
176 add(aux_reg_input_buffer_ptr, sizeof(void *));
177 } else {
178 add(aux_reg_input, ih_stride * dilate_h * jcp.typesize_in);
179 }
180
181 dec(iter_kh);
182 cmp(iter_kh, 0);
183 jg(kh_label, T_NEAR);
184 }
185
186 L(iter_exit_label);
187}
188
189template <typename F>
190static void iterate(const int ur_ch_blocks, const int ur_w,
191 const bool mask_tail, const F &f) {
192 for (int ch = 0; ch < ur_ch_blocks; ch++) {
193 const bool mask_flag = mask_tail && ch + 1 == ur_ch_blocks;
194 for (int ow = 0; ow < ur_w; ow++)
195 f(ch, ow, mask_flag);
196 }
197}
198template <typename F>
199static void iterate(const int ur_ch_blocks, const int ur_w, const F &f) {
200 iterate(ur_ch_blocks, ur_w, false, f);
201}
202
203void jit_avx512_dw_conv_fwd_kernel_bf16::apply_postops(
204 int ur_ch_blocks, int ur_w, bool last_ch_block_flag) {
205 if (this->jcp.with_eltwise || this->jcp.with_binary) {
206
207 injector_utils::vmm_index_set_t vmm_idxs;
208 if (jcp.with_binary) {
209 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
210 rhs_arg_params_tail;
211 const auto mask_tail = jcp.oc_without_padding % jcp.ch_block;
212 const auto dst_layout_nxc = is_dst_layout_nxc();
213 const auto ch_blk = jcp.ch_block;
214 const auto ocb_stride
215 = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
216 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
217 const bool mask_tail_blocked_layout
218 = jcp.oc_without_padding % jcp.ch_block && !dst_layout_nxc;
219 iterate(ur_ch_blocks, ur_w, mask_tail,
220 [&](int ch, int ow, int mask_flag) {
221 const size_t aux_output_l_off = jcp.typesize_out
222 * (ch * ocb_stride + ow * ow_stride);
223 const auto vmm_idx = get_acc_reg_idx(ch * ur_w + ow);
224 vmm_idxs.emplace(vmm_idx);
225
226 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
227 vmm_idx, reg_output);
228 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
229 vmm_idx, aux_output_l_off);
230 if (mask_flag)
231 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
232 });
233 rhs_arg_params = rhs_arg_params_tail;
234 rhs_arg_params.vmm_tail_idx_.clear();
235
236 Label postops_done;
237 if (mask_tail_blocked_layout) {
238 Label postops_no_tail;
239 mov(reg_tmp, ptr[param1 + GET_OFF(load_work)]);
240 cmp(reg_tmp, jcp.nb_ch_blocking * jcp.ch_block);
241 jge(postops_no_tail, T_NEAR);
242 postops_injector_->compute_vector_range(
243 vmm_idxs, rhs_arg_params_tail);
244 jmp(postops_done, T_NEAR);
245 L(postops_no_tail);
246 postops_injector_->compute_vector_range(
247 vmm_idxs, rhs_arg_params);
248 } else if (last_ch_block_flag)
249 postops_injector_->compute_vector_range(
250 vmm_idxs, rhs_arg_params_tail);
251 else /* if (!last_ch_block_flag) */
252 postops_injector_->compute_vector_range(
253 vmm_idxs, rhs_arg_params);
254 L(postops_done);
255
256 } else {
257 iterate(ur_ch_blocks, ur_w, [&](int ch, int ow, int) {
258 vmm_idxs.emplace(get_acc_reg_idx(ch * ur_w + ow));
259 });
260 postops_injector_->compute_vector_range(vmm_idxs);
261 }
262 }
263}
264
265void jit_avx512_dw_conv_fwd_kernel_bf16::store_dst(
266 int ur_ch_blocks, int ur_w, bool last_ch_block_flag) {
267
268 const auto dst_layout_nxc = is_dst_layout_nxc();
269 const auto ch_blk = jcp.ch_block;
270 const auto ocb_stride = dst_layout_nxc ? ch_blk : jcp.oh * jcp.ow * ch_blk;
271 const auto ow_stride = dst_layout_nxc ? jcp.ngroups : ch_blk;
272
273 if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa))
274 bf16_emu_->init_vcvtneps2bf16();
275
276 if (dst_layout_nxc && jcp.dst_dt == data_type::bf16
277 && isa_has_bf16(jcp.isa)) {
278 for (int j = 0; j < ur_w; ++j) {
279 int n_2bf2ps = (ur_ch_blocks / 2) * 2;
280 int ch = 0;
281 for (; ch < n_2bf2ps; ch += 2) {
282 size_t aux_output_offset
283 = (size_t)ch * ocb_stride + j * ow_stride;
284 auto addr = ptr[reg_output
285 + aux_output_offset * jcp.typesize_out];
286 auto zmm_dst = get_acc_reg(ch * ur_w + j);
287 vcvtne2ps2bf16(
288 zmm_dst, get_acc_reg((ch + 1) * ur_w + j), zmm_dst);
289 bool mask_flag = last_ch_block_flag && ch + 2 == ur_ch_blocks;
290 Zmm zmm_dst_msk = mask_flag ? zmm_dst | k_ch_tail_mask_extended
291 : zmm_dst;
292 vmovdqu16(addr, zmm_dst_msk);
293 }
294 /* Perform tail write for odd ch sizes */
295 if (ch < ur_ch_blocks) {
296 size_t aux_output_offset
297 = (size_t)ch * ocb_stride + j * ow_stride;
298 auto addr = ptr[reg_output
299 + aux_output_offset * jcp.typesize_out];
300 auto zmm_dst = get_acc_reg(ch * ur_w + j);
301 auto ymm_dst = Ymm(zmm_dst.getIdx());
302 vcvtneps2bf16(ymm_dst, zmm_dst);
303 Ymm ymm_dst_msk
304 = last_ch_block_flag ? ymm_dst | ktail_mask : ymm_dst;
305 vmovdqu16(addr, ymm_dst_msk);
306 }
307 }
308 } else {
309 // also used for case when dst_layout_nxc && dst.dt == f32
310 if (jcp.dst_dt == data_type::f32) {
311 for (int ch = 0; ch < ur_ch_blocks; ch++) {
312 bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1;
313 for (int ow = 0; ow < ur_w; ow++) {
314 int o_off = ch * ocb_stride + ow * ow_stride;
315 Zmm zmm_dst = get_acc_reg(ch * ur_w + ow);
316 Zmm zmm_dst_msk
317 = mask_flag ? zmm_dst | ktail_mask : zmm_dst;
318 vmovups(vmmword[reg_output + o_off * jcp.typesize_out],
319 zmm_dst_msk);
320 }
321 }
322 } else if (jcp.dst_dt == data_type::bf16) {
323 if (isa_has_bf16(jcp.isa)) { // !dst_layout_nxc()
324 assert(jcp.ngroups % jcp.ch_block == 0);
325 for (int ch = 0; ch < ur_ch_blocks; ch++) {
326 int n_2bf2ps = (ur_w / 2) * 2;
327 int j = 0;
328 for (; j < n_2bf2ps; j += 2) {
329 size_t aux_output_offset
330 = (size_t)ch * ocb_stride + j * ow_stride;
331 auto addr = ptr[reg_output
332 + aux_output_offset * jcp.typesize_out];
333 auto zmm_dst = get_acc_reg(ch * ur_w + j);
334 vcvtne2ps2bf16(zmm_dst, get_acc_reg(ch * ur_w + j + 1),
335 get_acc_reg(ch * ur_w + j));
336 vmovups(addr, zmm_dst);
337 }
338 /* Perform tail write for odd ur_w sizes */
339 if (j < ur_w) {
340 size_t aux_output_offset
341 = (size_t)ch * ocb_stride + j * ow_stride;
342 auto addr = ptr[reg_output
343 + aux_output_offset * jcp.typesize_out];
344 auto zmm_dst = get_acc_reg(ch * ur_w + j);
345 auto ymm_dst = Ymm(zmm_dst.getIdx());
346 vcvtneps2bf16(ymm_dst, zmm_dst);
347 vmovups(addr, ymm_dst);
348 }
349 }
350 } else {
351 for (int ch = 0; ch < ur_ch_blocks; ch++) {
352 bool mask_flag
353 = last_ch_block_flag && ch == ur_ch_blocks - 1;
354 for (int ow = 0; ow < ur_w; ow++) {
355 int o_off = ch * ocb_stride + ow * ow_stride;
356 Zmm zmm_dst = get_acc_reg(ch * ur_w + ow);
357
358 /* down-convert f32 output to bf16 */
359 auto ymm_dst = Ymm(zmm_dst.getIdx());
360 bf16_emu_->vcvtneps2bf16(ymm_dst, zmm_dst);
361
362 Ymm ymm_dst_msk
363 = mask_flag ? ymm_dst | ktail_mask : ymm_dst;
364 vmovdqu16(ptr[reg_output + o_off * jcp.typesize_out],
365 ymm_dst_msk);
366 }
367 }
368 }
369 } else
370 assert(!"unsupported destination type");
371 }
372}
373
374void jit_avx512_dw_conv_fwd_kernel_bf16::compute_loop(
375 int ur_w, int ur_ch_blocks, int pad_l, int pad_r) {
376
377 // ch_loop currently happen only when data layout is nxc. The strides are
378 // calculated for this layout only.
379 const size_t wei_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw
380 * jcp.ch_block * jcp.typesize_in;
381 const size_t inp_ch_stride
382 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_in;
383 const size_t out_ch_stride
384 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * jcp.typesize_out;
385 const size_t bias_stride
386 = (size_t)jcp.nb_ch_blocking * jcp.ch_block * sizeof(float);
387
388 auto compute = [&](int ur_ch_blocks, bool last_ch_block_flag = false) {
389 if (jcp.is_fused_conv) {
390 mov(aux_reg_input_buffer_ptr, reg_input_buffer_ptr);
391 } else {
392 mov(aux_reg_input, reg_input);
393 }
394
395 mov(aux_reg_kernel, reg_kernel);
396 load_src(ur_ch_blocks, ur_w, last_ch_block_flag);
397 apply_filter_unrolled(
398 ur_ch_blocks, ur_w, pad_l, pad_r, last_ch_block_flag);
399 apply_postops(ur_ch_blocks, ur_w, last_ch_block_flag);
400 store_dst(ur_ch_blocks, ur_w, last_ch_block_flag);
401 };
402
403 const bool masked_ch_block_tail = jcp.oc % jcp.ch_block != 0;
404 const bool ch_loop = ur_ch_blocks > jcp.nb_ch_blocking;
405
406 push(reg_ch_blocks);
407
408 if (ch_loop) {
409 Label ch_loop_label, ch_tail_label, skip_ch_tail_label;
410 const int nb_ch = jcp.oc / jcp.ch_block;
411 const int nb_ch_blocking_tail
412 = jcp.nb_ch - utils::rnd_dn(nb_ch, jcp.nb_ch_blocking);
413 const int ch_step = jcp.nb_ch_blocking * jcp.ch_block;
414
415 push(reg_kernel);
416 push(reg_input);
417 push(reg_output);
418 if (jcp.with_bias) push(reg_bias);
419
420 if (nb_ch >= jcp.nb_ch_blocking) {
421 if (nb_ch_blocking_tail) {
422 cmp(reg_ch_blocks, ch_step);
423 jl(ch_tail_label, T_NEAR);
424 }
425
426 L(ch_loop_label);
427 {
428 compute(jcp.nb_ch_blocking);
429 add(reg_kernel, wei_ch_stride);
430 add(reg_input, inp_ch_stride);
431 add(reg_output, out_ch_stride);
432 if (jcp.with_bias) add(reg_bias, bias_stride);
433 sub(reg_ch_blocks, ch_step);
434 cmp(reg_ch_blocks, ch_step);
435 jge(ch_loop_label, T_NEAR);
436 }
437 }
438 if (nb_ch_blocking_tail) {
439 // ch work range [1, jcp.nb_ch_blocking * ch_block)
440 L(ch_tail_label);
441 cmp(reg_ch_blocks, 0);
442 jle(skip_ch_tail_label, T_NEAR);
443 compute(nb_ch_blocking_tail, masked_ch_block_tail);
444 L(skip_ch_tail_label);
445 }
446 if (jcp.with_bias) pop(reg_bias);
447 pop(reg_output);
448 pop(reg_input);
449 pop(reg_kernel);
450
451 } else {
452 compute(ur_ch_blocks, masked_ch_block_tail);
453 }
454
455 pop(reg_ch_blocks);
456}
457
458void jit_avx512_dw_conv_fwd_kernel_bf16::loop_ow(int ur_ch_blocks) {
459
460 int iw = jcp.iw;
461 int ow = jcp.ow;
462 int kw = jcp.kw;
463 int l_pad = jcp.l_pad;
464 int ur_w = jcp.ur_w;
465 int ur_w_tail = jcp.ur_w_tail;
466 int stride_w = jcp.stride_w;
467
468 const auto src_layout_nxc = is_src_layout_nxc();
469 const auto dat_c_stride = src_layout_nxc ? jcp.ngroups : jcp.ch_block;
470 size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * dat_c_stride;
471 size_t out_shift = (size_t)jcp.typesize_out * ur_w * dat_c_stride;
472
473 int inp_shift_pad
474 = jcp.typesize_in * (ur_w * stride_w - l_pad) * dat_c_stride;
475
476 int r_pad = nstl::max(0, jcp.r_pad);
477 int n_oi = ow / ur_w;
478 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
479 calculate_extended_filter_size(kw, jcp.dilate_w));
480
481 assert(jcp.nb_ow <= 1);
482
483 if (r_pad1 > 0) n_oi--;
484 xor_(reg_oi, reg_oi);
485 if (ow == ur_w) {
486 compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad);
487 } else {
488 if (n_oi == 0) {
489 compute_loop(ur_w, ur_ch_blocks, l_pad, r_pad1);
490 add(reg_input, inp_shift_pad);
491 add(reg_output, out_shift);
492 if (ur_w_tail != 0) {
493 compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad);
494 }
495 } else {
496 if (l_pad > 0) {
497 compute_loop(ur_w, ur_ch_blocks, l_pad, 0);
498 add(reg_input, inp_shift_pad);
499 add(reg_output, out_shift);
500 inc(reg_oi);
501 }
502 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
503 Label ow_loop_label;
504 L(ow_loop_label);
505 {
506 compute_loop(ur_w, ur_ch_blocks, 0, 0);
507 add(reg_input, inp_shift);
508 add(reg_output, out_shift);
509
510 inc(reg_oi);
511 cmp(reg_oi, n_oi);
512 jl(ow_loop_label, T_NEAR);
513 }
514 }
515 if (r_pad1 > 0) {
516 compute_loop(ur_w, ur_ch_blocks, 0, r_pad1);
517 add(reg_input, inp_shift);
518 add(reg_output, out_shift);
519 }
520 if (ur_w_tail != 0) {
521 compute_loop(ur_w_tail, ur_ch_blocks, 0, r_pad);
522 }
523 }
524 }
525}
526
527void jit_avx512_dw_conv_fwd_kernel_bf16::generate() {
528 this->preamble();
529
530 assert(mayiuse(avx512_core));
531 if (jcp.is_fused_conv) {
532 mov(reg_input_buffer_ptr, ptr[this->param1 + GET_OFF(src)]);
533 /* In case of fused depthwise convolution, `param.src` is not a pointer
534 to input, instead it points to a buffer containing pointers to
535 consecutive rows of input in format Cwc with blocking nb_ch_blocking.
536 Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2].
537 Traverse the data as
538 mov(reg_data, ptr[reg_input_buffer_ptr])
539 ... process row0 ...
540 add(reg_input_buffer_ptr, sizeof(void*))
541 mov(reg_data, ptr[reg_input_buffer_ptr])
542 ... process row1 ...
543 add(reg_input_buffer_ptr, sizeof(void*))
544 mov(reg_data, ptr[reg_input_buffer_ptr])
545 ... process row2 ...
546 */
547 xor_(reg_iw_offset, reg_iw_offset);
548 } else {
549 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
550 }
551 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
552 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
553 if (jcp.with_bias) mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
554 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
555 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(load_work)]);
556
557 Label ch_blocks_tail_label;
558 Label exit_label;
559
560 const int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
561 const auto oc_tail = jcp.oc_without_padding % jcp.ch_block;
562 if (oc_tail != 0) {
563 // Note: is_src_layout_nxc() == true, otherwise channels are padded
564 // Prepare masks for tailing
565 const int oc_tail_shift
566 = jcp.ch_block - jcp.oc_without_padding % jcp.ch_block;
567 static constexpr auto zmm_16b_mask = ((1 << 16) - 1);
568
569 // To account for special store optimization, where two oc_blocks are
570 // combined with one single write, extend the mask for 32 bits
571 // (i.e. 32 bfloat16 elements)
572 const bool need_extended_mask = jcp.dst_dt == data_type::bf16
573 && isa_has_bf16(jcp.isa) && jcp.nb_ch_blocking > 1;
574 if (need_extended_mask)
575 kxnord(k_ch_tail_mask_extended, k_ch_tail_mask_extended,
576 k_ch_tail_mask_extended);
577
578 Label done;
579 mov(reg_tail, ptr[param1 + GET_OFF(load_work)]);
580 cmp(reg_tail, jcp.nb_ch_blocking * jcp.ch_block);
581 je(done, T_NEAR);
582 Reg32 reg_tail_32 = reg_tail.cvt32();
583 mov(reg_tail_32, zmm_16b_mask >> oc_tail_shift);
584 kmovw(k_oc_tail_mask, reg_tail_32);
585 if (need_extended_mask) {
586 auto zmm_32b_mask = (1 << (oc_tail + jcp.ch_block)) - 1;
587 mov(reg_tail_32, zmm_32b_mask);
588 kmovd(k_ch_tail_mask_extended, reg_tail_32);
589 }
590 L(done);
591 }
592
593 if (is_src_layout_nxc()) {
594 loop_ow(jcp.nb_ch);
595 } else {
596 cmp(reg_ch_blocks, (jcp.nb_ch_blocking - 1) * jcp.ch_block);
597 jle(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
598
599 loop_ow(jcp.nb_ch_blocking); // channel main loop
600
601 if (ch_blocks_tail) {
602 jmp(exit_label, T_NEAR);
603 L(ch_blocks_tail_label);
604
605 loop_ow(ch_blocks_tail); // channel tail loop
606 }
607
608 L(exit_label);
609 }
610
611 postamble();
612
613 if (jcp.with_eltwise) postops_injector_->prepare_table();
614}
615
616inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::load_ddst(
617 int ur_ch_blocks, int ur_str_w) {
618 for (int ch = 0; ch < ur_ch_blocks; ch++) {
619 for (int w = 0; w < ur_str_w; w++) {
620 Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w);
621 uni_vpxor(zmm_acc, zmm_acc, zmm_acc);
622 }
623 }
624}
625
626inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::apply_filter(
627 int ur_ch_blocks, int ur_str_w, bool last_ch_block_flag) {
628 int kw = jcp.kw;
629 int kh = jcp.kh;
630 int ow = jcp.ow;
631 int oh = jcp.oh;
632
633 int ch_blk = jcp.ch_block;
634 int stride_h = jcp.stride_h;
635 int stride_w = jcp.stride_w;
636
637 const bool ddst_layout_nxc = is_ddst_layout_nxc();
638 const size_t ch_block_step = ch_blk * (ddst_layout_nxc ? 1 : oh * ow);
639 const size_t sp_step = ddst_layout_nxc ? jcp.ngroups : ch_blk;
640
641 Label iter_exit_label;
642
643 cmp(reg_kh, 0);
644 je(iter_exit_label, T_NEAR);
645
646 cmp(reg_kw, 0);
647 je(iter_exit_label, T_NEAR);
648
649 mov(iter_kh, reg_kh);
650 Label kh_label;
651 L(kh_label);
652 {
653 mov(aux1_reg_ddst, aux_reg_ddst);
654 mov(aux1_reg_kernel, aux_reg_kernel);
655
656 mov(iter_kw, reg_kw);
657 Label kw_label;
658 L(kw_label);
659 {
660 for (int ch = 0; ch < ur_ch_blocks; ch++) {
661 const bool mask_flag
662 = last_ch_block_flag && ch == ur_ch_blocks - 1;
663 int ker_off = ch * kh * kw * ch_blk;
664 Zmm mm_zmm_ker // mm: maybe masked
665 = mask_flag ? zmm_ker_reg | k_ch_tail_mask | T_z
666 : zmm_ker_reg;
667 vpmovzxwd(mm_zmm_ker,
668 ptr[aux1_reg_kernel + ker_off * jcp.typesize_in]);
669
670 for (int w = 0; w < ur_str_w; w++) {
671 size_t sp_offset = w * sp_step;
672 size_t ch_offset = ch * ch_block_step;
673 size_t ddst_off = sp_offset + ch_offset;
674 Zmm zmm_acc = get_acc_reg(ch * ur_str_w + w);
675 Zmm mm_zmm_dst // mm: maybe masked
676 = mask_flag ? zmm_dst_reg | k_ch_tail_mask | T_z
677 : zmm_dst_reg;
678 vpmovzxwd(mm_zmm_dst,
679 ptr[aux1_reg_ddst + ddst_off * jcp.typesize_in]);
680
681 if (isa_has_bf16(jcp.isa))
682 vdpbf16ps(zmm_acc, mm_zmm_ker, mm_zmm_dst);
683 else
684 bf16_emu_->vdpbf16ps(zmm_acc, mm_zmm_dst, mm_zmm_ker);
685 }
686 }
687
688 add(aux1_reg_kernel, ch_blk * stride_w * jcp.typesize_in);
689 sub(aux1_reg_ddst, sp_step * jcp.typesize_in);
690
691 sub(iter_kw, stride_w);
692 cmp(iter_kw, 0);
693 jg(kw_label, T_NEAR);
694 }
695
696 add(aux_reg_kernel, kw * ch_blk * stride_h * jcp.typesize_in);
697 sub(aux_reg_ddst, ow * sp_step * jcp.typesize_in);
698
699 sub(iter_kh, stride_h);
700 cmp(iter_kh, 0);
701 jg(kh_label, T_NEAR);
702 }
703
704 L(iter_exit_label);
705}
706
707inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::store_dsrc(
708 int ur_ch_blocks, int ur_str_w, bool last_ch_block_flag) {
709 int ch_blk = jcp.ch_block;
710 int iw = jcp.iw;
711 int ih = jcp.ih;
712 int stride_w = jcp.stride_w;
713
714 const auto dsrc_layout_nxc = is_dsrc_layout_nxc();
715 const size_t ch_block_step = ch_blk * (dsrc_layout_nxc ? 1 : ih * iw);
716 const size_t sp_step = dsrc_layout_nxc ? jcp.ngroups : ch_blk;
717
718 if (jcp.dsrc_dt == data_type::bf16 && !isa_has_bf16(jcp.isa))
719 bf16_emu_->init_vcvtneps2bf16();
720
721 for (int ch = 0; ch < ur_ch_blocks; ch++) {
722 const bool mask_flag = last_ch_block_flag && ch == ur_ch_blocks - 1;
723 for (int w = 0; w < ur_str_w; w++) {
724 size_t sp_offset = w * stride_w * sp_step;
725 size_t ch_offset = ch * ch_block_step;
726 int dsrc_off = sp_offset + ch_offset;
727 auto zmm_dsrc = get_acc_reg(ch * ur_str_w + w);
728 Zmm mm_zmm_dsrc // mm: maybe masked
729 = mask_flag ? zmm_dsrc | k_ch_tail_mask : zmm_dsrc;
730
731 if (jcp.dsrc_dt == data_type::f32) {
732 uni_vmovups(ptr[reg_dsrc + dsrc_off * jcp.typesize_out],
733 mm_zmm_dsrc);
734 } else if (jcp.dsrc_dt == data_type::bf16) {
735 auto ymm_dsrc = Ymm(zmm_dsrc.getIdx());
736 Ymm mm_ymm_dsrc // mm: maybe masked
737 = mask_flag ? ymm_dsrc | k_ch_tail_mask : ymm_dsrc;
738
739 if (isa_has_bf16(jcp.isa))
740 vcvtneps2bf16(mm_ymm_dsrc, mm_zmm_dsrc);
741 else
742 bf16_emu_->vcvtneps2bf16(mm_ymm_dsrc, mm_zmm_dsrc);
743 vmovdqu16(ptr[reg_dsrc + dsrc_off * jcp.typesize_out],
744 mm_ymm_dsrc);
745 }
746 }
747 }
748 /* Note: current 'store_dsrc' is limited to storing 'ymm' output. This is
749 * because of the current implementation approach that calculates convolution as
750 * a strided backward-pass. To increase store throughput by writing 'zmm'
751 * registers, changes are needed in both JIT-kernel and Driver code. */
752}
753
754inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::ch_loop_body(
755 int ur_ch_blocks, int unroll_w) {
756
757 auto call_compute_body
758 = [&](int ur_ch_blocks, int unroll_w, bool is_last_ch = false) {
759 mov(aux_reg_ddst, reg_ddst);
760 mov(aux_reg_kernel, reg_kernel);
761
762 load_ddst(ur_ch_blocks, unroll_w);
763 apply_filter(ur_ch_blocks, unroll_w, is_last_ch);
764 store_dsrc(ur_ch_blocks, unroll_w, is_last_ch);
765 };
766
767 const bool write_ch_loop = ur_ch_blocks > jcp.nb_ch_blocking;
768 if (write_ch_loop) {
769 assert(is_ddst_layout_nxc() && is_dsrc_layout_nxc());
770
771 Label ch_loop_label, ch_tail_label, skip_ch_tail_label;
772 const int nb_oc = jcp.oc / jcp.ch_block;
773 const int ch_block_tail
774 = jcp.nb_ch - (utils::rnd_dn(nb_oc, jcp.nb_ch_blocking));
775 const int ch_step = jcp.nb_ch_blocking * jcp.ch_block;
776
777 const size_t wei_ch_stride
778 = (size_t)jcp.nb_ch_blocking * jcp.kh * jcp.kw * jcp.ch_block;
779 const size_t data_ch_stride = (size_t)jcp.nb_ch_blocking * jcp.ch_block;
780
781 mov(aux_reg_ch_blocks, reg_ch_blocks);
782 push(reg_dsrc);
783 push(reg_ddst);
784 push(reg_kernel);
785
786 if (nb_oc >= jcp.nb_ch_blocking) {
787 if (ch_block_tail) {
788 cmp(aux_reg_ch_blocks, jcp.nb_ch_blocking * jcp.ch_block);
789 jl(ch_tail_label, T_NEAR);
790 }
791
792 L(ch_loop_label);
793 {
794 call_compute_body(jcp.nb_ch_blocking, unroll_w);
795
796 add(reg_kernel, wei_ch_stride * jcp.typesize_in);
797 add(reg_dsrc, data_ch_stride * jcp.typesize_out);
798 add(reg_ddst, data_ch_stride * jcp.typesize_in);
799
800 sub(aux_reg_ch_blocks, ch_step);
801 cmp(aux_reg_ch_blocks, ch_step);
802 jge(ch_loop_label, T_NEAR);
803 }
804 }
805
806 if (ch_block_tail) {
807 // ch work range [1, jcp.nb_ch_blocking * ch_block)
808 L(ch_tail_label);
809 cmp(aux_reg_ch_blocks, 0);
810 jle(skip_ch_tail_label, T_NEAR);
811 call_compute_body(ch_block_tail, unroll_w, jcp.ch_tail);
812 L(skip_ch_tail_label);
813 }
814
815 pop(reg_kernel);
816 pop(reg_ddst);
817 pop(reg_dsrc);
818
819 } else {
820 call_compute_body(ur_ch_blocks, unroll_w, jcp.ch_tail);
821 }
822}
823
824inline void jit_avx512_dw_conv_bwd_data_kernel_bf16::unroll_width_body(
825 int ur_ch_blocks) {
826
827 auto unroll_width_loop = [&](int unroll_w) {
828 Label unroll_w_label, skip_compute_label;
829 L(unroll_w_label);
830 {
831 const size_t ch_step = unroll_w
832 * (is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block);
833 cmp(reg_ur_str_w, unroll_w);
834 jl(skip_compute_label, T_NEAR);
835
836 ch_loop_body(ur_ch_blocks, unroll_w);
837
838 add(reg_dsrc, jcp.typesize_out * jcp.stride_w * ch_step);
839 add(reg_ddst, jcp.typesize_in * ch_step);
840
841 sub(reg_ur_str_w, unroll_w);
842 jmp(unroll_w_label);
843 }
844 L(skip_compute_label);
845 };
846
847 unroll_width_loop(jcp.ur_w);
848
849 unroll_width_loop(1);
850}
851
852void jit_avx512_dw_conv_bwd_data_kernel_bf16::generate() {
853 assert(is_dsrc_layout_nxc() == is_ddst_layout_nxc());
854
855 preamble();
856 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
857 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
858 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
859 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
860 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
861 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
862 mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
863
864 if (is_dsrc_layout_nxc()) {
865 if (jcp.ch_tail) {
866 Label masking_done;
867 const size_t channel_step = jcp.nb_ch_blocking * jcp.ch_block;
868 kxnorw(k_ch_tail_mask, k_ch_tail_mask,
869 k_ch_tail_mask); // dummy mask all 1's
870 cmp(reg_ch_blocks, channel_step);
871 je(masking_done, T_NEAR);
872 // Prepare masks for tail
873 Reg32 reg_tmp_32 = reg_tmp.cvt32();
874 mov(reg_tmp_32, (1 << jcp.ch_tail) - 1);
875 kmovw(k_ch_tail_mask, reg_tmp_32);
876 L(masking_done);
877 }
878
879 unroll_width_body(jcp.nb_ch);
880 } else {
881 auto ch_blocks_loop = [&](int ch_blocks) {
882 Label skip_loop_label;
883 cmp(reg_ch_blocks, ch_blocks * jcp.ch_block);
884 jl(skip_loop_label, T_NEAR);
885 unroll_width_body(ch_blocks);
886 L(skip_loop_label);
887 };
888
889 ch_blocks_loop(jcp.nb_ch_blocking);
890
891 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
892 if (ch_blocks_tail) { ch_blocks_loop(ch_blocks_tail); }
893 }
894 postamble();
895}
896#undef GET_OFF
897
898#define GET_OFF(field) offsetof(jit_dw_conv_call_s, field)
899void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_filter() {
900 for (int i = 0; i < jcp.kw; ++i) {
901 Zmm zmm_acc = get_acc_reg(i);
902 uni_vpxor(zmm_acc, zmm_acc, zmm_acc);
903 }
904}
905
906void jit_avx512_dw_conv_bwd_weights_kernel_bf16::load_filter(bool is_last_ch) {
907 for (int i = 0; i < jcp.kw; ++i) {
908 int off_filter = i * jcp.ch_block;
909 Zmm zmm_acc = get_acc_reg(i);
910 Zmm m_zmm_acc = is_last_ch ? zmm_acc | k_ch_tail_mask | T_z : zmm_acc;
911 vmovups(m_zmm_acc,
912 vmmword[reg_tmp_filter + off_filter * jcp.typesize_out]);
913 }
914}
915
916void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_bias() {
917 uni_vpxor(zmm_bias_reg, zmm_bias_reg, zmm_bias_reg);
918}
919
920void jit_avx512_dw_conv_bwd_weights_kernel_bf16::load_bias(bool is_last_ch) {
921 Zmm m_zmm_bias_reg
922 = is_last_ch ? zmm_bias_reg | k_ch_tail_mask | T_z : zmm_bias_reg;
923 vmovups(m_zmm_bias_reg, vmmword[reg_bias_baddr]);
924}
925
926void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_step_unroll(
927 int unroll_w, int l_pad, int pad_offset, int ow_block,
928 bool is_last_ch) {
929
930 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
931 const int iw_block = ow_block * jcp.stride_w;
932 const int right_border = jcp.iw - iw_block;
933 const int r_pad = jcp.r_pad;
934
935 const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
936
937 /* preamble count for number of cascaded LOAD + FMA operation */
938 const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
939 const bool is_last_block = (unroll_w + ow_block == jcp.ow);
940
941 /* LOAD initial input registers, then cascade LOADs and FMAs*/
942 for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
943 size_t off_output
944 = static_cast<size_t>(i_ur * ch_step * jcp.typesize_in);
945 Zmm m_zmm_out_reg
946 = is_last_ch ? zmm_out_reg | k_ch_tail_mask | T_z : zmm_out_reg;
947 vpmovzxwd(m_zmm_out_reg, ptr[reg_tmp_output + off_output]);
948 if (i_ur == 0) {
949 for (int c = 0; c < input_overlap; ++c) {
950 int input_sp = c - pad_offset;
951 if (input_sp < 0 && unroll_w == jcp.ow) continue;
952
953 const bool over_steps_bdry = true && is_last_block
954 && (c - pad_offset + r_pad > right_border);
955 if (over_steps_bdry) continue;
956
957 size_t input_offset = static_cast<size_t>(
958 input_sp * ch_step * jcp.typesize_in);
959 Zmm zmm_input = get_input_reg(c);
960 Zmm m_zmm_input = is_last_ch ? zmm_input | k_ch_tail_mask | T_z
961 : zmm_input;
962 vpmovzxwd(m_zmm_input, ptr[reg_tmp_input + input_offset]);
963 }
964 } else {
965 for (int c = 0; c < cascade_input; ++c) {
966 int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
967 int input_sp = overlap + c - pad_offset;
968 if (input_sp < 0 || overlap + c + l_pad > right_border)
969 continue;
970
971 const bool over_steps_bdry = true && is_last_block
972 && (overlap + c - pad_offset + r_pad > right_border);
973 if (over_steps_bdry) continue;
974
975 size_t input_offset = static_cast<size_t>(
976 input_sp * ch_step * jcp.typesize_in);
977 Zmm zmm_input = get_input_reg(overlap + c);
978 Zmm m_zmm_input = is_last_ch ? zmm_input | k_ch_tail_mask | T_z
979 : zmm_input;
980 vpmovzxwd(m_zmm_input, ptr[reg_tmp_input + input_offset]);
981 }
982 }
983
984 for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
985 int io_overlap = i_kw + (i_ur * jcp.stride_w);
986
987 /* Don't apply FMAs that fall into the padded region */
988 if (io_overlap - l_pad < 0
989 || io_overlap - jcp.l_pad >= right_border)
990 continue;
991
992 const bool over_steps_bdry = true && is_last_block
993 && (io_overlap - jcp.l_pad + jcp.r_pad > right_border);
994 if (over_steps_bdry) continue;
995
996 Zmm zmm_input = get_input_reg(io_overlap - l_pad);
997 Zmm zmm_acc = get_acc_reg(i_kw);
998 if (isa_has_bf16(jcp.isa))
999 vdpbf16ps(zmm_acc, zmm_input, zmm_out_reg);
1000 else
1001 bf16_emu_->vdpbf16ps(zmm_acc, zmm_input, zmm_out_reg);
1002 }
1003 }
1004}
1005
1006void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_bias_step_unroll(
1007 const int unroll_w, bool is_last_ch) {
1008
1009 const int ch_step = is_ddst_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1010 for (int i = 0; i < unroll_w; ++i) {
1011 size_t off_output = static_cast<size_t>(i * ch_step * jcp.typesize_in);
1012 /* bf16 output data requires conversion to f32 */
1013 Zmm m_zmm_out_reg
1014 = is_last_ch ? zmm_out_reg | k_ch_tail_mask | T_z : zmm_out_reg;
1015 vpmovzxwd(m_zmm_out_reg, ptr[reg_tmp_output + off_output]);
1016 vpslld(m_zmm_out_reg, m_zmm_out_reg, 0x10);
1017 vaddps(zmm_bias_reg, zmm_bias_reg, m_zmm_out_reg);
1018 }
1019}
1020
1021void jit_avx512_dw_conv_bwd_weights_kernel_bf16::store_filter(bool is_last_ch) {
1022
1023 /* bf16: all data is stored as f32. Down-convert to bf16 happens at the
1024 * reduction phase. */
1025 for (int i = 0; i < jcp.kw; ++i) {
1026 int off_filter = i * jcp.ch_block;
1027 Zmm zmm_acc = get_acc_reg(i);
1028 Zmm m_zmm_acc = is_last_ch ? zmm_acc | k_ch_tail_mask : zmm_acc;
1029 vmovups(vmmword[reg_tmp_filter + off_filter * jcp.typesize_out],
1030 m_zmm_acc);
1031 }
1032}
1033
1034void jit_avx512_dw_conv_bwd_weights_kernel_bf16::store_bias(bool is_last_ch) {
1035 Zmm m_zmm_bias_reg
1036 = is_last_ch ? zmm_bias_reg | k_ch_tail_mask : zmm_bias_reg;
1037 vmovups(vmmword[reg_bias_baddr], m_zmm_bias_reg);
1038}
1039
1040void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_spatial_loop_bias(
1041 bool is_last_ch) {
1042 Label oh_label;
1043 Label ow_blk_label;
1044
1045 const int unroll_w = nstl::min(max_unroll_w_, jcp.ow);
1046 const int unroll_w_trips = jcp.ow / unroll_w;
1047 const int tail_w = jcp.ow > max_unroll_w_ ? jcp.ow % max_unroll_w_ : 0;
1048
1049 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1050 const size_t ch_offset = ch_step * jcp.typesize_in;
1051
1052 mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]);
1053 mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]);
1054
1055 mov(reg_tmp_output, reg_output_baddr);
1056 L(oh_label);
1057 {
1058
1059 mov(reg_iter_ow_blk, unroll_w_trips);
1060 L(ow_blk_label);
1061 {
1062 compute_bias_step_unroll(unroll_w, is_last_ch);
1063 add(reg_tmp_output, unroll_w * ch_offset);
1064
1065 dec(reg_iter_ow_blk);
1066 cmp(reg_iter_ow_blk, 0);
1067 jg(ow_blk_label, T_NEAR);
1068 }
1069
1070 if (tail_w > 0) {
1071 compute_bias_step_unroll(tail_w, is_last_ch);
1072 add(reg_tmp_output, tail_w * ch_offset);
1073 }
1074
1075 inc(reg_oh);
1076 cmp(reg_oh, reg_oh_worksize);
1077 jl(oh_label, T_NEAR);
1078 }
1079}
1080
1081void jit_avx512_dw_conv_bwd_weights_kernel_bf16::
1082 compute_single_ch_block_bias() {
1083
1084 auto write_compute_bias = [&](bool masked_ch_tail) {
1085 Label skip_load_bias;
1086
1087 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1088 and_(reg_exec_flags, FLAG_ZERO_BIAS);
1089 test(reg_exec_flags, reg_exec_flags);
1090 jne(skip_load_bias);
1091
1092 load_bias(masked_ch_tail);
1093
1094 L(skip_load_bias);
1095 compute_spatial_loop_bias(masked_ch_tail);
1096
1097 store_bias(masked_ch_tail);
1098 };
1099
1100 Label skip_masked_bias_label, done_bias_label;
1101
1102 zero_bias();
1103
1104 bool do_bias_ch_tail = jcp.ch_tail > 0;
1105 if (do_bias_ch_tail) {
1106 // test last channel
1107 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1108 and_(reg_exec_flags, FLAG_OC_LAST);
1109 test(reg_exec_flags, reg_exec_flags);
1110 jz(skip_masked_bias_label, T_NEAR);
1111
1112 write_compute_bias(true);
1113
1114 jmp(done_bias_label, T_NEAR);
1115 L(skip_masked_bias_label);
1116 }
1117
1118 write_compute_bias(false);
1119
1120 L(done_bias_label);
1121}
1122
1123void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ch_loop_bias(
1124 bool do_load_bias) {
1125
1126 assert(is_ddst_layout_nxc());
1127
1128 auto write_compute_bias = [&](bool masked_ch_tail) {
1129 if (do_load_bias)
1130 load_bias(masked_ch_tail);
1131 else
1132 zero_bias();
1133 compute_spatial_loop_bias(masked_ch_tail);
1134 store_bias(masked_ch_tail);
1135 };
1136
1137 bool masked_ch_tail = jcp.ch_tail > 0;
1138 if (jcp.nb_ch > 1) {
1139
1140 Label last_ch_block_label, ch_block_done_label;
1141 if (masked_ch_tail) {
1142 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1143 and_(reg_exec_flags, FLAG_OC_LAST);
1144 test(reg_exec_flags, reg_exec_flags);
1145 jnz(last_ch_block_label, T_NEAR);
1146 }
1147
1148 write_compute_bias(false);
1149
1150 if (masked_ch_tail) {
1151 jmp(ch_block_done_label, T_NEAR);
1152
1153 L(last_ch_block_label);
1154 write_compute_bias(true);
1155
1156 L(ch_block_done_label);
1157 }
1158 } else {
1159 write_compute_bias(masked_ch_tail);
1160 }
1161}
1162
1163void jit_avx512_dw_conv_bwd_weights_kernel_bf16::deploy_ch_loop_bias() {
1164
1165 Label ch_loop_label, zero_bias_label, load_bias_done_label;
1166
1167 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1168 and_(reg_exec_flags, FLAG_ZERO_BIAS);
1169 test(reg_exec_flags, reg_exec_flags);
1170 jne(zero_bias_label, T_NEAR);
1171
1172 compute_ch_loop_bias(true); // load_bias
1173 jmp(load_bias_done_label, T_NEAR);
1174
1175 L(zero_bias_label);
1176 compute_ch_loop_bias(false); // zero_bias
1177
1178 L(load_bias_done_label);
1179}
1180
1181void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_bias() {
1182
1183 mov(reg_bias_baddr, ptr[this->param1 + GET_OFF(bias)]);
1184
1185 if (is_ddst_layout_nxc())
1186 deploy_ch_loop_bias();
1187 else
1188 compute_single_ch_block_bias();
1189}
1190
1191void jit_avx512_dw_conv_bwd_weights_kernel_bf16::zero_filter_kh_loop() {
1192
1193 const size_t filter_offset_kw = jcp.kw * jcp.ch_block * jcp.typesize_out;
1194 const size_t filter_offset_kh = jcp.kh * filter_offset_kw;
1195
1196 Label kh_loop_label;
1197
1198 mov(reg_kh_aux, jcp.kh);
1199 L(kh_loop_label);
1200 {
1201 store_filter();
1202
1203 add(reg_tmp_filter, filter_offset_kw);
1204 dec(reg_kh_aux);
1205 cmp(reg_kh_aux, 0);
1206 jg(kh_loop_label, T_NEAR);
1207 }
1208
1209 /* Comeback pointers */
1210 sub(reg_tmp_filter, filter_offset_kh);
1211}
1212
1213void jit_avx512_dw_conv_bwd_weights_kernel_bf16::deploy_zero_filter() {
1214
1215 Label skip_zeroing_label;
1216
1217 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1218 and_(reg_exec_flags, FLAG_ZERO_FILTER);
1219 test(reg_exec_flags, reg_exec_flags);
1220 je(skip_zeroing_label, T_NEAR);
1221
1222 zero_filter();
1223
1224 mov(reg_tmp_filter, reg_filter_baddr);
1225 zero_filter_kh_loop();
1226
1227 L(skip_zeroing_label);
1228}
1229
1230void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_kh_step(int unroll_w,
1231 int l_pad, int pad_offset, int ow_block, bool is_last_ch) {
1232
1233 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1234 const size_t input_offset = jcp.iw * ch_step * jcp.typesize_in;
1235 const size_t filter_offset = jcp.kw * jcp.ch_block * jcp.typesize_out;
1236
1237 Label kh_loop_label, skip_loop_label;
1238
1239 cmp(reg_kh, 0);
1240 je(skip_loop_label, T_NEAR);
1241
1242 mov(reg_kh_aux, reg_kh);
1243 L(kh_loop_label);
1244 {
1245 load_filter();
1246 compute_ow_step_unroll(
1247 unroll_w, l_pad, pad_offset, ow_block, is_last_ch);
1248 store_filter();
1249
1250 add(reg_tmp_filter, filter_offset);
1251 add(reg_tmp_input, input_offset);
1252 dec(reg_kh_aux);
1253 cmp(reg_kh_aux, 0);
1254 jg(kh_loop_label, T_NEAR);
1255 }
1256
1257 /* Comeback pointers */
1258 Label kh_comeback_label;
1259 mov(reg_kh_aux, reg_kh);
1260 L(kh_comeback_label);
1261 {
1262 sub(reg_tmp_input, input_offset);
1263 sub(reg_tmp_filter, filter_offset);
1264 dec(reg_kh_aux);
1265 cmp(reg_kh_aux, 0);
1266 jg(kh_comeback_label, T_NEAR);
1267 }
1268
1269 L(skip_loop_label);
1270}
1271
1272void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ch_loop(
1273 int unroll_w, int l_pad, int pad_offset, int ow_block) {
1274
1275 const bool masked_ch_tail = is_layout_nxc() && jcp.ch_tail > 0;
1276 bool write_channel_loop = is_layout_nxc() && jcp.nb_ch > 1;
1277 if (write_channel_loop) {
1278 Label last_ch_block_label, ch_block_done_label;
1279 if (masked_ch_tail) {
1280 mov(reg_exec_flags, ptr[this->param1 + GET_OFF(exec_flags)]);
1281 and_(reg_exec_flags, FLAG_OC_LAST);
1282 test(reg_exec_flags, reg_exec_flags);
1283 jnz(last_ch_block_label, T_NEAR);
1284 }
1285
1286 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, false);
1287
1288 if (masked_ch_tail) {
1289 jmp(ch_block_done_label, T_NEAR);
1290
1291 L(last_ch_block_label);
1292 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, true);
1293 L(ch_block_done_label);
1294 }
1295 } else {
1296 compute_kh_step(unroll_w, l_pad, pad_offset, ow_block, masked_ch_tail);
1297 }
1298}
1299
1300void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_h_loop(
1301 int unroll_w, int l_pad, int pad_offset, int ow_block) {
1302
1303 mov(reg_tmp_output, reg_output_baddr);
1304 mov(reg_tmp_input, reg_input_baddr);
1305 mov(reg_tmp_filter, reg_filter_baddr);
1306
1307 const int input_bottom_padding_overlap
1308 = div_up(jcp.ih + jcp.t_pad - (jcp.kh - 1), jcp.stride_h);
1309
1310 const size_t ch_step = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1311 const size_t input_shift = jcp.typesize_in * jcp.iw * ch_step;
1312 const size_t output_shift = jcp.typesize_in * jcp.ow * ch_step;
1313 const size_t filter_shift = jcp.typesize_out * jcp.kw * jcp.ch_block;
1314
1315 Label loop_begin_label, loop_end_label, common_block_label,
1316 top_padding_end_label, bottom_padding_end_label,
1317 bottom_padding_label;
1318
1319 mov(reg_oh, ptr[this->param1 + GET_OFF(oh_index)]);
1320 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_count)]);
1321
1322 // replacement for 'os_index_end'
1323 mov(reg_oh_worksize, ptr[this->param1 + GET_OFF(oh_count)]);
1324
1325 cmp(reg_kh, 0);
1326 jle(loop_end_label, T_NEAR); // no iterations along kh
1327 cmp(reg_oh, reg_oh_worksize);
1328 jge(loop_end_label, T_NEAR); // no iterations along height dimension
1329
1330 L(loop_begin_label);
1331
1332 compute_ch_loop(unroll_w, l_pad, pad_offset, ow_block);
1333
1334 /* Compute 'top' edge */
1335 if (jcp.t_pad > 0) {
1336
1337 /* Check if within top padding region */
1338 cmp(reg_oh, div_up(jcp.t_pad, jcp.stride_h));
1339 jge(top_padding_end_label, T_NEAR);
1340
1341 /* Increment step counter and adjust filter position */
1342 sub(reg_tmp_filter, filter_shift * jcp.stride_h);
1343 add(reg_kh, jcp.stride_h);
1344
1345 /* Final number of kernel elements that overlap with input */
1346 const int inp_ker_overlap = nstl::min(jcp.kh, jcp.ih);
1347 cmp(reg_kh, inp_ker_overlap);
1348 jle(common_block_label, T_NEAR);
1349
1350 /* Correct any excess shifts to kernel and input */
1351 if (jcp.t_pad <= jcp.oh * jcp.stride_h) {
1352 /* Filter has moved beyond padding (adjust for stride effects) */
1353 if (jcp.t_pad % jcp.stride_h != 0) {
1354 int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
1355 add(reg_tmp_filter, filter_shift * inp_corr);
1356 add(reg_tmp_input, input_shift * inp_corr);
1357 }
1358 } else {
1359 /* Filter still overlaps padding (complete reset) */
1360 sub(reg_tmp_filter,
1361 (jcp.t_pad - jcp.oh * jcp.stride_h) * filter_shift);
1362 }
1363
1364 /* Apply correction: reset value of 'reg_kh' to scenario outside of
1365 * special cases due to top_padding (i.e. 'min(jcp.kh, jcp.ih)')*/
1366 mov(reg_kh, inp_ker_overlap);
1367 jmp(common_block_label);
1368
1369 L(top_padding_end_label);
1370 }
1371
1372 /* Compute 'bottom' edge */
1373 if (jcp.b_pad > 0) {
1374
1375 /* Check if within bottom padding region */
1376 cmp(reg_oh, input_bottom_padding_overlap - 1);
1377 jl(bottom_padding_end_label, T_NEAR);
1378 jg(bottom_padding_label, T_NEAR);
1379
1380 /* Execute overlap correction between the filter and the initial
1381 * bottom padding region. */
1382 mov(reg_kh,
1383 jcp.ih + jcp.t_pad
1384 - input_bottom_padding_overlap * jcp.stride_h);
1385 jmp(bottom_padding_end_label, T_NEAR);
1386
1387 L(bottom_padding_label);
1388 sub(reg_kh, jcp.stride_h);
1389 cmp(reg_kh, 0);
1390 jle(loop_end_label, T_NEAR);
1391
1392 L(bottom_padding_end_label);
1393 }
1394
1395 /* Compute middle block */
1396 add(reg_tmp_input, input_shift * jcp.stride_h);
1397
1398 /* Execute common block and loop */
1399 L(common_block_label);
1400 add(reg_tmp_output, output_shift);
1401 inc(reg_oh);
1402 cmp(reg_oh, reg_oh_worksize);
1403 jl(loop_begin_label, T_NEAR);
1404
1405 L(loop_end_label);
1406}
1407
1408void jit_avx512_dw_conv_bwd_weights_kernel_bf16::calculate_w_unrolling(
1409 int &unroll_trips, int &unroll_w, int &unroll_w_tail) {
1410
1411 const bool do_unroll_w = jcp.ow > max_unroll_w_;
1412 if (do_unroll_w) {
1413 unroll_w = nstl::min(block_size_, jcp.ow);
1414 unroll_trips = jcp.ow / unroll_w;
1415 /* calculate tail */
1416 unroll_w_tail = jcp.ow % unroll_w;
1417 /* Perform some rebalancing if tail too small*/
1418 if ((unroll_w_tail == 0 && jcp.r_pad != 0)
1419 || (jcp.r_pad > 0 && jcp.r_pad >= unroll_w_tail)) {
1420 if (unroll_trips > 1) {
1421 unroll_w_tail += unroll_w;
1422 unroll_trips--;
1423 } else {
1424 /* Idealy, this case shouldn't happen */
1425 unroll_w_tail += (unroll_w - unroll_w / 2);
1426 unroll_w = unroll_w / 2;
1427 }
1428 }
1429 } else {
1430 unroll_w_tail = jcp.ow;
1431 }
1432}
1433
1434void jit_avx512_dw_conv_bwd_weights_kernel_bf16::compute_ow_block_unroll() {
1435
1436 Label ow_blk_label; // for compute middle block
1437 int pad_offset = 0;
1438 int l_pad = jcp.l_pad;
1439 int unroll_w_tail = 0;
1440 int unroll_w = 0;
1441 int unroll_trips = 0;
1442 calculate_w_unrolling(unroll_trips, unroll_w, unroll_w_tail);
1443
1444 const size_t ch_offset = is_layout_nxc() ? jcp.ngroups : jcp.ch_block;
1445 const size_t data_offset
1446 = static_cast<size_t>(unroll_w * ch_offset * jcp.typesize_in);
1447
1448 if (jcp.with_bias) compute_bias();
1449
1450 /* Pass filter address, then offset for h_padding. */
1451 deploy_zero_filter();
1452 mov(reg_kh_offset, ptr[this->param1 + GET_OFF(filter_pad_off)]);
1453 add(reg_filter_baddr, reg_kh_offset);
1454
1455 /* compute left padded block */
1456 const bool do_unroll_w = jcp.ow > max_unroll_w_;
1457 if (l_pad && do_unroll_w) {
1458 compute_h_loop(unroll_w, l_pad, 0, 0);
1459 add(reg_output_baddr, data_offset);
1460 add(reg_input_baddr, data_offset * jcp.stride_w);
1461 unroll_trips--;
1462 pad_offset = l_pad;
1463 l_pad = 0;
1464 }
1465
1466 /* Insert loop for 'ow' block when middle block needs to execute more
1467 * than once */
1468 bool do_ow_blk_loop = unroll_trips > 1;
1469 if (do_ow_blk_loop) {
1470 mov(reg_iter_ow_blk, unroll_trips);
1471 L(ow_blk_label);
1472 }
1473 if (unroll_trips > 0) {
1474 compute_h_loop(unroll_w, l_pad, pad_offset, 0);
1475 add(reg_output_baddr, data_offset);
1476 add(reg_input_baddr, data_offset * jcp.stride_w);
1477 }
1478 if (do_ow_blk_loop) {
1479 dec(reg_iter_ow_blk);
1480 cmp(reg_iter_ow_blk, 0);
1481 jg(ow_blk_label, T_NEAR);
1482 }
1483
1484 /* compute right padded block */
1485 if (unroll_w_tail) {
1486 compute_h_loop(
1487 unroll_w_tail, l_pad, pad_offset, jcp.ow - unroll_w_tail);
1488 }
1489}
1490
1491void jit_avx512_dw_conv_bwd_weights_kernel_bf16::generate() {
1492 assert(is_src_layout_nxc() == is_ddst_layout_nxc());
1493
1494 preamble();
1495
1496 mov(reg_input_baddr, ptr[this->param1 + GET_OFF(input)]);
1497 mov(reg_output_baddr, ptr[this->param1 + GET_OFF(output)]);
1498 mov(reg_filter_baddr, ptr[this->param1 + GET_OFF(filter)]);
1499
1500 bool set_kmask = jcp.ch_tail > 0 && (jcp.with_bias || is_layout_nxc());
1501 if (set_kmask) {
1502 // Prepare masks for tail
1503 Reg32 reg_tmp_32 = reg_tmp.cvt32();
1504 mov(reg_tmp_32, (1 << jcp.ch_tail) - 1);
1505 kmovw(k_ch_tail_mask, reg_tmp_32);
1506 }
1507
1508 compute_ow_block_unroll();
1509
1510 postamble();
1511}
1512#undef GET_OFF
1513
1514} // namespace x64
1515} // namespace cpu
1516} // namespace impl
1517} // namespace dnnl
1518