1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_CONV_KERNEL_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32template <cpu_isa_t isa, typename Vmm>
33struct _jit_uni_x8s8s32x_fwd_kernel : public jit_generator {
34 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_conv_fwd_ker_t_)
35
36 _jit_uni_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp,
37 const primitive_attr_t &attr, const memory_desc_t &dst_md);
38
39 jit_conv_conf_t jcp;
40 const primitive_attr_t &attr_;
41
42private:
43 constexpr static int isa_simd_width_
44 = cpu_isa_traits<isa>::vlen / sizeof(float);
45 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
46 postops_injector_;
47 enum {
48 typesize = sizeof(float),
49 ic_sub_step = 4,
50 ker_zp_reg_base_idx = 9,
51 ker_reg_base_idx = 12,
52 ker_dw_reg_base_idx = 14,
53 ker_max_reg = 15,
54 };
55 enum ic_block_t {
56 no_last_block,
57 last_ic_block,
58 last_sp_block,
59 };
60
61 /* data registers */
62 const Xbyak::Reg64 reg_ptr_scales = rax;
63 const Xbyak::Reg64 reg_ptr_saturation_ubound = rax;
64 const Xbyak::Reg64 reg_inp = r8;
65 const Xbyak::Reg64 reg_ker = r9;
66 const Xbyak::Reg64 reg_out = r10;
67 const Xbyak::Reg64 aux_reg_inp = r11;
68 const Xbyak::Reg64 reg_ptr_sum_scale = r11;
69 const Xbyak::Reg64 reg_ptr_sum_zp = rdx;
70 const Xbyak::Reg64 aux_reg_ker = r12;
71 const Xbyak::Reg64 aux_reg_inp_d = r13;
72 const Xbyak::Reg64 reg_compensation = r14;
73 const Xbyak::Reg64 aux_reg_ker_d = r15;
74 const Xbyak::Reg64 reg_ker_long_offt = r13;
75
76 /* counter regs */
77 const Xbyak::Reg64 reg_oi = rbx;
78 const Xbyak::Reg64 reg_bias = rdx;
79 const Xbyak::Reg64 reg_oc_blocks = rsi;
80 const Xbyak::Reg64 reg_owb = aux_reg_ker;
81 const Xbyak::Reg64 reg_scratch = reg_compensation;
82 const Xbyak::Reg64 reg_ki = reg_compensation;
83 const Xbyak::Reg64 reg_kj = reg_ptr_scales;
84 const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
85 const Xbyak::Reg64 reg_icb = reg_bias;
86 // Using 3d regs as depthwise3d is not yet supported
87 const Xbyak::Reg64 reg_inp_buffer_ptr = aux_reg_inp_d;
88 const Xbyak::Reg64 aux_reg_inp_buffer_ptr = aux_reg_ker_d;
89 const Xbyak::Reg64 reg_jmp_tbl_base = reg_kj;
90 // zero-point computation
91 const Xbyak::Reg64 reg_zp_compensation = aux_reg_inp;
92 const Xbyak::Reg64 reg_src_zero_point = aux_reg_ker_d;
93 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point;
94 // dst scale
95 const Xbyak::Reg64 reg_dst_scale = reg_dst_zero_point;
96
97 /* binary post-ops operand */
98 const Xbyak::Reg64 temp_offset_reg = r12;
99
100 const Vmm vmm_wei = Vmm(0);
101 /* used during bias/comp/scale section of store_output */
102 const Vmm vmm_bias = Vmm(0);
103 const Vmm vmm_comp = Vmm(2); // only for signed input
104 const Vmm vmm_scale = Vmm(1);
105 /* used during post_op sum section of store_output */
106 const Vmm vmm_prev_dst = Vmm(0);
107 /* used during write-out section of store_output */
108 const Vmm vmm_zero = Vmm(0);
109 const Vmm vmm_saturation = Vmm(0);
110 /* used for zero-point */
111 const Vmm vmm_zp = Vmm(6);
112 const Vmm vmm_zp_one = Vmm(5);
113 const Vmm vmm_zp_comp = vmm_zp_one;
114 const Vmm vmm_zp_dw_tmp = vmm_zp_one;
115 /* dst scale */
116 const Vmm vmm_dst_scale = Vmm(5);
117
118 /* used in compute_ker (but set during prepare_output) */
119 const Vmm vmm_shift = Vmm(1); // only for signed input
120 /* used in compute_ker */
121 const Vmm vmm_tmp = Vmm(2); // not used for depthwise
122 const Vmm vmm_one
123 = Vmm(3); // set at start of kernel, not used for depthwise.
124 /* used only for depthwise */
125 Vmm vmm_dw_tmp;
126 Vmm vmm_dw_src;
127
128 int vmm_out_idx(int i_ur, int i_oc) {
129 const int idx_limit = jcp.src_zero_point
130 ? ker_zp_reg_base_idx
131 : jcp.is_depthwise ? ker_dw_reg_base_idx - jcp.signed_input
132 : ker_reg_base_idx;
133 const int nb_x_blocking
134 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
135 const int idx = i_ur * nb_x_blocking + i_oc;
136 assert(idx < idx_limit);
137 MAYBE_UNUSED(idx_limit);
138 /* remap register indices from 4 to 15
139 * to avoid passing xmm0 to comp*/
140 return ker_max_reg - idx;
141 }
142 Vmm vmm_out(int i_ur, int i_oc) {
143 const int idx = vmm_out_idx(i_ur, i_oc);
144 return Vmm(idx);
145 }
146 Vmm vmm_inp(int i_ic, int nb_x_blocking) {
147 int idx = i_ic + nb_x_blocking * jcp.ur_w;
148 assert(idx < ker_max_reg);
149 return Vmm(ker_max_reg - idx);
150 }
151 int get_ow_start(int ki, int pad_l) {
152 return nstl::max(0,
153 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
154 }
155 int get_ow_end(int ur_w, int ki, int pad_r) {
156 return ur_w
157 - nstl::max(0,
158 utils::div_up(
159 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
160 jcp.stride_w));
161 }
162 int get_blocking_size() {
163 return jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
164 }
165 int get_tail_size() {
166 return jcp.is_depthwise ? jcp.ngroups % jcp.ch_block
167 : jcp.oc_without_padding % jcp.oc_block;
168 }
169
170 void prepare_output(int ur_w);
171 void store_output(int ur_w, bool last_oc_block_flag);
172 void compute_ker_dw(int ur_w, int pad_l, int pad_r,
173 ic_block_t last_ic_block_flag, bool h_padded);
174 void compute_ker(int ur_w, int pad_l, int pad_r,
175 ic_block_t last_ic_block_flag, bool h_padded = false);
176 void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
177 void icb_loop(int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
178 void generate() override;
179
180 void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Xbyak::Reg64 &reg,
181 int offset, int load_size);
182 void apply_sum(const int nb_oc_block, const int ur_w,
183 const bool last_oc_block_flag, const int oc_block,
184 const float *p_sum_scale, const int32_t *p_sum_zp);
185 void apply_postops(const int nb_oc_block, const int ur_w,
186 const bool last_oc_block_flag, const int oc_block,
187 const float *p_sum_scale, const int32_t *p_sum_zp);
188};
189
190template <cpu_isa_t isa>
191struct jit_uni_x8s8s32x_fwd_kernel {
192
193 jit_uni_x8s8s32x_fwd_kernel(const jit_conv_conf_t &ajcp,
194 const primitive_attr_t &attr, const memory_desc_t &dst_md)
195 : kernel_(nullptr) {
196 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
197 switch (ch_block) {
198 case 8:
199 if (utils::one_of(isa, avx2)) {
200 kernel_ = new _jit_uni_x8s8s32x_fwd_kernel<isa, Xbyak::Ymm>(
201 ajcp, attr, dst_md);
202 } else
203 assert(!"invalid channel blocking for current ISA");
204 return;
205 case 4:
206 kernel_ = new _jit_uni_x8s8s32x_fwd_kernel<isa, Xbyak::Xmm>(
207 ajcp, attr, dst_md);
208 return;
209 default: assert(!"invalid channel blocking");
210 }
211 }
212
213 status_t create_kernel() { return kernel_->create_kernel(); }
214
215 ~jit_uni_x8s8s32x_fwd_kernel() { delete kernel_; }
216
217 void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
218
219 static status_t init_conf(jit_conv_conf_t &jcp,
220 const convolution_desc_t &cd, memory_desc_t &src_pd,
221 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
222 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
223 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
224 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
225
226 void (*jit_ker)(jit_conv_call_s *);
227
228private:
229 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_fwd_kernel);
230 jit_generator *kernel_;
231};
232
233} // namespace x64
234} // namespace cpu
235} // namespace impl
236} // namespace dnnl
237
238#endif
239