1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/gemm_schedule.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace jit { |
23 | |
24 | layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind, |
25 | const std::vector<bmnk_kind_t> &bmnk_kinds, const view_t &view) const { |
26 | auto layout = view.create_pseudo_vlayout(); |
27 | return map_to_bmnk(abc_kind, bmnk_kinds, layout); |
28 | } |
29 | |
30 | layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind, |
31 | const std::vector<bmnk_kind_t> &bmnk_kinds, |
32 | const layout_t &layout) const { |
33 | std::vector<block_t> blocks; |
34 | for (auto &b : layout.blocks()) { |
35 | auto b_bmnk_kind = bmnk_kind(abc_kind, b.dim_idx); |
36 | bool found = false; |
37 | for (int i = 0; i < int(bmnk_kinds.size()); i++) { |
38 | if (bmnk_kinds[i] == b_bmnk_kind) { |
39 | blocks.emplace_back(i, b.block, b.stride); |
40 | found = true; |
41 | break; |
42 | } |
43 | } |
44 | if (!found) ir_error_not_expected() << "MNK dimension not found." ; |
45 | } |
46 | return layout_t(layout.type(), int(bmnk_kinds.size()), 0, blocks); |
47 | } |
48 | |
49 | void bmnk_block_mapper_t::push_block(abc_kind_t abc_kind, const block_t &b) { |
50 | auto bmnk_kind = bmnk_mapper_.bmnk_kind(abc_kind, b.dim_idx); |
51 | switch (bmnk_kind) { |
52 | case bmnk_kind_t::m: m_blocks_.emplace_back(abc_kind, b); break; |
53 | case bmnk_kind_t::n: n_blocks_.emplace_back(abc_kind, b); break; |
54 | case bmnk_kind_t::k: k_blocks_.emplace_back(abc_kind, b); break; |
55 | default: ir_error_not_expected() << "Unknown MNK kind." ; |
56 | } |
57 | } |
58 | |
59 | layout_t bmnk_block_mapper_t::map_from_bmnk(abc_kind_t abc_kind, |
60 | const std::vector<bmnk_kind_t> &bmnk_kinds, |
61 | const layout_t &bmnk_layout) const { |
62 | ir_assert(bmnk_layout.ndims() <= 3); |
63 | ir_assert(bmnk_layout.has_zero_offset()); |
64 | std::vector<block_t> blocks; |
65 | std::vector<std::vector<block_t>> tmp_blocks( |
66 | static_cast<int>(bmnk_kind_t::k) + 1); |
67 | tmp_blocks[static_cast<int>(bmnk_kind_t::m)] |
68 | = create_prb_blocks(abc_kind, m_blocks_); |
69 | tmp_blocks[static_cast<int>(bmnk_kind_t::n)] |
70 | = create_prb_blocks(abc_kind, n_blocks_); |
71 | tmp_blocks[static_cast<int>(bmnk_kind_t::k)] |
72 | = create_prb_blocks(abc_kind, k_blocks_); |
73 | for (auto &b : bmnk_layout.blocks()) { |
74 | auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kinds[b.dim_idx])]; |
75 | bool ok = pop_block(bmnk_blocks, blocks, b); |
76 | ir_assert(ok) << "Can't map from bmnk layout to problem layout." ; |
77 | MAYBE_UNUSED(ok); |
78 | } |
79 | for (auto bmnk_kind : bmnk_kinds) { |
80 | auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kind)]; |
81 | pop_size_1_blocks(bmnk_blocks); |
82 | ir_assert(bmnk_blocks.empty()); |
83 | } |
84 | |
85 | // Fix strides to make them dense. |
86 | dim_t dense_stride = 1; |
87 | for (auto &b : blocks) { |
88 | b.stride = stride_t(dense_stride); |
89 | dense_stride *= b.block; |
90 | } |
91 | |
92 | return layout_t( |
93 | bmnk_layout.type(), bmnk_mapper_.ndims(abc_kind), 0, blocks); |
94 | } |
95 | |
96 | bool bmnk_block_mapper_t::pop_block(std::vector<block_t> &bmnk_blocks, |
97 | std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const { |
98 | if (bmnk_block.block == 1) return true; |
99 | |
100 | pop_size_1_blocks(bmnk_blocks); |
101 | if (bmnk_blocks.empty()) return false; |
102 | |
103 | auto &next_block = bmnk_blocks.front(); |
104 | dim_t common_block = math::gcd(next_block.block, bmnk_block.block); |
105 | if (common_block == bmnk_block.block) { |
106 | prb_blocks.emplace_back( |
107 | next_block.dim_idx, common_block, next_block.stride); |
108 | next_block.block /= common_block; |
109 | next_block.stride *= common_block; |
110 | return true; |
111 | } else if (common_block == next_block.block) { |
112 | prb_blocks.emplace_back( |
113 | next_block.dim_idx, common_block, next_block.stride); |
114 | bmnk_blocks.erase(bmnk_blocks.begin()); |
115 | auto tmp_block = bmnk_block; |
116 | tmp_block.block /= common_block; |
117 | return pop_block(bmnk_blocks, prb_blocks, tmp_block); |
118 | } |
119 | return false; |
120 | } |
121 | |
122 | } // namespace jit |
123 | } // namespace gpu |
124 | } // namespace impl |
125 | } // namespace dnnl |
126 | |