1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_CONFIG_HPP
18#define GPU_JIT_CONV_CONFIG_HPP
19
20#include <iostream>
21#include <sstream>
22#include <unordered_map>
23
24#include "common/c_types_map.hpp"
25#include "common/convolution_pd.hpp"
26#include "common/math_utils.hpp"
27#include "common/memory_desc_wrapper.hpp"
28#include "common/type_helpers.hpp"
29#include "gpu/compute/compute.hpp"
30#include "gpu/compute/compute_engine.hpp"
31#include "gpu/jit/conv/block_helper.hpp"
32#include "gpu/jit/conv/tensor_config.hpp"
33#include "gpu/jit/ir/fma.hpp"
34#include "gpu/jit/ir/hw_config.hpp"
35#include "gpu/jit/ir/tensor.hpp"
36#include "gpu/jit/jit_eltwise_injector.hpp"
37#include "gpu/jit/utils/utils.hpp"
38
39namespace dnnl {
40namespace impl {
41namespace gpu {
42namespace jit {
43
44// Description of the convolution problem.
45class conv_problem_t {
46public:
47 conv_problem_t() = default;
48
49 status_t init(const engine_t *engine, const convolution_pd_t *conv_pd);
50
51 bool is_stride1() const { return sd == 1 && sh == 1 && sw == 1; }
52
53 // Reduces dimensions for 1x1 kernel.
54 void try_reduce_to_1d();
55
56 // Helper methods.
57 bool is_s32_accumulator() const { return acc_data_type == data_type::s32; }
58 bool is_f32_conv() const {
59 return utils::everyone_is(src_data_type, wei_data_type, data_type::f32);
60 }
61 bool is_f64_conv() const {
62 return utils::everyone_is(src_data_type, wei_data_type, data_type::f64);
63 }
64 bool is_int8_dst() const {
65 return utils::one_of(dst_data_type, data_type::s8, data_type::u8);
66 }
67 bool is_mixed_int8() const {
68 return utils::one_of(a_data_type, dnnl_f16, dnnl_f32)
69 && utils::one_of(c_data_type, dnnl_u8, dnnl_s8);
70 }
71 bool matches_user_types() const {
72 if (is_fwd) {
73 return a_data_type == src_data_type && b_data_type == wei_data_type
74 && c_data_type == dst_data_type;
75 } else if (is_bwd_d) {
76 return a_data_type == dst_data_type && b_data_type == wei_data_type
77 && c_data_type == src_data_type;
78 } else if (is_bwd_w) {
79 return a_data_type == src_data_type && b_data_type == dst_data_type
80 && c_data_type == wei_data_type;
81 } else {
82 ir_error_not_expected();
83 return false;
84 }
85 }
86
87 const memory_desc_t &a_md() const {
88 return *pick_a(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(),
89 conv_pd->invariant_dst_md());
90 }
91
92 const memory_desc_t &b_md() const {
93 return *pick_b(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(),
94 conv_pd->invariant_dst_md());
95 }
96
97 const memory_desc_t &c_md() const {
98 return *pick_c(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(),
99 conv_pd->invariant_dst_md());
100 }
101
102 template <typename T>
103 T &&pick_a(T &&src, T &&wei, T &&dst) const {
104 return (is_fwd || is_bwd_w) ? std::forward<T>(src)
105 : std::forward<T>(dst);
106 }
107
108 template <typename T>
109 T &&pick_b(T &&src, T &&wei, T &&dst) const {
110 return (is_fwd || is_bwd_d) ? std::forward<T>(wei)
111 : std::forward<T>(dst);
112 }
113
114 template <typename T>
115 T &&pick_c(T &&src, T &&wei, T &&dst) const {
116 return std::forward<T>(is_fwd ? dst : is_bwd_d ? src : wei);
117 }
118
119 template <typename T>
120 T &&pick_by_dir(T &&fwd, T &&bwd_d, T &&bwd_w) const {
121 return std::forward<T>(is_fwd ? fwd : is_bwd_d ? bwd_d : bwd_w);
122 }
123
124 std::string desc_str(bool print_mb = true) const;
125
126 const convolution_pd_t *conv_pd;
127 const primitive_attr_t *attr;
128
129 data_type_t src_data_type;
130 data_type_t wei_data_type;
131 data_type_t dst_data_type;
132 data_type_t bia_data_type;
133 fpmath_mode_t fpmath_mode;
134
135 bool is_fwd;
136 bool is_bwd_d;
137 bool is_bwd_w;
138 bool with_bias;
139 bool with_groups;
140 bool with_sum;
141 bool is_dw;
142
143 int ndims;
144 int mb; // Batch size.
145 int g; // Groups.
146 int ic, oc; // Input and output channels.
147 int id, ih, iw; // Input spatial sizes.
148 int od, oh, ow; // Output spatial sizes.
149 int kd, kh, kw; // Kernel sizes.
150 int sd, sh, sw; // Strides.
151 int pd, ph, pw; // Padding in the beginning.
152 int dd, dh, dw; // Dilation.
153 int reduced_dim; // Indicates which dims were shifted over or reduced.
154 int isp, osp, ksp; // Combined input/output/kernel spatial size.
155
156 data_type_t a_data_type;
157 data_type_t b_data_type;
158 data_type_t c_data_type;
159 data_type_t acc_data_type;
160
161 int a_data_type_size;
162 int b_data_type_size;
163 int c_data_type_size;
164 int acc_data_type_size;
165
166 // Specific to FWD int8
167 struct zero_points_config_t {
168 bool do_src_compensation;
169 bool do_dst_compensation;
170 bool is_runtime_src_zero_points;
171 bool is_runtime_dst_zero_points;
172 bool is_common_src_zero_point;
173 bool is_common_dst_zero_point;
174 int common_src_zero_point;
175 int common_dst_zero_point;
176 } zp_cfg;
177
178private:
179 // Initializes A/B/C data types (GEMM notation: C += A * B) according to
180 // the following convention:
181 // FWD: src -> A, wei -> B, dst -> C
182 // BWD_D: diff_dst -> A, wei -> B, diff_src -> C
183 // BWD_W: src -> A, diff_dst -> B, diff_wei -> C
184 status_t init_abc_data_types(ngen::HW hw) {
185 a_data_type = pick_a(src_data_type, wei_data_type, dst_data_type);
186 b_data_type = pick_b(src_data_type, wei_data_type, dst_data_type);
187 // Always use f32 for accumulation/storing in the main kernel.
188 c_data_type = is_bwd_w
189 ? data_type::f32
190 : pick_c(src_data_type, wei_data_type, dst_data_type);
191
192 if (utils::everyone_is(
193 data_type::f32, a_data_type, b_data_type, c_data_type)) {
194
195 // TODO: bf16 and f16 currently perform worse than tf32, this is
196 // likely due to an extra reorder required on the b buffer.
197 bool use_matching_fpmath = false;
198#ifdef GEN_CONV_DEBUG
199 use_matching_fpmath = ir_utils::getenv_bool(
200 "use_matching_fpmath", use_matching_fpmath);
201#endif
202 if (use_matching_fpmath
203 && attr->mayidownconvert(data_type::f32, data_type::bf16)
204 && fma_kind::get_supported_kind(hw, data_type::bf16,
205 data_type::bf16, data_type::f32)
206 != fma_kind_t::unknown) {
207 a_data_type = data_type::bf16;
208 b_data_type = data_type::bf16;
209 } else if (use_matching_fpmath
210 && attr->mayidownconvert(data_type::f32, data_type::f16)
211 && fma_kind::get_supported_kind(hw, data_type::f16,
212 data_type::f16, data_type::f32)
213 != fma_kind_t::unknown) {
214 a_data_type = data_type::f16;
215 b_data_type = data_type::f16;
216 } else if (attr->mayidownconvert(data_type::f32, data_type::tf32)
217 && fma_kind::get_supported_kind(hw, data_type::tf32,
218 data_type::tf32, data_type::f32)
219 != fma_kind_t::unknown) {
220 a_data_type = data_type::tf32;
221 b_data_type = data_type::tf32;
222 }
223 }
224
225 a_data_type_size = (int)types::data_type_size(a_data_type);
226 b_data_type_size = (int)types::data_type_size(b_data_type);
227 c_data_type_size = (int)types::data_type_size(c_data_type);
228 return status::success;
229 }
230
231 status_t init_acc_data_type() {
232 auto a = a_data_type;
233 auto b = b_data_type;
234 acc_data_type = data_type::undef;
235 if (utils::one_of(a, data_type::s8, data_type::u8)
236 && utils::one_of(b, data_type::s8, data_type::u8)) {
237 acc_data_type = data_type::s32;
238 } else if (utils::everyone_is(data_type::f16, a, b)
239 || utils::everyone_is(data_type::bf16, a, b)) {
240 acc_data_type = data_type::f32;
241 } else if (utils::everyone_is(data_type::tf32, a, b)) {
242 acc_data_type = data_type::f32;
243 } else if (utils::everyone_is(data_type::f32, a, b)) {
244 acc_data_type = data_type::f32;
245 } else if (utils::everyone_is(data_type::f64, a, b)) {
246 acc_data_type = data_type::f64;
247 }
248 if (acc_data_type == data_type::undef) return status::unimplemented;
249 acc_data_type_size = (int)types::data_type_size(acc_data_type);
250 return status::success;
251 }
252
253 status_t init_zero_points_config();
254
255 bool with_sum_post_op() {
256 auto &post_ops = attr->post_ops_;
257 return post_ops.find(primitive_kind::sum) != -1;
258 }
259};
260
261class conv_hint_t {
262public:
263 conv_hint_t() = default;
264 conv_hint_t(int def_max_tg_size) : def_max_tg_size_(def_max_tg_size) {}
265
266 int max_tg_size() const {
267 if (max_tg_size_ != 0) return max_tg_size_;
268 return def_max_tg_size_;
269 }
270
271 bool max_tg_overridden() const { return max_tg_overridden_; }
272
273 void set_max_tg_size(int value) {
274 max_tg_overridden_ = max_tg_size_ != 0;
275 max_tg_size_ = value;
276 }
277
278private:
279 int max_tg_size_ = 0;
280 int def_max_tg_size_ = 0;
281 bool max_tg_overridden_ = false;
282};
283
284class conv_arg_helper_t {
285public:
286 conv_arg_helper_t(const conv_problem_t &prb) : prb_(prb) {}
287
288 int src_arg_key() const {
289 if (prb_.is_fwd) return DNNL_ARG_SRC;
290 if (prb_.is_bwd_d) return DNNL_ARG_DIFF_SRC;
291 if (prb_.is_bwd_w) return DNNL_ARG_SRC;
292 ir_error_not_expected();
293 return -1;
294 }
295
296 bool is_src_input() const { return prb_.is_fwd || prb_.is_bwd_w; }
297 bool is_src_output() const { return prb_.is_bwd_d; }
298
299 int wei_arg_key() const {
300 if (prb_.is_fwd) return DNNL_ARG_WEIGHTS;
301 if (prb_.is_bwd_d) return DNNL_ARG_WEIGHTS;
302 if (prb_.is_bwd_w) return DNNL_ARG_DIFF_WEIGHTS;
303 ir_error_not_expected();
304 return -1;
305 }
306
307 bool is_wei_input() const { return prb_.is_fwd || prb_.is_bwd_d; }
308 bool is_wei_output() const { return prb_.is_bwd_w; }
309
310 int bia_arg_key() const {
311 if (prb_.is_fwd) return DNNL_ARG_BIAS;
312 if (prb_.is_bwd_d) return DNNL_ARG_BIAS;
313 if (prb_.is_bwd_w) return DNNL_ARG_DIFF_BIAS;
314 ir_error_not_expected();
315 return -1;
316 }
317
318 bool is_bia_input() const { return prb_.is_fwd || prb_.is_bwd_d; }
319 bool is_bia_output() const { return prb_.is_bwd_w; }
320
321 int dst_arg_key() const {
322 if (prb_.is_fwd) return DNNL_ARG_DST;
323 if (prb_.is_bwd_d) return DNNL_ARG_DIFF_DST;
324 if (prb_.is_bwd_w) return DNNL_ARG_DIFF_DST;
325 ir_error_not_expected();
326 return -1;
327 }
328
329 bool is_dst_input() const { return prb_.is_bwd_d || prb_.is_bwd_w; }
330 bool is_dst_output() const { return prb_.is_fwd; }
331
332private:
333 const conv_problem_t &prb_;
334};
335
336class param_t {
337public:
338 virtual std::string name() const = 0;
339 virtual std::string short_name() const { return name(); }
340 virtual std::string desc() const = 0;
341
342 virtual bool accept_key(const std::string &key) const {
343 return key == short_name();
344 }
345
346 virtual void set_from_str(const std::string &s) { ir_error_not_expected(); }
347 virtual void set_from_str(
348 const std::string &key, const std::string &value) {
349 if (key == short_name()) {
350 set_from_str(value);
351 return;
352 }
353 ir_error_not_expected();
354 }
355 void override_set(const std::string &key, const std::string &value) {
356 is_overridden_[key] = true;
357 set_from_str(key, value);
358 }
359
360 bool is_overridden() const {
361 if (is_overridden_.empty()) return false;
362 return is_overridden(short_name());
363 }
364
365 bool is_overridden(const std::string &key) const {
366 auto it = is_overridden_.find(key);
367 if (it == is_overridden_.end()) return false;
368 return it->second;
369 }
370
371private:
372 std::unordered_map<std::string, bool> is_overridden_;
373};
374
375template <typename T>
376class value_param_t : public param_t {
377public:
378 using value_t = T;
379 using param_t::is_overridden;
380
381 value_param_t() = default;
382 value_param_t(const T &value) : value_(value) {}
383
384 const T &get() const { return value_; }
385
386 operator const T &() const { return get(); }
387
388 void set(const T &value) { value_ = value; }
389
390protected:
391 T value_;
392};
393
394class bool_param_t : public value_param_t<bool> {
395public:
396 using value_param_t::value_param_t;
397
398 void set_from_str(const std::string &s) override {
399 value_ = ir_utils::to_bool(s);
400 }
401};
402
403class int_param_t : public value_param_t<int> {
404public:
405 using value_param_t::value_param_t;
406
407 void set_from_str(const std::string &s) override { value_ = std::stoi(s); }
408};
409
410class grid_param_t : public value_param_t<grid_info_t> {
411public:
412 using value_param_t::value_param_t;
413};
414
415class layout_param_t : public param_t {
416public:
417 const layout_t &user() const { return user_; }
418 const layout_t &compute() const { return compute_; }
419 const layout_t &user_unnormalized() const { return user_unnormalized_; }
420 const layout_t &compute_unnormalized() const {
421 return compute_unnormalized_;
422 }
423
424 void set_from_str(const std::string &s) override {
425 ir_error_not_implemented();
426 }
427
428 void set_user(const layout_t &l) { user_ = l; }
429 void set_compute(const layout_t &l) { compute_ = l; }
430 void set_user_unnormalized(const layout_t &l) { user_unnormalized_ = l; }
431 void set_compute_unnormalized(const layout_t &l) {
432 compute_unnormalized_ = l;
433 }
434
435private:
436 layout_t user_;
437 layout_t compute_;
438 layout_t user_unnormalized_;
439 layout_t compute_unnormalized_;
440};
441
442inline std::unordered_map<std::string, int> to_map(const std::string &s) {
443 std::unordered_map<std::string, int> ret;
444 int name_beg = -1;
445 int value_beg = -1;
446 for (int pos = 0; pos < (int)s.size() + 1; pos++) {
447 bool prev_digit = pos > 0 && std::isdigit(s[pos - 1]);
448 bool cur_digit = pos < (int)s.size() && std::isdigit(s[pos]);
449 if ((pos == 0 || prev_digit) && !cur_digit) {
450 if (name_beg != -1 && value_beg != -1) {
451 auto key = s.substr(name_beg, value_beg - name_beg);
452 auto value = std::stoi(s.substr(value_beg, pos - value_beg));
453 ret[key] = value;
454 }
455 name_beg = pos;
456 value_beg = -1;
457 }
458 if (!prev_digit && cur_digit) value_beg = pos;
459 }
460 return ret;
461}
462
463class map_param_t : public param_t {
464public:
465 using value_t = std::unordered_map<std::string, int>;
466
467 const value_t &get() const { return map_; }
468
469 bool is_empty() const { return map_.empty(); }
470
471 int get(const std::string &name) const {
472 auto it = map_.find(name);
473 if (it == map_.end()) return 1;
474 return it->second;
475 }
476
477 int operator()(const std::string &name) const { return get(name); }
478
479 void set_from_str(const std::string &s) override {
480 map_.clear();
481 map_ = to_map(s);
482 }
483
484 void set(const std::string &name, int dim) {
485 auto it = map_.find(name);
486 if (dim == 1) {
487 if (it != map_.end()) map_.erase(it);
488 return;
489 }
490 map_[name] = dim;
491 }
492
493 void set(const value_t &value) { map_ = value; }
494
495 std::string str() const {
496 std::ostringstream oss;
497 for (auto &kv : map_) {
498 oss << kv.first << kv.second;
499 }
500 return oss.str();
501 }
502
503 IR_DEFINE_DUMP()
504
505private:
506 value_t map_;
507};
508
509// Parameters for kernel generation.
510
511// TODO: Remove, GRF reorder is to be emitted/accounted for depending on
512// layouts, FMA kind and other parameters.
513class allow_a_grf_reorder_param_t : public bool_param_t {
514public:
515 allow_a_grf_reorder_param_t() : bool_param_t(false) {}
516 std::string name() const override { return "allow-a-grf-reorder"; }
517 std::string desc() const override {
518 return "Whether to allow GRF reorders to FMA-friendly layouts for A.";
519 }
520};
521
522// TODO: Remove, GRF reorder is to be emitted/accounted for depending on
523// layouts, FMA kind and other parameters.
524class allow_b_grf_reorder_param_t : public bool_param_t {
525public:
526 allow_b_grf_reorder_param_t() : bool_param_t(false) {}
527 std::string name() const override { return "allow-b-grf-reorder"; }
528 std::string desc() const override {
529 return "Whether to allow GRF reorders to FMA-friendly layouts for B.";
530 }
531};
532
533// TODO: Remove, use internal logic to determine if SLM thread group slicing is
534// needed.
535class allow_slm_tg_slicing_param_t : public bool_param_t {
536public:
537 allow_slm_tg_slicing_param_t() : bool_param_t(false) {}
538 std::string name() const override { return "allow-slm-tg-slicing"; }
539 std::string desc() const override {
540 return "Whether to allow thread group split for SLM load/store.";
541 }
542};
543
544class assign_sbids_param_t : public bool_param_t {
545public:
546 assign_sbids_param_t() : bool_param_t(false) {}
547 std::string name() const override { return "assign-sbids"; }
548 std::string desc() const override {
549 return "Whether to manually assign SBIDs tokens to dpas/send.";
550 }
551};
552
553class bia_layout_param_t : public layout_param_t {
554 std::string name() const override { return "bia"; }
555 std::string desc() const override { return "Bias layout."; }
556};
557
558class bwd_d_optimize_strided_param_t : public bool_param_t {
559public:
560 bwd_d_optimize_strided_param_t() : bool_param_t(false) {}
561 std::string name() const override { return "bwd-d-optimize-strided"; }
562 std::string desc() const override {
563 return "Apply special optimization for strided BWD_D convolution.";
564 }
565};
566
567class bwd_d_optimize_strided_iw_param_t : public bool_param_t {
568public:
569 bwd_d_optimize_strided_iw_param_t() : bool_param_t(false) {}
570 std::string name() const override { return "bwd-d-optimize-strided-iw"; }
571 std::string desc() const override {
572 return "Apply special optimization for strided BWD_D convolution (iw "
573 "dimension).";
574 }
575};
576
577// TODO: Remove, use heuristics to determine if it's worth to sacrifice EU
578// utilization for larger SLM size.
579class check_slm_size_param_t : public bool_param_t {
580public:
581 check_slm_size_param_t() : bool_param_t(true) {}
582 std::string name() const override { return "check-slm-size"; }
583 std::string short_name() const override { return "c"; }
584 std::string desc() const override {
585 return "Whether to check SLM size to ensure full EU utilization.";
586 }
587};
588
589class dims_param_t : public map_param_t {
590public:
591 std::string name() const override { return "dims"; }
592 std::string desc() const override { return "Problem dimensions."; }
593};
594
595class dst_layout_param_t : public layout_param_t {
596 std::string name() const override { return "dst"; }
597 std::string desc() const override { return "Destination layout."; }
598};
599
600class exec_cfg_param_t : public value_param_t<exec_config_t> {
601public:
602 using value_param_t::accept_key;
603 using value_param_t::is_overridden;
604 using value_param_t::value_param_t;
605
606 std::string name() const override { return "exec-cfg"; }
607 std::string desc() const override {
608 return "Execution config (hardware config, number of registers, SIMD, "
609 "etc).";
610 }
611
612 bool accept_key(const std::string &key) const override {
613 if (key == "simd") return true;
614 return false;
615 }
616
617 void set_from_str(
618 const std::string &key, const std::string &value) override {
619 if (key == "simd") {
620 value_.set_simd(std::stoi(value));
621 } else {
622 ir_error_not_expected() << key;
623 }
624 }
625};
626
627class fma_kind_param_t : public value_param_t<fma_kind_t> {
628public:
629 using value_param_t::value_param_t;
630
631 std::string name() const override { return "fma"; }
632 std::string desc() const override { return "FMA kind."; }
633
634 void set_from_str(const std::string &s) override {
635 value_ = fma_kind::from_string(s);
636 }
637};
638
639class fuse_spatial_param_t : public bool_param_t {
640public:
641 fuse_spatial_param_t() : bool_param_t(false) {}
642 std::string name() const override { return "fuse-spatial"; }
643 std::string short_name() const override { return "fsp"; }
644 std::string desc() const override {
645 return "Whether to apply blocking to fused spatial (otherwise only `w` "
646 "is blocked).";
647 }
648};
649
650class hint_param_t : public value_param_t<conv_hint_t> {
651public:
652 using value_param_t::value_param_t;
653
654 std::string name() const override { return "hint"; }
655 std::string desc() const override { return "Configuration hint."; }
656};
657
658// TODO: Remove, use internal logic.
659class hoist_masks_from_compute_loop_param_t : public bool_param_t {
660public:
661 hoist_masks_from_compute_loop_param_t() : bool_param_t(false) {}
662 std::string name() const override {
663 return "hoist-masks-from-compute-loop";
664 }
665 std::string desc() const override {
666 return "Whether to move send mask initialization out of compute loop.";
667 }
668};
669
670class kernel_grid_param_t : public grid_param_t {
671public:
672 std::string name() const override { return "kernel-grid"; }
673 std::string desc() const override {
674 return "Number of thread groups across dimensions (kernel grid).";
675 }
676};
677
678class iter_dims_param_t : public map_param_t {
679public:
680 std::string name() const override { return "iter"; }
681 std::string short_name() const override { return "i"; }
682 std::string desc() const override {
683 return "Iteration-level dimension blocks.";
684 }
685};
686
687class loop_dims_param_t : public dims_param_t {
688public:
689 std::string name() const override { return "loop"; }
690 std::string short_name() const override { return "l"; }
691 std::string desc() const override { return "Loop-level dimension blocks."; }
692};
693
694class shrink_tg_dims_param_t : public bool_param_t {
695public:
696 shrink_tg_dims_param_t() : bool_param_t(false) {}
697 std::string name() const override { return "shrink-tg-dims"; }
698 std::string short_name() const override { return "stg"; }
699 std::string desc() const override {
700 return "Whether to adjust tile sizes depending on batch size.";
701 }
702};
703
704class ow_kw_grf_cache_param_t : public bool_param_t {
705public:
706 ow_kw_grf_cache_param_t() : bool_param_t(false) {}
707 std::string name() const override { return "ow-kw-grf-cache"; }
708 std::string desc() const override {
709 return "Whether to use GRF cache to reuse source for ow/kw pairs";
710 }
711};
712
713class pad_slm_param_t : public bool_param_t {
714public:
715 pad_slm_param_t() : bool_param_t(true) {}
716 std::string name() const override { return "pad-slm"; }
717 std::string desc() const override {
718 return "Whether to pad SLM layout to avoid bank conflicts.";
719 }
720};
721
722class padded_dims_param_t : public map_param_t {
723public:
724 std::string name() const override { return "pad"; }
725 std::string desc() const override {
726 return "Padded dimensions (rounded-up for blocks and to comply with "
727 "required zero padding in output layouts) .";
728 }
729};
730
731class pipeline_param_t : public param_t {
732public:
733 std::string name() const override { return "pipeline"; }
734 std::string short_name() const override { return "P"; }
735 std::string desc() const override { return "General pipeline parameters."; }
736
737 bool do_unroll() const { return do_unroll_; }
738 bool reuse_headers() const { return !do_unroll(); }
739
740 void set_from_str(const std::string &s) override {
741 do_unroll_ = false;
742 for (auto c : s) {
743 switch (c) {
744 case 'u': do_unroll_ = true; break;
745 default: ir_error_not_expected() << s;
746 }
747 }
748 }
749
750 void set(bool do_unroll) { do_unroll_ = do_unroll; }
751
752private:
753 bool do_unroll_ = false;
754};
755
756class prb_param_t : public value_param_t<conv_problem_t> {
757public:
758 using value_param_t::value_param_t;
759
760 std::string name() const override { return "prb"; }
761 std::string desc() const override { return "Convolution problem."; }
762
763 void set_pd(const convolution_pd_t *pd) {
764 value_.conv_pd = pd;
765 value_.attr = pd->attr();
766 }
767};
768
769class prefetch_param_t : public param_t {
770public:
771 std::string name() const override { return "prefetch"; }
772 std::string short_name() const override { return "p"; }
773 std::string desc() const override { return "Parameters for prefetching."; }
774
775 int bufs() const { return bufs_; }
776
777 operator bool() const { return bufs_ > 0; }
778
779 void set_from_str(const std::string &s) override {
780 auto parts = ir_utils::split(s, ".");
781 for (auto &p : parts) {
782 ir_assert(p.size() >= 2) << p;
783 char name = p[0];
784 int value = std::stoi(p.substr(1));
785 switch (name) {
786 case 'x': bufs_ = value; break;
787 default: ir_error_not_expected() << p;
788 }
789 }
790 }
791
792 void set(int bufs) { bufs_ = bufs; }
793
794private:
795 int bufs_ = 0;
796};
797
798class reduce_b_param_t : public bool_param_t {
799public:
800 reduce_b_param_t() : bool_param_t(false) {}
801 std::string name() const override { return "reduce-b"; }
802 std::string desc() const override {
803 return "Whether to reduce B tensor (used for dst reduction in backward "
804 "by weights).";
805 }
806};
807
808class reduce_grf_usage_param_t : public bool_param_t {
809public:
810 reduce_grf_usage_param_t() : bool_param_t(true) {}
811 std::string name() const override { return "reduce-grf-usage"; }
812 std::string short_name() const override { return "r"; }
813 std::string desc() const override {
814 return "Whether to try to reduce GRF usage based on heuristics.";
815 }
816};
817
818// TODO: Remove this parameter and enable 2D block messages based on the
819// generation flow.
820class send_2d_nhwc_param_t : public bool_param_t {
821public:
822 send_2d_nhwc_param_t() : bool_param_t(false) {}
823 std::string name() const override { return "2d-send-nhwc"; }
824 std::string desc() const override {
825 return "Whether to use the optimal NHWC setup relying on 2D block "
826 "messages.";
827 }
828};
829
830class slm_param_t : public param_t {
831public:
832 std::string name() const override { return "slm"; }
833 std::string short_name() const override { return "s"; }
834 std::string desc() const override { return "SLM buffering parameters."; }
835
836 int bufs() const { return bufs_; }
837 int gmem_bufs() const { return gmem_bufs_; }
838 int sync_version() const { return sync_version_; }
839 bool a() const { return a_; }
840 bool b() const { return b_; }
841
842 operator bool() const { return bufs() > 0; }
843
844 void set_from_str(const std::string &s) override {
845 auto parts = ir_utils::split(s, ".");
846 for (auto &p : parts) {
847 ir_assert(p.size() >= 2) << p;
848 char name = p[0];
849 int value = std::stoi(p.substr(1));
850 switch (name) {
851 case 'x': bufs_ = value; break;
852 case 'g': gmem_bufs_ = value; break;
853 case 'v': sync_version_ = value; break;
854 default: ir_error_not_expected() << p;
855 }
856 }
857 if (bufs_ > 0) {
858 a_ = true;
859 b_ = true;
860 }
861 }
862
863 void set(int bufs, int gmem_bufs, bool a, bool b) {
864 bufs_ = bufs;
865 gmem_bufs_ = gmem_bufs;
866 a_ = a;
867 b_ = b;
868 }
869
870 void set_bufs(int bufs) { bufs_ = bufs; }
871 void set_gmem_bufs(int gmem_bufs) { gmem_bufs_ = gmem_bufs; }
872
873private:
874 // Number of SLM buffers to use (0, 1, 2 or 3).
875 int bufs_ = 0;
876 // Number of GRF buffers to use for GMEM -> SLM copy (0, 1 or 2).
877 int gmem_bufs_ = 0;
878 // See slm_sync_manager_t for more details.
879 int sync_version_ = -1;
880 // Whether SLM buffering for A is enabled.
881 bool a_ = false;
882 // Whether SLM buffering for B is enabled.
883 bool b_ = false;
884};
885
886class src_layout_param_t : public layout_param_t {
887public:
888 std::string name() const override { return "src"; }
889 std::string desc() const override { return "Source layout."; }
890};
891
892// Subtiles to split into for the inner A x B multiplication:
893//
894// Case 1. a_subtiles = 1, b_subtiles = 1
895// A = load(...)
896// B = load(...)
897// C += A * B
898//
899// Case 2. a_subtiles > 1, b_subtiles = 1
900// B = load(...)
901// for i in range(0, a_subtiles):
902// A_i = load(...)
903// C_i += A_i * B
904//
905// Case 3. a_subtiles = 1, b_subtiles > 1
906// A = load(...)
907// for j in range(0, b_subtiles):
908// B_j = load(...)
909// C_j += A * B_j
910//
911// Tiling for A and tiling for B are mutually exclusive. Using subtiles helps
912// to reduce GRF consumption.
913class subtiles_param_t : public param_t {
914public:
915 std::string name() const override { return "subtiles"; }
916 std::string short_name() const override { return "S"; }
917 std::string desc() const override { return "Sub-iteration blocking."; }
918
919 int a() const { return a_; }
920 int b() const { return b_; }
921
922 void set_from_str(const std::string &s) override {
923 a_ = 1;
924 b_ = 1;
925 for (auto &kv : to_map(s)) {
926 if (kv.first == "a") {
927 a_ = kv.second;
928 } else if (kv.first == "b") {
929 b_ = kv.second;
930 } else {
931 ir_error_not_expected() << kv.first;
932 }
933 }
934 }
935
936 void set(int a, int b) {
937 a_ = a;
938 b_ = b;
939 }
940
941 void set_a(int a) { a_ = a; }
942 void set_b(int b) { b_ = b; }
943
944private:
945 int a_ = 1;
946 int b_ = 1;
947};
948
949class thread_group_grid_param_t : public grid_param_t {
950public:
951 std::string name() const override { return "tg-grid"; }
952 std::string desc() const override { return "Thread group grid."; }
953};
954
955class thread_group_dims_param_t : public map_param_t {
956public:
957 std::string name() const override { return "tg"; }
958 std::string short_name() const override { return "T"; }
959 std::string desc() const override {
960 return "Thread group-level dimension blocks.";
961 }
962};
963
964class unroll_param_t : public map_param_t {
965public:
966 std::string name() const override { return "unroll"; }
967 std::string short_name() const override { return "u"; }
968 std::string desc() const override {
969 return "Per-dimension unroll factors.";
970 }
971};
972
973class wei_layout_param_t : public layout_param_t {
974 std::string name() const override { return "wei"; }
975 std::string desc() const override { return "Weights layout."; }
976};
977
978namespace constants {
979// Maximum number of SLM buffers.
980static const int max_slm_bufs = 3;
981
982// GRF usage for kernel arguments, local work IDs/sizes, signal header,
983// temporary expressions, etc.
984static const int reserved_regs = 20;
985} // namespace constants
986
987class conv_config_t {
988public:
989 conv_config_t() = default;
990
991#define DECL_PARAM(name) \
992 const name##_param_t &name##_param() const { \
993 (void)name##_init_; \
994 return name##_; \
995 } \
996 name##_param_t &name##_param() { return name##_; } \
997 const name##_param_t::value_t &name() const { return name##_.get(); } \
998 void set_##name(const name##_param_t::value_t &value) { \
999 name##_.set(value); \
1000 }
1001
1002#define DECL_PARAM2(name) \
1003 const name##_param_t &name() const { \
1004 (void)name##_init_; \
1005 return name##_; \
1006 } \
1007 name##_param_t &name() { return name##_; }
1008
1009 DECL_PARAM(allow_a_grf_reorder)
1010 DECL_PARAM(allow_b_grf_reorder)
1011 DECL_PARAM(allow_slm_tg_slicing)
1012 DECL_PARAM(assign_sbids)
1013 DECL_PARAM(bwd_d_optimize_strided)
1014 DECL_PARAM(bwd_d_optimize_strided_iw)
1015 DECL_PARAM(check_slm_size)
1016 DECL_PARAM(exec_cfg)
1017 DECL_PARAM(fma_kind)
1018 DECL_PARAM(fuse_spatial)
1019 DECL_PARAM(hint)
1020 DECL_PARAM(hoist_masks_from_compute_loop)
1021 DECL_PARAM(kernel_grid)
1022 DECL_PARAM(ow_kw_grf_cache)
1023 DECL_PARAM(pad_slm)
1024 DECL_PARAM(prb)
1025 DECL_PARAM(reduce_b)
1026 DECL_PARAM(reduce_grf_usage)
1027 DECL_PARAM(send_2d_nhwc)
1028 DECL_PARAM(shrink_tg_dims)
1029 DECL_PARAM(thread_group_grid)
1030 DECL_PARAM2(bia_layout)
1031 DECL_PARAM2(dims)
1032 DECL_PARAM2(dst_layout)
1033 DECL_PARAM2(iter_dims)
1034 DECL_PARAM2(loop_dims)
1035 DECL_PARAM2(padded_dims)
1036 DECL_PARAM2(pipeline)
1037 DECL_PARAM2(prefetch)
1038 DECL_PARAM2(slm)
1039 DECL_PARAM2(src_layout)
1040 DECL_PARAM2(subtiles)
1041 DECL_PARAM2(thread_group_dims)
1042 DECL_PARAM2(unroll)
1043 DECL_PARAM2(wei_layout)
1044
1045#undef DECL_PARAM
1046#undef DECL_PARAM2
1047
1048 void override_set(const std::string &s);
1049
1050 std::string str() const;
1051
1052 std::string blocking_brief_str() const;
1053
1054 // Helper methods.
1055 int dim(const std::string &name) const { return dims()(name); }
1056
1057 int iter_dim(const std::string &name) const { return iter_dims()(name); }
1058
1059 int padded_dim(const std::string &name) const {
1060 return padded_dims()(name);
1061 }
1062
1063 int loop_dim(const std::string &name) const { return loop_dims()(name); }
1064
1065 int thread_group_dim(const std::string &name) const {
1066 return thread_group_dims()(name);
1067 }
1068
1069 // Blocks for padding. This is to comply with
1070 // zero-padding requirements. For example if the output
1071 // layout is nChw32c but there are only 8 channels to
1072 // compute and store, we still need to pad 8 to 32 and
1073 // spawn more thread groups to ensure 32c block is
1074 // properly zero-padded.
1075 int pad_block(const std::string &name) const {
1076 auto &src = src_layout().compute();
1077 auto &wei = wei_layout().compute();
1078 auto &dst = dst_layout().compute();
1079
1080#define CASE(_name, layout, idx) \
1081 if (name == _name) return layout.inner_block(idx)
1082
1083 if (prb().is_fwd) {
1084 CASE("mb", dst, 0);
1085 CASE("g", dst, 1);
1086 CASE("oc", dst, 2);
1087 } else if (prb().is_bwd_d) {
1088 CASE("mb", src, 0);
1089 CASE("g", src, 1);
1090 CASE("ic", src, 2);
1091 } else if (prb().is_bwd_w) {
1092 CASE("g", wei, 0);
1093 CASE("oc", wei, 1);
1094 CASE("ic", wei, 2);
1095 }
1096#undef CASE
1097 return 1;
1098 }
1099
1100 int unroll(const std::string &name) const { return unroll()(name); }
1101
1102 int reserved_regs() const { return constants::reserved_regs; }
1103
1104 const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); }
1105
1106 ngen::HW hw() const { return hw_cfg().hw(); }
1107
1108 bool is_ge_xe_hpc() const { return hw() >= ngen::HW::XeHPC; }
1109
1110 int grf_size() const { return hw_cfg().grf_size(); }
1111
1112 int regs() const { return exec_cfg().regs(); }
1113
1114 int simd() const { return exec_cfg().simd(); }
1115
1116 int vec_size() const { return exec_cfg().vec_size(); }
1117
1118 bool is_g_mad() const {
1119 return fma_kind() == fma_kind_t::mad && prb().g > 1 && prb().ic < 4
1120 && prb().oc < 4 && prb().mb < 8 && !prb().is_dw;
1121 }
1122
1123 bool is_dp_fma() const {
1124 return utils::one_of(fma_kind(), fma_kind_t::dpas, fma_kind_t::dpasw,
1125 fma_kind_t::dp4a);
1126 }
1127
1128 bool is_dpas_or_dpasw_fma() const {
1129 return utils::one_of(fma_kind(), fma_kind_t::dpas, fma_kind_t::dpasw);
1130 }
1131
1132 const layout_param_t &a_layout() const {
1133 return prb().pick_a<const layout_param_t &>(
1134 src_layout(), wei_layout(), dst_layout());
1135 }
1136
1137 const layout_param_t &b_layout() const {
1138 return prb().pick_b<const layout_param_t &>(
1139 src_layout(), wei_layout(), dst_layout());
1140 }
1141
1142 compute::nd_range_t nd_range() const {
1143 size_t gws[3];
1144 size_t lws[3];
1145 for (int i = 0; i < 3; i++) {
1146 lws[i] = thread_group_grid().dim(i) * (i == 0 ? simd() : 1);
1147 gws[i] = kernel_grid().dim(i) * lws[i];
1148 }
1149 return compute::nd_range_t(gws, lws);
1150 }
1151
1152 int grid_dim(const std::string &dim) const {
1153 return ir_utils::safe_divide(padded_dim(dim),
1154 loop_dim(dim) * thread_group_dim(dim) * iter_dim(dim));
1155 }
1156
1157 int iter_dim(std::initializer_list<const char *> dims) const {
1158 int ret = 1;
1159 for (auto *dim : dims)
1160 ret *= iter_dim(dim);
1161 return ret;
1162 }
1163
1164 void set_pd(const convolution_pd_t *pd) { prb_.set_pd(pd); }
1165
1166 void set_regs(int regs) {
1167 auto tmp = exec_cfg();
1168 tmp.set_regs(regs);
1169 set_exec_cfg(tmp);
1170 }
1171
1172 void set_simd(int simd) {
1173 auto tmp = exec_cfg();
1174 tmp.set_simd(simd);
1175 set_exec_cfg(tmp);
1176 }
1177
1178 void set_vec_size(int vec_size) {
1179 auto tmp = exec_cfg();
1180 tmp.set_vec_size(vec_size);
1181 set_exec_cfg(tmp);
1182 }
1183
1184 bool can_skip_wei_zero_out() const;
1185 bool can_skip_bia_zero_out() const;
1186
1187private:
1188 struct param_init_t {};
1189
1190 template <typename GetParamFuncT, typename PtrT>
1191 static param_init_t register_param(
1192 PtrT ptr, std::vector<GetParamFuncT> &get_params_) {
1193 get_params_.push_back([=](conv_config_t *cfg) { return &(cfg->*ptr); });
1194 return param_init_t();
1195 }
1196
1197 std::vector<std::function<param_t *(conv_config_t *)>> get_params_;
1198
1199#define INIT_PARAM(name) \
1200 name##_param_t name##_; \
1201 param_init_t name##_init_ \
1202 = register_param(&conv_config_t::name##_, get_params_);
1203
1204 INIT_PARAM(allow_a_grf_reorder)
1205 INIT_PARAM(allow_b_grf_reorder)
1206 INIT_PARAM(allow_slm_tg_slicing)
1207 INIT_PARAM(assign_sbids)
1208 INIT_PARAM(bia_layout)
1209 INIT_PARAM(bwd_d_optimize_strided)
1210 INIT_PARAM(bwd_d_optimize_strided_iw)
1211 INIT_PARAM(check_slm_size)
1212 INIT_PARAM(dims)
1213 INIT_PARAM(dst_layout)
1214 INIT_PARAM(exec_cfg)
1215 INIT_PARAM(fma_kind)
1216 INIT_PARAM(fuse_spatial)
1217 INIT_PARAM(hint)
1218 INIT_PARAM(hoist_masks_from_compute_loop)
1219 INIT_PARAM(iter_dims)
1220 INIT_PARAM(kernel_grid)
1221 INIT_PARAM(loop_dims)
1222 INIT_PARAM(ow_kw_grf_cache)
1223 INIT_PARAM(pad_slm)
1224 INIT_PARAM(padded_dims)
1225 INIT_PARAM(pipeline)
1226 INIT_PARAM(prb)
1227 INIT_PARAM(prefetch)
1228 INIT_PARAM(reduce_b)
1229 INIT_PARAM(reduce_grf_usage)
1230 INIT_PARAM(send_2d_nhwc)
1231 INIT_PARAM(shrink_tg_dims)
1232 INIT_PARAM(slm)
1233 INIT_PARAM(src_layout)
1234 INIT_PARAM(subtiles)
1235 INIT_PARAM(thread_group_dims)
1236 INIT_PARAM(thread_group_grid)
1237 INIT_PARAM(unroll)
1238 INIT_PARAM(wei_layout)
1239
1240#undef INIT_PARAM
1241};
1242
1243inline std::ostream &operator<<(std::ostream &out, const conv_config_t &cfg) {
1244 out << cfg.str();
1245 return out;
1246}
1247
1248class bmnk_dim_helper_t {
1249public:
1250 bmnk_dim_helper_t(const conv_config_t &cfg) : prb_(cfg.prb()), cfg_(cfg) {}
1251
1252 int iter_dim(char bmnk) const {
1253 int ret = 1;
1254 for (auto &kv : cfg_.iter_dims().get()) {
1255 if (to_bmnk(kv.first) != bmnk) continue;
1256 ret *= kv.second;
1257 }
1258 return ret;
1259 }
1260
1261 int thread_group_dim(char bmnk) const {
1262 int ret = 1;
1263 for (auto &kv : cfg_.thread_group_dims().get()) {
1264 if (to_bmnk(kv.first) != bmnk) continue;
1265 ret *= kv.second;
1266 }
1267 return ret;
1268 }
1269
1270 int loop_dim(char bmnk) const {
1271 int ret = 1;
1272 for (auto &kv : cfg_.loop_dims().get()) {
1273 if (to_bmnk(kv.first) != bmnk) continue;
1274 ret *= kv.second;
1275 }
1276 return ret;
1277 }
1278
1279private:
1280 static bool contains(const char **array, const std::string &s) {
1281 for (const char **ptr = array; *ptr; ptr++) {
1282 if (s == *ptr) return true;
1283 }
1284 return false;
1285 }
1286
1287 char to_bmnk(const std::string &dim_name) const {
1288 static const char *fwd_b_dims[] = {"g", nullptr};
1289 static const char *fwd_m_dims[]
1290 = {"mb", "osp", "od", "oh", "ow", nullptr};
1291 static const char *fwd_n_dims[] = {"oc", nullptr};
1292 static const char *fwd_k_dims[] = {"ic", "kd", "kh", "kw", nullptr};
1293 static const char *bwd_d_b_dims[] = {"g", nullptr};
1294 static const char *bwd_d_m_dims[] = {"mb", "id", "ih", "iw", nullptr};
1295 static const char *bwd_d_n_dims[] = {"ic", nullptr};
1296 static const char *bwd_d_k_dims[] = {"oc", "kd", "kh", "kw", nullptr};
1297 static const char *bwd_w_b_dims[] = {"g", nullptr};
1298 static const char *bwd_w_m_dims[] = {"ic", "kd", "kh", "kw", nullptr};
1299 static const char *bwd_w_n_dims[] = {"oc", nullptr};
1300 static const char *bwd_w_k_dims[] = {"mb", "od", "oh", "ow", nullptr};
1301
1302 const char **b_dims = prb_.pick_by_dir<const char **>(
1303 fwd_b_dims, bwd_d_b_dims, bwd_w_b_dims);
1304 const char **m_dims = prb_.pick_by_dir<const char **>(
1305 fwd_m_dims, bwd_d_m_dims, bwd_w_m_dims);
1306 const char **n_dims = prb_.pick_by_dir<const char **>(
1307 fwd_n_dims, bwd_d_n_dims, bwd_w_n_dims);
1308 const char **k_dims = prb_.pick_by_dir<const char **>(
1309 fwd_k_dims, bwd_d_k_dims, bwd_w_k_dims);
1310
1311 if (contains(b_dims, dim_name)) return 'b';
1312 if (contains(m_dims, dim_name)) return 'm';
1313 if (contains(n_dims, dim_name)) return 'n';
1314 if (contains(k_dims, dim_name)) return 'k';
1315
1316 ir_error_not_expected() << dim_name;
1317 return ' ';
1318 }
1319
1320 const conv_problem_t &prb_;
1321 const conv_config_t &cfg_;
1322};
1323
1324status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg,
1325 const engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr);
1326status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd);
1327tensor_config_t get_tensor_config(const conv_config_t &cfg);
1328int estimate_register_count(const conv_config_t &cfg);
1329bool can_use_a_2d_send(const conv_config_t &cfg);
1330bool can_use_b_2d_send(const conv_config_t &cfg);
1331const char **get_kernel_grid_conv_dims(const conv_problem_t &prb, int idx);
1332const char **get_thread_group_grid_conv_dims(
1333 const conv_problem_t &prb, int idx);
1334
1335} // namespace jit
1336} // namespace gpu
1337} // namespace impl
1338} // namespace dnnl
1339
1340#endif
1341