1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_COMPUTE_PROGRAM_LIST_HPP
18#define GPU_COMPUTE_PROGRAM_LIST_HPP
19
20#include <cassert>
21#include <unordered_map>
22
23#include "gpu/compute/compute_engine.hpp"
24#include "gpu/compute/utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace compute {
30
31class program_list_t {
32public:
33 program_list_t(engine_t *engine) {
34 auto compute_engine = utils::downcast<compute_engine_t *>(engine);
35 deleter_ = compute_engine->get_program_list_deleter();
36 }
37
38 void add(const binary_t *binary, void *program) {
39 assert(programs_.count(binary) == 0);
40 auto it = programs_.insert({binary, program});
41 assert(it.second);
42 MAYBE_UNUSED(it);
43 }
44
45 template <typename program_t>
46 program_t get(const binary_t *binary) const {
47 static_assert(std::is_pointer<program_t>::value,
48 "program_t is expected to be a pointer.");
49
50 auto it = programs_.find(binary);
51 if (it == programs_.end()) return nullptr;
52 return reinterpret_cast<program_t>(it->second);
53 }
54
55 ~program_list_t() {
56 assert(deleter_);
57 for (const auto &p : programs_)
58 deleter_(p.second);
59 }
60
61private:
62 program_list_t() = delete;
63 DNNL_DISALLOW_COPY_AND_ASSIGN(program_list_t);
64
65 std::function<void(void *)> deleter_;
66 std::unordered_map<const binary_t *, void *> programs_;
67};
68
69} // namespace compute
70} // namespace gpu
71} // namespace impl
72} // namespace dnnl
73
74#endif // GPU_COMPUTE_PROGRAM_LIST_HPP
75