1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstdint>
18#if defined(_MSC_VER)
19#include <malloc.h>
20#endif
21
22#include "oneapi/dnnl/dnnl_types.h"
23
24#include "common/bfloat16.hpp"
25#include "common/dnnl_traits.hpp"
26#include "common/nstl.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/platform.hpp"
30
31#include "cpu/gemm/f32/gemm_utils_f32.hpp"
32#include "cpu/gemm/gemm_msan_unpoison.hpp"
33
34#include "cpu/x64/jit_generator.hpp"
35
36#include "cpu/x64/gemm/gemm_driver.hpp"
37#include "cpu/x64/gemm/gemm_info.hpp"
38#include "cpu/x64/gemm/gemm_partition.hpp"
39#include "cpu/x64/gemm/gemm_threading.hpp"
40#include "cpu/x64/gemm/gemm_utils.hpp"
41#include "cpu/x64/gemm/gemv_driver.hpp"
42
43#include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp"
44#include "cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.hpp"
45#include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp"
46
47#include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemv_s8x8s32.hpp"
48
49namespace dnnl {
50namespace impl {
51namespace cpu {
52namespace x64 {
53#ifndef DNNL_WITH_SYCL
54#define MAX_STACK_SZ (4 * PAGE_4K)
55#else
56#define MAX_STACK_SZ 0
57#endif
58
59template <typename c_type>
60struct alignas(64) gemm_per_thread_t {
61 volatile int32_t result;
62 volatile int32_t compute_done;
63 int32_t thr_k_stride;
64 int32_t nthr_k;
65 dim_t ldc_local;
66 dim_t ldc_global;
67 c_type *c_local;
68 c_type *volatile c_global;
69 gemm_slice_t slice;
70};
71
72template <typename T>
73int get_vector_length() {
74 int v_bytes;
75
76 if (mayiuse(avx512_core))
77 v_bytes = cpu_isa_traits<avx512_core>::vlen;
78 else if (mayiuse(avx))
79 v_bytes = cpu_isa_traits<avx>::vlen;
80 else
81 v_bytes = cpu_isa_traits<sse41>::vlen;
82
83 return v_bytes / sizeof(T);
84}
85
86template <typename c_type>
87static inline void round_to_nearest(c_type *rounded_val, double fp_val) {
88 if (fp_val >= 0.) {
89 fp_val += 0.5;
90 if (fp_val > INT32_MAX) { fp_val = INT32_MAX; }
91 } else {
92 fp_val -= 0.5;
93 if (fp_val < INT32_MIN) { fp_val = INT32_MIN; }
94 }
95 *rounded_val = (c_type)fp_val;
96}
97
98template <typename c_type>
99static inline void add_results(const dim_t m, const dim_t n, const float alpha,
100 const float beta, const c_type *c_partial_sum, const dim_t ldcp,
101 c_type *c_data, const dim_t ldc, const c_type *co,
102 offset_type offsetc) {
103
104 constexpr bool is_int8 = data_traits<c_type>::data_type == data_type::s32;
105
106 for (dim_t j = 0; j < n; ++j) {
107 for (dim_t i = 0; i < m; ++i) {
108 c_type ctemp = c_partial_sum[i + j * ldcp];
109
110 if (alpha == 1.0f) {
111 if (beta == 0.0f) {
112 c_data[i + j * ldc] = ctemp;
113 } else {
114 if (is_int8) {
115 double c_float
116 = (double)beta * (double)c_data[i + j * ldc];
117 c_float += (double)ctemp;
118 round_to_nearest(&c_data[i + j * ldc], c_float);
119 } else {
120 c_data[i + j * ldc] *= beta;
121 c_data[i + j * ldc] += ctemp;
122 }
123 }
124 } else if (alpha == -1.0f) {
125 if (beta == 0.0f) {
126 c_data[i + j * ldc] = -ctemp;
127 } else {
128 if (is_int8) {
129 double c_float
130 = (double)beta * (double)c_data[i + j * ldc];
131 c_float -= (double)ctemp;
132 round_to_nearest(&c_data[i + j * ldc], c_float);
133 } else {
134 c_data[i + j * ldc] *= beta;
135 c_data[i + j * ldc] -= ctemp;
136 }
137 }
138 } else {
139 if (beta == 0.0f) {
140 if (is_int8) {
141 double c_float = alpha * (double)ctemp;
142 round_to_nearest(&c_data[i + j * ldc], c_float);
143 } else {
144 c_data[i + j * ldc] = alpha * ctemp;
145 }
146
147 } else {
148 if (is_int8) {
149 double c_float = alpha * (double)ctemp
150 + beta * (double)c_data[i + j * ldc];
151 round_to_nearest(&c_data[i + j * ldc], c_float);
152 } else {
153 c_data[i + j * ldc] *= beta;
154 c_data[i + j * ldc] += alpha * ctemp;
155 }
156 }
157 }
158
159 if (offsetc == offset_type::fixed) {
160 c_data[i + j * ldc] += co[0];
161 } else if (offsetc == offset_type::row) {
162 c_data[i + j * ldc] += co[j];
163 } else if (offsetc == offset_type::column) {
164 c_data[i + j * ldc] += co[i];
165 }
166 }
167 }
168}
169
170template <typename a_type, typename b_type, typename c_type>
171static inline dim_t get_k_padd(
172 int ithr, dim_t k, const gemm_info_t<a_type, b_type, c_type> *arg) {
173 if (arg->a_packed) {
174 dim_t block_m, block_k;
175 arg->a_packed->get_blocking(ithr, block_m, block_k);
176 return block_k;
177 } else if (arg->b_packed) {
178 dim_t block_n, block_k;
179 arg->b_packed->get_blocking(ithr, block_k, block_n);
180 return block_k;
181 } else {
182 dim_t k_padd = 0;
183
184 if (k <= arg->bk_traditional) {
185 k_padd = utils::rnd_up(k, arg->uk);
186 k_padd = nstl::max(dim_t(128), k_padd);
187 } else if (k < 2 * arg->bk)
188 k_padd = utils::rnd_up((k + 1) / 2, arg->uk);
189 else
190 k_padd = arg->bk;
191
192 return k_padd;
193 }
194}
195
196template <typename a_type, typename b_type, typename c_type>
197static inline dim_t get_m_padd(
198 int ithr, dim_t m, const gemm_info_t<a_type, b_type, c_type> *arg) {
199 if (arg->a_packed) {
200 dim_t block_m, block_k;
201 arg->a_packed->get_blocking(ithr, block_m, block_k);
202 return block_m;
203 } else
204 return utils::rnd_up(
205 nstl::min(nstl::max(m, arg->um), arg->bm), arg->um);
206}
207
208template <typename a_type, typename b_type, typename c_type>
209static inline dim_t get_m_padd_parallel_a(int ithr, dim_t m,
210 const gemm_info_t<a_type, b_type, c_type> *arg, int nthrs) {
211 auto m_padd = get_m_padd(ithr, m, arg);
212
213 if (!arg->a_packed) {
214 constexpr auto multiplier = 10;
215
216 m_padd *= nstl::min(nthrs, multiplier);
217 if (m_padd > m) m_padd = utils::rnd_up(m, arg->um);
218 }
219
220 return m_padd;
221}
222
223template <typename a_type, typename b_type, typename c_type>
224static inline dim_t get_n_padd(int ithr, dim_t n, dim_t k,
225 const gemm_info_t<a_type, b_type, c_type> *arg) {
226 if (arg->b_packed) {
227 dim_t block_n, block_k;
228 arg->b_packed->get_blocking(ithr, block_k, block_n);
229 return block_n;
230 } else {
231 auto bn = (k < arg->blocking_small_k) ? arg->bn_small_k : arg->bn;
232 return utils::rnd_up(nstl::min(nstl::max(n, arg->un), bn), arg->un);
233 }
234}
235
236static inline void *align(void *ptr, size_t alignment) {
237 return (void *)utils::rnd_up((uintptr_t)ptr, alignment);
238}
239
240template <typename scale_t, typename mat_t>
241void scale_matrix(
242 dim_t m, dim_t n, scale_t alpha, mat_t *__restrict p_mat, dim_t ld) {
243 if (data_traits<mat_t>::data_type == data_type::f32) {
244 for (dim_t j = 0; j < n; j++) {
245 for (dim_t i = 0; i < m; i++) {
246 p_mat[i + j * ld] = (mat_t)((scale_t)p_mat[i + j * ld] * alpha);
247 }
248 }
249 }
250}
251
252template <typename mat_t>
253static void sum_matrices(dim_t m, dim_t n, mat_t *__restrict dst, dim_t ld_dst,
254 mat_t *__restrict src, dim_t ld_src) {
255
256 for (dim_t j = 0; j < n; j++) {
257 PRAGMA_OMP_SIMD()
258 for (int i = 0; i < m; i++)
259 dst[i + j * ld_dst] += src[i + j * ld_src];
260 }
261}
262
263template <typename c_type>
264static void sum_k_blocks(
265 int ithr, gemm_per_thread_t<c_type> *thread_arg, bool wait) {
266
267 auto m = thread_arg[ithr].slice.m;
268 auto n = thread_arg[ithr].slice.n;
269 auto ithr_k = thread_arg[ithr].slice.ithr_k;
270 auto nthr_k = thread_arg[ithr].nthr_k;
271 auto stride = thread_arg[ithr].thr_k_stride;
272 dim_t n0, nn;
273
274 partition_1d(ithr_k, nthr_k, n, n0, nn);
275
276 auto get_thread_arg = [&](int thr_k) -> gemm_per_thread_t<c_type> & {
277 return thread_arg[ithr + (thr_k - ithr_k) * stride];
278 };
279
280 auto wait_thread = [&](int thr_k) {
281 if (wait) {
282 auto &tk_arg = get_thread_arg(thr_k);
283 while (!tk_arg.compute_done) {}
284 }
285 };
286
287 auto add_thread_results = [&](int thr_k) {
288 auto &tk_arg = get_thread_arg(thr_k);
289
290 sum_matrices(m, nn, tk_arg.c_global + n0 * tk_arg.ldc_global,
291 tk_arg.ldc_global, tk_arg.c_local + n0 * tk_arg.ldc_local,
292 tk_arg.ldc_local);
293 };
294
295 // First accumulate this thread's results while they are in cache.
296 if (ithr_k > 0) {
297 wait_thread(0);
298 add_thread_results(ithr_k);
299 }
300
301 // Then accumulate the others.
302 for (int thr_k = 1; thr_k < nthr_k; thr_k++) {
303 if (thr_k != ithr_k) {
304 wait_thread(thr_k);
305 add_thread_results(thr_k);
306 }
307 }
308}
309
310template <typename a_type, typename b_type, typename c_type>
311static dnnl_status_t pack_no_copy(gemm_info_t<a_type, b_type, c_type> *arg) {
312
313 if (arg->packing == pack_type::pack_a) {
314 return gemm_utils::pack_no_copy(arg->a, arg->lda, arg->m, arg->k,
315 arg->transa, arg->alpha, arg->pack_dst);
316 } else {
317 return gemm_utils::pack_no_copy(arg->b, arg->ldb, arg->k, arg->n,
318 arg->transb, arg->alpha, arg->pack_dst);
319 }
320}
321
322template <typename a_type, typename b_type, typename c_type>
323static dnnl_status_t gemm_packing_driver(int ithr, dim_t m, dim_t n, dim_t k,
324 const a_type *a, const b_type *b,
325 const gemm_info_t<a_type, b_type, c_type> *arg) {
326
327 if (m <= 0 || n <= 0) return dnnl_success;
328
329 gemm_pack_storage_t *pack_dst = arg->pack_dst;
330
331 if (!pack_dst->is_first_thread_in_slice(ithr)) return dnnl_success;
332
333 dim_t block_r, block_c;
334 pack_dst->get_blocking(ithr, block_r, block_c);
335
336 auto do_a = (arg->packing == pack_type::pack_a);
337 auto mn = do_a ? m : n;
338 auto mn_padd = do_a ? block_r : block_c;
339 auto k_padd = do_a ? block_c : block_r;
340 dim_t mn_stride, k_stride;
341
342 if (do_a) {
343 mn_stride = (arg->transa == no_trans) ? 1 : arg->lda;
344 k_stride = (arg->transa == no_trans) ? arg->lda : 1;
345 } else {
346 mn_stride = (arg->transb == no_trans) ? arg->ldb : 1;
347 k_stride = (arg->transb == no_trans) ? 1 : arg->ldb;
348 }
349
350 dim_t blk_k = 0;
351 for (dim_t Bk = 0; Bk < k; Bk += k_padd, blk_k++) {
352 dim_t nk = nstl::min(k - Bk, k_padd);
353
354 for (dim_t Bmn = 0; Bmn < mn; Bmn += mn_padd) {
355 dim_t nmn = nstl::min(mn - Bmn, mn_padd);
356
357 if (do_a) {
358 auto a_src = a + mn_stride * Bmn + k_stride * Bk;
359 auto a_dst = pack_dst->matrix<a_type>(ithr, Bmn, Bk);
360 auto a_row_sum = pack_dst->row_sums<c_type>(ithr, Bmn, blk_k);
361
362 arg->copyA(&nk, &nmn, a_src, &arg->lda, &arg->alpha, a_dst,
363 nullptr, nullptr, a_row_sum);
364 } else {
365 auto b_src = b + mn_stride * Bmn + k_stride * Bk;
366 auto b_dst = pack_dst->matrix<b_type>(ithr, Bk, Bmn);
367 auto b_col_sum = pack_dst->col_sums<c_type>(ithr, blk_k, Bmn);
368
369 arg->copyB(&nk, &nmn, b_src, &arg->ldb, &arg->alpha, b_dst,
370 nullptr, nullptr, b_col_sum);
371 }
372 }
373 }
374
375 return dnnl_success;
376}
377
378template <typename a_type, typename b_type, typename c_type>
379void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha,
380 const a_type *a, const b_type *b, float beta, c_type *c,
381 const dim_t ldc, const c_type *a_row_sum, const c_type *b_col_sum,
382 c_type *row_offset_ws, c_type *col_offset_ws, const c_type *co,
383 offset_type offsetc, const gemm_info_t<a_type, b_type, c_type> *arg) {
384
385 bool col_req = false;
386 bool row_req = false;
387
388 constexpr bool is_int8 = utils::one_of(
389 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
390 constexpr bool is_f32 = data_traits<a_type>::data_type == data_type::f32;
391 bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx);
392
393 dim_t m_stk = col_offset_ws ? 1 : m;
394 dim_t n_stk = row_offset_ws ? 1 : n;
395#if !defined(_MSC_VER)
396 c_type col_offset_stk[m_stk];
397 c_type row_offset_stk[n_stk];
398#else
399 c_type *col_offset_stk = nullptr;
400 if (!col_offset_ws)
401 col_offset_stk = (c_type *)_alloca(sizeof *col_offset_stk * m_stk);
402
403 c_type *row_offset_stk = nullptr;
404 if (!row_offset_ws)
405 row_offset_stk = (c_type *)_alloca(sizeof *row_offset_stk * n_stk);
406#endif
407
408 // Use the heap if already allocated and stack otherwise.
409 c_type *col_offset = col_offset_ws ? col_offset_ws : col_offset_stk;
410 c_type *row_offset = row_offset_ws ? row_offset_ws : row_offset_stk;
411
412 if (is_int8) {
413 c_type ao = arg->ao;
414 c_type bo = arg->bo;
415 c_type co_0 = offsetc == offset_type::none ? 0 : co[0];
416
417 if (bo != 0 || offsetc == offset_type::column) col_req = true;
418 if (ao != 0 || offsetc == offset_type::row) row_req = true;
419
420 // It needs one of column or row offsets, but it doesn't need both
421 if ((ao != 0 && bo != 0)
422 || (offsetc == offset_type::fixed && co_0 != 0)) {
423 if (!col_req && !row_req) {
424 if (m <= n) {
425 col_req = true;
426 } else {
427 row_req = true;
428 }
429 }
430 }
431
432 if (col_req) {
433 for (dim_t i = 0; i < m; i++)
434 col_offset[i] = 0;
435
436 if (offsetc == offset_type::column) {
437 for (dim_t i = 0; i < m; i++)
438 col_offset[i] += co[i];
439 }
440
441 if (bo != 0 && a_row_sum) {
442 for (dim_t i = 0; i < m; i++)
443 col_offset[i] -= bo * a_row_sum[i];
444 }
445 }
446
447 if (row_req) {
448 for (dim_t i = 0; i < n; i++)
449 row_offset[i] = 0;
450
451 if (offsetc == offset_type::row) {
452 for (dim_t i = 0; i < n; i++)
453 row_offset[i] += co[i];
454 }
455
456 if (ao != 0 && b_col_sum) {
457 for (dim_t i = 0; i < n; i++)
458 row_offset[i] -= ao * b_col_sum[i];
459 }
460 }
461
462 if (offsetc == offset_type::fixed && co_0 != 0) {
463 if (col_req) {
464 for (dim_t i = 0; i < m; i++)
465 col_offset[i] += co_0;
466 } else {
467 for (dim_t i = 0; i < n; i++)
468 row_offset[i] += co_0;
469 }
470 }
471
472 if (ao != 0 && bo != 0) {
473 if (col_req) {
474 for (dim_t i = 0; i < m; i++)
475 col_offset[i] += (c_type)k * ao * bo;
476 } else {
477 for (dim_t i = 0; i < n; i++)
478 row_offset[i] += (c_type)k * ao * bo;
479 }
480 }
481 }
482
483 bool isBeta0 = beta == 0.0f;
484
485 /* Column and row offsets are ignored by non-integer compute kernels.
486 * Scaling is done only for bfloat16 kernels.
487 */
488 if (m > 0 && n > 0)
489 arg->kernel[isBeta0][col_req][row_req](
490 &m, &n, &k, &alpha, a, b, c, ldc, col_offset, row_offset);
491
492 msan_unpoison_matrix(c, m, n, ldc, sizeof(*c));
493
494 // sgemm kernels don't support bias yet.
495 if (is_f32) {
496 if (co && offsetc == offset_type::column) {
497 for (dim_t j = 0; j < n; j++) {
498 for (dim_t i = 0; i < m; i++) {
499 c[i + j * ldc] += co[i];
500 }
501 }
502 }
503 }
504
505 // AMX igemm kernels don't support row & col sums yet.
506 if (is_int8_amx) {
507 for (dim_t j = 0; j < n; j++) {
508 for (dim_t i = 0; i < m; i++) {
509 if (row_req) c[i + j * ldc] += row_offset[j];
510 if (col_req) c[i + j * ldc] += col_offset[i];
511 }
512 }
513 }
514}
515
516template <typename a_type, typename b_type, typename c_type>
517static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k,
518 const a_type *a, const b_type *b, float beta, c_type *c, dim_t ldc,
519 offset_type offsetc, const c_type *co,
520 const gemm_info_t<a_type, b_type, c_type> *arg) {
521
522 if (arg->packing != pack_type::none)
523 return gemm_packing_driver(ithr, m, n, k, a, b, arg);
524
525 if (m <= 0 || n <= 0) return dnnl_success;
526
527 dim_t lda = arg->lda;
528 dim_t ldb = arg->ldb;
529
530 float alpha = arg->alpha;
531
532 constexpr bool is_int8 = utils::one_of(
533 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
534 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
535 bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx);
536 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx);
537 bool is_amx = is_int8_amx || is_bf16_amx;
538
539 const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
540 const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
541
542 // Scaling C matrix.
543 if (!is_int8 && beta != 1.0f && beta != 0.0f) {
544 scale_matrix(m, n, beta, c, ldc);
545 beta = 1.0f;
546 }
547
548 // Quick exit for C = beta * C
549 if (!is_int8 && alpha == 0.0f) {
550 if (beta == 0.0f) scale_matrix(m, n, beta, c, ldc);
551
552 return dnnl_success;
553 }
554
555 // Get block sizes.
556 dim_t k_padd = get_k_padd(ithr, k, arg);
557 dim_t m_padd = get_m_padd(ithr, m, arg);
558 dim_t n_padd = get_n_padd(ithr, n, k, arg);
559
560 // Padding for temporary buffer for C
561 dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m_padd);
562
563 dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
564 dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
565 dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
566 dim_t strideBn = (arg->transb != no_trans) ? 1 : ldb;
567
568 size_t a_buf_nelems = m_padd * k_padd;
569 size_t b_buf_nelems = k_padd * n_padd;
570 // A and B buffers need more space due to zero-padding.
571 if (is_amx) {
572 a_buf_nelems = utils::rnd_up(m_padd, arg->um)
573 * utils::rnd_up(k_padd, arg->uk);
574 b_buf_nelems = utils::rnd_up(k_padd, arg->uk)
575 * utils::rnd_up(n_padd, arg->un);
576 }
577 size_t a_row_sum_nelems = m_padd;
578 size_t b_col_sum_nelems = n_padd;
579
580 if (a_packed) a_buf_nelems = a_row_sum_nelems = 0;
581 if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
582
583 size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
584 + b_buf_nelems * sizeof(*b) + PAGE_4K;
585
586 if (is_int8) {
587 mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K
588 + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
589 }
590
591 size_t col_offset_ws_nelems = arg->um;
592 size_t row_offset_ws_nelems = n_padd;
593 size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c);
594 const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ;
595 if (!use_stack) {
596 mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K;
597 mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K;
598 }
599
600 bool need_c_buffer
601 = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
602 // AMX bfloat16 kernels don't support alpha scaling yet,
603 // so we need to use accumulation buffer even if beta == 0.
604 || (is_bf16_amx && alpha != 1.0f);
605
606 if (need_c_buffer) {
607 size_t c_buf_nelems = ldc_buf * n_padd;
608 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
609 }
610
611 char *mem = nullptr;
612
613 if (mem_size > 0) {
614 mem = (char *)malloc(mem_size, 128);
615 if (!mem) return dnnl_out_of_memory;
616 }
617
618 a_type *bufferA = (a_type *)align(mem, PAGE_4K);
619 void *p_next_buf = bufferA + a_buf_nelems;
620
621 b_type *bufferB = (b_type *)align(p_next_buf, PAGE_4K);
622 p_next_buf = bufferB + b_buf_nelems;
623
624 c_type *a_row_sum = nullptr;
625 c_type *b_col_sum = nullptr;
626 if (is_int8) {
627 a_row_sum = (c_type *)align(p_next_buf, PAGE_4K);
628 p_next_buf = a_row_sum + a_row_sum_nelems;
629
630 b_col_sum = (c_type *)align(p_next_buf, PAGE_4K);
631 p_next_buf = b_col_sum + b_col_sum_nelems;
632 }
633
634 c_type *col_offset_ws = nullptr;
635 c_type *row_offset_ws = nullptr;
636 if (!use_stack) {
637 col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K);
638 p_next_buf = col_offset_ws + col_offset_ws_nelems;
639
640 row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K);
641 p_next_buf = row_offset_ws + row_offset_ws_nelems;
642 }
643
644 c_type *bufferC = nullptr;
645 if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K);
646
647 int a_block_copied = 0;
648 dim_t sizeM = 0;
649 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
650 sizeM = m - Bm;
651 if (sizeM > m_padd) sizeM = m_padd;
652
653 dim_t sizeK = 0;
654 dim_t blk_k = 0;
655 for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
656 sizeK = k - Bk;
657 if (sizeK > k_padd) sizeK = k_padd;
658
659 // Scale C blocks by beta only for the first time
660 auto beta_eff = (Bk == 0) ? beta : 1.0f;
661
662 // Apply C offset when to the last k-block of the partial sum.
663 auto offsetc_eff = offset_type::none;
664 if (Bk + sizeK == k) offsetc_eff = offsetc;
665
666 dim_t sizeN = 0;
667 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
668 sizeN = n - Bn;
669 if (sizeN > n_padd) sizeN = n_padd;
670
671 if (b_packed) {
672 bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
673 if (is_int8)
674 b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
675 } else {
676 const b_type *b_block = b + Bk * strideBm + Bn * strideBn;
677 const float one = 1.0f;
678
679 /* Column sum argument is ignored for non-integer kernels
680 * and scaling factor is ignored by 8-bit and 16-bit copy
681 * kernels.
682 */
683 arg->copyB(&sizeK, &sizeN, b_block, &ldb, &one, bufferB,
684 nullptr, nullptr, b_col_sum);
685 }
686
687 dim_t sizeUM = 0;
688 for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
689 sizeUM = sizeM - Um;
690 if (sizeUM > arg->um) sizeUM = arg->um;
691
692 /* Use the whole A buffer only if we have multiple B
693 * blocks for k-dimension, otherwise we are wasting cache
694 * to store B and C blocks.
695 */
696 dim_t Um_forA = 0;
697 if (sizeN < n) Um_forA = Um;
698
699 a_type *bufferA_eff = nullptr;
700 c_type *a_row_sum_eff = nullptr;
701
702 if (a_packed) {
703 Um_forA = Um;
704
705 // TODO Can we simplify this!
706 dim_t buf_shift = 0;
707 if (is_amx)
708 buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
709 else
710 buf_shift = Um_forA * sizeK;
711
712 bufferA_eff = a_packed->matrix<a_type>(ithr, Bm, Bk)
713 + buf_shift;
714
715 if (is_int8)
716 a_row_sum_eff = a_packed->row_sums<c_type>(
717 ithr, Bm, blk_k)
718 + Um_forA;
719 } else {
720 // TODO Can we simplify this!
721 dim_t buf_shift = 0;
722 if (is_amx)
723 buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
724 else
725 buf_shift = Um_forA * sizeK;
726
727 bufferA_eff = bufferA + buf_shift;
728 a_row_sum_eff
729 = a_row_sum ? a_row_sum + Um_forA : nullptr;
730
731 if (!a_block_copied) {
732 const a_type *a_block
733 = a + (Bm + Um) * strideAm + Bk * strideAn;
734
735 /* Row sum argument is ignored for non-integer
736 * kernels and scaling factor is ignored by 8-bit
737 * and 16-bit copy kernels.
738 */
739 arg->copyA(&sizeK, &sizeUM, a_block, &lda, &alpha,
740 bufferA_eff, nullptr, nullptr,
741 a_row_sum_eff);
742 }
743 }
744
745 c_type *c_block = c + (Bm + Um) + Bn * ldc;
746
747 dim_t co_stride = 0;
748 if (offsetc_eff == offset_type::row)
749 co_stride = Bn;
750 else if (offsetc_eff == offset_type::column)
751 co_stride = Bm + Um;
752
753 if (need_c_buffer) {
754 gemm_kernel(sizeUM, sizeN, sizeK, 1.0f, bufferA_eff,
755 bufferB, 0.0f, bufferC + Um, ldc_buf,
756 a_row_sum_eff, b_col_sum, row_offset_ws,
757 col_offset_ws, (c_type *)nullptr,
758 offset_type::none, arg);
759
760 /* Finish the block adding the necessary alpha, beta
761 * and offsets.
762 */
763 add_results(sizeUM, sizeN, alpha, beta_eff,
764 bufferC + Um, ldc_buf, c_block, ldc,
765 co + co_stride, offsetc_eff);
766 } else {
767 gemm_kernel(sizeUM, sizeN, sizeK, alpha, bufferA_eff,
768 bufferB, beta_eff, c_block, ldc, a_row_sum_eff,
769 b_col_sum, row_offset_ws, col_offset_ws,
770 co + co_stride, offsetc_eff, arg);
771 }
772 }
773 a_block_copied = 1;
774 }
775 a_block_copied = 0;
776 }
777 }
778
779 free(mem);
780
781 return dnnl_success;
782}
783
784template <typename a_type, typename b_type, typename c_type>
785static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m,
786 dim_t n, dim_t k, dim_t blk_k, dim_t Bk, const a_type *bufferA,
787 const b_type *b, float beta, c_type *c, offset_type offsetc,
788 const c_type *co, const c_type *a_row_sum,
789 const gemm_info_t<a_type, b_type, c_type> *arg) {
790
791 dim_t ldb = arg->ldb;
792 dim_t ldc = arg->ldc;
793
794 float alpha = arg->alpha;
795
796 const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
797
798 if (m <= 0 || n <= 0) { return dnnl_success; }
799
800 // Padding along N dimension.
801 dim_t n_padd = get_n_padd(ithr, n, k, arg);
802
803 // Padding for temporary buffer for C
804 dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m);
805
806 dim_t strideBn = (arg->transb != 0) ? 1 : ldb;
807
808 size_t b_buf_nelems = k * n_padd;
809 size_t b_col_sum_nelems = n_padd;
810 constexpr bool is_int8 = utils::one_of(
811 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
812 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
813 bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx);
814 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx);
815 bool is_amx = is_int8_amx || is_bf16_amx;
816
817 // B buffer needs to be large due to zero-padding.
818 if (is_amx)
819 b_buf_nelems
820 = utils::rnd_up(k, arg->uk) * utils::rnd_up(n_padd, arg->un);
821
822 if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
823
824 size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K;
825
826 if (is_int8) { mem_size += b_col_sum_nelems * sizeof(*c) + PAGE_4K; }
827
828 size_t col_offset_ws_nelems = m;
829 size_t row_offset_ws_nelems = n_padd;
830 size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c);
831 const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ;
832 if (!use_stack) {
833 mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K;
834 mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K;
835 }
836
837 bool need_c_buffer
838 = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
839 // AMX bfloat16 kernels don't support alpha scaling yet,
840 // so we need to use accumulation buffer even if beta == 0.
841 || (is_bf16_amx && alpha != 1.0f);
842
843 if (need_c_buffer) {
844 size_t c_buf_nelems = ldc_buf * n_padd;
845 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
846 }
847
848 char *mem = nullptr;
849
850 if (mem_size > 0) {
851 mem = (char *)malloc(mem_size, 128);
852 if (!mem) return dnnl_out_of_memory;
853 }
854
855 b_type *bufferB = (b_type *)align(mem, PAGE_4K);
856 void *p_next_buf = bufferB + b_buf_nelems;
857
858 c_type *b_col_sum = nullptr;
859 if (is_int8) {
860 b_col_sum = (c_type *)align(p_next_buf, PAGE_4K);
861 p_next_buf = b_col_sum + b_col_sum_nelems;
862 }
863
864 c_type *col_offset_ws = nullptr;
865 c_type *row_offset_ws = nullptr;
866 if (!use_stack) {
867 col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K);
868 p_next_buf = col_offset_ws + col_offset_ws_nelems;
869
870 row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K);
871 p_next_buf = row_offset_ws + row_offset_ws_nelems;
872 }
873
874 c_type *bufferC = nullptr;
875 if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K);
876
877 dim_t sizeN = 0;
878 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
879 sizeN = n - Bn;
880 if (sizeN > n_padd) sizeN = n_padd;
881
882 if (b_packed) {
883 bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
884 if (is_int8)
885 b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
886 } else {
887 const b_type *b_block = b + Bn * strideBn;
888 const float one = 1.0f;
889
890 /* Column sum argument is ignored for non-integer kernels and
891 * scaling factor is ignored by 8-bit and 16-bit copy kernels.
892 */
893 arg->copyB(&k, &sizeN, b_block, &ldb, &one, bufferB, nullptr,
894 nullptr, b_col_sum);
895 }
896
897 dim_t co_stride = 0;
898 if (offsetc == offset_type::fixed) {
899 co_stride = 0;
900 } else if (offsetc == offset_type::row) {
901 co_stride = Bn;
902 } else if (offsetc == offset_type::column) {
903 co_stride = 0;
904 }
905
906 c_type *c_block = c + Bn * ldc;
907 if (need_c_buffer) {
908 gemm_kernel(m, sizeN, k, 1.0f, bufferA, bufferB, 0.0f, bufferC,
909 ldc_buf, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws,
910 (c_type *)nullptr, offset_type::none, arg);
911
912 // Finish the block adding the necessary alpha, beta and offsets.
913 add_results(m, sizeN, alpha, beta, bufferC, ldc_buf, c_block, ldc,
914 co + co_stride, offsetc);
915 } else {
916 gemm_kernel(m, sizeN, k, alpha, bufferA, bufferB, beta, c_block,
917 ldc, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws,
918 co + co_stride, offsetc, arg);
919 }
920 }
921
922 free(mem);
923
924 return dnnl_success;
925}
926
927static inline bool nocopy_checker_avx2(const int nthr, const int transa,
928 const int transb, const dim_t m, const dim_t n, const dim_t k,
929 const dim_t lda, const dim_t ldb, const dim_t ldc) {
930 static const dim_t BM_NOCOPY_AVX2 = 64;
931 static const dim_t MN_NOCOPY_AVX2 = 128;
932 static const dim_t N_TRANSB_PER_THR = 1;
933 static const dim_t K_TRANSB_PER_THR = 1;
934 static const dim_t N_NOTRANSB_PER_THR = 16;
935 static const dim_t K_NOTRANSB_PER_THR = 2;
936 static const double FORCE_NOCOPY_THRESH = 0.0038;
937
938 // Crude threshold to nocopy kernels if copy overhead is significant.
939 if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH) { return true; }
940
941 if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
942
943 if (m >= nthr * 378 && k >= nthr * 378) return false;
944
945 if (transb == no_trans) {
946 if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
947 if (n <= nthr * N_NOTRANSB_PER_THR) return true;
948 if (k <= nthr * K_NOTRANSB_PER_THR) return true;
949 if (m <= BM_NOCOPY_AVX2 && n >= nthr * N_NOTRANSB_PER_THR) return true;
950 } else {
951 if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
952 if (n <= nthr * N_TRANSB_PER_THR) return true;
953 if (k <= nthr * K_TRANSB_PER_THR) return true;
954 }
955
956 return false;
957}
958
959static inline bool nocopy_checker_avx512(int nthr, const int transa,
960 const int transb, const dim_t m, const dim_t n, const dim_t k,
961 const dim_t lda, const dim_t ldb, const dim_t ldc) {
962 // Constants definition
963 static const dim_t BAD_LD_MULT = 256;
964 static const dim_t VERYBAD_LD_MULT = 1024;
965 static const dim_t M_TRANSB_PER_THR = 28;
966 static const dim_t N_TRANSB_PER_THR = 28;
967 static const dim_t K_TRANSB_PER_THR = 1;
968 static const dim_t MN_NOTRANSB_PER_THR = 28;
969 static const dim_t K_NOTRANSB_PER_THR = 1;
970 static const double FORCE_NOCOPY_THRESH = 0.00196;
971
972 bool is_NN = transa == no_trans && transb == no_trans;
973 bool is_NT = transa == no_trans && transb == do_trans;
974 bool is_TN = transa == do_trans && transb == no_trans;
975
976 bool is_lda_bad = lda % BAD_LD_MULT == 0;
977 bool is_ldb_bad = ldb % BAD_LD_MULT == 0;
978 bool is_ldc_bad = ldc % BAD_LD_MULT == 0;
979 bool is_ld_bad = is_lda_bad || is_ldb_bad || is_ldc_bad;
980
981 bool is_lda_verybad = lda % VERYBAD_LD_MULT == 0;
982
983 // Copy-based performs better for TN case with small N in sequential case.
984 if (nthr == 1 && is_TN && m > 100
985 && ((m < 1200 && n < 200 && k < 1200)
986 || (is_lda_bad && is_ldb_bad)))
987 return false;
988
989 // Copy-based performs better for NN case on very bad leading dimension if
990 // each thread has enough work.
991 if (nthr <= 8 && is_NN && is_lda_verybad && k > 500 && n > 100)
992 return false;
993
994 // Crude threshold for nocopy kernels if copy overhead is significant.
995 if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH
996 && !(is_lda_verybad && is_NT)) {
997 return true;
998 }
999
1000 // Copy strategy usually performs better than nocopy on "bad" leading
1001 // dimensions.
1002 if (is_ld_bad) {
1003 bool use_copy_based = false;
1004
1005 if (m >= 32 && n > 16) use_copy_based = true;
1006
1007 // Nocopy outperforms copy-based in certain conditions.
1008 if (m >= 32 && n == 16
1009 && (k >= 6400 || transa == do_trans || m == 4096))
1010 use_copy_based = true;
1011
1012 if (use_copy_based) return false;
1013 }
1014
1015 if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
1016
1017 if (m >= nthr * 378 && k >= nthr * 378) return false;
1018
1019 if (transb == no_trans) {
1020 if (m <= nthr * MN_NOTRANSB_PER_THR) return true;
1021 if (n <= nthr * MN_NOTRANSB_PER_THR) return true;
1022 if (k <= nthr * K_NOTRANSB_PER_THR) return true;
1023 } else {
1024 if (m <= nthr * M_TRANSB_PER_THR && m >= n) return true;
1025 if (n <= nthr * N_TRANSB_PER_THR) return true;
1026 if (k <= nthr * K_TRANSB_PER_THR) return true;
1027 }
1028 return false;
1029}
1030
1031template <typename a_type, typename b_type, typename c_type>
1032static inline bool nocopy_checker(
1033 int nthr, const gemm_info_t<a_type, b_type, c_type> *arg) {
1034
1035 if (data_traits<a_type>::data_type != data_type::f32) return false;
1036
1037 if (!mayiuse(avx)) return false;
1038
1039 if (arg->force_nocopy) return true;
1040
1041 auto m = arg->m, n = arg->n, k = arg->k;
1042 auto lda = arg->lda, ldb = arg->ldb, ldc = arg->ldc;
1043 auto transa = arg->transa, transb = arg->transb;
1044 auto packing = arg->packing;
1045
1046 if (packing != pack_type::none) ldc = 64;
1047
1048 if (arg->a_packed || arg->b_packed)
1049 return false;
1050 else if (mayiuse(avx512_core))
1051 return nocopy_checker_avx512(
1052 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1053 else
1054 return nocopy_checker_avx2(
1055 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1056}
1057
1058template <typename a_type, typename b_type, typename c_type>
1059static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn,
1060 gemm_threading_t &thread_info,
1061 const gemm_info_t<a_type, b_type, c_type> *arg) {
1062
1063 static constexpr dim_t N2D_MAX = 384;
1064 static constexpr dim_t M2D_MIN = 384;
1065
1066 constexpr bool is_int8 = utils::one_of(
1067 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1068 bool isSgemm = data_traits<a_type>::data_type == data_type::f32;
1069
1070 dim_t m = arg->m;
1071 dim_t n = arg->n;
1072 dim_t k = arg->k;
1073
1074 thread_info.nthrs_m = 0;
1075 thread_info.nthrs_n = 0;
1076 thread_info.nthrs_k = 0;
1077 thread_info.copy = copy_type::nonshared;
1078 thread_info.partition = partition_type::row_1d;
1079
1080 // TODO Check if we can use dynamic scheduling for sgemm.
1081 // TODO Check if we should use 3D blocking.
1082 thread_info.nthrs_k = 1;
1083 thread_info.thread_k = k;
1084
1085 bool condition_2D_bsrc = false;
1086 if (isSgemm) {
1087 // If m is large and n is small then do 1D partitioning for AVX2.
1088 if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN))
1089 condition_2D_bsrc = false;
1090 else
1091 condition_2D_bsrc
1092 = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2))
1093 && (m >= 2 * M2D_MIN);
1094 } else {
1095 int scale = mayiuse(avx512_core) ? nthrs : 20;
1096 condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n);
1097 }
1098
1099 // TODO Check if we should use k-partitioning.
1100
1101 int condition_1D_copya = false;
1102 if (mayiuse(avx512_core)) {
1103 const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68;
1104 if (m >= 1000 && (n >= nthrs * thresh)) {
1105 condition_2D_bsrc = false;
1106 condition_1D_copya = true;
1107 }
1108 } else {
1109 if (m >= 1000 && n >= 4000) {
1110 condition_2D_bsrc = false;
1111 condition_1D_copya = true;
1112 }
1113 }
1114
1115 // If A or B offset is non-zero, we need to keep 1D_copya to reduce update
1116 // overhead.
1117 // TODO: the reasons seems to be in copy_sum_bx routines. At least,
1118 // after simple optimization of copy_sum_ax for avx512, similar
1119 // restriction on offset B became unnecessary. Revisit.
1120 if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) {
1121 condition_2D_bsrc = false;
1122 condition_1D_copya = true;
1123 }
1124
1125 if (condition_2D_bsrc) {
1126 int nthrs_m = 1;
1127 int nthrs_n = nthrs;
1128
1129 if (isSgemm) {
1130 while ((nthrs_n % 2 == 0)
1131 && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2)
1132 && (m / nthrs_m >= 2 * M2D_MIN) && (nthrs_m < 4)) {
1133 nthrs_m *= 2;
1134 nthrs_n /= 2;
1135 }
1136
1137 thread_info.nthrs_m = nthrs_m;
1138 thread_info.nthrs_n = nthrs_n;
1139 thread_info.partition = partition_type::col_major_2d;
1140 } else {
1141 if (m == 800 && n == 300) {
1142 // TODO: Expand this branch to other problem sizes.
1143
1144 auto &thread_m = thread_info.thread_m;
1145 auto &thread_n = thread_info.thread_n;
1146
1147 const dim_t block_m = arg->um * 4;
1148 constexpr dim_t block_n = 64;
1149 constexpr dim_t small_m = 16;
1150 constexpr dim_t small_n = 2;
1151
1152 std::tie(nthrs_m, nthrs_n)
1153 = gemm_utils::calc_nthr_2d(nthrs, m, n, block_m,
1154 block_n, small_m, small_n, thread_m, thread_n);
1155
1156 thread_info.nthrs_m = nthrs_m;
1157 thread_info.nthrs_n = nthrs_n;
1158 thread_info.partition = partition_type::mnk_3d;
1159
1160 } else if ((n <= 64 || n >= 256)) {
1161 while (((nthrs_n > 1) && (n / nthrs_n < arg->un)
1162 && (m / nthrs_m >= 2 * arg->um)
1163 && mayiuse(avx512_core))
1164 || ((nthrs_n % 2 == 0)
1165 && (n / nthrs > N2D_MAX
1166 || n / nthrs_n <= N2D_MAX / 2)
1167 && (m / nthrs_m >= 2 * M2D_MIN)
1168 && (nthrs_m < 4))) {
1169 nthrs_m *= 2;
1170 nthrs_n /= 2;
1171 }
1172
1173 thread_info.nthrs_m = nthrs_m;
1174 thread_info.nthrs_n = nthrs_n;
1175 thread_info.partition = partition_type::col_major_2d;
1176 } else {
1177 // Use 3D decomposition from pack api without k-partitioning.
1178 set_thread_opts_pack(nthrs, thread_info, arg, false);
1179 }
1180 }
1181
1182 } else if (condition_1D_copya && dnnl_thr_syncable()) {
1183 // Use parallel copy A algorithm
1184 thread_info.copy = copy_type::shared_a;
1185 thread_info.partition = partition_type::col_1d;
1186 thread_info.nthrs_m = 1;
1187 thread_info.nthrs_n = nthrs_spawn; // Using all spawned threads.
1188 } else {
1189 auto veclen = get_vector_length<c_type>();
1190
1191 if (m > n && (m >= nthrs * veclen || n < nthrs)) {
1192 if (n <= 20 && is_int8) {
1193 // Use 3D decomposition forcing m-blocking only.
1194 set_thread_opts_pack(
1195 nthrs, thread_info, arg, false, true, false);
1196 } else {
1197 thread_info.partition = partition_type::row_1d;
1198 thread_info.nthrs_m = nthrs;
1199 thread_info.nthrs_n = 1;
1200 }
1201 } else {
1202 thread_info.partition = partition_type::col_1d;
1203 thread_info.nthrs_m = 1;
1204 thread_info.nthrs_n = nthrs;
1205 }
1206 }
1207}
1208
1209template <typename a_type, typename b_type, typename c_type>
1210static inline void set_thread_opts_pack(int nthrs,
1211 gemm_threading_t &thread_info,
1212 const gemm_info_t<a_type, b_type, c_type> *arg,
1213 bool do_k_blocking = true, bool do_m_blocking = true,
1214 bool do_n_blocking = true) {
1215
1216 constexpr bool is_int8 = utils::one_of(
1217 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1218 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1219
1220 bool do_m_blocking_only = do_m_blocking && !do_n_blocking;
1221
1222 auto m = arg->m, n = arg->n, k = arg->k;
1223
1224 auto &nthr_m = thread_info.nthrs_m;
1225 auto &nthr_n = thread_info.nthrs_n;
1226 auto &nthr_k = thread_info.nthrs_k;
1227 auto &thread_m = thread_info.thread_m;
1228 auto &thread_n = thread_info.thread_n;
1229 auto &thread_k = thread_info.thread_k;
1230 auto &block_m = thread_info.block_m;
1231 auto &block_n = thread_info.block_n;
1232 auto &block_k = thread_info.block_k;
1233
1234 constexpr auto MBLK = 64;
1235 constexpr auto NBLK = 64;
1236 auto KBLK = is_int8 ? 3072 : 256;
1237 KBLK = do_m_blocking_only && is_int8 ? 384 : KBLK;
1238
1239 nthr_m = nthr_n = nthr_k = 1;
1240 thread_info.copy = copy_type::nonshared;
1241 thread_info.partition = partition_type::mnk_3d;
1242
1243 auto choose_blocking
1244 = [](dim_t size_z, dim_t &thread_z, int &nthr_z, dim_t block_z_init,
1245 dim_t &block_z, dim_t block_align) {
1246 thread_z = utils::div_up(size_z, nthr_z);
1247 auto num_blk = utils::div_up(thread_z, block_z_init);
1248 block_z = utils::div_up(thread_z, num_blk);
1249 block_z = utils::rnd_up(block_z, block_align);
1250 thread_z = num_blk * block_z;
1251 if (thread_z * nthr_z > size_z)
1252 nthr_z = utils::div_up(size_z, thread_z);
1253 };
1254
1255 auto choose_m_blocking = [&]() {
1256 auto align = get_vector_length<c_type>();
1257 align = do_m_blocking_only ? arg->um : align;
1258 choose_blocking(m, thread_m, nthr_m, arg->bm, block_m, align);
1259 };
1260 auto choose_n_blocking = [&]() {
1261 choose_blocking(n, thread_n, nthr_n, arg->bn, block_n, arg->un);
1262 };
1263 auto choose_k_blocking = [&]() {
1264 auto align = nstl::max(arg->uk, dim_t(4));
1265 choose_blocking(k, thread_k, nthr_k, arg->bk, block_k, align);
1266 };
1267
1268 // Choose k blocking.
1269 if ((m / MBLK + n / NBLK) < nthrs && do_k_blocking) {
1270 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1271 if (nthrs % nk == 0) nthr_k = nk;
1272
1273 // Sacrifice one thread and try again if parallelism is too small in
1274 // n-dimension.
1275 if (nthr_k == 1 && nthrs > 1 && do_m_blocking_only) {
1276 nthrs--;
1277 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1278 if (nthrs % nk == 0) nthr_k = nk;
1279 }
1280
1281 // Allow up to 2 threads to be sacrificed for large k >> m, n.
1282 if (nthr_k < 4 && k >= m * 4 && k >= n * 4 && nthrs > 10 && is_bf16) {
1283 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1284 if (nthrs % nk <= 2) nthr_k = nk;
1285 }
1286 }
1287
1288 choose_k_blocking();
1289
1290 // Choose m/n blocking.
1291 auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um;
1292 min_mblk = do_m_blocking ? min_mblk : m;
1293 min_mblk = do_m_blocking_only ? arg->um : min_mblk;
1294 auto min_nblk = do_n_blocking ? NBLK / 2 : n;
1295
1296 std::tie(nthr_m, nthr_n) = partition_2d_minblk(m, n, MBLK, NBLK, min_mblk,
1297 min_nblk, arg->um, arg->un, nthrs / nthr_k,
1298 do_m_blocking && do_n_blocking && do_k_blocking);
1299
1300 auto nthr_m_init = nthr_m, nthr_n_init = nthr_n;
1301
1302 choose_m_blocking();
1303 choose_n_blocking();
1304
1305 if (is_int8 && do_m_blocking && do_n_blocking) {
1306 // If we lost a thread in one dimension because we padded the blocking
1307 // size, try to rebalance the other dimensions.
1308 if ((nthr_n != nthr_n_init)
1309 && ((nthr_m + 1) * nthr_n * nthr_k <= nthrs)) {
1310 nthr_m++;
1311 choose_m_blocking();
1312 }
1313
1314 if ((nthr_m != nthr_m_init)
1315 && (nthr_m * (nthr_n + 1) * nthr_k <= nthrs)) {
1316 nthr_n++;
1317 choose_n_blocking();
1318 }
1319 }
1320}
1321
1322template <typename a_type, typename b_type, typename c_type>
1323static inline int set_thread_opts(int nthrs, int nthrs_spawn,
1324 gemm_threading_t &thread_info,
1325 const gemm_info_t<a_type, b_type, c_type> *arg) {
1326
1327 thread_info.block_m = thread_info.block_n = thread_info.block_k = -1;
1328 thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1;
1329
1330 constexpr bool is_int8 = utils::one_of(
1331 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1332 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1333
1334 if (nocopy_checker(nthrs, arg)) {
1335 thread_info.copy = copy_type::no_copy;
1336 thread_info.partition = partition_type::mnk_3d;
1337 int nthrs_m = 0;
1338 int nthrs_n = 0;
1339 int nthrs_k = 0;
1340 dim_t BM = 0;
1341 dim_t BN = 0;
1342 dim_t BK = 0;
1343 auto m = arg->m, n = arg->n, k = arg->k;
1344
1345 if (mayiuse(avx512_core)) {
1346 cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs,
1347 &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1348 } else {
1349 cpu::gemm_utils::calc_nthr_nocopy_avx(m, n, k, nthrs, &nthrs_m,
1350 &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1351 }
1352
1353 // Block information is being ignored. We will create partitioning
1354 // later.
1355 thread_info.nthrs_m = nthrs_m;
1356 thread_info.nthrs_n = nthrs_n;
1357 thread_info.nthrs_k = nthrs_k;
1358 } else {
1359 if (arg->packing != pack_type::none && (is_int8 || is_bf16))
1360 set_thread_opts_pack(nthrs, thread_info, arg);
1361 else
1362 set_thread_opts_nopack(nthrs, nthrs_spawn, thread_info, arg);
1363 }
1364
1365 return thread_info.nthrs_m * thread_info.nthrs_n * thread_info.nthrs_k;
1366}
1367
1368template <typename a_type, typename b_type, typename c_type>
1369static inline std::tuple<const a_type *, const b_type *, c_type *,
1370 const c_type *>
1371decompose_matrices(const gemm_slice_t &slice,
1372 const gemm_info_t<a_type, b_type, c_type> *arg) {
1373
1374 dim_t stride_am = (arg->transa == no_trans) ? 1 : arg->lda;
1375 dim_t stride_ak = (arg->transa != no_trans) ? 1 : arg->lda;
1376 dim_t stride_bn = (arg->transb != no_trans) ? 1 : arg->ldb;
1377 dim_t stride_bk = (arg->transb == no_trans) ? 1 : arg->ldb;
1378
1379 auto a = arg->a;
1380 auto b = arg->b;
1381 auto c = arg->c;
1382 if (a) a += slice.off_m * stride_am + slice.off_k * stride_ak;
1383 if (b) b += slice.off_n * stride_bn + slice.off_k * stride_bk;
1384 if (c) c += slice.off_m + slice.off_n * arg->ldc;
1385
1386 dim_t co_stride;
1387 switch (arg->offsetc) {
1388 case offset_type::row: co_stride = slice.off_n; break;
1389 case offset_type::column: co_stride = slice.off_m; break;
1390 default: co_stride = 0; break;
1391 }
1392 auto co = arg->co;
1393 if (co) co += co_stride;
1394
1395 return std::make_tuple(a, b, c, co);
1396}
1397
1398template <typename a_type, typename b_type, typename c_type>
1399static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs,
1400 const dim_t m, const dim_t n, const dim_t k, const a_type *a,
1401 const b_type *b, float beta, c_type *c, dim_t ldc, offset_type offsetc,
1402 const c_type *co, const gemm_info_t<a_type, b_type, c_type> *arg,
1403 char **p_shared_mem) {
1404
1405 if (arg->packing != pack_type::none)
1406 return gemm_packing_driver(ithr, m, n, k, a, b, arg);
1407
1408 const dim_t lda = arg->lda;
1409 const dim_t ldb = arg->ldb;
1410 const dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
1411 const dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
1412 const dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
1413
1414 float alpha = arg->alpha;
1415
1416 constexpr bool is_int8 = utils::one_of(
1417 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1418 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1419 bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx);
1420 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx);
1421 bool is_amx = is_int8_amx || is_bf16_amx;
1422
1423 const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
1424
1425 // Scaling C matrix.
1426 if (!is_int8 && beta != 1.0f && beta != 0.0f) {
1427 scale_matrix(m, n, beta, c, ldc);
1428 beta = 1.0f;
1429 }
1430
1431 // Padding along M, K dimensions.
1432 dim_t m_padd = get_m_padd_parallel_a(ithr, m, arg, nthrs);
1433 dim_t k_padd = get_k_padd(ithr, k, arg);
1434
1435 size_t a_buf_nelems = m_padd * k_padd;
1436
1437 // A buffer needs more space due to zero-padding.
1438 if (is_amx)
1439 a_buf_nelems = utils::rnd_up(m_padd, arg->um)
1440 * utils::rnd_up(k_padd, arg->uk);
1441
1442 // Allocate shared memory for A and its row sum buffers in master thread.
1443 char *mem = nullptr;
1444 a_type *bufferA = nullptr;
1445 c_type *a_row_sum = nullptr;
1446
1447 if (!a_packed) {
1448 if (ithr == 0) { // If thread master
1449 size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K);
1450
1451 if (is_int8) {
1452 size_t a_row_sum_nelems = m_padd;
1453 mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K;
1454 }
1455
1456 *p_shared_mem = (char *)malloc(mem_size, 128);
1457 }
1458
1459 dnnl_thr_barrier();
1460
1461 mem = *p_shared_mem;
1462 bufferA = (a_type *)align(mem, PAGE_4K);
1463
1464 if (is_int8)
1465 a_row_sum = (c_type *)align(bufferA + a_buf_nelems, PAGE_4K);
1466
1467 if (!mem) return dnnl_out_of_memory;
1468 }
1469
1470 dnnl_status_t result = dnnl_success; // Return status
1471
1472 dim_t sizeK = 0;
1473 dim_t blk_k = 0;
1474 for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
1475 sizeK = k - Bk;
1476 if (sizeK > k_padd) sizeK = k_padd;
1477
1478 // Scale C blocks by beta only for the first term of partial sum.
1479 auto beta_eff = (Bk == 0) ? beta : 1.0f;
1480
1481 // Apply C offset for the last k-block of the partial sum.
1482 auto offsetc_eff = offset_type::none;
1483 if (Bk + sizeK == k) offsetc_eff = offsetc;
1484
1485 dim_t sizeM = 0;
1486 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
1487 sizeM = m - Bm;
1488 if (sizeM > m_padd) sizeM = m_padd;
1489
1490 if ((ithr < nthrs) && !a_packed) {
1491 dim_t band = (sizeM + nthrs - 1) / nthrs;
1492 band = utils::rnd_up(band, arg->um);
1493
1494 dim_t offset = band * ithr;
1495
1496 // If offset is too large don't use that thread for copying.
1497 if (offset >= sizeM) {
1498 offset = 0;
1499 band = 0;
1500 }
1501
1502 // Handle the tail of the copy.
1503 if (offset + band > sizeM) { band = sizeM - offset; }
1504
1505 if (band > 0) {
1506 const a_type *a_block
1507 = a + (Bm + offset) * strideAm + Bk * strideAn;
1508
1509 dim_t buf_shift = 0;
1510 if (is_amx)
1511 buf_shift = offset * utils::rnd_up(sizeK, arg->uk);
1512 else
1513 buf_shift = offset * sizeK;
1514
1515 /* Row sum argument is ignored for non-integer kernels and
1516 * scaling factor is ignored by 8-bit and 16-bit copy
1517 * kernels.
1518 */
1519 c_type *a_row_sum_eff
1520 = a_row_sum ? a_row_sum + offset : nullptr;
1521 arg->copyA(&sizeK, &band, a_block, &lda, &alpha,
1522 bufferA + buf_shift, nullptr, nullptr,
1523 a_row_sum_eff);
1524 }
1525 }
1526 if (!a_packed)
1527 dnnl_thr_barrier(); // Wait for finishing parallel copy.
1528
1529 const b_type *b_block = b + Bk * strideBm;
1530 c_type *c_block = c + Bm;
1531
1532 dim_t co_stride = 0;
1533 if (offsetc_eff == offset_type::fixed) {
1534 co_stride = 0;
1535 } else if (offsetc_eff == offset_type::row) {
1536 co_stride = 0;
1537 } else if (offsetc_eff == offset_type::column) {
1538 co_stride = Bm;
1539 }
1540
1541 auto bufferA_eff
1542 = a_packed ? a_packed->matrix<a_type>(0, Bm, Bk) : bufferA;
1543 auto a_row_sum_eff = a_packed
1544 ? a_packed->row_sums<c_type>(0, Bm, blk_k)
1545 : a_row_sum;
1546
1547 auto this_result = kernel_driver_parallel_acopiedbcopy(ithr, sizeM,
1548 n, sizeK, blk_k, Bk, bufferA_eff, b_block, beta_eff,
1549 c_block, offsetc_eff, co + co_stride, a_row_sum_eff, arg);
1550
1551 if (this_result != dnnl_success) result = this_result;
1552
1553 if (!a_packed)
1554 dnnl_thr_barrier(); // Wait for kernel computations to finish.
1555 }
1556 }
1557
1558 // Free memory allocated in master thread
1559 if (ithr == 0 && !a_packed) free(mem);
1560
1561 return result;
1562}
1563
1564template <typename T>
1565static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) {
1566
1567 const double omp_overhead_small_core = 3.0e+3;
1568 const double omp_intercept_big_core = 4.0e+3;
1569 const double omp_slope_big_core = 5.0e+2;
1570
1571 auto veclen = get_vector_length<T>();
1572 const double fp_per_cycle = 2.0 * 2.0 * veclen;
1573
1574 const bool is_f32 = data_traits<T>::data_type == data_type::f32;
1575
1576 const bool is_avx512 = mayiuse(avx512_core);
1577 const bool is_avx = mayiuse(avx);
1578 const bool is_only_avx2 = mayiuse(avx2) && !is_avx512;
1579
1580 // Some sgemm cases still benefit from using all threads.
1581 const bool use_all_threads = is_f32 && n > 50
1582 && ((is_avx && m <= 3) || (is_avx512 && m <= 10));
1583 if (use_all_threads) return;
1584
1585 if (is_only_avx2)
1586 if (m > 10 * n && n < *nthrs)
1587 if (m / *nthrs < veclen * 3)
1588 *nthrs = nstl::max(m / veclen / 3, dim_t(1));
1589
1590 double gemm_cycles = m * n * k / fp_per_cycle;
1591 gemm_cycles *= is_f32 ? 2.0 : 8.0;
1592
1593#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
1594 if (is_f32) {
1595 static auto l2_cache_per_thread = platform::get_per_core_cache_size(2);
1596 static int n_cores_per_socket
1597 = static_cast<int>(platform::get_num_cores());
1598 auto l2_cache_socket = l2_cache_per_thread * n_cores_per_socket;
1599 auto problem_memory_footprint = (m * n + m * k + n * k) * sizeof(float);
1600
1601 if (is_only_avx2) {
1602 // Somehow it seems beneficial to split the job into bigger pieces.
1603 // Use L2 per-core cache size as a deal-breaker.
1604 int use_n_threads = utils::div_up(
1605 problem_memory_footprint, l2_cache_per_thread);
1606 *nthrs = nstl::min(*nthrs, use_n_threads);
1607 return;
1608 }
1609 if (l2_cache_socket > problem_memory_footprint) {
1610 *nthrs = nstl::min(*nthrs, n_cores_per_socket);
1611 return;
1612 }
1613 }
1614#endif
1615
1616 int i = *nthrs;
1617
1618 // Use a different model for omp overheads if nthrs is <= 4
1619 if (*nthrs <= 4 && omp_overhead_small_core > 0) {
1620 double omp_cycles = omp_overhead_small_core;
1621 if (gemm_cycles < omp_cycles) {
1622 *nthrs = 1;
1623 return;
1624 } else {
1625 while (i > 1) {
1626 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1627 --i;
1628 }
1629 }
1630 } else {
1631 if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
1632 *nthrs = 1;
1633 return;
1634 }
1635
1636 // adaptive decrement to march faster·
1637 while (i > 1) {
1638 double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
1639 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1640
1641 if (i < 10)
1642 i -= 2;
1643 else if (i < 30)
1644 i -= 4;
1645 else
1646 i -= 8;
1647 }
1648 }
1649
1650 if (i < 1) i = 1;
1651
1652 *nthrs = i;
1653}
1654
1655template <typename a_type, typename b_type, typename c_type>
1656static dnnl_status_t call_no_copy_sgemm(
1657 int nthrs, gemm_info_t<a_type, b_type, c_type> *arg) {
1658
1659 if (arg->packing == pack_type::none) {
1660 auto transa_char = (arg->transa != do_trans) ? "N" : "T";
1661 auto transb_char = (arg->transb != do_trans) ? "N" : "T";
1662
1663 if (mayiuse(avx512_core))
1664 return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char,
1665 &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a,
1666 &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta,
1667 (float *)arg->c, &arg->ldc, (float *)arg->co);
1668 else
1669 return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m,
1670 &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda,
1671 (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c,
1672 &arg->ldc, (float *)arg->co);
1673 } else
1674 return pack_no_copy(arg);
1675}
1676
1677template <typename a_type, typename b_type, typename c_type>
1678static dnnl_status_t gemm_threading_driver(
1679 gemm_info_t<a_type, b_type, c_type> *arg) {
1680
1681 auto packing = (arg->packing != pack_type::none);
1682 auto is_a_packed = (arg->transa == packed);
1683 auto is_b_packed = (arg->transb == packed);
1684 constexpr bool is_int8 = utils::one_of(
1685 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1686 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1687
1688 if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success;
1689
1690 if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg))
1691 return dnnl_success;
1692
1693 if (!is_a_packed && !is_b_packed
1694 && jump_to_gemm_smalln_tn(arg) == dnnl_success)
1695 return dnnl_success;
1696
1697 if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success)
1698 return dnnl_success;
1699
1700 if (is_a_packed && arg->bo != 0)
1701 if (!arg->a_packed->has_row_sums()) return dnnl_invalid_arguments;
1702
1703 if (is_b_packed && arg->ao != 0)
1704 if (!arg->b_packed->has_col_sums()) return dnnl_invalid_arguments;
1705
1706 auto nthr_max = dnnl_get_current_num_threads();
1707 int nthr_goal = nthr_max;
1708
1709 adjust_thread_count<c_type>(arg->m, arg->n, arg->k, &nthr_goal);
1710
1711 const gemm_threading_t *force_threading = nullptr;
1712 gemm_threading_t force_k_decomp;
1713
1714 // Initialize per-thread data.
1715 // Note: to support k blocking with non-packed GEMM, threading must be
1716 // chosen now and force_threading set.
1717 if (!packing) {
1718 // Override choice of thread count if data is pre-packed for a particular
1719 // number of threads.
1720 if (is_a_packed && is_b_packed)
1721 if (arg->a_packed->threading() != arg->b_packed->threading())
1722 return dnnl_invalid_arguments;
1723 if (is_a_packed)
1724 force_threading = &arg->a_packed->threading();
1725 else if (is_b_packed)
1726 force_threading = &arg->b_packed->threading();
1727 else if (arg->m <= 768 && arg->n <= 768 && arg->k >= 2048 && is_bf16) {
1728 // Try k-partitioning.
1729 set_thread_opts_pack(nthr_goal, force_k_decomp, arg);
1730
1731 // Decide partition type later if no partitions in k-dimension.
1732 if (force_k_decomp.nthrs_k > 1) force_threading = &force_k_decomp;
1733 } else if (arg->n <= 128 && arg->k >= 3072 && is_int8) {
1734 // Use k-partitioning if necessary.
1735 // Use 3D decomposition from pack api without n-partitioning.
1736 set_thread_opts_pack(
1737 nthr_goal, force_k_decomp, arg, true, true, false);
1738
1739 // Decide partition type later if no partitions in k-dimension.
1740 if (force_k_decomp.nthrs_k > 1 && force_k_decomp.nthrs_m > 1)
1741 force_threading = &force_k_decomp;
1742 }
1743
1744 if (force_threading) {
1745 nthr_goal = force_threading->nthrs();
1746 arg->update_blocking(*force_threading);
1747 }
1748 } else {
1749 // Prepare packed data layout.
1750 gemm_pack_storage_t *pack_dst = arg->pack_dst;
1751 bool do_a = (arg->packing == pack_type::pack_a);
1752
1753 pack_dst->which() = do_a ? matrix_id::a : matrix_id::b;
1754 pack_dst->setup(nthr_goal, do_a && is_int8, !do_a && is_int8);
1755
1756 auto &thread_info = pack_dst->threading();
1757 force_threading = &thread_info;
1758
1759 nthr_goal = set_thread_opts(nthr_goal, nthr_max, thread_info, arg);
1760 arg->update_blocking(thread_info);
1761
1762 if (thread_info.copy != copy_type::no_copy) {
1763 for (int ithr = 0; ithr < nthr_goal; ithr++) {
1764 if (!pack_dst->is_first_thread_in_slice(ithr)) continue;
1765
1766 auto slice = thread_info.get_thread_slice(
1767 ithr, arg->m, arg->n, arg->k);
1768
1769 auto m = slice.m, n = slice.n, k = slice.k;
1770
1771 auto m_padd = (thread_info.copy == copy_type::shared_a)
1772 ? get_m_padd_parallel_a(
1773 ithr, m, arg, thread_info.nthrs())
1774 : get_m_padd(ithr, m, arg);
1775 auto n_padd = get_n_padd(ithr, n, k, arg);
1776 auto k_padd = get_k_padd(ithr, k, arg);
1777
1778 do_a ? pack_dst->set_blocking(ithr, m, k, m_padd, k_padd)
1779 : pack_dst->set_blocking(ithr, k, n, k_padd, n_padd);
1780 }
1781 } else {
1782 auto ld = do_a ? gemm_utils::get_ld_padd<a_type>(arg->m)
1783 : gemm_utils::get_ld_padd<b_type>(arg->k);
1784
1785 pack_dst->set_nocopy(0, no_trans, ld, do_a ? arg->k : arg->n);
1786 }
1787
1788 do_a ? pack_dst->finalize<a_type, c_type>()
1789 : pack_dst->finalize<b_type, c_type>();
1790
1791 if (arg->measure_only) return dnnl_success;
1792 }
1793
1794 if (nocopy_checker(nthr_goal, arg))
1795 return call_no_copy_sgemm(nthr_goal, arg);
1796
1797 if (nthr_goal == 1)
1798 return gemm_kernel_driver(0, arg->m, arg->n, arg->k, arg->a, arg->b,
1799 arg->beta, arg->c, arg->ldc, arg->offsetc, arg->co, arg);
1800
1801 bool k_blocking = force_threading && (force_threading->nthrs_k > 1);
1802 bool k_summing = k_blocking && !packing;
1803
1804 auto *thread_arg = (gemm_per_thread_t<c_type> *)malloc(
1805 sizeof(gemm_per_thread_t<c_type>) * nthr_max, PAGE_4K);
1806
1807 if (!thread_arg) return dnnl_out_of_memory;
1808
1809 dim_t max_mt = 0, max_nt = 0;
1810 for (int ithr = 0; ithr < nthr_max; ithr++) {
1811 thread_arg[ithr].result = dnnl_success;
1812 thread_arg[ithr].compute_done = false;
1813 thread_arg[ithr].c_local = thread_arg[ithr].c_global = nullptr;
1814 thread_arg[ithr].ldc_global = arg->ldc;
1815 thread_arg[ithr].ldc_local = 0;
1816
1817 if (force_threading) {
1818 thread_arg[ithr].slice = force_threading->get_thread_slice(
1819 ithr, arg->m, arg->n, arg->k);
1820 thread_arg[ithr].nthr_k = force_threading->nthrs_k;
1821 thread_arg[ithr].thr_k_stride = force_threading->thr_k_stride();
1822 max_mt = nstl::max(max_mt, thread_arg[ithr].slice.m);
1823 max_nt = nstl::max(max_nt, thread_arg[ithr].slice.n);
1824 } else {
1825 thread_arg[ithr].slice = {0, 0, 0, 0, 0, 0, 0, 0, 0};
1826 thread_arg[ithr].nthr_k = 1;
1827 thread_arg[ithr].thr_k_stride = 0;
1828 }
1829 }
1830
1831 // Create temporary C buffers for k blocking if needed.
1832 c_type *c_local_storage = nullptr;
1833 if (k_summing) {
1834 const dim_t BAD_LD_MULT = 256;
1835 dim_t ldc_local = max_mt % BAD_LD_MULT
1836 ? max_mt
1837 : gemm_utils::get_ld_padd<c_type>(max_mt);
1838 dim_t c_local_stride = ldc_local * max_nt;
1839 c_local_storage = (c_type *)malloc(
1840 sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K);
1841
1842 if (!c_local_storage) {
1843 free(thread_arg);
1844 return dnnl_out_of_memory;
1845 }
1846
1847 for (int ithr = 0; ithr < nthr_goal; ithr++) {
1848 thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride;
1849 thread_arg[ithr].ldc_local = ldc_local;
1850 }
1851 }
1852
1853 char *shared_mem = nullptr;
1854
1855 // Always use the maximum number of threads to avoid OMP overhead that can
1856 // occur due to change thread counts.
1857 int nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal;
1858
1859 parallel(nthr_spawn, [&](int ithr, int nthr) {
1860 int nthr_eff = force_threading ? nthr_goal : nstl::min(nthr_goal, nthr);
1861
1862 if (nthr_eff == 1) {
1863 thread_arg[0].result = gemm_kernel_driver(0, arg->m, arg->n, arg->k,
1864 arg->a, arg->b, arg->beta, arg->c, arg->ldc, arg->offsetc,
1865 arg->co, arg);
1866 } else {
1867 gemm_threading_t thread_info;
1868
1869 if (force_threading)
1870 thread_info = *force_threading;
1871 else {
1872 nthr_eff = set_thread_opts(nthr_eff, nthr, thread_info, arg);
1873 if (ithr < nthr_eff)
1874 thread_arg[ithr].slice = thread_info.get_thread_slice(
1875 ithr, arg->m, arg->n, arg->k);
1876 }
1877
1878 for (; ithr < nthr_eff; ithr += nthr) {
1879 // Get submatrices and parameters for this thread's GEMM.
1880 const a_type *a = nullptr;
1881 const b_type *b = nullptr;
1882 c_type *c = nullptr;
1883 const c_type *co = nullptr;
1884 std::tie(a, b, c, co)
1885 = decompose_matrices(thread_arg[ithr].slice, arg);
1886
1887 auto m = thread_arg[ithr].slice.m;
1888 auto n = thread_arg[ithr].slice.n;
1889 auto k = thread_arg[ithr].slice.k;
1890 thread_arg[ithr].c_global = c;
1891 auto c_eff = c;
1892 auto ldc_eff = arg->ldc;
1893 auto beta_eff = arg->beta;
1894 auto offsetc_eff = arg->offsetc;
1895
1896 // For all but first k block: substitute local C matrix and
1897 // disable postops.
1898 if (k_summing && thread_arg[ithr].slice.ithr_k > 0) {
1899 c_eff = thread_arg[ithr].c_local;
1900 ldc_eff = thread_arg[ithr].ldc_local;
1901 beta_eff = 0;
1902 offsetc_eff = offset_type::none;
1903 }
1904
1905 // Dispatch appropriate GEMM driver.
1906 switch (thread_info.copy) {
1907 case copy_type::shared_a:
1908 thread_arg[ithr].result = parallel_a_copy(ithr,
1909 nthr_eff, m, n, k, a, b, beta_eff, c_eff,
1910 ldc_eff, offsetc_eff, co, arg, &shared_mem);
1911 break;
1912
1913 default:
1914 case copy_type::nonshared:
1915 thread_arg[ithr].result = gemm_kernel_driver(ithr, m, n,
1916 k, a, b, beta_eff, c_eff, ldc_eff, offsetc_eff,
1917 co, arg);
1918 break;
1919
1920 case copy_type::no_copy:
1921 // This route is taken only if we realize we need no-copy
1922 // after launching the parallel section, due to less
1923 // threads being spawned than expected.
1924 assert(data_traits<a_type>::data_type
1925 == data_type::f32);
1926 assert(arg->packing == pack_type::none);
1927
1928 if (mayiuse(avx512_core)) {
1929 avx512_common_gemm_f32::sgemm_nocopy_driver(
1930 arg->transa == no_trans ? "N" : "T",
1931 arg->transb == no_trans ? "N" : "T", m, n,
1932 k, &arg->alpha, (float *)a, arg->lda,
1933 (float *)b, arg->ldb, &beta_eff,
1934 (float *)c_eff, ldc_eff, nullptr);
1935 } else {
1936 avx_gemm_f32::sgemm_nocopy_driver(
1937 arg->transa == no_trans ? "N" : "T",
1938 arg->transb == no_trans ? "N" : "T", m, n,
1939 k, &arg->alpha, (float *)a, arg->lda,
1940 (float *)b, arg->ldb, &beta_eff,
1941 (float *)c_eff, ldc_eff, nullptr);
1942 }
1943 thread_arg[ithr].result = dnnl_success;
1944 break;
1945 }
1946
1947 // Sum thread results along k dimension, parallelized in the n
1948 // dimension. To avoid deadlocks, results are summed later if
1949 // not all threads are running concurrently. We can only detect
1950 // if this is safe when using OpenMP.
1951#if DNNL_THR_SYNC == 1
1952 if (k_summing && (nthr >= nthr_eff)) {
1953 thread_arg[ithr].compute_done = true;
1954 sum_k_blocks(ithr, thread_arg, true);
1955 }
1956#endif
1957 }
1958 }
1959 });
1960
1961 dnnl_status_t result = dnnl_success; // Initialize to success
1962 for (int ithr = 0; ithr < nthr_max; ithr++) {
1963 if (thread_arg[ithr].result != dnnl_success) {
1964 result = static_cast<dnnl_status_t>(thread_arg[ithr].result);
1965 break;
1966 }
1967 }
1968
1969 // Sum thread results along k dimension if this wasn't done earlier.
1970 if (k_summing && !thread_arg[0].compute_done) {
1971 parallel(nthr_goal, [&](int ithr, int nthr) {
1972 for (; ithr < nthr_goal; ithr += nthr)
1973 sum_k_blocks(ithr, thread_arg, false);
1974 });
1975 }
1976
1977 if (c_local_storage) dnnl::impl::free(c_local_storage);
1978 dnnl::impl::free(thread_arg);
1979
1980 return result;
1981}
1982
1983template <typename a_type, typename b_type, typename c_type>
1984dnnl_status_t gemm_driver(const char *transA, const char *transB,
1985 const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k,
1986 const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa,
1987 const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta,
1988 c_type *c, const dim_t *ldc, const c_type *oc, const bool force_nocopy,
1989 pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) {
1990
1991 constexpr bool is_int8 = utils::one_of(
1992 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1993 MAYBE_UNUSED(is_int8);
1994
1995 // gemm_driver supports bfloat16 gemm for Intel AVX512 and
1996 // Intel AVX512 BF16.
1997 assert(IMPLICATION(data_traits<a_type>::data_type == data_type::bf16,
1998 mayiuse(avx512_core) && !force_nocopy));
1999
2000 // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX,
2001 // Intel SSE4.1 and Intel DL Boost.
2002 assert(IMPLICATION(is_int8, mayiuse(sse41)));
2003
2004 // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX,
2005 // and Intel SSE4.1
2006 assert(IMPLICATION(
2007 data_traits<a_type>::data_type == data_type::f32, mayiuse(sse41)));
2008
2009 // 8-bit integer gemm doesn't support nocopy kernels.
2010 assert(IMPLICATION(is_int8, !force_nocopy));
2011
2012 // gemm_driver can only dispatch nocopy for avx and above.
2013 assert(IMPLICATION(force_nocopy, mayiuse(avx)));
2014
2015 gemm_info_t<a_type, b_type, c_type> args(transA, transB, offsetC, m, n, k,
2016 alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy,
2017 packing, pack_dst, measure_only);
2018
2019 // Check if copy algorithm kernels were generated on supported ISAs.
2020 if (!args.hasKernels()) return dnnl_unimplemented;
2021
2022 return gemm_threading_driver(&args);
2023}
2024
2025template // Instantiate gemm_bf16bf16f32
2026 dnnl_status_t
2027 gemm_driver<bfloat16_t, bfloat16_t, float>(const char *transA,
2028 const char *transB, const char *offsetC, const dim_t *m,
2029 const dim_t *n, const dim_t *k, const float *alpha,
2030 const bfloat16_t *a, const dim_t *lda, const bfloat16_t *oa,
2031 const bfloat16_t *b, const dim_t *ldb, const bfloat16_t *ob,
2032 const float *beta, float *c, const dim_t *ldc, const float *oc,
2033 const bool force_nocopy, pack_type packing,
2034 gemm_pack_storage_t *pack_dst, bool measure_only);
2035
2036template // Instantiate gemm_s8s8s32
2037 dnnl_status_t
2038 gemm_driver<int8_t, int8_t, int32_t>(const char *transA,
2039 const char *transB, const char *offsetC, const dim_t *m,
2040 const dim_t *n, const dim_t *k, const float *alpha,
2041 const int8_t *a, const dim_t *lda, const int8_t *oa,
2042 const int8_t *b, const dim_t *ldb, const int8_t *ob,
2043 const float *beta, int32_t *c, const dim_t *ldc,
2044 const int32_t *oc, const bool force_nocopy, pack_type packing,
2045 gemm_pack_storage_t *pack_dst, bool measure_only);
2046
2047template // Instantiate gemm_s8u8s32
2048 dnnl_status_t
2049 gemm_driver<int8_t, uint8_t, int32_t>(const char *transA,
2050 const char *transB, const char *offsetC, const dim_t *m,
2051 const dim_t *n, const dim_t *k, const float *alpha,
2052 const int8_t *a, const dim_t *lda, const int8_t *oa,
2053 const uint8_t *b, const dim_t *ldb, const uint8_t *ob,
2054 const float *beta, int32_t *c, const dim_t *ldc,
2055 const int32_t *oc, const bool force_nocopy, pack_type packing,
2056 gemm_pack_storage_t *pack_dst, bool measure_only);
2057
2058template // Instantiate sgemm
2059 dnnl_status_t
2060 gemm_driver<float, float, float>(const char *transA, const char *transB,
2061 const char *offsetC, const dim_t *m, const dim_t *n,
2062 const dim_t *k, const float *alpha, const float *a,
2063 const dim_t *lda, const float *oa, const float *b,
2064 const dim_t *ldb, const float *ob, const float *beta, float *c,
2065 const dim_t *ldc, const float *oc, const bool force_nocopy,
2066 pack_type packing, gemm_pack_storage_t *pack_dst,
2067 bool measure_only);
2068
2069#undef MAX_STACK_SZ
2070} // namespace x64
2071} // namespace cpu
2072} // namespace impl
2073} // namespace dnnl
2074