1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_brgemm_inner_product_utils.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace cpu {
22namespace x64 {
23
24using namespace dnnl::impl::status;
25using namespace dnnl::impl::format_tag;
26using namespace dnnl::impl::memory_tracking::names;
27using namespace dnnl::impl::utils;
28
29using namespace prop_kind;
30using namespace data_type;
31
32namespace brgemm_inner_product_utils {
33
34int get_brg_kernel_index(const jit_brgemm_primitive_conf_t &jbgp,
35 bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail,
36 bool is_K_tail) {
37 int idx = 16 * (int)is_bs_tail + 8 * (int)do_initialization
38 + 4 * (int)is_M_tail + 2 * (int)is_N_tail + (int)is_K_tail;
39
40 assert(idx < max_num_brg_kernels_ip);
41 return idx;
42}
43
44int get_os_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust,
45 bool is_adjustment) {
46 const bool is_amx_int8
47 = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8);
48 const bool is_amx_xf16 = is_superset(jbgp.isa, avx512_core_amx)
49 && one_of(jbgp.wei_dt, bf16, f16);
50 const bool is_avx512_bf16 = jbgp.isa == avx512_core_bf16;
51 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
52 const bool is_bf32 = jbgp.is_bf32;
53
54 int max_os_block = 0;
55 int min_os_block = 0;
56
57 if (try_to_adjust
58 || one_of(jbgp.prop_kind, forward_training, forward_inference)) {
59 min_os_block = (is_amx_int8 || is_amx_xf16) ? 16 : 6;
60 // Currently gigantic flag is used to separate out transformer_lt and
61 // alexnet shapes for which larger os_block gives better performance.
62 // TODO: Figure out how much the constraints for `gigantic-ness` can
63 // be further loosened.
64 const bool is_gigantic_shape
65 = jbgp.ic >= 9216 && jbgp.oc >= 4096 && jbgp.os >= 512;
66 const bool use_128_block_for_amx
67 = is_amx_xf16 && jbgp.os % 128 == 0 && jbgp.oc > 128;
68 const bool enable_128_os_blocking
69 = use_128_block_for_amx || is_gigantic_shape;
70 max_os_block = enable_128_os_blocking ? 128 : 64;
71 // Work done by each thread is given by:
72 // (nb_oc / nb_oc_blocking) * (nb_os / nb_os_blocking)
73 // As a first approximation we take nb_oc_blocking = nb_os_blocking = 1
74 // Furthermore, we recall that
75 // nb_oc = oc / oc_block
76 // nb_os = os / os_block
77 //
78 // For f32 data type our objective is to determine the optimal value
79 // of os_block such that the work amount per thread ~ 2
80 if (is_f32 && !is_bf32) {
81 const bool small_work_amt_per_thread
82 = div_up(jbgp.os, max_os_block) * jbgp.nb_oc
83 < 1.8f * jbgp.nthr;
84 if (small_work_amt_per_thread)
85 max_os_block = saturate(16, max_os_block,
86 div_up(jbgp.os * jbgp.nb_oc, 2 * jbgp.nthr));
87 }
88 } else if (jbgp.prop_kind == backward_data) {
89 int plat_max_os_block = 0;
90 if (is_amx_xf16) {
91 plat_max_os_block
92 = (jbgp.ic >= 512 && jbgp.oc / jbgp.ic <= 4) ? 128 : 64;
93 } else if (is_avx512_bf16) {
94 plat_max_os_block = (jbgp.ic > 256) ? 128 : 64;
95 } else {
96 plat_max_os_block = 64;
97 }
98 max_os_block = nstl::min(plat_max_os_block, jbgp.os);
99 min_os_block = is_amx_xf16 ? 16 : 6;
100 } else if (jbgp.prop_kind == backward_weights) {
101 constexpr int amx_xf16_row = 64;
102 constexpr int amx_xf16_half_row = amx_xf16_row / 2;
103 // ensure that os_tail <= amx_xf16_half_row
104 const bool use_large_os_block = (jbgp.os >= amx_xf16_row)
105 && (jbgp.os % amx_xf16_row) <= amx_xf16_half_row;
106 return is_amx_xf16
107 ? (use_large_os_block ? amx_xf16_row : amx_xf16_half_row)
108 : 16;
109 } else
110 assert(!"unsupported case");
111
112 if (is_adjustment) max_os_block /= 2;
113 int os_block = 1;
114 for (int osb = max_os_block; osb >= min_os_block; osb--) {
115 if (osb == 0) break;
116 if (jbgp.os % osb == 0) {
117 os_block = osb;
118 break;
119 }
120 }
121 if (os_block == 1) os_block = nstl::min(jbgp.os, max_os_block);
122
123 return os_block;
124}
125
126std::vector<format_tag_t> get_desired_weights_tag(
127 const jit_brgemm_primitive_conf_t &jbgp) {
128 using namespace format_tag;
129 const int n_sp_dims = jbgp.ndims - 2;
130 const bool is_xf16 = utils::one_of(jbgp.wei_dt, bf16, f16);
131 const bool is_not_vnni_tag = jbgp.wei_dt == f32
132 || (jbgp.wei_dt == f16 && jbgp.isa == avx512_core_fp16);
133 if (is_not_vnni_tag) {
134 return {pick(n_sp_dims, OI16i64o, OIw16i64o, OIhw16i64o, OIdhw16i64o),
135 pick(n_sp_dims, OI16i32o, OIw16i32o, OIhw16i32o, OIdhw16i32o),
136 pick(n_sp_dims, OI16i16o, OIw16i16o, OIhw16i16o, OIdhw16i16o)};
137 } else if (is_xf16) {
138 if (is_superset(jbgp.isa, avx512_core_amx)) {
139 return {pick(n_sp_dims, OI16i64o2i, OIw16i64o2i, OIhw16i64o2i,
140 OIdhw16i64o2i),
141 pick(n_sp_dims, OI16i32o2i, OIw16i32o2i, OIhw16i32o2i,
142 OIdhw16i32o2i),
143 pick(n_sp_dims, OI16i16o2i, OIw16i16o2i, OIhw16i16o2i,
144 OIdhw16i16o2i)};
145 } else {
146 return {pick(n_sp_dims, OI8i64o2i, OIw8i64o2i, OIhw8i64o2i,
147 OIdhw8i64o2i),
148 pick(n_sp_dims, OI8i32o2i, OIw8i32o2i, OIhw8i32o2i,
149 OIdhw8i32o2i),
150 pick(n_sp_dims, OI8i16o2i, OIw8i16o2i, OIhw8i16o2i,
151 OIdhw8i16o2i)};
152 }
153 } else if (jbgp.wei_dt == data_type::s8) {
154 if (jbgp.isa == avx512_core_amx) {
155 return {pick(n_sp_dims, OI16i64o4i, OIw16i64o4i, OIhw16i64o4i,
156 OIdhw16i64o4i),
157 pick(n_sp_dims, OI16i32o4i, OIw16i32o4i, OIhw16i32o4i,
158 OIdhw16i32o4i),
159 pick(n_sp_dims, OI16i16o4i, OIw16i16o4i, OIhw16i16o4i,
160 OIdhw16i16o4i)};
161 } else {
162 return {pick(n_sp_dims, OI4i64o4i, OIw4i64o4i, OIhw4i64o4i,
163 OIdhw4i64o4i),
164 pick(n_sp_dims, OI4i32o4i, OIw4i32o4i, OIhw4i32o4i,
165 OIdhw4i32o4i),
166 pick(n_sp_dims, OI4i16o4i, OIw4i16o4i, OIhw4i16o4i,
167 OIdhw4i16o4i)};
168 }
169 } else {
170 return std::vector<format_tag_t> {format_tag::undef};
171 }
172}
173
174int get_oc_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust) {
175 const bool amx_xf16_bwd_d_noadjust = !try_to_adjust
176 && jbgp.prop_kind == backward_data
177 && is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32;
178 if (amx_xf16_bwd_d_noadjust) {
179 constexpr int amx_xf16_row = 64;
180 return amx_xf16_row;
181 } else if (!jbgp.is_wei_layout_any) {
182 std::vector<format_tag_t> weights_tag = get_desired_weights_tag(jbgp);
183 if (jbgp.wei_tag == weights_tag[0])
184 return 64;
185 else if (jbgp.wei_tag == weights_tag[1])
186 return 32;
187 else
188 return 16;
189 } else {
190 if (jbgp.oc >= 64) {
191 return 64;
192 } else if (jbgp.oc >= 32) {
193 return 32;
194 } else {
195 return 16;
196 }
197 }
198}
199
200int ip_fwd_get_nb_oc_blocking(
201 const jit_brgemm_primitive_conf_t &jbgp, bool is_adjustment) {
202 const int small_oc_threshold = 256;
203 const int small_os_threshold = 8;
204 if (jbgp.os <= small_os_threshold && jbgp.oc <= small_oc_threshold) {
205 // For small problems compute all oc blocks as a single chunk to avoid
206 // parallel section
207 return div_up(jbgp.oc,
208 (is_adjustment) ? ip_fwd_get_adjusted_oc_block(jbgp)
209 : get_oc_block(jbgp));
210 } else
211 return 1;
212}
213
214bool ip_fwd_adjust_thread_balance(const jit_brgemm_primitive_conf_t &jbgp) {
215 if (IMPLICATION(jbgp.is_wei_layout_any,
216 !is_superset(jbgp.isa, avx512_core_amx)))
217 return false;
218
219 int os_chunks = div_up(jbgp.os, get_os_block(jbgp, true, false));
220
221 int nb_oc = div_up(jbgp.oc, get_oc_block(jbgp, true));
222 int nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp);
223 int oc_chunks = div_up(nb_oc, nb_oc_blocking);
224
225 int work_amount = oc_chunks * os_chunks;
226 const auto work_per_thread = work_amount / jbgp.nthr;
227 float wb_ratio = static_cast<float>(work_amount % jbgp.nthr) / jbgp.nthr;
228
229 // return true if work distribution between threads has significant
230 // imbalance - amount of work per thread is small and the last iteration
231 // is able to load less than half of threads available for the current
232 // block sizes
233 return (work_per_thread < 3 && wb_ratio > 0.0f && wb_ratio < .5f);
234}
235
236int ip_fwd_get_adjusted_oc_block(const jit_brgemm_primitive_conf_t &jbgp) {
237 const bool is_amx_xf16
238 = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32;
239
240 // we can't change block size on forward and weights update (external)
241 // if layout is set by user, for backward data it can be choosen different
242 // from external in this case because copy routine
243 const bool not_adjustable_oc_block_size
244 = !jbgp.is_wei_layout_any && jbgp.prop_kind != backward_data;
245
246 if (IMPLICATION(is_amx_xf16 || jbgp.is_bf32, not_adjustable_oc_block_size))
247 return get_oc_block(jbgp);
248
249 int oc_block = get_oc_block(jbgp, true);
250 if (ip_fwd_adjust_thread_balance(jbgp)) {
251 oc_block = (oc_block > 16) ? oc_block / 2 : oc_block;
252 }
253
254 constexpr int amx_bf16_half_row = 32;
255 // ensure that oc_tail <= amx_bf16_half_row (requirement for brgemm kernel)
256 while (jbgp.oc % oc_block > amx_bf16_half_row)
257 oc_block /= 2;
258 return oc_block;
259}
260
261format_tag_t get_brgemm_ip_weights_tag(cpu_isa_t isa,
262 const jit_brgemm_primitive_conf_t &jbgp,
263 const memory_desc_t &weights_md) {
264 std::vector<format_tag_t> weights_tag = get_desired_weights_tag(jbgp);
265 if (!jbgp.is_wei_layout_any) {
266 return memory_desc_matches_one_of_tag(
267 weights_md, weights_tag[0], weights_tag[1], weights_tag[2]);
268 } else {
269 const int oc_block = ip_fwd_get_adjusted_oc_block(jbgp);
270 const int idx = (oc_block == 64 ? 0 : (oc_block == 32 ? 1 : 2));
271 return weights_tag[idx];
272 }
273}
274
275bool post_ops_ok(jit_brgemm_primitive_conf_t &jbgp,
276 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
277 using namespace injector;
278
279 const auto &post_ops = attr.post_ops_;
280
281 return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
282 {sum, eltwise, binary}, post_ops, &dst_d,
283 false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
284 true /*sum_requires_zp_zero*/,
285 {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar,
286 broadcasting_strategy_t::no_broadcast}));
287}
288
289status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
290 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
291 const bool is_amx_int8
292 = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8);
293 const bool is_amx_xf16 = is_superset(jbgp.isa, avx512_core_amx)
294 && one_of(jbgp.wei_dt, bf16, f16) && !jbgp.is_bf32;
295 const bool is_int8 = one_of(jbgp.src_dt, u8, s8) && jbgp.wei_dt == s8;
296 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
297 jbgp.is_amx = is_superset(jbgp.isa, avx512_core_amx);
298 const auto &p = attr.post_ops_;
299 jbgp.with_sum = p.find(primitive_kind::sum) != -1;
300 const int eltwise_ind = p.find(primitive_kind::eltwise);
301 jbgp.with_eltwise = eltwise_ind != -1;
302 const int binary_ind = p.find(primitive_kind::binary);
303 jbgp.with_binary = binary_ind != -1;
304 if (!post_ops_ok(jbgp, attr, dst_d)) return status::unimplemented;
305 if (jbgp.with_scales) {
306 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
307 jbgp.is_oc_scale = wei_scales.mask_ == 1 << 0;
308
309 // only common and per-oc-channel scales are supported
310 const bool wei_scales_ok = one_of(wei_scales.mask_, 0, 1 << 0);
311 if (!wei_scales_ok) return status::unimplemented;
312 }
313 const int min_ic_divisor = is_amx_int8 ? 4 : is_amx_xf16 ? 2 : 1;
314
315 jbgp.use_buffer_a = jbgp.ic % min_ic_divisor != 0;
316
317 constexpr int amx_int8_row = 64;
318 constexpr int amx_xf16_row = 32;
319 jbgp.ic_block = (is_amx_int8) ? amx_int8_row
320 : (is_amx_xf16) ? amx_xf16_row : jbgp.simd_w;
321 jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block);
322
323 // gemm-based inner product performs better when oc = 1
324 if (is_f32 && !jbgp.is_bf32 && jbgp.oc == 1) return status::unimplemented;
325
326 jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp);
327 jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block);
328 jbgp.nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp);
329
330 jbgp.os_block = get_os_block(jbgp, false, false);
331 jbgp.nb_os = div_up(jbgp.os, jbgp.os_block);
332
333 jbgp.nb_os_blocking = 1;
334 // Work done by each thread is given by:
335 // (nb_oc / nb_oc_blocking) * (nb_os / nb_os_blocking)
336 // For f32 data type we want to increase the nb_os_blocking such that
337 // * 1 <= nb_os_blocking <= 8 AND nb_os_blocking <= nb_os
338 // * Work amount per thread ~ 2
339 // * NOTE: here nb_oc_blocking = 1 as os is large
340 if (jbgp.os > 256 && is_f32 && !jbgp.is_bf32) {
341 jbgp.nb_os_blocking = saturate(1, nstl::min(8, jbgp.nb_os),
342 nstl::min(nstl::max(jbgp.oc / jbgp.os / 2, 1),
343 div_up(jbgp.nb_os * jbgp.nb_oc, 2 * jbgp.nthr)));
344 }
345
346 // NOTE: comment about is_gigantic_shape is in get_os_block()
347 const bool is_gigantic_shape = jbgp.oc >= 4096 && jbgp.os >= 512;
348 const int num_work_to_parallel = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking)
349 * div_up(jbgp.nb_os, jbgp.nb_os_blocking);
350
351 // TODO: although the below heuristic produces good performance for fp32,
352 // num_work_to_parallel needs to compared with nthr (instead of nb_ic)
353 // and os_block needs some further tuning.
354
355 // Use parallel IC reduction for f32 if we have:
356 // * very large input channels
357 // * work amount in mb and oc dimensions is small compared to nb_ic
358 // * number of threads > 1
359 // * not a "gigantic shape" since it already has a lot of parallelism
360 // in mb and oc dimensions w/o enabling IC parallelism
361 const bool use_parallel_ic_reduction = is_f32 && !jbgp.is_bf32
362 && jbgp.ic > 1024 && num_work_to_parallel < jbgp.nb_ic
363 && jbgp.nthr > 1 && !is_gigantic_shape;
364
365 // For os > 256, compute all os blocks as a single chunk when performing
366 // IC reduction. Note that this condition is empirical
367 if (use_parallel_ic_reduction && jbgp.os > 256 && jbgp.nb_os_blocking > 1)
368 jbgp.nb_os_blocking = jbgp.nb_os;
369
370 jbgp.nb_ic_blocking = 1;
371 jbgp.nthr_ic_b = 1;
372 const int k_blk = jbgp.is_bf32 ? amx_xf16_row : jbgp.ic_block;
373 const int max_nb_ic_blocking = nstl::min(64, jbgp.nb_ic);
374 if (IMPLICATION(!is_int8, jbgp.ic <= max_nb_ic_blocking * jbgp.ic_block)
375 && everyone_is(1, jbgp.kw, jbgp.kh, jbgp.kd)
376 && !jbgp.use_buffer_a) {
377 // Optimization: data & weights layouts allow to generate
378 // brgemm kernel with K = ic & batch = 1
379 // (K = rnd_dn(ic, ic_block), K_tail = ic % ic_block & batch = 1)
380 // instead of K = ic_block & batch = nb_ic_blocking
381 jbgp.K = jbgp.ic <= jbgp.ic_block ? jbgp.ic : rnd_dn(jbgp.ic, k_blk);
382 jbgp.nb_ic_blocking = jbgp.nb_ic;
383 jbgp.gemm_batch_size = 1;
384 } else if (!jbgp.use_buffer_a && use_parallel_ic_reduction) {
385 const int min_chunk_sz = 16;
386 const int num_min_chunk_sz = div_up(jbgp.nb_ic, min_chunk_sz);
387 float reduce_work = 0.5f * num_min_chunk_sz * jbgp.nb_os
388 + (float)num_min_chunk_sz / jbgp.nb_oc + 0.5f;
389 const int reduce_thr_groups = jbgp.nb_ic >= 1024 ? 8 : 4;
390 jbgp.nthr_ic_b
391 = saturate(1, nstl::min(reduce_thr_groups, num_min_chunk_sz),
392 int(reduce_work));
393 jbgp.nthr_ic_b = nstl::min(jbgp.nthr_ic_b, jbgp.nthr);
394 if (jbgp.nthr_ic_b > 1) {
395 jbgp.nb_ic_blocking = div_up(jbgp.nb_ic, jbgp.nthr_ic_b);
396 jbgp.nb_ic_blocking /= div_up(jbgp.nb_ic_blocking, 64);
397 }
398 jbgp.gemm_batch_size = jbgp.nb_ic_blocking;
399 jbgp.K = jbgp.ic_block;
400 } else {
401 // Note: Here, ic divided into K_blocks of gemm_batch
402 const int ic_blks_per_k = div_up(k_blk, jbgp.ic_block);
403 const int nb_k_blk = div_up(jbgp.ic, k_blk);
404 const int max_nb_k_blocking = div_up(max_nb_ic_blocking, ic_blks_per_k);
405 int nb_k_blocking = max_div(nb_k_blk, max_nb_k_blocking);
406 const bool small_nb_k_blk = nb_k_blk <= max_nb_k_blocking;
407 if (small_nb_k_blk && nb_k_blocking == 1)
408 nb_k_blocking = max_nb_k_blocking;
409
410 // For non small_nb_ic [i.e. that has nb_ic > 64] shape that has
411 // gcd(nb_ic, 64) < 16, we manually set nb_ic_blocking = 64
412 // the coefficients 64 [used in max_nb_ic_blocking] and 16 are empirical
413 const int min_nb_k_blocking = small_nb_k_blk ? 1 : 16;
414 if (nb_k_blocking < min_nb_k_blocking)
415 nb_k_blocking = max_nb_k_blocking;
416
417 jbgp.nb_ic_blocking = nb_k_blocking * ic_blks_per_k;
418 jbgp.K = k_blk;
419 jbgp.gemm_batch_size = nb_k_blocking;
420 }
421
422 // to avoid cache concurrent write access from different threads
423 size_t sc_size = sizeof(brgemm_batch_element_t);
424 jbgp.adjusted_batch_size
425 = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size);
426
427 if (is_amx_xf16 || jbgp.is_bf32) {
428 if (ip_fwd_adjust_thread_balance(jbgp)) {
429 // Adjust oc_block to improve thread balancing
430 jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp);
431 jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block);
432 jbgp.nb_oc_blocking = ip_fwd_get_nb_oc_blocking(jbgp, true);
433
434 // Adjust os_block to improve thread balancing
435 if (jbgp.oc <= 16
436 || types::data_type_size(jbgp.src_dt) * jbgp.mb * jbgp.ic
437 <= (size_t)platform::get_per_core_cache_size(2)) {
438 jbgp.os_block = get_os_block(jbgp, false, true);
439 jbgp.nb_os = div_up(jbgp.os, jbgp.os_block);
440 }
441 }
442 }
443 jbgp.use_buffer = (IMPLICATION(jbgp.dst_dt == jbgp.acc_dt, jbgp.with_sum))
444 || (jbgp.nthr_ic_b > 1);
445
446 // Configure matrix sizes
447 jbgp.M = jbgp.os_block;
448 jbgp.M_tail = jbgp.os % jbgp.os_block;
449
450 jbgp.N = jbgp.oc_block;
451 jbgp.N_tail = jbgp.oc % jbgp.oc_block;
452 jbgp.K_tail = jbgp.use_buffer_a ? 0 : jbgp.ic % jbgp.K;
453
454 jbgp.LDA = jbgp.use_buffer_a ? jbgp.K * jbgp.gemm_batch_size
455 : jbgp.ic_without_padding;
456 jbgp.LDB = jbgp.N;
457 jbgp.LDD = jbgp.oc_without_padding;
458 jbgp.LDC = (jbgp.use_buffer && jbgp.nthr_ic_b == 1) ? jbgp.N : jbgp.LDD;
459
460 if (jbgp.is_bf32) {
461 const float M = static_cast<float>(jbgp.M);
462 const float N = nstl::min<float>(jbgp.N, jbgp.oc);
463 const float K
464 = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.ic);
465 const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32);
466 // TODO: Adjust blocking such that bigger M, N, K are generated.
467 if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25))
468 return status::unimplemented;
469 }
470
471 return status::success;
472}
473
474status_t init_ip_conf_bwd_d(jit_brgemm_primitive_conf_t &jbgp) {
475 const bool is_amx_xf16
476 = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32;
477 const bool is_avx512_bf16 = jbgp.isa == avx512_core_bf16;
478 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
479 const bool is_bf16 = everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt);
480 jbgp.is_amx = is_amx_xf16;
481
482 constexpr int amx_xf16_granularity = 2;
483 jbgp.use_buffer_a = is_amx_xf16 && jbgp.oc % amx_xf16_granularity != 0;
484 jbgp.use_buffer_b = true;
485 jbgp.ip_bwd_d_global_b_transpose = false;
486
487 jbgp.oc_block = ip_fwd_get_adjusted_oc_block(jbgp);
488
489 // Optimization: for small shape we avoid large ic_block
490 // Thinking of os, ic, and oc as three dimensions, the boundary for small
491 // shapes is heuristically chosen via the following constraints:
492 // os <= 128 && max(ic, oc) <= 2048 && min(ic, oc) <= 1000
493 //
494 // TODO: Will the optimization be useful for bf16 data type
495 const bool avoid_max_ic_block = is_f32 && !jbgp.is_bf32 && jbgp.os <= 128
496 && nstl::max(jbgp.ic, jbgp.oc) <= 2048
497 && nstl::min(jbgp.ic, jbgp.oc) <= 1000;
498 jbgp.ic_block = !avoid_max_ic_block
499 && jbgp.ic >= (is_f32 && !jbgp.is_bf32 ? 512 : 64)
500 ? 64
501 : jbgp.ic >= 32 ? 32 : 16;
502
503 jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block);
504 jbgp.nb_ic_blocking = 1;
505 jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block);
506
507 jbgp.os_block = get_os_block(jbgp, false, false);
508
509 jbgp.nb_os = div_up(jbgp.os, jbgp.os_block);
510 jbgp.nb_os_blocking = 1;
511 int os_blocking_max = 2;
512 for (int bl = os_blocking_max; bl >= 1; bl--)
513 if (jbgp.nb_os % bl == 0) {
514 jbgp.nb_os_blocking = bl;
515 break;
516 }
517
518 if (is_amx_xf16 || jbgp.is_bf32) {
519 const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking);
520 const int work_amount = jbgp.nb_ic * os_chunks;
521 float wb_ratio = (float)work_amount / (float)jbgp.nthr;
522 if (wb_ratio != 1.f && wb_ratio < 2.f) {
523 jbgp.ic_block
524 = (jbgp.ic_block > 16) ? jbgp.ic_block / 2 : jbgp.ic_block;
525 jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block);
526 }
527 }
528
529 jbgp.nb_oc_blocking = 1;
530 const int oc_chunk_max_size = 64;
531 for (int bl = oc_chunk_max_size; bl >= 1; bl--)
532 if (jbgp.nb_oc % bl == 0) {
533 jbgp.nb_oc_blocking = bl;
534 break;
535 }
536
537 jbgp.nthr_oc_b = 1;
538 const int num_work_to_parallel = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking)
539 * div_up(jbgp.nb_os, jbgp.nb_os_blocking);
540 // Use oc reduction if we have
541 // * very large output channels
542 // * small work amount available to each thread
543 if ((num_work_to_parallel < 2 * jbgp.nthr
544 || jbgp.oc > (is_bf16 || jbgp.is_bf32 ? 4096 : 1024))) {
545 const int min_chunck_sz = (is_avx512_bf16) ? 32 : 16;
546 const int num_min_chunk_sz = div_up(jbgp.nb_oc, min_chunck_sz);
547 float reduce_work = 0.5f * num_min_chunk_sz * jbgp.nb_os
548 + (float)num_min_chunk_sz / jbgp.nb_ic + 0.5f;
549
550 // optimization for transformer_lt on CPX/SKX
551 const int max_nthr_oc_b
552 = (!is_amx_xf16 && !jbgp.is_bf32 && jbgp.oc > 32000)
553 ? jbgp.nthr / 2
554 : 4;
555 jbgp.nthr_oc_b = saturate(1, nstl::min(max_nthr_oc_b, num_min_chunk_sz),
556 int(reduce_work));
557 jbgp.nthr_oc_b = nstl::min(jbgp.nthr_oc_b, jbgp.nthr);
558 if (jbgp.nthr_oc_b > 1) {
559 jbgp.nb_oc_blocking = div_up(jbgp.nb_oc, jbgp.nthr_oc_b);
560 jbgp.nb_oc_blocking
561 /= div_up(jbgp.nb_oc_blocking, oc_chunk_max_size);
562 }
563 }
564 jbgp.gemm_batch_size = jbgp.nb_oc_blocking;
565 // to avoid cache concurrent write access from different threads
566 size_t sc_size = sizeof(brgemm_batch_element_t);
567 jbgp.adjusted_batch_size
568 = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size);
569
570 jbgp.use_buffer = jbgp.src_dt != jbgp.acc_dt || jbgp.nthr_oc_b > 1;
571
572 jbgp.M = jbgp.os_block;
573 jbgp.M_tail = jbgp.os % jbgp.os_block;
574
575 jbgp.K = jbgp.oc_block;
576 jbgp.N = jbgp.ic_block;
577 jbgp.N_tail = jbgp.ic % jbgp.ic_block;
578 jbgp.K_tail = jbgp.use_buffer_a ? 0 : jbgp.oc % jbgp.oc_block;
579
580 jbgp.LDA = jbgp.use_buffer_a ? jbgp.K * jbgp.nb_oc_blocking
581 : jbgp.oc_without_padding;
582 jbgp.LDB = jbgp.N;
583 jbgp.LDD = jbgp.ic_without_padding;
584 jbgp.LDC = jbgp.use_buffer && jbgp.nthr_oc_b == 1 ? jbgp.N : jbgp.LDD;
585
586 if (jbgp.is_bf32) {
587 const float M = static_cast<float>(jbgp.M);
588 const float N = nstl::min<float>(jbgp.N, jbgp.ic);
589 const float K
590 = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.oc);
591 const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32);
592 // TODO: Adjust blocking such that bigger M, N, K are generated.
593 if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25))
594 return status::unimplemented;
595 }
596
597 return status::success;
598}
599
600void thread_balance(const jit_brgemm_primitive_conf_t &j, int &nb_os_blocking_,
601 int &nb_oc_blocking_, int &nb_ic_blocking_, int &nthr_, int &nthr_mb_,
602 int &nthr_oc_b_, int &nthr_ic_b_) {
603 nthr_ = nthr_mb_ = nthr_oc_b_ = nthr_ic_b_ = 1;
604 nb_os_blocking_ = j.nb_os_blocking;
605 nb_oc_blocking_ = j.nb_oc_blocking;
606 nb_ic_blocking_ = j.nb_ic_blocking;
607
608 const bool is_f32 = everyone_is(f32, j.src_dt, j.wei_dt, j.dst_dt);
609 const bool is_xf16 = one_of(j.src_dt, bf16, f16) && (j.src_dt == j.dst_dt);
610
611 const int max_threads = j.nthr;
612 const int nthr = max_threads;
613 auto calc_mem_cost = [=](int nb_os_blocking, int nb_oc_blocking,
614 int nb_ic_blocking, int nthr_mb, int nthr_oc,
615 int nthr_ic) {
616 float src_size = static_cast<float>(j.ic) * j.mb;
617 float dst_size = static_cast<float>(j.oc) * j.mb;
618 float wei_size = static_cast<float>(j.ic) * j.oc;
619 int os_chunks = div_up(j.nb_os, nb_os_blocking);
620 int oc_chunks = div_up(j.nb_oc, nb_oc_blocking);
621 int ic_chunks = div_up(j.nb_ic, nb_ic_blocking);
622
623 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
624
625 float oi_channels_ratio = 0;
626 if (is_xf16) {
627 oi_channels_ratio = ((j.oc > 3 * j.ic && os_chunks > 1)
628 || (os_chunks == 1 && j.ic > j.oc))
629 ? src_size / dst_size
630 : dst_size / src_size;
631 } else {
632 oi_channels_ratio = src_size / dst_size;
633 }
634
635 auto get_src_coef = [=]() {
636 if (is_f32) {
637 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
638 src_coef *= types::data_type_size(j.src_dt);
639 src_coef *= 4 * saturate(1, 4, div_up(j.ic, 1024));
640 if (wei_compensation_scale < 2.0f)
641 src_coef += sqrtf(2.0f / wei_compensation_scale);
642 return src_coef;
643 }
644 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
645 src_coef *= 4 * types::data_type_size(j.src_dt);
646 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
647
648 return src_coef;
649 };
650
651 auto get_dst_coef = [=]() {
652 if (is_f32) {
653 float dst_coef = types::data_type_size(j.dst_dt)
654 * nstl::max(oi_channels_ratio, 1.0f);
655 return dst_coef;
656 }
657
658 return 2 * types::data_type_size(j.dst_dt)
659 * nstl::max(oi_channels_ratio, 1.0f);
660 };
661
662 auto get_wei_coef = [=]() {
663 if (is_f32) {
664 return nstl::max(
665 4.0f - j.mb / 2048 * wei_compensation_scale, 1.0f);
666 }
667
668 // limit the range of coefficient values to have more stable behavior
669 // for extreme cases
670 const float low_limit = 1.0f;
671 const float upper_limit = 1024.0f;
672 return utils::saturate(
673 low_limit, upper_limit, wei_compensation_scale);
674 };
675
676 float src_tr = 0.0f;
677 if (j.use_buffer_a && !is_f32) {
678 int src_tr_oc_par_work = div_up(os_chunks, nthr_mb)
679 * div_up(ic_chunks, nthr_ic) * nb_ic_blocking;
680 src_tr = get_src_coef() * div_up(src_tr_oc_par_work, nthr_oc)
681 * nb_os_blocking * j.os_block * j.ic_block;
682 }
683
684 float dst_tr = 0.0f;
685 if (j.use_buffer_b && !is_f32) {
686 int dst_tr_ic_par_work = div_up(os_chunks, nthr_mb)
687 * div_up(oc_chunks, nthr_oc) * nb_oc_blocking;
688 dst_tr = get_dst_coef() * div_up(dst_tr_ic_par_work, nthr_ic)
689 * nb_os_blocking * j.os_block * j.oc_block;
690 }
691
692 float src_v = get_src_coef() * div_up(os_chunks, nthr_mb)
693 * div_up(ic_chunks, nthr_ic) * nb_os_blocking * j.os_block
694 * nb_ic_blocking * j.ic_block;
695 float dst_v = get_dst_coef() * div_up(os_chunks, nthr_mb)
696 * div_up(oc_chunks, nthr_oc) * nb_os_blocking * j.os_block
697 * nb_oc_blocking * j.oc_block;
698
699 auto acc_dt_sz = types::data_type_size(j.acc_dt);
700 float wei_v = get_wei_coef() * acc_dt_sz * div_up(oc_chunks, nthr_oc)
701 * div_up(ic_chunks, nthr_ic) * nb_oc_blocking * j.oc_block
702 * nb_ic_blocking * j.ic_block;
703
704 float wei_r = 0;
705 if (nthr_mb > 1) {
706 auto wei_dt_sz = types::data_type_size(j.wei_dt);
707 int wei_r_mb_par_work = div_up(oc_chunks, nthr_oc)
708 * div_up(ic_chunks, nthr_ic) * nb_oc_blocking
709 * nb_ic_blocking;
710 wei_r = get_wei_coef() * div_up(wei_r_mb_par_work, nthr_mb)
711 * j.oc_block * j.ic_block
712 * (wei_dt_sz
713 + (is_f32 ? div_up(j.os, 1024) : 1) * nthr_mb
714 * acc_dt_sz);
715 }
716
717 return src_tr + dst_tr + src_v + dst_v + wei_v + wei_r;
718 };
719
720 float best_mem_cost = calc_mem_cost(nb_os_blocking_, nb_oc_blocking_,
721 nb_ic_blocking_, nthr_mb_, nthr_oc_b_, nthr_ic_b_);
722
723 /* Set range of values for nb_oc_blocking/nb_ic_blocking parameters to try.
724 Use powers-of-2 values to avoid potential issues on converting to
725 blocked weights layout stage
726 */
727 auto get_blk_values
728 = [](int max_blk_value, int init_blk, int dim_blk_limit) {
729 int val_1st = rnd_up_pow2(init_blk);
730 int val_end = nstl::min(max_blk_value, dim_blk_limit);
731 std::vector<int> values;
732 for (int val = val_1st; val <= val_end; val <<= 1)
733 values.push_back(val);
734 return values;
735 };
736
737 const int max_nb_oc_blocking_pow
738 = j.ip_bwd_w_local_buffers_for_input_tensors ? 4 : j.nb_oc_blocking;
739 auto nb_oc_blocking_values
740 = get_blk_values(max_nb_oc_blocking_pow, j.nb_oc_blocking, j.nb_oc);
741 const int max_nb_ic_blocking_pow
742 = j.ip_bwd_w_local_buffers_for_input_tensors ? 4 : j.nb_ic_blocking;
743 auto nb_ic_blocking_values
744 = get_blk_values(max_nb_ic_blocking_pow, j.nb_ic_blocking, j.nb_ic);
745
746 /* find the best thread distribution with lowest memory cost */
747 const int min_osb_chunk = is_f32 ? 32 : is_xf16 ? 8 : 1;
748 const int nthr_mb_max = nstl::min(nthr, div_up(j.nb_os, min_osb_chunk));
749 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
750 int nb_os_blocking = j.nb_os_blocking;
751 int os_chunks = div_up(j.nb_os, nb_os_blocking);
752 if (os_chunks < nthr_mb) {
753 int coef = saturate(1, 4, 2 * j.mb / (j.oc + j.ic));
754 int os_blocking_max = div_up(div_up(j.nb_os, coef), nthr_mb);
755 for (int bl = os_blocking_max; bl >= 1; bl--)
756 if (j.nb_os % bl == 0) {
757 nb_os_blocking = bl;
758 break;
759 }
760 }
761
762 const int nthr_par = nthr / nthr_mb;
763
764 for (auto nb_oc_blocking : nb_oc_blocking_values) {
765 int num_oc_chunks = div_up(j.nb_oc, nb_oc_blocking);
766 const int nthr_oc_b_max = nstl::min(nthr_par, num_oc_chunks);
767 for_(int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b)
768 for (auto nb_ic_blocking : nb_ic_blocking_values) {
769 int num_ic_chunks = div_up(j.nb_ic, nb_ic_blocking);
770
771 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, num_ic_chunks);
772 float mem_cost = calc_mem_cost(nb_os_blocking, nb_oc_blocking,
773 nb_ic_blocking, nthr_mb, nthr_oc_b, nthr_ic_b);
774 if (mem_cost <= best_mem_cost) {
775 best_mem_cost = mem_cost;
776 nb_os_blocking_ = nb_os_blocking;
777 nb_oc_blocking_ = nb_oc_blocking;
778 nb_ic_blocking_ = nb_ic_blocking;
779 nthr_mb_ = nthr_mb;
780 nthr_oc_b_ = nthr_oc_b;
781 nthr_ic_b_ = nthr_ic_b;
782 }
783 }
784 }
785 }
786
787 nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_;
788}
789
790status_t init_ip_conf_bwd_w(jit_brgemm_primitive_conf_t &jbgp) {
791 const bool is_amx_xf16
792 = is_superset(jbgp.isa, avx512_core_amx) && !jbgp.is_bf32;
793 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
794 const bool has_weights_buffer = jbgp.wei_dt != jbgp.acc_dt;
795 jbgp.is_amx = is_amx_xf16;
796
797 const int amx_xf16_row = 64;
798 const bool big_ic_blk_ok
799 = is_f32 && jbgp.ic % (4 * jbgp.simd_w) == 0 && jbgp.mb <= 128;
800 jbgp.ic_block = big_ic_blk_ok && !is_amx_xf16
801 ? 4 * jbgp.simd_w
802 : (is_amx_xf16 && has_weights_buffer) ? amx_xf16_row : jbgp.simd_w;
803 jbgp.ic_block_ext
804 = is_amx_xf16 || (jbgp.wei_dt == dnnl::impl::data_type::bf16) ? 32
805 : 16;
806
807 jbgp.oc_block = has_weights_buffer ? get_oc_block(jbgp)
808 : ip_fwd_get_adjusted_oc_block(jbgp);
809 jbgp.oc_block_ext = ip_fwd_get_adjusted_oc_block(jbgp);
810
811 jbgp.os_block = get_os_block(jbgp, false, false);
812 jbgp.nb_os = div_up(jbgp.os, jbgp.os_block);
813
814 jbgp.nb_ic = div_up(jbgp.ic, jbgp.ic_block);
815 jbgp.nb_oc = div_up(jbgp.oc, jbgp.oc_block);
816 jbgp.nb_oc_blocking = 1;
817 jbgp.nb_ic_blocking = jbgp.nb_ic % 2 ? 1 : 2;
818
819 // Configure matrix sizes
820 jbgp.M = jbgp.ic_block;
821 jbgp.M_tail = jbgp.ic % jbgp.ic_block;
822
823 jbgp.N = jbgp.oc_block;
824 jbgp.N_tail = jbgp.oc % jbgp.oc_block;
825
826 constexpr int amx_xf16_granularity = 2;
827 // sanity check, must hold for transpose routines to work fine
828 assert(IMPLICATION(is_amx_xf16, jbgp.os_block % amx_xf16_granularity == 0));
829 const bool do_rnd_os = is_amx_xf16 && jbgp.os % amx_xf16_granularity != 0;
830
831 jbgp.K = jbgp.os_block;
832 jbgp.K_tail = (jbgp.os % jbgp.os_block) + (do_rnd_os ? 1 : 0);
833
834 jbgp.nb_os_blocking = 1;
835 int os_blocking_max = (is_amx_xf16 && jbgp.nb_os >= 64)
836 ? (types::data_type_size(jbgp.src_dt) * jbgp.mb * jbgp.ic
837 < platform::get_per_core_cache_size(2))
838 ? 8
839 : 4
840 : nstl::min(64, jbgp.nb_os);
841
842 for (int bl = os_blocking_max; bl >= 1; bl--)
843 if (jbgp.nb_os % bl == 0) {
844 jbgp.nb_os_blocking = bl;
845 break;
846 }
847
848 jbgp.use_buffer_a = true;
849 const bool is_oc_big_2_pow = jbgp.oc >= 512 && math::is_pow2(jbgp.oc);
850 const bool is_huge_oc = jbgp.oc >= 4 * 1024;
851 jbgp.use_buffer_b = jbgp.dst_dt != f32 || is_oc_big_2_pow || is_huge_oc;
852 const bool os_dim_dominating = jbgp.os >= 5 * (jbgp.ic + jbgp.oc);
853 const int big_nb_os_threshold = is_amx_xf16 ? 64 : 256;
854 jbgp.ip_bwd_w_local_buffers_for_input_tensors
855 = is_amx_xf16 && jbgp.nb_os >= big_nb_os_threshold;
856 jbgp.harness = os_dim_dominating && jbgp.nb_os >= big_nb_os_threshold
857 ? harness_mb_reduction
858 : harness_2d_reduction;
859
860 int nb_os_blocking, nb_oc_blocking, nb_ic_blocking, nthr, nthr_mb, nthr_oc,
861 nthr_ic;
862 // Caution: thread_balance requires `use_buffer_a` and `use_buffer_b`
863 // fields of jbgp to be properly set
864 thread_balance(jbgp, nb_os_blocking, nb_oc_blocking, nb_ic_blocking, nthr,
865 nthr_mb, nthr_oc, nthr_ic);
866
867 jbgp.nb_os_blocking = nb_os_blocking;
868 jbgp.nb_oc_blocking = nb_oc_blocking;
869 jbgp.nb_ic_blocking = nb_ic_blocking;
870 jbgp.nthr = nthr;
871 jbgp.nthr_mb = nthr_mb;
872 jbgp.nthr_oc_b = nthr_oc;
873 jbgp.nthr_ic_b = nthr_ic;
874
875 jbgp.gemm_batch_size = jbgp.nb_os_blocking;
876 // to avoid cache concurrent write access from different threads
877 size_t sc_size = sizeof(brgemm_batch_element_t);
878 jbgp.adjusted_batch_size
879 = div_up(rnd_up(jbgp.gemm_batch_size * sc_size, 4096), sc_size);
880
881 jbgp.use_buffer = IMPLICATION(!has_weights_buffer, jbgp.nthr_mb > 1);
882
883 jbgp.LDA = jbgp.K;
884 jbgp.LDB = (jbgp.use_buffer_b) ? jbgp.N * jbgp.nb_oc_blocking
885 : jbgp.oc_without_padding;
886 jbgp.LDC = jbgp.LDD = jbgp.N;
887
888 if (jbgp.is_bf32) {
889 const float M = static_cast<float>(jbgp.M);
890 const float N = nstl::min<float>(jbgp.N, jbgp.oc);
891 const float K
892 = nstl::min<float>(jbgp.K * jbgp.gemm_batch_size, jbgp.os);
893 const float tmul_efficiency = (M / 16) * (N / 16) * (K / 32);
894 // TODO: Adjust blocking such that bigger M, N, K are generated.
895 if (one_of(true, M <= 8, K <= 8, N < 16, tmul_efficiency <= 2.25))
896 return status::unimplemented;
897 }
898
899 return status::success;
900}
901
902size_t buf_dt_size(data_type_t dt, cpu_isa_t isa) {
903 const auto buf_dt = isa == avx512_core_fp16 && dt == data_type::f16
904 ? data_type::f32
905 : dt;
906 return types::data_type_size(buf_dt);
907}
908
909status_t init_ip_conf(cpu_isa_t isa, jit_brgemm_primitive_conf_t &jbgp,
910 const inner_product_desc_t &ipd, memory_desc_t &src_md,
911 memory_desc_t &weights_md, memory_desc_t &dst_md,
912 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
913 const memory_desc_wrapper src_d(&src_md);
914 const memory_desc_wrapper weights_d(&weights_md);
915 const memory_desc_wrapper dst_d(&dst_md);
916
917 using namespace prop_kind;
918 if (!mayiuse(avx512_core) && !mayiuse(avx2_vnni_2))
919 return status::unimplemented;
920
921 int ndims = src_d.ndims();
922 if (weights_d.ndims() != ndims || dst_d.ndims() != 2)
923 return status::unimplemented;
924
925 jbgp = zero<decltype(jbgp)>();
926 jbgp.ndims = ndims;
927 jbgp.isa = isa;
928 jbgp.prop_kind = ipd.prop_kind;
929 jbgp.ngroups = 1;
930 jbgp.mb = src_d.dims()[0];
931 jbgp.os = jbgp.mb;
932 jbgp.oc_without_padding = dst_d.dims()[1];
933 jbgp.oc = jbgp.oc_without_padding;
934 jbgp.ic_without_padding = src_d.dims()[1];
935 jbgp.ic = jbgp.ic_without_padding;
936 jbgp.id = (ndims == 5) ? src_d.dims()[2] : 1;
937 jbgp.ih = (ndims < 4) ? 1 : src_d.dims()[ndims - 2];
938 jbgp.iw = (ndims < 3) ? 1 : src_d.dims()[ndims - 1];
939 jbgp.od = jbgp.oh = jbgp.ow = 1;
940 jbgp.kd = (ndims == 5) ? weights_d.dims()[2] : 1;
941 jbgp.kh = (ndims < 4) ? 1 : weights_d.dims()[ndims - 2];
942 jbgp.kw = (ndims < 3) ? 1 : weights_d.dims()[ndims - 1];
943 jbgp.stride_d = jbgp.stride_h = jbgp.stride_w = 1;
944
945 if (!everyone_is(1, jbgp.ow, jbgp.oh, jbgp.od))
946 return status::unimplemented;
947 if (jbgp.kw != jbgp.iw || jbgp.kh != jbgp.ih || jbgp.kd != jbgp.id)
948 return status::unimplemented;
949 if (!everyone_is(1, jbgp.kw, jbgp.kh, jbgp.kd))
950 return status::unimplemented;
951
952 const int full_simd_w = 16;
953 jbgp.simd_w = full_simd_w;
954
955 jbgp.with_bias
956 = pick_by_prop_kind(jbgp.prop_kind, ipd.bias_desc.format_kind,
957 format_kind::undef, ipd.diff_bias_desc.format_kind)
958 != format_kind::undef;
959
960 jbgp.src_dt = src_d.data_type();
961 jbgp.dst_dt = dst_d.data_type();
962 jbgp.wei_dt = weights_d.data_type();
963 jbgp.bia_dt = jbgp.with_bias
964 ? pick_by_prop_kind(jbgp.prop_kind, ipd.bias_desc.data_type,
965 data_type::undef, ipd.diff_bias_desc.data_type)
966 : data_type::undef;
967 jbgp.signed_input = one_of(isa, avx512_core_vnni, avx512_core_bf16)
968 && jbgp.src_dt == s8;
969 const bool is_int8 = one_of(jbgp.src_dt, u8, s8) && jbgp.wei_dt == s8;
970 const bool is_bf16
971 = everyone_is(bf16, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt)
972 || pick_by_prop_kind(jbgp.prop_kind,
973 everyone_is(bf16, jbgp.src_dt, jbgp.wei_dt)
974 && jbgp.dst_dt == f32,
975 everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt)
976 && jbgp.src_dt == f32,
977 everyone_is(bf16, jbgp.src_dt, jbgp.dst_dt)
978 && jbgp.wei_dt == f32);
979 const bool is_f16 = everyone_is(f16, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt)
980 || pick_by_prop_kind(jbgp.prop_kind,
981 everyone_is(f16, jbgp.src_dt, jbgp.wei_dt)
982 && jbgp.dst_dt == f32,
983 everyone_is(f16, jbgp.wei_dt, jbgp.dst_dt)
984 && jbgp.src_dt == f32,
985 everyone_is(f16, jbgp.src_dt, jbgp.dst_dt)
986 && jbgp.wei_dt == f32);
987 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
988 jbgp.is_bf32 = is_f32 && attr.fpmath_mode_ == fpmath_mode::bf16
989 && isa == avx512_core_amx;
990
991 if (!IMPLICATION(is_int8,
992 one_of(isa, avx512_core_vnni, avx512_core_bf16,
993 avx512_core_amx)))
994 return status::unimplemented;
995 if (!IMPLICATION(is_bf16,
996 one_of(isa, avx2_vnni_2, avx512_core_bf16, avx512_core_amx)))
997 return status::unimplemented;
998 if (!IMPLICATION(is_f32, jbgp.is_bf32 || (isa == avx512_core)))
999 return status::unimplemented;
1000 if (!IMPLICATION(is_f16,
1001 one_of(isa, avx2_vnni_2, avx512_core_fp16,
1002 avx512_core_amx_fp16)))
1003 return status::unimplemented;
1004
1005 if (!one_of(true, is_int8, is_bf16, is_f16, is_f32))
1006 return status::unimplemented;
1007 if (is_int8) {
1008 jbgp.acc_dt = s32;
1009 jbgp.with_scales = true;
1010 } else
1011 jbgp.acc_dt = f32;
1012
1013 // Dispatch small shapes to VNNI for better performance
1014 const bool is_amx_int8
1015 = jbgp.isa == avx512_core_amx && one_of(jbgp.wei_dt, s8, u8);
1016 const auto amx_row
1017 = static_cast<int32_t>(data_type_vnni_granularity(jbgp.src_dt))
1018 * jbgp.simd_w;
1019 const auto max_size = is_amx_int8 ? 1024 : 512;
1020 const bool is_small_shapes
1021 = (jbgp.os <= 16 && jbgp.ic <= amx_row && jbgp.oc <= amx_row)
1022 || (jbgp.ic <= max_size && jbgp.oc <= max_size && jbgp.mb == 1
1023 && jbgp.ic % amx_row != 0);
1024 if (one_of(jbgp.isa, avx512_core_amx, avx512_core_amx) && is_small_shapes)
1025 return status::unimplemented;
1026
1027 auto set_or_check_tags = [&]() -> status_t {
1028 using namespace format_tag;
1029 format_tag_t desired_src_tag = pick(ndims - 2, nc, ncw, nchw, ncdhw);
1030 format_tag_t desired_dst_tag = nc;
1031
1032 if (src_d.format_kind() == format_kind::any) {
1033 CHECK(memory_desc_init_by_tag(src_md, desired_src_tag));
1034 jbgp.src_tag = desired_src_tag;
1035 } else {
1036 jbgp.src_tag
1037 = memory_desc_matches_one_of_tag(src_md, desired_src_tag);
1038 }
1039
1040 if (dst_d.format_kind() == format_kind::any) {
1041 CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag));
1042 jbgp.dst_tag = desired_dst_tag;
1043 } else {
1044 jbgp.dst_tag = memory_desc_matches_one_of_tag(dst_md, nc);
1045 }
1046
1047 if (one_of(format_tag::undef, jbgp.src_tag, jbgp.dst_tag))
1048 return status::unimplemented;
1049
1050 if (jbgp.with_bias && bias_md.format_kind == format_kind::any)
1051 CHECK(memory_desc_init_by_tag(bias_md, x));
1052
1053 jbgp.is_wei_layout_any = weights_d.format_kind() == format_kind::any;
1054
1055 memory_desc_t want_wei_md = weights_md;
1056 jbgp.wei_tag = get_brgemm_ip_weights_tag(isa, jbgp, weights_md);
1057 if (jbgp.wei_tag == format_tag::undef) return status::unimplemented;
1058 CHECK(memory_desc_init_by_tag(want_wei_md, jbgp.wei_tag));
1059
1060 if (jbgp.signed_input) {
1061 want_wei_md.extra.flags = 0
1062 | memory_extra_flags::compensation_conv_s8s8
1063 | memory_extra_flags::scale_adjust;
1064 want_wei_md.extra.compensation_mask = (1 << 0);
1065 want_wei_md.extra.scale_adjust
1066 = platform::s8s8_weights_scale_factor();
1067 if (weights_md.format_kind != format_kind::any
1068 && want_wei_md != weights_md)
1069 return status::unimplemented;
1070 }
1071 weights_md = want_wei_md;
1072 return status::success;
1073 };
1074
1075 jbgp.brg_type = brgemm_addr;
1076 jbgp.nthr = nthreads;
1077
1078 jbgp.use_uker = true;
1079 jbgp.use_interleave_stores = jbgp.use_uker;
1080 if (jbgp.use_uker)
1081 jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
1082 CHECK(set_or_check_tags());
1083 CHECK(attr.set_default_formats(&dst_md));
1084
1085 switch (jbgp.prop_kind) {
1086 case forward_training:
1087 case forward_inference:
1088 CHECK(init_ip_conf_fwd(jbgp, attr, dst_d));
1089 break;
1090 case backward_data: CHECK(init_ip_conf_bwd_d(jbgp)); break;
1091 case backward_weights: CHECK(init_ip_conf_bwd_w(jbgp)); break;
1092 default: assert(!"invalid prop_kind"); return invalid_arguments;
1093 }
1094
1095 return status::success;
1096}
1097
1098void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1099 const jit_brgemm_primitive_conf_t &jbgp) {
1100
1101 size_t sc_size = sizeof(brgemm_batch_element_t);
1102 size_t n_elems = (size_t)jbgp.nthr * jbgp.adjusted_batch_size;
1103
1104 if (jbgp.brg_type == brgemm_addr) {
1105 scratchpad.book(key_brgemm_primitive_batch, n_elems, sc_size, 64);
1106 }
1107 if (jbgp.use_buffer) {
1108 size_t nelements = (size_t)jbgp.nthr * jbgp.LDC * jbgp.M;
1109 if (jbgp.prop_kind == dnnl_backward_weights
1110 && (jbgp.nthr_mb > 1 || jbgp.harness == harness_mb_reduction)) {
1111 const size_t n_reduction_buffers = jbgp.nthr_mb > 1
1112 ? jbgp.nthr_mb - (jbgp.wei_dt == f32)
1113 : 1;
1114 const size_t num_ic_chunks
1115 = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking);
1116 const size_t num_oc_chunks
1117 = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking);
1118 nelements = (size_t)n_reduction_buffers * num_ic_chunks
1119 * num_oc_chunks * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking
1120 * jbgp.ic_block * jbgp.oc_block;
1121 } else if (jbgp.prop_kind == dnnl_backward_weights
1122 && jbgp.nthr_mb == 1) {
1123 nelements = (size_t)jbgp.nthr * jbgp.nb_ic_blocking * jbgp.ic_block
1124 * jbgp.nb_oc_blocking * jbgp.oc_block;
1125 } else if (jbgp.prop_kind == dnnl_backward_data && jbgp.nthr_oc_b > 1) {
1126 const int adj_buffers = (jbgp.src_dt == f32) ? 1 : 0;
1127 int n_reduction_buffers = jbgp.nthr_oc_b - adj_buffers;
1128 nelements = (size_t)n_reduction_buffers * jbgp.LDC * jbgp.os;
1129 } else if (one_of(jbgp.prop_kind, forward_training, forward_inference)
1130 && jbgp.nthr_ic_b > 1) {
1131 const bool need_extra_buffer
1132 = (jbgp.dst_dt == f32 && jbgp.with_sum);
1133 int n_reduction_buffers = jbgp.nthr_ic_b - !need_extra_buffer;
1134 nelements = (size_t)n_reduction_buffers * jbgp.oc * jbgp.os;
1135 }
1136 scratchpad.book(key_brgemm_primitive_buffer, nelements,
1137 types::data_type_size(jbgp.acc_dt));
1138 }
1139 if (jbgp.use_buffer_a && jbgp.prop_kind == dnnl_backward_weights) {
1140 const dim_t num_ic_chunks_per_thread
1141 = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1142 ? 1
1143 : div_up(div_up(jbgp.nb_ic, jbgp.nb_ic_blocking),
1144 jbgp.nthr_ic_b);
1145 const dim_t num_os_chunks_per_thread
1146 = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1147 ? 1
1148 : div_up(div_up(jbgp.nb_os, jbgp.nb_os_blocking), jbgp.nthr_mb);
1149 const dim_t num_elems_per_thread = num_ic_chunks_per_thread
1150 * num_os_chunks_per_thread * jbgp.gemm_batch_size
1151 * jbgp.os_block * jbgp.ic_block * jbgp.nb_ic_blocking;
1152 scratchpad.book(key_brgemm_primitive_buffer_a,
1153 jbgp.nthr * num_elems_per_thread,
1154 buf_dt_size(jbgp.src_dt, jbgp.isa));
1155 } else if (jbgp.use_buffer_a && jbgp.prop_kind == dnnl_backward_data) {
1156 scratchpad.book(key_brgemm_primitive_buffer_a,
1157 (size_t)jbgp.nthr * jbgp.os_block * jbgp.LDA,
1158 buf_dt_size(jbgp.dst_dt, jbgp.isa));
1159 } else if (jbgp.use_buffer_a) { // FWD
1160 scratchpad.book(key_brgemm_primitive_buffer_a,
1161 (size_t)jbgp.nthr * jbgp.LDA * jbgp.os_block
1162 * jbgp.nb_os_blocking,
1163 buf_dt_size(jbgp.src_dt, jbgp.isa));
1164 }
1165
1166 if (jbgp.use_buffer_b && jbgp.prop_kind == dnnl_backward_weights) {
1167 int num_os_chunks_per_thread
1168 = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1169 ? 1
1170 : div_up(div_up(jbgp.nb_os, jbgp.nb_os_blocking), jbgp.nthr_mb);
1171 const dim_t num_elems_per_thread = num_os_chunks_per_thread
1172 * jbgp.gemm_batch_size * jbgp.os_block * jbgp.LDB;
1173 scratchpad.book(key_brgemm_primitive_buffer_b,
1174 (size_t)jbgp.nthr * num_elems_per_thread,
1175 buf_dt_size(jbgp.dst_dt, jbgp.isa));
1176 }
1177
1178 if (jbgp.use_buffer_b && jbgp.prop_kind == dnnl_backward_data) {
1179 auto size_B = (size_t)jbgp.LDB * rnd_up(jbgp.K, 2);
1180
1181 if (!jbgp.ip_bwd_d_global_b_transpose)
1182 scratchpad.book(key_brgemm_primitive_buffer_b,
1183 (dim_t)jbgp.nthr * jbgp.gemm_batch_size * size_B,
1184 buf_dt_size(jbgp.wei_dt, jbgp.isa));
1185 else
1186 scratchpad.book(key_brgemm_primitive_buffer_b,
1187 (dim_t)jbgp.nb_oc * jbgp.nb_ic * size_B,
1188 buf_dt_size(jbgp.wei_dt, jbgp.isa));
1189 }
1190
1191 if (jbgp.prop_kind == dnnl_backward_weights && jbgp.with_bias
1192 && (jbgp.bia_dt != f32 || jbgp.nthr_mb > 1)) {
1193 int nbuffers = jbgp.nthr_mb - (jbgp.bia_dt == f32);
1194 scratchpad.book(key_iprod_bias_bf16_convert_wsp,
1195 (size_t)nbuffers * jbgp.oc, types::data_type_size(jbgp.acc_dt));
1196 }
1197
1198 if (dnnl_thr_syncable() && jbgp.prop_kind == dnnl_backward_weights)
1199 scratchpad.book<simple_barrier::ctx_t>(
1200 key_conv_wei_bia_reduction_bctx, 1);
1201
1202 if (jbgp.is_amx)
1203 scratchpad.book(key_conv_amx_tile_buffer,
1204 (size_t)jbgp.nthr * jbgp.amx_buf_size_per_thread, sizeof(char));
1205}
1206
1207} // namespace brgemm_inner_product_utils
1208
1209} // namespace x64
1210} // namespace cpu
1211} // namespace impl
1212} // namespace dnnl
1213