1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_COMPUTE_KERNEL_ARG_LIST_HPP
18#define GPU_COMPUTE_KERNEL_ARG_LIST_HPP
19
20#include <cassert>
21#include <cstddef>
22#include <type_traits>
23
24#include "common/bfloat16.hpp"
25#include "common/float16.hpp"
26#include "common/memory_storage.hpp"
27#include "common/nstl.hpp"
28#include "common/verbose.hpp"
29
30#include "gpu/zero_pad_struct.h"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace compute {
36
37enum class kernel_arg_kind_t {
38 undef,
39 global,
40 local,
41 scalar,
42 svm,
43};
44
45enum class scalar_type_t {
46 undef,
47 _char,
48 _bfloat16,
49 _float,
50 _half,
51 _int,
52 _long,
53 _short,
54 _uchar,
55 _uint,
56 _ulong,
57 _ushort,
58 _zero_pad_mask_t,
59};
60
61template <typename T>
62struct scalar_type_traits {};
63
64template <>
65struct scalar_type_traits<float16_t> {
66 static const auto type = scalar_type_t::_half;
67};
68template <>
69struct scalar_type_traits<bfloat16_t> {
70 static const auto type = scalar_type_t::_bfloat16;
71};
72template <>
73struct scalar_type_traits<float> {
74 static const auto type = scalar_type_t::_float;
75};
76
77template <>
78struct scalar_type_traits<uint8_t> {
79 static const auto type = scalar_type_t::_uchar;
80};
81template <>
82struct scalar_type_traits<uint16_t> {
83 static const auto type = scalar_type_t::_ushort;
84};
85template <>
86struct scalar_type_traits<uint32_t> {
87 static const auto type = scalar_type_t::_uint;
88};
89template <>
90struct scalar_type_traits<uint64_t> {
91 static const auto type = scalar_type_t::_ulong;
92};
93
94template <>
95struct scalar_type_traits<int8_t> {
96 static const auto type = scalar_type_t::_char;
97};
98template <>
99struct scalar_type_traits<int16_t> {
100 static const auto type = scalar_type_t::_short;
101};
102template <>
103struct scalar_type_traits<int32_t> {
104 static const auto type = scalar_type_t::_int;
105};
106template <>
107struct scalar_type_traits<int64_t> {
108 static const auto type = scalar_type_t::_long;
109};
110template <>
111struct scalar_type_traits<zero_pad_mask_t> {
112 static const auto type = scalar_type_t::_zero_pad_mask_t;
113};
114
115class kernel_arg_t {
116public:
117 kernel_arg_kind_t kind() const { return kind_; }
118 scalar_type_t scalar_type() const { return scalar_type_; }
119 size_t size() const { return size_; }
120
121 bool is_global() const { return kind() == kernel_arg_kind_t::global; }
122 bool is_local() const { return kind() == kernel_arg_kind_t::local; }
123 bool is_svm_pointer() const { return kind_ == kernel_arg_kind_t::svm; }
124
125 kernel_arg_t &set_value(const memory_storage_t &storage) {
126 kind_ = kernel_arg_kind_t::global;
127 size_ = 0;
128 value_ = static_cast<const void *>(&storage);
129 return *this;
130 }
131
132 template <typename T>
133 kernel_arg_t &set_value(const T &value, void *&data_pool) {
134 assert(size_ <= sizeof(T));
135 if (value_ == nullptr) {
136 assert(data_pool != nullptr);
137 size_ = sizeof(T);
138 data_pool = utils::align_ptr(data_pool, alignof(T));
139 value_ = data_pool;
140 data_pool = static_cast<char *>(data_pool) + size_;
141 }
142 kind_ = kernel_arg_kind_t::scalar;
143 scalar_type_ = scalar_type_traits<T>::type;
144 new (const_cast<void *>(value_)) T(value);
145 return *this;
146 }
147
148 kernel_arg_t &set_value(size_t size, std::nullptr_t) {
149 kind_ = kernel_arg_kind_t::local;
150 size_ = size;
151 value_ = nullptr;
152 return *this;
153 }
154
155 void set_value(void *svm_ptr, kernel_arg_kind_t kind) {
156 assert(kind == kernel_arg_kind_t::svm);
157 kind_ = kernel_arg_kind_t::svm;
158 size_ = 0;
159 value_ = svm_ptr;
160 }
161
162 const void *value() const {
163 assert(kind() != kernel_arg_kind_t::undef);
164 return value_;
165 }
166
167 template <typename T>
168 T as() const {
169 assert(kind() == kernel_arg_kind_t::scalar);
170 assert(scalar_type() == scalar_type_traits<T>::type);
171 return *(const T *)value();
172 }
173
174private:
175 kernel_arg_kind_t kind_ = kernel_arg_kind_t::undef;
176 scalar_type_t scalar_type_ = scalar_type_t::undef;
177 size_t size_ = 0;
178 const void *value_ = nullptr;
179};
180
181class kernel_arg_list_t {
182public:
183 kernel_arg_list_t() { nargs_ = 0; }
184 void set(int index, const memory_storage_t &storage) {
185 assert(index < max_args);
186 nargs_ = nstl::max(nargs_, index + 1);
187 args_[index].set_value(storage);
188 }
189
190 void set(int index, void *value, kernel_arg_kind_t kind) {
191 assert(index < max_args);
192 nargs_ = nstl::max(nargs_, index + 1);
193 args_[index].set_value(value, kind);
194 }
195
196 template <class T>
197 void set(int index, const T &value) {
198 assert(index < max_args);
199 nargs_ = nstl::max(nargs_, index + 1);
200 args_[index].set_value(value, unused_storage);
201
202 assert(unused_storage
203 <= reinterpret_cast<char *>(&scalar_storage_) + storage_size);
204 }
205
206 void set(int index, size_t size, std::nullptr_t) {
207 assert(index < max_args);
208 nargs_ = nstl::max(nargs_, index + 1);
209 args_[index].set_value(size, nullptr);
210 }
211
212 int nargs() const { return nargs_; }
213
214 const kernel_arg_t &get(int index) const {
215 assert(index < nargs());
216 return args_[index];
217 }
218
219 const memory_storage_t &get_memory_storage(int index) const {
220 assert(args_[index].kind() == kernel_arg_kind_t::global);
221 return *static_cast<const memory_storage_t *>(args_[index].value());
222 }
223
224private:
225 static constexpr int max_args = 96;
226 static constexpr int storage_size = 512;
227 static constexpr int storage_alginment = 8;
228
229 int nargs_ = 0;
230 kernel_arg_t args_[max_args];
231 typename std::aligned_storage<storage_size, storage_alginment>::type
232 scalar_storage_;
233 void *unused_storage = &scalar_storage_;
234
235 kernel_arg_list_t(const kernel_arg_list_t &) = delete;
236 kernel_arg_list_t(kernel_arg_list_t &&) = delete;
237 kernel_arg_list_t &operator=(const kernel_arg_list_t &) = delete;
238 kernel_arg_list_t &operator=(kernel_arg_list_t &&) = delete;
239};
240
241template <typename T>
242void set_scalar_arg_cvt(kernel_arg_list_t &arg_list, int index, T scalar,
243 scalar_type_t requested_type) {
244 if (scalar_type_traits<T>::type == requested_type) {
245 arg_list.set(index, scalar);
246 return;
247 }
248
249 switch (requested_type) {
250 case scalar_type_t::_half:
251 arg_list.set(index, (float16_t)scalar);
252 break;
253 case scalar_type_t::_uchar: arg_list.set(index, (uint8_t)scalar); break;
254 case scalar_type_t::_char: arg_list.set(index, (int8_t)scalar); break;
255 default: assert(!"Cannot convert scalar to the requested type.");
256 }
257}
258
259inline status_t check_scalar_arguments(const kernel_arg_list_t &arg_list,
260 const std::vector<scalar_type_t> &arg_types) {
261 for (int i = 0; i < arg_list.nargs(); i++) {
262 auto &arg = arg_list.get(i);
263 auto req_arg_type = arg_types[i];
264
265 if (!arg.is_global() && !arg.is_local() && !arg.is_svm_pointer()) {
266 if (req_arg_type == gpu::compute::scalar_type_t::undef) {
267 // Types of kernel arguments may not be available when zebin
268 // is used.
269 continue;
270 }
271
272 if (req_arg_type != arg.scalar_type()) {
273 if (get_verbose()) {
274 printf("onednn_verbose,gpu,error,type of a scalar kernel "
275 "argument #%d is different from the type of the "
276 "given scalar\n",
277 i);
278 fflush(nullptr);
279 }
280 return status::invalid_arguments;
281 }
282 }
283 }
284 return status::success;
285}
286
287} // namespace compute
288} // namespace gpu
289} // namespace impl
290} // namespace dnnl
291
292#endif // GPU_COMPUTE_KERNEL_ARG_LIST_HPP
293