1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_FMA_HPP
18#define GPU_JIT_IR_FMA_HPP
19
20#include <sstream>
21#include <string>
22
23#include "gpu/jit/ir/tensor.hpp"
24#include "gpu/jit/ngen/ngen.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31// Possible backend instruction sets
32enum class fma_kind_t {
33 mad,
34 dp4a,
35 dpas,
36 dpasw,
37 unknown,
38};
39
40namespace fma_kind {
41
42std::string to_string(fma_kind_t val);
43fma_kind_t from_string(std::string enum_string);
44
45fma_kind_t get_supported_kind(
46 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c);
47
48int get_simd_size(ngen::HW hw, fma_kind_t kind, const type_t &a,
49 const type_t &b, const type_t &c);
50
51} // namespace fma_kind
52
53class multiply_desc_t {
54public:
55 multiply_desc_t() = default;
56
57 multiply_desc_t(const layout_t &a_layout, const layout_t &b_layout,
58 bool force_c_upconvert)
59 : a_layout_(a_layout), b_layout_(b_layout) {
60 ir_assert(a_layout.ndims() == 2 && b_layout.ndims() == 2)
61 << "Expected 2D layouts, A layout: " << a_layout
62 << " B layout: " << b_layout;
63
64 c_type_ = get_c_type(a_type(), b_type(), force_c_upconvert);
65 }
66
67 const layout_t &a_layout() const { return a_layout_; }
68 const layout_t &b_layout() const { return b_layout_; }
69
70 const type_t &a_type() const { return a_layout_.type(); }
71 const type_t &b_type() const { return b_layout_.type(); }
72 const type_t &c_type() const { return c_type_; }
73
74 int m() const { return a_layout_.dims()[0]; }
75 int n() const { return b_layout_.dims()[1]; }
76 int k() const { return a_layout_.dims()[1]; }
77
78 static type_t get_c_type(
79 const type_t &a, const type_t &b, bool force_c_upconvert);
80
81private:
82 layout_t a_layout_;
83 layout_t b_layout_;
84 type_t c_type_;
85};
86
87// Function representing DPAS instruction.
88class dpas_t : public func_impl_t {
89public:
90 IR_DECL_DERIVED_TYPE_ID(dpas_t, func_impl_t)
91
92 static func_t make(bool is_dpasw, int exec_size, int sdepth, int rcount,
93 const type_t &dst_type, const type_t &src1_type,
94 const type_t &src2_type) {
95 return func_t(new dpas_t(is_dpasw, exec_size, sdepth, rcount, dst_type,
96 src1_type, src2_type));
97 }
98
99 static func_t make_dpasw(const dpas_t &dpas) {
100 return func_t(new dpas_t(true, dpas.exec_size, dpas.sdepth, dpas.rcount,
101 dpas.dst_type, dpas.src1_type, dpas.src2_type));
102 }
103
104 static bool is_dp4a_call(const stmt_t &s) {
105 auto call = s.as_ptr<func_call_t>();
106 return call && call->func.as<dpas_t>().is_dp4a();
107 }
108
109 bool is_dp4a() const { return rcount == 1 && sdepth == 1; }
110
111 bool is_equal(const object_impl_t &obj) const override {
112 if (!obj.is<self_type>()) return false;
113 auto &other = obj.as<self_type>();
114
115 return (is_dpasw == other.is_dpasw) && (sdepth == other.sdepth)
116 && (rcount == other.rcount) && (dst_type == other.dst_type)
117 && (src1_type == other.src1_type)
118 && (src2_type == other.src2_type);
119 }
120
121 size_t get_hash() const override {
122 return ir_utils::get_hash(
123 is_dpasw, sdepth, rcount, dst_type, src1_type, src2_type);
124 }
125
126 std::string str() const override {
127 std::ostringstream oss;
128 oss << (is_dpasw ? "dpasw" : is_dp4a() ? "dp4a" : "dpas");
129 if (!is_dp4a()) oss << "." << sdepth << "x" << rcount;
130 return oss.str();
131 }
132
133 IR_DEFINE_ARG_GET(dst, 0)
134 IR_DEFINE_ARG_GET(src0, 1)
135 IR_DEFINE_ARG_GET(src1, 2)
136 IR_DEFINE_ARG_GET(src2, 3)
137
138 stmt_t operator()(const expr_t &dst, const expr_t &src0, const expr_t &src1,
139 const expr_t &src2) const {
140 return call({dst, src0, src1, src2});
141 }
142
143 int dst_size() const { return exec_size * rcount * sizeof(uint32_t); }
144 int src0_size() const { return dst_size(); }
145 int src1_size() const { return exec_size * sdepth * sizeof(uint32_t); }
146 int src2_size() const {
147 const int dpas_size = sdepth * rcount * sizeof(uint32_t);
148 return is_dpasw ? dpas_size / 2 : dpas_size;
149 }
150
151 layout_t a_layout() const;
152 layout_t b_layout() const;
153 layout_t c_layout() const;
154
155 bool matches(const multiply_desc_t &desc) const;
156
157 static bool matches_types(
158 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c);
159 static bool is_src_type(type_t type);
160
161 bool is_dpasw;
162
163 int exec_size;
164 int sdepth;
165 int rcount;
166
167 type_t dst_type; // src0 type is same as dst_type.
168 type_t src1_type;
169 type_t src2_type;
170
171private:
172 dpas_t(bool is_dpasw, int exec_size, int sdepth, int rcount,
173 const type_t &dst_type, const type_t &src1_type,
174 const type_t &src2_type)
175 : func_impl_t(_type_info())
176 , is_dpasw(is_dpasw)
177 , exec_size(exec_size)
178 , sdepth(sdepth)
179 , rcount(rcount)
180 , dst_type(dst_type)
181 , src1_type(src1_type)
182 , src2_type(src2_type) {}
183};
184
185// Function representing MAD instruction.
186class mad_t : public func_impl_t {
187public:
188 IR_DECL_DERIVED_TYPE_ID(mad_t, func_impl_t)
189
190 static func_t make(ngen::HW hw, const type_t &dst_type, int exec_size,
191 const type_t &src1_type, int src1_stride, const type_t src2_type,
192 int src2_stride) {
193 return func_t(new mad_t(hw, dst_type, exec_size, src1_type, src1_stride,
194 src2_type, src2_stride));
195 }
196
197 bool is_equal(const object_impl_t &obj) const override {
198 if (!obj.is<self_type>()) return false;
199 auto &other = obj.as<self_type>();
200
201 return (dst_type == other.dst_type) && (src1_type == other.src1_type)
202 && (src2_type == other.src2_type)
203 && (exec_size == other.exec_size)
204 && (src1_stride == other.src1_stride)
205 && (src2_stride == other.src2_stride);
206 }
207
208 size_t get_hash() const override {
209 return ir_utils::get_hash(dst_type, src1_type, src2_type, exec_size,
210 src2_stride, src1_stride);
211 }
212
213 std::string str() const override {
214 std::ostringstream oss;
215 oss << "mad";
216 return oss.str();
217 }
218
219 IR_DEFINE_ARG_GET(dst, 0)
220 IR_DEFINE_ARG_GET(src0, 1)
221 IR_DEFINE_ARG_GET(src1, 2)
222 IR_DEFINE_ARG_GET(src2, 3)
223
224 stmt_t operator()(const expr_t &dst, const expr_t &src0, const expr_t &src1,
225 const expr_t &src2) const {
226 return call({dst, src0, src1, src2});
227 }
228
229 int dst_size() const { return exec_size * dst_type.size(); }
230 int src0_size() const { return dst_size(); }
231 int src1_size() const {
232 return std::max(
233 src1_type.size(), src1_stride * src1_type.size() * exec_size);
234 }
235 int src2_size() const {
236 return std::max(
237 src2_type.size(), src2_stride * src2_type.size() * exec_size);
238 }
239
240 static bool matches_types(
241 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c);
242
243 static const int max_exec_size = 32;
244 static const int get_max_exec_size_bytes(ngen::HW hw) {
245 return hw >= ngen::HW::XeHPC ? 128 : 64;
246 }
247 static int get_simd_size(
248 ngen::HW hw, const type_t &a, const type_t &b, const type_t &c) {
249 int max_exec_size_bytes = get_max_exec_size_bytes(hw);
250 int max_size = max_exec_size;
251 if (max_exec_size_bytes / a.size() < max_size)
252 max_size = max_exec_size_bytes / a.size();
253 if (max_exec_size_bytes / b.size() < max_size)
254 max_size = max_exec_size_bytes / b.size();
255 if (max_exec_size_bytes / c.size() < max_size)
256 max_size = max_exec_size_bytes / c.size();
257 return max_size;
258 }
259 int get_exec_size() const { return exec_size; }
260
261 type_t dst_type;
262 type_t src1_type;
263 type_t src2_type;
264
265 int exec_size;
266 int src1_stride;
267 int src2_stride;
268
269private:
270 mad_t(ngen::HW hw, const type_t &dst_type, int exec_size,
271 const type_t &src1_type, int src1_stride, const type_t &src2_type,
272 int src2_stride)
273 : func_impl_t(_type_info())
274 , dst_type(dst_type)
275 , src1_type(src1_type)
276 , src2_type(src2_type)
277 , exec_size(exec_size)
278 , src1_stride(src1_stride)
279 , src2_stride(src2_stride) {
280 int max_exec_size_bytes = get_max_exec_size_bytes(hw);
281 ir_assert(math::is_pow2(exec_size));
282
283 ir_assert(exec_size <= max_exec_size);
284 ir_assert(dst_size() <= max_exec_size_bytes);
285 ir_assert(src1_size() <= max_exec_size_bytes);
286 ir_assert(src2_size() <= max_exec_size_bytes);
287 }
288};
289
290} // namespace jit
291} // namespace gpu
292} // namespace impl
293} // namespace dnnl
294
295#endif
296