1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <unordered_set> |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "cpu/platform.hpp" |
21 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
22 | #include "cpu/x64/matmul/brgemm_matmul_utils.hpp" |
23 | |
24 | #include "cpu/binary_injector_utils.hpp" |
25 | #include "cpu/matmul/matmul_utils.hpp" |
26 | #include "oneapi/dnnl/dnnl_debug.h" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | namespace matmul { |
33 | |
34 | using namespace dnnl::impl::cpu::matmul; |
35 | |
36 | using namespace dnnl::impl::memory_tracking::names; |
37 | using namespace dnnl::impl::utils; |
38 | |
39 | using namespace data_type; |
40 | using namespace format_tag; |
41 | |
42 | int get_default_n_block(format_tag_t matrix_b_tag) { |
43 | // Note: consider using weights mem_descriptor 'inner_blks' to |
44 | // return B's inner block for non-default cases. |
45 | switch (matrix_b_tag) { |
46 | case aCB16b64c: |
47 | case aCB16b64c2b: |
48 | case aCB16b64c4b: |
49 | case BA16a64b4a: |
50 | case BA16a64b2a: |
51 | case BA16a64b: return 64; |
52 | case aCB16b48c: |
53 | case aCB16b48c2b: |
54 | case aCB16b48c4b: |
55 | case BA16a48b: |
56 | case BA16a48b2a: |
57 | case BA16a48b4a: return 48; |
58 | case aCB16b32c: |
59 | case aCB16b32c2b: |
60 | case aCB16b32c4b: |
61 | case BA16a32b: |
62 | case BA16a32b2a: |
63 | case BA16a32b4a: return 32; |
64 | case aCB16b16c: |
65 | case aCB16b16c2b: |
66 | case aCB16b16c4b: |
67 | case BA16a16b: |
68 | case BA16a16b2a: |
69 | case BA16a16b4a: return 16; |
70 | default: return 64; |
71 | } |
72 | } |
73 | |
74 | // TODO: add support of post-ops with multiple binary and eltwise execution |
75 | bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr, |
76 | const memory_desc_wrapper &dst_d) { |
77 | using namespace injector; |
78 | |
79 | const auto &post_ops = attr.post_ops_; |
80 | const auto ndims = dst_d.ndims(); |
81 | |
82 | bool is_binary_po_per_oc_sp_bcast {}; |
83 | bool is_binary_po_channel_bcast {}; |
84 | bool is_binary_po_per_mb_w_bcast {}; |
85 | bool is_binary_po_per_w_bcast {}; |
86 | std::tie(is_binary_po_per_oc_sp_bcast, is_binary_po_channel_bcast, |
87 | is_binary_po_per_mb_w_bcast, is_binary_po_per_w_bcast) |
88 | = binary_injector_utils::bcast_strategies_present_tup( |
89 | post_ops.entry_, dst_d, |
90 | broadcasting_strategy_t::per_oc_spatial, |
91 | broadcasting_strategy_t::per_mb_spatial, |
92 | broadcasting_strategy_t::per_mb_w, |
93 | broadcasting_strategy_t::per_w); |
94 | const bool supported_binary_bcast |
95 | = IMPLICATION(is_binary_po_per_oc_sp_bcast, ndims < 4) |
96 | && IMPLICATION( |
97 | is_binary_po_channel_bcast, utils::one_of(ndims, 3, 4)) |
98 | && IMPLICATION( |
99 | is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4)) |
100 | && IMPLICATION( |
101 | is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4)); |
102 | return supported_binary_bcast |
103 | && injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(), |
104 | {sum, eltwise, binary}, post_ops, &dst_d, |
105 | false /*sum_at_pos_0_only*/, |
106 | false /*sum_requires_scale_one*/, |
107 | false /*sum_requires_zp_zero*/, |
108 | {broadcasting_strategy_t::per_oc, |
109 | broadcasting_strategy_t::per_oc_spatial, |
110 | broadcasting_strategy_t::scalar, |
111 | broadcasting_strategy_t::per_mb_spatial, |
112 | broadcasting_strategy_t::per_mb_w, |
113 | broadcasting_strategy_t::per_w, |
114 | broadcasting_strategy_t::no_broadcast})); |
115 | } |
116 | |
117 | status_t check_isa_with_datatype( |
118 | const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) { |
119 | const bool ok = IMPLICATION(bm_conf_utils.is_f32(), |
120 | isa == avx512_core || bm_conf_utils.is_bf32()) |
121 | && IMPLICATION(bm_conf_utils.is_int8(), |
122 | one_of(isa, avx512_core_amx, avx512_core_vnni)) |
123 | && IMPLICATION(bm_conf_utils.is_bf16(), |
124 | one_of(isa, avx512_core_amx, avx512_core_bf16)) |
125 | && IMPLICATION(bm_conf_utils.is_f16(), |
126 | one_of(isa, avx512_core_amx_fp16, avx512_core_fp16)) |
127 | && IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(), |
128 | mayiuse(avx512_core_vnni)); |
129 | return ok ? status::success : status::unimplemented; |
130 | } |
131 | |
132 | brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t( |
133 | brgemm_matmul_conf_t &bgmmc, const cpu_isa_t isa, |
134 | const primitive_attr_t &attr, bool A_any_layout, bool B_any_layout, |
135 | bool C_any_layout, bool bias_any_layout) |
136 | : bgmmc(bgmmc) |
137 | , f32_dt(utils::everyone_is(f32, bgmmc.src_dt, bgmmc.wei_dt, bgmmc.dst_dt)) |
138 | , bf16_dt(utils::everyone_is(bf16, bgmmc.src_dt, bgmmc.wei_dt) |
139 | && one_of(bgmmc.dst_dt, bf16, f32)) |
140 | , f16_dt(utils::everyone_is(f16, bgmmc.src_dt, bgmmc.wei_dt) |
141 | && one_of(bgmmc.dst_dt, f16, f32)) |
142 | , int8_dt(utils::one_of(bgmmc.src_dt, u8, s8) && bgmmc.wei_dt == s8 |
143 | && one_of(bgmmc.dst_dt, u8, s8, s32, f32, bf16)) |
144 | , bf32_dt(f32_dt && attr.fpmath_mode_ == fpmath_mode::bf16 |
145 | && isa == avx512_core_amx) |
146 | , A_any_layout(A_any_layout) |
147 | , B_any_layout(B_any_layout) |
148 | , C_any_layout(C_any_layout) |
149 | , bias_any_layout(bias_any_layout) |
150 | , plain_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ab, abc, abcd, abcde, |
151 | abcdef, abcdefg, abcdefgh, abcdefghi, abcdefghij, abcdefghijk, |
152 | abcdefghijkl)) |
153 | , transposed_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ba, acb, abdc, |
154 | abced, abcdfe, abcdegf, abcdefhg, abcdefgih, abcdefghji, |
155 | abcdefghikj, abcdefghijlk)) |
156 | , blocked_64n_B_layout_tag(pick_blocked_B_layout(64)) |
157 | , blocked_48n_B_layout_tag(pick_blocked_B_layout(48)) |
158 | , blocked_32n_B_layout_tag(pick_blocked_B_layout(32)) |
159 | , blocked_16n_B_layout_tag(pick_blocked_B_layout(16)) |
160 | , blocked_B_layouts_allowed(!utils::one_of(format_tag::undef, |
161 | blocked_64n_B_layout_tag, blocked_48n_B_layout_tag, |
162 | blocked_32n_B_layout_tag, blocked_16n_B_layout_tag)) |
163 | , n_blk_fixed((!B_any_layout) && blocked_B_layouts_allowed) { |
164 | assert(int8_dt || bf16_dt || f16_dt || f32_dt || bf32_dt); |
165 | } |
166 | |
167 | status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag( |
168 | memory_desc_t &B_md, bool init_n_tag) const { |
169 | |
170 | if (B_any_layout) { |
171 | const int default_n_block = init_n_tag |
172 | ? get_default_n_block(format_tag::undef) |
173 | : bgmmc.N_blk; |
174 | bgmmc.wei_tag = blocked_B_layouts_allowed |
175 | ? this->pick_blocked_B_layout(default_n_block) |
176 | : plain_tensor_layout_tag; |
177 | if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented; |
178 | |
179 | CHECK(memory_desc_init_by_tag(B_md, bgmmc.wei_tag)); |
180 | const int dmax = nstl::min(bgmmc.ndims, 3); |
181 | const memory_desc_wrapper B_d(&B_md); |
182 | for (int d = 0; d < dmax; d++) { |
183 | int dim = bgmmc.ndims - 1 - d; |
184 | bgmmc.B_strides[d] |
185 | = bgmmc.b_dt_sz * B_d.blocking_desc().strides[dim]; |
186 | } |
187 | } else { |
188 | bgmmc.wei_tag = blocked_B_layouts_allowed |
189 | ? memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag, |
190 | transposed_tensor_layout_tag, blocked_64n_B_layout_tag, |
191 | blocked_48n_B_layout_tag, blocked_32n_B_layout_tag, |
192 | blocked_16n_B_layout_tag) |
193 | : memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag, |
194 | transposed_tensor_layout_tag, acbd, adbc); |
195 | |
196 | // For cases when the weights tensor is transposed but has |
197 | // 'dim_size == 1', we can ignore transposition and compute as a plain |
198 | // format tensor. This removes the need of allocating a scratchpad for |
199 | // copy_B. |
200 | if (transposed_tensor_layout_tag == bgmmc.wei_tag) { |
201 | memory_desc_t B_md_plain; |
202 | const status_t status |
203 | = memory_desc_init_by_tag(B_md_plain, B_md.ndims, B_md.dims, |
204 | B_md.data_type, plain_tensor_layout_tag); |
205 | if (status != status::success) return status; |
206 | if (B_md_plain == B_md) bgmmc.wei_tag = plain_tensor_layout_tag; |
207 | } |
208 | |
209 | if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented; |
210 | } |
211 | |
212 | return status::success; |
213 | } |
214 | |
215 | status_t brgemm_matmul_conf_utils_t::update_and_check_B_tag( |
216 | memory_desc_t &B_md, int n_blk_size) const { |
217 | |
218 | if (n_blk_fixed && n_blk_size != bgmmc.wei_n_blk) |
219 | return status::unimplemented; |
220 | |
221 | if (!(B_any_layout && blocked_B_layouts_allowed)) return status::success; |
222 | |
223 | return set_or_check_B_tag(B_md, false); |
224 | } |
225 | |
226 | status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, |
227 | memory_desc_t &C_md, memory_desc_t &bias_md) const { |
228 | if (A_any_layout) { |
229 | const format_tag_t desired_A_tag = plain_tensor_layout_tag; |
230 | CHECK(memory_desc_init_by_tag(A_md, desired_A_tag)); |
231 | bgmmc.src_tag = desired_A_tag; |
232 | } else { |
233 | bgmmc.src_tag = (this->is_bf16() || this->is_f32() || this->is_bf32() |
234 | || this->is_f16()) |
235 | ? memory_desc_matches_one_of_tag(A_md, plain_tensor_layout_tag, |
236 | transposed_tensor_layout_tag, acbd, adbc) |
237 | : memory_desc_matches_one_of_tag( |
238 | A_md, plain_tensor_layout_tag, acbd); |
239 | } |
240 | |
241 | if (C_any_layout) { |
242 | const format_tag_t desired_C_tag = plain_tensor_layout_tag; |
243 | CHECK(memory_desc_init_by_tag(C_md, desired_C_tag)); |
244 | bgmmc.dst_tag = desired_C_tag; |
245 | } else { |
246 | bgmmc.dst_tag = memory_desc_matches_one_of_tag( |
247 | C_md, plain_tensor_layout_tag, acbd); |
248 | } |
249 | |
250 | if (one_of(format_tag::undef, bgmmc.src_tag, bgmmc.dst_tag)) |
251 | return status::unimplemented; |
252 | |
253 | if (bgmmc.with_bias && bias_any_layout) |
254 | CHECK(memory_desc_init_by_tag(bias_md, plain_tensor_layout_tag)); |
255 | |
256 | return status::success; |
257 | } |
258 | |
259 | status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const { |
260 | |
261 | memory_desc_t want_B_md = B_md; |
262 | // Set bits for all dimensions except k dimension |
263 | const int compensation_mask |
264 | = ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2))); |
265 | if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) { |
266 | want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8; |
267 | want_B_md.extra.compensation_mask = compensation_mask; |
268 | } |
269 | if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) { |
270 | want_B_md.extra.flags |
271 | |= memory_extra_flags::compensation_conv_asymmetric_src; |
272 | want_B_md.extra.asymm_compensation_mask = compensation_mask; |
273 | } |
274 | |
275 | if (B_any_layout) { |
276 | B_md = want_B_md; |
277 | return status::success; |
278 | } |
279 | |
280 | return B_md == want_B_md ? status::success : status::unimplemented; |
281 | } |
282 | |
283 | format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( |
284 | int n_blk) const { |
285 | |
286 | if (bgmmc.ndims > 3) return format_tag::undef; |
287 | if (this->is_int8()) switch (n_blk) { |
288 | case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a; |
289 | case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a; |
290 | case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a; |
291 | case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a; |
292 | default: return format_tag::undef; |
293 | } |
294 | |
295 | if (this->is_bf16() |
296 | || (this->is_f16() && bgmmc.isa == avx512_core_amx_fp16)) |
297 | switch (n_blk) { |
298 | case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a; |
299 | case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a; |
300 | case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a; |
301 | case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a; |
302 | default: return format_tag::undef; |
303 | } |
304 | // Note: bf32 assumes f32 blocking |
305 | if (this->is_f32() || this->is_bf32() || this->is_f16()) switch (n_blk) { |
306 | case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b; |
307 | case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b; |
308 | case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b; |
309 | case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b; |
310 | default: return format_tag::undef; |
311 | } |
312 | return format_tag::undef; |
313 | } |
314 | |
315 | brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) { |
316 | return attr.zero_points_.has_default_values(arg) |
317 | ? brgemm_broadcast_t::none |
318 | : brgemm_broadcast_t::per_tensor; |
319 | } |
320 | |
321 | struct matmul_amx_blocking_params_t : public brgemm_matmul_conf_t { |
322 | matmul_amx_blocking_params_t() |
323 | : nthr_k_(0) |
324 | , nthr_mnb_(0) |
325 | , nthr_(0) |
326 | , n_blk_(0) |
327 | , n_chunk_size_(0) |
328 | , n_chunk_elems_(0) |
329 | , m_blk_(0) |
330 | , m_chunk_size_(0) |
331 | , m_chunk_elems_(0) |
332 | , k_blk_(0) |
333 | , k_chunk_size_(0) |
334 | , k_chunk_elems_(0) |
335 | , current_lda_(0) |
336 | , need_buf_c_(false) |
337 | , blocking_chunk_mem_size_(0) |
338 | , efficiency_score_(0.0f) {} |
339 | |
340 | matmul_amx_blocking_params_t(const brgemm_matmul_conf_t &bgmmc) |
341 | : brgemm_matmul_conf_t(bgmmc) |
342 | , nthr_k_(nstl::max(nthr_k, 1)) |
343 | , nthr_mnb_(nthr / nthr_k_) |
344 | , nthr_(nthr_mnb_ * nthr_k_) |
345 | , n_blk_(N_blk) |
346 | , n_chunk_size_(N_chunk_size) |
347 | , n_chunk_elems_(n_blk_ * n_chunk_size_) |
348 | , m_blk_(M_blk) |
349 | , m_chunk_size_(M_chunk_size) |
350 | , m_chunk_elems_(m_blk_ * m_chunk_size_) |
351 | , k_blk_(K_blk) |
352 | , k_chunk_size_(brgemm_batch_size) |
353 | , k_chunk_elems_(k_blk_ * k_chunk_size_) |
354 | , current_lda_(LDA) |
355 | , need_buf_c_(use_buffer_c) |
356 | , blocking_chunk_mem_size_(0) |
357 | , efficiency_score_(0.0f) {} |
358 | |
359 | void set_blocking_parameters(int nthr_k, int n_blk, int n_chunk_size, |
360 | int m_blk, int m_chunk_size); |
361 | void update_configuration(brgemm_matmul_conf_t &bgmmc) const; |
362 | float get_blocking_scores() const { return efficiency_score_; } |
363 | |
364 | static size_t L2_threshold(); |
365 | |
366 | private: |
367 | // num threads for parallelism wrt k dimension |
368 | int nthr_k_; |
369 | // num threads for parallelism wrt m, n and batch dimensions |
370 | int nthr_mnb_; |
371 | int nthr_; |
372 | dim_t n_blk_, n_chunk_size_, n_chunk_elems_; |
373 | dim_t m_blk_, m_chunk_size_, m_chunk_elems_; |
374 | dim_t k_blk_, k_chunk_size_, k_chunk_elems_; |
375 | |
376 | dim_t current_lda_; |
377 | bool need_buf_c_; |
378 | size_t blocking_chunk_mem_size_; |
379 | float efficiency_score_; |
380 | |
381 | void update_k_blocking_dependent_params(); |
382 | dim_t get_actual_lda(); |
383 | bool is_buffer_c_required(); |
384 | size_t calculate_chunk_memory_size(); |
385 | float get_thread_balance_scores(); |
386 | float get_copied_data_reusage_scores(); |
387 | float get_L2_utilization_scores() const; |
388 | float calculate_blocking_scores(); |
389 | }; |
390 | |
391 | struct matmul_avx512_blocking_params_t { |
392 | struct matmul_params_t { |
393 | |
394 | matmul_params_t(int m, int n, int k, int od) |
395 | : M(m), N(n), K(k), batch(od) {} |
396 | |
397 | const int M; |
398 | const int N; |
399 | const int K; |
400 | const int batch; |
401 | }; |
402 | |
403 | matmul_avx512_blocking_params_t(const matmul_params_t &m, const int nthr) |
404 | : mp(m) |
405 | , m_chunks(1) |
406 | , m_blk(1) |
407 | , m_tail(0) |
408 | , n_chunks(1) |
409 | , n_blk(1) |
410 | , n_tail(0) |
411 | , batch_size(1) |
412 | , k_blk(1) |
413 | , k_tail(0) |
414 | , nthr_k(1) |
415 | , nthr(nthr) {} |
416 | |
417 | matmul_avx512_blocking_params_t &operator=( |
418 | const matmul_avx512_blocking_params_t &brgemm_params) { |
419 | m_chunks = brgemm_params.m_chunks; |
420 | m_blk = brgemm_params.m_blk; |
421 | m_tail = brgemm_params.m_tail; |
422 | n_chunks = brgemm_params.n_chunks; |
423 | n_blk = brgemm_params.n_blk; |
424 | n_tail = brgemm_params.n_tail; |
425 | batch_size = brgemm_params.batch_size; |
426 | k_blk = brgemm_params.k_blk; |
427 | k_tail = brgemm_params.k_tail; |
428 | nthr_k = brgemm_params.nthr_k; |
429 | return *this; |
430 | } |
431 | |
432 | const matmul_params_t ∓ |
433 | int m_chunks, m_blk, m_tail; |
434 | int n_chunks, n_blk, n_tail; |
435 | int batch_size, k_blk, k_tail; |
436 | int nthr_k; |
437 | const int nthr; |
438 | |
439 | void update_params(int m_chunks_, int m_blk_, int n_chunks_, int n_blk_, |
440 | int batch_size_, int k_blk_, int nthr_k_) { |
441 | m_chunks = m_chunks_; |
442 | m_blk = m_blk_; |
443 | m_tail = mp.M % m_blk; |
444 | n_chunks = n_chunks_; |
445 | n_blk = n_blk_; |
446 | n_tail = mp.N % n_blk; |
447 | batch_size = batch_size_; |
448 | k_blk = k_blk_; |
449 | k_tail = mp.K % k_blk; |
450 | nthr_k = nthr_k_; |
451 | } |
452 | |
453 | float calculate_spatial_disbalance(size_t work, size_t thread_block) const { |
454 | size_t mod = work % thread_block; |
455 | size_t scalar = work < thread_block |
456 | ? thread_block - mod |
457 | : nstl::min(thread_block - mod, mod); |
458 | return static_cast<float>(scalar) / thread_block; |
459 | } |
460 | |
461 | float get_imbalance() const { |
462 | const size_t cur_nthr = nthr / nthr_k; |
463 | |
464 | size_t parallel_work = get_parallel_work(); |
465 | const float parallel_work_disb |
466 | = calculate_spatial_disbalance(parallel_work, cur_nthr); |
467 | |
468 | int m_work = (m_blk * div_up(mp.M, m_blk)) % mp.M; |
469 | const float m_blk_disbalance = static_cast<float>(m_work) / mp.M; |
470 | |
471 | int num_n_blk = div_up(mp.N, n_blk); |
472 | int par_n_chunks = div_up(num_n_blk, n_chunks); |
473 | const float n_chunk_disbalance |
474 | = (static_cast<float>(par_n_chunks) * n_chunks - num_n_blk) |
475 | / num_n_blk; |
476 | |
477 | const float disbalance_nthr_k |
478 | = calculate_spatial_disbalance(mp.K, nthr_k * k_blk); |
479 | |
480 | const float thread_allocation_disb |
481 | = (cur_nthr * nthr_k) != static_cast<size_t>(nthr) |
482 | ? (static_cast<float>(nthr) - cur_nthr * nthr_k) / nthr |
483 | : 0; |
484 | |
485 | const float score |
486 | = (parallel_work_disb + m_blk_disbalance + n_chunk_disbalance |
487 | + thread_allocation_disb + disbalance_nthr_k) |
488 | / 5; |
489 | |
490 | return score; |
491 | } |
492 | |
493 | size_t get_parallel_work() const { |
494 | int m_elems = div_up(mp.M, m_blk * m_chunks); |
495 | int n_elems = div_up(mp.N, n_blk * n_chunks); |
496 | return static_cast<size_t>(m_elems) * n_elems * mp.batch; |
497 | } |
498 | |
499 | inline dim_t get_actual_lda(bool use_buffer_a, dim_t a_dt_sz) const { |
500 | if (!use_buffer_a) return mp.K; |
501 | |
502 | constexpr int bytes_in_cacheline = 64; |
503 | const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz; |
504 | dim_t lda = rnd_up(k_blk, elems_in_cacheline); |
505 | const bool is_big_pow_2 = lda >= 512 && math::is_pow2(lda); |
506 | if (is_big_pow_2) lda += elems_in_cacheline; |
507 | return lda; |
508 | } |
509 | |
510 | inline bool is_buffer_c_required( |
511 | dim_t acc_dt, dim_t dst_dt, bool with_sum) const { |
512 | const size_t k_chunk_elems = k_blk * batch_size; |
513 | if (nthr_k > 1 && static_cast<size_t>(mp.K) > k_chunk_elems) |
514 | return true; |
515 | |
516 | return ((acc_dt != dst_dt || with_sum) |
517 | && (static_cast<size_t>(mp.K) > k_chunk_elems |
518 | || mp.K % k_blk > 0)); |
519 | } |
520 | |
521 | void update_configuration(brgemm_matmul_conf_t &bgmmc) const { |
522 | bgmmc.M_blk = m_blk; |
523 | bgmmc.M_chunk_size = m_chunks; |
524 | bgmmc.N_blk = n_blk; |
525 | bgmmc.N_chunk_size = n_chunks; |
526 | |
527 | bgmmc.K_blk = rnd_up(k_blk, bgmmc.required_k_granularity); |
528 | bgmmc.brgemm_batch_size = batch_size; |
529 | |
530 | bgmmc.nthr_k = nthr_k; |
531 | |
532 | bgmmc.use_buffer_c = is_buffer_c_required( |
533 | bgmmc.acc_dt, bgmmc.dst_dt, bgmmc.with_sum); |
534 | bgmmc.LDA = (bgmmc.src_tag == acbd && !bgmmc.use_buffer_a |
535 | ? bgmmc.A_strides[1] / bgmmc.a_dt_sz |
536 | : get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz)); |
537 | } |
538 | }; |
539 | |
540 | size_t matmul_amx_blocking_params_t::L2_threshold() { |
541 | return 3 * platform::get_per_core_cache_size(2) / 4; |
542 | } |
543 | |
544 | void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc, |
545 | const brgemm_matmul_conf_utils_t &bm_conf_utils, |
546 | matmul_amx_blocking_params_t &best_blocking) { |
547 | |
548 | matmul_amx_blocking_params_t current_blocking(bgmmc); |
549 | |
550 | const int min_k_per_thread = 1024; |
551 | const int max_k_parallel_work |
552 | = div_up(static_cast<int>(bgmmc.K), min_k_per_thread); |
553 | const bool is_amx_xf16 = bgmmc.is_amx |
554 | && (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16() |
555 | || bm_conf_utils.is_bf32()); |
556 | const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8(); |
557 | |
558 | const int max_nthr_k = is_amx_xf16 && bgmmc.batch == 1 |
559 | ? nstl::min(saturate(1, 7, bgmmc.nthr / 8), max_k_parallel_work) |
560 | : 1; |
561 | int iter = 0; |
562 | for (int nthr_k = 1; nthr_k <= max_nthr_k; nthr_k++) { |
563 | int num_M_blk = div_up(bgmmc.M, bgmmc.M_blk); |
564 | int num_N_blk = div_up(bgmmc.N, bgmmc.N_blk); |
565 | int k_parallel_work = nstl::min(max_k_parallel_work, nthr_k); |
566 | int num_parallel_work |
567 | = bgmmc.batch * num_M_blk * num_N_blk * k_parallel_work; |
568 | const bool a_lot_of_parallel_work = num_parallel_work > 8 * bgmmc.nthr; |
569 | const bool a_lot_of_parallel_work_lvl2 |
570 | = num_parallel_work > 16 * bgmmc.nthr; |
571 | const bool low_parallelism |
572 | = static_cast<float>(num_parallel_work) < 1.5f * bgmmc.nthr; |
573 | const bool maybe_low_blocking |
574 | = is_amx_int8 && bm_conf_utils.maybe_low_brg_blocking(); |
575 | const int min_M_blk |
576 | = (maybe_low_blocking || low_parallelism) && bgmmc.M_blk > 32 |
577 | ? div_up(bgmmc.M_blk, 2) |
578 | : bgmmc.M_blk; |
579 | const int min_N_blk = low_parallelism && is_amx_xf16 |
580 | && !bm_conf_utils.check_n_blk_fixed() |
581 | && bgmmc.N_blk > 32 |
582 | ? 32 |
583 | : bgmmc.N_blk; |
584 | const int desired_M_chunk = nstl::min( |
585 | (bgmmc.use_buffer_b || a_lot_of_parallel_work ? 4 : 1), |
586 | num_M_blk); |
587 | const int desired_N_chunk = nstl::min(a_lot_of_parallel_work_lvl2 |
588 | ? 6 |
589 | : (bgmmc.use_buffer_a || a_lot_of_parallel_work ? 4 |
590 | : 1), |
591 | num_N_blk); |
592 | |
593 | std::unordered_set<int> mblk_candidates; |
594 | for (int m_blk = bgmmc.M_blk; m_blk >= min_M_blk; |
595 | m_blk = m_blk > 1 ? div_up(m_blk, 2) : m_blk - 1) { |
596 | if (IMPLICATION(maybe_low_blocking, m_blk != bgmmc.M_blk)) |
597 | mblk_candidates.insert(m_blk); |
598 | } |
599 | |
600 | if (bgmmc.M > 16) { |
601 | // Add multiple of 16 M block sizes for consideration |
602 | const int mul16_m_blk_max |
603 | = nstl::min(rnd_dn(static_cast<int>(bgmmc.M), 16), 64); |
604 | const int mul16_m_blk_min = rnd_up(min_M_blk, 16); |
605 | for (int m_blk = mul16_m_blk_max; m_blk >= mul16_m_blk_min; |
606 | m_blk -= 16) { |
607 | mblk_candidates.insert(m_blk); |
608 | } |
609 | } |
610 | |
611 | for_(int n_blk = bgmmc.N_blk; n_blk >= min_N_blk; n_blk -= 16) |
612 | for_(int m_blk : mblk_candidates) |
613 | for_(int n_ch_sz = desired_N_chunk; n_ch_sz >= 1; n_ch_sz--) |
614 | for (int m_ch_sz = desired_M_chunk; m_ch_sz >= 1; m_ch_sz--, iter++) { |
615 | current_blocking.set_blocking_parameters( |
616 | nthr_k, n_blk, n_ch_sz, m_blk, m_ch_sz); |
617 | if (current_blocking.get_blocking_scores() |
618 | > best_blocking.get_blocking_scores()) |
619 | best_blocking = current_blocking; |
620 | } |
621 | } |
622 | } |
623 | |
624 | float compute_blocking_heuristic_avx512(brgemm_matmul_conf_t &bgmmc, |
625 | const brgemm_matmul_conf_utils_t &bm_conf_utils, |
626 | const matmul_avx512_blocking_params_t::matmul_params_t &matmul, |
627 | matmul_avx512_blocking_params_t &best_blocking) { |
628 | |
629 | const int nthr = bgmmc.nthr; |
630 | |
631 | const int max_m_blk = nstl::min(256, matmul.M); |
632 | int min_m_blk = nstl::min(32, matmul.M); |
633 | |
634 | int n_blk = bgmmc.N_blk; |
635 | const int n_chunks = div_up(matmul.N, n_blk); |
636 | const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1; |
637 | const int n_chunks_start = nstl::min(max_n_chunks, div_up(matmul.N, n_blk)); |
638 | |
639 | // Note: do not extend K_blk for 'bwd_w' cases |
640 | const bool use_extended_k_blk = matmul.K > 1024 |
641 | && (!bm_conf_utils.check_is_transposed(bgmmc.src_tag)); |
642 | int default_k_blk = use_extended_k_blk ? 1024 : 512; |
643 | int k_blk = nstl::min(matmul.K, default_k_blk); |
644 | int start_nthr_k = 1; |
645 | |
646 | // for cases with low parallel work, reduce 'min_m_blk' to |
647 | // increase potential parallelization balance. |
648 | const size_t max_parallel = matmul.batch * n_chunks; |
649 | const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel; |
650 | if (low_parallel_work) { |
651 | |
652 | min_m_blk = nstl::min(matmul.M, 16); |
653 | |
654 | // 2nd level tuning for low parallel work cases: |
655 | bool bwd_w_low_spatial_work |
656 | = bm_conf_utils.check_is_transposed(bgmmc.src_tag) |
657 | && matmul.M <= 512; |
658 | bool low_spatial_work = matmul.M <= 40; |
659 | if (low_spatial_work || bwd_w_low_spatial_work) { |
660 | |
661 | // Reduce n_blk size to increase parallel space |
662 | // note: over reduction of n_blk size on 2d shapes when n_chunks == 1 |
663 | // showed significant performance degradation |
664 | if (!bm_conf_utils.check_n_blk_fixed() |
665 | && IMPLICATION(n_chunks == 1, bgmmc.batch_ndims > 0)) |
666 | n_blk = nstl::min(matmul.N, 32); |
667 | |
668 | // force to plain B (wei) in small spatial size for FWD: |
669 | // note: this showed significant performance gain in WnD shapes |
670 | bool is_FWD = !(bm_conf_utils.check_is_transposed(bgmmc.wei_tag) |
671 | || bm_conf_utils.check_is_transposed(bgmmc.src_tag)); |
672 | if (bgmmc.use_buffer_b && is_FWD) { |
673 | bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(false); |
674 | } |
675 | } |
676 | |
677 | // Parallelize across K for shapes with big 'K' dimension |
678 | bool bwd_w_par_k_blk = bm_conf_utils.check_is_transposed(bgmmc.src_tag) |
679 | && IMPLICATION(bm_conf_utils.is_bf16(), math::is_pow2(matmul.K)) |
680 | && matmul.K >= 2048; |
681 | if (bwd_w_par_k_blk) { |
682 | start_nthr_k = nstl::min(nthr, 4); |
683 | assert(k_blk == nstl::min(matmul.K, 512)); |
684 | } |
685 | } |
686 | |
687 | float best_imbalance = 1.f; // reduce |
688 | for_(int nthr_k = start_nthr_k; nthr_k >= 1; --nthr_k) |
689 | for_(int n_chunk_size = n_chunks_start; n_chunk_size >= 1; --n_chunk_size) |
690 | for (int m_blk = max_m_blk; m_blk >= min_m_blk; --m_blk) { |
691 | |
692 | matmul_avx512_blocking_params_t cur_params(matmul, nthr); |
693 | cur_params.update_params( |
694 | 1, m_blk, n_chunk_size, n_blk, 1, k_blk, nthr_k); |
695 | |
696 | float cur_imbalance = cur_params.get_imbalance(); |
697 | if (cur_imbalance < best_imbalance) { |
698 | best_imbalance = cur_imbalance; |
699 | best_blocking = cur_params; |
700 | } |
701 | } |
702 | return best_imbalance; |
703 | } |
704 | |
705 | status_t compute_blocking_heuristic(brgemm_matmul_conf_t &bgmmc, |
706 | const brgemm_matmul_conf_utils_t &bm_conf_utils) { |
707 | |
708 | bgmmc.N_blk = nstl::min(static_cast<dim_t>(bgmmc.wei_n_blk), bgmmc.N); |
709 | |
710 | bgmmc.M_chunk_size = bgmmc.N_chunk_size = 1; |
711 | |
712 | if (bgmmc.is_amx) { |
713 | |
714 | // Configure matrix sizes |
715 | const dim_t max_M = 64, min_M = 32; |
716 | bgmmc.M_blk = 1; |
717 | for (dim_t m_ = max_M; m_ >= min_M; m_--) { |
718 | if (bgmmc.M % m_ == 0) { |
719 | bgmmc.M_blk = m_; |
720 | break; |
721 | } |
722 | } |
723 | if (bgmmc.M_blk == 1) bgmmc.M_blk = nstl::min(bgmmc.M, max_M); |
724 | |
725 | // AMX BRGEMM kernel requires (K_brgemm % 64 == 0 || K_brgemm < 64) |
726 | // for K_brgemm reduction value to avoid AMX tiles re-configuration. |
727 | // To satisfy this condition K_tail value is fixed to K % wei_k_blk here. |
728 | const bool fixed_K_tail_size |
729 | = bgmmc.K % bgmmc.wei_k_blk > 0 && bgmmc.K > bgmmc.wei_k_blk; |
730 | bgmmc.K_blk = bgmmc.K < bgmmc.wei_k_blk |
731 | ? rnd_up(bgmmc.K, bgmmc.required_k_granularity) |
732 | : fixed_K_tail_size ? bgmmc.wei_k_blk : bgmmc.K; |
733 | bgmmc.brgemm_batch_size |
734 | = nstl::max(bgmmc.K / bgmmc.K_blk, static_cast<dim_t>(1)); |
735 | |
736 | matmul_amx_blocking_params_t best_blocking(bgmmc); |
737 | |
738 | compute_blocking_heuristic_amx(bgmmc, bm_conf_utils, best_blocking); |
739 | |
740 | if (best_blocking.get_blocking_scores() == 0.0f) |
741 | return status::unimplemented; |
742 | |
743 | best_blocking.update_configuration(bgmmc); |
744 | |
745 | } else { |
746 | // TODO: |
747 | // *) adjust K_BLK using 'rnd_up(bgmmc.K, bgmmc.required_k_granularity)' |
748 | // for non-f32 datatypes. |
749 | // *) optimize param search complexity |
750 | |
751 | // Approach for selecting ideal 'blocking parameters': |
752 | // M_blk: |
753 | // - main param for having parallel_work optimally distributed. |
754 | // - 'br_block' is a BRGeMM uKernel parameter derived from 'M_Blk', |
755 | // however, there is no measured performance impact from small |
756 | // variations in 'br_block' size. |
757 | // |
758 | // M_Chunks: |
759 | // - no noticeable performance impact i.e. 'M_blk = M_Chunks * M_Blk'; |
760 | // with M_Chunks > 1', brgemm has the same performance results. Instead, |
761 | // choose a larger 'M_blk'. |
762 | // |
763 | // N_blk: |
764 | // - ideally 64 (from 'get_default_n_block()'). |
765 | // - can be reduced to 32 to improve performance for some shapes, as |
766 | // well as increasing parallelization search space. |
767 | // |
768 | // N_Chunks: |
769 | // - No different as long as thread/work balance is the same. |
770 | // - Note: for A_Transposed cases using A_buffer (i.e. bwd-w): select |
771 | // a higher count to increase performance -better for transposed data |
772 | // reuse. |
773 | // |
774 | // K_blk: |
775 | // - block size variation '512 <= K_blk < 1024' has negligible |
776 | // performance difference. However, Some cases benefit from higher |
777 | // block size. |
778 | // - can parallelize if not enough work; notice: requires reduction! |
779 | // |
780 | // Batch_Size: |
781 | // - unused. |
782 | |
783 | const matmul_avx512_blocking_params_t::matmul_params_t matmul( |
784 | bgmmc.M, bgmmc.N, bgmmc.K, bgmmc.batch); |
785 | |
786 | matmul_avx512_blocking_params_t best_blocking(matmul, bgmmc.nthr); |
787 | |
788 | const float best_imbalance = compute_blocking_heuristic_avx512( |
789 | bgmmc, bm_conf_utils, matmul, best_blocking); |
790 | |
791 | if (best_imbalance == 1.f) return status::unimplemented; |
792 | |
793 | best_blocking.update_configuration(bgmmc); |
794 | } |
795 | |
796 | return status::success; |
797 | } |
798 | |
799 | status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, |
800 | const matmul_desc_t &mmd, memory_desc_t &src_md, |
801 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
802 | memory_desc_t &bias_md, primitive_attr_t &attr) { |
803 | const memory_desc_wrapper src_d(&src_md); |
804 | const memory_desc_wrapper weights_d(&weights_md); |
805 | const memory_desc_wrapper dst_d(&dst_md); |
806 | |
807 | bgmmc = zero<decltype(bgmmc)>(); |
808 | bgmmc.isa = isa; |
809 | bgmmc.nthr = dnnl_get_max_threads(); |
810 | bgmmc.brg_type = brgemm_addr; |
811 | |
812 | bgmmc.src_dt = src_d.data_type(); |
813 | bgmmc.dst_dt = dst_d.data_type(); |
814 | bgmmc.wei_dt = weights_d.data_type(); |
815 | |
816 | bgmmc.with_bias = mmd.bias_desc.format_kind != format_kind::undef; |
817 | bgmmc.bia_dt = bgmmc.with_bias ? mmd.bias_desc.data_type : data_type::undef; |
818 | bgmmc.s8s8_compensation_required |
819 | = isa == avx512_core_vnni && bgmmc.src_dt == s8; |
820 | bgmmc.ndims = dst_d.ndims(); |
821 | |
822 | brgemm_matmul_conf_utils_t bm_conf_utils(bgmmc, isa, attr, |
823 | src_d.format_kind() == format_kind::any, |
824 | weights_d.format_kind() == format_kind::any, |
825 | dst_d.format_kind() == format_kind::any, |
826 | bias_md.format_kind == format_kind::any); |
827 | |
828 | CHECK(check_isa_with_datatype(isa, bm_conf_utils)); |
829 | |
830 | bgmmc.is_amx = is_superset(isa, avx512_core_amx); |
831 | bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt); |
832 | bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt); |
833 | |
834 | bgmmc.is_bf32 = bm_conf_utils.is_bf32(); |
835 | |
836 | // Make BRGeMM compute MatMul as if it were in bfloat16, while down-convert |
837 | // happens during copy-buffer computations |
838 | if (bgmmc.is_bf32) { |
839 | bgmmc.src_dt = bf16; |
840 | bgmmc.wei_dt = bf16; |
841 | bgmmc.tr_a_dt_sz = types::data_type_size(bf16); |
842 | bgmmc.tr_b_dt_sz = types::data_type_size(bf16); |
843 | } else if (bm_conf_utils.is_f16() && bgmmc.isa == avx512_core_fp16) { |
844 | // Similar to bf32, convert input data before compute |
845 | bgmmc.src_dt = f32; |
846 | bgmmc.wei_dt = f32; |
847 | bgmmc.tr_a_dt_sz = types::data_type_size(f32); |
848 | bgmmc.tr_b_dt_sz = types::data_type_size(f32); |
849 | } |
850 | |
851 | bgmmc.acc_dt = bm_conf_utils.is_int8() ? s32 : f32; |
852 | |
853 | bgmmc.c_dt_sz = types::data_type_size(bgmmc.dst_dt); |
854 | bgmmc.acc_dt_sz = types::data_type_size(bgmmc.acc_dt); |
855 | if (bgmmc.with_bias) bgmmc.bias_dt_sz = types::data_type_size(bgmmc.bia_dt); |
856 | |
857 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
858 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
859 | bgmmc.with_scales = !src_scales.has_default_values() |
860 | || !wei_scales.has_default_values(); |
861 | if (bgmmc.with_scales) { |
862 | bgmmc.is_oscale_per_n = wei_scales.mask_ == 1 << (bgmmc.ndims - 1); |
863 | |
864 | // only common and per-oc-channel scales are supported |
865 | const bool oscales_ok = wei_scales.mask_ == 0 || bgmmc.is_oscale_per_n; |
866 | if (!oscales_ok) return status::unimplemented; |
867 | } |
868 | |
869 | const auto &p = attr.post_ops_; |
870 | bgmmc.with_sum = p.find(primitive_kind::sum) != -1; |
871 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
872 | bgmmc.with_eltwise = eltwise_ind != -1; |
873 | const int binary_ind = p.find(primitive_kind::binary); |
874 | bgmmc.with_binary = binary_ind != -1; |
875 | |
876 | if (!post_ops_ok(bgmmc, attr, dst_d)) return status::unimplemented; |
877 | |
878 | bgmmc.src_zp_type = get_zp_type(attr, DNNL_ARG_SRC); |
879 | bgmmc.wei_zp_type = get_zp_type(attr, DNNL_ARG_WEIGHTS); |
880 | bgmmc.dst_zp_type = get_zp_type(attr, DNNL_ARG_DST); |
881 | |
882 | if (!IMPLICATION(!bm_conf_utils.is_int8(), |
883 | everyone_is(brgemm_broadcast_t::none, bgmmc.src_zp_type, |
884 | bgmmc.wei_zp_type, bgmmc.dst_zp_type))) |
885 | return status::unimplemented; |
886 | |
887 | matmul_helper_t helper(src_d, weights_d, dst_d); |
888 | |
889 | bgmmc.batch_ndims = bgmmc.ndims - 2; |
890 | bgmmc.M = helper.M(); |
891 | bgmmc.N = helper.N(); |
892 | bgmmc.K = helper.K(); |
893 | bgmmc.batch = helper.batch(); |
894 | bgmmc.batch_without_first_dim |
895 | = bgmmc.batch_ndims > 1 ? helper.batch() / dst_d.dims()[0] : 0; |
896 | |
897 | bgmmc.bcast_A_desc.set_params( |
898 | src_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch); |
899 | bgmmc.bcast_B_desc.set_params( |
900 | weights_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch); |
901 | |
902 | // Dispatch small shapes to VNNI for better performance |
903 | const bool is_small_shapes = bgmmc.is_amx && bgmmc.ndims < 3 |
904 | && ((bgmmc.M == 1 && bgmmc.K == 256) |
905 | || (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256) |
906 | || bgmmc.K <= 16); |
907 | if (is_small_shapes) return status::unimplemented; |
908 | |
909 | // required granularity for k dimension |
910 | bgmmc.required_k_granularity |
911 | = bgmmc.is_amx ? data_type_vnni_granularity(bgmmc.wei_dt) : 1; |
912 | if (bgmmc.required_k_granularity == 0) return status::unimplemented; |
913 | bgmmc.wei_k_blk = data_type_vnni_simd_elems<avx512_core>(bgmmc.wei_dt); |
914 | |
915 | CHECK(bm_conf_utils.set_or_check_tags(src_md, dst_md, bias_md)); |
916 | CHECK(bm_conf_utils.set_or_check_B_tag(weights_md)); |
917 | |
918 | bgmmc.req_wei_vnni_downconvert = bm_conf_utils.wei_down_convert_to_vnni(); |
919 | |
920 | CHECK(attr.set_default_formats(&dst_md)); |
921 | |
922 | bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag); |
923 | |
924 | bgmmc.blocked_B = bm_conf_utils.get_blocked_B(); |
925 | bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(); |
926 | |
927 | bgmmc.transposed_A = (bm_conf_utils.check_is_transposed(bgmmc.src_tag) |
928 | || bgmmc.src_tag == adbc); |
929 | const bool lda_is_big_2pow |
930 | = (bm_conf_utils.is_bf16() |
931 | || (bgmmc.is_amx && bm_conf_utils.is_f16())) |
932 | && !bgmmc.transposed_A && math::is_pow2(bgmmc.K) && bgmmc.K >= 4096 |
933 | && bgmmc.M >= 1024; |
934 | const bool is_copy_a_required |
935 | = (bgmmc.is_amx |
936 | && ((bgmmc.K % bgmmc.required_k_granularity != 0) |
937 | || bm_conf_utils.is_bf32())) |
938 | || (bm_conf_utils.is_f16() && isa == avx512_core_fp16) |
939 | || bgmmc.wei_zp_type != brgemm_broadcast_t::none |
940 | || bgmmc.transposed_A || lda_is_big_2pow; |
941 | bgmmc.use_buffer_a = is_copy_a_required; |
942 | |
943 | // Supported computation with copy only part of A related to K_tail if |
944 | // is_copy_a_required == true, but the current performance measurements |
945 | // show worse performance for it in comparison with copy whole A approach |
946 | // (especially for big K sizes). |
947 | bgmmc.use_buffer_a_tail_only = false; |
948 | |
949 | const int dmax = nstl::min(bgmmc.ndims, 3); |
950 | for (int d = 0; d < dmax; d++) { |
951 | int dim = bgmmc.ndims - 1 - d; |
952 | bgmmc.A_strides[d] = bgmmc.a_dt_sz * src_d.blocking_desc().strides[dim]; |
953 | bgmmc.B_strides[d] |
954 | = bgmmc.b_dt_sz * weights_d.blocking_desc().strides[dim]; |
955 | bgmmc.C_strides[d] = bgmmc.c_dt_sz * dst_d.blocking_desc().strides[dim]; |
956 | } |
957 | |
958 | // BF32 'Hint' Heuristic: |
959 | // Under the following conditions, F32 through AVX512_CORE performs better |
960 | // than using BF32 arithmetic. |
961 | if (bgmmc.is_bf32 && (bgmmc.M < 8) |
962 | && ((bgmmc.wei_tag == abcd) || bm_conf_utils.is_any_B_layout())) |
963 | return status::unimplemented; |
964 | |
965 | // Heuristic tries to optimize the following parameters: |
966 | // - M_blk, M_Chunk |
967 | // - N_blk, N_Chunk |
968 | // - K_blk, batch_size |
969 | // - nthr_K |
970 | CHECK(compute_blocking_heuristic(bgmmc, bm_conf_utils)); |
971 | |
972 | if (bgmmc.wei_n_blk > bgmmc.N_blk |
973 | && IMPLICATION( |
974 | bgmmc.N == bgmmc.N_blk, bgmmc.N >= bgmmc.wei_n_blk)) { |
975 | bgmmc.wei_n_blk = bgmmc.N_blk; |
976 | CHECK(bm_conf_utils.update_and_check_B_tag( |
977 | weights_md, bgmmc.wei_n_blk)); |
978 | |
979 | bgmmc.req_wei_vnni_downconvert |
980 | = bm_conf_utils.wei_down_convert_to_vnni(); |
981 | } |
982 | |
983 | CHECK(bm_conf_utils.set_B_flags(weights_md)); |
984 | |
985 | bgmmc.M_tail = bgmmc.M % bgmmc.M_blk; |
986 | bgmmc.N_tail = bgmmc.N % bgmmc.N_blk; |
987 | bgmmc.K_tail = bgmmc.K > bgmmc.K_blk |
988 | ? rnd_up(bgmmc.K % bgmmc.K_blk, bgmmc.required_k_granularity) |
989 | : 0; |
990 | |
991 | bgmmc.LDB = bm_conf_utils.get_actual_LDB(); |
992 | bgmmc.LDD = bgmmc.dst_tag == acbd ? dst_d.blocking_desc().strides[2] |
993 | : bgmmc.N; |
994 | bgmmc.LDC |
995 | = bgmmc.use_buffer_c && bgmmc.nthr_k <= 1 ? bgmmc.N_blk : bgmmc.LDD; |
996 | |
997 | init_aux_values(bgmmc, src_d, weights_d, dst_d); |
998 | |
999 | return status::success; |
1000 | } |
1001 | |
1002 | void init_aux_values(brgemm_matmul_conf_t &bgmmc, |
1003 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d, |
1004 | const memory_desc_wrapper &dst_d) { |
1005 | |
1006 | bgmmc.M_chunk_elems = bgmmc.M_blk * bgmmc.M_chunk_size; |
1007 | bgmmc.N_chunk_elems = bgmmc.N_blk * bgmmc.N_chunk_size; |
1008 | bgmmc.K_chunk_elems = bgmmc.K_blk * bgmmc.brgemm_batch_size; |
1009 | bgmmc.M_chunks = div_up(bgmmc.M, bgmmc.M_chunk_elems); |
1010 | bgmmc.N_chunks = div_up(bgmmc.N, bgmmc.N_chunk_elems); |
1011 | bgmmc.K_chunks = div_up(bgmmc.K, bgmmc.K_chunk_elems); |
1012 | bgmmc.num_M_blocks = div_up(bgmmc.M, bgmmc.M_blk); |
1013 | bgmmc.num_N_blocks = div_up(bgmmc.N, bgmmc.N_blk); |
1014 | const int last_chunck_batch_size |
1015 | = (nstl::max(bgmmc.K, bgmmc.K_blk) |
1016 | - (bgmmc.K_chunks - 1) * bgmmc.K_chunk_elems) |
1017 | / bgmmc.K_blk; |
1018 | bgmmc.brgemm_batch_tail_size |
1019 | = last_chunck_batch_size % bgmmc.brgemm_batch_size; |
1020 | |
1021 | bgmmc.buffer_c_chunk_sz = bgmmc.acc_dt_sz * bgmmc.LDC |
1022 | * (bgmmc.nthr_k > 1 ? bgmmc.M : bgmmc.M_blk); |
1023 | bgmmc.buffer_c_per_thread_sz = bgmmc.buffer_c_chunk_sz |
1024 | * (bgmmc.nthr_k > 1 ? 1 : bgmmc.M_chunk_size * bgmmc.N_chunk_size); |
1025 | |
1026 | bgmmc.buffer_a_chunk_sz = bgmmc.tr_a_dt_sz * bgmmc.M_blk |
1027 | * (bgmmc.use_buffer_a_tail_only ? bgmmc.wei_k_blk : bgmmc.LDA); |
1028 | bgmmc.buffer_a_chunk_shift_along_m = bgmmc.buffer_a_chunk_sz |
1029 | * (bgmmc.use_buffer_a_tail_only ? 1 : bgmmc.brgemm_batch_size); |
1030 | bgmmc.buffer_a_per_thread_sz |
1031 | = bgmmc.buffer_a_chunk_shift_along_m * bgmmc.M_chunk_size; |
1032 | |
1033 | bgmmc.buffer_b_chunk_sz = bgmmc.tr_b_dt_sz * bgmmc.LDB |
1034 | * rnd_up(bgmmc.K_blk, bgmmc.wei_k_blk); |
1035 | bgmmc.buffer_b_per_thread_sz |
1036 | = bgmmc.buffer_b_chunk_sz * bgmmc.brgemm_batch_size; |
1037 | |
1038 | bgmmc.s8s8_comp_ithr_str |
1039 | = bgmmc.use_buffer_b ? bgmmc.wei_n_blk * bgmmc.N_chunk_size : 0; |
1040 | bgmmc.s8s8_comp_b_str = bgmmc.use_buffer_b |
1041 | ? 0 |
1042 | : div_up(bgmmc.N, bgmmc.wei_n_blk) * bgmmc.wei_n_blk; |
1043 | bgmmc.s8s8_comp_n_str = bgmmc.wei_n_blk; |
1044 | |
1045 | bgmmc.A_ptr_shift_b = 0; |
1046 | bgmmc.copy_A_src_stride = 0; |
1047 | if (bgmmc.src_tag == acbd || bgmmc.src_tag == adbc) { |
1048 | const dim_t factor = bgmmc.src_dt == f32 ? 2 : 1; |
1049 | const dim_t src_stride = bgmmc.src_tag == acbd ? bgmmc.A_strides[1] |
1050 | : bgmmc.A_strides[0]; |
1051 | bgmmc.copy_A_src_stride = nstl::min(src_d.blocking_desc().strides[0], |
1052 | src_stride / factor) |
1053 | * factor; |
1054 | const dim_t bcast_shift_b = bgmmc.src_tag == acbd ? bgmmc.K : bgmmc.M; |
1055 | bgmmc.A_ptr_shift_b |
1056 | = (bgmmc.bcast_A_desc.bcast_mask == 2 |
1057 | ? bcast_shift_b |
1058 | : src_d.blocking_desc().strides[0]) |
1059 | * bgmmc.a_dt_sz; |
1060 | } |
1061 | |
1062 | bgmmc.B_ptr_shift_b = 0; |
1063 | bgmmc.copy_B_wei_stride = 0; |
1064 | if (one_of(bgmmc.wei_tag, acbd, adbc)) { |
1065 | const dim_t factor = bgmmc.wei_dt == f32 ? 2 : 1; |
1066 | const dim_t wei_stride = bgmmc.wei_tag == acbd ? bgmmc.B_strides[1] |
1067 | : bgmmc.B_strides[0]; |
1068 | bgmmc.copy_B_wei_stride = nstl::min(wei_d.blocking_desc().strides[0], |
1069 | wei_stride / factor) |
1070 | * factor; |
1071 | const dim_t bcast_shift_b = bgmmc.wei_tag == acbd ? bgmmc.N : bgmmc.K; |
1072 | bgmmc.B_ptr_shift_b |
1073 | = (bgmmc.bcast_B_desc.bcast_mask == 2 |
1074 | ? bcast_shift_b |
1075 | : wei_d.blocking_desc().strides[0]) |
1076 | * bgmmc.b_dt_sz; |
1077 | } |
1078 | |
1079 | bgmmc.C_ptr_shift_b = bgmmc.dst_tag == acbd |
1080 | ? dst_d.blocking_desc().strides[0] * bgmmc.c_dt_sz |
1081 | : 0; |
1082 | |
1083 | bgmmc.has_zero_point_a = bgmmc.src_zp_type != brgemm_broadcast_t::none; |
1084 | bgmmc.has_zero_point_b = bgmmc.wei_zp_type != brgemm_broadcast_t::none; |
1085 | bgmmc.has_zero_point_c = bgmmc.dst_zp_type != brgemm_broadcast_t::none; |
1086 | bgmmc.post_ops_applicable = one_of(true, bgmmc.with_sum, bgmmc.with_bias, |
1087 | bgmmc.with_scales, bgmmc.with_eltwise, bgmmc.with_binary, |
1088 | bgmmc.acc_dt != bgmmc.dst_dt, bgmmc.s8s8_compensation_required, |
1089 | bgmmc.has_zero_point_a, bgmmc.has_zero_point_b, |
1090 | bgmmc.has_zero_point_c); |
1091 | |
1092 | bgmmc.zp_a_comp_shift_n = bgmmc.wei_n_blk; |
1093 | bgmmc.zp_a_comp_elems_per_thr |
1094 | = bgmmc.N_chunk_size * bgmmc.zp_a_comp_shift_n; |
1095 | |
1096 | const int s32_elems_in_cacheline = 16; |
1097 | bgmmc.zp_b_comp_result_shift_m = bgmmc.M_blk; |
1098 | bgmmc.zp_b_comp_buffer_start |
1099 | = bgmmc.M_chunk_size * bgmmc.zp_b_comp_result_shift_m; |
1100 | bgmmc.zp_b_comp_buffer_shift_m = s32_elems_in_cacheline * bgmmc.M_blk; |
1101 | bgmmc.zp_b_comp_elems_per_thr = bgmmc.M_chunk_size |
1102 | * (bgmmc.zp_b_comp_result_shift_m + bgmmc.zp_b_comp_buffer_shift_m); |
1103 | |
1104 | bgmmc.brgemm_batch_element_per_thr_sz = 16 * bgmmc.brgemm_batch_size; |
1105 | } |
1106 | |
1107 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
1108 | const brgemm_matmul_conf_t &bgmmc) { |
1109 | const size_t default_data_align = sizeof(char); |
1110 | if (bgmmc.brg_type == brgemm_addr) |
1111 | scratchpad.book(key_brgemm_primitive_batch, |
1112 | static_cast<size_t>(bgmmc.nthr) |
1113 | * bgmmc.brgemm_batch_element_per_thr_sz, |
1114 | sizeof(brgemm_batch_element_t), 64); |
1115 | |
1116 | if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only) |
1117 | scratchpad.book(key_brgemm_primitive_buffer_a, |
1118 | bgmmc.nthr * bgmmc.buffer_a_per_thread_sz, default_data_align); |
1119 | |
1120 | if (bgmmc.use_buffer_b) { |
1121 | scratchpad.book(key_brgemm_primitive_buffer_b, |
1122 | bgmmc.nthr * bgmmc.buffer_b_per_thread_sz, default_data_align); |
1123 | |
1124 | if (bgmmc.s8s8_compensation_required && (!bgmmc.blocked_B)) |
1125 | scratchpad.book(key_brgemm_primitive_buffer_comp, |
1126 | bgmmc.nthr * bgmmc.s8s8_comp_ithr_str, |
1127 | types::data_type_size(f32)); |
1128 | } |
1129 | |
1130 | if (bgmmc.use_buffer_c) |
1131 | scratchpad.book(key_brgemm_primitive_buffer, |
1132 | bgmmc.nthr * bgmmc.buffer_c_per_thread_sz, default_data_align); |
1133 | |
1134 | if (bgmmc.has_zero_point_a) { |
1135 | const auto num_elems = bgmmc.nthr * bgmmc.zp_a_comp_elems_per_thr; |
1136 | scratchpad.book(key_brgemm_primitive_zp_comp_a, num_elems, |
1137 | types::data_type_size(s32)); |
1138 | } |
1139 | |
1140 | if (bgmmc.has_zero_point_b) |
1141 | scratchpad.book(key_brgemm_primitive_zp_comp_b, |
1142 | bgmmc.nthr * bgmmc.zp_b_comp_elems_per_thr, |
1143 | types::data_type_size(s32)); |
1144 | |
1145 | if (is_superset(bgmmc.isa, avx512_core_amx)) |
1146 | scratchpad.book(key_conv_amx_tile_buffer, |
1147 | static_cast<size_t>(bgmmc.nthr) * bgmmc.wsp_tile_per_thr_bytes, |
1148 | default_data_align); |
1149 | } |
1150 | |
1151 | void matmul_amx_blocking_params_t::update_k_blocking_dependent_params() { |
1152 | k_chunk_elems_ = k_blk_ * k_chunk_size_; |
1153 | current_lda_ = get_actual_lda(); |
1154 | need_buf_c_ = is_buffer_c_required(); |
1155 | } |
1156 | |
1157 | void matmul_amx_blocking_params_t::set_blocking_parameters( |
1158 | int nthr_k, int n_blk, int n_chunk_size, int m_blk, int m_chunk_size) { |
1159 | nthr_k_ = nstl::max(1, nthr_k); |
1160 | nthr_mnb_ = nthr / nthr_k_; |
1161 | nthr_ = nthr_mnb_ * nthr_k_; |
1162 | n_blk_ = n_blk; |
1163 | n_chunk_size_ = n_chunk_size; |
1164 | m_blk_ = m_blk; |
1165 | m_chunk_size_ = m_chunk_size; |
1166 | if (one_of(0, n_blk_, n_chunk_size_, m_blk_, m_chunk_size_)) { |
1167 | k_blk_ = k_chunk_size_ = k_chunk_elems_ = 0; |
1168 | efficiency_score_ = 0.0f; |
1169 | return; |
1170 | } |
1171 | |
1172 | n_chunk_elems_ = n_blk_ * n_chunk_size_; |
1173 | m_chunk_elems_ = m_blk_ * m_chunk_size_; |
1174 | |
1175 | if (K < wei_k_blk) { |
1176 | k_blk_ = is_amx ? rnd_up(K, required_k_granularity) : K; |
1177 | k_chunk_size_ = 1; |
1178 | } else { |
1179 | dim_t k_per_thr = div_up(K, nthr_k_); |
1180 | k_blk_ = nstl::min( |
1181 | is_amx ? rnd_up(k_per_thr, required_k_granularity) : k_per_thr, |
1182 | static_cast<dim_t>(wei_k_blk)); |
1183 | k_chunk_size_ = nstl::min(nstl::max(static_cast<dim_t>(1), K / k_blk_), |
1184 | div_up(k_per_thr, k_blk_)); |
1185 | |
1186 | update_k_blocking_dependent_params(); |
1187 | auto chunk_sz = calculate_chunk_memory_size(); |
1188 | float k_div = (float)chunk_sz / L2_threshold(); |
1189 | if (k_div > 1.0f) |
1190 | k_chunk_size_ = static_cast<int>( |
1191 | static_cast<float>(k_chunk_size_) / k_div + 0.6f); |
1192 | |
1193 | const dim_t current_k_tail = K % k_blk_; |
1194 | if (current_k_tail == 0 && K % (k_blk_ * k_chunk_size_) == 0) { |
1195 | k_blk_ *= k_chunk_size_; |
1196 | k_chunk_size_ = 1; |
1197 | } else if (nthr_k_ == 1 |
1198 | && K == k_blk_ * k_chunk_size_ + current_k_tail) { |
1199 | k_blk_ *= k_chunk_size_; |
1200 | k_chunk_size_ = 2; |
1201 | } |
1202 | } |
1203 | |
1204 | update_k_blocking_dependent_params(); |
1205 | |
1206 | blocking_chunk_mem_size_ = calculate_chunk_memory_size(); |
1207 | |
1208 | efficiency_score_ = calculate_blocking_scores(); |
1209 | } |
1210 | |
1211 | // returns score for current blocking parameters' values in range [0, 1] |
1212 | // for parallel work over threads distribution score. Maximum scores - when |
1213 | // all threads have the same work amount w/o tails |
1214 | float matmul_amx_blocking_params_t::get_thread_balance_scores() { |
1215 | dim_t num_M_chunks = div_up(M, m_chunk_elems_); |
1216 | dim_t num_N_chunks = div_up(N, n_chunk_elems_); |
1217 | float mnb_parallel_score = batch * ((float)M / m_chunk_elems_) |
1218 | * ((float)N / n_chunk_elems_) |
1219 | / rnd_up(batch * num_M_chunks * num_N_chunks, nthr_mnb_) |
1220 | * nthr_mnb_; |
1221 | float k_parallel_score = 1.0f; |
1222 | if (nthr_k_ > 1) { |
1223 | dim_t num_K_chunks = div_up(K, k_chunk_elems_); |
1224 | const float parallel_reduction_penalty = 0.8f; |
1225 | k_parallel_score = parallel_reduction_penalty |
1226 | * ((float)K / k_chunk_elems_) / rnd_up(num_K_chunks, nthr_k_) |
1227 | * nthr_k_; |
1228 | } |
1229 | |
1230 | return mnb_parallel_score * k_parallel_score / nthr; |
1231 | } |
1232 | |
1233 | // returns score for current blocking parameters' values in range [0, 1] |
1234 | // for copied data reusage |
1235 | float matmul_amx_blocking_params_t::get_copied_data_reusage_scores() { |
1236 | const int desired_M_chunk = use_buffer_b |
1237 | ? nstl::min(4, rnd_up(static_cast<int>(M), m_blk_)) |
1238 | : 1; |
1239 | const int desired_N_chunk = use_buffer_a |
1240 | ? nstl::min(4, rnd_up(static_cast<int>(N), n_blk_)) |
1241 | : 1; |
1242 | |
1243 | return 0.5f |
1244 | * (nstl::min((float)m_chunk_size_ / desired_M_chunk, 1.0f) |
1245 | + nstl::min((float)n_chunk_size_ / desired_N_chunk, 1.0f)); |
1246 | } |
1247 | |
1248 | // returns score for current blocking parameters' values in range [0, 1] |
1249 | // for L2 utilization |
1250 | float matmul_amx_blocking_params_t::get_L2_utilization_scores() const { |
1251 | const float relative_difference_with_L2 |
1252 | = fabsf((float)L2_threshold() - blocking_chunk_mem_size_) |
1253 | / nstl::max(L2_threshold(), blocking_chunk_mem_size_); |
1254 | return 1.0f - relative_difference_with_L2; |
1255 | } |
1256 | |
1257 | // returns score for current blocking parameters' values in range [0, 1] |
1258 | // consists of 3 parts with its own weights: |
1259 | // 1) parallel work over threads distribution score |
1260 | // 2) L2 utilization score |
1261 | // 3) copied data re-usage score |
1262 | float matmul_amx_blocking_params_t::calculate_blocking_scores() { |
1263 | if (one_of(0, n_blk_, n_chunk_size_, m_blk_, m_chunk_size_, k_blk_, |
1264 | k_chunk_size_)) |
1265 | return 0.0f; |
1266 | |
1267 | const float nthr_coeff = nstl::min(nthr, 100); |
1268 | const float reusage_factor = 1.0f; |
1269 | const float balance_factor = (nthr_coeff - 1.0f) / nthr_coeff; |
1270 | const float cache_utilization_factor = 1.0f / nthr_coeff; |
1271 | |
1272 | float scores = cache_utilization_factor * get_L2_utilization_scores() |
1273 | + reusage_factor * get_copied_data_reusage_scores(); |
1274 | if (balance_factor > 0.0f) |
1275 | scores += balance_factor * get_thread_balance_scores(); |
1276 | return scores |
1277 | / (reusage_factor + balance_factor + cache_utilization_factor); |
1278 | } |
1279 | |
1280 | void matmul_amx_blocking_params_t::update_configuration( |
1281 | brgemm_matmul_conf_t &bgmmc) const { |
1282 | bgmmc.nthr_k = nthr_k_; |
1283 | bgmmc.M_blk = m_blk_; |
1284 | bgmmc.M_chunk_size = m_chunk_size_; |
1285 | bgmmc.N_blk = n_blk_; |
1286 | bgmmc.N_chunk_size = n_chunk_size_; |
1287 | |
1288 | bgmmc.K_blk = k_blk_; |
1289 | bgmmc.brgemm_batch_size = k_chunk_size_; |
1290 | |
1291 | bgmmc.use_buffer_c = need_buf_c_; |
1292 | bgmmc.LDA = current_lda_; |
1293 | } |
1294 | |
1295 | dim_t matmul_amx_blocking_params_t::get_actual_lda() { |
1296 | if (!use_buffer_a) return src_tag == acbd ? A_strides[1] / a_dt_sz : K; |
1297 | |
1298 | constexpr int bytes_in_cacheline = 64; |
1299 | const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz; |
1300 | dim_t lda = rnd_up(k_blk_, elems_in_cacheline); |
1301 | const bool is_big_2_pow = lda >= 512 && math::is_pow2(lda); |
1302 | if (is_big_2_pow) lda += elems_in_cacheline; |
1303 | return lda; |
1304 | } |
1305 | |
1306 | bool matmul_amx_blocking_params_t::is_buffer_c_required() { |
1307 | if (nthr_k_ > 1 && K > k_chunk_elems_) return true; |
1308 | |
1309 | return ((acc_dt != dst_dt || with_sum) |
1310 | && (K > k_chunk_elems_ || K % k_blk_ > 0)); |
1311 | } |
1312 | |
1313 | size_t matmul_amx_blocking_params_t::calculate_chunk_memory_size() { |
1314 | size_t A_chunk_sz = a_dt_sz * k_chunk_elems_ * m_chunk_elems_; |
1315 | size_t A_buf_sz = use_buffer_a |
1316 | ? tr_a_dt_sz * current_lda_ * k_chunk_size_ * m_chunk_elems_ |
1317 | : 0; |
1318 | size_t B_chunk_sz = b_dt_sz * k_chunk_elems_ * n_chunk_elems_; |
1319 | size_t B_buf_sz = use_buffer_b ? tr_b_dt_sz * n_blk_ * k_chunk_elems_ : 0; |
1320 | size_t C_chunk_sz = c_dt_sz * m_chunk_elems_ * n_chunk_elems_; |
1321 | size_t C_buf_sz |
1322 | = need_buf_c_ ? acc_dt_sz * m_chunk_elems_ * n_chunk_elems_ : 0; |
1323 | return A_chunk_sz + A_buf_sz + B_chunk_sz + B_buf_sz + C_chunk_sz |
1324 | + C_buf_sz; |
1325 | } |
1326 | |
1327 | } // namespace matmul |
1328 | } // namespace x64 |
1329 | } // namespace cpu |
1330 | } // namespace impl |
1331 | } // namespace dnnl |
1332 | |