1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_SUPPORT_REGISTER_H |
17 | #define GLOW_SUPPORT_REGISTER_H |
18 | |
19 | #include <cassert> |
20 | #include <map> |
21 | #include <vector> |
22 | |
23 | namespace glow { |
24 | |
25 | /// Base factory interface which needs to be implemented |
26 | /// for static registration of arbitrary classes. |
27 | /// For example, CPUFactory would be responsible for creating CPU backends |
28 | /// registred with "CPU" key. |
29 | template <class Key, class Base> class BaseFactory { |
30 | public: |
31 | virtual ~BaseFactory(); |
32 | |
33 | /// Create an object of Base type. |
34 | virtual Base *create() = 0; |
35 | /// Key used for a registered factory. |
36 | virtual Key getRegistrationKey() const = 0; |
37 | /// Number of devices available for the registered factory. |
38 | virtual unsigned numDevices() const = 0; |
39 | /// Scan devices available and return their ids. |
40 | virtual std::vector<unsigned> scanDeviceIDs() const = 0; |
41 | }; |
42 | |
43 | /// General registry for implementation factories. |
44 | /// The registry is templated by the Key class and Base class that a |
45 | /// set of factories inherits from. |
46 | template <class Key, class Base> class FactoryRegistry { |
47 | public: |
48 | using FactoryMap = std::map<Key, BaseFactory<Key, Base> *>; |
49 | |
50 | /// Register \p factory in a static map. |
51 | static void registerFactory(BaseFactory<Key, Base> &factory) { |
52 | Key registrationKey = factory.getRegistrationKey(); |
53 | assert(findRegistration(factory) == factories().end() && |
54 | "Double registration of base factory" ); |
55 | auto inserted = factories().emplace(registrationKey, &factory); |
56 | assert(inserted.second && |
57 | "Double registration of a factory with the same key" ); |
58 | (void)inserted; |
59 | } |
60 | |
61 | static void unregisterFactory(BaseFactory<Key, Base> &factory) { |
62 | auto registration = findRegistration(factory); |
63 | assert(registration != factories().end() && |
64 | "Could not unregister a base factory" ); |
65 | factories().erase(registration); |
66 | } |
67 | |
68 | /// \returns newly created object from factory keyed by \p key. |
69 | /// \returns nullptr if there is no factory registered with \p key. |
70 | static Base *get(const Key &key) { |
71 | auto it = factories().find(key); |
72 | |
73 | if (it == factories().end()) { |
74 | return nullptr; |
75 | } |
76 | |
77 | return it->second->create(); |
78 | } |
79 | |
80 | /// \returns all registered factories. |
81 | static FactoryMap &factories() { |
82 | static FactoryMap *factories = new FactoryMap(); |
83 | return *factories; |
84 | } |
85 | |
86 | private: |
87 | /// Find a registration of the given factory. |
88 | /// \returns iterator referring to the found registration or factories::end() |
89 | /// if nothing was found. |
90 | static typename FactoryMap::iterator |
91 | findRegistration(BaseFactory<Key, Base> &factory) { |
92 | // Unfortunately, factory.getRegistrationKey() cannot be used here as it is |
93 | // a virtual function and findRegistration could be invoked from |
94 | // destructors, which are not supposed to invoke any virtual functions. |
95 | // Therefore find the factory registration using the address of the factory. |
96 | for (auto it = factories().begin(), e = factories().end(); it != e; ++it) { |
97 | if (it->second != &factory) { |
98 | continue; |
99 | } |
100 | return it; |
101 | } |
102 | return factories().end(); |
103 | } |
104 | }; |
105 | |
106 | template <class Key, class Base> BaseFactory<Key, Base>::~BaseFactory() { |
107 | FactoryRegistry<Key, Base>::unregisterFactory(*this); |
108 | } |
109 | |
110 | /// Factory registration template, all static registration should be done |
111 | /// via RegisterFactory. It allows to register specific implementation factory |
112 | /// with the FactoryRegistry by instantiating this templated class with the |
113 | /// specific factory class, specific Key class and the general Base class. |
114 | /// |
115 | /// Example registration: |
116 | /// static Registry::RegisterFactory< |
117 | /// SpecificKeyType, SpecificFactory, BaseFactory> registered_; |
118 | template <class Key, class Factory, class Base> class RegisterFactory { |
119 | public: |
120 | RegisterFactory() { FactoryRegistry<Key, Base>::registerFactory(factory_); } |
121 | |
122 | private: |
123 | Factory factory_{}; |
124 | }; |
125 | |
126 | } // namespace glow |
127 | |
128 | #endif // GLOW_SUPPORT_REGISTER_H |
129 | |