1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef DNNL_TEST_COMMON_OCL_HPP
18#define DNNL_TEST_COMMON_OCL_HPP
19
20#include "gpu/ocl/ocl_utils.hpp"
21#include "oneapi/dnnl/dnnl.hpp"
22#include "oneapi/dnnl/dnnl_debug.h"
23#include "oneapi/dnnl/dnnl_ocl.hpp"
24#include "gtest/gtest.h"
25
26#include <CL/cl.h>
27
28// Define a separate macro, that does not clash with OCL_CHECK from the library.
29#ifdef DNNL_ENABLE_MEM_DEBUG
30#define TEST_OCL_CHECK(x) \
31 do { \
32 dnnl_status_t s = dnnl::impl::gpu::ocl::convert_to_dnnl(x); \
33 dnnl::error::wrap_c_api(s, dnnl_status2str(s)); \
34 } while (0)
35#else
36#define TEST_OCL_CHECK(x) \
37 do { \
38 int s = int(x); \
39 EXPECT_EQ(s, CL_SUCCESS) << "OpenCL error: " << s; \
40 } while (0)
41#endif
42
43static inline cl_device_id find_ocl_device(cl_device_type dev_type) {
44 cl_int err;
45 const size_t max_platforms = 16;
46
47 cl_uint nplatforms;
48 cl_platform_id ocl_platforms[max_platforms];
49 err = clGetPlatformIDs(max_platforms, ocl_platforms, &nplatforms);
50 if (err != CL_SUCCESS) {
51 // OpenCL has no support on the platform.
52 return nullptr;
53 }
54
55 for (cl_uint i = 0; i < nplatforms; ++i) {
56 cl_platform_id ocl_platform = ocl_platforms[i];
57
58 const size_t max_platform_vendor_size = 256;
59 std::string platform_vendor(max_platform_vendor_size + 1, 0);
60 TEST_OCL_CHECK(clGetPlatformInfo(ocl_platform, CL_PLATFORM_VENDOR,
61 max_platform_vendor_size * sizeof(char), &platform_vendor[0],
62 nullptr));
63 cl_uint ndevices;
64 cl_device_id ocl_dev;
65 err = clGetDeviceIDs(ocl_platform, dev_type, 1, &ocl_dev, &ndevices);
66 if (err == CL_SUCCESS) { return ocl_dev; }
67 }
68 return nullptr;
69}
70
71// Base generic class providing RAII support for OpenCL objects
72template <typename T, typename release_t = int32_t(T)>
73struct ocl_wrapper_base_t {
74 ocl_wrapper_base_t(T t, release_t *release = nullptr)
75 : t_(t), release_(release) {}
76 ocl_wrapper_base_t(ocl_wrapper_base_t &&other)
77 : t_(other.t_), release_(other.release_) {
78 other.t_ = nullptr;
79 }
80 ~ocl_wrapper_base_t() {
81 if (release_ && t_) { release_(t_); }
82 }
83
84 ocl_wrapper_base_t(const ocl_wrapper_base_t &) = delete;
85 ocl_wrapper_base_t &operator=(const ocl_wrapper_base_t &) = delete;
86
87 operator T() const { return t_; }
88
89private:
90 T t_;
91 release_t *release_;
92};
93
94// Auxiliary class providing RAII support for OpenCL objects,
95// specialized for specific OpenCL types
96template <typename T>
97struct ocl_wrapper_t {};
98
99template <>
100struct ocl_wrapper_t<cl_device_id> : ocl_wrapper_base_t<cl_device_id> {
101 ocl_wrapper_t(cl_device_id dev) : ocl_wrapper_base_t(dev) {}
102};
103
104template <>
105struct ocl_wrapper_t<cl_context> : ocl_wrapper_base_t<cl_context> {
106 ocl_wrapper_t(cl_context ctx)
107 : ocl_wrapper_base_t(ctx, &clReleaseContext) {}
108};
109
110template <>
111struct ocl_wrapper_t<cl_command_queue> : ocl_wrapper_base_t<cl_command_queue> {
112 ocl_wrapper_t(cl_command_queue queue)
113 : ocl_wrapper_base_t(queue, &clReleaseCommandQueue) {}
114};
115
116template <>
117struct ocl_wrapper_t<cl_kernel> : ocl_wrapper_base_t<cl_kernel> {
118 ocl_wrapper_t(cl_kernel kernel)
119 : ocl_wrapper_base_t(kernel, &clReleaseKernel) {}
120};
121
122// Constructs an OpenCL wrapper object (providing RAII support)
123template <typename T>
124ocl_wrapper_t<T> make_ocl_wrapper(T t) {
125 return ocl_wrapper_t<T>(t);
126}
127
128#endif
129