1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_types.h"
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/platform.hpp"
27#include "cpu/scale_utils.hpp"
28#include "cpu/x64/brgemm/brgemm_utils.hpp"
29#include "cpu/x64/cpu_barrier.hpp"
30#include "cpu/x64/cpu_isa_traits.hpp"
31#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
32#include "cpu/x64/jit_brgemm_conv_utils.hpp"
33#include "cpu/x64/jit_generator.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38namespace x64 {
39
40using namespace dnnl::impl::status;
41using namespace dnnl::impl::format_tag;
42using namespace dnnl::impl::memory_tracking::names;
43using namespace dnnl::impl::utils;
44
45using namespace prop_kind;
46using namespace data_type;
47
48namespace brgemm_convolution_utils {
49
50bool is_any_eligible(const jit_brgemm_conv_conf_t &jcp) {
51 return (jcp.prop_kind == prop_kind::forward_inference
52 || one_of(jcp.wei_dt, data_type::s8, data_type::f16)
53 || (jcp.isa == avx2_vnni_2) || is_amx(jcp.isa));
54}
55
56inline status_t init_tag(format_tag_t &tag, memory_desc_t &md,
57 const memory_desc_wrapper &mdw, const format_tag_t tag_value,
58 bool any_eligible) {
59
60 if (mdw.format_kind() == format_kind::any) {
61 if (any_eligible) {
62 CHECK(memory_desc_init_by_tag(md, tag_value));
63 tag = tag_value;
64 } else {
65 tag = format_tag::undef;
66 }
67 } else {
68 tag = mdw.matches_one_of_tag(tag_value);
69 }
70
71 if (tag != tag_value) return status::unimplemented;
72
73 return status::success;
74}
75
76bool is_amx(cpu_isa_t isa) {
77 return is_superset(isa, avx512_core_amx);
78}
79
80bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr,
81 const memory_desc_wrapper &dst_d) {
82 using namespace injector;
83
84 const auto &post_ops = attr.post_ops_;
85
86 return injector::post_ops_ok(post_ops_ok_args_t(jcp.isa,
87 {sum, eltwise, binary}, post_ops, &dst_d,
88 false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
89 false /*sum_requires_zp_zero*/,
90 {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar,
91 broadcasting_strategy_t::no_broadcast}));
92}
93
94bool is_groups_ok(jit_brgemm_conv_conf_t &jcp) {
95 // Enable grouped convs for the shapes not supported in direct convs
96 // direct approach only supports int8/bf16 grouped conv
97 // when channels per groups is at least multiple of 4
98 // and bf16 grouped conv with layout nxc on jit_bf16 impl
99 // TODO: remove this condition after the restriction on small ic is removed
100 return jcp.ngroups > 1
101 && IMPLICATION(one_of(jcp.src_dt, u8, s8, bf16),
102 jcp.ic % 4 == 0 && jcp.oc % 4 == 0);
103}
104
105status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md,
106 memory_desc_t &weights_md, memory_desc_t &dst_md,
107 memory_desc_t &bias_md) {
108 format_tag_t src_tag, dst_tag, wei_tag;
109 dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
110
111 const memory_desc_wrapper src_d(&src_md);
112 const memory_desc_wrapper weights_d(&weights_md);
113 const memory_desc_wrapper dst_d(&dst_md);
114 const memory_desc_wrapper bias_d(&bias_md);
115 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
116 const int vnni_granularity
117 = (jcp.wei_dt == f16 && jcp.isa == avx512_core_fp16)
118 ? 1
119 : data_type_vnni_granularity(jcp.wei_dt);
120
121 const bool is_1d = jcp.ndims == 3;
122 const bool is_2d = jcp.ndims == 4;
123 const bool is_3d = jcp.ndims == 5;
124
125 if (jcp.wei_plain) {
126 jcp.LDB = jcp.oc;
127 if (is_3d) {
128 switch (vnni_granularity) {
129 case 1: wei_tag = with_groups ? gdhwio : dhwio; break;
130 case 2: wei_tag = with_groups ? gdhwIo2i : dhwIo2i; break;
131 case 4: wei_tag = with_groups ? gdhwIo4i : dhwIo4i; break;
132 default: return status::unimplemented;
133 }
134 } else if (is_1d) {
135 switch (vnni_granularity) {
136 case 1: wei_tag = with_groups ? gwio : wio; break;
137 case 2: wei_tag = with_groups ? gwIo2i : wIo2i; break;
138 case 4: wei_tag = with_groups ? gwIo4i : wIo4i; break;
139 default: return status::unimplemented;
140 }
141 } else {
142 assert(is_2d);
143 UNUSED(is_2d);
144 switch (vnni_granularity) {
145 case 1: wei_tag = with_groups ? ghwio : hwio; break;
146 case 2: wei_tag = with_groups ? ghwIo2i : hwIo2i; break;
147 case 4: wei_tag = with_groups ? ghwIo4i : hwIo4i; break;
148 default: return status::unimplemented;
149 }
150 }
151 } else {
152 jcp.LDB = jcp.oc_block;
153 if (jcp.oc_block == 64) {
154 if (is_3d) {
155 switch (vnni_granularity) {
156 case 1: wei_tag = with_groups ? gOdhwi64o : Odhwi64o; break;
157 case 2:
158 if (jcp.is_ic_padded)
159 wei_tag = with_groups ? gOdhwI16i64o2i
160 : OdhwI16i64o2i;
161 else
162 wei_tag = with_groups ? gOdhwI64o2i : OdhwI64o2i;
163 break;
164 case 4:
165 if (jcp.is_ic_padded)
166 wei_tag = with_groups ? gOdhwI16i64o4i
167 : OdhwI16i64o4i;
168 else
169 wei_tag = with_groups ? gOdhwI64o4i : OdhwI64o4i;
170 break;
171 default: return status::unimplemented;
172 }
173 } else if (is_1d) {
174 switch (vnni_granularity) {
175 case 1: wei_tag = with_groups ? gOwi64o : Owi64o; break;
176 case 2:
177 if (jcp.is_ic_padded)
178 wei_tag = with_groups ? gOwI16i64o2i : OwI16i64o2i;
179 else
180 wei_tag = with_groups ? gOwI64o2i : OwI64o2i;
181 break;
182 case 4:
183 if (jcp.is_ic_padded)
184 wei_tag = with_groups ? gOwI16i64o4i : OwI16i64o4i;
185 else
186 wei_tag = with_groups ? gOwI64o4i : OwI64o4i;
187 break;
188 default: return status::unimplemented;
189 }
190 } else {
191 assert(is_2d);
192 UNUSED(is_2d);
193 switch (vnni_granularity) {
194 case 1: wei_tag = with_groups ? gOhwi64o : Ohwi64o; break;
195 case 2:
196 if (jcp.is_ic_padded)
197 wei_tag = with_groups ? gOhwI16i64o2i
198 : OhwI16i64o2i;
199 else
200 wei_tag = with_groups ? gOhwI64o2i : OhwI64o2i;
201 break;
202 case 4:
203 if (jcp.is_ic_padded)
204 wei_tag = with_groups ? gOhwI16i64o4i
205 : OhwI16i64o4i;
206 else
207 wei_tag = with_groups ? gOhwI64o4i : OhwI64o4i;
208 break;
209 default: return status::unimplemented;
210 }
211 }
212 } else if (jcp.oc_block == 48) {
213 if (is_3d) {
214 switch (vnni_granularity) {
215 case 1: wei_tag = with_groups ? gOdhwi48o : Odhwi48o; break;
216 case 2:
217 if (jcp.is_ic_padded)
218 wei_tag = with_groups ? gOdhwI16i48o2i
219 : OdhwI16i48o2i;
220 else
221 wei_tag = with_groups ? gOdhwI48o2i : OdhwI48o2i;
222 break;
223 case 4:
224 if (jcp.is_ic_padded)
225 wei_tag = with_groups ? gOdhwI16i48o4i
226 : OdhwI16i48o4i;
227 else
228 wei_tag = with_groups ? gOdhwI48o4i : OdhwI48o4i;
229 break;
230 default: return status::unimplemented;
231 }
232 } else if (is_1d) {
233 switch (vnni_granularity) {
234 case 1: wei_tag = with_groups ? gOwi48o : Owi48o; break;
235 case 2:
236 if (jcp.is_ic_padded)
237 wei_tag = with_groups ? gOwI16i48o2i : OwI16i48o2i;
238 else
239 wei_tag = with_groups ? gOwI48o2i : OwI48o2i;
240 break;
241 case 4:
242 if (jcp.is_ic_padded)
243 wei_tag = with_groups ? gOwI16i48o4i : OwI16i48o4i;
244 else
245 wei_tag = with_groups ? gOwI48o4i : OwI48o4i;
246 break;
247 default: return status::unimplemented;
248 }
249 } else {
250 assert(is_2d);
251 UNUSED(is_2d);
252 switch (vnni_granularity) {
253 case 1: wei_tag = with_groups ? gOhwi48o : Ohwi48o; break;
254 case 2:
255 if (jcp.is_ic_padded)
256 wei_tag = with_groups ? gOhwI16i48o2i
257 : OhwI16i48o2i;
258 else
259 wei_tag = with_groups ? gOhwI48o2i : OhwI48o2i;
260 break;
261 case 4:
262 if (jcp.is_ic_padded)
263 wei_tag = with_groups ? gOhwI16i48o4i
264 : OhwI16i48o4i;
265 else
266 wei_tag = with_groups ? gOhwI48o4i : OhwI48o4i;
267 break;
268 default: return status::unimplemented;
269 }
270 }
271 } else if (jcp.oc_block == 32) {
272 if (is_3d) {
273 switch (vnni_granularity) {
274 case 1: wei_tag = with_groups ? gOdhwi32o : Odhwi32o; break;
275 case 2:
276 if (jcp.is_ic_padded)
277 wei_tag = with_groups ? gOdhwI16i32o2i
278 : OdhwI16i32o2i;
279 else
280 wei_tag = with_groups ? gOdhwI32o2i : OdhwI32o2i;
281 break;
282 case 4:
283 if (jcp.is_ic_padded)
284 wei_tag = with_groups ? gOdhwI16i32o4i
285 : OdhwI16i32o4i;
286 else
287 wei_tag = with_groups ? gOdhwI32o4i : OdhwI32o4i;
288 break;
289 default: return status::unimplemented;
290 }
291 } else if (is_1d) {
292 switch (vnni_granularity) {
293 case 1: wei_tag = with_groups ? gOwi32o : Owi32o; break;
294 case 2:
295 if (jcp.is_ic_padded)
296 wei_tag = with_groups ? gOwI16i32o2i : OwI16i32o2i;
297 else
298 wei_tag = with_groups ? gOwI32o2i : OwI32o2i;
299 break;
300 case 4:
301 if (jcp.is_ic_padded)
302 wei_tag = with_groups ? gOwI16i32o4i : OwI16i32o4i;
303 else
304 wei_tag = with_groups ? gOwI32o4i : OwI32o4i;
305 break;
306 default: return status::unimplemented;
307 }
308 } else {
309 assert(is_2d);
310 UNUSED(is_2d);
311 switch (vnni_granularity) {
312 case 1: wei_tag = with_groups ? gOhwi32o : Ohwi32o; break;
313 case 2:
314 if (jcp.is_ic_padded)
315 wei_tag = with_groups ? gOhwI16i32o2i
316 : OhwI16i32o2i;
317 else
318 wei_tag = with_groups ? gOhwI32o2i : OhwI32o2i;
319 break;
320 case 4:
321 if (jcp.is_ic_padded)
322 wei_tag = with_groups ? gOhwI16i32o4i
323 : OhwI16i32o4i;
324 else
325 wei_tag = with_groups ? gOhwI32o4i : OhwI32o4i;
326 break;
327 default: return status::unimplemented;
328 }
329 }
330 } else if (jcp.oc_block == 16) {
331 if (is_3d) {
332 switch (vnni_granularity) {
333 case 1: wei_tag = with_groups ? gOdhwi16o : Odhwi16o; break;
334 case 2:
335 if (jcp.is_ic_padded)
336 wei_tag = with_groups ? gOdhwI16i16o2i
337 : OdhwI16i16o2i;
338 else
339 wei_tag = with_groups ? gOdhwI16o2i : OdhwI16o2i;
340 break;
341 case 4:
342 if (jcp.is_ic_padded)
343 wei_tag = with_groups ? gOdhwI16i16o4i
344 : OdhwI16i16o4i;
345 else
346 wei_tag = with_groups ? gOdhwI16o4i : OdhwI16o4i;
347 break;
348 default: return status::unimplemented;
349 }
350 } else if (is_1d) {
351 switch (vnni_granularity) {
352 case 1: wei_tag = with_groups ? gOwi16o : Owi16o; break;
353 case 2:
354 if (jcp.is_ic_padded)
355 wei_tag = with_groups ? gOwI16i16o2i : OwI16i16o2i;
356 else
357 wei_tag = with_groups ? gOwI16o2i : OwI16o2i;
358 break;
359 case 4:
360 if (jcp.is_ic_padded)
361 wei_tag = with_groups ? gOwI16i16o4i : OwI16i16o4i;
362 else
363 wei_tag = with_groups ? gOwI16o4i : OwI16o4i;
364 break;
365 default: return status::unimplemented;
366 }
367 } else {
368 assert(is_2d);
369 UNUSED(is_2d);
370
371 switch (vnni_granularity) {
372 case 1: wei_tag = with_groups ? gOhwi16o : Ohwi16o; break;
373 case 2:
374 if (jcp.is_ic_padded)
375 wei_tag = with_groups ? gOhwI16i16o2i
376 : OhwI16i16o2i;
377 else
378 wei_tag = with_groups ? gOhwI16o2i : OhwI16o2i;
379 break;
380 case 4:
381 if (jcp.is_ic_padded)
382 wei_tag = with_groups ? gOhwI16i16o4i
383 : OhwI16i16o4i;
384 else
385 wei_tag = with_groups ? gOhwI16o4i : OhwI16o4i;
386 break;
387 default: return status::unimplemented;
388 }
389 }
390 } else if (jcp.oc_block == 8) {
391 if (vnni_granularity == 1)
392 wei_tag = with_groups ? gOhwi8o : Ohwi8o;
393 else
394 return status::unimplemented;
395 } else {
396 return status::unimplemented;
397 }
398 }
399
400 src_tag = dst_tag;
401
402 const bool any_eligible = is_any_eligible(jcp);
403 CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible));
404 CHECK(init_tag(jcp.dst_tag, dst_md, dst_d, dst_tag, any_eligible));
405 CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag, true));
406
407 return status::success;
408}
409
410struct brg_blocking_t : public jit_brgemm_conv_conf_t {
411 struct array_in_loop_t {
412 dim_t itersize;
413 float repeatn;
414 float overlap;
415 void set(dim_t iter_s, float rpt, float ovlp = 1.f) {
416 itersize = iter_s;
417 repeatn = rpt;
418 overlap = ovlp;
419 }
420 };
421
422 struct loop_t {
423 array_in_loop_t src;
424 array_in_loop_t wei;
425 array_in_loop_t dst;
426 };
427
428 brg_blocking_t() {
429 // TODO: This is a broken form of initialization for a base class.
430 // Either set default values in a base class, or provide a proper
431 // default ctor, or take a `jit_brgemm_conv_conf_t` object to initialize
432 // a base class object.
433 jit_brgemm_conv_conf_t *base
434 = static_cast<jit_brgemm_conv_conf_t *>(this);
435 *base = jit_brgemm_conv_conf_t();
436 init();
437 }
438 brg_blocking_t(const jit_brgemm_conv_conf_t &jcp)
439 : jit_brgemm_conv_conf_t(jcp) {
440 init();
441 }
442 void init() {
443 ur = 0;
444 ur_block = 0;
445 ur_block_tail = 0;
446 eff = 0.f;
447 nb_kd = 0;
448 nb_kh = 0;
449 nb_kw = 0;
450 sp = 0;
451 sp_block = 0;
452 nb_sp = 0;
453 eff = 0;
454 max_regs = isa_num_vregs(isa);
455 bcast_simd = acc_simd_w;
456 }
457
458 int ur, ur_block, ur_block_tail;
459 int nb_kd, nb_kh, nb_kw;
460 int max_regs;
461 int bcast_simd;
462 float eff;
463 static unsigned L1;
464 static unsigned L2;
465 static unsigned L3;
466 // These are rough estimates of the latency (relative) of access to various
467 // cache levels. This is enough for an estimation of data access cost.
468 // TODO: Improve memory access estimates
469 static constexpr float L1_k = 1.f;
470 static constexpr float L2_k = 3.f;
471 static constexpr float L3_k = 15.f;
472 // TODO: At the moment, we are primarily evaluating the fit of the data into
473 // the L1/L2. Need to take into account the difference between the L3 and
474 // memory.
475 static constexpr float mem_k = 15.f;
476 static constexpr int bench_iterations = 1;
477
478 int sp, sp_block, nb_sp;
479 static int last_ic_block_size;
480
481 void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
482 void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
483
484 status_t estimate_brgemm_ur();
485 status_t get_brgemm_ur(
486 const primitive_attr_t *attr, const memory_desc_t &dst_md);
487
488 float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
489 bool is_broadcast, bool is_shared) const;
490
491 float io_k(const loop_t loop, const array_in_loop_t arr, float pk,
492 bool is_broadcast, bool is_shared) const;
493
494 void select_ic_block();
495
496 void update_blocks();
497 bool fast_check_oc_block() const;
498 float est_eff();
499 void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block,
500 int kh_block, bool maybe_use_buffer, int max_ow_block_thr);
501 status_t calc_blocks();
502
503 bool fast_check_oc_block_1x1() const;
504 float est_eff_1x1();
505 void calc_blocks_1x1();
506
507 // utils
508 static int get_inp_size(
509 int max_src_size, int dst_size, int k, int stride, int dilate) {
510 auto adj_str = nstl::min(k, stride);
511 const auto res = nstl::min(max_src_size,
512 calculate_end_padding(0, dst_size, 0, adj_str,
513 calculate_extended_filter_size(k, dilate)));
514 return res;
515 }
516
517 static float squeeze_val(float eff, float koeff) {
518 if (koeff <= 0) return 1;
519 if (koeff == 1) return eff;
520 const auto k = 1.f / koeff;
521 return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff;
522 }
523
524 static int estimate_ur(int oc_block) {
525 const auto est_ur = (oc_block == 64)
526 ? 6
527 : ((oc_block == 48) ? 9 : ((oc_block == 32) ? 14 : 28));
528 return est_ur;
529 }
530
531 int inp_w(int out_w, int ker_w) const {
532 return get_inp_size(iw, out_w, ker_w, stride_w, dilate_w);
533 }
534
535 int rnd_simd(int val) const { return rnd_up(val, simd_w); }
536
537 int rnd_inp_simd(int out_w, int ker_w, int vic) const {
538 const auto vsp = inp_w(out_w, ker_w);
539 return ((stride_w == 1 && vic >= ic) ? rnd_up(vsp * vic, simd_w)
540 : vsp * rnd_up(vic, simd_w));
541 }
542
543 static constexpr int MAXNLOOPS = 32;
544 loop_t loop[MAXNLOOPS];
545};
546
547unsigned brg_blocking_t::L1;
548unsigned brg_blocking_t::L2;
549unsigned brg_blocking_t::L3;
550int brg_blocking_t::last_ic_block_size;
551
552float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
553 bool is_broadcast, bool is_shared) const {
554 if (n < 1) return 0;
555 if (n == 1) return pk;
556 const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz
557 + (use_buffer ? dst * acc_dsz : 0);
558 const auto amount_L1 = is_broadcast ? src * src_dsz : amount;
559 const auto k = is_broadcast
560 ? ((amount_L1 < L1) ? L1_k
561 : ((amount < L2) ? L2_k
562 : (is_shared ? L3_k : mem_k)))
563 : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k));
564 const auto cost = pk + k * (n - 1);
565 return cost / n;
566}
567
568float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
569 float pk, bool is_broadcast, bool is_shared) const {
570 return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize,
571 arr.repeatn * arr.overlap, pk, is_broadcast, is_shared);
572}
573
574void brg_blocking_t::select_ic_block() {
575 if (is_1x1 && is_amx(isa)) {
576 // TODO: merge with non-1x1 code block below
577 const int ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
578 assert(IMPLICATION(
579 !is_bf32, ic < ic_padded_block || ic % ic_padded_block == 0));
580 MAYBE_UNUSED(ic_padded_block);
581 // Note: bf32 requires ic_block be less than 64, otherwise it results
582 // in incorrect output.
583 ic_block = is_bf32 && (!is_rtus) ? nstl::min(64, ic) : ic;
584 nb_ic = utils::div_up(ic, ic_block); // trivially 1 for now
585 return;
586 }
587 auto nb_simd = utils::div_up(ic, simd_w);
588 auto max_simd_blocks = nstl::min(5 * simd_w, nb_simd);
589 const auto nb_icb_eff_threshold = 0.5f;
590 const auto padded_ic = last_ic_block_size * (is_ic_padded ? acc_simd_w : 1);
591 if (is_amx(isa)) {
592 if (ic * kw_sets < simd_w) {
593 // this is current requirement from brgemm kernel
594 ic_block = rnd_up(ic, last_ic_block_size);
595 } else if (is_bf32) {
596 ic_block = simd_w;
597 } else {
598 if (exec_type == exec_trans) {
599 auto simd_blocks = 1;
600 for (int nb_icb = max_simd_blocks; nb_icb >= 1; nb_icb--) {
601 auto nb_icb_eff = static_cast<float>(nb_simd)
602 / rnd_up(nb_simd, nb_icb);
603 if (nb_icb_eff >= nb_icb_eff_threshold) {
604 simd_blocks = nb_icb;
605 break;
606 }
607 }
608 ic_block = simd_blocks * simd_w;
609 } else
610 ic_block = simd_w;
611 }
612 } else {
613 const auto est_ur = nstl::min(sp_block, estimate_ur(oc_block));
614 const auto inp_ur = is_os_blocking ? est_ur : inp_w(est_ur, kw_block);
615
616 if (kw_block > 1) {
617 // try to fit src into L1
618 const auto inp_per_ic = static_cast<unsigned int>(inp_ur) * src_dsz;
619 max_simd_blocks = saturate(1, max_simd_blocks,
620 static_cast<int>(L1 / (inp_per_ic * simd_w)));
621 }
622 // try to fit all batch for ur into L2
623 const auto wei_per_ic = static_cast<unsigned int>(kd_block) * kh_block
624 * kw_block * oc_block * wei_dsz;
625 const auto inp_per_ic = static_cast<unsigned int>(kd_block) * kh_block
626 * inp_ur * src_dsz;
627 const auto out_size
628 = static_cast<unsigned int>(ur) * oc_block * dst_dsz;
629
630 max_simd_blocks = saturate(1, max_simd_blocks,
631 static_cast<int>((L2 - out_size)
632 / ((wei_per_ic + inp_per_ic) * simd_w)));
633
634 auto simd_blocks = 1;
635 for (int nb_icb = nstl::min(max_simd_blocks, nb_simd); nb_icb >= 1;
636 nb_icb--) {
637 auto nb_icb_eff
638 = static_cast<float>(nb_simd) / rnd_up(nb_simd, nb_icb);
639 if (nb_icb_eff >= nb_icb_eff_threshold) {
640 simd_blocks = nb_icb;
641 break;
642 }
643 }
644
645 ic_block = nstl::min(
646 (exec_type == exec_trans) ? rnd_up(ic, padded_ic) : ic,
647 simd_blocks * simd_w);
648 }
649 nb_ic = utils::div_up(ic, ic_block);
650}
651
652status_t brg_blocking_t::estimate_brgemm_ur() {
653 // Simple simulation of brgemm_desc init
654 if (sp_block <= 0) return status::invalid_arguments;
655 LDA = is_rtus
656 ? (ic_block)
657 : (kh_sets > 1 ? kh_sets : 1) * (kw_sets > 1 ? kw_sets : stride_w)
658 * (exec_type == exec_trans ? ic_block
659 : ngroups * ic_without_padding);
660 LDB = oc_block;
661 LDC = use_buffer ? oc_block : oc_without_padding;
662
663 // Configure matrix sizes
664 // for amx if ic_block != ic then we use exec_trans so K is ic_block
665 const auto padded_ic = last_ic_block_size * (is_ic_padded ? acc_simd_w : 1);
666
667 icp = rnd_up(ic, padded_ic);
668 M = brgM = sp >= sp_block ? sp_block : 0;
669 M_tail = brgM_tail = sp % sp_block;
670 if (is_os_blocking) {
671 if (!is_1x1) M_tail = brgM_tail = (oh * ow) % sp_block;
672 oskip = ((ext_kw - 1) / stride_w) * stride_h + (stride_h - 1) * ow;
673
674 brgM = sp_block + oskip * (div_up(M, ow) - 1);
675
676 // round up brgM to help brgemm kernel use max amx_h as brgemm bd_block
677 if (use_M_mask == 2) {
678 int ibrgM = 0;
679 const auto adj_ow = ow_block + oskip;
680 while (ibrgM < brgM) {
681 if (ibrgM % adj_ow < ow_block)
682 ibrgM += amx_h;
683 else
684 ibrgM++;
685 }
686 brgM = ibrgM;
687 } else
688 brgM = rnd_up(brgM, amx_h);
689
690 brgM_tail = brgM;
691 }
692
693 N = oc >= oc_block ? oc_block : 0;
694 N_tail = oc % oc_block;
695
696 K = kh_sets * kw_sets * (ic >= ic_block ? ic_block : 0);
697 K_tail = kh_sets * kw_sets
698 * (exec_type == exec_trans && (!is_bf32)
699 ? ic_block
700 : rnd_up(ic % ic_block, last_ic_block_size));
701
702 const auto vK = K > 0 ? K : K_tail;
703 const auto vM = M > 0 ? M : M_tail;
704 const auto vN = N > 0 ? N : N_tail;
705
706 const float alpha = 1.0;
707 const float beta = 0.0;
708 brgemm_t brg;
709 brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
710 brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr,
711 is_bf32);
712 CHECK(brgemm_utils::brgemm_blocking(&brg));
713 ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1);
714 ur_block = brg.bd_block;
715 if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) {
716 brgemm_t brg_sp_tail;
717 brgemm_utils::init_brgemm_conf(&brg_sp_tail, isa, brgemm_addr, src_dt,
718 wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, M_tail,
719 vN, vK, nullptr, is_bf32);
720 CHECK(brgemm_utils::brgemm_blocking(&brg_sp_tail));
721 ur_block_tail = brg_sp_tail.bd_block;
722 } else {
723 ur_block_tail = 0;
724 }
725 return status::success;
726}
727
728status_t brg_blocking_t::get_brgemm_ur(
729 const primitive_attr_t *attr, const memory_desc_t &dst_md) {
730 // Detailed simulation of brgemm convolution init
731 if (sp_block <= 0 || ic_block <= 0 || oc_block <= 0)
732 return status::invalid_arguments;
733 CHECK(estimate_brgemm_ur());
734
735 LDD = oc_without_padding;
736
737 const float alpha = 1.0;
738 const float beta = 1.0;
739 const float beta_init = 0.0;
740
741 for (int i = 0; i < M; i++) {
742 auto vM = i + 1;
743 // init only needed brgemm descriptors
744 if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1)
745 && vM != M && vM != M_tail)
746 continue;
747 for (int i_init = 0; i_init < 2; i_init++) {
748 for (int i_N = 0; i_N < 2; i_N++) {
749 for (int i_K = 0; i_K < 2; i_K++) {
750 auto vbeta = (i_init) ? beta_init : beta;
751 auto vN = (i_N) ? N_tail : N;
752 auto vK = (i_K) ? K_tail : K;
753 if (vN == 0 || vK == 0) continue;
754 brgemm_t brg;
755 brgemm_strides_t brg_strides;
756 brg_strides.stride_a = ngroups * ic_without_padding
757 * (dilate_w + 1) * src_dsz;
758 //weights are padded by oc_block and last_ic_block
759 brg_strides.stride_b = rnd_up(ic, last_ic_block_size)
760 * rnd_up(oc, oc_block) * wei_dsz;
761 const auto strides_ptr = (brg_type == brgemm_strd)
762 ? &brg_strides
763 : nullptr;
764 brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt,
765 wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB,
766 LDC, vM, vN, vK, strides_ptr, is_bf32);
767 CHECK(brgemm_utils::brgemm_blocking(&brg));
768
769 brgemm_attr_t brgattr;
770 brgattr.max_bs = max_batch;
771 const auto max_vpad = (exec_type == exec_vpad)
772 ? nstl::max(l_pad, r_pad)
773 : 0;
774 brgattr.max_top_vpad = max_vpad;
775 brgattr.max_bottom_vpad = max_vpad;
776 brgattr.fpmath_mode = attr->fpmath_mode_;
777 CHECK(brgemm_desc_set_attr(&brg, brgattr));
778
779 brg.with_sum = with_sum;
780 CHECK(brgemm_desc_set_postops(
781 &brg, attr, &dst_md, LDD, bia_dt));
782 }
783 }
784 }
785 }
786
787 return status::success;
788}
789
790void brg_blocking_t::update_blocks() {
791 if (sp_block <= 0
792 || utils::one_of(0, od_block, oh_block, ic_block, oc_block,
793 kd_block, kh_block, kw_block, os_block, ow_block))
794 return;
795
796 nb_od = div_up(od, od_block);
797 nb_oh = div_up(oh, oh_block);
798 nb_ic = div_up(ic, ic_block);
799 nb_oc = div_up(oc, oc_block);
800 nb_kd = div_up(kd, kd_block);
801 nb_kh = div_up(kh, kh_block);
802 nb_kw = div_up(kw, kw_block);
803 nb_ow = div_up(ow, ow_block);
804 if (is_os_blocking) {
805 nb_os = div_up(os, os_block);
806 sp = os;
807 sp_block = os_block;
808 nb_sp = nb_os;
809 } else {
810 sp = ow;
811 sp_block = ow_block;
812 nb_sp = nb_ow;
813 iw_block = get_inp_size(iwp, ow_block, kw, stride_w, dilate_w);
814 }
815}
816
817bool brg_blocking_t::fast_check_oc_block() const {
818 // This function for reducing the number of blocking variants
819 // TODO: eliminate heuristic in this function
820 const auto rnd_oc = rnd_up(oc, acc_simd_w);
821 auto res = false;
822 if (oc_block == 64) {
823 res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz < 192 * 4);
824 } else if (oc_block == 48) {
825 const bool big_spatial
826 = id * ih * iw > 81 * stride_d * stride_h * stride_w;
827 res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz <= 384 * 4
828 && big_spatial);
829 } else
830 res = true;
831
832 return res;
833}
834
835float brg_blocking_t::est_eff() {
836 const auto ocblock = oc_block / acc_simd_w;
837
838 const auto brgemm_microkernel_eff
839 = (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
840
841 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
842 const auto brgemm_eff = squeeze_val(ur
843 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
844 / 64,
845 0.5f);
846
847 const auto sp_amount = nb_od * nb_oh * nb_sp;
848 const auto work_amount = mb * ngroups * nb_oc * sp_amount;
849 const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block));
850
851 const auto thr_eff = static_cast<float>(work_amount)
852 / utils::rnd_up(work_amount, nthr);
853
854 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
855
856 const auto job = div_up(work_amount, nthr);
857
858 auto job_eff = 1.f;
859 if (job < nthr) {
860 std::vector<dim_t> thr_jobs(nthr);
861
862 for (int ithr = 0; ithr < nthr; ithr++) {
863 thr_jobs[ithr] = 0;
864 if (ithr >= work_amount) continue;
865 dim_t thr_job = 0;
866 int start {0}, end {0};
867 balance211(work_amount, nthr, ithr, start, end);
868 int n {0}, g {0}, ocb {0}, odp {0}, ohp {0}, spb {0};
869 if (loop_order == loop_ndhwgc)
870 nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, g,
871 ngroups, ocb, nb_oc);
872 else if (loop_order == loop_ngcdhw)
873 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp, od,
874 ohp, oh, spb, nb_sp);
875
876 for (auto work = start; work < end; work++) {
877 const int ocp = ocb * oc_block;
878 const auto oc_sz = nstl::min(oc - ocp, oc_block);
879 int sp_sz = 0;
880 const int spp = spb * sp_block;
881 sp_sz = nstl::min(sp - spp, sp_block);
882 thr_job += sp_sz * oc_sz;
883
884 if (loop_order == loop_ndhwgc)
885 nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
886 ngroups, ocb, nb_oc);
887 else if (loop_order == loop_ngcdhw)
888 nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
889 ohp, oh, spb, nb_sp);
890 }
891 thr_jobs[ithr] = thr_job;
892 }
893
894 dim_t max_job = 0;
895 dim_t sum_job = 0;
896 for (int ithr = 0; ithr < nthr; ithr++) {
897 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
898 sum_job += thr_jobs[ithr];
899 }
900 job_eff = max_job == 0 ? 1
901 : static_cast<float>(sum_job) / (max_job * nthr);
902
903 } else {
904 job_eff = thr_eff;
905 }
906
907 const auto ic_blocking_size = ic_block * nb_ic_blocking;
908 const auto oc_blocking_size = oc_block * ic_blocking_size;
909
910 int l = -1;
911
912 // -- brgemm kernel: loop by simd_w --
913 l++;
914 const auto inp_ur = inp_w(ur, kw_block);
915 loop[l].src.set(inp_ur * simd_w, 1, bcast_simd);
916 loop[l].dst.set(0, 1);
917 loop[l].wei.set(oc_block, 1);
918
919 // -- brgemm kernel: loop by kw in kw_block --
920 l++;
921 auto src_is = rnd_inp_simd(ur, kw_block, ic_blocking_size);
922 loop[l].src.set(src_is, 1, kw_block);
923 loop[l].dst.set(0, 1);
924 loop[l].wei.set(oc_blocking_size, 1);
925
926 // -- brgemm kernel: loop by batch (grouped by kw_block) in ur --
927 l++;
928 loop[l].src.set(src_is, 1);
929 loop[l].dst.set(0, 1);
930 auto wei_is = kw_block * oc_blocking_size;
931 loop[l].wei.set(wei_is, 1);
932 // -- brgemm kernel: loop by ur in sp_block --
933 l++;
934 const auto nb_ur = div_up(sp_block, ur);
935 loop[l].src.set(kd_block * kh_block * src_is, 1);
936 loop[l].dst.set(ur * oc_block, 1);
937 wei_is = kd_block * kh_block * kw_block * oc_blocking_size;
938 loop[l].wei.set(wei_is, nb_ur);
939
940 // -- harness: loop by k_blocks in ks --
941 l++;
942 loop[l].src.set(kd_block * kh_block
943 * rnd_inp_simd(sp_block, kw_block, ic_blocking_size),
944 1);
945 loop[l].dst.set(sp_block * oc_block, nb_kd * nb_kh * nb_kw);
946 loop[l].wei.set(wei_is, 1);
947
948 // -- brgemm kernel: loop by ic_chunks --
949 l++;
950 const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
951 loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, ic_blocking_size), 1);
952 loop[l].dst.set(sp_block * oc_block, ic_chunks);
953 wei_is = kd * kh * kw * oc_blocking_size;
954 loop[l].wei.set(wei_is, 1);
955
956 const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
957 const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
958 const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
959 const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
960
961 const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
962 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
963 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
964
965 int nb_oh_thr {1}, oh_thr {1}, nb_od_thr {1}, od_thr {1};
966 if (!is_os_blocking) {
967 const auto dim_oh = nb_sp * dim_sp;
968 nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
969 oh_thr = nstl::min(oh, nb_oh_thr * oh_block);
970
971 const auto dim_od = nb_oh * dim_oh;
972 nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
973 od_thr = nstl::min(od, nb_od_thr * od_block);
974 }
975
976 src_is = kd * kh * rnd_inp_simd(sp_block, kw, ic);
977
978 auto wei_op = kd * kh * kw * ocblock * ic;
979 if (loop_order == loop_ndhwgc) {
980 // -- harness: loop by oc_block --
981 l++;
982 loop[l].src.set(src_is, nb_oc_thr);
983 loop[l].dst.set(sp_block * oc_block, 1);
984 wei_is = kd * kh * kw * oc_block * ic;
985 wei_op = kd * kh * kw * nsimd_oc_thr * ic;
986 loop[l].wei.set(wei_is, 1);
987 }
988
989 // -- harness: loop by sp_blocks --
990 l++;
991 loop[l].src.set(src_is, 1);
992 const auto rnd_oc_for_sp
993 = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
994 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
995 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
996 // oh_block almost all is 1. TODO: manage oh_block != 1
997 // -- harness: loop by oh_blocks --
998 l++;
999 src_is = kd * kh * rnd_inp_simd(sp_thr, kw, ic);
1000 loop[l].src.set(oh_block * src_is, 1);
1001 loop[l].dst.set(sp_thr * rnd_oc_for_sp, 1);
1002 loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
1003 // od_block almost all is 1. TODO: manage oh_block != 1
1004 // -- harness: loop by od_blocks --
1005 l++;
1006 loop[l].src.set(od_block * oh_thr * src_is, 1);
1007 loop[l].dst.set(oh_thr * sp_thr * rnd_oc_for_sp, 1);
1008 loop[l].wei.set(wei_op * simd_w, nb_od_thr);
1009
1010 if (loop_order != loop_ndhwgc) {
1011 // -- harness: loop by oc_block --
1012 l++;
1013 loop[l].src.set(od_thr * oh_thr * src_is, nb_oc_thr);
1014 loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
1015 loop[l].wei.set(kd * kh * kw * oc_block * ic, 1);
1016 }
1017
1018 // -- harness: loop by mb --
1019 l++;
1020 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
1021 loop[l].src.set(od_thr * oh_thr * src_is, 1);
1022 loop[l].dst.set(od_thr * oh_thr * sp_thr * nsimd_oc_thr * simd_w, 1);
1023 loop[l].wei.set(kd * kh * kw * nsimd_oc_thr * simd_w * ic, mb_thr);
1024
1025 const auto src_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr
1026 * kd * kh * kw * ic;
1027 const auto dst_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr
1028 * nsimd_oc_thr;
1029 wei_op = kd * kh * kw * nsimd_oc_thr * ic;
1030
1031 // for "real" application set bench_iterations to 1
1032 const auto iterations = bench_iterations;
1033 l++;
1034 loop[l].src.set(src_op, iterations);
1035 loop[l].dst.set(dst_op * simd_w, iterations);
1036 loop[l].wei.set(wei_op * simd_w, iterations);
1037
1038 auto src_mem_k = mem_k;
1039 auto dst_mem_k = mem_k;
1040 auto wei_mem_k = mem_k;
1041 float src_rp = 1;
1042 float dst_rp = 1;
1043 float wei_rp = 1;
1044
1045 for (auto il = l; il >= 0; il--) {
1046 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true,
1047 loop_order == loop_ndhwgc ? false : true);
1048 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
1049 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false,
1050 loop_order == loop_ndhwgc ? true : false);
1051 src_rp *= loop[il].src.repeatn;
1052 dst_rp *= loop[il].dst.repeatn;
1053 wei_rp *= loop[il].wei.repeatn;
1054 }
1055 const auto src_ops = (src_op * src_rp) / iterations;
1056 const auto dst_ops = (dst_op * dst_rp) / iterations;
1057 const auto wei_ops = (wei_op * wei_rp) / iterations;
1058
1059 const auto src_cost = src_mem_k * src_ops;
1060 const auto dst_cost = dst_mem_k * dst_ops;
1061 const auto wei_cost = wei_mem_k * wei_ops;
1062 const auto call_kernel_cost
1063 = 1000.f * job * ic_chunks * nb_kd * nb_kh * nb_kw;
1064
1065 const auto cache_eff = (static_cast<dim_t>(mb) * od * oh * sp * ic * oc)
1066 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
1067 const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
1068 * job_eff * ur_eff * cache_eff * brgemm_eff;
1069 return res_eff;
1070}
1071
1072void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_,
1073 int kh_block_, bool maybe_use_buffer, int max_ow_block_thr) {
1074
1075 unsigned est_k_amount = ic * oc_block * wei_dsz;
1076
1077 kd_block = kd_block_;
1078 kh_block = kh_block_;
1079 if (one_of(exec_type, exec_vpad, exec_trans)) {
1080 kw_block = kw;
1081 kd_block_pad = kd_block;
1082 kh_block_pad = kh_block;
1083 kw_block_pad = kw_block;
1084 } else {
1085 kw_block = (est_k_amount * kw < L2) ? kw : 1;
1086 kd_block_pad = kh_block >= kd ? kd : 1;
1087 kh_block_pad = kw_block >= kh ? kh : 1;
1088 kw_block_pad = kw;
1089 }
1090
1091 if (exec_type == exec_vpad) {
1092 od_block = 1;
1093 oh_block = 1;
1094 } else if (exec_type == exec_trans) {
1095 const auto w_block_size
1096 = 2 * src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1097 const auto other_size = wei_dsz * kd * kh * kw * ic * oc_block
1098 + acc_dsz * 2 * amx_h * oc_block;
1099 const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)),
1100 other_size > L2 ? 0 : L2 - other_size);
1101 if (idp * ihp * w_block_size > L2_available) {
1102 od_block = utils::saturate(
1103 1, od, int(L2_available / (ihp * w_block_size)));
1104 if (od_block == 1)
1105 oh_block = utils::saturate(
1106 1, oh, int(L2_available / (w_block_size)));
1107 else
1108 oh_block = oh;
1109 } else {
1110 od_block = 1;
1111 oh_block = oh;
1112 }
1113 if (is_amx(isa)) {
1114 // try to fit into L1
1115 bool L1_fit_res = false;
1116 auto cur_od_block = od_block;
1117 auto cur_oh_block = oh_block;
1118 const auto src_w_block_size
1119 = src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1120 if (src_w_block_size < L1) {
1121 cur_od_block = utils::saturate(
1122 1, od, int(L1 / (ihp * src_w_block_size)));
1123 if (cur_od_block == 1)
1124 cur_oh_block = utils::saturate(
1125 1, oh, int(L1 / (src_w_block_size)));
1126 }
1127 for (; cur_od_block > 1; cur_od_block--) {
1128 const auto sp_size = cur_od_block * cur_oh_block * iwp;
1129 if ((static_cast<float>(od) / rnd_up(od, cur_od_block)) > 0.9f
1130 && static_cast<float>(sp_size) / rnd_up(sp, amx_h)
1131 > 0.8f) {
1132 L1_fit_res = true;
1133 break;
1134 }
1135 }
1136 if (cur_od_block == 1) {
1137 for (; cur_oh_block > 1; cur_oh_block--) {
1138 const auto sp_size = cur_oh_block * iwp;
1139 if ((static_cast<float>(oh) / rnd_up(oh, cur_oh_block))
1140 > 0.9f
1141 && sp_size > 128) {
1142 L1_fit_res = true;
1143 break;
1144 }
1145 }
1146 }
1147 if (L1_fit_res) {
1148 od_block = cur_od_block;
1149 oh_block = cur_oh_block;
1150 }
1151 }
1152
1153 // limit oh_block to have good threading
1154 const auto thr_oc_block = div_up(
1155 nthr, mb * div_up((oc > 32 ? ngroups : 1) * oc, oc_block));
1156 const auto thr_od_block = div_up(od, thr_oc_block);
1157 const auto thr_oh_block
1158 = div_up(oh, thr_oc_block * div_up(od, thr_od_block));
1159 od_block = nstl::min(od_block, thr_od_block);
1160 oh_block = nstl::min(oh_block, thr_oh_block);
1161 } else {
1162 od_block = 1;
1163 oh_block = 1;
1164 }
1165
1166 // --- Select ow_block ----
1167 const auto max_ow_block_L2 = ow;
1168 auto start_ow_block = nstl::min(max_ow_block_thr, max_ow_block_L2);
1169
1170 sp = ow;
1171 const auto start_sp_block = is_os_blocking ? ow : start_ow_block;
1172 auto prev_spb = 0;
1173 for (auto ns = 1; ns <= sp; ns++) {
1174 const auto spb = div_up(sp, ns);
1175 if (spb == prev_spb || spb > start_sp_block) continue;
1176 if (is_os_blocking && spb != ow) continue;
1177 prev_spb = spb;
1178 ow_block = spb;
1179 sp_block = ow_block;
1180
1181 select_ic_block();
1182
1183 use_buffer = maybe_use_buffer
1184 && (ic_block * nb_ic_blocking < ic || kd_block != kd
1185 || kh_block != kh || kw_block != kw
1186 || kd_block_pad != kd || kh_block_pad != kh
1187 || kw_block_pad != kw);
1188 if (exec_type == exec_base)
1189 use_buffer = use_buffer || (maybe_use_buffer && iwp != iw);
1190
1191 const status_t st = estimate_brgemm_ur();
1192 if (st != status::success) continue;
1193 os_block = sp_block = ow_block;
1194 update_blocks();
1195
1196 eff = est_eff();
1197
1198 if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1199 }
1200}
1201
1202status_t brg_blocking_t::calc_blocks() {
1203 sp = ow;
1204
1205 nb_ic_blocking = 1;
1206 // --- Select kernel blocking ---
1207 // if dst_dt != acc_dt and we need to store intermediate
1208 // results then we need the out buffer
1209 const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum);
1210
1211 std::vector<int> kd_blocks(1), kh_blocks(1);
1212 kd_blocks[0] = kd;
1213 kh_blocks[0] = kh;
1214 if (kd != 1) {
1215 kd_blocks.resize(2);
1216 kd_blocks[1] = 1;
1217 }
1218 if (kh != 1) {
1219 kh_blocks.resize(2);
1220 kh_blocks[1] = 1;
1221 }
1222
1223 const auto thr_eff_threshold = 0.9f;
1224 const auto max_ow_block_thr = utils::saturate(1, ow,
1225 static_cast<int>(div_up(
1226 mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1227
1228 ow_block = os_block = sp_block = -1;
1229 brg_blocking_t best_brgb = *this;
1230 for (const auto &kd_block : kd_blocks) {
1231 for (const auto &kh_block : kh_blocks) {
1232 iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer,
1233 max_ow_block_thr);
1234 }
1235 }
1236 *this = best_brgb;
1237 if (!IMPLICATION(!is_os_blocking, sp_block > 0))
1238 return status::unimplemented;
1239
1240 if (is_os_blocking) {
1241 ow_block = ow;
1242 os_block = ow * oh_block;
1243 sp_block = os_block;
1244 ow_tail = 0;
1245 } else {
1246 ow_block = os_block = sp_block;
1247 ow_tail = ow % ow_block;
1248 }
1249 update_blocks();
1250 return status::success;
1251}
1252
1253bool brg_blocking_t::fast_check_oc_block_1x1() const {
1254 // This function for reducing the number of blocking variants
1255 // TODO: eliminate heuristic in this function
1256 if (is_1x1 && is_amx(isa)) return true;
1257 const auto rnd_oc = rnd_up(oc, acc_simd_w);
1258 auto res = false;
1259 if (oc_block == 64) {
1260 const auto big_spatial
1261 = od * oh * ow >= 64 * stride_d * stride_h * stride_w;
1262 res = (rnd_oc % oc_block == 0 && big_spatial);
1263 } else if (oc_block == 48) {
1264 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1265 res = (oc_block_eff >= 0.95f);
1266 } else
1267 res = true;
1268
1269 return res;
1270}
1271
1272float brg_blocking_t::est_eff_1x1() {
1273 const auto ocblock = oc_block / acc_simd_w;
1274
1275 auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float {
1276 const int nb = dim / block;
1277 constexpr int max_nb = 2; // only consider 2x2 tile blocking
1278 const int block2 = nstl::min(max_nb, nb);
1279 const int nb2 = nb / block2;
1280 const int nb2_tail = nb % block2;
1281 if (!use_ave) return block2;
1282 return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2);
1283 };
1284 const bool use_ocb_ave = true;
1285 const auto ocb_ave = calc_ave_blk(oc_block, acc_simd_w, use_ocb_ave);
1286 const bool use_spb_ave = false;
1287 const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave);
1288 const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0;
1289 const auto M_tail_n_sp_blks
1290 = ur_block_tail > 0 ? M_tail / ur_block_tail : 0;
1291
1292 // heuristic for maskrcnn workaround: use old blocking for some convolutions
1293 // TODO: remove this condition
1294 const bool maskrcnn_cond = (ic == 1024 && oc == 2048)
1295 || (ic == 1024 && oc == 512) || (ic == 256 && oc == 1024)
1296 || (ic == 512 && oc == 1024) || (ic == 512 && oc == 2048);
1297 const auto amx_fac = maskrcnn_cond
1298 ? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks))
1299 : (static_cast<float>(div_up(M + M_tail, 16))
1300 / (M_n_sp_blks + M_tail_n_sp_blks));
1301
1302 const auto brgemm_microkernel_eff = is_amx(isa)
1303 ? amx_fac * (static_cast<float>(ocb_ave) * spb_ave)
1304 / (ocb_ave + spb_ave)
1305 : (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
1306 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
1307 const auto brgemm_eff = squeeze_val(ur
1308 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
1309 / 64,
1310 0.5f);
1311
1312 const auto sp_amount = is_os_blocking ? div_up(nb_os, nb_os_blocking)
1313 : nb_od * nb_oh * nb_sp;
1314 const auto work_amount = mb * ngroups * nb_oc * sp_amount;
1315
1316 const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block);
1317 const auto thr_eff = static_cast<float>(work_amount)
1318 / utils::rnd_up(work_amount, nthr);
1319 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1320
1321 const auto job = div_up(work_amount, nthr);
1322
1323 const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
1324 const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
1325 const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
1326 const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
1327
1328 const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
1329 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
1330 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
1331
1332 int nb_oh_thr {1}, oh_thr {1}, nb_od_thr {1}, od_thr {1};
1333 if (!is_os_blocking) {
1334 const auto dim_oh = nb_sp * dim_sp;
1335 nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
1336 oh_thr = nstl::min(oh, nb_oh_thr * oh_block);
1337
1338 const auto dim_od = nb_oh * dim_oh;
1339 nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
1340 od_thr = nstl::min(od, nb_od_thr * od_block);
1341 }
1342
1343 auto job_eff = 1.f;
1344 if (job < nthr) {
1345 std::vector<dim_t> thr_jobs(nthr);
1346 for (int ithr = 0; ithr < nthr; ithr++) {
1347 thr_jobs[ithr] = 0;
1348 if (ithr >= work_amount) continue;
1349 dim_t thr_job = 0;
1350 int start {0}, end {0};
1351 balance211(work_amount, nthr, ithr, start, end);
1352 int n {0}, g {0}, ocb {0}, oss {0}, odp {0}, ohp {0}, spb {0};
1353 if (loop_order == loop_ndhwgc) {
1354 if (is_os_blocking)
1355 nd_iterator_init(start, n, mb, oss, sp_amount, g, ngroups,
1356 ocb, nb_oc);
1357 else
1358 nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp,
1359 g, ngroups, ocb, nb_oc);
1360 } else if (loop_order == loop_ngcdhw) {
1361 if (is_os_blocking)
1362 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, oss,
1363 sp_amount);
1364 else
1365 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp,
1366 od, ohp, oh, spb, nb_sp);
1367 }
1368
1369 for (auto work = start; work < end; work++) {
1370 const int ocp = ocb * oc_block;
1371 const auto oc_sz = nstl::min(oc - ocp, oc_block);
1372 int sp_sz = 0;
1373 if (is_os_blocking) {
1374 const auto osb_start = oss * nb_os_blocking;
1375 const auto osb_range
1376 = nstl::min(nb_os - osb_start, nb_os_blocking);
1377 for (int osb = 0; osb < osb_range; osb++) {
1378 const int osp = (osb_start + osb) * sp_block;
1379 sp_sz = nstl::min(os - osp, sp_block);
1380 }
1381 } else {
1382 const int spp = spb * sp_block;
1383 sp_sz = nstl::min(sp - spp, sp_block);
1384 }
1385 thr_job += sp_sz * oc_sz;
1386
1387 if (loop_order == loop_ndhwgc) {
1388 if (is_os_blocking)
1389 nd_iterator_step(
1390 n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc);
1391 else
1392 nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
1393 ngroups, ocb, nb_oc);
1394 } else if (loop_order == loop_ngcdhw) {
1395 if (is_os_blocking)
1396 nd_iterator_step(
1397 n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount);
1398 else
1399 nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
1400 ohp, oh, spb, nb_sp);
1401 }
1402 }
1403 thr_jobs[ithr] = thr_job;
1404 }
1405
1406 dim_t max_job = 0;
1407 dim_t sum_job = 0;
1408 for (int ithr = 0; ithr < nthr; ithr++) {
1409 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
1410 sum_job += thr_jobs[ithr];
1411 }
1412
1413 job_eff = max_job == 0 ? 1
1414 : static_cast<float>(sum_job) / (max_job * nthr);
1415 } else {
1416 job_eff = thr_eff;
1417 }
1418
1419 const auto ic_blocking_size = ic_block * nb_ic_blocking;
1420 const auto oc_blocking_size = oc_block * ic_blocking_size;
1421
1422 int l = -1;
1423 // -- brgemm kernel: loop by simd_w --
1424 l++;
1425 loop[l].src.set(ur * simd_w, 1, bcast_simd);
1426 loop[l].dst.set(0, 1);
1427 loop[l].wei.set(oc_block, 1);
1428
1429 // -- brgemm kernel: loop by ur in sp_block --
1430 l++;
1431 const auto nb_ur = div_up(sp_block, ur);
1432 const auto nb_sp_no_tail = sp / sp_block;
1433 const auto sp_block_tail = sp % sp_block;
1434 const auto nb_ur_average
1435 = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp;
1436 loop[l].src.set(ur * rnd_simd(ic_blocking_size), 1);
1437 loop[l].dst.set(ur * oc_block, 1);
1438 loop[l].wei.set(oc_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur);
1439 // -- brgemm kernel: loop by ic_chunks --
1440 l++;
1441 const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
1442 loop[l].src.set(sp_block * ic_blocking_size, 1);
1443 loop[l].dst.set(sp_block * oc_block, ic_chunks);
1444 auto wei_is = oc_blocking_size;
1445 auto wei_op = ocblock * ic;
1446 loop[l].wei.set(wei_is, 1);
1447
1448 if (loop_order == loop_ndhwgc) {
1449 // -- harness: loop by oc_block --
1450 l++;
1451 loop[l].src.set(sp_block * rnd_simd(ic), nb_oc_thr);
1452 loop[l].dst.set(sp_block * oc_block, 1);
1453 wei_is = oc_block * ic;
1454 wei_op = nsimd_oc_thr * ic;
1455 loop[l].wei.set(wei_is, 1);
1456 }
1457
1458 const auto rnd_oc_for_sp
1459 = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
1460 if (is_os_blocking) {
1461 // -- harness: loop by os_blocks --
1462 l++;
1463 loop[l].src.set(sp_block * ic_blocking_size, 1);
1464 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1465 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1466 } else {
1467 // -- harness: loop by sp_blocks --
1468 l++;
1469 loop[l].src.set(sp_block * ic_blocking_size, 1);
1470 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1471 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1472 // -- harness: loop by oh_blocks --
1473 l++;
1474 loop[l].src.set(oh_block * sp_thr * rnd_simd(ic_blocking_size), 1);
1475 loop[l].dst.set(oh_block * sp_thr * rnd_oc_for_sp, 1);
1476 loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
1477 // -- harness: loop by od_blocks --
1478 l++;
1479 loop[l].src.set(
1480 od_block * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1481 loop[l].dst.set(od_block * oh_thr * sp_thr * rnd_oc_for_sp, 1);
1482 loop[l].wei.set(wei_op * simd_w, nb_od_thr);
1483 }
1484
1485 if (loop_order != loop_ndhwgc) {
1486 // -- harness: loop by oc_block --
1487 l++;
1488 loop[l].src.set(od_thr * oh_thr * rnd_simd(sp_thr * ic_blocking_size),
1489 nb_oc_thr);
1490 loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
1491 loop[l].wei.set(oc_block * ic, 1);
1492 }
1493
1494 // -- harness: loop by mb --
1495 l++;
1496 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
1497 loop[l].src.set(od_thr * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1498 loop[l].dst.set(nsimd_oc_thr * simd_w * od_thr * oh_thr * sp_thr, 1);
1499 loop[l].wei.set(nsimd_oc_thr * ic * simd_w, mb_thr);
1500
1501 const auto src_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr
1502 * ic_blocking_size;
1503 const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_oc_thr * od_thr
1504 * oh_thr * sp_thr;
1505 wei_op = nsimd_oc_thr * ic;
1506
1507 // for "real" application set bench_iterations to 1
1508 const auto iterations = bench_iterations;
1509 l++;
1510 loop[l].src.set(src_op, iterations);
1511 loop[l].dst.set(dst_op * simd_w, iterations);
1512 loop[l].wei.set(wei_op * simd_w, iterations);
1513
1514 auto src_mem_k = mem_k;
1515 auto dst_mem_k = mem_k;
1516 auto wei_mem_k = mem_k;
1517 float src_rp = 1;
1518 float dst_rp = 1;
1519 float wei_rp = 1;
1520
1521 for (auto il = l; il >= 0; il--) {
1522 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false);
1523 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
1524 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true);
1525 src_rp *= loop[il].src.repeatn;
1526 dst_rp *= loop[il].dst.repeatn;
1527 wei_rp *= loop[il].wei.repeatn;
1528 }
1529 const auto src_ops = (src_op * src_rp) / iterations;
1530 const auto dst_ops = (dst_op * dst_rp) / iterations;
1531 const auto wei_ops = (wei_op * wei_rp) / iterations;
1532
1533 const auto src_cost = src_mem_k * src_ops;
1534 const auto dst_cost = dst_mem_k * dst_ops;
1535 const auto wei_cost = wei_mem_k * wei_ops;
1536 const auto call_kernel_cost = 1000.f * job * ic_chunks;
1537
1538 const auto up_sp_size = is_os_blocking ? 1 : od * oh;
1539
1540 const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * ic * oc)
1541 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
1542
1543 const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
1544 * job_eff * ur_eff * cache_eff * brgemm_eff;
1545 return res_eff;
1546}
1547
1548void brg_blocking_t::calc_blocks_1x1() {
1549 const bool is_os_blocking_ok
1550 = utils::everyone_is(1, stride_d, stride_h) && iw % stride_w == 0;
1551 const bool is_ic_zero_padded = ic != ic_without_padding;
1552 is_rtus = is_ic_zero_padded || (!is_os_blocking_ok && is_amx(isa));
1553 if (is_os_blocking_ok || is_rtus) {
1554 sp = os;
1555 is_os_blocking = true;
1556 } else {
1557 sp = ow;
1558 is_os_blocking = false;
1559 }
1560
1561 od_block = 1;
1562 oh_block = 1;
1563 kd_block = kh_block = kw_block = 1;
1564 kd_block_pad = kh_block_pad = kw_block_pad = 1;
1565 nb_ic_blocking = 1;
1566
1567 const auto thr_eff_threshold = 0.9f;
1568
1569 const auto max_sp_block_L2 = os;
1570 // TODO: nb_os_blocking always is 1 for now. Update this code
1571 nb_os_blocking = 1;
1572 int start_sp_block = 0;
1573
1574 if (is_os_blocking) {
1575 ow_block = 0;
1576
1577 const auto max_os_block_thr
1578 = (src_dsz * ic >= 1024 && src_dsz * ic < 4096)
1579 ? nstl::max(nstl::min(16, os),
1580 div_up(os, div_up(nthr, mb * div_up(oc, oc_block))))
1581 : nstl::max(div_up(2048, oc_block),
1582 static_cast<int>(div_up(mb * ngroups * os, nthr)));
1583 const auto max_os_block_L2 = max_sp_block_L2;
1584
1585 auto max_os_block_aliasing = 1000000 / nthr;
1586 if ((oc_without_padding * os * dst_dsz) % P4K == 0) {
1587 max_os_block_aliasing /= 1;
1588 for (auto cur_oc = oc_without_padding;
1589 max_os_block_aliasing * dst_dsz > 400 && cur_oc % 2 == 0
1590 && cur_oc * os * dst_dsz >= P4K;
1591 cur_oc /= 2) {
1592 max_os_block_aliasing /= 2;
1593 }
1594 max_os_block_aliasing += max_os_block_aliasing % 2 ? 0 : 1;
1595 }
1596 max_os_block_aliasing
1597 = nstl::min(div_up(1001, dst_dsz), max_os_block_aliasing);
1598
1599 start_sp_block = utils::saturate(1, os,
1600 nstl::min(nstl::min(max_os_block_thr, max_os_block_L2),
1601 max_os_block_aliasing));
1602
1603 } else {
1604 os_block = 0;
1605
1606 const auto max_ow_block_thr = utils::saturate(1, ow,
1607 static_cast<int>(div_up(
1608 mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1609 const auto max_ow_block_L2 = max_sp_block_L2;
1610
1611 start_sp_block = utils::saturate(
1612 1, ow, nstl::min(max_ow_block_thr, max_ow_block_L2));
1613 }
1614 os_block = ow_block = sp_block = -1;
1615 brg_blocking_t best_brgb = *this;
1616
1617 auto prev_spb = 0;
1618 for (auto ns = 1; ns <= sp; ns++) {
1619 auto spb = div_up(sp, ns);
1620 if (is_amx(isa)) {
1621 auto min_dis = 16;
1622 auto best_w = 16;
1623 const auto max_tile_width = nstl::min(16, sp);
1624 const auto min_tile_width = utils::saturate(1, 11, sp / 2);
1625 if (spb < min_tile_width) break;
1626 for (auto w = max_tile_width; w >= min_tile_width; w--) {
1627 const auto dis = nstl::additive_inverse_modulo(spb, w);
1628 if (dis < min_dis) {
1629 min_dis = dis;
1630 best_w = w;
1631 }
1632 }
1633 spb = nstl::min(sp, rnd_dn(spb, best_w));
1634 if (spb == prev_spb) continue;
1635 }
1636 if (spb == prev_spb || spb > start_sp_block) continue;
1637 prev_spb = spb;
1638 os_block = ow_block = sp_block = spb;
1639 select_ic_block();
1640 const status_t st = estimate_brgemm_ur();
1641 if (st != status::success) continue;
1642 update_blocks();
1643
1644 use_buffer = (dst_dt != acc_dt || with_sum)
1645 && (ic_block * nb_ic_blocking < ic);
1646
1647 eff = est_eff_1x1();
1648 if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1649 }
1650 *this = best_brgb;
1651 os_block = ow_block = sp_block;
1652 update_blocks();
1653}
1654
1655brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) {
1656 return attr.zero_points_.has_default_values(arg)
1657 ? brgemm_broadcast_t::none
1658 : brgemm_broadcast_t::per_tensor;
1659}
1660status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1661 const convolution_desc_t &cd, memory_desc_t &src_md,
1662 memory_desc_t &weights_md, memory_desc_t &dst_md,
1663 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1664 using namespace prop_kind;
1665
1666 brg_blocking_t::L1 = platform::get_per_core_cache_size(1);
1667 brg_blocking_t::L2 = platform::get_per_core_cache_size(2);
1668 brg_blocking_t::L3 = platform::get_per_core_cache_size(2);
1669
1670 const memory_desc_wrapper src_d(&src_md);
1671 const memory_desc_wrapper weights_d(&weights_md);
1672 const memory_desc_wrapper dst_d(&dst_md);
1673 const memory_desc_wrapper bias_d(&bias_md);
1674
1675 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1676 int ndims = src_d.ndims();
1677
1678 jcp = zero<decltype(jcp)>();
1679 jcp.isa = isa;
1680
1681 if (is_amx(isa)) {
1682 const int target_palette = amx::get_target_palette();
1683 if (amx::get_max_tiles(target_palette) != 8
1684 || amx::get_max_rows(target_palette) != 16)
1685 return status::unimplemented;
1686 }
1687
1688 jcp.ndims = ndims;
1689 jcp.prop_kind = cd.prop_kind;
1690 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1691 jcp.mb = src_d.dims()[0];
1692 jcp.oc_without_padding = dst_d.dims()[1];
1693 jcp.oc = jcp.oc_without_padding / jcp.ngroups;
1694 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
1695 jcp.ic = jcp.ic_without_padding;
1696 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1697 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
1698 jcp.iw = src_d.dims()[ndims - 1];
1699 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1700 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
1701 jcp.ow = dst_d.dims()[ndims - 1];
1702 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1703 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1704 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1705 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1706 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1707 jcp.l_pad = cd.padding[0][ndims - 3];
1708 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1709 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1710 jcp.stride_w = cd.strides[ndims - 3];
1711
1712 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1713 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1714 jcp.dilate_w = cd.dilates[ndims - 3];
1715
1716 jcp.os = jcp.od * jcp.oh * jcp.ow;
1717
1718 jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1719 jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1720 jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1721
1722 jcp.back_pad = calculate_end_padding(
1723 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd);
1724 jcp.b_pad = calculate_end_padding(
1725 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh);
1726 jcp.r_pad = calculate_end_padding(
1727 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw);
1728
1729 jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0
1730 && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0
1731 && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw);
1732
1733 jcp.with_bias = bias_md.format_kind != format_kind::undef;
1734
1735 jcp.src_dt = src_md.data_type;
1736 jcp.dst_dt = dst_md.data_type;
1737 jcp.wei_dt = weights_md.data_type;
1738 jcp.bia_dt = jcp.with_bias ? bias_md.data_type : data_type::undef;
1739
1740 if (one_of(jcp.src_dt, u8, s8)) {
1741 jcp.acc_dt = s32;
1742 } else if (one_of(jcp.src_dt, f32, bf16, f16)) {
1743 jcp.acc_dt = f32;
1744 } else
1745 return status::unimplemented;
1746
1747 jcp.src_dsz = types::data_type_size(jcp.src_dt);
1748 jcp.wei_dsz = types::data_type_size(jcp.wei_dt);
1749 jcp.dst_dsz = types::data_type_size(jcp.dst_dt);
1750 jcp.acc_dsz = types::data_type_size(jcp.acc_dt);
1751 jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
1752
1753 jcp.simd_w = isa_max_vlen(isa) / jcp.src_dsz;
1754 jcp.acc_simd_w = isa_max_vlen(isa) / jcp.acc_dsz;
1755 jcp.is_bf32 = everyone_is(f32, jcp.src_dt, jcp.wei_dt)
1756 && attr.fpmath_mode_ == fpmath_mode::bf16 && isa == avx512_core_amx;
1757
1758 brg_blocking_t::last_ic_block_size
1759 = (jcp.wei_dt == f16 && isa == avx512_core_fp16)
1760 ? 1
1761 : data_type_vnni_granularity(jcp.wei_dt);
1762
1763 // TODO: optimize depthwise convolutions (for now direct approach is faster)
1764 const bool is_depthwise
1765 = with_groups && jcp.ngroups > 1 && everyone_is(1, jcp.ic, jcp.oc);
1766 if (is_depthwise) return status::unimplemented;
1767
1768 // TODO: optimize grouped convolutions with small ic
1769 const bool is_grouped_small_ic
1770 = jcp.prop_kind != prop_kind::backward_weights && with_groups
1771 && jcp.ngroups > 1 && jcp.ic <= jcp.acc_simd_w
1772 && IMPLICATION(is_amx(jcp.isa),
1773 jcp.ic < 16
1774 && jcp.oc < 16
1775 // already optimized for amx 1x1 convs
1776 && !jcp.is_1x1)
1777 // Enable the shapes not supported in direct convs
1778 && IMPLICATION(with_groups, is_groups_ok(jcp));
1779 if (is_grouped_small_ic) return status::unimplemented;
1780
1781 // Dispatch the shapes to VNNI for better performance
1782 // TODO: optimize the perf of 3d shape with small ic and large spatial
1783 const auto max_small_shapes_sz = jcp.is_1x1
1784 ? static_cast<int32_t>(brg_blocking_t::L1) / 2
1785 : static_cast<int32_t>(brg_blocking_t::L1);
1786 const auto is_small_shape = is_amx(jcp.isa) && jcp.os <= 4 && jcp.ic <= 512
1787 && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= max_small_shapes_sz;
1788 const auto is_3d_small_ic = is_amx(jcp.isa) && jcp.ndims == 5
1789 && jcp.ic * jcp.oc <= 32 && jcp.od >= 128 && jcp.oh >= 128
1790 && jcp.ow >= 128;
1791 if (one_of(jcp.prop_kind, prop_kind::forward_training,
1792 prop_kind::forward_inference)
1793 && (is_small_shape || is_3d_small_ic))
1794 return status::unimplemented;
1795
1796 jcp.s8s8_avx512 = jcp.src_dt == s8 && !is_amx(jcp.isa);
1797
1798 if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni)))
1799 return status::unimplemented;
1800 if (!IMPLICATION(jcp.wei_dt == bf16,
1801 mayiuse(avx512_core_bf16) || mayiuse(avx2_vnni_2)))
1802 return status::unimplemented;
1803 if (!IMPLICATION(jcp.wei_dt == f16,
1804 mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)))
1805 return status::unimplemented;
1806 const bool is_f32
1807 = utils::everyone_is(f32, jcp.src_dt, jcp.wei_dt, jcp.dst_dt);
1808 if (!IMPLICATION(is_f32, one_of(isa, avx512_core, avx2) || jcp.is_bf32))
1809 return status::unimplemented;
1810
1811 if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented;
1812
1813 jcp.amx_h = 16;
1814 jcp.amx_w = 64 / (jcp.is_bf32 ? types::data_type_size(bf16) : jcp.src_dsz);
1815
1816 const auto &p = attr.post_ops_;
1817 jcp.with_sum = p.find(primitive_kind::sum) != -1;
1818 const int eltwise_ind = p.find(primitive_kind::eltwise);
1819 jcp.with_eltwise = eltwise_ind != -1;
1820
1821 const int binary_ind = p.find(primitive_kind::binary);
1822 jcp.with_binary = binary_ind != -1;
1823
1824 jcp.src_zero_point
1825 = get_zp_type(attr, DNNL_ARG_SRC) != brgemm_broadcast_t::none;
1826 jcp.dst_zero_point
1827 = get_zp_type(attr, DNNL_ARG_DST) != brgemm_broadcast_t::none;
1828
1829 // Only common zero points for the whole output tensor is supported now
1830 // TODO: Extend zero points support to AMX
1831 const bool has_zero_points = jcp.src_zero_point || jcp.dst_zero_point;
1832 if (has_zero_points || jcp.s8s8_avx512) {
1833 const bool params_ok = IMPLICATION(has_zero_points, !is_amx(jcp.isa))
1834 && IMPLICATION(
1835 has_zero_points, utils::one_of(jcp.src_dt, u8, s8))
1836 && IMPLICATION(jcp.src_zero_point,
1837 attr.zero_points_.common(DNNL_ARG_SRC))
1838 && IMPLICATION(jcp.dst_zero_point,
1839 attr.zero_points_.common(DNNL_ARG_DST));
1840 if (!params_ok) return status::unimplemented;
1841 }
1842
1843 jcp.nthr = nthreads;
1844 jcp.kh_sets = 1;
1845 jcp.kw_sets = 1;
1846 jcp.copy_block_only = false;
1847 jcp.amx_tile_load_xx = false;
1848 jcp.use_M_mask = 0;
1849 jcp.is_os_blocking = false;
1850 jcp.oskip = 0;
1851 jcp.use_uker = false;
1852 jcp.use_interleave_stores = false;
1853 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf_default;
1854 jcp.brgemm_bd_loop_innermost = false;
1855
1856 if (jcp.prop_kind != prop_kind::backward_weights) {
1857 // fast check data layout before spending time for blocking selection
1858 format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
1859 CHECK(init_tag(
1860 jcp.src_tag, src_md, src_d, src_tag, is_any_eligible(jcp)));
1861 }
1862 if (jcp.with_bias) {
1863 if (bias_d.format_kind() == format_kind::any)
1864 CHECK(memory_desc_init_by_tag(bias_md, x));
1865 }
1866
1867 const auto ic_padded_block
1868 = jcp.acc_simd_w * brg_blocking_t::last_ic_block_size;
1869 jcp.is_ic_padded = !jcp.is_1x1 && one_of(jcp.wei_dt, bf16, f16, s8)
1870 && jcp.ic * jcp.kw_sets > ic_padded_block && is_amx(isa);
1871
1872 jcp.idp = jcp.id + jcp.f_pad + jcp.back_pad;
1873 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1874 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1875
1876 return status::success;
1877}
1878
1879status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1880 const convolution_desc_t &cd, memory_desc_t &src_md,
1881 memory_desc_t &weights_md, memory_desc_t &dst_md,
1882 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1883
1884 using namespace prop_kind;
1885 if (!mayiuse(isa)) return status::unimplemented;
1886
1887 CHECK(init_jcp(
1888 jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
1889
1890 if (jcp.is_1x1) return status::unimplemented;
1891 const memory_desc_wrapper src_d(&src_md);
1892 const memory_desc_wrapper weights_d(&weights_md);
1893 const memory_desc_wrapper dst_d(&dst_md);
1894 const memory_desc_wrapper bias_d(&bias_md);
1895
1896 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1897
1898 // TODO: check these restrictions
1899 if (is_amx(isa)) {
1900 // disabled for two convolutions from ssd_resnet34
1901 if ((jcp.ic == jcp.oc) && (jcp.ic == 128 || jcp.ic == 256)
1902 && (jcp.oh == jcp.ow) && (jcp.oh == 150))
1903 return status::unimplemented;
1904 // disabled for first convolutions excepting 3d
1905 const bool is_real_3d = (jcp.ndims == 5
1906 && (jcp.id > 1 || jcp.od > 1 || jcp.kd > 1
1907 || jcp.dilate_d > 0));
1908
1909 if (jcp.ic <= 4 && !is_real_3d
1910 && IMPLICATION(with_groups, is_groups_ok(jcp)))
1911 return status::unimplemented;
1912
1913 if (jcp.f_pad >= jcp.ext_kd || jcp.t_pad >= jcp.ext_kh
1914 || jcp.r_pad >= jcp.ext_kw)
1915 return status::unimplemented;
1916 }
1917
1918 using namespace data_type;
1919 // ======================= blocking =================================
1920
1921 auto bcast_amount
1922 = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
1923 auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.kd * jcp.kh * jcp.kw
1924 * jcp.wei_dsz;
1925
1926 jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1927
1928 const int min_oc_block = jcp.acc_simd_w;
1929
1930 int selected_ur = 0;
1931 MAYBE_UNUSED(selected_ur);
1932
1933 auto try_exec_type = [&]() {
1934 brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1935 best_brgb.oc_block = min_oc_block;
1936 brg_blocking_t cur_brgb = zero<decltype(best_brgb)>();
1937 cur_brgb.get_from_jcp(jcp);
1938 auto start_ocb = (is_amx(isa) && jcp.is_os_blocking) ? 2 : 4;
1939 if (jcp.wei_plain)
1940 start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
1941 div_up(jcp.oc, jcp.acc_simd_w));
1942 start_ocb = nstl::min(div_up(jcp.oc, jcp.acc_simd_w), start_ocb);
1943
1944 auto finish_ocb = 1;
1945 for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
1946 cur_brgb.oc_block = ocb * jcp.acc_simd_w;
1947 cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
1948 if (!cur_brgb.fast_check_oc_block()) continue;
1949
1950 const status_t blocking_ok = cur_brgb.calc_blocks();
1951 if (blocking_ok != status::success) continue;
1952
1953 const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
1954 if (st != status::success) continue;
1955 cur_brgb.eff = cur_brgb.est_eff();
1956 if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
1957 }
1958 if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0
1959 || best_brgb.ow_block == 0)
1960 return false;
1961 best_brgb.save_to_jcp(jcp);
1962 selected_ur = best_brgb.ur;
1963 return true;
1964 };
1965
1966 //-----------------------------------------------------------------------
1967
1968 jcp.exec_type = exec_base;
1969 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1970
1971 bool try_exec_vpad = false;
1972 bool try_exec_trans = false;
1973 bool try_exec_base = true;
1974
1975 if (!is_amx(isa) && div_up(jcp.l_pad, jcp.stride_w) < jcp.kw
1976 && div_up(jcp.r_pad, jcp.stride_w) < jcp.kw) {
1977 try_exec_vpad = true;
1978 }
1979
1980 const auto ic_padded_block
1981 = jcp.acc_simd_w * brg_blocking_t::last_ic_block_size;
1982 // TODO: remove this restriction
1983 const auto w_padding = jcp.l_pad > 0 || jcp.r_pad > 0;
1984 if (is_amx(isa)) {
1985 try_exec_base = !w_padding
1986 && IMPLICATION(jcp.ic <= ic_padded_block,
1987 jcp.ic % brg_blocking_t::last_ic_block_size == 0)
1988 && IMPLICATION(
1989 jcp.ic > ic_padded_block, jcp.ic % ic_padded_block == 0)
1990 && jcp.ow > 50 /*TODO: reinvestigate this heuristic */;
1991 try_exec_trans = !try_exec_base;
1992 }
1993
1994 bool must_exec_vpad = false;
1995
1996 // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more
1997 // precise calculation of jcp.max_batch
1998 jcp.max_batch = jcp.kd * jcp.kh * jcp.kw;
1999
2000 //TODO: check wei plain
2001 jcp.wei_plain = false;
2002 jcp.wei_plain = jcp.exec_type == exec_vpad ? jcp.wei_plain : false;
2003
2004 bool try_exec_type_res = false;
2005
2006 if (try_exec_vpad) {
2007 jcp.exec_type = exec_vpad;
2008 try_exec_type_res = try_exec_type();
2009 // to avoid case when both top and bottom virtual padding are non-zero
2010 // TODO: remove this restriction
2011 const auto iw_block = (jcp.ow_block - 1) * jcp.stride_w + 1;
2012 if (!must_exec_vpad && (iw_block > jcp.iw)) try_exec_type_res = false;
2013 }
2014 if (try_exec_type_res == false && try_exec_trans) {
2015 jcp.exec_type = exec_trans;
2016
2017 // try loop_ndhwgc always for exec_trans
2018 jcp.loop_order = loop_ndhwgc;
2019
2020 // we read input block only once for loop_ndhwgc, so we don't need to
2021 // keep it memory
2022 if (jcp.loop_order == loop_ndhwgc) { jcp.copy_block_only = true; }
2023
2024 jcp.is_ic_padded = one_of(jcp.wei_dt, bf16, f16, s8)
2025 && jcp.ic * jcp.kw_sets > ic_padded_block;
2026
2027 if (is_amx(isa) && (/* heuristic*/ jcp.kw_sets == 1 && jcp.ow < 256)) {
2028 jcp.is_os_blocking = jcp.f_pad < jcp.kd && jcp.back_pad < jcp.kd
2029 && jcp.t_pad < jcp.kh && jcp.b_pad < jcp.kh
2030 && jcp.r_pad < jcp.kw && jcp.l_pad < jcp.kw;
2031 jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0;
2032 jcp.use_uker = true;
2033 jcp.use_interleave_stores = true;
2034 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
2035 // assuming 2x2 decomposition in amx brgemm kernel
2036 // and overlap of input by kw
2037 const auto bd_blocking = 2 * jcp.amx_h;
2038 const auto ld_blocking = 2 * 16;
2039 const auto A_ds
2040 = jcp.src_dsz * bd_blocking * jcp.ic * jcp.kd * jcp.kh;
2041 const auto B_ds = jcp.wei_dsz * ld_blocking * jcp.ic * jcp.kd
2042 * jcp.kh * jcp.kw;
2043 const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking;
2044 if (A_ds + B_ds + C_ds > brg_blocking_t::L1)
2045 jcp.amx_tile_load_xx = true;
2046 }
2047
2048 try_exec_type_res = try_exec_type();
2049 }
2050 if (try_exec_base && try_exec_type_res == false) {
2051 jcp.exec_type = exec_base;
2052 try_exec_type_res = try_exec_type();
2053 }
2054
2055 if (try_exec_type_res == false) return status::unimplemented;
2056
2057 // ============ end blocking ===========================================
2058 if (jcp.exec_type == exec_vpad)
2059 jcp.max_vpad = nstl::max(jcp.l_pad, jcp.r_pad);
2060 else
2061 jcp.max_vpad = 0;
2062
2063 if (jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0)
2064 return status::unimplemented;
2065
2066 jcp.gemm_batch_size = jcp.nb_ic_blocking
2067 * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block,
2068 jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad);
2069 // to avoid cache concurrent write access from different threads
2070 size_t sc_size = sizeof(brgemm_batch_element_t);
2071 jcp.adjusted_batch_size
2072 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size);
2073
2074 CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
2075 CHECK(attr.set_default_formats(&dst_md));
2076
2077 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
2078 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
2079 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
2080 jcp.with_scales = !src_scales.has_default_values()
2081 || !wei_scales.has_default_values();
2082 const int wei_mask_per_oc = 1 << (int)with_groups;
2083 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
2084
2085 // only common and per-oc-channel scales are supported
2086 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
2087 && src_scales.mask_ == 0 && dst_scales.has_default_values();
2088 if (!scales_ok) return status::unimplemented;
2089
2090 jcp.buffer_size = jcp.LDC * jcp.M;
2091
2092 jcp.nb_od = div_up(jcp.od, jcp.od_block);
2093 jcp.nb_oh = div_up(jcp.oh, jcp.oh_block);
2094
2095 if (jcp.exec_type == exec_trans) {
2096 // TODO: this is rough estimation of buffer for transpose input
2097 dim_t ds = jcp.copy_block_only
2098 ? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd,
2099 jcp.stride_d, jcp.dilate_d)
2100 + nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad))
2101 : jcp.idp;
2102 dim_t hs = jcp.copy_block_only
2103 ? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh,
2104 jcp.stride_h, jcp.dilate_h)
2105 + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad))
2106 : jcp.ihp;
2107 if (jcp.is_os_blocking)
2108 hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp);
2109
2110 jcp.inp_buffer_size = rnd_up(ds * hs * jcp.iwp * jcp.ngroups * jcp.nb_ic
2111 * jcp.ic_block * jcp.kh_sets * jcp.kw_sets,
2112 P4K);
2113 jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_od)
2114 * jcp.nb_oh * jcp.nb_ow * jcp.ngroups * jcp.nb_ic,
2115 P4K);
2116 }
2117
2118 const bool with_pad = jcp.f_pad > 0 || jcp.back_pad > 0 || jcp.t_pad > 0
2119 || jcp.b_pad > 0;
2120
2121 if (jcp.s8s8_avx512) {
2122 weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8;
2123 weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1;
2124 }
2125 if (jcp.src_zero_point && !is_amx(jcp.isa)) {
2126 weights_md.extra.flags
2127 |= memory_extra_flags::compensation_conv_asymmetric_src;
2128 weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1;
2129 }
2130
2131 // disables the shape with small ic but large spatial
2132 // or specific large spatial shapes for int8 conv
2133 const auto is_ok_large_spatial
2134 = IMPLICATION(!is_amx(jcp.isa) && jcp.ic <= 128,
2135 jcp.od * jcp.oh < 100
2136 || jcp.ic * jcp.oc_block * jcp.ow_block > 8192)
2137 && !(is_amx(jcp.isa) && jcp.ic < 16 && jcp.ndims == 4)
2138 && IMPLICATION(is_amx(jcp.isa) && jcp.ic <= 16,
2139 jcp.ow < 2048
2140 || div_up(jcp.ow_block, selected_ur) * jcp.kd
2141 * jcp.kh * jcp.kw
2142 > 8192)
2143 && !(!is_amx(jcp.isa) && jcp.oc == 1024
2144 && utils::everyone_is(1, jcp.od, jcp.oh, jcp.kd, jcp.kh)
2145 && jcp.ow >= 595 && jcp.kw <= 5);
2146 if (one_of(jcp.src_dt, u8, s8) && !is_ok_large_spatial)
2147 return status::unimplemented;
2148
2149 // For padding shapes, we calculate the comp along with the computation
2150 // inside brgemm kernel when output size is small to get optimal perf
2151 // Or we calculate the comp using brgemm_coomp_pad kernel
2152 const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.oc
2153 * jcp.od * jcp.oh * jcp.ow;
2154 const auto comp_with_pads = (jcp.src_zero_point || jcp.s8s8_avx512)
2155 && IMPLICATION(jcp.exec_type == exec_vpad, with_pad);
2156 jcp.req_brg_comp_pad = comp_with_pads && output_sz <= 8192 && jcp.oc < 512;
2157 jcp.req_cal_comp_pad = comp_with_pads && !jcp.req_brg_comp_pad;
2158
2159 // estimate the number of kernel range combination for compensation
2160 const auto kd_cnt = 1 + utils::div_up(abs(jcp.f_pad), jcp.dilate_d + 1)
2161 + utils::div_up(abs(jcp.back_pad), jcp.dilate_d + 1);
2162 const auto kh_cnt = 1 + utils::div_up(abs(jcp.t_pad), jcp.dilate_h + 1)
2163 + utils::div_up(abs(jcp.b_pad), jcp.dilate_h + 1);
2164 const auto kw_cnt
2165 = (1
2166 + (utils::div_up(abs(jcp.l_pad), jcp.dilate_w + 1)
2167 + utils::div_up(
2168 abs(jcp.r_pad), jcp.dilate_w + 1)))
2169 * 2;
2170
2171 jcp.ker_ranges_size = jcp.exec_type == exec_base ? kd_cnt * kh_cnt * kw_cnt
2172 : kd_cnt * kh_cnt;
2173 jcp.comp_a_buffer_size
2174 = jcp.ngroups * jcp.nb_oc * jcp.ker_ranges_size * jcp.oc_block;
2175 jcp.s8s8_comp_buffer_size = jcp.comp_a_buffer_size;
2176
2177 if (!IMPLICATION(jcp.is_bf32, jcp.use_uker)) return status::unimplemented;
2178
2179 return status::success;
2180}
2181
2182status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
2183 const convolution_desc_t &cd, memory_desc_t &src_md,
2184 memory_desc_t &weights_md, memory_desc_t &dst_md,
2185 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
2186
2187 using namespace prop_kind;
2188 if (!mayiuse(isa)) return status::unimplemented;
2189
2190 CHECK(init_jcp(
2191 jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
2192
2193 const memory_desc_wrapper src_d(&src_md);
2194 const memory_desc_wrapper weights_d(&weights_md);
2195 const memory_desc_wrapper dst_d(&dst_md);
2196 const memory_desc_wrapper bias_d(&bias_md);
2197
2198 if (!jcp.is_1x1) return status::unimplemented;
2199
2200 using namespace data_type;
2201 // ===================== blocking =================================
2202
2203 auto bcast_amount
2204 = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
2205 auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.wei_dsz;
2206
2207 jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
2208
2209 if (is_amx(isa)) {
2210 // round up ic if needed
2211 const int vnni_width = brg_blocking_t::last_ic_block_size;
2212 const int n_vnni_blocks = utils::div_up(jcp.ic, vnni_width);
2213 const int ic_block
2214 = nstl::min(jcp.acc_simd_w, n_vnni_blocks) * vnni_width;
2215 const bool do_zeropad = (!jcp.is_bf32)
2216 && (jcp.ic % vnni_width != 0 || jcp.ic > ic_block);
2217 if (do_zeropad) jcp.ic = utils::rnd_up(jcp.ic, ic_block);
2218 const auto ic_padded_block = jcp.acc_simd_w * vnni_width;
2219 jcp.is_ic_padded = jcp.ic > ic_padded_block && !(jcp.is_bf32);
2220
2221 // try to choose optimal loop order
2222 // TODO: incorporate loop order into smart blocking selection
2223 auto wei_size = (size_t)jcp.oc * jcp.ic * jcp.wei_dsz;
2224 auto max_size = 0.75f * brg_blocking_t::L2;
2225 const dim_t os = jcp.od * jcp.oh * jcp.ow;
2226 const dim_t os_cutoff = 400; // approximate and empiric
2227 const bool use_loop_ngcdhw
2228 = max_size < wei_size || (jcp.mb == 1 && os < os_cutoff);
2229 jcp.loop_order = use_loop_ngcdhw ? loop_ngcdhw : loop_ndhwgc;
2230 }
2231
2232 const auto min_oc_block = jcp.acc_simd_w;
2233
2234 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
2235
2236 // max_batch is 1 and max_vpad is 0 for 1x1 convolutions
2237 jcp.max_batch = 1;
2238 jcp.max_vpad = 0;
2239
2240 jcp.wei_plain = false;
2241
2242 brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
2243 best_brgb.oc_block = min_oc_block;
2244 brg_blocking_t cur_brgb = zero<decltype(cur_brgb)>();
2245 cur_brgb.get_from_jcp(jcp);
2246 auto start_ocb = 4;
2247 if (jcp.wei_plain)
2248 start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
2249 div_up(jcp.oc, jcp.acc_simd_w));
2250 start_ocb = nstl::min(div_up(jcp.oc, jcp.acc_simd_w), start_ocb);
2251
2252 auto finish_ocb = 1;
2253 for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
2254 cur_brgb.oc_block = ocb * min_oc_block;
2255 cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
2256
2257 if (!cur_brgb.fast_check_oc_block_1x1()) continue;
2258
2259 cur_brgb.calc_blocks_1x1();
2260 const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
2261 if (st != status::success) continue;
2262 cur_brgb.eff = cur_brgb.est_eff_1x1();
2263 if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
2264 }
2265 best_brgb.save_to_jcp(jcp);
2266
2267 // =============== end blocking =================================
2268 jcp.brg_stride_a = jcp.ic_block * jcp.src_dsz;
2269 jcp.brg_stride_b = jcp.ic_block * jcp.oc * jcp.wei_dsz;
2270
2271 if (jcp.ic_block == 0 || jcp.oc_block == 0) return status::unimplemented;
2272
2273 // Configure matrix sizes
2274
2275 if (best_brgb.is_os_blocking) {
2276 if (jcp.os_block == 0) return status::unimplemented;
2277 jcp.M = jcp.brgM = jcp.os_block;
2278 jcp.M_tail = jcp.brgM_tail = jcp.os % jcp.os_block;
2279 } else {
2280 if (jcp.ow_block == 0) return status::unimplemented;
2281 jcp.M = jcp.brgM = jcp.ow_block;
2282 jcp.M_tail = jcp.brgM_tail = jcp.ow % jcp.ow_block;
2283 }
2284
2285 jcp.K = jcp.ic >= jcp.ic_block ? jcp.ic_block : 0;
2286 jcp.N = jcp.oc >= jcp.oc_block ? jcp.oc_block : 0;
2287 jcp.N_tail = jcp.oc % jcp.oc_block;
2288 jcp.K_tail = jcp.ic % jcp.ic_block;
2289
2290 jcp.gemm_batch_size = jcp.nb_ic_blocking;
2291 // to avoid cache concurrent access from different threads
2292 size_t sc_size = sizeof(brgemm_batch_element_t);
2293 jcp.adjusted_batch_size
2294 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size);
2295
2296 if (is_amx(isa)) {
2297 // heuristic for small mb
2298 const bool is_small_mb = jcp.mb == 1 && jcp.ic * jcp.oh <= 28 * 1024
2299 && jcp.oc * jcp.oh <= 14 * 1024;
2300 // non-unrolled kernel does not support bf32, only dispatch unrolled
2301 // kernel for now
2302 jcp.use_uker = jcp.is_bf32 || !is_small_mb;
2303 }
2304
2305 // TODO: heuristic to dispatch BF32 BRGeMM
2306 // The following condition checks for shapes where down-convert execution
2307 // in brgemm fails
2308 if (jcp.is_bf32 && jcp.ic < 64 && jcp.ic % 32 != 0)
2309 return status::unimplemented;
2310
2311 if (jcp.use_uker)
2312 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
2313 CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
2314 CHECK(attr.set_default_formats(&dst_md));
2315
2316 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
2317
2318 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
2319 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
2320 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
2321 jcp.with_scales = !src_scales.has_default_values()
2322 || !wei_scales.has_default_values();
2323 const int wei_mask_per_oc = 1 << (int)with_groups;
2324 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
2325
2326 // only common and per-oc-channel scales are supported
2327 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
2328 && src_scales.mask_ == 0 && dst_scales.has_default_values();
2329 if (!scales_ok) return status::unimplemented;
2330
2331 // no inp buffer or brgemm_vpad for 1x1
2332 constexpr int align_size = platform::get_cache_line_size();
2333 jcp.exec_type = jcp.is_rtus ? exec_trans : exec_base;
2334 jcp.inp_buffer_size
2335 = jcp.is_rtus ? rnd_up(jcp.LDA * jcp.os, align_size) : 0;
2336 jcp.inp_buffer_mask_size = jcp.is_rtus
2337 ? rnd_up(div_up(jcp.nb_ic, jcp.nb_ic_blocking) * jcp.nb_os,
2338 align_size)
2339 : 0;
2340 jcp.buffer_size = jcp.LDC * jcp.M;
2341
2342 if (jcp.s8s8_avx512) {
2343 weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8;
2344 weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1;
2345 }
2346 if (jcp.src_zero_point) {
2347 weights_md.extra.flags
2348 |= memory_extra_flags::compensation_conv_asymmetric_src;
2349 weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1;
2350 }
2351 jcp.req_cal_comp_pad = false;
2352 jcp.s8s8_comp_buffer_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
2353 jcp.comp_a_buffer_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
2354
2355 return status::success;
2356}
2357
2358void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp) {
2359 // ensure buffers for individual threads do not lie on same page and also
2360 // they are not contiguous.
2361 jcp.amx_buf_size_per_thread
2362 = utils::rnd_up(jcp.amx_buf_size_per_thread + 1, P4K);
2363}
2364
2365void init_scratchpad(memory_tracking::registrar_t &scratchpad,
2366 const jit_brgemm_conv_conf_t &jcp) {
2367 if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs
2368 || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad))
2369 scratchpad.book(key_brgemm_primitive_batch,
2370 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
2371 sizeof(brgemm_batch_element_t), 64, P4K);
2372 if (jcp.exec_type == exec_trans) {
2373 size_t inp_buffer_size
2374 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size;
2375 scratchpad.book(key_conv_brgemm_inp_buffer, inp_buffer_size,
2376 jcp.src_dsz, 0, P4K);
2377 size_t inp_buffer_mask_size
2378 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size;
2379 scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size,
2380 sizeof(uint8_t), 0, P4K);
2381 }
2382 if (jcp.use_buffer) {
2383 scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size,
2384 jcp.acc_dsz, 0, P4K);
2385 }
2386 if (is_amx(jcp.isa)) {
2387 scratchpad.book(key_conv_amx_tile_buffer,
2388 jcp.nthr * jcp.amx_buf_size_per_thread, sizeof(char), 0, P4K);
2389 }
2390 if (jcp.s8s8_avx512 && jcp.req_cal_comp_pad) {
2391 scratchpad.book(key_brgemm_primitive_buffer_comp,
2392 jcp.s8s8_comp_buffer_size, sizeof(int32_t), 0, P4K);
2393 }
2394 if (jcp.src_zero_point && jcp.req_cal_comp_pad && !is_amx(jcp.isa)) {
2395 scratchpad.book(key_brgemm_primitive_zp_comp_a, jcp.comp_a_buffer_size,
2396 sizeof(int32_t), 0, P4K);
2397 }
2398}
2399
2400void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) {
2401
2402 const auto os_chunks = jcp.nthr_mb_work;
2403 const auto oc_chunks = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
2404 const auto ic_chunks = div_up(jcp.nb_ic, jcp.nb_ic_blocking);
2405
2406 auto calc_mem_cost = [=](int nthr_mb, int nthr_g, int nthr_oc_b,
2407 int nthr_ic_b) {
2408 /* calculate per thread memory cost (read/write). high level
2409 * optimizer tries to minimize memory consumption. few notes:
2410 * (n1) if weights tensor size is less than source and destination
2411 * tensors we apply the ratio of the source and destination
2412 * tensor sizes to weights one as compensation coefficient to
2413 * avoid parallelization across batch size only, otherwise we
2414 * apply additional coefficient to source component based on
2415 * performance measurements
2416 * (n2) use scales based on output vs input channels ratio for
2417 * source and destination components to improve threading
2418 * balance across input and output channels */
2419
2420 const dim_t src_type_size = 2;
2421 const dim_t wei_type_size = 4;
2422 const dim_t acc_type_size = wei_type_size;
2423
2424 const auto wei_ks = jcp.kh * jcp.kw * jcp.kd;
2425
2426 const auto src_spatial = (dim_t)jcp.mb * jcp.id * jcp.ih * jcp.tr_iw;
2427 const auto dst_spatial = (dim_t)jcp.mb * jcp.od * jcp.oh * jcp.tr_ow;
2428
2429 dim_t src_size = src_spatial * jcp.ic * src_type_size;
2430 dim_t dst_size = dst_spatial * jcp.oc * src_type_size;
2431 dim_t wei_size = (dim_t)jcp.oc * jcp.ic * wei_ks * wei_type_size;
2432
2433 float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size;
2434 float oi_channels_ratio = (float)(oc_chunks) / ic_chunks;
2435
2436 auto get_src_coef = [=]() {
2437 float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f);
2438 if (wei_compensation_scale < 1.0f) src_coef *= 4.0f;
2439 return src_coef;
2440 };
2441
2442 auto get_dst_coef
2443 = [=]() { return nstl::max(oi_channels_ratio, 1.0f); };
2444
2445 auto get_wei_coef
2446 = [=]() { return nstl::max(wei_compensation_scale, 1.0f); };
2447
2448 const float src_coef = get_src_coef();
2449 const float dst_coef = get_dst_coef();
2450 const float wei_coef = get_wei_coef();
2451
2452 const auto thr_mb = div_up(os_chunks, nthr_mb);
2453 const auto nb_oc_job = jcp.oc_block * jcp.nb_oc_blocking;
2454 const auto nb_ic_job = jcp.ic_block * jcp.nb_ic_blocking;
2455
2456 const auto src_chunk = src_spatial / os_chunks;
2457 const auto dst_chunk = dst_spatial / os_chunks;
2458
2459 const auto thr_g = div_up(jcp.ngroups, nthr_g);
2460 const auto thr_ic_b = div_up(ic_chunks, nthr_ic_b);
2461 const auto thr_src_sp = thr_mb * src_chunk / jcp.stride_d / jcp.stride_h
2462 / jcp.stride_w;
2463 const auto thr_dst_sp = thr_mb * dst_chunk;
2464 const auto thr_ic_amount = thr_ic_b * nb_ic_job;
2465
2466 const auto thr_oc_b = div_up(oc_chunks, nb_oc_job * nthr_oc_b);
2467
2468 const auto thr_oc_amount = thr_oc_b * nb_oc_job;
2469 float src_v
2470 = src_type_size * src_coef * thr_g * thr_ic_amount * thr_src_sp;
2471 float dst_v
2472 = src_type_size * dst_coef * thr_g * thr_oc_amount * thr_dst_sp;
2473 float wei_v = acc_type_size * wei_coef * thr_g * thr_oc_amount
2474 * thr_ic_amount * wei_ks;
2475
2476 return src_v + dst_v + wei_v;
2477 };
2478
2479 auto balance = [=](int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_,
2480 int &nthr_ic_b_) {
2481 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
2482
2483 if (jcp.nthr < jcp.ngroups) {
2484 /* simplification... fortunately it doesn't hurt much */
2485 nthr_ = nthr_g_ = jcp.nthr;
2486 return;
2487 }
2488
2489 nthr_g_ = jcp.ngroups;
2490 const int nthr = jcp.nthr / nthr_g_;
2491
2492 float best_mem_cost
2493 = calc_mem_cost(nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_);
2494
2495 /* find the best thread distribution with lowest memory cost */
2496
2497 const int nthr_mb_max = nstl::min(nthr, jcp.nthr_mb_work);
2498 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
2499 const int nthr_par = nthr / nthr_mb;
2500 const int nthr_oc_b_max = nstl::min(nthr_par,
2501 oc_chunks); // Amount of nb_oc_blocks
2502 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
2503 int nthr_ic_b = nstl::min(
2504 nthr_par / nthr_oc_b, (jcp.nb_ic / jcp.nb_ic_blocking));
2505
2506 float mem_cost
2507 = calc_mem_cost(nthr_mb, nthr_g_, nthr_oc_b, nthr_ic_b);
2508 if (mem_cost <= best_mem_cost) {
2509 best_mem_cost = mem_cost;
2510 nthr_mb_ = nthr_mb;
2511 nthr_oc_b_ = nthr_oc_b;
2512 nthr_ic_b_ = nthr_ic_b;
2513 }
2514 }
2515 }
2516
2517 if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr)
2518 nthr_mb_ = nstl::min(jcp.nthr_mb_work, nthr);
2519 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
2520
2521 assert(nthr_ <= jcp.nthr);
2522 };
2523
2524 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
2525 balance(nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
2526
2527 // empiric balancing for some shapes
2528 const auto sps = (jcp.ih * jcp.iw);
2529 bool neat_1x1
2530 = everyone_is(1, jcp.id, jcp.kh, jcp.kw, jcp.ngroups, jcp.stride_h);
2531 if (neat_1x1 && jcp.nthr >= 28 && jcp.mb >= jcp.nthr) {
2532 const bool more_oc = (jcp.ic < jcp.oc);
2533 if (sps >= 56 * 56 && jcp.ic >= 64 && jcp.oc >= 64) {
2534 nthr_mb = jcp.nthr;
2535 nthr_oc_b = 1;
2536 } else if (sps >= 28 * 28 && jcp.ic >= 128 && jcp.oc >= 128) {
2537 nthr_mb = jcp.nthr / 4;
2538 nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1;
2539 } else if (sps >= 14 * 14 && jcp.ic >= 256 && jcp.oc >= 256) {
2540 nthr_mb = div_up(jcp.nthr, 8);
2541 nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1;
2542 } else if (sps >= 7 * 7 && jcp.ic >= 512 && jcp.oc >= 512) {
2543 nthr_mb = div_up(jcp.nthr, 14);
2544 nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1;
2545 }
2546 nthr_ic_b = jcp.nthr / (nthr_mb * nthr_oc_b);
2547 nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b;
2548 }
2549
2550 jcp.nthr = nthr;
2551 jcp.nthr_mb = nthr_mb;
2552 jcp.nthr_g = nthr_g;
2553 jcp.nthr_oc_b = nthr_oc_b;
2554 jcp.nthr_ic_b = nthr_ic_b;
2555
2556 // TODO: Optimize memory allocation when threaded on height and depth
2557 jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id;
2558 jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od;
2559 jcp.tr_src_buf_count = jcp.global_transpose
2560 ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups
2561 : jcp.nthr;
2562 jcp.tr_diff_dst_buf_count = jcp.global_transpose
2563 ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups
2564 : jcp.nthr;
2565}
2566
2567status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
2568 const convolution_desc_t &cd, memory_desc_t &src_md,
2569 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
2570 memory_desc_t &diff_dst_md, primitive_attr_t &attr, int nthreads) {
2571
2572 const memory_desc_wrapper src_d(&src_md);
2573 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
2574 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
2575 const memory_desc_wrapper diff_bias_d(&diff_bias_md);
2576
2577 const bool is_f16 = src_d.data_type() == data_type::f16;
2578
2579 jcp.isa = is_f16 ? avx512_core_amx_fp16 : avx512_core_amx;
2580 if (!mayiuse(jcp.isa)) return status::unimplemented;
2581
2582 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2583 int ndims = src_d.ndims();
2584
2585 CHECK(init_jcp(jcp, jcp.isa, cd, src_md, diff_weights_md, diff_dst_md,
2586 diff_bias_md, attr, nthreads));
2587
2588 jcp.max_batch = jcp.od * jcp.oh;
2589 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
2590 jcp.use_uker = true;
2591 jcp.var_bs = true;
2592
2593 // Process some 1x1 convolutions with small iw as 1d (h=1, w = h*w)
2594 // convolutions to make brgemm K dimension bigger for better utilization of
2595 // AMX tiles
2596 bool neat_1x1_2d = (everyone_is(
2597 1, jcp.kh, jcp.kw, jcp.stride_h, jcp.stride_w)
2598 && everyone_is(0, jcp.t_pad, jcp.b_pad, jcp.l_pad, jcp.r_pad));
2599 bool make_1d = neat_1x1_2d && jcp.iw <= 28;
2600 if (make_1d) {
2601 jcp.iw *= jcp.ih;
2602 jcp.ih = 1;
2603 jcp.ow *= jcp.oh;
2604 jcp.oh = 1;
2605 jcp.max_batch = jcp.od;
2606 }
2607 // TODO: sometimes we can call brgemm kernel with bs = 0 to do initialization
2608 // review this condition
2609 if (jcp.max_batch == 1
2610 && everyone_is(0, jcp.f_pad, jcp.back_pad, jcp.t_pad, jcp.b_pad))
2611 jcp.var_bs = false;
2612
2613 jcp.has_vnni = true; // Needed for transpose routines
2614
2615 jcp.typesize_in = sizeof(bfloat16_t);
2616 jcp.typesize_out = sizeof(float);
2617
2618 bool ok = true
2619 // general condition to simplify dilations
2620 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
2621 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
2622 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
2623 // special condition to simplify dilations in compute_oh_loop_common
2624 && IMPLICATION(jcp.dilate_h != 0, jcp.ext_kh <= jcp.ih);
2625 if (!ok) return status::unimplemented;
2626
2627 jcp.transform_to_vnni = diff_weights_d.data_type() != data_type::f32;
2628
2629 /* XXX: no support for padding when dilation_d > 0 */
2630 if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad)))
2631 return status::unimplemented;
2632
2633 const bool is_depthwise = true && with_groups && jcp.ngroups > 1
2634 && everyone_is(1, jcp.ic, jcp.oc);
2635 if (is_depthwise)
2636 return status::unimplemented; // TODO: add support of DW convolution
2637
2638 const int dat_format_tag = ndims - 3;
2639 format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc,
2640 format_tag::nhwc, format_tag::ndhwc);
2641 format_tag_t dat_tag_opt = dat_tag_nspc;
2642
2643 if (src_d.format_kind() == format_kind::any) {
2644 CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt));
2645 jcp.src_tag = dat_tag_opt;
2646 } else
2647 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt);
2648 if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented;
2649
2650 const bool is_nspc = jcp.src_tag == dat_tag_nspc;
2651 if (!is_nspc) return status::unimplemented;
2652
2653 if (diff_dst_d.format_kind() == format_kind::any) {
2654 CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag));
2655 jcp.dst_tag = jcp.src_tag;
2656 } else
2657 jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag);
2658 if (jcp.dst_tag != jcp.src_tag) return status::unimplemented;
2659
2660 const int wei_format_tag = 2 * ndims - 6 + with_groups;
2661 format_tag_t wei_tag;
2662 if (jcp.transform_to_vnni)
2663 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i,
2664 format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i,
2665 format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i,
2666 format_tag::gOIdhw16i16o2i);
2667 else
2668 wei_tag = pick(wei_format_tag, format_tag::OIw16i16o,
2669 format_tag::gOIw16i16o, format_tag::OIhw16i16o,
2670 format_tag::gOIhw16i16o, format_tag::OIdhw16i16o,
2671 format_tag::gOIdhw16i16o);
2672 if (diff_weights_md.format_kind == format_kind::any) {
2673 CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
2674 jcp.wei_tag = wei_tag;
2675 } else {
2676 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
2677 if (jcp.wei_tag != wei_tag) return status::unimplemented;
2678 }
2679 jcp.wei_dt = diff_weights_d.data_type();
2680
2681 /* kernel applicability check wrt boundaries
2682 * the conditions are quite general across the kernels we have,
2683 * but ideally the check should belong to a specific kernel... */
2684 const int max_pad_h = jcp.ext_kh / 2;
2685 const bool boundaries_ok = true && jcp.l_pad < jcp.ext_kw
2686 && jcp.r_pad < jcp.ext_kw && jcp.t_pad <= max_pad_h
2687 && jcp.b_pad <= max_pad_h && jcp.f_pad < jcp.ext_kd
2688 && jcp.back_pad < jcp.ext_kd;
2689 if (!boundaries_ok) return status::unimplemented;
2690
2691 jcp.ic_block = 16;
2692 jcp.oc_block = 16;
2693
2694 jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block);
2695 jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block);
2696
2697 jcp.ic_tail = jcp.ic % jcp.ic_block;
2698 jcp.oc_tail = jcp.oc % jcp.oc_block;
2699
2700 jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1;
2701 jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1;
2702
2703 const bool is_2d = (ndims == 4);
2704 const bool is_3d = (ndims == 5);
2705
2706 // TODO: Find more shapes (especially 3D with large spatials) for which
2707 // local transposition will be beneficial. Furthermore, for TBB threads
2708 // more shapes can potentially benefit from spatial blocking
2709 int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow;
2710
2711 jcp.global_transpose = dnnl_thr_syncable();
2712 jcp.spatial_blk_size = optimal_blk_size;
2713
2714 const int tr_round = 32; // To load full tile register
2715 int tr_pad
2716 = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round); //!!! why?
2717 jcp.tr_iw = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, jcp.stride_w),
2718 tr_round)
2719 * jcp.stride_w;
2720
2721 // TODO: xf16 training is supported only
2722 const auto rnd_val = data_type_vnni_granularity(bf16);
2723 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
2724 jcp.tr_ow = rnd_up(jcp.ow, rnd_val);
2725 if (jcp.tr_ow > tr_round) {
2726 // we may increase tr_ow to have better bd_block in brgemm kernel
2727 int best_bdb = jcp.tr_ow / rnd_val;
2728 int best_tr_ow = jcp.tr_ow;
2729 for (int tr_ow = jcp.tr_ow; tr_ow <= rnd_up(jcp.tr_ow, tr_round);
2730 tr_ow += rnd_val) {
2731 for (int i = tr_round; i > 0; i -= rnd_val) {
2732 if (tr_ow % i == 0) {
2733 const auto cbdb = tr_ow / i;
2734 if (cbdb < best_bdb) {
2735 best_bdb = cbdb;
2736 best_tr_ow = tr_ow;
2737 }
2738 break;
2739 }
2740 }
2741 }
2742 jcp.tr_ow = best_tr_ow;
2743 }
2744
2745 bool args_ok = true && jcp.ic <= src_d.padded_dims()[1]
2746 && jcp.oc <= diff_dst_d.padded_dims()[1]
2747 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
2748 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
2749 if (!args_ok) return status::unimplemented;
2750
2751 jcp.harness = ndims == 5 ? harness_3d_reduction : harness_2d_reduction;
2752
2753 if (!one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) {
2754 return status::unimplemented;
2755 }
2756
2757 switch (jcp.harness) {
2758 case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break;
2759 case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break;
2760 default: assert(!"Invalid harness"); jcp.nthr_mb_work = jcp.mb;
2761 }
2762
2763 balance_bwd_w(jcp);
2764
2765 if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) {
2766 jcp.K = jcp.tr_ow;
2767 }
2768 jcp.K_tail = 0;
2769
2770 jcp.M = jcp.ic_block * jcp.nb_ic_blocking;
2771 // assumption that jcp.nb_ic_blocking is always 2
2772 if (jcp.nb_ic % jcp.nthr_ic_b == 0
2773 && (jcp.nb_ic / jcp.nthr_ic_b) % jcp.nb_ic_blocking == 0)
2774 jcp.M_tail = 0;
2775 else
2776 jcp.M_tail = jcp.ic_block;
2777
2778 jcp.N = jcp.oc_block * jcp.nb_oc_blocking;
2779 // assumption that jcp.nb_oc_blocking is always 2
2780 if (jcp.nb_oc % jcp.nthr_oc_b == 0
2781 && (jcp.nb_oc / jcp.nthr_oc_b) % jcp.nb_oc_blocking == 0)
2782 jcp.N_tail = 0;
2783 else
2784 jcp.N_tail = jcp.oc_block;
2785
2786 // for convolutions with big spatial: transpose only chunk
2787 // (oc_block * nb_oc_blocking) of diff_dst on each iteration by oc blocks
2788 // for better cache utilization
2789 // the equal number of channel blocks per thread is required to use this
2790 // approach to avoid hangs
2791 bool tr_ocb_chunk_allowed = (jcp.nb_oc % jcp.nthr_oc_b == 0);
2792 jcp.tr_ocb_chunk = tr_ocb_chunk_allowed && (jcp.oh * jcp.ow > 38 * 38);
2793 jcp.tr_icb_chunk = false;
2794
2795 const int irow_size = jcp.src_dsz * jcp.tr_iw * jcp.ic_block
2796 * div_up(jcp.nb_ic, jcp.nthr_ic_b)
2797 * 2 /*we have real and transposed input */;
2798 const int orow_size = jcp.dst_dsz * jcp.tr_ow * jcp.oc_block
2799 * div_up(jcp.nb_oc, jcp.nthr_oc_b)
2800 * 2 /*we have real and transposed diff_dst*/;
2801 int oh_block_limit = nstl::max(1.f,
2802 nstl::max(0.f, 0.8f * brg_blocking_t::L2 - jcp.kh * irow_size)
2803 / (irow_size + orow_size));
2804 // try to split oh by equal oh blocks
2805 oh_block_limit = div_up(jcp.oh, div_up(jcp.oh, oh_block_limit));
2806 jcp.oh_block = utils::saturate(1, jcp.oh, oh_block_limit);
2807
2808 const int iframe_size = irow_size * jcp.id;
2809 const int oframe_size = orow_size * jcp.od;
2810 int od_block_limit = nstl::max(1.f,
2811 nstl::max(0.f, 0.8f * brg_blocking_t::L2 - jcp.kd * iframe_size)
2812 / (iframe_size + oframe_size));
2813 // try to split od by equal od blocks
2814 od_block_limit = div_up(jcp.od, div_up(jcp.od, od_block_limit));
2815 jcp.od_block = utils::saturate(1, jcp.od, od_block_limit);
2816
2817 jcp.use_interleave_stores = false;
2818 jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1;
2819 jcp.amx_tile_load_xx = false;
2820
2821 if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) {
2822 jcp.LDA = jcp.tr_iw;
2823 jcp.LDB = jcp.oc_block;
2824 jcp.LDC = jcp.LDD = jcp.oc_block;
2825 }
2826
2827 jcp.gemm_batch_size = jcp.max_batch;
2828 // to avoid cache concurrent access from different threads
2829 size_t sc_size = sizeof(brgemm_batch_element_t);
2830 jcp.adjusted_batch_size
2831 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size);
2832
2833 return status::success;
2834}
2835
2836status_t init_scratchpad_bwd_w(memory_tracking::registrar_t &scratchpad,
2837 const jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md,
2838 memory_desc_t &diff_weights_md, memory_desc_t &diff_dst_md) {
2839 const memory_desc_wrapper src_d(&src_md);
2840 const memory_desc_wrapper diff_weights_d(&diff_weights_md);
2841 const memory_desc_wrapper diff_dst_d(&diff_dst_md);
2842
2843 // XXX: See the comment about tr_iw and guarding elements in
2844 // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf()
2845 const size_t tr_src_size
2846 = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking)
2847 + jcp.tr_src_num_guard_elems;
2848 scratchpad.book(key_conv_tr_src, tr_src_size, jcp.src_dsz);
2849
2850 /* prepare synchronization contexts */
2851 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
2852 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
2853 scratchpad.book<simple_barrier::ctx_t>(
2854 key_conv_tr_src_bctx, tr_src_bctx_size);
2855 }
2856
2857 // The tr_ow <= tr_iw, so we need some guarding at the end of diff_dst
2858 // TODO: update this guarding:
2859 // (jcp.tr_diff_dst_buf_size + jcp.tr_iw * jcp.oc_block)
2860 const auto tr_diff_dst_size = jcp.tr_diff_dst_buf_count
2861
2862 * (jcp.tr_diff_dst_buf_size + jcp.tr_iw * jcp.oc_block)
2863 * jcp.nb_oc_blocking;
2864
2865 const size_t min_align = 64;
2866 scratchpad.book(
2867 key_conv_tr_diff_dst, tr_diff_dst_size, jcp.src_dsz, min_align);
2868
2869 /* prepare synchronization contexts */
2870 if (jcp.global_transpose && jcp.nthr_ic_b > 1) {
2871 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
2872 scratchpad.book<simple_barrier::ctx_t>(
2873 key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size);
2874 }
2875
2876 if (IMPLICATION(jcp.nthr_mb == 1,
2877 (jcp.with_bias && jcp.bia_dt != data_type::f32)
2878 || jcp.wei_dt != data_type::f32)) {
2879 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
2880 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
2881 const size_t bia_size
2882 = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block;
2883
2884 const int num_wei_buffers
2885 = jcp.wei_dt != data_type::f32 ? jcp.nthr_mb : jcp.nthr_mb - 1;
2886 const int num_bia_buffers = jcp.with_bias
2887 ? (jcp.bia_dt != data_type::f32 ? jcp.nthr_mb : jcp.nthr_mb - 1)
2888 : 0;
2889
2890 const size_t wei_bia_reduction_size
2891 = wei_size * num_wei_buffers + bia_size * num_bia_buffers;
2892
2893 scratchpad.book<float>(
2894 key_conv_wei_bia_reduction, wei_bia_reduction_size);
2895
2896 scratchpad.book<simple_barrier::ctx_t>(
2897 key_conv_wei_bia_reduction_bctx, 1);
2898 }
2899
2900 if (jcp.with_bias
2901 && ((jcp.oc % jcp.oc_block != 0) && jcp.bia_dt == data_type::f32)) {
2902 scratchpad.book(key_conv_padded_bias,
2903 jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.bia_dsz);
2904 }
2905 scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline
2906
2907 constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32
2908 << 30; // 32Gb - TODO: may it's too large?
2909 const size_t scratchpad_limit_by_tensor_sizes = (size_t)64 * jcp.nthr
2910 * (src_d.size() + diff_weights_d.size() + diff_dst_d.size());
2911 const size_t scratchpad_limit
2912 = nstl::min(scratchpad_limit_by_absolute_value,
2913 scratchpad_limit_by_tensor_sizes);
2914
2915 scratchpad.book(key_brgemm_primitive_batch,
2916 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
2917 sizeof(brgemm_batch_element_t), 64, P4K);
2918
2919 scratchpad.book(
2920 key_conv_amx_tile_buffer, jcp.nthr * 2 * P4K, sizeof(char), 0, P4K);
2921
2922 if (scratchpad.size() > scratchpad_limit)
2923 return status::unimplemented;
2924 else
2925 return status::success;
2926}
2927
2928} // namespace brgemm_convolution_utils
2929
2930} // namespace x64
2931} // namespace cpu
2932} // namespace impl
2933} // namespace dnnl
2934