1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <memory> |
17 | #include <vector> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | #define GET_OFF(field) offsetof(brgemm_kernel_params_t, field) |
32 | #define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field) |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | using namespace dnnl::impl::utils; |
40 | using namespace Xbyak; |
41 | template <cpu_isa_t isa, typename Wmm> |
42 | struct jit_brgemm_kernel_t : public jit_generator { |
43 | jit_brgemm_kernel_t(const brgemm_t &abrg) |
44 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, abrg.isa_impl) |
45 | , brg(abrg) |
46 | , postops_injector_(nullptr) { |
47 | |
48 | // The implementation uses is_superset(), is_subset() utilities. |
49 | // So avoid isa_all, isa_undef in these comparisions. |
50 | assert(!utils::one_of(brg.isa_impl, isa_all, isa_undef)); |
51 | const int is_ldb2_tail = brg.ldb2_tail ? 1 : 0; |
52 | const int is_ldb_tail = brg.ldb_tail ? 1 : 0; |
53 | is_ldb_loop_ = brg.ldb2 + is_ldb2_tail + is_ldb_tail > 1; |
54 | |
55 | if (brg.with_eltwise || brg.with_binary || brg.with_sum) { |
56 | |
57 | static constexpr bool preserve_gpr = true; |
58 | static constexpr bool preserve_vmm = true; |
59 | static constexpr bool use_exact_tail_scalar_bcast = false; |
60 | const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md); |
61 | |
62 | static const bcast_set_t enabled_bcast_strategy |
63 | = {broadcasting_strategy_t::scalar, |
64 | broadcasting_strategy_t::per_oc, |
65 | broadcasting_strategy_t::per_oc_spatial, |
66 | broadcasting_strategy_t::per_mb_spatial, |
67 | broadcasting_strategy_t::per_mb_w, |
68 | broadcasting_strategy_t::per_w, |
69 | broadcasting_strategy_t::no_broadcast}; |
70 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
71 | static_cast<size_t>(Vmm(1).getIdx()), this->r14, this->r15, |
72 | this->r13, preserve_gpr, preserve_vmm, |
73 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), |
74 | dst_md_wrapper, static_cast<size_t>(brg.ldb_tail), |
75 | ld_tail_mask, use_exact_tail_scalar_bcast}; |
76 | const binary_injector::static_params_t bsp { |
77 | this->param1, enabled_bcast_strategy, rhs_sp}; |
78 | |
79 | postops_injector_ = utils::make_unique<po_injector_t>( |
80 | this, brg.attr->post_ops_, bsp); |
81 | |
82 | using namespace dnnl::impl::cpu::binary_injector_utils; |
83 | std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_, |
84 | with_binary_channel_bcast_, with_binary_per_mb_w_bcast_, |
85 | with_binary_per_w_bcast_, with_binary_no_bcast_) |
86 | = bcast_strategies_present_tup(brg.attr->post_ops_.entry_, |
87 | dst_md_wrapper, broadcasting_strategy_t::per_oc, |
88 | broadcasting_strategy_t::per_oc_spatial, |
89 | broadcasting_strategy_t::per_mb_spatial, |
90 | broadcasting_strategy_t::per_mb_w, |
91 | broadcasting_strategy_t::per_w, |
92 | broadcasting_strategy_t::no_broadcast); |
93 | handle_binary_po_offset_ = with_binary_per_oc_bcast_ |
94 | || with_binary_per_oc_sp_bcast_ |
95 | || with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_ |
96 | || with_binary_per_w_bcast_ || with_binary_no_bcast_; |
97 | } |
98 | if (brg.is_bf16_emu) |
99 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
100 | bf16_emu_reserv_1(), bf16_emu_reserv_2(), |
101 | bf16_emu_reserv_3(), bf16_emu_scratch, bf16_emu_reserv_4(), |
102 | bf16_emu_reserv_4()); |
103 | } |
104 | |
105 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_kernel_t) |
106 | |
107 | brgemm_t brg; |
108 | |
109 | private: |
110 | using Vmm = |
111 | typename utils::conditional<std::is_same<Wmm, Xbyak::Tmm>::value, |
112 | Xbyak::Zmm, Wmm>::type; |
113 | using Vmm_lower_t = typename vreg_traits<Vmm>::Vmm_lower_t; |
114 | static constexpr cpu_isa_t po_isa_t = utils::map(isa, avx512_core, |
115 | avx512_core_amx_fp16, avx512_core_fp16, avx512_core_amx, |
116 | avx512_core_fp16, avx512_core_fp16, avx512_core_fp16, avx2_vnni_2, |
117 | avx2_vnni_2, avx2_vnni, avx2, avx2, avx2); |
118 | using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_t, Vmm>; |
119 | std::unique_ptr<po_injector_t> postops_injector_; |
120 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
121 | |
122 | using reg64_t = const Xbyak::Reg64; |
123 | |
124 | // Register decomposition |
125 | const reg64_t param1 = abi_param1; |
126 | |
127 | const reg64_t reg_C = r15; |
128 | const reg64_t reg_aux_C = r14; |
129 | |
130 | const reg64_t reg_addr_batch = r13; |
131 | const reg64_t reg_A = r13; |
132 | const reg64_t reg_B = r12; |
133 | |
134 | const reg64_t reg_aux_A = r11; |
135 | const reg64_t reg_aux_B = r10; |
136 | const reg64_t reg_aux_A_vpad = reg_aux_A; |
137 | |
138 | const reg64_t reg_bdb_loop = r9; |
139 | const reg64_t reg_ldb_loop = r8; |
140 | |
141 | const reg64_t reg_stride_lda = reg_bdb_loop; |
142 | const reg64_t reg_stride_ldb = reg_ldb_loop; |
143 | const reg64_t reg_stride_ld_block = reg_ldb_loop; |
144 | const reg64_t reg_s8_input_shift = reg_bdb_loop; |
145 | const reg64_t reg_zp_a_input_shift = reg_bdb_loop; |
146 | |
147 | const reg64_t reg_BS_loop = rax; |
148 | const reg64_t reg_rdb_loop = rbx; |
149 | const reg64_t reg_BS = abi_not_param1; |
150 | |
151 | const reg64_t reg_a_offset = rdx; |
152 | const reg64_t reg_b_offset = rsi; |
153 | |
154 | const reg64_t reg_aux1_batch = rbp; |
155 | const reg64_t reg_aux1_A = rbp; |
156 | const reg64_t reg_aux1_B = abi_param1; |
157 | |
158 | const reg64_t reg_offs_batch = reg_aux1_A; |
159 | const reg64_t reg_strd_batch = reg_rdb_loop; |
160 | |
161 | const reg64_t reg_bias = reg_rdb_loop; |
162 | const reg64_t reg_scales = reg_rdb_loop; |
163 | const reg64_t reg_aux_bias = reg_rdb_loop; |
164 | const reg64_t reg_binary_postops_oc_l = reg_rdb_loop; |
165 | const reg64_t reg_aux_binary_postops_oc_l = reg_rdb_loop; |
166 | const reg64_t reg_aux_binary_postops_sp = reg_rdb_loop; |
167 | const reg64_t reg_binary_po_stack_frame = reg_rdb_loop; |
168 | const reg64_t reg_zp_comp_a = reg_rdb_loop; |
169 | const reg64_t reg_aux_zp_comp_a = reg_rdb_loop; |
170 | const reg64_t reg_zp_comp_b = reg_rdb_loop; |
171 | const reg64_t reg_aux_zp_comp_b = reg_rdb_loop; |
172 | const reg64_t reg_zp_c_values = reg_rdb_loop; |
173 | const reg64_t reg_aux_zp_c_values = reg_rdb_loop; |
174 | |
175 | const reg64_t reg_aux_scales = reg_aux_B; |
176 | const reg64_t reg_do_post_ops = reg_rdb_loop; |
177 | const reg64_t reg_do_comp = reg_rdb_loop; |
178 | const reg64_t reg_skip_accm = reg_rdb_loop; |
179 | const reg64_t reg_tmp_gpr = reg_rdb_loop; |
180 | const reg64_t reg_ptr_sum_scale = reg_rdb_loop; |
181 | const reg64_t reg_ptr_sum_zp = reg_bdb_loop; |
182 | const reg64_t reg_zp_a_val = reg_rdb_loop; |
183 | |
184 | const reg64_t reg_buf = reg_rdb_loop; |
185 | const reg64_t reg_compensation = reg_bias; |
186 | const reg64_t reg_aux_compensation = reg_aux_bias; |
187 | |
188 | const reg64_t reg_D = reg_aux_A; |
189 | const reg64_t reg_aux_D = reg_BS_loop; |
190 | |
191 | /* bf16 emulation */ |
192 | const reg64_t bf16_emu_scratch = reg_rdb_loop; |
193 | |
194 | constexpr static int origin_offs_batch_offs_ = 0; |
195 | constexpr static int origin_strd_batch_offs_ = 0; |
196 | constexpr static int reg_bias_offs_ = 8; |
197 | constexpr static int reg_aux_bias_offs_ = 16; |
198 | constexpr static int reg_do_post_ops_offs_ = 24; |
199 | constexpr static int reg_D_offs_ = 32; |
200 | constexpr static int reg_aux_D_offs_ = 40; |
201 | constexpr static int reg_scales_offs_ = 48; |
202 | constexpr static int reg_aux_scales_offs_ = 56; |
203 | constexpr static int reg_bdb_loop_offs_ = 64; |
204 | constexpr static int reg_ldb_loop_offs_ = 72; |
205 | constexpr static int reg_buf_offs_ = 80; |
206 | constexpr static int reg_comp_offs_ = reg_buf_offs_; |
207 | constexpr static int reg_aux_comp_offs_ = 88; |
208 | constexpr static int abi_param1_offs_ = 96; |
209 | constexpr static int reg_binary_postops_oc_l_offs_ = 104; |
210 | constexpr static int reg_aux_binary_postops_oc_l_offs_ = 112; |
211 | constexpr static int reg_binary_postops_sp_offs_ = 120; |
212 | constexpr static int reg_aux_binary_postops_sp_offs_ = 128; |
213 | constexpr static int reg_zp_comp_a_offs_ = 136; |
214 | constexpr static int reg_aux_zp_comp_a_offs_ = 144; |
215 | constexpr static int reg_zp_comp_b_offs_ = 152; |
216 | constexpr static int reg_aux_zp_comp_b_offs_ = 160; |
217 | constexpr static int reg_zp_c_values_offs_ = 168; |
218 | constexpr static int reg_aux_zp_c_values_offs_ = 176; |
219 | constexpr static int reg_data_C_ptr_ = 184; |
220 | constexpr static int reg_skip_accm_offs_ = 192; |
221 | constexpr static int reg_zp_a_val_offs_ = 200; |
222 | constexpr static int reg_do_comp_offs_ = 208; |
223 | constexpr static int stack_space_needed_ = 216; |
224 | |
225 | bool is_ldb_loop_ = false; |
226 | bool handle_binary_po_offset_ = false; |
227 | bool with_binary_per_oc_bcast_ = false; |
228 | bool with_binary_per_oc_sp_bcast_ = false; |
229 | bool with_binary_channel_bcast_ = false; |
230 | bool with_binary_per_mb_w_bcast_ = false; |
231 | bool with_binary_per_w_bcast_ = false; |
232 | bool with_binary_no_bcast_ = false; |
233 | constexpr static int max_vregs = cpu_isa_traits<po_isa_t>::n_vregs; |
234 | |
235 | Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2); |
236 | Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3); |
237 | |
238 | Vmm accm(int ld_block, int bd, int ld) { |
239 | return Vmm(max_vregs - 1 - (bd * ld_block + ld)); |
240 | } |
241 | |
242 | Vmm bcst(int bd = 0) { |
243 | if (n_bcast_1_load) { |
244 | int idx = max_vregs - 1 - (brg.ld_block2 * brg.bd_block) - bd; |
245 | assert(idx > 0); |
246 | return Vmm(idx); |
247 | } else |
248 | return Vmm(0); |
249 | } |
250 | |
251 | Vmm load(int ld = 0) { |
252 | if (n_bcast_1_load) { |
253 | return Vmm(0); |
254 | } else { |
255 | int idx = max_vregs - 1 - (brg.ld_block2 * brg.bd_block) - ld; |
256 | assert(idx > 0); |
257 | return Vmm(idx); |
258 | } |
259 | } |
260 | |
261 | Vmm vmm_tmp_1() const noexcept { return Vmm(0); } |
262 | Vmm vmm_tmp_2() const noexcept { return Vmm(1); } |
263 | Vmm vmm_tmp_3() const noexcept { return Vmm(2); } |
264 | Vmm vmm_one_bytes() const noexcept { return Vmm(3); } |
265 | Vmm vmm_zp_a_shift() const noexcept { return Vmm(2); } |
266 | Vmm vmm_inp_shift() const noexcept { return Vmm(1); } |
267 | |
268 | /* bf16 emulation */ |
269 | Zmm bf16_emu_reserv_1() const noexcept { return Zmm(0); } |
270 | Zmm bf16_emu_reserv_2() const noexcept { return Zmm(1); } |
271 | Zmm bf16_emu_reserv_3() const noexcept { return Zmm(2); } |
272 | Zmm bf16_emu_reserv_4() const noexcept { return Zmm(3); } |
273 | // note: zmm reserv_5 is not necessary since it's only used for 'vdpbf16ps' |
274 | |
275 | Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store, |
276 | Xbyak::Opmask ktail_mask) const; |
277 | Vmm_lower_t vmm_lower_mask(const Vmm_lower_t vmm_lower_in, bool mask_flag, |
278 | bool store, Xbyak::Opmask ktail_mask) const; |
279 | |
280 | void cvt2ps(data_type_t type_in, const Vmm vmm_in, const Xbyak::Operand &op, |
281 | bool mask_flag, bool store, Xbyak::Opmask ktail_mask, |
282 | int tail_size); |
283 | |
284 | void advance_ldb_post_op_regs(); |
285 | void restore_ldb_post_op_regs(int ld_block2); |
286 | void advance_bdb_post_op_regs(int adj_bd_block); |
287 | void restore_bdb_post_op_regs(int bd_block2); |
288 | void ldb_regs_shift(int ld_block2, bool is_tail = false); |
289 | void advance_bd_block2_post_op_regs(int bd_block2); |
290 | |
291 | void copy_post_ops_stack_values_to_aux(bool is_reg_tail); |
292 | void read_params(); |
293 | void zero_accumulators(int bd_block2, bool is_bdb_tail, int ld_block, |
294 | bool is_ld_tail, bool skip_accumulation); |
295 | |
296 | void store_accumulators(int bd_block2, bool is_bdb_tail, int ld_block, |
297 | bool is_ld_tail, bool skip_accumulation); |
298 | void store_accumulators_without_post_ops( |
299 | int bd_block, int ld_block, bool is_ld_tail); |
300 | void store_accumulators_apply_post_ops(int bd_block, int ld_block, |
301 | int ldb_and_bdb_offset, bool is_ld_tail); |
302 | void apply_compensation(int bd_block, int ld_block, bool is_ld_tail); |
303 | void apply_alpha_beta(int bd_block, int ld_block, bool is_ld_tail); |
304 | void apply_post_ops(int bd_block, int ld_block2, int ldb_and_bdb_offset, |
305 | bool is_ld_tail); |
306 | void restore_A_B_matrices(); |
307 | void set_A_B_matrices(); |
308 | |
309 | void gemm_microkernel(int bd_block2, bool is_bdb_tail, int ld_block, |
310 | bool is_rd_tail, bool is_ld_tail, int vpad, int rows_for_rd_tail); |
311 | void gemm_microkernel_amx(int bd_block2, bool is_bdb_tail, int ld_block, |
312 | bool is_rd_tail, bool is_ld_tail); |
313 | |
314 | void ldb_loop(int bd_block2, bool is_bdb_tail, int ld_block, |
315 | int ldb_loop_length, bool is_reg_tail, bool is_ld_tail, |
316 | bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail, |
317 | bool skip_accumulation); |
318 | void bdb_loop(); |
319 | |
320 | void generate() override; |
321 | |
322 | int A_offset(int bd, int rd, bool is_amx = false) const noexcept; |
323 | int B_offset(int ld, int rd, bool is_amx = false) const noexcept; |
324 | int C_offset(int bd, int ld) const noexcept; |
325 | int D_offset(int bd, int ld) const noexcept; |
326 | int po_offset(int bd, int ld) const noexcept; |
327 | |
328 | int rdb_A_offset() const noexcept; |
329 | int rdb_B_offset() const noexcept; |
330 | |
331 | int ldb_B_offset(int ld_block2, bool is_tail = false) const noexcept; |
332 | int ldb_C_offset(int ld_block2, bool is_tail = false) const noexcept; |
333 | int ldb_D_offset(int ld_block2, bool is_tail = false) const noexcept; |
334 | int ldb_po_offset(int ld_block2, bool is_tail = false) const noexcept; |
335 | |
336 | int bdb_A_offset(int bd_block2) const noexcept; |
337 | int bdb_C_offset(int bd_block2) const noexcept; |
338 | int bdb_D_offset(int bd_block2) const noexcept; |
339 | int bdb_po_offset(int bd_block2) const noexcept; |
340 | |
341 | int bias_offset(int ld, bool is_tail = false) const noexcept; |
342 | int oc_logical_offset(int ld, bool is_tail = false) const noexcept; |
343 | |
344 | int compensations_offset(int ld, bool is_tail = false) const noexcept; |
345 | int bdb_compensation_offset(int bd_block2) const noexcept; |
346 | int compensation_vpad_offset(int ld, int bd) const noexcept; |
347 | int scales_offset(int ld, bool is_tail = false) const noexcept; |
348 | int zp_comp_a_offset(int ld, bool is_tail = false) const noexcept; |
349 | int zp_comp_a_vpad_offset(int ld, int bd) const noexcept; |
350 | int bdb_zp_comp_a_offset(int bd_block2) const noexcept; |
351 | int zp_comp_b_offset(int bd) const noexcept; |
352 | int bdb_zp_comp_b_offset(int bd_block2) const noexcept; |
353 | int zp_c_values_offset(int ld, bool is_tail = false) const noexcept; |
354 | |
355 | bool n_bcast_1_load = false; |
356 | bool vpad_exist = false; |
357 | bool need_comp_pads = false; |
358 | }; |
359 | |
360 | template <cpu_isa_t isa, typename Wmm> |
361 | int jit_brgemm_kernel_t<isa, Wmm>::A_offset(int bd, int rd, bool is_amx) const |
362 | noexcept { |
363 | return (is_amx) ? brg.typesize_A * (bd * brg.bd_block * brg.LDA) |
364 | : brg.typesize_A * (bd * brg.LDA + rd); |
365 | } |
366 | |
367 | template <cpu_isa_t isa, typename Wmm> |
368 | int jit_brgemm_kernel_t<isa, Wmm>::B_offset(int ld, int rd, bool is_amx) const |
369 | noexcept { |
370 | if (is_amx) { |
371 | return brg.typesize_B * (brg.rd_step * ld * brg.ld_block); |
372 | } else { |
373 | const int data_vnni_granularity = brg.ld_step; |
374 | const int rdb0 = rd / data_vnni_granularity; |
375 | // Note: Offsets for elements within vnni_granularity are expected to be |
376 | // handled within gemm_microkernel (for ex: odd-even converts). |
377 | // hence no `rd % data_vnni_granularity` |
378 | return brg.typesize_B |
379 | * (rdb0 * data_vnni_granularity * brg.LDB |
380 | + data_vnni_granularity * ld * brg.ld_block); |
381 | } |
382 | } |
383 | |
384 | template <cpu_isa_t isa, typename Wmm> |
385 | int jit_brgemm_kernel_t<isa, Wmm>::C_offset(int bd, int ld) const noexcept { |
386 | return brg.typesize_C * (bd * brg.LDC + ld * brg.ld_block); |
387 | } |
388 | |
389 | template <cpu_isa_t isa, typename Wmm> |
390 | int jit_brgemm_kernel_t<isa, Wmm>::D_offset(int bd, int ld) const noexcept { |
391 | return brg.typesize_D * (bd * brg.LDD + ld * brg.ld_block); |
392 | } |
393 | |
394 | template <cpu_isa_t isa, typename Wmm> |
395 | int jit_brgemm_kernel_t<isa, Wmm>::po_offset(int bd, int ld) const noexcept { |
396 | return bd * brg.LDD + ld * brg.ld_block; |
397 | } |
398 | |
399 | template <cpu_isa_t isa, typename Wmm> |
400 | int jit_brgemm_kernel_t<isa, Wmm>::rdb_A_offset() const noexcept { |
401 | return brg.typesize_A * brg.rd_block; |
402 | } |
403 | |
404 | template <cpu_isa_t isa, typename Wmm> |
405 | int jit_brgemm_kernel_t<isa, Wmm>::rdb_B_offset() const noexcept { |
406 | return brg.typesize_B * brg.rd_block * brg.LDB; |
407 | } |
408 | |
409 | template <cpu_isa_t isa, typename Wmm> |
410 | int jit_brgemm_kernel_t<isa, Wmm>::ldb_B_offset( |
411 | int ld_block2, bool is_tail) const noexcept { |
412 | return (is_tail) ? brg.typesize_B * brg.ldb_tail * brg.ld_step |
413 | : brg.typesize_B * ld_block2 * brg.ld_block * brg.ld_step; |
414 | } |
415 | |
416 | template <cpu_isa_t isa, typename Wmm> |
417 | int jit_brgemm_kernel_t<isa, Wmm>::ldb_C_offset( |
418 | int ld_block2, bool is_tail) const noexcept { |
419 | return (is_tail) ? brg.typesize_C * brg.ldb_tail |
420 | : brg.typesize_C * ld_block2 * brg.ld_block; |
421 | } |
422 | |
423 | template <cpu_isa_t isa, typename Wmm> |
424 | int jit_brgemm_kernel_t<isa, Wmm>::ldb_D_offset( |
425 | int ld_block2, bool is_tail) const noexcept { |
426 | return (is_tail) ? brg.typesize_D * brg.ldb_tail |
427 | : brg.typesize_D * ld_block2 * brg.ld_block; |
428 | } |
429 | |
430 | template <cpu_isa_t isa, typename Wmm> |
431 | int jit_brgemm_kernel_t<isa, Wmm>::ldb_po_offset( |
432 | int ld_block2, bool is_tail) const noexcept { |
433 | return (is_tail) ? brg.ldb_tail : ld_block2 * brg.ld_block; |
434 | } |
435 | |
436 | template <cpu_isa_t isa, typename Wmm> |
437 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_A_offset(int bd_block2) const noexcept { |
438 | return brg.typesize_A * bd_block2 * brg.bd_block * brg.LDA; |
439 | } |
440 | |
441 | template <cpu_isa_t isa, typename Wmm> |
442 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_C_offset(int bd_block2) const noexcept { |
443 | return brg.typesize_C * bd_block2 * brg.bd_block * brg.LDC; |
444 | } |
445 | |
446 | template <cpu_isa_t isa, typename Wmm> |
447 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_D_offset(int bd_block2) const noexcept { |
448 | return brg.typesize_D * bd_block2 * brg.bd_block * brg.LDD; |
449 | } |
450 | |
451 | template <cpu_isa_t isa, typename Wmm> |
452 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_po_offset(int bd_block2) const noexcept { |
453 | return bd_block2 * brg.bd_block * brg.LDD; |
454 | } |
455 | |
456 | template <cpu_isa_t isa, typename Wmm> |
457 | int jit_brgemm_kernel_t<isa, Wmm>::bias_offset(int ld, bool is_tail) const |
458 | noexcept { |
459 | return (is_tail) ? brg.typesize_bias * brg.ldb_tail |
460 | : brg.typesize_bias * ld * brg.ld_block; |
461 | } |
462 | |
463 | template <cpu_isa_t isa, typename Wmm> |
464 | int jit_brgemm_kernel_t<isa, Wmm>::oc_logical_offset(int ld, bool is_tail) const |
465 | noexcept { |
466 | return (is_tail) ? brg.ldb_tail : ld * brg.ld_block; |
467 | } |
468 | |
469 | template <cpu_isa_t isa, typename Wmm> |
470 | int jit_brgemm_kernel_t<isa, Wmm>::compensations_offset( |
471 | int ld, bool is_tail) const noexcept { |
472 | return (is_tail) ? sizeof(int32_t) * brg.ldb_tail |
473 | : sizeof(int32_t) * ld * brg.ld_block; |
474 | } |
475 | |
476 | template <cpu_isa_t isa, typename Wmm> |
477 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_compensation_offset(int bd_block2) const |
478 | noexcept { |
479 | return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB; |
480 | } |
481 | |
482 | template <cpu_isa_t isa, typename Wmm> |
483 | int jit_brgemm_kernel_t<isa, Wmm>::compensation_vpad_offset( |
484 | int ld, int bd) const noexcept { |
485 | return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB); |
486 | } |
487 | |
488 | template <cpu_isa_t isa, typename Wmm> |
489 | int jit_brgemm_kernel_t<isa, Wmm>::scales_offset(int ld, bool is_tail) const |
490 | noexcept { |
491 | return (is_tail) ? brg.is_oc_scale * sizeof(float) * brg.ldb_tail |
492 | : brg.is_oc_scale * sizeof(float) * ld * brg.ld_block; |
493 | } |
494 | |
495 | template <cpu_isa_t isa, typename Wmm> |
496 | int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_a_offset(int ld, bool is_tail) const |
497 | noexcept { |
498 | return (is_tail) ? sizeof(int32_t) * brg.ldb_tail |
499 | : sizeof(int32_t) * ld * brg.ld_block; |
500 | } |
501 | |
502 | template <cpu_isa_t isa, typename Wmm> |
503 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_zp_comp_a_offset(int bd_block2) const |
504 | noexcept { |
505 | return sizeof(int32_t) * bd_block2 * brg.bd_block * brg.LDB; |
506 | } |
507 | |
508 | template <cpu_isa_t isa, typename Wmm> |
509 | int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_a_vpad_offset(int ld, int bd) const |
510 | noexcept { |
511 | return sizeof(int32_t) * (ld * brg.ld_block + bd * brg.LDB); |
512 | } |
513 | |
514 | template <cpu_isa_t isa, typename Wmm> |
515 | int jit_brgemm_kernel_t<isa, Wmm>::zp_comp_b_offset(int bd) const noexcept { |
516 | return sizeof(int32_t) * bd; |
517 | } |
518 | |
519 | template <cpu_isa_t isa, typename Wmm> |
520 | int jit_brgemm_kernel_t<isa, Wmm>::bdb_zp_comp_b_offset(int bd_block2) const |
521 | noexcept { |
522 | return zp_comp_b_offset(bd_block2 * brg.bd_block); |
523 | } |
524 | |
525 | template <cpu_isa_t isa, typename Wmm> |
526 | int jit_brgemm_kernel_t<isa, Wmm>::zp_c_values_offset( |
527 | int ld, bool is_tail) const noexcept { |
528 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
529 | return (is_tail) ? sizeof(int32_t) * brg.ldb_tail |
530 | : sizeof(int32_t) * ld * brg.ld_block; |
531 | } |
532 | |
533 | return 0; |
534 | } |
535 | template <cpu_isa_t isa, typename Wmm> |
536 | typename jit_brgemm_kernel_t<isa, Wmm>::Vmm |
537 | jit_brgemm_kernel_t<isa, Wmm>::vmm_mask(const Vmm vmm_in, bool mask_flag, |
538 | bool store, Xbyak::Opmask ktail_mask) const { |
539 | return mask_flag && is_superset(brg.isa_impl, avx512_core) |
540 | ? (store ? vmm_in | ktail_mask : vmm_in | ktail_mask | T_z) |
541 | : vmm_in; |
542 | } |
543 | |
544 | template <cpu_isa_t isa, typename Wmm> |
545 | typename jit_brgemm_kernel_t<isa, Wmm>::Vmm_lower_t |
546 | jit_brgemm_kernel_t<isa, Wmm>::vmm_lower_mask(const Vmm_lower_t vmm_lower_in, |
547 | bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const { |
548 | return mask_flag && is_superset(brg.isa_impl, avx512_core) |
549 | ? (store ? vmm_lower_in | ktail_mask |
550 | : vmm_lower_in | ktail_mask | T_z) |
551 | : vmm_lower_in; |
552 | } |
553 | |
554 | template <cpu_isa_t isa, typename Wmm> |
555 | void jit_brgemm_kernel_t<isa, Wmm>::cvt2ps(data_type_t type_in, |
556 | const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag, bool store, |
557 | Xbyak::Opmask ktail_mask, int tail_size) { |
558 | Vmm vmm = vmm_in; |
559 | const bool has_tail |
560 | = op.isMEM() && tail_size != vreg_traits<Vmm>::vlen / sizeof(float); |
561 | if (IMPLICATION(has_tail, is_superset(brg.isa_impl, avx512_core))) { |
562 | vmm = vmm_mask(vmm_in, mask_flag, store, ktail_mask); |
563 | } else { |
564 | uni_vpxor(vmm_in, vmm_in, vmm_in); |
565 | load_data(type_in, vmm_in, op.getAddress(), tail_size); |
566 | if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in); |
567 | return; |
568 | } |
569 | switch (type_in) { |
570 | case data_type::f32: |
571 | case data_type::s32: uni_vmovups(vmm, op); break; |
572 | case data_type::bf16: |
573 | uni_vpmovzxwd(vmm, op); |
574 | uni_vpslld(vmm, vmm, 16); |
575 | break; |
576 | case data_type::f16: vcvtph2ps(vmm, op); break; |
577 | case data_type::s8: uni_vpmovsxbd(vmm, op); break; |
578 | case data_type::u8: uni_vpmovzxbd(vmm, op); break; |
579 | default: assert(!"unsupported data type" ); |
580 | } |
581 | if (types::is_integral_dt(type_in)) uni_vcvtdq2ps(vmm_in, vmm_in); |
582 | } |
583 | |
584 | template <cpu_isa_t isa, typename Wmm> |
585 | void jit_brgemm_kernel_t<isa, Wmm>::advance_ldb_post_op_regs() { |
586 | if (brg.with_bias) { |
587 | mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); |
588 | add(reg_aux_bias, bias_offset(1)); |
589 | mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias); |
590 | } |
591 | if (brg.with_scales) { |
592 | mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); |
593 | add(reg_aux_scales, scales_offset(1)); |
594 | mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales); |
595 | } |
596 | if (with_binary_per_oc_bcast_) { |
597 | mov(reg_aux_binary_postops_oc_l, |
598 | ptr[rsp + reg_aux_binary_postops_oc_l_offs_]); |
599 | add(reg_aux_binary_postops_oc_l, oc_logical_offset(1)); |
600 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_], |
601 | reg_aux_binary_postops_oc_l); |
602 | } |
603 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
604 | mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); |
605 | add(reg_aux_zp_comp_a, zp_comp_a_offset(1)); |
606 | mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a); |
607 | } |
608 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
609 | mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]); |
610 | add(reg_aux_zp_c_values, zp_c_values_offset(1)); |
611 | mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values); |
612 | } |
613 | } |
614 | |
615 | template <cpu_isa_t isa, typename Wmm> |
616 | void jit_brgemm_kernel_t<isa, Wmm>::restore_ldb_post_op_regs(int ld_block2) { |
617 | if (brg.with_bias) { |
618 | mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); |
619 | sub(reg_aux_bias, bias_offset(ld_block2 - 1)); |
620 | mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias); |
621 | } |
622 | if (brg.with_scales) { |
623 | mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); |
624 | sub(reg_aux_scales, scales_offset(ld_block2 - 1)); |
625 | mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales); |
626 | } |
627 | if (with_binary_per_oc_bcast_) { |
628 | mov(reg_aux_binary_postops_oc_l, |
629 | ptr[rsp + reg_aux_binary_postops_oc_l_offs_]); |
630 | sub(reg_aux_binary_postops_oc_l, oc_logical_offset(ld_block2 - 1)); |
631 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_], |
632 | reg_aux_binary_postops_oc_l); |
633 | } |
634 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
635 | mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); |
636 | sub(reg_aux_zp_comp_a, zp_comp_a_offset(ld_block2 - 1)); |
637 | mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a); |
638 | } |
639 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
640 | mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]); |
641 | sub(reg_aux_zp_c_values, zp_c_values_offset(ld_block2 - 1)); |
642 | mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values); |
643 | } |
644 | } |
645 | |
646 | template <cpu_isa_t isa, typename Wmm> |
647 | void jit_brgemm_kernel_t<isa, Wmm>::advance_bdb_post_op_regs(int adj_bd_block) { |
648 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
649 | mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]); |
650 | add(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(1)); |
651 | mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b); |
652 | } |
653 | if (with_binary_per_oc_sp_bcast_) { |
654 | const injector_utils::register_preserve_guard_t register_guard( |
655 | this, {reg_aux_binary_postops_oc_l}); |
656 | mov(reg_aux_binary_postops_oc_l, |
657 | ptr[rsp + reg_aux_binary_postops_oc_l_offs_ |
658 | + register_guard.stack_space_occupied()]); |
659 | add(reg_aux_binary_postops_oc_l, adj_bd_block); |
660 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_ |
661 | + register_guard.stack_space_occupied()], |
662 | reg_aux_binary_postops_oc_l); |
663 | } |
664 | } |
665 | |
666 | template <cpu_isa_t isa, typename Wmm> |
667 | void jit_brgemm_kernel_t<isa, Wmm>::restore_bdb_post_op_regs(int bd_block2) { |
668 | bool post_processed = false; |
669 | if (bd_block2 > 1) { |
670 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
671 | post_processed = true; |
672 | mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]); |
673 | sub(reg_aux_zp_comp_b, bdb_zp_comp_b_offset(bd_block2 - 1)); |
674 | mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_aux_zp_comp_b); |
675 | } |
676 | if (with_binary_per_oc_sp_bcast_) { |
677 | post_processed = true; |
678 | const injector_utils::register_preserve_guard_t register_guard( |
679 | this, {reg_aux_binary_postops_oc_l}); |
680 | mov(reg_aux_binary_postops_oc_l, |
681 | ptr[rsp + reg_aux_binary_postops_oc_l_offs_ |
682 | + register_guard.stack_space_occupied()]); |
683 | sub(reg_aux_binary_postops_oc_l, (bd_block2 - 1) * brg.bd_block); |
684 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_ |
685 | + register_guard.stack_space_occupied()], |
686 | reg_aux_binary_postops_oc_l); |
687 | } |
688 | } |
689 | if (post_processed) mov(reg_buf, ptr[rsp + reg_buf_offs_]); |
690 | } |
691 | |
692 | template <cpu_isa_t isa, typename Wmm> |
693 | void jit_brgemm_kernel_t<isa, Wmm>::ldb_regs_shift( |
694 | int ld_block2, bool is_tail) { |
695 | int C_offset = (is_tail) ? ldb_C_offset(1, true) : ldb_C_offset(ld_block2); |
696 | int D_offset = (is_tail) ? ldb_D_offset(1, true) : ldb_D_offset(ld_block2); |
697 | add(reg_aux_C, C_offset); |
698 | add(reg_aux_D, D_offset); |
699 | |
700 | add(reg_b_offset, |
701 | (is_tail) ? ldb_B_offset(1, true) : ldb_B_offset(ld_block2)); |
702 | |
703 | if (brg.with_bias) { |
704 | mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); |
705 | add(reg_aux_bias, |
706 | (is_tail) ? bias_offset(1, true) : bias_offset(ld_block2)); |
707 | mov(ptr[rsp + reg_aux_bias_offs_], reg_aux_bias); |
708 | } |
709 | if (brg.req_s8s8_compensation) { |
710 | mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]); |
711 | add(reg_aux_compensation, |
712 | (is_tail) ? compensations_offset(1, true) |
713 | : compensations_offset(ld_block2)); |
714 | mov(ptr[rsp + reg_aux_comp_offs_], reg_aux_compensation); |
715 | } |
716 | if (brg.with_scales) { |
717 | mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); |
718 | add(reg_aux_scales, |
719 | (is_tail) ? scales_offset(1, true) : scales_offset(ld_block2)); |
720 | mov(ptr[rsp + reg_aux_scales_offs_], reg_aux_scales); |
721 | } |
722 | if (with_binary_channel_bcast_) { |
723 | const int po_offset |
724 | = (is_tail) ? ldb_po_offset(1, true) : ldb_po_offset(ld_block2); |
725 | mov(reg_aux_binary_postops_sp, |
726 | ptr[rsp + reg_aux_binary_postops_sp_offs_]); |
727 | add(reg_aux_binary_postops_sp, po_offset); |
728 | mov(ptr[rsp + reg_aux_binary_postops_sp_offs_], |
729 | reg_aux_binary_postops_sp); |
730 | } |
731 | if (with_binary_per_oc_bcast_) { |
732 | mov(reg_aux_binary_postops_oc_l, |
733 | ptr[rsp + reg_aux_binary_postops_oc_l_offs_]); |
734 | add(reg_aux_binary_postops_oc_l, |
735 | (is_tail) ? oc_logical_offset(1, true) |
736 | : oc_logical_offset(ld_block2)); |
737 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_], |
738 | reg_aux_binary_postops_oc_l); |
739 | } |
740 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
741 | mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); |
742 | add(reg_aux_zp_comp_a, |
743 | (is_tail) ? zp_comp_a_offset(1, true) |
744 | : zp_comp_a_offset(ld_block2)); |
745 | mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_aux_zp_comp_a); |
746 | } |
747 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
748 | mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]); |
749 | add(reg_aux_zp_c_values, |
750 | (is_tail) ? zp_c_values_offset(1, true) |
751 | : zp_c_values_offset(ld_block2)); |
752 | mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_aux_zp_c_values); |
753 | } |
754 | } |
755 | |
756 | template <cpu_isa_t isa, typename Wmm> |
757 | void jit_brgemm_kernel_t<isa, Wmm>::advance_bd_block2_post_op_regs( |
758 | int bd_block2) { |
759 | if (with_binary_per_oc_sp_bcast_) { |
760 | mov(reg_aux_binary_postops_oc_l, |
761 | ptr[rsp + reg_binary_postops_oc_l_offs_]); |
762 | add(reg_binary_postops_oc_l, bd_block2 * brg.bd_block); |
763 | mov(ptr[rsp + reg_binary_postops_oc_l_offs_], |
764 | reg_aux_binary_postops_oc_l); |
765 | } |
766 | if (with_binary_channel_bcast_) { |
767 | mov(reg_aux_binary_postops_sp, ptr[rsp + reg_binary_postops_sp_offs_]); |
768 | add(reg_aux_binary_postops_sp, bdb_po_offset(bd_block2)); |
769 | mov(ptr[rsp + reg_binary_postops_sp_offs_], reg_aux_binary_postops_sp); |
770 | } |
771 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
772 | mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); |
773 | add(reg_zp_comp_b, bdb_zp_comp_b_offset(bd_block2)); |
774 | mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b); |
775 | } |
776 | } |
777 | |
778 | template <cpu_isa_t isa, typename Wmm> |
779 | void jit_brgemm_kernel_t<isa, Wmm>::copy_post_ops_stack_values_to_aux( |
780 | bool is_reg_tail) { |
781 | if (!is_reg_tail) { |
782 | mov(reg_aux_C, reg_C); |
783 | mov(reg_aux_D, reg_D); |
784 | xor_(reg_b_offset, reg_b_offset); |
785 | if (brg.with_bias) { |
786 | mov(reg_bias, ptr[rsp + reg_bias_offs_]); |
787 | mov(ptr[rsp + reg_aux_bias_offs_], reg_bias); |
788 | } |
789 | if (brg.req_s8s8_compensation) { |
790 | mov(reg_compensation, ptr[rsp + reg_comp_offs_]); |
791 | mov(ptr[rsp + reg_aux_comp_offs_], reg_compensation); |
792 | } |
793 | if (brg.with_scales) { |
794 | mov(reg_scales, ptr[rsp + reg_scales_offs_]); |
795 | mov(ptr[rsp + reg_aux_scales_offs_], reg_scales); |
796 | } |
797 | if (with_binary_channel_bcast_) { |
798 | mov(reg_aux_binary_postops_sp, |
799 | ptr[rsp + reg_binary_postops_sp_offs_]); |
800 | mov(ptr[rsp + reg_aux_binary_postops_sp_offs_], |
801 | reg_aux_binary_postops_sp); |
802 | } |
803 | if (with_binary_per_oc_bcast_) { |
804 | mov(reg_binary_postops_oc_l, |
805 | ptr[rsp + reg_binary_postops_oc_l_offs_]); |
806 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_], |
807 | reg_binary_postops_oc_l); |
808 | } |
809 | |
810 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
811 | mov(reg_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]); |
812 | mov(ptr[rsp + reg_aux_zp_comp_a_offs_], reg_zp_comp_a); |
813 | } |
814 | |
815 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
816 | mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]); |
817 | mov(ptr[rsp + reg_aux_zp_c_values_offs_], reg_zp_c_values); |
818 | } |
819 | } |
820 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
821 | mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); |
822 | mov(ptr[rsp + reg_aux_zp_comp_b_offs_], reg_zp_comp_b); |
823 | } |
824 | if (with_binary_per_oc_sp_bcast_) { |
825 | mov(reg_aux_binary_postops_oc_l, |
826 | ptr[rsp + reg_binary_postops_oc_l_offs_]); |
827 | mov(ptr[rsp + reg_aux_binary_postops_oc_l_offs_], |
828 | reg_binary_postops_oc_l); |
829 | } |
830 | } |
831 | |
832 | template <cpu_isa_t isa, typename Wmm> |
833 | void jit_brgemm_kernel_t<isa, Wmm>::read_params() { |
834 | Label label_done; |
835 | |
836 | if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1); |
837 | |
838 | if (brg.type == brgemm_addr) { |
839 | mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]); |
840 | } else { |
841 | if (brg.layout == brgemm_row_major) { |
842 | mov(reg_A, ptr[param1 + GET_OFF(ptr_A)]); |
843 | mov(reg_B, ptr[param1 + GET_OFF(ptr_B)]); |
844 | } else { |
845 | mov(reg_A, ptr[param1 + GET_OFF(ptr_B)]); |
846 | mov(reg_B, ptr[param1 + GET_OFF(ptr_A)]); |
847 | } |
848 | |
849 | if (brg.type == brgemm_offs) { |
850 | mov(reg_offs_batch, ptr[param1 + GET_OFF(batch)]); |
851 | mov(ptr[rsp + origin_offs_batch_offs_], reg_offs_batch); |
852 | } else { |
853 | mov(reg_strd_batch, ptr[param1 + GET_OFF(batch)]); |
854 | mov(ptr[rsp + origin_strd_batch_offs_], reg_strd_batch); |
855 | } |
856 | } |
857 | |
858 | mov(reg_C, ptr[param1 + GET_OFF(ptr_C)]); |
859 | mov(reg_D, ptr[param1 + GET_OFF(ptr_D)]); |
860 | mov(reg_BS, ptr[param1 + GET_OFF(BS)]); |
861 | |
862 | // ptr_buf is re-used for passing compensations for |
863 | // brg.req_s8s8_compensation case |
864 | if (brg.is_tmm || brg.req_s8s8_compensation) { |
865 | mov(reg_buf, ptr[param1 + GET_OFF(ptr_buf)]); |
866 | mov(ptr[rsp + reg_buf_offs_], reg_buf); |
867 | } |
868 | |
869 | if (brg.with_bias) { |
870 | mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]); |
871 | mov(ptr[rsp + reg_bias_offs_], reg_bias); |
872 | } |
873 | if (brg.with_scales) { |
874 | mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]); |
875 | mov(ptr[rsp + reg_scales_offs_], reg_scales); |
876 | } |
877 | if (with_binary_no_bcast_) { |
878 | mov(reg_aux_binary_postops_sp, ptr[param1 + GET_OFF(data_C_ptr_)]); |
879 | mov(ptr[rsp + reg_data_C_ptr_], reg_aux_binary_postops_sp); |
880 | } |
881 | if (with_binary_channel_bcast_) { |
882 | mov(reg_aux_binary_postops_sp, |
883 | ptr[param1 + GET_OFF(first_mb_matrix_addr_off)]); |
884 | mov(ptr[rsp + reg_binary_postops_sp_offs_], reg_aux_binary_postops_sp); |
885 | } |
886 | if (with_binary_per_oc_bcast_) { |
887 | mov(reg_binary_postops_oc_l, ptr[param1 + GET_OFF(oc_logical_off)]); |
888 | mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_binary_postops_oc_l); |
889 | } else if (with_binary_per_oc_sp_bcast_) { |
890 | mov(reg_binary_postops_oc_l, |
891 | ptr[param1 + GET_OFF(dst_row_logical_off)]); |
892 | mov(ptr[rsp + reg_binary_postops_oc_l_offs_], reg_binary_postops_oc_l); |
893 | } |
894 | |
895 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
896 | mov(reg_zp_comp_a, ptr[param1 + GET_OFF(a_zp_compensations)]); |
897 | mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a); |
898 | } |
899 | |
900 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
901 | mov(reg_zp_comp_b, ptr[param1 + GET_OFF(b_zp_compensations)]); |
902 | mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b); |
903 | } |
904 | |
905 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
906 | mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]); |
907 | mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values); |
908 | } |
909 | |
910 | mov(reg_do_post_ops, ptr[param1 + GET_OFF(do_post_ops)]); |
911 | mov(ptr[rsp + reg_do_post_ops_offs_], reg_do_post_ops); |
912 | |
913 | mov(reg_skip_accm, ptr[param1 + GET_OFF(skip_accm)]); |
914 | mov(ptr[rsp + reg_skip_accm_offs_], reg_skip_accm); |
915 | |
916 | mov(reg_zp_a_val, ptr[param1 + GET_OFF(zp_a_val)]); |
917 | mov(ptr[rsp + reg_zp_a_val_offs_], reg_zp_a_val); |
918 | |
919 | mov(reg_do_comp, ptr[param1 + GET_OFF(do_apply_comp)]); |
920 | mov(ptr[rsp + reg_do_comp_offs_], reg_do_comp); |
921 | } |
922 | |
923 | template <cpu_isa_t isa, typename Wmm> |
924 | void jit_brgemm_kernel_t<isa, Wmm>::zero_accumulators(int bd_block2, |
925 | bool is_bdb_tail, int ld_block2, bool is_ld_tail, |
926 | bool skip_accumulation) { |
927 | if (brg.is_tmm) { |
928 | // avoid usage of tile registers if there is no accumulation |
929 | if (skip_accumulation) return; |
930 | for_(int bdb = 0; bdb < bd_block2; bdb++) |
931 | for (int ldb = 0; ldb < ld_block2; ldb++) { |
932 | int idx = (is_ld_tail) ? brg.ld_block2 : ldb; |
933 | tilezero(Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail, is_ld_tail))); |
934 | } |
935 | } else { |
936 | int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; |
937 | for_(int bd = 0; bd < bd_block; bd++) |
938 | for (int ld = 0; ld < ld_block2; ld++) { |
939 | auto vmm = accm(ld_block2, bd, ld); |
940 | uni_vpxor(vmm, vmm, vmm); |
941 | } |
942 | } |
943 | } |
944 | |
945 | template <cpu_isa_t isa, typename Wmm> |
946 | void jit_brgemm_kernel_t<isa, Wmm>::apply_alpha_beta( |
947 | int bd_block, int ld_block2, bool is_ld_tail) { |
948 | auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
949 | const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block; |
950 | auto vmm_beta = vmm_tmp_1(); |
951 | auto vmm_alpha = vmm_tmp_2(); |
952 | auto vmm_prev_dst = vmm_tmp_3(); |
953 | |
954 | const bool apply_alpha = brg.alpha != 1.f; |
955 | const bool apply_beta = brg.beta != 0.f; |
956 | if (!apply_alpha && !apply_beta) return; |
957 | |
958 | const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f); |
959 | const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required; |
960 | |
961 | if (apply_beta && !use_vadd_for_beta) { |
962 | mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta))); |
963 | uni_vmovq(Xmm(vmm_beta.getIdx()), reg_tmp_gpr); |
964 | uni_vbroadcastss(vmm_beta, Xmm(vmm_beta.getIdx())); |
965 | } |
966 | if (apply_alpha) { |
967 | mov(reg_tmp_gpr, float2int(static_cast<float>(brg.alpha))); |
968 | uni_vmovq(Xmm(vmm_alpha.getIdx()), reg_tmp_gpr); |
969 | uni_vbroadcastss(vmm_alpha, Xmm(vmm_alpha.getIdx())); |
970 | } |
971 | for_(int bd = 0; bd < bd_block; bd++) |
972 | for (int ld = 0; ld < ld_block2; ld++) { |
973 | const bool is_tail = is_ld_tail && ld + 1 == ld_block2; |
974 | auto vmm = accm(ld_block2, bd, ld); |
975 | if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm); |
976 | if (apply_alpha) uni_vmulps(vmm, vmm, vmm_alpha); |
977 | if (apply_beta) { |
978 | auto ptr_C = ptr[reg_aux_C + C_offset(bd, ld)]; |
979 | if (use_vadd_for_beta) { |
980 | if (IMPLICATION( |
981 | is_tail, is_superset(brg.isa_impl, avx512_core))) { |
982 | auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask); |
983 | if (brg.is_int8) |
984 | uni_vpaddd(vmm_masked, vmm, ptr_C); |
985 | else |
986 | uni_vaddps(vmm_masked, vmm, ptr_C); |
987 | } else { |
988 | load_data(brg.dt_c, vmm_prev_dst, ptr_C, ld_size); |
989 | if (brg.is_int8) |
990 | uni_vpaddd(vmm, vmm, vmm_prev_dst); |
991 | else |
992 | uni_vaddps(vmm, vmm, vmm_prev_dst); |
993 | } |
994 | } else { |
995 | cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, true, false, k_mask, |
996 | ld_size); |
997 | uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_beta); |
998 | } |
999 | } |
1000 | } |
1001 | } |
1002 | |
1003 | template <cpu_isa_t isa, typename Wmm> |
1004 | void jit_brgemm_kernel_t<isa, Wmm>::apply_post_ops( |
1005 | int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) { |
1006 | |
1007 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
1008 | |
1009 | const injector_utils::conditional_register_preserve_guard_t register_guard( |
1010 | brg.with_binary, this, {param1}); |
1011 | const auto guard_space = register_guard.stack_space_occupied(); |
1012 | if (brg.with_binary) { |
1013 | mov(param1, ptr[rsp + abi_param1_offs_ + guard_space]); |
1014 | |
1015 | if (handle_binary_po_offset_) { |
1016 | for_(int bd = 0; bd < bd_block; bd++) |
1017 | for (int ld = 0; ld < ld_block2; ld++) { |
1018 | const auto vmm_idx = accm(ld_block2, bd, ld).getIdx(); |
1019 | |
1020 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_aux_D); |
1021 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
1022 | vmm_idx, D_offset(bd, ld)); |
1023 | if (is_ld_tail) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
1024 | } |
1025 | } |
1026 | } |
1027 | |
1028 | const auto sum_injector = [&] { |
1029 | const float *p_sum_scale = &brg.sum_scale; |
1030 | const int32_t *p_sum_zp = &brg.sum_zp; |
1031 | const bool p_sum_scale_reg_set = *p_sum_scale != 1.f; |
1032 | const bool p_sum_zp_reg_set = *p_sum_zp != 0; |
1033 | |
1034 | { |
1035 | const injector_utils::conditional_register_preserve_guard_t |
1036 | register_guard_sum_scale( |
1037 | (handle_binary_po_offset_) && p_sum_scale_reg_set, |
1038 | this, {reg_ptr_sum_scale}); |
1039 | const injector_utils::conditional_register_preserve_guard_t |
1040 | register_guard_sum_zp( |
1041 | p_sum_zp_reg_set, this, {reg_ptr_sum_zp}); |
1042 | |
1043 | if (p_sum_scale_reg_set) |
1044 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
1045 | |
1046 | const auto &vmm_sum_zp = vmm_tmp_2(); |
1047 | if (p_sum_zp_reg_set) { |
1048 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
1049 | if (is_superset(brg.isa_impl, avx512_core)) { |
1050 | vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
1051 | } else { |
1052 | uni_vpbroadcastd(vmm_sum_zp, ptr[reg_ptr_sum_zp]); |
1053 | uni_vcvtdq2ps(vmm_sum_zp, vmm_sum_zp); |
1054 | } |
1055 | } |
1056 | |
1057 | const auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
1058 | const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block; |
1059 | |
1060 | for (int bd = 0; bd < bd_block; bd++) { |
1061 | for (int ld = 0; ld < ld_block2; ld++) { |
1062 | const auto vmm = accm(ld_block2, bd, ld); |
1063 | const auto addr = ptr[reg_aux_D + D_offset(bd, ld)]; |
1064 | const auto vmm_prev_dst = vmm_tmp_1(); |
1065 | cvt2ps(brg.sum_dt, vmm_prev_dst, addr, true, false, k_mask, |
1066 | ld_size); |
1067 | if (p_sum_zp_reg_set) |
1068 | uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_sum_zp); |
1069 | if (!p_sum_scale_reg_set) |
1070 | uni_vaddps(vmm, vmm, vmm_prev_dst); |
1071 | else { |
1072 | if (is_superset(brg.isa_impl, avx512_core)) { |
1073 | uni_vfmadd231ps(vmm, vmm_prev_dst, |
1074 | ptr_b[reg_ptr_sum_scale]); |
1075 | } else { |
1076 | auto vmm_tmp = vmm_tmp_2(); |
1077 | uni_vpbroadcastd(vmm_tmp, ptr[reg_ptr_sum_scale]); |
1078 | uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_tmp); |
1079 | } |
1080 | } |
1081 | } |
1082 | } |
1083 | } |
1084 | }; |
1085 | |
1086 | if (brg.with_sum) { |
1087 | postops_injector_->set_lambda_injector( |
1088 | primitive_kind::sum, sum_injector); |
1089 | } |
1090 | |
1091 | postops_injector_->compute_vector_range( |
1092 | max_vregs - bd_block * ld_block2, max_vregs, rhs_arg_params); |
1093 | } |
1094 | |
1095 | template <cpu_isa_t isa, typename Wmm> |
1096 | void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators_apply_post_ops( |
1097 | int bd_block, int ld_block2, int ldb_and_bdb_offset, bool is_ld_tail) { |
1098 | auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
1099 | const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block; |
1100 | |
1101 | // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> |
1102 | // accumulated values are already converted to ps in apply_alpha_beta() |
1103 | const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f; |
1104 | const bool beta_uses_vadd |
1105 | = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f); |
1106 | const bool dq2ps_required = brg.is_int8 |
1107 | && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); |
1108 | |
1109 | if (brg.with_scales) { |
1110 | mov(reg_aux_scales, ptr[rsp + reg_aux_scales_offs_]); |
1111 | for (int bd = 0; bd < bd_block; bd++) { |
1112 | for (int ld = 0; ld < ld_block2; ld++) { |
1113 | const auto addr = ptr[reg_aux_scales + scales_offset(ld)]; |
1114 | auto vmm = accm(ld_block2, bd, ld); |
1115 | if (dq2ps_required) uni_vcvtdq2ps(vmm, vmm); |
1116 | if (is_superset(brg.isa_impl, avx512_core)) { |
1117 | const Vmm vmm_masked = vmm_mask(vmm, true, false, k_mask); |
1118 | uni_vmulps(vmm_masked, vmm, addr); |
1119 | } else { |
1120 | auto vmm_scales = vmm_tmp_1(); |
1121 | load_data(data_type::f32, vmm_scales, addr, ld_size); |
1122 | uni_vmulps(vmm, vmm, vmm_scales); |
1123 | } |
1124 | } |
1125 | } |
1126 | } |
1127 | |
1128 | if (brg.with_bias) { mov(reg_aux_bias, ptr[rsp + reg_aux_bias_offs_]); } |
1129 | for_(int bd = 0; bd < bd_block; bd++) |
1130 | for (int ld = 0; ld < ld_block2; ld++) { |
1131 | auto vmm = accm(ld_block2, bd, ld); |
1132 | if (dq2ps_required && !brg.with_scales) uni_vcvtdq2ps(vmm, vmm); |
1133 | if (brg.with_bias) { |
1134 | auto vmm_bias = vmm_tmp_1(); |
1135 | auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)]; |
1136 | cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, true, false, k_mask, |
1137 | ld_size); |
1138 | uni_vaddps(vmm, vmm, vmm_bias); |
1139 | } |
1140 | } |
1141 | |
1142 | if (postops_injector_) |
1143 | apply_post_ops(bd_block, ld_block2, ldb_and_bdb_offset, is_ld_tail); |
1144 | |
1145 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
1146 | mov(reg_aux_zp_c_values, ptr[rsp + reg_aux_zp_c_values_offs_]); |
1147 | auto vmm_zp_c = vmm_tmp_1(); |
1148 | if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) { |
1149 | if (is_superset(brg.isa_impl, avx512_core)) { |
1150 | uni_vcvtdq2ps(vmm_zp_c, |
1151 | EVEX_compress_addr(reg_aux_zp_c_values, 0, true)); |
1152 | } else { |
1153 | uni_vpbroadcastd(vmm_zp_c, ptr[reg_aux_zp_c_values]); |
1154 | uni_vcvtdq2ps(vmm_zp_c, vmm_zp_c); |
1155 | } |
1156 | } |
1157 | for (int ld = 0; ld < ld_block2; ld++) { |
1158 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
1159 | int zp_c_off = zp_c_values_offset(ld); |
1160 | auto zp_c_addr |
1161 | = EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off); |
1162 | cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, true, false, k_mask, |
1163 | ld_size); |
1164 | } |
1165 | for (int bd = 0; bd < bd_block; bd++) { |
1166 | auto vmm = accm(ld_block2, bd, ld); |
1167 | uni_vaddps(vmm, vmm, vmm_zp_c); |
1168 | } |
1169 | } |
1170 | } |
1171 | |
1172 | const bool dt_requires_saturation |
1173 | = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32); |
1174 | auto vmm_lbound = vmm_tmp_1(); |
1175 | auto vmm_ubound = vmm_tmp_2(); |
1176 | if (dt_requires_saturation) { |
1177 | init_saturate_f32( |
1178 | vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d); |
1179 | } |
1180 | |
1181 | if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16(); |
1182 | |
1183 | for (int bd = 0; bd < bd_block; bd++) { |
1184 | if (dt_requires_saturation) { |
1185 | for (int ld = 0; ld < ld_block2; ld++) { |
1186 | auto vmm = accm(ld_block2, bd, ld); |
1187 | saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); |
1188 | uni_vcvtps2dq(vmm, vmm); |
1189 | } |
1190 | } |
1191 | for (int ld = 0; ld < ld_block2; ld++) { |
1192 | auto addr = ptr[reg_aux_D + D_offset(bd, ld)]; |
1193 | auto vmm = accm(ld_block2, bd, ld); |
1194 | auto vmm_lower = Vmm_lower_t(vmm.getIdx()); |
1195 | const bool is_tail = is_ld_tail && ld + 1 == ld_block2; |
1196 | const Vmm r_vmm = vmm_mask(vmm, is_tail, true, k_mask); |
1197 | const Vmm_lower_t r_ymm |
1198 | = vmm_lower_mask(vmm_lower, is_tail, true, k_mask); |
1199 | if (is_superset(brg.isa_impl, avx512_core)) { |
1200 | switch (brg.dt_d) { |
1201 | case data_type::f32: |
1202 | case data_type::s32: uni_vmovups(addr, r_vmm); break; |
1203 | case data_type::bf16: // TODO - clean |
1204 | if (brg.is_bf16_emu) { |
1205 | bf16_emu_->vcvtneps2bf16(vmm_lower, vmm); |
1206 | vmovdqu16(addr, r_ymm); |
1207 | } else { |
1208 | vcvtneps2bf16(vmm_lower, vmm); |
1209 | vmovdqu16(addr, r_ymm); |
1210 | } |
1211 | break; |
1212 | case data_type::f16: |
1213 | vcvtps2ph(vmm_lower, vmm, _op_mxcsr); |
1214 | vmovdqu16(addr, r_ymm); |
1215 | break; |
1216 | case data_type::s8: vpmovsdb(addr, r_vmm); break; |
1217 | case data_type::u8: vpmovusdb(addr, r_vmm); break; |
1218 | default: assert(!"unknown dst_dt" ); |
1219 | } |
1220 | } else { |
1221 | const int ld_block = is_tail ? brg.ldb_tail : brg.ld_block; |
1222 | store_data( |
1223 | brg.dt_d, vmm, reg_aux_D, D_offset(bd, ld), ld_block); |
1224 | } |
1225 | } |
1226 | } |
1227 | } |
1228 | |
1229 | template <cpu_isa_t isa, typename Wmm> |
1230 | void jit_brgemm_kernel_t<isa, Wmm>::apply_compensation( |
1231 | int bd_block, int ld_block2, bool is_ld_tail) { |
1232 | // apply compensation to accumulated values |
1233 | // to avoid the loss of accuracy when converting s32 to f32 |
1234 | auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
1235 | const int ld_size = is_ld_tail ? brg.ldb_tail : brg.ld_block; |
1236 | |
1237 | if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { |
1238 | auto vmm_zp_a_val = vmm_tmp_2(); |
1239 | mov(reg_zp_a_val, ptr[rsp + reg_zp_a_val_offs_]); |
1240 | uni_vpbroadcastd(vmm_zp_a_val, reg_zp_a_val.cvt32()); |
1241 | |
1242 | mov(reg_aux_zp_comp_a, ptr[rsp + reg_aux_zp_comp_a_offs_]); |
1243 | for (int ld = 0; ld < ld_block2; ld++) { |
1244 | auto vmm_zp_comp_a = vmm_tmp_1(); |
1245 | int zp_comp_a_off = zp_comp_a_offset(ld); |
1246 | auto zp_comp_a_addr |
1247 | = EVEX_compress_addr(reg_aux_zp_comp_a, zp_comp_a_off); |
1248 | // apply src zero points value to the accumulated values |
1249 | if (is_superset(brg.isa_impl, avx512_core)) { |
1250 | auto vmm_zp_comp_a_masked |
1251 | = vmm_mask(vmm_zp_comp_a, true, false, k_mask); |
1252 | uni_vmovups(vmm_zp_comp_a_masked, zp_comp_a_addr); |
1253 | } else { |
1254 | load_data( |
1255 | data_type::s32, vmm_zp_comp_a, zp_comp_a_addr, ld_size); |
1256 | } |
1257 | uni_vpmulld(vmm_zp_comp_a, vmm_zp_comp_a, vmm_zp_a_val); |
1258 | |
1259 | for (int bd = 0; bd < bd_block; bd++) { |
1260 | auto vmm = accm(ld_block2, bd, ld); |
1261 | uni_vpaddd(vmm, vmm, vmm_zp_comp_a); |
1262 | } |
1263 | } |
1264 | } |
1265 | |
1266 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
1267 | mov(reg_aux_zp_comp_b, ptr[rsp + reg_aux_zp_comp_b_offs_]); |
1268 | for (int bd = 0; bd < bd_block; bd++) { |
1269 | int zp_comp_b_off = zp_comp_b_offset(bd); |
1270 | auto zp_comp_b_addr = EVEX_compress_addr( |
1271 | reg_aux_zp_comp_b, zp_comp_b_off, true); |
1272 | for (int ld = 0; ld < ld_block2; ld++) { |
1273 | auto vmm = accm(ld_block2, bd, ld); |
1274 | uni_vpaddd(vmm, vmm, zp_comp_b_addr); |
1275 | } |
1276 | } |
1277 | } |
1278 | |
1279 | if (!brg.req_cal_comp_pads && brg.req_s8s8_compensation) { |
1280 | mov(reg_aux_compensation, ptr[rsp + reg_aux_comp_offs_]); |
1281 | for (int ld = 0; ld < ld_block2; ld++) { |
1282 | auto vmm_comp = vmm_tmp_1(); |
1283 | int comp_offset = compensations_offset(ld); |
1284 | auto comp_addr |
1285 | = EVEX_compress_addr(reg_aux_compensation, comp_offset); |
1286 | const bool is_tail = is_ld_tail && ld + 1 == ld_block2; |
1287 | if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) { |
1288 | vmm_comp = vmm_mask(vmm_comp, true, false, k_mask); |
1289 | uni_vmovups(vmm_comp, comp_addr); |
1290 | } else { |
1291 | load_data(data_type::s32, vmm_comp, comp_addr, ld_size); |
1292 | } |
1293 | for (int bd = 0; bd < bd_block; bd++) { |
1294 | auto vmm = accm(ld_block2, bd, ld); |
1295 | uni_vpaddd(vmm, vmm, vmm_comp); |
1296 | } |
1297 | } |
1298 | } |
1299 | } |
1300 | |
1301 | template <cpu_isa_t isa, typename Wmm> |
1302 | void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators_without_post_ops( |
1303 | int bd_block, int ld_block2, bool is_ld_tail) { |
1304 | |
1305 | // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> |
1306 | // accumulated values are converted to ps in apply_alpha_beta() |
1307 | const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f; |
1308 | const bool beta_uses_vadd |
1309 | = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f); |
1310 | const bool dt_requires_saturation = brg.is_int8 |
1311 | && !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); |
1312 | auto vmm_lbound = vmm_tmp_1(); |
1313 | auto vmm_ubound = vmm_tmp_2(); |
1314 | if (dt_requires_saturation) { |
1315 | init_saturate_f32( |
1316 | vmm_lbound, vmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d); |
1317 | } |
1318 | |
1319 | for (int bd = 0; bd < bd_block; bd++) { |
1320 | if (dt_requires_saturation) { |
1321 | for (int ld = 0; ld < ld_block2; ld++) { |
1322 | auto vmm = accm(ld_block2, bd, ld); |
1323 | saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); |
1324 | uni_vcvtps2dq(vmm, vmm); |
1325 | } |
1326 | } |
1327 | for (int ld = 0; ld < ld_block2; ld++) { |
1328 | auto vmm = accm(ld_block2, bd, ld); |
1329 | const auto addr_c = ptr[reg_aux_C + C_offset(bd, ld)]; |
1330 | if (is_ld_tail) { |
1331 | if (is_superset(brg.isa_impl, avx512_core)) { |
1332 | uni_vmovups(addr_c | ld_tail_mask | T_z, vmm); |
1333 | } else { |
1334 | store_data(brg.dt_c, vmm, reg_aux_C, C_offset(bd, ld), |
1335 | brg.ldb_tail); |
1336 | } |
1337 | } else |
1338 | uni_vmovups(ptr[reg_aux_C + C_offset(bd, ld)], vmm); |
1339 | } |
1340 | } |
1341 | } |
1342 | |
1343 | template <cpu_isa_t isa, typename Wmm> |
1344 | void jit_brgemm_kernel_t<isa, Wmm>::store_accumulators(int bd_block2, |
1345 | bool is_bdb_tail, int ld_block2, bool is_ld_tail, |
1346 | bool skip_accumulation) { |
1347 | const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none, |
1348 | brg.zp_type_a, brg.zp_type_b, brg.zp_type_c); |
1349 | const bool are_post_ops_applicable = one_of(true, brg.with_eltwise, |
1350 | brg.with_binary, brg.with_scales, brg.with_bias, brg.with_sum, |
1351 | brg.dt_d != brg.dt_c, brg.req_s8s8_compensation, has_zero_points); |
1352 | const bool need_to_apply_alpha_beta = brg.beta != 0.f || brg.alpha != 1.f; |
1353 | |
1354 | if (brg.is_tmm) { |
1355 | if (need_to_apply_alpha_beta || are_post_ops_applicable) |
1356 | mov(reg_stride_ld_block, brg.ld_block * brg.typesize_C); |
1357 | else |
1358 | mov(reg_stride_ld_block, brg.LDC * brg.typesize_C); |
1359 | |
1360 | auto store_accumulators_amx = [=](const bool apply_post_ops) { |
1361 | mov(reg_buf, ptr[rsp + reg_buf_offs_]); |
1362 | for (int bdb = 0; bdb < bd_block2; bdb++) { |
1363 | int adj_bd_block = (brg.is_M_tail && is_bdb_tail) |
1364 | ? brg.bdb_tail |
1365 | : brg.bd_block; |
1366 | for (int ldb = 0; ldb < ld_block2; ldb++) { |
1367 | int idx = (is_ld_tail) ? brg.ld_block2 : ldb; |
1368 | if (need_to_apply_alpha_beta || are_post_ops_applicable) { |
1369 | if (skip_accumulation) { |
1370 | for (int bd = 0; bd < adj_bd_block; bd++) { |
1371 | auto vreg_acc = accm(1, bd, 0); |
1372 | uni_vpxor(vreg_acc, vreg_acc, vreg_acc); |
1373 | } |
1374 | } else { |
1375 | tilestored(ptr[reg_buf + reg_stride_ld_block], |
1376 | Tmm(brg.get_C_tensor(bdb, idx, is_bdb_tail, |
1377 | is_ld_tail))); |
1378 | for (int bd = 0; bd < adj_bd_block; bd++) { |
1379 | size_t buf_offset |
1380 | = (bd * brg.ld_block) * brg.typesize_C; |
1381 | auto vreg_acc = is_ld_tail |
1382 | ? accm(1, bd, 0) | ld_tail_mask | T_z |
1383 | : accm(1, bd, 0); |
1384 | uni_vmovups( |
1385 | vreg_acc, ptr[reg_buf + buf_offset]); |
1386 | } |
1387 | } |
1388 | if (need_to_apply_alpha_beta) |
1389 | apply_alpha_beta(adj_bd_block, 1, is_ld_tail); |
1390 | |
1391 | if (apply_post_ops) { |
1392 | const size_t ldb_and_bdb_offset |
1393 | = ldb_po_offset(ldb) + bdb_po_offset(bdb); |
1394 | store_accumulators_apply_post_ops(adj_bd_block, 1, |
1395 | ldb_and_bdb_offset, is_ld_tail); |
1396 | if (ldb < ld_block2 - 1) advance_ldb_post_op_regs(); |
1397 | add(reg_aux_D, ldb_D_offset(1)); |
1398 | } else { |
1399 | store_accumulators_without_post_ops( |
1400 | adj_bd_block, 1, is_ld_tail); |
1401 | } |
1402 | mov(reg_buf, ptr[rsp + reg_buf_offs_]); |
1403 | } else { |
1404 | auto tmm = Tmm(brg.get_C_tensor( |
1405 | bdb, idx, is_bdb_tail, is_ld_tail)); |
1406 | if (skip_accumulation) tilezero(tmm); |
1407 | tilestored(ptr[reg_aux_C + reg_stride_ld_block], tmm); |
1408 | } |
1409 | add(reg_aux_C, ldb_C_offset(1)); |
1410 | } |
1411 | sub(reg_aux_C, ldb_C_offset(ld_block2)); |
1412 | add(reg_aux_C, bdb_C_offset(1)); |
1413 | if (apply_post_ops) { |
1414 | sub(reg_aux_D, ldb_D_offset(ld_block2)); |
1415 | add(reg_aux_D, bdb_D_offset(1)); |
1416 | |
1417 | bool post_processed = false; |
1418 | if (ld_block2 > 1) { |
1419 | restore_ldb_post_op_regs(ld_block2); |
1420 | post_processed |= utils::one_of(true, brg.with_bias, |
1421 | brg.with_scales, with_binary_per_oc_bcast_, |
1422 | brg.zp_type_a != brgemm_broadcast_t::none, |
1423 | brg.zp_type_c == brgemm_broadcast_t::per_n); |
1424 | } |
1425 | if (bdb < bd_block2 - 1) { |
1426 | advance_bdb_post_op_regs(adj_bd_block); |
1427 | post_processed |= utils::one_of(true, |
1428 | brg.zp_type_b != brgemm_broadcast_t::none, |
1429 | with_binary_per_oc_sp_bcast_); |
1430 | } |
1431 | if (post_processed) mov(reg_buf, ptr[rsp + reg_buf_offs_]); |
1432 | } |
1433 | } |
1434 | sub(reg_aux_C, bdb_C_offset(bd_block2)); |
1435 | if (apply_post_ops) { |
1436 | sub(reg_aux_D, bdb_D_offset(bd_block2)); |
1437 | restore_bdb_post_op_regs(bd_block2); |
1438 | } |
1439 | }; |
1440 | |
1441 | Label label_done; |
1442 | if (are_post_ops_applicable) { |
1443 | Label label_store_without_post_ops; |
1444 | mov(reg_do_post_ops, ptr[rsp + reg_do_post_ops_offs_]); |
1445 | cmp(reg_do_post_ops, 0); |
1446 | jz(label_store_without_post_ops, T_NEAR); |
1447 | |
1448 | store_accumulators_amx(true); |
1449 | jmp(label_done, T_NEAR); |
1450 | |
1451 | L_aligned(label_store_without_post_ops); |
1452 | } |
1453 | store_accumulators_amx(false); |
1454 | L_aligned(label_done); |
1455 | } else { |
1456 | int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; |
1457 | |
1458 | if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) { |
1459 | Label label_store_without_comp; |
1460 | mov(reg_do_comp, ptr[rsp + reg_do_comp_offs_]); |
1461 | cmp(reg_do_comp, 0); |
1462 | jz(label_store_without_comp, T_NEAR); |
1463 | apply_compensation(bd_block, ld_block2, is_ld_tail); |
1464 | |
1465 | L_aligned(label_store_without_comp); |
1466 | } |
1467 | |
1468 | if (need_to_apply_alpha_beta) |
1469 | apply_alpha_beta(bd_block, ld_block2, is_ld_tail); |
1470 | |
1471 | Label label_done; |
1472 | if (are_post_ops_applicable) { |
1473 | Label label_store_without_post_ops; |
1474 | mov(reg_do_post_ops, ptr[rsp + reg_do_post_ops_offs_]); |
1475 | cmp(reg_do_post_ops, 0); |
1476 | jz(label_store_without_post_ops, T_NEAR); |
1477 | store_accumulators_apply_post_ops( |
1478 | bd_block, ld_block2, 0, is_ld_tail); |
1479 | jmp(label_done, T_NEAR); |
1480 | |
1481 | L_aligned(label_store_without_post_ops); |
1482 | } |
1483 | store_accumulators_without_post_ops(bd_block, ld_block2, is_ld_tail); |
1484 | L_aligned(label_done); |
1485 | } |
1486 | } |
1487 | |
1488 | template <cpu_isa_t isa, typename Wmm> |
1489 | void jit_brgemm_kernel_t<isa, Wmm>::restore_A_B_matrices() { |
1490 | auto restore_reg_batch = brg.brgattr.max_bs > 1 || vpad_exist; |
1491 | if (brg.type == brgemm_addr) { |
1492 | if (restore_reg_batch) mov(reg_aux1_batch, reg_addr_batch); |
1493 | } else { |
1494 | mov(reg_aux1_A, reg_A); |
1495 | mov(reg_aux1_B, reg_B); |
1496 | |
1497 | if (restore_reg_batch) { |
1498 | if (brg.type == brgemm_offs) |
1499 | mov(reg_offs_batch, ptr[rsp + origin_offs_batch_offs_]); |
1500 | else |
1501 | mov(reg_strd_batch, ptr[rsp + origin_strd_batch_offs_]); |
1502 | } |
1503 | } |
1504 | } |
1505 | |
1506 | template <cpu_isa_t isa, typename Wmm> |
1507 | void jit_brgemm_kernel_t<isa, Wmm>::set_A_B_matrices() { |
1508 | if (brg.type == brgemm_addr) { |
1509 | if (brg.brgattr.max_bs > 1) { |
1510 | if (brg.layout == brgemm_row_major) { |
1511 | mov(reg_aux_A, |
1512 | ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
1513 | mov(reg_aux_B, |
1514 | ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
1515 | } else { |
1516 | mov(reg_aux_A, |
1517 | ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
1518 | mov(reg_aux_B, |
1519 | ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
1520 | } |
1521 | } else { |
1522 | // for max_batch == 1 we stored A and B pointers at the beginning |
1523 | // of kernel in reg_aux1_A and reg_aux1_B |
1524 | if (brg.layout == brgemm_row_major) { |
1525 | mov(reg_aux_A, reg_aux1_A); |
1526 | mov(reg_aux_B, reg_aux1_B); |
1527 | } else { |
1528 | mov(reg_aux_A, reg_aux1_B); |
1529 | mov(reg_aux_B, reg_aux1_A); |
1530 | } |
1531 | } |
1532 | |
1533 | if (brg.brgattr.max_bs > 1) { |
1534 | add(reg_aux1_batch, sizeof(brgemm_batch_element_t)); |
1535 | prefetcht0(ptr[reg_aux1_batch]); |
1536 | } |
1537 | } else if (brg.type == brgemm_offs) { |
1538 | mov(reg_aux_A, reg_A); |
1539 | mov(reg_aux_B, reg_B); |
1540 | |
1541 | add(reg_aux_A, ptr[reg_offs_batch + GET_OFF_BATCH_ELEMENT(offset.A)]); |
1542 | add(reg_aux_B, ptr[reg_offs_batch + GET_OFF_BATCH_ELEMENT(offset.B)]); |
1543 | add(reg_offs_batch, sizeof(brgemm_batch_element_t)); |
1544 | } else if (brg.type == brgemm_strd) { |
1545 | mov(reg_aux_A, reg_aux1_A); |
1546 | mov(reg_aux_B, reg_aux1_B); |
1547 | |
1548 | safe_add(reg_aux1_A, brg.stride_a, reg_tmp_gpr); |
1549 | safe_add(reg_aux1_B, brg.stride_b, reg_tmp_gpr); |
1550 | if (vpad_exist) { |
1551 | mov(reg_strd_batch, ptr[rsp + origin_strd_batch_offs_]); |
1552 | add(reg_strd_batch, sizeof(brgemm_batch_element_t)); |
1553 | mov(ptr[rsp + origin_strd_batch_offs_], reg_strd_batch); |
1554 | } |
1555 | } |
1556 | |
1557 | add(reg_aux_A, reg_a_offset); |
1558 | add(reg_aux_B, reg_b_offset); |
1559 | } |
1560 | |
1561 | template <cpu_isa_t isa, typename Wmm> |
1562 | void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel_amx(int bd_block2, |
1563 | bool is_bdb_tail, int ld_block2, bool is_rd_tail, bool is_ld_tail) { |
1564 | auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) { |
1565 | if (brg.dt_a == data_type::bf16 && brg.dt_b == data_type::bf16) { |
1566 | tdpbf16ps(x1, x2, x3); |
1567 | } else if (brg.dt_a == data_type::f16 && brg.dt_b == data_type::f16) { |
1568 | tdpfp16ps(x1, x2, x3); |
1569 | } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8) { |
1570 | tdpbuud(x1, x2, x3); |
1571 | } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) { |
1572 | tdpbusd(x1, x2, x3); |
1573 | } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) { |
1574 | tdpbsud(x1, x2, x3); |
1575 | } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8) { |
1576 | tdpbssd(x1, x2, x3); |
1577 | } else { |
1578 | assert(!"unsupported combination" ); |
1579 | } |
1580 | }; |
1581 | |
1582 | auto maybe_tileloadd_nt = [=](const Tmm &t1, reg64_t base, int offset, |
1583 | reg64_t stride, bool try_load_nt) { |
1584 | if (try_load_nt |
1585 | && static_cast<size_t>( |
1586 | brg.typesize_A * brg.brgattr.hint_expected_A_size |
1587 | + brg.typesize_B * brg.brgattr.hint_expected_B_size |
1588 | + brg.typesize_C * brg.brgattr.hint_expected_C_size) |
1589 | >= platform::get_per_core_cache_size(1)) |
1590 | tileloaddt1(t1, ptr[base + offset + stride]); |
1591 | else |
1592 | tileloadd(t1, ptr[base + offset + stride]); |
1593 | }; |
1594 | |
1595 | int rbd_block = (is_rd_tail) ? 1 : brg.rdb; |
1596 | for (int rdb = 0; rdb < rbd_block; rdb++) { |
1597 | for (int bdb = 0; bdb < bd_block2; bdb++) { |
1598 | maybe_tileloadd_nt(Tmm(brg.get_A_tensor(bdb, is_bdb_tail)), |
1599 | reg_aux_A, rdb * rdb_A_offset() + A_offset(bdb, 0, true), |
1600 | reg_stride_lda, |
1601 | brg.brgattr.hint_innermost_loop |
1602 | == brgemm_bd_loop_innermost); |
1603 | } |
1604 | for (int ldb = 0; ldb < ld_block2; ldb++) { |
1605 | const int idx = (is_ld_tail) ? brg.ld_block2 : ldb; |
1606 | maybe_tileloadd_nt(Tmm(brg.get_B_tensor(idx, is_ld_tail)), |
1607 | reg_aux_B, rdb * rdb_B_offset() + B_offset(ldb, 0, true), |
1608 | reg_stride_ldb, |
1609 | brg.brgattr.hint_innermost_loop |
1610 | == brgemm_ld_loop_innermost); |
1611 | for (int bdb = 0; bdb < bd_block2; bdb++) { |
1612 | tdpbxxd(Tmm(brg.get_C_tensor( |
1613 | bdb, idx, is_bdb_tail, is_ld_tail)), |
1614 | Tmm(brg.get_A_tensor(bdb, is_bdb_tail)), |
1615 | Tmm(brg.get_B_tensor(idx, is_ld_tail))); |
1616 | } |
1617 | } |
1618 | } |
1619 | if (!is_rd_tail) { |
1620 | add(reg_aux_A, brg.rdb * rdb_A_offset()); |
1621 | add(reg_aux_B, brg.rdb * rdb_B_offset()); |
1622 | } |
1623 | } |
1624 | |
1625 | template <cpu_isa_t isa, typename Wmm> |
1626 | void jit_brgemm_kernel_t<isa, Wmm>::gemm_microkernel(int bd_block2, |
1627 | bool is_bdb_tail, int ld_block2, bool is_rd_tail, bool is_ld_tail, |
1628 | int vpad, int rows_for_rd_tail) { |
1629 | MAYBE_UNUSED(bd_block2); |
1630 | auto dot_product = [=](Vmm v1, Vmm v2, Vmm v3) { |
1631 | if (brg.is_f32 || brg.is_f16 |
1632 | || (brg.is_bf16 && brg.isa_impl == avx2_vnni_2)) |
1633 | uni_vfmadd231ps(v1, v2, v3); |
1634 | else if (brg.is_bf16) |
1635 | vdpbf16ps(v1, v2, v3); |
1636 | else if (brg.is_int8) |
1637 | vpdpbusd(v1, v3, v2, isa == avx2_vnni ? VexEncoding : EvexEncoding); |
1638 | }; |
1639 | |
1640 | int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; |
1641 | const auto bd_b = nstl::max(0, vpad); |
1642 | const auto bd_e = nstl::min(bd_block, bd_block + vpad); |
1643 | const auto is_valid_bd |
1644 | = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e; |
1645 | if (!is_valid_bd) return; |
1646 | |
1647 | bool is_emdbd = brg.embd_bcst; |
1648 | |
1649 | int rd_loop = 0, rd_tail_size = 0; |
1650 | if (is_rd_tail) { |
1651 | if (brg.is_bf16 || brg.is_int8) { |
1652 | rd_tail_size = brg.rdb_tail % brg.rd_step; |
1653 | rd_loop = (rd_tail_size != 0) |
1654 | ? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step |
1655 | : brg.rdb_tail; |
1656 | } else |
1657 | rd_loop = brg.rdb_tail; |
1658 | } else |
1659 | rd_loop = brg.rd_block; |
1660 | |
1661 | auto broadcast = [=](Vmm v1, size_t offset, bool is_tail, data_type_t dt) { |
1662 | if (is_tail) { |
1663 | uni_vpxor(v1, v1, v1); |
1664 | Xmm xmm_tmp = Xmm(v1.getIdx()); |
1665 | load_bytes( |
1666 | xmm_tmp, reg_aux_A, offset, rd_tail_size * brg.typesize_A); |
1667 | uni_vpbroadcastd(v1, xmm_tmp); |
1668 | } else { |
1669 | if (dt == data_type::f32) { |
1670 | uni_vbroadcastss(v1, ptr[reg_aux_A + offset]); |
1671 | } else if (dt == data_type::bf16) { |
1672 | if (brg.isa_impl == avx2_vnni_2) |
1673 | vbcstnebf162ps(v1, ptr[reg_aux_A + offset]); |
1674 | else |
1675 | uni_vpbroadcastd(v1, ptr[reg_aux_A + offset]); |
1676 | } else if (one_of(dt, data_type::s8, data_type::u8)) { |
1677 | uni_vpbroadcastd(v1, ptr[reg_aux_A + offset]); |
1678 | } else if (dt == data_type::f16) { |
1679 | if (brg.isa_impl == avx2_vnni_2) |
1680 | vbcstnesh2ps(v1, ptr[reg_aux_A + offset]); |
1681 | else |
1682 | vcvtph2psx(v1, ptr_b[reg_aux_A + offset]); |
1683 | } |
1684 | } |
1685 | |
1686 | if (brg.req_s8s8_compensation) uni_vpaddb(v1, v1, vmm_inp_shift()); |
1687 | }; |
1688 | |
1689 | auto compensation_padding |
1690 | = [=](Vmm vmm_load, Vmm vmm_tmp, int ld, int bd_b, int bd_e) { |
1691 | /* req_cal_comp_pads -> only calculate compensation along with computation |
1692 | * and do not use pre-calculate compensation, calculate comp padding as: |
1693 | * accum - inp_shift * conv(1, wei_s32) */ |
1694 | if (brg.req_s8s8_compensation) { |
1695 | if (brg.req_cal_comp_pads) { |
1696 | uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); |
1697 | dot_product(vmm_tmp, vmm_load, vmm_inp_shift()); |
1698 | } |
1699 | |
1700 | for (int bd = bd_b; bd < bd_e; bd++) { |
1701 | auto vmm = accm(ld_block2, bd, ld); |
1702 | if (brg.req_cal_comp_pads) { |
1703 | uni_vpsubd(vmm, vmm, vmm_tmp); |
1704 | } else { |
1705 | dot_product(vmm, vmm_load, vmm_inp_shift()); |
1706 | } |
1707 | } |
1708 | } |
1709 | |
1710 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
1711 | uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); |
1712 | dot_product(vmm_tmp, vmm_load, vmm_one_bytes()); |
1713 | uni_vpmulld(vmm_tmp, vmm_tmp, vmm_zp_a_shift()); |
1714 | |
1715 | for (int bd = bd_b; bd < bd_e; bd++) { |
1716 | auto vmm = accm(ld_block2, bd, ld); |
1717 | if (brg.req_cal_comp_pads) { |
1718 | uni_vpsubd(vmm, vmm, vmm_tmp); |
1719 | } else { |
1720 | uni_vpaddd(vmm, vmm, vmm_tmp); |
1721 | } |
1722 | } |
1723 | } |
1724 | }; |
1725 | |
1726 | if (brg.req_cal_comp_pads |
1727 | || (vpad != 0 |
1728 | && (brg.req_s8s8_compensation |
1729 | || brg.zp_type_a != brgemm_broadcast_t::none))) { |
1730 | // only used for int8 compensation related things. |
1731 | assert(brg.is_int8); |
1732 | if (n_bcast_1_load && brg.zp_type_a != brgemm_broadcast_t::none) { |
1733 | mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); |
1734 | const auto reg32_scratch = reg_zp_a_input_shift.cvt32(); |
1735 | mov(reg32_scratch, 0x1010101); |
1736 | uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch); |
1737 | mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]); |
1738 | uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch); |
1739 | mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); |
1740 | } |
1741 | |
1742 | for_(int rd = 0; rd < rd_loop; rd += brg.rd_step) |
1743 | for (int ld = 0; ld < ld_block2; ++ld) { |
1744 | const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; |
1745 | const bool is_tail = is_ld_tail && ld + 1 == ld_block2; |
1746 | auto vmm_store = load(); |
1747 | if (IMPLICATION(is_tail, is_superset(brg.isa_impl, avx512_core))) { |
1748 | vmm_store = vmm_mask(vmm_store, is_tail, false, ld_tail_mask); |
1749 | uni_vmovups(vmm_store, addr); |
1750 | } else { |
1751 | load_bytes(load(), addr, |
1752 | brg.typesize_B * brg.ldb_tail * brg.ld_step); |
1753 | } |
1754 | |
1755 | if (brg.req_cal_comp_pads) { |
1756 | compensation_padding(vmm_store, bcst(), ld, bd_b, bd_e); |
1757 | } else if (vpad != 0) { |
1758 | if (bd_b > 0) |
1759 | compensation_padding(vmm_store, bcst(), ld, 0, bd_b); |
1760 | if (bd_e < bd_block) |
1761 | compensation_padding(vmm_store, bcst(), ld, bd_e, bd_block); |
1762 | } |
1763 | } |
1764 | } |
1765 | |
1766 | bool maybe_load_bytes = (rows_for_rd_tail > 0 || brg.brgattr.wary_tail_read) |
1767 | && is_rd_tail && rd_tail_size != 0 && (brg.is_bf16 || brg.is_int8); |
1768 | if (n_bcast_1_load) { |
1769 | for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { |
1770 | bool have_to_load_bytes |
1771 | = maybe_load_bytes && (rd == rd_loop - brg.rd_step); |
1772 | |
1773 | auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; |
1774 | for (int bd = bd_b; bd < bd_e && !is_emdbd; bd++) { |
1775 | const auto bd_by_load_bytes = (bd >= bd_e - rows_by_load_bytes |
1776 | || brg.brgattr.wary_tail_read); |
1777 | broadcast(bcst(bd), A_offset(bd, rd), |
1778 | have_to_load_bytes && bd_by_load_bytes, brg.dt_a); |
1779 | } |
1780 | for (int ld = 0; ld < ld_block2; ld++) { |
1781 | const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; |
1782 | const Vmm vmm_load |
1783 | = vmm_mask(load(), is_ld_tail, false, ld_tail_mask); |
1784 | // Note: Assuming the tails are properly padded/blocked for |
1785 | // avx2_vnni_2, as the B matrix is generally |
1786 | // at least double-blocked. |
1787 | if (brg.dt_b == data_type::f16) { |
1788 | if (brg.isa_impl == avx2_vnni_2) { |
1789 | if (rd % 2 == 0) |
1790 | vcvtneeph2ps(vmm_load, addr); |
1791 | else |
1792 | vcvtneoph2ps(vmm_load, addr); |
1793 | } else |
1794 | vcvtph2psx(vmm_load, addr); |
1795 | } else if (brg.dt_b == data_type::bf16 |
1796 | && brg.isa_impl == avx2_vnni_2) { |
1797 | if (rd % 2 == 0) |
1798 | vcvtneebf162ps(vmm_load, addr); |
1799 | else |
1800 | vcvtneobf162ps(vmm_load, addr); |
1801 | } else if (is_ld_tail) { |
1802 | if (is_superset(brg.isa_impl, avx512_core)) { |
1803 | uni_vmovups(vmm_load, addr); |
1804 | } else { |
1805 | load_bytes(vmm_load, addr, |
1806 | brg.typesize_B * brg.ldb_tail * brg.ld_step); |
1807 | } |
1808 | } else { |
1809 | uni_vmovups(vmm_load, addr); |
1810 | } |
1811 | for (int bd = bd_b; bd < bd_e; bd++) { |
1812 | auto vmm = accm(ld_block2, bd, ld); |
1813 | if (is_emdbd) |
1814 | uni_vfmadd231ps(vmm, load(), |
1815 | ptr_b[reg_aux_A + A_offset(bd, rd)]); |
1816 | else |
1817 | dot_product(vmm, load(), bcst(bd)); |
1818 | } |
1819 | } |
1820 | } |
1821 | } else { |
1822 | for (int rd = 0; rd < rd_loop; rd += brg.rd_step) { |
1823 | int prefetch_count_B = 0; |
1824 | for (int ld = 0; ld < ld_block2; ld++) { |
1825 | const auto addr = ptr[reg_aux_B + B_offset(ld, rd)]; |
1826 | const Vmm vmm_load |
1827 | = vmm_mask(load(ld), is_ld_tail, false, ld_tail_mask); |
1828 | // Note: Assuming the tails are properly padded/blocked for |
1829 | // avx2_vnni_2, as the B matrix is generally |
1830 | // at least double-blocked. |
1831 | if (brg.dt_b == data_type::f16) { |
1832 | if (brg.isa_impl == avx2_vnni_2) { |
1833 | if (rd % 2 == 0) |
1834 | vcvtneeph2ps(vmm_load, addr); |
1835 | else |
1836 | vcvtneoph2ps(vmm_load, addr); |
1837 | } else { |
1838 | vcvtph2psx(vmm_load, addr); |
1839 | } |
1840 | } else if (brg.dt_b == data_type::bf16 |
1841 | && brg.isa_impl == avx2_vnni_2) { |
1842 | if (rd % 2 == 0) |
1843 | vcvtneebf162ps(vmm_load, addr); |
1844 | else |
1845 | vcvtneobf162ps(vmm_load, addr); |
1846 | } else if (is_ld_tail) { |
1847 | if (is_superset(brg.isa_impl, avx512_core)) { |
1848 | uni_vmovups(vmm_load, addr); |
1849 | } else { |
1850 | load_bytes(vmm_load, addr, |
1851 | brg.typesize_B * brg.ldb_tail * brg.ld_step); |
1852 | } |
1853 | } else { |
1854 | uni_vmovups(vmm_load, addr); |
1855 | } |
1856 | } |
1857 | |
1858 | bool have_to_load_bytes |
1859 | = maybe_load_bytes && (rd == rd_loop - brg.rd_step); |
1860 | |
1861 | auto rows_by_load_bytes = have_to_load_bytes ? rows_for_rd_tail : 0; |
1862 | for (int bd = bd_b; bd < bd_e; bd++) { |
1863 | if (!is_emdbd) { |
1864 | const auto bd_by_load_bytes |
1865 | = (bd >= bd_e - rows_by_load_bytes |
1866 | || brg.brgattr.wary_tail_read); |
1867 | broadcast(bcst(), A_offset(bd, rd), |
1868 | have_to_load_bytes && bd_by_load_bytes, brg.dt_a); |
1869 | } |
1870 | if (prefetch_count_B < ld_block2) { |
1871 | prefetcht0(ptr[reg_aux_B + B_offset(prefetch_count_B++, rd) |
1872 | + brg.LDB * brg.rd_block * brg.typesize_B]); |
1873 | } |
1874 | for (int ld = 0; ld < ld_block2; ld++) { |
1875 | auto vmm = accm(ld_block2, bd, ld); |
1876 | if (is_emdbd) |
1877 | uni_vfmadd231ps(vmm, load(ld), |
1878 | ptr_b[reg_aux_A + A_offset(bd, rd)]); |
1879 | else |
1880 | dot_product(vmm, load(ld), bcst()); |
1881 | } |
1882 | } |
1883 | } |
1884 | } |
1885 | } |
1886 | |
1887 | template <cpu_isa_t isa, typename Wmm> |
1888 | void jit_brgemm_kernel_t<isa, Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail, |
1889 | int ld_block2, int ldb_loop_length, bool is_reg_tail, bool is_ld_tail, |
1890 | bool check_top_vpad, bool check_bottom_vpad, int rows_for_rd_tail, |
1891 | bool skip_accumulation) { |
1892 | |
1893 | Label ldb_loop_label; |
1894 | Label BS_loop_label; |
1895 | |
1896 | copy_post_ops_stack_values_to_aux(is_reg_tail); |
1897 | |
1898 | auto ld_loop_body = [=](int vpad) { |
1899 | set_A_B_matrices(); |
1900 | |
1901 | int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; |
1902 | const auto bd_b = nstl::max(0, vpad); |
1903 | const auto bd_e = nstl::min(bd_block, bd_block + vpad); |
1904 | const auto is_valid_bd |
1905 | = need_comp_pads && vpad != 0 ? bd_b <= bd_e : bd_b < bd_e; |
1906 | if (!is_valid_bd) return; |
1907 | |
1908 | if (brg.is_tmm) { |
1909 | const bool is_rd_tail = false; |
1910 | gemm_microkernel_amx( |
1911 | bd_block2, is_bdb_tail, ld_block2, is_rd_tail, is_ld_tail); |
1912 | } else { |
1913 | if (brg.rdb > 0) { |
1914 | Label rdb_loop_label; |
1915 | mov(reg_rdb_loop, brg.rdb); |
1916 | L_aligned(rdb_loop_label, 64); |
1917 | { |
1918 | const bool is_rd_tail = false; |
1919 | gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, |
1920 | is_rd_tail, is_ld_tail, vpad, rows_for_rd_tail); |
1921 | |
1922 | add(reg_aux_A, rdb_A_offset()); |
1923 | add(reg_aux_B, rdb_B_offset()); |
1924 | |
1925 | dec(reg_rdb_loop); |
1926 | cmp(reg_rdb_loop, 0); |
1927 | } |
1928 | jg(rdb_loop_label, T_NEAR); |
1929 | } |
1930 | } |
1931 | if (brg.rdb_tail != 0) { |
1932 | const bool is_rd_tail = true; |
1933 | if (brg.is_tmm) { |
1934 | gemm_microkernel_amx(bd_block2, is_bdb_tail, ld_block2, |
1935 | is_rd_tail, is_ld_tail); |
1936 | } else { |
1937 | gemm_microkernel(bd_block2, is_bdb_tail, ld_block2, is_rd_tail, |
1938 | is_ld_tail, vpad, rows_for_rd_tail); |
1939 | } |
1940 | } |
1941 | }; |
1942 | if (is_ldb_loop_) { |
1943 | mov(reg_ldb_loop, ldb_loop_length); |
1944 | if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); |
1945 | } |
1946 | |
1947 | L_aligned(ldb_loop_label, 64); |
1948 | { |
1949 | zero_accumulators(bd_block2, is_bdb_tail, ld_block2, is_ld_tail, |
1950 | skip_accumulation); |
1951 | |
1952 | if (is_ldb_loop_) |
1953 | mov(ptr[rsp + reg_D_offs_], reg_D); |
1954 | else { |
1955 | mov(reg_ldb_loop, reg_D); |
1956 | if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); |
1957 | } |
1958 | if (brg.brgattr.max_bs > 1) mov(ptr[rsp + reg_aux_D_offs_], reg_aux_D); |
1959 | |
1960 | if (brg.alpha != 0.f && !skip_accumulation) { |
1961 | restore_A_B_matrices(); |
1962 | if (brg.is_tmm) { |
1963 | mov(reg_stride_lda, brg.typesize_A * brg.LDA); |
1964 | mov(reg_stride_ldb, brg.rd_step * brg.typesize_B * brg.LDB); |
1965 | } |
1966 | |
1967 | if (brg.req_s8s8_compensation) { |
1968 | mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); |
1969 | mov(reg_s8_input_shift, 128); |
1970 | uni_vpbroadcastb(vmm_inp_shift(), reg_s8_input_shift.cvt8()); |
1971 | mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); |
1972 | } |
1973 | if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { |
1974 | mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); |
1975 | const auto reg32_scratch = reg_zp_a_input_shift.cvt32(); |
1976 | mov(reg32_scratch, 0x1010101); |
1977 | uni_vpbroadcastd(vmm_one_bytes(), reg32_scratch); |
1978 | mov(reg32_scratch, ptr[rsp + reg_zp_a_val_offs_]); |
1979 | uni_vpbroadcastd(vmm_zp_a_shift(), reg32_scratch); |
1980 | mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); |
1981 | } |
1982 | |
1983 | if (brg.brgattr.max_bs > 1) mov(reg_BS_loop, reg_BS); |
1984 | L_aligned(BS_loop_label, 64); |
1985 | { |
1986 | if (check_top_vpad || check_bottom_vpad) { |
1987 | const auto vpad_first = -brg.brgattr.max_bottom_vpad; |
1988 | const auto vpad_last = brg.brgattr.max_top_vpad; |
1989 | const auto n_vpads = vpad_last - vpad_first + 2; |
1990 | constexpr auto MAX_N_VPADS = 2 * brgemm_t::MAX_VPAD; |
1991 | assert(n_vpads < MAX_N_VPADS); |
1992 | |
1993 | Label Vpad_loop_end_label; |
1994 | std::vector<Label> Vpad_loop_iter_label(MAX_N_VPADS); |
1995 | if (vpad_exist) { |
1996 | reg64_t reg_batch = (brg.type == brgemm_addr) |
1997 | ? reg_aux1_batch |
1998 | : ((brg.type == brgemm_offs) ? reg_offs_batch |
1999 | : reg_strd_batch); |
2000 | if (brg.type == brgemm_strd) |
2001 | mov(reg_strd_batch, |
2002 | ptr[rsp + origin_strd_batch_offs_]); |
2003 | |
2004 | mov(reg_aux_A_vpad, |
2005 | ptr[reg_batch |
2006 | + GET_OFF_BATCH_ELEMENT(vvpad.top)]); |
2007 | sub(reg_aux_A_vpad, |
2008 | ptr[reg_batch |
2009 | + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]); |
2010 | } else |
2011 | xor_(reg_aux_A_vpad, reg_aux_A_vpad); |
2012 | |
2013 | for (int vpad = vpad_first; vpad <= vpad_last; vpad++) { |
2014 | const auto label_vpad = vpad - vpad_first; |
2015 | L(Vpad_loop_iter_label[label_vpad]); |
2016 | if (!check_top_vpad && vpad > 0) continue; |
2017 | if (!check_bottom_vpad && vpad < 0) continue; |
2018 | auto real_vpad = vpad; |
2019 | if (check_bottom_vpad && brg.bdb_tail) { |
2020 | if (!is_bdb_tail) { |
2021 | // for last full block before |
2022 | // bdb_tail && -vpad greater than bdb_tail |
2023 | if (brg.bdb_tail < -vpad) |
2024 | real_vpad += brg.bdb_tail; |
2025 | else |
2026 | continue; |
2027 | } else { |
2028 | // for block with tail, call ldb_loop() |
2029 | // to only calculate compensation for |
2030 | // padding area when bdb_tail < -vpad for |
2031 | // the cases using pre-cal compensation |
2032 | if (brg.bdb_tail < -vpad && need_comp_pads |
2033 | && !brg.req_cal_comp_pads) |
2034 | real_vpad = -brg.bdb_tail; |
2035 | } |
2036 | } |
2037 | cmp(reg_aux_A_vpad, vpad); |
2038 | jne(Vpad_loop_iter_label[label_vpad + 1], T_NEAR); |
2039 | ld_loop_body(real_vpad); |
2040 | jmp(Vpad_loop_end_label, T_NEAR); |
2041 | } |
2042 | L(Vpad_loop_iter_label[n_vpads - 1]); |
2043 | ld_loop_body(0); |
2044 | L(Vpad_loop_end_label); |
2045 | } else { |
2046 | ld_loop_body(0); |
2047 | } |
2048 | if (brg.brgattr.max_bs > 1) { |
2049 | dec(reg_BS_loop); |
2050 | cmp(reg_BS_loop, 0); |
2051 | jg(BS_loop_label, T_NEAR); |
2052 | } |
2053 | } |
2054 | } |
2055 | |
2056 | if (is_ldb_loop_) |
2057 | mov(reg_D, ptr[rsp + reg_D_offs_]); |
2058 | else { |
2059 | if (brg.is_tmm) mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); |
2060 | mov(reg_D, reg_ldb_loop); |
2061 | } |
2062 | if (brg.brgattr.max_bs > 1) mov(reg_aux_D, ptr[rsp + reg_aux_D_offs_]); |
2063 | |
2064 | store_accumulators(bd_block2, is_bdb_tail, ld_block2, is_ld_tail, |
2065 | skip_accumulation); |
2066 | if (is_ldb_loop_) { |
2067 | if (brg.is_tmm) mov(reg_ldb_loop, ptr[rsp + reg_ldb_loop_offs_]); |
2068 | if (!is_ld_tail) |
2069 | ldb_regs_shift(ld_block2); |
2070 | else |
2071 | ldb_regs_shift(1, true); |
2072 | dec(reg_ldb_loop); |
2073 | cmp(reg_ldb_loop, 0); |
2074 | if (brg.is_tmm) mov(ptr[rsp + reg_ldb_loop_offs_], reg_ldb_loop); |
2075 | jg(ldb_loop_label, T_NEAR); |
2076 | } |
2077 | } |
2078 | } |
2079 | |
2080 | template <cpu_isa_t isa, typename Wmm> |
2081 | void jit_brgemm_kernel_t<isa, Wmm>::bdb_loop() { |
2082 | auto do_ldb_loop = [=](int bd_block2, bool is_bdb_tail, bool check_top_vpad, |
2083 | bool check_bottom_vpad, int rows_for_rd_tail, |
2084 | bool skip_accumulation) { |
2085 | if (brg.ldb2 > 0) { |
2086 | const bool is_ld_reg_tail = false; |
2087 | const bool is_ld_tail = false; |
2088 | ldb_loop(bd_block2, is_bdb_tail, brg.ld_block2, brg.ldb2, |
2089 | is_ld_reg_tail, is_ld_tail, check_top_vpad, |
2090 | check_bottom_vpad, rows_for_rd_tail, skip_accumulation); |
2091 | } |
2092 | if (brg.ldb2_tail > 0) { |
2093 | const bool is_ld_reg_tail = (brg.ldb2 == 0) ? false : true; |
2094 | const bool is_ld_tail = false; |
2095 | ldb_loop(bd_block2, is_bdb_tail, brg.ldb2_tail, 1, is_ld_reg_tail, |
2096 | is_ld_tail, check_top_vpad, check_bottom_vpad, |
2097 | rows_for_rd_tail, skip_accumulation); |
2098 | } |
2099 | if (brg.ldb_tail > 0) { |
2100 | const bool is_ld_reg_tail |
2101 | = (brg.ldb2 == 0 && brg.ldb2_tail == 0) ? false : true; |
2102 | const bool is_ld_tail = true; |
2103 | ldb_loop(bd_block2, is_bdb_tail, 1, 1, is_ld_reg_tail, is_ld_tail, |
2104 | check_top_vpad, check_bottom_vpad, rows_for_rd_tail, |
2105 | skip_accumulation); |
2106 | } |
2107 | }; |
2108 | |
2109 | auto bdb_loop_body = [=](int bd_block2, bool is_bdb_tail, |
2110 | bool check_top_vpad, bool check_bottom_vpad, |
2111 | int rows_for_rd_tail, bool skip_accumulation) { |
2112 | do_ldb_loop(bd_block2, is_bdb_tail, check_top_vpad, check_bottom_vpad, |
2113 | rows_for_rd_tail, skip_accumulation); |
2114 | |
2115 | add(reg_C, bdb_C_offset(bd_block2)); |
2116 | add(reg_D, bdb_D_offset(bd_block2)); |
2117 | add(reg_a_offset, bdb_A_offset(bd_block2)); |
2118 | |
2119 | advance_bd_block2_post_op_regs(bd_block2); |
2120 | }; |
2121 | |
2122 | int rows_for_rd_tail, bd_blocks_for_rd_tail; |
2123 | |
2124 | if (brg.is_tmm) { |
2125 | rows_for_rd_tail = 0; |
2126 | bd_blocks_for_rd_tail = 0; |
2127 | n_bcast_1_load = false; |
2128 | } else { |
2129 | rows_for_rd_tail = 0; |
2130 | if (brg.rdb_tail != 0 && (brg.is_bf16 || brg.is_int8)) { |
2131 | const auto rd_tail_size = brg.rdb_tail % brg.rd_step; |
2132 | rows_for_rd_tail = rd_tail_size |
2133 | ? div_up(brg.rd_step - rd_tail_size, brg.reduce_dim) |
2134 | : 0; |
2135 | } |
2136 | bd_blocks_for_rd_tail |
2137 | = div_up(nstl::max(0, |
2138 | rows_for_rd_tail - brg.bdb_tail |
2139 | + brg.brgattr.max_bottom_vpad), |
2140 | brg.bd_block); |
2141 | |
2142 | auto ld_block2 = (brg.ldb2 > 0) |
2143 | ? brg.ld_block2 |
2144 | : ((brg.ldb2_tail > 0) ? brg.ldb2_tail : 1); |
2145 | const int free_vregs = max_vregs - brg.req_s8s8_compensation; |
2146 | n_bcast_1_load = brg.is_int8 |
2147 | && ((brg.bd_block * (ld_block2 + 1) < free_vregs) |
2148 | && (bd_blocks_for_rd_tail == 0) |
2149 | && (rows_for_rd_tail == 0)); |
2150 | if (brg.brgattr.hint_loop_order != brgemm_lo_default) |
2151 | n_bcast_1_load = (brg.brgattr.hint_loop_order == brgemm_lo_bl_1load) |
2152 | ? true |
2153 | : false; |
2154 | } |
2155 | |
2156 | auto bdb_loop_avx512 = [=](bool skip_accumulation) { |
2157 | Label bdb_loop_end_label, no_vpad_label; |
2158 | if (vpad_exist) { |
2159 | // max_top_vp is restricted by bd_block due to |
2160 | // brgemm_kernel implementation. TODO: remove this restriction |
2161 | assert(brg.brgattr.max_top_vpad <= brg.bd_block |
2162 | && brg.brgattr.max_bottom_vpad <= brg.bd_block); |
2163 | |
2164 | if (brg.type == brgemm_strd) { |
2165 | // if batch is nullptr then it means no vpadding in this call |
2166 | cmp(reg_offs_batch, 0); |
2167 | je(no_vpad_label, T_NEAR); |
2168 | } |
2169 | |
2170 | // first bd_block -------------- |
2171 | auto bdblocks = brg.bdb; |
2172 | if (bdblocks >= 1) { |
2173 | bdb_loop_body(1, false, true, brg.bdb == 1 && brg.bdb_tail == 0, |
2174 | brg.bdb - bd_blocks_for_rd_tail > 0 ? 0 |
2175 | : rows_for_rd_tail, |
2176 | skip_accumulation); |
2177 | bdblocks--; |
2178 | } |
2179 | if (bdblocks > 1) { |
2180 | // middle bd_blocks ----------- |
2181 | Label bdb_loop_label; |
2182 | mov(reg_bdb_loop, bdblocks); |
2183 | L_aligned(bdb_loop_label, 64); |
2184 | { |
2185 | bdb_loop_body(1, false, false, false, |
2186 | bd_blocks_for_rd_tail <= 1 ? 0 : rows_for_rd_tail, |
2187 | skip_accumulation); |
2188 | dec(reg_bdb_loop); |
2189 | cmp(reg_bdb_loop, 1); |
2190 | jg(bdb_loop_label, T_NEAR); |
2191 | } |
2192 | bdblocks = 1; |
2193 | } |
2194 | if (bdblocks == 1) { |
2195 | // last bd_block ------------ |
2196 | bdb_loop_body(1, false, false, true, |
2197 | bd_blocks_for_rd_tail == 0 ? 0 : rows_for_rd_tail, |
2198 | skip_accumulation); |
2199 | } |
2200 | if (brg.bdb_tail > 0) |
2201 | do_ldb_loop(1, true, brg.bdb < 1, true, rows_for_rd_tail, |
2202 | skip_accumulation); |
2203 | // for brgemm_strd "no vpadding" case may be implemented, so skip it |
2204 | if (brg.type == brgemm_strd) jmp(bdb_loop_end_label); |
2205 | } |
2206 | if (!vpad_exist || brg.type == brgemm_strd) { |
2207 | // for brgemm_strd batch may be null so we need this code path |
2208 | L_aligned(no_vpad_label, 64); |
2209 | if (brg.bdb > 0) { |
2210 | mov(reg_bdb_loop, brg.bdb); |
2211 | if (brg.bdb > (rows_for_rd_tail ? 1 : 0)) { |
2212 | Label bdb_loop_label; |
2213 | L_aligned(bdb_loop_label, 64); |
2214 | { |
2215 | bdb_loop_body(1, false, false, false, |
2216 | bd_blocks_for_rd_tail <= 1 ? 0 |
2217 | : rows_for_rd_tail, |
2218 | skip_accumulation); |
2219 | dec(reg_bdb_loop); |
2220 | cmp(reg_bdb_loop, rows_for_rd_tail ? 1 : 0); |
2221 | jg(bdb_loop_label, T_NEAR); |
2222 | } |
2223 | } |
2224 | |
2225 | if (rows_for_rd_tail) |
2226 | bdb_loop_body(1, false, false, true, |
2227 | bd_blocks_for_rd_tail == 0 ? 0 : rows_for_rd_tail, |
2228 | skip_accumulation); |
2229 | } |
2230 | if (brg.bdb_tail > 0) |
2231 | do_ldb_loop(1, true, false, false, rows_for_rd_tail, |
2232 | skip_accumulation); |
2233 | } |
2234 | L_aligned(bdb_loop_end_label, 64); |
2235 | }; |
2236 | auto bdb_loop_amx = [=](bool skip_accumulation) { |
2237 | Label bdb_loop_label; |
2238 | if (brg.bd_block2 >= 1) { |
2239 | mov(reg_bdb_loop, brg.bdb2); |
2240 | mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); |
2241 | L_aligned(bdb_loop_label, 64); |
2242 | { |
2243 | bdb_loop_body(brg.bd_block2, false, false, false, 0, |
2244 | skip_accumulation); |
2245 | mov(reg_bdb_loop, ptr[rsp + reg_bdb_loop_offs_]); |
2246 | dec(reg_bdb_loop); |
2247 | cmp(reg_bdb_loop, 0); |
2248 | mov(ptr[rsp + reg_bdb_loop_offs_], reg_bdb_loop); |
2249 | } |
2250 | jg(bdb_loop_label, T_NEAR); |
2251 | } |
2252 | if (brg.bdb2_tail > 0) |
2253 | bdb_loop_body( |
2254 | brg.bdb2_tail, false, false, false, 0, skip_accumulation); |
2255 | if (brg.bdb_tail > 0) |
2256 | do_ldb_loop(1, true, false, false, 0, skip_accumulation); |
2257 | }; |
2258 | |
2259 | auto bdb_loop_general = [=](bool skip_accumulation) { |
2260 | if (brg.type == brgemm_addr && brg.brgattr.max_bs == 1 && !vpad_exist |
2261 | && !skip_accumulation) { |
2262 | mov(reg_aux1_A, ptr[reg_addr_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
2263 | mov(reg_aux1_B, ptr[reg_addr_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
2264 | } |
2265 | |
2266 | xor_(reg_a_offset, reg_a_offset); |
2267 | if (brg.is_tmm) |
2268 | bdb_loop_amx(skip_accumulation); |
2269 | else |
2270 | bdb_loop_avx512(skip_accumulation); |
2271 | }; |
2272 | |
2273 | if (brg.brgattr.generate_skip_accumulation) { |
2274 | Label bdb_loop_skip_acc_label, bdb_loop_done_label; |
2275 | mov(reg_skip_accm, ptr[rsp + reg_skip_accm_offs_]); |
2276 | cmp(reg_skip_accm, 0); |
2277 | jnz(bdb_loop_skip_acc_label, T_NEAR); |
2278 | |
2279 | bdb_loop_general(false); |
2280 | jmp(bdb_loop_done_label, T_NEAR); |
2281 | |
2282 | L_aligned(bdb_loop_skip_acc_label, 64); |
2283 | bdb_loop_general(true); |
2284 | |
2285 | L_aligned(bdb_loop_done_label, 64); |
2286 | } else |
2287 | bdb_loop_general(false); |
2288 | } |
2289 | |
2290 | template <cpu_isa_t isa, typename Wmm> |
2291 | void jit_brgemm_kernel_t<isa, Wmm>::generate() { |
2292 | preamble(); |
2293 | |
2294 | sub(rsp, stack_space_needed_); |
2295 | |
2296 | vpad_exist |
2297 | = (brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0) |
2298 | ? true |
2299 | : false; |
2300 | need_comp_pads = IMPLICATION(brg.zp_type_a == brgemm_broadcast_t::none, |
2301 | brg.req_s8s8_compensation) |
2302 | && IMPLICATION(!vpad_exist, brg.req_cal_comp_pads); |
2303 | |
2304 | if (is_superset(brg.isa_impl, avx512_core)) { |
2305 | const auto full_mask = size_t {0xffffffffffffffff}; |
2306 | const auto tail_mask = size_t((1 << brg.ldb_tail) - 1); |
2307 | reg64_t reg_mask = rax; |
2308 | |
2309 | mov(reg_mask, full_mask); |
2310 | kmovq(ld_full_mask, reg_mask); |
2311 | mov(reg_mask, tail_mask); |
2312 | kmovq(ld_tail_mask, reg_mask); |
2313 | } |
2314 | |
2315 | read_params(); |
2316 | |
2317 | bdb_loop(); |
2318 | |
2319 | add(rsp, stack_space_needed_); |
2320 | |
2321 | postamble(); |
2322 | |
2323 | if (brg.with_eltwise) postops_injector_->prepare_table(); |
2324 | } |
2325 | |
2326 | brgemm_attr_t::brgemm_attr_t() |
2327 | : max_bs(INT_MAX) |
2328 | , max_top_vpad(0) |
2329 | , max_bottom_vpad(0) |
2330 | , hint_expected_A_size(platform::get_per_core_cache_size(1)) |
2331 | , hint_expected_B_size(platform::get_per_core_cache_size(1)) |
2332 | , hint_expected_C_size(platform::get_per_core_cache_size(1)) |
2333 | , hint_innermost_loop(brgemm_ld_loop_innermost) |
2334 | , hint_loop_order(brgemm_kernel_loop_order_t::brgemm_lo_default) |
2335 | , hint_prefetching(brgemm_kernel_prefetching_t::brgemm_prf_default) |
2336 | , wary_tail_read(true) |
2337 | , generate_skip_accumulation(false) |
2338 | , bd_mask(nullptr) |
2339 | , bd_mask_level(0) |
2340 | , use_uker(false) |
2341 | , use_interleave_stores(false) |
2342 | , LDA2(0) |
2343 | , LDB2(0) |
2344 | , LDC2_M(0) |
2345 | , LDC2_N(0) {} |
2346 | |
2347 | template <cpu_isa_t isa, typename Wmm> |
2348 | brgemm_kernel_common_t<isa, Wmm>::brgemm_kernel_common_t(const brgemm_t abrd) { |
2349 | brgemm_kernel_ = new jit_brgemm_kernel_t<isa, Wmm>(abrd); |
2350 | } |
2351 | |
2352 | template <cpu_isa_t isa, typename Wmm> |
2353 | status_t brgemm_kernel_common_t<isa, Wmm>::create_kernel() { |
2354 | return brgemm_kernel_->create_kernel(); |
2355 | } |
2356 | |
2357 | template <cpu_isa_t isa, typename Wmm> |
2358 | void brgemm_kernel_common_t<isa, Wmm>::operator()( |
2359 | brgemm_kernel_params_t *params) const { |
2360 | (*brgemm_kernel_)(params); |
2361 | } |
2362 | |
2363 | template <cpu_isa_t isa, typename Wmm> |
2364 | brgemm_kernel_common_t<isa, Wmm>::~brgemm_kernel_common_t() { |
2365 | delete brgemm_kernel_; |
2366 | } |
2367 | |
2368 | // isa specific instantiations are required because |
2369 | // post-ops require template isa param. |
2370 | template struct brgemm_kernel_common_t<avx512_core_amx_fp16, Xbyak::Tmm>; |
2371 | template struct brgemm_kernel_common_t<avx512_core_amx, Xbyak::Tmm>; |
2372 | template struct brgemm_kernel_common_t<avx512_core_fp16, Xbyak::Zmm>; |
2373 | template struct brgemm_kernel_common_t<avx512_core_bf16, Xbyak::Zmm>; |
2374 | template struct brgemm_kernel_common_t<avx512_core_vnni, Xbyak::Zmm>; |
2375 | template struct brgemm_kernel_common_t<avx512_core, Xbyak::Zmm>; |
2376 | template struct brgemm_kernel_common_t<avx2_vnni, Xbyak::Ymm>; |
2377 | template struct brgemm_kernel_common_t<avx2_vnni_2, Xbyak::Ymm>; |
2378 | template struct brgemm_kernel_common_t<avx2, Xbyak::Ymm>; |
2379 | } // namespace x64 |
2380 | } // namespace cpu |
2381 | } // namespace impl |
2382 | } // namespace dnnl |
2383 | |
2384 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
2385 | |