1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_WRAPPER_HPP
18#define UTILS_WRAPPER_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "common.hpp"
23#include "dnnl_debug.hpp"
24
25#define CHECK_DESTROY(f) \
26 do { \
27 dnnl_status_t status__ = f; \
28 if (status__ != dnnl_success) { \
29 BENCHDNN_PRINT(0, "error [%s:%d]: '%s' -> %s(%d)\n", \
30 __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \
31 status2str(status__), (int)status__); \
32 fflush(0); \
33 exit(2); \
34 } \
35 } while (0)
36
37template <typename T>
38struct dnnl_api_traits;
39//{
40// static void destroy(T t) {}
41//};
42
43template <>
44struct dnnl_api_traits<dnnl_primitive_t> {
45 static void destroy(dnnl_primitive_t t) {
46 CHECK_DESTROY(dnnl_primitive_destroy(t));
47 }
48};
49
50template <>
51struct dnnl_api_traits<dnnl_primitive_desc_t> {
52 static void destroy(dnnl_primitive_desc_t t) {
53 CHECK_DESTROY(dnnl_primitive_desc_destroy(t));
54 }
55};
56
57template <>
58struct dnnl_api_traits<dnnl_primitive_attr_t> {
59 static void destroy(dnnl_primitive_attr_t t) {
60 CHECK_DESTROY(dnnl_primitive_attr_destroy(t));
61 }
62};
63
64template <>
65struct dnnl_api_traits<dnnl_memory_desc_t> {
66 static void destroy(dnnl_memory_desc_t t) {
67 CHECK_DESTROY(dnnl_memory_desc_destroy(t));
68 }
69};
70
71// Generic class providing RAII support for DNNL objects in benchdnn
72template <typename T>
73struct benchdnn_dnnl_wrapper_t {
74 benchdnn_dnnl_wrapper_t(T t = nullptr) : t_(t) {
75 static_assert(std::is_pointer<T>::value, "T is not a pointer type.");
76 }
77
78 benchdnn_dnnl_wrapper_t &operator=(benchdnn_dnnl_wrapper_t &&rhs) {
79 if (this == &rhs) return *this;
80 reset(rhs.release());
81 return *this;
82 }
83
84 benchdnn_dnnl_wrapper_t(benchdnn_dnnl_wrapper_t &&rhs) {
85 t_ = nullptr;
86 *this = std::move(rhs);
87 }
88
89 ~benchdnn_dnnl_wrapper_t() { do_destroy(); }
90
91 T release() {
92 T tmp = t_;
93 t_ = nullptr;
94 return tmp;
95 }
96
97 void reset(T t) {
98 do_destroy();
99 t_ = t;
100 }
101
102 operator T() const { return t_; }
103
104 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(benchdnn_dnnl_wrapper_t);
105
106private:
107 T t_;
108
109 void do_destroy() {
110 if (t_) { dnnl_api_traits<T>::destroy(t_); }
111 }
112};
113
114// Constructs a wrapper object (providing RAII support)
115template <typename T>
116benchdnn_dnnl_wrapper_t<T> make_benchdnn_dnnl_wrapper(T t) {
117 return benchdnn_dnnl_wrapper_t<T>(t);
118}
119
120#undef CHECK_DESTROY
121
122#endif
123