1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRDGMM_KERNEL_HPP |
18 | #define CPU_X64_JIT_BRDGMM_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | template <cpu_isa_t isa, typename Wmm> |
37 | struct jit_brdgmm_kernel_base_t : public jit_generator { |
38 | jit_brdgmm_kernel_base_t(const brgemm_t &abrd); |
39 | |
40 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brdgmm_kernel_base_t) |
41 | |
42 | brgemm_t brg; |
43 | |
44 | static bool is_fast_vnni_int8(const brgemm_t &brg) { |
45 | return brg.is_dgmm && brg.is_int8 && brg.isa_impl == avx512_core_vnni |
46 | && brg.ldb_tail /*n_vlen_tail*/ == 0; |
47 | } |
48 | |
49 | private: |
50 | // note: this kernel doesn't yet support TMM's. We differentiate Wmm and Vmm |
51 | // just to follow same template style as brgemm_kernel. |
52 | using Vmm = |
53 | typename utils::conditional<std::is_same<Wmm, Xbyak::Tmm>::value, |
54 | Xbyak::Zmm, Wmm>::type; |
55 | using Vmm_low_t = typename vreg_traits<Vmm>::Vmm_lower_t; |
56 | static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core, avx2, |
57 | avx2, avx2_vnni_2, avx2_vnni_2, avx512_core_fp16, avx512_core_fp16); |
58 | using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_t, Vmm>; |
59 | std::unique_ptr<po_injector_t> postops_injector_; |
60 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
61 | |
62 | Xbyak::Label permute_index_table; |
63 | |
64 | using reg64_t = const Xbyak::Reg64; |
65 | // Register decomposition |
66 | const reg64_t param1 = abi_param1; |
67 | const reg64_t reg_A = abi_not_param1; |
68 | const reg64_t reg_B = r8; |
69 | const reg64_t reg_aux_batch_addr = r15; |
70 | const reg64_t reg_BS = rsi; |
71 | |
72 | // loop variables |
73 | const reg64_t reg_BS_loop = r12; |
74 | const reg64_t reg_aux_M = r13; |
75 | const reg64_t reg_aux_D = rbx; |
76 | const reg64_t reg_aux_C = rdx; |
77 | const reg64_t reg_aux_A = r10; |
78 | const reg64_t reg_aux_B = abi_param1; |
79 | const reg64_t reg_aux1_A = reg_A; // brgemm_strd |
80 | const reg64_t reg_aux1_B = reg_B; // brgemm_strd |
81 | const reg64_t reg_a_offset = r9; |
82 | const reg64_t reg_aux_N = r11; |
83 | |
84 | const reg64_t reg_aux_A_vpad_top = r14; |
85 | const reg64_t reg_aux_A_vpad_bottom = rbp; |
86 | |
87 | const reg64_t reg_table_base = rax; |
88 | const reg64_t reg_tmp = reg_table_base; |
89 | const reg64_t reg_total_padding = reg_table_base; |
90 | const reg64_t reg_aux_bias = reg_table_base; |
91 | const reg64_t reg_aux_scales = reg_table_base; |
92 | const reg64_t reg_binary_params = abi_param1; // default for binary ops |
93 | const reg64_t reg_ptr_sum_scale = reg_aux_A_vpad_top; |
94 | const reg64_t reg_ptr_sum_zp = reg_aux_A_vpad_bottom; |
95 | |
96 | Xbyak::Opmask k_mask = Xbyak::Opmask(2); |
97 | Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3); |
98 | Xbyak::Opmask kblend_mask = Xbyak::Opmask(4); |
99 | |
100 | /* used for bfloat16 */ |
101 | reg64_t bf16_emu_scratch = reg_table_base; |
102 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(0); |
103 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(1); |
104 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(2); |
105 | Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(3); |
106 | // note 1: zmm reserv_5 is not necessary since it's only used for |
107 | // 'vdpbf16ps' |
108 | // note 2: zmm0 collides with vmm_permute, hence need to write this value |
109 | // before every loop. |
110 | |
111 | const int simd_w_; |
112 | const int max_vmms_; |
113 | constexpr static int reg_batch0_addr_offs_ = 0; |
114 | constexpr static int reg_bias_offs_ = 8; |
115 | constexpr static int reg_scales_offs_ = 16; |
116 | constexpr static int reg_A_offs_ = 24; // brgemm_strd |
117 | constexpr static int reg_B_offs_ = 32; // brgemm_strd |
118 | constexpr static int abi_param1_offs_ = 40; |
119 | constexpr static int stack_space_needed_ = 48; |
120 | |
121 | bool with_binary_non_scalar_bcast_ = false; |
122 | |
123 | inline int M() { return brg.bcast_dim; } |
124 | inline int N() { return brg.load_dim; } |
125 | inline int m_block1() { return brg.bd_block; } |
126 | inline int nb_m_block1() { return brg.bdb; } |
127 | inline int m_block1_tail() { return brg.bdb_tail; } |
128 | inline int m_block2() { return brg.bd_block2; } |
129 | inline int nb_m_block2() { return brg.bdb2; } |
130 | inline int m_block2_tail() { return brg.bdb2_tail; } |
131 | |
132 | inline int n_block1() { return brg.ld_block; } |
133 | inline int nb_n_block1() { return brg.ldb; } |
134 | inline int n_block1_tail() { return brg.ldb_tail; } |
135 | inline int n_block2() { return brg.ld_block2; } |
136 | inline int nb_n_block2() { return brg.ldb2; } |
137 | inline int n_block2_tail() { return brg.ldb2_tail; } |
138 | |
139 | int tail_length() { return n_block1_tail() % simd_w_; } |
140 | bool is_fma_embd() { return brg.is_f32 && is_superset(isa, avx512_core); } |
141 | bool is_fast_vnni_int8() { return is_fast_vnni_int8(brg); } |
142 | int vnni_substep() { |
143 | return brg.isa_impl == avx2_vnni_2 && (brg.is_bf16 || brg.is_f16) ? 2 |
144 | : 1; |
145 | } |
146 | int get_substep_simd(int n_i, int v_i, bool has_n_tail) { |
147 | const int last_n_block_sz |
148 | = n_block2_tail() > 0 ? n_block2_tail() : n_block2(); |
149 | if (has_n_tail && n_i + 1 == last_n_block_sz) { |
150 | return nstl::min(simd_w_, n_block1_tail() - v_i * simd_w_); |
151 | } else { |
152 | return simd_w_; |
153 | } |
154 | } |
155 | Vmm vmm_permute() { return Vmm(0); } // used in fast_vnni_int8 |
156 | Vmm vmm_a() { return Vmm(is_fast_vnni_int8()); } |
157 | Vmm vmm_b(int bi = 0) { |
158 | return Vmm(is_fast_vnni_int8() + !is_fma_embd() + bi); |
159 | } |
160 | Vmm accm(int m_blocks, int n_blocks, int m, int n, int vnni_idx) { |
161 | assert(m_blocks <= m_block2() && m < m_blocks); |
162 | assert(n_blocks <= n_block2() && n < n_blocks); |
163 | const int accm_start = max_vmms_ - m_blocks * n_blocks * vnni_substep(); |
164 | const int accm_rel_idx |
165 | = m * n_blocks * vnni_substep() + n * vnni_substep() + vnni_idx; |
166 | const int idx = accm_start + accm_rel_idx; |
167 | assert(idx < max_vmms_ && idx > vmm_b(0).getIdx()); |
168 | return Vmm(idx); |
169 | } |
170 | Vmm vmm_tmp(int i) { |
171 | const int idx |
172 | = max_vmms_ - m_block2() * n_block2() * vnni_substep() - 1 - i; |
173 | assert(idx > (is_fast_vnni_int8() - 1)); |
174 | return Vmm(idx); |
175 | } |
176 | |
177 | template <typename U> |
178 | U maybe_mask(const U umm_in, bool mask_flag, bool store); |
179 | void init_masks(); |
180 | void read_params(); |
181 | void load_accumulators(int m_blocks, int n_blocks); |
182 | void restore_A_B_matrices(); |
183 | void set_A_B_matrices(); |
184 | void advance_A_B_matrices(); |
185 | void load_a(Vmm vmma, int m_i, int n_i, int v_i, bool has_n_tail); |
186 | void load_b(Vmm vmmb, int n_i, int v_i, bool has_n_tail); |
187 | void brdgmm_microkernel(int m_blocks, int n_blocks, bool has_top_padding, |
188 | bool has_bottom_padding, bool has_tail = false); |
189 | void compute_loop(); |
190 | void batch_loop(const int m_blocks, const int n_blocks, bool has_n_tail); |
191 | void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, |
192 | bool mask_flag, bool store); |
193 | void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail); |
194 | void maybe_transpose_interleaved_vnni_to_plain( |
195 | int m_blocks, int n_blocks, bool has_n_tail); |
196 | void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail); |
197 | void store_accumulators_without_post_ops( |
198 | int m_blocks, int n_blocks, bool has_n_tail); |
199 | void store_accumulators_apply_post_ops( |
200 | int m_blocks, int n_blocks, bool has_n_tail); |
201 | |
202 | bool has_vpad() { |
203 | return brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0; |
204 | } |
205 | bool check_effective_padding() { return has_vpad() && M() > m_block2(); } |
206 | |
207 | int oc_logical_offset(int n) { return n * n_block1(); } |
208 | int A_offset(int m, int n) { |
209 | return brg.typesize_A * (m * brg.LDA + n * n_block1()); |
210 | } |
211 | int B_offset(int n) { return brg.typesize_B * n * n_block1(); } |
212 | int C_offset(int m, int n, int v) { |
213 | return brg.typesize_C * (m * brg.LDC + n * n_block1() + v * simd_w_); |
214 | } |
215 | int D_offset(int m, int n, int v) { |
216 | return brg.typesize_D * (m * brg.LDD + n * n_block1() + v * simd_w_); |
217 | } |
218 | int bias_offset(int n, int v) { |
219 | return brg.typesize_bias * (n * n_block1() + v * simd_w_); |
220 | } |
221 | int scales_offset(int n, int v) { |
222 | return sizeof(float) * brg.is_oc_scale * (n * n_block1() + v * simd_w_); |
223 | } |
224 | |
225 | void generate() override; |
226 | }; |
227 | |
228 | } // namespace x64 |
229 | } // namespace cpu |
230 | } // namespace impl |
231 | } // namespace dnnl |
232 | |
233 | #endif |
234 | |