1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef DNNL_TEST_COMMON_OCL_HPP |
18 | #define DNNL_TEST_COMMON_OCL_HPP |
19 | |
20 | #include "gpu/ocl/ocl_utils.hpp" |
21 | #include "oneapi/dnnl/dnnl.hpp" |
22 | #include "oneapi/dnnl/dnnl_debug.h" |
23 | #include "oneapi/dnnl/dnnl_ocl.hpp" |
24 | #include "gtest/gtest.h" |
25 | |
26 | #include <CL/cl.h> |
27 | |
28 | // Define a separate macro, that does not clash with OCL_CHECK from the library. |
29 | #ifdef DNNL_ENABLE_MEM_DEBUG |
30 | #define TEST_OCL_CHECK(x) \ |
31 | do { \ |
32 | dnnl_status_t s = dnnl::impl::gpu::ocl::convert_to_dnnl(x); \ |
33 | dnnl::error::wrap_c_api(s, dnnl_status2str(s)); \ |
34 | } while (0) |
35 | #else |
36 | #define TEST_OCL_CHECK(x) \ |
37 | do { \ |
38 | int s = int(x); \ |
39 | EXPECT_EQ(s, CL_SUCCESS) << "OpenCL error: " << s; \ |
40 | } while (0) |
41 | #endif |
42 | |
43 | static inline cl_device_id find_ocl_device(cl_device_type dev_type) { |
44 | cl_int err; |
45 | const size_t max_platforms = 16; |
46 | |
47 | cl_uint nplatforms; |
48 | cl_platform_id ocl_platforms[max_platforms]; |
49 | err = clGetPlatformIDs(max_platforms, ocl_platforms, &nplatforms); |
50 | if (err != CL_SUCCESS) { |
51 | // OpenCL has no support on the platform. |
52 | return nullptr; |
53 | } |
54 | |
55 | for (cl_uint i = 0; i < nplatforms; ++i) { |
56 | cl_platform_id ocl_platform = ocl_platforms[i]; |
57 | |
58 | const size_t max_platform_vendor_size = 256; |
59 | std::string platform_vendor(max_platform_vendor_size + 1, 0); |
60 | TEST_OCL_CHECK(clGetPlatformInfo(ocl_platform, CL_PLATFORM_VENDOR, |
61 | max_platform_vendor_size * sizeof(char), &platform_vendor[0], |
62 | nullptr)); |
63 | cl_uint ndevices; |
64 | cl_device_id ocl_dev; |
65 | err = clGetDeviceIDs(ocl_platform, dev_type, 1, &ocl_dev, &ndevices); |
66 | if (err == CL_SUCCESS) { return ocl_dev; } |
67 | } |
68 | return nullptr; |
69 | } |
70 | |
71 | // Base generic class providing RAII support for OpenCL objects |
72 | template <typename T, typename release_t = int32_t(T)> |
73 | struct ocl_wrapper_base_t { |
74 | ocl_wrapper_base_t(T t, release_t *release = nullptr) |
75 | : t_(t), release_(release) {} |
76 | ocl_wrapper_base_t(ocl_wrapper_base_t &&other) |
77 | : t_(other.t_), release_(other.release_) { |
78 | other.t_ = nullptr; |
79 | } |
80 | ~ocl_wrapper_base_t() { |
81 | if (release_ && t_) { release_(t_); } |
82 | } |
83 | |
84 | ocl_wrapper_base_t(const ocl_wrapper_base_t &) = delete; |
85 | ocl_wrapper_base_t &operator=(const ocl_wrapper_base_t &) = delete; |
86 | |
87 | operator T() const { return t_; } |
88 | |
89 | private: |
90 | T t_; |
91 | release_t *release_; |
92 | }; |
93 | |
94 | // Auxiliary class providing RAII support for OpenCL objects, |
95 | // specialized for specific OpenCL types |
96 | template <typename T> |
97 | struct ocl_wrapper_t {}; |
98 | |
99 | template <> |
100 | struct ocl_wrapper_t<cl_device_id> : ocl_wrapper_base_t<cl_device_id> { |
101 | ocl_wrapper_t(cl_device_id dev) : ocl_wrapper_base_t(dev) {} |
102 | }; |
103 | |
104 | template <> |
105 | struct ocl_wrapper_t<cl_context> : ocl_wrapper_base_t<cl_context> { |
106 | ocl_wrapper_t(cl_context ctx) |
107 | : ocl_wrapper_base_t(ctx, &clReleaseContext) {} |
108 | }; |
109 | |
110 | template <> |
111 | struct ocl_wrapper_t<cl_command_queue> : ocl_wrapper_base_t<cl_command_queue> { |
112 | ocl_wrapper_t(cl_command_queue queue) |
113 | : ocl_wrapper_base_t(queue, &clReleaseCommandQueue) {} |
114 | }; |
115 | |
116 | template <> |
117 | struct ocl_wrapper_t<cl_kernel> : ocl_wrapper_base_t<cl_kernel> { |
118 | ocl_wrapper_t(cl_kernel kernel) |
119 | : ocl_wrapper_base_t(kernel, &clReleaseKernel) {} |
120 | }; |
121 | |
122 | // Constructs an OpenCL wrapper object (providing RAII support) |
123 | template <typename T> |
124 | ocl_wrapper_t<T> make_ocl_wrapper(T t) { |
125 | return ocl_wrapper_t<T>(t); |
126 | } |
127 | |
128 | #endif |
129 | |