1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstdlib>
18#include <functional>
19
20#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
21#include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
22#include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
23#include "cpu/x64/jit_generator.hpp"
24#include "cpu/x64/jit_primitive_conf.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31namespace gemm_x8s8s32x_convolution_utils {
32using namespace dnnl::impl::cpu::gemm_x8s8s32x_convolution_utils;
33
34struct jit_pp_ker_t : pp_ker_t, public jit_generator {
35 DECLARE_CPU_JIT_AUX_FUNCTIONS(
36 gemm_x8s8s32x_convolution_utils::jit_pp_ker_t);
37
38 jit_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp);
39
40 status_t create_kernel() override { return jit_generator::create_kernel(); }
41 void operator()(void *void_dst, const acc_data_t *acc, const char *bias,
42 const float *scales, float dst_scale, float sum_scale,
43 float signed_scale, int g, size_t start, size_t end,
44 const zero_point_call_params_t &zp,
45 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
46 const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */,
47 const single_gemm_conv_chunk_desc_t &) const override;
48
49private:
50 void apply_postops(const Xbyak::Reg64 &reg_dst, const int idx);
51 void generate() override;
52 void append_zp_src_comp(size_t offset, int idx, bool apply_mask);
53 void load_as_f32(const Xbyak::Zmm &dst, const Xbyak::Opmask &mask,
54 const Xbyak::Address &src_addr, const data_type_t &src_dt);
55
56 int vreg_dst_idx(const int idx) const noexcept;
57 Xbyak::Zmm get_vreg_dst(int idx) const;
58 Xbyak::Zmm get_vreg_bias(int idx) const;
59 Xbyak::Zmm get_vreg_prev_dst(int idx) const;
60 Xbyak::Zmm get_vreg_zp_comp_src(int idx) const;
61 Xbyak::Zmm get_masked_vreg_dst(int idx, bool apply_mask) const;
62 Xbyak::Zmm reserve_zmm();
63
64 template <typename T>
65 void advance_binary_postops_off(const T &offset);
66 void zero_binary_postops_off();
67 void set_binary_postops_off(const Xbyak::Reg64 &reg);
68 const Xbyak::Opmask &opmask_binary = k2;
69
70 struct ker_args_t {
71 char *dst;
72 const acc_data_t *acc;
73 const char *bias;
74 const float *scales;
75 float dst_scale;
76 float sum_scale;
77 float signed_scale;
78 size_t len;
79 size_t oc_offset;
80 const int32_t *zp_src;
81 const int32_t *zp_dst;
82 const int32_t *zp_src_comp;
83 const int32_t *zp_src_pad_comp;
84 size_t g_oc_offset_prologue;
85 size_t g_oc_offset;
86 const void *post_ops_binary_rhs_arg_vec;
87 const void *dst_orig;
88 dim_t h;
89 dim_t w;
90 dim_t w_size;
91 dim_t w_off;
92 dim_t zp_src_pad_com_d_offset;
93 bool should_apply_zp_src_pad_comp_d;
94 };
95
96 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
97 postops_injector_;
98
99 size_t number_of_reserved_zmm_regs_;
100 const size_t bias_data_type_size_;
101 const size_t dst_data_type_size_;
102 const bool saturation_needed_;
103
104 const Xbyak::Reg64 &reg_param_ = rdi;
105 const Xbyak::Reg64 &reg_tmp_ = rcx; // intentional for shifting purposes
106
107 const Xbyak::Reg64 &reg_dst_ = rdx;
108 const Xbyak::Reg64 &reg_acc_ = rax;
109 const Xbyak::Reg64 &reg_bias_ = rbx;
110 const Xbyak::Reg64 &reg_scales_ = rsi;
111 const Xbyak::Reg64 &reg_len_ = r8;
112 const Xbyak::Reg64 &reg_oc_offset_ = r9;
113 const Xbyak::Reg64 &reg_rem_mask_short_ = r10;
114 const Xbyak::Reg64 &reg_rem_mask_vlen_ = reg_rem_mask_short_;
115 const Xbyak::Reg64 &reg_zp_pad_comp_temp_ = r10;
116 const Xbyak::Reg64 &reg_zp_pad_comp_ = r11;
117 const Xbyak::Reg8 &reg_should_apply_src_pad_comp_ = r13b;
118
119 const Xbyak::Reg64 &reg_tmp_comp_
120 = r12; // used to broadcast scalar values to vreg
121 const Xbyak::Reg64 &reg_g_oc_off_ = reg_tmp_comp_;
122 const Xbyak::Reg64 &reg_zp_src_comp_ = r14;
123
124 const Xbyak::Zmm vreg_zero_;
125 const Xbyak::Zmm vreg_scale_;
126 const Xbyak::Zmm vreg_dst_scale_;
127 const Xbyak::Zmm vreg_sum_scale_;
128 const Xbyak::Zmm vreg_signed_scale_;
129 const Xbyak::Zmm vreg_saturation_ubound_;
130 const Xbyak::Zmm vreg_zp_dst_common_;
131
132 const Xbyak::Opmask &kreg_rem_mask_short_ = k3;
133 const Xbyak::Opmask &kreg_rem_mask_vlen_ = k4;
134
135 static constexpr size_t def_unroll_ = 4u;
136 size_t zmm_step_;
137 const size_t bias_step_factor_;
138 const size_t sum_step_factor_;
139 const size_t max_unroll_;
140 int dst_l_offset_ = 0;
141
142 std::unique_ptr<jit_gemm_x8s8s32x_zp_pad_comp_helper> zp_pad_comp_helper_;
143};
144
145jit_pp_ker_t::jit_pp_ker_t(
146 const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
147 : pp_ker_t(pd, jcp)
148 , jit_generator(jit_name())
149 , number_of_reserved_zmm_regs_(0)
150 , bias_data_type_size_(jcp.bias_data_type != data_type::undef
151 ? types::data_type_size(jcp.bias_data_type)
152 : 0u)
153 , dst_data_type_size_(types::data_type_size(jcp.dst_data_type))
154 , saturation_needed_(utils::one_of(
155 jcp_.dst_data_type, data_type::u8, data_type::s8, data_type::s32))
156 , vreg_zero_((jcp_.with_eltwise || saturation_needed_) ? reserve_zmm()
157 : Xbyak::Zmm(0))
158 , vreg_scale_(reserve_zmm())
159 , vreg_dst_scale_(reserve_zmm())
160 , vreg_sum_scale_(jcp_.with_sum ? reserve_zmm() : Xbyak::Zmm(0))
161 , vreg_signed_scale_(jcp_.signed_input ? reserve_zmm() : Xbyak::Zmm(0))
162 , vreg_saturation_ubound_(
163 saturation_needed_ ? reserve_zmm() : Xbyak::Zmm(0))
164 , vreg_zp_dst_common_(jcp_.zp.dst_exists ? reserve_zmm() : Xbyak::Zmm(0))
165 , zmm_step_(1u)
166 , bias_step_factor_(jcp_.with_bias ? zmm_step_++ : 0u)
167 , sum_step_factor_(jcp_.with_sum ? zmm_step_++ : 0)
168 , max_unroll_((cpu_isa_traits<avx512_core>::n_vregs
169 - number_of_reserved_zmm_regs_)
170 / zmm_step_)
171 , zp_pad_comp_helper_(jit_gemm_convolution_utils::padding_exists(jcp)
172 && jcp.zp.src_exists
173 ? utils::make_unique<
174 jit_gemm_x8s8s32x_zp_pad_comp_helper>(this, jcp_,
175 reg_zp_pad_comp_, reg_zp_pad_comp_temp_,
176 reg_should_apply_src_pad_comp_,
177 pd->src_md()->ndims)
178 : nullptr)
179
180{
181
182 if (jcp.with_eltwise || jcp.with_binary) {
183 using namespace binary_injector;
184 static constexpr bool preserve_gpr = true;
185 static constexpr bool preserve_vmm = true;
186 static constexpr size_t helper_vmm_idx = 31;
187 // tail_size = 1 just indicates that tailing is to be performed
188 // actual tail value is held in opmask passed to injector
189 static constexpr size_t tail_size = 1;
190 static constexpr bool use_exact_tail_scalar_bcast = false;
191
192#define PARAM_OFF(x) offsetof(ker_args_t, x)
193 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
194 r13, r14, r15, preserve_gpr, preserve_vmm,
195 PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig),
196 memory_desc_wrapper(pd->dst_md()), tail_size, opmask_binary,
197 use_exact_tail_scalar_bcast};
198#undef PARAM_OFF
199
200 const static_params_t static_params {reg_param_, rhs_arg_static_params};
201
202 postops_injector_ = utils::make_unique<
203 injector::jit_uni_postops_injector_t<avx512_core>>(
204 this, jcp_.post_ops, static_params);
205 }
206}
207
208void jit_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc,
209 const char *bias, const float *scales, float dst_scale, float sum_scale,
210 float signed_scale, int g, size_t start, size_t end,
211 const zero_point_call_params_t &zp,
212 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
213 const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */,
214 const single_gemm_conv_chunk_desc_t &chunk_desc) const {
215
216 if (end <= start) return;
217
218 char *dst = (char *)void_dst;
219
220 ker_args_t args;
221 const auto dv = std::div(start, jcp_.oc);
222 const size_t oc_offset = dv.rem;
223 const size_t os_offset = dv.quot;
224 args.acc = acc + start;
225 args.dst = dst
226 + (os_offset * jcp_.dst_os_stride + oc_offset)
227 * dst_data_type_size_;
228
229 const ptrdiff_t g_oc_offset = g * jcp_.oc;
230 const ptrdiff_t g_oc_offset_prologue = g_oc_offset + oc_offset;
231 args.bias = bias + g_oc_offset_prologue * bias_data_type_size_;
232 args.zp_src = zp.src + (jcp_.zp.src_is_common ? 0 : g_oc_offset_prologue);
233 args.zp_src_comp
234 = zp.src_comp ? zp.src_comp + g_oc_offset_prologue : nullptr;
235 args.zp_dst = zp.dst;
236 args.scales = scales + jcp_.scale_idx_mult * g_oc_offset_prologue;
237 args.dst_scale = dst_scale;
238 args.sum_scale = sum_scale;
239 args.signed_scale = signed_scale;
240 args.len = end - start;
241 args.oc_offset = oc_offset;
242
243 args.g_oc_offset = g_oc_offset;
244 args.g_oc_offset_prologue = g_oc_offset_prologue;
245
246 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
247 args.dst_orig = dst_orig;
248
249 if (zp_pad_comp_helper_) {
250 const auto hw
251 = std::div(static_cast<dim_t>(os_offset), chunk_desc.w_size_);
252 args.h = hw.quot + chunk_desc.h_off_;
253 args.w = hw.rem + chunk_desc.w_off_;
254 args.w_size = chunk_desc.w_size_ + chunk_desc.w_off_;
255 args.w_off = chunk_desc.w_off_;
256 args.zp_src_pad_comp = zp.src_pad_comp;
257 const auto zp_src_pad_com_d
258 = zp_pad_comp_helper_->calculate_zp_src_pad_com_d(
259 chunk_desc.d_off_);
260 args.zp_src_pad_com_d_offset = zp_src_pad_com_d.offset;
261 args.should_apply_zp_src_pad_comp_d
262 = zp_src_pad_com_d.should_apply_pad_comp_d;
263 }
264
265 jit_generator::operator()(&args);
266}
267
268template <typename T>
269void jit_pp_ker_t::advance_binary_postops_off(const T &offset) {
270 add(reg_g_oc_off_, offset);
271
272 Xbyak::Label end;
273 cmp(reg_g_oc_off_, jcp_.oc);
274 jl(end, T_NEAR);
275 xor_(reg_g_oc_off_, reg_g_oc_off_);
276
277 L(end);
278}
279void jit_pp_ker_t::zero_binary_postops_off() {
280 xor_(reg_g_oc_off_, reg_g_oc_off_);
281 dst_l_offset_ = 0;
282}
283void jit_pp_ker_t::set_binary_postops_off(const Xbyak::Reg64 &reg) {
284 mov(reg_g_oc_off_, reg);
285 dst_l_offset_ = 0;
286}
287
288Xbyak::Zmm jit_pp_ker_t::reserve_zmm() {
289 return Xbyak::Zmm(number_of_reserved_zmm_regs_++);
290}
291
292int jit_pp_ker_t::vreg_dst_idx(const int idx) const noexcept {
293 return (number_of_reserved_zmm_regs_ + idx * zmm_step_);
294}
295
296Xbyak::Zmm jit_pp_ker_t::get_vreg_dst(int idx) const {
297 return Xbyak::Zmm(vreg_dst_idx(idx));
298}
299
300Xbyak::Zmm jit_pp_ker_t::get_vreg_bias(int idx) const {
301 return Xbyak::Zmm(vreg_dst_idx(idx) + bias_step_factor_);
302}
303
304Xbyak::Zmm jit_pp_ker_t::get_vreg_prev_dst(int idx) const {
305 return Xbyak::Zmm(vreg_dst_idx(idx) + sum_step_factor_);
306}
307
308Xbyak::Zmm jit_pp_ker_t::get_masked_vreg_dst(int idx, bool apply_mask) const {
309 auto vreg_dst = this->get_vreg_dst(idx);
310 if (apply_mask)
311 vreg_dst = vreg_dst | kreg_rem_mask_short_;
312 else
313 vreg_dst = vreg_dst | kreg_rem_mask_vlen_;
314 return vreg_dst;
315}
316
317void jit_pp_ker_t::append_zp_src_comp(size_t offset, int idx, bool apply_mask) {
318 const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask);
319 const auto vreg_dst = get_vreg_dst(idx);
320 const auto zp_src_comp_offset = offset * sizeof(int32_t);
321 const auto zp_src_comp_addr = ptr[reg_zp_src_comp_ + zp_src_comp_offset];
322
323 vpaddd(vreg_dst_masked, vreg_dst, zp_src_comp_addr);
324
325 if (zp_pad_comp_helper_)
326 zp_pad_comp_helper_->zp_src_comp_pad_operation(
327 [&](const Xbyak::Reg64 &reg_zp_pad_comp) {
328 vpaddd(vreg_dst_masked, vreg_dst,
329 ptr[reg_zp_pad_comp + zp_src_comp_offset]);
330 });
331}
332
333void jit_pp_ker_t::apply_postops(const Xbyak::Reg64 &reg_dst, const int idx) {
334#define PARAM_OFF(x) offsetof(ker_args_t, x)
335 if (jcp_.with_eltwise || jcp_.with_binary) {
336 if (jcp_.with_binary) {
337 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
338 const auto vmm_idx = vreg_dst_idx(idx);
339
340 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst);
341 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx,
342 dst_l_offset_ * types::data_type_size(jcp_.dst_data_type));
343 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
344
345 postops_injector_->compute_vector(
346 vreg_dst_idx(idx), rhs_arg_params);
347 } else
348 postops_injector_->compute_vector(vreg_dst_idx(idx));
349 }
350#undef PARAM_OFF
351}
352
353void jit_pp_ker_t::load_as_f32(const Xbyak::Zmm &dst,
354 const Xbyak::Opmask &mask_reg, const Xbyak::Address &src_addr,
355 const data_type_t &src_dt) {
356
357 const auto dst_masked = dst | mask_reg;
358
359 switch (src_dt) {
360 case data_type::s8: vpmovsxbd(dst_masked, src_addr); break;
361 case data_type::u8: vpmovzxbd(dst_masked, src_addr); break;
362 case data_type::s32: vcvtdq2ps(dst_masked, src_addr); break;
363 case data_type::f32: vmovups(dst_masked, src_addr); break;
364 default: assert(!"unimplemented");
365 }
366
367 if (utils::one_of(src_dt, data_type::s8, data_type::u8))
368 vcvtdq2ps(dst_masked, dst);
369}
370
371void jit_pp_ker_t::generate() {
372 using namespace Xbyak;
373 using namespace utils;
374
375 size_t vlen = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
376 for (; vlen >= 1 && (jcp_.oc % vlen != 0); --vlen) {}
377
378 preamble();
379
380#ifdef _WIN32
381 mov(reg_param_, rcx);
382#endif
383
384#define PARAM_OFF(x) offsetof(ker_args_t, x)
385 mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]);
386 mov(reg_acc_, ptr[reg_param_ + PARAM_OFF(acc)]);
387 mov(reg_bias_, ptr[reg_param_ + PARAM_OFF(bias)]);
388 mov(reg_scales_, ptr[reg_param_ + PARAM_OFF(scales)]);
389 mov(reg_len_, ptr[reg_param_ + PARAM_OFF(len)]);
390 mov(reg_oc_offset_, ptr[reg_param_ + PARAM_OFF(oc_offset)]);
391
392 if (jcp_.zp.src_exists) {
393 mov(reg_zp_src_comp_, ptr[reg_param_ + PARAM_OFF(zp_src_comp)]);
394 if (zp_pad_comp_helper_)
395 zp_pad_comp_helper_->init(PARAM_OFF(w), PARAM_OFF(h),
396 PARAM_OFF(w_size), PARAM_OFF(w_off),
397 PARAM_OFF(zp_src_pad_comp), PARAM_OFF(g_oc_offset_prologue),
398 PARAM_OFF(g_oc_offset), PARAM_OFF(zp_src_pad_com_d_offset),
399 PARAM_OFF(should_apply_zp_src_pad_comp_d));
400 }
401
402 if (jcp_.zp.dst_exists) {
403 mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(zp_dst)]);
404 vcvtdq2ps(vreg_zp_dst_common_, ptr_b[reg_tmp_]);
405 }
406
407 if (jcp_.with_dst_scale)
408 vbroadcastss(vreg_dst_scale_, ptr[reg_param_ + PARAM_OFF(dst_scale)]);
409 if (jcp_.with_sum)
410 vbroadcastss(vreg_sum_scale_, ptr[reg_param_ + PARAM_OFF(sum_scale)]);
411 if (jcp_.signed_input)
412 vbroadcastss(
413 vreg_signed_scale_, ptr[reg_param_ + PARAM_OFF(signed_scale)]);
414 if (jcp_.scale_idx_mult == 0) vbroadcastss(vreg_scale_, dword[reg_scales_]);
415#undef PARAM_OFF
416
417 mov(reg_rem_mask_vlen_, 1);
418 shl(reg_rem_mask_vlen_, vlen);
419 sub(reg_rem_mask_vlen_, 1);
420 kmovq(kreg_rem_mask_vlen_, reg_rem_mask_vlen_);
421
422 if (jcp_.with_eltwise) vxorps(vreg_zero_, vreg_zero_, vreg_zero_);
423 if (saturation_needed_)
424 init_saturate_f32(vreg_zero_, vreg_saturation_ubound_, reg_tmp_comp_,
425 data_type::f32, jcp_.dst_data_type);
426
427 if (jcp_.with_binary) set_binary_postops_off(reg_oc_offset_);
428
429 // Load accumulated value, convert to float, apply sum (if any),
430 // bias (if any), scaling, and relu (if any);
431 // then convert to destination type and store
432 const auto compute = [&](size_t offset, int idx, bool apply_mask) {
433 auto acc_addr = ptr[reg_acc_ + offset * sizeof(acc_data_t)];
434
435 const auto &mask_reg
436 = apply_mask ? kreg_rem_mask_short_ : kreg_rem_mask_vlen_;
437
438 if (jcp_.scale_idx_mult > 0) {
439 assert(jcp_.scale_idx_mult == 1);
440 const auto scale_addr = ptr[reg_scales_ + offset * sizeof(float)];
441 auto vreg_scale = vreg_scale_;
442 vreg_scale = vreg_scale | mask_reg;
443 vmovups(vreg_scale, scale_addr);
444 }
445
446 if (jcp_.with_binary) {
447 if (offset) {
448 advance_binary_postops_off(vlen);
449 dst_l_offset_ += offset;
450 }
451 kmovq(opmask_binary, mask_reg);
452 }
453 const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask);
454 const auto vreg_dst = get_vreg_dst(idx);
455 if (jcp_.zp.src_exists) {
456 vmovups(vreg_dst_masked, acc_addr);
457 append_zp_src_comp(offset, idx, apply_mask);
458 vcvtdq2ps(vreg_dst_masked, vreg_dst);
459 } else {
460 vcvtdq2ps(vreg_dst_masked, acc_addr);
461 }
462
463 if (jcp_.signed_input)
464 vmulps(vreg_dst_masked, vreg_dst, vreg_signed_scale_);
465
466 vmulps(vreg_dst_masked, vreg_dst, vreg_scale_);
467
468 if (jcp_.with_bias) {
469 const auto bias_addr
470 = ptr[reg_bias_ + offset * bias_data_type_size_];
471 const auto vreg_bias = get_vreg_bias(idx);
472 load_as_f32(vreg_bias, mask_reg, bias_addr, jcp_.bias_data_type);
473 vaddps(vreg_dst_masked, vreg_dst, vreg_bias);
474 }
475
476 const auto dst_addr = ptr[reg_dst_ + offset * dst_data_type_size_];
477
478 if (jcp_.with_sum) {
479 const auto vreg_prev_dst = get_vreg_prev_dst(idx);
480 load_as_f32(vreg_prev_dst, mask_reg, dst_addr, jcp_.sum_data_type);
481 vfmadd231ps(vreg_dst_masked, vreg_prev_dst, vreg_sum_scale_);
482 }
483
484 apply_postops(reg_dst_, idx);
485
486 if (jcp_.with_dst_scale) {
487 vmulps(vreg_dst_masked, vreg_dst, vreg_dst_scale_);
488 }
489
490 if (jcp_.zp.dst_exists) {
491 vaddps(vreg_dst_masked, vreg_dst, vreg_zp_dst_common_);
492 }
493
494 if (saturation_needed_) {
495 saturate_f32(get_vreg_dst(idx), vreg_zero_, vreg_saturation_ubound_,
496 jcp_.dst_data_type);
497 vcvtps2dq(vreg_dst_masked, vreg_dst);
498 }
499
500 switch (jcp_.dst_data_type) {
501 case data_type::s8: vpmovsdb(dst_addr, vreg_dst_masked); break;
502 case data_type::u8: vpmovusdb(dst_addr, vreg_dst_masked); break;
503 case data_type::f32:
504 case data_type::s32: vmovups(dst_addr, vreg_dst_masked); break;
505 default: assert(!"unimplemented");
506 }
507 };
508
509 // Advance all pointers by an immediate
510 const auto advance_ptrs_imm = [&](const size_t offset,
511 const size_t binary_offset) {
512 add(reg_dst_, offset * dst_data_type_size_);
513 add(reg_acc_, offset * sizeof(acc_data_t));
514 if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); }
515 if (jcp_.scale_idx_mult) {
516 assert(jcp_.scale_idx_mult == 1);
517 add(reg_scales_, offset * sizeof(float));
518 }
519 if (jcp_.with_bias) add(reg_bias_, offset * bias_data_type_size_);
520 if (jcp_.zp.src_exists) {
521 add(reg_zp_src_comp_, offset * sizeof(int32_t));
522
523 if (zp_pad_comp_helper_) {
524 zp_pad_comp_helper_->zp_src_comp_pad_operation(
525 [&](const Xbyak::Reg64 &reg_zp_pad_comp) {
526 add(reg_zp_pad_comp, offset * sizeof(int32_t));
527 });
528 }
529 }
530 };
531
532 // Advance all pointers by a value stored in a register
533 const auto advance_ptrs_reg = [&](const Reg64 offset,
534 const Reg64 binary_offset) {
535 lea(reg_dst_, ptr[reg_dst_ + offset * dst_data_type_size_]);
536 lea(reg_acc_, ptr[reg_acc_ + offset * sizeof(acc_data_t)]);
537 if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); }
538 if (jcp_.scale_idx_mult) {
539 assert(jcp_.scale_idx_mult == 1);
540 lea(reg_scales_, ptr[reg_scales_ + offset * sizeof(float)]);
541 }
542 if (jcp_.with_bias)
543 lea(reg_bias_, ptr[reg_bias_ + offset * bias_data_type_size_]);
544
545 if (jcp_.zp.src_exists) {
546 lea(reg_zp_src_comp_,
547 ptr[reg_zp_src_comp_ + offset * sizeof(int32_t)]);
548
549 if (zp_pad_comp_helper_)
550 zp_pad_comp_helper_->zp_src_comp_pad_operation(
551 [&](const Xbyak::Reg64 &reg_zp_pad_comp) {
552 lea(reg_zp_pad_comp,
553 ptr[reg_zp_pad_comp
554 + offset * sizeof(int32_t)]);
555 });
556 }
557 };
558
559 // Rewind pointers that point to data that is indexed by output channel
560 // (bias or per-oc scaling factors)
561 const auto rewind_ptrs = [&]() {
562 if (jcp_.with_bias) sub(reg_bias_, jcp_.oc * bias_data_type_size_);
563 if (jcp_.with_binary) {
564 zero_binary_postops_off();
565 dst_l_offset_ = 0;
566 }
567 if (jcp_.zp.src_exists) {
568 const auto offset = jcp_.oc * sizeof(int32_t);
569 sub(reg_zp_src_comp_, offset);
570 if (zp_pad_comp_helper_)
571 zp_pad_comp_helper_->load_next_point_zp_src_comp_pad_addr();
572 }
573 if (jcp_.scale_idx_mult) {
574 assert(jcp_.scale_idx_mult == 1);
575 sub(reg_scales_, jcp_.oc * sizeof(float));
576 }
577 add(reg_dst_, (jcp_.dst_os_stride - jcp_.oc) * dst_data_type_size_);
578 };
579
580 // <--------- OC --------------->
581 //
582 // ^ ................+..............+-------------+.......................
583 // | . : not accessed |Prologue loop| .
584 // | . +--------------+-------------+ .
585 // . | | .
586 // O . | Main loop (unrolled) | .
587 // S . | | .
588 // . +--------------+-------------+ .
589 // | . | Epilogue loop|not accessed : .
590 // v ................+--------------+.............+.......................
591
592 Label prologue_end;
593 cmp(reg_oc_offset_, 0);
594 je(prologue_end, T_NEAR);
595
596 // Prologue loop
597 {
598 mov(reg_tmp_, jcp_.oc);
599 sub(reg_tmp_, reg_oc_offset_);
600 cmp(reg_tmp_, reg_len_);
601 cmovg(reg_tmp_, reg_len_);
602 sub(reg_len_, reg_tmp_);
603
604 Label prologue_loop, prologue_loop_tail, prologue_loop_end;
605 cmp(reg_tmp_, vlen);
606 jle(prologue_loop_tail, T_NEAR);
607 L(prologue_loop);
608 {
609 compute(0, max_unroll_ - 1, false);
610 advance_ptrs_imm(vlen, vlen);
611 sub(reg_tmp_, vlen);
612 cmp(reg_tmp_, vlen);
613 jge(prologue_loop, T_NEAR);
614 }
615
616 L(prologue_loop_tail);
617 mov(reg_rem_mask_short_, 1);
618 // cl == reg_tmp_ because reg_tmp_ <= vlen here
619 shl(reg_rem_mask_short_, cl);
620 sub(reg_rem_mask_short_, 1);
621 jz(prologue_loop_end, T_NEAR);
622
623 kmovq(kreg_rem_mask_short_, reg_rem_mask_short_);
624 compute(0, max_unroll_ - 1, true);
625 advance_ptrs_reg(reg_tmp_, reg_tmp_);
626
627 L(prologue_loop_end);
628 rewind_ptrs();
629 }
630 L(prologue_end);
631
632 // Main loop
633 Label main_loop_end;
634 {
635 cmp(reg_len_, jcp_.oc);
636 jle(main_loop_end, T_NEAR);
637
638 Label main_loop;
639 L(main_loop);
640 {
641 size_t OC_loop, OC_tail;
642 if (static_cast<size_t>(jcp_.oc) < max_unroll_ * vlen) {
643 // Fully unroll small loops
644 OC_loop = 0;
645 OC_tail = jcp_.oc;
646 } else {
647 OC_loop = vlen * def_unroll_;
648 OC_tail = jcp_.oc % OC_loop;
649 }
650
651 assert(!!OC_loop || !!OC_tail);
652
653 const int vlen_tail = OC_tail % vlen;
654 if (vlen_tail) {
655 unsigned tail_mask = (1 << vlen_tail) - 1;
656 mov(reg_tmp_, tail_mask);
657 kmovq(kreg_rem_mask_short_, reg_tmp_);
658 }
659
660 if (OC_loop) {
661 mov(reg_tmp_, rnd_dn(jcp_.oc, OC_loop));
662 Label oc_loop;
663 L(oc_loop);
664 {
665 for (size_t offset = 0; offset < OC_loop; offset += vlen)
666 compute(offset, offset / vlen, false);
667 advance_ptrs_imm(OC_loop, vlen);
668 sub(reg_tmp_, OC_loop);
669 jnz(oc_loop);
670 }
671 }
672
673 if (OC_tail) {
674 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
675 bool use_mask = (offset + vlen) > OC_tail;
676 compute(offset, offset / vlen, use_mask);
677 }
678 const size_t oc_tail_rem = OC_tail % vlen;
679 const size_t binary_offset = oc_tail_rem ? oc_tail_rem : vlen;
680 advance_ptrs_imm(OC_tail, binary_offset);
681 }
682
683 rewind_ptrs();
684 sub(reg_len_, jcp_.oc);
685 cmp(reg_len_, jcp_.oc);
686 jge(main_loop, T_NEAR);
687 }
688 }
689 L(main_loop_end);
690
691 // Epilogue loop
692 Label epilogue_end;
693 {
694 cmp(reg_len_, 0);
695 je(epilogue_end, T_NEAR);
696
697 Label epilogue_loop, epilogue_loop_tail;
698 cmp(reg_len_, vlen);
699 jle(epilogue_loop_tail, T_NEAR);
700 L(epilogue_loop);
701 {
702 compute(0, 0, false);
703 sub(reg_len_, vlen);
704 advance_ptrs_imm(vlen, vlen);
705 cmp(reg_len_, vlen);
706 jge(epilogue_loop, T_NEAR);
707 }
708
709 L(epilogue_loop_tail);
710 mov(reg_tmp_,
711 reg_len_); // reg_tmp_ is rcx, and we need cl for the shift
712 mov(reg_rem_mask_short_, 1);
713 shl(reg_rem_mask_short_, cl); // reg_tmp_ == rcx and reg_tail < vlen
714 sub(reg_rem_mask_short_, 1);
715 jz(epilogue_end, T_NEAR);
716 kmovq(kreg_rem_mask_short_, reg_rem_mask_short_);
717 compute(0, 0, true);
718 }
719
720 L(epilogue_end);
721
722 if (zp_pad_comp_helper_) zp_pad_comp_helper_->fin();
723
724 postamble();
725
726 if (jcp_.with_eltwise) postops_injector_->prepare_table();
727}
728
729bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept {
730 const auto is_bf16_dst_dt = dst_dt == data_type::bf16;
731 return mayiuse(avx512_core) && !is_bf16_dst_dt;
732}
733
734pp_ker_t *jit_pp_ker_create(
735 const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
736 return mayiuse_jit_pp_kernel(pd->dst_md()->data_type)
737 ? new jit_pp_ker_t(pd, jcp)
738 : nullptr;
739}
740
741bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) {
742 using namespace x64::injector;
743 static constexpr bool sum_at_pos_0_only = true;
744 static constexpr bool sum_requires_scale_one = false;
745 return mayiuse_jit_pp_kernel(dst_d->data_type())
746 && dnnl::impl::cpu::x64::injector::post_ops_ok(
747 {avx512_core, {binary, eltwise, sum}, post_ops, dst_d,
748 sum_at_pos_0_only, sum_requires_scale_one});
749}
750
751} // namespace gemm_x8s8s32x_convolution_utils
752} // namespace x64
753} // namespace cpu
754} // namespace impl
755} // namespace dnnl
756