1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/memory_tracking.hpp"
20#include "common/tag_traits.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/cpu_primitive.hpp"
25#include "cpu/scale_utils.hpp"
26
27#include "cpu/x64/amx_tile_configure.hpp"
28#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
29#include "cpu/x64/matmul/brgemm_matmul.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35namespace matmul {
36
37using namespace dnnl::impl::memory_tracking::names;
38using namespace dnnl::impl::utils;
39
40using namespace nstl;
41
42using namespace data_type;
43
44template <cpu_isa_t isa>
45status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
46 const auto src_dt = src_md_.data_type;
47 const auto wei_dt = weights_md_.data_type;
48 const auto dst_dt = dst_md_.data_type;
49
50 const bool is_f32 = everyone_is(f32, src_dt, wei_dt, dst_dt);
51 const bool is_int8 = one_of(src_dt, u8, s8) && wei_dt == s8
52 && one_of(dst_dt, u8, s8, s32, f32, bf16);
53 const bool is_bf16
54 = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
55 const bool is_f16
56 = everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
57
58 auto check_bias = [&]() -> bool {
59 const auto bia_dt = weights_md(1)->data_type;
60 // The cause in IMPLICATION should be an expression to work around
61 // ICE in GCC 7.4.
62 const bool is_bia_dt_correct
63 = IMPLICATION(is_int8 == true,
64 one_of(bia_dt, f32, s32, s8, u8, bf16))
65 && IMPLICATION(!is_int8, one_of(bia_dt, f32, src_dt));
66 return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN());
67 };
68
69 auto check_attr_scales = [&]() -> bool {
70 using namespace data_type;
71 const std::vector<int> supported_args
72 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
73 bool ok = attr()->scales_.has_default_values(supported_args);
74 for (int arg : supported_args) {
75 const auto &mask = attr()->scales_.get(arg).mask_;
76 if (arg == DNNL_ARG_WEIGHTS)
77 ok = ok && (mask == 0 || mask == 1 << (dst_md()->ndims - 1));
78 else
79 ok = ok && (mask == 0);
80 }
81 if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values()
82 && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()
83 && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) {
84 // This case requires scratchpad
85 if (N() == DNNL_RUNTIME_DIM_VAL) return false;
86 }
87 return ok;
88 };
89
90 auto check_attr_zero_points
91 = [&]() -> bool { return attr()->zero_points_.common(); };
92
93 const bool problem_dt_correct = is_int8 || is_bf16 || is_f32 || is_f16;
94 bool ok = mayiuse(isa) && problem_dt_correct && !has_zero_dim_memory()
95 && !has_runtime_dims_or_strides()
96 && attr()->has_default_values(
97 primitive_attr_t::skip_mask_t::scales_runtime
98 | primitive_attr_t::skip_mask_t::zero_points_runtime
99 | primitive_attr_t::skip_mask_t::post_ops
100 | primitive_attr_t::skip_mask_t::sum_dt,
101 dst_dt)
102 && attr()->post_ops_.check_sum_consistent_dt(dst_dt)
103 && check_attr_scales() && check_attr_zero_points() && check_bias();
104 if (!ok) return status::unimplemented;
105
106 CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_,
107 dst_md_, bias_md_, attr_));
108
109 const float alpha = 1.0;
110 const float beta = 1.0;
111 const float beta_init = 0.0;
112
113 for_(int i_bs = 0; i_bs < 2; i_bs++)
114 for_(int i_init = 0; i_init < 2; i_init++)
115 for_(int i_M = 0; i_M < 2; i_M++)
116 for_(int i_N = 0; i_N < 2; i_N++)
117 for (int i_K = 0; i_K < 2; i_K++) {
118 auto vbeta = (i_init) ? beta_init : beta;
119 auto vM = (i_M) ? bgmmc_.M_tail : bgmmc_.M_blk;
120 auto vN = (i_N) ? bgmmc_.N_tail : bgmmc_.N_blk;
121 auto vK = (i_K) ? bgmmc_.K_tail : bgmmc_.K_blk;
122
123 int bs = get_brg_batchsize(bgmmc_, i_bs, i_K);
124 int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K);
125 if (idx < 0) continue;
126 brgemm_t &brg = brg_descs_[idx];
127 auto LDA = i_K && bgmmc_.use_buffer_a_tail_only
128 ? (dim_t)bgmmc_.wei_k_blk
129 : bgmmc_.LDA;
130 CHECK(brgemm_desc_init(&brg, isa, bgmmc_.brg_type, bgmmc_.src_dt,
131 bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
132 LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));
133
134 auto LDD = bgmmc_.LDD;
135 CHECK(brgemm_desc_set_postops(
136 &brg, attr(), &dst_md_, LDD, bgmmc_.bia_dt));
137
138 brgemm_attr_t brgattr;
139 brgattr.generate_skip_accumulation
140 = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1;
141 const bool is_amx = is_superset(isa, avx512_core_amx);
142 if (is_amx) {
143 if (!brgattr.generate_skip_accumulation) {
144 // TODO: uker doesn't yet support generate_skip_accumulation
145 brgattr.use_uker = true;
146 brgattr.use_interleave_stores = true;
147 }
148 brgattr.max_bs = bs;
149 brgattr.wary_tail_read = false;
150
151 // TODO: change expected sizes to local chunks wrt L2 blocking
152 brgattr.hint_expected_A_size = vM * vK * bs;
153 brgattr.hint_expected_B_size = vN * vK * bs;
154 brgattr.hint_expected_C_size = vM * vN * bs;
155 brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
156 brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
157 }
158
159 CHECK(brgemm_desc_set_attr(&brg, brgattr));
160 bgmmc_.wsp_tile_per_thr_bytes = nstl::max(
161 brg.get_wsp_buffer_size(), bgmmc_.wsp_tile_per_thr_bytes);
162 }
163
164 auto scratchpad = scratchpad_registry().registrar();
165 init_scratchpad(scratchpad, bgmmc_);
166 book_precomputed_scales(scratchpad, attr()->scales_, N());
167
168 return status::success;
169}
170
171template <cpu_isa_t isa>
172status_t brgemm_matmul_t<isa>::init(engine_t *engine) {
173 for_(int i_bs = 0; i_bs < 2; i_bs++)
174 for_(int i_M = 0; i_M < 2; i_M++)
175 for_(int i_N = 0; i_N < 2; i_N++)
176 for_(int i_K = 0; i_K < 2; i_K++)
177 for (int i_init = 0; i_init < 2; i_init++) {
178 int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K);
179 if (idx < 0) continue;
180
181 brgemm_kernel_t *ker = nullptr;
182 CHECK(brgemm_kernel_create(&ker, pd()->get_brg_desc(idx)));
183 CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
184 if (is_superset(isa, avx512_core_amx))
185 CHECK(brgemm_init_tiles(
186 pd()->get_brg_desc(idx), &brg_kernel_palettes_[idx][0]));
187 }
188
189 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
190 if (bgmmc.use_buffer_b)
191 CHECK(create_brgemm_matmul_copy_b(copy_B_kernel_, &bgmmc));
192
193 if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only)
194 CHECK(create_brgemm_matmul_copy_a(copy_A_kernel_, &bgmmc));
195
196 if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == f32) {
197 CHECK(safe_ptr_assign(
198 acc_ker_f32_, new cpu_accumulator_1d_t<data_type::f32>()));
199 CHECK(acc_ker_f32_->create_kernel());
200 } else if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == s32) {
201 CHECK(safe_ptr_assign(
202 acc_ker_s32_, new cpu_accumulator_1d_t<data_type::s32>()));
203 CHECK(acc_ker_s32_->create_kernel());
204 }
205
206 return status::success;
207}
208
209template <cpu_isa_t isa>
210status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
211 DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
212 DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS);
213 DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
214 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
215 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
216
217 auto &scratchpad = ctx.get_scratchpad_grantor();
218 const float *oscales = precompute_scales(
219 scratchpad, src_scales, wei_scales, pd()->N(), pd()->attr());
220
221 brg_matmul_exec_ctx_t brgmm_ctx(
222 ctx, pd(), oscales, src_zero_point, wei_zero_point, dst_zero_point);
223
224 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
225 const bool use_buffer_a
226 = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
227 const bool is_amx = is_superset(isa, avx512_core_amx);
228 const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
229
230 parallel(num_threads, [&](const int ithr, const int nthr) {
231 const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
232 const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
233 if (ithr_bmn < 0 || ithr_k < 0) return;
234 int start {0}, end {0};
235 balance211(brgmm_ctx.get_parallel_work_amount(),
236 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, start, end);
237 int kc_start {0}, kc_end {bgmmc.K_chunks};
238 if (brgmm_ctx.parallel_reduction_is_used())
239 balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(),
240 ithr_k, kc_start, kc_end);
241
242 if (is_amx) {
243 const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
244 amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]);
245 }
246
247 int b {0}, mc {0}, nc {0};
248 nd_iterator_init(
249 start, b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
250 while (start < end) {
251 auto m_start = mc * bgmmc.M_chunk_size;
252 auto m_end = nstl::min(
253 (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
254 auto n_start = nc * bgmmc.N_chunk_size;
255 auto n_end = nstl::min(
256 (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
257 for_(int kc = kc_start; kc < kc_end; kc++)
258 for (int nb = n_start; nb < n_end; nb++) {
259 if (bgmmc.use_buffer_b)
260 copy_b_chunk_in_buffer(brgmm_ctx, ithr, b, nb, kc);
261 for (int mb = m_start; mb < m_end; mb++) {
262 if (use_buffer_a && nb == n_start)
263 copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc);
264 compute_kernel(
265 brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
266 }
267 }
268 ++start;
269 nd_iterator_step(
270 b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
271 }
272 if (is_amx) { amx_tile_release(); }
273 });
274
275 maybe_reduce_partial_results_and_apply_postops(brgmm_ctx);
276
277 return status::success;
278}
279
280template <cpu_isa_t isa>
281void brgemm_matmul_t<isa>::compute_kernel(
282 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
283 int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
284 const bool is_amx = is_superset(isa, avx512_core_amx);
285 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
286 const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr);
287 const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
288
289 const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr);
290 const int m = m_blk_idx * bgmmc.M_blk;
291 const int n = n_blk_idx * bgmmc.N_blk;
292 const int k_blk_idx = k_chunk_idx * bgmmc.brgemm_batch_size;
293
294 const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
295 const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
296 const bool is_last_K_chunk = brgmm_ctx.is_last_K_chunk(k_chunk_idx);
297
298 const int remaining_k_blks
299 = (bgmmc.use_buffer_a ? utils::rnd_up(bgmmc.K, bgmmc.K_blk)
300 : bgmmc.K)
301 - k_chunk_idx * bgmmc.K_chunk_elems;
302 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
303 const bool is_K_tail
304 = is_last_K_chunk && (gemm_batch * bgmmc.K_blk) != remaining_k_blks;
305 auto is_bs_tail = (gemm_batch != bgmmc.brgemm_batch_size);
306 const int brg_ker_idx = pd()->get_brg_kernel_idx(
307 is_bs_tail, do_init, is_M_tail, is_N_tail, false);
308 const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
309 auto ptr_D = brgmm_ctx.get_data_C_ptr(b_idx, m, n);
310 auto ptr_C = (bgmmc.use_buffer_c)
311 ? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx)
312 : ptr_D;
313
314 const auto zp_comp_a
315 = brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx);
316 const auto zp_comp_b
317 = brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
318 const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
319 const auto &post_ops_binary_rhs_arg_vec
320 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
321 const bool post_ops_applicable = bgmmc.post_ops_applicable
322 && (brgmm_ctx.get_num_threads_for_k() <= 1 || bgmmc.K_chunks == 1);
323
324 if (gemm_batch > 0 && brg_ker_idx >= 0) {
325 const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
326 assert(brg_kernel != nullptr);
327
328 const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
329 if (is_tile_reconf_required)
330 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
331
332 brgmm_ctx.init_brgemm_batch_elements_values(
333 ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
334
335 if (post_ops_applicable && is_last_K_chunk && !is_K_tail) {
336 void *scratch = is_amx
337 ? static_cast<void *>(wsp_tile)
338 : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
339 ithr, b_idx, n_blk_idx));
340
341 const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
342 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
343 ? b_idx / bgmmc.batch_without_first_dim
344 : 0;
345 const size_t first_mb_matrix_addr_off
346 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
347 + (m * bgmmc.N + n);
348 const brgemm_post_ops_data_t post_ops_data {
349 static_cast<const void *>(ptr_bias),
350 brgmm_ctx.get_oscales_ptr(n),
351 post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
352 dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
353 first_mb_matrix_addr_off,
354 static_cast<const void *>(zp_comp_a),
355 static_cast<const void *>(zp_comp_b),
356 static_cast<const void *>(zp_c_val_ptr)};
357
358 brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch,
359 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
360 } else {
361 brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
362 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
363 }
364
365 if (is_tile_reconf_required)
366 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
367 }
368 if (is_K_tail) {
369 brgmm_ctx.init_brgemm_batch_elements_values(
370 ithr, gemm_batch, 1, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
371
372 const bool use_init_ker = (do_init && gemm_batch == 0);
373 const int brg_ker_idx = pd()->get_brg_kernel_idx(
374 false, use_init_ker, is_M_tail, is_N_tail, true);
375 const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get();
376 const bool is_tile_reconf_required
377 = is_amx && bgmmc.K_tail != bgmmc.K_blk;
378 if (is_tile_reconf_required)
379 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
380 if (post_ops_applicable) {
381 void *scratch = is_amx
382 ? static_cast<void *>(wsp_tile)
383 : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
384 ithr, b_idx, n_blk_idx));
385
386 const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
387 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
388 ? b_idx / bgmmc.batch_without_first_dim
389 : 0;
390 const size_t first_mb_matrix_addr_off
391 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
392 + (m * bgmmc.N + n);
393 const brgemm_post_ops_data_t post_ops_data {
394 static_cast<const void *>(ptr_bias),
395 brgmm_ctx.get_oscales_ptr(n),
396 post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
397 dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
398 first_mb_matrix_addr_off,
399 static_cast<const void *>(zp_comp_a),
400 static_cast<const void *>(zp_comp_b),
401 static_cast<const void *>(zp_c_val_ptr)};
402
403 brgemm_kernel_execute_postops(brg_kernel_k_tail, 1, addr_batch,
404 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
405 } else {
406 brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch,
407 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
408 }
409 if (is_tile_reconf_required)
410 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
411 }
412}
413
414template <cpu_isa_t isa>
415void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
416 const brg_matmul_exec_ctx_t &brgmm_ctx) const {
417 if (!brgmm_ctx.parallel_reduction_is_used()) return;
418
419 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
420 const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
421
422 parallel(num_threads, [&](const int ithr, const int nthr) {
423 const int nthr_k = brgmm_ctx.get_num_threads_for_k();
424 const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
425 const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
426 if (ithr_bmn < 0 || ithr_k < 0) return;
427
428 const int num_reduction_buffers = nstl::min(nthr_k, bgmmc.K_chunks);
429
430 int bmn_start {0}, bmn_end {0};
431 int start {0}, end {0};
432 balance211(brgmm_ctx.get_parallel_work_amount(),
433 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, bmn_start,
434 bmn_end);
435 balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end);
436
437 int b {0}, mc {0}, nc {0};
438
439 assert(bgmmc.batch == 1);
440 nd_iterator_init(bmn_start + start, b, bgmmc.batch, mc, bgmmc.M_chunks,
441 nc, bgmmc.N_chunks);
442 while (start < end) {
443 auto mb_start = mc * bgmmc.M_chunk_size;
444 auto mb_end = nstl::min(
445 (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
446 auto nb_start = nc * bgmmc.N_chunk_size;
447 auto nb_end = nstl::min(
448 (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
449 for (int mb = mb_start; mb < mb_end; mb++) {
450 const int curr_M_blk
451 = nstl::min(bgmmc.M - mb * bgmmc.M_blk, bgmmc.M_blk);
452 const bool is_M_tail = curr_M_blk < bgmmc.M_blk;
453 const int curr_N_chunk_size
454 = nstl::min(bgmmc.N, nb_end * bgmmc.N_blk)
455 - nb_start * bgmmc.N_blk;
456 char *buf_reduced_base = brgmm_ctx.get_buf_C_par_reduction_ptr(
457 0, mb, nb_start);
458 const size_t m_offset = bgmmc.LDC * bgmmc.acc_dt_sz;
459 for (int r = 1; r < num_reduction_buffers; r++) {
460 const char *buf_to_reduce_base
461 = brgmm_ctx.get_buf_C_par_reduction_ptr(
462 r, mb, nb_start);
463 for (int m = 0; m < curr_M_blk; m++) {
464 accumulate(buf_reduced_base + m * m_offset,
465 buf_to_reduce_base + m * m_offset,
466 curr_N_chunk_size);
467 }
468 }
469 if (bgmmc.post_ops_applicable) {
470 for (int nb = nb_start; nb < nb_end; nb++) {
471 const bool is_N_tail
472 = (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk);
473 const int brg_ker_idx = pd()->get_brg_kernel_idx(
474 false, false, is_M_tail, is_N_tail, false);
475 const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
476 const int m = mb * bgmmc.M_blk;
477 const int n = nb * bgmmc.N_blk;
478 const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
479 auto ptr_D = brgmm_ctx.get_data_C_ptr(b, m, n);
480 auto ptr_C = brgmm_ctx.get_buf_C_par_reduction_ptr(
481 0, mb, nb);
482
483 // TODO: support reduction for zp/s8s8 compensations
484 // computed in copy routines
485 const auto zp_comp_a
486 = brgmm_ctx.get_zp_a_compensation_ptr(
487 ithr, b, nb);
488 const auto zp_comp_b
489 = brgmm_ctx.get_zp_b_compensation_result_ptr(
490 ithr, mb);
491 const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
492 const auto &post_ops_binary_rhs_arg_vec
493 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
494
495 const size_t dst_row_logical_off = mb * bgmmc.M_blk;
496 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
497 ? b / bgmmc.batch_without_first_dim
498 : 0;
499 const size_t first_mb_matrix_addr_off
500 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
501 + (m * bgmmc.N + n);
502 // apply post-ops and convert to dst data type only
503 constexpr bool skip_accumulation = true;
504 const brgemm_post_ops_data_t post_ops_data {
505 static_cast<const void *>(ptr_bias),
506 brgmm_ctx.get_oscales_ptr(n),
507 post_ops_binary_rhs_arg_vec.data(),
508 static_cast<size_t>(n), dst_row_logical_off,
509 brgmm_ctx.get_data_C_ptr(0, 0, 0),
510 first_mb_matrix_addr_off,
511 static_cast<const void *>(zp_comp_a),
512 static_cast<const void *>(zp_comp_b),
513 static_cast<const void *>(zp_c_val_ptr),
514 skip_accumulation};
515
516 brgemm_kernel_execute_postops(brg_kernel, 0, nullptr,
517 (void *)ptr_C, (void *)ptr_D, post_ops_data,
518 nullptr);
519 }
520 }
521 }
522 ++start;
523 nd_iterator_step(
524 b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
525 }
526 });
527}
528
529template <cpu_isa_t isa>
530void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
531 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
532 int m_blk_idx, int k_chunk_idx) const {
533 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
534
535 auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t();
536 const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
537 const bool is_K_tail
538 = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
539 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
540 const int gemm_batch_iters = bgmmc.use_buffer_a_tail_only ? 0 : gemm_batch;
541
542 const int m = m_blk_idx * bgmmc.M_blk;
543 const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
544 ctx.current_M_blk = is_M_tail ? bgmmc.M_tail : bgmmc.M_blk;
545 ctx.zp_b_compensation_buffer_ptr
546 = (void *)brgmm_ctx.get_zp_b_compensation_buffer_ptr(
547 ithr, m_blk_idx);
548 ctx.zp_a_compensation_result_ptr
549 = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
550 ithr, m_blk_idx);
551 ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr();
552 ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
553
554 for (int gb = 0; gb < gemm_batch_iters; gb++) {
555 const int k = k_start + gb * bgmmc.K_blk;
556 ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
557 ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(ithr, m_blk_idx, gb);
558 ctx.current_K_blk = nstl::min(bgmmc.K_blk, bgmmc.K);
559 ctx.current_K_start = k;
560
561 (*copy_A_kernel_)(&ctx);
562 }
563 if (is_K_tail) {
564 const auto K_tail = bgmmc.K % bgmmc.K_blk;
565 const int k = k_start + gemm_batch * bgmmc.K_blk;
566 ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
567 ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(
568 ithr, m_blk_idx, gemm_batch_iters);
569 ctx.current_K_blk = K_tail;
570 ctx.current_K_start = k;
571
572 (*copy_A_kernel_)(&ctx);
573 }
574}
575
576template <cpu_isa_t isa>
577void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
578 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
579 int n_blk_idx, int k_chunk_idx) const {
580 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
581
582 const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
583 const bool is_K_tail
584 = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
585 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
586 auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t();
587
588 const int n = n_blk_idx * bgmmc.N_blk;
589 const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
590 ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk;
591 ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
592 ithr, b_idx, n_blk_idx);
593 ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
594
595 int gb = 0;
596 for (; gb < gemm_batch; gb++) {
597 const int k = k_start + gb * bgmmc.K_blk;
598 ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
599 ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
600 ctx.compensation_ptr
601 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
602 ctx.current_K_start = k;
603 ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
604 if (bgmmc.blocked_B && isa == avx512_core_fp16) {
605 cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
606 bgmmc.wei_n_blk * ctx.current_K_iters);
607 } else {
608 (*copy_B_kernel_)(&ctx);
609 }
610 }
611
612 if (is_K_tail) {
613 const int k = k_start + gb * bgmmc.K_blk;
614 ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
615 ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
616 ctx.compensation_ptr
617 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
618 ctx.current_K_start = k;
619 ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
620 if (bgmmc.blocked_B && isa == avx512_core_fp16) {
621 cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
622 bgmmc.wei_n_blk * ctx.current_K_iters);
623 } else {
624 (*copy_B_kernel_)(&ctx);
625 }
626 }
627}
628
629template <cpu_isa_t isa>
630void brgemm_matmul_t<isa>::accumulate(
631 char *result_ptr, const char *reduce_ptr, size_t size) const {
632 if (pd()->get_brgemm_matmul_conf().acc_dt == f32)
633 acc_ker_f32_->accumulate(
634 (float *)result_ptr, (const float *)reduce_ptr, size);
635 else if (pd()->get_brgemm_matmul_conf().acc_dt == s32)
636 acc_ker_s32_->accumulate(
637 (int32_t *)result_ptr, (const int32_t *)reduce_ptr, size);
638 else
639 assert(!"unsupported accumulation data type");
640}
641
642template <cpu_isa_t isa>
643struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
644 brg_matmul_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd,
645 const float *oscales, int32_t src_zp, int32_t wei_zp,
646 int32_t dst_zp)
647 : bgmmc_(pd->get_brgemm_matmul_conf()) {
648
649 data_A_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
650 data_B_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
651 data_C_ptr_ = CTX_OUT_MEM(char *, DNNL_ARG_DST);
652
653 bias_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
654 oscales_ptr_ = oscales;
655 memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
656 const auto &bgmmc = pd->get_brgemm_matmul_conf();
657
658 batch_element_ptr_ = scratchpad.template get<brgemm_batch_element_t>(
659 key_brgemm_primitive_batch);
660
661 const bool use_buffer_a
662 = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
663 buf_A_ptr_ = (use_buffer_a)
664 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a)
665 : nullptr;
666
667 buf_B_ptr_ = (bgmmc.use_buffer_b)
668 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b)
669 : nullptr;
670
671 buf_C_ptr_ = (bgmmc.use_buffer_c)
672 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
673 : nullptr;
674
675 is_amx_ = is_superset(isa, avx512_core_amx);
676 wsp_tile_ptr_ = is_amx_
677 ? ctx.get_scratchpad_grantor().template get<char>(
678 key_conv_amx_tile_buffer)
679 : nullptr;
680
681 const memory_desc_wrapper weights_d(pd->weights_md(0));
682 const dim_t comp_offset = bgmmc_.b_dt_sz
683 * (weights_d.size() - weights_d.additional_buffer_size());
684 s8s8_compensation_ptr_ = (bgmmc.s8s8_compensation_required)
685 ? ((bgmmc.use_buffer_b)
686 ? scratchpad.template get<int32_t>(
687 key_brgemm_primitive_buffer_comp)
688 : const_cast<int32_t *>(
689 reinterpret_cast<const int32_t *>(
690 &data_B_ptr_[comp_offset])))
691 : nullptr;
692 assert(IMPLICATION(bgmmc.s8s8_compensation_required,
693 bgmmc_.b_dt_sz == bgmmc_.tr_b_dt_sz));
694
695 zero_point_a_compensations_ptr_ = bgmmc.has_zero_point_a
696 ? scratchpad.template get<int32_t>(
697 key_brgemm_primitive_zp_comp_a)
698 : nullptr;
699 zero_point_b_compensations_ptr_ = bgmmc.has_zero_point_b
700 ? scratchpad.template get<int32_t>(
701 key_brgemm_primitive_zp_comp_b)
702 : nullptr;
703
704 zero_point_a_negative_val_ = -src_zp;
705 zero_point_b_negative_val_ = -wei_zp;
706 zero_point_mixed_ab_compensation_component_
707 = bgmmc.K * zero_point_a_negative_val_;
708
709 zero_point_c_val_ = dst_zp;
710
711 post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args(
712 pd->attr()->post_ops_, ctx);
713 base_brg_ker_idx_
714 = pd->get_brg_kernel_idx(false, true, false, false, false);
715 vnni_factor = data_type_vnni_granularity(bgmmc.wei_dt);
716
717 reorder_zp_a_comp_ptr_ = nullptr;
718 if (bgmmc_.has_zero_point_a && bgmmc_.blocked_B) {
719 // Store the pointer to computed in reorder compensation values to
720 // scale them locally by zp_a value just before usage in post-ops.
721 // Using the single global scaling before parallel section might
722 // produce significant overhead for small problems running in
723 // multitreaded execution mode
724 const size_t reorder_zp_a_comp_offset
725 = weights_d.size() - weights_d.additional_buffer_size();
726 const size_t b_batch
727 = get_bb_idx(bgmmc.batch - 1, bgmmc_.bcast_B_desc) + 1;
728 const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
729 ? sizeof(int32_t) * b_batch * bgmmc.s8s8_comp_b_str
730 : 0;
731 reorder_zp_a_comp_ptr_
732 = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(
733 &data_B_ptr_[reorder_zp_a_comp_offset
734 + s8s8_buffer_sz]));
735 }
736
737 // Set last_chunk_brgemm_batch_size_ to brgemm_batch_size
738 // when K_tail = 0 and brgemm_batch_tail_size = 0
739 last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_tail_size;
740 if (bgmmc.K_tail == 0 && last_chunk_brgemm_batch_size_ == 0)
741 last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_size;
742
743 // parallelization
744 parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks;
745
746 // The number of threads available during primitive execution may
747 // increase (ex. Eigen threadpool implementation) or decrease
748 // (ex. nested parallelism) compared to the
749 // number of threads available during primitive creation.
750 // So we limit the total number of threads to the
751 // minimum of these two values to prevent potential OOM issues.
752 nthr_ = nstl::min(dnnl_get_current_num_threads(), bgmmc.nthr);
753
754 nthr_k_ = bgmmc.nthr_k > 0 && bgmmc.nthr_k <= nthr_ ? bgmmc.nthr_k : 1;
755 nthr_bmn_ = nthr_ / nthr_k_;
756 num_threads_used_ = nthr_k_ * nthr_bmn_;
757
758 // If parallel_work_amount_ == 1 and parallel reduction is not used, we
759 // limit num threads to 1 as parallel(1, ...) does not create parallel
760 // section at all. We do not limit number of threads for case
761 // 1 < parallel_work_amount_ < dnnl_get_max_threads() to avoid potential
762 // overhead on spawning different number of OMP threads from layer to
763 // layer.
764 if (parallel_work_amount_ == 1 && !parallel_reduction_is_used())
765 nthr_ = nthr_bmn_ = nthr_k_ = 1;
766
767 const bool need_to_calculate_compensation_for_a
768 = bgmmc.has_zero_point_b;
769 const bool need_to_calculate_compensation_for_b = !IMPLICATION(
770 (bgmmc.has_zero_point_a || bgmmc.s8s8_compensation_required),
771 bgmmc.blocked_B);
772 const bool calculate_compensations_in_copy_routines
773 = need_to_calculate_compensation_for_a
774 || need_to_calculate_compensation_for_b;
775 // currently parallel reduction is supported only for case of
776 // non-batched problems without computation of any compensations in
777 // copy routines
778 assert(IMPLICATION(parallel_reduction_is_used(),
779 bgmmc.batch == 1 && !calculate_compensations_in_copy_routines));
780 MAYBE_UNUSED(need_to_calculate_compensation_for_a);
781 MAYBE_UNUSED(need_to_calculate_compensation_for_b);
782 MAYBE_UNUSED(calculate_compensations_in_copy_routines);
783 }
784
785 // NOTE: gb --> generalized batch, bb --> broadcast batch
786 int get_bb_idx(int gb_idx, const brgemm_matmul_bcast_desc_t &bd) const {
787 if (!bd.bcast_mask) // no broadcast
788 return gb_idx;
789
790 int gb_off_before_bcast = utils::rnd_dn(
791 gb_idx, bd.first_bcast_dim_to_last_batch_dim_prod);
792 int bb_idx = gb_off_before_bcast / (bd.bcast_dims_prod);
793
794 dim_t cur_bcast_dims_prod = bd.bcast_dims_prod;
795 int mask = 1 << (bgmmc_.batch_ndims - bd.first_bcast_dim - 1);
796 for (int d = bd.first_bcast_dim; d < bd.last_bcast_dim; ++d) {
797 if (bd.bcast_mask & mask) // broadcast
798 cur_bcast_dims_prod /= bd.batch_dims[d];
799 else {
800 int cur_b = (gb_idx / bd.gb_off[d]) % bd.batch_dims[d];
801 bb_idx += cur_b * (bd.gb_off[d] / cur_bcast_dims_prod);
802 }
803 mask >>= 1;
804 }
805 bb_idx += gb_idx % bd.gb_off[bd.last_bcast_dim];
806 return bb_idx;
807 }
808
809 const char *get_data_A_ptr(int b, int m, int k) const {
810 int cur_b = get_bb_idx(b, bgmmc_.bcast_A_desc);
811 return data_A_ptr_ + get_data_A_off(cur_b, m, k);
812 }
813
814 const char *get_data_B_ptr(int b, int k, int n) const {
815 int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc);
816 return data_B_ptr_ + get_data_B_off(cur_b, k, n);
817 }
818
819 char *get_data_C_ptr(int b, int m, int n) const {
820 return data_C_ptr_ + get_data_C_off(b, m, n);
821 }
822
823 brgemm_batch_element_t *get_batch_elem_ptr(int ithr) const {
824 return batch_element_ptr_
825 + ithr * bgmmc_.brgemm_batch_element_per_thr_sz;
826 }
827
828 void init_brgemm_batch_elements_values(int ithr, int brg_batch_start,
829 int brg_batch_iters, int b_idx, int m_blk_idx, int k_blk_idx,
830 int n_blk_idx) const {
831 auto addr_batch = get_batch_elem_ptr(ithr);
832
833 const int m = m_blk_idx * bgmmc_.M_blk;
834 const int n = n_blk_idx * bgmmc_.N_blk;
835
836 for (int b_iter = 0; b_iter < brg_batch_iters; b_iter++) {
837 const int brg_batch_idx = brg_batch_start + b_iter;
838 const int k = (k_blk_idx + brg_batch_idx) * bgmmc_.K_blk;
839 addr_batch[b_iter].ptr.A = bgmmc_.use_buffer_a
840 ? get_buf_A_ptr(ithr, m_blk_idx, brg_batch_idx)
841 : get_data_A_ptr(b_idx, m, k);
842 addr_batch[b_iter].ptr.B = (bgmmc_.use_buffer_b)
843 ? get_buf_B_ptr(ithr, brg_batch_idx, n_blk_idx)
844 : get_data_B_ptr(b_idx, k, n);
845 }
846 }
847
848 char *get_buf_A_ptr(int ithr, int m_blk_idx, int k_blk_idx) const {
849 if (!bgmmc_.use_buffer_a && !bgmmc_.use_buffer_a_tail_only)
850 return nullptr;
851
852 const int k_blk_local = bgmmc_.use_buffer_a_tail_only ? 0 : k_blk_idx;
853 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
854 return buf_A_ptr_ + ithr * bgmmc_.buffer_a_per_thread_sz
855 + m_blk_local * bgmmc_.buffer_a_chunk_shift_along_m
856 + k_blk_local * bgmmc_.buffer_a_chunk_sz;
857 }
858
859 char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx) const {
860 UNUSED(n_blk_idx);
861 if (!bgmmc_.use_buffer_b) return nullptr;
862
863 return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz
864 + k_blk_idx * bgmmc_.buffer_b_chunk_sz;
865 }
866
867 char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const {
868 if (!bgmmc_.use_buffer_c) return nullptr;
869
870 if (bgmmc_.nthr_k > 1) {
871 const int nthr_k = bgmmc_.nthr_k <= nthr_ ? bgmmc_.nthr_k : 1;
872 const int nthr_bmn = nthr_ / nthr_k;
873 const int ithr_k = ithr / nthr_bmn;
874 return get_buf_C_par_reduction_ptr(ithr_k, m_blk_idx, n_blk_idx);
875 }
876
877 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
878 const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
879 const int buf_idx = bgmmc_.N_chunk_size * m_blk_local + n_blk_local;
880
881 return buf_C_ptr_ + ithr * bgmmc_.buffer_c_per_thread_sz
882 + buf_idx * bgmmc_.buffer_c_chunk_sz;
883 }
884
885 char *get_buf_C_par_reduction_ptr(
886 int ithr_k, int m_blk_idx, int n_blk_idx) const {
887 if (bgmmc_.nthr_k <= 1) return nullptr;
888
889 const int m = m_blk_idx * bgmmc_.M_blk;
890 const int n = n_blk_idx * bgmmc_.N_blk;
891
892 if (!bgmmc_.post_ops_applicable && ithr_k == 0)
893 return get_data_C_ptr(0, m, n);
894
895 int k_buf_idx = ithr_k - (!bgmmc_.post_ops_applicable ? 1 : 0);
896 return buf_C_ptr_ + k_buf_idx * bgmmc_.buffer_c_per_thread_sz
897 + get_data_C_off(0, m, n) * bgmmc_.acc_dt_sz / bgmmc_.c_dt_sz;
898 }
899
900 // Auxiliary functions for getting offsets with pre-calculated memory
901 // strides for each tensor to get general sulution for all possible
902 // dimension without significant overhead
903 dim_t get_data_A_off(int b, int m, int k) const {
904 using namespace format_tag;
905 if (bgmmc_.src_tag == acbd || bgmmc_.src_tag == adbc) {
906 dim_t b_off = 0;
907 if (!bgmmc_.bcast_A_desc.bcast_mask) { // no broadcast
908 const dim_t batch_dim1 = bgmmc_.bcast_A_desc.batch_dims[1];
909 b_off = bgmmc_.A_strides[2] * (b % batch_dim1)
910 + (b / batch_dim1) * bgmmc_.A_ptr_shift_b;
911 } else {
912 b_off = b * bgmmc_.A_ptr_shift_b;
913 }
914 return b_off + bgmmc_.A_strides[1] * m + bgmmc_.A_strides[0] * k;
915 } else {
916 return bgmmc_.A_strides[2] * b + bgmmc_.A_strides[1] * m
917 + bgmmc_.A_strides[0] * k;
918 }
919 }
920
921 dim_t get_data_B_off(int b, int k, int n) const {
922 using namespace format_tag;
923 if (bgmmc_.wei_tag == acbd || bgmmc_.wei_tag == adbc) {
924 dim_t b_off = 0;
925 if (!bgmmc_.bcast_B_desc.bcast_mask) { // no broadcast
926 const dim_t batch_dim1 = bgmmc_.bcast_B_desc.batch_dims[1];
927 b_off = bgmmc_.B_strides[2] * (b % batch_dim1)
928 + (b / batch_dim1) * bgmmc_.B_ptr_shift_b;
929 } else {
930 b_off = b * bgmmc_.B_ptr_shift_b;
931 }
932 return b_off + bgmmc_.B_strides[1] * k + bgmmc_.B_strides[0] * n;
933 } else {
934 int dt_b_k_blk = bgmmc_.is_bf32
935 ? data_type_vnni_simd_elems<avx512_core>(f32)
936 : bgmmc_.wei_k_blk;
937 int k_idx = bgmmc_.blocked_B ? k / dt_b_k_blk : k;
938 int n_idx = bgmmc_.blocked_B ? n / bgmmc_.wei_n_blk : n;
939 return bgmmc_.B_strides[2] * b + bgmmc_.B_strides[1] * k_idx
940 + bgmmc_.B_strides[0] * n_idx
941 + get_data_B_off_within_block(k, n);
942 }
943 }
944
945 dim_t get_data_B_off_within_block(int k, int n) const {
946 using namespace format_tag;
947
948 if (!bgmmc_.blocked_B) return 0;
949
950 int x0 = k % bgmmc_.wei_k_blk;
951 int x1 = n % bgmmc_.wei_n_blk;
952 dim_t offset = (x0 / vnni_factor) * vnni_factor * bgmmc_.wei_n_blk
953 + x1 * vnni_factor + x0 % vnni_factor;
954 return bgmmc_.b_dt_sz * offset;
955 }
956
957 dim_t get_data_C_off(int b, int m, int n) const {
958 using namespace format_tag;
959 assert(bgmmc_.dst_tag != adbc);
960 if (bgmmc_.dst_tag == acbd) {
961 const dim_t batch_dim1 = bgmmc_.bcast_A_desc.batch_dims[1];
962 dim_t b_off = bgmmc_.C_strides[2] * (b % batch_dim1)
963 + (b / batch_dim1) * bgmmc_.C_ptr_shift_b;
964 return b_off + bgmmc_.C_strides[1] * m + bgmmc_.C_strides[0] * n;
965 } else {
966 return bgmmc_.C_strides[2] * b + bgmmc_.C_strides[1] * m
967 + bgmmc_.C_strides[0] * n;
968 }
969 }
970
971 const char *get_bias_ptr(int n) const {
972 if (!bgmmc_.with_bias) return nullptr;
973
974 return bias_ptr_ + n * bgmmc_.bias_dt_sz;
975 }
976
977 int32_t *get_s8s8_comp_ptr(int ithr, int b, int n_blk_idx) const {
978 if (!bgmmc_.s8s8_compensation_required) return nullptr;
979
980 const int n_blk_local = bgmmc_.use_buffer_b
981 ? n_blk_idx % bgmmc_.N_chunk_size
982 : n_blk_idx;
983 return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
984 + get_bb_idx(b, bgmmc_.bcast_B_desc) * bgmmc_.s8s8_comp_b_str
985 + n_blk_local * bgmmc_.s8s8_comp_n_str;
986 }
987
988 const float *get_oscales_ptr(int n) const {
989 return oscales_ptr_ + bgmmc_.is_oscale_per_n * n;
990 }
991
992 const int32_t *get_zp_a_neg_val_ptr() const {
993 return &zero_point_a_negative_val_;
994 }
995
996 const int32_t *get_zp_b_neg_val_ptr() const {
997 return &zero_point_b_negative_val_;
998 }
999
1000 const int32_t *get_zp_ab_mixed_comp_ptr() const {
1001 return &zero_point_mixed_ab_compensation_component_;
1002 }
1003
1004 const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }
1005
1006 int32_t *get_zp_a_compensation_ptr(
1007 int ithr, int b_idx, int n_blk_idx) const {
1008 if (!bgmmc_.has_zero_point_a) return nullptr;
1009
1010 const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
1011 int32_t *zp_comp = zero_point_a_compensations_ptr_
1012 + ithr * bgmmc_.zp_a_comp_elems_per_thr
1013 + n_blk_local * bgmmc_.zp_a_comp_shift_n;
1014
1015 if (bgmmc_.blocked_B) {
1016 // Scale computed in reorder compensation values by zp_a value
1017 // locally just before usage. Using the single global scaling before
1018 // parallel section might produce significant overhead for small
1019 // problems running in multitreaded execution mode
1020 const int base_offset = get_bb_idx(b_idx, bgmmc_.bcast_B_desc)
1021 * rnd_up(bgmmc_.N, bgmmc_.wei_n_blk)
1022 + n_blk_idx * bgmmc_.wei_n_blk;
1023 PRAGMA_OMP_SIMD()
1024 for (int b = 0; b < bgmmc_.wei_n_blk; b++)
1025 zp_comp[b] = -zero_point_a_negative_val_
1026 * reorder_zp_a_comp_ptr_[base_offset + b];
1027 }
1028 return zp_comp;
1029 }
1030
1031 int32_t *get_zp_b_compensation_result_ptr(int ithr, int m_blk_idx) const {
1032 if (!bgmmc_.has_zero_point_b) return nullptr;
1033
1034 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
1035 return zero_point_b_compensations_ptr_
1036 + ithr * bgmmc_.zp_b_comp_elems_per_thr
1037 + m_blk_local * bgmmc_.zp_b_comp_result_shift_m;
1038 }
1039
1040 int32_t *get_zp_b_compensation_buffer_ptr(int ithr, int m_blk_idx) const {
1041 if (!bgmmc_.has_zero_point_b) return nullptr;
1042
1043 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
1044 return get_zp_b_compensation_result_ptr(ithr, 0)
1045 + bgmmc_.zp_b_comp_buffer_start
1046 + m_blk_local * bgmmc_.zp_b_comp_buffer_shift_m;
1047 }
1048
1049 char *get_tile_workspace(int ithr) const {
1050 return is_amx_ ? wsp_tile_ptr_ + ithr * bgmmc_.wsp_tile_per_thr_bytes
1051 : nullptr;
1052 }
1053
1054 const std::vector<const void *> &get_post_ops_binary_rhs_arg_vec() const {
1055 return post_ops_binary_rhs_arg_vec_;
1056 }
1057
1058 int get_base_brgemm_kernel_idx() const { return base_brg_ker_idx_; }
1059
1060 bool is_last_K_chunk(int k_chunk_idx) const {
1061 return k_chunk_idx == bgmmc_.K_chunks - 1;
1062 }
1063
1064 int get_brgemm_batch_size(int k_chunk_idx) const {
1065 return is_last_K_chunk(k_chunk_idx) ? last_chunk_brgemm_batch_size_
1066 : bgmmc_.brgemm_batch_size;
1067 }
1068
1069 int get_parallel_work_amount() const { return parallel_work_amount_; }
1070 int get_num_threads_for_k() const { return nthr_k_; }
1071 bool parallel_reduction_is_used() const {
1072 return nthr_k_ > 1 && bgmmc_.K_chunks > 1;
1073 }
1074 int get_num_threads_for_bmn() const { return nthr_bmn_; }
1075 // ithr = ithr_k * nthr_bmn + ithr_bmn
1076 int get_thread_idx_for_k(int ithr) const {
1077 if (ithr >= num_threads_used_) return -1;
1078 const int ithr_k = ithr / nthr_bmn_;
1079 return ithr_k < bgmmc_.K_chunks ? ithr_k : -1;
1080 }
1081 int get_thread_idx_for_bmn(int ithr) const {
1082 if (ithr >= num_threads_used_) return -1;
1083 const int ithr_bmn = ithr % nthr_bmn_;
1084 return ithr_bmn < parallel_work_amount_ ? ithr_bmn : -1;
1085 }
1086 int get_num_threads_for_parallelization() const { return nthr_; }
1087
1088private:
1089 bool is_amx_;
1090 const brgemm_matmul_conf_t &bgmmc_;
1091 const char *data_A_ptr_;
1092 const char *data_B_ptr_;
1093 char *data_C_ptr_;
1094 brgemm_batch_element_t *batch_element_ptr_;
1095
1096 char *buf_A_ptr_;
1097 char *buf_B_ptr_;
1098 char *buf_C_ptr_;
1099
1100 char *wsp_tile_ptr_;
1101 const char *bias_ptr_;
1102 const float *oscales_ptr_;
1103 int32_t *s8s8_compensation_ptr_;
1104
1105 int32_t *zero_point_a_compensations_ptr_;
1106 int32_t *zero_point_b_compensations_ptr_;
1107 int32_t *reorder_zp_a_comp_ptr_;
1108
1109 int32_t zero_point_a_negative_val_;
1110 int32_t zero_point_b_negative_val_;
1111 int32_t zero_point_mixed_ab_compensation_component_;
1112 int32_t zero_point_c_val_;
1113 std::vector<const void *> post_ops_binary_rhs_arg_vec_;
1114
1115 int base_brg_ker_idx_;
1116 int vnni_factor;
1117
1118 // parallelization parameters
1119 int parallel_work_amount_;
1120 int nthr_, nthr_k_, nthr_bmn_, num_threads_used_;
1121 int last_chunk_brgemm_batch_size_;
1122};
1123
1124template struct brgemm_matmul_t<avx512_core_amx_fp16>;
1125template struct brgemm_matmul_t<avx512_core_amx>;
1126template struct brgemm_matmul_t<avx512_core_fp16>;
1127template struct brgemm_matmul_t<avx512_core_bf16>;
1128template struct brgemm_matmul_t<avx512_core_vnni>;
1129template struct brgemm_matmul_t<avx512_core>;
1130
1131} // namespace matmul
1132} // namespace x64
1133} // namespace cpu
1134} // namespace impl
1135} // namespace dnnl
1136