1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * Copyright 2022 IBM Corporation |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include "oneapi/dnnl/dnnl.h" |
19 | |
20 | #include "common/bfloat16.hpp" |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/dnnl_traits.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/gemm/gemm.hpp" |
28 | #include "cpu/gemm/gemm_msan_unpoison.hpp" |
29 | #include "cpu/gemm/os_blas.hpp" |
30 | |
31 | #include "cpu/gemm/f32/ref_gemm_f32.hpp" |
32 | #include "cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp" |
33 | #include "cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp" |
34 | |
35 | #if DNNL_X64 |
36 | #include "cpu/x64/cpu_isa_traits.hpp" |
37 | |
38 | #include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp" |
39 | #include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp" |
40 | |
41 | #include "cpu/x64/gemm/gemm_driver.hpp" |
42 | |
43 | using namespace dnnl::impl::cpu::x64; |
44 | #elif DNNL_PPC64 |
45 | #include "cpu/ppc64/ppc64_gemm_driver.hpp" |
46 | using namespace dnnl::impl::cpu::ppc64; |
47 | #endif |
48 | |
49 | namespace dnnl { |
50 | namespace impl { |
51 | namespace cpu { |
52 | |
53 | dnnl_status_t check_gemm_input(const char *transa, const char *transb, |
54 | const dim_t *M, const dim_t *N, const dim_t *K, const void *A, |
55 | const dim_t *lda, const void *B, const dim_t *ldb, const void *C, |
56 | const dim_t *ldc, const float *alpha, const float *beta, |
57 | const bool with_bias) { |
58 | if (utils::any_null( |
59 | transa, transb, M, N, K, A, lda, B, ldb, C, ldc, alpha, beta)) |
60 | return dnnl_invalid_arguments; |
61 | if (with_bias && *beta != 0) return dnnl_unimplemented; |
62 | bool consistency = true |
63 | && utils::one_of(*transa, 'T', 't', 'N', 'n', 'P', 'p') |
64 | && utils::one_of(*transb, 'T', 't', 'N', 'n', 'P', 'p') && *M >= 0 |
65 | && *N >= 0 && *K >= 0; |
66 | |
67 | if (!consistency) return dnnl_invalid_arguments; |
68 | |
69 | bool is_packed_a = utils::one_of(*transa, 'P', 'p'); |
70 | bool is_packed_b = utils::one_of(*transb, 'P', 'p'); |
71 | bool is_trans_a = utils::one_of(*transa, 'T', 't'); |
72 | bool is_trans_b = utils::one_of(*transb, 'T', 't'); |
73 | dim_t nrow_a = is_trans_a ? *K : *M; |
74 | dim_t nrow_b = is_trans_b ? *N : *K; |
75 | consistency = true && (is_packed_a || *lda >= nstl::max(dim_t(1), nrow_a)) |
76 | && (is_packed_b || *ldb >= nstl::max(dim_t(1), nrow_b)) |
77 | && *ldc >= nstl::max(dim_t(1), *M); |
78 | if (!consistency) return dnnl_invalid_arguments; |
79 | #if DNNL_PPC64 |
80 | #ifdef __MMA__ |
81 | if (!(utils::one_of(*transa, 'n', 'N', 't', 'T') |
82 | && utils::one_of(*transb, 'n', 'N', 't', 'T'))) |
83 | return dnnl_unimplemented; |
84 | #endif |
85 | #endif |
86 | |
87 | return dnnl_success; |
88 | } |
89 | |
90 | dnnl_status_t check_gemm_x8x8x32_input(const char *offsetc, const char *transa, |
91 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
92 | const void *A, const dim_t *lda, const void *B, const dim_t *ldb, |
93 | const void *C, const dim_t *ldc, const float *alpha, const float *beta, |
94 | const bool with_bias) { |
95 | if (offsetc == nullptr) return dnnl_invalid_arguments; |
96 | if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) |
97 | return dnnl_invalid_arguments; |
98 | |
99 | return check_gemm_input(transa, transb, M, N, K, A, lda, B, ldb, C, ldc, |
100 | alpha, beta, with_bias); |
101 | } |
102 | |
103 | dnnl_status_t extended_sgemm(const char *transa, const char *transb, |
104 | const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, |
105 | const float *A, const dim_t *lda, const float *B, const dim_t *ldb, |
106 | const float *beta, float *C, const dim_t *ldc, const float *bias, |
107 | const bool force_jit_nocopy_gemm) { |
108 | dnnl_status_t status = check_gemm_input(transa, transb, M, N, K, A, lda, B, |
109 | ldb, C, ldc, alpha, beta, bias != nullptr); |
110 | if (status != dnnl_success) return status; |
111 | |
112 | #ifdef USE_CBLAS |
113 | if (!force_jit_nocopy_gemm && utils::one_of(*transa, 'n', 'N', 't', 'T') |
114 | && utils::one_of(*transb, 'n', 'N', 't', 'T')) { |
115 | bool trA = *transa == 't' || *transa == 'T'; |
116 | bool trB = *transb == 't' || *transb == 'T'; |
117 | CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; |
118 | CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; |
119 | cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, *M, *N, *K, *alpha, A, |
120 | *lda, B, *ldb, *beta, C, *ldc); |
121 | if (bias) { |
122 | // Add bias if necessary (bias is applied to columns of C) |
123 | dim_t incx = 1, incy = 1; |
124 | parallel_nd(*N, [&](dim_t n) { |
125 | dim_t offset = n * (*ldc); |
126 | cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); |
127 | }); |
128 | } |
129 | msan_unpoison_matrix(C, *M, *N, *ldc, sizeof(*C)); |
130 | return dnnl_success; |
131 | } |
132 | #endif |
133 | |
134 | #if DNNL_X64 |
135 | if (mayiuse(sse41)) { |
136 | float *dummy_ao = nullptr; |
137 | float *dummy_bo = nullptr; |
138 | return gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, alpha, |
139 | A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias, |
140 | force_jit_nocopy_gemm); |
141 | } |
142 | #endif |
143 | |
144 | return ref_gemm<float>( |
145 | transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); |
146 | } |
147 | |
148 | // Tries calling Intel MKL cblas_gemm_s8u8s32 if applicable and available |
149 | dnnl_status_t try_cblas_gemm_s8u8s32(const char *transa, const char *transb, |
150 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
151 | const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao, |
152 | const uint8_t *B, const dim_t *LDB, const uint8_t *bo, |
153 | const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co) { |
154 | #if USE_MKL_IGEMM |
155 | // cblas_gemm_s8u8s32 uses `+` to apply offsets, |
156 | // hence we need to inverse ao and b0. |
157 | if (*ao == -128 || *bo > 128) return dnnl_unimplemented; |
158 | |
159 | assert(-127 <= *ao && *ao <= 127); |
160 | assert(*bo <= 128); |
161 | |
162 | int8_t ao_s8 = -(*ao); |
163 | int8_t bo_s8 = (int8_t)(-(int32_t)*bo); |
164 | |
165 | bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); |
166 | bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); |
167 | bool AisN = (*transa == 'N' || *transa == 'n'); |
168 | bool BisN = (*transb == 'N' || *transb == 'n'); |
169 | |
170 | CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; |
171 | CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; |
172 | CBLAS_OFFSET Cblas_offsetc = OCisR |
173 | ? CblasRowOffset |
174 | : (OCisC ? CblasColOffset : CblasFixOffset); |
175 | cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, *M, |
176 | *N, *K, *alpha, A, *LDA, ao_s8, B, *LDB, bo_s8, *beta, C, *LDC, co); |
177 | msan_unpoison_matrix(C, *M, *N, *LDC, sizeof(*C)); |
178 | return dnnl_success; |
179 | #else |
180 | return dnnl_unimplemented; |
181 | #endif |
182 | } |
183 | |
184 | template <> |
185 | dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, |
186 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
187 | const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao, |
188 | const uint8_t *B, const dim_t *LDB, const uint8_t *bo, |
189 | const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co) { |
190 | dnnl_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, M, |
191 | N, K, A, LDA, B, LDB, C, LDC, alpha, beta, false); |
192 | if (status != dnnl_success) return status; |
193 | |
194 | if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; |
195 | |
196 | status = try_cblas_gemm_s8u8s32(transa, transb, offsetc, M, N, K, alpha, A, |
197 | LDA, ao, B, LDB, bo, beta, C, LDC, co); |
198 | if (status == dnnl_success) return status; |
199 | |
200 | #if DNNL_X64 |
201 | if (mayiuse(sse41)) |
202 | return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, |
203 | B, LDB, bo, beta, C, LDC, co, false); |
204 | #elif DNNL_PPC64 |
205 | #ifdef __MMA__ |
206 | int ATflag = (*transa == 'T') || (*transa == 't'); |
207 | int BTflag = (*transb == 'T') || (*transb == 't'); |
208 | |
209 | return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha, |
210 | A, *LDA, ao, B, *LDB, bo, C, *beta, *LDC, co, 0); |
211 | #endif |
212 | #endif |
213 | |
214 | return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, |
215 | B, LDB, bo, beta, C, LDC, co); |
216 | } |
217 | |
218 | template <> |
219 | dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb, |
220 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
221 | const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao, |
222 | const int8_t *B, const dim_t *LDB, const int8_t *bo, const float *beta, |
223 | int32_t *C, const dim_t *LDC, const int32_t *co) { |
224 | dnnl_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, M, |
225 | N, K, A, LDA, B, LDB, C, LDC, alpha, beta, false); |
226 | if (status != dnnl_success) return status; |
227 | |
228 | if (*M == 0 || *N == 0 || *K == 0) return dnnl_success; |
229 | |
230 | #if DNNL_X64 |
231 | bool use_jit = mayiuse(avx512_core); |
232 | bool use_s8u8 = true |
233 | && utils::everyone_is(0, *ao, *bo) // so far a requirement |
234 | && IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41)); |
235 | |
236 | if (use_jit) |
237 | return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, |
238 | B, LDB, bo, beta, C, LDC, co, false); |
239 | else if (use_s8u8) |
240 | return simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, alpha, A, |
241 | LDA, ao, B, LDB, bo, beta, C, LDC, co); |
242 | #endif |
243 | |
244 | #if DNNL_PPC64 |
245 | #ifdef __MMA__ |
246 | int ATflag = (*transa == 'T') || (*transa == 't'); |
247 | int BTflag = (*transb == 'T') || (*transb == 't'); |
248 | |
249 | // Note please that the coercion of "B" and "bo" from int8_t to uint8_t is |
250 | // accompanied by the last parameter being set to "1" instead of "0", as |
251 | // in the analogous call in the previous routine above. |
252 | // This last parameter flags the fact of the coercion, so the called routine |
253 | // can process "B" and "bo" appropriately. |
254 | |
255 | return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha, |
256 | A, *LDA, ao, (const uint8_t *)B, *LDB, (const uint8_t *)bo, C, |
257 | *beta, *LDC, co, 1); |
258 | #endif |
259 | #endif |
260 | |
261 | return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao, |
262 | B, LDB, bo, beta, C, LDC, co); |
263 | } |
264 | |
265 | dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, |
266 | const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, |
267 | const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, |
268 | const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) { |
269 | dnnl_status_t status = check_gemm_input(transa, transb, M, N, K, A, lda, B, |
270 | ldb, C, ldc, alpha, beta, false); |
271 | if (status != dnnl_success) return status; |
272 | |
273 | #if DNNL_X64 |
274 | char *dummyOffsetC = nullptr; |
275 | bfloat16_t *dummy_ao = nullptr; |
276 | bfloat16_t *dummy_bo = nullptr; |
277 | float *dummy_co = nullptr; |
278 | |
279 | if (mayiuse(avx512_core)) |
280 | return gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha, |
281 | (const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B, |
282 | ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false); |
283 | #elif DNNL_PPC64 |
284 | #if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__) |
285 | bool trA = *transa == 't' || *transa == 'T'; |
286 | bool trB = *transb == 't' || *transb == 'T'; |
287 | CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; |
288 | CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; |
289 | |
290 | cblas_sbgemm(CblasColMajor, Cblas_trA, Cblas_trB, *M, *N, *K, *alpha, |
291 | (const bfloat16 *)A, *lda, (const bfloat16 *)B, *ldb, *beta, C, |
292 | *ldc); |
293 | return dnnl_success; |
294 | #endif |
295 | #endif |
296 | |
297 | return dnnl_unimplemented; |
298 | } |
299 | |
300 | } // namespace cpu |
301 | } // namespace impl |
302 | } // namespace dnnl |
303 | |