1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstdlib> |
18 | #include <functional> |
19 | |
20 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
21 | #include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp" |
22 | #include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp" |
23 | #include "cpu/x64/jit_generator.hpp" |
24 | #include "cpu/x64/jit_primitive_conf.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | namespace gemm_x8s8s32x_convolution_utils { |
32 | using namespace dnnl::impl::cpu::gemm_x8s8s32x_convolution_utils; |
33 | |
34 | struct jit_pp_ker_t : pp_ker_t, public jit_generator { |
35 | DECLARE_CPU_JIT_AUX_FUNCTIONS( |
36 | gemm_x8s8s32x_convolution_utils::jit_pp_ker_t); |
37 | |
38 | jit_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp); |
39 | |
40 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
41 | void operator()(void *void_dst, const acc_data_t *acc, const char *bias, |
42 | const float *scales, float dst_scale, float sum_scale, |
43 | float signed_scale, int g, size_t start, size_t end, |
44 | const zero_point_call_params_t &zp, |
45 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
46 | const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */, |
47 | const single_gemm_conv_chunk_desc_t &) const override; |
48 | |
49 | private: |
50 | void apply_postops(const Xbyak::Reg64 ®_dst, const int idx); |
51 | void generate() override; |
52 | void append_zp_src_comp(size_t offset, int idx, bool apply_mask); |
53 | void load_as_f32(const Xbyak::Zmm &dst, const Xbyak::Opmask &mask, |
54 | const Xbyak::Address &src_addr, const data_type_t &src_dt); |
55 | |
56 | int vreg_dst_idx(const int idx) const noexcept; |
57 | Xbyak::Zmm get_vreg_dst(int idx) const; |
58 | Xbyak::Zmm get_vreg_bias(int idx) const; |
59 | Xbyak::Zmm get_vreg_prev_dst(int idx) const; |
60 | Xbyak::Zmm get_vreg_zp_comp_src(int idx) const; |
61 | Xbyak::Zmm get_masked_vreg_dst(int idx, bool apply_mask) const; |
62 | Xbyak::Zmm reserve_zmm(); |
63 | |
64 | template <typename T> |
65 | void advance_binary_postops_off(const T &offset); |
66 | void zero_binary_postops_off(); |
67 | void set_binary_postops_off(const Xbyak::Reg64 ®); |
68 | const Xbyak::Opmask &opmask_binary = k2; |
69 | |
70 | struct ker_args_t { |
71 | char *dst; |
72 | const acc_data_t *acc; |
73 | const char *bias; |
74 | const float *scales; |
75 | float dst_scale; |
76 | float sum_scale; |
77 | float signed_scale; |
78 | size_t len; |
79 | size_t oc_offset; |
80 | const int32_t *zp_src; |
81 | const int32_t *zp_dst; |
82 | const int32_t *zp_src_comp; |
83 | const int32_t *zp_src_pad_comp; |
84 | size_t g_oc_offset_prologue; |
85 | size_t g_oc_offset; |
86 | const void *post_ops_binary_rhs_arg_vec; |
87 | const void *dst_orig; |
88 | dim_t h; |
89 | dim_t w; |
90 | dim_t w_size; |
91 | dim_t w_off; |
92 | dim_t zp_src_pad_com_d_offset; |
93 | bool should_apply_zp_src_pad_comp_d; |
94 | }; |
95 | |
96 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
97 | postops_injector_; |
98 | |
99 | size_t number_of_reserved_zmm_regs_; |
100 | const size_t bias_data_type_size_; |
101 | const size_t dst_data_type_size_; |
102 | const bool saturation_needed_; |
103 | |
104 | const Xbyak::Reg64 ®_param_ = rdi; |
105 | const Xbyak::Reg64 ®_tmp_ = rcx; // intentional for shifting purposes |
106 | |
107 | const Xbyak::Reg64 ®_dst_ = rdx; |
108 | const Xbyak::Reg64 ®_acc_ = rax; |
109 | const Xbyak::Reg64 ®_bias_ = rbx; |
110 | const Xbyak::Reg64 ®_scales_ = rsi; |
111 | const Xbyak::Reg64 ®_len_ = r8; |
112 | const Xbyak::Reg64 ®_oc_offset_ = r9; |
113 | const Xbyak::Reg64 ®_rem_mask_short_ = r10; |
114 | const Xbyak::Reg64 ®_rem_mask_vlen_ = reg_rem_mask_short_; |
115 | const Xbyak::Reg64 ®_zp_pad_comp_temp_ = r10; |
116 | const Xbyak::Reg64 ®_zp_pad_comp_ = r11; |
117 | const Xbyak::Reg8 ®_should_apply_src_pad_comp_ = r13b; |
118 | |
119 | const Xbyak::Reg64 ®_tmp_comp_ |
120 | = r12; // used to broadcast scalar values to vreg |
121 | const Xbyak::Reg64 ®_g_oc_off_ = reg_tmp_comp_; |
122 | const Xbyak::Reg64 ®_zp_src_comp_ = r14; |
123 | |
124 | const Xbyak::Zmm vreg_zero_; |
125 | const Xbyak::Zmm vreg_scale_; |
126 | const Xbyak::Zmm vreg_dst_scale_; |
127 | const Xbyak::Zmm vreg_sum_scale_; |
128 | const Xbyak::Zmm vreg_signed_scale_; |
129 | const Xbyak::Zmm vreg_saturation_ubound_; |
130 | const Xbyak::Zmm vreg_zp_dst_common_; |
131 | |
132 | const Xbyak::Opmask &kreg_rem_mask_short_ = k3; |
133 | const Xbyak::Opmask &kreg_rem_mask_vlen_ = k4; |
134 | |
135 | static constexpr size_t def_unroll_ = 4u; |
136 | size_t zmm_step_; |
137 | const size_t bias_step_factor_; |
138 | const size_t sum_step_factor_; |
139 | const size_t max_unroll_; |
140 | int dst_l_offset_ = 0; |
141 | |
142 | std::unique_ptr<jit_gemm_x8s8s32x_zp_pad_comp_helper> zp_pad_comp_helper_; |
143 | }; |
144 | |
145 | jit_pp_ker_t::jit_pp_ker_t( |
146 | const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) |
147 | : pp_ker_t(pd, jcp) |
148 | , jit_generator(jit_name()) |
149 | , number_of_reserved_zmm_regs_(0) |
150 | , bias_data_type_size_(jcp.bias_data_type != data_type::undef |
151 | ? types::data_type_size(jcp.bias_data_type) |
152 | : 0u) |
153 | , dst_data_type_size_(types::data_type_size(jcp.dst_data_type)) |
154 | , saturation_needed_(utils::one_of( |
155 | jcp_.dst_data_type, data_type::u8, data_type::s8, data_type::s32)) |
156 | , vreg_zero_((jcp_.with_eltwise || saturation_needed_) ? reserve_zmm() |
157 | : Xbyak::Zmm(0)) |
158 | , vreg_scale_(reserve_zmm()) |
159 | , vreg_dst_scale_(reserve_zmm()) |
160 | , vreg_sum_scale_(jcp_.with_sum ? reserve_zmm() : Xbyak::Zmm(0)) |
161 | , vreg_signed_scale_(jcp_.signed_input ? reserve_zmm() : Xbyak::Zmm(0)) |
162 | , vreg_saturation_ubound_( |
163 | saturation_needed_ ? reserve_zmm() : Xbyak::Zmm(0)) |
164 | , vreg_zp_dst_common_(jcp_.zp.dst_exists ? reserve_zmm() : Xbyak::Zmm(0)) |
165 | , zmm_step_(1u) |
166 | , bias_step_factor_(jcp_.with_bias ? zmm_step_++ : 0u) |
167 | , sum_step_factor_(jcp_.with_sum ? zmm_step_++ : 0) |
168 | , max_unroll_((cpu_isa_traits<avx512_core>::n_vregs |
169 | - number_of_reserved_zmm_regs_) |
170 | / zmm_step_) |
171 | , zp_pad_comp_helper_(jit_gemm_convolution_utils::padding_exists(jcp) |
172 | && jcp.zp.src_exists |
173 | ? utils::make_unique< |
174 | jit_gemm_x8s8s32x_zp_pad_comp_helper>(this, jcp_, |
175 | reg_zp_pad_comp_, reg_zp_pad_comp_temp_, |
176 | reg_should_apply_src_pad_comp_, |
177 | pd->src_md()->ndims) |
178 | : nullptr) |
179 | |
180 | { |
181 | |
182 | if (jcp.with_eltwise || jcp.with_binary) { |
183 | using namespace binary_injector; |
184 | static constexpr bool preserve_gpr = true; |
185 | static constexpr bool preserve_vmm = true; |
186 | static constexpr size_t helper_vmm_idx = 31; |
187 | // tail_size = 1 just indicates that tailing is to be performed |
188 | // actual tail value is held in opmask passed to injector |
189 | static constexpr size_t tail_size = 1; |
190 | static constexpr bool use_exact_tail_scalar_bcast = false; |
191 | |
192 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
193 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
194 | r13, r14, r15, preserve_gpr, preserve_vmm, |
195 | PARAM_OFF(post_ops_binary_rhs_arg_vec), PARAM_OFF(dst_orig), |
196 | memory_desc_wrapper(pd->dst_md()), tail_size, opmask_binary, |
197 | use_exact_tail_scalar_bcast}; |
198 | #undef PARAM_OFF |
199 | |
200 | const static_params_t static_params {reg_param_, rhs_arg_static_params}; |
201 | |
202 | postops_injector_ = utils::make_unique< |
203 | injector::jit_uni_postops_injector_t<avx512_core>>( |
204 | this, jcp_.post_ops, static_params); |
205 | } |
206 | } |
207 | |
208 | void jit_pp_ker_t::operator()(void *void_dst, const acc_data_t *acc, |
209 | const char *bias, const float *scales, float dst_scale, float sum_scale, |
210 | float signed_scale, int g, size_t start, size_t end, |
211 | const zero_point_call_params_t &zp, |
212 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
213 | const exec_ctx_t & /* ctx */, const memory_desc_t & /* dst_md */, |
214 | const single_gemm_conv_chunk_desc_t &chunk_desc) const { |
215 | |
216 | if (end <= start) return; |
217 | |
218 | char *dst = (char *)void_dst; |
219 | |
220 | ker_args_t args; |
221 | const auto dv = std::div(start, jcp_.oc); |
222 | const size_t oc_offset = dv.rem; |
223 | const size_t os_offset = dv.quot; |
224 | args.acc = acc + start; |
225 | args.dst = dst |
226 | + (os_offset * jcp_.dst_os_stride + oc_offset) |
227 | * dst_data_type_size_; |
228 | |
229 | const ptrdiff_t g_oc_offset = g * jcp_.oc; |
230 | const ptrdiff_t g_oc_offset_prologue = g_oc_offset + oc_offset; |
231 | args.bias = bias + g_oc_offset_prologue * bias_data_type_size_; |
232 | args.zp_src = zp.src + (jcp_.zp.src_is_common ? 0 : g_oc_offset_prologue); |
233 | args.zp_src_comp |
234 | = zp.src_comp ? zp.src_comp + g_oc_offset_prologue : nullptr; |
235 | args.zp_dst = zp.dst; |
236 | args.scales = scales + jcp_.scale_idx_mult * g_oc_offset_prologue; |
237 | args.dst_scale = dst_scale; |
238 | args.sum_scale = sum_scale; |
239 | args.signed_scale = signed_scale; |
240 | args.len = end - start; |
241 | args.oc_offset = oc_offset; |
242 | |
243 | args.g_oc_offset = g_oc_offset; |
244 | args.g_oc_offset_prologue = g_oc_offset_prologue; |
245 | |
246 | args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
247 | args.dst_orig = dst_orig; |
248 | |
249 | if (zp_pad_comp_helper_) { |
250 | const auto hw |
251 | = std::div(static_cast<dim_t>(os_offset), chunk_desc.w_size_); |
252 | args.h = hw.quot + chunk_desc.h_off_; |
253 | args.w = hw.rem + chunk_desc.w_off_; |
254 | args.w_size = chunk_desc.w_size_ + chunk_desc.w_off_; |
255 | args.w_off = chunk_desc.w_off_; |
256 | args.zp_src_pad_comp = zp.src_pad_comp; |
257 | const auto zp_src_pad_com_d |
258 | = zp_pad_comp_helper_->calculate_zp_src_pad_com_d( |
259 | chunk_desc.d_off_); |
260 | args.zp_src_pad_com_d_offset = zp_src_pad_com_d.offset; |
261 | args.should_apply_zp_src_pad_comp_d |
262 | = zp_src_pad_com_d.should_apply_pad_comp_d; |
263 | } |
264 | |
265 | jit_generator::operator()(&args); |
266 | } |
267 | |
268 | template <typename T> |
269 | void jit_pp_ker_t::advance_binary_postops_off(const T &offset) { |
270 | add(reg_g_oc_off_, offset); |
271 | |
272 | Xbyak::Label end; |
273 | cmp(reg_g_oc_off_, jcp_.oc); |
274 | jl(end, T_NEAR); |
275 | xor_(reg_g_oc_off_, reg_g_oc_off_); |
276 | |
277 | L(end); |
278 | } |
279 | void jit_pp_ker_t::zero_binary_postops_off() { |
280 | xor_(reg_g_oc_off_, reg_g_oc_off_); |
281 | dst_l_offset_ = 0; |
282 | } |
283 | void jit_pp_ker_t::set_binary_postops_off(const Xbyak::Reg64 ®) { |
284 | mov(reg_g_oc_off_, reg); |
285 | dst_l_offset_ = 0; |
286 | } |
287 | |
288 | Xbyak::Zmm jit_pp_ker_t::reserve_zmm() { |
289 | return Xbyak::Zmm(number_of_reserved_zmm_regs_++); |
290 | } |
291 | |
292 | int jit_pp_ker_t::vreg_dst_idx(const int idx) const noexcept { |
293 | return (number_of_reserved_zmm_regs_ + idx * zmm_step_); |
294 | } |
295 | |
296 | Xbyak::Zmm jit_pp_ker_t::get_vreg_dst(int idx) const { |
297 | return Xbyak::Zmm(vreg_dst_idx(idx)); |
298 | } |
299 | |
300 | Xbyak::Zmm jit_pp_ker_t::get_vreg_bias(int idx) const { |
301 | return Xbyak::Zmm(vreg_dst_idx(idx) + bias_step_factor_); |
302 | } |
303 | |
304 | Xbyak::Zmm jit_pp_ker_t::get_vreg_prev_dst(int idx) const { |
305 | return Xbyak::Zmm(vreg_dst_idx(idx) + sum_step_factor_); |
306 | } |
307 | |
308 | Xbyak::Zmm jit_pp_ker_t::get_masked_vreg_dst(int idx, bool apply_mask) const { |
309 | auto vreg_dst = this->get_vreg_dst(idx); |
310 | if (apply_mask) |
311 | vreg_dst = vreg_dst | kreg_rem_mask_short_; |
312 | else |
313 | vreg_dst = vreg_dst | kreg_rem_mask_vlen_; |
314 | return vreg_dst; |
315 | } |
316 | |
317 | void jit_pp_ker_t::append_zp_src_comp(size_t offset, int idx, bool apply_mask) { |
318 | const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask); |
319 | const auto vreg_dst = get_vreg_dst(idx); |
320 | const auto zp_src_comp_offset = offset * sizeof(int32_t); |
321 | const auto zp_src_comp_addr = ptr[reg_zp_src_comp_ + zp_src_comp_offset]; |
322 | |
323 | vpaddd(vreg_dst_masked, vreg_dst, zp_src_comp_addr); |
324 | |
325 | if (zp_pad_comp_helper_) |
326 | zp_pad_comp_helper_->zp_src_comp_pad_operation( |
327 | [&](const Xbyak::Reg64 ®_zp_pad_comp) { |
328 | vpaddd(vreg_dst_masked, vreg_dst, |
329 | ptr[reg_zp_pad_comp + zp_src_comp_offset]); |
330 | }); |
331 | } |
332 | |
333 | void jit_pp_ker_t::apply_postops(const Xbyak::Reg64 ®_dst, const int idx) { |
334 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
335 | if (jcp_.with_eltwise || jcp_.with_binary) { |
336 | if (jcp_.with_binary) { |
337 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
338 | const auto vmm_idx = vreg_dst_idx(idx); |
339 | |
340 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); |
341 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, |
342 | dst_l_offset_ * types::data_type_size(jcp_.dst_data_type)); |
343 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
344 | |
345 | postops_injector_->compute_vector( |
346 | vreg_dst_idx(idx), rhs_arg_params); |
347 | } else |
348 | postops_injector_->compute_vector(vreg_dst_idx(idx)); |
349 | } |
350 | #undef PARAM_OFF |
351 | } |
352 | |
353 | void jit_pp_ker_t::load_as_f32(const Xbyak::Zmm &dst, |
354 | const Xbyak::Opmask &mask_reg, const Xbyak::Address &src_addr, |
355 | const data_type_t &src_dt) { |
356 | |
357 | const auto dst_masked = dst | mask_reg; |
358 | |
359 | switch (src_dt) { |
360 | case data_type::s8: vpmovsxbd(dst_masked, src_addr); break; |
361 | case data_type::u8: vpmovzxbd(dst_masked, src_addr); break; |
362 | case data_type::s32: vcvtdq2ps(dst_masked, src_addr); break; |
363 | case data_type::f32: vmovups(dst_masked, src_addr); break; |
364 | default: assert(!"unimplemented" ); |
365 | } |
366 | |
367 | if (utils::one_of(src_dt, data_type::s8, data_type::u8)) |
368 | vcvtdq2ps(dst_masked, dst); |
369 | } |
370 | |
371 | void jit_pp_ker_t::generate() { |
372 | using namespace Xbyak; |
373 | using namespace utils; |
374 | |
375 | size_t vlen = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
376 | for (; vlen >= 1 && (jcp_.oc % vlen != 0); --vlen) {} |
377 | |
378 | preamble(); |
379 | |
380 | #ifdef _WIN32 |
381 | mov(reg_param_, rcx); |
382 | #endif |
383 | |
384 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
385 | mov(reg_dst_, ptr[reg_param_ + PARAM_OFF(dst)]); |
386 | mov(reg_acc_, ptr[reg_param_ + PARAM_OFF(acc)]); |
387 | mov(reg_bias_, ptr[reg_param_ + PARAM_OFF(bias)]); |
388 | mov(reg_scales_, ptr[reg_param_ + PARAM_OFF(scales)]); |
389 | mov(reg_len_, ptr[reg_param_ + PARAM_OFF(len)]); |
390 | mov(reg_oc_offset_, ptr[reg_param_ + PARAM_OFF(oc_offset)]); |
391 | |
392 | if (jcp_.zp.src_exists) { |
393 | mov(reg_zp_src_comp_, ptr[reg_param_ + PARAM_OFF(zp_src_comp)]); |
394 | if (zp_pad_comp_helper_) |
395 | zp_pad_comp_helper_->init(PARAM_OFF(w), PARAM_OFF(h), |
396 | PARAM_OFF(w_size), PARAM_OFF(w_off), |
397 | PARAM_OFF(zp_src_pad_comp), PARAM_OFF(g_oc_offset_prologue), |
398 | PARAM_OFF(g_oc_offset), PARAM_OFF(zp_src_pad_com_d_offset), |
399 | PARAM_OFF(should_apply_zp_src_pad_comp_d)); |
400 | } |
401 | |
402 | if (jcp_.zp.dst_exists) { |
403 | mov(reg_tmp_, ptr[reg_param_ + PARAM_OFF(zp_dst)]); |
404 | vcvtdq2ps(vreg_zp_dst_common_, ptr_b[reg_tmp_]); |
405 | } |
406 | |
407 | if (jcp_.with_dst_scale) |
408 | vbroadcastss(vreg_dst_scale_, ptr[reg_param_ + PARAM_OFF(dst_scale)]); |
409 | if (jcp_.with_sum) |
410 | vbroadcastss(vreg_sum_scale_, ptr[reg_param_ + PARAM_OFF(sum_scale)]); |
411 | if (jcp_.signed_input) |
412 | vbroadcastss( |
413 | vreg_signed_scale_, ptr[reg_param_ + PARAM_OFF(signed_scale)]); |
414 | if (jcp_.scale_idx_mult == 0) vbroadcastss(vreg_scale_, dword[reg_scales_]); |
415 | #undef PARAM_OFF |
416 | |
417 | mov(reg_rem_mask_vlen_, 1); |
418 | shl(reg_rem_mask_vlen_, vlen); |
419 | sub(reg_rem_mask_vlen_, 1); |
420 | kmovq(kreg_rem_mask_vlen_, reg_rem_mask_vlen_); |
421 | |
422 | if (jcp_.with_eltwise) vxorps(vreg_zero_, vreg_zero_, vreg_zero_); |
423 | if (saturation_needed_) |
424 | init_saturate_f32(vreg_zero_, vreg_saturation_ubound_, reg_tmp_comp_, |
425 | data_type::f32, jcp_.dst_data_type); |
426 | |
427 | if (jcp_.with_binary) set_binary_postops_off(reg_oc_offset_); |
428 | |
429 | // Load accumulated value, convert to float, apply sum (if any), |
430 | // bias (if any), scaling, and relu (if any); |
431 | // then convert to destination type and store |
432 | const auto compute = [&](size_t offset, int idx, bool apply_mask) { |
433 | auto acc_addr = ptr[reg_acc_ + offset * sizeof(acc_data_t)]; |
434 | |
435 | const auto &mask_reg |
436 | = apply_mask ? kreg_rem_mask_short_ : kreg_rem_mask_vlen_; |
437 | |
438 | if (jcp_.scale_idx_mult > 0) { |
439 | assert(jcp_.scale_idx_mult == 1); |
440 | const auto scale_addr = ptr[reg_scales_ + offset * sizeof(float)]; |
441 | auto vreg_scale = vreg_scale_; |
442 | vreg_scale = vreg_scale | mask_reg; |
443 | vmovups(vreg_scale, scale_addr); |
444 | } |
445 | |
446 | if (jcp_.with_binary) { |
447 | if (offset) { |
448 | advance_binary_postops_off(vlen); |
449 | dst_l_offset_ += offset; |
450 | } |
451 | kmovq(opmask_binary, mask_reg); |
452 | } |
453 | const auto vreg_dst_masked = get_masked_vreg_dst(idx, apply_mask); |
454 | const auto vreg_dst = get_vreg_dst(idx); |
455 | if (jcp_.zp.src_exists) { |
456 | vmovups(vreg_dst_masked, acc_addr); |
457 | append_zp_src_comp(offset, idx, apply_mask); |
458 | vcvtdq2ps(vreg_dst_masked, vreg_dst); |
459 | } else { |
460 | vcvtdq2ps(vreg_dst_masked, acc_addr); |
461 | } |
462 | |
463 | if (jcp_.signed_input) |
464 | vmulps(vreg_dst_masked, vreg_dst, vreg_signed_scale_); |
465 | |
466 | vmulps(vreg_dst_masked, vreg_dst, vreg_scale_); |
467 | |
468 | if (jcp_.with_bias) { |
469 | const auto bias_addr |
470 | = ptr[reg_bias_ + offset * bias_data_type_size_]; |
471 | const auto vreg_bias = get_vreg_bias(idx); |
472 | load_as_f32(vreg_bias, mask_reg, bias_addr, jcp_.bias_data_type); |
473 | vaddps(vreg_dst_masked, vreg_dst, vreg_bias); |
474 | } |
475 | |
476 | const auto dst_addr = ptr[reg_dst_ + offset * dst_data_type_size_]; |
477 | |
478 | if (jcp_.with_sum) { |
479 | const auto vreg_prev_dst = get_vreg_prev_dst(idx); |
480 | load_as_f32(vreg_prev_dst, mask_reg, dst_addr, jcp_.sum_data_type); |
481 | vfmadd231ps(vreg_dst_masked, vreg_prev_dst, vreg_sum_scale_); |
482 | } |
483 | |
484 | apply_postops(reg_dst_, idx); |
485 | |
486 | if (jcp_.with_dst_scale) { |
487 | vmulps(vreg_dst_masked, vreg_dst, vreg_dst_scale_); |
488 | } |
489 | |
490 | if (jcp_.zp.dst_exists) { |
491 | vaddps(vreg_dst_masked, vreg_dst, vreg_zp_dst_common_); |
492 | } |
493 | |
494 | if (saturation_needed_) { |
495 | saturate_f32(get_vreg_dst(idx), vreg_zero_, vreg_saturation_ubound_, |
496 | jcp_.dst_data_type); |
497 | vcvtps2dq(vreg_dst_masked, vreg_dst); |
498 | } |
499 | |
500 | switch (jcp_.dst_data_type) { |
501 | case data_type::s8: vpmovsdb(dst_addr, vreg_dst_masked); break; |
502 | case data_type::u8: vpmovusdb(dst_addr, vreg_dst_masked); break; |
503 | case data_type::f32: |
504 | case data_type::s32: vmovups(dst_addr, vreg_dst_masked); break; |
505 | default: assert(!"unimplemented" ); |
506 | } |
507 | }; |
508 | |
509 | // Advance all pointers by an immediate |
510 | const auto advance_ptrs_imm = [&](const size_t offset, |
511 | const size_t binary_offset) { |
512 | add(reg_dst_, offset * dst_data_type_size_); |
513 | add(reg_acc_, offset * sizeof(acc_data_t)); |
514 | if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); } |
515 | if (jcp_.scale_idx_mult) { |
516 | assert(jcp_.scale_idx_mult == 1); |
517 | add(reg_scales_, offset * sizeof(float)); |
518 | } |
519 | if (jcp_.with_bias) add(reg_bias_, offset * bias_data_type_size_); |
520 | if (jcp_.zp.src_exists) { |
521 | add(reg_zp_src_comp_, offset * sizeof(int32_t)); |
522 | |
523 | if (zp_pad_comp_helper_) { |
524 | zp_pad_comp_helper_->zp_src_comp_pad_operation( |
525 | [&](const Xbyak::Reg64 ®_zp_pad_comp) { |
526 | add(reg_zp_pad_comp, offset * sizeof(int32_t)); |
527 | }); |
528 | } |
529 | } |
530 | }; |
531 | |
532 | // Advance all pointers by a value stored in a register |
533 | const auto advance_ptrs_reg = [&](const Reg64 offset, |
534 | const Reg64 binary_offset) { |
535 | lea(reg_dst_, ptr[reg_dst_ + offset * dst_data_type_size_]); |
536 | lea(reg_acc_, ptr[reg_acc_ + offset * sizeof(acc_data_t)]); |
537 | if (jcp_.with_binary) { advance_binary_postops_off(binary_offset); } |
538 | if (jcp_.scale_idx_mult) { |
539 | assert(jcp_.scale_idx_mult == 1); |
540 | lea(reg_scales_, ptr[reg_scales_ + offset * sizeof(float)]); |
541 | } |
542 | if (jcp_.with_bias) |
543 | lea(reg_bias_, ptr[reg_bias_ + offset * bias_data_type_size_]); |
544 | |
545 | if (jcp_.zp.src_exists) { |
546 | lea(reg_zp_src_comp_, |
547 | ptr[reg_zp_src_comp_ + offset * sizeof(int32_t)]); |
548 | |
549 | if (zp_pad_comp_helper_) |
550 | zp_pad_comp_helper_->zp_src_comp_pad_operation( |
551 | [&](const Xbyak::Reg64 ®_zp_pad_comp) { |
552 | lea(reg_zp_pad_comp, |
553 | ptr[reg_zp_pad_comp |
554 | + offset * sizeof(int32_t)]); |
555 | }); |
556 | } |
557 | }; |
558 | |
559 | // Rewind pointers that point to data that is indexed by output channel |
560 | // (bias or per-oc scaling factors) |
561 | const auto rewind_ptrs = [&]() { |
562 | if (jcp_.with_bias) sub(reg_bias_, jcp_.oc * bias_data_type_size_); |
563 | if (jcp_.with_binary) { |
564 | zero_binary_postops_off(); |
565 | dst_l_offset_ = 0; |
566 | } |
567 | if (jcp_.zp.src_exists) { |
568 | const auto offset = jcp_.oc * sizeof(int32_t); |
569 | sub(reg_zp_src_comp_, offset); |
570 | if (zp_pad_comp_helper_) |
571 | zp_pad_comp_helper_->load_next_point_zp_src_comp_pad_addr(); |
572 | } |
573 | if (jcp_.scale_idx_mult) { |
574 | assert(jcp_.scale_idx_mult == 1); |
575 | sub(reg_scales_, jcp_.oc * sizeof(float)); |
576 | } |
577 | add(reg_dst_, (jcp_.dst_os_stride - jcp_.oc) * dst_data_type_size_); |
578 | }; |
579 | |
580 | // <--------- OC ---------------> |
581 | // |
582 | // ^ ................+..............+-------------+....................... |
583 | // | . : not accessed |Prologue loop| . |
584 | // | . +--------------+-------------+ . |
585 | // . | | . |
586 | // O . | Main loop (unrolled) | . |
587 | // S . | | . |
588 | // . +--------------+-------------+ . |
589 | // | . | Epilogue loop|not accessed : . |
590 | // v ................+--------------+.............+....................... |
591 | |
592 | Label prologue_end; |
593 | cmp(reg_oc_offset_, 0); |
594 | je(prologue_end, T_NEAR); |
595 | |
596 | // Prologue loop |
597 | { |
598 | mov(reg_tmp_, jcp_.oc); |
599 | sub(reg_tmp_, reg_oc_offset_); |
600 | cmp(reg_tmp_, reg_len_); |
601 | cmovg(reg_tmp_, reg_len_); |
602 | sub(reg_len_, reg_tmp_); |
603 | |
604 | Label prologue_loop, prologue_loop_tail, prologue_loop_end; |
605 | cmp(reg_tmp_, vlen); |
606 | jle(prologue_loop_tail, T_NEAR); |
607 | L(prologue_loop); |
608 | { |
609 | compute(0, max_unroll_ - 1, false); |
610 | advance_ptrs_imm(vlen, vlen); |
611 | sub(reg_tmp_, vlen); |
612 | cmp(reg_tmp_, vlen); |
613 | jge(prologue_loop, T_NEAR); |
614 | } |
615 | |
616 | L(prologue_loop_tail); |
617 | mov(reg_rem_mask_short_, 1); |
618 | // cl == reg_tmp_ because reg_tmp_ <= vlen here |
619 | shl(reg_rem_mask_short_, cl); |
620 | sub(reg_rem_mask_short_, 1); |
621 | jz(prologue_loop_end, T_NEAR); |
622 | |
623 | kmovq(kreg_rem_mask_short_, reg_rem_mask_short_); |
624 | compute(0, max_unroll_ - 1, true); |
625 | advance_ptrs_reg(reg_tmp_, reg_tmp_); |
626 | |
627 | L(prologue_loop_end); |
628 | rewind_ptrs(); |
629 | } |
630 | L(prologue_end); |
631 | |
632 | // Main loop |
633 | Label main_loop_end; |
634 | { |
635 | cmp(reg_len_, jcp_.oc); |
636 | jle(main_loop_end, T_NEAR); |
637 | |
638 | Label main_loop; |
639 | L(main_loop); |
640 | { |
641 | size_t OC_loop, OC_tail; |
642 | if (static_cast<size_t>(jcp_.oc) < max_unroll_ * vlen) { |
643 | // Fully unroll small loops |
644 | OC_loop = 0; |
645 | OC_tail = jcp_.oc; |
646 | } else { |
647 | OC_loop = vlen * def_unroll_; |
648 | OC_tail = jcp_.oc % OC_loop; |
649 | } |
650 | |
651 | assert(!!OC_loop || !!OC_tail); |
652 | |
653 | const int vlen_tail = OC_tail % vlen; |
654 | if (vlen_tail) { |
655 | unsigned tail_mask = (1 << vlen_tail) - 1; |
656 | mov(reg_tmp_, tail_mask); |
657 | kmovq(kreg_rem_mask_short_, reg_tmp_); |
658 | } |
659 | |
660 | if (OC_loop) { |
661 | mov(reg_tmp_, rnd_dn(jcp_.oc, OC_loop)); |
662 | Label oc_loop; |
663 | L(oc_loop); |
664 | { |
665 | for (size_t offset = 0; offset < OC_loop; offset += vlen) |
666 | compute(offset, offset / vlen, false); |
667 | advance_ptrs_imm(OC_loop, vlen); |
668 | sub(reg_tmp_, OC_loop); |
669 | jnz(oc_loop); |
670 | } |
671 | } |
672 | |
673 | if (OC_tail) { |
674 | for (size_t offset = 0; offset < OC_tail; offset += vlen) { |
675 | bool use_mask = (offset + vlen) > OC_tail; |
676 | compute(offset, offset / vlen, use_mask); |
677 | } |
678 | const size_t oc_tail_rem = OC_tail % vlen; |
679 | const size_t binary_offset = oc_tail_rem ? oc_tail_rem : vlen; |
680 | advance_ptrs_imm(OC_tail, binary_offset); |
681 | } |
682 | |
683 | rewind_ptrs(); |
684 | sub(reg_len_, jcp_.oc); |
685 | cmp(reg_len_, jcp_.oc); |
686 | jge(main_loop, T_NEAR); |
687 | } |
688 | } |
689 | L(main_loop_end); |
690 | |
691 | // Epilogue loop |
692 | Label epilogue_end; |
693 | { |
694 | cmp(reg_len_, 0); |
695 | je(epilogue_end, T_NEAR); |
696 | |
697 | Label epilogue_loop, epilogue_loop_tail; |
698 | cmp(reg_len_, vlen); |
699 | jle(epilogue_loop_tail, T_NEAR); |
700 | L(epilogue_loop); |
701 | { |
702 | compute(0, 0, false); |
703 | sub(reg_len_, vlen); |
704 | advance_ptrs_imm(vlen, vlen); |
705 | cmp(reg_len_, vlen); |
706 | jge(epilogue_loop, T_NEAR); |
707 | } |
708 | |
709 | L(epilogue_loop_tail); |
710 | mov(reg_tmp_, |
711 | reg_len_); // reg_tmp_ is rcx, and we need cl for the shift |
712 | mov(reg_rem_mask_short_, 1); |
713 | shl(reg_rem_mask_short_, cl); // reg_tmp_ == rcx and reg_tail < vlen |
714 | sub(reg_rem_mask_short_, 1); |
715 | jz(epilogue_end, T_NEAR); |
716 | kmovq(kreg_rem_mask_short_, reg_rem_mask_short_); |
717 | compute(0, 0, true); |
718 | } |
719 | |
720 | L(epilogue_end); |
721 | |
722 | if (zp_pad_comp_helper_) zp_pad_comp_helper_->fin(); |
723 | |
724 | postamble(); |
725 | |
726 | if (jcp_.with_eltwise) postops_injector_->prepare_table(); |
727 | } |
728 | |
729 | bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept { |
730 | const auto is_bf16_dst_dt = dst_dt == data_type::bf16; |
731 | return mayiuse(avx512_core) && !is_bf16_dst_dt; |
732 | } |
733 | |
734 | pp_ker_t *jit_pp_ker_create( |
735 | const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { |
736 | return mayiuse_jit_pp_kernel(pd->dst_md()->data_type) |
737 | ? new jit_pp_ker_t(pd, jcp) |
738 | : nullptr; |
739 | } |
740 | |
741 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) { |
742 | using namespace x64::injector; |
743 | static constexpr bool sum_at_pos_0_only = true; |
744 | static constexpr bool sum_requires_scale_one = false; |
745 | return mayiuse_jit_pp_kernel(dst_d->data_type()) |
746 | && dnnl::impl::cpu::x64::injector::post_ops_ok( |
747 | {avx512_core, {binary, eltwise, sum}, post_ops, dst_d, |
748 | sum_at_pos_0_only, sum_requires_scale_one}); |
749 | } |
750 | |
751 | } // namespace gemm_x8s8s32x_convolution_utils |
752 | } // namespace x64 |
753 | } // namespace cpu |
754 | } // namespace impl |
755 | } // namespace dnnl |
756 | |