1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <memory>
18
19#include "common/dnnl_thread.hpp"
20#include "common/math_utils.hpp"
21#include "cpu/simple_q10n.hpp"
22
23#include "cpu/x64/cpu_isa_traits.hpp"
24#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
26#include "cpu/x64/jit_generator.hpp"
27
28#include "cpu/x64/jit_gemm_inner_product_utils.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34namespace inner_product_utils {
35
36using namespace dnnl::impl::cpu::inner_product_utils;
37using namespace Xbyak;
38using namespace data_type;
39
40template <cpu_isa_t isa>
41struct jit_pp_kernel_t : public pp_kernel_t, public jit_generator {
42 DECLARE_CPU_JIT_AUX_FUNCTIONS(inner_product_utils::jit_pp_kernel_t);
43
44 jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
45 const primitive_attr_t *attr, data_type_t bias_dt,
46 data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum);
47
48 void operator()(void *dst, const void *acc, const char *bias,
49 const float *scales, float dst_scale, size_t start,
50 size_t dst_logical_off, size_t dim1_off, size_t end,
51 size_t runtime_oc, dim_t dst_mb_stride,
52 const float *dst_zero_points,
53 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
54 size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx,
55 const memory_desc_t &dst_md) const override;
56
57 status_t create_kernel() override { return jit_generator::create_kernel(); }
58
59private:
60 using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm,
61 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
62
63 enum class arg_t { dst, acc, bias, stack, scale, sum };
64 enum class data_op_t { load, store };
65
66 void apply_postops(const bool apply_mask, const int vmm_idx,
67 const size_t offset, bool runtime_tail_mask);
68 void prepare_mask(const size_t tail);
69 void load_no_tail(const Vmm v, Xbyak::Address op, const data_type_t dt);
70 void load_tail(const Vmm v, const arg_t arg_num, const size_t off,
71 const data_type_t dt, const size_t tail);
72 void load_and_cvt(const Vmm v, const arg_t arg_num, const size_t off,
73 const size_t tail, bool do_cvt = true);
74 // convert and store instances for each case of Vmm
75 void cvt_and_store(const Xbyak::Zmm v, const arg_t arg_num,
76 const size_t off, const size_t tail);
77 void cvt_and_store(const Xbyak::Ymm v, const arg_t arg_num,
78 const size_t off, const size_t tail);
79 void cvt_and_store(const Xbyak::Xmm v, const arg_t arg_num,
80 const size_t off, const size_t tail);
81 void runtime_tail_load_cvt(const Vmm v, const arg_t arg_num,
82 const size_t off, bool cvt = true);
83 void runtime_tail_cvt_store(
84 const Vmm v, const arg_t arg_num, const size_t off);
85 void data_copy(const Vmm v, const arg_t arg_num, const size_t off,
86 data_op_t data_op, const size_t tail,
87 const bool is_needed_runtime_tail_process,
88 const bool do_cvt = true);
89 void generate() override;
90 void compute_oc_channel_blk();
91 void compute_mb_blk(); // vectorize across minibatch
92
93 void advance_binary_postops_off(const size_t offset);
94
95 void advance_binary_postops_off(const Xbyak::Reg64 offset);
96
97 template <typename T>
98 void advance_binary_postops_per_oc_off(const T &offset);
99
100 void update_binary_postops_per_tensor_off();
101
102 template <typename T>
103 void advance_binary_postops_channel_bcast_off(const T &offset);
104
105 struct ker_args_t {
106 char *dst = nullptr;
107 const char *acc = nullptr;
108 const char *bias = nullptr;
109 const float *scales = nullptr;
110 float dst_scale = 1.0f;
111 const float *dst_zero_points = nullptr;
112 float nslope = 0;
113 size_t oc = 0;
114 size_t len = 0;
115 size_t oc_offset = 0;
116 size_t dim1_off = 0;
117 size_t dst_logical_off = 0;
118 size_t first_mb_matrix_addr_off = 0;
119 dim_t dst_mb_stride = 0;
120 const void *post_ops_binary_rhs_arg_vec = nullptr;
121 const void *dst_orig = nullptr;
122 };
123
124 const bool is_avx512_ = utils::one_of(isa, avx512_core, avx512_core_bf16);
125 static constexpr cpu_isa_t inject_isa_
126 = isa == avx512_core_bf16 ? avx512_core : isa;
127 std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa_>>
128 postops_injector_;
129
130 std::unique_ptr<bf16_emulation_t> bf16_emu_;
131
132#ifdef _WIN32
133 const Xbyak::Reg64 reg_binary_inj_param_ = abi_not_param1;
134#else
135 const Xbyak::Reg64 reg_binary_inj_param_ = abi_param1;
136#endif
137
138 const Xbyak::Reg64 reg_param = abi_param1;
139 const Xbyak::Reg64 reg_stack_frame_ = rbp;
140 const Xbyak::Reg64 reg_dst = rdx;
141 const Xbyak::Reg64 reg_acc = rax;
142 const Xbyak::Reg64 reg_bias = rbx;
143 const Xbyak::Reg64 reg_scales = rsi;
144
145 const Xbyak::Reg64 reg_oc = r13;
146 const Xbyak::Reg64 reg_len = r8;
147 const Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes
148 const Xbyak::Reg64 reg_tail = reg_tmp;
149 const Xbyak::Reg64 reg_oc_offset = r9;
150 const Xbyak::Reg64 reg_rem_mask = r10;
151 const Xbyak::Opmask kreg_rem_mask = k1;
152 const Xbyak::Opmask opmask_binary = k3;
153 const Vmm vmm_rem_mask = Vmm(0);
154 // register used for temp computation, needs not to be preserved
155 const Xbyak::Reg64 reg_tmp_comp = r15;
156
157 // *mb_stride used only in matmul_pp_kernel && compute_oc_channel_blk()
158 const Xbyak::Reg64 reg_dst_mb_stride = r12;
159 const Xbyak::Reg64 reg_acc_mb_stride = r14;
160
161 // Will be assigned in constructor
162 Vmm vreg_zero, vreg_saturation_ubound, vreg_scale, vreg_dst_scale,
163 vreg_sum_scale, vreg_sum_zp, vreg_dst_zero_points;
164
165 const Xbyak::Reg64 eltwise_reserved_gpr_ = r11;
166 const Xbyak::Opmask eltwise_reserved_opmask_ = k2;
167
168 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(28);
169 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(29);
170 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(30);
171 Xbyak::Reg64 bf16_emu_reserv_4 = reg_tmp_comp;
172 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(31);
173
174 int max_OC_loop_unroll_ = 13;
175 int idx_compute_vreg_start_ = is_avx512_ ? 0 : 1;
176 int idx_compute_vreg_max_ = is_avx512_ ? 31 : 15;
177 int compute_vregs_per_iter_ = 1;
178 int compute_vreg_bias_shift_ = 0;
179 int compute_vreg_prev_dst_shift_ = 0;
180
181 const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof(float);
182 static constexpr int reg64_size_ = sizeof(int64_t);
183 static constexpr int reg_binary_post_op_oc_off_ = 0;
184 static constexpr int reg_binary_post_op_offset_ = 1 * reg64_size_;
185 static constexpr int reg_binary_post_op_sp_off_ = 2 * reg64_size_;
186 static constexpr int reg_origin_dst_ptr_ = 3 * reg64_size_;
187 static constexpr int stack_space_needed_ = 4 * reg64_size_;
188
189 bool any_binary_postop_is_no_bcast_type_ = false;
190 bool any_binary_postop_is_per_oc_bcast_type_ = false;
191 bool any_binary_postop_is_per_oc_sp_bcast_type_ = false;
192 bool any_binary_postop_is_oc_bcast_type_ = false;
193
194 int vreg_dst_idx(const int iter) const {
195 int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_;
196 assert(idx <= idx_compute_vreg_max_);
197 return idx;
198 }
199
200 Vmm vreg_dst(int iter) { return Vmm(vreg_dst_idx(iter)); }
201
202 Vmm vreg_prev_dst(int iter) {
203 int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_
204 + compute_vreg_prev_dst_shift_;
205 assert(idx <= idx_compute_vreg_max_);
206 return Vmm(idx);
207 }
208
209 Vmm vreg_bias(int iter) {
210 int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_
211 + compute_vreg_bias_shift_;
212 assert(idx <= idx_compute_vreg_max_);
213 return Vmm(idx);
214 }
215
216 Xbyak::Address dst_ptr(const size_t offt) { return ptr[reg_dst + offt]; }
217
218 Xbyak::Address acc_ptr(const size_t offt) { return ptr[reg_acc + offt]; }
219
220 Xbyak::Address bias_ptr(const size_t offt) { return ptr[reg_bias + offt]; }
221
222 Xbyak::Address stack_ptr(const size_t offt) { return ptr[rsp + offt]; }
223
224 Xbyak::Address scale_ptr(const size_t offt) {
225 return ptr[reg_scales + offt];
226 }
227
228 Xbyak::Address get_address(const arg_t arg_num, const size_t off) {
229 switch (arg_num) {
230 case arg_t::dst:
231 case arg_t::sum: return dst_ptr(off);
232 case arg_t::acc: return acc_ptr(off);
233 case arg_t::bias: return bias_ptr(off);
234 case arg_t::stack: return stack_ptr(off);
235 case arg_t::scale: return scale_ptr(off);
236 default: assert(!"unsupported arg_num"); break;
237 }
238 return Xbyak::Address(0);
239 }
240
241 Xbyak::Reg64 get_reg_address(const arg_t arg_num) {
242 switch (arg_num) {
243 case arg_t::dst:
244 case arg_t::sum: return reg_dst;
245 case arg_t::acc: return reg_acc;
246 case arg_t::bias: return reg_bias;
247 case arg_t::stack: return rsp;
248 case arg_t::scale: return reg_scales;
249 default: assert(!"unsupported arg_num"); break;
250 }
251 return rsp;
252 }
253
254 data_type_t get_data_type(const arg_t arg_num) {
255 switch (arg_num) {
256 case arg_t::dst: return this->dst_data_type_;
257 case arg_t::sum: return this->sum_data_type_;
258 case arg_t::acc: return this->acc_data_type_;
259 case arg_t::bias: return this->bias_data_type_;
260 // default for stack or scale operation
261 default: return f32;
262 }
263 return data_type::undef;
264 }
265};
266
267template <cpu_isa_t isa>
268jit_pp_kernel_t<isa>::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
269 const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt,
270 const memory_desc_t *dst_md, bool skip_sum)
271 : pp_kernel_t(
272 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum)
273 , jit_generator(jit_name()) {
274 assert(IMPLICATION(this->dst_data_type_ == bf16, mayiuse(avx512_core)));
275
276 if (this->do_scale_) vreg_scale = Vmm(idx_compute_vreg_start_++);
277
278 if (this->dst_data_type_ == u8) vreg_zero = Vmm(idx_compute_vreg_start_++);
279 if (utils::one_of(this->dst_data_type_, u8, s8, s32))
280 vreg_saturation_ubound = Vmm(idx_compute_vreg_start_++);
281
282 if (this->do_sum_) {
283 compute_vreg_prev_dst_shift_ = compute_vregs_per_iter_++;
284 if (this->sum_scale_ != 1.f)
285 vreg_sum_scale = Vmm(idx_compute_vreg_start_++);
286 if (this->sum_zp_ != 0) vreg_sum_zp = Vmm(idx_compute_vreg_start_++);
287 }
288
289 if (this->do_bias()) compute_vreg_bias_shift_ = compute_vregs_per_iter_++;
290
291 if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) {
292 this->do_dst_scale_ = true;
293 vreg_dst_scale = Vmm(idx_compute_vreg_start_++);
294 }
295
296 if (!attr->zero_points_.has_default_values(DNNL_ARG_DST)) {
297 this->do_dst_zero_points_ = true;
298 vreg_dst_zero_points = Vmm(idx_compute_vreg_start_++);
299 }
300
301 if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16) {
302 idx_compute_vreg_max_ = 27;
303 bf16_emu_.reset(new bf16_emulation_t(this, bf16_emu_reserv_1,
304 bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4,
305 bf16_emu_reserv_5));
306 }
307
308 int max_unroll = (idx_compute_vreg_max_ - idx_compute_vreg_start_ + 1)
309 / compute_vregs_per_iter_;
310 max_OC_loop_unroll_ = nstl::min(max_OC_loop_unroll_, max_unroll);
311 if (this->do_eltwise_ || this->do_binary_) {
312#define PARAM_OFF(field) offsetof(ker_args_t, field)
313 static constexpr bool preserve_gpr = true;
314 static constexpr bool preserve_vmm = false;
315 static const size_t helper_vmm_idx = is_avx512_ ? 31 : 15;
316 static constexpr bool use_exact_tail_scalar_bcast = false;
317 const auto dst_md_wrapper = memory_desc_wrapper(*dst_md);
318
319 size_t OC_loop, OC_tail;
320 if (OC < max_OC_loop_unroll_ * vlen) {
321 // Fully unroll small loops
322 OC_loop = 0;
323 OC_tail = OC;
324 } else {
325 OC_loop = vlen * max_OC_loop_unroll_;
326 OC_tail = OC % OC_loop;
327 }
328 size_t tail_size = OC_tail % vlen;
329 // enable tail processing for runtime load even if there is no tail
330 // for the OC
331 tail_size = !!tail_size ? tail_size : 1;
332 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
333 helper_vmm_idx, eltwise_reserved_gpr_, r14, r15, preserve_gpr,
334 preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
335 PARAM_OFF(dst_orig), dst_md_wrapper, tail_size, opmask_binary,
336 reg_tmp, use_exact_tail_scalar_bcast};
337 static const bcast_set_t enabled_bcast_strategy
338 = {broadcasting_strategy_t::scalar,
339 broadcasting_strategy_t::per_oc,
340 broadcasting_strategy_t::per_oc_spatial,
341 broadcasting_strategy_t::per_mb_spatial,
342 broadcasting_strategy_t::per_mb_w,
343 broadcasting_strategy_t::per_w,
344 broadcasting_strategy_t::no_broadcast};
345 const binary_injector::static_params_t binary_static_params {
346 reg_binary_inj_param_, enabled_bcast_strategy,
347 rhs_arg_static_params};
348 static constexpr bool save_state = true;
349 const eltwise_injector::static_params_t eltwise_static_params {
350 save_state, reg_tmp_comp, eltwise_reserved_opmask_};
351
352 postops_injector_ = utils::make_unique<
353 injector::jit_uni_postops_injector_t<inject_isa_>>(this,
354 this->post_ops_, binary_static_params, eltwise_static_params);
355
356 using namespace dnnl::impl::cpu::binary_injector_utils;
357 std::tie(any_binary_postop_is_no_bcast_type_,
358 any_binary_postop_is_per_oc_bcast_type_,
359 any_binary_postop_is_per_oc_sp_bcast_type_,
360 any_binary_postop_is_oc_bcast_type_)
361 = bcast_strategies_present_tup(this->post_ops_.entry_,
362 dst_md_wrapper, broadcasting_strategy_t::no_broadcast,
363 broadcasting_strategy_t::per_oc,
364 broadcasting_strategy_t::per_oc_spatial,
365 broadcasting_strategy_t::per_mb_spatial);
366 }
367#undef PARAM_OFF
368}
369
370template <cpu_isa_t isa>
371template <typename T>
372void jit_pp_kernel_t<isa>::advance_binary_postops_per_oc_off(const T &offset) {
373
374 const auto binary_post_op_oc_off_reg = reg_tmp_comp;
375 const auto binary_post_op_current_offset_on_stack
376 = ptr[rsp + reg_binary_post_op_oc_off_];
377
378 mov(binary_post_op_oc_off_reg, binary_post_op_current_offset_on_stack);
379 add(binary_post_op_oc_off_reg, offset);
380
381 if (this->ndims_ == 2) {
382 Xbyak::Label end;
383 cmp(binary_post_op_oc_off_reg, this->OC_);
384 jl(end, T_NEAR);
385 xor_(binary_post_op_oc_off_reg, binary_post_op_oc_off_reg);
386 L(end);
387 }
388
389 mov(binary_post_op_current_offset_on_stack, binary_post_op_oc_off_reg);
390}
391
392template <cpu_isa_t isa>
393void jit_pp_kernel_t<isa>::update_binary_postops_per_tensor_off() {
394 // substract dst_origin from current dst and divide it by dst data type
395 // size to get the correct offset
396 const auto binary_post_op_offset_reg = reg_tmp_comp;
397 const auto binary_post_op_current_offset_on_stack
398 = ptr[rsp + reg_binary_post_op_offset_];
399 mov(binary_post_op_offset_reg, reg_dst);
400 sub(binary_post_op_offset_reg, ptr[rsp + reg_origin_dst_ptr_]);
401 sar(binary_post_op_offset_reg,
402 std::log2(types::data_type_size(get_data_type(arg_t::dst))));
403 mov(binary_post_op_current_offset_on_stack, binary_post_op_offset_reg);
404}
405
406template <cpu_isa_t isa>
407template <typename T>
408void jit_pp_kernel_t<isa>::advance_binary_postops_channel_bcast_off(
409 const T &offset) {
410
411 const auto binary_post_op_offset_reg = reg_tmp_comp;
412 const auto binary_post_op_current_offset_on_stack
413 = ptr[rsp + reg_binary_post_op_sp_off_];
414 mov(binary_post_op_offset_reg, binary_post_op_current_offset_on_stack);
415 add(binary_post_op_offset_reg, offset);
416 mov(binary_post_op_current_offset_on_stack, binary_post_op_offset_reg);
417}
418
419/*
420 * Advance binary postops offsets with per_tensor_offset passed as plain value
421 * type (const offset value).
422 */
423template <cpu_isa_t isa>
424void jit_pp_kernel_t<isa>::advance_binary_postops_off(const size_t offset) {
425 if (offset) {
426 if (any_binary_postop_is_per_oc_bcast_type_)
427 advance_binary_postops_per_oc_off(offset);
428 if (any_binary_postop_is_no_bcast_type_)
429 update_binary_postops_per_tensor_off();
430 if (any_binary_postop_is_oc_bcast_type_)
431 advance_binary_postops_channel_bcast_off(offset);
432 }
433}
434
435/*
436 * Advance binary postops offsets with per_tensor_offset passed in Reg64.
437 */
438template <cpu_isa_t isa>
439void jit_pp_kernel_t<isa>::advance_binary_postops_off(
440 const Xbyak::Reg64 reg_offset) {
441 if (any_binary_postop_is_per_oc_bcast_type_)
442 advance_binary_postops_per_oc_off(reg_offset);
443 if (any_binary_postop_is_no_bcast_type_)
444 update_binary_postops_per_tensor_off();
445 if (any_binary_postop_is_oc_bcast_type_)
446 advance_binary_postops_channel_bcast_off(reg_offset);
447}
448
449template <cpu_isa_t isa>
450void jit_pp_kernel_t<isa>::apply_postops(const bool apply_mask,
451 const int vmm_idx, const size_t offset, const bool runtime_tail_mask) {
452 if (this->do_eltwise_ || this->do_binary_) {
453 if (this->do_binary_) {
454 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
455 if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
456 rhs_arg_params.tail_load_mode = runtime_tail_mask
457 ? binary_injector::tail_lode_mode_t::DYNAMIC
458 : binary_injector::tail_lode_mode_t::DEFAULT;
459
460 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst);
461 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, offset);
462
463 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
464 } else
465 postops_injector_->compute_vector(vmm_idx);
466 }
467}
468
469template <cpu_isa_t isa>
470void jit_pp_kernel_t<isa>::prepare_mask(const size_t tail) {
471 assert(tail > 0 && tail <= vlen - 1);
472 if (is_avx512_) {
473 const size_t tail_mask = (1 << tail) - 1;
474 mov(reg_tmp, tail_mask);
475 kmovq(kreg_rem_mask, reg_tmp);
476 } else if (isa == avx2) {
477 static const uint32_t mask_f32[14]
478 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
479 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
480
481 mov(reg_tmp, reinterpret_cast<size_t>(&mask_f32[7 - tail]));
482 vmovups(vmm_rem_mask, ptr[reg_tmp]);
483 }
484}
485
486template <cpu_isa_t isa>
487void jit_pp_kernel_t<isa>::load_no_tail(
488 const Vmm v, Xbyak::Address op, const data_type_t dt) {
489 using namespace data_type;
490 switch (dt) {
491 case s8: uni_vpmovsxbd(v, op); break;
492 case u8: uni_vpmovzxbd(v, op); break;
493 case s32:
494 case f32: uni_vmovups(v, op); break;
495 case bf16:
496 vpmovzxwd(v, op);
497 vpslld(v, v, 0x10);
498 break;
499 default: assert(!"unimplemented");
500 }
501}
502
503template <cpu_isa_t isa>
504void jit_pp_kernel_t<isa>::load_tail(const Vmm v, const arg_t arg_num,
505 const size_t off, const data_type_t dt, const size_t tail) {
506 using namespace data_type;
507 if (is_avx512_) {
508 auto v_dst = tail ? v | kreg_rem_mask : v;
509 load_no_tail(v_dst, get_address(arg_num, off), dt);
510 } else {
511 if (utils::one_of(dt, s8, u8)) {
512 const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx());
513 for (size_t i = 0; i < tail; i++)
514 uni_vpinsrb(x, x, get_address(arg_num, i + off), i);
515 if (dt == s8)
516 uni_vpmovsxbd(v, v);
517 else
518 uni_vpmovzxbd(v, v);
519 } else {
520 const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
521 if (is_ymm) {
522 vmaskmovps(v, vmm_rem_mask, get_address(arg_num, off));
523 } else {
524 const size_t dt_size = types::data_type_size(f32);
525 for (size_t i = 0; i < tail; i++)
526 uni_vpinsrd(
527 v, v, get_address(arg_num, i * dt_size + off), i);
528 }
529 }
530 }
531}
532
533template <cpu_isa_t isa>
534void jit_pp_kernel_t<isa>::load_and_cvt(const Vmm v, const arg_t arg_num,
535 const size_t off, const size_t tail, bool do_cvt) {
536 using namespace data_type;
537 const data_type_t dt = get_data_type(arg_num);
538 if (tail)
539 load_tail(v, arg_num, off, dt, tail);
540 else
541 load_no_tail(v, get_address(arg_num, off), dt);
542
543 if (do_cvt && utils::one_of(dt, u8, s8, s32)) uni_vcvtdq2ps(v, v);
544}
545
546template <cpu_isa_t isa>
547void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Zmm v,
548 const arg_t arg_num, const size_t off, const size_t tail) {
549 using namespace data_type;
550 const data_type_t dt = get_data_type(arg_num);
551 if (!utils::one_of(dt, f32, bf16)) {
552 Vmm vreg = Vmm(v.getIdx()); // in case of use Ymm for bf16
553 saturate_f32(vreg, vreg_zero, vreg_saturation_ubound, dt);
554 vcvtps2dq(v, v);
555 } else if (dt == bf16) {
556 if (isa == avx512_core_bf16)
557 vcvtneps2bf16(Ymm(v.getIdx()), v);
558 else
559 bf16_emu_->vcvtneps2bf16(Ymm(v.getIdx()), v);
560 }
561
562 auto v_src = tail ? v | kreg_rem_mask : v;
563 const Xbyak::Address dst = get_address(arg_num, off);
564 switch (dt) {
565 case s8: vpmovsdb(dst, v_src); break;
566 case u8: vpmovusdb(dst, v_src); break;
567 case f32:
568 case s32: uni_vmovups(dst, v_src); break;
569 case bf16:
570 vmovdqu16(dst,
571 tail ? Ymm(v.getIdx()) | kreg_rem_mask : Ymm(v.getIdx()));
572 break;
573 default: assert(!"unimplemented");
574 }
575}
576
577template <cpu_isa_t isa>
578void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Ymm v,
579 const arg_t arg_num, const size_t off, const size_t tail) {
580 using namespace data_type;
581 const data_type_t dt = get_data_type(arg_num);
582 const Xbyak::Address dst = get_address(arg_num, off);
583 const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx());
584 if (dt == bf16) {
585 // use Zmm implementation for bf16 with Ymm
586 cvt_and_store(Xbyak::Zmm(v.getIdx()), arg_num, off, tail);
587 return;
588 } else if (utils::one_of(dt, s8, u8, s32)) {
589 saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt);
590 vcvtps2dq(v, v);
591
592 if (dt != s32) {
593 // v = { 8x32 }
594 vpackssdw(v, v, vreg_zero);
595 // v = { 4x16, 0, 4x16, 0 }
596 vpermq(v, v, 0x58);
597 // v = { 8x16, 0 }
598 if (dt == s8)
599 vpacksswb(v, v, vreg_zero);
600 else
601 vpackuswb(v, v, vreg_zero);
602 }
603 }
604
605 if (tail) {
606 switch (dt) {
607 case s8:
608 case u8:
609 for (size_t i = 0; i < tail; i++)
610 vpextrb(get_address(arg_num, off + i), x, i);
611 break;
612 case f32:
613 case s32: vmaskmovps(dst, vmm_rem_mask, v); break;
614 default: assert(!"unimplemented");
615 }
616 } else {
617 switch (dt) {
618 case s8:
619 case u8: vmovq(dst, x); break;
620 case f32:
621 case s32: vmovups(dst, v); break;
622 default: assert(!"unimplemented");
623 }
624 }
625}
626
627template <cpu_isa_t isa>
628void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Xmm v,
629 const arg_t arg_num, const size_t off, const size_t tail) {
630 using namespace data_type;
631 const data_type_t dt = get_data_type(arg_num);
632 const Xbyak::Address dst = get_address(arg_num, off);
633 if (utils::one_of(dt, s8, u8, s32)) {
634 saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt);
635 uni_vcvtps2dq(v, v);
636
637 if (dt != s32) {
638 // v = { 8x32 }
639 uni_vpackssdw(v, v, vreg_zero);
640 // v = { 4x16, 0}
641 if (dt == s8)
642 uni_vpacksswb(v, v, vreg_zero);
643 else
644 uni_vpackuswb(v, v, vreg_zero);
645 }
646 }
647
648 if (tail) {
649 switch (dt) {
650 case s8:
651 case u8:
652 for (size_t i = 0; i < tail; i++)
653 uni_vpextrb(get_address(arg_num, off + i), v, i);
654 break;
655 case f32:
656 case s32: {
657 const size_t dt_size = types::data_type_size(f32);
658 for (size_t i = 0; i < tail; i++)
659 uni_vpextrd(get_address(arg_num, off + i * dt_size), v, i);
660 } break;
661 default: assert(!"unimplemented");
662 }
663 } else {
664 switch (dt) {
665 case s8:
666 case u8: uni_vmovd(dst, v); break;
667 case f32:
668 case s32: uni_vmovups(dst, v); break;
669 default: assert(!"unimplemented");
670 }
671 }
672}
673
674template <cpu_isa_t isa>
675void jit_pp_kernel_t<isa>::runtime_tail_load_cvt(
676 const Vmm v, const arg_t arg_num, const size_t off, bool cvt) {
677 assert(!is_avx512_);
678 const data_type_t dt = get_data_type(arg_num);
679 const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
680 const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx());
681 const Xbyak::Ymm y = Xbyak::Ymm(v.getIdx());
682 const Xbyak::Reg64 reg_addr = get_reg_address(arg_num);
683
684 auto runtime_tail_load = [&](int load_size) {
685 if (is_ymm)
686 load_data(dt, y, reg_addr, off, load_size);
687 else
688 load_data(dt, x, reg_addr, off, load_size);
689 };
690
691 runtime_tail_process<Vmm>(reg_tail, reg_rem_mask, runtime_tail_load);
692
693 if (cvt && utils::one_of(dt, u8, s8, s32)) uni_vcvtdq2ps(v, v);
694}
695
696template <cpu_isa_t isa>
697void jit_pp_kernel_t<isa>::runtime_tail_cvt_store(
698 const Vmm v, const arg_t arg_num, const size_t off) {
699 assert(!is_avx512_);
700 const data_type_t dt = get_data_type(arg_num);
701 const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
702 const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx());
703 const Xbyak::Ymm y = Xbyak::Ymm(v.getIdx());
704 const Xbyak::Reg64 reg_addr = get_reg_address(arg_num);
705
706 if (utils::one_of(dt, u8, s8, s32)) {
707 saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt);
708 uni_vcvtps2dq(v, v);
709 }
710
711 auto runtime_tail_store = [&](int store_size) {
712 if (is_ymm)
713 store_data(dt, y, reg_addr, off, store_size);
714 else
715 store_data(dt, x, reg_addr, off, store_size);
716 };
717
718 runtime_tail_process<Vmm>(reg_tail, reg_rem_mask, runtime_tail_store);
719}
720
721template <cpu_isa_t isa>
722void jit_pp_kernel_t<isa>::data_copy(const Vmm v, const arg_t arg_num,
723 const size_t off, data_op_t data_op, const size_t tail,
724 const bool is_needed_runtime_tail_process, const bool do_cvt) {
725 if (data_op == data_op_t::load) {
726 if (is_needed_runtime_tail_process)
727 runtime_tail_load_cvt(v, arg_num, off, do_cvt);
728 else
729 load_and_cvt(v, arg_num, off, tail, do_cvt);
730 } else {
731 if (is_needed_runtime_tail_process)
732 runtime_tail_cvt_store(v, arg_num, off);
733 else
734 cvt_and_store(v, arg_num, off, tail);
735 }
736}
737
738template <cpu_isa_t isa>
739void jit_pp_kernel_t<isa>::compute_oc_channel_blk() {
740 // Load accumulated value, convert to float, apply bias (if any), scaling,
741 // and eltwise (if any); then convert to destination type and store
742
743 auto compute = [&](size_t offset, int idx, bool runtime_tail_mask,
744 int tail = 0) {
745 const bool is_needed_runtime_tail_process
746 = runtime_tail_mask && tail && !is_avx512_;
747
748 if (this->do_scale_ && this->scale_idx_mult_ == 1)
749 data_copy(vreg_scale, arg_t::scale, offset * sizeof(float),
750 data_op_t::load, tail, is_needed_runtime_tail_process,
751 false);
752
753 if (this->do_binary_ && tail && is_avx512_)
754 kmovq(opmask_binary, kreg_rem_mask);
755
756 const int dst_idx = vreg_dst_idx(idx);
757 auto vreg_dst_ = Vmm(dst_idx);
758 data_copy(vreg_dst_, arg_t::acc, offset * this->acc_data_type_size_,
759 data_op_t::load, tail, is_needed_runtime_tail_process);
760
761 if (this->do_scale_) uni_vmulps(vreg_dst_, vreg_dst_, vreg_scale);
762
763 if (this->do_bias()) {
764 auto vreg_bias_ = vreg_bias(idx);
765 data_copy(vreg_bias_, arg_t::bias,
766 offset * this->bias_data_type_size_, data_op_t::load, tail,
767 is_needed_runtime_tail_process);
768 uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_);
769 }
770
771 if (this->do_sum_) {
772 auto vreg_prev_dst_ = vreg_prev_dst(idx);
773 data_copy(vreg_prev_dst_, arg_t::sum,
774 offset * this->dst_data_type_size_, data_op_t::load, tail,
775 is_needed_runtime_tail_process);
776 if (this->sum_zp_ != 0)
777 uni_vsubps(vreg_prev_dst_, vreg_prev_dst_, vreg_sum_zp);
778 if (this->sum_scale_ != 1.f)
779 uni_vfmadd231ps(vreg_dst_, vreg_prev_dst_, vreg_sum_scale);
780 else
781 uni_vaddps(vreg_dst_, vreg_dst_, vreg_prev_dst_);
782 }
783
784 apply_postops(!!tail, dst_idx, offset * this->dst_data_type_size_,
785 is_needed_runtime_tail_process);
786
787 if (this->do_dst_scale_)
788 uni_vmulps(vreg_dst_, vreg_dst_, vreg_dst_scale);
789
790 if (this->do_dst_zero_points_)
791 uni_vaddps(vreg_dst_, vreg_dst_, vreg_dst_zero_points);
792
793 data_copy(vreg_dst_, arg_t::dst, offset * this->dst_data_type_size_,
794 data_op_t::store, tail, is_needed_runtime_tail_process);
795 };
796
797 // Advance all pointers by an immediate
798 auto advance_ptrs_imm = [&](size_t offset) {
799 add(reg_dst, offset * this->dst_data_type_size_);
800 add(reg_acc, offset * this->acc_data_type_size_);
801 if (this->do_scale_ && this->scale_idx_mult_ == 1)
802 add(reg_scales, offset * sizeof(float));
803 if (this->do_bias()) add(reg_bias, offset * this->bias_data_type_size_);
804 if (this->do_binary_) { advance_binary_postops_off(offset); }
805 };
806
807 // Advance all pointers by a value stored in a register
808 auto advance_ptrs_reg = [&](const Reg64 offset) {
809 lea(reg_dst, ptr[reg_dst + offset * this->dst_data_type_size_]);
810 lea(reg_acc, ptr[reg_acc + offset * this->acc_data_type_size_]);
811 if (this->do_scale_ && this->scale_idx_mult_ == 1)
812 lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
813 if (this->do_bias())
814 lea(reg_bias, ptr[reg_bias + offset * this->bias_data_type_size_]);
815 if (this->do_binary_) advance_binary_postops_off(offset);
816 };
817
818 // incase of non-trivial dst_mb_strides, fixup the reg_dst and reg_acc
819 auto maybe_advance_mb_stride = [&]() {
820 if (!this->has_trivial_mb_stride()) {
821 lea(reg_dst,
822 ptr[reg_dst
823 + reg_dst_mb_stride * this->dst_data_type_size_]);
824 lea(reg_acc,
825 ptr[reg_acc
826 + reg_acc_mb_stride * this->acc_data_type_size_]);
827 }
828 if (this->do_binary_ && any_binary_postop_is_no_bcast_type_)
829 update_binary_postops_per_tensor_off();
830 };
831
832 // Rewind pointers that point to data that is indexed by output channel
833 // (bias or per-oc scaling factors)
834 auto rewind_ptrs = [&]() {
835 neg(reg_oc);
836 if (this->do_bias())
837 lea(reg_bias, ptr[reg_bias + reg_oc * this->bias_data_type_size_]);
838 if (this->do_scale_ && this->scale_idx_mult_ == 1)
839 lea(reg_scales, ptr[reg_scales + reg_oc * sizeof(float)]);
840
841 neg(reg_oc);
842 };
843
844 // Process one row of reg_tmp elements
845 auto process_runtime_oc = [&]() {
846 Label l_loop, l_loop_tail, l_loop_end;
847 cmp(reg_tmp, vlen);
848 jl(l_loop_tail, T_NEAR);
849
850 L(l_loop);
851 {
852 compute(0, 0, true);
853 advance_ptrs_imm(vlen);
854
855 sub(reg_tmp, vlen);
856 cmp(reg_tmp, vlen);
857 jge(l_loop, T_NEAR);
858 }
859
860 L(l_loop_tail);
861 cmp(reg_tmp, 0);
862 je(l_loop_end, T_NEAR);
863
864 if (is_avx512_) {
865 mov(reg_rem_mask, 1);
866 shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
867 sub(reg_rem_mask, 1);
868 kmovq(kreg_rem_mask, reg_rem_mask);
869 }
870 // tail size does not matter for runtime load
871 compute(0, 0, true, true);
872 advance_ptrs_reg(reg_tmp);
873
874 L(l_loop_end);
875 };
876
877 // <-------------------- OC ------------------------------->
878 //
879 // ^ +....................+----------------------------------+
880 // | : not accessed | Prologue loop |
881 // | +--------------------+----------------------------------+
882 // | |
883 // M | Main loop (unrolled) |
884 // B | |
885 // +--------------------------------+----------------------+
886 // | | Epilogue loop | not accessed :
887 // v +--------------------------------+......................+
888
889 if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16)
890 bf16_emu_->init_vcvtneps2bf16();
891
892 // Prologue loop
893 Label l_prologue_end;
894 cmp(reg_oc_offset, 0);
895 je(l_prologue_end, T_NEAR);
896 {
897 mov(reg_tmp, reg_oc);
898 sub(reg_tmp, reg_oc_offset);
899 cmp(reg_tmp, reg_len);
900 cmovg(reg_tmp, reg_len);
901 sub(reg_len, reg_tmp);
902 process_runtime_oc();
903 rewind_ptrs();
904 maybe_advance_mb_stride();
905 }
906 L(l_prologue_end);
907
908 // Main loop
909 Label l_main_loop_end;
910 cmp(reg_len, reg_oc);
911 jle(l_main_loop_end, T_NEAR);
912 if (this->runtime_oc()) {
913 Label l_main_loop;
914 L(l_main_loop);
915 {
916 mov(reg_tmp, reg_oc);
917
918 process_runtime_oc();
919 rewind_ptrs();
920
921 sub(reg_len, reg_oc);
922 maybe_advance_mb_stride();
923 cmp(reg_len, reg_oc);
924 jge(l_main_loop, T_NEAR);
925 }
926 } else {
927 Label l_main_loop;
928 L(l_main_loop);
929 {
930 size_t OC_loop, OC_tail;
931 if (this->OC_ < max_OC_loop_unroll_ * vlen) {
932 // Fully unroll small loops
933 OC_loop = 0;
934 OC_tail = this->OC_;
935 } else {
936 OC_loop = vlen * max_OC_loop_unroll_;
937 OC_tail = this->OC_ % OC_loop;
938 }
939
940 assert(!!OC_loop || !!OC_tail);
941
942 const int vlen_tail = OC_tail % vlen;
943 if (vlen_tail) prepare_mask(vlen_tail);
944
945 if (OC_loop) {
946 mov(reg_tmp, utils::rnd_dn(this->OC_, OC_loop));
947 Label l_oc_loop;
948 L(l_oc_loop);
949 {
950 for (size_t offset = 0; offset < OC_loop; offset += vlen)
951 compute(offset, offset / vlen, false);
952 advance_ptrs_imm(OC_loop);
953 sub(reg_tmp, OC_loop);
954 jnz(l_oc_loop);
955 }
956 }
957
958 if (OC_tail) {
959 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
960 const bool use_mask = (offset + vlen) > OC_tail;
961 compute(offset, offset / vlen, false,
962 use_mask ? vlen_tail : 0);
963 }
964 advance_ptrs_imm(OC_tail);
965 }
966
967 if (any_binary_postop_is_per_oc_sp_bcast_type_
968 && this->ndims_ <= 3) {
969 static constexpr size_t offset_oc_spatial = 1;
970 advance_binary_postops_per_oc_off(offset_oc_spatial);
971 }
972
973 rewind_ptrs();
974 sub(reg_len, reg_oc);
975 maybe_advance_mb_stride();
976 cmp(reg_len, reg_oc);
977 jge(l_main_loop, T_NEAR);
978 }
979 }
980 L(l_main_loop_end);
981
982 // Epilogue loop
983 Label l_epilogue_end;
984 cmp(reg_len, 0);
985 je(l_epilogue_end, T_NEAR);
986 {
987 mov(reg_tmp, reg_len);
988 process_runtime_oc();
989 }
990 L(l_epilogue_end);
991}
992
993template <cpu_isa_t isa>
994void jit_pp_kernel_t<isa>::compute_mb_blk() {
995 auto compute = [&](int tail, bool runtime_tail = false) {
996 auto vmm_bias = vreg_bias(0);
997 auto vmm_dst = vreg_dst(0);
998 assert(utils::one_of(this->acc_data_type_, s32, f32));
999 data_copy(vmm_dst, arg_t::acc, 0, data_op_t::load, tail, runtime_tail);
1000 uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
1001 data_copy(vmm_dst, arg_t::dst, 0, data_op_t::store, tail, runtime_tail);
1002 };
1003
1004 Label mb_main_loop, end_main_loop;
1005
1006 bool expl_broadcast
1007 = this->OC_ == 1 && utils::one_of(this->bias_data_type_, s32, f32);
1008 size_t mb_step = vlen / this->OC_;
1009 size_t mb_tail = this->MB_ % mb_step;
1010 size_t mb_oc_blk = mb_step * this->OC_;
1011 size_t tail_size = mb_oc_blk % vlen;
1012 auto vmm_bias = vreg_bias(0);
1013
1014 if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16)
1015 bf16_emu_->init_vcvtneps2bf16();
1016
1017 if (expl_broadcast) {
1018 // when OC == 1 bias can be loaded directly into simd
1019 switch (this->bias_data_type_) {
1020 case s32: uni_vpbroadcastd(vmm_bias, ptr[reg_bias]); break;
1021 case f32: uni_vbroadcastss(vmm_bias, ptr[reg_bias]); break;
1022 // TODO: enable broadcast for other data types
1023 default: assert(!"unimplemented");
1024 }
1025 } else {
1026 // prepare bias data for simd computation
1027 prepare_mask(this->OC_); // this->OC will never be larger than vlen / 2
1028 load_and_cvt(vmm_bias, arg_t::bias, 0, this->OC_, false);
1029
1030 // write repeated MB*OC entries into stack
1031 sub(rsp, mb_oc_blk * sizeof(uint32_t));
1032 for (size_t i = 0; i < mb_step; ++i)
1033 cvt_and_store(vmm_bias, arg_t::stack,
1034 i * this->OC_ * sizeof(uint32_t), this->OC_);
1035
1036 // load into simd
1037 if (tail_size) prepare_mask(tail_size);
1038 load_and_cvt(vmm_bias, arg_t::stack, 0, tail_size, false);
1039 }
1040 if (utils::one_of(this->bias_data_type_, u8, s8, s32))
1041 uni_vcvtdq2ps(vmm_bias, vmm_bias);
1042 L(mb_main_loop);
1043 {
1044 cmp(reg_len, mb_oc_blk);
1045 jl(end_main_loop, T_NEAR);
1046
1047 compute(!expl_broadcast ? tail_size : 0);
1048 add(reg_dst, mb_oc_blk * this->dst_data_type_size_);
1049 add(reg_acc, mb_oc_blk * this->acc_data_type_size_);
1050 sub(reg_len, mb_oc_blk);
1051 jmp(mb_main_loop, T_NEAR);
1052 }
1053 L(end_main_loop);
1054
1055 if (mb_tail > 0) {
1056 Label mb_tail_loop, runtime_tail, end_runtime_tail;
1057 tail_size = (mb_tail * this->OC_);
1058 if (tail_size) prepare_mask(tail_size);
1059 L(mb_tail_loop);
1060 {
1061 cmp(reg_len, tail_size);
1062 jl(runtime_tail, T_NEAR);
1063 compute(tail_size);
1064 add(reg_dst, tail_size * this->dst_data_type_size_);
1065 add(reg_acc, tail_size * this->acc_data_type_size_);
1066 sub(reg_len, tail_size);
1067 jmp(mb_tail_loop, T_NEAR);
1068 }
1069 // Load tail in runtime if len < mb_tail * oc
1070 L(runtime_tail);
1071 {
1072 cmp(reg_len, 0);
1073 jle(end_runtime_tail, T_NEAR);
1074 mov(reg_tail, reg_len); // save tail
1075 if (is_avx512_) {
1076 mov(reg_rem_mask, 1);
1077 shl(reg_rem_mask, cl); // cl == last 8 bits of reg_tail
1078 sub(reg_rem_mask, 1);
1079 kmovq(kreg_rem_mask, reg_rem_mask);
1080 }
1081 compute(tail_size, !is_avx512_);
1082 }
1083 L(end_runtime_tail);
1084 }
1085
1086 if (!expl_broadcast) add(rsp, mb_oc_blk * sizeof(uint32_t));
1087}
1088
1089template <cpu_isa_t isa>
1090void jit_pp_kernel_t<isa>::generate() {
1091 preamble();
1092
1093#ifdef _WIN32
1094 // binary postops injector needs params held (in case of WIN32)
1095 // in rcx register that is also used as a temp reg, so the pointer to
1096 // params needs to be stored in extra reg
1097 if (this->do_binary_) mov(reg_binary_inj_param_, param1);
1098#endif
1099
1100#define PARAM_OFF(x) offsetof(ker_args_t, x)
1101 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
1102 mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
1103 mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
1104 if (this->do_scale_) mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
1105 if (this->do_dst_scale_) {
1106 // don't overwrite reg_param
1107 assert(reg_tmp_comp.getIdx() != reg_param.getIdx());
1108 mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dst_scale)]);
1109 auto xreg_dst_scale = Xmm(vreg_dst_scale.getIdx());
1110 uni_vmovq(xreg_dst_scale, reg_tmp_comp);
1111 uni_vbroadcastss(vreg_dst_scale, xreg_dst_scale);
1112 }
1113 if (this->do_dst_zero_points_) {
1114 // use reg_oc as a temporary one (alas, reg_tmp = reg_param on Windows)
1115 mov(reg_oc, ptr[reg_param + PARAM_OFF(dst_zero_points)]);
1116 uni_vbroadcastss(vreg_dst_zero_points, ptr[reg_oc]);
1117 }
1118 if (this->runtime_oc())
1119 mov(reg_oc, ptr[reg_param + PARAM_OFF(oc)]);
1120 else
1121 mov(reg_oc, this->OC_);
1122 mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
1123 mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
1124 if (this->do_binary_) {
1125 mov(reg_stack_frame_, rsp);
1126 sub(rsp, stack_space_needed_);
1127 if (any_binary_postop_is_per_oc_sp_bcast_type_
1128 || any_binary_postop_is_per_oc_bcast_type_) {
1129 mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dim1_off)]);
1130 mov(ptr[rsp + reg_binary_post_op_oc_off_], reg_tmp_comp);
1131 }
1132 if (any_binary_postop_is_no_bcast_type_) {
1133 // store origin dst pointer to calculate proper binary src1 offset
1134 mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dst_orig)]);
1135 mov(ptr[rsp + reg_origin_dst_ptr_], reg_tmp_comp);
1136 // init offset
1137 update_binary_postops_per_tensor_off();
1138 }
1139 if (any_binary_postop_is_oc_bcast_type_) {
1140 // initialize binary post_ops no bcast offset accumulator
1141 mov(reg_tmp_comp,
1142 ptr[reg_param + PARAM_OFF(first_mb_matrix_addr_off)]);
1143 mov(ptr[rsp + reg_binary_post_op_sp_off_], reg_tmp_comp);
1144 }
1145 }
1146 if (this->do_scale_ && this->scale_idx_mult_ == 0)
1147 uni_vbroadcastss(vreg_scale, dword[reg_scales]);
1148 if (!this->has_trivial_mb_stride()) {
1149 mov(reg_dst_mb_stride, ptr[reg_param + PARAM_OFF(dst_mb_stride)]);
1150 sub(reg_dst_mb_stride, reg_oc);
1151 // if dst and acc point to same address (in-place), then strides must be
1152 // similar, else assume acc buffer is dense.
1153 xor_(reg_acc_mb_stride, reg_acc_mb_stride);
1154 cmp(reg_dst, reg_acc);
1155 cmove(reg_acc_mb_stride, reg_dst_mb_stride);
1156 }
1157#undef PARAM_OFF
1158
1159 if (this->do_sum_) {
1160 if (this->sum_scale_ != 1.f) {
1161 mov(reg_tmp, float2int(this->sum_scale_));
1162 auto xreg_sum_scale = Xmm(vreg_sum_scale.getIdx());
1163 uni_vmovq(xreg_sum_scale, reg_tmp);
1164 uni_vbroadcastss(vreg_sum_scale, xreg_sum_scale);
1165 }
1166 if (this->sum_zp_ != 0) {
1167 mov(reg_tmp, this->sum_zp_);
1168 auto xreg_sum_zp = Xmm(vreg_sum_zp.getIdx());
1169 uni_vmovq(xreg_sum_zp, reg_tmp);
1170 uni_vbroadcastss(vreg_sum_zp, xreg_sum_zp);
1171 uni_vcvtdq2ps(vreg_sum_zp, vreg_sum_zp);
1172 }
1173 }
1174
1175 init_saturate_f32(vreg_zero, vreg_saturation_ubound, reg_tmp_comp, f32,
1176 this->dst_data_type_);
1177
1178 // at least 2 blocks of mb within vlen
1179 bool dim_restrict = !this->runtime_oc() && !this->runtime_mb()
1180 && (this->OC_ <= vlen / 2) && (this->MB_ >= vlen);
1181 bool supported_postops = this->do_scale_ || this->do_eltwise_
1182 || this->do_binary_ || this->do_sum_ || this->do_dst_zero_points_
1183 || this->do_dst_scale_;
1184 if (this->do_bias() && !supported_postops && dim_restrict
1185 && this->has_trivial_mb_stride()) {
1186 this->mb_blk_kernel_ = true;
1187 compute_mb_blk();
1188 } else {
1189 compute_oc_channel_blk();
1190 }
1191
1192 if (this->do_binary_) add(rsp, stack_space_needed_);
1193 postamble();
1194
1195 if (this->do_eltwise_) postops_injector_->prepare_table();
1196}
1197
1198template <cpu_isa_t isa>
1199void jit_pp_kernel_t<isa>::operator()(void *dst, const void *acc,
1200 const char *bias, const float *scales, float dst_scale, size_t start,
1201 size_t dst_logical_off, size_t dim1_off, size_t end, size_t runtime_oc,
1202 dim_t dst_mb_stride, const float *dst_zero_points,
1203 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
1204 size_t first_mb_matrix_addr_off, const exec_ctx_t & /* ctx */,
1205 const memory_desc_t & /* dst_md */) const {
1206
1207 if (end <= start) return;
1208 const size_t OC = this->runtime_oc() ? runtime_oc : this->OC_;
1209
1210 ker_args_t args;
1211 size_t oc_offset = start % OC;
1212 if (this->has_trivial_mb_stride()) {
1213 args.dst = static_cast<char *>(dst) + this->dst_data_type_size_ * start;
1214 args.acc = static_cast<const char *>(acc)
1215 + this->acc_data_type_size_ * start;
1216 } else {
1217 const dim_t offt = (start / OC) * dst_mb_stride + oc_offset;
1218 args.dst = static_cast<char *>(dst) + this->dst_data_type_size_ * offt;
1219 // if dst and acc point to same address (inplace), then strides
1220 // must be similar, else assume acc buffer is dense.
1221 const auto stride = dst == acc ? offt : start;
1222 args.acc = static_cast<const char *>(acc)
1223 + this->acc_data_type_size_ * stride;
1224 }
1225 args.bias = bias + oc_offset * this->bias_data_type_size_;
1226 args.scales = scales + this->scale_idx_mult_ * oc_offset;
1227 args.dst_scale = dst_scale;
1228 args.dst_zero_points = dst_zero_points;
1229 args.oc = OC;
1230 args.len = end - start;
1231 args.oc_offset = oc_offset;
1232 args.dst_logical_off = dst_logical_off;
1233 args.dim1_off = dim1_off;
1234 args.dst_mb_stride = dst_mb_stride;
1235 args.first_mb_matrix_addr_off = first_mb_matrix_addr_off;
1236
1237 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
1238 args.dst_orig = dst_orig;
1239 jit_generator::operator()(&args);
1240}
1241
1242pp_kernel_t *jit_pp_kernel_create(size_t OC, size_t MB, dim_t dst_mb_stride,
1243 const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt,
1244 const memory_desc_t *dst_md, bool skip_sum) {
1245 if (mayiuse(avx512_core_bf16)) {
1246 return new jit_pp_kernel_t<avx512_core_bf16>(
1247 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
1248 } else if (mayiuse(avx512_core)) {
1249 return new jit_pp_kernel_t<avx512_core>(
1250 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
1251 } else if (mayiuse(avx2)) {
1252 return new jit_pp_kernel_t<avx2>(
1253 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
1254 } else if (mayiuse(sse41)) {
1255 return new jit_pp_kernel_t<sse41>(
1256 OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum);
1257 } else {
1258 return nullptr;
1259 }
1260}
1261
1262} // namespace inner_product_utils
1263} // namespace x64
1264} // namespace cpu
1265} // namespace impl
1266} // namespace dnnl
1267