1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/cpu_primitive.hpp"
24#include "cpu/scale_utils.hpp"
25
26#include "cpu/x64/amx_tile_configure.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28#include "cpu/x64/jit_brgemm_1x1_conv.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::status;
36using namespace dnnl::impl::memory_tracking::names;
37using namespace dnnl::impl::utils;
38
39using namespace nstl;
40using namespace data_type;
41
42#define ndims_pick(v5, v4, v3) \
43 ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
44
45template <cpu_isa_t isa>
46status_t brgemm_1x1_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
47 using namespace data_type;
48 using namespace utils;
49
50 const auto src_type = src_md(0)->data_type;
51 const auto wei_type = weights_md(0)->data_type;
52 const auto dst_type = dst_md(0)->data_type;
53 const bool is_int8 = one_of(src_type, u8, s8);
54
55 using skip_mask_t = primitive_attr_t::skip_mask_t;
56 auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt
57 | skip_mask_t::zero_points_runtime;
58 if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::scales_runtime;
59
60 bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct)
61 && expect_data_types(src_type, wei_type, data_type::undef, dst_type,
62 data_type::undef)
63 && IMPLICATION(is_int8,
64 one_of(bias_md_.data_type, data_type::undef, f32, s32, s8,
65 u8))
66 && IMPLICATION(!is_int8,
67 one_of(bias_md_.data_type, data_type::undef, f32, src_type))
68 && attr()->has_default_values(skip_mask, dst_type)
69 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
70 && !has_zero_dim_memory() && zero_points_ok() && arg_scales_ok();
71 if (!ok) return status::unimplemented;
72
73 CHECK(brgemm_convolution_utils::init_1x1_conf(jcp_, isa, *desc(), src_md_,
74 weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads()));
75
76 for (int i = 0; i < 16; i++)
77 brgs_[i].bcast_dim = brgs_[i].load_dim = brgs_[i].reduce_dim = 0;
78
79 const float alpha = 1.0;
80 const float beta = 1.0;
81 const auto &p = attr()->post_ops_;
82 const int sum_idx = p.find(primitive_kind::sum);
83 with_sum = (sum_idx != -1);
84 sum_scale = with_sum ? p.entry_[sum_idx].sum.scale : 0.0;
85
86 ic_chunks = div_up(jcp_.nb_ic, jcp_.nb_ic_blocking);
87 need_postwork = jcp_.with_bias || jcp_.with_eltwise || jcp_.with_binary
88 || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed
89 || (jcp_.dst_dt != jcp_.acc_dt) || jcp_.with_sum;
90
91 int i_init_begin = (ic_chunks == 1) ? 1 : 0;
92 int i_init_end = 2;
93
94 for_(int i_M = 0; i_M < 2; i_M++)
95 for_(int i_N = 0; i_N < 2; i_N++)
96 for_(int i_K = 0; i_K < 2; i_K++)
97 for (int i_init = i_init_begin; i_init < i_init_end; i_init++) {
98 auto vbeta = (i_init) ? 0 : beta;
99 auto vM = (i_M) ? jcp_.M_tail : jcp_.M;
100 auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
101 auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
102 brgemm_t &brg = brgs_[get_brg_idx(i_init, i_M, i_N, i_K)];
103 if (vM == 0 || vN == 0 || vK == 0) continue;
104 brgemm_strides_t brg_strides;
105 brg_strides.stride_a = jcp_.brg_stride_a;
106 brg_strides.stride_b = jcp_.brg_stride_b;
107 const auto strides_ptr
108 = (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr;
109 CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, src_type, wei_type,
110 false, false, brgemm_row_major, alpha, vbeta, jcp_.LDA,
111 jcp_.LDB, jcp_.LDC, vM, vN, vK, strides_ptr));
112
113 brgemm_attr_t brgattr;
114 brgattr.max_bs = jcp_.gemm_batch_size;
115 brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
116 ? brgemm_bd_loop_innermost
117 : brgemm_ld_loop_innermost;
118 brgattr.max_top_vpad = jcp_.max_vpad;
119 brgattr.max_bottom_vpad = jcp_.max_vpad;
120
121 // assuming 2x2 decomposition in amx brgemm kernel
122 const auto bd_blocking = 2 * jcp_.amx_h;
123 brgattr.hint_expected_A_size = bd_blocking * vK;
124 brgattr.hint_expected_B_size = vN * vK;
125 brgattr.hint_expected_C_size = bd_blocking * vN;
126
127 brgattr.wary_tail_read = false;
128 brgattr.use_uker = jcp_.use_uker;
129 brgattr.use_interleave_stores = brgattr.use_uker;
130 brgattr.hint_prefetching = jcp_.hint_prefetching;
131 brgattr.fpmath_mode = attr()->fpmath_mode_;
132 // if post-ops are required and there are no intermediate calculations
133 // (like ic_chunks > 1) then we don't need code without post-ops in
134 // brgemm kernel
135 if (need_postwork && ic_chunks == 1) brgattr.postops_only = true;
136
137 CHECK(brgemm_desc_set_attr(&brg, brgattr));
138 auto LDD = jcp_.oc_without_padding;
139 brg.with_sum = with_sum;
140 CHECK(brgemm_desc_set_postops(
141 &brg, attr(), &dst_md_, LDD, jcp_.bia_dt));
142 jcp_.amx_buf_size_per_thread = nstl::max(
143 brg.get_wsp_buffer_size(), jcp_.amx_buf_size_per_thread);
144 }
145
146 brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_);
147 auto scratchpad = scratchpad_registry().registrar();
148 brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
149 if (jcp_.with_scales)
150 book_precomputed_scales(scratchpad, attr()->scales_, OC());
151
152 return status::success;
153}
154
155template <cpu_isa_t isa>
156status_t brgemm_1x1_convolution_fwd_t<isa>::init(engine_t *engine) {
157 auto ndims = pd()->ndims();
158 if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!");
159
160 const auto &jcp = pd()->jcp_;
161
162 ID = ndims_pick(jcp.id, 1, 1);
163 IH = ndims_pick(jcp.ih, jcp.ih, 1);
164 IW = jcp.iw;
165
166 OD = ndims_pick(jcp.od, 1, 1);
167 OH = ndims_pick(jcp.oh, jcp.oh, 1);
168 OW = jcp.ow;
169
170 SD = ndims_pick(jcp.stride_d, 1, 1);
171 SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1);
172 SW = jcp.stride_w;
173
174 bia_dsz = jcp.bia_dsz;
175 acc_dsz = jcp.acc_dsz;
176 src_dsz = jcp.src_dsz;
177 wei_dsz = jcp.wei_dsz;
178
179 // const variables used for address calculations
180 src_w_sz = (dim_t)IW * jcp.ngroups * jcp.ic_without_padding;
181 src_h_sz = IH * src_w_sz;
182 src_d_sz = ID * src_h_sz;
183 dst_w_sz = (dim_t)OW * jcp.oc_without_padding;
184 dst_h_sz = OH * dst_w_sz;
185 dst_d_sz = OD * dst_h_sz;
186
187 const auto src_type = pd()->src_md(0)->data_type;
188
189 const auto last_ic_block
190 = src_type == f16 ? 1 : data_type_vnni_granularity(src_type);
191
192 wei_oc_sz = jcp.wei_plain ? jcp.oc : jcp.oc_block;
193 wei_ic_sz = jcp.wei_plain
194 ? (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc
195 : (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc_block;
196 wei_ocb_sz = jcp.wei_plain ? jcp.oc_block * last_ic_block
197 : jcp.nb_oc * wei_ic_sz;
198
199 for (int i = 0; i < 16; i++)
200 brg_kernels_[i] = nullptr;
201
202 if (jcp.is_rtus) {
203 CHECK(safe_ptr_assign(rtus_kernel_,
204 new jit_avx512_core_brgemm_conv_trans_kernel::
205 jit_avx512_core_brgemm_conv_rtus_kernel_t(jcp)));
206 CHECK(rtus_kernel_->create_kernel());
207 }
208 int i_init_begin = (pd()->ic_chunks == 1) ? 1 : 0;
209 int i_init_end = 2;
210
211 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
212 for_(int i_M = 0; i_M < 2; i_M++)
213 for_(int i_N = 0; i_N < 2; i_N++)
214 for_(int i_K = 0; i_K < 2; i_K++)
215 for (int i_init = i_init_begin; i_init < i_init_end; i_init++) {
216 auto brg_idx = get_brg_idx(i_init, i_M, i_N, i_K);
217 auto &brg = pd()->brgs_[brg_idx];
218 if (brg.bcast_dim > 0 && brg.load_dim > 0 && brg.reduce_dim > 0
219 && !brg_kernels_[brg_idx]) {
220 brgemm_kernel_t *brg_kernel = nullptr;
221 CHECK(brgemm_kernel_create(&brg_kernel, brg));
222 CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
223 if (is_amx) {
224 amx_palette_t tmp;
225 int &palette_idx = brg_kernel_palette_idx_[brg_idx];
226 palette_idx = -1;
227 CHECK(brgemm_init_tiles(brg, tmp.p));
228 // check if it's in set of tile configs
229 for (size_t i = 0; i < brg_kernel_palette_.size(); i++) {
230 const bool is_match = 0
231 == std::memcmp(brg_kernel_palette_[i].p, tmp.p,
232 AMX_PALETTE_SIZE);
233 if (is_match) {
234 palette_idx = i;
235 break;
236 }
237 }
238 // add to set of tile configs if needed
239 if (palette_idx == -1) {
240 palette_idx = brg_kernel_palette_.size();
241 brg_kernel_palette_.push_back(tmp);
242 }
243 }
244 }
245 }
246 return status::success;
247}
248
249template <cpu_isa_t isa>
250void brgemm_1x1_convolution_fwd_t<isa>::maybe_rtus(int ithr,
251 const char *__restrict src, char *__restrict inp_buffer,
252 uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int od,
253 int oh, int ow) const {
254 const auto &jcp = pd()->jcp_;
255 if (!jcp.is_rtus) return;
256 assert(jcp.is_os_blocking);
257 const size_t src_dt_size = jcp.src_dsz;
258
259 const auto os = (od * OH + oh) * OW + ow;
260 const auto osb = os / jcp.os_block;
261
262 uint8_t *bmask = &inp_buffer_mask[icc * jcp.nb_os + osb];
263 if (bmask && *bmask) return; // skip if already masked
264 if (bmask) *bmask = 1; // set mask to skip next time
265
266 const auto g_ic = g * jcp.ic_without_padding
267 + icc * jcp.nb_ic_blocking * jcp.ic_block;
268
269 auto call_kernel = [&](int nh, int nw, int od, int oh, int ow) {
270 assert(nh == 0 || (nw == 0 && ow == 0));
271 if (utils::everyone_is(0, nh, nw)) return;
272 const int id = od * jcp.stride_d;
273 const int ih = oh * jcp.stride_h;
274 const int iw = ow * jcp.stride_w;
275 const auto inp_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
276 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
277 auto p = jit_avx512_core_brgemm_conv_trans_kernel::
278 jit_brgemm_conv_trans_kernel_call_s();
279 p.h_count = nh;
280 p.owb = nw;
281 p.src = src + src_dt_size * inp_offset;
282 p.dst = inp_buffer;
283 (*rtus_kernel_)(&p);
284 inp_buffer += src_dt_size * (nh * jcp.ow + nw) * jcp.LDA;
285 };
286
287 const bool is_os_tail = jcp.os - os < jcp.os_block;
288 int count = is_os_tail ? jcp.M_tail : jcp.M;
289
290 if (count < OW || ow > 0) {
291 // copy to end of row
292 const auto nw = nstl::min(count, OW - ow);
293 call_kernel(0, nw, od, oh, ow);
294 count -= nw;
295 if (count == 0) return;
296 ow = 0;
297 oh = (oh + 1) % OH;
298 if (oh == 0) od++;
299 }
300
301 while (od < OD) {
302 // copy to end of column
303 const auto nh = nstl::min(count / OW, OH - oh);
304 call_kernel(nh, 0, od, oh, ow);
305 count -= nh * OW;
306 if (count == 0) return;
307 oh = (oh + nh) % OH;
308 if (oh == 0) od++;
309 if (count < OW) {
310 // copy partial row
311 const auto nw = count;
312 call_kernel(0, nw, od, oh, ow);
313 return;
314 }
315 }
316}
317
318template <cpu_isa_t isa>
319void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
320 const brgemm_exec_ctx_t &brgemm_ctx, int ithr,
321 brgemm_batch_element_t *const __restrict brg_batch,
322 char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
323 int od, int oh, int ow, int icc, int *last_palette_idx,
324 const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
325 int32_t *dst_zp_vals, int32_t *s8s8_compensation) const {
326
327 const memory_desc_wrapper src_d(pd()->src_md());
328 const memory_desc_wrapper weights_d(pd()->weights_md());
329 const memory_desc_wrapper dst_d(pd()->dst_md());
330 const size_t src_dt_size = types::data_type_size(src_d.data_type());
331 const size_t wei_dt_size = types::data_type_size(weights_d.data_type());
332 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
333
334 const char *const __restrict src = brgemm_ctx.src;
335 const char *const __restrict weights = brgemm_ctx.weights;
336 const char *const __restrict bias = brgemm_ctx.bias;
337 char *const __restrict dst = brgemm_ctx.dst;
338 const std::vector<const void *> &post_ops_binary_rhs_arg_vec
339 = brgemm_ctx.post_ops_binary_rhs_arg_vec;
340
341 const auto &jcp = pd()->jcp_;
342 auto ndims = pd()->ndims();
343
344 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
345 char *const wsp_tile = is_amx
346 ? brgemm_ctx.wsp_tile + ithr * jcp.amx_buf_size_per_thread
347 : nullptr;
348
349 const int id = ndims_pick(od * SD, 0, 0);
350 const int ih = ndims_pick(oh * SH, oh * SH, 0);
351 const int iw = ow * SW;
352
353 const int oc = ocb * jcp.oc_block;
354 const int g_oc = g * jcp.oc + oc;
355
356 const int icb = icc * jcp.nb_ic_blocking;
357 const int ic = icb * jcp.ic_block;
358 const int g_ic = g * jcp.ic + ic;
359
360 const bool kernel_init = (icc == 0);
361
362 const auto os = (od * OH + oh) * OW + ow;
363
364 const bool is_os_tail = jcp.is_os_blocking ? (jcp.os - os < jcp.os_block)
365 : (OW - ow < jcp.ow_block);
366 const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block);
367 const bool is_ic_tail = (icc == pd()->ic_chunks - 1
368 && ((jcp.ic - ic) % jcp.ic_block != 0));
369
370 const auto src_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
371 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
372 const auto src_base
373 = jcp.is_rtus ? inp_buffer : src + src_dt_size * src_offset;
374 const auto wei_offset = jcp.wei_plain ? g * wei_ic_sz + ocb * wei_ocb_sz
375 : g * wei_ocb_sz + ocb * wei_ic_sz;
376 const auto wei_base = weights + wei_dt_size * wei_offset;
377 const auto ptr_D = dst
378 + dst_dt_size
379 * (n * dst_d_sz + od * dst_h_sz + oh * dst_w_sz
380 + ow * jcp.oc_without_padding + g_oc);
381 char *const ptr_C = (jcp.use_buffer) ? c_buffer : (char *)ptr_D;
382
383 const auto bias_w
384 = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr;
385 const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb)
386 - (is_ic_tail ? 1 : 0);
387
388 const auto comp_offset = (g * jcp.nb_oc + ocb) * jcp.oc_block;
389 int32_t *src_zp_comp_ptr
390 = (jcp.src_zero_point && icc == pd()->ic_chunks - 1)
391 ? &src_zp_comp[comp_offset]
392 : nullptr;
393 int32_t *s8s8_comp_ptr = (jcp.s8s8_avx512 && icc == pd()->ic_chunks - 1)
394 ? &s8s8_compensation[comp_offset]
395 : nullptr;
396
397 const auto call_brgemm = [=](int brg_idx, int ic_block_s, int n_ic_blocks,
398 bool do_postops) {
399 for (int k = 0; k < n_ic_blocks; k++) {
400 const auto ic_off = (ic_block_s + k) * jcp.ic_block;
401 const auto src_ic = ic_off;
402 const auto wei_ic = ic + ic_off;
403 const auto ptr_A = src_base + src_dt_size * src_ic;
404 const auto ptr_B = wei_base + wei_dt_size * wei_ic * wei_oc_sz;
405 brg_batch[k].ptr.A = ptr_A;
406 brg_batch[k].ptr.B = ptr_B;
407 brg_batch[k].vvpad.top = 0;
408 brg_batch[k].vvpad.bottom = 0;
409 }
410
411 // NOTE: avoid some costly tile reconfigurations here by keeping track
412 // of the previous brg kernel tile configuration palette
413 // TODO: adjust harness to avoid even more tile reconfigurations
414 if (is_amx) {
415 const int curr_palette_idx = brg_kernel_palette_idx_[brg_idx];
416 if (curr_palette_idx != *last_palette_idx) {
417 amx_tile_configure(brg_kernel_palette_[curr_palette_idx].p);
418 *last_palette_idx = curr_palette_idx;
419 }
420 }
421
422 const brgemm_kernel_t *brg_ker = brg_kernels_[brg_idx].get();
423 if (do_postops) {
424 const brgemm_post_ops_data_t post_ops_data {
425 static_cast<const void *>(bias_w),
426 &oscales[jcp.is_oc_scale * g_oc],
427 post_ops_binary_rhs_arg_vec.data(),
428 static_cast<size_t>(g_oc), 0, dst, 0,
429 static_cast<void *>(src_zp_comp_ptr), nullptr,
430 static_cast<void *>(dst_zp_vals), false, src_zp_vals};
431
432 void *scratch = is_amx ? static_cast<void *>(wsp_tile)
433 : static_cast<void *>(s8s8_comp_ptr);
434 brgemm_kernel_execute_postops(brg_ker, n_ic_blocks, brg_batch,
435 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
436 } else {
437 void *scratch = is_amx ? static_cast<void *>(wsp_tile)
438 : static_cast<void *>(s8s8_comp_ptr);
439 brgemm_kernel_execute(
440 brg_ker, n_ic_blocks, brg_batch, (void *)ptr_C, scratch);
441 }
442 };
443
444 const auto do_post_work = (pd()->need_postwork || jcp.use_buffer)
445 && icc == pd()->ic_chunks - 1;
446
447 if (nb_ic_b > 0) {
448 const auto brg_idx
449 = get_brg_idx(kernel_init, is_os_tail, is_oc_tail, false);
450 call_brgemm(brg_idx, 0, nb_ic_b, do_post_work && !is_ic_tail);
451 }
452 if (is_ic_tail) {
453 const auto use_init_ker = (kernel_init && nb_ic_b == 0);
454 const auto brg_idx
455 = get_brg_idx(use_init_ker, is_os_tail, is_oc_tail, true);
456
457 call_brgemm(brg_idx, nb_ic_b, 1, do_post_work);
458 }
459}
460
461template <cpu_isa_t isa>
462status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
463 const exec_ctx_t &ctx) const {
464
465 brgemm_exec_ctx_t brgemm_ctx(ctx, pd());
466
467 const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
468
469 const auto &jcp = pd()->jcp_;
470 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
471 const memory_desc_wrapper weights_d(pd()->weights_md(0));
472
473 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
474 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
475
476 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
477 src_scales, wei_scales, pd()->OC(), pd()->attr());
478
479 DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
480 DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
481
482 const auto extra_data_offset
483 = weights_d.size() - weights_d.additional_buffer_size();
484 auto w = const_cast<char *>(brgemm_ctx.weights);
485 int32_t *s8s8_compensation = (jcp.s8s8_avx512)
486 ? reinterpret_cast<int32_t *>(w + extra_data_offset)
487 : nullptr;
488 int32_t *zp_compensation = (jcp.src_zero_point)
489 ? reinterpret_cast<int32_t *>(&w[extra_data_offset])
490 + (jcp.s8s8_avx512 ? jcp.s8s8_comp_buffer_size : 0)
491 : nullptr;
492 int32_t *dst_zp_vals = jcp.dst_zero_point ? &dst_zero_point : nullptr;
493
494 brgemm_batch_element_t *const brg_batch_global
495 = (jcp.brg_type != brgemm_strd)
496 ? scratchpad.template get<brgemm_batch_element_t>(
497 key_brgemm_primitive_batch)
498 : nullptr;
499 char *const c_buffer_global = (jcp.use_buffer)
500 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
501 : nullptr;
502 char *inp_buffer_base = (jcp.is_rtus)
503 ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer)
504 : nullptr;
505 uint8_t *inp_buffer_mask_base = (jcp.is_rtus)
506 ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask)
507 : nullptr;
508
509 if (jcp.is_os_blocking) {
510 const int os_chunks = div_up(jcp.nb_os, jcp.nb_os_blocking);
511 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_oc * os_chunks;
512
513#define BRGC_WO(...) \
514 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
515 if (ithr >= work_amount) return; \
516 brgemm_batch_element_t *const brg_batch \
517 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
518 char *const c_buffer = (jcp.use_buffer) \
519 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
520 : nullptr; \
521 char *inp_buffer = (jcp.is_rtus) \
522 ? inp_buffer_base + ithr * src_dsz * jcp.inp_buffer_size \
523 : nullptr; \
524 uint8_t *__restrict inp_buffer_mask = (jcp.is_rtus) \
525 ? inp_buffer_mask_base + ithr * jcp.inp_buffer_mask_size \
526 : nullptr; \
527 int last_n = -1; \
528 int last_g = -1; \
529 int last_palette_idx = -1; \
530 int start {0}, end {0}; \
531 balance211(work_amount, nthr, ithr, start, end); \
532 int n {0}, g {0}, ocb {0}, oss {0}; \
533 nd_iterator_init(start, __VA_ARGS__); \
534 for (auto work = start; work < end; work++) { \
535 if (jcp.is_rtus && (last_n != n || last_g != g)) \
536 std::memset(inp_buffer_mask, 0, jcp.inp_buffer_mask_size); \
537 const auto osb_start = oss * jcp.nb_os_blocking; \
538 const auto osb_range \
539 = nstl::min(jcp.nb_os - osb_start, jcp.nb_os_blocking); \
540 for (int osb = 0; osb < osb_range; osb++) { \
541 const int os = (osb_start + osb) * jcp.os_block; \
542 const int od = os / (OH * OW); \
543 const int oh = (os % (OH * OW)) / OW; \
544 const int ow = os % OW; \
545 char *inp_buffer_sp = (jcp.is_rtus) \
546 ? inp_buffer + src_dsz * os * jcp.LDA \
547 : nullptr; \
548 for (int icc = 0; icc < pd()->ic_chunks; icc++) { \
549 if (jcp.is_rtus) \
550 maybe_rtus(ithr, brgemm_ctx.src, inp_buffer_sp, \
551 inp_buffer_mask, g, n, icc, od, oh, ow); \
552 exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \
553 inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \
554 &last_palette_idx, oscales, src_zero_point, \
555 zp_compensation, dst_zp_vals, s8s8_compensation); \
556 } \
557 } \
558 last_n = n; \
559 last_g = g; \
560 nd_iterator_step(__VA_ARGS__); \
561 } \
562 if (is_amx) amx_tile_release(); \
563 });
564
565 if (jcp.loop_order == loop_ndhwgc)
566 BRGC_WO(n, jcp.mb, oss, os_chunks, g, jcp.ngroups, ocb, jcp.nb_oc)
567 else if (jcp.loop_order == loop_ngcdhw)
568 BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, oss, os_chunks)
569 else
570 assert(!"Unknown loop order");
571
572#undef BRGC_WO
573
574 } else {
575 const int work_amount
576 = jcp.mb * jcp.ngroups * jcp.nb_oc * OD * OH * jcp.nb_ow;
577
578#define BRGC_WO(...) \
579 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
580 if (ithr >= work_amount) return; \
581 brgemm_batch_element_t *const brg_batch \
582 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
583 char *const c_buffer = (jcp.use_buffer) \
584 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
585 : nullptr; \
586 int last_palette_idx = -1; \
587 int start {0}, end {0}; \
588 balance211(work_amount, nthr, ithr, start, end); \
589 int n {0}, g {0}, ocb {0}, od {0}, oh {0}, owb {0}; \
590 nd_iterator_init(start, __VA_ARGS__); \
591 for (auto work = start; work < end; work++) { \
592 for (int icc = 0; icc < pd()->ic_chunks; icc++) { \
593 const int ow = owb * jcp.ow_block; \
594 exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \
595 ocb, od, oh, ow, icc, &last_palette_idx, oscales, \
596 src_zero_point, zp_compensation, dst_zp_vals, \
597 s8s8_compensation); \
598 } \
599 nd_iterator_step(__VA_ARGS__); \
600 } \
601 if (is_amx) amx_tile_release(); \
602 });
603
604 if (jcp.loop_order == loop_ndhwgc)
605 BRGC_WO(n, jcp.mb, od, OD, oh, OH, owb, jcp.nb_ow, g, jcp.ngroups,
606 ocb, jcp.nb_oc)
607 else if (jcp.loop_order == loop_ngcdhw)
608 BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, od, OD, oh, OH,
609 owb, jcp.nb_ow)
610 else
611 assert(!"Unknown loop order");
612
613#undef BRGC_WO
614 }
615
616 return status::success;
617}
618
619template struct brgemm_1x1_convolution_fwd_t<avx2>;
620template struct brgemm_1x1_convolution_fwd_t<avx2_vnni_2>;
621template struct brgemm_1x1_convolution_fwd_t<avx512_core>;
622template struct brgemm_1x1_convolution_fwd_t<avx512_core_vnni>;
623template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16>;
624template struct brgemm_1x1_convolution_fwd_t<avx512_core_fp16>;
625template struct brgemm_1x1_convolution_fwd_t<avx512_core_amx>;
626template struct brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>;
627
628} // namespace x64
629} // namespace cpu
630} // namespace impl
631} // namespace dnnl
632
633// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
634