1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <memory> |
17 | |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/platform.hpp" |
24 | #include "cpu/x64/brgemm/brgemm.hpp" |
25 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
26 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | |
29 | #define GET_OFF(field) offsetof(brgemm_kernel_params_t, field) |
30 | #define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field) |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | using namespace dnnl::impl::utils; |
38 | using namespace Xbyak; |
39 | |
40 | struct jit_brgemm_amx_uker_base_t : public jit_generator { |
41 | jit_brgemm_amx_uker_base_t(const brgemm_t &abrg) |
42 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core) |
43 | , brg(abrg) |
44 | , postops_injector_(nullptr) { |
45 | |
46 | if (brg.with_eltwise || brg.with_binary || brg.with_sum) { |
47 | |
48 | static constexpr bool preserve_gpr = true; |
49 | // we don't use zmm1 for storing vectors |
50 | // so we don't need to preserve vmm |
51 | static constexpr bool preserve_vmm = false; |
52 | static constexpr bool use_exact_tail_scalar_bcast = false; |
53 | const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md); |
54 | |
55 | static const bcast_set_t enabled_bcast_strategy |
56 | = {broadcasting_strategy_t::scalar, |
57 | broadcasting_strategy_t::per_oc, |
58 | broadcasting_strategy_t::per_oc_spatial, |
59 | broadcasting_strategy_t::per_mb_spatial, |
60 | broadcasting_strategy_t::per_mb_w, |
61 | broadcasting_strategy_t::per_w, |
62 | broadcasting_strategy_t::no_broadcast}; |
63 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
64 | static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14, |
65 | this->r15, this->r13, preserve_gpr, preserve_vmm, |
66 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), |
67 | dst_md_wrapper, static_cast<size_t>(brg.ldb_tail), |
68 | ld_tail_mask, use_exact_tail_scalar_bcast}; |
69 | const binary_injector::static_params_t bsp { |
70 | this->param1, enabled_bcast_strategy, rhs_sp}; |
71 | |
72 | eltwise_injector::static_params_t esp; |
73 | esp.preserve_vmm = preserve_vmm; |
74 | esp.preserve_p_table = false; |
75 | |
76 | postops_injector_ = utils::make_unique<po_injector_t>( |
77 | this, brg.attr->post_ops_, bsp, esp); |
78 | |
79 | using namespace dnnl::impl::cpu::binary_injector_utils; |
80 | std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_, |
81 | with_binary_channel_bcast_, with_binary_per_mb_w_bcast_, |
82 | with_binary_per_w_bcast_, with_binary_no_bcast_) |
83 | = bcast_strategies_present_tup(brg.attr->post_ops_.entry_, |
84 | dst_md_wrapper, broadcasting_strategy_t::per_oc, |
85 | broadcasting_strategy_t::per_oc_spatial, |
86 | broadcasting_strategy_t::per_mb_spatial, |
87 | broadcasting_strategy_t::per_mb_w, |
88 | broadcasting_strategy_t::per_w, |
89 | broadcasting_strategy_t::no_broadcast); |
90 | handle_binary_po_offset_ = with_binary_per_oc_bcast_ |
91 | || with_binary_per_oc_sp_bcast_ |
92 | || with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_ |
93 | || with_binary_per_w_bcast_ || with_binary_no_bcast_; |
94 | } |
95 | use_ils_ = brg.brgattr.use_interleave_stores; |
96 | } |
97 | |
98 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_amx_uker_base_t) |
99 | |
100 | brgemm_t brg; |
101 | |
102 | private: |
103 | static constexpr cpu_isa_t po_isa_ = avx512_core_fp16; |
104 | using po_injector_t = injector::jit_uni_postops_injector_t<po_isa_>; |
105 | std::unique_ptr<po_injector_t> postops_injector_; |
106 | |
107 | using reg64_t = const Xbyak::Reg64; |
108 | enum { |
109 | simd_w = 16, |
110 | zmm_width_in_bytes = cpu_isa_traits<avx512_core>::vlen, |
111 | tile_size = 1024 |
112 | }; |
113 | |
114 | // Register decomposition |
115 | const reg64_t param1 = abi_param1; |
116 | |
117 | const reg64_t reg_addr_batch = r13; |
118 | const reg64_t reg_aux1_batch = rbp; |
119 | const reg64_t reg_aux_A = r11; |
120 | const reg64_t reg_aux_B = r10; |
121 | const reg64_t reg_stride_lda = r14; |
122 | const reg64_t reg_stride_ldb = abi_not_param1; |
123 | const reg64_t reg_C = r15; |
124 | const reg64_t reg_D = r12; |
125 | |
126 | const reg64_t reg_buf = r8; |
127 | const reg64_t reg_BS = rbx; |
128 | const reg64_t reg_BS_loop = r9; |
129 | const reg64_t reg_bias = rbx; |
130 | const reg64_t reg_scales = rbx; |
131 | |
132 | const reg64_t reg_stride_ld_block = rdx; |
133 | const reg64_t reg_do_post_ops = rbx; |
134 | const reg64_t reg_tmp_gpr = rbx; |
135 | const reg64_t reg_ptr_sum_scale = rbx; |
136 | |
137 | const reg64_t reg_zp_comp_a = rbx; |
138 | const reg64_t reg_aux_zp_comp_a = rbx; |
139 | const reg64_t reg_zp_comp_b = rbx; |
140 | const reg64_t reg_zp_c_values = rbx; |
141 | const reg64_t reg_ptr_sum_zp = r9; |
142 | const reg64_t reg_bf32_stride = rsi; |
143 | |
144 | constexpr static int abi_param1_offs_ = 0; |
145 | constexpr static int reg_zp_comp_a_offs_ = 8; |
146 | constexpr static int reg_zp_comp_b_offs_ = 16; |
147 | constexpr static int reg_zp_c_values_offs_ = 24; |
148 | constexpr static int stack_space_needed_ = 32; |
149 | |
150 | bool are_post_ops_applicable_ = false; |
151 | bool need_to_apply_alpha_beta_ = false; |
152 | bool may_load_accumulators_ = false; |
153 | |
154 | bool handle_binary_po_offset_ = false; |
155 | bool with_binary_per_oc_bcast_ = false; |
156 | bool with_binary_per_oc_sp_bcast_ = false; |
157 | bool with_binary_channel_bcast_ = false; |
158 | bool with_binary_per_mb_w_bcast_ = false; |
159 | bool with_binary_per_w_bcast_ = false; |
160 | bool with_binary_no_bcast_ = false; |
161 | bool prepare_post_ops_registers_once_ = false; |
162 | |
163 | char *bd_mask_buffer_ptr_ = nullptr; |
164 | std::vector<size_t> adj_bd_mask_buffer_; |
165 | size_t *adj_bd_mask_buffer_ptr_ = nullptr; |
166 | std::vector<size_t> skipped_bd_mask_buffer_; |
167 | size_t *skipped_bd_mask_buffer_ptr_ = nullptr; |
168 | palette_config_t palette_; |
169 | // used to store offsets within wsp buffer where the data is |
170 | // transformed(downconverted), to reuse when needed. |
171 | std::unordered_map<std::string, size_t> transform_buf_map_A_; |
172 | std::unordered_map<std::string, size_t> transform_buf_map_B_; |
173 | |
174 | size_t LDA_size_ = 0, LDA2_size_ = 0; |
175 | size_t LDB_size_ = 0, LDB2_size_ = 0; |
176 | size_t LDC_size_ = 0, LDC2_size_M_ = 0, LDC2_size_N_ = 0; |
177 | size_t LDD_size_ = 0; |
178 | size_t ld_block_B_size_ = 0; |
179 | size_t ld_block_C_size_ = 0; |
180 | size_t ld_block_D_size_ = 0; |
181 | size_t ld_block_bias_size_ = 0; |
182 | size_t ld_block_scales_size_ = 0; |
183 | size_t ld_block_zp_size_ = 0; |
184 | |
185 | size_t ldb_tail_B_size_ = 0; |
186 | size_t ldb_tail_C_size_ = 0; |
187 | size_t ldb_tail_D_size_ = 0; |
188 | size_t ldb_tail_zp_size_ = 0; |
189 | |
190 | enum matrix_kind_t { matrix_A, matrix_B, matrix_C, matrix_D }; |
191 | |
192 | // Loops in brgemm kernel are (two outermost loops depend on loop order): |
193 | // by bd block2 |
194 | // by ld block2 |
195 | // by batch_size |
196 | // by rd block |
197 | // gemm_microkernel |
198 | // Structures below (dim_iteration_t, bd_iteration_t, bs_iteration_t and |
199 | // iteration_map_t) describe the structure of cycles |
200 | // and are used for JIT code generation |
201 | struct dim_iteration_t { |
202 | size_t idx = 0; |
203 | size_t pos = 0; |
204 | int block = 0; |
205 | int block2 = 0; |
206 | bool is_tail = false; |
207 | dim_iteration_t() = default; |
208 | dim_iteration_t( |
209 | size_t pos_, int block_, int block2_, bool is_tail_ = false) |
210 | : pos(pos_), block(block_), block2(block2_), is_tail(is_tail_) {} |
211 | }; |
212 | |
213 | struct bd_iteration_t : public dim_iteration_t { |
214 | std ::vector<size_t> bdb_pos; |
215 | bd_iteration_t() = default; |
216 | bd_iteration_t( |
217 | size_t pos_, int block_, int block2_, bool is_tail_ = false) |
218 | : dim_iteration_t(pos_, block_, block2_, is_tail_) {} |
219 | }; |
220 | |
221 | struct bs_iteration_t { |
222 | size_t idx = 0; |
223 | size_t pos = 0; |
224 | bool is_first = false; |
225 | bool is_last = false; |
226 | bs_iteration_t() = default; |
227 | bs_iteration_t( |
228 | size_t pos_, bool is_first_ = true, bool is_last_ = false) |
229 | : pos(pos_), is_first(is_first_), is_last(is_last_) {} |
230 | }; |
231 | |
232 | struct iteration_map_t { |
233 | std::vector<dim_iteration_t> ldis; |
234 | std::vector<bd_iteration_t> bdis; |
235 | std::vector<bs_iteration_t> bsis; |
236 | std::vector<dim_iteration_t> rdis; |
237 | bool is_last_ldi(const dim_iteration_t &ldi) const { |
238 | return (ldi.idx == ldis.size() - 1); |
239 | } |
240 | bool is_last_bdi(const dim_iteration_t &bdi) const { |
241 | return (bdi.idx == bdis.size() - 1); |
242 | } |
243 | bool is_last_rdi(const dim_iteration_t &rdi) const { |
244 | return (rdi.idx == rdis.size() - 1); |
245 | } |
246 | }; |
247 | |
248 | struct brgemm_iteration_t { |
249 | bd_iteration_t bdi; |
250 | dim_iteration_t ldi; |
251 | bs_iteration_t bsi; |
252 | dim_iteration_t rdi; |
253 | brgemm_iteration_t *prev_iter = nullptr; |
254 | bool apply_postops = false; |
255 | brgemm_iteration_t() = default; |
256 | }; |
257 | |
258 | struct prf_t { |
259 | brgemm_kernel_prefetching_t pft; |
260 | int dist = -1; |
261 | int vec = 0; |
262 | }; |
263 | |
264 | // iteration map |
265 | iteration_map_t imap_; |
266 | |
267 | // interleave stores |
268 | bool use_ils_ = false; |
269 | bool was_prev_bi = false; |
270 | // saved parameters for storing |
271 | brgemm_iteration_t prev_bi_; |
272 | // current storing coordinates |
273 | int ils_vec_ = 0, ils_bdb_ = 0, ils_ldb_ = 0, ils_bd_start_ = 0; |
274 | int ils_bd_step_ = 3; // heuristic value |
275 | prf_t prf1A, prf2A, prf1B, prf2B, prf1C, prf2C; |
276 | |
277 | bool dt_requires_saturation_ = false; |
278 | |
279 | Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2); |
280 | Xbyak::Opmask ld_tail_mask = Xbyak::Opmask(3); |
281 | Xbyak::Opmask bf32_col_mask = Xbyak::Opmask(4); |
282 | |
283 | // Zmm map below |
284 | const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm0; } |
285 | const Xbyak::Zmm &zmm_tmp_2() const noexcept { return this->zmm1; } |
286 | const Xbyak::Zmm &zmm_tmp_3() const noexcept { return this->zmm2; } |
287 | |
288 | const Xbyak::Zmm zmm_bf32_pemute = zmm6; |
289 | const Xbyak::Zmm zmm_zp_comp_a = zmm6; |
290 | const Xbyak::Zmm zmm_zp_c = zmm7; |
291 | const Xbyak::Zmm zmm_lbound = zmm8; |
292 | const Xbyak::Zmm zmm_ubound = zmm9; |
293 | |
294 | // zmm_bias, zmm_bias and accm shouldn't be overlapped |
295 | Xbyak::Zmm accm(int bd) const { |
296 | assert(bd < 16); |
297 | return Xbyak::Zmm(31 - (bd % ils_bd_step_)); |
298 | } |
299 | |
300 | Xbyak::Zmm zmm_bias(int ldb) const { |
301 | assert(ldb < 5); |
302 | // zmm10 - zmm14 |
303 | return Xbyak::Zmm(10 + ldb); |
304 | } |
305 | |
306 | Xbyak::Zmm zmm_scales(int ldb) const { |
307 | assert(ldb < 5); |
308 | assert(ils_bd_step_ < 10); |
309 | // zmm15 - zmm19 |
310 | return Xbyak::Zmm(15 + ldb); |
311 | } |
312 | |
313 | Xbyak::Zmm zmm_mask(const Xbyak::Zmm zmm_in, bool mask_flag, bool store, |
314 | Xbyak::Opmask ktail_mask) const; |
315 | Xbyak::Ymm ymm_mask(const Xbyak::Ymm ymm_in, bool mask_flag, bool store, |
316 | Xbyak::Opmask ktail_mask) const; |
317 | |
318 | void cvt2ps(data_type_t type_in, const Xbyak::Zmm zmm_in, |
319 | const Xbyak::Operand &op, bool mask_flag, bool store, |
320 | Xbyak::Opmask ktail_mask); |
321 | |
322 | void read_params(); |
323 | void load_accumulators(brgemm_iteration_t &bi); |
324 | |
325 | void maybe_saturation(Xbyak::Zmm &zmm); |
326 | void apply_alpha_beta_to_vector( |
327 | const int idx, const Address &addr, bool is_ld_tail); |
328 | void apply_post_ops_to_range(brgemm_iteration_t &bi, int bd_start, |
329 | int bd_finish, int bd_inp_bdb, int ldb); |
330 | void store_vector_with_post_ops(const int idx, const Address &addr, |
331 | const int bd, const int ldb, bool is_ld_tail); |
332 | void prepare_post_ops_registers_ldb(brgemm_iteration_t &bi, int ldb); |
333 | void prepare_post_ops_registers(brgemm_iteration_t &bi); |
334 | |
335 | bool bi_shift_output( |
336 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi); |
337 | bool bi_shift_A( |
338 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi); |
339 | bool bi_shift_B( |
340 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi); |
341 | |
342 | void uni_prefetch(const Address &addr, brgemm_kernel_prefetching_t pft); |
343 | void prefetch_CD_range(brgemm_iteration_t &bi, |
344 | brgemm_kernel_prefetching_t pft, int bd_start, int bd_finish, |
345 | int bd_inp_bdb, int ldb); |
346 | int calc_ops_CD(brgemm_iteration_t &bi) const noexcept; |
347 | void prefetch_CD(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi, |
348 | prf_t &prf, bool prefetch_all); |
349 | |
350 | void prefetch_A(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi, |
351 | prf_t &prf, bool prefetch_all); |
352 | void prefetch_B(brgemm_iteration_t &bi, brgemm_iteration_t &pfo_bi, |
353 | prf_t &prf, bool prefetch_all); |
354 | void prefetching(brgemm_iteration_t &bi, bool prefetch_all); |
355 | |
356 | void process_output_range(brgemm_iteration_t &bi, int bd_start, |
357 | int bd_finish, int bd_inp_bdb, int bdb, int ldb); |
358 | void store_vector_without_post_ops( |
359 | const int idx, const Address &addr, bool is_ld_tail); |
360 | void store_vector( |
361 | brgemm_iteration_t &bi, const int idx, const int bd, const int ldb); |
362 | |
363 | void interleave_store(brgemm_iteration_t &bi, bool store_all); |
364 | |
365 | void store_accumulators(brgemm_iteration_t &bi); |
366 | |
367 | void set_A_B_matrices(int bs); |
368 | void set_A_B_matrices(); |
369 | |
370 | void bf32_downconvert(int num_rows, int tile_num_col_bytes, |
371 | reg64_t reg_data, int offset, reg64_t reg_data_stride, |
372 | reg64_t reg_buf, bool is_rd_tail); |
373 | |
374 | void bf32_downconvert_to_vnni(int num_rows, int tile_num_col_bytes, |
375 | reg64_t reg_data, int offset, reg64_t reg_data_stride, |
376 | reg64_t reg_buf, bool is_rd_tail); |
377 | |
378 | void maybe_pre_process_data(brgemm_iteration_t &bi, const Tmm &t1, |
379 | reg64_t reg_base, size_t offset, reg64_t reg_stride, |
380 | matrix_kind_t mk); |
381 | |
382 | void maybe_tileloadd_nt( |
383 | brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset); |
384 | |
385 | void tdpbxxd(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx, |
386 | bool do_pre_tilestore, bool do_post_tilestore); |
387 | |
388 | void gemm_microkernel_amx(brgemm_iteration_t &bi); |
389 | |
390 | void rdb_loop(brgemm_iteration_t &bi); |
391 | |
392 | void bs_loop_body(brgemm_iteration_t &bi); |
393 | void bs_loop(brgemm_iteration_t &bi); |
394 | |
395 | void ldb_loop_body(brgemm_iteration_t &bi); |
396 | void ldb_loop(brgemm_iteration_t &bi); |
397 | |
398 | void bdb_loop_body(brgemm_iteration_t &bi); |
399 | void bdb_loop(brgemm_iteration_t &bi); |
400 | |
401 | void init(brgemm_iteration_t &bi); |
402 | void generate() override; |
403 | |
404 | void prepare_bd_mask() noexcept; |
405 | int skipped_bd_mask(int inp_bd) noexcept; |
406 | |
407 | bool get_store_by_vectors(bool apply_post_ops) const { |
408 | const bool need_to_apply_post_ops |
409 | = are_post_ops_applicable_ && apply_post_ops; |
410 | const auto store_by_vectors = need_to_apply_alpha_beta_ |
411 | || need_to_apply_post_ops || brg.brgattr.bd_mask_level; |
412 | return store_by_vectors; |
413 | } |
414 | |
415 | size_t A_offset(int bdb) const noexcept; |
416 | |
417 | size_t B_offset(int ldb) const noexcept; |
418 | size_t C_offset(int bd, int ldb) const noexcept; |
419 | size_t C_block_offset(int bd, int ldb) const noexcept; |
420 | size_t D_offset(int bd, int ldb) const noexcept; |
421 | |
422 | size_t lda() const noexcept; |
423 | size_t ldb() const noexcept; |
424 | size_t rdb_A_offset(const brgemm_iteration_t &bi) const noexcept; |
425 | size_t rdb_B_offset(const brgemm_iteration_t &bi) const noexcept; |
426 | size_t ldb_B_offset(const brgemm_iteration_t &bi) const noexcept; |
427 | |
428 | size_t bias_offset(int ldb) const noexcept; |
429 | |
430 | size_t scales_offset(int ldb) const noexcept; |
431 | size_t zp_comp_a_offset(int ldb) const noexcept; |
432 | size_t zp_comp_b_offset(int bd) const noexcept; |
433 | size_t zp_c_values_offset(brgemm_iteration_t &bi, int ldb) const noexcept; |
434 | int get_out_bd(int bd_inp_bdb, int bd) const; |
435 | |
436 | void maybe_tilestore(brgemm_iteration_t &bi, int bdb_idx, int ldb_idx, |
437 | bool do_pre_tilestore, bool do_post_tilestore); |
438 | int get_C_tensor(brgemm_iteration_t &bi, int m, int n) const noexcept; |
439 | void top_loop(brgemm_iteration_t &bi); |
440 | |
441 | void fill_imap(); |
442 | }; |
443 | |
444 | bool jit_brgemm_amx_uker_base_t::bi_shift_output( |
445 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) { |
446 | res_bi = bi; |
447 | if (shift == 0) return true; |
448 | |
449 | size_t lidx = 0; |
450 | size_t bd_idx = 0; |
451 | size_t ld_idx = 0; |
452 | if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost) { |
453 | lidx = bi.bdi.idx * imap_.ldis.size() + bi.ldi.idx; |
454 | lidx += shift; |
455 | bd_idx = lidx / imap_.ldis.size(); |
456 | ld_idx = lidx % imap_.ldis.size(); |
457 | } else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost) { |
458 | lidx = bi.ldi.idx * imap_.bdis.size() + bi.bdi.idx; |
459 | lidx += shift; |
460 | ld_idx = lidx / imap_.bdis.size(); |
461 | bd_idx = lidx % imap_.bdis.size(); |
462 | } else |
463 | assert(!"Unknown loop order!" ); |
464 | if (lidx >= imap_.ldis.size() * imap_.bdis.size()) return false; |
465 | res_bi.bdi = imap_.bdis[bd_idx]; |
466 | res_bi.ldi = imap_.ldis[ld_idx]; |
467 | |
468 | return true; |
469 | } |
470 | |
471 | bool jit_brgemm_amx_uker_base_t::bi_shift_A( |
472 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) { |
473 | res_bi = bi; |
474 | auto lidx = bi.bdi.idx * imap_.rdis.size() + bi.rdi.idx; |
475 | lidx += shift; |
476 | if (lidx >= imap_.rdis.size() * imap_.bdis.size()) return false; |
477 | |
478 | const auto bd_idx = lidx / imap_.rdis.size(); |
479 | const auto rd_idx = lidx % imap_.rdis.size(); |
480 | |
481 | res_bi.bdi = imap_.bdis[bd_idx]; |
482 | res_bi.rdi = imap_.rdis[rd_idx]; |
483 | |
484 | return true; |
485 | } |
486 | |
487 | bool jit_brgemm_amx_uker_base_t::bi_shift_B( |
488 | brgemm_iteration_t &bi, int shift, brgemm_iteration_t &res_bi) { |
489 | res_bi = bi; |
490 | auto lidx = bi.ldi.idx * imap_.rdis.size() + bi.rdi.idx; |
491 | lidx += shift; |
492 | if (lidx >= imap_.rdis.size() * imap_.ldis.size()) return false; |
493 | |
494 | const auto ld_idx = lidx / imap_.rdis.size(); |
495 | const auto rd_idx = lidx % imap_.rdis.size(); |
496 | |
497 | res_bi.ldi = imap_.ldis[ld_idx]; |
498 | res_bi.rdi = imap_.rdis[rd_idx]; |
499 | |
500 | return true; |
501 | } |
502 | |
503 | int jit_brgemm_amx_uker_base_t::get_C_tensor( |
504 | brgemm_iteration_t &bi, int m, int n) const noexcept { |
505 | return brg.get_C_tensor(m, n, bi.bdi.is_tail, bi.ldi.is_tail); |
506 | } |
507 | |
508 | void jit_brgemm_amx_uker_base_t::prepare_bd_mask() noexcept { |
509 | if (!brg.brgattr.bd_mask_level) return; |
510 | bd_mask_buffer_ptr_ = brg.brgattr.bd_mask; |
511 | const auto bd_mask_size = brg.bcast_dim; |
512 | adj_bd_mask_buffer_.resize(bd_mask_size); |
513 | adj_bd_mask_buffer_ptr_ = adj_bd_mask_buffer_.data(); |
514 | skipped_bd_mask_buffer_.resize(bd_mask_size); |
515 | skipped_bd_mask_buffer_ptr_ = skipped_bd_mask_buffer_.data(); |
516 | if (!utils::any_null(bd_mask_buffer_ptr_, adj_bd_mask_buffer_ptr_)) { |
517 | int out_ibd = 0; |
518 | for (int i = 0; i < bd_mask_size; i++) { |
519 | adj_bd_mask_buffer_ptr_[i] = out_ibd; |
520 | out_ibd += bd_mask_buffer_ptr_[i]; |
521 | skipped_bd_mask_buffer_ptr_[i] = i; |
522 | for (auto ii = i; ii < bd_mask_size; ii++) { |
523 | if (bd_mask_buffer_ptr_[ii]) { |
524 | skipped_bd_mask_buffer_ptr_[i] = ii; |
525 | break; |
526 | } |
527 | } |
528 | } |
529 | } else |
530 | assert(!"struct nullptr error" ); |
531 | } |
532 | |
533 | int jit_brgemm_amx_uker_base_t::skipped_bd_mask(int inp_bd) noexcept { |
534 | if (brg.brgattr.bd_mask_level != 2) |
535 | return inp_bd; |
536 | else |
537 | return skipped_bd_mask_buffer_ptr_[inp_bd]; |
538 | } |
539 | |
540 | size_t jit_brgemm_amx_uker_base_t::A_offset(int bdb) const noexcept { |
541 | return bdb * LDA2_size_; |
542 | } |
543 | |
544 | size_t jit_brgemm_amx_uker_base_t::B_offset(int ldb) const noexcept { |
545 | return (brg.is_blocked ? 1 : brg.rd_step) * ldb * ld_block_B_size_; |
546 | } |
547 | |
548 | size_t jit_brgemm_amx_uker_base_t::C_offset(int bd, int ldb) const noexcept { |
549 | return bd * LDC_size_ + ldb * ld_block_C_size_; |
550 | } |
551 | |
552 | size_t jit_brgemm_amx_uker_base_t::C_block_offset(int bd, int ldb) const |
553 | noexcept { |
554 | return (size_t)bd * LDC2_size_M_ + (size_t)ldb * LDC2_size_N_; |
555 | } |
556 | |
557 | size_t jit_brgemm_amx_uker_base_t::D_offset(int bd, int ldb) const noexcept { |
558 | return bd * LDD_size_ + ldb * ld_block_D_size_; |
559 | } |
560 | |
561 | size_t jit_brgemm_amx_uker_base_t::lda() const noexcept { |
562 | return LDA_size_; |
563 | } |
564 | |
565 | size_t jit_brgemm_amx_uker_base_t::ldb() const noexcept { |
566 | return LDB_size_ * brg.rd_step; |
567 | } |
568 | |
569 | size_t jit_brgemm_amx_uker_base_t::rdb_A_offset( |
570 | const brgemm_iteration_t &bi) const noexcept { |
571 | return bi.rdi.pos * brg.typesize_A * brg.rd_block; |
572 | } |
573 | |
574 | size_t jit_brgemm_amx_uker_base_t::rdb_B_offset( |
575 | const brgemm_iteration_t &bi) const noexcept { |
576 | return bi.rdi.pos * brg.rd_block * LDB_size_; |
577 | } |
578 | |
579 | size_t jit_brgemm_amx_uker_base_t::ldb_B_offset( |
580 | const brgemm_iteration_t &bi) const noexcept { |
581 | return bi.ldi.pos * ld_block_B_size_ * brg.ld_step; |
582 | } |
583 | |
584 | size_t jit_brgemm_amx_uker_base_t::bias_offset(int ldb) const noexcept { |
585 | return ldb * ld_block_bias_size_; |
586 | } |
587 | |
588 | size_t jit_brgemm_amx_uker_base_t::scales_offset(int ldb) const noexcept { |
589 | return brg.is_oc_scale * ldb * ld_block_scales_size_; |
590 | } |
591 | |
592 | size_t jit_brgemm_amx_uker_base_t::zp_comp_a_offset(int ldb) const noexcept { |
593 | return ldb * ld_block_zp_size_; |
594 | } |
595 | |
596 | size_t jit_brgemm_amx_uker_base_t::zp_comp_b_offset(int bd) const noexcept { |
597 | return sizeof(int32_t) * bd; |
598 | } |
599 | |
600 | size_t jit_brgemm_amx_uker_base_t::zp_c_values_offset( |
601 | brgemm_iteration_t &bi, int ldb) const noexcept { |
602 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
603 | return (bi.ldi.is_tail) ? ldb_tail_zp_size_ |
604 | : (bi.ldi.pos + ldb) * ld_block_zp_size_; |
605 | } |
606 | |
607 | return 0; |
608 | } |
609 | |
610 | int jit_brgemm_amx_uker_base_t::get_out_bd(int bd_inp_bdb, int bd) const { |
611 | const auto bd_out_bd = bd_inp_bdb + bd; |
612 | if (brg.brgattr.bd_mask_level && !bd_mask_buffer_ptr_[bd_out_bd]) |
613 | return -1; |
614 | else { |
615 | if (brg.brgattr.bd_mask_level) |
616 | return adj_bd_mask_buffer_ptr_[bd_out_bd]; |
617 | else |
618 | return bd_out_bd; |
619 | } |
620 | } |
621 | |
622 | Xbyak::Zmm jit_brgemm_amx_uker_base_t::zmm_mask(const Xbyak::Zmm zmm_in, |
623 | bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const { |
624 | return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
625 | : zmm_in; |
626 | } |
627 | |
628 | Xbyak::Ymm jit_brgemm_amx_uker_base_t::ymm_mask(const Xbyak::Ymm ymm_in, |
629 | bool mask_flag, bool store, Xbyak::Opmask ktail_mask) const { |
630 | return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z) |
631 | : ymm_in; |
632 | } |
633 | |
634 | void jit_brgemm_amx_uker_base_t::cvt2ps(data_type_t type_in, |
635 | const Xbyak::Zmm zmm_in, const Xbyak::Operand &op, bool mask_flag, |
636 | bool store, Xbyak::Opmask ktail_mask) { |
637 | const Xbyak::Zmm zmm = zmm_mask(zmm_in, mask_flag, store, ktail_mask); |
638 | switch (type_in) { |
639 | case data_type::f32: |
640 | case data_type::s32: vmovups(zmm, op); break; |
641 | case data_type::bf16: |
642 | vpmovzxwd(zmm, op); |
643 | vpslld(zmm, zmm, 16); |
644 | break; |
645 | case data_type::f16: vcvtph2ps(zmm, op); break; |
646 | case data_type::s8: vpmovsxbd(zmm, op); break; |
647 | case data_type::u8: vpmovzxbd(zmm, op); break; |
648 | default: assert(!"unsupported data type" ); |
649 | } |
650 | if (types::is_integral_dt(type_in)) vcvtdq2ps(zmm_in, zmm_in); |
651 | } |
652 | |
653 | void jit_brgemm_amx_uker_base_t::read_params() { |
654 | Label label_done; |
655 | |
656 | mov(reg_C, ptr[param1 + GET_OFF(ptr_C)]); |
657 | mov(reg_D, ptr[param1 + GET_OFF(ptr_D)]); |
658 | mov(reg_BS, ptr[param1 + GET_OFF(BS)]); |
659 | |
660 | mov(reg_addr_batch, ptr[param1 + GET_OFF(batch)]); |
661 | |
662 | mov(reg_buf, ptr[param1 + GET_OFF(ptr_buf)]); |
663 | |
664 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
665 | mov(reg_zp_comp_a, ptr[param1 + GET_OFF(a_zp_compensations)]); |
666 | mov(ptr[rsp + reg_zp_comp_a_offs_], reg_zp_comp_a); |
667 | } |
668 | |
669 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
670 | mov(reg_zp_comp_b, ptr[param1 + GET_OFF(b_zp_compensations)]); |
671 | mov(ptr[rsp + reg_zp_comp_b_offs_], reg_zp_comp_b); |
672 | } |
673 | |
674 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
675 | mov(reg_zp_c_values, ptr[param1 + GET_OFF(c_zp_values)]); |
676 | mov(ptr[rsp + reg_zp_c_values_offs_], reg_zp_c_values); |
677 | } |
678 | } |
679 | |
680 | void jit_brgemm_amx_uker_base_t::load_accumulators(brgemm_iteration_t &bi) { |
681 | assert(IMPLICATION(bi.ldi.is_tail, bi.ldi.block2 == 1)); |
682 | if (may_load_accumulators_) mov(reg_stride_ld_block, LDC_size_); |
683 | |
684 | for (int bdb = 0; bdb < bi.bdi.block2; bdb++) { |
685 | const auto bd_out_bdb = get_out_bd(bi.bdi.bdb_pos[bdb], 0); |
686 | for (int ldb = 0; ldb < bi.ldi.block2; ldb++) { |
687 | if (may_load_accumulators_) { |
688 | const auto c_offset |
689 | = C_block_offset(bd_out_bdb, bi.ldi.pos + ldb); |
690 | tileloadd(Tmm(get_C_tensor(bi, bdb, ldb)), |
691 | ptr[reg_C + c_offset + reg_stride_ld_block]); |
692 | } else { |
693 | // call tilezero on very first iteration |
694 | if (!brg.interleave_tilestores_ |
695 | || everyone_is(0u, bi.bdi.idx, bi.ldi.idx)) |
696 | tilezero(Tmm(get_C_tensor(bi, bdb, ldb))); |
697 | } |
698 | } |
699 | } |
700 | } |
701 | |
702 | void jit_brgemm_amx_uker_base_t::apply_alpha_beta_to_vector( |
703 | const int idx, const Address &addr, bool is_ld_tail) { |
704 | auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
705 | auto zmm = Zmm(idx); |
706 | auto zmm_beta = zmm_tmp_1(); |
707 | auto zmm_alpha = zmm_tmp_2(); |
708 | auto zmm_prev_dst = zmm_tmp_3(); |
709 | |
710 | const bool apply_alpha = brg.alpha != 1.f; |
711 | const bool apply_beta = brg.beta != 0.f; |
712 | if (!apply_alpha && !apply_beta) return; |
713 | |
714 | const bool dq2ps_required = brg.is_int8 && (apply_alpha || brg.beta != 1.f); |
715 | const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required; |
716 | |
717 | if (apply_beta && !use_vadd_for_beta) { |
718 | mov(reg_tmp_gpr, float2int(static_cast<float>(brg.beta))); |
719 | vmovq(Xmm(zmm_beta.getIdx()), reg_tmp_gpr); |
720 | vbroadcastss(zmm_beta, Xmm(zmm_beta.getIdx())); |
721 | } |
722 | if (apply_alpha) { |
723 | mov(reg_tmp_gpr, float2int(static_cast<float>(brg.alpha))); |
724 | vmovq(Xmm(zmm_alpha.getIdx()), reg_tmp_gpr); |
725 | vbroadcastss(zmm_alpha, Xmm(zmm_alpha.getIdx())); |
726 | } |
727 | if (dq2ps_required) vcvtdq2ps(zmm, zmm); |
728 | if (apply_alpha) vmulps(zmm, zmm, zmm_alpha); |
729 | if (apply_beta) { |
730 | if (use_vadd_for_beta) { |
731 | auto zmm_masked = zmm | k_mask | T_z; |
732 | if (brg.is_int8) |
733 | vpaddd(zmm_masked, zmm, addr); |
734 | else |
735 | vaddps(zmm_masked, zmm, addr); |
736 | } else { |
737 | cvt2ps(brg.dt_c, zmm_prev_dst, addr, true, false, k_mask); |
738 | vfmadd231ps(zmm, zmm_prev_dst, zmm_beta); |
739 | } |
740 | } |
741 | } |
742 | |
743 | void jit_brgemm_amx_uker_base_t::apply_post_ops_to_range(brgemm_iteration_t &bi, |
744 | int bd_start, int bd_finish, int bd_inp_bdb, int ldb) { |
745 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
746 | |
747 | auto ldb_pos = bi.ldi.pos + ldb; |
748 | auto is_ld_tail = bi.ldi.is_tail; |
749 | |
750 | if (brg.with_binary) { |
751 | if (handle_binary_po_offset_) { |
752 | for (int bd = bd_start; bd < bd_finish; bd++) { |
753 | // We have no way to tell the injector to skip some vectors. |
754 | // Therefore, we must set parameters correctly for all registers. |
755 | // TODO: Make it possible to specify "skipped" vectors to injector |
756 | const auto idx = accm(bd).getIdx(); |
757 | if (is_ld_tail) rhs_arg_params.vmm_tail_idx_.emplace(idx); |
758 | rhs_arg_params.vmm_idx_to_out_reg.emplace(idx, reg_D); |
759 | |
760 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
761 | if (bd_out_bd == -1) continue; |
762 | |
763 | const auto d_offset = D_offset(bd_out_bd, ldb_pos); |
764 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
765 | idx, d_offset); |
766 | } |
767 | } |
768 | } |
769 | |
770 | const auto sum_injector = [&] { |
771 | const float *p_sum_scale = &brg.sum_scale; |
772 | const int32_t *p_sum_zp = &brg.sum_zp; |
773 | const bool p_sum_scale_reg_set = *p_sum_scale != 1.f; |
774 | const bool p_sum_zp_reg_set = *p_sum_zp != 0; |
775 | |
776 | { |
777 | if (p_sum_scale_reg_set) |
778 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
779 | |
780 | const auto &zmm_sum_zp = zmm_tmp_2(); |
781 | if (p_sum_zp_reg_set) { |
782 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
783 | vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
784 | } |
785 | |
786 | const auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
787 | const auto zmm_prev_dst = Xbyak::Zmm(0); |
788 | |
789 | for (int bd = bd_start; bd < bd_finish; bd++) { |
790 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
791 | if (bd_out_bd == -1) continue; |
792 | |
793 | auto zmm = accm(bd); |
794 | const auto d_offset = D_offset(bd_out_bd, ldb_pos); |
795 | auto addr = EVEX_compress_addr(reg_D, d_offset); |
796 | |
797 | cvt2ps(brg.sum_dt, zmm_prev_dst, addr, true, false, k_mask); |
798 | if (p_sum_zp_reg_set) vsubps(zmm_prev_dst, zmm_sum_zp); |
799 | if (!p_sum_scale_reg_set) |
800 | vaddps(zmm, zmm_prev_dst); |
801 | else |
802 | vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
803 | } |
804 | } |
805 | }; |
806 | |
807 | if (brg.with_sum) { |
808 | postops_injector_->set_lambda_injector( |
809 | primitive_kind::sum, sum_injector); |
810 | } |
811 | |
812 | // Using knowledge how "accm" assign zmm registers. |
813 | // TODO: make this code more clear |
814 | const auto finish_idx = accm(bd_start).getIdx() + 1; |
815 | const auto start_idx = accm(bd_finish - 1).getIdx(); |
816 | postops_injector_->compute_vector_range( |
817 | start_idx, finish_idx, rhs_arg_params); |
818 | } |
819 | |
820 | void jit_brgemm_amx_uker_base_t::maybe_saturation(Xbyak::Zmm &zmm) { |
821 | if (!dt_requires_saturation_) return; |
822 | saturate_f32(zmm, zmm_lbound, zmm_ubound, brg.dt_d); |
823 | vcvtps2dq(zmm, zmm); |
824 | } |
825 | |
826 | void jit_brgemm_amx_uker_base_t::prepare_post_ops_registers_ldb( |
827 | brgemm_iteration_t &bi, int ldb) { |
828 | if (!bi.apply_postops) return; |
829 | auto k_mask = (!bi.ldi.is_tail) ? ld_full_mask : ld_tail_mask; |
830 | |
831 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
832 | mov(reg_aux_zp_comp_a, ptr[rsp + reg_zp_comp_a_offs_]); |
833 | |
834 | const auto zp_comp_a_off = zp_comp_a_offset(bi.ldi.pos + ldb); |
835 | const auto zp_comp_a_addr |
836 | = EVEX_compress_addr(reg_aux_zp_comp_a, zp_comp_a_off); |
837 | cvt2ps(data_type::s32, zmm_zp_comp_a, zp_comp_a_addr, true, false, |
838 | k_mask); |
839 | } |
840 | |
841 | if (brg.zp_type_c != brgemm_broadcast_t::none) { |
842 | mov(reg_zp_c_values, ptr[rsp + reg_zp_c_values_offs_]); |
843 | if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) { |
844 | vcvtdq2ps(zmm_zp_c, EVEX_compress_addr(reg_zp_c_values, 0, true)); |
845 | } |
846 | if (brg.zp_type_c == brgemm_broadcast_t::per_n) { |
847 | const auto zp_c_off = zp_c_values_offset(bi, ldb); |
848 | const auto zp_c_addr |
849 | = EVEX_compress_addr(reg_zp_c_values, zp_c_off); |
850 | cvt2ps(data_type::s32, zmm_zp_c, zp_c_addr, true, false, k_mask); |
851 | } |
852 | } |
853 | } |
854 | |
855 | void jit_brgemm_amx_uker_base_t::prepare_post_ops_registers( |
856 | brgemm_iteration_t &bi) { |
857 | if (!bi.apply_postops) return; |
858 | dim_iteration_t &ldi = bi.ldi; |
859 | auto k_mask = (!ldi.is_tail) ? ld_full_mask : ld_tail_mask; |
860 | if (brg.with_scales) { |
861 | mov(reg_scales, ptr[param1 + GET_OFF(ptr_scales)]); |
862 | for (int ldb = 0; ldb < ldi.block2; ldb++) { |
863 | auto scales_ptr = EVEX_compress_addr( |
864 | reg_scales, scales_offset(ldi.pos + ldb)); |
865 | vmovups(zmm_scales(ldb) | k_mask | T_z, scales_ptr); |
866 | } |
867 | } |
868 | |
869 | if (brg.with_bias) { |
870 | mov(reg_bias, ptr[param1 + GET_OFF(ptr_bias)]); |
871 | |
872 | for (int ldb = 0; ldb < ldi.block2; ldb++) { |
873 | auto ptr_bias |
874 | = EVEX_compress_addr(reg_bias, bias_offset(ldi.pos + ldb)); |
875 | cvt2ps(brg.dt_bias, zmm_bias(ldb), ptr_bias, true, false, k_mask); |
876 | } |
877 | } |
878 | } |
879 | |
880 | void jit_brgemm_amx_uker_base_t::uni_prefetch( |
881 | const Address &addr, brgemm_kernel_prefetching_t pft) { |
882 | if (pft == brgemm_kernel_prefetching_t::brgemm_prf1) |
883 | prefetcht1(addr); |
884 | else if (pft == brgemm_kernel_prefetching_t::brgemm_prf2) |
885 | prefetcht2(addr); |
886 | } |
887 | |
888 | void jit_brgemm_amx_uker_base_t::prefetch_CD_range(brgemm_iteration_t &bi, |
889 | brgemm_kernel_prefetching_t pft, int bd_start, int bd_finish, |
890 | int bd_inp_bdb, int ldb) { |
891 | auto ldb_pos = bi.ldi.pos + ldb; |
892 | for (int bd = bd_start; bd < bd_finish; bd++) { |
893 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
894 | if (bd_out_bd == -1) continue; |
895 | if (bi.apply_postops) { |
896 | const auto d_offset = D_offset(bd_out_bd, ldb_pos); |
897 | auto ptr_D = EVEX_compress_addr(reg_D, d_offset); |
898 | uni_prefetch(ptr_D, pft); |
899 | } else if (are_post_ops_applicable_) { |
900 | const auto c_offset = C_offset(bd_out_bd, ldb_pos); |
901 | auto ptr_C = EVEX_compress_addr(reg_C, c_offset); |
902 | uni_prefetch(ptr_C, pft); |
903 | } else { |
904 | const auto d_offset = D_offset(bd_out_bd, ldb_pos); |
905 | auto ptr_D = EVEX_compress_addr(reg_D, d_offset); |
906 | uni_prefetch(ptr_D, pft); |
907 | } |
908 | } |
909 | } |
910 | |
911 | int jit_brgemm_amx_uker_base_t::calc_ops_CD(brgemm_iteration_t &bi) const |
912 | noexcept { |
913 | return (brg.rdb + (brg.rdb_tail ? 1 : 0)) * bi.ldi.block2 * bi.bdi.block2 |
914 | * (brg.brgattr.var_bs ? 1 : brg.brgattr.max_bs); |
915 | } |
916 | |
917 | void jit_brgemm_amx_uker_base_t::prefetch_CD(brgemm_iteration_t &bi, |
918 | brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) { |
919 | |
920 | const auto calc_ops = calc_ops_CD(bi); |
921 | const auto bdb_row = pfo_bi.bdi.block * pfo_bi.ldi.block2; |
922 | const auto tot_vecs = pfo_bi.bdi.block2 * bdb_row; |
923 | const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0; |
924 | |
925 | const auto nvecs = prefetch_all |
926 | ? tot_vecs |
927 | : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec); |
928 | |
929 | const auto out_typesize |
930 | = (are_post_ops_applicable_ && !prev_bi_.apply_postops) |
931 | ? brg.typesize_C |
932 | : brg.typesize_D; |
933 | for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) { |
934 | const auto bdb = prf.vec / bdb_row; |
935 | const auto vec_in_bdb_row = prf.vec - bdb * bdb_row; |
936 | const auto ldb = vec_in_bdb_row / pfo_bi.bdi.block; |
937 | const auto bd = vec_in_bdb_row % pfo_bi.bdi.block; |
938 | // prefetch output cache lines only once |
939 | if ((pfo_bi.ldi.pos + ldb) % (4 / out_typesize) == 0) { |
940 | auto bd_inp_bdb = pfo_bi.bdi.bdb_pos[bdb]; |
941 | prefetch_CD_range(pfo_bi, prf.pft, 0, 1, bd_inp_bdb + bd, ldb); |
942 | } |
943 | prf.vec++; |
944 | } |
945 | } |
946 | |
947 | void jit_brgemm_amx_uker_base_t::prefetch_A(brgemm_iteration_t &bi, |
948 | brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) { |
949 | |
950 | const auto calc_ops = bi.ldi.block2 * bi.bdi.block2; |
951 | const auto tot_vecs = pfo_bi.bdi.block2 * pfo_bi.bdi.block; |
952 | const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0; |
953 | |
954 | const auto nvecs = prefetch_all |
955 | ? tot_vecs |
956 | : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec); |
957 | |
958 | const auto rdb_A_off = rdb_A_offset(pfo_bi); |
959 | |
960 | for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) { |
961 | const auto bdb = prf.vec / pfo_bi.bdi.block; |
962 | const auto bd = prf.vec % pfo_bi.bdi.block; |
963 | const auto bd_inp_bdb = pfo_bi.bdi.bdb_pos[bdb]; |
964 | |
965 | //TODO: looks like we have to prefetch in each bs separately |
966 | const auto ptr_A = EVEX_compress_addr( |
967 | reg_aux_A, A_offset(bd_inp_bdb) + bd * LDA_size_ + rdb_A_off); |
968 | uni_prefetch(ptr_A, prf.pft); |
969 | prf.vec++; |
970 | } |
971 | } |
972 | |
973 | void jit_brgemm_amx_uker_base_t::prefetch_B(brgemm_iteration_t &bi, |
974 | brgemm_iteration_t &pfo_bi, prf_t &prf, bool prefetch_all) { |
975 | |
976 | const auto calc_ops = bi.ldi.block2 * bi.bdi.block2; |
977 | const auto tot_vecs = pfo_bi.ldi.block2 * pfo_bi.rdi.block; |
978 | const auto pfo_vecs_per_store = (calc_ops) ? div_up(tot_vecs, calc_ops) : 0; |
979 | |
980 | const auto nvecs = prefetch_all |
981 | ? tot_vecs |
982 | : nstl::min(pfo_vecs_per_store, tot_vecs - prf.vec); |
983 | |
984 | // TODO: check these addressing for correctness |
985 | const auto rdb_B_off = rdb_B_offset(pfo_bi) + ldb_B_offset(pfo_bi); |
986 | |
987 | for (int iv = 0; iv < nvecs && prf.vec < tot_vecs; iv++) { |
988 | |
989 | const auto ldb = prf.vec / pfo_bi.rdi.block; |
990 | const auto rb = prf.vec % pfo_bi.rdi.block; |
991 | //TODO: looks like we have to prefetch in each bs separately |
992 | const auto ptr_B = EVEX_compress_addr( |
993 | reg_aux_B, B_offset(ldb) + rdb_B_off + rb * LDB_size_); |
994 | |
995 | uni_prefetch(ptr_B, prf.pft); |
996 | prf.vec++; |
997 | } |
998 | } |
999 | |
1000 | void jit_brgemm_amx_uker_base_t::prefetching( |
1001 | brgemm_iteration_t &bi, bool prefetch_all) { |
1002 | // for var_bs we do prefetch on last iteration by bs only |
1003 | if (brg.brgattr.var_bs && !bi.bsi.is_last) return; |
1004 | brgemm_iteration_t pfo_bi; |
1005 | if (brg.prfC.dist1 >= 0) { |
1006 | bool is_pfo_bi = false; |
1007 | brgemm_iteration_t pfo_bi; |
1008 | if (use_ils_ && get_store_by_vectors(bi.apply_postops)) { |
1009 | if (was_prev_bi && brg.prfC.dist1 == 0) { |
1010 | is_pfo_bi = true; |
1011 | pfo_bi = prev_bi_; |
1012 | } else if (brg.prfC.dist1 > 0) { |
1013 | is_pfo_bi = bi_shift_output(bi, brg.prfC.dist1 - 1, pfo_bi); |
1014 | } |
1015 | } else { |
1016 | is_pfo_bi = bi_shift_output(bi, brg.prfC.dist1, pfo_bi); |
1017 | } |
1018 | if (is_pfo_bi) prefetch_CD(bi, pfo_bi, prf1C, prefetch_all); |
1019 | } |
1020 | if (brg.prfC.dist2 >= 0) { |
1021 | bool is_pfo_bi = false; |
1022 | brgemm_iteration_t pfo_bi; |
1023 | if (use_ils_ && get_store_by_vectors(bi.apply_postops)) { |
1024 | if (was_prev_bi && brg.prfC.dist2 == 0) { |
1025 | is_pfo_bi = true; |
1026 | pfo_bi = prev_bi_; |
1027 | } else if (brg.prfC.dist2 > 0) { |
1028 | is_pfo_bi = bi_shift_output(bi, brg.prfC.dist2 - 1, pfo_bi); |
1029 | } |
1030 | } else { |
1031 | is_pfo_bi = bi_shift_output(bi, brg.prfC.dist2, pfo_bi); |
1032 | } |
1033 | if (is_pfo_bi) prefetch_CD(bi, pfo_bi, prf2C, prefetch_all); |
1034 | } |
1035 | if (brg.prfA.dist1 >= 0) { |
1036 | if (bi_shift_A(bi, brg.prfA.dist1, pfo_bi)) |
1037 | prefetch_A(bi, pfo_bi, prf1A, prefetch_all); |
1038 | } |
1039 | if (brg.prfA.dist2 >= 0) { |
1040 | if (bi_shift_A(bi, brg.prfA.dist2, pfo_bi)) |
1041 | prefetch_A(bi, pfo_bi, prf2A, prefetch_all); |
1042 | } |
1043 | if (brg.prfB.dist1 >= 0) { |
1044 | if (bi_shift_B(bi, brg.prfB.dist1, pfo_bi)) |
1045 | prefetch_B(bi, pfo_bi, prf1B, prefetch_all); |
1046 | } |
1047 | if (brg.prfB.dist2 >= 0) { |
1048 | if (bi_shift_B(bi, brg.prfB.dist2, pfo_bi)) |
1049 | prefetch_B(bi, pfo_bi, prf2B, prefetch_all); |
1050 | } |
1051 | } |
1052 | |
1053 | void jit_brgemm_amx_uker_base_t::process_output_range(brgemm_iteration_t &bi, |
1054 | int bd_start, int bd_finish, int bd_inp_bdb, int bdb, int ldb) { |
1055 | |
1056 | const auto wsp_offset = (use_ils_ || brg.interleave_tilestores_) |
1057 | ? (bdb * prev_bi_.ldi.block2 + ldb) * brg.bd_block |
1058 | * ld_block_C_size_ |
1059 | : 0; |
1060 | |
1061 | const auto k_mask = (!bi.ldi.is_tail) ? ld_full_mask : ld_tail_mask; |
1062 | |
1063 | // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> |
1064 | // accumulated values are already converted to ps in apply_alpha_beta() |
1065 | const bool alpha_or_beta_applicable = brg.alpha != 1.0f || brg.beta != 0.f; |
1066 | const bool beta_uses_vadd |
1067 | = brg.beta == 1.f && IMPLICATION(brg.is_int8, brg.alpha == 1.0f); |
1068 | const bool dq2ps_required = brg.is_int8 |
1069 | && IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); |
1070 | |
1071 | bool some_bd_mask = false; |
1072 | for (int bd = bd_start; bd < bd_finish; bd++) { |
1073 | auto zmm = accm(bd); |
1074 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1075 | if (bd_out_bd == -1) continue; |
1076 | |
1077 | auto vreg_acc |
1078 | = bi.ldi.is_tail ? accm(bd) | ld_tail_mask | T_z : accm(bd); |
1079 | some_bd_mask = true; |
1080 | |
1081 | const auto buf_offset = bd * ld_block_C_size_; |
1082 | vmovups(vreg_acc, ptr[reg_buf + buf_offset + wsp_offset]); |
1083 | |
1084 | const auto c_offset = C_offset(bd_out_bd, bi.ldi.pos + ldb); |
1085 | const auto ptr_C = EVEX_compress_addr(reg_C, c_offset); |
1086 | |
1087 | if (need_to_apply_alpha_beta_) |
1088 | apply_alpha_beta_to_vector(zmm.getIdx(), ptr_C, bi.ldi.is_tail); |
1089 | |
1090 | if (!bi.apply_postops) continue; |
1091 | |
1092 | if (dq2ps_required) vcvtdq2ps(zmm, zmm); |
1093 | } |
1094 | |
1095 | if (!bi.apply_postops || !some_bd_mask) return; |
1096 | |
1097 | if (brg.zp_type_a != brgemm_broadcast_t::none) { |
1098 | for (int bd = bd_start; bd < bd_finish; bd++) { |
1099 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1100 | if (bd_out_bd == -1) continue; |
1101 | |
1102 | auto zmm = accm(bd); |
1103 | vaddps(zmm, zmm, zmm_zp_comp_a); |
1104 | } |
1105 | } |
1106 | |
1107 | if (brg.zp_type_b != brgemm_broadcast_t::none) { |
1108 | mov(reg_zp_comp_b, ptr[rsp + reg_zp_comp_b_offs_]); |
1109 | |
1110 | auto zmm_zp_comp_b = zmm_tmp_1(); |
1111 | for (int bd = bd_start; bd < bd_finish; bd++) { |
1112 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1113 | if (bd_out_bd == -1) continue; |
1114 | |
1115 | auto zmm = accm(bd); |
1116 | |
1117 | const auto zp_comp_b_off = zp_comp_b_offset(bd_out_bd); |
1118 | vcvtdq2ps(zmm_zp_comp_b, |
1119 | EVEX_compress_addr(reg_zp_comp_b, zp_comp_b_off, true)); |
1120 | |
1121 | vaddps(zmm, zmm, zmm_zp_comp_b); |
1122 | } |
1123 | } |
1124 | |
1125 | if (brg.with_scales) { |
1126 | for (int bd = bd_start; bd < bd_finish; bd++) { |
1127 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1128 | if (bd_out_bd == -1) continue; |
1129 | |
1130 | auto zmm = accm(bd); |
1131 | const Xbyak::Zmm scaled_zmm = zmm_mask(zmm, true, false, k_mask); |
1132 | vmulps(scaled_zmm, scaled_zmm, zmm_scales(ldb)); |
1133 | } |
1134 | } |
1135 | |
1136 | if (brg.with_bias) { |
1137 | for (int bd = bd_start; bd < bd_finish; bd++) { |
1138 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1139 | if (bd_out_bd == -1) continue; |
1140 | |
1141 | auto zmm = accm(bd); |
1142 | vaddps(zmm, zmm, zmm_bias(ldb)); |
1143 | } |
1144 | } |
1145 | |
1146 | if (postops_injector_) { |
1147 | apply_post_ops_to_range(bi, bd_start, bd_finish, bd_inp_bdb, ldb); |
1148 | } |
1149 | } |
1150 | |
1151 | void jit_brgemm_amx_uker_base_t::store_vector_with_post_ops(const int idx, |
1152 | const Address &addr, const int bd, const int ldb, bool is_ld_tail) { |
1153 | auto zmm = Zmm(idx); |
1154 | auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; |
1155 | |
1156 | if (brg.zp_type_c != brgemm_broadcast_t::none) vaddps(zmm, zmm, zmm_zp_c); |
1157 | |
1158 | maybe_saturation(zmm); |
1159 | |
1160 | auto ymm = Xbyak::Ymm(idx); |
1161 | const Xbyak::Zmm r_zmm = zmm_mask(zmm, true, true, k_mask); |
1162 | const Xbyak::Ymm r_ymm = ymm_mask(ymm, true, true, k_mask); |
1163 | |
1164 | switch (brg.dt_d) { |
1165 | case data_type::f32: |
1166 | case data_type::s32: vmovups(addr, r_zmm); break; |
1167 | case data_type::bf16: |
1168 | vcvtneps2bf16(ymm, zmm); |
1169 | vmovdqu16(addr, r_ymm); |
1170 | break; |
1171 | case data_type::f16: |
1172 | vcvtps2ph(ymm, zmm, _op_mxcsr); |
1173 | vmovdqu16(addr, r_ymm); |
1174 | break; |
1175 | case data_type::s8: vpmovsdb(addr, r_zmm); break; |
1176 | case data_type::u8: vpmovusdb(addr, r_zmm); break; |
1177 | default: assert(!"unknown dst_dt" ); |
1178 | } |
1179 | } |
1180 | |
1181 | void jit_brgemm_amx_uker_base_t::store_vector_without_post_ops( |
1182 | const int idx, const Address &addr, bool is_ld_tail) { |
1183 | auto zmm = Zmm(idx); |
1184 | |
1185 | maybe_saturation(zmm); |
1186 | |
1187 | if (is_ld_tail) |
1188 | vmovups(addr | ld_tail_mask | T_z, zmm); |
1189 | else |
1190 | vmovups(addr, zmm); |
1191 | } |
1192 | |
1193 | void jit_brgemm_amx_uker_base_t::store_vector( |
1194 | brgemm_iteration_t &bi, const int idx, const int bd, const int ldb) { |
1195 | auto ldb_pos = bi.ldi.pos + ldb; |
1196 | auto is_ld_tail = bi.ldi.is_tail; |
1197 | const auto c_offset = C_offset(bd, ldb_pos); |
1198 | const auto d_offset = D_offset(bd, ldb_pos); |
1199 | |
1200 | auto ptr_C = EVEX_compress_addr(reg_C, c_offset); |
1201 | auto ptr_D = EVEX_compress_addr(reg_D, d_offset); |
1202 | |
1203 | if (bi.apply_postops) |
1204 | store_vector_with_post_ops(idx, ptr_D, bd, ldb_pos, is_ld_tail); |
1205 | else if (are_post_ops_applicable_) |
1206 | store_vector_without_post_ops(idx, ptr_C, is_ld_tail); |
1207 | else |
1208 | store_vector_without_post_ops(idx, ptr_D, is_ld_tail); |
1209 | } |
1210 | |
1211 | void jit_brgemm_amx_uker_base_t::interleave_store( |
1212 | brgemm_iteration_t &bi, bool store_all) { |
1213 | |
1214 | if (!use_ils_) return; |
1215 | if (!was_prev_bi) return; |
1216 | if (!get_store_by_vectors(prev_bi_.apply_postops)) return; |
1217 | |
1218 | if (store_all) prefetching(prev_bi_, true); |
1219 | |
1220 | const auto bd_inp_bdb = prev_bi_.bdi.pos; |
1221 | |
1222 | auto cur_bdb = ils_bdb_; |
1223 | auto cur_ldb = ils_ldb_; |
1224 | |
1225 | // if first block |
1226 | if (ils_vec_ == 0) { |
1227 | if (!prepare_post_ops_registers_once_) { |
1228 | prepare_post_ops_registers(prev_bi_); |
1229 | } |
1230 | prepare_post_ops_registers_ldb(prev_bi_, 0); |
1231 | ils_bd_start_ = 0; |
1232 | auto bd_finish = nstl::min(ils_bd_step_, prev_bi_.bdi.block); |
1233 | process_output_range( |
1234 | prev_bi_, 0, bd_finish, bd_inp_bdb, cur_bdb, cur_ldb); |
1235 | } |
1236 | |
1237 | const auto calc_ops = calc_ops_CD(bi); |
1238 | const auto ils_store_ops |
1239 | = prev_bi_.ldi.block2 * prev_bi_.bdi.block2 * prev_bi_.bdi.block; |
1240 | const auto ils_vecs_per_store |
1241 | = (calc_ops) ? div_up(ils_store_ops, calc_ops) : 0; |
1242 | |
1243 | // last bd_block may be bd_tail |
1244 | const auto bdb_row = prev_bi_.bdi.block * prev_bi_.ldi.block2; |
1245 | const auto total_vectors = prev_bi_.bdi.block2 * bdb_row; |
1246 | const auto nvecs = store_all ? total_vectors : ils_vecs_per_store; |
1247 | for (int vec = 0; vec < nvecs && ils_vec_ < total_vectors; vec++) { |
1248 | const auto bdb = ils_vec_ / bdb_row; |
1249 | const auto vec_in_bdb_row = ils_vec_ - bdb * bdb_row; |
1250 | const auto ldb = vec_in_bdb_row / prev_bi_.bdi.block; |
1251 | const auto bd = vec_in_bdb_row % prev_bi_.bdi.block; |
1252 | |
1253 | auto bd_inp_bdb = prev_bi_.bdi.bdb_pos[bdb]; |
1254 | if (ldb != cur_ldb) prepare_post_ops_registers_ldb(prev_bi_, ldb); |
1255 | |
1256 | if (bdb != cur_bdb || ldb != cur_ldb |
1257 | || rnd_dn(bd, ils_bd_step_) != ils_bd_start_) { |
1258 | ils_bd_start_ = rnd_dn(bd, ils_bd_step_); |
1259 | auto bd_finish = nstl::min( |
1260 | ils_bd_start_ + ils_bd_step_, prev_bi_.bdi.block); |
1261 | process_output_range( |
1262 | prev_bi_, ils_bd_start_, bd_finish, bd_inp_bdb, bdb, ldb); |
1263 | } |
1264 | |
1265 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1266 | if (bd_out_bd != -1) { |
1267 | auto vreg_acc = prev_bi_.ldi.is_tail ? accm(bd) | ld_tail_mask | T_z |
1268 | : accm(bd); |
1269 | |
1270 | store_vector(prev_bi_, vreg_acc.getIdx(), bd_out_bd, ldb); |
1271 | } |
1272 | cur_bdb = bdb; |
1273 | cur_ldb = ldb; |
1274 | ils_vec_++; |
1275 | } |
1276 | ils_ldb_ = cur_ldb; |
1277 | ils_bdb_ = cur_bdb; |
1278 | } |
1279 | |
1280 | void jit_brgemm_amx_uker_base_t::store_accumulators(brgemm_iteration_t &bi) { |
1281 | |
1282 | const auto store_by_vectors = get_store_by_vectors(bi.apply_postops); |
1283 | |
1284 | if (store_by_vectors) { |
1285 | if (!brg.interleave_tilestores_) |
1286 | mov(reg_stride_ld_block, ld_block_C_size_); |
1287 | } else |
1288 | mov(reg_stride_ld_block, LDC_size_); |
1289 | |
1290 | prev_bi_ = bi; |
1291 | |
1292 | ils_vec_ = 0; |
1293 | ils_bdb_ = 0; |
1294 | ils_ldb_ = 0; |
1295 | was_prev_bi = true; |
1296 | |
1297 | prf1C.vec = 0; |
1298 | prf2C.vec = 0; |
1299 | |
1300 | if (store_by_vectors && !use_ils_ && !prepare_post_ops_registers_once_) |
1301 | prepare_post_ops_registers(bi); |
1302 | |
1303 | for (int bdb = 0; bdb < bi.bdi.block2; bdb++) { |
1304 | const auto bd_inp_bdb = bi.bdi.bdb_pos[bdb]; |
1305 | |
1306 | for (int ldb = 0; ldb < bi.ldi.block2; ldb++) { |
1307 | if (store_by_vectors) { |
1308 | if (!brg.interleave_tilestores_) { |
1309 | const auto wsp_offset = use_ils_ |
1310 | ? (bdb * bi.ldi.block2 + ldb) * brg.bd_block |
1311 | * ld_block_C_size_ |
1312 | : 0; |
1313 | tilestored(ptr[reg_buf + reg_stride_ld_block + wsp_offset], |
1314 | Tmm(get_C_tensor(bi, bdb, ldb))); |
1315 | } |
1316 | if (use_ils_) continue; |
1317 | |
1318 | prepare_post_ops_registers_ldb(bi, ldb); |
1319 | |
1320 | for (int bd_step = 0; bd_step < bi.bdi.block; |
1321 | bd_step += ils_bd_step_) { |
1322 | auto bd_finish |
1323 | = nstl::min(bd_step + ils_bd_step_, bi.bdi.block); |
1324 | process_output_range( |
1325 | bi, bd_step, bd_finish, bd_inp_bdb, bdb, ldb); |
1326 | |
1327 | for (int bd = bd_step; bd < bd_finish; bd++) { |
1328 | const auto bd_out_bd = get_out_bd(bd_inp_bdb, bd); |
1329 | if (bd_out_bd == -1) continue; |
1330 | |
1331 | auto vreg_acc = bi.ldi.is_tail |
1332 | ? accm(bd) | ld_tail_mask | T_z |
1333 | : accm(bd); |
1334 | store_vector(bi, vreg_acc.getIdx(), bd_out_bd, ldb); |
1335 | } |
1336 | } |
1337 | } else if (!brg.interleave_tilestores_) { |
1338 | const auto bd_out_bdb = get_out_bd(bd_inp_bdb, 0); |
1339 | const auto c_offset |
1340 | = C_block_offset(bd_out_bdb, bi.ldi.pos + ldb); |
1341 | tilestored(ptr[reg_C + reg_stride_ld_block + c_offset], |
1342 | Tmm(get_C_tensor(bi, bdb, ldb))); |
1343 | } |
1344 | } |
1345 | } |
1346 | } |
1347 | |
1348 | void jit_brgemm_amx_uker_base_t::set_A_B_matrices(int bs) { |
1349 | assert(brg.type == brgemm_addr); |
1350 | if (brg.brgattr.max_bs == 1) return; |
1351 | auto batch_offset = (size_t)bs * sizeof(brgemm_batch_element_t); |
1352 | if (brg.layout == brgemm_row_major) { |
1353 | mov(reg_aux_A, |
1354 | EVEX_compress_addr(reg_addr_batch, |
1355 | batch_offset + GET_OFF_BATCH_ELEMENT(ptr.A))); |
1356 | mov(reg_aux_B, |
1357 | EVEX_compress_addr(reg_addr_batch, |
1358 | batch_offset + GET_OFF_BATCH_ELEMENT(ptr.B))); |
1359 | } else { |
1360 | mov(reg_aux_A, |
1361 | EVEX_compress_addr(reg_addr_batch, |
1362 | batch_offset + GET_OFF_BATCH_ELEMENT(ptr.B))); |
1363 | mov(reg_aux_B, |
1364 | EVEX_compress_addr(reg_addr_batch, |
1365 | batch_offset + GET_OFF_BATCH_ELEMENT(ptr.A))); |
1366 | } |
1367 | } |
1368 | |
1369 | void jit_brgemm_amx_uker_base_t::set_A_B_matrices() { |
1370 | assert(brg.type == brgemm_addr); |
1371 | assert(brg.brgattr.var_bs); |
1372 | |
1373 | if (brg.layout == brgemm_row_major) { |
1374 | mov(reg_aux_A, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
1375 | mov(reg_aux_B, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
1376 | } else { |
1377 | mov(reg_aux_A, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
1378 | mov(reg_aux_B, ptr[reg_aux1_batch + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
1379 | } |
1380 | } |
1381 | |
1382 | void jit_brgemm_amx_uker_base_t::maybe_tileloadd_nt( |
1383 | brgemm_iteration_t &bi, matrix_kind_t mk, int xdb, size_t offset) { |
1384 | |
1385 | const bool is_A = mk == matrix_kind_t::matrix_A; |
1386 | bool load_nt = is_A ? brg.load_nt_A : brg.load_nt_B; |
1387 | |
1388 | auto t1 = Tmm(is_A ? brg.get_A_tensor(xdb, bi.bdi.is_tail) |
1389 | : brg.get_B_tensor(xdb, bi.ldi.is_tail)); |
1390 | auto reg_base = is_A ? reg_aux_A : reg_aux_B; |
1391 | auto reg_stride = is_A ? reg_stride_lda : reg_stride_ldb; |
1392 | |
1393 | if (brg.is_bf32) |
1394 | // try_load_nt is not supported in maybe_pre_process_data as there is |
1395 | // no guarantee that the data is cacheline aligned. |
1396 | maybe_pre_process_data(bi, t1, reg_base, offset, reg_stride, mk); |
1397 | else if (load_nt) |
1398 | tileloaddt1(t1, ptr[reg_base + offset + reg_stride]); |
1399 | else |
1400 | tileloadd(t1, ptr[reg_base + offset + reg_stride]); |
1401 | } |
1402 | |
1403 | void jit_brgemm_amx_uker_base_t::maybe_tilestore(brgemm_iteration_t &bi, |
1404 | int bdb_idx, int ldb_idx, bool do_pre_tilestore, |
1405 | bool do_post_tilestore) { |
1406 | auto current_tensor_idx = get_C_tensor(bi, bdb_idx, ldb_idx); |
1407 | |
1408 | if (!brg.interleave_tilestores_) return; |
1409 | const auto current_tensor_number |
1410 | = current_tensor_idx - get_C_tensor(bi, 0, 0); |
1411 | const auto store_tensor_shift |
1412 | = do_pre_tilestore ? (bi.bdi.block2 == brg.bdb2_tail ? 2 : 1) : 0; |
1413 | const auto store_tensor_idx = current_tensor_idx + store_tensor_shift; |
1414 | const auto store_tensor_number = current_tensor_number + store_tensor_shift; |
1415 | |
1416 | const auto max_store_tensor_number |
1417 | = prev_bi_.bdi.block2 * prev_bi_.ldi.block2; |
1418 | bool perform_store |
1419 | = (do_pre_tilestore |
1420 | && (store_tensor_number >= 2 |
1421 | && store_tensor_number < max_store_tensor_number)) |
1422 | || (do_post_tilestore && (store_tensor_number < 2)); |
1423 | |
1424 | if (perform_store) { |
1425 | if (do_pre_tilestore) { |
1426 | bdb_idx = store_tensor_idx / bi.ldi.block2; |
1427 | ldb_idx = store_tensor_idx % bi.ldi.block2; |
1428 | } |
1429 | const bool store_by_vectors = get_store_by_vectors(bi.apply_postops); |
1430 | Tmm acc = Tmm(store_tensor_idx); |
1431 | if (store_by_vectors) { |
1432 | const auto wsp_offset = (use_ils_ || brg.interleave_tilestores_) |
1433 | ? (bdb_idx * bi.ldi.block2 + ldb_idx) * brg.bd_block |
1434 | * ld_block_C_size_ |
1435 | : 0; |
1436 | tilestored(ptr[reg_buf + reg_stride_ld_block + wsp_offset], acc); |
1437 | } else { |
1438 | const auto &store_inp_bdi |
1439 | = do_pre_tilestore ? prev_bi_.bdi : bi.bdi; |
1440 | const auto store_ldb_ind |
1441 | = do_pre_tilestore ? prev_bi_.ldi.pos : bi.ldi.pos; |
1442 | const auto bd_out_bdb |
1443 | = get_out_bd(store_inp_bdi.bdb_pos[bdb_idx], 0); |
1444 | const auto c_offset |
1445 | = C_block_offset(bd_out_bdb, store_ldb_ind + ldb_idx); |
1446 | tilestored(ptr[reg_C + reg_stride_ld_block + c_offset], acc); |
1447 | } |
1448 | tilezero(acc); |
1449 | } |
1450 | } |
1451 | |
1452 | void jit_brgemm_amx_uker_base_t::tdpbxxd(brgemm_iteration_t &bi, int bdb_idx, |
1453 | int ldb_idx, bool do_pre_tilestore, bool do_post_tilestore) { |
1454 | prefetching(bi, false); |
1455 | maybe_tilestore(bi, bdb_idx, ldb_idx, do_pre_tilestore, false); |
1456 | |
1457 | const Tmm &x1 = Tmm(get_C_tensor(bi, bdb_idx, ldb_idx)); |
1458 | const Tmm &x2 = Tmm(brg.get_A_tensor(bdb_idx, bi.bdi.is_tail)); |
1459 | const Tmm &x3 = Tmm(brg.get_B_tensor(ldb_idx, bi.ldi.is_tail)); |
1460 | |
1461 | if (brg.is_bf32 |
1462 | || (brg.dt_a == data_type::bf16 && brg.dt_b == data_type::bf16)) { |
1463 | tdpbf16ps(x1, x2, x3); |
1464 | } else if (brg.dt_a == data_type::f16 && brg.dt_b == data_type::f16) { |
1465 | tdpfp16ps(x1, x2, x3); |
1466 | } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8) { |
1467 | tdpbuud(x1, x2, x3); |
1468 | } else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) { |
1469 | tdpbusd(x1, x2, x3); |
1470 | } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) { |
1471 | tdpbsud(x1, x2, x3); |
1472 | } else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8) { |
1473 | tdpbssd(x1, x2, x3); |
1474 | } else { |
1475 | assert(!"unsupported combination" ); |
1476 | } |
1477 | interleave_store(bi, false); |
1478 | maybe_tilestore(bi, bdb_idx, ldb_idx, false, do_post_tilestore); |
1479 | } |
1480 | |
1481 | // This method down-converts the data from f32 to bf16 and saves at reg_buf. |
1482 | // Generally used by matrix_A, where no vnni transformation of data is needed. |
1483 | void jit_brgemm_amx_uker_base_t::bf32_downconvert(int num_rows, |
1484 | int tile_num_col_bytes, reg64_t reg_data, int offset, |
1485 | reg64_t reg_data_stride, reg64_t reg_buf, bool is_rd_tail) { |
1486 | const auto rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; |
1487 | const auto max_num_cols |
1488 | = nstl::min<int>(tile_num_col_bytes / sizeof(bfloat16_t), rd_block); |
1489 | const auto col_tail = max_num_cols % simd_w; |
1490 | auto zmm_1 = zmm_tmp_1(); |
1491 | auto zmm_2 = zmm_tmp_2(); |
1492 | auto zmm_2_masked = col_tail ? zmm_2 | bf32_col_mask | T_z : zmm_2; |
1493 | |
1494 | assert(max_num_cols > 0); |
1495 | |
1496 | if (col_tail) { |
1497 | const int tail_mask = (1 << col_tail) - 1; |
1498 | auto reg_tmp_32 = reg_tmp_gpr.cvt32(); |
1499 | mov(reg_tmp_32, tail_mask); |
1500 | kmovw(bf32_col_mask, reg_tmp_32); |
1501 | } |
1502 | |
1503 | // Note: using the same register used in col_tail, so order is important |
1504 | const auto reg_data_aux = reg_tmp_gpr; |
1505 | lea(reg_data_aux, ptr[reg_data + offset]); |
1506 | |
1507 | for (int r = 0; r < num_rows; ++r) { |
1508 | if (max_num_cols > 16) { |
1509 | vmovups(zmm_1, ptr[reg_data_aux]); |
1510 | vmovups(zmm_2_masked, ptr[reg_data_aux + zmm_width_in_bytes]); |
1511 | vcvtne2ps2bf16(zmm_1, zmm_2, zmm_1); |
1512 | // we assume enough padding space is available. |
1513 | vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_1); |
1514 | } else { |
1515 | auto ymm_1 = Ymm(zmm_1.getIdx()); |
1516 | auto ymm_1_masked |
1517 | = max_num_cols == 16 ? ymm_1 : ymm_1 | bf32_col_mask | T_z; |
1518 | vcvtneps2bf16(ymm_1_masked, ptr[reg_data_aux]); |
1519 | vmovups(ptr[reg_buf + r * zmm_width_in_bytes], ymm_1); |
1520 | } |
1521 | add(reg_data_aux, reg_data_stride); |
1522 | } |
1523 | } |
1524 | |
1525 | // This method down-converts and transforms the data from f32 to bf16_vnni |
1526 | // format. Generally used by matrix_B. |
1527 | void jit_brgemm_amx_uker_base_t::bf32_downconvert_to_vnni(int num_rows, |
1528 | int tile_num_col_bytes, reg64_t reg_data, int offset, |
1529 | reg64_t reg_data_stride, reg64_t reg_buf, bool is_rd_tail) { |
1530 | const auto num_cols_ele = tile_num_col_bytes / sizeof(bfloat16_t); |
1531 | const auto num_N = num_cols_ele / sizeof(bfloat16_t); |
1532 | const auto col_tail = num_N % simd_w; |
1533 | const auto zmm_1 = zmm_tmp_1(); |
1534 | const auto zmm_2 = zmm_tmp_2(); |
1535 | |
1536 | assert(num_N > 0); |
1537 | |
1538 | auto load = [&](Zmm zmm, Address addr) { |
1539 | if (col_tail) |
1540 | vmovups(zmm | bf32_col_mask | T_z, addr); |
1541 | else |
1542 | vmovups(zmm, addr); |
1543 | }; |
1544 | |
1545 | if (col_tail) { |
1546 | const int tail_mask = (1 << col_tail) - 1; |
1547 | auto reg_tmp_32 = reg_tmp_gpr.cvt32(); |
1548 | mov(reg_tmp_32, tail_mask); |
1549 | kmovw(bf32_col_mask, reg_tmp_32); |
1550 | } |
1551 | |
1552 | // Note: using the same register used in col_tail, so order is important |
1553 | const auto reg_data_aux = reg_tmp_gpr; |
1554 | lea(reg_data_aux, ptr[reg_data + offset]); |
1555 | |
1556 | const auto rd_block = is_rd_tail ? brg.rdb_tail : brg.rd_block; |
1557 | const int vnni_granularity |
1558 | = data_type_vnni_granularity(data_type_t::dnnl_bf16); |
1559 | const auto r_end |
1560 | = nstl::min(utils::div_up(rd_block, vnni_granularity), num_rows); |
1561 | |
1562 | for (int r = 0; r < r_end; ++r) { |
1563 | load(zmm_1, ptr[reg_data_aux]); |
1564 | |
1565 | if (r * vnni_granularity + 1 >= rd_block) { |
1566 | vpxord(zmm_2, zmm_2, zmm_2); |
1567 | } else { |
1568 | load(zmm_2, ptr[reg_data_aux + reg_data_stride]); |
1569 | } |
1570 | |
1571 | vcvtne2ps2bf16(zmm_1, zmm_2, zmm_1); |
1572 | vpermw(zmm_1, zmm_bf32_pemute, zmm_1); |
1573 | vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_1); |
1574 | lea(reg_data_aux, |
1575 | ptr[reg_data_aux + vnni_granularity * reg_data_stride]); |
1576 | } |
1577 | |
1578 | // zero rest of the tile data |
1579 | if (r_end < num_rows) { |
1580 | vpxord(zmm_2, zmm_2, zmm_2); |
1581 | for (int r = r_end; r < num_rows; ++r) |
1582 | vmovups(ptr[reg_buf + r * zmm_width_in_bytes], zmm_2); |
1583 | } |
1584 | } |
1585 | |
1586 | void jit_brgemm_amx_uker_base_t::maybe_pre_process_data(brgemm_iteration_t &bi, |
1587 | const Tmm &t1, reg64_t reg_base, size_t offset, reg64_t reg_stride, |
1588 | matrix_kind_t mk) { |
1589 | |
1590 | auto should_save_transform = [&](matrix_kind_t mk) { |
1591 | // save if there is a reuse |
1592 | if (mk == matrix_A) { |
1593 | return brg.ldb + (brg.ldb_tail != 0) > brg.ld_block2; |
1594 | } else { |
1595 | return brg.bdb + (brg.bdb_tail != 0) > brg.bd_block2; |
1596 | } |
1597 | }; |
1598 | |
1599 | const bool is_A = mk == matrix_A; |
1600 | auto &transform_buf = is_A ? transform_buf_map_A_ : transform_buf_map_B_; |
1601 | |
1602 | const auto transform_offset |
1603 | = use_ils_ ? brg.get_num_C_tiles() * tile_size : 0; |
1604 | const auto max_bdb2 = brg.bd_block2; |
1605 | const auto max_rdb = brg.rdb + (brg.rdb_tail != 0); |
1606 | const auto matrix_a_offset = transform_offset; |
1607 | const auto matrix_b_offset = transform_offset |
1608 | + tile_size |
1609 | * (nstl::max<int>(should_save_transform(mk), |
1610 | should_save_transform(matrix_A) * brg.brgattr.max_bs |
1611 | * max_bdb2 * max_rdb)); |
1612 | const auto matrix_offset = is_A ? matrix_a_offset : matrix_b_offset; |
1613 | const std::string key |
1614 | = std::to_string(bi.bsi.pos) + "_" + std::to_string(offset); |
1615 | |
1616 | if (transform_buf.find(key) != transform_buf.end()) { |
1617 | auto buf_idx = transform_buf[key]; |
1618 | auto offt = matrix_offset + buf_idx * tile_size; |
1619 | tileloadd(t1, ptr[reg_buf + reg_bf32_stride + offt]); |
1620 | return; |
1621 | } |
1622 | |
1623 | auto buf_offt = matrix_offset; |
1624 | // save offset of the transformation if required. |
1625 | if (should_save_transform(mk)) { |
1626 | auto buf_idx = transform_buf.size(); |
1627 | buf_offt = matrix_offset + buf_idx * tile_size; |
1628 | transform_buf[key] = buf_idx; |
1629 | } |
1630 | |
1631 | if (buf_offt) add(reg_buf, buf_offt); |
1632 | mov(reg_bf32_stride, zmm_width_in_bytes); |
1633 | |
1634 | assert(t1.getIdx() >= 0 && t1.getIdx() < 16); |
1635 | const auto num_rows = palette_.rows[t1.getIdx()]; |
1636 | const auto num_col_bytes = palette_.cols[t1.getIdx()]; |
1637 | if (is_A) { |
1638 | bf32_downconvert(num_rows, num_col_bytes, reg_base, offset, reg_stride, |
1639 | reg_buf, bi.rdi.is_tail); |
1640 | } else { |
1641 | bf32_downconvert_to_vnni(num_rows, num_col_bytes, reg_base, offset, |
1642 | reg_stride, reg_buf, bi.rdi.is_tail); |
1643 | } |
1644 | |
1645 | // load into tmm from the transformed data. |
1646 | tileloadd(t1, ptr[reg_buf + reg_bf32_stride]); |
1647 | |
1648 | // reset buf pointer. |
1649 | if (buf_offt) sub(reg_buf, buf_offt); |
1650 | } |
1651 | |
1652 | void jit_brgemm_amx_uker_base_t::gemm_microkernel_amx(brgemm_iteration_t &bi) { |
1653 | prf1A.vec = 0; |
1654 | prf2A.vec = 0; |
1655 | prf1B.vec = 0; |
1656 | prf2B.vec = 0; |
1657 | |
1658 | const auto store_by_vectors = get_store_by_vectors(bi.apply_postops); |
1659 | |
1660 | bool do_post_tilestore = (brg.interleave_tilestores_ && bi.bsi.is_last |
1661 | && imap_.is_last_rdi(bi.rdi)); |
1662 | |
1663 | bool do_pre_tilestore = (brg.interleave_tilestores_ && bi.bsi.is_first |
1664 | && bi.rdi.pos == 0 && was_prev_bi); |
1665 | |
1666 | if (store_by_vectors) |
1667 | mov(reg_stride_ld_block, ld_block_C_size_); |
1668 | else |
1669 | mov(reg_stride_ld_block, LDC_size_); |
1670 | |
1671 | const auto rdb_A_off = rdb_A_offset(bi); |
1672 | const auto rdb_B_off = rdb_B_offset(bi) + ldb_B_offset(bi); |
1673 | |
1674 | for (int bdb = 0; bdb < bi.bdi.block2; bdb++) { |
1675 | const auto bd_inp_bdb = bi.bdi.bdb_pos[bdb]; |
1676 | maybe_tileloadd_nt(bi, matrix_kind_t::matrix_A, bdb, |
1677 | rdb_A_off + A_offset(bd_inp_bdb)); |
1678 | for (int ldb = 0; ldb < bi.ldi.block2; ldb++) { |
1679 | if (bdb == 0) |
1680 | maybe_tileloadd_nt(bi, matrix_kind_t::matrix_B, ldb, |
1681 | rdb_B_off + B_offset(ldb)); |
1682 | if (ldb == 0) { |
1683 | if (bdb > 0) |
1684 | tdpbxxd(bi, bdb - 1, bi.ldi.block2 - 1, do_pre_tilestore, |
1685 | do_post_tilestore); |
1686 | } else |
1687 | tdpbxxd(bi, bdb, ldb - 1, do_pre_tilestore, do_post_tilestore); |
1688 | } |
1689 | } |
1690 | // last tdpbxxd |
1691 | tdpbxxd(bi, bi.bdi.block2 - 1, bi.ldi.block2 - 1, do_pre_tilestore, |
1692 | do_post_tilestore); |
1693 | } |
1694 | |
1695 | void jit_brgemm_amx_uker_base_t::rdb_loop(brgemm_iteration_t &bi) { |
1696 | for (size_t irdi = 0; irdi < imap_.rdis.size(); irdi++) { |
1697 | bi.rdi = imap_.rdis[irdi]; |
1698 | gemm_microkernel_amx(bi); |
1699 | } |
1700 | } |
1701 | |
1702 | void jit_brgemm_amx_uker_base_t::bs_loop_body(brgemm_iteration_t &bi) { |
1703 | if (brg.brgattr.var_bs) { |
1704 | set_A_B_matrices(); |
1705 | add(reg_aux1_batch, sizeof(brgemm_batch_element_t)); |
1706 | prefetcht0(ptr[reg_aux1_batch]); |
1707 | } else { |
1708 | set_A_B_matrices(bi.bsi.pos); |
1709 | } |
1710 | |
1711 | rdb_loop(bi); |
1712 | } |
1713 | |
1714 | void jit_brgemm_amx_uker_base_t::bs_loop(brgemm_iteration_t &bi) { |
1715 | |
1716 | load_accumulators(bi); |
1717 | |
1718 | if (brg.brgattr.var_bs) { |
1719 | if (brg.alpha != 0.f) { |
1720 | Label BS_loop_label, end_BS_loop_label, first_BS_loop_label, |
1721 | last_BS_loop_label; |
1722 | |
1723 | mov(reg_BS_loop, reg_BS); |
1724 | cmp(reg_BS_loop, 0); |
1725 | jz(end_BS_loop_label, T_NEAR); |
1726 | |
1727 | mov(reg_aux1_batch, reg_addr_batch); |
1728 | // first bs iteration |
1729 | cmp(reg_BS_loop, 1); |
1730 | jg(first_BS_loop_label, T_NEAR); |
1731 | |
1732 | bi.bsi = imap_.bsis[0]; |
1733 | // only one BS iteration: first and last |
1734 | bi.bsi.is_first = true; |
1735 | bi.bsi.is_last = true; |
1736 | bs_loop_body(bi); |
1737 | jmp(end_BS_loop_label, T_NEAR); |
1738 | |
1739 | // first BS iteration |
1740 | L_aligned(first_BS_loop_label, 64); |
1741 | bi.bsi.is_first = true; |
1742 | bi.bsi.is_last = false; |
1743 | bs_loop_body(bi); |
1744 | |
1745 | dec(reg_BS_loop); |
1746 | cmp(reg_BS_loop, 1); |
1747 | je(last_BS_loop_label, T_NEAR); |
1748 | |
1749 | // middle BS iterations |
1750 | L_aligned(BS_loop_label, 64); |
1751 | { |
1752 | bi.bsi.is_first = false; |
1753 | bi.bsi.is_last = false; |
1754 | bs_loop_body(bi); |
1755 | dec(reg_BS_loop); |
1756 | cmp(reg_BS_loop, 1); |
1757 | jg(BS_loop_label, T_NEAR); |
1758 | } |
1759 | // last BS iteration |
1760 | L_aligned(last_BS_loop_label, 64); |
1761 | bi.bsi.is_first = false; |
1762 | bi.bsi.is_last = true; |
1763 | bs_loop_body(bi); |
1764 | |
1765 | L_aligned(end_BS_loop_label, 64); |
1766 | } |
1767 | store_accumulators(bi); |
1768 | } else { |
1769 | if (brg.alpha != 0.f) { |
1770 | for (int bs = 0; bs < brg.brgattr.max_bs; bs++) { |
1771 | bi.bsi = imap_.bsis[bs]; |
1772 | bs_loop_body(bi); |
1773 | } |
1774 | } |
1775 | store_accumulators(bi); |
1776 | } |
1777 | } |
1778 | |
1779 | void jit_brgemm_amx_uker_base_t::ldb_loop_body(brgemm_iteration_t &bi) { |
1780 | if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost) |
1781 | bdb_loop(bi); |
1782 | else if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost) |
1783 | bs_loop(bi); |
1784 | else |
1785 | assert(!"Unknown loop order!" ); |
1786 | } |
1787 | |
1788 | void jit_brgemm_amx_uker_base_t::ldb_loop(brgemm_iteration_t &bi) { |
1789 | // clear the transform cache for A, as the existing data is invalid as |
1790 | // we move to next bdb2 block. |
1791 | transform_buf_map_A_.clear(); |
1792 | bi.ldi = imap_.ldis[0]; |
1793 | for (size_t ildi = 0; ildi < imap_.ldis.size(); ildi++) { |
1794 | bi.ldi = imap_.ldis[ildi]; |
1795 | ldb_loop_body(bi); |
1796 | } |
1797 | } |
1798 | |
1799 | void jit_brgemm_amx_uker_base_t::bdb_loop_body(brgemm_iteration_t &bi) { |
1800 | if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost) |
1801 | ldb_loop(bi); |
1802 | else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost) |
1803 | bs_loop(bi); |
1804 | else |
1805 | assert(!"Unknown loop order!" ); |
1806 | }; |
1807 | |
1808 | void jit_brgemm_amx_uker_base_t::bdb_loop(brgemm_iteration_t &bi) { |
1809 | bi.bdi = imap_.bdis[0]; |
1810 | |
1811 | for (size_t ibdi = 0; ibdi < imap_.bdis.size(); ibdi++) { |
1812 | bi.bdi = imap_.bdis[ibdi]; |
1813 | bdb_loop_body(bi); |
1814 | } |
1815 | } |
1816 | |
1817 | void jit_brgemm_amx_uker_base_t::top_loop(brgemm_iteration_t &bi) { |
1818 | init(bi); |
1819 | if (brg.brgattr.hint_innermost_loop == brgemm_ld_loop_innermost) |
1820 | bdb_loop(bi); |
1821 | else if (brg.brgattr.hint_innermost_loop == brgemm_bd_loop_innermost) |
1822 | ldb_loop(bi); |
1823 | else |
1824 | assert(!"Unknown loop order!" ); |
1825 | |
1826 | if (brg.interleave_tilestores_) { |
1827 | bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2); |
1828 | for_(int bdb = 0; bdb < brg.bd_block2; bdb++) |
1829 | for (int ldb = 0; ldb < brg.ld_block2; ldb++) { |
1830 | maybe_tilestore(bi, bdb, ldb, true, false); |
1831 | } |
1832 | } |
1833 | |
1834 | interleave_store(bi, true); |
1835 | } |
1836 | |
1837 | void jit_brgemm_amx_uker_base_t::fill_imap() { |
1838 | imap_.bdis.clear(); |
1839 | imap_.ldis.clear(); |
1840 | imap_.rdis.clear(); |
1841 | imap_.bsis.clear(); |
1842 | size_t bdi_pos = skipped_bd_mask(0); |
1843 | auto bdi = bd_iteration_t(bdi_pos, brg.bd_block, brg.bd_block2); |
1844 | for (int bdb2 = 0; bdb2 < brg.bdb2; bdb2++) { |
1845 | bdi.pos = bdi_pos; |
1846 | bdi.bdb_pos.clear(); |
1847 | for (int bdb = 0; bdb < bdi.block2; bdb++) { |
1848 | bdi.bdb_pos.push_back(bdi_pos); |
1849 | bdi_pos += brg.bd_block; |
1850 | bdi_pos = skipped_bd_mask(bdi_pos); |
1851 | } |
1852 | bdi.idx = imap_.bdis.size(); |
1853 | imap_.bdis.push_back(bdi); |
1854 | } |
1855 | if (brg.bdb2_tail > 0) { |
1856 | bdi.block2 = brg.bdb2_tail; |
1857 | bdi.pos = bdi_pos; |
1858 | bdi.bdb_pos.clear(); |
1859 | for (int bdb = 0; bdb < bdi.block2; bdb++) { |
1860 | bdi.bdb_pos.push_back(bdi_pos); |
1861 | bdi_pos += brg.bd_block; |
1862 | bdi_pos = skipped_bd_mask(bdi_pos); |
1863 | } |
1864 | bdi.idx = imap_.bdis.size(); |
1865 | imap_.bdis.push_back(bdi); |
1866 | } |
1867 | if (brg.bdb_tail > 0) { |
1868 | bdi.block2 = 1; |
1869 | bdi.block = brg.bdb_tail; |
1870 | bdi.is_tail = true; |
1871 | bdi.pos = bdi_pos; |
1872 | bdi.bdb_pos.clear(); |
1873 | bdi.bdb_pos.push_back(bdi_pos); |
1874 | bdi.idx = imap_.bdis.size(); |
1875 | imap_.bdis.push_back(bdi); |
1876 | } |
1877 | |
1878 | auto ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2); |
1879 | for (int ldb2 = 0; ldb2 < brg.ldb2; ldb2++) { |
1880 | ldi.idx = imap_.ldis.size(); |
1881 | imap_.ldis.push_back(ldi); |
1882 | ldi.pos += ldi.block2; |
1883 | } |
1884 | if (brg.ldb2_tail > 0) { |
1885 | ldi.block2 = brg.ldb2_tail; |
1886 | ldi.idx = imap_.ldis.size(); |
1887 | imap_.ldis.push_back(ldi); |
1888 | ldi.pos += ldi.block2; |
1889 | } |
1890 | if (brg.ldb_tail > 0) { |
1891 | ldi.block2 = 1; |
1892 | ldi.block = brg.ldb_tail; |
1893 | ldi.is_tail = true; |
1894 | ldi.idx = imap_.ldis.size(); |
1895 | imap_.ldis.push_back(ldi); |
1896 | } |
1897 | |
1898 | auto rdi = dim_iteration_t(0, brg.rd_block, 1); |
1899 | for (int rdb = 0; rdb < brg.rdb; rdb++) { |
1900 | rdi.idx = imap_.rdis.size(); |
1901 | imap_.rdis.push_back(rdi); |
1902 | rdi.pos++; |
1903 | } |
1904 | if (brg.rdb_tail > 0) { |
1905 | rdi.block = brg.rdb_tail; |
1906 | rdi.is_tail = true; |
1907 | rdi.idx = imap_.rdis.size(); |
1908 | imap_.rdis.push_back(rdi); |
1909 | } |
1910 | |
1911 | bs_iteration_t bsi; |
1912 | for (int bs = 0; bs < brg.brgattr.max_bs; bs++) { |
1913 | bsi.pos = bs; |
1914 | bsi.is_first = (bs == 0); |
1915 | bsi.is_last = (bs == brg.brgattr.max_bs - 1); |
1916 | bsi.idx = imap_.bsis.size(); |
1917 | imap_.bsis.push_back(bsi); |
1918 | } |
1919 | } |
1920 | |
1921 | void jit_brgemm_amx_uker_base_t::init(brgemm_iteration_t &bi) { |
1922 | was_prev_bi = false; |
1923 | |
1924 | if (brg.brgattr.max_bs == 1) { |
1925 | if (brg.layout == brgemm_row_major) { |
1926 | mov(reg_aux_A, |
1927 | EVEX_compress_addr( |
1928 | reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.A))); |
1929 | mov(reg_aux_B, |
1930 | EVEX_compress_addr( |
1931 | reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.B))); |
1932 | } else { |
1933 | mov(reg_aux_A, |
1934 | EVEX_compress_addr( |
1935 | reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.B))); |
1936 | mov(reg_aux_B, |
1937 | EVEX_compress_addr( |
1938 | reg_addr_batch, GET_OFF_BATCH_ELEMENT(ptr.A))); |
1939 | } |
1940 | } |
1941 | |
1942 | // for many primitives which use brgemm the brg.ldb2 is equal or less than 1 |
1943 | // so we can read post ops data only once per brgemm call |
1944 | |
1945 | if (brg.ldb2 > 1) { |
1946 | prepare_post_ops_registers_once_ = false; |
1947 | } else if (brg.ldb2 == 1) { |
1948 | if (brg.ldb2_tail == 0 && brg.ldb_tail == 0) { |
1949 | prepare_post_ops_registers_once_ = true; |
1950 | bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ld_block2); |
1951 | prepare_post_ops_registers(bi); |
1952 | } |
1953 | } else if (brg.ldb2_tail > 0) { |
1954 | if (brg.ldb_tail == 0) { |
1955 | prepare_post_ops_registers_once_ = true; |
1956 | bi.ldi = dim_iteration_t(0, brg.ld_block, brg.ldb2_tail); |
1957 | prepare_post_ops_registers(bi); |
1958 | } |
1959 | } else { |
1960 | prepare_post_ops_registers_once_ = true; |
1961 | bi.ldi = dim_iteration_t(0, brg.ldb_tail, 1); |
1962 | bi.ldi.is_tail = true; |
1963 | prepare_post_ops_registers(bi); |
1964 | } |
1965 | if (bi.apply_postops) |
1966 | dt_requires_saturation_ = one_of( |
1967 | brg.dt_d, data_type::u8, data_type::s8, data_type::s32); |
1968 | else { |
1969 | // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> |
1970 | // accumulated values are converted to ps in apply_alpha_beta() |
1971 | const bool alpha_or_beta_applicable |
1972 | = brg.alpha != 1.0f || brg.beta != 0.f; |
1973 | const bool beta_uses_vadd = brg.beta == 1.f |
1974 | && IMPLICATION(brg.is_int8, brg.alpha == 1.0f); |
1975 | dt_requires_saturation_ = brg.is_int8 |
1976 | && !IMPLICATION(alpha_or_beta_applicable, beta_uses_vadd); |
1977 | } |
1978 | if (dt_requires_saturation_) { |
1979 | init_saturate_f32( |
1980 | zmm_lbound, zmm_ubound, reg_tmp_gpr, data_type::f32, brg.dt_d); |
1981 | } |
1982 | |
1983 | fill_imap(); |
1984 | prf1A.pft = brgemm_kernel_prefetching_t::brgemm_prf1; |
1985 | prf1A.dist = brg.prfA.dist1; |
1986 | prf2A.pft = brgemm_kernel_prefetching_t::brgemm_prf2; |
1987 | prf2A.dist = brg.prfA.dist2; |
1988 | prf1B.pft = brgemm_kernel_prefetching_t::brgemm_prf1; |
1989 | prf1B.dist = brg.prfB.dist1; |
1990 | prf2B.pft = brgemm_kernel_prefetching_t::brgemm_prf2; |
1991 | prf2B.dist = brg.prfB.dist2; |
1992 | prf1C.pft = brgemm_kernel_prefetching_t::brgemm_prf1; |
1993 | prf1C.dist = brg.prfC.dist1; |
1994 | prf2C.pft = brgemm_kernel_prefetching_t::brgemm_prf2; |
1995 | prf2C.dist = brg.prfC.dist2; |
1996 | } |
1997 | |
1998 | void jit_brgemm_amx_uker_base_t::generate() { |
1999 | preamble(); |
2000 | |
2001 | sub(rsp, stack_space_needed_); |
2002 | |
2003 | const auto full_mask = size_t {0xffffffffffffffff}; |
2004 | const auto tail_mask = size_t((1 << brg.ldb_tail) - 1); |
2005 | LDA_size_ = brg.typesize_A * brg.LDA; |
2006 | LDB_size_ = brg.typesize_B * brg.LDB; |
2007 | LDC_size_ = brg.typesize_C * brg.LDC; |
2008 | LDD_size_ = brg.typesize_D * brg.LDD; |
2009 | |
2010 | LDA2_size_ = brg.typesize_A * brg.LDA2; |
2011 | LDB2_size_ = brg.typesize_B * brg.LDB2; |
2012 | LDC2_size_M_ = brg.typesize_C * brg.LDC2_M; |
2013 | LDC2_size_N_ = brg.typesize_C * brg.LDC2_N; |
2014 | |
2015 | ld_block_B_size_ = brg.typesize_B |
2016 | * ((brg.brgattr.LDB2 != 0) ? brg.brgattr.LDB2 : brg.ld_block); |
2017 | ld_block_C_size_ = brg.typesize_C * brg.ld_block; |
2018 | ld_block_D_size_ = brg.typesize_D * brg.ld_block; |
2019 | ld_block_bias_size_ = brg.typesize_bias * brg.ld_block; |
2020 | ld_block_scales_size_ = sizeof(float) * brg.ld_block; |
2021 | ld_block_zp_size_ = sizeof(int32_t) * brg.ld_block; |
2022 | ldb_tail_B_size_ = brg.typesize_B * brg.ldb_tail; |
2023 | ldb_tail_C_size_ = brg.typesize_C * brg.ldb_tail; |
2024 | ldb_tail_D_size_ = brg.typesize_D * brg.ldb_tail; |
2025 | ldb_tail_zp_size_ = sizeof(int32_t) * brg.ldb_tail; |
2026 | |
2027 | // if beta == 1 and C datatype is f32 it is better to perform addition by |
2028 | // reading tiles directly from C instead of by reading/writing by vectors |
2029 | may_load_accumulators_ = one_of(brg.alpha, 0, 1) && brg.beta == 1.f |
2030 | && brg.dt_c == brg.dt_d && !brg.is_bf32 |
2031 | && IMPLICATION( |
2032 | brg.is_f32 || brg.is_bf16, brg.dt_c == data_type::f32) |
2033 | && IMPLICATION(brg.is_int8, brg.dt_c == data_type::s32); |
2034 | need_to_apply_alpha_beta_ |
2035 | = (brg.beta != 0.f && !may_load_accumulators_) || brg.alpha != 1.f; |
2036 | const bool has_zero_points = !everyone_is(brgemm_broadcast_t::none, |
2037 | brg.zp_type_a, brg.zp_type_b, brg.zp_type_c); |
2038 | are_post_ops_applicable_ = one_of(true, brg.with_eltwise, brg.with_binary, |
2039 | brg.with_scales, brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c, |
2040 | has_zero_points); |
2041 | |
2042 | // second level blocking eligible only if we don't use store by vectors for now |
2043 | assert(IMPLICATION(are_post_ops_applicable_ || need_to_apply_alpha_beta_ |
2044 | || brg.brgattr.bd_mask_level, |
2045 | !brg.is_blocked && !brg.brgattr.var_bs)); |
2046 | assert(IMPLICATION(brg.brgattr.var_bs, !brg.is_bf32)); |
2047 | read_params(); |
2048 | prepare_bd_mask(); |
2049 | Label permute_index_table; |
2050 | if (brg.is_bf32) { |
2051 | brgemm_init_tiles(brg, (char *)(&palette_)); |
2052 | mov(reg_tmp_gpr, permute_index_table); |
2053 | vmovups(zmm_bf32_pemute, ptr[reg_tmp_gpr]); |
2054 | } |
2055 | |
2056 | reg64_t reg_mask = rax; |
2057 | |
2058 | mov(reg_mask, full_mask); |
2059 | kmovq(ld_full_mask, reg_mask); |
2060 | mov(reg_mask, tail_mask); |
2061 | kmovq(ld_tail_mask, reg_mask); |
2062 | |
2063 | mov(reg_stride_lda, lda()); |
2064 | mov(reg_stride_ldb, ldb()); |
2065 | |
2066 | bool non_postops_generate |
2067 | = !are_post_ops_applicable_ || !brg.brgattr.postops_only; |
2068 | brgemm_iteration_t bi; |
2069 | |
2070 | Label label_to_ret; |
2071 | if (are_post_ops_applicable_) { |
2072 | Label label_store_without_post_ops; |
2073 | mov(reg_do_post_ops, ptr[param1 + GET_OFF(do_post_ops)]); |
2074 | cmp(reg_do_post_ops, 0); |
2075 | jz(label_store_without_post_ops, T_NEAR); |
2076 | bi.apply_postops = true; |
2077 | top_loop(bi); |
2078 | if (non_postops_generate) jmp(label_to_ret, T_NEAR); |
2079 | transform_buf_map_A_.clear(); |
2080 | transform_buf_map_B_.clear(); |
2081 | L(label_store_without_post_ops); |
2082 | } |
2083 | if (non_postops_generate) { |
2084 | bi.apply_postops = false; |
2085 | top_loop(bi); |
2086 | } |
2087 | L(label_to_ret); |
2088 | |
2089 | add(rsp, stack_space_needed_); |
2090 | |
2091 | postamble(); |
2092 | |
2093 | if (brg.with_eltwise) postops_injector_->prepare_table(); |
2094 | |
2095 | if (brg.is_bf32) { |
2096 | align(64); |
2097 | L(permute_index_table); |
2098 | const uint16_t _idx[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, |
2099 | 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, |
2100 | 15, 31}; |
2101 | for (size_t i = 0; i < 32; ++i) |
2102 | dw(_idx[i]); |
2103 | } |
2104 | } |
2105 | |
2106 | brgemm_amx_uker_t::brgemm_amx_uker_t(const brgemm_t abrd) { |
2107 | brgemm_kernel_ = new jit_brgemm_amx_uker_base_t(abrd); |
2108 | } |
2109 | |
2110 | status_t brgemm_amx_uker_t::create_kernel() { |
2111 | return brgemm_kernel_->create_kernel(); |
2112 | } |
2113 | |
2114 | void brgemm_amx_uker_t::operator()(brgemm_kernel_params_t *params) const { |
2115 | (*brgemm_kernel_)(params); |
2116 | } |
2117 | |
2118 | brgemm_amx_uker_t::~brgemm_amx_uker_t() { |
2119 | delete brgemm_kernel_; |
2120 | } |
2121 | |
2122 | } // namespace x64 |
2123 | } // namespace cpu |
2124 | } // namespace impl |
2125 | } // namespace dnnl |
2126 | |
2127 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
2128 | |