1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_COMPUTE_KERNEL_HPP
18#define GPU_COMPUTE_KERNEL_HPP
19
20#include <memory>
21#include <utility>
22
23#include "common/stream.hpp"
24#include "gpu/compute/kernel_arg_list.hpp"
25#include "gpu/compute/utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace compute {
31
32class program_list_t;
33class kernel_impl_t;
34
35class kernel_t {
36public:
37 using id_t = intptr_t;
38 kernel_t(kernel_impl_t *impl) : impl_(impl) {}
39
40 kernel_t() = default;
41 kernel_t(kernel_t &&other) = default;
42 kernel_t(const kernel_t &other) = default;
43 kernel_t &operator=(const kernel_t &other) = default;
44 virtual ~kernel_t() = default;
45
46 operator bool() const { return bool(impl_); }
47 id_t id() const;
48
49 kernel_impl_t *impl() const { return impl_.get(); }
50
51 status_t parallel_for(stream_t &stream, const nd_range_t &range,
52 const kernel_arg_list_t &arg_list) const;
53
54 status_t parallel_for(
55 stream_t &stream, const std::function<void(void *)> &cgf) const;
56
57 status_t realize(kernel_t *kernel, const engine_t *engine,
58 program_list_t *programs) const;
59
60 void clear();
61 status_t binary_size(size_t *binary_size) const;
62
63 status_t binary(engine_t *engine, compute::binary_t &binary) const;
64 const std::shared_ptr<compute::binary_t> &binary() const;
65
66 const std::vector<scalar_type_t> &arg_types() const;
67
68private:
69 std::shared_ptr<kernel_impl_t> impl_;
70};
71
72class kernel_impl_t {
73public:
74 kernel_impl_t() = default;
75
76 kernel_impl_t(const kernel_impl_t &) = delete;
77 kernel_impl_t &operator=(const kernel_impl_t &) = delete;
78 virtual ~kernel_impl_t() = default;
79
80 virtual status_t parallel_for(stream_t &stream, const nd_range_t &range,
81 const kernel_arg_list_t &arg_list) {
82 assert(!"unexpected");
83 return status::runtime_error;
84 }
85
86 virtual status_t parallel_for(
87 stream_t &stream, const std::function<void(void *)> &cgf) {
88 assert(!"unexpected");
89 return status::runtime_error;
90 }
91
92 virtual status_t realize(kernel_t *kernel, const engine_t *engine,
93 program_list_t *programs) const {
94 return status::success;
95 }
96
97 virtual void clear() {}
98
99 virtual status_t binary_size(size_t *binary_size) const {
100 assert(!"unexpected");
101 return status::runtime_error;
102 }
103 virtual status_t binary(engine_t *engine, compute::binary_t &binary) const {
104 assert(!"unexpected");
105 return status::runtime_error;
106 }
107
108 virtual const std::vector<scalar_type_t> &arg_types() const {
109 static const std::vector<scalar_type_t> dummy;
110 return dummy;
111 }
112
113 virtual const std::shared_ptr<compute::binary_t> &binary() const {
114 static const std::shared_ptr<compute::binary_t> dummy;
115 return dummy;
116 }
117};
118
119inline kernel_t::id_t kernel_t::id() const {
120 return reinterpret_cast<id_t>(impl_.get());
121}
122inline status_t kernel_t::parallel_for(stream_t &stream,
123 const nd_range_t &range, const kernel_arg_list_t &arg_list) const {
124 return impl_->parallel_for(stream, range, arg_list);
125}
126inline status_t kernel_t::parallel_for(
127 stream_t &stream, const std::function<void(void *)> &cgf) const {
128 return impl_->parallel_for(stream, cgf);
129}
130
131inline status_t kernel_t::realize(kernel_t *kernel, const engine_t *engine,
132 program_list_t *programs) const {
133 return impl_->realize(kernel, engine, programs);
134}
135
136inline void kernel_t::clear() {
137 impl_->clear();
138}
139
140inline status_t kernel_t::binary_size(size_t *binary_size) const {
141 return impl_->binary_size(binary_size);
142}
143
144inline status_t kernel_t::binary(
145 engine_t *engine, compute::binary_t &binary) const {
146 return impl_->binary(engine, binary);
147}
148
149inline const std::shared_ptr<compute::binary_t> &kernel_t::binary() const {
150 return impl_->binary();
151}
152
153inline const std::vector<scalar_type_t> &kernel_t::arg_types() const {
154 return impl_->arg_types();
155}
156
157} // namespace compute
158} // namespace gpu
159} // namespace impl
160} // namespace dnnl
161
162#endif // GPU_COMPUTE_KERNEL_HPP
163