1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21#include "cpu/x64/jit_generator.hpp"
22
23#include "cpu/x64/jit_brgemm_transpose_utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::format_tag;
31using namespace dnnl::impl::utils;
32using namespace Xbyak;
33
34#define GET_OFF(x) offsetof(ctx_t, x)
35
36struct jit_brgemm_trans_m_k_f32_t : public jit_brgemm_trans_src_t,
37 public jit_generator {
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f32_t)
39
40 jit_brgemm_trans_m_k_f32_t(const jit_brgemm_primitive_conf_t *conf)
41 : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {}
42
43 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
44 status_t create_kernel() override { return jit_generator::create_kernel(); }
45
46private:
47 using reg64_t = const Xbyak::Reg64;
48 using reg32_t = const Xbyak::Reg32;
49 using opmask_t = const Xbyak::Opmask;
50
51 enum { typesize = sizeof(float), transpose_size = 16 };
52 dim_t src_stride = 0, tr_src_stride = 0;
53
54 opmask_t k3333 = k1;
55 opmask_t k5555 = k2;
56 opmask_t kAAAA = k3;
57 opmask_t kCCCC = k4;
58 opmask_t k0F0F = k5;
59 opmask_t kF0F0 = k6;
60 opmask_t kTail = k7;
61
62 reg64_t reg_src_base = rax;
63 reg64_t reg_tr_src_base = rbx;
64
65 reg64_t reg_src = r8;
66 reg64_t reg_tr_src = r9;
67 reg64_t reg_loop_K = r10;
68 reg64_t reg_loop_M = r11;
69 reg64_t reg_loop_batch = r12;
70 reg64_t reg_tr_src_tmp = r13;
71 reg32_t regw_tmp = r14d;
72 reg64_t reg_row_loop = r15;
73
74 void transpose_16x16(int nrows, int ncolumns);
75 void transpose(int nrows, int ncolumns);
76 void generate() override;
77};
78
79void jit_brgemm_trans_m_k_f32_t::transpose_16x16(int nrows, int ncolumns) {
80 assert(nrows >= 0 && nrows <= transpose_size);
81 static_assert(transpose_size == 16, "Unsupported transpose size");
82 if (!nrows) return;
83
84 auto src_zmm = [=](int i) {
85 assert(i >= 0 && i < 16);
86 return Zmm(i);
87 };
88
89 auto tmp_zmm = [=](int i) {
90 assert(i >= 0 && i < 16);
91 return Zmm(16 + i);
92 };
93
94 auto kmovw = [=](Opmask k, unsigned w) {
95 mov(regw_tmp, w);
96 jit_generator::kmovw(k, regw_tmp);
97 };
98
99 auto load = [=](int i) {
100 auto src_load = src_zmm(i);
101 if (i >= nrows) {
102 vpxord(src_load, src_load, src_load);
103 return;
104 }
105
106 if (ncolumns < transpose_size) {
107 kmovw(kTail, (1 << ncolumns) - 1);
108 src_load = src_zmm(i) | kTail | T_z;
109 }
110 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
111 };
112
113 auto store = [=](Zmm r, int i) {
114 mov(reg_tr_src_tmp, reg_tr_src);
115 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
116
117 // Xbyak does not allow k0 to be specified explicitly via the '|'
118 // operator, so we have to do this via a method call (implicitly
119 // EVEX encoding uses k0 to mean 'no mask')
120 const bool partial_store = nrows < transpose_size;
121 const auto k = partial_store ? kTail : k0;
122 auto base = reg_tr_src_tmp;
123 base.setOpmaskIdx(k.getIdx(), true);
124
125 const auto addr = EVEX_compress_addr(base, i * tr_src_stride);
126 vmovups(addr, r);
127 };
128
129 auto transpose16x8 = [=](int base_idx) {
130 assert(base_idx == 0 || base_idx == 8);
131
132 // swap 1
133 for (int i = 0; i < 4; i++) {
134 const int src_idx0 = base_idx + i * 2;
135 const int src_idx1 = src_idx0 + 1;
136
137 const int next_src_idx0 = src_idx0 + 2;
138 const int next_src_idx1 = src_idx1 + 2;
139 const bool load_next = base_idx == 0 || i < 3;
140
141 if (base_idx == 0 && i == 0) {
142 load(src_idx0);
143 if (src_idx1 < nrows)
144 load(src_idx1);
145 else
146 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
147 src_zmm(src_idx1));
148 }
149
150 const auto tmp0 = tmp_zmm(src_idx0);
151 const auto tmp1 = tmp_zmm(src_idx1);
152 const auto src0 = src_zmm(src_idx0);
153 const auto src1 = src_zmm(src_idx1);
154
155 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
156 valignd(tmp0, src0, src0, 0x1);
157
158 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
159 valignd(tmp1, src1, src1, 0xf);
160
161 vmovaps(src0 | kAAAA, tmp1);
162 vmovaps(src1 | k5555, tmp0);
163 }
164 // swap 2
165 for (int i = 0; i < 4; i++) {
166 const int select_half = (i < 2) ? 0 : 2;
167 const int src_idx0 = base_idx + i + select_half + 0;
168 const int src_idx2 = src_idx0 + 2;
169
170 const auto tmp0 = tmp_zmm(src_idx0);
171 const auto tmp1 = tmp_zmm(src_idx2);
172 const auto src0 = src_zmm(src_idx0);
173 const auto src2 = src_zmm(src_idx2);
174
175 valignd(tmp0, src0, src0, 0x2);
176 valignd(tmp1, src2, src2, 0xe);
177 vmovaps(src2 | k3333, tmp0);
178 vmovaps(src0 | kCCCC, tmp1);
179 }
180
181 // swap 4
182 for (int i = 0; i < 4; i++) {
183 const int src_idx0 = base_idx + i;
184 const int src_idx4 = src_idx0 + 4;
185
186 const auto tmp0 = tmp_zmm(src_idx0);
187 const auto src0 = src_zmm(src_idx0);
188 const auto src4 = src_zmm(src_idx4);
189
190 vmovaps(tmp0, src0);
191 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
192 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
193 }
194 };
195
196 auto fixup16x16 = [=]() {
197 // swap 8
198 const auto max_iters_phase_1 = std::min(ncolumns, 8);
199 for (int i = 0; i < max_iters_phase_1; i++) {
200 const auto tmp = tmp_zmm(i);
201 const auto src0 = src_zmm(i);
202 const auto src8 = src_zmm(8 + i);
203 vshuff64x2(tmp, src0, src8, 0x44);
204 store(tmp, i);
205 }
206
207 const auto max_iters_phase_2 = std::min(ncolumns - 8, 8);
208 for (int i = 0; i < max_iters_phase_2; i++) {
209 const auto tmp = tmp_zmm(8 + i);
210 const auto src0 = src_zmm(i);
211 const auto src8 = src_zmm(8 + i);
212 vshuff64x2(tmp, src0, src8, 0xee);
213 store(tmp, 8 + i);
214 }
215 };
216
217 transpose16x8(0);
218 transpose16x8(8);
219 fixup16x16();
220}
221
222void jit_brgemm_trans_m_k_f32_t::transpose(int nrows, int ncolumns) {
223
224 Label K_loop, K_tail_or_done, K_done;
225 const int num_nrows_loop = nrows / transpose_size;
226 const int nrows_tail = nrows % transpose_size;
227 const dim_t src_shift = transpose_size * conf_->ic * typesize;
228 const dim_t tr_src_shift = transpose_size * typesize;
229
230 if (num_nrows_loop > 1) mov(reg_row_loop, num_nrows_loop);
231 L(K_loop);
232 if (num_nrows_loop > 0) transpose_16x16(transpose_size, ncolumns);
233 if (num_nrows_loop > 1 || (num_nrows_loop > 0 && nrows_tail > 0)) {
234 add(reg_src, src_shift);
235 add(reg_tr_src, tr_src_shift);
236 }
237 if (num_nrows_loop > 1) {
238 dec(reg_row_loop);
239 jg(K_loop);
240 }
241
242 if (nrows_tail > 0) { transpose_16x16(nrows_tail, ncolumns); }
243
244 if (num_nrows_loop > 1 || nrows_tail > 0) {
245 // reset pointers
246 sub(reg_src, src_shift * num_nrows_loop);
247 sub(reg_tr_src, tr_src_shift * num_nrows_loop);
248 }
249}
250
251void jit_brgemm_trans_m_k_f32_t::generate() {
252 preamble();
253 assert(conf_->ic_block % transpose_size == 0);
254 const int os_block = conf_->os_block;
255 const int last_os_block_tail = conf_->K_tail % os_block;
256 const int ic_tail = conf_->M_tail % transpose_size;
257 src_stride = conf_->ic * typesize;
258 tr_src_stride = conf_->LDA * typesize;
259 const dim_t m_src_shift = transpose_size * typesize;
260 const dim_t m_tr_src_shift = tr_src_stride * transpose_size;
261
262 const dim_t batch_src_shift = src_stride * os_block;
263 const dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
264
265 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
266 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
267 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
268 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
269
270 auto kmovw = [=](Opmask k, unsigned w) {
271 mov(regw_tmp, w);
272 jit_generator::kmovw(k, regw_tmp);
273 };
274
275 kmovw(k3333, 0x3333); // 0011001100110011
276 kmovw(k5555, 0x5555); // 0101010101010101
277 kmovw(kAAAA, 0xaaaa); // 1010101010101010
278 kmovw(kCCCC, 0xcccc); // 1100110011001100
279 kmovw(k0F0F, 0x0f0f); // 0000111100001111
280 kmovw(kF0F0, 0xf0f0); // 1111000011110000
281
282 auto compute_M = [=](bool is_os_tail) {
283 const auto nrows = is_os_tail ? last_os_block_tail : os_block;
284 mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
285 mov(reg_src, reg_src_base);
286 mov(reg_tr_src, reg_tr_src_base);
287 Label M_loop, M_tail_or_done, M_done;
288 if (ic_tail > 0) {
289 cmp(reg_loop_M, transpose_size);
290 jl(M_tail_or_done, T_NEAR);
291 }
292
293 L(M_loop);
294 transpose(nrows, transpose_size);
295 if (conf_->ic_block > transpose_size) {
296 add(reg_src, m_src_shift);
297 add(reg_tr_src, m_tr_src_shift);
298 sub(reg_loop_M, transpose_size);
299 cmp(reg_loop_M, transpose_size);
300 jge(M_loop, T_NEAR);
301 } else {
302 jmp(M_done, T_NEAR);
303 }
304
305 L(M_tail_or_done);
306 if (ic_tail > 0) {
307 cmp(reg_loop_M, 0);
308 jle(M_done, T_NEAR);
309
310 transpose(nrows, ic_tail);
311 }
312 L(M_done);
313 };
314
315 auto compute_batch = [=](bool is_os_tail) {
316 Label batch_loop;
317 L(batch_loop);
318
319 compute_M(is_os_tail);
320 add(reg_src_base, batch_src_shift);
321 add(reg_tr_src_base, batch_tr_src_shift);
322
323 sub(reg_loop_batch, 1);
324 jnz(batch_loop, T_NEAR);
325 };
326
327 Label K_tail;
328 if (last_os_block_tail > 0) {
329 cmp(reg_loop_K, os_block);
330 jl(K_tail, T_NEAR);
331 }
332
333 compute_batch(false);
334
335 if (last_os_block_tail > 0) {
336 Label K_done;
337 jmp(K_done, T_NEAR);
338
339 L(K_tail);
340 compute_batch(true);
341 L(K_done);
342 }
343
344 postamble();
345}
346
347struct jit_brgemm_trans_m_k_bf16_t : public jit_brgemm_trans_src_t,
348 public jit_generator {
349 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_bf16_t)
350 jit_brgemm_trans_m_k_bf16_t(const jit_brgemm_primitive_conf_t *conf)
351 : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {}
352
353 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
354 status_t create_kernel() override { return jit_generator::create_kernel(); }
355
356private:
357 using reg64_t = const Xbyak::Reg64;
358 using reg32_t = const Xbyak::Reg32;
359 using opmask_t = const Xbyak::Opmask;
360
361 enum {
362 typesize = sizeof(int16_t),
363 transpose_size = 16,
364 };
365 dim_t src_stride = 0, tr_src_stride = 0;
366
367 opmask_t kFFFF = k1;
368 opmask_t k5555 = k2;
369 opmask_t kAAAA = k3;
370 opmask_t kAA = k4;
371 opmask_t k55 = k5;
372 opmask_t kCC = k6;
373 opmask_t k33 = k7;
374 opmask_t kTail = k1;
375
376 reg32_t regw_tmp = r15d;
377
378 reg64_t reg_k_src = r14;
379 reg64_t reg_k_tr_src = r13;
380
381 reg64_t reg_m_src = r12;
382 reg64_t reg_m_tr_src = r11;
383
384 reg64_t reg_batch_src = r10;
385 reg64_t reg_batch_tr_src = r9;
386
387 reg64_t reg_loop_batch = r8;
388 reg64_t reg_loop_K = rax;
389 reg64_t reg_loop_M = rbx;
390
391 reg64_t reg_tr_src_tmp = abi_not_param1; // lnx -> rcx
392 reg64_t imm_addr64 = rdx;
393
394 Xbyak::Zmm vidx1 = zmm31;
395 Xbyak::Zmm vidx2 = zmm30;
396 Xbyak::Zmm vidx3 = zmm29;
397 Xbyak::Zmm vidx4 = zmm28;
398 Xbyak::Zmm vidx5 = zmm27;
399 Xbyak::Zmm zmm_tmp = zmm26;
400
401 void transpose(
402 reg64_t dst, reg64_t src, int nrows, int ncolumns = transpose_size);
403 void generate() override;
404};
405
406void jit_brgemm_trans_m_k_bf16_t::transpose(
407 reg64_t dst, reg64_t src, int nrows, int ncolumns) {
408 assert(nrows >= 0 && nrows <= transpose_size);
409 static_assert(transpose_size == 16, "Unsupported transpose size");
410 if (!nrows) return;
411
412 auto src_zmm = [=](int i) { return Zmm(i); };
413
414 auto src_ymm = [=](int i) {
415 assert(i >= 0 && i < 16);
416 return Ymm(i);
417 };
418
419 auto kmovw = [=](Opmask k, unsigned w) {
420 mov(regw_tmp, w);
421 jit_generator::kmovw(k, regw_tmp);
422 };
423
424 auto kmovd = [=](Opmask k, unsigned w) {
425 mov(regw_tmp, w);
426 jit_generator::kmovd(k, regw_tmp);
427 };
428
429 auto store = [=](Zmm r, int i) {
430 mov(reg_tr_src_tmp, dst);
431
432 auto k = kTail;
433 auto base = reg_tr_src_tmp;
434 base.setOpmaskIdx(k.getIdx(), true);
435
436 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
437 vmovups(addr, r);
438 };
439
440 const int ic_block = ncolumns;
441 kmovd(kFFFF, ic_block < transpose_size ? (1 << ic_block) - 1 : 0xffff);
442
443 for (int i = 0; i < nrows / 2; i++) {
444 auto zmm_src0 = src_zmm(2 * i);
445 auto zmm_src1 = src_zmm(2 * i + 1);
446 auto src1 = src_ymm(2 * i + 1);
447 vmovdqu16(zmm_src0 | kFFFF | T_z,
448 EVEX_compress_addr(src, 2 * i * src_stride));
449 vmovdqu16(zmm_src1 | kFFFF | T_z,
450 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
451 vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
452 vpermw(zmm_src0, vidx5, zmm_src0);
453 }
454
455 // for odd numbers we need to mix row with zeroes
456 if (nrows % 2) {
457 int i = nrows / 2;
458 auto zmm_src0 = src_zmm(2 * i);
459 vmovdqu16(zmm_src0 | kFFFF | T_z,
460 EVEX_compress_addr(src, 2 * i * src_stride));
461 vpermw(zmm_src0, vidx5, zmm_src0);
462 }
463
464 for (int i = rnd_up(nrows, 2); i < 16; i += 2) {
465 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
466 }
467
468 // swap 1
469 for (int i = 0; i < 4; i++) {
470 auto zmm0 = src_zmm(4 * i);
471 auto zmm1 = src_zmm(4 * i + 2);
472 auto tmp0 = src_zmm(4 * i + 1);
473 auto tmp1 = src_zmm(4 * i + 3);
474
475 vmovups(tmp0, zmm0);
476 vmovups(tmp1, zmm1);
477
478 vpermps(tmp0 | kAAAA, vidx3, zmm1);
479 vpermps(tmp1 | k5555, vidx3, zmm0);
480 }
481 // swap 2
482 int base_idx;
483 base_idx = 0;
484 for (int i = 0; i < 2; i++) {
485 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
486 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
487
488 auto tmp0 = src_zmm(base_idx + 2 * i);
489 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
490
491 vmovupd(tmp0, zmm0);
492 vmovupd(tmp1, zmm1);
493
494 vpermpd(tmp0 | kAA, vidx2, zmm1);
495 vpermpd(tmp1 | k55, vidx2, zmm0);
496 }
497 base_idx = 8;
498 for (int i = 0; i < 2; i++) {
499 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
500 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
501
502 auto tmp0 = src_zmm(base_idx + 2 * i);
503 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
504
505 vmovupd(tmp0, zmm0);
506 vmovupd(tmp1, zmm1);
507
508 vpermpd(tmp0 | kAA, vidx2, zmm1);
509 vpermpd(tmp1 | k55, vidx2, zmm0);
510 }
511
512 // swap 3
513 for (int i = 0; i < 4; i++) {
514 auto zmm0 = src_zmm(2 * i);
515 auto zmm1 = src_zmm(2 * i + 8);
516
517 auto tmp0 = src_zmm(2 * i + 1);
518 auto tmp1 = src_zmm(2 * i + 9);
519
520 vmovupd(tmp0, zmm0);
521 vmovupd(tmp1, zmm1);
522
523 vpermpd(tmp0 | kCC, vidx1, zmm1);
524 vpermpd(tmp1 | k33, vidx1, zmm0);
525 }
526
527 // all stores
528 for (int i = 0; i < 8; i++)
529 vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
530
531 auto get_vec_idx = [=](int ic_idx) {
532 assert(ic_idx < 16 && ic_idx >= 0);
533 switch (ic_idx) {
534 case 0: return 1;
535 case 1: return 0;
536 case 2: return 3;
537 case 3: return 2;
538 case 4: return 9;
539 case 5: return 8;
540 case 6: return 11;
541 case 7: return 10;
542 case 8: return 5;
543 case 9: return 4;
544 case 10: return 7;
545 case 11: return 6;
546 case 12: return 13;
547 case 13: return 12;
548 case 14: return 15;
549 default: return 14;
550 }
551 };
552
553 int store_tail = rnd_up(nrows, 2);
554 kmovw(kTail, (1 << store_tail / 2) - 1);
555
556 for (int ic = 0; ic < ic_block; ic++)
557 store(src_zmm(get_vec_idx(ic)), ic);
558}
559
560void jit_brgemm_trans_m_k_bf16_t::generate() {
561 preamble();
562
563 alignas(64) static constexpr const int64_t idx1[8]
564 = {2, 3, 0, 1, 6, 7, 4, 5};
565 alignas(64) static constexpr const int64_t idx2[8]
566 = {1, 0, 3, 2, 5, 4, 7, 6};
567 alignas(64) static constexpr const int32_t idx3[16]
568 = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
569 alignas(64) static constexpr const int32_t idx4[16]
570 = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
571 alignas(64) static constexpr const uint16_t idx5[32]
572 = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
573 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
574
575 constexpr int amx_xf16_granularity = 2;
576 const bool last_row_padded = is_superset(conf_->isa, avx512_core_amx)
577 && conf_->os % amx_xf16_granularity != 0;
578 const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
579
580 const int os_block = conf_->os_block;
581 const int last_os_block_tail = eff_K_tail % transpose_size;
582 const int ic_tail = conf_->M_tail % transpose_size;
583 src_stride = conf_->ic * typesize;
584 tr_src_stride = conf_->LDA * typesize;
585
586 const dim_t batch_src_shift = src_stride * os_block;
587 const dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
588
589 const dim_t M_src_shift = transpose_size * typesize;
590 const dim_t M_tr_src_shift = transpose_size * conf_->LDA * typesize;
591
592 const dim_t K_src_shift = transpose_size * conf_->ic * typesize;
593 const dim_t K_tr_src_shift = transpose_size * typesize;
594
595 auto kmovw = [=](Opmask k, unsigned w) {
596 mov(regw_tmp, w);
597 jit_generator::kmovw(k, regw_tmp);
598 };
599
600 kmovw(kFFFF, 0xffff);
601 kmovw(k5555, 0x5555);
602 kmovw(kAAAA, 0xaaaa);
603 kmovw(kAA, 0xaa);
604 kmovw(k55, 0x55);
605 kmovw(kCC, 0xcc);
606 kmovw(k33, 0x33);
607
608 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
609 mov(imm_addr64, reinterpret_cast<size_t>(addr));
610 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
611 };
612
613 auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
614 mov(imm_addr64, reinterpret_cast<size_t>(addr));
615 jit_generator::vmovdqa32(z, ptr[imm_addr64]);
616 };
617
618 vmovdqa64(vidx1, idx1);
619 vmovdqa64(vidx2, idx2);
620 vmovdqa32(vidx3, idx3);
621 vmovdqa32(vidx4, idx4);
622 vmovdqa32(vidx5, (const int32_t *)idx5);
623
624 auto compute_m_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base,
625 bool is_os_tail) {
626 mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
627 mov(reg_m_src, reg_base);
628 mov(reg_m_tr_src, reg_tr_base);
629
630 Label M_loop_tail, M_loop;
631 if (ic_tail > 0) {
632 cmp(reg_loop_M, transpose_size);
633 jl(M_loop_tail, T_NEAR);
634 }
635 L(M_loop);
636 {
637 transpose(reg_m_tr_src, reg_m_src,
638 is_os_tail ? last_os_block_tail : transpose_size,
639 transpose_size);
640 add(reg_m_src, M_src_shift);
641 add(reg_m_tr_src, M_tr_src_shift);
642 }
643 sub(reg_loop_M, transpose_size);
644 cmp(reg_loop_M, transpose_size);
645 jge(M_loop, T_NEAR);
646
647 if (ic_tail > 0) {
648 Label M_loop_done;
649 L(M_loop_tail);
650 cmp(reg_loop_M, 0);
651 jle(M_loop_done, T_NEAR);
652
653 transpose(reg_m_tr_src, reg_m_src,
654 is_os_tail ? last_os_block_tail : transpose_size, ic_tail);
655 L(M_loop_done);
656 }
657 };
658
659 auto compute_k_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base) {
660 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
661 mov(reg_k_src, reg_base);
662 mov(reg_k_tr_src, reg_tr_base);
663
664 Label K_tail, K_loop, K_done;
665 if (last_os_block_tail > 0) {
666 cmp(reg_loop_K, transpose_size);
667 jl(K_tail, T_NEAR);
668 }
669 L(K_loop);
670 {
671 compute_m_loop(reg_k_src, reg_k_tr_src, false);
672 add(reg_k_src, K_src_shift);
673 add(reg_k_tr_src, K_tr_src_shift);
674 }
675 sub(reg_loop_K, transpose_size);
676 cmp(reg_loop_K, transpose_size);
677 jge(K_loop, T_NEAR);
678
679 cmp(reg_loop_K, 0);
680 je(K_done, T_NEAR);
681
682 if (last_os_block_tail > 0) {
683 L(K_tail);
684 compute_m_loop(reg_k_src, reg_k_tr_src, true);
685 }
686 L(K_done);
687 };
688
689 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
690 mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
691 mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
692
693 Label batch_loop;
694 L(batch_loop);
695 {
696 compute_k_loop(reg_batch_src, reg_batch_tr_src);
697
698 add(reg_batch_src, batch_src_shift);
699 add(reg_batch_tr_src, batch_tr_src_shift);
700 }
701 sub(reg_loop_batch, 1);
702 jnz(batch_loop, T_NEAR);
703
704 postamble();
705}
706
707struct jit_brgemm_trans_m_k_f16_t : public jit_brgemm_trans_src_t,
708 public jit_generator {
709 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f16_t)
710
711 jit_brgemm_trans_m_k_f16_t(const jit_brgemm_primitive_conf_t *conf)
712 : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {}
713
714 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
715 status_t create_kernel() override { return jit_generator::create_kernel(); }
716
717private:
718 using reg64_t = const Xbyak::Reg64;
719 using reg32_t = const Xbyak::Reg32;
720 using opmask_t = const Xbyak::Opmask;
721
722 enum {
723 typesize_in = sizeof(float16_t),
724 typesize_out = sizeof(float),
725 transpose_size = 16
726 };
727 dim_t src_stride = 0, tr_src_stride = 0;
728
729 opmask_t k3333 = k1;
730 opmask_t k5555 = k2;
731 opmask_t kAAAA = k3;
732 opmask_t kCCCC = k4;
733 opmask_t k0F0F = k5;
734 opmask_t kF0F0 = k6;
735 opmask_t kTail = k7;
736
737 reg64_t reg_src_base = rax;
738 reg64_t reg_tr_src_base = rbx;
739
740 reg64_t reg_src = r8;
741 reg64_t reg_tr_src = r9;
742 reg64_t reg_loop_K = r10;
743 reg64_t reg_loop_M = r11;
744 reg64_t reg_loop_batch = r12;
745 reg64_t reg_tr_src_tmp = r13;
746 reg32_t regw_tmp = r14d;
747
748 void transpose_16x16(int nrows, int ncolumns = transpose_size);
749 void generate() override;
750};
751
752void jit_brgemm_trans_m_k_f16_t::transpose_16x16(int nrows, int ncolumns) {
753 assert(nrows >= 0 && nrows <= transpose_size);
754 static_assert(transpose_size == 16, "Unsupported transpose size");
755 if (!nrows) return;
756
757 auto src_zmm = [=](int i) {
758 assert(i >= 0 && i < 16);
759 return Zmm(i);
760 };
761
762 auto tmp_zmm = [=](int i) {
763 assert(i >= 0 && i < 16);
764 return Zmm(16 + i);
765 };
766
767 auto kmovw = [=](Opmask k, unsigned w) {
768 mov(regw_tmp, w);
769 jit_generator::kmovw(k, regw_tmp);
770 };
771
772 auto load = [=](int i) {
773 auto src_load = src_zmm(i);
774 if (i >= nrows) {
775 vpxord(src_load, src_load, src_load);
776 return;
777 }
778
779 if (ncolumns < transpose_size) {
780 kmovw(kTail, (1 << ncolumns) - 1);
781 src_load = src_zmm(i) | kTail | T_z;
782 }
783 vcvtph2psx(src_load, EVEX_compress_addr(reg_src, i * src_stride));
784 };
785
786 auto store = [=](Zmm r, int i) {
787 mov(reg_tr_src_tmp, reg_tr_src);
788 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
789
790 // Xbyak does not allow k0 to be specified explicitly via the '|'
791 // operator, so we have to do this via a method call (implicitly
792 // EVEX encoding uses k0 to mean 'no mask')
793 const bool partial_store = nrows < transpose_size;
794 const auto k = partial_store ? kTail : k0;
795 auto base = reg_tr_src_tmp;
796 base.setOpmaskIdx(k.getIdx(), true);
797
798 const auto addr = EVEX_compress_addr(base, i * tr_src_stride);
799 vmovups(addr, r);
800 };
801
802 auto transpose16x8 = [=](int base_idx) {
803 assert(base_idx == 0 || base_idx == 8);
804
805 // swap 1
806 for (int i = 0; i < 4; i++) {
807 const int src_idx0 = base_idx + i * 2;
808 const int src_idx1 = src_idx0 + 1;
809
810 const int next_src_idx0 = src_idx0 + 2;
811 const int next_src_idx1 = src_idx1 + 2;
812 const bool load_next = base_idx == 0 || i < 3;
813
814 if (base_idx == 0 && i == 0) {
815 load(src_idx0);
816 if (src_idx1 < nrows)
817 load(src_idx1);
818 else
819 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
820 src_zmm(src_idx1));
821 }
822
823 const auto tmp0 = tmp_zmm(src_idx0);
824 const auto tmp1 = tmp_zmm(src_idx1);
825 const auto src0 = src_zmm(src_idx0);
826 const auto src1 = src_zmm(src_idx1);
827
828 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
829 valignd(tmp0, src0, src0, 0x1);
830
831 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
832 valignd(tmp1, src1, src1, 0xf);
833
834 vmovaps(src0 | kAAAA, tmp1);
835 vmovaps(src1 | k5555, tmp0);
836 }
837 // swap 2
838 for (int i = 0; i < 4; i++) {
839 const int select_half = (i < 2) ? 0 : 2;
840 const int src_idx0 = base_idx + i + select_half + 0;
841 const int src_idx2 = src_idx0 + 2;
842
843 const auto tmp0 = tmp_zmm(src_idx0);
844 const auto tmp1 = tmp_zmm(src_idx2);
845 const auto src0 = src_zmm(src_idx0);
846 const auto src2 = src_zmm(src_idx2);
847
848 valignd(tmp0, src0, src0, 0x2);
849 valignd(tmp1, src2, src2, 0xe);
850 vmovaps(src2 | k3333, tmp0);
851 vmovaps(src0 | kCCCC, tmp1);
852 }
853
854 // swap 4
855 for (int i = 0; i < 4; i++) {
856 const int src_idx0 = base_idx + i;
857 const int src_idx4 = src_idx0 + 4;
858
859 const auto tmp0 = tmp_zmm(src_idx0);
860 const auto src0 = src_zmm(src_idx0);
861 const auto src4 = src_zmm(src_idx4);
862
863 vmovaps(tmp0, src0);
864 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
865 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
866 }
867 };
868
869 auto fixup16x16 = [=]() {
870 // swap 8
871 const auto max_iters_phase_1 = std::min(ncolumns, 8);
872 for (int i = 0; i < max_iters_phase_1; i++) {
873 const auto tmp = tmp_zmm(i);
874 const auto src0 = src_zmm(i);
875 const auto src8 = src_zmm(8 + i);
876 vshuff64x2(tmp, src0, src8, 0x44);
877 store(tmp, i);
878 }
879
880 const auto max_iters_phase_2 = std::min(ncolumns - 8, 8);
881 for (int i = 0; i < max_iters_phase_2; i++) {
882 const auto tmp = tmp_zmm(8 + i);
883 const auto src0 = src_zmm(i);
884 const auto src8 = src_zmm(8 + i);
885 vshuff64x2(tmp, src0, src8, 0xee);
886 store(tmp, 8 + i);
887 }
888 };
889
890 transpose16x8(0);
891 transpose16x8(8);
892 fixup16x16();
893}
894
895void jit_brgemm_trans_m_k_f16_t::generate() {
896 preamble();
897 assert(conf_->ic_block % transpose_size == 0);
898 const int os_block = conf_->os_block;
899 const int last_os_block_tail = conf_->K_tail % transpose_size;
900 const int ic_tail = conf_->M_tail % transpose_size;
901 src_stride = conf_->ic * typesize_in;
902 tr_src_stride = conf_->LDA * typesize_out;
903 const dim_t m_src_shift = transpose_size * typesize_in;
904 const dim_t m_tr_src_shift = tr_src_stride * transpose_size;
905
906 const dim_t batch_src_shift = src_stride * os_block;
907 const dim_t batch_tr_src_shift = tr_src_stride * conf_->M;
908
909 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
910 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
911 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
912 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
913
914 auto kmovw = [=](Opmask k, unsigned w) {
915 mov(regw_tmp, w);
916 jit_generator::kmovw(k, regw_tmp);
917 };
918
919 kmovw(k3333, 0x3333); // 0011001100110011
920 kmovw(k5555, 0x5555); // 0101010101010101
921 kmovw(kAAAA, 0xaaaa); // 1010101010101010
922 kmovw(kCCCC, 0xcccc); // 1100110011001100
923 kmovw(k0F0F, 0x0f0f); // 0000111100001111
924 kmovw(kF0F0, 0xf0f0); // 1111000011110000
925
926 auto compute_M = [=](bool is_os_tail) {
927 const auto nrows = is_os_tail ? last_os_block_tail : transpose_size;
928 mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]);
929 mov(reg_src, reg_src_base);
930 mov(reg_tr_src, reg_tr_src_base);
931 Label M_loop, M_tail_or_done, M_done;
932 if (ic_tail > 0) {
933 cmp(reg_loop_M, transpose_size);
934 jl(M_tail_or_done, T_NEAR);
935 }
936
937 L(M_loop);
938 transpose_16x16(nrows, transpose_size);
939 if (conf_->ic_block > transpose_size) {
940 add(reg_src, m_src_shift);
941 add(reg_tr_src, m_tr_src_shift);
942 sub(reg_loop_M, transpose_size);
943 cmp(reg_loop_M, transpose_size);
944 jge(M_loop, T_NEAR);
945 } else {
946 jmp(M_done, T_NEAR);
947 }
948
949 L(M_tail_or_done);
950 if (ic_tail > 0) {
951 cmp(reg_loop_M, 0);
952 jle(M_done, T_NEAR);
953
954 transpose_16x16(nrows, ic_tail);
955 }
956 L(M_done);
957 };
958
959 auto compute_batch = [=](bool is_os_tail) {
960 Label batch_loop;
961 L(batch_loop);
962
963 compute_M(is_os_tail);
964 add(reg_src_base, batch_src_shift);
965 add(reg_tr_src_base, batch_tr_src_shift);
966
967 sub(reg_loop_batch, 1);
968 jnz(batch_loop, T_NEAR);
969 };
970
971 Label K_tail;
972 if (last_os_block_tail > 0) {
973 cmp(reg_loop_K, transpose_size);
974 jl(K_tail, T_NEAR);
975 }
976
977 compute_batch(false);
978
979 if (last_os_block_tail > 0) {
980 Label K_done;
981 jmp(K_done, T_NEAR);
982
983 L(K_tail);
984 compute_batch(true);
985 L(K_done);
986 }
987
988 postamble();
989}
990
991void jit_brgemm_copy_to_coarse_t::copy_row_blks(int num_row_blks) {
992 int rnd_row_blks = div_up(num_row_blks, row_loop_unroll);
993
994 for (int row_b = 0; row_b < rnd_row_blks; ++row_b) {
995 const int row_start = 0;
996 const int row_end = nstl::min(static_cast<int>(row_loop_unroll),
997 num_row_blks - row_b * static_cast<int>(row_loop_unroll));
998
999 for (int row = row_start; row < row_end; ++row) {
1000 const int row_idx = row_b * row_loop_unroll + row;
1001 const auto offset = addr_offset(row_idx);
1002
1003 const auto zmm = get_zmm_copy(row);
1004 const auto addr = EVEX_compress_addr(reg_data, offset);
1005 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
1006
1007 vmovdqu8(zmm, addr);
1008 vmovdqu8(addr_tr, zmm);
1009 }
1010 }
1011}
1012
1013void jit_brgemm_copy_to_coarse_t::copy_row_tail(
1014 bool is_last_iteration, int row_offset) {
1015 // Masks for row tail load and store are already set up
1016 const auto load_mask = is_last_iteration ? reg_m_last_row_tail_load
1017 : reg_m_full_row_tail_load;
1018 const auto store_mask = is_last_iteration ? reg_m_last_row_tail_store
1019 : reg_m_full_row_tail_store;
1020
1021 const auto zmm_data = zmm_row_tail | load_mask | T_z;
1022 const auto zmm_tr_data = zmm_row_tail | store_mask;
1023
1024 const auto offset = addr_offset(row_offset);
1025 const auto addr = EVEX_compress_addr(reg_data, offset);
1026 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
1027
1028 vmovdqu8(zmm_data, addr);
1029 vmovdqu8(addr_tr, zmm_tr_data);
1030}
1031
1032void jit_brgemm_copy_to_coarse_t::zero_out_rows() {
1033 const int row_blk = row_size_ % tr_row_size_;
1034 const int rnd_up_row_blk = utils::rnd_up(row_blk, row_step_);
1035
1036 int zero_row_blks = tr_row_size_ - rnd_up_row_blk;
1037 if (zero_row_blks == 0) return;
1038
1039 const auto zmm_step = row_step_, ymm_step = row_step_ / 2,
1040 xmm_step = row_step_ / 4;
1041 assert(zero_row_blks % xmm_step == 0);
1042 MAYBE_UNUSED(xmm_step);
1043
1044 int zmm_iters = zero_row_blks / zmm_step;
1045 zero_row_blks %= zmm_step;
1046 int ymm_iters = zero_row_blks / ymm_step;
1047 zero_row_blks %= ymm_step;
1048 int xmm_iters = zero_row_blks / xmm_step;
1049
1050 auto offset = addr_offset(rnd_up_row_blk / row_step_);
1051
1052 for (int row = 0; row < zmm_iters; ++row) {
1053 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
1054 vmovdqu8(addr_tr, zmm_zero);
1055 offset += (zmm_step * typesize_);
1056 }
1057
1058 const auto ymm_zero = Xbyak::Ymm(zmm_zero.getIdx());
1059 const auto xmm_zero = Xbyak::Xmm(zmm_zero.getIdx());
1060
1061 assert(xmm_iters <= 1 && ymm_iters <= 1);
1062 if (ymm_iters > 0) {
1063 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
1064 vmovdqu8(addr_tr, ymm_zero);
1065 offset += (ymm_step * typesize_);
1066 }
1067
1068 if (xmm_iters > 0) {
1069 const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset);
1070 vmovdqu8(addr_tr, xmm_zero);
1071 }
1072}
1073
1074void jit_brgemm_copy_to_coarse_t::copy_row_loop() {
1075 Xbyak::Label label_row_tail, label_row_exit;
1076
1077 // Note: copying is done in chunks of size row_step_
1078 const auto copy_row = [&](bool is_last_iteration) {
1079 const int row_blk
1080 = is_last_iteration ? (row_size_ % tr_row_size_) : tr_row_size_;
1081 const int row_iters = row_blk / row_step_;
1082 const int row_iters_tail = row_blk % row_step_;
1083
1084 copy_row_blks(row_iters);
1085 if (row_iters_tail != 0)
1086 copy_row_tail(is_last_iteration, /* row_offset = */ row_iters);
1087
1088 // For the last iteration, zero-out rows if needed
1089 if (is_last_iteration) zero_out_rows();
1090 };
1091
1092 const bool only_row_tail = row_size_ < tr_row_size_;
1093
1094 if (!only_row_tail) {
1095 cmp(reg_last_row_blk, 0);
1096 jne(label_row_tail, T_NEAR);
1097
1098 copy_row(/* is_last_iteration = */ false);
1099 jmp(label_row_exit, T_NEAR);
1100 }
1101
1102 L(label_row_tail);
1103 copy_row(/* is_last_iteration = */ true);
1104
1105 L(label_row_exit);
1106}
1107
1108void jit_brgemm_copy_to_coarse_t::copy_os_loop() {
1109
1110 Label loop_os;
1111 L(loop_os);
1112
1113 copy_row_loop();
1114 add(reg_data, data_stride_);
1115 add(reg_tr_data, tr_data_stride_);
1116
1117 dec(reg_os_work);
1118 jnz(loop_os, T_NEAR);
1119}
1120
1121void jit_brgemm_copy_to_coarse_t::set_last_row_tail_masks() {
1122 const int row_tail = (row_size_ % tr_row_size_) % row_step_;
1123 assert(row_tail > 0 && "kernel is meant to be used with tail processing");
1124
1125 // Set load mask
1126 const size_t tail_mask_load
1127 = (static_cast<size_t>(1) << (typesize_ * row_tail)) - 1;
1128 mov(reg_tail_mask, tail_mask_load);
1129 kmovq(reg_m_last_row_tail_load, reg_tail_mask);
1130
1131 // Caution: Since size of ZMM equals 64 bytes therefore we need
1132 // different masks to store tails with smaller row_block_size_
1133 constexpr auto full_mask = size_t {0xffffffffffffffff};
1134 constexpr auto half_mask = size_t {0x00000000ffffffff};
1135 constexpr auto quad_mask = size_t {0x000000000000ffff};
1136
1137 const auto num_bytes = [](size_t mask) -> int {
1138 // Given by 1 + position of leftmost 1 bit
1139 return 1 + math::ilog2q(mask);
1140 };
1141
1142 const int row_tail_store_size
1143 = utils::rnd_up(row_tail, row_block_size_) * typesize_;
1144 if (row_tail_store_size >= num_bytes(full_mask))
1145 mov(reg_tail_mask, full_mask);
1146 else if (row_tail_store_size >= num_bytes(half_mask))
1147 mov(reg_tail_mask, half_mask);
1148 else {
1149 assert(row_tail_store_size == num_bytes(quad_mask));
1150 mov(reg_tail_mask, quad_mask);
1151 }
1152 kmovq(reg_m_last_row_tail_store, reg_tail_mask);
1153}
1154
1155void jit_brgemm_copy_to_coarse_t::set_full_row_tail_masks() {
1156 const auto full_row_tail = tr_row_size_ % row_step_;
1157 assert(row_step_ == 2 * full_row_tail || row_step_ == 4 * full_row_tail);
1158
1159 const auto tail_mask = row_step_ == 2 * full_row_tail
1160 ? size_t {0x00000000ffffffff}
1161 : size_t {0x000000000000ffff};
1162
1163 mov(reg_tail_mask, tail_mask);
1164 kmovq(reg_m_full_row_tail_store, reg_tail_mask);
1165 kmovq(reg_m_full_row_tail_load, reg_tail_mask);
1166}
1167
1168void jit_brgemm_copy_to_coarse_t::generate() {
1169 preamble();
1170
1171 // set up masks for tail processing
1172 set_last_row_tail_masks();
1173 const bool has_full_row_tail_ = tr_row_size_ % row_step_ != 0;
1174 if (has_full_row_tail_) set_full_row_tail_masks();
1175
1176 // init zero vreg (zmm_zero) if it is needed
1177 const int last_row_size
1178 = utils::rnd_up(row_size_ % tr_row_size_, row_step_);
1179 const bool zero_iters_needed
1180 = last_row_size > 0 && last_row_size < tr_row_size_;
1181 if (zero_iters_needed) vpxord(zmm_zero, zmm_zero, zmm_zero);
1182
1183 // load arguments to the jit kernel
1184 mov(reg_data, ptr[param1 + GET_OFF(data)]);
1185 mov(reg_tr_data, ptr[param1 + GET_OFF(tr_data)]);
1186 mov(reg_os_work, ptr[param1 + GET_OFF(os_work)]);
1187 mov(reg_last_row_blk, ptr[param1 + GET_OFF(last_row_blk)]);
1188
1189 // enter the `main` loop
1190 copy_os_loop();
1191
1192 postamble();
1193}
1194
1195struct jit_trans_to_vnni_t : public jit_brgemm_trans_to_vnni_t,
1196 public jit_generator {
1197 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_to_vnni_t)
1198 jit_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf,
1199 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
1200 matrix_to_transform)
1201 : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform)
1202 , jit_generator(jit_name()) {}
1203
1204 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1205 status_t create_kernel() override { return jit_generator::create_kernel(); }
1206
1207private:
1208 using reg64_t = const Xbyak::Reg64;
1209 using reg32_t = const Xbyak::Reg32;
1210 using opmask_t = const Xbyak::Opmask;
1211 using zmm = const Xbyak::Zmm;
1212
1213 enum {
1214 typesize_data = sizeof(int16_t),
1215 typesize_acc = sizeof(float),
1216 transpose_size = 16,
1217 };
1218
1219 int last_row_block_tail = 0, col_tail = 0;
1220 dim_t src_stride = 0, tr_src_stride = 0;
1221 dim_t src_col_shift = 0, tr_src_col_shift = 0;
1222 dim_t src_row_shift = 0, tr_src_row_shift = 0;
1223 dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
1224
1225 opmask_t kFFFF = k1;
1226 opmask_t mask_tail = k2;
1227
1228 zmm vidx1 = zmm31;
1229
1230 reg32_t regw_tmp = r15d;
1231
1232 reg64_t reg_batch_src = r14;
1233 reg64_t reg_batch_tr_src = r13;
1234
1235 reg64_t reg_row_src = r12;
1236 reg64_t reg_row_tr_src = r11;
1237
1238 reg64_t reg_col_src = r10;
1239 reg64_t reg_col_tr_src = r9;
1240
1241 reg64_t reg_loop_batch = r8;
1242 reg64_t reg_loop_row = rax;
1243 reg64_t reg_loop_col = rbx;
1244
1245 reg64_t imm_addr64 = abi_not_param1; // lnx -> rcx
1246
1247 void maybe_zero_pad_col(reg64_t dst);
1248 void transpose(reg64_t dst, reg64_t src, int nrows,
1249 int ncolumns = transpose_size, bool pad_by_zeroes = false);
1250 void generate() override;
1251};
1252
1253void jit_trans_to_vnni_t::maybe_zero_pad_col(reg64_t dst) {
1254 auto zmm_zero = Xbyak::Zmm(0);
1255 vpxord(zmm_zero, zmm_zero, zmm_zero);
1256 const int oc_utilized = rnd_up(conf_->oc % conf_->oc_block, transpose_size);
1257 const int iters = (conf_->oc_block - oc_utilized) / transpose_size;
1258 for (int n = 0; n < iters; ++n) {
1259 for (int i = 0; i < transpose_size; i += 2) {
1260 auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
1261 vmovups(addr, zmm_zero);
1262 }
1263 add(reg_col_tr_src, tr_src_col_shift);
1264 }
1265}
1266
1267void jit_trans_to_vnni_t::transpose(
1268 reg64_t dst, reg64_t src, int nrows, int ncolumns, bool pad_by_zeroes) {
1269 assert(nrows >= 0 && nrows <= transpose_size);
1270 static_assert(transpose_size == 16, "Unsupported transpose size");
1271 if (!nrows) return;
1272
1273 auto src_zmm = [=](int i) { return Zmm(i); };
1274
1275 auto src_ymm = [=](int i) {
1276 assert(i >= 0 && i < 16);
1277 return Ymm(i);
1278 };
1279
1280 auto store = [=](Zmm r, int i) {
1281 auto addr = EVEX_compress_addr(dst, i * tr_src_stride);
1282 vmovups(addr, r);
1283 };
1284 auto mask = ncolumns == transpose_size ? kFFFF : mask_tail;
1285
1286 int i = 0;
1287 for (; i < nrows / 2; i++) {
1288 auto src1 = src_ymm(2 * i + 1);
1289 auto zmm_src0 = src_zmm(2 * i);
1290 auto zmm_src1 = src_zmm(2 * i + 1);
1291 if (matrix_to_transform_ == matrix_B) {
1292 vmovdqu16(zmm_src0 | mask | T_z,
1293 EVEX_compress_addr(src, 2 * i * src_stride));
1294 vmovdqu16(zmm_src1 | mask | T_z,
1295 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
1296 vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
1297 } else {
1298 vmovups(zmm_src0 | mask | T_z,
1299 EVEX_compress_addr(src, 2 * i * src_stride));
1300 vmovups(zmm_src1 | mask | T_z,
1301 EVEX_compress_addr(src, (2 * i + 1) * src_stride));
1302 vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0);
1303 }
1304 vpermw(zmm_src0, vidx1, zmm_src0);
1305 store(zmm_src0, 2 * i);
1306 }
1307
1308 if (nrows % 2) {
1309 auto zmm_src0 = src_zmm(2 * i);
1310 if (matrix_to_transform_ == matrix_B) {
1311 vmovdqu16(zmm_src0 | mask | T_z,
1312 EVEX_compress_addr(src, 2 * i * src_stride));
1313 } else {
1314 auto zmm_zero = src_zmm(2 * i + 1);
1315 vmovups(zmm_src0 | mask | T_z,
1316 EVEX_compress_addr(src, 2 * i * src_stride));
1317 vpxord(zmm_zero, zmm_zero, zmm_zero);
1318 vcvtne2ps2bf16(zmm_src0, zmm_zero, zmm_src0);
1319 }
1320 vpermw(zmm_src0, vidx1, zmm_src0);
1321 store(zmm_src0, 2 * i);
1322 i++;
1323 }
1324
1325 if (pad_by_zeroes && i < transpose_size / 2) {
1326 auto zmm_zero = src_zmm(2 * i);
1327 vpxord(zmm_zero, zmm_zero, zmm_zero);
1328 for (; i < transpose_size / 2; i++)
1329 store(zmm_zero, 2 * i);
1330 }
1331}
1332
1333void jit_trans_to_vnni_t::generate() {
1334 preamble();
1335
1336 alignas(64) static constexpr const int16_t idx1[32]
1337 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
1338 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
1339
1340 if (matrix_to_transform_ == matrix_B) {
1341 int row_block = conf_->os_block;
1342
1343 constexpr int amx_xf16_granularity = 2;
1344 const bool last_row_padded = is_superset(conf_->isa, avx512_core_amx)
1345 && conf_->os % amx_xf16_granularity != 0;
1346 const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0);
1347
1348 last_row_block_tail = eff_K_tail % transpose_size;
1349 col_tail = conf_->oc % transpose_size;
1350 src_stride = conf_->oc * typesize_data;
1351 tr_src_stride = conf_->LDB * typesize_data;
1352
1353 src_batch_shift = src_stride * row_block;
1354 tr_src_batch_shift = tr_src_stride * rnd_up(conf_->K, 2);
1355
1356 src_col_shift = transpose_size * typesize_data;
1357 tr_src_col_shift = 2 * transpose_size * typesize_data;
1358
1359 src_row_shift = transpose_size * conf_->oc * typesize_data;
1360 tr_src_row_shift = transpose_size * conf_->LDB * typesize_data;
1361
1362 } else { // matrix_to_transform_ == matrix_C
1363 int row_block = conf_->ic_block;
1364 last_row_block_tail = conf_->M_tail % transpose_size;
1365 assert(row_block == transpose_size);
1366 col_tail = conf_->oc % transpose_size;
1367 src_stride = conf_->LDC * typesize_acc;
1368 tr_src_stride = conf_->LDD * typesize_data;
1369
1370 src_batch_shift = src_stride * row_block;
1371 tr_src_batch_shift = tr_src_stride * rnd_up(conf_->M, 2);
1372
1373 src_col_shift = transpose_size * typesize_acc;
1374 tr_src_col_shift = 2 * transpose_size * typesize_data;
1375 }
1376
1377 // mov(reg_src, ptr[param1 + GET_OFF(src)]);
1378 // mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1379 // mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1380
1381 auto kmovw = [=](Opmask k, unsigned w) {
1382 mov(regw_tmp, w);
1383 jit_generator::kmovw(k, regw_tmp);
1384 };
1385 auto kmovd = [=](Opmask k, unsigned w) {
1386 mov(regw_tmp, w);
1387 jit_generator::kmovd(k, regw_tmp);
1388 };
1389
1390 kmovw(kFFFF, 0xffff); // 1111111111111111
1391 kmovd(mask_tail, (1 << col_tail) - 1);
1392
1393 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
1394 mov(imm_addr64, reinterpret_cast<size_t>(addr));
1395 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
1396 };
1397
1398 vmovdqa64(vidx1, (const int64_t *)idx1);
1399
1400 auto compute_col_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base,
1401 bool is_row_tail) {
1402 const bool pad_by_zeroes = matrix_to_transform_ == matrix_C;
1403 int nrows = is_row_tail ? last_row_block_tail : transpose_size;
1404
1405 mov(reg_col_src, reg_base);
1406 mov(reg_col_tr_src, reg_tr_base);
1407 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1408
1409 Label col_loop, col_loop_tail;
1410 cmp(reg_loop_col, transpose_size);
1411 jl(col_loop_tail, T_NEAR);
1412
1413 L(col_loop);
1414 {
1415 transpose(reg_col_tr_src, reg_col_src, nrows, transpose_size,
1416 pad_by_zeroes);
1417 add(reg_col_src, src_col_shift);
1418 add(reg_col_tr_src, tr_src_col_shift);
1419 }
1420 sub(reg_loop_col, transpose_size);
1421 cmp(reg_loop_col, transpose_size);
1422 jge(col_loop, T_NEAR);
1423
1424 L(col_loop_tail);
1425 if (col_tail > 0) {
1426 Label col_loop_done;
1427 cmp(reg_loop_col, 0);
1428 jle(col_loop_done, T_NEAR);
1429 transpose(reg_col_tr_src, reg_col_src, nrows, col_tail,
1430 pad_by_zeroes);
1431 L(col_loop_done);
1432 }
1433 const int oc_block_tail = conf_->oc % conf_->oc_block;
1434 const bool full_oc_block_utilized = oc_block_tail == 0
1435 || rnd_up(oc_block_tail, transpose_size) == conf_->oc_block;
1436 const bool col_pad_required = pad_by_zeroes && !full_oc_block_utilized;
1437
1438 if (col_pad_required) {
1439 Label col_pad_done;
1440 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1441 cmp(reg_loop_col, conf_->oc_block);
1442 je(col_pad_done, T_NEAR);
1443 if (col_tail > 0) add(reg_col_tr_src, tr_src_col_shift);
1444 maybe_zero_pad_col(reg_col_tr_src);
1445 L(col_pad_done);
1446 }
1447 };
1448
1449 auto compute_row_loop = [&](reg64_t &reg_base, reg64_t &reg_tr_base) {
1450 mov(reg_row_src, reg_base);
1451 mov(reg_row_tr_src, reg_tr_base);
1452 mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1453
1454 Label row_tail, row_loop, row_done;
1455 if (last_row_block_tail > 0) {
1456 cmp(reg_loop_row, transpose_size);
1457 jl(row_tail, T_NEAR);
1458 }
1459 L(row_loop);
1460 {
1461 compute_col_loop(reg_row_src, reg_row_tr_src, false);
1462
1463 add(reg_row_src, src_row_shift);
1464 add(reg_row_tr_src, tr_src_row_shift);
1465 }
1466 sub(reg_loop_row, transpose_size);
1467 cmp(reg_loop_row, transpose_size);
1468 jge(row_loop, T_NEAR);
1469
1470 cmp(reg_loop_row, 0);
1471 je(row_done, T_NEAR);
1472
1473 if (last_row_block_tail > 0) {
1474 L(row_tail);
1475 compute_col_loop(reg_row_src, reg_row_tr_src, true);
1476 }
1477 L(row_done);
1478 };
1479
1480 mov(reg_batch_src, ptr[param1 + GET_OFF(src)]);
1481 mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1482 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1483
1484 Label batch_loop;
1485 L(batch_loop);
1486 {
1487 compute_row_loop(reg_batch_src, reg_batch_tr_src);
1488
1489 add(reg_batch_src, src_batch_shift);
1490 add(reg_batch_tr_src, tr_src_batch_shift);
1491 }
1492 sub(reg_loop_batch, 1);
1493 jnz(batch_loop, T_NEAR);
1494
1495 postamble();
1496}
1497
1498struct jit_copy_f32_t : public jit_brgemm_trans_to_vnni_t,
1499 public jit_generator {
1500 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f32_t)
1501 jit_copy_f32_t(const jit_brgemm_primitive_conf_t *conf,
1502 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
1503 matrix_to_transform)
1504 : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform)
1505 , jit_generator(jit_name()) {}
1506
1507 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1508 status_t create_kernel() override { return jit_generator::create_kernel(); }
1509
1510private:
1511 using reg64_t = const Xbyak::Reg64;
1512 using reg32_t = const Xbyak::Reg32;
1513 using opmask_t = const Xbyak::Opmask;
1514 using zmm = const Xbyak::Zmm;
1515
1516 enum {
1517 typesize_data = sizeof(float),
1518 column_step = 16,
1519 num_regs = 32,
1520 };
1521
1522 dim_t src_stride = 0, tr_src_stride = 0;
1523 dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
1524 dim_t col_shift = column_step * typesize_data;
1525
1526 opmask_t mask_tail = k2;
1527
1528 reg64_t reg_src = r8;
1529 reg64_t reg_tr_src = r9;
1530 reg64_t reg_loop_batch = r10;
1531 reg64_t reg_loop_row = r11;
1532 reg64_t reg_loop_col = r12;
1533 reg32_t regw_tmp = r14d;
1534 reg64_t reg_long_offt = r15;
1535
1536 void copy_block(int nrows, int ncolumns);
1537 void generate() override;
1538};
1539
1540void jit_copy_f32_t::copy_block(int nrows, int ncolumns) {
1541
1542 auto kmovd = [=](Opmask k, unsigned w) {
1543 mov(regw_tmp, w);
1544 jit_generator::kmovd(k, regw_tmp);
1545 };
1546
1547 const int nc_tail = ncolumns % column_step;
1548 if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1);
1549
1550 auto get_zmm = [=](int i) { return Zmm(i % num_regs); };
1551
1552 auto load = [=](int r, int cb) {
1553 auto src_reg = get_zmm(r * cb);
1554 const bool is_tail
1555 = nc_tail > 0 && ncolumns - cb * column_step < column_step;
1556 auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg;
1557 const dim_t offset = r * src_stride + cb * col_shift;
1558 auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt);
1559 vmovups(src_load, addr);
1560 };
1561
1562 auto store = [=](int r, int cb) {
1563 auto reg = get_zmm(r * cb);
1564 const dim_t offset = r * tr_src_stride + cb * col_shift;
1565 auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt);
1566 vmovups(addr, reg);
1567 };
1568
1569 for_(int r = 0; r < nrows; r++)
1570 for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) {
1571 load(r, cb);
1572 store(r, cb);
1573 }
1574}
1575
1576void jit_copy_f32_t::generate() {
1577 preamble();
1578
1579 const int row_block = conf_->os_block;
1580 const int row_tail = conf_->os % row_block;
1581 const int col_block = conf_->oc_block * conf_->nb_oc_blocking;
1582 const int col_tail = conf_->oc % col_block;
1583 src_stride = conf_->oc * typesize_data;
1584 tr_src_stride = conf_->LDB * typesize_data;
1585 src_batch_shift = src_stride * row_block;
1586 tr_src_batch_shift = tr_src_stride * row_block;
1587
1588 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1589 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1590 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1591 mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1592 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1593
1594 auto compute_batch = [=](int nrows, int ncolumns) {
1595 Label batch_loop;
1596 L(batch_loop);
1597
1598 copy_block(nrows, ncolumns);
1599 add(reg_src, src_batch_shift);
1600 add(reg_tr_src, tr_src_batch_shift);
1601
1602 sub(reg_loop_batch, 1);
1603 jnz(batch_loop, T_NEAR);
1604 };
1605
1606 auto compute_rows = [=](int ncolumns) {
1607 Label row_done;
1608 if (row_tail > 0) {
1609 Label row_common;
1610 cmp(reg_loop_row, row_block);
1611 je(row_common, T_NEAR);
1612
1613 compute_batch(row_tail, ncolumns);
1614 jmp(row_done, T_NEAR);
1615
1616 L(row_common);
1617 }
1618
1619 compute_batch(row_block, ncolumns);
1620 L(row_done);
1621 };
1622
1623 Label col_done;
1624 if (col_tail > 0) {
1625 Label col_common;
1626 cmp(reg_loop_col, col_block);
1627 je(col_common, T_NEAR);
1628
1629 compute_rows(col_tail);
1630 jmp(col_done, T_NEAR);
1631
1632 L(col_common);
1633 }
1634
1635 compute_rows(col_block);
1636 L(col_done);
1637
1638 postamble();
1639}
1640
1641struct jit_copy_f16_t : public jit_brgemm_trans_to_vnni_t,
1642 public jit_generator {
1643 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f16_t)
1644 jit_copy_f16_t(const jit_brgemm_primitive_conf_t *conf,
1645 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t
1646 matrix_to_transform)
1647 : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform)
1648 , jit_generator(jit_name()) {
1649
1650 // matrix_to_transform_ == matrix_B, copy(f16) -> f32
1651 // matrix_to_transform_ == matrix_C, copy(f32) -> f16 + zero_pad
1652 if (matrix_to_transform_ == matrix_B) {
1653 row_block = conf_->os_block;
1654 row_tail = conf_->os % row_block;
1655 col_block = conf_->oc_block * conf_->nb_oc_blocking;
1656 col_tail = conf_->oc % col_block;
1657 typesize_in = types::data_type_size(data_type::f16);
1658 typesize_out = types::data_type_size(data_type::f32);
1659 src_stride = conf_->oc * typesize_in;
1660 tr_src_stride = conf_->LDB * typesize_out;
1661 src_batch_shift = src_stride * row_block;
1662 tr_src_batch_shift = tr_src_stride * row_block;
1663 } else { // matrix_C
1664 row_block = conf_->os_block;
1665 row_tail = conf_->os % row_block;
1666 col_block = conf_->oc_block * conf_->nb_oc_blocking;
1667 col_tail = conf_->oc % col_block;
1668 typesize_in = types::data_type_size(data_type::f32);
1669 typesize_out = types::data_type_size(data_type::f16);
1670 src_stride = conf_->LDB * typesize_in;
1671 tr_src_stride = conf_->LDB * typesize_out;
1672 src_batch_shift = src_stride * row_block;
1673 tr_src_batch_shift = tr_src_stride * row_block;
1674 }
1675
1676 col_shift_in = column_step * typesize_in;
1677 col_shift_out = column_step * typesize_out;
1678 }
1679
1680 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1681 status_t create_kernel() override { return jit_generator::create_kernel(); }
1682
1683private:
1684 using reg64_t = const Xbyak::Reg64;
1685 using reg32_t = const Xbyak::Reg32;
1686 using opmask_t = const Xbyak::Opmask;
1687 using zmm = const Xbyak::Zmm;
1688
1689 enum {
1690 column_step = 16,
1691 num_regs = 32,
1692 };
1693
1694 size_t typesize_in = 0;
1695 size_t typesize_out = 0;
1696
1697 int row_block = 0, row_tail = 0;
1698 int col_block = 0, col_tail = 0;
1699 dim_t src_stride = 0, tr_src_stride = 0;
1700 dim_t src_batch_shift = 0, tr_src_batch_shift = 0;
1701 dim_t col_shift_in = 0;
1702 dim_t col_shift_out = 0;
1703
1704 opmask_t mask_tail = k2;
1705
1706 reg64_t reg_src = r8;
1707 reg64_t reg_tr_src = r9;
1708 reg64_t reg_loop_batch = r10;
1709 reg64_t reg_loop_row = r11;
1710 reg64_t reg_loop_col = r12;
1711 reg32_t regw_tmp = r14d;
1712 reg64_t reg_long_offt = r15;
1713
1714 void copy_block(bool is_row_tail, bool is_col_tail);
1715 void generate() override;
1716};
1717
1718void jit_copy_f16_t::copy_block(bool is_row_tail, bool is_col_tail) {
1719
1720 const int nrows = is_row_tail && matrix_to_transform_ != matrix_C
1721 ? row_tail
1722 : row_block;
1723 const int ncolumns = is_col_tail && matrix_to_transform_ != matrix_C
1724 ? col_tail
1725 : col_block;
1726
1727 auto kmovd = [=](Opmask k, unsigned w) {
1728 mov(regw_tmp, w);
1729 jit_generator::kmovd(k, regw_tmp);
1730 };
1731
1732 const int nc_tail = ncolumns % column_step;
1733 if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1);
1734
1735 auto get_zmm = [=](int i) { return Zmm(i % num_regs); };
1736
1737 auto load = [=](int r, int cb) {
1738 auto src_reg = get_zmm(r * cb);
1739 const bool is_tail
1740 = nc_tail > 0 && ncolumns - cb * column_step < column_step;
1741 auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg;
1742 const dim_t offset = r * src_stride + cb * col_shift_in;
1743 auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt);
1744 if (matrix_to_transform_ == matrix_B)
1745 vcvtph2psx(src_load, addr);
1746 else { // matrix_c
1747 if (r < nrows)
1748 vmovups(src_load, addr);
1749 else
1750 vpxord(src_load, src_load, src_load);
1751 }
1752 };
1753
1754 auto store = [=](int r, int cb) {
1755 auto reg = get_zmm(r * cb);
1756 const dim_t offset = r * tr_src_stride + cb * col_shift_out;
1757 auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt);
1758 if (matrix_to_transform_ == matrix_B)
1759 vmovups(addr, reg);
1760 else // matrix_C
1761 vcvtps2ph(addr, reg, 0x4);
1762 };
1763
1764 for_(int r = 0; r < nrows; r++)
1765 for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) {
1766 load(r, cb);
1767 store(r, cb);
1768 }
1769}
1770
1771void jit_copy_f16_t::generate() {
1772 preamble();
1773
1774 mov(reg_src, ptr[param1 + GET_OFF(src)]);
1775 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
1776 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
1777 mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]);
1778 mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]);
1779
1780 auto compute_batch = [=](bool is_row_tail, bool is_col_tail) {
1781 Label batch_loop;
1782 L(batch_loop);
1783
1784 copy_block(is_row_tail, is_col_tail);
1785 add(reg_src, src_batch_shift);
1786 add(reg_tr_src, tr_src_batch_shift);
1787
1788 sub(reg_loop_batch, 1);
1789 jnz(batch_loop, T_NEAR);
1790 };
1791
1792 auto compute_rows = [=](bool is_col_tail) {
1793 Label row_done;
1794 if (row_tail > 0) {
1795 Label row_common;
1796 cmp(reg_loop_row, row_block);
1797 je(row_common, T_NEAR);
1798
1799 compute_batch(true, is_col_tail);
1800 jmp(row_done, T_NEAR);
1801
1802 L(row_common);
1803 }
1804
1805 compute_batch(false, is_col_tail);
1806 L(row_done);
1807 };
1808
1809 Label col_done;
1810 if (col_tail > 0) {
1811 Label col_common;
1812 cmp(reg_loop_col, col_block);
1813 je(col_common, T_NEAR);
1814
1815 compute_rows(true);
1816 jmp(col_done, T_NEAR);
1817
1818 L(col_common);
1819 }
1820
1821 compute_rows(false);
1822 L(col_done);
1823
1824 postamble();
1825}
1826
1827struct jit_brgemm_trans_wei_f32_t : public jit_brgemm_trans_wei_t,
1828 public jit_generator {
1829 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f32_t)
1830
1831 jit_brgemm_trans_wei_f32_t(const jit_brgemm_primitive_conf_t *conf)
1832 : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {}
1833
1834 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
1835 status_t create_kernel() override { return jit_generator::create_kernel(); }
1836
1837private:
1838 using reg64_t = const Xbyak::Reg64;
1839 using reg32_t = const Xbyak::Reg32;
1840 using opmask_t = const Xbyak::Opmask;
1841
1842 enum { typesize = sizeof(float), transpose_size = 16 };
1843 dim_t src_stride = 0, tr_src_stride = 0;
1844
1845 opmask_t k3333 = k1;
1846 opmask_t k5555 = k2;
1847 opmask_t kAAAA = k3;
1848 opmask_t kCCCC = k4;
1849 opmask_t k0F0F = k5;
1850 opmask_t kF0F0 = k6;
1851 opmask_t kTail = k7;
1852
1853 reg64_t reg_src_base = rax;
1854 reg64_t reg_tr_src_base = rbx;
1855
1856 reg64_t reg_src = r8;
1857 reg64_t reg_tr_src = r9;
1858 reg64_t reg_loop_N = r10;
1859 reg64_t reg_loop_K = r11;
1860 reg64_t reg_loop_batch = r12;
1861 reg64_t reg_tr_src_tmp = r13;
1862 reg32_t regw_tmp = r14d;
1863
1864 void transpose_16x16(int nrows, int ncolumns = transpose_size);
1865 void generate() override;
1866};
1867
1868void jit_brgemm_trans_wei_f32_t::transpose_16x16(int nrows, int ncolumns) {
1869 assert(nrows >= 0 && nrows <= transpose_size);
1870 static_assert(transpose_size == 16, "Unsupported transpose size");
1871 if (!nrows) return;
1872
1873 auto src_zmm = [=](int i) {
1874 assert(i >= 0 && i < 16);
1875 return Zmm(i);
1876 };
1877
1878 auto tmp_zmm = [=](int i) {
1879 assert(i >= 0 && i < 16);
1880 return Zmm(16 + i);
1881 };
1882
1883 auto kmovw = [=](Opmask k, unsigned w) {
1884 mov(regw_tmp, w);
1885 jit_generator::kmovw(k, regw_tmp);
1886 };
1887
1888 auto load = [=](int i) {
1889 auto src_load = src_zmm(i);
1890 if (ncolumns < transpose_size) {
1891 kmovw(kTail, (1 << ncolumns) - 1);
1892 src_load = src_zmm(i) | kTail | T_z;
1893 }
1894 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
1895 };
1896
1897 auto store = [=](Zmm r, int i) {
1898 mov(reg_tr_src_tmp, reg_tr_src);
1899 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
1900
1901 // Xbyak does not allow k0 to be specified explicitly via the '|'
1902 // operator, so we have to do this via a method call (implicitly
1903 // EVEX encoding uses k0 to mean 'no mask')
1904 bool partial_store = nrows < transpose_size;
1905 auto k = partial_store ? kTail : k0;
1906 auto base = reg_tr_src_tmp;
1907 base.setOpmaskIdx(k.getIdx(), true);
1908
1909 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
1910 vmovups(addr, r);
1911 };
1912
1913 auto transpose16x8 = [=](int base_idx) {
1914 assert(base_idx == 0 || base_idx == 8);
1915
1916 // swap 1
1917 for (int i = 0; i < 4; i++) {
1918 int src_idx0 = base_idx + i * 2;
1919 int src_idx1 = src_idx0 + 1;
1920
1921 int next_src_idx0 = src_idx0 + 2;
1922 int next_src_idx1 = src_idx1 + 2;
1923 bool load_next = base_idx == 0 || i < 3;
1924
1925 if (base_idx == 0 && i == 0) {
1926 load(src_idx0);
1927 if (src_idx1 < nrows)
1928 load(src_idx1);
1929 else
1930 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
1931 src_zmm(src_idx1));
1932 }
1933
1934 auto tmp0 = tmp_zmm(src_idx0);
1935 auto tmp1 = tmp_zmm(src_idx1);
1936 auto src0 = src_zmm(src_idx0);
1937 auto src1 = src_zmm(src_idx1);
1938
1939 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
1940 valignd(tmp0, src0, src0, 0x1);
1941
1942 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
1943 valignd(tmp1, src1, src1, 0xf);
1944
1945 vmovaps(src0 | kAAAA, tmp1);
1946 vmovaps(src1 | k5555, tmp0);
1947 }
1948 // swap 2
1949 for (int i = 0; i < 4; i++) {
1950 int select_half = (i < 2) ? 0 : 2;
1951 int src_idx0 = base_idx + i + select_half + 0;
1952 int src_idx2 = src_idx0 + 2;
1953
1954 auto tmp0 = tmp_zmm(src_idx0);
1955 auto tmp1 = tmp_zmm(src_idx2);
1956 auto src0 = src_zmm(src_idx0);
1957 auto src2 = src_zmm(src_idx2);
1958
1959 valignd(tmp0, src0, src0, 0x2);
1960 valignd(tmp1, src2, src2, 0xe);
1961 vmovaps(src2 | k3333, tmp0);
1962 vmovaps(src0 | kCCCC, tmp1);
1963 }
1964
1965 // swap 4
1966 for (int i = 0; i < 4; i++) {
1967 int src_idx0 = base_idx + i;
1968 int src_idx4 = src_idx0 + 4;
1969
1970 auto tmp0 = tmp_zmm(src_idx0);
1971 auto src0 = src_zmm(src_idx0);
1972 auto src4 = src_zmm(src_idx4);
1973
1974 vmovaps(tmp0, src0);
1975 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
1976 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
1977 }
1978 };
1979
1980 auto fixup16x16 = [=]() {
1981 // swap 8
1982 for (int i = 0; i < 8; i++) {
1983 auto tmp = tmp_zmm(i);
1984 auto src0 = src_zmm(i);
1985 auto src8 = src_zmm(8 + i);
1986 vshuff64x2(tmp, src0, src8, 0x44);
1987 store(tmp, i);
1988 }
1989
1990 for (int i = 0; i < 8; i++) {
1991 auto tmp = tmp_zmm(8 + i);
1992 auto src0 = src_zmm(i);
1993 auto src8 = src_zmm(8 + i);
1994 vshuff64x2(tmp, src0, src8, 0xee);
1995 store(tmp, 8 + i);
1996 }
1997 };
1998
1999 transpose16x8(0);
2000 transpose16x8(8);
2001 fixup16x16();
2002}
2003
2004void jit_brgemm_trans_wei_f32_t::generate() {
2005 preamble();
2006 assert(conf_->oc_block % transpose_size == 0);
2007 int fwd_ic_block = conf_->simd_w;
2008 int fwd_oc_block = 0;
2009 switch (conf_->wei_tag) {
2010 case OI16i64o:
2011 case OIw16i64o:
2012 case OIhw16i64o:
2013 case OIdhw16i64o:
2014 case OI8i64o2i:
2015 case OIw8i64o2i:
2016 case OIhw8i64o2i:
2017 case OIdhw8i64o2i:
2018 case OI16i64o2i:
2019 case OIw16i64o2i:
2020 case OIhw16i64o2i:
2021 case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
2022 case OI16i32o:
2023 case OIw16i32o:
2024 case OIhw16i32o:
2025 case OIdhw16i32o:
2026 case OI8i32o2i:
2027 case OIw8i32o2i:
2028 case OIhw8i32o2i:
2029 case OIdhw8i32o2i:
2030 case OI16i32o2i:
2031 case OIw16i32o2i:
2032 case OIhw16i32o2i:
2033 case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
2034 default: fwd_oc_block = conf_->simd_w;
2035 };
2036
2037 int oc_tail = conf_->K_tail % transpose_size;
2038 int ic_block = conf_->ic_block;
2039 int ic_tail = conf_->N_tail % transpose_size;
2040 src_stride = fwd_oc_block * typesize;
2041 tr_src_stride = ic_block * typesize;
2042 dim_t N_src_shift = conf_->kd * conf_->kh * conf_->kw * fwd_ic_block
2043 * fwd_oc_block * typesize;
2044 dim_t N_tr_src_shift = conf_->simd_w * typesize;
2045 dim_t K_src_shift = conf_->simd_w * typesize;
2046 dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
2047
2048 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
2049 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
2050 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
2051 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
2052
2053 auto kmovw = [=](Opmask k, unsigned w) {
2054 mov(regw_tmp, w);
2055 jit_generator::kmovw(k, regw_tmp);
2056 };
2057
2058 kmovw(k3333, 0x3333); // 0011001100110011
2059 kmovw(k5555, 0x5555); // 0101010101010101
2060 kmovw(kAAAA, 0xaaaa); // 1010101010101010
2061 kmovw(kCCCC, 0xcccc); // 1100110011001100
2062 kmovw(k0F0F, 0x0f0f); // 0000111100001111
2063 kmovw(kF0F0, 0xf0f0); // 1111000011110000
2064
2065 auto compute_N = [=](bool is_oc_tail) {
2066 mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
2067 mov(reg_src, reg_src_base);
2068 mov(reg_tr_src, reg_tr_src_base);
2069 Label N_loop, N_loop_tail;
2070
2071 cmp(reg_loop_N, transpose_size);
2072 jl(N_loop_tail, T_NEAR);
2073
2074 L(N_loop);
2075
2076 transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size);
2077 add(reg_src, N_src_shift);
2078 add(reg_tr_src, N_tr_src_shift);
2079
2080 sub(reg_loop_N, transpose_size);
2081 cmp(reg_loop_N, transpose_size);
2082 jge(N_loop, T_NEAR);
2083
2084 L(N_loop_tail);
2085 if (ic_tail > 0) {
2086 Label N_loop_done;
2087 cmp(reg_loop_N, 0);
2088 jle(N_loop_done, T_NEAR);
2089 transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size);
2090 L(N_loop_done);
2091 }
2092 };
2093
2094 Label K_loop, K_tail;
2095 if (oc_tail > 0) {
2096 cmp(reg_loop_K, transpose_size);
2097 jl(K_tail, T_NEAR);
2098 }
2099
2100 L(K_loop);
2101 compute_N(false);
2102 add(reg_src_base, K_src_shift);
2103 add(reg_tr_src_base, K_tr_src_shift);
2104
2105 sub(reg_loop_K, transpose_size);
2106 cmp(reg_loop_K, transpose_size);
2107 jge(K_loop, T_NEAR);
2108
2109 L(K_tail);
2110 if (oc_tail > 0) {
2111 Label K_loop_done;
2112 cmp(reg_loop_K, 0);
2113 jle(K_loop_done, T_NEAR);
2114
2115 compute_N(true);
2116 L(K_loop_done);
2117 }
2118
2119 postamble();
2120}
2121
2122struct jit_brgemm_trans_wei_bf16_t : public jit_brgemm_trans_wei_t,
2123 public jit_generator {
2124 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_bf16_t)
2125
2126 jit_brgemm_trans_wei_bf16_t(const jit_brgemm_primitive_conf_t *conf)
2127 : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {}
2128
2129 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
2130 status_t create_kernel() override { return jit_generator::create_kernel(); }
2131
2132private:
2133 using reg64_t = const Xbyak::Reg64;
2134 using reg32_t = const Xbyak::Reg32;
2135 using opmask_t = const Xbyak::Opmask;
2136 using zmm = const Xbyak::Zmm;
2137
2138 enum { typesize = sizeof(int16_t), transpose_size = 16 };
2139 dim_t src_stride = 0, tr_src_stride = 0;
2140
2141 opmask_t kTail = k7;
2142
2143 reg64_t reg_src_base = rax;
2144 reg64_t reg_tr_src_base = rbx;
2145
2146 reg64_t reg_src = r8;
2147 reg64_t reg_tr_src = r9;
2148 reg64_t reg_loop_N = r10;
2149 reg64_t reg_loop_K = r11;
2150 reg64_t reg_loop_batch = r12;
2151 reg64_t reg_tr_src_tmp = r13;
2152 reg32_t regw_tmp = r14d;
2153 reg64_t imm_addr64 = r15;
2154
2155 zmm v_abcdefgh_to_abefcdgh = zmm31;
2156
2157 void transpose_16x16_vnni(int nrows, int ncolumns = transpose_size);
2158 void generate() override;
2159};
2160
2161void jit_brgemm_trans_wei_bf16_t::transpose_16x16_vnni(
2162 int nrows, int ncolumns) {
2163 assert(nrows >= 0 && nrows <= transpose_size);
2164 static_assert(transpose_size == 16, "Unsupported transpose size");
2165 if (!nrows) return;
2166
2167 auto src_zmm = [=](int i) {
2168 assert(i >= 0 && i < 8);
2169 return Zmm(i);
2170 };
2171
2172 auto tmp_zmm = [=](int i) {
2173 assert(i >= 0 && i < 8);
2174 return Zmm(8 + i);
2175 };
2176
2177 auto kmovw = [=](Opmask k, unsigned w) {
2178 mov(regw_tmp, w);
2179 jit_generator::kmovw(k, regw_tmp);
2180 };
2181
2182 auto load = [=](int i) {
2183 auto src_load = src_zmm(i);
2184 if (ncolumns < transpose_size) {
2185 kmovw(kTail, (1 << ncolumns) - 1);
2186 src_load = src_zmm(i) | kTail | T_z;
2187 }
2188 vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride));
2189 };
2190
2191 auto store = [=](Zmm r, int i) {
2192 mov(reg_tr_src_tmp, reg_tr_src);
2193 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
2194
2195 // Xbyak does not allow k0 to be specified explicitly via the '|'
2196 // operator, so we have to do this via a method call (implicitly
2197 // EVEX encoding uses k0 to mean 'no mask')
2198 bool partial_store = nrows < transpose_size;
2199 auto k = partial_store ? kTail : k0;
2200 auto base = reg_tr_src_tmp;
2201 base.setOpmaskIdx(k.getIdx(), true);
2202
2203 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
2204 vmovups(addr, r);
2205 };
2206
2207 for (int i = 0; i < 8; i++)
2208 load(i);
2209
2210 for (int i = 0; i < 8; i++)
2211 vpshufb(src_zmm(i), src_zmm(i), v_abcdefgh_to_abefcdgh);
2212
2213 for (int i = 0; i < 2; i++) {
2214 vpunpcklqdq(tmp_zmm(2 * i + 0), src_zmm(2 * i), src_zmm(2 * i + 1));
2215 vpunpckhqdq(tmp_zmm(2 * i + 1), src_zmm(2 * i), src_zmm(2 * i + 1));
2216 }
2217
2218 for (int i = 0; i < 2; i++) {
2219 vpunpcklqdq(
2220 src_zmm(2 * i + 0), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
2221 vpunpckhqdq(
2222 src_zmm(2 * i + 1), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1));
2223 }
2224
2225 for (int i = 0; i < 2; i++) {
2226 vshufi32x4(src_zmm(4 + 0 + i), tmp_zmm(i), tmp_zmm(2 + i), 0x88);
2227 vshufi32x4(src_zmm(4 + 2 + i), tmp_zmm(i), tmp_zmm(2 + i), 0xdd);
2228 }
2229
2230 for (int i = 0; i < 2; i++) {
2231 vshufi32x4(tmp_zmm(0 + i), src_zmm(i), src_zmm(2 + i), 0x88);
2232 vshufi32x4(tmp_zmm(2 + i), src_zmm(i), src_zmm(2 + i), 0xdd);
2233 }
2234
2235 for (int i = 0; i < 4; i++)
2236 vshufi32x4(src_zmm(i), src_zmm(4 + i), tmp_zmm(i), 0x88);
2237
2238 for (int i = 0; i < 4; i++)
2239 vshufi32x4(src_zmm(4 + i), src_zmm(4 + i), tmp_zmm(i), 0xdd);
2240
2241 for (int i = 0; i < 8; i++)
2242 store(src_zmm(i), i);
2243}
2244
2245void jit_brgemm_trans_wei_bf16_t::generate() {
2246 preamble();
2247 int fwd_oc_block = 0;
2248 switch (conf_->wei_tag) {
2249 case OI16i64o:
2250 case OIw16i64o:
2251 case OIhw16i64o:
2252 case OIdhw16i64o:
2253 case OI8i64o2i:
2254 case OIw8i64o2i:
2255 case OIhw8i64o2i:
2256 case OIdhw8i64o2i:
2257 case OI16i64o2i:
2258 case OIw16i64o2i:
2259 case OIhw16i64o2i:
2260 case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
2261 case OI16i32o:
2262 case OIw16i32o:
2263 case OIhw16i32o:
2264 case OIdhw16i32o:
2265 case OI8i32o2i:
2266 case OIw8i32o2i:
2267 case OIhw8i32o2i:
2268 case OIdhw8i32o2i:
2269 case OI16i32o2i:
2270 case OIw16i32o2i:
2271 case OIhw16i32o2i:
2272 case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
2273 default: fwd_oc_block = conf_->simd_w;
2274 };
2275
2276 int oc_tail = conf_->K_tail % transpose_size;
2277 int ic_block = conf_->ic_block;
2278 int ic_tail = conf_->N_tail % transpose_size;
2279 src_stride = 2 * fwd_oc_block * typesize;
2280 tr_src_stride = 2 * ic_block * typesize;
2281 dim_t N_src_shift = conf_->simd_w * fwd_oc_block * typesize;
2282 dim_t N_tr_src_shift = 2 * conf_->simd_w * typesize;
2283 dim_t K_src_shift = 2 * conf_->simd_w * typesize;
2284 dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize;
2285
2286 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
2287 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
2288 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
2289 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
2290
2291 alignas(64) static constexpr const int32_t abcdefgh_to_abefcdgh[16]
2292 = {0x05040100, 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100,
2293 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302,
2294 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302, 0x0d0c0908,
2295 0x0f0e0b0a};
2296
2297 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
2298 mov(imm_addr64, reinterpret_cast<size_t>(addr));
2299 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
2300 };
2301
2302 vmovdqa64(v_abcdefgh_to_abefcdgh, (const int64_t *)abcdefgh_to_abefcdgh);
2303 auto compute_N = [=](bool is_oc_tail) {
2304 mov(reg_src, reg_src_base);
2305 mov(reg_tr_src, reg_tr_src_base);
2306 mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
2307
2308 Label N_loop, N_loop_tail;
2309 cmp(reg_loop_N, transpose_size);
2310 jl(N_loop_tail, T_NEAR);
2311
2312 L(N_loop);
2313
2314 transpose_16x16_vnni(
2315 transpose_size, is_oc_tail ? oc_tail : transpose_size);
2316 add(reg_src, N_src_shift);
2317 add(reg_tr_src, N_tr_src_shift);
2318
2319 sub(reg_loop_N, transpose_size);
2320 cmp(reg_loop_N, transpose_size);
2321 jge(N_loop, T_NEAR);
2322
2323 L(N_loop_tail);
2324 if (ic_tail > 0) {
2325 Label N_loop_done;
2326 cmp(reg_loop_N, 0);
2327 jle(N_loop_done, T_NEAR);
2328 transpose_16x16_vnni(
2329 ic_tail, is_oc_tail ? oc_tail : transpose_size);
2330 L(N_loop_done);
2331 }
2332 };
2333
2334 Label K_loop, K_tail;
2335 if (oc_tail > 0) {
2336 cmp(reg_loop_K, transpose_size);
2337 jl(K_tail, T_NEAR);
2338 }
2339
2340 L(K_loop);
2341 compute_N(false);
2342 add(reg_src_base, K_src_shift);
2343 add(reg_tr_src_base, K_tr_src_shift);
2344
2345 sub(reg_loop_K, transpose_size);
2346 cmp(reg_loop_K, transpose_size);
2347 jge(K_loop, T_NEAR);
2348
2349 L(K_tail);
2350 if (oc_tail > 0) {
2351 Label K_loop_done;
2352 cmp(reg_loop_K, 0);
2353 jle(K_loop_done, T_NEAR);
2354 compute_N(true);
2355 L(K_loop_done);
2356 }
2357
2358 postamble();
2359}
2360
2361struct jit_brgemm_trans_wei_f16_t : public jit_brgemm_trans_wei_t,
2362 public jit_generator {
2363 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f16_t)
2364
2365 jit_brgemm_trans_wei_f16_t(const jit_brgemm_primitive_conf_t *conf)
2366 : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {}
2367
2368 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
2369 status_t create_kernel() override { return jit_generator::create_kernel(); }
2370
2371private:
2372 using reg64_t = const Xbyak::Reg64;
2373 using reg32_t = const Xbyak::Reg32;
2374 using opmask_t = const Xbyak::Opmask;
2375
2376 enum {
2377 typesize_in = sizeof(int16_t),
2378 typesize_out = sizeof(float),
2379 transpose_size = 16
2380 };
2381 dim_t src_stride = 0, tr_src_stride = 0;
2382
2383 opmask_t k3333 = k1;
2384 opmask_t k5555 = k2;
2385 opmask_t kAAAA = k3;
2386 opmask_t kCCCC = k4;
2387 opmask_t k0F0F = k5;
2388 opmask_t kF0F0 = k6;
2389 opmask_t kTail = k7;
2390
2391 reg64_t reg_src_base = rax;
2392 reg64_t reg_tr_src_base = rbx;
2393
2394 reg64_t reg_src = r8;
2395 reg64_t reg_tr_src = r9;
2396 reg64_t reg_loop_N = r10;
2397 reg64_t reg_loop_K = r11;
2398 reg64_t reg_loop_batch = r12;
2399 reg64_t reg_tr_src_tmp = r13;
2400 reg32_t regw_tmp = r14d;
2401 reg64_t imm_addr64 = r15;
2402
2403 Xbyak::Zmm v_abcdefgh_to_abefcdgh = zmm31;
2404
2405 void transpose_16x16(int nrows, int ncolumns = transpose_size);
2406 void generate() override;
2407};
2408
2409void jit_brgemm_trans_wei_f16_t::transpose_16x16(int nrows, int ncolumns) {
2410 assert(nrows >= 0 && nrows <= transpose_size);
2411 static_assert(transpose_size == 16, "Unsupported transpose size");
2412 if (!nrows) return;
2413
2414 auto src_zmm = [=](int i) {
2415 assert(i >= 0 && i < 16);
2416 return Zmm(i);
2417 };
2418
2419 auto tmp_zmm = [=](int i) {
2420 assert(i >= 0 && i < 16);
2421 return Zmm(16 + i);
2422 };
2423
2424 auto kmovw = [=](Opmask k, unsigned w) {
2425 mov(regw_tmp, w);
2426 jit_generator::kmovw(k, regw_tmp);
2427 };
2428
2429 auto load = [=](int i) {
2430 auto src_load = src_zmm(i);
2431 if (ncolumns < transpose_size) {
2432 kmovw(kTail, (1 << ncolumns) - 1);
2433 src_load = src_zmm(i) | kTail | T_z;
2434 }
2435 // TODO: Maybe do tranformations in fp16 data type and at the end
2436 // cvt to f32. Thus reducing instructions
2437 vcvtph2psx(src_load, EVEX_compress_addr(reg_src, i * src_stride));
2438 };
2439
2440 auto store = [=](Zmm r, int i) {
2441 mov(reg_tr_src_tmp, reg_tr_src);
2442 if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1);
2443
2444 // Xbyak does not allow k0 to be specified explicitly via the '|'
2445 // operator, so we have to do this via a method call (implicitly
2446 // EVEX encoding uses k0 to mean 'no mask')
2447 bool partial_store = nrows < transpose_size;
2448 auto k = partial_store ? kTail : k0;
2449 auto base = reg_tr_src_tmp;
2450 base.setOpmaskIdx(k.getIdx(), true);
2451
2452 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
2453 vmovups(addr, r);
2454 };
2455
2456 auto transpose16x8 = [=](int base_idx) {
2457 assert(base_idx == 0 || base_idx == 8);
2458
2459 // swap 1
2460 for (int i = 0; i < 4; i++) {
2461 int src_idx0 = base_idx + i * 2;
2462 int src_idx1 = src_idx0 + 1;
2463
2464 int next_src_idx0 = src_idx0 + 2;
2465 int next_src_idx1 = src_idx1 + 2;
2466 bool load_next = base_idx == 0 || i < 3;
2467
2468 if (base_idx == 0 && i == 0) {
2469 load(src_idx0);
2470 if (src_idx1 < nrows)
2471 load(src_idx1);
2472 else
2473 vpxord(src_zmm(src_idx1), src_zmm(src_idx1),
2474 src_zmm(src_idx1));
2475 }
2476
2477 auto tmp0 = tmp_zmm(src_idx0);
2478 auto tmp1 = tmp_zmm(src_idx1);
2479 auto src0 = src_zmm(src_idx0);
2480 auto src1 = src_zmm(src_idx1);
2481
2482 if (next_src_idx0 < nrows && load_next) load(next_src_idx0);
2483 valignd(tmp0, src0, src0, 0x1);
2484
2485 if (next_src_idx1 < nrows && load_next) load(next_src_idx1);
2486 valignd(tmp1, src1, src1, 0xf);
2487
2488 vmovaps(src0 | kAAAA, tmp1);
2489 vmovaps(src1 | k5555, tmp0);
2490 }
2491 // swap 2
2492 for (int i = 0; i < 4; i++) {
2493 int select_half = (i < 2) ? 0 : 2;
2494 int src_idx0 = base_idx + i + select_half + 0;
2495 int src_idx2 = src_idx0 + 2;
2496
2497 auto tmp0 = tmp_zmm(src_idx0);
2498 auto tmp1 = tmp_zmm(src_idx2);
2499 auto src0 = src_zmm(src_idx0);
2500 auto src2 = src_zmm(src_idx2);
2501
2502 valignd(tmp0, src0, src0, 0x2);
2503 valignd(tmp1, src2, src2, 0xe);
2504 vmovaps(src2 | k3333, tmp0);
2505 vmovaps(src0 | kCCCC, tmp1);
2506 }
2507
2508 // swap 4
2509 for (int i = 0; i < 4; i++) {
2510 int src_idx0 = base_idx + i;
2511 int src_idx4 = src_idx0 + 4;
2512
2513 auto tmp0 = tmp_zmm(src_idx0);
2514 auto src0 = src_zmm(src_idx0);
2515 auto src4 = src_zmm(src_idx4);
2516
2517 vmovaps(tmp0, src0);
2518 vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
2519 vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
2520 }
2521 };
2522
2523 auto fixup16x16 = [=]() {
2524 // swap 8
2525 for (int i = 0; i < 8; i++) {
2526 auto tmp = tmp_zmm(i);
2527 auto src0 = src_zmm(i);
2528 auto src8 = src_zmm(8 + i);
2529 vshuff64x2(tmp, src0, src8, 0x44);
2530 store(tmp, i);
2531 }
2532
2533 for (int i = 0; i < 8; i++) {
2534 auto tmp = tmp_zmm(8 + i);
2535 auto src0 = src_zmm(i);
2536 auto src8 = src_zmm(8 + i);
2537 vshuff64x2(tmp, src0, src8, 0xee);
2538 store(tmp, 8 + i);
2539 }
2540 };
2541
2542 transpose16x8(0);
2543 transpose16x8(8);
2544 fixup16x16();
2545}
2546
2547void jit_brgemm_trans_wei_f16_t::generate() {
2548 preamble();
2549 assert(conf_->oc_block % transpose_size == 0);
2550 int fwd_ic_block = conf_->simd_w;
2551 int fwd_oc_block = 0;
2552 switch (conf_->wei_tag) {
2553 case OI16i64o:
2554 case OIw16i64o:
2555 case OIhw16i64o:
2556 case OIdhw16i64o:
2557 case OI8i64o2i:
2558 case OIw8i64o2i:
2559 case OIhw8i64o2i:
2560 case OIdhw8i64o2i:
2561 case OI16i64o2i:
2562 case OIw16i64o2i:
2563 case OIhw16i64o2i:
2564 case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break;
2565 case OI16i32o:
2566 case OIw16i32o:
2567 case OIhw16i32o:
2568 case OIdhw16i32o:
2569 case OI8i32o2i:
2570 case OIw8i32o2i:
2571 case OIhw8i32o2i:
2572 case OIdhw8i32o2i:
2573 case OI16i32o2i:
2574 case OIw16i32o2i:
2575 case OIhw16i32o2i:
2576 case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break;
2577 default: fwd_oc_block = conf_->simd_w;
2578 };
2579
2580 int oc_tail = conf_->K_tail % transpose_size;
2581 int ic_block = conf_->ic_block;
2582 int ic_tail = conf_->N_tail % transpose_size;
2583 src_stride = fwd_oc_block * typesize_in;
2584 tr_src_stride = ic_block * typesize_out;
2585 dim_t N_src_shift = (conf_->kd * conf_->kh * conf_->kw * fwd_ic_block)
2586 * fwd_oc_block * typesize_in;
2587 dim_t N_tr_src_shift = conf_->simd_w * typesize_out;
2588 dim_t K_src_shift = conf_->simd_w * typesize_in;
2589 dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize_out;
2590
2591 mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
2592 mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
2593 mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]);
2594 mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]);
2595
2596 auto kmovw = [=](Opmask k, unsigned w) {
2597 mov(regw_tmp, w);
2598 jit_generator::kmovw(k, regw_tmp);
2599 };
2600
2601 kmovw(k3333, 0x3333); // 0011001100110011
2602 kmovw(k5555, 0x5555); // 0101010101010101
2603 kmovw(kAAAA, 0xaaaa); // 1010101010101010
2604 kmovw(kCCCC, 0xcccc); // 1100110011001100
2605 kmovw(k0F0F, 0x0f0f); // 0000111100001111
2606 kmovw(kF0F0, 0xf0f0); // 1111000011110000
2607
2608 auto compute_N = [=](bool is_oc_tail) {
2609 mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]);
2610 mov(reg_src, reg_src_base);
2611 mov(reg_tr_src, reg_tr_src_base);
2612 Label N_loop, N_loop_tail;
2613
2614 cmp(reg_loop_N, transpose_size);
2615 jl(N_loop_tail, T_NEAR);
2616
2617 L(N_loop);
2618
2619 transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size);
2620 add(reg_src, N_src_shift);
2621 add(reg_tr_src, N_tr_src_shift);
2622
2623 sub(reg_loop_N, transpose_size);
2624 cmp(reg_loop_N, transpose_size);
2625 jge(N_loop, T_NEAR);
2626
2627 L(N_loop_tail);
2628 if (ic_tail > 0) {
2629 Label N_loop_done;
2630 cmp(reg_loop_N, 0);
2631 jle(N_loop_done, T_NEAR);
2632 transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size);
2633 L(N_loop_done);
2634 }
2635 };
2636
2637 Label K_loop, K_tail;
2638 if (oc_tail > 0) {
2639 cmp(reg_loop_K, transpose_size);
2640 jl(K_tail, T_NEAR);
2641 }
2642
2643 L(K_loop);
2644 compute_N(false);
2645 add(reg_src_base, K_src_shift);
2646 add(reg_tr_src_base, K_tr_src_shift);
2647
2648 sub(reg_loop_K, transpose_size);
2649 cmp(reg_loop_K, transpose_size);
2650 jge(K_loop, T_NEAR);
2651
2652 L(K_tail);
2653 if (oc_tail > 0) {
2654 Label K_loop_done;
2655 cmp(reg_loop_K, 0);
2656 jle(K_loop_done, T_NEAR);
2657
2658 compute_N(true);
2659 L(K_loop_done);
2660 }
2661
2662 postamble();
2663}
2664
2665struct jit_amx_ip_trans_diff_wei_to_vnni_t : public jit_amx_ip_trans_diff_wei,
2666 public jit_generator {
2667 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_ip_trans_diff_wei_to_vnni)
2668
2669 jit_amx_ip_trans_diff_wei_to_vnni_t(const jit_brgemm_primitive_conf_t *jbgp,
2670 const int ext_ic_block, const int ext_oc_block)
2671 : jit_amx_ip_trans_diff_wei(jbgp, ext_ic_block, ext_oc_block)
2672 , jit_generator(jit_name()) {}
2673
2674 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
2675 status_t create_kernel() override { return jit_generator::create_kernel(); }
2676
2677private:
2678 void generate() override;
2679};
2680
2681void jit_amx_ip_trans_diff_wei_to_vnni_t::generate() {
2682 const int typesize_out = 2;
2683 const int typesize_acc = 4;
2684 const int simd_w = 16;
2685
2686 using reg64_t = const Xbyak::Reg64;
2687 using reg32_t = const Xbyak::Reg32;
2688
2689 const reg64_t &reg_output = r15;
2690 const reg64_t &reg_input = r14;
2691 const reg64_t &reg_prm_table = r13;
2692 const reg64_t &reg_last_ic_block = r12;
2693 const reg64_t &reg_last_oc_block = r11;
2694 const reg32_t &regw_tmp = r10d;
2695
2696 const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31);
2697 auto get_zmm_src = [&](int ic) { return Xbyak::Zmm(ic % 8); };
2698
2699 Xbyak::Label prm_table;
2700 Xbyak::Label skip_oc_tail, to_exit;
2701
2702 Xbyak::Opmask load_mask = k4;
2703
2704 int tail_mask = (jbgp_->N_tail % simd_w)
2705 ? (1 << (jbgp_->N_tail % simd_w)) - 1
2706 : 0xffff;
2707 auto kmovw = [=](Xbyak::Opmask k, unsigned w) {
2708 mov(regw_tmp, w);
2709 jit_generator::kmovw(k, regw_tmp);
2710 };
2711
2712 auto reorder_oc_block = [&](int icb, int ic_block, bool is_oc_tail) {
2713 // INP: [64i][No] : FP32
2714 // OUT: [OCB][ICB][16i][No][2i]: BF16
2715 if (ic_block <= 0) return;
2716
2717 dim_t inp_icb_offset = typesize_acc
2718 * (icb * ext_ic_block_ * jbgp_->oc_block); // Internal
2719 dim_t out_icb_offset = typesize_out
2720 * (icb * div_up(ext_ic_block_, 2) * ext_oc_block_
2721 * 2); // External
2722
2723 const int oc_padded = rnd_up(jbgp_->oc, jbgp_->oc_block);
2724 const int oc_padded_ext = rnd_up(jbgp_->oc, ext_oc_block_);
2725
2726 bool tailing_done = false;
2727 for (int oc = 0; oc < jbgp_->oc_block; oc += simd_w) {
2728 int ext_oc = oc % ext_oc_block_;
2729 int ext_ocb = oc / ext_oc_block_;
2730 dim_t ext_ocb_offset = typesize_out
2731 * (ext_ocb * div_up(jbgp_->ic, ext_ic_block_)
2732 * div_up(ext_ic_block_, 2) * ext_oc_block_ * 2);
2733 if (is_oc_tail && oc_padded != oc_padded_ext
2734 && oc + simd_w > ext_oc_block_)
2735 break;
2736 dim_t inp_offset = inp_icb_offset + typesize_acc * (oc); // Internal
2737 dim_t out_offset = out_icb_offset + typesize_out * (ext_oc * 2)
2738 + ext_ocb_offset; // External
2739 kmovw(load_mask, 0xffff);
2740 if (is_oc_tail) {
2741 if (jbgp_->N_tail && (oc + simd_w) >= jbgp_->N_tail) {
2742 if (tailing_done == false) {
2743 kmovw(load_mask, tail_mask);
2744 tailing_done = true;
2745 } else {
2746 auto zmm_src_0 = get_zmm_src(0);
2747 vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
2748 for (int ic = 0; ic < ext_ic_block_ / 2; ic++) {
2749 vmovups(ptr[reg_output + out_offset
2750 + typesize_out
2751 * (ic * ext_oc_block_ * 2)],
2752 zmm_src_0);
2753 }
2754 continue;
2755 }
2756 }
2757 }
2758
2759 int ic = 0;
2760 for (; ic < ic_block / 2; ic++) {
2761 int ic1 = 2 * ic;
2762 int ic2 = 2 * ic + 1;
2763
2764 auto zmm_src_0 = get_zmm_src(ic1);
2765 auto zmm_src_1 = get_zmm_src(ic2);
2766
2767 vmovups(zmm_src_0 | load_mask | T_z,
2768 ptr[reg_input + inp_offset
2769 + typesize_acc * (ic1 * jbgp_->oc_block)]);
2770 vmovups(zmm_src_1 | load_mask | T_z,
2771 ptr[reg_input + inp_offset
2772 + typesize_acc * (ic2 * jbgp_->oc_block)]);
2773 if (jbgp_->wei_dt == data_type::bf16) {
2774 vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0);
2775 } else {
2776 assert(jbgp_->wei_dt == data_type::f16);
2777 vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0);
2778 vcvtps2phx(Ymm(zmm_src_1.getIdx()), zmm_src_1);
2779 vinsertf32x8(
2780 zmm_src_0, zmm_src_0, Ymm(zmm_src_1.getIdx()), 1);
2781 }
2782 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
2783
2784 vmovups(ptr[reg_output + out_offset
2785 + typesize_out * (ic * ext_oc_block_ * 2)],
2786 zmm_src_0);
2787 }
2788 if (ic_block % 2) {
2789 int ic1 = 2 * ic;
2790 auto zmm_src_0 = get_zmm_src(ic1);
2791
2792 vmovups(zmm_src_0 | load_mask | T_z,
2793 ptr[reg_input + inp_offset
2794 + typesize_acc * (ic1 * jbgp_->oc_block)]);
2795
2796 if (jbgp_->wei_dt == data_type::bf16) {
2797 vcvtneps2bf16(Ymm(zmm_src_0.getIdx()), zmm_src_0);
2798 } else {
2799 assert(jbgp_->wei_dt == data_type::f16);
2800 vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0);
2801 }
2802 vpermw(zmm_src_0, zmm_idx, zmm_src_0);
2803
2804 vmovups(ptr[reg_output + out_offset
2805 + typesize_out * (ic * ext_oc_block_ * 2)],
2806 zmm_src_0);
2807 ic++;
2808 }
2809 if (ic < ext_ic_block_ / 2) {
2810 auto zmm_src_0 = get_zmm_src(0);
2811 vpxord(zmm_src_0, zmm_src_0, zmm_src_0);
2812 for (; ic < ext_ic_block_ / 2; ic++) {
2813 vmovups(ptr[reg_output + out_offset
2814 + typesize_out * (ic * ext_oc_block_ * 2)],
2815 zmm_src_0);
2816 }
2817 }
2818 }
2819 };
2820
2821 auto reorder_ic_block = [&](bool is_oc_tail, bool is_ic_tail) {
2822 int nb_ic = div_up(jbgp_->ic_block, ext_ic_block_);
2823 for (int icb = 0; icb < nb_ic; icb++) {
2824 int ic_0 = icb * ext_ic_block_;
2825 int ic_1 = (icb + 1) * ext_ic_block_;
2826 if (is_ic_tail) {
2827 int ext_ic_tail = (jbgp_->ic % ext_ic_block_)
2828 ? (jbgp_->ic % ext_ic_block_)
2829 : ext_ic_block_;
2830 if (jbgp_->M_tail && ic_0 >= jbgp_->M_tail) break;
2831 if (jbgp_->M_tail && ic_0 <= jbgp_->M_tail
2832 && jbgp_->M_tail <= ic_1) {
2833 reorder_oc_block(icb, ext_ic_tail, is_oc_tail);
2834 } else {
2835 reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
2836 }
2837 } else {
2838 reorder_oc_block(icb, ext_ic_block_, is_oc_tail);
2839 }
2840 }
2841 };
2842
2843 auto reorder = [&](bool is_oc_tail) {
2844 Xbyak::Label skip_ic_tail, to_exit_1;
2845
2846 cmp(reg_last_ic_block, 0);
2847 je(skip_ic_tail, T_NEAR);
2848
2849 reorder_ic_block(is_oc_tail, true);
2850 jmp(to_exit, T_NEAR);
2851
2852 L(skip_ic_tail);
2853 reorder_ic_block(is_oc_tail, false);
2854
2855 L(to_exit_1);
2856 };
2857
2858 preamble();
2859
2860 mov(reg_input, ptr[abi_param1 + GET_OFF(src)]);
2861 mov(reg_output, ptr[abi_param1 + GET_OFF(dst)]);
2862 mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]);
2863 mov(reg_last_oc_block, ptr[abi_param1 + GET_OFF(last_oc_block)]);
2864
2865 mov(reg_prm_table, prm_table);
2866 vmovups(zmm_idx, ptr[reg_prm_table]);
2867
2868 cmp(reg_last_oc_block, 0);
2869 je(skip_oc_tail, T_NEAR);
2870
2871 reorder(true);
2872 jmp(to_exit, T_NEAR);
2873
2874 L(skip_oc_tail);
2875 reorder(false);
2876
2877 L(to_exit);
2878 postamble();
2879
2880 align(64);
2881 L(prm_table);
2882 const uint16_t prm_array[32]
2883 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
2884 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
2885 for (size_t i = 0; i < 32; ++i)
2886 dw(prm_array[i]);
2887}
2888
2889#undef GET_OFF
2890
2891status_t create_brgemm_trans_src(
2892 std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker,
2893 const jit_brgemm_primitive_conf_t *conf) {
2894
2895 if (conf->prop_kind != dnnl_backward_weights)
2896 return status::invalid_arguments;
2897
2898 if (conf->src_dt == data_type::f32) {
2899 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f32_t(conf)));
2900 } else if (utils::one_of(conf->src_dt, data_type::bf16, data_type::f16)
2901 && conf->isa != avx512_core_fp16) {
2902 CHECK(safe_ptr_assign(
2903 trans_ker, new jit_brgemm_trans_m_k_bf16_t(conf)));
2904 } else if (conf->src_dt == data_type::f16) {
2905 assert(conf->isa == avx512_core_fp16);
2906 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f16_t(conf)));
2907 } else {
2908 return status::invalid_arguments;
2909 }
2910
2911 return trans_ker->create_kernel();
2912}
2913
2914status_t create_brgemm_copy_to_coarse(
2915 std::unique_ptr<jit_brgemm_copy_to_coarse_t> &copy_ker,
2916 const jit_brgemm_primitive_conf_t *conf) {
2917 if (is_superset(conf->isa, avx512_core_amx))
2918 CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_copy_to_coarse_t(conf)));
2919 else
2920 return status::invalid_arguments;
2921
2922 return copy_ker->create_kernel();
2923}
2924
2925status_t create_brgemm_trans_to_vnni(
2926 std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker,
2927 const jit_brgemm_primitive_conf_t *conf,
2928 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform) {
2929 if (conf->prop_kind != dnnl_backward_weights)
2930 return status::invalid_arguments;
2931
2932 if (conf->dst_dt == data_type::f32) {
2933 CHECK(safe_ptr_assign(
2934 trans_ker, new jit_copy_f32_t(conf, matrix_to_transform)));
2935 } else if (one_of(conf->dst_dt, data_type::bf16, data_type::f16)
2936 && conf->isa != avx512_core_fp16) {
2937 CHECK(safe_ptr_assign(
2938 trans_ker, new jit_trans_to_vnni_t(conf, matrix_to_transform)));
2939 } else if (conf->dst_dt == data_type::f16) {
2940 CHECK(safe_ptr_assign(
2941 trans_ker, new jit_copy_f16_t(conf, matrix_to_transform)));
2942 } else {
2943 return status::invalid_arguments;
2944 }
2945
2946 return trans_ker->create_kernel();
2947}
2948
2949status_t create_brgemm_trans_wei(
2950 std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker,
2951 const jit_brgemm_primitive_conf_t *conf) {
2952
2953 if (conf->prop_kind != dnnl_backward_data) return status::invalid_arguments;
2954
2955 if (conf->wei_dt == data_type::f32) {
2956 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f32_t(conf)));
2957 } else if (one_of(conf->wei_dt, data_type::bf16, data_type::f16)
2958 && conf->isa != avx512_core_fp16) {
2959 CHECK(safe_ptr_assign(
2960 trans_ker, new jit_brgemm_trans_wei_bf16_t(conf)));
2961 } else if (conf->wei_dt == data_type::f16) {
2962 assert(conf->isa == avx512_core_fp16);
2963 CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f16_t(conf)));
2964 } else {
2965 return status::invalid_arguments;
2966 }
2967
2968 return trans_ker->create_kernel();
2969}
2970
2971status_t create_brgemm_amx_ip_trans_wei(
2972 std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker,
2973 const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block,
2974 const int ext_oc_block) {
2975 if (conf->prop_kind == dnnl_backward_weights
2976 && one_of(conf->wei_dt, data_type::bf16, data_type::f16)) {
2977 CHECK(safe_ptr_assign(trans_ker,
2978 new jit_amx_ip_trans_diff_wei_to_vnni_t(
2979 conf, ext_ic_block, ext_oc_block)));
2980 } else
2981 return status::invalid_arguments;
2982
2983 return trans_ker->create_kernel();
2984}
2985
2986} // namespace x64
2987} // namespace cpu
2988} // namespace impl
2989} // namespace dnnl
2990