1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_OCL_GPU_ENGINE_HPP
18#define GPU_OCL_OCL_GPU_ENGINE_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "common/c_types_map.hpp"
23#include "common/engine.hpp"
24#include "common/stream.hpp"
25#include "common/utils.hpp"
26#include "gpu/compute/compute.hpp"
27#include "gpu/gpu_impl_list.hpp"
28#include "gpu/ocl/ocl_gpu_engine_id.hpp"
29#include "gpu/ocl/ocl_gpu_kernel.hpp"
30#include "gpu/ocl/ocl_utils.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37class ocl_gpu_engine_t : public compute::compute_engine_t {
38public:
39 ocl_gpu_engine_t(cl_device_id adevice, cl_context acontext, size_t index)
40 : compute::compute_engine_t(engine_kind::gpu, runtime_kind::ocl, index)
41 , device_(adevice)
42 , context_(acontext)
43 , is_user_context_(acontext) {}
44
45 status_t init() override;
46 status_t init(const std::vector<uint8_t> &cache_blob);
47
48 status_t create_memory_storage(memory_storage_t **storage, unsigned flags,
49 size_t size, void *handle) override;
50
51 status_t create_stream(stream_t **stream, unsigned flags) override;
52 status_t create_stream(stream_t **stream, cl_command_queue queue);
53
54 status_t create_kernel(compute::kernel_t *kernel,
55 jit::jit_generator_base *jitter,
56 cache_blob_t cache_blob) const override;
57
58 status_t create_kernels(std::vector<compute::kernel_t> *kernels,
59 const std::vector<const char *> &kernel_names,
60 const compute::kernel_ctx_t &kernel_ctx,
61 cache_blob_t cache_blob) const override;
62
63 status_t create_kernels_from_ocl_source(
64 std::vector<compute::kernel_t> *kernels,
65 const std::vector<const char *> &kernel_names,
66 const char *source_string,
67 const compute::kernel_ctx_t &kernel_ctx) const override;
68
69 std::function<void(void *)> get_program_list_deleter() const override;
70
71 const impl_list_item_t *get_concat_implementation_list() const override {
72 return gpu_impl_list_t::get_concat_implementation_list();
73 }
74
75 const impl_list_item_t *get_reorder_implementation_list(
76 const memory_desc_t *src_md,
77 const memory_desc_t *dst_md) const override {
78 return gpu_impl_list_t::get_reorder_implementation_list(src_md, dst_md);
79 }
80
81 const impl_list_item_t *get_sum_implementation_list() const override {
82 return gpu_impl_list_t::get_sum_implementation_list();
83 }
84
85 const impl_list_item_t *get_implementation_list(
86 const op_desc_t *desc) const override {
87 return gpu_impl_list_t::get_implementation_list(desc);
88 }
89
90 cl_device_id device() const { return device_; }
91 cl_context context() const { return context_; }
92 cl_platform_id platform() const { return platform_; }
93
94 device_id_t device_id() const override {
95 return std::make_tuple(0, reinterpret_cast<uint64_t>(device()), 0);
96 }
97
98 status_t serialize_device(serialization_stream_t &sstream) const override;
99
100 status_t get_cache_blob_size(size_t *size) const {
101 return device_info_->get_cache_blob_size(size);
102 }
103
104 status_t get_cache_blob(size_t size, uint8_t *cache_blob) const {
105 return device_info_->get_cache_blob(size, cache_blob);
106 }
107
108 engine_id_t engine_id() const override {
109 return engine_id_t(new ocl_gpu_engine_id_impl_t(
110 device(), context(), kind(), runtime_kind(), index()));
111 }
112
113protected:
114 ~ocl_gpu_engine_t() override {
115 if (device_) { clReleaseDevice(device_); }
116 if (context_) { clReleaseContext(context_); }
117 }
118
119protected:
120 status_t init_device_info() override;
121 status_t init_device_info(const std::vector<uint8_t> &cache_blob) override;
122
123private:
124 cl_device_id device_;
125 cl_context context_;
126 cl_platform_id platform_;
127 bool is_user_context_;
128};
129
130} // namespace ocl
131} // namespace gpu
132} // namespace impl
133} // namespace dnnl
134
135#endif
136