1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdint> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/bfloat16.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/nstl.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | |
27 | #include "cpu/x64/cpu_isa_traits.hpp" |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | |
30 | #include "cpu/x64/gemm/gemm_info.hpp" |
31 | #include "cpu/x64/gemm/gemm_utils.hpp" |
32 | #include "cpu/x64/gemm/gemv_driver.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | // gemv kernel when A is non-transposed incy == 1 and any stride on X. |
40 | template <typename a_t, typename b_t, typename c_t> |
41 | static inline void gemv_n_kernel(const dim_t m, const dim_t n, float alpha, |
42 | const a_t *__restrict a, const dim_t lda, const b_t *__restrict x, |
43 | const dim_t incx, c_t *__restrict y, const dim_t incy, |
44 | const gemm_info_t<a_t, b_t, c_t> *arg) { |
45 | assert(incy == 1); |
46 | |
47 | auto gemv_n_kern = arg->gemv_kernel[no_trans]; |
48 | if (gemv_n_kern) { |
49 | gemv_n_kern(&m, &n, &alpha, a, &lda, x, &incx, y, &incy); |
50 | } else { |
51 | if (incx == 1) { |
52 | for (dim_t i = 0; i < n; i++) { |
53 | PRAGMA_OMP_SIMD() |
54 | for (dim_t j = 0; j < m; j++) { |
55 | y[j] += alpha * x[i] * a[j + i * lda]; |
56 | } |
57 | } |
58 | } else { |
59 | dim_t idx = incx < 0 ? (1 - n) * incx : 0; |
60 | for (dim_t i = 0; i < n; i++) { |
61 | PRAGMA_OMP_SIMD() |
62 | for (dim_t j = 0; j < m; j++) { |
63 | y[j] += alpha * x[idx] * a[j + i * lda]; |
64 | } |
65 | idx += incx; |
66 | } |
67 | } |
68 | } |
69 | } |
70 | |
71 | // gemv kernel when A is transposed incx == 1 and any stride on Y. |
72 | template <typename a_t, typename b_t, typename c_t> |
73 | static inline void gemv_t_kernel(const dim_t m, const dim_t n, float alpha, |
74 | const a_t *__restrict a, const dim_t lda, const b_t *__restrict x, |
75 | const dim_t incx, c_t *__restrict y, const dim_t incy, |
76 | const gemm_info_t<a_t, b_t, c_t> *arg) { |
77 | assert(incx == 1); |
78 | |
79 | auto gemv_t_kern = arg->gemv_kernel[do_trans]; |
80 | if (gemv_t_kern) { |
81 | gemv_t_kern(&m, &n, &alpha, a, &lda, x, &incx, y, &incy); |
82 | } else { |
83 | if (incy == 1) { |
84 | for (dim_t i = 0; i < n; i++) { |
85 | c_t temp = (c_t)0; |
86 | for (dim_t j = 0; j < m; j++) { |
87 | temp += x[j] * a[j + i * lda]; |
88 | } |
89 | y[i] += temp * alpha; |
90 | } |
91 | } else { |
92 | dim_t idy = incy < 0 ? (1 - n) * incy : 0; |
93 | for (dim_t i = 0; i < n; i++) { |
94 | c_t temp = (c_t)0; |
95 | for (dim_t j = 0; j < m; j++) { |
96 | temp += x[j] * a[j + i * lda]; |
97 | } |
98 | y[idy] += temp * alpha; |
99 | |
100 | idy += incy; |
101 | } |
102 | } |
103 | } |
104 | } |
105 | |
106 | #define M_BLK 512 |
107 | template <typename a_t, typename b_t, typename c_t> |
108 | static inline void gemv_kernel_driver(const int trans, const dim_t m, |
109 | const dim_t n, const float alpha, const a_t *a, const dim_t lda, |
110 | const b_t *x, const dim_t incx, const float beta, c_t *y, |
111 | const dim_t incy, const gemm_info_t<a_t, b_t, c_t> *arg) { |
112 | // Set dimensions of X and Y vectors based on transpose type. |
113 | dim_t x_dim = trans == no_trans ? n : m; |
114 | dim_t y_dim = trans == no_trans ? m : n; |
115 | |
116 | if (y_dim <= 0) return; |
117 | |
118 | // Set the indices for y and x vectors based on incx/incy |
119 | dim_t idx_x = incx < 0 ? (1 - x_dim) * incx : 0; |
120 | dim_t idx_y = incy < 0 ? (1 - y_dim) * incy : 0; |
121 | |
122 | // Scale the Y vector |
123 | if (beta != 1.0f) { |
124 | if (incy == 1) { |
125 | if (beta == 0.0f) { |
126 | PRAGMA_OMP_SIMD() |
127 | for (dim_t i = 0; i < y_dim; i++) { |
128 | y[i] = (c_t)0.0f; |
129 | } |
130 | } else { |
131 | PRAGMA_OMP_SIMD() |
132 | for (dim_t i = 0; i < y_dim; i++) { |
133 | y[i] *= beta; |
134 | } |
135 | } |
136 | } else { |
137 | if (beta == 0.0f) { |
138 | for (dim_t i = 0, inc = idx_y; i < y_dim; i++) { |
139 | y[inc] = (c_t)0.0f; |
140 | inc += incy; |
141 | } |
142 | } else { |
143 | for (dim_t i = 0, inc = idx_y; i < y_dim; i++) { |
144 | y[inc] *= beta; |
145 | inc += incy; |
146 | } |
147 | } |
148 | } |
149 | } |
150 | |
151 | if (x_dim <= 0 || alpha == 0.0f) return; |
152 | |
153 | if (trans == no_trans) { // A is not transpose. |
154 | if (incy == 1) { |
155 | gemv_n_kernel(m, n, alpha, a, lda, x, incx, y, incy, arg); |
156 | } else { |
157 | // Allocate temporary buffer for y vector. |
158 | #if !defined(_MSC_VER) |
159 | c_t ytmp[M_BLK]; |
160 | #else |
161 | c_t *ytmp = (c_t *)_alloca(sizeof(*ytmp) * M_BLK); |
162 | #endif |
163 | |
164 | dim_t m_blk = 0; |
165 | for (dim_t i = 0; i < m; i += m_blk) { |
166 | m_blk = m - i; |
167 | if (m_blk > M_BLK) m_blk = M_BLK; |
168 | |
169 | PRAGMA_OMP_SIMD() |
170 | for (dim_t j = 0; j < m_blk; j++) |
171 | ytmp[j] = (c_t)0.0; |
172 | |
173 | // Call unit-stride kernel. |
174 | gemv_n_kernel(m_blk, n, alpha, a, lda, x, incx, ytmp, 1, arg); |
175 | |
176 | // Add matrix-vector result back to y vector. |
177 | for (dim_t j = 0, inc = idx_y; j < m_blk; j++) { |
178 | y[inc] += ytmp[j]; |
179 | inc += incy; |
180 | } |
181 | a += m_blk; |
182 | y += m_blk * incy; |
183 | } |
184 | } |
185 | } else { // Matrix A is transpose. |
186 | if (incx == 1) { |
187 | gemv_t_kernel(m, n, alpha, a, lda, x, incx, y, incy, arg); |
188 | } else { |
189 | // Allocate temporary buffer for x vector. |
190 | #if !defined(_MSC_VER) |
191 | b_t xtmp[M_BLK]; |
192 | #else |
193 | b_t *xtmp = (b_t *)_alloca(sizeof(*xtmp) * M_BLK); |
194 | #endif |
195 | dim_t m_blk = 0; |
196 | for (dim_t i = 0; i < m; i += m_blk) { |
197 | m_blk = m - i; |
198 | if (m_blk > M_BLK) m_blk = M_BLK; |
199 | |
200 | // Copy a block of x vector to temporary buffer. |
201 | for (dim_t j = 0, inc = idx_x; j < m_blk; j++) { |
202 | xtmp[j] = x[inc]; |
203 | inc += incx; |
204 | } |
205 | |
206 | // Call unit-stride kernel. |
207 | gemv_t_kernel(m_blk, n, alpha, a, lda, xtmp, 1, y, incy, arg); |
208 | |
209 | a += m_blk; |
210 | x += m_blk * incx; |
211 | } |
212 | } |
213 | } |
214 | } |
215 | #undef M_BLK |
216 | |
217 | #define M_MIN 128 |
218 | #define N_MIN 128 |
219 | #define BAND_MIN 32 |
220 | #define MN_MIN_N 1536 |
221 | #define MN_MIN_T 2048 |
222 | #define M_LARGE 20000 |
223 | #define N_LARGE 20000 |
224 | #define M_SMALL 200 |
225 | #define N_SMALL 200 |
226 | #define CONST1_AVX2 288 |
227 | #define CONST2_AVX2 41700 |
228 | #define MIN_WIDTH 32 |
229 | // Check if threading is beneficial. |
230 | template <typename a_t> |
231 | static inline int thread_checker( |
232 | int nthr, const dim_t m, const dim_t n, int trans) { |
233 | constexpr bool is_f32 |
234 | = utils::one_of(data_traits<a_t>::data_type, data_type::f32); |
235 | |
236 | if (is_f32) { |
237 | // Threshold based on performance measurement with warm and cold cache |
238 | // to decide when threading is beneficial. |
239 | if (mayiuse(avx2)) { |
240 | if (m * n + CONST1_AVX2 * n < CONST2_AVX2) { return 1; } |
241 | } else { |
242 | if (m < M_MIN && n < N_MIN) { |
243 | // Execute in sequential mode for small n and m. |
244 | return 1; |
245 | } |
246 | } |
247 | |
248 | if (m >= M_LARGE && n <= N_SMALL) { |
249 | // Execute in parallel mode. |
250 | return nthr; |
251 | } |
252 | |
253 | dim_t bandt = n / nthr; // size per thread. |
254 | |
255 | if (nthr <= 12 && bandt < BAND_MIN) { |
256 | if (m * bandt < MN_MIN_T) { return 1; } |
257 | } else if (nthr <= 12 && m * bandt < 2 * MN_MIN_T) { |
258 | return 1; |
259 | } else if (nthr > 12 && bandt * m < 2 * MN_MIN_T) { |
260 | if (bandt == 0) { |
261 | return 1; |
262 | } else { |
263 | return nstl::min(nstl::max(n * m / (2 * MN_MIN_N), dim_t(1)), |
264 | dim_t(nthr)); |
265 | } |
266 | } |
267 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
268 | if (is_f32) { |
269 | static const bool is_avx2 = mayiuse(avx2) && !mayiuse(avx512_core); |
270 | static auto l2_cache_per_thread |
271 | = platform::get_per_core_cache_size(2); |
272 | static int n_cores_per_socket |
273 | = static_cast<int>(platform::get_num_cores()); |
274 | auto l2_cache_socket = l2_cache_per_thread * n_cores_per_socket; |
275 | auto problem_memory_footprint = m * n * sizeof(float); |
276 | |
277 | if (is_avx2) { |
278 | // Somehow it seems beneficial to split the job into bigger |
279 | // pieces. Use L2 per-core cache size as a deal-breaker. |
280 | int use_n_threads = utils::div_up( |
281 | problem_memory_footprint, l2_cache_per_thread); |
282 | return nstl::min(nthr, use_n_threads); |
283 | } |
284 | if (l2_cache_socket > problem_memory_footprint) { |
285 | return nstl::min(nthr, n_cores_per_socket); |
286 | } |
287 | } |
288 | #endif |
289 | |
290 | } else { |
291 | if (trans) { |
292 | if (MIN_WIDTH * nthr > m) nthr = utils::div_up(m, MIN_WIDTH); |
293 | } else { |
294 | if (MIN_WIDTH * nthr > n) nthr = utils::div_up(n, MIN_WIDTH); |
295 | } |
296 | } |
297 | |
298 | return nthr; |
299 | } |
300 | #undef M_MIN |
301 | #undef N_MIN |
302 | #undef BAND_MIN |
303 | #undef MN_MIN_N |
304 | #undef MN_MIN_T |
305 | #undef M_LARGE |
306 | #undef N_LARGE |
307 | #undef M_SMALL |
308 | #undef N_SMALL |
309 | #undef CONST1_AVX2 |
310 | #undef CONST2_AVX2 |
311 | #undef MIN_WIDTH |
312 | |
313 | template <typename T> |
314 | static inline void part_1d(const dim_t m, const int ithr, const int nthr, |
315 | T *addr, dim_t &off, dim_t &size) { |
316 | constexpr bool is_f32 |
317 | = utils::one_of(data_traits<T>::data_type, data_type::f32); |
318 | |
319 | if (ithr >= nthr) { |
320 | size = 0; |
321 | off = 0; |
322 | return; |
323 | } |
324 | |
325 | if (is_f32) { |
326 | if (addr == nullptr) { |
327 | dim_t xthr = m % nthr; |
328 | dim_t width = m / nthr; |
329 | |
330 | if (ithr < xthr) { |
331 | size = width + 1; |
332 | off = ithr * size; |
333 | } else { |
334 | size = width; |
335 | off = m - (nthr - ithr) * size; |
336 | } |
337 | } else { |
338 | // Consider cache slashing. |
339 | enum { CACHE_LINE_SIZE = 64 }; |
340 | |
341 | // Find the offset against cache line. |
342 | dim_t cache_off = (size_t)addr % CACHE_LINE_SIZE / sizeof(*addr); |
343 | |
344 | // Find partition size, but it needs to be multiple of cache line. |
345 | dim_t align = CACHE_LINE_SIZE / sizeof(*addr); |
346 | dim_t width |
347 | = utils::rnd_up(utils::div_up(m + cache_off, nthr), align); |
348 | |
349 | if (width > m + cache_off) width = m + cache_off; |
350 | |
351 | if (ithr == 0) { |
352 | // First thread is sacrificed to align against cache. |
353 | size = width - cache_off; |
354 | off = 0; |
355 | } else { |
356 | size = width; |
357 | off = ithr * width - cache_off; |
358 | } |
359 | } |
360 | } else { |
361 | size = utils::div_up(m, nthr); |
362 | off = ithr * size; |
363 | } |
364 | |
365 | if (off > m) off = m; |
366 | if (off + size > m) size = m - off; |
367 | } |
368 | |
369 | template <typename c_t> |
370 | void sum_ybufs( |
371 | int ithr, int nthr, dim_t m, c_t *y, dim_t incy, c_t *ybuf, int nbufs) { |
372 | if (incy < 0) y += (-m + 1) * incy; |
373 | |
374 | dim_t off_m = 0; |
375 | dim_t thread_m = 0; |
376 | |
377 | // Reduction in each thread. |
378 | part_1d(m, ithr, nthr, (c_t *)nullptr, off_m, thread_m); |
379 | if (incy == 1) |
380 | for (int buf_id = 0; buf_id < nbufs; buf_id++) { |
381 | PRAGMA_OMP_SIMD() |
382 | for (dim_t i = off_m; i < off_m + thread_m; i++) |
383 | y[i] += ybuf[i + buf_id * m]; |
384 | } |
385 | else |
386 | for (int buf_id = 0; buf_id < nbufs; buf_id++) |
387 | for (dim_t i = off_m; i < off_m + thread_m; i++) |
388 | y[i * incy] += ybuf[i + buf_id * m]; |
389 | } |
390 | |
391 | template <typename a_t, typename b_t, typename c_t> |
392 | static inline void gemv_threading_driver(const int trans, const dim_t m, |
393 | const dim_t n, const float alpha, const a_t *a, const dim_t lda, |
394 | const b_t *x, const dim_t incx, const float beta, c_t *y, |
395 | const dim_t incy, const gemm_info_t<a_t, b_t, c_t> *arg) { |
396 | constexpr bool is_f32 |
397 | = utils::one_of(data_traits<a_t>::data_type, data_type::f32); |
398 | constexpr bool is_bf16 |
399 | = utils::one_of(data_traits<a_t>::data_type, data_type::bf16); |
400 | |
401 | // Quick return if possible. |
402 | if (m <= 0 || n <= 0) return; |
403 | |
404 | auto nthr_max = dnnl_get_current_num_threads(); |
405 | auto nthr_goal = thread_checker<a_t>(nthr_max, m, n, trans); |
406 | |
407 | if (nthr_goal == 1) { |
408 | gemv_kernel_driver( |
409 | trans, m, n, alpha, a, lda, x, incx, beta, y, incy, arg); |
410 | return; |
411 | } |
412 | |
413 | enum { M_MIN = 500, N_MIN = 128 }; |
414 | bool is_short_fat = m <= nthr_goal * M_MIN && n >= nthr_goal * N_MIN; |
415 | |
416 | bool use_y_buf = trans == no_trans && (is_bf16 || (is_f32 && is_short_fat)); |
417 | bool is_syncable = dnnl_thr_syncable(); |
418 | |
419 | c_t *ybuf = nullptr; |
420 | if (use_y_buf) |
421 | ybuf = (c_t *)malloc(sizeof(*ybuf) * m * (nthr_goal - 1), PAGE_4K); |
422 | |
423 | // Always use the maximum number of threads to avoid OMP overhead that can |
424 | // occur due to change thread counts. |
425 | auto nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal; |
426 | int nbufs_used = 0; |
427 | parallel(nthr_spawn, [&](int ithr, int nthr) { |
428 | int nthr_eff = nstl::min(nthr_goal, nthr); |
429 | |
430 | dim_t thread_m = m, off_m = 0; |
431 | dim_t thread_n = n, off_n = 0; |
432 | dim_t band = 1; |
433 | |
434 | // Default effective values. |
435 | auto a_eff = a; |
436 | auto x_eff = x; |
437 | auto y_eff = y; |
438 | auto incy_eff = incy; |
439 | auto beta_eff = beta; |
440 | |
441 | if (trans == do_trans) { |
442 | part_1d(n, ithr, nthr_eff, (c_t *)nullptr, off_n, thread_n); |
443 | a_eff += off_m + off_n * lda; |
444 | y_eff += off_n * incy; |
445 | if (incy < 0) y_eff += (-n + thread_n) * incy; |
446 | band = thread_n; |
447 | } else if (ybuf) { |
448 | // Non-transpose for short and fat matrix sizes. |
449 | part_1d(n, ithr, nthr_eff, (c_t *)nullptr, off_n, thread_n); |
450 | a_eff += off_m + off_n * lda; |
451 | x_eff += off_n * incx; |
452 | if (incx < 0) x_eff += (-n + thread_n) * incx; |
453 | if (ithr != 0) { |
454 | y_eff = ybuf + m * (ithr - 1); |
455 | incy_eff = 1; |
456 | beta_eff = 0.0; |
457 | } else { |
458 | // Set number of used buffers for perform reduction later. |
459 | nbufs_used = nthr_eff - 1; |
460 | } |
461 | } else { |
462 | // Non-transpose for other matrix sizes. |
463 | // Fallback for no_trans with no extra buffer. |
464 | part_1d(m, ithr, nthr_eff, y, off_m, thread_m); |
465 | a_eff += off_m + off_n * lda; |
466 | y_eff += off_m * incy; |
467 | if (incy < 0) y_eff += (-m + thread_m) * incy; |
468 | band = thread_m; |
469 | } |
470 | |
471 | // Buffers for y need to be set to zero for reduction case. |
472 | assert(IMPLICATION(ybuf, band > 0)); |
473 | |
474 | if (band > 0 && ithr < nthr_eff) |
475 | gemv_kernel_driver(trans, thread_m, thread_n, alpha, a_eff, lda, |
476 | x_eff, incx, beta_eff, y_eff, incy_eff, arg); |
477 | |
478 | // Do reduction for multiple buffers if needed. |
479 | if (is_syncable && ybuf) { |
480 | dnnl_thr_barrier(); |
481 | |
482 | sum_ybufs(ithr, nthr_eff, m, y, incy, ybuf, nbufs_used); |
483 | } |
484 | }); |
485 | |
486 | // Reduce on y after each gemv computation is done. |
487 | if (!is_syncable && ybuf) { |
488 | parallel(nthr_spawn, [&](int ithr, int nthr) { |
489 | sum_ybufs(ithr, nthr, m, y, incy, ybuf, nbufs_used); |
490 | }); |
491 | } |
492 | |
493 | free(ybuf); |
494 | } |
495 | |
496 | template <> |
497 | dnnl_status_t jump_to_gemv(const gemm_info_t<int8_t, uint8_t, int32_t> *arg) { |
498 | return dnnl_unimplemented; |
499 | } |
500 | |
501 | template <> |
502 | dnnl_status_t jump_to_gemv(const gemm_info_t<int8_t, int8_t, int32_t> *arg) { |
503 | return dnnl_unimplemented; |
504 | } |
505 | |
506 | template <typename a_t, typename b_t, typename c_t> |
507 | dnnl_status_t jump_to_gemv(const gemm_info_t<a_t, b_t, c_t> *arg) { |
508 | int transa = arg->transa; |
509 | int transb = arg->transb; |
510 | |
511 | dim_t m = arg->m; |
512 | dim_t n = arg->n; |
513 | dim_t k = arg->k; |
514 | |
515 | dim_t lda = arg->lda; |
516 | dim_t ldb = arg->ldb; |
517 | dim_t ldc = arg->ldc; |
518 | |
519 | float alpha = arg->alpha; |
520 | float beta = arg->beta; |
521 | |
522 | const a_t *a = arg->a; |
523 | const b_t *b = arg->b; |
524 | c_t *c = arg->c; |
525 | |
526 | if (k == 0) return dnnl_success; |
527 | |
528 | auto packing = (arg->packing != pack_type::none); |
529 | auto do_a = (arg->packing == pack_type::pack_a); |
530 | gemm_pack_storage_t *pack_dst = arg->pack_dst; |
531 | |
532 | if (n == 1 && (transa == do_trans || packing)) { |
533 | if (!packing) { |
534 | gemv_threading_driver(do_trans, k, m, alpha, a, lda, b, |
535 | transb == no_trans ? 1 : ldb, beta, c, 1, arg); |
536 | } else { |
537 | if (do_a) { |
538 | gemm_utils::prep_gemm_pack<a_t, c_t>( |
539 | do_a, do_trans, m, k, pack_dst); |
540 | } else { |
541 | gemm_utils::prep_gemm_pack<b_t, c_t>( |
542 | do_a, no_trans, k, n, pack_dst); |
543 | } |
544 | |
545 | if (arg->measure_only) return dnnl_success; |
546 | |
547 | if (do_a) { |
548 | gemm_utils::pack_no_copy(a, lda, m, k, transa, alpha, pack_dst); |
549 | } else { |
550 | gemm_utils::pack_no_copy(b, ldb, k, n, transb, alpha, pack_dst); |
551 | } |
552 | } |
553 | return dnnl_success; |
554 | } else if (n == 1 && transa == no_trans && !packing) { |
555 | gemv_threading_driver(no_trans, m, k, alpha, a, lda, b, |
556 | transb == no_trans ? 1 : ldb, beta, c, 1, arg); |
557 | return dnnl_success; |
558 | } |
559 | |
560 | if (m == 1 && (transb == no_trans || packing)) { |
561 | if (!packing) { |
562 | gemv_threading_driver(do_trans, k, n, alpha, b, ldb, a, |
563 | transa == no_trans ? lda : 1, beta, c, ldc, arg); |
564 | } else { |
565 | if (do_a) { |
566 | gemm_utils::prep_gemm_pack<a_t, c_t>( |
567 | do_a, do_trans, m, k, pack_dst); |
568 | } else { |
569 | gemm_utils::prep_gemm_pack<b_t, c_t>( |
570 | do_a, no_trans, k, n, pack_dst); |
571 | } |
572 | |
573 | if (arg->measure_only) return dnnl_success; |
574 | |
575 | if (do_a) { |
576 | gemm_utils::pack_no_copy(a, lda, m, k, transa, alpha, pack_dst); |
577 | } else { |
578 | gemm_utils::pack_no_copy(b, ldb, k, n, transb, alpha, pack_dst); |
579 | } |
580 | } |
581 | return dnnl_success; |
582 | } else if (m == 1 && transb == do_trans && !packing) { |
583 | gemv_threading_driver(no_trans, n, k, alpha, b, ldb, a, |
584 | transa == no_trans ? lda : 1, beta, c, ldc, arg); |
585 | return dnnl_success; |
586 | } |
587 | |
588 | return dnnl_unimplemented; |
589 | } |
590 | |
591 | template // Instatiate gemv_f32 |
592 | dnnl_status_t |
593 | jump_to_gemv<float, float, float>( |
594 | const gemm_info_t<float, float, float> *arg); |
595 | template // Instatiate gemv_bf16bf16f32 |
596 | dnnl_status_t |
597 | jump_to_gemv<bfloat16_t, bfloat16_t, float>( |
598 | const gemm_info_t<bfloat16_t, bfloat16_t, float> *arg); |
599 | |
600 | } // namespace x64 |
601 | } // namespace cpu |
602 | } // namespace impl |
603 | } // namespace dnnl |
604 | |