1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_GEN_GEMM_HPP
18#define GPU_JIT_GEMM_GEN_GEMM_HPP
19
20#include <assert.h>
21#include <memory>
22
23#include "common/c_types_map.hpp"
24#include "common/gemm_utils.hpp"
25#include "common/utils.hpp"
26#include "gpu/compute/compute.hpp"
27#include "gpu/compute/kernel.hpp"
28#include "gpu/gemm/gpu_gemm.hpp"
29#include "gpu/jit/gemm/gen_gemm_kernel.hpp"
30#include "gpu/jit/gemm/jit_gemm_pd.hpp"
31#include "gpu/jit/jit_post_op_injector.hpp"
32#include "gpu/primitive_conf.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace gpu {
37namespace jit {
38
39struct gen_gemm_t : public gpu_gemm_t {
40 struct pd_t : public jit_gemm_pd_t {
41 using jit_gemm_pd_t::jit_gemm_pd_t;
42 using kernel_desc_t = gen_gemm_nocopy_kernel_desc_t;
43
44 DECLARE_COMMON_PD_T("jit:gemm:any", gen_gemm_t);
45
46 status_t init(engine_t *engine) {
47 using namespace prop_kind;
48 using namespace data_type;
49 using namespace primitive_kind;
50 using namespace alg_kind;
51 using smask_t = primitive_attr_t::skip_mask_t;
52 using arch_t = compute::gpu_arch_t;
53
54 assert(engine->kind() == engine_kind::gpu);
55 auto *compute_engine
56 = utils::downcast<compute::compute_engine_t *>(engine);
57
58 // LIMITATIONS:
59 // - runtime dims are not supported
60 bool ok = true;
61
62 auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops;
63
64 dev_info_ = compute_engine->device_info();
65 arch_ = dev_info_->gpu_arch();
66 int stepping = dev_info_->stepping_id();
67
68 ok = set_default_formats();
69 if (!ok) return status::unimplemented;
70
71 bool check_lda
72 = ((desc()->transa() == dnnl_notrans && desc()->lda() == 1)
73 || (desc()->transa() == dnnl_trans));
74 swap_ab_ = (desc()->a_type() == data_type::f16 && desc()->m() == 1
75 && desc()->ldc() == 1 && check_lda);
76
77 const auto d = desc();
78
79 if (utils::one_of(d->c_type(), s32, f16, f32, u8, s8)
80 && utils::one_of(d->a_type(), u8, s8)) {
81 ok = ok && utils::one_of(d->b_type(), u8, s8);
82
83 a_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC);
84 b_zp_ = !attr()->zero_points_.has_default_values(
85 DNNL_ARG_WEIGHTS);
86 if (swap_ab_) std::swap(a_zp_, b_zp_);
87
88 int cmask_a = 0, cmask_b = 0, cmask_c = 0;
89 attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &cmask_b);
90 attr()->zero_points_.get(DNNL_ARG_SRC, &cmask_a);
91 attr()->zero_points_.get(DNNL_ARG_DST, &cmask_c);
92 ok &= (cmask_a == 0) && (cmask_b == 0)
93 && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1);
94
95 attr_skip_mask |= smask_t::zero_points_runtime;
96
97 ok = ok
98 && IMPLICATION(
99 utils::one_of(d->c_type(), f32, s8, u8, f16),
100 arch_ >= arch_t::xe_hp);
101 } else if (d->a_type() == bf16) {
102 ok = ok && d->b_type() == bf16
103 && utils::one_of(d->c_type(), bf16, f32)
104 && utils::one_of(d->acc_type, bf16, f32);
105 } else {
106 ok = ok && utils::one_of(d->a_type(), f32, f16)
107 && d->b_type() == d->a_type()
108 && utils::one_of(d->acc_type, d->a_type(), f32);
109 }
110
111 ok = ok && !has_blocks() && batch_dims() <= 2
112 && !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(),
113 d->k(), d->lda(), d->ldb(), d->ldc(), d->batch())
114 && IMPLICATION(with_bias(),
115 utils::one_of(d->bias_type(), f32, bf16, f16)
116 && (d->bias_desc.ndims <= 3)
117 && utils::one_of(bias_cmask(), 0, 1, 2, 3))
118 && IMPLICATION(utils::one_of(d->bias_type(), bf16, f16),
119 (d->bias_type() == d->c_type()))
120 && compute_engine->mayiuse_ngen_kernels()
121 && attr()->has_default_values(attr_skip_mask)
122 && attr()->output_scales_.mask_ == 0
123 && IMPLICATION(with_sum_ab(),
124 !with_bias()
125 && (attr()->zero_points_.has_default_values(
126 DNNL_ARG_DST)));
127
128 auto status = init_post_ops();
129 if (status != status::success) return status;
130
131 bool with_binary = (post_ops_.find(binary) != -1);
132
133 // check GPU architecture
134 ok &= utils::one_of(arch_, arch_t::gen9, arch_t::xe_lp,
135 arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc);
136 ok &= IMPLICATION(with_binary, arch_ >= arch_t::xe_hp);
137
138 if (!ok) return status::unimplemented;
139
140 // size checks for fused reduction kernels
141 if (with_sum_ab()) {
142 auto mnk = d->m() * d->n() * d->k();
143 if (arch_ == arch_t::xe_hpc && d->a_type() == f32)
144 ok &= (mnk <= 256 * 1024 * 1024);
145
146 if (!ok) return status::unimplemented;
147 }
148
149 // choose kernel
150 auto co_type = with_bias()
151 ? d->bias_type()
152 : with_sum_ab() ? d->sum_ab_type
153 : (utils::one_of(eff_a_type(), s8, u8)
154 ? s32
155 : d->c_type());
156
157 auto acc_type = utils::one_of(eff_a_type(), s8, u8) ? s32 : f32;
158
159 if (d->c_type() == f16 && arch_ < compute::gpu_arch_t::xe_hpg)
160 acc_type = data_type::f16;
161
162 if (types::data_type_size(acc_type) < 4) {
163 // Limited post-op support for low-precision accumulation.
164 ok &= !with_binary && IMPLICATION(with_sum_, sum_at_begin_);
165 }
166
167 kernel_desc_t::compute_mode mode = kernel_desc_t::mode_default;
168
169 if (attr()->mayidownconvert(f32, tf32))
170 mode = static_cast<decltype(mode)>(
171 mode | kernel_desc_t::mode_tf32);
172
173 if (attr()->mayidownconvert(f32, bf16))
174 mode = static_cast<decltype(mode)>(
175 mode | kernel_desc_t::mode_bf16x1);
176
177 status = kernel_desc_.select_kernel(arch_, stepping,
178 dev_info_->eu_count(), mode, batch_dims(), eff_transa(),
179 eff_transb(), eff_trans_bias(), swap_ab(),
180 with_a_zero_points(), with_b_zero_points(),
181 with_c_zero_points(), with_bias(), sum_ab(), alpha(),
182 beta(), post_ops_, eff_a_type(), eff_b_type(),
183 desc()->c_type(), co_type, acc_type, eff_align_a(),
184 eff_align_b(), align_c(), eff_m(), eff_n(), d->k(),
185 eff_lda(), eff_ldb(), d->ldc(), d->batch());
186
187 if (status != status::success) return status;
188
189 // global k-parallel kernels don't support post-ops.
190 // use global k-parallel kernels only with f32 accumulation
191 bool k_parallel_global = kernel_desc_.driver_info()->kParallel;
192 bool with_eltwise = (post_ops_.find(eltwise) != -1);
193
194 ok &= IMPLICATION(k_parallel_global,
195 !with_bias() && !with_eltwise && !with_binary
196 && utils::one_of(d->c_type(), f32, s32));
197
198 if (!ok) return status::unimplemented;
199
200 return status::success;
201 }
202
203 status_t query(query_t what, int idx, void *result) const override {
204 switch ((int)what) {
205 case (int)query::preferred_gpu_threads_per_eu: {
206 int grfs = kernel_desc_.driver_info()->grfCount;
207 *(int *)result = (grfs > 128) ? 4 : 8;
208 break;
209 }
210 default: return gpu_gemm_pd_t::query(what, idx, result);
211 }
212 return status::success;
213 }
214
215 bool set_default_formats() {
216 using namespace data_type;
217 using namespace format_tag;
218 using arch_t = compute::gpu_arch_t;
219
220 auto d = desc();
221
222 auto m = d->m();
223 auto n = d->n();
224 auto k = d->k();
225 auto a_t = d->a_type();
226 auto b_t = d->b_type();
227 auto c_t = d->c_type();
228 auto a_t_sz = types::data_type_size(a_t);
229 auto b_t_sz = types::data_type_size(b_t);
230
231 bool is_f16 = utils::everyone_is(f16, a_t, b_t, c_t);
232 bool is_xe_hp_plus = arch_ >= arch_t::xe_hp;
233
234 // Rename memory descriptors following column major format.
235 auto &a_desc = desc_.b_desc;
236 auto &b_desc = desc_.a_desc;
237 auto &c_desc = desc_.c_desc;
238
239 memory_desc_wrapper a_mdw(&a_desc);
240 memory_desc_wrapper b_mdw(&b_desc);
241 memory_desc_wrapper c_mdw(&c_desc);
242
243 bool a_any = a_mdw.format_any();
244 bool b_any = b_mdw.format_any();
245 bool c_any = c_mdw.format_any();
246
247 if (!a_any && !is_md_gemm_compatible_plain_format(&a_desc))
248 return false;
249 if (!b_any && !is_md_gemm_compatible_plain_format(&b_desc))
250 return false;
251 if (!c_any && !is_md_gemm_compatible_plain_format(&c_desc, true))
252 return false;
253
254 bool is_a_trans = (desc()->transa() == dnnl_trans);
255 bool is_b_trans = (desc()->transb() == dnnl_trans);
256
257 auto lda = is_a_trans ? m : k;
258 auto ldb = is_b_trans ? k : n;
259
260 auto is_aligned = [](dim_t ld, size_t sz, int byte) {
261 return ld * sz % byte == 0;
262 };
263
264 bool a_4B_aligned = is_aligned(lda, a_t_sz, 4);
265 bool b_4B_aligned = is_aligned(ldb, b_t_sz, 4);
266 bool ab_4B_aligned = a_4B_aligned && b_4B_aligned;
267
268 bool a_tn_4B_aligned = is_aligned(k, a_t_sz, 4);
269 bool b_tn_4B_aligned = is_aligned(k, b_t_sz, 4);
270 bool ab_tn_4B_aligned = a_tn_4B_aligned && b_tn_4B_aligned;
271
272 bool use_tn = (m <= 32 || n <= 32) && !ab_4B_aligned
273 && ab_tn_4B_aligned;
274
275 bool batch = d->is_batched();
276
277 auto dotrans = batch ? acb : ba;
278 auto notrans = batch ? abc : ab;
279
280 if (is_f16 && is_xe_hp_plus && use_tn) {
281 if (a_any && b_any) {
282 CHECK(memory_desc_init_by_tag(a_desc, dotrans));
283 CHECK(memory_desc_init_by_tag(b_desc, notrans));
284 } else if (a_any && !is_b_trans) {
285 CHECK(memory_desc_init_by_tag(a_desc, dotrans));
286 } else if (b_any && is_a_trans) {
287 CHECK(memory_desc_init_by_tag(b_desc, notrans));
288 }
289 }
290
291 return gpu_gemm_pd_t::set_default_formats();
292 }
293
294 float alpha() const { return 1.0f; }
295
296 float beta() const { return beta_; }
297
298 bool with_bias() const {
299 return desc()->bias_type() != data_type::undef && !bias_via_binary_;
300 }
301
302 int bias_cmask() const {
303 unsigned char to_cmask[8] = {0, 4, 2, 6, 1, 5, 3, 7};
304 assert(unsigned(desc()->bias_mask()) < 8);
305 return with_bias() ? to_cmask[desc()->bias_mask() & 7] : -1;
306 }
307
308 sum_ab_t sum_ab() const { return desc()->sum_ab; }
309
310 bool with_sum_ab() const { return sum_ab() != sum_ab::sum_none; }
311
312 int sum_ab_cmask() const {
313 switch (sum_ab()) {
314 default:
315 case sum_ab::sum_none: return 0;
316 case sum_ab::sum_a_row: return 1;
317 case sum_ab::sum_b_col: return 2;
318 }
319 }
320
321 bool with_a_zero_points() const { return a_zp_; }
322 bool with_b_zero_points() const { return b_zp_; }
323
324 bool with_c_zero_points() const {
325 return !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
326 }
327
328 bool swap_ab() const { return swap_ab_; }
329
330 int batch_dims() const {
331 return nstl::max(desc()->c_desc.ndims - 2, 0);
332 }
333
334 int align_a() const {
335 return int(utils::max_pow2_div(
336 types::data_type_size(desc()->a_type()) * desc()->lda()));
337 }
338 int align_b() const {
339 return int(utils::max_pow2_div(
340 types::data_type_size(desc()->b_type()) * desc()->ldb()));
341 }
342 int align_c() const {
343 return int(utils::max_pow2_div(
344 types::data_type_size(desc()->c_type()) * desc()->ldc()));
345 }
346
347 int eff_align_a() const { return !swap_ab() ? align_a() : align_b(); }
348 int eff_align_b() const { return !swap_ab() ? align_b() : align_a(); }
349 bool eff_transa() const {
350 return !swap_ab() ? (desc()->transa() == dnnl_trans)
351 : (desc()->transb() == dnnl_notrans);
352 }
353 bool eff_transb() const {
354 return !swap_ab() ? (desc()->transb() == dnnl_trans) : false;
355 }
356 bool eff_trans_bias() const {
357 return swap_ab() ? (desc()->trans_bias() == dnnl_notrans)
358 : (desc()->trans_bias() == dnnl_trans);
359 }
360 dim_t eff_m() const { return !swap_ab() ? desc()->m() : desc()->n(); }
361 dim_t eff_n() const { return !swap_ab() ? desc()->n() : desc()->m(); }
362 dim_t eff_lda() const {
363 return !swap_ab() ? desc()->lda() : desc()->ldb();
364 }
365 dim_t eff_ldb() const {
366 return !swap_ab() ? desc()->ldb() : desc()->lda();
367 }
368 data_type_t eff_a_type() const {
369 return !swap_ab() ? desc()->a_type() : desc()->b_type();
370 }
371 data_type_t eff_b_type() const {
372 return !swap_ab() ? desc()->b_type() : desc()->a_type();
373 }
374 const gen_gemm_nocopy_kernel_desc_t *kernel_desc() const {
375 return &kernel_desc_;
376 }
377
378 size_t dyn_offset_a = 0;
379 size_t dyn_offset_b = 0;
380 size_t dyn_offset_c = 0;
381 size_t dyn_offset_co = 0;
382
383 bool swap_ab_ = false;
384 bool a_zp_ = false, b_zp_ = false;
385
386 const compute::device_info_t *dev_info_;
387 compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown;
388
389 kernel_desc_t kernel_desc_;
390 };
391
392 gen_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {}
393
394 status_t init(engine_t *engine) override { return init_nocopy(engine); }
395
396 status_t init_nocopy(engine_t *engine) {
397 using kernel_t = gen_gemm_kernel_t;
398 using namespace data_type;
399
400 auto kd = pd()->kernel_desc();
401 kernel_t kernel(*kd);
402
403 create_kernel(engine, &nocopy_kernel_, &kernel);
404
405 scalar_type_ = kd->scalar_type();
406
407 if (get_verbose() >= 2) {
408 auto info = kd->driver_info();
409 printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%dx%dx%d\n",
410 info->unroll[LoopM], info->unroll[LoopN], info->wg[LoopM],
411 info->wg[LoopN], info->wg[LoopK]);
412 }
413
414 return status::success;
415 }
416
417 status_t execute(const gemm_exec_ctx_t &ctx) const override;
418
419private:
420 status_t launch_nocopy(const gemm_exec_ctx_t &ctx,
421 compute::compute_stream_t *s, const memory_storage_t &a,
422 const memory_storage_t &b, const memory_storage_t &c,
423 const memory_storage_t *ao, const memory_storage_t *bo,
424 const memory_storage_t &co, int po_count,
425 const memory_storage_t **po_src, int64_t offset_a, int64_t offset_b,
426 int64_t offset_c, int32_t offset_co, int32_t *offset_po_src,
427 int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n,
428 int32_t k, int32_t k0, float alpha, float beta, int32_t cmask,
429 bool last_k_block, bool swapab, bool disable_hilbert) const;
430
431 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
432 const CommonDriverInfo *nocopy_info() const {
433 return pd()->kernel_desc()->driver_info();
434 }
435
436 compute::kernel_t nocopy_kernel_;
437 compute::scalar_type_t scalar_type_;
438};
439
440} // namespace jit
441} // namespace gpu
442} // namespace impl
443} // namespace dnnl
444#endif
445
446// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
447