1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_DECONV_ZP_PAD_STR_KERNEL_HPP
18#define CPU_X64_JIT_UNI_DECONV_ZP_PAD_STR_KERNEL_HPP
19
20#include "cpu/x64/jit_generator.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27struct jit_conv_conf_t;
28
29namespace zp {
30
31struct jit_uni_deconv_zp_pad_str_call_params_t {
32 const int8_t *wei;
33 const int32_t *src_zero_point;
34 int32_t *dst_scratchpad;
35 bool last_oc_block;
36};
37
38/*
39 * Compute zero point source compensation applied during filter application on
40 * the padding as well as stride holes.
41 *
42 * zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source
43 *
44 * output_format - dhwc
45 */
46class jit_uni_deconv_zp_pad_str_kernel_base_t : public jit_generator {
47public:
48 jit_uni_deconv_zp_pad_str_kernel_base_t(const jit_conv_conf_t &jcp);
49
50 void operator()(const jit_uni_deconv_zp_pad_str_call_params_t *params) {
51 jit_generator::operator()(params);
52 }
53
54 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_deconv_zp_pad_str_kernel_base_t);
55
56private:
57 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_deconv_zp_pad_str_kernel_base_t);
58
59 void generate() override;
60 void load_addresses();
61 void compute();
62 virtual void compute_step(const dim_t icb_offset) = 0;
63 virtual void apply_zero_point() = 0;
64 virtual void store_result() = 0;
65 virtual void init() = 0;
66
67protected:
68 size_t number_reserved_vmms_ = 0;
69 size_t reserve_vmm();
70
71 const jit_conv_conf_t &jcp_;
72 const Xbyak::Reg64 &reg_src_zp_ = r8;
73 const Xbyak::Reg64 &reg_wei_ = r9;
74 const Xbyak::Reg64 &reg_dst_ = r10;
75 const Xbyak::Reg64 &reg_tmp_ = r11;
76 const Xbyak::Reg8 &reg_last_oc_block_ = r12b;
77 const size_t tail_size_;
78};
79
80template <cpu_isa_t isa, typename Vmm>
81class jit_uni_deconv_zp_pad_str_kernel_t
82 : public jit_uni_deconv_zp_pad_str_kernel_base_t {
83public:
84 jit_uni_deconv_zp_pad_str_kernel_t(const jit_conv_conf_t &jcp);
85
86private:
87 void init() override;
88 void compute_step(const dim_t icb_offset) override;
89 void apply_zero_point() override;
90 void store_result() override;
91
92 Vmm get_next_vmm();
93
94 const Vmm result_acc_;
95 const Vmm vmm_tmp_;
96 const Vmm vmm_one_bytes_;
97 const Vmm vmm_one_words_;
98
99 const Xbyak::Opmask &ktail_mask_ = k2;
100 dim_t current_vmm_;
101};
102
103bool should_calculate_deconv_zp_src_pad_str_comp(
104 const jit_conv_conf_t &jcp) noexcept;
105
106template <cpu_isa_t isa>
107jit_uni_deconv_zp_pad_str_kernel_base_t *create_deconv_zp_pad_str_comp_ker(
108 const jit_conv_conf_t &jcp);
109
110void compute_deconv_zp_pad_str_comp_ker(const jit_conv_conf_t &jcp_,
111 const bool with_groups, const memory_desc_wrapper &wei_d,
112 const int8_t *wei, const int32_t *src_zp, int32_t *dst,
113 jit_uni_deconv_zp_pad_str_kernel_base_t *ker);
114
115} // namespace zp
116} // namespace x64
117} // namespace cpu
118} // namespace impl
119} // namespace dnnl
120
121#endif
122