1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <memory>
17#include <vector>
18
19#include "common/c_types_map.hpp"
20#include "common/nstl.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/platform.hpp"
25#include "cpu/x64/brgemm/brgemm_types.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
29#include "cpu/x64/jit_generator.hpp"
30
31#define GET_OFF(field) offsetof(brgemm_kernel_params_t, field)
32#define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field)
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39using namespace dnnl::impl::utils;
40using namespace Xbyak;
41template <cpu_isa_t isa, typename Wmm>
42struct jit_brgemm_kernel_t : public jit_generator {
43 jit_brgemm_kernel_t(const brgemm_t &abrg)
44 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, abrg.isa_impl)
45 , brg(abrg)
46 , postops_injector_(nullptr) {
47
48 // The implementation uses is_superset(), is_subset() utilities.
49 // So avoid isa_all, isa_undef in these comparisions.
50 assert(!utils::one_of(brg.isa_impl, isa_all, isa_undef));
51 const int is_ldb2_tail = brg.ldb2_tail ? 1 : 0;
52 const int is_ldb_tail = brg.ldb_tail ? 1 : 0;
53 is_ldb_loop_ = brg.ldb2 + is_ldb2_tail + is_ldb_tail > 1;
54
55 if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
56
57 static constexpr bool preserve_gpr = true;
58 static constexpr bool preserve_vmm = true;
59 static constexpr bool use_exact_tail_scalar_bcast = false;
60 const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md);
61
62 static const bcast_set_t enabled_bcast_strategy
63 = {broadcasting_strategy_t::scalar,
64 broadcasting_strategy_t::per_oc,
65 broadcasting_strategy_t::per_oc_spatial,
66 broadcasting_strategy_t::per_mb_spatial,
67 broadcasting_strategy_t::per_mb_w,
68 broadcasting_strategy_t::per_w,
69 broadcasting_strategy_t::no_broadcast};
70 const binary_injector::rhs_arg_static_params_t rhs_sp {
71 static_cast<size_t>(Vmm(1).getIdx()), this->r14, this->r15,
72 this->r13, preserve_gpr, preserve_vmm,
73 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
74 dst_md_wrapper, static_cast<size_t>(brg.ldb_tail),
75 ld_tail_mask, use_exact_tail_scalar_bcast};
76 const binary_injector::static_params_t bsp {
77 this->param1, enabled_bcast_strategy, rhs_sp};
78
79 postops_injector_ = utils::make_unique<po_injector_t>(
80 this, brg.attr->post_ops_, bsp);
81
82 using namespace dnnl::impl::cpu::binary_injector_utils;
83 std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_,
84 with_binary_channel_bcast_, with_binary_per_mb_w_bcast_,
85 with_binary_per_w_bcast_, with_binary_no_bcast_)
86 = bcast_strategies_present_tup(brg.attr->post_ops_.entry_,
87 dst_md_wrapper, broadcasting_strategy_t::per_oc,
88 broadcasting_strategy_t::per_oc_spatial,
89 broadcasting_strategy_t::per_mb_spatial,
90 broadcasting_strategy_t::per_mb_w,
91 broadcasting_strategy_t::per_w,
92 broadcasting_strategy_t::no_broadcast);
93 handle_binary_po_offset_ = with_binary_per_oc_bcast_
94 || with_binary_per_oc_sp_bcast_
95 || with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_
96 || with_binary_per_w_bcast_ || with_binary_no_bcast_;
97 }
98 if (brg.is_bf16_emu)
99 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
100 bf16_emu_reserv_1(), bf16_emu_reserv_2(),
101 bf16_emu_reserv_3(), bf16_emu_scratch, bf16_emu_reserv_4(),
102 bf16_emu_reserv_4());
103 }
104
105 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_t)
106
107 brgemm_t brg;
108
109private:
110 using Vmm =
111 typename utils::conditional<std::is_same<Wmm, Xbyak::Tmm>::value,
112 Xbyak::Zmm, Wmm>::type;
113 using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t;
114 static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core,
115 avx512_core_amx_fp16, avx512_core_fp16, avx512_core_amx,
116 avx512_core_fp16, avx512_core_fp16, avx512_core_fp16, avx2_vnni_2,
117 avx2_vnni_2, avx2_vnni, avx2, avx2, avx2);
118 using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_t, Vmm>;
119 std::unique_ptr<po_injector_t> postops_injector_;
120 std::unique_ptr<bf16_emulation_t> bf16_emu_;
121
122 using reg64_t = const Xbyak::Reg64;
123
124 // Register decomposition
125 const reg64_t param1 = abi_param1;
126
127 const reg64_t reg_C = r15;
128 const reg64_t reg_aux_C = r14;
129
130 const reg64_t reg_addr_batch = r13;
131 const reg64_t reg_A = r13;
132 const reg64_t reg_B = r12;
133
134 const reg64_t reg_aux_A = r11;
135 const reg64_t reg_aux_B = r10;
136 const reg64_t reg_aux_A_vpad = reg_aux_A;
137
138 const reg64_t reg_bdb_loop = r9;
139 const reg64_t reg_ldb_loop = r8;
140
141 const reg64_t reg_stride_lda = reg_bdb_loop;
142 const reg64_t reg_stride_ldb = reg_ldb_loop;
143 const reg64_t reg_stride_ld_block = reg_ldb_loop;
144 const reg64_t reg_s8_input_shift = reg_bdb_loop;
145 const reg64_t reg_zp_a_input_shift = reg_bdb_loop;
146
147 const reg64_t reg_BS_loop = rax;
148 const reg64_t reg_rdb_loop = rbx;
149 const reg64_t reg_BS = abi_not_param1;
150
151 const reg64_t reg_a_offset = rdx;
152 const reg64_t reg_b_offset = rsi;
153
154 const reg64_t reg_aux1_batch = rbp;
155 const reg64_t reg_aux1_A = rbp;
156 const reg64_t reg_aux1_B = abi_param1;
157
158 const reg64_t reg_offs_batch = reg_aux1_A;
159 const reg64_t reg_strd_batch = reg_rdb_loop;
160
161 const reg64_t reg_bias = reg_rdb_loop;
162 const reg64_t reg_scales = reg_rdb_loop;
163 const reg64_t reg_aux_bias = reg_rdb_loop;
164 const reg64_t reg_binary_postops_oc_l = reg_rdb_loop;
165 const reg64_t reg_aux_binary_postops_oc_l = reg_rdb_loop;
166 const reg64_t reg_aux_binary_postops_sp = reg_rdb_loop;
167 const reg64_t reg_binary_po_stack_frame = reg_rdb_loop;
168 const reg64_t reg_zp_comp_a = reg_rdb_loop;
169 const reg64_t reg_aux_zp_comp_a = reg_rdb_loop;
170 const reg64_t reg_zp_comp_b = reg_rdb_loop;
171 const reg64_t reg_aux_zp_comp_b = reg_rdb_loop;
172 const reg64_t reg_zp_c_values = reg_rdb_loop;
173 const reg64_t reg_aux_zp_c_values = reg_rdb_loop;
174
175 const reg64_t reg_aux_scales = reg_aux_B;
176 const reg64_t reg_do_post_ops = reg_rdb_loop;
177 const reg64_t reg_do_comp = reg_rdb_loop;
178 const reg64_t reg_skip_accm = reg_rdb_loop;
179 const reg64_t reg_tmp_gpr = reg_rdb_loop;
180 const reg64_t reg_ptr_sum_scale = reg_rdb_loop;
181 const reg64_t reg_ptr_sum_zp = reg_bdb_loop;
182 const reg64_t reg_zp_a_val = reg_rdb_loop;
183
184 const reg64_t reg_buf = reg_rdb_loop;
185 const reg64_t reg_compensation = reg_bias;
186 const reg64_t reg_aux_compensation = reg_aux_bias;
187
188 const reg64_t reg_D = reg_aux_A;
189 const reg64_t reg_aux_D = reg_BS_loop;
190
191 /* bf16 emulation */
192 const reg64_t bf16_emu_scratch = reg_rdb_loop;
193
194 constexpr static int origin_offs_batch_offs_ = 0;
195 constexpr static int origin_strd_batch_offs_ = 0;
196 constexpr static int reg_bias_offs_ = 8;
197 constexpr static int reg_aux_bias_offs_ = 16;
198 constexpr static int reg_do_post_ops_offs_ = 24;
199 constexpr static int reg_D_offs_ = 32;
200 constexpr static int reg_aux_D_offs_ = 40;
201 constexpr static int reg_scales_offs_ = 48;
202 constexpr static int reg_aux_scales_offs_ = 56;
203 constexpr static int reg_bdb_loop_offs_ = 64;
204 constexpr static int reg_ldb_loop_offs_ = 72;
205 constexpr static int reg_buf_offs_ = 80;
206 constexpr static int reg_comp_offs_ = reg_buf_offs_;
207 constexpr static int reg_aux_comp_offs_ = 88;
208 constexpr static int abi_param1_offs_ = 96;
209 constexpr static int reg_binary_postops_oc_l_offs_ = 104;
210 constexpr static int reg_aux_binary_postops_oc_l_offs_ = 112;
211 constexpr static int reg_binary_postops_sp_offs_ = 120;
212 constexpr static int reg_aux_binary_postops_sp_offs_ = 128;
213 constexpr static int reg_zp_comp_a_offs_ = 136;
214 constexpr static int reg_aux_zp_comp_a_offs_ = 144;
215 constexpr static int reg_zp_comp_b_offs_ = 152;
216 constexpr static int reg_aux_zp_comp_b_offs_ = 160;
217 constexpr static int reg_zp_c_values_offs_ = 168;
218 constexpr static int reg_aux_zp_c_values_offs_ = 176;
219 constexpr static int reg_data_C_ptr_ = 184;
220 constexpr static int reg_skip_accm_offs_ = 192;
221 constexpr static int reg_zp_a_val_offs_ = 200;
222 constexpr static int reg_do_comp_offs_ = 208;
223 constexpr static int stack_space_needed_ = 216;
224
225 bool is_ldb_loop_ = false;
226 bool handle_binary_po_offset_ = false;
227 bool with_binary_per_oc_bcast_ = false;
228 bool with_binary_per_oc_sp_bcast_ = false;
229 bool with_binary_channel_bcast_ = false;
230 bool with_binary_per_mb_w_bcast_ = false;
231 bool with_binary_per_w_bcast_ = false;
232 bool with_binary_no_bcast_ = false;
233 constexpr static int max_vregs = cpu_isa_traits<po_isa_t>::n_vregs;
234
235 Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2);
236 Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3);
237
238 Vmm accm(int ld_block, int bd, int ld) {
239 return Vmm(max_vregs - 1 - (bd * ld_block + ld));
240 }
241
242 Vmm bcst(int bd = 0) {
243 if (n_bcast_1_load) {
244 int idx = max_vregs - 1 - (brg.ld_block2 * brg.bd_block) - bd;
245 assert(idx > 0);
246 return Vmm(idx);
247 } else
248 return Vmm(0);
249 }
250
251 Vmm load(int ld = 0) {
252 if (n_bcast_1_load) {
253 return Vmm(0);
254 } else {
255 int idx = max_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld;
256 assert(idx > 0);
257 return Vmm(idx);
258 }
259 }
260
261 Vmm vmm_tmp_1() const noexcept { return Vmm(0); }
262 Vmm vmm_tmp_2() const noexcept { return Vmm(1); }
263 Vmm vmm_tmp_3() const noexcept { return Vmm(2); }
264 Vmm vmm_one_bytes() const noexcept { return Vmm(3); }
265 Vmm vmm_zp_a_shift() const noexcept { return Vmm(2); }
266 Vmm vmm_inp_shift() const noexcept { return Vmm(1); }
267
268 /* bf16 emulation */
269 Zmm bf16_emu_reserv_1() const noexcept { return Zmm(0); }
270 Zmm bf16_emu_reserv_2() const noexcept { return Zmm(1); }
271 Zmm bf16_emu_reserv_3() const noexcept { return Zmm(2); }
272 Zmm bf16_emu_reserv_4() const noexcept { return Zmm(3); }
273 // note: zmm reserv_5 is not necessary since it's only used for 'vdpbf16ps'
274
275 Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store,
276 Xbyak::Opmask ktail_mask) const;
277 Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag,
278 bool store, Xbyak::Opmask ktail_mask) const;
279
280 void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op,
281 bool mask_flag, bool store, Xbyak::Opmask ktail_mask,
282 int tail_size);
283
284 void advance_ldb_post_op_regs();
285 void restore_ldb_post_op_regs(int ld_block2);
286 void advance_bdb_post_op_regs(int adj_bd_block);
287 void restore_bdb_post_op_regs(int bd_block2);
288 void ldb_regs_shift(int ld_block2, bool is_tail = false);
289 void advance_bd_block2_post_op_regs(int bd_block2);
290
291 void copy_post_ops_stack_values_to_aux(bool is_reg_tail);
292 void read_params();
293 void zero_accumulators(int bd_block2, bool is_bdb_tail, int ld_block,
294 bool is_ld_tail, bool skip_accumulation);
295
296 void store_accumulators(int bd_block2, bool is_bdb_tail, int ld_block,
297 bool is_ld_tail, bool skip_accumulation);
298 void store_accumulators_without_post_ops(
299 int bd_block, int ld_block, bool is_ld_tail);
300 void store_accumulators_apply_post_ops(int bd_block, int ld_block,
301 int ldb_and_bdb_offset, bool is_ld_tail);
302 void apply_compensation(int bd_block, int ld_block, bool is_ld_tail);
303 void apply_alpha_beta(int bd_block, int ld_block, bool is_ld_tail);
304 void apply_post_ops(int bd_block, int ld_block2, int ldb_and_bdb_offset,
305 bool is_ld_tail);
306 void restore_A_B_matrices();
307 void set_A_B_matrices();
308
309 void gemm_microkernel(int bd_block2, bool is_bdb_tail, int ld_block,
310 bool is_rd_tail, bool is_ld_tail, int vpad, int rows_for_rd_tail);
311 void gemm_microkernel_amx(int bd_block2, bool is_bdb_tail, int ld_block,
312 bool is_rd_tail, bool is_ld_tail);
313
314 void ldb_loop(int bd_block2, bool is_bdb_tail, int ld_block,
315 int ldb_loop_length, bool is_reg_tail, bool is_ld_tail,
316 bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail,
317 bool skip_accumulation);
318 void bdb_loop();
319
320 void generate() override;
321
322 int A_offset(int bd, int rd, bool is_amx = false) const noexcept;
323 int B_offset(int ld, int rd, bool is_amx = false) const noexcept;
324 int C_offset(int bd, int ld) const noexcept;
325 int D_offset(int bd, int ld) const noexcept;
326 int po_offset(int bd, int ld) const noexcept;
327
328 int rdb_A_offset() const noexcept;
329 int rdb_B_offset() const noexcept;
330
331 int ldb_B_offset(int ld_block2, bool is_tail = false) const noexcept;
332 int ldb_C_offset(int ld_block2, bool is_tail = false) const noexcept;
333 int ldb_D_offset(int ld_block2, bool is_tail = false) const noexcept;
334 int ldb_po_offset(int ld_block2, bool is_tail = false) const noexcept;
335
336 int bdb_A_offset(int bd_block2) const noexcept;
337 int bdb_C_offset(int bd_block2) const noexcept;
338 int bdb_D_offset(int bd_block2) const noexcept;
339 int bdb_po_offset(int bd_block2) const noexcept;
340
341 int bias_offset(int ld, bool is_tail = false) const noexcept;
342 int oc_logical_offset(int ld, bool is_tail = false) const noexcept;
343
344 int compensations_offset(int ld, bool is_tail = false) const noexcept;
345 int bdb_compensation_offset(int bd_block2) const noexcept;
346 int compensation_vpad_offset(int ld, int bd) const noexcept;
347 int scales_offset(int ld, bool is_tail = false) const noexcept;
348 int zp_comp_a_offset(int ld, bool is_tail = false) const noexcept;
349 int zp_comp_a_vpad_offset(int ld, int bd) const noexcept;
350 int bdb_zp_comp_a_offset(int bd_block2) const noexcept;
351 int zp_comp_b_offset(int bd) const noexcept;
352 int bdb_zp_comp_b_offset(int bd_block2) const noexcept;
353 int zp_c_values_offset(int ld, bool is_tail = false) const noexcept;
354
355 bool n_bcast_1_load = false;
356 bool vpad_exist = false;
357 bool need_comp_pads = false;
358};
359
360template <cpu_isa_t isa, typename Wmm>
361int jit_brgemm_kernel_t<isa, Wmm>::A_offset(int bd, int rd, bool is_amx) const
362 noexcept {
363 return (is_amx) ? brg.typesize_A * (bd * brg.bd_block * brg.LDA)
364 : brg.typesize_A * (bd * brg.LDA + rd);
365}
366
367template <cpu_isa_t isa, typename Wmm>
368int jit_brgemm_kernel_t<isa, Wmm>::B_offset(int ld, int rd, bool is_amx) const
369 noexcept {
370 if (is_amx) {
371 return brg.typesize_B * (brg.rd_step * ld * brg.ld_block);
372 } else {
373 const int data_vnni_granularity = brg.ld_step;
374 const int rdb0 = rd / data_vnni_granularity;
375 // Note: Offsets for elements within vnni_granularity are expected to be
376 // handled within gemm_microkernel (for ex: odd-even converts).
377 // hence no `rd % data_vnni_granularity`
378 return brg.typesize_B
379 * (rdb0 * data_vnni_granularity * brg.LDB
380 + data_vnni_granularity * ld * brg.ld_block);
381 }
382}
383
384template <cpu_isa_t isa, typename Wmm>
385int jit_brgemm_kernel_t<isa, Wmm>::C_offset(int bd, int ld) const noexcept {
386 return brg.typesize_C * (bd * brg.LDC + ld * brg.ld_block);
387}
388
389template <cpu_isa_t isa, typename Wmm>
390int jit_brgemm_kernel_t<isa, Wmm>::D_offset(int bd, int ld) const noexcept {
391 return brg.typesize_D * (bd * brg.LDD + ld * brg.ld_block);
392}
393
394template <cpu_isa_t isa, typename Wmm>
395int jit_brgemm_kernel_t<isa, Wmm>::po_offset(int bd, int ld) const noexcept {
396 return bd * brg.LDD + ld * brg.ld_block;
397}
398
399template <cpu_isa_t isa, typename Wmm>
400int jit_brgemm_kernel_t<isa, Wmm>::rdb_A_offset() const noexcept {
401 return brg.typesize_A * brg.rd_block;
402}
403
404template <cpu_isa_t isa, typename Wmm>
405int jit_brgemm_kernel_t<isa, Wmm>::rdb_B_offset() const noexcept {
406 return brg.typesize_B * brg.rd_block * brg.LDB;
407}
408
409template <cpu_isa_t isa, typename Wmm>
410int jit_brgemm_kernel_t<isa, Wmm>::ldb_B_offset(
411 int ld_block2, bool is_tail) const noexcept {
412 return (is_tail) ? brg.typesize_B * brg.ldb_tail * brg.ld_step
413 : brg.typesize_B * ld_block2 * brg.ld_block * brg.ld_step;
414}
415
416template <cpu_isa_t isa, typename Wmm>
417int jit_brgemm_kernel_t<isa, Wmm>::ldb_C_offset(
418 int ld_block2, bool is_tail) const noexcept {
419 return (is_tail) ? brg.typesize_C * brg.ldb_tail
420 : brg.typesize_C * ld_block2 * brg.ld_block;
421}
422
423template <cpu_isa_t isa, typename Wmm>
424int jit_brgemm_kernel_t<isa, Wmm>::ldb_D_offset(
425 int ld_block2, bool is_tail) const noexcept {
426 return (is_tail) ? brg.typesize_D * brg.ldb_tail
427 : brg.typesize_D * ld_block2 * brg.ld_block;
428}
429
430template <cpu_isa_t isa, typename Wmm>
431int jit_brgemm_kernel_t<isa, Wmm>::ldb_po_offset(
432 int ld_block2, bool is_tail) const noexcept {
433 return (is_tail) ? brg.ldb_tail : ld_block2 * brg.ld_block;
434}
435
436template <cpu_isa_t isa, typename Wmm>
437int jit_brgemm_kernel_t<isa, Wmm>::bdb_A_offset(int bd_block2) const noexcept {
438 return brg.typesize_A * bd_block2 * brg.bd_block * brg.LDA;
439}
440
441template <cpu_isa_t isa, typename Wmm>
442int jit_brgemm_kernel_t<isa, Wmm>::bdb_C_offset(int bd_block2) const noexcept {
443 return brg.typesize_C * bd_block2 * brg.bd_block * brg.LDC;
444}
445
446template <cpu_isa_t isa, typename Wmm>
447int jit_brgemm_kernel_t<isa, Wmm>::bdb_D_offset(int bd_block2) const noexcept {
448 return brg.typesize_D * bd_block2 * brg.bd_block * brg.LDD;
449}
450
451template <cpu_isa_t isa, typename Wmm>
452int jit_brgemm_kernel_t<isa, Wmm>::bdb_po_offset(int bd_block2) const noexcept {
453 return bd_block2 * brg.bd_block * brg.LDD;
454}
455
456template <cpu_isa_t isa, typename Wmm>
457int jit_brgemm_kernel_t<isa, Wmm>::bias_offset(int ld, bool is_tail) const
458 noexcept {
459 return (is_tail) ? brg.typesize_bias * brg.ldb_tail
460 : brg.typesize_bias * ld * brg.ld_block;
461}
462
463template <cpu_isa_t isa, typename Wmm>
464int jit_brgemm_kernel_t<isa, Wmm>::oc_logical_offset(int ld, bool is_tail) const
465 noexcept {
466 return (is_tail) ? brg.ldb_tail : ld * brg.ld_block;
467}
468
469template <cpu_isa_t isa, typename Wmm>
470int jit_brgemm_kernel_t<isa, Wmm>::compensations_offset(
471 int ld, bool is_tail) const noexcept {
472 return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
473 : sizeof(int32_t) * ld * brg.ld_block;
474}
475
476template <cpu_isa_t isa, typename Wmm>
477int jit_brgemm_kernel_t<isa, Wmm>::bdb_compensation_offset(int bd_block2) const
478 noexcept {
479 return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB;
480}
481
482template <cpu_isa_t isa, typename Wmm>
483int jit_brgemm_kernel_t<isa, Wmm>::compensation_vpad_offset(
484 int ld, int bd) const noexcept {
485 return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB);
486}
487
488template <cpu_isa_t isa, typename Wmm>
489int jit_brgemm_kernel_t<isa, Wmm>::scales_offset(int ld, bool is_tail) const
490 noexcept {
491 return (is_tail) ? brg.is_oc_scale * sizeof(float) * brg.ldb_tail
492 : brg.is_oc_scale * sizeof(float) * ld * brg.ld_block;
493}
494
495template <cpu_isa_t isa, typename Wmm>
496int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_a_offset(int ld, bool is_tail) const
497 noexcept {
498 return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
499 : sizeof(int32_t) * ld * brg.ld_block;
500}
501
502template <cpu_isa_t isa, typename Wmm>
503int jit_brgemm_kernel_t<isa, Wmm>::bdb_zp_comp_a_offset(int bd_block2) const
504 noexcept {
505 return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB;
506}
507
508template <cpu_isa_t isa, typename Wmm>
509int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_a_vpad_offset(int ld, int bd) const
510 noexcept {
511 return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB);
512}
513
514template <cpu_isa_t isa, typename Wmm>
515int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_b_offset(int bd) const noexcept {
516 return sizeof(int32_t) * bd;
517}
518
519template <cpu_isa_t isa, typename Wmm>
520int jit_brgemm_kernel_t<isa, Wmm>::bdb_zp_comp_b_offset(int bd_block2) const
521 noexcept {
522 return zp_comp_b_offset(bd_block2 * brg.bd_block);
523}
524
525template <cpu_isa_t isa, typename Wmm>
526int jit_brgemm_kernel_t<isa, Wmm>::zp_c_values_offset(
527 int ld, bool is_tail) const noexcept {
528 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
529 return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
530 : sizeof(int32_t) * ld * brg.ld_block;
531 }
532
533 return 0;
534}
535template <cpu_isa_t isa, typename Wmm>
536typename jit_brgemm_kernel_t<isa, Wmm>::Vmm
537jit_brgemm_kernel_t<isa, Wmm>::vmm_mask(const Vmm vmm_in, bool mask_flag,
538 bool store, Xbyak::Opmask ktail_mask) const {
539 return mask_flag && is_superset(brg.isa_impl, avx512_core)
540 ? (store ? vmm_in | ktail_mask : vmm_in | ktail_mask | T_z)
541 : vmm_in;
542}
543
544template <cpu_isa_t isa, typename Wmm>
545typename jit_brgemm_kernel_t<isa, Wmm>::Vmm_lower_t
546jit_brgemm_kernel_t<isa, Wmm>::vmm_lower_mask(const Vmm_lower_t vmm_lower_in,
547 bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const {
548 return mask_flag && is_superset(brg.isa_impl, avx512_core)
549 ? (store ? vmm_lower_in | ktail_mask
550 : vmm_lower_in | ktail_mask | T_z)
551 : vmm_lower_in;
552}
553
554template <cpu_isa_t isa, typename Wmm>
555void jit_brgemm_kernel_t<isa, Wmm>::cvt2ps(data_type_t type_in,
556 const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag, bool store,
557 Xbyak::Opmask ktail_mask, int tail_size) {
558 Vmm vmm = vmm_in;
559 const bool has_tail
560 = op.isMEM() && tail_size != vreg_traits<Vmm>::vlen / sizeof(float);
561 if (IMPLICATION(has_tail, is_superset(brg.isa_impl, avx512_core))) {
562 vmm = vmm_mask(vmm_in, mask_flag, store, ktail_mask);
563 } else {
564 uni_vpxor(vmm_in, vmm_in, vmm_in);
565 load_data(type_in, vmm_in, op.getAddress(), tail_size);
566 if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in);
567 return;
568 }
569 switch (type_in) {
570 case data_type::f32:
571 case data_type::s32: uni_vmovups(vmm, op); break;
572 case data_type::bf16:
573 uni_vpmovzxwd(vmm, op);
574 uni_vpslld(vmm, vmm, 16);
575 break;
576 case data_type::f16: vcvtph2ps(vmm, op); break;
577 case data_type::s8: uni_vpmovsxbd(vmm, op); break;
578 case data_type::u8: uni_vpmovzxbd(vmm, op); break;
579 default: assert(!"unsupported data type");
580 }
581 if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in);
582}
583
584template <cpu_isa_t isa, typename Wmm>
585void jit_brgemm_kernel_t<isa, Wmm>::advance_ldb_post_op_regs() {
586 if (brg.with_bias) {
587 mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
588 add(reg_aux_bias, bias_offset(1));
589 mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
590 }
591 if (brg.with_scales) {
592 mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
593 add(reg_aux_scales, scales_offset(1));
594 mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
595 }
596 if (with_binary_per_oc_bcast_) {
597 mov(reg_aux_binary_postops_oc_l,
598 ptr[rsp + reg_aux_binary_postops_oc_l_offs_]);
599 add(reg_aux_binary_postops_oc_l, oc_logical_offset(1));
600 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_],
601 reg_aux_binary_postops_oc_l);
602 }
603 if (brg.zp_type_a != brgemm_broadcast_t::none) {
604 mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
605 add(reg_aux_zp_comp_a, zp_comp_a_offset(1));
606 mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
607 }
608 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
609 mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
610 add(reg_aux_zp_c_values, zp_c_values_offset(1));
611 mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
612 }
613}
614
615template <cpu_isa_t isa, typename Wmm>
616void jit_brgemm_kernel_t<isa, Wmm>::restore_ldb_post_op_regs(int ld_block2) {
617 if (brg.with_bias) {
618 mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
619 sub(reg_aux_bias, bias_offset(ld_block2 - 1));
620 mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
621 }
622 if (brg.with_scales) {
623 mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
624 sub(reg_aux_scales, scales_offset(ld_block2 - 1));
625 mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
626 }
627 if (with_binary_per_oc_bcast_) {
628 mov(reg_aux_binary_postops_oc_l,
629 ptr[rsp + reg_aux_binary_postops_oc_l_offs_]);
630 sub(reg_aux_binary_postops_oc_l, oc_logical_offset(ld_block2 - 1));
631 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_],
632 reg_aux_binary_postops_oc_l);
633 }
634 if (brg.zp_type_a != brgemm_broadcast_t::none) {
635 mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
636 sub(reg_aux_zp_comp_a, zp_comp_a_offset(ld_block2 - 1));
637 mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
638 }
639 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
640 mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
641 sub(reg_aux_zp_c_values, zp_c_values_offset(ld_block2 - 1));
642 mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
643 }
644}
645
646template <cpu_isa_t isa, typename Wmm>
647void jit_brgemm_kernel_t<isa, Wmm>::advance_bdb_post_op_regs(int adj_bd_block) {
648 if (brg.zp_type_b != brgemm_broadcast_t::none) {
649 mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]);
650 add(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(1));
651 mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b);
652 }
653 if (with_binary_per_oc_sp_bcast_) {
654 const injector_utils::register_preserve_guard_t register_guard(
655 this, {reg_aux_binary_postops_oc_l});
656 mov(reg_aux_binary_postops_oc_l,
657 ptr[rsp + reg_aux_binary_postops_oc_l_offs_
658 + register_guard.stack_space_occupied()]);
659 add(reg_aux_binary_postops_oc_l, adj_bd_block);
660 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_
661 + register_guard.stack_space_occupied()],
662 reg_aux_binary_postops_oc_l);
663 }
664}
665
666template <cpu_isa_t isa, typename Wmm>
667void jit_brgemm_kernel_t<isa, Wmm>::restore_bdb_post_op_regs(int bd_block2) {
668 bool post_processed = false;
669 if (bd_block2 > 1) {
670 if (brg.zp_type_b != brgemm_broadcast_t::none) {
671 post_processed = true;
672 mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]);
673 sub(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(bd_block2 - 1));
674 mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b);
675 }
676 if (with_binary_per_oc_sp_bcast_) {
677 post_processed = true;
678 const injector_utils::register_preserve_guard_t register_guard(
679 this, {reg_aux_binary_postops_oc_l});
680 mov(reg_aux_binary_postops_oc_l,
681 ptr[rsp + reg_aux_binary_postops_oc_l_offs_
682 + register_guard.stack_space_occupied()]);
683 sub(reg_aux_binary_postops_oc_l, (bd_block2 - 1) * brg.bd_block);
684 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_
685 + register_guard.stack_space_occupied()],
686 reg_aux_binary_postops_oc_l);
687 }
688 }
689 if (post_processed) mov(reg_buf, ptr[rsp + reg_buf_offs_]);
690}
691
692template <cpu_isa_t isa, typename Wmm>
693void jit_brgemm_kernel_t<isa, Wmm>::ldb_regs_shift(
694 int ld_block2, bool is_tail) {
695 int C_offset = (is_tail) ? ldb_C_offset(1, true) : ldb_C_offset(ld_block2);
696 int D_offset = (is_tail) ? ldb_D_offset(1, true) : ldb_D_offset(ld_block2);
697 add(reg_aux_C, C_offset);
698 add(reg_aux_D, D_offset);
699
700 add(reg_b_offset,
701 (is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2));
702
703 if (brg.with_bias) {
704 mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]);
705 add(reg_aux_bias,
706 (is_tail) ? bias_offset(1, true) : bias_offset(ld_block2));
707 mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias);
708 }
709 if (brg.req_s8s8_compensation) {
710 mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]);
711 add(reg_aux_compensation,
712 (is_tail) ? compensations_offset(1, true)
713 : compensations_offset(ld_block2));
714 mov(ptr[rsp + reg_aux_comp_offs_], reg_aux_compensation);
715 }
716 if (brg.with_scales) {
717 mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
718 add(reg_aux_scales,
719 (is_tail) ? scales_offset(1, true) : scales_offset(ld_block2));
720 mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales);
721 }
722 if (with_binary_channel_bcast_) {
723 const int po_offset
724 = (is_tail) ? ldb_po_offset(1, true) : ldb_po_offset(ld_block2);
725 mov(reg_aux_binary_postops_sp,
726 ptr[rsp + reg_aux_binary_postops_sp_offs_]);
727 add(reg_aux_binary_postops_sp, po_offset);
728 mov(ptr[rsp + reg_aux_binary_postops_sp_offs_],
729 reg_aux_binary_postops_sp);
730 }
731 if (with_binary_per_oc_bcast_) {
732 mov(reg_aux_binary_postops_oc_l,
733 ptr[rsp + reg_aux_binary_postops_oc_l_offs_]);
734 add(reg_aux_binary_postops_oc_l,
735 (is_tail) ? oc_logical_offset(1, true)
736 : oc_logical_offset(ld_block2));
737 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_],
738 reg_aux_binary_postops_oc_l);
739 }
740 if (brg.zp_type_a != brgemm_broadcast_t::none) {
741 mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
742 add(reg_aux_zp_comp_a,
743 (is_tail) ? zp_comp_a_offset(1, true)
744 : zp_comp_a_offset(ld_block2));
745 mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a);
746 }
747 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
748 mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
749 add(reg_aux_zp_c_values,
750 (is_tail) ? zp_c_values_offset(1, true)
751 : zp_c_values_offset(ld_block2));
752 mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values);
753 }
754}
755
756template <cpu_isa_t isa, typename Wmm>
757void jit_brgemm_kernel_t<isa, Wmm>::advance_bd_block2_post_op_regs(
758 int bd_block2) {
759 if (with_binary_per_oc_sp_bcast_) {
760 mov(reg_aux_binary_postops_oc_l,
761 ptr[rsp + reg_binary_postops_oc_l_offs_]);
762 add(reg_binary_postops_oc_l, bd_block2 * brg.bd_block);
763 mov(ptr[rsp + reg_binary_postops_oc_l_offs_],
764 reg_aux_binary_postops_oc_l);
765 }
766 if (with_binary_channel_bcast_) {
767 mov(reg_aux_binary_postops_sp, ptr[rsp + reg_binary_postops_sp_offs_]);
768 add(reg_aux_binary_postops_sp, bdb_po_offset(bd_block2));
769 mov(ptr[rsp + reg_binary_postops_sp_offs_], reg_aux_binary_postops_sp);
770 }
771 if (brg.zp_type_b != brgemm_broadcast_t::none) {
772 mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
773 add(reg_zp_comp_b, bdb_zp_comp_b_offset(bd_block2));
774 mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b);
775 }
776}
777
778template <cpu_isa_t isa, typename Wmm>
779void jit_brgemm_kernel_t<isa, Wmm>::copy_post_ops_stack_values_to_aux(
780 bool is_reg_tail) {
781 if (!is_reg_tail) {
782 mov(reg_aux_C, reg_C);
783 mov(reg_aux_D, reg_D);
784 xor_(reg_b_offset, reg_b_offset);
785 if (brg.with_bias) {
786 mov(reg_bias, ptr[rsp + reg_bias_offs_]);
787 mov(ptr[rsp + reg_aux_bias_offs_], reg_bias);
788 }
789 if (brg.req_s8s8_compensation) {
790 mov(reg_compensation, ptr[rsp + reg_comp_offs_]);
791 mov(ptr[rsp + reg_aux_comp_offs_], reg_compensation);
792 }
793 if (brg.with_scales) {
794 mov(reg_scales, ptr[rsp + reg_scales_offs_]);
795 mov(ptr[rsp + reg_aux_scales_offs_], reg_scales);
796 }
797 if (with_binary_channel_bcast_) {
798 mov(reg_aux_binary_postops_sp,
799 ptr[rsp + reg_binary_postops_sp_offs_]);
800 mov(ptr[rsp + reg_aux_binary_postops_sp_offs_],
801 reg_aux_binary_postops_sp);
802 }
803 if (with_binary_per_oc_bcast_) {
804 mov(reg_binary_postops_oc_l,
805 ptr[rsp + reg_binary_postops_oc_l_offs_]);
806 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_],
807 reg_binary_postops_oc_l);
808 }
809
810 if (brg.zp_type_a != brgemm_broadcast_t::none) {
811 mov(reg_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]);
812 mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_zp_comp_a);
813 }
814
815 if (brg.zp_type_c != brgemm_broadcast_t::none) {
816 mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
817 mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_zp_c_values);
818 }
819 }
820 if (brg.zp_type_b != brgemm_broadcast_t::none) {
821 mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]);
822 mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_zp_comp_b);
823 }
824 if (with_binary_per_oc_sp_bcast_) {
825 mov(reg_aux_binary_postops_oc_l,
826 ptr[rsp + reg_binary_postops_oc_l_offs_]);
827 mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_],
828 reg_binary_postops_oc_l);
829 }
830}
831
832template <cpu_isa_t isa, typename Wmm>
833void jit_brgemm_kernel_t<isa, Wmm>::read_params() {
834 Label label_done;
835
836 if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1);
837
838 if (brg.type == brgemm_addr) {
839 mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]);
840 } else {
841 if (brg.layout == brgemm_row_major) {
842 mov(reg_A, ptr[param1 + GET_OFF(ptr_A)]);
843 mov(reg_B, ptr[param1 + GET_OFF(ptr_B)]);
844 } else {
845 mov(reg_A, ptr[param1 + GET_OFF(ptr_B)]);
846 mov(reg_B, ptr[param1 + GET_OFF(ptr_A)]);
847 }
848
849 if (brg.type == brgemm_offs) {
850 mov(reg_offs_batch, ptr[param1 + GET_OFF(batch)]);
851 mov(ptr[rsp + origin_offs_batch_offs_], reg_offs_batch);
852 } else {
853 mov(reg_strd_batch, ptr[param1 + GET_OFF(batch)]);
854 mov(ptr[rsp + origin_strd_batch_offs_], reg_strd_batch);
855 }
856 }
857
858 mov(reg_C, ptr[param1 + GET_OFF(ptr_C)]);
859 mov(reg_D, ptr[param1 + GET_OFF(ptr_D)]);
860 mov(reg_BS, ptr[param1 + GET_OFF(BS)]);
861
862 // ptr_buf is re-used for passing compensations for
863 // brg.req_s8s8_compensation case
864 if (brg.is_tmm || brg.req_s8s8_compensation) {
865 mov(reg_buf, ptr[param1 + GET_OFF(ptr_buf)]);
866 mov(ptr[rsp + reg_buf_offs_], reg_buf);
867 }
868
869 if (brg.with_bias) {
870 mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]);
871 mov(ptr[rsp + reg_bias_offs_], reg_bias);
872 }
873 if (brg.with_scales) {
874 mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]);
875 mov(ptr[rsp + reg_scales_offs_], reg_scales);
876 }
877 if (with_binary_no_bcast_) {
878 mov(reg_aux_binary_postops_sp, ptr[param1 + GET_OFF(data_C_ptr_)]);
879 mov(ptr[rsp + reg_data_C_ptr_], reg_aux_binary_postops_sp);
880 }
881 if (with_binary_channel_bcast_) {
882 mov(reg_aux_binary_postops_sp,
883 ptr[param1 + GET_OFF(first_mb_matrix_addr_off)]);
884 mov(ptr[rsp + reg_binary_postops_sp_offs_], reg_aux_binary_postops_sp);
885 }
886 if (with_binary_per_oc_bcast_) {
887 mov(reg_binary_postops_oc_l, ptr[param1 + GET_OFF(oc_logical_off)]);
888 mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_binary_postops_oc_l);
889 } else if (with_binary_per_oc_sp_bcast_) {
890 mov(reg_binary_postops_oc_l,
891 ptr[param1 + GET_OFF(dst_row_logical_off)]);
892 mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_binary_postops_oc_l);
893 }
894
895 if (brg.zp_type_a != brgemm_broadcast_t::none) {
896 mov(reg_zp_comp_a, ptr[param1 + GET_OFF(a_zp_compensations)]);
897 mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a);
898 }
899
900 if (brg.zp_type_b != brgemm_broadcast_t::none) {
901 mov(reg_zp_comp_b, ptr[param1 + GET_OFF(b_zp_compensations)]);
902 mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b);
903 }
904
905 if (brg.zp_type_c != brgemm_broadcast_t::none) {
906 mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]);
907 mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values);
908 }
909
910 mov(reg_do_post_ops, ptr[param1 + GET_OFF(do_post_ops)]);
911 mov(ptr[rsp + reg_do_post_ops_offs_], reg_do_post_ops);
912
913 mov(reg_skip_accm, ptr[param1 + GET_OFF(skip_accm)]);
914 mov(ptr[rsp + reg_skip_accm_offs_], reg_skip_accm);
915
916 mov(reg_zp_a_val, ptr[param1 + GET_OFF(zp_a_val)]);
917 mov(ptr[rsp + reg_zp_a_val_offs_], reg_zp_a_val);
918
919 mov(reg_do_comp, ptr[param1 + GET_OFF(do_apply_comp)]);
920 mov(ptr[rsp + reg_do_comp_offs_], reg_do_comp);
921}
922
923template <cpu_isa_t isa, typename Wmm>
924void jit_brgemm_kernel_t<isa, Wmm>::zero_accumulators(int bd_block2,
925 bool is_bdb_tail, int ld_block2, bool is_ld_tail,
926 bool skip_accumulation) {
927 if (brg.is_tmm) {
928 // avoid usage of tile registers if there is no accumulation
929 if (skip_accumulation) return;
930 for_(int bdb = 0; bdb < bd_block2; bdb++)
931 for (int ldb = 0; ldb < ld_block2; ldb++) {
932 int idx = (is_ld_tail) ? brg.ld_block2 : ldb;
933 tilezero(Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail, is_ld_tail)));
934 }
935 } else {
936 int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
937 for_(int bd = 0; bd < bd_block; bd++)
938 for (int ld = 0; ld < ld_block2; ld++) {
939 auto vmm = accm(ld_block2, bd, ld);
940 uni_vpxor(vmm, vmm, vmm);
941 }
942 }
943}
944
945template <cpu_isa_t isa, typename Wmm>
946void jit_brgemm_kernel_t<isa, Wmm>::apply_alpha_beta(
947 int bd_block, int ld_block2, bool is_ld_tail) {
948 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
949 const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block;
950 auto vmm_beta = vmm_tmp_1();
951 auto vmm_alpha = vmm_tmp_2();
952 auto vmm_prev_dst = vmm_tmp_3();
953
954 const bool apply_alpha = brg.alpha != 1.f;
955 const bool apply_beta = brg.beta != 0.f;
956 if (!apply_alpha && !apply_beta) return;
957
958 const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f);
959 const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
960
961 if (apply_beta && !use_vadd_for_beta) {
962 mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta)));
963 uni_vmovq(Xmm(vmm_beta.getIdx()), reg_tmp_gpr);
964 uni_vbroadcastss(vmm_beta, Xmm(vmm_beta.getIdx()));
965 }
966 if (apply_alpha) {
967 mov(reg_tmp_gpr, float2int(static_cast<float>(brg.alpha)));
968 uni_vmovq(Xmm(vmm_alpha.getIdx()), reg_tmp_gpr);
969 uni_vbroadcastss(vmm_alpha, Xmm(vmm_alpha.getIdx()));
970 }
971 for_(int bd = 0; bd < bd_block; bd++)
972 for (int ld = 0; ld < ld_block2; ld++) {
973 const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
974 auto vmm = accm(ld_block2, bd, ld);
975 if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm);
976 if (apply_alpha) uni_vmulps(vmm, vmm, vmm_alpha);
977 if (apply_beta) {
978 auto ptr_C = ptr[reg_aux_C + C_offset(bd, ld)];
979 if (use_vadd_for_beta) {
980 if (IMPLICATION(
981 is_tail, is_superset(brg.isa_impl, avx512_core))) {
982 auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask);
983 if (brg.is_int8)
984 uni_vpaddd(vmm_masked, vmm, ptr_C);
985 else
986 uni_vaddps(vmm_masked, vmm, ptr_C);
987 } else {
988 load_data(brg.dt_c, vmm_prev_dst, ptr_C, ld_size);
989 if (brg.is_int8)
990 uni_vpaddd(vmm, vmm, vmm_prev_dst);
991 else
992 uni_vaddps(vmm, vmm, vmm_prev_dst);
993 }
994 } else {
995 cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, true, false, k_mask,
996 ld_size);
997 uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta);
998 }
999 }
1000 }
1001}
1002
1003template <cpu_isa_t isa, typename Wmm>
1004void jit_brgemm_kernel_t<isa, Wmm>::apply_post_ops(
1005 int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) {
1006
1007 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1008
1009 const injector_utils::conditional_register_preserve_guard_t register_guard(
1010 brg.with_binary, this, {param1});
1011 const auto guard_space = register_guard.stack_space_occupied();
1012 if (brg.with_binary) {
1013 mov(param1, ptr[rsp + abi_param1_offs_ + guard_space]);
1014
1015 if (handle_binary_po_offset_) {
1016 for_(int bd = 0; bd < bd_block; bd++)
1017 for (int ld = 0; ld < ld_block2; ld++) {
1018 const auto vmm_idx = accm(ld_block2, bd, ld).getIdx();
1019
1020 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_aux_D);
1021 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1022 vmm_idx, D_offset(bd, ld));
1023 if (is_ld_tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1024 }
1025 }
1026 }
1027
1028 const auto sum_injector = [&] {
1029 const float *p_sum_scale = &brg.sum_scale;
1030 const int32_t *p_sum_zp = &brg.sum_zp;
1031 const bool p_sum_scale_reg_set = *p_sum_scale != 1.f;
1032 const bool p_sum_zp_reg_set = *p_sum_zp != 0;
1033
1034 {
1035 const injector_utils::conditional_register_preserve_guard_t
1036 register_guard_sum_scale(
1037 (handle_binary_po_offset_) && p_sum_scale_reg_set,
1038 this, {reg_ptr_sum_scale});
1039 const injector_utils::conditional_register_preserve_guard_t
1040 register_guard_sum_zp(
1041 p_sum_zp_reg_set, this, {reg_ptr_sum_zp});
1042
1043 if (p_sum_scale_reg_set)
1044 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
1045
1046 const auto &vmm_sum_zp = vmm_tmp_2();
1047 if (p_sum_zp_reg_set) {
1048 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
1049 if (is_superset(brg.isa_impl, avx512_core)) {
1050 vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
1051 } else {
1052 uni_vpbroadcastd(vmm_sum_zp, ptr[reg_ptr_sum_zp]);
1053 uni_vcvtdq2ps(vmm_sum_zp, vmm_sum_zp);
1054 }
1055 }
1056
1057 const auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
1058 const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block;
1059
1060 for (int bd = 0; bd < bd_block; bd++) {
1061 for (int ld = 0; ld < ld_block2; ld++) {
1062 const auto vmm = accm(ld_block2, bd, ld);
1063 const auto addr = ptr[reg_aux_D + D_offset(bd, ld)];
1064 const auto vmm_prev_dst = vmm_tmp_1();
1065 cvt2ps(brg.sum_dt, vmm_prev_dst, addr, true, false, k_mask,
1066 ld_size);
1067 if (p_sum_zp_reg_set)
1068 uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_sum_zp);
1069 if (!p_sum_scale_reg_set)
1070 uni_vaddps(vmm, vmm, vmm_prev_dst);
1071 else {
1072 if (is_superset(brg.isa_impl, avx512_core)) {
1073 uni_vfmadd231ps(vmm, vmm_prev_dst,
1074 ptr_b[reg_ptr_sum_scale]);
1075 } else {
1076 auto vmm_tmp = vmm_tmp_2();
1077 uni_vpbroadcastd(vmm_tmp, ptr[reg_ptr_sum_scale]);
1078 uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_tmp);
1079 }
1080 }
1081 }
1082 }
1083 }
1084 };
1085
1086 if (brg.with_sum) {
1087 postops_injector_->set_lambda_injector(
1088 primitive_kind::sum, sum_injector);
1089 }
1090
1091 postops_injector_->compute_vector_range(
1092 max_vregs - bd_block * ld_block2, max_vregs, rhs_arg_params);
1093}
1094
1095template <cpu_isa_t isa, typename Wmm>
1096void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators_apply_post_ops(
1097 int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) {
1098 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
1099 const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block;
1100
1101 // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
1102 // accumulated values are already converted to ps in apply_alpha_beta()
1103 const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f;
1104 const bool beta_uses_vadd
1105 = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f);
1106 const bool dq2ps_required = brg.is_int8
1107 && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd);
1108
1109 if (brg.with_scales) {
1110 mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]);
1111 for (int bd = 0; bd < bd_block; bd++) {
1112 for (int ld = 0; ld < ld_block2; ld++) {
1113 const auto addr = ptr[reg_aux_scales + scales_offset(ld)];
1114 auto vmm = accm(ld_block2, bd, ld);
1115 if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm);
1116 if (is_superset(brg.isa_impl, avx512_core)) {
1117 const Vmm vmm_masked = vmm_mask(vmm, true, false, k_mask);
1118 uni_vmulps(vmm_masked, vmm, addr);
1119 } else {
1120 auto vmm_scales = vmm_tmp_1();
1121 load_data(data_type::f32, vmm_scales, addr, ld_size);
1122 uni_vmulps(vmm, vmm, vmm_scales);
1123 }
1124 }
1125 }
1126 }
1127
1128 if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); }
1129 for_(int bd = 0; bd < bd_block; bd++)
1130 for (int ld = 0; ld < ld_block2; ld++) {
1131 auto vmm = accm(ld_block2, bd, ld);
1132 if (dq2ps_required && !brg.with_scales) uni_vcvtdq2ps(vmm, vmm);
1133 if (brg.with_bias) {
1134 auto vmm_bias = vmm_tmp_1();
1135 auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)];
1136 cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, true, false, k_mask,
1137 ld_size);
1138 uni_vaddps(vmm, vmm, vmm_bias);
1139 }
1140 }
1141
1142 if (postops_injector_)
1143 apply_post_ops(bd_block, ld_block2, ldb_and_bdb_offset, is_ld_tail);
1144
1145 if (brg.zp_type_c != brgemm_broadcast_t::none) {
1146 mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]);
1147 auto vmm_zp_c = vmm_tmp_1();
1148 if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
1149 if (is_superset(brg.isa_impl, avx512_core)) {
1150 uni_vcvtdq2ps(vmm_zp_c,
1151 EVEX_compress_addr(reg_aux_zp_c_values, 0, true));
1152 } else {
1153 uni_vpbroadcastd(vmm_zp_c, ptr[reg_aux_zp_c_values]);
1154 uni_vcvtdq2ps(vmm_zp_c, vmm_zp_c);
1155 }
1156 }
1157 for (int ld = 0; ld < ld_block2; ld++) {
1158 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
1159 int zp_c_off = zp_c_values_offset(ld);
1160 auto zp_c_addr
1161 = EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off);
1162 cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, true, false, k_mask,
1163 ld_size);
1164 }
1165 for (int bd = 0; bd < bd_block; bd++) {
1166 auto vmm = accm(ld_block2, bd, ld);
1167 uni_vaddps(vmm, vmm, vmm_zp_c);
1168 }
1169 }
1170 }
1171
1172 const bool dt_requires_saturation
1173 = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
1174 auto vmm_lbound = vmm_tmp_1();
1175 auto vmm_ubound = vmm_tmp_2();
1176 if (dt_requires_saturation) {
1177 init_saturate_f32(
1178 vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d);
1179 }
1180
1181 if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16();
1182
1183 for (int bd = 0; bd < bd_block; bd++) {
1184 if (dt_requires_saturation) {
1185 for (int ld = 0; ld < ld_block2; ld++) {
1186 auto vmm = accm(ld_block2, bd, ld);
1187 saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
1188 uni_vcvtps2dq(vmm, vmm);
1189 }
1190 }
1191 for (int ld = 0; ld < ld_block2; ld++) {
1192 auto addr = ptr[reg_aux_D + D_offset(bd, ld)];
1193 auto vmm = accm(ld_block2, bd, ld);
1194 auto vmm_lower = Vmm_lower_t(vmm.getIdx());
1195 const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
1196 const Vmm r_vmm = vmm_mask(vmm, is_tail, true, k_mask);
1197 const Vmm_lower_t r_ymm
1198 = vmm_lower_mask(vmm_lower, is_tail, true, k_mask);
1199 if (is_superset(brg.isa_impl, avx512_core)) {
1200 switch (brg.dt_d) {
1201 case data_type::f32:
1202 case data_type::s32: uni_vmovups(addr, r_vmm); break;
1203 case data_type::bf16: // TODO - clean
1204 if (brg.is_bf16_emu) {
1205 bf16_emu_->vcvtneps2bf16(vmm_lower, vmm);
1206 vmovdqu16(addr, r_ymm);
1207 } else {
1208 vcvtneps2bf16(vmm_lower, vmm);
1209 vmovdqu16(addr, r_ymm);
1210 }
1211 break;
1212 case data_type::f16:
1213 vcvtps2ph(vmm_lower, vmm, _op_mxcsr);
1214 vmovdqu16(addr, r_ymm);
1215 break;
1216 case data_type::s8: vpmovsdb(addr, r_vmm); break;
1217 case data_type::u8: vpmovusdb(addr, r_vmm); break;
1218 default: assert(!"unknown dst_dt");
1219 }
1220 } else {
1221 const int ld_block = is_tail ? brg.ldb_tail : brg.ld_block;
1222 store_data(
1223 brg.dt_d, vmm, reg_aux_D, D_offset(bd, ld), ld_block);
1224 }
1225 }
1226 }
1227}
1228
1229template <cpu_isa_t isa, typename Wmm>
1230void jit_brgemm_kernel_t<isa, Wmm>::apply_compensation(
1231 int bd_block, int ld_block2, bool is_ld_tail) {
1232 // apply compensation to accumulated values
1233 // to avoid the loss of accuracy when converting s32 to f32
1234 auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
1235 const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block;
1236
1237 if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
1238 auto vmm_zp_a_val = vmm_tmp_2();
1239 mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]);
1240 uni_vpbroadcastd(vmm_zp_a_val, reg_zp_a_val.cvt32());
1241
1242 mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]);
1243 for (int ld = 0; ld < ld_block2; ld++) {
1244 auto vmm_zp_comp_a = vmm_tmp_1();
1245 int zp_comp_a_off = zp_comp_a_offset(ld);
1246 auto zp_comp_a_addr
1247 = EVEX_compress_addr(reg_aux_zp_comp_a, zp_comp_a_off);
1248 // apply src zero points value to the accumulated values
1249 if (is_superset(brg.isa_impl, avx512_core)) {
1250 auto vmm_zp_comp_a_masked
1251 = vmm_mask(vmm_zp_comp_a, true, false, k_mask);
1252 uni_vmovups(vmm_zp_comp_a_masked, zp_comp_a_addr);
1253 } else {
1254 load_data(
1255 data_type::s32, vmm_zp_comp_a, zp_comp_a_addr, ld_size);
1256 }
1257 uni_vpmulld(vmm_zp_comp_a, vmm_zp_comp_a, vmm_zp_a_val);
1258
1259 for (int bd = 0; bd < bd_block; bd++) {
1260 auto vmm = accm(ld_block2, bd, ld);
1261 uni_vpaddd(vmm, vmm, vmm_zp_comp_a);
1262 }
1263 }
1264 }
1265
1266 if (brg.zp_type_b != brgemm_broadcast_t::none) {
1267 mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]);
1268 for (int bd = 0; bd < bd_block; bd++) {
1269 int zp_comp_b_off = zp_comp_b_offset(bd);
1270 auto zp_comp_b_addr = EVEX_compress_addr(
1271 reg_aux_zp_comp_b, zp_comp_b_off, true);
1272 for (int ld = 0; ld < ld_block2; ld++) {
1273 auto vmm = accm(ld_block2, bd, ld);
1274 uni_vpaddd(vmm, vmm, zp_comp_b_addr);
1275 }
1276 }
1277 }
1278
1279 if (!brg.req_cal_comp_pads && brg.req_s8s8_compensation) {
1280 mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]);
1281 for (int ld = 0; ld < ld_block2; ld++) {
1282 auto vmm_comp = vmm_tmp_1();
1283 int comp_offset = compensations_offset(ld);
1284 auto comp_addr
1285 = EVEX_compress_addr(reg_aux_compensation, comp_offset);
1286 const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
1287 if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) {
1288 vmm_comp = vmm_mask(vmm_comp, true, false, k_mask);
1289 uni_vmovups(vmm_comp, comp_addr);
1290 } else {
1291 load_data(data_type::s32, vmm_comp, comp_addr, ld_size);
1292 }
1293 for (int bd = 0; bd < bd_block; bd++) {
1294 auto vmm = accm(ld_block2, bd, ld);
1295 uni_vpaddd(vmm, vmm, vmm_comp);
1296 }
1297 }
1298 }
1299}
1300
1301template <cpu_isa_t isa, typename Wmm>
1302void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators_without_post_ops(
1303 int bd_block, int ld_block2, bool is_ld_tail) {
1304
1305 // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
1306 // accumulated values are converted to ps in apply_alpha_beta()
1307 const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f;
1308 const bool beta_uses_vadd
1309 = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f);
1310 const bool dt_requires_saturation = brg.is_int8
1311 && !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd);
1312 auto vmm_lbound = vmm_tmp_1();
1313 auto vmm_ubound = vmm_tmp_2();
1314 if (dt_requires_saturation) {
1315 init_saturate_f32(
1316 vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d);
1317 }
1318
1319 for (int bd = 0; bd < bd_block; bd++) {
1320 if (dt_requires_saturation) {
1321 for (int ld = 0; ld < ld_block2; ld++) {
1322 auto vmm = accm(ld_block2, bd, ld);
1323 saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
1324 uni_vcvtps2dq(vmm, vmm);
1325 }
1326 }
1327 for (int ld = 0; ld < ld_block2; ld++) {
1328 auto vmm = accm(ld_block2, bd, ld);
1329 const auto addr_c = ptr[reg_aux_C + C_offset(bd, ld)];
1330 if (is_ld_tail) {
1331 if (is_superset(brg.isa_impl, avx512_core)) {
1332 uni_vmovups(addr_c | ld_tail_mask | T_z, vmm);
1333 } else {
1334 store_data(brg.dt_c, vmm, reg_aux_C, C_offset(bd, ld),
1335 brg.ldb_tail);
1336 }
1337 } else
1338 uni_vmovups(ptr[reg_aux_C + C_offset(bd, ld)], vmm);
1339 }
1340 }
1341}
1342
1343template <cpu_isa_t isa, typename Wmm>
1344void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators(int bd_block2,
1345 bool is_bdb_tail, int ld_block2, bool is_ld_tail,
1346 bool skip_accumulation) {
1347 const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none,
1348 brg.zp_type_a, brg.zp_type_b, brg.zp_type_c);
1349 const bool are_post_ops_applicable = one_of(true, brg.with_eltwise,
1350 brg.with_binary, brg.with_scales, brg.with_bias, brg.with_sum,
1351 brg.dt_d != brg.dt_c, brg.req_s8s8_compensation, has_zero_points);
1352 const bool need_to_apply_alpha_beta = brg.beta != 0.f || brg.alpha != 1.f;
1353
1354 if (brg.is_tmm) {
1355 if (need_to_apply_alpha_beta || are_post_ops_applicable)
1356 mov(reg_stride_ld_block, brg.ld_block * brg.typesize_C);
1357 else
1358 mov(reg_stride_ld_block, brg.LDC * brg.typesize_C);
1359
1360 auto store_accumulators_amx = [=](const bool apply_post_ops) {
1361 mov(reg_buf, ptr[rsp + reg_buf_offs_]);
1362 for (int bdb = 0; bdb < bd_block2; bdb++) {
1363 int adj_bd_block = (brg.is_M_tail && is_bdb_tail)
1364 ? brg.bdb_tail
1365 : brg.bd_block;
1366 for (int ldb = 0; ldb < ld_block2; ldb++) {
1367 int idx = (is_ld_tail) ? brg.ld_block2 : ldb;
1368 if (need_to_apply_alpha_beta || are_post_ops_applicable) {
1369 if (skip_accumulation) {
1370 for (int bd = 0; bd < adj_bd_block; bd++) {
1371 auto vreg_acc = accm(1, bd, 0);
1372 uni_vpxor(vreg_acc, vreg_acc, vreg_acc);
1373 }
1374 } else {
1375 tilestored(ptr[reg_buf + reg_stride_ld_block],
1376 Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail,
1377 is_ld_tail)));
1378 for (int bd = 0; bd < adj_bd_block; bd++) {
1379 size_t buf_offset
1380 = (bd * brg.ld_block) * brg.typesize_C;
1381 auto vreg_acc = is_ld_tail
1382 ? accm(1, bd, 0) | ld_tail_mask | T_z
1383 : accm(1, bd, 0);
1384 uni_vmovups(
1385 vreg_acc, ptr[reg_buf + buf_offset]);
1386 }
1387 }
1388 if (need_to_apply_alpha_beta)
1389 apply_alpha_beta(adj_bd_block, 1, is_ld_tail);
1390
1391 if (apply_post_ops) {
1392 const size_t ldb_and_bdb_offset
1393 = ldb_po_offset(ldb) + bdb_po_offset(bdb);
1394 store_accumulators_apply_post_ops(adj_bd_block, 1,
1395 ldb_and_bdb_offset, is_ld_tail);
1396 if (ldb < ld_block2 - 1) advance_ldb_post_op_regs();
1397 add(reg_aux_D, ldb_D_offset(1));
1398 } else {
1399 store_accumulators_without_post_ops(
1400 adj_bd_block, 1, is_ld_tail);
1401 }
1402 mov(reg_buf, ptr[rsp + reg_buf_offs_]);
1403 } else {
1404 auto tmm = Tmm(brg.get_C_tensor(
1405 bdb, idx, is_bdb_tail, is_ld_tail));
1406 if (skip_accumulation) tilezero(tmm);
1407 tilestored(ptr[reg_aux_C + reg_stride_ld_block], tmm);
1408 }
1409 add(reg_aux_C, ldb_C_offset(1));
1410 }
1411 sub(reg_aux_C, ldb_C_offset(ld_block2));
1412 add(reg_aux_C, bdb_C_offset(1));
1413 if (apply_post_ops) {
1414 sub(reg_aux_D, ldb_D_offset(ld_block2));
1415 add(reg_aux_D, bdb_D_offset(1));
1416
1417 bool post_processed = false;
1418 if (ld_block2 > 1) {
1419 restore_ldb_post_op_regs(ld_block2);
1420 post_processed |= utils::one_of(true, brg.with_bias,
1421 brg.with_scales, with_binary_per_oc_bcast_,
1422 brg.zp_type_a != brgemm_broadcast_t::none,
1423 brg.zp_type_c == brgemm_broadcast_t::per_n);
1424 }
1425 if (bdb < bd_block2 - 1) {
1426 advance_bdb_post_op_regs(adj_bd_block);
1427 post_processed |= utils::one_of(true,
1428 brg.zp_type_b != brgemm_broadcast_t::none,
1429 with_binary_per_oc_sp_bcast_);
1430 }
1431 if (post_processed) mov(reg_buf, ptr[rsp + reg_buf_offs_]);
1432 }
1433 }
1434 sub(reg_aux_C, bdb_C_offset(bd_block2));
1435 if (apply_post_ops) {
1436 sub(reg_aux_D, bdb_D_offset(bd_block2));
1437 restore_bdb_post_op_regs(bd_block2);
1438 }
1439 };
1440
1441 Label label_done;
1442 if (are_post_ops_applicable) {
1443 Label label_store_without_post_ops;
1444 mov(reg_do_post_ops, ptr[rsp + reg_do_post_ops_offs_]);
1445 cmp(reg_do_post_ops, 0);
1446 jz(label_store_without_post_ops, T_NEAR);
1447
1448 store_accumulators_amx(true);
1449 jmp(label_done, T_NEAR);
1450
1451 L_aligned(label_store_without_post_ops);
1452 }
1453 store_accumulators_amx(false);
1454 L_aligned(label_done);
1455 } else {
1456 int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
1457
1458 if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) {
1459 Label label_store_without_comp;
1460 mov(reg_do_comp, ptr[rsp + reg_do_comp_offs_]);
1461 cmp(reg_do_comp, 0);
1462 jz(label_store_without_comp, T_NEAR);
1463 apply_compensation(bd_block, ld_block2, is_ld_tail);
1464
1465 L_aligned(label_store_without_comp);
1466 }
1467
1468 if (need_to_apply_alpha_beta)
1469 apply_alpha_beta(bd_block, ld_block2, is_ld_tail);
1470
1471 Label label_done;
1472 if (are_post_ops_applicable) {
1473 Label label_store_without_post_ops;
1474 mov(reg_do_post_ops, ptr[rsp + reg_do_post_ops_offs_]);
1475 cmp(reg_do_post_ops, 0);
1476 jz(label_store_without_post_ops, T_NEAR);
1477 store_accumulators_apply_post_ops(
1478 bd_block, ld_block2, 0, is_ld_tail);
1479 jmp(label_done, T_NEAR);
1480
1481 L_aligned(label_store_without_post_ops);
1482 }
1483 store_accumulators_without_post_ops(bd_block, ld_block2, is_ld_tail);
1484 L_aligned(label_done);
1485 }
1486}
1487
1488template <cpu_isa_t isa, typename Wmm>
1489void jit_brgemm_kernel_t<isa, Wmm>::restore_A_B_matrices() {
1490 auto restore_reg_batch = brg.brgattr.max_bs > 1 || vpad_exist;
1491 if (brg.type == brgemm_addr) {
1492 if (restore_reg_batch) mov(reg_aux1_batch, reg_addr_batch);
1493 } else {
1494 mov(reg_aux1_A, reg_A);
1495 mov(reg_aux1_B, reg_B);
1496
1497 if (restore_reg_batch) {
1498 if (brg.type == brgemm_offs)
1499 mov(reg_offs_batch, ptr[rsp + origin_offs_batch_offs_]);
1500 else
1501 mov(reg_strd_batch, ptr[rsp + origin_strd_batch_offs_]);
1502 }
1503 }
1504}
1505
1506template <cpu_isa_t isa, typename Wmm>
1507void jit_brgemm_kernel_t<isa, Wmm>::set_A_B_matrices() {
1508 if (brg.type == brgemm_addr) {
1509 if (brg.brgattr.max_bs > 1) {
1510 if (brg.layout == brgemm_row_major) {
1511 mov(reg_aux_A,
1512 ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]);
1513 mov(reg_aux_B,
1514 ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]);
1515 } else {
1516 mov(reg_aux_A,
1517 ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]);
1518 mov(reg_aux_B,
1519 ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]);
1520 }
1521 } else {
1522 // for max_batch == 1 we stored A and B pointers at the beginning
1523 // of kernel in reg_aux1_A and reg_aux1_B
1524 if (brg.layout == brgemm_row_major) {
1525 mov(reg_aux_A, reg_aux1_A);
1526 mov(reg_aux_B, reg_aux1_B);
1527 } else {
1528 mov(reg_aux_A, reg_aux1_B);
1529 mov(reg_aux_B, reg_aux1_A);
1530 }
1531 }
1532
1533 if (brg.brgattr.max_bs > 1) {
1534 add(reg_aux1_batch, sizeof(brgemm_batch_element_t));
1535 prefetcht0(ptr[reg_aux1_batch]);
1536 }
1537 } else if (brg.type == brgemm_offs) {
1538 mov(reg_aux_A, reg_A);
1539 mov(reg_aux_B, reg_B);
1540
1541 add(reg_aux_A, ptr[reg_offs_batch + GET_OFF_BATCH_ELEMENT(offset.A)]);
1542 add(reg_aux_B, ptr[reg_offs_batch + GET_OFF_BATCH_ELEMENT(offset.B)]);
1543 add(reg_offs_batch, sizeof(brgemm_batch_element_t));
1544 } else if (brg.type == brgemm_strd) {
1545 mov(reg_aux_A, reg_aux1_A);
1546 mov(reg_aux_B, reg_aux1_B);
1547
1548 safe_add(reg_aux1_A, brg.stride_a, reg_tmp_gpr);
1549 safe_add(reg_aux1_B, brg.stride_b, reg_tmp_gpr);
1550 if (vpad_exist) {
1551 mov(reg_strd_batch, ptr[rsp + origin_strd_batch_offs_]);
1552 add(reg_strd_batch, sizeof(brgemm_batch_element_t));
1553 mov(ptr[rsp + origin_strd_batch_offs_], reg_strd_batch);
1554 }
1555 }
1556
1557 add(reg_aux_A, reg_a_offset);
1558 add(reg_aux_B, reg_b_offset);
1559}
1560
1561template <cpu_isa_t isa, typename Wmm>
1562void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel_amx(int bd_block2,
1563 bool is_bdb_tail, int ld_block2, bool is_rd_tail, bool is_ld_tail) {
1564 auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) {
1565 if (brg.dt_a == data_type::bf16 && brg.dt_b == data_type::bf16) {
1566 tdpbf16ps(x1, x2, x3);
1567 } else if (brg.dt_a == data_type::f16 && brg.dt_b == data_type::f16) {
1568 tdpfp16ps(x1, x2, x3);
1569 } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8) {
1570 tdpbuud(x1, x2, x3);
1571 } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) {
1572 tdpbusd(x1, x2, x3);
1573 } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) {
1574 tdpbsud(x1, x2, x3);
1575 } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8) {
1576 tdpbssd(x1, x2, x3);
1577 } else {
1578 assert(!"unsupported combination");
1579 }
1580 };
1581
1582 auto maybe_tileloadd_nt = [=](const Tmm &t1, reg64_t base, int offset,
1583 reg64_t stride, bool try_load_nt) {
1584 if (try_load_nt
1585 && static_cast<size_t>(
1586 brg.typesize_A * brg.brgattr.hint_expected_A_size
1587 + brg.typesize_B * brg.brgattr.hint_expected_B_size
1588 + brg.typesize_C * brg.brgattr.hint_expected_C_size)
1589 >= platform::get_per_core_cache_size(1))
1590 tileloaddt1(t1, ptr[base + offset + stride]);
1591 else
1592 tileloadd(t1, ptr[base + offset + stride]);
1593 };
1594
1595 int rbd_block = (is_rd_tail) ? 1 : brg.rdb;
1596 for (int rdb = 0; rdb < rbd_block; rdb++) {
1597 for (int bdb = 0; bdb < bd_block2; bdb++) {
1598 maybe_tileloadd_nt(Tmm(brg.get_A_tensor(bdb, is_bdb_tail)),
1599 reg_aux_A, rdb * rdb_A_offset() + A_offset(bdb, 0, true),
1600 reg_stride_lda,
1601 brg.brgattr.hint_innermost_loop
1602 == brgemm_bd_loop_innermost);
1603 }
1604 for (int ldb = 0; ldb < ld_block2; ldb++) {
1605 const int idx = (is_ld_tail) ? brg.ld_block2 : ldb;
1606 maybe_tileloadd_nt(Tmm(brg.get_B_tensor(idx, is_ld_tail)),
1607 reg_aux_B, rdb * rdb_B_offset() + B_offset(ldb, 0, true),
1608 reg_stride_ldb,
1609 brg.brgattr.hint_innermost_loop
1610 == brgemm_ld_loop_innermost);
1611 for (int bdb = 0; bdb < bd_block2; bdb++) {
1612 tdpbxxd(Tmm(brg.get_C_tensor(
1613 bdb, idx, is_bdb_tail, is_ld_tail)),
1614 Tmm(brg.get_A_tensor(bdb, is_bdb_tail)),
1615 Tmm(brg.get_B_tensor(idx, is_ld_tail)));
1616 }
1617 }
1618 }
1619 if (!is_rd_tail) {
1620 add(reg_aux_A, brg.rdb * rdb_A_offset());
1621 add(reg_aux_B, brg.rdb * rdb_B_offset());
1622 }
1623}
1624
1625template <cpu_isa_t isa, typename Wmm>
1626void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel(int bd_block2,
1627 bool is_bdb_tail, int ld_block2, bool is_rd_tail, bool is_ld_tail,
1628 int vpad, int rows_for_rd_tail) {
1629 MAYBE_UNUSED(bd_block2);
1630 auto dot_product = [=](Vmm v1, Vmm v2, Vmm v3) {
1631 if (brg.is_f32 || brg.is_f16
1632 || (brg.is_bf16 && brg.isa_impl == avx2_vnni_2))
1633 uni_vfmadd231ps(v1, v2, v3);
1634 else if (brg.is_bf16)
1635 vdpbf16ps(v1, v2, v3);
1636 else if (brg.is_int8)
1637 vpdpbusd(v1, v3, v2, isa == avx2_vnni ? VexEncoding : EvexEncoding);
1638 };
1639
1640 int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
1641 const auto bd_b = nstl::max(0, vpad);
1642 const auto bd_e = nstl::min(bd_block, bd_block + vpad);
1643 const auto is_valid_bd
1644 = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e;
1645 if (!is_valid_bd) return;
1646
1647 bool is_emdbd = brg.embd_bcst;
1648
1649 int rd_loop = 0, rd_tail_size = 0;
1650 if (is_rd_tail) {
1651 if (brg.is_bf16 || brg.is_int8) {
1652 rd_tail_size = brg.rdb_tail % brg.rd_step;
1653 rd_loop = (rd_tail_size != 0)
1654 ? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step
1655 : brg.rdb_tail;
1656 } else
1657 rd_loop = brg.rdb_tail;
1658 } else
1659 rd_loop = brg.rd_block;
1660
1661 auto broadcast = [=](Vmm v1, size_t offset, bool is_tail, data_type_t dt) {
1662 if (is_tail) {
1663 uni_vpxor(v1, v1, v1);
1664 Xmm xmm_tmp = Xmm(v1.getIdx());
1665 load_bytes(
1666 xmm_tmp, reg_aux_A, offset, rd_tail_size * brg.typesize_A);
1667 uni_vpbroadcastd(v1, xmm_tmp);
1668 } else {
1669 if (dt == data_type::f32) {
1670 uni_vbroadcastss(v1, ptr[reg_aux_A + offset]);
1671 } else if (dt == data_type::bf16) {
1672 if (brg.isa_impl == avx2_vnni_2)
1673 vbcstnebf162ps(v1, ptr[reg_aux_A + offset]);
1674 else
1675 uni_vpbroadcastd(v1, ptr[reg_aux_A + offset]);
1676 } else if (one_of(dt, data_type::s8, data_type::u8)) {
1677 uni_vpbroadcastd(v1, ptr[reg_aux_A + offset]);
1678 } else if (dt == data_type::f16) {
1679 if (brg.isa_impl == avx2_vnni_2)
1680 vbcstnesh2ps(v1, ptr[reg_aux_A + offset]);
1681 else
1682 vcvtph2psx(v1, ptr_b[reg_aux_A + offset]);
1683 }
1684 }
1685
1686 if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift());
1687 };
1688
1689 auto compensation_padding
1690 = [=](Vmm vmm_load, Vmm vmm_tmp, int ld, int bd_b, int bd_e) {
1691 /* req_cal_comp_pads -> only calculate compensation along with computation
1692 * and do not use pre-calculate compensation, calculate comp padding as:
1693 * accum - inp_shift * conv(1, wei_s32) */
1694 if (brg.req_s8s8_compensation) {
1695 if (brg.req_cal_comp_pads) {
1696 uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
1697 dot_product(vmm_tmp, vmm_load, vmm_inp_shift());
1698 }
1699
1700 for (int bd = bd_b; bd < bd_e; bd++) {
1701 auto vmm = accm(ld_block2, bd, ld);
1702 if (brg.req_cal_comp_pads) {
1703 uni_vpsubd(vmm, vmm, vmm_tmp);
1704 } else {
1705 dot_product(vmm, vmm_load, vmm_inp_shift());
1706 }
1707 }
1708 }
1709
1710 if (brg.zp_type_a != brgemm_broadcast_t::none) {
1711 uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
1712 dot_product(vmm_tmp, vmm_load, vmm_one_bytes());
1713 uni_vpmulld(vmm_tmp, vmm_tmp, vmm_zp_a_shift());
1714
1715 for (int bd = bd_b; bd < bd_e; bd++) {
1716 auto vmm = accm(ld_block2, bd, ld);
1717 if (brg.req_cal_comp_pads) {
1718 uni_vpsubd(vmm, vmm, vmm_tmp);
1719 } else {
1720 uni_vpaddd(vmm, vmm, vmm_tmp);
1721 }
1722 }
1723 }
1724 };
1725
1726 if (brg.req_cal_comp_pads
1727 || (vpad != 0
1728 && (brg.req_s8s8_compensation
1729 || brg.zp_type_a != brgemm_broadcast_t::none))) {
1730 // only used for int8 compensation related things.
1731 assert(brg.is_int8);
1732 if (n_bcast_1_load && brg.zp_type_a != brgemm_broadcast_t::none) {
1733 mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
1734 const auto reg32_scratch = reg_zp_a_input_shift.cvt32();
1735 mov(reg32_scratch, 0x1010101);
1736 uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch);
1737 mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]);
1738 uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch);
1739 mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
1740 }
1741
1742 for_(int rd = 0; rd < rd_loop; rd += brg.rd_step)
1743 for (int ld = 0; ld < ld_block2; ++ld) {
1744 const auto addr = ptr[reg_aux_B + B_offset(ld, rd)];
1745 const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
1746 auto vmm_store = load();
1747 if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) {
1748 vmm_store = vmm_mask(vmm_store, is_tail, false, ld_tail_mask);
1749 uni_vmovups(vmm_store, addr);
1750 } else {
1751 load_bytes(load(), addr,
1752 brg.typesize_B * brg.ldb_tail * brg.ld_step);
1753 }
1754
1755 if (brg.req_cal_comp_pads) {
1756 compensation_padding(vmm_store, bcst(), ld, bd_b, bd_e);
1757 } else if (vpad != 0) {
1758 if (bd_b > 0)
1759 compensation_padding(vmm_store, bcst(), ld, 0, bd_b);
1760 if (bd_e < bd_block)
1761 compensation_padding(vmm_store, bcst(), ld, bd_e, bd_block);
1762 }
1763 }
1764 }
1765
1766 bool maybe_load_bytes = (rows_for_rd_tail > 0 || brg.brgattr.wary_tail_read)
1767 && is_rd_tail && rd_tail_size != 0 && (brg.is_bf16 || brg.is_int8);
1768 if (n_bcast_1_load) {
1769 for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
1770 bool have_to_load_bytes
1771 = maybe_load_bytes && (rd == rd_loop - brg.rd_step);
1772
1773 auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0;
1774 for (int bd = bd_b; bd < bd_e && !is_emdbd; bd++) {
1775 const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes
1776 || brg.brgattr.wary_tail_read);
1777 broadcast(bcst(bd), A_offset(bd, rd),
1778 have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
1779 }
1780 for (int ld = 0; ld < ld_block2; ld++) {
1781 const auto addr = ptr[reg_aux_B + B_offset(ld, rd)];
1782 const Vmm vmm_load
1783 = vmm_mask(load(), is_ld_tail, false, ld_tail_mask);
1784 // Note: Assuming the tails are properly padded/blocked for
1785 // avx2_vnni_2, as the B matrix is generally
1786 // at least double-blocked.
1787 if (brg.dt_b == data_type::f16) {
1788 if (brg.isa_impl == avx2_vnni_2) {
1789 if (rd % 2 == 0)
1790 vcvtneeph2ps(vmm_load, addr);
1791 else
1792 vcvtneoph2ps(vmm_load, addr);
1793 } else
1794 vcvtph2psx(vmm_load, addr);
1795 } else if (brg.dt_b == data_type::bf16
1796 && brg.isa_impl == avx2_vnni_2) {
1797 if (rd % 2 == 0)
1798 vcvtneebf162ps(vmm_load, addr);
1799 else
1800 vcvtneobf162ps(vmm_load, addr);
1801 } else if (is_ld_tail) {
1802 if (is_superset(brg.isa_impl, avx512_core)) {
1803 uni_vmovups(vmm_load, addr);
1804 } else {
1805 load_bytes(vmm_load, addr,
1806 brg.typesize_B * brg.ldb_tail * brg.ld_step);
1807 }
1808 } else {
1809 uni_vmovups(vmm_load, addr);
1810 }
1811 for (int bd = bd_b; bd < bd_e; bd++) {
1812 auto vmm = accm(ld_block2, bd, ld);
1813 if (is_emdbd)
1814 uni_vfmadd231ps(vmm, load(),
1815 ptr_b[reg_aux_A + A_offset(bd, rd)]);
1816 else
1817 dot_product(vmm, load(), bcst(bd));
1818 }
1819 }
1820 }
1821 } else {
1822 for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
1823 int prefetch_count_B = 0;
1824 for (int ld = 0; ld < ld_block2; ld++) {
1825 const auto addr = ptr[reg_aux_B + B_offset(ld, rd)];
1826 const Vmm vmm_load
1827 = vmm_mask(load(ld), is_ld_tail, false, ld_tail_mask);
1828 // Note: Assuming the tails are properly padded/blocked for
1829 // avx2_vnni_2, as the B matrix is generally
1830 // at least double-blocked.
1831 if (brg.dt_b == data_type::f16) {
1832 if (brg.isa_impl == avx2_vnni_2) {
1833 if (rd % 2 == 0)
1834 vcvtneeph2ps(vmm_load, addr);
1835 else
1836 vcvtneoph2ps(vmm_load, addr);
1837 } else {
1838 vcvtph2psx(vmm_load, addr);
1839 }
1840 } else if (brg.dt_b == data_type::bf16
1841 && brg.isa_impl == avx2_vnni_2) {
1842 if (rd % 2 == 0)
1843 vcvtneebf162ps(vmm_load, addr);
1844 else
1845 vcvtneobf162ps(vmm_load, addr);
1846 } else if (is_ld_tail) {
1847 if (is_superset(brg.isa_impl, avx512_core)) {
1848 uni_vmovups(vmm_load, addr);
1849 } else {
1850 load_bytes(vmm_load, addr,
1851 brg.typesize_B * brg.ldb_tail * brg.ld_step);
1852 }
1853 } else {
1854 uni_vmovups(vmm_load, addr);
1855 }
1856 }
1857
1858 bool have_to_load_bytes
1859 = maybe_load_bytes && (rd == rd_loop - brg.rd_step);
1860
1861 auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0;
1862 for (int bd = bd_b; bd < bd_e; bd++) {
1863 if (!is_emdbd) {
1864 const auto bd_by_load_bytes
1865 = (bd >= bd_e - rows_by_load_bytes
1866 || brg.brgattr.wary_tail_read);
1867 broadcast(bcst(), A_offset(bd, rd),
1868 have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
1869 }
1870 if (prefetch_count_B < ld_block2) {
1871 prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd)
1872 + brg.LDB * brg.rd_block * brg.typesize_B]);
1873 }
1874 for (int ld = 0; ld < ld_block2; ld++) {
1875 auto vmm = accm(ld_block2, bd, ld);
1876 if (is_emdbd)
1877 uni_vfmadd231ps(vmm, load(ld),
1878 ptr_b[reg_aux_A + A_offset(bd, rd)]);
1879 else
1880 dot_product(vmm, load(ld), bcst());
1881 }
1882 }
1883 }
1884 }
1885}
1886
1887template <cpu_isa_t isa, typename Wmm>
1888void jit_brgemm_kernel_t<isa, Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
1889 int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail,
1890 bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail,
1891 bool skip_accumulation) {
1892
1893 Label ldb_loop_label;
1894 Label BS_loop_label;
1895
1896 copy_post_ops_stack_values_to_aux(is_reg_tail);
1897
1898 auto ld_loop_body = [=](int vpad) {
1899 set_A_B_matrices();
1900
1901 int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
1902 const auto bd_b = nstl::max(0, vpad);
1903 const auto bd_e = nstl::min(bd_block, bd_block + vpad);
1904 const auto is_valid_bd
1905 = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e;
1906 if (!is_valid_bd) return;
1907
1908 if (brg.is_tmm) {
1909 const bool is_rd_tail = false;
1910 gemm_microkernel_amx(
1911 bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail);
1912 } else {
1913 if (brg.rdb > 0) {
1914 Label rdb_loop_label;
1915 mov(reg_rdb_loop, brg.rdb);
1916 L_aligned(rdb_loop_label, 64);
1917 {
1918 const bool is_rd_tail = false;
1919 gemm_microkernel(bd_block2, is_bdb_tail, ld_block2,
1920 is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail);
1921
1922 add(reg_aux_A, rdb_A_offset());
1923 add(reg_aux_B, rdb_B_offset());
1924
1925 dec(reg_rdb_loop);
1926 cmp(reg_rdb_loop, 0);
1927 }
1928 jg(rdb_loop_label, T_NEAR);
1929 }
1930 }
1931 if (brg.rdb_tail != 0) {
1932 const bool is_rd_tail = true;
1933 if (brg.is_tmm) {
1934 gemm_microkernel_amx(bd_block2, is_bdb_tail, ld_block2,
1935 is_rd_tail, is_ld_tail);
1936 } else {
1937 gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, is_rd_tail,
1938 is_ld_tail, vpad, rows_for_rd_tail);
1939 }
1940 }
1941 };
1942 if (is_ldb_loop_) {
1943 mov(reg_ldb_loop, ldb_loop_length);
1944 if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
1945 }
1946
1947 L_aligned(ldb_loop_label, 64);
1948 {
1949 zero_accumulators(bd_block2, is_bdb_tail, ld_block2, is_ld_tail,
1950 skip_accumulation);
1951
1952 if (is_ldb_loop_)
1953 mov(ptr[rsp + reg_D_offs_], reg_D);
1954 else {
1955 mov(reg_ldb_loop, reg_D);
1956 if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
1957 }
1958 if (brg.brgattr.max_bs > 1) mov(ptr[rsp + reg_aux_D_offs_], reg_aux_D);
1959
1960 if (brg.alpha != 0.f && !skip_accumulation) {
1961 restore_A_B_matrices();
1962 if (brg.is_tmm) {
1963 mov(reg_stride_lda, brg.typesize_A * brg.LDA);
1964 mov(reg_stride_ldb, brg.rd_step * brg.typesize_B * brg.LDB);
1965 }
1966
1967 if (brg.req_s8s8_compensation) {
1968 mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
1969 mov(reg_s8_input_shift, 128);
1970 uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8());
1971 mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
1972 }
1973 if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
1974 mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
1975 const auto reg32_scratch = reg_zp_a_input_shift.cvt32();
1976 mov(reg32_scratch, 0x1010101);
1977 uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch);
1978 mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]);
1979 uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch);
1980 mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
1981 }
1982
1983 if (brg.brgattr.max_bs > 1) mov(reg_BS_loop, reg_BS);
1984 L_aligned(BS_loop_label, 64);
1985 {
1986 if (check_top_vpad || check_bottom_vpad) {
1987 const auto vpad_first = -brg.brgattr.max_bottom_vpad;
1988 const auto vpad_last = brg.brgattr.max_top_vpad;
1989 const auto n_vpads = vpad_last - vpad_first + 2;
1990 constexpr auto MAX_N_VPADS = 2 * brgemm_t::MAX_VPAD;
1991 assert(n_vpads < MAX_N_VPADS);
1992
1993 Label Vpad_loop_end_label;
1994 std::vector<Label> Vpad_loop_iter_label(MAX_N_VPADS);
1995 if (vpad_exist) {
1996 reg64_t reg_batch = (brg.type == brgemm_addr)
1997 ? reg_aux1_batch
1998 : ((brg.type == brgemm_offs) ? reg_offs_batch
1999 : reg_strd_batch);
2000 if (brg.type == brgemm_strd)
2001 mov(reg_strd_batch,
2002 ptr[rsp + origin_strd_batch_offs_]);
2003
2004 mov(reg_aux_A_vpad,
2005 ptr[reg_batch
2006 + GET_OFF_BATCH_ELEMENT(vvpad.top)]);
2007 sub(reg_aux_A_vpad,
2008 ptr[reg_batch
2009 + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]);
2010 } else
2011 xor_(reg_aux_A_vpad, reg_aux_A_vpad);
2012
2013 for (int vpad = vpad_first; vpad <= vpad_last; vpad++) {
2014 const auto label_vpad = vpad - vpad_first;
2015 L(Vpad_loop_iter_label[label_vpad]);
2016 if (!check_top_vpad && vpad > 0) continue;
2017 if (!check_bottom_vpad && vpad < 0) continue;
2018 auto real_vpad = vpad;
2019 if (check_bottom_vpad && brg.bdb_tail) {
2020 if (!is_bdb_tail) {
2021 // for last full block before
2022 // bdb_tail && -vpad greater than bdb_tail
2023 if (brg.bdb_tail < -vpad)
2024 real_vpad += brg.bdb_tail;
2025 else
2026 continue;
2027 } else {
2028 // for block with tail, call ldb_loop()
2029 // to only calculate compensation for
2030 // padding area when bdb_tail < -vpad for
2031 // the cases using pre-cal compensation
2032 if (brg.bdb_tail < -vpad && need_comp_pads
2033 && !brg.req_cal_comp_pads)
2034 real_vpad = -brg.bdb_tail;
2035 }
2036 }
2037 cmp(reg_aux_A_vpad, vpad);
2038 jne(Vpad_loop_iter_label[label_vpad + 1], T_NEAR);
2039 ld_loop_body(real_vpad);
2040 jmp(Vpad_loop_end_label, T_NEAR);
2041 }
2042 L(Vpad_loop_iter_label[n_vpads - 1]);
2043 ld_loop_body(0);
2044 L(Vpad_loop_end_label);
2045 } else {
2046 ld_loop_body(0);
2047 }
2048 if (brg.brgattr.max_bs > 1) {
2049 dec(reg_BS_loop);
2050 cmp(reg_BS_loop, 0);
2051 jg(BS_loop_label, T_NEAR);
2052 }
2053 }
2054 }
2055
2056 if (is_ldb_loop_)
2057 mov(reg_D, ptr[rsp + reg_D_offs_]);
2058 else {
2059 if (brg.is_tmm) mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
2060 mov(reg_D, reg_ldb_loop);
2061 }
2062 if (brg.brgattr.max_bs > 1) mov(reg_aux_D, ptr[rsp + reg_aux_D_offs_]);
2063
2064 store_accumulators(bd_block2, is_bdb_tail, ld_block2, is_ld_tail,
2065 skip_accumulation);
2066 if (is_ldb_loop_) {
2067 if (brg.is_tmm) mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]);
2068 if (!is_ld_tail)
2069 ldb_regs_shift(ld_block2);
2070 else
2071 ldb_regs_shift(1, true);
2072 dec(reg_ldb_loop);
2073 cmp(reg_ldb_loop, 0);
2074 if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop);
2075 jg(ldb_loop_label, T_NEAR);
2076 }
2077 }
2078}
2079
2080template <cpu_isa_t isa, typename Wmm>
2081void jit_brgemm_kernel_t<isa, Wmm>::bdb_loop() {
2082 auto do_ldb_loop = [=](int bd_block2, bool is_bdb_tail, bool check_top_vpad,
2083 bool check_bottom_vpad, int rows_for_rd_tail,
2084 bool skip_accumulation) {
2085 if (brg.ldb2 > 0) {
2086 const bool is_ld_reg_tail = false;
2087 const bool is_ld_tail = false;
2088 ldb_loop(bd_block2, is_bdb_tail, brg.ld_block2, brg.ldb2,
2089 is_ld_reg_tail, is_ld_tail, check_top_vpad,
2090 check_bottom_vpad, rows_for_rd_tail, skip_accumulation);
2091 }
2092 if (brg.ldb2_tail > 0) {
2093 const bool is_ld_reg_tail = (brg.ldb2 == 0) ? false : true;
2094 const bool is_ld_tail = false;
2095 ldb_loop(bd_block2, is_bdb_tail, brg.ldb2_tail, 1, is_ld_reg_tail,
2096 is_ld_tail, check_top_vpad, check_bottom_vpad,
2097 rows_for_rd_tail, skip_accumulation);
2098 }
2099 if (brg.ldb_tail > 0) {
2100 const bool is_ld_reg_tail
2101 = (brg.ldb2 == 0 && brg.ldb2_tail == 0) ? false : true;
2102 const bool is_ld_tail = true;
2103 ldb_loop(bd_block2, is_bdb_tail, 1, 1, is_ld_reg_tail, is_ld_tail,
2104 check_top_vpad, check_bottom_vpad, rows_for_rd_tail,
2105 skip_accumulation);
2106 }
2107 };
2108
2109 auto bdb_loop_body = [=](int bd_block2, bool is_bdb_tail,
2110 bool check_top_vpad, bool check_bottom_vpad,
2111 int rows_for_rd_tail, bool skip_accumulation) {
2112 do_ldb_loop(bd_block2, is_bdb_tail, check_top_vpad, check_bottom_vpad,
2113 rows_for_rd_tail, skip_accumulation);
2114
2115 add(reg_C, bdb_C_offset(bd_block2));
2116 add(reg_D, bdb_D_offset(bd_block2));
2117 add(reg_a_offset, bdb_A_offset(bd_block2));
2118
2119 advance_bd_block2_post_op_regs(bd_block2);
2120 };
2121
2122 int rows_for_rd_tail, bd_blocks_for_rd_tail;
2123
2124 if (brg.is_tmm) {
2125 rows_for_rd_tail = 0;
2126 bd_blocks_for_rd_tail = 0;
2127 n_bcast_1_load = false;
2128 } else {
2129 rows_for_rd_tail = 0;
2130 if (brg.rdb_tail != 0 && (brg.is_bf16 || brg.is_int8)) {
2131 const auto rd_tail_size = brg.rdb_tail % brg.rd_step;
2132 rows_for_rd_tail = rd_tail_size
2133 ? div_up(brg.rd_step - rd_tail_size, brg.reduce_dim)
2134 : 0;
2135 }
2136 bd_blocks_for_rd_tail
2137 = div_up(nstl::max(0,
2138 rows_for_rd_tail - brg.bdb_tail
2139 + brg.brgattr.max_bottom_vpad),
2140 brg.bd_block);
2141
2142 auto ld_block2 = (brg.ldb2 > 0)
2143 ? brg.ld_block2
2144 : ((brg.ldb2_tail > 0) ? brg.ldb2_tail : 1);
2145 const int free_vregs = max_vregs - brg.req_s8s8_compensation;
2146 n_bcast_1_load = brg.is_int8
2147 && ((brg.bd_block * (ld_block2 + 1) < free_vregs)
2148 && (bd_blocks_for_rd_tail == 0)
2149 && (rows_for_rd_tail == 0));
2150 if (brg.brgattr.hint_loop_order != brgemm_lo_default)
2151 n_bcast_1_load = (brg.brgattr.hint_loop_order == brgemm_lo_bl_1load)
2152 ? true
2153 : false;
2154 }
2155
2156 auto bdb_loop_avx512 = [=](bool skip_accumulation) {
2157 Label bdb_loop_end_label, no_vpad_label;
2158 if (vpad_exist) {
2159 // max_top_vp is restricted by bd_block due to
2160 // brgemm_kernel implementation. TODO: remove this restriction
2161 assert(brg.brgattr.max_top_vpad <= brg.bd_block
2162 && brg.brgattr.max_bottom_vpad <= brg.bd_block);
2163
2164 if (brg.type == brgemm_strd) {
2165 // if batch is nullptr then it means no vpadding in this call
2166 cmp(reg_offs_batch, 0);
2167 je(no_vpad_label, T_NEAR);
2168 }
2169
2170 // first bd_block --------------
2171 auto bdblocks = brg.bdb;
2172 if (bdblocks >= 1) {
2173 bdb_loop_body(1, false, true, brg.bdb == 1 && brg.bdb_tail == 0,
2174 brg.bdb - bd_blocks_for_rd_tail > 0 ? 0
2175 : rows_for_rd_tail,
2176 skip_accumulation);
2177 bdblocks--;
2178 }
2179 if (bdblocks > 1) {
2180 // middle bd_blocks -----------
2181 Label bdb_loop_label;
2182 mov(reg_bdb_loop, bdblocks);
2183 L_aligned(bdb_loop_label, 64);
2184 {
2185 bdb_loop_body(1, false, false, false,
2186 bd_blocks_for_rd_tail <= 1 ? 0 : rows_for_rd_tail,
2187 skip_accumulation);
2188 dec(reg_bdb_loop);
2189 cmp(reg_bdb_loop, 1);
2190 jg(bdb_loop_label, T_NEAR);
2191 }
2192 bdblocks = 1;
2193 }
2194 if (bdblocks == 1) {
2195 // last bd_block ------------
2196 bdb_loop_body(1, false, false, true,
2197 bd_blocks_for_rd_tail == 0 ? 0 : rows_for_rd_tail,
2198 skip_accumulation);
2199 }
2200 if (brg.bdb_tail > 0)
2201 do_ldb_loop(1, true, brg.bdb < 1, true, rows_for_rd_tail,
2202 skip_accumulation);
2203 // for brgemm_strd "no vpadding" case may be implemented, so skip it
2204 if (brg.type == brgemm_strd) jmp(bdb_loop_end_label);
2205 }
2206 if (!vpad_exist || brg.type == brgemm_strd) {
2207 // for brgemm_strd batch may be null so we need this code path
2208 L_aligned(no_vpad_label, 64);
2209 if (brg.bdb > 0) {
2210 mov(reg_bdb_loop, brg.bdb);
2211 if (brg.bdb > (rows_for_rd_tail ? 1 : 0)) {
2212 Label bdb_loop_label;
2213 L_aligned(bdb_loop_label, 64);
2214 {
2215 bdb_loop_body(1, false, false, false,
2216 bd_blocks_for_rd_tail <= 1 ? 0
2217 : rows_for_rd_tail,
2218 skip_accumulation);
2219 dec(reg_bdb_loop);
2220 cmp(reg_bdb_loop, rows_for_rd_tail ? 1 : 0);
2221 jg(bdb_loop_label, T_NEAR);
2222 }
2223 }
2224
2225 if (rows_for_rd_tail)
2226 bdb_loop_body(1, false, false, true,
2227 bd_blocks_for_rd_tail == 0 ? 0 : rows_for_rd_tail,
2228 skip_accumulation);
2229 }
2230 if (brg.bdb_tail > 0)
2231 do_ldb_loop(1, true, false, false, rows_for_rd_tail,
2232 skip_accumulation);
2233 }
2234 L_aligned(bdb_loop_end_label, 64);
2235 };
2236 auto bdb_loop_amx = [=](bool skip_accumulation) {
2237 Label bdb_loop_label;
2238 if (brg.bd_block2 >= 1) {
2239 mov(reg_bdb_loop, brg.bdb2);
2240 mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
2241 L_aligned(bdb_loop_label, 64);
2242 {
2243 bdb_loop_body(brg.bd_block2, false, false, false, 0,
2244 skip_accumulation);
2245 mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]);
2246 dec(reg_bdb_loop);
2247 cmp(reg_bdb_loop, 0);
2248 mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop);
2249 }
2250 jg(bdb_loop_label, T_NEAR);
2251 }
2252 if (brg.bdb2_tail > 0)
2253 bdb_loop_body(
2254 brg.bdb2_tail, false, false, false, 0, skip_accumulation);
2255 if (brg.bdb_tail > 0)
2256 do_ldb_loop(1, true, false, false, 0, skip_accumulation);
2257 };
2258
2259 auto bdb_loop_general = [=](bool skip_accumulation) {
2260 if (brg.type == brgemm_addr && brg.brgattr.max_bs == 1 && !vpad_exist
2261 && !skip_accumulation) {
2262 mov(reg_aux1_A, ptr[reg_addr_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]);
2263 mov(reg_aux1_B, ptr[reg_addr_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]);
2264 }
2265
2266 xor_(reg_a_offset, reg_a_offset);
2267 if (brg.is_tmm)
2268 bdb_loop_amx(skip_accumulation);
2269 else
2270 bdb_loop_avx512(skip_accumulation);
2271 };
2272
2273 if (brg.brgattr.generate_skip_accumulation) {
2274 Label bdb_loop_skip_acc_label, bdb_loop_done_label;
2275 mov(reg_skip_accm, ptr[rsp + reg_skip_accm_offs_]);
2276 cmp(reg_skip_accm, 0);
2277 jnz(bdb_loop_skip_acc_label, T_NEAR);
2278
2279 bdb_loop_general(false);
2280 jmp(bdb_loop_done_label, T_NEAR);
2281
2282 L_aligned(bdb_loop_skip_acc_label, 64);
2283 bdb_loop_general(true);
2284
2285 L_aligned(bdb_loop_done_label, 64);
2286 } else
2287 bdb_loop_general(false);
2288}
2289
2290template <cpu_isa_t isa, typename Wmm>
2291void jit_brgemm_kernel_t<isa, Wmm>::generate() {
2292 preamble();
2293
2294 sub(rsp, stack_space_needed_);
2295
2296 vpad_exist
2297 = (brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0)
2298 ? true
2299 : false;
2300 need_comp_pads = IMPLICATION(brg.zp_type_a == brgemm_broadcast_t::none,
2301 brg.req_s8s8_compensation)
2302 && IMPLICATION(!vpad_exist, brg.req_cal_comp_pads);
2303
2304 if (is_superset(brg.isa_impl, avx512_core)) {
2305 const auto full_mask = size_t {0xffffffffffffffff};
2306 const auto tail_mask = size_t((1 << brg.ldb_tail) - 1);
2307 reg64_t reg_mask = rax;
2308
2309 mov(reg_mask, full_mask);
2310 kmovq(ld_full_mask, reg_mask);
2311 mov(reg_mask, tail_mask);
2312 kmovq(ld_tail_mask, reg_mask);
2313 }
2314
2315 read_params();
2316
2317 bdb_loop();
2318
2319 add(rsp, stack_space_needed_);
2320
2321 postamble();
2322
2323 if (brg.with_eltwise) postops_injector_->prepare_table();
2324}
2325
2326brgemm_attr_t::brgemm_attr_t()
2327 : max_bs(INT_MAX)
2328 , max_top_vpad(0)
2329 , max_bottom_vpad(0)
2330 , hint_expected_A_size(platform::get_per_core_cache_size(1))
2331 , hint_expected_B_size(platform::get_per_core_cache_size(1))
2332 , hint_expected_C_size(platform::get_per_core_cache_size(1))
2333 , hint_innermost_loop(brgemm_ld_loop_innermost)
2334 , hint_loop_order(brgemm_kernel_loop_order_t::brgemm_lo_default)
2335 , hint_prefetching(brgemm_kernel_prefetching_t::brgemm_prf_default)
2336 , wary_tail_read(true)
2337 , generate_skip_accumulation(false)
2338 , bd_mask(nullptr)
2339 , bd_mask_level(0)
2340 , use_uker(false)
2341 , use_interleave_stores(false)
2342 , LDA2(0)
2343 , LDB2(0)
2344 , LDC2_M(0)
2345 , LDC2_N(0) {}
2346
2347template <cpu_isa_t isa, typename Wmm>
2348brgemm_kernel_common_t<isa, Wmm>::brgemm_kernel_common_t(const brgemm_t abrd) {
2349 brgemm_kernel_ = new jit_brgemm_kernel_t<isa, Wmm>(abrd);
2350}
2351
2352template <cpu_isa_t isa, typename Wmm>
2353status_t brgemm_kernel_common_t<isa, Wmm>::create_kernel() {
2354 return brgemm_kernel_->create_kernel();
2355}
2356
2357template <cpu_isa_t isa, typename Wmm>
2358void brgemm_kernel_common_t<isa, Wmm>::operator()(
2359 brgemm_kernel_params_t *params) const {
2360 (*brgemm_kernel_)(params);
2361}
2362
2363template <cpu_isa_t isa, typename Wmm>
2364brgemm_kernel_common_t<isa, Wmm>::~brgemm_kernel_common_t() {
2365 delete brgemm_kernel_;
2366}
2367
2368// isa specific instantiations are required because
2369// post-ops require template isa param.
2370template struct brgemm_kernel_common_t<avx512_core_amx_fp16, Xbyak::Tmm>;
2371template struct brgemm_kernel_common_t<avx512_core_amx, Xbyak::Tmm>;
2372template struct brgemm_kernel_common_t<avx512_core_fp16, Xbyak::Zmm>;
2373template struct brgemm_kernel_common_t<avx512_core_bf16, Xbyak::Zmm>;
2374template struct brgemm_kernel_common_t<avx512_core_vnni, Xbyak::Zmm>;
2375template struct brgemm_kernel_common_t<avx512_core, Xbyak::Zmm>;
2376template struct brgemm_kernel_common_t<avx2_vnni, Xbyak::Ymm>;
2377template struct brgemm_kernel_common_t<avx2_vnni_2, Xbyak::Ymm>;
2378template struct brgemm_kernel_common_t<avx2, Xbyak::Ymm>;
2379} // namespace x64
2380} // namespace cpu
2381} // namespace impl
2382} // namespace dnnl
2383
2384// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2385