1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRDGMM_KERNEL_HPP
18#define CPU_X64_JIT_BRDGMM_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/x64/brgemm/brgemm_types.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
29#include "cpu/x64/jit_generator.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36template <cpu_isa_t isa, typename Wmm>
37struct jit_brdgmm_kernel_base_t : public jit_generator {
38 jit_brdgmm_kernel_base_t(const brgemm_t &abrd);
39
40 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brdgmm_kernel_base_t)
41
42 brgemm_t brg;
43
44 static bool is_fast_vnni_int8(const brgemm_t &brg) {
45 return brg.is_dgmm && brg.is_int8 && brg.isa_impl == avx512_core_vnni
46 && brg.ldb_tail /*n_vlen_tail*/ == 0;
47 }
48
49private:
50 // note: this kernel doesn't yet support TMM's. We differentiate Wmm and Vmm
51 // just to follow same template style as brgemm_kernel.
52 using Vmm =
53 typename utils::conditional<std::is_same<Wmm, Xbyak::Tmm>::value,
54 Xbyak::Zmm, Wmm>::type;
55 using Vmm_low_t = typename vreg_traits<Vmm>::Vmm_lower_t;
56 static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core, avx2,
57 avx2, avx2_vnni_2, avx2_vnni_2, avx512_core_fp16, avx512_core_fp16);
58 using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_t, Vmm>;
59 std::unique_ptr<po_injector_t> postops_injector_;
60 std::unique_ptr<bf16_emulation_t> bf16_emu_;
61
62 Xbyak::Label permute_index_table;
63
64 using reg64_t = const Xbyak::Reg64;
65 // Register decomposition
66 const reg64_t param1 = abi_param1;
67 const reg64_t reg_A = abi_not_param1;
68 const reg64_t reg_B = r8;
69 const reg64_t reg_aux_batch_addr = r15;
70 const reg64_t reg_BS = rsi;
71
72 // loop variables
73 const reg64_t reg_BS_loop = r12;
74 const reg64_t reg_aux_M = r13;
75 const reg64_t reg_aux_D = rbx;
76 const reg64_t reg_aux_C = rdx;
77 const reg64_t reg_aux_A = r10;
78 const reg64_t reg_aux_B = abi_param1;
79 const reg64_t reg_aux1_A = reg_A; // brgemm_strd
80 const reg64_t reg_aux1_B = reg_B; // brgemm_strd
81 const reg64_t reg_a_offset = r9;
82 const reg64_t reg_aux_N = r11;
83
84 const reg64_t reg_aux_A_vpad_top = r14;
85 const reg64_t reg_aux_A_vpad_bottom = rbp;
86
87 const reg64_t reg_table_base = rax;
88 const reg64_t reg_tmp = reg_table_base;
89 const reg64_t reg_total_padding = reg_table_base;
90 const reg64_t reg_aux_bias = reg_table_base;
91 const reg64_t reg_aux_scales = reg_table_base;
92 const reg64_t reg_binary_params = abi_param1; // default for binary ops
93 const reg64_t reg_ptr_sum_scale = reg_aux_A_vpad_top;
94 const reg64_t reg_ptr_sum_zp = reg_aux_A_vpad_bottom;
95
96 Xbyak::Opmask k_mask = Xbyak::Opmask(2);
97 Xbyak::Opmask k_tail_mask = Xbyak::Opmask(3);
98 Xbyak::Opmask kblend_mask = Xbyak::Opmask(4);
99
100 /* used for bfloat16 */
101 reg64_t bf16_emu_scratch = reg_table_base;
102 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(0);
103 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(1);
104 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(2);
105 Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(3);
106 // note 1: zmm reserv_5 is not necessary since it's only used for
107 // 'vdpbf16ps'
108 // note 2: zmm0 collides with vmm_permute, hence need to write this value
109 // before every loop.
110
111 const int simd_w_;
112 const int max_vmms_;
113 constexpr static int reg_batch0_addr_offs_ = 0;
114 constexpr static int reg_bias_offs_ = 8;
115 constexpr static int reg_scales_offs_ = 16;
116 constexpr static int reg_A_offs_ = 24; // brgemm_strd
117 constexpr static int reg_B_offs_ = 32; // brgemm_strd
118 constexpr static int abi_param1_offs_ = 40;
119 constexpr static int stack_space_needed_ = 48;
120
121 bool with_binary_non_scalar_bcast_ = false;
122
123 inline int M() { return brg.bcast_dim; }
124 inline int N() { return brg.load_dim; }
125 inline int m_block1() { return brg.bd_block; }
126 inline int nb_m_block1() { return brg.bdb; }
127 inline int m_block1_tail() { return brg.bdb_tail; }
128 inline int m_block2() { return brg.bd_block2; }
129 inline int nb_m_block2() { return brg.bdb2; }
130 inline int m_block2_tail() { return brg.bdb2_tail; }
131
132 inline int n_block1() { return brg.ld_block; }
133 inline int nb_n_block1() { return brg.ldb; }
134 inline int n_block1_tail() { return brg.ldb_tail; }
135 inline int n_block2() { return brg.ld_block2; }
136 inline int nb_n_block2() { return brg.ldb2; }
137 inline int n_block2_tail() { return brg.ldb2_tail; }
138
139 int tail_length() { return n_block1_tail() % simd_w_; }
140 bool is_fma_embd() { return brg.is_f32 && is_superset(isa, avx512_core); }
141 bool is_fast_vnni_int8() { return is_fast_vnni_int8(brg); }
142 int vnni_substep() {
143 return brg.isa_impl == avx2_vnni_2 && (brg.is_bf16 || brg.is_f16) ? 2
144 : 1;
145 }
146 int get_substep_simd(int n_i, int v_i, bool has_n_tail) {
147 const int last_n_block_sz
148 = n_block2_tail() > 0 ? n_block2_tail() : n_block2();
149 if (has_n_tail && n_i + 1 == last_n_block_sz) {
150 return nstl::min(simd_w_, n_block1_tail() - v_i * simd_w_);
151 } else {
152 return simd_w_;
153 }
154 }
155 Vmm vmm_permute() { return Vmm(0); } // used in fast_vnni_int8
156 Vmm vmm_a() { return Vmm(is_fast_vnni_int8()); }
157 Vmm vmm_b(int bi = 0) {
158 return Vmm(is_fast_vnni_int8() + !is_fma_embd() + bi);
159 }
160 Vmm accm(int m_blocks, int n_blocks, int m, int n, int vnni_idx) {
161 assert(m_blocks <= m_block2() && m < m_blocks);
162 assert(n_blocks <= n_block2() && n < n_blocks);
163 const int accm_start = max_vmms_ - m_blocks * n_blocks * vnni_substep();
164 const int accm_rel_idx
165 = m * n_blocks * vnni_substep() + n * vnni_substep() + vnni_idx;
166 const int idx = accm_start + accm_rel_idx;
167 assert(idx < max_vmms_ && idx > vmm_b(0).getIdx());
168 return Vmm(idx);
169 }
170 Vmm vmm_tmp(int i) {
171 const int idx
172 = max_vmms_ - m_block2() * n_block2() * vnni_substep() - 1 - i;
173 assert(idx > (is_fast_vnni_int8() - 1));
174 return Vmm(idx);
175 }
176
177 template <typename U>
178 U maybe_mask(const U umm_in, bool mask_flag, bool store);
179 void init_masks();
180 void read_params();
181 void load_accumulators(int m_blocks, int n_blocks);
182 void restore_A_B_matrices();
183 void set_A_B_matrices();
184 void advance_A_B_matrices();
185 void load_a(Vmm vmma, int m_i, int n_i, int v_i, bool has_n_tail);
186 void load_b(Vmm vmmb, int n_i, int v_i, bool has_n_tail);
187 void brdgmm_microkernel(int m_blocks, int n_blocks, bool has_top_padding,
188 bool has_bottom_padding, bool has_tail = false);
189 void compute_loop();
190 void batch_loop(const int m_blocks, const int n_blocks, bool has_n_tail);
191 void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op,
192 bool mask_flag, bool store);
193 void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail);
194 void maybe_transpose_interleaved_vnni_to_plain(
195 int m_blocks, int n_blocks, bool has_n_tail);
196 void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail);
197 void store_accumulators_without_post_ops(
198 int m_blocks, int n_blocks, bool has_n_tail);
199 void store_accumulators_apply_post_ops(
200 int m_blocks, int n_blocks, bool has_n_tail);
201
202 bool has_vpad() {
203 return brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0;
204 }
205 bool check_effective_padding() { return has_vpad() && M() > m_block2(); }
206
207 int oc_logical_offset(int n) { return n * n_block1(); }
208 int A_offset(int m, int n) {
209 return brg.typesize_A * (m * brg.LDA + n * n_block1());
210 }
211 int B_offset(int n) { return brg.typesize_B * n * n_block1(); }
212 int C_offset(int m, int n, int v) {
213 return brg.typesize_C * (m * brg.LDC + n * n_block1() + v * simd_w_);
214 }
215 int D_offset(int m, int n, int v) {
216 return brg.typesize_D * (m * brg.LDD + n * n_block1() + v * simd_w_);
217 }
218 int bias_offset(int n, int v) {
219 return brg.typesize_bias * (n * n_block1() + v * simd_w_);
220 }
221 int scales_offset(int n, int v) {
222 return sizeof(float) * brg.is_oc_scale * (n * n_block1() + v * simd_w_);
223 }
224
225 void generate() override;
226};
227
228} // namespace x64
229} // namespace cpu
230} // namespace impl
231} // namespace dnnl
232
233#endif
234