1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3* Copyright 2022 IBM Corporation
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "oneapi/dnnl/dnnl.h"
19
20#include "common/bfloat16.hpp"
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/dnnl_traits.hpp"
24#include "common/nstl.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/gemm/gemm.hpp"
28#include "cpu/gemm/gemm_msan_unpoison.hpp"
29#include "cpu/gemm/os_blas.hpp"
30
31#include "cpu/gemm/f32/ref_gemm_f32.hpp"
32#include "cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp"
33#include "cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp"
34
35#if DNNL_X64
36#include "cpu/x64/cpu_isa_traits.hpp"
37
38#include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp"
39#include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp"
40
41#include "cpu/x64/gemm/gemm_driver.hpp"
42
43using namespace dnnl::impl::cpu::x64;
44#elif DNNL_PPC64
45#include "cpu/ppc64/ppc64_gemm_driver.hpp"
46using namespace dnnl::impl::cpu::ppc64;
47#endif
48
49namespace dnnl {
50namespace impl {
51namespace cpu {
52
53dnnl_status_t check_gemm_input(const char *transa, const char *transb,
54 const dim_t *M, const dim_t *N, const dim_t *K, const void *A,
55 const dim_t *lda, const void *B, const dim_t *ldb, const void *C,
56 const dim_t *ldc, const float *alpha, const float *beta,
57 const bool with_bias) {
58 if (utils::any_null(
59 transa, transb, M, N, K, A, lda, B, ldb, C, ldc, alpha, beta))
60 return dnnl_invalid_arguments;
61 if (with_bias && *beta != 0) return dnnl_unimplemented;
62 bool consistency = true
63 && utils::one_of(*transa, 'T', 't', 'N', 'n', 'P', 'p')
64 && utils::one_of(*transb, 'T', 't', 'N', 'n', 'P', 'p') && *M >= 0
65 && *N >= 0 && *K >= 0;
66
67 if (!consistency) return dnnl_invalid_arguments;
68
69 bool is_packed_a = utils::one_of(*transa, 'P', 'p');
70 bool is_packed_b = utils::one_of(*transb, 'P', 'p');
71 bool is_trans_a = utils::one_of(*transa, 'T', 't');
72 bool is_trans_b = utils::one_of(*transb, 'T', 't');
73 dim_t nrow_a = is_trans_a ? *K : *M;
74 dim_t nrow_b = is_trans_b ? *N : *K;
75 consistency = true && (is_packed_a || *lda >= nstl::max(dim_t(1), nrow_a))
76 && (is_packed_b || *ldb >= nstl::max(dim_t(1), nrow_b))
77 && *ldc >= nstl::max(dim_t(1), *M);
78 if (!consistency) return dnnl_invalid_arguments;
79#if DNNL_PPC64
80#ifdef __MMA__
81 if (!(utils::one_of(*transa, 'n', 'N', 't', 'T')
82 && utils::one_of(*transb, 'n', 'N', 't', 'T')))
83 return dnnl_unimplemented;
84#endif
85#endif
86
87 return dnnl_success;
88}
89
90dnnl_status_t check_gemm_x8x8x32_input(const char *offsetc, const char *transa,
91 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
92 const void *A, const dim_t *lda, const void *B, const dim_t *ldb,
93 const void *C, const dim_t *ldc, const float *alpha, const float *beta,
94 const bool with_bias) {
95 if (offsetc == nullptr) return dnnl_invalid_arguments;
96 if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
97 return dnnl_invalid_arguments;
98
99 return check_gemm_input(transa, transb, M, N, K, A, lda, B, ldb, C, ldc,
100 alpha, beta, with_bias);
101}
102
103dnnl_status_t extended_sgemm(const char *transa, const char *transb,
104 const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha,
105 const float *A, const dim_t *lda, const float *B, const dim_t *ldb,
106 const float *beta, float *C, const dim_t *ldc, const float *bias,
107 const bool force_jit_nocopy_gemm) {
108 dnnl_status_t status = check_gemm_input(transa, transb, M, N, K, A, lda, B,
109 ldb, C, ldc, alpha, beta, bias != nullptr);
110 if (status != dnnl_success) return status;
111
112#ifdef USE_CBLAS
113 if (!force_jit_nocopy_gemm && utils::one_of(*transa, 'n', 'N', 't', 'T')
114 && utils::one_of(*transb, 'n', 'N', 't', 'T')) {
115 bool trA = *transa == 't' || *transa == 'T';
116 bool trB = *transb == 't' || *transb == 'T';
117 CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
118 CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
119 cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, *M, *N, *K, *alpha, A,
120 *lda, B, *ldb, *beta, C, *ldc);
121 if (bias) {
122 // Add bias if necessary (bias is applied to columns of C)
123 dim_t incx = 1, incy = 1;
124 parallel_nd(*N, [&](dim_t n) {
125 dim_t offset = n * (*ldc);
126 cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
127 });
128 }
129 msan_unpoison_matrix(C, *M, *N, *ldc, sizeof(*C));
130 return dnnl_success;
131 }
132#endif
133
134#if DNNL_X64
135 if (mayiuse(sse41)) {
136 float *dummy_ao = nullptr;
137 float *dummy_bo = nullptr;
138 return gemm_driver(transa, transb, bias ? "C" : nullptr, M, N, K, alpha,
139 A, lda, dummy_ao, B, ldb, dummy_bo, beta, C, ldc, bias,
140 force_jit_nocopy_gemm);
141 }
142#endif
143
144 return ref_gemm<float>(
145 transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
146}
147
148// Tries calling Intel MKL cblas_gemm_s8u8s32 if applicable and available
149dnnl_status_t try_cblas_gemm_s8u8s32(const char *transa, const char *transb,
150 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
151 const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
152 const uint8_t *B, const dim_t *LDB, const uint8_t *bo,
153 const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co) {
154#if USE_MKL_IGEMM
155 // cblas_gemm_s8u8s32 uses `+` to apply offsets,
156 // hence we need to inverse ao and b0.
157 if (*ao == -128 || *bo > 128) return dnnl_unimplemented;
158
159 assert(-127 <= *ao && *ao <= 127);
160 assert(*bo <= 128);
161
162 int8_t ao_s8 = -(*ao);
163 int8_t bo_s8 = (int8_t)(-(int32_t)*bo);
164
165 bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
166 bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
167 bool AisN = (*transa == 'N' || *transa == 'n');
168 bool BisN = (*transb == 'N' || *transb == 'n');
169
170 CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
171 CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
172 CBLAS_OFFSET Cblas_offsetc = OCisR
173 ? CblasRowOffset
174 : (OCisC ? CblasColOffset : CblasFixOffset);
175 cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, *M,
176 *N, *K, *alpha, A, *LDA, ao_s8, B, *LDB, bo_s8, *beta, C, *LDC, co);
177 msan_unpoison_matrix(C, *M, *N, *LDC, sizeof(*C));
178 return dnnl_success;
179#else
180 return dnnl_unimplemented;
181#endif
182}
183
184template <>
185dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
186 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
187 const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
188 const uint8_t *B, const dim_t *LDB, const uint8_t *bo,
189 const float *beta, int32_t *C, const dim_t *LDC, const int32_t *co) {
190 dnnl_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, M,
191 N, K, A, LDA, B, LDB, C, LDC, alpha, beta, false);
192 if (status != dnnl_success) return status;
193
194 if (*M == 0 || *N == 0 || *K == 0) return dnnl_success;
195
196 status = try_cblas_gemm_s8u8s32(transa, transb, offsetc, M, N, K, alpha, A,
197 LDA, ao, B, LDB, bo, beta, C, LDC, co);
198 if (status == dnnl_success) return status;
199
200#if DNNL_X64
201 if (mayiuse(sse41))
202 return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao,
203 B, LDB, bo, beta, C, LDC, co, false);
204#elif DNNL_PPC64
205#ifdef __MMA__
206 int ATflag = (*transa == 'T') || (*transa == 't');
207 int BTflag = (*transb == 'T') || (*transb == 't');
208
209 return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha,
210 A, *LDA, ao, B, *LDB, bo, C, *beta, *LDC, co, 0);
211#endif
212#endif
213
214 return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao,
215 B, LDB, bo, beta, C, LDC, co);
216}
217
218template <>
219dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
220 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
221 const float *alpha, const int8_t *A, const dim_t *LDA, const int8_t *ao,
222 const int8_t *B, const dim_t *LDB, const int8_t *bo, const float *beta,
223 int32_t *C, const dim_t *LDC, const int32_t *co) {
224 dnnl_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, M,
225 N, K, A, LDA, B, LDB, C, LDC, alpha, beta, false);
226 if (status != dnnl_success) return status;
227
228 if (*M == 0 || *N == 0 || *K == 0) return dnnl_success;
229
230#if DNNL_X64
231 bool use_jit = mayiuse(avx512_core);
232 bool use_s8u8 = true
233 && utils::everyone_is(0, *ao, *bo) // so far a requirement
234 && IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41));
235
236 if (use_jit)
237 return gemm_driver(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao,
238 B, LDB, bo, beta, C, LDC, co, false);
239 else if (use_s8u8)
240 return simple_gemm_s8s8s32(transa, transb, offsetc, M, N, K, alpha, A,
241 LDA, ao, B, LDB, bo, beta, C, LDC, co);
242#endif
243
244#if DNNL_PPC64
245#ifdef __MMA__
246 int ATflag = (*transa == 'T') || (*transa == 't');
247 int BTflag = (*transb == 'T') || (*transb == 't');
248
249 // Note please that the coercion of "B" and "bo" from int8_t to uint8_t is
250 // accompanied by the last parameter being set to "1" instead of "0", as
251 // in the analogous call in the previous routine above.
252 // This last parameter flags the fact of the coercion, so the called routine
253 // can process "B" and "bo" appropriately.
254
255 return cblas_gemm_s8x8s32_ppc64(ATflag, BTflag, offsetc, *M, *N, *K, *alpha,
256 A, *LDA, ao, (const uint8_t *)B, *LDB, (const uint8_t *)bo, C,
257 *beta, *LDC, co, 1);
258#endif
259#endif
260
261 return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, alpha, A, LDA, ao,
262 B, LDB, bo, beta, C, LDC, co);
263}
264
265dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
266 const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha,
267 const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B,
268 const dim_t *ldb, const float *beta, float *C, const dim_t *ldc) {
269 dnnl_status_t status = check_gemm_input(transa, transb, M, N, K, A, lda, B,
270 ldb, C, ldc, alpha, beta, false);
271 if (status != dnnl_success) return status;
272
273#if DNNL_X64
274 char *dummyOffsetC = nullptr;
275 bfloat16_t *dummy_ao = nullptr;
276 bfloat16_t *dummy_bo = nullptr;
277 float *dummy_co = nullptr;
278
279 if (mayiuse(avx512_core))
280 return gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha,
281 (const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B,
282 ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false);
283#elif DNNL_PPC64
284#if defined(USE_CBLAS) && defined(BLAS_HAS_SBGEMM) && defined(__MMA__)
285 bool trA = *transa == 't' || *transa == 'T';
286 bool trB = *transb == 't' || *transb == 'T';
287 CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
288 CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
289
290 cblas_sbgemm(CblasColMajor, Cblas_trA, Cblas_trB, *M, *N, *K, *alpha,
291 (const bfloat16 *)A, *lda, (const bfloat16 *)B, *ldb, *beta, C,
292 *ldc);
293 return dnnl_success;
294#endif
295#endif
296
297 return dnnl_unimplemented;
298}
299
300} // namespace cpu
301} // namespace impl
302} // namespace dnnl
303