1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_FMA_HPP |
18 | #define GPU_JIT_IR_FMA_HPP |
19 | |
20 | #include <sstream> |
21 | #include <string> |
22 | |
23 | #include "gpu/jit/ir/tensor.hpp" |
24 | #include "gpu/jit/ngen/ngen.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | // Possible backend instruction sets |
32 | enum class fma_kind_t { |
33 | mad, |
34 | dp4a, |
35 | dpas, |
36 | dpasw, |
37 | unknown, |
38 | }; |
39 | |
40 | namespace fma_kind { |
41 | |
42 | std::string to_string(fma_kind_t val); |
43 | fma_kind_t from_string(std::string enum_string); |
44 | |
45 | fma_kind_t get_supported_kind( |
46 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c); |
47 | |
48 | int get_simd_size(ngen::HW hw, fma_kind_t kind, const type_t &a, |
49 | const type_t &b, const type_t &c); |
50 | |
51 | } // namespace fma_kind |
52 | |
53 | class multiply_desc_t { |
54 | public: |
55 | multiply_desc_t() = default; |
56 | |
57 | multiply_desc_t(const layout_t &a_layout, const layout_t &b_layout, |
58 | bool force_c_upconvert) |
59 | : a_layout_(a_layout), b_layout_(b_layout) { |
60 | ir_assert(a_layout.ndims() == 2 && b_layout.ndims() == 2) |
61 | << "Expected 2D layouts, A layout: " << a_layout |
62 | << " B layout: " << b_layout; |
63 | |
64 | c_type_ = get_c_type(a_type(), b_type(), force_c_upconvert); |
65 | } |
66 | |
67 | const layout_t &a_layout() const { return a_layout_; } |
68 | const layout_t &b_layout() const { return b_layout_; } |
69 | |
70 | const type_t &a_type() const { return a_layout_.type(); } |
71 | const type_t &b_type() const { return b_layout_.type(); } |
72 | const type_t &c_type() const { return c_type_; } |
73 | |
74 | int m() const { return a_layout_.dims()[0]; } |
75 | int n() const { return b_layout_.dims()[1]; } |
76 | int k() const { return a_layout_.dims()[1]; } |
77 | |
78 | static type_t get_c_type( |
79 | const type_t &a, const type_t &b, bool force_c_upconvert); |
80 | |
81 | private: |
82 | layout_t a_layout_; |
83 | layout_t b_layout_; |
84 | type_t c_type_; |
85 | }; |
86 | |
87 | // Function representing DPAS instruction. |
88 | class dpas_t : public func_impl_t { |
89 | public: |
90 | IR_DECL_DERIVED_TYPE_ID(dpas_t, func_impl_t) |
91 | |
92 | static func_t make(bool is_dpasw, int exec_size, int sdepth, int rcount, |
93 | const type_t &dst_type, const type_t &src1_type, |
94 | const type_t &src2_type) { |
95 | return func_t(new dpas_t(is_dpasw, exec_size, sdepth, rcount, dst_type, |
96 | src1_type, src2_type)); |
97 | } |
98 | |
99 | static func_t make_dpasw(const dpas_t &dpas) { |
100 | return func_t(new dpas_t(true, dpas.exec_size, dpas.sdepth, dpas.rcount, |
101 | dpas.dst_type, dpas.src1_type, dpas.src2_type)); |
102 | } |
103 | |
104 | static bool is_dp4a_call(const stmt_t &s) { |
105 | auto call = s.as_ptr<func_call_t>(); |
106 | return call && call->func.as<dpas_t>().is_dp4a(); |
107 | } |
108 | |
109 | bool is_dp4a() const { return rcount == 1 && sdepth == 1; } |
110 | |
111 | bool is_equal(const object_impl_t &obj) const override { |
112 | if (!obj.is<self_type>()) return false; |
113 | auto &other = obj.as<self_type>(); |
114 | |
115 | return (is_dpasw == other.is_dpasw) && (sdepth == other.sdepth) |
116 | && (rcount == other.rcount) && (dst_type == other.dst_type) |
117 | && (src1_type == other.src1_type) |
118 | && (src2_type == other.src2_type); |
119 | } |
120 | |
121 | size_t get_hash() const override { |
122 | return ir_utils::get_hash( |
123 | is_dpasw, sdepth, rcount, dst_type, src1_type, src2_type); |
124 | } |
125 | |
126 | std::string str() const override { |
127 | std::ostringstream oss; |
128 | oss << (is_dpasw ? "dpasw" : is_dp4a() ? "dp4a" : "dpas" ); |
129 | if (!is_dp4a()) oss << "." << sdepth << "x" << rcount; |
130 | return oss.str(); |
131 | } |
132 | |
133 | IR_DEFINE_ARG_GET(dst, 0) |
134 | IR_DEFINE_ARG_GET(src0, 1) |
135 | IR_DEFINE_ARG_GET(src1, 2) |
136 | IR_DEFINE_ARG_GET(src2, 3) |
137 | |
138 | stmt_t operator()(const expr_t &dst, const expr_t &src0, const expr_t &src1, |
139 | const expr_t &src2) const { |
140 | return call({dst, src0, src1, src2}); |
141 | } |
142 | |
143 | int dst_size() const { return exec_size * rcount * sizeof(uint32_t); } |
144 | int src0_size() const { return dst_size(); } |
145 | int src1_size() const { return exec_size * sdepth * sizeof(uint32_t); } |
146 | int src2_size() const { |
147 | const int dpas_size = sdepth * rcount * sizeof(uint32_t); |
148 | return is_dpasw ? dpas_size / 2 : dpas_size; |
149 | } |
150 | |
151 | layout_t a_layout() const; |
152 | layout_t b_layout() const; |
153 | layout_t c_layout() const; |
154 | |
155 | bool matches(const multiply_desc_t &desc) const; |
156 | |
157 | static bool matches_types( |
158 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c); |
159 | static bool is_src_type(type_t type); |
160 | |
161 | bool is_dpasw; |
162 | |
163 | int exec_size; |
164 | int sdepth; |
165 | int rcount; |
166 | |
167 | type_t dst_type; // src0 type is same as dst_type. |
168 | type_t src1_type; |
169 | type_t src2_type; |
170 | |
171 | private: |
172 | dpas_t(bool is_dpasw, int exec_size, int sdepth, int rcount, |
173 | const type_t &dst_type, const type_t &src1_type, |
174 | const type_t &src2_type) |
175 | : func_impl_t(_type_info()) |
176 | , is_dpasw(is_dpasw) |
177 | , exec_size(exec_size) |
178 | , sdepth(sdepth) |
179 | , rcount(rcount) |
180 | , dst_type(dst_type) |
181 | , src1_type(src1_type) |
182 | , src2_type(src2_type) {} |
183 | }; |
184 | |
185 | // Function representing MAD instruction. |
186 | class mad_t : public func_impl_t { |
187 | public: |
188 | IR_DECL_DERIVED_TYPE_ID(mad_t, func_impl_t) |
189 | |
190 | static func_t make(ngen::HW hw, const type_t &dst_type, int exec_size, |
191 | const type_t &src1_type, int src1_stride, const type_t src2_type, |
192 | int src2_stride) { |
193 | return func_t(new mad_t(hw, dst_type, exec_size, src1_type, src1_stride, |
194 | src2_type, src2_stride)); |
195 | } |
196 | |
197 | bool is_equal(const object_impl_t &obj) const override { |
198 | if (!obj.is<self_type>()) return false; |
199 | auto &other = obj.as<self_type>(); |
200 | |
201 | return (dst_type == other.dst_type) && (src1_type == other.src1_type) |
202 | && (src2_type == other.src2_type) |
203 | && (exec_size == other.exec_size) |
204 | && (src1_stride == other.src1_stride) |
205 | && (src2_stride == other.src2_stride); |
206 | } |
207 | |
208 | size_t get_hash() const override { |
209 | return ir_utils::get_hash(dst_type, src1_type, src2_type, exec_size, |
210 | src2_stride, src1_stride); |
211 | } |
212 | |
213 | std::string str() const override { |
214 | std::ostringstream oss; |
215 | oss << "mad" ; |
216 | return oss.str(); |
217 | } |
218 | |
219 | IR_DEFINE_ARG_GET(dst, 0) |
220 | IR_DEFINE_ARG_GET(src0, 1) |
221 | IR_DEFINE_ARG_GET(src1, 2) |
222 | IR_DEFINE_ARG_GET(src2, 3) |
223 | |
224 | stmt_t operator()(const expr_t &dst, const expr_t &src0, const expr_t &src1, |
225 | const expr_t &src2) const { |
226 | return call({dst, src0, src1, src2}); |
227 | } |
228 | |
229 | int dst_size() const { return exec_size * dst_type.size(); } |
230 | int src0_size() const { return dst_size(); } |
231 | int src1_size() const { |
232 | return std::max( |
233 | src1_type.size(), src1_stride * src1_type.size() * exec_size); |
234 | } |
235 | int src2_size() const { |
236 | return std::max( |
237 | src2_type.size(), src2_stride * src2_type.size() * exec_size); |
238 | } |
239 | |
240 | static bool matches_types( |
241 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c); |
242 | |
243 | static const int max_exec_size = 32; |
244 | static const int get_max_exec_size_bytes(ngen::HW hw) { |
245 | return hw >= ngen::HW::XeHPC ? 128 : 64; |
246 | } |
247 | static int get_simd_size( |
248 | ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) { |
249 | int max_exec_size_bytes = get_max_exec_size_bytes(hw); |
250 | int max_size = max_exec_size; |
251 | if (max_exec_size_bytes / a.size() < max_size) |
252 | max_size = max_exec_size_bytes / a.size(); |
253 | if (max_exec_size_bytes / b.size() < max_size) |
254 | max_size = max_exec_size_bytes / b.size(); |
255 | if (max_exec_size_bytes / c.size() < max_size) |
256 | max_size = max_exec_size_bytes / c.size(); |
257 | return max_size; |
258 | } |
259 | int get_exec_size() const { return exec_size; } |
260 | |
261 | type_t dst_type; |
262 | type_t src1_type; |
263 | type_t src2_type; |
264 | |
265 | int exec_size; |
266 | int src1_stride; |
267 | int src2_stride; |
268 | |
269 | private: |
270 | mad_t(ngen::HW hw, const type_t &dst_type, int exec_size, |
271 | const type_t &src1_type, int src1_stride, const type_t &src2_type, |
272 | int src2_stride) |
273 | : func_impl_t(_type_info()) |
274 | , dst_type(dst_type) |
275 | , src1_type(src1_type) |
276 | , src2_type(src2_type) |
277 | , exec_size(exec_size) |
278 | , src1_stride(src1_stride) |
279 | , src2_stride(src2_stride) { |
280 | int max_exec_size_bytes = get_max_exec_size_bytes(hw); |
281 | ir_assert(math::is_pow2(exec_size)); |
282 | |
283 | ir_assert(exec_size <= max_exec_size); |
284 | ir_assert(dst_size() <= max_exec_size_bytes); |
285 | ir_assert(src1_size() <= max_exec_size_bytes); |
286 | ir_assert(src2_size() <= max_exec_size_bytes); |
287 | } |
288 | }; |
289 | |
290 | } // namespace jit |
291 | } // namespace gpu |
292 | } // namespace impl |
293 | } // namespace dnnl |
294 | |
295 | #endif |
296 | |