1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdint> |
18 | #include <memory> |
19 | #include <mutex> |
20 | |
21 | #include "oneapi/dnnl/dnnl_types.h" |
22 | |
23 | #include "common/bfloat16.hpp" |
24 | #include "common/dnnl_traits.hpp" |
25 | |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | |
29 | #include "cpu/x64/gemm/gemm_info.hpp" |
30 | |
31 | #include "cpu/x64/gemm/amx/jit_avx512_core_amx_copy_kern.hpp" |
32 | #include "cpu/x64/gemm/amx/jit_avx512_core_amx_gemm_kern.hpp" |
33 | |
34 | #include "cpu/x64/gemm/bf16/common_s16.hpp" |
35 | #include "cpu/x64/gemm/bf16/jit_avx512_core_gemm_bf16bf16f32_kern.hpp" |
36 | #include "cpu/x64/gemm/bf16/jit_avx512_core_gemv_bf16bf16f32_kern.hpp" |
37 | |
38 | #include "cpu/x64/gemm/f32/common_f32.hpp" |
39 | #include "cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.hpp" |
40 | #include "cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp" |
41 | #include "cpu/x64/gemm/f32/jit_sse41_gemv_n_f32_kern.hpp" |
42 | #include "cpu/x64/gemm/f32/jit_sse41_gemv_t_f32_kern.hpp" |
43 | |
44 | #include "cpu/x64/gemm/s8x8s32/common_u8.hpp" |
45 | #include "cpu/x64/gemm/s8x8s32/jit_avx2_gemm_s8u8s32_kern.hpp" |
46 | #include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp" |
47 | #include "cpu/x64/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8x8s32_kern.hpp" |
48 | |
49 | namespace dnnl { |
50 | namespace impl { |
51 | namespace cpu { |
52 | namespace x64 { |
53 | |
54 | static inline int decode_trans(char trans) { |
55 | switch (trans) { |
56 | case 'T': |
57 | case 't': return do_trans; |
58 | case 'P': |
59 | case 'p': return packed; |
60 | default: return no_trans; |
61 | } |
62 | } |
63 | |
64 | namespace { |
65 | template <typename b_t> // XXX for float and bfloat |
66 | void prepare_bo(int32_t &bo_gemm_info, const b_t *bo_orig) { |
67 | UNUSED(bo_orig); |
68 | bo_gemm_info = 0; |
69 | } |
70 | template <> |
71 | void prepare_bo(int32_t &bo_gemm_info, const uint8_t *bo_orig) { |
72 | bo_gemm_info = bo_orig ? *bo_orig : 0; |
73 | } |
74 | template <> |
75 | void prepare_bo(int32_t &bo_gemm_info, const int8_t *bo_orig) { |
76 | int bo_s32 = bo_orig ? *bo_orig : 0; |
77 | if (!mayiuse(avx512_core_amx)) bo_s32 += 128; |
78 | bo_gemm_info = bo_s32; |
79 | } |
80 | |
81 | } // namespace |
82 | |
83 | template <typename a_t, typename b_t, typename c_t> |
84 | gemm_info_t<a_t, b_t, c_t>::gemm_info_t(const char *transA, const char *transB, |
85 | const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k, |
86 | const float *alpha, const a_t *a, const dim_t *lda, const a_t *oa, |
87 | const b_t *b, const dim_t *ldb, const b_t *ob, const float *beta, |
88 | c_t *c, const dim_t *ldc, const c_t *oc, bool force_nocopy, |
89 | pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) { |
90 | |
91 | this->transa = decode_trans(*transA); |
92 | this->transb = decode_trans(*transB); |
93 | |
94 | this->m = *m; |
95 | this->n = *n; |
96 | this->k = *k; |
97 | |
98 | this->a = a; |
99 | this->b = b; |
100 | this->c = c; |
101 | |
102 | this->lda = lda ? *lda : 0; |
103 | this->ldb = ldb ? *ldb : 0; |
104 | this->ldc = ldc ? *ldc : 0; |
105 | |
106 | this->ao = 0; |
107 | this->bo = 0; |
108 | this->co = nullptr; |
109 | |
110 | this->alpha = alpha ? *alpha : 1.0f; |
111 | this->beta = beta ? *beta : 1.0f; |
112 | |
113 | this->offsetc = offset_type::none; |
114 | |
115 | this->packing = packing; |
116 | this->pack_dst = pack_dst; |
117 | this->measure_only |
118 | = measure_only && pack_dst && (packing != pack_type::none); |
119 | |
120 | if (this->transa == packed) { |
121 | dim_t cols; |
122 | |
123 | this->a_packed.reset(new gemm_pack_storage_t(a)); |
124 | if (this->a_packed->get_nocopy(this->transa, this->lda, cols)) { |
125 | this->a = this->a_packed->template matrix<a_t>(); |
126 | this->a_packed = nullptr; |
127 | } |
128 | } |
129 | if (this->transb == packed) { |
130 | dim_t rows; |
131 | |
132 | this->b_packed.reset(new gemm_pack_storage_t(b)); |
133 | if (this->b_packed->get_nocopy(this->transb, this->ldb, rows)) { |
134 | this->b = this->b_packed->template matrix<b_t>(); |
135 | this->b_packed = nullptr; |
136 | } |
137 | } |
138 | |
139 | constexpr bool is_int8 = utils::one_of( |
140 | data_traits<a_t>::data_type, data_type::s8, data_type::u8); |
141 | if (is_int8) this->ao = oa ? *oa : a_t(0); |
142 | prepare_bo<b_t>(this->bo, ob); |
143 | |
144 | if (offsetC != nullptr) { |
145 | char offsetc = *offsetC; |
146 | if (offsetc == 'F' || offsetc == 'f') { |
147 | this->offsetc = offset_type::fixed; |
148 | } else if (offsetc == 'R' || offsetc == 'r') { |
149 | this->offsetc = offset_type::row; |
150 | } else { // offsetc == 'C' || offsetc == 'c' |
151 | this->offsetc = offset_type::column; |
152 | } |
153 | this->co = oc; |
154 | } |
155 | |
156 | bool is_sgemm = data_traits<a_t>::data_type == data_type::f32; |
157 | bool is_gemv = this->m == 1 || this->n == 1; |
158 | |
159 | // Copy-based sgemm doesn't support force-nocopy for ISAs older |
160 | // than Intel AVX. |
161 | this->force_nocopy = is_sgemm && force_nocopy && mayiuse(avx); |
162 | |
163 | if (!this->force_nocopy || is_gemv) { this->jit_init(); } |
164 | } |
165 | |
166 | static std::mutex kern_mutex; |
167 | |
168 | // copyA[trans][sum] |
169 | template <typename a_t, typename b_t, typename c_t> |
170 | typename gemm_info_t<a_t, b_t, c_t>::copy_a_fptr_t |
171 | gemm_info_t<a_t, b_t, c_t>::copy_a_kern[2][2] |
172 | = {{nullptr}}; |
173 | |
174 | // copyB[trans][sum] |
175 | template <typename a_t, typename b_t, typename c_t> |
176 | typename gemm_info_t<a_t, b_t, c_t>::copy_b_fptr_t |
177 | gemm_info_t<a_t, b_t, c_t>::copy_b_kern[2][2] |
178 | = {{nullptr}}; |
179 | |
180 | // kern[beta0][alpha1][col_off][row_off] |
181 | template <typename a_t, typename b_t, typename c_t> |
182 | typename gemm_info_t<a_t, b_t, c_t>::gemm_fptr_t |
183 | gemm_info_t<a_t, b_t, c_t>::kern[2][2][2][2] |
184 | = {{{{nullptr}}}}; |
185 | |
186 | // gemv_kern[trans] |
187 | template <typename a_t, typename b_t, typename c_t> |
188 | typename gemm_info_t<a_t, b_t, c_t>::gemv_fptr_t |
189 | gemm_info_t<a_t, b_t, c_t>::gemv_kern[2] |
190 | = {nullptr}; |
191 | |
192 | template <typename a_t, typename b_t, typename c_t> |
193 | typename gemm_info_t<a_t, b_t, c_t>::gemv_s8s8s32_fptr_t |
194 | gemm_info_t<a_t, b_t, c_t>::gemv_s8s8s32_kern |
195 | = nullptr; |
196 | template <typename a_t, typename b_t, typename c_t> |
197 | typename gemm_info_t<a_t, b_t, c_t>::gemv_s8u8s32_fptr_t |
198 | gemm_info_t<a_t, b_t, c_t>::gemv_s8u8s32_kern |
199 | = nullptr; |
200 | template <typename a_t, typename b_t, typename c_t> |
201 | typename gemm_info_t<a_t, b_t, c_t>::gemv_u8s8s32_fptr_t |
202 | gemm_info_t<a_t, b_t, c_t>::gemv_u8s8s32_kern |
203 | = nullptr; |
204 | |
205 | template <typename a_t, typename b_t, typename c_t> |
206 | void gemm_info_t<a_t, b_t, c_t>::jit_init(void) { |
207 | |
208 | bool use_bf16_ymm = false; |
209 | // TODO: Add dispatching for 1-fma SKUs with support to bf16 |
210 | // instructions for AMX kernel. |
211 | { |
212 | constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16; |
213 | const bool max_isa_supports_bf16_ymm |
214 | = mayiuse(avx512_core_bf16_ymm) && !mayiuse(avx512_core_amx); |
215 | |
216 | use_bf16_ymm = is_bf16 && max_isa_supports_bf16_ymm; |
217 | } |
218 | |
219 | switch (data_traits<a_t>::data_type) { |
220 | case data_type::s8: |
221 | if (mayiuse(avx512_core_amx)) { |
222 | this->um = 32; |
223 | this->un = 32; |
224 | this->uk = 64; |
225 | this->bm = 9984; |
226 | this->bn = 384; |
227 | this->bk = 1536; |
228 | |
229 | this->bk_traditional = 0; |
230 | this->blocking_small_k = 0; |
231 | this->bn_small_k = 0; |
232 | } else if (mayiuse(avx512_core)) { |
233 | this->um = 48; |
234 | this->un = 8; |
235 | this->uk = 1; |
236 | this->bm = 9984; |
237 | this->bn = 384; |
238 | this->bk = mayiuse(avx512_core_vnni) ? 1536 : 768; |
239 | |
240 | this->bk_traditional = 384; |
241 | this->blocking_small_k = 48; |
242 | this->bn_small_k = 24; |
243 | } else if (mayiuse(avx2)) { |
244 | this->um = mayiuse(avx2_vnni) ? 24 : 16; |
245 | this->un = 4; |
246 | this->uk = 1; |
247 | this->bm = 9984; |
248 | this->bn = mayiuse(avx2_vnni) ? 192 : 384; |
249 | this->bk = mayiuse(avx2_vnni) ? 768 : 384; |
250 | |
251 | this->bk_traditional = 256; |
252 | this->blocking_small_k = 48; |
253 | this->bn_small_k = 24; |
254 | } else if (mayiuse(avx)) { |
255 | this->um = 16; |
256 | this->un = 2; |
257 | this->uk = 1; |
258 | this->bm = 4096; |
259 | this->bn = 256; |
260 | this->bk = 256; |
261 | |
262 | this->bk_traditional = 256; |
263 | this->blocking_small_k = 48; |
264 | this->bn_small_k = 24; |
265 | } else if (mayiuse(sse41)) { |
266 | this->um = 16; |
267 | this->un = 2; |
268 | this->uk = 1; |
269 | this->bm = 4096; |
270 | this->bn = 256; |
271 | this->bk = 256; |
272 | |
273 | this->bk_traditional = 256; |
274 | this->blocking_small_k = 48; |
275 | this->bn_small_k = 24; |
276 | } |
277 | break; |
278 | |
279 | case data_type::bf16: |
280 | if (mayiuse(avx512_core_amx)) { |
281 | this->um = 32; |
282 | this->un = 32; |
283 | this->uk = 32; |
284 | this->bm = 9984; |
285 | this->bn = 384; |
286 | this->bk = 768; |
287 | |
288 | this->bk_traditional = 0; |
289 | this->blocking_small_k = 0; |
290 | this->bn_small_k = 0; |
291 | } else if (mayiuse(avx512_core)) { |
292 | this->um = use_bf16_ymm ? 24 : 48; |
293 | this->un = 8; |
294 | this->uk = 1; |
295 | this->bm = 9984; |
296 | this->bn = 384; |
297 | this->bk = use_bf16_ymm ? 384 : 768; |
298 | |
299 | this->bk_traditional = 384; |
300 | this->blocking_small_k = 48; |
301 | this->bn_small_k = 24; |
302 | } |
303 | break; |
304 | |
305 | case data_type::f32: |
306 | if (mayiuse(avx512_core)) { |
307 | this->um = 48; |
308 | this->un = 8; |
309 | this->uk = 1; |
310 | this->bm = 9984; |
311 | this->bn = 384; |
312 | this->bk = 384; |
313 | |
314 | this->bk_traditional = 384; |
315 | this->blocking_small_k = 48; |
316 | this->bn_small_k = 24; |
317 | } else if (mayiuse(avx2)) { |
318 | this->um = 24; |
319 | this->un = 4; |
320 | this->uk = 1; |
321 | this->bm = 10000; |
322 | this->bn = 384; |
323 | this->bk = 192; |
324 | |
325 | this->bk_traditional = 256; |
326 | this->blocking_small_k = 48; |
327 | this->bn_small_k = 24; |
328 | } else if (mayiuse(avx)) { |
329 | this->um = 16; |
330 | this->un = 4; |
331 | this->uk = 1; |
332 | this->bm = 4096; |
333 | this->bn = 96; |
334 | this->bk = 256; |
335 | |
336 | this->bk_traditional = 256; |
337 | this->blocking_small_k = 48; |
338 | this->bn_small_k = 24; |
339 | } else if (mayiuse(sse41)) { |
340 | this->um = 8; |
341 | this->un = 4; |
342 | this->uk = 1; |
343 | this->bm = 4096; |
344 | this->bn = 96; |
345 | this->bk = 256; |
346 | |
347 | this->bk_traditional = 256; |
348 | this->blocking_small_k = 48; |
349 | this->bn_small_k = 24; |
350 | } |
351 | break; |
352 | default: assert(!"unsupported data type!" ); |
353 | } |
354 | |
355 | // Note: um is fixed for a given set of data types and ISA. |
356 | const int um = this->um; |
357 | |
358 | static std::once_flag initialized; |
359 | static std::atomic<dnnl_status_t> st(dnnl_success); |
360 | std::call_once(initialized, [&, um] { |
361 | const bool b_is_s8 = data_traits<b_t>::data_type == data_type::s8; |
362 | constexpr bool is_int8 = utils::one_of( |
363 | data_traits<a_t>::data_type, data_type::s8, data_type::u8); |
364 | constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16; |
365 | bool is_int8_amx = is_int8 && mayiuse(avx512_core_amx); |
366 | bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_amx); |
367 | bool is_amx = is_int8_amx || is_bf16_amx; |
368 | |
369 | static maybe_unique_ptr<jit_generator> copy_a[2][2] = {{nullptr}}; |
370 | static maybe_unique_ptr<jit_generator> copy_b[2][2] = {{nullptr}}; |
371 | |
372 | switch (data_traits<a_t>::data_type) { |
373 | case data_type::s8: |
374 | if (mayiuse(amx_int8)) { |
375 | for (int isTrans : {no_trans, do_trans}) { |
376 | copy_a[isTrans][no_sum].reset( |
377 | new jit_avx512_core_amx_copy_kern( |
378 | true, !isTrans, sizeof(a_t))); |
379 | |
380 | copy_b[isTrans][no_sum].reset( |
381 | new jit_avx512_core_amx_copy_kern( |
382 | false, isTrans, sizeof(b_t))); |
383 | } |
384 | } else if (mayiuse(avx512_core)) { |
385 | copy_a[no_trans][no_sum].reset( |
386 | new jit_avx512_core_u8_copy_an_kern()); |
387 | copy_a[do_trans][no_sum].reset( |
388 | new jit_avx512_core_u8_copy_at_kern()); |
389 | |
390 | copy_b[no_trans][no_sum].reset( |
391 | new jit_avx512_core_u8_copy_bn_kern(b_is_s8)); |
392 | copy_b[do_trans][no_sum].reset( |
393 | new jit_avx512_core_u8_copy_bt_kern(b_is_s8)); |
394 | |
395 | copy_a[no_trans][do_sum].reset( |
396 | new jit_avx512_core_u8_copy_sum_an_kern()); |
397 | copy_a[do_trans][do_sum].reset( |
398 | new jit_avx512_core_u8_copy_sum_at_kern()); |
399 | |
400 | copy_b[no_trans][do_sum].reset( |
401 | new jit_avx512_core_u8_copy_sum_bn_kern(b_is_s8)); |
402 | copy_b[do_trans][do_sum].reset( |
403 | new jit_avx512_core_u8_copy_sum_bt_kern(b_is_s8)); |
404 | } else if (mayiuse(avx2_vnni)) { |
405 | copy_a[no_trans][no_sum].reset( |
406 | new jit_avx2_vnni_u8_copy_an_kern()); |
407 | copy_a[do_trans][no_sum].reset( |
408 | new jit_avx2_vnni_u8_copy_at_kern()); |
409 | |
410 | copy_b[no_trans][no_sum].reset( |
411 | new jit_avx2_vnni_u8_copy_bn_kern()); |
412 | copy_b[do_trans][no_sum].reset( |
413 | new jit_avx2_vnni_u8_copy_bt_kern()); |
414 | |
415 | copy_a[no_trans][do_sum].reset( |
416 | new jit_avx2_vnni_u8_copy_sum_an_kern()); |
417 | copy_a[do_trans][do_sum].reset( |
418 | new jit_avx2_vnni_u8_copy_sum_at_kern()); |
419 | |
420 | copy_b[no_trans][do_sum].reset( |
421 | new jit_avx2_vnni_u8_copy_sum_bn_kern()); |
422 | copy_b[do_trans][do_sum].reset( |
423 | new jit_avx2_vnni_u8_copy_sum_bt_kern()); |
424 | } else if (mayiuse(avx2)) { |
425 | copy_a[no_trans][no_sum].reset( |
426 | new jit_avx2_u8_copy_an_kern()); |
427 | copy_a[do_trans][no_sum].reset( |
428 | new jit_avx2_u8_copy_at_kern()); |
429 | |
430 | copy_b[no_trans][no_sum].reset( |
431 | new jit_avx2_u8_copy_bn_kern()); |
432 | copy_b[do_trans][no_sum].reset( |
433 | new jit_avx2_u8_copy_bt_kern()); |
434 | |
435 | copy_a[no_trans][do_sum].reset( |
436 | new jit_avx2_u8_copy_sum_an_kern()); |
437 | copy_a[do_trans][do_sum].reset( |
438 | new jit_avx2_u8_copy_sum_at_kern()); |
439 | |
440 | copy_b[no_trans][do_sum].reset( |
441 | new jit_avx2_u8_copy_sum_bn_kern()); |
442 | copy_b[do_trans][do_sum].reset( |
443 | new jit_avx2_u8_copy_sum_bt_kern()); |
444 | } else if (mayiuse(avx)) { |
445 | copy_a[no_trans][no_sum].reset( |
446 | new jit_avx_u8_copy_an_kern()); |
447 | copy_a[do_trans][no_sum].reset( |
448 | new jit_avx_u8_copy_at_kern()); |
449 | |
450 | copy_b[no_trans][no_sum].reset( |
451 | new jit_avx_u8_copy_bn_kern()); |
452 | copy_b[do_trans][no_sum].reset( |
453 | new jit_avx_u8_copy_bt_kern()); |
454 | |
455 | copy_a[no_trans][do_sum].reset( |
456 | new jit_avx_u8_copy_sum_an_kern()); |
457 | copy_a[do_trans][do_sum].reset( |
458 | new jit_avx_u8_copy_sum_at_kern()); |
459 | |
460 | copy_b[no_trans][do_sum].reset( |
461 | new jit_avx_u8_copy_sum_bn_kern()); |
462 | copy_b[do_trans][do_sum].reset( |
463 | new jit_avx_u8_copy_sum_bt_kern()); |
464 | } else if (mayiuse(sse41)) { |
465 | copy_a[no_trans][no_sum].reset( |
466 | new jit_sse41_u8_copy_an_kern()); |
467 | copy_a[do_trans][no_sum].reset( |
468 | new jit_sse41_u8_copy_at_kern()); |
469 | |
470 | copy_b[no_trans][no_sum].reset( |
471 | new jit_sse41_u8_copy_bn_kern()); |
472 | copy_b[do_trans][no_sum].reset( |
473 | new jit_sse41_u8_copy_bt_kern()); |
474 | |
475 | copy_a[no_trans][do_sum].reset( |
476 | new jit_sse41_u8_copy_sum_an_kern()); |
477 | copy_a[do_trans][do_sum].reset( |
478 | new jit_sse41_u8_copy_sum_at_kern()); |
479 | |
480 | copy_b[no_trans][do_sum].reset( |
481 | new jit_sse41_u8_copy_sum_bn_kern()); |
482 | copy_b[do_trans][do_sum].reset( |
483 | new jit_sse41_u8_copy_sum_bt_kern()); |
484 | } |
485 | break; |
486 | |
487 | case data_type::bf16: |
488 | if (mayiuse(amx_bf16)) { |
489 | for (int isTrans : {no_trans, do_trans}) { |
490 | copy_a[isTrans][no_sum].reset( |
491 | new jit_avx512_core_amx_copy_kern( |
492 | true, !isTrans, sizeof(a_t))); |
493 | |
494 | copy_b[isTrans][no_sum].reset( |
495 | new jit_avx512_core_amx_copy_kern( |
496 | false, isTrans, sizeof(b_t))); |
497 | } |
498 | } else if (mayiuse(avx512_core) && !use_bf16_ymm) { |
499 | copy_a[no_trans][no_sum].reset( |
500 | new jit_avx512_core_s16_48x8_copy_an_kern()); |
501 | copy_a[do_trans][no_sum].reset( |
502 | new jit_avx512_core_s16_48x8_copy_at_kern()); |
503 | |
504 | copy_b[no_trans][no_sum].reset( |
505 | new jit_avx512_core_s16_48x8_copy_bn_kern()); |
506 | copy_b[do_trans][no_sum].reset( |
507 | new jit_avx512_core_s16_48x8_copy_bt_kern()); |
508 | } else if (mayiuse(avx512_core) && use_bf16_ymm) { |
509 | copy_a[no_trans][no_sum].reset( |
510 | new jit_avx512_core_s16_24x8_copy_an_kern()); |
511 | copy_a[do_trans][no_sum].reset( |
512 | new jit_avx512_core_s16_24x8_copy_at_kern()); |
513 | |
514 | copy_b[no_trans][no_sum].reset( |
515 | new jit_avx512_core_s16_24x8_copy_bn_kern()); |
516 | copy_b[do_trans][no_sum].reset( |
517 | new jit_avx512_core_s16_24x8_copy_bt_kern()); |
518 | } |
519 | break; |
520 | |
521 | case data_type::f32: |
522 | if (mayiuse(avx512_core)) { |
523 | copy_a[no_trans][no_sum].reset( |
524 | new jit_avx512_core_f32_copy_an_kern()); |
525 | copy_a[do_trans][no_sum].reset( |
526 | new jit_avx512_core_f32_copy_at_kern()); |
527 | |
528 | copy_b[no_trans][no_sum].reset( |
529 | new jit_avx512_core_f32_copy_bn_kern()); |
530 | copy_b[do_trans][no_sum].reset( |
531 | new jit_avx512_core_f32_copy_bt_kern()); |
532 | } else if (mayiuse(avx2)) { |
533 | copy_a[no_trans][no_sum].reset( |
534 | new jit_avx2_f32_copy_an_kern()); |
535 | copy_a[do_trans][no_sum].reset( |
536 | new jit_avx2_f32_copy_at_kern()); |
537 | |
538 | copy_b[no_trans][no_sum].reset( |
539 | new jit_avx2_f32_copy_bn_kern()); |
540 | copy_b[do_trans][no_sum].reset( |
541 | new jit_avx2_f32_copy_bt_kern()); |
542 | } else if (mayiuse(avx)) { |
543 | copy_a[no_trans][no_sum].reset( |
544 | new jit_avx_f32_copy_an_kern()); |
545 | copy_a[do_trans][no_sum].reset( |
546 | new jit_avx_f32_copy_at_kern()); |
547 | |
548 | copy_b[no_trans][no_sum].reset( |
549 | new jit_avx_f32_copy_bn_kern()); |
550 | copy_b[do_trans][no_sum].reset( |
551 | new jit_avx_f32_copy_bt_kern()); |
552 | } else if (mayiuse(sse41)) { |
553 | copy_a[no_trans][no_sum].reset( |
554 | new jit_sse41_f32_copy_an_kern()); |
555 | copy_a[do_trans][no_sum].reset( |
556 | new jit_sse41_f32_copy_at_kern()); |
557 | |
558 | copy_b[no_trans][no_sum].reset( |
559 | new jit_sse41_f32_copy_bn_kern()); |
560 | copy_b[do_trans][no_sum].reset( |
561 | new jit_sse41_f32_copy_bt_kern()); |
562 | } |
563 | break; |
564 | |
565 | default: break; |
566 | } |
567 | |
568 | constexpr bool is_a_s8 = data_traits<a_t>::data_type == data_type::s8; |
569 | constexpr bool is_b_s8 = data_traits<b_t>::data_type == data_type::s8; |
570 | constexpr bool is_c_s32 = data_traits<c_t>::data_type == data_type::s32; |
571 | |
572 | static maybe_unique_ptr<jit_generator> kernel[2][2][2][2] |
573 | = {{{{nullptr}}}}; |
574 | switch (data_traits<a_t>::data_type) { |
575 | case data_type::s8: |
576 | if (mayiuse(avx512_core_amx)) { |
577 | for (int isBeta0 : {no_beta0, do_beta0}) { |
578 | kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( |
579 | new jit_avx512_core_amx_gemm_kern( |
580 | is_a_s8, is_b_s8, is_c_s32, isBeta0)); |
581 | } |
582 | } else if (mayiuse(avx512_core)) { |
583 | for (int isBeta0 : {no_beta0, do_beta0}) |
584 | for (int doColSum : {no_sum, do_sum}) |
585 | for (int doRowSum : {no_sum, do_sum}) { |
586 | kernel[isBeta0][do_alpha1][doColSum][doRowSum].reset( |
587 | new jit_avx512_core_gemm_s8u8s32_kern( |
588 | isBeta0, doColSum, doRowSum)); |
589 | } |
590 | } else if (mayiuse(avx2)) { |
591 | for (int isBeta0 : {no_beta0, do_beta0}) |
592 | for (int doColSum : {no_sum, do_sum}) |
593 | for (int doRowSum : {no_sum, do_sum}) { |
594 | kernel[isBeta0][do_alpha1][doColSum][doRowSum] |
595 | .reset(new jit_avx2_gemm_s8u8s32_kern( |
596 | isBeta0, doColSum, doRowSum, |
597 | um)); |
598 | } |
599 | } else if (mayiuse(avx)) { |
600 | kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( |
601 | new jit_avx_kernel_gemm_s8u8s32_kern()); |
602 | kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( |
603 | new jit_avx_kernel_c_gemm_s8u8s32_kern()); |
604 | kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( |
605 | new jit_avx_kernel_r_gemm_s8u8s32_kern()); |
606 | kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( |
607 | new jit_avx_kernel_b_gemm_s8u8s32_kern()); |
608 | |
609 | kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( |
610 | new jit_avx_kernel_b0_gemm_s8u8s32_kern()); |
611 | kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( |
612 | new jit_avx_kernel_b0_c_gemm_s8u8s32_kern()); |
613 | kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( |
614 | new jit_avx_kernel_b0_r_gemm_s8u8s32_kern()); |
615 | kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( |
616 | new jit_avx_kernel_b0_b_gemm_s8u8s32_kern()); |
617 | } else if (mayiuse(sse41)) { |
618 | kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( |
619 | new jit_sse41_kernel_gemm_s8u8s32_kern()); |
620 | kernel[no_beta0][do_alpha1][do_sum][no_sum].reset( |
621 | new jit_sse41_kernel_c_gemm_s8u8s32_kern()); |
622 | kernel[no_beta0][do_alpha1][no_sum][do_sum].reset( |
623 | new jit_sse41_kernel_r_gemm_s8u8s32_kern()); |
624 | kernel[no_beta0][do_alpha1][do_sum][do_sum].reset( |
625 | new jit_sse41_kernel_b_gemm_s8u8s32_kern()); |
626 | |
627 | kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( |
628 | new jit_sse41_kernel_b0_gemm_s8u8s32_kern()); |
629 | kernel[do_beta0][do_alpha1][do_sum][no_sum].reset( |
630 | new jit_sse41_kernel_b0_c_gemm_s8u8s32_kern()); |
631 | kernel[do_beta0][do_alpha1][no_sum][do_sum].reset( |
632 | new jit_sse41_kernel_b0_r_gemm_s8u8s32_kern()); |
633 | kernel[do_beta0][do_alpha1][do_sum][do_sum].reset( |
634 | new jit_sse41_kernel_b0_b_gemm_s8u8s32_kern()); |
635 | } |
636 | break; |
637 | |
638 | case data_type::bf16: |
639 | if (mayiuse(avx512_core_amx)) { |
640 | for (int isBeta0 : {no_beta0, do_beta0}) { |
641 | kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( |
642 | new jit_avx512_core_amx_gemm_kern( |
643 | is_a_s8, is_b_s8, is_c_s32, isBeta0)); |
644 | } |
645 | } else if (mayiuse(avx512_core)) { |
646 | for (int isBeta0 : {no_beta0, do_beta0}) |
647 | for (int isAlpha1 : {no_alpha1, do_alpha1}) { |
648 | kernel[isBeta0][isAlpha1][no_sum][no_sum].reset( |
649 | new jit_avx512_core_gemm_bf16bf16f32_kern( |
650 | isBeta0, isAlpha1, !use_bf16_ymm)); |
651 | } |
652 | } |
653 | break; |
654 | |
655 | case data_type::f32: |
656 | if (mayiuse(avx2)) { |
657 | for (int isBeta0 : {no_beta0, do_beta0}) { |
658 | kernel[isBeta0][do_alpha1][no_sum][no_sum].reset( |
659 | new jit_avx2_kernel_sgemm_kern(isBeta0)); |
660 | } |
661 | } else if (mayiuse(avx)) { |
662 | kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( |
663 | new jit_avx_kernel_sgemm_kern()); |
664 | kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( |
665 | new jit_avx_kernel_b0_sgemm_kern()); |
666 | } else if (mayiuse(sse41)) { |
667 | kernel[no_beta0][do_alpha1][no_sum][no_sum].reset( |
668 | new jit_sse41_kernel_sgemm_kern()); |
669 | kernel[do_beta0][do_alpha1][no_sum][no_sum].reset( |
670 | new jit_sse41_kernel_b0_sgemm_kern()); |
671 | } |
672 | break; |
673 | |
674 | default: break; |
675 | } |
676 | |
677 | static maybe_unique_ptr<jit_generator> gemv_kernel[2] = {nullptr}; |
678 | static maybe_unique_ptr<jit_generator> gemv_s8s8s32_kernel = nullptr; |
679 | static maybe_unique_ptr<jit_generator> gemv_s8u8s32_kernel = nullptr; |
680 | static maybe_unique_ptr<jit_generator> gemv_u8s8s32_kernel = nullptr; |
681 | switch (data_traits<a_t>::data_type) { |
682 | case data_type::s8: |
683 | if (mayiuse(avx512_core)) { |
684 | gemv_s8s8s32_kernel.reset( |
685 | new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8s8)); |
686 | gemv_s8u8s32_kernel.reset( |
687 | new jit_avx512_core_gemv_s8x8s32_kern(ver_t::s8u8)); |
688 | gemv_u8s8s32_kernel.reset( |
689 | new jit_avx512_core_gemv_s8x8s32_kern(ver_t::u8s8)); |
690 | } |
691 | break; |
692 | |
693 | case data_type::bf16: |
694 | if (mayiuse(avx512_core)) { |
695 | for (int isTrans : {no_trans, do_trans}) |
696 | gemv_kernel[isTrans].reset( |
697 | new jit_avx512_core_gemv_bf16bf16f32_kern( |
698 | isTrans)); |
699 | } |
700 | break; |
701 | |
702 | case data_type::f32: |
703 | if (mayiuse(avx)) { |
704 | gemv_kernel[no_trans].reset( |
705 | new jit_sse41_gemv_n_f32_kern()); |
706 | gemv_kernel[do_trans].reset(new jit_avx_gemv_t_f32_kern()); |
707 | } else if (mayiuse(sse41)) { |
708 | gemv_kernel[no_trans].reset( |
709 | new jit_sse41_gemv_n_f32_kern()); |
710 | gemv_kernel[do_trans].reset( |
711 | new jit_sse41_gemv_t_f32_kern()); |
712 | } |
713 | break; |
714 | default: assert(!"unsupported data type!" ); |
715 | } |
716 | |
717 | // Set copy kernels function pointer table |
718 | for (int isTrans : {no_trans, do_trans}) |
719 | for (int isSum : {no_sum, do_sum}) { |
720 | auto *p_copy_a = copy_a[isTrans][isSum].get(); |
721 | if (p_copy_a != nullptr) { |
722 | st = p_copy_a->create_kernel(); |
723 | if (st != dnnl_success) return; |
724 | copy_a_kern[isTrans][isSum] |
725 | = (copy_a_fptr_t)p_copy_a->jit_ker(); |
726 | } |
727 | auto *p_copy_b = copy_b[isTrans][isSum].get(); |
728 | if (p_copy_b != nullptr) { |
729 | st = p_copy_b->create_kernel(); |
730 | if (st != dnnl_success) return; |
731 | copy_b_kern[isTrans][isSum] |
732 | = (copy_b_fptr_t)p_copy_b->jit_ker(); |
733 | } |
734 | } |
735 | |
736 | // AMX copy kernels don't support row/column sum. Use wrappers for now. |
737 | if (is_int8_amx) { |
738 | copy_a_kern[no_trans][do_sum] = ©_a_sum_ref<no_trans>; |
739 | copy_a_kern[do_trans][do_sum] = ©_a_sum_ref<do_trans>; |
740 | copy_b_kern[no_trans][do_sum] = ©_b_sum_ref<no_trans>; |
741 | copy_b_kern[do_trans][do_sum] = ©_b_sum_ref<do_trans>; |
742 | } |
743 | |
744 | // Set compute kernel function pointer table |
745 | for (int isBeta0 : {no_beta0, do_beta0}) |
746 | for (int isAlpha1 : {no_alpha1, do_alpha1}) |
747 | for (int doColSum : {no_sum, do_sum}) |
748 | for (int doRowSum : {no_sum, do_sum}) { |
749 | auto *p_kernel |
750 | = kernel[isBeta0][isAlpha1][doColSum][doRowSum] |
751 | .get(); |
752 | if (p_kernel != nullptr) { |
753 | st = p_kernel->create_kernel(); |
754 | if (st != dnnl_success) return; |
755 | kern[isBeta0][isAlpha1][doColSum][doRowSum] |
756 | = (gemm_fptr_t)p_kernel->jit_ker(); |
757 | } |
758 | } |
759 | // Override compute kernel table with AMX kernels |
760 | if (is_amx) { |
761 | // AMX compute kernels don't support alpha scaling, row-offset or |
762 | // col-offset. |
763 | for (int isBeta0 : {no_beta0, do_beta0}) |
764 | for (int isAlpha1 : {no_alpha1, do_alpha1}) |
765 | for (int doColSum : {no_sum, do_sum}) |
766 | for (int doRowSum : {no_sum, do_sum}) { |
767 | kern[isBeta0][isAlpha1][doColSum][doRowSum] |
768 | = kern[isBeta0][do_alpha1][no_sum][no_sum]; |
769 | } |
770 | } |
771 | |
772 | // Set gemv floating point kernels |
773 | if (utils::one_of(data_traits<a_t>::data_type, data_type::f32, |
774 | data_type::bf16)) { |
775 | for (int isTrans : {no_trans, do_trans}) { |
776 | auto *p_gemv_kernel = gemv_kernel[isTrans].get(); |
777 | if (p_gemv_kernel != nullptr) { |
778 | st = p_gemv_kernel->create_kernel(); |
779 | if (st != dnnl_success) return; |
780 | gemv_kern[isTrans] = (gemv_fptr_t)p_gemv_kernel->jit_ker(); |
781 | } |
782 | } |
783 | } |
784 | |
785 | // Set gemv integer gemm kernels |
786 | if (data_traits<a_t>::data_type == data_type::s8) { |
787 | if (gemv_s8s8s32_kernel != nullptr) { |
788 | auto *kern = gemv_s8s8s32_kernel.get(); |
789 | st = kern->create_kernel(); |
790 | if (st != dnnl_success) return; |
791 | gemv_s8s8s32_kern = (gemv_s8s8s32_fptr_t)kern->jit_ker(); |
792 | } |
793 | |
794 | if (gemv_s8u8s32_kernel != nullptr) { |
795 | auto *kern = gemv_s8u8s32_kernel.get(); |
796 | st = kern->create_kernel(); |
797 | if (st != dnnl_success) return; |
798 | gemv_s8u8s32_kern = (gemv_s8u8s32_fptr_t)kern->jit_ker(); |
799 | } |
800 | |
801 | if (gemv_u8s8s32_kernel != nullptr) { |
802 | auto *kern = gemv_u8s8s32_kernel.get(); |
803 | st = kern->create_kernel(); |
804 | if (st != dnnl_success) return; |
805 | gemv_u8s8s32_kern = (gemv_u8s8s32_fptr_t)kern->jit_ker(); |
806 | } |
807 | } |
808 | }); |
809 | |
810 | if (st != dnnl_success) return; |
811 | |
812 | int doSumA = this->bo != 0 ? do_sum : no_sum; |
813 | int doSumB = this->ao != 0 ? do_sum : no_sum; |
814 | |
815 | int copy_trans_a = (this->transa == do_trans) ? do_trans : no_trans; |
816 | int copy_trans_b = (this->transb == do_trans) ? do_trans : no_trans; |
817 | |
818 | constexpr bool is_bf16 = data_traits<a_t>::data_type == data_type::bf16; |
819 | bool doAlpha1 = this->alpha != 1.0f && is_bf16 ? no_alpha1 : do_alpha1; |
820 | |
821 | { |
822 | // NOTE: This lock may not be needed at all as writes to copy_a_kern |
823 | // (and others) are protected within std::call_once(). The lock is added |
824 | // only to fix warnings reported by clang TSAN about a data race in |
825 | // this code block. |
826 | std::lock_guard<std::mutex> g(kern_mutex); |
827 | this->copyA = copy_a_kern[copy_trans_a][doSumA]; |
828 | this->copyB = copy_b_kern[copy_trans_b][doSumB]; |
829 | for (int isBeta0 : {no_beta0, do_beta0}) |
830 | for (int doColSum : {no_sum, do_sum}) |
831 | for (int doRowSum : {no_sum, do_sum}) |
832 | this->kernel[isBeta0][doColSum][doRowSum] |
833 | = kern[isBeta0][doAlpha1][doColSum][doRowSum]; |
834 | for (int isTrans : {no_trans, do_trans}) |
835 | this->gemv_kernel[isTrans] = gemv_kern[isTrans]; |
836 | } |
837 | |
838 | this->gemv_s8s8s32_kernel = nullptr; |
839 | this->gemv_s8u8s32_kernel = nullptr; |
840 | this->gemv_u8s8s32_kernel = nullptr; |
841 | if (data_traits<a_t>::data_type == data_type::s8) { |
842 | this->gemv_s8s8s32_kernel = gemv_s8s8s32_kern; |
843 | this->gemv_s8u8s32_kernel = gemv_s8u8s32_kern; |
844 | this->gemv_u8s8s32_kernel = gemv_u8s8s32_kern; |
845 | } |
846 | } |
847 | |
848 | // Check if copy algorithm kernels were generated on supported ISAs. |
849 | // Copy algorithm supported for: |
850 | // s8 : Intel AVX512, Intel DL Boost |
851 | // bf16 : Intel AVX512, Intel AVX512 BF16 |
852 | // f32 : Intel SSE4.1, Intel AVX, Intel AVX2, Intel AVX512 |
853 | template <typename a_t, typename b_t, typename c_t> |
854 | bool gemm_info_t<a_t, b_t, c_t>::hasKernels(void) { |
855 | |
856 | switch (data_traits<a_t>::data_type) { |
857 | case data_type::s8: |
858 | if (mayiuse(sse41)) { |
859 | for (int isBeta0 : {no_beta0, do_beta0}) |
860 | for (int doColSum : {no_sum, do_sum}) |
861 | for (int doRowSum : {no_sum, do_sum}) |
862 | if (!this->kernel[isBeta0][doColSum][doRowSum]) |
863 | return false; |
864 | |
865 | if (!this->copyA || !this->copyB) return false; |
866 | |
867 | if (mayiuse(avx512_core)) |
868 | if (!this->gemv_s8u8s32_kernel || !this->gemv_u8s8s32_kernel |
869 | || !this->gemv_s8s8s32_kernel) |
870 | return false; |
871 | } |
872 | break; |
873 | |
874 | case data_type::bf16: |
875 | if (mayiuse(avx512_core)) { |
876 | for (int isBeta0 : {no_beta0, do_beta0}) |
877 | if (!this->kernel[isBeta0][no_sum][no_sum]) return false; |
878 | |
879 | if (!this->copyA || !this->copyB) return false; |
880 | |
881 | for (int isTrans : {no_trans, do_trans}) |
882 | if (!this->gemv_kernel[isTrans]) return false; |
883 | } |
884 | break; |
885 | |
886 | case data_type::f32: |
887 | if (mayiuse(sse41) && !this->force_nocopy) { |
888 | for (int isBeta0 : {no_beta0, do_beta0}) |
889 | if (!this->kernel[isBeta0][no_sum][no_sum]) return false; |
890 | |
891 | if (!this->copyA || !this->copyB) return false; |
892 | |
893 | for (int isTrans : {no_trans, do_trans}) |
894 | if (!this->gemv_kernel[isTrans]) return false; |
895 | } |
896 | break; |
897 | default: assert(!"unsupported data type!" ); |
898 | } |
899 | |
900 | // All kernels necessary have been found or ISA is not supported. |
901 | return true; |
902 | } |
903 | |
904 | // Override default blocking sizes with sizes specified in the gemm_threading_t |
905 | // structure. |
906 | template <typename a_t, typename b_t, typename c_t> |
907 | void gemm_info_t<a_t, b_t, c_t>::update_blocking( |
908 | const gemm_threading_t &thread_info) { |
909 | |
910 | if (thread_info.block_m > 0) this->bm = thread_info.block_m; |
911 | if (thread_info.block_n > 0) this->bn = thread_info.block_n; |
912 | if (thread_info.block_k > 0) this->bk = thread_info.block_k; |
913 | } |
914 | |
915 | // Instantiate the gemm_info_t templates needed. |
916 | template // For gemm_s8u8s32 |
917 | struct gemm_info_t<int8_t, uint8_t, int32_t>; |
918 | |
919 | template // For gemm_s8s8s32 |
920 | struct gemm_info_t<int8_t, int8_t, int32_t>; |
921 | |
922 | template // For gemm_bf16bf16f32 |
923 | struct gemm_info_t<bfloat16_t, bfloat16_t, float>; |
924 | |
925 | template // For sgemm. |
926 | struct gemm_info_t<float, float, float>; |
927 | |
928 | } // namespace x64 |
929 | } // namespace cpu |
930 | } // namespace impl |
931 | } // namespace dnnl |
932 | |