1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/x64/jit_generator.hpp" |
22 | |
23 | #include "cpu/x64/jit_brgemm_transpose_utils.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::format_tag; |
31 | using namespace dnnl::impl::utils; |
32 | using namespace Xbyak; |
33 | |
34 | #define GET_OFF(x) offsetof(ctx_t, x) |
35 | |
36 | struct jit_brgemm_trans_m_k_f32_t : public jit_brgemm_trans_src_t, |
37 | public jit_generator { |
38 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f32_t) |
39 | |
40 | jit_brgemm_trans_m_k_f32_t(const jit_brgemm_primitive_conf_t *conf) |
41 | : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {} |
42 | |
43 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
44 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
45 | |
46 | private: |
47 | using reg64_t = const Xbyak::Reg64; |
48 | using reg32_t = const Xbyak::Reg32; |
49 | using opmask_t = const Xbyak::Opmask; |
50 | |
51 | enum { typesize = sizeof(float), transpose_size = 16 }; |
52 | dim_t src_stride = 0, tr_src_stride = 0; |
53 | |
54 | opmask_t k3333 = k1; |
55 | opmask_t k5555 = k2; |
56 | opmask_t kAAAA = k3; |
57 | opmask_t kCCCC = k4; |
58 | opmask_t k0F0F = k5; |
59 | opmask_t kF0F0 = k6; |
60 | opmask_t kTail = k7; |
61 | |
62 | reg64_t reg_src_base = rax; |
63 | reg64_t reg_tr_src_base = rbx; |
64 | |
65 | reg64_t reg_src = r8; |
66 | reg64_t reg_tr_src = r9; |
67 | reg64_t reg_loop_K = r10; |
68 | reg64_t reg_loop_M = r11; |
69 | reg64_t reg_loop_batch = r12; |
70 | reg64_t reg_tr_src_tmp = r13; |
71 | reg32_t regw_tmp = r14d; |
72 | reg64_t reg_row_loop = r15; |
73 | |
74 | void transpose_16x16(int nrows, int ncolumns); |
75 | void transpose(int nrows, int ncolumns); |
76 | void generate() override; |
77 | }; |
78 | |
79 | void jit_brgemm_trans_m_k_f32_t::transpose_16x16(int nrows, int ncolumns) { |
80 | assert(nrows >= 0 && nrows <= transpose_size); |
81 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
82 | if (!nrows) return; |
83 | |
84 | auto src_zmm = [=](int i) { |
85 | assert(i >= 0 && i < 16); |
86 | return Zmm(i); |
87 | }; |
88 | |
89 | auto tmp_zmm = [=](int i) { |
90 | assert(i >= 0 && i < 16); |
91 | return Zmm(16 + i); |
92 | }; |
93 | |
94 | auto kmovw = [=](Opmask k, unsigned w) { |
95 | mov(regw_tmp, w); |
96 | jit_generator::kmovw(k, regw_tmp); |
97 | }; |
98 | |
99 | auto load = [=](int i) { |
100 | auto src_load = src_zmm(i); |
101 | if (i >= nrows) { |
102 | vpxord(src_load, src_load, src_load); |
103 | return; |
104 | } |
105 | |
106 | if (ncolumns < transpose_size) { |
107 | kmovw(kTail, (1 << ncolumns) - 1); |
108 | src_load = src_zmm(i) | kTail | T_z; |
109 | } |
110 | vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
111 | }; |
112 | |
113 | auto store = [=](Zmm r, int i) { |
114 | mov(reg_tr_src_tmp, reg_tr_src); |
115 | if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1); |
116 | |
117 | // Xbyak does not allow k0 to be specified explicitly via the '|' |
118 | // operator, so we have to do this via a method call (implicitly |
119 | // EVEX encoding uses k0 to mean 'no mask') |
120 | const bool partial_store = nrows < transpose_size; |
121 | const auto k = partial_store ? kTail : k0; |
122 | auto base = reg_tr_src_tmp; |
123 | base.setOpmaskIdx(k.getIdx(), true); |
124 | |
125 | const auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
126 | vmovups(addr, r); |
127 | }; |
128 | |
129 | auto transpose16x8 = [=](int base_idx) { |
130 | assert(base_idx == 0 || base_idx == 8); |
131 | |
132 | // swap 1 |
133 | for (int i = 0; i < 4; i++) { |
134 | const int src_idx0 = base_idx + i * 2; |
135 | const int src_idx1 = src_idx0 + 1; |
136 | |
137 | const int next_src_idx0 = src_idx0 + 2; |
138 | const int next_src_idx1 = src_idx1 + 2; |
139 | const bool load_next = base_idx == 0 || i < 3; |
140 | |
141 | if (base_idx == 0 && i == 0) { |
142 | load(src_idx0); |
143 | if (src_idx1 < nrows) |
144 | load(src_idx1); |
145 | else |
146 | vpxord(src_zmm(src_idx1), src_zmm(src_idx1), |
147 | src_zmm(src_idx1)); |
148 | } |
149 | |
150 | const auto tmp0 = tmp_zmm(src_idx0); |
151 | const auto tmp1 = tmp_zmm(src_idx1); |
152 | const auto src0 = src_zmm(src_idx0); |
153 | const auto src1 = src_zmm(src_idx1); |
154 | |
155 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
156 | valignd(tmp0, src0, src0, 0x1); |
157 | |
158 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
159 | valignd(tmp1, src1, src1, 0xf); |
160 | |
161 | vmovaps(src0 | kAAAA, tmp1); |
162 | vmovaps(src1 | k5555, tmp0); |
163 | } |
164 | // swap 2 |
165 | for (int i = 0; i < 4; i++) { |
166 | const int select_half = (i < 2) ? 0 : 2; |
167 | const int src_idx0 = base_idx + i + select_half + 0; |
168 | const int src_idx2 = src_idx0 + 2; |
169 | |
170 | const auto tmp0 = tmp_zmm(src_idx0); |
171 | const auto tmp1 = tmp_zmm(src_idx2); |
172 | const auto src0 = src_zmm(src_idx0); |
173 | const auto src2 = src_zmm(src_idx2); |
174 | |
175 | valignd(tmp0, src0, src0, 0x2); |
176 | valignd(tmp1, src2, src2, 0xe); |
177 | vmovaps(src2 | k3333, tmp0); |
178 | vmovaps(src0 | kCCCC, tmp1); |
179 | } |
180 | |
181 | // swap 4 |
182 | for (int i = 0; i < 4; i++) { |
183 | const int src_idx0 = base_idx + i; |
184 | const int src_idx4 = src_idx0 + 4; |
185 | |
186 | const auto tmp0 = tmp_zmm(src_idx0); |
187 | const auto src0 = src_zmm(src_idx0); |
188 | const auto src4 = src_zmm(src_idx4); |
189 | |
190 | vmovaps(tmp0, src0); |
191 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
192 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
193 | } |
194 | }; |
195 | |
196 | auto fixup16x16 = [=]() { |
197 | // swap 8 |
198 | const auto max_iters_phase_1 = std::min(ncolumns, 8); |
199 | for (int i = 0; i < max_iters_phase_1; i++) { |
200 | const auto tmp = tmp_zmm(i); |
201 | const auto src0 = src_zmm(i); |
202 | const auto src8 = src_zmm(8 + i); |
203 | vshuff64x2(tmp, src0, src8, 0x44); |
204 | store(tmp, i); |
205 | } |
206 | |
207 | const auto max_iters_phase_2 = std::min(ncolumns - 8, 8); |
208 | for (int i = 0; i < max_iters_phase_2; i++) { |
209 | const auto tmp = tmp_zmm(8 + i); |
210 | const auto src0 = src_zmm(i); |
211 | const auto src8 = src_zmm(8 + i); |
212 | vshuff64x2(tmp, src0, src8, 0xee); |
213 | store(tmp, 8 + i); |
214 | } |
215 | }; |
216 | |
217 | transpose16x8(0); |
218 | transpose16x8(8); |
219 | fixup16x16(); |
220 | } |
221 | |
222 | void jit_brgemm_trans_m_k_f32_t::transpose(int nrows, int ncolumns) { |
223 | |
224 | Label K_loop, K_tail_or_done, K_done; |
225 | const int num_nrows_loop = nrows / transpose_size; |
226 | const int nrows_tail = nrows % transpose_size; |
227 | const dim_t src_shift = transpose_size * conf_->ic * typesize; |
228 | const dim_t tr_src_shift = transpose_size * typesize; |
229 | |
230 | if (num_nrows_loop > 1) mov(reg_row_loop, num_nrows_loop); |
231 | L(K_loop); |
232 | if (num_nrows_loop > 0) transpose_16x16(transpose_size, ncolumns); |
233 | if (num_nrows_loop > 1 || (num_nrows_loop > 0 && nrows_tail > 0)) { |
234 | add(reg_src, src_shift); |
235 | add(reg_tr_src, tr_src_shift); |
236 | } |
237 | if (num_nrows_loop > 1) { |
238 | dec(reg_row_loop); |
239 | jg(K_loop); |
240 | } |
241 | |
242 | if (nrows_tail > 0) { transpose_16x16(nrows_tail, ncolumns); } |
243 | |
244 | if (num_nrows_loop > 1 || nrows_tail > 0) { |
245 | // reset pointers |
246 | sub(reg_src, src_shift * num_nrows_loop); |
247 | sub(reg_tr_src, tr_src_shift * num_nrows_loop); |
248 | } |
249 | } |
250 | |
251 | void jit_brgemm_trans_m_k_f32_t::generate() { |
252 | preamble(); |
253 | assert(conf_->ic_block % transpose_size == 0); |
254 | const int os_block = conf_->os_block; |
255 | const int last_os_block_tail = conf_->K_tail % os_block; |
256 | const int ic_tail = conf_->M_tail % transpose_size; |
257 | src_stride = conf_->ic * typesize; |
258 | tr_src_stride = conf_->LDA * typesize; |
259 | const dim_t m_src_shift = transpose_size * typesize; |
260 | const dim_t m_tr_src_shift = tr_src_stride * transpose_size; |
261 | |
262 | const dim_t batch_src_shift = src_stride * os_block; |
263 | const dim_t batch_tr_src_shift = tr_src_stride * conf_->M; |
264 | |
265 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
266 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
267 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
268 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
269 | |
270 | auto kmovw = [=](Opmask k, unsigned w) { |
271 | mov(regw_tmp, w); |
272 | jit_generator::kmovw(k, regw_tmp); |
273 | }; |
274 | |
275 | kmovw(k3333, 0x3333); // 0011001100110011 |
276 | kmovw(k5555, 0x5555); // 0101010101010101 |
277 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
278 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
279 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
280 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
281 | |
282 | auto compute_M = [=](bool is_os_tail) { |
283 | const auto nrows = is_os_tail ? last_os_block_tail : os_block; |
284 | mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]); |
285 | mov(reg_src, reg_src_base); |
286 | mov(reg_tr_src, reg_tr_src_base); |
287 | Label M_loop, M_tail_or_done, M_done; |
288 | if (ic_tail > 0) { |
289 | cmp(reg_loop_M, transpose_size); |
290 | jl(M_tail_or_done, T_NEAR); |
291 | } |
292 | |
293 | L(M_loop); |
294 | transpose(nrows, transpose_size); |
295 | if (conf_->ic_block > transpose_size) { |
296 | add(reg_src, m_src_shift); |
297 | add(reg_tr_src, m_tr_src_shift); |
298 | sub(reg_loop_M, transpose_size); |
299 | cmp(reg_loop_M, transpose_size); |
300 | jge(M_loop, T_NEAR); |
301 | } else { |
302 | jmp(M_done, T_NEAR); |
303 | } |
304 | |
305 | L(M_tail_or_done); |
306 | if (ic_tail > 0) { |
307 | cmp(reg_loop_M, 0); |
308 | jle(M_done, T_NEAR); |
309 | |
310 | transpose(nrows, ic_tail); |
311 | } |
312 | L(M_done); |
313 | }; |
314 | |
315 | auto compute_batch = [=](bool is_os_tail) { |
316 | Label batch_loop; |
317 | L(batch_loop); |
318 | |
319 | compute_M(is_os_tail); |
320 | add(reg_src_base, batch_src_shift); |
321 | add(reg_tr_src_base, batch_tr_src_shift); |
322 | |
323 | sub(reg_loop_batch, 1); |
324 | jnz(batch_loop, T_NEAR); |
325 | }; |
326 | |
327 | Label K_tail; |
328 | if (last_os_block_tail > 0) { |
329 | cmp(reg_loop_K, os_block); |
330 | jl(K_tail, T_NEAR); |
331 | } |
332 | |
333 | compute_batch(false); |
334 | |
335 | if (last_os_block_tail > 0) { |
336 | Label K_done; |
337 | jmp(K_done, T_NEAR); |
338 | |
339 | L(K_tail); |
340 | compute_batch(true); |
341 | L(K_done); |
342 | } |
343 | |
344 | postamble(); |
345 | } |
346 | |
347 | struct jit_brgemm_trans_m_k_bf16_t : public jit_brgemm_trans_src_t, |
348 | public jit_generator { |
349 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_bf16_t) |
350 | jit_brgemm_trans_m_k_bf16_t(const jit_brgemm_primitive_conf_t *conf) |
351 | : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {} |
352 | |
353 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
354 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
355 | |
356 | private: |
357 | using reg64_t = const Xbyak::Reg64; |
358 | using reg32_t = const Xbyak::Reg32; |
359 | using opmask_t = const Xbyak::Opmask; |
360 | |
361 | enum { |
362 | typesize = sizeof(int16_t), |
363 | transpose_size = 16, |
364 | }; |
365 | dim_t src_stride = 0, tr_src_stride = 0; |
366 | |
367 | opmask_t kFFFF = k1; |
368 | opmask_t k5555 = k2; |
369 | opmask_t kAAAA = k3; |
370 | opmask_t kAA = k4; |
371 | opmask_t k55 = k5; |
372 | opmask_t kCC = k6; |
373 | opmask_t k33 = k7; |
374 | opmask_t kTail = k1; |
375 | |
376 | reg32_t regw_tmp = r15d; |
377 | |
378 | reg64_t reg_k_src = r14; |
379 | reg64_t reg_k_tr_src = r13; |
380 | |
381 | reg64_t reg_m_src = r12; |
382 | reg64_t reg_m_tr_src = r11; |
383 | |
384 | reg64_t reg_batch_src = r10; |
385 | reg64_t reg_batch_tr_src = r9; |
386 | |
387 | reg64_t reg_loop_batch = r8; |
388 | reg64_t reg_loop_K = rax; |
389 | reg64_t reg_loop_M = rbx; |
390 | |
391 | reg64_t reg_tr_src_tmp = abi_not_param1; // lnx -> rcx |
392 | reg64_t imm_addr64 = rdx; |
393 | |
394 | Xbyak::Zmm vidx1 = zmm31; |
395 | Xbyak::Zmm vidx2 = zmm30; |
396 | Xbyak::Zmm vidx3 = zmm29; |
397 | Xbyak::Zmm vidx4 = zmm28; |
398 | Xbyak::Zmm vidx5 = zmm27; |
399 | Xbyak::Zmm zmm_tmp = zmm26; |
400 | |
401 | void transpose( |
402 | reg64_t dst, reg64_t src, int nrows, int ncolumns = transpose_size); |
403 | void generate() override; |
404 | }; |
405 | |
406 | void jit_brgemm_trans_m_k_bf16_t::transpose( |
407 | reg64_t dst, reg64_t src, int nrows, int ncolumns) { |
408 | assert(nrows >= 0 && nrows <= transpose_size); |
409 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
410 | if (!nrows) return; |
411 | |
412 | auto src_zmm = [=](int i) { return Zmm(i); }; |
413 | |
414 | auto src_ymm = [=](int i) { |
415 | assert(i >= 0 && i < 16); |
416 | return Ymm(i); |
417 | }; |
418 | |
419 | auto kmovw = [=](Opmask k, unsigned w) { |
420 | mov(regw_tmp, w); |
421 | jit_generator::kmovw(k, regw_tmp); |
422 | }; |
423 | |
424 | auto kmovd = [=](Opmask k, unsigned w) { |
425 | mov(regw_tmp, w); |
426 | jit_generator::kmovd(k, regw_tmp); |
427 | }; |
428 | |
429 | auto store = [=](Zmm r, int i) { |
430 | mov(reg_tr_src_tmp, dst); |
431 | |
432 | auto k = kTail; |
433 | auto base = reg_tr_src_tmp; |
434 | base.setOpmaskIdx(k.getIdx(), true); |
435 | |
436 | auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
437 | vmovups(addr, r); |
438 | }; |
439 | |
440 | const int ic_block = ncolumns; |
441 | kmovd(kFFFF, ic_block < transpose_size ? (1 << ic_block) - 1 : 0xffff); |
442 | |
443 | for (int i = 0; i < nrows / 2; i++) { |
444 | auto zmm_src0 = src_zmm(2 * i); |
445 | auto zmm_src1 = src_zmm(2 * i + 1); |
446 | auto src1 = src_ymm(2 * i + 1); |
447 | vmovdqu16(zmm_src0 | kFFFF | T_z, |
448 | EVEX_compress_addr(src, 2 * i * src_stride)); |
449 | vmovdqu16(zmm_src1 | kFFFF | T_z, |
450 | EVEX_compress_addr(src, (2 * i + 1) * src_stride)); |
451 | vinsertf64x4(zmm_src0, zmm_src0, src1, 1); |
452 | vpermw(zmm_src0, vidx5, zmm_src0); |
453 | } |
454 | |
455 | // for odd numbers we need to mix row with zeroes |
456 | if (nrows % 2) { |
457 | int i = nrows / 2; |
458 | auto zmm_src0 = src_zmm(2 * i); |
459 | vmovdqu16(zmm_src0 | kFFFF | T_z, |
460 | EVEX_compress_addr(src, 2 * i * src_stride)); |
461 | vpermw(zmm_src0, vidx5, zmm_src0); |
462 | } |
463 | |
464 | for (int i = rnd_up(nrows, 2); i < 16; i += 2) { |
465 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
466 | } |
467 | |
468 | // swap 1 |
469 | for (int i = 0; i < 4; i++) { |
470 | auto zmm0 = src_zmm(4 * i); |
471 | auto zmm1 = src_zmm(4 * i + 2); |
472 | auto tmp0 = src_zmm(4 * i + 1); |
473 | auto tmp1 = src_zmm(4 * i + 3); |
474 | |
475 | vmovups(tmp0, zmm0); |
476 | vmovups(tmp1, zmm1); |
477 | |
478 | vpermps(tmp0 | kAAAA, vidx3, zmm1); |
479 | vpermps(tmp1 | k5555, vidx3, zmm0); |
480 | } |
481 | // swap 2 |
482 | int base_idx; |
483 | base_idx = 0; |
484 | for (int i = 0; i < 2; i++) { |
485 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
486 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
487 | |
488 | auto tmp0 = src_zmm(base_idx + 2 * i); |
489 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
490 | |
491 | vmovupd(tmp0, zmm0); |
492 | vmovupd(tmp1, zmm1); |
493 | |
494 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
495 | vpermpd(tmp1 | k55, vidx2, zmm0); |
496 | } |
497 | base_idx = 8; |
498 | for (int i = 0; i < 2; i++) { |
499 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
500 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
501 | |
502 | auto tmp0 = src_zmm(base_idx + 2 * i); |
503 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
504 | |
505 | vmovupd(tmp0, zmm0); |
506 | vmovupd(tmp1, zmm1); |
507 | |
508 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
509 | vpermpd(tmp1 | k55, vidx2, zmm0); |
510 | } |
511 | |
512 | // swap 3 |
513 | for (int i = 0; i < 4; i++) { |
514 | auto zmm0 = src_zmm(2 * i); |
515 | auto zmm1 = src_zmm(2 * i + 8); |
516 | |
517 | auto tmp0 = src_zmm(2 * i + 1); |
518 | auto tmp1 = src_zmm(2 * i + 9); |
519 | |
520 | vmovupd(tmp0, zmm0); |
521 | vmovupd(tmp1, zmm1); |
522 | |
523 | vpermpd(tmp0 | kCC, vidx1, zmm1); |
524 | vpermpd(tmp1 | k33, vidx1, zmm0); |
525 | } |
526 | |
527 | // all stores |
528 | for (int i = 0; i < 8; i++) |
529 | vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1); |
530 | |
531 | auto get_vec_idx = [=](int ic_idx) { |
532 | assert(ic_idx < 16 && ic_idx >= 0); |
533 | switch (ic_idx) { |
534 | case 0: return 1; |
535 | case 1: return 0; |
536 | case 2: return 3; |
537 | case 3: return 2; |
538 | case 4: return 9; |
539 | case 5: return 8; |
540 | case 6: return 11; |
541 | case 7: return 10; |
542 | case 8: return 5; |
543 | case 9: return 4; |
544 | case 10: return 7; |
545 | case 11: return 6; |
546 | case 12: return 13; |
547 | case 13: return 12; |
548 | case 14: return 15; |
549 | default: return 14; |
550 | } |
551 | }; |
552 | |
553 | int store_tail = rnd_up(nrows, 2); |
554 | kmovw(kTail, (1 << store_tail / 2) - 1); |
555 | |
556 | for (int ic = 0; ic < ic_block; ic++) |
557 | store(src_zmm(get_vec_idx(ic)), ic); |
558 | } |
559 | |
560 | void jit_brgemm_trans_m_k_bf16_t::generate() { |
561 | preamble(); |
562 | |
563 | alignas(64) static constexpr const int64_t idx1[8] |
564 | = {2, 3, 0, 1, 6, 7, 4, 5}; |
565 | alignas(64) static constexpr const int64_t idx2[8] |
566 | = {1, 0, 3, 2, 5, 4, 7, 6}; |
567 | alignas(64) static constexpr const int32_t idx3[16] |
568 | = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14}; |
569 | alignas(64) static constexpr const int32_t idx4[16] |
570 | = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7}; |
571 | alignas(64) static constexpr const uint16_t idx5[32] |
572 | = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17, |
573 | 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31}; |
574 | |
575 | constexpr int amx_xf16_granularity = 2; |
576 | const bool last_row_padded = is_superset(conf_->isa, avx512_core_amx) |
577 | && conf_->os % amx_xf16_granularity != 0; |
578 | const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0); |
579 | |
580 | const int os_block = conf_->os_block; |
581 | const int last_os_block_tail = eff_K_tail % transpose_size; |
582 | const int ic_tail = conf_->M_tail % transpose_size; |
583 | src_stride = conf_->ic * typesize; |
584 | tr_src_stride = conf_->LDA * typesize; |
585 | |
586 | const dim_t batch_src_shift = src_stride * os_block; |
587 | const dim_t batch_tr_src_shift = tr_src_stride * conf_->M; |
588 | |
589 | const dim_t M_src_shift = transpose_size * typesize; |
590 | const dim_t M_tr_src_shift = transpose_size * conf_->LDA * typesize; |
591 | |
592 | const dim_t K_src_shift = transpose_size * conf_->ic * typesize; |
593 | const dim_t K_tr_src_shift = transpose_size * typesize; |
594 | |
595 | auto kmovw = [=](Opmask k, unsigned w) { |
596 | mov(regw_tmp, w); |
597 | jit_generator::kmovw(k, regw_tmp); |
598 | }; |
599 | |
600 | kmovw(kFFFF, 0xffff); |
601 | kmovw(k5555, 0x5555); |
602 | kmovw(kAAAA, 0xaaaa); |
603 | kmovw(kAA, 0xaa); |
604 | kmovw(k55, 0x55); |
605 | kmovw(kCC, 0xcc); |
606 | kmovw(k33, 0x33); |
607 | |
608 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
609 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
610 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
611 | }; |
612 | |
613 | auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { |
614 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
615 | jit_generator::vmovdqa32(z, ptr[imm_addr64]); |
616 | }; |
617 | |
618 | vmovdqa64(vidx1, idx1); |
619 | vmovdqa64(vidx2, idx2); |
620 | vmovdqa32(vidx3, idx3); |
621 | vmovdqa32(vidx4, idx4); |
622 | vmovdqa32(vidx5, (const int32_t *)idx5); |
623 | |
624 | auto compute_m_loop = [&](reg64_t ®_base, reg64_t ®_tr_base, |
625 | bool is_os_tail) { |
626 | mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]); |
627 | mov(reg_m_src, reg_base); |
628 | mov(reg_m_tr_src, reg_tr_base); |
629 | |
630 | Label M_loop_tail, M_loop; |
631 | if (ic_tail > 0) { |
632 | cmp(reg_loop_M, transpose_size); |
633 | jl(M_loop_tail, T_NEAR); |
634 | } |
635 | L(M_loop); |
636 | { |
637 | transpose(reg_m_tr_src, reg_m_src, |
638 | is_os_tail ? last_os_block_tail : transpose_size, |
639 | transpose_size); |
640 | add(reg_m_src, M_src_shift); |
641 | add(reg_m_tr_src, M_tr_src_shift); |
642 | } |
643 | sub(reg_loop_M, transpose_size); |
644 | cmp(reg_loop_M, transpose_size); |
645 | jge(M_loop, T_NEAR); |
646 | |
647 | if (ic_tail > 0) { |
648 | Label M_loop_done; |
649 | L(M_loop_tail); |
650 | cmp(reg_loop_M, 0); |
651 | jle(M_loop_done, T_NEAR); |
652 | |
653 | transpose(reg_m_tr_src, reg_m_src, |
654 | is_os_tail ? last_os_block_tail : transpose_size, ic_tail); |
655 | L(M_loop_done); |
656 | } |
657 | }; |
658 | |
659 | auto compute_k_loop = [&](reg64_t ®_base, reg64_t ®_tr_base) { |
660 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
661 | mov(reg_k_src, reg_base); |
662 | mov(reg_k_tr_src, reg_tr_base); |
663 | |
664 | Label K_tail, K_loop, K_done; |
665 | if (last_os_block_tail > 0) { |
666 | cmp(reg_loop_K, transpose_size); |
667 | jl(K_tail, T_NEAR); |
668 | } |
669 | L(K_loop); |
670 | { |
671 | compute_m_loop(reg_k_src, reg_k_tr_src, false); |
672 | add(reg_k_src, K_src_shift); |
673 | add(reg_k_tr_src, K_tr_src_shift); |
674 | } |
675 | sub(reg_loop_K, transpose_size); |
676 | cmp(reg_loop_K, transpose_size); |
677 | jge(K_loop, T_NEAR); |
678 | |
679 | cmp(reg_loop_K, 0); |
680 | je(K_done, T_NEAR); |
681 | |
682 | if (last_os_block_tail > 0) { |
683 | L(K_tail); |
684 | compute_m_loop(reg_k_src, reg_k_tr_src, true); |
685 | } |
686 | L(K_done); |
687 | }; |
688 | |
689 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
690 | mov(reg_batch_src, ptr[param1 + GET_OFF(src)]); |
691 | mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
692 | |
693 | Label batch_loop; |
694 | L(batch_loop); |
695 | { |
696 | compute_k_loop(reg_batch_src, reg_batch_tr_src); |
697 | |
698 | add(reg_batch_src, batch_src_shift); |
699 | add(reg_batch_tr_src, batch_tr_src_shift); |
700 | } |
701 | sub(reg_loop_batch, 1); |
702 | jnz(batch_loop, T_NEAR); |
703 | |
704 | postamble(); |
705 | } |
706 | |
707 | struct jit_brgemm_trans_m_k_f16_t : public jit_brgemm_trans_src_t, |
708 | public jit_generator { |
709 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_m_k_f16_t) |
710 | |
711 | jit_brgemm_trans_m_k_f16_t(const jit_brgemm_primitive_conf_t *conf) |
712 | : jit_brgemm_trans_src_t(conf), jit_generator(jit_name()) {} |
713 | |
714 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
715 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
716 | |
717 | private: |
718 | using reg64_t = const Xbyak::Reg64; |
719 | using reg32_t = const Xbyak::Reg32; |
720 | using opmask_t = const Xbyak::Opmask; |
721 | |
722 | enum { |
723 | typesize_in = sizeof(float16_t), |
724 | typesize_out = sizeof(float), |
725 | transpose_size = 16 |
726 | }; |
727 | dim_t src_stride = 0, tr_src_stride = 0; |
728 | |
729 | opmask_t k3333 = k1; |
730 | opmask_t k5555 = k2; |
731 | opmask_t kAAAA = k3; |
732 | opmask_t kCCCC = k4; |
733 | opmask_t k0F0F = k5; |
734 | opmask_t kF0F0 = k6; |
735 | opmask_t kTail = k7; |
736 | |
737 | reg64_t reg_src_base = rax; |
738 | reg64_t reg_tr_src_base = rbx; |
739 | |
740 | reg64_t reg_src = r8; |
741 | reg64_t reg_tr_src = r9; |
742 | reg64_t reg_loop_K = r10; |
743 | reg64_t reg_loop_M = r11; |
744 | reg64_t reg_loop_batch = r12; |
745 | reg64_t reg_tr_src_tmp = r13; |
746 | reg32_t regw_tmp = r14d; |
747 | |
748 | void transpose_16x16(int nrows, int ncolumns = transpose_size); |
749 | void generate() override; |
750 | }; |
751 | |
752 | void jit_brgemm_trans_m_k_f16_t::transpose_16x16(int nrows, int ncolumns) { |
753 | assert(nrows >= 0 && nrows <= transpose_size); |
754 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
755 | if (!nrows) return; |
756 | |
757 | auto src_zmm = [=](int i) { |
758 | assert(i >= 0 && i < 16); |
759 | return Zmm(i); |
760 | }; |
761 | |
762 | auto tmp_zmm = [=](int i) { |
763 | assert(i >= 0 && i < 16); |
764 | return Zmm(16 + i); |
765 | }; |
766 | |
767 | auto kmovw = [=](Opmask k, unsigned w) { |
768 | mov(regw_tmp, w); |
769 | jit_generator::kmovw(k, regw_tmp); |
770 | }; |
771 | |
772 | auto load = [=](int i) { |
773 | auto src_load = src_zmm(i); |
774 | if (i >= nrows) { |
775 | vpxord(src_load, src_load, src_load); |
776 | return; |
777 | } |
778 | |
779 | if (ncolumns < transpose_size) { |
780 | kmovw(kTail, (1 << ncolumns) - 1); |
781 | src_load = src_zmm(i) | kTail | T_z; |
782 | } |
783 | vcvtph2psx(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
784 | }; |
785 | |
786 | auto store = [=](Zmm r, int i) { |
787 | mov(reg_tr_src_tmp, reg_tr_src); |
788 | if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1); |
789 | |
790 | // Xbyak does not allow k0 to be specified explicitly via the '|' |
791 | // operator, so we have to do this via a method call (implicitly |
792 | // EVEX encoding uses k0 to mean 'no mask') |
793 | const bool partial_store = nrows < transpose_size; |
794 | const auto k = partial_store ? kTail : k0; |
795 | auto base = reg_tr_src_tmp; |
796 | base.setOpmaskIdx(k.getIdx(), true); |
797 | |
798 | const auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
799 | vmovups(addr, r); |
800 | }; |
801 | |
802 | auto transpose16x8 = [=](int base_idx) { |
803 | assert(base_idx == 0 || base_idx == 8); |
804 | |
805 | // swap 1 |
806 | for (int i = 0; i < 4; i++) { |
807 | const int src_idx0 = base_idx + i * 2; |
808 | const int src_idx1 = src_idx0 + 1; |
809 | |
810 | const int next_src_idx0 = src_idx0 + 2; |
811 | const int next_src_idx1 = src_idx1 + 2; |
812 | const bool load_next = base_idx == 0 || i < 3; |
813 | |
814 | if (base_idx == 0 && i == 0) { |
815 | load(src_idx0); |
816 | if (src_idx1 < nrows) |
817 | load(src_idx1); |
818 | else |
819 | vpxord(src_zmm(src_idx1), src_zmm(src_idx1), |
820 | src_zmm(src_idx1)); |
821 | } |
822 | |
823 | const auto tmp0 = tmp_zmm(src_idx0); |
824 | const auto tmp1 = tmp_zmm(src_idx1); |
825 | const auto src0 = src_zmm(src_idx0); |
826 | const auto src1 = src_zmm(src_idx1); |
827 | |
828 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
829 | valignd(tmp0, src0, src0, 0x1); |
830 | |
831 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
832 | valignd(tmp1, src1, src1, 0xf); |
833 | |
834 | vmovaps(src0 | kAAAA, tmp1); |
835 | vmovaps(src1 | k5555, tmp0); |
836 | } |
837 | // swap 2 |
838 | for (int i = 0; i < 4; i++) { |
839 | const int select_half = (i < 2) ? 0 : 2; |
840 | const int src_idx0 = base_idx + i + select_half + 0; |
841 | const int src_idx2 = src_idx0 + 2; |
842 | |
843 | const auto tmp0 = tmp_zmm(src_idx0); |
844 | const auto tmp1 = tmp_zmm(src_idx2); |
845 | const auto src0 = src_zmm(src_idx0); |
846 | const auto src2 = src_zmm(src_idx2); |
847 | |
848 | valignd(tmp0, src0, src0, 0x2); |
849 | valignd(tmp1, src2, src2, 0xe); |
850 | vmovaps(src2 | k3333, tmp0); |
851 | vmovaps(src0 | kCCCC, tmp1); |
852 | } |
853 | |
854 | // swap 4 |
855 | for (int i = 0; i < 4; i++) { |
856 | const int src_idx0 = base_idx + i; |
857 | const int src_idx4 = src_idx0 + 4; |
858 | |
859 | const auto tmp0 = tmp_zmm(src_idx0); |
860 | const auto src0 = src_zmm(src_idx0); |
861 | const auto src4 = src_zmm(src_idx4); |
862 | |
863 | vmovaps(tmp0, src0); |
864 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
865 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
866 | } |
867 | }; |
868 | |
869 | auto fixup16x16 = [=]() { |
870 | // swap 8 |
871 | const auto max_iters_phase_1 = std::min(ncolumns, 8); |
872 | for (int i = 0; i < max_iters_phase_1; i++) { |
873 | const auto tmp = tmp_zmm(i); |
874 | const auto src0 = src_zmm(i); |
875 | const auto src8 = src_zmm(8 + i); |
876 | vshuff64x2(tmp, src0, src8, 0x44); |
877 | store(tmp, i); |
878 | } |
879 | |
880 | const auto max_iters_phase_2 = std::min(ncolumns - 8, 8); |
881 | for (int i = 0; i < max_iters_phase_2; i++) { |
882 | const auto tmp = tmp_zmm(8 + i); |
883 | const auto src0 = src_zmm(i); |
884 | const auto src8 = src_zmm(8 + i); |
885 | vshuff64x2(tmp, src0, src8, 0xee); |
886 | store(tmp, 8 + i); |
887 | } |
888 | }; |
889 | |
890 | transpose16x8(0); |
891 | transpose16x8(8); |
892 | fixup16x16(); |
893 | } |
894 | |
895 | void jit_brgemm_trans_m_k_f16_t::generate() { |
896 | preamble(); |
897 | assert(conf_->ic_block % transpose_size == 0); |
898 | const int os_block = conf_->os_block; |
899 | const int last_os_block_tail = conf_->K_tail % transpose_size; |
900 | const int ic_tail = conf_->M_tail % transpose_size; |
901 | src_stride = conf_->ic * typesize_in; |
902 | tr_src_stride = conf_->LDA * typesize_out; |
903 | const dim_t m_src_shift = transpose_size * typesize_in; |
904 | const dim_t m_tr_src_shift = tr_src_stride * transpose_size; |
905 | |
906 | const dim_t batch_src_shift = src_stride * os_block; |
907 | const dim_t batch_tr_src_shift = tr_src_stride * conf_->M; |
908 | |
909 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
910 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
911 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
912 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
913 | |
914 | auto kmovw = [=](Opmask k, unsigned w) { |
915 | mov(regw_tmp, w); |
916 | jit_generator::kmovw(k, regw_tmp); |
917 | }; |
918 | |
919 | kmovw(k3333, 0x3333); // 0011001100110011 |
920 | kmovw(k5555, 0x5555); // 0101010101010101 |
921 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
922 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
923 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
924 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
925 | |
926 | auto compute_M = [=](bool is_os_tail) { |
927 | const auto nrows = is_os_tail ? last_os_block_tail : transpose_size; |
928 | mov(reg_loop_M, ptr[param1 + GET_OFF(current_M)]); |
929 | mov(reg_src, reg_src_base); |
930 | mov(reg_tr_src, reg_tr_src_base); |
931 | Label M_loop, M_tail_or_done, M_done; |
932 | if (ic_tail > 0) { |
933 | cmp(reg_loop_M, transpose_size); |
934 | jl(M_tail_or_done, T_NEAR); |
935 | } |
936 | |
937 | L(M_loop); |
938 | transpose_16x16(nrows, transpose_size); |
939 | if (conf_->ic_block > transpose_size) { |
940 | add(reg_src, m_src_shift); |
941 | add(reg_tr_src, m_tr_src_shift); |
942 | sub(reg_loop_M, transpose_size); |
943 | cmp(reg_loop_M, transpose_size); |
944 | jge(M_loop, T_NEAR); |
945 | } else { |
946 | jmp(M_done, T_NEAR); |
947 | } |
948 | |
949 | L(M_tail_or_done); |
950 | if (ic_tail > 0) { |
951 | cmp(reg_loop_M, 0); |
952 | jle(M_done, T_NEAR); |
953 | |
954 | transpose_16x16(nrows, ic_tail); |
955 | } |
956 | L(M_done); |
957 | }; |
958 | |
959 | auto compute_batch = [=](bool is_os_tail) { |
960 | Label batch_loop; |
961 | L(batch_loop); |
962 | |
963 | compute_M(is_os_tail); |
964 | add(reg_src_base, batch_src_shift); |
965 | add(reg_tr_src_base, batch_tr_src_shift); |
966 | |
967 | sub(reg_loop_batch, 1); |
968 | jnz(batch_loop, T_NEAR); |
969 | }; |
970 | |
971 | Label K_tail; |
972 | if (last_os_block_tail > 0) { |
973 | cmp(reg_loop_K, transpose_size); |
974 | jl(K_tail, T_NEAR); |
975 | } |
976 | |
977 | compute_batch(false); |
978 | |
979 | if (last_os_block_tail > 0) { |
980 | Label K_done; |
981 | jmp(K_done, T_NEAR); |
982 | |
983 | L(K_tail); |
984 | compute_batch(true); |
985 | L(K_done); |
986 | } |
987 | |
988 | postamble(); |
989 | } |
990 | |
991 | void jit_brgemm_copy_to_coarse_t::copy_row_blks(int num_row_blks) { |
992 | int rnd_row_blks = div_up(num_row_blks, row_loop_unroll); |
993 | |
994 | for (int row_b = 0; row_b < rnd_row_blks; ++row_b) { |
995 | const int row_start = 0; |
996 | const int row_end = nstl::min(static_cast<int>(row_loop_unroll), |
997 | num_row_blks - row_b * static_cast<int>(row_loop_unroll)); |
998 | |
999 | for (int row = row_start; row < row_end; ++row) { |
1000 | const int row_idx = row_b * row_loop_unroll + row; |
1001 | const auto offset = addr_offset(row_idx); |
1002 | |
1003 | const auto zmm = get_zmm_copy(row); |
1004 | const auto addr = EVEX_compress_addr(reg_data, offset); |
1005 | const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset); |
1006 | |
1007 | vmovdqu8(zmm, addr); |
1008 | vmovdqu8(addr_tr, zmm); |
1009 | } |
1010 | } |
1011 | } |
1012 | |
1013 | void jit_brgemm_copy_to_coarse_t::copy_row_tail( |
1014 | bool is_last_iteration, int row_offset) { |
1015 | // Masks for row tail load and store are already set up |
1016 | const auto load_mask = is_last_iteration ? reg_m_last_row_tail_load |
1017 | : reg_m_full_row_tail_load; |
1018 | const auto store_mask = is_last_iteration ? reg_m_last_row_tail_store |
1019 | : reg_m_full_row_tail_store; |
1020 | |
1021 | const auto zmm_data = zmm_row_tail | load_mask | T_z; |
1022 | const auto zmm_tr_data = zmm_row_tail | store_mask; |
1023 | |
1024 | const auto offset = addr_offset(row_offset); |
1025 | const auto addr = EVEX_compress_addr(reg_data, offset); |
1026 | const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset); |
1027 | |
1028 | vmovdqu8(zmm_data, addr); |
1029 | vmovdqu8(addr_tr, zmm_tr_data); |
1030 | } |
1031 | |
1032 | void jit_brgemm_copy_to_coarse_t::zero_out_rows() { |
1033 | const int row_blk = row_size_ % tr_row_size_; |
1034 | const int rnd_up_row_blk = utils::rnd_up(row_blk, row_step_); |
1035 | |
1036 | int zero_row_blks = tr_row_size_ - rnd_up_row_blk; |
1037 | if (zero_row_blks == 0) return; |
1038 | |
1039 | const auto zmm_step = row_step_, ymm_step = row_step_ / 2, |
1040 | xmm_step = row_step_ / 4; |
1041 | assert(zero_row_blks % xmm_step == 0); |
1042 | MAYBE_UNUSED(xmm_step); |
1043 | |
1044 | int zmm_iters = zero_row_blks / zmm_step; |
1045 | zero_row_blks %= zmm_step; |
1046 | int ymm_iters = zero_row_blks / ymm_step; |
1047 | zero_row_blks %= ymm_step; |
1048 | int xmm_iters = zero_row_blks / xmm_step; |
1049 | |
1050 | auto offset = addr_offset(rnd_up_row_blk / row_step_); |
1051 | |
1052 | for (int row = 0; row < zmm_iters; ++row) { |
1053 | const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset); |
1054 | vmovdqu8(addr_tr, zmm_zero); |
1055 | offset += (zmm_step * typesize_); |
1056 | } |
1057 | |
1058 | const auto ymm_zero = Xbyak::Ymm(zmm_zero.getIdx()); |
1059 | const auto xmm_zero = Xbyak::Xmm(zmm_zero.getIdx()); |
1060 | |
1061 | assert(xmm_iters <= 1 && ymm_iters <= 1); |
1062 | if (ymm_iters > 0) { |
1063 | const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset); |
1064 | vmovdqu8(addr_tr, ymm_zero); |
1065 | offset += (ymm_step * typesize_); |
1066 | } |
1067 | |
1068 | if (xmm_iters > 0) { |
1069 | const auto addr_tr = EVEX_compress_addr(reg_tr_data, offset); |
1070 | vmovdqu8(addr_tr, xmm_zero); |
1071 | } |
1072 | } |
1073 | |
1074 | void jit_brgemm_copy_to_coarse_t::copy_row_loop() { |
1075 | Xbyak::Label label_row_tail, label_row_exit; |
1076 | |
1077 | // Note: copying is done in chunks of size row_step_ |
1078 | const auto copy_row = [&](bool is_last_iteration) { |
1079 | const int row_blk |
1080 | = is_last_iteration ? (row_size_ % tr_row_size_) : tr_row_size_; |
1081 | const int row_iters = row_blk / row_step_; |
1082 | const int row_iters_tail = row_blk % row_step_; |
1083 | |
1084 | copy_row_blks(row_iters); |
1085 | if (row_iters_tail != 0) |
1086 | copy_row_tail(is_last_iteration, /* row_offset = */ row_iters); |
1087 | |
1088 | // For the last iteration, zero-out rows if needed |
1089 | if (is_last_iteration) zero_out_rows(); |
1090 | }; |
1091 | |
1092 | const bool only_row_tail = row_size_ < tr_row_size_; |
1093 | |
1094 | if (!only_row_tail) { |
1095 | cmp(reg_last_row_blk, 0); |
1096 | jne(label_row_tail, T_NEAR); |
1097 | |
1098 | copy_row(/* is_last_iteration = */ false); |
1099 | jmp(label_row_exit, T_NEAR); |
1100 | } |
1101 | |
1102 | L(label_row_tail); |
1103 | copy_row(/* is_last_iteration = */ true); |
1104 | |
1105 | L(label_row_exit); |
1106 | } |
1107 | |
1108 | void jit_brgemm_copy_to_coarse_t::copy_os_loop() { |
1109 | |
1110 | Label loop_os; |
1111 | L(loop_os); |
1112 | |
1113 | copy_row_loop(); |
1114 | add(reg_data, data_stride_); |
1115 | add(reg_tr_data, tr_data_stride_); |
1116 | |
1117 | dec(reg_os_work); |
1118 | jnz(loop_os, T_NEAR); |
1119 | } |
1120 | |
1121 | void jit_brgemm_copy_to_coarse_t::set_last_row_tail_masks() { |
1122 | const int row_tail = (row_size_ % tr_row_size_) % row_step_; |
1123 | assert(row_tail > 0 && "kernel is meant to be used with tail processing" ); |
1124 | |
1125 | // Set load mask |
1126 | const size_t tail_mask_load |
1127 | = (static_cast<size_t>(1) << (typesize_ * row_tail)) - 1; |
1128 | mov(reg_tail_mask, tail_mask_load); |
1129 | kmovq(reg_m_last_row_tail_load, reg_tail_mask); |
1130 | |
1131 | // Caution: Since size of ZMM equals 64 bytes therefore we need |
1132 | // different masks to store tails with smaller row_block_size_ |
1133 | constexpr auto full_mask = size_t {0xffffffffffffffff}; |
1134 | constexpr auto half_mask = size_t {0x00000000ffffffff}; |
1135 | constexpr auto quad_mask = size_t {0x000000000000ffff}; |
1136 | |
1137 | const auto num_bytes = [](size_t mask) -> int { |
1138 | // Given by 1 + position of leftmost 1 bit |
1139 | return 1 + math::ilog2q(mask); |
1140 | }; |
1141 | |
1142 | const int row_tail_store_size |
1143 | = utils::rnd_up(row_tail, row_block_size_) * typesize_; |
1144 | if (row_tail_store_size >= num_bytes(full_mask)) |
1145 | mov(reg_tail_mask, full_mask); |
1146 | else if (row_tail_store_size >= num_bytes(half_mask)) |
1147 | mov(reg_tail_mask, half_mask); |
1148 | else { |
1149 | assert(row_tail_store_size == num_bytes(quad_mask)); |
1150 | mov(reg_tail_mask, quad_mask); |
1151 | } |
1152 | kmovq(reg_m_last_row_tail_store, reg_tail_mask); |
1153 | } |
1154 | |
1155 | void jit_brgemm_copy_to_coarse_t::set_full_row_tail_masks() { |
1156 | const auto full_row_tail = tr_row_size_ % row_step_; |
1157 | assert(row_step_ == 2 * full_row_tail || row_step_ == 4 * full_row_tail); |
1158 | |
1159 | const auto tail_mask = row_step_ == 2 * full_row_tail |
1160 | ? size_t {0x00000000ffffffff} |
1161 | : size_t {0x000000000000ffff}; |
1162 | |
1163 | mov(reg_tail_mask, tail_mask); |
1164 | kmovq(reg_m_full_row_tail_store, reg_tail_mask); |
1165 | kmovq(reg_m_full_row_tail_load, reg_tail_mask); |
1166 | } |
1167 | |
1168 | void jit_brgemm_copy_to_coarse_t::generate() { |
1169 | preamble(); |
1170 | |
1171 | // set up masks for tail processing |
1172 | set_last_row_tail_masks(); |
1173 | const bool has_full_row_tail_ = tr_row_size_ % row_step_ != 0; |
1174 | if (has_full_row_tail_) set_full_row_tail_masks(); |
1175 | |
1176 | // init zero vreg (zmm_zero) if it is needed |
1177 | const int last_row_size |
1178 | = utils::rnd_up(row_size_ % tr_row_size_, row_step_); |
1179 | const bool zero_iters_needed |
1180 | = last_row_size > 0 && last_row_size < tr_row_size_; |
1181 | if (zero_iters_needed) vpxord(zmm_zero, zmm_zero, zmm_zero); |
1182 | |
1183 | // load arguments to the jit kernel |
1184 | mov(reg_data, ptr[param1 + GET_OFF(data)]); |
1185 | mov(reg_tr_data, ptr[param1 + GET_OFF(tr_data)]); |
1186 | mov(reg_os_work, ptr[param1 + GET_OFF(os_work)]); |
1187 | mov(reg_last_row_blk, ptr[param1 + GET_OFF(last_row_blk)]); |
1188 | |
1189 | // enter the `main` loop |
1190 | copy_os_loop(); |
1191 | |
1192 | postamble(); |
1193 | } |
1194 | |
1195 | struct jit_trans_to_vnni_t : public jit_brgemm_trans_to_vnni_t, |
1196 | public jit_generator { |
1197 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_to_vnni_t) |
1198 | jit_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf, |
1199 | jit_brgemm_trans_to_vnni_t::matrix_to_transform_t |
1200 | matrix_to_transform) |
1201 | : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) |
1202 | , jit_generator(jit_name()) {} |
1203 | |
1204 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1205 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1206 | |
1207 | private: |
1208 | using reg64_t = const Xbyak::Reg64; |
1209 | using reg32_t = const Xbyak::Reg32; |
1210 | using opmask_t = const Xbyak::Opmask; |
1211 | using zmm = const Xbyak::Zmm; |
1212 | |
1213 | enum { |
1214 | typesize_data = sizeof(int16_t), |
1215 | typesize_acc = sizeof(float), |
1216 | transpose_size = 16, |
1217 | }; |
1218 | |
1219 | int last_row_block_tail = 0, col_tail = 0; |
1220 | dim_t src_stride = 0, tr_src_stride = 0; |
1221 | dim_t src_col_shift = 0, tr_src_col_shift = 0; |
1222 | dim_t src_row_shift = 0, tr_src_row_shift = 0; |
1223 | dim_t src_batch_shift = 0, tr_src_batch_shift = 0; |
1224 | |
1225 | opmask_t kFFFF = k1; |
1226 | opmask_t mask_tail = k2; |
1227 | |
1228 | zmm vidx1 = zmm31; |
1229 | |
1230 | reg32_t regw_tmp = r15d; |
1231 | |
1232 | reg64_t reg_batch_src = r14; |
1233 | reg64_t reg_batch_tr_src = r13; |
1234 | |
1235 | reg64_t reg_row_src = r12; |
1236 | reg64_t reg_row_tr_src = r11; |
1237 | |
1238 | reg64_t reg_col_src = r10; |
1239 | reg64_t reg_col_tr_src = r9; |
1240 | |
1241 | reg64_t reg_loop_batch = r8; |
1242 | reg64_t reg_loop_row = rax; |
1243 | reg64_t reg_loop_col = rbx; |
1244 | |
1245 | reg64_t imm_addr64 = abi_not_param1; // lnx -> rcx |
1246 | |
1247 | void maybe_zero_pad_col(reg64_t dst); |
1248 | void transpose(reg64_t dst, reg64_t src, int nrows, |
1249 | int ncolumns = transpose_size, bool pad_by_zeroes = false); |
1250 | void generate() override; |
1251 | }; |
1252 | |
1253 | void jit_trans_to_vnni_t::maybe_zero_pad_col(reg64_t dst) { |
1254 | auto zmm_zero = Xbyak::Zmm(0); |
1255 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1256 | const int oc_utilized = rnd_up(conf_->oc % conf_->oc_block, transpose_size); |
1257 | const int iters = (conf_->oc_block - oc_utilized) / transpose_size; |
1258 | for (int n = 0; n < iters; ++n) { |
1259 | for (int i = 0; i < transpose_size; i += 2) { |
1260 | auto addr = EVEX_compress_addr(dst, i * tr_src_stride); |
1261 | vmovups(addr, zmm_zero); |
1262 | } |
1263 | add(reg_col_tr_src, tr_src_col_shift); |
1264 | } |
1265 | } |
1266 | |
1267 | void jit_trans_to_vnni_t::transpose( |
1268 | reg64_t dst, reg64_t src, int nrows, int ncolumns, bool pad_by_zeroes) { |
1269 | assert(nrows >= 0 && nrows <= transpose_size); |
1270 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
1271 | if (!nrows) return; |
1272 | |
1273 | auto src_zmm = [=](int i) { return Zmm(i); }; |
1274 | |
1275 | auto src_ymm = [=](int i) { |
1276 | assert(i >= 0 && i < 16); |
1277 | return Ymm(i); |
1278 | }; |
1279 | |
1280 | auto store = [=](Zmm r, int i) { |
1281 | auto addr = EVEX_compress_addr(dst, i * tr_src_stride); |
1282 | vmovups(addr, r); |
1283 | }; |
1284 | auto mask = ncolumns == transpose_size ? kFFFF : mask_tail; |
1285 | |
1286 | int i = 0; |
1287 | for (; i < nrows / 2; i++) { |
1288 | auto src1 = src_ymm(2 * i + 1); |
1289 | auto zmm_src0 = src_zmm(2 * i); |
1290 | auto zmm_src1 = src_zmm(2 * i + 1); |
1291 | if (matrix_to_transform_ == matrix_B) { |
1292 | vmovdqu16(zmm_src0 | mask | T_z, |
1293 | EVEX_compress_addr(src, 2 * i * src_stride)); |
1294 | vmovdqu16(zmm_src1 | mask | T_z, |
1295 | EVEX_compress_addr(src, (2 * i + 1) * src_stride)); |
1296 | vinsertf64x4(zmm_src0, zmm_src0, src1, 1); |
1297 | } else { |
1298 | vmovups(zmm_src0 | mask | T_z, |
1299 | EVEX_compress_addr(src, 2 * i * src_stride)); |
1300 | vmovups(zmm_src1 | mask | T_z, |
1301 | EVEX_compress_addr(src, (2 * i + 1) * src_stride)); |
1302 | vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0); |
1303 | } |
1304 | vpermw(zmm_src0, vidx1, zmm_src0); |
1305 | store(zmm_src0, 2 * i); |
1306 | } |
1307 | |
1308 | if (nrows % 2) { |
1309 | auto zmm_src0 = src_zmm(2 * i); |
1310 | if (matrix_to_transform_ == matrix_B) { |
1311 | vmovdqu16(zmm_src0 | mask | T_z, |
1312 | EVEX_compress_addr(src, 2 * i * src_stride)); |
1313 | } else { |
1314 | auto zmm_zero = src_zmm(2 * i + 1); |
1315 | vmovups(zmm_src0 | mask | T_z, |
1316 | EVEX_compress_addr(src, 2 * i * src_stride)); |
1317 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1318 | vcvtne2ps2bf16(zmm_src0, zmm_zero, zmm_src0); |
1319 | } |
1320 | vpermw(zmm_src0, vidx1, zmm_src0); |
1321 | store(zmm_src0, 2 * i); |
1322 | i++; |
1323 | } |
1324 | |
1325 | if (pad_by_zeroes && i < transpose_size / 2) { |
1326 | auto zmm_zero = src_zmm(2 * i); |
1327 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1328 | for (; i < transpose_size / 2; i++) |
1329 | store(zmm_zero, 2 * i); |
1330 | } |
1331 | } |
1332 | |
1333 | void jit_trans_to_vnni_t::generate() { |
1334 | preamble(); |
1335 | |
1336 | alignas(64) static constexpr const int16_t idx1[32] |
1337 | = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, |
1338 | 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
1339 | |
1340 | if (matrix_to_transform_ == matrix_B) { |
1341 | int row_block = conf_->os_block; |
1342 | |
1343 | constexpr int amx_xf16_granularity = 2; |
1344 | const bool last_row_padded = is_superset(conf_->isa, avx512_core_amx) |
1345 | && conf_->os % amx_xf16_granularity != 0; |
1346 | const int eff_K_tail = conf_->K_tail - (last_row_padded ? 1 : 0); |
1347 | |
1348 | last_row_block_tail = eff_K_tail % transpose_size; |
1349 | col_tail = conf_->oc % transpose_size; |
1350 | src_stride = conf_->oc * typesize_data; |
1351 | tr_src_stride = conf_->LDB * typesize_data; |
1352 | |
1353 | src_batch_shift = src_stride * row_block; |
1354 | tr_src_batch_shift = tr_src_stride * rnd_up(conf_->K, 2); |
1355 | |
1356 | src_col_shift = transpose_size * typesize_data; |
1357 | tr_src_col_shift = 2 * transpose_size * typesize_data; |
1358 | |
1359 | src_row_shift = transpose_size * conf_->oc * typesize_data; |
1360 | tr_src_row_shift = transpose_size * conf_->LDB * typesize_data; |
1361 | |
1362 | } else { // matrix_to_transform_ == matrix_C |
1363 | int row_block = conf_->ic_block; |
1364 | last_row_block_tail = conf_->M_tail % transpose_size; |
1365 | assert(row_block == transpose_size); |
1366 | col_tail = conf_->oc % transpose_size; |
1367 | src_stride = conf_->LDC * typesize_acc; |
1368 | tr_src_stride = conf_->LDD * typesize_data; |
1369 | |
1370 | src_batch_shift = src_stride * row_block; |
1371 | tr_src_batch_shift = tr_src_stride * rnd_up(conf_->M, 2); |
1372 | |
1373 | src_col_shift = transpose_size * typesize_acc; |
1374 | tr_src_col_shift = 2 * transpose_size * typesize_data; |
1375 | } |
1376 | |
1377 | // mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1378 | // mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1379 | // mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]); |
1380 | |
1381 | auto kmovw = [=](Opmask k, unsigned w) { |
1382 | mov(regw_tmp, w); |
1383 | jit_generator::kmovw(k, regw_tmp); |
1384 | }; |
1385 | auto kmovd = [=](Opmask k, unsigned w) { |
1386 | mov(regw_tmp, w); |
1387 | jit_generator::kmovd(k, regw_tmp); |
1388 | }; |
1389 | |
1390 | kmovw(kFFFF, 0xffff); // 1111111111111111 |
1391 | kmovd(mask_tail, (1 << col_tail) - 1); |
1392 | |
1393 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
1394 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
1395 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
1396 | }; |
1397 | |
1398 | vmovdqa64(vidx1, (const int64_t *)idx1); |
1399 | |
1400 | auto compute_col_loop = [&](reg64_t ®_base, reg64_t ®_tr_base, |
1401 | bool is_row_tail) { |
1402 | const bool pad_by_zeroes = matrix_to_transform_ == matrix_C; |
1403 | int nrows = is_row_tail ? last_row_block_tail : transpose_size; |
1404 | |
1405 | mov(reg_col_src, reg_base); |
1406 | mov(reg_col_tr_src, reg_tr_base); |
1407 | mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]); |
1408 | |
1409 | Label col_loop, col_loop_tail; |
1410 | cmp(reg_loop_col, transpose_size); |
1411 | jl(col_loop_tail, T_NEAR); |
1412 | |
1413 | L(col_loop); |
1414 | { |
1415 | transpose(reg_col_tr_src, reg_col_src, nrows, transpose_size, |
1416 | pad_by_zeroes); |
1417 | add(reg_col_src, src_col_shift); |
1418 | add(reg_col_tr_src, tr_src_col_shift); |
1419 | } |
1420 | sub(reg_loop_col, transpose_size); |
1421 | cmp(reg_loop_col, transpose_size); |
1422 | jge(col_loop, T_NEAR); |
1423 | |
1424 | L(col_loop_tail); |
1425 | if (col_tail > 0) { |
1426 | Label col_loop_done; |
1427 | cmp(reg_loop_col, 0); |
1428 | jle(col_loop_done, T_NEAR); |
1429 | transpose(reg_col_tr_src, reg_col_src, nrows, col_tail, |
1430 | pad_by_zeroes); |
1431 | L(col_loop_done); |
1432 | } |
1433 | const int oc_block_tail = conf_->oc % conf_->oc_block; |
1434 | const bool full_oc_block_utilized = oc_block_tail == 0 |
1435 | || rnd_up(oc_block_tail, transpose_size) == conf_->oc_block; |
1436 | const bool col_pad_required = pad_by_zeroes && !full_oc_block_utilized; |
1437 | |
1438 | if (col_pad_required) { |
1439 | Label col_pad_done; |
1440 | mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]); |
1441 | cmp(reg_loop_col, conf_->oc_block); |
1442 | je(col_pad_done, T_NEAR); |
1443 | if (col_tail > 0) add(reg_col_tr_src, tr_src_col_shift); |
1444 | maybe_zero_pad_col(reg_col_tr_src); |
1445 | L(col_pad_done); |
1446 | } |
1447 | }; |
1448 | |
1449 | auto compute_row_loop = [&](reg64_t ®_base, reg64_t ®_tr_base) { |
1450 | mov(reg_row_src, reg_base); |
1451 | mov(reg_row_tr_src, reg_tr_base); |
1452 | mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]); |
1453 | |
1454 | Label row_tail, row_loop, row_done; |
1455 | if (last_row_block_tail > 0) { |
1456 | cmp(reg_loop_row, transpose_size); |
1457 | jl(row_tail, T_NEAR); |
1458 | } |
1459 | L(row_loop); |
1460 | { |
1461 | compute_col_loop(reg_row_src, reg_row_tr_src, false); |
1462 | |
1463 | add(reg_row_src, src_row_shift); |
1464 | add(reg_row_tr_src, tr_src_row_shift); |
1465 | } |
1466 | sub(reg_loop_row, transpose_size); |
1467 | cmp(reg_loop_row, transpose_size); |
1468 | jge(row_loop, T_NEAR); |
1469 | |
1470 | cmp(reg_loop_row, 0); |
1471 | je(row_done, T_NEAR); |
1472 | |
1473 | if (last_row_block_tail > 0) { |
1474 | L(row_tail); |
1475 | compute_col_loop(reg_row_src, reg_row_tr_src, true); |
1476 | } |
1477 | L(row_done); |
1478 | }; |
1479 | |
1480 | mov(reg_batch_src, ptr[param1 + GET_OFF(src)]); |
1481 | mov(reg_batch_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1482 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
1483 | |
1484 | Label batch_loop; |
1485 | L(batch_loop); |
1486 | { |
1487 | compute_row_loop(reg_batch_src, reg_batch_tr_src); |
1488 | |
1489 | add(reg_batch_src, src_batch_shift); |
1490 | add(reg_batch_tr_src, tr_src_batch_shift); |
1491 | } |
1492 | sub(reg_loop_batch, 1); |
1493 | jnz(batch_loop, T_NEAR); |
1494 | |
1495 | postamble(); |
1496 | } |
1497 | |
1498 | struct jit_copy_f32_t : public jit_brgemm_trans_to_vnni_t, |
1499 | public jit_generator { |
1500 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f32_t) |
1501 | jit_copy_f32_t(const jit_brgemm_primitive_conf_t *conf, |
1502 | jit_brgemm_trans_to_vnni_t::matrix_to_transform_t |
1503 | matrix_to_transform) |
1504 | : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) |
1505 | , jit_generator(jit_name()) {} |
1506 | |
1507 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1508 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1509 | |
1510 | private: |
1511 | using reg64_t = const Xbyak::Reg64; |
1512 | using reg32_t = const Xbyak::Reg32; |
1513 | using opmask_t = const Xbyak::Opmask; |
1514 | using zmm = const Xbyak::Zmm; |
1515 | |
1516 | enum { |
1517 | typesize_data = sizeof(float), |
1518 | column_step = 16, |
1519 | num_regs = 32, |
1520 | }; |
1521 | |
1522 | dim_t src_stride = 0, tr_src_stride = 0; |
1523 | dim_t src_batch_shift = 0, tr_src_batch_shift = 0; |
1524 | dim_t col_shift = column_step * typesize_data; |
1525 | |
1526 | opmask_t mask_tail = k2; |
1527 | |
1528 | reg64_t reg_src = r8; |
1529 | reg64_t reg_tr_src = r9; |
1530 | reg64_t reg_loop_batch = r10; |
1531 | reg64_t reg_loop_row = r11; |
1532 | reg64_t reg_loop_col = r12; |
1533 | reg32_t regw_tmp = r14d; |
1534 | reg64_t reg_long_offt = r15; |
1535 | |
1536 | void copy_block(int nrows, int ncolumns); |
1537 | void generate() override; |
1538 | }; |
1539 | |
1540 | void jit_copy_f32_t::copy_block(int nrows, int ncolumns) { |
1541 | |
1542 | auto kmovd = [=](Opmask k, unsigned w) { |
1543 | mov(regw_tmp, w); |
1544 | jit_generator::kmovd(k, regw_tmp); |
1545 | }; |
1546 | |
1547 | const int nc_tail = ncolumns % column_step; |
1548 | if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1); |
1549 | |
1550 | auto get_zmm = [=](int i) { return Zmm(i % num_regs); }; |
1551 | |
1552 | auto load = [=](int r, int cb) { |
1553 | auto src_reg = get_zmm(r * cb); |
1554 | const bool is_tail |
1555 | = nc_tail > 0 && ncolumns - cb * column_step < column_step; |
1556 | auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg; |
1557 | const dim_t offset = r * src_stride + cb * col_shift; |
1558 | auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt); |
1559 | vmovups(src_load, addr); |
1560 | }; |
1561 | |
1562 | auto store = [=](int r, int cb) { |
1563 | auto reg = get_zmm(r * cb); |
1564 | const dim_t offset = r * tr_src_stride + cb * col_shift; |
1565 | auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt); |
1566 | vmovups(addr, reg); |
1567 | }; |
1568 | |
1569 | for_(int r = 0; r < nrows; r++) |
1570 | for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) { |
1571 | load(r, cb); |
1572 | store(r, cb); |
1573 | } |
1574 | } |
1575 | |
1576 | void jit_copy_f32_t::generate() { |
1577 | preamble(); |
1578 | |
1579 | const int row_block = conf_->os_block; |
1580 | const int row_tail = conf_->os % row_block; |
1581 | const int col_block = conf_->oc_block * conf_->nb_oc_blocking; |
1582 | const int col_tail = conf_->oc % col_block; |
1583 | src_stride = conf_->oc * typesize_data; |
1584 | tr_src_stride = conf_->LDB * typesize_data; |
1585 | src_batch_shift = src_stride * row_block; |
1586 | tr_src_batch_shift = tr_src_stride * row_block; |
1587 | |
1588 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1589 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1590 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
1591 | mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]); |
1592 | mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]); |
1593 | |
1594 | auto compute_batch = [=](int nrows, int ncolumns) { |
1595 | Label batch_loop; |
1596 | L(batch_loop); |
1597 | |
1598 | copy_block(nrows, ncolumns); |
1599 | add(reg_src, src_batch_shift); |
1600 | add(reg_tr_src, tr_src_batch_shift); |
1601 | |
1602 | sub(reg_loop_batch, 1); |
1603 | jnz(batch_loop, T_NEAR); |
1604 | }; |
1605 | |
1606 | auto compute_rows = [=](int ncolumns) { |
1607 | Label row_done; |
1608 | if (row_tail > 0) { |
1609 | Label row_common; |
1610 | cmp(reg_loop_row, row_block); |
1611 | je(row_common, T_NEAR); |
1612 | |
1613 | compute_batch(row_tail, ncolumns); |
1614 | jmp(row_done, T_NEAR); |
1615 | |
1616 | L(row_common); |
1617 | } |
1618 | |
1619 | compute_batch(row_block, ncolumns); |
1620 | L(row_done); |
1621 | }; |
1622 | |
1623 | Label col_done; |
1624 | if (col_tail > 0) { |
1625 | Label col_common; |
1626 | cmp(reg_loop_col, col_block); |
1627 | je(col_common, T_NEAR); |
1628 | |
1629 | compute_rows(col_tail); |
1630 | jmp(col_done, T_NEAR); |
1631 | |
1632 | L(col_common); |
1633 | } |
1634 | |
1635 | compute_rows(col_block); |
1636 | L(col_done); |
1637 | |
1638 | postamble(); |
1639 | } |
1640 | |
1641 | struct jit_copy_f16_t : public jit_brgemm_trans_to_vnni_t, |
1642 | public jit_generator { |
1643 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_copy_f16_t) |
1644 | jit_copy_f16_t(const jit_brgemm_primitive_conf_t *conf, |
1645 | jit_brgemm_trans_to_vnni_t::matrix_to_transform_t |
1646 | matrix_to_transform) |
1647 | : jit_brgemm_trans_to_vnni_t(conf, matrix_to_transform) |
1648 | , jit_generator(jit_name()) { |
1649 | |
1650 | // matrix_to_transform_ == matrix_B, copy(f16) -> f32 |
1651 | // matrix_to_transform_ == matrix_C, copy(f32) -> f16 + zero_pad |
1652 | if (matrix_to_transform_ == matrix_B) { |
1653 | row_block = conf_->os_block; |
1654 | row_tail = conf_->os % row_block; |
1655 | col_block = conf_->oc_block * conf_->nb_oc_blocking; |
1656 | col_tail = conf_->oc % col_block; |
1657 | typesize_in = types::data_type_size(data_type::f16); |
1658 | typesize_out = types::data_type_size(data_type::f32); |
1659 | src_stride = conf_->oc * typesize_in; |
1660 | tr_src_stride = conf_->LDB * typesize_out; |
1661 | src_batch_shift = src_stride * row_block; |
1662 | tr_src_batch_shift = tr_src_stride * row_block; |
1663 | } else { // matrix_C |
1664 | row_block = conf_->os_block; |
1665 | row_tail = conf_->os % row_block; |
1666 | col_block = conf_->oc_block * conf_->nb_oc_blocking; |
1667 | col_tail = conf_->oc % col_block; |
1668 | typesize_in = types::data_type_size(data_type::f32); |
1669 | typesize_out = types::data_type_size(data_type::f16); |
1670 | src_stride = conf_->LDB * typesize_in; |
1671 | tr_src_stride = conf_->LDB * typesize_out; |
1672 | src_batch_shift = src_stride * row_block; |
1673 | tr_src_batch_shift = tr_src_stride * row_block; |
1674 | } |
1675 | |
1676 | col_shift_in = column_step * typesize_in; |
1677 | col_shift_out = column_step * typesize_out; |
1678 | } |
1679 | |
1680 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1681 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1682 | |
1683 | private: |
1684 | using reg64_t = const Xbyak::Reg64; |
1685 | using reg32_t = const Xbyak::Reg32; |
1686 | using opmask_t = const Xbyak::Opmask; |
1687 | using zmm = const Xbyak::Zmm; |
1688 | |
1689 | enum { |
1690 | column_step = 16, |
1691 | num_regs = 32, |
1692 | }; |
1693 | |
1694 | size_t typesize_in = 0; |
1695 | size_t typesize_out = 0; |
1696 | |
1697 | int row_block = 0, row_tail = 0; |
1698 | int col_block = 0, col_tail = 0; |
1699 | dim_t src_stride = 0, tr_src_stride = 0; |
1700 | dim_t src_batch_shift = 0, tr_src_batch_shift = 0; |
1701 | dim_t col_shift_in = 0; |
1702 | dim_t col_shift_out = 0; |
1703 | |
1704 | opmask_t mask_tail = k2; |
1705 | |
1706 | reg64_t reg_src = r8; |
1707 | reg64_t reg_tr_src = r9; |
1708 | reg64_t reg_loop_batch = r10; |
1709 | reg64_t reg_loop_row = r11; |
1710 | reg64_t reg_loop_col = r12; |
1711 | reg32_t regw_tmp = r14d; |
1712 | reg64_t reg_long_offt = r15; |
1713 | |
1714 | void copy_block(bool is_row_tail, bool is_col_tail); |
1715 | void generate() override; |
1716 | }; |
1717 | |
1718 | void jit_copy_f16_t::copy_block(bool is_row_tail, bool is_col_tail) { |
1719 | |
1720 | const int nrows = is_row_tail && matrix_to_transform_ != matrix_C |
1721 | ? row_tail |
1722 | : row_block; |
1723 | const int ncolumns = is_col_tail && matrix_to_transform_ != matrix_C |
1724 | ? col_tail |
1725 | : col_block; |
1726 | |
1727 | auto kmovd = [=](Opmask k, unsigned w) { |
1728 | mov(regw_tmp, w); |
1729 | jit_generator::kmovd(k, regw_tmp); |
1730 | }; |
1731 | |
1732 | const int nc_tail = ncolumns % column_step; |
1733 | if (nc_tail > 0) kmovd(mask_tail, (1 << nc_tail) - 1); |
1734 | |
1735 | auto get_zmm = [=](int i) { return Zmm(i % num_regs); }; |
1736 | |
1737 | auto load = [=](int r, int cb) { |
1738 | auto src_reg = get_zmm(r * cb); |
1739 | const bool is_tail |
1740 | = nc_tail > 0 && ncolumns - cb * column_step < column_step; |
1741 | auto src_load = is_tail ? src_reg | mask_tail | T_z : src_reg; |
1742 | const dim_t offset = r * src_stride + cb * col_shift_in; |
1743 | auto addr = EVEX_compress_addr_safe(reg_src, offset, reg_long_offt); |
1744 | if (matrix_to_transform_ == matrix_B) |
1745 | vcvtph2psx(src_load, addr); |
1746 | else { // matrix_c |
1747 | if (r < nrows) |
1748 | vmovups(src_load, addr); |
1749 | else |
1750 | vpxord(src_load, src_load, src_load); |
1751 | } |
1752 | }; |
1753 | |
1754 | auto store = [=](int r, int cb) { |
1755 | auto reg = get_zmm(r * cb); |
1756 | const dim_t offset = r * tr_src_stride + cb * col_shift_out; |
1757 | auto addr = EVEX_compress_addr_safe(reg_tr_src, offset, reg_long_offt); |
1758 | if (matrix_to_transform_ == matrix_B) |
1759 | vmovups(addr, reg); |
1760 | else // matrix_C |
1761 | vcvtps2ph(addr, reg, 0x4); |
1762 | }; |
1763 | |
1764 | for_(int r = 0; r < nrows; r++) |
1765 | for (int cb = 0; cb < div_up(ncolumns, column_step); cb++) { |
1766 | load(r, cb); |
1767 | store(r, cb); |
1768 | } |
1769 | } |
1770 | |
1771 | void jit_copy_f16_t::generate() { |
1772 | preamble(); |
1773 | |
1774 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1775 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1776 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
1777 | mov(reg_loop_row, ptr[param1 + GET_OFF(current_row_size)]); |
1778 | mov(reg_loop_col, ptr[param1 + GET_OFF(current_col_size)]); |
1779 | |
1780 | auto compute_batch = [=](bool is_row_tail, bool is_col_tail) { |
1781 | Label batch_loop; |
1782 | L(batch_loop); |
1783 | |
1784 | copy_block(is_row_tail, is_col_tail); |
1785 | add(reg_src, src_batch_shift); |
1786 | add(reg_tr_src, tr_src_batch_shift); |
1787 | |
1788 | sub(reg_loop_batch, 1); |
1789 | jnz(batch_loop, T_NEAR); |
1790 | }; |
1791 | |
1792 | auto compute_rows = [=](bool is_col_tail) { |
1793 | Label row_done; |
1794 | if (row_tail > 0) { |
1795 | Label row_common; |
1796 | cmp(reg_loop_row, row_block); |
1797 | je(row_common, T_NEAR); |
1798 | |
1799 | compute_batch(true, is_col_tail); |
1800 | jmp(row_done, T_NEAR); |
1801 | |
1802 | L(row_common); |
1803 | } |
1804 | |
1805 | compute_batch(false, is_col_tail); |
1806 | L(row_done); |
1807 | }; |
1808 | |
1809 | Label col_done; |
1810 | if (col_tail > 0) { |
1811 | Label col_common; |
1812 | cmp(reg_loop_col, col_block); |
1813 | je(col_common, T_NEAR); |
1814 | |
1815 | compute_rows(true); |
1816 | jmp(col_done, T_NEAR); |
1817 | |
1818 | L(col_common); |
1819 | } |
1820 | |
1821 | compute_rows(false); |
1822 | L(col_done); |
1823 | |
1824 | postamble(); |
1825 | } |
1826 | |
1827 | struct jit_brgemm_trans_wei_f32_t : public jit_brgemm_trans_wei_t, |
1828 | public jit_generator { |
1829 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f32_t) |
1830 | |
1831 | jit_brgemm_trans_wei_f32_t(const jit_brgemm_primitive_conf_t *conf) |
1832 | : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {} |
1833 | |
1834 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1835 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1836 | |
1837 | private: |
1838 | using reg64_t = const Xbyak::Reg64; |
1839 | using reg32_t = const Xbyak::Reg32; |
1840 | using opmask_t = const Xbyak::Opmask; |
1841 | |
1842 | enum { typesize = sizeof(float), transpose_size = 16 }; |
1843 | dim_t src_stride = 0, tr_src_stride = 0; |
1844 | |
1845 | opmask_t k3333 = k1; |
1846 | opmask_t k5555 = k2; |
1847 | opmask_t kAAAA = k3; |
1848 | opmask_t kCCCC = k4; |
1849 | opmask_t k0F0F = k5; |
1850 | opmask_t kF0F0 = k6; |
1851 | opmask_t kTail = k7; |
1852 | |
1853 | reg64_t reg_src_base = rax; |
1854 | reg64_t reg_tr_src_base = rbx; |
1855 | |
1856 | reg64_t reg_src = r8; |
1857 | reg64_t reg_tr_src = r9; |
1858 | reg64_t reg_loop_N = r10; |
1859 | reg64_t reg_loop_K = r11; |
1860 | reg64_t reg_loop_batch = r12; |
1861 | reg64_t reg_tr_src_tmp = r13; |
1862 | reg32_t regw_tmp = r14d; |
1863 | |
1864 | void transpose_16x16(int nrows, int ncolumns = transpose_size); |
1865 | void generate() override; |
1866 | }; |
1867 | |
1868 | void jit_brgemm_trans_wei_f32_t::transpose_16x16(int nrows, int ncolumns) { |
1869 | assert(nrows >= 0 && nrows <= transpose_size); |
1870 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
1871 | if (!nrows) return; |
1872 | |
1873 | auto src_zmm = [=](int i) { |
1874 | assert(i >= 0 && i < 16); |
1875 | return Zmm(i); |
1876 | }; |
1877 | |
1878 | auto tmp_zmm = [=](int i) { |
1879 | assert(i >= 0 && i < 16); |
1880 | return Zmm(16 + i); |
1881 | }; |
1882 | |
1883 | auto kmovw = [=](Opmask k, unsigned w) { |
1884 | mov(regw_tmp, w); |
1885 | jit_generator::kmovw(k, regw_tmp); |
1886 | }; |
1887 | |
1888 | auto load = [=](int i) { |
1889 | auto src_load = src_zmm(i); |
1890 | if (ncolumns < transpose_size) { |
1891 | kmovw(kTail, (1 << ncolumns) - 1); |
1892 | src_load = src_zmm(i) | kTail | T_z; |
1893 | } |
1894 | vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
1895 | }; |
1896 | |
1897 | auto store = [=](Zmm r, int i) { |
1898 | mov(reg_tr_src_tmp, reg_tr_src); |
1899 | if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1); |
1900 | |
1901 | // Xbyak does not allow k0 to be specified explicitly via the '|' |
1902 | // operator, so we have to do this via a method call (implicitly |
1903 | // EVEX encoding uses k0 to mean 'no mask') |
1904 | bool partial_store = nrows < transpose_size; |
1905 | auto k = partial_store ? kTail : k0; |
1906 | auto base = reg_tr_src_tmp; |
1907 | base.setOpmaskIdx(k.getIdx(), true); |
1908 | |
1909 | auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
1910 | vmovups(addr, r); |
1911 | }; |
1912 | |
1913 | auto transpose16x8 = [=](int base_idx) { |
1914 | assert(base_idx == 0 || base_idx == 8); |
1915 | |
1916 | // swap 1 |
1917 | for (int i = 0; i < 4; i++) { |
1918 | int src_idx0 = base_idx + i * 2; |
1919 | int src_idx1 = src_idx0 + 1; |
1920 | |
1921 | int next_src_idx0 = src_idx0 + 2; |
1922 | int next_src_idx1 = src_idx1 + 2; |
1923 | bool load_next = base_idx == 0 || i < 3; |
1924 | |
1925 | if (base_idx == 0 && i == 0) { |
1926 | load(src_idx0); |
1927 | if (src_idx1 < nrows) |
1928 | load(src_idx1); |
1929 | else |
1930 | vpxord(src_zmm(src_idx1), src_zmm(src_idx1), |
1931 | src_zmm(src_idx1)); |
1932 | } |
1933 | |
1934 | auto tmp0 = tmp_zmm(src_idx0); |
1935 | auto tmp1 = tmp_zmm(src_idx1); |
1936 | auto src0 = src_zmm(src_idx0); |
1937 | auto src1 = src_zmm(src_idx1); |
1938 | |
1939 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
1940 | valignd(tmp0, src0, src0, 0x1); |
1941 | |
1942 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
1943 | valignd(tmp1, src1, src1, 0xf); |
1944 | |
1945 | vmovaps(src0 | kAAAA, tmp1); |
1946 | vmovaps(src1 | k5555, tmp0); |
1947 | } |
1948 | // swap 2 |
1949 | for (int i = 0; i < 4; i++) { |
1950 | int select_half = (i < 2) ? 0 : 2; |
1951 | int src_idx0 = base_idx + i + select_half + 0; |
1952 | int src_idx2 = src_idx0 + 2; |
1953 | |
1954 | auto tmp0 = tmp_zmm(src_idx0); |
1955 | auto tmp1 = tmp_zmm(src_idx2); |
1956 | auto src0 = src_zmm(src_idx0); |
1957 | auto src2 = src_zmm(src_idx2); |
1958 | |
1959 | valignd(tmp0, src0, src0, 0x2); |
1960 | valignd(tmp1, src2, src2, 0xe); |
1961 | vmovaps(src2 | k3333, tmp0); |
1962 | vmovaps(src0 | kCCCC, tmp1); |
1963 | } |
1964 | |
1965 | // swap 4 |
1966 | for (int i = 0; i < 4; i++) { |
1967 | int src_idx0 = base_idx + i; |
1968 | int src_idx4 = src_idx0 + 4; |
1969 | |
1970 | auto tmp0 = tmp_zmm(src_idx0); |
1971 | auto src0 = src_zmm(src_idx0); |
1972 | auto src4 = src_zmm(src_idx4); |
1973 | |
1974 | vmovaps(tmp0, src0); |
1975 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
1976 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
1977 | } |
1978 | }; |
1979 | |
1980 | auto fixup16x16 = [=]() { |
1981 | // swap 8 |
1982 | for (int i = 0; i < 8; i++) { |
1983 | auto tmp = tmp_zmm(i); |
1984 | auto src0 = src_zmm(i); |
1985 | auto src8 = src_zmm(8 + i); |
1986 | vshuff64x2(tmp, src0, src8, 0x44); |
1987 | store(tmp, i); |
1988 | } |
1989 | |
1990 | for (int i = 0; i < 8; i++) { |
1991 | auto tmp = tmp_zmm(8 + i); |
1992 | auto src0 = src_zmm(i); |
1993 | auto src8 = src_zmm(8 + i); |
1994 | vshuff64x2(tmp, src0, src8, 0xee); |
1995 | store(tmp, 8 + i); |
1996 | } |
1997 | }; |
1998 | |
1999 | transpose16x8(0); |
2000 | transpose16x8(8); |
2001 | fixup16x16(); |
2002 | } |
2003 | |
2004 | void jit_brgemm_trans_wei_f32_t::generate() { |
2005 | preamble(); |
2006 | assert(conf_->oc_block % transpose_size == 0); |
2007 | int fwd_ic_block = conf_->simd_w; |
2008 | int fwd_oc_block = 0; |
2009 | switch (conf_->wei_tag) { |
2010 | case OI16i64o: |
2011 | case OIw16i64o: |
2012 | case OIhw16i64o: |
2013 | case OIdhw16i64o: |
2014 | case OI8i64o2i: |
2015 | case OIw8i64o2i: |
2016 | case OIhw8i64o2i: |
2017 | case OIdhw8i64o2i: |
2018 | case OI16i64o2i: |
2019 | case OIw16i64o2i: |
2020 | case OIhw16i64o2i: |
2021 | case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break; |
2022 | case OI16i32o: |
2023 | case OIw16i32o: |
2024 | case OIhw16i32o: |
2025 | case OIdhw16i32o: |
2026 | case OI8i32o2i: |
2027 | case OIw8i32o2i: |
2028 | case OIhw8i32o2i: |
2029 | case OIdhw8i32o2i: |
2030 | case OI16i32o2i: |
2031 | case OIw16i32o2i: |
2032 | case OIhw16i32o2i: |
2033 | case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break; |
2034 | default: fwd_oc_block = conf_->simd_w; |
2035 | }; |
2036 | |
2037 | int oc_tail = conf_->K_tail % transpose_size; |
2038 | int ic_block = conf_->ic_block; |
2039 | int ic_tail = conf_->N_tail % transpose_size; |
2040 | src_stride = fwd_oc_block * typesize; |
2041 | tr_src_stride = ic_block * typesize; |
2042 | dim_t N_src_shift = conf_->kd * conf_->kh * conf_->kw * fwd_ic_block |
2043 | * fwd_oc_block * typesize; |
2044 | dim_t N_tr_src_shift = conf_->simd_w * typesize; |
2045 | dim_t K_src_shift = conf_->simd_w * typesize; |
2046 | dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize; |
2047 | |
2048 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
2049 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
2050 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
2051 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
2052 | |
2053 | auto kmovw = [=](Opmask k, unsigned w) { |
2054 | mov(regw_tmp, w); |
2055 | jit_generator::kmovw(k, regw_tmp); |
2056 | }; |
2057 | |
2058 | kmovw(k3333, 0x3333); // 0011001100110011 |
2059 | kmovw(k5555, 0x5555); // 0101010101010101 |
2060 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
2061 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
2062 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
2063 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
2064 | |
2065 | auto compute_N = [=](bool is_oc_tail) { |
2066 | mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]); |
2067 | mov(reg_src, reg_src_base); |
2068 | mov(reg_tr_src, reg_tr_src_base); |
2069 | Label N_loop, N_loop_tail; |
2070 | |
2071 | cmp(reg_loop_N, transpose_size); |
2072 | jl(N_loop_tail, T_NEAR); |
2073 | |
2074 | L(N_loop); |
2075 | |
2076 | transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size); |
2077 | add(reg_src, N_src_shift); |
2078 | add(reg_tr_src, N_tr_src_shift); |
2079 | |
2080 | sub(reg_loop_N, transpose_size); |
2081 | cmp(reg_loop_N, transpose_size); |
2082 | jge(N_loop, T_NEAR); |
2083 | |
2084 | L(N_loop_tail); |
2085 | if (ic_tail > 0) { |
2086 | Label N_loop_done; |
2087 | cmp(reg_loop_N, 0); |
2088 | jle(N_loop_done, T_NEAR); |
2089 | transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size); |
2090 | L(N_loop_done); |
2091 | } |
2092 | }; |
2093 | |
2094 | Label K_loop, K_tail; |
2095 | if (oc_tail > 0) { |
2096 | cmp(reg_loop_K, transpose_size); |
2097 | jl(K_tail, T_NEAR); |
2098 | } |
2099 | |
2100 | L(K_loop); |
2101 | compute_N(false); |
2102 | add(reg_src_base, K_src_shift); |
2103 | add(reg_tr_src_base, K_tr_src_shift); |
2104 | |
2105 | sub(reg_loop_K, transpose_size); |
2106 | cmp(reg_loop_K, transpose_size); |
2107 | jge(K_loop, T_NEAR); |
2108 | |
2109 | L(K_tail); |
2110 | if (oc_tail > 0) { |
2111 | Label K_loop_done; |
2112 | cmp(reg_loop_K, 0); |
2113 | jle(K_loop_done, T_NEAR); |
2114 | |
2115 | compute_N(true); |
2116 | L(K_loop_done); |
2117 | } |
2118 | |
2119 | postamble(); |
2120 | } |
2121 | |
2122 | struct jit_brgemm_trans_wei_bf16_t : public jit_brgemm_trans_wei_t, |
2123 | public jit_generator { |
2124 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_bf16_t) |
2125 | |
2126 | jit_brgemm_trans_wei_bf16_t(const jit_brgemm_primitive_conf_t *conf) |
2127 | : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {} |
2128 | |
2129 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
2130 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
2131 | |
2132 | private: |
2133 | using reg64_t = const Xbyak::Reg64; |
2134 | using reg32_t = const Xbyak::Reg32; |
2135 | using opmask_t = const Xbyak::Opmask; |
2136 | using zmm = const Xbyak::Zmm; |
2137 | |
2138 | enum { typesize = sizeof(int16_t), transpose_size = 16 }; |
2139 | dim_t src_stride = 0, tr_src_stride = 0; |
2140 | |
2141 | opmask_t kTail = k7; |
2142 | |
2143 | reg64_t reg_src_base = rax; |
2144 | reg64_t reg_tr_src_base = rbx; |
2145 | |
2146 | reg64_t reg_src = r8; |
2147 | reg64_t reg_tr_src = r9; |
2148 | reg64_t reg_loop_N = r10; |
2149 | reg64_t reg_loop_K = r11; |
2150 | reg64_t reg_loop_batch = r12; |
2151 | reg64_t reg_tr_src_tmp = r13; |
2152 | reg32_t regw_tmp = r14d; |
2153 | reg64_t imm_addr64 = r15; |
2154 | |
2155 | zmm v_abcdefgh_to_abefcdgh = zmm31; |
2156 | |
2157 | void transpose_16x16_vnni(int nrows, int ncolumns = transpose_size); |
2158 | void generate() override; |
2159 | }; |
2160 | |
2161 | void jit_brgemm_trans_wei_bf16_t::transpose_16x16_vnni( |
2162 | int nrows, int ncolumns) { |
2163 | assert(nrows >= 0 && nrows <= transpose_size); |
2164 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
2165 | if (!nrows) return; |
2166 | |
2167 | auto src_zmm = [=](int i) { |
2168 | assert(i >= 0 && i < 8); |
2169 | return Zmm(i); |
2170 | }; |
2171 | |
2172 | auto tmp_zmm = [=](int i) { |
2173 | assert(i >= 0 && i < 8); |
2174 | return Zmm(8 + i); |
2175 | }; |
2176 | |
2177 | auto kmovw = [=](Opmask k, unsigned w) { |
2178 | mov(regw_tmp, w); |
2179 | jit_generator::kmovw(k, regw_tmp); |
2180 | }; |
2181 | |
2182 | auto load = [=](int i) { |
2183 | auto src_load = src_zmm(i); |
2184 | if (ncolumns < transpose_size) { |
2185 | kmovw(kTail, (1 << ncolumns) - 1); |
2186 | src_load = src_zmm(i) | kTail | T_z; |
2187 | } |
2188 | vmovups(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
2189 | }; |
2190 | |
2191 | auto store = [=](Zmm r, int i) { |
2192 | mov(reg_tr_src_tmp, reg_tr_src); |
2193 | if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1); |
2194 | |
2195 | // Xbyak does not allow k0 to be specified explicitly via the '|' |
2196 | // operator, so we have to do this via a method call (implicitly |
2197 | // EVEX encoding uses k0 to mean 'no mask') |
2198 | bool partial_store = nrows < transpose_size; |
2199 | auto k = partial_store ? kTail : k0; |
2200 | auto base = reg_tr_src_tmp; |
2201 | base.setOpmaskIdx(k.getIdx(), true); |
2202 | |
2203 | auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
2204 | vmovups(addr, r); |
2205 | }; |
2206 | |
2207 | for (int i = 0; i < 8; i++) |
2208 | load(i); |
2209 | |
2210 | for (int i = 0; i < 8; i++) |
2211 | vpshufb(src_zmm(i), src_zmm(i), v_abcdefgh_to_abefcdgh); |
2212 | |
2213 | for (int i = 0; i < 2; i++) { |
2214 | vpunpcklqdq(tmp_zmm(2 * i + 0), src_zmm(2 * i), src_zmm(2 * i + 1)); |
2215 | vpunpckhqdq(tmp_zmm(2 * i + 1), src_zmm(2 * i), src_zmm(2 * i + 1)); |
2216 | } |
2217 | |
2218 | for (int i = 0; i < 2; i++) { |
2219 | vpunpcklqdq( |
2220 | src_zmm(2 * i + 0), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1)); |
2221 | vpunpckhqdq( |
2222 | src_zmm(2 * i + 1), src_zmm(4 + 2 * i), src_zmm(4 + 2 * i + 1)); |
2223 | } |
2224 | |
2225 | for (int i = 0; i < 2; i++) { |
2226 | vshufi32x4(src_zmm(4 + 0 + i), tmp_zmm(i), tmp_zmm(2 + i), 0x88); |
2227 | vshufi32x4(src_zmm(4 + 2 + i), tmp_zmm(i), tmp_zmm(2 + i), 0xdd); |
2228 | } |
2229 | |
2230 | for (int i = 0; i < 2; i++) { |
2231 | vshufi32x4(tmp_zmm(0 + i), src_zmm(i), src_zmm(2 + i), 0x88); |
2232 | vshufi32x4(tmp_zmm(2 + i), src_zmm(i), src_zmm(2 + i), 0xdd); |
2233 | } |
2234 | |
2235 | for (int i = 0; i < 4; i++) |
2236 | vshufi32x4(src_zmm(i), src_zmm(4 + i), tmp_zmm(i), 0x88); |
2237 | |
2238 | for (int i = 0; i < 4; i++) |
2239 | vshufi32x4(src_zmm(4 + i), src_zmm(4 + i), tmp_zmm(i), 0xdd); |
2240 | |
2241 | for (int i = 0; i < 8; i++) |
2242 | store(src_zmm(i), i); |
2243 | } |
2244 | |
2245 | void jit_brgemm_trans_wei_bf16_t::generate() { |
2246 | preamble(); |
2247 | int fwd_oc_block = 0; |
2248 | switch (conf_->wei_tag) { |
2249 | case OI16i64o: |
2250 | case OIw16i64o: |
2251 | case OIhw16i64o: |
2252 | case OIdhw16i64o: |
2253 | case OI8i64o2i: |
2254 | case OIw8i64o2i: |
2255 | case OIhw8i64o2i: |
2256 | case OIdhw8i64o2i: |
2257 | case OI16i64o2i: |
2258 | case OIw16i64o2i: |
2259 | case OIhw16i64o2i: |
2260 | case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break; |
2261 | case OI16i32o: |
2262 | case OIw16i32o: |
2263 | case OIhw16i32o: |
2264 | case OIdhw16i32o: |
2265 | case OI8i32o2i: |
2266 | case OIw8i32o2i: |
2267 | case OIhw8i32o2i: |
2268 | case OIdhw8i32o2i: |
2269 | case OI16i32o2i: |
2270 | case OIw16i32o2i: |
2271 | case OIhw16i32o2i: |
2272 | case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break; |
2273 | default: fwd_oc_block = conf_->simd_w; |
2274 | }; |
2275 | |
2276 | int oc_tail = conf_->K_tail % transpose_size; |
2277 | int ic_block = conf_->ic_block; |
2278 | int ic_tail = conf_->N_tail % transpose_size; |
2279 | src_stride = 2 * fwd_oc_block * typesize; |
2280 | tr_src_stride = 2 * ic_block * typesize; |
2281 | dim_t N_src_shift = conf_->simd_w * fwd_oc_block * typesize; |
2282 | dim_t N_tr_src_shift = 2 * conf_->simd_w * typesize; |
2283 | dim_t K_src_shift = 2 * conf_->simd_w * typesize; |
2284 | dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize; |
2285 | |
2286 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
2287 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
2288 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
2289 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
2290 | |
2291 | alignas(64) static constexpr const int32_t abcdefgh_to_abefcdgh[16] |
2292 | = {0x05040100, 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100, |
2293 | 0x07060302, 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302, |
2294 | 0x0d0c0908, 0x0f0e0b0a, 0x05040100, 0x07060302, 0x0d0c0908, |
2295 | 0x0f0e0b0a}; |
2296 | |
2297 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
2298 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
2299 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
2300 | }; |
2301 | |
2302 | vmovdqa64(v_abcdefgh_to_abefcdgh, (const int64_t *)abcdefgh_to_abefcdgh); |
2303 | auto compute_N = [=](bool is_oc_tail) { |
2304 | mov(reg_src, reg_src_base); |
2305 | mov(reg_tr_src, reg_tr_src_base); |
2306 | mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]); |
2307 | |
2308 | Label N_loop, N_loop_tail; |
2309 | cmp(reg_loop_N, transpose_size); |
2310 | jl(N_loop_tail, T_NEAR); |
2311 | |
2312 | L(N_loop); |
2313 | |
2314 | transpose_16x16_vnni( |
2315 | transpose_size, is_oc_tail ? oc_tail : transpose_size); |
2316 | add(reg_src, N_src_shift); |
2317 | add(reg_tr_src, N_tr_src_shift); |
2318 | |
2319 | sub(reg_loop_N, transpose_size); |
2320 | cmp(reg_loop_N, transpose_size); |
2321 | jge(N_loop, T_NEAR); |
2322 | |
2323 | L(N_loop_tail); |
2324 | if (ic_tail > 0) { |
2325 | Label N_loop_done; |
2326 | cmp(reg_loop_N, 0); |
2327 | jle(N_loop_done, T_NEAR); |
2328 | transpose_16x16_vnni( |
2329 | ic_tail, is_oc_tail ? oc_tail : transpose_size); |
2330 | L(N_loop_done); |
2331 | } |
2332 | }; |
2333 | |
2334 | Label K_loop, K_tail; |
2335 | if (oc_tail > 0) { |
2336 | cmp(reg_loop_K, transpose_size); |
2337 | jl(K_tail, T_NEAR); |
2338 | } |
2339 | |
2340 | L(K_loop); |
2341 | compute_N(false); |
2342 | add(reg_src_base, K_src_shift); |
2343 | add(reg_tr_src_base, K_tr_src_shift); |
2344 | |
2345 | sub(reg_loop_K, transpose_size); |
2346 | cmp(reg_loop_K, transpose_size); |
2347 | jge(K_loop, T_NEAR); |
2348 | |
2349 | L(K_tail); |
2350 | if (oc_tail > 0) { |
2351 | Label K_loop_done; |
2352 | cmp(reg_loop_K, 0); |
2353 | jle(K_loop_done, T_NEAR); |
2354 | compute_N(true); |
2355 | L(K_loop_done); |
2356 | } |
2357 | |
2358 | postamble(); |
2359 | } |
2360 | |
2361 | struct jit_brgemm_trans_wei_f16_t : public jit_brgemm_trans_wei_t, |
2362 | public jit_generator { |
2363 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_trans_wei_f16_t) |
2364 | |
2365 | jit_brgemm_trans_wei_f16_t(const jit_brgemm_primitive_conf_t *conf) |
2366 | : jit_brgemm_trans_wei_t(conf), jit_generator(jit_name()) {} |
2367 | |
2368 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
2369 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
2370 | |
2371 | private: |
2372 | using reg64_t = const Xbyak::Reg64; |
2373 | using reg32_t = const Xbyak::Reg32; |
2374 | using opmask_t = const Xbyak::Opmask; |
2375 | |
2376 | enum { |
2377 | typesize_in = sizeof(int16_t), |
2378 | typesize_out = sizeof(float), |
2379 | transpose_size = 16 |
2380 | }; |
2381 | dim_t src_stride = 0, tr_src_stride = 0; |
2382 | |
2383 | opmask_t k3333 = k1; |
2384 | opmask_t k5555 = k2; |
2385 | opmask_t kAAAA = k3; |
2386 | opmask_t kCCCC = k4; |
2387 | opmask_t k0F0F = k5; |
2388 | opmask_t kF0F0 = k6; |
2389 | opmask_t kTail = k7; |
2390 | |
2391 | reg64_t reg_src_base = rax; |
2392 | reg64_t reg_tr_src_base = rbx; |
2393 | |
2394 | reg64_t reg_src = r8; |
2395 | reg64_t reg_tr_src = r9; |
2396 | reg64_t reg_loop_N = r10; |
2397 | reg64_t reg_loop_K = r11; |
2398 | reg64_t reg_loop_batch = r12; |
2399 | reg64_t reg_tr_src_tmp = r13; |
2400 | reg32_t regw_tmp = r14d; |
2401 | reg64_t imm_addr64 = r15; |
2402 | |
2403 | Xbyak::Zmm v_abcdefgh_to_abefcdgh = zmm31; |
2404 | |
2405 | void transpose_16x16(int nrows, int ncolumns = transpose_size); |
2406 | void generate() override; |
2407 | }; |
2408 | |
2409 | void jit_brgemm_trans_wei_f16_t::transpose_16x16(int nrows, int ncolumns) { |
2410 | assert(nrows >= 0 && nrows <= transpose_size); |
2411 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
2412 | if (!nrows) return; |
2413 | |
2414 | auto src_zmm = [=](int i) { |
2415 | assert(i >= 0 && i < 16); |
2416 | return Zmm(i); |
2417 | }; |
2418 | |
2419 | auto tmp_zmm = [=](int i) { |
2420 | assert(i >= 0 && i < 16); |
2421 | return Zmm(16 + i); |
2422 | }; |
2423 | |
2424 | auto kmovw = [=](Opmask k, unsigned w) { |
2425 | mov(regw_tmp, w); |
2426 | jit_generator::kmovw(k, regw_tmp); |
2427 | }; |
2428 | |
2429 | auto load = [=](int i) { |
2430 | auto src_load = src_zmm(i); |
2431 | if (ncolumns < transpose_size) { |
2432 | kmovw(kTail, (1 << ncolumns) - 1); |
2433 | src_load = src_zmm(i) | kTail | T_z; |
2434 | } |
2435 | // TODO: Maybe do tranformations in fp16 data type and at the end |
2436 | // cvt to f32. Thus reducing instructions |
2437 | vcvtph2psx(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
2438 | }; |
2439 | |
2440 | auto store = [=](Zmm r, int i) { |
2441 | mov(reg_tr_src_tmp, reg_tr_src); |
2442 | if (nrows < transpose_size) kmovw(kTail, (1 << nrows) - 1); |
2443 | |
2444 | // Xbyak does not allow k0 to be specified explicitly via the '|' |
2445 | // operator, so we have to do this via a method call (implicitly |
2446 | // EVEX encoding uses k0 to mean 'no mask') |
2447 | bool partial_store = nrows < transpose_size; |
2448 | auto k = partial_store ? kTail : k0; |
2449 | auto base = reg_tr_src_tmp; |
2450 | base.setOpmaskIdx(k.getIdx(), true); |
2451 | |
2452 | auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
2453 | vmovups(addr, r); |
2454 | }; |
2455 | |
2456 | auto transpose16x8 = [=](int base_idx) { |
2457 | assert(base_idx == 0 || base_idx == 8); |
2458 | |
2459 | // swap 1 |
2460 | for (int i = 0; i < 4; i++) { |
2461 | int src_idx0 = base_idx + i * 2; |
2462 | int src_idx1 = src_idx0 + 1; |
2463 | |
2464 | int next_src_idx0 = src_idx0 + 2; |
2465 | int next_src_idx1 = src_idx1 + 2; |
2466 | bool load_next = base_idx == 0 || i < 3; |
2467 | |
2468 | if (base_idx == 0 && i == 0) { |
2469 | load(src_idx0); |
2470 | if (src_idx1 < nrows) |
2471 | load(src_idx1); |
2472 | else |
2473 | vpxord(src_zmm(src_idx1), src_zmm(src_idx1), |
2474 | src_zmm(src_idx1)); |
2475 | } |
2476 | |
2477 | auto tmp0 = tmp_zmm(src_idx0); |
2478 | auto tmp1 = tmp_zmm(src_idx1); |
2479 | auto src0 = src_zmm(src_idx0); |
2480 | auto src1 = src_zmm(src_idx1); |
2481 | |
2482 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
2483 | valignd(tmp0, src0, src0, 0x1); |
2484 | |
2485 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
2486 | valignd(tmp1, src1, src1, 0xf); |
2487 | |
2488 | vmovaps(src0 | kAAAA, tmp1); |
2489 | vmovaps(src1 | k5555, tmp0); |
2490 | } |
2491 | // swap 2 |
2492 | for (int i = 0; i < 4; i++) { |
2493 | int select_half = (i < 2) ? 0 : 2; |
2494 | int src_idx0 = base_idx + i + select_half + 0; |
2495 | int src_idx2 = src_idx0 + 2; |
2496 | |
2497 | auto tmp0 = tmp_zmm(src_idx0); |
2498 | auto tmp1 = tmp_zmm(src_idx2); |
2499 | auto src0 = src_zmm(src_idx0); |
2500 | auto src2 = src_zmm(src_idx2); |
2501 | |
2502 | valignd(tmp0, src0, src0, 0x2); |
2503 | valignd(tmp1, src2, src2, 0xe); |
2504 | vmovaps(src2 | k3333, tmp0); |
2505 | vmovaps(src0 | kCCCC, tmp1); |
2506 | } |
2507 | |
2508 | // swap 4 |
2509 | for (int i = 0; i < 4; i++) { |
2510 | int src_idx0 = base_idx + i; |
2511 | int src_idx4 = src_idx0 + 4; |
2512 | |
2513 | auto tmp0 = tmp_zmm(src_idx0); |
2514 | auto src0 = src_zmm(src_idx0); |
2515 | auto src4 = src_zmm(src_idx4); |
2516 | |
2517 | vmovaps(tmp0, src0); |
2518 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
2519 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
2520 | } |
2521 | }; |
2522 | |
2523 | auto fixup16x16 = [=]() { |
2524 | // swap 8 |
2525 | for (int i = 0; i < 8; i++) { |
2526 | auto tmp = tmp_zmm(i); |
2527 | auto src0 = src_zmm(i); |
2528 | auto src8 = src_zmm(8 + i); |
2529 | vshuff64x2(tmp, src0, src8, 0x44); |
2530 | store(tmp, i); |
2531 | } |
2532 | |
2533 | for (int i = 0; i < 8; i++) { |
2534 | auto tmp = tmp_zmm(8 + i); |
2535 | auto src0 = src_zmm(i); |
2536 | auto src8 = src_zmm(8 + i); |
2537 | vshuff64x2(tmp, src0, src8, 0xee); |
2538 | store(tmp, 8 + i); |
2539 | } |
2540 | }; |
2541 | |
2542 | transpose16x8(0); |
2543 | transpose16x8(8); |
2544 | fixup16x16(); |
2545 | } |
2546 | |
2547 | void jit_brgemm_trans_wei_f16_t::generate() { |
2548 | preamble(); |
2549 | assert(conf_->oc_block % transpose_size == 0); |
2550 | int fwd_ic_block = conf_->simd_w; |
2551 | int fwd_oc_block = 0; |
2552 | switch (conf_->wei_tag) { |
2553 | case OI16i64o: |
2554 | case OIw16i64o: |
2555 | case OIhw16i64o: |
2556 | case OIdhw16i64o: |
2557 | case OI8i64o2i: |
2558 | case OIw8i64o2i: |
2559 | case OIhw8i64o2i: |
2560 | case OIdhw8i64o2i: |
2561 | case OI16i64o2i: |
2562 | case OIw16i64o2i: |
2563 | case OIhw16i64o2i: |
2564 | case OIdhw16i64o2i: fwd_oc_block = 4 * conf_->simd_w; break; |
2565 | case OI16i32o: |
2566 | case OIw16i32o: |
2567 | case OIhw16i32o: |
2568 | case OIdhw16i32o: |
2569 | case OI8i32o2i: |
2570 | case OIw8i32o2i: |
2571 | case OIhw8i32o2i: |
2572 | case OIdhw8i32o2i: |
2573 | case OI16i32o2i: |
2574 | case OIw16i32o2i: |
2575 | case OIhw16i32o2i: |
2576 | case OIdhw16i32o2i: fwd_oc_block = 2 * conf_->simd_w; break; |
2577 | default: fwd_oc_block = conf_->simd_w; |
2578 | }; |
2579 | |
2580 | int oc_tail = conf_->K_tail % transpose_size; |
2581 | int ic_block = conf_->ic_block; |
2582 | int ic_tail = conf_->N_tail % transpose_size; |
2583 | src_stride = fwd_oc_block * typesize_in; |
2584 | tr_src_stride = ic_block * typesize_out; |
2585 | dim_t N_src_shift = (conf_->kd * conf_->kh * conf_->kw * fwd_ic_block) |
2586 | * fwd_oc_block * typesize_in; |
2587 | dim_t N_tr_src_shift = conf_->simd_w * typesize_out; |
2588 | dim_t K_src_shift = conf_->simd_w * typesize_in; |
2589 | dim_t K_tr_src_shift = conf_->ic_block * conf_->simd_w * typesize_out; |
2590 | |
2591 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
2592 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
2593 | mov(reg_loop_batch, ptr[param1 + GET_OFF(current_gemm_batch)]); |
2594 | mov(reg_loop_K, ptr[param1 + GET_OFF(current_K)]); |
2595 | |
2596 | auto kmovw = [=](Opmask k, unsigned w) { |
2597 | mov(regw_tmp, w); |
2598 | jit_generator::kmovw(k, regw_tmp); |
2599 | }; |
2600 | |
2601 | kmovw(k3333, 0x3333); // 0011001100110011 |
2602 | kmovw(k5555, 0x5555); // 0101010101010101 |
2603 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
2604 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
2605 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
2606 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
2607 | |
2608 | auto compute_N = [=](bool is_oc_tail) { |
2609 | mov(reg_loop_N, ptr[param1 + GET_OFF(current_N)]); |
2610 | mov(reg_src, reg_src_base); |
2611 | mov(reg_tr_src, reg_tr_src_base); |
2612 | Label N_loop, N_loop_tail; |
2613 | |
2614 | cmp(reg_loop_N, transpose_size); |
2615 | jl(N_loop_tail, T_NEAR); |
2616 | |
2617 | L(N_loop); |
2618 | |
2619 | transpose_16x16(transpose_size, is_oc_tail ? oc_tail : transpose_size); |
2620 | add(reg_src, N_src_shift); |
2621 | add(reg_tr_src, N_tr_src_shift); |
2622 | |
2623 | sub(reg_loop_N, transpose_size); |
2624 | cmp(reg_loop_N, transpose_size); |
2625 | jge(N_loop, T_NEAR); |
2626 | |
2627 | L(N_loop_tail); |
2628 | if (ic_tail > 0) { |
2629 | Label N_loop_done; |
2630 | cmp(reg_loop_N, 0); |
2631 | jle(N_loop_done, T_NEAR); |
2632 | transpose_16x16(ic_tail, is_oc_tail ? oc_tail : transpose_size); |
2633 | L(N_loop_done); |
2634 | } |
2635 | }; |
2636 | |
2637 | Label K_loop, K_tail; |
2638 | if (oc_tail > 0) { |
2639 | cmp(reg_loop_K, transpose_size); |
2640 | jl(K_tail, T_NEAR); |
2641 | } |
2642 | |
2643 | L(K_loop); |
2644 | compute_N(false); |
2645 | add(reg_src_base, K_src_shift); |
2646 | add(reg_tr_src_base, K_tr_src_shift); |
2647 | |
2648 | sub(reg_loop_K, transpose_size); |
2649 | cmp(reg_loop_K, transpose_size); |
2650 | jge(K_loop, T_NEAR); |
2651 | |
2652 | L(K_tail); |
2653 | if (oc_tail > 0) { |
2654 | Label K_loop_done; |
2655 | cmp(reg_loop_K, 0); |
2656 | jle(K_loop_done, T_NEAR); |
2657 | |
2658 | compute_N(true); |
2659 | L(K_loop_done); |
2660 | } |
2661 | |
2662 | postamble(); |
2663 | } |
2664 | |
2665 | struct jit_amx_ip_trans_diff_wei_to_vnni_t : public jit_amx_ip_trans_diff_wei, |
2666 | public jit_generator { |
2667 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_amx_ip_trans_diff_wei_to_vnni) |
2668 | |
2669 | jit_amx_ip_trans_diff_wei_to_vnni_t(const jit_brgemm_primitive_conf_t *jbgp, |
2670 | const int ext_ic_block, const int ext_oc_block) |
2671 | : jit_amx_ip_trans_diff_wei(jbgp, ext_ic_block, ext_oc_block) |
2672 | , jit_generator(jit_name()) {} |
2673 | |
2674 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
2675 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
2676 | |
2677 | private: |
2678 | void generate() override; |
2679 | }; |
2680 | |
2681 | void jit_amx_ip_trans_diff_wei_to_vnni_t::generate() { |
2682 | const int typesize_out = 2; |
2683 | const int typesize_acc = 4; |
2684 | const int simd_w = 16; |
2685 | |
2686 | using reg64_t = const Xbyak::Reg64; |
2687 | using reg32_t = const Xbyak::Reg32; |
2688 | |
2689 | const reg64_t ®_output = r15; |
2690 | const reg64_t ®_input = r14; |
2691 | const reg64_t ®_prm_table = r13; |
2692 | const reg64_t ®_last_ic_block = r12; |
2693 | const reg64_t ®_last_oc_block = r11; |
2694 | const reg32_t ®w_tmp = r10d; |
2695 | |
2696 | const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31); |
2697 | auto get_zmm_src = [&](int ic) { return Xbyak::Zmm(ic % 8); }; |
2698 | |
2699 | Xbyak::Label prm_table; |
2700 | Xbyak::Label skip_oc_tail, to_exit; |
2701 | |
2702 | Xbyak::Opmask load_mask = k4; |
2703 | |
2704 | int tail_mask = (jbgp_->N_tail % simd_w) |
2705 | ? (1 << (jbgp_->N_tail % simd_w)) - 1 |
2706 | : 0xffff; |
2707 | auto kmovw = [=](Xbyak::Opmask k, unsigned w) { |
2708 | mov(regw_tmp, w); |
2709 | jit_generator::kmovw(k, regw_tmp); |
2710 | }; |
2711 | |
2712 | auto reorder_oc_block = [&](int icb, int ic_block, bool is_oc_tail) { |
2713 | // INP: [64i][No] : FP32 |
2714 | // OUT: [OCB][ICB][16i][No][2i]: BF16 |
2715 | if (ic_block <= 0) return; |
2716 | |
2717 | dim_t inp_icb_offset = typesize_acc |
2718 | * (icb * ext_ic_block_ * jbgp_->oc_block); // Internal |
2719 | dim_t out_icb_offset = typesize_out |
2720 | * (icb * div_up(ext_ic_block_, 2) * ext_oc_block_ |
2721 | * 2); // External |
2722 | |
2723 | const int oc_padded = rnd_up(jbgp_->oc, jbgp_->oc_block); |
2724 | const int oc_padded_ext = rnd_up(jbgp_->oc, ext_oc_block_); |
2725 | |
2726 | bool tailing_done = false; |
2727 | for (int oc = 0; oc < jbgp_->oc_block; oc += simd_w) { |
2728 | int ext_oc = oc % ext_oc_block_; |
2729 | int ext_ocb = oc / ext_oc_block_; |
2730 | dim_t ext_ocb_offset = typesize_out |
2731 | * (ext_ocb * div_up(jbgp_->ic, ext_ic_block_) |
2732 | * div_up(ext_ic_block_, 2) * ext_oc_block_ * 2); |
2733 | if (is_oc_tail && oc_padded != oc_padded_ext |
2734 | && oc + simd_w > ext_oc_block_) |
2735 | break; |
2736 | dim_t inp_offset = inp_icb_offset + typesize_acc * (oc); // Internal |
2737 | dim_t out_offset = out_icb_offset + typesize_out * (ext_oc * 2) |
2738 | + ext_ocb_offset; // External |
2739 | kmovw(load_mask, 0xffff); |
2740 | if (is_oc_tail) { |
2741 | if (jbgp_->N_tail && (oc + simd_w) >= jbgp_->N_tail) { |
2742 | if (tailing_done == false) { |
2743 | kmovw(load_mask, tail_mask); |
2744 | tailing_done = true; |
2745 | } else { |
2746 | auto zmm_src_0 = get_zmm_src(0); |
2747 | vpxord(zmm_src_0, zmm_src_0, zmm_src_0); |
2748 | for (int ic = 0; ic < ext_ic_block_ / 2; ic++) { |
2749 | vmovups(ptr[reg_output + out_offset |
2750 | + typesize_out |
2751 | * (ic * ext_oc_block_ * 2)], |
2752 | zmm_src_0); |
2753 | } |
2754 | continue; |
2755 | } |
2756 | } |
2757 | } |
2758 | |
2759 | int ic = 0; |
2760 | for (; ic < ic_block / 2; ic++) { |
2761 | int ic1 = 2 * ic; |
2762 | int ic2 = 2 * ic + 1; |
2763 | |
2764 | auto zmm_src_0 = get_zmm_src(ic1); |
2765 | auto zmm_src_1 = get_zmm_src(ic2); |
2766 | |
2767 | vmovups(zmm_src_0 | load_mask | T_z, |
2768 | ptr[reg_input + inp_offset |
2769 | + typesize_acc * (ic1 * jbgp_->oc_block)]); |
2770 | vmovups(zmm_src_1 | load_mask | T_z, |
2771 | ptr[reg_input + inp_offset |
2772 | + typesize_acc * (ic2 * jbgp_->oc_block)]); |
2773 | if (jbgp_->wei_dt == data_type::bf16) { |
2774 | vcvtne2ps2bf16(zmm_src_0, zmm_src_1, zmm_src_0); |
2775 | } else { |
2776 | assert(jbgp_->wei_dt == data_type::f16); |
2777 | vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0); |
2778 | vcvtps2phx(Ymm(zmm_src_1.getIdx()), zmm_src_1); |
2779 | vinsertf32x8( |
2780 | zmm_src_0, zmm_src_0, Ymm(zmm_src_1.getIdx()), 1); |
2781 | } |
2782 | vpermw(zmm_src_0, zmm_idx, zmm_src_0); |
2783 | |
2784 | vmovups(ptr[reg_output + out_offset |
2785 | + typesize_out * (ic * ext_oc_block_ * 2)], |
2786 | zmm_src_0); |
2787 | } |
2788 | if (ic_block % 2) { |
2789 | int ic1 = 2 * ic; |
2790 | auto zmm_src_0 = get_zmm_src(ic1); |
2791 | |
2792 | vmovups(zmm_src_0 | load_mask | T_z, |
2793 | ptr[reg_input + inp_offset |
2794 | + typesize_acc * (ic1 * jbgp_->oc_block)]); |
2795 | |
2796 | if (jbgp_->wei_dt == data_type::bf16) { |
2797 | vcvtneps2bf16(Ymm(zmm_src_0.getIdx()), zmm_src_0); |
2798 | } else { |
2799 | assert(jbgp_->wei_dt == data_type::f16); |
2800 | vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0); |
2801 | } |
2802 | vpermw(zmm_src_0, zmm_idx, zmm_src_0); |
2803 | |
2804 | vmovups(ptr[reg_output + out_offset |
2805 | + typesize_out * (ic * ext_oc_block_ * 2)], |
2806 | zmm_src_0); |
2807 | ic++; |
2808 | } |
2809 | if (ic < ext_ic_block_ / 2) { |
2810 | auto zmm_src_0 = get_zmm_src(0); |
2811 | vpxord(zmm_src_0, zmm_src_0, zmm_src_0); |
2812 | for (; ic < ext_ic_block_ / 2; ic++) { |
2813 | vmovups(ptr[reg_output + out_offset |
2814 | + typesize_out * (ic * ext_oc_block_ * 2)], |
2815 | zmm_src_0); |
2816 | } |
2817 | } |
2818 | } |
2819 | }; |
2820 | |
2821 | auto reorder_ic_block = [&](bool is_oc_tail, bool is_ic_tail) { |
2822 | int nb_ic = div_up(jbgp_->ic_block, ext_ic_block_); |
2823 | for (int icb = 0; icb < nb_ic; icb++) { |
2824 | int ic_0 = icb * ext_ic_block_; |
2825 | int ic_1 = (icb + 1) * ext_ic_block_; |
2826 | if (is_ic_tail) { |
2827 | int ext_ic_tail = (jbgp_->ic % ext_ic_block_) |
2828 | ? (jbgp_->ic % ext_ic_block_) |
2829 | : ext_ic_block_; |
2830 | if (jbgp_->M_tail && ic_0 >= jbgp_->M_tail) break; |
2831 | if (jbgp_->M_tail && ic_0 <= jbgp_->M_tail |
2832 | && jbgp_->M_tail <= ic_1) { |
2833 | reorder_oc_block(icb, ext_ic_tail, is_oc_tail); |
2834 | } else { |
2835 | reorder_oc_block(icb, ext_ic_block_, is_oc_tail); |
2836 | } |
2837 | } else { |
2838 | reorder_oc_block(icb, ext_ic_block_, is_oc_tail); |
2839 | } |
2840 | } |
2841 | }; |
2842 | |
2843 | auto reorder = [&](bool is_oc_tail) { |
2844 | Xbyak::Label skip_ic_tail, to_exit_1; |
2845 | |
2846 | cmp(reg_last_ic_block, 0); |
2847 | je(skip_ic_tail, T_NEAR); |
2848 | |
2849 | reorder_ic_block(is_oc_tail, true); |
2850 | jmp(to_exit, T_NEAR); |
2851 | |
2852 | L(skip_ic_tail); |
2853 | reorder_ic_block(is_oc_tail, false); |
2854 | |
2855 | L(to_exit_1); |
2856 | }; |
2857 | |
2858 | preamble(); |
2859 | |
2860 | mov(reg_input, ptr[abi_param1 + GET_OFF(src)]); |
2861 | mov(reg_output, ptr[abi_param1 + GET_OFF(dst)]); |
2862 | mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]); |
2863 | mov(reg_last_oc_block, ptr[abi_param1 + GET_OFF(last_oc_block)]); |
2864 | |
2865 | mov(reg_prm_table, prm_table); |
2866 | vmovups(zmm_idx, ptr[reg_prm_table]); |
2867 | |
2868 | cmp(reg_last_oc_block, 0); |
2869 | je(skip_oc_tail, T_NEAR); |
2870 | |
2871 | reorder(true); |
2872 | jmp(to_exit, T_NEAR); |
2873 | |
2874 | L(skip_oc_tail); |
2875 | reorder(false); |
2876 | |
2877 | L(to_exit); |
2878 | postamble(); |
2879 | |
2880 | align(64); |
2881 | L(prm_table); |
2882 | const uint16_t prm_array[32] |
2883 | = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, |
2884 | 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
2885 | for (size_t i = 0; i < 32; ++i) |
2886 | dw(prm_array[i]); |
2887 | } |
2888 | |
2889 | #undef GET_OFF |
2890 | |
2891 | status_t create_brgemm_trans_src( |
2892 | std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker, |
2893 | const jit_brgemm_primitive_conf_t *conf) { |
2894 | |
2895 | if (conf->prop_kind != dnnl_backward_weights) |
2896 | return status::invalid_arguments; |
2897 | |
2898 | if (conf->src_dt == data_type::f32) { |
2899 | CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f32_t(conf))); |
2900 | } else if (utils::one_of(conf->src_dt, data_type::bf16, data_type::f16) |
2901 | && conf->isa != avx512_core_fp16) { |
2902 | CHECK(safe_ptr_assign( |
2903 | trans_ker, new jit_brgemm_trans_m_k_bf16_t(conf))); |
2904 | } else if (conf->src_dt == data_type::f16) { |
2905 | assert(conf->isa == avx512_core_fp16); |
2906 | CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_m_k_f16_t(conf))); |
2907 | } else { |
2908 | return status::invalid_arguments; |
2909 | } |
2910 | |
2911 | return trans_ker->create_kernel(); |
2912 | } |
2913 | |
2914 | status_t create_brgemm_copy_to_coarse( |
2915 | std::unique_ptr<jit_brgemm_copy_to_coarse_t> ©_ker, |
2916 | const jit_brgemm_primitive_conf_t *conf) { |
2917 | if (is_superset(conf->isa, avx512_core_amx)) |
2918 | CHECK(safe_ptr_assign(copy_ker, new jit_brgemm_copy_to_coarse_t(conf))); |
2919 | else |
2920 | return status::invalid_arguments; |
2921 | |
2922 | return copy_ker->create_kernel(); |
2923 | } |
2924 | |
2925 | status_t create_brgemm_trans_to_vnni( |
2926 | std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker, |
2927 | const jit_brgemm_primitive_conf_t *conf, |
2928 | jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform) { |
2929 | if (conf->prop_kind != dnnl_backward_weights) |
2930 | return status::invalid_arguments; |
2931 | |
2932 | if (conf->dst_dt == data_type::f32) { |
2933 | CHECK(safe_ptr_assign( |
2934 | trans_ker, new jit_copy_f32_t(conf, matrix_to_transform))); |
2935 | } else if (one_of(conf->dst_dt, data_type::bf16, data_type::f16) |
2936 | && conf->isa != avx512_core_fp16) { |
2937 | CHECK(safe_ptr_assign( |
2938 | trans_ker, new jit_trans_to_vnni_t(conf, matrix_to_transform))); |
2939 | } else if (conf->dst_dt == data_type::f16) { |
2940 | CHECK(safe_ptr_assign( |
2941 | trans_ker, new jit_copy_f16_t(conf, matrix_to_transform))); |
2942 | } else { |
2943 | return status::invalid_arguments; |
2944 | } |
2945 | |
2946 | return trans_ker->create_kernel(); |
2947 | } |
2948 | |
2949 | status_t create_brgemm_trans_wei( |
2950 | std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker, |
2951 | const jit_brgemm_primitive_conf_t *conf) { |
2952 | |
2953 | if (conf->prop_kind != dnnl_backward_data) return status::invalid_arguments; |
2954 | |
2955 | if (conf->wei_dt == data_type::f32) { |
2956 | CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f32_t(conf))); |
2957 | } else if (one_of(conf->wei_dt, data_type::bf16, data_type::f16) |
2958 | && conf->isa != avx512_core_fp16) { |
2959 | CHECK(safe_ptr_assign( |
2960 | trans_ker, new jit_brgemm_trans_wei_bf16_t(conf))); |
2961 | } else if (conf->wei_dt == data_type::f16) { |
2962 | assert(conf->isa == avx512_core_fp16); |
2963 | CHECK(safe_ptr_assign(trans_ker, new jit_brgemm_trans_wei_f16_t(conf))); |
2964 | } else { |
2965 | return status::invalid_arguments; |
2966 | } |
2967 | |
2968 | return trans_ker->create_kernel(); |
2969 | } |
2970 | |
2971 | status_t create_brgemm_amx_ip_trans_wei( |
2972 | std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker, |
2973 | const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block, |
2974 | const int ext_oc_block) { |
2975 | if (conf->prop_kind == dnnl_backward_weights |
2976 | && one_of(conf->wei_dt, data_type::bf16, data_type::f16)) { |
2977 | CHECK(safe_ptr_assign(trans_ker, |
2978 | new jit_amx_ip_trans_diff_wei_to_vnni_t( |
2979 | conf, ext_ic_block, ext_oc_block))); |
2980 | } else |
2981 | return status::invalid_arguments; |
2982 | |
2983 | return trans_ker->create_kernel(); |
2984 | } |
2985 | |
2986 | } // namespace x64 |
2987 | } // namespace cpu |
2988 | } // namespace impl |
2989 | } // namespace dnnl |
2990 | |