1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_BLOCK_HELPER_HPP
18#define GPU_JIT_CONV_BLOCK_HELPER_HPP
19
20#include <algorithm>
21#include <cmath>
22#include <functional>
23#include <iostream>
24#include <sstream>
25#include <vector>
26#include <initializer_list>
27#include <unordered_map>
28
29#include "common/c_types_map.hpp"
30#include "common/math_utils.hpp"
31#include "common/utils.hpp"
32#include "gpu/compute/device_info.hpp"
33#include "gpu/jit/ir/core.hpp"
34#include "gpu/jit/ir/fma.hpp"
35#include "gpu/jit/ir/hw_config.hpp"
36#include "gpu/jit/ngen/ngen.hpp"
37#include "gpu/jit/utils/utils.hpp"
38
39namespace dnnl {
40namespace impl {
41namespace gpu {
42namespace jit {
43
44// Tile level - describes hierarchy of dimensions for blocking.
45enum class tile_level_t {
46 unknown,
47 iter, // Number of elements per iteration.
48 tg, // Number of threads per thread group.
49 loop, // Number of iterations per loop.
50 _last,
51};
52
53// Min/max integer indices of tile levels.
54const int min_tile_level_idx = (int)tile_level_t::iter;
55const int max_tile_level_idx = (int)tile_level_t::loop;
56
57// Describes the dimension value, contains either an integer or a special
58// "unlimited" value which behaves like infinity when mixing in operations with
59// integer values.
60class dim_value_t {
61public:
62 dim_value_t() = default;
63 dim_value_t(int value) : value_(value) {}
64 dim_value_t &operator=(int value) {
65 is_unlimited_ = false;
66 value_ = value;
67 return *this;
68 }
69
70 bool is_unlimited() const { return is_unlimited_; }
71
72 bool operator==(dim_value_t other) const {
73 if (is_unlimited() && other.is_unlimited()) return true;
74 return (is_unlimited_ == other.is_unlimited_)
75 && (value_ == other.value_);
76 }
77 bool operator==(int value) const {
78 return !is_unlimited() && value_ == value;
79 }
80 bool operator!=(dim_value_t other) const { return !(*this == other); }
81 bool operator!=(int value) const { return !(*this == value); }
82
83 operator int() const {
84 if (is_unlimited_) {
85 ir_error_not_expected() << "Can't convert unlimited value to int.";
86 return -1;
87 }
88 return value_;
89 }
90
91 std::string str() const {
92 std::ostringstream oss;
93 if (is_unlimited()) {
94 oss << "(unlimited)";
95 } else {
96 oss << value_;
97 }
98 return oss.str();
99 }
100
101 IR_DEFINE_DUMP()
102
103 static dim_value_t unlimited() {
104 dim_value_t ret;
105 ret.is_unlimited_ = true;
106 return ret;
107 }
108
109private:
110 bool is_unlimited_ = false;
111 int value_ = 0;
112};
113
114inline dim_value_t min(dim_value_t a, dim_value_t b) {
115 if (a.is_unlimited()) return b;
116 if (b.is_unlimited()) return a;
117 return std::min((int)a, (int)b);
118}
119
120// Stores information about dimension blocking in context of BMNK/GEMM
121// notation.
122class dim_info_t {
123public:
124 dim_info_t() {
125 for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) {
126 tile_dims_[i] = 1;
127 max_tile_dims_[i] = dim_value_t::unlimited();
128 }
129 }
130
131 dim_info_t(const std::string &name, int size) : dim_info_t() {
132 name_ = name;
133 size_ = size;
134 }
135
136 const std::string &name() const { return name_; }
137 void set_name(const std::string &name) { name_ = name; }
138
139 int size() const { return size_; }
140 void set_size(int value) { size_ = value; }
141
142 int padded_size() const {
143 return utils::rnd_up(size(), math::lcm(tg_blk(), pad_blk_));
144 }
145
146 char bmnk() const { return bmnk_; }
147 void set_bmnk(char value) { bmnk_ = value; }
148
149 int base_iter_block() const { return base_iter_blk_; }
150 void set_base_iter_block(int value) { base_iter_blk_ = value; }
151
152 int pad_block() const { return pad_blk_; }
153 void set_pad_block(int value) { pad_blk_ = value; }
154
155 int order_key() const { return order_key_; }
156 void set_order_key(int value) { order_key_ = value; }
157
158 bool is_blocked() const { return is_blocked_; }
159 void set_blocked(bool value = true) { is_blocked_ = value; }
160
161 bool allow_fuse() const { return allow_fuse_; }
162 bool allow_split() const { return allow_split_; }
163 void set_allow_fuse(bool value = true) { allow_fuse_ = value; }
164 void set_allow_split(bool value = true) { allow_split_ = value; }
165
166 int grid_dim() const {
167 return ir_utils::safe_divide(padded_size(), tg_blk());
168 }
169 dim_value_t tg_dim() const { return dim(tile_level_t::tg); }
170 dim_value_t loop_dim() const { return dim(tile_level_t::loop); }
171 dim_value_t iter_dim() const { return dim(tile_level_t::iter); }
172 dim_value_t dim(tile_level_t level) const {
173 int idx = (int)level;
174 ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx);
175 return tile_dims_[idx];
176 }
177 dim_value_t max_dim(tile_level_t level) const {
178 int idx = (int)level;
179 ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx);
180 return max_tile_dims_[idx];
181 }
182 bool pref_tg_block() { return pref_tg_block_; }
183 void set_pref_tg_block(bool value = true) { pref_tg_block_ = value; }
184 void set_tg_dim(dim_value_t value) { set_dim(tile_level_t::tg, value); }
185 void set_loop_dim(dim_value_t value) { set_dim(tile_level_t::loop, value); }
186 void set_iter_dim(dim_value_t value) { set_dim(tile_level_t::iter, value); }
187 void set_dim(tile_level_t level, dim_value_t value) {
188 int idx = (int)level;
189 ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx);
190 tile_dims_[idx] = value;
191 }
192
193 void set_max_dim(tile_level_t level, dim_value_t value) {
194 int idx = (int)level;
195 ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx);
196 max_tile_dims_[idx] = value;
197 }
198
199 int inner_dims() const { return inner_dims_; }
200 void incr_inner_dims() { inner_dims_++; }
201
202 int iter_blk() const { return iter_dim(); }
203 int loop_blk() const { return loop_dim() * iter_blk(); }
204 int tg_blk() const { return tg_dim() * loop_blk(); }
205
206 bool has_any_blocking() const {
207 if (iter_dim() != 1) return true;
208 if (loop_dim() != 1) return true;
209 if (tg_dim() != 1) return true;
210 return false;
211 }
212
213 std::string str() const {
214 using namespace ir_utils;
215 std::ostringstream oss;
216 oss << "Dimension " << name_ << std::endl;
217 oss << " Size: " << size_ << std::endl;
218 oss << " Base iter block: " << base_iter_blk_ << std::endl;
219 oss << " BMNK: " << bmnk_ << std::endl;
220 oss << " Blocked: " << to_string(is_blocked_) << std::endl;
221 oss << " Allow fuse: " << to_string(allow_fuse_) << std::endl;
222 oss << " Allow split: " << to_string(allow_split_) << std::endl;
223 oss << " Order key: " << order_key_ << std::endl;
224
225 const char *tags[] = {"Iteration", "Thread group", "Loop", nullptr};
226 int level_id = min_tile_level_idx;
227 for (const char **tag = tags; *tag; tag++) {
228 tile_level_t level = (tile_level_t)level_id++;
229 dim_value_t max_dim_val = max_dim(level);
230 oss << " " << pad_str(*tag + std::string(" dim:"), -19);
231 oss << dim(level).str();
232 if (!max_dim_val.is_unlimited()) {
233 oss << " (max: " << max_dim_val.str() << ")";
234 }
235 oss << std::endl;
236 }
237 return oss.str();
238 }
239
240 std::string brief_str() const {
241 std::ostringstream oss;
242 oss << "Dimension " << name_ << pad_str(":", -18 + (int)name_.length());
243 oss << "(grid:" << pad_int(grid_dim(), 5) << ") x ";
244 oss << "(loop:" << pad_int(loop_dim(), 5) << ") x ";
245 oss << "(tg:" << pad_int(tg_dim(), 5) << ") x ";
246 oss << "(iter:" << pad_int(iter_dim(), 5) << ")";
247 return oss.str();
248 }
249
250 IR_DEFINE_DUMP()
251
252private:
253 static std::string pad_str(std::string s, int pad) {
254 auto pos = (pad >= 0 ? 0 : s.length());
255 s.insert(pos, std::abs(pad) - s.length(), ' ');
256 return s;
257 }
258
259 static std::string pad_int(int i, int pad) {
260 return pad_str(std::to_string(i), pad);
261 }
262
263 // Dimension name.
264 std::string name_;
265
266 // Dimension size.
267 int size_ = 0;
268
269 // Minimal block size for iteration blocking. Iteration-level block must be
270 // divisible by this value.
271 int base_iter_blk_ = 1;
272
273 // Block size to ensure correct zero padding. Blocked memory layouts must
274 // be fully covered to ensure they are zero-padded.
275 int pad_blk_ = 1;
276
277 // Dimension kind in terms of BMNK notation.
278 char bmnk_ = ' ';
279
280 // Number of prb dims for a BMNK dim
281 int inner_dims_ = 0;
282
283 // Whether the dimension can be blocked. Dimensions without blocking are
284 // implicitly tiled on the grid level (handle one element per thread
285 // group).
286 bool is_blocked_ = false;
287
288 // Whether the dimension can be fused with other "fused" dimensions on the
289 // same tile level.
290 bool allow_fuse_ = false;
291
292 // Whether the dimension can be split between multiple tile levels.
293 bool allow_split_ = false;
294
295 bool pref_tg_block_ = false;
296
297 // Dimensions with smaller order keys are tiled first.
298 int order_key_ = -1;
299
300 // Dimensions of tiles.
301 dim_value_t tile_dims_[max_tile_level_idx + 1];
302
303 // Max allowed dimensions of tiles.
304 dim_value_t max_tile_dims_[max_tile_level_idx + 1];
305};
306
307// Block helper provides functionality to compute tiling/blocking for a
308// GEMM-like problem (when problem dimensions can be classified in terms of
309// BMNK dimensions). A typical flow consists of three steps:
310// - Setting problem configuration:
311// - Problem dimensions (sizes, padding requirements, base block
312// requirements, BMNK behavior)
313// - HW details, FMA kind, data types, etc
314// - Setting restrictions/hints, for example:
315// - Maximal block sizes (e.g. to limit GRF usage)
316// - Fuse/split settings - to get the desired blocking decomposition
317// - K-slicing settings (grid or thread group sliciing)
318// - Computing block sizes for each tile: per iteration, per loop, per thread
319// group
320class block_helper_t {
321public:
322 bool is_frozen() const { return is_frozen_; }
323
324 const std::unordered_map<std::string, dim_info_t> &dims() const {
325 return dims_;
326 }
327
328 dim_info_t &dim(const std::string &name) {
329 ir_assert(dims_.count(name) != 0) << "Dimension not found: " << name;
330 return dims_.at(name);
331 }
332
333 void set_hw_config(const hw_config_t &hw_cfg) {
334 check_if_can_set();
335 hw_cfg_ = hw_cfg;
336 }
337
338 void set_fma_kind(fma_kind_t fma_kind) {
339 check_if_can_set();
340 fma_kind_ = fma_kind;
341 }
342
343 void set_simd_size(int simd_size) {
344 check_if_can_set();
345 simd_size_ = simd_size;
346 }
347
348 void set_vec_size(int vec_size) {
349 check_if_can_set();
350 vec_size_ = vec_size;
351 }
352
353 void set_max_tg_size(int max_tg_size) {
354 check_if_can_set();
355 max_tg_size_ = max_tg_size;
356 }
357
358 void set_max_tg_overridden(bool max_tg_overridden) {
359 check_if_can_set();
360 max_tg_overridden_ = max_tg_overridden;
361 }
362
363 void set_abc_types(
364 data_type_t a_type, data_type_t b_type, data_type_t c_type) {
365 check_if_can_set();
366 a_type_ = a_type;
367 b_type_ = b_type;
368 c_type_ = c_type;
369 }
370
371 void set_use_2d_send(bool use_a_2d_send, bool use_b_2d_send) {
372 use_a_2d_send_ = use_a_2d_send;
373 use_b_2d_send_ = use_b_2d_send;
374 }
375
376 void set_max_m_tg_dim(int value) {
377 m_dim().set_max_dim(tile_level_t::tg, value);
378 }
379
380 void set_max_n_tg_dim(int value) {
381 n_dim().set_max_dim(tile_level_t::tg, value);
382 }
383
384 void set_max_k_tg_dim(int value) {
385 k_dim().set_max_dim(tile_level_t::tg, value);
386 }
387
388 void set_dims(std::initializer_list<std::string> names,
389 std::initializer_list<int> sizes) {
390 check_if_can_set();
391 ir_assert(names.size() == sizes.size());
392 for (size_t i = 0; i < names.size(); i++) {
393 set_dim(*(names.begin() + i), *(sizes.begin() + i));
394 }
395 }
396
397 void set_dim(const std::string &name, int size) {
398 check_if_can_set();
399 ir_assert(dims_.count(name) == 0)
400 << "Dimension already exists: " << name;
401 dims_.emplace(name, dim_info_t(name, size));
402 }
403
404 void set_b_dims(std::initializer_list<std::string> names) {
405 set_bmnk_dims(names, 'B');
406 }
407 void set_m_dims(std::initializer_list<std::string> names) {
408 set_bmnk_dims(names, 'M');
409 }
410 void set_n_dims(std::initializer_list<std::string> names) {
411 set_bmnk_dims(names, 'N');
412 }
413 void set_k_dims(std::initializer_list<std::string> names) {
414 set_bmnk_dims(names, 'K');
415 }
416
417 void set_block_dims(std::initializer_list<std::string> names) {
418 check_if_can_set();
419 for (auto &name : names) {
420 dim(name).set_blocked();
421 }
422 }
423
424 void set_loop_dim(const std::string &name, int value) {
425 dim(name).set_loop_dim(value);
426 }
427
428 void set_tg_dim(const std::string &name, int value) {
429 dim(name).set_tg_dim(value);
430 }
431
432 void set_max_iter_dim(const std::string &name, int value) {
433 dim(name).set_max_dim(tile_level_t::iter, value);
434 }
435
436 void set_max_loop_dim(const std::string &name, int value) {
437 dim(name).set_max_dim(tile_level_t::loop, value);
438 }
439
440 void set_max_tg_dim(const std::string &name, int value) {
441 dim(name).set_max_dim(tile_level_t::tg, value);
442 }
443
444 void set_pref_tg_block(const std::string &name, bool value = true) {
445 dim(name).set_pref_tg_block(value);
446 }
447
448 bool any_pref_tg_block() {
449 for (auto &kv : dims_) {
450 auto &d = kv.second;
451 if (d.pref_tg_block()) return true;
452 }
453 return false;
454 }
455
456 void set_reduce_m_block_hint(bool value = true) {
457 reduce_m_block_hint_ = value;
458 reduce_m_block_hint_set_ = true;
459 }
460
461 void allow_fuse(std::initializer_list<std::string> names) {
462 check_if_can_set();
463 for (auto &name : names) {
464 dim(name).set_allow_fuse();
465 }
466 }
467
468 void allow_split(std::initializer_list<std::string> names) {
469 check_if_can_set();
470 for (auto &name : names) {
471 dim(name).set_allow_split();
472 }
473 }
474
475 void allow_k_tg_slicing() { allow_k_tg_slicing_ = true; }
476
477 void allow_k_grid_slicing() { allow_k_grid_slicing_ = true; }
478
479 void set_vector_dim(const std::string &name) {
480 check_if_can_set();
481 auto &d = dim(name);
482 d.set_base_iter_block(math::lcm(vec_size_, d.base_iter_block()));
483 vector_bmnk_ = d.bmnk();
484 }
485
486 void set_base_iter_block(const std::string &name, int block) {
487 check_if_can_set();
488 dim(name).set_base_iter_block(block);
489 }
490
491 void set_base_iter_block(const std::string &name, int block0, int block1) {
492 set_base_iter_block(name, math::lcm(block0, block1));
493 }
494
495 void set_pad_block(const std::string &name, int block) {
496 dim(name).set_pad_block(block);
497 }
498
499 void reorder(std::initializer_list<std::string> names) {
500 check_if_can_set();
501 int key = 0;
502 for (auto &name : names)
503 dim(name).set_order_key(key++);
504 }
505
506 void compute();
507
508 bool has_dim(const std::string &name) const {
509 return dims_.count(name) != 0;
510 }
511
512 int iter_dim(const std::string &name) const { return dim(name).iter_dim(); }
513 int loop_dim(const std::string &name) const { return dim(name).loop_dim(); }
514 int tg_dim(const std::string &name) const { return dim(name).tg_dim(); }
515 int grid_dim(const std::string &name) const { return dim(name).grid_dim(); }
516
517 int iter_blk(const std::string &name) const { return dim(name).iter_blk(); }
518 int loop_blk(const std::string &name) const { return dim(name).loop_blk(); }
519 int tg_blk(const std::string &name) const { return dim(name).tg_blk(); }
520
521 dim_value_t max_iter_dim(const std::string &name) const {
522 return dim(name).max_dim(tile_level_t::iter);
523 }
524
525 int padded_size(const std::string &name) const {
526 return dim(name).padded_size();
527 }
528
529 std::string str() const {
530 std::ostringstream oss;
531 for (auto &kv : dims_) {
532 auto &d = kv.second;
533 if (!d.has_any_blocking()) continue;
534 oss << d.str();
535 }
536 return oss.str();
537 }
538
539 std::string brief_str() {
540 std::ostringstream oss;
541 for (auto &kv : dims_) {
542 auto &d = kv.second;
543 if (!d.has_any_blocking()) continue;
544 oss << " " << d.brief_str() << std::endl;
545 }
546 return oss.str();
547 }
548
549 IR_DEFINE_DUMP()
550
551private:
552 void check_if_can_set() const {
553 ir_assert(!is_frozen_) << "Can't set: setup is already frozen.";
554 }
555
556 const dim_info_t &dim(const std::string &name) const {
557 ir_assert(dims_.count(name) != 0) << "Dimension not found: " << name;
558 return dims_.at(name);
559 }
560
561 dim_info_t &b_dim() { return bmnk_dims_[0]; }
562 dim_info_t &m_dim() { return bmnk_dims_[1]; }
563 dim_info_t &n_dim() { return bmnk_dims_[2]; }
564 dim_info_t &k_dim() { return bmnk_dims_[3]; }
565 const dim_info_t &b_dim() const { return bmnk_dims_[0]; }
566 const dim_info_t &m_dim() const { return bmnk_dims_[1]; }
567 const dim_info_t &n_dim() const { return bmnk_dims_[2]; }
568 const dim_info_t &k_dim() const { return bmnk_dims_[3]; }
569
570 dim_info_t &bmnk_dim(char bmnk) {
571 auto &ret = const_cast<const block_helper_t *>(this)->bmnk_dim(bmnk);
572 return const_cast<dim_info_t &>(ret);
573 }
574
575 const dim_info_t &bmnk_dim(char bmnk) const {
576 switch (bmnk) {
577 case 'B': return b_dim();
578 case 'M': return m_dim();
579 case 'N': return n_dim();
580 case 'K': return k_dim();
581 default: ir_error_not_expected();
582 }
583 return b_dim();
584 }
585
586 int prb_blocked_ndims(char bmnk) const {
587 int ret = 0;
588 for (auto &kv : dims_) {
589 auto &d = kv.second;
590 if (d.bmnk() != bmnk) continue;
591 if (!d.is_blocked()) continue;
592 if (d.size() == 1) continue;
593 ret++;
594 }
595 return ret;
596 }
597
598 const dim_info_t &prb_blocked_dim(char bmnk) const {
599 ir_assert(prb_blocked_ndims(bmnk) == 1);
600 for (auto &kv : dims_) {
601 auto &d = kv.second;
602 if (d.bmnk() != bmnk) continue;
603 if (!d.is_blocked()) continue;
604 if (d.size() == 1) continue;
605 return d;
606 }
607 return dim("");
608 }
609
610 int prb_max_dim(char bmnk, tile_level_t level) const {
611 int ret = 1;
612 for (auto &kv : dims_) {
613 auto &d = kv.second;
614 if (d.bmnk() != bmnk) continue;
615 if (!d.is_blocked()) continue;
616 ret *= min(d.size(), d.max_dim(level));
617 }
618 return ret;
619 }
620
621 void set_bmnk_dims(std::initializer_list<std::string> dims, char bmnk) {
622 check_if_can_set();
623 for (auto &d : dims)
624 dims_.at(d).set_bmnk(bmnk);
625 }
626
627 bool is_x8x8s32() const {
628 if (!utils::one_of(a_type_, data_type::s8, data_type::u8)) return false;
629 if (!utils::one_of(b_type_, data_type::s8, data_type::u8)) return false;
630 if (c_type_ != data_type::s32) return false;
631 return true;
632 }
633 bool is_tf32() const {
634 return a_type_ == data_type::tf32 && b_type_ == data_type::tf32
635 && c_type_ == data_type::f32;
636 }
637
638 bool vectorize_by_b() const { return vectorize_by_bmnk('B'); }
639 bool vectorize_by_m() const { return vectorize_by_bmnk('M'); }
640 bool vectorize_by_n() const { return vectorize_by_bmnk('N'); }
641 bool vectorize_by_k() const { return vectorize_by_bmnk('K'); }
642
643 bool vectorize_by_bmnk(char bmnk) const { return vector_bmnk_ == bmnk; }
644
645 int b_size() const { return b_dim().size(); }
646 int m_size() const { return m_dim().size(); }
647 int n_size() const { return n_dim().size(); }
648 int k_size() const { return k_dim().size(); }
649
650 void init_bmnk_dims() {
651 for (char bmnk : {'B', 'M', 'N', 'K'}) {
652 auto &d = bmnk_dim(bmnk);
653 d.set_name(std::string(1, bmnk));
654 d.set_size(1);
655 d.set_bmnk(bmnk);
656 }
657
658 for (auto &kv : dims_) {
659 auto &d = kv.second;
660 auto &bmnk_d = bmnk_dim(d.bmnk());
661 if (d.pref_tg_block()) bmnk_d.set_pref_tg_block();
662 bmnk_d.set_size(bmnk_d.size() * d.size());
663 bmnk_d.set_base_iter_block(
664 bmnk_d.base_iter_block() * d.base_iter_block());
665 if (d.is_blocked() && d.size() != 1) bmnk_d.incr_inner_dims();
666 }
667 }
668
669 void init_bmnk_blocks();
670 void init_k_blocking();
671 bool enable_k_tg_slicing() const;
672 void init_prb_blocks();
673 int compute_mad_k_block() const;
674
675 static int compute_block(int dim, int target_blk, int base_iter_block,
676 double target_eff = 0.75) {
677 int nblks = ir_utils::safe_divide(target_blk, base_iter_block);
678 while (nblks != 1) {
679 int dim_padded = utils::rnd_up(dim, nblks * base_iter_block);
680 double eff = (double)dim / dim_padded;
681 if (eff >= target_eff) break;
682 nblks--;
683 }
684 return nblks * base_iter_block;
685 }
686
687 // Whether compute() was already called.
688 bool is_frozen_ = false;
689
690 // BMNK kind of dimension to vectorize.
691 char vector_bmnk_ = ' ';
692
693 // General information about HW and computation.
694 hw_config_t hw_cfg_;
695 fma_kind_t fma_kind_ = fma_kind_t::unknown;
696 int vec_size_ = -1;
697 int simd_size_ = -1;
698 int max_tg_size_ = 0;
699 bool max_tg_overridden_ = false;
700 data_type_t a_type_ = data_type::undef;
701 data_type_t b_type_ = data_type::undef;
702 data_type_t c_type_ = data_type::undef;
703
704 bool use_a_2d_send_ = false;
705 bool use_b_2d_send_ = false;
706
707 // Whether K computation can be split across threads in thread group.
708 bool allow_k_tg_slicing_ = false;
709
710 // Whether K computation can be split across thread groups in the grid.
711 bool allow_k_grid_slicing_ = false;
712
713 // Problem dimensions.
714 std::unordered_map<std::string, dim_info_t> dims_;
715
716 // BMNK dimensions.
717 static const int bmnk_length = 4;
718 dim_info_t bmnk_dims_[bmnk_length];
719
720 bool reduce_m_block_hint_;
721 bool reduce_m_block_hint_set_ = false;
722};
723
724} // namespace jit
725} // namespace gpu
726} // namespace impl
727} // namespace dnnl
728
729#endif
730