1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstdint>
18
19#include "oneapi/dnnl/dnnl_types.h"
20
21#include "common/bfloat16.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/nstl.hpp"
24
25#include "cpu/platform.hpp"
26
27#include "cpu/x64/cpu_isa_traits.hpp"
28#include "cpu/x64/jit_generator.hpp"
29
30#include "cpu/x64/gemm/gemm_info.hpp"
31#include "cpu/x64/gemm/gemm_utils.hpp"
32#include "cpu/x64/gemm/gemv_driver.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39// gemv kernel when A is non-transposed incy == 1 and any stride on X.
40template <typename a_t, typename b_t, typename c_t>
41static inline void gemv_n_kernel(const dim_t m, const dim_t n, float alpha,
42 const a_t *__restrict a, const dim_t lda, const b_t *__restrict x,
43 const dim_t incx, c_t *__restrict y, const dim_t incy,
44 const gemm_info_t<a_t, b_t, c_t> *arg) {
45 assert(incy == 1);
46
47 auto gemv_n_kern = arg->gemv_kernel[no_trans];
48 if (gemv_n_kern) {
49 gemv_n_kern(&m, &n, &alpha, a, &lda, x, &incx, y, &incy);
50 } else {
51 if (incx == 1) {
52 for (dim_t i = 0; i < n; i++) {
53 PRAGMA_OMP_SIMD()
54 for (dim_t j = 0; j < m; j++) {
55 y[j] += alpha * x[i] * a[j + i * lda];
56 }
57 }
58 } else {
59 dim_t idx = incx < 0 ? (1 - n) * incx : 0;
60 for (dim_t i = 0; i < n; i++) {
61 PRAGMA_OMP_SIMD()
62 for (dim_t j = 0; j < m; j++) {
63 y[j] += alpha * x[idx] * a[j + i * lda];
64 }
65 idx += incx;
66 }
67 }
68 }
69}
70
71// gemv kernel when A is transposed incx == 1 and any stride on Y.
72template <typename a_t, typename b_t, typename c_t>
73static inline void gemv_t_kernel(const dim_t m, const dim_t n, float alpha,
74 const a_t *__restrict a, const dim_t lda, const b_t *__restrict x,
75 const dim_t incx, c_t *__restrict y, const dim_t incy,
76 const gemm_info_t<a_t, b_t, c_t> *arg) {
77 assert(incx == 1);
78
79 auto gemv_t_kern = arg->gemv_kernel[do_trans];
80 if (gemv_t_kern) {
81 gemv_t_kern(&m, &n, &alpha, a, &lda, x, &incx, y, &incy);
82 } else {
83 if (incy == 1) {
84 for (dim_t i = 0; i < n; i++) {
85 c_t temp = (c_t)0;
86 for (dim_t j = 0; j < m; j++) {
87 temp += x[j] * a[j + i * lda];
88 }
89 y[i] += temp * alpha;
90 }
91 } else {
92 dim_t idy = incy < 0 ? (1 - n) * incy : 0;
93 for (dim_t i = 0; i < n; i++) {
94 c_t temp = (c_t)0;
95 for (dim_t j = 0; j < m; j++) {
96 temp += x[j] * a[j + i * lda];
97 }
98 y[idy] += temp * alpha;
99
100 idy += incy;
101 }
102 }
103 }
104}
105
106#define M_BLK 512
107template <typename a_t, typename b_t, typename c_t>
108static inline void gemv_kernel_driver(const int trans, const dim_t m,
109 const dim_t n, const float alpha, const a_t *a, const dim_t lda,
110 const b_t *x, const dim_t incx, const float beta, c_t *y,
111 const dim_t incy, const gemm_info_t<a_t, b_t, c_t> *arg) {
112 // Set dimensions of X and Y vectors based on transpose type.
113 dim_t x_dim = trans == no_trans ? n : m;
114 dim_t y_dim = trans == no_trans ? m : n;
115
116 if (y_dim <= 0) return;
117
118 // Set the indices for y and x vectors based on incx/incy
119 dim_t idx_x = incx < 0 ? (1 - x_dim) * incx : 0;
120 dim_t idx_y = incy < 0 ? (1 - y_dim) * incy : 0;
121
122 // Scale the Y vector
123 if (beta != 1.0f) {
124 if (incy == 1) {
125 if (beta == 0.0f) {
126 PRAGMA_OMP_SIMD()
127 for (dim_t i = 0; i < y_dim; i++) {
128 y[i] = (c_t)0.0f;
129 }
130 } else {
131 PRAGMA_OMP_SIMD()
132 for (dim_t i = 0; i < y_dim; i++) {
133 y[i] *= beta;
134 }
135 }
136 } else {
137 if (beta == 0.0f) {
138 for (dim_t i = 0, inc = idx_y; i < y_dim; i++) {
139 y[inc] = (c_t)0.0f;
140 inc += incy;
141 }
142 } else {
143 for (dim_t i = 0, inc = idx_y; i < y_dim; i++) {
144 y[inc] *= beta;
145 inc += incy;
146 }
147 }
148 }
149 }
150
151 if (x_dim <= 0 || alpha == 0.0f) return;
152
153 if (trans == no_trans) { // A is not transpose.
154 if (incy == 1) {
155 gemv_n_kernel(m, n, alpha, a, lda, x, incx, y, incy, arg);
156 } else {
157 // Allocate temporary buffer for y vector.
158#if !defined(_MSC_VER)
159 c_t ytmp[M_BLK];
160#else
161 c_t *ytmp = (c_t *)_alloca(sizeof(*ytmp) * M_BLK);
162#endif
163
164 dim_t m_blk = 0;
165 for (dim_t i = 0; i < m; i += m_blk) {
166 m_blk = m - i;
167 if (m_blk > M_BLK) m_blk = M_BLK;
168
169 PRAGMA_OMP_SIMD()
170 for (dim_t j = 0; j < m_blk; j++)
171 ytmp[j] = (c_t)0.0;
172
173 // Call unit-stride kernel.
174 gemv_n_kernel(m_blk, n, alpha, a, lda, x, incx, ytmp, 1, arg);
175
176 // Add matrix-vector result back to y vector.
177 for (dim_t j = 0, inc = idx_y; j < m_blk; j++) {
178 y[inc] += ytmp[j];
179 inc += incy;
180 }
181 a += m_blk;
182 y += m_blk * incy;
183 }
184 }
185 } else { // Matrix A is transpose.
186 if (incx == 1) {
187 gemv_t_kernel(m, n, alpha, a, lda, x, incx, y, incy, arg);
188 } else {
189 // Allocate temporary buffer for x vector.
190#if !defined(_MSC_VER)
191 b_t xtmp[M_BLK];
192#else
193 b_t *xtmp = (b_t *)_alloca(sizeof(*xtmp) * M_BLK);
194#endif
195 dim_t m_blk = 0;
196 for (dim_t i = 0; i < m; i += m_blk) {
197 m_blk = m - i;
198 if (m_blk > M_BLK) m_blk = M_BLK;
199
200 // Copy a block of x vector to temporary buffer.
201 for (dim_t j = 0, inc = idx_x; j < m_blk; j++) {
202 xtmp[j] = x[inc];
203 inc += incx;
204 }
205
206 // Call unit-stride kernel.
207 gemv_t_kernel(m_blk, n, alpha, a, lda, xtmp, 1, y, incy, arg);
208
209 a += m_blk;
210 x += m_blk * incx;
211 }
212 }
213 }
214}
215#undef M_BLK
216
217#define M_MIN 128
218#define N_MIN 128
219#define BAND_MIN 32
220#define MN_MIN_N 1536
221#define MN_MIN_T 2048
222#define M_LARGE 20000
223#define N_LARGE 20000
224#define M_SMALL 200
225#define N_SMALL 200
226#define CONST1_AVX2 288
227#define CONST2_AVX2 41700
228#define MIN_WIDTH 32
229// Check if threading is beneficial.
230template <typename a_t>
231static inline int thread_checker(
232 int nthr, const dim_t m, const dim_t n, int trans) {
233 constexpr bool is_f32
234 = utils::one_of(data_traits<a_t>::data_type, data_type::f32);
235
236 if (is_f32) {
237 // Threshold based on performance measurement with warm and cold cache
238 // to decide when threading is beneficial.
239 if (mayiuse(avx2)) {
240 if (m * n + CONST1_AVX2 * n < CONST2_AVX2) { return 1; }
241 } else {
242 if (m < M_MIN && n < N_MIN) {
243 // Execute in sequential mode for small n and m.
244 return 1;
245 }
246 }
247
248 if (m >= M_LARGE && n <= N_SMALL) {
249 // Execute in parallel mode.
250 return nthr;
251 }
252
253 dim_t bandt = n / nthr; // size per thread.
254
255 if (nthr <= 12 && bandt < BAND_MIN) {
256 if (m * bandt < MN_MIN_T) { return 1; }
257 } else if (nthr <= 12 && m * bandt < 2 * MN_MIN_T) {
258 return 1;
259 } else if (nthr > 12 && bandt * m < 2 * MN_MIN_T) {
260 if (bandt == 0) {
261 return 1;
262 } else {
263 return nstl::min(nstl::max(n * m / (2 * MN_MIN_N), dim_t(1)),
264 dim_t(nthr));
265 }
266 }
267#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
268 if (is_f32) {
269 static const bool is_avx2 = mayiuse(avx2) && !mayiuse(avx512_core);
270 static auto l2_cache_per_thread
271 = platform::get_per_core_cache_size(2);
272 static int n_cores_per_socket
273 = static_cast<int>(platform::get_num_cores());
274 auto l2_cache_socket = l2_cache_per_thread * n_cores_per_socket;
275 auto problem_memory_footprint = m * n * sizeof(float);
276
277 if (is_avx2) {
278 // Somehow it seems beneficial to split the job into bigger
279 // pieces. Use L2 per-core cache size as a deal-breaker.
280 int use_n_threads = utils::div_up(
281 problem_memory_footprint, l2_cache_per_thread);
282 return nstl::min(nthr, use_n_threads);
283 }
284 if (l2_cache_socket > problem_memory_footprint) {
285 return nstl::min(nthr, n_cores_per_socket);
286 }
287 }
288#endif
289
290 } else {
291 if (trans) {
292 if (MIN_WIDTH * nthr > m) nthr = utils::div_up(m, MIN_WIDTH);
293 } else {
294 if (MIN_WIDTH * nthr > n) nthr = utils::div_up(n, MIN_WIDTH);
295 }
296 }
297
298 return nthr;
299}
300#undef M_MIN
301#undef N_MIN
302#undef BAND_MIN
303#undef MN_MIN_N
304#undef MN_MIN_T
305#undef M_LARGE
306#undef N_LARGE
307#undef M_SMALL
308#undef N_SMALL
309#undef CONST1_AVX2
310#undef CONST2_AVX2
311#undef MIN_WIDTH
312
313template <typename T>
314static inline void part_1d(const dim_t m, const int ithr, const int nthr,
315 T *addr, dim_t &off, dim_t &size) {
316 constexpr bool is_f32
317 = utils::one_of(data_traits<T>::data_type, data_type::f32);
318
319 if (ithr >= nthr) {
320 size = 0;
321 off = 0;
322 return;
323 }
324
325 if (is_f32) {
326 if (addr == nullptr) {
327 dim_t xthr = m % nthr;
328 dim_t width = m / nthr;
329
330 if (ithr < xthr) {
331 size = width + 1;
332 off = ithr * size;
333 } else {
334 size = width;
335 off = m - (nthr - ithr) * size;
336 }
337 } else {
338 // Consider cache slashing.
339 enum { CACHE_LINE_SIZE = 64 };
340
341 // Find the offset against cache line.
342 dim_t cache_off = (size_t)addr % CACHE_LINE_SIZE / sizeof(*addr);
343
344 // Find partition size, but it needs to be multiple of cache line.
345 dim_t align = CACHE_LINE_SIZE / sizeof(*addr);
346 dim_t width
347 = utils::rnd_up(utils::div_up(m + cache_off, nthr), align);
348
349 if (width > m + cache_off) width = m + cache_off;
350
351 if (ithr == 0) {
352 // First thread is sacrificed to align against cache.
353 size = width - cache_off;
354 off = 0;
355 } else {
356 size = width;
357 off = ithr * width - cache_off;
358 }
359 }
360 } else {
361 size = utils::div_up(m, nthr);
362 off = ithr * size;
363 }
364
365 if (off > m) off = m;
366 if (off + size > m) size = m - off;
367}
368
369template <typename c_t>
370void sum_ybufs(
371 int ithr, int nthr, dim_t m, c_t *y, dim_t incy, c_t *ybuf, int nbufs) {
372 if (incy < 0) y += (-m + 1) * incy;
373
374 dim_t off_m = 0;
375 dim_t thread_m = 0;
376
377 // Reduction in each thread.
378 part_1d(m, ithr, nthr, (c_t *)nullptr, off_m, thread_m);
379 if (incy == 1)
380 for (int buf_id = 0; buf_id < nbufs; buf_id++) {
381 PRAGMA_OMP_SIMD()
382 for (dim_t i = off_m; i < off_m + thread_m; i++)
383 y[i] += ybuf[i + buf_id * m];
384 }
385 else
386 for (int buf_id = 0; buf_id < nbufs; buf_id++)
387 for (dim_t i = off_m; i < off_m + thread_m; i++)
388 y[i * incy] += ybuf[i + buf_id * m];
389}
390
391template <typename a_t, typename b_t, typename c_t>
392static inline void gemv_threading_driver(const int trans, const dim_t m,
393 const dim_t n, const float alpha, const a_t *a, const dim_t lda,
394 const b_t *x, const dim_t incx, const float beta, c_t *y,
395 const dim_t incy, const gemm_info_t<a_t, b_t, c_t> *arg) {
396 constexpr bool is_f32
397 = utils::one_of(data_traits<a_t>::data_type, data_type::f32);
398 constexpr bool is_bf16
399 = utils::one_of(data_traits<a_t>::data_type, data_type::bf16);
400
401 // Quick return if possible.
402 if (m <= 0 || n <= 0) return;
403
404 auto nthr_max = dnnl_get_current_num_threads();
405 auto nthr_goal = thread_checker<a_t>(nthr_max, m, n, trans);
406
407 if (nthr_goal == 1) {
408 gemv_kernel_driver(
409 trans, m, n, alpha, a, lda, x, incx, beta, y, incy, arg);
410 return;
411 }
412
413 enum { M_MIN = 500, N_MIN = 128 };
414 bool is_short_fat = m <= nthr_goal * M_MIN && n >= nthr_goal * N_MIN;
415
416 bool use_y_buf = trans == no_trans && (is_bf16 || (is_f32 && is_short_fat));
417 bool is_syncable = dnnl_thr_syncable();
418
419 c_t *ybuf = nullptr;
420 if (use_y_buf)
421 ybuf = (c_t *)malloc(sizeof(*ybuf) * m * (nthr_goal - 1), PAGE_4K);
422
423 // Always use the maximum number of threads to avoid OMP overhead that can
424 // occur due to change thread counts.
425 auto nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal;
426 int nbufs_used = 0;
427 parallel(nthr_spawn, [&](int ithr, int nthr) {
428 int nthr_eff = nstl::min(nthr_goal, nthr);
429
430 dim_t thread_m = m, off_m = 0;
431 dim_t thread_n = n, off_n = 0;
432 dim_t band = 1;
433
434 // Default effective values.
435 auto a_eff = a;
436 auto x_eff = x;
437 auto y_eff = y;
438 auto incy_eff = incy;
439 auto beta_eff = beta;
440
441 if (trans == do_trans) {
442 part_1d(n, ithr, nthr_eff, (c_t *)nullptr, off_n, thread_n);
443 a_eff += off_m + off_n * lda;
444 y_eff += off_n * incy;
445 if (incy < 0) y_eff += (-n + thread_n) * incy;
446 band = thread_n;
447 } else if (ybuf) {
448 // Non-transpose for short and fat matrix sizes.
449 part_1d(n, ithr, nthr_eff, (c_t *)nullptr, off_n, thread_n);
450 a_eff += off_m + off_n * lda;
451 x_eff += off_n * incx;
452 if (incx < 0) x_eff += (-n + thread_n) * incx;
453 if (ithr != 0) {
454 y_eff = ybuf + m * (ithr - 1);
455 incy_eff = 1;
456 beta_eff = 0.0;
457 } else {
458 // Set number of used buffers for perform reduction later.
459 nbufs_used = nthr_eff - 1;
460 }
461 } else {
462 // Non-transpose for other matrix sizes.
463 // Fallback for no_trans with no extra buffer.
464 part_1d(m, ithr, nthr_eff, y, off_m, thread_m);
465 a_eff += off_m + off_n * lda;
466 y_eff += off_m * incy;
467 if (incy < 0) y_eff += (-m + thread_m) * incy;
468 band = thread_m;
469 }
470
471 // Buffers for y need to be set to zero for reduction case.
472 assert(IMPLICATION(ybuf, band > 0));
473
474 if (band > 0 && ithr < nthr_eff)
475 gemv_kernel_driver(trans, thread_m, thread_n, alpha, a_eff, lda,
476 x_eff, incx, beta_eff, y_eff, incy_eff, arg);
477
478 // Do reduction for multiple buffers if needed.
479 if (is_syncable && ybuf) {
480 dnnl_thr_barrier();
481
482 sum_ybufs(ithr, nthr_eff, m, y, incy, ybuf, nbufs_used);
483 }
484 });
485
486 // Reduce on y after each gemv computation is done.
487 if (!is_syncable && ybuf) {
488 parallel(nthr_spawn, [&](int ithr, int nthr) {
489 sum_ybufs(ithr, nthr, m, y, incy, ybuf, nbufs_used);
490 });
491 }
492
493 free(ybuf);
494}
495
496template <>
497dnnl_status_t jump_to_gemv(const gemm_info_t<int8_t, uint8_t, int32_t> *arg) {
498 return dnnl_unimplemented;
499}
500
501template <>
502dnnl_status_t jump_to_gemv(const gemm_info_t<int8_t, int8_t, int32_t> *arg) {
503 return dnnl_unimplemented;
504}
505
506template <typename a_t, typename b_t, typename c_t>
507dnnl_status_t jump_to_gemv(const gemm_info_t<a_t, b_t, c_t> *arg) {
508 int transa = arg->transa;
509 int transb = arg->transb;
510
511 dim_t m = arg->m;
512 dim_t n = arg->n;
513 dim_t k = arg->k;
514
515 dim_t lda = arg->lda;
516 dim_t ldb = arg->ldb;
517 dim_t ldc = arg->ldc;
518
519 float alpha = arg->alpha;
520 float beta = arg->beta;
521
522 const a_t *a = arg->a;
523 const b_t *b = arg->b;
524 c_t *c = arg->c;
525
526 if (k == 0) return dnnl_success;
527
528 auto packing = (arg->packing != pack_type::none);
529 auto do_a = (arg->packing == pack_type::pack_a);
530 gemm_pack_storage_t *pack_dst = arg->pack_dst;
531
532 if (n == 1 && (transa == do_trans || packing)) {
533 if (!packing) {
534 gemv_threading_driver(do_trans, k, m, alpha, a, lda, b,
535 transb == no_trans ? 1 : ldb, beta, c, 1, arg);
536 } else {
537 if (do_a) {
538 gemm_utils::prep_gemm_pack<a_t, c_t>(
539 do_a, do_trans, m, k, pack_dst);
540 } else {
541 gemm_utils::prep_gemm_pack<b_t, c_t>(
542 do_a, no_trans, k, n, pack_dst);
543 }
544
545 if (arg->measure_only) return dnnl_success;
546
547 if (do_a) {
548 gemm_utils::pack_no_copy(a, lda, m, k, transa, alpha, pack_dst);
549 } else {
550 gemm_utils::pack_no_copy(b, ldb, k, n, transb, alpha, pack_dst);
551 }
552 }
553 return dnnl_success;
554 } else if (n == 1 && transa == no_trans && !packing) {
555 gemv_threading_driver(no_trans, m, k, alpha, a, lda, b,
556 transb == no_trans ? 1 : ldb, beta, c, 1, arg);
557 return dnnl_success;
558 }
559
560 if (m == 1 && (transb == no_trans || packing)) {
561 if (!packing) {
562 gemv_threading_driver(do_trans, k, n, alpha, b, ldb, a,
563 transa == no_trans ? lda : 1, beta, c, ldc, arg);
564 } else {
565 if (do_a) {
566 gemm_utils::prep_gemm_pack<a_t, c_t>(
567 do_a, do_trans, m, k, pack_dst);
568 } else {
569 gemm_utils::prep_gemm_pack<b_t, c_t>(
570 do_a, no_trans, k, n, pack_dst);
571 }
572
573 if (arg->measure_only) return dnnl_success;
574
575 if (do_a) {
576 gemm_utils::pack_no_copy(a, lda, m, k, transa, alpha, pack_dst);
577 } else {
578 gemm_utils::pack_no_copy(b, ldb, k, n, transb, alpha, pack_dst);
579 }
580 }
581 return dnnl_success;
582 } else if (m == 1 && transb == do_trans && !packing) {
583 gemv_threading_driver(no_trans, n, k, alpha, b, ldb, a,
584 transa == no_trans ? lda : 1, beta, c, ldc, arg);
585 return dnnl_success;
586 }
587
588 return dnnl_unimplemented;
589}
590
591template // Instatiate gemv_f32
592 dnnl_status_t
593 jump_to_gemv<float, float, float>(
594 const gemm_info_t<float, float, float> *arg);
595template // Instatiate gemv_bf16bf16f32
596 dnnl_status_t
597 jump_to_gemv<bfloat16_t, bfloat16_t, float>(
598 const gemm_info_t<bfloat16_t, bfloat16_t, float> *arg);
599
600} // namespace x64
601} // namespace cpu
602} // namespace impl
603} // namespace dnnl
604