1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <memory>
17
18#include "common/c_types_map.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/platform.hpp"
24#include "cpu/x64/brgemm/brgemm.hpp"
25#include "cpu/x64/brgemm/brgemm_types.hpp"
26#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
27#include "cpu/x64/jit_generator.hpp"
28
29#define GET_OFF(field) offsetof(brgemm_kernel_params_t, field)
30#define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field)
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37using namespace dnnl::impl::utils;
38using namespace Xbyak;
39
40struct jit_brgemm_amx_uker_base_t : public jit_generator {
41 jit_brgemm_amx_uker_base_t(const brgemm_t &abrg)
42 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core)
43 , brg(abrg)
44 , postops_injector_(nullptr) {
45
46 if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
47
48 static constexpr bool preserve_gpr = true;
49 // we don't use zmm1 for storing vectors
50 // so we don't need to preserve vmm
51 static constexpr bool preserve_vmm = false;
52 static constexpr bool use_exact_tail_scalar_bcast = false;
53 const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md);
54
55 static const bcast_set_t enabled_bcast_strategy
56 = {broadcasting_strategy_t::scalar,
57 broadcasting_strategy_t::per_oc,
58 broadcasting_strategy_t::per_oc_spatial,
59 broadcasting_strategy_t::per_mb_spatial,
60 broadcasting_strategy_t::per_mb_w,
61 broadcasting_strategy_t::per_w,
62 broadcasting_strategy_t::no_broadcast};
63 const binary_injector::rhs_arg_static_params_t rhs_sp {
64 static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
65 this->r15, this->r13, preserve_gpr, preserve_vmm,
66 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
67 dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
68 ld_tail_mask, use_exact_tail_scalar_bcast};
69 const binary_injector::static_params_t bsp {
70 this->param1, enabled_bcast_strategy, rhs_sp};
71
72 eltwise_injector::static_params_t esp;
73 esp.preserve_vmm = preserve_vmm;
74 esp.preserve_p_table = false;
75
76 postops_injector_ = utils::make_unique<po_injector_t>(
77 this, brg.attr->post_ops_, bsp, esp);
78
79 using namespace dnnl::impl::cpu::binary_injector_utils;
80 std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_,
81 with_binary_channel_bcast_, with_binary_per_mb_w_bcast_,
82 with_binary_per_w_bcast_, with_binary_no_bcast_)
83 = bcast_strategies_present_tup(brg.attr->post_ops_.entry_,
84 dst_md_wrapper, broadcasting_strategy_t::per_oc,
85 broadcasting_strategy_t::per_oc_spatial,
86 broadcasting_strategy_t::per_mb_spatial,
87 broadcasting_strategy_t::per_mb_w,
88 broadcasting_strategy_t::per_w,
89 broadcasting_strategy_t::no_broadcast);
90 handle_binary_po_offset_ = with_binary_per_oc_bcast_
91 || with_binary_per_oc_sp_bcast_
92 || with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_
93 || with_binary_per_w_bcast_ || with_binary_no_bcast_;
94 }
95 use_ils_ = brg.brgattr.use_interleave_stores;
96 }
97
98 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_amx_uker_base_t)
99
100 brgemm_t brg;
101
102private:
103 static constexpr cpu_isa_t po_isa_ = avx512_core_fp16;
104 using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_>;
105 std::unique_ptr<po_injector_t> postops_injector_;
106
107 using reg64_t = const Xbyak::Reg64;
108 enum {
109 simd_w = 16,
110 zmm_width_in_bytes = cpu_isa_traits<avx512_core>::vlen,
111 tile_size = 1024
112 };
113
114 // Register decomposition
115 const reg64_t param1 = abi_param1;
116
117 const reg64_t reg_addr_batch = r13;
118 const reg64_t reg_aux1_batch = rbp;
119 const reg64_t reg_aux_A = r11;
120 const reg64_t reg_aux_B = r10;
121 const reg64_t reg_stride_lda = r14;
122 const reg64_t reg_stride_ldb = abi_not_param1;
123 const reg64_t reg_C = r15;
124 const reg64_t reg_D = r12;
125
126 const reg64_t reg_buf = r8;
127 const reg64_t reg_BS = rbx;
128 const reg64_t reg_BS_loop = r9;
129 const reg64_t reg_bias = rbx;
130 const reg64_t reg_scales = rbx;
131
132 const reg64_t reg_stride_ld_block = rdx;
133 const reg64_t reg_do_post_ops = rbx;
134 const reg64_t reg_tmp_gpr = rbx;
135 const reg64_t reg_ptr_sum_scale = rbx;
136
137 const reg64_t reg_zp_comp_a = rbx;
138 const reg64_t reg_aux_zp_comp_a = rbx;
139 const reg64_t reg_zp_comp_b = rbx;
140 const reg64_t reg_zp_c_values = rbx;
141 const reg64_t reg_ptr_sum_zp = r9;
142 const reg64_t reg_bf32_stride = rsi;
143
144 constexpr static int abi_param1_offs_ = 0;
145 constexpr static int reg_zp_comp_a_offs_ = 8;
146 constexpr static int reg_zp_comp_b_offs_ = 16;
147 constexpr static int reg_zp_c_values_offs_ = 24;
148 constexpr static int stack_space_needed_ = 32;
149
150 bool are_post_ops_applicable_ = false;
151 bool need_to_apply_alpha_beta_ = false;
152 bool may_load_accumulators_ = false;
153
154 bool handle_binary_po_offset_ = false;
155 bool with_binary_per_oc_bcast_ = false;
156 bool with_binary_per_oc_sp_bcast_ = false;
157 bool with_binary_channel_bcast_ = false;
158 bool with_binary_per_mb_w_bcast_ = false;
159 bool with_binary_per_w_bcast_ = false;
160 bool with_binary_no_bcast_ = false;
161 bool prepare_post_ops_registers_once_ = false;
162
163 char *bd_mask_buffer_ptr_ = nullptr;
164 std::vector<size_t> adj_bd_mask_buffer_;
165 size_t *adj_bd_mask_buffer_ptr_ = nullptr;
166 std::vector<size_t> skipped_bd_mask_buffer_;
167 size_t *skipped_bd_mask_buffer_ptr_ = nullptr;
168 palette_config_t palette_;
169 // used to store offsets within wsp buffer where the data is
170 // transformed(downconverted), to reuse when needed.
171 std::unordered_map<std::string, size_t> transform_buf_map_A_;
172 std::unordered_map<std::string, size_t> transform_buf_map_B_;
173
174 size_t LDA_size_ = 0, LDA2_size_ = 0;
175 size_t LDB_size_ = 0, LDB2_size_ = 0;
176 size_t LDC_size_ = 0, LDC2_size_M_ = 0, LDC2_size_N_ = 0;
177 size_t LDD_size_ = 0;
178 size_t ld_block_B_size_ = 0;
179 size_t ld_block_C_size_ = 0;
180 size_t ld_block_D_size_ = 0;
181 size_t ld_block_bias_size_ = 0;
182 size_t ld_block_scales_size_ = 0;
183 size_t ld_block_zp_size_ = 0;
184
185 size_t ldb_tail_B_size_ = 0;
186 size_t ldb_tail_C_size_ = 0;
187 size_t ldb_tail_D_size_ = 0;
188 size_t ldb_tail_zp_size_ = 0;
189
190 enum matrix_kind_t { matrix_A, matrix_B, matrix_C, matrix_D };
191
192 // Loops in brgemm kernel are (two outermost loops depend on loop order):
193 // by bd block2
194 // by ld block2
195 // by batch_size
196 // by rd block
197 // gemm_microkernel
198 // Structures below (dim_iteration_t, bd_iteration_t, bs_iteration_t and
199 // iteration_map_t) describe the structure of cycles
200 // and are used for JIT code generation
201 struct dim_iteration_t {
202 size_t idx = 0;
203 size_t pos = 0;
204 int block = 0;
205 int block2 = 0;
206 bool is_tail = false;
207 dim_iteration_t() = default;
208 dim_iteration_t(
209 size_t pos_, int block_, int block2_, bool is_tail_ = false)
210 : pos(pos_), block(block_), block2(block2_), is_tail(is_tail_) {}
211 };
212
213 struct bd_iteration_t : public dim_iteration_t {
214 std ::vector<size_t> bdb_pos;
215 bd_iteration_t() = default;
216 bd_iteration_t(
217 size_t pos_, int block_, int block2_, bool is_tail_ = false)
218 : dim_iteration_t(pos_, block_, block2_, is_tail_) {}
219 };
220
221 struct bs_iteration_t {
222 size_t idx = 0;
223 size_t pos = 0;
224 bool is_first = false;
225 bool is_last = false;
226 bs_iteration_t() = default;
227 bs_iteration_t(
228 size_t pos_, bool is_first_ = true, bool is_last_ = false)
229 : pos(pos_), is_first(is_first_), is_last(is_last_) {}
230 };
231
232 struct iteration_map_t {
233 std::vector<dim_iteration_t> ldis;
234 std::vector<bd_iteration_t> bdis;
235 std::vector<bs_iteration_t> bsis;
236 std::vector<dim_iteration_t> rdis;
237 bool is_last_ldi(const dim_iteration_t &ldi) const {
238 return (ldi.idx == ldis.size() - 1);
239 }
240 bool is_last_bdi(const dim_iteration_t &bdi) const {
241 return (bdi.idx == bdis.size() - 1);
242 }
243 bool is_last_rdi(const dim_iteration_t &rdi) const {
244 return (rdi.idx == rdis.size() - 1);
245 }
246 };
247
248 struct brgemm_iteration_t {
249 bd_iteration_t bdi;
250 dim_iteration_t ldi;
251 bs_iteration_t bsi;
252 dim_iteration_t rdi;
253 brgemm_iteration_t *prev_iter = nullptr;
254 bool apply_postops = false;
255 brgemm_iteration_t() = default;
256 };
257
258 struct prf_t {
259 brgemm_kernel_prefetching_t pft;
260 int dist = -1;
261 int vec = 0;
262 };
263
264 // iteration map
265 iteration_map_t imap_;
266
267 // interleave stores
268 bool use_ils_ = false;
269 bool was_prev_bi = false;
270 // saved parameters for storing
271 brgemm_iteration_t prev_bi_;
272 // current storing coordinates
273 int ils_vec_ = 0, ils_bdb_ = 0, ils_ldb_ = 0, ils_bd_start_ = 0;
274 int ils_bd_step_ = 3; // heuristic value
275 prf_t prf1A, prf2A, prf1B, prf2B, prf1C, prf2C;
276
277 bool dt_requires_saturation_ = false;
278
279 Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2);
280 Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3);
281 Xbyak::Opmask bf32_col_mask = Xbyak::Opmask(4);
282
283 // Zmm map below
284 const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm0; }
285 const Xbyak::Zmm &zmm_tmp_2() const noexcept { return this->zmm1; }
286 const Xbyak::Zmm &zmm_tmp_3() const noexcept { return this->zmm2; }
287
288 const Xbyak::Zmm zmm_bf32_pemute = zmm6;
289 const Xbyak::Zmm zmm_zp_comp_a = zmm6;
290 const Xbyak::Zmm zmm_zp_c = zmm7;
291 const Xbyak::Zmm zmm_lbound = zmm8;
292 const Xbyak::Zmm zmm_ubound = zmm9;
293
294 // zmm_bias, zmm_bias and accm shouldn't be overlapped
295 Xbyak::Zmm accm(int bd) const {
296 assert(bd < 16);
297 return Xbyak::Zmm(31 - (bd % ils_bd_step_));
298 }
299
300 Xbyak::Zmm zmm_bias(int ldb) const {
301 assert(ldb < 5);
302 // zmm10 - zmm14
303 return Xbyak::Zmm(10 + ldb);
304 }
305
306 Xbyak::Zmm zmm_scales(int ldb) const {
307 assert(ldb < 5);
308 assert(ils_bd_step_ < 10);
309 // zmm15 - zmm19
310 return Xbyak::Zmm(15 + ldb);
311 }
312
313 Xbyak::Zmm zmm_mask(const Xbyak::Zmm zmm_in, bool mask_flag, bool store,
314 Xbyak::Opmask ktail_mask) const;
315 Xbyak::Ymm ymm_mask(const Xbyak::Ymm ymm_in, bool mask_flag, bool store,
316 Xbyak::Opmask ktail_mask) const;
317
318 void cvt2ps(data_type_t type_in, const Xbyak::Zmm zmm_in,
319 const Xbyak::Operand &op, bool mask_flag, bool store,
320 Xbyak::Opmask ktail_mask);
321
322 void read_params();
323 void load_accumulators(brgemm_iteration_t &bi);
324
325 void maybe_saturation(Xbyak::Zmm &zmm);
326 void apply_alpha_beta_to_vector(
327 const int idx, const Address &addr, bool is_ld_tail);
328 void apply_post_ops_to_range(brgemm_iteration_t &bi, int bd_start,
329 int bd_finish, int bd_inp_bdb, int ldb);
330 void store_vector_with_post_ops(const int idx, const Address &addr,
331 const int bd, const int ldb, bool is_ld_tail);
332 void prepare_post_ops_registers_ldb(brgemm_iteration_t &bi, int ldb);
333 void prepare_post_ops_registers(brgemm_iteration_t &bi);
334
335 bool bi_shift_output(
336 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
337 bool bi_shift_A(
338 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
339 bool bi_shift_B(
340 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi);
341
342 void uni_prefetch(const Address &addr, brgemm_kernel_prefetching_t pft);
343 void prefetch_CD_range(brgemm_iteration_t &bi,
344 brgemm_kernel_prefetching_t pft, int bd_start, int bd_finish,
345 int bd_inp_bdb, int ldb);
346 int calc_ops_CD(brgemm_iteration_t &bi) const noexcept;
347 void prefetch_CD(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
348 prf_t &prf, bool prefetch_all);
349
350 void prefetch_A(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
351 prf_t &prf, bool prefetch_all);
352 void prefetch_B(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi,
353 prf_t &prf, bool prefetch_all);
354 void prefetching(brgemm_iteration_t &bi, bool prefetch_all);
355
356 void process_output_range(brgemm_iteration_t &bi, int bd_start,
357 int bd_finish, int bd_inp_bdb, int bdb, int ldb);
358 void store_vector_without_post_ops(
359 const int idx, const Address &addr, bool is_ld_tail);
360 void store_vector(
361 brgemm_iteration_t &bi, const int idx, const int bd, const int ldb);
362
363 void interleave_store(brgemm_iteration_t &bi, bool store_all);
364
365 void store_accumulators(brgemm_iteration_t &bi);
366
367 void set_A_B_matrices(int bs);
368 void set_A_B_matrices();
369
370 void bf32_downconvert(int num_rows, int tile_num_col_bytes,
371 reg64_t reg_data, int offset, reg64_t reg_data_stride,
372 reg64_t reg_buf, bool is_rd_tail);
373
374 void bf32_downconvert_to_vnni(int num_rows, int tile_num_col_bytes,
375 reg64_t reg_data, int offset, reg64_t reg_data_stride,
376 reg64_t reg_buf, bool is_rd_tail);
377
378 void maybe_pre_process_data(brgemm_iteration_t &bi, const Tmm &t1,
379 reg64_t reg_base, size_t offset, reg64_t reg_stride,
380 matrix_kind_t mk);
381
382 void maybe_tileloadd_nt(
383 brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset);
384
385 void tdpbxxd(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx,
386 bool do_pre_tilestore, bool do_post_tilestore);
387
388 void gemm_microkernel_amx(brgemm_iteration_t &bi);
389
390 void rdb_loop(brgemm_iteration_t &bi);
391
392 void bs_loop_body(brgemm_iteration_t &bi);
393 void bs_loop(brgemm_iteration_t &bi);
394
395 void ldb_loop_body(brgemm_iteration_t &bi);
396 void ldb_loop(brgemm_iteration_t &bi);
397
398 void bdb_loop_body(brgemm_iteration_t &bi);
399 void bdb_loop(brgemm_iteration_t &bi);
400
401 void init(brgemm_iteration_t &bi);
402 void generate() override;
403
404 void prepare_bd_mask() noexcept;
405 int skipped_bd_mask(int inp_bd) noexcept;
406
407 bool get_store_by_vectors(bool apply_post_ops) const {
408 const bool need_to_apply_post_ops
409 = are_post_ops_applicable_ && apply_post_ops;
410 const auto store_by_vectors = need_to_apply_alpha_beta_
411 || need_to_apply_post_ops || brg.brgattr.bd_mask_level;
412 return store_by_vectors;
413 }
414
415 size_t A_offset(int bdb) const noexcept;
416
417 size_t B_offset(int ldb) const noexcept;
418 size_t C_offset(int bd, int ldb) const noexcept;
419 size_t C_block_offset(int bd, int ldb) const noexcept;
420 size_t D_offset(int bd, int ldb) const noexcept;
421
422 size_t lda() const noexcept;
423 size_t ldb() const noexcept;
424 size_t rdb_A_offset(const brgemm_iteration_t &bi) const noexcept;
425 size_t rdb_B_offset(const brgemm_iteration_t &bi) const noexcept;
426 size_t ldb_B_offset(const brgemm_iteration_t &bi) const noexcept;
427
428 size_t bias_offset(int ldb) const noexcept;
429
430 size_t scales_offset(int ldb) const noexcept;
431 size_t zp_comp_a_offset(int ldb) const noexcept;
432 size_t zp_comp_b_offset(int bd) const noexcept;
433 size_t zp_c_values_offset(brgemm_iteration_t &bi, int ldb) const noexcept;
434 int get_out_bd(int bd_inp_bdb, int bd) const;
435
436 void maybe_tilestore(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx,
437 bool do_pre_tilestore, bool do_post_tilestore);
438 int get_C_tensor(brgemm_iteration_t &bi, int m, int n) const noexcept;
439 void top_loop(brgemm_iteration_t &bi);
440
441 void fill_imap();
442};
443
444bool jit_brgemm_amx_uker_base_t::bi_shift_output(
445 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
446 res_bi = bi;
447 if (shift == 0) return true;
448
449 size_t lidx = 0;
450 size_t bd_idx = 0;
451 size_t ld_idx = 0;
452 if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost) {
453 lidx = bi.bdi.idx * imap_.ldis.size() + bi.ldi.idx;
454 lidx += shift;
455 bd_idx = lidx / imap_.ldis.size();
456 ld_idx = lidx % imap_.ldis.size();
457 } else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost) {
458 lidx = bi.ldi.idx * imap_.bdis.size() + bi.bdi.idx;
459 lidx += shift;
460 ld_idx = lidx / imap_.bdis.size();
461 bd_idx = lidx % imap_.bdis.size();
462 } else
463 assert(!"Unknown loop order!");
464 if (lidx >= imap_.ldis.size() * imap_.bdis.size()) return false;
465 res_bi.bdi = imap_.bdis[bd_idx];
466 res_bi.ldi = imap_.ldis[ld_idx];
467
468 return true;
469}
470
471bool jit_brgemm_amx_uker_base_t::bi_shift_A(
472 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
473 res_bi = bi;
474 auto lidx = bi.bdi.idx * imap_.rdis.size() + bi.rdi.idx;
475 lidx += shift;
476 if (lidx >= imap_.rdis.size() * imap_.bdis.size()) return false;
477
478 const auto bd_idx = lidx / imap_.rdis.size();
479 const auto rd_idx = lidx % imap_.rdis.size();
480
481 res_bi.bdi = imap_.bdis[bd_idx];
482 res_bi.rdi = imap_.rdis[rd_idx];
483
484 return true;
485}
486
487bool jit_brgemm_amx_uker_base_t::bi_shift_B(
488 brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) {
489 res_bi = bi;
490 auto lidx = bi.ldi.idx * imap_.rdis.size() + bi.rdi.idx;
491 lidx += shift;
492 if (lidx >= imap_.rdis.size() * imap_.ldis.size()) return false;
493
494 const auto ld_idx = lidx / imap_.rdis.size();
495 const auto rd_idx = lidx % imap_.rdis.size();
496
497 res_bi.ldi = imap_.ldis[ld_idx];
498 res_bi.rdi = imap_.rdis[rd_idx];
499
500 return true;
501}
502
503int jit_brgemm_amx_uker_base_t::get_C_tensor(
504 brgemm_iteration_t &bi, int m, int n) const noexcept {
505 return brg.get_C_tensor(m, n, bi.bdi.is_tail, bi.ldi.is_tail);
506}
507
508void jit_brgemm_amx_uker_base_t::prepare_bd_mask() noexcept {
509 if (!brg.brgattr.bd_mask_level) return;
510 bd_mask_buffer_ptr_ = brg.brgattr.bd_mask;
511 const auto bd_mask_size = brg.bcast_dim;
512 adj_bd_mask_buffer_.resize(bd_mask_size);
513 adj_bd_mask_buffer_ptr_ = adj_bd_mask_buffer_.data();
514 skipped_bd_mask_buffer_.resize(bd_mask_size);
515 skipped_bd_mask_buffer_ptr_ = skipped_bd_mask_buffer_.data();
516 if (!utils::any_null(bd_mask_buffer_ptr_, adj_bd_mask_buffer_ptr_)) {
517 int out_ibd = 0;
518 for (int i = 0; i < bd_mask_size; i++) {
519 adj_bd_mask_buffer_ptr_[i] = out_ibd;
520 out_ibd += bd_mask_buffer_ptr_[i];
521 skipped_bd_mask_buffer_ptr_[i] = i;
522 for (auto ii = i; ii < bd_mask_size; ii++) {
523 if (bd_mask_buffer_ptr_[ii]) {
524 skipped_bd_mask_buffer_ptr_[i] = ii;
525 break;
526 }
527 }
528 }
529 } else
530 assert(!"struct nullptr error");
531}
532
533int jit_brgemm_amx_uker_base_t::skipped_bd_mask(int inp_bd) noexcept {
534 if (brg.brgattr.bd_mask_level != 2)
535 return inp_bd;
536 else
537 return skipped_bd_mask_buffer_ptr_[inp_bd];
538}
539
540size_t jit_brgemm_amx_uker_base_t::A_offset(int bdb) const noexcept {
541 return bdb * LDA2_size_;
542}
543
544size_t jit_brgemm_amx_uker_base_t::B_offset(int ldb) const noexcept {
545 return (brg.is_blocked ? 1 : brg.rd_step) * ldb * ld_block_B_size_;
546}
547
548size_t jit_brgemm_amx_uker_base_t::C_offset(int bd, int ldb) const noexcept {
549 return bd * LDC_size_ + ldb * ld_block_C_size_;
550}
551
552size_t jit_brgemm_amx_uker_base_t::C_block_offset(int bd, int ldb) const
553 noexcept {
554 return (size_t)bd * LDC2_size_M_ + (size_t)ldb * LDC2_size_N_;
555}
556
557size_t jit_brgemm_amx_uker_base_t::D_offset(int bd, int ldb) const noexcept {
558 return bd * LDD_size_ + ldb * ld_block_D_size_;
559}
560
561size_t jit_brgemm_amx_uker_base_t::lda() const noexcept {
562 return LDA_size_;
563}
564
565size_t jit_brgemm_amx_uker_base_t::ldb() const noexcept {
566 return LDB_size_ * brg.rd_step;
567}
568
569size_t jit_brgemm_amx_uker_base_t::rdb_A_offset(
570 const brgemm_iteration_t &bi) const noexcept {
571 return bi.rdi.pos * brg.typesize_A * brg.rd_block;
572}
573
574size_t jit_brgemm_amx_uker_base_t::rdb_B_offset(
575 const brgemm_iteration_t &bi) const noexcept {
576 return bi.rdi.pos * brg.rd_block * LDB_size_;
577}
578
579size_t jit_brgemm_amx_uker_base_t::ldb_B_offset(
580 const brgemm_iteration_t &bi) const noexcept {
581 return bi.ldi.pos * ld_block_B_size_ * brg.ld_step;
582}
583
584size_t jit_brgemm_amx_uker_base_t::bias_offset(int ldb) const noexcept {
585 return ldb * ld_block_bias_size_;
586}
587
588size_t jit_brgemm_amx_uker_base_t::scales_offset(int ldb) const noexcept {
589 return brg.is_oc_scale * ldb * ld_block_scales_size_;
590}
591
592size_t jit_brgemm_amx_uker_base_t::zp_comp_a_offset(int ldb) const noexcept {
593 return ldb * ld_block_zp_size_;
594}
595
596size_t jit_brgemm_amx_uker_base_t::zp_comp_b_offset(int bd) const noexcept {
597 return sizeof(int32_t) * bd;
598}
599
600size_t jit_brgemm_amx_uker_base_t::zp_c_values_offset(
601 brgemm_iteration_t &bi, int ldb) const noexcept {
602 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
603 return (bi.ldi.is_tail) ? ldb_tail_zp_size_
604 : (bi.ldi.pos + ldb) * ld_block_zp_size_;
605 }
606
607 return 0;
608}
609
610int jit_brgemm_amx_uker_base_t::get_out_bd(int bd_inp_bdb, int bd) const {
611 const auto bd_out_bd = bd_inp_bdb + bd;
612 if (brg.brgattr.bd_mask_level && !bd_mask_buffer_ptr_[bd_out_bd])
613 return -1;
614 else {
615 if (brg.brgattr.bd_mask_level)
616 return adj_bd_mask_buffer_ptr_[bd_out_bd];
617 else
618 return bd_out_bd;
619 }
620}
621
622Xbyak::Zmm jit_brgemm_amx_uker_base_t::zmm_mask(const Xbyak::Zmm zmm_in,
623 bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
624 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
625 : zmm_in;
626}
627
628Xbyak::Ymm jit_brgemm_amx_uker_base_t::ymm_mask(const Xbyak::Ymm ymm_in,
629 bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
630 return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z)
631 : ymm_in;
632}
633
634void jit_brgemm_amx_uker_base_t::cvt2ps(data_type_t type_in,
635 const Xbyak::Zmm zmm_in, const Xbyak::Operand &op, bool mask_flag,
636 bool store, Xbyak::Opmask ktail_mask) {
637 const Xbyak::Zmm zmm = zmm_mask(zmm_in, mask_flag, store, ktail_mask);
638 switch (type_in) {
639 case data_type::f32:
640 case data_type::s32: vmovups(zmm, op); break;
641 case data_type::bf16:
642 vpmovzxwd(zmm, op);
643 vpslld(zmm, zmm, 16);
644 break;
645 case data_type::f16: vcvtph2ps(zmm, op); break;
646 case data_type::s8: vpmovsxbd(zmm, op); break;
647 case data_type::u8: vpmovzxbd(zmm, op); break;
648 default: assert(!"unsupported data type");
649 }
650 if (types::is_integral_dt(type_in)) vcvtdq2ps(zmm_in, zmm_in);
651}
652
653void jit_brgemm_amx_uker_base_t::read_params() {
654 Label label_done;
655
656 mov(reg_C, ptr[param1 + GET_OFF(ptr_C)]);
657 mov(reg_D, ptr[param1 + GET_OFF(ptr_D)]);
658 mov(reg_BS, ptr[param1 + GET_OFF(BS)]);
659
660 mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]);
661
662 mov(reg_buf, ptr[param1 + GET_OFF(ptr_buf)]);
663
664 if (brg.zp_type_a != brgemm_broadcast_t::none) {
665 mov(reg_zp_comp_a, ptr[param1 + GET_OFF(a_zp_compensations)]);
666 mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a);
667 }
668
669 if (brg.zp_type_b != brgemm_broadcast_t::none) {
670 mov(reg_zp_comp_b, ptr[param1 + GET_OFF(b_zp_compensations)]);
671 mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b);
672 }
673
674 if (brg.zp_type_c != brgemm_broadcast_t::none) {
675 mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]);
676 mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values);
677 }
678}
679
680void jit_brgemm_amx_uker_base_t::load_accumulators(brgemm_iteration_t &bi) {
681 assert(IMPLICATION(bi.ldi.is_tail, bi.ldi.block2 == 1));
682 if (may_load_accumulators_) mov(reg_stride_ld_block, LDC_size_);
683
684 for (int bdb = 0; bdb < bi.bdi.block2; bdb++) {
685 const auto bd_out_bdb = get_out_bd(bi.bdi.bdb_pos[bdb], 0);
686 for (int ldb = 0; ldb < bi.ldi.block2; ldb++) {
687 if (may_load_accumulators_) {
688 const auto c_offset
689 = C_block_offset(bd_out_bdb, bi.ldi.pos + ldb);
690 tileloadd(Tmm(get_C_tensor(bi, bdb, ldb)),
691 ptr[reg_C + c_offset + reg_stride_ld_block]);
692 } else {
693 // call tilezero on very first iteration
694 if (!brg.interleave_tilestores_
695 || everyone_is(0u, bi.bdi.idx, bi.ldi.idx))
696 tilezero(Tmm(get_C_tensor(bi, bdb, ldb)));
697 }
698 }
699 }
700}
701
702void jit_brgemm_amx_uker_base_t::apply_alpha_beta_to_vector(
703 const int idx, const Address &addr, bool is_ld_tail) {
704 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
705 auto zmm = Zmm(idx);
706 auto zmm_beta = zmm_tmp_1();
707 auto zmm_alpha = zmm_tmp_2();
708 auto zmm_prev_dst = zmm_tmp_3();
709
710 const bool apply_alpha = brg.alpha != 1.f;
711 const bool apply_beta = brg.beta != 0.f;
712 if (!apply_alpha && !apply_beta) return;
713
714 const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f);
715 const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
716
717 if (apply_beta && !use_vadd_for_beta) {
718 mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta)));
719 vmovq(Xmm(zmm_beta.getIdx()), reg_tmp_gpr);
720 vbroadcastss(zmm_beta, Xmm(zmm_beta.getIdx()));
721 }
722 if (apply_alpha) {
723 mov(reg_tmp_gpr, float2int(static_cast<float>(brg.alpha)));
724 vmovq(Xmm(zmm_alpha.getIdx()), reg_tmp_gpr);
725 vbroadcastss(zmm_alpha, Xmm(zmm_alpha.getIdx()));
726 }
727 if (dq2ps_required) vcvtdq2ps(zmm, zmm);
728 if (apply_alpha) vmulps(zmm, zmm, zmm_alpha);
729 if (apply_beta) {
730 if (use_vadd_for_beta) {
731 auto zmm_masked = zmm | k_mask | T_z;
732 if (brg.is_int8)
733 vpaddd(zmm_masked, zmm, addr);
734 else
735 vaddps(zmm_masked, zmm, addr);
736 } else {
737 cvt2ps(brg.dt_c, zmm_prev_dst, addr, true, false, k_mask);
738 vfmadd231ps(zmm, zmm_prev_dst, zmm_beta);
739 }
740 }
741}
742
743void jit_brgemm_amx_uker_base_t::apply_post_ops_to_range(brgemm_iteration_t &bi,
744 int bd_start, int bd_finish, int bd_inp_bdb, int ldb) {
745 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
746
747 auto ldb_pos = bi.ldi.pos + ldb;
748 auto is_ld_tail = bi.ldi.is_tail;
749
750 if (brg.with_binary) {
751 if (handle_binary_po_offset_) {
752 for (int bd = bd_start; bd < bd_finish; bd++) {
753 // We have no way to tell the injector to skip some vectors.
754 // Therefore, we must set parameters correctly for all registers.
755 // TODO: Make it possible to specify "skipped" vectors to injector
756 const auto idx = accm(bd).getIdx();
757 if (is_ld_tail) rhs_arg_params.vmm_tail_idx_.emplace(idx);
758 rhs_arg_params.vmm_idx_to_out_reg.emplace(idx, reg_D);
759
760 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
761 if (bd_out_bd == -1) continue;
762
763 const auto d_offset = D_offset(bd_out_bd, ldb_pos);
764 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
765 idx, d_offset);
766 }
767 }
768 }
769
770 const auto sum_injector = [&] {
771 const float *p_sum_scale = &brg.sum_scale;
772 const int32_t *p_sum_zp = &brg.sum_zp;
773 const bool p_sum_scale_reg_set = *p_sum_scale != 1.f;
774 const bool p_sum_zp_reg_set = *p_sum_zp != 0;
775
776 {
777 if (p_sum_scale_reg_set)
778 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
779
780 const auto &zmm_sum_zp = zmm_tmp_2();
781 if (p_sum_zp_reg_set) {
782 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
783 vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
784 }
785
786 const auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
787 const auto zmm_prev_dst = Xbyak::Zmm(0);
788
789 for (int bd = bd_start; bd < bd_finish; bd++) {
790 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
791 if (bd_out_bd == -1) continue;
792
793 auto zmm = accm(bd);
794 const auto d_offset = D_offset(bd_out_bd, ldb_pos);
795 auto addr = EVEX_compress_addr(reg_D, d_offset);
796
797 cvt2ps(brg.sum_dt, zmm_prev_dst, addr, true, false, k_mask);
798 if (p_sum_zp_reg_set) vsubps(zmm_prev_dst, zmm_sum_zp);
799 if (!p_sum_scale_reg_set)
800 vaddps(zmm, zmm_prev_dst);
801 else
802 vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
803 }
804 }
805 };
806
807 if (brg.with_sum) {
808 postops_injector_->set_lambda_injector(
809 primitive_kind::sum, sum_injector);
810 }
811
812 // Using knowledge how "accm" assign zmm registers.
813 // TODO: make this code more clear
814 const auto finish_idx = accm(bd_start).getIdx() + 1;
815 const auto start_idx = accm(bd_finish - 1).getIdx();
816 postops_injector_->compute_vector_range(
817 start_idx, finish_idx, rhs_arg_params);
818}
819
820void jit_brgemm_amx_uker_base_t::maybe_saturation(Xbyak::Zmm &zmm) {
821 if (!dt_requires_saturation_) return;
822 saturate_f32(zmm, zmm_lbound, zmm_ubound, brg.dt_d);
823 vcvtps2dq(zmm, zmm);
824}
825
826void jit_brgemm_amx_uker_base_t::prepare_post_ops_registers_ldb(
827 brgemm_iteration_t &bi, int ldb) {
828 if (!bi.apply_postops) return;
829 auto k_mask = (!bi.ldi.is_tail) ? ld_full_mask : ld_tail_mask;
830
831 if (brg.zp_type_a != brgemm_broadcast_t::none) {
832 mov(reg_aux_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]);
833
834 const auto zp_comp_a_off = zp_comp_a_offset(bi.ldi.pos + ldb);
835 const auto zp_comp_a_addr
836 = EVEX_compress_addr(reg_aux_zp_comp_a, zp_comp_a_off);
837 cvt2ps(data_type::s32, zmm_zp_comp_a, zp_comp_a_addr, true, false,
838 k_mask);
839 }
840
841 if (brg.zp_type_c != brgemm_broadcast_t::none) {
842 mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
843 if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
844 vcvtdq2ps(zmm_zp_c, EVEX_compress_addr(reg_zp_c_values, 0, true));
845 }
846 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
847 const auto zp_c_off = zp_c_values_offset(bi, ldb);
848 const auto zp_c_addr
849 = EVEX_compress_addr(reg_zp_c_values, zp_c_off);
850 cvt2ps(data_type::s32, zmm_zp_c, zp_c_addr, true, false, k_mask);
851 }
852 }
853}
854
855void jit_brgemm_amx_uker_base_t::prepare_post_ops_registers(
856 brgemm_iteration_t &bi) {
857 if (!bi.apply_postops) return;
858 dim_iteration_t &ldi = bi.ldi;
859 auto k_mask = (!ldi.is_tail) ? ld_full_mask : ld_tail_mask;
860 if (brg.with_scales) {
861 mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]);
862 for (int ldb = 0; ldb < ldi.block2; ldb++) {
863 auto scales_ptr = EVEX_compress_addr(
864 reg_scales, scales_offset(ldi.pos + ldb));
865 vmovups(zmm_scales(ldb) | k_mask | T_z, scales_ptr);
866 }
867 }
868
869 if (brg.with_bias) {
870 mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]);
871
872 for (int ldb = 0; ldb < ldi.block2; ldb++) {
873 auto ptr_bias
874 = EVEX_compress_addr(reg_bias, bias_offset(ldi.pos + ldb));
875 cvt2ps(brg.dt_bias, zmm_bias(ldb), ptr_bias, true, false, k_mask);
876 }
877 }
878}
879
880void jit_brgemm_amx_uker_base_t::uni_prefetch(
881 const Address &addr, brgemm_kernel_prefetching_t pft) {
882 if (pft == brgemm_kernel_prefetching_t::brgemm_prf1)
883 prefetcht1(addr);
884 else if (pft == brgemm_kernel_prefetching_t::brgemm_prf2)
885 prefetcht2(addr);
886}
887
888void jit_brgemm_amx_uker_base_t::prefetch_CD_range(brgemm_iteration_t &bi,
889 brgemm_kernel_prefetching_t pft, int bd_start, int bd_finish,
890 int bd_inp_bdb, int ldb) {
891 auto ldb_pos = bi.ldi.pos + ldb;
892 for (int bd = bd_start; bd < bd_finish; bd++) {
893 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
894 if (bd_out_bd == -1) continue;
895 if (bi.apply_postops) {
896 const auto d_offset = D_offset(bd_out_bd, ldb_pos);
897 auto ptr_D = EVEX_compress_addr(reg_D, d_offset);
898 uni_prefetch(ptr_D, pft);
899 } else if (are_post_ops_applicable_) {
900 const auto c_offset = C_offset(bd_out_bd, ldb_pos);
901 auto ptr_C = EVEX_compress_addr(reg_C, c_offset);
902 uni_prefetch(ptr_C, pft);
903 } else {
904 const auto d_offset = D_offset(bd_out_bd, ldb_pos);
905 auto ptr_D = EVEX_compress_addr(reg_D, d_offset);
906 uni_prefetch(ptr_D, pft);
907 }
908 }
909}
910
911int jit_brgemm_amx_uker_base_t::calc_ops_CD(brgemm_iteration_t &bi) const
912 noexcept {
913 return (brg.rdb + (brg.rdb_tail ? 1 : 0)) * bi.ldi.block2 * bi.bdi.block2
914 * (brg.brgattr.var_bs ? 1 : brg.brgattr.max_bs);
915}
916
917void jit_brgemm_amx_uker_base_t::prefetch_CD(brgemm_iteration_t &bi,
918 brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) {
919
920 const auto calc_ops = calc_ops_CD(bi);
921 const auto bdb_row = pfo_bi.bdi.block * pfo_bi.ldi.block2;
922 const auto tot_vecs = pfo_bi.bdi.block2 * bdb_row;
923 const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0;
924
925 const auto nvecs = prefetch_all
926 ? tot_vecs
927 : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec);
928
929 const auto out_typesize
930 = (are_post_ops_applicable_ && !prev_bi_.apply_postops)
931 ? brg.typesize_C
932 : brg.typesize_D;
933 for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) {
934 const auto bdb = prf.vec / bdb_row;
935 const auto vec_in_bdb_row = prf.vec - bdb * bdb_row;
936 const auto ldb = vec_in_bdb_row / pfo_bi.bdi.block;
937 const auto bd = vec_in_bdb_row % pfo_bi.bdi.block;
938 // prefetch output cache lines only once
939 if ((pfo_bi.ldi.pos + ldb) % (4 / out_typesize) == 0) {
940 auto bd_inp_bdb = pfo_bi.bdi.bdb_pos[bdb];
941 prefetch_CD_range(pfo_bi, prf.pft, 0, 1, bd_inp_bdb + bd, ldb);
942 }
943 prf.vec++;
944 }
945}
946
947void jit_brgemm_amx_uker_base_t::prefetch_A(brgemm_iteration_t &bi,
948 brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) {
949
950 const auto calc_ops = bi.ldi.block2 * bi.bdi.block2;
951 const auto tot_vecs = pfo_bi.bdi.block2 * pfo_bi.bdi.block;
952 const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0;
953
954 const auto nvecs = prefetch_all
955 ? tot_vecs
956 : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec);
957
958 const auto rdb_A_off = rdb_A_offset(pfo_bi);
959
960 for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) {
961 const auto bdb = prf.vec / pfo_bi.bdi.block;
962 const auto bd = prf.vec % pfo_bi.bdi.block;
963 const auto bd_inp_bdb = pfo_bi.bdi.bdb_pos[bdb];
964
965 //TODO: looks like we have to prefetch in each bs separately
966 const auto ptr_A = EVEX_compress_addr(
967 reg_aux_A, A_offset(bd_inp_bdb) + bd * LDA_size_ + rdb_A_off);
968 uni_prefetch(ptr_A, prf.pft);
969 prf.vec++;
970 }
971}
972
973void jit_brgemm_amx_uker_base_t::prefetch_B(brgemm_iteration_t &bi,
974 brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) {
975
976 const auto calc_ops = bi.ldi.block2 * bi.bdi.block2;
977 const auto tot_vecs = pfo_bi.ldi.block2 * pfo_bi.rdi.block;
978 const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0;
979
980 const auto nvecs = prefetch_all
981 ? tot_vecs
982 : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec);
983
984 // TODO: check these addressing for correctness
985 const auto rdb_B_off = rdb_B_offset(pfo_bi) + ldb_B_offset(pfo_bi);
986
987 for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) {
988
989 const auto ldb = prf.vec / pfo_bi.rdi.block;
990 const auto rb = prf.vec % pfo_bi.rdi.block;
991 //TODO: looks like we have to prefetch in each bs separately
992 const auto ptr_B = EVEX_compress_addr(
993 reg_aux_B, B_offset(ldb) + rdb_B_off + rb * LDB_size_);
994
995 uni_prefetch(ptr_B, prf.pft);
996 prf.vec++;
997 }
998}
999
1000void jit_brgemm_amx_uker_base_t::prefetching(
1001 brgemm_iteration_t &bi, bool prefetch_all) {
1002 // for var_bs we do prefetch on last iteration by bs only
1003 if (brg.brgattr.var_bs && !bi.bsi.is_last) return;
1004 brgemm_iteration_t pfo_bi;
1005 if (brg.prfC.dist1 >= 0) {
1006 bool is_pfo_bi = false;
1007 brgemm_iteration_t pfo_bi;
1008 if (use_ils_ && get_store_by_vectors(bi.apply_postops)) {
1009 if (was_prev_bi && brg.prfC.dist1 == 0) {
1010 is_pfo_bi = true;
1011 pfo_bi = prev_bi_;
1012 } else if (brg.prfC.dist1 > 0) {
1013 is_pfo_bi = bi_shift_output(bi, brg.prfC.dist1 - 1, pfo_bi);
1014 }
1015 } else {
1016 is_pfo_bi = bi_shift_output(bi, brg.prfC.dist1, pfo_bi);
1017 }
1018 if (is_pfo_bi) prefetch_CD(bi, pfo_bi, prf1C, prefetch_all);
1019 }
1020 if (brg.prfC.dist2 >= 0) {
1021 bool is_pfo_bi = false;
1022 brgemm_iteration_t pfo_bi;
1023 if (use_ils_ && get_store_by_vectors(bi.apply_postops)) {
1024 if (was_prev_bi && brg.prfC.dist2 == 0) {
1025 is_pfo_bi = true;
1026 pfo_bi = prev_bi_;
1027 } else if (brg.prfC.dist2 > 0) {
1028 is_pfo_bi = bi_shift_output(bi, brg.prfC.dist2 - 1, pfo_bi);
1029 }
1030 } else {
1031 is_pfo_bi = bi_shift_output(bi, brg.prfC.dist2, pfo_bi);
1032 }
1033 if (is_pfo_bi) prefetch_CD(bi, pfo_bi, prf2C, prefetch_all);
1034 }
1035 if (brg.prfA.dist1 >= 0) {
1036 if (bi_shift_A(bi, brg.prfA.dist1, pfo_bi))
1037 prefetch_A(bi, pfo_bi, prf1A, prefetch_all);
1038 }
1039 if (brg.prfA.dist2 >= 0) {
1040 if (bi_shift_A(bi, brg.prfA.dist2, pfo_bi))
1041 prefetch_A(bi, pfo_bi, prf2A, prefetch_all);
1042 }
1043 if (brg.prfB.dist1 >= 0) {
1044 if (bi_shift_B(bi, brg.prfB.dist1, pfo_bi))
1045 prefetch_B(bi, pfo_bi, prf1B, prefetch_all);
1046 }
1047 if (brg.prfB.dist2 >= 0) {
1048 if (bi_shift_B(bi, brg.prfB.dist2, pfo_bi))
1049 prefetch_B(bi, pfo_bi, prf2B, prefetch_all);
1050 }
1051}
1052
1053void jit_brgemm_amx_uker_base_t::process_output_range(brgemm_iteration_t &bi,
1054 int bd_start, int bd_finish, int bd_inp_bdb, int bdb, int ldb) {
1055
1056 const auto wsp_offset = (use_ils_ || brg.interleave_tilestores_)
1057 ? (bdb * prev_bi_.ldi.block2 + ldb) * brg.bd_block
1058 * ld_block_C_size_
1059 : 0;
1060
1061 const auto k_mask = (!bi.ldi.is_tail) ? ld_full_mask : ld_tail_mask;
1062
1063 // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
1064 // accumulated values are already converted to ps in apply_alpha_beta()
1065 const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f;
1066 const bool beta_uses_vadd
1067 = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f);
1068 const bool dq2ps_required = brg.is_int8
1069 && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd);
1070
1071 bool some_bd_mask = false;
1072 for (int bd = bd_start; bd < bd_finish; bd++) {
1073 auto zmm = accm(bd);
1074 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1075 if (bd_out_bd == -1) continue;
1076
1077 auto vreg_acc
1078 = bi.ldi.is_tail ? accm(bd) | ld_tail_mask | T_z : accm(bd);
1079 some_bd_mask = true;
1080
1081 const auto buf_offset = bd * ld_block_C_size_;
1082 vmovups(vreg_acc, ptr[reg_buf + buf_offset + wsp_offset]);
1083
1084 const auto c_offset = C_offset(bd_out_bd, bi.ldi.pos + ldb);
1085 const auto ptr_C = EVEX_compress_addr(reg_C, c_offset);
1086
1087 if (need_to_apply_alpha_beta_)
1088 apply_alpha_beta_to_vector(zmm.getIdx(), ptr_C, bi.ldi.is_tail);
1089
1090 if (!bi.apply_postops) continue;
1091
1092 if (dq2ps_required) vcvtdq2ps(zmm, zmm);
1093 }
1094
1095 if (!bi.apply_postops || !some_bd_mask) return;
1096
1097 if (brg.zp_type_a != brgemm_broadcast_t::none) {
1098 for (int bd = bd_start; bd < bd_finish; bd++) {
1099 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1100 if (bd_out_bd == -1) continue;
1101
1102 auto zmm = accm(bd);
1103 vaddps(zmm, zmm, zmm_zp_comp_a);
1104 }
1105 }
1106
1107 if (brg.zp_type_b != brgemm_broadcast_t::none) {
1108 mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
1109
1110 auto zmm_zp_comp_b = zmm_tmp_1();
1111 for (int bd = bd_start; bd < bd_finish; bd++) {
1112 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1113 if (bd_out_bd == -1) continue;
1114
1115 auto zmm = accm(bd);
1116
1117 const auto zp_comp_b_off = zp_comp_b_offset(bd_out_bd);
1118 vcvtdq2ps(zmm_zp_comp_b,
1119 EVEX_compress_addr(reg_zp_comp_b, zp_comp_b_off, true));
1120
1121 vaddps(zmm, zmm, zmm_zp_comp_b);
1122 }
1123 }
1124
1125 if (brg.with_scales) {
1126 for (int bd = bd_start; bd < bd_finish; bd++) {
1127 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1128 if (bd_out_bd == -1) continue;
1129
1130 auto zmm = accm(bd);
1131 const Xbyak::Zmm scaled_zmm = zmm_mask(zmm, true, false, k_mask);
1132 vmulps(scaled_zmm, scaled_zmm, zmm_scales(ldb));
1133 }
1134 }
1135
1136 if (brg.with_bias) {
1137 for (int bd = bd_start; bd < bd_finish; bd++) {
1138 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1139 if (bd_out_bd == -1) continue;
1140
1141 auto zmm = accm(bd);
1142 vaddps(zmm, zmm, zmm_bias(ldb));
1143 }
1144 }
1145
1146 if (postops_injector_) {
1147 apply_post_ops_to_range(bi, bd_start, bd_finish, bd_inp_bdb, ldb);
1148 }
1149}
1150
1151void jit_brgemm_amx_uker_base_t::store_vector_with_post_ops(const int idx,
1152 const Address &addr, const int bd, const int ldb, bool is_ld_tail) {
1153 auto zmm = Zmm(idx);
1154 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
1155
1156 if (brg.zp_type_c != brgemm_broadcast_t::none) vaddps(zmm, zmm, zmm_zp_c);
1157
1158 maybe_saturation(zmm);
1159
1160 auto ymm = Xbyak::Ymm(idx);
1161 const Xbyak::Zmm r_zmm = zmm_mask(zmm, true, true, k_mask);
1162 const Xbyak::Ymm r_ymm = ymm_mask(ymm, true, true, k_mask);
1163
1164 switch (brg.dt_d) {
1165 case data_type::f32:
1166 case data_type::s32: vmovups(addr, r_zmm); break;
1167 case data_type::bf16:
1168 vcvtneps2bf16(ymm, zmm);
1169 vmovdqu16(addr, r_ymm);
1170 break;
1171 case data_type::f16:
1172 vcvtps2ph(ymm, zmm, _op_mxcsr);
1173 vmovdqu16(addr, r_ymm);
1174 break;
1175 case data_type::s8: vpmovsdb(addr, r_zmm); break;
1176 case data_type::u8: vpmovusdb(addr, r_zmm); break;
1177 default: assert(!"unknown dst_dt");
1178 }
1179}
1180
1181void jit_brgemm_amx_uker_base_t::store_vector_without_post_ops(
1182 const int idx, const Address &addr, bool is_ld_tail) {
1183 auto zmm = Zmm(idx);
1184
1185 maybe_saturation(zmm);
1186
1187 if (is_ld_tail)
1188 vmovups(addr | ld_tail_mask | T_z, zmm);
1189 else
1190 vmovups(addr, zmm);
1191}
1192
1193void jit_brgemm_amx_uker_base_t::store_vector(
1194 brgemm_iteration_t &bi, const int idx, const int bd, const int ldb) {
1195 auto ldb_pos = bi.ldi.pos + ldb;
1196 auto is_ld_tail = bi.ldi.is_tail;
1197 const auto c_offset = C_offset(bd, ldb_pos);
1198 const auto d_offset = D_offset(bd, ldb_pos);
1199
1200 auto ptr_C = EVEX_compress_addr(reg_C, c_offset);
1201 auto ptr_D = EVEX_compress_addr(reg_D, d_offset);
1202
1203 if (bi.apply_postops)
1204 store_vector_with_post_ops(idx, ptr_D, bd, ldb_pos, is_ld_tail);
1205 else if (are_post_ops_applicable_)
1206 store_vector_without_post_ops(idx, ptr_C, is_ld_tail);
1207 else
1208 store_vector_without_post_ops(idx, ptr_D, is_ld_tail);
1209}
1210
1211void jit_brgemm_amx_uker_base_t::interleave_store(
1212 brgemm_iteration_t &bi, bool store_all) {
1213
1214 if (!use_ils_) return;
1215 if (!was_prev_bi) return;
1216 if (!get_store_by_vectors(prev_bi_.apply_postops)) return;
1217
1218 if (store_all) prefetching(prev_bi_, true);
1219
1220 const auto bd_inp_bdb = prev_bi_.bdi.pos;
1221
1222 auto cur_bdb = ils_bdb_;
1223 auto cur_ldb = ils_ldb_;
1224
1225 // if first block
1226 if (ils_vec_ == 0) {
1227 if (!prepare_post_ops_registers_once_) {
1228 prepare_post_ops_registers(prev_bi_);
1229 }
1230 prepare_post_ops_registers_ldb(prev_bi_, 0);
1231 ils_bd_start_ = 0;
1232 auto bd_finish = nstl::min(ils_bd_step_, prev_bi_.bdi.block);
1233 process_output_range(
1234 prev_bi_, 0, bd_finish, bd_inp_bdb, cur_bdb, cur_ldb);
1235 }
1236
1237 const auto calc_ops = calc_ops_CD(bi);
1238 const auto ils_store_ops
1239 = prev_bi_.ldi.block2 * prev_bi_.bdi.block2 * prev_bi_.bdi.block;
1240 const auto ils_vecs_per_store
1241 = (calc_ops) ? div_up(ils_store_ops, calc_ops) : 0;
1242
1243 // last bd_block may be bd_tail
1244 const auto bdb_row = prev_bi_.bdi.block * prev_bi_.ldi.block2;
1245 const auto total_vectors = prev_bi_.bdi.block2 * bdb_row;
1246 const auto nvecs = store_all ? total_vectors : ils_vecs_per_store;
1247 for (int vec = 0; vec < nvecs && ils_vec_ < total_vectors; vec++) {
1248 const auto bdb = ils_vec_ / bdb_row;
1249 const auto vec_in_bdb_row = ils_vec_ - bdb * bdb_row;
1250 const auto ldb = vec_in_bdb_row / prev_bi_.bdi.block;
1251 const auto bd = vec_in_bdb_row % prev_bi_.bdi.block;
1252
1253 auto bd_inp_bdb = prev_bi_.bdi.bdb_pos[bdb];
1254 if (ldb != cur_ldb) prepare_post_ops_registers_ldb(prev_bi_, ldb);
1255
1256 if (bdb != cur_bdb || ldb != cur_ldb
1257 || rnd_dn(bd, ils_bd_step_) != ils_bd_start_) {
1258 ils_bd_start_ = rnd_dn(bd, ils_bd_step_);
1259 auto bd_finish = nstl::min(
1260 ils_bd_start_ + ils_bd_step_, prev_bi_.bdi.block);
1261 process_output_range(
1262 prev_bi_, ils_bd_start_, bd_finish, bd_inp_bdb, bdb, ldb);
1263 }
1264
1265 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1266 if (bd_out_bd != -1) {
1267 auto vreg_acc = prev_bi_.ldi.is_tail ? accm(bd) | ld_tail_mask | T_z
1268 : accm(bd);
1269
1270 store_vector(prev_bi_, vreg_acc.getIdx(), bd_out_bd, ldb);
1271 }
1272 cur_bdb = bdb;
1273 cur_ldb = ldb;
1274 ils_vec_++;
1275 }
1276 ils_ldb_ = cur_ldb;
1277 ils_bdb_ = cur_bdb;
1278}
1279
1280void jit_brgemm_amx_uker_base_t::store_accumulators(brgemm_iteration_t &bi) {
1281
1282 const auto store_by_vectors = get_store_by_vectors(bi.apply_postops);
1283
1284 if (store_by_vectors) {
1285 if (!brg.interleave_tilestores_)
1286 mov(reg_stride_ld_block, ld_block_C_size_);
1287 } else
1288 mov(reg_stride_ld_block, LDC_size_);
1289
1290 prev_bi_ = bi;
1291
1292 ils_vec_ = 0;
1293 ils_bdb_ = 0;
1294 ils_ldb_ = 0;
1295 was_prev_bi = true;
1296
1297 prf1C.vec = 0;
1298 prf2C.vec = 0;
1299
1300 if (store_by_vectors && !use_ils_ && !prepare_post_ops_registers_once_)
1301 prepare_post_ops_registers(bi);
1302
1303 for (int bdb = 0; bdb < bi.bdi.block2; bdb++) {
1304 const auto bd_inp_bdb = bi.bdi.bdb_pos[bdb];
1305
1306 for (int ldb = 0; ldb < bi.ldi.block2; ldb++) {
1307 if (store_by_vectors) {
1308 if (!brg.interleave_tilestores_) {
1309 const auto wsp_offset = use_ils_
1310 ? (bdb * bi.ldi.block2 + ldb) * brg.bd_block
1311 * ld_block_C_size_
1312 : 0;
1313 tilestored(ptr[reg_buf + reg_stride_ld_block + wsp_offset],
1314 Tmm(get_C_tensor(bi, bdb, ldb)));
1315 }
1316 if (use_ils_) continue;
1317
1318 prepare_post_ops_registers_ldb(bi, ldb);
1319
1320 for (int bd_step = 0; bd_step < bi.bdi.block;
1321 bd_step += ils_bd_step_) {
1322 auto bd_finish
1323 = nstl::min(bd_step + ils_bd_step_, bi.bdi.block);
1324 process_output_range(
1325 bi, bd_step, bd_finish, bd_inp_bdb, bdb, ldb);
1326
1327 for (int bd = bd_step; bd < bd_finish; bd++) {
1328 const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd);
1329 if (bd_out_bd == -1) continue;
1330
1331 auto vreg_acc = bi.ldi.is_tail
1332 ? accm(bd) | ld_tail_mask | T_z
1333 : accm(bd);
1334 store_vector(bi, vreg_acc.getIdx(), bd_out_bd, ldb);
1335 }
1336 }
1337 } else if (!brg.interleave_tilestores_) {
1338 const auto bd_out_bdb = get_out_bd(bd_inp_bdb, 0);
1339 const auto c_offset
1340 = C_block_offset(bd_out_bdb, bi.ldi.pos + ldb);
1341 tilestored(ptr[reg_C + reg_stride_ld_block + c_offset],
1342 Tmm(get_C_tensor(bi, bdb, ldb)));
1343 }
1344 }
1345 }
1346}
1347
1348void jit_brgemm_amx_uker_base_t::set_A_B_matrices(int bs) {
1349 assert(brg.type == brgemm_addr);
1350 if (brg.brgattr.max_bs == 1) return;
1351 auto batch_offset = (size_t)bs * sizeof(brgemm_batch_element_t);
1352 if (brg.layout == brgemm_row_major) {
1353 mov(reg_aux_A,
1354 EVEX_compress_addr(reg_addr_batch,
1355 batch_offset + GET_OFF_BATCH_ELEMENT(ptr.A)));
1356 mov(reg_aux_B,
1357 EVEX_compress_addr(reg_addr_batch,
1358 batch_offset + GET_OFF_BATCH_ELEMENT(ptr.B)));
1359 } else {
1360 mov(reg_aux_A,
1361 EVEX_compress_addr(reg_addr_batch,
1362 batch_offset + GET_OFF_BATCH_ELEMENT(ptr.B)));
1363 mov(reg_aux_B,
1364 EVEX_compress_addr(reg_addr_batch,
1365 batch_offset + GET_OFF_BATCH_ELEMENT(ptr.A)));
1366 }
1367}
1368
1369void jit_brgemm_amx_uker_base_t::set_A_B_matrices() {
1370 assert(brg.type == brgemm_addr);
1371 assert(brg.brgattr.var_bs);
1372
1373 if (brg.layout == brgemm_row_major) {
1374 mov(reg_aux_A, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]);
1375 mov(reg_aux_B, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]);
1376 } else {
1377 mov(reg_aux_A, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]);
1378 mov(reg_aux_B, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]);
1379 }
1380}
1381
1382void jit_brgemm_amx_uker_base_t::maybe_tileloadd_nt(
1383 brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset) {
1384
1385 const bool is_A = mk == matrix_kind_t::matrix_A;
1386 bool load_nt = is_A ? brg.load_nt_A : brg.load_nt_B;
1387
1388 auto t1 = Tmm(is_A ? brg.get_A_tensor(xdb, bi.bdi.is_tail)
1389 : brg.get_B_tensor(xdb, bi.ldi.is_tail));
1390 auto reg_base = is_A ? reg_aux_A : reg_aux_B;
1391 auto reg_stride = is_A ? reg_stride_lda : reg_stride_ldb;
1392
1393 if (brg.is_bf32)
1394 // try_load_nt is not supported in maybe_pre_process_data as there is
1395 // no guarantee that the data is cacheline aligned.
1396 maybe_pre_process_data(bi, t1, reg_base, offset, reg_stride, mk);
1397 else if (load_nt)
1398 tileloaddt1(t1, ptr[reg_base + offset + reg_stride]);
1399 else
1400 tileloadd(t1, ptr[reg_base + offset + reg_stride]);
1401}
1402
1403void jit_brgemm_amx_uker_base_t::maybe_tilestore(brgemm_iteration_t &bi,
1404 int bdb_idx, int ldb_idx, bool do_pre_tilestore,
1405 bool do_post_tilestore) {
1406 auto current_tensor_idx = get_C_tensor(bi, bdb_idx, ldb_idx);
1407
1408 if (!brg.interleave_tilestores_) return;
1409 const auto current_tensor_number
1410 = current_tensor_idx - get_C_tensor(bi, 0, 0);
1411 const auto store_tensor_shift
1412 = do_pre_tilestore ? (bi.bdi.block2 == brg.bdb2_tail ? 2 : 1) : 0;
1413 const auto store_tensor_idx = current_tensor_idx + store_tensor_shift;
1414 const auto store_tensor_number = current_tensor_number + store_tensor_shift;
1415
1416 const auto max_store_tensor_number
1417 = prev_bi_.bdi.block2 * prev_bi_.ldi.block2;
1418 bool perform_store
1419 = (do_pre_tilestore
1420 && (store_tensor_number >= 2
1421 && store_tensor_number < max_store_tensor_number))
1422 || (do_post_tilestore && (store_tensor_number < 2));
1423
1424 if (perform_store) {
1425 if (do_pre_tilestore) {
1426 bdb_idx = store_tensor_idx / bi.ldi.block2;
1427 ldb_idx = store_tensor_idx % bi.ldi.block2;
1428 }
1429 const bool store_by_vectors = get_store_by_vectors(bi.apply_postops);
1430 Tmm acc = Tmm(store_tensor_idx);
1431 if (store_by_vectors) {
1432 const auto wsp_offset = (use_ils_ || brg.interleave_tilestores_)
1433 ? (bdb_idx * bi.ldi.block2 + ldb_idx) * brg.bd_block
1434 * ld_block_C_size_
1435 : 0;
1436 tilestored(ptr[reg_buf + reg_stride_ld_block + wsp_offset], acc);
1437 } else {
1438 const auto &store_inp_bdi
1439 = do_pre_tilestore ? prev_bi_.bdi : bi.bdi;
1440 const auto store_ldb_ind
1441 = do_pre_tilestore ? prev_bi_.ldi.pos : bi.ldi.pos;
1442 const auto bd_out_bdb
1443 = get_out_bd(store_inp_bdi.bdb_pos[bdb_idx], 0);
1444 const auto c_offset
1445 = C_block_offset(bd_out_bdb, store_ldb_ind + ldb_idx);
1446 tilestored(ptr[reg_C + reg_stride_ld_block + c_offset], acc);
1447 }
1448 tilezero(acc);
1449 }
1450}
1451
1452void jit_brgemm_amx_uker_base_t::tdpbxxd(brgemm_iteration_t &bi, int bdb_idx,
1453 int ldb_idx, bool do_pre_tilestore, bool do_post_tilestore) {
1454 prefetching(bi, false);
1455 maybe_tilestore(bi, bdb_idx, ldb_idx, do_pre_tilestore, false);
1456
1457 const Tmm &x1 = Tmm(get_C_tensor(bi, bdb_idx, ldb_idx));
1458 const Tmm &x2 = Tmm(brg.get_A_tensor(bdb_idx, bi.bdi.is_tail));
1459 const Tmm &x3 = Tmm(brg.get_B_tensor(ldb_idx, bi.ldi.is_tail));
1460
1461 if (brg.is_bf32
1462 || (brg.dt_a == data_type::bf16 && brg.dt_b == data_type::bf16)) {
1463 tdpbf16ps(x1, x2, x3);
1464 } else if (brg.dt_a == data_type::f16 && brg.dt_b == data_type::f16) {
1465 tdpfp16ps(x1, x2, x3);
1466 } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8) {
1467 tdpbuud(x1, x2, x3);
1468 } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) {
1469 tdpbusd(x1, x2, x3);
1470 } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) {
1471 tdpbsud(x1, x2, x3);
1472 } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8) {
1473 tdpbssd(x1, x2, x3);
1474 } else {
1475 assert(!"unsupported combination");
1476 }
1477 interleave_store(bi, false);
1478 maybe_tilestore(bi, bdb_idx, ldb_idx, false, do_post_tilestore);
1479}
1480
1481// This method down-converts the data from f32 to bf16 and saves at reg_buf.
1482// Generally used by matrix_A, where no vnni transformation of data is needed.
1483void jit_brgemm_amx_uker_base_t::bf32_downconvert(int num_rows,
1484 int tile_num_col_bytes, reg64_t reg_data, int offset,
1485 reg64_t reg_data_stride, reg64_t reg_buf, bool is_rd_tail) {
1486 const auto rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block;
1487 const auto max_num_cols
1488 = nstl::min<int>(tile_num_col_bytes / sizeof(bfloat16_t), rd_block);
1489 const auto col_tail = max_num_cols % simd_w;
1490 auto zmm_1 = zmm_tmp_1();
1491 auto zmm_2 = zmm_tmp_2();
1492 auto zmm_2_masked = col_tail ? zmm_2 | bf32_col_mask | T_z : zmm_2;
1493
1494 assert(max_num_cols > 0);
1495
1496 if (col_tail) {
1497 const int tail_mask = (1 << col_tail) - 1;
1498 auto reg_tmp_32 = reg_tmp_gpr.cvt32();
1499 mov(reg_tmp_32, tail_mask);
1500 kmovw(bf32_col_mask, reg_tmp_32);
1501 }
1502
1503 // Note: using the same register used in col_tail, so order is important
1504 const auto reg_data_aux = reg_tmp_gpr;
1505 lea(reg_data_aux, ptr[reg_data + offset]);
1506
1507 for (int r = 0; r < num_rows; ++r) {
1508 if (max_num_cols > 16) {
1509 vmovups(zmm_1, ptr[reg_data_aux]);
1510 vmovups(zmm_2_masked, ptr[reg_data_aux + zmm_width_in_bytes]);
1511 vcvtne2ps2bf16(zmm_1, zmm_2, zmm_1);
1512 // we assume enough padding space is available.
1513 vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_1);
1514 } else {
1515 auto ymm_1 = Ymm(zmm_1.getIdx());
1516 auto ymm_1_masked
1517 = max_num_cols == 16 ? ymm_1 : ymm_1 | bf32_col_mask | T_z;
1518 vcvtneps2bf16(ymm_1_masked, ptr[reg_data_aux]);
1519 vmovups(ptr[reg_buf + r * zmm_width_in_bytes], ymm_1);
1520 }
1521 add(reg_data_aux, reg_data_stride);
1522 }
1523}
1524
1525// This method down-converts and transforms the data from f32 to bf16_vnni
1526// format. Generally used by matrix_B.
1527void jit_brgemm_amx_uker_base_t::bf32_downconvert_to_vnni(int num_rows,
1528 int tile_num_col_bytes, reg64_t reg_data, int offset,
1529 reg64_t reg_data_stride, reg64_t reg_buf, bool is_rd_tail) {
1530 const auto num_cols_ele = tile_num_col_bytes / sizeof(bfloat16_t);
1531 const auto num_N = num_cols_ele / sizeof(bfloat16_t);
1532 const auto col_tail = num_N % simd_w;
1533 const auto zmm_1 = zmm_tmp_1();
1534 const auto zmm_2 = zmm_tmp_2();
1535
1536 assert(num_N > 0);
1537
1538 auto load = [&](Zmm zmm, Address addr) {
1539 if (col_tail)
1540 vmovups(zmm | bf32_col_mask | T_z, addr);
1541 else
1542 vmovups(zmm, addr);
1543 };
1544
1545 if (col_tail) {
1546 const int tail_mask = (1 << col_tail) - 1;
1547 auto reg_tmp_32 = reg_tmp_gpr.cvt32();
1548 mov(reg_tmp_32, tail_mask);
1549 kmovw(bf32_col_mask, reg_tmp_32);
1550 }
1551
1552 // Note: using the same register used in col_tail, so order is important
1553 const auto reg_data_aux = reg_tmp_gpr;
1554 lea(reg_data_aux, ptr[reg_data + offset]);
1555
1556 const auto rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block;
1557 const int vnni_granularity
1558 = data_type_vnni_granularity(data_type_t::dnnl_bf16);
1559 const auto r_end
1560 = nstl::min(utils::div_up(rd_block, vnni_granularity), num_rows);
1561
1562 for (int r = 0; r < r_end; ++r) {
1563 load(zmm_1, ptr[reg_data_aux]);
1564
1565 if (r * vnni_granularity + 1 >= rd_block) {
1566 vpxord(zmm_2, zmm_2, zmm_2);
1567 } else {
1568 load(zmm_2, ptr[reg_data_aux + reg_data_stride]);
1569 }
1570
1571 vcvtne2ps2bf16(zmm_1, zmm_2, zmm_1);
1572 vpermw(zmm_1, zmm_bf32_pemute, zmm_1);
1573 vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_1);
1574 lea(reg_data_aux,
1575 ptr[reg_data_aux + vnni_granularity * reg_data_stride]);
1576 }
1577
1578 // zero rest of the tile data
1579 if (r_end < num_rows) {
1580 vpxord(zmm_2, zmm_2, zmm_2);
1581 for (int r = r_end; r < num_rows; ++r)
1582 vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_2);
1583 }
1584}
1585
1586void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi,
1587 const Tmm &t1, reg64_t reg_base, size_t offset, reg64_t reg_stride,
1588 matrix_kind_t mk) {
1589
1590 auto should_save_transform = [&](matrix_kind_t mk) {
1591 // save if there is a reuse
1592 if (mk == matrix_A) {
1593 return brg.ldb + (brg.ldb_tail != 0) > brg.ld_block2;
1594 } else {
1595 return brg.bdb + (brg.bdb_tail != 0) > brg.bd_block2;
1596 }
1597 };
1598
1599 const bool is_A = mk == matrix_A;
1600 auto &transform_buf = is_A ? transform_buf_map_A_ : transform_buf_map_B_;
1601
1602 const auto transform_offset
1603 = use_ils_ ? brg.get_num_C_tiles() * tile_size : 0;
1604 const auto max_bdb2 = brg.bd_block2;
1605 const auto max_rdb = brg.rdb + (brg.rdb_tail != 0);
1606 const auto matrix_a_offset = transform_offset;
1607 const auto matrix_b_offset = transform_offset
1608 + tile_size
1609 * (nstl::max<int>(should_save_transform(mk),
1610 should_save_transform(matrix_A) * brg.brgattr.max_bs
1611 * max_bdb2 * max_rdb));
1612 const auto matrix_offset = is_A ? matrix_a_offset : matrix_b_offset;
1613 const std::string key
1614 = std::to_string(bi.bsi.pos) + "_" + std::to_string(offset);
1615
1616 if (transform_buf.find(key) != transform_buf.end()) {
1617 auto buf_idx = transform_buf[key];
1618 auto offt = matrix_offset + buf_idx * tile_size;
1619 tileloadd(t1, ptr[reg_buf + reg_bf32_stride + offt]);
1620 return;
1621 }
1622
1623 auto buf_offt = matrix_offset;
1624 // save offset of the transformation if required.
1625 if (should_save_transform(mk)) {
1626 auto buf_idx = transform_buf.size();
1627 buf_offt = matrix_offset + buf_idx * tile_size;
1628 transform_buf[key] = buf_idx;
1629 }
1630
1631 if (buf_offt) add(reg_buf, buf_offt);
1632 mov(reg_bf32_stride, zmm_width_in_bytes);
1633
1634 assert(t1.getIdx() >= 0 && t1.getIdx() < 16);
1635 const auto num_rows = palette_.rows[t1.getIdx()];
1636 const auto num_col_bytes = palette_.cols[t1.getIdx()];
1637 if (is_A) {
1638 bf32_downconvert(num_rows, num_col_bytes, reg_base, offset, reg_stride,
1639 reg_buf, bi.rdi.is_tail);
1640 } else {
1641 bf32_downconvert_to_vnni(num_rows, num_col_bytes, reg_base, offset,
1642 reg_stride, reg_buf, bi.rdi.is_tail);
1643 }
1644
1645 // load into tmm from the transformed data.
1646 tileloadd(t1, ptr[reg_buf + reg_bf32_stride]);
1647
1648 // reset buf pointer.
1649 if (buf_offt) sub(reg_buf, buf_offt);
1650}
1651
1652void jit_brgemm_amx_uker_base_t::gemm_microkernel_amx(brgemm_iteration_t &bi) {
1653 prf1A.vec = 0;
1654 prf2A.vec = 0;
1655 prf1B.vec = 0;
1656 prf2B.vec = 0;
1657
1658 const auto store_by_vectors = get_store_by_vectors(bi.apply_postops);
1659
1660 bool do_post_tilestore = (brg.interleave_tilestores_ && bi.bsi.is_last
1661 && imap_.is_last_rdi(bi.rdi));
1662
1663 bool do_pre_tilestore = (brg.interleave_tilestores_ && bi.bsi.is_first
1664 && bi.rdi.pos == 0 && was_prev_bi);
1665
1666 if (store_by_vectors)
1667 mov(reg_stride_ld_block, ld_block_C_size_);
1668 else
1669 mov(reg_stride_ld_block, LDC_size_);
1670
1671 const auto rdb_A_off = rdb_A_offset(bi);
1672 const auto rdb_B_off = rdb_B_offset(bi) + ldb_B_offset(bi);
1673
1674 for (int bdb = 0; bdb < bi.bdi.block2; bdb++) {
1675 const auto bd_inp_bdb = bi.bdi.bdb_pos[bdb];
1676 maybe_tileloadd_nt(bi, matrix_kind_t::matrix_A, bdb,
1677 rdb_A_off + A_offset(bd_inp_bdb));
1678 for (int ldb = 0; ldb < bi.ldi.block2; ldb++) {
1679 if (bdb == 0)
1680 maybe_tileloadd_nt(bi, matrix_kind_t::matrix_B, ldb,
1681 rdb_B_off + B_offset(ldb));
1682 if (ldb == 0) {
1683 if (bdb > 0)
1684 tdpbxxd(bi, bdb - 1, bi.ldi.block2 - 1, do_pre_tilestore,
1685 do_post_tilestore);
1686 } else
1687 tdpbxxd(bi, bdb, ldb - 1, do_pre_tilestore, do_post_tilestore);
1688 }
1689 }
1690 // last tdpbxxd
1691 tdpbxxd(bi, bi.bdi.block2 - 1, bi.ldi.block2 - 1, do_pre_tilestore,
1692 do_post_tilestore);
1693}
1694
1695void jit_brgemm_amx_uker_base_t::rdb_loop(brgemm_iteration_t &bi) {
1696 for (size_t irdi = 0; irdi < imap_.rdis.size(); irdi++) {
1697 bi.rdi = imap_.rdis[irdi];
1698 gemm_microkernel_amx(bi);
1699 }
1700}
1701
1702void jit_brgemm_amx_uker_base_t::bs_loop_body(brgemm_iteration_t &bi) {
1703 if (brg.brgattr.var_bs) {
1704 set_A_B_matrices();
1705 add(reg_aux1_batch, sizeof(brgemm_batch_element_t));
1706 prefetcht0(ptr[reg_aux1_batch]);
1707 } else {
1708 set_A_B_matrices(bi.bsi.pos);
1709 }
1710
1711 rdb_loop(bi);
1712}
1713
1714void jit_brgemm_amx_uker_base_t::bs_loop(brgemm_iteration_t &bi) {
1715
1716 load_accumulators(bi);
1717
1718 if (brg.brgattr.var_bs) {
1719 if (brg.alpha != 0.f) {
1720 Label BS_loop_label, end_BS_loop_label, first_BS_loop_label,
1721 last_BS_loop_label;
1722
1723 mov(reg_BS_loop, reg_BS);
1724 cmp(reg_BS_loop, 0);
1725 jz(end_BS_loop_label, T_NEAR);
1726
1727 mov(reg_aux1_batch, reg_addr_batch);
1728 // first bs iteration
1729 cmp(reg_BS_loop, 1);
1730 jg(first_BS_loop_label, T_NEAR);
1731
1732 bi.bsi = imap_.bsis[0];
1733 // only one BS iteration: first and last
1734 bi.bsi.is_first = true;
1735 bi.bsi.is_last = true;
1736 bs_loop_body(bi);
1737 jmp(end_BS_loop_label, T_NEAR);
1738
1739 // first BS iteration
1740 L_aligned(first_BS_loop_label, 64);
1741 bi.bsi.is_first = true;
1742 bi.bsi.is_last = false;
1743 bs_loop_body(bi);
1744
1745 dec(reg_BS_loop);
1746 cmp(reg_BS_loop, 1);
1747 je(last_BS_loop_label, T_NEAR);
1748
1749 // middle BS iterations
1750 L_aligned(BS_loop_label, 64);
1751 {
1752 bi.bsi.is_first = false;
1753 bi.bsi.is_last = false;
1754 bs_loop_body(bi);
1755 dec(reg_BS_loop);
1756 cmp(reg_BS_loop, 1);
1757 jg(BS_loop_label, T_NEAR);
1758 }
1759 // last BS iteration
1760 L_aligned(last_BS_loop_label, 64);
1761 bi.bsi.is_first = false;
1762 bi.bsi.is_last = true;
1763 bs_loop_body(bi);
1764
1765 L_aligned(end_BS_loop_label, 64);
1766 }
1767 store_accumulators(bi);
1768 } else {
1769 if (brg.alpha != 0.f) {
1770 for (int bs = 0; bs < brg.brgattr.max_bs; bs++) {
1771 bi.bsi = imap_.bsis[bs];
1772 bs_loop_body(bi);
1773 }
1774 }
1775 store_accumulators(bi);
1776 }
1777}
1778
1779void jit_brgemm_amx_uker_base_t::ldb_loop_body(brgemm_iteration_t &bi) {
1780 if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost)
1781 bdb_loop(bi);
1782 else if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost)
1783 bs_loop(bi);
1784 else
1785 assert(!"Unknown loop order!");
1786}
1787
1788void jit_brgemm_amx_uker_base_t::ldb_loop(brgemm_iteration_t &bi) {
1789 // clear the transform cache for A, as the existing data is invalid as
1790 // we move to next bdb2 block.
1791 transform_buf_map_A_.clear();
1792 bi.ldi = imap_.ldis[0];
1793 for (size_t ildi = 0; ildi < imap_.ldis.size(); ildi++) {
1794 bi.ldi = imap_.ldis[ildi];
1795 ldb_loop_body(bi);
1796 }
1797}
1798
1799void jit_brgemm_amx_uker_base_t::bdb_loop_body(brgemm_iteration_t &bi) {
1800 if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost)
1801 ldb_loop(bi);
1802 else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost)
1803 bs_loop(bi);
1804 else
1805 assert(!"Unknown loop order!");
1806};
1807
1808void jit_brgemm_amx_uker_base_t::bdb_loop(brgemm_iteration_t &bi) {
1809 bi.bdi = imap_.bdis[0];
1810
1811 for (size_t ibdi = 0; ibdi < imap_.bdis.size(); ibdi++) {
1812 bi.bdi = imap_.bdis[ibdi];
1813 bdb_loop_body(bi);
1814 }
1815}
1816
1817void jit_brgemm_amx_uker_base_t::top_loop(brgemm_iteration_t &bi) {
1818 init(bi);
1819 if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost)
1820 bdb_loop(bi);
1821 else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost)
1822 ldb_loop(bi);
1823 else
1824 assert(!"Unknown loop order!");
1825
1826 if (brg.interleave_tilestores_) {
1827 bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2);
1828 for_(int bdb = 0; bdb < brg.bd_block2; bdb++)
1829 for (int ldb = 0; ldb < brg.ld_block2; ldb++) {
1830 maybe_tilestore(bi, bdb, ldb, true, false);
1831 }
1832 }
1833
1834 interleave_store(bi, true);
1835}
1836
1837void jit_brgemm_amx_uker_base_t::fill_imap() {
1838 imap_.bdis.clear();
1839 imap_.ldis.clear();
1840 imap_.rdis.clear();
1841 imap_.bsis.clear();
1842 size_t bdi_pos = skipped_bd_mask(0);
1843 auto bdi = bd_iteration_t(bdi_pos, brg.bd_block, brg.bd_block2);
1844 for (int bdb2 = 0; bdb2 < brg.bdb2; bdb2++) {
1845 bdi.pos = bdi_pos;
1846 bdi.bdb_pos.clear();
1847 for (int bdb = 0; bdb < bdi.block2; bdb++) {
1848 bdi.bdb_pos.push_back(bdi_pos);
1849 bdi_pos += brg.bd_block;
1850 bdi_pos = skipped_bd_mask(bdi_pos);
1851 }
1852 bdi.idx = imap_.bdis.size();
1853 imap_.bdis.push_back(bdi);
1854 }
1855 if (brg.bdb2_tail > 0) {
1856 bdi.block2 = brg.bdb2_tail;
1857 bdi.pos = bdi_pos;
1858 bdi.bdb_pos.clear();
1859 for (int bdb = 0; bdb < bdi.block2; bdb++) {
1860 bdi.bdb_pos.push_back(bdi_pos);
1861 bdi_pos += brg.bd_block;
1862 bdi_pos = skipped_bd_mask(bdi_pos);
1863 }
1864 bdi.idx = imap_.bdis.size();
1865 imap_.bdis.push_back(bdi);
1866 }
1867 if (brg.bdb_tail > 0) {
1868 bdi.block2 = 1;
1869 bdi.block = brg.bdb_tail;
1870 bdi.is_tail = true;
1871 bdi.pos = bdi_pos;
1872 bdi.bdb_pos.clear();
1873 bdi.bdb_pos.push_back(bdi_pos);
1874 bdi.idx = imap_.bdis.size();
1875 imap_.bdis.push_back(bdi);
1876 }
1877
1878 auto ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2);
1879 for (int ldb2 = 0; ldb2 < brg.ldb2; ldb2++) {
1880 ldi.idx = imap_.ldis.size();
1881 imap_.ldis.push_back(ldi);
1882 ldi.pos += ldi.block2;
1883 }
1884 if (brg.ldb2_tail > 0) {
1885 ldi.block2 = brg.ldb2_tail;
1886 ldi.idx = imap_.ldis.size();
1887 imap_.ldis.push_back(ldi);
1888 ldi.pos += ldi.block2;
1889 }
1890 if (brg.ldb_tail > 0) {
1891 ldi.block2 = 1;
1892 ldi.block = brg.ldb_tail;
1893 ldi.is_tail = true;
1894 ldi.idx = imap_.ldis.size();
1895 imap_.ldis.push_back(ldi);
1896 }
1897
1898 auto rdi = dim_iteration_t(0, brg.rd_block, 1);
1899 for (int rdb = 0; rdb < brg.rdb; rdb++) {
1900 rdi.idx = imap_.rdis.size();
1901 imap_.rdis.push_back(rdi);
1902 rdi.pos++;
1903 }
1904 if (brg.rdb_tail > 0) {
1905 rdi.block = brg.rdb_tail;
1906 rdi.is_tail = true;
1907 rdi.idx = imap_.rdis.size();
1908 imap_.rdis.push_back(rdi);
1909 }
1910
1911 bs_iteration_t bsi;
1912 for (int bs = 0; bs < brg.brgattr.max_bs; bs++) {
1913 bsi.pos = bs;
1914 bsi.is_first = (bs == 0);
1915 bsi.is_last = (bs == brg.brgattr.max_bs - 1);
1916 bsi.idx = imap_.bsis.size();
1917 imap_.bsis.push_back(bsi);
1918 }
1919}
1920
1921void jit_brgemm_amx_uker_base_t::init(brgemm_iteration_t &bi) {
1922 was_prev_bi = false;
1923
1924 if (brg.brgattr.max_bs == 1) {
1925 if (brg.layout == brgemm_row_major) {
1926 mov(reg_aux_A,
1927 EVEX_compress_addr(
1928 reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.A)));
1929 mov(reg_aux_B,
1930 EVEX_compress_addr(
1931 reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.B)));
1932 } else {
1933 mov(reg_aux_A,
1934 EVEX_compress_addr(
1935 reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.B)));
1936 mov(reg_aux_B,
1937 EVEX_compress_addr(
1938 reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.A)));
1939 }
1940 }
1941
1942 // for many primitives which use brgemm the brg.ldb2 is equal or less than 1
1943 // so we can read post ops data only once per brgemm call
1944
1945 if (brg.ldb2 > 1) {
1946 prepare_post_ops_registers_once_ = false;
1947 } else if (brg.ldb2 == 1) {
1948 if (brg.ldb2_tail == 0 && brg.ldb_tail == 0) {
1949 prepare_post_ops_registers_once_ = true;
1950 bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2);
1951 prepare_post_ops_registers(bi);
1952 }
1953 } else if (brg.ldb2_tail > 0) {
1954 if (brg.ldb_tail == 0) {
1955 prepare_post_ops_registers_once_ = true;
1956 bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ldb2_tail);
1957 prepare_post_ops_registers(bi);
1958 }
1959 } else {
1960 prepare_post_ops_registers_once_ = true;
1961 bi.ldi = dim_iteration_t(0, brg.ldb_tail, 1);
1962 bi.ldi.is_tail = true;
1963 prepare_post_ops_registers(bi);
1964 }
1965 if (bi.apply_postops)
1966 dt_requires_saturation_ = one_of(
1967 brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
1968 else {
1969 // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
1970 // accumulated values are converted to ps in apply_alpha_beta()
1971 const bool alpha_or_beta_applicable
1972 = brg.alpha != 1.0f || brg.beta != 0.f;
1973 const bool beta_uses_vadd = brg.beta == 1.f
1974 && IMPLICATION(brg.is_int8, brg.alpha == 1.0f);
1975 dt_requires_saturation_ = brg.is_int8
1976 && !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd);
1977 }
1978 if (dt_requires_saturation_) {
1979 init_saturate_f32(
1980 zmm_lbound, zmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d);
1981 }
1982
1983 fill_imap();
1984 prf1A.pft = brgemm_kernel_prefetching_t::brgemm_prf1;
1985 prf1A.dist = brg.prfA.dist1;
1986 prf2A.pft = brgemm_kernel_prefetching_t::brgemm_prf2;
1987 prf2A.dist = brg.prfA.dist2;
1988 prf1B.pft = brgemm_kernel_prefetching_t::brgemm_prf1;
1989 prf1B.dist = brg.prfB.dist1;
1990 prf2B.pft = brgemm_kernel_prefetching_t::brgemm_prf2;
1991 prf2B.dist = brg.prfB.dist2;
1992 prf1C.pft = brgemm_kernel_prefetching_t::brgemm_prf1;
1993 prf1C.dist = brg.prfC.dist1;
1994 prf2C.pft = brgemm_kernel_prefetching_t::brgemm_prf2;
1995 prf2C.dist = brg.prfC.dist2;
1996}
1997
1998void jit_brgemm_amx_uker_base_t::generate() {
1999 preamble();
2000
2001 sub(rsp, stack_space_needed_);
2002
2003 const auto full_mask = size_t {0xffffffffffffffff};
2004 const auto tail_mask = size_t((1 << brg.ldb_tail) - 1);
2005 LDA_size_ = brg.typesize_A * brg.LDA;
2006 LDB_size_ = brg.typesize_B * brg.LDB;
2007 LDC_size_ = brg.typesize_C * brg.LDC;
2008 LDD_size_ = brg.typesize_D * brg.LDD;
2009
2010 LDA2_size_ = brg.typesize_A * brg.LDA2;
2011 LDB2_size_ = brg.typesize_B * brg.LDB2;
2012 LDC2_size_M_ = brg.typesize_C * brg.LDC2_M;
2013 LDC2_size_N_ = brg.typesize_C * brg.LDC2_N;
2014
2015 ld_block_B_size_ = brg.typesize_B
2016 * ((brg.brgattr.LDB2 != 0) ? brg.brgattr.LDB2 : brg.ld_block);
2017 ld_block_C_size_ = brg.typesize_C * brg.ld_block;
2018 ld_block_D_size_ = brg.typesize_D * brg.ld_block;
2019 ld_block_bias_size_ = brg.typesize_bias * brg.ld_block;
2020 ld_block_scales_size_ = sizeof(float) * brg.ld_block;
2021 ld_block_zp_size_ = sizeof(int32_t) * brg.ld_block;
2022 ldb_tail_B_size_ = brg.typesize_B * brg.ldb_tail;
2023 ldb_tail_C_size_ = brg.typesize_C * brg.ldb_tail;
2024 ldb_tail_D_size_ = brg.typesize_D * brg.ldb_tail;
2025 ldb_tail_zp_size_ = sizeof(int32_t) * brg.ldb_tail;
2026
2027 // if beta == 1 and C datatype is f32 it is better to perform addition by
2028 // reading tiles directly from C instead of by reading/writing by vectors
2029 may_load_accumulators_ = one_of(brg.alpha, 0, 1) && brg.beta == 1.f
2030 && brg.dt_c == brg.dt_d && !brg.is_bf32
2031 && IMPLICATION(
2032 brg.is_f32 || brg.is_bf16, brg.dt_c == data_type::f32)
2033 && IMPLICATION(brg.is_int8, brg.dt_c == data_type::s32);
2034 need_to_apply_alpha_beta_
2035 = (brg.beta != 0.f && !may_load_accumulators_) || brg.alpha != 1.f;
2036 const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none,
2037 brg.zp_type_a, brg.zp_type_b, brg.zp_type_c);
2038 are_post_ops_applicable_ = one_of(true, brg.with_eltwise, brg.with_binary,
2039 brg.with_scales, brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c,
2040 has_zero_points);
2041
2042 // second level blocking eligible only if we don't use store by vectors for now
2043 assert(IMPLICATION(are_post_ops_applicable_ || need_to_apply_alpha_beta_
2044 || brg.brgattr.bd_mask_level,
2045 !brg.is_blocked && !brg.brgattr.var_bs));
2046 assert(IMPLICATION(brg.brgattr.var_bs, !brg.is_bf32));
2047 read_params();
2048 prepare_bd_mask();
2049 Label permute_index_table;
2050 if (brg.is_bf32) {
2051 brgemm_init_tiles(brg, (char *)(&palette_));
2052 mov(reg_tmp_gpr, permute_index_table);
2053 vmovups(zmm_bf32_pemute, ptr[reg_tmp_gpr]);
2054 }
2055
2056 reg64_t reg_mask = rax;
2057
2058 mov(reg_mask, full_mask);
2059 kmovq(ld_full_mask, reg_mask);
2060 mov(reg_mask, tail_mask);
2061 kmovq(ld_tail_mask, reg_mask);
2062
2063 mov(reg_stride_lda, lda());
2064 mov(reg_stride_ldb, ldb());
2065
2066 bool non_postops_generate
2067 = !are_post_ops_applicable_ || !brg.brgattr.postops_only;
2068 brgemm_iteration_t bi;
2069
2070 Label label_to_ret;
2071 if (are_post_ops_applicable_) {
2072 Label label_store_without_post_ops;
2073 mov(reg_do_post_ops, ptr[param1 + GET_OFF(do_post_ops)]);
2074 cmp(reg_do_post_ops, 0);
2075 jz(label_store_without_post_ops, T_NEAR);
2076 bi.apply_postops = true;
2077 top_loop(bi);
2078 if (non_postops_generate) jmp(label_to_ret, T_NEAR);
2079 transform_buf_map_A_.clear();
2080 transform_buf_map_B_.clear();
2081 L(label_store_without_post_ops);
2082 }
2083 if (non_postops_generate) {
2084 bi.apply_postops = false;
2085 top_loop(bi);
2086 }
2087 L(label_to_ret);
2088
2089 add(rsp, stack_space_needed_);
2090
2091 postamble();
2092
2093 if (brg.with_eltwise) postops_injector_->prepare_table();
2094
2095 if (brg.is_bf32) {
2096 align(64);
2097 L(permute_index_table);
2098 const uint16_t _idx[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6,
2099 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30,
2100 15, 31};
2101 for (size_t i = 0; i < 32; ++i)
2102 dw(_idx[i]);
2103 }
2104}
2105
2106brgemm_amx_uker_t::brgemm_amx_uker_t(const brgemm_t abrd) {
2107 brgemm_kernel_ = new jit_brgemm_amx_uker_base_t(abrd);
2108}
2109
2110status_t brgemm_amx_uker_t::create_kernel() {
2111 return brgemm_kernel_->create_kernel();
2112}
2113
2114void brgemm_amx_uker_t::operator()(brgemm_kernel_params_t *params) const {
2115 (*brgemm_kernel_)(params);
2116}
2117
2118brgemm_amx_uker_t::~brgemm_amx_uker_t() {
2119 delete brgemm_kernel_;
2120}
2121
2122} // namespace x64
2123} // namespace cpu
2124} // namespace impl
2125} // namespace dnnl
2126
2127// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2128