1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_OCL_USM_MEMORY_STORAGE_HPP
18#define GPU_OCL_OCL_USM_MEMORY_STORAGE_HPP
19
20#include <CL/cl.h>
21
22#include <functional>
23
24#include "common/c_types_map.hpp"
25#include "common/memory_storage.hpp"
26#include "common/utils.hpp"
27#include "gpu/ocl/ocl_gpu_engine.hpp"
28#include "gpu/ocl/ocl_memory_storage_base.hpp"
29#include "gpu/ocl/ocl_usm_utils.hpp"
30#include "gpu/ocl/ocl_utils.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37class ocl_usm_memory_storage_t : public ocl_memory_storage_base_t {
38public:
39 using ocl_memory_storage_base_t::ocl_memory_storage_base_t;
40
41 void *usm_ptr() const { return usm_ptr_.get(); }
42
43 memory_kind_t memory_kind() const override { return memory_kind::usm; }
44
45 status_t get_data_handle(void **handle) const override {
46 *handle = usm_ptr_.get();
47 return status::success;
48 }
49
50 status_t set_data_handle(void *handle) override {
51 usm_ptr_ = decltype(usm_ptr_)(handle, [](void *) {});
52 usm_kind_ = usm::get_pointer_type(engine(), handle);
53 return status::success;
54 }
55
56 status_t map_data(
57 void **mapped_ptr, stream_t *stream, size_t size) const override;
58 status_t unmap_data(void *mapped_ptr, stream_t *stream) const override;
59
60 bool is_host_accessible() const override {
61 return utils::one_of(usm_kind_, usm::ocl_usm_kind_t::host,
62 usm::ocl_usm_kind_t::shared, usm::ocl_usm_kind_t::unknown);
63 }
64
65 std::unique_ptr<memory_storage_t> get_sub_storage(
66 size_t offset, size_t size) const override;
67 std::unique_ptr<memory_storage_t> clone() const override;
68
69protected:
70 status_t init_allocate(size_t size) override {
71 usm_kind_ = usm::ocl_usm_kind_t::shared;
72 void *usm_ptr_alloc = usm::malloc_shared(engine(), size);
73 if (!usm_ptr_alloc) return status::out_of_memory;
74
75 usm_ptr_ = decltype(usm_ptr_)(
76 usm_ptr_alloc, [&](void *ptr) { usm::free(engine(), ptr); });
77 return status::success;
78 }
79
80private:
81 std::unique_ptr<void, std::function<void(void *)>> usm_ptr_;
82 usm::ocl_usm_kind_t usm_kind_ = usm::ocl_usm_kind_t::unknown;
83
84 DNNL_DISALLOW_COPY_AND_ASSIGN(ocl_usm_memory_storage_t);
85};
86} // namespace ocl
87} // namespace gpu
88} // namespace impl
89} // namespace dnnl
90
91#endif
92