1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21#include "cpu/cpu_primitive.hpp"
22#include "cpu/scale_utils.hpp"
23
24#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
25#include "cpu/x64/jit_brgemm_conv.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35
36using namespace nstl;
37using namespace data_type;
38
39using namespace jit_avx512_core_brgemm_conv_trans_kernel;
40using namespace jit_avx512_core_brgemm_conv_comp_pad_kernel;
41
42#define ndims_pick(v5, v4, v3) \
43 ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
44
45template <cpu_isa_t isa, bool use_inversion>
46status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init(
47 engine_t *engine) {
48 using namespace data_type;
49 using namespace utils;
50
51 const auto src_type = src_md(0)->data_type;
52 const auto wei_type = weights_md(0)->data_type;
53 const auto dst_type = dst_md(0)->data_type;
54 const bool is_int8 = one_of(src_type, u8, s8);
55
56 using skip_mask_t = primitive_attr_t::skip_mask_t;
57 auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt
58 | skip_mask_t::zero_points_runtime;
59 if (is_int8) skip_mask |= skip_mask_t::scales_runtime;
60
61 bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct)
62 && IMPLICATION(is_int8,
63 one_of(bias_md_.data_type, data_type::undef, f32, s32, s8,
64 u8))
65 && IMPLICATION(!is_int8,
66 one_of(bias_md_.data_type, data_type::undef, f32, src_type))
67 && attr()->has_default_values(skip_mask, dst_type)
68 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
69 && !has_zero_dim_memory() && zero_points_ok() && arg_scales_ok();
70 if (!ok) return status::unimplemented;
71 const auto is_amx = brgemm_convolution_utils::is_amx(isa);
72
73 CHECK(brgemm_convolution_utils::init_conf(jcp_, isa, *desc(), src_md_,
74 weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads()));
75
76 const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail);
77
78 // 1. Use unrolled kernel for exec_trans only to avoid creation a lot of
79 // kernels for each kw range
80 // 2. For exec_trans block by kw is always KW
81 assert(IMPLICATION(jcp_.use_uker, is_amx && jcp_.exec_type == exec_trans));
82 assert(IMPLICATION(jcp_.use_interleave_stores, jcp_.use_uker));
83
84 batchsizes.resize(jcp_.max_batch + 1);
85 for (int i = 0; i <= jcp_.max_batch; i++)
86 batchsizes[i] = -1;
87
88 first_bs = 0;
89 bs_c = 0;
90
91 const auto SD = jcp_.stride_d;
92 const auto FP = jcp_.f_pad;
93 const auto DD = jcp_.dilate_d + 1;
94 const auto KD = jcp_.kd;
95 const auto ID = jcp_.id;
96
97 const auto SH = jcp_.stride_h;
98 const auto TP = jcp_.t_pad;
99 const auto DH = jcp_.dilate_h + 1;
100 const auto KH = jcp_.kh;
101 const auto KW = jcp_.kw;
102 const auto IH = jcp_.ih;
103
104 const auto KD_BLOCK = jcp_.kd_block;
105 const auto KH_BLOCK = jcp_.kh_block;
106 const auto KW_BLOCK = jcp_.kw_block;
107
108 if (jcp_.use_uker) {
109
110 assert(KD % KD_BLOCK == 0);
111 assert(KH % KH_BLOCK == 0);
112
113 for (int iod = 0; iod < jcp_.od; iod++) {
114 const int iid = iod * SD - FP;
115 const int kd_s = div_up(max(0, -iid), DD);
116 const int kd_f
117 = KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD);
118 const auto kd_l = nstl::min(KD_BLOCK, kd_f - kd_s);
119 for (int ioh = 0; ioh < jcp_.oh; ioh++) {
120
121 const auto iih = ioh * SH - TP;
122 const auto kh_s
123 = jcp_.is_os_blocking ? 0 : div_up(max(0, -iih), DH);
124 const auto kh_f
125 = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH);
126 const auto kh_l = nstl::min(KH_BLOCK, kh_f - kh_s);
127 const auto bs = kd_l * kh_l * jcp_.kw;
128 if (bs < 0) continue;
129
130 if (batchsizes[bs] == -1) {
131 batchsizes[bs] = bs_c;
132 if (first_bs == 0) first_bs = bs;
133 bs_c++;
134 }
135 }
136 }
137 } else {
138 batchsizes[jcp_.max_batch] = bs_c;
139 first_bs = jcp_.max_batch;
140 bs_c++;
141 }
142
143 brgs_sz_ = bs_c * adj_M * 2 * 2 * 2;
144 brgs_.resize(brgs_sz_);
145 bd_masks.resize(brgs_sz_);
146
147 const float alpha = 1.0;
148 const float beta = 1.0;
149
150 const auto &p = attr()->post_ops_;
151 const int sum_idx = p.find(primitive_kind::sum);
152 with_sum = (sum_idx != -1);
153
154 // os_blocking is supported for exec_trans only
155 assert(IMPLICATION(jcp_.exec_type != exec_trans, !jcp_.is_os_blocking));
156 assert(IMPLICATION(jcp_.is_os_blocking,
157 jcp_.os_block % jcp_.ow == 0 && jcp_.os_block / jcp_.ow <= jcp_.oh
158 && jcp_.os_block / jcp_.ow == jcp_.oh_block));
159
160 auto maybe_M_mask = [&](int brg_idx, brgemm_attr_t &brgattr, int vM,
161 int vbrgM) {
162 if (!jcp_.use_M_mask) return;
163 auto sm_size = vbrgM;
164 bd_masks[brg_idx] = std::make_shared<std::vector<char>>();
165 bd_masks[brg_idx]->resize(sm_size);
166 char *bd_mask = bd_masks[brg_idx]->data();
167 if (jcp_.is_os_blocking) {
168 int ibrgM = 0;
169 int iM = 0;
170 for (int hh = 0; hh < jcp_.oh_block; hh++) {
171 auto M_mask = (iM >= vM) ? 0 : 1;
172 for (int ww = 0; ww < jcp_.ow_block && ibrgM < sm_size;
173 ww++, ibrgM++, iM += M_mask) {
174 bd_mask[ibrgM] = M_mask;
175 }
176 for (int kk = 0; kk < jcp_.oskip && ibrgM < sm_size;
177 kk++, ibrgM++) {
178 bd_mask[ibrgM] = 0;
179 }
180 }
181 for (; ibrgM < sm_size; ibrgM++) {
182 bd_mask[ibrgM] = 0;
183 }
184 } else {
185 for (int ibrgM = 0; ibrgM < sm_size; ibrgM++) {
186 bd_mask[ibrgM] = 1;
187 }
188 }
189 brgattr.bd_mask = bd_mask;
190 };
191
192 ic_chunks = div_up(jcp_.nb_ic, jcp_.nb_ic_blocking);
193 need_postwork = jcp_.with_bias || jcp_.with_eltwise || jcp_.with_binary
194 || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed
195 || (jcp_.dst_dt != jcp_.acc_dt) || jcp_.with_sum || jcp_.use_M_mask
196 || jcp_.src_zero_point || jcp_.dst_zero_point;
197
198 const auto M_end = nstl::max(jcp_.M, jcp_.M_tail);
199
200 int K_begin = 0;
201 int K_end = (jcp_.K_tail == 0) ? 1 : 2;
202
203 int i_init_begin = (jcp_.K_tail == 0 && jcp_.exec_type == exec_trans
204 && div_up(jcp_.nb_ic, jcp_.nb_ic_blocking) == 1
205 && KD_BLOCK == KD && KH_BLOCK == KH)
206 ? 1
207 : 0;
208 int i_init_end = 2;
209
210 for (int i = 0; i < M_end; i++) {
211 auto vM = i + 1;
212 // init only needed brgemm descriptors
213 if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M
214 && vM != jcp_.M_tail)
215 continue;
216 for (int bs = 0; bs <= jcp_.max_batch; bs++) {
217 if (batchsizes[bs] == -1) continue;
218 for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
219 for_(int i_N = 0; i_N < 2; i_N++)
220 for (int i_K = K_begin; i_K < K_end; i_K++) {
221 auto vbeta = (i_init) ? 0 : beta;
222 auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
223 auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
224 auto vbrgM = jcp_.use_M_mask
225 ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
226 : vM;
227 auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K);
228 // if brgemm_t already created then skip this iteration
229 if (brgs_[brg_idx] != nullptr) continue;
230 brgs_[brg_idx] = std::make_shared<brgemm_t>();
231 brgemm_t *brg = brgs_[brg_idx].get();
232 if (vN == 0 || vK == 0) continue;
233 brgemm_strides_t brg_strides;
234 brg_strides.stride_a = jcp_.brg_stride_a;
235 brg_strides.stride_b = jcp_.brg_stride_b;
236 brg->req_cal_comp_pads = jcp_.req_brg_comp_pad
237 && (jcp_.src_zero_point || jcp_.s8s8_avx512);
238 const auto strides_ptr = (jcp_.brg_type == brgemm_strd)
239 ? &brg_strides
240 : nullptr;
241 CHECK(brgemm_desc_init(brg, isa, jcp_.brg_type, src_type,
242 wei_type, false, false, brgemm_row_major, alpha, vbeta,
243 jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK,
244 strides_ptr));
245
246 brgemm_attr_t brgattr;
247 brgattr.use_uker = jcp_.use_uker;
248 brgattr.use_interleave_stores = jcp_.use_interleave_stores;
249 brgattr.hint_prefetching = jcp_.hint_prefetching;
250 brgattr.max_bs = bs;
251 brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
252 ? brgemm_bd_loop_innermost
253 : brgemm_ld_loop_innermost;
254 if (jcp_.amx_tile_load_xx) {
255 // assuming 2x2 decomposition in amx brgemm kernel
256 // and overlap of input by kw
257 const auto bd_blocking = 2 * jcp_.amx_h;
258 const auto ld_blocking = 2 * 16;
259 brgattr.hint_expected_A_size = bd_blocking * jcp_.K
260 * jcp_.kd_block * jcp_.kh_block;
261 brgattr.hint_expected_B_size = ld_blocking * jcp_.K
262 * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
263 brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
264 } else {
265 brgattr.hint_expected_A_size = 0;
266 brgattr.hint_expected_B_size = 0;
267 brgattr.hint_expected_C_size = 0;
268 }
269
270 brgattr.wary_tail_read = false;
271 maybe_M_mask(brg_idx, brgattr, vM, vbrgM);
272 brgattr.bd_mask_level = jcp_.use_M_mask;
273
274 if (is_amx) {
275 brgattr.max_top_vpad = 0;
276 brgattr.max_bottom_vpad = 0;
277 } else {
278 brgattr.max_top_vpad = jcp_.max_vpad;
279 brgattr.max_bottom_vpad = jcp_.max_vpad;
280 }
281 brgattr.fpmath_mode = attr()->fpmath_mode_;
282
283 // if need post_ops and there are no intermediate calculations
284 // (like ic_chunks > 1 or blocking by kernel) we don't need
285 // code without post-ops in brgemm kernel
286 if (need_postwork && ic_chunks == 1 && KD_BLOCK == KD
287 && KH_BLOCK == KH && KW_BLOCK == KW)
288 brgattr.postops_only = true;
289
290 CHECK(brgemm_desc_set_attr(brg, brgattr));
291
292 auto LDD = jcp_.oc_without_padding;
293 brg->with_sum = with_sum;
294 CHECK(brgemm_desc_set_postops(
295 brg, attr(), &dst_md_, LDD, jcp_.bia_dt));
296 jcp_.amx_buf_size_per_thread
297 = nstl::max(brg->get_wsp_buffer_size(),
298 jcp_.amx_buf_size_per_thread);
299 }
300 }
301 }
302
303 brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_);
304 auto scratchpad = scratchpad_registry().registrar();
305 brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
306 if (jcp_.with_scales)
307 book_precomputed_scales(scratchpad, attr()->scales_, OC());
308
309 return status::success;
310}
311
312template <cpu_isa_t isa, bool use_inversion>
313brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_convolution_fwd_t(
314 const pd_t *apd)
315 : primitive_t(apd), bias_d(pd()->weights_md(1)) {}
316
317template <cpu_isa_t isa, bool use_inversion>
318void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range(
319 int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const {
320 // This function needed for exec_base only
321 const auto _pd = pd();
322 const auto &jcp = _pd->jcp_;
323 // TODO: calculate these values instead direct loop by kw
324
325 const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block);
326 const auto M = is_ow_tail ? jcp.ow_tail : jcp.ow_block;
327 kw_s = kw_full_s = kw_full_f = kw_f = -1;
328 for (int kw = 0; kw < jcp.kw; kw++) {
329 int ow_s {0}, ow_f {0};
330 get_ow_range(ow, kw, ow_s, ow_f);
331 if (ow_s < ow_f) {
332 if (kw_s == -1) kw_s = kw;
333 kw_f = kw + 1;
334 if (ow_f - ow_s == M) {
335 if (kw_full_s == -1) kw_full_s = kw;
336 kw_full_f = kw + 1;
337 }
338 }
339 }
340 if (kw_f == -1) {
341 kw_s = 0;
342 kw_f = 0;
343 }
344 if (kw_full_f == -1) kw_full_s = kw_full_f = kw_f;
345}
346
347template <cpu_isa_t isa, bool use_inversion>
348void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range(
349 int ow, int kw, int &ow_s, int &ow_f) const {
350 // This function needed for exec_base only
351 const auto _pd = pd();
352 const auto &jcp = _pd->jcp_;
353
354 const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block);
355 const auto M = is_ow_tail ? jcp.ow_tail : jcp.ow_block;
356
357 const auto IW = jcp.iw;
358 const auto SW = jcp.stride_w;
359 const auto LP = jcp.l_pad;
360 const auto DW = jcp.dilate_w + 1;
361
362 const auto iiw = ow * SW - LP;
363 auto iw_lp = iiw + kw * DW;
364 const auto iw_rp = iw_lp + (M - 1) * SW - IW + 1;
365 ow_s = ow;
366
367 int ker_idx = 0;
368 if (iw_lp < 0) {
369 iw_lp = nstl::abs(iw_lp);
370 ker_idx += div_up(iw_lp, SW);
371 ow_s += ker_idx;
372 }
373 if (iw_rp > 0) ker_idx += div_up(iw_rp, SW);
374 ow_f = ow_s + (M - ker_idx);
375 ow_s = nstl::min(ow_s, ow + M);
376 ow_f = nstl::min(nstl::max(ow_f, ow_s), ow + M);
377}
378
379template <cpu_isa_t isa, bool use_inversion>
380status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel(
381 int bs, int M, int i_N, int i_K, int i_init) {
382 if (M <= 0) return status::success;
383 const auto _pd = pd();
384 const auto &jcp = _pd->jcp_;
385 const auto &brgs = _pd->brgs_;
386
387 auto N = (i_N) ? jcp.N_tail : jcp.N;
388 auto K = (i_K) ? jcp.K_tail : jcp.K;
389 if (N <= 0 || K <= 0) return status::success;
390 auto brg_idx = _pd->get_brg_idx(bs, M - 1, i_init, i_N, i_K);
391 auto brg = brgs[brg_idx];
392 if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0
393 && brg->reduce_dim > 0) {
394 brgemm_kernel_t *brg_kernel = nullptr;
395 CHECK(brgemm_kernel_create(&brg_kernel, *brg));
396 CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
397 if (is_amx) {
398 CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0]));
399 }
400 }
401 return status::success;
402}
403
404template <cpu_isa_t isa, bool use_inversion>
405status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
406 brgemm_t *bcfg, int ker_idx, bool is_init) {
407 if (!bcfg) return status::success;
408 const auto _pd = pd();
409 const auto &jcp = _pd->jcp_;
410
411 bcfg->LDD = (is_init && jcp.use_buffer) ? jcp.LDC : jcp.LDD;
412 bcfg->dt_c = (!is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // inp
413 bcfg->dt_d = (is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // out
414 bcfg->alpha
415 = (!is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer)) ? 1 : 0;
416 bcfg->beta = is_init ? 0 : 1;
417 CHECK(safe_ptr_assign(kernels_po_[ker_idx],
418 new jit_brgemm_kernel_post_ops<isa>(jcp, *bcfg, *_pd->attr())));
419 kernels_po_[ker_idx]->create_kernel();
420 return status::success;
421}
422
423template <cpu_isa_t isa, bool use_inversion>
424void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels(
425 int i_N, int init_bcast_dim, int po_bcast_dim) {
426 const auto _pd = pd();
427 const auto &jcp = _pd->jcp_;
428 const auto &brgs = _pd->brgs_;
429
430 auto N = (i_N) ? jcp.N_tail : jcp.N;
431 if (N <= 0) return;
432 auto i_K = (jcp.K_tail > 0);
433
434 if (init_bcast_dim > 0) {
435 auto brg_idx = _pd->get_brg_idx(
436 _pd->first_bs, init_bcast_dim - 1, 0, i_N, i_K);
437 if (brgs[brg_idx]) {
438 auto init_cfg = *(brgs[brg_idx].get());
439 auto ker_init_idx = get_ker_po_idx(init_bcast_dim - 1, false, i_N);
440 if (init_cfg.load_dim > 0 && kernels_po_[ker_init_idx] == nullptr) {
441 init_cfg.bcast_dim = init_bcast_dim;
442 add_po_kernel(&init_cfg, ker_init_idx, true);
443 }
444 }
445 }
446
447 if ((_pd->need_postwork || jcp.use_buffer) && po_bcast_dim > 0) {
448 auto brg_idx = _pd->get_brg_idx(
449 _pd->first_bs, po_bcast_dim - 1, 0, i_N, i_K);
450 if (brgs[brg_idx]) {
451 auto po_cfg = *(brgs[brg_idx].get());
452 auto ker_po_idx = get_ker_po_idx(po_bcast_dim - 1, true, i_N);
453 if (po_cfg.load_dim > 0 && kernels_po_[ker_po_idx] == nullptr) {
454 po_cfg.bcast_dim = po_bcast_dim;
455 add_po_kernel(&po_cfg, ker_po_idx, false);
456 }
457 }
458 }
459}
460template <cpu_isa_t isa, bool use_inversion>
461int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx(
462 const int kd_b, const int kd_e, const int kh_b, const int kh_e,
463 const int kw_b, const int kw_e) const {
464 const auto _pd = pd();
465 const auto &jcp = _pd->jcp_;
466
467 if (!jcp.req_cal_comp_pad) return 0;
468
469 assert(kd_e > kd_b && kh_e > kh_b);
470 for (int k = 0; k < jcp.ker_ranges_size; k++) {
471 if (kd_b == kd_bs[k] && kd_e == kd_es[k] && kh_b == kh_bs[k]
472 && kh_e == kh_es[k] && kw_b == kw_bs[k] && kw_e == kw_es[k]) {
473 return k;
474 }
475 }
476
477 return -1;
478}
479
480template <cpu_isa_t isa, bool use_inversion>
481int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(const int g,
482 const int ocb, const int ow, const int kd_b, const int kd_e,
483 const int kh_b, const int kh_e, const int kw_b, const int kw_e) const {
484 const auto _pd = pd();
485 const auto &jcp = _pd->jcp_;
486
487 if (!jcp.src_zero_point && !jcp.s8s8_avx512) return 0;
488
489 const auto comp_idx = get_comp_ker_idx(kd_b, kd_e, kh_b, kh_e, kw_b, kw_e);
490 assert(IMPLICATION(jcp.req_cal_comp_pad, comp_idx >= 0));
491
492 return jcp.req_cal_comp_pad
493 ? g * comp_ocb_sz + ocb * comp_ker_sz + comp_idx * comp_kw_sz
494 : (g * jcp.nb_oc + ocb) * jcp.oc_block;
495}
496
497template <cpu_isa_t isa, bool use_inversion>
498status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {
499
500 const auto _pd = pd();
501 const auto &jcp = _pd->jcp_;
502
503 bia_dsz = jcp.bia_dsz;
504 acc_dsz = jcp.acc_dsz;
505 src_dsz = jcp.src_dsz;
506 wei_dsz = jcp.wei_dsz;
507 dst_dsz = jcp.dst_dsz;
508
509 auto ndims = _pd->ndims();
510 if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!");
511
512 KD = ndims_pick(jcp.kd, 1, 1);
513 KH = ndims_pick(jcp.kh, jcp.kh, 1);
514 KW = jcp.kw;
515
516 EXT_KD = ndims_pick(jcp.ext_kd, 1, 1);
517 EXT_KH = ndims_pick(jcp.ext_kh, jcp.ext_kh, 1);
518 EXT_KW = jcp.ext_kw;
519
520 IDP = ndims_pick(jcp.idp, 1, 1);
521 IHP = ndims_pick(jcp.ihp, jcp.ihp, 1);
522 IWP = jcp.iwp;
523
524 KS = KD * KH * KW;
525 KD_BLOCK = ndims_pick(jcp.kd_block, 1, 1);
526 KH_BLOCK = ndims_pick(jcp.kh_block, jcp.kh_block, 1);
527 KW_BLOCK = jcp.kw_block;
528 KD_BLOCK_PAD = ndims_pick(jcp.kd_block_pad, 1, 1);
529 KH_BLOCK_PAD = ndims_pick(jcp.kh_block_pad, jcp.kh_block_pad, 1);
530 ID = ndims_pick(jcp.id, 1, 1);
531 IH = ndims_pick(jcp.ih, jcp.ih, 1);
532 IW = jcp.iw;
533 OD = ndims_pick(jcp.od, 1, 1);
534 OH = ndims_pick(jcp.oh, jcp.oh, 1);
535 OW = jcp.ow;
536 SD = ndims_pick(jcp.stride_d, 1, 1);
537 SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1);
538 SW = jcp.stride_w;
539 FP = ndims_pick(jcp.f_pad, 0, 0);
540 TP = ndims_pick(jcp.t_pad, jcp.t_pad, 0);
541 LP = jcp.l_pad;
542 DD = ndims_pick(jcp.dilate_d, 0, 0) + 1;
543 DH = ndims_pick(jcp.dilate_h, jcp.dilate_h, 0) + 1;
544 DW = jcp.dilate_w + 1;
545
546 // const variables used for address calculations
547 src_w_sz = static_cast<dim_t>(IW) * jcp.ngroups * jcp.ic_without_padding;
548 src_h_sz = IH * src_w_sz;
549 src_d_sz = ID * src_h_sz;
550 dst_w_sz = static_cast<dim_t>(OW) * jcp.oc_without_padding;
551 dst_h_sz = OH * dst_w_sz;
552 dst_d_sz = OD * dst_h_sz;
553
554 wei_ic_sz = static_cast<dim_t>(jcp.icp) * jcp.oc_block;
555 wei_kw_sz = KW * wei_ic_sz;
556 wei_kh_sz = KH * wei_kw_sz;
557 wei_kd_sz = KD * wei_kh_sz;
558 wei_ocb_sz = jcp.nb_oc * wei_kd_sz;
559
560 comp_kw_sz = static_cast<dim_t>(jcp.oc_block);
561 comp_ker_sz = jcp.ker_ranges_size * comp_kw_sz;
562 comp_ocb_sz = jcp.nb_oc * comp_ker_sz;
563
564 need_compensation
565 = (jcp.src_zero_point || jcp.s8s8_avx512) && !jcp.req_brg_comp_pad;
566
567 // ---- Initialize arrays ---------------------
568 brg_kernels_.resize(_pd->brgs_sz_);
569 brg_kernel_palettes_.resize(_pd->brgs_sz_);
570
571 for (int i = 0; i < _pd->brgs_sz_; i++)
572 brg_kernels_[i] = nullptr;
573
574 int num_po_kernels = nstl::max(jcp.M, jcp.M_tail);
575 kernels_po_.resize(num_po_kernels * 2 * 2);
576 for (int i = 0; i < num_po_kernels; i++) {
577 for_(int i_init = 0; i_init < 2; i_init++)
578 for (int i_N = 0; i_N < 2; i_N++)
579 kernels_po_[get_ker_po_idx(i, i_init, i_N)] = nullptr;
580 }
581
582 if (jcp.exec_type == exec_trans) {
583 CHECK(safe_ptr_assign(copy_to_pbuffer_,
584 new jit_avx512_core_brgemm_conv_trans_kernel_t(jcp)));
585 CHECK(copy_to_pbuffer_->create_kernel());
586 }
587 if (jcp.copy_block_only) {
588 assert(jcp.exec_type == exec_trans && "Missing copy kernel");
589 const auto iw_block = copy_to_pbuffer_->dst_w(jcp.ow_block);
590 const auto ih_block = get_inp_size(IHP, jcp.oh_block, KH, SH, DH - 1);
591 const auto id_block = get_inp_size(IDP, jcp.od_block, KD, SD, DD - 1);
592
593 pbuf_w_sz = (dim_t)jcp.ic_block * jcp.kh_sets * jcp.kw_sets * iw_block;
594 pbuf_h_sz = pbuf_w_sz * ih_block;
595 pbuf_d_sz = pbuf_h_sz * id_block;
596
597 } else {
598 pbuf_w_sz = (dim_t)jcp.ic_block * jcp.kh_sets * jcp.kw_sets * jcp.iwp;
599 pbuf_h_sz = pbuf_w_sz * jcp.ihp;
600 pbuf_d_sz = pbuf_h_sz * jcp.idp;
601 }
602
603 if (jcp.req_cal_comp_pad) {
604 CHECK(safe_ptr_assign(comp_vpad_pbuffer_,
605 new jit_avx512_core_brgemm_conv_comp_pad_kernel_t(jcp)));
606 CHECK(comp_vpad_pbuffer_->create_kernel());
607 }
608
609 is_amx = brgemm_convolution_utils::is_amx(isa);
610
611 // #TODO: this needed only if we have d/h padding more then kd/kh
612 int M_begin = 0;
613 int M_end = (jcp.M_tail == jcp.M) ? 1 : 2;
614 int N_begin = 0;
615 int N_end = (jcp.N_tail == jcp.N) ? 1 : 2;
616 int K_begin = 0;
617 int K_end = (jcp.K_tail == 0) ? 1 : 2;
618 int i_init_begin = (jcp.K_tail == 0 && jcp.exec_type == exec_trans
619 && div_up(jcp.nb_ic, jcp.nb_ic_blocking) == 1
620 && KD_BLOCK == KD && KH_BLOCK == KH)
621 ? 1
622 : 0;
623 int i_init_end = 2;
624
625 for (int bs = 0; bs <= jcp.max_batch; bs++) {
626 if (_pd->batchsizes[bs] == -1) continue;
627
628 for_(int i_N = N_begin; i_N < N_end; i_N++)
629 for_(int i_M = M_begin; i_M < M_end; i_M++)
630 for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
631 for (int i_K = K_begin; i_K < K_end; i_K++) {
632 auto M = (i_M) ? jcp.M_tail : jcp.M;
633 if (M <= 0) continue;
634 add_brg_kernel(bs, M, i_N, i_K, i_init);
635 }
636 }
637
638 for_(int i_N = N_begin; i_N < N_end; i_N++)
639 for (int i_M = M_begin; i_M < M_end; i_M++) {
640 // init "init" and "po" kernels for cases then we never call brgemm kernels
641 // e.g. for d/h padded areas
642 // TODO: do this only if d/h padding > kd/kh
643 if (IMPLICATION(jcp.exec_type == exec_trans,
644 jcp.od > jcp.id || jcp.oh > jcp.ih)) {
645 auto M = (i_M) ? jcp.M_tail : jcp.M;
646 add_po_kernels(i_N, M, M);
647 }
648 }
649
650 if (jcp.exec_type == exec_base) {
651 // create brgemm kernels for ow_blocks with padded areas and
652 // apply post-ops on final iteration by kw to padded areas in ow_block
653 int kw_s {0}, kw_full_s {0}, kw_full_f {0}, kw_f {0}, ow_s {0},
654 ow_f {0};
655 for (int ow = 0; ow < OW; ow += jcp.ow_block) {
656 get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f);
657 for (int kw = kw_s; kw < kw_f; kw++) {
658 get_ow_range(ow, kw, ow_s, ow_f);
659 if (ow_f - ow_s <= 0) continue;
660
661 auto M = ow_f - ow_s;
662 if (M <= 0) continue;
663 for (int bs = 0; bs <= jcp.max_batch; bs++) {
664 if (_pd->batchsizes[bs] == -1) continue;
665 for_(int i_init = 0; i_init < 2; i_init++)
666 for_(int i_N = 0; i_N < 2; i_N++)
667 for (int i_K = 0; i_K < 2; i_K++) {
668 add_brg_kernel(bs, M, i_N, i_K, i_init);
669 }
670 }
671 }
672
673 bool is_ow_tail = (jcp.ow - ow < jcp.ow_block);
674 for_(int i_N = 0; i_N < 2; i_N++)
675 for (int i_side = 0; i_side < 2; i_side++) {
676 auto M = is_ow_tail ? jcp.M_tail : jcp.M;
677 if (M <= 0) continue;
678 get_ow_range(ow, kw_s, ow_s, ow_f);
679 const auto init_bcast_dim
680 = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f);
681 get_ow_range(ow, kw_f - 1, ow_s, ow_f);
682 const auto po_bcast_dim
683 = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f);
684 add_po_kernels(i_N, init_bcast_dim, po_bcast_dim);
685 }
686
687 if (kw_f == jcp.kw && kw_s == 0) break;
688 }
689
690 for (int ow = (jcp.nb_ow - 1) * jcp.ow_block; ow >= 0;
691 ow -= jcp.ow_block) {
692 get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f);
693 for (int kw = kw_s; kw < kw_f; kw++) {
694 get_ow_range(ow, kw, ow_s, ow_f);
695 if (ow_f - ow_s <= 0) continue;
696
697 auto M = ow_f - ow_s;
698 if (M <= 0) continue;
699 for (int bs = 0; bs <= jcp.max_batch; bs++) {
700 if (_pd->batchsizes[bs] == -1) continue;
701 for_(int i_init = 0; i_init < 2; i_init++)
702 for_(int i_N = 0; i_N < 2; i_N++)
703 for (int i_K = 0; i_K < 2; i_K++) {
704 add_brg_kernel(bs, M, i_N, i_K, i_init);
705 }
706 }
707 }
708
709 bool is_ow_tail = (jcp.ow - ow < jcp.ow_block);
710
711 for_(int i_N = 0; i_N < 2; i_N++)
712 for (int i_side = 0; i_side < 2; i_side++) {
713 auto M = is_ow_tail ? jcp.M_tail : jcp.M;
714 if (M <= 0) continue;
715 get_ow_range(ow, kw_s, ow_s, ow_f);
716 const auto init_bcast_dim
717 = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f);
718 get_ow_range(ow, kw_f - 1, ow_s, ow_f);
719 const auto po_bcast_dim
720 = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f);
721 add_po_kernels(i_N, init_bcast_dim, po_bcast_dim);
722 }
723
724 if (kw_f == jcp.kw && kw_s == 0) break;
725 }
726 }
727
728 // pre-calculated values
729 if (jcp.exec_type == exec_vpad) {
730 owb_kw_top_vpads.resize(jcp.nb_ow * jcp.kw);
731 owb_kw_bottom_vpads.resize(jcp.nb_ow * jcp.kw);
732
733 for (int owb = 0; owb < jcp.nb_ow; owb++) {
734 const int ow = owb * jcp.ow_block;
735 const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block);
736 const int ow_b {ow}, ow_e {ow + (is_ow_tail ? jcp.M_tail : jcp.M)};
737 const auto ow_l = ow_e - ow_b;
738 MAYBE_UNUSED(ow_l);
739 assert(0 <= ow_l && ow_l <= jcp.ow_block);
740 const auto iiw_b = ow_b * SW - LP;
741 const auto iiw_e = (ow_e - 1) * SW - LP + 1;
742 const auto iiw_l = iiw_e - iiw_b;
743 for (int kw = 0; kw < KW; kw++) {
744 const auto iw = iiw_b + kw * DW;
745 const auto top_vpad = iw >= 0 ? 0 : div_up(abs(iw), SW);
746 const auto bottom_vpad
747 = iw + iiw_l <= IW ? 0 : div_up(iw + iiw_l - IW, SW);
748 assert(top_vpad == 0 || bottom_vpad == 0);
749 owb_kw_top_vpads[owb * KW + kw] = top_vpad;
750 owb_kw_bottom_vpads[owb * KW + kw] = bottom_vpad;
751 }
752 }
753 }
754
755 // pre-calculate unique kernel combination
756 if (jcp.req_cal_comp_pad) {
757 std::set<std::vector<int>> unique_kernels;
758 size_t k = 0;
759 kd_bs.resize(jcp.ker_ranges_size);
760 kd_es.resize(jcp.ker_ranges_size);
761 kh_bs.resize(jcp.ker_ranges_size);
762 kh_es.resize(jcp.ker_ranges_size);
763 kw_bs.resize(jcp.ker_ranges_size);
764 kw_es.resize(jcp.ker_ranges_size);
765
766 const auto update_kernels = [&](int kd_b, int kd_e, int kh_b, int kh_e,
767 int kw_b, int kw_e) {
768 unique_kernels.insert({kd_b, kd_e, kh_b, kh_e, kw_b, kw_e});
769 if (k == unique_kernels.size()) return;
770 kd_bs[k] = kd_b;
771 kd_es[k] = kd_e;
772 kh_bs[k] = kh_b;
773 kh_es[k] = kh_e;
774 kw_bs[k] = kw_b;
775 kw_es[k] = kw_e;
776 k++;
777 assert(k <= static_cast<size_t>(jcp.ker_ranges_size));
778 };
779
780 for_(int odb = 0; odb < jcp.nb_od; odb++)
781 for_(int ohb = 0; ohb < jcp.nb_oh; ohb++)
782 for (int owb = 0; owb < jcp.nb_ow; owb++) {
783 auto od_begin = odb * jcp.od_block;
784 auto od_end = nstl::min(OD, od_begin + jcp.od_block);
785 auto oh_begin = ohb * jcp.oh_block;
786 auto oh_end = jcp.is_os_blocking
787 ? oh_begin + 1
788 : nstl::min(OH, oh_begin + jcp.oh_block);
789 for_(int od = od_begin; od < od_end; od++)
790 for (int oh = oh_begin; oh < oh_end; oh++) {
791 int kw_s {0}, kw_full_s {0}, kw_f {0}, kw_full_f {0};
792 const int ow = owb * jcp.ow_block;
793 const int iid = ndims_pick(od * SD - FP, 0, 0);
794 const int kd_s
795 = ndims_pick(div_up(nstl::max(0, -iid), DD), 0, 0);
796 const int kd_f = ndims_pick(
797 KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD),
798 1, 1);
799 const int iih = ndims_pick(oh * SH - TP, oh * SH - TP, 0);
800 const auto kh_s_ = div_up(max(0, -iih), DH);
801 const auto kh_s = ndims_pick(kh_s_, kh_s_, 0);
802 const auto kh_f_
803 = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH);
804 const auto kh_f = ndims_pick(kh_f_, kh_f_, 1);
805 get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f);
806 if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) {
807 if (jcp.exec_type == exec_vpad) {
808 update_kernels(kd_s, kd_f, kh_s, kh_f, 0, KW);
809 } else if (jcp.exec_type == exec_base) {
810 if (kw_s < kw_full_s) {
811 for (auto kw = kw_s; kw < kw_full_s; kw++) {
812 update_kernels(
813 kd_s, kd_f, kh_s, kh_f, kw, kw + 1);
814 }
815 }
816 if (kw_full_s < kw_full_f) {
817 for (auto kw = kw_full_s; kw < kw_full_f;
818 kw += KW_BLOCK) {
819 const auto kw_e
820 = nstl::min(kw_full_f, kw + KW_BLOCK);
821 update_kernels(
822 kd_s, kd_f, kh_s, kh_f, kw, kw_e);
823 }
824 }
825 if (kw_full_f < kw_f) {
826 for (auto kw = kw_full_f; kw < kw_f; kw++) {
827 update_kernels(
828 kd_s, kd_f, kh_s, kh_f, kw, kw + 1);
829 }
830 }
831 }
832 }
833 }
834 }
835 ker_vpad_sz = k;
836 }
837
838 return status::success;
839}
840template <cpu_isa_t isa, bool use_inversion>
841struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t {
842 brgemm_thread_ctx_t(brgemm_exec_ctx_t &brgemm_ctx_, int ithr_,
843 brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_,
844 char *wsp_tile_)
845 : brgemm_ctx(brgemm_ctx_)
846 , ithr(ithr_)
847 , brg_batch(brg_batch_)
848 , c_buffer(c_buffer_)
849 , wsp_tile(wsp_tile_) {}
850
851 brgemm_exec_ctx_t &brgemm_ctx;
852 int ithr;
853 brgemm_batch_element_t *__restrict brg_batch;
854 char *c_buffer;
855 char *wsp_tile;
856 S_t cur_palette;
857 int g, n, ocb;
858 int od, odb, oh, ohb, owb;
859 int icc;
860 const float *oscales {nullptr};
861 int32_t src_zp_vals;
862 int32_t *src_zp_comp_ptr;
863 int32_t *dst_zp_vals;
864 int32_t *s8s8_comp_ptr;
865};
866
867template <cpu_isa_t isa, bool use_inversion>
868status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute(
869 const exec_ctx_t &ctx) const {
870 const auto _pd = pd();
871 const auto &jcp = _pd->jcp_;
872
873 DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
874 DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
875
876 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
877 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
878
879 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
880 src_scales, wei_scales, _pd->OC(), _pd->attr());
881
882 brgemm_exec_ctx_t brgemm_ctx(ctx, _pd);
883
884 const char *const __restrict src = brgemm_ctx.src;
885 const char *const __restrict wei = brgemm_ctx.weights;
886 const memory_desc_wrapper weights_d(pd()->weights_md(0));
887
888 const auto extra_data_offset
889 = weights_d.size() - weights_d.additional_buffer_size();
890 auto w = const_cast<char *>(brgemm_ctx.weights);
891 const auto s8s8_comp_offset = jcp.req_cal_comp_pad
892 ? jcp.ngroups * jcp.nb_oc * jcp.kd * jcp.kh * jcp.kw * jcp.oc_block
893 : jcp.ngroups * jcp.nb_oc * jcp.oc_block;
894 int32_t *s8s8_compensation = jcp.s8s8_avx512
895 ? reinterpret_cast<int32_t *>(w + extra_data_offset)
896 : nullptr;
897 int32_t *zp_compensation = jcp.src_zero_point
898 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
899 + (jcp.s8s8_avx512 ? s8s8_comp_offset : 0)
900 : nullptr;
901
902 const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
903 brgemm_batch_element_t *const __restrict brg_batch_global
904 = (jcp.brg_type == brgemm_strd && jcp.exec_type != exec_vpad)
905 ? nullptr
906 : scratchpad.template get<brgemm_batch_element_t>(
907 key_brgemm_primitive_batch);
908 char *const __restrict c_buffer_global = (jcp.use_buffer)
909 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
910 : nullptr;
911
912 auto inp_p_buffer = (jcp.exec_type == exec_trans)
913 ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer)
914 : nullptr;
915 auto inp_p_buffer_mask = (jcp.exec_type == exec_trans)
916 ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask)
917 : nullptr;
918 int32_t *src_zp_comp_base = jcp.src_zero_point
919 ? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>(
920 key_brgemm_primitive_zp_comp_a)
921 : zp_compensation)
922 : nullptr;
923 int32_t *s8s8_comp_base = jcp.s8s8_avx512
924 ? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>(
925 key_brgemm_primitive_buffer_comp)
926 : s8s8_compensation)
927 : nullptr;
928 const auto dst_zp_vals = jcp.dst_zero_point ? &dst_zero_point : nullptr;
929 const auto src_zp_vals = src_zero_point;
930
931 cal_compensation(wei, src_zp_comp_base, s8s8_comp_base);
932
933 char *const wsp_tile_global = is_amx
934 ? scratchpad.template get<char>(key_conv_amx_tile_buffer)
935 : nullptr;
936
937 // --------------- Parallel section ------------------------------
938 const dim_t work_amount = static_cast<dim_t>(jcp.mb) * jcp.ngroups
939 * jcp.nb_oc * jcp.nb_od * jcp.nb_oh * jcp.nb_ow;
940 // TODO: consider loop by icc be innermost because for current
941 // implementation if we use buffer then we accumulate in it only on row
942 // or made ic_chunks = 1 if use_buffer
943 // or (looks more general) increase buffer size to store several rows
944
945 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
946 if (ithr >= work_amount) return;
947
948 brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global
949 + static_cast<size_t>(ithr) * jcp.adjusted_batch_size;
950 char *const __restrict c_buffer = (jcp.use_buffer)
951 ? c_buffer_global + ithr * acc_dsz * jcp.buffer_size
952 : nullptr;
953 char *inp_buffer = (jcp.exec_type == exec_trans)
954 ? inp_p_buffer + src_dsz * ithr * jcp.inp_buffer_size
955 : nullptr;
956 if (is_amx) {
957 // Workaround: for some machines SEGFAULT possible on tile load
958 // if the page was not touched before it
959 for (dim_t i = 0; i < jcp.inp_buffer_size;
960 i += brgemm_convolution_utils::P4K)
961 inp_buffer[i] = 0;
962 }
963
964 uint8_t *__restrict inp_buffer_mask = (jcp.exec_type == exec_trans)
965 ? inp_p_buffer_mask + ithr * jcp.inp_buffer_mask_size
966 : nullptr;
967
968 char *const wsp_tile = is_amx
969 ? wsp_tile_global + ithr * jcp.amx_buf_size_per_thread
970 : nullptr;
971 dim_t start {0}, end {0};
972 balance211(work_amount, nthr, ithr, start, end);
973 int n {0}, g {0}, ocb {0}, odb {0}, ohb {0}, owb {0};
974 if (jcp.loop_order == loop_ndhwgc)
975 nd_iterator_init(start, n, jcp.mb, odb, jcp.nb_od, ohb, jcp.nb_oh,
976 owb, jcp.nb_ow, g, jcp.ngroups, ocb, jcp.nb_oc);
977 else if (jcp.loop_order == loop_ngcdhw)
978 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc,
979 odb, jcp.nb_od, ohb, jcp.nb_oh, owb, jcp.nb_ow);
980 else
981 assert(!"Unknown loop order");
982
983 brgemm_thread_ctx_t btc(
984 brgemm_ctx, ithr, brg_batch, c_buffer, wsp_tile);
985 std::memset(btc.cur_palette.a, 0, AMX_PALETTE_SIZE);
986
987 int last_n = -1;
988 int last_g = -1;
989 int last_icc = -1;
990 int last_odb = -1;
991 int last_ohb = -1;
992 int last_owb = -1;
993 for (auto work = start; work < end; work++) {
994 btc.g = g;
995 btc.n = n;
996 btc.ocb = ocb;
997 btc.odb = odb;
998 btc.ohb = ohb;
999 btc.owb = owb;
1000 btc.oscales = oscales;
1001 btc.src_zp_vals = src_zp_vals;
1002 btc.dst_zp_vals = jcp.dst_zero_point ? dst_zp_vals : nullptr;
1003 btc.src_zp_comp_ptr
1004 = jcp.src_zero_point ? src_zp_comp_base : nullptr;
1005 btc.s8s8_comp_ptr = jcp.s8s8_avx512 ? s8s8_comp_base : nullptr;
1006
1007 if (jcp.exec_type == exec_trans && (last_n != n || last_g != g)) {
1008 if (!jcp.copy_block_only)
1009 std::memset(
1010 inp_buffer_mask, false, jcp.inp_buffer_mask_size);
1011 }
1012 auto od_begin = odb * jcp.od_block;
1013 auto od_end = nstl::min(OD, od_begin + jcp.od_block);
1014 auto oh_begin = ohb * jcp.oh_block;
1015 // if is_os_blocking is true then we do only one iteration of loop
1016 // by oh and process entire oh block in kernel call
1017 auto oh_end = jcp.is_os_blocking
1018 ? oh_begin + 1
1019 : nstl::min(OH, oh_begin + jcp.oh_block);
1020 for_(int od = od_begin; od < od_end; od++)
1021 for (int oh = oh_begin; oh < oh_end; oh++) {
1022 for (int icc = 0; icc < _pd->ic_chunks; icc++) {
1023 btc.od = od;
1024 btc.oh = oh;
1025 btc.icc = icc;
1026
1027 if (jcp.exec_type == exec_base) {
1028 ker_base(btc);
1029 } else if (jcp.exec_type == exec_trans) {
1030 maybe_conv_inp(ithr, src, inp_buffer, inp_buffer_mask,
1031 g, n, icc, odb, ohb, owb, last_g, last_n,
1032 last_icc, last_odb, last_ohb, last_owb);
1033 ker_trans(btc, inp_buffer);
1034 } else if (jcp.exec_type == exec_vpad) {
1035 ker_vpad(btc);
1036 } else
1037 assert(!"Unknown exec type");
1038 last_n = n;
1039 last_g = g;
1040 last_icc = icc;
1041 last_odb = odb;
1042 last_ohb = ohb;
1043 last_owb = owb;
1044 }
1045 }
1046 if (jcp.loop_order == loop_ndhwgc)
1047 nd_iterator_step(n, jcp.mb, odb, jcp.nb_od, ohb, jcp.nb_oh, owb,
1048 jcp.nb_ow, g, jcp.ngroups, ocb, jcp.nb_oc);
1049 else if (jcp.loop_order == loop_ngcdhw)
1050 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, odb,
1051 jcp.nb_od, ohb, jcp.nb_oh, owb, jcp.nb_ow);
1052 else
1053 assert(!"Unknown loop order");
1054 }
1055 if (is_amx) { amx_tile_release(); }
1056 });
1057
1058 if (_pd->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx);
1059
1060 return status::success;
1061}
1062
1063template <cpu_isa_t isa, bool use_inversion>
1064status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation(
1065 const char *__restrict weights, int32_t *src_zp_buffer,
1066 int32_t *s8s8_comp_buffer) const {
1067 const auto _pd = pd();
1068 const auto &jcp = _pd->jcp_;
1069
1070 if (!jcp.req_cal_comp_pad) return status::success;
1071
1072 if (jcp.src_zero_point)
1073 std::memset(src_zp_buffer, 0, sizeof(int32_t) * jcp.comp_a_buffer_size);
1074 if (jcp.s8s8_avx512)
1075 std::memset(s8s8_comp_buffer, 0,
1076 sizeof(int32_t) * jcp.s8s8_comp_buffer_size);
1077
1078 const auto work_amount
1079 = static_cast<dim_t>(jcp.ngroups) * jcp.nb_oc * ker_vpad_sz;
1080 const auto is_small_shape = work_amount <= jcp.nthr
1081 && (work_amount * jcp.oc_block * jcp.icp
1082 <= platform::get_per_core_cache_size(1));
1083 const int nthr = is_small_shape ? 1 : jcp.nthr;
1084
1085 parallel(nthr, [&](const int ithr, const int nthr) {
1086 if (ithr >= work_amount) return;
1087
1088 dim_t start {0}, end {0};
1089 int g {0}, ocb {0}, k {0};
1090 balance211(work_amount, nthr, ithr, start, end);
1091 nd_iterator_init(start, g, jcp.ngroups, ocb, jcp.nb_oc, k, ker_vpad_sz);
1092 for (auto work = start; work < end; work++) {
1093 const dim_t kd_b {kd_bs[k]}, kd_e {kd_es[k]}, kh_b {kh_bs[k]},
1094 kh_e {kh_es[k]}, kw_b {kw_bs[k]}, kw_e {kw_es[k]};
1095 assert(kd_e > kd_b && kh_e > kh_b && kw_e > kw_b);
1096
1097 const auto buffer_offs
1098 = g * comp_ocb_sz + ocb * comp_ker_sz + k * comp_kw_sz;
1099 const auto wei_offs = (g * jcp.nb_oc + ocb) * wei_kd_sz
1100 + kd_b * wei_kh_sz + kh_b * wei_kw_sz + kw_b * wei_ic_sz;
1101
1102 jit_brgemm_conv_comp_pad_call_s p;
1103
1104 p.kd_l = kd_e - kd_b;
1105 p.kh_l = kh_e - kh_b;
1106 p.kw_l = kw_e - kw_b;
1107 p.ptr_in = &weights[wei_offs];
1108 p.ptr_zp_out = jcp.src_zero_point ? &src_zp_buffer[buffer_offs]
1109 : nullptr;
1110 p.ptr_cp_out = jcp.s8s8_avx512 ? &s8s8_comp_buffer[buffer_offs]
1111 : nullptr;
1112 (*comp_vpad_pbuffer_)(&p);
1113
1114 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, k, ker_vpad_sz);
1115 }
1116 });
1117 return status::success;
1118}
1119
1120template <cpu_isa_t isa, bool use_inversion>
1121void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork(
1122 char *dst_base, char *dst, char *c_buffer, const char *bias_w, int od,
1123 int oh, int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f,
1124 int kd_l, int kh_l, const void *post_ops_binary_rhs_arg_vec,
1125 const float *oscales, int32_t src_zp_vals, int32_t *src_zp_ptr,
1126 int32_t *dst_zp_ptr, int32_t *s8s8_compensation, bool maybe_do_init,
1127 bool do_postwork, bool do_post_comp) const {
1128
1129 const auto _pd = pd();
1130 const auto &jcp = _pd->jcp_;
1131
1132 const auto do_init
1133 = maybe_do_init && IMPLICATION(jcp.with_sum, jcp.use_buffer);
1134 if (!do_init && !do_postwork) return;
1135
1136 assert(!jcp.is_os_blocking);
1137
1138 const bool is_ow_tail = (OW - ow < jcp.ow_block);
1139
1140 const auto M = is_ow_tail ? jcp.M_tail : jcp.M;
1141 const auto kdh_l = kd_l * kh_l;
1142 const auto ow_s = (kdh_l <= 0) ? ow : ker_ow_s;
1143 const auto ow_f = (kdh_l <= 0) ? ow : ker_ow_f;
1144 assert(ow <= ow_s && ow_s <= ow_f && ow_f <= ow + M);
1145
1146 brgemm_kernel_post_ops_t p;
1147 if (do_postwork) {
1148 p.ptr_bias = (void *)(bias_w);
1149 p.ptr_scales = (void *)(&oscales[jcp.is_oc_scale * g_oc]);
1150 p.ptr_binary_post_ops_rhs = post_ops_binary_rhs_arg_vec;
1151 p.dst_orig = dst;
1152 p.c_zp_values = dst_zp_ptr;
1153 p.a_comp_val = src_zp_vals;
1154 }
1155
1156 auto call_outwork_ker = [&](bool is_postwork, bool has_postcomp,
1157 int ow_pw_s, int ow_pw_l) {
1158 auto ker_po_idx = get_ker_po_idx(ow_pw_l - 1, is_postwork, is_oc_tail);
1159 const auto outwork_ker = kernels_po_[ker_po_idx].get();
1160 assert(outwork_ker != nullptr && ow_pw_l == outwork_ker->brg.bcast_dim);
1161 if (is_postwork) {
1162 p.apply_comp = has_postcomp;
1163 p.a_zp_compensation = has_postcomp && jcp.src_zero_point
1164 ? &src_zp_ptr[ow_pw_s * jcp.LDB]
1165 : src_zp_ptr;
1166 p.s8s8_compensation = has_postcomp && jcp.s8s8_avx512
1167 ? &s8s8_compensation[ow_pw_s * jcp.LDB]
1168 : s8s8_compensation;
1169
1170 p.ptr_out = dst_base
1171 + dst_dsz
1172 * (od * dst_h_sz + oh * dst_w_sz
1173 + ow_pw_s * jcp.oc_without_padding);
1174 p.ptr_in = static_cast<void *>(jcp.use_buffer
1175 ? (c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC)
1176 : p.ptr_out);
1177 } else {
1178 p.apply_comp = has_postcomp;
1179 char *const ptr_Cz = jcp.use_buffer
1180 ? (c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC)
1181 : dst_base
1182 + dst_dsz
1183 * (od * dst_h_sz + oh * dst_w_sz
1184 + ow_pw_s * jcp.oc_without_padding);
1185 p.ptr_out = static_cast<void *>(ptr_Cz);
1186 }
1187 (*outwork_ker)(&p);
1188 };
1189
1190 if (ow < ow_s) {
1191 // left side
1192 const auto ow_pw_l = ow_s - ow;
1193 if (do_init) call_outwork_ker(false, false, ow, ow_pw_l);
1194 if (do_postwork) call_outwork_ker(true, do_post_comp, ow, ow_pw_l);
1195 }
1196 if (ow_f < ow + M) {
1197 // right side
1198 const auto ow_pw_l = ow + M - ow_f;
1199 if (do_init) call_outwork_ker(false, false, ow_f, ow_pw_l);
1200 if (do_postwork) call_outwork_ker(true, do_post_comp, ow_f, ow_pw_l);
1201 }
1202}
1203
1204template <cpu_isa_t isa, bool use_inversion>
1205void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel(
1206 brgemm_thread_ctx_t &btc, int brg_idx, int batch_size, char *ptr_C,
1207 char *ptr_D, const char *bias_w, int g_oc, bool do_postops,
1208 const void *binary_post_ops_rhs, int32_t src_zp_vals,
1209 int32_t *src_zp_ptr, int32_t *dst_zp_ptr, int32_t *s8s8_comp,
1210 bool do_only_comp) const {
1211
1212 const auto _pd = pd();
1213 const auto &jcp = _pd->jcp_;
1214
1215 const auto brg_ker = brg_kernels_[brg_idx].get();
1216 assert(brg_ker != nullptr);
1217
1218 // TODO: avoid costly tile reconfigurations
1219 if (is_amx) {
1220 if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
1221 AMX_PALETTE_SIZE)
1222 != 0) {
1223 amx_tile_configure(brg_kernel_palettes_[brg_idx].a);
1224 std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
1225 AMX_PALETTE_SIZE);
1226 }
1227 }
1228
1229 const auto do_only_pass_comp = !do_postops && jcp.src_zero_point
1230 && (jcp.req_brg_comp_pad || jcp.max_vpad > 0);
1231 const auto maybe_do_postops
1232 = one_of(true, do_postops, do_only_comp, do_only_pass_comp);
1233 if (maybe_do_postops) {
1234 const brgemm_post_ops_data_t post_ops_data {
1235 static_cast<const char *>(bias_w),
1236 &btc.oscales[jcp.is_oc_scale * g_oc], binary_post_ops_rhs,
1237 static_cast<size_t>(g_oc), 0, btc.brgemm_ctx.dst, 0,
1238 static_cast<void *>(src_zp_ptr), nullptr,
1239 static_cast<void *>(dst_zp_ptr), false, src_zp_vals,
1240 do_only_comp, do_only_pass_comp};
1241
1242 void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
1243 : static_cast<void *>(s8s8_comp);
1244
1245 if (do_postops)
1246 brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch,
1247 ptr_C, ptr_D, post_ops_data, scratch);
1248 else
1249 brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch,
1250 ptr_C, ptr_C, post_ops_data, scratch);
1251 } else
1252 brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C,
1253 static_cast<void *>(btc.wsp_tile));
1254}
1255
1256template <cpu_isa_t isa, bool use_inversion>
1257void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr,
1258 const char *__restrict src, char *__restrict inp_buffer,
1259 uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb,
1260 int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb,
1261 int last_ohb, int last_owb) const {
1262
1263 const auto _pd = pd();
1264 const auto &jcp = _pd->jcp_;
1265 const auto icb = icc * jcp.nb_ic_blocking;
1266
1267#define bmask(icb, odb, ohb, owb) \
1268 inp_buffer_mask[(((icb)*jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \
1269 * jcp.nb_ow \
1270 + (owb)]
1271
1272 if (jcp.copy_block_only) {
1273 if (last_g == g && last_n == n && last_icc == icc && last_odb == odb
1274 && last_ohb == ohb && last_owb == owb)
1275 return;
1276 } else {
1277 if (bmask(icb, odb, ohb, owb)) return;
1278 }
1279
1280 auto cp = jit_brgemm_conv_trans_kernel_call_s();
1281
1282 const auto prev_odb = (jcp.copy_block_only || odb == 0
1283 || bmask(icb, odb - 1, ohb, owb) == 0)
1284 ? false
1285 : true;
1286
1287 const auto prev_ohb = (jcp.copy_block_only || ohb == 0
1288 || bmask(icb, odb, ohb - 1, owb) == 0)
1289 ? false
1290 : true;
1291
1292 const auto prev_odb_ohb
1293 = (jcp.copy_block_only
1294 || (odb > 0 && ohb > 0
1295 && bmask(icb, odb - 1, ohb - 1, owb) == 0))
1296 ? false
1297 : true;
1298
1299 const auto ic = icb * jcp.ic_block;
1300 const auto g_ic = g * jcp.ic + ic;
1301 const auto oh = ohb * jcp.oh_block;
1302 const auto ow = owb * jcp.ow_block;
1303 const auto iw = nstl::max(0, ow * SW - LP);
1304
1305 int id_start {0}, id_end {0}, ih_start {0}, ih_end {0};
1306 int virt_id_start {0}, virt_id_end {0}, virt_ih_start {0}, virt_ih_end {0};
1307
1308 auto get_start_end = [](int &start, int &end, int &virt_start,
1309 int &virt_end, int b, int bs, int i, int o,
1310 int s, int p, int k, int d, bool prev) {
1311 const auto o_b = saturate(0, o, b * bs);
1312 const auto prev_o_b = saturate(0, o, (b - 1) * bs);
1313 const auto virt_cur_start = o_b * s - p;
1314 const auto cur_start = saturate(0, i, virt_cur_start);
1315 const auto virt_prev_start = prev_o_b * s - p;
1316 const auto i_bs = get_inp_size(i, bs, k, s, d);
1317 const auto virt_i_bs = calculate_end_padding(
1318 0, bs, 0, s, calculate_extended_filter_size(k, d));
1319 const auto virt_prev_end = prev ? virt_prev_start + virt_i_bs : -p;
1320 const auto prev_end = prev ? saturate(0, i, virt_prev_end) : 0;
1321 virt_start = nstl::max(virt_prev_end, virt_cur_start);
1322 start = nstl::max(prev_end, cur_start);
1323 virt_end = virt_cur_start + virt_i_bs;
1324 end = saturate(0, i, cur_start + i_bs);
1325 };
1326 get_start_end(id_start, id_end, virt_id_start, virt_id_end, odb,
1327 jcp.od_block, nstl::min(ID, IDP - FP), OD, SD, FP, KD, DD - 1,
1328 prev_odb && prev_odb_ohb);
1329 get_start_end(ih_start, ih_end, virt_ih_start, virt_ih_end, ohb,
1330 jcp.oh_block, nstl::min(IH, IHP - TP), OH, SH, TP, KH, DH - 1,
1331 prev_ohb && prev_odb_ohb);
1332
1333 // how many real data rows to copy (including padding)
1334 const auto rows_to_copy = ih_end - ih_start;
1335 cp.owb = owb;
1336 cp.ic = ic;
1337 const auto iw_buf = jcp.copy_block_only ? 0 : (ow * SW);
1338 dim_t inp_offset_start, out_offset_start;
1339
1340 for (int kh = 0; kh < jcp.kh_sets; kh++) {
1341 if (jcp.kh_sets > 1) {
1342 assert(!jcp.is_os_blocking);
1343 const auto ih_s = oh * SH + kh * DH - TP;
1344 const auto ih_f = (oh + jcp.oh_block - 1) * SH + kh * DH - TP + 1;
1345
1346 cp.t_pad = max(0, -ih_s);
1347 cp.b_pad = max(0, ih_f - jcp.ih);
1348 cp.h_count = max(0, jcp.oh_block);
1349 const auto ih_buf = (jcp.copy_block_only ? 0 : ih_start) + TP;
1350
1351 inp_offset_start = static_cast<dim_t>(n) * src_d_sz
1352 + max(ih_s, ih_start) * src_w_sz
1353 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
1354
1355 // inp_buffer has physical padding
1356 out_offset_start = (jcp.copy_block_only ? 0
1357 : static_cast<dim_t>(icb)
1358 * pbuf_d_sz)
1359 + ih_buf * pbuf_w_sz
1360 + (iw_buf * jcp.kh_sets + kh) * jcp.kw_sets * jcp.ic_block;
1361 } else {
1362 // For os_blocking:
1363 // We have to zero top and bottom padding now
1364 // taking into account that batch size is always the same (kh_s is 0 for os_blocking)
1365 // TODO: extend M_mask (may be different for different kh) to avoid copying
1366 // top/bottom padded rows and avoid extra calculations in kernel
1367 // also for convolutions with pw == 0 the copy routine maybe not needed
1368 cp.t_pad = jcp.is_os_blocking ? max(0, -virt_ih_start) : 0;
1369 cp.b_pad = jcp.is_os_blocking ? max(0, virt_ih_end - IH) : 0;
1370 cp.h_count = max(0, rows_to_copy) + cp.t_pad + cp.b_pad;
1371 const auto ih_buf
1372 = (jcp.copy_block_only ? 0 : ih_start) + TP - cp.t_pad;
1373
1374 inp_offset_start = static_cast<dim_t>(n) * src_d_sz
1375 + ih_start * src_w_sz
1376 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
1377
1378 // inp_buffer has physical padding
1379 out_offset_start = (jcp.copy_block_only ? 0
1380 : static_cast<dim_t>(icb)
1381 * pbuf_d_sz)
1382 + ih_buf * pbuf_w_sz
1383 + iw_buf * jcp.ic_block * jcp.kh_sets * jcp.kw_sets;
1384 }
1385
1386 for (int id = id_start; id < id_end; id++) {
1387 const auto inp_offset = inp_offset_start + id * src_h_sz;
1388 const auto id_buf = id - (jcp.copy_block_only ? id_start : 0) + FP;
1389 const auto out_offset = out_offset_start + id_buf * pbuf_h_sz;
1390 cp.src = src + src_dsz * inp_offset;
1391 cp.dst = inp_buffer + src_dsz * out_offset;
1392 (*copy_to_pbuffer_)(&cp);
1393 }
1394 }
1395 if (!jcp.copy_block_only) bmask(icb, odb, ohb, owb) = 1;
1396
1397#undef bmask
1398}
1399
1400#define BRGEMM_CONV_KER_HEADER \
1401 const char *const __restrict src = btc.brgemm_ctx.src; \
1402 const char *const __restrict weights = btc.brgemm_ctx.weights; \
1403 const char *const __restrict bias = btc.brgemm_ctx.bias; \
1404 char *const __restrict dst = btc.brgemm_ctx.dst; \
1405 const std::vector<const void *> &post_ops_binary_rhs_arg_vec \
1406 = btc.brgemm_ctx.post_ops_binary_rhs_arg_vec; \
1407 const int oc = btc.ocb * jcp.oc_block; \
1408 const int g_oc = btc.g * jcp.oc + oc; \
1409 const int icb = btc.icc * jcp.nb_ic_blocking; \
1410 const int ic = icb * jcp.ic_block; \
1411 const int g_ic = btc.g * jcp.ic + ic; \
1412 const int ow = btc.owb * jcp.ow_block; \
1413 const int oh = btc.ohb * jcp.oh_block; \
1414 const int iid = ndims_pick(btc.od * SD - FP, 0, 0); \
1415 const int kd_s = ndims_pick(div_up(max(0, -iid), DD), 0, 0); \
1416 const int kd_f = ndims_pick( \
1417 KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD), 1, 1); \
1418 const auto kd_l = kd_f - kd_s; \
1419 const auto iih = ndims_pick(btc.oh * SH - TP, btc.oh * SH - TP, 0); \
1420 const auto kh_s_ = div_up(max(0, -iih), DH); \
1421 const auto kh_s = jcp.is_os_blocking ? 0 : ndims_pick(kh_s_, kh_s_, 0); \
1422 const auto kh_f_ = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH); \
1423 const auto kh_f = ndims_pick(kh_f_, kh_f_, 1); \
1424 const auto kh_l = kh_f - kh_s; \
1425 const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block); \
1426 const bool is_ic_tail = (btc.icc == _pd->ic_chunks - 1 \
1427 && ((jcp.ic - ic) % jcp.ic_block != 0)); \
1428 const bool is_ow_tail = (OW - ow < jcp.ow_block); \
1429 const bool is_oh_tail = (OH - oh < jcp.oh_block); \
1430 const char *const __restrict bias_w \
1431 = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr; \
1432 const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb) \
1433 - (is_ic_tail ? 1 : 0); \
1434 char *const __restrict dst_base \
1435 = dst + dst_dsz * (btc.n * dst_d_sz + g_oc); \
1436 char *ptr_C; \
1437 char *ptr_D; \
1438 int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0), iiw_b(0);
1439
1440template <cpu_isa_t isa, bool use_inversion>
1441void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base(
1442 brgemm_thread_ctx_t &btc) const {
1443
1444 const auto _pd = pd();
1445 const auto &jcp = _pd->jcp_;
1446 auto ndims = _pd->ndims();
1447
1448 BRGEMM_CONV_KER_HEADER;
1449 MAYBE_UNUSED(is_ow_tail);
1450 MAYBE_UNUSED(is_oh_tail);
1451
1452 int kw_s {0}, kw_full_s {0}, kw_f {0}, kw_full_f {0}, kw_b(0), kw_e(0);
1453
1454 get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f);
1455
1456 const auto src_base = src + src_dsz * (btc.n * src_d_sz + g_ic);
1457 const auto wei_base
1458 = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz);
1459
1460 const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks,
1461 int32_t *src_zp, int32_t *s8s8_comp,
1462 bool do_postops, bool do_only_comp) {
1463 if (k_l <= 0) return;
1464
1465 for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) {
1466 const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block;
1467 const auto src_ic = ic_off;
1468 const auto wei_ic = ic + ic_off;
1469 const auto n_icb_off = i_icb * k_l;
1470 const auto src_base_ic = src_base + src_dsz * src_ic;
1471 const auto wei_base_ic = wei_base + wei_dsz * wei_ic * jcp.oc_block;
1472
1473 auto k = 0;
1474 for (int kd = kd_b; kd < kd_e; kd++) {
1475 const auto id = iid + kd * DD;
1476 const auto src_base_kd = src_base_ic + src_dsz * id * src_h_sz;
1477 const auto wei_base_kd = wei_base_ic
1478 + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz;
1479 for (int kh = kh_b; kh < kh_e; kh++) {
1480 const auto ih = iih + kh * DH;
1481 const auto src_base_kh
1482 = src_base_kd + src_dsz * ih * src_w_sz;
1483 const auto wei_base_kh = wei_base_kd
1484 + wei_dsz * maybe_invert(kh, KH) * wei_kw_sz;
1485 for (int kw = kw_b; kw < kw_e; kw++) {
1486 const auto iw = iiw_b + kw * DW;
1487 btc.brg_batch[n_icb_off + k].ptr.A = src_base_kh
1488 + src_dsz * iw * jcp.ngroups
1489 * jcp.ic_without_padding;
1490 btc.brg_batch[n_icb_off + k].vvpad.top = 0;
1491 btc.brg_batch[n_icb_off + k].vvpad.bottom = 0;
1492 // general wei layout is gOdhwI<block_o><block_i>
1493 btc.brg_batch[n_icb_off + k].ptr.B = wei_base_kh
1494 + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz;
1495 k++;
1496 }
1497 }
1498 }
1499 }
1500 call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D,
1501 bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(),
1502 btc.src_zp_vals, src_zp, btc.dst_zp_vals, s8s8_comp,
1503 do_only_comp);
1504 };
1505
1506 const auto kdhw_loop = [&]() {
1507 if (kw_e - kw_b <= 0) return;
1508 int ow_b {0}, ow_e {0};
1509 get_ow_range(ow, kw_b, ow_b, ow_e);
1510
1511 const auto do_init
1512 = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s && kw_b == kw_s;
1513 const auto do_postwork = _pd->need_postwork
1514 && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f
1515 && kh_e == kh_f && kw_e == kw_f;
1516 const auto do_only_comp = need_compensation && kd_e == kd_f
1517 && kh_e == kh_f && kw_e != kw_f
1518 && btc.icc == (_pd->ic_chunks - 1);
1519 if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return;
1520
1521 k_l = (kd_e - kd_b) * (kh_e - kh_b) * (kw_e - kw_b);
1522 iiw_b = ow_b * SW - LP;
1523 ptr_D = dst_base
1524 + dst_dsz
1525 * (btc.od * dst_h_sz + btc.oh * dst_w_sz
1526 + ow_b * jcp.oc_without_padding);
1527 ptr_C = (jcp.use_buffer)
1528 ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC
1529 : static_cast<char *>(ptr_D);
1530
1531 const auto ow_l = ow_e - ow_b;
1532 assert(0 <= ow_l && ow_l <= jcp.ow_block);
1533
1534 const auto comp_ker_offs = get_comp_offset(
1535 btc.g, btc.ocb, ow_b, kd_s, kd_f, kh_s, kh_f, kw_b, kw_e);
1536
1537 const auto ker_i = ow_l - 1;
1538 int kernel_idx[2][2];
1539 kernel_idx[false][false]
1540 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false);
1541 kernel_idx[true][false]
1542 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false);
1543 kernel_idx[false][true]
1544 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true);
1545 kernel_idx[true][true]
1546 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true);
1547
1548 if (ow_l > 0 && k_l > 0) {
1549 if (nb_ic_b > 0) {
1550 const auto brg_idx = kernel_idx[do_init][false];
1551 call_brgemm(brg_idx, 0, nb_ic_b,
1552 jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_ker_offs]
1553 : nullptr,
1554 jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_ker_offs]
1555 : nullptr,
1556 do_postwork && !is_ic_tail, do_only_comp);
1557 }
1558
1559 if (is_ic_tail) {
1560 const auto use_init_ker = (do_init && nb_ic_b == 0);
1561 const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true];
1562 call_brgemm(brg_ic_tail_idx, nb_ic_b, 1,
1563 jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_ker_offs]
1564 : nullptr,
1565 jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_ker_offs]
1566 : nullptr,
1567 do_postwork, do_only_comp);
1568 }
1569 }
1570
1571 perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow,
1572 g_oc, is_oc_tail, ow_b, ow_e, kd_l, kh_l,
1573 post_ops_binary_rhs_arg_vec.data(), btc.oscales,
1574 btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1575 btc.s8s8_comp_ptr, do_init, do_postwork, false);
1576 };
1577
1578 if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) {
1579 // kw values with left padding
1580 if (kw_s < kw_full_s) {
1581 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK_PAD) {
1582 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK_PAD);
1583 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK_PAD) {
1584 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK_PAD);
1585 for (auto kw = kw_s; kw < kw_full_s; kw++) {
1586 kw_b = kw;
1587 kw_e = kw + 1;
1588 kdhw_loop();
1589 }
1590 }
1591 }
1592 }
1593
1594 // kw values covering full ow_block
1595 if (kw_full_s < kw_full_f) {
1596 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) {
1597 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK);
1598 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) {
1599 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK);
1600 for (kw_b = kw_full_s; kw_b < kw_full_f; kw_b += KW_BLOCK) {
1601 kw_e = nstl::min(kw_full_f, kw_b + KW_BLOCK);
1602 kdhw_loop();
1603 }
1604 }
1605 }
1606 }
1607
1608 // kw values with right padding
1609 if (kw_full_f < kw_f) {
1610 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK_PAD) {
1611 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK_PAD);
1612 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK_PAD) {
1613 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK_PAD);
1614 for (int kw = kw_full_f; kw < kw_f; kw++) {
1615 kw_b = kw;
1616 kw_e = kw + 1;
1617 kdhw_loop();
1618 }
1619 }
1620 }
1621 }
1622 } else {
1623 const auto do_init = btc.icc == 0;
1624 const auto do_postwork
1625 = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1);
1626 perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow,
1627 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
1628 post_ops_binary_rhs_arg_vec.data(), btc.oscales,
1629 btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1630 btc.s8s8_comp_ptr, do_init, do_postwork, false);
1631 }
1632}
1633
1634template <cpu_isa_t isa, bool use_inversion>
1635void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans(
1636 brgemm_thread_ctx_t &btc, char *inp_buffer) const {
1637
1638 const auto _pd = pd();
1639 const auto &jcp = _pd->jcp_;
1640 auto ndims = _pd->ndims();
1641
1642 BRGEMM_CONV_KER_HEADER;
1643 MAYBE_UNUSED(g_ic);
1644 MAYBE_UNUSED(src);
1645
1646 const auto wei_base
1647 = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz);
1648 const int ow_b {ow},
1649 ow_e {ow + (is_ow_tail ? jcp.ow % jcp.ow_block : jcp.ow_block)};
1650 const int oh_b {oh},
1651 oh_e {oh + (is_oh_tail ? jcp.oh % jcp.oh_block : jcp.oh_block)};
1652 iiw_b = ow_b * SW - LP;
1653 ptr_D = dst_base
1654 + dst_dsz
1655 * (btc.od * dst_h_sz + btc.oh * dst_w_sz
1656 + ow_b * jcp.oc_without_padding);
1657 ptr_C = (jcp.use_buffer) ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC
1658 : static_cast<char *>(ptr_D);
1659
1660 const auto ow_l = ow_e - ow_b;
1661 const auto oh_l = oh_e - oh_b;
1662 assert(0 <= ow_l && ow_l <= jcp.ow_block && 0 <= oh_l
1663 && oh_l <= jcp.oh_block);
1664
1665 const auto ker_i = (jcp.is_os_blocking ? oh_l * ow_l : ow_l) - 1;
1666
1667 const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks,
1668 bool do_postops) {
1669 if (k_l <= 0) return;
1670
1671 const auto kh_ee = jcp.kh_sets > 1 ? kh_b + 1 : kh_e;
1672 const auto kw_e = jcp.kw_sets > 1 ? 1 : KW;
1673 const auto pbuf_base = inp_buffer
1674 + src_dsz
1675 * ((jcp.copy_block_only
1676 ? 0
1677 : ((icb + ic_block_s) * pbuf_d_sz)));
1678 const auto iid_shift = jcp.copy_block_only
1679 ? nstl::max(0, btc.odb * jcp.od_block * SD - FP)
1680 : 0;
1681 const auto iih_shift = jcp.copy_block_only
1682 ? nstl::max(0, btc.ohb * jcp.oh_block * SH - TP)
1683 : 0;
1684 const auto iiw_shift
1685 = jcp.copy_block_only ? (btc.owb * jcp.ow_block * SW) : 0;
1686
1687 for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) {
1688 const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block;
1689 const auto wei_ic = ic + ic_off;
1690 const auto n_icb_off = i_icb * k_l;
1691 const auto pbuf_base_ic = pbuf_base
1692 + src_dsz
1693 * ((jcp.copy_block_only ? 0 : (i_icb * pbuf_d_sz)));
1694 const auto wei_base_ic = wei_base + wei_dsz * wei_ic * jcp.oc_block;
1695
1696 auto k = 0;
1697 for (int kd = kd_b; kd < kd_e; kd++) {
1698 const auto id = iid - iid_shift + kd * DD + FP;
1699 const auto pbuf_base_kd
1700 = pbuf_base_ic + src_dsz * id * pbuf_h_sz;
1701 const auto wei_base_kd = wei_base_ic
1702 + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz;
1703 for (int kh = kh_b; kh < kh_ee; kh++) {
1704 const auto ih = jcp.kh_sets > 1
1705 ? (iih + 2 * TP)
1706 : (iih - iih_shift + kh * DH + TP);
1707 const auto pbuf_base_kh
1708 = pbuf_base_kd + src_dsz * ih * pbuf_w_sz;
1709 const auto wei_base_kh = wei_base_kd
1710 + wei_dsz
1711 * ((jcp.kh_sets > 1 ? 0
1712 : maybe_invert(kh, KH))
1713 * wei_kw_sz);
1714 for (int kw = 0; kw < kw_e; kw++) {
1715 const auto iw = iiw_b - iiw_shift + kw * DW + LP;
1716 // inp_buffer layout is Cdhw<ic_block>c
1717 btc.brg_batch[n_icb_off + k].ptr.A = pbuf_base_kh
1718 + src_dsz * iw * jcp.ic_block * jcp.kh_sets
1719 * jcp.kw_sets;
1720 btc.brg_batch[n_icb_off + k].vvpad.top = 0;
1721 btc.brg_batch[n_icb_off + k].vvpad.bottom = 0;
1722 // general wei layout is gOdhwI<block_o><block_i>
1723 btc.brg_batch[n_icb_off + k].ptr.B = wei_base_kh
1724 + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz;
1725 k++;
1726 }
1727 }
1728 }
1729 }
1730
1731 call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D,
1732 bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(),
1733 btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1734 btc.s8s8_comp_ptr, false);
1735 };
1736
1737 const auto kdhw_loop = [&]() {
1738 const auto do_init = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s;
1739 const auto do_postwork = _pd->need_postwork
1740 && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f
1741 && kh_e == kh_f;
1742 if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return;
1743
1744 k_l = (kd_e - kd_b) * (jcp.kh_sets > 1 ? 1 : (kh_e - kh_b))
1745 * (jcp.kw_sets > 1 ? 1 : KW);
1746
1747 int kernel_idx[2][2];
1748 kernel_idx[false][false]
1749 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false);
1750 kernel_idx[true][false]
1751 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false);
1752 kernel_idx[false][true]
1753 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true);
1754 kernel_idx[true][true]
1755 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true);
1756
1757 if (nb_ic_b > 0) {
1758 const auto brg_idx = kernel_idx[do_init][false];
1759 call_brgemm(brg_idx, 0, nb_ic_b, do_postwork && !is_ic_tail);
1760 }
1761
1762 if (is_ic_tail) {
1763 const auto use_init_ker = (do_init && nb_ic_b == 0);
1764 const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true];
1765 call_brgemm(brg_ic_tail_idx, nb_ic_b, 1, do_postwork);
1766 }
1767 };
1768
1769 if (kd_f > kd_s && kh_f > kh_s) {
1770 // kw values covering full ow_block
1771 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) {
1772 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK);
1773 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) {
1774 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK);
1775 kdhw_loop();
1776 }
1777 }
1778 } else {
1779 const auto do_init = btc.icc == 0;
1780 const auto do_postwork
1781 = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1);
1782 perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow,
1783 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
1784 post_ops_binary_rhs_arg_vec.data(), btc.oscales,
1785 btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1786 btc.s8s8_comp_ptr, do_init, do_postwork, false);
1787 }
1788}
1789
1790template <cpu_isa_t isa, bool use_inversion>
1791void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(
1792 brgemm_thread_ctx_t &btc) const {
1793
1794 const auto _pd = pd();
1795 const auto &jcp = _pd->jcp_;
1796 auto ndims = _pd->ndims();
1797
1798 BRGEMM_CONV_KER_HEADER;
1799 MAYBE_UNUSED(is_oh_tail);
1800
1801 const char *const __restrict src_base
1802 = src + src_dsz * (btc.n * src_d_sz + g_ic);
1803
1804 const char *const __restrict wei_base
1805 = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz);
1806
1807 const int ow_b {ow}, ow_e {ow + (is_ow_tail ? jcp.M_tail : jcp.M)};
1808 iiw_b = ow_b * SW - LP;
1809 ptr_D = dst_base
1810 + dst_dsz
1811 * (btc.od * dst_h_sz + btc.oh * dst_w_sz
1812 + ow_b * jcp.oc_without_padding);
1813 ptr_C = (jcp.use_buffer) ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC
1814 : static_cast<char *>(ptr_D);
1815
1816 const auto ow_l = ow_e - ow_b;
1817 assert(0 <= ow_l && ow_l <= jcp.ow_block);
1818 const auto ker_i = ow_l - 1;
1819 const dim_t *const __restrict kw_top_vpads
1820 = owb_kw_top_vpads.data() + btc.owb * KW;
1821 const dim_t *const __restrict kw_bottom_vpads
1822 = owb_kw_bottom_vpads.data() + btc.owb * KW;
1823
1824 const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks,
1825 int32_t *src_zp, int32_t *s8s8_comp,
1826 bool do_postops) {
1827 for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) {
1828 const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block;
1829 const auto src_ic = ic_off;
1830 const auto wei_ic = ic + ic_off;
1831 const auto n_icb_off = i_icb * k_l;
1832 const char *const __restrict src_base_ic
1833 = src_base + src_dsz * src_ic;
1834 const char *const __restrict wei_base_ic
1835 = wei_base + wei_dsz * wei_ic * jcp.oc_block;
1836 brgemm_batch_element_t *const __restrict icb_batch
1837 = btc.brg_batch + n_icb_off;
1838
1839 auto k = 0;
1840 for (int kd = kd_b; kd < kd_e; kd++) {
1841 const auto id = iid + kd * DD;
1842 const char *const __restrict src_base_kd
1843 = src_base_ic + src_dsz * id * src_h_sz;
1844 const char *const __restrict wei_base_kd = wei_base_ic
1845 + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz;
1846 for (int kh = kh_b; kh < kh_e; kh++) {
1847 const auto ih = iih + kh * DH;
1848 const char *const __restrict src_base_kh
1849 = src_base_kd + src_dsz * ih * src_w_sz;
1850 const char *const __restrict wei_base_kh = wei_base_kd
1851 + wei_dsz * maybe_invert(kh, KH) * wei_kw_sz;
1852 for (int kw = 0; kw < KW; kw++) {
1853 const auto iw = iiw_b + kw * DW;
1854 const auto ptr_A = src_base_kh
1855 + static_cast<ptrdiff_t>(src_dsz) * iw
1856 * jcp.ngroups * jcp.ic_without_padding;
1857 if (jcp.max_vpad) {
1858 icb_batch[k].vvpad.top = kw_top_vpads[kw];
1859 icb_batch[k].vvpad.bottom = kw_bottom_vpads[kw];
1860 }
1861 // general wei layout is gOdhwI<block_o><block_i>
1862 const auto ptr_B = wei_base_kh
1863 + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz;
1864
1865 icb_batch[k].ptr.A = ptr_A;
1866 icb_batch[k].ptr.B = ptr_B;
1867
1868 k++;
1869 }
1870 }
1871 }
1872 }
1873
1874 call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D,
1875 bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(),
1876 btc.src_zp_vals, src_zp, btc.dst_zp_vals, s8s8_comp, false);
1877 };
1878
1879 const auto kdhw_loop = [&]() {
1880 const auto do_init = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s;
1881 const auto do_postwork = _pd->need_postwork
1882 && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f
1883 && kh_e == kh_f;
1884
1885 if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return;
1886
1887 k_l = (kd_e - kd_b) * (kh_e - kh_b) * KW;
1888 int kernel_idx[2][2];
1889 kernel_idx[false][false]
1890 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false);
1891 kernel_idx[true][false]
1892 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false);
1893 kernel_idx[false][true]
1894 = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true);
1895 kernel_idx[true][true]
1896 = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true);
1897
1898 const auto comp_offs = get_comp_offset(
1899 btc.g, btc.ocb, ow, kd_b, kd_e, kh_b, kh_e, 0, KW);
1900
1901 if (nb_ic_b > 0) {
1902 const auto brg_idx = kernel_idx[do_init][false];
1903 call_brgemm(brg_idx, 0, nb_ic_b,
1904 jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_offs]
1905 : nullptr,
1906 jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_offs] : nullptr,
1907 do_postwork && !is_ic_tail);
1908 }
1909
1910 if (is_ic_tail) {
1911 const auto use_init_ker = (do_init && nb_ic_b == 0);
1912 const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true];
1913 call_brgemm(brg_ic_tail_idx, nb_ic_b, 1,
1914 jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_offs]
1915 : nullptr,
1916 jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_offs] : nullptr,
1917 do_postwork);
1918 }
1919 };
1920
1921 if (kd_f > kd_s && kh_f > kh_s) {
1922 // kw values covering full ow_block
1923 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) {
1924 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK);
1925 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) {
1926 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK);
1927 kdhw_loop();
1928 }
1929 }
1930 } else {
1931 const auto do_init = btc.icc == 0;
1932 const auto do_postwork
1933 = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1);
1934 perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow,
1935 g_oc, is_oc_tail, ow, ow, kd_l, kh_l,
1936 post_ops_binary_rhs_arg_vec.data(), btc.oscales,
1937 btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals,
1938 btc.s8s8_comp_ptr, do_init, do_postwork, false);
1939 }
1940}
1941
1942#undef BRGEMM_CONV_KER_HEADER
1943
1944template struct brgemm_convolution_fwd_t<avx2>;
1945template struct brgemm_convolution_fwd_t<avx2_vnni_2>;
1946template struct brgemm_convolution_fwd_t<avx2_vnni_2, true>;
1947template struct brgemm_convolution_fwd_t<avx512_core>;
1948template struct brgemm_convolution_fwd_t<avx512_core, true>;
1949template struct brgemm_convolution_fwd_t<avx512_core_vnni>;
1950template struct brgemm_convolution_fwd_t<avx512_core_bf16>;
1951template struct brgemm_convolution_fwd_t<avx512_core_bf16, true>;
1952template struct brgemm_convolution_fwd_t<avx512_core_fp16>;
1953template struct brgemm_convolution_fwd_t<avx512_core_fp16, true>;
1954template struct brgemm_convolution_fwd_t<avx512_core_amx>;
1955template struct brgemm_convolution_fwd_t<avx512_core_amx, true>;
1956template struct brgemm_convolution_fwd_t<avx512_core_amx_fp16>;
1957template struct brgemm_convolution_fwd_t<avx512_core_amx_fp16, true>;
1958
1959} // namespace x64
1960
1961} // namespace cpu
1962} // namespace impl
1963} // namespace dnnl
1964
1965// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1966