1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_HPP
18#define GPU_JIT_GEMM_GEN_GEMM_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/type_helpers.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/compute/device_info.hpp"
24#include "gpu/compute/kernel_arg_list.hpp"
25#include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp"
26#include "gpu/jit/gemm/kernel_catalog.hpp"
27#include "gpu/jit/gemm/kernel_evaluator.hpp"
28#include "gpu/jit/jit_generator_base.hpp"
29#include "gpu/jit/utils/ngen_type_bridge.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace jit {
36
37struct gen_gemm_kernel_desc_t {
38 friend struct gen_gemm_kernel_t;
39
40 const GEMMProblem *problem() const { return &problem_; };
41 const GEMMStrategy *strategy() const { return &strategy_; };
42
43 const CommonDriverInfo *driver_info() const { return &driver_info_; };
44 const EvaluateAuxOutput *aux_params() const { return &aux_params_; };
45
46 compute::scalar_type_t scalar_type() const {
47 switch (problem_.Ts) {
48 case Type::s8: return compute::scalar_type_t::_char;
49 case Type::u8: return compute::scalar_type_t::_uchar;
50 case Type::s16: return compute::scalar_type_t::_short;
51 case Type::u16: return compute::scalar_type_t::_ushort;
52 case Type::s32: return compute::scalar_type_t::_int;
53 case Type::u32: return compute::scalar_type_t::_uint;
54 case Type::s64: return compute::scalar_type_t::_long;
55 case Type::u64: return compute::scalar_type_t::_ulong;
56 case Type::bf16: return compute::scalar_type_t::_bfloat16;
57 case Type::f16: return compute::scalar_type_t::_half;
58 case Type::f32: return compute::scalar_type_t::_float;
59 default: return compute::scalar_type_t::undef;
60 }
61 }
62
63protected:
64 static Type convert_dnnl_to_kernel_type(data_type_t type) {
65 switch (type) {
66 default: assert(!"Unknown type");
67 case data_type::f32: return Type::f32;
68 case data_type::f16: return Type::f16;
69 case data_type::bf16: return Type::bf16;
70 case data_type::s32: return Type::s32;
71 case data_type::u8: return Type::u8;
72 case data_type::s8: return Type::s8;
73 }
74 }
75
76 static ngen::HW convert_dnnl_arch_to_hw(compute::gpu_arch_t arch) {
77 switch (arch) {
78 case compute::gpu_arch_t::gen9: return ngen::HW::Gen9;
79 case compute::gpu_arch_t::xe_lp: return ngen::HW::XeLP;
80 case compute::gpu_arch_t::xe_hp: return ngen::HW::XeHP;
81 case compute::gpu_arch_t::xe_hpg: return ngen::HW::XeHPG;
82 case compute::gpu_arch_t::xe_hpc: return ngen::HW::XeHPC;
83 default: return ngen::HW::Unknown;
84 }
85 }
86
87 compute::gpu_arch_t arch_;
88 ngen::HW hw_ = ngen::HW::Unknown;
89 int stepping_ = 0;
90 GEMMProblem problem_;
91 GEMMStrategy strategy_;
92 const kcatalog::Entry *entry_ = nullptr;
93 EvaluateAuxOutput aux_params_;
94 CommonDriverInfo driver_info_;
95
96 bool a_offset_ = false, b_offset_ = false;
97
98 /* optional information to fine-tune kernel */
99 int m_ = -1, n_ = -1, k_ = -1;
100 int eu_count_ = -1;
101
102 status_t transfer_post_ops(
103 const post_ops_t &post_ops, bool swap_ab = false);
104
105 status_t finalize();
106 void update_driver_info();
107};
108
109struct gen_gemm_nocopy_kernel_desc_t : public gen_gemm_kernel_desc_t {
110 enum compute_mode { mode_default = 0, mode_tf32 = 0x1, mode_bf16x1 = 0x2 };
111
112 status_t select_kernel(compute::gpu_arch_t arch, int stepping, int eu_count,
113 compute_mode mode, int batch_dims, bool trans_a, bool trans_b,
114 bool trans_co, bool swap_ab, bool a_offset, bool b_offset,
115 bool c_offset, bool bias, sum_ab_t reduce_ab, float alpha,
116 float beta, const post_ops_t &post_ops, data_type_t a_type,
117 data_type_t b_type, data_type_t c_type, data_type_t co_type,
118 data_type_t acc_type, int align_a, int align_b, int align_c,
119 dim_t m, dim_t n, dim_t k, dim_t lda, dim_t ldb, dim_t ldc,
120 dim_t batch);
121};
122
123struct gen_gemm_xe_systolic_kernel_desc_t : public gen_gemm_kernel_desc_t {
124 status_t select_kernel(compute::gpu_arch_t arch, int eu_count,
125 int batch_dims, bool packed_c, bool a_offset, bool b_offset,
126 bool c_offset, bool bias, float alpha, float beta,
127 const post_ops_t &post_ops, data_type_t a_type, data_type_t b_type,
128 data_type_t c_type, data_type_t co_type, data_type_t acc_type,
129 dim_t m, dim_t n, dim_t k, dim_t batch, int unroll_m, int unroll_n,
130 bool alt);
131
132 static void choose_unrolls(compute::gpu_arch_t arch, int eu_count,
133 data_type_t a_type, data_type_t b_type, data_type_t c_type, dim_t m,
134 dim_t n, dim_t k, dim_t batch, int &unroll_m, int &unroll_n,
135 bool &alt);
136
137 static int min_block_k(data_type_t a_type) { return 2048; }
138};
139
140struct gen_gemm_kernel_t : public jit_generator_base {
141
142 explicit gen_gemm_kernel_t(const gen_gemm_kernel_desc_t &desc)
143 : desc_(desc) {}
144
145 const char *kernel_name() const override { return "gemm_kernel"; }
146 cl_kernel get_kernel(cl_context context, cl_device_id device) override;
147
148 const gen_gemm_kernel_desc_t *desc() const { return &desc_; }
149
150protected:
151 const gen_gemm_kernel_desc_t &desc_;
152 ngen::NEOInterfaceHandler interface_ {ngen::HW::Unknown};
153
154 void init_interface();
155};
156
157} // namespace jit
158} // namespace gpu
159} // namespace impl
160} // namespace dnnl
161
162#endif
163