1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/bfloat16.hpp"
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/math_utils.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27
28#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
29#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
30#include "cpu/x64/jit_avx512_core_bf16_conv_kernel.hpp"
31
32#define GET_OFF(field) offsetof(jit_conv_call_s, field)
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39using namespace format_tag;
40using namespace dnnl::impl::memory_tracking::names;
41using namespace dnnl::impl::utils;
42using namespace Xbyak;
43
44namespace {
45
46constexpr auto small_spatial = 14;
47
48inline void pick_loop_order(jit_conv_conf_t &jcp) {
49 using namespace prop_kind;
50 assert(one_of(
51 jcp.prop_kind, forward_training, forward_inference, backward_data));
52 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
53 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
54
55 if (utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
56 format_tag::nwc)
57 && jcp.ngroups > 1 && jcp.oc < 16) {
58 jcp.loop_order = loop_nhwcg;
59 } else if (jcp.prop_kind == backward_data) {
60 // ow-threading is currently implemented for forward only
61 // TODO: single code for fwd and bwd after ow-thr for bwd
62 // meaningless switch was removed
63 if (jcp.ndims < 5)
64 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
65 ? loop_cwgn
66 : loop_gncw;
67 else
68 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
69 ? loop_cgn
70 : loop_gnc;
71 } else {
72 jcp.loop_order = (w <= small_spatial && h <= small_spatial) ? loop_cwgn
73 : loop_gncw;
74 }
75}
76inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
77 /*is 1D conv */
78 return (jcp.id == 1 && jcp.ih == 1 && jcp.kd == 1 && jcp.kh == 1);
79}
80inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
81 return (jcp.nb_ow > 1);
82}
83inline bool is_iw_threading_available(const jit_conv_conf_t &jcp) {
84 return one_of(jcp.ndims, 3, 4);
85}
86inline bool is_iw_threading_on(const jit_conv_conf_t &jcp) {
87 return (jcp.nb_iw > 1);
88}
89inline bool is_1stconv(const jit_conv_conf_t &jcp) {
90 const bool no_big_offt = nstl::max<size_t>(jcp.ic, jcp.oc)
91 * nstl::max(jcp.typesize_in, jcp.typesize_out) * jcp.id
92 * jcp.ih * jcp.iw
93 < INT_MAX;
94 return jcp.ic < 16 && jcp.ngroups == 1 && no_big_offt;
95}
96} // namespace
97
98template <typename Vmm>
99_jit_avx512_core_bf16_fwd_kernel<Vmm>::_jit_avx512_core_bf16_fwd_kernel(
100 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
101 const memory_desc_t &dst_md)
102 : jit_generator(jit_name(), nullptr, ker_code_size, true, avx512_core_bf16)
103 , jcp(ajcp)
104 , attr_(attr) {
105 if (jcp.with_eltwise || jcp.with_binary) {
106 using namespace binary_injector;
107 static constexpr bool preserve_gpr = true;
108 static constexpr bool preserve_vmm = false;
109 static constexpr size_t helper_vmm_idx = 31;
110 const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
111 const size_t tail_size = oc_block_tail
112 ? oc_block_tail
113 : jcp.oc_without_padding % isa_simd_width_;
114 static constexpr bool use_exact_tail_scalar_bcast = true;
115
116 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
117 r14, r15, r12, preserve_gpr, preserve_vmm,
118 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
119 memory_desc_wrapper(dst_md), tail_size, postops_mask,
120 use_exact_tail_scalar_bcast};
121 const static_params_t static_params {
122 this->param1, rhs_arg_static_params};
123
124 postops_injector_ = utils::make_unique<
125 injector::jit_uni_postops_injector_t<avx512_core, Vmm>>(
126 this, jcp.post_ops, static_params);
127 }
128 if (!isa_has_bf16(jcp.isa))
129 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
130 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
131 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5);
132}
133
134template <typename Vmm>
135void _jit_avx512_core_bf16_fwd_kernel<Vmm>::prepare_dst(int ur_w) {
136 for (int k = 0; k < jcp.nb_oc_blocking; k++)
137 for (int j = 0; j < ur_w; j++) {
138 Vmm vmm = vmm_dst(j, k);
139 vpxord(vmm, vmm, vmm);
140 }
141}
142
143template <typename Vmm>
144int _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst_idx(
145 int i_ur, int i_oc) const {
146 const int idx = i_ur * jcp.nb_oc_blocking + i_oc;
147 assert(idx < ker_reg_base_idx);
148 return idx;
149}
150
151template <typename Vmm>
152Vmm _jit_avx512_core_bf16_fwd_kernel<Vmm>::vmm_dst(int i_ur, int i_oc) const {
153 return Vmm(vmm_dst_idx(i_ur, i_oc));
154}
155
156template <typename F>
157static void iterate(const int nb_oc_block, const int ur_w, const bool mask_tail,
158 const bool force_masking, const F &f) {
159 for (int k = 0; k < nb_oc_block; k++) {
160 const bool mask_flag
161 = force_masking || (mask_tail && k + 1 == nb_oc_block);
162 for (int j = 0; j < ur_w; j++)
163 f(mask_flag, k, j);
164 }
165}
166template <typename F>
167static void iterate(const int nb_oc_block, const int ur_w, const F &f) {
168 iterate(nb_oc_block, ur_w, false, false, f);
169}
170
171template <typename Vmm>
172void _jit_avx512_core_bf16_fwd_kernel<Vmm>::apply_postops(int ur_w) {
173 if (jcp.with_eltwise || jcp.with_binary) {
174 injector_utils::vmm_index_set_t vmm_idxs;
175 if (jcp.with_binary) {
176 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
177 rhs_arg_params_tail;
178 const auto mask_tail = jcp.oc_without_padding % jcp.simd_w;
179 const bool oc_blk_is_smaller_than_vmm
180 = jcp.oc_block < isa_simd_width_;
181 iterate(jcp.nb_oc_blocking, ur_w, mask_tail,
182 oc_blk_is_smaller_than_vmm,
183 [&](const bool mask_flag, const int k, const int j) {
184 const size_t aux_output_l_off = get_dst_offset(j, k);
185 const auto vmm_idx = vmm_dst_idx(j, k);
186 vmm_idxs.emplace(vmm_idx);
187
188 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
189 vmm_idx, reg_dst);
190 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
191 vmm_idx, aux_output_l_off);
192 if (mask_flag)
193 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
194 });
195 rhs_arg_params = rhs_arg_params_tail;
196 rhs_arg_params.vmm_tail_idx_.clear();
197
198 Label postops_done;
199 if (mask_tail || oc_blk_is_smaller_than_vmm) {
200 Label postops_no_tail;
201 if (mask_tail) {
202 test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
203 jz(postops_no_tail, T_NEAR);
204 }
205 postops_injector_->compute_vector_range(
206 vmm_idxs, rhs_arg_params_tail);
207 jmp(postops_done, T_NEAR);
208 L(postops_no_tail);
209 }
210 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
211 L(postops_done);
212
213 } else {
214 iterate(jcp.nb_oc_blocking, ur_w,
215 [&](const bool, const int k, const int j) {
216 vmm_idxs.emplace(vmm_dst_idx(j, k));
217 });
218 postops_injector_->compute_vector_range(vmm_idxs);
219 }
220 }
221}
222
223template <typename Vmm>
224void _jit_avx512_core_bf16_fwd_kernel<Vmm>::store_dst(int ur_w) {
225 Label store_label;
226 const int oc_tail = jcp.oc_tail;
227 if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
228
229 if (jcp.with_sum) {
230 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
231 for (int j = 0; j < ur_w; j++) {
232 // mask only needed for last oc_block
233 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
234 Vmm vmm = vmm_dst(j, k);
235 size_t aux_dst_offset = get_dst_offset(j, k);
236 if (jcp.dst_dt == data_type::bf16) {
237 vpmovzxwd(may_be_mask_vmm(vmm_prev_dst, mask_flag, true),
238 make_safe_addr(
239 reg_dst, aux_dst_offset, reg_long_offt));
240 vpslld(vmm_prev_dst, vmm_prev_dst, 16);
241 vaddps(vmm, vmm_prev_dst);
242 } else {
243 vaddps(may_be_mask_vmm(vmm, mask_flag, true),
244 make_safe_addr(
245 reg_dst, aux_dst_offset, reg_long_offt));
246 }
247 }
248 }
249 }
250
251 if (jcp.with_bias) {
252 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
253 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
254 int bias_offset = jcp.typesize_bia * k * jcp.oc_block;
255 for (int j = 0; j < ur_w; j++) {
256 // mask only needed for last oc_block
257 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking;
258 Vmm vmm = vmm_dst(j, k);
259 if (jcp.bia_dt == data_type::bf16) {
260 vpmovzxwd(may_be_mask_vmm(vmm_bias, mask_flag, true),
261 EVEX_compress_addr(reg_bias, bias_offset));
262 vpslld(vmm_bias, vmm_bias, 16);
263 vaddps(vmm, vmm_bias);
264 } else
265 vaddps(may_be_mask_vmm(vmm, mask_flag, true),
266 EVEX_compress_addr(reg_bias, bias_offset));
267 }
268 }
269 }
270
271 apply_postops(ur_w);
272
273 L(store_label);
274 if (jcp.dst_dt == data_type::f32) {
275 for (int k = 0; k < jcp.nb_oc_blocking; k++)
276 for (int j = 0; j < ur_w; j++) {
277 Vmm vmm = vmm_dst(j, k);
278 size_t aux_dst_offset = get_dst_offset(j, k);
279 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
280 // mask only needed for last oc_block
281 bool mask_flag = oc_tail && k + 1 == jcp.nb_oc_blocking
282 && is_dst_layout_nxc();
283 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
284 }
285 } else if (jcp.dst_dt == data_type::bf16) {
286 if (isa_has_bf16(jcp.isa) && is_dst_layout_nxc()) {
287 // Optimization: use single store instruction for pair of the
288 // nearest vectors along OC dimension
289 for (int j = 0; j < ur_w; j++) {
290 int k = 0;
291 for (; k < rnd_dn(jcp.nb_oc_blocking, 2); k += 2) {
292 Vmm vmm = vmm_dst(j, k);
293 Vmm vmm_next = vmm_dst(j, k + 1);
294 size_t aux_dst_offset = get_dst_offset(j, k);
295 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
296 vcvtne2ps2bf16(vmm, vmm_next, vmm);
297 // mask only needed for last oc_block
298 bool mask_flag = oc_tail && k + 2 == jcp.nb_oc_blocking;
299 vmovdqu16(
300 addr, may_be_mask_vmm(vmm, mask_flag, false, true));
301 }
302 if (jcp.nb_oc_blocking % 2 != 0) {
303 Vmm vmm = vmm_dst(j, k);
304 auto vmm_down = Vmm_down_t(vmm.getIdx());
305 size_t aux_dst_offset = get_dst_offset(j, k);
306 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
307 vcvtneps2bf16(vmm_down, vmm);
308 // for xmm, upper half is zero after conversion to
309 // bf16, so mask always & mask for tails
310 bool mask_flag = jcp.simd_w == 4 || oc_tail;
311 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
312 }
313 }
314 } else if (isa_has_bf16(jcp.isa) /* !is_dst_layout_nxc() */) {
315 // Optimization: use single store instruction for pair of the
316 // nearest vectors along WIDTH dimension
317 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
318 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
319 for (j = 0; j < n_2bf2ps; j += 2) {
320 size_t aux_dst_offset = get_dst_offset(j, k);
321 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
322
323 auto vmm_str = vmm_src(j, jcp.nb_oc_blocking);
324 vcvtne2ps2bf16(vmm_str, vmm_dst(j + 1, k), vmm_dst(j, k));
325 vmovups(addr, vmm_str);
326 }
327 if (j < ur_w) {
328 size_t aux_dst_offset = get_dst_offset(j, k);
329
330 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
331 auto vmm_down_str = vmm_src_down(j, jcp.nb_oc_blocking);
332 vcvtneps2bf16(vmm_down_str, vmm_dst(j, k));
333 // for xmm, upper half is zero after conversion to
334 // bf16, so mask always.
335 const bool mask_flag = jcp.simd_w == 4;
336 vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
337 }
338 }
339 } else {
340 for (int k = 0; k < jcp.nb_oc_blocking; k++)
341 for (int j = 0; j < ur_w; j++) {
342 Vmm vmm = vmm_dst(j, k);
343 size_t aux_dst_offset = get_dst_offset(j, k);
344 auto addr = EVEX_compress_addr(reg_dst, aux_dst_offset);
345 auto vmm_down = vmm_src_down(0, jcp.nb_oc_blocking);
346 bf16_emu_->vcvtneps2bf16(
347 Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
348 bool mask_flag = (oc_tail && k + 1 == jcp.nb_oc_blocking
349 && is_dst_layout_nxc())
350 // for xmm, upper half is zero after conversion to
351 // bf16, so mask always & mask for tails
352 || jcp.simd_w == 4;
353 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
354 }
355 }
356 } else
357 assert(!"unsupported destination type");
358}
359
360template <typename Vmm>
361void _jit_avx512_core_bf16_fwd_kernel<Vmm>::compute_loop(
362 int ur_w, int pad_l, int pad_r) {
363 Label kh_label, kd_label;
364 const int ic_tail = jcp.ic_tail;
365 const int ic_step = 2;
366
367 /* max_src_offset is explicitly used in the 1st convolution.
368 * Set its value so that accessing the double-word memory
369 * referenced by ptr[src_base + offset] is safe whenever
370 * 0 <= offset < max_src_offset
371 *
372 * Note: Since the arguments pad_l, pad_r might not exactly match
373 * with jcp.l_pad and jcp.r_pad respectively so this value needs to be
374 * computed separately for each invocation of the compute_loop.
375 */
376 dim_t max_src_offset = 0;
377 if (jcp.is_1stconv || ic_tail) {
378 for (int ki = 0; ki < jcp.kw; ki++) {
379 int ow_fst = get_ow_start(ki, pad_l);
380 int ow_last = get_ow_end(ur_w, ki, pad_r) - 1;
381 if (ow_fst > ow_last) continue;
382 int ic_last = rnd_up(nstl::min(jcp.ic_block,
383 nstl::max(jcp.ic, ic_tail)),
384 ic_step)
385 - ic_step;
386
387 dim_t src_offset = get_src_offset(
388 ic_last, filter_w_to_src(ki, ow_last, pad_l));
389 if (src_offset > max_src_offset) max_src_offset = src_offset;
390 }
391 }
392
393 prepare_dst(ur_w);
394
395 Label skip_compute_loop;
396 if (jcp.ndims == 5) {
397 mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
398 if ((jcp.dilate_d >= jcp.id)
399 || (jcp.kd - 1) * (jcp.dilate_d + 1)
400 < nstl::max(jcp.f_pad, jcp.back_pad)) {
401 cmp(reg_kj, 0);
402 je(skip_compute_loop, T_NEAR);
403 }
404 }
405 mov(reg_kj, reg_kh);
406 if ((jcp.dilate_h >= jcp.ih)
407 || (jcp.kh - 1) * (jcp.dilate_h + 1)
408 < nstl::max(jcp.t_pad, jcp.b_pad)) {
409 cmp(reg_kj, 0);
410 je(skip_compute_loop, T_NEAR);
411 }
412
413 // IC loop
414 Label icb_label;
415 mov(reg_ic, jcp.ic);
416 L(icb_label);
417
418 if (jcp.ndims == 5) {
419 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
420 mov(ptr[rsp + off_reg_ker_], reg_ker);
421 mov(ptr[rsp + off_reg_src_], reg_src);
422
423 L(kd_label);
424 }
425
426 mov(aux_reg_src, reg_src);
427 mov(aux_reg_ker, reg_ker);
428
429 mov(reg_kj, reg_kh);
430
431 std::vector<Label> ic_tail_jmp(jcp.kw);
432 L(kh_label);
433 {
434 for (int ki = 0; ki < jcp.kw; ki++) {
435 int ow_start = get_ow_start(ki, pad_l);
436 int ow_end = get_ow_end(ur_w, ki, pad_r);
437 for (int ic = 0;
438 ic < rnd_up(nstl::min(jcp.ic_block, jcp.ic), ic_step);
439 ic += ic_step) {
440 if (ic_tail && ic == rnd_up(ic_tail, ic_step)) {
441 // insert this check at most once per icb, no more.
442 cmp(reg_ic, ic_tail);
443 je(ic_tail_jmp[ki], T_NEAR);
444 }
445 for (int oi = ow_start; oi < ow_end; oi++) {
446 dim_t src_offset = get_src_offset(
447 ic, filter_w_to_src(ki, oi, pad_l));
448 auto vmm_in = vmm_src(oi, jcp.nb_oc_blocking);
449 const auto addr_base = EVEX_compress_addr_safe(
450 aux_reg_src, src_offset, reg_long_offt);
451 const bool tail_load
452 = ic_tail && ic == rnd_dn(ic_tail, ic_step);
453 if (jcp.is_1stconv || tail_load) {
454 const bool need_single_load
455 = (ic + 1 == jcp.ic || ic + 1 == ic_tail);
456 const bool safe_overstep = (src_offset < max_src_offset)
457 && !is_src_layout_nxc();
458
459 /* For the comment below, let us define three words
460 * x_b = ptr[addr_base] and x_s = ptr[addr_strided]
461 * x_g = ptr[addr_base + 2]
462 *
463 * For single load case:
464 * Without overstep zmm_in register is loaded as
465 * [0, x_b, ..., 0, x_b, 0, x_b]
466 * On the other hand, "with overstep" zmm_in register
467 * is loaded as
468 * [x_g, x_b, ..., x_g, x_b, x_g, x_b]
469 * where x_g is a garbage word.
470 *
471 * Note:
472 * 1. In single load case with safe_overstep enabled,
473 * it is implicitly assumed that the element in zmm_wei
474 * register corresponding to the "garbage value x_g" in
475 * zmm_in register is zero.
476 * 2. One can have potential problem when x_g is
477 * either Inf or NaN since it is multiplied by zero
478 * in accumulation. But as x_g is a "valid input"
479 * for different offset so one might assume that x_g is
480 * neither Inf nor Nan.
481 *
482 * For non single load case:
483 * zmm_in register is loaded as
484 * [x_s, x_b, ...., x_s, x_b, x_s, x_b]
485 */
486 if (tail_load) {
487 if (need_single_load) {
488 Label mask_load, load_done;
489 cmp(reg_ic, ic + ic_step);
490 jl(mask_load, T_NEAR);
491 vpbroadcastd(vmm_in, addr_base);
492 jmp(load_done, T_NEAR);
493 L(mask_load);
494 vpbroadcastw(vmm_in | odd_load_mask | T_z,
495 addr_base);
496 L(load_done);
497 } else {
498 vpbroadcastd(vmm_in, addr_base);
499 }
500 } else if (need_single_load && !safe_overstep)
501 vpbroadcastw(
502 vmm_in | odd_load_mask | T_z, addr_base);
503 else if (IMPLICATION(!is_src_layout_nxc(),
504 need_single_load && safe_overstep))
505 vpbroadcastd(vmm_in, addr_base);
506 else {
507 const auto addr_strided
508 = EVEX_compress_addr_safe(aux_reg_src,
509 src_offset + get_src_offset(1, 0),
510 reg_long_offt);
511 vpbroadcastd(vmm_in, addr_base);
512 vpbroadcastw(vmm_in | even_load_mask, addr_strided);
513 }
514 } else {
515 vpbroadcastd(vmm_in, addr_base);
516 }
517 }
518 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
519 auto wei_off = get_kernel_offset(kk, ic, ki);
520 vmovups(vmm_wei,
521 EVEX_compress_addr_safe(
522 aux_reg_ker, wei_off, reg_long_offt));
523 for (int oi = ow_start; oi < ow_end; oi++) {
524 auto acc = vmm_dst(oi, kk);
525 auto src = vmm_src(oi, jcp.nb_oc_blocking);
526 if (isa_has_bf16(jcp.isa)) {
527 vdpbf16ps(acc, vmm_wei, src);
528 } else
529 bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
530 Zmm(vmm_wei.getIdx()), Zmm(src.getIdx()));
531 }
532 }
533 }
534 L(ic_tail_jmp[ki]);
535 }
536 safe_add(aux_reg_ker, get_kernel_offset(0, 0, 0, 1), reg_long_offt);
537 safe_add(aux_reg_src, get_src_offset(0, filter_h_to_src(1)),
538 reg_long_offt);
539
540 dec(reg_kj);
541 cmp(reg_kj, 0);
542 jg(kh_label, T_NEAR);
543 }
544
545 if (jcp.ndims == 5) {
546 safe_add(reg_src, get_src_offset(0, filter_d_to_src(1)), reg_long_offt);
547 safe_add(reg_ker, get_kernel_offset(0, 0, 0, 0, 1), reg_long_offt);
548 dec(reg_ki);
549 cmp(reg_ki, 0);
550 jg(kd_label, T_NEAR);
551
552 mov(reg_ker, ptr[rsp + off_reg_ker_]);
553 mov(reg_src, ptr[rsp + off_reg_src_]);
554 }
555
556 // End of IC Loop
557 dim_t src_step = get_src_offset(jcp.ic_block, 0);
558 const size_t ker_step = get_kernel_offset(0, jcp.ic_block, 0);
559 safe_add(reg_src, src_step, reg_long_offt);
560 safe_add(reg_ker, ker_step, reg_long_offt);
561
562 sub(reg_ic, jcp.ic_block);
563 cmp(reg_ic, 0);
564 jg(icb_label, T_NEAR);
565
566 safe_sub(reg_src, src_step * jcp.nb_ic, reg_long_offt);
567 safe_sub(reg_ker, ker_step * jcp.nb_ic, reg_long_offt);
568
569 L(skip_compute_loop);
570 store_dst(ur_w);
571}
572
573template <typename Vmm>
574void _jit_avx512_core_bf16_fwd_kernel<Vmm>::generate() {
575 int iw = jcp.iw;
576 int ow = jcp.ow;
577 int ow_block = jcp.ow_block;
578 int nb_ow = jcp.nb_ow;
579 int kw = jcp.kw;
580 int l_pad = jcp.l_pad;
581 int ur_w = jcp.ur_w;
582 int ur_w_tail = jcp.ur_w_tail;
583 int stride_w = jcp.stride_w;
584
585 auto src_shift = get_src_offset(0, filter_w_to_src(0, ur_w));
586 auto dst_shift = get_dst_offset(ur_w, 0);
587
588 auto src_shift_pad = get_src_offset(0, filter_w_to_src(0, ur_w, l_pad));
589 auto src_shift_pad_second_block
590 = get_src_offset(0, filter_w_to_src(0, 0, l_pad));
591
592 preamble();
593 if (jcp.ndims == 5) sub(rsp, stack_space_needed_);
594
595 if (jcp.is_1stconv || jcp.ic_tail) {
596 Xbyak::Reg64 reg_alt_mask = r8;
597 const auto odd_mask = size_t {0x5555555555555555};
598 const auto even_mask = size_t {0xaaaaaaaaaaaaaaaa};
599 mov(reg_alt_mask, odd_mask);
600 kmovq(odd_load_mask, reg_alt_mask);
601 mov(reg_alt_mask, even_mask);
602 kmovq(even_load_mask, reg_alt_mask);
603 }
604
605 if (jcp.simd_w == 4) {
606 auto reg_tail_32 = reg_oc.cvt32();
607 mov(reg_tail_32, (1 << jcp.simd_w) - 1);
608 kmovb(k_oc_tail_mask, reg_tail_32);
609 }
610
611 if (jcp.oc_tail) {
612 Label done;
613 // dummy mask all 1's
614 if (jcp.simd_w != 4) { // simd_w == 4, has its dummy mask set already
615 kxnord(k_oc_tail_mask, k_oc_tail_mask, k_oc_tail_mask);
616 }
617 // To account for special store optimization, where two oc_blocks are
618 // combined with one single write, extend the mask for 32bits (32 bf16s)
619 const bool need_extended_mask = jcp.dst_dt == data_type::bf16
620 && isa_has_bf16(jcp.isa) && jcp.nb_oc_blocking > 1;
621 if (need_extended_mask)
622 kxnord(k_oc_tail_mask_extended, k_oc_tail_mask_extended,
623 k_oc_tail_mask_extended);
624
625 test(byte[param1 + GET_OFF(load_work)], jcp.oc_block - 1);
626 jz(done, T_NEAR);
627 auto reg_tail_32 = reg_oc.cvt32();
628 mov(reg_tail_32, (1 << jcp.oc_tail) - 1);
629 kmovd(k_oc_tail_mask, reg_tail_32);
630 kmovd(postops_mask, reg_tail_32);
631 if (need_extended_mask) {
632 mov(reg_tail_32, (1 << (jcp.oc_tail + jcp.simd_w)) - 1);
633 kmovd(k_oc_tail_mask_extended, reg_tail_32);
634 }
635 L(done);
636 } else if (jcp.with_binary)
637 if (jcp.oc_block != isa_simd_width_) {
638 const int mask = (1 << jcp.oc_block) - 1;
639 const Reg32 regw_tmp = reg_oi.cvt32();
640 mov(regw_tmp, mask);
641 kmovd(postops_mask, regw_tmp);
642 }
643
644 mov(reg_src, ptr[param1 + GET_OFF(src)]);
645 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
646 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
647 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
648
649 int r_pad = nstl::max(0, jcp.r_pad);
650 int n_oi = ow / ur_w;
651 int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w,
652 calculate_extended_filter_size(kw, jcp.dilate_w));
653
654 if (!is_ow_threading_on(jcp)) {
655 // ow is being processed as a whole - with left and right paddings
656 if (r_pad1 > 0) n_oi--;
657
658 xor_(reg_oi, reg_oi);
659 if (ow == ur_w) {
660 compute_loop(ur_w, l_pad, r_pad);
661 } else {
662 if (n_oi == 0) {
663 compute_loop(ur_w, l_pad, r_pad1);
664 add(reg_src, src_shift_pad);
665 add(reg_dst, dst_shift);
666 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
667 } else {
668 if (l_pad > 0) {
669 compute_loop(ur_w, l_pad, 0);
670 add(reg_src, src_shift_pad);
671 add(reg_dst, dst_shift);
672 inc(reg_oi);
673 }
674 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
675 Label ow_loop_label;
676 L(ow_loop_label);
677 {
678 compute_loop(ur_w, 0, 0);
679 add(reg_src, src_shift);
680 add(reg_dst, dst_shift);
681
682 inc(reg_oi);
683 cmp(reg_oi, n_oi);
684 jl(ow_loop_label, T_NEAR);
685 }
686 }
687 if (r_pad1 > 0) {
688 compute_loop(ur_w, 0, r_pad1);
689 add(reg_src, src_shift);
690 add(reg_dst, dst_shift);
691 }
692 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
693 }
694 }
695 } else {
696 // ow block is only processed.
697 // Number of block is passed as parameter owb,
698 // and padding processing depends on this number.
699
700 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
701 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
702
703 assert(ow_block % ur_w == 0);
704 int n_oi_not_last_ow_block = ow_block / ur_w;
705 // to simplify code (and general regs usage),
706 // size of ow block must be >= 2 * ur_w
707 assert(n_oi_not_last_ow_block > 1);
708 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
709 int n_oi_first_ow_block = n_oi_not_last_ow_block;
710
711 int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w;
712
713 // prepare right padding
714 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
715 bool first_ow_block_padded
716 = next_last_ow_block_padded && jcp.nb_ow == 2;
717 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
718
719 if (last_ow_block_padded)
720 n_oi_last_ow_block--;
721 else if (first_ow_block_padded)
722 n_oi_first_ow_block--;
723 else if (next_last_ow_block_padded)
724 n_oi_next_last_ow_block--;
725
726 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
727 cmp(reg_owb, 0); // is that the first ow-block ?
728 jg(middle_ow_blocks_label, T_NEAR);
729
730 // the first ow block, compute left padding
731
732 mov(reg_oi, n_oi_first_ow_block);
733 if (l_pad > 0) {
734 compute_loop(ur_w, l_pad, 0);
735 add(reg_src, src_shift_pad);
736 add(reg_dst, dst_shift);
737 dec(reg_oi);
738 }
739 jmp(oi_loop_label, T_NEAR);
740
741 // middle or last ow block entry
742
743 L(middle_ow_blocks_label);
744
745 if (l_pad > 0) {
746 // just to consider left padding, not compute
747 add(reg_src, src_shift_pad_second_block);
748 }
749
750 // set number of iteration for oi-loop
751 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
752 mov(reg_oi, n_oi_last_ow_block);
753 je(oi_loop_label, T_NEAR);
754 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
755 mov(reg_oi, n_oi_next_last_ow_block);
756 je(oi_loop_label, T_NEAR);
757 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
758
759 // oi loop w/o padding
760 L(oi_loop_label);
761 L(oi_loop_start_label);
762 cmp(reg_oi, 0);
763 jle(oi_loop_end_label, T_NEAR);
764
765 compute_loop(ur_w, 0, 0);
766 add(reg_src, src_shift);
767 add(reg_dst, dst_shift);
768 dec(reg_oi);
769 jmp(oi_loop_start_label, T_NEAR);
770 L(oi_loop_end_label);
771
772 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
773
774 cmp(reg_owb, 0); // first ow-block ?
775 if (first_ow_block_padded) {
776 je(last_oi_label, T_NEAR);
777 } else {
778 je(end_label, T_NEAR);
779 }
780 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
781 jl(end_label, T_NEAR);
782 if (next_last_ow_block_padded) {
783 je(last_oi_label, T_NEAR);
784 } else {
785 je(end_label, T_NEAR);
786 }
787 // that is last block
788 if (!last_ow_block_padded) { jmp(tail_label, T_NEAR); }
789
790 // last oi block with right padding
791 L(last_oi_label);
792 compute_loop(ur_w, 0, r_pad1);
793 add(reg_src, src_shift);
794 add(reg_dst, dst_shift);
795
796 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
797 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
798 jl(end_label, T_NEAR);
799
800 L(tail_label);
801 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_pad); }
802 L(end_label);
803 }
804
805 if (jcp.ndims == 5) add(rsp, stack_space_needed_);
806 postamble();
807
808 if (jcp.with_eltwise) postops_injector_->prepare_table();
809}
810
811void jit_avx512_core_bf16_fwd_kernel::init_scratchpad(
812 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
813 using namespace memory_tracking::names;
814 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) {
815 assert(jcp.ngroups == 1);
816 scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia);
817 }
818}
819
820status_t jit_avx512_core_bf16_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
821 const convolution_desc_t &cd, memory_desc_t &src_md,
822 memory_desc_t &weights_md, memory_desc_t &dst_md,
823 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
824
825 using namespace prop_kind;
826
827 const memory_desc_wrapper src_d(&src_md);
828 const memory_desc_wrapper weights_d(&weights_md);
829 const memory_desc_wrapper dst_d(&dst_md);
830 const memory_desc_wrapper bias_d(&bias_md);
831
832 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
833 int ndims = src_d.ndims();
834
835 jcp = zero<decltype(jcp)>();
836 jcp.nthr = nthreads;
837 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
838 : bf16_emulation_t::get_isa();
839 jcp.has_vnni = true;
840 jcp.ndims = ndims;
841 jcp.prop_kind = cd.prop_kind;
842 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
843 jcp.mb = src_d.dims()[0];
844 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
845 jcp.oc_without_padding = jcp.oc;
846 jcp.ic = src_d.dims()[1] / jcp.ngroups;
847 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
848 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
849 jcp.iw = src_d.dims()[ndims - 1];
850 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
851 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
852 jcp.ow = dst_d.dims()[ndims - 1];
853 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
854 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
855 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
856 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
857 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
858 jcp.l_pad = cd.padding[0][ndims - 3];
859 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
860 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
861 jcp.stride_w = cd.strides[ndims - 3];
862 jcp.dst_dt = dst_d.data_type();
863
864 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
865 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
866 jcp.dilate_w = cd.dilates[ndims - 3];
867
868 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
869
870 jcp.typesize_in = types::data_type_size(src_d.data_type());
871 jcp.typesize_out = types::data_type_size(dst_d.data_type());
872
873 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
874 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
875
876 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
877 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
878 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
879 jcp.r_pad = calculate_end_padding(
880 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
881 jcp.b_pad = calculate_end_padding(
882 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
883 jcp.back_pad = calculate_end_padding(
884 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
885 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
886 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
887 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
888 if (kernel_outside_src) return status::unimplemented;
889
890 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
891 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
892 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
893 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
894 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
895 auto curr_src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c,
896 dat_tag_nCx8c, dat_tag_nCx4c, dat_tag_ncx);
897 auto curr_dst_tag = dst_d.matches_one_of_tag(
898 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
899 bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc,
900 src_d.format_kind() == format_kind::any)
901 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
902 dst_d.format_kind() == format_kind::any)
903 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
904 jcp.is_1stconv = is_1stconv(jcp);
905
906 const int regs = isa_has_bf16(jcp.isa) ? 31 /* expl_bcast case */ : 26;
907 const bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
908
909 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
910
911 const bool ok_to_try_lower_zmm = true
912 && IMPLICATION(is_data_layout_nxc,
913 jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
914 && jcp.ngroups > 1)
915 && !jcp.is_1stconv && !ok_to_pad_channels
916 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
917
918 if (ok_to_try_lower_zmm) {
919 for (auto simd : {8, 4}) {
920 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
921 jcp.simd_w = simd;
922 break;
923 }
924 }
925 }
926
927 jcp.oc_block = jcp.simd_w;
928 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
929
930 if (ok_to_pad_channels) {
931 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
932 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
933 }
934
935 if (!IMPLICATION(!is_data_layout_nxc,
936 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
937 return status::unimplemented;
938
939 format_tag_t src_tag, dst_tag, wei_tag;
940
941 if (jcp.simd_w == 8) {
942 assert(with_groups);
943 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
944 wei_tag = pick(ndims - 3, gOIw4i8o2i, gOIhw4i8o2i, gOIdhw4i8o2i);
945 } else if (jcp.simd_w == 4) {
946 assert(with_groups);
947 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
948 wei_tag = pick(ndims - 3, gOIw2i4o2i, gOIhw2i4o2i, gOIdhw2i4o2i);
949 } else if (jcp.is_1stconv) {
950 dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
951 src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_ncx;
952 wei_tag = pick(2 * ndims - 6 + with_groups, OwI16o2i, gOwI16o2i,
953 OhwI16o2i, gOhwI16o2i, OdhwI16o2i, gOdhwI16o2i);
954 } else {
955 dst_tag = src_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
956 wei_tag = pick(2 * ndims - 6 + with_groups, OIw8i16o2i, gOIw8i16o2i,
957 OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i, gOIdhw8i16o2i);
958 }
959
960 if (src_md.format_kind == format_kind::any)
961 CHECK(memory_desc_init_by_tag(src_md, src_tag));
962 else if (curr_src_tag != src_tag)
963 return status::unimplemented;
964 jcp.src_tag = src_tag;
965
966 if (dst_md.format_kind == format_kind::any)
967 CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
968 else if (curr_dst_tag != dst_tag)
969 return status::unimplemented;
970 jcp.dst_tag = dst_tag;
971
972 if (weights_md.format_kind == format_kind::any) {
973 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
974 jcp.wei_tag = wei_tag;
975 } else {
976 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
977 if (jcp.wei_tag != wei_tag) return status::unimplemented;
978 }
979
980 if (jcp.with_bias) {
981 if (bias_d.format_kind() == format_kind::any)
982 CHECK(memory_desc_init_by_tag(bias_md, x));
983 }
984
985 jcp.aligned_threads = 0;
986
987 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
988 && jcp.oc <= dst_d.padded_dims()[1]
989 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
990 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
991 if (!args_ok) return status::unimplemented;
992
993 const auto &post_ops = attr.post_ops_;
994 jcp.with_sum = post_ops.find(primitive_kind::sum) != -1;
995 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
996 jcp.with_eltwise = eltwise_ind != -1;
997 if (jcp.with_eltwise) {
998 jcp.eltwise = post_ops.entry_[eltwise_ind].eltwise;
999 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1000 }
1001 const int binary_ind = post_ops.find(primitive_kind::binary);
1002 jcp.with_binary = binary_ind != -1;
1003
1004 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1005 if (is_data_layout_nxc)
1006 jcp.oc_tail = jcp.oc % jcp.simd_w;
1007 else
1008 jcp.oc_tail = jcp.with_binary ? jcp.oc_without_padding % jcp.simd_w : 0;
1009
1010 if (attr.set_default_formats(&dst_md) != status::success)
1011 return status::unimplemented;
1012
1013 jcp.post_ops = post_ops;
1014
1015 using namespace injector;
1016 static constexpr bool sum_at_pos_0_only = true;
1017 static constexpr bool sum_requires_scale_one = true;
1018 static constexpr bool sum_requires_zp_zero = true;
1019 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1020 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1021 sum_requires_zp_zero});
1022 if (!post_ops_ok_) return status::unimplemented;
1023
1024 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1025 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1026 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1027
1028 jcp.kernel_kind = expl_bcast;
1029 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
1030 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
1031 int ur_w = regs / (jcp.nb_oc_blocking + 1);
1032 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
1033 && (jcp.l_pad <= ur_w
1034 && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
1035 break;
1036 }
1037
1038 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
1039 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1040 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1041
1042 jcp.ow_block = jcp.ow;
1043 if (is_ow_threading_available(jcp)) {
1044 const int L1_part = platform::get_per_core_cache_size(1) * 5 / 8;
1045 int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w;
1046 int size_dst_chunk = jcp.typesize_out * jcp.oc_block
1047 * jcp.nb_oc_blocking * jcp.ur_w;
1048 int size_wei_chunk = jcp.typesize_in * jcp.oc_block * jcp.ic_block
1049 * jcp.nb_oc_blocking * jcp.kw;
1050 int nurw = (L1_part - size_wei_chunk)
1051 / (size_dst_chunk + size_src_chunk);
1052 // current design of generate() requires ow_block >= 2 * ur_w
1053 jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
1054 }
1055 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1056
1057 int r_pad_no_tail = nstl::max(0,
1058 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1059 jcp.stride_w, ext_kw));
1060 if (jcp.l_pad > jcp.ur_w || r_pad_no_tail > jcp.ur_w)
1061 return status::unimplemented;
1062
1063 /* adjust the thread decomposition
1064 * to improve the perf for small problem size
1065 * the threshold L1_cache_size/factor and the factor is empirical
1066 * simply set the thread to 4 for now
1067 * TODO: Add get_thr_eff func to get optimal thread number */
1068
1069 size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1070 * jcp.kw * jcp.kd;
1071 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1072 * jcp.iw * jcp.id;
1073 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1074 * jcp.ow * jcp.od;
1075 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1076 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1077
1078 // The factor for 1d=1, 2d=2, 3d=4;
1079 int factor = nstl::max(1, (2 * (ndims - 3)));
1080 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1081 jcp.nthr = nstl::min(jcp.nthr, 4);
1082 }
1083
1084 pick_loop_order(jcp);
1085
1086 return status::success;
1087}
1088
1089template <typename Vmm>
1090void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::prepare_output(int ur_w) {
1091 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1092 for (int j = 0; j < ur_w; j++) {
1093 Vmm vmm = vmm_dsrc(j, k);
1094 vpxord(vmm, vmm, vmm);
1095 }
1096 }
1097}
1098
1099template <typename Vmm>
1100void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::store_output(int ur_w) {
1101 if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16();
1102 const int ic_tail = jcp.ic_tail;
1103
1104 if (jcp.dst_dt == data_type::f32) {
1105 for (int k = 0; k < jcp.nb_ic_blocking; k++)
1106 for (int j = 0; j < ur_w; j++) {
1107 Vmm vmm = vmm_dsrc(j, k);
1108 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1109 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1110 // mask only needed for last ic_block
1111 bool mask_flag = ic_tail && k + 1 == jcp.nb_ic_blocking
1112 && is_dsrc_layout_nxc();
1113 vmovups(addr, may_be_mask_vmm(vmm, mask_flag, false));
1114 }
1115 } else if (jcp.dst_dt == data_type::bf16) {
1116 if (isa_has_bf16(jcp.isa) && is_ddst_layout_nxc()) {
1117 // Optimization: use single store instruction for pair of the
1118 // nearest vectors along IC dimension
1119 for (int j = 0; j < ur_w; j++) {
1120 int k = 0;
1121 for (; k < rnd_dn(jcp.nb_ic_blocking, 2); k += 2) {
1122 Vmm vmm = vmm_dsrc(j, k);
1123 Vmm vmm_next = vmm_dsrc(j, k + 1);
1124 size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1125 auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1126 vcvtne2ps2bf16(vmm, vmm_next, vmm);
1127 bool mask_flag = ic_tail && k + 2 == jcp.nb_ic_blocking;
1128 vmovdqu16(
1129 addr, may_be_mask_vmm(vmm, mask_flag, false, true));
1130 }
1131 if (jcp.nb_ic_blocking % 2 != 0) {
1132 Vmm vmm = vmm_dsrc(j, k);
1133 auto vmm_down = Vmm_down_t(vmm.getIdx());
1134 size_t aux_dsrc_offset = get_diff_src_offset(j, k);
1135 auto addr = EVEX_compress_addr(reg_src, aux_dsrc_offset);
1136 vcvtneps2bf16(vmm_down, vmm);
1137 // for xmm, upper half is zero after conversion to
1138 // bf16, so mask always & mask for tails
1139 bool mask_flag = jcp.simd_w == 4 || ic_tail;
1140 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1141 }
1142 }
1143 } else if (isa_has_bf16(jcp.isa) /* && !is_ddst_layout_nxc() */) {
1144 // Optimization: use single store instruction for pair of the
1145 // nearest vectors along WIDTH dimension
1146 int store_idx = 0;
1147 const int max_regs = 32;
1148 const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking;
1149 const int num_regs_available = max_regs - free_regs_start_idx;
1150 int reg_idx = 0;
1151 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1152 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
1153 for (j = 0; j < n_2bf2ps; j += 2) {
1154 reg_idx = free_regs_start_idx
1155 + store_idx % num_regs_available;
1156 assert(reg_idx < max_regs);
1157 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1158 auto addr
1159 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1160
1161 auto vmm_str = Vmm(reg_idx);
1162 vcvtne2ps2bf16(vmm_str, vmm_dsrc(j + 1, k), vmm_dsrc(j, k));
1163 vmovups(addr, vmm_str);
1164 store_idx++;
1165 }
1166 if (j < ur_w) {
1167 reg_idx = free_regs_start_idx
1168 + store_idx % num_regs_available;
1169 assert(reg_idx < max_regs);
1170
1171 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1172 auto addr
1173 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1174 auto vmm_down_str = Vmm_down_t(reg_idx);
1175 vcvtneps2bf16(vmm_down_str, vmm_dsrc(j, k));
1176 // for xmm, upper half is zero after conversion to
1177 // bf16, so mask always.
1178 bool mask_flag = jcp.simd_w == 4;
1179 vmovdqu16(addr, may_be_mask_vmm(vmm_down_str, mask_flag));
1180 store_idx++;
1181 }
1182 }
1183 } else {
1184 for (int k = 0; k < jcp.nb_ic_blocking; k++)
1185 for (int j = 0; j < ur_w; j++) {
1186 Vmm vmm = vmm_dsrc(j, k);
1187 size_t aux_diff_src_offset = get_diff_src_offset(j, k);
1188 auto addr
1189 = EVEX_compress_addr(reg_src, aux_diff_src_offset);
1190 auto vmm_down = vmm_ddst_down(0);
1191 bf16_emu_->vcvtneps2bf16(
1192 Ymm(vmm_down.getIdx()), Zmm(vmm.getIdx()));
1193 bool mask_flag = (ic_tail && k + 1 == jcp.nb_ic_blocking
1194 && is_dsrc_layout_nxc())
1195 // for xmm, upper half is zero after conversion to
1196 // bf16, so mask always & mask for tails
1197 || jcp.simd_w == 4;
1198 vmovdqu16(addr, may_be_mask_vmm(vmm_down, mask_flag));
1199 }
1200 }
1201 } else
1202 assert(!"unsupported diff_src type");
1203}
1204
1205template <typename Vmm>
1206void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::compute_loop(
1207 int ur_w, int l_overflow, int r_overflow) {
1208 int kw = jcp.kw;
1209 int dilate_w = jcp.dilate_w + 1;
1210 int stride_w = jcp.stride_w;
1211 int stride_h = jcp.stride_h;
1212 const int oc_tail = jcp.oc_tail;
1213 Label kh_label, skip_compute_label;
1214
1215 prepare_output(ur_w);
1216
1217 if (jcp.ndims == 5) {
1218 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1219 cmp(reg_ki, 0);
1220 jle(skip_compute_label, T_NEAR);
1221 }
1222
1223 cmp(reg_kh, 0);
1224 jle(skip_compute_label, T_NEAR);
1225
1226 // OC loop
1227 Label ocb_label;
1228 mov(reg_oc, jcp.oc);
1229 L(ocb_label);
1230
1231 if (jcp.ndims < 5) {
1232 mov(aux_reg_dst, reg_dst);
1233 mov(aux_reg_ker, reg_ker);
1234 }
1235 Label kd_label;
1236 if (jcp.ndims == 5) {
1237 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1238 mov(aux_reg_dst_d, reg_dst);
1239 mov(aux_reg_ker_d, reg_ker);
1240
1241 L(kd_label);
1242 mov(aux_reg_dst, aux_reg_dst_d);
1243 mov(aux_reg_ker, aux_reg_ker_d);
1244 }
1245
1246 std::vector<Label> oc_tail_jmp(jcp.kw);
1247 mov(reg_kj, reg_kh);
1248 L(kh_label);
1249 {
1250 for (int ki = 0; ki < kw; ki++) {
1251 int jj_start = get_iw_start(ki, l_overflow);
1252 int jj_end = get_iw_end(ur_w, ki, r_overflow);
1253 const int ref_jj_start
1254 = nstl::max(0, l_overflow - (kw - 1 - ki) * dilate_w);
1255 const int ref_jj_end
1256 = ur_w - nstl::max(0, r_overflow - ki * dilate_w);
1257 assert(IMPLICATION(stride_w == 1,
1258 jj_start == ref_jj_start && jj_end == ref_jj_end));
1259 UNUSED(ref_jj_start);
1260 UNUSED(ref_jj_end);
1261 const int oc_step = 2;
1262 for (int oc = 0;
1263 oc < rnd_up(nstl::min(jcp.oc_block, jcp.oc), oc_step);
1264 oc += oc_step) {
1265 if (oc_tail && oc == rnd_up(oc_tail, oc_step)) {
1266 cmp(reg_oc, oc_tail);
1267 je(oc_tail_jmp[ki], T_NEAR);
1268 }
1269 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1270 assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0);
1271 int ow_idx = (jj + jcp.l_pad - ki * dilate_w) / stride_w;
1272 auto aux_ddst_offset = get_diff_dst_offset(ow_idx, oc);
1273 auto ddst = vmm_ddst(jj / stride_w);
1274 const bool tail_load = oc_tail && oc == rnd_dn(oc_tail, 2);
1275 const bool need_single_load = oc + 1 == oc_tail;
1276
1277 if (tail_load && need_single_load) {
1278 Label mask_load, load_done;
1279 cmp(reg_oc, oc + 2);
1280 jl(mask_load, T_NEAR);
1281 vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1282 jmp(load_done, T_NEAR);
1283 L(mask_load);
1284 // We broadcast w here. As the weights are zero-padded
1285 // at oc + 1, vdpbf16ps({0, w}, {dst, dst}) is okay.
1286 vpbroadcastw(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1287 L(load_done);
1288 } else {
1289 vpbroadcastd(ddst, ptr[aux_reg_dst + aux_ddst_offset]);
1290 }
1291 }
1292 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1293 size_t aux_kernel_offset = get_kernel_offset(kk, oc, ki);
1294 vmovups(vmm_wei,
1295 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1296
1297 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
1298 auto ddst = vmm_ddst(jj / stride_w);
1299 auto acc = vmm_dsrc(jj, kk);
1300
1301 if (isa_has_bf16(jcp.isa)) {
1302 vdpbf16ps(acc, vmm_wei, ddst);
1303 } else
1304 bf16_emu_->vdpbf16ps(Zmm(acc.getIdx()),
1305 Zmm(vmm_wei.getIdx()), Zmm(ddst.getIdx()));
1306 }
1307 }
1308 }
1309 L(oc_tail_jmp[ki]);
1310 }
1311
1312 add(aux_reg_ker, get_kernel_offset(0, 0, 0, stride_h));
1313 sub(aux_reg_dst, get_diff_dst_offset(filter_h_to_dst(1), 0));
1314
1315 dec(reg_kj);
1316 cmp(reg_kj, 0);
1317 jg(kh_label, T_NEAR);
1318 }
1319
1320 if (jcp.ndims == 5) {
1321 sub(aux_reg_dst_d, get_diff_dst_offset(filter_d_to_dst(1), 0));
1322 add(aux_reg_ker_d, get_kernel_offset(0, 0, 0, 0, jcp.stride_d));
1323
1324 dec(reg_ki);
1325 cmp(reg_ki, 0);
1326 jg(kd_label, T_NEAR);
1327 }
1328
1329 // End of OC Loop
1330 auto diff_dst_step = get_diff_dst_offset(0, 0, 1);
1331 auto ker_step = get_kernel_offset(0, jcp.oc_block, 0);
1332 add(reg_dst, diff_dst_step);
1333 add(reg_ker, ker_step);
1334
1335 sub(reg_oc, jcp.oc_block);
1336 cmp(reg_oc, 0);
1337 jg(ocb_label, T_NEAR);
1338
1339 sub(reg_dst, diff_dst_step * jcp.nb_oc);
1340 sub(reg_ker, ker_step * jcp.nb_oc);
1341
1342 L(skip_compute_label);
1343 store_output(ur_w);
1344}
1345
1346template <typename Vmm>
1347void _jit_avx512_core_bf16_bwd_data_kernel<Vmm>::generate() {
1348 int iw = jcp.iw;
1349 int kw = jcp.kw;
1350 int ur_w = jcp.ur_w;
1351 int nb_iw = jcp.nb_iw;
1352 int iw_block = jcp.iw_block;
1353 int ur_w_tail = jcp.ur_w_tail;
1354 int dilate_w = jcp.dilate_w + 1;
1355 int stride_w = jcp.stride_w;
1356
1357 const auto dst_shift = get_diff_dst_offset(ur_w / stride_w, 0);
1358 const auto src_shift = get_diff_src_offset(ur_w, 0);
1359
1360 preamble();
1361
1362 if (jcp.simd_w == 4) {
1363 Reg32 reg_tail_32 = reg_oc.cvt32();
1364 mov(reg_tail_32, (1 << jcp.simd_w) - 1);
1365 kmovb(k_ic_tail_mask, reg_tail_32);
1366 }
1367
1368 if (jcp.ic_tail) {
1369 Label done;
1370 // dummy mask all 1's
1371 if (jcp.simd_w != 4)
1372 kxnord(k_ic_tail_mask, k_ic_tail_mask, k_ic_tail_mask);
1373 // To account for special store optimization, where two ic_blocks are
1374 // combined with one single write, extend the mask for 32bits (32 bf16s)
1375 const bool need_extended_mask
1376 = isa_has_bf16(jcp.isa) && jcp.nb_ic_blocking > 1;
1377 if (need_extended_mask)
1378 kxnord(k_ic_tail_mask_extended, k_ic_tail_mask_extended,
1379 k_ic_tail_mask_extended);
1380
1381 test(byte[param1 + GET_OFF(load_work)], jcp.ic_block - 1);
1382 jz(done, T_NEAR);
1383 Reg32 reg_tail_32 = reg_ic.cvt32();
1384 mov(reg_tail_32, (1 << jcp.ic_tail) - 1);
1385 kmovd(k_ic_tail_mask, reg_tail_32);
1386 if (need_extended_mask) {
1387 mov(reg_tail_32, (1 << (jcp.ic_tail + jcp.simd_w)) - 1);
1388 kmovd(k_ic_tail_mask_extended, reg_tail_32);
1389 }
1390 L(done);
1391 }
1392
1393 mov(reg_src, ptr[param + GET_OFF(src)]);
1394 mov(reg_dst, ptr[param + GET_OFF(dst)]);
1395 mov(reg_ker, ptr[param + GET_OFF(filt)]);
1396
1397 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
1398
1399 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
1400 int r_overflow = nstl::max(
1401 0, ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad)) / stride_w);
1402 int r_overflow1 = nstl::max(0,
1403 ((kw - 1) * dilate_w - nstl::max(0, jcp.r_pad + ur_w_tail))
1404 / stride_w);
1405
1406 int body_l_overflow = 0, body_r_overflow = 0;
1407 int n_oi = iw / ur_w;
1408 int head_n_oi = 0, body_n_oi = 0, pretail_n_oi = 0, tail_n_oi = 0;
1409 int head_thread = 0, pretail_thread = 0, tail_thread = 0;
1410 bool threaded = is_iw_threading_on(jcp);
1411 Label head_label, body_label, pretail_label, tail_label, end_label;
1412 assert(n_oi > 0);
1413
1414 if (r_overflow1 > 0) n_oi--;
1415 if (l_overflow > 0) n_oi--;
1416 if (n_oi < 0) {
1417 // l_overflow and r_overflow1 are handled in the same compute_loop.
1418 // Perform one iteration of body handling l_overflow and r_overflow1.
1419 body_l_overflow = l_overflow;
1420 body_r_overflow = r_overflow1;
1421 n_oi = 1;
1422 l_overflow = 0;
1423 r_overflow1 = 0;
1424 }
1425
1426 if (!threaded) {
1427 if (n_oi > 1) { mov(reg_oi, n_oi); }
1428 } else {
1429 // Setup for threaded code generation, and jump into the correct
1430 // portion of code for execution.
1431 head_thread = 0;
1432 tail_thread = nb_iw - 1;
1433 pretail_thread = tail_thread;
1434
1435 int base_n_oi = iw_block / ur_w;
1436 head_n_oi = l_overflow > 0 ? base_n_oi - 1 : base_n_oi;
1437 tail_n_oi = (iw - iw_block * (nb_iw - 1)) / ur_w;
1438 pretail_n_oi = tail_n_oi;
1439 if (r_overflow1 > 0) {
1440 if (tail_n_oi > 0) {
1441 pretail_n_oi--;
1442 tail_n_oi = pretail_n_oi;
1443 } else {
1444 // pretail_thread and tail_thread are different
1445 pretail_n_oi = base_n_oi - 1;
1446 pretail_thread = tail_thread - 1;
1447 }
1448 if (head_thread == pretail_thread) {
1449 head_n_oi--;
1450 pretail_n_oi = 0;
1451 tail_n_oi = 0;
1452 }
1453 }
1454 body_n_oi = (head_thread < pretail_thread - 1) ? base_n_oi : 0;
1455
1456 // n_oi is used to determine how much control flow in the body portion
1457 // of the code needs generated. As such, n_oi needs to be set to the
1458 // maximum number of iterations it will be used the body code section.
1459 n_oi = nstl::max(body_n_oi, head_n_oi);
1460 n_oi = nstl::max(n_oi, pretail_n_oi);
1461
1462 assert(iw_block % ur_w == 0);
1463 mov(reg_iwb, ptr[param1 + GET_OFF(iwb)]);
1464
1465 if (head_n_oi != 0) mov(reg_oi, head_n_oi);
1466 cmp(reg_iwb, head_thread);
1467 je(head_label, T_NEAR);
1468
1469 cmp(reg_iwb, pretail_thread);
1470 if (pretail_n_oi == 0) {
1471 je(pretail_label, T_NEAR);
1472 } else {
1473 mov(reg_oi, pretail_n_oi);
1474 je(body_label, T_NEAR);
1475 }
1476 if (pretail_thread != tail_thread) {
1477 cmp(reg_iwb, tail_thread);
1478 je(tail_label, T_NEAR);
1479 }
1480 if (body_n_oi != 0) {
1481 mov(reg_oi, body_n_oi);
1482 jmp(body_label, T_NEAR);
1483 } else {
1484 jmp(end_label, T_NEAR);
1485 }
1486 }
1487 L(head_label);
1488 if (l_overflow > 0) {
1489 compute_loop(ur_w, l_overflow, 0);
1490 if (threaded && head_n_oi == 0 && head_thread != pretail_thread)
1491 jmp(end_label, T_NEAR);
1492 add(reg_src, src_shift);
1493 add(reg_dst, dst_shift);
1494 }
1495 L(body_label);
1496 if (n_oi > 0) {
1497 Label ow_loop_label;
1498 L(ow_loop_label);
1499 {
1500 compute_loop(ur_w, body_l_overflow, body_r_overflow);
1501 if (n_oi > 1 || r_overflow1 > 0 || ur_w_tail != 0) {
1502 add(reg_src, src_shift);
1503 add(reg_dst, dst_shift);
1504 }
1505 if (n_oi > 1) {
1506 sub(reg_oi, 1);
1507 jg(ow_loop_label, T_NEAR);
1508 }
1509 }
1510 }
1511 if (threaded) {
1512 cmp(reg_iwb, pretail_thread);
1513 jne(end_label, T_NEAR);
1514 }
1515 L(pretail_label);
1516 if (r_overflow1 > 0) {
1517 compute_loop(ur_w, 0, r_overflow1);
1518 if (ur_w_tail != 0) {
1519 if (threaded && tail_thread != pretail_thread)
1520 jmp(end_label, T_NEAR);
1521 else {
1522 add(reg_src, src_shift);
1523 add(reg_dst, dst_shift);
1524 }
1525 }
1526 }
1527 L(tail_label);
1528 if (ur_w_tail != 0) { compute_loop(ur_w_tail, 0, r_overflow); }
1529 L(end_label);
1530
1531 postamble();
1532}
1533
1534status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(jit_conv_conf_t &jcp,
1535 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
1536 memory_desc_t &weights_md, memory_desc_t &diff_dst_md, int nthreads) {
1537
1538 const memory_desc_wrapper diff_src_d(&diff_src_md);
1539 const memory_desc_wrapper weights_d(&weights_md);
1540 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1541
1542 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1543 int ndims = diff_src_d.ndims();
1544
1545 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
1546 : bf16_emulation_t::get_isa();
1547 jcp.nthr = nthreads;
1548 jcp.has_vnni = true;
1549 jcp.ndims = ndims;
1550 jcp.prop_kind = cd.prop_kind;
1551
1552 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1553 jcp.mb = diff_src_d.dims()[0];
1554
1555 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1556 jcp.oc_without_padding = jcp.oc;
1557 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
1558
1559 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1560 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1561 jcp.iw = diff_src_d.dims()[ndims - 1];
1562 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1563 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1564 jcp.ow = diff_dst_d.dims()[ndims - 1];
1565
1566 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1567 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1568 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1569
1570 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1571 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1572 jcp.l_pad = cd.padding[0][ndims - 3];
1573
1574 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1575 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1576 jcp.stride_w = cd.strides[ndims - 3];
1577
1578 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1579 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1580 jcp.dilate_w = cd.dilates[ndims - 3];
1581 jcp.dst_dt = cd.diff_src_desc.data_type;
1582 jcp.nb_iw = 1;
1583 jcp.iw_block = jcp.iw;
1584
1585 /* Dilated convolutions supported with unit strides only */
1586 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
1587 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
1588 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
1589 return status::unimplemented;
1590
1591 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1592 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1593 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1594 jcp.r_pad = calculate_end_padding(
1595 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1596 jcp.b_pad = calculate_end_padding(
1597 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1598 jcp.back_pad = calculate_end_padding(
1599 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1600 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1601 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1602 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1603 if (kernel_outside_src) return status::unimplemented;
1604
1605 jcp.aligned_threads = 0;
1606
1607 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
1608 const auto dat_tag_nCx4c = pick(ndims - 3, nCw4c, nChw4c, nCdhw4c);
1609 const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c);
1610 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1611 auto curr_src_tag = diff_src_d.matches_one_of_tag(
1612 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1613 auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
1614 dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_nCx4c);
1615 bool is_data_layout_nxc
1616 = IMPLICATION(curr_src_tag != dat_tag_nxc,
1617 diff_src_d.format_kind() == format_kind::any)
1618 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
1619 diff_dst_d.format_kind() == format_kind::any)
1620 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
1621
1622 bool ok_to_pad_channels = jcp.ngroups == 1 && !is_data_layout_nxc;
1623
1624 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
1625
1626 const bool ok_to_try_lower_zmm = true
1627 && IMPLICATION(is_data_layout_nxc,
1628 jcp.oc < jcp.simd_w && jcp.ic < jcp.simd_w
1629 && jcp.ngroups > 1)
1630 && !ok_to_pad_channels
1631 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0);
1632
1633 if (ok_to_try_lower_zmm) {
1634 for (auto simd : {8, 4}) {
1635 if (jcp.ic % simd == 0 && jcp.oc % simd == 0) {
1636 jcp.simd_w = simd;
1637 break;
1638 }
1639 }
1640 }
1641
1642 jcp.oc_block = jcp.simd_w;
1643 jcp.ic_block = jcp.simd_w;
1644
1645 if (ok_to_pad_channels) {
1646 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1647 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1648 }
1649
1650 if (!IMPLICATION(!is_data_layout_nxc,
1651 jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0))
1652 return status::unimplemented;
1653 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.simd_w : 0;
1654 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.simd_w : 0;
1655
1656 format_tag_t wei_tag, dat_tag;
1657
1658 if (jcp.simd_w == 8) {
1659 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
1660 wei_tag = utils::pick(ndims - 3, gOIw4o8i2o, gOIhw4o8i2o, gOIdhw4o8i2o);
1661 } else if (jcp.simd_w == 4) {
1662 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx4c;
1663 wei_tag = utils::pick(ndims - 3, gOIw2o4i2o, gOIhw2o4i2o, gOIdhw2o4i2o);
1664 } else {
1665 dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
1666 wei_tag = pick(2 * ndims - 6 + with_groups, OIw8o16i2o, gOIw8o16i2o,
1667 OIhw8o16i2o, gOIhw8o16i2o, OIdhw8o16i2o, gOIdhw8o16i2o);
1668 }
1669
1670 if (diff_src_md.format_kind == format_kind::any) {
1671 CHECK(memory_desc_init_by_tag(diff_src_md, dat_tag));
1672 } else if (curr_src_tag != dat_tag)
1673 return status::unimplemented;
1674 jcp.src_tag = dat_tag;
1675
1676 if (diff_dst_md.format_kind == format_kind::any) {
1677 CHECK(memory_desc_init_by_tag(diff_dst_md, dat_tag));
1678 } else if (curr_dst_tag != dat_tag)
1679 return status::unimplemented;
1680 jcp.dst_tag = dat_tag;
1681
1682 if (weights_md.format_kind == format_kind::any) {
1683 CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
1684 jcp.wei_tag = wei_tag;
1685 } else {
1686 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1687 if (jcp.wei_tag != wei_tag) return status::unimplemented;
1688 }
1689
1690 bool args_ok = true && jcp.ic <= diff_src_d.padded_dims()[1]
1691 && jcp.oc <= diff_dst_d.padded_dims()[1]
1692 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1693 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
1694 if (!args_ok) return status::unimplemented;
1695
1696 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
1697 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
1698
1699 jcp.ur_w = jcp.stride_w;
1700
1701 /* Maximum number of registers available for result accumulation and delta
1702 dst data. One additional register is reserved for weights data. */
1703 const int max_regs
1704 = isa_has_bf16(jcp.isa) ? 31 : 26; /* In case of cpx emulation
1705 additional 5 registers are
1706 reserved */
1707 int l_overflow = nstl::max(
1708 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
1709
1710 jcp.typesize_in = types::data_type_size(diff_dst_d.data_type());
1711 jcp.typesize_out = types::data_type_size(diff_src_d.data_type());
1712
1713 /* Find the best blocking with maximum number of compute instructions
1714 per ur_w * nb_ic_blocking compute loops. Number of required registers
1715 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1716 ur_w must be divisible by stride_w */
1717 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1718 distribution exceeds max_regs */
1719 return status::unimplemented;
1720
1721 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1722 {
1723 jcp.kernel_kind = expl_bcast;
1724 int best_compute_pipeline_length = 0;
1725 const int max_ic_blocks = 4;
1726 for (int b = 1; b <= max_ic_blocks; b++) {
1727 if (jcp.nb_ic % b != 0) continue;
1728
1729 for (int u = jcp.stride_w; u * b + u / jcp.stride_w <= max_regs
1730 && u < jcp.iw + jcp.stride_w;
1731 u += jcp.stride_w) {
1732 int ur_w = nstl::min(u, jcp.iw);
1733 /* maximum 1 step with l_overflow so far */
1734 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1735 continue;
1736 int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b;
1737 if (pipeline_length > best_compute_pipeline_length
1738 || (pipeline_length == best_compute_pipeline_length
1739 && jcp.ur_w < ur_w)) {
1740 jcp.ur_w = ur_w;
1741 jcp.nb_ic_blocking = b;
1742 best_compute_pipeline_length = pipeline_length;
1743 }
1744 }
1745 }
1746 if (best_compute_pipeline_length == 0) /* can't find
1747 appropriate blocking */
1748 return status::unimplemented;
1749 }
1750 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1751
1752 if (is_iw_threading_available(jcp)) {
1753 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
1754 int work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
1755 float no_iw_block_eff
1756 = (float)work_units / rnd_up(work_units, jcp.nthr);
1757
1758 // current design of generate() requires iw_block >= 2 * ur_w
1759 const int min_iw_block = jcp.ur_w * 2;
1760 int iw_threads = jcp.nthr / math::gcd(work_units, jcp.nthr);
1761 int iw_block = nstl::max(min_iw_block,
1762 rnd_up(jcp.iw, jcp.ur_w * iw_threads) / iw_threads);
1763 int nb_iw = div_up(jcp.iw, iw_block);
1764
1765 float block_eff = (float)jcp.iw / rnd_up(jcp.iw, iw_block);
1766 work_units = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih * nb_iw;
1767 float work_eff = (float)work_units / rnd_up(work_units, jcp.nthr);
1768 float iw_block_eff = block_eff * work_eff;
1769
1770 const int iw_thread_min_size = 16 * 128;
1771 const float iw_block_cost = 20.0;
1772 float block_overhead = nstl::max(0.0f, 1.0f - iw_block_cost / iw_block);
1773
1774 bool iw_thread_useful = no_iw_block_eff < block_overhead * iw_block_eff
1775 && jcp.ic_block * jcp.iw > iw_thread_min_size;
1776
1777 if (iw_thread_useful) {
1778 jcp.iw_block = iw_block;
1779 jcp.nb_iw = nb_iw;
1780 }
1781 }
1782
1783 if (l_overflow * jcp.stride_w > jcp.ur_w) return status::unimplemented;
1784 int r_overflow_no_tail = nstl::max(0,
1785 ((jcp.kw - 1) * (jcp.dilate_w + 1)
1786 - nstl::max(0, jcp.r_pad + jcp.ur_w_tail))
1787 / jcp.stride_w);
1788 bool tails_not_ok = false
1789 /* maximum 1 ur_w block with r_overflow so far */
1790 || r_overflow_no_tail * jcp.stride_w > jcp.ur_w
1791 /* ur_w must be a multiple of stride */
1792 || ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1793 /* r_pad must not extend beyond ur_w_tail */
1794 || ((jcp.iw > jcp.ur_w) && (jcp.r_pad + jcp.ur_w_tail < 0));
1795 if (tails_not_ok) return status::unimplemented;
1796
1797 /* adjust the thread decomposition
1798 * to improve the perf for small problem size
1799 * the threshold L1_cache_size/factor and the factor is empirical
1800 * simply set the thread number to 4 now
1801 * TODO: Add get_thr_eff function to compute optimal thread*/
1802 size_t wei_size = (size_t)sizeof(bfloat16_t) * jcp.ic * jcp.oc * jcp.kh
1803 * jcp.kw * jcp.kd;
1804 size_t inp_size = (size_t)jcp.typesize_in * jcp.mb * jcp.ic * jcp.ih
1805 * jcp.iw * jcp.id;
1806 size_t out_size = (size_t)jcp.typesize_out * jcp.mb * jcp.oc * jcp.oh
1807 * jcp.ow * jcp.od;
1808 size_t total_size = jcp.ngroups * (wei_size + inp_size + out_size);
1809 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1810
1811 //The factor for 1d: 1, 2d: 2, 3d: 4;
1812 int factor = nstl::max(1, (2 * (ndims - 3)));
1813 if (jcp.ngroups < jcp.nthr && total_size < L1_cache_size / factor) {
1814 jcp.nthr = nstl::min(jcp.nthr, 4);
1815 }
1816
1817 pick_loop_order(jcp);
1818
1819 return status::success;
1820}
1821
1822const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28;
1823
1824void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
1825 od_step_comeback_pointers() {
1826 Label kd_comeback_label;
1827 mov(kj, reg_kd_count);
1828 L(kd_comeback_label);
1829 {
1830 sub(reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
1831 sub(reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
1832 dec(kj);
1833 cmp(kj, 0);
1834 jg(kd_comeback_label, T_NEAR);
1835 }
1836}
1837void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
1838 oh_step_comeback_pointers() {
1839 Label kh_comeback_label;
1840 mov(kj, reg_kh);
1841 L(kh_comeback_label);
1842 {
1843 sub(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
1844 sub(reg_kernel, get_kernel_offset(0, jcp.kw));
1845 dec(kj);
1846 cmp(kj, 0);
1847 jg(kh_comeback_label, T_NEAR);
1848 }
1849}
1850
1851void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
1852 compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r,
1853 int ic_block_step, int src_offset, int kernel_offset,
1854 int ddst_offset, bool is_tail) {
1855 assert(!is_src_layout_nxc() && !is_ddst_layout_nxc());
1856 int kw = jcp.kw;
1857 bool no_src_pad = jcp.is_1stconv && !jcp.transpose_src;
1858 const int ddst_zmm_base_idx = 24;
1859 const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1860 const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1861
1862 auto zmm_ker = [=](int i_kw, int i_ic) {
1863 return Zmm(i_kw * ic_block_step + i_ic);
1864 };
1865 auto zmm_ddst = [=](int i_iw) {
1866 // TODO: move reg calc to global member funcs
1867 return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
1868 };
1869
1870 auto ker_addr = [=](int i_kw, int i_ic) {
1871 auto local_offset = get_kernel_offset(i_ic, i_kw);
1872 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
1873 };
1874 auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
1875 bool vnni_bcast = false) {
1876 auto local_offset = get_src_offset(i_ic, i_iw);
1877 return EVEX_compress_addr(
1878 reg_src, local_offset + src_offset + extra_offset, vnni_bcast);
1879 };
1880 auto ddst_addr = [=](int i_ur) {
1881 auto ow_scale = 2;
1882 return EVEX_compress_addr(
1883 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
1884 };
1885
1886 for (int i_kw = 0; i_kw < kw; i_kw++)
1887 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1888 auto zmm = zmm_ker(i_kw, i_ic);
1889 vpxord(zmm, zmm, zmm);
1890 }
1891 assert(ur_w % 2 == 0);
1892 auto steps = ur_w / 2;
1893
1894 const int str_w = jcp.stride_w;
1895 const int underflow_boundary = -1;
1896 int i_iw_shift = jcp.tr_ow - ur_w - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
1897 const int overflow_boundary = jcp.iw - 1 - i_iw_shift;
1898
1899 for (int s = 0; s < str_w; s++) {
1900 const int kw_start = s;
1901 assert(jcp.tr_iw % str_w == 0);
1902 const int src_stride_w_shift = jcp.tr_iw / str_w;
1903 for (int i_ur = 0; i_ur < steps; i_ur++) {
1904 auto zmm = zmm_ddst(i_ur);
1905 vmovdqu16(zmm, ddst_addr(i_ur));
1906
1907 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1908 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1909 int i_iw = 2 * i_ur + (i_kw * (jcp.dilate_w + 1)) / str_w
1910 + s * src_stride_w_shift;
1911 bool underflow = false;
1912 bool overflow = false;
1913 if (no_src_pad) {
1914 i_iw = i_iw - pad_l;
1915 underflow = i_iw <= underflow_boundary;
1916 overflow = is_tail && i_iw >= overflow_boundary;
1917 }
1918
1919 auto src = Zmm(zmm_src_reg);
1920 auto acc = zmm_ker(i_kw, i_ic);
1921 auto ddst = zmm_ddst(i_ur);
1922 if (underflow || overflow || !isa_has_bf16(jcp.isa)) {
1923 assert(ddst != src);
1924 assert(acc != src);
1925 }
1926 assert(ddst != acc);
1927 if (underflow || overflow) {
1928 if (underflow && i_iw == underflow_boundary)
1929 vpbroadcastw(src | everyother_shift_mask | T_z,
1930 src_addr(i_iw + 1, i_ic, 0));
1931 else if (overflow && i_iw == overflow_boundary)
1932 vpbroadcastw(src | everyother_mask | T_z,
1933 src_addr(i_iw, i_ic, 0));
1934 else
1935 continue;
1936
1937 if (!isa_has_bf16(jcp.isa))
1938 bf16_emu_->vdpbf16ps(acc, ddst, src);
1939 else
1940 vdpbf16ps(acc, ddst, src);
1941 } else if (!isa_has_bf16(jcp.isa)) {
1942 vpbroadcastd(src, src_addr(i_iw, i_ic, 0));
1943 bf16_emu_->vdpbf16ps(acc, ddst, src);
1944 } else
1945 vdpbf16ps(acc, ddst, src_addr(i_iw, i_ic, 0, true));
1946 }
1947 }
1948 }
1949 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1950 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1951 auto addr = ker_addr(i_kw, i_ic);
1952 auto zmm = zmm_ker(i_kw, i_ic);
1953 vaddps(zmm, zmm, addr);
1954 vmovups(addr, zmm);
1955 }
1956 }
1957 }
1958}
1959
1960int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_w_reorder_size(
1961 int ur_w) const {
1962 const int reorder_block = 16;
1963 return rnd_up(jcp.stride_w * (ur_w - 1) + jcp.kw, reorder_block);
1964}
1965int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
1966 interleave_w_reorder_bytes(int ur_w) {
1967 return 2 * jcp.typesize_in * interleave_w_reorder_size(ur_w);
1968}
1969int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::interleave_stack_size(
1970 int ur_w, int ic_block_step) {
1971 return ic_block_step * interleave_w_reorder_bytes(ur_w);
1972}
1973void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
1974 compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r,
1975 int ic_block_step, int src_offset, int kernel_offset,
1976 int ddst_offset, bool is_tail) {
1977 // Only supports nchw format src
1978 assert(jcp.is_1stconv && !jcp.transpose_src);
1979 int kw = jcp.kw;
1980 const int ddst_zmm_base_idx = 24;
1981 const int in_zmm_base_idx = 24;
1982 const int num_ddst_zmm_regs = !isa_has_bf16(jcp.isa) ? 2 : 4;
1983 //const int num_in_zmm_regs = 8;
1984 const int zmm_src_reg = ddst_zmm_base_idx + num_ddst_zmm_regs;
1985 const int reorder_block = 16;
1986 const int reorder_size = interleave_w_reorder_size(ur_w);
1987 const int reorder_bytes = interleave_w_reorder_bytes(ur_w);
1988 const int stack_size = interleave_stack_size(ur_w, ic_block_step);
1989 if (stack_size > ic_block_step_stack_size) {
1990 // This is a guard. Ideally it is never used, but is included to defend
1991 // against overlooked edge cases.
1992 assert(stack_size <= ic_block_step_stack_size);
1993 sub(rsp, stack_size - ic_block_step_stack_size);
1994 }
1995
1996 auto zmm_ker = [=](int i_kw, int i_ic) {
1997 return Zmm(i_kw * ic_block_step + i_ic);
1998 };
1999 auto zmm_ddst = [=](int i_iw) {
2000 return Zmm(ddst_zmm_base_idx + i_iw % num_ddst_zmm_regs);
2001 };
2002 auto zmm_in = [=](int i_iw, int i_ic, bool stride_reg) {
2003 int stride = stride_reg ? 1 : 0;
2004 return Zmm(in_zmm_base_idx + 4 * (i_ic % 2) + 2 * (i_iw % 2) + stride);
2005 };
2006
2007 auto ker_addr = [=](int i_kw, int i_ic) {
2008 auto local_offset = get_kernel_offset(i_ic, i_kw);
2009 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2010 };
2011 auto src_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
2012 bool vnni_bcast = false) {
2013 int local_offset = i_ic * reorder_bytes + 2 * jcp.typesize_in * i_iw;
2014 return EVEX_compress_addr(rsp, local_offset, vnni_bcast);
2015 };
2016 auto ddst_addr = [=](int i_ur) {
2017 auto ow_scale = 2;
2018 return EVEX_compress_addr(
2019 reg_ddst, get_ddst_offset(ow_scale * i_ur) + ddst_offset);
2020 };
2021 auto load_src_to_stack = [=](int i_iw, int i_ic, Opmask mask,
2022 bool mask_empty, Opmask stride_mask,
2023 bool stride_mask_empty) {
2024 auto local_offset = get_src_offset(i_ic, i_iw);
2025 int stack_offset
2026 = i_ic * reorder_bytes + 2 * jcp.typesize_in * (i_iw + pad_l);
2027
2028 auto zmm = zmm_in(i_iw, i_ic, false);
2029 auto zmm_stride = zmm_in(i_iw, i_ic, true);
2030 auto base_addr
2031 = EVEX_compress_addr(reg_src, local_offset + src_offset, false);
2032 auto stride_addr = EVEX_compress_addr(reg_src,
2033 local_offset + src_offset + get_src_offset(0, jcp.stride_w));
2034 auto stack_addr = EVEX_compress_addr(rsp, stack_offset);
2035 assert(IMPLICATION(mask_empty, stride_mask_empty));
2036 if (mask_empty) {
2037 vpxord(zmm, zmm, zmm);
2038 } else {
2039 vpmovzxwd(zmm | mask | T_z, base_addr);
2040 }
2041 if (!stride_mask_empty) {
2042 vpmovzxwd(zmm_stride | stride_mask | T_z, stride_addr);
2043 vpslld(zmm_stride, zmm_stride, 16);
2044 vpord(zmm, zmm, zmm_stride);
2045 }
2046 vmovdqu16(stack_addr, zmm);
2047 };
2048
2049 assert(ur_w % 2 == 0);
2050 auto steps = ur_w / 2;
2051
2052 const int str_w = jcp.stride_w;
2053 int i_iw_shift = str_w * (jcp.tr_ow - ur_w)
2054 - ((jcp.l_pad != pad_l) ? jcp.l_pad : 0);
2055 const int overflow_boundary
2056 = is_tail ? jcp.iw - i_iw_shift : str_w * (ur_w - 1) + kw - pad_l;
2057
2058 // Calculate padding required by the data reorder using 32 byte loads
2059 int reorder_overflow = reorder_size - pad_l - overflow_boundary;
2060 int reorder_stride_overflow = reorder_overflow + str_w;
2061 reorder_overflow = nstl::max(0, reorder_overflow);
2062 reorder_stride_overflow = nstl::max(0, reorder_stride_overflow);
2063 int reorder_pad_r = reorder_overflow % reorder_block;
2064 int reorder_stride_pad_r = reorder_stride_overflow % reorder_block;
2065 if (reorder_stride_overflow >= reorder_size && reorder_stride_pad_r == 0) {
2066 assert(reorder_stride_overflow == reorder_size);
2067 reorder_stride_pad_r = reorder_block;
2068 }
2069 reorder_overflow -= reorder_pad_r;
2070 reorder_stride_overflow -= reorder_stride_pad_r;
2071
2072 int pad_l_mask = (0xffff << pad_l) & 0xffff;
2073 int pad_l_mask_strided
2074 = (0xffff << (pad_l >= str_w ? (pad_l - str_w) : 0)) & 0xffff;
2075 int pad_r_mask = 0xffff >> reorder_pad_r;
2076 int pad_r_mask_strided = 0xffff >> (reorder_stride_pad_r);
2077 pad_r_mask = pad_r_mask & 0xffff;
2078
2079 // Setup masks to load and reorder data
2080 if (reorder_size - reorder_stride_overflow > reorder_block) {
2081 // Overflow and underflow happen in different data reorder rounds
2082 kxnorw(overflow_stride_mask, overflow_stride_mask,
2083 overflow_stride_mask);
2084 kshiftlw(underflow_mask, overflow_stride_mask, pad_l);
2085 kshiftlw(underflow_stride_mask, overflow_stride_mask,
2086 pad_l >= str_w ? pad_l - str_w : 0);
2087 kshiftrw(overflow_mask, overflow_stride_mask, reorder_pad_r);
2088 kshiftrw(overflow_stride_mask, overflow_stride_mask,
2089 reorder_stride_pad_r);
2090 } else if (reorder_size - reorder_overflow > reorder_block) {
2091 // Overflow and underflow happen in the same round for loading the data
2092 // at the stride offset.
2093 kxnorw(overflow_mask, overflow_mask, overflow_mask);
2094 kshiftlw(underflow_mask, overflow_mask, pad_l);
2095 kshiftrw(overflow_mask, overflow_mask, reorder_pad_r);
2096 mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2097 kmovw(underflow_stride_mask, reg_tmp.cvt32());
2098 } else {
2099 // Overflow and underflow happen in the same round for all data loads
2100 mov(reg_tmp.cvt32(), pad_l_mask & pad_r_mask);
2101 kmovw(underflow_mask, reg_tmp.cvt32());
2102 mov(reg_tmp.cvt32(), pad_l_mask_strided & pad_r_mask_strided);
2103 kmovw(underflow_stride_mask, reg_tmp.cvt32());
2104 }
2105
2106 // Load and reorder data to the stack
2107 int reorder_start = -pad_l;
2108 int reorder_end = reorder_size - pad_l;
2109 for (int i_iw = reorder_start; i_iw < reorder_end; i_iw += reorder_block) {
2110 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2111 Opmask mask, stride_mask;
2112 bool mask_empty, stride_mask_empty;
2113 // Performing this reorder on the stack may not be (always) optimal.
2114 // There are a couple of methods involving externally reordering the
2115 // data that were not considered due to time constraints. The first
2116 // is to transpose similar to the extern method. The other is to
2117 // perform the same interleave transform used here. The tradeoff
2118 // between these methods is the transpose method does not lend
2119 // itself to SIMD instructions (except possibly for some specific
2120 // strides) since the data is not blocked. The transform performed
2121 // here does, but uses twice as much data since
2122 // most data elements are duplicated.
2123
2124 if (i_iw == reorder_start) {
2125 mask = underflow_mask;
2126 mask_empty = false;
2127 if (pad_l_mask == 0) mask_empty = true;
2128 } else if (i_iw + reorder_overflow >= reorder_end) {
2129 mask_empty = true;
2130 } else if (i_iw + reorder_block + reorder_overflow >= reorder_end) {
2131 mask = overflow_mask;
2132 mask_empty = false;
2133 if (pad_r_mask == 0) mask_empty = true;
2134 } else {
2135 mask = m_ffffffff;
2136 mask_empty = false;
2137 }
2138 if (i_iw == reorder_start) {
2139 stride_mask = underflow_stride_mask;
2140 stride_mask_empty = false;
2141 if (pad_l_mask_strided == 0) mask_empty = true;
2142 } else if (i_iw + reorder_stride_overflow >= reorder_end) {
2143 stride_mask_empty = true;
2144 } else if (i_iw + reorder_block + reorder_stride_overflow
2145 >= reorder_end) {
2146 stride_mask = overflow_stride_mask;
2147 stride_mask_empty = false;
2148 if (pad_r_mask_strided == 0) mask_empty = true;
2149 } else {
2150 stride_mask = m_ffffffff;
2151 stride_mask_empty = false;
2152 }
2153 load_src_to_stack(i_iw, i_ic, mask, mask_empty, stride_mask,
2154 stride_mask_empty);
2155 }
2156 }
2157
2158 // Initialize kernel accumulators. It should sometimes be possible to skip
2159 // initializing and storing this data between calls to this function.
2160 for (int i_kw = 0; i_kw < kw; i_kw++)
2161 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2162 auto zmm = zmm_ker(i_kw, i_ic);
2163 vpxord(zmm, zmm, zmm);
2164 }
2165
2166 // Calculate this blocks contribution
2167 for (int i_ur = 0; i_ur < steps; i_ur++) {
2168 auto zmm = zmm_ddst(i_ur);
2169 vmovdqu16(zmm, ddst_addr(i_ur));
2170
2171 for (int i_kw = 0; i_kw < kw; i_kw++) {
2172 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2173 int i_iw = 2 * i_ur * str_w + i_kw;
2174 auto acc = zmm_ker(i_kw, i_ic);
2175 auto ddst = zmm_ddst(i_ur);
2176
2177 const bool isa_supports_bf16 = isa_has_bf16(jcp.isa);
2178 auto src_stack_addr
2179 = src_addr(i_iw, i_ic, 0, isa_supports_bf16);
2180
2181 if (isa_supports_bf16)
2182 vdpbf16ps(acc, ddst, src_stack_addr);
2183 else {
2184 auto src = Zmm(zmm_src_reg);
2185 vpbroadcastd(src, src_stack_addr);
2186 bf16_emu_->vdpbf16ps(acc, ddst, src);
2187 }
2188 }
2189 }
2190 }
2191
2192 // Store kernel accumulators
2193 for (int i_kw = 0; i_kw < kw; i_kw++) {
2194 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2195 auto addr = ker_addr(i_kw, i_ic);
2196 auto zmm = zmm_ker(i_kw, i_ic);
2197 vaddps(zmm, zmm, addr);
2198 vmovups(addr, zmm);
2199 }
2200 }
2201
2202 if (stack_size > ic_block_step_stack_size) {
2203 // This is a guard. Ideally it is never used, but is included to defend
2204 // against overlooked edge cases.
2205 add(rsp, stack_size - ic_block_step_stack_size);
2206 }
2207}
2208
2209void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2210 convert_src_to_vnni_format(
2211 int ur_w, int pad_l, int pad_r, int src_offset) {
2212 Reg64 reg_trans_tmp = r11;
2213 const int ic_tail = jcp.ic_tail;
2214 mov(EVEX_compress_addr(rsp, trans_tmp_offset), reg_trans_tmp);
2215
2216 mov(reg_trans_tmp, dst_prm_table);
2217 vmovups(get_perm_reg(), ptr[reg_trans_tmp]);
2218
2219 mov(reg_trans_tmp, EVEX_compress_addr(rsp, trans_tmp_offset));
2220 const int max_regs = 16;
2221 if (ic_tail) {
2222 Label skip_tail_mask;
2223 cmp(reg_icb, jcp.simd_w);
2224 jge(skip_tail_mask);
2225 kandd(m_0000ffff, m_0000ffff, m_0000_ic_tail);
2226 kandd(m_ffff0000, m_ffff0000, m_ic_tail_0000);
2227 L(skip_tail_mask);
2228 }
2229 for (int src_count = 0;
2230 sizeof_cacheline * src_count < permw_stack_size(ur_w);
2231 src_count++) {
2232 int i_ur = nstl::min(src_count, ur_w - 2);
2233 int i_kw = src_count - i_ur;
2234 int buffer_offset = permw_buffer_start + src_count * 64;
2235 auto bcast_values = Zmm(src_count % max_regs);
2236 bool check = check_borders(ur_w, pad_l, pad_r, i_ur, i_kw);
2237 if (check) {
2238 if (is_src_layout_nxc()) {
2239 int iw_1, iw_2;
2240 get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2);
2241 if (iw_1 == -1)
2242 vxorpd(bcast_values, bcast_values, bcast_values);
2243 else {
2244 dim_t local_src_offset = src_offset
2245 + get_src_offset(
2246 0, filter_w_to_src(i_kw, i_ur, pad_l));
2247 vmovdqu16(bcast_values | m_0000ffff | T_z,
2248 ptr[reg_src + local_src_offset]);
2249 }
2250 if (iw_2 != -1) {
2251 dim_t local_src_offset = src_offset - 32
2252 + get_src_offset(
2253 0, filter_w_to_src(i_kw, i_ur + 1, pad_l));
2254 vmovdqu16(bcast_values | m_ffff0000,
2255 ptr[reg_src + local_src_offset]);
2256 }
2257 } else {
2258 Opmask load_mask;
2259 get_load_mask(ur_w, pad_l, pad_r, i_ur, i_kw, load_mask);
2260
2261 dim_t local_src_offset = src_offset
2262 + get_src_offset(0, filter_w_to_src(i_kw, i_ur, pad_l));
2263 vmovdqu16(bcast_values | load_mask | T_z,
2264 ptr[reg_src + local_src_offset]);
2265 }
2266 vpermw(bcast_values, get_perm_reg(), bcast_values);
2267 } else {
2268 vpxord(bcast_values, bcast_values, bcast_values);
2269 }
2270 vmovups(ptr[rsp + buffer_offset], bcast_values);
2271 }
2272 if (ic_tail) {
2273 // Reset-back the masks
2274 kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2275 kshiftld(m_ffff0000, m_0000ffff, 16);
2276 }
2277}
2278
2279void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2280 may_be_set_oc_tail_mask() {
2281 if (jcp.oc_tail) {
2282 Label skip_tail_mask;
2283 cmp(dword[param + GET_OFF(load_work)], jcp.simd_w);
2284 jge(skip_tail_mask);
2285 kandd(m_0000ffff, m_0000ffff, m_0000_oc_tail);
2286 kandd(m_ffff0000, m_ffff0000, m_oc_tail_0000);
2287 L(skip_tail_mask);
2288 }
2289}
2290
2291void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2292 may_be_reset_oc_tail_mask() {
2293 if (jcp.oc_tail) {
2294 // Reset-back the masks
2295 kxnorw(m_0000ffff, m_0000ffff, m_0000ffff);
2296 kshiftld(m_ffff0000, m_0000ffff, 16);
2297 }
2298}
2299
2300void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2301 compute_ic_block_step_vpermw_expl(int ur_w, int pad_l, int pad_r,
2302 int ic_block_step, int src_offset, int kernel_offset,
2303 int ddst_offset, bool is_tail) {
2304 assert(!jcp.is_1stconv); // This method does not support nchw data
2305 int kw = jcp.kw;
2306 int src_count = 0;
2307 int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2308 const int max_regs = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2309 int src_pl_len = kw;
2310 const int diff_dst_pl_start_reg_idx = ic_block_step * (kw + src_pl_len);
2311 const int diff_dst_pl_len = max_regs - diff_dst_pl_start_reg_idx;
2312
2313 auto get_diff_wei_reg_idx
2314 = [=](int i_kw, int i_ic) { return i_kw * ic_block_step + i_ic; };
2315 auto get_src_reg_idx = [=](int i_iw, int i_ic) {
2316 return kw * ic_block_step + (i_iw % src_pl_len) * ic_block_step + i_ic;
2317 };
2318 auto get_diff_dst_reg_idx = [=](int i_ur) {
2319 return diff_dst_pl_start_reg_idx + (i_ur / 2) % diff_dst_pl_len;
2320 };
2321
2322 may_be_set_oc_tail_mask();
2323 auto load_dst = [=](int c) {
2324 bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2325 bool is_ddst_nxc = is_ddst_layout_nxc();
2326 auto offset = get_ddst_offset(c * 2) + ddst_offset;
2327
2328 Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2329 vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | load_mask | T_z,
2330 EVEX_compress_addr(reg_ddst, offset));
2331
2332 if (is_ddst_nxc && !is_tail) {
2333 offset += get_ddst_offset(1) - 32;
2334 vmovdqu16(Zmm(get_diff_dst_reg_idx(2 * c)) | m_ffff0000,
2335 EVEX_compress_addr(reg_ddst, offset));
2336 }
2337 vpermw(Zmm(get_diff_dst_reg_idx(2 * c)), get_perm_reg(),
2338 Zmm(get_diff_dst_reg_idx(2 * c)));
2339 };
2340
2341 for (int i_kw = 0; i_kw < kw; i_kw++)
2342 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2343 vpxord(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2344 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2345 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2346
2347 auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2348 int scale = 2 * jcp.typesize_in;
2349 return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2350 + jcp.typesize_in * 2
2351 * (ic_block_step_idx * ic_block_step + ic);
2352 };
2353 int src_count_last = 0;
2354 for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2355 if (i_ur == 0) {
2356 for (int dst_count = 0;
2357 dst_count < nstl::min(diff_dst_pl_len, div_up(ur_w, 2));
2358 dst_count++) {
2359 load_dst(dst_count);
2360 }
2361 for (src_count = 0; src_count < src_pl_len; src_count++) {
2362 int _i_ur = src_count / kw;
2363 int _i_kw = src_count % kw;
2364 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2365 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2366 vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2367 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2368 }
2369 }
2370 src_count_last = src_count;
2371 } else {
2372 int diff_dst_load_idx = i_ur + 2 * (diff_dst_pl_len - 1);
2373 if (diff_dst_load_idx < ur_w) load_dst(diff_dst_load_idx / 2);
2374 for (src_count = i_ur; src_count < i_ur + src_pl_len; src_count++) {
2375 if (src_count < src_count_last) continue;
2376 int _i_ur = (src_count - i_ur) / kw + i_ur;
2377 int _i_kw = (src_count - i_ur) % kw;
2378 if (check_borders(ur_w, pad_l, pad_r, _i_ur, _i_kw))
2379 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2380 vbroadcastss(Zmm(get_src_reg_idx(src_count, i_ic)),
2381 ptr[get_bcast_ptr(_i_ur, _i_kw, i_ic)]);
2382 }
2383 }
2384 src_count_last = src_count;
2385 }
2386 for (int i_kw = 0; i_kw < kw; i_kw++) {
2387 int i_iw = i_ur + i_kw;
2388 if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2389 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2390 if (!isa_has_bf16(jcp.isa)) {
2391 bf16_emu_->vdpbf16ps(
2392 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2393 Zmm(get_diff_dst_reg_idx(i_ur)),
2394 Zmm(get_src_reg_idx(i_iw, i_ic)));
2395 } else {
2396 vdpbf16ps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2397 Zmm(get_diff_dst_reg_idx(i_ur)),
2398 Zmm(get_src_reg_idx(i_iw, i_ic)));
2399 }
2400 }
2401 }
2402 }
2403 }
2404
2405 for (int i_kw = 0; i_kw < kw; i_kw++)
2406 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2407 auto l_offset = get_kernel_offset(i_ic, i_kw);
2408 vaddps(Zmm(get_diff_wei_reg_idx(i_kw, i_ic)),
2409 EVEX_compress_addr(reg_kernel, l_offset + kernel_offset));
2410 }
2411
2412 for (int i_kw = 0; i_kw < kw; i_kw++) {
2413 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2414 auto l_offset = get_kernel_offset(i_ic, i_kw);
2415 vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2416 Zmm(get_diff_wei_reg_idx(i_kw, i_ic)));
2417 }
2418 }
2419
2420 may_be_reset_oc_tail_mask();
2421}
2422
2423void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2424 compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r,
2425 int ic_block_step, int src_offset, int kernel_offset,
2426 int ddst_offset, bool is_tail) {
2427 assert(!jcp.is_1stconv); // This method does not support nchw data
2428 int kw = jcp.kw;
2429
2430 int dst_count = 0;
2431
2432 int ic_block_step_idx = src_offset / (jcp.typesize_in * ic_block_step);
2433
2434 int pipeline_length = (isa_has_bf16(jcp.isa))
2435 ? nstl::max(1, nstl::min(4, ur_w / 2))
2436 : 1;
2437 may_be_set_oc_tail_mask();
2438
2439 const int dst_off_reg = (!isa_has_bf16(jcp.isa)) ? 26 : 31;
2440 auto load_dst = [=](int c) {
2441 bool is_tail = ur_w % 2 && c * 2 + 2 >= ur_w;
2442 bool is_ddst_nxc = is_ddst_layout_nxc();
2443 auto offset = get_ddst_offset(2 * c) + ddst_offset;
2444
2445 Opmask load_mask = is_ddst_nxc || is_tail ? m_0000ffff : m_ffffffff;
2446 vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | load_mask | T_z,
2447 EVEX_compress_addr(reg_ddst, offset));
2448
2449 if (is_ddst_nxc && !is_tail) {
2450 offset += get_ddst_offset(1) - 32;
2451 vmovdqu16(Zmm(dst_off_reg - c % pipeline_length) | m_ffff0000,
2452 EVEX_compress_addr(reg_ddst, offset));
2453 }
2454 vpermw(Zmm(dst_off_reg - c % pipeline_length), get_perm_reg(),
2455 Zmm(dst_off_reg - c % pipeline_length));
2456 };
2457
2458 for (int i_kw = 0; i_kw < kw; i_kw++)
2459 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2460 vmovups(Zmm(i_kw * ic_block_step + i_ic),
2461 EVEX_compress_addr(reg_kernel,
2462 get_kernel_offset(i_ic, i_kw) + kernel_offset));
2463
2464 for (dst_count = 0; dst_count < pipeline_length; dst_count++) {
2465 load_dst(dst_count);
2466 }
2467 auto get_bcast_ptr = [=](int i_ur, int i_kw, int ic) {
2468 int scale = 2 * jcp.typesize_in;
2469 return rsp + b_ic * scale + permw_buffer_start + (i_ur + i_kw) * 64
2470 + jcp.typesize_in * 2
2471 * (ic_block_step_idx * ic_block_step + ic);
2472 };
2473
2474 for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
2475 for (int i_kw = 0; i_kw < kw; i_kw++) {
2476 if (check_borders(ur_w, pad_l, pad_r, i_ur, i_kw)) {
2477 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2478 if (!isa_has_bf16(jcp.isa)) {
2479 auto zmm_src = Zmm(28);
2480 vpbroadcastd(
2481 zmm_src, ptr[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2482 bf16_emu_->vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2483 Zmm(dst_off_reg - dst_count % pipeline_length),
2484 zmm_src);
2485 } else {
2486 vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
2487 Zmm(dst_off_reg - dst_count % pipeline_length),
2488 zword_b[get_bcast_ptr(i_ur, i_kw, i_ic)]);
2489 }
2490 }
2491 }
2492 }
2493 if (dst_count * 2 < ur_w) load_dst(dst_count);
2494 dst_count++;
2495 }
2496 for (int i_kw = 0; i_kw < kw; i_kw++) {
2497 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2498 auto l_offset = get_kernel_offset(i_ic, i_kw);
2499 vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
2500 Zmm(i_kw * ic_block_step + i_ic));
2501 }
2502 }
2503
2504 may_be_reset_oc_tail_mask();
2505}
2506
2507void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2508 compute_diff_bias_init() {
2509 auto reg_unit_val = reg_tmp.cvt16();
2510 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
2511 vpbroadcastw(vreg_bias_unit, reg_unit_val);
2512
2513 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2514 vmovups(vreg_bias_acc, ptr[reg_tmp]);
2515
2516 if (jcp.uses_permw_transposition) {
2517 mov(reg_tmp, dst_prm_table);
2518 vmovups(get_perm_reg(), ptr[reg_tmp]);
2519 }
2520}
2521
2522void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_diff_bias_row(
2523 bool is_partial) {
2524 if (!jcp.with_bias) return;
2525 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2526 Label skip_label;
2527 test(reg_tmp, FLAG_IC_FIRST);
2528 jz(skip_label, T_NEAR);
2529
2530 may_be_set_oc_tail_mask();
2531
2532 if (is_partial) compute_diff_bias_init();
2533
2534 auto compute_step = [&](bool is_tail) {
2535 if (jcp.transpose_dst) {
2536 UNUSED(is_tail);
2537 vmovups(vreg_bias_ddst, ptr[reg_ddst]);
2538 } else {
2539 auto vreg_ddst_load = is_ddst_layout_nxc() || is_tail
2540 ? vreg_bias_ddst | m_0000ffff | T_z
2541 : vreg_bias_ddst;
2542 vmovdqu16(vreg_ddst_load, ptr[reg_ddst]);
2543 if (is_ddst_layout_nxc() && !is_tail) {
2544 const int shift_16_elems = 16 * jcp.typesize_in;
2545 vmovdqu16(vreg_bias_ddst | m_ffff0000,
2546 ptr[reg_ddst + get_ddst_offset(1) - shift_16_elems]);
2547 }
2548 vpermw(vreg_bias_ddst, get_perm_reg(), vreg_bias_ddst);
2549 }
2550 if (!isa_has_bf16(jcp.isa))
2551 bf16_emu_->vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2552 else
2553 vdpbf16ps(vreg_bias_acc, vreg_bias_ddst, vreg_bias_unit);
2554 };
2555
2556 Label ow_loop, ow_tail;
2557 int niters = jcp.tr_ow / 2;
2558 if (niters > 0) {
2559 mov(reg_tmp, jcp.tr_ow / 2);
2560 L(ow_loop);
2561 compute_step(false);
2562 add(reg_ddst, get_ddst_offset(2));
2563 sub(reg_tmp, 1);
2564 jnz(ow_loop, T_NEAR);
2565 }
2566 if (jcp.tr_ow % 2) compute_step(true);
2567
2568 if (niters > 0) sub(reg_ddst, get_ddst_offset(2 * niters));
2569
2570 if (is_partial) {
2571 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2572 vmovups(ptr[reg_tmp], vreg_bias_acc);
2573 }
2574
2575 may_be_reset_oc_tail_mask();
2576
2577 L(skip_label);
2578}
2579void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
2580 maybe_compute_diff_bias() {
2581 // In harness_3d_reduction case calculation of diff_bias is called
2582 // for every ow row separately to be aligned with od loop in
2583 // compute_od_loop_common()
2584 if (!jcp.with_bias || jcp.harness == harness_3d_reduction) return;
2585 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
2586
2587 Label skip_label;
2588 test(reg_tmp, FLAG_IC_FIRST);
2589 jz(skip_label, T_NEAR);
2590
2591 switch (jcp.harness) {
2592 case harness_2d_reduction:
2593 mov(reg_oj, ptr[param + GET_OFF(os_index_end)]);
2594 sub(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
2595 break;
2596 case harness_mb_reduction:
2597 case harness_compute_full_spatial: mov(reg_oj, jcp.oh); break;
2598 case harness_3d_reduction:
2599 default: assert(!"Invalid harness type");
2600 }
2601
2602 compute_diff_bias_init();
2603
2604 cmp(reg_oj, 0);
2605 jle(skip_label, T_NEAR); // nothing to do
2606 Label bias_loop;
2607 L(bias_loop);
2608 {
2609 compute_diff_bias_row(false);
2610 add(reg_ddst, get_ddst_offset(0, 1));
2611
2612 sub(reg_oj, 1);
2613 jnz(bias_loop, T_NEAR);
2614 }
2615
2616 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
2617 vmovups(ptr[reg_tmp], vreg_bias_acc);
2618
2619 // restore reg_ddst value
2620 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
2621
2622 L(skip_label);
2623}
2624
2625void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
2626 int ur_w, int pad_l, int pad_r, int ic_block_step, int src_offset,
2627 int kernel_offset, int ddst_offset, bool is_tail) {
2628
2629 if (jcp.uses_permw_transposition)
2630 if (jcp.kernel_kind == expl_bcast)
2631 compute_ic_block_step_vpermw_expl(ur_w, pad_l, pad_r, ic_block_step,
2632 src_offset, kernel_offset, ddst_offset, is_tail);
2633 else
2634 compute_ic_block_step_vpermw(ur_w, pad_l, pad_r, ic_block_step,
2635 src_offset, kernel_offset, ddst_offset, is_tail);
2636 else if (jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
2637 compute_ic_block_step_interleave(ur_w, pad_l, pad_r, ic_block_step,
2638 src_offset, kernel_offset, ddst_offset, is_tail);
2639 else
2640 compute_ic_block_step_extern(ur_w, pad_l, pad_r, ic_block_step,
2641 src_offset, kernel_offset, ddst_offset, is_tail);
2642}
2643
2644void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::get_ur_w(
2645 int &ur_w, int &ur_w_tail, int &ur_w_trips) {
2646 if (jcp.tr_ow <= max_ur_w) {
2647 ur_w = jcp.tr_ow;
2648 ur_w_tail = 0;
2649 ur_w_trips = 1;
2650 return;
2651 }
2652
2653 int r_pad = 0;
2654 if (!jcp.transpose_src) {
2655 // If jcp.transpose_src, the buffer has physical padding
2656 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2657 r_pad = nstl::max(0,
2658 calculate_end_padding(
2659 jcp.l_pad, jcp.tr_ow, jcp.tr_iw, jcp.stride_w, ext_kw));
2660 }
2661 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2662 ur_w = max_ur_w;
2663 ur_w_trips = jcp.tr_ow / ur_w;
2664 ur_w_tail = jcp.tr_ow % ur_w;
2665 if ((ur_w_tail == 0 && jcp.r_pad != 0) || r_pad >= ur_w_tail) {
2666 if (ur_w_trips > 1) {
2667 ur_w_tail += ur_w;
2668 ur_w_trips--;
2669 } else {
2670 int ur_w_tail_total = ur_w + ur_w_tail;
2671 ur_w = (ur_w_tail_total % 4 == 0) ? ur_w_tail / 2
2672 : ur_w_tail / 2 + 1;
2673 ur_w_tail = ur_w_tail_total - ur_w;
2674 if (l_pad > ur_w / 2) {
2675 ur_w = (l_pad % 2 == 0) ? l_pad : l_pad + 1;
2676 ur_w_tail = ur_w_tail_total - ur_w;
2677 } else if (r_pad > ur_w_tail) {
2678 ur_w_tail = (r_pad % 2 == 0) ? r_pad : r_pad + 1;
2679 ur_w = ur_w_tail_total - ur_w_tail;
2680 }
2681 }
2682 }
2683}
2684
2685void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
2686 compute_oh_step_unroll_ow_icblock(int ic_block_step) {
2687 Label kh_label, kd_label;
2688
2689 int ic_block = jcp.ic_block;
2690 int ic_tail = jcp.ic_tail;
2691 int ow = jcp.tr_ow;
2692 int r_pad = 0;
2693 int ur_w, ur_w_tail, ur_w_trips;
2694 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2695 assert(ur_w_tail == 0 && ur_w_trips == 1);
2696
2697 if (!jcp.transpose_src) {
2698 // If jcp.transpose_src, the buffer has physical padding
2699 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2700 int iw = jcp.tr_iw;
2701 r_pad = nstl::max(0,
2702 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2703 }
2704 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2705
2706 if (jcp.ndims == 5) {
2707 L(kd_label);
2708 mov(reg_src, aux_reg_src);
2709 mov(reg_kernel, aux_reg_kernel);
2710 }
2711
2712 mov(kj, reg_kh);
2713 L(kh_label);
2714 {
2715 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2716 // icb loop is supported for nxc layout only
2717 assert(IMPLICATION(generate_icb_loop,
2718 is_src_layout_nxc() && is_ddst_layout_nxc()));
2719 Label icb_block_label, icb_block_label_end;
2720 if (generate_icb_loop || ic_tail) {
2721 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2722 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2723 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2724 L(icb_block_label);
2725 }
2726
2727 if (jcp.uses_permw_transposition) {
2728 convert_src_to_vnni_format(ur_w, l_pad, r_pad, 0);
2729 xor_(b_ic, b_ic);
2730 }
2731
2732 const int ic_tail_loop_work = rnd_up(ic_tail, ic_block_step);
2733 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
2734 const int src_offset = get_src_offset(i_b_ic, 0);
2735 compute_ic_block_step(ur_w, l_pad, r_pad, ic_block_step, src_offset,
2736 get_kernel_offset(i_b_ic, 0), 0, true);
2737 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2738 // We relax the boundary for reg_icb, as the src is already
2739 // converted to vnni_format with appropriate padding either through
2740 // transpose_src or convert_to_src_to_vnni_format. We can safely
2741 // allow compute_ic_block_step overstep as it operates on buffer
2742 // instead of src.
2743 if (ic_tail && i_b_ic + ic_block_step == ic_tail_loop_work) {
2744 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2745 cmp(reg_icb, 0);
2746 jle(icb_block_label_end, T_NEAR);
2747 }
2748 }
2749 L(icb_block_label_end);
2750
2751 const auto src_icb_loop_shift_bytes = get_src_offset(ic_block, 0);
2752 const auto kernel_icb_loop_shift_bytes
2753 = get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw);
2754 if (generate_icb_loop) {
2755 add(reg_src, src_icb_loop_shift_bytes);
2756 safe_add(reg_kernel, kernel_icb_loop_shift_bytes, reg_long_offt);
2757
2758 assert(jcp.uses_permw_transposition);
2759 cmp(reg_icb, 0);
2760 jg(icb_block_label, T_NEAR);
2761 }
2762
2763 if (generate_icb_loop || ic_tail) {
2764 // restore pointers
2765 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2766 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2767 }
2768
2769 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2770 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2771 dec(kj);
2772 cmp(kj, 0);
2773 jg(kh_label, T_NEAR);
2774 }
2775
2776 if (jcp.ndims == 5) {
2777 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2778 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2779 dec(ki);
2780 cmp(ki, 0);
2781 jg(kd_label, T_NEAR);
2782 }
2783}
2784
2785void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::
2786 compute_oh_step_unroll_ow(int ic_block_step) {
2787 Label kh_label, ic_block_label, kd_label;
2788
2789 int ic_block = jcp.ic_block;
2790 const int ic_tail = jcp.ic_tail;
2791 int ow = jcp.tr_ow;
2792
2793 int r_pad = 0;
2794 int ur_w, ur_w_tail, ur_w_trips;
2795 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2796 assert(ur_w_tail == 0 && ur_w_trips == 1);
2797
2798 if (!jcp.transpose_src) {
2799 // If jcp.transpose_src, the buffer has physical padding
2800 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2801 int iw = jcp.tr_iw;
2802 r_pad = nstl::max(0,
2803 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2804 }
2805 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2806
2807 if (jcp.ndims == 5) {
2808 L(kd_label);
2809 mov(reg_src, aux_reg_src);
2810 mov(reg_kernel, aux_reg_kernel);
2811 }
2812
2813 mov(kj, reg_kh);
2814 L(kh_label);
2815 {
2816 size_t src_offset = get_src_offset(ic_block_step, 0);
2817
2818 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
2819 // icb loop is supported for nxc layout only
2820 assert(IMPLICATION(generate_icb_loop,
2821 is_src_layout_nxc() && is_ddst_layout_nxc()));
2822 Label icb_block_label, icb_block_label_end;
2823 if (generate_icb_loop || ic_tail) {
2824 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
2825 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
2826 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
2827 L(icb_block_label);
2828 }
2829
2830 xor_(b_ic, b_ic);
2831 if (jcp.uses_permw_transposition) {
2832 convert_src_to_vnni_format(ow, l_pad, r_pad, 0);
2833 xor_(b_ic, b_ic);
2834 }
2835
2836 L(ic_block_label);
2837 {
2838 compute_ic_block_step(
2839 ur_w, l_pad, r_pad, ic_block_step, 0, 0, 0, true);
2840 assert(jcp.ic_block % jcp.ic_block_step == 0);
2841 safe_add(reg_src, src_offset, reg_long_offt);
2842 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2843 add(b_ic, ic_block_step);
2844 if (generate_icb_loop || ic_tail) sub(reg_icb, ic_block_step);
2845 // We relax the boundary for reg_icb, as the src is already
2846 // converted to vnni_format with appropriate padding either through
2847 // transpose_src or convert_to_src_to_vnni_format. We can safely
2848 // allow compute_ic_block_step overstep as it operates on buffer
2849 // instead of src.
2850 if (ic_tail) {
2851 assert(jcp.transpose_src || jcp.uses_permw_transposition);
2852 cmp(reg_icb, 0);
2853 jle(icb_block_label_end, T_NEAR);
2854 }
2855 cmp(b_ic, jcp.ic_block);
2856 jl(ic_block_label, T_NEAR);
2857 }
2858 L(icb_block_label_end);
2859
2860 if (jcp.uses_permw_transposition) {
2861 if (generate_icb_loop || ic_tail) {
2862 // substract pointer shift made within ic block loop
2863 // and move to next ic block
2864 safe_add(reg_kernel,
2865 get_kernel_offset(-ic_block, jcp.kd * jcp.kh * jcp.kw),
2866 reg_long_offt);
2867
2868 cmp(reg_icb, 0);
2869 jg(icb_block_label, T_NEAR);
2870 // restore pointers
2871 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2872 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2873
2874 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2875 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2876 } else {
2877 add(reg_src,
2878 get_src_offset(0, 0, filter_h_to_src(1))
2879 - jcp.typesize_in * ic_block);
2880 }
2881 } else if (ic_tail) {
2882 // restore pointers
2883 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
2884 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
2885
2886 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
2887 add(reg_kernel, get_kernel_offset(0, jcp.kw));
2888 } else if (jcp.is_1stconv && !jcp.transpose_src) {
2889 // Fixup reg_src to point to the correct location
2890 safe_add(reg_src,
2891 get_src_offset(0, 0, filter_h_to_src(1))
2892 - src_offset * (jcp.ic_block / ic_block_step),
2893 reg_long_offt);
2894 } else {
2895 if (jcp.dilate_h > 0)
2896 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
2897 }
2898 if (!generate_icb_loop && !ic_tail)
2899 // substract pointer shift made within ic block loop
2900 // and move to next kh index
2901 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
2902 dec(kj);
2903 cmp(kj, 0);
2904 jg(kh_label, T_NEAR);
2905 }
2906 if (jcp.ndims == 5) {
2907 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
2908 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
2909 dec(ki);
2910 cmp(ki, 0);
2911 jg(kd_label, T_NEAR);
2912 }
2913}
2914
2915void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_common(
2916 int ic_block_step) {
2917 Label kh_label, ic_block_label, ow_block_label, kd_label;
2918
2919 int ic_block = jcp.ic_block;
2920 int ic_tail = jcp.ic_tail;
2921 int ow = jcp.tr_ow;
2922 int r_pad = 0;
2923 if (!jcp.transpose_src) {
2924 // If jcp.transpose_src, the buffer has physical padding
2925 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
2926 int iw = jcp.tr_iw;
2927 r_pad = nstl::max(0,
2928 calculate_end_padding(jcp.l_pad, ow, iw, jcp.stride_w, ext_kw));
2929 }
2930 int l_pad = (jcp.transpose_src) ? 0 : jcp.l_pad;
2931
2932 int ur_w, ur_w_trips, ur_w_tail;
2933 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
2934 assert(l_pad <= ur_w);
2935 assert(r_pad <= ur_w_tail);
2936
2937 auto src_comeback
2938 = get_src_offset(0, filter_w_to_src(0, ur_w_trips * ur_w, l_pad));
2939 auto ddst_comeback = get_ddst_offset(ur_w_trips * ur_w);
2940
2941 if (jcp.ndims == 5) {
2942 L(kd_label);
2943 mov(reg_src, aux_reg_src);
2944 mov(reg_kernel, aux_reg_kernel);
2945 }
2946
2947 bool use_kh_ic_ow_loop_order = !jcp.uses_permw_transposition;
2948 if (use_kh_ic_ow_loop_order) {
2949 assert(!jcp.uses_permw_transposition);
2950
2951 auto ic_loop = [=](int ic_block_step) {
2952 Label ow_block_label;
2953 // create a local copy
2954 int ur_w_blocks = ur_w_trips;
2955 auto src_offset = get_src_offset(ic_block_step, 0);
2956 if (l_pad != 0) {
2957 ur_w_blocks--;
2958 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
2959 add(reg_src,
2960 get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
2961 add(reg_ddst, get_ddst_offset(ur_w));
2962 }
2963
2964 if (ur_w_blocks > 0) {
2965 xor_(reg_ur_w_trips, reg_ur_w_trips);
2966 L(ow_block_label);
2967 {
2968 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
2969 add(reg_src,
2970 get_src_offset(0, filter_w_to_src(0, ur_w, 0)));
2971 add(reg_ddst, get_ddst_offset(ur_w));
2972
2973 inc(reg_ur_w_trips);
2974 cmp(reg_ur_w_trips, ur_w_blocks);
2975 jl(ow_block_label, T_NEAR);
2976 }
2977 }
2978
2979 if (ur_w_tail > 0) {
2980 compute_ic_block_step(
2981 ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
2982 }
2983
2984 sub(reg_src, src_comeback);
2985 sub(reg_ddst, ddst_comeback);
2986
2987 safe_add(reg_src, src_offset, reg_long_offt);
2988 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
2989 };
2990
2991 mov(kj, reg_kh);
2992 L(kh_label);
2993 {
2994 Label ic_tail_label, skip_ic_tail_offset_compensation;
2995 if (ic_tail) {
2996 // It appears currently, generate_icb_loop is not enabled here,
2997 // implying at most one icb is processed.
2998 assert(jcp.nb_ic_blocking_max == 1);
2999 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3000 } else {
3001 mov(reg_icb, ic_block);
3002 }
3003
3004 L(ic_block_label);
3005 {
3006 ic_loop(ic_block_step);
3007 sub(reg_icb, ic_block_step);
3008 // We relax the boundary for reg_icb, as the src is already
3009 // converted to vnni_format with appropriate padding either
3010 // through transpose_src or convert_to_src_to_vnni_format. We
3011 // can safely allow compute_ic_block_step overstep as it
3012 // operates on buffer instead of src.
3013 if (ic_tail) {
3014 assert(jcp.transpose_src || jcp.uses_permw_transposition);
3015 }
3016 cmp(reg_icb, 0);
3017 jg(ic_block_label, T_NEAR);
3018 }
3019
3020 if (ic_tail) {
3021 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3022 cmp(reg_icb, jcp.simd_w);
3023 je(skip_ic_tail_offset_compensation);
3024 add(reg_kernel,
3025 get_kernel_offset(
3026 jcp.ic_block - rnd_up(ic_tail, ic_block_step),
3027 0));
3028 safe_add(reg_src,
3029 get_src_offset(0, 0, filter_h_to_src(1))
3030 - get_src_offset(
3031 rnd_up(ic_tail, ic_block_step), 0),
3032 reg_long_offt);
3033 L(skip_ic_tail_offset_compensation);
3034 }
3035 if (jcp.is_1stconv && !jcp.transpose_src) {
3036 // Fixup reg_src to point to the correct location
3037 auto src_offset = get_src_offset(ic_block_step, 0);
3038 safe_add(reg_src,
3039 get_src_offset(0, 0, filter_h_to_src(1))
3040 - src_offset * (jcp.ic_block / ic_block_step),
3041 reg_long_offt);
3042 } else if (jcp.dilate_h > 0) {
3043 add(reg_src, get_src_offset(0, 0, jcp.dilate_h));
3044 }
3045 // substract pointer shift made within ic block loop
3046 // and move to next kh index
3047 add(reg_kernel, get_kernel_offset(-ic_block, jcp.kw));
3048 dec(kj);
3049 cmp(kj, 0);
3050 jg(kh_label, T_NEAR);
3051 }
3052 } else {
3053 assert(!jcp.is_1stconv);
3054 auto src_icbstep_shift = get_src_offset(1, 0);
3055
3056 auto ic_loop = [=](int ic_block_step) {
3057 int ic_work = ic_block;
3058 Label ow_block_label, ic_block_label_padl, ic_block_label_general,
3059 ic_block_label_tail;
3060 int ur_w_blocks = ur_w_trips;
3061 if (l_pad != 0) {
3062 ur_w_blocks--;
3063 xor_(b_ic, b_ic);
3064 if (jcp.uses_permw_transposition) {
3065 convert_src_to_vnni_format(ur_w, l_pad, 0, 0);
3066 }
3067 L(ic_block_label_padl);
3068 {
3069 compute_ic_block_step(
3070 ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
3071 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3072 reg_long_offt);
3073 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3074
3075 add(b_ic, ic_block_step);
3076 cmp(b_ic, ic_work);
3077 jl(ic_block_label_padl, T_NEAR);
3078 }
3079 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3080 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3081 add(reg_src,
3082 get_src_offset(0, filter_w_to_src(0, ur_w, l_pad)));
3083 add(reg_ddst, get_ddst_offset(ur_w));
3084 }
3085
3086 if (ur_w_blocks > 0) {
3087 xor_(reg_ur_w_trips, reg_ur_w_trips);
3088 L(ow_block_label);
3089 {
3090 if (jcp.uses_permw_transposition) {
3091 convert_src_to_vnni_format(ur_w, 0, 0, 0);
3092 }
3093 xor_(b_ic, b_ic);
3094 L(ic_block_label_general);
3095 {
3096 compute_ic_block_step(
3097 ur_w, 0, 0, ic_block_step, 0, 0, 0);
3098 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3099 reg_long_offt);
3100 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3101
3102 add(b_ic, ic_block_step);
3103 cmp(b_ic, ic_work);
3104 jl(ic_block_label_general, T_NEAR);
3105 }
3106 safe_sub(reg_src, src_icbstep_shift * ic_work,
3107 reg_long_offt);
3108 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3109 add(reg_src, get_src_offset(0, filter_w_to_src(0, ur_w)));
3110 add(reg_ddst, get_ddst_offset(ur_w));
3111
3112 inc(reg_ur_w_trips);
3113 cmp(reg_ur_w_trips, ur_w_blocks);
3114 jl(ow_block_label, T_NEAR);
3115 }
3116 }
3117
3118 if (ur_w_tail > 0) {
3119 if (jcp.uses_permw_transposition) {
3120 convert_src_to_vnni_format(ur_w_tail, 0, r_pad, 0);
3121 }
3122 xor_(b_ic, b_ic);
3123 L(ic_block_label_tail);
3124 {
3125 compute_ic_block_step(
3126 ur_w_tail, 0, r_pad, ic_block_step, 0, 0, 0, true);
3127 safe_add(reg_src, src_icbstep_shift * ic_block_step,
3128 reg_long_offt);
3129 add(reg_kernel, get_kernel_offset(ic_block_step, 0));
3130
3131 add(b_ic, ic_block_step);
3132 cmp(b_ic, ic_work);
3133 jl(ic_block_label_tail, T_NEAR);
3134 }
3135 safe_sub(reg_src, src_icbstep_shift * ic_work, reg_long_offt);
3136 sub(reg_kernel, get_kernel_offset(ic_work, 0));
3137 }
3138
3139 sub(reg_src, src_comeback);
3140 sub(reg_ddst, ddst_comeback);
3141 };
3142
3143 mov(kj, reg_kh);
3144 L(kh_label);
3145 {
3146 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3147 // icb loop is supported for nxc layout only
3148 assert(IMPLICATION(generate_icb_loop,
3149 is_src_layout_nxc() && is_ddst_layout_nxc()));
3150 Label icb_block_label, icb_block_label_cb, ic_tail_loop_label;
3151
3152 if (generate_icb_loop) {
3153 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3154 mov(ptr[rsp + icb_loop_src_ptr], reg_src);
3155 }
3156 if (ic_tail || generate_icb_loop)
3157 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3158 L(icb_block_label);
3159
3160 ic_loop(ic_block_step);
3161
3162 if (generate_icb_loop) {
3163 add(reg_src, get_src_offset(ic_block, 0));
3164 safe_add(reg_kernel,
3165 get_kernel_offset(0, jcp.kd * jcp.kh * jcp.kw),
3166 reg_long_offt);
3167 sub(reg_icb, ic_block);
3168 cmp(reg_icb, 0);
3169 jg(icb_block_label, T_NEAR);
3170 }
3171
3172 if (generate_icb_loop) {
3173 // restore pointers
3174 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3175 mov(reg_src, ptr[rsp + icb_loop_src_ptr]);
3176 }
3177
3178 add(reg_src, get_src_offset(0, 0, filter_h_to_src(1)));
3179 add(reg_kernel, get_kernel_offset(0, jcp.kw));
3180 dec(kj);
3181 cmp(kj, 0);
3182 jg(kh_label, T_NEAR);
3183 }
3184 }
3185 if (jcp.ndims == 5) {
3186 add(aux_reg_src, get_src_offset(0, 0, filter_d_to_src(1)));
3187 add(aux_reg_kernel, get_kernel_offset(0, jcp.kh * jcp.kw));
3188 dec(ki);
3189 cmp(ki, 0);
3190 jg(kd_label, T_NEAR);
3191 }
3192}
3193
3194void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_oh_step_disp() {
3195 int ic_block_step = jcp.ic_block_step;
3196
3197 bool too_large_to_unroll = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
3198 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
3199
3200 int ow = jcp.tr_ow;
3201 if (jcp.ndims == 5) {
3202 /* NOTE: reg_kd_count = aux_reg_src = r12. The following order of
3203 * 'movs' must be guaranteed. */
3204 mov(ki, reg_kd_count);
3205 mov(EVEX_compress_addr(rsp, kd_count_offset), reg_kd_count);
3206 mov(aux_reg_src, reg_src);
3207 mov(aux_reg_kernel, reg_kernel);
3208 }
3209 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) {
3210 compute_oh_step_unroll_ow_icblock(ic_block_step);
3211 } else if (ow <= max_ur_w) {
3212 compute_oh_step_unroll_ow(ic_block_step);
3213 } else {
3214 compute_oh_step_common(ic_block_step);
3215 }
3216
3217 // In harness_3d_reduction case calculation of diff_bias is called
3218 // for every ow row separately to be aligned with od loop in
3219 // compute_od_loop_common()
3220 if (jcp.harness == harness_3d_reduction) compute_diff_bias_row();
3221 if (jcp.ndims == 5) {
3222 mov(reg_src, aux_reg_src);
3223 mov(reg_kernel, aux_reg_kernel);
3224 mov(reg_kd_count, EVEX_compress_addr(rsp, kd_count_offset));
3225 od_step_comeback_pointers();
3226 } else {
3227 oh_step_comeback_pointers();
3228 }
3229}
3230
3231void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel() {
3232 if (jcp.harness == harness_compute_full_spatial && !jcp.with_bias) return;
3233 Label skip_zeroing, zeroing_loop;
3234
3235 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3236 cmp(reg_tmp, 0);
3237 jz(skip_zeroing, T_NEAR);
3238
3239 Zmm zero = Zmm(0);
3240 vpxord(zero, zero, zero);
3241 if (jcp.with_bias) {
3242 Label skip_bias_zeroing;
3243 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3244 test(reg_tmp, FLAG_IC_FIRST);
3245 jz(skip_bias_zeroing, T_NEAR);
3246
3247 mov(reg_tmp, ptr[param + GET_OFF(bias)]);
3248 vmovups(ptr[reg_tmp], zero);
3249
3250 L(skip_bias_zeroing);
3251 if (jcp.harness == harness_compute_full_spatial)
3252 jmp(skip_zeroing, T_NEAR);
3253 }
3254
3255 const size_t kernel_block_bytes
3256 = get_kernel_offset(0, jcp.kw * jcp.kh * jcp.kd);
3257 Label icb_block_label, icb_block_label_cb;
3258
3259 const bool generate_icb_loop = jcp.nb_ic_blocking_max > 1;
3260 // icb loop is supported for nxc layout only
3261 assert(IMPLICATION(
3262 generate_icb_loop, is_src_layout_nxc() && is_ddst_layout_nxc()));
3263 if (generate_icb_loop) {
3264 mov(ptr[rsp + icb_loop_ker_ptr], reg_kernel);
3265 mov(reg_icb, ptr[param + GET_OFF(reduce_work)]);
3266 L(icb_block_label);
3267 }
3268
3269 xor_(reg_tmp, reg_tmp);
3270 L(zeroing_loop);
3271 {
3272 assert(get_kernel_offset(1, 0) == cpu_isa_traits<avx512_core>::vlen);
3273 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3274 vmovups(ptr[reg_kernel + reg_tmp + get_kernel_offset(ic1, 0)],
3275 zero);
3276 add(reg_tmp, get_kernel_offset(jcp.ic_block, 0));
3277 cmp(reg_tmp, kernel_block_bytes);
3278 jnz(zeroing_loop);
3279 }
3280
3281 if (generate_icb_loop) {
3282 add(reg_kernel, kernel_block_bytes);
3283 sub(reg_icb, jcp.ic_block);
3284 cmp(reg_icb, 0);
3285 jg(icb_block_label, T_NEAR);
3286 // restore pointer
3287 mov(reg_kernel, ptr[rsp + icb_loop_ker_ptr]);
3288 }
3289
3290 L(skip_zeroing);
3291}
3292
3293void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_oh_loop_common(
3294 bool is_partial) {
3295 int b_pad = jcp.b_pad;
3296 int t_pad = jcp.t_pad;
3297 bool is_dilated = jcp.dilate_h != 0;
3298 int dilate_h = jcp.dilate_h + 1;
3299 int stride_h = jcp.stride_h;
3300 auto filter_step_size = get_kernel_offset(0, jcp.kw);
3301 auto src_step_size = get_src_offset(0, 0, 1);
3302 auto ddst_step_size = get_ddst_offset(0, 1);
3303 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_label_end,
3304 oh_tpad_tail_label, oh_tpad_tail_label_end, oh_bpad_label,
3305 oh_bpad_label_end, oh_dilate_label_shift, oh_dilate_label_noshift,
3306 oh_dilate_label_end, oh_dilate_setup_label_shift,
3307 oh_dilate_setup_label_noshift, oh_dilate_setup_label_end;
3308
3309 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
3310 int oh_body_end = div_up(t_pad + jcp.ih - ext_kh + 1, stride_h);
3311 int oh_head_end = nstl::min(div_up(t_pad, stride_h), oh_body_end);
3312 int oh_head_overflow_end = div_up(t_pad, stride_h);
3313 int oh_tail_end = jcp.oh;
3314
3315 int body_src_start_offset = (stride_h - (t_pad % stride_h)) % stride_h;
3316 int ih_body_end
3317 = nstl::max(-t_pad + oh_body_end * stride_h, body_src_start_offset);
3318
3319 if (is_partial)
3320 mov(reg_oj, ptr[param + GET_OFF(os_index_begin)]);
3321 else
3322 xor_(reg_oj, reg_oj);
3323
3324 /* Compute 'top' edge */
3325 if (t_pad > 0) {
3326 if (is_partial) {
3327 cmp(reg_oj, oh_head_overflow_end);
3328 jge(oh_tpad_tail_label_end, T_NEAR);
3329 }
3330 const int overflow
3331 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3332 const int underflow = div_up(t_pad, dilate_h);
3333 const int initial_kh = jcp.kh - overflow - underflow;
3334
3335 // Setup reg_kh, reg_kernel, and reg_src
3336 mov(reg_kh, initial_kh);
3337 add(reg_kernel, filter_step_size * underflow);
3338 if (is_dilated) {
3339 const int tail = t_pad % dilate_h;
3340 const int shift = tail == 0 ? 0 : dilate_h - tail;
3341 mov(reg_ih_shift, shift);
3342 if (!is_partial) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3343 add(reg_src, src_step_size * shift);
3344 }
3345
3346 if (is_partial) {
3347 Label head_setup, head_setup_finish;
3348 cmp(reg_oj, 0);
3349 je(head_setup_finish, T_NEAR);
3350 mov(reg_oj_setup, reg_oj);
3351
3352 L(head_setup);
3353 if (is_dilated) {
3354 inc(reg_ih_shift);
3355 cmp(reg_ih_shift, dilate_h);
3356 jl(oh_dilate_setup_label_shift, T_NEAR);
3357 // unshift src as new kernel element enters
3358 sub(reg_src, src_step_size * (dilate_h - 1));
3359 xor_(reg_ih_shift, reg_ih_shift);
3360 }
3361 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3362 add(reg_kh, stride_h);
3363 sub(reg_kernel, filter_step_size * stride_h);
3364 if (is_dilated) {
3365 jmp(oh_dilate_setup_label_noshift, T_NEAR);
3366 L(oh_dilate_setup_label_shift);
3367 // shift src as old kernel element progresses
3368 add(reg_src, src_step_size * stride_h);
3369 L(oh_dilate_setup_label_noshift);
3370 }
3371 sub(reg_oj_setup, 1);
3372 jg(head_setup, T_NEAR);
3373 L(head_setup_finish);
3374
3375 if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3376 if (oh_head_end < oh_head_overflow_end) {
3377 cmp(reg_oj, oh_head_end);
3378 jge(oh_tpad_label_end, T_NEAR);
3379 }
3380 }
3381
3382 //Setup reg_kernel
3383 // If dilated, shift src ptr
3384 // Loop
3385 L(oh_tpad_label);
3386 compute_oh_step_disp();
3387 add(reg_ddst, ddst_step_size);
3388 if (is_dilated) {
3389 mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3390 inc(reg_ih_shift);
3391 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3392 cmp(reg_ih_shift, dilate_h);
3393 jl(oh_dilate_label_shift, T_NEAR);
3394 // unshift src as new kernel element enters
3395 sub(reg_src, src_step_size * (dilate_h - 1));
3396 xor_(reg_ih_shift, reg_ih_shift);
3397 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3398 }
3399 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3400 add(reg_kh, stride_h);
3401 sub(reg_kernel, filter_step_size * stride_h);
3402 if (is_dilated) {
3403 jmp(oh_dilate_label_noshift, T_NEAR);
3404 L(oh_dilate_label_shift);
3405 // shift src as old kernel element progresses
3406 add(reg_src, src_step_size * stride_h);
3407 L(oh_dilate_label_noshift);
3408 }
3409 inc(reg_oj);
3410
3411 if (is_partial) {
3412 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3413 jge(oh_bpad_label_end, T_NEAR);
3414 }
3415 cmp(reg_oj, oh_head_end);
3416 jl(oh_tpad_label, T_NEAR);
3417
3418 L(oh_tpad_label_end);
3419 // need second loop to process kernel if it is larger than the src
3420 // (does not apply to dilations as they must have unit stride)
3421 if (oh_head_end < oh_head_overflow_end) {
3422 assert(!is_dilated);
3423
3424 cmp(reg_oj, oh_head_overflow_end);
3425 jge(oh_tpad_tail_label_end, T_NEAR);
3426
3427 mov(reg_kh, jcp.ih);
3428 L(oh_tpad_tail_label);
3429 {
3430 compute_oh_step_disp();
3431 add(reg_ddst, ddst_step_size);
3432 sub(reg_kernel, filter_step_size * stride_h);
3433
3434 inc(reg_oj);
3435
3436 if (is_partial) {
3437 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3438 jge(oh_bpad_label_end, T_NEAR);
3439 }
3440 cmp(reg_oj, oh_head_overflow_end);
3441 jl(oh_tpad_tail_label, T_NEAR);
3442 }
3443 }
3444 if (body_src_start_offset != 0) {
3445 add(reg_kernel, filter_step_size * body_src_start_offset);
3446 add(reg_src, src_step_size * body_src_start_offset);
3447 }
3448 L(oh_tpad_tail_label_end);
3449 }
3450
3451 if (is_partial) {
3452 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3453 jge(oh_bpad_label_end, T_NEAR);
3454 }
3455 cmp(reg_oj, oh_body_end);
3456 jge(oh_label_end, T_NEAR);
3457
3458 /* Compute middle block(s) */
3459 mov(reg_kh, jcp.kh);
3460 L(oh_label);
3461 {
3462 compute_oh_step_disp();
3463 add(reg_src, src_step_size * stride_h);
3464 add(reg_ddst, ddst_step_size);
3465
3466 inc(reg_oj);
3467
3468 if (is_partial) {
3469 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3470 jge(oh_bpad_label_end, T_NEAR);
3471 }
3472
3473 cmp(reg_oj, oh_body_end);
3474 jl(oh_label, T_NEAR);
3475 }
3476 L(oh_label_end);
3477
3478 /* Compute bottom edge */
3479 if (b_pad > 0) {
3480 if (is_partial) {
3481 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3482 jge(oh_bpad_label_end, T_NEAR);
3483 }
3484 cmp(reg_oj, jcp.oh);
3485 jge(oh_bpad_label_end, T_NEAR);
3486
3487 if (is_dilated) {
3488 // Assumes unit stride for dilations
3489 mov(reg_kh, jcp.kh - 1);
3490 xor_(reg_ih_shift, reg_ih_shift);
3491 } else {
3492 assert(jcp.dilate_h == 0);
3493 mov(reg_kh, jcp.ih - ih_body_end);
3494 }
3495 if (is_partial) {
3496 lea(reg_oj_setup,
3497 ptr[reg_oj - nstl::max(oh_body_end, oh_head_overflow_end)]);
3498 if (stride_h == 1 && !is_dilated) {
3499 sub(reg_kh, reg_oj_setup);
3500 } else {
3501 Label body_setup, body_setup_finish, dilate_skip;
3502 cmp(reg_oj_setup, 0);
3503 je(body_setup_finish, T_NEAR);
3504
3505 L(body_setup);
3506 if (is_dilated) {
3507 inc(reg_ih_shift);
3508 cmp(reg_ih_shift, dilate_h);
3509 jl(dilate_skip, T_NEAR);
3510 xor_(reg_ih_shift, reg_ih_shift);
3511 }
3512 sub(reg_kh, stride_h);
3513 L(dilate_skip);
3514 sub(reg_oj_setup, 1);
3515 jg(body_setup, T_NEAR);
3516 L(body_setup_finish);
3517 }
3518 }
3519
3520 if (is_dilated) mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3521 L(oh_bpad_label);
3522 {
3523 compute_oh_step_disp();
3524 add(reg_src, src_step_size * stride_h);
3525 add(reg_ddst, ddst_step_size);
3526
3527 if (is_dilated) {
3528 mov(reg_ih_shift, ptr[rsp + ih_dilate_shift]);
3529 inc(reg_ih_shift);
3530 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3531 cmp(reg_ih_shift, dilate_h);
3532 jl(oh_dilate_label_end, T_NEAR);
3533 xor_(reg_ih_shift, reg_ih_shift);
3534 mov(ptr[rsp + ih_dilate_shift], reg_ih_shift);
3535 }
3536 sub(reg_kh, stride_h);
3537 L(oh_dilate_label_end);
3538 inc(reg_oj);
3539 if (is_partial) {
3540 cmp(reg_oj, ptr[param + GET_OFF(os_index_end)]);
3541 jge(oh_bpad_label_end, T_NEAR);
3542 }
3543 cmp(reg_oj, oh_tail_end);
3544 jl(oh_bpad_label, T_NEAR);
3545 }
3546 }
3547 L(oh_bpad_label_end);
3548}
3549
3550void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_od_loop_common(
3551 bool is_partial) {
3552 assert(jcp.harness == harness_3d_reduction);
3553
3554 const int src_backpad_overlap
3555 = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
3556
3557 const auto filter_shift = get_kernel_offset(0, jcp.kh * jcp.kw);
3558 const auto src_shift = get_src_offset(0, 0, jcp.ih);
3559 const auto ddst_shift = get_ddst_offset(0, jcp.oh);
3560
3561 const int kd_front_pad = nstl::max(0, jcp.f_pad);
3562 const int kd_back_pad = nstl::max(0, jcp.kd - jcp.f_pad - jcp.id);
3563
3564 Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
3565 backpad_end_label, backpad_label;
3566
3567 /* initially offset 'kd' by f_pad */
3568 mov(reg_src_d, ptr[param + GET_OFF(src)]);
3569 mov(reg_ddst_d, ptr[param + GET_OFF(dst)]);
3570
3571 if (is_partial) {
3572 add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3573 mov(reg_d_index, ptr[param + GET_OFF(os_index_begin)]);
3574 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3575 } else {
3576 const int kd_padding = jcp.kd - kd_front_pad - kd_back_pad;
3577 const int kd_offset = get_kernel_offset(
3578 0, nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw);
3579 add(reg_kernel, kd_offset);
3580 xor_(reg_d_index, reg_d_index);
3581 mov(reg_kd_count, kd_padding);
3582 }
3583
3584 cmp(reg_kd_count, 0);
3585 jle(loop_end_label, T_NEAR); // no iterations along kd
3586 if (is_partial)
3587 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3588 else
3589 cmp(reg_d_index, jcp.od);
3590 jge(loop_end_label, T_NEAR); // no iterations along depth dimension
3591
3592 L(d_loop_label);
3593
3594 mov(reg_src, reg_src_d);
3595 mov(reg_ddst, reg_ddst_d);
3596
3597 mov(EVEX_compress_addr(rsp, src_d_offset), reg_src_d);
3598 mov(EVEX_compress_addr(rsp, ddst_d_offset), reg_ddst_d);
3599 mov(EVEX_compress_addr(rsp, d_index_offset), reg_d_index);
3600
3601 compute_oh_loop_common();
3602
3603 mov(reg_src_d, EVEX_compress_addr(rsp, src_d_offset));
3604 mov(reg_ddst_d, EVEX_compress_addr(rsp, ddst_d_offset));
3605 mov(reg_d_index, EVEX_compress_addr(rsp, d_index_offset));
3606
3607 /* Compute 'front' edge */
3608 if (jcp.f_pad > 0) {
3609 /* Check if within fpad region */
3610 cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
3611 jge(fpad_end_label, T_NEAR);
3612
3613 /* Fpad steps */
3614 sub(reg_kernel, filter_shift * jcp.stride_d);
3615 add(reg_kd_count, jcp.stride_d);
3616
3617 /* Final number of kernel elements that overlap with src */
3618 const int src_ker_overlap = nstl::min(jcp.kd, jcp.id);
3619 cmp(reg_kd_count, src_ker_overlap);
3620 jle(common_block_label, T_NEAR);
3621
3622 /* Correct any excess shifts to kernel and src */
3623 if (jcp.f_pad <= jcp.od * jcp.stride_d) {
3624 /* Filter has moved beyond padding (adjust for stride effects) */
3625 if (jcp.f_pad % jcp.stride_d != 0) {
3626 int src_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
3627 add(reg_kernel, filter_shift * src_corr);
3628 add(reg_src_d, src_shift * src_corr);
3629 }
3630 } else {
3631 /* Filter still overlaps padding (complete reset) */
3632 sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
3633 }
3634
3635 /* Apply correction */
3636 mov(reg_kd_count, src_ker_overlap);
3637 jmp(common_block_label);
3638
3639 L(fpad_end_label);
3640 }
3641
3642 /* Compute bottom edge */
3643 if (jcp.back_pad > 0) {
3644
3645 /* Check if within back_pad region */
3646 cmp(reg_d_index, src_backpad_overlap - 1);
3647 jl(backpad_end_label, T_NEAR);
3648 jg(backpad_label, T_NEAR);
3649
3650 /* Execute overlap correction between the filter and the initial
3651 * back_pad region. */
3652 mov(reg_kd_count,
3653 jcp.id + jcp.f_pad - src_backpad_overlap * jcp.stride_d);
3654 jmp(backpad_end_label, T_NEAR);
3655
3656 L(backpad_label);
3657 sub(reg_kd_count, jcp.stride_d);
3658 cmp(reg_kd_count, 0);
3659 jle(loop_end_label, T_NEAR);
3660
3661 L(backpad_end_label);
3662 }
3663
3664 /* Compute middle block */
3665 add(reg_src_d, src_shift * jcp.stride_d);
3666
3667 /* Execute common block and loop */
3668 L(common_block_label);
3669 add(reg_ddst_d, ddst_shift);
3670 inc(reg_d_index);
3671 if (is_partial)
3672 cmp(reg_d_index, ptr[param + GET_OFF(os_index_end)]);
3673 else
3674 cmp(reg_d_index, jcp.od);
3675 jl(d_loop_label, T_NEAR);
3676
3677 L(loop_end_label);
3678}
3679
3680void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::
3681 compute_full_spat_loop() {
3682 // General code layout:
3683 //
3684 // Blocking over OH -- top level
3685 // (Reduces L2 pressure; not very useful right now)
3686 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3687 // Loop over OH block -- emit_h_loop()
3688 // Loop over OW blocks -- emit_fma_block()
3689 // (Supports both fully unrolled and partially unrolled
3690 // versions to reduce code size)
3691 // Loop over OW block -- emit_fma_step()
3692
3693 auto src_row_size = get_src_offset(0, 0, 1);
3694 auto ddst_row_size = get_ddst_offset(0, 1);
3695 auto row_size = src_row_size + ddst_row_size;
3696
3697 int h_block_size = jcp.oh;
3698 int h_last_block_size = h_block_size;
3699 int min_h_block_size = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
3700 auto working_set_size = row_size * h_block_size;
3701
3702 if (working_set_size > full_spat_max_working_set_size) {
3703 assert(full_spat_opt_working_set_size < full_spat_max_working_set_size);
3704
3705 while (working_set_size > full_spat_opt_working_set_size
3706 && h_block_size >= min_h_block_size) {
3707 for (int i = 2; i <= h_block_size; i++)
3708 if (i == h_block_size)
3709 h_block_size = h_block_size / 2;
3710 else if (h_block_size % i == 0) {
3711 h_block_size = h_block_size / i;
3712 break;
3713 }
3714 working_set_size = row_size * h_block_size;
3715 }
3716 h_block_size = nstl::max(min_h_block_size, h_block_size);
3717 h_last_block_size = jcp.oh % h_block_size;
3718 if (h_last_block_size < jcp.b_pad) h_last_block_size += h_block_size;
3719 }
3720
3721 Opmask reg_h_block = k1;
3722 Reg64 reg_kh = rax;
3723 Reg64 reg_kw = rbx;
3724 Reg64 reg_tmp = abi_not_param1;
3725 Reg32 reg_tmp_w = reg_tmp.cvt32();
3726 Reg64 reg_ohs = rdx;
3727 Reg64 reg_ihs = rsi;
3728 Reg64 reg_h = r8;
3729 Reg64 reg_i = r9;
3730 Reg64 reg_j = r10;
3731
3732 Reg64 reg_src = r13;
3733 Reg64 reg_ddst = r14;
3734 Reg64 reg_ker = r15;
3735
3736 Reg64 reg_src_save = abi_param1;
3737 Reg64 reg_ddst_save = reg_tmp;
3738
3739 auto zmm_ddst = [&](int oi) { return Zmm(24 + oi % 8); };
3740 auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3741 auto src_addr = [&](int oi, int ic1) {
3742 return zword_b[reg_src + get_src_offset(ic1, oi)];
3743 };
3744 auto ddst_addr = [&](int oi) {
3745 auto ow_per_oc = 2;
3746 return ptr[reg_ddst + get_ddst_offset(ow_per_oc * oi)];
3747 };
3748 auto ker_addr
3749 = [&](int ic1) { return ptr[reg_ker + get_kernel_offset(ic1, 0)]; };
3750
3751 auto emit_block = [&]() {
3752 auto pad_ow = jcp.tr_ow;
3753 int ow_per_oc = 2;
3754 int def_step_size = 16;
3755 bool has_w_tail = pad_ow % def_step_size != 0;
3756 bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3757
3758 auto emit_step = [&](int ur_ow, bool is_w_tail) {
3759 int tail_size = pad_ow % ur_ow;
3760 int this_ur_ow = (is_w_tail && tail_size) ? tail_size : ur_ow;
3761 auto numloads = 1;
3762
3763 assert(this_ur_ow % ow_per_oc == 0);
3764 int steps = this_ur_ow / ow_per_oc;
3765 for (int oi_base = 0; oi_base < steps; oi_base += numloads) {
3766 for (int oi_offset = 0; oi_offset < numloads; oi_offset++) {
3767 int oi = oi_base + oi_offset;
3768 if (oi < steps) {
3769 vmovups(zmm_ddst(oi), ddst_addr(oi));
3770 } else {
3771 auto zmm = zmm_ddst(oi);
3772 vpxord(zmm, zmm, zmm);
3773 }
3774 }
3775
3776 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3777 vdpbf16ps(zmm_ker(ic1), zmm_ddst(oi_base),
3778 src_addr(ow_per_oc * oi_base, ic1));
3779 }
3780 }
3781 };
3782
3783 if (full_w_unroll) {
3784 emit_step(pad_ow, true);
3785 } else {
3786 Label w_loop;
3787 int num_w_iters = pad_ow / def_step_size;
3788 mov(reg_i, num_w_iters);
3789 L(w_loop);
3790 {
3791 emit_step(def_step_size, false);
3792 add(reg_src, get_src_offset(0, def_step_size));
3793 add(reg_ddst, get_ddst_offset(def_step_size));
3794 sub(reg_i, 1);
3795 jnz(w_loop);
3796 }
3797 if (has_w_tail) { emit_step(def_step_size, true); }
3798 // reset reg_src and reg_ddst because emit_h_loop expects
3799 // unmodified pointers
3800 int w_offset = num_w_iters * def_step_size;
3801 sub(reg_src, get_src_offset(0, w_offset));
3802 sub(reg_ddst, get_ddst_offset(w_offset));
3803 }
3804 };
3805
3806 auto emit_h_loop = [&]() {
3807 Label h_loop, skip_h_loop;
3808 mov(reg_j, 1);
3809 cmp(reg_j, reg_h);
3810 je(skip_h_loop, T_NEAR);
3811 L(h_loop);
3812 {
3813 emit_block();
3814
3815 add(reg_src, get_src_offset(0, 0, 1));
3816 add(reg_ddst, get_ddst_offset(0, 1));
3817 add(reg_j, 1);
3818 cmp(reg_j, reg_h);
3819 jb(h_loop);
3820 }
3821 L(skip_h_loop);
3822
3823 emit_block();
3824 };
3825
3826 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block) {
3827 xor_(reg_kh, reg_kh);
3828 Label kh_loop, kh_loop_end;
3829
3830 int oh_block_size = (is_last_block) ? h_last_block_size : h_block_size;
3831 // NB: this is correct because we only support t_pad = kh / 2 and thus
3832 // ih == oh
3833 int ih_block_size = oh_block_size
3834 + (!is_first_block + !is_last_block) * jcp.t_pad;
3835
3836 L(kh_loop);
3837 {
3838 if (is_first_block) {
3839 xor_(reg_tmp, reg_tmp);
3840 mov(reg_ohs, jcp.t_pad);
3841 sub(reg_ohs, reg_kh);
3842 cmovb(reg_ohs, reg_tmp);
3843
3844 mov(reg_ihs, reg_ohs);
3845 sub(reg_ihs, jcp.t_pad);
3846 add(reg_ihs, reg_kh);
3847 } else {
3848 xor_(reg_ohs, reg_ohs);
3849 mov(reg_ihs, reg_kh);
3850 }
3851
3852 mov(reg_tmp, oh_block_size);
3853 sub(reg_tmp, reg_ohs);
3854 mov(reg_h, ih_block_size);
3855 sub(reg_h, reg_ihs);
3856 cmp(reg_tmp, reg_h);
3857 cmovb(reg_h, reg_tmp);
3858
3859 Label kh_loop_work;
3860 cmp(reg_h, 0);
3861 jg(kh_loop_work, T_NEAR);
3862
3863 // empty h loop for this jcp.kh:
3864 // - set the ddst to 0 if necessary
3865 // - move ker pt
3866 // - jump to the end
3867 sub(reg_h, 1);
3868 Label skip_ker_zeroing;
3869
3870 // The reg_ker ptr has highest bit set if the ddst needs to be
3871 // zeroed. Those who have byte-aligned their data will suffer the
3872 // consequences :(
3873 // TODO: move the flag to a mask register? (Roma)
3874 test(reg_ker, 1);
3875 jz(skip_ker_zeroing, T_NEAR);
3876
3877 Label zeroing_loop;
3878 vpxord(zmm0, zmm0, zmm0);
3879 and_(reg_ker, ~1); // temporarily clear the zeroing flag
3880 mov(reg_tmp, jcp.kw);
3881 L(zeroing_loop);
3882 {
3883 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3884 vmovups(ker_addr(ic1), zmm0);
3885 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3886 sub(reg_tmp, 1);
3887 jnz(zeroing_loop, T_NEAR);
3888 }
3889 // restore the zeroing flag (it will be cleared after the end of
3890 // emit_kh_kw_loop, but we may need it until then)
3891 or_(reg_ker, 1);
3892 jmp(kh_loop_end, T_NEAR);
3893
3894 L(skip_ker_zeroing);
3895 add(reg_ker, get_kernel_offset(0, jcp.kw));
3896 jmp(kh_loop_end, T_NEAR);
3897
3898 L(kh_loop_work);
3899
3900 mul_by_const(reg_ihs, reg_tmp, get_src_offset(0, 0, 1));
3901 mul_by_const(reg_ohs, reg_tmp, get_ddst_offset(0, 1));
3902
3903 add(reg_src, reg_ihs);
3904 add(reg_ddst, reg_ohs);
3905
3906 Label kw_loop;
3907 xor_(reg_kw, reg_kw);
3908 L(kw_loop);
3909 {
3910 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3911 auto zmm = zmm_ker(ic1);
3912 vpxord(zmm, zmm, zmm);
3913 }
3914
3915 mov(reg_ddst_save, reg_ddst);
3916 mov(reg_src_save, reg_src);
3917 lea(reg_src, ptr[reg_src + reg_kw * jcp.typesize_in]);
3918
3919 emit_h_loop();
3920
3921 mov(reg_ddst, reg_ddst_save);
3922 mov(reg_src, reg_src_save);
3923
3924 Label do_store;
3925 // The reg_ker ptr has highest bit set if the ddst needs to
3926 // be zeroed. Those who have byte-aligned their data will
3927 // suffer the consiquences :(
3928 mov(reg_tmp, reg_ker);
3929 and_(reg_ker, ~1);
3930 test(reg_tmp, 1);
3931 jnz(do_store, T_NEAR);
3932
3933 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3934 auto zmm = zmm_ker(ic1);
3935 vaddps(zmm, ker_addr(ic1));
3936 }
3937
3938 L(do_store);
3939 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3940 auto zmm = zmm_ker(ic1);
3941 vmovups(ker_addr(ic1), zmm);
3942 }
3943
3944 mov(reg_ker, reg_tmp);
3945 add(reg_ker, get_kernel_offset(jcp.ic_block, 0));
3946 add(reg_kw, 1);
3947 cmp(reg_kw, jcp.kw);
3948 jl(kw_loop);
3949 }
3950
3951 sub(reg_src, reg_ihs);
3952 sub(reg_ddst, reg_ohs);
3953
3954 L(kh_loop_end);
3955 add(reg_kh, 1);
3956 cmp(reg_kh, jcp.kh);
3957 jl(kh_loop);
3958 }
3959 };
3960
3961 mov(reg_src, ptr[param + GET_OFF(src)]);
3962 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
3963 mov(reg_ker, ptr[param + GET_OFF(filt)]);
3964 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3965 or_(reg_ker, reg_tmp);
3966
3967 bool single_kh_kw_loop = (h_last_block_size == jcp.oh);
3968
3969 auto src_row_step = get_src_offset(0, 0, 1);
3970 auto first_src_block_step = src_row_step * (h_block_size - jcp.t_pad);
3971 auto ddst_block_step = get_ddst_offset(0, h_block_size);
3972
3973 emit_kh_kw_loop(true, single_kh_kw_loop);
3974
3975 if (!single_kh_kw_loop) {
3976 auto ker_reset_offset = get_kernel_offset(0, jcp.kw * jcp.kh);
3977 sub(reg_ker, ker_reset_offset);
3978 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
3979
3980 add(reg_src, first_src_block_step);
3981 add(reg_ddst, ddst_block_step);
3982
3983 int num_innermost_iters
3984 = (jcp.oh - h_last_block_size) / h_block_size - 1;
3985 if (num_innermost_iters > 0) {
3986 Label h_block_loop;
3987
3988 mov(reg_tmp_w, num_innermost_iters);
3989 kmovw(reg_h_block, reg_tmp_w);
3990 L(h_block_loop);
3991 {
3992 emit_kh_kw_loop(false, false);
3993 sub(reg_ker, ker_reset_offset);
3994 add(reg_src, src_row_step * h_block_size);
3995 add(reg_ddst, ddst_block_step);
3996
3997 kmovw(reg_tmp_w, reg_h_block);
3998 sub(reg_tmp_w, 1);
3999 kmovw(reg_h_block, reg_tmp_w);
4000 jnz(h_block_loop);
4001 }
4002 }
4003
4004 emit_kh_kw_loop(false, true);
4005 }
4006}
4007
4008void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 ::compute_loop() {
4009 Reg64 reg_mask_load = r11;
4010 if (jcp.uses_permw_transposition) {
4011
4012 mov(reg_mask_load.cvt32(), 0xffffffff);
4013 kmovd(m_ffffffff, reg_mask_load.cvt32());
4014
4015 mov(reg_mask_load.cvt32(), 0x0000ffff);
4016 kmovd(m_0000ffff, reg_mask_load.cvt32());
4017
4018 mov(reg_mask_load.cvt32(), 0xffff0000);
4019 kmovd(m_ffff0000, reg_mask_load.cvt32());
4020 const int oc_tail = jcp.oc_tail;
4021 if (oc_tail) {
4022 mov(reg_mask_load.cvt32(), (1 << oc_tail) - 1);
4023 kmovd(m_0000_oc_tail, reg_mask_load.cvt32());
4024 kshiftld(m_oc_tail_0000, m_0000_oc_tail, 16);
4025 }
4026 const int ic_tail = jcp.ic_tail;
4027 if (ic_tail) {
4028 mov(reg_mask_load.cvt32(), (1 << ic_tail) - 1);
4029 kmovd(m_0000_ic_tail, reg_mask_load.cvt32());
4030 kshiftld(m_ic_tail_0000, m_0000_ic_tail, 16);
4031 }
4032 } else if (jcp.is_1stconv && !jcp.transpose_src) {
4033 if (jcp.stride_w == 1) {
4034 int ieveryother_mask = 0x55555555;
4035 mov(reg_mask_load.cvt32(), ieveryother_mask);
4036 kmovd(everyother_mask, reg_mask_load.cvt32());
4037 kshiftld(everyother_shift_mask, everyother_mask, 1);
4038 } else {
4039 mov(reg_mask_load.cvt32(), 0xffffffff);
4040 kmovd(m_ffffffff, reg_mask_load.cvt32());
4041 }
4042 }
4043
4044 mov(reg_src, ptr[param + GET_OFF(src)]);
4045 mov(reg_ddst, ptr[param + GET_OFF(dst)]);
4046 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4047
4048 maybe_zero_kernel();
4049 maybe_compute_diff_bias();
4050
4051 switch (jcp.harness) {
4052 case harness_3d_reduction: compute_od_loop_common(true); break;
4053 case harness_2d_reduction: compute_oh_loop_common(true); break;
4054 case harness_mb_reduction: compute_oh_loop_common(); break;
4055 case harness_compute_full_spatial: compute_full_spat_loop(); break;
4056 default: assert(!"Invalid harness type");
4057 }
4058}
4059
4060void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::setup_stack_space() {
4061
4062 if ((jcp.is_1stconv && !jcp.transpose_src && jcp.stride_w > 1)
4063 || jcp.uses_permw_transposition) {
4064 int ur_w, ur_w_tail, ur_w_trips;
4065 get_ur_w(ur_w, ur_w_tail, ur_w_trips);
4066 ur_w = nstl::max(ur_w, ur_w_tail);
4067 ic_block_step_stack_size = jcp.uses_permw_transposition
4068 ? permw_stack_size(ur_w)
4069 : interleave_stack_size(ur_w, jcp.ic_block_step);
4070 } else
4071 ic_block_step_stack_size = extern_ic_block_step_stack_size;
4072
4073 permw_buffer_start = 0;
4074 kd_count_offset = ic_block_step_stack_size;
4075 src_d_offset = ic_block_step_stack_size + 8;
4076 ddst_d_offset = ic_block_step_stack_size + 16;
4077 d_index_offset = ic_block_step_stack_size + 24;
4078 trans_tmp_offset = ic_block_step_stack_size + 32;
4079 ih_dilate_shift = ic_block_step_stack_size + 40;
4080 icb_loop_ker_ptr = ic_block_step_stack_size + 48;
4081 icb_loop_src_ptr = ic_block_step_stack_size + 56;
4082 stack_space_needed = ic_block_step_stack_size + 64;
4083}
4084
4085void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate() {
4086 preamble();
4087
4088 setup_stack_space();
4089
4090 sub(rsp, stack_space_needed);
4091
4092 compute_loop();
4093
4094 add(rsp, stack_space_needed);
4095
4096 postamble();
4097
4098 if (jcp.uses_permw_transposition) {
4099 align(64);
4100 L(dst_prm_table);
4101 const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20,
4102 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
4103 29, 14, 30, 15, 31};
4104
4105 for (size_t i = 0; i < 32; ++i)
4106 dw(dst_prm_array[i]);
4107 }
4108}
4109
4110status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
4111 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4112 memory_desc_t &src_md, memory_desc_t &diff_weights_md,
4113 memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md, int nthreads) {
4114 const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
4115
4116 const memory_desc_wrapper src_d(&src_md);
4117 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
4118 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
4119 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
4120
4121 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4122 int ndims = src_d.ndims();
4123
4124 jcp = zero<decltype(jcp)>();
4125 jcp.nthr = nthreads;
4126 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
4127 : bf16_emulation_t::get_isa();
4128 jcp.has_vnni = true;
4129 jcp.ndims = ndims;
4130 jcp.prop_kind = cd.prop_kind;
4131
4132 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4133 jcp.mb = src_d.dims()[0];
4134
4135 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4136 jcp.oc_without_padding = jcp.oc;
4137 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4138
4139 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4140 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
4141 jcp.iw = src_d.dims()[ndims - 1];
4142 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4143 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
4144 jcp.ow = diff_dst_d.dims()[ndims - 1];
4145
4146 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4147 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims - 2];
4148 jcp.kw = diff_weights_d.dims()[with_groups + ndims - 1];
4149
4150 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4151 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
4152 jcp.l_pad = cd.padding[0][ndims - 3];
4153
4154 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4155 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
4156 jcp.stride_w = cd.strides[ndims - 3];
4157
4158 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4159 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
4160 jcp.dilate_w = cd.dilates[ndims - 3];
4161
4162 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
4163 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
4164 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
4165
4166 bool ok = true
4167 // general condition to simplify dilations
4168 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4169 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4170 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4171 // special condition to simplify dilations in compute_oh_loop_common
4172 && IMPLICATION(jcp.dilate_h != 0, ext_kh <= jcp.ih);
4173 if (!ok) return status::unimplemented;
4174
4175 jcp.r_pad = nstl::max(0,
4176 calculate_end_padding(
4177 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw));
4178 jcp.b_pad = nstl::max(0,
4179 calculate_end_padding(
4180 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh));
4181 jcp.back_pad = nstl::max(0,
4182 calculate_end_padding(
4183 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd));
4184
4185 /* XXX: no support for padding when dilation_d > 0 */
4186 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
4187 return status::unimplemented;
4188
4189 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4190 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4191 jcp.ohp = jcp.oh;
4192 jcp.owp = jcp.ow;
4193 jcp.aligned_threads = 0;
4194
4195 jcp.simd_w = simd_w;
4196 jcp.oc_block = simd_w;
4197 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
4198 const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw);
4199 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4200 auto curr_src_tag = src_d.matches_one_of_tag(
4201 dat_tag_nxc, dat_tag_nCx16c, dat_tag_ncx);
4202 auto curr_dst_tag
4203 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
4204 bool is_data_layout_nxc = IMPLICATION(curr_src_tag != dat_tag_nxc,
4205 src_d.format_kind() == format_kind::any)
4206 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
4207 diff_dst_d.format_kind() == format_kind::any)
4208 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
4209
4210 jcp.is_1stconv = is_1stconv(jcp);
4211
4212 bool ok_to_pad_channels
4213 = (jcp.ngroups == 1) && !jcp.is_1stconv && !is_data_layout_nxc;
4214
4215 if (ok_to_pad_channels) {
4216 jcp.oc = rnd_up(jcp.oc, simd_w);
4217 jcp.ic = rnd_up(jcp.ic, simd_w);
4218 }
4219
4220 auto src_tag = is_data_layout_nxc
4221 ? dat_tag_nxc
4222 : (jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c);
4223 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
4224 auto wei_tag = jcp.is_1stconv
4225 ? pick(2 * ndims - 6 + with_groups, Owi16o, gOwi16o, Ohwi16o,
4226 gOhwi16o, Odhwi16o, gOdhwi16o)
4227 : pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o,
4228 OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o);
4229
4230 if (src_md.format_kind == format_kind::any) {
4231 CHECK(memory_desc_init_by_tag(src_md, src_tag));
4232 } else if (curr_src_tag != src_tag)
4233 return status::unimplemented;
4234 jcp.src_tag = src_tag;
4235
4236 if (diff_dst_md.format_kind == format_kind::any) {
4237 CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
4238 } else if (curr_dst_tag != dst_tag)
4239 return status::unimplemented;
4240 jcp.dst_tag = dst_tag;
4241
4242 if (diff_weights_md.format_kind == format_kind::any) {
4243 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
4244 jcp.wei_tag = wei_tag;
4245 } else {
4246 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
4247 if (jcp.wei_tag != wei_tag) return status::unimplemented;
4248 }
4249
4250 /* conditions on bias memory */
4251 jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
4252 if (jcp.with_bias) {
4253 if (diff_bias_d.format_kind() == format_kind::any)
4254 CHECK(memory_desc_init_by_tag(diff_bias_md, x));
4255 }
4256 jcp.bia_dt = jcp.with_bias ? diff_bias_d.data_type() : data_type::undef;
4257 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
4258
4259 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
4260
4261 /* kernel applicability check wrt boundaries
4262 * the conditions are quite general across the kernels we have,
4263 * but ideally the check should belong to a specific kernel... */
4264 const int max_pad_h = ext_kh / 2;
4265 const bool boundaries_ok = true && jcp.l_pad < ext_kw && jcp.r_pad < ext_kw
4266 && jcp.t_pad <= max_pad_h && jcp.b_pad <= max_pad_h
4267 && jcp.f_pad < ext_kd && jcp.back_pad < ext_kd
4268 && IMPLICATION(jcp.is_1stconv && jcp.ow > max_ur_w,
4269 jcp.l_pad < max_ur_w && ext_kw <= jcp.ow);
4270 if (!boundaries_ok) return status::unimplemented;
4271
4272 const int max_kw = jcp.is_1stconv ? 24 : 14;
4273 /* yet another common check */
4274 if (jcp.kw > max_kw) return status::unimplemented;
4275
4276 jcp.wei_dt = diff_weights_d.data_type();
4277
4278 jcp.ic_block = jcp.is_1stconv ? jcp.ic : simd_w;
4279 if (ok_to_pad_channels) jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4280 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
4281 ok = true && one_of(ndims, 3, 4, 5)
4282 && everyone_is(
4283 data_type::bf16, src_d.data_type(), diff_dst_d.data_type())
4284 && one_of(diff_weights_d.data_type(), data_type::f32,
4285 data_type::bf16);
4286 if (!ok) return status::unimplemented;
4287
4288 jcp.ic_tail = is_data_layout_nxc ? jcp.ic % jcp.ic_block : 0;
4289 jcp.oc_tail = is_data_layout_nxc ? jcp.oc % jcp.oc_block : 0;
4290
4291 if (jcp.is_1stconv) {
4292 jcp.ic_block_step = 24 / jcp.kw;
4293 while (jcp.ic_block % jcp.ic_block_step != 0)
4294 jcp.ic_block_step--;
4295 } else {
4296 jcp.ic_block_step
4297 = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : (jcp.kw <= 12 ? 2 : 1));
4298 }
4299
4300 // jcp.uses_permw_transposition = false shows better performance for
4301 // resnet50 v1.5 problems
4302 // jcp.uses_permw_transposition = true works better for 3d 1x1x1 problems
4303 const bool is_permw_applicable
4304 = !jcp.is_1stconv && jcp.stride_w == 1 && jcp.dilate_w == 0;
4305 const bool apply_permw_blocked = !is_data_layout_nxc && ndims == 5
4306 && jcp.kw == 1 && jcp.ic_block_step > 4;
4307 // Threshold is based on performance measurements
4308 const bool apply_permw_nxc = is_data_layout_nxc && ndims == 3
4309 && nstl::max(jcp.ic, jcp.oc) <= 32;
4310 jcp.uses_permw_transposition
4311 = is_permw_applicable && (apply_permw_blocked || apply_permw_nxc);
4312
4313 jcp.kernel_kind = embd_bcast;
4314 if (jcp.uses_permw_transposition && jcp.kw <= 3)
4315 jcp.kernel_kind = expl_bcast;
4316 if (jcp.uses_permw_transposition && jcp.kernel_kind == expl_bcast)
4317 jcp.ic_block_step = jcp.kw <= 3 ? 4 : (jcp.kw < 7 ? 2 : 1);
4318
4319 if (jcp.uses_permw_transposition) {
4320 jcp.transpose_src = false;
4321 jcp.transpose_dst = false;
4322 } else if (jcp.is_1stconv && IMPLICATION(is_data_layout_nxc, jcp.ic == 1)) {
4323 jcp.transpose_src = false;
4324 jcp.transpose_dst = true;
4325 } else {
4326 jcp.transpose_src = true;
4327 jcp.transpose_dst = true;
4328 }
4329
4330 const bool is_2d = (ndims == 4);
4331 const bool is_3d = (ndims == 5);
4332 jcp.typesize_in = sizeof(bfloat16_t);
4333 jcp.typesize_out = sizeof(float);
4334 const dim_t cache_l2
4335 = platform::get_per_core_cache_size(2) / jcp.typesize_out;
4336
4337 // Observation: Given large 3D shapes with large filter size, 1st nspc
4338 // bwd_w convolution benefits from non-temporal stores in diff_dst
4339 // transformation but not so much from blocking w.r.t. depth dimension
4340 // In particular, it's optimized for i3D 1st convolution
4341 const bool nt_stores_ok = is_data_layout_nxc
4342 && dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= 2 * cache_l2
4343 && jcp.kd >= 6 && jcp.kh >= 6 && jcp.kw >= 6;
4344
4345 // Performancewise transposition of diff_dst tensor is one of the major
4346 // bottleneck in 1st convolution. Thus for large diff_dst size we can
4347 // potentially further split up transposition in smaller chunks to achieve
4348 // better cache reuse
4349 const bool large_diff_dst_size
4350 = dim_t(jcp.oc) * jcp.od * jcp.oh * jcp.ow >= cache_l2;
4351
4352 // For two dimensional diff_dst tensor blocking along height demands
4353 // non-trivial work along width dimension. Similarly, for three dimensional
4354 // diff_dst tensor enough work must be present in the joint width-height
4355 // dimension. Finally, there is no blocking along the width dimension
4356 const bool blocking_ok = large_diff_dst_size
4357 && IMPLICATION(is_2d, jcp.ow >= 124 && jcp.oh > 1)
4358 && IMPLICATION(is_3d, jcp.ow * jcp.oh >= 64 * 124 && jcp.od > 1)
4359 && (is_2d || is_3d);
4360
4361 // TODO: Find more shapes (especially 3D with large spatials) for which
4362 // local transposition will be beneficial. Furthermore, for TBB threads
4363 // more shapes can potentially benefit from spatial blocking
4364 bool use_spatial_blocking = jcp.is_1stconv && !nt_stores_ok && blocking_ok;
4365 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
4366 if (use_spatial_blocking) {
4367 // Default value, works best most of the times
4368 // TODO: For 3D shapes with intermediate sizes especially the ones not
4369 // belonging to the 1st convolution, we potentially have more scope
4370 // for optimization
4371 optimal_blk_size = 1;
4372
4373 // Diff_weights computation can be roughly broken down into
4374 // the following three steps
4375 // = [Src transform*] + [Diff_dst transform] + [Weights computation]
4376 //
4377 // where the bottleneck lies with diff_dst transform that spatial
4378 // blocking tries to mitigate by avoiding cache thrashing.
4379 // *note: Src transform may not always be needed.
4380 //
4381 // In an idealistic scenario, optimal_blk_size will be an explicit
4382 // function of the following form
4383 // optimal_blk_size = f(od, oh, ow, oc)
4384 //
4385 // though owing to lack of data points w.r.t. 1st convolution shapes it
4386 // is approximated by one with few exceptional cases [found by manual
4387 // optimization] as written below
4388
4389 if (is_2d && utils::one_of(jcp.oh, 149, 300, 224, 512, 608)) {
4390 switch (jcp.oh) {
4391 case 149: optimal_blk_size = 10; break;
4392 case 224: optimal_blk_size = 56; break;
4393 case 300: optimal_blk_size = 30; break;
4394 case 512: optimal_blk_size = 8; break;
4395 case 608: optimal_blk_size = 10; break;
4396 }
4397 }
4398 }
4399
4400 jcp.global_transpose = dnnl_thr_syncable() && !use_spatial_blocking;
4401 jcp.use_nt_stores_ddst = jcp.global_transpose && nt_stores_ok;
4402 jcp.spatial_blk_size = optimal_blk_size;
4403
4404 const bool padding_ok = IMPLICATION(!jcp.transpose_src,
4405 jcp.l_pad < max_ur_w && jcp.r_pad < max_ur_w
4406 && ext_kw <= jcp.iw + 1);
4407 if (!padding_ok) return status::unimplemented;
4408
4409 const int tr_round = 2;
4410 // Logic for tr_pad calculation: transpose is used in the extern kernel.
4411 // There is a memory usage optimization where physical padding is shared
4412 // between transpose buffers. In calculating on a row, data is read from the
4413 // src 2 elements at a time due to the bf16 broadcast. Calculation starts
4414 // at the beginning of the left padding and ends at the end of the right
4415 // padding. Because elements are read two at a time, we may need r_pad + 1
4416 // padding on the right. As such, the shared padding is the max of l_pad and
4417 // r_pad + 1, rounded as necessary for the transpose data format.
4418 int tr_pad = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round);
4419 jcp.tr_iw = jcp.transpose_src
4420 ? rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
4421 * jcp.stride_w
4422 : jcp.iw;
4423
4424 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
4425 jcp.tr_ow = jcp.transpose_dst ? rnd_up(jcp.ow, 2) : jcp.ow;
4426
4427 bool args_ok = true
4428 && IMPLICATION(!is_data_layout_nxc,
4429 jcp.ic % jcp.ic_block == 0 && jcp.oc % jcp.oc_block == 0)
4430 && jcp.ic <= src_d.padded_dims()[1]
4431 && jcp.oc <= diff_dst_d.padded_dims()[1]
4432 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
4433 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
4434 if (!args_ok) return status::unimplemented;
4435
4436 int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
4437 int out_row_size = jcp.oc_block * jcp.tr_ow * jcp.typesize_in;
4438 int full_spat_min_h_block_size
4439 = nstl::max(1, nstl::max(jcp.b_pad, jcp.t_pad));
4440 int full_spat_working_set_size
4441 = (inp_row_size + out_row_size) * full_spat_min_h_block_size;
4442 bool use_full_spat_loop = isa_has_bf16(jcp.isa) && jcp.ndims < 5
4443 && jcp.ih == jcp.oh && jcp.iw == jcp.ow
4444 && !one_of(1, jcp.kh, jcp.kw)
4445 && everyone_is(1, jcp.stride_h, jcp.stride_w)
4446 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
4447 && jcp.l_pad == jcp.kw / 2 && jcp.t_pad == jcp.kh / 2
4448 && !jcp.uses_permw_transposition && !jcp.is_1stconv
4449 && full_spat_working_set_size <= full_spat_opt_working_set_size
4450 && jcp.ic >= 128;
4451
4452 jcp.harness = ndims == 5
4453 ? harness_3d_reduction
4454 : (use_full_spat_loop ? harness_compute_full_spatial
4455 : (ndims == 4) ? harness_2d_reduction
4456 : harness_mb_reduction);
4457
4458 switch (jcp.harness) {
4459 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
4460 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
4461 case harness_compute_full_spatial:
4462 case harness_mb_reduction: jcp.nthr_mb_work = jcp.mb; break;
4463 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
4464 }
4465 { // balancing
4466 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4467 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4468 jcp.nthr = nthr;
4469 jcp.nthr_mb = nthr_mb;
4470 jcp.nthr_g = nthr_g;
4471 jcp.nthr_oc_b = nthr_oc_b;
4472 jcp.nthr_ic_b = nthr_ic_b;
4473
4474 // TODO: Optimize memory allocation when threaded on height and depth
4475 if (jcp.transpose_src) {
4476 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
4477 jcp.tr_src_buf_count = jcp.global_transpose
4478 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
4479 : jcp.nthr;
4480 }
4481 if (jcp.transpose_dst) {
4482 jcp.tr_diff_dst_buf_size
4483 = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
4484 jcp.tr_diff_dst_buf_count = jcp.global_transpose
4485 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
4486 : jcp.nthr;
4487 }
4488 }
4489
4490 jcp.nb_ic_blocking_max = 1;
4491 if (is_data_layout_nxc && jcp.uses_permw_transposition
4492 && (jcp.ow > max_ur_w || jcp.ndims == 5))
4493 jcp.nb_ic_blocking_max = nstl::min(8, div_up(jcp.nb_ic, jcp.nthr_ic_b));
4494 return status::success;
4495}
4496
4497void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
4498 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4499
4500 if (!jcp.uses_permw_transposition) {
4501 // XXX: See the comment about tr_iw and guarding elements in
4502 // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
4503 const size_t tr_src_size = jcp.tr_src_buf_count * jcp.tr_src_buf_size
4504 + jcp.tr_src_num_guard_elems;
4505 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.typesize_in);
4506
4507 /* prepare synchronization contexts */
4508 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
4509 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4510 scratchpad.book<simple_barrier::ctx_t>(
4511 key_conv_tr_src_bctx, tr_src_bctx_size);
4512 }
4513
4514 const size_t tr_diff_dst_size
4515 = jcp.tr_diff_dst_buf_count * jcp.tr_diff_dst_buf_size;
4516
4517 const size_t min_align = jcp.use_nt_stores_ddst ? 64 : jcp.typesize_in;
4518 scratchpad.book(key_conv_tr_diff_dst, tr_diff_dst_size, jcp.typesize_in,
4519 min_align);
4520
4521 /* prepare synchronization contexts */
4522 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
4523 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
4524 scratchpad.book<simple_barrier::ctx_t>(
4525 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
4526 }
4527 }
4528
4529 if (IMPLICATION(jcp.nthr_mb == 1,
4530 (jcp.with_bias && jcp.bia_dt == data_type::bf16)
4531 || jcp.wei_dt == data_type::bf16)) {
4532 const size_t wei_size = static_cast<size_t>(jcp.ngroups) * jcp.nb_oc
4533 * jcp.oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw
4534 * jcp.kd;
4535 const size_t bia_size
4536 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
4537
4538 const int num_wei_buffers
4539 = jcp.wei_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1;
4540 const int num_bia_buffers = jcp.with_bias
4541 ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb
4542 : jcp.nthr_mb - 1)
4543 : 0;
4544
4545 const size_t wei_bia_reduction_size
4546 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
4547
4548 scratchpad.book<float>(
4549 key_conv_wei_bia_reduction, wei_bia_reduction_size);
4550
4551 if (jcp.global_transpose)
4552 scratchpad.book<simple_barrier::ctx_t>(
4553 key_conv_wei_bia_reduction_bctx, 1);
4554 }
4555
4556 if (jcp.with_bias) {
4557 if ((jcp.oc_without_padding % jcp.oc_block != 0)
4558 && jcp.bia_dt == data_type::f32)
4559 scratchpad.book(key_conv_padded_bias,
4560 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.typesize_bia);
4561 }
4562}
4563
4564void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance(
4565 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4566 int &nthr_oc_b_, int &nthr_ic_b_) {
4567 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4568
4569 const int max_threads = dnnl_get_max_threads();
4570
4571 if (max_threads < j.ngroups) {
4572 /* simplification... fortunately it doesn't hurt much */
4573 nthr_ = nthr_g_ = max_threads;
4574 return;
4575 }
4576
4577 nthr_g_ = j.ngroups;
4578 const int nthr = max_threads / nthr_g_;
4579
4580 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4581 /* calculate per thread memory cost (read/write). high level optimizer
4582 * tries to minimize memory consumption. few notes:
4583 * (n1) if weights tensor size is less than source and destination
4584 * tensors we apply the ratio of the source and destination
4585 * tensor sizes to weights one as compensation coefficient to
4586 * avoid parallelization across batch size only, othervise we
4587 * apply additional coefficient to source component based on
4588 * performance measurements
4589 * (n2) use scales based on output vs input channels ratio for source
4590 * and destination componets to imporve threading balance across
4591 * input and output channels */
4592
4593 const dim_t src_type_size = 2;
4594 const dim_t wei_type_size = 4;
4595
4596 dim_t src_size
4597 = (dim_t)j.mb * j.ic * j.id * j.ih * j.tr_iw * src_type_size;
4598 dim_t dst_size
4599 = (dim_t)j.mb * j.oc * j.od * j.oh * j.tr_ow * src_type_size;
4600 dim_t wei_size
4601 = (dim_t)j.oc * j.ic * j.kd * j.kh * j.kw * wei_type_size;
4602
4603 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
4604 float oi_channels_ratio = (float)j.nb_oc / j.nb_ic;
4605 auto get_src_coef = [=]() {
4606 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
4607 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
4608
4609 return src_coef;
4610 };
4611
4612 auto get_dst_coef
4613 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
4614
4615 auto get_wei_coef
4616 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
4617
4618 const float src_coef = get_src_coef();
4619 const float dst_coef = get_dst_coef();
4620 const float wei_coef = get_wei_coef();
4621
4622 float src_v = src_coef * div_up(j.nthr_mb_work, nthr_mb)
4623 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_ic, nthr_ic_b) * j.mb
4624 * j.ic_block * j.id * j.ih * j.tr_iw / j.nthr_mb_work
4625 / j.stride_d / j.stride_h / j.stride_w;
4626 float wei_v = wei_coef * div_up(j.ngroups, nthr_g_)
4627 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b) * j.kh
4628 * j.kw * j.kd * j.ic_block * j.oc_block;
4629 float dst_v = dst_coef * div_up(j.nthr_mb_work, nthr_mb)
4630 * div_up(j.ngroups, nthr_g_) * div_up(j.nb_oc, nthr_oc_b) * j.mb
4631 * j.oc_block * j.od * j.oh * j.tr_ow / j.nthr_mb_work;
4632
4633 return src_v + dst_v + wei_v;
4634 };
4635
4636 float best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4637
4638 /* find the best thread distribution with lowest memory cost */
4639 const int nthr_mb_max = nstl::min(nthr, j.nthr_mb_work);
4640 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4641 const int nthr_par = nthr / nthr_mb;
4642 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4643 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4644 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4645
4646 float mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4647 if (mem_cost <= best_mem_cost) {
4648 best_mem_cost = mem_cost;
4649 nthr_mb_ = nthr_mb;
4650 nthr_oc_b_ = nthr_oc_b;
4651 nthr_ic_b_ = nthr_ic_b;
4652 }
4653 }
4654 }
4655
4656 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
4657 nthr_mb_ = nstl::min(j.nthr_mb_work, nthr);
4658 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
4659
4660 assert(nthr_ <= max_threads);
4661}
4662
4663template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>;
4664template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>;
4665template struct _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>;
4666template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>;
4667template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>;
4668template struct _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>;
4669} // namespace x64
4670} // namespace cpu
4671} // namespace impl
4672} // namespace dnnl
4673// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
4674