1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator { |
33 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_1x1_fwd_kernel_t) |
34 | |
35 | jit_avx512_core_amx_1x1_fwd_kernel_t(const jit_conv_conf_t &ajcp, |
36 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
37 | |
38 | static status_t init_conf(jit_conv_conf_t &jcp, |
39 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
40 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
41 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
42 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
43 | const jit_conv_conf_t &jcp, const primitive_attr_t &attr); |
44 | |
45 | // Tile-registers decomposition |
46 | enum { C_BASE = 0, W_BASE = 6, I_BASE = 4 }; |
47 | |
48 | void tile_configure(char *tcgf_buff); |
49 | |
50 | jit_conv_conf_t jcp; |
51 | const primitive_attr_t &attr_; |
52 | |
53 | private: |
54 | constexpr static int isa_simd_width_ |
55 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
56 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
57 | postops_injector_; |
58 | |
59 | enum { |
60 | zmm_idx_limit_bf16 = 29, |
61 | zmm_idx_limit_int8 = 27, |
62 | }; |
63 | |
64 | int row_count_ = 0; |
65 | int buf_count_ = 0; |
66 | bool is_store_done_ = false; |
67 | bool is_buffer_empty_ = true; |
68 | bool check_last_sb_ = false; |
69 | bool last_oc_block_flag_ = false; |
70 | |
71 | /* data regs */ |
72 | const Xbyak::Reg64 inp_ptr = r15; |
73 | const Xbyak::Reg64 wei_ptr = r14; |
74 | const Xbyak::Reg64 out_ptr = r13; |
75 | const Xbyak::Reg64 wsp_ptr = r12; |
76 | |
77 | const Xbyak::Reg64 reg_bias = r11; |
78 | const Xbyak::Reg64 reg_ptr_scales = r10; |
79 | const Xbyak::Reg64 reg_ptr_sum_scale = r9; |
80 | const Xbyak::Reg64 reg_ptr_sum_zp = rax; |
81 | const Xbyak::Reg64 aux_reg_saturation = reg_ptr_sum_scale; |
82 | const Xbyak::Reg64 reg_last_h = r8; |
83 | |
84 | const Xbyak::Reg64 stride_seq = rbx; |
85 | const Xbyak::Reg64 stride_nhwc = rsi; |
86 | const Xbyak::Reg64 reg_tmp = abi_not_param1; |
87 | |
88 | const Xbyak::Reg64 reg_oc_blocks = rdx; |
89 | const Xbyak::Reg64 reg_is_osb = rsi; |
90 | const Xbyak::Reg64 reg_postop = abi_not_param1; |
91 | const Xbyak::Reg64 reg_scratch = reg_bias; |
92 | const Xbyak::Reg64 reg_tilebuff = reg_ptr_scales; |
93 | /* zero-point */ |
94 | const Xbyak::Reg64 reg_zp_compensation = reg_last_h; |
95 | const Xbyak::Reg64 reg_src_zero_point = reg_oc_blocks; |
96 | const Xbyak::Reg64 reg_dst_zero_point = rax; |
97 | |
98 | /* scale */ |
99 | const Xbyak::Reg64 reg_ptr_dst_scale = reg_ptr_scales; |
100 | |
101 | const Xbyak::Zmm zmm_bias = zmm31; |
102 | const Xbyak::Zmm zmm_saturation = zmm_bias; |
103 | const Xbyak::Zmm zmm_zero = zmm30; |
104 | const Xbyak::Zmm zmm_prev_dst = zmm29; |
105 | const Xbyak::Zmm zmm_sum_zp = zmm26; |
106 | /* zero-point */ |
107 | const Xbyak::Zmm zmm_zp = zmm29; |
108 | const Xbyak::Zmm zmm_src_zp = zmm28; |
109 | const Xbyak::Zmm zmm_dst_zp = zmm27; |
110 | |
111 | const Xbyak::Reg64 bin_injector_helper_reg_1 = r14; |
112 | const Xbyak::Reg64 bin_injector_helper_reg_2 = r15; |
113 | const Xbyak::Reg64 bin_injector_helper_reg_3 = r11; |
114 | |
115 | const Xbyak::Opmask ktail_mask = k2; |
116 | |
117 | bool is_bf16() const; |
118 | |
119 | void init_runtime_counters(); |
120 | |
121 | int get_out_tensor(int h, int i) const; |
122 | int get_inp_tensor(int h) const; |
123 | int get_wei_tensor(int i) const; |
124 | int get_ic_tail() const; |
125 | |
126 | size_t out_h_shift() const; |
127 | size_t out_w_shift() const; |
128 | size_t inp_offset(int ih, int iw, int icb) const; |
129 | size_t out_row_offset(int h, int w, int ocb) const; |
130 | |
131 | void prepare_output(); |
132 | |
133 | void cvt2ps(data_type_t type_in, const Xbyak::Zmm ymm_in, |
134 | const Xbyak::Operand &op, bool mask_flag); |
135 | Xbyak::Zmm zmm_out(const int idx) { |
136 | const int upper_limit |
137 | = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8; |
138 | assert(upper_limit > idx); |
139 | MAYBE_UNUSED(upper_limit); |
140 | return Xbyak::Zmm(idx); |
141 | } |
142 | Xbyak::Zmm zmm_mask( |
143 | const Xbyak::Zmm zmm_in, bool mask_flag, bool store = false); |
144 | Xbyak::Ymm ymm_mask( |
145 | const Xbyak::Ymm ymm_in, bool mask_flag, bool store = false); |
146 | |
147 | void update_buffer_pointers(); |
148 | void interleave_store(); |
149 | void apply_sum(const Xbyak::Zmm zmm_out, const float *p_sum_scale, |
150 | const int32_t *p_sum_zp, const Xbyak::Address &addr, |
151 | const bool mask_flag); |
152 | void apply_postops(const Xbyak::Zmm zmm_out, const float *p_sum_scale, |
153 | const int32_t *p_sum_zp, const Xbyak::Address &addr, |
154 | const size_t off, const bool mask_flag); |
155 | static bool is_fast_postops(const jit_conv_conf_t &jcp); |
156 | void store_output_vectors_int8(int ocb, int osb); |
157 | void store_output_vector_int8( |
158 | const Xbyak::Zmm zmm_out, int ocb, int h, int w); |
159 | inline void store_output_ymm_bf16( |
160 | const int idx, const Xbyak::Address &addr, const bool mask_flag); |
161 | void store_output_vectors_bf16(int ocb, int osb); |
162 | void store_output_vector_bf16( |
163 | const Xbyak::Zmm zmm_out, int ocb, int h, int w); |
164 | void store_output_vectors(int ocb, int osb); |
165 | void store_output_vector(const Xbyak::Zmm zmm_out, int ocb, int h, int w); |
166 | void store_output(bool do_store, bool is_tail); |
167 | void icb_loop(bool do_store); |
168 | void osb_loop(int nb_os = 1); |
169 | |
170 | void generate() override; |
171 | }; |
172 | |
173 | } // namespace x64 |
174 | } // namespace cpu |
175 | } // namespace impl |
176 | } // namespace dnnl |
177 | |
178 | #endif |
179 | |