1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/fma.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace jit {
23
24std::string fma_kind::to_string(fma_kind_t val) {
25 switch (val) {
26 case fma_kind_t::mad: return "mad";
27 case fma_kind_t::dp4a: return "dp4a";
28 case fma_kind_t::dpas: return "dpas";
29 case fma_kind_t::dpasw: return "dpasw";
30 case fma_kind_t::unknown: return "unknown";
31 default: assert(!"unknown fma kind"); return "unknown";
32 }
33}
34
35fma_kind_t fma_kind::from_string(std::string enum_string) {
36 for (int enum_int = static_cast<int>(fma_kind_t::mad);
37 enum_int <= static_cast<int>(fma_kind_t::unknown); enum_int++) {
38 fma_kind_t enum_val = static_cast<fma_kind_t>(enum_int);
39 if (fma_kind::to_string(enum_val).compare(enum_string) == 0)
40 return enum_val;
41 }
42 assert(!"unknown fma kind");
43 return fma_kind_t::unknown;
44}
45
46fma_kind_t fma_kind::get_supported_kind(
47 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) {
48 if (hw >= ngen::HW::XeHP && dpas_t::matches_types(hw, a, b, c)) {
49 if (hw >= ngen::HW::XeHPC)
50 return fma_kind_t::dpas;
51 else
52 return fma_kind_t::dpasw;
53 }
54 if (mad_t::matches_types(hw, a, b, c)) return fma_kind_t::mad;
55 return fma_kind_t::unknown;
56}
57
58int fma_kind::get_simd_size(ngen::HW hw, const fma_kind_t kind, const type_t &a,
59 const type_t &b, const type_t &c) {
60 int max_simd_size = 16;
61 int min_simd_size = hw >= ngen::HW::XeHPC ? 16 : 8;
62 int ret = 0;
63 switch (kind) {
64 case fma_kind_t::dp4a:
65 ret = mad_t::get_simd_size(hw, a.with_elems(4), b.with_elems(4), c);
66 break;
67 case fma_kind_t::dpas:
68 case fma_kind_t::dpasw: ret = hw >= ngen::HW::XeHPC ? 16 : 8; break;
69 case fma_kind_t::mad: ret = mad_t::get_simd_size(hw, a, b, c); break;
70 default: break;
71 }
72 ir_assert(ret != 0);
73 ret = std::max(std::min(ret, max_simd_size), min_simd_size);
74 return ret;
75}
76
77type_t multiply_desc_t::get_c_type(
78 const type_t &a, const type_t &b, bool force_c_upconvert) {
79 if (utils::one_of(
80 a, type_t::s8(), type_t::u8(), type_t::s16(), type_t::s32())
81 && utils::one_of(b, type_t::s8(), type_t::u8(), type_t::s16(),
82 type_t::s32()))
83 return type_t::s32();
84
85 if (a == type_t::bf16() && b == type_t::bf16()) return type_t::f32();
86 if (a == type_t::tf32() && b == type_t::tf32()) return type_t::f32();
87 if (a == type_t::f32() && b == type_t::f32()) return type_t::f32();
88 if (a == type_t::f64() && b == type_t::f64()) return type_t::f64();
89
90 if (utils::one_of(a, type_t::f16(), type_t::bf16()) && b == type_t::f32()) {
91 return type_t::f32();
92 }
93
94 if (a == type_t::f16() && b == type_t::f16()) {
95 if (force_c_upconvert) return type_t::f32();
96 return type_t::f16();
97 }
98
99 ir_error_not_expected()
100 << "Can't deduce C type. A type: " << a << " B type: " << b;
101 return type_t::undef();
102}
103
104bool dpas_t::is_src_type(type_t type) {
105 return utils::one_of(type.kind(), type_kind_t::u8, type_kind_t::s8,
106 type_kind_t::bf16, type_kind_t::f16, type_kind_t::tf32);
107}
108
109layout_t dpas_t::a_layout() const {
110 if (!is_src_type(src1_type)) ir_error_not_expected();
111
112 int m_blk = exec_size;
113 int inner_blk = 4 / src1_type.size();
114 int outer_blk = sdepth;
115 std::vector<std::pair<int, dim_t>> blocks
116 = {{1, outer_blk}, {0, m_blk}, {1, inner_blk}};
117 return layout_t(src1_type, 0, blocks);
118}
119
120layout_t dpas_t::b_layout() const {
121 if (!is_src_type(src2_type)) ir_error_not_expected();
122
123 int n_blk = rcount;
124 int k_blk = sdepth * 4 / src2_type.size();
125 std::vector<dim_t> blocks = {n_blk, k_blk};
126 auto tmp = layout_t(src2_type, 0, blocks);
127 return tmp.transpose();
128}
129
130layout_t dpas_t::c_layout() const {
131 int m_blk = exec_size;
132 int n_blk = rcount;
133 std::vector<dim_t> dims = {n_blk, m_blk};
134 return layout_t(dst_type, 0, dims).transpose();
135}
136
137bool dpas_t::matches(const multiply_desc_t &desc) const {
138 int m_blk = exec_size;
139 int n_blk = rcount;
140 int k_blk = sdepth * 4 / src1_type.size();
141
142 if (desc.m() % m_blk != 0 || desc.k() % k_blk != 0) return false;
143
144 auto a_blk_layout = desc.a_layout().map(tensor_t({m_blk, k_blk}));
145 auto b_blk_layout = desc.b_layout().map(tensor_t({k_blk, n_blk}));
146
147 if (a_blk_layout != a_layout()) return false;
148 if (b_blk_layout != b_layout()) return false;
149
150 return true;
151}
152
153bool dpas_t::matches_types(
154 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) {
155 if (a.is_x8() && b.is_x8() && c.is_s32()) return true;
156 if (a.is_f16() && b.is_f16() && c.is_f32()) return true;
157 if (a.is_bf16() && b.is_bf16() && c.is_f32()) return true;
158 if (a.is_tf32() && b.is_tf32() && c.is_f32() && hw >= ngen::HW::XeHPC)
159 return true;
160
161 return false;
162}
163
164bool mad_t::matches_types(
165 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) {
166 if (a != b) return false;
167
168 if (a.is_f64() && c.is_f64()) return true;
169 if (a.is_f32() && c.is_f32()) return true;
170 if (a.is_f16() && c.is_f16()) return true;
171 if (a.is_f16() && c.is_f32()) return true;
172 if (hw >= ngen::HW::XeHP) {
173 if (a.is_bf16() && c.is_f32()) return true;
174 if (a.is_f32() && c.is_bf16()) return true;
175 }
176 if (a.is_x8() && (c.is_x16() || c.is_x32())) return true;
177 if ((a.is_x16() || a.is_x32()) && (c.is_x16() || c.is_x32())) return true;
178
179 return false;
180}
181
182} // namespace jit
183} // namespace gpu
184} // namespace impl
185} // namespace dnnl
186