1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstdint>
18#include <memory>
19#include <mutex>
20
21#include "oneapi/dnnl/dnnl_types.h"
22
23#include "common/bfloat16.hpp"
24#include "common/dnnl_traits.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/jit_generator.hpp"
28
29#include "cpu/x64/gemm/gemm_info.hpp"
30
31#include "cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp"
32#include "cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp"
33
34#include "cpu/x64/gemm/bf16/common_s16.hpp"
35#include "cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp"
36#include "cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp"
37
38#include "cpu/x64/gemm/f32/common_f32.hpp"
39#include "cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp"
40#include "cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp"
41#include "cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp"
42#include "cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp"
43
44#include "cpu/x64/gemm/s8x8s32/common_u8.hpp"
45#include "cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.hpp"
46#include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp"
47#include "cpu/x64/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8x8s32_kern.hpp"
48
49namespace dnnl {
50namespace impl {
51namespace cpu {
52namespace x64 {
53
54static inline int decode_trans(char trans) {
55 switch (trans) {
56 case 'T':
57 case 't': return do_trans;
58 case 'P':
59 case 'p': return packed;
60 default: return no_trans;
61 }
62}
63
64namespace {
65template <typename b_t> // XXX for float and bfloat
66void prepare_bo(int32_t &bo_gemm_info, const b_t *bo_orig) {
67 UNUSED(bo_orig);
68 bo_gemm_info = 0;
69}
70template <>
71void prepare_bo(int32_t &bo_gemm_info, const uint8_t *bo_orig) {
72 bo_gemm_info = bo_orig ? *bo_orig : 0;
73}
74template <>
75void prepare_bo(int32_t &bo_gemm_info, const int8_t *bo_orig) {
76 int bo_s32 = bo_orig ? *bo_orig : 0;
77 if (!mayiuse(avx512_core_amx)) bo_s32 += 128;
78 bo_gemm_info = bo_s32;
79}
80
81} // namespace
82
83template <typename a_t, typename b_t, typename c_t>
84gemm_info_t<a_t, b_t, c_t>::gemm_info_t(const char *transA, const char *transB,
85 const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k,
86 const float *alpha, const a_t *a, const dim_t *lda, const a_t *oa,
87 const b_t *b, const dim_t *ldb, const b_t *ob, const float *beta,
88 c_t *c, const dim_t *ldc, const c_t *oc, bool force_nocopy,
89 pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) {
90
91 this->transa = decode_trans(*transA);
92 this->transb = decode_trans(*transB);
93
94 this->m = *m;
95 this->n = *n;
96 this->k = *k;
97
98 this->a = a;
99 this->b = b;
100 this->c = c;
101
102 this->lda = lda ? *lda : 0;
103 this->ldb = ldb ? *ldb : 0;
104 this->ldc = ldc ? *ldc : 0;
105
106 this->ao = 0;
107 this->bo = 0;
108 this->co = nullptr;
109
110 this->alpha = alpha ? *alpha : 1.0f;
111 this->beta = beta ? *beta : 1.0f;
112
113 this->offsetc = offset_type::none;
114
115 this->packing = packing;
116 this->pack_dst = pack_dst;
117 this->measure_only
118 = measure_only && pack_dst && (packing != pack_type::none);
119
120 if (this->transa == packed) {
121 dim_t cols;
122
123 this->a_packed.reset(new gemm_pack_storage_t(a));
124 if (this->a_packed->get_nocopy(this->transa, this->lda, cols)) {
125 this->a = this->a_packed->template matrix<a_t>();
126 this->a_packed = nullptr;
127 }
128 }
129 if (this->transb == packed) {
130 dim_t rows;
131
132 this->b_packed.reset(new gemm_pack_storage_t(b));
133 if (this->b_packed->get_nocopy(this->transb, this->ldb, rows)) {
134 this->b = this->b_packed->template matrix<b_t>();
135 this->b_packed = nullptr;
136 }
137 }
138
139 constexpr bool is_int8 = utils::one_of(
140 data_traits<a_t>::data_type, data_type::s8, data_type::u8);
141 if (is_int8) this->ao = oa ? *oa : a_t(0);
142 prepare_bo<b_t>(this->bo, ob);
143
144 if (offsetC != nullptr) {
145 char offsetc = *offsetC;
146 if (offsetc == 'F' || offsetc == 'f') {
147 this->offsetc = offset_type::fixed;
148 } else if (offsetc == 'R' || offsetc == 'r') {
149 this->offsetc = offset_type::row;
150 } else { // offsetc == 'C' || offsetc == 'c'
151 this->offsetc = offset_type::column;
152 }
153 this->co = oc;
154 }
155
156 bool is_sgemm = data_traits<a_t>::data_type == data_type::f32;
157 bool is_gemv = this->m == 1 || this->n == 1;
158
159 // Copy-based sgemm doesn't support force-nocopy for ISAs older
160 // than Intel AVX.
161 this->force_nocopy = is_sgemm && force_nocopy && mayiuse(avx);
162
163 if (!this->force_nocopy || is_gemv) { this->jit_init(); }
164}
165
166static std::mutex kern_mutex;
167
168// copyA[trans][sum]
169template <typename a_t, typename b_t, typename c_t>
170typename gemm_info_t<a_t, b_t, c_t>::copy_a_fptr_t
171 gemm_info_t<a_t, b_t, c_t>::copy_a_kern[2][2]
172 = {{nullptr}};
173
174// copyB[trans][sum]
175template <typename a_t, typename b_t, typename c_t>
176typename gemm_info_t<a_t, b_t, c_t>::copy_b_fptr_t
177 gemm_info_t<a_t, b_t, c_t>::copy_b_kern[2][2]
178 = {{nullptr}};
179
180// kern[beta0][alpha1][col_off][row_off]
181template <typename a_t, typename b_t, typename c_t>
182typename gemm_info_t<a_t, b_t, c_t>::gemm_fptr_t
183 gemm_info_t<a_t, b_t, c_t>::kern[2][2][2][2]
184 = {{{{nullptr}}}};
185
186// gemv_kern[trans]
187template <typename a_t, typename b_t, typename c_t>
188typename gemm_info_t<a_t, b_t, c_t>::gemv_fptr_t
189 gemm_info_t<a_t, b_t, c_t>::gemv_kern[2]
190 = {nullptr};
191
192template <typename a_t, typename b_t, typename c_t>
193typename gemm_info_t<a_t, b_t, c_t>::gemv_s8s8s32_fptr_t
194 gemm_info_t<a_t, b_t, c_t>::gemv_s8s8s32_kern
195 = nullptr;
196template <typename a_t, typename b_t, typename c_t>
197typename gemm_info_t<a_t, b_t, c_t>::gemv_s8u8s32_fptr_t
198 gemm_info_t<a_t, b_t, c_t>::gemv_s8u8s32_kern
199 = nullptr;
200template <typename a_t, typename b_t, typename c_t>
201typename gemm_info_t<a_t, b_t, c_t>::gemv_u8s8s32_fptr_t
202 gemm_info_t<a_t, b_t, c_t>::gemv_u8s8s32_kern
203 = nullptr;
204
205template <typename a_t, typename b_t, typename c_t>
206void gemm_info_t<a_t, b_t, c_t>::jit_init(void) {
207
208 bool use_bf16_ymm = false;
209 // TODO: Add dispatching for 1-fma SKUs with support to bf16
210 // instructions for AMX kernel.
211 {
212 constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16;
213 const bool max_isa_supports_bf16_ymm
214 = mayiuse(avx512_core_bf16_ymm) && !mayiuse(avx512_core_amx);
215
216 use_bf16_ymm = is_bf16 && max_isa_supports_bf16_ymm;
217 }
218
219 switch (data_traits<a_t>::data_type) {
220 case data_type::s8:
221 if (mayiuse(avx512_core_amx)) {
222 this->um = 32;
223 this->un = 32;
224 this->uk = 64;
225 this->bm = 9984;
226 this->bn = 384;
227 this->bk = 1536;
228
229 this->bk_traditional = 0;
230 this->blocking_small_k = 0;
231 this->bn_small_k = 0;
232 } else if (mayiuse(avx512_core)) {
233 this->um = 48;
234 this->un = 8;
235 this->uk = 1;
236 this->bm = 9984;
237 this->bn = 384;
238 this->bk = mayiuse(avx512_core_vnni) ? 1536 : 768;
239
240 this->bk_traditional = 384;
241 this->blocking_small_k = 48;
242 this->bn_small_k = 24;
243 } else if (mayiuse(avx2)) {
244 this->um = mayiuse(avx2_vnni) ? 24 : 16;
245 this->un = 4;
246 this->uk = 1;
247 this->bm = 9984;
248 this->bn = mayiuse(avx2_vnni) ? 192 : 384;
249 this->bk = mayiuse(avx2_vnni) ? 768 : 384;
250
251 this->bk_traditional = 256;
252 this->blocking_small_k = 48;
253 this->bn_small_k = 24;
254 } else if (mayiuse(avx)) {
255 this->um = 16;
256 this->un = 2;
257 this->uk = 1;
258 this->bm = 4096;
259 this->bn = 256;
260 this->bk = 256;
261
262 this->bk_traditional = 256;
263 this->blocking_small_k = 48;
264 this->bn_small_k = 24;
265 } else if (mayiuse(sse41)) {
266 this->um = 16;
267 this->un = 2;
268 this->uk = 1;
269 this->bm = 4096;
270 this->bn = 256;
271 this->bk = 256;
272
273 this->bk_traditional = 256;
274 this->blocking_small_k = 48;
275 this->bn_small_k = 24;
276 }
277 break;
278
279 case data_type::bf16:
280 if (mayiuse(avx512_core_amx)) {
281 this->um = 32;
282 this->un = 32;
283 this->uk = 32;
284 this->bm = 9984;
285 this->bn = 384;
286 this->bk = 768;
287
288 this->bk_traditional = 0;
289 this->blocking_small_k = 0;
290 this->bn_small_k = 0;
291 } else if (mayiuse(avx512_core)) {
292 this->um = use_bf16_ymm ? 24 : 48;
293 this->un = 8;
294 this->uk = 1;
295 this->bm = 9984;
296 this->bn = 384;
297 this->bk = use_bf16_ymm ? 384 : 768;
298
299 this->bk_traditional = 384;
300 this->blocking_small_k = 48;
301 this->bn_small_k = 24;
302 }
303 break;
304
305 case data_type::f32:
306 if (mayiuse(avx512_core)) {
307 this->um = 48;
308 this->un = 8;
309 this->uk = 1;
310 this->bm = 9984;
311 this->bn = 384;
312 this->bk = 384;
313
314 this->bk_traditional = 384;
315 this->blocking_small_k = 48;
316 this->bn_small_k = 24;
317 } else if (mayiuse(avx2)) {
318 this->um = 24;
319 this->un = 4;
320 this->uk = 1;
321 this->bm = 10000;
322 this->bn = 384;
323 this->bk = 192;
324
325 this->bk_traditional = 256;
326 this->blocking_small_k = 48;
327 this->bn_small_k = 24;
328 } else if (mayiuse(avx)) {
329 this->um = 16;
330 this->un = 4;
331 this->uk = 1;
332 this->bm = 4096;
333 this->bn = 96;
334 this->bk = 256;
335
336 this->bk_traditional = 256;
337 this->blocking_small_k = 48;
338 this->bn_small_k = 24;
339 } else if (mayiuse(sse41)) {
340 this->um = 8;
341 this->un = 4;
342 this->uk = 1;
343 this->bm = 4096;
344 this->bn = 96;
345 this->bk = 256;
346
347 this->bk_traditional = 256;
348 this->blocking_small_k = 48;
349 this->bn_small_k = 24;
350 }
351 break;
352 default: assert(!"unsupported data type!");
353 }
354
355 // Note: um is fixed for a given set of data types and ISA.
356 const int um = this->um;
357
358 static std::once_flag initialized;
359 static std::atomic<dnnl_status_t> st(dnnl_success);
360 std::call_once(initialized, [&, um] {
361 const bool b_is_s8 = data_traits<b_t>::data_type == data_type::s8;
362 constexpr bool is_int8 = utils::one_of(
363 data_traits<a_t>::data_type, data_type::s8, data_type::u8);
364 constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16;
365 bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx);
366 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx);
367 bool is_amx = is_int8_amx || is_bf16_amx;
368
369 static maybe_unique_ptr<jit_generator> copy_a[2][2] = {{nullptr}};
370 static maybe_unique_ptr<jit_generator> copy_b[2][2] = {{nullptr}};
371
372 switch (data_traits<a_t>::data_type) {
373 case data_type::s8:
374 if (mayiuse(amx_int8)) {
375 for (int isTrans : {no_trans, do_trans}) {
376 copy_a[isTrans][no_sum].reset(
377 new jit_avx512_core_amx_copy_kern(
378 true, !isTrans, sizeof(a_t)));
379
380 copy_b[isTrans][no_sum].reset(
381 new jit_avx512_core_amx_copy_kern(
382 false, isTrans, sizeof(b_t)));
383 }
384 } else if (mayiuse(avx512_core)) {
385 copy_a[no_trans][no_sum].reset(
386 new jit_avx512_core_u8_copy_an_kern());
387 copy_a[do_trans][no_sum].reset(
388 new jit_avx512_core_u8_copy_at_kern());
389
390 copy_b[no_trans][no_sum].reset(
391 new jit_avx512_core_u8_copy_bn_kern(b_is_s8));
392 copy_b[do_trans][no_sum].reset(
393 new jit_avx512_core_u8_copy_bt_kern(b_is_s8));
394
395 copy_a[no_trans][do_sum].reset(
396 new jit_avx512_core_u8_copy_sum_an_kern());
397 copy_a[do_trans][do_sum].reset(
398 new jit_avx512_core_u8_copy_sum_at_kern());
399
400 copy_b[no_trans][do_sum].reset(
401 new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8));
402 copy_b[do_trans][do_sum].reset(
403 new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8));
404 } else if (mayiuse(avx2_vnni)) {
405 copy_a[no_trans][no_sum].reset(
406 new jit_avx2_vnni_u8_copy_an_kern());
407 copy_a[do_trans][no_sum].reset(
408 new jit_avx2_vnni_u8_copy_at_kern());
409
410 copy_b[no_trans][no_sum].reset(
411 new jit_avx2_vnni_u8_copy_bn_kern());
412 copy_b[do_trans][no_sum].reset(
413 new jit_avx2_vnni_u8_copy_bt_kern());
414
415 copy_a[no_trans][do_sum].reset(
416 new jit_avx2_vnni_u8_copy_sum_an_kern());
417 copy_a[do_trans][do_sum].reset(
418 new jit_avx2_vnni_u8_copy_sum_at_kern());
419
420 copy_b[no_trans][do_sum].reset(
421 new jit_avx2_vnni_u8_copy_sum_bn_kern());
422 copy_b[do_trans][do_sum].reset(
423 new jit_avx2_vnni_u8_copy_sum_bt_kern());
424 } else if (mayiuse(avx2)) {
425 copy_a[no_trans][no_sum].reset(
426 new jit_avx2_u8_copy_an_kern());
427 copy_a[do_trans][no_sum].reset(
428 new jit_avx2_u8_copy_at_kern());
429
430 copy_b[no_trans][no_sum].reset(
431 new jit_avx2_u8_copy_bn_kern());
432 copy_b[do_trans][no_sum].reset(
433 new jit_avx2_u8_copy_bt_kern());
434
435 copy_a[no_trans][do_sum].reset(
436 new jit_avx2_u8_copy_sum_an_kern());
437 copy_a[do_trans][do_sum].reset(
438 new jit_avx2_u8_copy_sum_at_kern());
439
440 copy_b[no_trans][do_sum].reset(
441 new jit_avx2_u8_copy_sum_bn_kern());
442 copy_b[do_trans][do_sum].reset(
443 new jit_avx2_u8_copy_sum_bt_kern());
444 } else if (mayiuse(avx)) {
445 copy_a[no_trans][no_sum].reset(
446 new jit_avx_u8_copy_an_kern());
447 copy_a[do_trans][no_sum].reset(
448 new jit_avx_u8_copy_at_kern());
449
450 copy_b[no_trans][no_sum].reset(
451 new jit_avx_u8_copy_bn_kern());
452 copy_b[do_trans][no_sum].reset(
453 new jit_avx_u8_copy_bt_kern());
454
455 copy_a[no_trans][do_sum].reset(
456 new jit_avx_u8_copy_sum_an_kern());
457 copy_a[do_trans][do_sum].reset(
458 new jit_avx_u8_copy_sum_at_kern());
459
460 copy_b[no_trans][do_sum].reset(
461 new jit_avx_u8_copy_sum_bn_kern());
462 copy_b[do_trans][do_sum].reset(
463 new jit_avx_u8_copy_sum_bt_kern());
464 } else if (mayiuse(sse41)) {
465 copy_a[no_trans][no_sum].reset(
466 new jit_sse41_u8_copy_an_kern());
467 copy_a[do_trans][no_sum].reset(
468 new jit_sse41_u8_copy_at_kern());
469
470 copy_b[no_trans][no_sum].reset(
471 new jit_sse41_u8_copy_bn_kern());
472 copy_b[do_trans][no_sum].reset(
473 new jit_sse41_u8_copy_bt_kern());
474
475 copy_a[no_trans][do_sum].reset(
476 new jit_sse41_u8_copy_sum_an_kern());
477 copy_a[do_trans][do_sum].reset(
478 new jit_sse41_u8_copy_sum_at_kern());
479
480 copy_b[no_trans][do_sum].reset(
481 new jit_sse41_u8_copy_sum_bn_kern());
482 copy_b[do_trans][do_sum].reset(
483 new jit_sse41_u8_copy_sum_bt_kern());
484 }
485 break;
486
487 case data_type::bf16:
488 if (mayiuse(amx_bf16)) {
489 for (int isTrans : {no_trans, do_trans}) {
490 copy_a[isTrans][no_sum].reset(
491 new jit_avx512_core_amx_copy_kern(
492 true, !isTrans, sizeof(a_t)));
493
494 copy_b[isTrans][no_sum].reset(
495 new jit_avx512_core_amx_copy_kern(
496 false, isTrans, sizeof(b_t)));
497 }
498 } else if (mayiuse(avx512_core) && !use_bf16_ymm) {
499 copy_a[no_trans][no_sum].reset(
500 new jit_avx512_core_s16_48x8_copy_an_kern());
501 copy_a[do_trans][no_sum].reset(
502 new jit_avx512_core_s16_48x8_copy_at_kern());
503
504 copy_b[no_trans][no_sum].reset(
505 new jit_avx512_core_s16_48x8_copy_bn_kern());
506 copy_b[do_trans][no_sum].reset(
507 new jit_avx512_core_s16_48x8_copy_bt_kern());
508 } else if (mayiuse(avx512_core) && use_bf16_ymm) {
509 copy_a[no_trans][no_sum].reset(
510 new jit_avx512_core_s16_24x8_copy_an_kern());
511 copy_a[do_trans][no_sum].reset(
512 new jit_avx512_core_s16_24x8_copy_at_kern());
513
514 copy_b[no_trans][no_sum].reset(
515 new jit_avx512_core_s16_24x8_copy_bn_kern());
516 copy_b[do_trans][no_sum].reset(
517 new jit_avx512_core_s16_24x8_copy_bt_kern());
518 }
519 break;
520
521 case data_type::f32:
522 if (mayiuse(avx512_core)) {
523 copy_a[no_trans][no_sum].reset(
524 new jit_avx512_core_f32_copy_an_kern());
525 copy_a[do_trans][no_sum].reset(
526 new jit_avx512_core_f32_copy_at_kern());
527
528 copy_b[no_trans][no_sum].reset(
529 new jit_avx512_core_f32_copy_bn_kern());
530 copy_b[do_trans][no_sum].reset(
531 new jit_avx512_core_f32_copy_bt_kern());
532 } else if (mayiuse(avx2)) {
533 copy_a[no_trans][no_sum].reset(
534 new jit_avx2_f32_copy_an_kern());
535 copy_a[do_trans][no_sum].reset(
536 new jit_avx2_f32_copy_at_kern());
537
538 copy_b[no_trans][no_sum].reset(
539 new jit_avx2_f32_copy_bn_kern());
540 copy_b[do_trans][no_sum].reset(
541 new jit_avx2_f32_copy_bt_kern());
542 } else if (mayiuse(avx)) {
543 copy_a[no_trans][no_sum].reset(
544 new jit_avx_f32_copy_an_kern());
545 copy_a[do_trans][no_sum].reset(
546 new jit_avx_f32_copy_at_kern());
547
548 copy_b[no_trans][no_sum].reset(
549 new jit_avx_f32_copy_bn_kern());
550 copy_b[do_trans][no_sum].reset(
551 new jit_avx_f32_copy_bt_kern());
552 } else if (mayiuse(sse41)) {
553 copy_a[no_trans][no_sum].reset(
554 new jit_sse41_f32_copy_an_kern());
555 copy_a[do_trans][no_sum].reset(
556 new jit_sse41_f32_copy_at_kern());
557
558 copy_b[no_trans][no_sum].reset(
559 new jit_sse41_f32_copy_bn_kern());
560 copy_b[do_trans][no_sum].reset(
561 new jit_sse41_f32_copy_bt_kern());
562 }
563 break;
564
565 default: break;
566 }
567
568 constexpr bool is_a_s8 = data_traits<a_t>::data_type == data_type::s8;
569 constexpr bool is_b_s8 = data_traits<b_t>::data_type == data_type::s8;
570 constexpr bool is_c_s32 = data_traits<c_t>::data_type == data_type::s32;
571
572 static maybe_unique_ptr<jit_generator> kernel[2][2][2][2]
573 = {{{{nullptr}}}};
574 switch (data_traits<a_t>::data_type) {
575 case data_type::s8:
576 if (mayiuse(avx512_core_amx)) {
577 for (int isBeta0 : {no_beta0, do_beta0}) {
578 kernel[isBeta0][do_alpha1][no_sum][no_sum].reset(
579 new jit_avx512_core_amx_gemm_kern(
580 is_a_s8, is_b_s8, is_c_s32, isBeta0));
581 }
582 } else if (mayiuse(avx512_core)) {
583 for (int isBeta0 : {no_beta0, do_beta0})
584 for (int doColSum : {no_sum, do_sum})
585 for (int doRowSum : {no_sum, do_sum}) {
586 kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset(
587 new jit_avx512_core_gemm_s8u8s32_kern(
588 isBeta0, doColSum, doRowSum));
589 }
590 } else if (mayiuse(avx2)) {
591 for (int isBeta0 : {no_beta0, do_beta0})
592 for (int doColSum : {no_sum, do_sum})
593 for (int doRowSum : {no_sum, do_sum}) {
594 kernel[isBeta0][do_alpha1][doColSum][doRowSum]
595 .reset(new jit_avx2_gemm_s8u8s32_kern(
596 isBeta0, doColSum, doRowSum,
597 um));
598 }
599 } else if (mayiuse(avx)) {
600 kernel[no_beta0][do_alpha1][no_sum][no_sum].reset(
601 new jit_avx_kernel_gemm_s8u8s32_kern());
602 kernel[no_beta0][do_alpha1][do_sum][no_sum].reset(
603 new jit_avx_kernel_c_gemm_s8u8s32_kern());
604 kernel[no_beta0][do_alpha1][no_sum][do_sum].reset(
605 new jit_avx_kernel_r_gemm_s8u8s32_kern());
606 kernel[no_beta0][do_alpha1][do_sum][do_sum].reset(
607 new jit_avx_kernel_b_gemm_s8u8s32_kern());
608
609 kernel[do_beta0][do_alpha1][no_sum][no_sum].reset(
610 new jit_avx_kernel_b0_gemm_s8u8s32_kern());
611 kernel[do_beta0][do_alpha1][do_sum][no_sum].reset(
612 new jit_avx_kernel_b0_c_gemm_s8u8s32_kern());
613 kernel[do_beta0][do_alpha1][no_sum][do_sum].reset(
614 new jit_avx_kernel_b0_r_gemm_s8u8s32_kern());
615 kernel[do_beta0][do_alpha1][do_sum][do_sum].reset(
616 new jit_avx_kernel_b0_b_gemm_s8u8s32_kern());
617 } else if (mayiuse(sse41)) {
618 kernel[no_beta0][do_alpha1][no_sum][no_sum].reset(
619 new jit_sse41_kernel_gemm_s8u8s32_kern());
620 kernel[no_beta0][do_alpha1][do_sum][no_sum].reset(
621 new jit_sse41_kernel_c_gemm_s8u8s32_kern());
622 kernel[no_beta0][do_alpha1][no_sum][do_sum].reset(
623 new jit_sse41_kernel_r_gemm_s8u8s32_kern());
624 kernel[no_beta0][do_alpha1][do_sum][do_sum].reset(
625 new jit_sse41_kernel_b_gemm_s8u8s32_kern());
626
627 kernel[do_beta0][do_alpha1][no_sum][no_sum].reset(
628 new jit_sse41_kernel_b0_gemm_s8u8s32_kern());
629 kernel[do_beta0][do_alpha1][do_sum][no_sum].reset(
630 new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern());
631 kernel[do_beta0][do_alpha1][no_sum][do_sum].reset(
632 new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern());
633 kernel[do_beta0][do_alpha1][do_sum][do_sum].reset(
634 new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern());
635 }
636 break;
637
638 case data_type::bf16:
639 if (mayiuse(avx512_core_amx)) {
640 for (int isBeta0 : {no_beta0, do_beta0}) {
641 kernel[isBeta0][do_alpha1][no_sum][no_sum].reset(
642 new jit_avx512_core_amx_gemm_kern(
643 is_a_s8, is_b_s8, is_c_s32, isBeta0));
644 }
645 } else if (mayiuse(avx512_core)) {
646 for (int isBeta0 : {no_beta0, do_beta0})
647 for (int isAlpha1 : {no_alpha1, do_alpha1}) {
648 kernel[isBeta0][isAlpha1][no_sum][no_sum].reset(
649 new jit_avx512_core_gemm_bf16bf16f32_kern(
650 isBeta0, isAlpha1, !use_bf16_ymm));
651 }
652 }
653 break;
654
655 case data_type::f32:
656 if (mayiuse(avx2)) {
657 for (int isBeta0 : {no_beta0, do_beta0}) {
658 kernel[isBeta0][do_alpha1][no_sum][no_sum].reset(
659 new jit_avx2_kernel_sgemm_kern(isBeta0));
660 }
661 } else if (mayiuse(avx)) {
662 kernel[no_beta0][do_alpha1][no_sum][no_sum].reset(
663 new jit_avx_kernel_sgemm_kern());
664 kernel[do_beta0][do_alpha1][no_sum][no_sum].reset(
665 new jit_avx_kernel_b0_sgemm_kern());
666 } else if (mayiuse(sse41)) {
667 kernel[no_beta0][do_alpha1][no_sum][no_sum].reset(
668 new jit_sse41_kernel_sgemm_kern());
669 kernel[do_beta0][do_alpha1][no_sum][no_sum].reset(
670 new jit_sse41_kernel_b0_sgemm_kern());
671 }
672 break;
673
674 default: break;
675 }
676
677 static maybe_unique_ptr<jit_generator> gemv_kernel[2] = {nullptr};
678 static maybe_unique_ptr<jit_generator> gemv_s8s8s32_kernel = nullptr;
679 static maybe_unique_ptr<jit_generator> gemv_s8u8s32_kernel = nullptr;
680 static maybe_unique_ptr<jit_generator> gemv_u8s8s32_kernel = nullptr;
681 switch (data_traits<a_t>::data_type) {
682 case data_type::s8:
683 if (mayiuse(avx512_core)) {
684 gemv_s8s8s32_kernel.reset(
685 new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8));
686 gemv_s8u8s32_kernel.reset(
687 new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8));
688 gemv_u8s8s32_kernel.reset(
689 new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8));
690 }
691 break;
692
693 case data_type::bf16:
694 if (mayiuse(avx512_core)) {
695 for (int isTrans : {no_trans, do_trans})
696 gemv_kernel[isTrans].reset(
697 new jit_avx512_core_gemv_bf16bf16f32_kern(
698 isTrans));
699 }
700 break;
701
702 case data_type::f32:
703 if (mayiuse(avx)) {
704 gemv_kernel[no_trans].reset(
705 new jit_sse41_gemv_n_f32_kern());
706 gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern());
707 } else if (mayiuse(sse41)) {
708 gemv_kernel[no_trans].reset(
709 new jit_sse41_gemv_n_f32_kern());
710 gemv_kernel[do_trans].reset(
711 new jit_sse41_gemv_t_f32_kern());
712 }
713 break;
714 default: assert(!"unsupported data type!");
715 }
716
717 // Set copy kernels function pointer table
718 for (int isTrans : {no_trans, do_trans})
719 for (int isSum : {no_sum, do_sum}) {
720 auto *p_copy_a = copy_a[isTrans][isSum].get();
721 if (p_copy_a != nullptr) {
722 st = p_copy_a->create_kernel();
723 if (st != dnnl_success) return;
724 copy_a_kern[isTrans][isSum]
725 = (copy_a_fptr_t)p_copy_a->jit_ker();
726 }
727 auto *p_copy_b = copy_b[isTrans][isSum].get();
728 if (p_copy_b != nullptr) {
729 st = p_copy_b->create_kernel();
730 if (st != dnnl_success) return;
731 copy_b_kern[isTrans][isSum]
732 = (copy_b_fptr_t)p_copy_b->jit_ker();
733 }
734 }
735
736 // AMX copy kernels don't support row/column sum. Use wrappers for now.
737 if (is_int8_amx) {
738 copy_a_kern[no_trans][do_sum] = &copy_a_sum_ref<no_trans>;
739 copy_a_kern[do_trans][do_sum] = &copy_a_sum_ref<do_trans>;
740 copy_b_kern[no_trans][do_sum] = &copy_b_sum_ref<no_trans>;
741 copy_b_kern[do_trans][do_sum] = &copy_b_sum_ref<do_trans>;
742 }
743
744 // Set compute kernel function pointer table
745 for (int isBeta0 : {no_beta0, do_beta0})
746 for (int isAlpha1 : {no_alpha1, do_alpha1})
747 for (int doColSum : {no_sum, do_sum})
748 for (int doRowSum : {no_sum, do_sum}) {
749 auto *p_kernel
750 = kernel[isBeta0][isAlpha1][doColSum][doRowSum]
751 .get();
752 if (p_kernel != nullptr) {
753 st = p_kernel->create_kernel();
754 if (st != dnnl_success) return;
755 kern[isBeta0][isAlpha1][doColSum][doRowSum]
756 = (gemm_fptr_t)p_kernel->jit_ker();
757 }
758 }
759 // Override compute kernel table with AMX kernels
760 if (is_amx) {
761 // AMX compute kernels don't support alpha scaling, row-offset or
762 // col-offset.
763 for (int isBeta0 : {no_beta0, do_beta0})
764 for (int isAlpha1 : {no_alpha1, do_alpha1})
765 for (int doColSum : {no_sum, do_sum})
766 for (int doRowSum : {no_sum, do_sum}) {
767 kern[isBeta0][isAlpha1][doColSum][doRowSum]
768 = kern[isBeta0][do_alpha1][no_sum][no_sum];
769 }
770 }
771
772 // Set gemv floating point kernels
773 if (utils::one_of(data_traits<a_t>::data_type, data_type::f32,
774 data_type::bf16)) {
775 for (int isTrans : {no_trans, do_trans}) {
776 auto *p_gemv_kernel = gemv_kernel[isTrans].get();
777 if (p_gemv_kernel != nullptr) {
778 st = p_gemv_kernel->create_kernel();
779 if (st != dnnl_success) return;
780 gemv_kern[isTrans] = (gemv_fptr_t)p_gemv_kernel->jit_ker();
781 }
782 }
783 }
784
785 // Set gemv integer gemm kernels
786 if (data_traits<a_t>::data_type == data_type::s8) {
787 if (gemv_s8s8s32_kernel != nullptr) {
788 auto *kern = gemv_s8s8s32_kernel.get();
789 st = kern->create_kernel();
790 if (st != dnnl_success) return;
791 gemv_s8s8s32_kern = (gemv_s8s8s32_fptr_t)kern->jit_ker();
792 }
793
794 if (gemv_s8u8s32_kernel != nullptr) {
795 auto *kern = gemv_s8u8s32_kernel.get();
796 st = kern->create_kernel();
797 if (st != dnnl_success) return;
798 gemv_s8u8s32_kern = (gemv_s8u8s32_fptr_t)kern->jit_ker();
799 }
800
801 if (gemv_u8s8s32_kernel != nullptr) {
802 auto *kern = gemv_u8s8s32_kernel.get();
803 st = kern->create_kernel();
804 if (st != dnnl_success) return;
805 gemv_u8s8s32_kern = (gemv_u8s8s32_fptr_t)kern->jit_ker();
806 }
807 }
808 });
809
810 if (st != dnnl_success) return;
811
812 int doSumA = this->bo != 0 ? do_sum : no_sum;
813 int doSumB = this->ao != 0 ? do_sum : no_sum;
814
815 int copy_trans_a = (this->transa == do_trans) ? do_trans : no_trans;
816 int copy_trans_b = (this->transb == do_trans) ? do_trans : no_trans;
817
818 constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16;
819 bool doAlpha1 = this->alpha != 1.0f && is_bf16 ? no_alpha1 : do_alpha1;
820
821 {
822 // NOTE: This lock may not be needed at all as writes to copy_a_kern
823 // (and others) are protected within std::call_once(). The lock is added
824 // only to fix warnings reported by clang TSAN about a data race in
825 // this code block.
826 std::lock_guard<std::mutex> g(kern_mutex);
827 this->copyA = copy_a_kern[copy_trans_a][doSumA];
828 this->copyB = copy_b_kern[copy_trans_b][doSumB];
829 for (int isBeta0 : {no_beta0, do_beta0})
830 for (int doColSum : {no_sum, do_sum})
831 for (int doRowSum : {no_sum, do_sum})
832 this->kernel[isBeta0][doColSum][doRowSum]
833 = kern[isBeta0][doAlpha1][doColSum][doRowSum];
834 for (int isTrans : {no_trans, do_trans})
835 this->gemv_kernel[isTrans] = gemv_kern[isTrans];
836 }
837
838 this->gemv_s8s8s32_kernel = nullptr;
839 this->gemv_s8u8s32_kernel = nullptr;
840 this->gemv_u8s8s32_kernel = nullptr;
841 if (data_traits<a_t>::data_type == data_type::s8) {
842 this->gemv_s8s8s32_kernel = gemv_s8s8s32_kern;
843 this->gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
844 this->gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
845 }
846}
847
848// Check if copy algorithm kernels were generated on supported ISAs.
849// Copy algorithm supported for:
850// s8 : Intel AVX512, Intel DL Boost
851// bf16 : Intel AVX512, Intel AVX512 BF16
852// f32 : Intel SSE4.1, Intel AVX, Intel AVX2, Intel AVX512
853template <typename a_t, typename b_t, typename c_t>
854bool gemm_info_t<a_t, b_t, c_t>::hasKernels(void) {
855
856 switch (data_traits<a_t>::data_type) {
857 case data_type::s8:
858 if (mayiuse(sse41)) {
859 for (int isBeta0 : {no_beta0, do_beta0})
860 for (int doColSum : {no_sum, do_sum})
861 for (int doRowSum : {no_sum, do_sum})
862 if (!this->kernel[isBeta0][doColSum][doRowSum])
863 return false;
864
865 if (!this->copyA || !this->copyB) return false;
866
867 if (mayiuse(avx512_core))
868 if (!this->gemv_s8u8s32_kernel || !this->gemv_u8s8s32_kernel
869 || !this->gemv_s8s8s32_kernel)
870 return false;
871 }
872 break;
873
874 case data_type::bf16:
875 if (mayiuse(avx512_core)) {
876 for (int isBeta0 : {no_beta0, do_beta0})
877 if (!this->kernel[isBeta0][no_sum][no_sum]) return false;
878
879 if (!this->copyA || !this->copyB) return false;
880
881 for (int isTrans : {no_trans, do_trans})
882 if (!this->gemv_kernel[isTrans]) return false;
883 }
884 break;
885
886 case data_type::f32:
887 if (mayiuse(sse41) && !this->force_nocopy) {
888 for (int isBeta0 : {no_beta0, do_beta0})
889 if (!this->kernel[isBeta0][no_sum][no_sum]) return false;
890
891 if (!this->copyA || !this->copyB) return false;
892
893 for (int isTrans : {no_trans, do_trans})
894 if (!this->gemv_kernel[isTrans]) return false;
895 }
896 break;
897 default: assert(!"unsupported data type!");
898 }
899
900 // All kernels necessary have been found or ISA is not supported.
901 return true;
902}
903
904// Override default blocking sizes with sizes specified in the gemm_threading_t
905// structure.
906template <typename a_t, typename b_t, typename c_t>
907void gemm_info_t<a_t, b_t, c_t>::update_blocking(
908 const gemm_threading_t &thread_info) {
909
910 if (thread_info.block_m > 0) this->bm = thread_info.block_m;
911 if (thread_info.block_n > 0) this->bn = thread_info.block_n;
912 if (thread_info.block_k > 0) this->bk = thread_info.block_k;
913}
914
915// Instantiate the gemm_info_t templates needed.
916template // For gemm_s8u8s32
917 struct gemm_info_t<int8_t, uint8_t, int32_t>;
918
919template // For gemm_s8s8s32
920 struct gemm_info_t<int8_t, int8_t, int32_t>;
921
922template // For gemm_bf16bf16f32
923 struct gemm_info_t<bfloat16_t, bfloat16_t, float>;
924
925template // For sgemm.
926 struct gemm_info_t<float, float, float>;
927
928} // namespace x64
929} // namespace cpu
930} // namespace impl
931} // namespace dnnl
932