1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/brgemm/brgemm_types.hpp"
23#include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp"
24#include "cpu/x64/cpu_barrier.hpp"
25#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
26#include "cpu/x64/jit_generator.hpp"
27
28#define GET_OFF(field) offsetof(brgemm_kernel_params_t, field)
29#define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field)
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36using namespace dnnl::impl::utils;
37using namespace Xbyak;
38
39template <cpu_isa_t isa, typename Wmm>
40jit_brdgmm_kernel_base_t<isa, Wmm>::jit_brdgmm_kernel_base_t(
41 const brgemm_t &abrd)
42 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
43 , brg(abrd)
44 , simd_w_(vreg_traits<Vmm>::vlen / brg.typesize_C)
45 , max_vmms_(isa_num_vregs(isa)) {
46
47 if (brg.with_eltwise || brg.with_binary || brg.with_sum) {
48
49 static constexpr bool preserve_gpr = true;
50 static constexpr bool preserve_vmm = false;
51 static constexpr bool use_exact_tail_scalar_bcast = false;
52 const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md);
53 const size_t tail = tail_length();
54
55 static const bcast_set_t enabled_bcast_strategy
56 = {broadcasting_strategy_t::scalar,
57 broadcasting_strategy_t::per_oc,
58 broadcasting_strategy_t::no_broadcast};
59 const binary_injector::rhs_arg_static_params_t rhs_sp {
60 static_cast<size_t>(vmm_b().getIdx()), r14, r15, r13,
61 preserve_gpr, preserve_vmm,
62 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_),
63 dst_md_wrapper, tail, k_mask, use_exact_tail_scalar_bcast};
64 const binary_injector::static_params_t bsp {
65 this->param1, enabled_bcast_strategy, rhs_sp};
66
67 postops_injector_ = utils::make_unique<po_injector_t>(
68 this, brg.attr->post_ops_, bsp);
69
70 with_binary_non_scalar_bcast_
71 = binary_injector::any_binary_postop_rhs_non_scalar_broadcast(
72 brg.attr->post_ops_, dst_md_wrapper);
73 }
74 if (brg.is_bf16_emu)
75 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
76 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
77 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4);
78}
79
80template <cpu_isa_t isa, typename Wmm>
81template <typename U>
82U jit_brdgmm_kernel_base_t<isa, Wmm>::maybe_mask(
83 const U umm_in, bool mask_flag, bool store) {
84 return mask_flag ? (store ? umm_in | k_mask : umm_in | k_mask | T_z)
85 : umm_in;
86}
87
88template <cpu_isa_t isa, typename Wmm>
89void jit_brdgmm_kernel_base_t<isa, Wmm>::read_params() {
90 Label label_done;
91
92 mov(reg_BS, ptr[param1 + GET_OFF(BS)]);
93 mov(reg_aux_C, ptr[param1 + GET_OFF(ptr_C)]);
94 mov(reg_aux_D, ptr[param1 + GET_OFF(ptr_D)]);
95
96 if (brg.type == brgemm_offs) {
97 mov(reg_A, ptr[param1 + GET_OFF(ptr_A)]);
98 mov(reg_B, ptr[param1 + GET_OFF(ptr_B)]);
99 } else if (brg.type == brgemm_strd) {
100 mov(reg_aux1_A, ptr[param1 + GET_OFF(ptr_A)]);
101 mov(reg_aux1_B, ptr[param1 + GET_OFF(ptr_B)]);
102 if (brg.brgattr.max_bs > 1) {
103 mov(ptr[rsp + reg_A_offs_], reg_aux1_A);
104 mov(ptr[rsp + reg_B_offs_], reg_aux1_B);
105 }
106 }
107
108 if (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad()) {
109 mov(reg_aux_batch_addr, ptr[param1 + GET_OFF(batch)]);
110 if (brg.brgattr.max_bs > 1)
111 mov(ptr[rsp + reg_batch0_addr_offs_], reg_aux_batch_addr);
112 }
113
114 if (brg.with_bias) {
115 mov(reg_tmp, ptr[param1 + GET_OFF(ptr_bias)]);
116 mov(ptr[rsp + reg_bias_offs_], reg_tmp);
117 }
118
119 if (brg.with_scales) {
120 mov(reg_tmp, ptr[param1 + GET_OFF(ptr_scales)]);
121 mov(ptr[rsp + reg_scales_offs_], reg_tmp);
122 }
123
124 if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1);
125}
126
127template <cpu_isa_t isa, typename Wmm>
128void jit_brdgmm_kernel_base_t<isa, Wmm>::load_accumulators(
129 int m_blocks, int n_blocks) {
130 const int v_substep = vnni_substep();
131 for_(int v = 0; v < v_substep; ++v)
132 for_(int m = 0; m < m_blocks; ++m)
133 for (int n = 0; n < n_blocks; ++n) {
134 auto vmm = accm(m_blocks, n_blocks, m, n, v);
135 uni_vpxor(vmm, vmm, vmm);
136 }
137}
138
139template <cpu_isa_t isa, typename Wmm>
140void jit_brdgmm_kernel_base_t<isa, Wmm>::restore_A_B_matrices() {
141 if (brg.brgattr.max_bs > 1
142 && (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad()))
143 mov(reg_aux_batch_addr, ptr[rsp + reg_batch0_addr_offs_]);
144
145 if (brg.type == brgemm_strd && brg.brgattr.max_bs > 1) {
146 mov(reg_aux1_A, ptr[rsp + reg_A_offs_]);
147 mov(reg_aux1_B, ptr[rsp + reg_B_offs_]);
148 }
149}
150
151template <cpu_isa_t isa, typename Wmm>
152void jit_brdgmm_kernel_base_t<isa, Wmm>::set_A_B_matrices() {
153
154 if (brg.type == brgemm_addr) {
155 mov(reg_aux_A, ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(ptr.A)]);
156 mov(reg_aux_B, ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(ptr.B)]);
157 } else if (brg.type == brgemm_offs) {
158 mov(reg_aux_A, reg_A);
159 mov(reg_aux_B, reg_B);
160 add(reg_aux_A,
161 ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(offset.A)]);
162 add(reg_aux_B,
163 ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(offset.B)]);
164 } else if (brg.type == brgemm_strd) {
165 mov(reg_aux_A, reg_aux1_A);
166 mov(reg_aux_B, reg_aux1_B);
167 if (brg.brgattr.max_bs > 1) {
168 safe_add(reg_aux1_A, brg.stride_a, reg_tmp);
169 safe_add(reg_aux1_B, brg.stride_b, reg_tmp);
170 }
171 }
172
173 add(reg_aux_A, reg_a_offset);
174 lea(reg_aux_B, ptr[reg_aux_B + reg_aux_N * brg.typesize_B]);
175}
176
177template <cpu_isa_t isa, typename Wmm>
178void jit_brdgmm_kernel_base_t<isa, Wmm>::advance_A_B_matrices() {
179 if (brg.brgattr.max_bs > 1
180 && (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad()))
181 add(reg_aux_batch_addr, sizeof(brgemm_batch_element_t));
182}
183
184template <cpu_isa_t isa, typename Wmm>
185void jit_brdgmm_kernel_base_t<isa, Wmm>::cvt2ps(data_type_t type_in,
186 const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag,
187 bool store) {
188 const int tail_size = tail_length();
189 const bool is_load_tail = op.isMEM() && mask_flag && tail_size > 0
190 && (tail_size
191 < static_cast<int>(vreg_traits<Vmm>::vlen / sizeof(float)));
192 if (IMPLICATION(is_load_tail, isa_has_masks(brg.isa_impl))) {
193 const Vmm vmm = maybe_mask(vmm_in, is_load_tail, store);
194 switch (type_in) {
195 case data_type::f32:
196 case data_type::s32: vmovups(vmm, op); break;
197 case data_type::bf16:
198 vpmovzxwd(vmm, op);
199 vpslld(vmm, vmm, 16);
200 break;
201 case data_type::f16: vcvtph2ps(vmm, op); break;
202 case data_type::s8: vpmovsxbd(vmm, op); break;
203 case data_type::u8: vpmovzxbd(vmm, op); break;
204 default: assert(!"unsupported data type");
205 }
206 } else {
207 uni_vpxor(vmm_in, vmm_in, vmm_in);
208 load_data(type_in, vmm_in, op.getAddress(), tail_size);
209 }
210 if (types::is_integral_dt(type_in)) vcvtdq2ps(vmm_in, vmm_in);
211}
212
213template <cpu_isa_t isa, typename Wmm>
214void jit_brdgmm_kernel_base_t<isa, Wmm>::apply_post_ops(
215 int m_blocks, int n_blocks, bool has_n_tail) {
216
217 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
218 injector_utils::vmm_index_set_t vmm_idxs_param;
219 const int v_substep = vnni_substep();
220
221 // collect vmm_idx's to apply post ops.
222 // incase of avx2_vnni_2 tails, it is possible we do not need apply post-ops
223 // to last vnni_substep
224 for_(int v_i = 0; v_i < v_substep; ++v_i)
225 for_(int m_i = 0; m_i < m_blocks; ++m_i)
226 for (int n_i = 0; n_i < n_blocks; ++n_i) {
227 if (get_substep_simd(n_i, v_i, has_n_tail) <= 0) continue;
228 const auto vmm_idx = accm(m_blocks, n_blocks, m_i, n_i, v_i).getIdx();
229 vmm_idxs_param.insert(vmm_idx);
230 }
231
232 if (brg.with_binary) {
233 mov(reg_binary_params, ptr[rsp + abi_param1_offs_]);
234
235 if (with_binary_non_scalar_bcast_) {
236
237 for_(int v_i = 0; v_i < v_substep; ++v_i)
238 for_(int m_i = 0; m_i < m_blocks; m_i++)
239 for (int n_i = 0; n_i < n_blocks; n_i++) {
240 const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail);
241 if (substep_simd <= 0) continue;
242 const auto vmm_idx
243 = accm(m_blocks, n_blocks, m_i, n_i, v_i).getIdx();
244 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_aux_D);
245 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
246 vmm_idx, D_offset(m_i, n_i, v_i));
247
248 if (n_i + 1 == n_blocks && has_n_tail && substep_simd < simd_w_)
249 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
250 }
251 }
252 }
253
254 const auto sum_injector = [&] {
255 const float *p_sum_scale = &brg.sum_scale;
256 const int32_t *p_sum_zp = &brg.sum_zp;
257 const bool p_sum_scale_reg_set = *p_sum_scale != 1.f;
258 const bool p_sum_zp_reg_set = *p_sum_zp != 0;
259
260 const injector_utils::conditional_register_preserve_guard_t
261 register_guard_sum_scale(
262 (with_binary_non_scalar_bcast_) && p_sum_scale_reg_set,
263 this, {reg_ptr_sum_scale});
264 const injector_utils::conditional_register_preserve_guard_t
265 register_guard_sum_zp(p_sum_zp_reg_set, this, {reg_ptr_sum_zp});
266
267 if (p_sum_scale_reg_set)
268 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
269
270 auto vmm_sum_zp = vmm_tmp(0);
271 if (p_sum_zp_reg_set) {
272 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
273 vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
274 }
275
276 for_(int m_i = 0; m_i < m_blocks; m_i++)
277 for_(int n_i = 0; n_i < n_blocks; n_i++)
278 for (int v_i = 0; v_i < v_substep; v_i++) {
279 const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail);
280 if (substep_simd <= 0) continue;
281 const auto vmm = accm(m_blocks, n_blocks, m_i, n_i, v_i);
282 const auto addr = ptr[reg_aux_D + D_offset(m_i, n_i, v_i)];
283 const auto vmm_prev_dst = vmm_tmp(1);
284 cvt2ps(brg.sum_dt, vmm_prev_dst, addr, substep_simd != simd_w_,
285 false);
286 if (p_sum_zp_reg_set) vsubps(vmm_prev_dst, vmm_sum_zp);
287 if (!p_sum_scale_reg_set)
288 vaddps(vmm, vmm_prev_dst);
289 else {
290 if (is_superset(brg.isa_impl, avx512_core)) {
291 vfmadd231ps(vmm, vmm_prev_dst, ptr_b[reg_ptr_sum_scale]);
292 } else {
293 auto vmm_scale = vmm_tmp(2);
294 uni_vpbroadcastd(vmm_scale, ptr[reg_ptr_sum_scale]);
295 uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_scale);
296 }
297 }
298 }
299 };
300
301 if (brg.with_sum) {
302 postops_injector_->set_lambda_injector(
303 primitive_kind::sum, sum_injector);
304 }
305
306 postops_injector_->compute_vector_range(vmm_idxs_param, rhs_arg_params);
307}
308
309template <cpu_isa_t isa, typename Wmm>
310void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators_apply_post_ops(
311 int m_blocks, int n_blocks, bool has_n_tail) {
312
313 const bool dq2ps_required = brg.is_int8;
314 const int v_substep = vnni_substep();
315 if (brg.with_scales) {
316 mov(reg_aux_scales, ptr[rsp + reg_scales_offs_]);
317 if (brg.is_oc_scale) {
318 lea(reg_aux_scales,
319 ptr[reg_aux_scales + reg_aux_N * sizeof(float)]);
320 }
321 for_(int m = 0; m < m_blocks; m++)
322 for_(int n = 0; n < n_blocks; n++)
323 for (int v_i = 0; v_i < v_substep; ++v_i) {
324 const int substep_simd = get_substep_simd(n, v_i, has_n_tail);
325 if (substep_simd <= 0) continue;
326 const bool mask_flag = substep_simd < simd_w_;
327 const Vmm vmm = maybe_mask(
328 accm(m_blocks, n_blocks, m, n, v_i), mask_flag, false);
329 if (dq2ps_required) vcvtdq2ps(vmm, vmm);
330 if (IMPLICATION(mask_flag || !brg.is_oc_scale,
331 is_superset(brg.isa_impl, avx512_core))) {
332 if (brg.is_oc_scale) {
333 vmulps(vmm, vmm,
334 ptr[reg_aux_scales + scales_offset(n, v_i)]);
335 } else {
336 vmulps(vmm, vmm, ptr_b[reg_aux_scales]);
337 }
338 } else {
339 auto vmm_scale = vmm_tmp(0);
340 const auto addr = ptr[reg_aux_scales + scales_offset(n, v_i)];
341 if (brg.is_oc_scale) {
342 uni_vpxor(vmm_scale, vmm_scale, vmm_scale);
343 load_data(data_type::f32, vmm_scale, addr, substep_simd);
344 } else {
345 vbroadcastss(vmm_scale, ptr[reg_aux_scales]);
346 }
347 vmulps(vmm, vmm, vmm_scale);
348 }
349 }
350 }
351
352 if (brg.with_bias) {
353 mov(reg_aux_bias, ptr[rsp + reg_bias_offs_]);
354 lea(reg_aux_bias, ptr[reg_aux_bias + reg_aux_N * brg.typesize_bias]);
355 }
356
357 for_(int v_i = 0; v_i < v_substep; ++v_i)
358 for (int n = 0; n < n_blocks; n++) {
359 auto vmm_bias = vmm_tmp(0);
360 const int substep_simd = get_substep_simd(n, v_i, has_n_tail);
361 if (substep_simd <= 0) continue;
362 if (brg.with_bias) {
363 auto ptr_bias = ptr[reg_aux_bias + bias_offset(n, v_i)];
364 cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, substep_simd != simd_w_,
365 false);
366 }
367 for (int m = 0; m < m_blocks; m++) {
368 auto vmm = accm(m_blocks, n_blocks, m, n, v_i);
369 if (dq2ps_required && !brg.with_scales) vcvtdq2ps(vmm, vmm);
370 if (brg.with_bias) { vaddps(vmm, vmm, vmm_bias); }
371 }
372 }
373
374 if (postops_injector_) apply_post_ops(m_blocks, n_blocks, has_n_tail);
375
376 const bool dt_requires_saturation
377 = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
378 auto vmm_lbound = vmm_tmp(0);
379 auto vmm_ubound = vmm_tmp(1);
380 if (dt_requires_saturation) {
381 init_saturate_f32(
382 vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d);
383 }
384
385 if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16();
386
387 for (int m = 0; m < m_blocks; m++) {
388 if (dt_requires_saturation) {
389 for_(int n = 0; n < n_blocks; n++)
390 for (int v_i = 0; v_i < v_substep; ++v_i) {
391 if (get_substep_simd(n, v_i, has_n_tail) <= 0) continue;
392 auto vmm = accm(m_blocks, n_blocks, m, n, v_i);
393 saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
394 vcvtps2dq(vmm, vmm);
395 }
396 }
397
398 for_(int n = 0; n < n_blocks; n++)
399 for (int v_i = 0; v_i < v_substep; ++v_i) {
400 const int substep_simd = get_substep_simd(n, v_i, has_n_tail);
401 if (substep_simd <= 0) continue;
402 const auto offset = D_offset(m, n, v_i);
403 auto addr = ptr[reg_aux_D + offset];
404 auto vmm = accm(m_blocks, n_blocks, m, n, v_i);
405 auto vmm_low = Vmm_low_t(vmm.getIdx());
406 const bool mask_flag = substep_simd < simd_w_;
407 const Vmm r_vmm = maybe_mask(vmm, mask_flag, true);
408 const Vmm_low_t r_vmm_low = maybe_mask(vmm_low, mask_flag, true);
409 if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) {
410 switch (brg.dt_d) {
411 case data_type::f32:
412 case data_type::s32: vmovups(addr, r_vmm); break;
413 case data_type::bf16:
414 if (brg.is_bf16_emu)
415 bf16_emu_->vcvtneps2bf16(vmm_low, vmm);
416 else
417 vcvtneps2bf16(vmm_low, vmm,
418 brg.isa_impl == avx2_vnni_2
419 ? Xbyak::VexEncoding
420 : Xbyak::EvexEncoding);
421 if (mask_flag)
422 vmovdqu16(addr, r_vmm_low);
423 else
424 vmovups(addr, r_vmm_low);
425 break;
426 case data_type::f16:
427 vcvtps2ph(addr, r_vmm, _op_mxcsr);
428 break;
429 case data_type::s8: vpmovsdb(addr, r_vmm); break;
430 case data_type::u8: vpmovusdb(addr, r_vmm); break;
431 default: assert(!"unknown dst_dt");
432 }
433 } else {
434 store_data(brg.dt_d, vmm, reg_aux_D, offset, substep_simd);
435 }
436 }
437 }
438}
439
440template <cpu_isa_t isa, typename Wmm>
441void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators_without_post_ops(
442 int m_blocks, int n_blocks, bool has_n_tail) {
443
444 const bool dt_requires_saturation
445 = brg.is_int8 && brg.dt_c != data_type::s32;
446 auto vmm_lbound = vmm_tmp(0);
447 auto vmm_ubound = vmm_tmp(1);
448 if (dt_requires_saturation) {
449 init_saturate_f32(
450 vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d);
451 }
452
453 for_(int m = 0; m < m_blocks; m++)
454 for_(int n = 0; n < n_blocks; n++)
455 for (int v_i = 0; v_i < vnni_substep(); ++v_i) {
456 const int substep_simd = get_substep_simd(n, v_i, has_n_tail);
457 if (substep_simd <= 0) continue;
458 const bool mask_flag = substep_simd < simd_w_;
459 auto vmm_acc = accm(m_blocks, n_blocks, m, n, v_i);
460 if (dt_requires_saturation) {
461 saturate_f32(vmm_acc, vmm_lbound, vmm_ubound, brg.dt_d);
462 vcvtps2dq(vmm_acc, vmm_acc);
463 }
464 const auto offset = C_offset(m, n, v_i);
465 if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) {
466 auto vmm_acc_masked = maybe_mask(vmm_acc, mask_flag, true);
467 vmovups(ptr[reg_aux_C + offset], vmm_acc_masked);
468 } else {
469 store_data(brg.dt_c, vmm_acc, reg_aux_C, offset, substep_simd);
470 }
471 }
472}
473
474template <cpu_isa_t isa, typename Wmm>
475void jit_brdgmm_kernel_base_t<isa,
476 Wmm>::maybe_transpose_interleaved_vnni_to_plain(int m_blocks,
477 int n_blocks, bool has_n_tail) {
478
479 if (vnni_substep() == 1) return;
480
481 // The tail block is always processed as plain.
482 // No need to transpose it here.
483 const int n_blocks_e = n_blocks - has_n_tail;
484
485 auto ymm_aux0 = vmm_tmp(0);
486 for_(int m_i = 0; m_i < m_blocks; m_i++)
487 for (int n_i = 0; n_i < n_blocks_e; n_i++) {
488 auto ymm_even = accm(m_blocks, n_blocks, m_i, n_i, 0);
489 auto ymm_odd = accm(m_blocks, n_blocks, m_i, n_i, 1);
490 // reusing ymm_odd as aux
491 // TODO: Check for any latency due to register dependency
492 auto ymm_aux1 = ymm_odd;
493 vpunpckldq(ymm_aux0, ymm_even, ymm_odd);
494 vpunpckhdq(ymm_aux1, ymm_even, ymm_odd);
495 vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20);
496 vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31);
497 }
498}
499
500template <cpu_isa_t isa, typename Wmm>
501void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators(
502 int m_blocks, int n_blocks, bool has_n_tail) {
503
504 maybe_transpose_interleaved_vnni_to_plain(m_blocks, n_blocks, has_n_tail);
505
506 if (is_fast_vnni_int8() && brg.is_bf16_emu) {
507 // load permute indices from data section
508 mov(reg_tmp, permute_index_table);
509 vmovdqu32(vmm_permute(), ptr[reg_tmp]);
510 }
511
512 if (is_fast_vnni_int8()) {
513 for_(int m_i = 0; m_i < m_blocks; ++m_i)
514 for (int n_i = 0; n_i < n_blocks; ++n_i) {
515 auto vmm_out = accm(m_blocks, n_blocks, m_i, n_i, 0);
516 vpermd(vmm_out, vmm_permute(), vmm_out);
517 }
518 }
519
520 const bool are_post_ops_applicable
521 = one_of(true, brg.with_eltwise, brg.with_binary, brg.with_scales,
522 brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c);
523
524 Label label_done;
525 if (are_post_ops_applicable) {
526 store_accumulators_apply_post_ops(m_blocks, n_blocks, has_n_tail);
527 } else {
528 store_accumulators_without_post_ops(m_blocks, n_blocks, has_n_tail);
529 }
530}
531
532template <cpu_isa_t isa, typename Wmm>
533void jit_brdgmm_kernel_base_t<isa, Wmm>::load_a(
534 Vmm vmma, int m_i, int n_i, int v_i, bool has_n_tail) {
535 const int n_blocks
536 = has_n_tail && n_block2_tail() > 0 ? n_block2_tail() : n_block2();
537 const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail);
538 const bool is_tail_block = has_n_tail && n_i + 1 == n_blocks;
539 const bool mask_flag = substep_simd < simd_w_;
540 const auto addr = ptr[reg_aux_A + A_offset(m_i, n_i)
541 + is_tail_block * v_i * simd_w_ * brg.typesize_A];
542 if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) {
543 vmma = maybe_mask(vmma, mask_flag, false);
544 if (brg.is_f32) {
545 vmovups(vmma, addr);
546 } else if (brg.is_bf16) {
547 if (brg.isa_impl == avx2_vnni_2) {
548 if (is_tail_block) {
549 vpmovzxwd(vmma, addr);
550 vpslld(vmma, vmma, 16);
551 } else if (v_i == 0)
552 vcvtneebf162ps(vmma, addr);
553 else
554 vcvtneobf162ps(vmma, addr);
555 } else {
556 vpmovzxwd(vmma, addr);
557 if (brg.is_bf16_tmm) vpslld(vmma, vmma, 16);
558 }
559 } else if (brg.is_f16) {
560 if (brg.isa_impl == avx2_vnni_2) {
561 if (is_tail_block)
562 vcvtph2ps(vmma, addr);
563 else if (v_i == 0)
564 vcvtneeph2ps(vmma, addr);
565 else
566 vcvtneoph2ps(vmma, addr);
567 } else
568 vcvtph2ps(vmma, addr);
569 } else if (brg.is_int8) {
570 if (is_fast_vnni_int8()) {
571 assert(!mask_flag);
572 vbroadcasti32x4(vmma, addr);
573 } else
574 vpmovzxbd(vmma, addr);
575 }
576 } else {
577 uni_vpxor(vmma, vmma, vmma);
578 load_data(brg.dt_a, vmma, addr, substep_simd);
579 }
580}
581
582template <cpu_isa_t isa, typename Wmm>
583void jit_brdgmm_kernel_base_t<isa, Wmm>::load_b(
584 Vmm vmmb, int n_i, int v_i, bool has_n_tail) {
585 // for B matrix we assume memory is padded and it is safe to load simd
586 // elements. is_tail only used during avx_ne_convert tail optimization.
587 const int n_blocks
588 = has_n_tail && n_block2_tail() > 0 ? n_block2_tail() : n_block2();
589 const bool is_tail_block = has_n_tail && (n_i + 1 == n_blocks);
590 const auto addr = ptr[reg_aux_B + B_offset(n_i)
591 + is_tail_block * v_i * simd_w_ * brg.typesize_B];
592 if (brg.is_f32) {
593 vmovups(vmmb, addr);
594 } else if (brg.is_int8) {
595 // wei is sign extend(s8), where as src is zero extended(u8).
596 if (is_fast_vnni_int8()) {
597 vbroadcasti32x4(vmmb, addr);
598 vmovdqu8(vmmb | kblend_mask | T_z, vmmb);
599 } else {
600 vpmovsxbd(vmmb, addr);
601 }
602 } else if (brg.is_f16) {
603 if (brg.isa_impl == avx2_vnni_2) {
604 if (is_tail_block)
605 vcvtph2ps(vmmb, addr);
606 else if (v_i == 0)
607 vcvtneeph2ps(vmmb, addr);
608 else
609 vcvtneoph2ps(vmmb, addr);
610 } else
611 vcvtph2ps(vmmb, addr);
612 } else if (brg.is_bf16) {
613 if (brg.isa_impl == avx2_vnni_2) {
614 if (is_tail_block) {
615 vpmovzxwd(vmmb, addr);
616 vpslld(vmmb, vmmb, 16);
617 } else if (v_i == 0)
618 vcvtneebf162ps(vmmb, addr);
619 else
620 vcvtneobf162ps(vmmb, addr);
621 } else {
622 vpmovzxwd(vmmb, addr);
623 if (brg.is_bf16_tmm) vpslld(vmmb, vmmb, 16);
624 }
625 }
626}
627
628template <cpu_isa_t isa, typename Wmm>
629void jit_brdgmm_kernel_base_t<isa, Wmm>::brdgmm_microkernel(int m_blocks,
630 int n_blocks, bool has_top_padding, bool has_bottom_padding,
631 bool has_tail) {
632
633 const bool has_padding = has_top_padding || has_bottom_padding;
634 const int max_bvmms
635 = accm(m_blocks, n_blocks, 0, 0, 0).getIdx() - vmm_b(0).getIdx();
636 const int v_substep = vnni_substep();
637
638 auto dot_product = [&](Vmm vmma, Vmm vmmb, int m_i, int n_i, int v_i) {
639 auto vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, v_i);
640 if (brg.is_f32) {
641 if (is_fma_embd()) {
642 const bool mask_flag = has_tail && (n_i + 1 == n_blocks);
643 const auto addr = ptr[reg_aux_A + A_offset(m_i, n_i)];
644 vmm_acc = maybe_mask(vmm_acc, mask_flag, false);
645 vfmadd231ps(vmm_acc, vmmb, addr);
646 } else {
647 vfmadd231ps(vmm_acc, vmma, vmmb);
648 }
649 } else if (brg.is_bf16) {
650 if (brg.is_bf16_tmm /* dont use vdpbf16ps on cpus supporting amx due
651 to poor perf.*/
652 || brg.isa_impl == avx2_vnni_2)
653 vfmadd231ps(vmm_acc, vmma, vmmb);
654 else
655 vdpbf16ps(vmm_acc, vmma, vmmb);
656 } else if (brg.is_f16) {
657 vfmadd231ps(vmm_acc, vmma, vmmb);
658 } else if (brg.is_int8) {
659 vpdpbusd(vmm_acc, vmma, vmmb);
660 }
661 };
662
663 if (!has_padding) {
664 // preload vmm_b if possible.
665 for_(int v_i = 0; v_i < v_substep; ++v_i)
666 for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) {
667 const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i;
668 for (int i = 0; i < n_e; ++i) {
669 const int n_i = nb_i + i;
670 if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue;
671 load_b(vmm_b(i), n_i, v_i, has_tail);
672 }
673 for_(int m_i = 0; m_i < m_blocks; ++m_i)
674 for (int i = 0; i < n_e; ++i) {
675 const int n_i = nb_i + i;
676 if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue;
677 if (!is_fma_embd()) load_a(vmm_a(), m_i, n_i, v_i, has_tail);
678 dot_product(vmm_a(), vmm_b(i), m_i, n_i, v_i);
679 }
680 }
681 } else {
682 const int max_req_preload_vmms = n_blocks * vnni_substep();
683 const int n_preload_b_vmms = max_bvmms >= max_req_preload_vmms
684 ? max_req_preload_vmms
685 : max_bvmms - 1 /*for ad-hoc load*/;
686 for (int i = 0; i < n_preload_b_vmms; ++i) {
687 const int n_i = i % n_blocks;
688 const int v_i = i / n_blocks;
689 if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue;
690 load_b(vmm_b(i), n_i, v_i, has_tail);
691 }
692
693 Label done;
694 Label jmp_table_base;
695 std::vector<Label> jmp_table_labels(m_blocks);
696 if (has_top_padding) {
697 // jmp table
698 mov(reg_table_base, jmp_table_base);
699 lea(reg_table_base,
700 ptr[reg_table_base + reg_aux_A_vpad_top * sizeof(void *)]);
701 jmp(ptr[reg_table_base]);
702 align(8);
703 L(jmp_table_base);
704 for (int m_i = 0; m_i < m_blocks; ++m_i) {
705 putL(jmp_table_labels[m_i]);
706 }
707 }
708
709 for (int m_i = 0; m_i < m_blocks; ++m_i) {
710 L(jmp_table_labels[m_i]);
711 if (has_bottom_padding) {
712 cmp(reg_aux_A_vpad_bottom, m_blocks - m_i);
713 jge(done, T_NEAR);
714 }
715
716 for_(int v_i = 0, p_b_i = 0; v_i < v_substep; ++v_i)
717 for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) {
718 if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue;
719 if (!is_fma_embd()) load_a(vmm_a(), m_i, n_i, v_i, has_tail);
720 if (p_b_i < n_preload_b_vmms) {
721 dot_product(vmm_a(), vmm_b(p_b_i), m_i, n_i, v_i);
722 } else {
723 // preloaded vmm_b not available
724 const int b_idx = max_bvmms - 1;
725 load_b(vmm_b(b_idx), n_i, v_i, has_tail);
726 dot_product(vmm_a(), vmm_b(b_idx), m_i, n_i, v_i);
727 }
728 }
729 }
730 L(done);
731 }
732}
733
734template <cpu_isa_t isa, typename Wmm>
735void jit_brdgmm_kernel_base_t<isa, Wmm>::batch_loop(
736 const int m_blocks, const int n_blocks, bool has_n_tail) {
737
738 auto get_padding_info = [&]() {
739 const bool do_check_effective_padding = check_effective_padding();
740 if (has_vpad()) {
741 Label no_top_padding;
742
743 if (brg.brgattr.max_bottom_vpad > 0) {
744 if (do_check_effective_padding) {
745 Label done_adjust_bottom_padding;
746 mov(reg_aux_A_vpad_bottom, reg_aux_M);
747 add(reg_aux_A_vpad_bottom, m_blocks - M());
748 add(reg_aux_A_vpad_bottom,
749 ptr[reg_aux_batch_addr
750 + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]);
751 jge(done_adjust_bottom_padding, T_NEAR);
752 xor_(reg_aux_A_vpad_bottom, reg_aux_A_vpad_bottom);
753 L(done_adjust_bottom_padding);
754 } else {
755 mov(reg_aux_A_vpad_bottom,
756 ptr[reg_aux_batch_addr
757 + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]);
758 }
759 mov(reg_total_padding, reg_aux_A_vpad_bottom);
760 }
761 if (brg.brgattr.max_top_vpad > 0) {
762 mov(reg_aux_A_vpad_top,
763 ptr[reg_aux_batch_addr
764 + GET_OFF_BATCH_ELEMENT(vvpad.top)]);
765 if (do_check_effective_padding) {
766 Label done_adjust_top_padding;
767 sub(reg_aux_A_vpad_top, reg_aux_M);
768 jge(done_adjust_top_padding, T_NEAR);
769 xor_(reg_aux_A_vpad_top, reg_aux_A_vpad_top);
770 L(done_adjust_top_padding);
771 }
772 if (brg.brgattr.max_bottom_vpad > 0) {
773 add(reg_total_padding, reg_aux_A_vpad_top);
774 } else {
775 mov(reg_total_padding, reg_aux_A_vpad_top);
776 }
777 }
778 }
779 };
780
781 auto call_brdgmm_microkernel = [&]() {
782 const int tpad = brg.brgattr.max_top_vpad;
783 const int bpad = brg.brgattr.max_bottom_vpad;
784 const bool vpad_exists = has_vpad();
785 Label microkernel_with_padding, done_microkernel;
786
787 if (vpad_exists) {
788 cmp(reg_total_padding, 0);
789 jg(microkernel_with_padding, T_NEAR);
790 }
791 brdgmm_microkernel(m_blocks, n_blocks, false, false, has_n_tail);
792 if (vpad_exists) {
793 jmp(done_microkernel, T_NEAR);
794 L(microkernel_with_padding);
795 if ((tpad + bpad) >= m_blocks) {
796 cmp(reg_total_padding, m_blocks);
797 jge(done_microkernel, T_NEAR);
798 }
799 brdgmm_microkernel(m_blocks, n_blocks, tpad, bpad, has_n_tail);
800 }
801 L(done_microkernel);
802 };
803
804 Label bs_loop_label, done_bs_loop;
805 load_accumulators(m_blocks, n_blocks);
806 cmp(reg_BS, 0);
807 jle(done_bs_loop, T_NEAR);
808 mov(reg_BS_loop, reg_BS);
809 restore_A_B_matrices();
810
811 L(bs_loop_label);
812 {
813 set_A_B_matrices();
814 get_padding_info();
815 advance_A_B_matrices();
816 call_brdgmm_microkernel();
817 dec(reg_BS_loop);
818 jg(bs_loop_label, T_NEAR);
819 }
820
821 L(done_bs_loop);
822
823 store_accumulators(m_blocks, n_blocks, has_n_tail);
824}
825
826template <cpu_isa_t isa, typename Wmm>
827void jit_brdgmm_kernel_base_t<isa, Wmm>::compute_loop() {
828
829 const bool has_m_block2_tail = m_block2_tail() > 0;
830 const int loop_m = (nb_m_block2() - has_m_block2_tail);
831 const bool do_loop_m = loop_m > 1;
832
833 const bool has_n_block2_tail = n_block2_tail() > 0;
834 const bool need_separate_n_block1_tail_block = n_block1_tail() != 0
835 && !has_n_block2_tail && nb_n_block2() > 1
836 && !isa_has_masks(brg.isa_impl);
837 const int loop_n = nb_n_block2() - has_n_block2_tail
838 - need_separate_n_block1_tail_block;
839 const bool do_loop_n = loop_n > 1;
840 const bool loop_n_update_aux_ptrs = do_loop_n || (loop_n < nb_n_block2());
841
842 auto n_loop = [&](int m_blocks) {
843 Label n_loop_label;
844 const int n_blocks = n_block2();
845 const int n_loop_step = oc_logical_offset(n_blocks);
846 const int n_loop_work = loop_n * n_blocks * n_block1();
847 const bool vlen_tail_in_loop = n_block1_tail() != 0
848 && !need_separate_n_block1_tail_block && !has_n_block2_tail;
849
850 xor_(reg_aux_N, reg_aux_N);
851
852 L(n_loop_label);
853 {
854 if (do_loop_n) {
855 if (vlen_tail_in_loop) {
856 Label done_k_mask;
857 cmp(reg_aux_N, n_loop_work - n_loop_step);
858 jl(done_k_mask, T_NEAR);
859 kmovd(k_mask, k_tail_mask);
860 L(done_k_mask);
861 }
862 }
863
864 batch_loop(m_blocks, n_blocks, vlen_tail_in_loop);
865
866 if (loop_n_update_aux_ptrs) {
867 add(reg_aux_N, n_loop_step);
868 add(reg_a_offset, n_loop_step * brg.typesize_A);
869 add(reg_aux_C, n_loop_step * brg.typesize_C);
870 add(reg_aux_D, n_loop_step * brg.typesize_D);
871 }
872
873 if (do_loop_n) {
874 cmp(reg_aux_N, n_loop_work);
875 jl(n_loop_label, T_NEAR);
876 }
877 }
878
879 if (need_separate_n_block1_tail_block)
880 batch_loop(m_blocks, n_blocks, true);
881
882 if (has_n_block2_tail) {
883 batch_loop(m_blocks, n_block2_tail(), n_block1_tail() != 0);
884 }
885 };
886
887 auto m_loop = [&]() {
888 Label m_loop_label;
889 const int m_blocks = m_block2();
890 const bool reset_mask = isa_has_masks(brg.isa_impl)
891 && n_block1_tail() != 0 && do_loop_n && !has_n_block2_tail;
892
893 xor_(reg_aux_M, reg_aux_M);
894 xor_(reg_a_offset, reg_a_offset);
895
896 L(m_loop_label);
897 {
898 if (reset_mask) kxnorq(k_mask, k_mask, k_mask);
899 n_loop(m_blocks);
900
901 if (do_loop_m || has_m_block2_tail) {
902 add(reg_aux_M, m_blocks);
903 const int n_loop_offset
904 = loop_n_update_aux_ptrs * loop_n * n_block2();
905 add(reg_a_offset, A_offset(m_blocks, -n_loop_offset));
906 add(reg_aux_C, C_offset(m_blocks, -n_loop_offset, 0));
907 add(reg_aux_D, D_offset(m_blocks, -n_loop_offset, 0));
908 }
909
910 if (do_loop_m) {
911 cmp(reg_aux_M, loop_m * m_block2());
912 jl(m_loop_label, T_NEAR);
913 }
914 }
915
916 if (m_block2_tail() > 0) {
917 if (reset_mask) { kxnorq(k_mask, k_mask, k_mask); }
918 n_loop(m_block2_tail());
919 }
920 };
921
922 assert(m_block1_tail() == 0);
923 m_loop();
924}
925
926template <cpu_isa_t isa, typename Wmm>
927void jit_brdgmm_kernel_base_t<isa, Wmm>::init_masks() {
928 if (!isa_has_masks(brg.isa_impl)) return;
929
930 if (is_fast_vnni_int8()) {
931 mov(reg_tmp, 0x8888444422221111);
932 kmovq(kblend_mask, reg_tmp);
933 }
934
935 if (n_block1_tail() != 0) {
936 const auto tail_mask = size_t((1 << n_block1_tail()) - 1);
937 const bool has_n_block2_tail = n_block2_tail() > 0;
938 mov(reg_tmp, tail_mask);
939 if (has_n_block2_tail || nb_n_block2() <= 1) {
940 // The mask can be set only once.
941 kmovq(k_mask, reg_tmp);
942 } else {
943 // Need to adjust mask, and set only when needed.
944 // So store it temporarily in k_tail_mask.
945 kmovq(k_tail_mask, reg_tmp);
946 }
947 } else if (brg.with_binary) {
948 // the post-ops injector seems to use mask unconditionally
949 // set a default mask.
950 kxnorq(k_mask, k_mask, k_mask);
951 }
952}
953
954template <cpu_isa_t isa, typename Wmm>
955void jit_brdgmm_kernel_base_t<isa, Wmm>::generate() {
956
957 preamble();
958 sub(rsp, stack_space_needed_);
959
960 init_masks();
961
962 if (is_fast_vnni_int8() && !brg.is_bf16_emu) {
963 // load permute indices from data section
964 mov(reg_tmp, permute_index_table);
965 vmovdqu32(vmm_permute(), ptr[reg_tmp]);
966 }
967
968 read_params();
969 compute_loop();
970
971 add(rsp, stack_space_needed_);
972 postamble();
973
974 if (brg.with_eltwise) postops_injector_->prepare_table();
975
976 if (is_fast_vnni_int8()) {
977 align(64);
978 L(permute_index_table);
979 const uint32_t _idx[]
980 = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
981 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
982 dd(_idx[i]);
983 }
984}
985
986template <cpu_isa_t isa, typename Wmm>
987brdgmm_kernel_t<isa, Wmm>::brdgmm_kernel_t(const brgemm_t abrd) {
988 brgemm_kernel_ = new jit_brdgmm_kernel_base_t<isa, Wmm>(abrd);
989}
990
991template <cpu_isa_t isa, typename Wmm>
992status_t brdgmm_kernel_t<isa, Wmm>::create_kernel() {
993 return brgemm_kernel_->create_kernel();
994}
995
996template <cpu_isa_t isa, typename Wmm>
997void brdgmm_kernel_t<isa, Wmm>::operator()(
998 brgemm_kernel_params_t *params) const {
999 (*brgemm_kernel_)(params);
1000}
1001
1002template <cpu_isa_t isa, typename Wmm>
1003brdgmm_kernel_t<isa, Wmm>::~brdgmm_kernel_t() {
1004 delete brgemm_kernel_;
1005}
1006
1007template struct brdgmm_kernel_t<avx512_core_fp16, Xbyak::Zmm>;
1008template struct brdgmm_kernel_t<avx512_core_bf16, Xbyak::Zmm>;
1009template struct brdgmm_kernel_t<avx512_core_vnni, Xbyak::Zmm>;
1010template struct brdgmm_kernel_t<avx512_core, Xbyak::Zmm>;
1011template struct brdgmm_kernel_t<avx2, Xbyak::Ymm>;
1012template struct brdgmm_kernel_t<avx2_vnni_2, Xbyak::Ymm>;
1013
1014} // namespace x64
1015} // namespace cpu
1016} // namespace impl
1017} // namespace dnnl
1018