1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_OCL_ENGINE_HPP
18#define GPU_OCL_OCL_ENGINE_HPP
19
20#include "gpu/ocl/ocl_gpu_engine.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace ocl {
26
27class ocl_engine_factory_t : public engine_factory_t {
28public:
29 ocl_engine_factory_t(engine_kind_t engine_kind) {
30 assert(engine_kind == engine_kind::gpu);
31 MAYBE_UNUSED(engine_kind);
32 }
33
34 size_t count() const override {
35 std::vector<cl_device_id> ocl_devices;
36 status_t status = get_ocl_devices(&ocl_devices, CL_DEVICE_TYPE_GPU);
37 if (status != status::success) return status;
38 return ocl_devices.size();
39 }
40
41 status_t engine_create(engine_t **engine, size_t index) const override {
42 status_t status;
43 std::vector<cl_device_id> ocl_devices;
44
45 status = get_ocl_devices(&ocl_devices, CL_DEVICE_TYPE_GPU);
46 if (status != status::success) return status;
47
48 if (index >= ocl_devices.size()) return status::invalid_arguments;
49
50 auto *ocl_engine
51 = new ocl_gpu_engine_t(ocl_devices[index], nullptr, index);
52 if (!ocl_engine) return status::out_of_memory;
53
54 status = ocl_engine->init();
55 if (status != status::success) {
56 ocl_engine->release();
57 return status;
58 }
59 *engine = ocl_engine;
60 return status::success;
61 }
62
63 status_t engine_create(engine_t **engine, cl_device_id device,
64 cl_context context, size_t index,
65 const std::vector<uint8_t> &cache_blob = {}) {
66 auto *ocl_engine = new ocl_gpu_engine_t(device, context, index);
67 if (!ocl_engine) return status::out_of_memory;
68
69 status_t status = ocl_engine->init(cache_blob);
70 if (status != status::success) {
71 ocl_engine->release();
72 return status;
73 }
74 *engine = ocl_engine;
75 return status::success;
76 }
77};
78} // namespace ocl
79} // namespace gpu
80} // namespace impl
81} // namespace dnnl
82
83#endif // GPU_OCL_OCL_ENGINE_HPP
84