1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /// @file |
18 | /// C++ common API |
19 | |
20 | #ifndef ONEAPI_DNNL_DNNL_COMMON_HPP |
21 | #define ONEAPI_DNNL_DNNL_COMMON_HPP |
22 | |
23 | /// @cond DO_NOT_DOCUMENT_THIS |
24 | #include <algorithm> |
25 | #include <cstdlib> |
26 | #include <iterator> |
27 | #include <memory> |
28 | #include <string> |
29 | #include <vector> |
30 | #include <unordered_map> |
31 | |
32 | #include "oneapi/dnnl/dnnl_common.h" |
33 | |
34 | /// @endcond |
35 | |
36 | // __cpp_exceptions is referred from |
37 | // https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html |
38 | // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, |
39 | // Microsoft C++ Compiler does not provide an option to disable exceptions |
40 | #ifndef DNNL_ENABLE_EXCEPTIONS |
41 | #if __cpp_exceptions || __EXCEPTIONS \ |
42 | || (defined(_MSC_VER) && !defined(__clang__)) |
43 | #define DNNL_ENABLE_EXCEPTIONS 1 |
44 | #else |
45 | #define DNNL_ENABLE_EXCEPTIONS 0 |
46 | #endif |
47 | #endif |
48 | |
49 | #if defined(__GNUC__) || defined(__clang__) |
50 | #define DNNL_TRAP() __builtin_trap() |
51 | #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) |
52 | #define DNNL_TRAP() __debugbreak() |
53 | #else |
54 | #error "unknown compiler" |
55 | #endif |
56 | |
57 | #if DNNL_ENABLE_EXCEPTIONS |
58 | #define DNNL_THROW_ERROR(status, msg) throw error(status, msg) |
59 | #else |
60 | #include <cstdio> |
61 | #define DNNL_THROW_ERROR(status, msg) \ |
62 | do { \ |
63 | fputs(msg, stderr); \ |
64 | DNNL_TRAP(); \ |
65 | } while (0) |
66 | #endif |
67 | |
68 | /// @addtogroup dnnl_api oneDNN API |
69 | /// @{ |
70 | |
71 | /// oneDNN namespace |
72 | namespace dnnl { |
73 | |
74 | /// @addtogroup dnnl_api_common Common API |
75 | /// @{ |
76 | |
77 | /// @addtogroup dnnl_api_utils Utilities |
78 | /// Utility types and definitions. |
79 | /// @{ |
80 | |
81 | /// oneDNN exception class. |
82 | /// |
83 | /// This class captures the status returned by a failed C API function and |
84 | /// the error message from the call site. |
85 | struct error : public std::exception { |
86 | dnnl_status_t status; |
87 | const char *message; |
88 | |
89 | /// Constructs an instance of an exception class. |
90 | /// |
91 | /// @param status The error status returned by a C API function. |
92 | /// @param message The error message. |
93 | error(dnnl_status_t status, const char *message) |
94 | : status(status), message(message) {} |
95 | |
96 | /// Returns the explanatory string. |
97 | const char *what() const noexcept override { return message; } |
98 | |
99 | /// A convenience function for wrapping calls to C API functions. Checks |
100 | /// the return status and throws an dnnl::error in case of failure. |
101 | /// |
102 | /// @param status The error status returned by a C API function. |
103 | /// @param message The error message. |
104 | static void wrap_c_api(dnnl_status_t status, const char *message) { |
105 | if (status != dnnl_success) DNNL_THROW_ERROR(status, message); |
106 | } |
107 | }; |
108 | |
109 | /// A class that provides the destructor for a oneDNN C API handle. |
110 | template <typename T> |
111 | struct handle_traits {}; |
112 | |
113 | /// oneDNN C API handle wrapper class. |
114 | /// |
115 | /// This class is used as the base class for primitive (dnnl::primitive), |
116 | /// engine (dnnl::engine), and stream (dnnl::stream) classes, as well as |
117 | /// others. An object of the dnnl::handle class can be passed by value. |
118 | /// |
119 | /// A handle can be weak, in which case it follows std::weak_ptr semantics. |
120 | /// Otherwise, it follows `std::shared_ptr` semantics. |
121 | /// |
122 | /// @note |
123 | /// The implementation stores oneDNN C API handles in a `std::shared_ptr` |
124 | /// with deleter set to a dummy function in the weak mode. |
125 | /// |
126 | template <typename T, typename traits = handle_traits<T>> |
127 | struct handle { |
128 | private: |
129 | static dnnl_status_t dummy_destructor(T) { return dnnl_success; } |
130 | std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0}; |
131 | |
132 | protected: |
133 | bool operator==(const T other) const { return other == data_.get(); } |
134 | bool operator!=(const T other) const { return !(*this == other); } |
135 | |
136 | public: |
137 | /// Constructs an empty handle object. |
138 | /// |
139 | /// @warning |
140 | /// Uninitialized object cannot be used in most library calls and is |
141 | /// equivalent to a null pointer. Any attempt to use its methods, or |
142 | /// passing it to the other library function, will cause an exception |
143 | /// to be thrown. |
144 | handle() = default; |
145 | |
146 | /// Copy constructor. |
147 | handle(const handle<T, traits> &) = default; |
148 | /// Assignment operator. |
149 | handle<T, traits> &operator=(const handle<T, traits> &) = default; |
150 | /// Move constructor. |
151 | handle(handle<T, traits> &&) = default; |
152 | /// Move assignment operator. |
153 | handle<T, traits> &operator=(handle<T, traits> &&) = default; |
154 | |
155 | /// Constructs a handle wrapper object from a C API handle. |
156 | /// |
157 | /// @param t The C API handle to wrap. |
158 | /// @param weak A flag specifying whether to construct a weak wrapper; |
159 | /// defaults to @c false. |
160 | explicit handle(T t, bool weak = false) { reset(t, weak); } |
161 | |
162 | /// Resets the handle wrapper objects to wrap a new C API handle. |
163 | /// |
164 | /// @param t The new value of the C API handle. |
165 | /// @param weak A flag specifying whether the wrapper should be weak; |
166 | /// defaults to @c false. |
167 | void reset(T t, bool weak = false) { |
168 | data_.reset(t, weak ? &dummy_destructor : traits::destructor); |
169 | } |
170 | |
171 | /// Returns the underlying C API handle. |
172 | /// |
173 | /// @param allow_empty A flag signifying whether the method is allowed to |
174 | /// return an empty (null) object without throwing an exception. |
175 | /// @returns The underlying C API handle. |
176 | T get(bool allow_empty = false) const { |
177 | T result = data_.get(); |
178 | if (allow_empty == false && result == nullptr) |
179 | DNNL_THROW_ERROR( |
180 | dnnl_invalid_arguments, "object is not initialized" ); |
181 | return result; |
182 | } |
183 | |
184 | /// Converts a handle to the underlying C API handle type. Does not throw |
185 | /// and returns `nullptr` if the object is empty. |
186 | /// |
187 | /// @returns The underlying C API handle. |
188 | explicit operator T() const { return get(true); } |
189 | |
190 | /// Checks whether the object is not empty. |
191 | /// |
192 | /// @returns Whether the object is not empty. |
193 | explicit operator bool() const { return get(true) != nullptr; } |
194 | |
195 | /// Equality operator. |
196 | /// |
197 | /// @param other Another handle wrapper. |
198 | /// @returns @c true if this and the other handle wrapper manage the same |
199 | /// underlying C API handle, and @c false otherwise. Empty handle |
200 | /// objects are considered to be equal. |
201 | bool operator==(const handle<T, traits> &other) const { |
202 | return other.data_.get() == data_.get(); |
203 | } |
204 | |
205 | /// Inequality operator. |
206 | /// |
207 | /// @param other Another handle wrapper. |
208 | /// @returns @c true if this and the other handle wrapper manage different |
209 | /// underlying C API handles, and @c false otherwise. Empty handle |
210 | /// objects are considered to be equal. |
211 | bool operator!=(const handle &other) const { return !(*this == other); } |
212 | }; |
213 | |
214 | /// @} dnnl_api_utils |
215 | |
216 | /// @addtogroup dnnl_api_engine Engine |
217 | /// |
218 | /// An abstraction of a computational device: a CPU, a specific GPU |
219 | /// card in the system, etc. Most primitives are created to execute |
220 | /// computations on one specific engine. The only exceptions are reorder |
221 | /// primitives that transfer data between two different engines. |
222 | /// |
223 | /// @sa @ref dev_guide_basic_concepts |
224 | /// |
225 | /// @{ |
226 | |
227 | /// @cond DO_NOT_DOCUMENT_THIS |
228 | template <> |
229 | struct handle_traits<dnnl_engine_t> { |
230 | static dnnl_status_t destructor(dnnl_engine_t p) { |
231 | return dnnl_engine_destroy(p); |
232 | } |
233 | }; |
234 | /// @endcond |
235 | |
236 | /// An execution engine. |
237 | struct engine : public handle<dnnl_engine_t> { |
238 | friend struct primitive; |
239 | friend struct reorder; |
240 | |
241 | /// Kinds of engines. |
242 | enum class kind { |
243 | /// An unspecified engine |
244 | any = dnnl_any_engine, |
245 | /// CPU engine |
246 | cpu = dnnl_cpu, |
247 | /// GPU engine |
248 | gpu = dnnl_gpu, |
249 | }; |
250 | |
251 | using handle::handle; |
252 | |
253 | /// Constructs an empty engine. An empty engine cannot be used in any |
254 | /// operations. |
255 | engine() = default; |
256 | |
257 | /// Returns the number of engines of a certain kind. |
258 | /// |
259 | /// @param akind The kind of engines to count. |
260 | /// @returns The number of engines of the specified kind. |
261 | static size_t get_count(kind akind) { |
262 | return dnnl_engine_get_count(convert_to_c(akind)); |
263 | } |
264 | |
265 | /// Constructs an engine. |
266 | /// |
267 | /// @param akind The kind of engine to construct. |
268 | /// @param index The index of the engine. Must be less than the value |
269 | /// returned by #get_count() for this particular kind of engine. |
270 | engine(kind akind, size_t index) { |
271 | dnnl_engine_t engine; |
272 | error::wrap_c_api( |
273 | dnnl_engine_create(&engine, convert_to_c(akind), index), |
274 | "could not create an engine" ); |
275 | reset(engine); |
276 | } |
277 | |
278 | /// Returns the kind of the engine. |
279 | /// @returns The kind of the engine. |
280 | kind get_kind() const { |
281 | dnnl_engine_kind_t kind; |
282 | error::wrap_c_api(dnnl_engine_get_kind(get(), &kind), |
283 | "could not get kind of an engine" ); |
284 | return static_cast<engine::kind>(kind); |
285 | } |
286 | |
287 | private: |
288 | static dnnl_engine_kind_t convert_to_c(kind akind) { |
289 | return static_cast<dnnl_engine_kind_t>(akind); |
290 | } |
291 | }; |
292 | |
293 | /// Converts engine kind enum value from C++ API to C API type. |
294 | /// |
295 | /// @param akind C++ API engine kind enum value. |
296 | /// @returns Corresponding C API engine kind enum value. |
297 | inline dnnl_engine_kind_t convert_to_c(engine::kind akind) { |
298 | return static_cast<dnnl_engine_kind_t>(akind); |
299 | } |
300 | |
301 | /// @} dnnl_api_engine |
302 | |
303 | /// @addtogroup dnnl_api_stream Stream |
304 | /// |
305 | /// An encapsulation of execution context tied to a particular engine. |
306 | /// |
307 | /// @sa @ref dev_guide_basic_concepts |
308 | /// |
309 | /// @{ |
310 | |
311 | /// @cond DO_NOT_DOCUMENT_THIS |
312 | template <> |
313 | struct handle_traits<dnnl_stream_t> { |
314 | static dnnl_status_t destructor(dnnl_stream_t p) { |
315 | return dnnl_stream_destroy(p); |
316 | } |
317 | }; |
318 | /// @endcond |
319 | |
320 | /// An execution stream. |
321 | struct stream : public handle<dnnl_stream_t> { |
322 | using handle::handle; |
323 | |
324 | /// Stream flags. Can be combined using the bitwise OR operator. |
325 | enum class flags : unsigned { |
326 | /// In-order execution. |
327 | in_order = dnnl_stream_in_order, |
328 | /// Out-of-order execution. |
329 | out_of_order = dnnl_stream_out_of_order, |
330 | /// Default stream configuration. |
331 | default_flags = dnnl_stream_default_flags, |
332 | }; |
333 | |
334 | /// Constructs an empty stream. An empty stream cannot be used in any |
335 | /// operations. |
336 | stream() = default; |
337 | |
338 | /// Constructs a stream for the specified engine and with behavior |
339 | /// controlled by the specified flags. |
340 | /// |
341 | /// @param aengine Engine to create the stream on. |
342 | /// @param aflags Flags controlling stream behavior. |
343 | explicit stream( |
344 | const engine &aengine, flags aflags = flags::default_flags) { |
345 | dnnl_stream_t stream; |
346 | error::wrap_c_api(dnnl_stream_create(&stream, aengine.get(), |
347 | static_cast<dnnl_stream_flags_t>(aflags)), |
348 | "could not create a stream" ); |
349 | reset(stream); |
350 | } |
351 | |
352 | /// Returns the associated engine. |
353 | engine get_engine() const { |
354 | dnnl_engine_t c_engine; |
355 | error::wrap_c_api(dnnl_stream_get_engine(get(), &c_engine), |
356 | "could not get an engine from a stream object" ); |
357 | return engine(c_engine, true); |
358 | } |
359 | |
360 | /// Waits for all primitives executing in the stream to finish. |
361 | /// @returns The stream itself. |
362 | stream &wait() { |
363 | error::wrap_c_api( |
364 | dnnl_stream_wait(get()), "could not wait on a stream" ); |
365 | return *this; |
366 | } |
367 | }; |
368 | |
369 | #define DNNL_DEFINE_BITMASK_OPS(enum_name) \ |
370 | inline enum_name operator|(enum_name lhs, enum_name rhs) { \ |
371 | return static_cast<enum_name>( \ |
372 | static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \ |
373 | } \ |
374 | \ |
375 | inline enum_name operator&(enum_name lhs, enum_name rhs) { \ |
376 | return static_cast<enum_name>( \ |
377 | static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \ |
378 | } \ |
379 | \ |
380 | inline enum_name operator^(enum_name lhs, enum_name rhs) { \ |
381 | return static_cast<enum_name>( \ |
382 | static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \ |
383 | } \ |
384 | \ |
385 | inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \ |
386 | lhs = static_cast<enum_name>( \ |
387 | static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \ |
388 | return lhs; \ |
389 | } \ |
390 | \ |
391 | inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \ |
392 | lhs = static_cast<enum_name>( \ |
393 | static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \ |
394 | return lhs; \ |
395 | } \ |
396 | \ |
397 | inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \ |
398 | lhs = static_cast<enum_name>( \ |
399 | static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \ |
400 | return lhs; \ |
401 | } \ |
402 | \ |
403 | inline enum_name operator~(enum_name rhs) { \ |
404 | return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \ |
405 | } |
406 | |
407 | DNNL_DEFINE_BITMASK_OPS(stream::flags) |
408 | |
409 | /// @} dnnl_api_stream |
410 | |
411 | /// @addtogroup dnnl_api_fpmath_mode Floating-point Math Mode |
412 | /// @{ |
413 | |
414 | /// Floating-point math mode |
415 | enum class fpmath_mode { |
416 | /// Default behavior, no downconversions allowed |
417 | strict = dnnl_fpmath_mode_strict, |
418 | /// Implicit f32->bf16 conversions allowed |
419 | bf16 = dnnl_fpmath_mode_bf16, |
420 | /// Implicit f32->f16 conversions allowed |
421 | f16 = dnnl_fpmath_mode_f16, |
422 | /// Implicit f32->tf32 conversions allowed |
423 | tf32 = dnnl_fpmath_mode_tf32, |
424 | /// Implicit f32->f16 or f32->bf16 conversions allowed |
425 | any = dnnl_fpmath_mode_any |
426 | }; |
427 | |
428 | /// Converts an fpmath mode enum value from C++ API to C API type. |
429 | /// |
430 | /// @param mode C++ API fpmath mode enum value. |
431 | /// @returns Corresponding C API fpmath mode enum value. |
432 | inline dnnl_fpmath_mode_t convert_to_c(fpmath_mode mode) { |
433 | return static_cast<dnnl_fpmath_mode_t>(mode); |
434 | } |
435 | |
436 | /// @} dnnl_api_fpmath_mode |
437 | |
438 | /// @} dnnl_api_common |
439 | |
440 | } // namespace dnnl |
441 | |
442 | /// @} dnnl_api |
443 | |
444 | #endif |
445 | |