1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_types.h" |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/platform.hpp" |
27 | #include "cpu/x64/brgemm/brgemm_utils.hpp" |
28 | #include "cpu/x64/cpu_isa_traits.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
30 | #include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp" |
31 | #include "cpu/x64/jit_generator.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | |
38 | using namespace dnnl::impl::status; |
39 | using namespace dnnl::impl::format_tag; |
40 | using namespace dnnl::impl::memory_tracking::names; |
41 | using namespace dnnl::impl::utils; |
42 | |
43 | using namespace prop_kind; |
44 | using namespace data_type; |
45 | |
46 | namespace brgemm_convolution_bwd_utils { |
47 | |
48 | inline status_t init_tag(format_tag_t &tag, memory_desc_t &md, |
49 | const memory_desc_wrapper &mdw, const format_tag_t tag_value) { |
50 | if (mdw.format_kind() == format_kind::any) { |
51 | CHECK(memory_desc_init_by_tag(md, tag_value)); |
52 | tag = tag_value; |
53 | } else { |
54 | tag = mdw.matches_one_of_tag(tag_value); |
55 | } |
56 | |
57 | if (tag != tag_value) return status::unimplemented; |
58 | |
59 | return status::success; |
60 | } |
61 | |
62 | bool is_amx(cpu_isa_t isa) { |
63 | return is_superset(isa, avx512_core_amx); |
64 | } |
65 | |
66 | bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr, |
67 | const memory_desc_wrapper &dst_d, bool enable_postops) { |
68 | using namespace injector; |
69 | |
70 | const auto &post_ops = attr.post_ops_; |
71 | |
72 | if (post_ops.len() > 0 && !enable_postops) return false; |
73 | |
74 | return injector::post_ops_ok(post_ops_ok_args_t(jcp.isa, |
75 | {sum, eltwise, binary}, post_ops, &dst_d, |
76 | false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, |
77 | false /*sum_requires_zp_zero*/, |
78 | {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar, |
79 | broadcasting_strategy_t::no_broadcast})); |
80 | } |
81 | |
82 | bool is_groups_ok(jit_brgemm_conv_conf_t &jcp) { |
83 | // Enable grouped convs for the shapes not supported in direct convs |
84 | // direct approach only supports int8/bf16 grouped conv |
85 | // when channels per groups is at least multiple of 4 |
86 | // and bf16 grouped conv with layout nxc on jit_bf16 impl |
87 | // TODO: remove this condition after the restriction on small oc is removed |
88 | return jcp.ngroups > 1 |
89 | && IMPLICATION(one_of(jcp.src_dt, u8, s8, bf16, f16), |
90 | jcp.oc % 4 == 0 && jcp.ic % 4 == 0); |
91 | } |
92 | |
93 | status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &diff_dst_md, |
94 | memory_desc_t &weights_md, memory_desc_t &diff_src_md, |
95 | memory_desc_t &bias_md) { |
96 | format_tag_t src_tag, dst_tag, wei_tag; |
97 | dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc); |
98 | |
99 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
100 | const memory_desc_wrapper weights_d(&weights_md); |
101 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
102 | const memory_desc_wrapper bias_d(&bias_md); |
103 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
104 | |
105 | const bool is_1d = jcp.ndims == 3; |
106 | const bool is_2d = jcp.ndims == 4; |
107 | const bool is_3d = jcp.ndims == 5; |
108 | |
109 | if (jcp.wei_plain) { |
110 | return status::unimplemented; |
111 | } else { |
112 | jcp.LDB = jcp.ic_block; |
113 | if (jcp.ic_block == 64) { |
114 | if (is_3d) { |
115 | if (jcp.wei_dt == f32) |
116 | wei_tag = with_groups ? gIdhwo64i : Idhwo64i; |
117 | else if (jcp.wei_dt == s8) { |
118 | if (jcp.is_oc_padded) |
119 | wei_tag = with_groups ? gIdhwO16o64i4o : IdhwO16o64i4o; |
120 | else |
121 | wei_tag = with_groups ? gIdhwO64i4o : IdhwO64i4o; |
122 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
123 | if (jcp.is_oc_padded) |
124 | wei_tag = with_groups ? gIdhwO16o64i2o : IdhwO16o64i2o; |
125 | else |
126 | wei_tag = with_groups ? gIdhwO64i2o : IdhwO64i2o; |
127 | } else |
128 | return status::unimplemented; |
129 | } else if (is_1d) { |
130 | if (jcp.wei_dt == f32) |
131 | wei_tag = with_groups ? gIwo64i : Iwo64i; |
132 | else if (jcp.wei_dt == s8) { |
133 | if (jcp.is_oc_padded) |
134 | wei_tag = with_groups ? gIwO16o64i4o : IwO16o64i4o; |
135 | else |
136 | wei_tag = with_groups ? gIwO64i4o : IwO64i4o; |
137 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
138 | if (jcp.is_oc_padded) |
139 | wei_tag = with_groups ? gIwO16o64i2o : IwO16o64i2o; |
140 | else |
141 | wei_tag = with_groups ? gIwO64i2o : IwO64i2o; |
142 | } else |
143 | return status::unimplemented; |
144 | } else { |
145 | assert(is_2d); |
146 | UNUSED(is_2d); |
147 | if (jcp.wei_dt == f32) |
148 | wei_tag = with_groups ? gIhwo64i : Ihwo64i; |
149 | else if (jcp.wei_dt == s8) { |
150 | if (jcp.is_oc_padded) |
151 | wei_tag = with_groups ? gIhwO16o64i4o : IhwO16o64i4o; |
152 | else |
153 | wei_tag = with_groups ? gIhwO64i4o : IhwO64i4o; |
154 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
155 | if (jcp.is_oc_padded) |
156 | wei_tag = with_groups ? gIhwO16o64i2o : IhwO16o64i2o; |
157 | else |
158 | wei_tag = with_groups ? gIhwO64i2o : IhwO64i2o; |
159 | } else |
160 | return status::unimplemented; |
161 | } |
162 | } else if (jcp.ic_block == 48) { |
163 | if (is_3d) { |
164 | if (jcp.wei_dt == f32) |
165 | wei_tag = with_groups ? gIdhwo48i : Idhwo48i; |
166 | else if (jcp.wei_dt == s8) { |
167 | if (jcp.is_oc_padded) |
168 | wei_tag = with_groups ? gIdhwO16o48i4o : IdhwO16o48i4o; |
169 | else |
170 | wei_tag = with_groups ? gIdhwO48i4o : IdhwO48i4o; |
171 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
172 | if (jcp.is_oc_padded) |
173 | wei_tag = with_groups ? gIdhwO16o48i2o : IdhwO16o48i2o; |
174 | else |
175 | wei_tag = with_groups ? gIdhwO48i2o : IdhwO48i2o; |
176 | } else |
177 | return status::unimplemented; |
178 | } else if (is_1d) { |
179 | if (jcp.wei_dt == f32) |
180 | wei_tag = with_groups ? gIwo48i : Iwo48i; |
181 | else if (jcp.wei_dt == s8) { |
182 | if (jcp.is_oc_padded) |
183 | wei_tag = with_groups ? gIwO16o48i4o : IwO16o48i4o; |
184 | else |
185 | wei_tag = with_groups ? gIwO48i4o : IwO48i4o; |
186 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
187 | if (jcp.is_oc_padded) |
188 | wei_tag = with_groups ? gIwO16o48i2o : IwO16o48i2o; |
189 | else |
190 | wei_tag = with_groups ? gIwO48i2o : IwO48i2o; |
191 | } else |
192 | return status::unimplemented; |
193 | } else { |
194 | assert(is_2d); |
195 | UNUSED(is_2d); |
196 | if (jcp.wei_dt == f32) |
197 | wei_tag = with_groups ? gIhwo48i : Ihwo48i; |
198 | else if (jcp.wei_dt == s8) { |
199 | if (jcp.is_oc_padded) |
200 | wei_tag = with_groups ? gIhwO16o48i4o : IhwO16o48i4o; |
201 | else |
202 | wei_tag = with_groups ? gIhwO48i4o : IhwO48i4o; |
203 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
204 | if (jcp.is_oc_padded) |
205 | wei_tag = with_groups ? gIhwO16o48i2o : IhwO16o48i2o; |
206 | else |
207 | wei_tag = with_groups ? gIhwO48i2o : IhwO48i2o; |
208 | } else |
209 | return status::unimplemented; |
210 | } |
211 | } else if (jcp.ic_block == 32) { |
212 | if (is_3d) { |
213 | if (jcp.wei_dt == f32) |
214 | wei_tag = with_groups ? gIdhwo32i : Idhwo32i; |
215 | else if (jcp.wei_dt == s8) { |
216 | if (jcp.is_oc_padded) |
217 | wei_tag = with_groups ? gIdhwO16o32i4o : IdhwO16o32i4o; |
218 | else |
219 | wei_tag = with_groups ? gIdhwO32i4o : IdhwO32i4o; |
220 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
221 | if (jcp.is_oc_padded) |
222 | wei_tag = with_groups ? gIdhwO16o32i2o : IdhwO16o32i2o; |
223 | else |
224 | wei_tag = with_groups ? gIdhwO32i2o : IdhwO32i2o; |
225 | } else |
226 | return status::unimplemented; |
227 | } else if (is_1d) { |
228 | if (jcp.wei_dt == f32) |
229 | wei_tag = with_groups ? gIwo32i : Iwo32i; |
230 | else if (jcp.wei_dt == s8) { |
231 | if (jcp.is_oc_padded) |
232 | wei_tag = with_groups ? gIwO16o32i4o : IwO16o32i4o; |
233 | else |
234 | wei_tag = with_groups ? gIwO32i4o : IwO32i4o; |
235 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
236 | if (jcp.is_oc_padded) |
237 | wei_tag = with_groups ? gIwO16o32i2o : IwO16o32i2o; |
238 | else |
239 | wei_tag = with_groups ? gIwO32i2o : IwO32i2o; |
240 | } else |
241 | return status::unimplemented; |
242 | } else { |
243 | assert(is_2d); |
244 | UNUSED(is_2d); |
245 | if (jcp.wei_dt == f32) |
246 | wei_tag = with_groups ? gIhwo32i : Ihwo32i; |
247 | else if (jcp.wei_dt == s8) { |
248 | if (jcp.is_oc_padded) |
249 | wei_tag = with_groups ? gIhwO16o32i4o : IhwO16o32i4o; |
250 | else |
251 | wei_tag = with_groups ? gIhwO32i4o : IhwO32i4o; |
252 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
253 | if (jcp.is_oc_padded) |
254 | wei_tag = with_groups ? gIhwO16o32i2o : IhwO16o32i2o; |
255 | else |
256 | wei_tag = with_groups ? gIhwO32i2o : IhwO32i2o; |
257 | } else |
258 | return status::unimplemented; |
259 | } |
260 | } else { |
261 | if (is_3d) { |
262 | if (jcp.wei_dt == f32) |
263 | wei_tag = with_groups ? gIdhwo16i : Idhwo16i; |
264 | else if (jcp.wei_dt == s8) { |
265 | if (jcp.is_oc_padded) |
266 | wei_tag = with_groups ? gIdhwO16o16i4o : IdhwO16o16i4o; |
267 | else |
268 | wei_tag = with_groups ? gIdhwO16i4o : IdhwO16i4o; |
269 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
270 | if (jcp.is_oc_padded) |
271 | wei_tag = with_groups ? gIdhwO16o16i2o : IdhwO16o16i2o; |
272 | else |
273 | wei_tag = with_groups ? gIdhwO16i2o : IdhwO16i2o; |
274 | } else |
275 | return status::unimplemented; |
276 | } else if (is_1d) { |
277 | if (jcp.wei_dt == f32) |
278 | wei_tag = with_groups ? gIwo16i : Iwo16i; |
279 | else if (jcp.wei_dt == s8) { |
280 | if (jcp.is_oc_padded) |
281 | wei_tag = with_groups ? gIwO16o16i4o : IwO16o16i4o; |
282 | else |
283 | wei_tag = with_groups ? gIwO16i4o : IwO16i4o; |
284 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
285 | if (jcp.is_oc_padded) |
286 | wei_tag = with_groups ? gIwO16o16i2o : IwO16o16i2o; |
287 | else |
288 | wei_tag = with_groups ? gIwO16i2o : IwO16i2o; |
289 | } else |
290 | return status::unimplemented; |
291 | } else { |
292 | assert(is_2d); |
293 | UNUSED(is_2d); |
294 | |
295 | if (jcp.wei_dt == f32) |
296 | wei_tag = with_groups ? gIhwo16i : Ihwo16i; |
297 | else if (jcp.wei_dt == s8) { |
298 | if (jcp.is_oc_padded) |
299 | wei_tag = with_groups ? gIhwO16o16i4o : IhwO16o16i4o; |
300 | else |
301 | wei_tag = with_groups ? gIhwO16i4o : IhwO16i4o; |
302 | } else if (one_of(jcp.wei_dt, bf16, f16)) { |
303 | if (jcp.is_oc_padded) |
304 | wei_tag = with_groups ? gIhwO16o16i2o : IhwO16o16i2o; |
305 | else |
306 | wei_tag = with_groups ? gIhwO16i2o : IhwO16i2o; |
307 | } else |
308 | return status::unimplemented; |
309 | } |
310 | } |
311 | } |
312 | |
313 | src_tag = dst_tag; |
314 | |
315 | CHECK(init_tag(jcp.src_tag, diff_dst_md, diff_dst_d, src_tag)); |
316 | CHECK(init_tag(jcp.dst_tag, diff_src_md, diff_src_d, dst_tag)); |
317 | CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag)); |
318 | |
319 | return status::success; |
320 | } |
321 | |
322 | struct brg_blocking_t : public jit_brgemm_conv_conf_t { |
323 | struct array_in_loop_t { |
324 | dim_t itersize; |
325 | float repeatn; |
326 | float overlap; |
327 | void set(dim_t iter_s, float rpt, float ovlp = 1.f) { |
328 | itersize = iter_s; |
329 | repeatn = rpt; |
330 | overlap = ovlp; |
331 | } |
332 | }; |
333 | |
334 | struct loop_t { |
335 | array_in_loop_t src; |
336 | array_in_loop_t wei; |
337 | array_in_loop_t dst; |
338 | }; |
339 | |
340 | brg_blocking_t() { |
341 | // TODO: This is a broken form of initialization for a base class. |
342 | // Either set default values in a base class, or provide a proper |
343 | // default ctor, or take a `jit_brgemm_conv_conf_t` object to initialize |
344 | // a base class object. |
345 | jit_brgemm_conv_conf_t *base |
346 | = static_cast<jit_brgemm_conv_conf_t *>(this); |
347 | *base = jit_brgemm_conv_conf_t(); |
348 | init(); |
349 | } |
350 | brg_blocking_t(const jit_brgemm_conv_conf_t &jcp) |
351 | : jit_brgemm_conv_conf_t(jcp) { |
352 | init(); |
353 | } |
354 | void init() { |
355 | ur = 0; |
356 | ur_block = 0; |
357 | ur_block_tail = 0; |
358 | eff = 0.f; |
359 | nb_kd = 0; |
360 | nb_kh = 0; |
361 | nb_kw = 0; |
362 | sp = 0; |
363 | sp_block = 0; |
364 | nb_sp = 0; |
365 | eff = 0; |
366 | } |
367 | |
368 | int ur, ur_block, ur_block_tail; |
369 | int nb_kd, nb_kh, nb_kw; |
370 | float eff; |
371 | static unsigned L1; |
372 | static unsigned L2; |
373 | static unsigned L3; |
374 | // These are rough estimates of the latency (relative) of access to various |
375 | // cache levels. This is enough for an estimation of data access cost. |
376 | // TODO: Improve memory access estimates |
377 | static constexpr float L1_k = 1.f; |
378 | static constexpr float L2_k = 3.f; |
379 | static constexpr float L3_k = 15.f; |
380 | // TODO: At the moment, we are primarily evaluating the fit of the data into |
381 | // the L1/L2. Need to take into account the difference between the L3 and |
382 | // memory. |
383 | static constexpr float mem_k = 15.f; |
384 | static constexpr int bench_iterations = 1; |
385 | static constexpr int max_regs = 32; |
386 | static constexpr int bcast_simd = 16; |
387 | |
388 | int sp, sp_block, nb_sp; |
389 | static int last_oc_block_size; |
390 | |
391 | void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; } |
392 | void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; } |
393 | |
394 | status_t estimate_brgemm_ur(); |
395 | status_t get_brgemm_ur( |
396 | const primitive_attr_t *attr, const memory_desc_t &dst_md); |
397 | |
398 | float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk, |
399 | bool is_broadcast, bool is_shared) const; |
400 | |
401 | float io_k(const loop_t loop, const array_in_loop_t arr, float pk, |
402 | bool is_broadcast, bool is_shared) const; |
403 | |
404 | void select_oc_block(); |
405 | |
406 | void update_blocks(); |
407 | bool fast_check_ic_block() const; |
408 | float est_eff(); |
409 | void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block, |
410 | int kh_block, bool maybe_use_buffer, int max_iw_block_thr); |
411 | status_t calc_blocks(); |
412 | |
413 | bool fast_check_ic_block_1x1() const; |
414 | float est_eff_1x1(); |
415 | |
416 | // utils |
417 | static int get_inp_size( |
418 | int max_src_size, int dst_size, int k, int stride, int dilate) { |
419 | auto adj_str = nstl::min(k, stride); |
420 | const auto res = nstl::min(max_src_size, |
421 | calculate_end_padding(0, dst_size, 0, adj_str, |
422 | calculate_extended_filter_size(k, dilate))); |
423 | return res; |
424 | } |
425 | |
426 | static int get_inp_block_size( |
427 | int out_size, int stride, int ext_k, int padding) { |
428 | const auto res = div_up(out_size + padding % stride, stride) |
429 | + (ext_k - 1 - padding % stride) / stride; |
430 | return res; |
431 | } |
432 | |
433 | static float squeeze_val(float eff, float koeff) { |
434 | if (koeff <= 0) return 1; |
435 | if (koeff == 1) return eff; |
436 | const auto k = 1.f / koeff; |
437 | return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff; |
438 | } |
439 | |
440 | static int estimate_ur(int ic_block) { |
441 | const auto est_ur = (ic_block == 64) |
442 | ? 6 |
443 | : ((ic_block == 48) ? 9 : ((ic_block == 32) ? 14 : 28)); |
444 | return est_ur; |
445 | } |
446 | |
447 | int inp_w(int out_w, int ker_w) const { |
448 | return get_inp_size(ow, out_w, ker_w, stride_w, dilate_w); |
449 | } |
450 | |
451 | int rnd_simd(int val) const { return rnd_up(val, simd_w); } |
452 | |
453 | int rnd_inp_simd(int out_w, int ker_w, int voc) const { |
454 | const auto vsp = inp_w(out_w, ker_w); |
455 | return ((stride_w == 1 && voc >= oc) ? rnd_up(vsp * voc, simd_w) |
456 | : vsp * rnd_up(voc, simd_w)); |
457 | } |
458 | |
459 | static constexpr int MAXNLOOPS = 32; |
460 | loop_t loop[MAXNLOOPS]; |
461 | }; |
462 | |
463 | unsigned brg_blocking_t::L1; |
464 | unsigned brg_blocking_t::L2; |
465 | unsigned brg_blocking_t::L3; |
466 | int brg_blocking_t::last_oc_block_size; |
467 | |
468 | float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk, |
469 | bool is_broadcast, bool is_shared) const { |
470 | if (n < 1) return 0; |
471 | if (n == 1) return pk; |
472 | const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz |
473 | + (use_buffer ? dst * acc_dsz : 0); |
474 | const auto amount_L1 = is_broadcast ? src * src_dsz : amount; |
475 | const auto k = is_broadcast |
476 | ? ((amount_L1 < L1) ? L1_k |
477 | : ((amount < L2) ? L2_k |
478 | : (is_shared ? L3_k : mem_k))) |
479 | : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k)); |
480 | const auto cost = pk + k * (n - 1); |
481 | return cost / n; |
482 | } |
483 | |
484 | float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr, |
485 | float pk, bool is_broadcast, bool is_shared) const { |
486 | return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize, |
487 | arr.repeatn * arr.overlap, pk, is_broadcast, is_shared); |
488 | } |
489 | |
490 | void brg_blocking_t::select_oc_block() { |
491 | const auto padded_oc = last_oc_block_size * (is_oc_padded ? 16 : 1); |
492 | oc_block = rnd_up(oc, padded_oc); |
493 | nb_oc = utils::div_up(oc, oc_block); |
494 | } |
495 | |
496 | status_t brg_blocking_t::estimate_brgemm_ur() { |
497 | // Simple simulation of brgemm_desc init |
498 | if (sp_block <= 0) return status::invalid_arguments; |
499 | LDA = oc_block; |
500 | LDB = ic_block; |
501 | LDC = use_buffer ? ic_block : stride_w * ic_without_padding; |
502 | |
503 | // Configure matrix sizes |
504 | // for amx if oc_block != oc then we use exec_trans so K is oc_block |
505 | const auto padded_oc = last_oc_block_size * (is_oc_padded ? 16 : 1); |
506 | |
507 | ocp = rnd_up(oc, padded_oc); |
508 | |
509 | const auto adj_sp = div_up(iw_block, stride_w); |
510 | M = brgM = adj_sp >= sp_block ? sp_block : 0; |
511 | M_tail = brgM_tail = adj_sp % sp_block; |
512 | |
513 | N = ic >= ic_block ? ic_block : 0; |
514 | N_tail = ic % ic_block; |
515 | K = oc >= oc_block ? oc_block : 0; |
516 | K_tail = oc_block; |
517 | |
518 | const auto vK = K > 0 ? K : K_tail; |
519 | const auto vM = M > 0 ? M : M_tail; |
520 | const auto vN = N > 0 ? N : N_tail; |
521 | |
522 | const float alpha = 1.0; |
523 | const float beta = 0.0; |
524 | brgemm_t brg; |
525 | brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, |
526 | brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr, |
527 | is_bf32); |
528 | CHECK(brgemm_utils::brgemm_blocking(&brg)); |
529 | ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1); |
530 | if (ur == 0) return status::invalid_arguments; |
531 | ur_block = brg.bd_block; |
532 | if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) { |
533 | brgemm_t brg_sp_tail; |
534 | brgemm_utils::init_brgemm_conf(&brg_sp_tail, isa, brgemm_addr, src_dt, |
535 | wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, M_tail, |
536 | vN, vK, nullptr, is_bf32); |
537 | CHECK(brgemm_utils::brgemm_blocking(&brg_sp_tail)); |
538 | ur_block_tail = brg_sp_tail.bd_block; |
539 | } else { |
540 | ur_block_tail = 0; |
541 | } |
542 | return status::success; |
543 | } |
544 | |
545 | status_t brg_blocking_t::get_brgemm_ur( |
546 | const primitive_attr_t *attr, const memory_desc_t &dst_md) { |
547 | // Detailed simulation of brgemm convolution init |
548 | if (sp_block <= 0 || oc_block <= 0 || ic_block <= 0) |
549 | return status::invalid_arguments; |
550 | CHECK(estimate_brgemm_ur()); |
551 | |
552 | LDD = stride_w * ic_without_padding; |
553 | |
554 | const float alpha = 1.0; |
555 | const float beta = 1.0; |
556 | const float beta_init = 0.0; |
557 | |
558 | for (int i = 0; i < M; i++) { |
559 | auto vM = i + 1; |
560 | // init only needed brgemm descriptors |
561 | if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1) |
562 | && vM != M && vM != M_tail) |
563 | continue; |
564 | for (int i_init = 0; i_init < 2; i_init++) { |
565 | for (int i_N = 0; i_N < 2; i_N++) { |
566 | for (int i_K = 0; i_K < 2; i_K++) { |
567 | auto vbeta = (i_init) ? beta_init : beta; |
568 | auto vN = (i_N) ? N_tail : N; |
569 | auto vK = (i_K) ? K_tail : K; |
570 | if (vN == 0 || vK == 0) continue; |
571 | brgemm_t brg; |
572 | brgemm_strides_t brg_strides; |
573 | brg_strides.stride_a = ngroups * oc_without_padding |
574 | * (dilate_w + 1) * src_dsz; |
575 | //weights are padded by ic_block and last_oc_block |
576 | brg_strides.stride_b = rnd_up(oc, last_oc_block_size) |
577 | * rnd_up(ic, ic_block) * wei_dsz; |
578 | const auto strides_ptr = (brg_type == brgemm_strd) |
579 | ? &brg_strides |
580 | : nullptr; |
581 | brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt, |
582 | wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, |
583 | LDC, vM, vN, vK, strides_ptr, is_bf32); |
584 | CHECK(brgemm_utils::brgemm_blocking(&brg)); |
585 | |
586 | brgemm_attr_t brgattr; |
587 | brgattr.max_bs = max_batch; |
588 | const auto max_vpad = (exec_type == exec_vpad) |
589 | ? nstl::max(l_pad, r_pad) |
590 | : 0; |
591 | brgattr.max_top_vpad = max_vpad; |
592 | brgattr.max_bottom_vpad = max_vpad; |
593 | brgattr.fpmath_mode = attr->fpmath_mode_; |
594 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
595 | |
596 | brg.with_sum = with_sum; |
597 | CHECK(brgemm_desc_set_postops( |
598 | &brg, attr, &dst_md, LDD, bia_dt)); |
599 | } |
600 | } |
601 | } |
602 | } |
603 | |
604 | return status::success; |
605 | } |
606 | |
607 | void brg_blocking_t::update_blocks() { |
608 | if (sp_block <= 0 |
609 | || utils::one_of(0, id_block, ih_block, oc_block, ic_block, |
610 | kd_block, kh_block, kw_block, is_block, iw_block)) |
611 | return; |
612 | |
613 | nb_id = div_up(id, id_block); |
614 | nb_ih = div_up(ih, ih_block); |
615 | nb_oc = div_up(oc, oc_block); |
616 | nb_ic = div_up(ic, ic_block); |
617 | nb_kd = div_up(kd, kd_block); |
618 | nb_kh = div_up(kh, kh_block); |
619 | nb_kw = div_up(kw, kw_block); |
620 | nb_iw = div_up(iw, iw_block); |
621 | |
622 | sp = iw; |
623 | sp_block = iw_block; |
624 | nb_sp = nb_iw; |
625 | |
626 | ow_block = get_inp_block_size(iw_block, stride_w, ext_kw, l_pad); |
627 | oh_block = get_inp_block_size(ih_block, stride_h, ext_kh, t_pad); |
628 | od_block = get_inp_block_size(id_block, stride_d, ext_kd, f_pad); |
629 | } |
630 | |
631 | bool brg_blocking_t::fast_check_ic_block() const { |
632 | // This function is for reducing the number of blocking variants |
633 | // TODO: eliminate heuristic in this function |
634 | if (is_1x1) return fast_check_ic_block_1x1(); |
635 | const auto rnd_ic = rnd_up(ic, 16); |
636 | auto res = false; |
637 | if (ic_block == 64) { |
638 | res = (rnd_ic % ic_block == 0 && rnd_ic * wei_dsz < 192 * 4); |
639 | } else if (ic_block == 48) { |
640 | // TODO: edit this heuristic for bwd_d |
641 | const bool big_spatial |
642 | = od * oh * ow > 81 * stride_d * stride_h * stride_w; |
643 | res = (rnd_ic % ic_block == 0 && rnd_ic * wei_dsz <= 384 * 4 |
644 | && big_spatial); |
645 | } else |
646 | res = true; |
647 | |
648 | return res; |
649 | } |
650 | |
651 | float brg_blocking_t::est_eff() { |
652 | if (is_1x1) return est_eff_1x1(); |
653 | const auto icblock = ic_block / 16; |
654 | |
655 | const auto brgemm_microkernel_eff |
656 | = (static_cast<float>(icblock) * ur) / ((ur + icblock) * max_regs); |
657 | |
658 | const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur); |
659 | const auto brgemm_eff = squeeze_val(ur |
660 | * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block)) |
661 | / 64, |
662 | 0.5f); |
663 | |
664 | const auto sp_amount = nb_id * nb_ih * nb_sp; |
665 | const auto work_amount = mb * ngroups * nb_ic * sp_amount; |
666 | const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block)); |
667 | |
668 | const auto thr_eff = static_cast<float>(work_amount) |
669 | / utils::rnd_up(work_amount, nthr); |
670 | |
671 | const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block); |
672 | |
673 | const auto job = div_up(work_amount, nthr); |
674 | |
675 | auto job_eff = 1.f; |
676 | if (job < nthr) { |
677 | std::vector<dim_t> thr_jobs(nthr); |
678 | |
679 | for (int ithr = 0; ithr < nthr; ithr++) { |
680 | thr_jobs[ithr] = 0; |
681 | if (ithr >= work_amount) continue; |
682 | dim_t thr_job = 0; |
683 | int start {0}, end {0}; |
684 | balance211(work_amount, nthr, ithr, start, end); |
685 | int n {0}, g {0}, icb {0}, idp {0}, ihp {0}, spb {0}; |
686 | nd_iterator_init(start, n, mb, idp, id, ihp, ih, spb, nb_sp, g, |
687 | ngroups, icb, nb_ic); |
688 | |
689 | for (auto work = start; work < end; work++) { |
690 | const int icp = icb * ic_block; |
691 | const auto ic_sz = nstl::min(ic - icp, ic_block); |
692 | int sp_sz = 0; |
693 | const int spp = spb * sp_block; |
694 | sp_sz = nstl::min(sp - spp, sp_block); |
695 | thr_job += sp_sz * ic_sz; |
696 | |
697 | nd_iterator_step(n, mb, idp, id, ihp, ih, spb, nb_sp, g, |
698 | ngroups, icb, nb_ic); |
699 | } |
700 | thr_jobs[ithr] = thr_job; |
701 | } |
702 | |
703 | dim_t max_job = 0; |
704 | dim_t sum_job = 0; |
705 | for (int ithr = 0; ithr < nthr; ithr++) { |
706 | if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr]; |
707 | sum_job += thr_jobs[ithr]; |
708 | } |
709 | job_eff = max_job == 0 ? 1 |
710 | : static_cast<float>(sum_job) / (max_job * nthr); |
711 | |
712 | } else { |
713 | job_eff = thr_eff; |
714 | } |
715 | |
716 | const auto oc_blocking_size = oc_block * nb_oc_blocking; |
717 | const auto ic_blocking_size = ic_block * oc_blocking_size; |
718 | |
719 | int l = -1; |
720 | |
721 | // -- brgemm kernel: loop by simd_w -- |
722 | l++; |
723 | const auto inp_ur = inp_w(ur, kw_block); |
724 | loop[l].src.set(inp_ur * simd_w, 1, bcast_simd); |
725 | loop[l].dst.set(0, 1); |
726 | loop[l].wei.set(ic_block, 1); |
727 | |
728 | // -- brgemm kernel: loop by kw in kw_block -- |
729 | l++; |
730 | auto src_is = rnd_inp_simd(ur, kw_block, oc_blocking_size); |
731 | loop[l].src.set(src_is, 1, kw_block); |
732 | loop[l].dst.set(0, 1); |
733 | loop[l].wei.set(ic_blocking_size, 1); |
734 | |
735 | // -- brgemm kernel: loop by batch (grouped by kw_block) in ur -- |
736 | l++; |
737 | loop[l].src.set(src_is, 1); |
738 | loop[l].dst.set(0, 1); |
739 | auto wei_is = kw_block * ic_blocking_size; |
740 | loop[l].wei.set(wei_is, 1); |
741 | // -- brgemm kernel: loop by ur in sp_block -- |
742 | l++; |
743 | const auto nb_ur = div_up(sp_block, ur); |
744 | loop[l].src.set(kd_block * kh_block * src_is, 1); |
745 | loop[l].dst.set(ur * ic_block, 1); |
746 | wei_is = kd_block * kh_block * kw_block * ic_blocking_size; |
747 | loop[l].wei.set(wei_is, nb_ur); |
748 | |
749 | // -- harness: loop by k_blocks in ks -- |
750 | l++; |
751 | loop[l].src.set(kd_block * kh_block |
752 | * rnd_inp_simd(sp_block, kw_block, oc_blocking_size), |
753 | 1); |
754 | loop[l].dst.set(sp_block * ic_block, nb_kd * nb_kh * nb_kw); |
755 | loop[l].wei.set(wei_is, 1); |
756 | |
757 | // -- brgemm kernel: loop by oc_chunks -- |
758 | l++; |
759 | const auto oc_chunks = div_up(nb_oc, nb_oc_blocking); |
760 | loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, oc_blocking_size), 1); |
761 | loop[l].dst.set(sp_block * ic_block, oc_chunks); |
762 | wei_is = kd * kh * kw * ic_blocking_size; |
763 | loop[l].wei.set(wei_is, 1); |
764 | |
765 | const auto dim_ic = 1; |
766 | const auto nb_ic_thr = nstl::min(nb_ic, div_up(job, dim_ic)); |
767 | const auto ic_thr = nstl::min(ic, nb_ic_thr * ic_block); |
768 | const auto nsimd_ic_thr = div_up(ic_thr, simd_w); |
769 | |
770 | const auto dim_sp = ngroups * nb_ic; |
771 | const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp)); |
772 | const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block); |
773 | |
774 | const auto dim_ih = nb_sp * dim_sp; |
775 | const int nb_ih_thr = nstl::min(nb_ih, div_up(job, dim_ih)); |
776 | const int ih_thr = nstl::min(ih, nb_ih_thr * ih_block); |
777 | |
778 | const auto dim_id = nb_ih * dim_ih; |
779 | const int nb_id_thr = nstl::min(nb_id, div_up(job, dim_id)); |
780 | const int id_thr = nstl::min(id, nb_id_thr * id_block); |
781 | |
782 | src_is = kd * kh * rnd_inp_simd(sp_block, kw, oc); |
783 | |
784 | auto wei_op = kd * kh * kw * icblock * oc; |
785 | |
786 | // -- harness: loop by ic_block -- |
787 | l++; |
788 | loop[l].src.set(src_is, nb_ic_thr); |
789 | loop[l].dst.set(sp_block * ic_block, 1); |
790 | wei_is = kd * kh * kw * ic_block * oc; |
791 | wei_op = kd * kh * kw * nsimd_ic_thr * oc; |
792 | loop[l].wei.set(wei_is, 1); |
793 | |
794 | // -- harness: loop by sp_blocks -- |
795 | l++; |
796 | loop[l].src.set(src_is, 1); |
797 | const auto rnd_ic_for_sp = simd_w * nsimd_ic_thr; |
798 | loop[l].dst.set(sp_block * rnd_ic_for_sp, 1); |
799 | loop[l].wei.set(wei_op * simd_w, nb_sp_thr); |
800 | // oh_block almost all is 1. TODO: manage oh_block != 1 |
801 | // -- harness: loop by oh_blocks -- |
802 | l++; |
803 | src_is = kd * kh * rnd_inp_simd(sp_thr, kw, oc); |
804 | loop[l].src.set(ih_block * src_is, 1); |
805 | loop[l].dst.set(sp_thr * rnd_ic_for_sp, 1); |
806 | loop[l].wei.set(wei_op * simd_w, nb_ih_thr); |
807 | // od_block almost all is 1. TODO: manage oh_block != 1 |
808 | // -- harness: loop by od_blocks -- |
809 | l++; |
810 | loop[l].src.set(id_block * ih_thr * src_is, 1); |
811 | loop[l].dst.set(ih_thr * sp_thr * rnd_ic_for_sp, 1); |
812 | loop[l].wei.set(wei_op * simd_w, nb_id_thr); |
813 | |
814 | // -- harness: loop by mb -- |
815 | l++; |
816 | const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_ic)); |
817 | loop[l].src.set(id_thr * ih_thr * src_is, 1); |
818 | loop[l].dst.set(id_thr * ih_thr * sp_thr * nsimd_ic_thr * simd_w, 1); |
819 | loop[l].wei.set(kd * kh * kw * nsimd_ic_thr * simd_w * oc, mb_thr); |
820 | |
821 | const auto src_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr |
822 | * kd * kh * kw * oc; |
823 | const auto dst_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr |
824 | * nsimd_ic_thr; |
825 | wei_op = kd * kh * kw * nsimd_ic_thr * oc; |
826 | |
827 | // for "real" application set bench_iterations to 1 |
828 | const auto iterations = bench_iterations; |
829 | l++; |
830 | loop[l].src.set(src_op, iterations); |
831 | loop[l].dst.set(dst_op * simd_w, iterations); |
832 | loop[l].wei.set(wei_op * simd_w, iterations); |
833 | |
834 | auto src_mem_k = mem_k; |
835 | auto dst_mem_k = mem_k; |
836 | auto wei_mem_k = mem_k; |
837 | float src_rp = 1; |
838 | float dst_rp = 1; |
839 | float wei_rp = 1; |
840 | |
841 | for (auto il = l; il >= 0; il--) { |
842 | src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false); |
843 | dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false); |
844 | wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true); |
845 | src_rp *= loop[il].src.repeatn; |
846 | dst_rp *= loop[il].dst.repeatn; |
847 | wei_rp *= loop[il].wei.repeatn; |
848 | } |
849 | const auto src_ops = (src_op * src_rp) / iterations; |
850 | const auto dst_ops = (dst_op * dst_rp) / iterations; |
851 | const auto wei_ops = (wei_op * wei_rp) / iterations; |
852 | |
853 | const auto src_cost = src_mem_k * src_ops; |
854 | const auto dst_cost = dst_mem_k * dst_ops; |
855 | const auto wei_cost = wei_mem_k * wei_ops; |
856 | const auto call_kernel_cost = job * oc_chunks * nb_kd * nb_kh * nb_kw; |
857 | |
858 | const auto cache_eff = (static_cast<dim_t>(mb) * id * ih * sp * oc * ic) |
859 | / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost)); |
860 | const auto res_eff = ic_block_eff * brgemm_microkernel_eff * sp_eff |
861 | * job_eff * ur_eff * cache_eff * brgemm_eff; |
862 | return res_eff; |
863 | } |
864 | |
865 | void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_, |
866 | int kh_block_, bool maybe_use_buffer, int max_iw_block_thr) { |
867 | kd_block = kd_block_; |
868 | kh_block = kh_block_; |
869 | |
870 | kw_block = kw; |
871 | kd_block_pad = kd_block; |
872 | kh_block_pad = kh_block; |
873 | kw_block_pad = kw_block; |
874 | |
875 | const auto w_block_size = 2 * src_dsz * oc * owp + dst_dsz * iw * ic_block; |
876 | const auto other_size = wei_dsz * kd * kh * kw * oc * ic_block |
877 | + acc_dsz * 2 * amx_h * ic_block; |
878 | const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)), |
879 | other_size > L2 ? 0 : L2 - other_size); |
880 | if (odp * ohp * w_block_size > L2_available) { |
881 | id_block = utils::saturate( |
882 | 1, id, int(L2_available / (ohp * w_block_size))); |
883 | if (id_block == 1) |
884 | ih_block = utils::saturate( |
885 | 1, ih, int(L2_available / (w_block_size))); |
886 | else |
887 | ih_block = ih; |
888 | } else { |
889 | id_block = 1; |
890 | ih_block = ih; |
891 | } |
892 | if (is_amx(isa)) { |
893 | // try to fit into L1 |
894 | bool L1_fit_res = false; |
895 | auto cur_id_block = id_block; |
896 | auto cur_ih_block = ih_block; |
897 | const auto src_w_block_size |
898 | = src_dsz * oc * owp + dst_dsz * iw * ic_block; |
899 | if (src_w_block_size < L1) { |
900 | cur_id_block = utils::saturate( |
901 | 1, id, int(L1 / (ohp * src_w_block_size))); |
902 | if (cur_id_block == 1) |
903 | cur_ih_block |
904 | = utils::saturate(1, ih, int(L1 / (src_w_block_size))); |
905 | } |
906 | for (; cur_id_block > 1; cur_id_block--) { |
907 | const auto sp_size = cur_id_block * cur_ih_block * owp; |
908 | if ((static_cast<float>(id) / rnd_up(id, cur_id_block)) > 0.9f |
909 | && static_cast<float>(sp_size) / rnd_up(sp, amx_h) > 0.8f) { |
910 | L1_fit_res = true; |
911 | break; |
912 | } |
913 | } |
914 | if (cur_id_block == 1) { |
915 | for (; cur_ih_block > 1; cur_ih_block--) { |
916 | const auto sp_size = cur_ih_block * owp; |
917 | if ((static_cast<float>(ih) / rnd_up(ih, cur_ih_block)) > 0.9f |
918 | && sp_size > 128) { |
919 | L1_fit_res = true; |
920 | break; |
921 | } |
922 | } |
923 | } |
924 | if (L1_fit_res) { |
925 | id_block = cur_id_block; |
926 | ih_block = cur_ih_block; |
927 | } |
928 | } |
929 | |
930 | // limit ih_block to have good threading |
931 | const auto thr_ic_block |
932 | = div_up(nthr, mb * div_up((ic > 32 ? ngroups : 1) * ic, ic_block)); |
933 | const auto thr_id_block = div_up(id, thr_ic_block); |
934 | const auto thr_ih_block |
935 | = div_up(ih, thr_ic_block * div_up(id, thr_id_block)); |
936 | id_block = nstl::min(id_block, thr_id_block); |
937 | ih_block = nstl::min(ih_block, thr_ih_block); |
938 | while ((id_block % stride_d != 0 || id % id_block != 0) && id_block < id) |
939 | id_block++; |
940 | while ((ih_block % stride_h != 0 || ih % ih_block != 0) && ih_block < ih) |
941 | ih_block++; |
942 | |
943 | // --- Select iw_block ---- |
944 | const auto max_iw_block_L2 = iw; |
945 | auto start_iw_block = nstl::min(max_iw_block_thr, max_iw_block_L2); |
946 | |
947 | sp = iw; |
948 | const auto start_sp_block = start_iw_block; |
949 | auto prev_spb = 0; |
950 | for (auto ns = 1; ns <= sp; ns++) { |
951 | const auto spb = div_up(sp, ns); |
952 | if (spb == prev_spb || spb > start_sp_block) continue; |
953 | if (spb % stride_w != 0) continue; |
954 | if (iw % spb != 0) continue; |
955 | |
956 | prev_spb = spb; |
957 | iw_block = spb; |
958 | sp_block = iw_block; |
959 | |
960 | select_oc_block(); |
961 | |
962 | use_buffer = maybe_use_buffer; |
963 | |
964 | const status_t st = estimate_brgemm_ur(); |
965 | if (st != status::success) continue; |
966 | os_block = sp_block = iw_block; |
967 | update_blocks(); |
968 | |
969 | eff = est_eff(); |
970 | if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this; |
971 | } |
972 | } |
973 | |
974 | status_t brg_blocking_t::calc_blocks() { |
975 | sp = iw; |
976 | |
977 | nb_oc_blocking = 1; |
978 | // --- Select kernel blocking --- |
979 | // if dst_dt != acc_dt and we need to store intermediate |
980 | // results then we need the out buffer |
981 | const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum); |
982 | |
983 | std::vector<int> kd_blocks(1), kh_blocks(1); |
984 | kd_blocks[0] = kd; |
985 | kh_blocks[0] = kh; |
986 | if (kd != 1) { |
987 | kd_blocks.resize(2); |
988 | kd_blocks[1] = 1; |
989 | } |
990 | if (kh != 1) { |
991 | kh_blocks.resize(2); |
992 | kh_blocks[1] = 1; |
993 | } |
994 | |
995 | const auto thr_eff_threshold = 0.9f; |
996 | const auto max_iw_block_thr = utils::saturate(1, iw, |
997 | static_cast<int>(div_up( |
998 | mb * ngroups * nb_ic * is, thr_eff_threshold * nthr))); |
999 | |
1000 | iw_block = is_block = sp_block = -1; |
1001 | brg_blocking_t best_brgb = *this; |
1002 | for (const auto &kd_block : kd_blocks) { |
1003 | for (const auto &kh_block : kh_blocks) { |
1004 | iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer, |
1005 | max_iw_block_thr); |
1006 | } |
1007 | } |
1008 | *this = best_brgb; |
1009 | if (sp_block <= 0) return status::unimplemented; |
1010 | |
1011 | iw_block = is_block = sp_block; |
1012 | iw_tail = iw % iw_block; |
1013 | |
1014 | update_blocks(); |
1015 | |
1016 | return status::success; |
1017 | } |
1018 | |
1019 | bool brg_blocking_t::fast_check_ic_block_1x1() const { |
1020 | // This function checks for reducing the number of blocking variants |
1021 | // TODO: eliminate heuristic in this function |
1022 | if (is_1x1 && is_amx(isa)) return true; |
1023 | const auto rnd_ic = rnd_up(ic, 16); |
1024 | auto res = false; |
1025 | if (ic_block == 64) { |
1026 | const auto big_spatial |
1027 | = id * ih * iw >= 64 * stride_d * stride_h * stride_w; |
1028 | res = (rnd_ic % ic_block == 0 && big_spatial); |
1029 | } else if (ic_block == 48) { |
1030 | const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block); |
1031 | res = (ic_block_eff >= 0.95f); |
1032 | } else |
1033 | res = true; |
1034 | |
1035 | return res; |
1036 | } |
1037 | |
1038 | float brg_blocking_t::est_eff_1x1() { |
1039 | const auto icblock = ic_block / 16; |
1040 | |
1041 | auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float { |
1042 | const int nb = dim / block; |
1043 | constexpr int max_nb = 2; // only consider 2x2 tile blocking |
1044 | const int block2 = nstl::min(max_nb, nb); |
1045 | const int nb2 = nb / block2; |
1046 | const int nb2_tail = nb % block2; |
1047 | if (!use_ave) return block2; |
1048 | return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2); |
1049 | }; |
1050 | const bool use_ocb_ave = true; |
1051 | const auto icb_ave = calc_ave_blk(ic_block, 16, use_ocb_ave); |
1052 | const bool use_spb_ave = false; |
1053 | const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave); |
1054 | const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0; |
1055 | const auto M_tail_n_sp_blks |
1056 | = ur_block_tail > 0 ? M_tail / ur_block_tail : 0; |
1057 | |
1058 | // heuristic for maskrcnn workaround: use old blocking for some convolutions |
1059 | // TODO: remove this condition |
1060 | const bool maskrcnn_cond = (ic == 1024 && oc == 2048) |
1061 | || (ic == 1024 && oc == 512) || (ic == 256 && oc == 1024) |
1062 | || (ic == 512 && oc == 1024) || (ic == 512 && oc == 2048); |
1063 | const auto amx_fac = maskrcnn_cond |
1064 | ? (div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks)) |
1065 | : (static_cast<float>(div_up(M + M_tail, 16)) |
1066 | / (M_n_sp_blks + M_tail_n_sp_blks)); |
1067 | |
1068 | const auto brgemm_microkernel_eff = is_amx(isa) |
1069 | ? amx_fac * (static_cast<float>(icb_ave) * spb_ave) |
1070 | / (icb_ave + spb_ave) |
1071 | : (static_cast<float>(icblock) * ur) / ((ur + icblock) * max_regs); |
1072 | const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur); |
1073 | const auto brgemm_eff = squeeze_val(ur |
1074 | * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block)) |
1075 | / 64, |
1076 | 0.5f); |
1077 | |
1078 | const auto sp_amount = nb_id * nb_ih * nb_sp; |
1079 | const auto work_amount = mb * ngroups * nb_ic * sp_amount; |
1080 | |
1081 | const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block); |
1082 | const auto thr_eff = static_cast<float>(work_amount) |
1083 | / utils::rnd_up(work_amount, nthr); |
1084 | const auto ic_block_eff = static_cast<float>(ic) / rnd_up(ic, ic_block); |
1085 | |
1086 | const auto job = div_up(work_amount, nthr); |
1087 | |
1088 | const auto dim_ic = 1; |
1089 | const auto nb_ic_thr = nstl::min(nb_ic, div_up(job, dim_ic)); |
1090 | const auto ic_thr = nstl::min(ic, nb_ic_thr * ic_block); |
1091 | const auto nsimd_ic_thr = div_up(ic_thr, simd_w); |
1092 | |
1093 | const auto dim_sp = ngroups * nb_ic; |
1094 | const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp)); |
1095 | const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block); |
1096 | |
1097 | const auto dim_ih = nb_sp * dim_sp; |
1098 | const int nb_ih_thr = nstl::min(nb_ih, div_up(job, dim_ih)); |
1099 | const int ih_thr = nstl::min(ih, nb_ih_thr * ih_block); |
1100 | |
1101 | const auto dim_id = nb_ih * dim_ih; |
1102 | const int nb_id_thr = nstl::min(nb_id, div_up(job, dim_id)); |
1103 | const int id_thr = nstl::min(id, nb_id_thr * id_block); |
1104 | |
1105 | auto job_eff = 1.f; |
1106 | if (job < nthr) { |
1107 | std::vector<dim_t> thr_jobs(nthr); |
1108 | for (int ithr = 0; ithr < nthr; ithr++) { |
1109 | thr_jobs[ithr] = 0; |
1110 | if (ithr >= work_amount) continue; |
1111 | dim_t thr_job = 0; |
1112 | int start {0}, end {0}; |
1113 | balance211(work_amount, nthr, ithr, start, end); |
1114 | int n {0}, g {0}, icb {0}, idp {0}, ihp {0}, spb {0}; |
1115 | nd_iterator_init(start, n, mb, idp, id, ihp, ih, spb, nb_sp, g, |
1116 | ngroups, icb, nb_ic); |
1117 | |
1118 | for (auto work = start; work < end; work++) { |
1119 | const int icp = icb * ic_block; |
1120 | const auto ic_sz = nstl::min(ic - icp, ic_block); |
1121 | int sp_sz = 0; |
1122 | const int spp = spb * sp_block; |
1123 | sp_sz = nstl::min(sp - spp, sp_block); |
1124 | thr_job += sp_sz * ic_sz; |
1125 | nd_iterator_step(n, mb, idp, id, ihp, ih, spb, nb_sp, g, |
1126 | ngroups, icb, nb_ic); |
1127 | } |
1128 | thr_jobs[ithr] = thr_job; |
1129 | } |
1130 | |
1131 | dim_t max_job = 0; |
1132 | dim_t sum_job = 0; |
1133 | for (int ithr = 0; ithr < nthr; ithr++) { |
1134 | if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr]; |
1135 | sum_job += thr_jobs[ithr]; |
1136 | } |
1137 | |
1138 | job_eff = max_job == 0 ? 1 |
1139 | : static_cast<float>(sum_job) / (max_job * nthr); |
1140 | } else { |
1141 | job_eff = thr_eff; |
1142 | } |
1143 | |
1144 | const auto oc_blocking_size = oc_block * nb_oc_blocking; |
1145 | const auto ic_blocking_size = ic_block * oc_blocking_size; |
1146 | |
1147 | int l = -1; |
1148 | // -- brgemm kernel: loop by simd_w -- |
1149 | l++; |
1150 | loop[l].src.set(ur * simd_w, 1, bcast_simd); |
1151 | loop[l].dst.set(0, 1); |
1152 | loop[l].wei.set(ic_block, 1); |
1153 | |
1154 | // -- brgemm kernel: loop by ur in sp_block -- |
1155 | l++; |
1156 | const auto nb_ur = div_up(sp_block, ur); |
1157 | const auto nb_sp_no_tail = sp / sp_block; |
1158 | const auto sp_block_tail = sp % sp_block; |
1159 | const auto nb_ur_average |
1160 | = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp; |
1161 | loop[l].src.set(ur * rnd_simd(oc_blocking_size), 1); |
1162 | loop[l].dst.set(ur * ic_block, 1); |
1163 | loop[l].wei.set(ic_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur); |
1164 | // -- brgemm kernel: loop by ic_chunks -- |
1165 | l++; |
1166 | const auto oc_chunks = div_up(nb_oc, nb_oc_blocking); |
1167 | loop[l].src.set(sp_block * oc_blocking_size, 1); |
1168 | loop[l].dst.set(sp_block * ic_block, oc_chunks); |
1169 | auto wei_is = ic_blocking_size; |
1170 | auto wei_op = icblock * oc; |
1171 | loop[l].wei.set(wei_is, 1); |
1172 | |
1173 | // -- harness: loop by oc_block -- |
1174 | l++; |
1175 | loop[l].src.set(sp_block * rnd_simd(ic), nb_ic_thr); |
1176 | loop[l].dst.set(sp_block * ic_block, 1); |
1177 | wei_is = ic_block * ic; |
1178 | wei_op = nsimd_ic_thr * ic; |
1179 | loop[l].wei.set(wei_is, 1); |
1180 | |
1181 | const auto rnd_ic_for_sp = simd_w * nsimd_ic_thr; |
1182 | // -- harness: loop by sp_blocks -- |
1183 | l++; |
1184 | loop[l].src.set(sp_block * oc_blocking_size, 1); |
1185 | loop[l].dst.set(sp_block * rnd_ic_for_sp, 1); |
1186 | loop[l].wei.set(wei_op * simd_w, nb_sp_thr); |
1187 | // -- harness: loop by oh_blocks -- |
1188 | l++; |
1189 | loop[l].src.set(ih_block * sp_thr * rnd_simd(oc_blocking_size), 1); |
1190 | loop[l].dst.set(ih_block * sp_thr * rnd_ic_for_sp, 1); |
1191 | loop[l].wei.set(wei_op * simd_w, nb_ih_thr); |
1192 | // -- harness: loop by od_blocks -- |
1193 | l++; |
1194 | loop[l].src.set(id_block * ih_thr * sp_thr * rnd_simd(oc_blocking_size), 1); |
1195 | loop[l].dst.set(id_block * ih_thr * sp_thr * rnd_ic_for_sp, 1); |
1196 | loop[l].wei.set(wei_op * simd_w, nb_id_thr); |
1197 | |
1198 | // -- harness: loop by mb -- |
1199 | l++; |
1200 | const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_ic)); |
1201 | loop[l].src.set(id_thr * ih_thr * sp_thr * rnd_simd(oc_blocking_size), 1); |
1202 | loop[l].dst.set(nsimd_ic_thr * simd_w * id_thr * ih_thr * sp_thr, 1); |
1203 | loop[l].wei.set(nsimd_ic_thr * oc * simd_w, mb_thr); |
1204 | |
1205 | const auto src_op = static_cast<dim_t>(mb_thr) * id_thr * ih_thr * sp_thr |
1206 | * oc_blocking_size; |
1207 | const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_ic_thr * id_thr |
1208 | * ih_thr * sp_thr; |
1209 | wei_op = nsimd_ic_thr * oc; |
1210 | |
1211 | // for "real" application set bench_iterations to 1 |
1212 | const auto iterations = bench_iterations; |
1213 | l++; |
1214 | loop[l].src.set(src_op, iterations); |
1215 | loop[l].dst.set(dst_op * simd_w, iterations); |
1216 | loop[l].wei.set(wei_op * simd_w, iterations); |
1217 | |
1218 | auto src_mem_k = mem_k; |
1219 | auto dst_mem_k = mem_k; |
1220 | auto wei_mem_k = mem_k; |
1221 | float src_rp = 1; |
1222 | float dst_rp = 1; |
1223 | float wei_rp = 1; |
1224 | |
1225 | for (auto il = l; il >= 0; il--) { |
1226 | src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false); |
1227 | dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false); |
1228 | wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true); |
1229 | src_rp *= loop[il].src.repeatn; |
1230 | dst_rp *= loop[il].dst.repeatn; |
1231 | wei_rp *= loop[il].wei.repeatn; |
1232 | } |
1233 | const auto src_ops = (src_op * src_rp) / iterations; |
1234 | const auto dst_ops = (dst_op * dst_rp) / iterations; |
1235 | const auto wei_ops = (wei_op * wei_rp) / iterations; |
1236 | |
1237 | const auto src_cost = src_mem_k * src_ops; |
1238 | const auto dst_cost = dst_mem_k * dst_ops; |
1239 | const auto wei_cost = wei_mem_k * wei_ops; |
1240 | const auto call_kernel_cost = job * oc_chunks; |
1241 | |
1242 | const auto up_sp_size = id * ih; |
1243 | |
1244 | const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * oc * ic) |
1245 | / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost)); |
1246 | |
1247 | const auto res_eff = ic_block_eff * brgemm_microkernel_eff * sp_eff |
1248 | * job_eff * ur_eff * cache_eff * brgemm_eff; |
1249 | return res_eff; |
1250 | } |
1251 | |
1252 | brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) { |
1253 | return attr.zero_points_.has_default_values(arg) |
1254 | ? brgemm_broadcast_t::none |
1255 | : brgemm_broadcast_t::per_tensor; |
1256 | } |
1257 | |
1258 | status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, |
1259 | const convolution_desc_t &cd, memory_desc_t &diff_dst_md, |
1260 | memory_desc_t &weights_md, memory_desc_t &diff_src_md, |
1261 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads, |
1262 | bool enable_postops) { |
1263 | using namespace prop_kind; |
1264 | |
1265 | brg_blocking_t::L1 = platform::get_per_core_cache_size(1); |
1266 | brg_blocking_t::L2 = platform::get_per_core_cache_size(2); |
1267 | brg_blocking_t::L3 = platform::get_per_core_cache_size(2); |
1268 | |
1269 | if (!mayiuse(avx512_core)) return status::unimplemented; |
1270 | |
1271 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
1272 | const memory_desc_wrapper weights_d(&weights_md); |
1273 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
1274 | const memory_desc_wrapper bias_d(&bias_md); |
1275 | |
1276 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
1277 | int ndims = diff_src_d.ndims(); |
1278 | |
1279 | jcp = zero<decltype(jcp)>(); |
1280 | jcp.isa = isa; |
1281 | jcp.ndims = ndims; |
1282 | jcp.prop_kind = cd.prop_kind; |
1283 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1284 | jcp.mb = diff_src_d.dims()[0]; |
1285 | jcp.oc_without_padding = diff_dst_d.dims()[1] / jcp.ngroups; |
1286 | jcp.oc = jcp.oc_without_padding; |
1287 | jcp.ic_without_padding = diff_src_d.dims()[1]; |
1288 | jcp.ic = jcp.ic_without_padding / jcp.ngroups; |
1289 | jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1; |
1290 | jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims - 2]; |
1291 | jcp.iw = diff_src_d.dims()[ndims - 1]; |
1292 | jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1; |
1293 | jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims - 2]; |
1294 | jcp.ow = diff_dst_d.dims()[ndims - 1]; |
1295 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1296 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1297 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1298 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1299 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1300 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1301 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1302 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1303 | jcp.stride_w = cd.strides[ndims - 3]; |
1304 | |
1305 | if (everyone_is(1, jcp.stride_d, jcp.stride_h, jcp.stride_w)) |
1306 | return status::unimplemented; |
1307 | |
1308 | if (jcp.id % jcp.stride_d != 0 || jcp.ih % jcp.stride_h != 0 |
1309 | || jcp.iw % jcp.stride_w) |
1310 | return status::unimplemented; |
1311 | |
1312 | jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; |
1313 | jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; |
1314 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1315 | |
1316 | if (!everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)) |
1317 | return status::unimplemented; |
1318 | |
1319 | jcp.is = jcp.id * jcp.ih * jcp.iw; |
1320 | |
1321 | jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1322 | jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1323 | jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1324 | |
1325 | jcp.back_pad = calculate_end_padding( |
1326 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd); |
1327 | jcp.b_pad = calculate_end_padding( |
1328 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh); |
1329 | jcp.r_pad = calculate_end_padding( |
1330 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw); |
1331 | |
1332 | jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0 |
1333 | && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0 |
1334 | && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw); |
1335 | |
1336 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
1337 | |
1338 | jcp.src_dt = diff_dst_md.data_type; |
1339 | jcp.dst_dt = diff_src_md.data_type; |
1340 | jcp.wei_dt = weights_md.data_type; |
1341 | jcp.bia_dt = jcp.with_bias ? bias_md.data_type : data_type::undef; |
1342 | |
1343 | jcp.is_bf32 = everyone_is(f32, jcp.src_dt, jcp.wei_dt) |
1344 | && attr.fpmath_mode_ == fpmath_mode::bf16 && isa == avx512_core_amx; |
1345 | |
1346 | if (jcp.is_bf32) return status::unimplemented; |
1347 | |
1348 | brg_blocking_t::last_oc_block_size = data_type_vnni_granularity(jcp.wei_dt); |
1349 | |
1350 | // TODO: optimize grouped convolutions with small oc |
1351 | const bool is_grouped_small_oc |
1352 | = jcp.prop_kind != prop_kind::backward_weights && with_groups |
1353 | && jcp.ngroups > 1 && jcp.oc <= 16 |
1354 | && IMPLICATION(is_amx(jcp.isa), |
1355 | jcp.oc < 16 |
1356 | && jcp.ic < 16 |
1357 | // already optimized for amx 1x1 convs |
1358 | && !jcp.is_1x1) |
1359 | // Enable the shapes not supported in direct convs |
1360 | && IMPLICATION(with_groups, is_groups_ok(jcp)); |
1361 | if (is_grouped_small_oc) return status::unimplemented; |
1362 | |
1363 | // Dispatch the shapes to VNNI for better performance |
1364 | // TODO: optimize the perf of 3d shape with small oc and large spatial |
1365 | const auto max_small_shapes_sz = jcp.is_1x1 |
1366 | ? static_cast<int32_t>(brg_blocking_t::L1) / 2 |
1367 | : static_cast<int32_t>(brg_blocking_t::L1); |
1368 | const auto is_small_shape = is_amx(jcp.isa) && jcp.is <= 4 && jcp.oc <= 512 |
1369 | && jcp.mb * jcp.ngroups * jcp.oc * jcp.ic <= max_small_shapes_sz; |
1370 | const auto is_3d_small_oc = is_amx(jcp.isa) && jcp.ndims == 5 |
1371 | && jcp.oc * jcp.ic <= 32 && jcp.id >= 128 && jcp.ih >= 128 |
1372 | && jcp.iw >= 128; |
1373 | if (is_small_shape || is_3d_small_oc) return status::unimplemented; |
1374 | |
1375 | jcp.s8s8_avx512 = jcp.src_dt == s8 && !is_amx(jcp.isa); |
1376 | |
1377 | if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni))) |
1378 | return status::unimplemented; |
1379 | |
1380 | if (!IMPLICATION(jcp.wei_dt == bf16, mayiuse(avx512_core_bf16))) |
1381 | return status::unimplemented; |
1382 | |
1383 | if (!IMPLICATION(jcp.wei_dt == f16, mayiuse(avx512_core_fp16))) |
1384 | return status::unimplemented; |
1385 | |
1386 | jcp.acc_dt = types::is_integral_dt(jcp.src_dt) ? s32 : f32; |
1387 | |
1388 | jcp.src_dsz = types::data_type_size(jcp.src_dt); |
1389 | jcp.wei_dsz = types::data_type_size(jcp.wei_dt); |
1390 | jcp.dst_dsz = types::data_type_size(jcp.dst_dt); |
1391 | jcp.acc_dsz = types::data_type_size(jcp.acc_dt); |
1392 | jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
1393 | |
1394 | if (!post_ops_ok(jcp, attr, diff_src_d, enable_postops)) |
1395 | return status::unimplemented; |
1396 | |
1397 | jcp.simd_w = cpu_isa_traits<avx512_core>::vlen / jcp.src_dsz; |
1398 | jcp.amx_h = 16; |
1399 | jcp.amx_w = 64 / jcp.src_dsz; |
1400 | |
1401 | if (jcp.with_bias) { |
1402 | if (bias_d.format_kind() == format_kind::any) |
1403 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
1404 | } |
1405 | |
1406 | const auto &p = attr.post_ops_; |
1407 | jcp.with_sum = p.find(primitive_kind::sum) != -1; |
1408 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
1409 | jcp.with_eltwise = eltwise_ind != -1; |
1410 | |
1411 | const int binary_ind = p.find(primitive_kind::binary); |
1412 | jcp.with_binary = binary_ind != -1; |
1413 | |
1414 | jcp.src_zero_point |
1415 | = get_zp_type(attr, DNNL_ARG_DIFF_DST) != brgemm_broadcast_t::none; |
1416 | jcp.dst_zero_point |
1417 | = get_zp_type(attr, DNNL_ARG_DIFF_SRC) != brgemm_broadcast_t::none; |
1418 | |
1419 | const bool has_zero_points = jcp.src_zero_point || jcp.dst_zero_point; |
1420 | if (has_zero_points || jcp.s8s8_avx512) return status::unimplemented; |
1421 | |
1422 | jcp.nthr = nthreads; |
1423 | jcp.kh_sets = 1; |
1424 | jcp.kw_sets = 1; |
1425 | jcp.copy_block_only = false; |
1426 | jcp.amx_tile_load_xx = false; |
1427 | jcp.use_M_mask = 0; |
1428 | jcp.is_is_blocking = false; |
1429 | jcp.oskip = 0; |
1430 | jcp.use_uker = false; |
1431 | jcp.use_interleave_stores = false; |
1432 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf_default; |
1433 | jcp.brgemm_bd_loop_innermost = false; |
1434 | |
1435 | // fast check data layout before spending time for blocking selection |
1436 | format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc); |
1437 | |
1438 | CHECK(init_tag(jcp.src_tag, diff_dst_md, diff_dst_d, src_tag)); |
1439 | |
1440 | return status::success; |
1441 | } |
1442 | |
1443 | status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, |
1444 | const convolution_desc_t &cd, memory_desc_t &diff_dst_md, |
1445 | memory_desc_t &weights_md, memory_desc_t &diff_src_md, |
1446 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads, |
1447 | bool enable_postops) { |
1448 | |
1449 | using namespace prop_kind; |
1450 | |
1451 | if (!mayiuse(isa)) return status::unimplemented; |
1452 | |
1453 | CHECK(init_jcp(jcp, isa, cd, diff_dst_md, weights_md, diff_src_md, bias_md, |
1454 | attr, nthreads, enable_postops)); |
1455 | |
1456 | const memory_desc_wrapper diff_dst_d(&diff_dst_md); |
1457 | const memory_desc_wrapper weights_d(&weights_md); |
1458 | const memory_desc_wrapper diff_src_d(&diff_src_md); |
1459 | const memory_desc_wrapper bias_d(&bias_md); |
1460 | |
1461 | jcp.l_ovf = nstl::max(0, jcp.ext_kw - 1 - jcp.l_pad) / jcp.stride_w; |
1462 | jcp.r_ovf = nstl::max(0, jcp.ext_kw - 1 - jcp.r_pad) / jcp.stride_w; |
1463 | jcp.t_ovf = nstl::max(0, jcp.ext_kh - 1 - jcp.t_pad) / jcp.stride_h; |
1464 | jcp.b_ovf = nstl::max(0, jcp.ext_kh - 1 - jcp.b_pad) / jcp.stride_h; |
1465 | jcp.f_ovf = nstl::max(0, jcp.ext_kd - 1 - jcp.f_pad) / jcp.stride_d; |
1466 | jcp.back_ovf = nstl::max(0, jcp.kd - 1 - jcp.back_pad) / jcp.stride_d; |
1467 | |
1468 | jcp.odp = jcp.od + jcp.f_ovf + jcp.back_ovf; |
1469 | jcp.ohp = jcp.oh + jcp.t_ovf + jcp.b_ovf; |
1470 | jcp.owp = jcp.ow + jcp.l_ovf + jcp.r_ovf; |
1471 | |
1472 | using namespace data_type; |
1473 | // ======================= blocking ================================= |
1474 | |
1475 | const int min_ic_block = 16; |
1476 | int selected_ur = 0; |
1477 | |
1478 | //----------------------------------------------------------------------- |
1479 | |
1480 | jcp.exec_type = exec_trans; |
1481 | jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM |
1482 | |
1483 | // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more |
1484 | // precise calculation of jcp.max_batch |
1485 | jcp.max_batch = jcp.kd * jcp.kh * jcp.kw; |
1486 | |
1487 | jcp.wei_plain = false; |
1488 | |
1489 | // try loop_ndhwgc always for exec_trans |
1490 | jcp.loop_order = loop_ndhwgc; |
1491 | |
1492 | jcp.copy_block_only = true; |
1493 | |
1494 | const auto oc_padded_block = 16 * brg_blocking_t::last_oc_block_size; |
1495 | jcp.is_oc_padded = one_of(jcp.wei_dt, bf16, s8) |
1496 | && jcp.oc * jcp.kw_sets > oc_padded_block; |
1497 | |
1498 | if (is_amx(isa) && (/* heuristic */ jcp.kw_sets == 1 && jcp.iw < 256)) { |
1499 | jcp.use_M_mask = 0; |
1500 | |
1501 | jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; |
1502 | |
1503 | // assuming 2x2 decomposition in amx brgemm kernel |
1504 | // and overlap of input by kw |
1505 | const auto bd_blocking = 2 * jcp.amx_h; |
1506 | const auto ld_blocking = 2 * 16; |
1507 | const auto A_ds = jcp.src_dsz * bd_blocking * jcp.oc * jcp.kd * jcp.kh; |
1508 | const auto B_ds |
1509 | = jcp.wei_dsz * ld_blocking * jcp.oc * jcp.kd * jcp.kh * jcp.kw; |
1510 | const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking; |
1511 | if (A_ds + B_ds + C_ds > brg_blocking_t::L1) |
1512 | jcp.amx_tile_load_xx = true; |
1513 | } |
1514 | |
1515 | auto try_exec_type = [&]() { |
1516 | brg_blocking_t best_brgb = zero<decltype(best_brgb)>(); |
1517 | best_brgb.ic_block = min_ic_block; |
1518 | brg_blocking_t cur_brgb = zero<decltype(best_brgb)>(); |
1519 | cur_brgb.get_from_jcp(jcp); |
1520 | const auto start_icb = nstl::min(div_up(jcp.ic, 16), 4); |
1521 | |
1522 | auto finish_icb = 1; |
1523 | for (auto icb = start_icb; icb >= finish_icb; icb--) { |
1524 | cur_brgb.ic_block = icb * 16; |
1525 | cur_brgb.nb_ic = utils::div_up(jcp.ic, cur_brgb.ic_block); |
1526 | if (!cur_brgb.fast_check_ic_block()) continue; |
1527 | |
1528 | const status_t blocking_ok = cur_brgb.calc_blocks(); |
1529 | if (blocking_ok != status::success) continue; |
1530 | |
1531 | const status_t st = cur_brgb.get_brgemm_ur(&attr, diff_src_md); |
1532 | if (st != status::success) continue; |
1533 | cur_brgb.eff = cur_brgb.est_eff(); |
1534 | if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb; |
1535 | } |
1536 | if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0 |
1537 | || best_brgb.iw_block == 0) |
1538 | return false; |
1539 | best_brgb.save_to_jcp(jcp); |
1540 | selected_ur = best_brgb.ur; |
1541 | return true; |
1542 | }; |
1543 | |
1544 | if (!try_exec_type()) return status::unimplemented; |
1545 | |
1546 | // ============ end blocking =========================================== |
1547 | jcp.max_vpad = 0; |
1548 | |
1549 | if (jcp.iw_block == 0 || jcp.oc_block == 0 || jcp.ic_block == 0) |
1550 | return status::unimplemented; |
1551 | |
1552 | jcp.gemm_batch_size = jcp.nb_oc_blocking |
1553 | * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block, |
1554 | jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad); |
1555 | // to avoid cache concurrent write access from different threads |
1556 | size_t sc_size = sizeof(brgemm_batch_element_t); |
1557 | jcp.adjusted_batch_size |
1558 | = div_up(rnd_up(jcp.gemm_batch_size * sc_size, P4K), sc_size); |
1559 | |
1560 | CHECK(pick_tags(jcp, diff_dst_md, weights_md, diff_src_md, bias_md)); |
1561 | CHECK(attr.set_default_formats(&diff_src_md)); |
1562 | |
1563 | jcp.buffer_size = jcp.LDC * (jcp.M > 0 ? jcp.M : jcp.M_tail); |
1564 | |
1565 | jcp.nb_id = div_up(jcp.id, jcp.id_block); |
1566 | jcp.nb_ih = div_up(jcp.ih, jcp.ih_block); |
1567 | |
1568 | jcp.inp_buffer_size = rnd_up(jcp.odp * jcp.ohp * jcp.owp * jcp.ngroups |
1569 | * jcp.nb_oc * jcp.oc_block, |
1570 | P4K); |
1571 | jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_id) * jcp.nb_ih |
1572 | * jcp.nb_iw * jcp.ngroups * jcp.nb_oc, |
1573 | P4K); |
1574 | |
1575 | const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1; |
1576 | const bool with_pad = jcp.f_pad > 0 || jcp.back_pad > 0 || jcp.t_pad > 0 |
1577 | || jcp.b_pad > 0 || jcp.l_pad > 0 || jcp.r_pad > 0; |
1578 | |
1579 | if (jcp.s8s8_avx512) { |
1580 | weights_md.extra.flags = 0 | memory_extra_flags::compensation_conv_s8s8; |
1581 | weights_md.extra.compensation_mask = with_groups ? 0x3 : 0x1; |
1582 | } |
1583 | if (jcp.src_zero_point && !is_amx(jcp.isa)) { |
1584 | weights_md.extra.flags |
1585 | |= memory_extra_flags::compensation_conv_asymmetric_src; |
1586 | weights_md.extra.asymm_compensation_mask = with_groups ? 0x3 : 0x1; |
1587 | } |
1588 | |
1589 | // For padding shapes, we calculate the comp along with the computation |
1590 | // inside brgemm kernel when output size is small to get optimal perf |
1591 | // Or we calculate the comp using brgemm_coomp_pad kernel |
1592 | const auto output_sz = static_cast<dim_t>(jcp.mb) * jcp.ngroups * jcp.ic |
1593 | * jcp.id * jcp.ih * jcp.iw; |
1594 | const auto comp_with_pads = (jcp.src_zero_point || jcp.s8s8_avx512) |
1595 | && IMPLICATION(jcp.exec_type == exec_vpad, with_pad); |
1596 | jcp.req_brg_comp_pad = comp_with_pads && output_sz <= 8192 && jcp.ic < 512; |
1597 | jcp.req_cal_comp_pad = comp_with_pads && !jcp.req_brg_comp_pad; |
1598 | |
1599 | // estimate the number of kernel range combination for compensation |
1600 | const auto kd_cnt = 1 + utils::div_up(abs(jcp.f_pad), jcp.dilate_d + 1) |
1601 | + utils::div_up(abs(jcp.back_pad), jcp.dilate_d + 1); |
1602 | const auto kh_cnt = 1 + utils::div_up(abs(jcp.t_pad), jcp.dilate_h + 1) |
1603 | + utils::div_up(abs(jcp.b_pad), jcp.dilate_h + 1); |
1604 | |
1605 | jcp.ker_ranges_size = kd_cnt * kh_cnt; |
1606 | jcp.comp_a_buffer_size = static_cast<dim_t>(jcp.ngroups) * jcp.nb_ic |
1607 | * jcp.ker_ranges_size * jcp.iw * jcp.ic_block; |
1608 | jcp.s8s8_comp_buffer_size = jcp.comp_a_buffer_size; |
1609 | |
1610 | return status::success; |
1611 | } |
1612 | |
1613 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
1614 | const jit_brgemm_conv_conf_t &jcp) { |
1615 | if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs |
1616 | || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad)) |
1617 | scratchpad.book(key_brgemm_primitive_batch, |
1618 | static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size, |
1619 | sizeof(brgemm_batch_element_t), 64, P4K); |
1620 | |
1621 | size_t inp_buffer_size |
1622 | = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size; |
1623 | scratchpad.book( |
1624 | key_conv_brgemm_inp_buffer, inp_buffer_size, jcp.src_dsz, 0, P4K); |
1625 | size_t inp_buffer_mask_size |
1626 | = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size; |
1627 | scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size, |
1628 | sizeof(uint8_t), 0, P4K); |
1629 | |
1630 | if (jcp.use_buffer) { |
1631 | scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size, |
1632 | jcp.acc_dsz, 0, P4K); |
1633 | } |
1634 | if (is_amx(jcp.isa)) { |
1635 | scratchpad.book(key_conv_amx_tile_buffer, jcp.nthr * 2 * P4K, |
1636 | sizeof(char), 0, P4K); |
1637 | } |
1638 | if (jcp.s8s8_avx512 && jcp.req_cal_comp_pad) { |
1639 | scratchpad.book(key_brgemm_primitive_buffer_comp, |
1640 | jcp.s8s8_comp_buffer_size, sizeof(int32_t), 0, P4K); |
1641 | } |
1642 | if (jcp.src_zero_point && jcp.req_cal_comp_pad && !is_amx(jcp.isa)) { |
1643 | scratchpad.book(key_brgemm_primitive_zp_comp_a, jcp.comp_a_buffer_size, |
1644 | sizeof(int32_t), 0, P4K); |
1645 | } |
1646 | } |
1647 | |
1648 | } // namespace brgemm_convolution_bwd_utils |
1649 | |
1650 | } // namespace x64 |
1651 | } // namespace cpu |
1652 | } // namespace impl |
1653 | } // namespace dnnl |
1654 | |