1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_POST_OPS_HPP |
18 | #define CPU_X64_JIT_BRGEMM_POST_OPS_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | |
25 | #include "cpu/cpu_engine.hpp" |
26 | |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
29 | #include "cpu/x64/jit_brgemm_primitive_conf.hpp" |
30 | #include "cpu/x64/jit_generator.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | struct brgemm_kernel_diff_bias_t { |
38 | brgemm_kernel_diff_bias_t() |
39 | : ptr_diff_dst(nullptr) |
40 | , ptr_diff_bias_acc(nullptr) |
41 | , ptr_diff_bias(nullptr) |
42 | , flags(0) {}; |
43 | |
44 | void *ptr_diff_dst; |
45 | void *ptr_diff_bias_acc; |
46 | void *ptr_diff_bias; |
47 | int flags; |
48 | }; |
49 | |
50 | #define GET_OFF(field) offsetof(brgemm_kernel_diff_bias_t, field) |
51 | |
52 | struct jit_brgemm_kernel_diff_bias_t : public jit_generator { |
53 | jit_brgemm_kernel_diff_bias_t( |
54 | const jit_brgemm_primitive_conf_t &ajbgp, const brgemm_t &abrg) |
55 | : jit_generator(jit_name()) |
56 | , brg_(abrg) |
57 | , ddst_dt_(ajbgp.dst_dt) |
58 | , bia_dt_(ajbgp.bia_dt) |
59 | , acc_dt_(ajbgp.acc_dt) |
60 | , bia_typesize_(types::data_type_size(bia_dt_)) |
61 | , acc_typesize_(types::data_type_size(acc_dt_)) { |
62 | |
63 | ddst_dt_ = (ajbgp.isa == avx512_core_fp16 && ajbgp.use_buffer_b) |
64 | ? data_type::f32 |
65 | : ajbgp.dst_dt; |
66 | ddst_typesize_ = types::data_type_size(ddst_dt_); |
67 | mult_ = data_type_vnni_granularity(ddst_dt_); |
68 | } |
69 | |
70 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_diff_bias_t) |
71 | |
72 | private: |
73 | brgemm_t brg_; |
74 | data_type_t ddst_dt_; |
75 | data_type_t bia_dt_; |
76 | data_type_t acc_dt_; |
77 | |
78 | int ddst_typesize_; |
79 | int bia_typesize_; |
80 | int acc_typesize_; |
81 | int mult_; |
82 | |
83 | using reg64_t = const Xbyak::Reg64; |
84 | // Register decomposition |
85 | const reg64_t param1 = abi_param1; |
86 | const reg64_t reg_ddst = r15; |
87 | const reg64_t reg_bias = r14; |
88 | const reg64_t reg_bias_acc = r13; |
89 | const reg64_t aux_reg_ddst = r12; |
90 | const reg64_t reg_k_iter = r11; |
91 | const reg64_t reg_flag = r10; |
92 | |
93 | Xbyak::Opmask k_full_mask = Xbyak::Opmask(2); |
94 | Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3); |
95 | Xbyak::Opmask k_f16_perm_mask = Xbyak::Opmask(4); |
96 | Xbyak::Zmm vreg_unit = Xbyak::Zmm(31); |
97 | Xbyak::Zmm vreg_perm = Xbyak::Zmm(30); |
98 | |
99 | const int n_max_regs_ = 4; |
100 | |
101 | const Xbyak::Zmm zmm_mask(const Xbyak::Zmm zmm_in, bool mask_flag, |
102 | bool store, Xbyak::Opmask ktail_mask) { |
103 | return mask_flag |
104 | ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
105 | : zmm_in; |
106 | } |
107 | |
108 | Xbyak::Zmm get_bias_reg(int n) const { return Xbyak::Zmm(n); } |
109 | Xbyak::Ymm get_bias_reg_lower(int n) const { return Xbyak::Ymm(n); } |
110 | Xbyak::Zmm get_ddst_reg(int n) const { return Xbyak::Zmm(n + n_max_regs_); } |
111 | |
112 | void accumulate_bias(int idx, bool mask_flag) { |
113 | auto vddst = get_ddst_reg(idx); |
114 | auto vddst_load = zmm_mask(vddst, mask_flag, false, k_tail_mask); |
115 | auto vbias = get_bias_reg(idx); |
116 | if (ddst_dt_ == data_type::f16) { |
117 | // As we do not have fp16_vnni, we add twice to accumulate |
118 | // adjacent elements. |
119 | for (int i = 0; i < 2; ++i) { |
120 | auto addr = ptr[aux_reg_ddst |
121 | + ddst_typesize_ * mult_ * idx * brg_.ld_block + i * 2]; |
122 | vmovups(vddst_load, addr); |
123 | vpermw(vddst | k_f16_perm_mask | T_z, vreg_perm, vddst); |
124 | vcvtph2psx(vddst, Xbyak::Ymm(vddst.getIdx())); |
125 | vaddps(vbias, vbias, vddst); |
126 | } |
127 | } else { |
128 | auto addr = ptr[aux_reg_ddst |
129 | + ddst_typesize_ * mult_ * idx * brg_.ld_block]; |
130 | vmovups(vddst_load, addr); |
131 | if (ddst_dt_ == data_type::bf16) |
132 | vdpbf16ps(vbias, vreg_unit, vddst); |
133 | else |
134 | vaddps(vbias, vbias, vddst); |
135 | } |
136 | } |
137 | |
138 | void store(int idx, bool mask_flag) { |
139 | auto addr = ptr[reg_bias + bia_typesize_ * idx * brg_.ld_block]; |
140 | auto vbias = get_bias_reg(idx); |
141 | auto vbias_lower = get_bias_reg_lower(idx); |
142 | switch (bia_dt_) { |
143 | case data_type::bf16: |
144 | vcvtneps2bf16(vbias_lower, vbias); |
145 | if (mask_flag) { |
146 | vmovdqu16(addr, |
147 | zmm_mask(vbias, mask_flag, true, k_tail_mask)); |
148 | } else { |
149 | vmovups(addr, vbias_lower); |
150 | } |
151 | break; |
152 | case data_type::f16: |
153 | vcvtps2ph(vbias_lower, vbias, 0x4); |
154 | if (mask_flag) { |
155 | vmovdqu16(addr, |
156 | zmm_mask(vbias, mask_flag, true, k_tail_mask)); |
157 | } else { |
158 | vmovups(addr, vbias_lower); |
159 | } |
160 | break; |
161 | case data_type::f32: |
162 | vmovups(addr, |
163 | zmm_mask(get_bias_reg(idx), mask_flag, true, |
164 | k_tail_mask)); |
165 | break; |
166 | default: assert("Unsupported bias data type" ); |
167 | } |
168 | } |
169 | |
170 | void loop_by_N(int n_loop, int nb_tail) { |
171 | |
172 | mov(aux_reg_ddst, reg_ddst); |
173 | |
174 | int n_iters = n_loop; |
175 | if (nb_tail > 0) n_iters--; |
176 | Xbyak::Label k_loop, init_zero, init_done; |
177 | int n_ = 0; |
178 | |
179 | test(reg_flag, FLAG_REDUCE_FIRST); |
180 | jnz(init_zero, T_NEAR); // FLAG_REDUCE_FIRST is set |
181 | |
182 | for (; n_ < n_iters; n_++) { |
183 | auto vbias = get_bias_reg(n_); |
184 | auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block]; |
185 | vmovups(vbias, addr); |
186 | } |
187 | if (nb_tail > 0) { |
188 | auto vbias = zmm_mask(get_bias_reg(n_), true, false, k_tail_mask); |
189 | auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block]; |
190 | vmovups(vbias, addr); |
191 | } |
192 | jmp(init_done, T_NEAR); |
193 | L(init_zero); |
194 | |
195 | for (int n_ = 0; n_ < n_loop; n_++) { |
196 | vxorpd(get_bias_reg(n_), get_bias_reg(n_), get_bias_reg(n_)); |
197 | } |
198 | L(init_done); |
199 | |
200 | mov(reg_k_iter, utils::div_up(brg_.reduce_dim, mult_)); |
201 | L(k_loop); |
202 | { |
203 | int n_ = 0; |
204 | for (; n_ < n_iters; n_++) |
205 | accumulate_bias(n_, false); |
206 | |
207 | if (nb_tail > 0) accumulate_bias(n_, true); |
208 | |
209 | add(aux_reg_ddst, ddst_typesize_ * mult_ * brg_.LDB); |
210 | |
211 | sub(reg_k_iter, 1); |
212 | jnz(k_loop, T_NEAR); |
213 | } |
214 | |
215 | Xbyak::Label store_final, store_done; |
216 | test(reg_flag, FLAG_REDUCE_LAST); |
217 | jnz(store_final, T_NEAR); // FLAG_REDUCE_LAST is set |
218 | |
219 | n_ = 0; |
220 | for (; n_ < n_iters; n_++) { |
221 | auto vbias = get_bias_reg(n_); |
222 | auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block]; |
223 | vmovups(addr, vbias); |
224 | } |
225 | if (nb_tail > 0) { |
226 | auto addr = ptr[reg_bias_acc + acc_typesize_ * n_ * brg_.ld_block]; |
227 | auto vbias = zmm_mask(get_bias_reg(n_), true, true, k_tail_mask); |
228 | vmovups(addr, vbias); |
229 | } |
230 | jmp(store_done, T_NEAR); |
231 | |
232 | L(store_final); |
233 | n_ = 0; |
234 | |
235 | for (; n_ < n_iters; n_++) |
236 | store(n_, false); |
237 | |
238 | if (nb_tail > 0) store(n_, true); |
239 | |
240 | L(store_done); |
241 | } |
242 | |
243 | void generate() override { |
244 | preamble(); |
245 | |
246 | int nb = utils::div_up(brg_.load_dim, brg_.ld_block); |
247 | int nb_tail = brg_.load_dim % brg_.ld_block; |
248 | |
249 | int n_loop = nb / n_max_regs_; |
250 | int n_loop_tail = nb % n_max_regs_; |
251 | if (n_loop_tail == 0 && nb_tail > 0) { |
252 | n_loop--; |
253 | n_loop_tail = n_max_regs_; |
254 | } |
255 | |
256 | const auto full_mask = size_t {0xffffffffffffffff}; |
257 | const auto tail_mask = size_t((1 << nb_tail) - 1); |
258 | reg64_t reg_mask = rax; |
259 | |
260 | mov(reg_mask, full_mask); |
261 | kmovq(k_full_mask, reg_mask); |
262 | mov(reg_mask, tail_mask); |
263 | kmovq(k_tail_mask, reg_mask); |
264 | |
265 | if (ddst_dt_ == data_type::bf16) { |
266 | auto reg_unit_val = reg_mask.cvt16(); |
267 | mov(reg_unit_val, 0x3f80); // bf16 value of 1. |
268 | vpbroadcastw(vreg_unit, reg_unit_val); |
269 | } |
270 | |
271 | Xbyak::Label f16_perm_table; |
272 | if (ddst_dt_ == data_type::f16) { |
273 | const auto half_mask = size_t((1 << 16) - 1); |
274 | mov(reg_mask, half_mask); |
275 | kmovq(k_f16_perm_mask, reg_mask); |
276 | |
277 | mov(reg_mask, f16_perm_table); |
278 | vmovups(vreg_perm | k_f16_perm_mask | T_z, ptr[reg_mask]); |
279 | } |
280 | |
281 | mov(reg_ddst, ptr[param1 + GET_OFF(ptr_diff_dst)]); |
282 | mov(reg_bias_acc, ptr[param1 + GET_OFF(ptr_diff_bias_acc)]); |
283 | mov(reg_bias, ptr[param1 + GET_OFF(ptr_diff_bias)]); |
284 | mov(reg_flag, ptr[param1 + GET_OFF(flags)]); |
285 | |
286 | for (int nb_ = 0; nb_ < n_loop; nb_++) { |
287 | loop_by_N(n_max_regs_, 0); |
288 | |
289 | add(reg_ddst, ddst_typesize_ * mult_ * n_max_regs_ * brg_.ld_block); |
290 | add(reg_bias, bia_typesize_ * n_max_regs_ * brg_.ld_block); |
291 | add(reg_bias_acc, acc_typesize_ * n_max_regs_ * brg_.ld_block); |
292 | } |
293 | |
294 | if (n_loop_tail > 0) loop_by_N(n_loop_tail, nb_tail); |
295 | postamble(); |
296 | |
297 | if (ddst_dt_ == data_type::f16) { |
298 | // convert interleaved vnni data with holes to packed. |
299 | const uint16_t f16_prm_array[16] = { |
300 | 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; |
301 | align(64); |
302 | L(f16_perm_table); |
303 | for (int i = 0; i < 16; ++i) |
304 | dw(f16_prm_array[i]); |
305 | } |
306 | } |
307 | }; |
308 | |
309 | #undef GET_OFF |
310 | |
311 | #define GET_OFF(field) offsetof(brgemm_kernel_post_ops_t, field) |
312 | |
313 | struct brgemm_kernel_post_ops_t { |
314 | void *ptr_in; |
315 | void *ptr_out; |
316 | void *ptr_bias; |
317 | void *ptr_scales; |
318 | const void *ptr_binary_post_ops_rhs; |
319 | size_t apply_comp = 0; |
320 | int32_t a_comp_val = 1; |
321 | int32_t *a_zp_compensation; |
322 | int32_t *c_zp_values; |
323 | int32_t *s8s8_compensation; |
324 | const void *dst_orig; |
325 | }; |
326 | |
327 | template <cpu_isa_t isa> |
328 | struct jit_brgemm_kernel_post_ops : public jit_generator { |
329 | |
330 | jit_brgemm_kernel_post_ops(const jit_brgemm_conv_conf_t &ajcp, |
331 | const brgemm_t &abrg, const primitive_attr_t &aattr) |
332 | : jit_generator(jit_name()) |
333 | , brg(abrg) |
334 | , jcp(ajcp) |
335 | , attr(aattr) |
336 | , postops_injector_(nullptr) |
337 | , with_binary_non_scalar_bcast_(brg.with_binary |
338 | && binary_injector:: |
339 | any_binary_postop_rhs_non_scalar_broadcast( |
340 | brg.attr->post_ops_, |
341 | memory_desc_wrapper(brg.dst_md))) { |
342 | |
343 | if ((jcp.with_sum && brg.beta != 0) |
344 | || ((jcp.with_binary || jcp.with_eltwise) && brg.alpha != 0)) { |
345 | static constexpr bool preserve_gpr = true; |
346 | static constexpr bool preserve_vmm = true; |
347 | static constexpr bool use_exact_tail_scalar_bcast = false; |
348 | |
349 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
350 | static_cast<size_t>(vmm_tmp(4).getIdx()), this->r14, |
351 | this->r15, this->r13, preserve_gpr, preserve_vmm, |
352 | GET_OFF(ptr_binary_post_ops_rhs), GET_OFF(dst_orig), |
353 | memory_desc_wrapper(brg.dst_md), |
354 | static_cast<size_t>(brg.load_dim % brg.ld_block), |
355 | k_tail_mask, use_exact_tail_scalar_bcast}; |
356 | const binary_injector::static_params_t bsp {this->param1, rhs_sp}; |
357 | |
358 | const bool save_state = (brg.alpha != 0) && jcp.with_eltwise; |
359 | const auto &reserved_eltwise_gpr = rax; |
360 | const auto reserved_eltwise_maskr = Xbyak::Opmask(1); |
361 | |
362 | const eltwise_injector::static_params_t esp { |
363 | save_state, reserved_eltwise_gpr, reserved_eltwise_maskr}; |
364 | |
365 | postops_injector_ = utils::make_unique< |
366 | injector::jit_uni_postops_injector_t<po_isa_t>>( |
367 | this, attr.post_ops_, bsp, esp); |
368 | } |
369 | if (brg.is_bf16_emu) |
370 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
371 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
372 | bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4); |
373 | |
374 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
375 | // per_oc: conv: 1 << 0, 1 << 1 (with groups) |
376 | // per_oc: ip: 1 << 0 |
377 | is_oc_scale_ = utils::one_of(wei_scales.mask_, 1 << 0, 1 << 1); |
378 | |
379 | LDD_ = brg.LDD; |
380 | inp_dt_ = brg.dt_c; |
381 | out_dt_ = brg.dt_d; |
382 | bia_dt_ = jcp.bia_dt; |
383 | inp_typesize_ = types::data_type_size(inp_dt_); |
384 | out_typesize_ = types::data_type_size(out_dt_); |
385 | bia_typesize_ = (jcp.with_bias) ? types::data_type_size(bia_dt_) : 0; |
386 | } |
387 | |
388 | ~jit_brgemm_kernel_post_ops() = default; |
389 | |
390 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_post_ops) |
391 | |
392 | brgemm_t brg; |
393 | jit_brgemm_conv_conf_t jcp; |
394 | const primitive_attr_t &attr; |
395 | |
396 | private: |
397 | int LDD_; |
398 | |
399 | data_type_t inp_dt_; |
400 | data_type_t out_dt_; |
401 | data_type_t bia_dt_; |
402 | static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core, avx2, |
403 | avx2, avx2_vnni_2, avx2_vnni_2, avx512_core_fp16, avx512_core_fp16); |
404 | std::unique_ptr<injector::jit_uni_postops_injector_t<po_isa_t>> |
405 | postops_injector_; |
406 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
407 | |
408 | const bool with_binary_non_scalar_bcast_; |
409 | |
410 | int inp_typesize_; |
411 | int out_typesize_; |
412 | int bia_typesize_; |
413 | |
414 | int is_oc_scale_; |
415 | constexpr static int max_vregs_ = cpu_isa_traits<po_isa_t>::n_vregs; |
416 | |
417 | using reg64_t = const Xbyak::Reg64; |
418 | using Vmm = |
419 | typename utils::conditional<utils::one_of(isa, avx2, avx2_vnni_2), |
420 | Xbyak::Ymm, Xbyak::Zmm>::type; |
421 | using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t; |
422 | |
423 | // Register decomposition |
424 | const reg64_t param1 = abi_param1; |
425 | const reg64_t reg_in = r15; |
426 | const reg64_t reg_out = r14; |
427 | const reg64_t aux_reg_in = r13; |
428 | const reg64_t aux_reg_out = r12; |
429 | |
430 | const reg64_t reg_bias = r11; |
431 | const reg64_t aux_reg_bias = r10; |
432 | |
433 | const reg64_t reg_scales = r9; |
434 | const reg64_t aux_reg_scales = r8; |
435 | |
436 | const reg64_t reg_ptr_sum_scale = rdx; |
437 | const reg64_t reg_ptr_sum_zp = rsi; |
438 | |
439 | const reg64_t reg_zp_c_values = rbx; |
440 | const reg64_t aux_reg_zp_c_values = rbx; |
441 | const reg64_t reg_zp_a_comp = rbx; |
442 | const reg64_t aux_reg_zp_a_comp = rbx; |
443 | const reg64_t reg_s8s8_comp = rbx; |
444 | const reg64_t aux_reg_s8s8_comp = rbx; |
445 | const reg64_t reg_zp_a_val = rbx; |
446 | const reg64_t reg_apply_comp = rbx; |
447 | |
448 | constexpr static int reg_zp_c_values_offs_ = 0; |
449 | constexpr static int aux_reg_zp_c_values_offs_ = 8; |
450 | constexpr static int reg_zp_a_comp_offs_ = 16; |
451 | constexpr static int aux_reg_zp_a_comp_offs_ = 24; |
452 | constexpr static int reg_s8s8_comp_offs_ = 32; |
453 | constexpr static int aux_reg_s8s8_comp_offs_ = 40; |
454 | constexpr static int reg_zp_a_val_offs_ = 48; |
455 | constexpr static int reg_apply_comp_offs_ = 56; |
456 | constexpr static int stack_space_needed_ = 64; |
457 | |
458 | /* bf16 emulation */ |
459 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27); |
460 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(24); |
461 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(25); |
462 | Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(26); |
463 | reg64_t bf16_emu_scratch = rax; |
464 | |
465 | Xbyak::Opmask k_full_mask = Xbyak::Opmask(2); |
466 | Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3); |
467 | |
468 | const int n_block2_ = 4; |
469 | |
470 | Vmm vmm_tmp(int i) const { return Vmm(max_vregs_ - 1 - i); } |
471 | |
472 | int zp_c_values_offset(int n, bool is_tail = false) const noexcept { |
473 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
474 | return (is_tail) ? sizeof(int32_t) * brg.ldb_tail |
475 | : sizeof(int32_t) * n * brg.ld_block; |
476 | } |
477 | |
478 | return 0; |
479 | } |
480 | int zp_comp_a_vpad_offset(int n, int m, bool is_tail = false) const |
481 | noexcept { |
482 | return (is_tail) ? sizeof(int32_t) * (brg.ldb_tail + m * brg.LDB) |
483 | : sizeof(int32_t) * (n * brg.ld_block + m * brg.LDB); |
484 | } |
485 | int mb_zp_comp_a_offset(int m_block) const noexcept { |
486 | return sizeof(int32_t) * m_block * brg.LDB; |
487 | } |
488 | int compensation_vpad_offset(int n, int m, bool is_tail = false) const |
489 | noexcept { |
490 | return (is_tail) ? sizeof(int32_t) * (brg.ldb_tail + m * brg.LDB) |
491 | : sizeof(int32_t) * (n * brg.ld_block + m * brg.LDB); |
492 | } |
493 | int mb_compensation_offset(int m_block) const noexcept { |
494 | return sizeof(int32_t) * m_block * brg.LDB; |
495 | } |
496 | |
497 | template <typename T> |
498 | const T maybe_mask(const T vmm_in, bool mask_flag, bool store, |
499 | Xbyak::Opmask ktail_mask) { |
500 | assert(IMPLICATION(mask_flag, isa_has_masks(isa))); |
501 | return mask_flag |
502 | ? (store ? vmm_in | ktail_mask : vmm_in | ktail_mask | T_z) |
503 | : vmm_in; |
504 | } |
505 | |
506 | void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, |
507 | int tail_size, bool store, Xbyak::Opmask ktail_mask, |
508 | bool skip_cvt2ps = false) { |
509 | const bool is_tail = op.isMEM() |
510 | && tail_size != vreg_traits<Vmm>::vlen / sizeof(float) |
511 | // The current kernel is written such that tail_size = 0 implies |
512 | // no tail and full vmm must be processed. |
513 | && tail_size > 0; |
514 | |
515 | if (IMPLICATION(is_tail, isa_has_masks(isa))) { |
516 | const Vmm vmm = maybe_mask(vmm_in, is_tail, store, ktail_mask); |
517 | switch (type_in) { |
518 | case data_type::f32: |
519 | case data_type::s32: vmovups(vmm, op); break; |
520 | case data_type::s8: vpmovsxbd(vmm, op); break; |
521 | case data_type::u8: vpmovzxbd(vmm, op); break; |
522 | case data_type::bf16: |
523 | vpmovzxwd(vmm, op); |
524 | vpslld(vmm, vmm, 16); |
525 | break; |
526 | case data_type::f16: vcvtph2ps(vmm, op); break; |
527 | default: assert(!"unsupported data type" ); |
528 | } |
529 | } else { |
530 | load_data(type_in, vmm_in, op.getAddress(), tail_size); |
531 | } |
532 | if (!skip_cvt2ps && types::is_integral_dt(type_in)) |
533 | vcvtdq2ps(vmm_in, vmm_in); |
534 | } |
535 | |
536 | Vmm vector(int m, int n, int n_block) { return Vmm(m * n_block + n); }; |
537 | |
538 | void inject_attr_postops(int m_block, int n_block, int tail = 0) { |
539 | const auto &p = attr.post_ops_; |
540 | const int sum_idx = p.find(primitive_kind::sum); |
541 | const auto k_mask = tail == 0 ? k_full_mask : k_tail_mask; |
542 | const auto sum_dt = p.get_sum_dt(out_dt_); |
543 | |
544 | const auto sum_injector = [&] { |
545 | const float *p_sum_scale = &p.entry_[sum_idx].sum.scale; |
546 | const int32_t *p_sum_zp = &p.entry_[sum_idx].sum.zero_point; |
547 | if (*p_sum_scale != 1.f) |
548 | mov(reg_ptr_sum_scale, (size_t)p_sum_scale); |
549 | auto vmm_sum_zp = vmm_tmp(1); |
550 | if (*p_sum_zp != 0) { |
551 | mov(reg_ptr_sum_zp, (size_t)p_sum_zp); |
552 | vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
553 | } |
554 | |
555 | for_(int m = 0; m < m_block; m++) |
556 | for (int n = 0; n < n_block; n++) { |
557 | const auto vmm = vector(m, n, n_block); |
558 | const auto addr = ptr[aux_reg_out |
559 | + out_typesize_ * (m * LDD_ + n * brg.ld_block)]; |
560 | |
561 | const auto vmm_prev_dst = vmm_tmp(0); |
562 | cvt2ps(sum_dt, vmm_prev_dst, addr, tail, false, k_mask); |
563 | if (*p_sum_zp != 0) vsubps(vmm_prev_dst, vmm_sum_zp); |
564 | if (*p_sum_scale == 1.f) |
565 | vaddps(vmm, vmm_prev_dst); |
566 | else { |
567 | if (is_superset(isa, avx512_core)) { |
568 | vfmadd231ps( |
569 | vmm, vmm_prev_dst, ptr_b[reg_ptr_sum_scale]); |
570 | } else { |
571 | auto vmm_sum_scale = vmm_tmp(1); |
572 | vpbroadcastd(vmm_sum_scale, ptr[reg_ptr_sum_scale]); |
573 | vfmadd231ps(vmm, vmm_prev_dst, vmm_sum_scale); |
574 | } |
575 | } |
576 | } |
577 | }; |
578 | |
579 | if (jcp.with_sum && brg.beta != 0) { |
580 | postops_injector_->set_lambda_injector( |
581 | primitive_kind::sum, sum_injector); |
582 | } |
583 | |
584 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
585 | |
586 | if (with_binary_non_scalar_bcast_) { |
587 | for_(int m = 0; m < m_block; m++) |
588 | for (int n = 0; n < n_block; n++) { |
589 | const auto vmm_idx = vector(m, n, n_block).getIdx(); |
590 | const size_t aux_output_offset |
591 | = out_typesize_ * (m * LDD_ + n * brg.ld_block); |
592 | |
593 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, aux_reg_out); |
594 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
595 | vmm_idx, aux_output_offset); |
596 | if (tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
597 | } |
598 | } |
599 | |
600 | postops_injector_->compute_vector_range( |
601 | 0, m_block * n_block, rhs_arg_params); |
602 | } |
603 | |
604 | void apply_comp(int m_block, int n_block, int tail = 0) { |
605 | auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask; |
606 | |
607 | if (brg.alpha != 0 && brg.zp_type_a != brgemm_broadcast_t::none) { |
608 | auto vmm_zp_a_val = vmm_tmp(1); |
609 | mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]); |
610 | vpbroadcastd(vmm_zp_a_val, reg_zp_a_val.cvt32()); |
611 | |
612 | mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]); |
613 | for (int n = 0; n < n_block; n++) { |
614 | auto vmm_zp_comp_a = vmm_tmp(0); |
615 | auto zp_comp_a_addr = EVEX_compress_addr(aux_reg_zp_a_comp, |
616 | sizeof(int32_t) * (n * brg.ld_block)); |
617 | vmm_zp_comp_a |
618 | = maybe_mask(vmm_zp_comp_a, tail > 0, false, k_mask); |
619 | vmovups(vmm_zp_comp_a, zp_comp_a_addr); |
620 | vpmulld(vmm_zp_comp_a, vmm_zp_a_val, zp_comp_a_addr); |
621 | |
622 | for (int m = 0; m < m_block; m++) { |
623 | auto vmm = vector(m, n, n_block); |
624 | vpaddd(vmm, vmm, vmm_zp_comp_a); |
625 | } |
626 | } |
627 | } |
628 | |
629 | if (brg.alpha != 0 && brg.req_s8s8_compensation) { |
630 | mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]); |
631 | for (int n = 0; n < n_block; n++) { |
632 | auto vmm_comp = vmm_tmp(0); |
633 | auto comp_addr = EVEX_compress_addr(aux_reg_s8s8_comp, |
634 | sizeof(int32_t) * (n * brg.ld_block)); |
635 | vmm_comp = maybe_mask(vmm_comp, tail > 0, false, k_mask); |
636 | vmovups(vmm_comp, comp_addr); |
637 | |
638 | for (int m = 0; m < m_block; m++) { |
639 | auto vmm = vector(m, n, n_block); |
640 | vpaddd(vmm, vmm, vmm_comp); |
641 | } |
642 | } |
643 | } |
644 | } |
645 | |
646 | void maybe_apply_comp(int m_block, int n_block, int tail = 0) { |
647 | Xbyak::Label label_apply_without_comp; |
648 | mov(reg_apply_comp, ptr[rsp + reg_apply_comp_offs_]); |
649 | cmp(reg_apply_comp, 0); |
650 | je(label_apply_without_comp, T_NEAR); |
651 | apply_comp(m_block, n_block, tail); |
652 | L_aligned(label_apply_without_comp); |
653 | |
654 | for_(int m = 0; m < m_block; m++) |
655 | for (int n = 0; n < n_block; n++) { |
656 | vcvtdq2ps(vector(m, n, n_block), vector(m, n, n_block)); |
657 | } |
658 | } |
659 | |
660 | void apply_post_ops(int m_block, int n_block, int tail = 0) { |
661 | const auto vector = [=](int m, int n) { return Vmm(m * n_block + n); }; |
662 | auto k_mask = (tail == 0) ? k_full_mask : k_tail_mask; |
663 | const auto &p = attr.post_ops_; |
664 | const int sum_idx = p.find(primitive_kind::sum); |
665 | const auto req_comp = brg.is_int8 && brg.alpha != 0 |
666 | && (brg.req_s8s8_compensation |
667 | || brg.zp_type_a != brgemm_broadcast_t::none); |
668 | |
669 | // brg.alpha == 0 means no read from input, no bias, no eltwise - just |
670 | // initialize registers by zero at the beginning of kernel |
671 | // brg.beta == 0 means no sum - just registers write to output |
672 | // req_comp == true -> convert accumulated values to f32 after applying |
673 | // compensation to avoid the loss of accuracy when converting s32 to f32 |
674 | for_(int m = 0; m < m_block; m++) |
675 | for (int n = 0; n < n_block; n++) { |
676 | if (brg.alpha == 0) { |
677 | if (sum_idx != -1 && brg.beta != 0) { |
678 | // if sum then have to init vmm each time |
679 | uni_vpxor(vector(m, n), vector(m, n), vector(m, n)); |
680 | } |
681 | } else { |
682 | auto inp_addr = ptr[aux_reg_in |
683 | + inp_typesize_ * (m * brg.LDC + n * brg.ld_block)]; |
684 | cvt2ps(inp_dt_, vector(m, n), inp_addr, tail, false, k_mask, |
685 | req_comp); |
686 | } |
687 | } |
688 | |
689 | if (req_comp) maybe_apply_comp(m_block, n_block, tail); |
690 | |
691 | if (brg.alpha != 0 && jcp.with_bias) { |
692 | for_(int m = 0; m < m_block; m++) |
693 | for (int n = 0; n < n_block; n++) { |
694 | auto vmm_bias = vmm_tmp(0); |
695 | auto bias_addr = ptr[aux_reg_bias |
696 | + bia_typesize_ * (n * brg.ld_block)]; |
697 | |
698 | cvt2ps(bia_dt_, vmm_bias, bias_addr, tail, false, k_mask); |
699 | vaddps(vector(m, n), vmm_bias); |
700 | } |
701 | } |
702 | |
703 | if (brg.alpha != 0) { |
704 | for_(int m = 0; m < m_block; m++) |
705 | for (int n = 0; n < n_block; n++) { |
706 | const auto addr = ptr[aux_reg_scales |
707 | + is_oc_scale_ * sizeof(float) * (n * brg.ld_block)]; |
708 | auto vmm = vector(m, n); |
709 | if (IMPLICATION(tail > 0, isa_has_masks(isa))) { |
710 | vmm = maybe_mask(vector(m, n), tail > 0, false, k_mask); |
711 | vmulps(vmm, vmm, addr); |
712 | } else { |
713 | auto vmm_scales = vmm_tmp(0); |
714 | load_data(data_type::f32, vmm_scales, addr, tail); |
715 | vmulps(vmm, vmm, vmm_scales); |
716 | } |
717 | } |
718 | } |
719 | |
720 | if (postops_injector_) inject_attr_postops(m_block, n_block, tail); |
721 | |
722 | if (brg.alpha != 0 && brg.zp_type_c != brgemm_broadcast_t::none) { |
723 | mov(aux_reg_zp_c_values, ptr[rsp + aux_reg_zp_c_values_offs_]); |
724 | auto vmm_zp_c = vmm_tmp(0); |
725 | if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) { |
726 | vcvtdq2ps(vmm_zp_c, |
727 | EVEX_compress_addr(aux_reg_zp_c_values, 0, true)); |
728 | } |
729 | for (int n = 0; n < n_block; n++) { |
730 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
731 | int zp_c_off = zp_c_values_offset(n); |
732 | auto zp_c_addr |
733 | = EVEX_compress_addr(aux_reg_zp_c_values, zp_c_off); |
734 | cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, tail, false, |
735 | k_mask); |
736 | } |
737 | for (int m = 0; m < m_block; m++) |
738 | vaddps(vector(m, n), vmm_zp_c); |
739 | } |
740 | } |
741 | |
742 | const bool dt_requires_saturation = utils::one_of( |
743 | brg.dt_d, data_type::u8, data_type::s8, data_type::s32); |
744 | |
745 | const reg64_t reg_tmp_gpr = rax; |
746 | auto vmm_lbound = vmm_tmp(0); |
747 | auto vmm_ubound = vmm_tmp(1); |
748 | if (dt_requires_saturation) { |
749 | init_saturate_f32(vmm_lbound, vmm_ubound, reg_tmp_gpr, |
750 | data_type::f32, brg.dt_d); |
751 | } |
752 | |
753 | if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16(); |
754 | |
755 | for_(int m = 0; m < m_block; m++) |
756 | for (int n = 0; n < n_block; n++) { |
757 | // incase of tail, stores are unconditionally masked, regardless |
758 | // of `n`, implying n_block must be equal to `1`. |
759 | assert(IMPLICATION(tail > 0, n_block == 1)); |
760 | auto vmm = vector(m, n); |
761 | const size_t offset = out_typesize_ * (m * LDD_ + n * brg.ld_block); |
762 | const auto addr = ptr[aux_reg_out + offset]; |
763 | |
764 | if (utils::one_of(out_dt_, data_type::bf16, data_type::f16)) { |
765 | Vmm_lower_t vmm_low = Vmm_lower_t(vmm.getIdx()); |
766 | if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) { |
767 | if (brg.is_f16) |
768 | vcvtps2ph(vmm_low, vmm, _op_mxcsr); |
769 | else if (brg.is_bf16_emu) |
770 | bf16_emu_->vcvtneps2bf16(vmm_low, vmm); |
771 | else |
772 | vcvtneps2bf16(vmm_low, vmm); |
773 | } |
774 | vmm_low = maybe_mask(vmm_low, tail > 0, true, k_mask); |
775 | vmovdqu16(addr, vmm_low); |
776 | } else { |
777 | if (brg.alpha != 0 || (sum_idx != -1 && brg.beta != 0)) { |
778 | saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); |
779 | if (out_dt_ != data_type::f32) vcvtps2dq(vmm, vmm); |
780 | } |
781 | if (IMPLICATION(tail > 0, isa_has_masks(isa))) { |
782 | vmm = maybe_mask(vmm, tail > 0, true, k_mask); |
783 | switch (out_dt_) { |
784 | case data_type::f32: |
785 | case data_type::s32: vmovups(addr, vmm); break; |
786 | case data_type::s8: vpmovsdb(addr, vmm); break; |
787 | case data_type::u8: vpmovusdb(addr, vmm); break; |
788 | default: assert(!"unknown dst_dt" ); |
789 | } |
790 | } else { |
791 | store_data(out_dt_, vmm, aux_reg_out, offset, tail); |
792 | } |
793 | } |
794 | } |
795 | } |
796 | |
797 | void loop_by_N(int m_block, int nb2, int nb2_tail, int nb_tail) { |
798 | |
799 | if (brg.alpha) { |
800 | mov(aux_reg_in, reg_in); |
801 | if (jcp.with_bias) mov(aux_reg_bias, reg_bias); |
802 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
803 | mov(aux_reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]); |
804 | mov(ptr[rsp + aux_reg_zp_c_values_offs_], aux_reg_zp_c_values); |
805 | } |
806 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
807 | mov(aux_reg_zp_a_comp, ptr[rsp + reg_zp_a_comp_offs_]); |
808 | mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp); |
809 | } |
810 | if (brg.req_s8s8_compensation) { |
811 | mov(aux_reg_s8s8_comp, ptr[rsp + reg_s8s8_comp_offs_]); |
812 | mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp); |
813 | } |
814 | mov(aux_reg_scales, reg_scales); |
815 | } |
816 | mov(aux_reg_out, reg_out); |
817 | |
818 | for (int n_loop_ = 0; n_loop_ < nb2; n_loop_++) { |
819 | apply_post_ops(m_block, n_block2_); |
820 | |
821 | const auto oc_l_offset = n_block2_ * brg.ld_block; |
822 | |
823 | add(aux_reg_out, out_typesize_ * oc_l_offset); |
824 | if (brg.alpha != 0) { |
825 | add(aux_reg_in, inp_typesize_ * oc_l_offset); |
826 | |
827 | if (jcp.with_bias) |
828 | add(aux_reg_bias, bia_typesize_ * oc_l_offset); |
829 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
830 | mov(aux_reg_zp_c_values, |
831 | ptr[rsp + aux_reg_zp_c_values_offs_]); |
832 | add(aux_reg_zp_c_values, zp_c_values_offset(n_block2_)); |
833 | mov(ptr[rsp + aux_reg_zp_c_values_offs_], |
834 | aux_reg_zp_c_values); |
835 | } |
836 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
837 | mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]); |
838 | add(aux_reg_zp_a_comp, sizeof(int32_t) * oc_l_offset); |
839 | mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp); |
840 | } |
841 | if (brg.req_s8s8_compensation) { |
842 | mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]); |
843 | add(aux_reg_s8s8_comp, sizeof(int32_t) * oc_l_offset); |
844 | mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp); |
845 | } |
846 | |
847 | add(aux_reg_scales, is_oc_scale_ * sizeof(float) * oc_l_offset); |
848 | } |
849 | } |
850 | if (nb2_tail > 0) { |
851 | apply_post_ops(m_block, nb2_tail); |
852 | const auto oc_l_offset = nb2_tail * brg.ld_block; |
853 | |
854 | add(aux_reg_out, out_typesize_ * oc_l_offset); |
855 | if (brg.alpha != 0) { |
856 | add(aux_reg_in, inp_typesize_ * oc_l_offset); |
857 | if (jcp.with_bias) |
858 | add(aux_reg_bias, bia_typesize_ * oc_l_offset); |
859 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
860 | mov(aux_reg_zp_c_values, |
861 | ptr[rsp + aux_reg_zp_c_values_offs_]); |
862 | add(aux_reg_zp_c_values, zp_c_values_offset(nb2_tail)); |
863 | mov(ptr[rsp + aux_reg_zp_c_values_offs_], |
864 | aux_reg_zp_c_values); |
865 | } |
866 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
867 | mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]); |
868 | add(aux_reg_zp_a_comp, sizeof(int32_t) * oc_l_offset); |
869 | mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp); |
870 | } |
871 | if (brg.req_s8s8_compensation) { |
872 | mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]); |
873 | add(aux_reg_s8s8_comp, sizeof(int32_t) * oc_l_offset); |
874 | mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp); |
875 | } |
876 | |
877 | add(aux_reg_scales, is_oc_scale_ * sizeof(float) * oc_l_offset); |
878 | } |
879 | } |
880 | if (nb_tail > 0) { |
881 | apply_post_ops(m_block, 1, nb_tail); |
882 | |
883 | if (brg.alpha != 0) { |
884 | add(aux_reg_in, inp_typesize_ * (nb_tail)); |
885 | if (jcp.with_bias) add(aux_reg_bias, bia_typesize_ * (nb_tail)); |
886 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
887 | mov(aux_reg_zp_c_values, |
888 | ptr[rsp + aux_reg_zp_c_values_offs_]); |
889 | add(aux_reg_zp_c_values, zp_c_values_offset(1, nb_tail)); |
890 | mov(ptr[rsp + aux_reg_zp_c_values_offs_], |
891 | aux_reg_zp_c_values); |
892 | } |
893 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
894 | mov(aux_reg_zp_a_comp, ptr[rsp + aux_reg_zp_a_comp_offs_]); |
895 | add(aux_reg_zp_a_comp, sizeof(int32_t) * nb_tail); |
896 | mov(ptr[rsp + aux_reg_zp_a_comp_offs_], aux_reg_zp_a_comp); |
897 | } |
898 | if (brg.req_s8s8_compensation) { |
899 | mov(aux_reg_s8s8_comp, ptr[rsp + aux_reg_s8s8_comp_offs_]); |
900 | add(aux_reg_s8s8_comp, sizeof(int32_t) * nb_tail); |
901 | mov(ptr[rsp + aux_reg_s8s8_comp_offs_], aux_reg_s8s8_comp); |
902 | } |
903 | add(aux_reg_scales, is_oc_scale_ * bia_typesize_ * (nb_tail)); |
904 | } |
905 | add(aux_reg_out, out_typesize_ * (nb_tail)); |
906 | } |
907 | } |
908 | |
909 | void generate() override { |
910 | preamble(); |
911 | |
912 | sub(rsp, stack_space_needed_); |
913 | |
914 | int nb = brg.load_dim / brg.ld_block; |
915 | int nb_tail = brg.load_dim % brg.ld_block; |
916 | |
917 | int nb2 = nb / n_block2_; |
918 | int nb2_tail = nb % n_block2_; |
919 | int n_block = (nb2 == 0) ? nstl::max(1, nb2_tail) : n_block2_; |
920 | |
921 | int m_max_regs = (brg.is_bf16_emu ? 24 : max_vregs_ - 4) / n_block; |
922 | int m_block = nstl::min(brg.bcast_dim, m_max_regs); |
923 | |
924 | int mb = brg.bcast_dim / m_block; |
925 | int mb_tail = brg.bcast_dim % m_block; |
926 | |
927 | if (isa_has_masks(isa)) { |
928 | const auto full_mask = size_t {0xffffffffffffffff}; |
929 | const auto tail_mask = size_t((1 << nb_tail) - 1); |
930 | |
931 | reg64_t reg_mask = rax; |
932 | |
933 | mov(reg_mask, full_mask); |
934 | kmovq(k_full_mask, reg_mask); |
935 | mov(reg_mask, tail_mask); |
936 | kmovq(k_tail_mask, reg_mask); |
937 | } |
938 | |
939 | if (brg.alpha != 0) { |
940 | mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]); |
941 | mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]); |
942 | mov(reg_apply_comp, ptr[param1 + GET_OFF(apply_comp)]); |
943 | mov(ptr[rsp + reg_apply_comp_offs_], reg_apply_comp); |
944 | |
945 | if (jcp.with_bias) mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]); |
946 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
947 | mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]); |
948 | mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values); |
949 | } |
950 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
951 | mov(reg_zp_a_comp, ptr[param1 + GET_OFF(a_zp_compensation)]); |
952 | mov(ptr[rsp + reg_zp_a_comp_offs_], reg_zp_a_comp); |
953 | |
954 | mov(reg_zp_a_val, ptr[param1 + GET_OFF(a_comp_val)]); |
955 | mov(ptr[rsp + reg_zp_a_val_offs_], reg_zp_a_val); |
956 | } |
957 | if (brg.req_s8s8_compensation) { |
958 | mov(reg_s8s8_comp, ptr[param1 + GET_OFF(s8s8_compensation)]); |
959 | mov(ptr[rsp + reg_s8s8_comp_offs_], reg_s8s8_comp); |
960 | } |
961 | } |
962 | mov(reg_out, ptr[param1 + GET_OFF(ptr_out)]); |
963 | |
964 | // brg.alpha == 0 means no read from input, no bias, no eltwise - just |
965 | // initialize registers by zero |
966 | // brg.beta == 0 means no sum - just registers write to output |
967 | if (brg.alpha == 0) { |
968 | for_(int m = 0; m < m_block; m++) |
969 | for (int n = 0; n < n_block; n++) { |
970 | auto vmm = Vmm(m * n_block + n); |
971 | uni_vpxor(vmm, vmm, vmm); |
972 | } |
973 | } |
974 | |
975 | for (int mb_ = 0; mb_ < mb; mb_++) { |
976 | loop_by_N(m_block, nb2, nb2_tail, nb_tail); |
977 | |
978 | if (brg.alpha != 0) |
979 | add(reg_in, inp_typesize_ * (m_block * brg.LDC)); |
980 | add(reg_out, out_typesize_ * (m_block * LDD_)); |
981 | } |
982 | if (mb_tail > 0) loop_by_N(mb_tail, nb2, nb2_tail, nb_tail); |
983 | |
984 | add(rsp, stack_space_needed_); |
985 | |
986 | postamble(); |
987 | |
988 | if (brg.alpha != 0 && jcp.with_eltwise) |
989 | postops_injector_->prepare_table(); |
990 | } |
991 | }; |
992 | |
993 | #undef GET_OFF |
994 | |
995 | } // namespace x64 |
996 | } // namespace cpu |
997 | } // namespace impl |
998 | } // namespace dnnl |
999 | |
1000 | #endif |
1001 | |