1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/brgemm/brgemm_utils.hpp"
18#include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp"
19
20#include "cpu/x64/cpu_isa_traits.hpp"
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace dnnl::impl::utils;
34
35enum {
36 decomposition_2x2 = 101,
37 decomposition_3x1_3,
38 decomposition_3x1_2,
39 undefined,
40};
41
42impl::data_type_t get_accum_datatype(brgemm_t *brg) {
43 // this assert should check if 'init_kernel_datatype()' was previously
44 // called.
45 assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
46 return brg->is_int8 ? data_type::s32 : data_type::f32;
47}
48
49void init_kernel_datatype(
50 brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) {
51 assert(dt_a != data_type::undef && dt_b != data_type::undef);
52 brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
53 && utils::one_of(dt_b, data_type::u8, data_type::s8);
54 brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
55 brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
56 brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
57 assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
58}
59
60void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
61 float beta, const brgemm_strides_t *strides) {
62 brg->beta = beta;
63 brg->alpha = alpha;
64 brg->type = type;
65 brg->with_bias = false;
66 brg->with_eltwise = false;
67 brg->with_sum = false;
68 brg->sum_scale = 0;
69 brg->sum_zp = 0;
70 brg->with_scales = false;
71
72 if (strides != nullptr) {
73 brg->stride_a = strides->stride_a;
74 brg->stride_b = strides->stride_b;
75 } else {
76 brg->stride_a = brg->stride_b = 0;
77 }
78}
79
80namespace brgemm_utils {
81
82bool can_dispatch_uker(const brgemm_t *brg) {
83 return brg->is_tmm && brg->type == brgemm_addr && brg->brgattr.use_uker
84 && !brg->brgattr.generate_skip_accumulation;
85}
86
87void maybe_try_bf32(brgemm_t *brg) {
88 const bool try_bf32 = brg->is_f32
89 && brg->brgattr.fpmath_mode == fpmath_mode::bf16
90 && utils::one_of(brg->isa_user, isa_undef, avx512_core_amx)
91 && mayiuse(avx512_core_amx);
92 if (try_bf32) {
93 const bool is_tmm = brg->is_tmm;
94 brg->is_tmm = true;
95 if (can_dispatch_uker(brg) /*Requires is_amx to be true*/) {
96 brg->is_bf32 = true;
97 } else {
98 brg->is_bf32 = false;
99 // Restore
100 brg->is_tmm = is_tmm;
101 }
102 }
103}
104
105void set_isa_impl(brgemm_t *brg) {
106 auto is_isa_ok = [&](cpu_isa_t isa) {
107 return mayiuse(isa) &&
108 // maybe IMPLICATION(brg->isa_user != isa_undef,
109 // is_superset(brg->isa_user, isa)), but the API is not clear.
110 one_of(brg->isa_user, isa_undef, isa);
111 };
112
113 if (brg->is_bf32) {
114 brg->isa_impl = avx512_core_amx;
115 } else if (brg->is_f32) {
116 brg->isa_impl = utils::map(true, isa_undef,
117 is_isa_ok(avx512_core) || is_isa_ok(avx512_core_amx) /*bf32*/,
118 avx512_core, is_isa_ok(avx2), avx2,
119 // Allow avx512_core_fp16 isa in case of a f16 primitive that
120 // is implemented using pre-conversion of inputs to f32.
121 // This is needed to support f16 binary post-ops.
122 is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2),
123 avx2);
124 } else if (brg->is_bf16) {
125 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx),
126 avx512_core_amx, is_isa_ok(avx512_core_bf16), avx512_core_bf16,
127 is_isa_ok(avx2_vnni_2), avx2_vnni_2);
128 } else if (brg->is_f16) {
129 brg->isa_impl
130 = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx_fp16),
131 avx512_core_amx_fp16, is_isa_ok(avx512_core_fp16),
132 avx512_core_fp16, is_isa_ok(avx2_vnni_2), avx2_vnni_2);
133 } else if (brg->is_int8) {
134 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx),
135 avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni,
136 is_isa_ok(avx2_vnni), avx2_vnni);
137 }
138}
139
140void set_brg_vmm(brgemm_t *brg) {
141 brg->is_tmm = brg->is_int8_tmm || brg->is_bf16_tmm || brg->is_f16_tmm
142 || brg->is_bf32;
143 brg->is_zmm = !brg->is_tmm && mayiuse(avx512_core)
144 && is_superset(brg->isa_impl, avx512_core);
145 brg->is_ymm
146 = !brg->is_zmm && mayiuse(avx2) && is_superset(brg->isa_impl, avx2);
147}
148
149status_t brgemm_blocking(brgemm_t *brg) {
150
151 set_isa_impl(brg);
152 if (brg->isa_impl == isa_undef) return status::unimplemented;
153 set_brg_vmm(brg);
154 if (!(brg->is_tmm || brg->is_zmm || brg->is_ymm))
155 return status::unimplemented;
156
157 if (!brg->is_tmm) {
158 const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8;
159 brg->ld_block = simd_w;
160 brg->ldb = brg->load_dim / brg->ld_block;
161 brg->ldb_tail = brg->load_dim % brg->ld_block;
162
163 brg->ld_block2 = 4; // (M < 9) ? 2 : 4 | TODO - fix this for INT8
164 brg->ldb2 = brg->ldb / brg->ld_block2;
165 brg->ldb2_tail = brg->ldb % brg->ld_block2;
166
167 if (brg->ldb2 == 0) brg->ld_block2 = nstl::max(1, brg->ldb2_tail);
168 brg->embd_bcst = brg->is_f32
169 && (brg->ldb2_tail <= 1 && brg->ldb2 == 0)
170 /*only avx512 or more can bcast*/
171 && is_superset(brg->isa_impl, avx512_core);
172
173 const auto ld_block
174 = (brg->ldb2 != 0) ? brg->ld_block2 : brg->ldb2_tail;
175 const auto adj_ld_block = (ld_block == 0) ? (ld_block + 1) : ld_block;
176
177 const int max_isa_regs
178 = is_superset(brg->isa_impl, avx512_core) ? 32 : 16;
179 const int max_bcst_regs = 1;
180 const bool req_compensation = brg->req_s8s8_compensation
181 || brg->zp_type_a != brgemm_broadcast_t::none;
182 const bool req_zp_a_comp_pads
183 = (brg->req_cal_comp_pads || brg->brgattr.max_top_vpad > 0
184 || brg->brgattr.max_bottom_vpad > 0)
185 && brg->zp_type_a != brgemm_broadcast_t::none;
186 const auto max_regs = max_isa_regs - (adj_ld_block + max_bcst_regs);
187 auto max_block
188 = (brg->embd_bcst ? max_regs - 4
189 : ((brg->beta == 1.f || brg->beta == 0.f)
190 ? max_regs
191 : max_regs - 1));
192 max_block -= req_compensation;
193 max_block -= req_zp_a_comp_pads;
194 if (!is_superset(brg->isa_impl, avx512_core) && brg->ldb_tail > 0
195 && brg->beta != 0.f)
196 --max_block;
197 if (req_zp_a_comp_pads) max_block = nstl::min(max_block, max_regs - 5);
198 if (brg->is_bf16_emu)
199 max_block
200 = nstl::min(max_block, 28); // bf16_emu only for avx512_core
201 max_block /= adj_ld_block;
202 const int min_block = 1;
203 float best_bd_block_eff = 0.f;
204 brg->bd_block = 1;
205 for (int bd_block = max_block; bd_block >= min_block; bd_block--) {
206 const auto bd_block_disb = static_cast<float>(brg->bcast_dim)
207 / rnd_up(brg->bcast_dim, bd_block);
208 const auto brgemm_microkernel_eff
209 = (static_cast<float>(adj_ld_block) * bd_block)
210 / (((adj_ld_block) + bd_block) * max_block);
211 const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff;
212
213 float block_foot_print = static_cast<float>(brg->typesize_A)
214 * (bd_block * brg->reduce_dim);
215 if (block_foot_print <= static_cast<float>(
216 platform::get_per_core_cache_size(1))
217 && (bd_block_eff > best_bd_block_eff)) {
218 brg->bd_block = bd_block;
219 best_bd_block_eff = bd_block_eff;
220 }
221 }
222 brg->bdb = brg->bcast_dim / brg->bd_block;
223 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
224
225 const int rd_unroll = 4;
226 const int vnni_granularity
227 = (brg->is_f16 && brg->isa_impl == avx512_core_fp16)
228 ? 1
229 : data_type_vnni_granularity(brg->dt_a);
230 brg->rd_block = rd_unroll * vnni_granularity;
231 brg->rdb = brg->reduce_dim / brg->rd_block;
232 brg->rdb_tail = brg->reduce_dim % brg->rd_block;
233
234 brg->is_M_tail = false;
235 } else {
236 // Blocking configuration for AMX
237 const int max_width = 16, min_width = 1;
238 brg->ld_block = 16;
239 brg->ldb = brg->load_dim / brg->ld_block;
240 brg->ldb_tail = brg->load_dim % brg->ld_block;
241
242 auto find_bd_block_for_bd_mask = [&]() {
243 const auto bd_mask_size = brg->bcast_dim;
244 if (brg->brgattr.bd_mask_level != 2 || bd_mask_size == 0)
245 return false;
246
247 const auto sm_buffer = brg->brgattr.bd_mask;
248 auto min_bdb = INT_MAX;
249 const auto start_bd_block = nstl::min(max_width, brg->bcast_dim);
250 auto best_bd_block = start_bd_block;
251 for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) {
252 auto bdb = 0;
253 for (int i = 0; i < bd_mask_size;) {
254 if (brg->brgattr.bd_mask_level == 2 && sm_buffer[i] == 0) {
255 i++;
256 } else {
257 i += bd_block;
258 if (i > brg->bcast_dim) {
259 // bcast_dim not divided by bd_block
260 bdb = INT_MAX;
261 } else
262 bdb++;
263 }
264 }
265 if (bdb < min_bdb) {
266 min_bdb = bdb;
267 best_bd_block = bd_block;
268 }
269 }
270 brg->bd_block = best_bd_block;
271 brg->bdb_tail = 0;
272 brg->bdb = min_bdb;
273 return true;
274 };
275
276 auto set_decomposition_by_ld = [&]() {
277 if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) {
278 if (brg->ldb % 3 == 0)
279 brg->ld_block2 = 3;
280 else if (brg->ldb % 2 == 0)
281 brg->ld_block2 = 2;
282 else
283 brg->ld_block2 = 1;
284 } else {
285 brg->ld_block2
286 = (brg->ldb > 0 && brg->ldb % 2 == 0
287 && brg->ldb_tail == 0 && brg->bd_block2 < 3)
288 ? 2
289 : 1;
290 }
291 brg->ldb2 = brg->ldb / brg->ld_block2;
292 brg->ldb2_tail = brg->ldb % brg->ld_block2;
293
294 // Re-adjust the bd_block2 if possible
295 if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) {
296 brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1;
297 brg->bdb2 = brg->bdb / brg->bd_block2;
298 brg->bdb2_tail = (brg->bd_block2 == 1)
299 ? brg->bdb
300 : brg->bdb % brg->bd_block2;
301 }
302 };
303
304 auto try_3x1_decomposition = [&](int width_step) {
305 brg->is_M_tail = false;
306 if (brg->bcast_dim > (width_step - 1) * max_width
307 && brg->bcast_dim < width_step * max_width
308 && brg->ldb_tail == 0) {
309 if (!find_bd_block_for_bd_mask()) {
310 brg->bd_block = max_width;
311 brg->bdb = div_up(brg->bcast_dim, brg->bd_block);
312 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
313 brg->is_M_tail = true;
314 }
315 brg->bd_block2 = width_step;
316 brg->bdb2 = brg->bdb / brg->bd_block2;
317 brg->bdb2_tail = brg->bdb % brg->bd_block2;
318 set_decomposition_by_ld();
319 return true;
320 }
321 return false;
322 };
323
324 auto try_2x2_decomposition = [&]() {
325 if (!find_bd_block_for_bd_mask()) {
326 for (int m_block = max_width; m_block >= min_width; m_block--) {
327 if (brg->bcast_dim % m_block == 0) {
328 brg->bd_block = m_block;
329 break;
330 }
331 }
332 if (brg->bd_block == 1) {
333 brg->bd_block = nstl::min(max_width, brg->bcast_dim);
334 brg->bdb_tail = brg->bcast_dim % max_width;
335 for (int i = max_width; i >= min_width; i--) {
336 const auto i_tail = brg->bcast_dim % i;
337 if (i_tail > brg->bdb_tail || i_tail == 0) {
338 brg->bd_block = i;
339 brg->bdb_tail = i_tail;
340 if (i_tail == 0) break;
341 }
342 }
343 }
344 brg->bdb = brg->bcast_dim / brg->bd_block;
345 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
346 }
347
348 brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1;
349 brg->bdb2 = brg->bdb / brg->bd_block2;
350 brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb
351 : brg->bdb % brg->bd_block2;
352
353 brg->is_M_tail = false;
354
355 set_decomposition_by_ld();
356
357 return !(brg->ld_block2 == 1 || brg->bd_block2 == 1
358 || brg->bd_block < 8);
359 };
360
361 bool is_decomposition_defined = false;
362 for (int i = decomposition_2x2; i != undefined; i++) {
363 switch (i) {
364 case decomposition_2x2:
365 is_decomposition_defined = try_2x2_decomposition();
366 break;
367 case decomposition_3x1_3:
368 is_decomposition_defined = try_3x1_decomposition(3);
369 break;
370 case decomposition_3x1_2:
371 is_decomposition_defined = try_3x1_decomposition(2);
372 break;
373 default: assert(!"invalid value"); break;
374 };
375 if (is_decomposition_defined) break;
376 }
377 if (!is_decomposition_defined) try_2x2_decomposition();
378
379 const bool try_load_nt_A = (brg->brgattr.hint_innermost_loop
380 == brgemm_bd_loop_innermost);
381 const bool try_load_nt_B = (brg->brgattr.hint_innermost_loop
382 == brgemm_ld_loop_innermost);
383 const bool try_load_nt
384 = (static_cast<size_t>(brg->typesize_A)
385 * brg->brgattr.hint_expected_A_size
386 + static_cast<size_t>(brg->typesize_B)
387 * brg->brgattr.hint_expected_B_size
388 + static_cast<size_t>(brg->typesize_C)
389 * brg->brgattr.hint_expected_C_size)
390 >= platform::get_per_core_cache_size(1);
391 brg->load_nt_A = try_load_nt_A && try_load_nt;
392 brg->load_nt_B = try_load_nt_B && try_load_nt;
393
394 auto recalc_bd_block = [&](int new_bd_block) {
395 if (new_bd_block == 0) return;
396 brg->bd_block = new_bd_block;
397 brg->bdb = div_up(brg->bcast_dim, brg->bd_block);
398 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
399 brg->is_M_tail = (brg->bdb_tail != 0);
400 };
401
402 auto recalc_bd_block2 = [&](int new_bd_block2) {
403 if (new_bd_block2 == 0) return;
404 brg->bd_block2 = new_bd_block2;
405 if (brg->bdb_tail && brg->bd_block2 > 1) brg->bd_block2--;
406 auto full_bd_blocks = brg->bdb - (brg->bdb_tail != 0 ? 1 : 0);
407 brg->bdb2 = full_bd_blocks / brg->bd_block2;
408 brg->bdb2_tail = full_bd_blocks % brg->bd_block2;
409 };
410
411 auto recalc_ld_block = [&](int new_ld_block) {
412 if (new_ld_block == 0) return;
413 brg->ld_block = new_ld_block;
414 brg->ldb = brg->load_dim / brg->ld_block;
415 brg->ldb_tail = brg->load_dim % brg->ld_block;
416 };
417
418 auto recalc_ld_block2 = [&](int new_ld_block2) {
419 if (new_ld_block2 == 0) return;
420 brg->ld_block2 = new_ld_block2;
421 if (brg->ldb_tail && brg->ld_block2 > 1) brg->ld_block2--;
422 brg->ldb2 = brg->ldb / brg->ld_block2;
423 brg->ldb2_tail = brg->ldb % brg->ld_block2;
424 };
425
426 recalc_bd_block2(brg->bd_block2);
427 recalc_ld_block2(brg->ld_block2);
428
429 // check hints for blocking parameters
430 recalc_bd_block(brg->brgattr.hint_bd_block);
431 recalc_bd_block2(brg->brgattr.hint_bd_block2);
432 recalc_ld_block(brg->brgattr.hint_ld_block);
433 recalc_ld_block2(brg->brgattr.hint_ld_block2);
434
435 if (brg->brgattr.hint_load_nt_A != brgemm_hint_nt_undef)
436 brg->load_nt_A
437 = (brg->brgattr.hint_load_nt_A == brgemm_hint_nt_true);
438 if (brg->brgattr.hint_load_nt_B != brgemm_hint_nt_undef)
439 brg->load_nt_B
440 = (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true);
441
442 const auto max_rd_block
443 = (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 32
444 : 64;
445 const auto rd_block_step
446 = (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 2 : 4;
447 // TODO: if rd_block calculated is very small then maybe it makes
448 // sense to use 1x2 or 2x1 blocking with supporting rd_block
449 // and rdb_tail
450 brg->rd_block = rd_block_step;
451 for (int i = max_rd_block; i > 0; i -= rd_block_step) {
452 if (brg->reduce_dim % i == 0) {
453 brg->rd_block = i;
454 break;
455 }
456 }
457 brg->rdb = brg->reduce_dim / brg->rd_block;
458 brg->rdb_tail = brg->reduce_dim % brg->rd_block;
459
460 // Remove these guards in the future (add tail processing by reduction
461 // dimension)
462 if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, brg->is_bf32))
463 return status::unimplemented;
464 if (!IMPLICATION(
465 (brg->rdb_tail
466 % ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4))
467 != 0,
468 brg->is_bf32))
469 return status::unimplemented;
470
471 //TODO: check this condition
472 brg->interleave_tilestores_ = brg->beta == 0
473 && (brg->brgattr.use_interleave_stores
474 && (brg->bd_block2 * brg->ld_block2 == 4)
475 && !brg->brgattr.var_bs)
476 ? true
477 : false;
478 }
479
480 return status::success;
481}
482
483status_t brdgmm_blocking(brgemm_t *brg) {
484
485 if (brg->isa_impl == isa_undef) return status::unimplemented;
486
487 const int requires_permute_dst_vmm = brg->isa_impl == avx512_core_vnni
488 && jit_brdgmm_kernel_base_t<avx512_core_vnni,
489 Xbyak::Zmm>::is_fast_vnni_int8(*brg);
490 const int max_vregs = isa_num_vregs(brg->isa_impl);
491 const int aux_vregs
492 = nstl::max(brg->is_bf16_emu * 4, 2) + requires_permute_dst_vmm;
493 const int max_acc_vmms = max_vregs - aux_vregs;
494 const int simd_w = isa_max_vlen(brg->isa_impl) / brg->typesize_C;
495 const bool is_avx2_vnni_2_xf16
496 = (brg->is_bf16 || brg->is_f16) && brg->isa_impl == avx2_vnni_2;
497
498 auto &M = brg->bcast_dim;
499 auto &N = brg->load_dim;
500
501 // In current implementation of dgmm, there is no reduce dim.
502 auto &m_block1 = brg->bd_block;
503 auto &nb_m_block1 = brg->bdb;
504 auto &m_block1_tail = brg->bdb_tail;
505 auto &m_block2 = brg->bd_block2;
506 auto &nb_m_block2 = brg->bdb2;
507 auto &m_block2_tail = brg->bdb2_tail;
508
509 auto &n_block1 = brg->ld_block;
510 auto &nb_n_block1 = brg->ldb;
511 auto &n_block1_tail = brg->ldb_tail;
512 auto &n_block2 = brg->ld_block2;
513 auto &nb_n_block2 = brg->ldb2;
514 auto &n_block2_tail = brg->ldb2_tail;
515
516 // begin blocking
517 // for avx2_vnni_2_xf16, instead of processing a n_block1 at once, it is
518 // processed as even/odd pair.
519 const int n_block1_num_steps = is_avx2_vnni_2_xf16 ? 2 : 1;
520 n_block1 = n_block1_num_steps * simd_w;
521 nb_n_block1 = div_up(N, n_block1);
522 n_block1_tail = N % n_block1;
523
524 const int max_n_block2_vmms = 4;
525 const int max_n_block2 = max_n_block2_vmms / n_block1_num_steps;
526 n_block2 = nstl::min(max_n_block2, nb_n_block1);
527 nb_n_block2 = div_up(nb_n_block1, n_block2);
528 n_block2_tail = nb_n_block1 % n_block2;
529
530 m_block1 = 1;
531 nb_m_block1 = M / m_block1;
532 m_block1_tail = M % m_block1;
533 m_block2 = nstl::min(
534 nb_m_block1, max_acc_vmms / (n_block2 * n_block1_num_steps));
535 nb_m_block2 = div_up(nb_m_block1, m_block2);
536 m_block2_tail = nb_m_block1 % m_block2;
537
538 return status::success;
539}
540
541void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
542 impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
543 float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
544 dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) {
545
546 init_common_conf(brg, type, alpha, beta, strides);
547
548 brg->layout = layout;
549
550 brg->dt_a = brg->is_row_major() ? dt_a : dt_b;
551 brg->dt_b = brg->is_row_major() ? dt_b : dt_a;
552 init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
553
554 brg->dt_c = get_accum_datatype(brg);
555 brg->dt_d = brg->dt_c;
556 brg->dt_bias = brg->dt_c;
557
558 brg->typesize_A = types::data_type_size(brg->dt_a);
559 brg->typesize_B = types::data_type_size(brg->dt_b);
560 brg->typesize_C = types::data_type_size(brg->dt_c);
561 brg->typesize_D = types::data_type_size(brg->dt_d);
562
563 brg->isa_user = isa;
564 set_isa_impl(brg);
565 brg->is_int8_tmm = brg->is_int8 && brg->isa_impl == avx512_core_amx;
566 brg->is_bf16_tmm = brg->is_bf16 && brg->isa_impl == avx512_core_amx;
567 brg->is_f16_tmm = brg->is_f16 && brg->isa_impl == avx512_core_amx_fp16;
568 brg->is_bf32 = is_bf32
569 && utils::one_of(brg->isa_user, isa_undef, avx512_core_amx)
570 && mayiuse(avx512_core_amx);
571 set_brg_vmm(brg); // TODO: Investigate if it is really needed here.
572 brg->req_s8s8_compensation
573 = brg->is_int8 && !brg->is_int8_tmm && brg->dt_a == data_type::s8;
574
575 brg->LDA = (brg->is_row_major()) ? static_cast<int>(LDA)
576 : static_cast<int>(LDB);
577 brg->LDB = (brg->is_row_major()) ? static_cast<int>(LDB)
578 : static_cast<int>(LDA);
579 brg->LDC = static_cast<int>(LDC);
580 brg->LDD = static_cast<int>(LDC);
581
582 brg->bcast_dim
583 = (brg->is_row_major()) ? static_cast<int>(M) : static_cast<int>(N);
584 brg->load_dim
585 = (brg->is_row_major()) ? static_cast<int>(N) : static_cast<int>(M);
586 brg->reduce_dim = static_cast<int>(K);
587
588 brg->bd_block2 = 0;
589 brg->bdb2 = 0;
590 brg->bdb2_tail = 0;
591
592 const bool is_b_in_vnni_format = !(
593 brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16);
594 brg->ld_step
595 = is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1;
596
597 const bool has_no_vnni_compute_instruction
598 = (brg->is_f16
599 && one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16))
600 || (brg->is_bf16 && brg->isa_impl == avx2_vnni_2);
601 brg->rd_step = has_no_vnni_compute_instruction
602 ? 1
603 : data_type_vnni_granularity(brg->dt_b);
604}
605
606void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
607 impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
608 float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
609 const brgemm_strides_t *strides) {
610
611 init_common_conf(brg, type, alpha, beta, strides);
612
613 brg->layout = layout;
614
615 brg->dt_a = dt_a;
616 brg->dt_b = dt_b;
617 init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
618
619 brg->dt_c = get_accum_datatype(brg);
620 brg->dt_d = brg->dt_c;
621 brg->dt_bias = brg->dt_c;
622
623 brg->typesize_A = types::data_type_size(brg->dt_a);
624 brg->typesize_B = types::data_type_size(brg->dt_b);
625 brg->typesize_C = types::data_type_size(brg->dt_c);
626 brg->typesize_D = types::data_type_size(brg->dt_d);
627
628 brg->isa_user = isa;
629 auto is_isa_ok = [&](cpu_isa_t isa) {
630 return mayiuse(isa) && one_of(brg->isa_user, isa_undef, isa);
631 };
632
633 if (brg->is_f32) {
634 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core),
635 avx512_core, is_isa_ok(avx2), avx2);
636 } else if (brg->is_bf16) {
637 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_bf16),
638 avx512_core_bf16, is_isa_ok(avx2_vnni_2), avx2_vnni_2);
639 } else if (brg->is_f16) {
640 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_fp16),
641 avx512_core_fp16, is_isa_ok(avx2_vnni_2), avx2_vnni_2);
642 } else if (brg->is_int8) {
643 brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_vnni),
644 avx512_core_vnni, is_isa_ok(avx2_vnni), avx2_vnni);
645 }
646
647 brg->is_bf16_tmm = brg->is_bf16 && mayiuse(avx512_core_amx);
648 brg->is_dgmm = true;
649
650 brg->LDA = static_cast<int>(LDA);
651 brg->LDC = static_cast<int>(LDC);
652 brg->LDD = static_cast<int>(LDC);
653
654 brg->bcast_dim = M;
655 brg->load_dim = N;
656}
657
658} // namespace brgemm_utils
659} // namespace x64
660} // namespace cpu
661} // namespace impl
662} // namespace dnnl
663
664//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
665