1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/memory_tracking.hpp" |
20 | #include "common/tag_traits.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/cpu_primitive.hpp" |
25 | #include "cpu/scale_utils.hpp" |
26 | |
27 | #include "cpu/x64/amx_tile_configure.hpp" |
28 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
29 | #include "cpu/x64/matmul/brgemm_matmul.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | namespace matmul { |
36 | |
37 | using namespace dnnl::impl::memory_tracking::names; |
38 | using namespace dnnl::impl::utils; |
39 | |
40 | using namespace nstl; |
41 | |
42 | using namespace data_type; |
43 | |
44 | template <cpu_isa_t isa> |
45 | status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) { |
46 | const auto src_dt = src_md_.data_type; |
47 | const auto wei_dt = weights_md_.data_type; |
48 | const auto dst_dt = dst_md_.data_type; |
49 | |
50 | const bool is_f32 = everyone_is(f32, src_dt, wei_dt, dst_dt); |
51 | const bool is_int8 = one_of(src_dt, u8, s8) && wei_dt == s8 |
52 | && one_of(dst_dt, u8, s8, s32, f32, bf16); |
53 | const bool is_bf16 |
54 | = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32); |
55 | const bool is_f16 |
56 | = everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32); |
57 | |
58 | auto check_bias = [&]() -> bool { |
59 | const auto bia_dt = weights_md(1)->data_type; |
60 | // The cause in IMPLICATION should be an expression to work around |
61 | // ICE in GCC 7.4. |
62 | const bool is_bia_dt_correct |
63 | = IMPLICATION(is_int8 == true, |
64 | one_of(bia_dt, f32, s32, s8, u8, bf16)) |
65 | && IMPLICATION(!is_int8, one_of(bia_dt, f32, src_dt)); |
66 | return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN()); |
67 | }; |
68 | |
69 | auto check_attr_scales = [&]() -> bool { |
70 | using namespace data_type; |
71 | const std::vector<int> supported_args |
72 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; |
73 | bool ok = attr()->scales_.has_default_values(supported_args); |
74 | for (int arg : supported_args) { |
75 | const auto &mask = attr()->scales_.get(arg).mask_; |
76 | if (arg == DNNL_ARG_WEIGHTS) |
77 | ok = ok && (mask == 0 || mask == 1 << (dst_md()->ndims - 1)); |
78 | else |
79 | ok = ok && (mask == 0); |
80 | } |
81 | if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() |
82 | && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() |
83 | && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { |
84 | // This case requires scratchpad |
85 | if (N() == DNNL_RUNTIME_DIM_VAL) return false; |
86 | } |
87 | return ok; |
88 | }; |
89 | |
90 | auto check_attr_zero_points |
91 | = [&]() -> bool { return attr()->zero_points_.common(); }; |
92 | |
93 | const bool problem_dt_correct = is_int8 || is_bf16 || is_f32 || is_f16; |
94 | bool ok = mayiuse(isa) && problem_dt_correct && !has_zero_dim_memory() |
95 | && !has_runtime_dims_or_strides() |
96 | && attr()->has_default_values( |
97 | primitive_attr_t::skip_mask_t::scales_runtime |
98 | | primitive_attr_t::skip_mask_t::zero_points_runtime |
99 | | primitive_attr_t::skip_mask_t::post_ops |
100 | | primitive_attr_t::skip_mask_t::sum_dt, |
101 | dst_dt) |
102 | && attr()->post_ops_.check_sum_consistent_dt(dst_dt) |
103 | && check_attr_scales() && check_attr_zero_points() && check_bias(); |
104 | if (!ok) return status::unimplemented; |
105 | |
106 | CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_, |
107 | dst_md_, bias_md_, attr_)); |
108 | |
109 | const float alpha = 1.0; |
110 | const float beta = 1.0; |
111 | const float beta_init = 0.0; |
112 | |
113 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
114 | for_(int i_init = 0; i_init < 2; i_init++) |
115 | for_(int i_M = 0; i_M < 2; i_M++) |
116 | for_(int i_N = 0; i_N < 2; i_N++) |
117 | for (int i_K = 0; i_K < 2; i_K++) { |
118 | auto vbeta = (i_init) ? beta_init : beta; |
119 | auto vM = (i_M) ? bgmmc_.M_tail : bgmmc_.M_blk; |
120 | auto vN = (i_N) ? bgmmc_.N_tail : bgmmc_.N_blk; |
121 | auto vK = (i_K) ? bgmmc_.K_tail : bgmmc_.K_blk; |
122 | |
123 | int bs = get_brg_batchsize(bgmmc_, i_bs, i_K); |
124 | int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K); |
125 | if (idx < 0) continue; |
126 | brgemm_t &brg = brg_descs_[idx]; |
127 | auto LDA = i_K && bgmmc_.use_buffer_a_tail_only |
128 | ? (dim_t)bgmmc_.wei_k_blk |
129 | : bgmmc_.LDA; |
130 | CHECK(brgemm_desc_init(&brg, isa, bgmmc_.brg_type, bgmmc_.src_dt, |
131 | bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta, |
132 | LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK)); |
133 | |
134 | auto LDD = bgmmc_.LDD; |
135 | CHECK(brgemm_desc_set_postops( |
136 | &brg, attr(), &dst_md_, LDD, bgmmc_.bia_dt)); |
137 | |
138 | brgemm_attr_t brgattr; |
139 | brgattr.generate_skip_accumulation |
140 | = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1; |
141 | const bool is_amx = is_superset(isa, avx512_core_amx); |
142 | if (is_amx) { |
143 | if (!brgattr.generate_skip_accumulation) { |
144 | // TODO: uker doesn't yet support generate_skip_accumulation |
145 | brgattr.use_uker = true; |
146 | brgattr.use_interleave_stores = true; |
147 | } |
148 | brgattr.max_bs = bs; |
149 | brgattr.wary_tail_read = false; |
150 | |
151 | // TODO: change expected sizes to local chunks wrt L2 blocking |
152 | brgattr.hint_expected_A_size = vM * vK * bs; |
153 | brgattr.hint_expected_B_size = vN * vK * bs; |
154 | brgattr.hint_expected_C_size = vM * vN * bs; |
155 | brgattr.hint_innermost_loop = brgemm_ld_loop_innermost; |
156 | brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
157 | } |
158 | |
159 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
160 | bgmmc_.wsp_tile_per_thr_bytes = nstl::max( |
161 | brg.get_wsp_buffer_size(), bgmmc_.wsp_tile_per_thr_bytes); |
162 | } |
163 | |
164 | auto scratchpad = scratchpad_registry().registrar(); |
165 | init_scratchpad(scratchpad, bgmmc_); |
166 | book_precomputed_scales(scratchpad, attr()->scales_, N()); |
167 | |
168 | return status::success; |
169 | } |
170 | |
171 | template <cpu_isa_t isa> |
172 | status_t brgemm_matmul_t<isa>::init(engine_t *engine) { |
173 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
174 | for_(int i_M = 0; i_M < 2; i_M++) |
175 | for_(int i_N = 0; i_N < 2; i_N++) |
176 | for_(int i_K = 0; i_K < 2; i_K++) |
177 | for (int i_init = 0; i_init < 2; i_init++) { |
178 | int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K); |
179 | if (idx < 0) continue; |
180 | |
181 | brgemm_kernel_t *ker = nullptr; |
182 | CHECK(brgemm_kernel_create(&ker, pd()->get_brg_desc(idx))); |
183 | CHECK(safe_ptr_assign(brg_kernels_[idx], ker)); |
184 | if (is_superset(isa, avx512_core_amx)) |
185 | CHECK(brgemm_init_tiles( |
186 | pd()->get_brg_desc(idx), &brg_kernel_palettes_[idx][0])); |
187 | } |
188 | |
189 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
190 | if (bgmmc.use_buffer_b) |
191 | CHECK(create_brgemm_matmul_copy_b(copy_B_kernel_, &bgmmc)); |
192 | |
193 | if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only) |
194 | CHECK(create_brgemm_matmul_copy_a(copy_A_kernel_, &bgmmc)); |
195 | |
196 | if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == f32) { |
197 | CHECK(safe_ptr_assign( |
198 | acc_ker_f32_, new cpu_accumulator_1d_t<data_type::f32>())); |
199 | CHECK(acc_ker_f32_->create_kernel()); |
200 | } else if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == s32) { |
201 | CHECK(safe_ptr_assign( |
202 | acc_ker_s32_, new cpu_accumulator_1d_t<data_type::s32>())); |
203 | CHECK(acc_ker_s32_->create_kernel()); |
204 | } |
205 | |
206 | return status::success; |
207 | } |
208 | |
209 | template <cpu_isa_t isa> |
210 | status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const { |
211 | DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); |
212 | DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS); |
213 | DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); |
214 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
215 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
216 | |
217 | auto &scratchpad = ctx.get_scratchpad_grantor(); |
218 | const float *oscales = precompute_scales( |
219 | scratchpad, src_scales, wei_scales, pd()->N(), pd()->attr()); |
220 | |
221 | brg_matmul_exec_ctx_t brgmm_ctx( |
222 | ctx, pd(), oscales, src_zero_point, wei_zero_point, dst_zero_point); |
223 | |
224 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
225 | const bool use_buffer_a |
226 | = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only; |
227 | const bool is_amx = is_superset(isa, avx512_core_amx); |
228 | const int num_threads = brgmm_ctx.get_num_threads_for_parallelization(); |
229 | |
230 | parallel(num_threads, [&](const int ithr, const int nthr) { |
231 | const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr); |
232 | const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr); |
233 | if (ithr_bmn < 0 || ithr_k < 0) return; |
234 | int start {0}, end {0}; |
235 | balance211(brgmm_ctx.get_parallel_work_amount(), |
236 | brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, start, end); |
237 | int kc_start {0}, kc_end {bgmmc.K_chunks}; |
238 | if (brgmm_ctx.parallel_reduction_is_used()) |
239 | balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(), |
240 | ithr_k, kc_start, kc_end); |
241 | |
242 | if (is_amx) { |
243 | const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx(); |
244 | amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]); |
245 | } |
246 | |
247 | int b {0}, mc {0}, nc {0}; |
248 | nd_iterator_init( |
249 | start, b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks); |
250 | while (start < end) { |
251 | auto m_start = mc * bgmmc.M_chunk_size; |
252 | auto m_end = nstl::min( |
253 | (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks); |
254 | auto n_start = nc * bgmmc.N_chunk_size; |
255 | auto n_end = nstl::min( |
256 | (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks); |
257 | for_(int kc = kc_start; kc < kc_end; kc++) |
258 | for (int nb = n_start; nb < n_end; nb++) { |
259 | if (bgmmc.use_buffer_b) |
260 | copy_b_chunk_in_buffer(brgmm_ctx, ithr, b, nb, kc); |
261 | for (int mb = m_start; mb < m_end; mb++) { |
262 | if (use_buffer_a && nb == n_start) |
263 | copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc); |
264 | compute_kernel( |
265 | brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start); |
266 | } |
267 | } |
268 | ++start; |
269 | nd_iterator_step( |
270 | b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks); |
271 | } |
272 | if (is_amx) { amx_tile_release(); } |
273 | }); |
274 | |
275 | maybe_reduce_partial_results_and_apply_postops(brgmm_ctx); |
276 | |
277 | return status::success; |
278 | } |
279 | |
280 | template <cpu_isa_t isa> |
281 | void brgemm_matmul_t<isa>::compute_kernel( |
282 | const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx, |
283 | int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const { |
284 | const bool is_amx = is_superset(isa, avx512_core_amx); |
285 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
286 | const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr); |
287 | const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx(); |
288 | |
289 | const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr); |
290 | const int m = m_blk_idx * bgmmc.M_blk; |
291 | const int n = n_blk_idx * bgmmc.N_blk; |
292 | const int k_blk_idx = k_chunk_idx * bgmmc.brgemm_batch_size; |
293 | |
294 | const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk); |
295 | const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk); |
296 | const bool is_last_K_chunk = brgmm_ctx.is_last_K_chunk(k_chunk_idx); |
297 | |
298 | const int remaining_k_blks |
299 | = (bgmmc.use_buffer_a ? utils::rnd_up(bgmmc.K, bgmmc.K_blk) |
300 | : bgmmc.K) |
301 | - k_chunk_idx * bgmmc.K_chunk_elems; |
302 | const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx); |
303 | const bool is_K_tail |
304 | = is_last_K_chunk && (gemm_batch * bgmmc.K_blk) != remaining_k_blks; |
305 | auto is_bs_tail = (gemm_batch != bgmmc.brgemm_batch_size); |
306 | const int brg_ker_idx = pd()->get_brg_kernel_idx( |
307 | is_bs_tail, do_init, is_M_tail, is_N_tail, false); |
308 | const auto ptr_bias = brgmm_ctx.get_bias_ptr(n); |
309 | auto ptr_D = brgmm_ctx.get_data_C_ptr(b_idx, m, n); |
310 | auto ptr_C = (bgmmc.use_buffer_c) |
311 | ? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx) |
312 | : ptr_D; |
313 | |
314 | const auto zp_comp_a |
315 | = brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx); |
316 | const auto zp_comp_b |
317 | = brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx); |
318 | const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr(); |
319 | const auto &post_ops_binary_rhs_arg_vec |
320 | = brgmm_ctx.get_post_ops_binary_rhs_arg_vec(); |
321 | const bool post_ops_applicable = bgmmc.post_ops_applicable |
322 | && (brgmm_ctx.get_num_threads_for_k() <= 1 || bgmmc.K_chunks == 1); |
323 | |
324 | if (gemm_batch > 0 && brg_ker_idx >= 0) { |
325 | const auto brg_kernel = brg_kernels_[brg_ker_idx].get(); |
326 | assert(brg_kernel != nullptr); |
327 | |
328 | const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail); |
329 | if (is_tile_reconf_required) |
330 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
331 | |
332 | brgmm_ctx.init_brgemm_batch_elements_values( |
333 | ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx); |
334 | |
335 | if (post_ops_applicable && is_last_K_chunk && !is_K_tail) { |
336 | void *scratch = is_amx |
337 | ? static_cast<void *>(wsp_tile) |
338 | : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr( |
339 | ithr, b_idx, n_blk_idx)); |
340 | |
341 | const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk; |
342 | const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1 |
343 | ? b_idx / bgmmc.batch_without_first_dim |
344 | : 0; |
345 | const size_t first_mb_matrix_addr_off |
346 | = batch_first_dim_idx * (bgmmc.M * bgmmc.N) |
347 | + (m * bgmmc.N + n); |
348 | const brgemm_post_ops_data_t post_ops_data { |
349 | static_cast<const void *>(ptr_bias), |
350 | brgmm_ctx.get_oscales_ptr(n), |
351 | post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n), |
352 | dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0), |
353 | first_mb_matrix_addr_off, |
354 | static_cast<const void *>(zp_comp_a), |
355 | static_cast<const void *>(zp_comp_b), |
356 | static_cast<const void *>(zp_c_val_ptr)}; |
357 | |
358 | brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch, |
359 | (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch); |
360 | } else { |
361 | brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, |
362 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
363 | } |
364 | |
365 | if (is_tile_reconf_required) |
366 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
367 | } |
368 | if (is_K_tail) { |
369 | brgmm_ctx.init_brgemm_batch_elements_values( |
370 | ithr, gemm_batch, 1, b_idx, m_blk_idx, k_blk_idx, n_blk_idx); |
371 | |
372 | const bool use_init_ker = (do_init && gemm_batch == 0); |
373 | const int brg_ker_idx = pd()->get_brg_kernel_idx( |
374 | false, use_init_ker, is_M_tail, is_N_tail, true); |
375 | const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get(); |
376 | const bool is_tile_reconf_required |
377 | = is_amx && bgmmc.K_tail != bgmmc.K_blk; |
378 | if (is_tile_reconf_required) |
379 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
380 | if (post_ops_applicable) { |
381 | void *scratch = is_amx |
382 | ? static_cast<void *>(wsp_tile) |
383 | : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr( |
384 | ithr, b_idx, n_blk_idx)); |
385 | |
386 | const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk; |
387 | const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1 |
388 | ? b_idx / bgmmc.batch_without_first_dim |
389 | : 0; |
390 | const size_t first_mb_matrix_addr_off |
391 | = batch_first_dim_idx * (bgmmc.M * bgmmc.N) |
392 | + (m * bgmmc.N + n); |
393 | const brgemm_post_ops_data_t post_ops_data { |
394 | static_cast<const void *>(ptr_bias), |
395 | brgmm_ctx.get_oscales_ptr(n), |
396 | post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n), |
397 | dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0), |
398 | first_mb_matrix_addr_off, |
399 | static_cast<const void *>(zp_comp_a), |
400 | static_cast<const void *>(zp_comp_b), |
401 | static_cast<const void *>(zp_c_val_ptr)}; |
402 | |
403 | brgemm_kernel_execute_postops(brg_kernel_k_tail, 1, addr_batch, |
404 | (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch); |
405 | } else { |
406 | brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch, |
407 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
408 | } |
409 | if (is_tile_reconf_required) |
410 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
411 | } |
412 | } |
413 | |
414 | template <cpu_isa_t isa> |
415 | void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops( |
416 | const brg_matmul_exec_ctx_t &brgmm_ctx) const { |
417 | if (!brgmm_ctx.parallel_reduction_is_used()) return; |
418 | |
419 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
420 | const int num_threads = brgmm_ctx.get_num_threads_for_parallelization(); |
421 | |
422 | parallel(num_threads, [&](const int ithr, const int nthr) { |
423 | const int nthr_k = brgmm_ctx.get_num_threads_for_k(); |
424 | const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr); |
425 | const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr); |
426 | if (ithr_bmn < 0 || ithr_k < 0) return; |
427 | |
428 | const int num_reduction_buffers = nstl::min(nthr_k, bgmmc.K_chunks); |
429 | |
430 | int bmn_start {0}, bmn_end {0}; |
431 | int start {0}, end {0}; |
432 | balance211(brgmm_ctx.get_parallel_work_amount(), |
433 | brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, bmn_start, |
434 | bmn_end); |
435 | balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end); |
436 | |
437 | int b {0}, mc {0}, nc {0}; |
438 | |
439 | assert(bgmmc.batch == 1); |
440 | nd_iterator_init(bmn_start + start, b, bgmmc.batch, mc, bgmmc.M_chunks, |
441 | nc, bgmmc.N_chunks); |
442 | while (start < end) { |
443 | auto mb_start = mc * bgmmc.M_chunk_size; |
444 | auto mb_end = nstl::min( |
445 | (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks); |
446 | auto nb_start = nc * bgmmc.N_chunk_size; |
447 | auto nb_end = nstl::min( |
448 | (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks); |
449 | for (int mb = mb_start; mb < mb_end; mb++) { |
450 | const int curr_M_blk |
451 | = nstl::min(bgmmc.M - mb * bgmmc.M_blk, bgmmc.M_blk); |
452 | const bool is_M_tail = curr_M_blk < bgmmc.M_blk; |
453 | const int curr_N_chunk_size |
454 | = nstl::min(bgmmc.N, nb_end * bgmmc.N_blk) |
455 | - nb_start * bgmmc.N_blk; |
456 | char *buf_reduced_base = brgmm_ctx.get_buf_C_par_reduction_ptr( |
457 | 0, mb, nb_start); |
458 | const size_t m_offset = bgmmc.LDC * bgmmc.acc_dt_sz; |
459 | for (int r = 1; r < num_reduction_buffers; r++) { |
460 | const char *buf_to_reduce_base |
461 | = brgmm_ctx.get_buf_C_par_reduction_ptr( |
462 | r, mb, nb_start); |
463 | for (int m = 0; m < curr_M_blk; m++) { |
464 | accumulate(buf_reduced_base + m * m_offset, |
465 | buf_to_reduce_base + m * m_offset, |
466 | curr_N_chunk_size); |
467 | } |
468 | } |
469 | if (bgmmc.post_ops_applicable) { |
470 | for (int nb = nb_start; nb < nb_end; nb++) { |
471 | const bool is_N_tail |
472 | = (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk); |
473 | const int brg_ker_idx = pd()->get_brg_kernel_idx( |
474 | false, false, is_M_tail, is_N_tail, false); |
475 | const auto brg_kernel = brg_kernels_[brg_ker_idx].get(); |
476 | const int m = mb * bgmmc.M_blk; |
477 | const int n = nb * bgmmc.N_blk; |
478 | const auto ptr_bias = brgmm_ctx.get_bias_ptr(n); |
479 | auto ptr_D = brgmm_ctx.get_data_C_ptr(b, m, n); |
480 | auto ptr_C = brgmm_ctx.get_buf_C_par_reduction_ptr( |
481 | 0, mb, nb); |
482 | |
483 | // TODO: support reduction for zp/s8s8 compensations |
484 | // computed in copy routines |
485 | const auto zp_comp_a |
486 | = brgmm_ctx.get_zp_a_compensation_ptr( |
487 | ithr, b, nb); |
488 | const auto zp_comp_b |
489 | = brgmm_ctx.get_zp_b_compensation_result_ptr( |
490 | ithr, mb); |
491 | const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr(); |
492 | const auto &post_ops_binary_rhs_arg_vec |
493 | = brgmm_ctx.get_post_ops_binary_rhs_arg_vec(); |
494 | |
495 | const size_t dst_row_logical_off = mb * bgmmc.M_blk; |
496 | const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1 |
497 | ? b / bgmmc.batch_without_first_dim |
498 | : 0; |
499 | const size_t first_mb_matrix_addr_off |
500 | = batch_first_dim_idx * (bgmmc.M * bgmmc.N) |
501 | + (m * bgmmc.N + n); |
502 | // apply post-ops and convert to dst data type only |
503 | constexpr bool skip_accumulation = true; |
504 | const brgemm_post_ops_data_t post_ops_data { |
505 | static_cast<const void *>(ptr_bias), |
506 | brgmm_ctx.get_oscales_ptr(n), |
507 | post_ops_binary_rhs_arg_vec.data(), |
508 | static_cast<size_t>(n), dst_row_logical_off, |
509 | brgmm_ctx.get_data_C_ptr(0, 0, 0), |
510 | first_mb_matrix_addr_off, |
511 | static_cast<const void *>(zp_comp_a), |
512 | static_cast<const void *>(zp_comp_b), |
513 | static_cast<const void *>(zp_c_val_ptr), |
514 | skip_accumulation}; |
515 | |
516 | brgemm_kernel_execute_postops(brg_kernel, 0, nullptr, |
517 | (void *)ptr_C, (void *)ptr_D, post_ops_data, |
518 | nullptr); |
519 | } |
520 | } |
521 | } |
522 | ++start; |
523 | nd_iterator_step( |
524 | b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks); |
525 | } |
526 | }); |
527 | } |
528 | |
529 | template <cpu_isa_t isa> |
530 | void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer( |
531 | const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx, |
532 | int m_blk_idx, int k_chunk_idx) const { |
533 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
534 | |
535 | auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t(); |
536 | const int k_start = k_chunk_idx * bgmmc.K_chunk_elems; |
537 | const bool is_K_tail |
538 | = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0; |
539 | const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx); |
540 | const int gemm_batch_iters = bgmmc.use_buffer_a_tail_only ? 0 : gemm_batch; |
541 | |
542 | const int m = m_blk_idx * bgmmc.M_blk; |
543 | const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk); |
544 | ctx.current_M_blk = is_M_tail ? bgmmc.M_tail : bgmmc.M_blk; |
545 | ctx.zp_b_compensation_buffer_ptr |
546 | = (void *)brgmm_ctx.get_zp_b_compensation_buffer_ptr( |
547 | ithr, m_blk_idx); |
548 | ctx.zp_a_compensation_result_ptr |
549 | = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr( |
550 | ithr, m_blk_idx); |
551 | ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr(); |
552 | ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr(); |
553 | |
554 | for (int gb = 0; gb < gemm_batch_iters; gb++) { |
555 | const int k = k_start + gb * bgmmc.K_blk; |
556 | ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k); |
557 | ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(ithr, m_blk_idx, gb); |
558 | ctx.current_K_blk = nstl::min(bgmmc.K_blk, bgmmc.K); |
559 | ctx.current_K_start = k; |
560 | |
561 | (*copy_A_kernel_)(&ctx); |
562 | } |
563 | if (is_K_tail) { |
564 | const auto K_tail = bgmmc.K % bgmmc.K_blk; |
565 | const int k = k_start + gemm_batch * bgmmc.K_blk; |
566 | ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k); |
567 | ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr( |
568 | ithr, m_blk_idx, gemm_batch_iters); |
569 | ctx.current_K_blk = K_tail; |
570 | ctx.current_K_start = k; |
571 | |
572 | (*copy_A_kernel_)(&ctx); |
573 | } |
574 | } |
575 | |
576 | template <cpu_isa_t isa> |
577 | void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer( |
578 | const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx, |
579 | int n_blk_idx, int k_chunk_idx) const { |
580 | const auto &bgmmc = pd()->get_brgemm_matmul_conf(); |
581 | |
582 | const int k_start = k_chunk_idx * bgmmc.K_chunk_elems; |
583 | const bool is_K_tail |
584 | = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0; |
585 | const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx); |
586 | auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t(); |
587 | |
588 | const int n = n_blk_idx * bgmmc.N_blk; |
589 | const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk); |
590 | ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk; |
591 | ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr( |
592 | ithr, b_idx, n_blk_idx); |
593 | ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr(); |
594 | |
595 | int gb = 0; |
596 | for (; gb < gemm_batch; gb++) { |
597 | const int k = k_start + gb * bgmmc.K_blk; |
598 | ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n); |
599 | ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx); |
600 | ctx.compensation_ptr |
601 | = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); |
602 | ctx.current_K_start = k; |
603 | ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K); |
604 | if (bgmmc.blocked_B && isa == avx512_core_fp16) { |
605 | cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src, |
606 | bgmmc.wei_n_blk * ctx.current_K_iters); |
607 | } else { |
608 | (*copy_B_kernel_)(&ctx); |
609 | } |
610 | } |
611 | |
612 | if (is_K_tail) { |
613 | const int k = k_start + gb * bgmmc.K_blk; |
614 | ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n); |
615 | ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx); |
616 | ctx.compensation_ptr |
617 | = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); |
618 | ctx.current_K_start = k; |
619 | ctx.current_K_iters = bgmmc.K % bgmmc.K_blk; |
620 | if (bgmmc.blocked_B && isa == avx512_core_fp16) { |
621 | cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src, |
622 | bgmmc.wei_n_blk * ctx.current_K_iters); |
623 | } else { |
624 | (*copy_B_kernel_)(&ctx); |
625 | } |
626 | } |
627 | } |
628 | |
629 | template <cpu_isa_t isa> |
630 | void brgemm_matmul_t<isa>::accumulate( |
631 | char *result_ptr, const char *reduce_ptr, size_t size) const { |
632 | if (pd()->get_brgemm_matmul_conf().acc_dt == f32) |
633 | acc_ker_f32_->accumulate( |
634 | (float *)result_ptr, (const float *)reduce_ptr, size); |
635 | else if (pd()->get_brgemm_matmul_conf().acc_dt == s32) |
636 | acc_ker_s32_->accumulate( |
637 | (int32_t *)result_ptr, (const int32_t *)reduce_ptr, size); |
638 | else |
639 | assert(!"unsupported accumulation data type" ); |
640 | } |
641 | |
642 | template <cpu_isa_t isa> |
643 | struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t { |
644 | brg_matmul_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd, |
645 | const float *oscales, int32_t src_zp, int32_t wei_zp, |
646 | int32_t dst_zp) |
647 | : bgmmc_(pd->get_brgemm_matmul_conf()) { |
648 | |
649 | data_A_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
650 | data_B_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
651 | data_C_ptr_ = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
652 | |
653 | bias_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
654 | oscales_ptr_ = oscales; |
655 | memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
656 | const auto &bgmmc = pd->get_brgemm_matmul_conf(); |
657 | |
658 | batch_element_ptr_ = scratchpad.template get<brgemm_batch_element_t>( |
659 | key_brgemm_primitive_batch); |
660 | |
661 | const bool use_buffer_a |
662 | = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only; |
663 | buf_A_ptr_ = (use_buffer_a) |
664 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a) |
665 | : nullptr; |
666 | |
667 | buf_B_ptr_ = (bgmmc.use_buffer_b) |
668 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b) |
669 | : nullptr; |
670 | |
671 | buf_C_ptr_ = (bgmmc.use_buffer_c) |
672 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
673 | : nullptr; |
674 | |
675 | is_amx_ = is_superset(isa, avx512_core_amx); |
676 | wsp_tile_ptr_ = is_amx_ |
677 | ? ctx.get_scratchpad_grantor().template get<char>( |
678 | key_conv_amx_tile_buffer) |
679 | : nullptr; |
680 | |
681 | const memory_desc_wrapper weights_d(pd->weights_md(0)); |
682 | const dim_t comp_offset = bgmmc_.b_dt_sz |
683 | * (weights_d.size() - weights_d.additional_buffer_size()); |
684 | s8s8_compensation_ptr_ = (bgmmc.s8s8_compensation_required) |
685 | ? ((bgmmc.use_buffer_b) |
686 | ? scratchpad.template get<int32_t>( |
687 | key_brgemm_primitive_buffer_comp) |
688 | : const_cast<int32_t *>( |
689 | reinterpret_cast<const int32_t *>( |
690 | &data_B_ptr_[comp_offset]))) |
691 | : nullptr; |
692 | assert(IMPLICATION(bgmmc.s8s8_compensation_required, |
693 | bgmmc_.b_dt_sz == bgmmc_.tr_b_dt_sz)); |
694 | |
695 | zero_point_a_compensations_ptr_ = bgmmc.has_zero_point_a |
696 | ? scratchpad.template get<int32_t>( |
697 | key_brgemm_primitive_zp_comp_a) |
698 | : nullptr; |
699 | zero_point_b_compensations_ptr_ = bgmmc.has_zero_point_b |
700 | ? scratchpad.template get<int32_t>( |
701 | key_brgemm_primitive_zp_comp_b) |
702 | : nullptr; |
703 | |
704 | zero_point_a_negative_val_ = -src_zp; |
705 | zero_point_b_negative_val_ = -wei_zp; |
706 | zero_point_mixed_ab_compensation_component_ |
707 | = bgmmc.K * zero_point_a_negative_val_; |
708 | |
709 | zero_point_c_val_ = dst_zp; |
710 | |
711 | post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args( |
712 | pd->attr()->post_ops_, ctx); |
713 | base_brg_ker_idx_ |
714 | = pd->get_brg_kernel_idx(false, true, false, false, false); |
715 | vnni_factor = data_type_vnni_granularity(bgmmc.wei_dt); |
716 | |
717 | reorder_zp_a_comp_ptr_ = nullptr; |
718 | if (bgmmc_.has_zero_point_a && bgmmc_.blocked_B) { |
719 | // Store the pointer to computed in reorder compensation values to |
720 | // scale them locally by zp_a value just before usage in post-ops. |
721 | // Using the single global scaling before parallel section might |
722 | // produce significant overhead for small problems running in |
723 | // multitreaded execution mode |
724 | const size_t reorder_zp_a_comp_offset |
725 | = weights_d.size() - weights_d.additional_buffer_size(); |
726 | const size_t b_batch |
727 | = get_bb_idx(bgmmc.batch - 1, bgmmc_.bcast_B_desc) + 1; |
728 | const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required |
729 | ? sizeof(int32_t) * b_batch * bgmmc.s8s8_comp_b_str |
730 | : 0; |
731 | reorder_zp_a_comp_ptr_ |
732 | = const_cast<int32_t *>(reinterpret_cast<const int32_t *>( |
733 | &data_B_ptr_[reorder_zp_a_comp_offset |
734 | + s8s8_buffer_sz])); |
735 | } |
736 | |
737 | // Set last_chunk_brgemm_batch_size_ to brgemm_batch_size |
738 | // when K_tail = 0 and brgemm_batch_tail_size = 0 |
739 | last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_tail_size; |
740 | if (bgmmc.K_tail == 0 && last_chunk_brgemm_batch_size_ == 0) |
741 | last_chunk_brgemm_batch_size_ = bgmmc.brgemm_batch_size; |
742 | |
743 | // parallelization |
744 | parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks; |
745 | |
746 | // The number of threads available during primitive execution may |
747 | // increase (ex. Eigen threadpool implementation) or decrease |
748 | // (ex. nested parallelism) compared to the |
749 | // number of threads available during primitive creation. |
750 | // So we limit the total number of threads to the |
751 | // minimum of these two values to prevent potential OOM issues. |
752 | nthr_ = nstl::min(dnnl_get_current_num_threads(), bgmmc.nthr); |
753 | |
754 | nthr_k_ = bgmmc.nthr_k > 0 && bgmmc.nthr_k <= nthr_ ? bgmmc.nthr_k : 1; |
755 | nthr_bmn_ = nthr_ / nthr_k_; |
756 | num_threads_used_ = nthr_k_ * nthr_bmn_; |
757 | |
758 | // If parallel_work_amount_ == 1 and parallel reduction is not used, we |
759 | // limit num threads to 1 as parallel(1, ...) does not create parallel |
760 | // section at all. We do not limit number of threads for case |
761 | // 1 < parallel_work_amount_ < dnnl_get_max_threads() to avoid potential |
762 | // overhead on spawning different number of OMP threads from layer to |
763 | // layer. |
764 | if (parallel_work_amount_ == 1 && !parallel_reduction_is_used()) |
765 | nthr_ = nthr_bmn_ = nthr_k_ = 1; |
766 | |
767 | const bool need_to_calculate_compensation_for_a |
768 | = bgmmc.has_zero_point_b; |
769 | const bool need_to_calculate_compensation_for_b = !IMPLICATION( |
770 | (bgmmc.has_zero_point_a || bgmmc.s8s8_compensation_required), |
771 | bgmmc.blocked_B); |
772 | const bool calculate_compensations_in_copy_routines |
773 | = need_to_calculate_compensation_for_a |
774 | || need_to_calculate_compensation_for_b; |
775 | // currently parallel reduction is supported only for case of |
776 | // non-batched problems without computation of any compensations in |
777 | // copy routines |
778 | assert(IMPLICATION(parallel_reduction_is_used(), |
779 | bgmmc.batch == 1 && !calculate_compensations_in_copy_routines)); |
780 | MAYBE_UNUSED(need_to_calculate_compensation_for_a); |
781 | MAYBE_UNUSED(need_to_calculate_compensation_for_b); |
782 | MAYBE_UNUSED(calculate_compensations_in_copy_routines); |
783 | } |
784 | |
785 | // NOTE: gb --> generalized batch, bb --> broadcast batch |
786 | int get_bb_idx(int gb_idx, const brgemm_matmul_bcast_desc_t &bd) const { |
787 | if (!bd.bcast_mask) // no broadcast |
788 | return gb_idx; |
789 | |
790 | int gb_off_before_bcast = utils::rnd_dn( |
791 | gb_idx, bd.first_bcast_dim_to_last_batch_dim_prod); |
792 | int bb_idx = gb_off_before_bcast / (bd.bcast_dims_prod); |
793 | |
794 | dim_t cur_bcast_dims_prod = bd.bcast_dims_prod; |
795 | int mask = 1 << (bgmmc_.batch_ndims - bd.first_bcast_dim - 1); |
796 | for (int d = bd.first_bcast_dim; d < bd.last_bcast_dim; ++d) { |
797 | if (bd.bcast_mask & mask) // broadcast |
798 | cur_bcast_dims_prod /= bd.batch_dims[d]; |
799 | else { |
800 | int cur_b = (gb_idx / bd.gb_off[d]) % bd.batch_dims[d]; |
801 | bb_idx += cur_b * (bd.gb_off[d] / cur_bcast_dims_prod); |
802 | } |
803 | mask >>= 1; |
804 | } |
805 | bb_idx += gb_idx % bd.gb_off[bd.last_bcast_dim]; |
806 | return bb_idx; |
807 | } |
808 | |
809 | const char *get_data_A_ptr(int b, int m, int k) const { |
810 | int cur_b = get_bb_idx(b, bgmmc_.bcast_A_desc); |
811 | return data_A_ptr_ + get_data_A_off(cur_b, m, k); |
812 | } |
813 | |
814 | const char *get_data_B_ptr(int b, int k, int n) const { |
815 | int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc); |
816 | return data_B_ptr_ + get_data_B_off(cur_b, k, n); |
817 | } |
818 | |
819 | char *get_data_C_ptr(int b, int m, int n) const { |
820 | return data_C_ptr_ + get_data_C_off(b, m, n); |
821 | } |
822 | |
823 | brgemm_batch_element_t *get_batch_elem_ptr(int ithr) const { |
824 | return batch_element_ptr_ |
825 | + ithr * bgmmc_.brgemm_batch_element_per_thr_sz; |
826 | } |
827 | |
828 | void init_brgemm_batch_elements_values(int ithr, int brg_batch_start, |
829 | int brg_batch_iters, int b_idx, int m_blk_idx, int k_blk_idx, |
830 | int n_blk_idx) const { |
831 | auto addr_batch = get_batch_elem_ptr(ithr); |
832 | |
833 | const int m = m_blk_idx * bgmmc_.M_blk; |
834 | const int n = n_blk_idx * bgmmc_.N_blk; |
835 | |
836 | for (int b_iter = 0; b_iter < brg_batch_iters; b_iter++) { |
837 | const int brg_batch_idx = brg_batch_start + b_iter; |
838 | const int k = (k_blk_idx + brg_batch_idx) * bgmmc_.K_blk; |
839 | addr_batch[b_iter].ptr.A = bgmmc_.use_buffer_a |
840 | ? get_buf_A_ptr(ithr, m_blk_idx, brg_batch_idx) |
841 | : get_data_A_ptr(b_idx, m, k); |
842 | addr_batch[b_iter].ptr.B = (bgmmc_.use_buffer_b) |
843 | ? get_buf_B_ptr(ithr, brg_batch_idx, n_blk_idx) |
844 | : get_data_B_ptr(b_idx, k, n); |
845 | } |
846 | } |
847 | |
848 | char *get_buf_A_ptr(int ithr, int m_blk_idx, int k_blk_idx) const { |
849 | if (!bgmmc_.use_buffer_a && !bgmmc_.use_buffer_a_tail_only) |
850 | return nullptr; |
851 | |
852 | const int k_blk_local = bgmmc_.use_buffer_a_tail_only ? 0 : k_blk_idx; |
853 | const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size; |
854 | return buf_A_ptr_ + ithr * bgmmc_.buffer_a_per_thread_sz |
855 | + m_blk_local * bgmmc_.buffer_a_chunk_shift_along_m |
856 | + k_blk_local * bgmmc_.buffer_a_chunk_sz; |
857 | } |
858 | |
859 | char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx) const { |
860 | UNUSED(n_blk_idx); |
861 | if (!bgmmc_.use_buffer_b) return nullptr; |
862 | |
863 | return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz |
864 | + k_blk_idx * bgmmc_.buffer_b_chunk_sz; |
865 | } |
866 | |
867 | char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const { |
868 | if (!bgmmc_.use_buffer_c) return nullptr; |
869 | |
870 | if (bgmmc_.nthr_k > 1) { |
871 | const int nthr_k = bgmmc_.nthr_k <= nthr_ ? bgmmc_.nthr_k : 1; |
872 | const int nthr_bmn = nthr_ / nthr_k; |
873 | const int ithr_k = ithr / nthr_bmn; |
874 | return get_buf_C_par_reduction_ptr(ithr_k, m_blk_idx, n_blk_idx); |
875 | } |
876 | |
877 | const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size; |
878 | const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size; |
879 | const int buf_idx = bgmmc_.N_chunk_size * m_blk_local + n_blk_local; |
880 | |
881 | return buf_C_ptr_ + ithr * bgmmc_.buffer_c_per_thread_sz |
882 | + buf_idx * bgmmc_.buffer_c_chunk_sz; |
883 | } |
884 | |
885 | char *get_buf_C_par_reduction_ptr( |
886 | int ithr_k, int m_blk_idx, int n_blk_idx) const { |
887 | if (bgmmc_.nthr_k <= 1) return nullptr; |
888 | |
889 | const int m = m_blk_idx * bgmmc_.M_blk; |
890 | const int n = n_blk_idx * bgmmc_.N_blk; |
891 | |
892 | if (!bgmmc_.post_ops_applicable && ithr_k == 0) |
893 | return get_data_C_ptr(0, m, n); |
894 | |
895 | int k_buf_idx = ithr_k - (!bgmmc_.post_ops_applicable ? 1 : 0); |
896 | return buf_C_ptr_ + k_buf_idx * bgmmc_.buffer_c_per_thread_sz |
897 | + get_data_C_off(0, m, n) * bgmmc_.acc_dt_sz / bgmmc_.c_dt_sz; |
898 | } |
899 | |
900 | // Auxiliary functions for getting offsets with pre-calculated memory |
901 | // strides for each tensor to get general sulution for all possible |
902 | // dimension without significant overhead |
903 | dim_t get_data_A_off(int b, int m, int k) const { |
904 | using namespace format_tag; |
905 | if (bgmmc_.src_tag == acbd || bgmmc_.src_tag == adbc) { |
906 | dim_t b_off = 0; |
907 | if (!bgmmc_.bcast_A_desc.bcast_mask) { // no broadcast |
908 | const dim_t batch_dim1 = bgmmc_.bcast_A_desc.batch_dims[1]; |
909 | b_off = bgmmc_.A_strides[2] * (b % batch_dim1) |
910 | + (b / batch_dim1) * bgmmc_.A_ptr_shift_b; |
911 | } else { |
912 | b_off = b * bgmmc_.A_ptr_shift_b; |
913 | } |
914 | return b_off + bgmmc_.A_strides[1] * m + bgmmc_.A_strides[0] * k; |
915 | } else { |
916 | return bgmmc_.A_strides[2] * b + bgmmc_.A_strides[1] * m |
917 | + bgmmc_.A_strides[0] * k; |
918 | } |
919 | } |
920 | |
921 | dim_t get_data_B_off(int b, int k, int n) const { |
922 | using namespace format_tag; |
923 | if (bgmmc_.wei_tag == acbd || bgmmc_.wei_tag == adbc) { |
924 | dim_t b_off = 0; |
925 | if (!bgmmc_.bcast_B_desc.bcast_mask) { // no broadcast |
926 | const dim_t batch_dim1 = bgmmc_.bcast_B_desc.batch_dims[1]; |
927 | b_off = bgmmc_.B_strides[2] * (b % batch_dim1) |
928 | + (b / batch_dim1) * bgmmc_.B_ptr_shift_b; |
929 | } else { |
930 | b_off = b * bgmmc_.B_ptr_shift_b; |
931 | } |
932 | return b_off + bgmmc_.B_strides[1] * k + bgmmc_.B_strides[0] * n; |
933 | } else { |
934 | int dt_b_k_blk = bgmmc_.is_bf32 |
935 | ? data_type_vnni_simd_elems<avx512_core>(f32) |
936 | : bgmmc_.wei_k_blk; |
937 | int k_idx = bgmmc_.blocked_B ? k / dt_b_k_blk : k; |
938 | int n_idx = bgmmc_.blocked_B ? n / bgmmc_.wei_n_blk : n; |
939 | return bgmmc_.B_strides[2] * b + bgmmc_.B_strides[1] * k_idx |
940 | + bgmmc_.B_strides[0] * n_idx |
941 | + get_data_B_off_within_block(k, n); |
942 | } |
943 | } |
944 | |
945 | dim_t get_data_B_off_within_block(int k, int n) const { |
946 | using namespace format_tag; |
947 | |
948 | if (!bgmmc_.blocked_B) return 0; |
949 | |
950 | int x0 = k % bgmmc_.wei_k_blk; |
951 | int x1 = n % bgmmc_.wei_n_blk; |
952 | dim_t offset = (x0 / vnni_factor) * vnni_factor * bgmmc_.wei_n_blk |
953 | + x1 * vnni_factor + x0 % vnni_factor; |
954 | return bgmmc_.b_dt_sz * offset; |
955 | } |
956 | |
957 | dim_t get_data_C_off(int b, int m, int n) const { |
958 | using namespace format_tag; |
959 | assert(bgmmc_.dst_tag != adbc); |
960 | if (bgmmc_.dst_tag == acbd) { |
961 | const dim_t batch_dim1 = bgmmc_.bcast_A_desc.batch_dims[1]; |
962 | dim_t b_off = bgmmc_.C_strides[2] * (b % batch_dim1) |
963 | + (b / batch_dim1) * bgmmc_.C_ptr_shift_b; |
964 | return b_off + bgmmc_.C_strides[1] * m + bgmmc_.C_strides[0] * n; |
965 | } else { |
966 | return bgmmc_.C_strides[2] * b + bgmmc_.C_strides[1] * m |
967 | + bgmmc_.C_strides[0] * n; |
968 | } |
969 | } |
970 | |
971 | const char *get_bias_ptr(int n) const { |
972 | if (!bgmmc_.with_bias) return nullptr; |
973 | |
974 | return bias_ptr_ + n * bgmmc_.bias_dt_sz; |
975 | } |
976 | |
977 | int32_t *get_s8s8_comp_ptr(int ithr, int b, int n_blk_idx) const { |
978 | if (!bgmmc_.s8s8_compensation_required) return nullptr; |
979 | |
980 | const int n_blk_local = bgmmc_.use_buffer_b |
981 | ? n_blk_idx % bgmmc_.N_chunk_size |
982 | : n_blk_idx; |
983 | return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str |
984 | + get_bb_idx(b, bgmmc_.bcast_B_desc) * bgmmc_.s8s8_comp_b_str |
985 | + n_blk_local * bgmmc_.s8s8_comp_n_str; |
986 | } |
987 | |
988 | const float *get_oscales_ptr(int n) const { |
989 | return oscales_ptr_ + bgmmc_.is_oscale_per_n * n; |
990 | } |
991 | |
992 | const int32_t *get_zp_a_neg_val_ptr() const { |
993 | return &zero_point_a_negative_val_; |
994 | } |
995 | |
996 | const int32_t *get_zp_b_neg_val_ptr() const { |
997 | return &zero_point_b_negative_val_; |
998 | } |
999 | |
1000 | const int32_t *get_zp_ab_mixed_comp_ptr() const { |
1001 | return &zero_point_mixed_ab_compensation_component_; |
1002 | } |
1003 | |
1004 | const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; } |
1005 | |
1006 | int32_t *get_zp_a_compensation_ptr( |
1007 | int ithr, int b_idx, int n_blk_idx) const { |
1008 | if (!bgmmc_.has_zero_point_a) return nullptr; |
1009 | |
1010 | const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size; |
1011 | int32_t *zp_comp = zero_point_a_compensations_ptr_ |
1012 | + ithr * bgmmc_.zp_a_comp_elems_per_thr |
1013 | + n_blk_local * bgmmc_.zp_a_comp_shift_n; |
1014 | |
1015 | if (bgmmc_.blocked_B) { |
1016 | // Scale computed in reorder compensation values by zp_a value |
1017 | // locally just before usage. Using the single global scaling before |
1018 | // parallel section might produce significant overhead for small |
1019 | // problems running in multitreaded execution mode |
1020 | const int base_offset = get_bb_idx(b_idx, bgmmc_.bcast_B_desc) |
1021 | * rnd_up(bgmmc_.N, bgmmc_.wei_n_blk) |
1022 | + n_blk_idx * bgmmc_.wei_n_blk; |
1023 | PRAGMA_OMP_SIMD() |
1024 | for (int b = 0; b < bgmmc_.wei_n_blk; b++) |
1025 | zp_comp[b] = -zero_point_a_negative_val_ |
1026 | * reorder_zp_a_comp_ptr_[base_offset + b]; |
1027 | } |
1028 | return zp_comp; |
1029 | } |
1030 | |
1031 | int32_t *get_zp_b_compensation_result_ptr(int ithr, int m_blk_idx) const { |
1032 | if (!bgmmc_.has_zero_point_b) return nullptr; |
1033 | |
1034 | const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size; |
1035 | return zero_point_b_compensations_ptr_ |
1036 | + ithr * bgmmc_.zp_b_comp_elems_per_thr |
1037 | + m_blk_local * bgmmc_.zp_b_comp_result_shift_m; |
1038 | } |
1039 | |
1040 | int32_t *get_zp_b_compensation_buffer_ptr(int ithr, int m_blk_idx) const { |
1041 | if (!bgmmc_.has_zero_point_b) return nullptr; |
1042 | |
1043 | const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size; |
1044 | return get_zp_b_compensation_result_ptr(ithr, 0) |
1045 | + bgmmc_.zp_b_comp_buffer_start |
1046 | + m_blk_local * bgmmc_.zp_b_comp_buffer_shift_m; |
1047 | } |
1048 | |
1049 | char *get_tile_workspace(int ithr) const { |
1050 | return is_amx_ ? wsp_tile_ptr_ + ithr * bgmmc_.wsp_tile_per_thr_bytes |
1051 | : nullptr; |
1052 | } |
1053 | |
1054 | const std::vector<const void *> &get_post_ops_binary_rhs_arg_vec() const { |
1055 | return post_ops_binary_rhs_arg_vec_; |
1056 | } |
1057 | |
1058 | int get_base_brgemm_kernel_idx() const { return base_brg_ker_idx_; } |
1059 | |
1060 | bool is_last_K_chunk(int k_chunk_idx) const { |
1061 | return k_chunk_idx == bgmmc_.K_chunks - 1; |
1062 | } |
1063 | |
1064 | int get_brgemm_batch_size(int k_chunk_idx) const { |
1065 | return is_last_K_chunk(k_chunk_idx) ? last_chunk_brgemm_batch_size_ |
1066 | : bgmmc_.brgemm_batch_size; |
1067 | } |
1068 | |
1069 | int get_parallel_work_amount() const { return parallel_work_amount_; } |
1070 | int get_num_threads_for_k() const { return nthr_k_; } |
1071 | bool parallel_reduction_is_used() const { |
1072 | return nthr_k_ > 1 && bgmmc_.K_chunks > 1; |
1073 | } |
1074 | int get_num_threads_for_bmn() const { return nthr_bmn_; } |
1075 | // ithr = ithr_k * nthr_bmn + ithr_bmn |
1076 | int get_thread_idx_for_k(int ithr) const { |
1077 | if (ithr >= num_threads_used_) return -1; |
1078 | const int ithr_k = ithr / nthr_bmn_; |
1079 | return ithr_k < bgmmc_.K_chunks ? ithr_k : -1; |
1080 | } |
1081 | int get_thread_idx_for_bmn(int ithr) const { |
1082 | if (ithr >= num_threads_used_) return -1; |
1083 | const int ithr_bmn = ithr % nthr_bmn_; |
1084 | return ithr_bmn < parallel_work_amount_ ? ithr_bmn : -1; |
1085 | } |
1086 | int get_num_threads_for_parallelization() const { return nthr_; } |
1087 | |
1088 | private: |
1089 | bool is_amx_; |
1090 | const brgemm_matmul_conf_t &bgmmc_; |
1091 | const char *data_A_ptr_; |
1092 | const char *data_B_ptr_; |
1093 | char *data_C_ptr_; |
1094 | brgemm_batch_element_t *batch_element_ptr_; |
1095 | |
1096 | char *buf_A_ptr_; |
1097 | char *buf_B_ptr_; |
1098 | char *buf_C_ptr_; |
1099 | |
1100 | char *wsp_tile_ptr_; |
1101 | const char *bias_ptr_; |
1102 | const float *oscales_ptr_; |
1103 | int32_t *s8s8_compensation_ptr_; |
1104 | |
1105 | int32_t *zero_point_a_compensations_ptr_; |
1106 | int32_t *zero_point_b_compensations_ptr_; |
1107 | int32_t *reorder_zp_a_comp_ptr_; |
1108 | |
1109 | int32_t zero_point_a_negative_val_; |
1110 | int32_t zero_point_b_negative_val_; |
1111 | int32_t zero_point_mixed_ab_compensation_component_; |
1112 | int32_t zero_point_c_val_; |
1113 | std::vector<const void *> post_ops_binary_rhs_arg_vec_; |
1114 | |
1115 | int base_brg_ker_idx_; |
1116 | int vnni_factor; |
1117 | |
1118 | // parallelization parameters |
1119 | int parallel_work_amount_; |
1120 | int nthr_, nthr_k_, nthr_bmn_, num_threads_used_; |
1121 | int last_chunk_brgemm_batch_size_; |
1122 | }; |
1123 | |
1124 | template struct brgemm_matmul_t<avx512_core_amx_fp16>; |
1125 | template struct brgemm_matmul_t<avx512_core_amx>; |
1126 | template struct brgemm_matmul_t<avx512_core_fp16>; |
1127 | template struct brgemm_matmul_t<avx512_core_bf16>; |
1128 | template struct brgemm_matmul_t<avx512_core_vnni>; |
1129 | template struct brgemm_matmul_t<avx512_core>; |
1130 | |
1131 | } // namespace matmul |
1132 | } // namespace x64 |
1133 | } // namespace cpu |
1134 | } // namespace impl |
1135 | } // namespace dnnl |
1136 | |