1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_KERNEL_UTILS_HPP
18#define GPU_OCL_KERNEL_UTILS_HPP
19
20#include <vector>
21#include <unordered_map>
22
23#include "gpu/compute/compute.hpp"
24#include "gpu/ocl/ocl_gpu_engine.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace ocl {
30
31const char *get_kernel_source(const char *name);
32const char *get_kernel_header(const std::string &name);
33
34template <typename GetKernelSourceFunc>
35status_t create_kernels(const compute::compute_engine_t *engine,
36 compute::kernel_list_t &kernel_list,
37 const compute::kernel_ctx_t &kernel_ctx,
38 const GetKernelSourceFunc &get_kernel_source_func) {
39 auto *ocl_engine = utils::downcast<const ocl::ocl_gpu_engine_t *>(engine);
40
41 // Group kernels by their source.
42 std::unordered_map<const char *, std::vector<const char *>> source_to_names;
43 for (auto &kv : kernel_list.kernels()) {
44 auto &name = kv.first;
45 const char *source = get_kernel_source_func(name.c_str());
46 source_to_names[source].push_back(name.c_str());
47 }
48
49 // Iterate through sources, create all kernels for the current source.
50 for (auto &kv : source_to_names) {
51 std::vector<compute::kernel_t> kernels;
52 CHECK(ocl_engine->create_kernels_from_ocl_source(
53 &kernels, kv.second, kv.first, kernel_ctx));
54
55 // Update kernel list with created kernels.
56 for (size_t i = 0; i < kv.second.size(); ++i) {
57 kernel_list.set(kv.second[i], kernels[i]);
58 }
59 }
60 return status::success;
61}
62
63inline status_t create_kernels(const compute::compute_engine_t *engine,
64 compute::kernel_list_t &kernel_list,
65 const compute::kernel_ctx_t &kernel_ctx) {
66 return create_kernels(
67 engine, kernel_list, kernel_ctx, ocl::get_kernel_source);
68}
69
70} // namespace ocl
71} // namespace gpu
72} // namespace impl
73} // namespace dnnl
74
75#endif
76