1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl_types.h" |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | |
22 | #include "cpu/gemm/gemm.hpp" |
23 | #include "cpu/gemm/gemm_pack.hpp" |
24 | #include "cpu/gemm/os_blas.hpp" |
25 | |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | |
28 | #include "cpu/x64/gemm/gemm_driver.hpp" |
29 | #include "cpu/x64/gemm/gemm_utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | bool pack_sgemm_supported() { |
37 | #if USE_MKL_PACKED_GEMM |
38 | return true; |
39 | #else |
40 | return mayiuse(sse41); |
41 | #endif |
42 | } |
43 | |
44 | bool pack_gemm_bf16bf16f32_supported() { |
45 | return mayiuse(avx512_core); |
46 | } |
47 | |
48 | #if USE_MKL_PACKED_GEMM |
49 | static inline CBLAS_IDENTIFIER cblas_identifier(const char *identifier) { |
50 | return utils::one_of(*identifier, 'a', 'A') ? CblasAMatrix : CblasBMatrix; |
51 | } |
52 | |
53 | static inline CBLAS_TRANSPOSE cblas_transpose(const char *trans) { |
54 | return utils::one_of(*trans, 'n', 'N') ? CblasNoTrans : CblasTrans; |
55 | } |
56 | |
57 | static inline MKL_INT cblas_storage(const char *trans) { |
58 | switch (*trans) { |
59 | case 'N': |
60 | case 'n': return CblasNoTrans; |
61 | case 'T': |
62 | case 't': return CblasTrans; |
63 | default: return CblasPacked; |
64 | } |
65 | } |
66 | |
67 | static inline CBLAS_OFFSET cblas_offset(const char *offset) { |
68 | switch (*offset) { |
69 | case 'R': |
70 | case 'r': return CblasRowOffset; |
71 | case 'C': |
72 | case 'c': return CblasColOffset; |
73 | default: return CblasFixOffset; |
74 | } |
75 | } |
76 | #endif |
77 | |
78 | #if !USE_MKL_PACKED_GEMM |
79 | template <typename a_dt, typename b_dt> |
80 | static inline bool use_reference_igemm(void) { |
81 | constexpr bool is_s8u8 = true |
82 | && data_traits<a_dt>::data_type == data_type::s8 |
83 | && data_traits<b_dt>::data_type == data_type::u8; |
84 | if (is_s8u8) |
85 | return !mayiuse(sse41); |
86 | else |
87 | return !mayiuse(avx512_core); |
88 | } |
89 | |
90 | #else |
91 | template <typename a_dt, typename b_dt> |
92 | static inline bool use_reference_igemm(void) { |
93 | return true; |
94 | } |
95 | #endif |
96 | |
97 | template <typename T> |
98 | static bool is_good_ld(dim_t ld) { |
99 | static constexpr auto align = 64 / sizeof(T); |
100 | static constexpr auto no_align = 2048 / sizeof(T); |
101 | |
102 | return ((ld % align) == 0) && ((ld % no_align) != 0); |
103 | } |
104 | |
105 | static dnnl_status_t check_pack_get_size_input(const char *identifier, |
106 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
107 | const dim_t *K, const dim_t *lda, const dim_t *ldb) { |
108 | |
109 | if (utils::any_null(identifier, transa, transb, M, N, K, lda, ldb)) |
110 | return dnnl_invalid_arguments; |
111 | |
112 | bool is_transa = utils::one_of(*transa, 'T', 't'); |
113 | bool is_transb = utils::one_of(*transb, 'T', 't'); |
114 | |
115 | bool ok = true && utils::one_of(*transa, 'T', 't', 'N', 'n') |
116 | && utils::one_of(*transb, 'T', 't', 'N', 'n') |
117 | && utils::one_of(*identifier, 'A', 'a', 'B', 'b') && *M >= 0 |
118 | && *N >= 0 && *K >= 0 |
119 | && *lda >= nstl::max(dim_t(1), !is_transa ? *M : *K) |
120 | && *ldb >= nstl::max(dim_t(1), !is_transb ? *K : *N); |
121 | |
122 | if (!ok) return dnnl_invalid_arguments; |
123 | |
124 | return dnnl_success; |
125 | } |
126 | |
127 | static dnnl_status_t check_pack_input(const char *identifier, |
128 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
129 | const dim_t *K, const float *alpha, const dim_t *lda, const dim_t *ldb, |
130 | const void *src, void *dst) { |
131 | if (utils::any_null(src, dst, alpha)) return dnnl_invalid_arguments; |
132 | |
133 | return check_pack_get_size_input( |
134 | identifier, transa, transb, M, N, K, lda, ldb); |
135 | } |
136 | |
137 | template <typename a_dt, typename b_dt, typename c_dt> |
138 | static dnnl_status_t gemm_pack_driver(const char *identifier, |
139 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
140 | const dim_t *K, const float *alpha, const dim_t *lda, const dim_t *ldb, |
141 | const void *src, gemm_pack_storage_t *pack_dst, bool measure_only) { |
142 | |
143 | a_dt oa = 0; |
144 | b_dt ob = 0; |
145 | |
146 | const a_dt *a = nullptr; |
147 | const b_dt *b = nullptr; |
148 | pack_type packing; |
149 | |
150 | if (utils::one_of(*identifier, 'a', 'A')) { |
151 | a = (const a_dt *)src; |
152 | packing = pack_type::pack_a; |
153 | } else { |
154 | b = (const b_dt *)src; |
155 | packing = pack_type::pack_b; |
156 | } |
157 | |
158 | return gemm_driver<a_dt, b_dt, c_dt>(transa, transb, "N" , M, N, K, alpha, a, |
159 | lda, &oa, b, ldb, &ob, nullptr, nullptr, nullptr, nullptr, false, |
160 | packing, pack_dst, measure_only); |
161 | } |
162 | |
163 | dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa, |
164 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
165 | const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack) { |
166 | |
167 | if (!pack_sgemm_supported()) return dnnl_unimplemented; |
168 | |
169 | dnnl_status_t result; |
170 | *size = 0; |
171 | if (pack) *pack = true; |
172 | |
173 | result = check_pack_get_size_input( |
174 | identifier, transa, transb, M, N, K, lda, ldb); |
175 | if (result != dnnl_success) return result; |
176 | |
177 | #if USE_MKL_PACKED_GEMM |
178 | *size = cblas_sgemm_pack_get_size(cblas_identifier(identifier), *M, *N, *K); |
179 | #else |
180 | bool do_a = utils::one_of(*identifier, 'a', 'A'); |
181 | float alpha = 1.0f; |
182 | gemm_pack_storage_shell_t shell {dnnl_get_max_threads()}; |
183 | if (!shell.get()) return dnnl_out_of_memory; |
184 | |
185 | result = gemm_pack_driver<float, float, float>(identifier, transa, transb, |
186 | M, N, K, &alpha, lda, ldb, nullptr, &shell, true); |
187 | if (result != dnnl_success) return result; |
188 | |
189 | *size = shell.size(); |
190 | if (pack) { |
191 | *pack = !(shell.single_nocopy() |
192 | && utils::one_of(do_a ? *transa : *transb, 'n', 'N') |
193 | && is_good_ld<float>(do_a ? *lda : *ldb)); |
194 | } |
195 | #endif |
196 | |
197 | return dnnl_success; |
198 | } |
199 | |
200 | dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier, |
201 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
202 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
203 | bool *pack) { |
204 | |
205 | if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented; |
206 | |
207 | dnnl_status_t result; |
208 | *size = 0; |
209 | if (pack) *pack = true; |
210 | |
211 | result = check_pack_get_size_input( |
212 | identifier, transa, transb, M, N, K, lda, ldb); |
213 | if (result != dnnl_success) return result; |
214 | |
215 | float alpha = 1.0f; |
216 | gemm_pack_storage_shell_t shell {dnnl_get_max_threads()}; |
217 | if (!shell.get()) return dnnl_out_of_memory; |
218 | |
219 | result = gemm_pack_driver<bfloat16_t, bfloat16_t, float>(identifier, transa, |
220 | transb, M, N, K, &alpha, lda, ldb, nullptr, &shell, true); |
221 | if (result != dnnl_success) return result; |
222 | |
223 | *size = shell.size(); |
224 | |
225 | return dnnl_success; |
226 | } |
227 | |
228 | template <typename a_dt, typename b_dt> |
229 | dnnl_status_t gemm_x8x8s32_pack_get_size(const char *identifier, |
230 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
231 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
232 | bool *pack) { |
233 | |
234 | dnnl_status_t result; |
235 | *size = 0; |
236 | if (pack) *pack = true; |
237 | |
238 | result = check_pack_get_size_input( |
239 | identifier, transa, transb, M, N, K, lda, ldb); |
240 | if (result != dnnl_success) return result; |
241 | |
242 | #if USE_MKL_PACKED_GEMM |
243 | constexpr bool is_s8u8 = true |
244 | && data_traits<a_dt>::data_type == data_type::s8 |
245 | && data_traits<b_dt>::data_type == data_type::u8; |
246 | |
247 | if (is_s8u8) { |
248 | *size = cblas_gemm_s8u8s32_pack_get_size( |
249 | cblas_identifier(identifier), *M, *N, *K); |
250 | return dnnl_success; |
251 | } |
252 | #endif |
253 | |
254 | bool do_a = utils::one_of(*identifier, 'a', 'A'); |
255 | float alpha = 1.0f; |
256 | gemm_pack_storage_shell_t shell {dnnl_get_max_threads(), do_a, !do_a}; |
257 | if (!shell.get()) return dnnl_out_of_memory; |
258 | |
259 | if (!use_reference_igemm<a_dt, b_dt>()) { |
260 | result = gemm_pack_driver<a_dt, b_dt, int32_t>(identifier, transa, |
261 | transb, M, N, K, &alpha, lda, ldb, nullptr, &shell, true); |
262 | if (result != dnnl_success) return result; |
263 | } else { |
264 | auto rows = do_a ? *M : *K; |
265 | auto cols = do_a ? *K : *N; |
266 | if (do_a) { |
267 | gemm_utils::prep_gemm_pack<int8_t, int32_t>( |
268 | do_a, no_trans, rows, cols, &shell); |
269 | } else { |
270 | gemm_utils::prep_gemm_pack<uint8_t, int32_t>( |
271 | do_a, no_trans, rows, cols, &shell); |
272 | } |
273 | } |
274 | |
275 | *size = shell.size(); |
276 | if (pack) { |
277 | *pack = !(shell.single_nocopy() |
278 | && utils::one_of(do_a ? *transa : *transb, 'n', 'N') |
279 | && is_good_ld<float>(do_a ? *lda : *ldb)); |
280 | } |
281 | |
282 | return dnnl_success; |
283 | } |
284 | |
285 | dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier, |
286 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
287 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
288 | bool *pack) { |
289 | |
290 | return gemm_x8x8s32_pack_get_size<int8_t, uint8_t>( |
291 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
292 | } |
293 | |
294 | dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier, |
295 | const char *transa, const char *transb, const dim_t *M, const dim_t *N, |
296 | const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size, |
297 | bool *pack) { |
298 | |
299 | return gemm_x8x8s32_pack_get_size<int8_t, int8_t>( |
300 | identifier, transa, transb, M, N, K, lda, ldb, size, pack); |
301 | } |
302 | |
303 | dnnl_status_t sgemm_pack(const char *identifier, const char *transa, |
304 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
305 | const dim_t *lda, const dim_t *ldb, const float *src, float *dst) { |
306 | float one = 1.f, *alpha = &one; |
307 | |
308 | if (!pack_sgemm_supported()) return dnnl_unimplemented; |
309 | |
310 | auto result = check_pack_input( |
311 | identifier, transa, transb, M, N, K, alpha, lda, ldb, src, dst); |
312 | if (result != dnnl_success) return result; |
313 | |
314 | #if USE_MKL_PACKED_GEMM |
315 | auto cblas_id = cblas_identifier(identifier); |
316 | auto ld = (cblas_id == CblasAMatrix) ? *lda : *ldb; |
317 | auto trans = (cblas_id == CblasAMatrix) ? transa : transb; |
318 | cblas_sgemm_pack(CblasColMajor, cblas_id, cblas_transpose(trans), *M, *N, |
319 | *K, *alpha, src, ld, dst); |
320 | return dnnl_success; |
321 | #else |
322 | gemm_pack_storage_t pack_dst(dst, false); |
323 | |
324 | return gemm_pack_driver<float, float, float>(identifier, transa, transb, M, |
325 | N, K, alpha, lda, ldb, src, &pack_dst, false); |
326 | #endif |
327 | } |
328 | |
329 | dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa, |
330 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
331 | const dim_t *lda, const dim_t *ldb, const bfloat16_t *src, |
332 | bfloat16_t *dst) { |
333 | float one = 1.f, *alpha = &one; |
334 | |
335 | if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented; |
336 | |
337 | auto result = check_pack_input( |
338 | identifier, transa, transb, M, N, K, alpha, lda, ldb, src, dst); |
339 | if (result != dnnl_success) return result; |
340 | |
341 | gemm_pack_storage_t pack_dst(dst, false); |
342 | |
343 | return gemm_pack_driver<bfloat16_t, bfloat16_t, float>(identifier, transa, |
344 | transb, M, N, K, alpha, lda, ldb, src, &pack_dst, false); |
345 | } |
346 | |
347 | template <typename a_dt, typename b_dt> |
348 | dnnl_status_t gemm_x8x8s32_pack(const char *identifier, const char *transa, |
349 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
350 | const dim_t *lda, const dim_t *ldb, const void *src_void, void *dst) { |
351 | |
352 | float alpha = 1.0f; // Not used with igemm. |
353 | auto result = check_pack_input(identifier, transa, transb, M, N, K, &alpha, |
354 | lda, ldb, src_void, dst); |
355 | if (result != dnnl_success) return result; |
356 | |
357 | #if USE_MKL_PACKED_GEMM |
358 | constexpr bool is_s8u8 = true |
359 | && data_traits<a_dt>::data_type == data_type::s8 |
360 | && data_traits<b_dt>::data_type == data_type::u8; |
361 | |
362 | if (is_s8u8) { |
363 | auto cblas_id = cblas_identifier(identifier); |
364 | auto ld = (cblas_id == CblasAMatrix) ? *lda : *ldb; |
365 | auto trans = (cblas_id == CblasAMatrix) ? transa : transb; |
366 | cblas_gemm_s8u8s32_pack(CblasColMajor, cblas_id, cblas_transpose(trans), |
367 | *M, *N, *K, src_void, ld, dst); |
368 | return dnnl_success; |
369 | } |
370 | #endif |
371 | gemm_pack_storage_t pack_dst(dst, false); |
372 | |
373 | if (!use_reference_igemm<a_dt, b_dt>()) { |
374 | return gemm_pack_driver<a_dt, b_dt, int32_t>(identifier, transa, transb, |
375 | M, N, K, &alpha, lda, ldb, src_void, &pack_dst, false); |
376 | } else { |
377 | bool do_a = utils::one_of(*identifier, 'a', 'A'); |
378 | bool is_trans = utils::one_of(do_a ? *transa : *transb, 't', 'T'); |
379 | auto ld = do_a ? *lda : *ldb; |
380 | auto rows = do_a ? *M : *K; |
381 | auto cols = do_a ? *K : *N; |
382 | |
383 | if (do_a) { |
384 | gemm_utils::prep_gemm_pack<int8_t, int32_t>( |
385 | do_a, no_trans, rows, cols, &pack_dst); |
386 | auto src = reinterpret_cast<const int8_t *>(src_void); |
387 | return gemm_utils::pack_no_copy( |
388 | src, ld, rows, cols, is_trans, alpha, &pack_dst); |
389 | } else { |
390 | gemm_utils::prep_gemm_pack<uint8_t, int32_t>( |
391 | do_a, no_trans, rows, cols, &pack_dst); |
392 | auto src = reinterpret_cast<const uint8_t *>(src_void); |
393 | return gemm_utils::pack_no_copy( |
394 | src, ld, rows, cols, is_trans, alpha, &pack_dst); |
395 | } |
396 | } |
397 | } |
398 | |
399 | dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa, |
400 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
401 | const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { |
402 | |
403 | return gemm_x8x8s32_pack<int8_t, uint8_t>( |
404 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
405 | } |
406 | |
407 | dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa, |
408 | const char *transb, const dim_t *M, const dim_t *N, const dim_t *K, |
409 | const dim_t *lda, const dim_t *ldb, const void *src, void *dst) { |
410 | |
411 | return gemm_x8x8s32_pack<int8_t, int8_t>( |
412 | identifier, transa, transb, M, N, K, lda, ldb, src, dst); |
413 | } |
414 | |
415 | dnnl_status_t sgemm_compute(const char *transa, const char *transb, |
416 | const dim_t *M, const dim_t *N, const dim_t *K, const float *A, |
417 | const dim_t *lda, const float *B, const dim_t *ldb, const float *beta, |
418 | float *C, const dim_t *ldc) { |
419 | |
420 | #if USE_MKL_PACKED_GEMM |
421 | if (utils::any_null(transa, transb, M, N, K, A, lda, B, ldb, beta, C, ldc)) |
422 | return dnnl_invalid_arguments; |
423 | cblas_sgemm_compute(CblasColMajor, cblas_storage(transa), |
424 | cblas_storage(transb), *M, *N, *K, A, *lda, B, *ldb, *beta, C, |
425 | *ldc); |
426 | return dnnl_success; |
427 | #else |
428 | if (!pack_sgemm_supported()) return dnnl_unimplemented; |
429 | |
430 | float one = 1.0f; |
431 | |
432 | return extended_sgemm( |
433 | transa, transb, M, N, K, &one, A, lda, B, ldb, beta, C, ldc); |
434 | #endif |
435 | } |
436 | |
437 | dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb, |
438 | const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A, |
439 | const dim_t *lda, const bfloat16_t *B, const dim_t *ldb, |
440 | const float *beta, float *C, const dim_t *ldc) { |
441 | |
442 | if (!pack_gemm_bf16bf16f32_supported()) return dnnl_unimplemented; |
443 | |
444 | float one = 1.0f; |
445 | |
446 | return gemm_bf16bf16f32( |
447 | transa, transb, M, N, K, &one, A, lda, B, ldb, beta, C, ldc); |
448 | } |
449 | |
450 | template <typename a_dt, typename b_dt> |
451 | dnnl_status_t gemm_x8x8s32_compute(const char *transa, const char *transb, |
452 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
453 | const a_dt *A, const dim_t *lda, const b_dt *B, const dim_t *ldb, |
454 | const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { |
455 | |
456 | const float one = 1.f, *alpha = &one; |
457 | const a_dt zero_a_dt = 0, *ao = &zero_a_dt; |
458 | const b_dt zero_b_dt = 0, *bo = &zero_b_dt; |
459 | |
460 | #if USE_MKL_PACKED_GEMM |
461 | constexpr bool is_s8u8 = true |
462 | && data_traits<a_dt>::data_type == data_type::s8 |
463 | && data_traits<b_dt>::data_type == data_type::u8; |
464 | |
465 | if (is_s8u8) { |
466 | if (utils::any_null(transa, transb, offsetc, M, N, K, alpha, A, lda, ao, |
467 | B, ldb, bo, beta, C, ldc, co)) |
468 | return dnnl_invalid_arguments; |
469 | cblas_gemm_s8u8s32_compute(CblasColMajor, cblas_storage(transa), |
470 | cblas_storage(transb), cblas_offset(offsetc), *M, *N, *K, |
471 | *alpha, A, *lda, *ao, B, *ldb, *bo, *beta, C, *ldc, co); |
472 | return dnnl_success; |
473 | } |
474 | #endif |
475 | auto lda_eff = *lda, ldb_eff = *ldb; |
476 | auto transa_eff = *transa, transb_eff = *transb; |
477 | |
478 | if (!use_reference_igemm<a_dt, b_dt>()) { |
479 | return gemm_s8x8s32(&transa_eff, &transb_eff, offsetc, M, N, K, alpha, |
480 | A, &lda_eff, ao, B, &ldb_eff, bo, beta, C, ldc, co); |
481 | } else { |
482 | dim_t ld, td; |
483 | |
484 | if (transa_eff == 'p' || transa_eff == 'P') { |
485 | gemm_pack_storage_t a_packed {A}; |
486 | int trans; |
487 | if (!a_packed.get_nocopy(trans, ld, td)) |
488 | return dnnl_invalid_arguments; |
489 | A = a_packed.matrix<a_dt>(); |
490 | lda_eff = ld; |
491 | transa_eff = trans == no_trans ? 'N' : 'T'; |
492 | } |
493 | |
494 | if (transb_eff == 'p' || transb_eff == 'P') { |
495 | gemm_pack_storage_t b_packed {B}; |
496 | int trans; |
497 | if (!b_packed.get_nocopy(trans, ld, td)) |
498 | return dnnl_invalid_arguments; |
499 | B = b_packed.matrix<b_dt>(); |
500 | ldb_eff = ld; |
501 | transb_eff = trans == no_trans ? 'N' : 'T'; |
502 | } |
503 | |
504 | return gemm_s8x8s32(&transa_eff, &transb_eff, offsetc, M, N, K, alpha, |
505 | A, &lda_eff, ao, B, &ldb_eff, bo, beta, C, ldc, co); |
506 | } |
507 | } |
508 | |
509 | dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb, |
510 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
511 | const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb, |
512 | const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { |
513 | |
514 | return gemm_x8x8s32_compute( |
515 | transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); |
516 | } |
517 | |
518 | dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb, |
519 | const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K, |
520 | const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb, |
521 | const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co) { |
522 | |
523 | return gemm_x8x8s32_compute( |
524 | transa, transb, offsetc, M, N, K, A, lda, B, ldb, beta, C, ldc, co); |
525 | } |
526 | |
527 | } // namespace x64 |
528 | } // namespace cpu |
529 | } // namespace impl |
530 | } // namespace dnnl |
531 | |