1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <unordered_set>
18
19#include "common/dnnl_thread.hpp"
20#include "cpu/platform.hpp"
21#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
22#include "cpu/x64/matmul/brgemm_matmul_utils.hpp"
23
24#include "cpu/binary_injector_utils.hpp"
25#include "cpu/matmul/matmul_utils.hpp"
26#include "oneapi/dnnl/dnnl_debug.h"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32namespace matmul {
33
34using namespace dnnl::impl::cpu::matmul;
35
36using namespace dnnl::impl::memory_tracking::names;
37using namespace dnnl::impl::utils;
38
39using namespace data_type;
40using namespace format_tag;
41
42int get_default_n_block(format_tag_t matrix_b_tag) {
43 // Note: consider using weights mem_descriptor 'inner_blks' to
44 // return B's inner block for non-default cases.
45 switch (matrix_b_tag) {
46 case aCB16b64c:
47 case aCB16b64c2b:
48 case aCB16b64c4b:
49 case BA16a64b4a:
50 case BA16a64b2a:
51 case BA16a64b: return 64;
52 case aCB16b48c:
53 case aCB16b48c2b:
54 case aCB16b48c4b:
55 case BA16a48b:
56 case BA16a48b2a:
57 case BA16a48b4a: return 48;
58 case aCB16b32c:
59 case aCB16b32c2b:
60 case aCB16b32c4b:
61 case BA16a32b:
62 case BA16a32b2a:
63 case BA16a32b4a: return 32;
64 case aCB16b16c:
65 case aCB16b16c2b:
66 case aCB16b16c4b:
67 case BA16a16b:
68 case BA16a16b2a:
69 case BA16a16b4a: return 16;
70 default: return 64;
71 }
72}
73
74// TODO: add support of post-ops with multiple binary and eltwise execution
75bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
76 const memory_desc_wrapper &dst_d) {
77 using namespace injector;
78
79 const auto &post_ops = attr.post_ops_;
80 const auto ndims = dst_d.ndims();
81
82 bool is_binary_po_per_oc_sp_bcast {};
83 bool is_binary_po_channel_bcast {};
84 bool is_binary_po_per_mb_w_bcast {};
85 bool is_binary_po_per_w_bcast {};
86 std::tie(is_binary_po_per_oc_sp_bcast, is_binary_po_channel_bcast,
87 is_binary_po_per_mb_w_bcast, is_binary_po_per_w_bcast)
88 = binary_injector_utils::bcast_strategies_present_tup(
89 post_ops.entry_, dst_d,
90 broadcasting_strategy_t::per_oc_spatial,
91 broadcasting_strategy_t::per_mb_spatial,
92 broadcasting_strategy_t::per_mb_w,
93 broadcasting_strategy_t::per_w);
94 const bool supported_binary_bcast
95 = IMPLICATION(is_binary_po_per_oc_sp_bcast, ndims < 4)
96 && IMPLICATION(
97 is_binary_po_channel_bcast, utils::one_of(ndims, 3, 4))
98 && IMPLICATION(
99 is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4))
100 && IMPLICATION(
101 is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4));
102 return supported_binary_bcast
103 && injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
104 {sum, eltwise, binary}, post_ops, &dst_d,
105 false /*sum_at_pos_0_only*/,
106 false /*sum_requires_scale_one*/,
107 false /*sum_requires_zp_zero*/,
108 {broadcasting_strategy_t::per_oc,
109 broadcasting_strategy_t::per_oc_spatial,
110 broadcasting_strategy_t::scalar,
111 broadcasting_strategy_t::per_mb_spatial,
112 broadcasting_strategy_t::per_mb_w,
113 broadcasting_strategy_t::per_w,
114 broadcasting_strategy_t::no_broadcast}));
115}
116
117status_t check_isa_with_datatype(
118 const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) {
119 const bool ok = IMPLICATION(bm_conf_utils.is_f32(),
120 isa == avx512_core || bm_conf_utils.is_bf32())
121 && IMPLICATION(bm_conf_utils.is_int8(),
122 one_of(isa, avx512_core_amx, avx512_core_vnni))
123 && IMPLICATION(bm_conf_utils.is_bf16(),
124 one_of(isa, avx512_core_amx, avx512_core_bf16))
125 && IMPLICATION(bm_conf_utils.is_f16(),
126 one_of(isa, avx512_core_amx_fp16, avx512_core_fp16))
127 && IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(),
128 mayiuse(avx512_core_vnni));
129 return ok ? status::success : status::unimplemented;
130}
131
132brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
133 brgemm_matmul_conf_t &bgmmc, const cpu_isa_t isa,
134 const primitive_attr_t &attr, bool A_any_layout, bool B_any_layout,
135 bool C_any_layout, bool bias_any_layout)
136 : bgmmc(bgmmc)
137 , f32_dt(utils::everyone_is(f32, bgmmc.src_dt, bgmmc.wei_dt, bgmmc.dst_dt))
138 , bf16_dt(utils::everyone_is(bf16, bgmmc.src_dt, bgmmc.wei_dt)
139 && one_of(bgmmc.dst_dt, bf16, f32))
140 , f16_dt(utils::everyone_is(f16, bgmmc.src_dt, bgmmc.wei_dt)
141 && one_of(bgmmc.dst_dt, f16, f32))
142 , int8_dt(utils::one_of(bgmmc.src_dt, u8, s8) && bgmmc.wei_dt == s8
143 && one_of(bgmmc.dst_dt, u8, s8, s32, f32, bf16))
144 , bf32_dt(f32_dt && attr.fpmath_mode_ == fpmath_mode::bf16
145 && isa == avx512_core_amx)
146 , A_any_layout(A_any_layout)
147 , B_any_layout(B_any_layout)
148 , C_any_layout(C_any_layout)
149 , bias_any_layout(bias_any_layout)
150 , plain_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ab, abc, abcd, abcde,
151 abcdef, abcdefg, abcdefgh, abcdefghi, abcdefghij, abcdefghijk,
152 abcdefghijkl))
153 , transposed_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ba, acb, abdc,
154 abced, abcdfe, abcdegf, abcdefhg, abcdefgih, abcdefghji,
155 abcdefghikj, abcdefghijlk))
156 , blocked_64n_B_layout_tag(pick_blocked_B_layout(64))
157 , blocked_48n_B_layout_tag(pick_blocked_B_layout(48))
158 , blocked_32n_B_layout_tag(pick_blocked_B_layout(32))
159 , blocked_16n_B_layout_tag(pick_blocked_B_layout(16))
160 , blocked_B_layouts_allowed(!utils::one_of(format_tag::undef,
161 blocked_64n_B_layout_tag, blocked_48n_B_layout_tag,
162 blocked_32n_B_layout_tag, blocked_16n_B_layout_tag))
163 , n_blk_fixed((!B_any_layout) && blocked_B_layouts_allowed) {
164 assert(int8_dt || bf16_dt || f16_dt || f32_dt || bf32_dt);
165}
166
167status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(
168 memory_desc_t &B_md, bool init_n_tag) const {
169
170 if (B_any_layout) {
171 const int default_n_block = init_n_tag
172 ? get_default_n_block(format_tag::undef)
173 : bgmmc.N_blk;
174 bgmmc.wei_tag = blocked_B_layouts_allowed
175 ? this->pick_blocked_B_layout(default_n_block)
176 : plain_tensor_layout_tag;
177 if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented;
178
179 CHECK(memory_desc_init_by_tag(B_md, bgmmc.wei_tag));
180 const int dmax = nstl::min(bgmmc.ndims, 3);
181 const memory_desc_wrapper B_d(&B_md);
182 for (int d = 0; d < dmax; d++) {
183 int dim = bgmmc.ndims - 1 - d;
184 bgmmc.B_strides[d]
185 = bgmmc.b_dt_sz * B_d.blocking_desc().strides[dim];
186 }
187 } else {
188 bgmmc.wei_tag = blocked_B_layouts_allowed
189 ? memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag,
190 transposed_tensor_layout_tag, blocked_64n_B_layout_tag,
191 blocked_48n_B_layout_tag, blocked_32n_B_layout_tag,
192 blocked_16n_B_layout_tag)
193 : memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag,
194 transposed_tensor_layout_tag, acbd, adbc);
195
196 // For cases when the weights tensor is transposed but has
197 // 'dim_size == 1', we can ignore transposition and compute as a plain
198 // format tensor. This removes the need of allocating a scratchpad for
199 // copy_B.
200 if (transposed_tensor_layout_tag == bgmmc.wei_tag) {
201 memory_desc_t B_md_plain;
202 const status_t status
203 = memory_desc_init_by_tag(B_md_plain, B_md.ndims, B_md.dims,
204 B_md.data_type, plain_tensor_layout_tag);
205 if (status != status::success) return status;
206 if (B_md_plain == B_md) bgmmc.wei_tag = plain_tensor_layout_tag;
207 }
208
209 if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented;
210 }
211
212 return status::success;
213}
214
215status_t brgemm_matmul_conf_utils_t::update_and_check_B_tag(
216 memory_desc_t &B_md, int n_blk_size) const {
217
218 if (n_blk_fixed && n_blk_size != bgmmc.wei_n_blk)
219 return status::unimplemented;
220
221 if (!(B_any_layout && blocked_B_layouts_allowed)) return status::success;
222
223 return set_or_check_B_tag(B_md, false);
224}
225
226status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
227 memory_desc_t &C_md, memory_desc_t &bias_md) const {
228 if (A_any_layout) {
229 const format_tag_t desired_A_tag = plain_tensor_layout_tag;
230 CHECK(memory_desc_init_by_tag(A_md, desired_A_tag));
231 bgmmc.src_tag = desired_A_tag;
232 } else {
233 bgmmc.src_tag = (this->is_bf16() || this->is_f32() || this->is_bf32()
234 || this->is_f16())
235 ? memory_desc_matches_one_of_tag(A_md, plain_tensor_layout_tag,
236 transposed_tensor_layout_tag, acbd, adbc)
237 : memory_desc_matches_one_of_tag(
238 A_md, plain_tensor_layout_tag, acbd);
239 }
240
241 if (C_any_layout) {
242 const format_tag_t desired_C_tag = plain_tensor_layout_tag;
243 CHECK(memory_desc_init_by_tag(C_md, desired_C_tag));
244 bgmmc.dst_tag = desired_C_tag;
245 } else {
246 bgmmc.dst_tag = memory_desc_matches_one_of_tag(
247 C_md, plain_tensor_layout_tag, acbd);
248 }
249
250 if (one_of(format_tag::undef, bgmmc.src_tag, bgmmc.dst_tag))
251 return status::unimplemented;
252
253 if (bgmmc.with_bias && bias_any_layout)
254 CHECK(memory_desc_init_by_tag(bias_md, plain_tensor_layout_tag));
255
256 return status::success;
257}
258
259status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
260
261 memory_desc_t want_B_md = B_md;
262 // Set bits for all dimensions except k dimension
263 const int compensation_mask
264 = ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2)));
265 if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) {
266 want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8;
267 want_B_md.extra.compensation_mask = compensation_mask;
268 }
269 if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) {
270 want_B_md.extra.flags
271 |= memory_extra_flags::compensation_conv_asymmetric_src;
272 want_B_md.extra.asymm_compensation_mask = compensation_mask;
273 }
274
275 if (B_any_layout) {
276 B_md = want_B_md;
277 return status::success;
278 }
279
280 return B_md == want_B_md ? status::success : status::unimplemented;
281}
282
283format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
284 int n_blk) const {
285
286 if (bgmmc.ndims > 3) return format_tag::undef;
287 if (this->is_int8()) switch (n_blk) {
288 case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
289 case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
290 case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
291 case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
292 default: return format_tag::undef;
293 }
294
295 if (this->is_bf16()
296 || (this->is_f16() && bgmmc.isa == avx512_core_amx_fp16))
297 switch (n_blk) {
298 case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
299 case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
300 case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
301 case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
302 default: return format_tag::undef;
303 }
304 // Note: bf32 assumes f32 blocking
305 if (this->is_f32() || this->is_bf32() || this->is_f16()) switch (n_blk) {
306 case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
307 case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
308 case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
309 case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
310 default: return format_tag::undef;
311 }
312 return format_tag::undef;
313}
314
315brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) {
316 return attr.zero_points_.has_default_values(arg)
317 ? brgemm_broadcast_t::none
318 : brgemm_broadcast_t::per_tensor;
319}
320
321struct matmul_amx_blocking_params_t : public brgemm_matmul_conf_t {
322 matmul_amx_blocking_params_t()
323 : nthr_k_(0)
324 , nthr_mnb_(0)
325 , nthr_(0)
326 , n_blk_(0)
327 , n_chunk_size_(0)
328 , n_chunk_elems_(0)
329 , m_blk_(0)
330 , m_chunk_size_(0)
331 , m_chunk_elems_(0)
332 , k_blk_(0)
333 , k_chunk_size_(0)
334 , k_chunk_elems_(0)
335 , current_lda_(0)
336 , need_buf_c_(false)
337 , blocking_chunk_mem_size_(0)
338 , efficiency_score_(0.0f) {}
339
340 matmul_amx_blocking_params_t(const brgemm_matmul_conf_t &bgmmc)
341 : brgemm_matmul_conf_t(bgmmc)
342 , nthr_k_(nstl::max(nthr_k, 1))
343 , nthr_mnb_(nthr / nthr_k_)
344 , nthr_(nthr_mnb_ * nthr_k_)
345 , n_blk_(N_blk)
346 , n_chunk_size_(N_chunk_size)
347 , n_chunk_elems_(n_blk_ * n_chunk_size_)
348 , m_blk_(M_blk)
349 , m_chunk_size_(M_chunk_size)
350 , m_chunk_elems_(m_blk_ * m_chunk_size_)
351 , k_blk_(K_blk)
352 , k_chunk_size_(brgemm_batch_size)
353 , k_chunk_elems_(k_blk_ * k_chunk_size_)
354 , current_lda_(LDA)
355 , need_buf_c_(use_buffer_c)
356 , blocking_chunk_mem_size_(0)
357 , efficiency_score_(0.0f) {}
358
359 void set_blocking_parameters(int nthr_k, int n_blk, int n_chunk_size,
360 int m_blk, int m_chunk_size);
361 void update_configuration(brgemm_matmul_conf_t &bgmmc) const;
362 float get_blocking_scores() const { return efficiency_score_; }
363
364 static size_t L2_threshold();
365
366private:
367 // num threads for parallelism wrt k dimension
368 int nthr_k_;
369 // num threads for parallelism wrt m, n and batch dimensions
370 int nthr_mnb_;
371 int nthr_;
372 dim_t n_blk_, n_chunk_size_, n_chunk_elems_;
373 dim_t m_blk_, m_chunk_size_, m_chunk_elems_;
374 dim_t k_blk_, k_chunk_size_, k_chunk_elems_;
375
376 dim_t current_lda_;
377 bool need_buf_c_;
378 size_t blocking_chunk_mem_size_;
379 float efficiency_score_;
380
381 void update_k_blocking_dependent_params();
382 dim_t get_actual_lda();
383 bool is_buffer_c_required();
384 size_t calculate_chunk_memory_size();
385 float get_thread_balance_scores();
386 float get_copied_data_reusage_scores();
387 float get_L2_utilization_scores() const;
388 float calculate_blocking_scores();
389};
390
391struct matmul_avx512_blocking_params_t {
392 struct matmul_params_t {
393
394 matmul_params_t(int m, int n, int k, int od)
395 : M(m), N(n), K(k), batch(od) {}
396
397 const int M;
398 const int N;
399 const int K;
400 const int batch;
401 };
402
403 matmul_avx512_blocking_params_t(const matmul_params_t &m, const int nthr)
404 : mp(m)
405 , m_chunks(1)
406 , m_blk(1)
407 , m_tail(0)
408 , n_chunks(1)
409 , n_blk(1)
410 , n_tail(0)
411 , batch_size(1)
412 , k_blk(1)
413 , k_tail(0)
414 , nthr_k(1)
415 , nthr(nthr) {}
416
417 matmul_avx512_blocking_params_t &operator=(
418 const matmul_avx512_blocking_params_t &brgemm_params) {
419 m_chunks = brgemm_params.m_chunks;
420 m_blk = brgemm_params.m_blk;
421 m_tail = brgemm_params.m_tail;
422 n_chunks = brgemm_params.n_chunks;
423 n_blk = brgemm_params.n_blk;
424 n_tail = brgemm_params.n_tail;
425 batch_size = brgemm_params.batch_size;
426 k_blk = brgemm_params.k_blk;
427 k_tail = brgemm_params.k_tail;
428 nthr_k = brgemm_params.nthr_k;
429 return *this;
430 }
431
432 const matmul_params_t &mp;
433 int m_chunks, m_blk, m_tail;
434 int n_chunks, n_blk, n_tail;
435 int batch_size, k_blk, k_tail;
436 int nthr_k;
437 const int nthr;
438
439 void update_params(int m_chunks_, int m_blk_, int n_chunks_, int n_blk_,
440 int batch_size_, int k_blk_, int nthr_k_) {
441 m_chunks = m_chunks_;
442 m_blk = m_blk_;
443 m_tail = mp.M % m_blk;
444 n_chunks = n_chunks_;
445 n_blk = n_blk_;
446 n_tail = mp.N % n_blk;
447 batch_size = batch_size_;
448 k_blk = k_blk_;
449 k_tail = mp.K % k_blk;
450 nthr_k = nthr_k_;
451 }
452
453 float calculate_spatial_disbalance(size_t work, size_t thread_block) const {
454 size_t mod = work % thread_block;
455 size_t scalar = work < thread_block
456 ? thread_block - mod
457 : nstl::min(thread_block - mod, mod);
458 return static_cast<float>(scalar) / thread_block;
459 }
460
461 float get_imbalance() const {
462 const size_t cur_nthr = nthr / nthr_k;
463
464 size_t parallel_work = get_parallel_work();
465 const float parallel_work_disb
466 = calculate_spatial_disbalance(parallel_work, cur_nthr);
467
468 int m_work = (m_blk * div_up(mp.M, m_blk)) % mp.M;
469 const float m_blk_disbalance = static_cast<float>(m_work) / mp.M;
470
471 int num_n_blk = div_up(mp.N, n_blk);
472 int par_n_chunks = div_up(num_n_blk, n_chunks);
473 const float n_chunk_disbalance
474 = (static_cast<float>(par_n_chunks) * n_chunks - num_n_blk)
475 / num_n_blk;
476
477 const float disbalance_nthr_k
478 = calculate_spatial_disbalance(mp.K, nthr_k * k_blk);
479
480 const float thread_allocation_disb
481 = (cur_nthr * nthr_k) != static_cast<size_t>(nthr)
482 ? (static_cast<float>(nthr) - cur_nthr * nthr_k) / nthr
483 : 0;
484
485 const float score
486 = (parallel_work_disb + m_blk_disbalance + n_chunk_disbalance
487 + thread_allocation_disb + disbalance_nthr_k)
488 / 5;
489
490 return score;
491 }
492
493 size_t get_parallel_work() const {
494 int m_elems = div_up(mp.M, m_blk * m_chunks);
495 int n_elems = div_up(mp.N, n_blk * n_chunks);
496 return static_cast<size_t>(m_elems) * n_elems * mp.batch;
497 }
498
499 inline dim_t get_actual_lda(bool use_buffer_a, dim_t a_dt_sz) const {
500 if (!use_buffer_a) return mp.K;
501
502 constexpr int bytes_in_cacheline = 64;
503 const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz;
504 dim_t lda = rnd_up(k_blk, elems_in_cacheline);
505 const bool is_big_pow_2 = lda >= 512 && math::is_pow2(lda);
506 if (is_big_pow_2) lda += elems_in_cacheline;
507 return lda;
508 }
509
510 inline bool is_buffer_c_required(
511 dim_t acc_dt, dim_t dst_dt, bool with_sum) const {
512 const size_t k_chunk_elems = k_blk * batch_size;
513 if (nthr_k > 1 && static_cast<size_t>(mp.K) > k_chunk_elems)
514 return true;
515
516 return ((acc_dt != dst_dt || with_sum)
517 && (static_cast<size_t>(mp.K) > k_chunk_elems
518 || mp.K % k_blk > 0));
519 }
520
521 void update_configuration(brgemm_matmul_conf_t &bgmmc) const {
522 bgmmc.M_blk = m_blk;
523 bgmmc.M_chunk_size = m_chunks;
524 bgmmc.N_blk = n_blk;
525 bgmmc.N_chunk_size = n_chunks;
526
527 bgmmc.K_blk = rnd_up(k_blk, bgmmc.required_k_granularity);
528 bgmmc.brgemm_batch_size = batch_size;
529
530 bgmmc.nthr_k = nthr_k;
531
532 bgmmc.use_buffer_c = is_buffer_c_required(
533 bgmmc.acc_dt, bgmmc.dst_dt, bgmmc.with_sum);
534 bgmmc.LDA = (bgmmc.src_tag == acbd && !bgmmc.use_buffer_a
535 ? bgmmc.A_strides[1] / bgmmc.a_dt_sz
536 : get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz));
537 }
538};
539
540size_t matmul_amx_blocking_params_t::L2_threshold() {
541 return 3 * platform::get_per_core_cache_size(2) / 4;
542}
543
544void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc,
545 const brgemm_matmul_conf_utils_t &bm_conf_utils,
546 matmul_amx_blocking_params_t &best_blocking) {
547
548 matmul_amx_blocking_params_t current_blocking(bgmmc);
549
550 const int min_k_per_thread = 1024;
551 const int max_k_parallel_work
552 = div_up(static_cast<int>(bgmmc.K), min_k_per_thread);
553 const bool is_amx_xf16 = bgmmc.is_amx
554 && (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
555 || bm_conf_utils.is_bf32());
556 const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8();
557
558 const int max_nthr_k = is_amx_xf16 && bgmmc.batch == 1
559 ? nstl::min(saturate(1, 7, bgmmc.nthr / 8), max_k_parallel_work)
560 : 1;
561 int iter = 0;
562 for (int nthr_k = 1; nthr_k <= max_nthr_k; nthr_k++) {
563 int num_M_blk = div_up(bgmmc.M, bgmmc.M_blk);
564 int num_N_blk = div_up(bgmmc.N, bgmmc.N_blk);
565 int k_parallel_work = nstl::min(max_k_parallel_work, nthr_k);
566 int num_parallel_work
567 = bgmmc.batch * num_M_blk * num_N_blk * k_parallel_work;
568 const bool a_lot_of_parallel_work = num_parallel_work > 8 * bgmmc.nthr;
569 const bool a_lot_of_parallel_work_lvl2
570 = num_parallel_work > 16 * bgmmc.nthr;
571 const bool low_parallelism
572 = static_cast<float>(num_parallel_work) < 1.5f * bgmmc.nthr;
573 const bool maybe_low_blocking
574 = is_amx_int8 && bm_conf_utils.maybe_low_brg_blocking();
575 const int min_M_blk
576 = (maybe_low_blocking || low_parallelism) && bgmmc.M_blk > 32
577 ? div_up(bgmmc.M_blk, 2)
578 : bgmmc.M_blk;
579 const int min_N_blk = low_parallelism && is_amx_xf16
580 && !bm_conf_utils.check_n_blk_fixed()
581 && bgmmc.N_blk > 32
582 ? 32
583 : bgmmc.N_blk;
584 const int desired_M_chunk = nstl::min(
585 (bgmmc.use_buffer_b || a_lot_of_parallel_work ? 4 : 1),
586 num_M_blk);
587 const int desired_N_chunk = nstl::min(a_lot_of_parallel_work_lvl2
588 ? 6
589 : (bgmmc.use_buffer_a || a_lot_of_parallel_work ? 4
590 : 1),
591 num_N_blk);
592
593 std::unordered_set<int> mblk_candidates;
594 for (int m_blk = bgmmc.M_blk; m_blk >= min_M_blk;
595 m_blk = m_blk > 1 ? div_up(m_blk, 2) : m_blk - 1) {
596 if (IMPLICATION(maybe_low_blocking, m_blk != bgmmc.M_blk))
597 mblk_candidates.insert(m_blk);
598 }
599
600 if (bgmmc.M > 16) {
601 // Add multiple of 16 M block sizes for consideration
602 const int mul16_m_blk_max
603 = nstl::min(rnd_dn(static_cast<int>(bgmmc.M), 16), 64);
604 const int mul16_m_blk_min = rnd_up(min_M_blk, 16);
605 for (int m_blk = mul16_m_blk_max; m_blk >= mul16_m_blk_min;
606 m_blk -= 16) {
607 mblk_candidates.insert(m_blk);
608 }
609 }
610
611 for_(int n_blk = bgmmc.N_blk; n_blk >= min_N_blk; n_blk -= 16)
612 for_(int m_blk : mblk_candidates)
613 for_(int n_ch_sz = desired_N_chunk; n_ch_sz >= 1; n_ch_sz--)
614 for (int m_ch_sz = desired_M_chunk; m_ch_sz >= 1; m_ch_sz--, iter++) {
615 current_blocking.set_blocking_parameters(
616 nthr_k, n_blk, n_ch_sz, m_blk, m_ch_sz);
617 if (current_blocking.get_blocking_scores()
618 > best_blocking.get_blocking_scores())
619 best_blocking = current_blocking;
620 }
621 }
622}
623
624float compute_blocking_heuristic_avx512(brgemm_matmul_conf_t &bgmmc,
625 const brgemm_matmul_conf_utils_t &bm_conf_utils,
626 const matmul_avx512_blocking_params_t::matmul_params_t &matmul,
627 matmul_avx512_blocking_params_t &best_blocking) {
628
629 const int nthr = bgmmc.nthr;
630
631 const int max_m_blk = nstl::min(256, matmul.M);
632 int min_m_blk = nstl::min(32, matmul.M);
633
634 int n_blk = bgmmc.N_blk;
635 const int n_chunks = div_up(matmul.N, n_blk);
636 const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
637 const int n_chunks_start = nstl::min(max_n_chunks, div_up(matmul.N, n_blk));
638
639 // Note: do not extend K_blk for 'bwd_w' cases
640 const bool use_extended_k_blk = matmul.K > 1024
641 && (!bm_conf_utils.check_is_transposed(bgmmc.src_tag));
642 int default_k_blk = use_extended_k_blk ? 1024 : 512;
643 int k_blk = nstl::min(matmul.K, default_k_blk);
644 int start_nthr_k = 1;
645
646 // for cases with low parallel work, reduce 'min_m_blk' to
647 // increase potential parallelization balance.
648 const size_t max_parallel = matmul.batch * n_chunks;
649 const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
650 if (low_parallel_work) {
651
652 min_m_blk = nstl::min(matmul.M, 16);
653
654 // 2nd level tuning for low parallel work cases:
655 bool bwd_w_low_spatial_work
656 = bm_conf_utils.check_is_transposed(bgmmc.src_tag)
657 && matmul.M <= 512;
658 bool low_spatial_work = matmul.M <= 40;
659 if (low_spatial_work || bwd_w_low_spatial_work) {
660
661 // Reduce n_blk size to increase parallel space
662 // note: over reduction of n_blk size on 2d shapes when n_chunks == 1
663 // showed significant performance degradation
664 if (!bm_conf_utils.check_n_blk_fixed()
665 && IMPLICATION(n_chunks == 1, bgmmc.batch_ndims > 0))
666 n_blk = nstl::min(matmul.N, 32);
667
668 // force to plain B (wei) in small spatial size for FWD:
669 // note: this showed significant performance gain in WnD shapes
670 bool is_FWD = !(bm_conf_utils.check_is_transposed(bgmmc.wei_tag)
671 || bm_conf_utils.check_is_transposed(bgmmc.src_tag));
672 if (bgmmc.use_buffer_b && is_FWD) {
673 bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(false);
674 }
675 }
676
677 // Parallelize across K for shapes with big 'K' dimension
678 bool bwd_w_par_k_blk = bm_conf_utils.check_is_transposed(bgmmc.src_tag)
679 && IMPLICATION(bm_conf_utils.is_bf16(), math::is_pow2(matmul.K))
680 && matmul.K >= 2048;
681 if (bwd_w_par_k_blk) {
682 start_nthr_k = nstl::min(nthr, 4);
683 assert(k_blk == nstl::min(matmul.K, 512));
684 }
685 }
686
687 float best_imbalance = 1.f; // reduce
688 for_(int nthr_k = start_nthr_k; nthr_k >= 1; --nthr_k)
689 for_(int n_chunk_size = n_chunks_start; n_chunk_size >= 1; --n_chunk_size)
690 for (int m_blk = max_m_blk; m_blk >= min_m_blk; --m_blk) {
691
692 matmul_avx512_blocking_params_t cur_params(matmul, nthr);
693 cur_params.update_params(
694 1, m_blk, n_chunk_size, n_blk, 1, k_blk, nthr_k);
695
696 float cur_imbalance = cur_params.get_imbalance();
697 if (cur_imbalance < best_imbalance) {
698 best_imbalance = cur_imbalance;
699 best_blocking = cur_params;
700 }
701 }
702 return best_imbalance;
703}
704
705status_t compute_blocking_heuristic(brgemm_matmul_conf_t &bgmmc,
706 const brgemm_matmul_conf_utils_t &bm_conf_utils) {
707
708 bgmmc.N_blk = nstl::min(static_cast<dim_t>(bgmmc.wei_n_blk), bgmmc.N);
709
710 bgmmc.M_chunk_size = bgmmc.N_chunk_size = 1;
711
712 if (bgmmc.is_amx) {
713
714 // Configure matrix sizes
715 const dim_t max_M = 64, min_M = 32;
716 bgmmc.M_blk = 1;
717 for (dim_t m_ = max_M; m_ >= min_M; m_--) {
718 if (bgmmc.M % m_ == 0) {
719 bgmmc.M_blk = m_;
720 break;
721 }
722 }
723 if (bgmmc.M_blk == 1) bgmmc.M_blk = nstl::min(bgmmc.M, max_M);
724
725 // AMX BRGEMM kernel requires (K_brgemm % 64 == 0 || K_brgemm < 64)
726 // for K_brgemm reduction value to avoid AMX tiles re-configuration.
727 // To satisfy this condition K_tail value is fixed to K % wei_k_blk here.
728 const bool fixed_K_tail_size
729 = bgmmc.K % bgmmc.wei_k_blk > 0 && bgmmc.K > bgmmc.wei_k_blk;
730 bgmmc.K_blk = bgmmc.K < bgmmc.wei_k_blk
731 ? rnd_up(bgmmc.K, bgmmc.required_k_granularity)
732 : fixed_K_tail_size ? bgmmc.wei_k_blk : bgmmc.K;
733 bgmmc.brgemm_batch_size
734 = nstl::max(bgmmc.K / bgmmc.K_blk, static_cast<dim_t>(1));
735
736 matmul_amx_blocking_params_t best_blocking(bgmmc);
737
738 compute_blocking_heuristic_amx(bgmmc, bm_conf_utils, best_blocking);
739
740 if (best_blocking.get_blocking_scores() == 0.0f)
741 return status::unimplemented;
742
743 best_blocking.update_configuration(bgmmc);
744
745 } else {
746 // TODO:
747 // *) adjust K_BLK using 'rnd_up(bgmmc.K, bgmmc.required_k_granularity)'
748 // for non-f32 datatypes.
749 // *) optimize param search complexity
750
751 // Approach for selecting ideal 'blocking parameters':
752 // M_blk:
753 // - main param for having parallel_work optimally distributed.
754 // - 'br_block' is a BRGeMM uKernel parameter derived from 'M_Blk',
755 // however, there is no measured performance impact from small
756 // variations in 'br_block' size.
757 //
758 // M_Chunks:
759 // - no noticeable performance impact i.e. 'M_blk = M_Chunks * M_Blk';
760 // with M_Chunks > 1', brgemm has the same performance results. Instead,
761 // choose a larger 'M_blk'.
762 //
763 // N_blk:
764 // - ideally 64 (from 'get_default_n_block()').
765 // - can be reduced to 32 to improve performance for some shapes, as
766 // well as increasing parallelization search space.
767 //
768 // N_Chunks:
769 // - No different as long as thread/work balance is the same.
770 // - Note: for A_Transposed cases using A_buffer (i.e. bwd-w): select
771 // a higher count to increase performance -better for transposed data
772 // reuse.
773 //
774 // K_blk:
775 // - block size variation '512 <= K_blk < 1024' has negligible
776 // performance difference. However, Some cases benefit from higher
777 // block size.
778 // - can parallelize if not enough work; notice: requires reduction!
779 //
780 // Batch_Size:
781 // - unused.
782
783 const matmul_avx512_blocking_params_t::matmul_params_t matmul(
784 bgmmc.M, bgmmc.N, bgmmc.K, bgmmc.batch);
785
786 matmul_avx512_blocking_params_t best_blocking(matmul, bgmmc.nthr);
787
788 const float best_imbalance = compute_blocking_heuristic_avx512(
789 bgmmc, bm_conf_utils, matmul, best_blocking);
790
791 if (best_imbalance == 1.f) return status::unimplemented;
792
793 best_blocking.update_configuration(bgmmc);
794 }
795
796 return status::success;
797}
798
799status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
800 const matmul_desc_t &mmd, memory_desc_t &src_md,
801 memory_desc_t &weights_md, memory_desc_t &dst_md,
802 memory_desc_t &bias_md, primitive_attr_t &attr) {
803 const memory_desc_wrapper src_d(&src_md);
804 const memory_desc_wrapper weights_d(&weights_md);
805 const memory_desc_wrapper dst_d(&dst_md);
806
807 bgmmc = zero<decltype(bgmmc)>();
808 bgmmc.isa = isa;
809 bgmmc.nthr = dnnl_get_max_threads();
810 bgmmc.brg_type = brgemm_addr;
811
812 bgmmc.src_dt = src_d.data_type();
813 bgmmc.dst_dt = dst_d.data_type();
814 bgmmc.wei_dt = weights_d.data_type();
815
816 bgmmc.with_bias = mmd.bias_desc.format_kind != format_kind::undef;
817 bgmmc.bia_dt = bgmmc.with_bias ? mmd.bias_desc.data_type : data_type::undef;
818 bgmmc.s8s8_compensation_required
819 = isa == avx512_core_vnni && bgmmc.src_dt == s8;
820 bgmmc.ndims = dst_d.ndims();
821
822 brgemm_matmul_conf_utils_t bm_conf_utils(bgmmc, isa, attr,
823 src_d.format_kind() == format_kind::any,
824 weights_d.format_kind() == format_kind::any,
825 dst_d.format_kind() == format_kind::any,
826 bias_md.format_kind == format_kind::any);
827
828 CHECK(check_isa_with_datatype(isa, bm_conf_utils));
829
830 bgmmc.is_amx = is_superset(isa, avx512_core_amx);
831 bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt);
832 bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt);
833
834 bgmmc.is_bf32 = bm_conf_utils.is_bf32();
835
836 // Make BRGeMM compute MatMul as if it were in bfloat16, while down-convert
837 // happens during copy-buffer computations
838 if (bgmmc.is_bf32) {
839 bgmmc.src_dt = bf16;
840 bgmmc.wei_dt = bf16;
841 bgmmc.tr_a_dt_sz = types::data_type_size(bf16);
842 bgmmc.tr_b_dt_sz = types::data_type_size(bf16);
843 } else if (bm_conf_utils.is_f16() && bgmmc.isa == avx512_core_fp16) {
844 // Similar to bf32, convert input data before compute
845 bgmmc.src_dt = f32;
846 bgmmc.wei_dt = f32;
847 bgmmc.tr_a_dt_sz = types::data_type_size(f32);
848 bgmmc.tr_b_dt_sz = types::data_type_size(f32);
849 }
850
851 bgmmc.acc_dt = bm_conf_utils.is_int8() ? s32 : f32;
852
853 bgmmc.c_dt_sz = types::data_type_size(bgmmc.dst_dt);
854 bgmmc.acc_dt_sz = types::data_type_size(bgmmc.acc_dt);
855 if (bgmmc.with_bias) bgmmc.bias_dt_sz = types::data_type_size(bgmmc.bia_dt);
856
857 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
858 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
859 bgmmc.with_scales = !src_scales.has_default_values()
860 || !wei_scales.has_default_values();
861 if (bgmmc.with_scales) {
862 bgmmc.is_oscale_per_n = wei_scales.mask_ == 1 << (bgmmc.ndims - 1);
863
864 // only common and per-oc-channel scales are supported
865 const bool oscales_ok = wei_scales.mask_ == 0 || bgmmc.is_oscale_per_n;
866 if (!oscales_ok) return status::unimplemented;
867 }
868
869 const auto &p = attr.post_ops_;
870 bgmmc.with_sum = p.find(primitive_kind::sum) != -1;
871 const int eltwise_ind = p.find(primitive_kind::eltwise);
872 bgmmc.with_eltwise = eltwise_ind != -1;
873 const int binary_ind = p.find(primitive_kind::binary);
874 bgmmc.with_binary = binary_ind != -1;
875
876 if (!post_ops_ok(bgmmc, attr, dst_d)) return status::unimplemented;
877
878 bgmmc.src_zp_type = get_zp_type(attr, DNNL_ARG_SRC);
879 bgmmc.wei_zp_type = get_zp_type(attr, DNNL_ARG_WEIGHTS);
880 bgmmc.dst_zp_type = get_zp_type(attr, DNNL_ARG_DST);
881
882 if (!IMPLICATION(!bm_conf_utils.is_int8(),
883 everyone_is(brgemm_broadcast_t::none, bgmmc.src_zp_type,
884 bgmmc.wei_zp_type, bgmmc.dst_zp_type)))
885 return status::unimplemented;
886
887 matmul_helper_t helper(src_d, weights_d, dst_d);
888
889 bgmmc.batch_ndims = bgmmc.ndims - 2;
890 bgmmc.M = helper.M();
891 bgmmc.N = helper.N();
892 bgmmc.K = helper.K();
893 bgmmc.batch = helper.batch();
894 bgmmc.batch_without_first_dim
895 = bgmmc.batch_ndims > 1 ? helper.batch() / dst_d.dims()[0] : 0;
896
897 bgmmc.bcast_A_desc.set_params(
898 src_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch);
899 bgmmc.bcast_B_desc.set_params(
900 weights_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch);
901
902 // Dispatch small shapes to VNNI for better performance
903 const bool is_small_shapes = bgmmc.is_amx && bgmmc.ndims < 3
904 && ((bgmmc.M == 1 && bgmmc.K == 256)
905 || (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256)
906 || bgmmc.K <= 16);
907 if (is_small_shapes) return status::unimplemented;
908
909 // required granularity for k dimension
910 bgmmc.required_k_granularity
911 = bgmmc.is_amx ? data_type_vnni_granularity(bgmmc.wei_dt) : 1;
912 if (bgmmc.required_k_granularity == 0) return status::unimplemented;
913 bgmmc.wei_k_blk = data_type_vnni_simd_elems<avx512_core>(bgmmc.wei_dt);
914
915 CHECK(bm_conf_utils.set_or_check_tags(src_md, dst_md, bias_md));
916 CHECK(bm_conf_utils.set_or_check_B_tag(weights_md));
917
918 bgmmc.req_wei_vnni_downconvert = bm_conf_utils.wei_down_convert_to_vnni();
919
920 CHECK(attr.set_default_formats(&dst_md));
921
922 bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);
923
924 bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
925 bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
926
927 bgmmc.transposed_A = (bm_conf_utils.check_is_transposed(bgmmc.src_tag)
928 || bgmmc.src_tag == adbc);
929 const bool lda_is_big_2pow
930 = (bm_conf_utils.is_bf16()
931 || (bgmmc.is_amx && bm_conf_utils.is_f16()))
932 && !bgmmc.transposed_A && math::is_pow2(bgmmc.K) && bgmmc.K >= 4096
933 && bgmmc.M >= 1024;
934 const bool is_copy_a_required
935 = (bgmmc.is_amx
936 && ((bgmmc.K % bgmmc.required_k_granularity != 0)
937 || bm_conf_utils.is_bf32()))
938 || (bm_conf_utils.is_f16() && isa == avx512_core_fp16)
939 || bgmmc.wei_zp_type != brgemm_broadcast_t::none
940 || bgmmc.transposed_A || lda_is_big_2pow;
941 bgmmc.use_buffer_a = is_copy_a_required;
942
943 // Supported computation with copy only part of A related to K_tail if
944 // is_copy_a_required == true, but the current performance measurements
945 // show worse performance for it in comparison with copy whole A approach
946 // (especially for big K sizes).
947 bgmmc.use_buffer_a_tail_only = false;
948
949 const int dmax = nstl::min(bgmmc.ndims, 3);
950 for (int d = 0; d < dmax; d++) {
951 int dim = bgmmc.ndims - 1 - d;
952 bgmmc.A_strides[d] = bgmmc.a_dt_sz * src_d.blocking_desc().strides[dim];
953 bgmmc.B_strides[d]
954 = bgmmc.b_dt_sz * weights_d.blocking_desc().strides[dim];
955 bgmmc.C_strides[d] = bgmmc.c_dt_sz * dst_d.blocking_desc().strides[dim];
956 }
957
958 // BF32 'Hint' Heuristic:
959 // Under the following conditions, F32 through AVX512_CORE performs better
960 // than using BF32 arithmetic.
961 if (bgmmc.is_bf32 && (bgmmc.M < 8)
962 && ((bgmmc.wei_tag == abcd) || bm_conf_utils.is_any_B_layout()))
963 return status::unimplemented;
964
965 // Heuristic tries to optimize the following parameters:
966 // - M_blk, M_Chunk
967 // - N_blk, N_Chunk
968 // - K_blk, batch_size
969 // - nthr_K
970 CHECK(compute_blocking_heuristic(bgmmc, bm_conf_utils));
971
972 if (bgmmc.wei_n_blk > bgmmc.N_blk
973 && IMPLICATION(
974 bgmmc.N == bgmmc.N_blk, bgmmc.N >= bgmmc.wei_n_blk)) {
975 bgmmc.wei_n_blk = bgmmc.N_blk;
976 CHECK(bm_conf_utils.update_and_check_B_tag(
977 weights_md, bgmmc.wei_n_blk));
978
979 bgmmc.req_wei_vnni_downconvert
980 = bm_conf_utils.wei_down_convert_to_vnni();
981 }
982
983 CHECK(bm_conf_utils.set_B_flags(weights_md));
984
985 bgmmc.M_tail = bgmmc.M % bgmmc.M_blk;
986 bgmmc.N_tail = bgmmc.N % bgmmc.N_blk;
987 bgmmc.K_tail = bgmmc.K > bgmmc.K_blk
988 ? rnd_up(bgmmc.K % bgmmc.K_blk, bgmmc.required_k_granularity)
989 : 0;
990
991 bgmmc.LDB = bm_conf_utils.get_actual_LDB();
992 bgmmc.LDD = bgmmc.dst_tag == acbd ? dst_d.blocking_desc().strides[2]
993 : bgmmc.N;
994 bgmmc.LDC
995 = bgmmc.use_buffer_c && bgmmc.nthr_k <= 1 ? bgmmc.N_blk : bgmmc.LDD;
996
997 init_aux_values(bgmmc, src_d, weights_d, dst_d);
998
999 return status::success;
1000}
1001
1002void init_aux_values(brgemm_matmul_conf_t &bgmmc,
1003 const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
1004 const memory_desc_wrapper &dst_d) {
1005
1006 bgmmc.M_chunk_elems = bgmmc.M_blk * bgmmc.M_chunk_size;
1007 bgmmc.N_chunk_elems = bgmmc.N_blk * bgmmc.N_chunk_size;
1008 bgmmc.K_chunk_elems = bgmmc.K_blk * bgmmc.brgemm_batch_size;
1009 bgmmc.M_chunks = div_up(bgmmc.M, bgmmc.M_chunk_elems);
1010 bgmmc.N_chunks = div_up(bgmmc.N, bgmmc.N_chunk_elems);
1011 bgmmc.K_chunks = div_up(bgmmc.K, bgmmc.K_chunk_elems);
1012 bgmmc.num_M_blocks = div_up(bgmmc.M, bgmmc.M_blk);
1013 bgmmc.num_N_blocks = div_up(bgmmc.N, bgmmc.N_blk);
1014 const int last_chunck_batch_size
1015 = (nstl::max(bgmmc.K, bgmmc.K_blk)
1016 - (bgmmc.K_chunks - 1) * bgmmc.K_chunk_elems)
1017 / bgmmc.K_blk;
1018 bgmmc.brgemm_batch_tail_size
1019 = last_chunck_batch_size % bgmmc.brgemm_batch_size;
1020
1021 bgmmc.buffer_c_chunk_sz = bgmmc.acc_dt_sz * bgmmc.LDC
1022 * (bgmmc.nthr_k > 1 ? bgmmc.M : bgmmc.M_blk);
1023 bgmmc.buffer_c_per_thread_sz = bgmmc.buffer_c_chunk_sz
1024 * (bgmmc.nthr_k > 1 ? 1 : bgmmc.M_chunk_size * bgmmc.N_chunk_size);
1025
1026 bgmmc.buffer_a_chunk_sz = bgmmc.tr_a_dt_sz * bgmmc.M_blk
1027 * (bgmmc.use_buffer_a_tail_only ? bgmmc.wei_k_blk : bgmmc.LDA);
1028 bgmmc.buffer_a_chunk_shift_along_m = bgmmc.buffer_a_chunk_sz
1029 * (bgmmc.use_buffer_a_tail_only ? 1 : bgmmc.brgemm_batch_size);
1030 bgmmc.buffer_a_per_thread_sz
1031 = bgmmc.buffer_a_chunk_shift_along_m * bgmmc.M_chunk_size;
1032
1033 bgmmc.buffer_b_chunk_sz = bgmmc.tr_b_dt_sz * bgmmc.LDB
1034 * rnd_up(bgmmc.K_blk, bgmmc.wei_k_blk);
1035 bgmmc.buffer_b_per_thread_sz
1036 = bgmmc.buffer_b_chunk_sz * bgmmc.brgemm_batch_size;
1037
1038 bgmmc.s8s8_comp_ithr_str
1039 = bgmmc.use_buffer_b ? bgmmc.wei_n_blk * bgmmc.N_chunk_size : 0;
1040 bgmmc.s8s8_comp_b_str = bgmmc.use_buffer_b
1041 ? 0
1042 : div_up(bgmmc.N, bgmmc.wei_n_blk) * bgmmc.wei_n_blk;
1043 bgmmc.s8s8_comp_n_str = bgmmc.wei_n_blk;
1044
1045 bgmmc.A_ptr_shift_b = 0;
1046 bgmmc.copy_A_src_stride = 0;
1047 if (bgmmc.src_tag == acbd || bgmmc.src_tag == adbc) {
1048 const dim_t factor = bgmmc.src_dt == f32 ? 2 : 1;
1049 const dim_t src_stride = bgmmc.src_tag == acbd ? bgmmc.A_strides[1]
1050 : bgmmc.A_strides[0];
1051 bgmmc.copy_A_src_stride = nstl::min(src_d.blocking_desc().strides[0],
1052 src_stride / factor)
1053 * factor;
1054 const dim_t bcast_shift_b = bgmmc.src_tag == acbd ? bgmmc.K : bgmmc.M;
1055 bgmmc.A_ptr_shift_b
1056 = (bgmmc.bcast_A_desc.bcast_mask == 2
1057 ? bcast_shift_b
1058 : src_d.blocking_desc().strides[0])
1059 * bgmmc.a_dt_sz;
1060 }
1061
1062 bgmmc.B_ptr_shift_b = 0;
1063 bgmmc.copy_B_wei_stride = 0;
1064 if (one_of(bgmmc.wei_tag, acbd, adbc)) {
1065 const dim_t factor = bgmmc.wei_dt == f32 ? 2 : 1;
1066 const dim_t wei_stride = bgmmc.wei_tag == acbd ? bgmmc.B_strides[1]
1067 : bgmmc.B_strides[0];
1068 bgmmc.copy_B_wei_stride = nstl::min(wei_d.blocking_desc().strides[0],
1069 wei_stride / factor)
1070 * factor;
1071 const dim_t bcast_shift_b = bgmmc.wei_tag == acbd ? bgmmc.N : bgmmc.K;
1072 bgmmc.B_ptr_shift_b
1073 = (bgmmc.bcast_B_desc.bcast_mask == 2
1074 ? bcast_shift_b
1075 : wei_d.blocking_desc().strides[0])
1076 * bgmmc.b_dt_sz;
1077 }
1078
1079 bgmmc.C_ptr_shift_b = bgmmc.dst_tag == acbd
1080 ? dst_d.blocking_desc().strides[0] * bgmmc.c_dt_sz
1081 : 0;
1082
1083 bgmmc.has_zero_point_a = bgmmc.src_zp_type != brgemm_broadcast_t::none;
1084 bgmmc.has_zero_point_b = bgmmc.wei_zp_type != brgemm_broadcast_t::none;
1085 bgmmc.has_zero_point_c = bgmmc.dst_zp_type != brgemm_broadcast_t::none;
1086 bgmmc.post_ops_applicable = one_of(true, bgmmc.with_sum, bgmmc.with_bias,
1087 bgmmc.with_scales, bgmmc.with_eltwise, bgmmc.with_binary,
1088 bgmmc.acc_dt != bgmmc.dst_dt, bgmmc.s8s8_compensation_required,
1089 bgmmc.has_zero_point_a, bgmmc.has_zero_point_b,
1090 bgmmc.has_zero_point_c);
1091
1092 bgmmc.zp_a_comp_shift_n = bgmmc.wei_n_blk;
1093 bgmmc.zp_a_comp_elems_per_thr
1094 = bgmmc.N_chunk_size * bgmmc.zp_a_comp_shift_n;
1095
1096 const int s32_elems_in_cacheline = 16;
1097 bgmmc.zp_b_comp_result_shift_m = bgmmc.M_blk;
1098 bgmmc.zp_b_comp_buffer_start
1099 = bgmmc.M_chunk_size * bgmmc.zp_b_comp_result_shift_m;
1100 bgmmc.zp_b_comp_buffer_shift_m = s32_elems_in_cacheline * bgmmc.M_blk;
1101 bgmmc.zp_b_comp_elems_per_thr = bgmmc.M_chunk_size
1102 * (bgmmc.zp_b_comp_result_shift_m + bgmmc.zp_b_comp_buffer_shift_m);
1103
1104 bgmmc.brgemm_batch_element_per_thr_sz = 16 * bgmmc.brgemm_batch_size;
1105}
1106
1107void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1108 const brgemm_matmul_conf_t &bgmmc) {
1109 const size_t default_data_align = sizeof(char);
1110 if (bgmmc.brg_type == brgemm_addr)
1111 scratchpad.book(key_brgemm_primitive_batch,
1112 static_cast<size_t>(bgmmc.nthr)
1113 * bgmmc.brgemm_batch_element_per_thr_sz,
1114 sizeof(brgemm_batch_element_t), 64);
1115
1116 if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only)
1117 scratchpad.book(key_brgemm_primitive_buffer_a,
1118 bgmmc.nthr * bgmmc.buffer_a_per_thread_sz, default_data_align);
1119
1120 if (bgmmc.use_buffer_b) {
1121 scratchpad.book(key_brgemm_primitive_buffer_b,
1122 bgmmc.nthr * bgmmc.buffer_b_per_thread_sz, default_data_align);
1123
1124 if (bgmmc.s8s8_compensation_required && (!bgmmc.blocked_B))
1125 scratchpad.book(key_brgemm_primitive_buffer_comp,
1126 bgmmc.nthr * bgmmc.s8s8_comp_ithr_str,
1127 types::data_type_size(f32));
1128 }
1129
1130 if (bgmmc.use_buffer_c)
1131 scratchpad.book(key_brgemm_primitive_buffer,
1132 bgmmc.nthr * bgmmc.buffer_c_per_thread_sz, default_data_align);
1133
1134 if (bgmmc.has_zero_point_a) {
1135 const auto num_elems = bgmmc.nthr * bgmmc.zp_a_comp_elems_per_thr;
1136 scratchpad.book(key_brgemm_primitive_zp_comp_a, num_elems,
1137 types::data_type_size(s32));
1138 }
1139
1140 if (bgmmc.has_zero_point_b)
1141 scratchpad.book(key_brgemm_primitive_zp_comp_b,
1142 bgmmc.nthr * bgmmc.zp_b_comp_elems_per_thr,
1143 types::data_type_size(s32));
1144
1145 if (is_superset(bgmmc.isa, avx512_core_amx))
1146 scratchpad.book(key_conv_amx_tile_buffer,
1147 static_cast<size_t>(bgmmc.nthr) * bgmmc.wsp_tile_per_thr_bytes,
1148 default_data_align);
1149}
1150
1151void matmul_amx_blocking_params_t::update_k_blocking_dependent_params() {
1152 k_chunk_elems_ = k_blk_ * k_chunk_size_;
1153 current_lda_ = get_actual_lda();
1154 need_buf_c_ = is_buffer_c_required();
1155}
1156
1157void matmul_amx_blocking_params_t::set_blocking_parameters(
1158 int nthr_k, int n_blk, int n_chunk_size, int m_blk, int m_chunk_size) {
1159 nthr_k_ = nstl::max(1, nthr_k);
1160 nthr_mnb_ = nthr / nthr_k_;
1161 nthr_ = nthr_mnb_ * nthr_k_;
1162 n_blk_ = n_blk;
1163 n_chunk_size_ = n_chunk_size;
1164 m_blk_ = m_blk;
1165 m_chunk_size_ = m_chunk_size;
1166 if (one_of(0, n_blk_, n_chunk_size_, m_blk_, m_chunk_size_)) {
1167 k_blk_ = k_chunk_size_ = k_chunk_elems_ = 0;
1168 efficiency_score_ = 0.0f;
1169 return;
1170 }
1171
1172 n_chunk_elems_ = n_blk_ * n_chunk_size_;
1173 m_chunk_elems_ = m_blk_ * m_chunk_size_;
1174
1175 if (K < wei_k_blk) {
1176 k_blk_ = is_amx ? rnd_up(K, required_k_granularity) : K;
1177 k_chunk_size_ = 1;
1178 } else {
1179 dim_t k_per_thr = div_up(K, nthr_k_);
1180 k_blk_ = nstl::min(
1181 is_amx ? rnd_up(k_per_thr, required_k_granularity) : k_per_thr,
1182 static_cast<dim_t>(wei_k_blk));
1183 k_chunk_size_ = nstl::min(nstl::max(static_cast<dim_t>(1), K / k_blk_),
1184 div_up(k_per_thr, k_blk_));
1185
1186 update_k_blocking_dependent_params();
1187 auto chunk_sz = calculate_chunk_memory_size();
1188 float k_div = (float)chunk_sz / L2_threshold();
1189 if (k_div > 1.0f)
1190 k_chunk_size_ = static_cast<int>(
1191 static_cast<float>(k_chunk_size_) / k_div + 0.6f);
1192
1193 const dim_t current_k_tail = K % k_blk_;
1194 if (current_k_tail == 0 && K % (k_blk_ * k_chunk_size_) == 0) {
1195 k_blk_ *= k_chunk_size_;
1196 k_chunk_size_ = 1;
1197 } else if (nthr_k_ == 1
1198 && K == k_blk_ * k_chunk_size_ + current_k_tail) {
1199 k_blk_ *= k_chunk_size_;
1200 k_chunk_size_ = 2;
1201 }
1202 }
1203
1204 update_k_blocking_dependent_params();
1205
1206 blocking_chunk_mem_size_ = calculate_chunk_memory_size();
1207
1208 efficiency_score_ = calculate_blocking_scores();
1209}
1210
1211// returns score for current blocking parameters' values in range [0, 1]
1212// for parallel work over threads distribution score. Maximum scores - when
1213// all threads have the same work amount w/o tails
1214float matmul_amx_blocking_params_t::get_thread_balance_scores() {
1215 dim_t num_M_chunks = div_up(M, m_chunk_elems_);
1216 dim_t num_N_chunks = div_up(N, n_chunk_elems_);
1217 float mnb_parallel_score = batch * ((float)M / m_chunk_elems_)
1218 * ((float)N / n_chunk_elems_)
1219 / rnd_up(batch * num_M_chunks * num_N_chunks, nthr_mnb_)
1220 * nthr_mnb_;
1221 float k_parallel_score = 1.0f;
1222 if (nthr_k_ > 1) {
1223 dim_t num_K_chunks = div_up(K, k_chunk_elems_);
1224 const float parallel_reduction_penalty = 0.8f;
1225 k_parallel_score = parallel_reduction_penalty
1226 * ((float)K / k_chunk_elems_) / rnd_up(num_K_chunks, nthr_k_)
1227 * nthr_k_;
1228 }
1229
1230 return mnb_parallel_score * k_parallel_score / nthr;
1231}
1232
1233// returns score for current blocking parameters' values in range [0, 1]
1234// for copied data reusage
1235float matmul_amx_blocking_params_t::get_copied_data_reusage_scores() {
1236 const int desired_M_chunk = use_buffer_b
1237 ? nstl::min(4, rnd_up(static_cast<int>(M), m_blk_))
1238 : 1;
1239 const int desired_N_chunk = use_buffer_a
1240 ? nstl::min(4, rnd_up(static_cast<int>(N), n_blk_))
1241 : 1;
1242
1243 return 0.5f
1244 * (nstl::min((float)m_chunk_size_ / desired_M_chunk, 1.0f)
1245 + nstl::min((float)n_chunk_size_ / desired_N_chunk, 1.0f));
1246}
1247
1248// returns score for current blocking parameters' values in range [0, 1]
1249// for L2 utilization
1250float matmul_amx_blocking_params_t::get_L2_utilization_scores() const {
1251 const float relative_difference_with_L2
1252 = fabsf((float)L2_threshold() - blocking_chunk_mem_size_)
1253 / nstl::max(L2_threshold(), blocking_chunk_mem_size_);
1254 return 1.0f - relative_difference_with_L2;
1255}
1256
1257// returns score for current blocking parameters' values in range [0, 1]
1258// consists of 3 parts with its own weights:
1259// 1) parallel work over threads distribution score
1260// 2) L2 utilization score
1261// 3) copied data re-usage score
1262float matmul_amx_blocking_params_t::calculate_blocking_scores() {
1263 if (one_of(0, n_blk_, n_chunk_size_, m_blk_, m_chunk_size_, k_blk_,
1264 k_chunk_size_))
1265 return 0.0f;
1266
1267 const float nthr_coeff = nstl::min(nthr, 100);
1268 const float reusage_factor = 1.0f;
1269 const float balance_factor = (nthr_coeff - 1.0f) / nthr_coeff;
1270 const float cache_utilization_factor = 1.0f / nthr_coeff;
1271
1272 float scores = cache_utilization_factor * get_L2_utilization_scores()
1273 + reusage_factor * get_copied_data_reusage_scores();
1274 if (balance_factor > 0.0f)
1275 scores += balance_factor * get_thread_balance_scores();
1276 return scores
1277 / (reusage_factor + balance_factor + cache_utilization_factor);
1278}
1279
1280void matmul_amx_blocking_params_t::update_configuration(
1281 brgemm_matmul_conf_t &bgmmc) const {
1282 bgmmc.nthr_k = nthr_k_;
1283 bgmmc.M_blk = m_blk_;
1284 bgmmc.M_chunk_size = m_chunk_size_;
1285 bgmmc.N_blk = n_blk_;
1286 bgmmc.N_chunk_size = n_chunk_size_;
1287
1288 bgmmc.K_blk = k_blk_;
1289 bgmmc.brgemm_batch_size = k_chunk_size_;
1290
1291 bgmmc.use_buffer_c = need_buf_c_;
1292 bgmmc.LDA = current_lda_;
1293}
1294
1295dim_t matmul_amx_blocking_params_t::get_actual_lda() {
1296 if (!use_buffer_a) return src_tag == acbd ? A_strides[1] / a_dt_sz : K;
1297
1298 constexpr int bytes_in_cacheline = 64;
1299 const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz;
1300 dim_t lda = rnd_up(k_blk_, elems_in_cacheline);
1301 const bool is_big_2_pow = lda >= 512 && math::is_pow2(lda);
1302 if (is_big_2_pow) lda += elems_in_cacheline;
1303 return lda;
1304}
1305
1306bool matmul_amx_blocking_params_t::is_buffer_c_required() {
1307 if (nthr_k_ > 1 && K > k_chunk_elems_) return true;
1308
1309 return ((acc_dt != dst_dt || with_sum)
1310 && (K > k_chunk_elems_ || K % k_blk_ > 0));
1311}
1312
1313size_t matmul_amx_blocking_params_t::calculate_chunk_memory_size() {
1314 size_t A_chunk_sz = a_dt_sz * k_chunk_elems_ * m_chunk_elems_;
1315 size_t A_buf_sz = use_buffer_a
1316 ? tr_a_dt_sz * current_lda_ * k_chunk_size_ * m_chunk_elems_
1317 : 0;
1318 size_t B_chunk_sz = b_dt_sz * k_chunk_elems_ * n_chunk_elems_;
1319 size_t B_buf_sz = use_buffer_b ? tr_b_dt_sz * n_blk_ * k_chunk_elems_ : 0;
1320 size_t C_chunk_sz = c_dt_sz * m_chunk_elems_ * n_chunk_elems_;
1321 size_t C_buf_sz
1322 = need_buf_c_ ? acc_dt_sz * m_chunk_elems_ * n_chunk_elems_ : 0;
1323 return A_chunk_sz + A_buf_sz + B_chunk_sz + B_buf_sz + C_chunk_sz
1324 + C_buf_sz;
1325}
1326
1327} // namespace matmul
1328} // namespace x64
1329} // namespace cpu
1330} // namespace impl
1331} // namespace dnnl
1332