1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/gemm_schedule.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace jit {
23
24layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind,
25 const std::vector<bmnk_kind_t> &bmnk_kinds, const view_t &view) const {
26 auto layout = view.create_pseudo_vlayout();
27 return map_to_bmnk(abc_kind, bmnk_kinds, layout);
28}
29
30layout_t bmnk_mapper_t::map_to_bmnk(abc_kind_t abc_kind,
31 const std::vector<bmnk_kind_t> &bmnk_kinds,
32 const layout_t &layout) const {
33 std::vector<block_t> blocks;
34 for (auto &b : layout.blocks()) {
35 auto b_bmnk_kind = bmnk_kind(abc_kind, b.dim_idx);
36 bool found = false;
37 for (int i = 0; i < int(bmnk_kinds.size()); i++) {
38 if (bmnk_kinds[i] == b_bmnk_kind) {
39 blocks.emplace_back(i, b.block, b.stride);
40 found = true;
41 break;
42 }
43 }
44 if (!found) ir_error_not_expected() << "MNK dimension not found.";
45 }
46 return layout_t(layout.type(), int(bmnk_kinds.size()), 0, blocks);
47}
48
49void bmnk_block_mapper_t::push_block(abc_kind_t abc_kind, const block_t &b) {
50 auto bmnk_kind = bmnk_mapper_.bmnk_kind(abc_kind, b.dim_idx);
51 switch (bmnk_kind) {
52 case bmnk_kind_t::m: m_blocks_.emplace_back(abc_kind, b); break;
53 case bmnk_kind_t::n: n_blocks_.emplace_back(abc_kind, b); break;
54 case bmnk_kind_t::k: k_blocks_.emplace_back(abc_kind, b); break;
55 default: ir_error_not_expected() << "Unknown MNK kind.";
56 }
57}
58
59layout_t bmnk_block_mapper_t::map_from_bmnk(abc_kind_t abc_kind,
60 const std::vector<bmnk_kind_t> &bmnk_kinds,
61 const layout_t &bmnk_layout) const {
62 ir_assert(bmnk_layout.ndims() <= 3);
63 ir_assert(bmnk_layout.has_zero_offset());
64 std::vector<block_t> blocks;
65 std::vector<std::vector<block_t>> tmp_blocks(
66 static_cast<int>(bmnk_kind_t::k) + 1);
67 tmp_blocks[static_cast<int>(bmnk_kind_t::m)]
68 = create_prb_blocks(abc_kind, m_blocks_);
69 tmp_blocks[static_cast<int>(bmnk_kind_t::n)]
70 = create_prb_blocks(abc_kind, n_blocks_);
71 tmp_blocks[static_cast<int>(bmnk_kind_t::k)]
72 = create_prb_blocks(abc_kind, k_blocks_);
73 for (auto &b : bmnk_layout.blocks()) {
74 auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kinds[b.dim_idx])];
75 bool ok = pop_block(bmnk_blocks, blocks, b);
76 ir_assert(ok) << "Can't map from bmnk layout to problem layout.";
77 MAYBE_UNUSED(ok);
78 }
79 for (auto bmnk_kind : bmnk_kinds) {
80 auto &bmnk_blocks = tmp_blocks[static_cast<int>(bmnk_kind)];
81 pop_size_1_blocks(bmnk_blocks);
82 ir_assert(bmnk_blocks.empty());
83 }
84
85 // Fix strides to make them dense.
86 dim_t dense_stride = 1;
87 for (auto &b : blocks) {
88 b.stride = stride_t(dense_stride);
89 dense_stride *= b.block;
90 }
91
92 return layout_t(
93 bmnk_layout.type(), bmnk_mapper_.ndims(abc_kind), 0, blocks);
94}
95
96bool bmnk_block_mapper_t::pop_block(std::vector<block_t> &bmnk_blocks,
97 std::vector<block_t> &prb_blocks, const block_t &bmnk_block) const {
98 if (bmnk_block.block == 1) return true;
99
100 pop_size_1_blocks(bmnk_blocks);
101 if (bmnk_blocks.empty()) return false;
102
103 auto &next_block = bmnk_blocks.front();
104 dim_t common_block = math::gcd(next_block.block, bmnk_block.block);
105 if (common_block == bmnk_block.block) {
106 prb_blocks.emplace_back(
107 next_block.dim_idx, common_block, next_block.stride);
108 next_block.block /= common_block;
109 next_block.stride *= common_block;
110 return true;
111 } else if (common_block == next_block.block) {
112 prb_blocks.emplace_back(
113 next_block.dim_idx, common_block, next_block.stride);
114 bmnk_blocks.erase(bmnk_blocks.begin());
115 auto tmp_block = bmnk_block;
116 tmp_block.block /= common_block;
117 return pop_block(bmnk_blocks, prb_blocks, tmp_block);
118 }
119 return false;
120}
121
122} // namespace jit
123} // namespace gpu
124} // namespace impl
125} // namespace dnnl
126