1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_BLOCK_HELPER_HPP |
18 | #define GPU_JIT_CONV_BLOCK_HELPER_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <cmath> |
22 | #include <functional> |
23 | #include <iostream> |
24 | #include <sstream> |
25 | #include <vector> |
26 | #include <initializer_list> |
27 | #include <unordered_map> |
28 | |
29 | #include "common/c_types_map.hpp" |
30 | #include "common/math_utils.hpp" |
31 | #include "common/utils.hpp" |
32 | #include "gpu/compute/device_info.hpp" |
33 | #include "gpu/jit/ir/core.hpp" |
34 | #include "gpu/jit/ir/fma.hpp" |
35 | #include "gpu/jit/ir/hw_config.hpp" |
36 | #include "gpu/jit/ngen/ngen.hpp" |
37 | #include "gpu/jit/utils/utils.hpp" |
38 | |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace gpu { |
42 | namespace jit { |
43 | |
44 | // Tile level - describes hierarchy of dimensions for blocking. |
45 | enum class tile_level_t { |
46 | unknown, |
47 | iter, // Number of elements per iteration. |
48 | tg, // Number of threads per thread group. |
49 | loop, // Number of iterations per loop. |
50 | _last, |
51 | }; |
52 | |
53 | // Min/max integer indices of tile levels. |
54 | const int min_tile_level_idx = (int)tile_level_t::iter; |
55 | const int max_tile_level_idx = (int)tile_level_t::loop; |
56 | |
57 | // Describes the dimension value, contains either an integer or a special |
58 | // "unlimited" value which behaves like infinity when mixing in operations with |
59 | // integer values. |
60 | class dim_value_t { |
61 | public: |
62 | dim_value_t() = default; |
63 | dim_value_t(int value) : value_(value) {} |
64 | dim_value_t &operator=(int value) { |
65 | is_unlimited_ = false; |
66 | value_ = value; |
67 | return *this; |
68 | } |
69 | |
70 | bool is_unlimited() const { return is_unlimited_; } |
71 | |
72 | bool operator==(dim_value_t other) const { |
73 | if (is_unlimited() && other.is_unlimited()) return true; |
74 | return (is_unlimited_ == other.is_unlimited_) |
75 | && (value_ == other.value_); |
76 | } |
77 | bool operator==(int value) const { |
78 | return !is_unlimited() && value_ == value; |
79 | } |
80 | bool operator!=(dim_value_t other) const { return !(*this == other); } |
81 | bool operator!=(int value) const { return !(*this == value); } |
82 | |
83 | operator int() const { |
84 | if (is_unlimited_) { |
85 | ir_error_not_expected() << "Can't convert unlimited value to int." ; |
86 | return -1; |
87 | } |
88 | return value_; |
89 | } |
90 | |
91 | std::string str() const { |
92 | std::ostringstream oss; |
93 | if (is_unlimited()) { |
94 | oss << "(unlimited)" ; |
95 | } else { |
96 | oss << value_; |
97 | } |
98 | return oss.str(); |
99 | } |
100 | |
101 | IR_DEFINE_DUMP() |
102 | |
103 | static dim_value_t unlimited() { |
104 | dim_value_t ret; |
105 | ret.is_unlimited_ = true; |
106 | return ret; |
107 | } |
108 | |
109 | private: |
110 | bool is_unlimited_ = false; |
111 | int value_ = 0; |
112 | }; |
113 | |
114 | inline dim_value_t min(dim_value_t a, dim_value_t b) { |
115 | if (a.is_unlimited()) return b; |
116 | if (b.is_unlimited()) return a; |
117 | return std::min((int)a, (int)b); |
118 | } |
119 | |
120 | // Stores information about dimension blocking in context of BMNK/GEMM |
121 | // notation. |
122 | class dim_info_t { |
123 | public: |
124 | dim_info_t() { |
125 | for (int i = min_tile_level_idx; i <= max_tile_level_idx; i++) { |
126 | tile_dims_[i] = 1; |
127 | max_tile_dims_[i] = dim_value_t::unlimited(); |
128 | } |
129 | } |
130 | |
131 | dim_info_t(const std::string &name, int size) : dim_info_t() { |
132 | name_ = name; |
133 | size_ = size; |
134 | } |
135 | |
136 | const std::string &name() const { return name_; } |
137 | void set_name(const std::string &name) { name_ = name; } |
138 | |
139 | int size() const { return size_; } |
140 | void set_size(int value) { size_ = value; } |
141 | |
142 | int padded_size() const { |
143 | return utils::rnd_up(size(), math::lcm(tg_blk(), pad_blk_)); |
144 | } |
145 | |
146 | char bmnk() const { return bmnk_; } |
147 | void set_bmnk(char value) { bmnk_ = value; } |
148 | |
149 | int base_iter_block() const { return base_iter_blk_; } |
150 | void set_base_iter_block(int value) { base_iter_blk_ = value; } |
151 | |
152 | int pad_block() const { return pad_blk_; } |
153 | void set_pad_block(int value) { pad_blk_ = value; } |
154 | |
155 | int order_key() const { return order_key_; } |
156 | void set_order_key(int value) { order_key_ = value; } |
157 | |
158 | bool is_blocked() const { return is_blocked_; } |
159 | void set_blocked(bool value = true) { is_blocked_ = value; } |
160 | |
161 | bool allow_fuse() const { return allow_fuse_; } |
162 | bool allow_split() const { return allow_split_; } |
163 | void set_allow_fuse(bool value = true) { allow_fuse_ = value; } |
164 | void set_allow_split(bool value = true) { allow_split_ = value; } |
165 | |
166 | int grid_dim() const { |
167 | return ir_utils::safe_divide(padded_size(), tg_blk()); |
168 | } |
169 | dim_value_t tg_dim() const { return dim(tile_level_t::tg); } |
170 | dim_value_t loop_dim() const { return dim(tile_level_t::loop); } |
171 | dim_value_t iter_dim() const { return dim(tile_level_t::iter); } |
172 | dim_value_t dim(tile_level_t level) const { |
173 | int idx = (int)level; |
174 | ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx); |
175 | return tile_dims_[idx]; |
176 | } |
177 | dim_value_t max_dim(tile_level_t level) const { |
178 | int idx = (int)level; |
179 | ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx); |
180 | return max_tile_dims_[idx]; |
181 | } |
182 | bool pref_tg_block() { return pref_tg_block_; } |
183 | void set_pref_tg_block(bool value = true) { pref_tg_block_ = value; } |
184 | void set_tg_dim(dim_value_t value) { set_dim(tile_level_t::tg, value); } |
185 | void set_loop_dim(dim_value_t value) { set_dim(tile_level_t::loop, value); } |
186 | void set_iter_dim(dim_value_t value) { set_dim(tile_level_t::iter, value); } |
187 | void set_dim(tile_level_t level, dim_value_t value) { |
188 | int idx = (int)level; |
189 | ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx); |
190 | tile_dims_[idx] = value; |
191 | } |
192 | |
193 | void set_max_dim(tile_level_t level, dim_value_t value) { |
194 | int idx = (int)level; |
195 | ir_assert(idx >= min_tile_level_idx && idx <= max_tile_level_idx); |
196 | max_tile_dims_[idx] = value; |
197 | } |
198 | |
199 | int inner_dims() const { return inner_dims_; } |
200 | void incr_inner_dims() { inner_dims_++; } |
201 | |
202 | int iter_blk() const { return iter_dim(); } |
203 | int loop_blk() const { return loop_dim() * iter_blk(); } |
204 | int tg_blk() const { return tg_dim() * loop_blk(); } |
205 | |
206 | bool has_any_blocking() const { |
207 | if (iter_dim() != 1) return true; |
208 | if (loop_dim() != 1) return true; |
209 | if (tg_dim() != 1) return true; |
210 | return false; |
211 | } |
212 | |
213 | std::string str() const { |
214 | using namespace ir_utils; |
215 | std::ostringstream oss; |
216 | oss << "Dimension " << name_ << std::endl; |
217 | oss << " Size: " << size_ << std::endl; |
218 | oss << " Base iter block: " << base_iter_blk_ << std::endl; |
219 | oss << " BMNK: " << bmnk_ << std::endl; |
220 | oss << " Blocked: " << to_string(is_blocked_) << std::endl; |
221 | oss << " Allow fuse: " << to_string(allow_fuse_) << std::endl; |
222 | oss << " Allow split: " << to_string(allow_split_) << std::endl; |
223 | oss << " Order key: " << order_key_ << std::endl; |
224 | |
225 | const char *tags[] = {"Iteration" , "Thread group" , "Loop" , nullptr}; |
226 | int level_id = min_tile_level_idx; |
227 | for (const char **tag = tags; *tag; tag++) { |
228 | tile_level_t level = (tile_level_t)level_id++; |
229 | dim_value_t max_dim_val = max_dim(level); |
230 | oss << " " << pad_str(*tag + std::string(" dim:" ), -19); |
231 | oss << dim(level).str(); |
232 | if (!max_dim_val.is_unlimited()) { |
233 | oss << " (max: " << max_dim_val.str() << ")" ; |
234 | } |
235 | oss << std::endl; |
236 | } |
237 | return oss.str(); |
238 | } |
239 | |
240 | std::string brief_str() const { |
241 | std::ostringstream oss; |
242 | oss << "Dimension " << name_ << pad_str(":" , -18 + (int)name_.length()); |
243 | oss << "(grid:" << pad_int(grid_dim(), 5) << ") x " ; |
244 | oss << "(loop:" << pad_int(loop_dim(), 5) << ") x " ; |
245 | oss << "(tg:" << pad_int(tg_dim(), 5) << ") x " ; |
246 | oss << "(iter:" << pad_int(iter_dim(), 5) << ")" ; |
247 | return oss.str(); |
248 | } |
249 | |
250 | IR_DEFINE_DUMP() |
251 | |
252 | private: |
253 | static std::string pad_str(std::string s, int pad) { |
254 | auto pos = (pad >= 0 ? 0 : s.length()); |
255 | s.insert(pos, std::abs(pad) - s.length(), ' '); |
256 | return s; |
257 | } |
258 | |
259 | static std::string pad_int(int i, int pad) { |
260 | return pad_str(std::to_string(i), pad); |
261 | } |
262 | |
263 | // Dimension name. |
264 | std::string name_; |
265 | |
266 | // Dimension size. |
267 | int size_ = 0; |
268 | |
269 | // Minimal block size for iteration blocking. Iteration-level block must be |
270 | // divisible by this value. |
271 | int base_iter_blk_ = 1; |
272 | |
273 | // Block size to ensure correct zero padding. Blocked memory layouts must |
274 | // be fully covered to ensure they are zero-padded. |
275 | int pad_blk_ = 1; |
276 | |
277 | // Dimension kind in terms of BMNK notation. |
278 | char bmnk_ = ' '; |
279 | |
280 | // Number of prb dims for a BMNK dim |
281 | int inner_dims_ = 0; |
282 | |
283 | // Whether the dimension can be blocked. Dimensions without blocking are |
284 | // implicitly tiled on the grid level (handle one element per thread |
285 | // group). |
286 | bool is_blocked_ = false; |
287 | |
288 | // Whether the dimension can be fused with other "fused" dimensions on the |
289 | // same tile level. |
290 | bool allow_fuse_ = false; |
291 | |
292 | // Whether the dimension can be split between multiple tile levels. |
293 | bool allow_split_ = false; |
294 | |
295 | bool pref_tg_block_ = false; |
296 | |
297 | // Dimensions with smaller order keys are tiled first. |
298 | int order_key_ = -1; |
299 | |
300 | // Dimensions of tiles. |
301 | dim_value_t tile_dims_[max_tile_level_idx + 1]; |
302 | |
303 | // Max allowed dimensions of tiles. |
304 | dim_value_t max_tile_dims_[max_tile_level_idx + 1]; |
305 | }; |
306 | |
307 | // Block helper provides functionality to compute tiling/blocking for a |
308 | // GEMM-like problem (when problem dimensions can be classified in terms of |
309 | // BMNK dimensions). A typical flow consists of three steps: |
310 | // - Setting problem configuration: |
311 | // - Problem dimensions (sizes, padding requirements, base block |
312 | // requirements, BMNK behavior) |
313 | // - HW details, FMA kind, data types, etc |
314 | // - Setting restrictions/hints, for example: |
315 | // - Maximal block sizes (e.g. to limit GRF usage) |
316 | // - Fuse/split settings - to get the desired blocking decomposition |
317 | // - K-slicing settings (grid or thread group sliciing) |
318 | // - Computing block sizes for each tile: per iteration, per loop, per thread |
319 | // group |
320 | class block_helper_t { |
321 | public: |
322 | bool is_frozen() const { return is_frozen_; } |
323 | |
324 | const std::unordered_map<std::string, dim_info_t> &dims() const { |
325 | return dims_; |
326 | } |
327 | |
328 | dim_info_t &dim(const std::string &name) { |
329 | ir_assert(dims_.count(name) != 0) << "Dimension not found: " << name; |
330 | return dims_.at(name); |
331 | } |
332 | |
333 | void set_hw_config(const hw_config_t &hw_cfg) { |
334 | check_if_can_set(); |
335 | hw_cfg_ = hw_cfg; |
336 | } |
337 | |
338 | void set_fma_kind(fma_kind_t fma_kind) { |
339 | check_if_can_set(); |
340 | fma_kind_ = fma_kind; |
341 | } |
342 | |
343 | void set_simd_size(int simd_size) { |
344 | check_if_can_set(); |
345 | simd_size_ = simd_size; |
346 | } |
347 | |
348 | void set_vec_size(int vec_size) { |
349 | check_if_can_set(); |
350 | vec_size_ = vec_size; |
351 | } |
352 | |
353 | void set_max_tg_size(int max_tg_size) { |
354 | check_if_can_set(); |
355 | max_tg_size_ = max_tg_size; |
356 | } |
357 | |
358 | void set_max_tg_overridden(bool max_tg_overridden) { |
359 | check_if_can_set(); |
360 | max_tg_overridden_ = max_tg_overridden; |
361 | } |
362 | |
363 | void set_abc_types( |
364 | data_type_t a_type, data_type_t b_type, data_type_t c_type) { |
365 | check_if_can_set(); |
366 | a_type_ = a_type; |
367 | b_type_ = b_type; |
368 | c_type_ = c_type; |
369 | } |
370 | |
371 | void set_use_2d_send(bool use_a_2d_send, bool use_b_2d_send) { |
372 | use_a_2d_send_ = use_a_2d_send; |
373 | use_b_2d_send_ = use_b_2d_send; |
374 | } |
375 | |
376 | void set_max_m_tg_dim(int value) { |
377 | m_dim().set_max_dim(tile_level_t::tg, value); |
378 | } |
379 | |
380 | void set_max_n_tg_dim(int value) { |
381 | n_dim().set_max_dim(tile_level_t::tg, value); |
382 | } |
383 | |
384 | void set_max_k_tg_dim(int value) { |
385 | k_dim().set_max_dim(tile_level_t::tg, value); |
386 | } |
387 | |
388 | void set_dims(std::initializer_list<std::string> names, |
389 | std::initializer_list<int> sizes) { |
390 | check_if_can_set(); |
391 | ir_assert(names.size() == sizes.size()); |
392 | for (size_t i = 0; i < names.size(); i++) { |
393 | set_dim(*(names.begin() + i), *(sizes.begin() + i)); |
394 | } |
395 | } |
396 | |
397 | void set_dim(const std::string &name, int size) { |
398 | check_if_can_set(); |
399 | ir_assert(dims_.count(name) == 0) |
400 | << "Dimension already exists: " << name; |
401 | dims_.emplace(name, dim_info_t(name, size)); |
402 | } |
403 | |
404 | void set_b_dims(std::initializer_list<std::string> names) { |
405 | set_bmnk_dims(names, 'B'); |
406 | } |
407 | void set_m_dims(std::initializer_list<std::string> names) { |
408 | set_bmnk_dims(names, 'M'); |
409 | } |
410 | void set_n_dims(std::initializer_list<std::string> names) { |
411 | set_bmnk_dims(names, 'N'); |
412 | } |
413 | void set_k_dims(std::initializer_list<std::string> names) { |
414 | set_bmnk_dims(names, 'K'); |
415 | } |
416 | |
417 | void set_block_dims(std::initializer_list<std::string> names) { |
418 | check_if_can_set(); |
419 | for (auto &name : names) { |
420 | dim(name).set_blocked(); |
421 | } |
422 | } |
423 | |
424 | void set_loop_dim(const std::string &name, int value) { |
425 | dim(name).set_loop_dim(value); |
426 | } |
427 | |
428 | void set_tg_dim(const std::string &name, int value) { |
429 | dim(name).set_tg_dim(value); |
430 | } |
431 | |
432 | void set_max_iter_dim(const std::string &name, int value) { |
433 | dim(name).set_max_dim(tile_level_t::iter, value); |
434 | } |
435 | |
436 | void set_max_loop_dim(const std::string &name, int value) { |
437 | dim(name).set_max_dim(tile_level_t::loop, value); |
438 | } |
439 | |
440 | void set_max_tg_dim(const std::string &name, int value) { |
441 | dim(name).set_max_dim(tile_level_t::tg, value); |
442 | } |
443 | |
444 | void set_pref_tg_block(const std::string &name, bool value = true) { |
445 | dim(name).set_pref_tg_block(value); |
446 | } |
447 | |
448 | bool any_pref_tg_block() { |
449 | for (auto &kv : dims_) { |
450 | auto &d = kv.second; |
451 | if (d.pref_tg_block()) return true; |
452 | } |
453 | return false; |
454 | } |
455 | |
456 | void set_reduce_m_block_hint(bool value = true) { |
457 | reduce_m_block_hint_ = value; |
458 | reduce_m_block_hint_set_ = true; |
459 | } |
460 | |
461 | void allow_fuse(std::initializer_list<std::string> names) { |
462 | check_if_can_set(); |
463 | for (auto &name : names) { |
464 | dim(name).set_allow_fuse(); |
465 | } |
466 | } |
467 | |
468 | void allow_split(std::initializer_list<std::string> names) { |
469 | check_if_can_set(); |
470 | for (auto &name : names) { |
471 | dim(name).set_allow_split(); |
472 | } |
473 | } |
474 | |
475 | void allow_k_tg_slicing() { allow_k_tg_slicing_ = true; } |
476 | |
477 | void allow_k_grid_slicing() { allow_k_grid_slicing_ = true; } |
478 | |
479 | void set_vector_dim(const std::string &name) { |
480 | check_if_can_set(); |
481 | auto &d = dim(name); |
482 | d.set_base_iter_block(math::lcm(vec_size_, d.base_iter_block())); |
483 | vector_bmnk_ = d.bmnk(); |
484 | } |
485 | |
486 | void set_base_iter_block(const std::string &name, int block) { |
487 | check_if_can_set(); |
488 | dim(name).set_base_iter_block(block); |
489 | } |
490 | |
491 | void set_base_iter_block(const std::string &name, int block0, int block1) { |
492 | set_base_iter_block(name, math::lcm(block0, block1)); |
493 | } |
494 | |
495 | void set_pad_block(const std::string &name, int block) { |
496 | dim(name).set_pad_block(block); |
497 | } |
498 | |
499 | void reorder(std::initializer_list<std::string> names) { |
500 | check_if_can_set(); |
501 | int key = 0; |
502 | for (auto &name : names) |
503 | dim(name).set_order_key(key++); |
504 | } |
505 | |
506 | void compute(); |
507 | |
508 | bool has_dim(const std::string &name) const { |
509 | return dims_.count(name) != 0; |
510 | } |
511 | |
512 | int iter_dim(const std::string &name) const { return dim(name).iter_dim(); } |
513 | int loop_dim(const std::string &name) const { return dim(name).loop_dim(); } |
514 | int tg_dim(const std::string &name) const { return dim(name).tg_dim(); } |
515 | int grid_dim(const std::string &name) const { return dim(name).grid_dim(); } |
516 | |
517 | int iter_blk(const std::string &name) const { return dim(name).iter_blk(); } |
518 | int loop_blk(const std::string &name) const { return dim(name).loop_blk(); } |
519 | int tg_blk(const std::string &name) const { return dim(name).tg_blk(); } |
520 | |
521 | dim_value_t max_iter_dim(const std::string &name) const { |
522 | return dim(name).max_dim(tile_level_t::iter); |
523 | } |
524 | |
525 | int padded_size(const std::string &name) const { |
526 | return dim(name).padded_size(); |
527 | } |
528 | |
529 | std::string str() const { |
530 | std::ostringstream oss; |
531 | for (auto &kv : dims_) { |
532 | auto &d = kv.second; |
533 | if (!d.has_any_blocking()) continue; |
534 | oss << d.str(); |
535 | } |
536 | return oss.str(); |
537 | } |
538 | |
539 | std::string brief_str() { |
540 | std::ostringstream oss; |
541 | for (auto &kv : dims_) { |
542 | auto &d = kv.second; |
543 | if (!d.has_any_blocking()) continue; |
544 | oss << " " << d.brief_str() << std::endl; |
545 | } |
546 | return oss.str(); |
547 | } |
548 | |
549 | IR_DEFINE_DUMP() |
550 | |
551 | private: |
552 | void check_if_can_set() const { |
553 | ir_assert(!is_frozen_) << "Can't set: setup is already frozen." ; |
554 | } |
555 | |
556 | const dim_info_t &dim(const std::string &name) const { |
557 | ir_assert(dims_.count(name) != 0) << "Dimension not found: " << name; |
558 | return dims_.at(name); |
559 | } |
560 | |
561 | dim_info_t &b_dim() { return bmnk_dims_[0]; } |
562 | dim_info_t &m_dim() { return bmnk_dims_[1]; } |
563 | dim_info_t &n_dim() { return bmnk_dims_[2]; } |
564 | dim_info_t &k_dim() { return bmnk_dims_[3]; } |
565 | const dim_info_t &b_dim() const { return bmnk_dims_[0]; } |
566 | const dim_info_t &m_dim() const { return bmnk_dims_[1]; } |
567 | const dim_info_t &n_dim() const { return bmnk_dims_[2]; } |
568 | const dim_info_t &k_dim() const { return bmnk_dims_[3]; } |
569 | |
570 | dim_info_t &bmnk_dim(char bmnk) { |
571 | auto &ret = const_cast<const block_helper_t *>(this)->bmnk_dim(bmnk); |
572 | return const_cast<dim_info_t &>(ret); |
573 | } |
574 | |
575 | const dim_info_t &bmnk_dim(char bmnk) const { |
576 | switch (bmnk) { |
577 | case 'B': return b_dim(); |
578 | case 'M': return m_dim(); |
579 | case 'N': return n_dim(); |
580 | case 'K': return k_dim(); |
581 | default: ir_error_not_expected(); |
582 | } |
583 | return b_dim(); |
584 | } |
585 | |
586 | int prb_blocked_ndims(char bmnk) const { |
587 | int ret = 0; |
588 | for (auto &kv : dims_) { |
589 | auto &d = kv.second; |
590 | if (d.bmnk() != bmnk) continue; |
591 | if (!d.is_blocked()) continue; |
592 | if (d.size() == 1) continue; |
593 | ret++; |
594 | } |
595 | return ret; |
596 | } |
597 | |
598 | const dim_info_t &prb_blocked_dim(char bmnk) const { |
599 | ir_assert(prb_blocked_ndims(bmnk) == 1); |
600 | for (auto &kv : dims_) { |
601 | auto &d = kv.second; |
602 | if (d.bmnk() != bmnk) continue; |
603 | if (!d.is_blocked()) continue; |
604 | if (d.size() == 1) continue; |
605 | return d; |
606 | } |
607 | return dim("" ); |
608 | } |
609 | |
610 | int prb_max_dim(char bmnk, tile_level_t level) const { |
611 | int ret = 1; |
612 | for (auto &kv : dims_) { |
613 | auto &d = kv.second; |
614 | if (d.bmnk() != bmnk) continue; |
615 | if (!d.is_blocked()) continue; |
616 | ret *= min(d.size(), d.max_dim(level)); |
617 | } |
618 | return ret; |
619 | } |
620 | |
621 | void set_bmnk_dims(std::initializer_list<std::string> dims, char bmnk) { |
622 | check_if_can_set(); |
623 | for (auto &d : dims) |
624 | dims_.at(d).set_bmnk(bmnk); |
625 | } |
626 | |
627 | bool is_x8x8s32() const { |
628 | if (!utils::one_of(a_type_, data_type::s8, data_type::u8)) return false; |
629 | if (!utils::one_of(b_type_, data_type::s8, data_type::u8)) return false; |
630 | if (c_type_ != data_type::s32) return false; |
631 | return true; |
632 | } |
633 | bool is_tf32() const { |
634 | return a_type_ == data_type::tf32 && b_type_ == data_type::tf32 |
635 | && c_type_ == data_type::f32; |
636 | } |
637 | |
638 | bool vectorize_by_b() const { return vectorize_by_bmnk('B'); } |
639 | bool vectorize_by_m() const { return vectorize_by_bmnk('M'); } |
640 | bool vectorize_by_n() const { return vectorize_by_bmnk('N'); } |
641 | bool vectorize_by_k() const { return vectorize_by_bmnk('K'); } |
642 | |
643 | bool vectorize_by_bmnk(char bmnk) const { return vector_bmnk_ == bmnk; } |
644 | |
645 | int b_size() const { return b_dim().size(); } |
646 | int m_size() const { return m_dim().size(); } |
647 | int n_size() const { return n_dim().size(); } |
648 | int k_size() const { return k_dim().size(); } |
649 | |
650 | void init_bmnk_dims() { |
651 | for (char bmnk : {'B', 'M', 'N', 'K'}) { |
652 | auto &d = bmnk_dim(bmnk); |
653 | d.set_name(std::string(1, bmnk)); |
654 | d.set_size(1); |
655 | d.set_bmnk(bmnk); |
656 | } |
657 | |
658 | for (auto &kv : dims_) { |
659 | auto &d = kv.second; |
660 | auto &bmnk_d = bmnk_dim(d.bmnk()); |
661 | if (d.pref_tg_block()) bmnk_d.set_pref_tg_block(); |
662 | bmnk_d.set_size(bmnk_d.size() * d.size()); |
663 | bmnk_d.set_base_iter_block( |
664 | bmnk_d.base_iter_block() * d.base_iter_block()); |
665 | if (d.is_blocked() && d.size() != 1) bmnk_d.incr_inner_dims(); |
666 | } |
667 | } |
668 | |
669 | void init_bmnk_blocks(); |
670 | void init_k_blocking(); |
671 | bool enable_k_tg_slicing() const; |
672 | void init_prb_blocks(); |
673 | int compute_mad_k_block() const; |
674 | |
675 | static int compute_block(int dim, int target_blk, int base_iter_block, |
676 | double target_eff = 0.75) { |
677 | int nblks = ir_utils::safe_divide(target_blk, base_iter_block); |
678 | while (nblks != 1) { |
679 | int dim_padded = utils::rnd_up(dim, nblks * base_iter_block); |
680 | double eff = (double)dim / dim_padded; |
681 | if (eff >= target_eff) break; |
682 | nblks--; |
683 | } |
684 | return nblks * base_iter_block; |
685 | } |
686 | |
687 | // Whether compute() was already called. |
688 | bool is_frozen_ = false; |
689 | |
690 | // BMNK kind of dimension to vectorize. |
691 | char vector_bmnk_ = ' '; |
692 | |
693 | // General information about HW and computation. |
694 | hw_config_t hw_cfg_; |
695 | fma_kind_t fma_kind_ = fma_kind_t::unknown; |
696 | int vec_size_ = -1; |
697 | int simd_size_ = -1; |
698 | int max_tg_size_ = 0; |
699 | bool max_tg_overridden_ = false; |
700 | data_type_t a_type_ = data_type::undef; |
701 | data_type_t b_type_ = data_type::undef; |
702 | data_type_t c_type_ = data_type::undef; |
703 | |
704 | bool use_a_2d_send_ = false; |
705 | bool use_b_2d_send_ = false; |
706 | |
707 | // Whether K computation can be split across threads in thread group. |
708 | bool allow_k_tg_slicing_ = false; |
709 | |
710 | // Whether K computation can be split across thread groups in the grid. |
711 | bool allow_k_grid_slicing_ = false; |
712 | |
713 | // Problem dimensions. |
714 | std::unordered_map<std::string, dim_info_t> dims_; |
715 | |
716 | // BMNK dimensions. |
717 | static const int bmnk_length = 4; |
718 | dim_info_t bmnk_dims_[bmnk_length]; |
719 | |
720 | bool reduce_m_block_hint_; |
721 | bool reduce_m_block_hint_set_ = false; |
722 | }; |
723 | |
724 | } // namespace jit |
725 | } // namespace gpu |
726 | } // namespace impl |
727 | } // namespace dnnl |
728 | |
729 | #endif |
730 | |