1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_X8S8S32X_1X1_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_UNI_X8S8S32X_1X1_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | template <cpu_isa_t isa, typename Vmm> |
33 | struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator { |
34 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_1x1_conv_kernel) |
35 | _jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, |
36 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
37 | |
38 | int get_tail_size() { return jcp.oc_without_padding % jcp.oc_block; } |
39 | |
40 | jit_1x1_conv_conf_t jcp; |
41 | const primitive_attr_t &attr_; |
42 | |
43 | private: |
44 | std::unique_ptr<injector::jit_uni_postops_injector_t<isa>> |
45 | postops_injector_; |
46 | |
47 | enum { |
48 | ker_max_reg_idx = 13, |
49 | }; |
50 | const Xbyak::Reg64 reg_bcast_data = r8; |
51 | const Xbyak::Reg64 reg_ptr_scales = r8; |
52 | const Xbyak::Reg64 reg_output_data = r9; |
53 | const Xbyak::Reg64 reg_load_data = r10; |
54 | const Xbyak::Reg64 reg_ptr_sum_scale = r10; |
55 | const Xbyak::Reg64 reg_ptr_sum_zp = rdx; |
56 | const Xbyak::Reg64 reg_reduce_loop_work = r11; |
57 | const Xbyak::Reg64 reg_bias_data = r12; |
58 | const Xbyak::Reg64 reg_comp_data = r12; |
59 | const Xbyak::Reg64 reg_ptr_dst_scale = r12; |
60 | const Xbyak::Reg64 reg_init_bcast = r13; |
61 | const Xbyak::Reg64 reg_store_bcast = r13; |
62 | const Xbyak::Reg64 reg_reduce_loop_iter = r13; |
63 | const Xbyak::Reg64 aux_reg_bcast_data = r14; |
64 | const Xbyak::Reg64 aux_reg_load_data = r15; |
65 | const Xbyak::Reg64 aux_reg_saturation = r15; |
66 | const Xbyak::Reg64 reg_reduce_pos_flag = rax; |
67 | const Xbyak::Reg64 aux1_reg_bcast_data = rbx; |
68 | const Xbyak::Reg64 reg_bcast_loop_work = rbx; |
69 | const Xbyak::Reg64 reg_bcast_loop_iter = rdx; |
70 | const Xbyak::Reg64 reg_load_loop_work = rsi; |
71 | const Xbyak::Reg64 aux_reg_output_data = abi_not_param1; |
72 | // zero-point computation |
73 | const Xbyak::Reg64 reg_zp_compensation = aux_reg_load_data; // r15 |
74 | const Xbyak::Reg64 reg_src_zero_point = aux_reg_bcast_data; // r14 |
75 | const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point; |
76 | |
77 | const Vmm vmm_tmp = Vmm(3); |
78 | const Vmm vmm_one = Vmm(2); |
79 | const Vmm vmm_zero = Vmm(1); |
80 | const Vmm vmm_shift = Vmm(1); |
81 | const Vmm vmm_bcast = Vmm(0); |
82 | const Vmm vmm_saturation = Vmm(0); |
83 | /* used during scale section of store_output */ |
84 | const Vmm vmm_scale = Vmm(1); |
85 | /* used during post_op sum section of store_output */ |
86 | const Vmm vmm_prev_dst = Vmm(1); |
87 | /* used during bias section of store_output */ |
88 | const Vmm vmm_comp = Vmm(0); // only for signed input |
89 | const Vmm vmm_bias = Vmm(3); |
90 | /* zero-point */ |
91 | const Vmm vmm_zp = Vmm(1); |
92 | const Vmm vmm_zp_comp = Vmm(2); |
93 | /* dst scale */ |
94 | const Vmm vmm_dst_scale = Vmm(1); |
95 | |
96 | constexpr static int simd_w = isa == avx2 ? 8 : 4; |
97 | constexpr static int reg64_size = sizeof(int64_t); |
98 | constexpr static int bcast_loop_work_off = 0; |
99 | constexpr static int reg_bias_data_off = 1 * reg64_size; |
100 | constexpr static int reg_bcast_data_off = 2 * reg64_size; |
101 | constexpr static int reg_load_data_off = 3 * reg64_size; |
102 | constexpr static int reg_ptr_sum_scale_off = 4 * reg64_size; |
103 | constexpr static int reg_bcast_loop_iter_off = 5 * reg64_size; |
104 | constexpr static int reg_comp_data_off = 6 * reg64_size; |
105 | constexpr static int reg_zp_compensation_off = 7 * reg64_size; |
106 | constexpr static int reg_src_zero_point_off = 8 * reg64_size; |
107 | constexpr static int reg_dst_zero_point_off = 9 * reg64_size; |
108 | constexpr static int reg_dst_scale_off = 10 * reg64_size; |
109 | constexpr static int reg_binary_post_op_acc_off = 11 * reg64_size; |
110 | constexpr static int stack_space_needed = 12 * reg64_size; |
111 | |
112 | int vreg_accum_idx( |
113 | const int load_loop_blk, const int i_load, const int i_ur); |
114 | Vmm vreg_accum(const int load_loop_blk, const int i_load, const int i_ur); |
115 | int output_ptr(const int i_load, const int i_ur); |
116 | void bcast_loop(int load_loop_blk); |
117 | void apply_sum(const int ur, const int load_loop_blk, |
118 | const bool mask_flag_in, const float *p_sum_scale, |
119 | const int32_t *p_sum_zp); |
120 | void apply_postops(const int ur, const int load_loop_blk, |
121 | const bool mask_flag_in, const float *p_sum_scale, |
122 | const int32_t *p_sum_zp); |
123 | void reduce_loop(int load_loop_blk, int ur, bool wraparound); |
124 | |
125 | void generate() override; |
126 | void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Xbyak::Reg64 ®, |
127 | int offset, int load_size); |
128 | }; |
129 | |
130 | template <cpu_isa_t isa> |
131 | struct jit_uni_x8s8s32x_1x1_conv_kernel { |
132 | |
133 | jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, |
134 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
135 | : kernel_(nullptr) { |
136 | |
137 | switch (isa) { |
138 | case avx2: |
139 | kernel_ = new jit_avx2_x8s8s32x_1x1_conv_kernel( |
140 | ajcp, attr, dst_md); |
141 | return; |
142 | case sse41: |
143 | kernel_ = new jit_sse41_x8s8s32x_1x1_conv_kernel( |
144 | ajcp, attr, dst_md); |
145 | return; |
146 | default: assert(!"Current ISA is not supported!" ); |
147 | } |
148 | } |
149 | |
150 | status_t create_kernel() { return kernel_->create_kernel(); } |
151 | |
152 | ~jit_uni_x8s8s32x_1x1_conv_kernel() { delete kernel_; } |
153 | |
154 | void operator()(const jit_1x1_conv_call_s *p) const { (*kernel_)(p); } |
155 | |
156 | static status_t init_conf(jit_1x1_conv_conf_t &jcp, |
157 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
158 | const memory_desc_wrapper &weights_d, |
159 | const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d, |
160 | primitive_attr_t &attr, int nthreads, bool reduce_src); |
161 | |
162 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
163 | const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr); |
164 | |
165 | using jit_sse41_x8s8s32x_1x1_conv_kernel |
166 | = _jit_uni_x8s8s32x_1x1_conv_kernel<sse41, Xbyak::Xmm>; |
167 | using jit_avx2_x8s8s32x_1x1_conv_kernel |
168 | = _jit_uni_x8s8s32x_1x1_conv_kernel<avx2, Xbyak::Ymm>; |
169 | |
170 | constexpr static int simd_w = isa == avx2 ? 8 : 4; |
171 | |
172 | private: |
173 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_1x1_conv_kernel); |
174 | jit_generator *kernel_; |
175 | }; |
176 | |
177 | } // namespace x64 |
178 | } // namespace cpu |
179 | } // namespace impl |
180 | } // namespace dnnl |
181 | |
182 | #endif |
183 | |