1 | /******************************************************************************* |
2 | * Copyright 2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/platform.hpp" |
18 | |
19 | #include "cpu/gemm/gemm_pack.hpp" |
20 | |
21 | #if DNNL_X64 |
22 | #include "cpu/x64/gemm/gemm_pack.hpp" |
23 | #endif |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | |
29 | bool pack_sgemm_supported() { |
30 | #if DNNL_X64 |
31 | return x64::pack_sgemm_supported(); |
32 | #endif |
33 | return false; |
34 | } |
35 | bool pack_gemm_bf16bf16f32_supported() { |
36 | #if DNNL_X64 |
37 | return x64::pack_gemm_bf16bf16f32_supported(); |
38 | #endif |
39 | return false; |
40 | } |
41 | |
42 | dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa, |
43 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
44 | const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { |
45 | #if DNNL_X64 |
46 | return x64::sgemm_pack_get_size( |
47 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
48 | #endif |
49 | return dnnl_unimplemented; |
50 | } |
51 | |
52 | dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier, |
53 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
54 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
55 | bool *pack) { |
56 | #if DNNL_X64 |
57 | return x64::gemm_bf16bf16f32_pack_get_size( |
58 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
59 | #endif |
60 | return dnnl_unimplemented; |
61 | } |
62 | |
63 | dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier, |
64 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
65 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
66 | bool *pack) { |
67 | #if DNNL_X64 |
68 | return x64::gemm_s8u8s32_pack_get_size( |
69 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
70 | #endif |
71 | return dnnl_unimplemented; |
72 | } |
73 | |
74 | dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, |
75 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
76 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
77 | bool *pack) { |
78 | #if DNNL_X64 |
79 | return x64::gemm_s8s8s32_pack_get_size( |
80 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
81 | #endif |
82 | return dnnl_unimplemented; |
83 | } |
84 | |
85 | dnnl_status_t sgemm_pack(const char *identifier, const char *transa, |
86 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
87 | const dim_t *lda, const dim_t *ldb, const float *src, float *dst) { |
88 | #if DNNL_X64 |
89 | return x64::sgemm_pack( |
90 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
91 | #endif |
92 | return dnnl_unimplemented; |
93 | } |
94 | |
95 | dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, |
96 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
97 | const dim_t *lda, const dim_t *ldb, const bfloat16_t *src, |
98 | bfloat16_t *dst) { |
99 | #if DNNL_X64 |
100 | return x64::gemm_bf16bf16f32_pack( |
101 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
102 | #endif |
103 | return dnnl_unimplemented; |
104 | } |
105 | |
106 | dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, |
107 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
108 | const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { |
109 | #if DNNL_X64 |
110 | return x64::gemm_s8u8s32_pack( |
111 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
112 | #endif |
113 | return dnnl_unimplemented; |
114 | } |
115 | |
116 | dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa, |
117 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
118 | const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { |
119 | #if DNNL_X64 |
120 | return x64::gemm_s8s8s32_pack( |
121 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
122 | #endif |
123 | return dnnl_unimplemented; |
124 | } |
125 | |
126 | dnnl_status_t sgemm_compute(const char *transa, const char *transb, |
127 | const dim_t *M, const dim_t *N, const dim_t *K, const float *A, |
128 | const dim_t *lda, const float *B, const dim_t *ldb, const float *beta, |
129 | float *C, const dim_t *ldc) { |
130 | #if DNNL_X64 |
131 | return x64::sgemm_compute( |
132 | transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); |
133 | #endif |
134 | return dnnl_unimplemented; |
135 | } |
136 | |
137 | dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb, |
138 | const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A, |
139 | const dim_t *lda, const bfloat16_t *B, const dim_t *ldb, |
140 | const float *beta, float *C, const dim_t *ldc) { |
141 | #if DNNL_X64 |
142 | return x64::gemm_bf16bf16f32_compute( |
143 | transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc); |
144 | #endif |
145 | return dnnl_unimplemented; |
146 | } |
147 | |
148 | dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb, |
149 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
150 | const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb, |
151 | const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { |
152 | #if DNNL_X64 |
153 | return x64::gemm_s8u8s32_compute( |
154 | transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); |
155 | #endif |
156 | return dnnl_unimplemented; |
157 | } |
158 | |
159 | dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb, |
160 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
161 | const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb, |
162 | const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { |
163 | #if DNNL_X64 |
164 | return x64::gemm_s8s8s32_compute( |
165 | transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); |
166 | #endif |
167 | return dnnl_unimplemented; |
168 | } |
169 | |
170 | } // namespace cpu |
171 | } // namespace impl |
172 | } // namespace dnnl |
173 | |