1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_types.h"
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/platform.hpp"
27#include "cpu/x64/brgemm/brgemm_utils.hpp"
28#include "cpu/x64/cpu_isa_traits.hpp"
29#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
30#include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp"
31#include "cpu/x64/jit_generator.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37
38using namespace dnnl::impl::status;
39using namespace dnnl::impl::format_tag;
40using namespace dnnl::impl::memory_tracking::names;
41using namespace dnnl::impl::utils;
42
43using namespace prop_kind;
44using namespace data_type;
45
46namespace brgemm_convolution_bwd_utils {
47
48inline status_t init_tag(format_tag_t &tag, memory_desc_t &md,
49 const memory_desc_wrapper &mdw, const format_tag_t tag_value) {
50 if (mdw.format_kind() == format_kind::any) {
51 CHECK(memory_desc_init_by_tag(md, tag_value));
52 tag = tag_value;
53 } else {
54 tag = mdw.matches_one_of_tag(tag_value);
55 }
56
57 if (tag != tag_value) return status::unimplemented;
58
59 return status::success;
60}
61
62bool is_amx(cpu_isa_t isa) {
63 return is_superset(isa, avx512_core_amx);
64}
65
66bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr,
67 const memory_desc_wrapper &dst_d, bool enable_postops) {
68 using namespace injector;
69
70 const auto &post_ops = attr.post_ops_;
71
72 if (post_ops.len() > 0 && !enable_postops) return false;
73
74 return injector::post_ops_ok(post_ops_ok_args_t(jcp.isa,
75 {sum, eltwise, binary}, post_ops, &dst_d,
76 false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
77 false /*sum_requires_zp_zero*/,
78 {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar,
79 broadcasting_strategy_t::no_broadcast}));
80}
81
82bool is_groups_ok(jit_brgemm_conv_conf_t &jcp) {
83 // Enable grouped convs for the shapes not supported in direct convs
84 // direct approach only supports int8/bf16 grouped conv
85 // when channels per groups is at least multiple of 4
86 // and bf16 grouped conv with layout nxc on jit_bf16 impl
87 // TODO: remove this condition after the restriction on small oc is removed
88 return jcp.ngroups > 1
89 && IMPLICATION(one_of(jcp.src_dt, u8, s8, bf16, f16),
90 jcp.oc % 4 == 0 && jcp.ic % 4 == 0);
91}
92
93status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &diff_dst_md,
94 memory_desc_t &weights_md, memory_desc_t &diff_src_md,
95 memory_desc_t &bias_md) {
96 format_tag_t src_tag, dst_tag, wei_tag;
97 dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
98
99 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
100 const memory_desc_wrapper weights_d(&weights_md);
101 const memory_desc_wrapper diff_src_d(&diff_src_md);
102 const memory_desc_wrapper bias_d(&bias_md);
103 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
104
105 const bool is_1d = jcp.ndims == 3;
106 const bool is_2d = jcp.ndims == 4;
107 const bool is_3d = jcp.ndims == 5;
108
109 if (jcp.wei_plain) {
110 return status::unimplemented;
111 } else {
112 jcp.LDB = jcp.ic_block;
113 if (jcp.ic_block == 64) {
114 if (is_3d) {
115 if (jcp.wei_dt == f32)
116 wei_tag = with_groups ? gIdhwo64i : Idhwo64i;
117 else if (jcp.wei_dt == s8) {
118 if (jcp.is_oc_padded)
119 wei_tag = with_groups ? gIdhwO16o64i4o : IdhwO16o64i4o;
120 else
121 wei_tag = with_groups ? gIdhwO64i4o : IdhwO64i4o;
122 } else if (one_of(jcp.wei_dt, bf16, f16)) {
123 if (jcp.is_oc_padded)
124 wei_tag = with_groups ? gIdhwO16o64i2o : IdhwO16o64i2o;
125 else
126 wei_tag = with_groups ? gIdhwO64i2o : IdhwO64i2o;
127 } else
128 return status::unimplemented;
129 } else if (is_1d) {
130 if (jcp.wei_dt == f32)
131 wei_tag = with_groups ? gIwo64i : Iwo64i;
132 else if (jcp.wei_dt == s8) {
133 if (jcp.is_oc_padded)
134 wei_tag = with_groups ? gIwO16o64i4o : IwO16o64i4o;
135 else
136 wei_tag = with_groups ? gIwO64i4o : IwO64i4o;
137 } else if (one_of(jcp.wei_dt, bf16, f16)) {
138 if (jcp.is_oc_padded)
139 wei_tag = with_groups ? gIwO16o64i2o : IwO16o64i2o;
140 else
141 wei_tag = with_groups ? gIwO64i2o : IwO64i2o;
142 } else
143 return status::unimplemented;
144 } else {
145 assert(is_2d);
146 UNUSED(is_2d);
147 if (jcp.wei_dt == f32)
148 wei_tag = with_groups ? gIhwo64i : Ihwo64i;
149 else if (jcp.wei_dt == s8) {
150 if (jcp.is_oc_padded)
151 wei_tag = with_groups ? gIhwO16o64i4o : IhwO16o64i4o;
152 else
153 wei_tag = with_groups ? gIhwO64i4o : IhwO64i4o;
154 } else if (one_of(jcp.wei_dt, bf16, f16)) {
155 if (jcp.is_oc_padded)
156 wei_tag = with_groups ? gIhwO16o64i2o : IhwO16o64i2o;
157 else
158 wei_tag = with_groups ? gIhwO64i2o : IhwO64i2o;
159 } else
160 return status::unimplemented;
161 }
162 } else if (jcp.ic_block == 48) {
163 if (is_3d) {
164 if (jcp.wei_dt == f32)
165 wei_tag = with_groups ? gIdhwo48i : Idhwo48i;
166 else if (jcp.wei_dt == s8) {
167 if (jcp.is_oc_padded)
168 wei_tag = with_groups ? gIdhwO16o48i4o : IdhwO16o48i4o;
169 else
170 wei_tag = with_groups ? gIdhwO48i4o : IdhwO48i4o;
171 } else if (one_of(jcp.wei_dt, bf16, f16)) {
172 if (jcp.is_oc_padded)
173 wei_tag = with_groups ? gIdhwO16o48i2o : IdhwO16o48i2o;
174 else
175 wei_tag = with_groups ? gIdhwO48i2o : IdhwO48i2o;
176 } else
177 return status::unimplemented;
178 } else if (is_1d) {
179 if (jcp.wei_dt == f32)
180 wei_tag = with_groups ? gIwo48i : Iwo48i;
181 else if (jcp.wei_dt == s8) {
182 if (jcp.is_oc_padded)
183 wei_tag = with_groups ? gIwO16o48i4o : IwO16o48i4o;
184 else
185 wei_tag = with_groups ? gIwO48i4o : IwO48i4o;
186 } else if (one_of(jcp.wei_dt, bf16, f16)) {
187 if (jcp.is_oc_padded)
188 wei_tag = with_groups ? gIwO16o48i2o : IwO16o48i2o;
189 else
190 wei_tag = with_groups ? gIwO48i2o : IwO48i2o;
191 } else
192 return status::unimplemented;
193 } else {
194 assert(is_2d);
195 UNUSED(is_2d);
196 if (jcp.wei_dt == f32)
197 wei_tag = with_groups ? gIhwo48i : Ihwo48i;
198 else if (jcp.wei_dt == s8) {
199 if (jcp.is_oc_padded)
200 wei_tag = with_groups ? gIhwO16o48i4o : IhwO16o48i4o;
201 else
202 wei_tag = with_groups ? gIhwO48i4o : IhwO48i4o;
203 } else if (one_of(jcp.wei_dt, bf16, f16)) {
204 if (jcp.is_oc_padded)
205 wei_tag = with_groups ? gIhwO16o48i2o : IhwO16o48i2o;
206 else
207 wei_tag = with_groups ? gIhwO48i2o : IhwO48i2o;
208 } else
209 return status::unimplemented;
210 }
211 } else if (jcp.ic_block == 32) {
212 if (is_3d) {
213 if (jcp.wei_dt == f32)
214 wei_tag = with_groups ? gIdhwo32i : Idhwo32i;
215 else if (jcp.wei_dt == s8) {
216 if (jcp.is_oc_padded)
217 wei_tag = with_groups ? gIdhwO16o32i4o : IdhwO16o32i4o;
218 else
219 wei_tag = with_groups ? gIdhwO32i4o : IdhwO32i4o;
220 } else if (one_of(jcp.wei_dt, bf16, f16)) {
221 if (jcp.is_oc_padded)
222 wei_tag = with_groups ? gIdhwO16o32i2o : IdhwO16o32i2o;
223 else
224 wei_tag = with_groups ? gIdhwO32i2o : IdhwO32i2o;
225 } else
226 return status::unimplemented;
227 } else if (is_1d) {
228 if (jcp.wei_dt == f32)
229 wei_tag = with_groups ? gIwo32i : Iwo32i;
230 else if (jcp.wei_dt == s8) {
231 if (jcp.is_oc_padded)
232 wei_tag = with_groups ? gIwO16o32i4o : IwO16o32i4o;
233 else
234 wei_tag = with_groups ? gIwO32i4o : IwO32i4o;
235 } else if (one_of(jcp.wei_dt, bf16, f16)) {
236 if (jcp.is_oc_padded)
237 wei_tag = with_groups ? gIwO16o32i2o : IwO16o32i2o;
238 else
239 wei_tag = with_groups ? gIwO32i2o : IwO32i2o;
240 } else
241 return status::unimplemented;
242 } else {
243 assert(is_2d);
244 UNUSED(is_2d);
245 if (jcp.wei_dt == f32)
246 wei_tag = with_groups ? gIhwo32i : Ihwo32i;
247 else if (jcp.wei_dt == s8) {
248 if (jcp.is_oc_padded)
249 wei_tag = with_groups ? gIhwO16o32i4o : IhwO16o32i4o;
250 else
251 wei_tag = with_groups ? gIhwO32i4o : IhwO32i4o;
252 } else if (one_of(jcp.wei_dt, bf16, f16)) {
253 if (jcp.is_oc_padded)
254 wei_tag = with_groups ? gIhwO16o32i2o : IhwO16o32i2o;
255 else
256 wei_tag = with_groups ? gIhwO32i2o : IhwO32i2o;
257 } else
258 return status::unimplemented;
259 }
260 } else {
261 if (is_3d) {
262 if (jcp.wei_dt == f32)
263 wei_tag = with_groups ? gIdhwo16i : Idhwo16i;
264 else if (jcp.wei_dt == s8) {
265 if (jcp.is_oc_padded)
266 wei_tag = with_groups ? gIdhwO16o16i4o : IdhwO16o16i4o;
267 else
268 wei_tag = with_groups ? gIdhwO16i4o : IdhwO16i4o;
269 } else if (one_of(jcp.wei_dt, bf16, f16)) {
270 if (jcp.is_oc_padded)
271 wei_tag = with_groups ? gIdhwO16o16i2o : IdhwO16o16i2o;
272 else
273 wei_tag = with_groups ? gIdhwO16i2o : IdhwO16i2o;
274 } else
275 return status::unimplemented;
276 } else if (is_1d) {
277 if (jcp.wei_dt == f32)
278 wei_tag = with_groups ? gIwo16i : Iwo16i;
279 else if (jcp.wei_dt == s8) {
280 if (jcp.is_oc_padded)
281 wei_tag = with_groups ? gIwO16o16i4o : IwO16o16i4o;
282 else
283 wei_tag = with_groups ? gIwO16i4o : IwO16i4o;
284 } else if (one_of(jcp.wei_dt, bf16, f16)) {
285 if (jcp.is_oc_padded)
286 wei_tag = with_groups ? gIwO16o16i2o : IwO16o16i2o;
287 else
288 wei_tag = with_groups ? gIwO16i2o : IwO16i2o;
289 } else
290 return status::unimplemented;
291 } else {
292 assert(is_2d);
293 UNUSED(is_2d);
294
295 if (jcp.wei_dt == f32)
296 wei_tag = with_groups ? gIhwo16i : Ihwo16i;
297 else if (jcp.wei_dt == s8) {
298 if (jcp.is_oc_padded)
299 wei_tag = with_groups ? gIhwO16o16i4o : IhwO16o16i4o;
300 else
301 wei_tag = with_groups ? gIhwO16i4o : IhwO16i4o;
302 } else if (one_of(jcp.wei_dt, bf16, f16)) {
303 if (jcp.is_oc_padded)
304 wei_tag = with_groups ? gIhwO16o16i2o : IhwO16o16i2o;
305 else
306 wei_tag = with_groups ? gIhwO16i2o : IhwO16i2o;
307 } else
308 return status::unimplemented;
309 }
310 }
311 }
312
313 src_tag = dst_tag;
314
315 CHECK(init_tag(jcp.src_tag, diff_dst_md, diff_dst_d, src_tag));
316 CHECK(init_tag(jcp.dst_tag, diff_src_md, diff_src_d, dst_tag));
317 CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag));
318
319 return status::success;
320}
321
322struct brg_blocking_t : public jit_brgemm_conv_conf_t {
323 struct array_in_loop_t {
324 dim_t itersize;
325 float repeatn;
326 float overlap;
327 void set(dim_t iter_s, float rpt, float ovlp = 1.f) {
328 itersize = iter_s;
329 repeatn = rpt;
330 overlap = ovlp;
331 }
332 };
333
334 struct loop_t {
335 array_in_loop_t src;
336 array_in_loop_t wei;
337 array_in_loop_t dst;
338 };
339
340 brg_blocking_t() {
341 // TODO: This is a broken form of initialization for a base class.
342 // Either set default values in a base class, or provide a proper
343 // default ctor, or take a `jit_brgemm_conv_conf_t` object to initialize
344 // a base class object.
345 jit_brgemm_conv_conf_t *base
346 = static_cast<jit_brgemm_conv_conf_t *>(this);
347 *base = jit_brgemm_conv_conf_t();
348 init();
349 }
350 brg_blocking_t(const jit_brgemm_conv_conf_t &jcp)
351 : jit_brgemm_conv_conf_t(jcp) {
352 init();
353 }
354 void init() {
355 ur = 0;
356 ur_block = 0;
357 ur_block_tail = 0;
358 eff = 0.f;
359 nb_kd = 0;
360 nb_kh = 0;
361 nb_kw = 0;
362 sp = 0;
363 sp_block = 0;
364 nb_sp = 0;
365 eff = 0;
366 }
367
368 int ur, ur_block, ur_block_tail;
369 int nb_kd, nb_kh, nb_kw;
370 float eff;
371 static unsigned L1;
372 static unsigned L2;
373 static unsigned L3;
374 // These are rough estimates of the latency (relative) of access to various
375 // cache levels. This is enough for an estimation of data access cost.
376 // TODO: Improve memory access estimates
377 static constexpr float L1_k = 1.f;
378 static constexpr float L2_k = 3.f;
379 static constexpr float L3_k = 15.f;
380 // TODO: At the moment, we are primarily evaluating the fit of the data into
381 // the L1/L2. Need to take into account the difference between the L3 and
382 // memory.
383 static constexpr float mem_k = 15.f;
384 static constexpr int bench_iterations = 1;
385 static constexpr int max_regs = 32;
386 static constexpr int bcast_simd = 16;
387
388 int sp, sp_block, nb_sp;
389 static int last_oc_block_size;
390
391 void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
392 void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
393
394 status_t estimate_brgemm_ur();
395 status_t get_brgemm_ur(
396 const primitive_attr_t *attr, const memory_desc_t &dst_md);
397
398 float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
399 bool is_broadcast, bool is_shared) const;
400
401 float io_k(const loop_t loop, const array_in_loop_t arr, float pk,
402 bool is_broadcast, bool is_shared) const;
403
404 void select_oc_block();
405
406 void update_blocks();
407 bool fast_check_ic_block() const;
408 float est_eff();
409 void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block,
410 int kh_block, bool maybe_use_buffer, int max_iw_block_thr);
411 status_t calc_blocks();
412
413 bool fast_check_ic_block_1x1() const;
414 float est_eff_1x1();
415
416 // utils
417 static int get_inp_size(
418 int max_src_size, int dst_size, int k, int stride, int dilate) {
419 auto adj_str = nstl::min(k, stride);
420 const auto res = nstl::min(max_src_size,
421 calculate_end_padding(0, dst_size, 0, adj_str,
422 calculate_extended_filter_size(k, dilate)));
423 return res;
424 }
425
426 static int get_inp_block_size(
427 int out_size, int stride, int ext_k, int padding) {
428 const auto res = div_up(out_size + padding % stride, stride)
429 + (ext_k - 1 - padding % stride) / stride;
430 return res;
431 }
432
433 static float squeeze_val(float eff, float koeff) {
434 if (koeff <= 0) return 1;
435 if (koeff == 1) return eff;
436 const auto k = 1.f / koeff;
437 return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff;
438 }
439
440 static int estimate_ur(int ic_block) {
441 const auto est_ur = (ic_block == 64)
442 ? 6
443 : ((ic_block == 48) ? 9 : ((ic_block == 32) ? 14 : 28));
444 return est_ur;
445 }
446
447 int inp_w(int out_w, int ker_w) const {
448 return get_inp_size(ow, out_w, ker_w, stride_w, dilate_w);
449 }
450
451 int rnd_simd(int val) const { return rnd_up(val, simd_w); }
452
453 int rnd_inp_simd(int out_w, int ker_w, int voc) const {
454 const auto vsp = inp_w(out_w, ker_w);
455 return ((stride_w == 1 && voc >= oc) ? rnd_up(vsp * voc, simd_w)
456 : vsp * rnd_up(voc, simd_w));
457 }
458
459 static constexpr int MAXNLOOPS = 32;
460 loop_t loop[MAXNLOOPS];
461};
462
463unsigned brg_blocking_t::L1;
464unsigned brg_blocking_t::L2;
465unsigned brg_blocking_t::L3;
466int brg_blocking_t::last_oc_block_size;
467
468float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
469 bool is_broadcast, bool is_shared) const {
470 if (n < 1) return 0;
471 if (n == 1) return pk;
472 const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz
473 + (use_buffer ? dst * acc_dsz : 0);
474 const auto amount_L1 = is_broadcast ? src * src_dsz : amount;
475 const auto k = is_broadcast
476 ? ((amount_L1 < L1) ? L1_k
477 : ((amount < L2) ? L2_k
478 : (is_shared ? L3_k : mem_k)))
479 : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k));
480 const auto cost = pk + k * (n - 1);
481 return cost / n;
482}
483
484float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
485 float pk, bool is_broadcast, bool is_shared) const {
486 return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize,
487 arr.repeatn * arr.overlap, pk, is_broadcast, is_shared);
488}
489
490void brg_blocking_t::select_oc_block() {
491 const auto padded_oc = last_oc_block_size * (is_oc_padded ? 16 : 1);
492 oc_block = rnd_up(oc, padded_oc);
493 nb_oc = utils::div_up(oc, oc_block);
494}
495
496status_t brg_blocking_t::estimate_brgemm_ur() {
497 // Simple simulation of brgemm_desc init
498 if (sp_block <= 0) return status::invalid_arguments;
499 LDA = oc_block;
500 LDB = ic_block;
501 LDC = use_buffer ? ic_block : stride_w * ic_without_padding;
502
503 // Configure matrix sizes
504 // for amx if oc_block != oc then we use exec_trans so K is oc_block
505 const auto padded_oc = last_oc_block_size * (is_oc_padded ? 16 : 1);
506
507 ocp = rnd_up(oc, padded_oc);
508
509 const auto adj_sp = div_up(iw_block, stride_w);
510 M = brgM = adj_sp >= sp_block ? sp_block : 0;
511 M_tail = brgM_tail = adj_sp % sp_block;
512
513 N = ic >= ic_block ? ic_block : 0;
514 N_tail = ic % ic_block;
515 K = oc >= oc_block ? oc_block : 0;
516 K_tail = oc_block;
517
518 const auto vK = K > 0 ? K : K_tail;
519 const auto vM = M > 0 ? M : M_tail;
520 const auto vN = N > 0 ? N : N_tail;
521
522 const float alpha = 1.0;
523 const float beta = 0.0;
524 brgemm_t brg;
525 brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
526 brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr,
527 is_bf32);
528 CHECK(brgemm_utils::brgemm_blocking(&brg));
529 ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1);
530 if (ur == 0) return status::invalid_arguments;
531 ur_block = brg.bd_block;
532 if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) {
533 brgemm_t brg_sp_tail;
534 brgemm_utils::init_brgemm_conf(&brg_sp_tail, isa, brgemm_addr, src_dt,
535 wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, M_tail,
536 vN, vK, nullptr, is_bf32);
537 CHECK(brgemm_utils::brgemm_blocking(&brg_sp_tail));
538 ur_block_tail = brg_sp_tail.bd_block;
539 } else {
540 ur_block_tail = 0;
541 }
542 return status::success;
543}
544
545status_t brg_blocking_t::get_brgemm_ur(
546 const primitive_attr_t *attr, const memory_desc_t &dst_md) {
547 // Detailed simulation of brgemm convolution init
548 if (sp_block <= 0 || oc_block <= 0 || ic_block <= 0)
549 return status::invalid_arguments;
550 CHECK(estimate_brgemm_ur());
551
552 LDD = stride_w * ic_without_padding;
553
554 const float alpha = 1.0;
555 const float beta = 1.0;
556 const float beta_init = 0.0;
557
558 for (int i = 0; i < M; i++) {
559 auto vM = i + 1;
560 // init only needed brgemm descriptors
561 if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1)
562 && vM != M && vM != M_tail)
563 continue;
564 for (int i_init = 0; i_init < 2; i_init++) {
565 for (int i_N = 0; i_N < 2; i_N++) {
566 for (int i_K = 0; i_K < 2; i_K++) {
567 auto vbeta = (i_init) ? beta_init : beta;
568 auto vN = (i_N) ? N_tail : N;
569 auto vK = (i_K) ? K_tail : K;
570 if (vN == 0 || vK == 0) continue;
571 brgemm_t brg;
572 brgemm_strides_t brg_strides;
573 brg_strides.stride_a = ngroups * oc_without_padding
574 * (dilate_w + 1) * src_dsz;
575 //weights are padded by ic_block and last_oc_block
576 brg_strides.stride_b = rnd_up(oc, last_oc_block_size)
577 * rnd_up(ic, ic_block) * wei_dsz;
578 const auto strides_ptr = (brg_type == brgemm_strd)
579 ? &brg_strides
580 : nullptr;
581 brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt,
582 wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB,
583 LDC, vM, vN, vK, strides_ptr, is_bf32);
584 CHECK(brgemm_utils::brgemm_blocking(&brg));
585
586 brgemm_attr_t brgattr;
587 brgattr.max_bs = max_batch;
588 const auto max_vpad = (exec_type == exec_vpad)
589 ? nstl::max(l_pad, r_pad)
590 : 0;
591 brgattr.max_top_vpad = max_vpad;
592 brgattr.max_bottom_vpad = max_vpad;
593 brgattr.fpmath_mode = attr->fpmath_mode_;
594 CHECK(brgemm_desc_set_attr(&brg, brgattr));
595
596 brg.with_sum = with_sum;
597 CHECK(brgemm_desc_set_postops(
598 &brg, attr, &dst_md, LDD, bia_dt));
599 }
600 }
601 }
602 }
603
604 return status::success;
605}
606
607void brg_blocking_t::update_blocks() {
608 if (sp_block <= 0
609 || utils::one_of(0, id_block, ih_block, oc_block, ic_block,
610 kd_block, kh_block, kw_block, is_block, iw_block))
611 return;
612
613 nb_id = div_up(id, id_block);
614 nb_ih = div_up(ih, ih_block);
615 nb_oc = div_up(oc, oc_block);
616 nb_ic = div_up(ic, ic_block);
617 nb_kd = div_up(kd, kd_block);
618 nb_kh = div_up(kh, kh_block);
619 nb_kw = div_up(kw, kw_block);
620 nb_iw = div_up(iw, iw_block);
621
622 sp = iw;
623 sp_block = iw_block;
624 nb_sp = nb_iw;
625
626 ow_block = get_inp_block_size(iw_block, stride_w, ext_kw, l_pad);
627 oh_block = get_inp_block_size(ih_block, stride_h, ext_kh, t_pad);
628 od_block = get_inp_block_size(id_block, stride_d, ext_kd, f_pad);
629}
630
631bool brg_blocking_t::fast_check_ic_block() const {
632 // This function is for reducing the number of blocking variants
633 // TODO: eliminate heuristic in this function
634 if (is_1x1) return fast_check_ic_block_1x1();
635 const auto rnd_ic = rnd_up(ic, 16);
636 auto res = false;
637 if (ic_block == 64) {
638 res = (rnd_ic % ic_block == 0 && rnd_ic * wei_dsz < 192 * 4);
639 } else if (ic_block == 48) {
640 // TODO: edit this heuristic for bwd_d
641 const bool big_spatial
642 = od * oh * ow > 81 * stride_d * stride_h * stride_w;
643 res = (rnd_ic % ic_block == 0 && rnd_ic * wei_dsz <= 384 * 4
644 && big_spatial);
645 } else
646 res = true;
647
648 return res;
649}
650
651float brg_blocking_t::est_eff() {
652 if (is_1x1) return est_eff_1x1();
653 const auto icblock = ic_block / 16;
654
655 const auto brgemm_microkernel_eff
656 = (static_cast<float>(icblock) * ur) / ((ur + icblock) * max_regs);
657
658 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
659 const auto brgemm_eff = squeeze_val(ur
660 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
661 / 64,
662 0.5f);
663
664 const auto sp_amount = nb_id * nb_ih * nb_sp;
665 const auto work_amount = mb * ngroups * nb_ic * sp_amount;
666 const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block));
667
668 const auto thr_eff = static_cast<float>(work_amount)
669 / utils::rnd_up(work_amount, nthr);
670
671 const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block);
672
673 const auto job = div_up(work_amount, nthr);
674
675 auto job_eff = 1.f;
676 if (job < nthr) {
677 std::vector<dim_t> thr_jobs(nthr);
678
679 for (int ithr = 0; ithr < nthr; ithr++) {
680 thr_jobs[ithr] = 0;
681 if (ithr >= work_amount) continue;
682 dim_t thr_job = 0;
683 int start {0}, end {0};
684 balance211(work_amount, nthr, ithr, start, end);
685 int n {0}, g {0}, icb {0}, idp {0}, ihp {0}, spb {0};
686 nd_iterator_init(start, n, mb, idp, id, ihp, ih, spb, nb_sp, g,
687 ngroups, icb, nb_ic);
688
689 for (auto work = start; work < end; work++) {
690 const int icp = icb * ic_block;
691 const auto ic_sz = nstl::min(ic - icp, ic_block);
692 int sp_sz = 0;
693 const int spp = spb * sp_block;
694 sp_sz = nstl::min(sp - spp, sp_block);
695 thr_job += sp_sz * ic_sz;
696
697 nd_iterator_step(n, mb, idp, id, ihp, ih, spb, nb_sp, g,
698 ngroups, icb, nb_ic);
699 }
700 thr_jobs[ithr] = thr_job;
701 }
702
703 dim_t max_job = 0;
704 dim_t sum_job = 0;
705 for (int ithr = 0; ithr < nthr; ithr++) {
706 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
707 sum_job += thr_jobs[ithr];
708 }
709 job_eff = max_job == 0 ? 1
710 : static_cast<float>(sum_job) / (max_job * nthr);
711
712 } else {
713 job_eff = thr_eff;
714 }
715
716 const auto oc_blocking_size = oc_block * nb_oc_blocking;
717 const auto ic_blocking_size = ic_block * oc_blocking_size;
718
719 int l = -1;
720
721 // -- brgemm kernel: loop by simd_w --
722 l++;
723 const auto inp_ur = inp_w(ur, kw_block);
724 loop[l].src.set(inp_ur * simd_w, 1, bcast_simd);
725 loop[l].dst.set(0, 1);
726 loop[l].wei.set(ic_block, 1);
727
728 // -- brgemm kernel: loop by kw in kw_block --
729 l++;
730 auto src_is = rnd_inp_simd(ur, kw_block, oc_blocking_size);
731 loop[l].src.set(src_is, 1, kw_block);
732 loop[l].dst.set(0, 1);
733 loop[l].wei.set(ic_blocking_size, 1);
734
735 // -- brgemm kernel: loop by batch (grouped by kw_block) in ur --
736 l++;
737 loop[l].src.set(src_is, 1);
738 loop[l].dst.set(0, 1);
739 auto wei_is = kw_block * ic_blocking_size;
740 loop[l].wei.set(wei_is, 1);
741 // -- brgemm kernel: loop by ur in sp_block --
742 l++;
743 const auto nb_ur = div_up(sp_block, ur);
744 loop[l].src.set(kd_block * kh_block * src_is, 1);
745 loop[l].dst.set(ur * ic_block, 1);
746 wei_is = kd_block * kh_block * kw_block * ic_blocking_size;
747 loop[l].wei.set(wei_is, nb_ur);
748
749 // -- harness: loop by k_blocks in ks --
750 l++;
751 loop[l].src.set(kd_block * kh_block
752 * rnd_inp_simd(sp_block, kw_block, oc_blocking_size),
753 1);
754 loop[l].dst.set(sp_block * ic_block, nb_kd * nb_kh * nb_kw);
755 loop[l].wei.set(wei_is, 1);
756
757 // -- brgemm kernel: loop by oc_chunks --
758 l++;
759 const auto oc_chunks = div_up(nb_oc, nb_oc_blocking);
760 loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, oc_blocking_size), 1);
761 loop[l].dst.set(sp_block * ic_block, oc_chunks);
762 wei_is = kd * kh * kw * ic_blocking_size;
763 loop[l].wei.set(wei_is, 1);
764
765 const auto dim_ic = 1;
766 const auto nb_ic_thr = nstl::min(nb_ic, div_up(job, dim_ic));
767 const auto ic_thr = nstl::min(ic, nb_ic_thr * ic_block);
768 const auto nsimd_ic_thr = div_up(ic_thr, simd_w);
769
770 const auto dim_sp = ngroups * nb_ic;
771 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
772 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
773
774 const auto dim_ih = nb_sp * dim_sp;
775 const int nb_ih_thr = nstl::min(nb_ih, div_up(job, dim_ih));
776 const int ih_thr = nstl::min(ih, nb_ih_thr * ih_block);
777
778 const auto dim_id = nb_ih * dim_ih;
779 const int nb_id_thr = nstl::min(nb_id, div_up(job, dim_id));
780 const int id_thr = nstl::min(id, nb_id_thr * id_block);
781
782 src_is = kd * kh * rnd_inp_simd(sp_block, kw, oc);
783
784 auto wei_op = kd * kh * kw * icblock * oc;
785
786 // -- harness: loop by ic_block --
787 l++;
788 loop[l].src.set(src_is, nb_ic_thr);
789 loop[l].dst.set(sp_block * ic_block, 1);
790 wei_is = kd * kh * kw * ic_block * oc;
791 wei_op = kd * kh * kw * nsimd_ic_thr * oc;
792 loop[l].wei.set(wei_is, 1);
793
794 // -- harness: loop by sp_blocks --
795 l++;
796 loop[l].src.set(src_is, 1);
797 const auto rnd_ic_for_sp = simd_w * nsimd_ic_thr;
798 loop[l].dst.set(sp_block * rnd_ic_for_sp, 1);
799 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
800 // oh_block almost all is 1. TODO: manage oh_block != 1
801 // -- harness: loop by oh_blocks --
802 l++;
803 src_is = kd * kh * rnd_inp_simd(sp_thr, kw, oc);
804 loop[l].src.set(ih_block * src_is, 1);
805 loop[l].dst.set(sp_thr * rnd_ic_for_sp, 1);
806 loop[l].wei.set(wei_op * simd_w, nb_ih_thr);
807 // od_block almost all is 1. TODO: manage oh_block != 1
808 // -- harness: loop by od_blocks --
809 l++;
810 loop[l].src.set(id_block * ih_thr * src_is, 1);
811 loop[l].dst.set(ih_thr * sp_thr * rnd_ic_for_sp, 1);
812 loop[l].wei.set(wei_op * simd_w, nb_id_thr);
813
814 // -- harness: loop by mb --
815 l++;
816 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_ic));
817 loop[l].src.set(id_thr * ih_thr * src_is, 1);
818 loop[l].dst.set(id_thr * ih_thr * sp_thr * nsimd_ic_thr * simd_w, 1);
819 loop[l].wei.set(kd * kh * kw * nsimd_ic_thr * simd_w * oc, mb_thr);
820
821 const auto src_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr
822 * kd * kh * kw * oc;
823 const auto dst_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr
824 * nsimd_ic_thr;
825 wei_op = kd * kh * kw * nsimd_ic_thr * oc;
826
827 // for "real" application set bench_iterations to 1
828 const auto iterations = bench_iterations;
829 l++;
830 loop[l].src.set(src_op, iterations);
831 loop[l].dst.set(dst_op * simd_w, iterations);
832 loop[l].wei.set(wei_op * simd_w, iterations);
833
834 auto src_mem_k = mem_k;
835 auto dst_mem_k = mem_k;
836 auto wei_mem_k = mem_k;
837 float src_rp = 1;
838 float dst_rp = 1;
839 float wei_rp = 1;
840
841 for (auto il = l; il >= 0; il--) {
842 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false);
843 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
844 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true);
845 src_rp *= loop[il].src.repeatn;
846 dst_rp *= loop[il].dst.repeatn;
847 wei_rp *= loop[il].wei.repeatn;
848 }
849 const auto src_ops = (src_op * src_rp) / iterations;
850 const auto dst_ops = (dst_op * dst_rp) / iterations;
851 const auto wei_ops = (wei_op * wei_rp) / iterations;
852
853 const auto src_cost = src_mem_k * src_ops;
854 const auto dst_cost = dst_mem_k * dst_ops;
855 const auto wei_cost = wei_mem_k * wei_ops;
856 const auto call_kernel_cost = job * oc_chunks * nb_kd * nb_kh * nb_kw;
857
858 const auto cache_eff = (static_cast<dim_t>(mb) * id * ih * sp * oc * ic)
859 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
860 const auto res_eff = ic_block_eff * brgemm_microkernel_eff * sp_eff
861 * job_eff * ur_eff * cache_eff * brgemm_eff;
862 return res_eff;
863}
864
865void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_,
866 int kh_block_, bool maybe_use_buffer, int max_iw_block_thr) {
867 kd_block = kd_block_;
868 kh_block = kh_block_;
869
870 kw_block = kw;
871 kd_block_pad = kd_block;
872 kh_block_pad = kh_block;
873 kw_block_pad = kw_block;
874
875 const auto w_block_size = 2 * src_dsz * oc * owp + dst_dsz * iw * ic_block;
876 const auto other_size = wei_dsz * kd * kh * kw * oc * ic_block
877 + acc_dsz * 2 * amx_h * ic_block;
878 const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)),
879 other_size > L2 ? 0 : L2 - other_size);
880 if (odp * ohp * w_block_size > L2_available) {
881 id_block = utils::saturate(
882 1, id, int(L2_available / (ohp * w_block_size)));
883 if (id_block == 1)
884 ih_block = utils::saturate(
885 1, ih, int(L2_available / (w_block_size)));
886 else
887 ih_block = ih;
888 } else {
889 id_block = 1;
890 ih_block = ih;
891 }
892 if (is_amx(isa)) {
893 // try to fit into L1
894 bool L1_fit_res = false;
895 auto cur_id_block = id_block;
896 auto cur_ih_block = ih_block;
897 const auto src_w_block_size
898 = src_dsz * oc * owp + dst_dsz * iw * ic_block;
899 if (src_w_block_size < L1) {
900 cur_id_block = utils::saturate(
901 1, id, int(L1 / (ohp * src_w_block_size)));
902 if (cur_id_block == 1)
903 cur_ih_block
904 = utils::saturate(1, ih, int(L1 / (src_w_block_size)));
905 }
906 for (; cur_id_block > 1; cur_id_block--) {
907 const auto sp_size = cur_id_block * cur_ih_block * owp;
908 if ((static_cast<float>(id) / rnd_up(id, cur_id_block)) > 0.9f
909 && static_cast<float>(sp_size) / rnd_up(sp, amx_h) > 0.8f) {
910 L1_fit_res = true;
911 break;
912 }
913 }
914 if (cur_id_block == 1) {
915 for (; cur_ih_block > 1; cur_ih_block--) {
916 const auto sp_size = cur_ih_block * owp;
917 if ((static_cast<float>(ih) / rnd_up(ih, cur_ih_block)) > 0.9f
918 && sp_size > 128) {
919 L1_fit_res = true;
920 break;
921 }
922 }
923 }
924 if (L1_fit_res) {
925 id_block = cur_id_block;
926 ih_block = cur_ih_block;
927 }
928 }
929
930 // limit ih_block to have good threading
931 const auto thr_ic_block
932 = div_up(nthr, mb * div_up((ic > 32 ? ngroups : 1) * ic, ic_block));
933 const auto thr_id_block = div_up(id, thr_ic_block);
934 const auto thr_ih_block
935 = div_up(ih, thr_ic_block * div_up(id, thr_id_block));
936 id_block = nstl::min(id_block, thr_id_block);
937 ih_block = nstl::min(ih_block, thr_ih_block);
938 while ((id_block % stride_d != 0 || id % id_block != 0) && id_block < id)
939 id_block++;
940 while ((ih_block % stride_h != 0 || ih % ih_block != 0) && ih_block < ih)
941 ih_block++;
942
943 // --- Select iw_block ----
944 const auto max_iw_block_L2 = iw;
945 auto start_iw_block = nstl::min(max_iw_block_thr, max_iw_block_L2);
946
947 sp = iw;
948 const auto start_sp_block = start_iw_block;
949 auto prev_spb = 0;
950 for (auto ns = 1; ns <= sp; ns++) {
951 const auto spb = div_up(sp, ns);
952 if (spb == prev_spb || spb > start_sp_block) continue;
953 if (spb % stride_w != 0) continue;
954 if (iw % spb != 0) continue;
955
956 prev_spb = spb;
957 iw_block = spb;
958 sp_block = iw_block;
959
960 select_oc_block();
961
962 use_buffer = maybe_use_buffer;
963
964 const status_t st = estimate_brgemm_ur();
965 if (st != status::success) continue;
966 os_block = sp_block = iw_block;
967 update_blocks();
968
969 eff = est_eff();
970 if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
971 }
972}
973
974status_t brg_blocking_t::calc_blocks() {
975 sp = iw;
976
977 nb_oc_blocking = 1;
978 // --- Select kernel blocking ---
979 // if dst_dt != acc_dt and we need to store intermediate
980 // results then we need the out buffer
981 const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum);
982
983 std::vector<int> kd_blocks(1), kh_blocks(1);
984 kd_blocks[0] = kd;
985 kh_blocks[0] = kh;
986 if (kd != 1) {
987 kd_blocks.resize(2);
988 kd_blocks[1] = 1;
989 }
990 if (kh != 1) {
991 kh_blocks.resize(2);
992 kh_blocks[1] = 1;
993 }
994
995 const auto thr_eff_threshold = 0.9f;
996 const auto max_iw_block_thr = utils::saturate(1, iw,
997 static_cast<int>(div_up(
998 mb * ngroups * nb_ic * is, thr_eff_threshold * nthr)));
999
1000 iw_block = is_block = sp_block = -1;
1001 brg_blocking_t best_brgb = *this;
1002 for (const auto &kd_block : kd_blocks) {
1003 for (const auto &kh_block : kh_blocks) {
1004 iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer,
1005 max_iw_block_thr);
1006 }
1007 }
1008 *this = best_brgb;
1009 if (sp_block <= 0) return status::unimplemented;
1010
1011 iw_block = is_block = sp_block;
1012 iw_tail = iw % iw_block;
1013
1014 update_blocks();
1015
1016 return status::success;
1017}
1018
1019bool brg_blocking_t::fast_check_ic_block_1x1() const {
1020 // This function checks for reducing the number of blocking variants
1021 // TODO: eliminate heuristic in this function
1022 if (is_1x1 && is_amx(isa)) return true;
1023 const auto rnd_ic = rnd_up(ic, 16);
1024 auto res = false;
1025 if (ic_block == 64) {
1026 const auto big_spatial
1027 = id * ih * iw >= 64 * stride_d * stride_h * stride_w;
1028 res = (rnd_ic % ic_block == 0 && big_spatial);
1029 } else if (ic_block == 48) {
1030 const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block);
1031 res = (ic_block_eff >= 0.95f);
1032 } else
1033 res = true;
1034
1035 return res;
1036}
1037
1038float brg_blocking_t::est_eff_1x1() {
1039 const auto icblock = ic_block / 16;
1040
1041 auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float {
1042 const int nb = dim / block;
1043 constexpr int max_nb = 2; // only consider 2x2 tile blocking
1044 const int block2 = nstl::min(max_nb, nb);
1045 const int nb2 = nb / block2;
1046 const int nb2_tail = nb % block2;
1047 if (!use_ave) return block2;
1048 return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2);
1049 };
1050 const bool use_ocb_ave = true;
1051 const auto icb_ave = calc_ave_blk(ic_block, 16, use_ocb_ave);
1052 const bool use_spb_ave = false;
1053 const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave);
1054 const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0;
1055 const auto M_tail_n_sp_blks
1056 = ur_block_tail > 0 ? M_tail / ur_block_tail : 0;
1057
1058 // heuristic for maskrcnn workaround: use old blocking for some convolutions
1059 // TODO: remove this condition
1060 const bool maskrcnn_cond = (ic == 1024 && oc == 2048)
1061 || (ic == 1024 && oc == 512) || (ic == 256 && oc == 1024)
1062 || (ic == 512 && oc == 1024) || (ic == 512 && oc == 2048);
1063 const auto amx_fac = maskrcnn_cond
1064 ? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks))
1065 : (static_cast<float>(div_up(M + M_tail, 16))
1066 / (M_n_sp_blks + M_tail_n_sp_blks));
1067
1068 const auto brgemm_microkernel_eff = is_amx(isa)
1069 ? amx_fac * (static_cast<float>(icb_ave) * spb_ave)
1070 / (icb_ave + spb_ave)
1071 : (static_cast<float>(icblock) * ur) / ((ur + icblock) * max_regs);
1072 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
1073 const auto brgemm_eff = squeeze_val(ur
1074 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
1075 / 64,
1076 0.5f);
1077
1078 const auto sp_amount = nb_id * nb_ih * nb_sp;
1079 const auto work_amount = mb * ngroups * nb_ic * sp_amount;
1080
1081 const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block);
1082 const auto thr_eff = static_cast<float>(work_amount)
1083 / utils::rnd_up(work_amount, nthr);
1084 const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block);
1085
1086 const auto job = div_up(work_amount, nthr);
1087
1088 const auto dim_ic = 1;
1089 const auto nb_ic_thr = nstl::min(nb_ic, div_up(job, dim_ic));
1090 const auto ic_thr = nstl::min(ic, nb_ic_thr * ic_block);
1091 const auto nsimd_ic_thr = div_up(ic_thr, simd_w);
1092
1093 const auto dim_sp = ngroups * nb_ic;
1094 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
1095 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
1096
1097 const auto dim_ih = nb_sp * dim_sp;
1098 const int nb_ih_thr = nstl::min(nb_ih, div_up(job, dim_ih));
1099 const int ih_thr = nstl::min(ih, nb_ih_thr * ih_block);
1100
1101 const auto dim_id = nb_ih * dim_ih;
1102 const int nb_id_thr = nstl::min(nb_id, div_up(job, dim_id));
1103 const int id_thr = nstl::min(id, nb_id_thr * id_block);
1104
1105 auto job_eff = 1.f;
1106 if (job < nthr) {
1107 std::vector<dim_t> thr_jobs(nthr);
1108 for (int ithr = 0; ithr < nthr; ithr++) {
1109 thr_jobs[ithr] = 0;
1110 if (ithr >= work_amount) continue;
1111 dim_t thr_job = 0;
1112 int start {0}, end {0};
1113 balance211(work_amount, nthr, ithr, start, end);
1114 int n {0}, g {0}, icb {0}, idp {0}, ihp {0}, spb {0};
1115 nd_iterator_init(start, n, mb, idp, id, ihp, ih, spb, nb_sp, g,
1116 ngroups, icb, nb_ic);
1117
1118 for (auto work = start; work < end; work++) {
1119 const int icp = icb * ic_block;
1120 const auto ic_sz = nstl::min(ic - icp, ic_block);
1121 int sp_sz = 0;
1122 const int spp = spb * sp_block;
1123 sp_sz = nstl::min(sp - spp, sp_block);
1124 thr_job += sp_sz * ic_sz;
1125 nd_iterator_step(n, mb, idp, id, ihp, ih, spb, nb_sp, g,
1126 ngroups, icb, nb_ic);
1127 }
1128 thr_jobs[ithr] = thr_job;
1129 }
1130
1131 dim_t max_job = 0;
1132 dim_t sum_job = 0;
1133 for (int ithr = 0; ithr < nthr; ithr++) {
1134 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
1135 sum_job += thr_jobs[ithr];
1136 }
1137
1138 job_eff = max_job == 0 ? 1
1139 : static_cast<float>(sum_job) / (max_job * nthr);
1140 } else {
1141 job_eff = thr_eff;
1142 }
1143
1144 const auto oc_blocking_size = oc_block * nb_oc_blocking;
1145 const auto ic_blocking_size = ic_block * oc_blocking_size;
1146
1147 int l = -1;
1148 // -- brgemm kernel: loop by simd_w --
1149 l++;
1150 loop[l].src.set(ur * simd_w, 1, bcast_simd);
1151 loop[l].dst.set(0, 1);
1152 loop[l].wei.set(ic_block, 1);
1153
1154 // -- brgemm kernel: loop by ur in sp_block --
1155 l++;
1156 const auto nb_ur = div_up(sp_block, ur);
1157 const auto nb_sp_no_tail = sp / sp_block;
1158 const auto sp_block_tail = sp % sp_block;
1159 const auto nb_ur_average
1160 = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp;
1161 loop[l].src.set(ur * rnd_simd(oc_blocking_size), 1);
1162 loop[l].dst.set(ur * ic_block, 1);
1163 loop[l].wei.set(ic_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur);
1164 // -- brgemm kernel: loop by ic_chunks --
1165 l++;
1166 const auto oc_chunks = div_up(nb_oc, nb_oc_blocking);
1167 loop[l].src.set(sp_block * oc_blocking_size, 1);
1168 loop[l].dst.set(sp_block * ic_block, oc_chunks);
1169 auto wei_is = ic_blocking_size;
1170 auto wei_op = icblock * oc;
1171 loop[l].wei.set(wei_is, 1);
1172
1173 // -- harness: loop by oc_block --
1174 l++;
1175 loop[l].src.set(sp_block * rnd_simd(ic), nb_ic_thr);
1176 loop[l].dst.set(sp_block * ic_block, 1);
1177 wei_is = ic_block * ic;
1178 wei_op = nsimd_ic_thr * ic;
1179 loop[l].wei.set(wei_is, 1);
1180
1181 const auto rnd_ic_for_sp = simd_w * nsimd_ic_thr;
1182 // -- harness: loop by sp_blocks --
1183 l++;
1184 loop[l].src.set(sp_block * oc_blocking_size, 1);
1185 loop[l].dst.set(sp_block * rnd_ic_for_sp, 1);
1186 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1187 // -- harness: loop by oh_blocks --
1188 l++;
1189 loop[l].src.set(ih_block * sp_thr * rnd_simd(oc_blocking_size), 1);
1190 loop[l].dst.set(ih_block * sp_thr * rnd_ic_for_sp, 1);
1191 loop[l].wei.set(wei_op * simd_w, nb_ih_thr);
1192 // -- harness: loop by od_blocks --
1193 l++;
1194 loop[l].src.set(id_block * ih_thr * sp_thr * rnd_simd(oc_blocking_size), 1);
1195 loop[l].dst.set(id_block * ih_thr * sp_thr * rnd_ic_for_sp, 1);
1196 loop[l].wei.set(wei_op * simd_w, nb_id_thr);
1197
1198 // -- harness: loop by mb --
1199 l++;
1200 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_ic));
1201 loop[l].src.set(id_thr * ih_thr * sp_thr * rnd_simd(oc_blocking_size), 1);
1202 loop[l].dst.set(nsimd_ic_thr * simd_w * id_thr * ih_thr * sp_thr, 1);
1203 loop[l].wei.set(nsimd_ic_thr * oc * simd_w, mb_thr);
1204
1205 const auto src_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr
1206 * oc_blocking_size;
1207 const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_ic_thr * id_thr
1208 * ih_thr * sp_thr;
1209 wei_op = nsimd_ic_thr * oc;
1210
1211 // for "real" application set bench_iterations to 1
1212 const auto iterations = bench_iterations;
1213 l++;
1214 loop[l].src.set(src_op, iterations);
1215 loop[l].dst.set(dst_op * simd_w, iterations);
1216 loop[l].wei.set(wei_op * simd_w, iterations);
1217
1218 auto src_mem_k = mem_k;
1219 auto dst_mem_k = mem_k;
1220 auto wei_mem_k = mem_k;
1221 float src_rp = 1;
1222 float dst_rp = 1;
1223 float wei_rp = 1;
1224
1225 for (auto il = l; il >= 0; il--) {
1226 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false);
1227 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
1228 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true);
1229 src_rp *= loop[il].src.repeatn;
1230 dst_rp *= loop[il].dst.repeatn;
1231 wei_rp *= loop[il].wei.repeatn;
1232 }
1233 const auto src_ops = (src_op * src_rp) / iterations;
1234 const auto dst_ops = (dst_op * dst_rp) / iterations;
1235 const auto wei_ops = (wei_op * wei_rp) / iterations;
1236
1237 const auto src_cost = src_mem_k * src_ops;
1238 const auto dst_cost = dst_mem_k * dst_ops;
1239 const auto wei_cost = wei_mem_k * wei_ops;
1240 const auto call_kernel_cost = job * oc_chunks;
1241
1242 const auto up_sp_size = id * ih;
1243
1244 const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * oc * ic)
1245 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
1246
1247 const auto res_eff = ic_block_eff * brgemm_microkernel_eff * sp_eff
1248 * job_eff * ur_eff * cache_eff * brgemm_eff;
1249 return res_eff;
1250}
1251
1252brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) {
1253 return attr.zero_points_.has_default_values(arg)
1254 ? brgemm_broadcast_t::none
1255 : brgemm_broadcast_t::per_tensor;
1256}
1257
1258status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1259 const convolution_desc_t &cd, memory_desc_t &diff_dst_md,
1260 memory_desc_t &weights_md, memory_desc_t &diff_src_md,
1261 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads,
1262 bool enable_postops) {
1263 using namespace prop_kind;
1264
1265 brg_blocking_t::L1 = platform::get_per_core_cache_size(1);
1266 brg_blocking_t::L2 = platform::get_per_core_cache_size(2);
1267 brg_blocking_t::L3 = platform::get_per_core_cache_size(2);
1268
1269 if (!mayiuse(avx512_core)) return status::unimplemented;
1270
1271 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1272 const memory_desc_wrapper weights_d(&weights_md);
1273 const memory_desc_wrapper diff_src_d(&diff_src_md);
1274 const memory_desc_wrapper bias_d(&bias_md);
1275
1276 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1277 int ndims = diff_src_d.ndims();
1278
1279 jcp = zero<decltype(jcp)>();
1280 jcp.isa = isa;
1281 jcp.ndims = ndims;
1282 jcp.prop_kind = cd.prop_kind;
1283 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1284 jcp.mb = diff_src_d.dims()[0];
1285 jcp.oc_without_padding = diff_dst_d.dims()[1] / jcp.ngroups;
1286 jcp.oc = jcp.oc_without_padding;
1287 jcp.ic_without_padding = diff_src_d.dims()[1];
1288 jcp.ic = jcp.ic_without_padding / jcp.ngroups;
1289 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
1290 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2];
1291 jcp.iw = diff_src_d.dims()[ndims - 1];
1292 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1293 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2];
1294 jcp.ow = diff_dst_d.dims()[ndims - 1];
1295 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1296 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1297 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1298 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1299 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1300 jcp.l_pad = cd.padding[0][ndims - 3];
1301 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1302 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1303 jcp.stride_w = cd.strides[ndims - 3];
1304
1305 if (everyone_is(1, jcp.stride_d, jcp.stride_h, jcp.stride_w))
1306 return status::unimplemented;
1307
1308 if (jcp.id % jcp.stride_d != 0 || jcp.ih % jcp.stride_h != 0
1309 || jcp.iw % jcp.stride_w)
1310 return status::unimplemented;
1311
1312 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1313 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1314 jcp.dilate_w = cd.dilates[ndims - 3];
1315
1316 if (!everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w))
1317 return status::unimplemented;
1318
1319 jcp.is = jcp.id * jcp.ih * jcp.iw;
1320
1321 jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1322 jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1323 jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1324
1325 jcp.back_pad = calculate_end_padding(
1326 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd);
1327 jcp.b_pad = calculate_end_padding(
1328 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh);
1329 jcp.r_pad = calculate_end_padding(
1330 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw);
1331
1332 jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0
1333 && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0
1334 && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw);
1335
1336 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1337
1338 jcp.src_dt = diff_dst_md.data_type;
1339 jcp.dst_dt = diff_src_md.data_type;
1340 jcp.wei_dt = weights_md.data_type;
1341 jcp.bia_dt = jcp.with_bias ? bias_md.data_type : data_type::undef;
1342
1343 jcp.is_bf32 = everyone_is(f32, jcp.src_dt, jcp.wei_dt)
1344 && attr.fpmath_mode_ == fpmath_mode::bf16 && isa == avx512_core_amx;
1345
1346 if (jcp.is_bf32) return status::unimplemented;
1347
1348 brg_blocking_t::last_oc_block_size = data_type_vnni_granularity(jcp.wei_dt);
1349
1350 // TODO: optimize grouped convolutions with small oc
1351 const bool is_grouped_small_oc
1352 = jcp.prop_kind != prop_kind::backward_weights && with_groups
1353 && jcp.ngroups > 1 && jcp.oc <= 16
1354 && IMPLICATION(is_amx(jcp.isa),
1355 jcp.oc < 16
1356 && jcp.ic < 16
1357 // already optimized for amx 1x1 convs
1358 && !jcp.is_1x1)
1359 // Enable the shapes not supported in direct convs
1360 && IMPLICATION(with_groups, is_groups_ok(jcp));
1361 if (is_grouped_small_oc) return status::unimplemented;
1362
1363 // Dispatch the shapes to VNNI for better performance
1364 // TODO: optimize the perf of 3d shape with small oc and large spatial
1365 const auto max_small_shapes_sz = jcp.is_1x1
1366 ? static_cast<int32_t>(brg_blocking_t::L1) / 2
1367 : static_cast<int32_t>(brg_blocking_t::L1);
1368 const auto is_small_shape = is_amx(jcp.isa) && jcp.is <= 4 && jcp.oc <= 512
1369 && jcp.mb * jcp.ngroups * jcp.oc * jcp.ic <= max_small_shapes_sz;
1370 const auto is_3d_small_oc = is_amx(jcp.isa) && jcp.ndims == 5
1371 && jcp.oc * jcp.ic <= 32 && jcp.id >= 128 && jcp.ih >= 128
1372 && jcp.iw >= 128;
1373 if (is_small_shape || is_3d_small_oc) return status::unimplemented;
1374
1375 jcp.s8s8_avx512 = jcp.src_dt == s8 && !is_amx(jcp.isa);
1376
1377 if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni)))
1378 return status::unimplemented;
1379
1380 if (!IMPLICATION(jcp.wei_dt == bf16, mayiuse(avx512_core_bf16)))
1381 return status::unimplemented;
1382
1383 if (!IMPLICATION(jcp.wei_dt == f16, mayiuse(avx512_core_fp16)))
1384 return status::unimplemented;
1385
1386 jcp.acc_dt = types::is_integral_dt(jcp.src_dt) ? s32 : f32;
1387
1388 jcp.src_dsz = types::data_type_size(jcp.src_dt);
1389 jcp.wei_dsz = types::data_type_size(jcp.wei_dt);
1390 jcp.dst_dsz = types::data_type_size(jcp.dst_dt);
1391 jcp.acc_dsz = types::data_type_size(jcp.acc_dt);
1392 jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
1393
1394 if (!post_ops_ok(jcp, attr, diff_src_d, enable_postops))
1395 return status::unimplemented;
1396
1397 jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / jcp.src_dsz;
1398 jcp.amx_h = 16;
1399 jcp.amx_w = 64 / jcp.src_dsz;
1400
1401 if (jcp.with_bias) {
1402 if (bias_d.format_kind() == format_kind::any)
1403 CHECK(memory_desc_init_by_tag(bias_md, x));
1404 }
1405
1406 const auto &p = attr.post_ops_;
1407 jcp.with_sum = p.find(primitive_kind::sum) != -1;
1408 const int eltwise_ind = p.find(primitive_kind::eltwise);
1409 jcp.with_eltwise = eltwise_ind != -1;
1410
1411 const int binary_ind = p.find(primitive_kind::binary);
1412 jcp.with_binary = binary_ind != -1;
1413
1414 jcp.src_zero_point
1415 = get_zp_type(attr, DNNL_ARG_DIFF_DST) != brgemm_broadcast_t::none;
1416 jcp.dst_zero_point
1417 = get_zp_type(attr, DNNL_ARG_DIFF_SRC) != brgemm_broadcast_t::none;
1418
1419 const bool has_zero_points = jcp.src_zero_point || jcp.dst_zero_point;
1420 if (has_zero_points || jcp.s8s8_avx512) return status::unimplemented;
1421
1422 jcp.nthr = nthreads;
1423 jcp.kh_sets = 1;
1424 jcp.kw_sets = 1;
1425 jcp.copy_block_only = false;
1426 jcp.amx_tile_load_xx = false;
1427 jcp.use_M_mask = 0;
1428 jcp.is_is_blocking = false;
1429 jcp.oskip = 0;
1430 jcp.use_uker = false;
1431 jcp.use_interleave_stores = false;
1432 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf_default;
1433 jcp.brgemm_bd_loop_innermost = false;
1434
1435 // fast check data layout before spending time for blocking selection
1436 format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
1437
1438 CHECK(init_tag(jcp.src_tag, diff_dst_md, diff_dst_d, src_tag));
1439
1440 return status::success;
1441}
1442
1443status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1444 const convolution_desc_t &cd, memory_desc_t &diff_dst_md,
1445 memory_desc_t &weights_md, memory_desc_t &diff_src_md,
1446 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads,
1447 bool enable_postops) {
1448
1449 using namespace prop_kind;
1450
1451 if (!mayiuse(isa)) return status::unimplemented;
1452
1453 CHECK(init_jcp(jcp, isa, cd, diff_dst_md, weights_md, diff_src_md, bias_md,
1454 attr, nthreads, enable_postops));
1455
1456 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
1457 const memory_desc_wrapper weights_d(&weights_md);
1458 const memory_desc_wrapper diff_src_d(&diff_src_md);
1459 const memory_desc_wrapper bias_d(&bias_md);
1460
1461 jcp.l_ovf = nstl::max(0, jcp.ext_kw - 1 - jcp.l_pad) / jcp.stride_w;
1462 jcp.r_ovf = nstl::max(0, jcp.ext_kw - 1 - jcp.r_pad) / jcp.stride_w;
1463 jcp.t_ovf = nstl::max(0, jcp.ext_kh - 1 - jcp.t_pad) / jcp.stride_h;
1464 jcp.b_ovf = nstl::max(0, jcp.ext_kh - 1 - jcp.b_pad) / jcp.stride_h;
1465 jcp.f_ovf = nstl::max(0, jcp.ext_kd - 1 - jcp.f_pad) / jcp.stride_d;
1466 jcp.back_ovf = nstl::max(0, jcp.kd - 1 - jcp.back_pad) / jcp.stride_d;
1467
1468 jcp.odp = jcp.od + jcp.f_ovf + jcp.back_ovf;
1469 jcp.ohp = jcp.oh + jcp.t_ovf + jcp.b_ovf;
1470 jcp.owp = jcp.ow + jcp.l_ovf + jcp.r_ovf;
1471
1472 using namespace data_type;
1473 // ======================= blocking =================================
1474
1475 const int min_ic_block = 16;
1476 int selected_ur = 0;
1477
1478 //-----------------------------------------------------------------------
1479
1480 jcp.exec_type = exec_trans;
1481 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1482
1483 // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more
1484 // precise calculation of jcp.max_batch
1485 jcp.max_batch = jcp.kd * jcp.kh * jcp.kw;
1486
1487 jcp.wei_plain = false;
1488
1489 // try loop_ndhwgc always for exec_trans
1490 jcp.loop_order = loop_ndhwgc;
1491
1492 jcp.copy_block_only = true;
1493
1494 const auto oc_padded_block = 16 * brg_blocking_t::last_oc_block_size;
1495 jcp.is_oc_padded = one_of(jcp.wei_dt, bf16, s8)
1496 && jcp.oc * jcp.kw_sets > oc_padded_block;
1497
1498 if (is_amx(isa) && (/* heuristic */ jcp.kw_sets == 1 && jcp.iw < 256)) {
1499 jcp.use_M_mask = 0;
1500
1501 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
1502
1503 // assuming 2x2 decomposition in amx brgemm kernel
1504 // and overlap of input by kw
1505 const auto bd_blocking = 2 * jcp.amx_h;
1506 const auto ld_blocking = 2 * 16;
1507 const auto A_ds = jcp.src_dsz * bd_blocking * jcp.oc * jcp.kd * jcp.kh;
1508 const auto B_ds
1509 = jcp.wei_dsz * ld_blocking * jcp.oc * jcp.kd * jcp.kh * jcp.kw;
1510 const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking;
1511 if (A_ds + B_ds + C_ds > brg_blocking_t::L1)
1512 jcp.amx_tile_load_xx = true;
1513 }
1514
1515 auto try_exec_type = [&]() {
1516 brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1517 best_brgb.ic_block = min_ic_block;
1518 brg_blocking_t cur_brgb = zero<decltype(best_brgb)>();
1519 cur_brgb.get_from_jcp(jcp);
1520 const auto start_icb = nstl::min(div_up(jcp.ic, 16), 4);
1521
1522 auto finish_icb = 1;
1523 for (auto icb = start_icb; icb >= finish_icb; icb--) {
1524 cur_brgb.ic_block = icb * 16;
1525 cur_brgb.nb_ic = utils::div_up(jcp.ic, cur_brgb.ic_block);
1526 if (!cur_brgb.fast_check_ic_block()) continue;
1527
1528 const status_t blocking_ok = cur_brgb.calc_blocks();
1529 if (blocking_ok != status::success) continue;
1530
1531 const status_t st = cur_brgb.get_brgemm_ur(&attr, diff_src_md);
1532 if (st != status::success) continue;
1533 cur_brgb.eff = cur_brgb.est_eff();
1534 if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
1535 }
1536 if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0
1537 || best_brgb.iw_block == 0)
1538 return false;
1539 best_brgb.save_to_jcp(jcp);
1540 selected_ur = best_brgb.ur;
1541 return true;
1542 };
1543
1544 if (!try_exec_type()) return status::unimplemented;
1545
1546 // ============ end blocking ===========================================
1547 jcp.max_vpad = 0;
1548
1549 if (jcp.iw_block == 0 || jcp.oc_block == 0 || jcp.ic_block == 0)
1550 return status::unimplemented;
1551
1552 jcp.gemm_batch_size = jcp.nb_oc_blocking
1553 * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block,
1554 jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad);
1555 // to avoid cache concurrent write access from different threads
1556 size_t sc_size = sizeof(brgemm_batch_element_t);
1557 jcp.adjusted_batch_size
1558 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size);
1559
1560 CHECK(pick_tags(jcp, diff_dst_md, weights_md, diff_src_md, bias_md));
1561 CHECK(attr.set_default_formats(&diff_src_md));
1562
1563 jcp.buffer_size = jcp.LDC * (jcp.M > 0 ? jcp.M : jcp.M_tail);
1564
1565 jcp.nb_id = div_up(jcp.id, jcp.id_block);
1566 jcp.nb_ih = div_up(jcp.ih, jcp.ih_block);
1567
1568 jcp.inp_buffer_size = rnd_up(jcp.odp * jcp.ohp * jcp.owp * jcp.ngroups
1569 * jcp.nb_oc * jcp.oc_block,
1570 P4K);
1571 jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_id) * jcp.nb_ih
1572 * jcp.nb_iw * jcp.ngroups * jcp.nb_oc,
1573 P4K);
1574
1575 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
1576 const bool with_pad = jcp.f_pad > 0 || jcp.back_pad > 0 || jcp.t_pad > 0
1577 || jcp.b_pad > 0 || jcp.l_pad > 0 || jcp.r_pad > 0;
1578
1579 if (jcp.s8s8_avx512) {
1580 weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8;
1581 weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1;
1582 }
1583 if (jcp.src_zero_point && !is_amx(jcp.isa)) {
1584 weights_md.extra.flags
1585 |= memory_extra_flags::compensation_conv_asymmetric_src;
1586 weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1;
1587 }
1588
1589 // For padding shapes, we calculate the comp along with the computation
1590 // inside brgemm kernel when output size is small to get optimal perf
1591 // Or we calculate the comp using brgemm_coomp_pad kernel
1592 const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.ic
1593 * jcp.id * jcp.ih * jcp.iw;
1594 const auto comp_with_pads = (jcp.src_zero_point || jcp.s8s8_avx512)
1595 && IMPLICATION(jcp.exec_type == exec_vpad, with_pad);
1596 jcp.req_brg_comp_pad = comp_with_pads && output_sz <= 8192 && jcp.ic < 512;
1597 jcp.req_cal_comp_pad = comp_with_pads && !jcp.req_brg_comp_pad;
1598
1599 // estimate the number of kernel range combination for compensation
1600 const auto kd_cnt = 1 + utils::div_up(abs(jcp.f_pad), jcp.dilate_d + 1)
1601 + utils::div_up(abs(jcp.back_pad), jcp.dilate_d + 1);
1602 const auto kh_cnt = 1 + utils::div_up(abs(jcp.t_pad), jcp.dilate_h + 1)
1603 + utils::div_up(abs(jcp.b_pad), jcp.dilate_h + 1);
1604
1605 jcp.ker_ranges_size = kd_cnt * kh_cnt;
1606 jcp.comp_a_buffer_size = static_cast<dim_t>(jcp.ngroups) * jcp.nb_ic
1607 * jcp.ker_ranges_size * jcp.iw * jcp.ic_block;
1608 jcp.s8s8_comp_buffer_size = jcp.comp_a_buffer_size;
1609
1610 return status::success;
1611}
1612
1613void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1614 const jit_brgemm_conv_conf_t &jcp) {
1615 if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs
1616 || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad))
1617 scratchpad.book(key_brgemm_primitive_batch,
1618 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
1619 sizeof(brgemm_batch_element_t), 64, P4K);
1620
1621 size_t inp_buffer_size
1622 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size;
1623 scratchpad.book(
1624 key_conv_brgemm_inp_buffer, inp_buffer_size, jcp.src_dsz, 0, P4K);
1625 size_t inp_buffer_mask_size
1626 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size;
1627 scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size,
1628 sizeof(uint8_t), 0, P4K);
1629
1630 if (jcp.use_buffer) {
1631 scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size,
1632 jcp.acc_dsz, 0, P4K);
1633 }
1634 if (is_amx(jcp.isa)) {
1635 scratchpad.book(key_conv_amx_tile_buffer, jcp.nthr * 2 * P4K,
1636 sizeof(char), 0, P4K);
1637 }
1638 if (jcp.s8s8_avx512 && jcp.req_cal_comp_pad) {
1639 scratchpad.book(key_brgemm_primitive_buffer_comp,
1640 jcp.s8s8_comp_buffer_size, sizeof(int32_t), 0, P4K);
1641 }
1642 if (jcp.src_zero_point && jcp.req_cal_comp_pad && !is_amx(jcp.isa)) {
1643 scratchpad.book(key_brgemm_primitive_zp_comp_a, jcp.comp_a_buffer_size,
1644 sizeof(int32_t), 0, P4K);
1645 }
1646}
1647
1648} // namespace brgemm_convolution_bwd_utils
1649
1650} // namespace x64
1651} // namespace cpu
1652} // namespace impl
1653} // namespace dnnl
1654