1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdint> |
18 | #if defined(_MSC_VER) |
19 | #include <malloc.h> |
20 | #endif |
21 | |
22 | #include "oneapi/dnnl/dnnl_types.h" |
23 | |
24 | #include "common/bfloat16.hpp" |
25 | #include "common/dnnl_traits.hpp" |
26 | #include "common/nstl.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/gemm/f32/gemm_utils_f32.hpp" |
32 | #include "cpu/gemm/gemm_msan_unpoison.hpp" |
33 | |
34 | #include "cpu/x64/jit_generator.hpp" |
35 | |
36 | #include "cpu/x64/gemm/gemm_driver.hpp" |
37 | #include "cpu/x64/gemm/gemm_info.hpp" |
38 | #include "cpu/x64/gemm/gemm_partition.hpp" |
39 | #include "cpu/x64/gemm/gemm_threading.hpp" |
40 | #include "cpu/x64/gemm/gemm_utils.hpp" |
41 | #include "cpu/x64/gemm/gemv_driver.hpp" |
42 | |
43 | #include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp" |
44 | #include "cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.hpp" |
45 | #include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp" |
46 | |
47 | #include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemv_s8x8s32.hpp" |
48 | |
49 | namespace dnnl { |
50 | namespace impl { |
51 | namespace cpu { |
52 | namespace x64 { |
53 | #ifndef DNNL_WITH_SYCL |
54 | #define MAX_STACK_SZ (4 * PAGE_4K) |
55 | #else |
56 | #define MAX_STACK_SZ 0 |
57 | #endif |
58 | |
59 | template <typename c_type> |
60 | struct alignas(64) gemm_per_thread_t { |
61 | volatile int32_t result; |
62 | volatile int32_t compute_done; |
63 | int32_t thr_k_stride; |
64 | int32_t nthr_k; |
65 | dim_t ldc_local; |
66 | dim_t ldc_global; |
67 | c_type *c_local; |
68 | c_type *volatile c_global; |
69 | gemm_slice_t slice; |
70 | }; |
71 | |
72 | template <typename T> |
73 | int get_vector_length() { |
74 | int v_bytes; |
75 | |
76 | if (mayiuse(avx512_core)) |
77 | v_bytes = cpu_isa_traits<avx512_core>::vlen; |
78 | else if (mayiuse(avx)) |
79 | v_bytes = cpu_isa_traits<avx>::vlen; |
80 | else |
81 | v_bytes = cpu_isa_traits<sse41>::vlen; |
82 | |
83 | return v_bytes / sizeof(T); |
84 | } |
85 | |
86 | template <typename c_type> |
87 | static inline void round_to_nearest(c_type *rounded_val, double fp_val) { |
88 | if (fp_val >= 0.) { |
89 | fp_val += 0.5; |
90 | if (fp_val > INT32_MAX) { fp_val = INT32_MAX; } |
91 | } else { |
92 | fp_val -= 0.5; |
93 | if (fp_val < INT32_MIN) { fp_val = INT32_MIN; } |
94 | } |
95 | *rounded_val = (c_type)fp_val; |
96 | } |
97 | |
98 | template <typename c_type> |
99 | static inline void add_results(const dim_t m, const dim_t n, const float alpha, |
100 | const float beta, const c_type *c_partial_sum, const dim_t ldcp, |
101 | c_type *c_data, const dim_t ldc, const c_type *co, |
102 | offset_type offsetc) { |
103 | |
104 | constexpr bool is_int8 = data_traits<c_type>::data_type == data_type::s32; |
105 | |
106 | for (dim_t j = 0; j < n; ++j) { |
107 | for (dim_t i = 0; i < m; ++i) { |
108 | c_type ctemp = c_partial_sum[i + j * ldcp]; |
109 | |
110 | if (alpha == 1.0f) { |
111 | if (beta == 0.0f) { |
112 | c_data[i + j * ldc] = ctemp; |
113 | } else { |
114 | if (is_int8) { |
115 | double c_float |
116 | = (double)beta * (double)c_data[i + j * ldc]; |
117 | c_float += (double)ctemp; |
118 | round_to_nearest(&c_data[i + j * ldc], c_float); |
119 | } else { |
120 | c_data[i + j * ldc] *= beta; |
121 | c_data[i + j * ldc] += ctemp; |
122 | } |
123 | } |
124 | } else if (alpha == -1.0f) { |
125 | if (beta == 0.0f) { |
126 | c_data[i + j * ldc] = -ctemp; |
127 | } else { |
128 | if (is_int8) { |
129 | double c_float |
130 | = (double)beta * (double)c_data[i + j * ldc]; |
131 | c_float -= (double)ctemp; |
132 | round_to_nearest(&c_data[i + j * ldc], c_float); |
133 | } else { |
134 | c_data[i + j * ldc] *= beta; |
135 | c_data[i + j * ldc] -= ctemp; |
136 | } |
137 | } |
138 | } else { |
139 | if (beta == 0.0f) { |
140 | if (is_int8) { |
141 | double c_float = alpha * (double)ctemp; |
142 | round_to_nearest(&c_data[i + j * ldc], c_float); |
143 | } else { |
144 | c_data[i + j * ldc] = alpha * ctemp; |
145 | } |
146 | |
147 | } else { |
148 | if (is_int8) { |
149 | double c_float = alpha * (double)ctemp |
150 | + beta * (double)c_data[i + j * ldc]; |
151 | round_to_nearest(&c_data[i + j * ldc], c_float); |
152 | } else { |
153 | c_data[i + j * ldc] *= beta; |
154 | c_data[i + j * ldc] += alpha * ctemp; |
155 | } |
156 | } |
157 | } |
158 | |
159 | if (offsetc == offset_type::fixed) { |
160 | c_data[i + j * ldc] += co[0]; |
161 | } else if (offsetc == offset_type::row) { |
162 | c_data[i + j * ldc] += co[j]; |
163 | } else if (offsetc == offset_type::column) { |
164 | c_data[i + j * ldc] += co[i]; |
165 | } |
166 | } |
167 | } |
168 | } |
169 | |
170 | template <typename a_type, typename b_type, typename c_type> |
171 | static inline dim_t get_k_padd( |
172 | int ithr, dim_t k, const gemm_info_t<a_type, b_type, c_type> *arg) { |
173 | if (arg->a_packed) { |
174 | dim_t block_m, block_k; |
175 | arg->a_packed->get_blocking(ithr, block_m, block_k); |
176 | return block_k; |
177 | } else if (arg->b_packed) { |
178 | dim_t block_n, block_k; |
179 | arg->b_packed->get_blocking(ithr, block_k, block_n); |
180 | return block_k; |
181 | } else { |
182 | dim_t k_padd = 0; |
183 | |
184 | if (k <= arg->bk_traditional) { |
185 | k_padd = utils::rnd_up(k, arg->uk); |
186 | k_padd = nstl::max(dim_t(128), k_padd); |
187 | } else if (k < 2 * arg->bk) |
188 | k_padd = utils::rnd_up((k + 1) / 2, arg->uk); |
189 | else |
190 | k_padd = arg->bk; |
191 | |
192 | return k_padd; |
193 | } |
194 | } |
195 | |
196 | template <typename a_type, typename b_type, typename c_type> |
197 | static inline dim_t get_m_padd( |
198 | int ithr, dim_t m, const gemm_info_t<a_type, b_type, c_type> *arg) { |
199 | if (arg->a_packed) { |
200 | dim_t block_m, block_k; |
201 | arg->a_packed->get_blocking(ithr, block_m, block_k); |
202 | return block_m; |
203 | } else |
204 | return utils::rnd_up( |
205 | nstl::min(nstl::max(m, arg->um), arg->bm), arg->um); |
206 | } |
207 | |
208 | template <typename a_type, typename b_type, typename c_type> |
209 | static inline dim_t get_m_padd_parallel_a(int ithr, dim_t m, |
210 | const gemm_info_t<a_type, b_type, c_type> *arg, int nthrs) { |
211 | auto m_padd = get_m_padd(ithr, m, arg); |
212 | |
213 | if (!arg->a_packed) { |
214 | constexpr auto multiplier = 10; |
215 | |
216 | m_padd *= nstl::min(nthrs, multiplier); |
217 | if (m_padd > m) m_padd = utils::rnd_up(m, arg->um); |
218 | } |
219 | |
220 | return m_padd; |
221 | } |
222 | |
223 | template <typename a_type, typename b_type, typename c_type> |
224 | static inline dim_t get_n_padd(int ithr, dim_t n, dim_t k, |
225 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
226 | if (arg->b_packed) { |
227 | dim_t block_n, block_k; |
228 | arg->b_packed->get_blocking(ithr, block_k, block_n); |
229 | return block_n; |
230 | } else { |
231 | auto bn = (k < arg->blocking_small_k) ? arg->bn_small_k : arg->bn; |
232 | return utils::rnd_up(nstl::min(nstl::max(n, arg->un), bn), arg->un); |
233 | } |
234 | } |
235 | |
236 | static inline void *align(void *ptr, size_t alignment) { |
237 | return (void *)utils::rnd_up((uintptr_t)ptr, alignment); |
238 | } |
239 | |
240 | template <typename scale_t, typename mat_t> |
241 | void scale_matrix( |
242 | dim_t m, dim_t n, scale_t alpha, mat_t *__restrict p_mat, dim_t ld) { |
243 | if (data_traits<mat_t>::data_type == data_type::f32) { |
244 | for (dim_t j = 0; j < n; j++) { |
245 | for (dim_t i = 0; i < m; i++) { |
246 | p_mat[i + j * ld] = (mat_t)((scale_t)p_mat[i + j * ld] * alpha); |
247 | } |
248 | } |
249 | } |
250 | } |
251 | |
252 | template <typename mat_t> |
253 | static void sum_matrices(dim_t m, dim_t n, mat_t *__restrict dst, dim_t ld_dst, |
254 | mat_t *__restrict src, dim_t ld_src) { |
255 | |
256 | for (dim_t j = 0; j < n; j++) { |
257 | PRAGMA_OMP_SIMD() |
258 | for (int i = 0; i < m; i++) |
259 | dst[i + j * ld_dst] += src[i + j * ld_src]; |
260 | } |
261 | } |
262 | |
263 | template <typename c_type> |
264 | static void sum_k_blocks( |
265 | int ithr, gemm_per_thread_t<c_type> *thread_arg, bool wait) { |
266 | |
267 | auto m = thread_arg[ithr].slice.m; |
268 | auto n = thread_arg[ithr].slice.n; |
269 | auto ithr_k = thread_arg[ithr].slice.ithr_k; |
270 | auto nthr_k = thread_arg[ithr].nthr_k; |
271 | auto stride = thread_arg[ithr].thr_k_stride; |
272 | dim_t n0, nn; |
273 | |
274 | partition_1d(ithr_k, nthr_k, n, n0, nn); |
275 | |
276 | auto get_thread_arg = [&](int thr_k) -> gemm_per_thread_t<c_type> & { |
277 | return thread_arg[ithr + (thr_k - ithr_k) * stride]; |
278 | }; |
279 | |
280 | auto wait_thread = [&](int thr_k) { |
281 | if (wait) { |
282 | auto &tk_arg = get_thread_arg(thr_k); |
283 | while (!tk_arg.compute_done) {} |
284 | } |
285 | }; |
286 | |
287 | auto add_thread_results = [&](int thr_k) { |
288 | auto &tk_arg = get_thread_arg(thr_k); |
289 | |
290 | sum_matrices(m, nn, tk_arg.c_global + n0 * tk_arg.ldc_global, |
291 | tk_arg.ldc_global, tk_arg.c_local + n0 * tk_arg.ldc_local, |
292 | tk_arg.ldc_local); |
293 | }; |
294 | |
295 | // First accumulate this thread's results while they are in cache. |
296 | if (ithr_k > 0) { |
297 | wait_thread(0); |
298 | add_thread_results(ithr_k); |
299 | } |
300 | |
301 | // Then accumulate the others. |
302 | for (int thr_k = 1; thr_k < nthr_k; thr_k++) { |
303 | if (thr_k != ithr_k) { |
304 | wait_thread(thr_k); |
305 | add_thread_results(thr_k); |
306 | } |
307 | } |
308 | } |
309 | |
310 | template <typename a_type, typename b_type, typename c_type> |
311 | static dnnl_status_t pack_no_copy(gemm_info_t<a_type, b_type, c_type> *arg) { |
312 | |
313 | if (arg->packing == pack_type::pack_a) { |
314 | return gemm_utils::pack_no_copy(arg->a, arg->lda, arg->m, arg->k, |
315 | arg->transa, arg->alpha, arg->pack_dst); |
316 | } else { |
317 | return gemm_utils::pack_no_copy(arg->b, arg->ldb, arg->k, arg->n, |
318 | arg->transb, arg->alpha, arg->pack_dst); |
319 | } |
320 | } |
321 | |
322 | template <typename a_type, typename b_type, typename c_type> |
323 | static dnnl_status_t gemm_packing_driver(int ithr, dim_t m, dim_t n, dim_t k, |
324 | const a_type *a, const b_type *b, |
325 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
326 | |
327 | if (m <= 0 || n <= 0) return dnnl_success; |
328 | |
329 | gemm_pack_storage_t *pack_dst = arg->pack_dst; |
330 | |
331 | if (!pack_dst->is_first_thread_in_slice(ithr)) return dnnl_success; |
332 | |
333 | dim_t block_r, block_c; |
334 | pack_dst->get_blocking(ithr, block_r, block_c); |
335 | |
336 | auto do_a = (arg->packing == pack_type::pack_a); |
337 | auto mn = do_a ? m : n; |
338 | auto mn_padd = do_a ? block_r : block_c; |
339 | auto k_padd = do_a ? block_c : block_r; |
340 | dim_t mn_stride, k_stride; |
341 | |
342 | if (do_a) { |
343 | mn_stride = (arg->transa == no_trans) ? 1 : arg->lda; |
344 | k_stride = (arg->transa == no_trans) ? arg->lda : 1; |
345 | } else { |
346 | mn_stride = (arg->transb == no_trans) ? arg->ldb : 1; |
347 | k_stride = (arg->transb == no_trans) ? 1 : arg->ldb; |
348 | } |
349 | |
350 | dim_t blk_k = 0; |
351 | for (dim_t Bk = 0; Bk < k; Bk += k_padd, blk_k++) { |
352 | dim_t nk = nstl::min(k - Bk, k_padd); |
353 | |
354 | for (dim_t Bmn = 0; Bmn < mn; Bmn += mn_padd) { |
355 | dim_t nmn = nstl::min(mn - Bmn, mn_padd); |
356 | |
357 | if (do_a) { |
358 | auto a_src = a + mn_stride * Bmn + k_stride * Bk; |
359 | auto a_dst = pack_dst->matrix<a_type>(ithr, Bmn, Bk); |
360 | auto a_row_sum = pack_dst->row_sums<c_type>(ithr, Bmn, blk_k); |
361 | |
362 | arg->copyA(&nk, &nmn, a_src, &arg->lda, &arg->alpha, a_dst, |
363 | nullptr, nullptr, a_row_sum); |
364 | } else { |
365 | auto b_src = b + mn_stride * Bmn + k_stride * Bk; |
366 | auto b_dst = pack_dst->matrix<b_type>(ithr, Bk, Bmn); |
367 | auto b_col_sum = pack_dst->col_sums<c_type>(ithr, blk_k, Bmn); |
368 | |
369 | arg->copyB(&nk, &nmn, b_src, &arg->ldb, &arg->alpha, b_dst, |
370 | nullptr, nullptr, b_col_sum); |
371 | } |
372 | } |
373 | } |
374 | |
375 | return dnnl_success; |
376 | } |
377 | |
378 | template <typename a_type, typename b_type, typename c_type> |
379 | void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha, |
380 | const a_type *a, const b_type *b, float beta, c_type *c, |
381 | const dim_t ldc, const c_type *a_row_sum, const c_type *b_col_sum, |
382 | c_type *row_offset_ws, c_type *col_offset_ws, const c_type *co, |
383 | offset_type offsetc, const gemm_info_t<a_type, b_type, c_type> *arg) { |
384 | |
385 | bool col_req = false; |
386 | bool row_req = false; |
387 | |
388 | constexpr bool is_int8 = utils::one_of( |
389 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
390 | constexpr bool is_f32 = data_traits<a_type>::data_type == data_type::f32; |
391 | bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); |
392 | |
393 | dim_t m_stk = col_offset_ws ? 1 : m; |
394 | dim_t n_stk = row_offset_ws ? 1 : n; |
395 | #if !defined(_MSC_VER) |
396 | c_type col_offset_stk[m_stk]; |
397 | c_type row_offset_stk[n_stk]; |
398 | #else |
399 | c_type *col_offset_stk = nullptr; |
400 | if (!col_offset_ws) |
401 | col_offset_stk = (c_type *)_alloca(sizeof *col_offset_stk * m_stk); |
402 | |
403 | c_type *row_offset_stk = nullptr; |
404 | if (!row_offset_ws) |
405 | row_offset_stk = (c_type *)_alloca(sizeof *row_offset_stk * n_stk); |
406 | #endif |
407 | |
408 | // Use the heap if already allocated and stack otherwise. |
409 | c_type *col_offset = col_offset_ws ? col_offset_ws : col_offset_stk; |
410 | c_type *row_offset = row_offset_ws ? row_offset_ws : row_offset_stk; |
411 | |
412 | if (is_int8) { |
413 | c_type ao = arg->ao; |
414 | c_type bo = arg->bo; |
415 | c_type co_0 = offsetc == offset_type::none ? 0 : co[0]; |
416 | |
417 | if (bo != 0 || offsetc == offset_type::column) col_req = true; |
418 | if (ao != 0 || offsetc == offset_type::row) row_req = true; |
419 | |
420 | // It needs one of column or row offsets, but it doesn't need both |
421 | if ((ao != 0 && bo != 0) |
422 | || (offsetc == offset_type::fixed && co_0 != 0)) { |
423 | if (!col_req && !row_req) { |
424 | if (m <= n) { |
425 | col_req = true; |
426 | } else { |
427 | row_req = true; |
428 | } |
429 | } |
430 | } |
431 | |
432 | if (col_req) { |
433 | for (dim_t i = 0; i < m; i++) |
434 | col_offset[i] = 0; |
435 | |
436 | if (offsetc == offset_type::column) { |
437 | for (dim_t i = 0; i < m; i++) |
438 | col_offset[i] += co[i]; |
439 | } |
440 | |
441 | if (bo != 0 && a_row_sum) { |
442 | for (dim_t i = 0; i < m; i++) |
443 | col_offset[i] -= bo * a_row_sum[i]; |
444 | } |
445 | } |
446 | |
447 | if (row_req) { |
448 | for (dim_t i = 0; i < n; i++) |
449 | row_offset[i] = 0; |
450 | |
451 | if (offsetc == offset_type::row) { |
452 | for (dim_t i = 0; i < n; i++) |
453 | row_offset[i] += co[i]; |
454 | } |
455 | |
456 | if (ao != 0 && b_col_sum) { |
457 | for (dim_t i = 0; i < n; i++) |
458 | row_offset[i] -= ao * b_col_sum[i]; |
459 | } |
460 | } |
461 | |
462 | if (offsetc == offset_type::fixed && co_0 != 0) { |
463 | if (col_req) { |
464 | for (dim_t i = 0; i < m; i++) |
465 | col_offset[i] += co_0; |
466 | } else { |
467 | for (dim_t i = 0; i < n; i++) |
468 | row_offset[i] += co_0; |
469 | } |
470 | } |
471 | |
472 | if (ao != 0 && bo != 0) { |
473 | if (col_req) { |
474 | for (dim_t i = 0; i < m; i++) |
475 | col_offset[i] += (c_type)k * ao * bo; |
476 | } else { |
477 | for (dim_t i = 0; i < n; i++) |
478 | row_offset[i] += (c_type)k * ao * bo; |
479 | } |
480 | } |
481 | } |
482 | |
483 | bool isBeta0 = beta == 0.0f; |
484 | |
485 | /* Column and row offsets are ignored by non-integer compute kernels. |
486 | * Scaling is done only for bfloat16 kernels. |
487 | */ |
488 | if (m > 0 && n > 0) |
489 | arg->kernel[isBeta0][col_req][row_req]( |
490 | &m, &n, &k, &alpha, a, b, c, ldc, col_offset, row_offset); |
491 | |
492 | msan_unpoison_matrix(c, m, n, ldc, sizeof(*c)); |
493 | |
494 | // sgemm kernels don't support bias yet. |
495 | if (is_f32) { |
496 | if (co && offsetc == offset_type::column) { |
497 | for (dim_t j = 0; j < n; j++) { |
498 | for (dim_t i = 0; i < m; i++) { |
499 | c[i + j * ldc] += co[i]; |
500 | } |
501 | } |
502 | } |
503 | } |
504 | |
505 | // AMX igemm kernels don't support row & col sums yet. |
506 | if (is_int8_amx) { |
507 | for (dim_t j = 0; j < n; j++) { |
508 | for (dim_t i = 0; i < m; i++) { |
509 | if (row_req) c[i + j * ldc] += row_offset[j]; |
510 | if (col_req) c[i + j * ldc] += col_offset[i]; |
511 | } |
512 | } |
513 | } |
514 | } |
515 | |
516 | template <typename a_type, typename b_type, typename c_type> |
517 | static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k, |
518 | const a_type *a, const b_type *b, float beta, c_type *c, dim_t ldc, |
519 | offset_type offsetc, const c_type *co, |
520 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
521 | |
522 | if (arg->packing != pack_type::none) |
523 | return gemm_packing_driver(ithr, m, n, k, a, b, arg); |
524 | |
525 | if (m <= 0 || n <= 0) return dnnl_success; |
526 | |
527 | dim_t lda = arg->lda; |
528 | dim_t ldb = arg->ldb; |
529 | |
530 | float alpha = arg->alpha; |
531 | |
532 | constexpr bool is_int8 = utils::one_of( |
533 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
534 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
535 | bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); |
536 | bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); |
537 | bool is_amx = is_int8_amx || is_bf16_amx; |
538 | |
539 | const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed; |
540 | const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed; |
541 | |
542 | // Scaling C matrix. |
543 | if (!is_int8 && beta != 1.0f && beta != 0.0f) { |
544 | scale_matrix(m, n, beta, c, ldc); |
545 | beta = 1.0f; |
546 | } |
547 | |
548 | // Quick exit for C = beta * C |
549 | if (!is_int8 && alpha == 0.0f) { |
550 | if (beta == 0.0f) scale_matrix(m, n, beta, c, ldc); |
551 | |
552 | return dnnl_success; |
553 | } |
554 | |
555 | // Get block sizes. |
556 | dim_t k_padd = get_k_padd(ithr, k, arg); |
557 | dim_t m_padd = get_m_padd(ithr, m, arg); |
558 | dim_t n_padd = get_n_padd(ithr, n, k, arg); |
559 | |
560 | // Padding for temporary buffer for C |
561 | dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m_padd); |
562 | |
563 | dim_t strideAm = (arg->transa == no_trans) ? 1 : lda; |
564 | dim_t strideAn = (arg->transa != no_trans) ? 1 : lda; |
565 | dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb; |
566 | dim_t strideBn = (arg->transb != no_trans) ? 1 : ldb; |
567 | |
568 | size_t a_buf_nelems = m_padd * k_padd; |
569 | size_t b_buf_nelems = k_padd * n_padd; |
570 | // A and B buffers need more space due to zero-padding. |
571 | if (is_amx) { |
572 | a_buf_nelems = utils::rnd_up(m_padd, arg->um) |
573 | * utils::rnd_up(k_padd, arg->uk); |
574 | b_buf_nelems = utils::rnd_up(k_padd, arg->uk) |
575 | * utils::rnd_up(n_padd, arg->un); |
576 | } |
577 | size_t a_row_sum_nelems = m_padd; |
578 | size_t b_col_sum_nelems = n_padd; |
579 | |
580 | if (a_packed) a_buf_nelems = a_row_sum_nelems = 0; |
581 | if (b_packed) b_buf_nelems = b_col_sum_nelems = 0; |
582 | |
583 | size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K |
584 | + b_buf_nelems * sizeof(*b) + PAGE_4K; |
585 | |
586 | if (is_int8) { |
587 | mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K |
588 | + b_col_sum_nelems * sizeof(*c) + PAGE_4K; |
589 | } |
590 | |
591 | size_t col_offset_ws_nelems = arg->um; |
592 | size_t row_offset_ws_nelems = n_padd; |
593 | size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c); |
594 | const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ; |
595 | if (!use_stack) { |
596 | mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K; |
597 | mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K; |
598 | } |
599 | |
600 | bool need_c_buffer |
601 | = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f))) |
602 | // AMX bfloat16 kernels don't support alpha scaling yet, |
603 | // so we need to use accumulation buffer even if beta == 0. |
604 | || (is_bf16_amx && alpha != 1.0f); |
605 | |
606 | if (need_c_buffer) { |
607 | size_t c_buf_nelems = ldc_buf * n_padd; |
608 | mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; |
609 | } |
610 | |
611 | char *mem = nullptr; |
612 | |
613 | if (mem_size > 0) { |
614 | mem = (char *)malloc(mem_size, 128); |
615 | if (!mem) return dnnl_out_of_memory; |
616 | } |
617 | |
618 | a_type *bufferA = (a_type *)align(mem, PAGE_4K); |
619 | void *p_next_buf = bufferA + a_buf_nelems; |
620 | |
621 | b_type *bufferB = (b_type *)align(p_next_buf, PAGE_4K); |
622 | p_next_buf = bufferB + b_buf_nelems; |
623 | |
624 | c_type *a_row_sum = nullptr; |
625 | c_type *b_col_sum = nullptr; |
626 | if (is_int8) { |
627 | a_row_sum = (c_type *)align(p_next_buf, PAGE_4K); |
628 | p_next_buf = a_row_sum + a_row_sum_nelems; |
629 | |
630 | b_col_sum = (c_type *)align(p_next_buf, PAGE_4K); |
631 | p_next_buf = b_col_sum + b_col_sum_nelems; |
632 | } |
633 | |
634 | c_type *col_offset_ws = nullptr; |
635 | c_type *row_offset_ws = nullptr; |
636 | if (!use_stack) { |
637 | col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); |
638 | p_next_buf = col_offset_ws + col_offset_ws_nelems; |
639 | |
640 | row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); |
641 | p_next_buf = row_offset_ws + row_offset_ws_nelems; |
642 | } |
643 | |
644 | c_type *bufferC = nullptr; |
645 | if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K); |
646 | |
647 | int a_block_copied = 0; |
648 | dim_t sizeM = 0; |
649 | for (dim_t Bm = 0; Bm < m; Bm += sizeM) { |
650 | sizeM = m - Bm; |
651 | if (sizeM > m_padd) sizeM = m_padd; |
652 | |
653 | dim_t sizeK = 0; |
654 | dim_t blk_k = 0; |
655 | for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) { |
656 | sizeK = k - Bk; |
657 | if (sizeK > k_padd) sizeK = k_padd; |
658 | |
659 | // Scale C blocks by beta only for the first time |
660 | auto beta_eff = (Bk == 0) ? beta : 1.0f; |
661 | |
662 | // Apply C offset when to the last k-block of the partial sum. |
663 | auto offsetc_eff = offset_type::none; |
664 | if (Bk + sizeK == k) offsetc_eff = offsetc; |
665 | |
666 | dim_t sizeN = 0; |
667 | for (dim_t Bn = 0; Bn < n; Bn += sizeN) { |
668 | sizeN = n - Bn; |
669 | if (sizeN > n_padd) sizeN = n_padd; |
670 | |
671 | if (b_packed) { |
672 | bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn); |
673 | if (is_int8) |
674 | b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn); |
675 | } else { |
676 | const b_type *b_block = b + Bk * strideBm + Bn * strideBn; |
677 | const float one = 1.0f; |
678 | |
679 | /* Column sum argument is ignored for non-integer kernels |
680 | * and scaling factor is ignored by 8-bit and 16-bit copy |
681 | * kernels. |
682 | */ |
683 | arg->copyB(&sizeK, &sizeN, b_block, &ldb, &one, bufferB, |
684 | nullptr, nullptr, b_col_sum); |
685 | } |
686 | |
687 | dim_t sizeUM = 0; |
688 | for (dim_t Um = 0; Um < sizeM; Um += sizeUM) { |
689 | sizeUM = sizeM - Um; |
690 | if (sizeUM > arg->um) sizeUM = arg->um; |
691 | |
692 | /* Use the whole A buffer only if we have multiple B |
693 | * blocks for k-dimension, otherwise we are wasting cache |
694 | * to store B and C blocks. |
695 | */ |
696 | dim_t Um_forA = 0; |
697 | if (sizeN < n) Um_forA = Um; |
698 | |
699 | a_type *bufferA_eff = nullptr; |
700 | c_type *a_row_sum_eff = nullptr; |
701 | |
702 | if (a_packed) { |
703 | Um_forA = Um; |
704 | |
705 | // TODO Can we simplify this! |
706 | dim_t buf_shift = 0; |
707 | if (is_amx) |
708 | buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk); |
709 | else |
710 | buf_shift = Um_forA * sizeK; |
711 | |
712 | bufferA_eff = a_packed->matrix<a_type>(ithr, Bm, Bk) |
713 | + buf_shift; |
714 | |
715 | if (is_int8) |
716 | a_row_sum_eff = a_packed->row_sums<c_type>( |
717 | ithr, Bm, blk_k) |
718 | + Um_forA; |
719 | } else { |
720 | // TODO Can we simplify this! |
721 | dim_t buf_shift = 0; |
722 | if (is_amx) |
723 | buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk); |
724 | else |
725 | buf_shift = Um_forA * sizeK; |
726 | |
727 | bufferA_eff = bufferA + buf_shift; |
728 | a_row_sum_eff |
729 | = a_row_sum ? a_row_sum + Um_forA : nullptr; |
730 | |
731 | if (!a_block_copied) { |
732 | const a_type *a_block |
733 | = a + (Bm + Um) * strideAm + Bk * strideAn; |
734 | |
735 | /* Row sum argument is ignored for non-integer |
736 | * kernels and scaling factor is ignored by 8-bit |
737 | * and 16-bit copy kernels. |
738 | */ |
739 | arg->copyA(&sizeK, &sizeUM, a_block, &lda, &alpha, |
740 | bufferA_eff, nullptr, nullptr, |
741 | a_row_sum_eff); |
742 | } |
743 | } |
744 | |
745 | c_type *c_block = c + (Bm + Um) + Bn * ldc; |
746 | |
747 | dim_t co_stride = 0; |
748 | if (offsetc_eff == offset_type::row) |
749 | co_stride = Bn; |
750 | else if (offsetc_eff == offset_type::column) |
751 | co_stride = Bm + Um; |
752 | |
753 | if (need_c_buffer) { |
754 | gemm_kernel(sizeUM, sizeN, sizeK, 1.0f, bufferA_eff, |
755 | bufferB, 0.0f, bufferC + Um, ldc_buf, |
756 | a_row_sum_eff, b_col_sum, row_offset_ws, |
757 | col_offset_ws, (c_type *)nullptr, |
758 | offset_type::none, arg); |
759 | |
760 | /* Finish the block adding the necessary alpha, beta |
761 | * and offsets. |
762 | */ |
763 | add_results(sizeUM, sizeN, alpha, beta_eff, |
764 | bufferC + Um, ldc_buf, c_block, ldc, |
765 | co + co_stride, offsetc_eff); |
766 | } else { |
767 | gemm_kernel(sizeUM, sizeN, sizeK, alpha, bufferA_eff, |
768 | bufferB, beta_eff, c_block, ldc, a_row_sum_eff, |
769 | b_col_sum, row_offset_ws, col_offset_ws, |
770 | co + co_stride, offsetc_eff, arg); |
771 | } |
772 | } |
773 | a_block_copied = 1; |
774 | } |
775 | a_block_copied = 0; |
776 | } |
777 | } |
778 | |
779 | free(mem); |
780 | |
781 | return dnnl_success; |
782 | } |
783 | |
784 | template <typename a_type, typename b_type, typename c_type> |
785 | static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m, |
786 | dim_t n, dim_t k, dim_t blk_k, dim_t Bk, const a_type *bufferA, |
787 | const b_type *b, float beta, c_type *c, offset_type offsetc, |
788 | const c_type *co, const c_type *a_row_sum, |
789 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
790 | |
791 | dim_t ldb = arg->ldb; |
792 | dim_t ldc = arg->ldc; |
793 | |
794 | float alpha = arg->alpha; |
795 | |
796 | const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed; |
797 | |
798 | if (m <= 0 || n <= 0) { return dnnl_success; } |
799 | |
800 | // Padding along N dimension. |
801 | dim_t n_padd = get_n_padd(ithr, n, k, arg); |
802 | |
803 | // Padding for temporary buffer for C |
804 | dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m); |
805 | |
806 | dim_t strideBn = (arg->transb != 0) ? 1 : ldb; |
807 | |
808 | size_t b_buf_nelems = k * n_padd; |
809 | size_t b_col_sum_nelems = n_padd; |
810 | constexpr bool is_int8 = utils::one_of( |
811 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
812 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
813 | bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); |
814 | bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); |
815 | bool is_amx = is_int8_amx || is_bf16_amx; |
816 | |
817 | // B buffer needs to be large due to zero-padding. |
818 | if (is_amx) |
819 | b_buf_nelems |
820 | = utils::rnd_up(k, arg->uk) * utils::rnd_up(n_padd, arg->un); |
821 | |
822 | if (b_packed) b_buf_nelems = b_col_sum_nelems = 0; |
823 | |
824 | size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K; |
825 | |
826 | if (is_int8) { mem_size += b_col_sum_nelems * sizeof(*c) + PAGE_4K; } |
827 | |
828 | size_t col_offset_ws_nelems = m; |
829 | size_t row_offset_ws_nelems = n_padd; |
830 | size_t stk_sz = (col_offset_ws_nelems + row_offset_ws_nelems) * sizeof(*c); |
831 | const bool use_stack = is_int8 && stk_sz <= MAX_STACK_SZ; |
832 | if (!use_stack) { |
833 | mem_size += col_offset_ws_nelems * sizeof(*c) + PAGE_4K; |
834 | mem_size += row_offset_ws_nelems * sizeof(*c) + PAGE_4K; |
835 | } |
836 | |
837 | bool need_c_buffer |
838 | = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f))) |
839 | // AMX bfloat16 kernels don't support alpha scaling yet, |
840 | // so we need to use accumulation buffer even if beta == 0. |
841 | || (is_bf16_amx && alpha != 1.0f); |
842 | |
843 | if (need_c_buffer) { |
844 | size_t c_buf_nelems = ldc_buf * n_padd; |
845 | mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K; |
846 | } |
847 | |
848 | char *mem = nullptr; |
849 | |
850 | if (mem_size > 0) { |
851 | mem = (char *)malloc(mem_size, 128); |
852 | if (!mem) return dnnl_out_of_memory; |
853 | } |
854 | |
855 | b_type *bufferB = (b_type *)align(mem, PAGE_4K); |
856 | void *p_next_buf = bufferB + b_buf_nelems; |
857 | |
858 | c_type *b_col_sum = nullptr; |
859 | if (is_int8) { |
860 | b_col_sum = (c_type *)align(p_next_buf, PAGE_4K); |
861 | p_next_buf = b_col_sum + b_col_sum_nelems; |
862 | } |
863 | |
864 | c_type *col_offset_ws = nullptr; |
865 | c_type *row_offset_ws = nullptr; |
866 | if (!use_stack) { |
867 | col_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); |
868 | p_next_buf = col_offset_ws + col_offset_ws_nelems; |
869 | |
870 | row_offset_ws = (c_type *)align(p_next_buf, PAGE_4K); |
871 | p_next_buf = row_offset_ws + row_offset_ws_nelems; |
872 | } |
873 | |
874 | c_type *bufferC = nullptr; |
875 | if (need_c_buffer) bufferC = (c_type *)align(p_next_buf, PAGE_4K); |
876 | |
877 | dim_t sizeN = 0; |
878 | for (dim_t Bn = 0; Bn < n; Bn += sizeN) { |
879 | sizeN = n - Bn; |
880 | if (sizeN > n_padd) sizeN = n_padd; |
881 | |
882 | if (b_packed) { |
883 | bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn); |
884 | if (is_int8) |
885 | b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn); |
886 | } else { |
887 | const b_type *b_block = b + Bn * strideBn; |
888 | const float one = 1.0f; |
889 | |
890 | /* Column sum argument is ignored for non-integer kernels and |
891 | * scaling factor is ignored by 8-bit and 16-bit copy kernels. |
892 | */ |
893 | arg->copyB(&k, &sizeN, b_block, &ldb, &one, bufferB, nullptr, |
894 | nullptr, b_col_sum); |
895 | } |
896 | |
897 | dim_t co_stride = 0; |
898 | if (offsetc == offset_type::fixed) { |
899 | co_stride = 0; |
900 | } else if (offsetc == offset_type::row) { |
901 | co_stride = Bn; |
902 | } else if (offsetc == offset_type::column) { |
903 | co_stride = 0; |
904 | } |
905 | |
906 | c_type *c_block = c + Bn * ldc; |
907 | if (need_c_buffer) { |
908 | gemm_kernel(m, sizeN, k, 1.0f, bufferA, bufferB, 0.0f, bufferC, |
909 | ldc_buf, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws, |
910 | (c_type *)nullptr, offset_type::none, arg); |
911 | |
912 | // Finish the block adding the necessary alpha, beta and offsets. |
913 | add_results(m, sizeN, alpha, beta, bufferC, ldc_buf, c_block, ldc, |
914 | co + co_stride, offsetc); |
915 | } else { |
916 | gemm_kernel(m, sizeN, k, alpha, bufferA, bufferB, beta, c_block, |
917 | ldc, a_row_sum, b_col_sum, row_offset_ws, col_offset_ws, |
918 | co + co_stride, offsetc, arg); |
919 | } |
920 | } |
921 | |
922 | free(mem); |
923 | |
924 | return dnnl_success; |
925 | } |
926 | |
927 | static inline bool nocopy_checker_avx2(const int nthr, const int transa, |
928 | const int transb, const dim_t m, const dim_t n, const dim_t k, |
929 | const dim_t lda, const dim_t ldb, const dim_t ldc) { |
930 | static const dim_t BM_NOCOPY_AVX2 = 64; |
931 | static const dim_t MN_NOCOPY_AVX2 = 128; |
932 | static const dim_t N_TRANSB_PER_THR = 1; |
933 | static const dim_t K_TRANSB_PER_THR = 1; |
934 | static const dim_t N_NOTRANSB_PER_THR = 16; |
935 | static const dim_t K_NOTRANSB_PER_THR = 2; |
936 | static const double FORCE_NOCOPY_THRESH = 0.0038; |
937 | |
938 | // Crude threshold to nocopy kernels if copy overhead is significant. |
939 | if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH) { return true; } |
940 | |
941 | if (m <= 378 && n <= 378 && k >= nthr * 378) return false; |
942 | |
943 | if (m >= nthr * 378 && k >= nthr * 378) return false; |
944 | |
945 | if (transb == no_trans) { |
946 | if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true; |
947 | if (n <= nthr * N_NOTRANSB_PER_THR) return true; |
948 | if (k <= nthr * K_NOTRANSB_PER_THR) return true; |
949 | if (m <= BM_NOCOPY_AVX2 && n >= nthr * N_NOTRANSB_PER_THR) return true; |
950 | } else { |
951 | if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true; |
952 | if (n <= nthr * N_TRANSB_PER_THR) return true; |
953 | if (k <= nthr * K_TRANSB_PER_THR) return true; |
954 | } |
955 | |
956 | return false; |
957 | } |
958 | |
959 | static inline bool nocopy_checker_avx512(int nthr, const int transa, |
960 | const int transb, const dim_t m, const dim_t n, const dim_t k, |
961 | const dim_t lda, const dim_t ldb, const dim_t ldc) { |
962 | // Constants definition |
963 | static const dim_t BAD_LD_MULT = 256; |
964 | static const dim_t VERYBAD_LD_MULT = 1024; |
965 | static const dim_t M_TRANSB_PER_THR = 28; |
966 | static const dim_t N_TRANSB_PER_THR = 28; |
967 | static const dim_t K_TRANSB_PER_THR = 1; |
968 | static const dim_t MN_NOTRANSB_PER_THR = 28; |
969 | static const dim_t K_NOTRANSB_PER_THR = 1; |
970 | static const double FORCE_NOCOPY_THRESH = 0.00196; |
971 | |
972 | bool is_NN = transa == no_trans && transb == no_trans; |
973 | bool is_NT = transa == no_trans && transb == do_trans; |
974 | bool is_TN = transa == do_trans && transb == no_trans; |
975 | |
976 | bool is_lda_bad = lda % BAD_LD_MULT == 0; |
977 | bool is_ldb_bad = ldb % BAD_LD_MULT == 0; |
978 | bool is_ldc_bad = ldc % BAD_LD_MULT == 0; |
979 | bool is_ld_bad = is_lda_bad || is_ldb_bad || is_ldc_bad; |
980 | |
981 | bool is_lda_verybad = lda % VERYBAD_LD_MULT == 0; |
982 | |
983 | // Copy-based performs better for TN case with small N in sequential case. |
984 | if (nthr == 1 && is_TN && m > 100 |
985 | && ((m < 1200 && n < 200 && k < 1200) |
986 | || (is_lda_bad && is_ldb_bad))) |
987 | return false; |
988 | |
989 | // Copy-based performs better for NN case on very bad leading dimension if |
990 | // each thread has enough work. |
991 | if (nthr <= 8 && is_NN && is_lda_verybad && k > 500 && n > 100) |
992 | return false; |
993 | |
994 | // Crude threshold for nocopy kernels if copy overhead is significant. |
995 | if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH |
996 | && !(is_lda_verybad && is_NT)) { |
997 | return true; |
998 | } |
999 | |
1000 | // Copy strategy usually performs better than nocopy on "bad" leading |
1001 | // dimensions. |
1002 | if (is_ld_bad) { |
1003 | bool use_copy_based = false; |
1004 | |
1005 | if (m >= 32 && n > 16) use_copy_based = true; |
1006 | |
1007 | // Nocopy outperforms copy-based in certain conditions. |
1008 | if (m >= 32 && n == 16 |
1009 | && (k >= 6400 || transa == do_trans || m == 4096)) |
1010 | use_copy_based = true; |
1011 | |
1012 | if (use_copy_based) return false; |
1013 | } |
1014 | |
1015 | if (m <= 378 && n <= 378 && k >= nthr * 378) return false; |
1016 | |
1017 | if (m >= nthr * 378 && k >= nthr * 378) return false; |
1018 | |
1019 | if (transb == no_trans) { |
1020 | if (m <= nthr * MN_NOTRANSB_PER_THR) return true; |
1021 | if (n <= nthr * MN_NOTRANSB_PER_THR) return true; |
1022 | if (k <= nthr * K_NOTRANSB_PER_THR) return true; |
1023 | } else { |
1024 | if (m <= nthr * M_TRANSB_PER_THR && m >= n) return true; |
1025 | if (n <= nthr * N_TRANSB_PER_THR) return true; |
1026 | if (k <= nthr * K_TRANSB_PER_THR) return true; |
1027 | } |
1028 | return false; |
1029 | } |
1030 | |
1031 | template <typename a_type, typename b_type, typename c_type> |
1032 | static inline bool nocopy_checker( |
1033 | int nthr, const gemm_info_t<a_type, b_type, c_type> *arg) { |
1034 | |
1035 | if (data_traits<a_type>::data_type != data_type::f32) return false; |
1036 | |
1037 | if (!mayiuse(avx)) return false; |
1038 | |
1039 | if (arg->force_nocopy) return true; |
1040 | |
1041 | auto m = arg->m, n = arg->n, k = arg->k; |
1042 | auto lda = arg->lda, ldb = arg->ldb, ldc = arg->ldc; |
1043 | auto transa = arg->transa, transb = arg->transb; |
1044 | auto packing = arg->packing; |
1045 | |
1046 | if (packing != pack_type::none) ldc = 64; |
1047 | |
1048 | if (arg->a_packed || arg->b_packed) |
1049 | return false; |
1050 | else if (mayiuse(avx512_core)) |
1051 | return nocopy_checker_avx512( |
1052 | nthr, transa, transb, m, n, k, lda, ldb, ldc); |
1053 | else |
1054 | return nocopy_checker_avx2( |
1055 | nthr, transa, transb, m, n, k, lda, ldb, ldc); |
1056 | } |
1057 | |
1058 | template <typename a_type, typename b_type, typename c_type> |
1059 | static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn, |
1060 | gemm_threading_t &thread_info, |
1061 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
1062 | |
1063 | static constexpr dim_t N2D_MAX = 384; |
1064 | static constexpr dim_t M2D_MIN = 384; |
1065 | |
1066 | constexpr bool is_int8 = utils::one_of( |
1067 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1068 | bool isSgemm = data_traits<a_type>::data_type == data_type::f32; |
1069 | |
1070 | dim_t m = arg->m; |
1071 | dim_t n = arg->n; |
1072 | dim_t k = arg->k; |
1073 | |
1074 | thread_info.nthrs_m = 0; |
1075 | thread_info.nthrs_n = 0; |
1076 | thread_info.nthrs_k = 0; |
1077 | thread_info.copy = copy_type::nonshared; |
1078 | thread_info.partition = partition_type::row_1d; |
1079 | |
1080 | // TODO Check if we can use dynamic scheduling for sgemm. |
1081 | // TODO Check if we should use 3D blocking. |
1082 | thread_info.nthrs_k = 1; |
1083 | thread_info.thread_k = k; |
1084 | |
1085 | bool condition_2D_bsrc = false; |
1086 | if (isSgemm) { |
1087 | // If m is large and n is small then do 1D partitioning for AVX2. |
1088 | if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN)) |
1089 | condition_2D_bsrc = false; |
1090 | else |
1091 | condition_2D_bsrc |
1092 | = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2)) |
1093 | && (m >= 2 * M2D_MIN); |
1094 | } else { |
1095 | int scale = mayiuse(avx512_core) ? nthrs : 20; |
1096 | condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n); |
1097 | } |
1098 | |
1099 | // TODO Check if we should use k-partitioning. |
1100 | |
1101 | int condition_1D_copya = false; |
1102 | if (mayiuse(avx512_core)) { |
1103 | const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68; |
1104 | if (m >= 1000 && (n >= nthrs * thresh)) { |
1105 | condition_2D_bsrc = false; |
1106 | condition_1D_copya = true; |
1107 | } |
1108 | } else { |
1109 | if (m >= 1000 && n >= 4000) { |
1110 | condition_2D_bsrc = false; |
1111 | condition_1D_copya = true; |
1112 | } |
1113 | } |
1114 | |
1115 | // If A or B offset is non-zero, we need to keep 1D_copya to reduce update |
1116 | // overhead. |
1117 | // TODO: the reasons seems to be in copy_sum_bx routines. At least, |
1118 | // after simple optimization of copy_sum_ax for avx512, similar |
1119 | // restriction on offset B became unnecessary. Revisit. |
1120 | if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) { |
1121 | condition_2D_bsrc = false; |
1122 | condition_1D_copya = true; |
1123 | } |
1124 | |
1125 | if (condition_2D_bsrc) { |
1126 | int nthrs_m = 1; |
1127 | int nthrs_n = nthrs; |
1128 | |
1129 | if (isSgemm) { |
1130 | while ((nthrs_n % 2 == 0) |
1131 | && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2) |
1132 | && (m / nthrs_m >= 2 * M2D_MIN) && (nthrs_m < 4)) { |
1133 | nthrs_m *= 2; |
1134 | nthrs_n /= 2; |
1135 | } |
1136 | |
1137 | thread_info.nthrs_m = nthrs_m; |
1138 | thread_info.nthrs_n = nthrs_n; |
1139 | thread_info.partition = partition_type::col_major_2d; |
1140 | } else { |
1141 | if (m == 800 && n == 300) { |
1142 | // TODO: Expand this branch to other problem sizes. |
1143 | |
1144 | auto &thread_m = thread_info.thread_m; |
1145 | auto &thread_n = thread_info.thread_n; |
1146 | |
1147 | const dim_t block_m = arg->um * 4; |
1148 | constexpr dim_t block_n = 64; |
1149 | constexpr dim_t small_m = 16; |
1150 | constexpr dim_t small_n = 2; |
1151 | |
1152 | std::tie(nthrs_m, nthrs_n) |
1153 | = gemm_utils::calc_nthr_2d(nthrs, m, n, block_m, |
1154 | block_n, small_m, small_n, thread_m, thread_n); |
1155 | |
1156 | thread_info.nthrs_m = nthrs_m; |
1157 | thread_info.nthrs_n = nthrs_n; |
1158 | thread_info.partition = partition_type::mnk_3d; |
1159 | |
1160 | } else if ((n <= 64 || n >= 256)) { |
1161 | while (((nthrs_n > 1) && (n / nthrs_n < arg->un) |
1162 | && (m / nthrs_m >= 2 * arg->um) |
1163 | && mayiuse(avx512_core)) |
1164 | || ((nthrs_n % 2 == 0) |
1165 | && (n / nthrs > N2D_MAX |
1166 | || n / nthrs_n <= N2D_MAX / 2) |
1167 | && (m / nthrs_m >= 2 * M2D_MIN) |
1168 | && (nthrs_m < 4))) { |
1169 | nthrs_m *= 2; |
1170 | nthrs_n /= 2; |
1171 | } |
1172 | |
1173 | thread_info.nthrs_m = nthrs_m; |
1174 | thread_info.nthrs_n = nthrs_n; |
1175 | thread_info.partition = partition_type::col_major_2d; |
1176 | } else { |
1177 | // Use 3D decomposition from pack api without k-partitioning. |
1178 | set_thread_opts_pack(nthrs, thread_info, arg, false); |
1179 | } |
1180 | } |
1181 | |
1182 | } else if (condition_1D_copya && dnnl_thr_syncable()) { |
1183 | // Use parallel copy A algorithm |
1184 | thread_info.copy = copy_type::shared_a; |
1185 | thread_info.partition = partition_type::col_1d; |
1186 | thread_info.nthrs_m = 1; |
1187 | thread_info.nthrs_n = nthrs_spawn; // Using all spawned threads. |
1188 | } else { |
1189 | auto veclen = get_vector_length<c_type>(); |
1190 | |
1191 | if (m > n && (m >= nthrs * veclen || n < nthrs)) { |
1192 | if (n <= 20 && is_int8) { |
1193 | // Use 3D decomposition forcing m-blocking only. |
1194 | set_thread_opts_pack( |
1195 | nthrs, thread_info, arg, false, true, false); |
1196 | } else { |
1197 | thread_info.partition = partition_type::row_1d; |
1198 | thread_info.nthrs_m = nthrs; |
1199 | thread_info.nthrs_n = 1; |
1200 | } |
1201 | } else { |
1202 | thread_info.partition = partition_type::col_1d; |
1203 | thread_info.nthrs_m = 1; |
1204 | thread_info.nthrs_n = nthrs; |
1205 | } |
1206 | } |
1207 | } |
1208 | |
1209 | template <typename a_type, typename b_type, typename c_type> |
1210 | static inline void set_thread_opts_pack(int nthrs, |
1211 | gemm_threading_t &thread_info, |
1212 | const gemm_info_t<a_type, b_type, c_type> *arg, |
1213 | bool do_k_blocking = true, bool do_m_blocking = true, |
1214 | bool do_n_blocking = true) { |
1215 | |
1216 | constexpr bool is_int8 = utils::one_of( |
1217 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1218 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
1219 | |
1220 | bool do_m_blocking_only = do_m_blocking && !do_n_blocking; |
1221 | |
1222 | auto m = arg->m, n = arg->n, k = arg->k; |
1223 | |
1224 | auto &nthr_m = thread_info.nthrs_m; |
1225 | auto &nthr_n = thread_info.nthrs_n; |
1226 | auto &nthr_k = thread_info.nthrs_k; |
1227 | auto &thread_m = thread_info.thread_m; |
1228 | auto &thread_n = thread_info.thread_n; |
1229 | auto &thread_k = thread_info.thread_k; |
1230 | auto &block_m = thread_info.block_m; |
1231 | auto &block_n = thread_info.block_n; |
1232 | auto &block_k = thread_info.block_k; |
1233 | |
1234 | constexpr auto MBLK = 64; |
1235 | constexpr auto NBLK = 64; |
1236 | auto KBLK = is_int8 ? 3072 : 256; |
1237 | KBLK = do_m_blocking_only && is_int8 ? 384 : KBLK; |
1238 | |
1239 | nthr_m = nthr_n = nthr_k = 1; |
1240 | thread_info.copy = copy_type::nonshared; |
1241 | thread_info.partition = partition_type::mnk_3d; |
1242 | |
1243 | auto choose_blocking |
1244 | = [](dim_t size_z, dim_t &thread_z, int &nthr_z, dim_t block_z_init, |
1245 | dim_t &block_z, dim_t block_align) { |
1246 | thread_z = utils::div_up(size_z, nthr_z); |
1247 | auto num_blk = utils::div_up(thread_z, block_z_init); |
1248 | block_z = utils::div_up(thread_z, num_blk); |
1249 | block_z = utils::rnd_up(block_z, block_align); |
1250 | thread_z = num_blk * block_z; |
1251 | if (thread_z * nthr_z > size_z) |
1252 | nthr_z = utils::div_up(size_z, thread_z); |
1253 | }; |
1254 | |
1255 | auto choose_m_blocking = [&]() { |
1256 | auto align = get_vector_length<c_type>(); |
1257 | align = do_m_blocking_only ? arg->um : align; |
1258 | choose_blocking(m, thread_m, nthr_m, arg->bm, block_m, align); |
1259 | }; |
1260 | auto choose_n_blocking = [&]() { |
1261 | choose_blocking(n, thread_n, nthr_n, arg->bn, block_n, arg->un); |
1262 | }; |
1263 | auto choose_k_blocking = [&]() { |
1264 | auto align = nstl::max(arg->uk, dim_t(4)); |
1265 | choose_blocking(k, thread_k, nthr_k, arg->bk, block_k, align); |
1266 | }; |
1267 | |
1268 | // Choose k blocking. |
1269 | if ((m / MBLK + n / NBLK) < nthrs && do_k_blocking) { |
1270 | for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) |
1271 | if (nthrs % nk == 0) nthr_k = nk; |
1272 | |
1273 | // Sacrifice one thread and try again if parallelism is too small in |
1274 | // n-dimension. |
1275 | if (nthr_k == 1 && nthrs > 1 && do_m_blocking_only) { |
1276 | nthrs--; |
1277 | for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) |
1278 | if (nthrs % nk == 0) nthr_k = nk; |
1279 | } |
1280 | |
1281 | // Allow up to 2 threads to be sacrificed for large k >> m, n. |
1282 | if (nthr_k < 4 && k >= m * 4 && k >= n * 4 && nthrs > 10 && is_bf16) { |
1283 | for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++) |
1284 | if (nthrs % nk <= 2) nthr_k = nk; |
1285 | } |
1286 | } |
1287 | |
1288 | choose_k_blocking(); |
1289 | |
1290 | // Choose m/n blocking. |
1291 | auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um; |
1292 | min_mblk = do_m_blocking ? min_mblk : m; |
1293 | min_mblk = do_m_blocking_only ? arg->um : min_mblk; |
1294 | auto min_nblk = do_n_blocking ? NBLK / 2 : n; |
1295 | |
1296 | std::tie(nthr_m, nthr_n) = partition_2d_minblk(m, n, MBLK, NBLK, min_mblk, |
1297 | min_nblk, arg->um, arg->un, nthrs / nthr_k, |
1298 | do_m_blocking && do_n_blocking && do_k_blocking); |
1299 | |
1300 | auto nthr_m_init = nthr_m, nthr_n_init = nthr_n; |
1301 | |
1302 | choose_m_blocking(); |
1303 | choose_n_blocking(); |
1304 | |
1305 | if (is_int8 && do_m_blocking && do_n_blocking) { |
1306 | // If we lost a thread in one dimension because we padded the blocking |
1307 | // size, try to rebalance the other dimensions. |
1308 | if ((nthr_n != nthr_n_init) |
1309 | && ((nthr_m + 1) * nthr_n * nthr_k <= nthrs)) { |
1310 | nthr_m++; |
1311 | choose_m_blocking(); |
1312 | } |
1313 | |
1314 | if ((nthr_m != nthr_m_init) |
1315 | && (nthr_m * (nthr_n + 1) * nthr_k <= nthrs)) { |
1316 | nthr_n++; |
1317 | choose_n_blocking(); |
1318 | } |
1319 | } |
1320 | } |
1321 | |
1322 | template <typename a_type, typename b_type, typename c_type> |
1323 | static inline int set_thread_opts(int nthrs, int nthrs_spawn, |
1324 | gemm_threading_t &thread_info, |
1325 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
1326 | |
1327 | thread_info.block_m = thread_info.block_n = thread_info.block_k = -1; |
1328 | thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1; |
1329 | |
1330 | constexpr bool is_int8 = utils::one_of( |
1331 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1332 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
1333 | |
1334 | if (nocopy_checker(nthrs, arg)) { |
1335 | thread_info.copy = copy_type::no_copy; |
1336 | thread_info.partition = partition_type::mnk_3d; |
1337 | int nthrs_m = 0; |
1338 | int nthrs_n = 0; |
1339 | int nthrs_k = 0; |
1340 | dim_t BM = 0; |
1341 | dim_t BN = 0; |
1342 | dim_t BK = 0; |
1343 | auto m = arg->m, n = arg->n, k = arg->k; |
1344 | |
1345 | if (mayiuse(avx512_core)) { |
1346 | cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs, |
1347 | &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK); |
1348 | } else { |
1349 | cpu::gemm_utils::calc_nthr_nocopy_avx(m, n, k, nthrs, &nthrs_m, |
1350 | &nthrs_n, &nthrs_k, &BM, &BN, &BK); |
1351 | } |
1352 | |
1353 | // Block information is being ignored. We will create partitioning |
1354 | // later. |
1355 | thread_info.nthrs_m = nthrs_m; |
1356 | thread_info.nthrs_n = nthrs_n; |
1357 | thread_info.nthrs_k = nthrs_k; |
1358 | } else { |
1359 | if (arg->packing != pack_type::none && (is_int8 || is_bf16)) |
1360 | set_thread_opts_pack(nthrs, thread_info, arg); |
1361 | else |
1362 | set_thread_opts_nopack(nthrs, nthrs_spawn, thread_info, arg); |
1363 | } |
1364 | |
1365 | return thread_info.nthrs_m * thread_info.nthrs_n * thread_info.nthrs_k; |
1366 | } |
1367 | |
1368 | template <typename a_type, typename b_type, typename c_type> |
1369 | static inline std::tuple<const a_type *, const b_type *, c_type *, |
1370 | const c_type *> |
1371 | decompose_matrices(const gemm_slice_t &slice, |
1372 | const gemm_info_t<a_type, b_type, c_type> *arg) { |
1373 | |
1374 | dim_t stride_am = (arg->transa == no_trans) ? 1 : arg->lda; |
1375 | dim_t stride_ak = (arg->transa != no_trans) ? 1 : arg->lda; |
1376 | dim_t stride_bn = (arg->transb != no_trans) ? 1 : arg->ldb; |
1377 | dim_t stride_bk = (arg->transb == no_trans) ? 1 : arg->ldb; |
1378 | |
1379 | auto a = arg->a; |
1380 | auto b = arg->b; |
1381 | auto c = arg->c; |
1382 | if (a) a += slice.off_m * stride_am + slice.off_k * stride_ak; |
1383 | if (b) b += slice.off_n * stride_bn + slice.off_k * stride_bk; |
1384 | if (c) c += slice.off_m + slice.off_n * arg->ldc; |
1385 | |
1386 | dim_t co_stride; |
1387 | switch (arg->offsetc) { |
1388 | case offset_type::row: co_stride = slice.off_n; break; |
1389 | case offset_type::column: co_stride = slice.off_m; break; |
1390 | default: co_stride = 0; break; |
1391 | } |
1392 | auto co = arg->co; |
1393 | if (co) co += co_stride; |
1394 | |
1395 | return std::make_tuple(a, b, c, co); |
1396 | } |
1397 | |
1398 | template <typename a_type, typename b_type, typename c_type> |
1399 | static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs, |
1400 | const dim_t m, const dim_t n, const dim_t k, const a_type *a, |
1401 | const b_type *b, float beta, c_type *c, dim_t ldc, offset_type offsetc, |
1402 | const c_type *co, const gemm_info_t<a_type, b_type, c_type> *arg, |
1403 | char **p_shared_mem) { |
1404 | |
1405 | if (arg->packing != pack_type::none) |
1406 | return gemm_packing_driver(ithr, m, n, k, a, b, arg); |
1407 | |
1408 | const dim_t lda = arg->lda; |
1409 | const dim_t ldb = arg->ldb; |
1410 | const dim_t strideAm = (arg->transa == no_trans) ? 1 : lda; |
1411 | const dim_t strideAn = (arg->transa != no_trans) ? 1 : lda; |
1412 | const dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb; |
1413 | |
1414 | float alpha = arg->alpha; |
1415 | |
1416 | constexpr bool is_int8 = utils::one_of( |
1417 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1418 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
1419 | bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); |
1420 | bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); |
1421 | bool is_amx = is_int8_amx || is_bf16_amx; |
1422 | |
1423 | const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed; |
1424 | |
1425 | // Scaling C matrix. |
1426 | if (!is_int8 && beta != 1.0f && beta != 0.0f) { |
1427 | scale_matrix(m, n, beta, c, ldc); |
1428 | beta = 1.0f; |
1429 | } |
1430 | |
1431 | // Padding along M, K dimensions. |
1432 | dim_t m_padd = get_m_padd_parallel_a(ithr, m, arg, nthrs); |
1433 | dim_t k_padd = get_k_padd(ithr, k, arg); |
1434 | |
1435 | size_t a_buf_nelems = m_padd * k_padd; |
1436 | |
1437 | // A buffer needs more space due to zero-padding. |
1438 | if (is_amx) |
1439 | a_buf_nelems = utils::rnd_up(m_padd, arg->um) |
1440 | * utils::rnd_up(k_padd, arg->uk); |
1441 | |
1442 | // Allocate shared memory for A and its row sum buffers in master thread. |
1443 | char *mem = nullptr; |
1444 | a_type *bufferA = nullptr; |
1445 | c_type *a_row_sum = nullptr; |
1446 | |
1447 | if (!a_packed) { |
1448 | if (ithr == 0) { // If thread master |
1449 | size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K); |
1450 | |
1451 | if (is_int8) { |
1452 | size_t a_row_sum_nelems = m_padd; |
1453 | mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K; |
1454 | } |
1455 | |
1456 | *p_shared_mem = (char *)malloc(mem_size, 128); |
1457 | } |
1458 | |
1459 | dnnl_thr_barrier(); |
1460 | |
1461 | mem = *p_shared_mem; |
1462 | bufferA = (a_type *)align(mem, PAGE_4K); |
1463 | |
1464 | if (is_int8) |
1465 | a_row_sum = (c_type *)align(bufferA + a_buf_nelems, PAGE_4K); |
1466 | |
1467 | if (!mem) return dnnl_out_of_memory; |
1468 | } |
1469 | |
1470 | dnnl_status_t result = dnnl_success; // Return status |
1471 | |
1472 | dim_t sizeK = 0; |
1473 | dim_t blk_k = 0; |
1474 | for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) { |
1475 | sizeK = k - Bk; |
1476 | if (sizeK > k_padd) sizeK = k_padd; |
1477 | |
1478 | // Scale C blocks by beta only for the first term of partial sum. |
1479 | auto beta_eff = (Bk == 0) ? beta : 1.0f; |
1480 | |
1481 | // Apply C offset for the last k-block of the partial sum. |
1482 | auto offsetc_eff = offset_type::none; |
1483 | if (Bk + sizeK == k) offsetc_eff = offsetc; |
1484 | |
1485 | dim_t sizeM = 0; |
1486 | for (dim_t Bm = 0; Bm < m; Bm += sizeM) { |
1487 | sizeM = m - Bm; |
1488 | if (sizeM > m_padd) sizeM = m_padd; |
1489 | |
1490 | if ((ithr < nthrs) && !a_packed) { |
1491 | dim_t band = (sizeM + nthrs - 1) / nthrs; |
1492 | band = utils::rnd_up(band, arg->um); |
1493 | |
1494 | dim_t offset = band * ithr; |
1495 | |
1496 | // If offset is too large don't use that thread for copying. |
1497 | if (offset >= sizeM) { |
1498 | offset = 0; |
1499 | band = 0; |
1500 | } |
1501 | |
1502 | // Handle the tail of the copy. |
1503 | if (offset + band > sizeM) { band = sizeM - offset; } |
1504 | |
1505 | if (band > 0) { |
1506 | const a_type *a_block |
1507 | = a + (Bm + offset) * strideAm + Bk * strideAn; |
1508 | |
1509 | dim_t buf_shift = 0; |
1510 | if (is_amx) |
1511 | buf_shift = offset * utils::rnd_up(sizeK, arg->uk); |
1512 | else |
1513 | buf_shift = offset * sizeK; |
1514 | |
1515 | /* Row sum argument is ignored for non-integer kernels and |
1516 | * scaling factor is ignored by 8-bit and 16-bit copy |
1517 | * kernels. |
1518 | */ |
1519 | c_type *a_row_sum_eff |
1520 | = a_row_sum ? a_row_sum + offset : nullptr; |
1521 | arg->copyA(&sizeK, &band, a_block, &lda, &alpha, |
1522 | bufferA + buf_shift, nullptr, nullptr, |
1523 | a_row_sum_eff); |
1524 | } |
1525 | } |
1526 | if (!a_packed) |
1527 | dnnl_thr_barrier(); // Wait for finishing parallel copy. |
1528 | |
1529 | const b_type *b_block = b + Bk * strideBm; |
1530 | c_type *c_block = c + Bm; |
1531 | |
1532 | dim_t co_stride = 0; |
1533 | if (offsetc_eff == offset_type::fixed) { |
1534 | co_stride = 0; |
1535 | } else if (offsetc_eff == offset_type::row) { |
1536 | co_stride = 0; |
1537 | } else if (offsetc_eff == offset_type::column) { |
1538 | co_stride = Bm; |
1539 | } |
1540 | |
1541 | auto bufferA_eff |
1542 | = a_packed ? a_packed->matrix<a_type>(0, Bm, Bk) : bufferA; |
1543 | auto a_row_sum_eff = a_packed |
1544 | ? a_packed->row_sums<c_type>(0, Bm, blk_k) |
1545 | : a_row_sum; |
1546 | |
1547 | auto this_result = kernel_driver_parallel_acopiedbcopy(ithr, sizeM, |
1548 | n, sizeK, blk_k, Bk, bufferA_eff, b_block, beta_eff, |
1549 | c_block, offsetc_eff, co + co_stride, a_row_sum_eff, arg); |
1550 | |
1551 | if (this_result != dnnl_success) result = this_result; |
1552 | |
1553 | if (!a_packed) |
1554 | dnnl_thr_barrier(); // Wait for kernel computations to finish. |
1555 | } |
1556 | } |
1557 | |
1558 | // Free memory allocated in master thread |
1559 | if (ithr == 0 && !a_packed) free(mem); |
1560 | |
1561 | return result; |
1562 | } |
1563 | |
1564 | template <typename T> |
1565 | static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) { |
1566 | |
1567 | const double omp_overhead_small_core = 3.0e+3; |
1568 | const double omp_intercept_big_core = 4.0e+3; |
1569 | const double omp_slope_big_core = 5.0e+2; |
1570 | |
1571 | auto veclen = get_vector_length<T>(); |
1572 | const double fp_per_cycle = 2.0 * 2.0 * veclen; |
1573 | |
1574 | const bool is_f32 = data_traits<T>::data_type == data_type::f32; |
1575 | |
1576 | const bool is_avx512 = mayiuse(avx512_core); |
1577 | const bool is_avx = mayiuse(avx); |
1578 | const bool is_only_avx2 = mayiuse(avx2) && !is_avx512; |
1579 | |
1580 | // Some sgemm cases still benefit from using all threads. |
1581 | const bool use_all_threads = is_f32 && n > 50 |
1582 | && ((is_avx && m <= 3) || (is_avx512 && m <= 10)); |
1583 | if (use_all_threads) return; |
1584 | |
1585 | if (is_only_avx2) |
1586 | if (m > 10 * n && n < *nthrs) |
1587 | if (m / *nthrs < veclen * 3) |
1588 | *nthrs = nstl::max(m / veclen / 3, dim_t(1)); |
1589 | |
1590 | double gemm_cycles = m * n * k / fp_per_cycle; |
1591 | gemm_cycles *= is_f32 ? 2.0 : 8.0; |
1592 | |
1593 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
1594 | if (is_f32) { |
1595 | static auto l2_cache_per_thread = platform::get_per_core_cache_size(2); |
1596 | static int n_cores_per_socket |
1597 | = static_cast<int>(platform::get_num_cores()); |
1598 | auto l2_cache_socket = l2_cache_per_thread * n_cores_per_socket; |
1599 | auto problem_memory_footprint = (m * n + m * k + n * k) * sizeof(float); |
1600 | |
1601 | if (is_only_avx2) { |
1602 | // Somehow it seems beneficial to split the job into bigger pieces. |
1603 | // Use L2 per-core cache size as a deal-breaker. |
1604 | int use_n_threads = utils::div_up( |
1605 | problem_memory_footprint, l2_cache_per_thread); |
1606 | *nthrs = nstl::min(*nthrs, use_n_threads); |
1607 | return; |
1608 | } |
1609 | if (l2_cache_socket > problem_memory_footprint) { |
1610 | *nthrs = nstl::min(*nthrs, n_cores_per_socket); |
1611 | return; |
1612 | } |
1613 | } |
1614 | #endif |
1615 | |
1616 | int i = *nthrs; |
1617 | |
1618 | // Use a different model for omp overheads if nthrs is <= 4 |
1619 | if (*nthrs <= 4 && omp_overhead_small_core > 0) { |
1620 | double omp_cycles = omp_overhead_small_core; |
1621 | if (gemm_cycles < omp_cycles) { |
1622 | *nthrs = 1; |
1623 | return; |
1624 | } else { |
1625 | while (i > 1) { |
1626 | if (omp_cycles * i < gemm_cycles * (i - 1)) break; |
1627 | --i; |
1628 | } |
1629 | } |
1630 | } else { |
1631 | if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) { |
1632 | *nthrs = 1; |
1633 | return; |
1634 | } |
1635 | |
1636 | // adaptive decrement to march faster· |
1637 | while (i > 1) { |
1638 | double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core; |
1639 | if (omp_cycles * i < gemm_cycles * (i - 1)) break; |
1640 | |
1641 | if (i < 10) |
1642 | i -= 2; |
1643 | else if (i < 30) |
1644 | i -= 4; |
1645 | else |
1646 | i -= 8; |
1647 | } |
1648 | } |
1649 | |
1650 | if (i < 1) i = 1; |
1651 | |
1652 | *nthrs = i; |
1653 | } |
1654 | |
1655 | template <typename a_type, typename b_type, typename c_type> |
1656 | static dnnl_status_t call_no_copy_sgemm( |
1657 | int nthrs, gemm_info_t<a_type, b_type, c_type> *arg) { |
1658 | |
1659 | if (arg->packing == pack_type::none) { |
1660 | auto transa_char = (arg->transa != do_trans) ? "N" : "T" ; |
1661 | auto transb_char = (arg->transb != do_trans) ? "N" : "T" ; |
1662 | |
1663 | if (mayiuse(avx512_core)) |
1664 | return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char, |
1665 | &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a, |
1666 | &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta, |
1667 | (float *)arg->c, &arg->ldc, (float *)arg->co); |
1668 | else |
1669 | return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m, |
1670 | &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda, |
1671 | (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c, |
1672 | &arg->ldc, (float *)arg->co); |
1673 | } else |
1674 | return pack_no_copy(arg); |
1675 | } |
1676 | |
1677 | template <typename a_type, typename b_type, typename c_type> |
1678 | static dnnl_status_t gemm_threading_driver( |
1679 | gemm_info_t<a_type, b_type, c_type> *arg) { |
1680 | |
1681 | auto packing = (arg->packing != pack_type::none); |
1682 | auto is_a_packed = (arg->transa == packed); |
1683 | auto is_b_packed = (arg->transb == packed); |
1684 | constexpr bool is_int8 = utils::one_of( |
1685 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1686 | constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16; |
1687 | |
1688 | if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success; |
1689 | |
1690 | if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg)) |
1691 | return dnnl_success; |
1692 | |
1693 | if (!is_a_packed && !is_b_packed |
1694 | && jump_to_gemm_smalln_tn(arg) == dnnl_success) |
1695 | return dnnl_success; |
1696 | |
1697 | if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success) |
1698 | return dnnl_success; |
1699 | |
1700 | if (is_a_packed && arg->bo != 0) |
1701 | if (!arg->a_packed->has_row_sums()) return dnnl_invalid_arguments; |
1702 | |
1703 | if (is_b_packed && arg->ao != 0) |
1704 | if (!arg->b_packed->has_col_sums()) return dnnl_invalid_arguments; |
1705 | |
1706 | auto nthr_max = dnnl_get_current_num_threads(); |
1707 | int nthr_goal = nthr_max; |
1708 | |
1709 | adjust_thread_count<c_type>(arg->m, arg->n, arg->k, &nthr_goal); |
1710 | |
1711 | const gemm_threading_t *force_threading = nullptr; |
1712 | gemm_threading_t force_k_decomp; |
1713 | |
1714 | // Initialize per-thread data. |
1715 | // Note: to support k blocking with non-packed GEMM, threading must be |
1716 | // chosen now and force_threading set. |
1717 | if (!packing) { |
1718 | // Override choice of thread count if data is pre-packed for a particular |
1719 | // number of threads. |
1720 | if (is_a_packed && is_b_packed) |
1721 | if (arg->a_packed->threading() != arg->b_packed->threading()) |
1722 | return dnnl_invalid_arguments; |
1723 | if (is_a_packed) |
1724 | force_threading = &arg->a_packed->threading(); |
1725 | else if (is_b_packed) |
1726 | force_threading = &arg->b_packed->threading(); |
1727 | else if (arg->m <= 768 && arg->n <= 768 && arg->k >= 2048 && is_bf16) { |
1728 | // Try k-partitioning. |
1729 | set_thread_opts_pack(nthr_goal, force_k_decomp, arg); |
1730 | |
1731 | // Decide partition type later if no partitions in k-dimension. |
1732 | if (force_k_decomp.nthrs_k > 1) force_threading = &force_k_decomp; |
1733 | } else if (arg->n <= 128 && arg->k >= 3072 && is_int8) { |
1734 | // Use k-partitioning if necessary. |
1735 | // Use 3D decomposition from pack api without n-partitioning. |
1736 | set_thread_opts_pack( |
1737 | nthr_goal, force_k_decomp, arg, true, true, false); |
1738 | |
1739 | // Decide partition type later if no partitions in k-dimension. |
1740 | if (force_k_decomp.nthrs_k > 1 && force_k_decomp.nthrs_m > 1) |
1741 | force_threading = &force_k_decomp; |
1742 | } |
1743 | |
1744 | if (force_threading) { |
1745 | nthr_goal = force_threading->nthrs(); |
1746 | arg->update_blocking(*force_threading); |
1747 | } |
1748 | } else { |
1749 | // Prepare packed data layout. |
1750 | gemm_pack_storage_t *pack_dst = arg->pack_dst; |
1751 | bool do_a = (arg->packing == pack_type::pack_a); |
1752 | |
1753 | pack_dst->which() = do_a ? matrix_id::a : matrix_id::b; |
1754 | pack_dst->setup(nthr_goal, do_a && is_int8, !do_a && is_int8); |
1755 | |
1756 | auto &thread_info = pack_dst->threading(); |
1757 | force_threading = &thread_info; |
1758 | |
1759 | nthr_goal = set_thread_opts(nthr_goal, nthr_max, thread_info, arg); |
1760 | arg->update_blocking(thread_info); |
1761 | |
1762 | if (thread_info.copy != copy_type::no_copy) { |
1763 | for (int ithr = 0; ithr < nthr_goal; ithr++) { |
1764 | if (!pack_dst->is_first_thread_in_slice(ithr)) continue; |
1765 | |
1766 | auto slice = thread_info.get_thread_slice( |
1767 | ithr, arg->m, arg->n, arg->k); |
1768 | |
1769 | auto m = slice.m, n = slice.n, k = slice.k; |
1770 | |
1771 | auto m_padd = (thread_info.copy == copy_type::shared_a) |
1772 | ? get_m_padd_parallel_a( |
1773 | ithr, m, arg, thread_info.nthrs()) |
1774 | : get_m_padd(ithr, m, arg); |
1775 | auto n_padd = get_n_padd(ithr, n, k, arg); |
1776 | auto k_padd = get_k_padd(ithr, k, arg); |
1777 | |
1778 | do_a ? pack_dst->set_blocking(ithr, m, k, m_padd, k_padd) |
1779 | : pack_dst->set_blocking(ithr, k, n, k_padd, n_padd); |
1780 | } |
1781 | } else { |
1782 | auto ld = do_a ? gemm_utils::get_ld_padd<a_type>(arg->m) |
1783 | : gemm_utils::get_ld_padd<b_type>(arg->k); |
1784 | |
1785 | pack_dst->set_nocopy(0, no_trans, ld, do_a ? arg->k : arg->n); |
1786 | } |
1787 | |
1788 | do_a ? pack_dst->finalize<a_type, c_type>() |
1789 | : pack_dst->finalize<b_type, c_type>(); |
1790 | |
1791 | if (arg->measure_only) return dnnl_success; |
1792 | } |
1793 | |
1794 | if (nocopy_checker(nthr_goal, arg)) |
1795 | return call_no_copy_sgemm(nthr_goal, arg); |
1796 | |
1797 | if (nthr_goal == 1) |
1798 | return gemm_kernel_driver(0, arg->m, arg->n, arg->k, arg->a, arg->b, |
1799 | arg->beta, arg->c, arg->ldc, arg->offsetc, arg->co, arg); |
1800 | |
1801 | bool k_blocking = force_threading && (force_threading->nthrs_k > 1); |
1802 | bool k_summing = k_blocking && !packing; |
1803 | |
1804 | auto *thread_arg = (gemm_per_thread_t<c_type> *)malloc( |
1805 | sizeof(gemm_per_thread_t<c_type>) * nthr_max, PAGE_4K); |
1806 | |
1807 | if (!thread_arg) return dnnl_out_of_memory; |
1808 | |
1809 | dim_t max_mt = 0, max_nt = 0; |
1810 | for (int ithr = 0; ithr < nthr_max; ithr++) { |
1811 | thread_arg[ithr].result = dnnl_success; |
1812 | thread_arg[ithr].compute_done = false; |
1813 | thread_arg[ithr].c_local = thread_arg[ithr].c_global = nullptr; |
1814 | thread_arg[ithr].ldc_global = arg->ldc; |
1815 | thread_arg[ithr].ldc_local = 0; |
1816 | |
1817 | if (force_threading) { |
1818 | thread_arg[ithr].slice = force_threading->get_thread_slice( |
1819 | ithr, arg->m, arg->n, arg->k); |
1820 | thread_arg[ithr].nthr_k = force_threading->nthrs_k; |
1821 | thread_arg[ithr].thr_k_stride = force_threading->thr_k_stride(); |
1822 | max_mt = nstl::max(max_mt, thread_arg[ithr].slice.m); |
1823 | max_nt = nstl::max(max_nt, thread_arg[ithr].slice.n); |
1824 | } else { |
1825 | thread_arg[ithr].slice = {0, 0, 0, 0, 0, 0, 0, 0, 0}; |
1826 | thread_arg[ithr].nthr_k = 1; |
1827 | thread_arg[ithr].thr_k_stride = 0; |
1828 | } |
1829 | } |
1830 | |
1831 | // Create temporary C buffers for k blocking if needed. |
1832 | c_type *c_local_storage = nullptr; |
1833 | if (k_summing) { |
1834 | const dim_t BAD_LD_MULT = 256; |
1835 | dim_t ldc_local = max_mt % BAD_LD_MULT |
1836 | ? max_mt |
1837 | : gemm_utils::get_ld_padd<c_type>(max_mt); |
1838 | dim_t c_local_stride = ldc_local * max_nt; |
1839 | c_local_storage = (c_type *)malloc( |
1840 | sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K); |
1841 | |
1842 | if (!c_local_storage) { |
1843 | free(thread_arg); |
1844 | return dnnl_out_of_memory; |
1845 | } |
1846 | |
1847 | for (int ithr = 0; ithr < nthr_goal; ithr++) { |
1848 | thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride; |
1849 | thread_arg[ithr].ldc_local = ldc_local; |
1850 | } |
1851 | } |
1852 | |
1853 | char *shared_mem = nullptr; |
1854 | |
1855 | // Always use the maximum number of threads to avoid OMP overhead that can |
1856 | // occur due to change thread counts. |
1857 | int nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal; |
1858 | |
1859 | parallel(nthr_spawn, [&](int ithr, int nthr) { |
1860 | int nthr_eff = force_threading ? nthr_goal : nstl::min(nthr_goal, nthr); |
1861 | |
1862 | if (nthr_eff == 1) { |
1863 | thread_arg[0].result = gemm_kernel_driver(0, arg->m, arg->n, arg->k, |
1864 | arg->a, arg->b, arg->beta, arg->c, arg->ldc, arg->offsetc, |
1865 | arg->co, arg); |
1866 | } else { |
1867 | gemm_threading_t thread_info; |
1868 | |
1869 | if (force_threading) |
1870 | thread_info = *force_threading; |
1871 | else { |
1872 | nthr_eff = set_thread_opts(nthr_eff, nthr, thread_info, arg); |
1873 | if (ithr < nthr_eff) |
1874 | thread_arg[ithr].slice = thread_info.get_thread_slice( |
1875 | ithr, arg->m, arg->n, arg->k); |
1876 | } |
1877 | |
1878 | for (; ithr < nthr_eff; ithr += nthr) { |
1879 | // Get submatrices and parameters for this thread's GEMM. |
1880 | const a_type *a = nullptr; |
1881 | const b_type *b = nullptr; |
1882 | c_type *c = nullptr; |
1883 | const c_type *co = nullptr; |
1884 | std::tie(a, b, c, co) |
1885 | = decompose_matrices(thread_arg[ithr].slice, arg); |
1886 | |
1887 | auto m = thread_arg[ithr].slice.m; |
1888 | auto n = thread_arg[ithr].slice.n; |
1889 | auto k = thread_arg[ithr].slice.k; |
1890 | thread_arg[ithr].c_global = c; |
1891 | auto c_eff = c; |
1892 | auto ldc_eff = arg->ldc; |
1893 | auto beta_eff = arg->beta; |
1894 | auto offsetc_eff = arg->offsetc; |
1895 | |
1896 | // For all but first k block: substitute local C matrix and |
1897 | // disable postops. |
1898 | if (k_summing && thread_arg[ithr].slice.ithr_k > 0) { |
1899 | c_eff = thread_arg[ithr].c_local; |
1900 | ldc_eff = thread_arg[ithr].ldc_local; |
1901 | beta_eff = 0; |
1902 | offsetc_eff = offset_type::none; |
1903 | } |
1904 | |
1905 | // Dispatch appropriate GEMM driver. |
1906 | switch (thread_info.copy) { |
1907 | case copy_type::shared_a: |
1908 | thread_arg[ithr].result = parallel_a_copy(ithr, |
1909 | nthr_eff, m, n, k, a, b, beta_eff, c_eff, |
1910 | ldc_eff, offsetc_eff, co, arg, &shared_mem); |
1911 | break; |
1912 | |
1913 | default: |
1914 | case copy_type::nonshared: |
1915 | thread_arg[ithr].result = gemm_kernel_driver(ithr, m, n, |
1916 | k, a, b, beta_eff, c_eff, ldc_eff, offsetc_eff, |
1917 | co, arg); |
1918 | break; |
1919 | |
1920 | case copy_type::no_copy: |
1921 | // This route is taken only if we realize we need no-copy |
1922 | // after launching the parallel section, due to less |
1923 | // threads being spawned than expected. |
1924 | assert(data_traits<a_type>::data_type |
1925 | == data_type::f32); |
1926 | assert(arg->packing == pack_type::none); |
1927 | |
1928 | if (mayiuse(avx512_core)) { |
1929 | avx512_common_gemm_f32::sgemm_nocopy_driver( |
1930 | arg->transa == no_trans ? "N" : "T" , |
1931 | arg->transb == no_trans ? "N" : "T" , m, n, |
1932 | k, &arg->alpha, (float *)a, arg->lda, |
1933 | (float *)b, arg->ldb, &beta_eff, |
1934 | (float *)c_eff, ldc_eff, nullptr); |
1935 | } else { |
1936 | avx_gemm_f32::sgemm_nocopy_driver( |
1937 | arg->transa == no_trans ? "N" : "T" , |
1938 | arg->transb == no_trans ? "N" : "T" , m, n, |
1939 | k, &arg->alpha, (float *)a, arg->lda, |
1940 | (float *)b, arg->ldb, &beta_eff, |
1941 | (float *)c_eff, ldc_eff, nullptr); |
1942 | } |
1943 | thread_arg[ithr].result = dnnl_success; |
1944 | break; |
1945 | } |
1946 | |
1947 | // Sum thread results along k dimension, parallelized in the n |
1948 | // dimension. To avoid deadlocks, results are summed later if |
1949 | // not all threads are running concurrently. We can only detect |
1950 | // if this is safe when using OpenMP. |
1951 | #if DNNL_THR_SYNC == 1 |
1952 | if (k_summing && (nthr >= nthr_eff)) { |
1953 | thread_arg[ithr].compute_done = true; |
1954 | sum_k_blocks(ithr, thread_arg, true); |
1955 | } |
1956 | #endif |
1957 | } |
1958 | } |
1959 | }); |
1960 | |
1961 | dnnl_status_t result = dnnl_success; // Initialize to success |
1962 | for (int ithr = 0; ithr < nthr_max; ithr++) { |
1963 | if (thread_arg[ithr].result != dnnl_success) { |
1964 | result = static_cast<dnnl_status_t>(thread_arg[ithr].result); |
1965 | break; |
1966 | } |
1967 | } |
1968 | |
1969 | // Sum thread results along k dimension if this wasn't done earlier. |
1970 | if (k_summing && !thread_arg[0].compute_done) { |
1971 | parallel(nthr_goal, [&](int ithr, int nthr) { |
1972 | for (; ithr < nthr_goal; ithr += nthr) |
1973 | sum_k_blocks(ithr, thread_arg, false); |
1974 | }); |
1975 | } |
1976 | |
1977 | if (c_local_storage) dnnl::impl::free(c_local_storage); |
1978 | dnnl::impl::free(thread_arg); |
1979 | |
1980 | return result; |
1981 | } |
1982 | |
1983 | template <typename a_type, typename b_type, typename c_type> |
1984 | dnnl_status_t gemm_driver(const char *transA, const char *transB, |
1985 | const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, |
1986 | const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa, |
1987 | const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta, |
1988 | c_type *c, const dim_t *ldc, const c_type *oc, const bool force_nocopy, |
1989 | pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) { |
1990 | |
1991 | constexpr bool is_int8 = utils::one_of( |
1992 | data_traits<a_type>::data_type, data_type::s8, data_type::u8); |
1993 | MAYBE_UNUSED(is_int8); |
1994 | |
1995 | // gemm_driver supports bfloat16 gemm for Intel AVX512 and |
1996 | // Intel AVX512 BF16. |
1997 | assert(IMPLICATION(data_traits<a_type>::data_type == data_type::bf16, |
1998 | mayiuse(avx512_core) && !force_nocopy)); |
1999 | |
2000 | // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX, |
2001 | // Intel SSE4.1 and Intel DL Boost. |
2002 | assert(IMPLICATION(is_int8, mayiuse(sse41))); |
2003 | |
2004 | // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX, |
2005 | // and Intel SSE4.1 |
2006 | assert(IMPLICATION( |
2007 | data_traits<a_type>::data_type == data_type::f32, mayiuse(sse41))); |
2008 | |
2009 | // 8-bit integer gemm doesn't support nocopy kernels. |
2010 | assert(IMPLICATION(is_int8, !force_nocopy)); |
2011 | |
2012 | // gemm_driver can only dispatch nocopy for avx and above. |
2013 | assert(IMPLICATION(force_nocopy, mayiuse(avx))); |
2014 | |
2015 | gemm_info_t<a_type, b_type, c_type> args(transA, transB, offsetC, m, n, k, |
2016 | alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy, |
2017 | packing, pack_dst, measure_only); |
2018 | |
2019 | // Check if copy algorithm kernels were generated on supported ISAs. |
2020 | if (!args.hasKernels()) return dnnl_unimplemented; |
2021 | |
2022 | return gemm_threading_driver(&args); |
2023 | } |
2024 | |
2025 | template // Instantiate gemm_bf16bf16f32 |
2026 | dnnl_status_t |
2027 | gemm_driver<bfloat16_t, bfloat16_t, float>(const char *transA, |
2028 | const char *transB, const char *offsetC, const dim_t *m, |
2029 | const dim_t *n, const dim_t *k, const float *alpha, |
2030 | const bfloat16_t *a, const dim_t *lda, const bfloat16_t *oa, |
2031 | const bfloat16_t *b, const dim_t *ldb, const bfloat16_t *ob, |
2032 | const float *beta, float *c, const dim_t *ldc, const float *oc, |
2033 | const bool force_nocopy, pack_type packing, |
2034 | gemm_pack_storage_t *pack_dst, bool measure_only); |
2035 | |
2036 | template // Instantiate gemm_s8s8s32 |
2037 | dnnl_status_t |
2038 | gemm_driver<int8_t, int8_t, int32_t>(const char *transA, |
2039 | const char *transB, const char *offsetC, const dim_t *m, |
2040 | const dim_t *n, const dim_t *k, const float *alpha, |
2041 | const int8_t *a, const dim_t *lda, const int8_t *oa, |
2042 | const int8_t *b, const dim_t *ldb, const int8_t *ob, |
2043 | const float *beta, int32_t *c, const dim_t *ldc, |
2044 | const int32_t *oc, const bool force_nocopy, pack_type packing, |
2045 | gemm_pack_storage_t *pack_dst, bool measure_only); |
2046 | |
2047 | template // Instantiate gemm_s8u8s32 |
2048 | dnnl_status_t |
2049 | gemm_driver<int8_t, uint8_t, int32_t>(const char *transA, |
2050 | const char *transB, const char *offsetC, const dim_t *m, |
2051 | const dim_t *n, const dim_t *k, const float *alpha, |
2052 | const int8_t *a, const dim_t *lda, const int8_t *oa, |
2053 | const uint8_t *b, const dim_t *ldb, const uint8_t *ob, |
2054 | const float *beta, int32_t *c, const dim_t *ldc, |
2055 | const int32_t *oc, const bool force_nocopy, pack_type packing, |
2056 | gemm_pack_storage_t *pack_dst, bool measure_only); |
2057 | |
2058 | template // Instantiate sgemm |
2059 | dnnl_status_t |
2060 | gemm_driver<float, float, float>(const char *transA, const char *transB, |
2061 | const char *offsetC, const dim_t *m, const dim_t *n, |
2062 | const dim_t *k, const float *alpha, const float *a, |
2063 | const dim_t *lda, const float *oa, const float *b, |
2064 | const dim_t *ldb, const float *ob, const float *beta, float *c, |
2065 | const dim_t *ldc, const float *oc, const bool force_nocopy, |
2066 | pack_type packing, gemm_pack_storage_t *pack_dst, |
2067 | bool measure_only); |
2068 | |
2069 | #undef MAX_STACK_SZ |
2070 | } // namespace x64 |
2071 | } // namespace cpu |
2072 | } // namespace impl |
2073 | } // namespace dnnl |
2074 | |