1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef UTILS_WRAPPER_HPP |
18 | #define UTILS_WRAPPER_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "common.hpp" |
23 | #include "dnnl_debug.hpp" |
24 | |
25 | #define CHECK_DESTROY(f) \ |
26 | do { \ |
27 | dnnl_status_t status__ = f; \ |
28 | if (status__ != dnnl_success) { \ |
29 | BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \ |
30 | __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \ |
31 | status2str(status__), (int)status__); \ |
32 | fflush(0); \ |
33 | exit(2); \ |
34 | } \ |
35 | } while (0) |
36 | |
37 | template <typename T> |
38 | struct dnnl_api_traits; |
39 | //{ |
40 | // static void destroy(T t) {} |
41 | //}; |
42 | |
43 | template <> |
44 | struct dnnl_api_traits<dnnl_primitive_t> { |
45 | static void destroy(dnnl_primitive_t t) { |
46 | CHECK_DESTROY(dnnl_primitive_destroy(t)); |
47 | } |
48 | }; |
49 | |
50 | template <> |
51 | struct dnnl_api_traits<dnnl_primitive_desc_t> { |
52 | static void destroy(dnnl_primitive_desc_t t) { |
53 | CHECK_DESTROY(dnnl_primitive_desc_destroy(t)); |
54 | } |
55 | }; |
56 | |
57 | template <> |
58 | struct dnnl_api_traits<dnnl_primitive_attr_t> { |
59 | static void destroy(dnnl_primitive_attr_t t) { |
60 | CHECK_DESTROY(dnnl_primitive_attr_destroy(t)); |
61 | } |
62 | }; |
63 | |
64 | template <> |
65 | struct dnnl_api_traits<dnnl_memory_desc_t> { |
66 | static void destroy(dnnl_memory_desc_t t) { |
67 | CHECK_DESTROY(dnnl_memory_desc_destroy(t)); |
68 | } |
69 | }; |
70 | |
71 | // Generic class providing RAII support for DNNL objects in benchdnn |
72 | template <typename T> |
73 | struct benchdnn_dnnl_wrapper_t { |
74 | benchdnn_dnnl_wrapper_t(T t = nullptr) : t_(t) { |
75 | static_assert(std::is_pointer<T>::value, "T is not a pointer type." ); |
76 | } |
77 | |
78 | benchdnn_dnnl_wrapper_t &operator=(benchdnn_dnnl_wrapper_t &&rhs) { |
79 | if (this == &rhs) return *this; |
80 | reset(rhs.release()); |
81 | return *this; |
82 | } |
83 | |
84 | benchdnn_dnnl_wrapper_t(benchdnn_dnnl_wrapper_t &&rhs) { |
85 | t_ = nullptr; |
86 | *this = std::move(rhs); |
87 | } |
88 | |
89 | ~benchdnn_dnnl_wrapper_t() { do_destroy(); } |
90 | |
91 | T release() { |
92 | T tmp = t_; |
93 | t_ = nullptr; |
94 | return tmp; |
95 | } |
96 | |
97 | void reset(T t) { |
98 | do_destroy(); |
99 | t_ = t; |
100 | } |
101 | |
102 | operator T() const { return t_; } |
103 | |
104 | BENCHDNN_DISALLOW_COPY_AND_ASSIGN(benchdnn_dnnl_wrapper_t); |
105 | |
106 | private: |
107 | T t_; |
108 | |
109 | void do_destroy() { |
110 | if (t_) { dnnl_api_traits<T>::destroy(t_); } |
111 | } |
112 | }; |
113 | |
114 | // Constructs a wrapper object (providing RAII support) |
115 | template <typename T> |
116 | benchdnn_dnnl_wrapper_t<T> make_benchdnn_dnnl_wrapper(T t) { |
117 | return benchdnn_dnnl_wrapper_t<T>(t); |
118 | } |
119 | |
120 | #undef CHECK_DESTROY |
121 | |
122 | #endif |
123 | |