1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18#include <string>
19#include <unordered_map>
20#include <vector>
21
22#include "gpu/ocl/kernel_utils.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29
30extern const char *combined_reduction_kernel;
31extern const char *custom_reorder_kernel;
32extern const char *gemm_with_post_ops_kernel;
33extern const char *gen9_gemm_beta_kernel;
34extern const char *gen9_gemm_compute_kernel;
35extern const char *gen9_gemm_copy_kernel;
36extern const char *gen9_gemm_nocopy_f16_kernel;
37extern const char *gen9_gemm_nocopy_f32_kernel;
38extern const char *gen9_gemm_nocopy_scale_x8x8s32_kernel;
39extern const char *gen9_gemm_nocopy_superkernel_f32_kernel;
40extern const char *gen9_gemm_nocopy_x8x8s32_kernel;
41extern const char *ref_gemm_kernel;
42extern const char *xe_hp_systolic_gemm_copy_kernel;
43extern const char *xe_hpc_systolic_gemm_copy_kernel;
44extern const char *xe_lp_gemm_nocopy_scale_x8x8s32_kernel;
45extern const char *xe_lp_gemm_nocopy_x8x8s32_kernel;
46extern const char *gemm_post_ops_inner_product_kernel;
47extern const char *gen9_binary_kernel;
48extern const char *gen9_bnorm_kernel;
49extern const char *gen9_concat_kernel;
50extern const char *gen9_conv_bwd_data_kernel;
51extern const char *gen9_conv_bwd_weights_kernel;
52extern const char *gen9_conv_dw_bwd_data_kernel;
53extern const char *gen9_conv_dw_fwd_data_kernel;
54extern const char *gen9_conv_fwd_data_kernel;
55extern const char *gen9_conv_nhwc_bwd_data_kernel;
56extern const char *gen9_conv_nhwc_bwd_weights_f32_kernel;
57extern const char *gen9_conv_nhwc_fwd_data_kernel;
58extern const char *gen9_eltwise_kernel;
59extern const char *gen9_global_pooling_kernel;
60extern const char *gen9_pooling_kernel;
61extern const char *gen9_reduction_kernel;
62extern const char *gen9_softmax_kernel;
63extern const char *gen9_sum_kernel;
64extern const char *gen9_wino_conv_fwd_data_2x3_kernel;
65extern const char *gen9_wino_conv_fwd_data_fused_kernel;
66extern const char *generic_reorder_kernel;
67extern const char *many_inputs_sum_kernel;
68extern const char *ref_binary_kernel;
69extern const char *ref_bnorm_kernel;
70extern const char *ref_convolution_kernel;
71extern const char *ref_deconv_backward_bias_kernel;
72extern const char *ref_eltwise_kernel;
73extern const char *ref_inner_product_kernel;
74extern const char *ref_layer_normalization_kernel;
75extern const char *ref_lrn_kernel;
76extern const char *ref_matmul_kernel;
77extern const char *ref_pooling_kernel;
78extern const char *ref_prelu_kernel;
79extern const char *ref_reduction_kernel;
80extern const char *ref_reorder_kernel;
81extern const char *ref_resampling_kernel;
82extern const char *ref_shuffle_kernel;
83extern const char *ref_softmax_kernel;
84extern const char *ref_zero_pad_kernel;
85extern const char *ref_rnn_kernel;
86extern const char *rnn_reorder_kernel;
87extern const char *simple_concat_kernel;
88extern const char *simple_sum_kernel;
89extern const char *vectorized_resampling_kernel;
90extern const char *xe_lp_1x1_conv_fwd_data_x8s8x_kernel;
91extern const char *xe_lp_conv_bwd_data_mb_block_x8s8x8_kernel;
92extern const char *xe_lp_conv_bwd_data_x8s8x8_kernel;
93extern const char *xe_lp_conv_dw_fwd_data_mb_block_x8s8x_kernel;
94extern const char *xe_lp_conv_dw_fwd_data_ow_block_x8s8x_kernel;
95extern const char *xe_lp_conv_fwd_data_first_x8s8x_kernel;
96extern const char *xe_lp_conv_fwd_data_mb_block_x8s8x_kernel;
97extern const char *xe_lp_conv_fwd_data_ow_block_x8s8x_kernel;
98extern const char *xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x_kernel;
99extern const char *xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x_kernel;
100extern const char *xe_lp_conv_nhwc_fwd_first_x8s8x_kernel;
101extern const char *xe_lp_conv_nhwc_fwd_x8s8x_kernel;
102extern const char *xe_lp_nhwc_1x1_conv_fwd_x8s8x_kernel;
103extern const char *xe_lp_x8s8x_compensation_kernel;
104
105extern const char *binary_types_header;
106extern const char *ocl_gemm_attrs_header;
107extern const char *metrics_discovery_api_header;
108extern const char *ocl_eltwise_header;
109extern const char *ocl_math_utils_header;
110extern const char *ocl_post_ops_header;
111extern const char *ocl_scales_header;
112extern const char *ocl_types_header;
113extern const char *ocl_zero_points_header;
114extern const char *offsets_header;
115extern const char *reorder_common_header;
116extern const char *rnn_types_header;
117extern const char *zero_pad_struct_header;
118
119const char *get_kernel_source(const char *name) {
120 static const std::unordered_map<std::string, const char *> kernel_list = {
121
122 { "combined_reduce", combined_reduction_kernel },
123 { "custom_reorder", custom_reorder_kernel },
124 { "gemm_post_ops", gemm_with_post_ops_kernel },
125 { "gen9_gemm_beta", gen9_gemm_beta_kernel },
126 { "gen9_gemm_compute", gen9_gemm_compute_kernel },
127 { "gen9_gemm_copy", gen9_gemm_copy_kernel },
128 { "gen9_gemm_nocopy_f16", gen9_gemm_nocopy_f16_kernel },
129 { "gen9_gemm_nocopy_f16", gen9_gemm_nocopy_f16_kernel },
130 { "gen9_gemm_nocopy_f16", gen9_gemm_nocopy_f16_kernel },
131 { "gen9_gemm_nocopy_f16", gen9_gemm_nocopy_f16_kernel },
132 { "gen9_gemm_nocopy_f32", gen9_gemm_nocopy_f32_kernel },
133 { "gen9_gemm_nocopy_f32", gen9_gemm_nocopy_f32_kernel },
134 { "gen9_gemm_nocopy_f32", gen9_gemm_nocopy_f32_kernel },
135 { "gen9_gemm_nocopy_f32", gen9_gemm_nocopy_f32_kernel },
136 { "gen9_gemm_scale_x8x8s32", gen9_gemm_nocopy_scale_x8x8s32_kernel },
137 { "gen9_gemm_nocopy_superkernel_f32", gen9_gemm_nocopy_superkernel_f32_kernel },
138 { "gen9_gemm_nocopy_superkernel_f32", gen9_gemm_nocopy_superkernel_f32_kernel },
139 { "gen9_gemm_compute_x8x8s32", gen9_gemm_nocopy_x8x8s32_kernel },
140 { "ref_gemm", ref_gemm_kernel },
141 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
142 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
143 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
144 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
145 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
146 { "xe_hp_systolic_gemm_copy", xe_hp_systolic_gemm_copy_kernel },
147 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
148 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
149 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
150 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
151 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
152 { "xe_hpc_systolic_gemm_copy", xe_hpc_systolic_gemm_copy_kernel },
153 { "xe_lp_gemm_scale_x8x8s32", xe_lp_gemm_nocopy_scale_x8x8s32_kernel },
154 { "xe_lp_gemm_compute_x8x8s32", xe_lp_gemm_nocopy_x8x8s32_kernel },
155 { "xe_lp_gemm_compute_x8x8s32", xe_lp_gemm_nocopy_x8x8s32_kernel },
156 { "xe_lp_gemm_compute_x8x8s32", xe_lp_gemm_nocopy_x8x8s32_kernel },
157 { "xe_lp_gemm_compute_x8x8s32", xe_lp_gemm_nocopy_x8x8s32_kernel },
158 { "gemm_post_ops_inner_product", gemm_post_ops_inner_product_kernel },
159 { "gen9_binary", gen9_binary_kernel },
160 { "gen9_binary", gen9_binary_kernel },
161 { "gen9_binary", gen9_binary_kernel },
162 { "gen9_binary", gen9_binary_kernel },
163 { "gen9_fused_reduce_init", gen9_bnorm_kernel },
164 { "gen9_calc_mean_var", gen9_bnorm_kernel },
165 { "gen9_calc_mean_var", gen9_bnorm_kernel },
166 { "gen9_reduce_mean_var", gen9_bnorm_kernel },
167 { "gen9_fused_reduce_final", gen9_bnorm_kernel },
168 { "gen9_calc_mean", gen9_bnorm_kernel },
169 { "gen9_calc_mean", gen9_bnorm_kernel },
170 { "gen9_reduce_mean", gen9_bnorm_kernel },
171 { "gen9_calc_variance", gen9_bnorm_kernel },
172 { "gen9_calc_variance", gen9_bnorm_kernel },
173 { "gen9_reduce_variance", gen9_bnorm_kernel },
174 { "gen9_bnorm_fwd", gen9_bnorm_kernel },
175 { "gen9_bnorm_fwd", gen9_bnorm_kernel },
176 { "gen9_calculate_stats", gen9_bnorm_kernel },
177 { "gen9_calculate_stats", gen9_bnorm_kernel },
178 { "gen9_reduce_stats", gen9_bnorm_kernel },
179 { "gen9_fused_reduce_final", gen9_bnorm_kernel },
180 { "gen9_bnorm_bwd", gen9_bnorm_kernel },
181 { "gen9_bnorm_bwd", gen9_bnorm_kernel },
182 { "gen9_concat", gen9_concat_kernel },
183 { "gen9_conv_bwd_data", gen9_conv_bwd_data_kernel },
184 { "gen9_conv_bwd_weights", gen9_conv_bwd_weights_kernel },
185 { "gen9_conv_dw_bwd_data", gen9_conv_dw_bwd_data_kernel },
186 { "gen9_conv_dw_fwd", gen9_conv_dw_fwd_data_kernel },
187 { "gen9_conv_fwd", gen9_conv_fwd_data_kernel },
188 { "gen9_conv_nhwc_bwd_data", gen9_conv_nhwc_bwd_data_kernel },
189 { "gen9_conv_nhwc_bwd_weights", gen9_conv_nhwc_bwd_weights_f32_kernel },
190 { "gen9_conv_nhwc_fwd", gen9_conv_nhwc_fwd_data_kernel },
191 { "gen9_eltwise_fwd", gen9_eltwise_kernel },
192 { "gen9_eltwise_bwd", gen9_eltwise_kernel },
193 { "gen9_global_pooling_fwd", gen9_global_pooling_kernel },
194 { "gen9_global_pooling_bwd", gen9_global_pooling_kernel },
195 { "gen9_pooling_fwd", gen9_pooling_kernel },
196 { "gen9_pooling_bwd", gen9_pooling_kernel },
197 { "gen9_initial_reduce", gen9_reduction_kernel },
198 { "gen9_final_reduce", gen9_reduction_kernel },
199 { "gen9_softmax_fwd", gen9_softmax_kernel },
200 { "gen9_softmax_bwd", gen9_softmax_kernel },
201 { "gen9_sum", gen9_sum_kernel },
202 { "gen9_wino_wei_transform_2x3", gen9_wino_conv_fwd_data_2x3_kernel },
203 { "gen9_wino_src_transform_2x3", gen9_wino_conv_fwd_data_2x3_kernel },
204 { "gen9_wino_dst_transform_2x3", gen9_wino_conv_fwd_data_2x3_kernel },
205 { "gen9_wino_conv_fwd_2x3", gen9_wino_conv_fwd_data_2x3_kernel },
206 { "gen9_wino_wei_transform", gen9_wino_conv_fwd_data_fused_kernel },
207 { "gen9_wino_conv_fwd", gen9_wino_conv_fwd_data_fused_kernel },
208 { "generic_reorder", generic_reorder_kernel },
209 { "many_inputs_sum", many_inputs_sum_kernel },
210 { "many_inputs_sum_batched", many_inputs_sum_kernel },
211 { "ref_binary", ref_binary_kernel },
212 { "ref_binary", ref_binary_kernel },
213 { "calculate_mean", ref_bnorm_kernel },
214 { "calculate_variance", ref_bnorm_kernel },
215 { "calculate_mean_variance", ref_bnorm_kernel },
216 { "calculate_mean", ref_bnorm_kernel },
217 { "calculate_variance", ref_bnorm_kernel },
218 { "reduce_mean", ref_bnorm_kernel },
219 { "reduce_variance", ref_bnorm_kernel },
220 { "ref_bnorm_fwd", ref_bnorm_kernel },
221 { "calculate_stats", ref_bnorm_kernel },
222 { "reduce_stats", ref_bnorm_kernel },
223 { "calculate_stats", ref_bnorm_kernel },
224 { "reduce_stats", ref_bnorm_kernel },
225 { "ref_bnorm_bwd", ref_bnorm_kernel },
226 { "ref_convolution_fwd", ref_convolution_kernel },
227 { "ref_convolution_bwd_data", ref_convolution_kernel },
228 { "ref_convolution_bwd_weights", ref_convolution_kernel },
229 { "ref_deconv_backward_bias", ref_deconv_backward_bias_kernel },
230 { "ref_eltwise_fwd", ref_eltwise_kernel },
231 { "ref_eltwise_bwd", ref_eltwise_kernel },
232 { "ref_inner_product_fwd", ref_inner_product_kernel },
233 { "ref_inner_product_bwd_data", ref_inner_product_kernel },
234 { "ref_inner_product_bwd_weights", ref_inner_product_kernel },
235 { "ref_lnorm_fwd", ref_layer_normalization_kernel },
236 { "ref_lnorm_fwd", ref_layer_normalization_kernel },
237 { "ref_lnorm_bwd_scaleshift", ref_layer_normalization_kernel },
238 { "ref_lnorm_bwd_scaleshift_final", ref_layer_normalization_kernel },
239 { "ref_lnorm_bwd_scaleshift", ref_layer_normalization_kernel },
240 { "ref_lnorm_bwd", ref_layer_normalization_kernel },
241 { "ref_lnorm_bwd", ref_layer_normalization_kernel },
242 { "ref_lrn_fwd", ref_lrn_kernel },
243 { "ref_lrn_bwd", ref_lrn_kernel },
244 { "ref_matmul", ref_matmul_kernel },
245 { "ref_pooling_fwd", ref_pooling_kernel },
246 { "ref_pooling_bwd", ref_pooling_kernel },
247 { "ref_prelu_fwd", ref_prelu_kernel },
248 { "ref_prelu_bwd", ref_prelu_kernel },
249 { "ref_reduce", ref_reduction_kernel },
250 { "ref_reorder", ref_reorder_kernel },
251 { "ref_resampling_fwd", ref_resampling_kernel },
252 { "ref_resampling_bwd", ref_resampling_kernel },
253 { "ref_shuffle", ref_shuffle_kernel },
254 { "ref_softmax_fwd_generic", ref_softmax_kernel },
255 { "ref_softmax_bwd_generic", ref_softmax_kernel },
256 { "ref_zero_pad", ref_zero_pad_kernel },
257 { "ref_zero_pad_subg_16", ref_zero_pad_kernel },
258 { "ref_zero_pad_subg_16_mask_and_clear_dt_1b", ref_zero_pad_kernel },
259 { "ref_rnn_copy_init_layer", ref_rnn_kernel },
260 { "ref_rnn_copy_init_iter", ref_rnn_kernel },
261 { "ref_rnn_copy_res_layer", ref_rnn_kernel },
262 { "ref_rnn_copy_res_iter", ref_rnn_kernel },
263 { "ref_rnn_ws_set", ref_rnn_kernel },
264 { "ref_rnn_ws_print", ref_rnn_kernel },
265 { "ref_rnn_bias_prepare", ref_rnn_kernel },
266 { "ref_rnn_elemwise_fwd", ref_rnn_kernel },
267 { "ref_rnn_elemwise_fwd", ref_rnn_kernel },
268 { "ref_rnn_elemwise_bwd", ref_rnn_kernel },
269 { "ref_rnn_gates_reduction", ref_rnn_kernel },
270 { "wei_reorder", rnn_reorder_kernel },
271 { "simple_concat", simple_concat_kernel },
272 { "simple_sum", simple_sum_kernel },
273 { "vectorized_resampling_bwd", vectorized_resampling_kernel },
274 { "xe_lp_1x1_conv_fwd_x8s8x", xe_lp_1x1_conv_fwd_data_x8s8x_kernel },
275 { "conv_bwd_data_mb_block_x8s8x8", xe_lp_conv_bwd_data_mb_block_x8s8x8_kernel },
276 { "conv_bwd_data_x8s8x8", xe_lp_conv_bwd_data_x8s8x8_kernel },
277 { "conv_dw_fwd_mb_block_x8s8x", xe_lp_conv_dw_fwd_data_mb_block_x8s8x_kernel },
278 { "conv_dw_fwd_ow_block_x8s8x", xe_lp_conv_dw_fwd_data_ow_block_x8s8x_kernel },
279 { "conv_fwd_first_x8s8x", xe_lp_conv_fwd_data_first_x8s8x_kernel },
280 { "conv_fwd_mb_block_x8s8x", xe_lp_conv_fwd_data_mb_block_x8s8x_kernel },
281 { "conv_fwd_ow_block_x8s8x", xe_lp_conv_fwd_data_ow_block_x8s8x_kernel },
282 { "conv_nhwc_fwd_dw_mb_block_x8s8x", xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x_kernel },
283 { "conv_nhwc_fwd_dw_ow_block_x8s8x", xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x_kernel },
284 { "conv_nhwc_fwd_first_x8s8x", xe_lp_conv_nhwc_fwd_first_x8s8x_kernel },
285 { "conv_nhwc_fwd_x8s8x", xe_lp_conv_nhwc_fwd_x8s8x_kernel },
286 { "xe_lp_nhwc_1x1_conv_fwd_x8s8x", xe_lp_nhwc_1x1_conv_fwd_x8s8x_kernel },
287 { "xe_lp_x8s8x_compensation", xe_lp_x8s8x_compensation_kernel },
288 { "xe_lp_x8s8x_compensation", xe_lp_x8s8x_compensation_kernel },
289 };
290
291 if (!name) return nullptr;
292
293 assert(kernel_list.count(name) == 1);
294 return kernel_list.at(name);
295}
296
297const char *get_kernel_header(const std::string &name) {
298 static const std::unordered_map<std::string, const char *> kernel_header_list ={
299
300 {"gpu/ocl/binary_types.h", binary_types_header},
301 {"gpu/ocl/gemm/ocl_gemm_attrs.h", ocl_gemm_attrs_header},
302 {"gpu/ocl/mdapi/metrics_discovery_api.h", metrics_discovery_api_header},
303 {"gpu/ocl/ocl_eltwise.h", ocl_eltwise_header},
304 {"gpu/ocl/ocl_math_utils.h", ocl_math_utils_header},
305 {"gpu/ocl/ocl_post_ops.h", ocl_post_ops_header},
306 {"gpu/ocl/ocl_scales.h", ocl_scales_header},
307 {"gpu/ocl/ocl_types.h", ocl_types_header},
308 {"gpu/ocl/ocl_zero_points.h", ocl_zero_points_header},
309 {"gpu/ocl/offsets.h", offsets_header},
310 {"gpu/ocl/reorder_common.h", reorder_common_header},
311 {"gpu/ocl/rnn/rnn_types.h", rnn_types_header},
312 {"gpu/zero_pad_struct.h", zero_pad_struct_header},
313 };
314
315 assert(kernel_header_list.count(name) == 1);
316 return kernel_header_list.at(name);
317}
318
319} // namespace ocl
320} // namespace gpu
321} // namespace impl
322} // namespace dnnl
323