1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @file
18/// C++ common API
19
20#ifndef ONEAPI_DNNL_DNNL_COMMON_HPP
21#define ONEAPI_DNNL_DNNL_COMMON_HPP
22
23/// @cond DO_NOT_DOCUMENT_THIS
24#include <algorithm>
25#include <cstdlib>
26#include <iterator>
27#include <memory>
28#include <string>
29#include <vector>
30#include <unordered_map>
31
32#include "oneapi/dnnl/dnnl_common.h"
33
34/// @endcond
35
36// __cpp_exceptions is referred from
37// https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
38// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
39// Microsoft C++ Compiler does not provide an option to disable exceptions
40#ifndef DNNL_ENABLE_EXCEPTIONS
41#if __cpp_exceptions || __EXCEPTIONS \
42 || (defined(_MSC_VER) && !defined(__clang__))
43#define DNNL_ENABLE_EXCEPTIONS 1
44#else
45#define DNNL_ENABLE_EXCEPTIONS 0
46#endif
47#endif
48
49#if defined(__GNUC__) || defined(__clang__)
50#define DNNL_TRAP() __builtin_trap()
51#elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
52#define DNNL_TRAP() __debugbreak()
53#else
54#error "unknown compiler"
55#endif
56
57#if DNNL_ENABLE_EXCEPTIONS
58#define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
59#else
60#include <cstdio>
61#define DNNL_THROW_ERROR(status, msg) \
62 do { \
63 fputs(msg, stderr); \
64 DNNL_TRAP(); \
65 } while (0)
66#endif
67
68/// @addtogroup dnnl_api oneDNN API
69/// @{
70
71/// oneDNN namespace
72namespace dnnl {
73
74/// @addtogroup dnnl_api_common Common API
75/// @{
76
77/// @addtogroup dnnl_api_utils Utilities
78/// Utility types and definitions.
79/// @{
80
81/// oneDNN exception class.
82///
83/// This class captures the status returned by a failed C API function and
84/// the error message from the call site.
85struct error : public std::exception {
86 dnnl_status_t status;
87 const char *message;
88
89 /// Constructs an instance of an exception class.
90 ///
91 /// @param status The error status returned by a C API function.
92 /// @param message The error message.
93 error(dnnl_status_t status, const char *message)
94 : status(status), message(message) {}
95
96 /// Returns the explanatory string.
97 const char *what() const noexcept override { return message; }
98
99 /// A convenience function for wrapping calls to C API functions. Checks
100 /// the return status and throws an dnnl::error in case of failure.
101 ///
102 /// @param status The error status returned by a C API function.
103 /// @param message The error message.
104 static void wrap_c_api(dnnl_status_t status, const char *message) {
105 if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
106 }
107};
108
109/// A class that provides the destructor for a oneDNN C API handle.
110template <typename T>
111struct handle_traits {};
112
113/// oneDNN C API handle wrapper class.
114///
115/// This class is used as the base class for primitive (dnnl::primitive),
116/// engine (dnnl::engine), and stream (dnnl::stream) classes, as well as
117/// others. An object of the dnnl::handle class can be passed by value.
118///
119/// A handle can be weak, in which case it follows std::weak_ptr semantics.
120/// Otherwise, it follows `std::shared_ptr` semantics.
121///
122/// @note
123/// The implementation stores oneDNN C API handles in a `std::shared_ptr`
124/// with deleter set to a dummy function in the weak mode.
125///
126template <typename T, typename traits = handle_traits<T>>
127struct handle {
128private:
129 static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
130 std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
131
132protected:
133 bool operator==(const T other) const { return other == data_.get(); }
134 bool operator!=(const T other) const { return !(*this == other); }
135
136public:
137 /// Constructs an empty handle object.
138 ///
139 /// @warning
140 /// Uninitialized object cannot be used in most library calls and is
141 /// equivalent to a null pointer. Any attempt to use its methods, or
142 /// passing it to the other library function, will cause an exception
143 /// to be thrown.
144 handle() = default;
145
146 /// Copy constructor.
147 handle(const handle<T, traits> &) = default;
148 /// Assignment operator.
149 handle<T, traits> &operator=(const handle<T, traits> &) = default;
150 /// Move constructor.
151 handle(handle<T, traits> &&) = default;
152 /// Move assignment operator.
153 handle<T, traits> &operator=(handle<T, traits> &&) = default;
154
155 /// Constructs a handle wrapper object from a C API handle.
156 ///
157 /// @param t The C API handle to wrap.
158 /// @param weak A flag specifying whether to construct a weak wrapper;
159 /// defaults to @c false.
160 explicit handle(T t, bool weak = false) { reset(t, weak); }
161
162 /// Resets the handle wrapper objects to wrap a new C API handle.
163 ///
164 /// @param t The new value of the C API handle.
165 /// @param weak A flag specifying whether the wrapper should be weak;
166 /// defaults to @c false.
167 void reset(T t, bool weak = false) {
168 data_.reset(t, weak ? &dummy_destructor : traits::destructor);
169 }
170
171 /// Returns the underlying C API handle.
172 ///
173 /// @param allow_empty A flag signifying whether the method is allowed to
174 /// return an empty (null) object without throwing an exception.
175 /// @returns The underlying C API handle.
176 T get(bool allow_empty = false) const {
177 T result = data_.get();
178 if (allow_empty == false && result == nullptr)
179 DNNL_THROW_ERROR(
180 dnnl_invalid_arguments, "object is not initialized");
181 return result;
182 }
183
184 /// Converts a handle to the underlying C API handle type. Does not throw
185 /// and returns `nullptr` if the object is empty.
186 ///
187 /// @returns The underlying C API handle.
188 explicit operator T() const { return get(true); }
189
190 /// Checks whether the object is not empty.
191 ///
192 /// @returns Whether the object is not empty.
193 explicit operator bool() const { return get(true) != nullptr; }
194
195 /// Equality operator.
196 ///
197 /// @param other Another handle wrapper.
198 /// @returns @c true if this and the other handle wrapper manage the same
199 /// underlying C API handle, and @c false otherwise. Empty handle
200 /// objects are considered to be equal.
201 bool operator==(const handle<T, traits> &other) const {
202 return other.data_.get() == data_.get();
203 }
204
205 /// Inequality operator.
206 ///
207 /// @param other Another handle wrapper.
208 /// @returns @c true if this and the other handle wrapper manage different
209 /// underlying C API handles, and @c false otherwise. Empty handle
210 /// objects are considered to be equal.
211 bool operator!=(const handle &other) const { return !(*this == other); }
212};
213
214/// @} dnnl_api_utils
215
216/// @addtogroup dnnl_api_engine Engine
217///
218/// An abstraction of a computational device: a CPU, a specific GPU
219/// card in the system, etc. Most primitives are created to execute
220/// computations on one specific engine. The only exceptions are reorder
221/// primitives that transfer data between two different engines.
222///
223/// @sa @ref dev_guide_basic_concepts
224///
225/// @{
226
227/// @cond DO_NOT_DOCUMENT_THIS
228template <>
229struct handle_traits<dnnl_engine_t> {
230 static dnnl_status_t destructor(dnnl_engine_t p) {
231 return dnnl_engine_destroy(p);
232 }
233};
234/// @endcond
235
236/// An execution engine.
237struct engine : public handle<dnnl_engine_t> {
238 friend struct primitive;
239 friend struct reorder;
240
241 /// Kinds of engines.
242 enum class kind {
243 /// An unspecified engine
244 any = dnnl_any_engine,
245 /// CPU engine
246 cpu = dnnl_cpu,
247 /// GPU engine
248 gpu = dnnl_gpu,
249 };
250
251 using handle::handle;
252
253 /// Constructs an empty engine. An empty engine cannot be used in any
254 /// operations.
255 engine() = default;
256
257 /// Returns the number of engines of a certain kind.
258 ///
259 /// @param akind The kind of engines to count.
260 /// @returns The number of engines of the specified kind.
261 static size_t get_count(kind akind) {
262 return dnnl_engine_get_count(convert_to_c(akind));
263 }
264
265 /// Constructs an engine.
266 ///
267 /// @param akind The kind of engine to construct.
268 /// @param index The index of the engine. Must be less than the value
269 /// returned by #get_count() for this particular kind of engine.
270 engine(kind akind, size_t index) {
271 dnnl_engine_t engine;
272 error::wrap_c_api(
273 dnnl_engine_create(&engine, convert_to_c(akind), index),
274 "could not create an engine");
275 reset(engine);
276 }
277
278 /// Returns the kind of the engine.
279 /// @returns The kind of the engine.
280 kind get_kind() const {
281 dnnl_engine_kind_t kind;
282 error::wrap_c_api(dnnl_engine_get_kind(get(), &kind),
283 "could not get kind of an engine");
284 return static_cast<engine::kind>(kind);
285 }
286
287private:
288 static dnnl_engine_kind_t convert_to_c(kind akind) {
289 return static_cast<dnnl_engine_kind_t>(akind);
290 }
291};
292
293/// Converts engine kind enum value from C++ API to C API type.
294///
295/// @param akind C++ API engine kind enum value.
296/// @returns Corresponding C API engine kind enum value.
297inline dnnl_engine_kind_t convert_to_c(engine::kind akind) {
298 return static_cast<dnnl_engine_kind_t>(akind);
299}
300
301/// @} dnnl_api_engine
302
303/// @addtogroup dnnl_api_stream Stream
304///
305/// An encapsulation of execution context tied to a particular engine.
306///
307/// @sa @ref dev_guide_basic_concepts
308///
309/// @{
310
311/// @cond DO_NOT_DOCUMENT_THIS
312template <>
313struct handle_traits<dnnl_stream_t> {
314 static dnnl_status_t destructor(dnnl_stream_t p) {
315 return dnnl_stream_destroy(p);
316 }
317};
318/// @endcond
319
320/// An execution stream.
321struct stream : public handle<dnnl_stream_t> {
322 using handle::handle;
323
324 /// Stream flags. Can be combined using the bitwise OR operator.
325 enum class flags : unsigned {
326 /// In-order execution.
327 in_order = dnnl_stream_in_order,
328 /// Out-of-order execution.
329 out_of_order = dnnl_stream_out_of_order,
330 /// Default stream configuration.
331 default_flags = dnnl_stream_default_flags,
332 };
333
334 /// Constructs an empty stream. An empty stream cannot be used in any
335 /// operations.
336 stream() = default;
337
338 /// Constructs a stream for the specified engine and with behavior
339 /// controlled by the specified flags.
340 ///
341 /// @param aengine Engine to create the stream on.
342 /// @param aflags Flags controlling stream behavior.
343 explicit stream(
344 const engine &aengine, flags aflags = flags::default_flags) {
345 dnnl_stream_t stream;
346 error::wrap_c_api(dnnl_stream_create(&stream, aengine.get(),
347 static_cast<dnnl_stream_flags_t>(aflags)),
348 "could not create a stream");
349 reset(stream);
350 }
351
352 /// Returns the associated engine.
353 engine get_engine() const {
354 dnnl_engine_t c_engine;
355 error::wrap_c_api(dnnl_stream_get_engine(get(), &c_engine),
356 "could not get an engine from a stream object");
357 return engine(c_engine, true);
358 }
359
360 /// Waits for all primitives executing in the stream to finish.
361 /// @returns The stream itself.
362 stream &wait() {
363 error::wrap_c_api(
364 dnnl_stream_wait(get()), "could not wait on a stream");
365 return *this;
366 }
367};
368
369#define DNNL_DEFINE_BITMASK_OPS(enum_name) \
370 inline enum_name operator|(enum_name lhs, enum_name rhs) { \
371 return static_cast<enum_name>( \
372 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
373 } \
374\
375 inline enum_name operator&(enum_name lhs, enum_name rhs) { \
376 return static_cast<enum_name>( \
377 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
378 } \
379\
380 inline enum_name operator^(enum_name lhs, enum_name rhs) { \
381 return static_cast<enum_name>( \
382 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
383 } \
384\
385 inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
386 lhs = static_cast<enum_name>( \
387 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
388 return lhs; \
389 } \
390\
391 inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
392 lhs = static_cast<enum_name>( \
393 static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
394 return lhs; \
395 } \
396\
397 inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
398 lhs = static_cast<enum_name>( \
399 static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
400 return lhs; \
401 } \
402\
403 inline enum_name operator~(enum_name rhs) { \
404 return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
405 }
406
407DNNL_DEFINE_BITMASK_OPS(stream::flags)
408
409/// @} dnnl_api_stream
410
411/// @addtogroup dnnl_api_fpmath_mode Floating-point Math Mode
412/// @{
413
414/// Floating-point math mode
415enum class fpmath_mode {
416 /// Default behavior, no downconversions allowed
417 strict = dnnl_fpmath_mode_strict,
418 /// Implicit f32->bf16 conversions allowed
419 bf16 = dnnl_fpmath_mode_bf16,
420 /// Implicit f32->f16 conversions allowed
421 f16 = dnnl_fpmath_mode_f16,
422 /// Implicit f32->tf32 conversions allowed
423 tf32 = dnnl_fpmath_mode_tf32,
424 /// Implicit f32->f16 or f32->bf16 conversions allowed
425 any = dnnl_fpmath_mode_any
426};
427
428/// Converts an fpmath mode enum value from C++ API to C API type.
429///
430/// @param mode C++ API fpmath mode enum value.
431/// @returns Corresponding C API fpmath mode enum value.
432inline dnnl_fpmath_mode_t convert_to_c(fpmath_mode mode) {
433 return static_cast<dnnl_fpmath_mode_t>(mode);
434}
435
436/// @} dnnl_api_fpmath_mode
437
438/// @} dnnl_api_common
439
440} // namespace dnnl
441
442/// @} dnnl_api
443
444#endif
445