1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * Copyright 2022 Arm Ltd. and affiliates |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #ifndef CPU_GEMM_GEMM_HPP |
19 | #define CPU_GEMM_GEMM_HPP |
20 | |
21 | #include "oneapi/dnnl/dnnl_types.h" |
22 | |
23 | #include "common/bfloat16.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | |
27 | #include "cpu/gemm/os_blas.hpp" |
28 | |
29 | #if DNNL_X64 |
30 | #include "cpu/x64/cpu_isa_traits.hpp" |
31 | #endif |
32 | |
33 | #if DNNL_AARCH64 |
34 | #include "cpu/aarch64/cpu_isa_traits.hpp" |
35 | #endif |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | |
41 | dnnl_status_t extended_sgemm(const char *transa, const char *transb, |
42 | const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, |
43 | const float *A, const dim_t *lda, const float *B, const dim_t *ldb, |
44 | const float *beta, float *C, const dim_t *ldc, |
45 | const float *bias = nullptr, bool force_jit_gemm = false); |
46 | |
47 | template <typename b_dt> |
48 | dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, |
49 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
50 | const float *alpha, const int8_t *A, const dim_t *lda, const int8_t *ao, |
51 | const b_dt *B, const dim_t *ldb, const b_dt *bo, const float *beta, |
52 | int32_t *c, const dim_t *ldc, const int32_t *co); |
53 | |
54 | dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, |
55 | const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, |
56 | const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, |
57 | const dim_t *ldb, const float *beta, float *C, const dim_t *ldc); |
58 | |
59 | #if defined(USE_CBLAS) |
60 | #define GEMM_IMPL_STR "x64:gemm:blas" |
61 | #elif DNNL_X64 |
62 | #define GEMM_IMPL_STR "x64:gemm:jit" |
63 | #else |
64 | #define GEMM_IMPL_STR "gemm:ref" |
65 | #endif |
66 | |
67 | #if USE_MKL_IGEMM |
68 | #define IGEMM_S8U8S32_IMPL_STR "x64:gemm_s8u8s32:blas" |
69 | #define IGEMM_S8S8S32_IMPL_STR "x64:gemm_s8s8s32:blas" |
70 | #elif DNNL_X64 |
71 | #define IGEMM_S8U8S32_IMPL_STR "x64:gemm_s8u8s32:jit" |
72 | #define IGEMM_S8S8S32_IMPL_STR "x64:gemm_s8s8s32:jit" |
73 | #else |
74 | #define IGEMM_S8U8S32_IMPL_STR "gemm_s8u8s32:ref" |
75 | #define IGEMM_S8S8S32_IMPL_STR "gemm_s8s8s32:ref" |
76 | #endif |
77 | |
78 | #if !defined(USE_MKL_IGEMM) && defined(DNNL_X64) |
79 | #define IGEMM_S8U8S32_ISA_STR \ |
80 | JIT_IMPL_NAME_HELPER(IGEMM_S8U8S32_IMPL_STR ":", \ |
81 | mayiuse(avx512_core_vnni) \ |
82 | ? avx512_core_vnni \ |
83 | : (mayiuse(avx512_core) ? avx512_core : isa_undef), \ |
84 | "") |
85 | #else |
86 | #define IGEMM_S8U8S32_ISA_STR IGEMM_S8U8S32_IMPL_STR |
87 | #endif |
88 | |
89 | } // namespace cpu |
90 | } // namespace impl |
91 | } // namespace dnnl |
92 | |
93 | #endif // CPU_GEMM_GEMM_HPP |
94 | |