1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_OCL_GPU_KERNEL_HPP
18#define GPU_OCL_OCL_GPU_KERNEL_HPP
19
20#include <assert.h>
21#include <string>
22#include <CL/cl.h>
23
24#include "gpu/compute/compute.hpp"
25#include "gpu/ocl/ocl_utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace ocl {
31
32class ocl_gpu_kernel_cache_t;
33
34class ocl_gpu_kernel_t : public compute::kernel_impl_t {
35public:
36 ocl_gpu_kernel_t(const std::shared_ptr<compute::binary_t> &binary,
37 const std::string &kernel_name,
38 const std::vector<gpu::compute::scalar_type_t> &arg_types)
39 : state_(state_t::binary)
40 , ocl_kernel_(nullptr)
41 , binary_(binary)
42 , binary_size_(binary->size())
43 , kernel_name_(kernel_name)
44 , arg_types_(arg_types) {
45 MAYBE_UNUSED(state_);
46 }
47
48 ~ocl_gpu_kernel_t() override;
49
50 cl_kernel ocl_kernel() const {
51 assert(state_ == state_t::kernel);
52 return ocl_kernel_;
53 }
54
55 status_t parallel_for(stream_t &stream, const compute::nd_range_t &range,
56 const compute::kernel_arg_list_t &arg_list) override;
57
58 status_t realize(compute::kernel_t *kernel, const engine_t *engine,
59 compute::program_list_t *programs) const override;
60
61 const char *name() const {
62 assert(state_ == state_t::binary);
63 return kernel_name_.c_str();
64 }
65
66 const std::shared_ptr<compute::binary_t> &binary() const override {
67 assert(state_ == state_t::binary);
68 return binary_;
69 }
70
71 status_t binary(engine_t *engine, compute::binary_t &binary) const override;
72
73 const std::vector<gpu::compute::scalar_type_t> &arg_types() const override {
74 return arg_types_;
75 }
76
77 void clear() override {
78 assert(state_ == state_t::binary);
79 binary_->clear();
80 kernel_name_.clear();
81 arg_types_.clear();
82 }
83
84 status_t binary_size(size_t *binary_size) const override {
85 (*binary_size) = binary_size_;
86 return status::success;
87 }
88
89 enum class state_t { binary, kernel };
90
91protected:
92 ocl_gpu_kernel_t(cl_kernel ocl_kernel,
93 const std::vector<gpu::compute::scalar_type_t> &arg_types);
94
95 state_t state_;
96 cl_kernel ocl_kernel_;
97 std::shared_ptr<compute::binary_t> binary_;
98 // The binary_ is cleared via `clear()` to reduce memory footprint. Because
99 // of that the binary size is stored separately to avoid querying it.
100 size_t binary_size_;
101 std::string kernel_name_;
102
103 std::vector<gpu::compute::scalar_type_t> arg_types_;
104
105 std::shared_ptr<ocl_gpu_kernel_cache_t> cache_;
106};
107
108} // namespace ocl
109} // namespace gpu
110} // namespace impl
111} // namespace dnnl
112
113#endif // GPU_OCL_OCL_GPU_KERNEL_HPP
114