1 | #ifndef C10_UTIL_REGISTRY_H_ |
2 | #define C10_UTIL_REGISTRY_H_ |
3 | |
4 | /** |
5 | * Simple registry implementation that uses static variables to |
6 | * register object creators during program initialization time. |
7 | */ |
8 | |
9 | // NB: This Registry works poorly when you have other namespaces. |
10 | // Make all macro invocations from inside the at namespace. |
11 | |
12 | #include <algorithm> |
13 | #include <cstdio> |
14 | #include <cstdlib> |
15 | #include <functional> |
16 | #include <memory> |
17 | #include <mutex> |
18 | #include <string> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include <c10/macros/Macros.h> |
23 | #include <c10/util/Type.h> |
24 | |
25 | namespace c10 { |
26 | |
27 | template <typename KeyType> |
28 | inline std::string KeyStrRepr(const KeyType& /*key*/) { |
29 | return "[key type printing not supported]" ; |
30 | } |
31 | |
32 | template <> |
33 | inline std::string KeyStrRepr(const std::string& key) { |
34 | return key; |
35 | } |
36 | |
37 | enum RegistryPriority { |
38 | REGISTRY_FALLBACK = 1, |
39 | REGISTRY_DEFAULT = 2, |
40 | REGISTRY_PREFERRED = 3, |
41 | }; |
42 | |
43 | /** |
44 | * @brief A template class that allows one to register classes by keys. |
45 | * |
46 | * The keys are usually a std::string specifying the name, but can be anything |
47 | * that can be used in a std::map. |
48 | * |
49 | * You should most likely not use the Registry class explicitly, but use the |
50 | * helper macros below to declare specific registries as well as registering |
51 | * objects. |
52 | */ |
53 | template <class SrcType, class ObjectPtrType, class... Args> |
54 | class Registry { |
55 | public: |
56 | typedef std::function<ObjectPtrType(Args...)> Creator; |
57 | |
58 | Registry(bool warning = true) |
59 | : registry_(), priority_(), terminate_(true), warning_(warning) {} |
60 | |
61 | void Register( |
62 | const SrcType& key, |
63 | Creator creator, |
64 | const RegistryPriority priority = REGISTRY_DEFAULT) { |
65 | std::lock_guard<std::mutex> lock(register_mutex_); |
66 | // The if statement below is essentially the same as the following line: |
67 | // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key |
68 | // << " registered twice."; |
69 | // However, TORCH_CHECK_EQ depends on google logging, and since registration |
70 | // is carried out at static initialization time, we do not want to have an |
71 | // explicit dependency on glog's initialization function. |
72 | if (registry_.count(key) != 0) { |
73 | auto cur_priority = priority_[key]; |
74 | if (priority > cur_priority) { |
75 | #ifdef DEBUG |
76 | std::string warn_msg = |
77 | "Overwriting already registered item for key " + KeyStrRepr(key); |
78 | fprintf(stderr, "%s\n" , warn_msg.c_str()); |
79 | #endif |
80 | registry_[key] = creator; |
81 | priority_[key] = priority; |
82 | } else if (priority == cur_priority) { |
83 | std::string err_msg = |
84 | "Key already registered with the same priority: " + KeyStrRepr(key); |
85 | fprintf(stderr, "%s\n" , err_msg.c_str()); |
86 | if (terminate_) { |
87 | std::exit(1); |
88 | } else { |
89 | throw std::runtime_error(err_msg); |
90 | } |
91 | } else if (warning_) { |
92 | std::string warn_msg = |
93 | "Higher priority item already registered, skipping registration of " + |
94 | KeyStrRepr(key); |
95 | fprintf(stderr, "%s\n" , warn_msg.c_str()); |
96 | } |
97 | } else { |
98 | registry_[key] = creator; |
99 | priority_[key] = priority; |
100 | } |
101 | } |
102 | |
103 | void Register( |
104 | const SrcType& key, |
105 | Creator creator, |
106 | const std::string& help_msg, |
107 | const RegistryPriority priority = REGISTRY_DEFAULT) { |
108 | Register(key, creator, priority); |
109 | help_message_[key] = help_msg; |
110 | } |
111 | |
112 | inline bool Has(const SrcType& key) { |
113 | return (registry_.count(key) != 0); |
114 | } |
115 | |
116 | ObjectPtrType Create(const SrcType& key, Args... args) { |
117 | auto it = registry_.find(key); |
118 | if (it == registry_.end()) { |
119 | // Returns nullptr if the key is not registered. |
120 | return nullptr; |
121 | } |
122 | return it->second(args...); |
123 | } |
124 | |
125 | /** |
126 | * Returns the keys currently registered as a std::vector. |
127 | */ |
128 | std::vector<SrcType> Keys() const { |
129 | std::vector<SrcType> keys; |
130 | keys.reserve(registry_.size()); |
131 | for (const auto& it : registry_) { |
132 | keys.push_back(it.first); |
133 | } |
134 | return keys; |
135 | } |
136 | |
137 | inline const std::unordered_map<SrcType, std::string>& HelpMessage() const { |
138 | return help_message_; |
139 | } |
140 | |
141 | const char* HelpMessage(const SrcType& key) const { |
142 | auto it = help_message_.find(key); |
143 | if (it == help_message_.end()) { |
144 | return nullptr; |
145 | } |
146 | return it->second.c_str(); |
147 | } |
148 | |
149 | // Used for testing, if terminate is unset, Registry throws instead of |
150 | // calling std::exit |
151 | void SetTerminate(bool terminate) { |
152 | terminate_ = terminate; |
153 | } |
154 | |
155 | private: |
156 | std::unordered_map<SrcType, Creator> registry_; |
157 | std::unordered_map<SrcType, RegistryPriority> priority_; |
158 | bool terminate_; |
159 | const bool warning_; |
160 | std::unordered_map<SrcType, std::string> help_message_; |
161 | std::mutex register_mutex_; |
162 | |
163 | C10_DISABLE_COPY_AND_ASSIGN(Registry); |
164 | }; |
165 | |
166 | template <class SrcType, class ObjectPtrType, class... Args> |
167 | class Registerer { |
168 | public: |
169 | explicit Registerer( |
170 | const SrcType& key, |
171 | Registry<SrcType, ObjectPtrType, Args...>* registry, |
172 | typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator, |
173 | const std::string& help_msg = "" ) { |
174 | registry->Register(key, creator, help_msg); |
175 | } |
176 | |
177 | explicit Registerer( |
178 | const SrcType& key, |
179 | const RegistryPriority priority, |
180 | Registry<SrcType, ObjectPtrType, Args...>* registry, |
181 | typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator, |
182 | const std::string& help_msg = "" ) { |
183 | registry->Register(key, creator, help_msg, priority); |
184 | } |
185 | |
186 | template <class DerivedType> |
187 | static ObjectPtrType DefaultCreator(Args... args) { |
188 | return ObjectPtrType(new DerivedType(args...)); |
189 | } |
190 | }; |
191 | |
192 | /** |
193 | * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function |
194 | * declaration, as well as creating a convenient typename for its corresponding |
195 | * registerer. |
196 | */ |
197 | // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE |
198 | // as import and DEFINE as export, because these registry macros will be used |
199 | // in downstream shared libraries as well, and one cannot use *_API - the API |
200 | // macro will be defined on a per-shared-library basis. Semantically, when one |
201 | // declares a typed registry it is always going to be IMPORT, and when one |
202 | // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE), |
203 | // the instantiation unit is always going to be exported. |
204 | // |
205 | // The only unique condition is when in the same file one does DECLARE and |
206 | // DEFINE - in Windows compilers, this generates a warning that dllimport and |
207 | // dllexport are mixed, but the warning is fine and linker will be properly |
208 | // exporting the symbol. Same thing happens in the gflags flag declaration and |
209 | // definition caes. |
210 | #define C10_DECLARE_TYPED_REGISTRY( \ |
211 | RegistryName, SrcType, ObjectType, PtrType, ...) \ |
212 | C10_IMPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \ |
213 | RegistryName(); \ |
214 | typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \ |
215 | Registerer##RegistryName |
216 | |
217 | #define C10_DEFINE_TYPED_REGISTRY( \ |
218 | RegistryName, SrcType, ObjectType, PtrType, ...) \ |
219 | C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \ |
220 | RegistryName() { \ |
221 | static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \ |
222 | registry = new ::c10:: \ |
223 | Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \ |
224 | return registry; \ |
225 | } |
226 | |
227 | #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ |
228 | RegistryName, SrcType, ObjectType, PtrType, ...) \ |
229 | C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \ |
230 | RegistryName() { \ |
231 | static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \ |
232 | registry = \ |
233 | new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \ |
234 | false); \ |
235 | return registry; \ |
236 | } |
237 | |
238 | // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated |
239 | // creator with comma in its templated arguments. |
240 | #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \ |
241 | static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
242 | key, RegistryName(), ##__VA_ARGS__); |
243 | |
244 | #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ |
245 | RegistryName, key, priority, ...) \ |
246 | static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
247 | key, priority, RegistryName(), ##__VA_ARGS__); |
248 | |
249 | #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \ |
250 | static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
251 | key, \ |
252 | RegistryName(), \ |
253 | Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ |
254 | ::c10::demangle_type<__VA_ARGS__>()); |
255 | |
256 | #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ |
257 | RegistryName, key, priority, ...) \ |
258 | static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \ |
259 | key, \ |
260 | priority, \ |
261 | RegistryName(), \ |
262 | Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \ |
263 | ::c10::demangle_type<__VA_ARGS__>()); |
264 | |
265 | // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use |
266 | // std::string as the key type, because that is the most commonly used cases. |
267 | #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ |
268 | C10_DECLARE_TYPED_REGISTRY( \ |
269 | RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) |
270 | |
271 | #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ |
272 | C10_DEFINE_TYPED_REGISTRY( \ |
273 | RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) |
274 | |
275 | #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \ |
276 | C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ |
277 | RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__) |
278 | |
279 | #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ |
280 | C10_DECLARE_TYPED_REGISTRY( \ |
281 | RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) |
282 | |
283 | #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \ |
284 | C10_DEFINE_TYPED_REGISTRY( \ |
285 | RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) |
286 | |
287 | #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \ |
288 | RegistryName, ObjectType, ...) \ |
289 | C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \ |
290 | RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__) |
291 | |
292 | // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string |
293 | // as the key |
294 | // type, because that is the most commonly used cases. |
295 | #define C10_REGISTER_CREATOR(RegistryName, key, ...) \ |
296 | C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__) |
297 | |
298 | #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \ |
299 | C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \ |
300 | RegistryName, #key, priority, __VA_ARGS__) |
301 | |
302 | #define C10_REGISTER_CLASS(RegistryName, key, ...) \ |
303 | C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__) |
304 | |
305 | #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \ |
306 | C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \ |
307 | RegistryName, #key, priority, __VA_ARGS__) |
308 | |
309 | } // namespace c10 |
310 | |
311 | #endif // C10_UTIL_REGISTRY_H_ |
312 | |