1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/dnnl_thread.hpp"
20#include "common/nstl.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/platform.hpp"
24
25#include "cpu/gemm/f32/gemm_utils_f32.hpp"
26#include "cpu/gemm/f32/ref_gemm_f32.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32using namespace dnnl::impl::utils;
33using namespace gemm_utils;
34
35namespace {
36
37template <typename data_t>
38void copy_A(
39 bool isTransA, dim_t K, const data_t *A, const dim_t lda, data_t *ws) {
40 for (dim_t k = 0; k < K; k++) {
41 PRAGMA_OMP_SIMD()
42 for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) {
43 ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
44 }
45 ws += unroll_factor<data_t>::m;
46 }
47}
48
49template <typename data_t, bool isTransA, bool isTransB>
50void kernel_mxn(dim_t K, const data_t *A, const dim_t lda, const data_t *B,
51 const dim_t ldb, data_t *C, const dim_t ldc, const data_t alpha,
52 const data_t beta) {
53 data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n]
54 = {static_cast<data_t>(0.)};
55 for (dim_t k = 0; k < K; k++) {
56 for (dim_t j = 0; j < unroll_factor<data_t>::n; j++) {
57 data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
58 PRAGMA_OMP_SIMD()
59 for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) {
60 data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
61 c[i + unroll_factor<data_t>::m * j] += a * b;
62 }
63 }
64 }
65 for (dim_t j = 0; j < unroll_factor<data_t>::n; j++) {
66 PRAGMA_OMP_SIMD()
67 for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) {
68 C[i + j * ldc] = (beta == static_cast<data_t>(0.))
69 ? alpha * c[i + unroll_factor<data_t>::m * j]
70 : alpha * c[i + unroll_factor<data_t>::m * j]
71 + beta * C[i + j * ldc];
72 }
73 }
74}
75
76template <typename data_t, bool isTransA, bool isTransB>
77void block_ker(const dim_t M, const dim_t N, const dim_t K, const data_t *A,
78 const dim_t lda, const data_t *B, const dim_t ldb, data_t *C,
79 const dim_t ldc, const data_t alpha, const data_t beta, data_t *ws,
80 bool do_copy) {
81 dim_t Nu = rnd_dn(N, unroll_factor<data_t>::n);
82 dim_t Mu = rnd_dn(M, unroll_factor<data_t>::m);
83 for (dim_t i = 0; i < Mu; i += unroll_factor<data_t>::m) {
84 for (dim_t j = 0; j < Nu; j += unroll_factor<data_t>::n) {
85 const data_t *b = isTransB ? &B[j] : &B[j * ldb];
86 const data_t *a = isTransA ? &A[i * lda] : &A[i];
87 if (do_copy) {
88 if (j == 0) { copy_A<data_t>(isTransA, K, a, lda, ws); }
89 kernel_mxn<data_t, false, isTransB>(K, ws,
90 unroll_factor<data_t>::m, b, ldb, &C[i + j * ldc], ldc,
91 alpha, beta);
92 } else {
93 kernel_mxn<data_t, isTransA, isTransB>(
94 K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
95 }
96 }
97 }
98 // tail processing
99 for (dim_t i = 0; i < M; i++) {
100 for (dim_t j = Nu; j < N; j++) {
101 data_t c = beta == static_cast<data_t>(0.) ? static_cast<data_t>(0.)
102 : beta * C[i + j * ldc];
103 for (dim_t p = 0; p < K; p++) {
104 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
105 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
106 c += alpha * a * b;
107 }
108 C[i + j * ldc] = c;
109 }
110 }
111 for (dim_t i = Mu; i < M; i++) {
112 for (dim_t j = 0; j < Nu; j++) {
113 data_t c = beta == static_cast<data_t>(0.) ? static_cast<data_t>(0.)
114 : beta * C[i + j * ldc];
115 for (dim_t p = 0; p < K; p++) {
116 data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
117 data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
118 c += alpha * a * b;
119 }
120 C[i + j * ldc] = c;
121 }
122 }
123}
124
125template <typename data_t, bool isTransA, bool isTransB>
126void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const data_t alpha,
127 const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
128 const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
129 data_t *ws) {
130 constexpr dim_t BM = gemm_traits<data_t, isTransA, isTransB>::BM;
131 constexpr dim_t BN = gemm_traits<data_t, isTransA, isTransB>::BN;
132 constexpr dim_t BK = gemm_traits<data_t, isTransA, isTransB>::BK;
133
134 const data_t *curA;
135 const data_t *curB;
136 data_t *curC;
137
138 if ((M <= 0) || (N <= 0)) return;
139
140 if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
141 dim_t MN = N * M;
142 if (beta == static_cast<data_t>(0.)) {
143 for (dim_t j = 0; j < MN; j++)
144 C[j] = static_cast<data_t>(0.);
145 } else if (beta != static_cast<data_t>(1.)) {
146 for (dim_t j = 0; j < MN; j++)
147 C[j] *= beta;
148 }
149 return;
150 }
151
152 for (dim_t Bk = 0; Bk < K; Bk += BK) {
153 dim_t kb = nstl::min(K - Bk, BK);
154 for (dim_t Bm = 0; Bm < M; Bm += BM) {
155 dim_t mb = nstl::min(M - Bm, BM);
156 for (dim_t Bn = 0; Bn < N; Bn += BN) {
157 dim_t nb = nstl::min(N - Bn, BN);
158 curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
159 curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
160 curC = C + Bm + Bn * ldc;
161 if (Bk == 0) {
162 block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
163 curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
164 } else {
165 block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
166 curB, ldb, curC, ldc, alpha,
167 static_cast<data_t>(1.0), ws, do_copy);
168 }
169 }
170 }
171 }
172}
173
174} // namespace
175
176template <typename data_t>
177dnnl_status_t ref_gemm(const char *transa_, const char *transb_,
178 const dim_t *M_, const dim_t *N_, const dim_t *K_, const data_t *alpha_,
179 const data_t *A, const dim_t *lda_, const data_t *B, const dim_t *ldb_,
180 const data_t *beta_, data_t *C, const dim_t *ldc_, const data_t *bias) {
181
182 if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T')
183 && utils::one_of(*transb_, 'n', 'N', 't', 'T')))
184 return dnnl_unimplemented;
185
186 bool isTransA = (*transa_ == 'T' || *transa_ == 't');
187 bool isTransB = (*transb_ == 'T' || *transb_ == 't');
188 const dim_t M = *M_, N = *N_, K = *K_;
189 const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
190 const data_t alpha = *alpha_, beta = *beta_;
191
192 // early out and avoid division by zero in partitioning
193 if (utils::one_of(0, M, N)) return dnnl_success;
194
195 int max_nthr = dnnl_get_current_num_threads();
196 int nthr_m, nthr_n, nthr_k;
197 dim_t MB, NB, KB;
198 // thread balancing over M, N, K & size of blocking dimensions
199 calc_nthr_nocopy_avx(
200 M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
201 assert(IMPLICATION(!dnnl_thr_syncable(), nthr_k == 1));
202
203 data_t *c_buffers = nullptr;
204 data_t *ws_buffers = nullptr;
205 if (nthr_k > 1) {
206 c_buffers = (data_t *)malloc(
207 sizeof(*c_buffers) * nthr_m * nthr_n * (nthr_k - 1) * MB * NB,
208 PAGE_4K);
209 if (!c_buffers) {
210 nthr_k = 1;
211 KB = K;
212 }
213 }
214
215 bool do_copy = (NB / unroll_factor<data_t>::n > 3);
216 const int nthr_mn = nthr_m * nthr_n;
217 const int nthr_to_use = nthr_mn * nthr_k;
218 const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
219 const size_t ws_size_per_thr
220 = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
221 if (do_copy) {
222 ws_buffers = (data_t *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K);
223 if (!ws_buffers) do_copy = false;
224 }
225
226 auto get_thr_block = [&](dim_t &from, dim_t &to, dim_t &myN, dim_t NB,
227 dim_t N, int ithr) {
228 from = NB * (ithr);
229 to = NB * (ithr + 1);
230 if (to > N) to = N;
231 myN = to - from;
232 };
233
234 parallel(nthr_to_use, [&](int ithr, int nthr) {
235 assert(nthr_to_use == nthr);
236 MAYBE_UNUSED(nthr);
237
238 int ithr_mn = ithr % nthr_mn;
239 int ithr_m = ithr_mn % nthr_m;
240 int ithr_n = ithr_mn / nthr_m;
241 int ithr_k = ithr / nthr_mn;
242
243 int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
244
245 data_t *ws = do_copy
246 ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
247 : nullptr;
248
249 dim_t m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
250 k_from = 0, k_to = 0, myK = 0;
251
252 get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
253 get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
254 get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
255
256 if (myM > 0 && myN > 0) {
257 data_t myBeta, *myC;
258 dim_t ld;
259 if (ithr_k == 0) {
260 myC = &(C[m_from + n_from * ldc]);
261 myBeta = beta;
262 ld = ldc;
263 } else {
264 myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
265 myBeta = 0.0f;
266 ld = MB;
267 }
268 const data_t *myA = isTransA ? &(A[k_from + m_from * lda])
269 : &(A[m_from + k_from * lda]);
270 const data_t *myB = isTransB ? &(B[n_from + k_from * ldb])
271 : &(B[k_from + n_from * ldb]);
272
273 if (!isTransA) {
274 if (!isTransB) {
275 gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
276 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
277 } else {
278 gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
279 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
280 }
281 } else {
282 if (!isTransB) {
283 gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
284 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
285 } else {
286 gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
287 lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
288 }
289 }
290 }
291 });
292
293 if (nthr_k > 1) {
294 parallel(nthr_to_use, [&](int ithr, int nthr) {
295 assert(nthr_to_use == nthr);
296 MAYBE_UNUSED(nthr);
297
298 int ithr_mn = ithr % nthr_mn;
299 int ithr_m = ithr_mn % nthr_m;
300 int ithr_k = ithr / nthr_mn;
301 int ithr_n = ithr_mn / nthr_m;
302
303 dim_t n_from = 0, n_to = 0, myN = 0;
304 dim_t m_from = 0, m_to = 0, myM = 0;
305
306 int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
307
308 get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
309 get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
310
311 // sum matrices partitioned along K dimension
312 dim_t offset = 0, block = 0;
313 gemm_utils::partition_unit_diff(
314 ithr_k, nthr_k, myN, &offset, &block);
315 for (int ik = 1; ik < nthr_k; ++ik) {
316 data_t *myC = c_buffers
317 + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
318
319 gemm_utils::sum_two_matrices(myM, block, myC, MB,
320 &C[m_from + (n_from + offset) * ldc], ldc);
321 }
322 });
323 }
324
325 if (bias) {
326 parallel_nd(N, M, [&](dim_t i, dim_t j) { C[i * ldc + j] += bias[j]; });
327 }
328
329 free(ws_buffers);
330 free(c_buffers);
331
332 return dnnl_success;
333}
334
335template dnnl_status_t ref_gemm<float>(const char *transa_, const char *transb_,
336 const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_,
337 const float *A, const dim_t *lda_, const float *B, const dim_t *ldb_,
338 const float *beta_, float *C, const dim_t *ldc_, const float *bias);
339
340template dnnl_status_t ref_gemm<double>(const char *transa_,
341 const char *transb_, const dim_t *M_, const dim_t *N_, const dim_t *K_,
342 const double *alpha_, const double *A, const dim_t *lda_,
343 const double *B, const dim_t *ldb_, const double *beta_, double *C,
344 const dim_t *ldc_, const double *bias);
345} // namespace cpu
346} // namespace impl
347} // namespace dnnl
348