1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/brgemm/brgemm.hpp"
18#include "cpu/x64/brgemm/brgemm_utils.hpp"
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34using namespace dnnl::impl::status;
35using namespace dnnl::impl::utils;
36
37using namespace prop_kind;
38using namespace data_type;
39using namespace brgemm_utils;
40
41void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
42 const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
43 brgemm_kernel_params_t brgemm_p;
44
45 brgemm_p.batch = batch;
46 brgemm_p.ptr_A = nullptr;
47 brgemm_p.ptr_B = nullptr;
48 brgemm_p.ptr_C = ptr_C;
49 brgemm_p.ptr_D = ptr_C;
50 brgemm_p.ptr_buf = scratch;
51 brgemm_p.ptr_bias = nullptr;
52 brgemm_p.do_post_ops = 0;
53 brgemm_p.do_apply_comp = 0;
54 brgemm_p.skip_accm = 0;
55 brgemm_p.BS = bs;
56
57 assert(brg_kernel);
58
59 (*brg_kernel)(&brgemm_p);
60}
61
62void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
63 const void *addr_A, const void *addr_B,
64 const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
65 brgemm_kernel_params_t brgemm_p;
66
67 brgemm_p.batch = batch;
68 brgemm_p.ptr_A = addr_A;
69 brgemm_p.ptr_B = addr_B;
70 brgemm_p.ptr_C = ptr_C;
71 brgemm_p.ptr_D = ptr_C;
72 brgemm_p.ptr_buf = scratch;
73 brgemm_p.ptr_bias = nullptr;
74 brgemm_p.do_post_ops = 0;
75 brgemm_p.do_apply_comp = 0;
76 brgemm_p.skip_accm = 0;
77 brgemm_p.BS = bs;
78 assert(brg_kernel);
79 (*brg_kernel)(&brgemm_p);
80}
81
82void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
83 const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
84 const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
85 brgemm_kernel_params_t brgemm_p;
86
87 brgemm_p.batch = batch;
88 brgemm_p.ptr_A = nullptr;
89 brgemm_p.ptr_B = nullptr;
90 brgemm_p.ptr_C = ptr_C;
91 brgemm_p.ptr_D = ptr_D;
92 brgemm_p.ptr_buf = scratch;
93 brgemm_p.ptr_bias = post_ops_data.bias;
94 brgemm_p.ptr_scales = post_ops_data.scales;
95 brgemm_p.do_post_ops
96 = post_ops_data.do_only_comp || post_ops_data.do_only_zp_a_val ? 0
97 : 1;
98 brgemm_p.do_apply_comp = post_ops_data.do_only_zp_a_val ? 0 : 1;
99 brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
100 brgemm_p.BS = bs;
101 brgemm_p.zp_a_val = post_ops_data.zp_a_val;
102 brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
103 brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
104 brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
105 brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
106 brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
107 brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
108 brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
109 brgemm_p.c_zp_values = post_ops_data.c_zp_values;
110 assert(brg_kernel);
111 (*brg_kernel)(&brgemm_p);
112}
113
114void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
115 const void *addr_A, const void *addr_B,
116 const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
117 const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
118 brgemm_kernel_params_t brgemm_p;
119
120 brgemm_p.batch = batch;
121 brgemm_p.ptr_A = addr_A;
122 brgemm_p.ptr_B = addr_B;
123 brgemm_p.ptr_C = ptr_C;
124 brgemm_p.ptr_D = ptr_D;
125 brgemm_p.ptr_buf = scratch;
126 brgemm_p.ptr_bias = post_ops_data.bias;
127 brgemm_p.ptr_scales = post_ops_data.scales;
128 brgemm_p.do_post_ops
129 = post_ops_data.do_only_comp || post_ops_data.do_only_zp_a_val ? 0
130 : 1;
131 brgemm_p.do_apply_comp = post_ops_data.do_only_zp_a_val ? 0 : 1;
132 brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
133 brgemm_p.BS = bs;
134 brgemm_p.zp_a_val = post_ops_data.zp_a_val;
135 brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
136 brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
137 brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
138 brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
139 brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
140 assert(brg_kernel);
141 (*brg_kernel)(&brgemm_p);
142}
143
144status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
145 brgemm_batch_kind_t type, impl::data_type_t dt_a,
146 impl::data_type_t dt_b, bool transA, bool transB,
147 brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
148 dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides) {
149 /*
150 m - number of rows of the matrix op(A) and number of rows of the matrix C
151 n - number of columns of the matrix op(B) and number of columns of the matrix C
152 k - number of columns of the matrix op(A) and number of rows of the matrix op(B)
153
154 Matrices are in row-major layouts:
155 A: lda * m, LDA - lda must be at least max(1, k)
156 B: ldb * k, LDB - ldb must be at least max(1, n)
157 C: ldc * m, LDC - ldc must be at least max(1, n)
158
159 Matrices are in column-major layouts:
160 A: lda * k, LDA - lda must be at least max(1, m)
161 B: ldb * n, LDB - ldb must be at least max(1, k)
162 C: ldc * n, LDC - ldc must be at least max(1, m)
163 */
164 if (brg == nullptr) return status::invalid_arguments;
165 if (transA || transB) return status::unimplemented;
166
167 brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
168 beta, LDA, LDB, LDC, M, N, K, strides);
169
170 if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
171 bool ldx_check = (brg->is_row_major()) ? (LDA < K)
172 : (LDA < M || LDB < K || LDC < M);
173 if (ldx_check) return status::invalid_arguments;
174
175 if (utils::everyone_is(
176 false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16))
177 return status::unimplemented;
178
179 CHECK(brgemm_blocking(brg));
180
181 return status::success;
182}
183
184status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
185 brgemm_batch_kind_t type, impl::data_type_t dt_a,
186 impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
187 float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
188 const brgemm_strides_t *strides) {
189
190 if (brg == nullptr) return status::invalid_arguments;
191 if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
192 return status::unimplemented;
193
194 brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
195 beta, LDA, LDC, M, N, strides);
196
197 const bool ldx_check = (LDA < N || LDC < N);
198 if (ldx_check) return status::invalid_arguments;
199
200 if (utils::everyone_is(
201 false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16))
202 return status::unimplemented;
203
204 CHECK(brdgmm_blocking(brg));
205
206 return status::success;
207}
208
209status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
210 const memory_desc_t *dst_md, int LDD, impl::data_type_t dt_bias) {
211 if (!brg || !dst_md) return status::invalid_arguments;
212
213 brg->attr = attr;
214 brg->dst_md = dst_md;
215
216 brg->with_bias = (dt_bias == data_type::undef) ? false : true;
217 brg->dt_bias = dt_bias;
218 brg->typesize_bias = (dt_bias == data_type::undef)
219 ? 0
220 : types::data_type_size(brg->dt_bias);
221
222 brg->LDD = LDD;
223 const auto dt_d = dst_md->data_type;
224
225 // check that bias and output data type are supported by isa
226 if (!IMPLICATION(one_of(data_type::bf16, dt_bias, dt_d),
227 is_superset(brg->isa_impl, avx512_core)
228 || is_superset(brg->isa_impl, avx2_vnni_2)))
229 return status::unimplemented;
230 if (!IMPLICATION(one_of(data_type::f16, dt_bias, dt_d),
231 is_superset(brg->isa_impl, avx512_core_fp16)
232 || is_superset(brg->isa_impl, avx2_vnni_2)))
233 return status::unimplemented;
234 // check that combination of data types is allowed
235 if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
236 && (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
237 data_type::f32, data_type::bf16))
238 && (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
239 data_type::s32, data_type::f32, data_type::bf16)))
240 return status::unimplemented;
241 if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
242 && (!one_of(dt_d, data_type::bf16, data_type::f32))
243 && (!one_of(dt_bias, data_type::undef, data_type::bf16,
244 data_type::f32)))
245 return status::unimplemented;
246 if ((brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32)
247 && (!one_of(dt_d, data_type::f32))
248 && (!one_of(dt_bias, data_type::undef, data_type::f32)))
249 return status::unimplemented;
250 if (!IMPLICATION(brg->is_f16,
251 one_of(dt_d, data_type::f32, data_type::f16)
252 && one_of(dt_bias, data_type::undef, data_type::f32,
253 data_type::f16)))
254 return status::unimplemented;
255
256 brg->dt_d = dt_d;
257 brg->typesize_D = types::data_type_size(brg->dt_d);
258
259 if (!IMPLICATION(
260 brg->is_int8 && brg->dt_d == bf16, mayiuse(avx512_core_vnni)))
261 return status::unimplemented;
262
263 if (brg->is_int8 && brg->dt_d == bf16)
264 brg->is_bf16_emu = !mayiuse(avx512_core_bf16);
265
266 // Rerun blocking heuristic due to reduced zmm register count
267 if (brg->is_bf16_emu && brg->is_dgmm) CHECK(brdgmm_blocking(brg));
268
269 if (!brg->attr) return status::success;
270
271 using namespace injector;
272
273 const auto &post_ops = brg->attr->post_ops_;
274 const memory_desc_wrapper dst_d(dst_md);
275
276 const auto binary_ind = post_ops.find(primitive_kind::binary);
277 brg->with_binary = binary_ind != -1;
278
279 // NOTE: Using brg->isa_impl here is a bit dangerous as it can change before
280 // kernel creation, so there is no gaurantee that the isa checked here
281 // matches the isa used at kernel creation time. For now this can only
282 // happen for bf32, where isa during this check is avx512_core and isa
283 // at kernel creation time is avx512_core_amx_bf16. It just so happens
284 // that the behavior of `post_ops_ok` is identical for those two isas,
285 // but there is no gaurentee that will always be the case.
286 if ((brg->with_binary && !dst_md)
287 || !injector::post_ops_ok(
288 post_ops_ok_args_t(brg->isa_impl, {sum, eltwise, binary},
289 post_ops, &dst_d, false /*sum_at_pos_0_only*/,
290 false /*sum_requires_scale_one*/,
291 false /*sum_requires_zp_zero*/,
292 {broadcasting_strategy_t::per_oc,
293 broadcasting_strategy_t::scalar,
294 broadcasting_strategy_t::per_mb_spatial,
295 broadcasting_strategy_t::per_mb_w,
296 broadcasting_strategy_t::per_w,
297 broadcasting_strategy_t::no_broadcast})))
298 return status::unimplemented;
299
300 const auto sum_idx = post_ops.find(primitive_kind::sum);
301 const bool with_sum = sum_idx != -1;
302 brg->with_sum = with_sum;
303 brg->sum_scale = with_sum ? post_ops.entry_[sum_idx].sum.scale : 0;
304 brg->sum_zp = with_sum ? post_ops.entry_[sum_idx].sum.zero_point : 0;
305 const auto sum_dt
306 = with_sum ? post_ops.entry_[sum_idx].sum.dt : data_type::undef;
307 brg->sum_dt = sum_dt != data_type::undef ? sum_dt : dt_d;
308
309 const auto eltwise_ind = post_ops.find(primitive_kind::eltwise);
310 brg->with_eltwise = eltwise_ind != -1;
311
312 const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC);
313 const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS);
314 brg->with_scales = !src_scales.has_default_values()
315 || !wei_scales.has_default_values();
316 if (brg->with_scales) {
317 // Note. the current version supports only two different output scale
318 // types:
319 // 1) common (mask_ = 0)
320 // 2) per_n_dim_scale - broadcast across n dimension;
321 // for convolution and inner product promitives it corresponds
322 // to "per_oc" mask_ = 1 << 1; for matmul - to
323 // mask_ = (1 << (ndims - 1))), where ndims is number of
324 // dimensions for original matmul problem
325 // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale
326 // type is per_n_dim_scale and driver which calls brgemm kernel checked
327 // that mask has correct value for this case
328 brg->is_oc_scale = wei_scales.mask_ != 0;
329 }
330 const bool scales_ok = src_scales.mask_ == 0
331 && attr->scales_.has_default_values(
332 {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS});
333 if (!scales_ok) return status::unimplemented;
334
335 auto init_zp_type
336 = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t {
337 auto zero_points = attr->zero_points_;
338
339 // common zero point type is supported for now
340 if (!zero_points.common(mem_arg)) return status::unimplemented;
341
342 zp_type = zero_points.has_default_values(mem_arg)
343 ? brgemm_broadcast_t::none
344 : brgemm_broadcast_t::per_tensor;
345 return status::success;
346 };
347
348 init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
349 init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
350 init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
351
352 // src zero points require additional register in brgemm kernel
353 if (brg->zp_type_a != brgemm_broadcast_t::none
354 || (brg->is_bf16_emu && !brg->is_dgmm))
355 CHECK(brgemm_blocking(brg));
356
357 return status::success;
358}
359
360status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
361 if (brg == nullptr) return status::invalid_arguments;
362
363 // negative padding is not supported
364 if (brgattr.max_top_vpad < 0 || brgattr.max_bottom_vpad < 0)
365 return status::unimplemented;
366
367 if (!brg->is_dgmm) {
368 // virtual padding size is restricted by MAX_VPAD value
369 if (brgattr.max_top_vpad > brgemm_t::MAX_VPAD
370 || brgattr.max_bottom_vpad > brgemm_t::MAX_VPAD)
371 return status::unimplemented;
372
373 // virtual padding is restricted by bd_block size due to
374 // brgemm_kernel implementation. TODO: remove this restriction
375 if (brgattr.max_top_vpad > brg->bd_block
376 || brgattr.max_bottom_vpad > brg->bd_block)
377 return status::unimplemented;
378 }
379
380 // virtual padding is supported for "brgemm_row_major" layout
381 // TODO: remove this restriction
382 if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
383 && brg->layout != brgemm_row_major)
384 return status::unimplemented;
385
386 brg->brgattr = brgattr;
387
388 if (brgattr.fpmath_mode != fpmath_mode::strict) maybe_try_bf32(brg);
389
390 bool hint_blocking_set
391 = (brgattr.hint_bd_block != 0 || brgattr.hint_bd_block2 != 0
392 || brgattr.hint_ld_block != 0 || brgattr.hint_ld_block2 != 0
393 || brgattr.hint_load_nt_A != brgemm_hint_nt_undef
394 || brgattr.hint_load_nt_B != brgemm_hint_nt_undef);
395 if (brg->is_bf16_tmm || hint_blocking_set || brgattr.bd_mask_level
396 || brgattr.fpmath_mode != fpmath_mode::strict) {
397 if (brg->is_dgmm)
398 CHECK(brdgmm_blocking(brg));
399 else
400 CHECK(brgemm_blocking(brg));
401 }
402
403 brg->LDA2 = (brgattr.LDA2 != 0) ? brgattr.LDA2 : brg->LDA;
404 brg->LDB2 = (brgattr.LDB2 != 0) ? brgattr.LDB2 : brg->LDB;
405 brg->LDC2_M = (brgattr.LDC2_M != 0) ? brgattr.LDC2_M : brg->LDC;
406 brg->LDC2_N = (brgattr.LDC2_N != 0) ? brgattr.LDC2_N : brg->ld_block;
407
408 brg->is_blocked = (brg->LDA2 != brg->LDA || brg->LDB2 != brg->LDB
409 || brg->LDC2_M != brg->LDC || brg->LDC2_N != brg->ld_block);
410
411 if (!IMPLICATION(brg->is_blocked, brg->layout = brgemm_row_major))
412 return status::invalid_arguments;
413
414 // virtual padding is not supported for "amx"
415 if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
416 && (brg->is_tmm))
417 return status::unimplemented;
418
419 brg->prfA = brgattr.hint_prfA;
420 brg->prfB = brgattr.hint_prfB;
421 brg->prfC = brgattr.hint_prfC;
422
423 if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf1
424 && brg->prfC.dist1 < 0)
425 brg->prfC.dist1 = 0;
426 if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf2
427 && brg->prfC.dist2 < 0)
428 brg->prfC.dist2 = 0;
429
430 return status::success;
431}
432
433status_t brgemm_kernel_create(
434 brgemm_kernel_t **brg_kernel, const brgemm_t &brg) {
435 if (!brg_kernel) return status::invalid_arguments;
436 *brg_kernel = nullptr;
437
438 if (brg.is_dgmm) {
439#define CASE(isa) \
440 case isa: \
441 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, \
442 new brdgmm_kernel_t<isa, typename cpu_isa_traits<isa>::Vmm>( \
443 brg))); \
444 break
445 switch (brg.isa_impl) {
446 CASE(avx512_core_fp16);
447 CASE(avx512_core_bf16);
448 CASE(avx512_core_vnni);
449 CASE(avx512_core);
450 CASE(avx2_vnni_2);
451 CASE(avx2);
452 default: return status::unimplemented;
453 }
454#undef CASE
455 } else if (can_dispatch_uker(&brg)) {
456 CHECK(safe_ptr_assign<brgemm_kernel_t>(
457 *brg_kernel, new brgemm_amx_uker_t(brg)));
458 } else {
459 if (brg.is_tmm) {
460 if (brg.is_f16_tmm) {
461 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
462 new brgemm_kernel_common_t<avx512_core_amx_fp16,
463 Xbyak::Tmm>(brg)));
464 } else {
465 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
466 new brgemm_kernel_common_t<avx512_core_amx, Xbyak::Tmm>(
467 brg)));
468 }
469 } else if (brg.is_zmm) {
470 // isa specific instantiations are required because
471 // post-ops require template isa param.
472 if (brg.isa_impl == avx512_core_fp16) {
473 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
474 new brgemm_kernel_common_t<avx512_core_fp16,
475 Xbyak::Zmm>(brg)));
476 } else if (brg.isa_impl == avx512_core_bf16) {
477 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
478 new brgemm_kernel_common_t<avx512_core_bf16,
479 Xbyak::Zmm>(brg)));
480 } else if (brg.isa_impl == avx512_core_vnni) {
481 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
482 new brgemm_kernel_common_t<avx512_core_vnni,
483 Xbyak::Zmm>(brg)));
484 } else {
485 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
486 new brgemm_kernel_common_t<avx512_core, Xbyak::Zmm>(
487 brg)));
488 }
489 } else if (brg.is_ymm) {
490 if (brg.isa_impl == avx2) {
491 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
492 new brgemm_kernel_common_t<avx2, Xbyak::Ymm>(brg)));
493 } else if (brg.isa_impl == avx2_vnni) {
494 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
495 new brgemm_kernel_common_t<avx2_vnni, Xbyak::Ymm>(
496 brg)));
497 } else if (brg.isa_impl == avx2_vnni_2) {
498 CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel,
499 new brgemm_kernel_common_t<avx2_vnni_2, Xbyak::Ymm>(
500 brg)));
501 }
502 }
503 }
504 if (!(*brg_kernel)) return status::unimplemented;
505 return (*brg_kernel)->create_kernel();
506}
507
508status_t brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel) {
509 delete brg_kernel;
510 return status::success;
511}
512
513status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
514 constexpr int max_palette_size_in_bytes = 64;
515
516 if (!brg.is_tmm) return status::unimplemented;
517
518 //TODO: Add support of tail processing by reduction dimension
519 auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
520 if (brg.is_bf32) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
521
522 palette_config_t *buff = (palette_config_t *)(palette);
523
524 char *_tc = (char *)(buff);
525 for (int i = 0; i < max_palette_size_in_bytes; i++)
526 _tc[i] = 0;
527
528 const int typesize_A = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_A;
529 const int typesize_B = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_B;
530
531 const int rd_step = 4 / typesize_A;
532
533 const auto Ac = typesize_A * rd_block;
534
535 const auto Br = (brg.typesize_C != 0) ? Ac / brg.typesize_C : 0;
536
537 if (brg.ldb_tail && (brg.ld_block2 > 1)) return status::unimplemented;
538 if (brg.get_num_A_tiles() + brg.get_num_B_tiles()
539 + brg.get_bd_block2() * brg.get_ld_block2()
540 > brgemm_t::AMX_TILES_NUM)
541 return status::unimplemented;
542
543 // Due to interleaving tileload/tmul we don't support blocking 1x6 and 6x1
544 //TODO: update gemm_microkernel_amx to support such blocking
545 if (brg.get_bd_block2() >= 6 || brg.get_num_C_tiles() >= 6)
546 return status::unimplemented;
547
548 for (int m = 0; m < brg.get_num_A_tiles(); m++) {
549 const bool is_bd_tail
550 = (brg.bdb_tail && m == (brg.get_num_A_tiles() - 1));
551 const auto A_tensor = brg.get_A_tensor(m, is_bd_tail);
552 const auto Ar = is_bd_tail ? brg.bdb_tail : brg.bd_block;
553 tc_configure_tile(buff, A_tensor, Ar, Ac);
554 }
555
556 for (int n = 0; n < brg.get_num_B_tiles(); n++) {
557 const bool is_ld_tail
558 = (brg.ldb_tail && n == (brg.get_num_B_tiles() - 1));
559 const auto B_tensor = brg.get_B_tensor(n, is_ld_tail);
560 const auto Bc = (is_ld_tail ? brg.ldb_tail : brg.ld_block) * typesize_B
561 * rd_step;
562 tc_configure_tile(buff, B_tensor, Br, Bc);
563 }
564
565 for (int m = 0; m < brg.get_bd_block2(); m++) {
566 const bool is_bd_tail
567 = (brg.bdb_tail && m == (brg.get_bd_block2() - 1));
568 const auto Cr = is_bd_tail ? brg.bdb_tail : brg.bd_block;
569 for (int n = 0; n < brg.get_ld_block2(); n++) {
570 const bool is_ld_tail
571 = (brg.ldb_tail && n == (brg.get_ld_block2() - 1));
572 const auto Cc = (is_ld_tail ? brg.ldb_tail : brg.ld_block)
573 * brg.typesize_C;
574 const auto C_tensor
575 = brg.get_C_tensor(m, n, is_bd_tail, is_ld_tail);
576 tc_configure_tile(buff, C_tensor, Cr, Cc);
577 }
578 }
579
580 buff->palette_id = amx::get_target_palette();
581
582 return status::success;
583}
584
585} // namespace x64
586} // namespace cpu
587} // namespace impl
588} // namespace dnnl
589
590// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
591