1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/fma.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace jit { |
23 | |
24 | std::string fma_kind::to_string(fma_kind_t val) { |
25 | switch (val) { |
26 | case fma_kind_t::mad: return "mad" ; |
27 | case fma_kind_t::dp4a: return "dp4a" ; |
28 | case fma_kind_t::dpas: return "dpas" ; |
29 | case fma_kind_t::dpasw: return "dpasw" ; |
30 | case fma_kind_t::unknown: return "unknown" ; |
31 | default: assert(!"unknown fma kind" ); return "unknown" ; |
32 | } |
33 | } |
34 | |
35 | fma_kind_t fma_kind::from_string(std::string enum_string) { |
36 | for (int enum_int = static_cast<int>(fma_kind_t::mad); |
37 | enum_int <= static_cast<int>(fma_kind_t::unknown); enum_int++) { |
38 | fma_kind_t enum_val = static_cast<fma_kind_t>(enum_int); |
39 | if (fma_kind::to_string(enum_val).compare(enum_string) == 0) |
40 | return enum_val; |
41 | } |
42 | assert(!"unknown fma kind" ); |
43 | return fma_kind_t::unknown; |
44 | } |
45 | |
46 | fma_kind_t fma_kind::get_supported_kind( |
47 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) { |
48 | if (hw >= ngen::HW::XeHP && dpas_t::matches_types(hw, a, b, c)) { |
49 | if (hw >= ngen::HW::XeHPC) |
50 | return fma_kind_t::dpas; |
51 | else |
52 | return fma_kind_t::dpasw; |
53 | } |
54 | if (mad_t::matches_types(hw, a, b, c)) return fma_kind_t::mad; |
55 | return fma_kind_t::unknown; |
56 | } |
57 | |
58 | int fma_kind::get_simd_size(ngen::HW hw, const fma_kind_t kind, const type_t &a, |
59 | const type_t &b, const type_t &c) { |
60 | int max_simd_size = 16; |
61 | int min_simd_size = hw >= ngen::HW::XeHPC ? 16 : 8; |
62 | int ret = 0; |
63 | switch (kind) { |
64 | case fma_kind_t::dp4a: |
65 | ret = mad_t::get_simd_size(hw, a.with_elems(4), b.with_elems(4), c); |
66 | break; |
67 | case fma_kind_t::dpas: |
68 | case fma_kind_t::dpasw: ret = hw >= ngen::HW::XeHPC ? 16 : 8; break; |
69 | case fma_kind_t::mad: ret = mad_t::get_simd_size(hw, a, b, c); break; |
70 | default: break; |
71 | } |
72 | ir_assert(ret != 0); |
73 | ret = std::max(std::min(ret, max_simd_size), min_simd_size); |
74 | return ret; |
75 | } |
76 | |
77 | type_t multiply_desc_t::get_c_type( |
78 | const type_t &a, const type_t &b, bool force_c_upconvert) { |
79 | if (utils::one_of( |
80 | a, type_t::s8(), type_t::u8(), type_t::s16(), type_t::s32()) |
81 | && utils::one_of(b, type_t::s8(), type_t::u8(), type_t::s16(), |
82 | type_t::s32())) |
83 | return type_t::s32(); |
84 | |
85 | if (a == type_t::bf16() && b == type_t::bf16()) return type_t::f32(); |
86 | if (a == type_t::tf32() && b == type_t::tf32()) return type_t::f32(); |
87 | if (a == type_t::f32() && b == type_t::f32()) return type_t::f32(); |
88 | if (a == type_t::f64() && b == type_t::f64()) return type_t::f64(); |
89 | |
90 | if (utils::one_of(a, type_t::f16(), type_t::bf16()) && b == type_t::f32()) { |
91 | return type_t::f32(); |
92 | } |
93 | |
94 | if (a == type_t::f16() && b == type_t::f16()) { |
95 | if (force_c_upconvert) return type_t::f32(); |
96 | return type_t::f16(); |
97 | } |
98 | |
99 | ir_error_not_expected() |
100 | << "Can't deduce C type. A type: " << a << " B type: " << b; |
101 | return type_t::undef(); |
102 | } |
103 | |
104 | bool dpas_t::is_src_type(type_t type) { |
105 | return utils::one_of(type.kind(), type_kind_t::u8, type_kind_t::s8, |
106 | type_kind_t::bf16, type_kind_t::f16, type_kind_t::tf32); |
107 | } |
108 | |
109 | layout_t dpas_t::a_layout() const { |
110 | if (!is_src_type(src1_type)) ir_error_not_expected(); |
111 | |
112 | int m_blk = exec_size; |
113 | int inner_blk = 4 / src1_type.size(); |
114 | int outer_blk = sdepth; |
115 | std::vector<std::pair<int, dim_t>> blocks |
116 | = {{1, outer_blk}, {0, m_blk}, {1, inner_blk}}; |
117 | return layout_t(src1_type, 0, blocks); |
118 | } |
119 | |
120 | layout_t dpas_t::b_layout() const { |
121 | if (!is_src_type(src2_type)) ir_error_not_expected(); |
122 | |
123 | int n_blk = rcount; |
124 | int k_blk = sdepth * 4 / src2_type.size(); |
125 | std::vector<dim_t> blocks = {n_blk, k_blk}; |
126 | auto tmp = layout_t(src2_type, 0, blocks); |
127 | return tmp.transpose(); |
128 | } |
129 | |
130 | layout_t dpas_t::c_layout() const { |
131 | int m_blk = exec_size; |
132 | int n_blk = rcount; |
133 | std::vector<dim_t> dims = {n_blk, m_blk}; |
134 | return layout_t(dst_type, 0, dims).transpose(); |
135 | } |
136 | |
137 | bool dpas_t::matches(const multiply_desc_t &desc) const { |
138 | int m_blk = exec_size; |
139 | int n_blk = rcount; |
140 | int k_blk = sdepth * 4 / src1_type.size(); |
141 | |
142 | if (desc.m() % m_blk != 0 || desc.k() % k_blk != 0) return false; |
143 | |
144 | auto a_blk_layout = desc.a_layout().map(tensor_t({m_blk, k_blk})); |
145 | auto b_blk_layout = desc.b_layout().map(tensor_t({k_blk, n_blk})); |
146 | |
147 | if (a_blk_layout != a_layout()) return false; |
148 | if (b_blk_layout != b_layout()) return false; |
149 | |
150 | return true; |
151 | } |
152 | |
153 | bool dpas_t::matches_types( |
154 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) { |
155 | if (a.is_x8() && b.is_x8() && c.is_s32()) return true; |
156 | if (a.is_f16() && b.is_f16() && c.is_f32()) return true; |
157 | if (a.is_bf16() && b.is_bf16() && c.is_f32()) return true; |
158 | if (a.is_tf32() && b.is_tf32() && c.is_f32() && hw >= ngen::HW::XeHPC) |
159 | return true; |
160 | |
161 | return false; |
162 | } |
163 | |
164 | bool mad_t::matches_types( |
165 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) { |
166 | if (a != b) return false; |
167 | |
168 | if (a.is_f64() && c.is_f64()) return true; |
169 | if (a.is_f32() && c.is_f32()) return true; |
170 | if (a.is_f16() && c.is_f16()) return true; |
171 | if (a.is_f16() && c.is_f32()) return true; |
172 | if (hw >= ngen::HW::XeHP) { |
173 | if (a.is_bf16() && c.is_f32()) return true; |
174 | if (a.is_f32() && c.is_bf16()) return true; |
175 | } |
176 | if (a.is_x8() && (c.is_x16() || c.is_x32())) return true; |
177 | if ((a.is_x16() || a.is_x32()) && (c.is_x16() || c.is_x32())) return true; |
178 | |
179 | return false; |
180 | } |
181 | |
182 | } // namespace jit |
183 | } // namespace gpu |
184 | } // namespace impl |
185 | } // namespace dnnl |
186 | |