1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_OCL_ENGINE_HPP |
18 | #define GPU_OCL_OCL_ENGINE_HPP |
19 | |
20 | #include "gpu/ocl/ocl_gpu_engine.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace ocl { |
26 | |
27 | class ocl_engine_factory_t : public engine_factory_t { |
28 | public: |
29 | ocl_engine_factory_t(engine_kind_t engine_kind) { |
30 | assert(engine_kind == engine_kind::gpu); |
31 | MAYBE_UNUSED(engine_kind); |
32 | } |
33 | |
34 | size_t count() const override { |
35 | std::vector<cl_device_id> ocl_devices; |
36 | status_t status = get_ocl_devices(&ocl_devices, CL_DEVICE_TYPE_GPU); |
37 | if (status != status::success) return status; |
38 | return ocl_devices.size(); |
39 | } |
40 | |
41 | status_t engine_create(engine_t **engine, size_t index) const override { |
42 | status_t status; |
43 | std::vector<cl_device_id> ocl_devices; |
44 | |
45 | status = get_ocl_devices(&ocl_devices, CL_DEVICE_TYPE_GPU); |
46 | if (status != status::success) return status; |
47 | |
48 | if (index >= ocl_devices.size()) return status::invalid_arguments; |
49 | |
50 | auto *ocl_engine |
51 | = new ocl_gpu_engine_t(ocl_devices[index], nullptr, index); |
52 | if (!ocl_engine) return status::out_of_memory; |
53 | |
54 | status = ocl_engine->init(); |
55 | if (status != status::success) { |
56 | ocl_engine->release(); |
57 | return status; |
58 | } |
59 | *engine = ocl_engine; |
60 | return status::success; |
61 | } |
62 | |
63 | status_t engine_create(engine_t **engine, cl_device_id device, |
64 | cl_context context, size_t index, |
65 | const std::vector<uint8_t> &cache_blob = {}) { |
66 | auto *ocl_engine = new ocl_gpu_engine_t(device, context, index); |
67 | if (!ocl_engine) return status::out_of_memory; |
68 | |
69 | status_t status = ocl_engine->init(cache_blob); |
70 | if (status != status::success) { |
71 | ocl_engine->release(); |
72 | return status; |
73 | } |
74 | *engine = ocl_engine; |
75 | return status::success; |
76 | } |
77 | }; |
78 | } // namespace ocl |
79 | } // namespace gpu |
80 | } // namespace impl |
81 | } // namespace dnnl |
82 | |
83 | #endif // GPU_OCL_OCL_ENGINE_HPP |
84 | |