1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_CONFIG_HPP |
18 | #define GPU_JIT_CONV_CONFIG_HPP |
19 | |
20 | #include <iostream> |
21 | #include <sstream> |
22 | #include <unordered_map> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/convolution_pd.hpp" |
26 | #include "common/math_utils.hpp" |
27 | #include "common/memory_desc_wrapper.hpp" |
28 | #include "common/type_helpers.hpp" |
29 | #include "gpu/compute/compute.hpp" |
30 | #include "gpu/compute/compute_engine.hpp" |
31 | #include "gpu/jit/conv/block_helper.hpp" |
32 | #include "gpu/jit/conv/tensor_config.hpp" |
33 | #include "gpu/jit/ir/fma.hpp" |
34 | #include "gpu/jit/ir/hw_config.hpp" |
35 | #include "gpu/jit/ir/tensor.hpp" |
36 | #include "gpu/jit/jit_eltwise_injector.hpp" |
37 | #include "gpu/jit/utils/utils.hpp" |
38 | |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace gpu { |
42 | namespace jit { |
43 | |
44 | // Description of the convolution problem. |
45 | class conv_problem_t { |
46 | public: |
47 | conv_problem_t() = default; |
48 | |
49 | status_t init(const engine_t *engine, const convolution_pd_t *conv_pd); |
50 | |
51 | bool is_stride1() const { return sd == 1 && sh == 1 && sw == 1; } |
52 | |
53 | // Reduces dimensions for 1x1 kernel. |
54 | void try_reduce_to_1d(); |
55 | |
56 | // Helper methods. |
57 | bool is_s32_accumulator() const { return acc_data_type == data_type::s32; } |
58 | bool is_f32_conv() const { |
59 | return utils::everyone_is(src_data_type, wei_data_type, data_type::f32); |
60 | } |
61 | bool is_f64_conv() const { |
62 | return utils::everyone_is(src_data_type, wei_data_type, data_type::f64); |
63 | } |
64 | bool is_int8_dst() const { |
65 | return utils::one_of(dst_data_type, data_type::s8, data_type::u8); |
66 | } |
67 | bool is_mixed_int8() const { |
68 | return utils::one_of(a_data_type, dnnl_f16, dnnl_f32) |
69 | && utils::one_of(c_data_type, dnnl_u8, dnnl_s8); |
70 | } |
71 | bool matches_user_types() const { |
72 | if (is_fwd) { |
73 | return a_data_type == src_data_type && b_data_type == wei_data_type |
74 | && c_data_type == dst_data_type; |
75 | } else if (is_bwd_d) { |
76 | return a_data_type == dst_data_type && b_data_type == wei_data_type |
77 | && c_data_type == src_data_type; |
78 | } else if (is_bwd_w) { |
79 | return a_data_type == src_data_type && b_data_type == dst_data_type |
80 | && c_data_type == wei_data_type; |
81 | } else { |
82 | ir_error_not_expected(); |
83 | return false; |
84 | } |
85 | } |
86 | |
87 | const memory_desc_t &a_md() const { |
88 | return *pick_a(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(), |
89 | conv_pd->invariant_dst_md()); |
90 | } |
91 | |
92 | const memory_desc_t &b_md() const { |
93 | return *pick_b(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(), |
94 | conv_pd->invariant_dst_md()); |
95 | } |
96 | |
97 | const memory_desc_t &c_md() const { |
98 | return *pick_c(conv_pd->invariant_src_md(), conv_pd->invariant_wei_md(), |
99 | conv_pd->invariant_dst_md()); |
100 | } |
101 | |
102 | template <typename T> |
103 | T &&pick_a(T &&src, T &&wei, T &&dst) const { |
104 | return (is_fwd || is_bwd_w) ? std::forward<T>(src) |
105 | : std::forward<T>(dst); |
106 | } |
107 | |
108 | template <typename T> |
109 | T &&pick_b(T &&src, T &&wei, T &&dst) const { |
110 | return (is_fwd || is_bwd_d) ? std::forward<T>(wei) |
111 | : std::forward<T>(dst); |
112 | } |
113 | |
114 | template <typename T> |
115 | T &&pick_c(T &&src, T &&wei, T &&dst) const { |
116 | return std::forward<T>(is_fwd ? dst : is_bwd_d ? src : wei); |
117 | } |
118 | |
119 | template <typename T> |
120 | T &&pick_by_dir(T &&fwd, T &&bwd_d, T &&bwd_w) const { |
121 | return std::forward<T>(is_fwd ? fwd : is_bwd_d ? bwd_d : bwd_w); |
122 | } |
123 | |
124 | std::string desc_str(bool print_mb = true) const; |
125 | |
126 | const convolution_pd_t *conv_pd; |
127 | const primitive_attr_t *attr; |
128 | |
129 | data_type_t src_data_type; |
130 | data_type_t wei_data_type; |
131 | data_type_t dst_data_type; |
132 | data_type_t bia_data_type; |
133 | fpmath_mode_t fpmath_mode; |
134 | |
135 | bool is_fwd; |
136 | bool is_bwd_d; |
137 | bool is_bwd_w; |
138 | bool with_bias; |
139 | bool with_groups; |
140 | bool with_sum; |
141 | bool is_dw; |
142 | |
143 | int ndims; |
144 | int mb; // Batch size. |
145 | int g; // Groups. |
146 | int ic, oc; // Input and output channels. |
147 | int id, ih, iw; // Input spatial sizes. |
148 | int od, oh, ow; // Output spatial sizes. |
149 | int kd, kh, kw; // Kernel sizes. |
150 | int sd, sh, sw; // Strides. |
151 | int pd, ph, pw; // Padding in the beginning. |
152 | int dd, dh, dw; // Dilation. |
153 | int reduced_dim; // Indicates which dims were shifted over or reduced. |
154 | int isp, osp, ksp; // Combined input/output/kernel spatial size. |
155 | |
156 | data_type_t a_data_type; |
157 | data_type_t b_data_type; |
158 | data_type_t c_data_type; |
159 | data_type_t acc_data_type; |
160 | |
161 | int a_data_type_size; |
162 | int b_data_type_size; |
163 | int c_data_type_size; |
164 | int acc_data_type_size; |
165 | |
166 | // Specific to FWD int8 |
167 | struct zero_points_config_t { |
168 | bool do_src_compensation; |
169 | bool do_dst_compensation; |
170 | bool is_runtime_src_zero_points; |
171 | bool is_runtime_dst_zero_points; |
172 | bool is_common_src_zero_point; |
173 | bool is_common_dst_zero_point; |
174 | int common_src_zero_point; |
175 | int common_dst_zero_point; |
176 | } zp_cfg; |
177 | |
178 | private: |
179 | // Initializes A/B/C data types (GEMM notation: C += A * B) according to |
180 | // the following convention: |
181 | // FWD: src -> A, wei -> B, dst -> C |
182 | // BWD_D: diff_dst -> A, wei -> B, diff_src -> C |
183 | // BWD_W: src -> A, diff_dst -> B, diff_wei -> C |
184 | status_t init_abc_data_types(ngen::HW hw) { |
185 | a_data_type = pick_a(src_data_type, wei_data_type, dst_data_type); |
186 | b_data_type = pick_b(src_data_type, wei_data_type, dst_data_type); |
187 | // Always use f32 for accumulation/storing in the main kernel. |
188 | c_data_type = is_bwd_w |
189 | ? data_type::f32 |
190 | : pick_c(src_data_type, wei_data_type, dst_data_type); |
191 | |
192 | if (utils::everyone_is( |
193 | data_type::f32, a_data_type, b_data_type, c_data_type)) { |
194 | |
195 | // TODO: bf16 and f16 currently perform worse than tf32, this is |
196 | // likely due to an extra reorder required on the b buffer. |
197 | bool use_matching_fpmath = false; |
198 | #ifdef GEN_CONV_DEBUG |
199 | use_matching_fpmath = ir_utils::getenv_bool( |
200 | "use_matching_fpmath" , use_matching_fpmath); |
201 | #endif |
202 | if (use_matching_fpmath |
203 | && attr->mayidownconvert(data_type::f32, data_type::bf16) |
204 | && fma_kind::get_supported_kind(hw, data_type::bf16, |
205 | data_type::bf16, data_type::f32) |
206 | != fma_kind_t::unknown) { |
207 | a_data_type = data_type::bf16; |
208 | b_data_type = data_type::bf16; |
209 | } else if (use_matching_fpmath |
210 | && attr->mayidownconvert(data_type::f32, data_type::f16) |
211 | && fma_kind::get_supported_kind(hw, data_type::f16, |
212 | data_type::f16, data_type::f32) |
213 | != fma_kind_t::unknown) { |
214 | a_data_type = data_type::f16; |
215 | b_data_type = data_type::f16; |
216 | } else if (attr->mayidownconvert(data_type::f32, data_type::tf32) |
217 | && fma_kind::get_supported_kind(hw, data_type::tf32, |
218 | data_type::tf32, data_type::f32) |
219 | != fma_kind_t::unknown) { |
220 | a_data_type = data_type::tf32; |
221 | b_data_type = data_type::tf32; |
222 | } |
223 | } |
224 | |
225 | a_data_type_size = (int)types::data_type_size(a_data_type); |
226 | b_data_type_size = (int)types::data_type_size(b_data_type); |
227 | c_data_type_size = (int)types::data_type_size(c_data_type); |
228 | return status::success; |
229 | } |
230 | |
231 | status_t init_acc_data_type() { |
232 | auto a = a_data_type; |
233 | auto b = b_data_type; |
234 | acc_data_type = data_type::undef; |
235 | if (utils::one_of(a, data_type::s8, data_type::u8) |
236 | && utils::one_of(b, data_type::s8, data_type::u8)) { |
237 | acc_data_type = data_type::s32; |
238 | } else if (utils::everyone_is(data_type::f16, a, b) |
239 | || utils::everyone_is(data_type::bf16, a, b)) { |
240 | acc_data_type = data_type::f32; |
241 | } else if (utils::everyone_is(data_type::tf32, a, b)) { |
242 | acc_data_type = data_type::f32; |
243 | } else if (utils::everyone_is(data_type::f32, a, b)) { |
244 | acc_data_type = data_type::f32; |
245 | } else if (utils::everyone_is(data_type::f64, a, b)) { |
246 | acc_data_type = data_type::f64; |
247 | } |
248 | if (acc_data_type == data_type::undef) return status::unimplemented; |
249 | acc_data_type_size = (int)types::data_type_size(acc_data_type); |
250 | return status::success; |
251 | } |
252 | |
253 | status_t init_zero_points_config(); |
254 | |
255 | bool with_sum_post_op() { |
256 | auto &post_ops = attr->post_ops_; |
257 | return post_ops.find(primitive_kind::sum) != -1; |
258 | } |
259 | }; |
260 | |
261 | class conv_hint_t { |
262 | public: |
263 | conv_hint_t() = default; |
264 | conv_hint_t(int def_max_tg_size) : def_max_tg_size_(def_max_tg_size) {} |
265 | |
266 | int max_tg_size() const { |
267 | if (max_tg_size_ != 0) return max_tg_size_; |
268 | return def_max_tg_size_; |
269 | } |
270 | |
271 | bool max_tg_overridden() const { return max_tg_overridden_; } |
272 | |
273 | void set_max_tg_size(int value) { |
274 | max_tg_overridden_ = max_tg_size_ != 0; |
275 | max_tg_size_ = value; |
276 | } |
277 | |
278 | private: |
279 | int max_tg_size_ = 0; |
280 | int def_max_tg_size_ = 0; |
281 | bool max_tg_overridden_ = false; |
282 | }; |
283 | |
284 | class conv_arg_helper_t { |
285 | public: |
286 | conv_arg_helper_t(const conv_problem_t &prb) : prb_(prb) {} |
287 | |
288 | int src_arg_key() const { |
289 | if (prb_.is_fwd) return DNNL_ARG_SRC; |
290 | if (prb_.is_bwd_d) return DNNL_ARG_DIFF_SRC; |
291 | if (prb_.is_bwd_w) return DNNL_ARG_SRC; |
292 | ir_error_not_expected(); |
293 | return -1; |
294 | } |
295 | |
296 | bool is_src_input() const { return prb_.is_fwd || prb_.is_bwd_w; } |
297 | bool is_src_output() const { return prb_.is_bwd_d; } |
298 | |
299 | int wei_arg_key() const { |
300 | if (prb_.is_fwd) return DNNL_ARG_WEIGHTS; |
301 | if (prb_.is_bwd_d) return DNNL_ARG_WEIGHTS; |
302 | if (prb_.is_bwd_w) return DNNL_ARG_DIFF_WEIGHTS; |
303 | ir_error_not_expected(); |
304 | return -1; |
305 | } |
306 | |
307 | bool is_wei_input() const { return prb_.is_fwd || prb_.is_bwd_d; } |
308 | bool is_wei_output() const { return prb_.is_bwd_w; } |
309 | |
310 | int bia_arg_key() const { |
311 | if (prb_.is_fwd) return DNNL_ARG_BIAS; |
312 | if (prb_.is_bwd_d) return DNNL_ARG_BIAS; |
313 | if (prb_.is_bwd_w) return DNNL_ARG_DIFF_BIAS; |
314 | ir_error_not_expected(); |
315 | return -1; |
316 | } |
317 | |
318 | bool is_bia_input() const { return prb_.is_fwd || prb_.is_bwd_d; } |
319 | bool is_bia_output() const { return prb_.is_bwd_w; } |
320 | |
321 | int dst_arg_key() const { |
322 | if (prb_.is_fwd) return DNNL_ARG_DST; |
323 | if (prb_.is_bwd_d) return DNNL_ARG_DIFF_DST; |
324 | if (prb_.is_bwd_w) return DNNL_ARG_DIFF_DST; |
325 | ir_error_not_expected(); |
326 | return -1; |
327 | } |
328 | |
329 | bool is_dst_input() const { return prb_.is_bwd_d || prb_.is_bwd_w; } |
330 | bool is_dst_output() const { return prb_.is_fwd; } |
331 | |
332 | private: |
333 | const conv_problem_t &prb_; |
334 | }; |
335 | |
336 | class param_t { |
337 | public: |
338 | virtual std::string name() const = 0; |
339 | virtual std::string short_name() const { return name(); } |
340 | virtual std::string desc() const = 0; |
341 | |
342 | virtual bool accept_key(const std::string &key) const { |
343 | return key == short_name(); |
344 | } |
345 | |
346 | virtual void set_from_str(const std::string &s) { ir_error_not_expected(); } |
347 | virtual void set_from_str( |
348 | const std::string &key, const std::string &value) { |
349 | if (key == short_name()) { |
350 | set_from_str(value); |
351 | return; |
352 | } |
353 | ir_error_not_expected(); |
354 | } |
355 | void override_set(const std::string &key, const std::string &value) { |
356 | is_overridden_[key] = true; |
357 | set_from_str(key, value); |
358 | } |
359 | |
360 | bool is_overridden() const { |
361 | if (is_overridden_.empty()) return false; |
362 | return is_overridden(short_name()); |
363 | } |
364 | |
365 | bool is_overridden(const std::string &key) const { |
366 | auto it = is_overridden_.find(key); |
367 | if (it == is_overridden_.end()) return false; |
368 | return it->second; |
369 | } |
370 | |
371 | private: |
372 | std::unordered_map<std::string, bool> is_overridden_; |
373 | }; |
374 | |
375 | template <typename T> |
376 | class value_param_t : public param_t { |
377 | public: |
378 | using value_t = T; |
379 | using param_t::is_overridden; |
380 | |
381 | value_param_t() = default; |
382 | value_param_t(const T &value) : value_(value) {} |
383 | |
384 | const T &get() const { return value_; } |
385 | |
386 | operator const T &() const { return get(); } |
387 | |
388 | void set(const T &value) { value_ = value; } |
389 | |
390 | protected: |
391 | T value_; |
392 | }; |
393 | |
394 | class bool_param_t : public value_param_t<bool> { |
395 | public: |
396 | using value_param_t::value_param_t; |
397 | |
398 | void set_from_str(const std::string &s) override { |
399 | value_ = ir_utils::to_bool(s); |
400 | } |
401 | }; |
402 | |
403 | class int_param_t : public value_param_t<int> { |
404 | public: |
405 | using value_param_t::value_param_t; |
406 | |
407 | void set_from_str(const std::string &s) override { value_ = std::stoi(s); } |
408 | }; |
409 | |
410 | class grid_param_t : public value_param_t<grid_info_t> { |
411 | public: |
412 | using value_param_t::value_param_t; |
413 | }; |
414 | |
415 | class layout_param_t : public param_t { |
416 | public: |
417 | const layout_t &user() const { return user_; } |
418 | const layout_t &compute() const { return compute_; } |
419 | const layout_t &user_unnormalized() const { return user_unnormalized_; } |
420 | const layout_t &compute_unnormalized() const { |
421 | return compute_unnormalized_; |
422 | } |
423 | |
424 | void set_from_str(const std::string &s) override { |
425 | ir_error_not_implemented(); |
426 | } |
427 | |
428 | void set_user(const layout_t &l) { user_ = l; } |
429 | void set_compute(const layout_t &l) { compute_ = l; } |
430 | void set_user_unnormalized(const layout_t &l) { user_unnormalized_ = l; } |
431 | void set_compute_unnormalized(const layout_t &l) { |
432 | compute_unnormalized_ = l; |
433 | } |
434 | |
435 | private: |
436 | layout_t user_; |
437 | layout_t compute_; |
438 | layout_t user_unnormalized_; |
439 | layout_t compute_unnormalized_; |
440 | }; |
441 | |
442 | inline std::unordered_map<std::string, int> to_map(const std::string &s) { |
443 | std::unordered_map<std::string, int> ret; |
444 | int name_beg = -1; |
445 | int value_beg = -1; |
446 | for (int pos = 0; pos < (int)s.size() + 1; pos++) { |
447 | bool prev_digit = pos > 0 && std::isdigit(s[pos - 1]); |
448 | bool cur_digit = pos < (int)s.size() && std::isdigit(s[pos]); |
449 | if ((pos == 0 || prev_digit) && !cur_digit) { |
450 | if (name_beg != -1 && value_beg != -1) { |
451 | auto key = s.substr(name_beg, value_beg - name_beg); |
452 | auto value = std::stoi(s.substr(value_beg, pos - value_beg)); |
453 | ret[key] = value; |
454 | } |
455 | name_beg = pos; |
456 | value_beg = -1; |
457 | } |
458 | if (!prev_digit && cur_digit) value_beg = pos; |
459 | } |
460 | return ret; |
461 | } |
462 | |
463 | class map_param_t : public param_t { |
464 | public: |
465 | using value_t = std::unordered_map<std::string, int>; |
466 | |
467 | const value_t &get() const { return map_; } |
468 | |
469 | bool is_empty() const { return map_.empty(); } |
470 | |
471 | int get(const std::string &name) const { |
472 | auto it = map_.find(name); |
473 | if (it == map_.end()) return 1; |
474 | return it->second; |
475 | } |
476 | |
477 | int operator()(const std::string &name) const { return get(name); } |
478 | |
479 | void set_from_str(const std::string &s) override { |
480 | map_.clear(); |
481 | map_ = to_map(s); |
482 | } |
483 | |
484 | void set(const std::string &name, int dim) { |
485 | auto it = map_.find(name); |
486 | if (dim == 1) { |
487 | if (it != map_.end()) map_.erase(it); |
488 | return; |
489 | } |
490 | map_[name] = dim; |
491 | } |
492 | |
493 | void set(const value_t &value) { map_ = value; } |
494 | |
495 | std::string str() const { |
496 | std::ostringstream oss; |
497 | for (auto &kv : map_) { |
498 | oss << kv.first << kv.second; |
499 | } |
500 | return oss.str(); |
501 | } |
502 | |
503 | IR_DEFINE_DUMP() |
504 | |
505 | private: |
506 | value_t map_; |
507 | }; |
508 | |
509 | // Parameters for kernel generation. |
510 | |
511 | // TODO: Remove, GRF reorder is to be emitted/accounted for depending on |
512 | // layouts, FMA kind and other parameters. |
513 | class allow_a_grf_reorder_param_t : public bool_param_t { |
514 | public: |
515 | allow_a_grf_reorder_param_t() : bool_param_t(false) {} |
516 | std::string name() const override { return "allow-a-grf-reorder" ; } |
517 | std::string desc() const override { |
518 | return "Whether to allow GRF reorders to FMA-friendly layouts for A." ; |
519 | } |
520 | }; |
521 | |
522 | // TODO: Remove, GRF reorder is to be emitted/accounted for depending on |
523 | // layouts, FMA kind and other parameters. |
524 | class allow_b_grf_reorder_param_t : public bool_param_t { |
525 | public: |
526 | allow_b_grf_reorder_param_t() : bool_param_t(false) {} |
527 | std::string name() const override { return "allow-b-grf-reorder" ; } |
528 | std::string desc() const override { |
529 | return "Whether to allow GRF reorders to FMA-friendly layouts for B." ; |
530 | } |
531 | }; |
532 | |
533 | // TODO: Remove, use internal logic to determine if SLM thread group slicing is |
534 | // needed. |
535 | class allow_slm_tg_slicing_param_t : public bool_param_t { |
536 | public: |
537 | allow_slm_tg_slicing_param_t() : bool_param_t(false) {} |
538 | std::string name() const override { return "allow-slm-tg-slicing" ; } |
539 | std::string desc() const override { |
540 | return "Whether to allow thread group split for SLM load/store." ; |
541 | } |
542 | }; |
543 | |
544 | class assign_sbids_param_t : public bool_param_t { |
545 | public: |
546 | assign_sbids_param_t() : bool_param_t(false) {} |
547 | std::string name() const override { return "assign-sbids" ; } |
548 | std::string desc() const override { |
549 | return "Whether to manually assign SBIDs tokens to dpas/send." ; |
550 | } |
551 | }; |
552 | |
553 | class bia_layout_param_t : public layout_param_t { |
554 | std::string name() const override { return "bia" ; } |
555 | std::string desc() const override { return "Bias layout." ; } |
556 | }; |
557 | |
558 | class bwd_d_optimize_strided_param_t : public bool_param_t { |
559 | public: |
560 | bwd_d_optimize_strided_param_t() : bool_param_t(false) {} |
561 | std::string name() const override { return "bwd-d-optimize-strided" ; } |
562 | std::string desc() const override { |
563 | return "Apply special optimization for strided BWD_D convolution." ; |
564 | } |
565 | }; |
566 | |
567 | class bwd_d_optimize_strided_iw_param_t : public bool_param_t { |
568 | public: |
569 | bwd_d_optimize_strided_iw_param_t() : bool_param_t(false) {} |
570 | std::string name() const override { return "bwd-d-optimize-strided-iw" ; } |
571 | std::string desc() const override { |
572 | return "Apply special optimization for strided BWD_D convolution (iw " |
573 | "dimension)." ; |
574 | } |
575 | }; |
576 | |
577 | // TODO: Remove, use heuristics to determine if it's worth to sacrifice EU |
578 | // utilization for larger SLM size. |
579 | class check_slm_size_param_t : public bool_param_t { |
580 | public: |
581 | check_slm_size_param_t() : bool_param_t(true) {} |
582 | std::string name() const override { return "check-slm-size" ; } |
583 | std::string short_name() const override { return "c" ; } |
584 | std::string desc() const override { |
585 | return "Whether to check SLM size to ensure full EU utilization." ; |
586 | } |
587 | }; |
588 | |
589 | class dims_param_t : public map_param_t { |
590 | public: |
591 | std::string name() const override { return "dims" ; } |
592 | std::string desc() const override { return "Problem dimensions." ; } |
593 | }; |
594 | |
595 | class dst_layout_param_t : public layout_param_t { |
596 | std::string name() const override { return "dst" ; } |
597 | std::string desc() const override { return "Destination layout." ; } |
598 | }; |
599 | |
600 | class exec_cfg_param_t : public value_param_t<exec_config_t> { |
601 | public: |
602 | using value_param_t::accept_key; |
603 | using value_param_t::is_overridden; |
604 | using value_param_t::value_param_t; |
605 | |
606 | std::string name() const override { return "exec-cfg" ; } |
607 | std::string desc() const override { |
608 | return "Execution config (hardware config, number of registers, SIMD, " |
609 | "etc)." ; |
610 | } |
611 | |
612 | bool accept_key(const std::string &key) const override { |
613 | if (key == "simd" ) return true; |
614 | return false; |
615 | } |
616 | |
617 | void set_from_str( |
618 | const std::string &key, const std::string &value) override { |
619 | if (key == "simd" ) { |
620 | value_.set_simd(std::stoi(value)); |
621 | } else { |
622 | ir_error_not_expected() << key; |
623 | } |
624 | } |
625 | }; |
626 | |
627 | class fma_kind_param_t : public value_param_t<fma_kind_t> { |
628 | public: |
629 | using value_param_t::value_param_t; |
630 | |
631 | std::string name() const override { return "fma" ; } |
632 | std::string desc() const override { return "FMA kind." ; } |
633 | |
634 | void set_from_str(const std::string &s) override { |
635 | value_ = fma_kind::from_string(s); |
636 | } |
637 | }; |
638 | |
639 | class fuse_spatial_param_t : public bool_param_t { |
640 | public: |
641 | fuse_spatial_param_t() : bool_param_t(false) {} |
642 | std::string name() const override { return "fuse-spatial" ; } |
643 | std::string short_name() const override { return "fsp" ; } |
644 | std::string desc() const override { |
645 | return "Whether to apply blocking to fused spatial (otherwise only `w` " |
646 | "is blocked)." ; |
647 | } |
648 | }; |
649 | |
650 | class hint_param_t : public value_param_t<conv_hint_t> { |
651 | public: |
652 | using value_param_t::value_param_t; |
653 | |
654 | std::string name() const override { return "hint" ; } |
655 | std::string desc() const override { return "Configuration hint." ; } |
656 | }; |
657 | |
658 | // TODO: Remove, use internal logic. |
659 | class hoist_masks_from_compute_loop_param_t : public bool_param_t { |
660 | public: |
661 | hoist_masks_from_compute_loop_param_t() : bool_param_t(false) {} |
662 | std::string name() const override { |
663 | return "hoist-masks-from-compute-loop" ; |
664 | } |
665 | std::string desc() const override { |
666 | return "Whether to move send mask initialization out of compute loop." ; |
667 | } |
668 | }; |
669 | |
670 | class kernel_grid_param_t : public grid_param_t { |
671 | public: |
672 | std::string name() const override { return "kernel-grid" ; } |
673 | std::string desc() const override { |
674 | return "Number of thread groups across dimensions (kernel grid)." ; |
675 | } |
676 | }; |
677 | |
678 | class iter_dims_param_t : public map_param_t { |
679 | public: |
680 | std::string name() const override { return "iter" ; } |
681 | std::string short_name() const override { return "i" ; } |
682 | std::string desc() const override { |
683 | return "Iteration-level dimension blocks." ; |
684 | } |
685 | }; |
686 | |
687 | class loop_dims_param_t : public dims_param_t { |
688 | public: |
689 | std::string name() const override { return "loop" ; } |
690 | std::string short_name() const override { return "l" ; } |
691 | std::string desc() const override { return "Loop-level dimension blocks." ; } |
692 | }; |
693 | |
694 | class shrink_tg_dims_param_t : public bool_param_t { |
695 | public: |
696 | shrink_tg_dims_param_t() : bool_param_t(false) {} |
697 | std::string name() const override { return "shrink-tg-dims" ; } |
698 | std::string short_name() const override { return "stg" ; } |
699 | std::string desc() const override { |
700 | return "Whether to adjust tile sizes depending on batch size." ; |
701 | } |
702 | }; |
703 | |
704 | class ow_kw_grf_cache_param_t : public bool_param_t { |
705 | public: |
706 | ow_kw_grf_cache_param_t() : bool_param_t(false) {} |
707 | std::string name() const override { return "ow-kw-grf-cache" ; } |
708 | std::string desc() const override { |
709 | return "Whether to use GRF cache to reuse source for ow/kw pairs" ; |
710 | } |
711 | }; |
712 | |
713 | class pad_slm_param_t : public bool_param_t { |
714 | public: |
715 | pad_slm_param_t() : bool_param_t(true) {} |
716 | std::string name() const override { return "pad-slm" ; } |
717 | std::string desc() const override { |
718 | return "Whether to pad SLM layout to avoid bank conflicts." ; |
719 | } |
720 | }; |
721 | |
722 | class padded_dims_param_t : public map_param_t { |
723 | public: |
724 | std::string name() const override { return "pad" ; } |
725 | std::string desc() const override { |
726 | return "Padded dimensions (rounded-up for blocks and to comply with " |
727 | "required zero padding in output layouts) ." ; |
728 | } |
729 | }; |
730 | |
731 | class pipeline_param_t : public param_t { |
732 | public: |
733 | std::string name() const override { return "pipeline" ; } |
734 | std::string short_name() const override { return "P" ; } |
735 | std::string desc() const override { return "General pipeline parameters." ; } |
736 | |
737 | bool do_unroll() const { return do_unroll_; } |
738 | bool () const { return !do_unroll(); } |
739 | |
740 | void set_from_str(const std::string &s) override { |
741 | do_unroll_ = false; |
742 | for (auto c : s) { |
743 | switch (c) { |
744 | case 'u': do_unroll_ = true; break; |
745 | default: ir_error_not_expected() << s; |
746 | } |
747 | } |
748 | } |
749 | |
750 | void set(bool do_unroll) { do_unroll_ = do_unroll; } |
751 | |
752 | private: |
753 | bool do_unroll_ = false; |
754 | }; |
755 | |
756 | class prb_param_t : public value_param_t<conv_problem_t> { |
757 | public: |
758 | using value_param_t::value_param_t; |
759 | |
760 | std::string name() const override { return "prb" ; } |
761 | std::string desc() const override { return "Convolution problem." ; } |
762 | |
763 | void set_pd(const convolution_pd_t *pd) { |
764 | value_.conv_pd = pd; |
765 | value_.attr = pd->attr(); |
766 | } |
767 | }; |
768 | |
769 | class prefetch_param_t : public param_t { |
770 | public: |
771 | std::string name() const override { return "prefetch" ; } |
772 | std::string short_name() const override { return "p" ; } |
773 | std::string desc() const override { return "Parameters for prefetching." ; } |
774 | |
775 | int bufs() const { return bufs_; } |
776 | |
777 | operator bool() const { return bufs_ > 0; } |
778 | |
779 | void set_from_str(const std::string &s) override { |
780 | auto parts = ir_utils::split(s, "." ); |
781 | for (auto &p : parts) { |
782 | ir_assert(p.size() >= 2) << p; |
783 | char name = p[0]; |
784 | int value = std::stoi(p.substr(1)); |
785 | switch (name) { |
786 | case 'x': bufs_ = value; break; |
787 | default: ir_error_not_expected() << p; |
788 | } |
789 | } |
790 | } |
791 | |
792 | void set(int bufs) { bufs_ = bufs; } |
793 | |
794 | private: |
795 | int bufs_ = 0; |
796 | }; |
797 | |
798 | class reduce_b_param_t : public bool_param_t { |
799 | public: |
800 | reduce_b_param_t() : bool_param_t(false) {} |
801 | std::string name() const override { return "reduce-b" ; } |
802 | std::string desc() const override { |
803 | return "Whether to reduce B tensor (used for dst reduction in backward " |
804 | "by weights)." ; |
805 | } |
806 | }; |
807 | |
808 | class reduce_grf_usage_param_t : public bool_param_t { |
809 | public: |
810 | reduce_grf_usage_param_t() : bool_param_t(true) {} |
811 | std::string name() const override { return "reduce-grf-usage" ; } |
812 | std::string short_name() const override { return "r" ; } |
813 | std::string desc() const override { |
814 | return "Whether to try to reduce GRF usage based on heuristics." ; |
815 | } |
816 | }; |
817 | |
818 | // TODO: Remove this parameter and enable 2D block messages based on the |
819 | // generation flow. |
820 | class send_2d_nhwc_param_t : public bool_param_t { |
821 | public: |
822 | send_2d_nhwc_param_t() : bool_param_t(false) {} |
823 | std::string name() const override { return "2d-send-nhwc" ; } |
824 | std::string desc() const override { |
825 | return "Whether to use the optimal NHWC setup relying on 2D block " |
826 | "messages." ; |
827 | } |
828 | }; |
829 | |
830 | class slm_param_t : public param_t { |
831 | public: |
832 | std::string name() const override { return "slm" ; } |
833 | std::string short_name() const override { return "s" ; } |
834 | std::string desc() const override { return "SLM buffering parameters." ; } |
835 | |
836 | int bufs() const { return bufs_; } |
837 | int gmem_bufs() const { return gmem_bufs_; } |
838 | int sync_version() const { return sync_version_; } |
839 | bool a() const { return a_; } |
840 | bool b() const { return b_; } |
841 | |
842 | operator bool() const { return bufs() > 0; } |
843 | |
844 | void set_from_str(const std::string &s) override { |
845 | auto parts = ir_utils::split(s, "." ); |
846 | for (auto &p : parts) { |
847 | ir_assert(p.size() >= 2) << p; |
848 | char name = p[0]; |
849 | int value = std::stoi(p.substr(1)); |
850 | switch (name) { |
851 | case 'x': bufs_ = value; break; |
852 | case 'g': gmem_bufs_ = value; break; |
853 | case 'v': sync_version_ = value; break; |
854 | default: ir_error_not_expected() << p; |
855 | } |
856 | } |
857 | if (bufs_ > 0) { |
858 | a_ = true; |
859 | b_ = true; |
860 | } |
861 | } |
862 | |
863 | void set(int bufs, int gmem_bufs, bool a, bool b) { |
864 | bufs_ = bufs; |
865 | gmem_bufs_ = gmem_bufs; |
866 | a_ = a; |
867 | b_ = b; |
868 | } |
869 | |
870 | void set_bufs(int bufs) { bufs_ = bufs; } |
871 | void set_gmem_bufs(int gmem_bufs) { gmem_bufs_ = gmem_bufs; } |
872 | |
873 | private: |
874 | // Number of SLM buffers to use (0, 1, 2 or 3). |
875 | int bufs_ = 0; |
876 | // Number of GRF buffers to use for GMEM -> SLM copy (0, 1 or 2). |
877 | int gmem_bufs_ = 0; |
878 | // See slm_sync_manager_t for more details. |
879 | int sync_version_ = -1; |
880 | // Whether SLM buffering for A is enabled. |
881 | bool a_ = false; |
882 | // Whether SLM buffering for B is enabled. |
883 | bool b_ = false; |
884 | }; |
885 | |
886 | class src_layout_param_t : public layout_param_t { |
887 | public: |
888 | std::string name() const override { return "src" ; } |
889 | std::string desc() const override { return "Source layout." ; } |
890 | }; |
891 | |
892 | // Subtiles to split into for the inner A x B multiplication: |
893 | // |
894 | // Case 1. a_subtiles = 1, b_subtiles = 1 |
895 | // A = load(...) |
896 | // B = load(...) |
897 | // C += A * B |
898 | // |
899 | // Case 2. a_subtiles > 1, b_subtiles = 1 |
900 | // B = load(...) |
901 | // for i in range(0, a_subtiles): |
902 | // A_i = load(...) |
903 | // C_i += A_i * B |
904 | // |
905 | // Case 3. a_subtiles = 1, b_subtiles > 1 |
906 | // A = load(...) |
907 | // for j in range(0, b_subtiles): |
908 | // B_j = load(...) |
909 | // C_j += A * B_j |
910 | // |
911 | // Tiling for A and tiling for B are mutually exclusive. Using subtiles helps |
912 | // to reduce GRF consumption. |
913 | class subtiles_param_t : public param_t { |
914 | public: |
915 | std::string name() const override { return "subtiles" ; } |
916 | std::string short_name() const override { return "S" ; } |
917 | std::string desc() const override { return "Sub-iteration blocking." ; } |
918 | |
919 | int a() const { return a_; } |
920 | int b() const { return b_; } |
921 | |
922 | void set_from_str(const std::string &s) override { |
923 | a_ = 1; |
924 | b_ = 1; |
925 | for (auto &kv : to_map(s)) { |
926 | if (kv.first == "a" ) { |
927 | a_ = kv.second; |
928 | } else if (kv.first == "b" ) { |
929 | b_ = kv.second; |
930 | } else { |
931 | ir_error_not_expected() << kv.first; |
932 | } |
933 | } |
934 | } |
935 | |
936 | void set(int a, int b) { |
937 | a_ = a; |
938 | b_ = b; |
939 | } |
940 | |
941 | void set_a(int a) { a_ = a; } |
942 | void set_b(int b) { b_ = b; } |
943 | |
944 | private: |
945 | int a_ = 1; |
946 | int b_ = 1; |
947 | }; |
948 | |
949 | class thread_group_grid_param_t : public grid_param_t { |
950 | public: |
951 | std::string name() const override { return "tg-grid" ; } |
952 | std::string desc() const override { return "Thread group grid." ; } |
953 | }; |
954 | |
955 | class thread_group_dims_param_t : public map_param_t { |
956 | public: |
957 | std::string name() const override { return "tg" ; } |
958 | std::string short_name() const override { return "T" ; } |
959 | std::string desc() const override { |
960 | return "Thread group-level dimension blocks." ; |
961 | } |
962 | }; |
963 | |
964 | class unroll_param_t : public map_param_t { |
965 | public: |
966 | std::string name() const override { return "unroll" ; } |
967 | std::string short_name() const override { return "u" ; } |
968 | std::string desc() const override { |
969 | return "Per-dimension unroll factors." ; |
970 | } |
971 | }; |
972 | |
973 | class wei_layout_param_t : public layout_param_t { |
974 | std::string name() const override { return "wei" ; } |
975 | std::string desc() const override { return "Weights layout." ; } |
976 | }; |
977 | |
978 | namespace constants { |
979 | // Maximum number of SLM buffers. |
980 | static const int max_slm_bufs = 3; |
981 | |
982 | // GRF usage for kernel arguments, local work IDs/sizes, signal header, |
983 | // temporary expressions, etc. |
984 | static const int reserved_regs = 20; |
985 | } // namespace constants |
986 | |
987 | class conv_config_t { |
988 | public: |
989 | conv_config_t() = default; |
990 | |
991 | #define DECL_PARAM(name) \ |
992 | const name##_param_t &name##_param() const { \ |
993 | (void)name##_init_; \ |
994 | return name##_; \ |
995 | } \ |
996 | name##_param_t &name##_param() { return name##_; } \ |
997 | const name##_param_t::value_t &name() const { return name##_.get(); } \ |
998 | void set_##name(const name##_param_t::value_t &value) { \ |
999 | name##_.set(value); \ |
1000 | } |
1001 | |
1002 | #define DECL_PARAM2(name) \ |
1003 | const name##_param_t &name() const { \ |
1004 | (void)name##_init_; \ |
1005 | return name##_; \ |
1006 | } \ |
1007 | name##_param_t &name() { return name##_; } |
1008 | |
1009 | DECL_PARAM(allow_a_grf_reorder) |
1010 | DECL_PARAM(allow_b_grf_reorder) |
1011 | DECL_PARAM(allow_slm_tg_slicing) |
1012 | DECL_PARAM(assign_sbids) |
1013 | DECL_PARAM(bwd_d_optimize_strided) |
1014 | DECL_PARAM(bwd_d_optimize_strided_iw) |
1015 | DECL_PARAM(check_slm_size) |
1016 | DECL_PARAM(exec_cfg) |
1017 | DECL_PARAM(fma_kind) |
1018 | DECL_PARAM(fuse_spatial) |
1019 | DECL_PARAM(hint) |
1020 | DECL_PARAM(hoist_masks_from_compute_loop) |
1021 | DECL_PARAM(kernel_grid) |
1022 | DECL_PARAM(ow_kw_grf_cache) |
1023 | DECL_PARAM(pad_slm) |
1024 | DECL_PARAM(prb) |
1025 | DECL_PARAM(reduce_b) |
1026 | DECL_PARAM(reduce_grf_usage) |
1027 | DECL_PARAM(send_2d_nhwc) |
1028 | DECL_PARAM(shrink_tg_dims) |
1029 | DECL_PARAM(thread_group_grid) |
1030 | DECL_PARAM2(bia_layout) |
1031 | DECL_PARAM2(dims) |
1032 | DECL_PARAM2(dst_layout) |
1033 | DECL_PARAM2(iter_dims) |
1034 | DECL_PARAM2(loop_dims) |
1035 | DECL_PARAM2(padded_dims) |
1036 | DECL_PARAM2(pipeline) |
1037 | DECL_PARAM2(prefetch) |
1038 | DECL_PARAM2(slm) |
1039 | DECL_PARAM2(src_layout) |
1040 | DECL_PARAM2(subtiles) |
1041 | DECL_PARAM2(thread_group_dims) |
1042 | DECL_PARAM2(unroll) |
1043 | DECL_PARAM2(wei_layout) |
1044 | |
1045 | #undef DECL_PARAM |
1046 | #undef DECL_PARAM2 |
1047 | |
1048 | void override_set(const std::string &s); |
1049 | |
1050 | std::string str() const; |
1051 | |
1052 | std::string blocking_brief_str() const; |
1053 | |
1054 | // Helper methods. |
1055 | int dim(const std::string &name) const { return dims()(name); } |
1056 | |
1057 | int iter_dim(const std::string &name) const { return iter_dims()(name); } |
1058 | |
1059 | int padded_dim(const std::string &name) const { |
1060 | return padded_dims()(name); |
1061 | } |
1062 | |
1063 | int loop_dim(const std::string &name) const { return loop_dims()(name); } |
1064 | |
1065 | int thread_group_dim(const std::string &name) const { |
1066 | return thread_group_dims()(name); |
1067 | } |
1068 | |
1069 | // Blocks for padding. This is to comply with |
1070 | // zero-padding requirements. For example if the output |
1071 | // layout is nChw32c but there are only 8 channels to |
1072 | // compute and store, we still need to pad 8 to 32 and |
1073 | // spawn more thread groups to ensure 32c block is |
1074 | // properly zero-padded. |
1075 | int pad_block(const std::string &name) const { |
1076 | auto &src = src_layout().compute(); |
1077 | auto &wei = wei_layout().compute(); |
1078 | auto &dst = dst_layout().compute(); |
1079 | |
1080 | #define CASE(_name, layout, idx) \ |
1081 | if (name == _name) return layout.inner_block(idx) |
1082 | |
1083 | if (prb().is_fwd) { |
1084 | CASE("mb" , dst, 0); |
1085 | CASE("g" , dst, 1); |
1086 | CASE("oc" , dst, 2); |
1087 | } else if (prb().is_bwd_d) { |
1088 | CASE("mb" , src, 0); |
1089 | CASE("g" , src, 1); |
1090 | CASE("ic" , src, 2); |
1091 | } else if (prb().is_bwd_w) { |
1092 | CASE("g" , wei, 0); |
1093 | CASE("oc" , wei, 1); |
1094 | CASE("ic" , wei, 2); |
1095 | } |
1096 | #undef CASE |
1097 | return 1; |
1098 | } |
1099 | |
1100 | int unroll(const std::string &name) const { return unroll()(name); } |
1101 | |
1102 | int reserved_regs() const { return constants::reserved_regs; } |
1103 | |
1104 | const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); } |
1105 | |
1106 | ngen::HW hw() const { return hw_cfg().hw(); } |
1107 | |
1108 | bool is_ge_xe_hpc() const { return hw() >= ngen::HW::XeHPC; } |
1109 | |
1110 | int grf_size() const { return hw_cfg().grf_size(); } |
1111 | |
1112 | int regs() const { return exec_cfg().regs(); } |
1113 | |
1114 | int simd() const { return exec_cfg().simd(); } |
1115 | |
1116 | int vec_size() const { return exec_cfg().vec_size(); } |
1117 | |
1118 | bool is_g_mad() const { |
1119 | return fma_kind() == fma_kind_t::mad && prb().g > 1 && prb().ic < 4 |
1120 | && prb().oc < 4 && prb().mb < 8 && !prb().is_dw; |
1121 | } |
1122 | |
1123 | bool is_dp_fma() const { |
1124 | return utils::one_of(fma_kind(), fma_kind_t::dpas, fma_kind_t::dpasw, |
1125 | fma_kind_t::dp4a); |
1126 | } |
1127 | |
1128 | bool is_dpas_or_dpasw_fma() const { |
1129 | return utils::one_of(fma_kind(), fma_kind_t::dpas, fma_kind_t::dpasw); |
1130 | } |
1131 | |
1132 | const layout_param_t &a_layout() const { |
1133 | return prb().pick_a<const layout_param_t &>( |
1134 | src_layout(), wei_layout(), dst_layout()); |
1135 | } |
1136 | |
1137 | const layout_param_t &b_layout() const { |
1138 | return prb().pick_b<const layout_param_t &>( |
1139 | src_layout(), wei_layout(), dst_layout()); |
1140 | } |
1141 | |
1142 | compute::nd_range_t nd_range() const { |
1143 | size_t gws[3]; |
1144 | size_t lws[3]; |
1145 | for (int i = 0; i < 3; i++) { |
1146 | lws[i] = thread_group_grid().dim(i) * (i == 0 ? simd() : 1); |
1147 | gws[i] = kernel_grid().dim(i) * lws[i]; |
1148 | } |
1149 | return compute::nd_range_t(gws, lws); |
1150 | } |
1151 | |
1152 | int grid_dim(const std::string &dim) const { |
1153 | return ir_utils::safe_divide(padded_dim(dim), |
1154 | loop_dim(dim) * thread_group_dim(dim) * iter_dim(dim)); |
1155 | } |
1156 | |
1157 | int iter_dim(std::initializer_list<const char *> dims) const { |
1158 | int ret = 1; |
1159 | for (auto *dim : dims) |
1160 | ret *= iter_dim(dim); |
1161 | return ret; |
1162 | } |
1163 | |
1164 | void set_pd(const convolution_pd_t *pd) { prb_.set_pd(pd); } |
1165 | |
1166 | void set_regs(int regs) { |
1167 | auto tmp = exec_cfg(); |
1168 | tmp.set_regs(regs); |
1169 | set_exec_cfg(tmp); |
1170 | } |
1171 | |
1172 | void set_simd(int simd) { |
1173 | auto tmp = exec_cfg(); |
1174 | tmp.set_simd(simd); |
1175 | set_exec_cfg(tmp); |
1176 | } |
1177 | |
1178 | void set_vec_size(int vec_size) { |
1179 | auto tmp = exec_cfg(); |
1180 | tmp.set_vec_size(vec_size); |
1181 | set_exec_cfg(tmp); |
1182 | } |
1183 | |
1184 | bool can_skip_wei_zero_out() const; |
1185 | bool can_skip_bia_zero_out() const; |
1186 | |
1187 | private: |
1188 | struct param_init_t {}; |
1189 | |
1190 | template <typename GetParamFuncT, typename PtrT> |
1191 | static param_init_t register_param( |
1192 | PtrT ptr, std::vector<GetParamFuncT> &get_params_) { |
1193 | get_params_.push_back([=](conv_config_t *cfg) { return &(cfg->*ptr); }); |
1194 | return param_init_t(); |
1195 | } |
1196 | |
1197 | std::vector<std::function<param_t *(conv_config_t *)>> get_params_; |
1198 | |
1199 | #define INIT_PARAM(name) \ |
1200 | name##_param_t name##_; \ |
1201 | param_init_t name##_init_ \ |
1202 | = register_param(&conv_config_t::name##_, get_params_); |
1203 | |
1204 | INIT_PARAM(allow_a_grf_reorder) |
1205 | INIT_PARAM(allow_b_grf_reorder) |
1206 | INIT_PARAM(allow_slm_tg_slicing) |
1207 | INIT_PARAM(assign_sbids) |
1208 | INIT_PARAM(bia_layout) |
1209 | INIT_PARAM(bwd_d_optimize_strided) |
1210 | INIT_PARAM(bwd_d_optimize_strided_iw) |
1211 | INIT_PARAM(check_slm_size) |
1212 | INIT_PARAM(dims) |
1213 | INIT_PARAM(dst_layout) |
1214 | INIT_PARAM(exec_cfg) |
1215 | INIT_PARAM(fma_kind) |
1216 | INIT_PARAM(fuse_spatial) |
1217 | INIT_PARAM(hint) |
1218 | INIT_PARAM(hoist_masks_from_compute_loop) |
1219 | INIT_PARAM(iter_dims) |
1220 | INIT_PARAM(kernel_grid) |
1221 | INIT_PARAM(loop_dims) |
1222 | INIT_PARAM(ow_kw_grf_cache) |
1223 | INIT_PARAM(pad_slm) |
1224 | INIT_PARAM(padded_dims) |
1225 | INIT_PARAM(pipeline) |
1226 | INIT_PARAM(prb) |
1227 | INIT_PARAM(prefetch) |
1228 | INIT_PARAM(reduce_b) |
1229 | INIT_PARAM(reduce_grf_usage) |
1230 | INIT_PARAM(send_2d_nhwc) |
1231 | INIT_PARAM(shrink_tg_dims) |
1232 | INIT_PARAM(slm) |
1233 | INIT_PARAM(src_layout) |
1234 | INIT_PARAM(subtiles) |
1235 | INIT_PARAM(thread_group_dims) |
1236 | INIT_PARAM(thread_group_grid) |
1237 | INIT_PARAM(unroll) |
1238 | INIT_PARAM(wei_layout) |
1239 | |
1240 | #undef INIT_PARAM |
1241 | }; |
1242 | |
1243 | inline std::ostream &operator<<(std::ostream &out, const conv_config_t &cfg) { |
1244 | out << cfg.str(); |
1245 | return out; |
1246 | } |
1247 | |
1248 | class bmnk_dim_helper_t { |
1249 | public: |
1250 | bmnk_dim_helper_t(const conv_config_t &cfg) : prb_(cfg.prb()), cfg_(cfg) {} |
1251 | |
1252 | int iter_dim(char bmnk) const { |
1253 | int ret = 1; |
1254 | for (auto &kv : cfg_.iter_dims().get()) { |
1255 | if (to_bmnk(kv.first) != bmnk) continue; |
1256 | ret *= kv.second; |
1257 | } |
1258 | return ret; |
1259 | } |
1260 | |
1261 | int thread_group_dim(char bmnk) const { |
1262 | int ret = 1; |
1263 | for (auto &kv : cfg_.thread_group_dims().get()) { |
1264 | if (to_bmnk(kv.first) != bmnk) continue; |
1265 | ret *= kv.second; |
1266 | } |
1267 | return ret; |
1268 | } |
1269 | |
1270 | int loop_dim(char bmnk) const { |
1271 | int ret = 1; |
1272 | for (auto &kv : cfg_.loop_dims().get()) { |
1273 | if (to_bmnk(kv.first) != bmnk) continue; |
1274 | ret *= kv.second; |
1275 | } |
1276 | return ret; |
1277 | } |
1278 | |
1279 | private: |
1280 | static bool contains(const char **array, const std::string &s) { |
1281 | for (const char **ptr = array; *ptr; ptr++) { |
1282 | if (s == *ptr) return true; |
1283 | } |
1284 | return false; |
1285 | } |
1286 | |
1287 | char to_bmnk(const std::string &dim_name) const { |
1288 | static const char *fwd_b_dims[] = {"g" , nullptr}; |
1289 | static const char *fwd_m_dims[] |
1290 | = {"mb" , "osp" , "od" , "oh" , "ow" , nullptr}; |
1291 | static const char *fwd_n_dims[] = {"oc" , nullptr}; |
1292 | static const char *fwd_k_dims[] = {"ic" , "kd" , "kh" , "kw" , nullptr}; |
1293 | static const char *bwd_d_b_dims[] = {"g" , nullptr}; |
1294 | static const char *bwd_d_m_dims[] = {"mb" , "id" , "ih" , "iw" , nullptr}; |
1295 | static const char *bwd_d_n_dims[] = {"ic" , nullptr}; |
1296 | static const char *bwd_d_k_dims[] = {"oc" , "kd" , "kh" , "kw" , nullptr}; |
1297 | static const char *bwd_w_b_dims[] = {"g" , nullptr}; |
1298 | static const char *bwd_w_m_dims[] = {"ic" , "kd" , "kh" , "kw" , nullptr}; |
1299 | static const char *bwd_w_n_dims[] = {"oc" , nullptr}; |
1300 | static const char *bwd_w_k_dims[] = {"mb" , "od" , "oh" , "ow" , nullptr}; |
1301 | |
1302 | const char **b_dims = prb_.pick_by_dir<const char **>( |
1303 | fwd_b_dims, bwd_d_b_dims, bwd_w_b_dims); |
1304 | const char **m_dims = prb_.pick_by_dir<const char **>( |
1305 | fwd_m_dims, bwd_d_m_dims, bwd_w_m_dims); |
1306 | const char **n_dims = prb_.pick_by_dir<const char **>( |
1307 | fwd_n_dims, bwd_d_n_dims, bwd_w_n_dims); |
1308 | const char **k_dims = prb_.pick_by_dir<const char **>( |
1309 | fwd_k_dims, bwd_d_k_dims, bwd_w_k_dims); |
1310 | |
1311 | if (contains(b_dims, dim_name)) return 'b'; |
1312 | if (contains(m_dims, dim_name)) return 'm'; |
1313 | if (contains(n_dims, dim_name)) return 'n'; |
1314 | if (contains(k_dims, dim_name)) return 'k'; |
1315 | |
1316 | ir_error_not_expected() << dim_name; |
1317 | return ' '; |
1318 | } |
1319 | |
1320 | const conv_problem_t &prb_; |
1321 | const conv_config_t &cfg_; |
1322 | }; |
1323 | |
1324 | status_t init_pd_time_cfg(const conv_problem_t &prb, conv_config_t &cfg, |
1325 | const engine_t *engine, convolution_pd_t *pd, primitive_attr_t *attr); |
1326 | status_t init_cfg(conv_config_t &cfg, const convolution_pd_t *pd); |
1327 | tensor_config_t get_tensor_config(const conv_config_t &cfg); |
1328 | int estimate_register_count(const conv_config_t &cfg); |
1329 | bool can_use_a_2d_send(const conv_config_t &cfg); |
1330 | bool can_use_b_2d_send(const conv_config_t &cfg); |
1331 | const char **get_kernel_grid_conv_dims(const conv_problem_t &prb, int idx); |
1332 | const char **get_thread_group_grid_conv_dims( |
1333 | const conv_problem_t &prb, int idx); |
1334 | |
1335 | } // namespace jit |
1336 | } // namespace gpu |
1337 | } // namespace impl |
1338 | } // namespace dnnl |
1339 | |
1340 | #endif |
1341 | |