1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_1X1_CONV_KERNEL_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_1X1_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32template <cpu_isa_t isa, typename Vmm>
33struct _jit_uni_x8s8s32x_1x1_conv_kernel : public jit_generator {
34 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_1x1_conv_kernel)
35 _jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
36 const primitive_attr_t &attr, const memory_desc_t &dst_md);
37
38 int get_tail_size() { return jcp.oc_without_padding % jcp.oc_block; }
39
40 jit_1x1_conv_conf_t jcp;
41 const primitive_attr_t &attr_;
42
43private:
44 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
45 postops_injector_;
46
47 enum {
48 ker_max_reg_idx = 13,
49 };
50 const Xbyak::Reg64 reg_bcast_data = r8;
51 const Xbyak::Reg64 reg_ptr_scales = r8;
52 const Xbyak::Reg64 reg_output_data = r9;
53 const Xbyak::Reg64 reg_load_data = r10;
54 const Xbyak::Reg64 reg_ptr_sum_scale = r10;
55 const Xbyak::Reg64 reg_ptr_sum_zp = rdx;
56 const Xbyak::Reg64 reg_reduce_loop_work = r11;
57 const Xbyak::Reg64 reg_bias_data = r12;
58 const Xbyak::Reg64 reg_comp_data = r12;
59 const Xbyak::Reg64 reg_ptr_dst_scale = r12;
60 const Xbyak::Reg64 reg_init_bcast = r13;
61 const Xbyak::Reg64 reg_store_bcast = r13;
62 const Xbyak::Reg64 reg_reduce_loop_iter = r13;
63 const Xbyak::Reg64 aux_reg_bcast_data = r14;
64 const Xbyak::Reg64 aux_reg_load_data = r15;
65 const Xbyak::Reg64 aux_reg_saturation = r15;
66 const Xbyak::Reg64 reg_reduce_pos_flag = rax;
67 const Xbyak::Reg64 aux1_reg_bcast_data = rbx;
68 const Xbyak::Reg64 reg_bcast_loop_work = rbx;
69 const Xbyak::Reg64 reg_bcast_loop_iter = rdx;
70 const Xbyak::Reg64 reg_load_loop_work = rsi;
71 const Xbyak::Reg64 aux_reg_output_data = abi_not_param1;
72 // zero-point computation
73 const Xbyak::Reg64 reg_zp_compensation = aux_reg_load_data; // r15
74 const Xbyak::Reg64 reg_src_zero_point = aux_reg_bcast_data; // r14
75 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point;
76
77 const Vmm vmm_tmp = Vmm(3);
78 const Vmm vmm_one = Vmm(2);
79 const Vmm vmm_zero = Vmm(1);
80 const Vmm vmm_shift = Vmm(1);
81 const Vmm vmm_bcast = Vmm(0);
82 const Vmm vmm_saturation = Vmm(0);
83 /* used during scale section of store_output */
84 const Vmm vmm_scale = Vmm(1);
85 /* used during post_op sum section of store_output */
86 const Vmm vmm_prev_dst = Vmm(1);
87 /* used during bias section of store_output */
88 const Vmm vmm_comp = Vmm(0); // only for signed input
89 const Vmm vmm_bias = Vmm(3);
90 /* zero-point */
91 const Vmm vmm_zp = Vmm(1);
92 const Vmm vmm_zp_comp = Vmm(2);
93 /* dst scale */
94 const Vmm vmm_dst_scale = Vmm(1);
95
96 constexpr static int simd_w = isa == avx2 ? 8 : 4;
97 constexpr static int reg64_size = sizeof(int64_t);
98 constexpr static int bcast_loop_work_off = 0;
99 constexpr static int reg_bias_data_off = 1 * reg64_size;
100 constexpr static int reg_bcast_data_off = 2 * reg64_size;
101 constexpr static int reg_load_data_off = 3 * reg64_size;
102 constexpr static int reg_ptr_sum_scale_off = 4 * reg64_size;
103 constexpr static int reg_bcast_loop_iter_off = 5 * reg64_size;
104 constexpr static int reg_comp_data_off = 6 * reg64_size;
105 constexpr static int reg_zp_compensation_off = 7 * reg64_size;
106 constexpr static int reg_src_zero_point_off = 8 * reg64_size;
107 constexpr static int reg_dst_zero_point_off = 9 * reg64_size;
108 constexpr static int reg_dst_scale_off = 10 * reg64_size;
109 constexpr static int reg_binary_post_op_acc_off = 11 * reg64_size;
110 constexpr static int stack_space_needed = 12 * reg64_size;
111
112 int vreg_accum_idx(
113 const int load_loop_blk, const int i_load, const int i_ur);
114 Vmm vreg_accum(const int load_loop_blk, const int i_load, const int i_ur);
115 int output_ptr(const int i_load, const int i_ur);
116 void bcast_loop(int load_loop_blk);
117 void apply_sum(const int ur, const int load_loop_blk,
118 const bool mask_flag_in, const float *p_sum_scale,
119 const int32_t *p_sum_zp);
120 void apply_postops(const int ur, const int load_loop_blk,
121 const bool mask_flag_in, const float *p_sum_scale,
122 const int32_t *p_sum_zp);
123 void reduce_loop(int load_loop_blk, int ur, bool wraparound);
124
125 void generate() override;
126 void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Xbyak::Reg64 &reg,
127 int offset, int load_size);
128};
129
130template <cpu_isa_t isa>
131struct jit_uni_x8s8s32x_1x1_conv_kernel {
132
133 jit_uni_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
134 const primitive_attr_t &attr, const memory_desc_t &dst_md)
135 : kernel_(nullptr) {
136
137 switch (isa) {
138 case avx2:
139 kernel_ = new jit_avx2_x8s8s32x_1x1_conv_kernel(
140 ajcp, attr, dst_md);
141 return;
142 case sse41:
143 kernel_ = new jit_sse41_x8s8s32x_1x1_conv_kernel(
144 ajcp, attr, dst_md);
145 return;
146 default: assert(!"Current ISA is not supported!");
147 }
148 }
149
150 status_t create_kernel() { return kernel_->create_kernel(); }
151
152 ~jit_uni_x8s8s32x_1x1_conv_kernel() { delete kernel_; }
153
154 void operator()(const jit_1x1_conv_call_s *p) const { (*kernel_)(p); }
155
156 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
157 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
158 const memory_desc_wrapper &weights_d,
159 const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
160 primitive_attr_t &attr, int nthreads, bool reduce_src);
161
162 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
163 const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
164
165 using jit_sse41_x8s8s32x_1x1_conv_kernel
166 = _jit_uni_x8s8s32x_1x1_conv_kernel<sse41, Xbyak::Xmm>;
167 using jit_avx2_x8s8s32x_1x1_conv_kernel
168 = _jit_uni_x8s8s32x_1x1_conv_kernel<avx2, Xbyak::Ymm>;
169
170 constexpr static int simd_w = isa == avx2 ? 8 : 4;
171
172private:
173 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_1x1_conv_kernel);
174 jit_generator *kernel_;
175};
176
177} // namespace x64
178} // namespace cpu
179} // namespace impl
180} // namespace dnnl
181
182#endif
183