1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/dnnl_thread.hpp"
20#include "common/dnnl_traits.hpp"
21
22#include "cpu/gemm/gemm.hpp"
23#include "cpu/gemm/gemm_pack.hpp"
24#include "cpu/gemm/os_blas.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27
28#include "cpu/x64/gemm/gemm_driver.hpp"
29#include "cpu/x64/gemm/gemm_utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36bool pack_sgemm_supported() {
37#if USE_MKL_PACKED_GEMM
38 return true;
39#else
40 return mayiuse(sse41);
41#endif
42}
43
44bool pack_gemm_bf16bf16f32_supported() {
45 return mayiuse(avx512_core);
46}
47
48#if USE_MKL_PACKED_GEMM
49static inline CBLAS_IDENTIFIER cblas_identifier(const char *identifier) {
50 return utils::one_of(*identifier, 'a', 'A') ? CblasAMatrix : CblasBMatrix;
51}
52
53static inline CBLAS_TRANSPOSE cblas_transpose(const char *trans) {
54 return utils::one_of(*trans, 'n', 'N') ? CblasNoTrans : CblasTrans;
55}
56
57static inline MKL_INT cblas_storage(const char *trans) {
58 switch (*trans) {
59 case 'N':
60 case 'n': return CblasNoTrans;
61 case 'T':
62 case 't': return CblasTrans;
63 default: return CblasPacked;
64 }
65}
66
67static inline CBLAS_OFFSET cblas_offset(const char *offset) {
68 switch (*offset) {
69 case 'R':
70 case 'r': return CblasRowOffset;
71 case 'C':
72 case 'c': return CblasColOffset;
73 default: return CblasFixOffset;
74 }
75}
76#endif
77
78#if !USE_MKL_PACKED_GEMM
79template <typename a_dt, typename b_dt>
80static inline bool use_reference_igemm(void) {
81 constexpr bool is_s8u8 = true
82 && data_traits<a_dt>::data_type == data_type::s8
83 && data_traits<b_dt>::data_type == data_type::u8;
84 if (is_s8u8)
85 return !mayiuse(sse41);
86 else
87 return !mayiuse(avx512_core);
88}
89
90#else
91template <typename a_dt, typename b_dt>
92static inline bool use_reference_igemm(void) {
93 return true;
94}
95#endif
96
97template <typename T>
98static bool is_good_ld(dim_t ld) {
99 static constexpr auto align = 64 / sizeof(T);
100 static constexpr auto no_align = 2048 / sizeof(T);
101
102 return ((ld % align) == 0) && ((ld % no_align) != 0);
103}
104
105static dnnl_status_t check_pack_get_size_input(const char *identifier,
106 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
107 const dim_t *K, const dim_t *lda, const dim_t *ldb) {
108
109 if (utils::any_null(identifier, transa, transb, M, N, K, lda, ldb))
110 return dnnl_invalid_arguments;
111
112 bool is_transa = utils::one_of(*transa, 'T', 't');
113 bool is_transb = utils::one_of(*transb, 'T', 't');
114
115 bool ok = true && utils::one_of(*transa, 'T', 't', 'N', 'n')
116 && utils::one_of(*transb, 'T', 't', 'N', 'n')
117 && utils::one_of(*identifier, 'A', 'a', 'B', 'b') && *M >= 0
118 && *N >= 0 && *K >= 0
119 && *lda >= nstl::max(dim_t(1), !is_transa ? *M : *K)
120 && *ldb >= nstl::max(dim_t(1), !is_transb ? *K : *N);
121
122 if (!ok) return dnnl_invalid_arguments;
123
124 return dnnl_success;
125}
126
127static dnnl_status_t check_pack_input(const char *identifier,
128 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
129 const dim_t *K, const float *alpha, const dim_t *lda, const dim_t *ldb,
130 const void *src, void *dst) {
131 if (utils::any_null(src, dst, alpha)) return dnnl_invalid_arguments;
132
133 return check_pack_get_size_input(
134 identifier, transa, transb, M, N, K, lda, ldb);
135}
136
137template <typename a_dt, typename b_dt, typename c_dt>
138static dnnl_status_t gemm_pack_driver(const char *identifier,
139 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
140 const dim_t *K, const float *alpha, const dim_t *lda, const dim_t *ldb,
141 const void *src, gemm_pack_storage_t *pack_dst, bool measure_only) {
142
143 a_dt oa = 0;
144 b_dt ob = 0;
145
146 const a_dt *a = nullptr;
147 const b_dt *b = nullptr;
148 pack_type packing;
149
150 if (utils::one_of(*identifier, 'a', 'A')) {
151 a = (const a_dt *)src;
152 packing = pack_type::pack_a;
153 } else {
154 b = (const b_dt *)src;
155 packing = pack_type::pack_b;
156 }
157
158 return gemm_driver<a_dt, b_dt, c_dt>(transa, transb, "N", M, N, K, alpha, a,
159 lda, &oa, b, ldb, &ob, nullptr, nullptr, nullptr, nullptr, false,
160 packing, pack_dst, measure_only);
161}
162
163dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa,
164 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
165 const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) {
166
167 if (!pack_sgemm_supported()) return dnnl_unimplemented;
168
169 dnnl_status_t result;
170 *size = 0;
171 if (pack) *pack = true;
172
173 result = check_pack_get_size_input(
174 identifier, transa, transb, M, N, K, lda, ldb);
175 if (result != dnnl_success) return result;
176
177#if USE_MKL_PACKED_GEMM
178 *size = cblas_sgemm_pack_get_size(cblas_identifier(identifier), *M, *N, *K);
179#else
180 bool do_a = utils::one_of(*identifier, 'a', 'A');
181 float alpha = 1.0f;
182 gemm_pack_storage_shell_t shell {dnnl_get_max_threads()};
183 if (!shell.get()) return dnnl_out_of_memory;
184
185 result = gemm_pack_driver<float, float, float>(identifier, transa, transb,
186 M, N, K, &alpha, lda, ldb, nullptr, &shell, true);
187 if (result != dnnl_success) return result;
188
189 *size = shell.size();
190 if (pack) {
191 *pack = !(shell.single_nocopy()
192 && utils::one_of(do_a ? *transa : *transb, 'n', 'N')
193 && is_good_ld<float>(do_a ? *lda : *ldb));
194 }
195#endif
196
197 return dnnl_success;
198}
199
200dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier,
201 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
202 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
203 bool *pack) {
204
205 if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented;
206
207 dnnl_status_t result;
208 *size = 0;
209 if (pack) *pack = true;
210
211 result = check_pack_get_size_input(
212 identifier, transa, transb, M, N, K, lda, ldb);
213 if (result != dnnl_success) return result;
214
215 float alpha = 1.0f;
216 gemm_pack_storage_shell_t shell {dnnl_get_max_threads()};
217 if (!shell.get()) return dnnl_out_of_memory;
218
219 result = gemm_pack_driver<bfloat16_t, bfloat16_t, float>(identifier, transa,
220 transb, M, N, K, &alpha, lda, ldb, nullptr, &shell, true);
221 if (result != dnnl_success) return result;
222
223 *size = shell.size();
224
225 return dnnl_success;
226}
227
228template <typename a_dt, typename b_dt>
229dnnl_status_t gemm_x8x8s32_pack_get_size(const char *identifier,
230 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
231 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
232 bool *pack) {
233
234 dnnl_status_t result;
235 *size = 0;
236 if (pack) *pack = true;
237
238 result = check_pack_get_size_input(
239 identifier, transa, transb, M, N, K, lda, ldb);
240 if (result != dnnl_success) return result;
241
242#if USE_MKL_PACKED_GEMM
243 constexpr bool is_s8u8 = true
244 && data_traits<a_dt>::data_type == data_type::s8
245 && data_traits<b_dt>::data_type == data_type::u8;
246
247 if (is_s8u8) {
248 *size = cblas_gemm_s8u8s32_pack_get_size(
249 cblas_identifier(identifier), *M, *N, *K);
250 return dnnl_success;
251 }
252#endif
253
254 bool do_a = utils::one_of(*identifier, 'a', 'A');
255 float alpha = 1.0f;
256 gemm_pack_storage_shell_t shell {dnnl_get_max_threads(), do_a, !do_a};
257 if (!shell.get()) return dnnl_out_of_memory;
258
259 if (!use_reference_igemm<a_dt, b_dt>()) {
260 result = gemm_pack_driver<a_dt, b_dt, int32_t>(identifier, transa,
261 transb, M, N, K, &alpha, lda, ldb, nullptr, &shell, true);
262 if (result != dnnl_success) return result;
263 } else {
264 auto rows = do_a ? *M : *K;
265 auto cols = do_a ? *K : *N;
266 if (do_a) {
267 gemm_utils::prep_gemm_pack<int8_t, int32_t>(
268 do_a, no_trans, rows, cols, &shell);
269 } else {
270 gemm_utils::prep_gemm_pack<uint8_t, int32_t>(
271 do_a, no_trans, rows, cols, &shell);
272 }
273 }
274
275 *size = shell.size();
276 if (pack) {
277 *pack = !(shell.single_nocopy()
278 && utils::one_of(do_a ? *transa : *transb, 'n', 'N')
279 && is_good_ld<float>(do_a ? *lda : *ldb));
280 }
281
282 return dnnl_success;
283}
284
285dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier,
286 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
287 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
288 bool *pack) {
289
290 return gemm_x8x8s32_pack_get_size<int8_t, uint8_t>(
291 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
292}
293
294dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier,
295 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
296 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
297 bool *pack) {
298
299 return gemm_x8x8s32_pack_get_size<int8_t, int8_t>(
300 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
301}
302
303dnnl_status_t sgemm_pack(const char *identifier, const char *transa,
304 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
305 const dim_t *lda, const dim_t *ldb, const float *src, float *dst) {
306 float one = 1.f, *alpha = &one;
307
308 if (!pack_sgemm_supported()) return dnnl_unimplemented;
309
310 auto result = check_pack_input(
311 identifier, transa, transb, M, N, K, alpha, lda, ldb, src, dst);
312 if (result != dnnl_success) return result;
313
314#if USE_MKL_PACKED_GEMM
315 auto cblas_id = cblas_identifier(identifier);
316 auto ld = (cblas_id == CblasAMatrix) ? *lda : *ldb;
317 auto trans = (cblas_id == CblasAMatrix) ? transa : transb;
318 cblas_sgemm_pack(CblasColMajor, cblas_id, cblas_transpose(trans), *M, *N,
319 *K, *alpha, src, ld, dst);
320 return dnnl_success;
321#else
322 gemm_pack_storage_t pack_dst(dst, false);
323
324 return gemm_pack_driver<float, float, float>(identifier, transa, transb, M,
325 N, K, alpha, lda, ldb, src, &pack_dst, false);
326#endif
327}
328
329dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa,
330 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
331 const dim_t *lda, const dim_t *ldb, const bfloat16_t *src,
332 bfloat16_t *dst) {
333 float one = 1.f, *alpha = &one;
334
335 if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented;
336
337 auto result = check_pack_input(
338 identifier, transa, transb, M, N, K, alpha, lda, ldb, src, dst);
339 if (result != dnnl_success) return result;
340
341 gemm_pack_storage_t pack_dst(dst, false);
342
343 return gemm_pack_driver<bfloat16_t, bfloat16_t, float>(identifier, transa,
344 transb, M, N, K, alpha, lda, ldb, src, &pack_dst, false);
345}
346
347template <typename a_dt, typename b_dt>
348dnnl_status_t gemm_x8x8s32_pack(const char *identifier, const char *transa,
349 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
350 const dim_t *lda, const dim_t *ldb, const void *src_void, void *dst) {
351
352 float alpha = 1.0f; // Not used with igemm.
353 auto result = check_pack_input(identifier, transa, transb, M, N, K, &alpha,
354 lda, ldb, src_void, dst);
355 if (result != dnnl_success) return result;
356
357#if USE_MKL_PACKED_GEMM
358 constexpr bool is_s8u8 = true
359 && data_traits<a_dt>::data_type == data_type::s8
360 && data_traits<b_dt>::data_type == data_type::u8;
361
362 if (is_s8u8) {
363 auto cblas_id = cblas_identifier(identifier);
364 auto ld = (cblas_id == CblasAMatrix) ? *lda : *ldb;
365 auto trans = (cblas_id == CblasAMatrix) ? transa : transb;
366 cblas_gemm_s8u8s32_pack(CblasColMajor, cblas_id, cblas_transpose(trans),
367 *M, *N, *K, src_void, ld, dst);
368 return dnnl_success;
369 }
370#endif
371 gemm_pack_storage_t pack_dst(dst, false);
372
373 if (!use_reference_igemm<a_dt, b_dt>()) {
374 return gemm_pack_driver<a_dt, b_dt, int32_t>(identifier, transa, transb,
375 M, N, K, &alpha, lda, ldb, src_void, &pack_dst, false);
376 } else {
377 bool do_a = utils::one_of(*identifier, 'a', 'A');
378 bool is_trans = utils::one_of(do_a ? *transa : *transb, 't', 'T');
379 auto ld = do_a ? *lda : *ldb;
380 auto rows = do_a ? *M : *K;
381 auto cols = do_a ? *K : *N;
382
383 if (do_a) {
384 gemm_utils::prep_gemm_pack<int8_t, int32_t>(
385 do_a, no_trans, rows, cols, &pack_dst);
386 auto src = reinterpret_cast<const int8_t *>(src_void);
387 return gemm_utils::pack_no_copy(
388 src, ld, rows, cols, is_trans, alpha, &pack_dst);
389 } else {
390 gemm_utils::prep_gemm_pack<uint8_t, int32_t>(
391 do_a, no_trans, rows, cols, &pack_dst);
392 auto src = reinterpret_cast<const uint8_t *>(src_void);
393 return gemm_utils::pack_no_copy(
394 src, ld, rows, cols, is_trans, alpha, &pack_dst);
395 }
396 }
397}
398
399dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa,
400 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
401 const dim_t *lda, const dim_t *ldb, const void *src, void *dst) {
402
403 return gemm_x8x8s32_pack<int8_t, uint8_t>(
404 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
405}
406
407dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa,
408 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
409 const dim_t *lda, const dim_t *ldb, const void *src, void *dst) {
410
411 return gemm_x8x8s32_pack<int8_t, int8_t>(
412 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
413}
414
415dnnl_status_t sgemm_compute(const char *transa, const char *transb,
416 const dim_t *M, const dim_t *N, const dim_t *K, const float *A,
417 const dim_t *lda, const float *B, const dim_t *ldb, const float *beta,
418 float *C, const dim_t *ldc) {
419
420#if USE_MKL_PACKED_GEMM
421 if (utils::any_null(transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc))
422 return dnnl_invalid_arguments;
423 cblas_sgemm_compute(CblasColMajor, cblas_storage(transa),
424 cblas_storage(transb), *M, *N, *K, A, *lda, B, *ldb, *beta, C,
425 *ldc);
426 return dnnl_success;
427#else
428 if (!pack_sgemm_supported()) return dnnl_unimplemented;
429
430 float one = 1.0f;
431
432 return extended_sgemm(
433 transa, transb, M, N, K, &one, A, lda, B, ldb, beta, C, ldc);
434#endif
435}
436
437dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb,
438 const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A,
439 const dim_t *lda, const bfloat16_t *B, const dim_t *ldb,
440 const float *beta, float *C, const dim_t *ldc) {
441
442 if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented;
443
444 float one = 1.0f;
445
446 return gemm_bf16bf16f32(
447 transa, transb, M, N, K, &one, A, lda, B, ldb, beta, C, ldc);
448}
449
450template <typename a_dt, typename b_dt>
451dnnl_status_t gemm_x8x8s32_compute(const char *transa, const char *transb,
452 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
453 const a_dt *A, const dim_t *lda, const b_dt *B, const dim_t *ldb,
454 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) {
455
456 const float one = 1.f, *alpha = &one;
457 const a_dt zero_a_dt = 0, *ao = &zero_a_dt;
458 const b_dt zero_b_dt = 0, *bo = &zero_b_dt;
459
460#if USE_MKL_PACKED_GEMM
461 constexpr bool is_s8u8 = true
462 && data_traits<a_dt>::data_type == data_type::s8
463 && data_traits<b_dt>::data_type == data_type::u8;
464
465 if (is_s8u8) {
466 if (utils::any_null(transa, transb, offsetc, M, N, K, alpha, A, lda, ao,
467 B, ldb, bo, beta, C, ldc, co))
468 return dnnl_invalid_arguments;
469 cblas_gemm_s8u8s32_compute(CblasColMajor, cblas_storage(transa),
470 cblas_storage(transb), cblas_offset(offsetc), *M, *N, *K,
471 *alpha, A, *lda, *ao, B, *ldb, *bo, *beta, C, *ldc, co);
472 return dnnl_success;
473 }
474#endif
475 auto lda_eff = *lda, ldb_eff = *ldb;
476 auto transa_eff = *transa, transb_eff = *transb;
477
478 if (!use_reference_igemm<a_dt, b_dt>()) {
479 return gemm_s8x8s32(&transa_eff, &transb_eff, offsetc, M, N, K, alpha,
480 A, &lda_eff, ao, B, &ldb_eff, bo, beta, C, ldc, co);
481 } else {
482 dim_t ld, td;
483
484 if (transa_eff == 'p' || transa_eff == 'P') {
485 gemm_pack_storage_t a_packed {A};
486 int trans;
487 if (!a_packed.get_nocopy(trans, ld, td))
488 return dnnl_invalid_arguments;
489 A = a_packed.matrix<a_dt>();
490 lda_eff = ld;
491 transa_eff = trans == no_trans ? 'N' : 'T';
492 }
493
494 if (transb_eff == 'p' || transb_eff == 'P') {
495 gemm_pack_storage_t b_packed {B};
496 int trans;
497 if (!b_packed.get_nocopy(trans, ld, td))
498 return dnnl_invalid_arguments;
499 B = b_packed.matrix<b_dt>();
500 ldb_eff = ld;
501 transb_eff = trans == no_trans ? 'N' : 'T';
502 }
503
504 return gemm_s8x8s32(&transa_eff, &transb_eff, offsetc, M, N, K, alpha,
505 A, &lda_eff, ao, B, &ldb_eff, bo, beta, C, ldc, co);
506 }
507}
508
509dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb,
510 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
511 const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb,
512 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) {
513
514 return gemm_x8x8s32_compute(
515 transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co);
516}
517
518dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb,
519 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
520 const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb,
521 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) {
522
523 return gemm_x8x8s32_compute(
524 transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co);
525}
526
527} // namespace x64
528} // namespace cpu
529} // namespace impl
530} // namespace dnnl
531