1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/platform.hpp"
18
19#include "cpu/gemm/gemm_pack.hpp"
20
21#if DNNL_X64
22#include "cpu/x64/gemm/gemm_pack.hpp"
23#endif
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28
29bool pack_sgemm_supported() {
30#if DNNL_X64
31 return x64::pack_sgemm_supported();
32#endif
33 return false;
34}
35bool pack_gemm_bf16bf16f32_supported() {
36#if DNNL_X64
37 return x64::pack_gemm_bf16bf16f32_supported();
38#endif
39 return false;
40}
41
42dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa,
43 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
44 const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) {
45#if DNNL_X64
46 return x64::sgemm_pack_get_size(
47 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
48#endif
49 return dnnl_unimplemented;
50}
51
52dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier,
53 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
54 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
55 bool *pack) {
56#if DNNL_X64
57 return x64::gemm_bf16bf16f32_pack_get_size(
58 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
59#endif
60 return dnnl_unimplemented;
61}
62
63dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier,
64 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
65 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
66 bool *pack) {
67#if DNNL_X64
68 return x64::gemm_s8u8s32_pack_get_size(
69 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
70#endif
71 return dnnl_unimplemented;
72}
73
74dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier,
75 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
76 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
77 bool *pack) {
78#if DNNL_X64
79 return x64::gemm_s8s8s32_pack_get_size(
80 identifier, transa, transb, M, N, K, lda, ldb, size, pack);
81#endif
82 return dnnl_unimplemented;
83}
84
85dnnl_status_t sgemm_pack(const char *identifier, const char *transa,
86 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
87 const dim_t *lda, const dim_t *ldb, const float *src, float *dst) {
88#if DNNL_X64
89 return x64::sgemm_pack(
90 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
91#endif
92 return dnnl_unimplemented;
93}
94
95dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa,
96 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
97 const dim_t *lda, const dim_t *ldb, const bfloat16_t *src,
98 bfloat16_t *dst) {
99#if DNNL_X64
100 return x64::gemm_bf16bf16f32_pack(
101 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
102#endif
103 return dnnl_unimplemented;
104}
105
106dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa,
107 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
108 const dim_t *lda, const dim_t *ldb, const void *src, void *dst) {
109#if DNNL_X64
110 return x64::gemm_s8u8s32_pack(
111 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
112#endif
113 return dnnl_unimplemented;
114}
115
116dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa,
117 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
118 const dim_t *lda, const dim_t *ldb, const void *src, void *dst) {
119#if DNNL_X64
120 return x64::gemm_s8s8s32_pack(
121 identifier, transa, transb, M, N, K, lda, ldb, src, dst);
122#endif
123 return dnnl_unimplemented;
124}
125
126dnnl_status_t sgemm_compute(const char *transa, const char *transb,
127 const dim_t *M, const dim_t *N, const dim_t *K, const float *A,
128 const dim_t *lda, const float *B, const dim_t *ldb, const float *beta,
129 float *C, const dim_t *ldc) {
130#if DNNL_X64
131 return x64::sgemm_compute(
132 transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc);
133#endif
134 return dnnl_unimplemented;
135}
136
137dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb,
138 const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A,
139 const dim_t *lda, const bfloat16_t *B, const dim_t *ldb,
140 const float *beta, float *C, const dim_t *ldc) {
141#if DNNL_X64
142 return x64::gemm_bf16bf16f32_compute(
143 transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc);
144#endif
145 return dnnl_unimplemented;
146}
147
148dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb,
149 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
150 const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb,
151 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) {
152#if DNNL_X64
153 return x64::gemm_s8u8s32_compute(
154 transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co);
155#endif
156 return dnnl_unimplemented;
157}
158
159dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb,
160 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
161 const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb,
162 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) {
163#if DNNL_X64
164 return x64::gemm_s8s8s32_compute(
165 transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co);
166#endif
167 return dnnl_unimplemented;
168}
169
170} // namespace cpu
171} // namespace impl
172} // namespace dnnl
173