1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <sstream>
18
19#include "oneapi/dnnl/dnnl.h"
20
21#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
22#include "oneapi/dnnl/dnnl_threadpool.hpp"
23#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
24#endif
25
26#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
27#include "cpu/gemm/gemm.hpp"
28#endif
29
30#include "common/bfloat16.hpp"
31#include "common/c_types_map.hpp"
32#include "common/dnnl_thread.hpp"
33#include "common/stack_checker.hpp"
34#include "common/verbose.hpp"
35
36using namespace dnnl::impl;
37
38#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
39namespace {
40const char *c2f_offsetC(const char *offC) {
41 if (offC) {
42 if (offC[0] == 'R' || offC[0] == 'r') return "C";
43 if (offC[0] == 'C' || offC[0] == 'c') return "R";
44 }
45 return offC;
46}
47
48std::string get_descriptor(dim_t M, dim_t N, dim_t K) {
49 std::string s_ = std::to_string(M);
50 s_ += "x";
51 s_ += std::to_string(K);
52 s_ += ":";
53 s_ += std::to_string(K);
54 s_ += "x";
55 s_ += std::to_string(N);
56 return s_;
57}
58
59} // namespace
60#endif
61
62#ifdef DNNL_ENABLE_STACK_CHECKER
63#define MAYBE_RUN_STACK_CHECKER(api_name, ...) \
64 stack_checker::stack_checker_t(#api_name).check(__VA_ARGS__)
65#else
66#define MAYBE_RUN_STACK_CHECKER(_, func, ...) func(__VA_ARGS__)
67#endif
68
69#define MAYBE_VERBOSE(status, sdt_, wdt_, ddt_, ...) \
70 if (get_verbose() >= 1) { \
71 double start_ms = get_msec(); \
72 status = __VA_ARGS__; \
73 double duration_ms = get_msec() - start_ms; \
74 std::stringstream ss; \
75 ss << "onednn_verbose,"; \
76 if (get_verbose_timestamp()) ss << start_ms << ","; \
77 ss << "exec,cpu,gemm_api,,undef,"; \
78 const bool is_src_ab = (transa == 'N' || transa == 'n'); \
79 ss << "src_" << sdt_ << "::blocked:" << (is_src_ab ? "ab" : "ba") \
80 << ":f0 "; \
81 const bool is_wei_ab = (transb == 'N' || transb == 'n'); \
82 ss << "wei_" << wdt_ << "::blocked:" << (is_wei_ab ? "ab" : "ba") \
83 << ":f0 "; \
84 ss << "dst_" << ddt_ << "::blocked:ab:f0,"; \
85 if (is_src_ab && lda != K) ss << "lda:" << lda << " "; \
86 if (!is_src_ab && lda != M) ss << "lda:" << lda << " "; \
87 if (is_wei_ab && ldb != N) ss << "ldb:" << ldb << " "; \
88 if (!is_wei_ab && ldb != K) ss << "ldb:" << ldb << " "; \
89 if (alpha != 1.f) ss << "attr-oscale:common:" << alpha << " "; \
90 if (beta != 0.f) ss << "attr-post-ops:sum:" << beta << " "; \
91 ss << ",," << get_descriptor(M, N, K); \
92 ss << "," << duration_ms << std::flush; \
93 printf("%s\n", ss.str().c_str()); \
94 } else { \
95 status = __VA_ARGS__; \
96 }
97
98dnnl_status_t dnnl_sgemm(char transa, char transb, dim_t M, dim_t N, dim_t K,
99 float alpha, const float *A, dim_t lda, const float *B, const dim_t ldb,
100 float beta, float *C, dim_t ldc) {
101#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
102 status_t status = dnnl_success;
103 MAYBE_VERBOSE(status, "f32", "f32", "f32",
104 MAYBE_RUN_STACK_CHECKER(dnnl_sgemm, cpu::extended_sgemm, &transb,
105 &transa, &N, &M, &K, &alpha, B, &ldb, A, &lda, &beta, C,
106 &ldc, nullptr, false));
107 return status;
108#else
109 return dnnl::impl::status::unimplemented;
110#endif
111}
112
113dnnl_status_t dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dim_t M,
114 dim_t N, dim_t K, float alpha, const uint8_t *A, dim_t lda, uint8_t ao,
115 const int8_t *B, dim_t ldb, int8_t bo, float beta, int32_t *C,
116 dim_t ldc, const int32_t *co) {
117#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
118 status_t status = dnnl_success;
119 MAYBE_VERBOSE(status, "u8", "s8", "s32",
120 MAYBE_RUN_STACK_CHECKER(dnnl_gemm_u8s8s32,
121 cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
122 c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
123 &lda, &ao, &beta, C, &ldc, co));
124 return status;
125#else
126 return dnnl::impl::status::unimplemented;
127#endif
128}
129
130dnnl_status_t dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dim_t M,
131 dim_t N, dim_t K, float alpha, const int8_t *A, dim_t lda, int8_t ao,
132 const int8_t *B, dim_t ldb, int8_t bo, float beta, int32_t *C,
133 dim_t ldc, const int32_t *co) {
134#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
135 status_t status = dnnl_success;
136 MAYBE_VERBOSE(status, "s8", "s8", "s32",
137 MAYBE_RUN_STACK_CHECKER(dnnl_gemm_s8s8s32,
138 cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
139 c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
140 &lda, &ao, &beta, C, &ldc, co));
141 return status;
142#else
143 return dnnl::impl::status::unimplemented;
144#endif
145}
146
147extern "C" dnnl_status_t DNNL_API dnnl_gemm_bf16bf16f32(char transa,
148 char transb, dim_t M, dim_t N, dim_t K, float alpha,
149 const bfloat16_t *A, dim_t lda, const bfloat16_t *B, dim_t ldb,
150 float beta, float *C, dim_t ldc) {
151#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
152 status_t status = dnnl_success;
153 MAYBE_VERBOSE(status, "bf16", "bf16", "f32",
154 MAYBE_RUN_STACK_CHECKER(dnnl_gemm_bf16bf16f32,
155 cpu::gemm_bf16bf16f32, &transb, &transa, &N, &M, &K, &alpha,
156 B, &ldb, A, &lda, &beta, C, &ldc));
157 return status;
158#else
159 return dnnl::impl::status::unimplemented;
160#endif
161}
162
163#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
164dnnl_status_t dnnl_threadpool_interop_sgemm(char transa, char transb, dim_t M,
165 dim_t N, dim_t K, float alpha, const float *A, dim_t lda,
166 const float *B, const dim_t ldb, float beta, float *C, dim_t ldc,
167 void *th) {
168 threadpool_utils::activate_threadpool(
169 (dnnl::threadpool_interop::threadpool_iface *)th);
170 status_t status = dnnl_success;
171 MAYBE_VERBOSE(status, "f32", "f32", "f32",
172 MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_sgemm,
173 cpu::extended_sgemm, &transb, &transa, &N, &M, &K, &alpha,
174 B, &ldb, A, &lda, &beta, C, &ldc, nullptr, false));
175 threadpool_utils::deactivate_threadpool();
176 return status;
177}
178
179dnnl_status_t dnnl_threadpool_interop_gemm_u8s8s32(char transa, char transb,
180 char offsetc, dim_t M, dim_t N, dim_t K, float alpha, const uint8_t *A,
181 dim_t lda, uint8_t ao, const int8_t *B, dim_t ldb, int8_t bo,
182 float beta, int32_t *C, dim_t ldc, const int32_t *co, void *th) {
183 threadpool_utils::activate_threadpool(
184 (dnnl::threadpool_interop::threadpool_iface *)th);
185 status_t status = dnnl_success;
186 MAYBE_VERBOSE(status, "u8", "s8", "s32",
187 MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_u8s8s32,
188 cpu::gemm_s8x8s32<uint8_t>, &transb, &transa,
189 c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
190 &lda, &ao, &beta, C, &ldc, co));
191 threadpool_utils::deactivate_threadpool();
192 return status;
193}
194
195dnnl_status_t dnnl_threadpool_interop_gemm_s8s8s32(char transa, char transb,
196 char offsetc, dim_t M, dim_t N, dim_t K, float alpha, const int8_t *A,
197 dim_t lda, int8_t ao, const int8_t *B, dim_t ldb, int8_t bo, float beta,
198 int32_t *C, dim_t ldc, const int32_t *co, void *th) {
199 threadpool_utils::activate_threadpool(
200 (dnnl::threadpool_interop::threadpool_iface *)th);
201 status_t status = dnnl_success;
202 MAYBE_VERBOSE(status, "s8", "s8", "s32",
203 MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_s8s8s32,
204 cpu::gemm_s8x8s32<int8_t>, &transb, &transa,
205 c2f_offsetC(&offsetc), &N, &M, &K, &alpha, B, &ldb, &bo, A,
206 &lda, &ao, &beta, C, &ldc, co));
207 threadpool_utils::deactivate_threadpool();
208 return status;
209}
210
211extern "C" dnnl_status_t DNNL_API dnnl_threadpool_interop_gemm_bf16bf16f32(
212 char transa, char transb, dim_t M, dim_t N, dim_t K, float alpha,
213 const bfloat16_t *A, dim_t lda, const bfloat16_t *B, dim_t ldb,
214 float beta, float *C, dim_t ldc, void *th) {
215 threadpool_utils::activate_threadpool(
216 (dnnl::threadpool_interop::threadpool_iface *)th);
217 status_t status = dnnl_success;
218 MAYBE_VERBOSE(status, "bf16", "bf16", "f32",
219 MAYBE_RUN_STACK_CHECKER(dnnl_threadpool_interop_gemm_bf16bf16f32,
220 cpu::gemm_bf16bf16f32, &transb, &transa, &N, &M, &K, &alpha,
221 B, &ldb, A, &lda, &beta, C, &ldc));
222 threadpool_utils::deactivate_threadpool();
223 return status;
224}
225
226#undef MAYBE_VERBOSE
227
228#endif
229