1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_HPP |
18 | #define GPU_JIT_GEMM_GEN_GEMM_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "gpu/compute/compute.hpp" |
23 | #include "gpu/compute/device_info.hpp" |
24 | #include "gpu/compute/kernel_arg_list.hpp" |
25 | #include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp" |
26 | #include "gpu/jit/gemm/kernel_catalog.hpp" |
27 | #include "gpu/jit/gemm/kernel_evaluator.hpp" |
28 | #include "gpu/jit/jit_generator_base.hpp" |
29 | #include "gpu/jit/utils/ngen_type_bridge.hpp" |
30 | #include "gpu/primitive_conf.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace jit { |
36 | |
37 | struct gen_gemm_kernel_desc_t { |
38 | friend struct gen_gemm_kernel_t; |
39 | |
40 | const GEMMProblem *problem() const { return &problem_; }; |
41 | const GEMMStrategy *strategy() const { return &strategy_; }; |
42 | |
43 | const CommonDriverInfo *driver_info() const { return &driver_info_; }; |
44 | const EvaluateAuxOutput *aux_params() const { return &aux_params_; }; |
45 | |
46 | compute::scalar_type_t scalar_type() const { |
47 | switch (problem_.Ts) { |
48 | case Type::s8: return compute::scalar_type_t::_char; |
49 | case Type::u8: return compute::scalar_type_t::_uchar; |
50 | case Type::s16: return compute::scalar_type_t::_short; |
51 | case Type::u16: return compute::scalar_type_t::_ushort; |
52 | case Type::s32: return compute::scalar_type_t::_int; |
53 | case Type::u32: return compute::scalar_type_t::_uint; |
54 | case Type::s64: return compute::scalar_type_t::_long; |
55 | case Type::u64: return compute::scalar_type_t::_ulong; |
56 | case Type::bf16: return compute::scalar_type_t::_bfloat16; |
57 | case Type::f16: return compute::scalar_type_t::_half; |
58 | case Type::f32: return compute::scalar_type_t::_float; |
59 | default: return compute::scalar_type_t::undef; |
60 | } |
61 | } |
62 | |
63 | protected: |
64 | static Type convert_dnnl_to_kernel_type(data_type_t type) { |
65 | switch (type) { |
66 | default: assert(!"Unknown type" ); |
67 | case data_type::f32: return Type::f32; |
68 | case data_type::f16: return Type::f16; |
69 | case data_type::bf16: return Type::bf16; |
70 | case data_type::s32: return Type::s32; |
71 | case data_type::u8: return Type::u8; |
72 | case data_type::s8: return Type::s8; |
73 | } |
74 | } |
75 | |
76 | static ngen::HW convert_dnnl_arch_to_hw(compute::gpu_arch_t arch) { |
77 | switch (arch) { |
78 | case compute::gpu_arch_t::gen9: return ngen::HW::Gen9; |
79 | case compute::gpu_arch_t::xe_lp: return ngen::HW::XeLP; |
80 | case compute::gpu_arch_t::xe_hp: return ngen::HW::XeHP; |
81 | case compute::gpu_arch_t::xe_hpg: return ngen::HW::XeHPG; |
82 | case compute::gpu_arch_t::xe_hpc: return ngen::HW::XeHPC; |
83 | default: return ngen::HW::Unknown; |
84 | } |
85 | } |
86 | |
87 | compute::gpu_arch_t arch_; |
88 | ngen::HW hw_ = ngen::HW::Unknown; |
89 | int stepping_ = 0; |
90 | GEMMProblem problem_; |
91 | GEMMStrategy strategy_; |
92 | const kcatalog::Entry *entry_ = nullptr; |
93 | EvaluateAuxOutput aux_params_; |
94 | CommonDriverInfo driver_info_; |
95 | |
96 | bool a_offset_ = false, b_offset_ = false; |
97 | |
98 | /* optional information to fine-tune kernel */ |
99 | int m_ = -1, n_ = -1, k_ = -1; |
100 | int eu_count_ = -1; |
101 | |
102 | status_t transfer_post_ops( |
103 | const post_ops_t &post_ops, bool swap_ab = false); |
104 | |
105 | status_t finalize(); |
106 | void update_driver_info(); |
107 | }; |
108 | |
109 | struct gen_gemm_nocopy_kernel_desc_t : public gen_gemm_kernel_desc_t { |
110 | enum compute_mode { mode_default = 0, mode_tf32 = 0x1, mode_bf16x1 = 0x2 }; |
111 | |
112 | status_t select_kernel(compute::gpu_arch_t arch, int stepping, int eu_count, |
113 | compute_mode mode, int batch_dims, bool trans_a, bool trans_b, |
114 | bool trans_co, bool swap_ab, bool a_offset, bool b_offset, |
115 | bool c_offset, bool bias, sum_ab_t reduce_ab, float alpha, |
116 | float beta, const post_ops_t &post_ops, data_type_t a_type, |
117 | data_type_t b_type, data_type_t c_type, data_type_t co_type, |
118 | data_type_t acc_type, int align_a, int align_b, int align_c, |
119 | dim_t m, dim_t n, dim_t k, dim_t lda, dim_t ldb, dim_t ldc, |
120 | dim_t batch); |
121 | }; |
122 | |
123 | struct gen_gemm_xe_systolic_kernel_desc_t : public gen_gemm_kernel_desc_t { |
124 | status_t select_kernel(compute::gpu_arch_t arch, int eu_count, |
125 | int batch_dims, bool packed_c, bool a_offset, bool b_offset, |
126 | bool c_offset, bool bias, float alpha, float beta, |
127 | const post_ops_t &post_ops, data_type_t a_type, data_type_t b_type, |
128 | data_type_t c_type, data_type_t co_type, data_type_t acc_type, |
129 | dim_t m, dim_t n, dim_t k, dim_t batch, int unroll_m, int unroll_n, |
130 | bool alt); |
131 | |
132 | static void choose_unrolls(compute::gpu_arch_t arch, int eu_count, |
133 | data_type_t a_type, data_type_t b_type, data_type_t c_type, dim_t m, |
134 | dim_t n, dim_t k, dim_t batch, int &unroll_m, int &unroll_n, |
135 | bool &alt); |
136 | |
137 | static int min_block_k(data_type_t a_type) { return 2048; } |
138 | }; |
139 | |
140 | struct gen_gemm_kernel_t : public jit_generator_base { |
141 | |
142 | explicit gen_gemm_kernel_t(const gen_gemm_kernel_desc_t &desc) |
143 | : desc_(desc) {} |
144 | |
145 | const char *kernel_name() const override { return "gemm_kernel" ; } |
146 | cl_kernel get_kernel(cl_context context, cl_device_id device) override; |
147 | |
148 | const gen_gemm_kernel_desc_t *desc() const { return &desc_; } |
149 | |
150 | protected: |
151 | const gen_gemm_kernel_desc_t &desc_; |
152 | ngen::NEOInterfaceHandler interface_ {ngen::HW::Unknown}; |
153 | |
154 | void init_interface(); |
155 | }; |
156 | |
157 | } // namespace jit |
158 | } // namespace gpu |
159 | } // namespace impl |
160 | } // namespace dnnl |
161 | |
162 | #endif |
163 | |