1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
23 | #include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp" |
24 | #include "cpu/x64/cpu_barrier.hpp" |
25 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
26 | #include "cpu/x64/jit_generator.hpp" |
27 | |
28 | #define GET_OFF(field) offsetof(brgemm_kernel_params_t, field) |
29 | #define GET_OFF_BATCH_ELEMENT(field) offsetof(brgemm_batch_element_t, field) |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | using namespace dnnl::impl::utils; |
37 | using namespace Xbyak; |
38 | |
39 | template <cpu_isa_t isa, typename Wmm> |
40 | jit_brdgmm_kernel_base_t<isa, Wmm>::jit_brdgmm_kernel_base_t( |
41 | const brgemm_t &abrd) |
42 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
43 | , brg(abrd) |
44 | , simd_w_(vreg_traits<Vmm>::vlen / brg.typesize_C) |
45 | , max_vmms_(isa_num_vregs(isa)) { |
46 | |
47 | if (brg.with_eltwise || brg.with_binary || brg.with_sum) { |
48 | |
49 | static constexpr bool preserve_gpr = true; |
50 | static constexpr bool preserve_vmm = false; |
51 | static constexpr bool use_exact_tail_scalar_bcast = false; |
52 | const auto dst_md_wrapper = memory_desc_wrapper(brg.dst_md); |
53 | const size_t tail = tail_length(); |
54 | |
55 | static const bcast_set_t enabled_bcast_strategy |
56 | = {broadcasting_strategy_t::scalar, |
57 | broadcasting_strategy_t::per_oc, |
58 | broadcasting_strategy_t::no_broadcast}; |
59 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
60 | static_cast<size_t>(vmm_b().getIdx()), r14, r15, r13, |
61 | preserve_gpr, preserve_vmm, |
62 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(data_C_ptr_), |
63 | dst_md_wrapper, tail, k_mask, use_exact_tail_scalar_bcast}; |
64 | const binary_injector::static_params_t bsp { |
65 | this->param1, enabled_bcast_strategy, rhs_sp}; |
66 | |
67 | postops_injector_ = utils::make_unique<po_injector_t>( |
68 | this, brg.attr->post_ops_, bsp); |
69 | |
70 | with_binary_non_scalar_bcast_ |
71 | = binary_injector::any_binary_postop_rhs_non_scalar_broadcast( |
72 | brg.attr->post_ops_, dst_md_wrapper); |
73 | } |
74 | if (brg.is_bf16_emu) |
75 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
76 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
77 | bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4); |
78 | } |
79 | |
80 | template <cpu_isa_t isa, typename Wmm> |
81 | template <typename U> |
82 | U jit_brdgmm_kernel_base_t<isa, Wmm>::maybe_mask( |
83 | const U umm_in, bool mask_flag, bool store) { |
84 | return mask_flag ? (store ? umm_in | k_mask : umm_in | k_mask | T_z) |
85 | : umm_in; |
86 | } |
87 | |
88 | template <cpu_isa_t isa, typename Wmm> |
89 | void jit_brdgmm_kernel_base_t<isa, Wmm>::read_params() { |
90 | Label label_done; |
91 | |
92 | mov(reg_BS, ptr[param1 + GET_OFF(BS)]); |
93 | mov(reg_aux_C, ptr[param1 + GET_OFF(ptr_C)]); |
94 | mov(reg_aux_D, ptr[param1 + GET_OFF(ptr_D)]); |
95 | |
96 | if (brg.type == brgemm_offs) { |
97 | mov(reg_A, ptr[param1 + GET_OFF(ptr_A)]); |
98 | mov(reg_B, ptr[param1 + GET_OFF(ptr_B)]); |
99 | } else if (brg.type == brgemm_strd) { |
100 | mov(reg_aux1_A, ptr[param1 + GET_OFF(ptr_A)]); |
101 | mov(reg_aux1_B, ptr[param1 + GET_OFF(ptr_B)]); |
102 | if (brg.brgattr.max_bs > 1) { |
103 | mov(ptr[rsp + reg_A_offs_], reg_aux1_A); |
104 | mov(ptr[rsp + reg_B_offs_], reg_aux1_B); |
105 | } |
106 | } |
107 | |
108 | if (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad()) { |
109 | mov(reg_aux_batch_addr, ptr[param1 + GET_OFF(batch)]); |
110 | if (brg.brgattr.max_bs > 1) |
111 | mov(ptr[rsp + reg_batch0_addr_offs_], reg_aux_batch_addr); |
112 | } |
113 | |
114 | if (brg.with_bias) { |
115 | mov(reg_tmp, ptr[param1 + GET_OFF(ptr_bias)]); |
116 | mov(ptr[rsp + reg_bias_offs_], reg_tmp); |
117 | } |
118 | |
119 | if (brg.with_scales) { |
120 | mov(reg_tmp, ptr[param1 + GET_OFF(ptr_scales)]); |
121 | mov(ptr[rsp + reg_scales_offs_], reg_tmp); |
122 | } |
123 | |
124 | if (brg.with_binary) mov(ptr[rsp + abi_param1_offs_], param1); |
125 | } |
126 | |
127 | template <cpu_isa_t isa, typename Wmm> |
128 | void jit_brdgmm_kernel_base_t<isa, Wmm>::load_accumulators( |
129 | int m_blocks, int n_blocks) { |
130 | const int v_substep = vnni_substep(); |
131 | for_(int v = 0; v < v_substep; ++v) |
132 | for_(int m = 0; m < m_blocks; ++m) |
133 | for (int n = 0; n < n_blocks; ++n) { |
134 | auto vmm = accm(m_blocks, n_blocks, m, n, v); |
135 | uni_vpxor(vmm, vmm, vmm); |
136 | } |
137 | } |
138 | |
139 | template <cpu_isa_t isa, typename Wmm> |
140 | void jit_brdgmm_kernel_base_t<isa, Wmm>::restore_A_B_matrices() { |
141 | if (brg.brgattr.max_bs > 1 |
142 | && (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad())) |
143 | mov(reg_aux_batch_addr, ptr[rsp + reg_batch0_addr_offs_]); |
144 | |
145 | if (brg.type == brgemm_strd && brg.brgattr.max_bs > 1) { |
146 | mov(reg_aux1_A, ptr[rsp + reg_A_offs_]); |
147 | mov(reg_aux1_B, ptr[rsp + reg_B_offs_]); |
148 | } |
149 | } |
150 | |
151 | template <cpu_isa_t isa, typename Wmm> |
152 | void jit_brdgmm_kernel_base_t<isa, Wmm>::set_A_B_matrices() { |
153 | |
154 | if (brg.type == brgemm_addr) { |
155 | mov(reg_aux_A, ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(ptr.A)]); |
156 | mov(reg_aux_B, ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(ptr.B)]); |
157 | } else if (brg.type == brgemm_offs) { |
158 | mov(reg_aux_A, reg_A); |
159 | mov(reg_aux_B, reg_B); |
160 | add(reg_aux_A, |
161 | ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(offset.A)]); |
162 | add(reg_aux_B, |
163 | ptr[reg_aux_batch_addr + GET_OFF_BATCH_ELEMENT(offset.B)]); |
164 | } else if (brg.type == brgemm_strd) { |
165 | mov(reg_aux_A, reg_aux1_A); |
166 | mov(reg_aux_B, reg_aux1_B); |
167 | if (brg.brgattr.max_bs > 1) { |
168 | safe_add(reg_aux1_A, brg.stride_a, reg_tmp); |
169 | safe_add(reg_aux1_B, brg.stride_b, reg_tmp); |
170 | } |
171 | } |
172 | |
173 | add(reg_aux_A, reg_a_offset); |
174 | lea(reg_aux_B, ptr[reg_aux_B + reg_aux_N * brg.typesize_B]); |
175 | } |
176 | |
177 | template <cpu_isa_t isa, typename Wmm> |
178 | void jit_brdgmm_kernel_base_t<isa, Wmm>::advance_A_B_matrices() { |
179 | if (brg.brgattr.max_bs > 1 |
180 | && (one_of(brg.type, brgemm_addr, brgemm_offs) || has_vpad())) |
181 | add(reg_aux_batch_addr, sizeof(brgemm_batch_element_t)); |
182 | } |
183 | |
184 | template <cpu_isa_t isa, typename Wmm> |
185 | void jit_brdgmm_kernel_base_t<isa, Wmm>::cvt2ps(data_type_t type_in, |
186 | const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag, |
187 | bool store) { |
188 | const int tail_size = tail_length(); |
189 | const bool is_load_tail = op.isMEM() && mask_flag && tail_size > 0 |
190 | && (tail_size |
191 | < static_cast<int>(vreg_traits<Vmm>::vlen / sizeof(float))); |
192 | if (IMPLICATION(is_load_tail, isa_has_masks(brg.isa_impl))) { |
193 | const Vmm vmm = maybe_mask(vmm_in, is_load_tail, store); |
194 | switch (type_in) { |
195 | case data_type::f32: |
196 | case data_type::s32: vmovups(vmm, op); break; |
197 | case data_type::bf16: |
198 | vpmovzxwd(vmm, op); |
199 | vpslld(vmm, vmm, 16); |
200 | break; |
201 | case data_type::f16: vcvtph2ps(vmm, op); break; |
202 | case data_type::s8: vpmovsxbd(vmm, op); break; |
203 | case data_type::u8: vpmovzxbd(vmm, op); break; |
204 | default: assert(!"unsupported data type" ); |
205 | } |
206 | } else { |
207 | uni_vpxor(vmm_in, vmm_in, vmm_in); |
208 | load_data(type_in, vmm_in, op.getAddress(), tail_size); |
209 | } |
210 | if (types::is_integral_dt(type_in)) vcvtdq2ps(vmm_in, vmm_in); |
211 | } |
212 | |
213 | template <cpu_isa_t isa, typename Wmm> |
214 | void jit_brdgmm_kernel_base_t<isa, Wmm>::apply_post_ops( |
215 | int m_blocks, int n_blocks, bool has_n_tail) { |
216 | |
217 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
218 | injector_utils::vmm_index_set_t vmm_idxs_param; |
219 | const int v_substep = vnni_substep(); |
220 | |
221 | // collect vmm_idx's to apply post ops. |
222 | // incase of avx2_vnni_2 tails, it is possible we do not need apply post-ops |
223 | // to last vnni_substep |
224 | for_(int v_i = 0; v_i < v_substep; ++v_i) |
225 | for_(int m_i = 0; m_i < m_blocks; ++m_i) |
226 | for (int n_i = 0; n_i < n_blocks; ++n_i) { |
227 | if (get_substep_simd(n_i, v_i, has_n_tail) <= 0) continue; |
228 | const auto vmm_idx = accm(m_blocks, n_blocks, m_i, n_i, v_i).getIdx(); |
229 | vmm_idxs_param.insert(vmm_idx); |
230 | } |
231 | |
232 | if (brg.with_binary) { |
233 | mov(reg_binary_params, ptr[rsp + abi_param1_offs_]); |
234 | |
235 | if (with_binary_non_scalar_bcast_) { |
236 | |
237 | for_(int v_i = 0; v_i < v_substep; ++v_i) |
238 | for_(int m_i = 0; m_i < m_blocks; m_i++) |
239 | for (int n_i = 0; n_i < n_blocks; n_i++) { |
240 | const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail); |
241 | if (substep_simd <= 0) continue; |
242 | const auto vmm_idx |
243 | = accm(m_blocks, n_blocks, m_i, n_i, v_i).getIdx(); |
244 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_aux_D); |
245 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
246 | vmm_idx, D_offset(m_i, n_i, v_i)); |
247 | |
248 | if (n_i + 1 == n_blocks && has_n_tail && substep_simd < simd_w_) |
249 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
250 | } |
251 | } |
252 | } |
253 | |
254 | const auto sum_injector = [&] { |
255 | const float *p_sum_scale = &brg.sum_scale; |
256 | const int32_t *p_sum_zp = &brg.sum_zp; |
257 | const bool p_sum_scale_reg_set = *p_sum_scale != 1.f; |
258 | const bool p_sum_zp_reg_set = *p_sum_zp != 0; |
259 | |
260 | const injector_utils::conditional_register_preserve_guard_t |
261 | register_guard_sum_scale( |
262 | (with_binary_non_scalar_bcast_) && p_sum_scale_reg_set, |
263 | this, {reg_ptr_sum_scale}); |
264 | const injector_utils::conditional_register_preserve_guard_t |
265 | register_guard_sum_zp(p_sum_zp_reg_set, this, {reg_ptr_sum_zp}); |
266 | |
267 | if (p_sum_scale_reg_set) |
268 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
269 | |
270 | auto vmm_sum_zp = vmm_tmp(0); |
271 | if (p_sum_zp_reg_set) { |
272 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
273 | vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
274 | } |
275 | |
276 | for_(int m_i = 0; m_i < m_blocks; m_i++) |
277 | for_(int n_i = 0; n_i < n_blocks; n_i++) |
278 | for (int v_i = 0; v_i < v_substep; v_i++) { |
279 | const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail); |
280 | if (substep_simd <= 0) continue; |
281 | const auto vmm = accm(m_blocks, n_blocks, m_i, n_i, v_i); |
282 | const auto addr = ptr[reg_aux_D + D_offset(m_i, n_i, v_i)]; |
283 | const auto vmm_prev_dst = vmm_tmp(1); |
284 | cvt2ps(brg.sum_dt, vmm_prev_dst, addr, substep_simd != simd_w_, |
285 | false); |
286 | if (p_sum_zp_reg_set) vsubps(vmm_prev_dst, vmm_sum_zp); |
287 | if (!p_sum_scale_reg_set) |
288 | vaddps(vmm, vmm_prev_dst); |
289 | else { |
290 | if (is_superset(brg.isa_impl, avx512_core)) { |
291 | vfmadd231ps(vmm, vmm_prev_dst, ptr_b[reg_ptr_sum_scale]); |
292 | } else { |
293 | auto vmm_scale = vmm_tmp(2); |
294 | uni_vpbroadcastd(vmm_scale, ptr[reg_ptr_sum_scale]); |
295 | uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_scale); |
296 | } |
297 | } |
298 | } |
299 | }; |
300 | |
301 | if (brg.with_sum) { |
302 | postops_injector_->set_lambda_injector( |
303 | primitive_kind::sum, sum_injector); |
304 | } |
305 | |
306 | postops_injector_->compute_vector_range(vmm_idxs_param, rhs_arg_params); |
307 | } |
308 | |
309 | template <cpu_isa_t isa, typename Wmm> |
310 | void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators_apply_post_ops( |
311 | int m_blocks, int n_blocks, bool has_n_tail) { |
312 | |
313 | const bool dq2ps_required = brg.is_int8; |
314 | const int v_substep = vnni_substep(); |
315 | if (brg.with_scales) { |
316 | mov(reg_aux_scales, ptr[rsp + reg_scales_offs_]); |
317 | if (brg.is_oc_scale) { |
318 | lea(reg_aux_scales, |
319 | ptr[reg_aux_scales + reg_aux_N * sizeof(float)]); |
320 | } |
321 | for_(int m = 0; m < m_blocks; m++) |
322 | for_(int n = 0; n < n_blocks; n++) |
323 | for (int v_i = 0; v_i < v_substep; ++v_i) { |
324 | const int substep_simd = get_substep_simd(n, v_i, has_n_tail); |
325 | if (substep_simd <= 0) continue; |
326 | const bool mask_flag = substep_simd < simd_w_; |
327 | const Vmm vmm = maybe_mask( |
328 | accm(m_blocks, n_blocks, m, n, v_i), mask_flag, false); |
329 | if (dq2ps_required) vcvtdq2ps(vmm, vmm); |
330 | if (IMPLICATION(mask_flag || !brg.is_oc_scale, |
331 | is_superset(brg.isa_impl, avx512_core))) { |
332 | if (brg.is_oc_scale) { |
333 | vmulps(vmm, vmm, |
334 | ptr[reg_aux_scales + scales_offset(n, v_i)]); |
335 | } else { |
336 | vmulps(vmm, vmm, ptr_b[reg_aux_scales]); |
337 | } |
338 | } else { |
339 | auto vmm_scale = vmm_tmp(0); |
340 | const auto addr = ptr[reg_aux_scales + scales_offset(n, v_i)]; |
341 | if (brg.is_oc_scale) { |
342 | uni_vpxor(vmm_scale, vmm_scale, vmm_scale); |
343 | load_data(data_type::f32, vmm_scale, addr, substep_simd); |
344 | } else { |
345 | vbroadcastss(vmm_scale, ptr[reg_aux_scales]); |
346 | } |
347 | vmulps(vmm, vmm, vmm_scale); |
348 | } |
349 | } |
350 | } |
351 | |
352 | if (brg.with_bias) { |
353 | mov(reg_aux_bias, ptr[rsp + reg_bias_offs_]); |
354 | lea(reg_aux_bias, ptr[reg_aux_bias + reg_aux_N * brg.typesize_bias]); |
355 | } |
356 | |
357 | for_(int v_i = 0; v_i < v_substep; ++v_i) |
358 | for (int n = 0; n < n_blocks; n++) { |
359 | auto vmm_bias = vmm_tmp(0); |
360 | const int substep_simd = get_substep_simd(n, v_i, has_n_tail); |
361 | if (substep_simd <= 0) continue; |
362 | if (brg.with_bias) { |
363 | auto ptr_bias = ptr[reg_aux_bias + bias_offset(n, v_i)]; |
364 | cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, substep_simd != simd_w_, |
365 | false); |
366 | } |
367 | for (int m = 0; m < m_blocks; m++) { |
368 | auto vmm = accm(m_blocks, n_blocks, m, n, v_i); |
369 | if (dq2ps_required && !brg.with_scales) vcvtdq2ps(vmm, vmm); |
370 | if (brg.with_bias) { vaddps(vmm, vmm, vmm_bias); } |
371 | } |
372 | } |
373 | |
374 | if (postops_injector_) apply_post_ops(m_blocks, n_blocks, has_n_tail); |
375 | |
376 | const bool dt_requires_saturation |
377 | = one_of(brg.dt_d, data_type::u8, data_type::s8, data_type::s32); |
378 | auto vmm_lbound = vmm_tmp(0); |
379 | auto vmm_ubound = vmm_tmp(1); |
380 | if (dt_requires_saturation) { |
381 | init_saturate_f32( |
382 | vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d); |
383 | } |
384 | |
385 | if (brg.is_bf16_emu) bf16_emu_->init_vcvtneps2bf16(); |
386 | |
387 | for (int m = 0; m < m_blocks; m++) { |
388 | if (dt_requires_saturation) { |
389 | for_(int n = 0; n < n_blocks; n++) |
390 | for (int v_i = 0; v_i < v_substep; ++v_i) { |
391 | if (get_substep_simd(n, v_i, has_n_tail) <= 0) continue; |
392 | auto vmm = accm(m_blocks, n_blocks, m, n, v_i); |
393 | saturate_f32(vmm, vmm_lbound, vmm_ubound, brg.dt_d); |
394 | vcvtps2dq(vmm, vmm); |
395 | } |
396 | } |
397 | |
398 | for_(int n = 0; n < n_blocks; n++) |
399 | for (int v_i = 0; v_i < v_substep; ++v_i) { |
400 | const int substep_simd = get_substep_simd(n, v_i, has_n_tail); |
401 | if (substep_simd <= 0) continue; |
402 | const auto offset = D_offset(m, n, v_i); |
403 | auto addr = ptr[reg_aux_D + offset]; |
404 | auto vmm = accm(m_blocks, n_blocks, m, n, v_i); |
405 | auto vmm_low = Vmm_low_t(vmm.getIdx()); |
406 | const bool mask_flag = substep_simd < simd_w_; |
407 | const Vmm r_vmm = maybe_mask(vmm, mask_flag, true); |
408 | const Vmm_low_t r_vmm_low = maybe_mask(vmm_low, mask_flag, true); |
409 | if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) { |
410 | switch (brg.dt_d) { |
411 | case data_type::f32: |
412 | case data_type::s32: vmovups(addr, r_vmm); break; |
413 | case data_type::bf16: |
414 | if (brg.is_bf16_emu) |
415 | bf16_emu_->vcvtneps2bf16(vmm_low, vmm); |
416 | else |
417 | vcvtneps2bf16(vmm_low, vmm, |
418 | brg.isa_impl == avx2_vnni_2 |
419 | ? Xbyak::VexEncoding |
420 | : Xbyak::EvexEncoding); |
421 | if (mask_flag) |
422 | vmovdqu16(addr, r_vmm_low); |
423 | else |
424 | vmovups(addr, r_vmm_low); |
425 | break; |
426 | case data_type::f16: |
427 | vcvtps2ph(addr, r_vmm, _op_mxcsr); |
428 | break; |
429 | case data_type::s8: vpmovsdb(addr, r_vmm); break; |
430 | case data_type::u8: vpmovusdb(addr, r_vmm); break; |
431 | default: assert(!"unknown dst_dt" ); |
432 | } |
433 | } else { |
434 | store_data(brg.dt_d, vmm, reg_aux_D, offset, substep_simd); |
435 | } |
436 | } |
437 | } |
438 | } |
439 | |
440 | template <cpu_isa_t isa, typename Wmm> |
441 | void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators_without_post_ops( |
442 | int m_blocks, int n_blocks, bool has_n_tail) { |
443 | |
444 | const bool dt_requires_saturation |
445 | = brg.is_int8 && brg.dt_c != data_type::s32; |
446 | auto vmm_lbound = vmm_tmp(0); |
447 | auto vmm_ubound = vmm_tmp(1); |
448 | if (dt_requires_saturation) { |
449 | init_saturate_f32( |
450 | vmm_lbound, vmm_ubound, reg_tmp, data_type::f32, brg.dt_d); |
451 | } |
452 | |
453 | for_(int m = 0; m < m_blocks; m++) |
454 | for_(int n = 0; n < n_blocks; n++) |
455 | for (int v_i = 0; v_i < vnni_substep(); ++v_i) { |
456 | const int substep_simd = get_substep_simd(n, v_i, has_n_tail); |
457 | if (substep_simd <= 0) continue; |
458 | const bool mask_flag = substep_simd < simd_w_; |
459 | auto vmm_acc = accm(m_blocks, n_blocks, m, n, v_i); |
460 | if (dt_requires_saturation) { |
461 | saturate_f32(vmm_acc, vmm_lbound, vmm_ubound, brg.dt_d); |
462 | vcvtps2dq(vmm_acc, vmm_acc); |
463 | } |
464 | const auto offset = C_offset(m, n, v_i); |
465 | if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) { |
466 | auto vmm_acc_masked = maybe_mask(vmm_acc, mask_flag, true); |
467 | vmovups(ptr[reg_aux_C + offset], vmm_acc_masked); |
468 | } else { |
469 | store_data(brg.dt_c, vmm_acc, reg_aux_C, offset, substep_simd); |
470 | } |
471 | } |
472 | } |
473 | |
474 | template <cpu_isa_t isa, typename Wmm> |
475 | void jit_brdgmm_kernel_base_t<isa, |
476 | Wmm>::maybe_transpose_interleaved_vnni_to_plain(int m_blocks, |
477 | int n_blocks, bool has_n_tail) { |
478 | |
479 | if (vnni_substep() == 1) return; |
480 | |
481 | // The tail block is always processed as plain. |
482 | // No need to transpose it here. |
483 | const int n_blocks_e = n_blocks - has_n_tail; |
484 | |
485 | auto ymm_aux0 = vmm_tmp(0); |
486 | for_(int m_i = 0; m_i < m_blocks; m_i++) |
487 | for (int n_i = 0; n_i < n_blocks_e; n_i++) { |
488 | auto ymm_even = accm(m_blocks, n_blocks, m_i, n_i, 0); |
489 | auto ymm_odd = accm(m_blocks, n_blocks, m_i, n_i, 1); |
490 | // reusing ymm_odd as aux |
491 | // TODO: Check for any latency due to register dependency |
492 | auto ymm_aux1 = ymm_odd; |
493 | vpunpckldq(ymm_aux0, ymm_even, ymm_odd); |
494 | vpunpckhdq(ymm_aux1, ymm_even, ymm_odd); |
495 | vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20); |
496 | vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31); |
497 | } |
498 | } |
499 | |
500 | template <cpu_isa_t isa, typename Wmm> |
501 | void jit_brdgmm_kernel_base_t<isa, Wmm>::store_accumulators( |
502 | int m_blocks, int n_blocks, bool has_n_tail) { |
503 | |
504 | maybe_transpose_interleaved_vnni_to_plain(m_blocks, n_blocks, has_n_tail); |
505 | |
506 | if (is_fast_vnni_int8() && brg.is_bf16_emu) { |
507 | // load permute indices from data section |
508 | mov(reg_tmp, permute_index_table); |
509 | vmovdqu32(vmm_permute(), ptr[reg_tmp]); |
510 | } |
511 | |
512 | if (is_fast_vnni_int8()) { |
513 | for_(int m_i = 0; m_i < m_blocks; ++m_i) |
514 | for (int n_i = 0; n_i < n_blocks; ++n_i) { |
515 | auto vmm_out = accm(m_blocks, n_blocks, m_i, n_i, 0); |
516 | vpermd(vmm_out, vmm_permute(), vmm_out); |
517 | } |
518 | } |
519 | |
520 | const bool are_post_ops_applicable |
521 | = one_of(true, brg.with_eltwise, brg.with_binary, brg.with_scales, |
522 | brg.with_bias, brg.with_sum, brg.dt_d != brg.dt_c); |
523 | |
524 | Label label_done; |
525 | if (are_post_ops_applicable) { |
526 | store_accumulators_apply_post_ops(m_blocks, n_blocks, has_n_tail); |
527 | } else { |
528 | store_accumulators_without_post_ops(m_blocks, n_blocks, has_n_tail); |
529 | } |
530 | } |
531 | |
532 | template <cpu_isa_t isa, typename Wmm> |
533 | void jit_brdgmm_kernel_base_t<isa, Wmm>::load_a( |
534 | Vmm vmma, int m_i, int n_i, int v_i, bool has_n_tail) { |
535 | const int n_blocks |
536 | = has_n_tail && n_block2_tail() > 0 ? n_block2_tail() : n_block2(); |
537 | const int substep_simd = get_substep_simd(n_i, v_i, has_n_tail); |
538 | const bool is_tail_block = has_n_tail && n_i + 1 == n_blocks; |
539 | const bool mask_flag = substep_simd < simd_w_; |
540 | const auto addr = ptr[reg_aux_A + A_offset(m_i, n_i) |
541 | + is_tail_block * v_i * simd_w_ * brg.typesize_A]; |
542 | if (IMPLICATION(mask_flag, isa_has_masks(brg.isa_impl))) { |
543 | vmma = maybe_mask(vmma, mask_flag, false); |
544 | if (brg.is_f32) { |
545 | vmovups(vmma, addr); |
546 | } else if (brg.is_bf16) { |
547 | if (brg.isa_impl == avx2_vnni_2) { |
548 | if (is_tail_block) { |
549 | vpmovzxwd(vmma, addr); |
550 | vpslld(vmma, vmma, 16); |
551 | } else if (v_i == 0) |
552 | vcvtneebf162ps(vmma, addr); |
553 | else |
554 | vcvtneobf162ps(vmma, addr); |
555 | } else { |
556 | vpmovzxwd(vmma, addr); |
557 | if (brg.is_bf16_tmm) vpslld(vmma, vmma, 16); |
558 | } |
559 | } else if (brg.is_f16) { |
560 | if (brg.isa_impl == avx2_vnni_2) { |
561 | if (is_tail_block) |
562 | vcvtph2ps(vmma, addr); |
563 | else if (v_i == 0) |
564 | vcvtneeph2ps(vmma, addr); |
565 | else |
566 | vcvtneoph2ps(vmma, addr); |
567 | } else |
568 | vcvtph2ps(vmma, addr); |
569 | } else if (brg.is_int8) { |
570 | if (is_fast_vnni_int8()) { |
571 | assert(!mask_flag); |
572 | vbroadcasti32x4(vmma, addr); |
573 | } else |
574 | vpmovzxbd(vmma, addr); |
575 | } |
576 | } else { |
577 | uni_vpxor(vmma, vmma, vmma); |
578 | load_data(brg.dt_a, vmma, addr, substep_simd); |
579 | } |
580 | } |
581 | |
582 | template <cpu_isa_t isa, typename Wmm> |
583 | void jit_brdgmm_kernel_base_t<isa, Wmm>::load_b( |
584 | Vmm vmmb, int n_i, int v_i, bool has_n_tail) { |
585 | // for B matrix we assume memory is padded and it is safe to load simd |
586 | // elements. is_tail only used during avx_ne_convert tail optimization. |
587 | const int n_blocks |
588 | = has_n_tail && n_block2_tail() > 0 ? n_block2_tail() : n_block2(); |
589 | const bool is_tail_block = has_n_tail && (n_i + 1 == n_blocks); |
590 | const auto addr = ptr[reg_aux_B + B_offset(n_i) |
591 | + is_tail_block * v_i * simd_w_ * brg.typesize_B]; |
592 | if (brg.is_f32) { |
593 | vmovups(vmmb, addr); |
594 | } else if (brg.is_int8) { |
595 | // wei is sign extend(s8), where as src is zero extended(u8). |
596 | if (is_fast_vnni_int8()) { |
597 | vbroadcasti32x4(vmmb, addr); |
598 | vmovdqu8(vmmb | kblend_mask | T_z, vmmb); |
599 | } else { |
600 | vpmovsxbd(vmmb, addr); |
601 | } |
602 | } else if (brg.is_f16) { |
603 | if (brg.isa_impl == avx2_vnni_2) { |
604 | if (is_tail_block) |
605 | vcvtph2ps(vmmb, addr); |
606 | else if (v_i == 0) |
607 | vcvtneeph2ps(vmmb, addr); |
608 | else |
609 | vcvtneoph2ps(vmmb, addr); |
610 | } else |
611 | vcvtph2ps(vmmb, addr); |
612 | } else if (brg.is_bf16) { |
613 | if (brg.isa_impl == avx2_vnni_2) { |
614 | if (is_tail_block) { |
615 | vpmovzxwd(vmmb, addr); |
616 | vpslld(vmmb, vmmb, 16); |
617 | } else if (v_i == 0) |
618 | vcvtneebf162ps(vmmb, addr); |
619 | else |
620 | vcvtneobf162ps(vmmb, addr); |
621 | } else { |
622 | vpmovzxwd(vmmb, addr); |
623 | if (brg.is_bf16_tmm) vpslld(vmmb, vmmb, 16); |
624 | } |
625 | } |
626 | } |
627 | |
628 | template <cpu_isa_t isa, typename Wmm> |
629 | void jit_brdgmm_kernel_base_t<isa, Wmm>::brdgmm_microkernel(int m_blocks, |
630 | int n_blocks, bool has_top_padding, bool has_bottom_padding, |
631 | bool has_tail) { |
632 | |
633 | const bool has_padding = has_top_padding || has_bottom_padding; |
634 | const int max_bvmms |
635 | = accm(m_blocks, n_blocks, 0, 0, 0).getIdx() - vmm_b(0).getIdx(); |
636 | const int v_substep = vnni_substep(); |
637 | |
638 | auto dot_product = [&](Vmm vmma, Vmm vmmb, int m_i, int n_i, int v_i) { |
639 | auto vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, v_i); |
640 | if (brg.is_f32) { |
641 | if (is_fma_embd()) { |
642 | const bool mask_flag = has_tail && (n_i + 1 == n_blocks); |
643 | const auto addr = ptr[reg_aux_A + A_offset(m_i, n_i)]; |
644 | vmm_acc = maybe_mask(vmm_acc, mask_flag, false); |
645 | vfmadd231ps(vmm_acc, vmmb, addr); |
646 | } else { |
647 | vfmadd231ps(vmm_acc, vmma, vmmb); |
648 | } |
649 | } else if (brg.is_bf16) { |
650 | if (brg.is_bf16_tmm /* dont use vdpbf16ps on cpus supporting amx due |
651 | to poor perf.*/ |
652 | || brg.isa_impl == avx2_vnni_2) |
653 | vfmadd231ps(vmm_acc, vmma, vmmb); |
654 | else |
655 | vdpbf16ps(vmm_acc, vmma, vmmb); |
656 | } else if (brg.is_f16) { |
657 | vfmadd231ps(vmm_acc, vmma, vmmb); |
658 | } else if (brg.is_int8) { |
659 | vpdpbusd(vmm_acc, vmma, vmmb); |
660 | } |
661 | }; |
662 | |
663 | if (!has_padding) { |
664 | // preload vmm_b if possible. |
665 | for_(int v_i = 0; v_i < v_substep; ++v_i) |
666 | for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) { |
667 | const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i; |
668 | for (int i = 0; i < n_e; ++i) { |
669 | const int n_i = nb_i + i; |
670 | if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue; |
671 | load_b(vmm_b(i), n_i, v_i, has_tail); |
672 | } |
673 | for_(int m_i = 0; m_i < m_blocks; ++m_i) |
674 | for (int i = 0; i < n_e; ++i) { |
675 | const int n_i = nb_i + i; |
676 | if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue; |
677 | if (!is_fma_embd()) load_a(vmm_a(), m_i, n_i, v_i, has_tail); |
678 | dot_product(vmm_a(), vmm_b(i), m_i, n_i, v_i); |
679 | } |
680 | } |
681 | } else { |
682 | const int max_req_preload_vmms = n_blocks * vnni_substep(); |
683 | const int n_preload_b_vmms = max_bvmms >= max_req_preload_vmms |
684 | ? max_req_preload_vmms |
685 | : max_bvmms - 1 /*for ad-hoc load*/; |
686 | for (int i = 0; i < n_preload_b_vmms; ++i) { |
687 | const int n_i = i % n_blocks; |
688 | const int v_i = i / n_blocks; |
689 | if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue; |
690 | load_b(vmm_b(i), n_i, v_i, has_tail); |
691 | } |
692 | |
693 | Label done; |
694 | Label jmp_table_base; |
695 | std::vector<Label> jmp_table_labels(m_blocks); |
696 | if (has_top_padding) { |
697 | // jmp table |
698 | mov(reg_table_base, jmp_table_base); |
699 | lea(reg_table_base, |
700 | ptr[reg_table_base + reg_aux_A_vpad_top * sizeof(void *)]); |
701 | jmp(ptr[reg_table_base]); |
702 | align(8); |
703 | L(jmp_table_base); |
704 | for (int m_i = 0; m_i < m_blocks; ++m_i) { |
705 | putL(jmp_table_labels[m_i]); |
706 | } |
707 | } |
708 | |
709 | for (int m_i = 0; m_i < m_blocks; ++m_i) { |
710 | L(jmp_table_labels[m_i]); |
711 | if (has_bottom_padding) { |
712 | cmp(reg_aux_A_vpad_bottom, m_blocks - m_i); |
713 | jge(done, T_NEAR); |
714 | } |
715 | |
716 | for_(int v_i = 0, p_b_i = 0; v_i < v_substep; ++v_i) |
717 | for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) { |
718 | if (get_substep_simd(n_i, v_i, has_tail) <= 0) continue; |
719 | if (!is_fma_embd()) load_a(vmm_a(), m_i, n_i, v_i, has_tail); |
720 | if (p_b_i < n_preload_b_vmms) { |
721 | dot_product(vmm_a(), vmm_b(p_b_i), m_i, n_i, v_i); |
722 | } else { |
723 | // preloaded vmm_b not available |
724 | const int b_idx = max_bvmms - 1; |
725 | load_b(vmm_b(b_idx), n_i, v_i, has_tail); |
726 | dot_product(vmm_a(), vmm_b(b_idx), m_i, n_i, v_i); |
727 | } |
728 | } |
729 | } |
730 | L(done); |
731 | } |
732 | } |
733 | |
734 | template <cpu_isa_t isa, typename Wmm> |
735 | void jit_brdgmm_kernel_base_t<isa, Wmm>::batch_loop( |
736 | const int m_blocks, const int n_blocks, bool has_n_tail) { |
737 | |
738 | auto get_padding_info = [&]() { |
739 | const bool do_check_effective_padding = check_effective_padding(); |
740 | if (has_vpad()) { |
741 | Label no_top_padding; |
742 | |
743 | if (brg.brgattr.max_bottom_vpad > 0) { |
744 | if (do_check_effective_padding) { |
745 | Label done_adjust_bottom_padding; |
746 | mov(reg_aux_A_vpad_bottom, reg_aux_M); |
747 | add(reg_aux_A_vpad_bottom, m_blocks - M()); |
748 | add(reg_aux_A_vpad_bottom, |
749 | ptr[reg_aux_batch_addr |
750 | + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]); |
751 | jge(done_adjust_bottom_padding, T_NEAR); |
752 | xor_(reg_aux_A_vpad_bottom, reg_aux_A_vpad_bottom); |
753 | L(done_adjust_bottom_padding); |
754 | } else { |
755 | mov(reg_aux_A_vpad_bottom, |
756 | ptr[reg_aux_batch_addr |
757 | + GET_OFF_BATCH_ELEMENT(vvpad.bottom)]); |
758 | } |
759 | mov(reg_total_padding, reg_aux_A_vpad_bottom); |
760 | } |
761 | if (brg.brgattr.max_top_vpad > 0) { |
762 | mov(reg_aux_A_vpad_top, |
763 | ptr[reg_aux_batch_addr |
764 | + GET_OFF_BATCH_ELEMENT(vvpad.top)]); |
765 | if (do_check_effective_padding) { |
766 | Label done_adjust_top_padding; |
767 | sub(reg_aux_A_vpad_top, reg_aux_M); |
768 | jge(done_adjust_top_padding, T_NEAR); |
769 | xor_(reg_aux_A_vpad_top, reg_aux_A_vpad_top); |
770 | L(done_adjust_top_padding); |
771 | } |
772 | if (brg.brgattr.max_bottom_vpad > 0) { |
773 | add(reg_total_padding, reg_aux_A_vpad_top); |
774 | } else { |
775 | mov(reg_total_padding, reg_aux_A_vpad_top); |
776 | } |
777 | } |
778 | } |
779 | }; |
780 | |
781 | auto call_brdgmm_microkernel = [&]() { |
782 | const int tpad = brg.brgattr.max_top_vpad; |
783 | const int bpad = brg.brgattr.max_bottom_vpad; |
784 | const bool vpad_exists = has_vpad(); |
785 | Label microkernel_with_padding, done_microkernel; |
786 | |
787 | if (vpad_exists) { |
788 | cmp(reg_total_padding, 0); |
789 | jg(microkernel_with_padding, T_NEAR); |
790 | } |
791 | brdgmm_microkernel(m_blocks, n_blocks, false, false, has_n_tail); |
792 | if (vpad_exists) { |
793 | jmp(done_microkernel, T_NEAR); |
794 | L(microkernel_with_padding); |
795 | if ((tpad + bpad) >= m_blocks) { |
796 | cmp(reg_total_padding, m_blocks); |
797 | jge(done_microkernel, T_NEAR); |
798 | } |
799 | brdgmm_microkernel(m_blocks, n_blocks, tpad, bpad, has_n_tail); |
800 | } |
801 | L(done_microkernel); |
802 | }; |
803 | |
804 | Label bs_loop_label, done_bs_loop; |
805 | load_accumulators(m_blocks, n_blocks); |
806 | cmp(reg_BS, 0); |
807 | jle(done_bs_loop, T_NEAR); |
808 | mov(reg_BS_loop, reg_BS); |
809 | restore_A_B_matrices(); |
810 | |
811 | L(bs_loop_label); |
812 | { |
813 | set_A_B_matrices(); |
814 | get_padding_info(); |
815 | advance_A_B_matrices(); |
816 | call_brdgmm_microkernel(); |
817 | dec(reg_BS_loop); |
818 | jg(bs_loop_label, T_NEAR); |
819 | } |
820 | |
821 | L(done_bs_loop); |
822 | |
823 | store_accumulators(m_blocks, n_blocks, has_n_tail); |
824 | } |
825 | |
826 | template <cpu_isa_t isa, typename Wmm> |
827 | void jit_brdgmm_kernel_base_t<isa, Wmm>::compute_loop() { |
828 | |
829 | const bool has_m_block2_tail = m_block2_tail() > 0; |
830 | const int loop_m = (nb_m_block2() - has_m_block2_tail); |
831 | const bool do_loop_m = loop_m > 1; |
832 | |
833 | const bool has_n_block2_tail = n_block2_tail() > 0; |
834 | const bool need_separate_n_block1_tail_block = n_block1_tail() != 0 |
835 | && !has_n_block2_tail && nb_n_block2() > 1 |
836 | && !isa_has_masks(brg.isa_impl); |
837 | const int loop_n = nb_n_block2() - has_n_block2_tail |
838 | - need_separate_n_block1_tail_block; |
839 | const bool do_loop_n = loop_n > 1; |
840 | const bool loop_n_update_aux_ptrs = do_loop_n || (loop_n < nb_n_block2()); |
841 | |
842 | auto n_loop = [&](int m_blocks) { |
843 | Label n_loop_label; |
844 | const int n_blocks = n_block2(); |
845 | const int n_loop_step = oc_logical_offset(n_blocks); |
846 | const int n_loop_work = loop_n * n_blocks * n_block1(); |
847 | const bool vlen_tail_in_loop = n_block1_tail() != 0 |
848 | && !need_separate_n_block1_tail_block && !has_n_block2_tail; |
849 | |
850 | xor_(reg_aux_N, reg_aux_N); |
851 | |
852 | L(n_loop_label); |
853 | { |
854 | if (do_loop_n) { |
855 | if (vlen_tail_in_loop) { |
856 | Label done_k_mask; |
857 | cmp(reg_aux_N, n_loop_work - n_loop_step); |
858 | jl(done_k_mask, T_NEAR); |
859 | kmovd(k_mask, k_tail_mask); |
860 | L(done_k_mask); |
861 | } |
862 | } |
863 | |
864 | batch_loop(m_blocks, n_blocks, vlen_tail_in_loop); |
865 | |
866 | if (loop_n_update_aux_ptrs) { |
867 | add(reg_aux_N, n_loop_step); |
868 | add(reg_a_offset, n_loop_step * brg.typesize_A); |
869 | add(reg_aux_C, n_loop_step * brg.typesize_C); |
870 | add(reg_aux_D, n_loop_step * brg.typesize_D); |
871 | } |
872 | |
873 | if (do_loop_n) { |
874 | cmp(reg_aux_N, n_loop_work); |
875 | jl(n_loop_label, T_NEAR); |
876 | } |
877 | } |
878 | |
879 | if (need_separate_n_block1_tail_block) |
880 | batch_loop(m_blocks, n_blocks, true); |
881 | |
882 | if (has_n_block2_tail) { |
883 | batch_loop(m_blocks, n_block2_tail(), n_block1_tail() != 0); |
884 | } |
885 | }; |
886 | |
887 | auto m_loop = [&]() { |
888 | Label m_loop_label; |
889 | const int m_blocks = m_block2(); |
890 | const bool reset_mask = isa_has_masks(brg.isa_impl) |
891 | && n_block1_tail() != 0 && do_loop_n && !has_n_block2_tail; |
892 | |
893 | xor_(reg_aux_M, reg_aux_M); |
894 | xor_(reg_a_offset, reg_a_offset); |
895 | |
896 | L(m_loop_label); |
897 | { |
898 | if (reset_mask) kxnorq(k_mask, k_mask, k_mask); |
899 | n_loop(m_blocks); |
900 | |
901 | if (do_loop_m || has_m_block2_tail) { |
902 | add(reg_aux_M, m_blocks); |
903 | const int n_loop_offset |
904 | = loop_n_update_aux_ptrs * loop_n * n_block2(); |
905 | add(reg_a_offset, A_offset(m_blocks, -n_loop_offset)); |
906 | add(reg_aux_C, C_offset(m_blocks, -n_loop_offset, 0)); |
907 | add(reg_aux_D, D_offset(m_blocks, -n_loop_offset, 0)); |
908 | } |
909 | |
910 | if (do_loop_m) { |
911 | cmp(reg_aux_M, loop_m * m_block2()); |
912 | jl(m_loop_label, T_NEAR); |
913 | } |
914 | } |
915 | |
916 | if (m_block2_tail() > 0) { |
917 | if (reset_mask) { kxnorq(k_mask, k_mask, k_mask); } |
918 | n_loop(m_block2_tail()); |
919 | } |
920 | }; |
921 | |
922 | assert(m_block1_tail() == 0); |
923 | m_loop(); |
924 | } |
925 | |
926 | template <cpu_isa_t isa, typename Wmm> |
927 | void jit_brdgmm_kernel_base_t<isa, Wmm>::init_masks() { |
928 | if (!isa_has_masks(brg.isa_impl)) return; |
929 | |
930 | if (is_fast_vnni_int8()) { |
931 | mov(reg_tmp, 0x8888444422221111); |
932 | kmovq(kblend_mask, reg_tmp); |
933 | } |
934 | |
935 | if (n_block1_tail() != 0) { |
936 | const auto tail_mask = size_t((1 << n_block1_tail()) - 1); |
937 | const bool has_n_block2_tail = n_block2_tail() > 0; |
938 | mov(reg_tmp, tail_mask); |
939 | if (has_n_block2_tail || nb_n_block2() <= 1) { |
940 | // The mask can be set only once. |
941 | kmovq(k_mask, reg_tmp); |
942 | } else { |
943 | // Need to adjust mask, and set only when needed. |
944 | // So store it temporarily in k_tail_mask. |
945 | kmovq(k_tail_mask, reg_tmp); |
946 | } |
947 | } else if (brg.with_binary) { |
948 | // the post-ops injector seems to use mask unconditionally |
949 | // set a default mask. |
950 | kxnorq(k_mask, k_mask, k_mask); |
951 | } |
952 | } |
953 | |
954 | template <cpu_isa_t isa, typename Wmm> |
955 | void jit_brdgmm_kernel_base_t<isa, Wmm>::generate() { |
956 | |
957 | preamble(); |
958 | sub(rsp, stack_space_needed_); |
959 | |
960 | init_masks(); |
961 | |
962 | if (is_fast_vnni_int8() && !brg.is_bf16_emu) { |
963 | // load permute indices from data section |
964 | mov(reg_tmp, permute_index_table); |
965 | vmovdqu32(vmm_permute(), ptr[reg_tmp]); |
966 | } |
967 | |
968 | read_params(); |
969 | compute_loop(); |
970 | |
971 | add(rsp, stack_space_needed_); |
972 | postamble(); |
973 | |
974 | if (brg.with_eltwise) postops_injector_->prepare_table(); |
975 | |
976 | if (is_fast_vnni_int8()) { |
977 | align(64); |
978 | L(permute_index_table); |
979 | const uint32_t _idx[] |
980 | = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; |
981 | for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) |
982 | dd(_idx[i]); |
983 | } |
984 | } |
985 | |
986 | template <cpu_isa_t isa, typename Wmm> |
987 | brdgmm_kernel_t<isa, Wmm>::brdgmm_kernel_t(const brgemm_t abrd) { |
988 | brgemm_kernel_ = new jit_brdgmm_kernel_base_t<isa, Wmm>(abrd); |
989 | } |
990 | |
991 | template <cpu_isa_t isa, typename Wmm> |
992 | status_t brdgmm_kernel_t<isa, Wmm>::create_kernel() { |
993 | return brgemm_kernel_->create_kernel(); |
994 | } |
995 | |
996 | template <cpu_isa_t isa, typename Wmm> |
997 | void brdgmm_kernel_t<isa, Wmm>::operator()( |
998 | brgemm_kernel_params_t *params) const { |
999 | (*brgemm_kernel_)(params); |
1000 | } |
1001 | |
1002 | template <cpu_isa_t isa, typename Wmm> |
1003 | brdgmm_kernel_t<isa, Wmm>::~brdgmm_kernel_t() { |
1004 | delete brgemm_kernel_; |
1005 | } |
1006 | |
1007 | template struct brdgmm_kernel_t<avx512_core_fp16, Xbyak::Zmm>; |
1008 | template struct brdgmm_kernel_t<avx512_core_bf16, Xbyak::Zmm>; |
1009 | template struct brdgmm_kernel_t<avx512_core_vnni, Xbyak::Zmm>; |
1010 | template struct brdgmm_kernel_t<avx512_core, Xbyak::Zmm>; |
1011 | template struct brdgmm_kernel_t<avx2, Xbyak::Ymm>; |
1012 | template struct brdgmm_kernel_t<avx2_vnni_2, Xbyak::Ymm>; |
1013 | |
1014 | } // namespace x64 |
1015 | } // namespace cpu |
1016 | } // namespace impl |
1017 | } // namespace dnnl |
1018 | |