1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/config.hpp" |
18 | |
19 | #include <cctype> |
20 | #include <cstring> |
21 | |
22 | #include "common/type_helpers.hpp" |
23 | #include "gpu/jit/conv/block_helper.hpp" |
24 | #include "gpu/jit/conv/config_lookup_table.hpp" |
25 | #include "gpu/jit/conv/grf_usage.hpp" |
26 | #include "gpu/jit/conv/normalization.hpp" |
27 | #include "gpu/jit/ir/block_2d_utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace jit { |
33 | |
34 | // Helper functions. |
35 | layout_t make_layout(const memory_desc_t &md) { |
36 | if (md.format_kind == format_kind::any) return layout_t(); |
37 | return layout_t(md, /*do_normalize=*/false); |
38 | } |
39 | |
40 | layout_t make_layout(const memory_desc_t &md, const std::string &tag) { |
41 | return layout_t(md, tag, /*do_normalize=*/false); |
42 | } |
43 | |
44 | layout_t make_layout(const type_t &type, const std::vector<dim_t> &dims, |
45 | const std::string &tag) { |
46 | return layout_t(type, 0, tag, dims, /*do_normalize=*/false); |
47 | } |
48 | |
49 | void set_default_format(memory_desc_t &md, const std::string &tag) { |
50 | if (md.format_kind != format_kind::any) return; |
51 | md = make_layout(md, tag).to_dnnl(md.dims); |
52 | } |
53 | |
54 | bool matches_tag(const layout_t &layout, const std::string &tag) { |
55 | if (layout.is_empty()) return false; |
56 | auto tag_layout = make_layout(layout.type(), layout.dims(), tag); |
57 | if (layout != tag_layout) return false; |
58 | return true; |
59 | } |
60 | |
61 | bool matches_tag_strict(const layout_t &layout, const std::string &tag) { |
62 | if (layout.is_empty()) return false; |
63 | auto tag_layout = make_layout(layout.type(), layout.dims(), tag); |
64 | if (!layout.is_strictly_equal(tag_layout)) return false; |
65 | return true; |
66 | } |
67 | |
68 | bool matches_tag(const memory_desc_t &md, const std::string &tag) { |
69 | if (md.format_kind == format_kind::any) return false; |
70 | return matches_tag(make_layout(md), tag); |
71 | } |
72 | |
73 | bool matches_tag_strict(const memory_desc_t &md, const std::string &tag) { |
74 | if (md.format_kind == format_kind::any) return false; |
75 | return matches_tag_strict(make_layout(md), tag); |
76 | } |
77 | |
78 | layout_t init_layout(memory_desc_t &user_md, const std::string &optimal_tag) { |
79 | auto optimal = make_layout(user_md, optimal_tag); |
80 | if (user_md.format_kind != format_kind::any) { |
81 | auto user = make_layout(user_md); |
82 | // If layouts are physically different return the layout passed by |
83 | // the user and return unimplemented later. |
84 | if (user != optimal) return user; |
85 | } else { |
86 | user_md = optimal.to_dnnl(user_md.dims); |
87 | } |
88 | return optimal; |
89 | } |
90 | |
91 | std::string prepend_groups_to_tag(const std::string &tag) { |
92 | auto ret = tag; |
93 | for (auto &c : ret) { |
94 | bool is_lower_dim = ('a' <= c && c < 'a' + DNNL_MAX_NDIMS); |
95 | bool is_upper_dim = ('A' <= c && c < 'A' + DNNL_MAX_NDIMS); |
96 | if (!is_lower_dim && !is_upper_dim) continue; |
97 | c += 1; |
98 | } |
99 | return "a" + ret; |
100 | } |
101 | |
102 | bool is_small_ic(const conv_problem_t &prb) { |
103 | int size = (int)types::data_type_size(prb.src_data_type); |
104 | if (size >= 4) |
105 | return prb.ic <= 8; |
106 | else |
107 | return prb.ic * size <= 16; |
108 | } |
109 | |
110 | bool is_small_oc(const conv_problem_t &prb) { |
111 | int size = (int)types::data_type_size(prb.dst_data_type); |
112 | if (size >= 4) |
113 | return prb.oc <= 8; |
114 | else |
115 | return prb.oc * size <= 16; |
116 | } |
117 | |
118 | bool is_dw_large_mb(const conv_problem_t &prb) { |
119 | return prb.is_dw && prb.mb >= 16; |
120 | } |
121 | |
122 | status_t conv_problem_t::init( |
123 | const engine_t *engine, const convolution_pd_t *conv_pd) { |
124 | using namespace compute; |
125 | |
126 | if (conv_pd->has_zero_dim_memory()) return status::unimplemented; |
127 | |
128 | this->conv_pd = conv_pd; |
129 | attr = conv_pd->attr(); |
130 | is_fwd = conv_pd->is_fwd(); |
131 | is_bwd_d = conv_pd->is_bwd_d(); |
132 | is_bwd_w = conv_pd->is_bwd_w(); |
133 | with_bias = conv_pd->with_bias(); |
134 | with_groups = conv_pd->with_groups(); |
135 | with_sum = with_sum_post_op(); |
136 | |
137 | src_data_type = conv_pd->invariant_src_md()->data_type; |
138 | wei_data_type = conv_pd->invariant_wei_md()->data_type; |
139 | bia_data_type = conv_pd->invariant_bia_md()->data_type; |
140 | dst_data_type = conv_pd->invariant_dst_md()->data_type; |
141 | fpmath_mode = attr->fpmath_mode_; |
142 | |
143 | ndims = conv_pd->ndims(); |
144 | |
145 | mb = conv_pd->MB(); |
146 | g = conv_pd->G(); |
147 | ic = ir_utils::safe_divide(conv_pd->IC(), g); |
148 | oc = ir_utils::safe_divide(conv_pd->OC(), g); |
149 | |
150 | // Input spatial. |
151 | id = conv_pd->ID(); |
152 | ih = conv_pd->IH(); |
153 | iw = conv_pd->IW(); |
154 | |
155 | // Output spatial. |
156 | od = conv_pd->OD(); |
157 | oh = conv_pd->OH(); |
158 | ow = conv_pd->OW(); |
159 | |
160 | // Kernel sizes. |
161 | kd = conv_pd->KD(); |
162 | kh = conv_pd->KH(); |
163 | kw = conv_pd->KW(); |
164 | |
165 | // Strides. |
166 | sd = conv_pd->KSD(); |
167 | sh = conv_pd->KSH(); |
168 | sw = conv_pd->KSW(); |
169 | |
170 | // Padding. |
171 | pd = conv_pd->padFront(); |
172 | ph = conv_pd->padT(); |
173 | pw = conv_pd->padL(); |
174 | |
175 | // Dilation. |
176 | dd = conv_pd->KDD(); |
177 | dh = conv_pd->KDH(); |
178 | dw = conv_pd->KDW(); |
179 | |
180 | try_reduce_to_1d(); |
181 | |
182 | is_dw = with_groups && (g > 1) && (oc == 1) && (ic == 1); |
183 | ksp = kd * kh * kw; |
184 | isp = id * ih * iw; |
185 | osp = od * oh * ow; |
186 | |
187 | auto *compute_engine = utils::downcast<const compute_engine_t *>(engine); |
188 | auto *device_info = compute_engine->device_info(); |
189 | gpu_arch_t gpu_arch = device_info->gpu_arch(); |
190 | auto hw = convert_dnnl_arch_to_ngen(gpu_arch); |
191 | |
192 | CHECK(init_abc_data_types(hw)); |
193 | CHECK(init_acc_data_type()); |
194 | CHECK(init_zero_points_config()); |
195 | |
196 | return status::success; |
197 | } |
198 | |
199 | void conv_problem_t::try_reduce_to_1d() { |
200 | bool is_1x1 = (kd * kh * kw == 1); |
201 | bool is_eq_oi = (od == id && oh == ih && ow == iw); |
202 | bool is_iw_1 = iw == 1 && kw == 1 && pw == 0 && ow == 1; |
203 | bool is_ih_1 = ih == 1 && kh == 1 && ph == 0 && oh == 1; |
204 | reduced_dim = 0; |
205 | auto shift_oh_to_ow = [&]() { |
206 | ow = oh; |
207 | iw = ih; |
208 | ih = 1; |
209 | oh = 1; |
210 | kw = kh; |
211 | kh = 1; |
212 | pw = ph; |
213 | ph = 0; |
214 | sw = sh; |
215 | sh = 1; |
216 | dw = dh; |
217 | dh = 0; |
218 | reduced_dim += 1; |
219 | }; |
220 | auto shift_od_to_oh = [&]() { |
221 | oh = od; |
222 | ih = id; |
223 | id = 1; |
224 | od = 1; |
225 | kh = kd; |
226 | kd = 1; |
227 | ph = pd; |
228 | pd = 0; |
229 | sh = sd; |
230 | sd = 1; |
231 | dh = dd; |
232 | dd = 0; |
233 | reduced_dim += 1; |
234 | }; |
235 | |
236 | if (is_iw_1) { shift_oh_to_ow(); } |
237 | if (is_ih_1 || is_iw_1) { shift_od_to_oh(); } |
238 | if (is_iw_1 && is_ih_1) { shift_oh_to_ow(); } |
239 | |
240 | if (is_1x1 && is_stride1() && is_eq_oi) { |
241 | ir_assert(pd == 0 && ph == 0 && pw == 0); |
242 | ow = od * oh * ow; |
243 | iw = id * ih * iw; |
244 | od = id = kd = 1; |
245 | oh = ih = kh = 1; |
246 | reduced_dim = 3; |
247 | } |
248 | } |
249 | |
250 | status_t conv_problem_t::init_zero_points_config() { |
251 | zp_cfg = zero_points_config_t(); |
252 | zp_cfg.do_src_compensation |
253 | = !attr->zero_points_.has_default_values(DNNL_ARG_SRC); |
254 | zp_cfg.do_dst_compensation |
255 | = !attr->zero_points_.has_default_values(DNNL_ARG_DST); |
256 | zp_cfg.is_runtime_src_zero_points |
257 | = !attr->zero_points_.defined(DNNL_ARG_SRC); |
258 | zp_cfg.is_runtime_dst_zero_points |
259 | = !attr->zero_points_.defined(DNNL_ARG_DST); |
260 | zp_cfg.is_common_src_zero_point = attr->zero_points_.common(DNNL_ARG_SRC); |
261 | zp_cfg.is_common_dst_zero_point = attr->zero_points_.common(DNNL_ARG_DST); |
262 | zp_cfg.common_src_zero_point = 0; |
263 | zp_cfg.common_dst_zero_point = 0; |
264 | return status::success; |
265 | } |
266 | |
267 | std::string conv_problem_t::desc_str(bool print_mb) const { |
268 | std::ostringstream oss; |
269 | if (print_mb) oss << "mb" << mb; |
270 | if (g > 1) oss << "g" << g; |
271 | oss << "ic" << ic; |
272 | |
273 | std::vector<int> xd = {id, od, kd, sd, dd, pd}; |
274 | std::vector<int> xh = {ih, oh, kh, sh, dh, ph}; |
275 | std::vector<int> xw = {iw, ow, kw, sw, dw, pw}; |
276 | std::vector<int> xdef = {1, 1, 1, 1, 0, 0}; |
277 | bool has_d = !ir_utils::is_equal(xd, xdef); |
278 | bool has_h = !ir_utils::is_equal(xh, xdef); |
279 | bool is_square = ir_utils::is_equal(xh, xw); |
280 | bool is_cubic = is_square && ir_utils::is_equal(xd, xh); |
281 | bool print_d = has_d; |
282 | bool print_h = has_h && !is_cubic; |
283 | bool print_w = !is_cubic && !is_square; |
284 | |
285 | if (print_d) oss << "id" << id; |
286 | if (print_h) oss << "ih" << ih; |
287 | if (print_w) oss << "iw" << iw; |
288 | oss << "oc" << oc; |
289 | if (print_d) oss << "od" << od; |
290 | if (print_h) oss << "oh" << oh; |
291 | if (print_w) oss << "ow" << ow; |
292 | if (print_d) oss << "kd" << kd; |
293 | if (print_h) oss << "kh" << kh; |
294 | if (print_w) oss << "kw" << kw; |
295 | if (print_d && sd != 1) oss << "sd" << sd; |
296 | if (print_h && sh != 1) oss << "sh" << sh; |
297 | if (print_w && sw != 1) oss << "sw" << sw; |
298 | if (print_d && dd != 0) oss << "dd" << dd; |
299 | if (print_h && dh != 0) oss << "dh" << dh; |
300 | if (print_w && dw != 0) oss << "dw" << dw; |
301 | if (print_d) oss << "pd" << pd; |
302 | if (print_h) oss << "ph" << ph; |
303 | if (print_w) oss << "pw" << pw; |
304 | return oss.str(); |
305 | } |
306 | |
307 | int get_default_max_tg_size(const hw_config_t &hw_cfg, int regs, int simd) { |
308 | const compute::gpu_arch_t arch = convert_ngen_arch_to_dnnl(hw_cfg.hw()); |
309 | const int max_eus_per_wg = compute::device_info_t::max_eus_per_wg(arch); |
310 | const int threads_per_eu |
311 | = compute::device_info_t::threads_per_eu(arch, regs > 128); |
312 | const int wg_per_thr = simd * compute::device_info_t::threads_per_eu(arch) |
313 | / threads_per_eu; |
314 | |
315 | // Optimal thread group size may differ from hardware thread count due |
316 | // to simd_size used in computation. |
317 | return std::min(max_eus_per_wg * utils::rnd_down_pow2(threads_per_eu), |
318 | static_cast<int>(hw_cfg.max_wg_size() / wg_per_thr)); |
319 | } |
320 | |
321 | std::vector<dim_t> get_prelu_weights_dims( |
322 | uint32_t mask, const memory_desc_t &md) { |
323 | std::vector<dim_t> dims(md.dims, md.dims + md.ndims); |
324 | for (int i = 0; i < md.ndims; ++i) |
325 | dims[i] = (mask & (1 << i)) ? dims[i] : 1; |
326 | return dims; |
327 | } |
328 | |
329 | std::string build_tag(const std::vector<int> &inner_blocks, |
330 | const std::vector<int> &outer_blocks, const std::vector<char> &letters, |
331 | const std::vector<int> &idxs) { |
332 | size_t n = letters.size(); |
333 | ir_assert(inner_blocks.size() == n); |
334 | ir_assert(outer_blocks.size() == n); |
335 | ir_assert(idxs.size() == n); |
336 | |
337 | std::string tag; |
338 | std::vector<bool> seen(n); |
339 | |
340 | // Iterate through outer blocks. |
341 | for (int i = (int)n - 1; i >= 0; i--) { |
342 | int idx = idxs[i]; |
343 | int blk = outer_blocks[idx]; |
344 | if (blk == 1) continue; |
345 | seen[idx] = true; |
346 | tag += std::to_string(blk) + letters[idx]; |
347 | } |
348 | |
349 | // Iterate through inner blocks. |
350 | for (int i = (int)n - 1; i >= 0; i--) { |
351 | int idx = idxs[i]; |
352 | int blk = inner_blocks[idx]; |
353 | if (blk == 1) continue; |
354 | seen[idx] = true; |
355 | tag += std::to_string(blk) + letters[idx]; |
356 | } |
357 | |
358 | if (tag.empty()) { |
359 | // Assume this is an activations tag, use NHWC by default. |
360 | tag = "axb" ; |
361 | } else { |
362 | tag = 'x' + tag; |
363 | for (int i = (int)n - 1; i >= 0; i--) { |
364 | char c = letters[i]; |
365 | if (c == ' ') continue; |
366 | if (seen[i]) c = std::toupper(c); |
367 | tag = c + tag; |
368 | } |
369 | } |
370 | |
371 | return tag; |
372 | } |
373 | |
374 | int pick_block_impl(bool prefer_rnd_up, int dim, int b0, int b1, int b2) { |
375 | int blocks[3] = {b0, b1, b2}; |
376 | int prev_blk = 1; |
377 | for (int i = 0; i < 3; i++) { |
378 | if (blocks[i] == 0) continue; |
379 | if (prefer_rnd_up) { |
380 | if (dim <= blocks[i] / 2) return prev_blk; |
381 | } else { |
382 | if (dim < blocks[i]) return prev_blk; |
383 | } |
384 | prev_blk = blocks[i]; |
385 | } |
386 | return prev_blk; |
387 | } |
388 | |
389 | int pick_block_rnd_up(int dim, int b0, int b1 = 0, int b2 = 0) { |
390 | return pick_block_impl(true, dim, b0, b1, b2); |
391 | } |
392 | |
393 | int pick_block(int dim, int b0, int b1 = 0, int b2 = 0) { |
394 | return pick_block_impl(false, dim, b0, b1, b2); |
395 | } |
396 | |
397 | struct nc_block_t { |
398 | nc_block_t(int n_block, int c_block, bool nc_order = true) |
399 | : n_block_(n_block), c_block_(c_block), nc_order_(nc_order) {} |
400 | |
401 | std::string tag() const { |
402 | std::vector<int> idxs = {1, 0}; |
403 | if (!nc_order_) std::swap(idxs[0], idxs[1]); |
404 | return build_tag({n_block_, c_block_}, {1, 1}, {'a', 'b'}, idxs); |
405 | } |
406 | |
407 | // Ideally, this should only depend on data type, direction, mb, c, and g to |
408 | // enable the same src/dst formats and avoid reorders between convolutions |
409 | static nc_block_t get_default_blocking(type_t type, bool is_dw, int n, |
410 | int c, int g, bool is_input, bool is_small_c, |
411 | int min_block_size = 0, bool nc_order = true, |
412 | bool force_default_c_blk = false) { |
413 | bool is_small_c_input |
414 | = (type.size() <= 2 && is_input && g == 1 && is_small_c); |
415 | auto default_c_blk = type.size() == 1 ? 32 : 16; |
416 | auto c_block = [&]() { |
417 | if (force_default_c_blk) return default_c_blk; |
418 | // Special case for small input channel shapes with dpas. |
419 | if (is_small_c_input) { |
420 | int packed_dword_elems = 4 / type.size(); |
421 | return std::max(packed_dword_elems, utils::rnd_up_pow2(c)); |
422 | } |
423 | auto blk_dim = is_dw ? g : g * c; |
424 | return pick_block_rnd_up(blk_dim, default_c_blk); |
425 | }(); |
426 | |
427 | // Non-depthwise convolutions currently require channel is a multiple of |
428 | // c_block. If that implementation restriction is removed, this logic |
429 | // could be removed. |
430 | if (g > 1 && !is_dw && c % c_block != 0 && c_block % c != 0) |
431 | c_block = 1; |
432 | |
433 | auto default_n_blk = type.size() < 4 ? 32 : 16; |
434 | auto n_block = [&]() { |
435 | if (c_block == 1) |
436 | return 1; |
437 | else if (is_small_c_input) |
438 | return pick_block(n, 8, 16); |
439 | else |
440 | return pick_block(n, 16, default_n_blk); |
441 | }(); |
442 | |
443 | // Require minimum block size, used to enable better message behavior |
444 | while (n_block * c_block * type.size() < min_block_size) { |
445 | // Prefer increasing blocks in dimensions with available data, and |
446 | // otherwise just increase c_block to meet requirements. Limit |
447 | // blocking dimensions to avoid untested edge cases. |
448 | if (c_block < c && c_block < default_c_blk) |
449 | c_block *= 2; |
450 | else if (n_block < n && n_block < default_n_blk) |
451 | n_block *= 2; |
452 | else { |
453 | c_block = utils::div_up(min_block_size, type.size() * n_block); |
454 | if (c_block > default_c_blk) c_block = default_c_blk; |
455 | break; |
456 | } |
457 | } |
458 | |
459 | return nc_block_t(n_block, c_block, nc_order); |
460 | } |
461 | |
462 | private: |
463 | int n_block_; |
464 | int c_block_; |
465 | bool nc_order_; |
466 | }; |
467 | |
468 | struct goi_block_t { |
469 | goi_block_t(fma_kind_t fma_kind, bool is_dw, bool is_bwd_d, int g_block, |
470 | int o_block, int i_block, int o_block_outer, int i_block_outer) |
471 | : fma_kind_(fma_kind) |
472 | , is_dw_(is_dw) |
473 | , is_bwd_d_(is_bwd_d) |
474 | , g_block_(g_block) |
475 | , o_block_(o_block) |
476 | , i_block_(i_block) |
477 | , o_block_outer_(o_block_outer) |
478 | , i_block_outer_(i_block_outer) {} |
479 | |
480 | std::string tag() const { |
481 | std::vector<char> wei_letters(3, ' '); |
482 | char wei_letter = 'a'; |
483 | for (int i = (is_dw_ ? 0 : 1); i < 3; i++) { |
484 | wei_letters[i] = wei_letter++; |
485 | } |
486 | std::vector<int> wei_idxs = {0, 1, 2}; // g, ic, oc |
487 | // dpas requires ic to go before oc in innermost blocks for weights. |
488 | if (fma_kind_ != fma_kind_t::mad) std::swap(wei_idxs[1], wei_idxs[2]); |
489 | if (is_bwd_d_) std::swap(wei_idxs[1], wei_idxs[2]); |
490 | return build_tag({g_block_, o_block_, i_block_}, |
491 | {1, o_block_outer_, i_block_outer_}, wei_letters, wei_idxs); |
492 | } |
493 | |
494 | static goi_block_t get_default_blocking(type_t type, int vec_size, |
495 | fma_kind_t fma_kind, bool is_bwd_d, bool is_small_ic, int g, int o, |
496 | int i) { |
497 | int x = o; |
498 | int y = i; |
499 | int g_block = 1; |
500 | int o_block = 1; |
501 | int i_block = 1; |
502 | int o_block_outer = 1; |
503 | int i_block_outer = 1; |
504 | int *x_block = &o_block; |
505 | int *y_block = &i_block; |
506 | int *x_block_outer = &o_block_outer; |
507 | int *y_block_outer = &i_block_outer; |
508 | // Backward by data requires flipped ic/oc in weights. |
509 | if (is_bwd_d) { |
510 | std::swap(x, y); |
511 | std::swap(x_block, y_block); |
512 | std::swap(x_block_outer, y_block_outer); |
513 | } |
514 | get_default_blocking(type, vec_size, fma_kind, is_bwd_d, is_small_ic, g, |
515 | x, y, g_block, *x_block, *y_block, *x_block_outer, |
516 | *y_block_outer); |
517 | return goi_block_t(fma_kind, is_dw(g, o, i), is_bwd_d, g_block, o_block, |
518 | i_block, o_block_outer, i_block_outer); |
519 | } |
520 | |
521 | static void get_default_blocking(type_t type, int vec_size, |
522 | fma_kind_t fma_kind, bool is_bwd_d, bool is_small_ic, int g, int x, |
523 | int y, int &g_block, int &x_block, int &y_block, int &x_block_outer, |
524 | int &y_block_outer) { |
525 | if (is_dw(g, x, y)) { |
526 | g_block = type.is_x8() ? 32 : 16; |
527 | } else if (fma_kind == fma_kind_t::mad) { |
528 | x_block = vec_size; |
529 | y_block = pick_block(y, 8, 16); |
530 | } else { |
531 | int packed_dword_elems = 4 / type.size(); |
532 | x_block = vec_size; |
533 | y_block = packed_dword_elems; |
534 | if (is_bwd_d || !is_small_ic) y_block_outer = 8; |
535 | } |
536 | } |
537 | |
538 | private: |
539 | static bool is_dw(int g, int o, int i) { |
540 | return (g > 1 && o == 1 && i == 1); |
541 | } |
542 | |
543 | fma_kind_t fma_kind_; |
544 | bool is_dw_; |
545 | bool is_bwd_d_; |
546 | int g_block_; |
547 | int o_block_; |
548 | int i_block_; |
549 | int o_block_outer_; |
550 | int i_block_outer_; |
551 | }; |
552 | |
553 | // TODO: Remove this logic and switch to an IR generation-driven flow. |
554 | bool can_use_2d_send(const conv_config_t &cfg, const layout_t &l, bool is_a) { |
555 | const auto &prb = cfg.prb(); |
556 | bool is_b = !is_a; |
557 | if (!cfg.is_ge_xe_hpc()) return false; |
558 | |
559 | bool with_blocking |
560 | = !cfg.iter_dims().is_empty() || !cfg.loop_dims().is_empty(); |
561 | |
562 | // Can't use 2D block messages for non-trivial strided dimensions. |
563 | if (is_a && (prb.is_fwd || prb.is_bwd_w) && prb.sw != 1 |
564 | && (prb.kw != 1 || prb.pw != 0)) { |
565 | if (with_blocking) { |
566 | if (cfg.iter_dim({"osp" , "ow" , "iw" }) > 1) return false; |
567 | } else if (prb.mb < 16) { |
568 | return false; |
569 | } |
570 | } |
571 | if (is_a && prb.is_bwd_d && prb.sw != 1) { |
572 | if (with_blocking && cfg.iter_dim({"osp" , "ow" , "iw" }) > 1) { |
573 | return false; |
574 | } else if (prb.mb < 16) { |
575 | return false; |
576 | } |
577 | } |
578 | |
579 | // Can't use 2D block messages for compound blocks. |
580 | if (is_a && with_blocking) { |
581 | bool has_mb_block = (cfg.iter_dim("mb" ) > 1); |
582 | bool has_sp_block = (cfg.iter_dim({"osp" , "ow" , "iw" }) > 1); |
583 | if (has_mb_block && has_sp_block) return false; |
584 | } |
585 | |
586 | // 2D messages does not support vnni format with 4 byte elements |
587 | if (type_t(prb.b_data_type).size() >= 4) return false; |
588 | |
589 | auto is_plain_ok = [&]() { |
590 | if (is_a || prb.is_bwd_w) return matches_tag_strict(l, "axb" ); |
591 | if (is_b && l.is_empty()) return true; |
592 | if (is_b && prb.is_fwd) return matches_tag_strict(l, "xba" ); |
593 | if (is_b && prb.is_bwd_d) return matches_tag_strict(l, "xab" ); |
594 | return false; |
595 | }; |
596 | |
597 | if (!is_plain_ok()) return false; |
598 | |
599 | // Check 2D block message limitations. |
600 | // Layouts for A: |
601 | // FWD: NHWC (src) |
602 | // BWD_D: NHWC (dst) |
603 | // BWD_W: NHWC (src) |
604 | // Layouts for B: |
605 | // FWD: HWIO (wei) |
606 | // BWD_D: HWOI (wei) |
607 | // BWD_W: NHWC (dst) |
608 | int a_width = (prb.is_fwd || prb.is_bwd_w) ? prb.ic : prb.oc; |
609 | int b_width = (prb.is_fwd || prb.is_bwd_w) ? prb.oc : prb.ic; |
610 | int a_max_height |
611 | = std::max((prb.is_fwd || prb.is_bwd_w) ? prb.iw : prb.ow, prb.mb); |
612 | int b_max_height = prb.is_fwd |
613 | ? prb.ic |
614 | : (prb.is_bwd_d ? prb.oc : std::max(prb.ow, prb.mb)); |
615 | int a_max_pitch = (prb.is_fwd || prb.is_bwd_w) ? (prb.ic * prb.isp) |
616 | : (prb.oc * prb.osp); |
617 | int b_max_pitch |
618 | = (prb.is_fwd || prb.is_bwd_d) ? b_width : (prb.oc * prb.osp); |
619 | int data_type_size = (is_a ? prb.a_data_type_size : prb.b_data_type_size); |
620 | int width = (is_a ? a_width : b_width); |
621 | int max_height = (is_a ? a_max_height : b_max_height); |
622 | int max_pitch = (is_a ? a_max_pitch : b_max_pitch); |
623 | if (!block_2d_width_ok(width, data_type_size)) return false; |
624 | if (!block_2d_height_ok(max_height)) return false; |
625 | if (!block_2d_pitch_ok(cfg.hw_cfg(), width, data_type_size)) return false; |
626 | if (!block_2d_pitch_ok(cfg.hw_cfg(), max_pitch, data_type_size)) |
627 | return false; |
628 | return true; |
629 | } |
630 | |
631 | void init_data_tags(const conv_config_t &cfg, bool allow_src_reorder, |
632 | bool allow_wei_reorder, bool allow_dst_reorder, |
633 | const memory_desc_t &src_md, const memory_desc_t &wei_md, |
634 | const memory_desc_t &dst_md, |
635 | |
636 | std::string &src_tag, std::string &wei_tag, std::string &dst_tag, |
637 | std::string &user_wei_tag) { |
638 | const auto &prb = cfg.prb(); |
639 | auto src_compute_type = prb.is_bwd_d ? prb.c_data_type : prb.a_data_type; |
640 | auto dst_compute_type = prb.is_fwd |
641 | ? prb.c_data_type |
642 | : (prb.is_bwd_d ? prb.a_data_type : prb.b_data_type); |
643 | auto wei_compute_type = prb.is_bwd_w ? prb.c_data_type : prb.b_data_type; |
644 | |
645 | int src_type_size = (int)types::data_type_size(src_compute_type); |
646 | |
647 | // Prefer larger messages for large mb bwd_w |
648 | bool is_bwd_w_message_opt = prb.is_bwd_w && src_type_size <= 2 |
649 | && allow_src_reorder && prb.mb >= 16; |
650 | int min_block_size = is_bwd_w_message_opt ? 128 : 0; |
651 | bool nc_order = is_bwd_w_message_opt ? false : true; |
652 | |
653 | nc_block_t src_blk = nc_block_t::get_default_blocking(src_compute_type, |
654 | prb.is_dw, prb.mb, prb.ic, prb.g, prb.is_fwd || prb.is_bwd_w, |
655 | is_small_ic(prb), min_block_size, nc_order); |
656 | // TODO: Force use of default_c_blk for bwd_w with bias due to reduction |
657 | // limitation to register granularity |
658 | nc_block_t dst_blk = nc_block_t::get_default_blocking(dst_compute_type, |
659 | prb.is_dw, prb.mb, prb.oc, prb.g, prb.is_bwd_d || prb.is_bwd_w, |
660 | is_small_oc(prb), 0, true, prb.is_bwd_w && prb.with_bias); |
661 | |
662 | auto wei_blk = goi_block_t::get_default_blocking(wei_compute_type, |
663 | cfg.vec_size(), cfg.fma_kind(), prb.is_bwd_d, is_small_ic(prb), |
664 | prb.g, prb.oc, prb.ic); |
665 | |
666 | src_tag = src_blk.tag(); |
667 | wei_tag = wei_blk.tag(); |
668 | dst_tag = dst_blk.tag(); |
669 | |
670 | // Use OhwIXoYi weights for small-channel forward convolution to ensure |
671 | // c-after-w order of reduction blocks to match the source layout. |
672 | if (is_small_ic(prb) && !prb.is_dw && prb.is_fwd && cfg.is_dp_fma()) { |
673 | const char *patterns[] = {"ABx" , "AxB" , "Abx" , "Axb" , nullptr}; |
674 | bool found = false; |
675 | for (auto *p = patterns; *p; p += 2) { |
676 | auto pos = wei_tag.find(*p); |
677 | if (pos == std::string::npos) continue; |
678 | wei_tag = wei_tag.replace(pos, std::strlen(*p), *(p + 1)); |
679 | found = true; |
680 | break; |
681 | } |
682 | ir_assert(found) << wei_tag; |
683 | } |
684 | |
685 | // Align weights layout between forward/backward by data in some cases via |
686 | // internal reorder to eliminate user-side reorder. |
687 | auto fwd_wei_blk = goi_block_t::get_default_blocking(wei_compute_type, |
688 | cfg.vec_size(), cfg.fma_kind(), /*is_bwd_d=*/false, |
689 | is_small_ic(prb), prb.g, prb.oc, prb.ic); |
690 | auto fwd_wei_tag = fwd_wei_blk.tag(); |
691 | if (fwd_wei_tag != wei_tag && allow_wei_reorder) { |
692 | user_wei_tag = fwd_wei_tag; |
693 | } |
694 | |
695 | // Override compute layouts when using nhwc with block 2D messages. |
696 | bool a_2d_ok = can_use_2d_send(cfg, make_layout(prb.a_md()), true); |
697 | bool b_2d_ok = can_use_2d_send(cfg, make_layout(prb.b_md()), false); |
698 | if (a_2d_ok && b_2d_ok) { |
699 | if (prb.is_bwd_d && !is_small_ic(prb)) { |
700 | wei_tag = "xab" ; |
701 | } else { |
702 | wei_tag = "xba" ; |
703 | } |
704 | user_wei_tag = "xba" ; |
705 | } |
706 | |
707 | // Override compute layouts for nhwc case. |
708 | bool src_matches = matches_tag(src_md, src_tag); |
709 | bool dst_matches = matches_tag(dst_md, dst_tag); |
710 | bool src_axb = matches_tag(src_md, "axb" ); |
711 | bool dst_axb = matches_tag(dst_md, "axb" ); |
712 | if (src_axb && dst_axb && (!src_matches || !dst_matches)) { |
713 | if (!allow_src_reorder) src_tag = "axb" ; |
714 | if (!allow_dst_reorder) dst_tag = "axb" ; |
715 | } |
716 | |
717 | // Override compute layouts for plain outputs. |
718 | if (prb.is_fwd && dst_axb) dst_tag = "axb" ; |
719 | if (prb.is_bwd_d && src_axb) src_tag = "axb" ; |
720 | } |
721 | |
722 | status_t init_tensor_layouts(conv_config_t &cfg, convolution_pd_t *pd) { |
723 | const auto &prb = cfg.prb(); |
724 | // Compute layout tags and user layout tags. If a compute layout is |
725 | // different from a user layout then an extra pre/post reorder will be |
726 | // executed before/after convolution. |
727 | std::string src_tag, user_src_tag; |
728 | std::string wei_tag, user_wei_tag; |
729 | std::string dst_tag, user_dst_tag; |
730 | |
731 | auto &src_md = *pd->invariant_src_md(); |
732 | auto &wei_md = *pd->invariant_wei_md(); |
733 | auto &dst_md = *pd->invariant_dst_md(); |
734 | auto &bia_md = *pd->invariant_bia_md(); |
735 | |
736 | // If src/dst is nhwc then set the other one with any to nhwc too (except |
737 | // 1st convolution). |
738 | bool is_small_ic_non_dw = is_small_ic(prb) && prb.g == 1; |
739 | bool is_small_oc_non_dw = is_small_oc(prb) && prb.g == 1; |
740 | bool propagate_nhwc = (matches_tag(src_md, "axb" ) && !is_small_ic_non_dw) |
741 | || matches_tag(dst_md, "axb" ); |
742 | if (propagate_nhwc) { |
743 | set_default_format(src_md, "axb" ); |
744 | set_default_format(dst_md, "axb" ); |
745 | } |
746 | |
747 | bool allow_src_reorder = false; |
748 | // Allow internal weights reorder in some cases. The goal is to have |
749 | // aligned weights layouts between fwd/bwd_d/bwd_w to reduce potential |
750 | // weights reorders during training. In general it's more efficient than |
751 | // the external reorder. |
752 | bool allow_wei_reorder = cfg.is_ge_xe_hpc() && cfg.is_dp_fma(); |
753 | bool allow_dst_reorder = false; |
754 | bool src_abx = matches_tag(src_md, "abx" ); |
755 | bool src_axb = matches_tag(src_md, "axb" ); |
756 | if ((src_abx || src_axb) && (prb.is_fwd || prb.is_bwd_w) |
757 | && is_small_ic_non_dw) { |
758 | allow_src_reorder = true; |
759 | } |
760 | |
761 | init_data_tags(cfg, allow_src_reorder, allow_wei_reorder, allow_dst_reorder, |
762 | src_md, wei_md, dst_md, src_tag, wei_tag, dst_tag, user_wei_tag); |
763 | |
764 | if (allow_src_reorder) { |
765 | if (src_abx) user_src_tag = "abx" ; |
766 | if (src_axb) user_src_tag = "axb" ; |
767 | } |
768 | |
769 | // Prefer nhwc for small-channel inputs. |
770 | if (user_src_tag.empty() && prb.is_fwd && is_small_ic_non_dw) { |
771 | if (!matches_tag(src_md, src_tag)) user_src_tag = "axb" ; |
772 | } |
773 | if (user_dst_tag.empty() && prb.is_bwd_d && is_small_oc_non_dw) { |
774 | if (!matches_tag(dst_md, dst_tag)) user_dst_tag = "axb" ; |
775 | } |
776 | |
777 | // Allow internal reorder from oihw/ohwi to more optimal weights layout. |
778 | if (allow_wei_reorder) { |
779 | if (matches_tag(wei_md, "abx" )) user_wei_tag = "abx" ; |
780 | if (matches_tag(wei_md, "axb" )) user_wei_tag = "axb" ; |
781 | } |
782 | |
783 | if (user_src_tag.empty()) user_src_tag = src_tag; |
784 | if (user_wei_tag.empty()) user_wei_tag = wei_tag; |
785 | if (user_dst_tag.empty()) user_dst_tag = dst_tag; |
786 | |
787 | bool wei_prepend_groups = (prb.with_groups && !prb.is_dw); |
788 | if (wei_prepend_groups) { |
789 | wei_tag = prepend_groups_to_tag(wei_tag); |
790 | user_wei_tag = prepend_groups_to_tag(user_wei_tag); |
791 | } |
792 | |
793 | // Select user layouts. |
794 | auto user_src_layout = init_layout(src_md, user_src_tag); |
795 | auto user_wei_layout = init_layout(wei_md, user_wei_tag); |
796 | auto user_dst_layout = init_layout(dst_md, user_dst_tag); |
797 | |
798 | layout_t user_bia_layout; |
799 | if (prb.with_bias) user_bia_layout = init_layout(bia_md, "a" ); |
800 | |
801 | if (!user_src_layout.is_strictly_equal(make_layout(src_md, user_src_tag))) |
802 | return status::unimplemented; |
803 | if (!user_dst_layout.is_strictly_equal(make_layout(dst_md, user_dst_tag))) |
804 | return status::unimplemented; |
805 | if (!user_wei_layout.is_strictly_equal(make_layout(wei_md, user_wei_tag))) |
806 | return status::unimplemented; |
807 | |
808 | auto src_layout = (src_tag != user_src_tag) ? make_layout(src_md, src_tag) |
809 | : user_src_layout; |
810 | auto wei_layout = (wei_tag != user_wei_tag) ? make_layout(wei_md, wei_tag) |
811 | : user_wei_layout; |
812 | auto dst_layout = (dst_tag != user_dst_tag) ? make_layout(dst_md, dst_tag) |
813 | : user_dst_layout; |
814 | auto bia_layout = user_bia_layout; |
815 | |
816 | if (prb.is_bwd_w) { |
817 | if (prb.wei_data_type == data_type::bf16) |
818 | wei_layout = wei_layout.retype(type_t::f32()); |
819 | if (prb.bia_data_type == data_type::bf16) |
820 | bia_layout = bia_layout.retype(type_t::f32()); |
821 | } |
822 | |
823 | auto &src = cfg.src_layout(); |
824 | auto &wei = cfg.wei_layout(); |
825 | auto &dst = cfg.dst_layout(); |
826 | auto &bia = cfg.bia_layout(); |
827 | |
828 | src.set_compute_unnormalized(src_layout); |
829 | src.set_user_unnormalized(user_src_layout); |
830 | wei.set_compute_unnormalized(wei_layout); |
831 | wei.set_user_unnormalized(user_wei_layout); |
832 | dst.set_compute_unnormalized(dst_layout); |
833 | dst.set_user_unnormalized(user_dst_layout); |
834 | bia.set_compute_unnormalized(bia_layout); |
835 | bia.set_user_unnormalized(user_bia_layout); |
836 | |
837 | // Normalize layouts: add group dimension for all layouts and reduce/fuse |
838 | // spatial dimensions when applicable. |
839 | normalize_conv_layouts(src_layout, wei_layout, dst_layout, bia_layout, |
840 | prb.with_groups, prb.g, prb.ic, prb.oc, prb.is_dw, prb.reduced_dim, |
841 | /*fuse_spatial=*/false, |
842 | /*add_groups=*/true); |
843 | normalize_conv_layouts(user_src_layout, user_wei_layout, user_dst_layout, |
844 | user_bia_layout, prb.with_groups, prb.g, prb.ic, prb.oc, prb.is_dw, |
845 | prb.reduced_dim, |
846 | /*fuse_spatial=*/false, |
847 | /*add_groups=*/true); |
848 | |
849 | src.set_compute(src_layout); |
850 | src.set_user(user_src_layout); |
851 | wei.set_compute(wei_layout); |
852 | wei.set_user(user_wei_layout); |
853 | dst.set_compute(dst_layout); |
854 | dst.set_user(user_dst_layout); |
855 | bia.set_compute(bia_layout); |
856 | bia.set_user(user_bia_layout); |
857 | |
858 | return status::success; |
859 | } |
860 | |
861 | bool hw_ok(const hw_config_t &hw_cfg) { |
862 | // Disable pre-XeHP until performance parity is reached with OpenCL |
863 | // kernels. |
864 | if (hw_cfg.hw() < ngen::HW::XeHP) return false; |
865 | return true; |
866 | } |
867 | |
868 | bool data_types_ok(const conv_problem_t &prb, const hw_config_t &hw_cfg) { |
869 | auto src = prb.src_data_type; |
870 | auto wei = prb.wei_data_type; |
871 | auto dst = prb.dst_data_type; |
872 | auto bia = prb.bia_data_type; |
873 | bool is_bf16 = utils::one_of(data_type::bf16, src, wei, dst, bia); |
874 | if (!prb.is_f64_conv() && utils::one_of(data_type::f64, src, wei, dst, bia)) |
875 | return false; |
876 | if (is_bf16 && hw_cfg.hw() <= ngen::HW::XeLP) return false; |
877 | if (prb.is_f64_conv() |
878 | && utils::one_of(hw_cfg.hw(), ngen::HW::XeLP, ngen::HW::XeHPG)) |
879 | return false; |
880 | if (prb.is_fwd) return true; |
881 | if (prb.is_bwd_d) return true; |
882 | if (prb.is_bwd_w) { |
883 | bool ok = true; |
884 | data_type_t default_acc_type |
885 | = src == data_type::f64 ? data_type::f64 : data_type::f32; |
886 | ok &= utils::one_of( |
887 | src, data_type::bf16, data_type::f32, data_type::f64); |
888 | ok &= (dst == src); |
889 | ok &= utils::one_of(wei, src, default_acc_type); |
890 | |
891 | if (prb.with_bias) { ok &= utils::one_of(bia, src, data_type::f32); } |
892 | return ok; |
893 | } |
894 | return false; |
895 | } |
896 | |
897 | bool zero_points_ok(const conv_problem_t &prb) { |
898 | auto *pd = prb.conv_pd; |
899 | auto *attr = pd->attr(); |
900 | |
901 | // TODO: implement the rest of the cases and remove this 'if' |
902 | bool ic_kdhw |
903 | = (prb.ic <= 8) && (prb.kd * prb.kh * prb.kw > 1) && !prb.is_dw; |
904 | if (!attr->zero_points_.has_default_values(DNNL_ARG_SRC) && ic_kdhw) |
905 | return false; |
906 | |
907 | using namespace data_type; |
908 | const auto input_type = (prb.is_fwd) ? pd->invariant_src_md()->data_type |
909 | : pd->invariant_dst_md()->data_type; |
910 | int mask_src = 0, mask_dst = 0; |
911 | attr->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
912 | attr->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
913 | |
914 | return IMPLICATION(!utils::one_of(input_type, s8, u8), |
915 | attr->zero_points_.has_default_values()) |
916 | && attr->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
917 | && (mask_src == 0 || mask_src == 1 << 1) |
918 | && (mask_dst == 0 || mask_dst == 1 << 1); |
919 | } |
920 | |
921 | std::vector<int> get_scale_args(const conv_problem_t &prb) { |
922 | conv_arg_helper_t h(prb); |
923 | std::vector<int> ret = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
924 | return ret; |
925 | } |
926 | |
927 | bool post_ops_ok(const conv_problem_t &prb, const hw_config_t &hw_cfg) { |
928 | auto *attr = prb.attr; |
929 | |
930 | // No post-ops are supported for f64 |
931 | if (prb.is_f64_conv() && !attr->has_default_values()) return false; |
932 | |
933 | if (prb.is_fwd || prb.is_bwd_d) { |
934 | using sm = primitive_attr_t::skip_mask_t; |
935 | auto attr_skip_mask = sm::post_ops | sm::sum_dt |
936 | | sm::zero_points_runtime | sm::scales_runtime; |
937 | if (!attr->has_default_values(attr_skip_mask)) return false; |
938 | } else { |
939 | if (!attr->has_default_values()) return false; |
940 | } |
941 | |
942 | if (!attr->scales_.has_default_values()) |
943 | if (!prb.is_s32_accumulator()) return false; |
944 | auto scale_args = get_scale_args(prb); |
945 | if (!attr->scales_.has_default_values(scale_args)) return false; |
946 | for (int arg : scale_args) { |
947 | int mask = attr->scales_.get(arg).mask_; |
948 | // XXX: per_oc for BWD_D is treated as per_ic assuming it's called from |
949 | // deconvolution. |
950 | int c_idx = prb.with_groups; |
951 | if (arg == DNNL_ARG_WEIGHTS) { |
952 | if (!utils::one_of(mask, 0, 1 << c_idx)) return false; |
953 | } else { |
954 | if (mask != 0) return false; |
955 | } |
956 | } |
957 | |
958 | for (int i = 0; i < attr->post_ops_.len(); i++) { |
959 | auto &po = attr->post_ops_.entry_[i]; |
960 | if (po.is_eltwise()) { |
961 | if (!jit_eltwise_injector_f32_is_supported(po.eltwise.alg)) |
962 | return false; |
963 | else if (po.eltwise.alg == alg_kind::eltwise_tanh |
964 | && hw_cfg.hw() == ngen::HW::XeHPG |
965 | && hw_cfg.eu_count() <= 128) |
966 | // Workaround for hard to reproduce issue in end to end |
967 | // workloads. It is unclear what the actual issue is as the |
968 | // kernel always works correctly in benchdnn. |
969 | return false; |
970 | } |
971 | } |
972 | return true; |
973 | } |
974 | |
975 | const memory_desc_t *output_md(const convolution_pd_t *pd) { |
976 | if (pd->is_fwd()) return pd->dst_md(); |
977 | if (pd->is_bwd_d()) return pd->diff_src_md(); |
978 | if (pd->is_bwd_w()) return pd->diff_weights_md(); |
979 | ir_error_not_expected(); |
980 | return nullptr; |
981 | } |
982 | |
983 | void maybe_override_from_lookup_table(conv_config_t &cfg) { |
984 | static conv_config_lookup_table_t table; |
985 | auto *s_params = table.find(cfg); |
986 | if (s_params) cfg.override_set(s_params); |
987 | } |
988 | |
989 | void maybe_override_from_env(conv_config_t &cfg) { |
990 | auto cfg_env = ir_utils::getenv_str("cfg" , "" ); |
991 | if (cfg_env.empty()) return; |
992 | cfg.override_set(cfg_env.c_str()); |
993 | } |
994 | |
995 | void maybe_override(conv_config_t &cfg) { |
996 | maybe_override_from_lookup_table(cfg); |
997 | #ifdef GEN_CONV_DEBUG |
998 | maybe_override_from_env(cfg); |
999 | #endif |
1000 | } |
1001 | |
1002 | status_t init_fma_kind(conv_config_t &cfg) { |
1003 | const auto &prb = cfg.prb(); |
1004 | auto fma_kind = fma_kind::get_supported_kind( |
1005 | cfg.hw(), prb.a_data_type, prb.b_data_type, prb.acc_data_type); |
1006 | // Force mad for some cases. |
1007 | if (prb.is_dw || (prb.g > 1 && prb.ic < 4 && prb.oc < 4 && prb.mb < 8)) |
1008 | fma_kind = fma_kind_t::mad; |
1009 | if (fma_kind == fma_kind_t::unknown) return status::unimplemented; |
1010 | cfg.set_fma_kind(fma_kind); |
1011 | return status::success; |
1012 | } |
1013 | |
1014 | status_t init_simd(conv_config_t &cfg) { |
1015 | if (cfg.exec_cfg_param().is_overridden("simd" )) return status::success; |
1016 | |
1017 | const auto &prb = cfg.prb(); |
1018 | int simd = fma_kind::get_simd_size(cfg.hw(), cfg.fma_kind(), |
1019 | prb.a_data_type, prb.b_data_type, prb.acc_data_type); |
1020 | cfg.set_simd(simd); |
1021 | return status::success; |
1022 | } |
1023 | |
1024 | status_t init_vec_size(conv_config_t &cfg) { |
1025 | const auto &prb = cfg.prb(); |
1026 | int vec_size = cfg.simd(); |
1027 | if (cfg.fma_kind() == fma_kind_t::mad) { |
1028 | int grf_elems = cfg.grf_size() / prb.acc_data_type_size; |
1029 | int vec_dim = (prb.is_fwd || prb.is_bwd_w) ? prb.oc : prb.ic; |
1030 | if (vec_size > grf_elems && vec_dim <= 8) vec_size = grf_elems; |
1031 | } |
1032 | cfg.set_vec_size(vec_size); |
1033 | return status::success; |
1034 | } |
1035 | |
1036 | bool post_op_layouts_ok(const conv_problem_t &prb) { |
1037 | auto *pd = prb.conv_pd; |
1038 | auto *attr = pd->attr(); |
1039 | |
1040 | for (int i = 0; i < attr->post_ops_.len(); i++) { |
1041 | auto &po = attr->post_ops_.entry_[i]; |
1042 | if (po.is_binary() || po.is_prelu()) { |
1043 | int mask = po.is_prelu() |
1044 | ? po.prelu.mask |
1045 | : utils::get_dims_mask(pd->invariant_dst_md()->dims, |
1046 | po.binary.src1_desc.dims, prb.ndims, true); |
1047 | // These cases don't have message-related limitations. |
1048 | if ((mask & (1 << 1)) == 0 || mask == (1 << 1)) continue; |
1049 | auto rhs_layout = po.is_prelu() ? layout_t(type_t::f32(), 0, |
1050 | get_prelu_weights_dims(po.prelu.mask, |
1051 | *pd->invariant_dst_md())) |
1052 | : layout_t(po.binary.src1_desc); |
1053 | // No blocks means it's a scalar, can be always loaded. |
1054 | if (rhs_layout.blocks().empty()) return true; |
1055 | |
1056 | auto rhs0 = rhs_layout.blocks()[0]; |
1057 | // Innermost block must: |
1058 | // - be across output channels |
1059 | // - be dense |
1060 | if (rhs0.dim_idx != 1 || dim_t(rhs0.stride) != 1) return false; |
1061 | } |
1062 | } |
1063 | return true; |
1064 | } |
1065 | |
1066 | status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg, |
1067 | const engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr) { |
1068 | hw_config_t hw_cfg(engine); |
1069 | |
1070 | if (!hw_ok(hw_cfg)) return status::unimplemented; |
1071 | if (!data_types_ok(prb, hw_cfg)) return status::unimplemented; |
1072 | if (!post_ops_ok(prb, hw_cfg)) return status::unimplemented; |
1073 | if (!zero_points_ok(prb)) return status::unimplemented; |
1074 | |
1075 | cfg.set_prb(prb); |
1076 | cfg.set_exec_cfg(exec_config_t(hw_cfg)); |
1077 | |
1078 | maybe_override(cfg); |
1079 | |
1080 | CHECK(init_fma_kind(cfg)); |
1081 | CHECK(init_simd(cfg)); |
1082 | CHECK(init_vec_size(cfg)); |
1083 | CHECK(init_tensor_layouts(cfg, pd)); |
1084 | |
1085 | CHECK(attr->set_default_formats(&prb.c_md())); |
1086 | |
1087 | if (!post_op_layouts_ok(prb)) return status::unimplemented; |
1088 | |
1089 | return status::success; |
1090 | } |
1091 | |
1092 | void init_hint(conv_config_t &cfg) { |
1093 | const auto &prb = cfg.prb(); |
1094 | if (prb.is_fwd && is_small_ic(prb)) { |
1095 | int max_tg = 16; |
1096 | auto hint = cfg.hint(); |
1097 | if (hint.max_tg_size() > max_tg) { |
1098 | hint.set_max_tg_size(max_tg); |
1099 | cfg.set_hint(hint); |
1100 | } |
1101 | } |
1102 | } |
1103 | |
1104 | void init_pipeline(conv_config_t &cfg) { |
1105 | if (cfg.pipeline().is_overridden()) return; |
1106 | |
1107 | const auto &prb = cfg.prb(); |
1108 | bool do_unroll = true; |
1109 | if (prb.is_fwd) { |
1110 | const int max_unroll = 9; |
1111 | if (prb.ksp > max_unroll) do_unroll = false; |
1112 | if (is_small_ic(prb)) do_unroll = false; |
1113 | } else if (prb.is_bwd_d) { |
1114 | // Do not perform full unrolling when there are too many inner |
1115 | // iterations. |
1116 | int kernel_limit = prb.is_f32_conv() ? 4 : 9; |
1117 | if (prb.ksp > kernel_limit) do_unroll = false; |
1118 | |
1119 | // Do not perform full unrolling with non-unit stride unless special |
1120 | // stride optimization is enabled. These cases have non-trivial |
1121 | // post-increment updates which result in unrolling all reduction loops |
1122 | // and exceeding the instruction cache. |
1123 | if (!prb.is_stride1() && !cfg.bwd_d_optimize_strided_iw()) |
1124 | do_unroll = false; |
1125 | } else if (prb.is_bwd_w) { |
1126 | int mb_iter_blk = cfg.iter_dim("mb" ); |
1127 | do_unroll = (cfg.is_ge_xe_hpc() && cfg.is_dp_fma() && mb_iter_blk > 1); |
1128 | } |
1129 | // Unrolling with mad or dp4a results in too large kernels. |
1130 | if (utils::one_of(cfg.fma_kind(), fma_kind_t::mad, fma_kind_t::dp4a) |
1131 | && (cfg.hw() >= ngen::HW::XeHPG || prb.mb != 1)) |
1132 | do_unroll = false; |
1133 | cfg.pipeline().set(do_unroll); |
1134 | } |
1135 | |
1136 | void init_send_2d_nhwc(conv_config_t &cfg) { |
1137 | const auto &prb = cfg.prb(); |
1138 | |
1139 | bool a_ok = can_use_a_2d_send(cfg); |
1140 | bool b_ok = can_use_b_2d_send(cfg); |
1141 | |
1142 | int64_t est_threads = 1; |
1143 | est_threads *= prb.g; |
1144 | est_threads *= prb.ic; |
1145 | est_threads *= prb.ksp; |
1146 | est_threads *= prb.mb; |
1147 | est_threads *= prb.oc; |
1148 | est_threads *= prb.osp; |
1149 | |
1150 | // Estimated max reduction size per thread for BWD_W. |
1151 | const int bwd_w_max_k_per_thr = 1000; |
1152 | // Estimated M x N elements per thread. |
1153 | const int mn_per_thr = 16 * 16; |
1154 | // Crosspoint to enable 2D send and blocking. |
1155 | const int min_threads_to_enable_2d = 1024; |
1156 | |
1157 | int k_fwd = prb.ic; |
1158 | int k_bwd_d = prb.oc; |
1159 | int k_bwd_w = std::min(bwd_w_max_k_per_thr, prb.mb * prb.osp); |
1160 | int k = prb.pick_by_dir(k_fwd, k_bwd_d, k_bwd_w); |
1161 | est_threads /= mn_per_thr; |
1162 | est_threads /= k; |
1163 | |
1164 | if (est_threads < min_threads_to_enable_2d) { |
1165 | cfg.set_send_2d_nhwc(false); |
1166 | return; |
1167 | } |
1168 | |
1169 | cfg.set_send_2d_nhwc(a_ok && b_ok); |
1170 | } |
1171 | |
1172 | void init_fuse_spatial(conv_config_t &cfg) { |
1173 | if (cfg.fuse_spatial_param().is_overridden()) return; |
1174 | |
1175 | const auto &prb = cfg.prb(); |
1176 | if (!prb.is_fwd || is_small_ic(prb)) return; |
1177 | |
1178 | // Spatial fusion may be suboptimal for small batch due to: |
1179 | // - Using smaller messages (load blocks are not fully dense anymore) |
1180 | // - Extra division arithmetic to work with fused indices |
1181 | if (cfg.src_layout().compute().inner_block(0) == 1) { |
1182 | if (!prb.is_fwd || cfg.is_ge_xe_hpc()) return; |
1183 | // Enable fusion for cases without m block with overwhelming spatial dim. |
1184 | if (prb.is_int8_dst() || (prb.osp < 4096) |
1185 | || !(prb.oh == prb.ow && prb.ow == prb.od)) { |
1186 | return; |
1187 | } |
1188 | } |
1189 | |
1190 | cfg.set_fuse_spatial(true); |
1191 | } |
1192 | |
1193 | void init_hoist_masks_from_compute_loop(conv_config_t &cfg) { |
1194 | if (cfg.send_2d_nhwc()) { |
1195 | cfg.set_hoist_masks_from_compute_loop(true); |
1196 | return; |
1197 | } |
1198 | if (!cfg.fuse_spatial()) return; |
1199 | if (cfg.hw() < ngen::HW::XeHPC) return; |
1200 | |
1201 | // Both nhwc layouts and mask hoisting require extra GRF memory so avoid |
1202 | // enabling both. |
1203 | if (matches_tag(cfg.a_layout().compute_unnormalized(), "axb" )) return; |
1204 | |
1205 | cfg.set_hoist_masks_from_compute_loop(true); |
1206 | } |
1207 | |
1208 | void init_ow_kw_grf_cache(conv_config_t &cfg) { |
1209 | const auto &prb = cfg.prb(); |
1210 | if (!prb.is_fwd || !is_small_ic(prb) || prb.kw < 3 || is_dw_large_mb(prb)) |
1211 | return; |
1212 | if (cfg.is_dp_fma()) return; |
1213 | if (cfg.fuse_spatial()) return; |
1214 | |
1215 | const int iw_blk_limit = 40; |
1216 | const int max_ow_blk = 16; |
1217 | int max_iw_blk |
1218 | = (prb.sw * (max_ow_blk - 1) + (prb.kw - 1) * (1 + prb.dw) + 1); |
1219 | if (max_iw_blk > iw_blk_limit) return; |
1220 | |
1221 | cfg.set_ow_kw_grf_cache(true); |
1222 | } |
1223 | |
1224 | void init_common_blocking(conv_config_t &cfg, block_helper_t &bh) { |
1225 | const auto &prb = cfg.prb(); |
1226 | |
1227 | auto &src_layout = cfg.src_layout().compute(); |
1228 | auto &wei_layout = cfg.wei_layout().compute(); |
1229 | auto &dst_layout = cfg.dst_layout().compute(); |
1230 | |
1231 | bh.set_hw_config(cfg.hw_cfg()); |
1232 | bh.set_fma_kind(cfg.fma_kind()); |
1233 | bh.set_simd_size(cfg.simd()); |
1234 | bh.set_vec_size(cfg.vec_size()); |
1235 | bh.set_max_tg_size(cfg.hint().max_tg_size()); |
1236 | bh.set_max_tg_overridden(cfg.hint().max_tg_overridden()); |
1237 | bh.set_abc_types(prb.a_data_type, prb.b_data_type, prb.acc_data_type); |
1238 | |
1239 | bh.set_dim("mb" , prb.mb); |
1240 | bh.set_dim("g" , prb.g); |
1241 | bh.set_dim("oc" , prb.oc); |
1242 | //take into account blocked ic channels when selecting block sizes |
1243 | bh.set_dim("ic" , |
1244 | prb.is_bwd_w ? std::max(src_layout.dims()[2], wei_layout.dims()[2]) |
1245 | : prb.ic); |
1246 | bh.set_dims({"kd" , "kh" , "kw" }, {prb.kd, prb.kh, prb.kw}); |
1247 | |
1248 | bh.set_b_dims({"g" }); |
1249 | |
1250 | if (prb.is_fwd) { |
1251 | if (cfg.fuse_spatial()) { |
1252 | bh.set_dims({"osp" }, {prb.osp}); |
1253 | bh.set_m_dims({"mb" , "osp" }); |
1254 | } else { |
1255 | bh.set_dims({"od" , "oh" , "ow" }, {prb.od, prb.oh, prb.ow}); |
1256 | bh.set_m_dims({"mb" , "od" , "oh" , "ow" }); |
1257 | } |
1258 | bh.set_n_dims({"oc" }); |
1259 | bh.set_k_dims({"ic" , "kd" , "kh" , "kw" }); |
1260 | } else if (prb.is_bwd_d) { |
1261 | ir_assert(!cfg.fuse_spatial()); |
1262 | bh.set_dims({"id" , "ih" , "iw" }, {prb.id, prb.ih, prb.iw}); |
1263 | bh.set_m_dims({"mb" , "id" , "ih" , "iw" }); |
1264 | bh.set_n_dims({"ic" }); |
1265 | bh.set_k_dims({"oc" , "kd" , "kh" , "kw" }); |
1266 | } else if (prb.is_bwd_w) { |
1267 | ir_assert(!cfg.fuse_spatial()); |
1268 | bh.set_dims({"od" , "oh" , "ow" }, {prb.od, prb.oh, prb.ow}); |
1269 | bh.set_m_dims({"ic" , "kd" , "kh" , "kw" }); |
1270 | bh.set_n_dims({"oc" }); |
1271 | bh.set_k_dims({"mb" , "od" , "oh" , "ow" }); |
1272 | } else { |
1273 | ir_error_not_expected(); |
1274 | } |
1275 | |
1276 | for (auto &kv : bh.dims()) { |
1277 | bh.set_pad_block(kv.first, cfg.pad_block(kv.first)); |
1278 | } |
1279 | |
1280 | // Set base blocks to align kernel blocking with layout blocking. |
1281 | if (prb.is_fwd) { |
1282 | bh.set_base_iter_block("mb" , src_layout.inner_block(0)); |
1283 | int src_g_blk = prb.is_dw ? src_layout.inner_block(1) : 1; |
1284 | int wei_g_blk = prb.is_dw ? wei_layout.inner_block(0) : 1; |
1285 | bh.set_base_iter_block("g" , src_g_blk, wei_g_blk); |
1286 | int src_ic_blk = src_layout.inner_block(2); |
1287 | int wei_ic_blk = wei_layout.inner_block(2); |
1288 | bh.set_base_iter_block("ic" , src_ic_blk, wei_ic_blk); |
1289 | if (cfg.is_g_mad()) { |
1290 | bh.set_base_iter_block( |
1291 | "oc" , dst_layout.inner_block(2), wei_layout.inner_block(1)); |
1292 | bh.dim("oc" ).set_iter_dim(bh.dim("oc" ).base_iter_block()); |
1293 | } |
1294 | } else if (prb.is_bwd_d) { |
1295 | bh.set_base_iter_block("mb" , dst_layout.inner_block(0)); |
1296 | int dst_g_blk = dst_layout.inner_block(1); |
1297 | int wei_g_blk = wei_layout.inner_block(0); |
1298 | bh.set_base_iter_block("g" , dst_g_blk, wei_g_blk); |
1299 | int dst_oc_blk = dst_layout.inner_block(2); |
1300 | int wei_oc_blk = wei_layout.inner_block(1); |
1301 | bh.set_base_iter_block("oc" , dst_oc_blk, wei_oc_blk); |
1302 | } else if (prb.is_bwd_w) { |
1303 | bh.set_base_iter_block("g" , wei_layout.inner_block(0)); |
1304 | int wei_oc_blk = wei_layout.inner_block(2); |
1305 | int dst_oc_blk = dst_layout.inner_block(2); |
1306 | bh.set_base_iter_block("oc" , wei_oc_blk, dst_oc_blk); |
1307 | int src_ic_blk = src_layout.inner_block(2); |
1308 | int wei_ic_blk = wei_layout.inner_block(2); |
1309 | bh.set_base_iter_block("ic" , src_ic_blk, wei_ic_blk); |
1310 | int src_mb_blk = src_layout.inner_block(0); |
1311 | int dst_mb_blk = dst_layout.inner_block(0); |
1312 | bh.set_base_iter_block("mb" , src_mb_blk, dst_mb_blk); |
1313 | } |
1314 | } |
1315 | |
1316 | bool should_use_spatial_blocking(const conv_config_t &cfg, |
1317 | dim_value_t mb_max_iter_dim, int d, int h, int w) { |
1318 | const auto &prb = cfg.prb(); |
1319 | if (!cfg.is_ge_xe_hpc()) return true; |
1320 | if (mb_max_iter_dim == 1) return true; |
1321 | if (cfg.send_2d_nhwc() && prb.is_bwd_d && prb.sw != 1) return false; |
1322 | int sp = (prb.ksp == 1 && prb.is_fwd) ? (d * h * w) : w; |
1323 | int block = 16; |
1324 | double mb_ratio = (double)prb.mb / utils::rnd_up(prb.mb, block); |
1325 | double sp_ratio = (double)sp / utils::rnd_up(sp, block); |
1326 | return sp_ratio >= mb_ratio; |
1327 | } |
1328 | |
1329 | void init_fwd(conv_config_t &cfg, block_helper_t &bh) { |
1330 | using namespace ir_utils; |
1331 | |
1332 | const auto &prb = cfg.prb(); |
1333 | const char *osp_name = cfg.fuse_spatial() ? "osp" : "ow" ; |
1334 | |
1335 | //set iter block for cases with no m block and large spatial |
1336 | if (!cfg.is_ge_xe_hpc() && cfg.src_layout().compute().inner_block(0) == 1 |
1337 | && prb.mb > 1 && (prb.oh == prb.ow && prb.ow == prb.od) |
1338 | && prb.osp >= 512) { |
1339 | bh.set_base_iter_block(osp_name, 16); |
1340 | } |
1341 | |
1342 | if (cfg.ow_kw_grf_cache()) { |
1343 | bh.set_base_iter_block("mb" , 1); |
1344 | bh.dim("mb" ).set_iter_dim(1); |
1345 | bh.set_max_iter_dim("mb" , 1); |
1346 | bh.set_max_m_tg_dim(2); |
1347 | bh.set_max_n_tg_dim(2); |
1348 | } |
1349 | if (cfg.is_g_mad()) { |
1350 | bh.set_base_iter_block("mb" , 1); |
1351 | bh.dim("mb" ).set_iter_dim(1); |
1352 | bh.set_max_iter_dim("mb" , 1); |
1353 | bh.set_max_iter_dim(osp_name, 4); |
1354 | } |
1355 | bh.set_loop_dim("kd" , prb.kd); |
1356 | bh.set_loop_dim("kh" , prb.kh); |
1357 | if (is_small_ic(prb) && !is_dw_large_mb(prb) |
1358 | && (prb.g == 1 || prb.ic == prb.oc)) { |
1359 | bh.set_block_dims({"kw" }); |
1360 | } else { |
1361 | bh.set_loop_dim("kw" , prb.kw); |
1362 | // mad is not tested with thread group k-slicing. |
1363 | if (cfg.is_dp_fma()) { |
1364 | bh.allow_k_tg_slicing(); |
1365 | bh.set_max_k_tg_dim(8); |
1366 | } |
1367 | } |
1368 | |
1369 | bh.set_block_dims({"g" , "oc" , "ic" , "mb" , osp_name}); |
1370 | bh.set_vector_dim(prb.is_dw || cfg.is_g_mad() ? "g" : "oc" ); |
1371 | bh.allow_fuse({"ic" , "kw" }); |
1372 | bh.allow_split({"oc" , "ic" , "kw" }); |
1373 | |
1374 | int mb_base_iter_blk = bh.dim("mb" ).base_iter_block(); |
1375 | // mb blocking is always outer so we can safely use a smaller divisor to |
1376 | // have more flexible blocking for some cases. |
1377 | int mb_base_iter_divisor = is_dw_large_mb(prb) ? 32 : 8; |
1378 | mb_base_iter_blk = math::gcd(mb_base_iter_divisor, mb_base_iter_blk); |
1379 | |
1380 | bh.set_base_iter_block("mb" , mb_base_iter_blk); |
1381 | |
1382 | bool use_sp_blocking = false; |
1383 | if (matches_tag(cfg.src_layout().compute_unnormalized(), "axb" )) { |
1384 | use_sp_blocking = should_use_spatial_blocking( |
1385 | cfg, bh.max_iter_dim("mb" ), prb.od, prb.oh, prb.ow); |
1386 | } else if (cfg.src_layout().compute().inner_block(0) == 1) { |
1387 | use_sp_blocking = true; |
1388 | } else if (prb.is_dw && !is_dw_large_mb(prb)) { |
1389 | use_sp_blocking = true; |
1390 | } else if (cfg.is_g_mad() || cfg.ow_kw_grf_cache()) { |
1391 | use_sp_blocking = true; |
1392 | } |
1393 | |
1394 | if (use_sp_blocking) { |
1395 | if (prb.is_dw) bh.set_pref_tg_block(osp_name); |
1396 | bh.allow_split({osp_name, "mb" }); |
1397 | bh.reorder({osp_name, "mb" }); |
1398 | if (!prb.is_int8_dst() && !cfg.fuse_spatial() && prb.mb < 16 |
1399 | && prb.iw % 8 != 0 && !prb.is_dw) { |
1400 | bh.set_max_m_tg_dim(1); |
1401 | } |
1402 | } else { |
1403 | const int large_sp_threshold = cfg.is_ge_xe_hpc() ? 128 : 256; |
1404 | if (!prb.is_dw && prb.ow > large_sp_threshold) |
1405 | bh.set_pref_tg_block("oc" ); |
1406 | else if (cfg.is_dp_fma() && prb.mb >= 16) |
1407 | bh.set_pref_tg_block(osp_name); |
1408 | bh.reorder({"mb" , osp_name}); |
1409 | auto spatial_dim = cfg.fuse_spatial() ? prb.osp : prb.ow; |
1410 | if (!cfg.send_2d_nhwc() && prb.mb >= 128 |
1411 | && (spatial_dim % 4 != 0 || spatial_dim < 64)) |
1412 | bh.allow_split({"mb" }); |
1413 | } |
1414 | |
1415 | if (prb.mb < 8 && !bh.any_pref_tg_block()) |
1416 | bh.set_pref_tg_block(prb.ow > prb.oc ? osp_name : "oc" ); |
1417 | |
1418 | bh.reorder({"ic" , "kw" }); |
1419 | |
1420 | if (cfg.send_2d_nhwc()) { |
1421 | // Use 64-byte reduction step to avoid partial cache line loads. |
1422 | bh.set_base_iter_block("ic" , 64 / prb.a_data_type_size); |
1423 | bh.set_reduce_m_block_hint(false); |
1424 | } |
1425 | |
1426 | bh.compute(); |
1427 | } |
1428 | |
1429 | void init_bwd_d(conv_config_t &cfg, block_helper_t &bh) { |
1430 | using namespace ir_utils; |
1431 | |
1432 | const auto &prb = cfg.prb(); |
1433 | bh.set_loop_dim("kw" , prb.kw); |
1434 | bh.set_loop_dim("kd" , prb.kd); |
1435 | bh.set_loop_dim("kh" , prb.kh); |
1436 | bh.set_block_dims({"g" , "oc" , "ic" , "mb" , "iw" }); |
1437 | bh.set_vector_dim(prb.is_dw ? "g" : "ic" ); |
1438 | bh.allow_split({"oc" , "ic" }); |
1439 | |
1440 | bool use_w_blocking = false; |
1441 | if (matches_tag(cfg.dst_layout().compute_unnormalized(), "axb" )) { |
1442 | use_w_blocking = should_use_spatial_blocking( |
1443 | cfg, bh.max_iter_dim("mb" ), prb.id, prb.ih, prb.iw); |
1444 | } else if (cfg.dst_layout().compute().inner_block(0) == 1) { |
1445 | use_w_blocking = true; |
1446 | } |
1447 | |
1448 | if (use_w_blocking) { |
1449 | bh.allow_fuse({"iw" , "mb" }); |
1450 | bh.allow_split({"iw" , "mb" }); |
1451 | bh.reorder({"iw" , "mb" }); |
1452 | } else { |
1453 | bh.reorder({"mb" , "iw" }); |
1454 | bh.set_base_iter_block("mb" , 8); |
1455 | } |
1456 | |
1457 | if (cfg.send_2d_nhwc()) { |
1458 | bh.set_base_iter_block("oc" , 64 / prb.a_data_type_size); |
1459 | if (!prb.is_stride1()) bh.allow_split({"mb" }); |
1460 | bh.set_reduce_m_block_hint(false); |
1461 | } |
1462 | |
1463 | bh.compute(); |
1464 | } |
1465 | |
1466 | void init_bwd_w(conv_config_t &cfg, block_helper_t &bh) { |
1467 | const auto &prb = cfg.prb(); |
1468 | bh.allow_k_grid_slicing(); |
1469 | |
1470 | bh.set_block_dims({"g" , "oc" , "ic" , "mb" , "oh" , "ow" }); |
1471 | bh.set_vector_dim(prb.is_dw ? "g" : "oc" ); |
1472 | |
1473 | if (prb.oc <= 32) bh.set_max_iter_dim("oc" , 16); |
1474 | if (prb.ic <= 32) bh.set_max_iter_dim("ic" , 16); |
1475 | |
1476 | if (is_small_ic(prb) && !prb.is_dw) { |
1477 | bh.set_block_dims({"kw" }); |
1478 | bh.set_max_tg_dim("kw" , 1); |
1479 | bh.set_max_iter_dim("kw" , 8); |
1480 | } |
1481 | |
1482 | // Avoid 2D spatial blocking when possible (when 1D blocking can be |
1483 | // enough). Extra oh/od loops may result in assembly bloat due to pipeline |
1484 | // unroll. |
1485 | if (prb.mb >= 32 && prb.ow >= 16) { |
1486 | bh.set_max_loop_dim("oh" , 1); |
1487 | bh.set_max_loop_dim("od" , 1); |
1488 | } |
1489 | |
1490 | bh.set_max_iter_dim("oh" , 1); |
1491 | |
1492 | bh.allow_split({"oc" , "ic" , "mb" , "ow" }); |
1493 | bh.allow_fuse({"ic" , "kw" }); |
1494 | bh.allow_fuse({"mb" , "oh" , "ow" }); |
1495 | bh.set_max_loop_dim("mb" , 2); |
1496 | bh.set_base_iter_block("mb" , math::gcd(16, bh.dim("mb" ).base_iter_block())); |
1497 | |
1498 | bh.reorder({"mb" , "ow" , "oh" }); |
1499 | |
1500 | if (cfg.send_2d_nhwc()) bh.set_reduce_m_block_hint(false); |
1501 | |
1502 | bh.compute(); |
1503 | } |
1504 | |
1505 | void init_blocking(conv_config_t &cfg) { |
1506 | const auto &prb = cfg.prb(); |
1507 | block_helper_t bh; |
1508 | init_common_blocking(cfg, bh); |
1509 | if (prb.is_fwd) { |
1510 | init_fwd(cfg, bh); |
1511 | } else if (prb.is_bwd_d) { |
1512 | init_bwd_d(cfg, bh); |
1513 | } else if (prb.is_bwd_w) { |
1514 | init_bwd_w(cfg, bh); |
1515 | } else { |
1516 | ir_error_not_expected(); |
1517 | } |
1518 | |
1519 | auto &dims = cfg.dims(); |
1520 | auto &iter_dims = cfg.iter_dims(); |
1521 | auto &thread_group_dims = cfg.thread_group_dims(); |
1522 | auto &loop_dims = cfg.loop_dims(); |
1523 | |
1524 | for (auto &kv : bh.dims()) { |
1525 | auto &name = kv.first; |
1526 | auto &d = kv.second; |
1527 | if (!dims.is_overridden()) dims.set(name, d.size()); |
1528 | if (!iter_dims.is_overridden()) iter_dims.set(name, d.iter_dim()); |
1529 | if (!thread_group_dims.is_overridden()) |
1530 | thread_group_dims.set(name, d.tg_dim()); |
1531 | if (!loop_dims.is_overridden()) loop_dims.set(name, d.loop_dim()); |
1532 | if (cfg.shrink_tg_dims()) { |
1533 | int dim = cfg.dim(name); |
1534 | int iter = cfg.iter_dim(name); |
1535 | int tg = cfg.thread_group_dim(name); |
1536 | int loop = cfg.loop_dim(name); |
1537 | int pad_blk = cfg.pad_block(name); |
1538 | while (tg > 1) { |
1539 | int padded = utils::rnd_up( |
1540 | dim, math::lcm(iter * tg * loop, pad_blk)); |
1541 | if (dim * 2 > padded) break; |
1542 | tg = std::max(1, tg / 2); |
1543 | } |
1544 | cfg.thread_group_dims().set(name, tg); |
1545 | } |
1546 | } |
1547 | } |
1548 | |
1549 | const char **get_kernel_grid_conv_dims(const conv_problem_t &prb, int idx) { |
1550 | static const char *fwd_0[] = {"oc" , nullptr}; |
1551 | static const char *fwd_1[] = {"g" , "osp" , "od" , "oh" , "ow" , nullptr}; |
1552 | static const char *fwd_2[] = {"mb" , nullptr}; |
1553 | static const char *bwd_d_0[] = {"ic" , nullptr}; |
1554 | static const char *bwd_d_1[] = {"g" , "id" , "ih" , "iw" , nullptr}; |
1555 | static const char *bwd_d_2[] = {"mb" , nullptr}; |
1556 | static const char *bwd_w_0[] = {"oc" , nullptr}; |
1557 | static const char *bwd_w_1[] |
1558 | = {"ic" , "kd" , "kh" , "kw" , "od" , "oh" , "ow" , nullptr}; |
1559 | static const char *bwd_w_2[] = {"g" , "mb" , nullptr}; |
1560 | static const char **fwd[] = {fwd_0, fwd_1, fwd_2}; |
1561 | static const char **bwd_d[] = {bwd_d_0, bwd_d_1, bwd_d_2}; |
1562 | static const char **bwd_w[] = {bwd_w_0, bwd_w_1, bwd_w_2}; |
1563 | ir_assert(idx >= 0 && idx < 3); |
1564 | if (prb.is_fwd) return fwd[idx]; |
1565 | if (prb.is_bwd_d) return bwd_d[idx]; |
1566 | if (prb.is_bwd_w) return bwd_w[idx]; |
1567 | ir_error_not_expected(); |
1568 | return nullptr; |
1569 | } |
1570 | |
1571 | const char **get_thread_group_grid_conv_dims( |
1572 | const conv_problem_t &prb, int idx) { |
1573 | static const char *fwd_0[] = {"oc" , nullptr}; |
1574 | static const char *fwd_1[] = {"mb" , "osp" , "ow" , nullptr}; |
1575 | static const char *fwd_2[] = {"ic" , nullptr}; |
1576 | static const char *bwd_d_0[] = {"ic" , nullptr}; |
1577 | static const char *bwd_d_1[] = {"mb" , "iw" , nullptr}; |
1578 | static const char *bwd_d_2[] = {"oc" , nullptr}; |
1579 | static const char *bwd_w_0[] = {"oc" , nullptr}; |
1580 | static const char *bwd_w_1[] = {"ic" , nullptr}; |
1581 | static const char *bwd_w_2[] = {nullptr}; |
1582 | static const char **fwd[] = {fwd_0, fwd_1, fwd_2}; |
1583 | static const char **bwd_d[] = {bwd_d_0, bwd_d_1, bwd_d_2}; |
1584 | static const char **bwd_w[] = {bwd_w_0, bwd_w_1, bwd_w_2}; |
1585 | ir_assert(idx >= 0 && idx < 3); |
1586 | if (prb.is_fwd) return fwd[idx]; |
1587 | if (prb.is_bwd_d) return bwd_d[idx]; |
1588 | if (prb.is_bwd_w) return bwd_w[idx]; |
1589 | ir_error_not_expected(); |
1590 | return nullptr; |
1591 | } |
1592 | |
1593 | void init_padded_dims(conv_config_t &cfg) { |
1594 | for (auto &kv : cfg.dims().get()) { |
1595 | auto &name = kv.first; |
1596 | int dim = cfg.dim(name); |
1597 | int iter = cfg.iter_dim(name); |
1598 | int tg = cfg.thread_group_dim(name); |
1599 | int loop = cfg.loop_dim(name); |
1600 | int blk = iter * tg * loop; |
1601 | int pad_blk = cfg.pad_block(name); |
1602 | int padded = utils::rnd_up(dim, math::lcm(blk, pad_blk)); |
1603 | cfg.padded_dims().set(name, padded); |
1604 | } |
1605 | } |
1606 | |
1607 | void init_kernel_grid(conv_config_t &cfg) { |
1608 | const auto &prb = cfg.prb(); |
1609 | auto get = [&](const char *name) { |
1610 | int padded = cfg.padded_dim(name); |
1611 | int iter = cfg.iter_dim(name); |
1612 | int loop = cfg.loop_dim(name); |
1613 | int tg = cfg.thread_group_dim(name); |
1614 | int tg_block = iter * loop * tg; |
1615 | return ir_utils::safe_divide(padded, tg_block); |
1616 | }; |
1617 | |
1618 | const int grid_ndims = 3; |
1619 | std::vector<int> dims = {1, 1, 1}; |
1620 | for (int i = 0; i < grid_ndims; i++) { |
1621 | auto **dd = get_kernel_grid_conv_dims(prb, i); |
1622 | for (auto **d = dd; *d; d++) |
1623 | dims[i] *= get(*d); |
1624 | } |
1625 | cfg.set_kernel_grid(grid_info_t(dims, "grid_idx" )); |
1626 | } |
1627 | |
1628 | void init_thread_group_grid(conv_config_t &cfg) { |
1629 | const auto &prb = cfg.prb(); |
1630 | auto get = [&](const char *name) { |
1631 | return cfg.thread_group_dims().get(name); |
1632 | }; |
1633 | |
1634 | const int grid_ndims = 3; |
1635 | std::vector<int> dims = {1, 1, 1}; |
1636 | for (int i = 0; i < grid_ndims; i++) { |
1637 | auto **dd = get_thread_group_grid_conv_dims(prb, i); |
1638 | for (auto **d = dd; *d; d++) |
1639 | dims[i] *= get(*d); |
1640 | } |
1641 | cfg.set_thread_group_grid(grid_info_t(dims, "tg_idx" )); |
1642 | } |
1643 | |
1644 | // Enable optimization for strided BWD_D convolution. |
1645 | void init_bwd_d_optimize_strided(conv_config_t &cfg) { |
1646 | const auto &prb = cfg.prb(); |
1647 | if (!prb.is_bwd_d) return; |
1648 | if (prb.is_stride1()) return; |
1649 | |
1650 | cfg.set_bwd_d_optimize_strided(true); |
1651 | |
1652 | if (cfg.iter_dim("iw" ) > 1) return; |
1653 | if (prb.iw % prb.sw != 0) return; |
1654 | cfg.set_bwd_d_optimize_strided_iw(true); |
1655 | |
1656 | // Update blocks. |
1657 | int iw_tg_dim0 = cfg.thread_group_dim("iw" ); |
1658 | ir_assert(math::is_pow2(iw_tg_dim0)); |
1659 | ir_assert(prb.iw % prb.sw == 0); |
1660 | for (int tg_dim = iw_tg_dim0; tg_dim >= 1; tg_dim /= 2) { |
1661 | if ((prb.iw / prb.sw) % tg_dim != 0) continue; |
1662 | |
1663 | cfg.thread_group_dims().set("iw" , tg_dim); |
1664 | int mb_iter_dim = cfg.iter_dim("mb" ); |
1665 | int new_mb_tg_dim = cfg.thread_group_dim("mb" ) * iw_tg_dim0 / tg_dim; |
1666 | // TODO: non-uniform thread group is unsupported |
1667 | while (new_mb_tg_dim > 1 |
1668 | && utils::rnd_up(prb.mb, mb_iter_dim * new_mb_tg_dim) - prb.mb |
1669 | >= mb_iter_dim) { |
1670 | new_mb_tg_dim /= 2; |
1671 | } |
1672 | if (mb_iter_dim * new_mb_tg_dim <= prb.mb) { |
1673 | cfg.thread_group_dims().set("mb" , new_mb_tg_dim); |
1674 | } |
1675 | break; |
1676 | } |
1677 | } |
1678 | |
1679 | void init_unroll(conv_config_t &cfg) { |
1680 | if (cfg.unroll().is_overridden()) return; |
1681 | |
1682 | const auto &prb = cfg.prb(); |
1683 | |
1684 | if (prb.is_bwd_w) { |
1685 | int mb_loop_dim = cfg.loop_dim("mb" ); |
1686 | int ow_loop_dim = cfg.loop_dim("ow" ); |
1687 | cfg.unroll().set("mb" , mb_loop_dim); |
1688 | if (cfg.iter_dim("ow" ) > 1 && ow_loop_dim <= 8 && cfg.is_dp_fma()) { |
1689 | cfg.unroll().set("ow" , ow_loop_dim); |
1690 | } |
1691 | } |
1692 | } |
1693 | |
1694 | bool can_split_across_thread_group(int tg_size, int elems, int type_size) { |
1695 | // Thread group grid is limited to powers of two. We can reliably split |
1696 | // only powers of two elements across such grids. |
1697 | if (!math::is_pow2(elems)) return false; |
1698 | |
1699 | // Check that the buffer can be uniformly distributed. |
1700 | if (elems % tg_size != 0) return false; |
1701 | |
1702 | // Check that SLM can be stored with oword messages. |
1703 | int bytes_per_thr = (elems / tg_size) * type_size; |
1704 | if (bytes_per_thr % 16 != 0) return false; |
1705 | |
1706 | return true; |
1707 | } |
1708 | |
1709 | void init_slm(conv_config_t &cfg) { |
1710 | if (cfg.slm().is_overridden()) return; |
1711 | |
1712 | const auto &prb = cfg.prb(); |
1713 | if (cfg.hw() >= ngen::HW::XeHPC) return; |
1714 | |
1715 | int bufs = 0; |
1716 | int gmem_bufs = 0; |
1717 | bool enable_a = false; |
1718 | bool enable_b = false; |
1719 | auto &tg = cfg.thread_group_grid(); |
1720 | int tg_size = tg.elems(); |
1721 | bmnk_dim_helper_t h(cfg); |
1722 | int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m'); |
1723 | int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n'); |
1724 | int k_iter_blk = h.iter_dim('k'); |
1725 | if (!cfg.ow_kw_grf_cache()) { |
1726 | //Check that SLM can be stored with oword messages. |
1727 | int tg_size = tg.elems(); |
1728 | int bytes_per_tg = (m_tg_blk * k_iter_blk * prb.a_data_type_size); |
1729 | int align = prb.is_bwd_w ? 32 : 16; |
1730 | bool can_split_a = bytes_per_tg % align == 0 |
1731 | && bytes_per_tg / tg_size >= k_iter_blk && k_iter_blk % 2 == 0; |
1732 | enable_a = (tg.dim(0) > 1) && can_split_a; |
1733 | } |
1734 | bool can_split_b = can_split_across_thread_group( |
1735 | tg_size, n_tg_blk * k_iter_blk, prb.b_data_type_size); |
1736 | enable_b = (tg.dim(1) > 1) && can_split_b; |
1737 | |
1738 | if (enable_a || enable_b) { |
1739 | bool is_small_tg = (tg.dim(0) * tg.dim(1) <= 8); |
1740 | int pref_bufs |
1741 | = ((is_small_tg || prb.is_f32_conv()) && prb.mb > 1 ? 2 : 3); |
1742 | if (cfg.pipeline().do_unroll()) { |
1743 | bufs = pref_bufs; |
1744 | gmem_bufs = (cfg.is_dp_fma() ? 2 : 1); |
1745 | } else { |
1746 | // Double/triple SLM buffering is not supported when only one |
1747 | // matrix is SLM-buffered. |
1748 | bufs = (enable_a == enable_b ? pref_bufs : 1); |
1749 | gmem_bufs = 1; |
1750 | } |
1751 | } |
1752 | cfg.slm().set(bufs, gmem_bufs, enable_a, enable_b); |
1753 | } |
1754 | |
1755 | void init_prefetch(conv_config_t &cfg) { |
1756 | if (cfg.prefetch().is_overridden()) return; |
1757 | |
1758 | const auto &prb = cfg.prb(); |
1759 | if (cfg.hw() < ngen::HW::XeHPC) return; |
1760 | |
1761 | auto &tg = cfg.thread_group_grid(); |
1762 | int tg_size = tg.elems(); |
1763 | bmnk_dim_helper_t h(cfg); |
1764 | int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m'); |
1765 | int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n'); |
1766 | int k_iter_blk = h.iter_dim('k'); |
1767 | bool can_split_a = (tg.dim(0) == 1) |
1768 | || can_split_across_thread_group( |
1769 | tg_size, m_tg_blk * k_iter_blk, prb.a_data_type_size); |
1770 | bool can_split_b = (tg.dim(1) == 1) |
1771 | || can_split_across_thread_group( |
1772 | tg_size, n_tg_blk * k_iter_blk, prb.b_data_type_size); |
1773 | |
1774 | bool use_prefetch = (can_split_a && can_split_b); |
1775 | if (!use_prefetch) return; |
1776 | |
1777 | cfg.prefetch().set(prb.is_f32_conv() ? 2 : 3); |
1778 | } |
1779 | |
1780 | void init_allow_grf_reorder(conv_config_t &cfg) { |
1781 | const auto &prb = cfg.prb(); |
1782 | cfg.set_allow_a_grf_reorder(!prb.matches_user_types()); |
1783 | cfg.set_allow_b_grf_reorder(!prb.matches_user_types()); |
1784 | |
1785 | bool is_mad = !cfg.is_dp_fma(); |
1786 | if (is_mad && prb.is_s32_accumulator()) { |
1787 | cfg.set_allow_a_grf_reorder(true); |
1788 | cfg.set_allow_b_grf_reorder(true); |
1789 | return; |
1790 | } |
1791 | |
1792 | if (is_mad && prb.b_data_type == data_type::bf16) { |
1793 | cfg.set_allow_b_grf_reorder(true); |
1794 | return; |
1795 | } |
1796 | |
1797 | bool use_a_2d_send = can_use_a_2d_send(cfg); |
1798 | bool use_b_2d_send = can_use_b_2d_send(cfg); |
1799 | bool is_a_grf_blocked |
1800 | = (cfg.a_layout().compute().innermost_block_layout().size() |
1801 | % cfg.grf_size() |
1802 | == 0); |
1803 | if ((prb.is_fwd || prb.is_bwd_d) && !use_a_2d_send && !is_a_grf_blocked) { |
1804 | const char *dim_name = (prb.is_fwd ? "ic" : "oc" ); |
1805 | int dim = (prb.is_fwd ? prb.ic : prb.oc); |
1806 | int blk = cfg.iter_dim(dim_name); |
1807 | if (blk * prb.a_data_type_size % cfg.grf_size() != 0 |
1808 | || dim != cfg.padded_dim(dim_name)) { |
1809 | cfg.set_allow_a_grf_reorder(true); |
1810 | } |
1811 | } |
1812 | if (cfg.send_2d_nhwc()) cfg.set_allow_a_grf_reorder(true); |
1813 | |
1814 | bool a_is_small_c = (prb.is_fwd || prb.is_bwd_w) ? is_small_ic(prb) |
1815 | : is_small_oc(prb); |
1816 | if (cfg.is_dp_fma() && !prb.is_dw && a_is_small_c) { |
1817 | cfg.set_allow_a_grf_reorder(true); |
1818 | } |
1819 | |
1820 | if (prb.is_bwd_w && cfg.is_dp_fma()) { |
1821 | cfg.set_allow_a_grf_reorder(true); |
1822 | if (!use_b_2d_send) cfg.set_allow_b_grf_reorder(true); |
1823 | } |
1824 | } |
1825 | |
1826 | void init_allow_slm_tg_slicing(conv_config_t &cfg) { |
1827 | const auto &prb = cfg.prb(); |
1828 | if (!prb.is_bwd_w) return; |
1829 | if (!utils::everyone_is(prb.a_data_type, prb.b_data_type, data_type::bf16)) |
1830 | return; |
1831 | if (!cfg.is_dp_fma()) return; |
1832 | |
1833 | // Enable only for layouts with batch blocking. |
1834 | int src_mb_blk = cfg.src_layout().compute().inner_block(0); |
1835 | int src_ic_blk = cfg.src_layout().compute().inner_block(2); |
1836 | int dst_mb_blk = cfg.dst_layout().compute().inner_block(0); |
1837 | int dst_oc_blk = cfg.dst_layout().compute().inner_block(2); |
1838 | if (src_mb_blk < 16 || dst_mb_blk < 16) return; |
1839 | |
1840 | bmnk_dim_helper_t h(cfg); |
1841 | int k_iter_blk = h.iter_dim('k'); |
1842 | int m_iter_blk = h.iter_dim('m'); |
1843 | int n_iter_blk = h.iter_dim('n'); |
1844 | int m_tg_dim = h.thread_group_dim('m'); |
1845 | int n_tg_dim = h.thread_group_dim('n'); |
1846 | int tg_size = m_tg_dim * n_tg_dim; |
1847 | |
1848 | // Backward by weights with dpas layouts requires GRF reorders for A/B |
1849 | // (e.g. 2c*16n16c -> 32c16n). When SLM is used, such reorders are |
1850 | // generated after load from GMEM and before store to SLM. For optimal |
1851 | // performance we need load/store layouts to have large dense blocks. This |
1852 | // means that in some cases we have to use only a sub-grid of thread group |
1853 | // (i.e. rely on TG slicing) to perform load-store operation, otherwise we |
1854 | // may end up with reorders like 8n16c -> 16c*8n which result in scattered |
1855 | // loads/stores). |
1856 | // At the same time using sub-grids results in higher GRF consumption so we |
1857 | // only enable TG slicing when the resulting sub-grid consists of at least |
1858 | // half of the total threads. |
1859 | int src_reorder_elems = k_iter_blk * src_ic_blk; |
1860 | int src_tg_elems = m_iter_blk * m_tg_dim * k_iter_blk; |
1861 | if (src_tg_elems % tg_size != 0) return; |
1862 | int src_elems_per_thr = src_tg_elems / tg_size; |
1863 | int src_slices = utils::div_up(src_reorder_elems, src_elems_per_thr); |
1864 | if (src_slices > 2) return; |
1865 | |
1866 | int dst_reorder_elems = k_iter_blk * dst_oc_blk; |
1867 | int dst_tg_elems = n_iter_blk * n_tg_dim * k_iter_blk; |
1868 | if (dst_tg_elems % tg_size != 0) return; |
1869 | int dst_elems_per_thr = dst_tg_elems / tg_size; |
1870 | int dst_slices = utils::div_up(dst_reorder_elems, dst_elems_per_thr); |
1871 | if (dst_slices > 2) return; |
1872 | |
1873 | cfg.set_allow_slm_tg_slicing(true); |
1874 | } |
1875 | |
1876 | void init_reduce_b(conv_config_t &cfg) { |
1877 | const auto &prb = cfg.prb(); |
1878 | |
1879 | if (prb.is_bwd_w && prb.with_bias) { cfg.set_reduce_b(true); } |
1880 | } |
1881 | |
1882 | void init_assign_sbids(conv_config_t &cfg) { |
1883 | if (cfg.is_dp_fma()) cfg.set_assign_sbids(true); |
1884 | } |
1885 | |
1886 | // Overwrites parameters that are implied by other parameters. |
1887 | status_t fixup_config(conv_config_t &cfg) { |
1888 | const auto &prb = cfg.prb(); |
1889 | |
1890 | // Downgrade dpasw -> dpas for some cases. |
1891 | if (cfg.fma_kind() == fma_kind_t::dpasw) { |
1892 | // dpasw is executed by fused EUs (across X thread group |
1893 | // dimension). Do not use dpasw if X is uneven. |
1894 | if (cfg.thread_group_grid().dim(0) % 2 != 0) |
1895 | cfg.set_fma_kind(fma_kind_t::dpas); |
1896 | // dpasw can't be generated in case of direct load from GMEM and reorder. |
1897 | if (prb.is_bwd_w |
1898 | && (cfg.allow_a_grf_reorder() || cfg.allow_b_grf_reorder()) |
1899 | && (!cfg.slm().a() || !cfg.slm().b())) |
1900 | cfg.set_fma_kind(fma_kind_t::dpas); |
1901 | } |
1902 | |
1903 | return status::success; |
1904 | } |
1905 | |
1906 | template <typename GetFuncT> |
1907 | bool in_grid_dims( |
1908 | GetFuncT get_func, const conv_problem_t &prb, const std::string &dim) { |
1909 | for (int i = 0; i < 3; i++) { |
1910 | auto **dd = get_func(prb, i); |
1911 | for (auto **d = dd; *d; d++) |
1912 | if (*d == dim) return true; |
1913 | } |
1914 | return false; |
1915 | } |
1916 | |
1917 | status_t check_config(conv_config_t &cfg) { |
1918 | const auto &prb = cfg.prb(); |
1919 | if (prb.is_fwd) { |
1920 | if (cfg.send_2d_nhwc() && prb.sw != 1 && (prb.kw != 1 || prb.pw != 0)) { |
1921 | int osp_iter_blk = cfg.iter_dim("osp" ) * cfg.iter_dim("ow" ); |
1922 | ir_assert(osp_iter_blk == 1) |
1923 | << "Can't use 2D block messages for non-trivial " |
1924 | "strided dimensions." ; |
1925 | } |
1926 | } else if (prb.is_bwd_d) { |
1927 | if (cfg.send_2d_nhwc() && prb.mb < 16 && prb.sw != 1) { |
1928 | ir_assert(cfg.iter_dim("iw" ) == 1) |
1929 | << "Can't use 2D block messages for non-trivial " |
1930 | "strided dimensions." ; |
1931 | } |
1932 | } else if (prb.is_bwd_w) { |
1933 | if (cfg.send_2d_nhwc() && prb.sw != 1 && (prb.kw != 1 || prb.pw != 0)) { |
1934 | ir_assert(cfg.iter_dim("ow" ) == 1) |
1935 | << "Can't use 2D block messages for non-trivial " |
1936 | "strided dimensions." ; |
1937 | } |
1938 | } |
1939 | |
1940 | for (auto &kv : cfg.dims().get()) { |
1941 | auto &name = kv.first; |
1942 | int tg = cfg.thread_group_dim(name); |
1943 | int grid = cfg.grid_dim(name); |
1944 | if (tg != 1) |
1945 | ir_assert(in_grid_dims(get_thread_group_grid_conv_dims, prb, name)) |
1946 | << name; |
1947 | if (grid != 1) |
1948 | ir_assert(in_grid_dims(get_kernel_grid_conv_dims, prb, name)) |
1949 | << name; |
1950 | } |
1951 | return status::success; |
1952 | } |
1953 | |
1954 | bool try_reduce_grf_usage(conv_config_t &cfg) { |
1955 | const auto &prb = cfg.prb(); |
1956 | if (!cfg.reduce_grf_usage()) return true; |
1957 | |
1958 | // TODO: improve estimate register count, it fails to account for tmp |
1959 | // values like mask_registers among other things. |
1960 | int max_regs = cfg.regs(); |
1961 | int est_regs = estimate_register_count(cfg); |
1962 | if (est_regs <= max_regs) return true; |
1963 | |
1964 | // Try to disable GRF buffering. |
1965 | if (cfg.slm().gmem_bufs() > 1) { |
1966 | cfg.slm().set_gmem_bufs(1); |
1967 | int est_regs = estimate_register_count(cfg); |
1968 | if (est_regs <= max_regs) return true; |
1969 | } |
1970 | |
1971 | bmnk_dim_helper_t h(cfg); |
1972 | |
1973 | // Try to use subtiles for B. |
1974 | if (!cfg.subtiles().is_overridden()) { |
1975 | int n_iter_blk = h.iter_dim('n'); |
1976 | int max_b_subtiles |
1977 | = std::min((cfg.slm().b() ? 4 : 2), n_iter_blk / cfg.simd()); |
1978 | // XXX: avoid layout mismatch for B loads |
1979 | if (cfg.hw() >= ngen::HW::XeHPC && prb.is_bwd_w) |
1980 | max_b_subtiles = std::min(2, max_b_subtiles); |
1981 | while (cfg.subtiles().b() < max_b_subtiles) { |
1982 | cfg.subtiles().set_b(cfg.subtiles().b() * 2); |
1983 | int est_regs = estimate_register_count(cfg); |
1984 | if (est_regs <= max_regs) return true; |
1985 | } |
1986 | |
1987 | // Try to use subtiles for A. |
1988 | int m_iter_blk = h.iter_dim('m'); |
1989 | int max_a_subtiles = std::min((cfg.slm().a() ? 4 : 2), m_iter_blk / 8); |
1990 | if (cfg.subtiles().b() > 1) max_a_subtiles = 1; |
1991 | while (cfg.subtiles().a() < max_a_subtiles) { |
1992 | cfg.subtiles().set_a(cfg.subtiles().a() * 2); |
1993 | int est_regs = estimate_register_count(cfg); |
1994 | if (est_regs <= max_regs) return true; |
1995 | } |
1996 | } |
1997 | |
1998 | if (!cfg.slm().is_overridden()) { |
1999 | // Try to use double SLM buffering. |
2000 | if (cfg.slm().bufs() == 3) { |
2001 | cfg.slm().set_bufs(2); |
2002 | int est_regs = estimate_register_count(cfg); |
2003 | if (est_regs <= max_regs) return true; |
2004 | } |
2005 | |
2006 | // Try to use single SLM buffering. |
2007 | if (cfg.slm().bufs() == 2) { |
2008 | cfg.slm().set_bufs(1); |
2009 | int est_regs = estimate_register_count(cfg); |
2010 | if (est_regs <= max_regs) return true; |
2011 | } |
2012 | } |
2013 | |
2014 | if (!cfg.pipeline().is_overridden()) { |
2015 | // Last resort settings to reduce GRF usage. |
2016 | cfg.pipeline().set(false); |
2017 | } |
2018 | |
2019 | return estimate_register_count(cfg) <= max_regs; |
2020 | } |
2021 | |
2022 | status_t try_init_cfg(conv_config_t &cfg) { |
2023 | init_hint(cfg); |
2024 | init_send_2d_nhwc(cfg); |
2025 | init_fuse_spatial(cfg); |
2026 | init_hoist_masks_from_compute_loop(cfg); |
2027 | init_ow_kw_grf_cache(cfg); |
2028 | init_blocking(cfg); |
2029 | init_bwd_d_optimize_strided(cfg); |
2030 | init_pipeline(cfg); |
2031 | init_padded_dims(cfg); |
2032 | init_kernel_grid(cfg); |
2033 | init_thread_group_grid(cfg); |
2034 | init_unroll(cfg); |
2035 | init_slm(cfg); |
2036 | init_prefetch(cfg); |
2037 | init_allow_grf_reorder(cfg); |
2038 | init_allow_slm_tg_slicing(cfg); |
2039 | init_reduce_b(cfg); |
2040 | init_assign_sbids(cfg); |
2041 | |
2042 | CHECK(fixup_config(cfg)); |
2043 | CHECK(check_config(cfg)); |
2044 | |
2045 | if (!try_reduce_grf_usage(cfg)) return status::unimplemented; |
2046 | |
2047 | return status::success; |
2048 | } |
2049 | |
2050 | // Returns max SLM size per thread group assuming max utilization (max |
2051 | // concurrent threads per EU). |
2052 | int max_slm_size(const conv_config_t &cfg) { |
2053 | ngen::HW hw = cfg.hw(); |
2054 | int regs = cfg.regs(); |
2055 | return compute::device_info_t::max_slm_size_per_tg( |
2056 | convert_ngen_arch_to_dnnl(hw), regs > 128); |
2057 | } |
2058 | |
2059 | int slm_size(const conv_config_t &cfg) { |
2060 | const auto &prb = cfg.prb(); |
2061 | auto &slm = cfg.slm(); |
2062 | if (slm.bufs() == 0) return 0; |
2063 | |
2064 | bmnk_dim_helper_t h(cfg); |
2065 | int m_tg_blk = h.thread_group_dim('m') * h.iter_dim('m'); |
2066 | int n_tg_blk = h.thread_group_dim('n') * h.iter_dim('n'); |
2067 | int k_iter_blk = h.iter_dim('k'); |
2068 | int a_slm_size = m_tg_blk * k_iter_blk * prb.a_data_type_size; |
2069 | int b_slm_size = n_tg_blk * k_iter_blk * prb.b_data_type_size; |
2070 | |
2071 | int ret = 0; |
2072 | if (slm.a()) ret += a_slm_size; |
2073 | if (slm.b()) ret += b_slm_size; |
2074 | ret *= slm.bufs(); |
2075 | |
2076 | return ret; |
2077 | } |
2078 | |
2079 | status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd) { |
2080 | cfg.set_pd(pd); |
2081 | |
2082 | // Try large GRF mode first. |
2083 | int try_regs = cfg.hw_cfg().large_grf_support() ? 256 : 128; |
2084 | //if (prb.g == 1 && prb.is_f32_conv()) try_regs = 128; |
2085 | |
2086 | int def_max_tg_size |
2087 | = get_default_max_tg_size(cfg.hw_cfg(), try_regs, cfg.simd()); |
2088 | conv_hint_t hint(def_max_tg_size); |
2089 | |
2090 | // Use fixed iterations to avoid infinite loop. |
2091 | int max_iters = 10; |
2092 | bool ok = false; |
2093 | for (int i = 0; i < max_iters; i++) { |
2094 | conv_config_t try_cfg = cfg; |
2095 | try_cfg.set_regs(try_regs); |
2096 | try_cfg.set_hint(hint); |
2097 | |
2098 | CHECK(try_init_cfg(try_cfg)); |
2099 | |
2100 | // Reduce thread group size if SLM size is too large. |
2101 | if (try_cfg.check_slm_size()) { |
2102 | if (slm_size(try_cfg) > max_slm_size(try_cfg)) { |
2103 | hint.set_max_tg_size(hint.max_tg_size() / 2); |
2104 | continue; |
2105 | } |
2106 | } |
2107 | |
2108 | // If the kernel fits 128 registers, switch to the normal mode which is |
2109 | // expected to have better performance for such cases. |
2110 | int bound = (!try_cfg.is_dp_fma() ? 128 : 116); |
2111 | int estimated_peak_grf_usage = estimate_register_count(try_cfg); |
2112 | if (try_regs == 256 && estimated_peak_grf_usage <= bound) { |
2113 | try_regs = 128; |
2114 | continue; |
2115 | } |
2116 | cfg = try_cfg; |
2117 | ok = true; |
2118 | break; |
2119 | } |
2120 | |
2121 | return ok ? status::success : status::runtime_error; |
2122 | } |
2123 | |
2124 | void conv_config_t::override_set(const std::string &s) { |
2125 | std::vector<param_t *> params; |
2126 | for (auto &gp : get_params_) |
2127 | params.push_back(gp(this)); |
2128 | auto parts = ir_utils::split(s); |
2129 | for (auto &p : parts) { |
2130 | auto sub_parts = ir_utils::split(p, "=" ); |
2131 | ir_assert(sub_parts.size() == 2); |
2132 | auto &key = sub_parts[0]; |
2133 | auto &value = sub_parts[1]; |
2134 | bool found = false; |
2135 | for (auto *p : params) { |
2136 | if (p->accept_key(key)) { |
2137 | ir_info() << "Override " << p->name() << ": " << key << "=" |
2138 | << value << std::endl; |
2139 | p->override_set(key, value); |
2140 | found = true; |
2141 | break; |
2142 | } |
2143 | } |
2144 | ir_assert(found) << p; |
2145 | } |
2146 | } |
2147 | |
2148 | int get_thread_count(const conv_config_t &cfg) { |
2149 | return cfg.kernel_grid().elems() * cfg.thread_group_grid().elems(); |
2150 | } |
2151 | |
2152 | // Return thread utilization as a percentage. If this value is low, |
2153 | // parallelism is a fundamental limitation to the current work scheduling. |
2154 | float get_thread_utilization(const conv_config_t &cfg) { |
2155 | auto arch = convert_ngen_arch_to_dnnl(cfg.hw()); |
2156 | int slice_eu_count = compute::device_info_t::max_eus_per_wg(arch); |
2157 | int slice_count = cfg.hw_cfg().eu_count() / slice_eu_count; |
2158 | |
2159 | int min_wg_per_slice_wave |
2160 | = std::max(slice_eu_count / cfg.thread_group_grid().elems(), 1); |
2161 | int min_wg_per_wave = slice_count * min_wg_per_slice_wave; |
2162 | int wg = cfg.kernel_grid().elems(); |
2163 | return ((float)wg / utils::rnd_up(wg, min_wg_per_wave)) * 100; |
2164 | } |
2165 | |
2166 | // Return wave utilization as a percentage. If this value is low, memory |
2167 | // latency may be an issue due to limited use of SMT to hide the latency. |
2168 | float get_wave_utilization(const conv_config_t &cfg) { |
2169 | auto arch = convert_ngen_arch_to_dnnl(cfg.hw()); |
2170 | int threads_per_eu = compute::device_info_t::threads_per_eu( |
2171 | arch, cfg.hw_cfg().large_grf_support()); |
2172 | int slice_eu_count = compute::device_info_t::max_eus_per_wg(arch); |
2173 | int slice_count = cfg.hw_cfg().eu_count() / slice_eu_count; |
2174 | |
2175 | int max_wg_per_slice_wave |
2176 | = slice_eu_count * threads_per_eu / cfg.thread_group_grid().elems(); |
2177 | int max_wg_per_wave = slice_count * max_wg_per_slice_wave; |
2178 | int wg = cfg.kernel_grid().elems(); |
2179 | return ((float)wg / utils::rnd_up(wg, max_wg_per_wave)) * 100; |
2180 | } |
2181 | |
2182 | std::string conv_config_t::str() const { |
2183 | using namespace ir_utils; |
2184 | |
2185 | std::ostringstream oss; |
2186 | // clang-format off |
2187 | oss << " HW config: " << exec_cfg().str(hint().max_tg_size()) << std::endl; |
2188 | oss << " Problem: " << prb().desc_str() << std::endl; |
2189 | const char *names[] = {"Source" , "Weights" , "Destination" }; |
2190 | const layout_param_t *layouts[] = {&src_layout(), &wei_layout(), &dst_layout()}; |
2191 | for (int i = 0; i < 3; i++) { |
2192 | std::string desc = std::string(names[i]) + " layout:" ; |
2193 | desc.insert(desc.size(), 28 - desc.size(), ' '); |
2194 | auto &compute_layout = layouts[i]->compute_unnormalized(); |
2195 | auto &user_layout = layouts[i]->user_unnormalized(); |
2196 | oss << " " << desc << compute_layout; |
2197 | if (user_layout != compute_layout) { |
2198 | oss << " (user: " << user_layout << ")" ; |
2199 | } |
2200 | oss << std::endl; |
2201 | } |
2202 | int estimated_peak_grf_usage = estimate_register_count(*this); |
2203 | oss << blocking_brief_str(); |
2204 | oss << " Kernel grid: " << kernel_grid() << std::endl; |
2205 | oss << " Thread group: " << thread_group_grid() << std::endl; |
2206 | oss << " Threads: " << get_thread_count(*this) << " (utilization: " |
2207 | << get_thread_utilization(*this) << "% thread, " |
2208 | << get_wave_utilization(*this) << "% wave)" << std::endl; |
2209 | oss << " FMA kind: " << fma_kind::to_string(fma_kind()) << std::endl; |
2210 | oss << " SLM buffering: " << "A: " << to_string(slm().a()) << ", B: " << to_string(slm().b()) |
2211 | << ", buffers: " << slm().bufs() << ", pad: " << to_string(pad_slm()) << std::endl; |
2212 | oss << " GRF buffers for GMEM load: " << slm().gmem_bufs() << std::endl; |
2213 | oss << " Prefetch: " << to_string(prefetch()) << ", buffers: " << prefetch().bufs() << std::endl; |
2214 | oss << " Do pipeline unroll: " << to_string(pipeline().do_unroll()) << std::endl; |
2215 | oss << " Assign SBIDs: " << to_string(assign_sbids()) << std::endl; |
2216 | oss << " Reduce GRF usage: " << to_string(reduce_grf_usage()) << std::endl; |
2217 | oss << " Reuse headers: " << to_string(pipeline().reuse_headers()) << std::endl; |
2218 | oss << " Allow GRF reorder: " << "A: " << to_string(allow_a_grf_reorder()) << ", B: " << to_string(allow_b_grf_reorder()) << std::endl; |
2219 | oss << " Subtiles: " << "A: " << subtiles().a() << ", B: " << subtiles().b() << std::endl; |
2220 | oss << " Estimated GRF usage: " << estimated_peak_grf_usage << std::endl; |
2221 | // clang-format on |
2222 | return oss.str(); |
2223 | } |
2224 | |
2225 | std::string pad_str(std::string s, int pad) { |
2226 | auto pos = (pad >= 0 ? 0 : s.length()); |
2227 | s.insert(pos, std::abs(pad) - s.length(), ' '); |
2228 | return s; |
2229 | } |
2230 | |
2231 | std::string pad_int(int i, int pad) { |
2232 | return pad_str(std::to_string(i), pad); |
2233 | } |
2234 | |
2235 | std::string conv_config_t::blocking_brief_str() const { |
2236 | std::ostringstream oss; |
2237 | std::vector<std::string> names; |
2238 | for (auto &kv : dims().get()) { |
2239 | names.push_back(kv.first); |
2240 | } |
2241 | std::sort(names.begin(), names.end()); |
2242 | for (auto &name : names) { |
2243 | int iter = iter_dim(name); |
2244 | int tg = thread_group_dim(name); |
2245 | int loop = loop_dim(name); |
2246 | int grid = grid_dim(name); |
2247 | if (iter == 1 && loop == 1 && tg == 1) continue; |
2248 | oss << " Dimension " << name << pad_str(":" , -18 + (int)name.length()); |
2249 | oss << "(grid:" << pad_int(grid, 5) << ") x " ; |
2250 | oss << "(tg:" << pad_int(tg, 5) << ") x " ; |
2251 | oss << "(loop:" << pad_int(loop, 5) << ") x " ; |
2252 | oss << "(iter:" << pad_int(iter, 5) << ")\n" ; |
2253 | } |
2254 | return oss.str(); |
2255 | } |
2256 | |
2257 | bool conv_config_t::can_skip_wei_zero_out() const { |
2258 | if (!prb().is_bwd_w) return true; |
2259 | bmnk_dim_helper_t h(*this); |
2260 | int k_iter_dim = h.iter_dim('k'); |
2261 | int k_loop_dim = h.loop_dim('k'); |
2262 | int k_tg_dim = h.thread_group_dim('k'); |
2263 | int k_tg_block = k_iter_dim * k_loop_dim * k_tg_dim; |
2264 | int k_padded = padded_dim("mb" ) * padded_dim("od" ) * padded_dim("oh" ) |
2265 | * padded_dim("ow" ); |
2266 | return k_tg_block >= k_padded; |
2267 | } |
2268 | |
2269 | bool conv_config_t::can_skip_bia_zero_out() const { |
2270 | if (!prb().is_bwd_w || !prb().with_bias) return true; |
2271 | return can_skip_wei_zero_out() && !slm().b(); |
2272 | } |
2273 | |
2274 | void (const conv_config_t &cfg, tensor_config_t &tensor_cfg) { |
2275 | const auto &prb = cfg.prb(); |
2276 | auto &zp_cfg = prb.zp_cfg; |
2277 | auto *pd = prb.conv_pd; |
2278 | auto *attr = prb.attr; |
2279 | if (zp_cfg.do_src_compensation && zp_cfg.is_runtime_src_zero_points) { |
2280 | int zp_ic = (zp_cfg.is_common_src_zero_point) ? 1 : prb.ic; |
2281 | std::vector<dim_t> dims = {zp_ic}; |
2282 | layout_t zp_layout(type_t::s32(), 0, dims); |
2283 | int arg_key = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; |
2284 | tensor_cfg.add_tensor("src_zero_points" , arg_key, |
2285 | /*is_input=*/true, /*is_output=*/false, zp_layout); |
2286 | } |
2287 | if (zp_cfg.do_dst_compensation && zp_cfg.is_runtime_dst_zero_points) { |
2288 | std::vector<dim_t> dims = {prb.oc}; |
2289 | layout_t zp_layout(type_t::s32(), 0, dims); |
2290 | int arg_key = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; |
2291 | tensor_cfg.add_tensor("dst_zero_points" , arg_key, |
2292 | /*is_input=*/true, /*is_output=*/false, zp_layout); |
2293 | } |
2294 | auto scale_args = get_scale_args(prb); |
2295 | const char *scale_names[] = {"src_scales" , "wei_scales" , "dst_scales" }; |
2296 | const int scale_names_len = sizeof(scale_names) / sizeof(scale_names[0]); |
2297 | ir_assert((int)scale_args.size() == scale_names_len); |
2298 | for (int i = 0; i < (int)scale_args.size(); i++) { |
2299 | int arg = scale_args[i]; |
2300 | auto &s = attr->scales_.get(arg); |
2301 | if (s.has_default_values()) continue; |
2302 | int dim = s.mask_ == 0 ? 1 : (prb.is_fwd ? prb.oc : prb.ic); |
2303 | std::vector<dim_t> dims = {dim}; |
2304 | layout_t layout(type_t::f32(), 0, dims); |
2305 | int arg_key = DNNL_ARG_ATTR_SCALES | arg; |
2306 | tensor_cfg.add_tensor(scale_names[i], arg_key, /*is_input=*/true, |
2307 | /*is_output=*/false, layout); |
2308 | } |
2309 | for (int i = 0; i < attr->post_ops_.len(); i++) { |
2310 | auto &po = attr->post_ops_.entry_[i]; |
2311 | if (po.is_eltwise() |
2312 | || po.is_sum(/*require_scale_one=*/false, |
2313 | /*require_zp_zero=*/false)) { |
2314 | // No extra tensors. |
2315 | } else if (po.is_binary()) { |
2316 | auto layout = make_layout(po.binary.src1_desc); |
2317 | int arg_key = DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1; |
2318 | tensor_cfg.add_tensor("binary_rhs_" + std::to_string(i), arg_key, |
2319 | /*is_input=*/true, |
2320 | /*is_output=*/false, layout); |
2321 | } else if (po.is_prelu()) { |
2322 | layout_t layout(type_t::f32(), 0, |
2323 | get_prelu_weights_dims( |
2324 | po.prelu.mask, *pd->invariant_dst_md())); |
2325 | int arg_key = DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_WEIGHTS; |
2326 | tensor_cfg.add_tensor("prelu_rhs_" + std::to_string(i), arg_key, |
2327 | /*is_input=*/true, /*is_output=*/false, layout); |
2328 | } else { |
2329 | ir_error_not_expected(); |
2330 | } |
2331 | } |
2332 | } |
2333 | |
2334 | tensor_config_t get_tensor_config(const conv_config_t &cfg) { |
2335 | const auto &prb = cfg.prb(); |
2336 | tensor_config_t tensor_cfg; |
2337 | conv_arg_helper_t h(prb); |
2338 | auto &src = cfg.src_layout(); |
2339 | auto &wei = cfg.wei_layout(); |
2340 | auto &bia = cfg.bia_layout(); |
2341 | auto &dst = cfg.dst_layout(); |
2342 | tensor_cfg.add_tensor("src" , h.src_arg_key(), h.is_src_input(), |
2343 | h.is_src_output(), src.compute(), src.user()); |
2344 | tensor_cfg.add_tensor("wei" , h.wei_arg_key(), h.is_wei_input(), |
2345 | h.is_wei_output(), wei.compute(), wei.user()); |
2346 | if (prb.with_bias) |
2347 | tensor_cfg.add_tensor("bia" , h.bia_arg_key(), h.is_bia_input(), |
2348 | h.is_bia_output(), bia.compute(), bia.user()); |
2349 | tensor_cfg.add_tensor("dst" , h.dst_arg_key(), h.is_dst_input(), |
2350 | h.is_dst_output(), dst.compute(), dst.user()); |
2351 | if (prb.is_bwd_w && !prb.with_sum) { |
2352 | tensor_cfg.require_zero_out("wei" ); |
2353 | if (prb.with_bias) tensor_cfg.require_zero_out("bia" ); |
2354 | } |
2355 | init_extra_tensors(cfg, tensor_cfg); |
2356 | return tensor_cfg; |
2357 | } |
2358 | |
2359 | int estimate_register_count(const conv_config_t &cfg) { |
2360 | return estimate_grf_usage(cfg).total(); |
2361 | } |
2362 | |
2363 | bool can_use_a_2d_send(const conv_config_t &cfg) { |
2364 | return can_use_2d_send(cfg, cfg.a_layout().compute_unnormalized(), true); |
2365 | } |
2366 | |
2367 | bool can_use_b_2d_send(const conv_config_t &cfg) { |
2368 | return can_use_2d_send(cfg, cfg.b_layout().compute_unnormalized(), false); |
2369 | } |
2370 | |
2371 | } // namespace jit |
2372 | } // namespace gpu |
2373 | } // namespace impl |
2374 | } // namespace dnnl |
2375 | |