1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_GEMM_PACK_HPP
18#define CPU_GEMM_GEMM_PACK_HPP
19
20#include "oneapi/dnnl/dnnl_config.h"
21#include "oneapi/dnnl/dnnl_types.h"
22
23#include "common/bfloat16.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28
29bool pack_sgemm_supported();
30bool pack_gemm_bf16bf16f32_supported();
31
32dnnl_status_t DNNL_API sgemm_pack_get_size(const char *identifier,
33 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
34 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
35 bool *pack = nullptr);
36
37dnnl_status_t DNNL_API gemm_bf16bf16f32_pack_get_size(const char *identifier,
38 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
39 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
40 bool *pack = nullptr);
41
42dnnl_status_t DNNL_API gemm_s8u8s32_pack_get_size(const char *identifier,
43 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
44 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
45 bool *pack = nullptr);
46
47dnnl_status_t DNNL_API gemm_s8s8s32_pack_get_size(const char *identifier,
48 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
49 const dim_t *K, const dim_t *lda, const dim_t *ldb, size_t *size,
50 bool *pack = nullptr);
51
52dnnl_status_t DNNL_API sgemm_pack(const char *identifier, const char *transa,
53 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
54 const dim_t *lda, const dim_t *ldb, const float *src, float *dst);
55
56dnnl_status_t DNNL_API gemm_bf16bf16f32_pack(const char *identifier,
57 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
58 const dim_t *K, const dim_t *lda, const dim_t *ldb,
59 const bfloat16_t *src, bfloat16_t *dst);
60
61dnnl_status_t DNNL_API gemm_s8u8s32_pack(const char *identifier,
62 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
63 const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src,
64 void *dst);
65
66dnnl_status_t DNNL_API gemm_s8s8s32_pack(const char *identifier,
67 const char *transa, const char *transb, const dim_t *M, const dim_t *N,
68 const dim_t *K, const dim_t *lda, const dim_t *ldb, const void *src,
69 void *dst);
70
71dnnl_status_t DNNL_API sgemm_compute(const char *transa, const char *transb,
72 const dim_t *M, const dim_t *N, const dim_t *K, const float *A,
73 const dim_t *lda, const float *B, const dim_t *ldb, const float *beta,
74 float *C, const dim_t *ldc);
75
76dnnl_status_t DNNL_API gemm_bf16bf16f32_compute(const char *transa,
77 const char *transb, const dim_t *M, const dim_t *N, const dim_t *K,
78 const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B,
79 const dim_t *ldb, const float *beta, float *C, const dim_t *ldc);
80
81dnnl_status_t DNNL_API gemm_s8u8s32_compute(const char *transa,
82 const char *transb, const char *offsetc, const dim_t *M, const dim_t *N,
83 const dim_t *K, const int8_t *A, const dim_t *lda, const uint8_t *B,
84 const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc,
85 const int32_t *co);
86
87dnnl_status_t DNNL_API gemm_s8s8s32_compute(const char *transa,
88 const char *transb, const char *offsetc, const dim_t *M, const dim_t *N,
89 const dim_t *K, const int8_t *A, const dim_t *lda, const int8_t *B,
90 const dim_t *ldb, const float *beta, int32_t *C, const dim_t *ldc,
91 const int32_t *co);
92
93} // namespace cpu
94} // namespace impl
95} // namespace dnnl
96
97#endif // CPU_GEMM_GEMM_PACK_HPP
98