1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
25#include "cpu/x64/jit_generator.hpp"
26#include "cpu/x64/jit_primitive_conf.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33template <typename Vmm>
34struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel : public jit_generator {
35 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t)
36 _jit_avx512_core_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
37 const primitive_attr_t &attr, const memory_desc_t &dst_md);
38
39 jit_1x1_conv_conf_t jcp;
40 const primitive_attr_t &attr_;
41
42private:
43 constexpr static int isa_simd_width_
44 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
45 using Vmm_down_t =
46 typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value,
47 Xbyak::Ymm, Xbyak::Xmm>::type;
48 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>>
49 postops_injector_;
50
51 /* register mapping */
52 const Xbyak::Reg64 reg_last_load = r8;
53 const Xbyak::Reg64 reg_bcast_data = r8;
54 const Xbyak::Reg64 reg_ptr_scales = r8;
55 const Xbyak::Reg64 reg_ptr_saturation_ubound = r8;
56 const Xbyak::Reg64 reg_output_data = r9;
57 const Xbyak::Reg64 reg_load_data = r10;
58 const Xbyak::Reg64 reg_ptr_sum_scale = r10;
59 const Xbyak::Reg64 reg_reduce_loop_work = r11;
60 const Xbyak::Reg64 reg_bias_data = r12;
61 const Xbyak::Reg64 reg_comp_data = r12;
62 const Xbyak::Reg64 reg_ptr_dst_scale = r12;
63 const Xbyak::Reg64 reg_scratch = r13;
64 const Xbyak::Reg64 aux_reg_bcast_data = r14;
65 const Xbyak::Reg64 aux_reg_load_data = r15;
66 const Xbyak::Reg64 imm_addr64 = r15;
67 const Xbyak::Reg64 reg_reduce_pos_flag = rax;
68 const Xbyak::Reg64 aux1_reg_bcast_data = rbx;
69 const Xbyak::Reg64 reg_bcast_loop_work = rbx;
70 const Xbyak::Reg64 bcast_loop_iter = rdx; // Note: Fix me
71 const Xbyak::Reg64 reg_load_loop_work = rsi;
72 const Xbyak::Reg64 aux_reg_output_data = abi_not_param1;
73 const Xbyak::Reg64 reduce_loop_iter = abi_param1;
74 // zero-point computation
75 const Xbyak::Reg64 reg_zp_compensation = aux_reg_load_data; // r15
76 const Xbyak::Reg64 reg_src_zero_point = aux_reg_bcast_data; // r14
77 const Xbyak::Reg64 reg_dst_zero_point = reg_src_zero_point;
78 const Xbyak::Reg64 reg_load_dim_tail_mask = reg_scratch;
79
80 const Xbyak::Opmask k_load_dim_mask = Xbyak::Opmask(2);
81 const Xbyak::Opmask k_load_dim_mask_extended = Xbyak::Opmask(3);
82 const Xbyak::Opmask k_load_dim_tail_mask = Xbyak::Opmask(4);
83 const Xbyak::Opmask k_load_dim_tail_mask_extended = Xbyak::Opmask(5);
84 const Xbyak::Opmask postops_mask = Xbyak::Opmask(6);
85 const Xbyak::Opmask vmask = k7;
86
87 const Vmm vmm_tmp = Vmm(28);
88 const Vmm vmm_saturation = Vmm(28);
89 const Vmm vmm_one = Vmm(29);
90 const Vmm vmm_zero = Vmm(30);
91 const Vmm vmm_prev_dst = Vmm(30);
92 const Vmm vmm_shift = Vmm(30);
93 const Vmm vmm_bcast = Vmm(31);
94 /* zero-point */
95 const Vmm vmm_zp = Vmm(30);
96 const Vmm vmm_zp_tmp = vmm_zp;
97
98 const Vmm vmm_dst_scale = Vmm(30);
99
100 /* bfloat16 */
101 const Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(25);
102 const Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(26);
103 const Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(27);
104 const Xbyak::Reg64 bf16_emu_reserv_4 = imm_addr64;
105 const Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(28);
106 const Xbyak::Ymm ymm_store = Xbyak::Ymm(31);
107
108 std::unique_ptr<bf16_emulation_t> bf16_emu_;
109
110 constexpr static int reg64_size_ = sizeof(int64_t);
111 constexpr static int bcast_loop_work_off = 0;
112 constexpr static int reg_bias_data_off = 1 * reg64_size_;
113 constexpr static int reg_bcast_data_off = 2 * reg64_size_;
114 constexpr static int reg_load_data_off = 3 * reg64_size_;
115 constexpr static int reg_ptr_sum_scale_off = 4 * reg64_size_;
116 constexpr static int reg_ptr_sum_zp_off = 5 * reg64_size_;
117 constexpr static int reg_comp_data_off = 6 * reg64_size_;
118 constexpr static int reg_zp_compensation_off = 7 * reg64_size_;
119 constexpr static int reg_src_zero_point_off = 8 * reg64_size_;
120 constexpr static int reg_dst_zero_point_off = 9 * reg64_size_;
121 constexpr static int reg_dst_scale_off = 10 * reg64_size_;
122 constexpr static int reg_binary_post_op_acc_off = 11 * reg64_size_;
123 constexpr static int reg_abi_param1_backup = 12 * reg64_size_;
124 constexpr static int stack_space_needed = 13 * reg64_size_;
125
126 inline Vmm maybe_mask_vmm(Vmm vmm, bool mask_flag) {
127 return mask_flag ? vmm | k_load_dim_mask_extended : vmm;
128 }
129 inline Vmm_down_t maybe_mask_vmm_down(Vmm_down_t vmm_down, bool mask_flag) {
130 return mask_flag ? vmm_down | k_load_dim_mask : vmm_down;
131 }
132 inline Vmm_down_t vmm_store() { return Vmm_down_t(ymm_store.getIdx()); };
133
134 void bcast_loop(int load_loop_blk);
135 void reduce_loop(int load_loop_blk, int ur, bool wraparound);
136
137 Xbyak::Address output_ptr(const int i_load, const int i_ur);
138 int vreg_accum_idx(const int load_loop_blk, int i_load, int i_ur) const;
139 Vmm vreg_accum(const int load_loop_blk, int i_load, int i_ur) const;
140 void apply_sum(const int load_loop_blk, const int ur,
141 const bool mask_flag_in, const float *p_sum_scale,
142 const int32_t *p_sum_zp);
143 void apply_postops(const int load_loop_blk, const int ur,
144 const bool mask_flag_in, const float *p_sum_scale,
145 const int32_t *p_sum_zp);
146 void generate() override;
147 void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op,
148 bool mask_flag);
149};
150
151struct jit_avx512_core_x8s8s32x_1x1_conv_kernel {
152 jit_avx512_core_x8s8s32x_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
153 const primitive_attr_t &attr, const memory_desc_t &dst_md)
154 : kernel_(nullptr) {
155 int ch_block = ajcp.ic_block;
156 switch (ch_block) {
157 case 16:
158 kernel_ = new _jit_avx512_core_x8s8s32x_1x1_conv_kernel<
159 Xbyak::Zmm>(ajcp, attr, dst_md);
160 return;
161 case 8:
162 kernel_ = new _jit_avx512_core_x8s8s32x_1x1_conv_kernel<
163 Xbyak::Ymm>(ajcp, attr, dst_md);
164 return;
165 case 4:
166 kernel_ = new _jit_avx512_core_x8s8s32x_1x1_conv_kernel<
167 Xbyak::Xmm>(ajcp, attr, dst_md);
168 return;
169 default: assert(!"invalid channel blocking");
170 }
171 }
172
173 status_t create_kernel() { return kernel_->create_kernel(); }
174
175 ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() { delete kernel_; }
176
177 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
178 const convolution_desc_t &cd, const memory_desc_t *&src_md,
179 memory_desc_t &weights_md, memory_desc_t &dst_md,
180 memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads,
181 bool reduce_src);
182
183 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
184 const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
185
186 void operator()(const jit_1x1_conv_call_s *p) const { (*kernel_)(p); }
187 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
188
189private:
190 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_x8s8s32x_1x1_conv_kernel);
191 jit_generator *kernel_;
192};
193
194} // namespace x64
195} // namespace cpu
196} // namespace impl
197} // namespace dnnl
198
199#endif
200