1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21#include "cpu/x64/jit_generator.hpp"
22
23#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29namespace matmul {
30
31using namespace dnnl::impl::format_tag;
32using namespace dnnl::impl::utils;
33using namespace Xbyak;
34
35#define GET_OFF(x) offsetof(ctx_t, x)
36
37struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
38 public jit_generator {
39 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_impl_t)
40
41 jit_brgemm_matmul_copy_a_impl_t(const brgemm_matmul_conf_t *conf)
42 : jit_brgemm_matmul_copy_a_t(conf)
43 , jit_generator(jit_name())
44 , typesize(conf_->a_dt_sz)
45 , tr_typesize(conf_->tr_a_dt_sz)
46 , vnni_granularity(data_type_vnni_granularity(conf_->src_dt))
47 , k_step(bytes_in_zmm / nstl::max(typesize, tr_typesize)) {}
48
49 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
50 status_t create_kernel() override { return jit_generator::create_kernel(); }
51
52private:
53 using reg64_t = const Xbyak::Reg64;
54 using reg32_t = const Xbyak::Reg32;
55 using opmask_t = const Xbyak::Opmask;
56 using zmm = const Xbyak::Zmm;
57 using ymm = const Xbyak::Ymm;
58 using xmm = const Xbyak::Xmm;
59
60 enum {
61 num_comp_acc = 8,
62 k_loop_unroll = 16,
63 bytes_in_zmm = 64,
64 };
65 const int typesize;
66 const int tr_typesize;
67 const int vnni_granularity;
68 const int k_step;
69
70 dim_t src_stride = 0, tr_src_stride = 0;
71 bool do_compute_compensation = false;
72
73 opmask_t kTail_load = k7;
74 opmask_t kTail_store = k6;
75 opmask_t kTail_comp = k5;
76
77 reg64_t reg_src = rax;
78 reg64_t reg_tr_src = rbx;
79 reg64_t reg_K_start = abi_not_param1;
80
81 reg64_t reg_zp_comp_buf_ptr = rdx;
82 reg64_t reg_zp_comp_res_ptr = rsi;
83
84 reg64_t reg_M_blk = r9;
85 reg64_t reg_K_blk = r10;
86 reg64_t reg_batch = r11;
87 reg64_t reg_aux_src = r12;
88 reg64_t reg_aux_tr_src = r13;
89 reg64_t regq_tmp = r14;
90 reg64_t imm_addr64 = r15;
91 reg64_t reg_zp_ab_comp_ptr = imm_addr64;
92 reg64_t reg_zp_b_neg_val_ptr = reg_K_blk;
93
94 zmm zmm_comp_mul = zmm30;
95 zmm zmm_comp_add = zmm31;
96
97 // Allows to shift A data by 128 for s8s8 problem for AVX512 in copy
98 // routine, not in compute kernel. It's disabled for now, as it
99 // requires setting some hint to brgemm kerenel to avoid double shifting
100 const bool allow_input_shift_for_s8s8 = false;
101
102 Xbyak::Zmm get_zmm_comp_acc(int i) {
103 assert(i >= 0 && i < num_comp_acc);
104 return Xbyak::Zmm(i);
105 }
106
107 Xbyak::Zmm get_zmm_copy(int i) {
108 assert(i >= 0 && i < k_loop_unroll);
109 return Xbyak::Zmm(29 - i);
110 }
111 void reduce_compensation_across_accumulators(int num_accumulators);
112 void copy_row(int ncolumns);
113 void copy_K_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter);
114 void copy_M_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter);
115 void generate() override;
116};
117
118void jit_brgemm_matmul_copy_a_impl_t::reduce_compensation_across_accumulators(
119 int num_accumulators) {
120 int num = num_accumulators;
121 while (num > 1) {
122 for (int i = 0; i < num / 2; i++) {
123 const auto zmm_acc0 = get_zmm_comp_acc(i);
124 const auto zmm_acc1 = get_zmm_comp_acc(div_up(num, 2) + i);
125 vpaddd(zmm_acc0, zmm_acc0, zmm_acc1);
126 }
127 num = div_up(num, 2);
128 }
129}
130
131void jit_brgemm_matmul_copy_a_impl_t::copy_K_loop(
132 bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
133 MAYBE_UNUSED(is_K_tail);
134 MAYBE_UNUSED(is_first_K_iter);
135 MAYBE_UNUSED(is_last_K_iter);
136
137 const int K_blk = is_K_tail ? conf_->K % conf_->K_blk
138 : nstl::min(conf_->K, conf_->K_blk);
139 const int k_tail = K_blk % k_step;
140 const int num_k_iters = K_blk / k_step;
141 const int num_acc = utils::saturate(1, (int)num_comp_acc, num_k_iters);
142
143 if (do_compute_compensation) {
144 for (int i = 0; i < num_acc; i++) {
145 const auto zmm_acc = get_zmm_comp_acc(i);
146 vpxord(zmm_acc, zmm_acc, zmm_acc);
147 }
148 }
149
150 auto maybe_compute_compensation = [=](int k_idx, zmm zmm_copy) {
151 if (do_compute_compensation) {
152 const auto zmm_comp_acc = get_zmm_comp_acc(k_idx % num_acc);
153 if (conf_->src_dt == data_type::s8)
154 vpdpbusd(zmm_comp_acc, zmm_comp_mul, zmm_copy);
155 else
156 vpdpbusd(zmm_comp_acc, zmm_copy, zmm_comp_mul);
157 }
158 };
159
160 for (int kb = 0; kb < div_up(num_k_iters, k_loop_unroll); kb++) {
161 int k_start = 0;
162 int k_end = nstl::min(
163 (int)k_loop_unroll, num_k_iters - kb * k_loop_unroll);
164 for (int k = k_start; k < k_end; k++) {
165 const int k_idx = kb * k_loop_unroll + k;
166 const size_t offset = (size_t)k_idx * k_step * typesize;
167 const auto addr = EVEX_compress_addr(reg_src, offset);
168 if (conf_->isa == avx512_core_fp16) {
169 vcvtph2psx(get_zmm_copy(k), addr);
170 } else {
171 vmovdqu8(get_zmm_copy(k), EVEX_compress_addr(reg_src, offset));
172 }
173 maybe_compute_compensation(k_idx, get_zmm_copy(k));
174 }
175 if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) {
176 for (int k = k_start; k < k_end; k++)
177 vpaddb(get_zmm_copy(k), get_zmm_copy(k), zmm_comp_add);
178 }
179 if (conf_->is_bf32) {
180 assert(typesize != tr_typesize);
181 int k = k_start;
182 const int k_end_2 = rnd_dn(k_end, 2);
183 for (; k < k_end_2; k += 2) {
184 const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step
185 * tr_typesize;
186 auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
187
188 auto zmm_src = get_zmm_copy(k);
189 auto zmm_src_next = get_zmm_copy(k + 1);
190
191 vcvtne2ps2bf16(zmm_src, zmm_src_next, zmm_src);
192 vmovups(tr_src_addr, zmm_src);
193 }
194 if (k < k_end) {
195 const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step
196 * tr_typesize;
197 auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
198 ymm ymm_downcvt_bf16 = ymm(get_zmm_copy(k).getIdx());
199 vcvtneps2bf16(ymm_downcvt_bf16, get_zmm_copy(k));
200 vmovdqu16(tr_src_addr, ymm_downcvt_bf16);
201 }
202 } else {
203 for (int k = k_start; k < k_end; k++) {
204 const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step
205 * tr_typesize;
206 auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
207 vmovdqu8(tr_src_addr, get_zmm_copy(k));
208 }
209 }
210 }
211
212 if (k_tail > 0) {
213 const auto kmovx = [=](Opmask k, size_t q) {
214 if (conf_->is_bf32) {
215 mov(regq_tmp.cvt32(), q);
216 jit_generator::kmovw(k, regq_tmp.cvt32());
217 } else {
218 mov(regq_tmp, q);
219 jit_generator::kmovq(k, regq_tmp);
220 }
221 };
222
223 const size_t dt_step = conf_->is_bf32 || conf_->isa == avx512_core_fp16
224 ? 1
225 : typesize;
226 const size_t tail_mask_load
227 = size_t(((size_t)1 << (dt_step * k_tail)) - 1);
228 kmovx(kTail_load, tail_mask_load);
229 const int k_tail_st = rnd_up(k_tail, vnni_granularity);
230 const size_t full_mask
231 = conf_->is_bf32 ? ((size_t)1 << 16) - 1 : 0xffffffffffffffff;
232 const size_t tail_mask_store = k_tail_st == k_step
233 ? full_mask
234 : size_t(((size_t)1 << (dt_step * k_tail_st)) - 1);
235 kmovx(kTail_store, tail_mask_store);
236
237 auto zmm_tail = get_zmm_copy(0) | kTail_load | T_z;
238 auto load_addr
239 = EVEX_compress_addr(reg_src, num_k_iters * k_step * typesize);
240 if (conf_->is_bf32)
241 vmovups(zmm_tail, load_addr);
242 else if (conf_->isa == avx512_core_fp16)
243 vcvtph2psx(zmm_tail, load_addr);
244 else
245 vmovdqu8(zmm_tail, load_addr);
246
247 maybe_compute_compensation(0, get_zmm_copy(0));
248
249 if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required)
250 vpaddb(get_zmm_copy(0), get_zmm_copy(0), zmm_comp_add);
251
252 auto tr_src_addr = EVEX_compress_addr(
253 reg_tr_src, num_k_iters * k_step * tr_typesize);
254 if (conf_->is_bf32) {
255 ymm ymm_downcvt_bf16 = ymm(get_zmm_copy(0).getIdx());
256 vcvtneps2bf16(ymm_downcvt_bf16, get_zmm_copy(0));
257 vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store);
258 } else if (conf_->isa == avx512_core_fp16) {
259 vmovups(tr_src_addr, get_zmm_copy(0) | kTail_store);
260 } else
261 vmovdqu8(tr_src_addr, get_zmm_copy(0) | kTail_store);
262 }
263
264 if (do_compute_compensation) {
265 reduce_compensation_across_accumulators(num_acc);
266
267 const auto addr_buf = zword[reg_zp_comp_buf_ptr];
268 if (!is_first_K_iter)
269 vpaddd(get_zmm_comp_acc(0), get_zmm_comp_acc(0), addr_buf);
270 if (!is_last_K_iter) {
271 vmovups(addr_buf, get_zmm_comp_acc(0));
272 return;
273 }
274
275 // is_last_K_iter == true: we need to reduce values within acc
276 // register, add mixed ab_compensation component if any, multiply
277 // it by negative zp_b_value and finally store the reslt
278
279 // step 1: reduce values within acc register
280 const auto ymm_red0 = ymm(get_zmm_comp_acc(0).getIdx());
281 const auto ymm_red1 = ymm(get_zmm_comp_acc(1).getIdx());
282 vextracti64x4(ymm_red1, get_zmm_comp_acc(0), 1);
283 vphaddd(ymm_red0, ymm_red0, ymm_red1);
284 vpxord(ymm_red1, ymm_red1, ymm_red1);
285 vphaddd(ymm_red0, ymm_red0, ymm_red1);
286 vphaddd(ymm_red0, ymm_red0, ymm_red1);
287 const auto xmm_red1 = xmm(ymm_red1.getIdx());
288 vextractf128(xmm_red1, ymm_red0, 1);
289 vpaddd(ymm_red0, ymm_red0, ymm_red1);
290
291 // step 2: add -K * zp_a_val as mixed ab_compensation component
292 if (conf_->src_zp_type != brgemm_broadcast_t::none) {
293 assert(conf_->src_zp_type == brgemm_broadcast_t::per_tensor);
294 reg64_t reg_zp_ab_comp_ptr = imm_addr64;
295 mov(reg_zp_ab_comp_ptr, ptr[param1 + GET_OFF(zp_ab_comp_ptr)]);
296
297 const auto addr_ab_comp = zword_b[reg_zp_ab_comp_ptr];
298 const auto zmm_res = get_zmm_comp_acc(0) | kTail_comp;
299 vpaddd(zmm_res, get_zmm_comp_acc(0), addr_ab_comp);
300 }
301
302 // step 3: multiply by zp_b_val
303 mov(reg_zp_b_neg_val_ptr, ptr[param1 + GET_OFF(zp_b_neg_value_ptr)]);
304 const auto zmm_zp_b_neg_val = get_zmm_comp_acc(1);
305 vbroadcastss(zmm_zp_b_neg_val, ptr[reg_zp_b_neg_val_ptr]);
306 vpmulld(get_zmm_comp_acc(0), get_zmm_comp_acc(0), zmm_zp_b_neg_val);
307
308 // step 4: store the final result value
309 vmovups(ptr[reg_zp_comp_res_ptr], get_zmm_comp_acc(0) | kTail_comp);
310 }
311}
312
313void jit_brgemm_matmul_copy_a_impl_t::copy_M_loop(
314 bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
315
316 if (do_compute_compensation) {
317 mov(imm_addr64, 1);
318 vpbroadcastb(zmm_comp_mul, imm_addr64.cvt8());
319 if (!(is_first_K_iter && is_last_K_iter))
320 mov(reg_zp_comp_buf_ptr,
321 ptr[param1 + GET_OFF(zp_b_compensation_buffer_ptr)]);
322
323 if (is_last_K_iter) {
324 mov(reg_zp_comp_res_ptr,
325 ptr[param1 + GET_OFF(zp_a_compensation_result_ptr)]);
326 const auto kmovw = [=](Opmask k, size_t q) {
327 mov(regq_tmp, q);
328 jit_generator::kmovw(k, imm_addr64.cvt32());
329 };
330 kmovw(kTail_comp, 1);
331 }
332 }
333
334 Label loop_M;
335 L(loop_M);
336
337 copy_K_loop(is_K_tail, is_first_K_iter, is_last_K_iter);
338
339 add(reg_src, src_stride);
340 add(reg_tr_src, tr_src_stride);
341 if (do_compute_compensation) {
342 // shift comp pointers
343 if (!(is_first_K_iter && is_last_K_iter))
344 add(reg_zp_comp_buf_ptr, sizeof(int32_t) * 16);
345 if (is_last_K_iter) add(reg_zp_comp_res_ptr, sizeof(int32_t));
346 }
347
348 dec(reg_M_blk);
349 jnz(loop_M, T_NEAR);
350}
351
352void jit_brgemm_matmul_copy_a_impl_t::generate() {
353 preamble();
354
355 src_stride = conf_->src_tag == format_tag::acbd ? conf_->copy_A_src_stride
356 : conf_->K * typesize;
357 const dim_t LDA = conf_->use_buffer_a_tail_only ? (dim_t)conf_->wei_k_blk
358 : conf_->LDA;
359 tr_src_stride = LDA * tr_typesize;
360 do_compute_compensation = conf_->has_zero_point_b;
361
362 mov(reg_src, ptr[param1 + GET_OFF(src)]);
363 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
364 mov(reg_K_blk, ptr[param1 + GET_OFF(current_K_blk)]);
365 mov(reg_M_blk, ptr[param1 + GET_OFF(current_M_blk)]);
366
367 if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) {
368 mov(imm_addr64, 128);
369 vpbroadcastb(zmm_comp_add, imm_addr64.cvt8());
370 }
371
372 auto copy_body = [=](bool is_first_K_iter, bool is_last_K_iter) {
373 Label copy_body_done;
374 // might be different from conf_->K_tail
375 const dim_t K_blk_tail
376 = conf_->K_tail > 0 ? conf_->K % conf_->K_blk : 0;
377 if (K_blk_tail > 0) {
378 Label not_K_tail;
379 cmp(reg_K_blk, K_blk_tail);
380 jne(not_K_tail, T_NEAR);
381 copy_M_loop(true, is_first_K_iter, is_last_K_iter);
382 jmp(copy_body_done, T_NEAR);
383
384 L(not_K_tail);
385 }
386
387 copy_M_loop(false, is_first_K_iter, is_last_K_iter);
388 L(copy_body_done);
389 };
390
391 Label done;
392 if (do_compute_compensation) {
393 assert(conf_->wei_zp_type == brgemm_broadcast_t::per_tensor);
394
395 mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
396 const auto last_K_threshold
397 = rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk;
398 Label not_first, not_first_not_last;
399 cmp(reg_K_start, 0);
400 jne(not_first, T_NEAR);
401 {
402 // first K iteration
403 Label first_not_last;
404 cmp(reg_K_start, last_K_threshold);
405 jl(first_not_last, T_NEAR);
406 copy_body(true, true);
407 jmp(done, T_NEAR);
408
409 L(first_not_last);
410 copy_body(true, false);
411 jmp(done, T_NEAR);
412 }
413
414 L(not_first);
415 cmp(reg_K_start, last_K_threshold);
416 jl(not_first_not_last, T_NEAR);
417
418 copy_body(false, true);
419 jmp(done, T_NEAR);
420 L(not_first_not_last);
421 }
422 copy_body(false, false);
423 L(done);
424
425 postamble();
426}
427
428struct jit_brgemm_matmul_copy_a_transposed_impl_t
429 : public jit_brgemm_matmul_copy_a_t,
430 public jit_generator {
431 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_transposed_impl_t)
432
433 jit_brgemm_matmul_copy_a_transposed_impl_t(const brgemm_matmul_conf_t *conf)
434 : jit_brgemm_matmul_copy_a_t(conf)
435 , jit_generator(jit_name())
436 , typesize(conf_->a_dt_sz)
437 , tr_typesize(conf_->tr_a_dt_sz)
438 , src_stride(conf_->src_tag == format_tag::adbc
439 ? conf_->copy_A_src_stride
440 : conf_->M * typesize)
441 , dst_stride(conf_->LDA * tr_typesize)
442 , m_loop_src_shift(columns_step * typesize)
443 , m_loop_dst_shift(columns_step * dst_stride)
444 , k_loop_src_shift(rows_step * src_stride)
445 , k_loop_dst_shift(rows_step * tr_typesize)
446 , is_f32(everyone_is(data_type::f32, conf_->src_dt, conf_->wei_dt))
447 , is_bf32(conf_->is_bf32) {}
448
449 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
450 status_t create_kernel() override { return jit_generator::create_kernel(); }
451
452private:
453 using reg64_t = const Xbyak::Reg64;
454 using reg32_t = const Xbyak::Reg32;
455 using opmask_t = const Xbyak::Opmask;
456
457 const size_t typesize;
458 const size_t tr_typesize;
459 const int rows_step = 16;
460 const int columns_step = rows_step;
461 const dim_t src_stride, dst_stride;
462 const dim_t m_loop_src_shift;
463 const dim_t m_loop_dst_shift;
464 const dim_t k_loop_src_shift;
465 const dim_t k_loop_dst_shift;
466 const bool is_f32;
467 const bool is_bf32;
468
469 opmask_t kFFFF = k1;
470 opmask_t k3333 = k1;
471 opmask_t k5555 = k2;
472 opmask_t kAAAA = k3;
473 opmask_t kAA = k4;
474 opmask_t kCCCC = k4;
475 opmask_t k55 = k5;
476 opmask_t k0F0F = k5;
477 opmask_t kCC = k6;
478 opmask_t kF0F0 = k6;
479 opmask_t k33 = k7;
480 opmask_t kTail = is_f32 ? k7 : k1;
481
482 reg32_t regw_tmp = r15d;
483 reg64_t reg_k_src = r14;
484 reg64_t reg_k_dst = r13;
485 reg64_t reg_m_src = r12;
486 reg64_t reg_m_dst = r11;
487 reg64_t reg_loop_k = rax;
488 reg64_t reg_loop_m = rbx;
489 reg64_t imm_addr64 = rdx;
490
491 Xbyak::Zmm vidx1 = zmm31;
492 Xbyak::Zmm vidx2 = zmm30;
493 Xbyak::Zmm vidx3 = zmm29;
494 Xbyak::Zmm vidx4 = zmm28;
495 Xbyak::Zmm vidx5 = zmm27;
496 Xbyak::Zmm zmm_tmp = zmm26;
497
498 void transpose_f32(reg64_t dst, reg64_t src, int nrows, int ncolumns);
499 void transpose_bf16(reg64_t dst, reg64_t src, int nrows, int ncolumns);
500 void deploy_transpose(reg64_t dst, reg64_t src, int nrows, int ncolumns);
501 void generate() override;
502};
503
504void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_bf16(
505 reg64_t dst, reg64_t src, int nrows, int ncolumns) {
506 assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0
507 && ncolumns <= columns_step);
508 if (!nrows) return;
509
510 auto src_zmm = [=](int i) { return Zmm(i); };
511
512 auto src_ymm = [=](int i) {
513 assert(i >= 0 && i < 16);
514 return Ymm(i);
515 };
516
517 auto kmovx = [=](Opmask k, unsigned w, bool use_word_sz = false) {
518 mov(regw_tmp, w);
519 if (use_word_sz)
520 jit_generator::kmovw(k, regw_tmp);
521 else
522 jit_generator::kmovd(k, regw_tmp);
523 };
524
525 auto store = [=](Zmm r, int i) {
526 auto addr = EVEX_compress_addr(dst, i * dst_stride);
527 vmovdqu16(addr, r | kTail);
528 };
529
530 const int load_mask
531 = ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff;
532 kmovx(kFFFF, load_mask, is_bf32);
533
534 for (int i = 0; i < nrows / 2; i++) {
535 auto idx0 = 2 * i;
536 auto idx1 = 2 * i + 1;
537 auto zmm_src0 = src_zmm(idx0);
538 auto zmm_src1 = src_zmm(idx1);
539 auto src_addr_0 = EVEX_compress_addr(src, idx0 * src_stride);
540 auto src_addr_1 = EVEX_compress_addr(src, idx1 * src_stride);
541 if (is_bf32) {
542 vmovups(zmm_src0 | kFFFF | T_z, src_addr_0);
543 vmovups(zmm_src1 | kFFFF | T_z, src_addr_1);
544 vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0);
545 } else {
546 auto src1 = src_ymm(idx1);
547 vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr_0);
548 vmovdqu16(zmm_src1 | kFFFF | T_z, src_addr_1);
549 vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
550 }
551 vpermw(zmm_src0, vidx5, zmm_src0);
552 }
553
554 // for odd numbers we need to mix row with zeroes
555 if (nrows % 2) {
556 int i = nrows / 2;
557 auto zmm_src0 = src_zmm(2 * i);
558 auto src_addr = EVEX_compress_addr(src, 2 * i * src_stride);
559 if (is_bf32) {
560 vmovups(zmm_src0 | kFFFF | T_z, src_addr);
561 vcvtneps2bf16(Ymm(zmm_src0.getIdx()), zmm_src0);
562 } else
563 vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr);
564 vpermw(zmm_src0, vidx5, zmm_src0);
565 }
566
567 for (int i = rnd_up(nrows, 2); i < rows_step; i += 2) {
568 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
569 }
570
571 // swap 1
572 for (int i = 0; i < 4; i++) {
573 auto zmm0 = src_zmm(4 * i);
574 auto zmm1 = src_zmm(4 * i + 2);
575 auto tmp0 = src_zmm(4 * i + 1);
576 auto tmp1 = src_zmm(4 * i + 3);
577
578 vmovups(tmp0, zmm0);
579 vmovups(tmp1, zmm1);
580
581 vpermps(tmp0 | kAAAA, vidx3, zmm1);
582 vpermps(tmp1 | k5555, vidx3, zmm0);
583 }
584 // swap 2
585 int base_idx;
586 base_idx = 0;
587 for (int i = 0; i < 2; i++) {
588 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
589 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
590
591 auto tmp0 = src_zmm(base_idx + 2 * i);
592 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
593
594 vmovupd(tmp0, zmm0);
595 vmovupd(tmp1, zmm1);
596
597 vpermpd(tmp0 | kAA, vidx2, zmm1);
598 vpermpd(tmp1 | k55, vidx2, zmm0);
599 }
600 base_idx = 8;
601 for (int i = 0; i < 2; i++) {
602 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
603 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
604
605 auto tmp0 = src_zmm(base_idx + 2 * i);
606 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
607
608 vmovupd(tmp0, zmm0);
609 vmovupd(tmp1, zmm1);
610
611 vpermpd(tmp0 | kAA, vidx2, zmm1);
612 vpermpd(tmp1 | k55, vidx2, zmm0);
613 }
614
615 // swap 3
616 for (int i = 0; i < 4; i++) {
617 auto zmm0 = src_zmm(2 * i);
618 auto zmm1 = src_zmm(2 * i + 8);
619
620 auto tmp0 = src_zmm(2 * i + 1);
621 auto tmp1 = src_zmm(2 * i + 9);
622
623 vmovupd(tmp0, zmm0);
624 vmovupd(tmp1, zmm1);
625
626 vpermpd(tmp0 | kCC, vidx1, zmm1);
627 vpermpd(tmp1 | k33, vidx1, zmm0);
628 }
629
630 // all stores
631 for (int i = 0; i < 8; i++)
632 vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
633
634 auto get_vec_idx = [=](int col_idx) {
635 assert(col_idx < columns_step && col_idx >= 0);
636 const int blk_sz = 4;
637 const int blk_idx = col_idx / blk_sz;
638 const int idx_within_blk = col_idx % blk_sz;
639
640 // 0 1 2 3 -> 0 2 1 3
641 const int mapped_blk_idx = 2 * blk_idx - (blk_idx / 2) * 3;
642 // 0 1 2 3 -> 1 0 3 2
643 const int mapped_idx_within_blk
644 = idx_within_blk + 1 - 2 * (idx_within_blk % 2);
645 return blk_sz * mapped_blk_idx + mapped_idx_within_blk;
646 };
647 const int columns_to_store = rnd_up(nrows, 2);
648 const int store_mask = columns_to_store < rows_step
649 ? (1 << columns_to_store) - 1
650 : 0xffff;
651 kmovx(kTail, store_mask);
652
653 for (int col_idx = 0; col_idx < ncolumns; col_idx++)
654 store(src_zmm(get_vec_idx(col_idx)), col_idx);
655}
656
657void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_f32(
658 reg64_t dst, reg64_t src, int nrows, int ncolumns) {
659 assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0
660 && ncolumns <= columns_step);
661 if (!nrows) return;
662
663 auto kmovw = [=](Opmask k, size_t q) {
664 mov(regw_tmp, q);
665 jit_generator::kmovw(k, regw_tmp);
666 };
667
668 const int load_mask
669 = ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff;
670 kmovw(kTail, load_mask);
671
672 auto src_zmm = [=](int i) {
673 assert(i >= 0 && i < 16);
674 return Zmm(i);
675 };
676
677 auto tmp_zmm = [=](int i) {
678 assert(i >= 0 && i < 16);
679 return Zmm(16 + i);
680 };
681
682 auto load = [=](int i) {
683 const auto addr = EVEX_compress_addr(src, i * src_stride);
684 if (i < nrows)
685 if (conf_->isa == avx512_core_fp16)
686 vcvtph2psx(src_zmm(i) | kTail | T_z, addr);
687 else
688 vmovups(src_zmm(i) | kTail | T_z, addr);
689 else
690 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
691 };
692
693 auto store = [=](Zmm r, int i) {
694 auto addr = EVEX_compress_addr(dst, i * dst_stride);
695 vmovups(addr, r | kTail);
696 };
697
698 auto transpose16x8 = [=](int base_idx) {
699 assert(base_idx == 0 || base_idx == 8);
700
701 // swap 1
702 for (int i = 0; i < 4; i++) {
703 int src_idx0 = base_idx + i * 2;
704 int src_idx1 = src_idx0 + 1;
705
706 int next_src_idx0 = src_idx0 + 2;
707 int next_src_idx1 = src_idx1 + 2;
708 bool load_next = base_idx == 0 || i < 3;
709
710 if (base_idx == 0 && i == 0) {
711 load(src_idx0);
712 load(src_idx1);
713 }
714
715 auto tmp0 = tmp_zmm(src_idx0);
716 auto tmp1 = tmp_zmm(src_idx1);
717 auto src0 = src_zmm(src_idx0);
718 auto src1 = src_zmm(src_idx1);
719
720 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
721 valignd(tmp0, src0, src0, 0x1);
722
723 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
724 valignd(tmp1, src1, src1, 0xf);
725
726 vmovaps(src0 | kAAAA, tmp1);
727 vmovaps(src1 | k5555, tmp0);
728 }
729
730 // swap 2
731 for (int i = 0; i < 4; i++) {
732 int select_half = (i < 2) ? 0 : 2;
733 int src_idx0 = base_idx + i + select_half + 0;
734 int src_idx2 = src_idx0 + 2;
735
736 auto tmp0 = tmp_zmm(src_idx0);
737 auto tmp1 = tmp_zmm(src_idx2);
738 auto src0 = src_zmm(src_idx0);
739 auto src2 = src_zmm(src_idx2);
740
741 valignd(tmp0, src0, src0, 0x2);
742 valignd(tmp1, src2, src2, 0xe);
743 vmovaps(src2 | k3333, tmp0);
744 vmovaps(src0 | kCCCC, tmp1);
745 }
746
747 // swap 4
748 for (int i = 0; i < 4; i++) {
749 int src_idx0 = base_idx + i;
750 int src_idx4 = src_idx0 + 4;
751
752 auto tmp0 = tmp_zmm(src_idx0);
753 auto src0 = src_zmm(src_idx0);
754 auto src4 = src_zmm(src_idx4);
755
756 vmovaps(tmp0, src0);
757 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
758 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
759 }
760 };
761
762 auto fixup16x16 = [=]() {
763 const int store_mask = nrows < rows_step ? (1 << nrows) - 1 : 0xffff;
764 kmovw(kTail, store_mask);
765
766 // swap 8
767 for (int i = 0; i < nstl::min(8, ncolumns); i++) {
768 auto tmp = tmp_zmm(i);
769 auto src0 = src_zmm(i);
770 auto src8 = src_zmm(8 + i);
771 vshuff64x2(tmp, src0, src8, 0x44);
772 store(tmp, i);
773 }
774
775 for (int i = 0; i < nstl::max(0, ncolumns - 8); i++) {
776 auto tmp = tmp_zmm(8 + i);
777 auto src0 = src_zmm(i);
778 auto src8 = src_zmm(8 + i);
779 vshuff64x2(tmp, src0, src8, 0xee);
780 store(tmp, 8 + i);
781 }
782 };
783
784 transpose16x8(0);
785 transpose16x8(8);
786 fixup16x16();
787}
788
789void jit_brgemm_matmul_copy_a_transposed_impl_t::deploy_transpose(
790 reg64_t dst, reg64_t src, int nrows, int ncolumns) {
791 if (is_f32 || conf_->isa == avx512_core_fp16)
792 transpose_f32(dst, src, nrows, ncolumns);
793 else
794 transpose_bf16(dst, src, nrows, ncolumns);
795}
796
797void jit_brgemm_matmul_copy_a_transposed_impl_t::generate() {
798
799 // only bf16, f16 and f32 supported for now
800 if (!one_of(conf_->src_dt, data_type::bf16, data_type::f32, data_type::f16))
801 return;
802 preamble();
803
804 alignas(64) static constexpr const int64_t idx1[8]
805 = {2, 3, 0, 1, 6, 7, 4, 5};
806 alignas(64) static constexpr const int64_t idx2[8]
807 = {1, 0, 3, 2, 5, 4, 7, 6};
808 alignas(64) static constexpr const int32_t idx3[16]
809 = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
810 alignas(64) static constexpr const int32_t idx4[16]
811 = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
812 alignas(64) static constexpr const uint16_t idx5[32]
813 = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
814 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
815
816 const int k_block_tail = conf_->K_blk % rows_step;
817 const int last_k_block_tail = (conf_->K % conf_->K_blk) % rows_step;
818 const int m_block_tail = conf_->M_blk % columns_step;
819 const int last_m_block_tail = conf_->M_tail % columns_step;
820
821 auto kmovw = [=](Opmask k, unsigned w) {
822 mov(regw_tmp, w);
823 jit_generator::kmovw(k, regw_tmp);
824 };
825
826 if (is_f32) {
827 kmovw(k3333, 0x3333); // 0011001100110011
828 kmovw(k5555, 0x5555); // 0101010101010101
829 kmovw(kAAAA, 0xaaaa); // 1010101010101010
830 kmovw(kCCCC, 0xcccc); // 1100110011001100
831 kmovw(k0F0F, 0x0f0f); // 0000111100001111
832 kmovw(kF0F0, 0xf0f0); // 1111000011110000
833 } else {
834 kmovw(kFFFF, 0xffff);
835 kmovw(k5555, 0x5555);
836 kmovw(kAAAA, 0xaaaa);
837 kmovw(kAA, 0xaa);
838 kmovw(k55, 0x55);
839 kmovw(kCC, 0xcc);
840 kmovw(k33, 0x33);
841 }
842
843 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
844 mov(imm_addr64, reinterpret_cast<size_t>(addr));
845 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
846 };
847
848 auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
849 mov(imm_addr64, reinterpret_cast<size_t>(addr));
850 jit_generator::vmovdqa32(z, ptr[imm_addr64]);
851 };
852
853 if (!is_f32) {
854 vmovdqa64(vidx1, idx1);
855 vmovdqa64(vidx2, idx2);
856 vmovdqa32(vidx3, idx3);
857 vmovdqa32(vidx4, idx4);
858 vmovdqa32(vidx5, (const int32_t *)idx5);
859 }
860
861 auto compute_m_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base,
862 int nrows) {
863 mov(reg_loop_m, ptr[param1 + GET_OFF(current_M_blk)]);
864 mov(reg_m_src, reg_base);
865 mov(reg_m_dst, reg_tr_base);
866
867 Label m_loop_tail_or_done, m_loop, compute_m_loop_done;
868 cmp(reg_loop_m, columns_step);
869 jl(m_loop_tail_or_done, T_NEAR);
870
871 L(m_loop);
872 {
873 deploy_transpose(reg_m_dst, reg_m_src, nrows, columns_step);
874 add(reg_m_src, m_loop_src_shift);
875 add(reg_m_dst, m_loop_dst_shift);
876 }
877 sub(reg_loop_m, columns_step);
878 cmp(reg_loop_m, columns_step);
879 jge(m_loop, T_NEAR);
880
881 if (m_block_tail > 0 || last_m_block_tail > 0)
882 jz(compute_m_loop_done, T_NEAR);
883
884 L(m_loop_tail_or_done);
885
886 if (m_block_tail > 0) {
887 Label m_block_tail_done;
888 cmp(reg_loop_m, m_block_tail);
889 jne(m_block_tail_done, T_NEAR);
890
891 deploy_transpose(reg_m_dst, reg_m_src, nrows, m_block_tail);
892 jmp(compute_m_loop_done, T_NEAR);
893
894 L(m_block_tail_done);
895 }
896 if (last_m_block_tail > 0 && last_m_block_tail != m_block_tail) {
897 Label last_m_block_tail_done;
898 cmp(reg_loop_m, last_m_block_tail);
899 jne(last_m_block_tail_done, T_NEAR);
900
901 deploy_transpose(reg_m_dst, reg_m_src, nrows, last_m_block_tail);
902 jmp(compute_m_loop_done, T_NEAR);
903
904 L(last_m_block_tail_done);
905 }
906
907 L(compute_m_loop_done);
908 };
909
910 auto compute_k_loop = [&]() {
911 mov(reg_k_src, ptr[param1 + GET_OFF(src)]);
912 mov(reg_k_dst, ptr[param1 + GET_OFF(tr_src)]);
913 mov(reg_loop_k, ptr[param1 + GET_OFF(current_K_blk)]);
914
915 Label k_tail_or_done, k_loop, compute_k_loop_done;
916 cmp(reg_loop_k, rows_step);
917 jl(k_tail_or_done, T_NEAR);
918
919 L(k_loop);
920 {
921 compute_m_loop(reg_k_src, reg_k_dst, rows_step);
922 add(reg_k_src, k_loop_src_shift);
923 add(reg_k_dst, k_loop_dst_shift);
924 }
925 sub(reg_loop_k, rows_step);
926 cmp(reg_loop_k, rows_step);
927 jge(k_loop, T_NEAR);
928
929 if (k_block_tail > 0 || last_k_block_tail > 0)
930 jz(compute_k_loop_done, T_NEAR);
931
932 L(k_tail_or_done);
933
934 if (k_block_tail > 0) {
935 Label k_block_tail_done;
936 cmp(reg_loop_k, k_block_tail);
937 jne(k_block_tail_done, T_NEAR);
938
939 compute_m_loop(reg_k_src, reg_k_dst, k_block_tail);
940 jmp(compute_k_loop_done, T_NEAR);
941
942 L(k_block_tail_done);
943 }
944 if (last_k_block_tail > 0 && last_k_block_tail != k_block_tail) {
945 Label last_k_block_tail_done;
946 cmp(reg_loop_k, last_k_block_tail);
947 jne(last_k_block_tail_done, T_NEAR);
948
949 compute_m_loop(reg_k_src, reg_k_dst, last_k_block_tail);
950 jmp(compute_k_loop_done, T_NEAR);
951
952 L(last_k_block_tail_done);
953 }
954
955 L(compute_k_loop_done);
956 };
957
958 compute_k_loop();
959
960 postamble();
961}
962
963struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
964 public jit_generator {
965 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_int8_t)
966
967 jit_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf)
968 : jit_brgemm_matmul_copy_b_t(conf), jit_generator(jit_name()) {}
969
970 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
971 status_t create_kernel() override { return jit_generator::create_kernel(); }
972
973private:
974 using reg64_t = const Xbyak::Reg64;
975 using reg32_t = const Xbyak::Reg32;
976 using opmask_t = const Xbyak::Opmask;
977 using zmm = const Xbyak::Zmm;
978 using ymm = const Xbyak::Ymm;
979
980 enum { typesize = sizeof(int8_t), k_blk_step = 4, n_blk_step = 64 };
981 dim_t src_stride = 0, tr_src_stride = 0;
982 bool is_amx = false;
983 bool do_compute_compensation = false;
984
985 opmask_t kTail = k7;
986
987 reg64_t reg_src = rax;
988 reg64_t reg_tr_src = rbx;
989 reg64_t reg_comp_ptr = rdx;
990 reg64_t reg_zp_comp_ptr = r11;
991 reg64_t reg_zp_a_neg_val_ptr = r12;
992
993 reg64_t reg_K_iters = r8;
994 reg64_t reg_N_blk = r9;
995 reg64_t reg_K_start = r10;
996 reg64_t regq_tmp = r14;
997 reg64_t imm_addr64 = r15;
998
999 zmm vreg_idx_lo_256 = zmm26;
1000 zmm vreg_idx_hi_256 = zmm27;
1001 zmm vreg_idx_lo_128 = zmm28;
1002 zmm vreg_idx_hi_128 = zmm29;
1003 zmm zmm_comp_mul = zmm30;
1004 zmm zmm_zero = zmm31;
1005
1006 Xbyak::Zmm get_comp_acc(int i) { return Xbyak::Zmm(25 - i); }
1007 Xbyak::Zmm get_zmm_zp_comp_res(int i) { return get_comp_acc(i); }
1008 Xbyak::Zmm get_zmm_oscale_comp_res(int i) { return Xbyak::Zmm(i); }
1009 void copy_4x64_vnni_avx512_core(int nrows, int ncolumns);
1010 void copy_4x64_vnni_amx(int nrows, int ncolumns);
1011 void copy_4x64_vnni(int nrows, int ncolumns);
1012 void generate() override;
1013};
1014
1015void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni(int nrows, int ncolumns) {
1016 if (is_amx)
1017 copy_4x64_vnni_amx(nrows, ncolumns);
1018 else
1019 copy_4x64_vnni_avx512_core(nrows, ncolumns);
1020}
1021
1022void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni_amx(
1023 int nrows, int ncolumns) {
1024 auto kmovq = [=](Opmask k, size_t q) {
1025 mov(regq_tmp, q);
1026 jit_generator::kmovq(k, regq_tmp);
1027 };
1028
1029 const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1);
1030 if (ncolumns < n_blk_step) kmovq(kTail, tail_mask);
1031
1032 const int blk_sz = 6;
1033 const int max_unroll = (do_compute_compensation ? 21 : 25) / blk_sz;
1034 auto get_zmm = [=](int blk, int idx) {
1035 assert(idx >= 0 && idx < blk_sz && blk >= 0);
1036 auto reg_idx = blk_sz * blk + idx;
1037 assert(reg_idx >= 0 && reg_idx < 32);
1038 return zmm(reg_idx);
1039 };
1040
1041 auto load = [=](int blk, int i) {
1042 auto src_reg = get_zmm(blk, i % k_blk_step);
1043 auto src_load = ncolumns < n_blk_step ? src_reg | kTail | T_z : src_reg;
1044 vmovdqu8(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1045 };
1046
1047 for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step); kb++)
1048 for (int k = 0;
1049 k < nstl::min(max_unroll,
1050 div_up(nrows - kb * max_unroll * k_blk_step, k_blk_step));
1051 k++) {
1052 const int row_start = (kb * max_unroll + k) * k_blk_step;
1053 const int row_end = nstl::min(row_start + k_blk_step, nrows);
1054
1055 for (int i = row_start; i < row_end; i++)
1056 load(k, i);
1057 if (row_end == nrows && nrows % k_blk_step > 0) {
1058 for (int i = nrows; i < rnd_up(nrows, k_blk_step); i++) {
1059 auto src_reg = get_zmm(k, i % k_blk_step);
1060 vpxord(src_reg, src_reg, src_reg);
1061 }
1062 }
1063
1064 vmovups(get_zmm(k, 4), vreg_idx_lo_256);
1065 vpermi2b(get_zmm(k, 4), get_zmm(k, 0), get_zmm(k, 2));
1066 vmovups(get_zmm(k, 5), vreg_idx_hi_256);
1067 vpermi2b(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 2));
1068 vmovups(get_zmm(k, 0), vreg_idx_lo_256);
1069 vpermi2b(get_zmm(k, 0), get_zmm(k, 1), get_zmm(k, 3));
1070 vmovups(get_zmm(k, 2), vreg_idx_hi_256);
1071 vpermi2b(get_zmm(k, 2), get_zmm(k, 1), get_zmm(k, 3));
1072
1073 vmovups(get_zmm(k, 1), vreg_idx_lo_128);
1074 vpermi2b(get_zmm(k, 1), get_zmm(k, 4), get_zmm(k, 0));
1075 dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride;
1076 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base), get_zmm(k, 1));
1077 if (do_compute_compensation)
1078 vpdpbusd(get_comp_acc(0), zmm_comp_mul, get_zmm(k, 1));
1079
1080 if (ncolumns > 16) {
1081 vmovups(get_zmm(k, 3), vreg_idx_hi_128);
1082 vpermi2b(get_zmm(k, 3), get_zmm(k, 4), get_zmm(k, 0));
1083 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
1084 get_zmm(k, 3));
1085 if (do_compute_compensation)
1086 vpdpbusd(get_comp_acc(1), zmm_comp_mul, get_zmm(k, 3));
1087 } else if (conf_->wei_n_blk > 16) {
1088 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
1089 zmm_zero);
1090 }
1091
1092 if (ncolumns > 32) {
1093 vmovups(get_zmm(k, 4), vreg_idx_lo_128);
1094 vpermi2b(get_zmm(k, 4), get_zmm(k, 5), get_zmm(k, 2));
1095 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
1096 get_zmm(k, 4));
1097 if (do_compute_compensation)
1098 vpdpbusd(get_comp_acc(2), zmm_comp_mul, get_zmm(k, 4));
1099 } else if (conf_->wei_n_blk > 32) {
1100 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
1101 zmm_zero);
1102 }
1103
1104 if (ncolumns > 48) {
1105 vmovups(get_zmm(k, 0), vreg_idx_hi_128);
1106 vpermi2b(get_zmm(k, 0), get_zmm(k, 5), get_zmm(k, 2));
1107 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
1108 get_zmm(k, 0));
1109 if (do_compute_compensation)
1110 vpdpbusd(get_comp_acc(3), zmm_comp_mul, get_zmm(k, 0));
1111 } else if (conf_->wei_n_blk > 48) {
1112 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
1113 zmm_zero);
1114 }
1115 }
1116}
1117
1118void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni_avx512_core(
1119 int nrows, int ncolumns) {
1120 auto kmovq = [=](Opmask k, size_t q) {
1121 mov(regq_tmp, q);
1122 jit_generator::kmovq(k, regq_tmp);
1123 };
1124
1125 const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1);
1126 if (ncolumns < n_blk_step) kmovq(kTail, tail_mask);
1127
1128 const int blk_sz = 6;
1129 const int max_unroll = (do_compute_compensation ? 21 : 25) / blk_sz;
1130 auto get_zmm = [=](int blk, int idx) {
1131 assert(idx >= 0 && idx < blk_sz && blk >= 0);
1132 auto reg_idx = blk_sz * blk + idx;
1133 assert(reg_idx >= 0 && reg_idx < 32);
1134 return zmm(reg_idx);
1135 };
1136 auto load = [=](int blk, int i) {
1137 auto src_reg = get_zmm(blk, i % k_blk_step);
1138 auto src_load = ncolumns < n_blk_step ? src_reg | kTail | T_z : src_reg;
1139 vmovdqu8(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1140 };
1141
1142 for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step); kb++)
1143 for (int k = 0;
1144 k < nstl::min(max_unroll,
1145 div_up(nrows - kb * max_unroll * k_blk_step, k_blk_step));
1146 k++) {
1147 const int row_start = (kb * max_unroll + k) * k_blk_step;
1148 const int row_end = nstl::min(row_start + k_blk_step, nrows);
1149
1150 for (int i = row_start; i < row_end; i++)
1151 load(k, i);
1152 if (row_end == nrows && nrows % k_blk_step > 0) {
1153 for (int i = nrows; i < rnd_up(nrows, k_blk_step); i++) {
1154 auto src_reg = get_zmm(k, i % k_blk_step);
1155 vpxord(src_reg, src_reg, src_reg);
1156 }
1157 }
1158
1159 vpunpcklbw(get_zmm(k, 4), get_zmm(k, 0), get_zmm(k, 1));
1160 vpunpckhbw(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 1));
1161 vpunpcklbw(get_zmm(k, 0), get_zmm(k, 2), get_zmm(k, 3));
1162 vpunpckhbw(get_zmm(k, 1), get_zmm(k, 2), get_zmm(k, 3));
1163
1164 vpunpcklwd(get_zmm(k, 2), get_zmm(k, 4), get_zmm(k, 0));
1165 vpunpckhwd(get_zmm(k, 3), get_zmm(k, 4), get_zmm(k, 0));
1166 vpunpcklwd(get_zmm(k, 4), get_zmm(k, 5), get_zmm(k, 1));
1167 vpunpckhwd(get_zmm(k, 5), get_zmm(k, 5), get_zmm(k, 1));
1168
1169 vmovups(get_zmm(k, 0), vreg_idx_lo_256);
1170 vpermi2q(get_zmm(k, 0), get_zmm(k, 2), get_zmm(k, 4));
1171 vmovups(get_zmm(k, 1), vreg_idx_hi_256);
1172 vpermi2q(get_zmm(k, 1), get_zmm(k, 2), get_zmm(k, 4));
1173 vmovups(get_zmm(k, 2), vreg_idx_lo_256);
1174 vpermi2q(get_zmm(k, 2), get_zmm(k, 3), get_zmm(k, 5));
1175 vmovups(get_zmm(k, 4), vreg_idx_hi_256);
1176 vpermi2q(get_zmm(k, 4), get_zmm(k, 3), get_zmm(k, 5));
1177
1178 vmovups(get_zmm(k, 3), vreg_idx_lo_128);
1179 vpermi2q(get_zmm(k, 3), get_zmm(k, 0), get_zmm(k, 2));
1180 dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride;
1181 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base), get_zmm(k, 3));
1182 if (do_compute_compensation)
1183 vpdpbusd(get_comp_acc(0), zmm_comp_mul, get_zmm(k, 3));
1184
1185 if (ncolumns > 16) {
1186 vmovups(get_zmm(k, 5), vreg_idx_hi_128);
1187 vpermi2q(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 2));
1188 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
1189 get_zmm(k, 5));
1190 if (do_compute_compensation)
1191 vpdpbusd(get_comp_acc(1), zmm_comp_mul, get_zmm(k, 5));
1192 } else if (conf_->wei_n_blk > 16) {
1193 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
1194 zmm_zero);
1195 }
1196
1197 if (ncolumns > 32) {
1198 vmovups(get_zmm(k, 0), vreg_idx_lo_128);
1199 vpermi2q(get_zmm(k, 0), get_zmm(k, 1), get_zmm(k, 4));
1200 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
1201 get_zmm(k, 0));
1202 if (do_compute_compensation)
1203 vpdpbusd(get_comp_acc(2), zmm_comp_mul, get_zmm(k, 0));
1204 } else if (conf_->wei_n_blk > 32) {
1205 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
1206 zmm_zero);
1207 }
1208
1209 if (ncolumns > 48) {
1210 vmovups(get_zmm(k, 2), vreg_idx_hi_128);
1211 vpermi2q(get_zmm(k, 2), get_zmm(k, 1), get_zmm(k, 4));
1212 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
1213 get_zmm(k, 2));
1214 if (do_compute_compensation)
1215 vpdpbusd(get_comp_acc(3), zmm_comp_mul, get_zmm(k, 2));
1216 } else if (conf_->wei_n_blk > 48) {
1217 vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
1218 zmm_zero);
1219 }
1220 }
1221}
1222
1223void jit_brgemm_matmul_copy_b_int8_t::generate() {
1224 preamble();
1225 vpxord(zmm_zero, zmm_zero, zmm_zero);
1226 src_stride = (conf_->wei_tag == format_tag::acbd ? conf_->copy_B_wei_stride
1227 : conf_->N * typesize);
1228 tr_src_stride = conf_->LDB * k_blk_step * typesize;
1229 is_amx = mayiuse(avx512_core_amx);
1230 do_compute_compensation
1231 = conf_->s8s8_compensation_required || conf_->has_zero_point_a;
1232
1233 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1234 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1235 mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
1236 mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
1237
1238 auto vmovdqa64 = [=](Zmm z, const void *addr) {
1239 mov(imm_addr64, reinterpret_cast<size_t>(addr));
1240 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1241 };
1242
1243 alignas(64) static constexpr const int64_t idx_lo_256[8]
1244 = {0, 1, 2, 3, 8, 9, 10, 11};
1245 alignas(64) static constexpr const int64_t idx_hi_256[8]
1246 = {4, 5, 6, 7, 12, 13, 14, 15};
1247
1248 alignas(64) static constexpr const int64_t idx_lo_128[8]
1249 = {0, 1, 8, 9, 4, 5, 12, 13};
1250 alignas(64) static constexpr const int64_t idx_hi_128[8]
1251 = {2, 3, 10, 11, 6, 7, 14, 15};
1252 alignas(64) static constexpr const uint8_t idx_lo_16[64]
1253 = {0, 1, 64, 65, 4, 5, 68, 69, 2, 3, 66, 67, 6, 7, 70, 71, 8, 9, 72,
1254 73, 12, 13, 76, 77, 10, 11, 74, 75, 14, 15, 78, 79, 16, 17,
1255 80, 81, 20, 21, 84, 85, 18, 19, 82, 83, 22, 23, 86, 87, 24,
1256 25, 88, 89, 28, 29, 92, 93, 26, 27, 90, 91, 30, 31, 94, 95};
1257
1258 alignas(64) static constexpr const uint8_t idx_hi_16[64] = {32, 33, 96, 97,
1259 36, 37, 100, 101, 34, 35, 98, 99, 38, 39, 102, 103, 40, 41, 104,
1260 105, 44, 45, 108, 109, 42, 43, 106, 107, 46, 47, 110, 111, 48, 49,
1261 112, 113, 52, 53, 116, 117, 50, 51, 114, 115, 54, 55, 118, 119, 56,
1262 57, 120, 121, 60, 61, 124, 125, 58, 59, 122, 123, 62, 63, 126, 127};
1263
1264 alignas(64) static constexpr const uint8_t idx_lo_8[64]
1265 = {0, 64, 2, 66, 1, 65, 3, 67, 8, 72, 10, 74, 9, 73, 11, 75, 4, 68,
1266 6, 70, 5, 69, 7, 71, 12, 76, 14, 78, 13, 77, 15, 79, 16, 80,
1267 18, 82, 17, 81, 19, 83, 24, 88, 26, 90, 25, 89, 27, 91, 20,
1268 84, 22, 86, 21, 85, 23, 87, 28, 92, 30, 94, 29, 93, 31, 95};
1269
1270 alignas(64) static constexpr const uint8_t idx_hi_8[64] = {32, 96, 34, 98,
1271 33, 97, 35, 99, 40, 104, 42, 106, 41, 105, 43, 107, 36, 100, 38,
1272 102, 37, 101, 39, 103, 44, 108, 46, 110, 45, 109, 47, 111, 48, 112,
1273 50, 114, 49, 113, 51, 115, 56, 120, 58, 122, 57, 121, 59, 123, 52,
1274 116, 54, 118, 53, 117, 55, 119, 60, 124, 62, 126, 61, 125, 63, 127};
1275
1276 vmovdqa64(vreg_idx_lo_256,
1277 is_amx ? (const void *)idx_lo_16 : (const void *)idx_lo_256);
1278 vmovdqa64(vreg_idx_hi_256,
1279 is_amx ? (const void *)idx_hi_16 : (const void *)idx_hi_256);
1280 vmovdqa64(vreg_idx_lo_128,
1281 is_amx ? (const void *)idx_lo_8 : (const void *)idx_lo_128);
1282 vmovdqa64(vreg_idx_hi_128,
1283 is_amx ? (const void *)idx_hi_8 : (const void *)idx_hi_128);
1284
1285 if (do_compute_compensation) {
1286 int n_iters = div_up(conf_->wei_n_blk, 16);
1287 for (int i = 0; i < n_iters; i++)
1288 vpxord(get_comp_acc(i), get_comp_acc(i), get_comp_acc(i));
1289 mov(imm_addr64, 1);
1290 vpbroadcastb(zmm_comp_mul, imm_addr64.cvt8());
1291 }
1292
1293 auto compute_K_loop = [=](bool is_N_tail) {
1294 const int k_unroll = 4;
1295 int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk;
1296
1297 Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done;
1298 cmp(reg_K_iters, k_unroll * k_blk_step);
1299 jl(K_loop_single, T_NEAR);
1300
1301 L(K_loop_unrolled);
1302 copy_4x64_vnni(k_unroll * k_blk_step, ncolumns);
1303 add(reg_src, k_unroll * k_blk_step * src_stride);
1304 add(reg_tr_src, k_unroll * tr_src_stride);
1305
1306 sub(reg_K_iters, k_unroll * k_blk_step);
1307 cmp(reg_K_iters, k_unroll * k_blk_step);
1308 jge(K_loop_unrolled, T_NEAR);
1309
1310 L(K_loop_single);
1311 cmp(reg_K_iters, k_blk_step);
1312 jl(K_loop_tail_or_done, T_NEAR);
1313
1314 copy_4x64_vnni(k_blk_step, ncolumns);
1315 add(reg_src, k_blk_step * src_stride);
1316 add(reg_tr_src, tr_src_stride);
1317
1318 sub(reg_K_iters, k_blk_step);
1319 jmp(K_loop_single, T_NEAR);
1320
1321 L(K_loop_tail_or_done);
1322
1323 int k_blk_tail = conf_->K % k_blk_step;
1324 if (k_blk_tail > 0) {
1325 Label K_loop_done;
1326 cmp(reg_K_iters, 0);
1327 jle(K_loop_done, T_NEAR);
1328
1329 copy_4x64_vnni(k_blk_tail, ncolumns);
1330 sub(reg_K_iters, k_blk_tail);
1331 L(K_loop_done);
1332 }
1333 };
1334
1335 Label done;
1336 if (conf_->N_tail > 0) {
1337 Label not_N_tail;
1338 cmp(reg_N_blk, conf_->N_tail);
1339 jne(not_N_tail, T_NEAR);
1340 compute_K_loop(true);
1341 jmp(done, T_NEAR);
1342
1343 L(not_N_tail);
1344 }
1345
1346 compute_K_loop(false);
1347 L(done);
1348
1349 if (do_compute_compensation) {
1350 const bool req_s8s8_comp = conf_->s8s8_compensation_required;
1351 const bool req_zp_comp = conf_->has_zero_point_a;
1352 int n_iters = div_up(conf_->wei_n_blk, 16);
1353 assert(IMPLICATION(req_zp_comp,
1354 conf_->src_zp_type == brgemm_broadcast_t::per_tensor));
1355
1356 // copy 'comp_acc' into s8s8_comp accumulator
1357 if (req_s8s8_comp) {
1358 for (int i = 0; i < n_iters; i++)
1359 vmovups(get_zmm_oscale_comp_res(i), get_comp_acc(i));
1360 }
1361
1362 Label skip_acc, store;
1363 if (req_s8s8_comp)
1364 mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]);
1365 if (req_zp_comp)
1366 mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]);
1367
1368 mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
1369 cmp(reg_K_start, 0);
1370 je(skip_acc, T_NEAR);
1371 if (req_s8s8_comp) {
1372 for (int i = 0; i < n_iters; i++) {
1373 const auto zmm_acc = get_comp_acc(i);
1374 const auto zmm_res = get_zmm_oscale_comp_res(i);
1375 const auto addr = EVEX_compress_addr(reg_comp_ptr, i * 64);
1376 vpaddd(zmm_res, zmm_acc, addr);
1377 }
1378 }
1379
1380 if (req_zp_comp) {
1381 for (int i = 0; i < n_iters; i++) {
1382 const auto zmm_acc = get_comp_acc(i);
1383 const auto zmm_res = get_zmm_zp_comp_res(i);
1384 const auto addr = EVEX_compress_addr(reg_zp_comp_ptr, i * 64);
1385 vpaddd(zmm_res, zmm_acc, addr);
1386 }
1387 }
1388
1389 L(skip_acc);
1390 cmp(reg_K_start, rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk);
1391 jl(store, T_NEAR);
1392
1393 if (req_s8s8_comp) {
1394 mov(imm_addr64, 0xffffffff);
1395 const auto zmm_all_bits_1 = zmm_comp_mul;
1396 vpbroadcastd(zmm_all_bits_1, imm_addr64.cvt32());
1397 mov(imm_addr64, 0x1);
1398 const auto zmm_one_s32 = zmm_zero;
1399 vpbroadcastd(zmm_one_s32, imm_addr64.cvt32());
1400
1401 for (int i = 0; i < n_iters; i++) {
1402 const auto zmm_res = get_zmm_oscale_comp_res(i);
1403 // multiply by 128
1404 vpslld(zmm_res, zmm_res, 7);
1405 // change sign
1406 vpandnq(zmm_res, zmm_res, zmm_all_bits_1);
1407 vpaddd(zmm_res, zmm_res, zmm_one_s32);
1408 }
1409 }
1410
1411 if (req_zp_comp) {
1412 mov(reg_zp_a_neg_val_ptr,
1413 ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]);
1414 const auto zmm_zp_a_neg_val = vreg_idx_hi_128;
1415 vbroadcastss(zmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]);
1416
1417 for (int i = 0; i < n_iters; i++) {
1418 const auto zmm_res = get_zmm_zp_comp_res(i);
1419 vpmulld(zmm_res, zmm_res, zmm_zp_a_neg_val);
1420 }
1421 }
1422
1423 L(store);
1424 if (req_s8s8_comp) {
1425 for (int i = 0; i < n_iters; i++) {
1426 const auto zmm_res = get_zmm_oscale_comp_res(i);
1427 const auto addr = EVEX_compress_addr(reg_comp_ptr, i * 64);
1428 vmovups(addr, zmm_res);
1429 }
1430 }
1431 if (req_zp_comp) {
1432 for (int i = 0; i < n_iters; i++) {
1433 const auto zmm_res = get_zmm_zp_comp_res(i);
1434 const auto addr = EVEX_compress_addr(reg_zp_comp_ptr, i * 64);
1435 vmovups(addr, zmm_res);
1436 }
1437 }
1438 }
1439
1440 postamble();
1441}
1442
1443struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
1444 public jit_generator {
1445 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t)
1446
1447 jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf)
1448 : jit_brgemm_matmul_copy_b_t(conf)
1449 , jit_generator(jit_name())
1450 , typesize(conf->b_dt_sz)
1451 , tr_typesize(conf->tr_b_dt_sz)
1452 , src_stride(conf_->wei_tag == format_tag::acbd
1453 ? conf->copy_B_wei_stride
1454 : conf->req_wei_vnni_downconvert
1455 ? conf_->LDB * typesize
1456 : conf_->N * typesize)
1457 , tr_src_stride(conf_->LDB * k_blk_step * tr_typesize) {}
1458
1459 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1460 status_t create_kernel() override { return jit_generator::create_kernel(); }
1461
1462private:
1463 using reg64_t = const Xbyak::Reg64;
1464 using reg32_t = const Xbyak::Reg32;
1465 using opmask_t = const Xbyak::Opmask;
1466 using zmm = const Xbyak::Zmm;
1467 using ymm = const Xbyak::Ymm;
1468
1469 enum { k_blk_step = 2, n_blk_step = 16 };
1470 const int typesize, tr_typesize;
1471 const dim_t src_stride, tr_src_stride;
1472
1473 opmask_t kTail = k7;
1474 opmask_t kFFFF = k6;
1475
1476 reg64_t reg_src = rax;
1477 reg64_t reg_tr_src = rbx;
1478
1479 reg64_t reg_K_iters = r8;
1480 reg64_t reg_N_blk = r9;
1481 reg64_t reg_K_start = r10;
1482 reg32_t regw_tmp = r14d;
1483 reg64_t imm_addr64 = r15;
1484
1485 zmm zmm_permw = zmm30;
1486 zmm zmm_zero = zmm31;
1487
1488 void copy_2x32_vnni(int nrows, int ncolumns);
1489 void generate() override;
1490};
1491
1492void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32_vnni(int nrows, int ncolumns) {
1493
1494 auto kmovx = [=](Opmask k, unsigned w) {
1495 mov(regw_tmp, w);
1496 if (conf_->is_bf32)
1497 jit_generator::kmovw(k, regw_tmp);
1498 else
1499 jit_generator::kmovd(k, regw_tmp);
1500 };
1501
1502 const int columns_tail = ncolumns % n_blk_step;
1503 const auto tail_mask = (1 << columns_tail) - 1;
1504 if (columns_tail < n_blk_step) kmovx(kTail, tail_mask);
1505
1506 const int blk_sz = k_blk_step;
1507 const int max_regs_available = 30;
1508 const int max_unroll = max_regs_available / blk_sz;
1509 auto get_zmm = [=](int blk, int idx) {
1510 assert(idx >= 0 && idx < blk_sz && blk >= 0);
1511 auto reg_idx = max_unroll * ((idx + 1) % blk_sz) + blk;
1512 assert(reg_idx >= 0 && reg_idx < max_regs_available);
1513 return zmm(reg_idx);
1514 };
1515
1516 auto load = [=](int blk, int k, int n, opmask_t current_mask) {
1517 auto src_reg = get_zmm(blk, k % k_blk_step);
1518 auto src_load = src_reg | current_mask | T_z;
1519 auto load_addr
1520 = EVEX_compress_addr(reg_src, k * src_stride + n * typesize);
1521 if (conf_->is_bf32) {
1522 vmovups(src_load, load_addr);
1523 } else {
1524 vmovdqu16(src_load, load_addr);
1525 }
1526 };
1527
1528 int iter = 0;
1529 for_(int k = 0; k < nrows; k += k_blk_step)
1530 for (int n = 0; n < conf_->wei_n_blk; n += n_blk_step) {
1531 const int k_blk = k / k_blk_step;
1532 const dim_t tr_src_off
1533 = k_blk * tr_src_stride + n * k_blk_step * tr_typesize;
1534 const auto store_addr = EVEX_compress_addr(reg_tr_src, tr_src_off);
1535 if (ncolumns - n <= 0) {
1536 vmovups(store_addr, zmm_zero);
1537 continue;
1538 }
1539
1540 const opmask_t curr_msk = ncolumns - n < n_blk_step ? kTail : kFFFF;
1541 const int blk_idx = iter % max_unroll;
1542 load(blk_idx, k, n, curr_msk);
1543
1544 const auto src_zmm0 = get_zmm(blk_idx, 0);
1545 if (nrows - k >= k_blk_step) {
1546 load(blk_idx, k + 1, n, curr_msk);
1547 const auto src_zmm1 = get_zmm(blk_idx, 1);
1548 if (conf_->is_bf32) {
1549 vcvtne2ps2bf16(src_zmm0, src_zmm1, src_zmm0);
1550 } else {
1551 const auto src_ymm1 = ymm(src_zmm1.getIdx());
1552 vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1);
1553 }
1554 } else if (conf_->is_bf32) {
1555 vcvtneps2bf16(ymm(src_zmm0.getIdx()), src_zmm0);
1556 }
1557
1558 vpermw(src_zmm0, zmm_permw, src_zmm0);
1559
1560 vmovups(store_addr, src_zmm0);
1561 iter++;
1562 }
1563}
1564
1565void jit_brgemm_matmul_copy_b_bf16_t::generate() {
1566 assert(tr_typesize == sizeof(bfloat16_t));
1567 preamble();
1568 vpxord(zmm_zero, zmm_zero, zmm_zero);
1569
1570 alignas(64) static constexpr const int16_t bf16_vnni_permute[32]
1571 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
1572 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
1573
1574 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1575 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1576 mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
1577 mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
1578
1579 kxnorw(kFFFF, kFFFF, kFFFF); // 1111 1111 1111 1111
1580 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
1581 mov(imm_addr64, reinterpret_cast<size_t>(addr));
1582 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1583 };
1584
1585 vmovdqa64(zmm_permw, (const int64_t *)bf16_vnni_permute);
1586
1587 auto compute_K_loop = [=](bool is_N_tail) {
1588 const int k_unroll = 8;
1589 int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk;
1590
1591 Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done;
1592 cmp(reg_K_iters, k_unroll * k_blk_step);
1593 jl(K_loop_single, T_NEAR);
1594
1595 L(K_loop_unrolled);
1596 copy_2x32_vnni(k_unroll * k_blk_step, ncolumns);
1597 add(reg_src, k_unroll * k_blk_step * src_stride);
1598 add(reg_tr_src, k_unroll * tr_src_stride);
1599
1600 sub(reg_K_iters, k_unroll * k_blk_step);
1601 cmp(reg_K_iters, k_unroll * k_blk_step);
1602 jge(K_loop_unrolled, T_NEAR);
1603
1604 L(K_loop_single);
1605 cmp(reg_K_iters, k_blk_step);
1606 jl(K_loop_tail_or_done, T_NEAR);
1607
1608 copy_2x32_vnni(k_blk_step, ncolumns);
1609 add(reg_src, k_blk_step * src_stride);
1610 add(reg_tr_src, tr_src_stride);
1611
1612 sub(reg_K_iters, k_blk_step);
1613 jmp(K_loop_single, T_NEAR);
1614
1615 L(K_loop_tail_or_done);
1616
1617 int k_blk_tail = conf_->K % k_blk_step;
1618 if (k_blk_tail > 0) {
1619 Label K_loop_done;
1620 cmp(reg_K_iters, 0);
1621 jle(K_loop_done, T_NEAR);
1622
1623 copy_2x32_vnni(k_blk_tail, ncolumns);
1624 sub(reg_K_iters, k_blk_tail);
1625 L(K_loop_done);
1626 }
1627 };
1628
1629 Label done;
1630 if (conf_->N_tail > 0) {
1631 Label not_N_tail;
1632 cmp(reg_N_blk, conf_->N_tail);
1633 jne(not_N_tail, T_NEAR);
1634 compute_K_loop(true);
1635 jmp(done, T_NEAR);
1636
1637 L(not_N_tail);
1638 }
1639
1640 compute_K_loop(false);
1641 L(done);
1642
1643 postamble();
1644}
1645
1646struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t,
1647 public jit_generator {
1648 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_f32_t)
1649
1650 jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf)
1651 : jit_brgemm_matmul_copy_b_t(conf)
1652 , jit_generator(jit_name())
1653 , dt_in_(conf->isa == avx512_core_fp16 ? data_type::f16
1654 : data_type::f32)
1655 , typesize_in_(types::data_type_size(dt_in_))
1656 , src_stride_(conf_->wei_tag == acbd ? conf_->copy_B_wei_stride
1657 : conf_->N * typesize_in_)
1658 , tr_src_stride_(conf_->LDB * typesize_out_) {}
1659
1660 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1661 status_t create_kernel() override { return jit_generator::create_kernel(); }
1662
1663private:
1664 using reg64_t = const Xbyak::Reg64;
1665 using reg32_t = const Xbyak::Reg32;
1666 using opmask_t = const Xbyak::Opmask;
1667 using zmm = const Xbyak::Zmm;
1668
1669 enum { n_blk_step = 16, max_regs_available = 30 };
1670 const data_type_t dt_in_;
1671 const size_t typesize_in_;
1672 const size_t typesize_out_ = sizeof(float);
1673 dim_t src_stride_, tr_src_stride_;
1674
1675 opmask_t kTail = k7;
1676 opmask_t kFFFF = k6;
1677
1678 reg64_t reg_src = rax;
1679 reg64_t reg_tr_src = rbx;
1680
1681 reg64_t reg_K_iters = r8;
1682 reg64_t reg_N_blk = r9;
1683 reg64_t reg_K_start = r10;
1684 reg32_t regw_tmp = r14d;
1685 reg64_t imm_addr64 = r15;
1686
1687 zmm zmm_permw = zmm30;
1688 zmm zmm_zero = zmm31;
1689
1690 inline void kmovw(Opmask k, unsigned w) {
1691 mov(regw_tmp, w);
1692 jit_generator::kmovd(k, regw_tmp);
1693 }
1694 void copy_16_x_n_block(int nrows, int ncolumns);
1695 void compute_k_loop(int ncolumns);
1696 void generate() override;
1697};
1698
1699void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block(
1700 int nrows, int ncolumns) {
1701
1702 auto get_zmm = [=](int reg_idx) {
1703 assert(reg_idx >= 0 && reg_idx < max_regs_available);
1704 return zmm(reg_idx);
1705 };
1706
1707 auto load = [=](int blk, int k, int n, opmask_t current_mask) {
1708 auto src_zmm = get_zmm(blk);
1709 auto src_zmm_m = src_zmm | current_mask | T_z;
1710 auto addr = EVEX_compress_addr(
1711 reg_src, k * src_stride_ + n * typesize_in_);
1712 if (dt_in_ == data_type::f16)
1713 vcvtph2psx(src_zmm_m, addr);
1714 else
1715 vmovups(src_zmm_m, addr);
1716 };
1717
1718 const int columns_tail = ncolumns % n_blk_step;
1719 const auto tail_mask = (1 << columns_tail) - 1;
1720 if (columns_tail < n_blk_step) kmovw(kTail, tail_mask);
1721
1722 int iter = 0;
1723 for_(int k = 0; k < nrows; k++)
1724 for (int n = 0; n < conf_->wei_n_blk; n += n_blk_step) {
1725 const dim_t tr_src_off = k * tr_src_stride_ + n * typesize_out_;
1726 const auto store_addr = EVEX_compress_addr(reg_tr_src, tr_src_off);
1727
1728 const int zero_padding = ncolumns - n;
1729 if (zero_padding <= 0) {
1730 vmovups(store_addr, zmm_zero);
1731 continue;
1732 }
1733
1734 const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : kFFFF;
1735 const int blk_idx = iter % max_regs_available;
1736 load(blk_idx, k, n, curr_msk);
1737
1738 const auto src_zmm0 = get_zmm(blk_idx);
1739 vmovups(store_addr, src_zmm0);
1740 iter++;
1741 }
1742}
1743
1744void jit_brgemm_matmul_copy_b_f32_t::compute_k_loop(int ncolumns) {
1745
1746 auto compute_uni_k_loop = [&](int unroll) {
1747 Label K_start_label, K_end_label;
1748
1749 L(K_start_label);
1750 cmp(reg_K_iters, unroll);
1751 jl(K_end_label, T_NEAR);
1752
1753 copy_16_x_n_block(unroll, ncolumns);
1754 add(reg_src, unroll * src_stride_);
1755 add(reg_tr_src, unroll * tr_src_stride_);
1756
1757 sub(reg_K_iters, unroll);
1758 jmp(K_start_label, T_NEAR);
1759
1760 L(K_end_label);
1761 };
1762
1763 constexpr int k_unroll = 16;
1764 compute_uni_k_loop(k_unroll);
1765 compute_uni_k_loop(1);
1766}
1767
1768void jit_brgemm_matmul_copy_b_f32_t::generate() {
1769 preamble();
1770 vpxord(zmm_zero, zmm_zero, zmm_zero);
1771
1772 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1773 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1774 mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
1775 mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
1776 kmovw(kFFFF, 0xffff); // 1111111111111111
1777
1778 Label done;
1779 if (conf_->N_tail > 0) {
1780 Label not_N_tail;
1781 cmp(reg_N_blk, conf_->N_tail);
1782 jne(not_N_tail, T_NEAR);
1783 compute_k_loop(conf_->N_tail);
1784 jmp(done, T_NEAR);
1785
1786 L(not_N_tail);
1787 }
1788
1789 compute_k_loop(conf_->N_blk);
1790 L(done);
1791
1792 postamble();
1793}
1794
1795struct jit_brgemm_matmul_copy_b_transposed_t
1796 : public jit_brgemm_matmul_copy_b_t,
1797 public jit_generator {
1798 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t)
1799
1800 jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf)
1801 : jit_brgemm_matmul_copy_b_t(conf)
1802 , jit_generator(jit_name())
1803 , typesize(conf_->b_dt_sz)
1804 , tr_typesize(conf_->tr_b_dt_sz)
1805 , vnni_granularity(data_type_vnni_granularity(conf_->wei_dt))
1806 , k_blk_step(bytes_in_zmm / tr_typesize)
1807 , do_compute_compensation(
1808 conf_->has_zero_point_a || conf_->s8s8_compensation_required)
1809 , is_bf32(conf->is_bf32)
1810 , req_zp_comp(conf_->has_zero_point_a)
1811 , req_s8s8_comp(conf_->s8s8_compensation_required)
1812 , src_stride(conf_->wei_tag == format_tag::adbc
1813 ? conf_->copy_B_wei_stride
1814 : conf_->K * typesize)
1815 , tr_src_stride(conf_->LDB * vnni_granularity * tr_typesize) {}
1816
1817 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1818 status_t create_kernel() override { return jit_generator::create_kernel(); }
1819
1820private:
1821 using reg64_t = const Xbyak::Reg64;
1822 using reg32_t = const Xbyak::Reg32;
1823 using opmask_t = const Xbyak::Opmask;
1824 using zmm = const Xbyak::Zmm;
1825
1826 enum {
1827 n_blk_step = 16,
1828 bytes_in_zmm = 64,
1829 bf32_k_blk_step = 16,
1830 };
1831
1832 const int typesize;
1833 const int tr_typesize;
1834 const int vnni_granularity;
1835 const int k_blk_step;
1836 const bool do_compute_compensation;
1837 const bool is_bf32;
1838 const bool req_zp_comp;
1839 const bool req_s8s8_comp;
1840
1841 const dim_t src_stride, tr_src_stride;
1842
1843 opmask_t k3333 = k1;
1844 opmask_t k5555 = k2;
1845 opmask_t kAAAA = k3;
1846 opmask_t kCCCC = k4;
1847 opmask_t k0F0F = k5;
1848 opmask_t kF0F0 = k6;
1849 opmask_t kTail = k7;
1850
1851 reg64_t reg_src_base = rax;
1852 reg64_t reg_tr_src_base = rbx;
1853 reg64_t reg_comp_ptr = rdx;
1854
1855 reg64_t reg_K_iters = r8;
1856 reg64_t reg_N_iters = r9;
1857 reg64_t reg_src = r10;
1858 reg64_t reg_tr_src = r11;
1859 reg64_t reg_zp_comp_ptr = r12;
1860 reg64_t reg_zp_a_neg_val_ptr = r13;
1861 reg64_t reg_K_start = r14;
1862
1863 reg64_t regq_tmp = r15;
1864 reg32_t regw_tmp = r15d;
1865 reg64_t imm_addr64 = abi_not_param1;
1866
1867 zmm zmm_zp_a_neg_val = zmm29;
1868 zmm zmm_comp_acc = zmm30;
1869 zmm zmm_comp_mul = zmm31;
1870 zmm zmm_s8s8_comp_acc = zmm28;
1871 zmm zmm_all_bits_1 = zmm27;
1872 zmm zmm_one_s32 = zmm26;
1873
1874 void kmovw(Opmask k, unsigned w) {
1875 mov(regw_tmp, w);
1876 jit_generator::kmovw(k, regw_tmp);
1877 };
1878
1879 void kmovq(Opmask k, size_t q) {
1880 mov(regq_tmp, q);
1881 jit_generator::kmovq(k, regq_tmp);
1882 };
1883
1884 void copy_16x64_vnni(int nrows, int ncolumns);
1885 void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter,
1886 bool is_last_K_iter);
1887 void compute_N_loop(
1888 int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter);
1889
1890 void generate() override;
1891};
1892
1893void jit_brgemm_matmul_copy_b_transposed_t::copy_16x64_vnni(
1894 int nrows, int ncolumns) {
1895 assert(nrows >= 0 && nrows <= n_blk_step && ncolumns >= 0
1896 && ncolumns <= k_blk_step);
1897 if (!nrows) return;
1898
1899 auto src_zmm = [=](int i) {
1900 assert(i >= 0 && i < 16);
1901 return Zmm(i);
1902 };
1903
1904 auto tmp_zmm = [=](int i) {
1905 // If compensation compute is required - last 6 zmms are reserved for it
1906 assert(i >= 0 && i < 16 - do_compute_compensation * 6);
1907 return Zmm(16 + i);
1908 };
1909
1910 const int columns_tail
1911 = ncolumns % (is_bf32 ? bf32_k_blk_step : k_blk_step);
1912 if (columns_tail > 0) {
1913 const int dt_step
1914 = (is_bf32 || conf_->isa == avx512_core_fp16) ? 1 : typesize;
1915 const auto tail_mask
1916 = size_t(((size_t)1 << dt_step * columns_tail) - 1);
1917 if (is_bf32)
1918 kmovw(kTail, tail_mask);
1919 else
1920 kmovq(kTail, tail_mask);
1921 }
1922
1923 auto load_bf32 = [=](int i) {
1924 auto src_reg = src_zmm(i);
1925 auto src_reg_next = tmp_zmm(i);
1926
1927 if (i >= nrows) {
1928 vpxord(src_reg, src_reg, src_reg);
1929 return;
1930 }
1931
1932 // check if k_tail exists and it's in the first zmm
1933 auto zmm_src = columns_tail > 0 && ncolumns < bf32_k_blk_step
1934 ? src_reg | kTail | T_z
1935 : src_reg;
1936 vmovups(zmm_src, EVEX_compress_addr(reg_src, i * src_stride));
1937
1938 if (ncolumns <= bf32_k_blk_step) {
1939 vpxord(src_reg_next, src_reg_next, src_reg_next);
1940 } else {
1941 auto zmm_src_next = columns_tail > 0 ? src_reg_next | kTail | T_z
1942 : src_reg_next;
1943 vmovups(zmm_src_next,
1944 EVEX_compress_addr(reg_src,
1945 i * src_stride + bf32_k_blk_step * typesize));
1946 }
1947
1948 vcvtne2ps2bf16(src_reg, src_reg_next, src_reg);
1949 };
1950
1951 auto load = [=](int i) {
1952 auto src_reg = src_zmm(i);
1953 if (i >= nrows) {
1954 vpxord(src_reg, src_reg, src_reg);
1955 return;
1956 }
1957
1958 auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg;
1959 const auto addr = EVEX_compress_addr(reg_src, i * src_stride);
1960 if (conf_->isa == avx512_core_fp16)
1961 vcvtph2psx(src_load, addr);
1962 else
1963 vmovdqu8(src_load, addr);
1964 };
1965
1966 auto store = [=](Zmm r, int i) {
1967 auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride);
1968 vmovups(addr, r);
1969 };
1970
1971 auto transpose16x8 = [=](int base_idx) {
1972 assert(base_idx == 0 || base_idx == 8);
1973 // If compensation compute is required - use tmp(0) ... tmp(7)
1974 // to not spoil reserved registers' values
1975 const int tmp_corr_idx = do_compute_compensation * base_idx;
1976
1977 // swap 1
1978 if (is_bf32) {
1979 for (int i = 0; i < 4; i++) {
1980 const int src_idx0 = base_idx + i * 2;
1981 const int src_idx1 = src_idx0 + 1;
1982
1983 if (base_idx == 0 && i == 0) {
1984 load_bf32(src_idx0);
1985 load_bf32(src_idx1);
1986 }
1987
1988 int next_src_idx0 = src_idx0 + 2;
1989 int next_src_idx1 = src_idx1 + 2;
1990
1991 bool load_next = base_idx == 0 || i < 3;
1992
1993 const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx);
1994 const auto tmp1 = tmp_zmm(src_idx1 - tmp_corr_idx);
1995 const auto src0 = src_zmm(src_idx0);
1996 const auto src1 = src_zmm(src_idx1);
1997
1998 if (next_src_idx0 < nrows && load_next)
1999 load_bf32(next_src_idx0);
2000 valignd(tmp0, src0, src0, 0x1);
2001
2002 if (next_src_idx1 < nrows && load_next)
2003 load_bf32(next_src_idx1);
2004 valignd(tmp1, src1, src1, 0xf);
2005
2006 vmovaps(src0 | kAAAA, tmp1);
2007 vmovaps(src1 | k5555, tmp0);
2008 }
2009 } else {
2010 for (int i = 0; i < 4; i++) {
2011 const int src_idx0 = base_idx + i * 2;
2012 const int src_idx1 = src_idx0 + 1;
2013
2014 int next_src_idx0 = src_idx0 + 2;
2015 int next_src_idx1 = src_idx1 + 2;
2016 bool load_next = base_idx == 0 || i < 3;
2017
2018 if (base_idx == 0 && i == 0) {
2019 load(src_idx0);
2020 load(src_idx1);
2021 }
2022
2023 const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx);
2024 const auto tmp1 = tmp_zmm(src_idx1 - tmp_corr_idx);
2025 const auto src0 = src_zmm(src_idx0);
2026 const auto src1 = src_zmm(src_idx1);
2027
2028 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
2029 valignd(tmp0, src0, src0, 0x1);
2030
2031 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
2032 valignd(tmp1, src1, src1, 0xf);
2033
2034 vmovaps(src0 | kAAAA, tmp1);
2035 vmovaps(src1 | k5555, tmp0);
2036 }
2037 }
2038 // swap 2
2039 for (int i = 0; i < 4; i++) {
2040 const int select_half = (i < 2) ? 0 : 2;
2041 const int src_idx0 = base_idx + i + select_half + 0;
2042 const int src_idx2 = src_idx0 + 2;
2043
2044 const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx);
2045 const auto tmp1 = tmp_zmm(src_idx2 - tmp_corr_idx);
2046 const auto src0 = src_zmm(src_idx0);
2047 const auto src2 = src_zmm(src_idx2);
2048
2049 valignd(tmp0, src0, src0, 0x2);
2050 valignd(tmp1, src2, src2, 0xe);
2051 vmovaps(src2 | k3333, tmp0);
2052 vmovaps(src0 | kCCCC, tmp1);
2053 }
2054
2055 // swap 4
2056 for (int i = 0; i < 4; i++) {
2057 const int src_idx0 = base_idx + i;
2058 const int src_idx4 = src_idx0 + 4;
2059
2060 const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx);
2061 const auto src0 = src_zmm(src_idx0);
2062 const auto src4 = src_zmm(src_idx4);
2063
2064 vmovaps(tmp0, src0);
2065 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
2066 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
2067 }
2068 };
2069
2070 auto fixup16x16 = [=]() {
2071 // swap 8
2072 for (int i = 0; i < 8; i++) {
2073 const auto tmp = tmp_zmm(i);
2074 const auto src0 = src_zmm(i);
2075 const auto src8 = src_zmm(8 + i);
2076 vshuff64x2(tmp, src0, src8, 0x44);
2077 if (do_compute_compensation)
2078 vpdpbusd(zmm_comp_acc, zmm_comp_mul, tmp);
2079 store(tmp, i);
2080 }
2081
2082 for (int i = 0; i < 8; i++) {
2083 // If compensation compute is required - last 4 zmms are reserved
2084 const auto tmp = IMPLICATION(do_compute_compensation, i < 2)
2085 ? tmp_zmm(8 + i)
2086 : src_zmm((i - 2) / 2 + (i % 2) * 8);
2087 const auto src0 = src_zmm(i);
2088 const auto src8 = src_zmm(8 + i);
2089 vshuff64x2(tmp, src0, src8, 0xee);
2090 if (do_compute_compensation)
2091 vpdpbusd(zmm_comp_acc, zmm_comp_mul, tmp);
2092 store(tmp, 8 + i);
2093 }
2094 };
2095
2096 transpose16x8(0);
2097 transpose16x8(8);
2098 fixup16x16();
2099}
2100
2101void jit_brgemm_matmul_copy_b_transposed_t::compute_K_loop(bool is_N_tail,
2102 int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
2103 MAYBE_UNUSED(is_first_K_iter);
2104 MAYBE_UNUSED(is_last_K_iter);
2105 const int N_chunk_tail = conf_->N % n_blk_step;
2106 int nrows = is_N_tail ? N_chunk_tail : n_blk_step;
2107 if (do_compute_compensation)
2108 vpxord(zmm_comp_acc, zmm_comp_acc, zmm_comp_acc);
2109
2110 Label K_loop, K_loop_tail_or_done;
2111 mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
2112
2113 mov(reg_src, reg_src_base);
2114 mov(reg_tr_src, reg_tr_src_base);
2115 if (curr_K_tail > 0) {
2116 cmp(reg_K_iters, k_blk_step);
2117 jl(K_loop_tail_or_done, T_NEAR);
2118 }
2119
2120 L(K_loop);
2121 copy_16x64_vnni(nrows, k_blk_step);
2122 add(reg_src, k_blk_step * typesize);
2123 add(reg_tr_src, k_blk_step / vnni_granularity * tr_src_stride);
2124
2125 sub(reg_K_iters, k_blk_step);
2126 cmp(reg_K_iters, k_blk_step);
2127 jge(K_loop, T_NEAR);
2128
2129 L(K_loop_tail_or_done);
2130
2131 if (curr_K_tail > 0) copy_16x64_vnni(nrows, curr_K_tail);
2132
2133 if (req_s8s8_comp) {
2134 const auto addr = zword[reg_comp_ptr];
2135 if (!is_first_K_iter)
2136 vpaddd(zmm_s8s8_comp_acc, zmm_comp_acc, addr);
2137 else
2138 vmovups(zmm_s8s8_comp_acc, zmm_comp_acc);
2139
2140 if (is_last_K_iter) {
2141 // multiply by 128
2142 vpslld(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, 7);
2143 // change sign
2144 vpandnq(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, zmm_all_bits_1);
2145 vpaddd(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, zmm_one_s32);
2146 }
2147 vmovups(addr, zmm_s8s8_comp_acc);
2148 }
2149 if (req_zp_comp) {
2150 const auto addr = zword[reg_zp_comp_ptr];
2151 if (!is_first_K_iter) vpaddd(zmm_comp_acc, zmm_comp_acc, addr);
2152 if (is_last_K_iter)
2153 vpmulld(zmm_comp_acc, zmm_comp_acc, zmm_zp_a_neg_val);
2154 vmovups(addr, zmm_comp_acc);
2155 }
2156}
2157
2158void jit_brgemm_matmul_copy_b_transposed_t::compute_N_loop(
2159 int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
2160 const int N_chunk_tail = conf_->N % n_blk_step;
2161 const size_t comp_shift = 64;
2162
2163 Label N_loop, N_loop_tail_or_done;
2164 if (N_chunk_tail > 0) {
2165 cmp(reg_N_iters, n_blk_step);
2166 jl(N_loop_tail_or_done, T_NEAR);
2167 }
2168
2169 L(N_loop);
2170 compute_K_loop(false, curr_K_tail, is_first_K_iter, is_last_K_iter);
2171 add(reg_src_base, n_blk_step * src_stride);
2172 add(reg_tr_src_base, n_blk_step * vnni_granularity * tr_typesize);
2173
2174 if (req_zp_comp) add(reg_zp_comp_ptr, comp_shift);
2175 if (req_s8s8_comp) add(reg_comp_ptr, comp_shift);
2176
2177 sub(reg_N_iters, n_blk_step);
2178 cmp(reg_N_iters, n_blk_step);
2179 jge(N_loop, T_NEAR);
2180
2181 L(N_loop_tail_or_done);
2182 if (N_chunk_tail > 0) {
2183 Label N_loop_done;
2184 cmp(reg_N_iters, 0);
2185 jle(N_loop_done, T_NEAR);
2186
2187 compute_K_loop(true, curr_K_tail, is_first_K_iter, is_last_K_iter);
2188 L(N_loop_done);
2189 }
2190}
2191
2192void jit_brgemm_matmul_copy_b_transposed_t::generate() {
2193
2194 preamble();
2195
2196 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
2197 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
2198 mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
2199 mov(reg_N_iters, ptr[param1 + GET_OFF(current_N_blk)]);
2200
2201 kmovw(k3333, 0x3333); // 0011001100110011
2202 kmovw(k5555, 0x5555); // 0101010101010101
2203 kmovw(kAAAA, 0xaaaa); // 1010101010101010
2204 kmovw(kCCCC, 0xcccc); // 1100110011001100
2205 kmovw(k0F0F, 0x0f0f); // 0000111100001111
2206 kmovw(kF0F0, 0xf0f0); // 1111000011110000
2207
2208 const dim_t N_chunk_elems = conf_->N_chunk_elems;
2209 assert(N_chunk_elems % n_blk_step == 0 || N_chunk_elems == conf_->N);
2210 UNUSED(N_chunk_elems);
2211
2212 const auto K_blk_tail = nstl::min(conf_->K, conf_->K_blk) % k_blk_step;
2213 const auto K_tail_tail = (conf_->K % conf_->K_blk) % k_blk_step;
2214
2215 auto compute_body = [=](bool is_first_K_iter, bool is_last_K_iter) {
2216 if (is_last_K_iter) {
2217 if (req_s8s8_comp) {
2218 mov(imm_addr64, 0xffffffff);
2219 vpbroadcastd(zmm_all_bits_1, imm_addr64.cvt32());
2220 mov(imm_addr64, 0x1);
2221 vpbroadcastd(zmm_one_s32, imm_addr64.cvt32());
2222 }
2223 if (req_zp_comp) {
2224 mov(reg_zp_a_neg_val_ptr,
2225 ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]);
2226 vbroadcastss(zmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]);
2227 }
2228 }
2229
2230 Label compute_body_done;
2231 if (conf_->K_tail > 0 && K_blk_tail != K_tail_tail) {
2232 Label not_K_tail;
2233 cmp(reg_K_iters, conf_->K_blk);
2234 je(not_K_tail, T_NEAR);
2235 compute_N_loop(K_tail_tail, is_first_K_iter, is_last_K_iter);
2236 jmp(compute_body_done, T_NEAR);
2237
2238 L(not_K_tail);
2239 }
2240
2241 compute_N_loop(K_blk_tail, is_first_K_iter, is_last_K_iter);
2242 L(compute_body_done);
2243 };
2244
2245 Label done;
2246 if (do_compute_compensation) {
2247 assert(IMPLICATION(req_zp_comp,
2248 conf_->src_zp_type == brgemm_broadcast_t::per_tensor));
2249
2250 mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
2251 if (req_s8s8_comp)
2252 mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]);
2253 if (req_zp_comp)
2254 mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]);
2255
2256 mov(regq_tmp, 1);
2257 vpbroadcastb(zmm_comp_mul, regq_tmp.cvt8());
2258
2259 const auto last_K_threshold
2260 = rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk;
2261 Label not_first, not_first_not_last;
2262 cmp(reg_K_start, 0);
2263 jne(not_first, T_NEAR);
2264 {
2265 // first K iteration
2266 Label first_not_last;
2267 cmp(reg_K_start, last_K_threshold);
2268 jl(first_not_last, T_NEAR);
2269 compute_body(true, true);
2270 jmp(done, T_NEAR);
2271
2272 L(first_not_last);
2273 compute_body(true, false);
2274 jmp(done, T_NEAR);
2275 }
2276
2277 L(not_first);
2278 cmp(reg_K_start, last_K_threshold);
2279 jl(not_first_not_last, T_NEAR);
2280
2281 compute_body(false, true);
2282 jmp(done, T_NEAR);
2283 L(not_first_not_last);
2284 }
2285
2286 compute_body(false, false);
2287 L(done);
2288
2289 postamble();
2290}
2291
2292status_t create_brgemm_matmul_copy_b(
2293 std::unique_ptr<jit_brgemm_matmul_copy_b_t> &copy_ker,
2294 const brgemm_matmul_conf_t *conf) {
2295 const bool is_B_transposed
2296 = one_of(conf->wei_tag, ba, acb, abdc, adbc, abced, abcdfe, abcdegf,
2297 abcdefhg, abcdefgih, abcdefghji, abcdefghikj, abcdefghijlk);
2298 const bool is_bf16
2299 = everyone_is(data_type::bf16, conf->src_dt, conf->wei_dt);
2300 const bool is_f32 = everyone_is(data_type::f32, conf->src_dt, conf->wei_dt);
2301 // Note: f16 support through avx512_core_fp16 sets src_dt and wei_dt as f32
2302 // to imply upconverting. So, the assumption is `is_f1`6 below evaluates to
2303 // `false` on avx512_core_fp16.
2304 const bool is_f16 = everyone_is(data_type::f16, conf->src_dt, conf->wei_dt);
2305 if (is_B_transposed) {
2306 CHECK(safe_ptr_assign(
2307 copy_ker, new jit_brgemm_matmul_copy_b_transposed_t(conf)));
2308 } else {
2309 if (is_bf16 || is_f16 || conf->is_bf32) {
2310 CHECK(safe_ptr_assign(
2311 copy_ker, new jit_brgemm_matmul_copy_b_bf16_t(conf)));
2312 } else if (is_f32 || conf->isa == avx512_core_fp16) {
2313 CHECK(safe_ptr_assign(
2314 copy_ker, new jit_brgemm_matmul_copy_b_f32_t(conf)));
2315 } else {
2316 CHECK(safe_ptr_assign(
2317 copy_ker, new jit_brgemm_matmul_copy_b_int8_t(conf)));
2318 }
2319 }
2320
2321 return copy_ker->create_kernel();
2322}
2323
2324status_t create_brgemm_matmul_copy_a(
2325 std::unique_ptr<jit_brgemm_matmul_copy_a_t> &copy_ker,
2326 const brgemm_matmul_conf_t *conf) {
2327 if (conf->transposed_A) {
2328 CHECK(safe_ptr_assign(copy_ker,
2329 new jit_brgemm_matmul_copy_a_transposed_impl_t(conf)));
2330 } else {
2331 CHECK(safe_ptr_assign(
2332 copy_ker, new jit_brgemm_matmul_copy_a_impl_t(conf)));
2333 }
2334
2335 return copy_ker->create_kernel();
2336}
2337
2338} // namespace matmul
2339} // namespace x64
2340} // namespace cpu
2341} // namespace impl
2342} // namespace dnnl
2343