1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl_types.h" |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/platform.hpp" |
24 | |
25 | #include "cpu/gemm/f32/gemm_utils_f32.hpp" |
26 | #include "cpu/gemm/f32/ref_gemm_f32.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | using namespace dnnl::impl::utils; |
33 | using namespace gemm_utils; |
34 | |
35 | namespace { |
36 | |
37 | template <typename data_t> |
38 | void copy_A( |
39 | bool isTransA, dim_t K, const data_t *A, const dim_t lda, data_t *ws) { |
40 | for (dim_t k = 0; k < K; k++) { |
41 | PRAGMA_OMP_SIMD() |
42 | for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) { |
43 | ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda]; |
44 | } |
45 | ws += unroll_factor<data_t>::m; |
46 | } |
47 | } |
48 | |
49 | template <typename data_t, bool isTransA, bool isTransB> |
50 | void kernel_mxn(dim_t K, const data_t *A, const dim_t lda, const data_t *B, |
51 | const dim_t ldb, data_t *C, const dim_t ldc, const data_t alpha, |
52 | const data_t beta) { |
53 | data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] |
54 | = {static_cast<data_t>(0.)}; |
55 | for (dim_t k = 0; k < K; k++) { |
56 | for (dim_t j = 0; j < unroll_factor<data_t>::n; j++) { |
57 | data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb]; |
58 | PRAGMA_OMP_SIMD() |
59 | for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) { |
60 | data_t a = isTransA ? A[i * lda + k] : A[i + lda * k]; |
61 | c[i + unroll_factor<data_t>::m * j] += a * b; |
62 | } |
63 | } |
64 | } |
65 | for (dim_t j = 0; j < unroll_factor<data_t>::n; j++) { |
66 | PRAGMA_OMP_SIMD() |
67 | for (dim_t i = 0; i < unroll_factor<data_t>::m; i++) { |
68 | C[i + j * ldc] = (beta == static_cast<data_t>(0.)) |
69 | ? alpha * c[i + unroll_factor<data_t>::m * j] |
70 | : alpha * c[i + unroll_factor<data_t>::m * j] |
71 | + beta * C[i + j * ldc]; |
72 | } |
73 | } |
74 | } |
75 | |
76 | template <typename data_t, bool isTransA, bool isTransB> |
77 | void block_ker(const dim_t M, const dim_t N, const dim_t K, const data_t *A, |
78 | const dim_t lda, const data_t *B, const dim_t ldb, data_t *C, |
79 | const dim_t ldc, const data_t alpha, const data_t beta, data_t *ws, |
80 | bool do_copy) { |
81 | dim_t Nu = rnd_dn(N, unroll_factor<data_t>::n); |
82 | dim_t Mu = rnd_dn(M, unroll_factor<data_t>::m); |
83 | for (dim_t i = 0; i < Mu; i += unroll_factor<data_t>::m) { |
84 | for (dim_t j = 0; j < Nu; j += unroll_factor<data_t>::n) { |
85 | const data_t *b = isTransB ? &B[j] : &B[j * ldb]; |
86 | const data_t *a = isTransA ? &A[i * lda] : &A[i]; |
87 | if (do_copy) { |
88 | if (j == 0) { copy_A<data_t>(isTransA, K, a, lda, ws); } |
89 | kernel_mxn<data_t, false, isTransB>(K, ws, |
90 | unroll_factor<data_t>::m, b, ldb, &C[i + j * ldc], ldc, |
91 | alpha, beta); |
92 | } else { |
93 | kernel_mxn<data_t, isTransA, isTransB>( |
94 | K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta); |
95 | } |
96 | } |
97 | } |
98 | // tail processing |
99 | for (dim_t i = 0; i < M; i++) { |
100 | for (dim_t j = Nu; j < N; j++) { |
101 | data_t c = beta == static_cast<data_t>(0.) ? static_cast<data_t>(0.) |
102 | : beta * C[i + j * ldc]; |
103 | for (dim_t p = 0; p < K; p++) { |
104 | data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; |
105 | data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; |
106 | c += alpha * a * b; |
107 | } |
108 | C[i + j * ldc] = c; |
109 | } |
110 | } |
111 | for (dim_t i = Mu; i < M; i++) { |
112 | for (dim_t j = 0; j < Nu; j++) { |
113 | data_t c = beta == static_cast<data_t>(0.) ? static_cast<data_t>(0.) |
114 | : beta * C[i + j * ldc]; |
115 | for (dim_t p = 0; p < K; p++) { |
116 | data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb]; |
117 | data_t a = isTransA ? A[p + i * lda] : A[i + p * lda]; |
118 | c += alpha * a * b; |
119 | } |
120 | C[i + j * ldc] = c; |
121 | } |
122 | } |
123 | } |
124 | |
125 | template <typename data_t, bool isTransA, bool isTransB> |
126 | void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const data_t alpha, |
127 | const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb, |
128 | const data_t beta, data_t *C, const dim_t ldc, bool do_copy, |
129 | data_t *ws) { |
130 | constexpr dim_t BM = gemm_traits<data_t, isTransA, isTransB>::BM; |
131 | constexpr dim_t BN = gemm_traits<data_t, isTransA, isTransB>::BN; |
132 | constexpr dim_t BK = gemm_traits<data_t, isTransA, isTransB>::BK; |
133 | |
134 | const data_t *curA; |
135 | const data_t *curB; |
136 | data_t *curC; |
137 | |
138 | if ((M <= 0) || (N <= 0)) return; |
139 | |
140 | if ((K <= 0) || (alpha == static_cast<data_t>(0))) { |
141 | dim_t MN = N * M; |
142 | if (beta == static_cast<data_t>(0.)) { |
143 | for (dim_t j = 0; j < MN; j++) |
144 | C[j] = static_cast<data_t>(0.); |
145 | } else if (beta != static_cast<data_t>(1.)) { |
146 | for (dim_t j = 0; j < MN; j++) |
147 | C[j] *= beta; |
148 | } |
149 | return; |
150 | } |
151 | |
152 | for (dim_t Bk = 0; Bk < K; Bk += BK) { |
153 | dim_t kb = nstl::min(K - Bk, BK); |
154 | for (dim_t Bm = 0; Bm < M; Bm += BM) { |
155 | dim_t mb = nstl::min(M - Bm, BM); |
156 | for (dim_t Bn = 0; Bn < N; Bn += BN) { |
157 | dim_t nb = nstl::min(N - Bn, BN); |
158 | curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda; |
159 | curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb; |
160 | curC = C + Bm + Bn * ldc; |
161 | if (Bk == 0) { |
162 | block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, |
163 | curB, ldb, curC, ldc, alpha, beta, ws, do_copy); |
164 | } else { |
165 | block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda, |
166 | curB, ldb, curC, ldc, alpha, |
167 | static_cast<data_t>(1.0), ws, do_copy); |
168 | } |
169 | } |
170 | } |
171 | } |
172 | } |
173 | |
174 | } // namespace |
175 | |
176 | template <typename data_t> |
177 | dnnl_status_t ref_gemm(const char *transa_, const char *transb_, |
178 | const dim_t *M_, const dim_t *N_, const dim_t *K_, const data_t *alpha_, |
179 | const data_t *A, const dim_t *lda_, const data_t *B, const dim_t *ldb_, |
180 | const data_t *beta_, data_t *C, const dim_t *ldc_, const data_t *bias) { |
181 | |
182 | if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T') |
183 | && utils::one_of(*transb_, 'n', 'N', 't', 'T'))) |
184 | return dnnl_unimplemented; |
185 | |
186 | bool isTransA = (*transa_ == 'T' || *transa_ == 't'); |
187 | bool isTransB = (*transb_ == 'T' || *transb_ == 't'); |
188 | const dim_t M = *M_, N = *N_, K = *K_; |
189 | const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; |
190 | const data_t alpha = *alpha_, beta = *beta_; |
191 | |
192 | // early out and avoid division by zero in partitioning |
193 | if (utils::one_of(0, M, N)) return dnnl_success; |
194 | |
195 | int max_nthr = dnnl_get_current_num_threads(); |
196 | int nthr_m, nthr_n, nthr_k; |
197 | dim_t MB, NB, KB; |
198 | // thread balancing over M, N, K & size of blocking dimensions |
199 | calc_nthr_nocopy_avx( |
200 | M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB); |
201 | assert(IMPLICATION(!dnnl_thr_syncable(), nthr_k == 1)); |
202 | |
203 | data_t *c_buffers = nullptr; |
204 | data_t *ws_buffers = nullptr; |
205 | if (nthr_k > 1) { |
206 | c_buffers = (data_t *)malloc( |
207 | sizeof(*c_buffers) * nthr_m * nthr_n * (nthr_k - 1) * MB * NB, |
208 | PAGE_4K); |
209 | if (!c_buffers) { |
210 | nthr_k = 1; |
211 | KB = K; |
212 | } |
213 | } |
214 | |
215 | bool do_copy = (NB / unroll_factor<data_t>::n > 3); |
216 | const int nthr_mn = nthr_m * nthr_n; |
217 | const int nthr_to_use = nthr_mn * nthr_k; |
218 | const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m; |
219 | const size_t ws_size_per_thr |
220 | = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K); |
221 | if (do_copy) { |
222 | ws_buffers = (data_t *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K); |
223 | if (!ws_buffers) do_copy = false; |
224 | } |
225 | |
226 | auto get_thr_block = [&](dim_t &from, dim_t &to, dim_t &myN, dim_t NB, |
227 | dim_t N, int ithr) { |
228 | from = NB * (ithr); |
229 | to = NB * (ithr + 1); |
230 | if (to > N) to = N; |
231 | myN = to - from; |
232 | }; |
233 | |
234 | parallel(nthr_to_use, [&](int ithr, int nthr) { |
235 | assert(nthr_to_use == nthr); |
236 | MAYBE_UNUSED(nthr); |
237 | |
238 | int ithr_mn = ithr % nthr_mn; |
239 | int ithr_m = ithr_mn % nthr_m; |
240 | int ithr_n = ithr_mn / nthr_m; |
241 | int ithr_k = ithr / nthr_mn; |
242 | |
243 | int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); |
244 | |
245 | data_t *ws = do_copy |
246 | ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t) |
247 | : nullptr; |
248 | |
249 | dim_t m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0, |
250 | k_from = 0, k_to = 0, myK = 0; |
251 | |
252 | get_thr_block(m_from, m_to, myM, MB, M, ithr_m); |
253 | get_thr_block(n_from, n_to, myN, NB, N, ithr_n); |
254 | get_thr_block(k_from, k_to, myK, KB, K, ithr_k); |
255 | |
256 | if (myM > 0 && myN > 0) { |
257 | data_t myBeta, *myC; |
258 | dim_t ld; |
259 | if (ithr_k == 0) { |
260 | myC = &(C[m_from + n_from * ldc]); |
261 | myBeta = beta; |
262 | ld = ldc; |
263 | } else { |
264 | myC = c_buffers + MB * NB * (cbase + ithr_k - 1); |
265 | myBeta = 0.0f; |
266 | ld = MB; |
267 | } |
268 | const data_t *myA = isTransA ? &(A[k_from + m_from * lda]) |
269 | : &(A[m_from + k_from * lda]); |
270 | const data_t *myB = isTransB ? &(B[n_from + k_from * ldb]) |
271 | : &(B[k_from + n_from * ldb]); |
272 | |
273 | if (!isTransA) { |
274 | if (!isTransB) { |
275 | gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA, |
276 | lda, myB, ldb, myBeta, myC, ld, do_copy, ws); |
277 | } else { |
278 | gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA, |
279 | lda, myB, ldb, myBeta, myC, ld, do_copy, ws); |
280 | } |
281 | } else { |
282 | if (!isTransB) { |
283 | gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA, |
284 | lda, myB, ldb, myBeta, myC, ld, do_copy, ws); |
285 | } else { |
286 | gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA, |
287 | lda, myB, ldb, myBeta, myC, ld, do_copy, ws); |
288 | } |
289 | } |
290 | } |
291 | }); |
292 | |
293 | if (nthr_k > 1) { |
294 | parallel(nthr_to_use, [&](int ithr, int nthr) { |
295 | assert(nthr_to_use == nthr); |
296 | MAYBE_UNUSED(nthr); |
297 | |
298 | int ithr_mn = ithr % nthr_mn; |
299 | int ithr_m = ithr_mn % nthr_m; |
300 | int ithr_k = ithr / nthr_mn; |
301 | int ithr_n = ithr_mn / nthr_m; |
302 | |
303 | dim_t n_from = 0, n_to = 0, myN = 0; |
304 | dim_t m_from = 0, m_to = 0, myM = 0; |
305 | |
306 | int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1); |
307 | |
308 | get_thr_block(n_from, n_to, myN, NB, N, ithr_n); |
309 | get_thr_block(m_from, m_to, myM, MB, M, ithr_m); |
310 | |
311 | // sum matrices partitioned along K dimension |
312 | dim_t offset = 0, block = 0; |
313 | gemm_utils::partition_unit_diff( |
314 | ithr_k, nthr_k, myN, &offset, &block); |
315 | for (int ik = 1; ik < nthr_k; ++ik) { |
316 | data_t *myC = c_buffers |
317 | + MB * ((dim_t)NB * (cbase + ik - 1) + offset); |
318 | |
319 | gemm_utils::sum_two_matrices(myM, block, myC, MB, |
320 | &C[m_from + (n_from + offset) * ldc], ldc); |
321 | } |
322 | }); |
323 | } |
324 | |
325 | if (bias) { |
326 | parallel_nd(N, M, [&](dim_t i, dim_t j) { C[i * ldc + j] += bias[j]; }); |
327 | } |
328 | |
329 | free(ws_buffers); |
330 | free(c_buffers); |
331 | |
332 | return dnnl_success; |
333 | } |
334 | |
335 | template dnnl_status_t ref_gemm<float>(const char *transa_, const char *transb_, |
336 | const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_, |
337 | const float *A, const dim_t *lda_, const float *B, const dim_t *ldb_, |
338 | const float *beta_, float *C, const dim_t *ldc_, const float *bias); |
339 | |
340 | template dnnl_status_t ref_gemm<double>(const char *transa_, |
341 | const char *transb_, const dim_t *M_, const dim_t *N_, const dim_t *K_, |
342 | const double *alpha_, const double *A, const dim_t *lda_, |
343 | const double *B, const dim_t *ldb_, const double *beta_, double *C, |
344 | const dim_t *ldc_, const double *bias); |
345 | } // namespace cpu |
346 | } // namespace impl |
347 | } // namespace dnnl |
348 | |