1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_types.h" |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/platform.hpp" |
27 | #include "cpu/scale_utils.hpp" |
28 | #include "cpu/x64/brgemm/brgemm_utils.hpp" |
29 | #include "cpu/x64/cpu_barrier.hpp" |
30 | #include "cpu/x64/cpu_isa_traits.hpp" |
31 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
32 | #include "cpu/x64/jit_brgemm_conv_utils.hpp" |
33 | #include "cpu/x64/jit_generator.hpp" |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace cpu { |
38 | namespace x64 { |
39 | |
40 | using namespace dnnl::impl::status; |
41 | using namespace dnnl::impl::format_tag; |
42 | using namespace dnnl::impl::memory_tracking::names; |
43 | using namespace dnnl::impl::utils; |
44 | |
45 | using namespace prop_kind; |
46 | using namespace data_type; |
47 | |
48 | namespace brgemm_convolution_utils { |
49 | |
50 | bool is_any_eligible(const jit_brgemm_conv_conf_t &jcp) { |
51 | return (jcp.prop_kind == prop_kind::forward_inference |
52 | || one_of(jcp.wei_dt, data_type::s8, data_type::f16) |
53 | || (jcp.isa == avx2_vnni_2) || is_amx(jcp.isa)); |
54 | } |
55 | |
56 | inline status_t init_tag(format_tag_t &tag, memory_desc_t &md, |
57 | const memory_desc_wrapper &mdw, const format_tag_t tag_value, |
58 | bool any_eligible) { |
59 | |
60 | if (mdw.format_kind() == format_kind::any) { |
61 | if (any_eligible) { |
62 | CHECK(memory_desc_init_by_tag(md, tag_value)); |
63 | tag = tag_value; |
64 | } else { |
65 | tag = format_tag::undef; |
66 | } |
67 | } else { |
68 | tag = mdw.matches_one_of_tag(tag_value); |
69 | } |
70 | |
71 | if (tag != tag_value) return status::unimplemented; |
72 | |
73 | return status::success; |
74 | } |
75 | |
76 | bool is_amx(cpu_isa_t isa) { |
77 | return is_superset(isa, avx512_core_amx); |
78 | } |
79 | |
80 | bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr, |
81 | const memory_desc_wrapper &dst_d) { |
82 | using namespace injector; |
83 | |
84 | const auto &post_ops = attr.post_ops_; |
85 | |
86 | return injector::post_ops_ok(post_ops_ok_args_t(jcp.isa, |
87 | {sum, eltwise, binary}, post_ops, &dst_d, |
88 | false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, |
89 | false /*sum_requires_zp_zero*/, |
90 | {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar, |
91 | broadcasting_strategy_t::no_broadcast})); |
92 | } |
93 | |
94 | bool is_groups_ok(jit_brgemm_conv_conf_t &jcp) { |
95 | // Enable grouped convs for the shapes not supported in direct convs |
96 | // direct approach only supports int8/bf16 grouped conv |
97 | // when channels per groups is at least multiple of 4 |
98 | // and bf16 grouped conv with layout nxc on jit_bf16 impl |
99 | // TODO: remove this condition after the restriction on small ic is removed |
100 | return jcp.ngroups > 1 |
101 | && IMPLICATION(one_of(jcp.src_dt, u8, s8, bf16), |
102 | jcp.ic % 4 == 0 && jcp.oc % 4 == 0); |
103 | } |
104 | |
105 | status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md, |
106 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
107 | memory_desc_t &bias_md) { |
108 | format_tag_t src_tag, dst_tag, wei_tag; |
109 | dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc); |
110 | |
111 | const memory_desc_wrapper src_d(&src_md); |
112 | const memory_desc_wrapper weights_d(&weights_md); |
113 | const memory_desc_wrapper dst_d(&dst_md); |
114 | const memory_desc_wrapper bias_d(&bias_md); |
115 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
116 | const int vnni_granularity |
117 | = (jcp.wei_dt == f16 && jcp.isa == avx512_core_fp16) |
118 | ? 1 |
119 | : data_type_vnni_granularity(jcp.wei_dt); |
120 | |
121 | const bool is_1d = jcp.ndims == 3; |
122 | const bool is_2d = jcp.ndims == 4; |
123 | const bool is_3d = jcp.ndims == 5; |
124 | |
125 | if (jcp.wei_plain) { |
126 | jcp.LDB = jcp.oc; |
127 | if (is_3d) { |
128 | switch (vnni_granularity) { |
129 | case 1: wei_tag = with_groups ? gdhwio : dhwio; break; |
130 | case 2: wei_tag = with_groups ? gdhwIo2i : dhwIo2i; break; |
131 | case 4: wei_tag = with_groups ? gdhwIo4i : dhwIo4i; break; |
132 | default: return status::unimplemented; |
133 | } |
134 | } else if (is_1d) { |
135 | switch (vnni_granularity) { |
136 | case 1: wei_tag = with_groups ? gwio : wio; break; |
137 | case 2: wei_tag = with_groups ? gwIo2i : wIo2i; break; |
138 | case 4: wei_tag = with_groups ? gwIo4i : wIo4i; break; |
139 | default: return status::unimplemented; |
140 | } |
141 | } else { |
142 | assert(is_2d); |
143 | UNUSED(is_2d); |
144 | switch (vnni_granularity) { |
145 | case 1: wei_tag = with_groups ? ghwio : hwio; break; |
146 | case 2: wei_tag = with_groups ? ghwIo2i : hwIo2i; break; |
147 | case 4: wei_tag = with_groups ? ghwIo4i : hwIo4i; break; |
148 | default: return status::unimplemented; |
149 | } |
150 | } |
151 | } else { |
152 | jcp.LDB = jcp.oc_block; |
153 | if (jcp.oc_block == 64) { |
154 | if (is_3d) { |
155 | switch (vnni_granularity) { |
156 | case 1: wei_tag = with_groups ? gOdhwi64o : Odhwi64o; break; |
157 | case 2: |
158 | if (jcp.is_ic_padded) |
159 | wei_tag = with_groups ? gOdhwI16i64o2i |
160 | : OdhwI16i64o2i; |
161 | else |
162 | wei_tag = with_groups ? gOdhwI64o2i : OdhwI64o2i; |
163 | break; |
164 | case 4: |
165 | if (jcp.is_ic_padded) |
166 | wei_tag = with_groups ? gOdhwI16i64o4i |
167 | : OdhwI16i64o4i; |
168 | else |
169 | wei_tag = with_groups ? gOdhwI64o4i : OdhwI64o4i; |
170 | break; |
171 | default: return status::unimplemented; |
172 | } |
173 | } else if (is_1d) { |
174 | switch (vnni_granularity) { |
175 | case 1: wei_tag = with_groups ? gOwi64o : Owi64o; break; |
176 | case 2: |
177 | if (jcp.is_ic_padded) |
178 | wei_tag = with_groups ? gOwI16i64o2i : OwI16i64o2i; |
179 | else |
180 | wei_tag = with_groups ? gOwI64o2i : OwI64o2i; |
181 | break; |
182 | case 4: |
183 | if (jcp.is_ic_padded) |
184 | wei_tag = with_groups ? gOwI16i64o4i : OwI16i64o4i; |
185 | else |
186 | wei_tag = with_groups ? gOwI64o4i : OwI64o4i; |
187 | break; |
188 | default: return status::unimplemented; |
189 | } |
190 | } else { |
191 | assert(is_2d); |
192 | UNUSED(is_2d); |
193 | switch (vnni_granularity) { |
194 | case 1: wei_tag = with_groups ? gOhwi64o : Ohwi64o; break; |
195 | case 2: |
196 | if (jcp.is_ic_padded) |
197 | wei_tag = with_groups ? gOhwI16i64o2i |
198 | : OhwI16i64o2i; |
199 | else |
200 | wei_tag = with_groups ? gOhwI64o2i : OhwI64o2i; |
201 | break; |
202 | case 4: |
203 | if (jcp.is_ic_padded) |
204 | wei_tag = with_groups ? gOhwI16i64o4i |
205 | : OhwI16i64o4i; |
206 | else |
207 | wei_tag = with_groups ? gOhwI64o4i : OhwI64o4i; |
208 | break; |
209 | default: return status::unimplemented; |
210 | } |
211 | } |
212 | } else if (jcp.oc_block == 48) { |
213 | if (is_3d) { |
214 | switch (vnni_granularity) { |
215 | case 1: wei_tag = with_groups ? gOdhwi48o : Odhwi48o; break; |
216 | case 2: |
217 | if (jcp.is_ic_padded) |
218 | wei_tag = with_groups ? gOdhwI16i48o2i |
219 | : OdhwI16i48o2i; |
220 | else |
221 | wei_tag = with_groups ? gOdhwI48o2i : OdhwI48o2i; |
222 | break; |
223 | case 4: |
224 | if (jcp.is_ic_padded) |
225 | wei_tag = with_groups ? gOdhwI16i48o4i |
226 | : OdhwI16i48o4i; |
227 | else |
228 | wei_tag = with_groups ? gOdhwI48o4i : OdhwI48o4i; |
229 | break; |
230 | default: return status::unimplemented; |
231 | } |
232 | } else if (is_1d) { |
233 | switch (vnni_granularity) { |
234 | case 1: wei_tag = with_groups ? gOwi48o : Owi48o; break; |
235 | case 2: |
236 | if (jcp.is_ic_padded) |
237 | wei_tag = with_groups ? gOwI16i48o2i : OwI16i48o2i; |
238 | else |
239 | wei_tag = with_groups ? gOwI48o2i : OwI48o2i; |
240 | break; |
241 | case 4: |
242 | if (jcp.is_ic_padded) |
243 | wei_tag = with_groups ? gOwI16i48o4i : OwI16i48o4i; |
244 | else |
245 | wei_tag = with_groups ? gOwI48o4i : OwI48o4i; |
246 | break; |
247 | default: return status::unimplemented; |
248 | } |
249 | } else { |
250 | assert(is_2d); |
251 | UNUSED(is_2d); |
252 | switch (vnni_granularity) { |
253 | case 1: wei_tag = with_groups ? gOhwi48o : Ohwi48o; break; |
254 | case 2: |
255 | if (jcp.is_ic_padded) |
256 | wei_tag = with_groups ? gOhwI16i48o2i |
257 | : OhwI16i48o2i; |
258 | else |
259 | wei_tag = with_groups ? gOhwI48o2i : OhwI48o2i; |
260 | break; |
261 | case 4: |
262 | if (jcp.is_ic_padded) |
263 | wei_tag = with_groups ? gOhwI16i48o4i |
264 | : OhwI16i48o4i; |
265 | else |
266 | wei_tag = with_groups ? gOhwI48o4i : OhwI48o4i; |
267 | break; |
268 | default: return status::unimplemented; |
269 | } |
270 | } |
271 | } else if (jcp.oc_block == 32) { |
272 | if (is_3d) { |
273 | switch (vnni_granularity) { |
274 | case 1: wei_tag = with_groups ? gOdhwi32o : Odhwi32o; break; |
275 | case 2: |
276 | if (jcp.is_ic_padded) |
277 | wei_tag = with_groups ? gOdhwI16i32o2i |
278 | : OdhwI16i32o2i; |
279 | else |
280 | wei_tag = with_groups ? gOdhwI32o2i : OdhwI32o2i; |
281 | break; |
282 | case 4: |
283 | if (jcp.is_ic_padded) |
284 | wei_tag = with_groups ? gOdhwI16i32o4i |
285 | : OdhwI16i32o4i; |
286 | else |
287 | wei_tag = with_groups ? gOdhwI32o4i : OdhwI32o4i; |
288 | break; |
289 | default: return status::unimplemented; |
290 | } |
291 | } else if (is_1d) { |
292 | switch (vnni_granularity) { |
293 | case 1: wei_tag = with_groups ? gOwi32o : Owi32o; break; |
294 | case 2: |
295 | if (jcp.is_ic_padded) |
296 | wei_tag = with_groups ? gOwI16i32o2i : OwI16i32o2i; |
297 | else |
298 | wei_tag = with_groups ? gOwI32o2i : OwI32o2i; |
299 | break; |
300 | case 4: |
301 | if (jcp.is_ic_padded) |
302 | wei_tag = with_groups ? gOwI16i32o4i : OwI16i32o4i; |
303 | else |
304 | wei_tag = with_groups ? gOwI32o4i : OwI32o4i; |
305 | break; |
306 | default: return status::unimplemented; |
307 | } |
308 | } else { |
309 | assert(is_2d); |
310 | UNUSED(is_2d); |
311 | switch (vnni_granularity) { |
312 | case 1: wei_tag = with_groups ? gOhwi32o : Ohwi32o; break; |
313 | case 2: |
314 | if (jcp.is_ic_padded) |
315 | wei_tag = with_groups ? gOhwI16i32o2i |
316 | : OhwI16i32o2i; |
317 | else |
318 | wei_tag = with_groups ? gOhwI32o2i : OhwI32o2i; |
319 | break; |
320 | case 4: |
321 | if (jcp.is_ic_padded) |
322 | wei_tag = with_groups ? gOhwI16i32o4i |
323 | : OhwI16i32o4i; |
324 | else |
325 | wei_tag = with_groups ? gOhwI32o4i : OhwI32o4i; |
326 | break; |
327 | default: return status::unimplemented; |
328 | } |
329 | } |
330 | } else if (jcp.oc_block == 16) { |
331 | if (is_3d) { |
332 | switch (vnni_granularity) { |
333 | case 1: wei_tag = with_groups ? gOdhwi16o : Odhwi16o; break; |
334 | case 2: |
335 | if (jcp.is_ic_padded) |
336 | wei_tag = with_groups ? gOdhwI16i16o2i |
337 | : OdhwI16i16o2i; |
338 | else |
339 | wei_tag = with_groups ? gOdhwI16o2i : OdhwI16o2i; |
340 | break; |
341 | case 4: |
342 | if (jcp.is_ic_padded) |
343 | wei_tag = with_groups ? gOdhwI16i16o4i |
344 | : OdhwI16i16o4i; |
345 | else |
346 | wei_tag = with_groups ? gOdhwI16o4i : OdhwI16o4i; |
347 | break; |
348 | default: return status::unimplemented; |
349 | } |
350 | } else if (is_1d) { |
351 | switch (vnni_granularity) { |
352 | case 1: wei_tag = with_groups ? gOwi16o : Owi16o; break; |
353 | case 2: |
354 | if (jcp.is_ic_padded) |
355 | wei_tag = with_groups ? gOwI16i16o2i : OwI16i16o2i; |
356 | else |
357 | wei_tag = with_groups ? gOwI16o2i : OwI16o2i; |
358 | break; |
359 | case 4: |
360 | if (jcp.is_ic_padded) |
361 | wei_tag = with_groups ? gOwI16i16o4i : OwI16i16o4i; |
362 | else |
363 | wei_tag = with_groups ? gOwI16o4i : OwI16o4i; |
364 | break; |
365 | default: return status::unimplemented; |
366 | } |
367 | } else { |
368 | assert(is_2d); |
369 | UNUSED(is_2d); |
370 | |
371 | switch (vnni_granularity) { |
372 | case 1: wei_tag = with_groups ? gOhwi16o : Ohwi16o; break; |
373 | case 2: |
374 | if (jcp.is_ic_padded) |
375 | wei_tag = with_groups ? gOhwI16i16o2i |
376 | : OhwI16i16o2i; |
377 | else |
378 | wei_tag = with_groups ? gOhwI16o2i : OhwI16o2i; |
379 | break; |
380 | case 4: |
381 | if (jcp.is_ic_padded) |
382 | wei_tag = with_groups ? gOhwI16i16o4i |
383 | : OhwI16i16o4i; |
384 | else |
385 | wei_tag = with_groups ? gOhwI16o4i : OhwI16o4i; |
386 | break; |
387 | default: return status::unimplemented; |
388 | } |
389 | } |
390 | } else if (jcp.oc_block == 8) { |
391 | if (vnni_granularity == 1) |
392 | wei_tag = with_groups ? gOhwi8o : Ohwi8o; |
393 | else |
394 | return status::unimplemented; |
395 | } else { |
396 | return status::unimplemented; |
397 | } |
398 | } |
399 | |
400 | src_tag = dst_tag; |
401 | |
402 | const bool any_eligible = is_any_eligible(jcp); |
403 | CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible)); |
404 | CHECK(init_tag(jcp.dst_tag, dst_md, dst_d, dst_tag, any_eligible)); |
405 | CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag, true)); |
406 | |
407 | return status::success; |
408 | } |
409 | |
410 | struct brg_blocking_t : public jit_brgemm_conv_conf_t { |
411 | struct array_in_loop_t { |
412 | dim_t itersize; |
413 | float repeatn; |
414 | float overlap; |
415 | void set(dim_t iter_s, float rpt, float ovlp = 1.f) { |
416 | itersize = iter_s; |
417 | repeatn = rpt; |
418 | overlap = ovlp; |
419 | } |
420 | }; |
421 | |
422 | struct loop_t { |
423 | array_in_loop_t src; |
424 | array_in_loop_t wei; |
425 | array_in_loop_t dst; |
426 | }; |
427 | |
428 | brg_blocking_t() { |
429 | // TODO: This is a broken form of initialization for a base class. |
430 | // Either set default values in a base class, or provide a proper |
431 | // default ctor, or take a `jit_brgemm_conv_conf_t` object to initialize |
432 | // a base class object. |
433 | jit_brgemm_conv_conf_t *base |
434 | = static_cast<jit_brgemm_conv_conf_t *>(this); |
435 | *base = jit_brgemm_conv_conf_t(); |
436 | init(); |
437 | } |
438 | brg_blocking_t(const jit_brgemm_conv_conf_t &jcp) |
439 | : jit_brgemm_conv_conf_t(jcp) { |
440 | init(); |
441 | } |
442 | void init() { |
443 | ur = 0; |
444 | ur_block = 0; |
445 | ur_block_tail = 0; |
446 | eff = 0.f; |
447 | nb_kd = 0; |
448 | nb_kh = 0; |
449 | nb_kw = 0; |
450 | sp = 0; |
451 | sp_block = 0; |
452 | nb_sp = 0; |
453 | eff = 0; |
454 | max_regs = isa_num_vregs(isa); |
455 | bcast_simd = acc_simd_w; |
456 | } |
457 | |
458 | int ur, ur_block, ur_block_tail; |
459 | int nb_kd, nb_kh, nb_kw; |
460 | int max_regs; |
461 | int bcast_simd; |
462 | float eff; |
463 | static unsigned L1; |
464 | static unsigned L2; |
465 | static unsigned L3; |
466 | // These are rough estimates of the latency (relative) of access to various |
467 | // cache levels. This is enough for an estimation of data access cost. |
468 | // TODO: Improve memory access estimates |
469 | static constexpr float L1_k = 1.f; |
470 | static constexpr float L2_k = 3.f; |
471 | static constexpr float L3_k = 15.f; |
472 | // TODO: At the moment, we are primarily evaluating the fit of the data into |
473 | // the L1/L2. Need to take into account the difference between the L3 and |
474 | // memory. |
475 | static constexpr float mem_k = 15.f; |
476 | static constexpr int bench_iterations = 1; |
477 | |
478 | int sp, sp_block, nb_sp; |
479 | static int last_ic_block_size; |
480 | |
481 | void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; } |
482 | void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; } |
483 | |
484 | status_t estimate_brgemm_ur(); |
485 | status_t get_brgemm_ur( |
486 | const primitive_attr_t *attr, const memory_desc_t &dst_md); |
487 | |
488 | float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk, |
489 | bool is_broadcast, bool is_shared) const; |
490 | |
491 | float io_k(const loop_t loop, const array_in_loop_t arr, float pk, |
492 | bool is_broadcast, bool is_shared) const; |
493 | |
494 | void select_ic_block(); |
495 | |
496 | void update_blocks(); |
497 | bool fast_check_oc_block() const; |
498 | float est_eff(); |
499 | void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block, |
500 | int kh_block, bool maybe_use_buffer, int max_ow_block_thr); |
501 | status_t calc_blocks(); |
502 | |
503 | bool fast_check_oc_block_1x1() const; |
504 | float est_eff_1x1(); |
505 | void calc_blocks_1x1(); |
506 | |
507 | // utils |
508 | static int get_inp_size( |
509 | int max_src_size, int dst_size, int k, int stride, int dilate) { |
510 | auto adj_str = nstl::min(k, stride); |
511 | const auto res = nstl::min(max_src_size, |
512 | calculate_end_padding(0, dst_size, 0, adj_str, |
513 | calculate_extended_filter_size(k, dilate))); |
514 | return res; |
515 | } |
516 | |
517 | static float squeeze_val(float eff, float koeff) { |
518 | if (koeff <= 0) return 1; |
519 | if (koeff == 1) return eff; |
520 | const auto k = 1.f / koeff; |
521 | return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff; |
522 | } |
523 | |
524 | static int estimate_ur(int oc_block) { |
525 | const auto est_ur = (oc_block == 64) |
526 | ? 6 |
527 | : ((oc_block == 48) ? 9 : ((oc_block == 32) ? 14 : 28)); |
528 | return est_ur; |
529 | } |
530 | |
531 | int inp_w(int out_w, int ker_w) const { |
532 | return get_inp_size(iw, out_w, ker_w, stride_w, dilate_w); |
533 | } |
534 | |
535 | int rnd_simd(int val) const { return rnd_up(val, simd_w); } |
536 | |
537 | int rnd_inp_simd(int out_w, int ker_w, int vic) const { |
538 | const auto vsp = inp_w(out_w, ker_w); |
539 | return ((stride_w == 1 && vic >= ic) ? rnd_up(vsp * vic, simd_w) |
540 | : vsp * rnd_up(vic, simd_w)); |
541 | } |
542 | |
543 | static constexpr int MAXNLOOPS = 32; |
544 | loop_t loop[MAXNLOOPS]; |
545 | }; |
546 | |
547 | unsigned brg_blocking_t::L1; |
548 | unsigned brg_blocking_t::L2; |
549 | unsigned brg_blocking_t::L3; |
550 | int brg_blocking_t::last_ic_block_size; |
551 | |
552 | float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk, |
553 | bool is_broadcast, bool is_shared) const { |
554 | if (n < 1) return 0; |
555 | if (n == 1) return pk; |
556 | const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz |
557 | + (use_buffer ? dst * acc_dsz : 0); |
558 | const auto amount_L1 = is_broadcast ? src * src_dsz : amount; |
559 | const auto k = is_broadcast |
560 | ? ((amount_L1 < L1) ? L1_k |
561 | : ((amount < L2) ? L2_k |
562 | : (is_shared ? L3_k : mem_k))) |
563 | : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k)); |
564 | const auto cost = pk + k * (n - 1); |
565 | return cost / n; |
566 | } |
567 | |
568 | float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr, |
569 | float pk, bool is_broadcast, bool is_shared) const { |
570 | return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize, |
571 | arr.repeatn * arr.overlap, pk, is_broadcast, is_shared); |
572 | } |
573 | |
574 | void brg_blocking_t::select_ic_block() { |
575 | if (is_1x1 && is_amx(isa)) { |
576 | // TODO: merge with non-1x1 code block below |
577 | const int ic_padded_block = 16 * brg_blocking_t::last_ic_block_size; |
578 | assert(IMPLICATION( |
579 | !is_bf32, ic < ic_padded_block || ic % ic_padded_block == 0)); |
580 | MAYBE_UNUSED(ic_padded_block); |
581 | // Note: bf32 requires ic_block be less than 64, otherwise it results |
582 | // in incorrect output. |
583 | ic_block = is_bf32 && (!is_rtus) ? nstl::min(64, ic) : ic; |
584 | nb_ic = utils::div_up(ic, ic_block); // trivially 1 for now |
585 | return; |
586 | } |
587 | auto nb_simd = utils::div_up(ic, simd_w); |
588 | auto max_simd_blocks = nstl::min(5 * simd_w, nb_simd); |
589 | const auto nb_icb_eff_threshold = 0.5f; |
590 | const auto padded_ic = last_ic_block_size * (is_ic_padded ? acc_simd_w : 1); |
591 | if (is_amx(isa)) { |
592 | if (ic * kw_sets < simd_w) { |
593 | // this is current requirement from brgemm kernel |
594 | ic_block = rnd_up(ic, last_ic_block_size); |
595 | } else if (is_bf32) { |
596 | ic_block = simd_w; |
597 | } else { |
598 | if (exec_type == exec_trans) { |
599 | auto simd_blocks = 1; |
600 | for (int nb_icb = max_simd_blocks; nb_icb >= 1; nb_icb--) { |
601 | auto nb_icb_eff = static_cast<float>(nb_simd) |
602 | / rnd_up(nb_simd, nb_icb); |
603 | if (nb_icb_eff >= nb_icb_eff_threshold) { |
604 | simd_blocks = nb_icb; |
605 | break; |
606 | } |
607 | } |
608 | ic_block = simd_blocks * simd_w; |
609 | } else |
610 | ic_block = simd_w; |
611 | } |
612 | } else { |
613 | const auto est_ur = nstl::min(sp_block, estimate_ur(oc_block)); |
614 | const auto inp_ur = is_os_blocking ? est_ur : inp_w(est_ur, kw_block); |
615 | |
616 | if (kw_block > 1) { |
617 | // try to fit src into L1 |
618 | const auto inp_per_ic = static_cast<unsigned int>(inp_ur) * src_dsz; |
619 | max_simd_blocks = saturate(1, max_simd_blocks, |
620 | static_cast<int>(L1 / (inp_per_ic * simd_w))); |
621 | } |
622 | // try to fit all batch for ur into L2 |
623 | const auto wei_per_ic = static_cast<unsigned int>(kd_block) * kh_block |
624 | * kw_block * oc_block * wei_dsz; |
625 | const auto inp_per_ic = static_cast<unsigned int>(kd_block) * kh_block |
626 | * inp_ur * src_dsz; |
627 | const auto out_size |
628 | = static_cast<unsigned int>(ur) * oc_block * dst_dsz; |
629 | |
630 | max_simd_blocks = saturate(1, max_simd_blocks, |
631 | static_cast<int>((L2 - out_size) |
632 | / ((wei_per_ic + inp_per_ic) * simd_w))); |
633 | |
634 | auto simd_blocks = 1; |
635 | for (int nb_icb = nstl::min(max_simd_blocks, nb_simd); nb_icb >= 1; |
636 | nb_icb--) { |
637 | auto nb_icb_eff |
638 | = static_cast<float>(nb_simd) / rnd_up(nb_simd, nb_icb); |
639 | if (nb_icb_eff >= nb_icb_eff_threshold) { |
640 | simd_blocks = nb_icb; |
641 | break; |
642 | } |
643 | } |
644 | |
645 | ic_block = nstl::min( |
646 | (exec_type == exec_trans) ? rnd_up(ic, padded_ic) : ic, |
647 | simd_blocks * simd_w); |
648 | } |
649 | nb_ic = utils::div_up(ic, ic_block); |
650 | } |
651 | |
652 | status_t brg_blocking_t::estimate_brgemm_ur() { |
653 | // Simple simulation of brgemm_desc init |
654 | if (sp_block <= 0) return status::invalid_arguments; |
655 | LDA = is_rtus |
656 | ? (ic_block) |
657 | : (kh_sets > 1 ? kh_sets : 1) * (kw_sets > 1 ? kw_sets : stride_w) |
658 | * (exec_type == exec_trans ? ic_block |
659 | : ngroups * ic_without_padding); |
660 | LDB = oc_block; |
661 | LDC = use_buffer ? oc_block : oc_without_padding; |
662 | |
663 | // Configure matrix sizes |
664 | // for amx if ic_block != ic then we use exec_trans so K is ic_block |
665 | const auto padded_ic = last_ic_block_size * (is_ic_padded ? acc_simd_w : 1); |
666 | |
667 | icp = rnd_up(ic, padded_ic); |
668 | M = brgM = sp >= sp_block ? sp_block : 0; |
669 | M_tail = brgM_tail = sp % sp_block; |
670 | if (is_os_blocking) { |
671 | if (!is_1x1) M_tail = brgM_tail = (oh * ow) % sp_block; |
672 | oskip = ((ext_kw - 1) / stride_w) * stride_h + (stride_h - 1) * ow; |
673 | |
674 | brgM = sp_block + oskip * (div_up(M, ow) - 1); |
675 | |
676 | // round up brgM to help brgemm kernel use max amx_h as brgemm bd_block |
677 | if (use_M_mask == 2) { |
678 | int ibrgM = 0; |
679 | const auto adj_ow = ow_block + oskip; |
680 | while (ibrgM < brgM) { |
681 | if (ibrgM % adj_ow < ow_block) |
682 | ibrgM += amx_h; |
683 | else |
684 | ibrgM++; |
685 | } |
686 | brgM = ibrgM; |
687 | } else |
688 | brgM = rnd_up(brgM, amx_h); |
689 | |
690 | brgM_tail = brgM; |
691 | } |
692 | |
693 | N = oc >= oc_block ? oc_block : 0; |
694 | N_tail = oc % oc_block; |
695 | |
696 | K = kh_sets * kw_sets * (ic >= ic_block ? ic_block : 0); |
697 | K_tail = kh_sets * kw_sets |
698 | * (exec_type == exec_trans && (!is_bf32) |
699 | ? ic_block |
700 | : rnd_up(ic % ic_block, last_ic_block_size)); |
701 | |
702 | const auto vK = K > 0 ? K : K_tail; |
703 | const auto vM = M > 0 ? M : M_tail; |
704 | const auto vN = N > 0 ? N : N_tail; |
705 | |
706 | const float alpha = 1.0; |
707 | const float beta = 0.0; |
708 | brgemm_t brg; |
709 | brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, |
710 | brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr, |
711 | is_bf32); |
712 | CHECK(brgemm_utils::brgemm_blocking(&brg)); |
713 | ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1); |
714 | ur_block = brg.bd_block; |
715 | if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) { |
716 | brgemm_t brg_sp_tail; |
717 | brgemm_utils::init_brgemm_conf(&brg_sp_tail, isa, brgemm_addr, src_dt, |
718 | wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, M_tail, |
719 | vN, vK, nullptr, is_bf32); |
720 | CHECK(brgemm_utils::brgemm_blocking(&brg_sp_tail)); |
721 | ur_block_tail = brg_sp_tail.bd_block; |
722 | } else { |
723 | ur_block_tail = 0; |
724 | } |
725 | return status::success; |
726 | } |
727 | |
728 | status_t brg_blocking_t::get_brgemm_ur( |
729 | const primitive_attr_t *attr, const memory_desc_t &dst_md) { |
730 | // Detailed simulation of brgemm convolution init |
731 | if (sp_block <= 0 || ic_block <= 0 || oc_block <= 0) |
732 | return status::invalid_arguments; |
733 | CHECK(estimate_brgemm_ur()); |
734 | |
735 | LDD = oc_without_padding; |
736 | |
737 | const float alpha = 1.0; |
738 | const float beta = 1.0; |
739 | const float beta_init = 0.0; |
740 | |
741 | for (int i = 0; i < M; i++) { |
742 | auto vM = i + 1; |
743 | // init only needed brgemm descriptors |
744 | if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1) |
745 | && vM != M && vM != M_tail) |
746 | continue; |
747 | for (int i_init = 0; i_init < 2; i_init++) { |
748 | for (int i_N = 0; i_N < 2; i_N++) { |
749 | for (int i_K = 0; i_K < 2; i_K++) { |
750 | auto vbeta = (i_init) ? beta_init : beta; |
751 | auto vN = (i_N) ? N_tail : N; |
752 | auto vK = (i_K) ? K_tail : K; |
753 | if (vN == 0 || vK == 0) continue; |
754 | brgemm_t brg; |
755 | brgemm_strides_t brg_strides; |
756 | brg_strides.stride_a = ngroups * ic_without_padding |
757 | * (dilate_w + 1) * src_dsz; |
758 | //weights are padded by oc_block and last_ic_block |
759 | brg_strides.stride_b = rnd_up(ic, last_ic_block_size) |
760 | * rnd_up(oc, oc_block) * wei_dsz; |
761 | const auto strides_ptr = (brg_type == brgemm_strd) |
762 | ? &brg_strides |
763 | : nullptr; |
764 | brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt, |
765 | wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, |
766 | LDC, vM, vN, vK, strides_ptr, is_bf32); |
767 | CHECK(brgemm_utils::brgemm_blocking(&brg)); |
768 | |
769 | brgemm_attr_t brgattr; |
770 | brgattr.max_bs = max_batch; |
771 | const auto max_vpad = (exec_type == exec_vpad) |
772 | ? nstl::max(l_pad, r_pad) |
773 | : 0; |
774 | brgattr.max_top_vpad = max_vpad; |
775 | brgattr.max_bottom_vpad = max_vpad; |
776 | brgattr.fpmath_mode = attr->fpmath_mode_; |
777 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
778 | |
779 | brg.with_sum = with_sum; |
780 | CHECK(brgemm_desc_set_postops( |
781 | &brg, attr, &dst_md, LDD, bia_dt)); |
782 | } |
783 | } |
784 | } |
785 | } |
786 | |
787 | return status::success; |
788 | } |
789 | |
790 | void brg_blocking_t::update_blocks() { |
791 | if (sp_block <= 0 |
792 | || utils::one_of(0, od_block, oh_block, ic_block, oc_block, |
793 | kd_block, kh_block, kw_block, os_block, ow_block)) |
794 | return; |
795 | |
796 | nb_od = div_up(od, od_block); |
797 | nb_oh = div_up(oh, oh_block); |
798 | nb_ic = div_up(ic, ic_block); |
799 | nb_oc = div_up(oc, oc_block); |
800 | nb_kd = div_up(kd, kd_block); |
801 | nb_kh = div_up(kh, kh_block); |
802 | nb_kw = div_up(kw, kw_block); |
803 | nb_ow = div_up(ow, ow_block); |
804 | if (is_os_blocking) { |
805 | nb_os = div_up(os, os_block); |
806 | sp = os; |
807 | sp_block = os_block; |
808 | nb_sp = nb_os; |
809 | } else { |
810 | sp = ow; |
811 | sp_block = ow_block; |
812 | nb_sp = nb_ow; |
813 | iw_block = get_inp_size(iwp, ow_block, kw, stride_w, dilate_w); |
814 | } |
815 | } |
816 | |
817 | bool brg_blocking_t::fast_check_oc_block() const { |
818 | // This function for reducing the number of blocking variants |
819 | // TODO: eliminate heuristic in this function |
820 | const auto rnd_oc = rnd_up(oc, acc_simd_w); |
821 | auto res = false; |
822 | if (oc_block == 64) { |
823 | res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz < 192 * 4); |
824 | } else if (oc_block == 48) { |
825 | const bool big_spatial |
826 | = id * ih * iw > 81 * stride_d * stride_h * stride_w; |
827 | res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz <= 384 * 4 |
828 | && big_spatial); |
829 | } else |
830 | res = true; |
831 | |
832 | return res; |
833 | } |
834 | |
835 | float brg_blocking_t::est_eff() { |
836 | const auto ocblock = oc_block / acc_simd_w; |
837 | |
838 | const auto brgemm_microkernel_eff |
839 | = (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs); |
840 | |
841 | const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur); |
842 | const auto brgemm_eff = squeeze_val(ur |
843 | * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block)) |
844 | / 64, |
845 | 0.5f); |
846 | |
847 | const auto sp_amount = nb_od * nb_oh * nb_sp; |
848 | const auto work_amount = mb * ngroups * nb_oc * sp_amount; |
849 | const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block)); |
850 | |
851 | const auto thr_eff = static_cast<float>(work_amount) |
852 | / utils::rnd_up(work_amount, nthr); |
853 | |
854 | const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block); |
855 | |
856 | const auto job = div_up(work_amount, nthr); |
857 | |
858 | auto job_eff = 1.f; |
859 | if (job < nthr) { |
860 | std::vector<dim_t> thr_jobs(nthr); |
861 | |
862 | for (int ithr = 0; ithr < nthr; ithr++) { |
863 | thr_jobs[ithr] = 0; |
864 | if (ithr >= work_amount) continue; |
865 | dim_t thr_job = 0; |
866 | int start {0}, end {0}; |
867 | balance211(work_amount, nthr, ithr, start, end); |
868 | int n {0}, g {0}, ocb {0}, odp {0}, ohp {0}, spb {0}; |
869 | if (loop_order == loop_ndhwgc) |
870 | nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, g, |
871 | ngroups, ocb, nb_oc); |
872 | else if (loop_order == loop_ngcdhw) |
873 | nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp, od, |
874 | ohp, oh, spb, nb_sp); |
875 | |
876 | for (auto work = start; work < end; work++) { |
877 | const int ocp = ocb * oc_block; |
878 | const auto oc_sz = nstl::min(oc - ocp, oc_block); |
879 | int sp_sz = 0; |
880 | const int spp = spb * sp_block; |
881 | sp_sz = nstl::min(sp - spp, sp_block); |
882 | thr_job += sp_sz * oc_sz; |
883 | |
884 | if (loop_order == loop_ndhwgc) |
885 | nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g, |
886 | ngroups, ocb, nb_oc); |
887 | else if (loop_order == loop_ngcdhw) |
888 | nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od, |
889 | ohp, oh, spb, nb_sp); |
890 | } |
891 | thr_jobs[ithr] = thr_job; |
892 | } |
893 | |
894 | dim_t max_job = 0; |
895 | dim_t sum_job = 0; |
896 | for (int ithr = 0; ithr < nthr; ithr++) { |
897 | if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr]; |
898 | sum_job += thr_jobs[ithr]; |
899 | } |
900 | job_eff = max_job == 0 ? 1 |
901 | : static_cast<float>(sum_job) / (max_job * nthr); |
902 | |
903 | } else { |
904 | job_eff = thr_eff; |
905 | } |
906 | |
907 | const auto ic_blocking_size = ic_block * nb_ic_blocking; |
908 | const auto oc_blocking_size = oc_block * ic_blocking_size; |
909 | |
910 | int l = -1; |
911 | |
912 | // -- brgemm kernel: loop by simd_w -- |
913 | l++; |
914 | const auto inp_ur = inp_w(ur, kw_block); |
915 | loop[l].src.set(inp_ur * simd_w, 1, bcast_simd); |
916 | loop[l].dst.set(0, 1); |
917 | loop[l].wei.set(oc_block, 1); |
918 | |
919 | // -- brgemm kernel: loop by kw in kw_block -- |
920 | l++; |
921 | auto src_is = rnd_inp_simd(ur, kw_block, ic_blocking_size); |
922 | loop[l].src.set(src_is, 1, kw_block); |
923 | loop[l].dst.set(0, 1); |
924 | loop[l].wei.set(oc_blocking_size, 1); |
925 | |
926 | // -- brgemm kernel: loop by batch (grouped by kw_block) in ur -- |
927 | l++; |
928 | loop[l].src.set(src_is, 1); |
929 | loop[l].dst.set(0, 1); |
930 | auto wei_is = kw_block * oc_blocking_size; |
931 | loop[l].wei.set(wei_is, 1); |
932 | // -- brgemm kernel: loop by ur in sp_block -- |
933 | l++; |
934 | const auto nb_ur = div_up(sp_block, ur); |
935 | loop[l].src.set(kd_block * kh_block * src_is, 1); |
936 | loop[l].dst.set(ur * oc_block, 1); |
937 | wei_is = kd_block * kh_block * kw_block * oc_blocking_size; |
938 | loop[l].wei.set(wei_is, nb_ur); |
939 | |
940 | // -- harness: loop by k_blocks in ks -- |
941 | l++; |
942 | loop[l].src.set(kd_block * kh_block |
943 | * rnd_inp_simd(sp_block, kw_block, ic_blocking_size), |
944 | 1); |
945 | loop[l].dst.set(sp_block * oc_block, nb_kd * nb_kh * nb_kw); |
946 | loop[l].wei.set(wei_is, 1); |
947 | |
948 | // -- brgemm kernel: loop by ic_chunks -- |
949 | l++; |
950 | const auto ic_chunks = div_up(nb_ic, nb_ic_blocking); |
951 | loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, ic_blocking_size), 1); |
952 | loop[l].dst.set(sp_block * oc_block, ic_chunks); |
953 | wei_is = kd * kh * kw * oc_blocking_size; |
954 | loop[l].wei.set(wei_is, 1); |
955 | |
956 | const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount; |
957 | const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc)); |
958 | const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block); |
959 | const auto nsimd_oc_thr = div_up(oc_thr, simd_w); |
960 | |
961 | const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1; |
962 | const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp)); |
963 | const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block); |
964 | |
965 | int nb_oh_thr {1}, oh_thr {1}, nb_od_thr {1}, od_thr {1}; |
966 | if (!is_os_blocking) { |
967 | const auto dim_oh = nb_sp * dim_sp; |
968 | nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh)); |
969 | oh_thr = nstl::min(oh, nb_oh_thr * oh_block); |
970 | |
971 | const auto dim_od = nb_oh * dim_oh; |
972 | nb_od_thr = nstl::min(nb_od, div_up(job, dim_od)); |
973 | od_thr = nstl::min(od, nb_od_thr * od_block); |
974 | } |
975 | |
976 | src_is = kd * kh * rnd_inp_simd(sp_block, kw, ic); |
977 | |
978 | auto wei_op = kd * kh * kw * ocblock * ic; |
979 | if (loop_order == loop_ndhwgc) { |
980 | // -- harness: loop by oc_block -- |
981 | l++; |
982 | loop[l].src.set(src_is, nb_oc_thr); |
983 | loop[l].dst.set(sp_block * oc_block, 1); |
984 | wei_is = kd * kh * kw * oc_block * ic; |
985 | wei_op = kd * kh * kw * nsimd_oc_thr * ic; |
986 | loop[l].wei.set(wei_is, 1); |
987 | } |
988 | |
989 | // -- harness: loop by sp_blocks -- |
990 | l++; |
991 | loop[l].src.set(src_is, 1); |
992 | const auto rnd_oc_for_sp |
993 | = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock); |
994 | loop[l].dst.set(sp_block * rnd_oc_for_sp, 1); |
995 | loop[l].wei.set(wei_op * simd_w, nb_sp_thr); |
996 | // oh_block almost all is 1. TODO: manage oh_block != 1 |
997 | // -- harness: loop by oh_blocks -- |
998 | l++; |
999 | src_is = kd * kh * rnd_inp_simd(sp_thr, kw, ic); |
1000 | loop[l].src.set(oh_block * src_is, 1); |
1001 | loop[l].dst.set(sp_thr * rnd_oc_for_sp, 1); |
1002 | loop[l].wei.set(wei_op * simd_w, nb_oh_thr); |
1003 | // od_block almost all is 1. TODO: manage oh_block != 1 |
1004 | // -- harness: loop by od_blocks -- |
1005 | l++; |
1006 | loop[l].src.set(od_block * oh_thr * src_is, 1); |
1007 | loop[l].dst.set(oh_thr * sp_thr * rnd_oc_for_sp, 1); |
1008 | loop[l].wei.set(wei_op * simd_w, nb_od_thr); |
1009 | |
1010 | if (loop_order != loop_ndhwgc) { |
1011 | // -- harness: loop by oc_block -- |
1012 | l++; |
1013 | loop[l].src.set(od_thr * oh_thr * src_is, nb_oc_thr); |
1014 | loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1); |
1015 | loop[l].wei.set(kd * kh * kw * oc_block * ic, 1); |
1016 | } |
1017 | |
1018 | // -- harness: loop by mb -- |
1019 | l++; |
1020 | const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc)); |
1021 | loop[l].src.set(od_thr * oh_thr * src_is, 1); |
1022 | loop[l].dst.set(od_thr * oh_thr * sp_thr * nsimd_oc_thr * simd_w, 1); |
1023 | loop[l].wei.set(kd * kh * kw * nsimd_oc_thr * simd_w * ic, mb_thr); |
1024 | |
1025 | const auto src_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr |
1026 | * kd * kh * kw * ic; |
1027 | const auto dst_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr |
1028 | * nsimd_oc_thr; |
1029 | wei_op = kd * kh * kw * nsimd_oc_thr * ic; |
1030 | |
1031 | // for "real" application set bench_iterations to 1 |
1032 | const auto iterations = bench_iterations; |
1033 | l++; |
1034 | loop[l].src.set(src_op, iterations); |
1035 | loop[l].dst.set(dst_op * simd_w, iterations); |
1036 | loop[l].wei.set(wei_op * simd_w, iterations); |
1037 | |
1038 | auto src_mem_k = mem_k; |
1039 | auto dst_mem_k = mem_k; |
1040 | auto wei_mem_k = mem_k; |
1041 | float src_rp = 1; |
1042 | float dst_rp = 1; |
1043 | float wei_rp = 1; |
1044 | |
1045 | for (auto il = l; il >= 0; il--) { |
1046 | src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, |
1047 | loop_order == loop_ndhwgc ? false : true); |
1048 | dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false); |
1049 | wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, |
1050 | loop_order == loop_ndhwgc ? true : false); |
1051 | src_rp *= loop[il].src.repeatn; |
1052 | dst_rp *= loop[il].dst.repeatn; |
1053 | wei_rp *= loop[il].wei.repeatn; |
1054 | } |
1055 | const auto src_ops = (src_op * src_rp) / iterations; |
1056 | const auto dst_ops = (dst_op * dst_rp) / iterations; |
1057 | const auto wei_ops = (wei_op * wei_rp) / iterations; |
1058 | |
1059 | const auto src_cost = src_mem_k * src_ops; |
1060 | const auto dst_cost = dst_mem_k * dst_ops; |
1061 | const auto wei_cost = wei_mem_k * wei_ops; |
1062 | const auto call_kernel_cost |
1063 | = 1000.f * job * ic_chunks * nb_kd * nb_kh * nb_kw; |
1064 | |
1065 | const auto cache_eff = (static_cast<dim_t>(mb) * od * oh * sp * ic * oc) |
1066 | / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost)); |
1067 | const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff |
1068 | * job_eff * ur_eff * cache_eff * brgemm_eff; |
1069 | return res_eff; |
1070 | } |
1071 | |
1072 | void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_, |
1073 | int kh_block_, bool maybe_use_buffer, int max_ow_block_thr) { |
1074 | |
1075 | unsigned est_k_amount = ic * oc_block * wei_dsz; |
1076 | |
1077 | kd_block = kd_block_; |
1078 | kh_block = kh_block_; |
1079 | if (one_of(exec_type, exec_vpad, exec_trans)) { |
1080 | kw_block = kw; |
1081 | kd_block_pad = kd_block; |
1082 | kh_block_pad = kh_block; |
1083 | kw_block_pad = kw_block; |
1084 | } else { |
1085 | kw_block = (est_k_amount * kw < L2) ? kw : 1; |
1086 | kd_block_pad = kh_block >= kd ? kd : 1; |
1087 | kh_block_pad = kw_block >= kh ? kh : 1; |
1088 | kw_block_pad = kw; |
1089 | } |
1090 | |
1091 | if (exec_type == exec_vpad) { |
1092 | od_block = 1; |
1093 | oh_block = 1; |
1094 | } else if (exec_type == exec_trans) { |
1095 | const auto w_block_size |
1096 | = 2 * src_dsz * ic * iwp + dst_dsz * ow * oc_block; |
1097 | const auto other_size = wei_dsz * kd * kh * kw * ic * oc_block |
1098 | + acc_dsz * 2 * amx_h * oc_block; |
1099 | const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)), |
1100 | other_size > L2 ? 0 : L2 - other_size); |
1101 | if (idp * ihp * w_block_size > L2_available) { |
1102 | od_block = utils::saturate( |
1103 | 1, od, int(L2_available / (ihp * w_block_size))); |
1104 | if (od_block == 1) |
1105 | oh_block = utils::saturate( |
1106 | 1, oh, int(L2_available / (w_block_size))); |
1107 | else |
1108 | oh_block = oh; |
1109 | } else { |
1110 | od_block = 1; |
1111 | oh_block = oh; |
1112 | } |
1113 | if (is_amx(isa)) { |
1114 | // try to fit into L1 |
1115 | bool L1_fit_res = false; |
1116 | auto cur_od_block = od_block; |
1117 | auto cur_oh_block = oh_block; |
1118 | const auto src_w_block_size |
1119 | = src_dsz * ic * iwp + dst_dsz * ow * oc_block; |
1120 | if (src_w_block_size < L1) { |
1121 | cur_od_block = utils::saturate( |
1122 | 1, od, int(L1 / (ihp * src_w_block_size))); |
1123 | if (cur_od_block == 1) |
1124 | cur_oh_block = utils::saturate( |
1125 | 1, oh, int(L1 / (src_w_block_size))); |
1126 | } |
1127 | for (; cur_od_block > 1; cur_od_block--) { |
1128 | const auto sp_size = cur_od_block * cur_oh_block * iwp; |
1129 | if ((static_cast<float>(od) / rnd_up(od, cur_od_block)) > 0.9f |
1130 | && static_cast<float>(sp_size) / rnd_up(sp, amx_h) |
1131 | > 0.8f) { |
1132 | L1_fit_res = true; |
1133 | break; |
1134 | } |
1135 | } |
1136 | if (cur_od_block == 1) { |
1137 | for (; cur_oh_block > 1; cur_oh_block--) { |
1138 | const auto sp_size = cur_oh_block * iwp; |
1139 | if ((static_cast<float>(oh) / rnd_up(oh, cur_oh_block)) |
1140 | > 0.9f |
1141 | && sp_size > 128) { |
1142 | L1_fit_res = true; |
1143 | break; |
1144 | } |
1145 | } |
1146 | } |
1147 | if (L1_fit_res) { |
1148 | od_block = cur_od_block; |
1149 | oh_block = cur_oh_block; |
1150 | } |
1151 | } |
1152 | |
1153 | // limit oh_block to have good threading |
1154 | const auto thr_oc_block = div_up( |
1155 | nthr, mb * div_up((oc > 32 ? ngroups : 1) * oc, oc_block)); |
1156 | const auto thr_od_block = div_up(od, thr_oc_block); |
1157 | const auto thr_oh_block |
1158 | = div_up(oh, thr_oc_block * div_up(od, thr_od_block)); |
1159 | od_block = nstl::min(od_block, thr_od_block); |
1160 | oh_block = nstl::min(oh_block, thr_oh_block); |
1161 | } else { |
1162 | od_block = 1; |
1163 | oh_block = 1; |
1164 | } |
1165 | |
1166 | // --- Select ow_block ---- |
1167 | const auto max_ow_block_L2 = ow; |
1168 | auto start_ow_block = nstl::min(max_ow_block_thr, max_ow_block_L2); |
1169 | |
1170 | sp = ow; |
1171 | const auto start_sp_block = is_os_blocking ? ow : start_ow_block; |
1172 | auto prev_spb = 0; |
1173 | for (auto ns = 1; ns <= sp; ns++) { |
1174 | const auto spb = div_up(sp, ns); |
1175 | if (spb == prev_spb || spb > start_sp_block) continue; |
1176 | if (is_os_blocking && spb != ow) continue; |
1177 | prev_spb = spb; |
1178 | ow_block = spb; |
1179 | sp_block = ow_block; |
1180 | |
1181 | select_ic_block(); |
1182 | |
1183 | use_buffer = maybe_use_buffer |
1184 | && (ic_block * nb_ic_blocking < ic || kd_block != kd |
1185 | || kh_block != kh || kw_block != kw |
1186 | || kd_block_pad != kd || kh_block_pad != kh |
1187 | || kw_block_pad != kw); |
1188 | if (exec_type == exec_base) |
1189 | use_buffer = use_buffer || (maybe_use_buffer && iwp != iw); |
1190 | |
1191 | const status_t st = estimate_brgemm_ur(); |
1192 | if (st != status::success) continue; |
1193 | os_block = sp_block = ow_block; |
1194 | update_blocks(); |
1195 | |
1196 | eff = est_eff(); |
1197 | |
1198 | if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this; |
1199 | } |
1200 | } |
1201 | |
1202 | status_t brg_blocking_t::calc_blocks() { |
1203 | sp = ow; |
1204 | |
1205 | nb_ic_blocking = 1; |
1206 | // --- Select kernel blocking --- |
1207 | // if dst_dt != acc_dt and we need to store intermediate |
1208 | // results then we need the out buffer |
1209 | const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum); |
1210 | |
1211 | std::vector<int> kd_blocks(1), kh_blocks(1); |
1212 | kd_blocks[0] = kd; |
1213 | kh_blocks[0] = kh; |
1214 | if (kd != 1) { |
1215 | kd_blocks.resize(2); |
1216 | kd_blocks[1] = 1; |
1217 | } |
1218 | if (kh != 1) { |
1219 | kh_blocks.resize(2); |
1220 | kh_blocks[1] = 1; |
1221 | } |
1222 | |
1223 | const auto thr_eff_threshold = 0.9f; |
1224 | const auto max_ow_block_thr = utils::saturate(1, ow, |
1225 | static_cast<int>(div_up( |
1226 | mb * ngroups * nb_oc * os, thr_eff_threshold * nthr))); |
1227 | |
1228 | ow_block = os_block = sp_block = -1; |
1229 | brg_blocking_t best_brgb = *this; |
1230 | for (const auto &kd_block : kd_blocks) { |
1231 | for (const auto &kh_block : kh_blocks) { |
1232 | iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer, |
1233 | max_ow_block_thr); |
1234 | } |
1235 | } |
1236 | *this = best_brgb; |
1237 | if (!IMPLICATION(!is_os_blocking, sp_block > 0)) |
1238 | return status::unimplemented; |
1239 | |
1240 | if (is_os_blocking) { |
1241 | ow_block = ow; |
1242 | os_block = ow * oh_block; |
1243 | sp_block = os_block; |
1244 | ow_tail = 0; |
1245 | } else { |
1246 | ow_block = os_block = sp_block; |
1247 | ow_tail = ow % ow_block; |
1248 | } |
1249 | update_blocks(); |
1250 | return status::success; |
1251 | } |
1252 | |
1253 | bool brg_blocking_t::fast_check_oc_block_1x1() const { |
1254 | // This function for reducing the number of blocking variants |
1255 | // TODO: eliminate heuristic in this function |
1256 | if (is_1x1 && is_amx(isa)) return true; |
1257 | const auto rnd_oc = rnd_up(oc, acc_simd_w); |
1258 | auto res = false; |
1259 | if (oc_block == 64) { |
1260 | const auto big_spatial |
1261 | = od * oh * ow >= 64 * stride_d * stride_h * stride_w; |
1262 | res = (rnd_oc % oc_block == 0 && big_spatial); |
1263 | } else if (oc_block == 48) { |
1264 | const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block); |
1265 | res = (oc_block_eff >= 0.95f); |
1266 | } else |
1267 | res = true; |
1268 | |
1269 | return res; |
1270 | } |
1271 | |
1272 | float brg_blocking_t::est_eff_1x1() { |
1273 | const auto ocblock = oc_block / acc_simd_w; |
1274 | |
1275 | auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float { |
1276 | const int nb = dim / block; |
1277 | constexpr int max_nb = 2; // only consider 2x2 tile blocking |
1278 | const int block2 = nstl::min(max_nb, nb); |
1279 | const int nb2 = nb / block2; |
1280 | const int nb2_tail = nb % block2; |
1281 | if (!use_ave) return block2; |
1282 | return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2); |
1283 | }; |
1284 | const bool use_ocb_ave = true; |
1285 | const auto ocb_ave = calc_ave_blk(oc_block, acc_simd_w, use_ocb_ave); |
1286 | const bool use_spb_ave = false; |
1287 | const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave); |
1288 | const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0; |
1289 | const auto M_tail_n_sp_blks |
1290 | = ur_block_tail > 0 ? M_tail / ur_block_tail : 0; |
1291 | |
1292 | // heuristic for maskrcnn workaround: use old blocking for some convolutions |
1293 | // TODO: remove this condition |
1294 | const bool maskrcnn_cond = (ic == 1024 && oc == 2048) |
1295 | || (ic == 1024 && oc == 512) || (ic == 256 && oc == 1024) |
1296 | || (ic == 512 && oc == 1024) || (ic == 512 && oc == 2048); |
1297 | const auto amx_fac = maskrcnn_cond |
1298 | ? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks)) |
1299 | : (static_cast<float>(div_up(M + M_tail, 16)) |
1300 | / (M_n_sp_blks + M_tail_n_sp_blks)); |
1301 | |
1302 | const auto brgemm_microkernel_eff = is_amx(isa) |
1303 | ? amx_fac * (static_cast<float>(ocb_ave) * spb_ave) |
1304 | / (ocb_ave + spb_ave) |
1305 | : (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs); |
1306 | const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur); |
1307 | const auto brgemm_eff = squeeze_val(ur |
1308 | * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block)) |
1309 | / 64, |
1310 | 0.5f); |
1311 | |
1312 | const auto sp_amount = is_os_blocking ? div_up(nb_os, nb_os_blocking) |
1313 | : nb_od * nb_oh * nb_sp; |
1314 | const auto work_amount = mb * ngroups * nb_oc * sp_amount; |
1315 | |
1316 | const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block); |
1317 | const auto thr_eff = static_cast<float>(work_amount) |
1318 | / utils::rnd_up(work_amount, nthr); |
1319 | const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block); |
1320 | |
1321 | const auto job = div_up(work_amount, nthr); |
1322 | |
1323 | const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount; |
1324 | const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc)); |
1325 | const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block); |
1326 | const auto nsimd_oc_thr = div_up(oc_thr, simd_w); |
1327 | |
1328 | const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1; |
1329 | const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp)); |
1330 | const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block); |
1331 | |
1332 | int nb_oh_thr {1}, oh_thr {1}, nb_od_thr {1}, od_thr {1}; |
1333 | if (!is_os_blocking) { |
1334 | const auto dim_oh = nb_sp * dim_sp; |
1335 | nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh)); |
1336 | oh_thr = nstl::min(oh, nb_oh_thr * oh_block); |
1337 | |
1338 | const auto dim_od = nb_oh * dim_oh; |
1339 | nb_od_thr = nstl::min(nb_od, div_up(job, dim_od)); |
1340 | od_thr = nstl::min(od, nb_od_thr * od_block); |
1341 | } |
1342 | |
1343 | auto job_eff = 1.f; |
1344 | if (job < nthr) { |
1345 | std::vector<dim_t> thr_jobs(nthr); |
1346 | for (int ithr = 0; ithr < nthr; ithr++) { |
1347 | thr_jobs[ithr] = 0; |
1348 | if (ithr >= work_amount) continue; |
1349 | dim_t thr_job = 0; |
1350 | int start {0}, end {0}; |
1351 | balance211(work_amount, nthr, ithr, start, end); |
1352 | int n {0}, g {0}, ocb {0}, oss {0}, odp {0}, ohp {0}, spb {0}; |
1353 | if (loop_order == loop_ndhwgc) { |
1354 | if (is_os_blocking) |
1355 | nd_iterator_init(start, n, mb, oss, sp_amount, g, ngroups, |
1356 | ocb, nb_oc); |
1357 | else |
1358 | nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, |
1359 | g, ngroups, ocb, nb_oc); |
1360 | } else if (loop_order == loop_ngcdhw) { |
1361 | if (is_os_blocking) |
1362 | nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, oss, |
1363 | sp_amount); |
1364 | else |
1365 | nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp, |
1366 | od, ohp, oh, spb, nb_sp); |
1367 | } |
1368 | |
1369 | for (auto work = start; work < end; work++) { |
1370 | const int ocp = ocb * oc_block; |
1371 | const auto oc_sz = nstl::min(oc - ocp, oc_block); |
1372 | int sp_sz = 0; |
1373 | if (is_os_blocking) { |
1374 | const auto osb_start = oss * nb_os_blocking; |
1375 | const auto osb_range |
1376 | = nstl::min(nb_os - osb_start, nb_os_blocking); |
1377 | for (int osb = 0; osb < osb_range; osb++) { |
1378 | const int osp = (osb_start + osb) * sp_block; |
1379 | sp_sz = nstl::min(os - osp, sp_block); |
1380 | } |
1381 | } else { |
1382 | const int spp = spb * sp_block; |
1383 | sp_sz = nstl::min(sp - spp, sp_block); |
1384 | } |
1385 | thr_job += sp_sz * oc_sz; |
1386 | |
1387 | if (loop_order == loop_ndhwgc) { |
1388 | if (is_os_blocking) |
1389 | nd_iterator_step( |
1390 | n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc); |
1391 | else |
1392 | nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g, |
1393 | ngroups, ocb, nb_oc); |
1394 | } else if (loop_order == loop_ngcdhw) { |
1395 | if (is_os_blocking) |
1396 | nd_iterator_step( |
1397 | n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount); |
1398 | else |
1399 | nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od, |
1400 | ohp, oh, spb, nb_sp); |
1401 | } |
1402 | } |
1403 | thr_jobs[ithr] = thr_job; |
1404 | } |
1405 | |
1406 | dim_t max_job = 0; |
1407 | dim_t sum_job = 0; |
1408 | for (int ithr = 0; ithr < nthr; ithr++) { |
1409 | if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr]; |
1410 | sum_job += thr_jobs[ithr]; |
1411 | } |
1412 | |
1413 | job_eff = max_job == 0 ? 1 |
1414 | : static_cast<float>(sum_job) / (max_job * nthr); |
1415 | } else { |
1416 | job_eff = thr_eff; |
1417 | } |
1418 | |
1419 | const auto ic_blocking_size = ic_block * nb_ic_blocking; |
1420 | const auto oc_blocking_size = oc_block * ic_blocking_size; |
1421 | |
1422 | int l = -1; |
1423 | // -- brgemm kernel: loop by simd_w -- |
1424 | l++; |
1425 | loop[l].src.set(ur * simd_w, 1, bcast_simd); |
1426 | loop[l].dst.set(0, 1); |
1427 | loop[l].wei.set(oc_block, 1); |
1428 | |
1429 | // -- brgemm kernel: loop by ur in sp_block -- |
1430 | l++; |
1431 | const auto nb_ur = div_up(sp_block, ur); |
1432 | const auto nb_sp_no_tail = sp / sp_block; |
1433 | const auto sp_block_tail = sp % sp_block; |
1434 | const auto nb_ur_average |
1435 | = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp; |
1436 | loop[l].src.set(ur * rnd_simd(ic_blocking_size), 1); |
1437 | loop[l].dst.set(ur * oc_block, 1); |
1438 | loop[l].wei.set(oc_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur); |
1439 | // -- brgemm kernel: loop by ic_chunks -- |
1440 | l++; |
1441 | const auto ic_chunks = div_up(nb_ic, nb_ic_blocking); |
1442 | loop[l].src.set(sp_block * ic_blocking_size, 1); |
1443 | loop[l].dst.set(sp_block * oc_block, ic_chunks); |
1444 | auto wei_is = oc_blocking_size; |
1445 | auto wei_op = ocblock * ic; |
1446 | loop[l].wei.set(wei_is, 1); |
1447 | |
1448 | if (loop_order == loop_ndhwgc) { |
1449 | // -- harness: loop by oc_block -- |
1450 | l++; |
1451 | loop[l].src.set(sp_block * rnd_simd(ic), nb_oc_thr); |
1452 | loop[l].dst.set(sp_block * oc_block, 1); |
1453 | wei_is = oc_block * ic; |
1454 | wei_op = nsimd_oc_thr * ic; |
1455 | loop[l].wei.set(wei_is, 1); |
1456 | } |
1457 | |
1458 | const auto rnd_oc_for_sp |
1459 | = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock); |
1460 | if (is_os_blocking) { |
1461 | // -- harness: loop by os_blocks -- |
1462 | l++; |
1463 | loop[l].src.set(sp_block * ic_blocking_size, 1); |
1464 | loop[l].dst.set(sp_block * rnd_oc_for_sp, 1); |
1465 | loop[l].wei.set(wei_op * simd_w, nb_sp_thr); |
1466 | } else { |
1467 | // -- harness: loop by sp_blocks -- |
1468 | l++; |
1469 | loop[l].src.set(sp_block * ic_blocking_size, 1); |
1470 | loop[l].dst.set(sp_block * rnd_oc_for_sp, 1); |
1471 | loop[l].wei.set(wei_op * simd_w, nb_sp_thr); |
1472 | // -- harness: loop by oh_blocks -- |
1473 | l++; |
1474 | loop[l].src.set(oh_block * sp_thr * rnd_simd(ic_blocking_size), 1); |
1475 | loop[l].dst.set(oh_block * sp_thr * rnd_oc_for_sp, 1); |
1476 | loop[l].wei.set(wei_op * simd_w, nb_oh_thr); |
1477 | // -- harness: loop by od_blocks -- |
1478 | l++; |
1479 | loop[l].src.set( |
1480 | od_block * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1); |
1481 | loop[l].dst.set(od_block * oh_thr * sp_thr * rnd_oc_for_sp, 1); |
1482 | loop[l].wei.set(wei_op * simd_w, nb_od_thr); |
1483 | } |
1484 | |
1485 | if (loop_order != loop_ndhwgc) { |
1486 | // -- harness: loop by oc_block -- |
1487 | l++; |
1488 | loop[l].src.set(od_thr * oh_thr * rnd_simd(sp_thr * ic_blocking_size), |
1489 | nb_oc_thr); |
1490 | loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1); |
1491 | loop[l].wei.set(oc_block * ic, 1); |
1492 | } |
1493 | |
1494 | // -- harness: loop by mb -- |
1495 | l++; |
1496 | const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc)); |
1497 | loop[l].src.set(od_thr * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1); |
1498 | loop[l].dst.set(nsimd_oc_thr * simd_w * od_thr * oh_thr * sp_thr, 1); |
1499 | loop[l].wei.set(nsimd_oc_thr * ic * simd_w, mb_thr); |
1500 | |
1501 | const auto src_op = static_cast<dim_t>(mb_thr) * od_thr * oh_thr * sp_thr |
1502 | * ic_blocking_size; |
1503 | const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_oc_thr * od_thr |
1504 | * oh_thr * sp_thr; |
1505 | wei_op = nsimd_oc_thr * ic; |
1506 | |
1507 | // for "real" application set bench_iterations to 1 |
1508 | const auto iterations = bench_iterations; |
1509 | l++; |
1510 | loop[l].src.set(src_op, iterations); |
1511 | loop[l].dst.set(dst_op * simd_w, iterations); |
1512 | loop[l].wei.set(wei_op * simd_w, iterations); |
1513 | |
1514 | auto src_mem_k = mem_k; |
1515 | auto dst_mem_k = mem_k; |
1516 | auto wei_mem_k = mem_k; |
1517 | float src_rp = 1; |
1518 | float dst_rp = 1; |
1519 | float wei_rp = 1; |
1520 | |
1521 | for (auto il = l; il >= 0; il--) { |
1522 | src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false); |
1523 | dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false); |
1524 | wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true); |
1525 | src_rp *= loop[il].src.repeatn; |
1526 | dst_rp *= loop[il].dst.repeatn; |
1527 | wei_rp *= loop[il].wei.repeatn; |
1528 | } |
1529 | const auto src_ops = (src_op * src_rp) / iterations; |
1530 | const auto dst_ops = (dst_op * dst_rp) / iterations; |
1531 | const auto wei_ops = (wei_op * wei_rp) / iterations; |
1532 | |
1533 | const auto src_cost = src_mem_k * src_ops; |
1534 | const auto dst_cost = dst_mem_k * dst_ops; |
1535 | const auto wei_cost = wei_mem_k * wei_ops; |
1536 | const auto call_kernel_cost = 1000.f * job * ic_chunks; |
1537 | |
1538 | const auto up_sp_size = is_os_blocking ? 1 : od * oh; |
1539 | |
1540 | const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * ic * oc) |
1541 | / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost)); |
1542 | |
1543 | const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff |
1544 | * job_eff * ur_eff * cache_eff * brgemm_eff; |
1545 | return res_eff; |
1546 | } |
1547 | |
1548 | void brg_blocking_t::calc_blocks_1x1() { |
1549 | const bool is_os_blocking_ok |
1550 | = utils::everyone_is(1, stride_d, stride_h) && iw % stride_w == 0; |
1551 | const bool is_ic_zero_padded = ic != ic_without_padding; |
1552 | is_rtus = is_ic_zero_padded || (!is_os_blocking_ok && is_amx(isa)); |
1553 | if (is_os_blocking_ok || is_rtus) { |
1554 | sp = os; |
1555 | is_os_blocking = true; |
1556 | } else { |
1557 | sp = ow; |
1558 | is_os_blocking = false; |
1559 | } |
1560 | |
1561 | od_block = 1; |
1562 | oh_block = 1; |
1563 | kd_block = kh_block = kw_block = 1; |
1564 | kd_block_pad = kh_block_pad = kw_block_pad = 1; |
1565 | nb_ic_blocking = 1; |
1566 | |
1567 | const auto thr_eff_threshold = 0.9f; |
1568 | |
1569 | const auto max_sp_block_L2 = os; |
1570 | // TODO: nb_os_blocking always is 1 for now. Update this code |
1571 | nb_os_blocking = 1; |
1572 | int start_sp_block = 0; |
1573 | |
1574 | if (is_os_blocking) { |
1575 | ow_block = 0; |
1576 | |
1577 | const auto max_os_block_thr |
1578 | = (src_dsz * ic >= 1024 && src_dsz * ic < 4096) |
1579 | ? nstl::max(nstl::min(16, os), |
1580 | div_up(os, div_up(nthr, mb * div_up(oc, oc_block)))) |
1581 | : nstl::max(div_up(2048, oc_block), |
1582 | static_cast<int>(div_up(mb * ngroups * os, nthr))); |
1583 | const auto max_os_block_L2 = max_sp_block_L2; |
1584 | |
1585 | auto max_os_block_aliasing = 1000000 / nthr; |
1586 | if ((oc_without_padding * os * dst_dsz) % P4K == 0) { |
1587 | max_os_block_aliasing /= 1; |
1588 | for (auto cur_oc = oc_without_padding; |
1589 | max_os_block_aliasing * dst_dsz > 400 && cur_oc % 2 == 0 |
1590 | && cur_oc * os * dst_dsz >= P4K; |
1591 | cur_oc /= 2) { |
1592 | max_os_block_aliasing /= 2; |
1593 | } |
1594 | max_os_block_aliasing += max_os_block_aliasing % 2 ? 0 : 1; |
1595 | } |
1596 | max_os_block_aliasing |
1597 | = nstl::min(div_up(1001, dst_dsz), max_os_block_aliasing); |
1598 | |
1599 | start_sp_block = utils::saturate(1, os, |
1600 | nstl::min(nstl::min(max_os_block_thr, max_os_block_L2), |
1601 | max_os_block_aliasing)); |
1602 | |
1603 | } else { |
1604 | os_block = 0; |
1605 | |
1606 | const auto max_ow_block_thr = utils::saturate(1, ow, |
1607 | static_cast<int>(div_up( |
1608 | mb * ngroups * nb_oc * os, thr_eff_threshold * nthr))); |
1609 | const auto max_ow_block_L2 = max_sp_block_L2; |
1610 | |
1611 | start_sp_block = utils::saturate( |
1612 | 1, ow, nstl::min(max_ow_block_thr, max_ow_block_L2)); |
1613 | } |
1614 | os_block = ow_block = sp_block = -1; |
1615 | brg_blocking_t best_brgb = *this; |
1616 | |
1617 | auto prev_spb = 0; |
1618 | for (auto ns = 1; ns <= sp; ns++) { |
1619 | auto spb = div_up(sp, ns); |
1620 | if (is_amx(isa)) { |
1621 | auto min_dis = 16; |
1622 | auto best_w = 16; |
1623 | const auto max_tile_width = nstl::min(16, sp); |
1624 | const auto min_tile_width = utils::saturate(1, 11, sp / 2); |
1625 | if (spb < min_tile_width) break; |
1626 | for (auto w = max_tile_width; w >= min_tile_width; w--) { |
1627 | const auto dis = nstl::additive_inverse_modulo(spb, w); |
1628 | if (dis < min_dis) { |
1629 | min_dis = dis; |
1630 | best_w = w; |
1631 | } |
1632 | } |
1633 | spb = nstl::min(sp, rnd_dn(spb, best_w)); |
1634 | if (spb == prev_spb) continue; |
1635 | } |
1636 | if (spb == prev_spb || spb > start_sp_block) continue; |
1637 | prev_spb = spb; |
1638 | os_block = ow_block = sp_block = spb; |
1639 | select_ic_block(); |
1640 | const status_t st = estimate_brgemm_ur(); |
1641 | if (st != status::success) continue; |
1642 | update_blocks(); |
1643 | |
1644 | use_buffer = (dst_dt != acc_dt || with_sum) |
1645 | && (ic_block * nb_ic_blocking < ic); |
1646 | |
1647 | eff = est_eff_1x1(); |
1648 | if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this; |
1649 | } |
1650 | *this = best_brgb; |
1651 | os_block = ow_block = sp_block; |
1652 | update_blocks(); |
1653 | } |
1654 | |
1655 | brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) { |
1656 | return attr.zero_points_.has_default_values(arg) |
1657 | ? brgemm_broadcast_t::none |
1658 | : brgemm_broadcast_t::per_tensor; |
1659 | } |
1660 | status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, |
1661 | const convolution_desc_t &cd, memory_desc_t &src_md, |
1662 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
1663 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
1664 | using namespace prop_kind; |
1665 | |
1666 | brg_blocking_t::L1 = platform::get_per_core_cache_size(1); |
1667 | brg_blocking_t::L2 = platform::get_per_core_cache_size(2); |
1668 | brg_blocking_t::L3 = platform::get_per_core_cache_size(2); |
1669 | |
1670 | const memory_desc_wrapper src_d(&src_md); |
1671 | const memory_desc_wrapper weights_d(&weights_md); |
1672 | const memory_desc_wrapper dst_d(&dst_md); |
1673 | const memory_desc_wrapper bias_d(&bias_md); |
1674 | |
1675 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1676 | int ndims = src_d.ndims(); |
1677 | |
1678 | jcp = zero<decltype(jcp)>(); |
1679 | jcp.isa = isa; |
1680 | |
1681 | if (is_amx(isa)) { |
1682 | const int target_palette = amx::get_target_palette(); |
1683 | if (amx::get_max_tiles(target_palette) != 8 |
1684 | || amx::get_max_rows(target_palette) != 16) |
1685 | return status::unimplemented; |
1686 | } |
1687 | |
1688 | jcp.ndims = ndims; |
1689 | jcp.prop_kind = cd.prop_kind; |
1690 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1691 | jcp.mb = src_d.dims()[0]; |
1692 | jcp.oc_without_padding = dst_d.dims()[1]; |
1693 | jcp.oc = jcp.oc_without_padding / jcp.ngroups; |
1694 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
1695 | jcp.ic = jcp.ic_without_padding; |
1696 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
1697 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
1698 | jcp.iw = src_d.dims()[ndims - 1]; |
1699 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
1700 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
1701 | jcp.ow = dst_d.dims()[ndims - 1]; |
1702 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1703 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1704 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1705 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1706 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1707 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1708 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1709 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1710 | jcp.stride_w = cd.strides[ndims - 3]; |
1711 | |
1712 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1713 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1714 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1715 | |
1716 | jcp.os = jcp.od * jcp.oh * jcp.ow; |
1717 | |
1718 | jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1719 | jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1720 | jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1721 | |
1722 | jcp.back_pad = calculate_end_padding( |
1723 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd); |
1724 | jcp.b_pad = calculate_end_padding( |
1725 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh); |
1726 | jcp.r_pad = calculate_end_padding( |
1727 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw); |
1728 | |
1729 | jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0 |
1730 | && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0 |
1731 | && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw); |
1732 | |
1733 | jcp.with_bias = bias_md.format_kind != format_kind::undef; |
1734 | |
1735 | jcp.src_dt = src_md.data_type; |
1736 | jcp.dst_dt = dst_md.data_type; |
1737 | jcp.wei_dt = weights_md.data_type; |
1738 | jcp.bia_dt = jcp.with_bias ? bias_md.data_type : data_type::undef; |
1739 | |
1740 | if (one_of(jcp.src_dt, u8, s8)) { |
1741 | jcp.acc_dt = s32; |
1742 | } else if (one_of(jcp.src_dt, f32, bf16, f16)) { |
1743 | jcp.acc_dt = f32; |
1744 | } else |
1745 | return status::unimplemented; |
1746 | |
1747 | jcp.src_dsz = types::data_type_size(jcp.src_dt); |
1748 | jcp.wei_dsz = types::data_type_size(jcp.wei_dt); |
1749 | jcp.dst_dsz = types::data_type_size(jcp.dst_dt); |
1750 | jcp.acc_dsz = types::data_type_size(jcp.acc_dt); |
1751 | jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
1752 | |
1753 | jcp.simd_w = isa_max_vlen(isa) / jcp.src_dsz; |
1754 | jcp.acc_simd_w = isa_max_vlen(isa) / jcp.acc_dsz; |
1755 | jcp.is_bf32 = everyone_is(f32, jcp.src_dt, jcp.wei_dt) |
1756 | && attr.fpmath_mode_ == fpmath_mode::bf16 && isa == avx512_core_amx; |
1757 | |
1758 | brg_blocking_t::last_ic_block_size |
1759 | = (jcp.wei_dt == f16 && isa == avx512_core_fp16) |
1760 | ? 1 |
1761 | : data_type_vnni_granularity(jcp.wei_dt); |
1762 | |
1763 | // TODO: optimize depthwise convolutions (for now direct approach is faster) |
1764 | const bool is_depthwise |
1765 | = with_groups && jcp.ngroups > 1 && everyone_is(1, jcp.ic, jcp.oc); |
1766 | if (is_depthwise) return status::unimplemented; |
1767 | |
1768 | // TODO: optimize grouped convolutions with small ic |
1769 | const bool is_grouped_small_ic |
1770 | = jcp.prop_kind != prop_kind::backward_weights && with_groups |
1771 | && jcp.ngroups > 1 && jcp.ic <= jcp.acc_simd_w |
1772 | && IMPLICATION(is_amx(jcp.isa), |
1773 | jcp.ic < 16 |
1774 | && jcp.oc < 16 |
1775 | // already optimized for amx 1x1 convs |
1776 | && !jcp.is_1x1) |
1777 | // Enable the shapes not supported in direct convs |
1778 | && IMPLICATION(with_groups, is_groups_ok(jcp)); |
1779 | if (is_grouped_small_ic) return status::unimplemented; |
1780 | |
1781 | // Dispatch the shapes to VNNI for better performance |
1782 | // TODO: optimize the perf of 3d shape with small ic and large spatial |
1783 | const auto max_small_shapes_sz = jcp.is_1x1 |
1784 | ? static_cast<int32_t>(brg_blocking_t::L1) / 2 |
1785 | : static_cast<int32_t>(brg_blocking_t::L1); |
1786 | const auto is_small_shape = is_amx(jcp.isa) && jcp.os <= 4 && jcp.ic <= 512 |
1787 | && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= max_small_shapes_sz; |
1788 | const auto is_3d_small_ic = is_amx(jcp.isa) && jcp.ndims == 5 |
1789 | && jcp.ic * jcp.oc <= 32 && jcp.od >= 128 && jcp.oh >= 128 |
1790 | && jcp.ow >= 128; |
1791 | if (one_of(jcp.prop_kind, prop_kind::forward_training, |
1792 | prop_kind::forward_inference) |
1793 | && (is_small_shape || is_3d_small_ic)) |
1794 | return status::unimplemented; |
1795 | |
1796 | jcp.s8s8_avx512 = jcp.src_dt == s8 && !is_amx(jcp.isa); |
1797 | |
1798 | if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni))) |
1799 | return status::unimplemented; |
1800 | if (!IMPLICATION(jcp.wei_dt == bf16, |
1801 | mayiuse(avx512_core_bf16) || mayiuse(avx2_vnni_2))) |
1802 | return status::unimplemented; |
1803 | if (!IMPLICATION(jcp.wei_dt == f16, |
1804 | mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2))) |
1805 | return status::unimplemented; |
1806 | const bool is_f32 |
1807 | = utils::everyone_is(f32, jcp.src_dt, jcp.wei_dt, jcp.dst_dt); |
1808 | if (!IMPLICATION(is_f32, one_of(isa, avx512_core, avx2) || jcp.is_bf32)) |
1809 | return status::unimplemented; |
1810 | |
1811 | if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented; |
1812 | |
1813 | jcp.amx_h = 16; |
1814 | jcp.amx_w = 64 / (jcp.is_bf32 ? types::data_type_size(bf16) : jcp.src_dsz); |
1815 | |
1816 | const auto &p = attr.post_ops_; |
1817 | jcp.with_sum = p.find(primitive_kind::sum) != -1; |
1818 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
1819 | jcp.with_eltwise = eltwise_ind != -1; |
1820 | |
1821 | const int binary_ind = p.find(primitive_kind::binary); |
1822 | jcp.with_binary = binary_ind != -1; |
1823 | |
1824 | jcp.src_zero_point |
1825 | = get_zp_type(attr, DNNL_ARG_SRC) != brgemm_broadcast_t::none; |
1826 | jcp.dst_zero_point |
1827 | = get_zp_type(attr, DNNL_ARG_DST) != brgemm_broadcast_t::none; |
1828 | |
1829 | // Only common zero points for the whole output tensor is supported now |
1830 | // TODO: Extend zero points support to AMX |
1831 | const bool has_zero_points = jcp.src_zero_point || jcp.dst_zero_point; |
1832 | if (has_zero_points || jcp.s8s8_avx512) { |
1833 | const bool params_ok = IMPLICATION(has_zero_points, !is_amx(jcp.isa)) |
1834 | && IMPLICATION( |
1835 | has_zero_points, utils::one_of(jcp.src_dt, u8, s8)) |
1836 | && IMPLICATION(jcp.src_zero_point, |
1837 | attr.zero_points_.common(DNNL_ARG_SRC)) |
1838 | && IMPLICATION(jcp.dst_zero_point, |
1839 | attr.zero_points_.common(DNNL_ARG_DST)); |
1840 | if (!params_ok) return status::unimplemented; |
1841 | } |
1842 | |
1843 | jcp.nthr = nthreads; |
1844 | jcp.kh_sets = 1; |
1845 | jcp.kw_sets = 1; |
1846 | jcp.copy_block_only = false; |
1847 | jcp.amx_tile_load_xx = false; |
1848 | jcp.use_M_mask = 0; |
1849 | jcp.is_os_blocking = false; |
1850 | jcp.oskip = 0; |
1851 | jcp.use_uker = false; |
1852 | jcp.use_interleave_stores = false; |
1853 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf_default; |
1854 | jcp.brgemm_bd_loop_innermost = false; |
1855 | |
1856 | if (jcp.prop_kind != prop_kind::backward_weights) { |
1857 | // fast check data layout before spending time for blocking selection |
1858 | format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc); |
1859 | CHECK(init_tag( |
1860 | jcp.src_tag, src_md, src_d, src_tag, is_any_eligible(jcp))); |
1861 | } |
1862 | if (jcp.with_bias) { |
1863 | if (bias_d.format_kind() == format_kind::any) |
1864 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
1865 | } |
1866 | |
1867 | const auto ic_padded_block |
1868 | = jcp.acc_simd_w * brg_blocking_t::last_ic_block_size; |
1869 | jcp.is_ic_padded = !jcp.is_1x1 && one_of(jcp.wei_dt, bf16, f16, s8) |
1870 | && jcp.ic * jcp.kw_sets > ic_padded_block && is_amx(isa); |
1871 | |
1872 | jcp.idp = jcp.id + jcp.f_pad + jcp.back_pad; |
1873 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
1874 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
1875 | |
1876 | return status::success; |
1877 | } |
1878 | |
1879 | status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, |
1880 | const convolution_desc_t &cd, memory_desc_t &src_md, |
1881 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
1882 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
1883 | |
1884 | using namespace prop_kind; |
1885 | if (!mayiuse(isa)) return status::unimplemented; |
1886 | |
1887 | CHECK(init_jcp( |
1888 | jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads)); |
1889 | |
1890 | if (jcp.is_1x1) return status::unimplemented; |
1891 | const memory_desc_wrapper src_d(&src_md); |
1892 | const memory_desc_wrapper weights_d(&weights_md); |
1893 | const memory_desc_wrapper dst_d(&dst_md); |
1894 | const memory_desc_wrapper bias_d(&bias_md); |
1895 | |
1896 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1897 | |
1898 | // TODO: check these restrictions |
1899 | if (is_amx(isa)) { |
1900 | // disabled for two convolutions from ssd_resnet34 |
1901 | if ((jcp.ic == jcp.oc) && (jcp.ic == 128 || jcp.ic == 256) |
1902 | && (jcp.oh == jcp.ow) && (jcp.oh == 150)) |
1903 | return status::unimplemented; |
1904 | // disabled for first convolutions excepting 3d |
1905 | const bool is_real_3d = (jcp.ndims == 5 |
1906 | && (jcp.id > 1 || jcp.od > 1 || jcp.kd > 1 |
1907 | || jcp.dilate_d > 0)); |
1908 | |
1909 | if (jcp.ic <= 4 && !is_real_3d |
1910 | && IMPLICATION(with_groups, is_groups_ok(jcp))) |
1911 | return status::unimplemented; |
1912 | |
1913 | if (jcp.f_pad >= jcp.ext_kd || jcp.t_pad >= jcp.ext_kh |
1914 | || jcp.r_pad >= jcp.ext_kw) |
1915 | return status::unimplemented; |
1916 | } |
1917 | |
1918 | using namespace data_type; |
1919 | // ======================= blocking ================================= |
1920 | |
1921 | auto bcast_amount |
1922 | = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz; |
1923 | auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.kd * jcp.kh * jcp.kw |
1924 | * jcp.wei_dsz; |
1925 | |
1926 | jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc; |
1927 | |
1928 | const int min_oc_block = jcp.acc_simd_w; |
1929 | |
1930 | int selected_ur = 0; |
1931 | MAYBE_UNUSED(selected_ur); |
1932 | |
1933 | auto try_exec_type = [&]() { |
1934 | brg_blocking_t best_brgb = zero<decltype(best_brgb)>(); |
1935 | best_brgb.oc_block = min_oc_block; |
1936 | brg_blocking_t cur_brgb = zero<decltype(best_brgb)>(); |
1937 | cur_brgb.get_from_jcp(jcp); |
1938 | auto start_ocb = (is_amx(isa) && jcp.is_os_blocking) ? 2 : 4; |
1939 | if (jcp.wei_plain) |
1940 | start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32, |
1941 | div_up(jcp.oc, jcp.acc_simd_w)); |
1942 | start_ocb = nstl::min(div_up(jcp.oc, jcp.acc_simd_w), start_ocb); |
1943 | |
1944 | auto finish_ocb = 1; |
1945 | for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) { |
1946 | cur_brgb.oc_block = ocb * jcp.acc_simd_w; |
1947 | cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block); |
1948 | if (!cur_brgb.fast_check_oc_block()) continue; |
1949 | |
1950 | const status_t blocking_ok = cur_brgb.calc_blocks(); |
1951 | if (blocking_ok != status::success) continue; |
1952 | |
1953 | const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md); |
1954 | if (st != status::success) continue; |
1955 | cur_brgb.eff = cur_brgb.est_eff(); |
1956 | if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb; |
1957 | } |
1958 | if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0 |
1959 | || best_brgb.ow_block == 0) |
1960 | return false; |
1961 | best_brgb.save_to_jcp(jcp); |
1962 | selected_ur = best_brgb.ur; |
1963 | return true; |
1964 | }; |
1965 | |
1966 | //----------------------------------------------------------------------- |
1967 | |
1968 | jcp.exec_type = exec_base; |
1969 | jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM |
1970 | |
1971 | bool try_exec_vpad = false; |
1972 | bool try_exec_trans = false; |
1973 | bool try_exec_base = true; |
1974 | |
1975 | if (!is_amx(isa) && div_up(jcp.l_pad, jcp.stride_w) < jcp.kw |
1976 | && div_up(jcp.r_pad, jcp.stride_w) < jcp.kw) { |
1977 | try_exec_vpad = true; |
1978 | } |
1979 | |
1980 | const auto ic_padded_block |
1981 | = jcp.acc_simd_w * brg_blocking_t::last_ic_block_size; |
1982 | // TODO: remove this restriction |
1983 | const auto w_padding = jcp.l_pad > 0 || jcp.r_pad > 0; |
1984 | if (is_amx(isa)) { |
1985 | try_exec_base = !w_padding |
1986 | && IMPLICATION(jcp.ic <= ic_padded_block, |
1987 | jcp.ic % brg_blocking_t::last_ic_block_size == 0) |
1988 | && IMPLICATION( |
1989 | jcp.ic > ic_padded_block, jcp.ic % ic_padded_block == 0) |
1990 | && jcp.ow > 50 /*TODO: reinvestigate this heuristic */; |
1991 | try_exec_trans = !try_exec_base; |
1992 | } |
1993 | |
1994 | bool must_exec_vpad = false; |
1995 | |
1996 | // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more |
1997 | // precise calculation of jcp.max_batch |
1998 | jcp.max_batch = jcp.kd * jcp.kh * jcp.kw; |
1999 | |
2000 | //TODO: check wei plain |
2001 | jcp.wei_plain = false; |
2002 | jcp.wei_plain = jcp.exec_type == exec_vpad ? jcp.wei_plain : false; |
2003 | |
2004 | bool try_exec_type_res = false; |
2005 | |
2006 | if (try_exec_vpad) { |
2007 | jcp.exec_type = exec_vpad; |
2008 | try_exec_type_res = try_exec_type(); |
2009 | // to avoid case when both top and bottom virtual padding are non-zero |
2010 | // TODO: remove this restriction |
2011 | const auto iw_block = (jcp.ow_block - 1) * jcp.stride_w + 1; |
2012 | if (!must_exec_vpad && (iw_block > jcp.iw)) try_exec_type_res = false; |
2013 | } |
2014 | if (try_exec_type_res == false && try_exec_trans) { |
2015 | jcp.exec_type = exec_trans; |
2016 | |
2017 | // try loop_ndhwgc always for exec_trans |
2018 | jcp.loop_order = loop_ndhwgc; |
2019 | |
2020 | // we read input block only once for loop_ndhwgc, so we don't need to |
2021 | // keep it memory |
2022 | if (jcp.loop_order == loop_ndhwgc) { jcp.copy_block_only = true; } |
2023 | |
2024 | jcp.is_ic_padded = one_of(jcp.wei_dt, bf16, f16, s8) |
2025 | && jcp.ic * jcp.kw_sets > ic_padded_block; |
2026 | |
2027 | if (is_amx(isa) && (/* heuristic*/ jcp.kw_sets == 1 && jcp.ow < 256)) { |
2028 | jcp.is_os_blocking = jcp.f_pad < jcp.kd && jcp.back_pad < jcp.kd |
2029 | && jcp.t_pad < jcp.kh && jcp.b_pad < jcp.kh |
2030 | && jcp.r_pad < jcp.kw && jcp.l_pad < jcp.kw; |
2031 | jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0; |
2032 | jcp.use_uker = true; |
2033 | jcp.use_interleave_stores = true; |
2034 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
2035 | // assuming 2x2 decomposition in amx brgemm kernel |
2036 | // and overlap of input by kw |
2037 | const auto bd_blocking = 2 * jcp.amx_h; |
2038 | const auto ld_blocking = 2 * 16; |
2039 | const auto A_ds |
2040 | = jcp.src_dsz * bd_blocking * jcp.ic * jcp.kd * jcp.kh; |
2041 | const auto B_ds = jcp.wei_dsz * ld_blocking * jcp.ic * jcp.kd |
2042 | * jcp.kh * jcp.kw; |
2043 | const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking; |
2044 | if (A_ds + B_ds + C_ds > brg_blocking_t::L1) |
2045 | jcp.amx_tile_load_xx = true; |
2046 | } |
2047 | |
2048 | try_exec_type_res = try_exec_type(); |
2049 | } |
2050 | if (try_exec_base && try_exec_type_res == false) { |
2051 | jcp.exec_type = exec_base; |
2052 | try_exec_type_res = try_exec_type(); |
2053 | } |
2054 | |
2055 | if (try_exec_type_res == false) return status::unimplemented; |
2056 | |
2057 | // ============ end blocking =========================================== |
2058 | if (jcp.exec_type == exec_vpad) |
2059 | jcp.max_vpad = nstl::max(jcp.l_pad, jcp.r_pad); |
2060 | else |
2061 | jcp.max_vpad = 0; |
2062 | |
2063 | if (jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0) |
2064 | return status::unimplemented; |
2065 | |
2066 | jcp.gemm_batch_size = jcp.nb_ic_blocking |
2067 | * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block, |
2068 | jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad); |
2069 | // to avoid cache concurrent write access from different threads |
2070 | size_t sc_size = sizeof(brgemm_batch_element_t); |
2071 | jcp.adjusted_batch_size |
2072 | = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size); |
2073 | |
2074 | CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md)); |
2075 | CHECK(attr.set_default_formats(&dst_md)); |
2076 | |
2077 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
2078 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
2079 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
2080 | jcp.with_scales = !src_scales.has_default_values() |
2081 | || !wei_scales.has_default_values(); |
2082 | const int wei_mask_per_oc = 1 << (int)with_groups; |
2083 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
2084 | |
2085 | // only common and per-oc-channel scales are supported |
2086 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
2087 | && src_scales.mask_ == 0 && dst_scales.has_default_values(); |
2088 | if (!scales_ok) return status::unimplemented; |
2089 | |
2090 | jcp.buffer_size = jcp.LDC * jcp.M; |
2091 | |
2092 | jcp.nb_od = div_up(jcp.od, jcp.od_block); |
2093 | jcp.nb_oh = div_up(jcp.oh, jcp.oh_block); |
2094 | |
2095 | if (jcp.exec_type == exec_trans) { |
2096 | // TODO: this is rough estimation of buffer for transpose input |
2097 | dim_t ds = jcp.copy_block_only |
2098 | ? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd, |
2099 | jcp.stride_d, jcp.dilate_d) |
2100 | + nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad)) |
2101 | : jcp.idp; |
2102 | dim_t hs = jcp.copy_block_only |
2103 | ? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh, |
2104 | jcp.stride_h, jcp.dilate_h) |
2105 | + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad)) |
2106 | : jcp.ihp; |
2107 | if (jcp.is_os_blocking) |
2108 | hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp); |
2109 | |
2110 | jcp.inp_buffer_size = rnd_up(ds * hs * jcp.iwp * jcp.ngroups * jcp.nb_ic |
2111 | * jcp.ic_block * jcp.kh_sets * jcp.kw_sets, |
2112 | P4K); |
2113 | jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_od) |
2114 | * jcp.nb_oh * jcp.nb_ow * jcp.ngroups * jcp.nb_ic, |
2115 | P4K); |
2116 | } |
2117 | |
2118 | const bool with_pad = jcp.f_pad > 0 || jcp.back_pad > 0 || jcp.t_pad > 0 |
2119 | || jcp.b_pad > 0; |
2120 | |
2121 | if (jcp.s8s8_avx512) { |
2122 | weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8; |
2123 | weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1; |
2124 | } |
2125 | if (jcp.src_zero_point && !is_amx(jcp.isa)) { |
2126 | weights_md.extra.flags |
2127 | |= memory_extra_flags::compensation_conv_asymmetric_src; |
2128 | weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1; |
2129 | } |
2130 | |
2131 | // disables the shape with small ic but large spatial |
2132 | // or specific large spatial shapes for int8 conv |
2133 | const auto is_ok_large_spatial |
2134 | = IMPLICATION(!is_amx(jcp.isa) && jcp.ic <= 128, |
2135 | jcp.od * jcp.oh < 100 |
2136 | || jcp.ic * jcp.oc_block * jcp.ow_block > 8192) |
2137 | && !(is_amx(jcp.isa) && jcp.ic < 16 && jcp.ndims == 4) |
2138 | && IMPLICATION(is_amx(jcp.isa) && jcp.ic <= 16, |
2139 | jcp.ow < 2048 |
2140 | || div_up(jcp.ow_block, selected_ur) * jcp.kd |
2141 | * jcp.kh * jcp.kw |
2142 | > 8192) |
2143 | && !(!is_amx(jcp.isa) && jcp.oc == 1024 |
2144 | && utils::everyone_is(1, jcp.od, jcp.oh, jcp.kd, jcp.kh) |
2145 | && jcp.ow >= 595 && jcp.kw <= 5); |
2146 | if (one_of(jcp.src_dt, u8, s8) && !is_ok_large_spatial) |
2147 | return status::unimplemented; |
2148 | |
2149 | // For padding shapes, we calculate the comp along with the computation |
2150 | // inside brgemm kernel when output size is small to get optimal perf |
2151 | // Or we calculate the comp using brgemm_coomp_pad kernel |
2152 | const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.oc |
2153 | * jcp.od * jcp.oh * jcp.ow; |
2154 | const auto comp_with_pads = (jcp.src_zero_point || jcp.s8s8_avx512) |
2155 | && IMPLICATION(jcp.exec_type == exec_vpad, with_pad); |
2156 | jcp.req_brg_comp_pad = comp_with_pads && output_sz <= 8192 && jcp.oc < 512; |
2157 | jcp.req_cal_comp_pad = comp_with_pads && !jcp.req_brg_comp_pad; |
2158 | |
2159 | // estimate the number of kernel range combination for compensation |
2160 | const auto kd_cnt = 1 + utils::div_up(abs(jcp.f_pad), jcp.dilate_d + 1) |
2161 | + utils::div_up(abs(jcp.back_pad), jcp.dilate_d + 1); |
2162 | const auto kh_cnt = 1 + utils::div_up(abs(jcp.t_pad), jcp.dilate_h + 1) |
2163 | + utils::div_up(abs(jcp.b_pad), jcp.dilate_h + 1); |
2164 | const auto kw_cnt |
2165 | = (1 |
2166 | + (utils::div_up(abs(jcp.l_pad), jcp.dilate_w + 1) |
2167 | + utils::div_up( |
2168 | abs(jcp.r_pad), jcp.dilate_w + 1))) |
2169 | * 2; |
2170 | |
2171 | jcp.ker_ranges_size = jcp.exec_type == exec_base ? kd_cnt * kh_cnt * kw_cnt |
2172 | : kd_cnt * kh_cnt; |
2173 | jcp.comp_a_buffer_size |
2174 | = jcp.ngroups * jcp.nb_oc * jcp.ker_ranges_size * jcp.oc_block; |
2175 | jcp.s8s8_comp_buffer_size = jcp.comp_a_buffer_size; |
2176 | |
2177 | if (!IMPLICATION(jcp.is_bf32, jcp.use_uker)) return status::unimplemented; |
2178 | |
2179 | return status::success; |
2180 | } |
2181 | |
2182 | status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, |
2183 | const convolution_desc_t &cd, memory_desc_t &src_md, |
2184 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
2185 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
2186 | |
2187 | using namespace prop_kind; |
2188 | if (!mayiuse(isa)) return status::unimplemented; |
2189 | |
2190 | CHECK(init_jcp( |
2191 | jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads)); |
2192 | |
2193 | const memory_desc_wrapper src_d(&src_md); |
2194 | const memory_desc_wrapper weights_d(&weights_md); |
2195 | const memory_desc_wrapper dst_d(&dst_md); |
2196 | const memory_desc_wrapper bias_d(&bias_md); |
2197 | |
2198 | if (!jcp.is_1x1) return status::unimplemented; |
2199 | |
2200 | using namespace data_type; |
2201 | // ===================== blocking ================================= |
2202 | |
2203 | auto bcast_amount |
2204 | = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz; |
2205 | auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.wei_dsz; |
2206 | |
2207 | jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc; |
2208 | |
2209 | if (is_amx(isa)) { |
2210 | // round up ic if needed |
2211 | const int vnni_width = brg_blocking_t::last_ic_block_size; |
2212 | const int n_vnni_blocks = utils::div_up(jcp.ic, vnni_width); |
2213 | const int ic_block |
2214 | = nstl::min(jcp.acc_simd_w, n_vnni_blocks) * vnni_width; |
2215 | const bool do_zeropad = (!jcp.is_bf32) |
2216 | && (jcp.ic % vnni_width != 0 || jcp.ic > ic_block); |
2217 | if (do_zeropad) jcp.ic = utils::rnd_up(jcp.ic, ic_block); |
2218 | const auto ic_padded_block = jcp.acc_simd_w * vnni_width; |
2219 | jcp.is_ic_padded = jcp.ic > ic_padded_block && !(jcp.is_bf32); |
2220 | |
2221 | // try to choose optimal loop order |
2222 | // TODO: incorporate loop order into smart blocking selection |
2223 | auto wei_size = (size_t)jcp.oc * jcp.ic * jcp.wei_dsz; |
2224 | auto max_size = 0.75f * brg_blocking_t::L2; |
2225 | const dim_t os = jcp.od * jcp.oh * jcp.ow; |
2226 | const dim_t os_cutoff = 400; // approximate and empiric |
2227 | const bool use_loop_ngcdhw |
2228 | = max_size < wei_size || (jcp.mb == 1 && os < os_cutoff); |
2229 | jcp.loop_order = use_loop_ngcdhw ? loop_ngcdhw : loop_ndhwgc; |
2230 | } |
2231 | |
2232 | const auto min_oc_block = jcp.acc_simd_w; |
2233 | |
2234 | jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM |
2235 | |
2236 | // max_batch is 1 and max_vpad is 0 for 1x1 convolutions |
2237 | jcp.max_batch = 1; |
2238 | jcp.max_vpad = 0; |
2239 | |
2240 | jcp.wei_plain = false; |
2241 | |
2242 | brg_blocking_t best_brgb = zero<decltype(best_brgb)>(); |
2243 | best_brgb.oc_block = min_oc_block; |
2244 | brg_blocking_t cur_brgb = zero<decltype(cur_brgb)>(); |
2245 | cur_brgb.get_from_jcp(jcp); |
2246 | auto start_ocb = 4; |
2247 | if (jcp.wei_plain) |
2248 | start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32, |
2249 | div_up(jcp.oc, jcp.acc_simd_w)); |
2250 | start_ocb = nstl::min(div_up(jcp.oc, jcp.acc_simd_w), start_ocb); |
2251 | |
2252 | auto finish_ocb = 1; |
2253 | for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) { |
2254 | cur_brgb.oc_block = ocb * min_oc_block; |
2255 | cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block); |
2256 | |
2257 | if (!cur_brgb.fast_check_oc_block_1x1()) continue; |
2258 | |
2259 | cur_brgb.calc_blocks_1x1(); |
2260 | const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md); |
2261 | if (st != status::success) continue; |
2262 | cur_brgb.eff = cur_brgb.est_eff_1x1(); |
2263 | if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb; |
2264 | } |
2265 | best_brgb.save_to_jcp(jcp); |
2266 | |
2267 | // =============== end blocking ================================= |
2268 | jcp.brg_stride_a = jcp.ic_block * jcp.src_dsz; |
2269 | jcp.brg_stride_b = jcp.ic_block * jcp.oc * jcp.wei_dsz; |
2270 | |
2271 | if (jcp.ic_block == 0 || jcp.oc_block == 0) return status::unimplemented; |
2272 | |
2273 | // Configure matrix sizes |
2274 | |
2275 | if (best_brgb.is_os_blocking) { |
2276 | if (jcp.os_block == 0) return status::unimplemented; |
2277 | jcp.M = jcp.brgM = jcp.os_block; |
2278 | jcp.M_tail = jcp.brgM_tail = jcp.os % jcp.os_block; |
2279 | } else { |
2280 | if (jcp.ow_block == 0) return status::unimplemented; |
2281 | jcp.M = jcp.brgM = jcp.ow_block; |
2282 | jcp.M_tail = jcp.brgM_tail = jcp.ow % jcp.ow_block; |
2283 | } |
2284 | |
2285 | jcp.K = jcp.ic >= jcp.ic_block ? jcp.ic_block : 0; |
2286 | jcp.N = jcp.oc >= jcp.oc_block ? jcp.oc_block : 0; |
2287 | jcp.N_tail = jcp.oc % jcp.oc_block; |
2288 | jcp.K_tail = jcp.ic % jcp.ic_block; |
2289 | |
2290 | jcp.gemm_batch_size = jcp.nb_ic_blocking; |
2291 | // to avoid cache concurrent access from different threads |
2292 | size_t sc_size = sizeof(brgemm_batch_element_t); |
2293 | jcp.adjusted_batch_size |
2294 | = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size); |
2295 | |
2296 | if (is_amx(isa)) { |
2297 | // heuristic for small mb |
2298 | const bool is_small_mb = jcp.mb == 1 && jcp.ic * jcp.oh <= 28 * 1024 |
2299 | && jcp.oc * jcp.oh <= 14 * 1024; |
2300 | // non-unrolled kernel does not support bf32, only dispatch unrolled |
2301 | // kernel for now |
2302 | jcp.use_uker = jcp.is_bf32 || !is_small_mb; |
2303 | } |
2304 | |
2305 | // TODO: heuristic to dispatch BF32 BRGeMM |
2306 | // The following condition checks for shapes where down-convert execution |
2307 | // in brgemm fails |
2308 | if (jcp.is_bf32 && jcp.ic < 64 && jcp.ic % 32 != 0) |
2309 | return status::unimplemented; |
2310 | |
2311 | if (jcp.use_uker) |
2312 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
2313 | CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md)); |
2314 | CHECK(attr.set_default_formats(&dst_md)); |
2315 | |
2316 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
2317 | |
2318 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
2319 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
2320 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
2321 | jcp.with_scales = !src_scales.has_default_values() |
2322 | || !wei_scales.has_default_values(); |
2323 | const int wei_mask_per_oc = 1 << (int)with_groups; |
2324 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
2325 | |
2326 | // only common and per-oc-channel scales are supported |
2327 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
2328 | && src_scales.mask_ == 0 && dst_scales.has_default_values(); |
2329 | if (!scales_ok) return status::unimplemented; |
2330 | |
2331 | // no inp buffer or brgemm_vpad for 1x1 |
2332 | constexpr int align_size = platform::get_cache_line_size(); |
2333 | jcp.exec_type = jcp.is_rtus ? exec_trans : exec_base; |
2334 | jcp.inp_buffer_size |
2335 | = jcp.is_rtus ? rnd_up(jcp.LDA * jcp.os, align_size) : 0; |
2336 | jcp.inp_buffer_mask_size = jcp.is_rtus |
2337 | ? rnd_up(div_up(jcp.nb_ic, jcp.nb_ic_blocking) * jcp.nb_os, |
2338 | align_size) |
2339 | : 0; |
2340 | jcp.buffer_size = jcp.LDC * jcp.M; |
2341 | |
2342 | if (jcp.s8s8_avx512) { |
2343 | weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8; |
2344 | weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1; |
2345 | } |
2346 | if (jcp.src_zero_point) { |
2347 | weights_md.extra.flags |
2348 | |= memory_extra_flags::compensation_conv_asymmetric_src; |
2349 | weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1; |
2350 | } |
2351 | jcp.req_cal_comp_pad = false; |
2352 | jcp.s8s8_comp_buffer_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
2353 | jcp.comp_a_buffer_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
2354 | |
2355 | return status::success; |
2356 | } |
2357 | |
2358 | void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp) { |
2359 | // ensure buffers for individual threads do not lie on same page and also |
2360 | // they are not contiguous. |
2361 | jcp.amx_buf_size_per_thread |
2362 | = utils::rnd_up(jcp.amx_buf_size_per_thread + 1, P4K); |
2363 | } |
2364 | |
2365 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
2366 | const jit_brgemm_conv_conf_t &jcp) { |
2367 | if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs |
2368 | || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad)) |
2369 | scratchpad.book(key_brgemm_primitive_batch, |
2370 | static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size, |
2371 | sizeof(brgemm_batch_element_t), 64, P4K); |
2372 | if (jcp.exec_type == exec_trans) { |
2373 | size_t inp_buffer_size |
2374 | = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size; |
2375 | scratchpad.book(key_conv_brgemm_inp_buffer, inp_buffer_size, |
2376 | jcp.src_dsz, 0, P4K); |
2377 | size_t inp_buffer_mask_size |
2378 | = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size; |
2379 | scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size, |
2380 | sizeof(uint8_t), 0, P4K); |
2381 | } |
2382 | if (jcp.use_buffer) { |
2383 | scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size, |
2384 | jcp.acc_dsz, 0, P4K); |
2385 | } |
2386 | if (is_amx(jcp.isa)) { |
2387 | scratchpad.book(key_conv_amx_tile_buffer, |
2388 | jcp.nthr * jcp.amx_buf_size_per_thread, sizeof(char), 0, P4K); |
2389 | } |
2390 | if (jcp.s8s8_avx512 && jcp.req_cal_comp_pad) { |
2391 | scratchpad.book(key_brgemm_primitive_buffer_comp, |
2392 | jcp.s8s8_comp_buffer_size, sizeof(int32_t), 0, P4K); |
2393 | } |
2394 | if (jcp.src_zero_point && jcp.req_cal_comp_pad && !is_amx(jcp.isa)) { |
2395 | scratchpad.book(key_brgemm_primitive_zp_comp_a, jcp.comp_a_buffer_size, |
2396 | sizeof(int32_t), 0, P4K); |
2397 | } |
2398 | } |
2399 | |
2400 | void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) { |
2401 | |
2402 | const auto os_chunks = jcp.nthr_mb_work; |
2403 | const auto oc_chunks = div_up(jcp.nb_oc, jcp.nb_oc_blocking); |
2404 | const auto ic_chunks = div_up(jcp.nb_ic, jcp.nb_ic_blocking); |
2405 | |
2406 | auto calc_mem_cost = [=](int nthr_mb, int nthr_g, int nthr_oc_b, |
2407 | int nthr_ic_b) { |
2408 | /* calculate per thread memory cost (read/write). high level |
2409 | * optimizer tries to minimize memory consumption. few notes: |
2410 | * (n1) if weights tensor size is less than source and destination |
2411 | * tensors we apply the ratio of the source and destination |
2412 | * tensor sizes to weights one as compensation coefficient to |
2413 | * avoid parallelization across batch size only, otherwise we |
2414 | * apply additional coefficient to source component based on |
2415 | * performance measurements |
2416 | * (n2) use scales based on output vs input channels ratio for |
2417 | * source and destination components to improve threading |
2418 | * balance across input and output channels */ |
2419 | |
2420 | const dim_t src_type_size = 2; |
2421 | const dim_t wei_type_size = 4; |
2422 | const dim_t acc_type_size = wei_type_size; |
2423 | |
2424 | const auto wei_ks = jcp.kh * jcp.kw * jcp.kd; |
2425 | |
2426 | const auto src_spatial = (dim_t)jcp.mb * jcp.id * jcp.ih * jcp.tr_iw; |
2427 | const auto dst_spatial = (dim_t)jcp.mb * jcp.od * jcp.oh * jcp.tr_ow; |
2428 | |
2429 | dim_t src_size = src_spatial * jcp.ic * src_type_size; |
2430 | dim_t dst_size = dst_spatial * jcp.oc * src_type_size; |
2431 | dim_t wei_size = (dim_t)jcp.oc * jcp.ic * wei_ks * wei_type_size; |
2432 | |
2433 | float wei_compensation_scale = 0.5f * (dst_size + src_size) / wei_size; |
2434 | float oi_channels_ratio = (float)(oc_chunks) / ic_chunks; |
2435 | |
2436 | auto get_src_coef = [=]() { |
2437 | float src_coef = nstl::max(1.0f / oi_channels_ratio, 1.0f); |
2438 | if (wei_compensation_scale < 1.0f) src_coef *= 4.0f; |
2439 | return src_coef; |
2440 | }; |
2441 | |
2442 | auto get_dst_coef |
2443 | = [=]() { return nstl::max(oi_channels_ratio, 1.0f); }; |
2444 | |
2445 | auto get_wei_coef |
2446 | = [=]() { return nstl::max(wei_compensation_scale, 1.0f); }; |
2447 | |
2448 | const float src_coef = get_src_coef(); |
2449 | const float dst_coef = get_dst_coef(); |
2450 | const float wei_coef = get_wei_coef(); |
2451 | |
2452 | const auto thr_mb = div_up(os_chunks, nthr_mb); |
2453 | const auto nb_oc_job = jcp.oc_block * jcp.nb_oc_blocking; |
2454 | const auto nb_ic_job = jcp.ic_block * jcp.nb_ic_blocking; |
2455 | |
2456 | const auto src_chunk = src_spatial / os_chunks; |
2457 | const auto dst_chunk = dst_spatial / os_chunks; |
2458 | |
2459 | const auto thr_g = div_up(jcp.ngroups, nthr_g); |
2460 | const auto thr_ic_b = div_up(ic_chunks, nthr_ic_b); |
2461 | const auto thr_src_sp = thr_mb * src_chunk / jcp.stride_d / jcp.stride_h |
2462 | / jcp.stride_w; |
2463 | const auto thr_dst_sp = thr_mb * dst_chunk; |
2464 | const auto thr_ic_amount = thr_ic_b * nb_ic_job; |
2465 | |
2466 | const auto thr_oc_b = div_up(oc_chunks, nb_oc_job * nthr_oc_b); |
2467 | |
2468 | const auto thr_oc_amount = thr_oc_b * nb_oc_job; |
2469 | float src_v |
2470 | = src_type_size * src_coef * thr_g * thr_ic_amount * thr_src_sp; |
2471 | float dst_v |
2472 | = src_type_size * dst_coef * thr_g * thr_oc_amount * thr_dst_sp; |
2473 | float wei_v = acc_type_size * wei_coef * thr_g * thr_oc_amount |
2474 | * thr_ic_amount * wei_ks; |
2475 | |
2476 | return src_v + dst_v + wei_v; |
2477 | }; |
2478 | |
2479 | auto balance = [=](int &nthr_, int &nthr_mb_, int &nthr_g_, int &nthr_oc_b_, |
2480 | int &nthr_ic_b_) { |
2481 | nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1; |
2482 | |
2483 | if (jcp.nthr < jcp.ngroups) { |
2484 | /* simplification... fortunately it doesn't hurt much */ |
2485 | nthr_ = nthr_g_ = jcp.nthr; |
2486 | return; |
2487 | } |
2488 | |
2489 | nthr_g_ = jcp.ngroups; |
2490 | const int nthr = jcp.nthr / nthr_g_; |
2491 | |
2492 | float best_mem_cost |
2493 | = calc_mem_cost(nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_); |
2494 | |
2495 | /* find the best thread distribution with lowest memory cost */ |
2496 | |
2497 | const int nthr_mb_max = nstl::min(nthr, jcp.nthr_mb_work); |
2498 | for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
2499 | const int nthr_par = nthr / nthr_mb; |
2500 | const int nthr_oc_b_max = nstl::min(nthr_par, |
2501 | oc_chunks); // Amount of nb_oc_blocks |
2502 | for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
2503 | int nthr_ic_b = nstl::min( |
2504 | nthr_par / nthr_oc_b, (jcp.nb_ic / jcp.nb_ic_blocking)); |
2505 | |
2506 | float mem_cost |
2507 | = calc_mem_cost(nthr_mb, nthr_g_, nthr_oc_b, nthr_ic_b); |
2508 | if (mem_cost <= best_mem_cost) { |
2509 | best_mem_cost = mem_cost; |
2510 | nthr_mb_ = nthr_mb; |
2511 | nthr_oc_b_ = nthr_oc_b; |
2512 | nthr_ic_b_ = nthr_ic_b; |
2513 | } |
2514 | } |
2515 | } |
2516 | |
2517 | if (nthr_mb_ > nthr / 2 && nthr_mb_ < nthr) |
2518 | nthr_mb_ = nstl::min(jcp.nthr_mb_work, nthr); |
2519 | nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_; |
2520 | |
2521 | assert(nthr_ <= jcp.nthr); |
2522 | }; |
2523 | |
2524 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
2525 | balance(nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b); |
2526 | |
2527 | // empiric balancing for some shapes |
2528 | const auto sps = (jcp.ih * jcp.iw); |
2529 | bool neat_1x1 |
2530 | = everyone_is(1, jcp.id, jcp.kh, jcp.kw, jcp.ngroups, jcp.stride_h); |
2531 | if (neat_1x1 && jcp.nthr >= 28 && jcp.mb >= jcp.nthr) { |
2532 | const bool more_oc = (jcp.ic < jcp.oc); |
2533 | if (sps >= 56 * 56 && jcp.ic >= 64 && jcp.oc >= 64) { |
2534 | nthr_mb = jcp.nthr; |
2535 | nthr_oc_b = 1; |
2536 | } else if (sps >= 28 * 28 && jcp.ic >= 128 && jcp.oc >= 128) { |
2537 | nthr_mb = jcp.nthr / 4; |
2538 | nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1; |
2539 | } else if (sps >= 14 * 14 && jcp.ic >= 256 && jcp.oc >= 256) { |
2540 | nthr_mb = div_up(jcp.nthr, 8); |
2541 | nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1; |
2542 | } else if (sps >= 7 * 7 && jcp.ic >= 512 && jcp.oc >= 512) { |
2543 | nthr_mb = div_up(jcp.nthr, 14); |
2544 | nthr_oc_b = more_oc ? jcp.nthr / nthr_mb : 1; |
2545 | } |
2546 | nthr_ic_b = jcp.nthr / (nthr_mb * nthr_oc_b); |
2547 | nthr = nthr_mb * nthr_g * nthr_oc_b * nthr_ic_b; |
2548 | } |
2549 | |
2550 | jcp.nthr = nthr; |
2551 | jcp.nthr_mb = nthr_mb; |
2552 | jcp.nthr_g = nthr_g; |
2553 | jcp.nthr_oc_b = nthr_oc_b; |
2554 | jcp.nthr_ic_b = nthr_ic_b; |
2555 | |
2556 | // TODO: Optimize memory allocation when threaded on height and depth |
2557 | jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id; |
2558 | jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; |
2559 | jcp.tr_src_buf_count = jcp.global_transpose |
2560 | ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups |
2561 | : jcp.nthr; |
2562 | jcp.tr_diff_dst_buf_count = jcp.global_transpose |
2563 | ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups |
2564 | : jcp.nthr; |
2565 | } |
2566 | |
2567 | status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, |
2568 | const convolution_desc_t &cd, memory_desc_t &src_md, |
2569 | memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, |
2570 | memory_desc_t &diff_dst_md, primitive_attr_t &attr, int nthreads) { |
2571 | |
2572 | const memory_desc_wrapper src_d(&src_md); |
2573 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
2574 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
2575 | const memory_desc_wrapper diff_bias_d(&diff_bias_md); |
2576 | |
2577 | const bool is_f16 = src_d.data_type() == data_type::f16; |
2578 | |
2579 | jcp.isa = is_f16 ? avx512_core_amx_fp16 : avx512_core_amx; |
2580 | if (!mayiuse(jcp.isa)) return status::unimplemented; |
2581 | |
2582 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
2583 | int ndims = src_d.ndims(); |
2584 | |
2585 | CHECK(init_jcp(jcp, jcp.isa, cd, src_md, diff_weights_md, diff_dst_md, |
2586 | diff_bias_md, attr, nthreads)); |
2587 | |
2588 | jcp.max_batch = jcp.od * jcp.oh; |
2589 | jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM |
2590 | jcp.use_uker = true; |
2591 | jcp.var_bs = true; |
2592 | |
2593 | // Process some 1x1 convolutions with small iw as 1d (h=1, w = h*w) |
2594 | // convolutions to make brgemm K dimension bigger for better utilization of |
2595 | // AMX tiles |
2596 | bool neat_1x1_2d = (everyone_is( |
2597 | 1, jcp.kh, jcp.kw, jcp.stride_h, jcp.stride_w) |
2598 | && everyone_is(0, jcp.t_pad, jcp.b_pad, jcp.l_pad, jcp.r_pad)); |
2599 | bool make_1d = neat_1x1_2d && jcp.iw <= 28; |
2600 | if (make_1d) { |
2601 | jcp.iw *= jcp.ih; |
2602 | jcp.ih = 1; |
2603 | jcp.ow *= jcp.oh; |
2604 | jcp.oh = 1; |
2605 | jcp.max_batch = jcp.od; |
2606 | } |
2607 | // TODO: sometimes we can call brgemm kernel with bs = 0 to do initialization |
2608 | // review this condition |
2609 | if (jcp.max_batch == 1 |
2610 | && everyone_is(0, jcp.f_pad, jcp.back_pad, jcp.t_pad, jcp.b_pad)) |
2611 | jcp.var_bs = false; |
2612 | |
2613 | jcp.has_vnni = true; // Needed for transpose routines |
2614 | |
2615 | jcp.typesize_in = sizeof(bfloat16_t); |
2616 | jcp.typesize_out = sizeof(float); |
2617 | |
2618 | bool ok = true |
2619 | // general condition to simplify dilations |
2620 | && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1) |
2621 | && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1) |
2622 | && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1) |
2623 | // special condition to simplify dilations in compute_oh_loop_common |
2624 | && IMPLICATION(jcp.dilate_h != 0, jcp.ext_kh <= jcp.ih); |
2625 | if (!ok) return status::unimplemented; |
2626 | |
2627 | jcp.transform_to_vnni = diff_weights_d.data_type() != data_type::f32; |
2628 | |
2629 | /* XXX: no support for padding when dilation_d > 0 */ |
2630 | if (!IMPLICATION(jcp.dilate_d > 0, everyone_is(0, jcp.back_pad, jcp.f_pad))) |
2631 | return status::unimplemented; |
2632 | |
2633 | const bool is_depthwise = true && with_groups && jcp.ngroups > 1 |
2634 | && everyone_is(1, jcp.ic, jcp.oc); |
2635 | if (is_depthwise) |
2636 | return status::unimplemented; // TODO: add support of DW convolution |
2637 | |
2638 | const int dat_format_tag = ndims - 3; |
2639 | format_tag_t dat_tag_nspc = utils::pick(dat_format_tag, format_tag::nwc, |
2640 | format_tag::nhwc, format_tag::ndhwc); |
2641 | format_tag_t dat_tag_opt = dat_tag_nspc; |
2642 | |
2643 | if (src_d.format_kind() == format_kind::any) { |
2644 | CHECK(memory_desc_init_by_tag(src_md, dat_tag_opt)); |
2645 | jcp.src_tag = dat_tag_opt; |
2646 | } else |
2647 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_opt); |
2648 | if (!one_of(jcp.src_tag, dat_tag_opt)) return status::unimplemented; |
2649 | |
2650 | const bool is_nspc = jcp.src_tag == dat_tag_nspc; |
2651 | if (!is_nspc) return status::unimplemented; |
2652 | |
2653 | if (diff_dst_d.format_kind() == format_kind::any) { |
2654 | CHECK(memory_desc_init_by_tag(diff_dst_md, jcp.src_tag)); |
2655 | jcp.dst_tag = jcp.src_tag; |
2656 | } else |
2657 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(jcp.src_tag); |
2658 | if (jcp.dst_tag != jcp.src_tag) return status::unimplemented; |
2659 | |
2660 | const int wei_format_tag = 2 * ndims - 6 + with_groups; |
2661 | format_tag_t wei_tag; |
2662 | if (jcp.transform_to_vnni) |
2663 | wei_tag = pick(wei_format_tag, format_tag::OIw16i16o2i, |
2664 | format_tag::gOIw16i16o2i, format_tag::OIhw16i16o2i, |
2665 | format_tag::gOIhw16i16o2i, format_tag::OIdhw16i16o2i, |
2666 | format_tag::gOIdhw16i16o2i); |
2667 | else |
2668 | wei_tag = pick(wei_format_tag, format_tag::OIw16i16o, |
2669 | format_tag::gOIw16i16o, format_tag::OIhw16i16o, |
2670 | format_tag::gOIhw16i16o, format_tag::OIdhw16i16o, |
2671 | format_tag::gOIdhw16i16o); |
2672 | if (diff_weights_md.format_kind == format_kind::any) { |
2673 | CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag)); |
2674 | jcp.wei_tag = wei_tag; |
2675 | } else { |
2676 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
2677 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
2678 | } |
2679 | jcp.wei_dt = diff_weights_d.data_type(); |
2680 | |
2681 | /* kernel applicability check wrt boundaries |
2682 | * the conditions are quite general across the kernels we have, |
2683 | * but ideally the check should belong to a specific kernel... */ |
2684 | const int max_pad_h = jcp.ext_kh / 2; |
2685 | const bool boundaries_ok = true && jcp.l_pad < jcp.ext_kw |
2686 | && jcp.r_pad < jcp.ext_kw && jcp.t_pad <= max_pad_h |
2687 | && jcp.b_pad <= max_pad_h && jcp.f_pad < jcp.ext_kd |
2688 | && jcp.back_pad < jcp.ext_kd; |
2689 | if (!boundaries_ok) return status::unimplemented; |
2690 | |
2691 | jcp.ic_block = 16; |
2692 | jcp.oc_block = 16; |
2693 | |
2694 | jcp.nb_ic = utils::div_up(jcp.ic, jcp.ic_block); |
2695 | jcp.nb_oc = utils::div_up(jcp.oc, jcp.oc_block); |
2696 | |
2697 | jcp.ic_tail = jcp.ic % jcp.ic_block; |
2698 | jcp.oc_tail = jcp.oc % jcp.oc_block; |
2699 | |
2700 | jcp.nb_oc_blocking = (jcp.nb_oc > 1) ? 2 : 1; |
2701 | jcp.nb_ic_blocking = (jcp.nb_ic > 1) ? 2 : 1; |
2702 | |
2703 | const bool is_2d = (ndims == 4); |
2704 | const bool is_3d = (ndims == 5); |
2705 | |
2706 | // TODO: Find more shapes (especially 3D with large spatials) for which |
2707 | // local transposition will be beneficial. Furthermore, for TBB threads |
2708 | // more shapes can potentially benefit from spatial blocking |
2709 | int optimal_blk_size = is_3d ? jcp.od : is_2d ? jcp.oh : jcp.ow; |
2710 | |
2711 | jcp.global_transpose = dnnl_thr_syncable(); |
2712 | jcp.spatial_blk_size = optimal_blk_size; |
2713 | |
2714 | const int tr_round = 32; // To load full tile register |
2715 | int tr_pad |
2716 | = rnd_up(nstl::max(jcp.l_pad, jcp.r_pad + 1), tr_round); //!!! why? |
2717 | jcp.tr_iw = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad, jcp.stride_w), |
2718 | tr_round) |
2719 | * jcp.stride_w; |
2720 | |
2721 | // TODO: xf16 training is supported only |
2722 | const auto rnd_val = data_type_vnni_granularity(bf16); |
2723 | jcp.tr_src_num_guard_elems = tr_pad; // upper bound |
2724 | jcp.tr_ow = rnd_up(jcp.ow, rnd_val); |
2725 | if (jcp.tr_ow > tr_round) { |
2726 | // we may increase tr_ow to have better bd_block in brgemm kernel |
2727 | int best_bdb = jcp.tr_ow / rnd_val; |
2728 | int best_tr_ow = jcp.tr_ow; |
2729 | for (int tr_ow = jcp.tr_ow; tr_ow <= rnd_up(jcp.tr_ow, tr_round); |
2730 | tr_ow += rnd_val) { |
2731 | for (int i = tr_round; i > 0; i -= rnd_val) { |
2732 | if (tr_ow % i == 0) { |
2733 | const auto cbdb = tr_ow / i; |
2734 | if (cbdb < best_bdb) { |
2735 | best_bdb = cbdb; |
2736 | best_tr_ow = tr_ow; |
2737 | } |
2738 | break; |
2739 | } |
2740 | } |
2741 | } |
2742 | jcp.tr_ow = best_tr_ow; |
2743 | } |
2744 | |
2745 | bool args_ok = true && jcp.ic <= src_d.padded_dims()[1] |
2746 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
2747 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
2748 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
2749 | if (!args_ok) return status::unimplemented; |
2750 | |
2751 | jcp.harness = ndims == 5 ? harness_3d_reduction : harness_2d_reduction; |
2752 | |
2753 | if (!one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) { |
2754 | return status::unimplemented; |
2755 | } |
2756 | |
2757 | switch (jcp.harness) { |
2758 | case harness_2d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.oh; break; |
2759 | case harness_3d_reduction: jcp.nthr_mb_work = jcp.mb * jcp.od; break; |
2760 | default: assert(!"Invalid harness" ); jcp.nthr_mb_work = jcp.mb; |
2761 | } |
2762 | |
2763 | balance_bwd_w(jcp); |
2764 | |
2765 | if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) { |
2766 | jcp.K = jcp.tr_ow; |
2767 | } |
2768 | jcp.K_tail = 0; |
2769 | |
2770 | jcp.M = jcp.ic_block * jcp.nb_ic_blocking; |
2771 | // assumption that jcp.nb_ic_blocking is always 2 |
2772 | if (jcp.nb_ic % jcp.nthr_ic_b == 0 |
2773 | && (jcp.nb_ic / jcp.nthr_ic_b) % jcp.nb_ic_blocking == 0) |
2774 | jcp.M_tail = 0; |
2775 | else |
2776 | jcp.M_tail = jcp.ic_block; |
2777 | |
2778 | jcp.N = jcp.oc_block * jcp.nb_oc_blocking; |
2779 | // assumption that jcp.nb_oc_blocking is always 2 |
2780 | if (jcp.nb_oc % jcp.nthr_oc_b == 0 |
2781 | && (jcp.nb_oc / jcp.nthr_oc_b) % jcp.nb_oc_blocking == 0) |
2782 | jcp.N_tail = 0; |
2783 | else |
2784 | jcp.N_tail = jcp.oc_block; |
2785 | |
2786 | // for convolutions with big spatial: transpose only chunk |
2787 | // (oc_block * nb_oc_blocking) of diff_dst on each iteration by oc blocks |
2788 | // for better cache utilization |
2789 | // the equal number of channel blocks per thread is required to use this |
2790 | // approach to avoid hangs |
2791 | bool tr_ocb_chunk_allowed = (jcp.nb_oc % jcp.nthr_oc_b == 0); |
2792 | jcp.tr_ocb_chunk = tr_ocb_chunk_allowed && (jcp.oh * jcp.ow > 38 * 38); |
2793 | jcp.tr_icb_chunk = false; |
2794 | |
2795 | const int irow_size = jcp.src_dsz * jcp.tr_iw * jcp.ic_block |
2796 | * div_up(jcp.nb_ic, jcp.nthr_ic_b) |
2797 | * 2 /*we have real and transposed input */; |
2798 | const int orow_size = jcp.dst_dsz * jcp.tr_ow * jcp.oc_block |
2799 | * div_up(jcp.nb_oc, jcp.nthr_oc_b) |
2800 | * 2 /*we have real and transposed diff_dst*/; |
2801 | int oh_block_limit = nstl::max(1.f, |
2802 | nstl::max(0.f, 0.8f * brg_blocking_t::L2 - jcp.kh * irow_size) |
2803 | / (irow_size + orow_size)); |
2804 | // try to split oh by equal oh blocks |
2805 | oh_block_limit = div_up(jcp.oh, div_up(jcp.oh, oh_block_limit)); |
2806 | jcp.oh_block = utils::saturate(1, jcp.oh, oh_block_limit); |
2807 | |
2808 | const int iframe_size = irow_size * jcp.id; |
2809 | const int oframe_size = orow_size * jcp.od; |
2810 | int od_block_limit = nstl::max(1.f, |
2811 | nstl::max(0.f, 0.8f * brg_blocking_t::L2 - jcp.kd * iframe_size) |
2812 | / (iframe_size + oframe_size)); |
2813 | // try to split od by equal od blocks |
2814 | od_block_limit = div_up(jcp.od, div_up(jcp.od, od_block_limit)); |
2815 | jcp.od_block = utils::saturate(1, jcp.od, od_block_limit); |
2816 | |
2817 | jcp.use_interleave_stores = false; |
2818 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
2819 | jcp.amx_tile_load_xx = false; |
2820 | |
2821 | if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) { |
2822 | jcp.LDA = jcp.tr_iw; |
2823 | jcp.LDB = jcp.oc_block; |
2824 | jcp.LDC = jcp.LDD = jcp.oc_block; |
2825 | } |
2826 | |
2827 | jcp.gemm_batch_size = jcp.max_batch; |
2828 | // to avoid cache concurrent access from different threads |
2829 | size_t sc_size = sizeof(brgemm_batch_element_t); |
2830 | jcp.adjusted_batch_size |
2831 | = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size); |
2832 | |
2833 | return status::success; |
2834 | } |
2835 | |
2836 | status_t init_scratchpad_bwd_w(memory_tracking::registrar_t &scratchpad, |
2837 | const jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md, |
2838 | memory_desc_t &diff_weights_md, memory_desc_t &diff_dst_md) { |
2839 | const memory_desc_wrapper src_d(&src_md); |
2840 | const memory_desc_wrapper diff_weights_d(&diff_weights_md); |
2841 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
2842 | |
2843 | // XXX: See the comment about tr_iw and guarding elements in |
2844 | // jit_avx512_core_amx_bwd_weights_kernel_t::init_conf() |
2845 | const size_t tr_src_size |
2846 | = (jcp.tr_src_buf_count * jcp.tr_src_buf_size * jcp.nb_ic_blocking) |
2847 | + jcp.tr_src_num_guard_elems; |
2848 | scratchpad.book(key_conv_tr_src, tr_src_size, jcp.src_dsz); |
2849 | |
2850 | /* prepare synchronization contexts */ |
2851 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
2852 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
2853 | scratchpad.book<simple_barrier::ctx_t>( |
2854 | key_conv_tr_src_bctx, tr_src_bctx_size); |
2855 | } |
2856 | |
2857 | // The tr_ow <= tr_iw, so we need some guarding at the end of diff_dst |
2858 | // TODO: update this guarding: |
2859 | // (jcp.tr_diff_dst_buf_size + jcp.tr_iw * jcp.oc_block) |
2860 | const auto tr_diff_dst_size = jcp.tr_diff_dst_buf_count |
2861 | |
2862 | * (jcp.tr_diff_dst_buf_size + jcp.tr_iw * jcp.oc_block) |
2863 | * jcp.nb_oc_blocking; |
2864 | |
2865 | const size_t min_align = 64; |
2866 | scratchpad.book( |
2867 | key_conv_tr_diff_dst, tr_diff_dst_size, jcp.src_dsz, min_align); |
2868 | |
2869 | /* prepare synchronization contexts */ |
2870 | if (jcp.global_transpose && jcp.nthr_ic_b > 1) { |
2871 | const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
2872 | scratchpad.book<simple_barrier::ctx_t>( |
2873 | key_conv_tr_diff_dst_bctx, tr_diff_dst_bctx_size); |
2874 | } |
2875 | |
2876 | if (IMPLICATION(jcp.nthr_mb == 1, |
2877 | (jcp.with_bias && jcp.bia_dt != data_type::f32) |
2878 | || jcp.wei_dt != data_type::f32)) { |
2879 | const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block |
2880 | * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
2881 | const size_t bia_size |
2882 | = jcp.with_bias * jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
2883 | |
2884 | const int num_wei_buffers |
2885 | = jcp.wei_dt != data_type::f32 ? jcp.nthr_mb : jcp.nthr_mb - 1; |
2886 | const int num_bia_buffers = jcp.with_bias |
2887 | ? (jcp.bia_dt != data_type::f32 ? jcp.nthr_mb : jcp.nthr_mb - 1) |
2888 | : 0; |
2889 | |
2890 | const size_t wei_bia_reduction_size |
2891 | = wei_size * num_wei_buffers + bia_size * num_bia_buffers; |
2892 | |
2893 | scratchpad.book<float>( |
2894 | key_conv_wei_bia_reduction, wei_bia_reduction_size); |
2895 | |
2896 | scratchpad.book<simple_barrier::ctx_t>( |
2897 | key_conv_wei_bia_reduction_bctx, 1); |
2898 | } |
2899 | |
2900 | if (jcp.with_bias |
2901 | && ((jcp.oc % jcp.oc_block != 0) && jcp.bia_dt == data_type::f32)) { |
2902 | scratchpad.book(key_conv_padded_bias, |
2903 | jcp.ngroups * jcp.nb_oc * jcp.oc_block, jcp.bia_dsz); |
2904 | } |
2905 | scratchpad.book(key_conv_amx_tilecfg, 1, 64); // 1 whole cacheline |
2906 | |
2907 | constexpr size_t scratchpad_limit_by_absolute_value = (size_t)32 |
2908 | << 30; // 32Gb - TODO: may it's too large? |
2909 | const size_t scratchpad_limit_by_tensor_sizes = (size_t)64 * jcp.nthr |
2910 | * (src_d.size() + diff_weights_d.size() + diff_dst_d.size()); |
2911 | const size_t scratchpad_limit |
2912 | = nstl::min(scratchpad_limit_by_absolute_value, |
2913 | scratchpad_limit_by_tensor_sizes); |
2914 | |
2915 | scratchpad.book(key_brgemm_primitive_batch, |
2916 | static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size, |
2917 | sizeof(brgemm_batch_element_t), 64, P4K); |
2918 | |
2919 | scratchpad.book( |
2920 | key_conv_amx_tile_buffer, jcp.nthr * 2 * P4K, sizeof(char), 0, P4K); |
2921 | |
2922 | if (scratchpad.size() > scratchpad_limit) |
2923 | return status::unimplemented; |
2924 | else |
2925 | return status::success; |
2926 | } |
2927 | |
2928 | } // namespace brgemm_convolution_utils |
2929 | |
2930 | } // namespace x64 |
2931 | } // namespace cpu |
2932 | } // namespace impl |
2933 | } // namespace dnnl |
2934 | |