1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_COMPUTE_KERNEL_ARG_LIST_HPP |
18 | #define GPU_COMPUTE_KERNEL_ARG_LIST_HPP |
19 | |
20 | #include <cassert> |
21 | #include <cstddef> |
22 | #include <type_traits> |
23 | |
24 | #include "common/bfloat16.hpp" |
25 | #include "common/float16.hpp" |
26 | #include "common/memory_storage.hpp" |
27 | #include "common/nstl.hpp" |
28 | #include "common/verbose.hpp" |
29 | |
30 | #include "gpu/zero_pad_struct.h" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace compute { |
36 | |
37 | enum class kernel_arg_kind_t { |
38 | undef, |
39 | global, |
40 | local, |
41 | scalar, |
42 | svm, |
43 | }; |
44 | |
45 | enum class scalar_type_t { |
46 | undef, |
47 | _char, |
48 | _bfloat16, |
49 | _float, |
50 | _half, |
51 | _int, |
52 | _long, |
53 | _short, |
54 | _uchar, |
55 | _uint, |
56 | _ulong, |
57 | _ushort, |
58 | _zero_pad_mask_t, |
59 | }; |
60 | |
61 | template <typename T> |
62 | struct scalar_type_traits {}; |
63 | |
64 | template <> |
65 | struct scalar_type_traits<float16_t> { |
66 | static const auto type = scalar_type_t::_half; |
67 | }; |
68 | template <> |
69 | struct scalar_type_traits<bfloat16_t> { |
70 | static const auto type = scalar_type_t::_bfloat16; |
71 | }; |
72 | template <> |
73 | struct scalar_type_traits<float> { |
74 | static const auto type = scalar_type_t::_float; |
75 | }; |
76 | |
77 | template <> |
78 | struct scalar_type_traits<uint8_t> { |
79 | static const auto type = scalar_type_t::_uchar; |
80 | }; |
81 | template <> |
82 | struct scalar_type_traits<uint16_t> { |
83 | static const auto type = scalar_type_t::_ushort; |
84 | }; |
85 | template <> |
86 | struct scalar_type_traits<uint32_t> { |
87 | static const auto type = scalar_type_t::_uint; |
88 | }; |
89 | template <> |
90 | struct scalar_type_traits<uint64_t> { |
91 | static const auto type = scalar_type_t::_ulong; |
92 | }; |
93 | |
94 | template <> |
95 | struct scalar_type_traits<int8_t> { |
96 | static const auto type = scalar_type_t::_char; |
97 | }; |
98 | template <> |
99 | struct scalar_type_traits<int16_t> { |
100 | static const auto type = scalar_type_t::_short; |
101 | }; |
102 | template <> |
103 | struct scalar_type_traits<int32_t> { |
104 | static const auto type = scalar_type_t::_int; |
105 | }; |
106 | template <> |
107 | struct scalar_type_traits<int64_t> { |
108 | static const auto type = scalar_type_t::_long; |
109 | }; |
110 | template <> |
111 | struct scalar_type_traits<zero_pad_mask_t> { |
112 | static const auto type = scalar_type_t::_zero_pad_mask_t; |
113 | }; |
114 | |
115 | class kernel_arg_t { |
116 | public: |
117 | kernel_arg_kind_t kind() const { return kind_; } |
118 | scalar_type_t scalar_type() const { return scalar_type_; } |
119 | size_t size() const { return size_; } |
120 | |
121 | bool is_global() const { return kind() == kernel_arg_kind_t::global; } |
122 | bool is_local() const { return kind() == kernel_arg_kind_t::local; } |
123 | bool is_svm_pointer() const { return kind_ == kernel_arg_kind_t::svm; } |
124 | |
125 | kernel_arg_t &set_value(const memory_storage_t &storage) { |
126 | kind_ = kernel_arg_kind_t::global; |
127 | size_ = 0; |
128 | value_ = static_cast<const void *>(&storage); |
129 | return *this; |
130 | } |
131 | |
132 | template <typename T> |
133 | kernel_arg_t &set_value(const T &value, void *&data_pool) { |
134 | assert(size_ <= sizeof(T)); |
135 | if (value_ == nullptr) { |
136 | assert(data_pool != nullptr); |
137 | size_ = sizeof(T); |
138 | data_pool = utils::align_ptr(data_pool, alignof(T)); |
139 | value_ = data_pool; |
140 | data_pool = static_cast<char *>(data_pool) + size_; |
141 | } |
142 | kind_ = kernel_arg_kind_t::scalar; |
143 | scalar_type_ = scalar_type_traits<T>::type; |
144 | new (const_cast<void *>(value_)) T(value); |
145 | return *this; |
146 | } |
147 | |
148 | kernel_arg_t &set_value(size_t size, std::nullptr_t) { |
149 | kind_ = kernel_arg_kind_t::local; |
150 | size_ = size; |
151 | value_ = nullptr; |
152 | return *this; |
153 | } |
154 | |
155 | void set_value(void *svm_ptr, kernel_arg_kind_t kind) { |
156 | assert(kind == kernel_arg_kind_t::svm); |
157 | kind_ = kernel_arg_kind_t::svm; |
158 | size_ = 0; |
159 | value_ = svm_ptr; |
160 | } |
161 | |
162 | const void *value() const { |
163 | assert(kind() != kernel_arg_kind_t::undef); |
164 | return value_; |
165 | } |
166 | |
167 | template <typename T> |
168 | T as() const { |
169 | assert(kind() == kernel_arg_kind_t::scalar); |
170 | assert(scalar_type() == scalar_type_traits<T>::type); |
171 | return *(const T *)value(); |
172 | } |
173 | |
174 | private: |
175 | kernel_arg_kind_t kind_ = kernel_arg_kind_t::undef; |
176 | scalar_type_t scalar_type_ = scalar_type_t::undef; |
177 | size_t size_ = 0; |
178 | const void *value_ = nullptr; |
179 | }; |
180 | |
181 | class kernel_arg_list_t { |
182 | public: |
183 | kernel_arg_list_t() { nargs_ = 0; } |
184 | void set(int index, const memory_storage_t &storage) { |
185 | assert(index < max_args); |
186 | nargs_ = nstl::max(nargs_, index + 1); |
187 | args_[index].set_value(storage); |
188 | } |
189 | |
190 | void set(int index, void *value, kernel_arg_kind_t kind) { |
191 | assert(index < max_args); |
192 | nargs_ = nstl::max(nargs_, index + 1); |
193 | args_[index].set_value(value, kind); |
194 | } |
195 | |
196 | template <class T> |
197 | void set(int index, const T &value) { |
198 | assert(index < max_args); |
199 | nargs_ = nstl::max(nargs_, index + 1); |
200 | args_[index].set_value(value, unused_storage); |
201 | |
202 | assert(unused_storage |
203 | <= reinterpret_cast<char *>(&scalar_storage_) + storage_size); |
204 | } |
205 | |
206 | void set(int index, size_t size, std::nullptr_t) { |
207 | assert(index < max_args); |
208 | nargs_ = nstl::max(nargs_, index + 1); |
209 | args_[index].set_value(size, nullptr); |
210 | } |
211 | |
212 | int nargs() const { return nargs_; } |
213 | |
214 | const kernel_arg_t &get(int index) const { |
215 | assert(index < nargs()); |
216 | return args_[index]; |
217 | } |
218 | |
219 | const memory_storage_t &get_memory_storage(int index) const { |
220 | assert(args_[index].kind() == kernel_arg_kind_t::global); |
221 | return *static_cast<const memory_storage_t *>(args_[index].value()); |
222 | } |
223 | |
224 | private: |
225 | static constexpr int max_args = 96; |
226 | static constexpr int storage_size = 512; |
227 | static constexpr int storage_alginment = 8; |
228 | |
229 | int nargs_ = 0; |
230 | kernel_arg_t args_[max_args]; |
231 | typename std::aligned_storage<storage_size, storage_alginment>::type |
232 | scalar_storage_; |
233 | void *unused_storage = &scalar_storage_; |
234 | |
235 | kernel_arg_list_t(const kernel_arg_list_t &) = delete; |
236 | kernel_arg_list_t(kernel_arg_list_t &&) = delete; |
237 | kernel_arg_list_t &operator=(const kernel_arg_list_t &) = delete; |
238 | kernel_arg_list_t &operator=(kernel_arg_list_t &&) = delete; |
239 | }; |
240 | |
241 | template <typename T> |
242 | void set_scalar_arg_cvt(kernel_arg_list_t &arg_list, int index, T scalar, |
243 | scalar_type_t requested_type) { |
244 | if (scalar_type_traits<T>::type == requested_type) { |
245 | arg_list.set(index, scalar); |
246 | return; |
247 | } |
248 | |
249 | switch (requested_type) { |
250 | case scalar_type_t::_half: |
251 | arg_list.set(index, (float16_t)scalar); |
252 | break; |
253 | case scalar_type_t::_uchar: arg_list.set(index, (uint8_t)scalar); break; |
254 | case scalar_type_t::_char: arg_list.set(index, (int8_t)scalar); break; |
255 | default: assert(!"Cannot convert scalar to the requested type." ); |
256 | } |
257 | } |
258 | |
259 | inline status_t check_scalar_arguments(const kernel_arg_list_t &arg_list, |
260 | const std::vector<scalar_type_t> &arg_types) { |
261 | for (int i = 0; i < arg_list.nargs(); i++) { |
262 | auto &arg = arg_list.get(i); |
263 | auto req_arg_type = arg_types[i]; |
264 | |
265 | if (!arg.is_global() && !arg.is_local() && !arg.is_svm_pointer()) { |
266 | if (req_arg_type == gpu::compute::scalar_type_t::undef) { |
267 | // Types of kernel arguments may not be available when zebin |
268 | // is used. |
269 | continue; |
270 | } |
271 | |
272 | if (req_arg_type != arg.scalar_type()) { |
273 | if (get_verbose()) { |
274 | printf("onednn_verbose,gpu,error,type of a scalar kernel " |
275 | "argument #%d is different from the type of the " |
276 | "given scalar\n" , |
277 | i); |
278 | fflush(nullptr); |
279 | } |
280 | return status::invalid_arguments; |
281 | } |
282 | } |
283 | } |
284 | return status::success; |
285 | } |
286 | |
287 | } // namespace compute |
288 | } // namespace gpu |
289 | } // namespace impl |
290 | } // namespace dnnl |
291 | |
292 | #endif // GPU_COMPUTE_KERNEL_ARG_LIST_HPP |
293 | |