1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/x64/jit_generator.hpp" |
22 | |
23 | #include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | namespace matmul { |
30 | |
31 | using namespace dnnl::impl::format_tag; |
32 | using namespace dnnl::impl::utils; |
33 | using namespace Xbyak; |
34 | |
35 | #define GET_OFF(x) offsetof(ctx_t, x) |
36 | |
37 | struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t, |
38 | public jit_generator { |
39 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_impl_t) |
40 | |
41 | jit_brgemm_matmul_copy_a_impl_t(const brgemm_matmul_conf_t *conf) |
42 | : jit_brgemm_matmul_copy_a_t(conf) |
43 | , jit_generator(jit_name()) |
44 | , typesize(conf_->a_dt_sz) |
45 | , tr_typesize(conf_->tr_a_dt_sz) |
46 | , vnni_granularity(data_type_vnni_granularity(conf_->src_dt)) |
47 | , k_step(bytes_in_zmm / nstl::max(typesize, tr_typesize)) {} |
48 | |
49 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
50 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
51 | |
52 | private: |
53 | using reg64_t = const Xbyak::Reg64; |
54 | using reg32_t = const Xbyak::Reg32; |
55 | using opmask_t = const Xbyak::Opmask; |
56 | using zmm = const Xbyak::Zmm; |
57 | using ymm = const Xbyak::Ymm; |
58 | using xmm = const Xbyak::Xmm; |
59 | |
60 | enum { |
61 | num_comp_acc = 8, |
62 | k_loop_unroll = 16, |
63 | bytes_in_zmm = 64, |
64 | }; |
65 | const int typesize; |
66 | const int tr_typesize; |
67 | const int vnni_granularity; |
68 | const int k_step; |
69 | |
70 | dim_t src_stride = 0, tr_src_stride = 0; |
71 | bool do_compute_compensation = false; |
72 | |
73 | opmask_t kTail_load = k7; |
74 | opmask_t kTail_store = k6; |
75 | opmask_t kTail_comp = k5; |
76 | |
77 | reg64_t reg_src = rax; |
78 | reg64_t reg_tr_src = rbx; |
79 | reg64_t reg_K_start = abi_not_param1; |
80 | |
81 | reg64_t reg_zp_comp_buf_ptr = rdx; |
82 | reg64_t reg_zp_comp_res_ptr = rsi; |
83 | |
84 | reg64_t reg_M_blk = r9; |
85 | reg64_t reg_K_blk = r10; |
86 | reg64_t reg_batch = r11; |
87 | reg64_t reg_aux_src = r12; |
88 | reg64_t reg_aux_tr_src = r13; |
89 | reg64_t regq_tmp = r14; |
90 | reg64_t imm_addr64 = r15; |
91 | reg64_t reg_zp_ab_comp_ptr = imm_addr64; |
92 | reg64_t reg_zp_b_neg_val_ptr = reg_K_blk; |
93 | |
94 | zmm zmm_comp_mul = zmm30; |
95 | zmm zmm_comp_add = zmm31; |
96 | |
97 | // Allows to shift A data by 128 for s8s8 problem for AVX512 in copy |
98 | // routine, not in compute kernel. It's disabled for now, as it |
99 | // requires setting some hint to brgemm kerenel to avoid double shifting |
100 | const bool allow_input_shift_for_s8s8 = false; |
101 | |
102 | Xbyak::Zmm get_zmm_comp_acc(int i) { |
103 | assert(i >= 0 && i < num_comp_acc); |
104 | return Xbyak::Zmm(i); |
105 | } |
106 | |
107 | Xbyak::Zmm get_zmm_copy(int i) { |
108 | assert(i >= 0 && i < k_loop_unroll); |
109 | return Xbyak::Zmm(29 - i); |
110 | } |
111 | void reduce_compensation_across_accumulators(int num_accumulators); |
112 | void copy_row(int ncolumns); |
113 | void copy_K_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter); |
114 | void copy_M_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter); |
115 | void generate() override; |
116 | }; |
117 | |
118 | void jit_brgemm_matmul_copy_a_impl_t::reduce_compensation_across_accumulators( |
119 | int num_accumulators) { |
120 | int num = num_accumulators; |
121 | while (num > 1) { |
122 | for (int i = 0; i < num / 2; i++) { |
123 | const auto zmm_acc0 = get_zmm_comp_acc(i); |
124 | const auto zmm_acc1 = get_zmm_comp_acc(div_up(num, 2) + i); |
125 | vpaddd(zmm_acc0, zmm_acc0, zmm_acc1); |
126 | } |
127 | num = div_up(num, 2); |
128 | } |
129 | } |
130 | |
131 | void jit_brgemm_matmul_copy_a_impl_t::copy_K_loop( |
132 | bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) { |
133 | MAYBE_UNUSED(is_K_tail); |
134 | MAYBE_UNUSED(is_first_K_iter); |
135 | MAYBE_UNUSED(is_last_K_iter); |
136 | |
137 | const int K_blk = is_K_tail ? conf_->K % conf_->K_blk |
138 | : nstl::min(conf_->K, conf_->K_blk); |
139 | const int k_tail = K_blk % k_step; |
140 | const int num_k_iters = K_blk / k_step; |
141 | const int num_acc = utils::saturate(1, (int)num_comp_acc, num_k_iters); |
142 | |
143 | if (do_compute_compensation) { |
144 | for (int i = 0; i < num_acc; i++) { |
145 | const auto zmm_acc = get_zmm_comp_acc(i); |
146 | vpxord(zmm_acc, zmm_acc, zmm_acc); |
147 | } |
148 | } |
149 | |
150 | auto maybe_compute_compensation = [=](int k_idx, zmm zmm_copy) { |
151 | if (do_compute_compensation) { |
152 | const auto zmm_comp_acc = get_zmm_comp_acc(k_idx % num_acc); |
153 | if (conf_->src_dt == data_type::s8) |
154 | vpdpbusd(zmm_comp_acc, zmm_comp_mul, zmm_copy); |
155 | else |
156 | vpdpbusd(zmm_comp_acc, zmm_copy, zmm_comp_mul); |
157 | } |
158 | }; |
159 | |
160 | for (int kb = 0; kb < div_up(num_k_iters, k_loop_unroll); kb++) { |
161 | int k_start = 0; |
162 | int k_end = nstl::min( |
163 | (int)k_loop_unroll, num_k_iters - kb * k_loop_unroll); |
164 | for (int k = k_start; k < k_end; k++) { |
165 | const int k_idx = kb * k_loop_unroll + k; |
166 | const size_t offset = (size_t)k_idx * k_step * typesize; |
167 | const auto addr = EVEX_compress_addr(reg_src, offset); |
168 | if (conf_->isa == avx512_core_fp16) { |
169 | vcvtph2psx(get_zmm_copy(k), addr); |
170 | } else { |
171 | vmovdqu8(get_zmm_copy(k), EVEX_compress_addr(reg_src, offset)); |
172 | } |
173 | maybe_compute_compensation(k_idx, get_zmm_copy(k)); |
174 | } |
175 | if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) { |
176 | for (int k = k_start; k < k_end; k++) |
177 | vpaddb(get_zmm_copy(k), get_zmm_copy(k), zmm_comp_add); |
178 | } |
179 | if (conf_->is_bf32) { |
180 | assert(typesize != tr_typesize); |
181 | int k = k_start; |
182 | const int k_end_2 = rnd_dn(k_end, 2); |
183 | for (; k < k_end_2; k += 2) { |
184 | const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step |
185 | * tr_typesize; |
186 | auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset); |
187 | |
188 | auto zmm_src = get_zmm_copy(k); |
189 | auto zmm_src_next = get_zmm_copy(k + 1); |
190 | |
191 | vcvtne2ps2bf16(zmm_src, zmm_src_next, zmm_src); |
192 | vmovups(tr_src_addr, zmm_src); |
193 | } |
194 | if (k < k_end) { |
195 | const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step |
196 | * tr_typesize; |
197 | auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset); |
198 | ymm ymm_downcvt_bf16 = ymm(get_zmm_copy(k).getIdx()); |
199 | vcvtneps2bf16(ymm_downcvt_bf16, get_zmm_copy(k)); |
200 | vmovdqu16(tr_src_addr, ymm_downcvt_bf16); |
201 | } |
202 | } else { |
203 | for (int k = k_start; k < k_end; k++) { |
204 | const size_t offset = ((size_t)kb * k_loop_unroll + k) * k_step |
205 | * tr_typesize; |
206 | auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset); |
207 | vmovdqu8(tr_src_addr, get_zmm_copy(k)); |
208 | } |
209 | } |
210 | } |
211 | |
212 | if (k_tail > 0) { |
213 | const auto kmovx = [=](Opmask k, size_t q) { |
214 | if (conf_->is_bf32) { |
215 | mov(regq_tmp.cvt32(), q); |
216 | jit_generator::kmovw(k, regq_tmp.cvt32()); |
217 | } else { |
218 | mov(regq_tmp, q); |
219 | jit_generator::kmovq(k, regq_tmp); |
220 | } |
221 | }; |
222 | |
223 | const size_t dt_step = conf_->is_bf32 || conf_->isa == avx512_core_fp16 |
224 | ? 1 |
225 | : typesize; |
226 | const size_t tail_mask_load |
227 | = size_t(((size_t)1 << (dt_step * k_tail)) - 1); |
228 | kmovx(kTail_load, tail_mask_load); |
229 | const int k_tail_st = rnd_up(k_tail, vnni_granularity); |
230 | const size_t full_mask |
231 | = conf_->is_bf32 ? ((size_t)1 << 16) - 1 : 0xffffffffffffffff; |
232 | const size_t tail_mask_store = k_tail_st == k_step |
233 | ? full_mask |
234 | : size_t(((size_t)1 << (dt_step * k_tail_st)) - 1); |
235 | kmovx(kTail_store, tail_mask_store); |
236 | |
237 | auto zmm_tail = get_zmm_copy(0) | kTail_load | T_z; |
238 | auto load_addr |
239 | = EVEX_compress_addr(reg_src, num_k_iters * k_step * typesize); |
240 | if (conf_->is_bf32) |
241 | vmovups(zmm_tail, load_addr); |
242 | else if (conf_->isa == avx512_core_fp16) |
243 | vcvtph2psx(zmm_tail, load_addr); |
244 | else |
245 | vmovdqu8(zmm_tail, load_addr); |
246 | |
247 | maybe_compute_compensation(0, get_zmm_copy(0)); |
248 | |
249 | if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) |
250 | vpaddb(get_zmm_copy(0), get_zmm_copy(0), zmm_comp_add); |
251 | |
252 | auto tr_src_addr = EVEX_compress_addr( |
253 | reg_tr_src, num_k_iters * k_step * tr_typesize); |
254 | if (conf_->is_bf32) { |
255 | ymm ymm_downcvt_bf16 = ymm(get_zmm_copy(0).getIdx()); |
256 | vcvtneps2bf16(ymm_downcvt_bf16, get_zmm_copy(0)); |
257 | vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store); |
258 | } else if (conf_->isa == avx512_core_fp16) { |
259 | vmovups(tr_src_addr, get_zmm_copy(0) | kTail_store); |
260 | } else |
261 | vmovdqu8(tr_src_addr, get_zmm_copy(0) | kTail_store); |
262 | } |
263 | |
264 | if (do_compute_compensation) { |
265 | reduce_compensation_across_accumulators(num_acc); |
266 | |
267 | const auto addr_buf = zword[reg_zp_comp_buf_ptr]; |
268 | if (!is_first_K_iter) |
269 | vpaddd(get_zmm_comp_acc(0), get_zmm_comp_acc(0), addr_buf); |
270 | if (!is_last_K_iter) { |
271 | vmovups(addr_buf, get_zmm_comp_acc(0)); |
272 | return; |
273 | } |
274 | |
275 | // is_last_K_iter == true: we need to reduce values within acc |
276 | // register, add mixed ab_compensation component if any, multiply |
277 | // it by negative zp_b_value and finally store the reslt |
278 | |
279 | // step 1: reduce values within acc register |
280 | const auto ymm_red0 = ymm(get_zmm_comp_acc(0).getIdx()); |
281 | const auto ymm_red1 = ymm(get_zmm_comp_acc(1).getIdx()); |
282 | vextracti64x4(ymm_red1, get_zmm_comp_acc(0), 1); |
283 | vphaddd(ymm_red0, ymm_red0, ymm_red1); |
284 | vpxord(ymm_red1, ymm_red1, ymm_red1); |
285 | vphaddd(ymm_red0, ymm_red0, ymm_red1); |
286 | vphaddd(ymm_red0, ymm_red0, ymm_red1); |
287 | const auto xmm_red1 = xmm(ymm_red1.getIdx()); |
288 | vextractf128(xmm_red1, ymm_red0, 1); |
289 | vpaddd(ymm_red0, ymm_red0, ymm_red1); |
290 | |
291 | // step 2: add -K * zp_a_val as mixed ab_compensation component |
292 | if (conf_->src_zp_type != brgemm_broadcast_t::none) { |
293 | assert(conf_->src_zp_type == brgemm_broadcast_t::per_tensor); |
294 | reg64_t reg_zp_ab_comp_ptr = imm_addr64; |
295 | mov(reg_zp_ab_comp_ptr, ptr[param1 + GET_OFF(zp_ab_comp_ptr)]); |
296 | |
297 | const auto addr_ab_comp = zword_b[reg_zp_ab_comp_ptr]; |
298 | const auto zmm_res = get_zmm_comp_acc(0) | kTail_comp; |
299 | vpaddd(zmm_res, get_zmm_comp_acc(0), addr_ab_comp); |
300 | } |
301 | |
302 | // step 3: multiply by zp_b_val |
303 | mov(reg_zp_b_neg_val_ptr, ptr[param1 + GET_OFF(zp_b_neg_value_ptr)]); |
304 | const auto zmm_zp_b_neg_val = get_zmm_comp_acc(1); |
305 | vbroadcastss(zmm_zp_b_neg_val, ptr[reg_zp_b_neg_val_ptr]); |
306 | vpmulld(get_zmm_comp_acc(0), get_zmm_comp_acc(0), zmm_zp_b_neg_val); |
307 | |
308 | // step 4: store the final result value |
309 | vmovups(ptr[reg_zp_comp_res_ptr], get_zmm_comp_acc(0) | kTail_comp); |
310 | } |
311 | } |
312 | |
313 | void jit_brgemm_matmul_copy_a_impl_t::copy_M_loop( |
314 | bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) { |
315 | |
316 | if (do_compute_compensation) { |
317 | mov(imm_addr64, 1); |
318 | vpbroadcastb(zmm_comp_mul, imm_addr64.cvt8()); |
319 | if (!(is_first_K_iter && is_last_K_iter)) |
320 | mov(reg_zp_comp_buf_ptr, |
321 | ptr[param1 + GET_OFF(zp_b_compensation_buffer_ptr)]); |
322 | |
323 | if (is_last_K_iter) { |
324 | mov(reg_zp_comp_res_ptr, |
325 | ptr[param1 + GET_OFF(zp_a_compensation_result_ptr)]); |
326 | const auto kmovw = [=](Opmask k, size_t q) { |
327 | mov(regq_tmp, q); |
328 | jit_generator::kmovw(k, imm_addr64.cvt32()); |
329 | }; |
330 | kmovw(kTail_comp, 1); |
331 | } |
332 | } |
333 | |
334 | Label loop_M; |
335 | L(loop_M); |
336 | |
337 | copy_K_loop(is_K_tail, is_first_K_iter, is_last_K_iter); |
338 | |
339 | add(reg_src, src_stride); |
340 | add(reg_tr_src, tr_src_stride); |
341 | if (do_compute_compensation) { |
342 | // shift comp pointers |
343 | if (!(is_first_K_iter && is_last_K_iter)) |
344 | add(reg_zp_comp_buf_ptr, sizeof(int32_t) * 16); |
345 | if (is_last_K_iter) add(reg_zp_comp_res_ptr, sizeof(int32_t)); |
346 | } |
347 | |
348 | dec(reg_M_blk); |
349 | jnz(loop_M, T_NEAR); |
350 | } |
351 | |
352 | void jit_brgemm_matmul_copy_a_impl_t::generate() { |
353 | preamble(); |
354 | |
355 | src_stride = conf_->src_tag == format_tag::acbd ? conf_->copy_A_src_stride |
356 | : conf_->K * typesize; |
357 | const dim_t LDA = conf_->use_buffer_a_tail_only ? (dim_t)conf_->wei_k_blk |
358 | : conf_->LDA; |
359 | tr_src_stride = LDA * tr_typesize; |
360 | do_compute_compensation = conf_->has_zero_point_b; |
361 | |
362 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
363 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
364 | mov(reg_K_blk, ptr[param1 + GET_OFF(current_K_blk)]); |
365 | mov(reg_M_blk, ptr[param1 + GET_OFF(current_M_blk)]); |
366 | |
367 | if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) { |
368 | mov(imm_addr64, 128); |
369 | vpbroadcastb(zmm_comp_add, imm_addr64.cvt8()); |
370 | } |
371 | |
372 | auto copy_body = [=](bool is_first_K_iter, bool is_last_K_iter) { |
373 | Label copy_body_done; |
374 | // might be different from conf_->K_tail |
375 | const dim_t K_blk_tail |
376 | = conf_->K_tail > 0 ? conf_->K % conf_->K_blk : 0; |
377 | if (K_blk_tail > 0) { |
378 | Label not_K_tail; |
379 | cmp(reg_K_blk, K_blk_tail); |
380 | jne(not_K_tail, T_NEAR); |
381 | copy_M_loop(true, is_first_K_iter, is_last_K_iter); |
382 | jmp(copy_body_done, T_NEAR); |
383 | |
384 | L(not_K_tail); |
385 | } |
386 | |
387 | copy_M_loop(false, is_first_K_iter, is_last_K_iter); |
388 | L(copy_body_done); |
389 | }; |
390 | |
391 | Label done; |
392 | if (do_compute_compensation) { |
393 | assert(conf_->wei_zp_type == brgemm_broadcast_t::per_tensor); |
394 | |
395 | mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]); |
396 | const auto last_K_threshold |
397 | = rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk; |
398 | Label not_first, not_first_not_last; |
399 | cmp(reg_K_start, 0); |
400 | jne(not_first, T_NEAR); |
401 | { |
402 | // first K iteration |
403 | Label first_not_last; |
404 | cmp(reg_K_start, last_K_threshold); |
405 | jl(first_not_last, T_NEAR); |
406 | copy_body(true, true); |
407 | jmp(done, T_NEAR); |
408 | |
409 | L(first_not_last); |
410 | copy_body(true, false); |
411 | jmp(done, T_NEAR); |
412 | } |
413 | |
414 | L(not_first); |
415 | cmp(reg_K_start, last_K_threshold); |
416 | jl(not_first_not_last, T_NEAR); |
417 | |
418 | copy_body(false, true); |
419 | jmp(done, T_NEAR); |
420 | L(not_first_not_last); |
421 | } |
422 | copy_body(false, false); |
423 | L(done); |
424 | |
425 | postamble(); |
426 | } |
427 | |
428 | struct jit_brgemm_matmul_copy_a_transposed_impl_t |
429 | : public jit_brgemm_matmul_copy_a_t, |
430 | public jit_generator { |
431 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_transposed_impl_t) |
432 | |
433 | jit_brgemm_matmul_copy_a_transposed_impl_t(const brgemm_matmul_conf_t *conf) |
434 | : jit_brgemm_matmul_copy_a_t(conf) |
435 | , jit_generator(jit_name()) |
436 | , typesize(conf_->a_dt_sz) |
437 | , tr_typesize(conf_->tr_a_dt_sz) |
438 | , src_stride(conf_->src_tag == format_tag::adbc |
439 | ? conf_->copy_A_src_stride |
440 | : conf_->M * typesize) |
441 | , dst_stride(conf_->LDA * tr_typesize) |
442 | , m_loop_src_shift(columns_step * typesize) |
443 | , m_loop_dst_shift(columns_step * dst_stride) |
444 | , k_loop_src_shift(rows_step * src_stride) |
445 | , k_loop_dst_shift(rows_step * tr_typesize) |
446 | , is_f32(everyone_is(data_type::f32, conf_->src_dt, conf_->wei_dt)) |
447 | , is_bf32(conf_->is_bf32) {} |
448 | |
449 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
450 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
451 | |
452 | private: |
453 | using reg64_t = const Xbyak::Reg64; |
454 | using reg32_t = const Xbyak::Reg32; |
455 | using opmask_t = const Xbyak::Opmask; |
456 | |
457 | const size_t typesize; |
458 | const size_t tr_typesize; |
459 | const int rows_step = 16; |
460 | const int columns_step = rows_step; |
461 | const dim_t src_stride, dst_stride; |
462 | const dim_t m_loop_src_shift; |
463 | const dim_t m_loop_dst_shift; |
464 | const dim_t k_loop_src_shift; |
465 | const dim_t k_loop_dst_shift; |
466 | const bool is_f32; |
467 | const bool is_bf32; |
468 | |
469 | opmask_t kFFFF = k1; |
470 | opmask_t k3333 = k1; |
471 | opmask_t k5555 = k2; |
472 | opmask_t kAAAA = k3; |
473 | opmask_t kAA = k4; |
474 | opmask_t kCCCC = k4; |
475 | opmask_t k55 = k5; |
476 | opmask_t k0F0F = k5; |
477 | opmask_t kCC = k6; |
478 | opmask_t kF0F0 = k6; |
479 | opmask_t k33 = k7; |
480 | opmask_t kTail = is_f32 ? k7 : k1; |
481 | |
482 | reg32_t regw_tmp = r15d; |
483 | reg64_t reg_k_src = r14; |
484 | reg64_t reg_k_dst = r13; |
485 | reg64_t reg_m_src = r12; |
486 | reg64_t reg_m_dst = r11; |
487 | reg64_t reg_loop_k = rax; |
488 | reg64_t reg_loop_m = rbx; |
489 | reg64_t imm_addr64 = rdx; |
490 | |
491 | Xbyak::Zmm vidx1 = zmm31; |
492 | Xbyak::Zmm vidx2 = zmm30; |
493 | Xbyak::Zmm vidx3 = zmm29; |
494 | Xbyak::Zmm vidx4 = zmm28; |
495 | Xbyak::Zmm vidx5 = zmm27; |
496 | Xbyak::Zmm zmm_tmp = zmm26; |
497 | |
498 | void transpose_f32(reg64_t dst, reg64_t src, int nrows, int ncolumns); |
499 | void transpose_bf16(reg64_t dst, reg64_t src, int nrows, int ncolumns); |
500 | void deploy_transpose(reg64_t dst, reg64_t src, int nrows, int ncolumns); |
501 | void generate() override; |
502 | }; |
503 | |
504 | void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_bf16( |
505 | reg64_t dst, reg64_t src, int nrows, int ncolumns) { |
506 | assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0 |
507 | && ncolumns <= columns_step); |
508 | if (!nrows) return; |
509 | |
510 | auto src_zmm = [=](int i) { return Zmm(i); }; |
511 | |
512 | auto src_ymm = [=](int i) { |
513 | assert(i >= 0 && i < 16); |
514 | return Ymm(i); |
515 | }; |
516 | |
517 | auto kmovx = [=](Opmask k, unsigned w, bool use_word_sz = false) { |
518 | mov(regw_tmp, w); |
519 | if (use_word_sz) |
520 | jit_generator::kmovw(k, regw_tmp); |
521 | else |
522 | jit_generator::kmovd(k, regw_tmp); |
523 | }; |
524 | |
525 | auto store = [=](Zmm r, int i) { |
526 | auto addr = EVEX_compress_addr(dst, i * dst_stride); |
527 | vmovdqu16(addr, r | kTail); |
528 | }; |
529 | |
530 | const int load_mask |
531 | = ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff; |
532 | kmovx(kFFFF, load_mask, is_bf32); |
533 | |
534 | for (int i = 0; i < nrows / 2; i++) { |
535 | auto idx0 = 2 * i; |
536 | auto idx1 = 2 * i + 1; |
537 | auto zmm_src0 = src_zmm(idx0); |
538 | auto zmm_src1 = src_zmm(idx1); |
539 | auto src_addr_0 = EVEX_compress_addr(src, idx0 * src_stride); |
540 | auto src_addr_1 = EVEX_compress_addr(src, idx1 * src_stride); |
541 | if (is_bf32) { |
542 | vmovups(zmm_src0 | kFFFF | T_z, src_addr_0); |
543 | vmovups(zmm_src1 | kFFFF | T_z, src_addr_1); |
544 | vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0); |
545 | } else { |
546 | auto src1 = src_ymm(idx1); |
547 | vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr_0); |
548 | vmovdqu16(zmm_src1 | kFFFF | T_z, src_addr_1); |
549 | vinsertf64x4(zmm_src0, zmm_src0, src1, 1); |
550 | } |
551 | vpermw(zmm_src0, vidx5, zmm_src0); |
552 | } |
553 | |
554 | // for odd numbers we need to mix row with zeroes |
555 | if (nrows % 2) { |
556 | int i = nrows / 2; |
557 | auto zmm_src0 = src_zmm(2 * i); |
558 | auto src_addr = EVEX_compress_addr(src, 2 * i * src_stride); |
559 | if (is_bf32) { |
560 | vmovups(zmm_src0 | kFFFF | T_z, src_addr); |
561 | vcvtneps2bf16(Ymm(zmm_src0.getIdx()), zmm_src0); |
562 | } else |
563 | vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr); |
564 | vpermw(zmm_src0, vidx5, zmm_src0); |
565 | } |
566 | |
567 | for (int i = rnd_up(nrows, 2); i < rows_step; i += 2) { |
568 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
569 | } |
570 | |
571 | // swap 1 |
572 | for (int i = 0; i < 4; i++) { |
573 | auto zmm0 = src_zmm(4 * i); |
574 | auto zmm1 = src_zmm(4 * i + 2); |
575 | auto tmp0 = src_zmm(4 * i + 1); |
576 | auto tmp1 = src_zmm(4 * i + 3); |
577 | |
578 | vmovups(tmp0, zmm0); |
579 | vmovups(tmp1, zmm1); |
580 | |
581 | vpermps(tmp0 | kAAAA, vidx3, zmm1); |
582 | vpermps(tmp1 | k5555, vidx3, zmm0); |
583 | } |
584 | // swap 2 |
585 | int base_idx; |
586 | base_idx = 0; |
587 | for (int i = 0; i < 2; i++) { |
588 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
589 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
590 | |
591 | auto tmp0 = src_zmm(base_idx + 2 * i); |
592 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
593 | |
594 | vmovupd(tmp0, zmm0); |
595 | vmovupd(tmp1, zmm1); |
596 | |
597 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
598 | vpermpd(tmp1 | k55, vidx2, zmm0); |
599 | } |
600 | base_idx = 8; |
601 | for (int i = 0; i < 2; i++) { |
602 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
603 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
604 | |
605 | auto tmp0 = src_zmm(base_idx + 2 * i); |
606 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
607 | |
608 | vmovupd(tmp0, zmm0); |
609 | vmovupd(tmp1, zmm1); |
610 | |
611 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
612 | vpermpd(tmp1 | k55, vidx2, zmm0); |
613 | } |
614 | |
615 | // swap 3 |
616 | for (int i = 0; i < 4; i++) { |
617 | auto zmm0 = src_zmm(2 * i); |
618 | auto zmm1 = src_zmm(2 * i + 8); |
619 | |
620 | auto tmp0 = src_zmm(2 * i + 1); |
621 | auto tmp1 = src_zmm(2 * i + 9); |
622 | |
623 | vmovupd(tmp0, zmm0); |
624 | vmovupd(tmp1, zmm1); |
625 | |
626 | vpermpd(tmp0 | kCC, vidx1, zmm1); |
627 | vpermpd(tmp1 | k33, vidx1, zmm0); |
628 | } |
629 | |
630 | // all stores |
631 | for (int i = 0; i < 8; i++) |
632 | vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1); |
633 | |
634 | auto get_vec_idx = [=](int col_idx) { |
635 | assert(col_idx < columns_step && col_idx >= 0); |
636 | const int blk_sz = 4; |
637 | const int blk_idx = col_idx / blk_sz; |
638 | const int idx_within_blk = col_idx % blk_sz; |
639 | |
640 | // 0 1 2 3 -> 0 2 1 3 |
641 | const int mapped_blk_idx = 2 * blk_idx - (blk_idx / 2) * 3; |
642 | // 0 1 2 3 -> 1 0 3 2 |
643 | const int mapped_idx_within_blk |
644 | = idx_within_blk + 1 - 2 * (idx_within_blk % 2); |
645 | return blk_sz * mapped_blk_idx + mapped_idx_within_blk; |
646 | }; |
647 | const int columns_to_store = rnd_up(nrows, 2); |
648 | const int store_mask = columns_to_store < rows_step |
649 | ? (1 << columns_to_store) - 1 |
650 | : 0xffff; |
651 | kmovx(kTail, store_mask); |
652 | |
653 | for (int col_idx = 0; col_idx < ncolumns; col_idx++) |
654 | store(src_zmm(get_vec_idx(col_idx)), col_idx); |
655 | } |
656 | |
657 | void jit_brgemm_matmul_copy_a_transposed_impl_t::transpose_f32( |
658 | reg64_t dst, reg64_t src, int nrows, int ncolumns) { |
659 | assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0 |
660 | && ncolumns <= columns_step); |
661 | if (!nrows) return; |
662 | |
663 | auto kmovw = [=](Opmask k, size_t q) { |
664 | mov(regw_tmp, q); |
665 | jit_generator::kmovw(k, regw_tmp); |
666 | }; |
667 | |
668 | const int load_mask |
669 | = ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff; |
670 | kmovw(kTail, load_mask); |
671 | |
672 | auto src_zmm = [=](int i) { |
673 | assert(i >= 0 && i < 16); |
674 | return Zmm(i); |
675 | }; |
676 | |
677 | auto tmp_zmm = [=](int i) { |
678 | assert(i >= 0 && i < 16); |
679 | return Zmm(16 + i); |
680 | }; |
681 | |
682 | auto load = [=](int i) { |
683 | const auto addr = EVEX_compress_addr(src, i * src_stride); |
684 | if (i < nrows) |
685 | if (conf_->isa == avx512_core_fp16) |
686 | vcvtph2psx(src_zmm(i) | kTail | T_z, addr); |
687 | else |
688 | vmovups(src_zmm(i) | kTail | T_z, addr); |
689 | else |
690 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
691 | }; |
692 | |
693 | auto store = [=](Zmm r, int i) { |
694 | auto addr = EVEX_compress_addr(dst, i * dst_stride); |
695 | vmovups(addr, r | kTail); |
696 | }; |
697 | |
698 | auto transpose16x8 = [=](int base_idx) { |
699 | assert(base_idx == 0 || base_idx == 8); |
700 | |
701 | // swap 1 |
702 | for (int i = 0; i < 4; i++) { |
703 | int src_idx0 = base_idx + i * 2; |
704 | int src_idx1 = src_idx0 + 1; |
705 | |
706 | int next_src_idx0 = src_idx0 + 2; |
707 | int next_src_idx1 = src_idx1 + 2; |
708 | bool load_next = base_idx == 0 || i < 3; |
709 | |
710 | if (base_idx == 0 && i == 0) { |
711 | load(src_idx0); |
712 | load(src_idx1); |
713 | } |
714 | |
715 | auto tmp0 = tmp_zmm(src_idx0); |
716 | auto tmp1 = tmp_zmm(src_idx1); |
717 | auto src0 = src_zmm(src_idx0); |
718 | auto src1 = src_zmm(src_idx1); |
719 | |
720 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
721 | valignd(tmp0, src0, src0, 0x1); |
722 | |
723 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
724 | valignd(tmp1, src1, src1, 0xf); |
725 | |
726 | vmovaps(src0 | kAAAA, tmp1); |
727 | vmovaps(src1 | k5555, tmp0); |
728 | } |
729 | |
730 | // swap 2 |
731 | for (int i = 0; i < 4; i++) { |
732 | int select_half = (i < 2) ? 0 : 2; |
733 | int src_idx0 = base_idx + i + select_half + 0; |
734 | int src_idx2 = src_idx0 + 2; |
735 | |
736 | auto tmp0 = tmp_zmm(src_idx0); |
737 | auto tmp1 = tmp_zmm(src_idx2); |
738 | auto src0 = src_zmm(src_idx0); |
739 | auto src2 = src_zmm(src_idx2); |
740 | |
741 | valignd(tmp0, src0, src0, 0x2); |
742 | valignd(tmp1, src2, src2, 0xe); |
743 | vmovaps(src2 | k3333, tmp0); |
744 | vmovaps(src0 | kCCCC, tmp1); |
745 | } |
746 | |
747 | // swap 4 |
748 | for (int i = 0; i < 4; i++) { |
749 | int src_idx0 = base_idx + i; |
750 | int src_idx4 = src_idx0 + 4; |
751 | |
752 | auto tmp0 = tmp_zmm(src_idx0); |
753 | auto src0 = src_zmm(src_idx0); |
754 | auto src4 = src_zmm(src_idx4); |
755 | |
756 | vmovaps(tmp0, src0); |
757 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
758 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
759 | } |
760 | }; |
761 | |
762 | auto fixup16x16 = [=]() { |
763 | const int store_mask = nrows < rows_step ? (1 << nrows) - 1 : 0xffff; |
764 | kmovw(kTail, store_mask); |
765 | |
766 | // swap 8 |
767 | for (int i = 0; i < nstl::min(8, ncolumns); i++) { |
768 | auto tmp = tmp_zmm(i); |
769 | auto src0 = src_zmm(i); |
770 | auto src8 = src_zmm(8 + i); |
771 | vshuff64x2(tmp, src0, src8, 0x44); |
772 | store(tmp, i); |
773 | } |
774 | |
775 | for (int i = 0; i < nstl::max(0, ncolumns - 8); i++) { |
776 | auto tmp = tmp_zmm(8 + i); |
777 | auto src0 = src_zmm(i); |
778 | auto src8 = src_zmm(8 + i); |
779 | vshuff64x2(tmp, src0, src8, 0xee); |
780 | store(tmp, 8 + i); |
781 | } |
782 | }; |
783 | |
784 | transpose16x8(0); |
785 | transpose16x8(8); |
786 | fixup16x16(); |
787 | } |
788 | |
789 | void jit_brgemm_matmul_copy_a_transposed_impl_t::deploy_transpose( |
790 | reg64_t dst, reg64_t src, int nrows, int ncolumns) { |
791 | if (is_f32 || conf_->isa == avx512_core_fp16) |
792 | transpose_f32(dst, src, nrows, ncolumns); |
793 | else |
794 | transpose_bf16(dst, src, nrows, ncolumns); |
795 | } |
796 | |
797 | void jit_brgemm_matmul_copy_a_transposed_impl_t::generate() { |
798 | |
799 | // only bf16, f16 and f32 supported for now |
800 | if (!one_of(conf_->src_dt, data_type::bf16, data_type::f32, data_type::f16)) |
801 | return; |
802 | preamble(); |
803 | |
804 | alignas(64) static constexpr const int64_t idx1[8] |
805 | = {2, 3, 0, 1, 6, 7, 4, 5}; |
806 | alignas(64) static constexpr const int64_t idx2[8] |
807 | = {1, 0, 3, 2, 5, 4, 7, 6}; |
808 | alignas(64) static constexpr const int32_t idx3[16] |
809 | = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14}; |
810 | alignas(64) static constexpr const int32_t idx4[16] |
811 | = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7}; |
812 | alignas(64) static constexpr const uint16_t idx5[32] |
813 | = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17, |
814 | 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31}; |
815 | |
816 | const int k_block_tail = conf_->K_blk % rows_step; |
817 | const int last_k_block_tail = (conf_->K % conf_->K_blk) % rows_step; |
818 | const int m_block_tail = conf_->M_blk % columns_step; |
819 | const int last_m_block_tail = conf_->M_tail % columns_step; |
820 | |
821 | auto kmovw = [=](Opmask k, unsigned w) { |
822 | mov(regw_tmp, w); |
823 | jit_generator::kmovw(k, regw_tmp); |
824 | }; |
825 | |
826 | if (is_f32) { |
827 | kmovw(k3333, 0x3333); // 0011001100110011 |
828 | kmovw(k5555, 0x5555); // 0101010101010101 |
829 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
830 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
831 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
832 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
833 | } else { |
834 | kmovw(kFFFF, 0xffff); |
835 | kmovw(k5555, 0x5555); |
836 | kmovw(kAAAA, 0xaaaa); |
837 | kmovw(kAA, 0xaa); |
838 | kmovw(k55, 0x55); |
839 | kmovw(kCC, 0xcc); |
840 | kmovw(k33, 0x33); |
841 | } |
842 | |
843 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
844 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
845 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
846 | }; |
847 | |
848 | auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { |
849 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
850 | jit_generator::vmovdqa32(z, ptr[imm_addr64]); |
851 | }; |
852 | |
853 | if (!is_f32) { |
854 | vmovdqa64(vidx1, idx1); |
855 | vmovdqa64(vidx2, idx2); |
856 | vmovdqa32(vidx3, idx3); |
857 | vmovdqa32(vidx4, idx4); |
858 | vmovdqa32(vidx5, (const int32_t *)idx5); |
859 | } |
860 | |
861 | auto compute_m_loop = [&](reg64_t ®_base, reg64_t ®_tr_base, |
862 | int nrows) { |
863 | mov(reg_loop_m, ptr[param1 + GET_OFF(current_M_blk)]); |
864 | mov(reg_m_src, reg_base); |
865 | mov(reg_m_dst, reg_tr_base); |
866 | |
867 | Label m_loop_tail_or_done, m_loop, compute_m_loop_done; |
868 | cmp(reg_loop_m, columns_step); |
869 | jl(m_loop_tail_or_done, T_NEAR); |
870 | |
871 | L(m_loop); |
872 | { |
873 | deploy_transpose(reg_m_dst, reg_m_src, nrows, columns_step); |
874 | add(reg_m_src, m_loop_src_shift); |
875 | add(reg_m_dst, m_loop_dst_shift); |
876 | } |
877 | sub(reg_loop_m, columns_step); |
878 | cmp(reg_loop_m, columns_step); |
879 | jge(m_loop, T_NEAR); |
880 | |
881 | if (m_block_tail > 0 || last_m_block_tail > 0) |
882 | jz(compute_m_loop_done, T_NEAR); |
883 | |
884 | L(m_loop_tail_or_done); |
885 | |
886 | if (m_block_tail > 0) { |
887 | Label m_block_tail_done; |
888 | cmp(reg_loop_m, m_block_tail); |
889 | jne(m_block_tail_done, T_NEAR); |
890 | |
891 | deploy_transpose(reg_m_dst, reg_m_src, nrows, m_block_tail); |
892 | jmp(compute_m_loop_done, T_NEAR); |
893 | |
894 | L(m_block_tail_done); |
895 | } |
896 | if (last_m_block_tail > 0 && last_m_block_tail != m_block_tail) { |
897 | Label last_m_block_tail_done; |
898 | cmp(reg_loop_m, last_m_block_tail); |
899 | jne(last_m_block_tail_done, T_NEAR); |
900 | |
901 | deploy_transpose(reg_m_dst, reg_m_src, nrows, last_m_block_tail); |
902 | jmp(compute_m_loop_done, T_NEAR); |
903 | |
904 | L(last_m_block_tail_done); |
905 | } |
906 | |
907 | L(compute_m_loop_done); |
908 | }; |
909 | |
910 | auto compute_k_loop = [&]() { |
911 | mov(reg_k_src, ptr[param1 + GET_OFF(src)]); |
912 | mov(reg_k_dst, ptr[param1 + GET_OFF(tr_src)]); |
913 | mov(reg_loop_k, ptr[param1 + GET_OFF(current_K_blk)]); |
914 | |
915 | Label k_tail_or_done, k_loop, compute_k_loop_done; |
916 | cmp(reg_loop_k, rows_step); |
917 | jl(k_tail_or_done, T_NEAR); |
918 | |
919 | L(k_loop); |
920 | { |
921 | compute_m_loop(reg_k_src, reg_k_dst, rows_step); |
922 | add(reg_k_src, k_loop_src_shift); |
923 | add(reg_k_dst, k_loop_dst_shift); |
924 | } |
925 | sub(reg_loop_k, rows_step); |
926 | cmp(reg_loop_k, rows_step); |
927 | jge(k_loop, T_NEAR); |
928 | |
929 | if (k_block_tail > 0 || last_k_block_tail > 0) |
930 | jz(compute_k_loop_done, T_NEAR); |
931 | |
932 | L(k_tail_or_done); |
933 | |
934 | if (k_block_tail > 0) { |
935 | Label k_block_tail_done; |
936 | cmp(reg_loop_k, k_block_tail); |
937 | jne(k_block_tail_done, T_NEAR); |
938 | |
939 | compute_m_loop(reg_k_src, reg_k_dst, k_block_tail); |
940 | jmp(compute_k_loop_done, T_NEAR); |
941 | |
942 | L(k_block_tail_done); |
943 | } |
944 | if (last_k_block_tail > 0 && last_k_block_tail != k_block_tail) { |
945 | Label last_k_block_tail_done; |
946 | cmp(reg_loop_k, last_k_block_tail); |
947 | jne(last_k_block_tail_done, T_NEAR); |
948 | |
949 | compute_m_loop(reg_k_src, reg_k_dst, last_k_block_tail); |
950 | jmp(compute_k_loop_done, T_NEAR); |
951 | |
952 | L(last_k_block_tail_done); |
953 | } |
954 | |
955 | L(compute_k_loop_done); |
956 | }; |
957 | |
958 | compute_k_loop(); |
959 | |
960 | postamble(); |
961 | } |
962 | |
963 | struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t, |
964 | public jit_generator { |
965 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_int8_t) |
966 | |
967 | jit_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf) |
968 | : jit_brgemm_matmul_copy_b_t(conf), jit_generator(jit_name()) {} |
969 | |
970 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
971 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
972 | |
973 | private: |
974 | using reg64_t = const Xbyak::Reg64; |
975 | using reg32_t = const Xbyak::Reg32; |
976 | using opmask_t = const Xbyak::Opmask; |
977 | using zmm = const Xbyak::Zmm; |
978 | using ymm = const Xbyak::Ymm; |
979 | |
980 | enum { typesize = sizeof(int8_t), k_blk_step = 4, n_blk_step = 64 }; |
981 | dim_t src_stride = 0, tr_src_stride = 0; |
982 | bool is_amx = false; |
983 | bool do_compute_compensation = false; |
984 | |
985 | opmask_t kTail = k7; |
986 | |
987 | reg64_t reg_src = rax; |
988 | reg64_t reg_tr_src = rbx; |
989 | reg64_t reg_comp_ptr = rdx; |
990 | reg64_t reg_zp_comp_ptr = r11; |
991 | reg64_t reg_zp_a_neg_val_ptr = r12; |
992 | |
993 | reg64_t reg_K_iters = r8; |
994 | reg64_t reg_N_blk = r9; |
995 | reg64_t reg_K_start = r10; |
996 | reg64_t regq_tmp = r14; |
997 | reg64_t imm_addr64 = r15; |
998 | |
999 | zmm vreg_idx_lo_256 = zmm26; |
1000 | zmm vreg_idx_hi_256 = zmm27; |
1001 | zmm vreg_idx_lo_128 = zmm28; |
1002 | zmm vreg_idx_hi_128 = zmm29; |
1003 | zmm zmm_comp_mul = zmm30; |
1004 | zmm zmm_zero = zmm31; |
1005 | |
1006 | Xbyak::Zmm get_comp_acc(int i) { return Xbyak::Zmm(25 - i); } |
1007 | Xbyak::Zmm get_zmm_zp_comp_res(int i) { return get_comp_acc(i); } |
1008 | Xbyak::Zmm get_zmm_oscale_comp_res(int i) { return Xbyak::Zmm(i); } |
1009 | void copy_4x64_vnni_avx512_core(int nrows, int ncolumns); |
1010 | void copy_4x64_vnni_amx(int nrows, int ncolumns); |
1011 | void copy_4x64_vnni(int nrows, int ncolumns); |
1012 | void generate() override; |
1013 | }; |
1014 | |
1015 | void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni(int nrows, int ncolumns) { |
1016 | if (is_amx) |
1017 | copy_4x64_vnni_amx(nrows, ncolumns); |
1018 | else |
1019 | copy_4x64_vnni_avx512_core(nrows, ncolumns); |
1020 | } |
1021 | |
1022 | void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni_amx( |
1023 | int nrows, int ncolumns) { |
1024 | auto kmovq = [=](Opmask k, size_t q) { |
1025 | mov(regq_tmp, q); |
1026 | jit_generator::kmovq(k, regq_tmp); |
1027 | }; |
1028 | |
1029 | const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1); |
1030 | if (ncolumns < n_blk_step) kmovq(kTail, tail_mask); |
1031 | |
1032 | const int blk_sz = 6; |
1033 | const int max_unroll = (do_compute_compensation ? 21 : 25) / blk_sz; |
1034 | auto get_zmm = [=](int blk, int idx) { |
1035 | assert(idx >= 0 && idx < blk_sz && blk >= 0); |
1036 | auto reg_idx = blk_sz * blk + idx; |
1037 | assert(reg_idx >= 0 && reg_idx < 32); |
1038 | return zmm(reg_idx); |
1039 | }; |
1040 | |
1041 | auto load = [=](int blk, int i) { |
1042 | auto src_reg = get_zmm(blk, i % k_blk_step); |
1043 | auto src_load = ncolumns < n_blk_step ? src_reg | kTail | T_z : src_reg; |
1044 | vmovdqu8(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
1045 | }; |
1046 | |
1047 | for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step); kb++) |
1048 | for (int k = 0; |
1049 | k < nstl::min(max_unroll, |
1050 | div_up(nrows - kb * max_unroll * k_blk_step, k_blk_step)); |
1051 | k++) { |
1052 | const int row_start = (kb * max_unroll + k) * k_blk_step; |
1053 | const int row_end = nstl::min(row_start + k_blk_step, nrows); |
1054 | |
1055 | for (int i = row_start; i < row_end; i++) |
1056 | load(k, i); |
1057 | if (row_end == nrows && nrows % k_blk_step > 0) { |
1058 | for (int i = nrows; i < rnd_up(nrows, k_blk_step); i++) { |
1059 | auto src_reg = get_zmm(k, i % k_blk_step); |
1060 | vpxord(src_reg, src_reg, src_reg); |
1061 | } |
1062 | } |
1063 | |
1064 | vmovups(get_zmm(k, 4), vreg_idx_lo_256); |
1065 | vpermi2b(get_zmm(k, 4), get_zmm(k, 0), get_zmm(k, 2)); |
1066 | vmovups(get_zmm(k, 5), vreg_idx_hi_256); |
1067 | vpermi2b(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 2)); |
1068 | vmovups(get_zmm(k, 0), vreg_idx_lo_256); |
1069 | vpermi2b(get_zmm(k, 0), get_zmm(k, 1), get_zmm(k, 3)); |
1070 | vmovups(get_zmm(k, 2), vreg_idx_hi_256); |
1071 | vpermi2b(get_zmm(k, 2), get_zmm(k, 1), get_zmm(k, 3)); |
1072 | |
1073 | vmovups(get_zmm(k, 1), vreg_idx_lo_128); |
1074 | vpermi2b(get_zmm(k, 1), get_zmm(k, 4), get_zmm(k, 0)); |
1075 | dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride; |
1076 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base), get_zmm(k, 1)); |
1077 | if (do_compute_compensation) |
1078 | vpdpbusd(get_comp_acc(0), zmm_comp_mul, get_zmm(k, 1)); |
1079 | |
1080 | if (ncolumns > 16) { |
1081 | vmovups(get_zmm(k, 3), vreg_idx_hi_128); |
1082 | vpermi2b(get_zmm(k, 3), get_zmm(k, 4), get_zmm(k, 0)); |
1083 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64), |
1084 | get_zmm(k, 3)); |
1085 | if (do_compute_compensation) |
1086 | vpdpbusd(get_comp_acc(1), zmm_comp_mul, get_zmm(k, 3)); |
1087 | } else if (conf_->wei_n_blk > 16) { |
1088 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64), |
1089 | zmm_zero); |
1090 | } |
1091 | |
1092 | if (ncolumns > 32) { |
1093 | vmovups(get_zmm(k, 4), vreg_idx_lo_128); |
1094 | vpermi2b(get_zmm(k, 4), get_zmm(k, 5), get_zmm(k, 2)); |
1095 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128), |
1096 | get_zmm(k, 4)); |
1097 | if (do_compute_compensation) |
1098 | vpdpbusd(get_comp_acc(2), zmm_comp_mul, get_zmm(k, 4)); |
1099 | } else if (conf_->wei_n_blk > 32) { |
1100 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128), |
1101 | zmm_zero); |
1102 | } |
1103 | |
1104 | if (ncolumns > 48) { |
1105 | vmovups(get_zmm(k, 0), vreg_idx_hi_128); |
1106 | vpermi2b(get_zmm(k, 0), get_zmm(k, 5), get_zmm(k, 2)); |
1107 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192), |
1108 | get_zmm(k, 0)); |
1109 | if (do_compute_compensation) |
1110 | vpdpbusd(get_comp_acc(3), zmm_comp_mul, get_zmm(k, 0)); |
1111 | } else if (conf_->wei_n_blk > 48) { |
1112 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192), |
1113 | zmm_zero); |
1114 | } |
1115 | } |
1116 | } |
1117 | |
1118 | void jit_brgemm_matmul_copy_b_int8_t::copy_4x64_vnni_avx512_core( |
1119 | int nrows, int ncolumns) { |
1120 | auto kmovq = [=](Opmask k, size_t q) { |
1121 | mov(regq_tmp, q); |
1122 | jit_generator::kmovq(k, regq_tmp); |
1123 | }; |
1124 | |
1125 | const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1); |
1126 | if (ncolumns < n_blk_step) kmovq(kTail, tail_mask); |
1127 | |
1128 | const int blk_sz = 6; |
1129 | const int max_unroll = (do_compute_compensation ? 21 : 25) / blk_sz; |
1130 | auto get_zmm = [=](int blk, int idx) { |
1131 | assert(idx >= 0 && idx < blk_sz && blk >= 0); |
1132 | auto reg_idx = blk_sz * blk + idx; |
1133 | assert(reg_idx >= 0 && reg_idx < 32); |
1134 | return zmm(reg_idx); |
1135 | }; |
1136 | auto load = [=](int blk, int i) { |
1137 | auto src_reg = get_zmm(blk, i % k_blk_step); |
1138 | auto src_load = ncolumns < n_blk_step ? src_reg | kTail | T_z : src_reg; |
1139 | vmovdqu8(src_load, EVEX_compress_addr(reg_src, i * src_stride)); |
1140 | }; |
1141 | |
1142 | for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step); kb++) |
1143 | for (int k = 0; |
1144 | k < nstl::min(max_unroll, |
1145 | div_up(nrows - kb * max_unroll * k_blk_step, k_blk_step)); |
1146 | k++) { |
1147 | const int row_start = (kb * max_unroll + k) * k_blk_step; |
1148 | const int row_end = nstl::min(row_start + k_blk_step, nrows); |
1149 | |
1150 | for (int i = row_start; i < row_end; i++) |
1151 | load(k, i); |
1152 | if (row_end == nrows && nrows % k_blk_step > 0) { |
1153 | for (int i = nrows; i < rnd_up(nrows, k_blk_step); i++) { |
1154 | auto src_reg = get_zmm(k, i % k_blk_step); |
1155 | vpxord(src_reg, src_reg, src_reg); |
1156 | } |
1157 | } |
1158 | |
1159 | vpunpcklbw(get_zmm(k, 4), get_zmm(k, 0), get_zmm(k, 1)); |
1160 | vpunpckhbw(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 1)); |
1161 | vpunpcklbw(get_zmm(k, 0), get_zmm(k, 2), get_zmm(k, 3)); |
1162 | vpunpckhbw(get_zmm(k, 1), get_zmm(k, 2), get_zmm(k, 3)); |
1163 | |
1164 | vpunpcklwd(get_zmm(k, 2), get_zmm(k, 4), get_zmm(k, 0)); |
1165 | vpunpckhwd(get_zmm(k, 3), get_zmm(k, 4), get_zmm(k, 0)); |
1166 | vpunpcklwd(get_zmm(k, 4), get_zmm(k, 5), get_zmm(k, 1)); |
1167 | vpunpckhwd(get_zmm(k, 5), get_zmm(k, 5), get_zmm(k, 1)); |
1168 | |
1169 | vmovups(get_zmm(k, 0), vreg_idx_lo_256); |
1170 | vpermi2q(get_zmm(k, 0), get_zmm(k, 2), get_zmm(k, 4)); |
1171 | vmovups(get_zmm(k, 1), vreg_idx_hi_256); |
1172 | vpermi2q(get_zmm(k, 1), get_zmm(k, 2), get_zmm(k, 4)); |
1173 | vmovups(get_zmm(k, 2), vreg_idx_lo_256); |
1174 | vpermi2q(get_zmm(k, 2), get_zmm(k, 3), get_zmm(k, 5)); |
1175 | vmovups(get_zmm(k, 4), vreg_idx_hi_256); |
1176 | vpermi2q(get_zmm(k, 4), get_zmm(k, 3), get_zmm(k, 5)); |
1177 | |
1178 | vmovups(get_zmm(k, 3), vreg_idx_lo_128); |
1179 | vpermi2q(get_zmm(k, 3), get_zmm(k, 0), get_zmm(k, 2)); |
1180 | dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride; |
1181 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base), get_zmm(k, 3)); |
1182 | if (do_compute_compensation) |
1183 | vpdpbusd(get_comp_acc(0), zmm_comp_mul, get_zmm(k, 3)); |
1184 | |
1185 | if (ncolumns > 16) { |
1186 | vmovups(get_zmm(k, 5), vreg_idx_hi_128); |
1187 | vpermi2q(get_zmm(k, 5), get_zmm(k, 0), get_zmm(k, 2)); |
1188 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64), |
1189 | get_zmm(k, 5)); |
1190 | if (do_compute_compensation) |
1191 | vpdpbusd(get_comp_acc(1), zmm_comp_mul, get_zmm(k, 5)); |
1192 | } else if (conf_->wei_n_blk > 16) { |
1193 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64), |
1194 | zmm_zero); |
1195 | } |
1196 | |
1197 | if (ncolumns > 32) { |
1198 | vmovups(get_zmm(k, 0), vreg_idx_lo_128); |
1199 | vpermi2q(get_zmm(k, 0), get_zmm(k, 1), get_zmm(k, 4)); |
1200 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128), |
1201 | get_zmm(k, 0)); |
1202 | if (do_compute_compensation) |
1203 | vpdpbusd(get_comp_acc(2), zmm_comp_mul, get_zmm(k, 0)); |
1204 | } else if (conf_->wei_n_blk > 32) { |
1205 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128), |
1206 | zmm_zero); |
1207 | } |
1208 | |
1209 | if (ncolumns > 48) { |
1210 | vmovups(get_zmm(k, 2), vreg_idx_hi_128); |
1211 | vpermi2q(get_zmm(k, 2), get_zmm(k, 1), get_zmm(k, 4)); |
1212 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192), |
1213 | get_zmm(k, 2)); |
1214 | if (do_compute_compensation) |
1215 | vpdpbusd(get_comp_acc(3), zmm_comp_mul, get_zmm(k, 2)); |
1216 | } else if (conf_->wei_n_blk > 48) { |
1217 | vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192), |
1218 | zmm_zero); |
1219 | } |
1220 | } |
1221 | } |
1222 | |
1223 | void jit_brgemm_matmul_copy_b_int8_t::generate() { |
1224 | preamble(); |
1225 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1226 | src_stride = (conf_->wei_tag == format_tag::acbd ? conf_->copy_B_wei_stride |
1227 | : conf_->N * typesize); |
1228 | tr_src_stride = conf_->LDB * k_blk_step * typesize; |
1229 | is_amx = mayiuse(avx512_core_amx); |
1230 | do_compute_compensation |
1231 | = conf_->s8s8_compensation_required || conf_->has_zero_point_a; |
1232 | |
1233 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1234 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1235 | mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); |
1236 | mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); |
1237 | |
1238 | auto vmovdqa64 = [=](Zmm z, const void *addr) { |
1239 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
1240 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
1241 | }; |
1242 | |
1243 | alignas(64) static constexpr const int64_t idx_lo_256[8] |
1244 | = {0, 1, 2, 3, 8, 9, 10, 11}; |
1245 | alignas(64) static constexpr const int64_t idx_hi_256[8] |
1246 | = {4, 5, 6, 7, 12, 13, 14, 15}; |
1247 | |
1248 | alignas(64) static constexpr const int64_t idx_lo_128[8] |
1249 | = {0, 1, 8, 9, 4, 5, 12, 13}; |
1250 | alignas(64) static constexpr const int64_t idx_hi_128[8] |
1251 | = {2, 3, 10, 11, 6, 7, 14, 15}; |
1252 | alignas(64) static constexpr const uint8_t idx_lo_16[64] |
1253 | = {0, 1, 64, 65, 4, 5, 68, 69, 2, 3, 66, 67, 6, 7, 70, 71, 8, 9, 72, |
1254 | 73, 12, 13, 76, 77, 10, 11, 74, 75, 14, 15, 78, 79, 16, 17, |
1255 | 80, 81, 20, 21, 84, 85, 18, 19, 82, 83, 22, 23, 86, 87, 24, |
1256 | 25, 88, 89, 28, 29, 92, 93, 26, 27, 90, 91, 30, 31, 94, 95}; |
1257 | |
1258 | alignas(64) static constexpr const uint8_t idx_hi_16[64] = {32, 33, 96, 97, |
1259 | 36, 37, 100, 101, 34, 35, 98, 99, 38, 39, 102, 103, 40, 41, 104, |
1260 | 105, 44, 45, 108, 109, 42, 43, 106, 107, 46, 47, 110, 111, 48, 49, |
1261 | 112, 113, 52, 53, 116, 117, 50, 51, 114, 115, 54, 55, 118, 119, 56, |
1262 | 57, 120, 121, 60, 61, 124, 125, 58, 59, 122, 123, 62, 63, 126, 127}; |
1263 | |
1264 | alignas(64) static constexpr const uint8_t idx_lo_8[64] |
1265 | = {0, 64, 2, 66, 1, 65, 3, 67, 8, 72, 10, 74, 9, 73, 11, 75, 4, 68, |
1266 | 6, 70, 5, 69, 7, 71, 12, 76, 14, 78, 13, 77, 15, 79, 16, 80, |
1267 | 18, 82, 17, 81, 19, 83, 24, 88, 26, 90, 25, 89, 27, 91, 20, |
1268 | 84, 22, 86, 21, 85, 23, 87, 28, 92, 30, 94, 29, 93, 31, 95}; |
1269 | |
1270 | alignas(64) static constexpr const uint8_t idx_hi_8[64] = {32, 96, 34, 98, |
1271 | 33, 97, 35, 99, 40, 104, 42, 106, 41, 105, 43, 107, 36, 100, 38, |
1272 | 102, 37, 101, 39, 103, 44, 108, 46, 110, 45, 109, 47, 111, 48, 112, |
1273 | 50, 114, 49, 113, 51, 115, 56, 120, 58, 122, 57, 121, 59, 123, 52, |
1274 | 116, 54, 118, 53, 117, 55, 119, 60, 124, 62, 126, 61, 125, 63, 127}; |
1275 | |
1276 | vmovdqa64(vreg_idx_lo_256, |
1277 | is_amx ? (const void *)idx_lo_16 : (const void *)idx_lo_256); |
1278 | vmovdqa64(vreg_idx_hi_256, |
1279 | is_amx ? (const void *)idx_hi_16 : (const void *)idx_hi_256); |
1280 | vmovdqa64(vreg_idx_lo_128, |
1281 | is_amx ? (const void *)idx_lo_8 : (const void *)idx_lo_128); |
1282 | vmovdqa64(vreg_idx_hi_128, |
1283 | is_amx ? (const void *)idx_hi_8 : (const void *)idx_hi_128); |
1284 | |
1285 | if (do_compute_compensation) { |
1286 | int n_iters = div_up(conf_->wei_n_blk, 16); |
1287 | for (int i = 0; i < n_iters; i++) |
1288 | vpxord(get_comp_acc(i), get_comp_acc(i), get_comp_acc(i)); |
1289 | mov(imm_addr64, 1); |
1290 | vpbroadcastb(zmm_comp_mul, imm_addr64.cvt8()); |
1291 | } |
1292 | |
1293 | auto compute_K_loop = [=](bool is_N_tail) { |
1294 | const int k_unroll = 4; |
1295 | int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk; |
1296 | |
1297 | Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done; |
1298 | cmp(reg_K_iters, k_unroll * k_blk_step); |
1299 | jl(K_loop_single, T_NEAR); |
1300 | |
1301 | L(K_loop_unrolled); |
1302 | copy_4x64_vnni(k_unroll * k_blk_step, ncolumns); |
1303 | add(reg_src, k_unroll * k_blk_step * src_stride); |
1304 | add(reg_tr_src, k_unroll * tr_src_stride); |
1305 | |
1306 | sub(reg_K_iters, k_unroll * k_blk_step); |
1307 | cmp(reg_K_iters, k_unroll * k_blk_step); |
1308 | jge(K_loop_unrolled, T_NEAR); |
1309 | |
1310 | L(K_loop_single); |
1311 | cmp(reg_K_iters, k_blk_step); |
1312 | jl(K_loop_tail_or_done, T_NEAR); |
1313 | |
1314 | copy_4x64_vnni(k_blk_step, ncolumns); |
1315 | add(reg_src, k_blk_step * src_stride); |
1316 | add(reg_tr_src, tr_src_stride); |
1317 | |
1318 | sub(reg_K_iters, k_blk_step); |
1319 | jmp(K_loop_single, T_NEAR); |
1320 | |
1321 | L(K_loop_tail_or_done); |
1322 | |
1323 | int k_blk_tail = conf_->K % k_blk_step; |
1324 | if (k_blk_tail > 0) { |
1325 | Label K_loop_done; |
1326 | cmp(reg_K_iters, 0); |
1327 | jle(K_loop_done, T_NEAR); |
1328 | |
1329 | copy_4x64_vnni(k_blk_tail, ncolumns); |
1330 | sub(reg_K_iters, k_blk_tail); |
1331 | L(K_loop_done); |
1332 | } |
1333 | }; |
1334 | |
1335 | Label done; |
1336 | if (conf_->N_tail > 0) { |
1337 | Label not_N_tail; |
1338 | cmp(reg_N_blk, conf_->N_tail); |
1339 | jne(not_N_tail, T_NEAR); |
1340 | compute_K_loop(true); |
1341 | jmp(done, T_NEAR); |
1342 | |
1343 | L(not_N_tail); |
1344 | } |
1345 | |
1346 | compute_K_loop(false); |
1347 | L(done); |
1348 | |
1349 | if (do_compute_compensation) { |
1350 | const bool req_s8s8_comp = conf_->s8s8_compensation_required; |
1351 | const bool req_zp_comp = conf_->has_zero_point_a; |
1352 | int n_iters = div_up(conf_->wei_n_blk, 16); |
1353 | assert(IMPLICATION(req_zp_comp, |
1354 | conf_->src_zp_type == brgemm_broadcast_t::per_tensor)); |
1355 | |
1356 | // copy 'comp_acc' into s8s8_comp accumulator |
1357 | if (req_s8s8_comp) { |
1358 | for (int i = 0; i < n_iters; i++) |
1359 | vmovups(get_zmm_oscale_comp_res(i), get_comp_acc(i)); |
1360 | } |
1361 | |
1362 | Label skip_acc, store; |
1363 | if (req_s8s8_comp) |
1364 | mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]); |
1365 | if (req_zp_comp) |
1366 | mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]); |
1367 | |
1368 | mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]); |
1369 | cmp(reg_K_start, 0); |
1370 | je(skip_acc, T_NEAR); |
1371 | if (req_s8s8_comp) { |
1372 | for (int i = 0; i < n_iters; i++) { |
1373 | const auto zmm_acc = get_comp_acc(i); |
1374 | const auto zmm_res = get_zmm_oscale_comp_res(i); |
1375 | const auto addr = EVEX_compress_addr(reg_comp_ptr, i * 64); |
1376 | vpaddd(zmm_res, zmm_acc, addr); |
1377 | } |
1378 | } |
1379 | |
1380 | if (req_zp_comp) { |
1381 | for (int i = 0; i < n_iters; i++) { |
1382 | const auto zmm_acc = get_comp_acc(i); |
1383 | const auto zmm_res = get_zmm_zp_comp_res(i); |
1384 | const auto addr = EVEX_compress_addr(reg_zp_comp_ptr, i * 64); |
1385 | vpaddd(zmm_res, zmm_acc, addr); |
1386 | } |
1387 | } |
1388 | |
1389 | L(skip_acc); |
1390 | cmp(reg_K_start, rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk); |
1391 | jl(store, T_NEAR); |
1392 | |
1393 | if (req_s8s8_comp) { |
1394 | mov(imm_addr64, 0xffffffff); |
1395 | const auto zmm_all_bits_1 = zmm_comp_mul; |
1396 | vpbroadcastd(zmm_all_bits_1, imm_addr64.cvt32()); |
1397 | mov(imm_addr64, 0x1); |
1398 | const auto zmm_one_s32 = zmm_zero; |
1399 | vpbroadcastd(zmm_one_s32, imm_addr64.cvt32()); |
1400 | |
1401 | for (int i = 0; i < n_iters; i++) { |
1402 | const auto zmm_res = get_zmm_oscale_comp_res(i); |
1403 | // multiply by 128 |
1404 | vpslld(zmm_res, zmm_res, 7); |
1405 | // change sign |
1406 | vpandnq(zmm_res, zmm_res, zmm_all_bits_1); |
1407 | vpaddd(zmm_res, zmm_res, zmm_one_s32); |
1408 | } |
1409 | } |
1410 | |
1411 | if (req_zp_comp) { |
1412 | mov(reg_zp_a_neg_val_ptr, |
1413 | ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]); |
1414 | const auto zmm_zp_a_neg_val = vreg_idx_hi_128; |
1415 | vbroadcastss(zmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]); |
1416 | |
1417 | for (int i = 0; i < n_iters; i++) { |
1418 | const auto zmm_res = get_zmm_zp_comp_res(i); |
1419 | vpmulld(zmm_res, zmm_res, zmm_zp_a_neg_val); |
1420 | } |
1421 | } |
1422 | |
1423 | L(store); |
1424 | if (req_s8s8_comp) { |
1425 | for (int i = 0; i < n_iters; i++) { |
1426 | const auto zmm_res = get_zmm_oscale_comp_res(i); |
1427 | const auto addr = EVEX_compress_addr(reg_comp_ptr, i * 64); |
1428 | vmovups(addr, zmm_res); |
1429 | } |
1430 | } |
1431 | if (req_zp_comp) { |
1432 | for (int i = 0; i < n_iters; i++) { |
1433 | const auto zmm_res = get_zmm_zp_comp_res(i); |
1434 | const auto addr = EVEX_compress_addr(reg_zp_comp_ptr, i * 64); |
1435 | vmovups(addr, zmm_res); |
1436 | } |
1437 | } |
1438 | } |
1439 | |
1440 | postamble(); |
1441 | } |
1442 | |
1443 | struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, |
1444 | public jit_generator { |
1445 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t) |
1446 | |
1447 | jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf) |
1448 | : jit_brgemm_matmul_copy_b_t(conf) |
1449 | , jit_generator(jit_name()) |
1450 | , typesize(conf->b_dt_sz) |
1451 | , tr_typesize(conf->tr_b_dt_sz) |
1452 | , src_stride(conf_->wei_tag == format_tag::acbd |
1453 | ? conf->copy_B_wei_stride |
1454 | : conf->req_wei_vnni_downconvert |
1455 | ? conf_->LDB * typesize |
1456 | : conf_->N * typesize) |
1457 | , tr_src_stride(conf_->LDB * k_blk_step * tr_typesize) {} |
1458 | |
1459 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1460 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1461 | |
1462 | private: |
1463 | using reg64_t = const Xbyak::Reg64; |
1464 | using reg32_t = const Xbyak::Reg32; |
1465 | using opmask_t = const Xbyak::Opmask; |
1466 | using zmm = const Xbyak::Zmm; |
1467 | using ymm = const Xbyak::Ymm; |
1468 | |
1469 | enum { k_blk_step = 2, n_blk_step = 16 }; |
1470 | const int typesize, tr_typesize; |
1471 | const dim_t src_stride, tr_src_stride; |
1472 | |
1473 | opmask_t kTail = k7; |
1474 | opmask_t kFFFF = k6; |
1475 | |
1476 | reg64_t reg_src = rax; |
1477 | reg64_t reg_tr_src = rbx; |
1478 | |
1479 | reg64_t reg_K_iters = r8; |
1480 | reg64_t reg_N_blk = r9; |
1481 | reg64_t reg_K_start = r10; |
1482 | reg32_t regw_tmp = r14d; |
1483 | reg64_t imm_addr64 = r15; |
1484 | |
1485 | zmm zmm_permw = zmm30; |
1486 | zmm zmm_zero = zmm31; |
1487 | |
1488 | void copy_2x32_vnni(int nrows, int ncolumns); |
1489 | void generate() override; |
1490 | }; |
1491 | |
1492 | void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32_vnni(int nrows, int ncolumns) { |
1493 | |
1494 | auto kmovx = [=](Opmask k, unsigned w) { |
1495 | mov(regw_tmp, w); |
1496 | if (conf_->is_bf32) |
1497 | jit_generator::kmovw(k, regw_tmp); |
1498 | else |
1499 | jit_generator::kmovd(k, regw_tmp); |
1500 | }; |
1501 | |
1502 | const int columns_tail = ncolumns % n_blk_step; |
1503 | const auto tail_mask = (1 << columns_tail) - 1; |
1504 | if (columns_tail < n_blk_step) kmovx(kTail, tail_mask); |
1505 | |
1506 | const int blk_sz = k_blk_step; |
1507 | const int max_regs_available = 30; |
1508 | const int max_unroll = max_regs_available / blk_sz; |
1509 | auto get_zmm = [=](int blk, int idx) { |
1510 | assert(idx >= 0 && idx < blk_sz && blk >= 0); |
1511 | auto reg_idx = max_unroll * ((idx + 1) % blk_sz) + blk; |
1512 | assert(reg_idx >= 0 && reg_idx < max_regs_available); |
1513 | return zmm(reg_idx); |
1514 | }; |
1515 | |
1516 | auto load = [=](int blk, int k, int n, opmask_t current_mask) { |
1517 | auto src_reg = get_zmm(blk, k % k_blk_step); |
1518 | auto src_load = src_reg | current_mask | T_z; |
1519 | auto load_addr |
1520 | = EVEX_compress_addr(reg_src, k * src_stride + n * typesize); |
1521 | if (conf_->is_bf32) { |
1522 | vmovups(src_load, load_addr); |
1523 | } else { |
1524 | vmovdqu16(src_load, load_addr); |
1525 | } |
1526 | }; |
1527 | |
1528 | int iter = 0; |
1529 | for_(int k = 0; k < nrows; k += k_blk_step) |
1530 | for (int n = 0; n < conf_->wei_n_blk; n += n_blk_step) { |
1531 | const int k_blk = k / k_blk_step; |
1532 | const dim_t tr_src_off |
1533 | = k_blk * tr_src_stride + n * k_blk_step * tr_typesize; |
1534 | const auto store_addr = EVEX_compress_addr(reg_tr_src, tr_src_off); |
1535 | if (ncolumns - n <= 0) { |
1536 | vmovups(store_addr, zmm_zero); |
1537 | continue; |
1538 | } |
1539 | |
1540 | const opmask_t curr_msk = ncolumns - n < n_blk_step ? kTail : kFFFF; |
1541 | const int blk_idx = iter % max_unroll; |
1542 | load(blk_idx, k, n, curr_msk); |
1543 | |
1544 | const auto src_zmm0 = get_zmm(blk_idx, 0); |
1545 | if (nrows - k >= k_blk_step) { |
1546 | load(blk_idx, k + 1, n, curr_msk); |
1547 | const auto src_zmm1 = get_zmm(blk_idx, 1); |
1548 | if (conf_->is_bf32) { |
1549 | vcvtne2ps2bf16(src_zmm0, src_zmm1, src_zmm0); |
1550 | } else { |
1551 | const auto src_ymm1 = ymm(src_zmm1.getIdx()); |
1552 | vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1); |
1553 | } |
1554 | } else if (conf_->is_bf32) { |
1555 | vcvtneps2bf16(ymm(src_zmm0.getIdx()), src_zmm0); |
1556 | } |
1557 | |
1558 | vpermw(src_zmm0, zmm_permw, src_zmm0); |
1559 | |
1560 | vmovups(store_addr, src_zmm0); |
1561 | iter++; |
1562 | } |
1563 | } |
1564 | |
1565 | void jit_brgemm_matmul_copy_b_bf16_t::generate() { |
1566 | assert(tr_typesize == sizeof(bfloat16_t)); |
1567 | preamble(); |
1568 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1569 | |
1570 | alignas(64) static constexpr const int16_t bf16_vnni_permute[32] |
1571 | = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, |
1572 | 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
1573 | |
1574 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1575 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1576 | mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); |
1577 | mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); |
1578 | |
1579 | kxnorw(kFFFF, kFFFF, kFFFF); // 1111 1111 1111 1111 |
1580 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
1581 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
1582 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
1583 | }; |
1584 | |
1585 | vmovdqa64(zmm_permw, (const int64_t *)bf16_vnni_permute); |
1586 | |
1587 | auto compute_K_loop = [=](bool is_N_tail) { |
1588 | const int k_unroll = 8; |
1589 | int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk; |
1590 | |
1591 | Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done; |
1592 | cmp(reg_K_iters, k_unroll * k_blk_step); |
1593 | jl(K_loop_single, T_NEAR); |
1594 | |
1595 | L(K_loop_unrolled); |
1596 | copy_2x32_vnni(k_unroll * k_blk_step, ncolumns); |
1597 | add(reg_src, k_unroll * k_blk_step * src_stride); |
1598 | add(reg_tr_src, k_unroll * tr_src_stride); |
1599 | |
1600 | sub(reg_K_iters, k_unroll * k_blk_step); |
1601 | cmp(reg_K_iters, k_unroll * k_blk_step); |
1602 | jge(K_loop_unrolled, T_NEAR); |
1603 | |
1604 | L(K_loop_single); |
1605 | cmp(reg_K_iters, k_blk_step); |
1606 | jl(K_loop_tail_or_done, T_NEAR); |
1607 | |
1608 | copy_2x32_vnni(k_blk_step, ncolumns); |
1609 | add(reg_src, k_blk_step * src_stride); |
1610 | add(reg_tr_src, tr_src_stride); |
1611 | |
1612 | sub(reg_K_iters, k_blk_step); |
1613 | jmp(K_loop_single, T_NEAR); |
1614 | |
1615 | L(K_loop_tail_or_done); |
1616 | |
1617 | int k_blk_tail = conf_->K % k_blk_step; |
1618 | if (k_blk_tail > 0) { |
1619 | Label K_loop_done; |
1620 | cmp(reg_K_iters, 0); |
1621 | jle(K_loop_done, T_NEAR); |
1622 | |
1623 | copy_2x32_vnni(k_blk_tail, ncolumns); |
1624 | sub(reg_K_iters, k_blk_tail); |
1625 | L(K_loop_done); |
1626 | } |
1627 | }; |
1628 | |
1629 | Label done; |
1630 | if (conf_->N_tail > 0) { |
1631 | Label not_N_tail; |
1632 | cmp(reg_N_blk, conf_->N_tail); |
1633 | jne(not_N_tail, T_NEAR); |
1634 | compute_K_loop(true); |
1635 | jmp(done, T_NEAR); |
1636 | |
1637 | L(not_N_tail); |
1638 | } |
1639 | |
1640 | compute_K_loop(false); |
1641 | L(done); |
1642 | |
1643 | postamble(); |
1644 | } |
1645 | |
1646 | struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, |
1647 | public jit_generator { |
1648 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_f32_t) |
1649 | |
1650 | jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf) |
1651 | : jit_brgemm_matmul_copy_b_t(conf) |
1652 | , jit_generator(jit_name()) |
1653 | , dt_in_(conf->isa == avx512_core_fp16 ? data_type::f16 |
1654 | : data_type::f32) |
1655 | , typesize_in_(types::data_type_size(dt_in_)) |
1656 | , src_stride_(conf_->wei_tag == acbd ? conf_->copy_B_wei_stride |
1657 | : conf_->N * typesize_in_) |
1658 | , tr_src_stride_(conf_->LDB * typesize_out_) {} |
1659 | |
1660 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1661 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1662 | |
1663 | private: |
1664 | using reg64_t = const Xbyak::Reg64; |
1665 | using reg32_t = const Xbyak::Reg32; |
1666 | using opmask_t = const Xbyak::Opmask; |
1667 | using zmm = const Xbyak::Zmm; |
1668 | |
1669 | enum { n_blk_step = 16, max_regs_available = 30 }; |
1670 | const data_type_t dt_in_; |
1671 | const size_t typesize_in_; |
1672 | const size_t typesize_out_ = sizeof(float); |
1673 | dim_t src_stride_, tr_src_stride_; |
1674 | |
1675 | opmask_t kTail = k7; |
1676 | opmask_t kFFFF = k6; |
1677 | |
1678 | reg64_t reg_src = rax; |
1679 | reg64_t reg_tr_src = rbx; |
1680 | |
1681 | reg64_t reg_K_iters = r8; |
1682 | reg64_t reg_N_blk = r9; |
1683 | reg64_t reg_K_start = r10; |
1684 | reg32_t regw_tmp = r14d; |
1685 | reg64_t imm_addr64 = r15; |
1686 | |
1687 | zmm zmm_permw = zmm30; |
1688 | zmm zmm_zero = zmm31; |
1689 | |
1690 | inline void kmovw(Opmask k, unsigned w) { |
1691 | mov(regw_tmp, w); |
1692 | jit_generator::kmovd(k, regw_tmp); |
1693 | } |
1694 | void copy_16_x_n_block(int nrows, int ncolumns); |
1695 | void compute_k_loop(int ncolumns); |
1696 | void generate() override; |
1697 | }; |
1698 | |
1699 | void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( |
1700 | int nrows, int ncolumns) { |
1701 | |
1702 | auto get_zmm = [=](int reg_idx) { |
1703 | assert(reg_idx >= 0 && reg_idx < max_regs_available); |
1704 | return zmm(reg_idx); |
1705 | }; |
1706 | |
1707 | auto load = [=](int blk, int k, int n, opmask_t current_mask) { |
1708 | auto src_zmm = get_zmm(blk); |
1709 | auto src_zmm_m = src_zmm | current_mask | T_z; |
1710 | auto addr = EVEX_compress_addr( |
1711 | reg_src, k * src_stride_ + n * typesize_in_); |
1712 | if (dt_in_ == data_type::f16) |
1713 | vcvtph2psx(src_zmm_m, addr); |
1714 | else |
1715 | vmovups(src_zmm_m, addr); |
1716 | }; |
1717 | |
1718 | const int columns_tail = ncolumns % n_blk_step; |
1719 | const auto tail_mask = (1 << columns_tail) - 1; |
1720 | if (columns_tail < n_blk_step) kmovw(kTail, tail_mask); |
1721 | |
1722 | int iter = 0; |
1723 | for_(int k = 0; k < nrows; k++) |
1724 | for (int n = 0; n < conf_->wei_n_blk; n += n_blk_step) { |
1725 | const dim_t tr_src_off = k * tr_src_stride_ + n * typesize_out_; |
1726 | const auto store_addr = EVEX_compress_addr(reg_tr_src, tr_src_off); |
1727 | |
1728 | const int zero_padding = ncolumns - n; |
1729 | if (zero_padding <= 0) { |
1730 | vmovups(store_addr, zmm_zero); |
1731 | continue; |
1732 | } |
1733 | |
1734 | const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : kFFFF; |
1735 | const int blk_idx = iter % max_regs_available; |
1736 | load(blk_idx, k, n, curr_msk); |
1737 | |
1738 | const auto src_zmm0 = get_zmm(blk_idx); |
1739 | vmovups(store_addr, src_zmm0); |
1740 | iter++; |
1741 | } |
1742 | } |
1743 | |
1744 | void jit_brgemm_matmul_copy_b_f32_t::compute_k_loop(int ncolumns) { |
1745 | |
1746 | auto compute_uni_k_loop = [&](int unroll) { |
1747 | Label K_start_label, K_end_label; |
1748 | |
1749 | L(K_start_label); |
1750 | cmp(reg_K_iters, unroll); |
1751 | jl(K_end_label, T_NEAR); |
1752 | |
1753 | copy_16_x_n_block(unroll, ncolumns); |
1754 | add(reg_src, unroll * src_stride_); |
1755 | add(reg_tr_src, unroll * tr_src_stride_); |
1756 | |
1757 | sub(reg_K_iters, unroll); |
1758 | jmp(K_start_label, T_NEAR); |
1759 | |
1760 | L(K_end_label); |
1761 | }; |
1762 | |
1763 | constexpr int k_unroll = 16; |
1764 | compute_uni_k_loop(k_unroll); |
1765 | compute_uni_k_loop(1); |
1766 | } |
1767 | |
1768 | void jit_brgemm_matmul_copy_b_f32_t::generate() { |
1769 | preamble(); |
1770 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
1771 | |
1772 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
1773 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
1774 | mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); |
1775 | mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); |
1776 | kmovw(kFFFF, 0xffff); // 1111111111111111 |
1777 | |
1778 | Label done; |
1779 | if (conf_->N_tail > 0) { |
1780 | Label not_N_tail; |
1781 | cmp(reg_N_blk, conf_->N_tail); |
1782 | jne(not_N_tail, T_NEAR); |
1783 | compute_k_loop(conf_->N_tail); |
1784 | jmp(done, T_NEAR); |
1785 | |
1786 | L(not_N_tail); |
1787 | } |
1788 | |
1789 | compute_k_loop(conf_->N_blk); |
1790 | L(done); |
1791 | |
1792 | postamble(); |
1793 | } |
1794 | |
1795 | struct jit_brgemm_matmul_copy_b_transposed_t |
1796 | : public jit_brgemm_matmul_copy_b_t, |
1797 | public jit_generator { |
1798 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t) |
1799 | |
1800 | jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf) |
1801 | : jit_brgemm_matmul_copy_b_t(conf) |
1802 | , jit_generator(jit_name()) |
1803 | , typesize(conf_->b_dt_sz) |
1804 | , tr_typesize(conf_->tr_b_dt_sz) |
1805 | , vnni_granularity(data_type_vnni_granularity(conf_->wei_dt)) |
1806 | , k_blk_step(bytes_in_zmm / tr_typesize) |
1807 | , do_compute_compensation( |
1808 | conf_->has_zero_point_a || conf_->s8s8_compensation_required) |
1809 | , is_bf32(conf->is_bf32) |
1810 | , req_zp_comp(conf_->has_zero_point_a) |
1811 | , req_s8s8_comp(conf_->s8s8_compensation_required) |
1812 | , src_stride(conf_->wei_tag == format_tag::adbc |
1813 | ? conf_->copy_B_wei_stride |
1814 | : conf_->K * typesize) |
1815 | , tr_src_stride(conf_->LDB * vnni_granularity * tr_typesize) {} |
1816 | |
1817 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
1818 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
1819 | |
1820 | private: |
1821 | using reg64_t = const Xbyak::Reg64; |
1822 | using reg32_t = const Xbyak::Reg32; |
1823 | using opmask_t = const Xbyak::Opmask; |
1824 | using zmm = const Xbyak::Zmm; |
1825 | |
1826 | enum { |
1827 | n_blk_step = 16, |
1828 | bytes_in_zmm = 64, |
1829 | bf32_k_blk_step = 16, |
1830 | }; |
1831 | |
1832 | const int typesize; |
1833 | const int tr_typesize; |
1834 | const int vnni_granularity; |
1835 | const int k_blk_step; |
1836 | const bool do_compute_compensation; |
1837 | const bool is_bf32; |
1838 | const bool req_zp_comp; |
1839 | const bool req_s8s8_comp; |
1840 | |
1841 | const dim_t src_stride, tr_src_stride; |
1842 | |
1843 | opmask_t k3333 = k1; |
1844 | opmask_t k5555 = k2; |
1845 | opmask_t kAAAA = k3; |
1846 | opmask_t kCCCC = k4; |
1847 | opmask_t k0F0F = k5; |
1848 | opmask_t kF0F0 = k6; |
1849 | opmask_t kTail = k7; |
1850 | |
1851 | reg64_t reg_src_base = rax; |
1852 | reg64_t reg_tr_src_base = rbx; |
1853 | reg64_t reg_comp_ptr = rdx; |
1854 | |
1855 | reg64_t reg_K_iters = r8; |
1856 | reg64_t reg_N_iters = r9; |
1857 | reg64_t reg_src = r10; |
1858 | reg64_t reg_tr_src = r11; |
1859 | reg64_t reg_zp_comp_ptr = r12; |
1860 | reg64_t reg_zp_a_neg_val_ptr = r13; |
1861 | reg64_t reg_K_start = r14; |
1862 | |
1863 | reg64_t regq_tmp = r15; |
1864 | reg32_t regw_tmp = r15d; |
1865 | reg64_t imm_addr64 = abi_not_param1; |
1866 | |
1867 | zmm zmm_zp_a_neg_val = zmm29; |
1868 | zmm zmm_comp_acc = zmm30; |
1869 | zmm zmm_comp_mul = zmm31; |
1870 | zmm zmm_s8s8_comp_acc = zmm28; |
1871 | zmm zmm_all_bits_1 = zmm27; |
1872 | zmm zmm_one_s32 = zmm26; |
1873 | |
1874 | void kmovw(Opmask k, unsigned w) { |
1875 | mov(regw_tmp, w); |
1876 | jit_generator::kmovw(k, regw_tmp); |
1877 | }; |
1878 | |
1879 | void kmovq(Opmask k, size_t q) { |
1880 | mov(regq_tmp, q); |
1881 | jit_generator::kmovq(k, regq_tmp); |
1882 | }; |
1883 | |
1884 | void copy_16x64_vnni(int nrows, int ncolumns); |
1885 | void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter, |
1886 | bool is_last_K_iter); |
1887 | void compute_N_loop( |
1888 | int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter); |
1889 | |
1890 | void generate() override; |
1891 | }; |
1892 | |
1893 | void jit_brgemm_matmul_copy_b_transposed_t::copy_16x64_vnni( |
1894 | int nrows, int ncolumns) { |
1895 | assert(nrows >= 0 && nrows <= n_blk_step && ncolumns >= 0 |
1896 | && ncolumns <= k_blk_step); |
1897 | if (!nrows) return; |
1898 | |
1899 | auto src_zmm = [=](int i) { |
1900 | assert(i >= 0 && i < 16); |
1901 | return Zmm(i); |
1902 | }; |
1903 | |
1904 | auto tmp_zmm = [=](int i) { |
1905 | // If compensation compute is required - last 6 zmms are reserved for it |
1906 | assert(i >= 0 && i < 16 - do_compute_compensation * 6); |
1907 | return Zmm(16 + i); |
1908 | }; |
1909 | |
1910 | const int columns_tail |
1911 | = ncolumns % (is_bf32 ? bf32_k_blk_step : k_blk_step); |
1912 | if (columns_tail > 0) { |
1913 | const int dt_step |
1914 | = (is_bf32 || conf_->isa == avx512_core_fp16) ? 1 : typesize; |
1915 | const auto tail_mask |
1916 | = size_t(((size_t)1 << dt_step * columns_tail) - 1); |
1917 | if (is_bf32) |
1918 | kmovw(kTail, tail_mask); |
1919 | else |
1920 | kmovq(kTail, tail_mask); |
1921 | } |
1922 | |
1923 | auto load_bf32 = [=](int i) { |
1924 | auto src_reg = src_zmm(i); |
1925 | auto src_reg_next = tmp_zmm(i); |
1926 | |
1927 | if (i >= nrows) { |
1928 | vpxord(src_reg, src_reg, src_reg); |
1929 | return; |
1930 | } |
1931 | |
1932 | // check if k_tail exists and it's in the first zmm |
1933 | auto zmm_src = columns_tail > 0 && ncolumns < bf32_k_blk_step |
1934 | ? src_reg | kTail | T_z |
1935 | : src_reg; |
1936 | vmovups(zmm_src, EVEX_compress_addr(reg_src, i * src_stride)); |
1937 | |
1938 | if (ncolumns <= bf32_k_blk_step) { |
1939 | vpxord(src_reg_next, src_reg_next, src_reg_next); |
1940 | } else { |
1941 | auto zmm_src_next = columns_tail > 0 ? src_reg_next | kTail | T_z |
1942 | : src_reg_next; |
1943 | vmovups(zmm_src_next, |
1944 | EVEX_compress_addr(reg_src, |
1945 | i * src_stride + bf32_k_blk_step * typesize)); |
1946 | } |
1947 | |
1948 | vcvtne2ps2bf16(src_reg, src_reg_next, src_reg); |
1949 | }; |
1950 | |
1951 | auto load = [=](int i) { |
1952 | auto src_reg = src_zmm(i); |
1953 | if (i >= nrows) { |
1954 | vpxord(src_reg, src_reg, src_reg); |
1955 | return; |
1956 | } |
1957 | |
1958 | auto src_load = columns_tail > 0 ? src_reg | kTail | T_z : src_reg; |
1959 | const auto addr = EVEX_compress_addr(reg_src, i * src_stride); |
1960 | if (conf_->isa == avx512_core_fp16) |
1961 | vcvtph2psx(src_load, addr); |
1962 | else |
1963 | vmovdqu8(src_load, addr); |
1964 | }; |
1965 | |
1966 | auto store = [=](Zmm r, int i) { |
1967 | auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); |
1968 | vmovups(addr, r); |
1969 | }; |
1970 | |
1971 | auto transpose16x8 = [=](int base_idx) { |
1972 | assert(base_idx == 0 || base_idx == 8); |
1973 | // If compensation compute is required - use tmp(0) ... tmp(7) |
1974 | // to not spoil reserved registers' values |
1975 | const int tmp_corr_idx = do_compute_compensation * base_idx; |
1976 | |
1977 | // swap 1 |
1978 | if (is_bf32) { |
1979 | for (int i = 0; i < 4; i++) { |
1980 | const int src_idx0 = base_idx + i * 2; |
1981 | const int src_idx1 = src_idx0 + 1; |
1982 | |
1983 | if (base_idx == 0 && i == 0) { |
1984 | load_bf32(src_idx0); |
1985 | load_bf32(src_idx1); |
1986 | } |
1987 | |
1988 | int next_src_idx0 = src_idx0 + 2; |
1989 | int next_src_idx1 = src_idx1 + 2; |
1990 | |
1991 | bool load_next = base_idx == 0 || i < 3; |
1992 | |
1993 | const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx); |
1994 | const auto tmp1 = tmp_zmm(src_idx1 - tmp_corr_idx); |
1995 | const auto src0 = src_zmm(src_idx0); |
1996 | const auto src1 = src_zmm(src_idx1); |
1997 | |
1998 | if (next_src_idx0 < nrows && load_next) |
1999 | load_bf32(next_src_idx0); |
2000 | valignd(tmp0, src0, src0, 0x1); |
2001 | |
2002 | if (next_src_idx1 < nrows && load_next) |
2003 | load_bf32(next_src_idx1); |
2004 | valignd(tmp1, src1, src1, 0xf); |
2005 | |
2006 | vmovaps(src0 | kAAAA, tmp1); |
2007 | vmovaps(src1 | k5555, tmp0); |
2008 | } |
2009 | } else { |
2010 | for (int i = 0; i < 4; i++) { |
2011 | const int src_idx0 = base_idx + i * 2; |
2012 | const int src_idx1 = src_idx0 + 1; |
2013 | |
2014 | int next_src_idx0 = src_idx0 + 2; |
2015 | int next_src_idx1 = src_idx1 + 2; |
2016 | bool load_next = base_idx == 0 || i < 3; |
2017 | |
2018 | if (base_idx == 0 && i == 0) { |
2019 | load(src_idx0); |
2020 | load(src_idx1); |
2021 | } |
2022 | |
2023 | const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx); |
2024 | const auto tmp1 = tmp_zmm(src_idx1 - tmp_corr_idx); |
2025 | const auto src0 = src_zmm(src_idx0); |
2026 | const auto src1 = src_zmm(src_idx1); |
2027 | |
2028 | if (next_src_idx0 < nrows && load_next) load(next_src_idx0); |
2029 | valignd(tmp0, src0, src0, 0x1); |
2030 | |
2031 | if (next_src_idx1 < nrows && load_next) load(next_src_idx1); |
2032 | valignd(tmp1, src1, src1, 0xf); |
2033 | |
2034 | vmovaps(src0 | kAAAA, tmp1); |
2035 | vmovaps(src1 | k5555, tmp0); |
2036 | } |
2037 | } |
2038 | // swap 2 |
2039 | for (int i = 0; i < 4; i++) { |
2040 | const int select_half = (i < 2) ? 0 : 2; |
2041 | const int src_idx0 = base_idx + i + select_half + 0; |
2042 | const int src_idx2 = src_idx0 + 2; |
2043 | |
2044 | const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx); |
2045 | const auto tmp1 = tmp_zmm(src_idx2 - tmp_corr_idx); |
2046 | const auto src0 = src_zmm(src_idx0); |
2047 | const auto src2 = src_zmm(src_idx2); |
2048 | |
2049 | valignd(tmp0, src0, src0, 0x2); |
2050 | valignd(tmp1, src2, src2, 0xe); |
2051 | vmovaps(src2 | k3333, tmp0); |
2052 | vmovaps(src0 | kCCCC, tmp1); |
2053 | } |
2054 | |
2055 | // swap 4 |
2056 | for (int i = 0; i < 4; i++) { |
2057 | const int src_idx0 = base_idx + i; |
2058 | const int src_idx4 = src_idx0 + 4; |
2059 | |
2060 | const auto tmp0 = tmp_zmm(src_idx0 - tmp_corr_idx); |
2061 | const auto src0 = src_zmm(src_idx0); |
2062 | const auto src4 = src_zmm(src_idx4); |
2063 | |
2064 | vmovaps(tmp0, src0); |
2065 | vshuff32x4(src0 | kF0F0, src4, src4, 0xb1); |
2066 | vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1); |
2067 | } |
2068 | }; |
2069 | |
2070 | auto fixup16x16 = [=]() { |
2071 | // swap 8 |
2072 | for (int i = 0; i < 8; i++) { |
2073 | const auto tmp = tmp_zmm(i); |
2074 | const auto src0 = src_zmm(i); |
2075 | const auto src8 = src_zmm(8 + i); |
2076 | vshuff64x2(tmp, src0, src8, 0x44); |
2077 | if (do_compute_compensation) |
2078 | vpdpbusd(zmm_comp_acc, zmm_comp_mul, tmp); |
2079 | store(tmp, i); |
2080 | } |
2081 | |
2082 | for (int i = 0; i < 8; i++) { |
2083 | // If compensation compute is required - last 4 zmms are reserved |
2084 | const auto tmp = IMPLICATION(do_compute_compensation, i < 2) |
2085 | ? tmp_zmm(8 + i) |
2086 | : src_zmm((i - 2) / 2 + (i % 2) * 8); |
2087 | const auto src0 = src_zmm(i); |
2088 | const auto src8 = src_zmm(8 + i); |
2089 | vshuff64x2(tmp, src0, src8, 0xee); |
2090 | if (do_compute_compensation) |
2091 | vpdpbusd(zmm_comp_acc, zmm_comp_mul, tmp); |
2092 | store(tmp, 8 + i); |
2093 | } |
2094 | }; |
2095 | |
2096 | transpose16x8(0); |
2097 | transpose16x8(8); |
2098 | fixup16x16(); |
2099 | } |
2100 | |
2101 | void jit_brgemm_matmul_copy_b_transposed_t::compute_K_loop(bool is_N_tail, |
2102 | int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) { |
2103 | MAYBE_UNUSED(is_first_K_iter); |
2104 | MAYBE_UNUSED(is_last_K_iter); |
2105 | const int N_chunk_tail = conf_->N % n_blk_step; |
2106 | int nrows = is_N_tail ? N_chunk_tail : n_blk_step; |
2107 | if (do_compute_compensation) |
2108 | vpxord(zmm_comp_acc, zmm_comp_acc, zmm_comp_acc); |
2109 | |
2110 | Label K_loop, K_loop_tail_or_done; |
2111 | mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); |
2112 | |
2113 | mov(reg_src, reg_src_base); |
2114 | mov(reg_tr_src, reg_tr_src_base); |
2115 | if (curr_K_tail > 0) { |
2116 | cmp(reg_K_iters, k_blk_step); |
2117 | jl(K_loop_tail_or_done, T_NEAR); |
2118 | } |
2119 | |
2120 | L(K_loop); |
2121 | copy_16x64_vnni(nrows, k_blk_step); |
2122 | add(reg_src, k_blk_step * typesize); |
2123 | add(reg_tr_src, k_blk_step / vnni_granularity * tr_src_stride); |
2124 | |
2125 | sub(reg_K_iters, k_blk_step); |
2126 | cmp(reg_K_iters, k_blk_step); |
2127 | jge(K_loop, T_NEAR); |
2128 | |
2129 | L(K_loop_tail_or_done); |
2130 | |
2131 | if (curr_K_tail > 0) copy_16x64_vnni(nrows, curr_K_tail); |
2132 | |
2133 | if (req_s8s8_comp) { |
2134 | const auto addr = zword[reg_comp_ptr]; |
2135 | if (!is_first_K_iter) |
2136 | vpaddd(zmm_s8s8_comp_acc, zmm_comp_acc, addr); |
2137 | else |
2138 | vmovups(zmm_s8s8_comp_acc, zmm_comp_acc); |
2139 | |
2140 | if (is_last_K_iter) { |
2141 | // multiply by 128 |
2142 | vpslld(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, 7); |
2143 | // change sign |
2144 | vpandnq(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, zmm_all_bits_1); |
2145 | vpaddd(zmm_s8s8_comp_acc, zmm_s8s8_comp_acc, zmm_one_s32); |
2146 | } |
2147 | vmovups(addr, zmm_s8s8_comp_acc); |
2148 | } |
2149 | if (req_zp_comp) { |
2150 | const auto addr = zword[reg_zp_comp_ptr]; |
2151 | if (!is_first_K_iter) vpaddd(zmm_comp_acc, zmm_comp_acc, addr); |
2152 | if (is_last_K_iter) |
2153 | vpmulld(zmm_comp_acc, zmm_comp_acc, zmm_zp_a_neg_val); |
2154 | vmovups(addr, zmm_comp_acc); |
2155 | } |
2156 | } |
2157 | |
2158 | void jit_brgemm_matmul_copy_b_transposed_t::compute_N_loop( |
2159 | int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) { |
2160 | const int N_chunk_tail = conf_->N % n_blk_step; |
2161 | const size_t comp_shift = 64; |
2162 | |
2163 | Label N_loop, N_loop_tail_or_done; |
2164 | if (N_chunk_tail > 0) { |
2165 | cmp(reg_N_iters, n_blk_step); |
2166 | jl(N_loop_tail_or_done, T_NEAR); |
2167 | } |
2168 | |
2169 | L(N_loop); |
2170 | compute_K_loop(false, curr_K_tail, is_first_K_iter, is_last_K_iter); |
2171 | add(reg_src_base, n_blk_step * src_stride); |
2172 | add(reg_tr_src_base, n_blk_step * vnni_granularity * tr_typesize); |
2173 | |
2174 | if (req_zp_comp) add(reg_zp_comp_ptr, comp_shift); |
2175 | if (req_s8s8_comp) add(reg_comp_ptr, comp_shift); |
2176 | |
2177 | sub(reg_N_iters, n_blk_step); |
2178 | cmp(reg_N_iters, n_blk_step); |
2179 | jge(N_loop, T_NEAR); |
2180 | |
2181 | L(N_loop_tail_or_done); |
2182 | if (N_chunk_tail > 0) { |
2183 | Label N_loop_done; |
2184 | cmp(reg_N_iters, 0); |
2185 | jle(N_loop_done, T_NEAR); |
2186 | |
2187 | compute_K_loop(true, curr_K_tail, is_first_K_iter, is_last_K_iter); |
2188 | L(N_loop_done); |
2189 | } |
2190 | } |
2191 | |
2192 | void jit_brgemm_matmul_copy_b_transposed_t::generate() { |
2193 | |
2194 | preamble(); |
2195 | |
2196 | mov(reg_src_base, ptr[param1 + GET_OFF(src)]); |
2197 | mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); |
2198 | mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); |
2199 | mov(reg_N_iters, ptr[param1 + GET_OFF(current_N_blk)]); |
2200 | |
2201 | kmovw(k3333, 0x3333); // 0011001100110011 |
2202 | kmovw(k5555, 0x5555); // 0101010101010101 |
2203 | kmovw(kAAAA, 0xaaaa); // 1010101010101010 |
2204 | kmovw(kCCCC, 0xcccc); // 1100110011001100 |
2205 | kmovw(k0F0F, 0x0f0f); // 0000111100001111 |
2206 | kmovw(kF0F0, 0xf0f0); // 1111000011110000 |
2207 | |
2208 | const dim_t N_chunk_elems = conf_->N_chunk_elems; |
2209 | assert(N_chunk_elems % n_blk_step == 0 || N_chunk_elems == conf_->N); |
2210 | UNUSED(N_chunk_elems); |
2211 | |
2212 | const auto K_blk_tail = nstl::min(conf_->K, conf_->K_blk) % k_blk_step; |
2213 | const auto K_tail_tail = (conf_->K % conf_->K_blk) % k_blk_step; |
2214 | |
2215 | auto compute_body = [=](bool is_first_K_iter, bool is_last_K_iter) { |
2216 | if (is_last_K_iter) { |
2217 | if (req_s8s8_comp) { |
2218 | mov(imm_addr64, 0xffffffff); |
2219 | vpbroadcastd(zmm_all_bits_1, imm_addr64.cvt32()); |
2220 | mov(imm_addr64, 0x1); |
2221 | vpbroadcastd(zmm_one_s32, imm_addr64.cvt32()); |
2222 | } |
2223 | if (req_zp_comp) { |
2224 | mov(reg_zp_a_neg_val_ptr, |
2225 | ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]); |
2226 | vbroadcastss(zmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]); |
2227 | } |
2228 | } |
2229 | |
2230 | Label compute_body_done; |
2231 | if (conf_->K_tail > 0 && K_blk_tail != K_tail_tail) { |
2232 | Label not_K_tail; |
2233 | cmp(reg_K_iters, conf_->K_blk); |
2234 | je(not_K_tail, T_NEAR); |
2235 | compute_N_loop(K_tail_tail, is_first_K_iter, is_last_K_iter); |
2236 | jmp(compute_body_done, T_NEAR); |
2237 | |
2238 | L(not_K_tail); |
2239 | } |
2240 | |
2241 | compute_N_loop(K_blk_tail, is_first_K_iter, is_last_K_iter); |
2242 | L(compute_body_done); |
2243 | }; |
2244 | |
2245 | Label done; |
2246 | if (do_compute_compensation) { |
2247 | assert(IMPLICATION(req_zp_comp, |
2248 | conf_->src_zp_type == brgemm_broadcast_t::per_tensor)); |
2249 | |
2250 | mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]); |
2251 | if (req_s8s8_comp) |
2252 | mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]); |
2253 | if (req_zp_comp) |
2254 | mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]); |
2255 | |
2256 | mov(regq_tmp, 1); |
2257 | vpbroadcastb(zmm_comp_mul, regq_tmp.cvt8()); |
2258 | |
2259 | const auto last_K_threshold |
2260 | = rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk; |
2261 | Label not_first, not_first_not_last; |
2262 | cmp(reg_K_start, 0); |
2263 | jne(not_first, T_NEAR); |
2264 | { |
2265 | // first K iteration |
2266 | Label first_not_last; |
2267 | cmp(reg_K_start, last_K_threshold); |
2268 | jl(first_not_last, T_NEAR); |
2269 | compute_body(true, true); |
2270 | jmp(done, T_NEAR); |
2271 | |
2272 | L(first_not_last); |
2273 | compute_body(true, false); |
2274 | jmp(done, T_NEAR); |
2275 | } |
2276 | |
2277 | L(not_first); |
2278 | cmp(reg_K_start, last_K_threshold); |
2279 | jl(not_first_not_last, T_NEAR); |
2280 | |
2281 | compute_body(false, true); |
2282 | jmp(done, T_NEAR); |
2283 | L(not_first_not_last); |
2284 | } |
2285 | |
2286 | compute_body(false, false); |
2287 | L(done); |
2288 | |
2289 | postamble(); |
2290 | } |
2291 | |
2292 | status_t create_brgemm_matmul_copy_b( |
2293 | std::unique_ptr<jit_brgemm_matmul_copy_b_t> ©_ker, |
2294 | const brgemm_matmul_conf_t *conf) { |
2295 | const bool is_B_transposed |
2296 | = one_of(conf->wei_tag, ba, acb, abdc, adbc, abced, abcdfe, abcdegf, |
2297 | abcdefhg, abcdefgih, abcdefghji, abcdefghikj, abcdefghijlk); |
2298 | const bool is_bf16 |
2299 | = everyone_is(data_type::bf16, conf->src_dt, conf->wei_dt); |
2300 | const bool is_f32 = everyone_is(data_type::f32, conf->src_dt, conf->wei_dt); |
2301 | // Note: f16 support through avx512_core_fp16 sets src_dt and wei_dt as f32 |
2302 | // to imply upconverting. So, the assumption is `is_f1`6 below evaluates to |
2303 | // `false` on avx512_core_fp16. |
2304 | const bool is_f16 = everyone_is(data_type::f16, conf->src_dt, conf->wei_dt); |
2305 | if (is_B_transposed) { |
2306 | CHECK(safe_ptr_assign( |
2307 | copy_ker, new jit_brgemm_matmul_copy_b_transposed_t(conf))); |
2308 | } else { |
2309 | if (is_bf16 || is_f16 || conf->is_bf32) { |
2310 | CHECK(safe_ptr_assign( |
2311 | copy_ker, new jit_brgemm_matmul_copy_b_bf16_t(conf))); |
2312 | } else if (is_f32 || conf->isa == avx512_core_fp16) { |
2313 | CHECK(safe_ptr_assign( |
2314 | copy_ker, new jit_brgemm_matmul_copy_b_f32_t(conf))); |
2315 | } else { |
2316 | CHECK(safe_ptr_assign( |
2317 | copy_ker, new jit_brgemm_matmul_copy_b_int8_t(conf))); |
2318 | } |
2319 | } |
2320 | |
2321 | return copy_ker->create_kernel(); |
2322 | } |
2323 | |
2324 | status_t create_brgemm_matmul_copy_a( |
2325 | std::unique_ptr<jit_brgemm_matmul_copy_a_t> ©_ker, |
2326 | const brgemm_matmul_conf_t *conf) { |
2327 | if (conf->transposed_A) { |
2328 | CHECK(safe_ptr_assign(copy_ker, |
2329 | new jit_brgemm_matmul_copy_a_transposed_impl_t(conf))); |
2330 | } else { |
2331 | CHECK(safe_ptr_assign( |
2332 | copy_ker, new jit_brgemm_matmul_copy_a_impl_t(conf))); |
2333 | } |
2334 | |
2335 | return copy_ker->create_kernel(); |
2336 | } |
2337 | |
2338 | } // namespace matmul |
2339 | } // namespace x64 |
2340 | } // namespace cpu |
2341 | } // namespace impl |
2342 | } // namespace dnnl |
2343 | |