1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_GEMM_PACK_HPP
18#define CPU_X64_GEMM_GEMM_PACK_HPP
19
20#include "oneapi/dnnl/dnnl_config.h"
21#include "oneapi/dnnl/dnnl_types.h"
22
23#include "common/bfloat16.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30bool pack_sgemm_supported();
31bool pack_gemm_bf16bf16f32_supported();
32
33dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa,
34 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
35 const dim_t *lda, const dim_t *ldb, size_t *size, bool *pack);
36
37dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier,
38 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
39 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
40 bool *pack);
41
42dnnl_status_t gemm_s8u8s32_pack_get_size(const char *identifier,
43 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
44 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
45 bool *pack);
46
47dnnl_status_t gemm_s8s8s32_pack_get_size(const char *identifier,
48 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
49 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
50 bool *pack);
51
52dnnl_status_t sgemm_pack(const char *identifier, const char *transa,
53 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
54 const dim_t *lda, const dim_t *ldb, const float *src, float *dst);
55
56dnnl_status_t gemm_bf16bf16f32_pack(const char *identifier, const char *transa,
57 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
58 const dim_t *lda, const dim_t *ldb, const bfloat16_t *src,
59 bfloat16_t *dst);
60
61dnnl_status_t gemm_s8u8s32_pack(const char *identifier, const char *transa,
62 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
63 const dim_t *lda, const dim_t *ldb, const void *src, void *dst);
64
65dnnl_status_t gemm_s8s8s32_pack(const char *identifier, const char *transa,
66 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
67 const dim_t *lda, const dim_t *ldb, const void *src, void *dst);
68
69dnnl_status_t sgemm_compute(const char *transa, const char *transb,
70 const dim_t *M, const dim_t *N, const dim_t *K, const float *A,
71 const dim_t *lda, const float *B, const dim_t *ldb, const float *beta,
72 float *C, const dim_t *ldc);
73
74dnnl_status_t gemm_bf16bf16f32_compute(const char *transa, const char *transb,
75 const dim_t *M, const dim_t *N, const dim_t *K, const bfloat16_t *A,
76 const dim_t *lda, const bfloat16_t *B, const dim_t *ldb,
77 const float *beta, float *C, const dim_t *ldc);
78
79dnnl_status_t gemm_s8u8s32_compute(const char *transa, const char *transb,
80 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
81 const int8_t *A, const dim_t *lda, const uint8_t *B, const dim_t *ldb,
82 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co);
83
84dnnl_status_t gemm_s8s8s32_compute(const char *transa, const char *transb,
85 const char *offsetc, const dim_t *M, const dim_t *N, const dim_t *K,
86 const int8_t *A, const dim_t *lda, const int8_t *B, const dim_t *ldb,
87 const float *beta, int32_t *C, const dim_t *ldc, const int32_t *co);
88
89} // namespace x64
90} // namespace cpu
91} // namespace impl
92} // namespace dnnl
93
94#endif // CPU_X64_GEMM_GEMM_PACK_HPP
95