1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_COMPUTE_KERNEL_HPP |
18 | #define GPU_COMPUTE_KERNEL_HPP |
19 | |
20 | #include <memory> |
21 | #include <utility> |
22 | |
23 | #include "common/stream.hpp" |
24 | #include "gpu/compute/kernel_arg_list.hpp" |
25 | #include "gpu/compute/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace compute { |
31 | |
32 | class program_list_t; |
33 | class kernel_impl_t; |
34 | |
35 | class kernel_t { |
36 | public: |
37 | using id_t = intptr_t; |
38 | kernel_t(kernel_impl_t *impl) : impl_(impl) {} |
39 | |
40 | kernel_t() = default; |
41 | kernel_t(kernel_t &&other) = default; |
42 | kernel_t(const kernel_t &other) = default; |
43 | kernel_t &operator=(const kernel_t &other) = default; |
44 | virtual ~kernel_t() = default; |
45 | |
46 | operator bool() const { return bool(impl_); } |
47 | id_t id() const; |
48 | |
49 | kernel_impl_t *impl() const { return impl_.get(); } |
50 | |
51 | status_t parallel_for(stream_t &stream, const nd_range_t &range, |
52 | const kernel_arg_list_t &arg_list) const; |
53 | |
54 | status_t parallel_for( |
55 | stream_t &stream, const std::function<void(void *)> &cgf) const; |
56 | |
57 | status_t realize(kernel_t *kernel, const engine_t *engine, |
58 | program_list_t *programs) const; |
59 | |
60 | void clear(); |
61 | status_t binary_size(size_t *binary_size) const; |
62 | |
63 | status_t binary(engine_t *engine, compute::binary_t &binary) const; |
64 | const std::shared_ptr<compute::binary_t> &binary() const; |
65 | |
66 | const std::vector<scalar_type_t> &arg_types() const; |
67 | |
68 | private: |
69 | std::shared_ptr<kernel_impl_t> impl_; |
70 | }; |
71 | |
72 | class kernel_impl_t { |
73 | public: |
74 | kernel_impl_t() = default; |
75 | |
76 | kernel_impl_t(const kernel_impl_t &) = delete; |
77 | kernel_impl_t &operator=(const kernel_impl_t &) = delete; |
78 | virtual ~kernel_impl_t() = default; |
79 | |
80 | virtual status_t parallel_for(stream_t &stream, const nd_range_t &range, |
81 | const kernel_arg_list_t &arg_list) { |
82 | assert(!"unexpected" ); |
83 | return status::runtime_error; |
84 | } |
85 | |
86 | virtual status_t parallel_for( |
87 | stream_t &stream, const std::function<void(void *)> &cgf) { |
88 | assert(!"unexpected" ); |
89 | return status::runtime_error; |
90 | } |
91 | |
92 | virtual status_t realize(kernel_t *kernel, const engine_t *engine, |
93 | program_list_t *programs) const { |
94 | return status::success; |
95 | } |
96 | |
97 | virtual void clear() {} |
98 | |
99 | virtual status_t binary_size(size_t *binary_size) const { |
100 | assert(!"unexpected" ); |
101 | return status::runtime_error; |
102 | } |
103 | virtual status_t binary(engine_t *engine, compute::binary_t &binary) const { |
104 | assert(!"unexpected" ); |
105 | return status::runtime_error; |
106 | } |
107 | |
108 | virtual const std::vector<scalar_type_t> &arg_types() const { |
109 | static const std::vector<scalar_type_t> dummy; |
110 | return dummy; |
111 | } |
112 | |
113 | virtual const std::shared_ptr<compute::binary_t> &binary() const { |
114 | static const std::shared_ptr<compute::binary_t> dummy; |
115 | return dummy; |
116 | } |
117 | }; |
118 | |
119 | inline kernel_t::id_t kernel_t::id() const { |
120 | return reinterpret_cast<id_t>(impl_.get()); |
121 | } |
122 | inline status_t kernel_t::parallel_for(stream_t &stream, |
123 | const nd_range_t &range, const kernel_arg_list_t &arg_list) const { |
124 | return impl_->parallel_for(stream, range, arg_list); |
125 | } |
126 | inline status_t kernel_t::parallel_for( |
127 | stream_t &stream, const std::function<void(void *)> &cgf) const { |
128 | return impl_->parallel_for(stream, cgf); |
129 | } |
130 | |
131 | inline status_t kernel_t::realize(kernel_t *kernel, const engine_t *engine, |
132 | program_list_t *programs) const { |
133 | return impl_->realize(kernel, engine, programs); |
134 | } |
135 | |
136 | inline void kernel_t::clear() { |
137 | impl_->clear(); |
138 | } |
139 | |
140 | inline status_t kernel_t::binary_size(size_t *binary_size) const { |
141 | return impl_->binary_size(binary_size); |
142 | } |
143 | |
144 | inline status_t kernel_t::binary( |
145 | engine_t *engine, compute::binary_t &binary) const { |
146 | return impl_->binary(engine, binary); |
147 | } |
148 | |
149 | inline const std::shared_ptr<compute::binary_t> &kernel_t::binary() const { |
150 | return impl_->binary(); |
151 | } |
152 | |
153 | inline const std::vector<scalar_type_t> &kernel_t::arg_types() const { |
154 | return impl_->arg_types(); |
155 | } |
156 | |
157 | } // namespace compute |
158 | } // namespace gpu |
159 | } // namespace impl |
160 | } // namespace dnnl |
161 | |
162 | #endif // GPU_COMPUTE_KERNEL_HPP |
163 | |