1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_POST_OPS_HPP
18#define CPU_X64_JIT_BRGEMM_POST_OPS_HPP
19
20#include <memory>
21
22#include "common/c_types_map.hpp"
23#include "common/memory_tracking.hpp"
24
25#include "cpu/cpu_engine.hpp"
26
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
29#include "cpu/x64/jit_brgemm_primitive_conf.hpp"
30#include "cpu/x64/jit_generator.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37struct brgemm_kernel_diff_bias_t {
38 brgemm_kernel_diff_bias_t()
39 : ptr_diff_dst(nullptr)
40 , ptr_diff_bias_acc(nullptr)
41 , ptr_diff_bias(nullptr)
42 , flags(0) {};
43
44 void *ptr_diff_dst;
45 void *ptr_diff_bias_acc;
46 void *ptr_diff_bias;
47 int flags;
48};
49
50#define GET_OFF(field) offsetof(brgemm_kernel_diff_bias_t, field)
51
52struct jit_brgemm_kernel_diff_bias_t : public jit_generator {
53 jit_brgemm_kernel_diff_bias_t(
54 const jit_brgemm_primitive_conf_t &ajbgp, const brgemm_t &abrg)
55 : jit_generator(jit_name())
56 , brg_(abrg)
57 , ddst_dt_(ajbgp.dst_dt)
58 , bia_dt_(ajbgp.bia_dt)
59 , acc_dt_(ajbgp.acc_dt)
60 , bia_typesize_(types::data_type_size(bia_dt_))
61 , acc_typesize_(types::data_type_size(acc_dt_)) {
62
63 ddst_dt_ = (ajbgp.isa == avx512_core_fp16 && ajbgp.use_buffer_b)
64 ? data_type::f32
65 : ajbgp.dst_dt;
66 ddst_typesize_ = types::data_type_size(ddst_dt_);
67 mult_ = data_type_vnni_granularity(ddst_dt_);
68 }
69
70 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_diff_bias_t)
71
72private:
73 brgemm_t brg_;
74 data_type_t ddst_dt_;
75 data_type_t bia_dt_;
76 data_type_t acc_dt_;
77
78 int ddst_typesize_;
79 int bia_typesize_;
80 int acc_typesize_;
81 int mult_;
82
83 using reg64_t = const Xbyak::Reg64;
84 // Register decomposition
85 const reg64_t param1 = abi_param1;
86 const reg64_t reg_ddst = r15;
87 const reg64_t reg_bias = r14;
88 const reg64_t reg_bias_acc = r13;
89 const reg64_t aux_reg_ddst = r12;
90 const reg64_t reg_k_iter = r11;
91 const reg64_t reg_flag = r10;
92
93 Xbyak::Opmask k_full_mask = Xbyak::Opmask(2);
94 Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3);
95 Xbyak::Opmask k_f16_perm_mask = Xbyak::Opmask(4);
96 Xbyak::Zmm vreg_unit = Xbyak::Zmm(31);
97 Xbyak::Zmm vreg_perm = Xbyak::Zmm(30);
98
99 const int n_max_regs_ = 4;
100
101 const Xbyak::Zmm zmm_mask(const Xbyak::Zmm zmm_in, bool mask_flag,
102 bool store, Xbyak::Opmask ktail_mask) {
103 return mask_flag
104 ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
105 : zmm_in;
106 }
107
108 Xbyak::Zmm get_bias_reg(int n) const { return Xbyak::Zmm(n); }
109 Xbyak::Ymm get_bias_reg_lower(int n) const { return Xbyak::Ymm(n); }
110 Xbyak::Zmm get_ddst_reg(int n) const { return Xbyak::Zmm(n + n_max_regs_); }
111
112 void accumulate_bias(int idx, bool mask_flag) {
113 auto vddst = get_ddst_reg(idx);
114 auto vddst_load = zmm_mask(vddst, mask_flag, false, k_tail_mask);
115 auto vbias = get_bias_reg(idx);
116 if (ddst_dt_ == data_type::f16) {
117 // As we do not have fp16_vnni, we add twice to accumulate
118 // adjacent elements.
119 for (int i = 0; i < 2; ++i) {
120 auto addr = ptr[aux_reg_ddst
121 + ddst_typesize_ * mult_ * idx * brg_.ld_block + i * 2];
122 vmovups(vddst_load, addr);
123 vpermw(vddst | k_f16_perm_mask | T_z, vreg_perm, vddst);
124 vcvtph2psx(vddst, Xbyak::Ymm(vddst.getIdx()));
125 vaddps(vbias, vbias, vddst);
126 }
127 } else {
128 auto addr = ptr[aux_reg_ddst
129 + ddst_typesize_ * mult_ * idx * brg_.ld_block];
130 vmovups(vddst_load, addr);
131 if (ddst_dt_ == data_type::bf16)
132 vdpbf16ps(vbias, vreg_unit, vddst);
133 else
134 vaddps(vbias, vbias, vddst);
135 }
136 }
137
138 void store(int idx, bool mask_flag) {
139 auto addr = ptr[reg_bias + bia_typesize_ * idx * brg_.ld_block];
140 auto vbias = get_bias_reg(idx);
141 auto vbias_lower = get_bias_reg_lower(idx);
142 switch (bia_dt_) {
143 case data_type::bf16:
144 vcvtneps2bf16(vbias_lower, vbias);
145 if (mask_flag) {
146 vmovdqu16(addr,
147 zmm_mask(vbias, mask_flag, true, k_tail_mask));
148 } else {
149 vmovups(addr, vbias_lower);
150 }
151 break;
152 case data_type::f16:
153 vcvtps2ph(vbias_lower, vbias, 0x4);
154 if (mask_flag) {
155 vmovdqu16(addr,
156 zmm_mask(vbias, mask_flag, true, k_tail_mask));
157 } else {
158 vmovups(addr, vbias_lower);
159 }
160 break;
161 case data_type::f32:
162 vmovups(addr,
163 zmm_mask(get_bias_reg(idx), mask_flag, true,
164 k_tail_mask));
165 break;
166 default: assert("Unsupported bias data type");
167 }
168 }
169
170 void loop_by_N(int n_loop, int nb_tail) {
171
172 mov(aux_reg_ddst, reg_ddst);
173
174 int n_iters = n_loop;
175 if (nb_tail > 0) n_iters--;
176 Xbyak::Label k_loop, init_zero, init_done;
177 int n_ = 0;
178
179 test(reg_flag, FLAG_REDUCE_FIRST);
180 jnz(init_zero, T_NEAR); // FLAG_REDUCE_FIRST is set
181
182 for (; n_ < n_iters; n_++) {
183 auto vbias = get_bias_reg(n_);
184 auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block];
185 vmovups(vbias, addr);
186 }
187 if (nb_tail > 0) {
188 auto vbias = zmm_mask(get_bias_reg(n_), true, false, k_tail_mask);
189 auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block];
190 vmovups(vbias, addr);
191 }
192 jmp(init_done, T_NEAR);
193 L(init_zero);
194
195 for (int n_ = 0; n_ < n_loop; n_++) {
196 vxorpd(get_bias_reg(n_), get_bias_reg(n_), get_bias_reg(n_));
197 }
198 L(init_done);
199
200 mov(reg_k_iter, utils::div_up(brg_.reduce_dim, mult_));
201 L(k_loop);
202 {
203 int n_ = 0;
204 for (; n_ < n_iters; n_++)
205 accumulate_bias(n_, false);
206
207 if (nb_tail > 0) accumulate_bias(n_, true);
208
209 add(aux_reg_ddst, ddst_typesize_ * mult_ * brg_.LDB);
210
211 sub(reg_k_iter, 1);
212 jnz(k_loop, T_NEAR);
213 }
214
215 Xbyak::Label store_final, store_done;
216 test(reg_flag, FLAG_REDUCE_LAST);
217 jnz(store_final, T_NEAR); // FLAG_REDUCE_LAST is set
218
219 n_ = 0;
220 for (; n_ < n_iters; n_++) {
221 auto vbias = get_bias_reg(n_);
222 auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block];
223 vmovups(addr, vbias);
224 }
225 if (nb_tail > 0) {
226 auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block];
227 auto vbias = zmm_mask(get_bias_reg(n_), true, true, k_tail_mask);
228 vmovups(addr, vbias);
229 }
230 jmp(store_done, T_NEAR);
231
232 L(store_final);
233 n_ = 0;
234
235 for (; n_ < n_iters; n_++)
236 store(n_, false);
237
238 if (nb_tail > 0) store(n_, true);
239
240 L(store_done);
241 }
242
243 void generate() override {
244 preamble();
245
246 int nb = utils::div_up(brg_.load_dim, brg_.ld_block);
247 int nb_tail = brg_.load_dim % brg_.ld_block;
248
249 int n_loop = nb / n_max_regs_;
250 int n_loop_tail = nb % n_max_regs_;
251 if (n_loop_tail == 0 && nb_tail > 0) {
252 n_loop--;
253 n_loop_tail = n_max_regs_;
254 }
255
256 const auto full_mask = size_t {0xffffffffffffffff};
257 const auto tail_mask = size_t((1 << nb_tail) - 1);
258 reg64_t reg_mask = rax;
259
260 mov(reg_mask, full_mask);
261 kmovq(k_full_mask, reg_mask);
262 mov(reg_mask, tail_mask);
263 kmovq(k_tail_mask, reg_mask);
264
265 if (ddst_dt_ == data_type::bf16) {
266 auto reg_unit_val = reg_mask.cvt16();
267 mov(reg_unit_val, 0x3f80); // bf16 value of 1.
268 vpbroadcastw(vreg_unit, reg_unit_val);
269 }
270
271 Xbyak::Label f16_perm_table;
272 if (ddst_dt_ == data_type::f16) {
273 const auto half_mask = size_t((1 << 16) - 1);
274 mov(reg_mask, half_mask);
275 kmovq(k_f16_perm_mask, reg_mask);
276
277 mov(reg_mask, f16_perm_table);
278 vmovups(vreg_perm | k_f16_perm_mask | T_z, ptr[reg_mask]);
279 }
280
281 mov(reg_ddst, ptr[param1 + GET_OFF(ptr_diff_dst)]);
282 mov(reg_bias_acc, ptr[param1 + GET_OFF(ptr_diff_bias_acc)]);
283 mov(reg_bias, ptr[param1 + GET_OFF(ptr_diff_bias)]);
284 mov(reg_flag, ptr[param1 + GET_OFF(flags)]);
285
286 for (int nb_ = 0; nb_ < n_loop; nb_++) {
287 loop_by_N(n_max_regs_, 0);
288
289 add(reg_ddst, ddst_typesize_ * mult_ * n_max_regs_ * brg_.ld_block);
290 add(reg_bias, bia_typesize_ * n_max_regs_ * brg_.ld_block);
291 add(reg_bias_acc, acc_typesize_ * n_max_regs_ * brg_.ld_block);
292 }
293
294 if (n_loop_tail > 0) loop_by_N(n_loop_tail, nb_tail);
295 postamble();
296
297 if (ddst_dt_ == data_type::f16) {
298 // convert interleaved vnni data with holes to packed.
299 const uint16_t f16_prm_array[16] = {
300 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30};
301 align(64);
302 L(f16_perm_table);
303 for (int i = 0; i < 16; ++i)
304 dw(f16_prm_array[i]);
305 }
306 }
307};
308
309#undef GET_OFF
310
311#define GET_OFF(field) offsetof(brgemm_kernel_post_ops_t, field)
312
313struct brgemm_kernel_post_ops_t {
314 void *ptr_in;
315 void *ptr_out;
316 void *ptr_bias;
317 void *ptr_scales;
318 const void *ptr_binary_post_ops_rhs;
319 size_t apply_comp = 0;
320 int32_t a_comp_val = 1;
321 int32_t *a_zp_compensation;
322 int32_t *c_zp_values;
323 int32_t *s8s8_compensation;
324 const void *dst_orig;
325};
326
327template <cpu_isa_t isa>
328struct jit_brgemm_kernel_post_ops : public jit_generator {
329
330 jit_brgemm_kernel_post_ops(const jit_brgemm_conv_conf_t &ajcp,
331 const brgemm_t &abrg, const primitive_attr_t &aattr)
332 : jit_generator(jit_name())
333 , brg(abrg)
334 , jcp(ajcp)
335 , attr(aattr)
336 , postops_injector_(nullptr)
337 , with_binary_non_scalar_bcast_(brg.with_binary
338 && binary_injector::
339 any_binary_postop_rhs_non_scalar_broadcast(
340 brg.attr->post_ops_,
341 memory_desc_wrapper(brg.dst_md))) {
342
343 if ((jcp.with_sum && brg.beta != 0)
344 || ((jcp.with_binary || jcp.with_eltwise) && brg.alpha != 0)) {
345 static constexpr bool preserve_gpr = true;
346 static constexpr bool preserve_vmm = true;
347 static constexpr bool use_exact_tail_scalar_bcast = false;
348
349 const binary_injector::rhs_arg_static_params_t rhs_sp {
350 static_cast<size_t>(vmm_tmp(4).getIdx()), this->r14,
351 this->r15, this->r13, preserve_gpr, preserve_vmm,
352 GET_OFF(ptr_binary_post_ops_rhs), GET_OFF(dst_orig),
353 memory_desc_wrapper(brg.dst_md),
354 static_cast<size_t>(brg.load_dim % brg.ld_block),
355 k_tail_mask, use_exact_tail_scalar_bcast};
356 const binary_injector::static_params_t bsp {this->param1, rhs_sp};
357
358 const bool save_state = (brg.alpha != 0) && jcp.with_eltwise;
359 const auto &reserved_eltwise_gpr = rax;
360 const auto reserved_eltwise_maskr = Xbyak::Opmask(1);
361
362 const eltwise_injector::static_params_t esp {
363 save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};
364
365 postops_injector_ = utils::make_unique<
366 injector::jit_uni_postops_injector_t<po_isa_t>>(
367 this, attr.post_ops_, bsp, esp);
368 }
369 if (brg.is_bf16_emu)
370 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
371 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
372 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4);
373
374 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
375 // per_oc: conv: 1 << 0, 1 << 1 (with groups)
376 // per_oc: ip: 1 << 0
377 is_oc_scale_ = utils::one_of(wei_scales.mask_, 1 << 0, 1 << 1);
378
379 LDD_ = brg.LDD;
380 inp_dt_ = brg.dt_c;
381 out_dt_ = brg.dt_d;
382 bia_dt_ = jcp.bia_dt;
383 inp_typesize_ = types::data_type_size(inp_dt_);
384 out_typesize_ = types::data_type_size(out_dt_);
385 bia_typesize_ = (jcp.with_bias) ? types::data_type_size(bia_dt_) : 0;
386 }
387
388 ~jit_brgemm_kernel_post_ops() = default;
389
390 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_post_ops)
391
392 brgemm_t brg;
393 jit_brgemm_conv_conf_t jcp;
394 const primitive_attr_t &attr;
395
396private:
397 int LDD_;
398
399 data_type_t inp_dt_;
400 data_type_t out_dt_;
401 data_type_t bia_dt_;
402 static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core, avx2,
403 avx2, avx2_vnni_2, avx2_vnni_2, avx512_core_fp16, avx512_core_fp16);
404 std::unique_ptr<injector::jit_uni_postops_injector_t<po_isa_t>>
405 postops_injector_;
406 std::unique_ptr<bf16_emulation_t> bf16_emu_;
407
408 const bool with_binary_non_scalar_bcast_;
409
410 int inp_typesize_;
411 int out_typesize_;
412 int bia_typesize_;
413
414 int is_oc_scale_;
415 constexpr static int max_vregs_ = cpu_isa_traits<po_isa_t>::n_vregs;
416
417 using reg64_t = const Xbyak::Reg64;
418 using Vmm =
419 typename utils::conditional<utils::one_of(isa, avx2, avx2_vnni_2),
420 Xbyak::Ymm, Xbyak::Zmm>::type;
421 using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t;
422
423 // Register decomposition
424 const reg64_t param1 = abi_param1;
425 const reg64_t reg_in = r15;
426 const reg64_t reg_out = r14;
427 const reg64_t aux_reg_in = r13;
428 const reg64_t aux_reg_out = r12;
429
430 const reg64_t reg_bias = r11;
431 const reg64_t aux_reg_bias = r10;
432
433 const reg64_t reg_scales = r9;
434 const reg64_t aux_reg_scales = r8;
435
436 const reg64_t reg_ptr_sum_scale = rdx;
437 const reg64_t reg_ptr_sum_zp = rsi;
438
439 const reg64_t reg_zp_c_values = rbx;
440 const reg64_t aux_reg_zp_c_values = rbx;
441 const reg64_t reg_zp_a_comp = rbx;
442 const reg64_t aux_reg_zp_a_comp = rbx;
443 const reg64_t reg_s8s8_comp = rbx;
444 const reg64_t aux_reg_s8s8_comp = rbx;
445 const reg64_t reg_zp_a_val = rbx;
446 const reg64_t reg_apply_comp = rbx;
447
448 constexpr static int reg_zp_c_values_offs_ = 0;
449 constexpr static int aux_reg_zp_c_values_offs_ = 8;
450 constexpr static int reg_zp_a_comp_offs_ = 16;
451 constexpr static int aux_reg_zp_a_comp_offs_ = 24;
452 constexpr static int reg_s8s8_comp_offs_ = 32;
453 constexpr static int aux_reg_s8s8_comp_offs_ = 40;
454 constexpr static int reg_zp_a_val_offs_ = 48;
455 constexpr static int reg_apply_comp_offs_ = 56;
456 constexpr static int stack_space_needed_ = 64;
457
458 /* bf16 emulation */
459 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27);
460 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(24);
461 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(25);
462 Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(26);
463 reg64_t bf16_emu_scratch = rax;
464
465 Xbyak::Opmask k_full_mask = Xbyak::Opmask(2);
466 Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3);
467
468 const int n_block2_ = 4;
469
470 Vmm vmm_tmp(int i) const { return Vmm(max_vregs_ - 1 - i); }
471
472 int zp_c_values_offset(int n, bool is_tail = false) const noexcept {
473 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
474 return (is_tail) ? sizeof(int32_t) * brg.ldb_tail
475 : sizeof(int32_t) * n * brg.ld_block;
476 }
477
478 return 0;
479 }
480 int zp_comp_a_vpad_offset(int n, int m, bool is_tail = false) const
481 noexcept {
482 return (is_tail) ? sizeof(int32_t) * (brg.ldb_tail + m * brg.LDB)
483 : sizeof(int32_t) * (n * brg.ld_block + m * brg.LDB);
484 }
485 int mb_zp_comp_a_offset(int m_block) const noexcept {
486 return sizeof(int32_t) * m_block * brg.LDB;
487 }
488 int compensation_vpad_offset(int n, int m, bool is_tail = false) const
489 noexcept {
490 return (is_tail) ? sizeof(int32_t) * (brg.ldb_tail + m * brg.LDB)
491 : sizeof(int32_t) * (n * brg.ld_block + m * brg.LDB);
492 }
493 int mb_compensation_offset(int m_block) const noexcept {
494 return sizeof(int32_t) * m_block * brg.LDB;
495 }
496
497 template <typename T>
498 const T maybe_mask(const T vmm_in, bool mask_flag, bool store,
499 Xbyak::Opmask ktail_mask) {
500 assert(IMPLICATION(mask_flag, isa_has_masks(isa)));
501 return mask_flag
502 ? (store ? vmm_in | ktail_mask : vmm_in | ktail_mask | T_z)
503 : vmm_in;
504 }
505
506 void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op,
507 int tail_size, bool store, Xbyak::Opmask ktail_mask,
508 bool skip_cvt2ps = false) {
509 const bool is_tail = op.isMEM()
510 && tail_size != vreg_traits<Vmm>::vlen / sizeof(float)
511 // The current kernel is written such that tail_size = 0 implies
512 // no tail and full vmm must be processed.
513 && tail_size > 0;
514
515 if (IMPLICATION(is_tail, isa_has_masks(isa))) {
516 const Vmm vmm = maybe_mask(vmm_in, is_tail, store, ktail_mask);
517 switch (type_in) {
518 case data_type::f32:
519 case data_type::s32: vmovups(vmm, op); break;
520 case data_type::s8: vpmovsxbd(vmm, op); break;
521 case data_type::u8: vpmovzxbd(vmm, op); break;
522 case data_type::bf16:
523 vpmovzxwd(vmm, op);
524 vpslld(vmm, vmm, 16);
525 break;
526 case data_type::f16: vcvtph2ps(vmm, op); break;
527 default: assert(!"unsupported data type");
528 }
529 } else {
530 load_data(type_in, vmm_in, op.getAddress(), tail_size);
531 }
532 if (!skip_cvt2ps && types::is_integral_dt(type_in))
533 vcvtdq2ps(vmm_in, vmm_in);
534 }
535
536 Vmm vector(int m, int n, int n_block) { return Vmm(m * n_block + n); };
537
538 void inject_attr_postops(int m_block, int n_block, int tail = 0) {
539 const auto &p = attr.post_ops_;
540 const int sum_idx = p.find(primitive_kind::sum);
541 const auto k_mask = tail == 0 ? k_full_mask : k_tail_mask;
542 const auto sum_dt = p.get_sum_dt(out_dt_);
543
544 const auto sum_injector = [&] {
545 const float *p_sum_scale = &p.entry_[sum_idx].sum.scale;
546 const int32_t *p_sum_zp = &p.entry_[sum_idx].sum.zero_point;
547 if (*p_sum_scale != 1.f)
548 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
549 auto vmm_sum_zp = vmm_tmp(1);
550 if (*p_sum_zp != 0) {
551 mov(reg_ptr_sum_zp, (size_t)p_sum_zp);
552 vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
553 }
554
555 for_(int m = 0; m < m_block; m++)
556 for (int n = 0; n < n_block; n++) {
557 const auto vmm = vector(m, n, n_block);
558 const auto addr = ptr[aux_reg_out
559 + out_typesize_ * (m * LDD_ + n * brg.ld_block)];
560
561 const auto vmm_prev_dst = vmm_tmp(0);
562 cvt2ps(sum_dt, vmm_prev_dst, addr, tail, false, k_mask);
563 if (*p_sum_zp != 0) vsubps(vmm_prev_dst, vmm_sum_zp);
564 if (*p_sum_scale == 1.f)
565 vaddps(vmm, vmm_prev_dst);
566 else {
567 if (is_superset(isa, avx512_core)) {
568 vfmadd231ps(
569 vmm, vmm_prev_dst, ptr_b[reg_ptr_sum_scale]);
570 } else {
571 auto vmm_sum_scale = vmm_tmp(1);
572 vpbroadcastd(vmm_sum_scale, ptr[reg_ptr_sum_scale]);
573 vfmadd231ps(vmm, vmm_prev_dst, vmm_sum_scale);
574 }
575 }
576 }
577 };
578
579 if (jcp.with_sum && brg.beta != 0) {
580 postops_injector_->set_lambda_injector(
581 primitive_kind::sum, sum_injector);
582 }
583
584 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
585
586 if (with_binary_non_scalar_bcast_) {
587 for_(int m = 0; m < m_block; m++)
588 for (int n = 0; n < n_block; n++) {
589 const auto vmm_idx = vector(m, n, n_block).getIdx();
590 const size_t aux_output_offset
591 = out_typesize_ * (m * LDD_ + n * brg.ld_block);
592
593 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, aux_reg_out);
594 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
595 vmm_idx, aux_output_offset);
596 if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
597 }
598 }
599
600 postops_injector_->compute_vector_range(
601 0, m_block * n_block, rhs_arg_params);
602 }
603
604 void apply_comp(int m_block, int n_block, int tail = 0) {
605 auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;
606
607 if (brg.alpha != 0 && brg.zp_type_a != brgemm_broadcast_t::none) {
608 auto vmm_zp_a_val = vmm_tmp(1);
609 mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]);
610 vpbroadcastd(vmm_zp_a_val, reg_zp_a_val.cvt32());
611
612 mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]);
613 for (int n = 0; n < n_block; n++) {
614 auto vmm_zp_comp_a = vmm_tmp(0);
615 auto zp_comp_a_addr = EVEX_compress_addr(aux_reg_zp_a_comp,
616 sizeof(int32_t) * (n * brg.ld_block));
617 vmm_zp_comp_a
618 = maybe_mask(vmm_zp_comp_a, tail > 0, false, k_mask);
619 vmovups(vmm_zp_comp_a, zp_comp_a_addr);
620 vpmulld(vmm_zp_comp_a, vmm_zp_a_val, zp_comp_a_addr);
621
622 for (int m = 0; m < m_block; m++) {
623 auto vmm = vector(m, n, n_block);
624 vpaddd(vmm, vmm, vmm_zp_comp_a);
625 }
626 }
627 }
628
629 if (brg.alpha != 0 && brg.req_s8s8_compensation) {
630 mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
631 for (int n = 0; n < n_block; n++) {
632 auto vmm_comp = vmm_tmp(0);
633 auto comp_addr = EVEX_compress_addr(aux_reg_s8s8_comp,
634 sizeof(int32_t) * (n * brg.ld_block));
635 vmm_comp = maybe_mask(vmm_comp, tail > 0, false, k_mask);
636 vmovups(vmm_comp, comp_addr);
637
638 for (int m = 0; m < m_block; m++) {
639 auto vmm = vector(m, n, n_block);
640 vpaddd(vmm, vmm, vmm_comp);
641 }
642 }
643 }
644 }
645
646 void maybe_apply_comp(int m_block, int n_block, int tail = 0) {
647 Xbyak::Label label_apply_without_comp;
648 mov(reg_apply_comp, ptr[rsp + reg_apply_comp_offs_]);
649 cmp(reg_apply_comp, 0);
650 je(label_apply_without_comp, T_NEAR);
651 apply_comp(m_block, n_block, tail);
652 L_aligned(label_apply_without_comp);
653
654 for_(int m = 0; m < m_block; m++)
655 for (int n = 0; n < n_block; n++) {
656 vcvtdq2ps(vector(m, n, n_block), vector(m, n, n_block));
657 }
658 }
659
660 void apply_post_ops(int m_block, int n_block, int tail = 0) {
661 const auto vector = [=](int m, int n) { return Vmm(m * n_block + n); };
662 auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask;
663 const auto &p = attr.post_ops_;
664 const int sum_idx = p.find(primitive_kind::sum);
665 const auto req_comp = brg.is_int8 && brg.alpha != 0
666 && (brg.req_s8s8_compensation
667 || brg.zp_type_a != brgemm_broadcast_t::none);
668
669 // brg.alpha == 0 means no read from input, no bias, no eltwise - just
670 // initialize registers by zero at the beginning of kernel
671 // brg.beta == 0 means no sum - just registers write to output
672 // req_comp == true -> convert accumulated values to f32 after applying
673 // compensation to avoid the loss of accuracy when converting s32 to f32
674 for_(int m = 0; m < m_block; m++)
675 for (int n = 0; n < n_block; n++) {
676 if (brg.alpha == 0) {
677 if (sum_idx != -1 && brg.beta != 0) {
678 // if sum then have to init vmm each time
679 uni_vpxor(vector(m, n), vector(m, n), vector(m, n));
680 }
681 } else {
682 auto inp_addr = ptr[aux_reg_in
683 + inp_typesize_ * (m * brg.LDC + n * brg.ld_block)];
684 cvt2ps(inp_dt_, vector(m, n), inp_addr, tail, false, k_mask,
685 req_comp);
686 }
687 }
688
689 if (req_comp) maybe_apply_comp(m_block, n_block, tail);
690
691 if (brg.alpha != 0 && jcp.with_bias) {
692 for_(int m = 0; m < m_block; m++)
693 for (int n = 0; n < n_block; n++) {
694 auto vmm_bias = vmm_tmp(0);
695 auto bias_addr = ptr[aux_reg_bias
696 + bia_typesize_ * (n * brg.ld_block)];
697
698 cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask);
699 vaddps(vector(m, n), vmm_bias);
700 }
701 }
702
703 if (brg.alpha != 0) {
704 for_(int m = 0; m < m_block; m++)
705 for (int n = 0; n < n_block; n++) {
706 const auto addr = ptr[aux_reg_scales
707 + is_oc_scale_ * sizeof(float) * (n * brg.ld_block)];
708 auto vmm = vector(m, n);
709 if (IMPLICATION(tail > 0, isa_has_masks(isa))) {
710 vmm = maybe_mask(vector(m, n), tail > 0, false, k_mask);
711 vmulps(vmm, vmm, addr);
712 } else {
713 auto vmm_scales = vmm_tmp(0);
714 load_data(data_type::f32, vmm_scales, addr, tail);
715 vmulps(vmm, vmm, vmm_scales);
716 }
717 }
718 }
719
720 if (postops_injector_) inject_attr_postops(m_block, n_block, tail);
721
722 if (brg.alpha != 0 && brg.zp_type_c != brgemm_broadcast_t::none) {
723 mov(aux_reg_zp_c_values, ptr[rsp + aux_reg_zp_c_values_offs_]);
724 auto vmm_zp_c = vmm_tmp(0);
725 if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
726 vcvtdq2ps(vmm_zp_c,
727 EVEX_compress_addr(aux_reg_zp_c_values, 0, true));
728 }
729 for (int n = 0; n < n_block; n++) {
730 if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
731 int zp_c_off = zp_c_values_offset(n);
732 auto zp_c_addr
733 = EVEX_compress_addr(aux_reg_zp_c_values, zp_c_off);
734 cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, tail, false,
735 k_mask);
736 }
737 for (int m = 0; m < m_block; m++)
738 vaddps(vector(m, n), vmm_zp_c);
739 }
740 }
741
742 const bool dt_requires_saturation = utils::one_of(
743 brg.dt_d, data_type::u8, data_type::s8, data_type::s32);
744
745 const reg64_t reg_tmp_gpr = rax;
746 auto vmm_lbound = vmm_tmp(0);
747 auto vmm_ubound = vmm_tmp(1);
748 if (dt_requires_saturation) {
749 init_saturate_f32(vmm_lbound, vmm_ubound, reg_tmp_gpr,
750 data_type::f32, brg.dt_d);
751 }
752
753 if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16();
754
755 for_(int m = 0; m < m_block; m++)
756 for (int n = 0; n < n_block; n++) {
757 // incase of tail, stores are unconditionally masked, regardless
758 // of `n`, implying n_block must be equal to `1`.
759 assert(IMPLICATION(tail > 0, n_block == 1));
760 auto vmm = vector(m, n);
761 const size_t offset = out_typesize_ * (m * LDD_ + n * brg.ld_block);
762 const auto addr = ptr[aux_reg_out + offset];
763
764 if (utils::one_of(out_dt_, data_type::bf16, data_type::f16)) {
765 Vmm_lower_t vmm_low = Vmm_lower_t(vmm.getIdx());
766 if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) {
767 if (brg.is_f16)
768 vcvtps2ph(vmm_low, vmm, _op_mxcsr);
769 else if (brg.is_bf16_emu)
770 bf16_emu_->vcvtneps2bf16(vmm_low, vmm);
771 else
772 vcvtneps2bf16(vmm_low, vmm);
773 }
774 vmm_low = maybe_mask(vmm_low, tail > 0, true, k_mask);
775 vmovdqu16(addr, vmm_low);
776 } else {
777 if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) {
778 saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d);
779 if (out_dt_ != data_type::f32) vcvtps2dq(vmm, vmm);
780 }
781 if (IMPLICATION(tail > 0, isa_has_masks(isa))) {
782 vmm = maybe_mask(vmm, tail > 0, true, k_mask);
783 switch (out_dt_) {
784 case data_type::f32:
785 case data_type::s32: vmovups(addr, vmm); break;
786 case data_type::s8: vpmovsdb(addr, vmm); break;
787 case data_type::u8: vpmovusdb(addr, vmm); break;
788 default: assert(!"unknown dst_dt");
789 }
790 } else {
791 store_data(out_dt_, vmm, aux_reg_out, offset, tail);
792 }
793 }
794 }
795 }
796
797 void loop_by_N(int m_block, int nb2, int nb2_tail, int nb_tail) {
798
799 if (brg.alpha) {
800 mov(aux_reg_in, reg_in);
801 if (jcp.with_bias) mov(aux_reg_bias, reg_bias);
802 if (brg.zp_type_c != brgemm_broadcast_t::none) {
803 mov(aux_reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]);
804 mov(ptr[rsp + aux_reg_zp_c_values_offs_], aux_reg_zp_c_values);
805 }
806 if (brg.zp_type_a != brgemm_broadcast_t::none) {
807 mov(aux_reg_zp_a_comp, ptr[rsp + reg_zp_a_comp_offs_]);
808 mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp);
809 }
810 if (brg.req_s8s8_compensation) {
811 mov(aux_reg_s8s8_comp, ptr[rsp + reg_s8s8_comp_offs_]);
812 mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp);
813 }
814 mov(aux_reg_scales, reg_scales);
815 }
816 mov(aux_reg_out, reg_out);
817
818 for (int n_loop_ = 0; n_loop_ < nb2; n_loop_++) {
819 apply_post_ops(m_block, n_block2_);
820
821 const auto oc_l_offset = n_block2_ * brg.ld_block;
822
823 add(aux_reg_out, out_typesize_ * oc_l_offset);
824 if (brg.alpha != 0) {
825 add(aux_reg_in, inp_typesize_ * oc_l_offset);
826
827 if (jcp.with_bias)
828 add(aux_reg_bias, bia_typesize_ * oc_l_offset);
829 if (brg.zp_type_c != brgemm_broadcast_t::none) {
830 mov(aux_reg_zp_c_values,
831 ptr[rsp + aux_reg_zp_c_values_offs_]);
832 add(aux_reg_zp_c_values, zp_c_values_offset(n_block2_));
833 mov(ptr[rsp + aux_reg_zp_c_values_offs_],
834 aux_reg_zp_c_values);
835 }
836 if (brg.zp_type_a != brgemm_broadcast_t::none) {
837 mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]);
838 add(aux_reg_zp_a_comp, sizeof(int32_t) * oc_l_offset);
839 mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp);
840 }
841 if (brg.req_s8s8_compensation) {
842 mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
843 add(aux_reg_s8s8_comp, sizeof(int32_t) * oc_l_offset);
844 mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp);
845 }
846
847 add(aux_reg_scales, is_oc_scale_ * sizeof(float) * oc_l_offset);
848 }
849 }
850 if (nb2_tail > 0) {
851 apply_post_ops(m_block, nb2_tail);
852 const auto oc_l_offset = nb2_tail * brg.ld_block;
853
854 add(aux_reg_out, out_typesize_ * oc_l_offset);
855 if (brg.alpha != 0) {
856 add(aux_reg_in, inp_typesize_ * oc_l_offset);
857 if (jcp.with_bias)
858 add(aux_reg_bias, bia_typesize_ * oc_l_offset);
859 if (brg.zp_type_c != brgemm_broadcast_t::none) {
860 mov(aux_reg_zp_c_values,
861 ptr[rsp + aux_reg_zp_c_values_offs_]);
862 add(aux_reg_zp_c_values, zp_c_values_offset(nb2_tail));
863 mov(ptr[rsp + aux_reg_zp_c_values_offs_],
864 aux_reg_zp_c_values);
865 }
866 if (brg.zp_type_a != brgemm_broadcast_t::none) {
867 mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]);
868 add(aux_reg_zp_a_comp, sizeof(int32_t) * oc_l_offset);
869 mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp);
870 }
871 if (brg.req_s8s8_compensation) {
872 mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
873 add(aux_reg_s8s8_comp, sizeof(int32_t) * oc_l_offset);
874 mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp);
875 }
876
877 add(aux_reg_scales, is_oc_scale_ * sizeof(float) * oc_l_offset);
878 }
879 }
880 if (nb_tail > 0) {
881 apply_post_ops(m_block, 1, nb_tail);
882
883 if (brg.alpha != 0) {
884 add(aux_reg_in, inp_typesize_ * (nb_tail));
885 if (jcp.with_bias) add(aux_reg_bias, bia_typesize_ * (nb_tail));
886 if (brg.zp_type_c != brgemm_broadcast_t::none) {
887 mov(aux_reg_zp_c_values,
888 ptr[rsp + aux_reg_zp_c_values_offs_]);
889 add(aux_reg_zp_c_values, zp_c_values_offset(1, nb_tail));
890 mov(ptr[rsp + aux_reg_zp_c_values_offs_],
891 aux_reg_zp_c_values);
892 }
893 if (brg.zp_type_a != brgemm_broadcast_t::none) {
894 mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]);
895 add(aux_reg_zp_a_comp, sizeof(int32_t) * nb_tail);
896 mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp);
897 }
898 if (brg.req_s8s8_compensation) {
899 mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]);
900 add(aux_reg_s8s8_comp, sizeof(int32_t) * nb_tail);
901 mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp);
902 }
903 add(aux_reg_scales, is_oc_scale_ * bia_typesize_ * (nb_tail));
904 }
905 add(aux_reg_out, out_typesize_ * (nb_tail));
906 }
907 }
908
909 void generate() override {
910 preamble();
911
912 sub(rsp, stack_space_needed_);
913
914 int nb = brg.load_dim / brg.ld_block;
915 int nb_tail = brg.load_dim % brg.ld_block;
916
917 int nb2 = nb / n_block2_;
918 int nb2_tail = nb % n_block2_;
919 int n_block = (nb2 == 0) ? nstl::max(1, nb2_tail) : n_block2_;
920
921 int m_max_regs = (brg.is_bf16_emu ? 24 : max_vregs_ - 4) / n_block;
922 int m_block = nstl::min(brg.bcast_dim, m_max_regs);
923
924 int mb = brg.bcast_dim / m_block;
925 int mb_tail = brg.bcast_dim % m_block;
926
927 if (isa_has_masks(isa)) {
928 const auto full_mask = size_t {0xffffffffffffffff};
929 const auto tail_mask = size_t((1 << nb_tail) - 1);
930
931 reg64_t reg_mask = rax;
932
933 mov(reg_mask, full_mask);
934 kmovq(k_full_mask, reg_mask);
935 mov(reg_mask, tail_mask);
936 kmovq(k_tail_mask, reg_mask);
937 }
938
939 if (brg.alpha != 0) {
940 mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]);
941 mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]);
942 mov(reg_apply_comp, ptr[param1 + GET_OFF(apply_comp)]);
943 mov(ptr[rsp + reg_apply_comp_offs_], reg_apply_comp);
944
945 if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]);
946 if (brg.zp_type_c != brgemm_broadcast_t::none) {
947 mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]);
948 mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values);
949 }
950 if (brg.zp_type_a != brgemm_broadcast_t::none) {
951 mov(reg_zp_a_comp, ptr[param1 + GET_OFF(a_zp_compensation)]);
952 mov(ptr[rsp + reg_zp_a_comp_offs_], reg_zp_a_comp);
953
954 mov(reg_zp_a_val, ptr[param1 + GET_OFF(a_comp_val)]);
955 mov(ptr[rsp + reg_zp_a_val_offs_], reg_zp_a_val);
956 }
957 if (brg.req_s8s8_compensation) {
958 mov(reg_s8s8_comp, ptr[param1 + GET_OFF(s8s8_compensation)]);
959 mov(ptr[rsp + reg_s8s8_comp_offs_], reg_s8s8_comp);
960 }
961 }
962 mov(reg_out, ptr[param1 + GET_OFF(ptr_out)]);
963
964 // brg.alpha == 0 means no read from input, no bias, no eltwise - just
965 // initialize registers by zero
966 // brg.beta == 0 means no sum - just registers write to output
967 if (brg.alpha == 0) {
968 for_(int m = 0; m < m_block; m++)
969 for (int n = 0; n < n_block; n++) {
970 auto vmm = Vmm(m * n_block + n);
971 uni_vpxor(vmm, vmm, vmm);
972 }
973 }
974
975 for (int mb_ = 0; mb_ < mb; mb_++) {
976 loop_by_N(m_block, nb2, nb2_tail, nb_tail);
977
978 if (brg.alpha != 0)
979 add(reg_in, inp_typesize_ * (m_block * brg.LDC));
980 add(reg_out, out_typesize_ * (m_block * LDD_));
981 }
982 if (mb_tail > 0) loop_by_N(mb_tail, nb2, nb2_tail, nb_tail);
983
984 add(rsp, stack_space_needed_);
985
986 postamble();
987
988 if (brg.alpha != 0 && jcp.with_eltwise)
989 postops_injector_->prepare_table();
990 }
991};
992
993#undef GET_OFF
994
995} // namespace x64
996} // namespace cpu
997} // namespace impl
998} // namespace dnnl
999
1000#endif
1001