1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include <string> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "gpu/ocl/kernel_utils.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | |
30 | extern const char *combined_reduction_kernel; |
31 | extern const char *custom_reorder_kernel; |
32 | extern const char *gemm_with_post_ops_kernel; |
33 | extern const char *gen9_gemm_beta_kernel; |
34 | extern const char *gen9_gemm_compute_kernel; |
35 | extern const char *gen9_gemm_copy_kernel; |
36 | extern const char *gen9_gemm_nocopy_f16_kernel; |
37 | extern const char *gen9_gemm_nocopy_f32_kernel; |
38 | extern const char *gen9_gemm_nocopy_scale_x8x8s32_kernel; |
39 | extern const char *gen9_gemm_nocopy_superkernel_f32_kernel; |
40 | extern const char *gen9_gemm_nocopy_x8x8s32_kernel; |
41 | extern const char *ref_gemm_kernel; |
42 | extern const char *xe_hp_systolic_gemm_copy_kernel; |
43 | extern const char *xe_hpc_systolic_gemm_copy_kernel; |
44 | extern const char *xe_lp_gemm_nocopy_scale_x8x8s32_kernel; |
45 | extern const char *xe_lp_gemm_nocopy_x8x8s32_kernel; |
46 | extern const char *gemm_post_ops_inner_product_kernel; |
47 | extern const char *gen9_binary_kernel; |
48 | extern const char *gen9_bnorm_kernel; |
49 | extern const char *gen9_concat_kernel; |
50 | extern const char *gen9_conv_bwd_data_kernel; |
51 | extern const char *gen9_conv_bwd_weights_kernel; |
52 | extern const char *gen9_conv_dw_bwd_data_kernel; |
53 | extern const char *gen9_conv_dw_fwd_data_kernel; |
54 | extern const char *gen9_conv_fwd_data_kernel; |
55 | extern const char *gen9_conv_nhwc_bwd_data_kernel; |
56 | extern const char *gen9_conv_nhwc_bwd_weights_f32_kernel; |
57 | extern const char *gen9_conv_nhwc_fwd_data_kernel; |
58 | extern const char *gen9_eltwise_kernel; |
59 | extern const char *gen9_global_pooling_kernel; |
60 | extern const char *gen9_pooling_kernel; |
61 | extern const char *gen9_reduction_kernel; |
62 | extern const char *gen9_softmax_kernel; |
63 | extern const char *gen9_sum_kernel; |
64 | extern const char *gen9_wino_conv_fwd_data_2x3_kernel; |
65 | extern const char *gen9_wino_conv_fwd_data_fused_kernel; |
66 | extern const char *generic_reorder_kernel; |
67 | extern const char *many_inputs_sum_kernel; |
68 | extern const char *ref_binary_kernel; |
69 | extern const char *ref_bnorm_kernel; |
70 | extern const char *ref_convolution_kernel; |
71 | extern const char *ref_deconv_backward_bias_kernel; |
72 | extern const char *ref_eltwise_kernel; |
73 | extern const char *ref_inner_product_kernel; |
74 | extern const char *ref_layer_normalization_kernel; |
75 | extern const char *ref_lrn_kernel; |
76 | extern const char *ref_matmul_kernel; |
77 | extern const char *ref_pooling_kernel; |
78 | extern const char *ref_prelu_kernel; |
79 | extern const char *ref_reduction_kernel; |
80 | extern const char *ref_reorder_kernel; |
81 | extern const char *ref_resampling_kernel; |
82 | extern const char *ref_shuffle_kernel; |
83 | extern const char *ref_softmax_kernel; |
84 | extern const char *ref_zero_pad_kernel; |
85 | extern const char *ref_rnn_kernel; |
86 | extern const char *rnn_reorder_kernel; |
87 | extern const char *simple_concat_kernel; |
88 | extern const char *simple_sum_kernel; |
89 | extern const char *vectorized_resampling_kernel; |
90 | extern const char *xe_lp_1x1_conv_fwd_data_x8s8x_kernel; |
91 | extern const char *xe_lp_conv_bwd_data_mb_block_x8s8x8_kernel; |
92 | extern const char *xe_lp_conv_bwd_data_x8s8x8_kernel; |
93 | extern const char *xe_lp_conv_dw_fwd_data_mb_block_x8s8x_kernel; |
94 | extern const char *xe_lp_conv_dw_fwd_data_ow_block_x8s8x_kernel; |
95 | extern const char *xe_lp_conv_fwd_data_first_x8s8x_kernel; |
96 | extern const char *xe_lp_conv_fwd_data_mb_block_x8s8x_kernel; |
97 | extern const char *xe_lp_conv_fwd_data_ow_block_x8s8x_kernel; |
98 | extern const char *xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x_kernel; |
99 | extern const char *xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x_kernel; |
100 | extern const char *xe_lp_conv_nhwc_fwd_first_x8s8x_kernel; |
101 | extern const char *xe_lp_conv_nhwc_fwd_x8s8x_kernel; |
102 | extern const char *xe_lp_nhwc_1x1_conv_fwd_x8s8x_kernel; |
103 | extern const char *xe_lp_x8s8x_compensation_kernel; |
104 | |
105 | extern const char *; |
106 | extern const char *; |
107 | extern const char *; |
108 | extern const char *; |
109 | extern const char *; |
110 | extern const char *; |
111 | extern const char *; |
112 | extern const char *; |
113 | extern const char *; |
114 | extern const char *; |
115 | extern const char *; |
116 | extern const char *; |
117 | extern const char *; |
118 | |
119 | const char *get_kernel_source(const char *name) { |
120 | static const std::unordered_map<std::string, const char *> kernel_list = { |
121 | |
122 | { "combined_reduce" , combined_reduction_kernel }, |
123 | { "custom_reorder" , custom_reorder_kernel }, |
124 | { "gemm_post_ops" , gemm_with_post_ops_kernel }, |
125 | { "gen9_gemm_beta" , gen9_gemm_beta_kernel }, |
126 | { "gen9_gemm_compute" , gen9_gemm_compute_kernel }, |
127 | { "gen9_gemm_copy" , gen9_gemm_copy_kernel }, |
128 | { "gen9_gemm_nocopy_f16" , gen9_gemm_nocopy_f16_kernel }, |
129 | { "gen9_gemm_nocopy_f16" , gen9_gemm_nocopy_f16_kernel }, |
130 | { "gen9_gemm_nocopy_f16" , gen9_gemm_nocopy_f16_kernel }, |
131 | { "gen9_gemm_nocopy_f16" , gen9_gemm_nocopy_f16_kernel }, |
132 | { "gen9_gemm_nocopy_f32" , gen9_gemm_nocopy_f32_kernel }, |
133 | { "gen9_gemm_nocopy_f32" , gen9_gemm_nocopy_f32_kernel }, |
134 | { "gen9_gemm_nocopy_f32" , gen9_gemm_nocopy_f32_kernel }, |
135 | { "gen9_gemm_nocopy_f32" , gen9_gemm_nocopy_f32_kernel }, |
136 | { "gen9_gemm_scale_x8x8s32" , gen9_gemm_nocopy_scale_x8x8s32_kernel }, |
137 | { "gen9_gemm_nocopy_superkernel_f32" , gen9_gemm_nocopy_superkernel_f32_kernel }, |
138 | { "gen9_gemm_nocopy_superkernel_f32" , gen9_gemm_nocopy_superkernel_f32_kernel }, |
139 | { "gen9_gemm_compute_x8x8s32" , gen9_gemm_nocopy_x8x8s32_kernel }, |
140 | { "ref_gemm" , ref_gemm_kernel }, |
141 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
142 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
143 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
144 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
145 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
146 | { "xe_hp_systolic_gemm_copy" , xe_hp_systolic_gemm_copy_kernel }, |
147 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
148 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
149 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
150 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
151 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
152 | { "xe_hpc_systolic_gemm_copy" , xe_hpc_systolic_gemm_copy_kernel }, |
153 | { "xe_lp_gemm_scale_x8x8s32" , xe_lp_gemm_nocopy_scale_x8x8s32_kernel }, |
154 | { "xe_lp_gemm_compute_x8x8s32" , xe_lp_gemm_nocopy_x8x8s32_kernel }, |
155 | { "xe_lp_gemm_compute_x8x8s32" , xe_lp_gemm_nocopy_x8x8s32_kernel }, |
156 | { "xe_lp_gemm_compute_x8x8s32" , xe_lp_gemm_nocopy_x8x8s32_kernel }, |
157 | { "xe_lp_gemm_compute_x8x8s32" , xe_lp_gemm_nocopy_x8x8s32_kernel }, |
158 | { "gemm_post_ops_inner_product" , gemm_post_ops_inner_product_kernel }, |
159 | { "gen9_binary" , gen9_binary_kernel }, |
160 | { "gen9_binary" , gen9_binary_kernel }, |
161 | { "gen9_binary" , gen9_binary_kernel }, |
162 | { "gen9_binary" , gen9_binary_kernel }, |
163 | { "gen9_fused_reduce_init" , gen9_bnorm_kernel }, |
164 | { "gen9_calc_mean_var" , gen9_bnorm_kernel }, |
165 | { "gen9_calc_mean_var" , gen9_bnorm_kernel }, |
166 | { "gen9_reduce_mean_var" , gen9_bnorm_kernel }, |
167 | { "gen9_fused_reduce_final" , gen9_bnorm_kernel }, |
168 | { "gen9_calc_mean" , gen9_bnorm_kernel }, |
169 | { "gen9_calc_mean" , gen9_bnorm_kernel }, |
170 | { "gen9_reduce_mean" , gen9_bnorm_kernel }, |
171 | { "gen9_calc_variance" , gen9_bnorm_kernel }, |
172 | { "gen9_calc_variance" , gen9_bnorm_kernel }, |
173 | { "gen9_reduce_variance" , gen9_bnorm_kernel }, |
174 | { "gen9_bnorm_fwd" , gen9_bnorm_kernel }, |
175 | { "gen9_bnorm_fwd" , gen9_bnorm_kernel }, |
176 | { "gen9_calculate_stats" , gen9_bnorm_kernel }, |
177 | { "gen9_calculate_stats" , gen9_bnorm_kernel }, |
178 | { "gen9_reduce_stats" , gen9_bnorm_kernel }, |
179 | { "gen9_fused_reduce_final" , gen9_bnorm_kernel }, |
180 | { "gen9_bnorm_bwd" , gen9_bnorm_kernel }, |
181 | { "gen9_bnorm_bwd" , gen9_bnorm_kernel }, |
182 | { "gen9_concat" , gen9_concat_kernel }, |
183 | { "gen9_conv_bwd_data" , gen9_conv_bwd_data_kernel }, |
184 | { "gen9_conv_bwd_weights" , gen9_conv_bwd_weights_kernel }, |
185 | { "gen9_conv_dw_bwd_data" , gen9_conv_dw_bwd_data_kernel }, |
186 | { "gen9_conv_dw_fwd" , gen9_conv_dw_fwd_data_kernel }, |
187 | { "gen9_conv_fwd" , gen9_conv_fwd_data_kernel }, |
188 | { "gen9_conv_nhwc_bwd_data" , gen9_conv_nhwc_bwd_data_kernel }, |
189 | { "gen9_conv_nhwc_bwd_weights" , gen9_conv_nhwc_bwd_weights_f32_kernel }, |
190 | { "gen9_conv_nhwc_fwd" , gen9_conv_nhwc_fwd_data_kernel }, |
191 | { "gen9_eltwise_fwd" , gen9_eltwise_kernel }, |
192 | { "gen9_eltwise_bwd" , gen9_eltwise_kernel }, |
193 | { "gen9_global_pooling_fwd" , gen9_global_pooling_kernel }, |
194 | { "gen9_global_pooling_bwd" , gen9_global_pooling_kernel }, |
195 | { "gen9_pooling_fwd" , gen9_pooling_kernel }, |
196 | { "gen9_pooling_bwd" , gen9_pooling_kernel }, |
197 | { "gen9_initial_reduce" , gen9_reduction_kernel }, |
198 | { "gen9_final_reduce" , gen9_reduction_kernel }, |
199 | { "gen9_softmax_fwd" , gen9_softmax_kernel }, |
200 | { "gen9_softmax_bwd" , gen9_softmax_kernel }, |
201 | { "gen9_sum" , gen9_sum_kernel }, |
202 | { "gen9_wino_wei_transform_2x3" , gen9_wino_conv_fwd_data_2x3_kernel }, |
203 | { "gen9_wino_src_transform_2x3" , gen9_wino_conv_fwd_data_2x3_kernel }, |
204 | { "gen9_wino_dst_transform_2x3" , gen9_wino_conv_fwd_data_2x3_kernel }, |
205 | { "gen9_wino_conv_fwd_2x3" , gen9_wino_conv_fwd_data_2x3_kernel }, |
206 | { "gen9_wino_wei_transform" , gen9_wino_conv_fwd_data_fused_kernel }, |
207 | { "gen9_wino_conv_fwd" , gen9_wino_conv_fwd_data_fused_kernel }, |
208 | { "generic_reorder" , generic_reorder_kernel }, |
209 | { "many_inputs_sum" , many_inputs_sum_kernel }, |
210 | { "many_inputs_sum_batched" , many_inputs_sum_kernel }, |
211 | { "ref_binary" , ref_binary_kernel }, |
212 | { "ref_binary" , ref_binary_kernel }, |
213 | { "calculate_mean" , ref_bnorm_kernel }, |
214 | { "calculate_variance" , ref_bnorm_kernel }, |
215 | { "calculate_mean_variance" , ref_bnorm_kernel }, |
216 | { "calculate_mean" , ref_bnorm_kernel }, |
217 | { "calculate_variance" , ref_bnorm_kernel }, |
218 | { "reduce_mean" , ref_bnorm_kernel }, |
219 | { "reduce_variance" , ref_bnorm_kernel }, |
220 | { "ref_bnorm_fwd" , ref_bnorm_kernel }, |
221 | { "calculate_stats" , ref_bnorm_kernel }, |
222 | { "reduce_stats" , ref_bnorm_kernel }, |
223 | { "calculate_stats" , ref_bnorm_kernel }, |
224 | { "reduce_stats" , ref_bnorm_kernel }, |
225 | { "ref_bnorm_bwd" , ref_bnorm_kernel }, |
226 | { "ref_convolution_fwd" , ref_convolution_kernel }, |
227 | { "ref_convolution_bwd_data" , ref_convolution_kernel }, |
228 | { "ref_convolution_bwd_weights" , ref_convolution_kernel }, |
229 | { "ref_deconv_backward_bias" , ref_deconv_backward_bias_kernel }, |
230 | { "ref_eltwise_fwd" , ref_eltwise_kernel }, |
231 | { "ref_eltwise_bwd" , ref_eltwise_kernel }, |
232 | { "ref_inner_product_fwd" , ref_inner_product_kernel }, |
233 | { "ref_inner_product_bwd_data" , ref_inner_product_kernel }, |
234 | { "ref_inner_product_bwd_weights" , ref_inner_product_kernel }, |
235 | { "ref_lnorm_fwd" , ref_layer_normalization_kernel }, |
236 | { "ref_lnorm_fwd" , ref_layer_normalization_kernel }, |
237 | { "ref_lnorm_bwd_scaleshift" , ref_layer_normalization_kernel }, |
238 | { "ref_lnorm_bwd_scaleshift_final" , ref_layer_normalization_kernel }, |
239 | { "ref_lnorm_bwd_scaleshift" , ref_layer_normalization_kernel }, |
240 | { "ref_lnorm_bwd" , ref_layer_normalization_kernel }, |
241 | { "ref_lnorm_bwd" , ref_layer_normalization_kernel }, |
242 | { "ref_lrn_fwd" , ref_lrn_kernel }, |
243 | { "ref_lrn_bwd" , ref_lrn_kernel }, |
244 | { "ref_matmul" , ref_matmul_kernel }, |
245 | { "ref_pooling_fwd" , ref_pooling_kernel }, |
246 | { "ref_pooling_bwd" , ref_pooling_kernel }, |
247 | { "ref_prelu_fwd" , ref_prelu_kernel }, |
248 | { "ref_prelu_bwd" , ref_prelu_kernel }, |
249 | { "ref_reduce" , ref_reduction_kernel }, |
250 | { "ref_reorder" , ref_reorder_kernel }, |
251 | { "ref_resampling_fwd" , ref_resampling_kernel }, |
252 | { "ref_resampling_bwd" , ref_resampling_kernel }, |
253 | { "ref_shuffle" , ref_shuffle_kernel }, |
254 | { "ref_softmax_fwd_generic" , ref_softmax_kernel }, |
255 | { "ref_softmax_bwd_generic" , ref_softmax_kernel }, |
256 | { "ref_zero_pad" , ref_zero_pad_kernel }, |
257 | { "ref_zero_pad_subg_16" , ref_zero_pad_kernel }, |
258 | { "ref_zero_pad_subg_16_mask_and_clear_dt_1b" , ref_zero_pad_kernel }, |
259 | { "ref_rnn_copy_init_layer" , ref_rnn_kernel }, |
260 | { "ref_rnn_copy_init_iter" , ref_rnn_kernel }, |
261 | { "ref_rnn_copy_res_layer" , ref_rnn_kernel }, |
262 | { "ref_rnn_copy_res_iter" , ref_rnn_kernel }, |
263 | { "ref_rnn_ws_set" , ref_rnn_kernel }, |
264 | { "ref_rnn_ws_print" , ref_rnn_kernel }, |
265 | { "ref_rnn_bias_prepare" , ref_rnn_kernel }, |
266 | { "ref_rnn_elemwise_fwd" , ref_rnn_kernel }, |
267 | { "ref_rnn_elemwise_fwd" , ref_rnn_kernel }, |
268 | { "ref_rnn_elemwise_bwd" , ref_rnn_kernel }, |
269 | { "ref_rnn_gates_reduction" , ref_rnn_kernel }, |
270 | { "wei_reorder" , rnn_reorder_kernel }, |
271 | { "simple_concat" , simple_concat_kernel }, |
272 | { "simple_sum" , simple_sum_kernel }, |
273 | { "vectorized_resampling_bwd" , vectorized_resampling_kernel }, |
274 | { "xe_lp_1x1_conv_fwd_x8s8x" , xe_lp_1x1_conv_fwd_data_x8s8x_kernel }, |
275 | { "conv_bwd_data_mb_block_x8s8x8" , xe_lp_conv_bwd_data_mb_block_x8s8x8_kernel }, |
276 | { "conv_bwd_data_x8s8x8" , xe_lp_conv_bwd_data_x8s8x8_kernel }, |
277 | { "conv_dw_fwd_mb_block_x8s8x" , xe_lp_conv_dw_fwd_data_mb_block_x8s8x_kernel }, |
278 | { "conv_dw_fwd_ow_block_x8s8x" , xe_lp_conv_dw_fwd_data_ow_block_x8s8x_kernel }, |
279 | { "conv_fwd_first_x8s8x" , xe_lp_conv_fwd_data_first_x8s8x_kernel }, |
280 | { "conv_fwd_mb_block_x8s8x" , xe_lp_conv_fwd_data_mb_block_x8s8x_kernel }, |
281 | { "conv_fwd_ow_block_x8s8x" , xe_lp_conv_fwd_data_ow_block_x8s8x_kernel }, |
282 | { "conv_nhwc_fwd_dw_mb_block_x8s8x" , xe_lp_conv_nhwc_fwd_dw_mb_block_x8s8x_kernel }, |
283 | { "conv_nhwc_fwd_dw_ow_block_x8s8x" , xe_lp_conv_nhwc_fwd_dw_ow_block_x8s8x_kernel }, |
284 | { "conv_nhwc_fwd_first_x8s8x" , xe_lp_conv_nhwc_fwd_first_x8s8x_kernel }, |
285 | { "conv_nhwc_fwd_x8s8x" , xe_lp_conv_nhwc_fwd_x8s8x_kernel }, |
286 | { "xe_lp_nhwc_1x1_conv_fwd_x8s8x" , xe_lp_nhwc_1x1_conv_fwd_x8s8x_kernel }, |
287 | { "xe_lp_x8s8x_compensation" , xe_lp_x8s8x_compensation_kernel }, |
288 | { "xe_lp_x8s8x_compensation" , xe_lp_x8s8x_compensation_kernel }, |
289 | }; |
290 | |
291 | if (!name) return nullptr; |
292 | |
293 | assert(kernel_list.count(name) == 1); |
294 | return kernel_list.at(name); |
295 | } |
296 | |
297 | const char *(const std::string &name) { |
298 | static const std::unordered_map<std::string, const char *> ={ |
299 | |
300 | {"gpu/ocl/binary_types.h" , binary_types_header}, |
301 | {"gpu/ocl/gemm/ocl_gemm_attrs.h" , ocl_gemm_attrs_header}, |
302 | {"gpu/ocl/mdapi/metrics_discovery_api.h" , metrics_discovery_api_header}, |
303 | {"gpu/ocl/ocl_eltwise.h" , ocl_eltwise_header}, |
304 | {"gpu/ocl/ocl_math_utils.h" , ocl_math_utils_header}, |
305 | {"gpu/ocl/ocl_post_ops.h" , ocl_post_ops_header}, |
306 | {"gpu/ocl/ocl_scales.h" , ocl_scales_header}, |
307 | {"gpu/ocl/ocl_types.h" , ocl_types_header}, |
308 | {"gpu/ocl/ocl_zero_points.h" , ocl_zero_points_header}, |
309 | {"gpu/ocl/offsets.h" , offsets_header}, |
310 | {"gpu/ocl/reorder_common.h" , reorder_common_header}, |
311 | {"gpu/ocl/rnn/rnn_types.h" , rnn_types_header}, |
312 | {"gpu/zero_pad_struct.h" , zero_pad_struct_header}, |
313 | }; |
314 | |
315 | assert(kernel_header_list.count(name) == 1); |
316 | return kernel_header_list.at(name); |
317 | } |
318 | |
319 | } // namespace ocl |
320 | } // namespace gpu |
321 | } // namespace impl |
322 | } // namespace dnnl |
323 | |