1/*******************************************************************************
2 Copyright (c) The Taichi Authors (2016- ). All Rights Reserved.
3 The use of this software is governed by the LICENSE file.
4*******************************************************************************/
5
6#pragma once
7
8#include "taichi/common/dict.h"
9
10#include <cstring>
11#include <string>
12#include <map>
13#include <functional>
14#include <memory>
15#include <iostream>
16
17namespace taichi {
18
19template <typename T>
20std::shared_ptr<T> create_instance(const std::string &alias);
21
22template <typename T>
23std::shared_ptr<T> create_instance(const std::string &alias,
24 const Config &config);
25
26template <typename T>
27std::unique_ptr<T> create_instance_unique(const std::string &alias);
28
29template <typename T>
30std::unique_ptr<T> create_instance_unique(const std::string &alias,
31 const Config &config);
32template <typename T>
33std::unique_ptr<T> create_instance_unique_ctor(const std::string &alias,
34 const Config &config);
35
36template <typename T>
37T *create_instance_raw(const std::string &alias);
38
39template <typename T>
40T *create_instance_raw(const std::string &alias, const Config &config);
41
42template <typename T>
43T *create_instance_placement(const std::string &alias, void *place);
44
45template <typename T>
46T *create_instance_placement(const std::string &alias,
47 void *place,
48 const Config &config);
49
50template <typename T>
51std::vector<std::string> get_implementation_names();
52
53class Unit {
54 public:
55 Unit() {
56 }
57
58 virtual void initialize(const Config &config) {
59 }
60
61 virtual bool test() const {
62 return true;
63 }
64
65 virtual std::string get_name() const {
66 TI_NOT_IMPLEMENTED;
67 return "";
68 }
69
70 virtual std::string general_action(const Config &config) {
71 TI_NOT_IMPLEMENTED;
72 return "";
73 }
74
75 virtual ~Unit() {
76 }
77};
78
79#define TI_IMPLEMENTATION_HOLDER_NAME(T) ImplementationHolder_##T
80#define TI_IMPLEMENTATION_HOLDER_PTR(T) instance_ImplementationHolder_##T
81
82class ImplementationHolderBase {
83 public:
84 std::string name;
85
86 virtual bool has(const std::string &alias) const = 0;
87
88 virtual void remove(const std::string &alias) = 0;
89
90 virtual std::vector<std::string> get_implementation_names() const = 0;
91};
92
93class InterfaceHolder {
94 public:
95 typedef std::function<void(void *)> RegistrationMethod;
96 std::map<std::string, RegistrationMethod> methods;
97 std::map<std::string, ImplementationHolderBase *> interfaces;
98
99 void register_registration_method(const std::string &name,
100 const RegistrationMethod &method) {
101 methods[name] = method;
102 }
103
104 void register_interface(const std::string &name,
105 ImplementationHolderBase *interface_) {
106 interfaces[name] = interface_;
107 }
108
109 static InterfaceHolder *get_instance() {
110 static InterfaceHolder holder;
111 return &holder;
112 }
113};
114
115#define TI_INTERFACE(T) \
116 extern void *get_implementation_holder_instance_##T(); \
117 class TI_IMPLEMENTATION_HOLDER_NAME(T) final \
118 : public ImplementationHolderBase { \
119 public: \
120 explicit TI_IMPLEMENTATION_HOLDER_NAME(T)(const std::string &name) { \
121 this->name = name; \
122 } \
123 using FactoryMethod = std::function<std::shared_ptr<T>()>; \
124 using FactoryUniqueMethod = std::function<std::unique_ptr<T>()>; \
125 using FactoryUniqueCtorMethod = \
126 std::function<std::unique_ptr<T>(const Dict &config)>; \
127 using FactoryRawMethod = std::function<T *()>; \
128 using FactoryPlacementMethod = std::function<T *(void *)>; \
129 std::map<std::string, FactoryMethod> implementation_factories; \
130 std::map<std::string, FactoryUniqueMethod> \
131 implementation_unique_factories; \
132 std::map<std::string, FactoryUniqueCtorMethod> \
133 implementation_unique_ctor_factories; \
134 std::map<std::string, FactoryRawMethod> implementation_raw_factories; \
135 std::map<std::string, FactoryPlacementMethod> \
136 implementation_placement_factories; \
137 std::vector<std::string> get_implementation_names() const override { \
138 std::vector<std::string> names; \
139 for (auto &kv : implementation_factories) { \
140 names.push_back(kv.first); \
141 } \
142 return names; \
143 } \
144 template <typename G> \
145 void insert(const std::string &alias) { \
146 implementation_factories.insert( \
147 std::make_pair(alias, [&]() { return std::make_shared<G>(); })); \
148 implementation_unique_factories.insert( \
149 std::make_pair(alias, [&]() { return std::make_unique<G>(); })); \
150 implementation_raw_factories.insert( \
151 std::make_pair(alias, [&]() { return new G(); })); \
152 implementation_placement_factories.insert(std::make_pair( \
153 alias, [&](void *place) { return new (place) G(); })); \
154 } \
155 template <typename G> \
156 void insert_new(const std::string &alias) { \
157 /*with ctor*/ \
158 implementation_factories.insert( \
159 std::make_pair(alias, [&]() { return std::make_shared<G>(); })); \
160 implementation_unique_factories.insert( \
161 std::make_pair(alias, [&]() { return std::make_unique<G>(); })); \
162 implementation_unique_ctor_factories.insert(std::make_pair( \
163 alias, \
164 [&](const Dict &config) { return std::make_unique<G>(config); })); \
165 implementation_raw_factories.insert( \
166 std::make_pair(alias, [&]() { return new G(); })); \
167 implementation_placement_factories.insert(std::make_pair( \
168 alias, [&](void *place) { return new (place) G(); })); \
169 } \
170 void insert(const std::string &alias, const FactoryMethod &f) { \
171 implementation_factories.insert(std::make_pair(alias, f)); \
172 } \
173 bool has(const std::string &alias) const override { \
174 return implementation_factories.find(alias) != \
175 implementation_factories.end(); \
176 } \
177 void remove(const std::string &alias) override { \
178 TI_ASSERT_INFO(has(alias), \
179 std::string("Implementation ") + alias + " not found!"); \
180 implementation_factories.erase(alias); \
181 } \
182 void update(const std::string &alias, const FactoryMethod &f) { \
183 if (has(alias)) { \
184 remove(alias); \
185 } \
186 insert(alias, f); \
187 } \
188 template <typename G> \
189 void update(const std::string &alias) { \
190 if (has(alias)) { \
191 remove(alias); \
192 } \
193 insert<G>(alias); \
194 } \
195 std::shared_ptr<T> create(const std::string &alias) { \
196 auto factory = implementation_factories.find(alias); \
197 TI_ASSERT_INFO( \
198 factory != implementation_factories.end(), \
199 "Implementation [" + name + "::" + alias + "] not found!"); \
200 return (factory->second)(); \
201 } \
202 std::unique_ptr<T> create_unique(const std::string &alias) { \
203 auto factory = implementation_unique_factories.find(alias); \
204 TI_ASSERT_INFO( \
205 factory != implementation_unique_factories.end(), \
206 "Implementation [" + name + "::" + alias + "] not found!"); \
207 return (factory->second)(); \
208 } \
209 std::unique_ptr<T> create_unique_ctor(const std::string &alias, \
210 const Dict &config) { \
211 auto factory = implementation_unique_ctor_factories.find(alias); \
212 TI_ASSERT_INFO( \
213 factory != implementation_unique_ctor_factories.end(), \
214 "Implementation [" + name + "::" + alias + "] not found!"); \
215 return (factory->second)(config); \
216 } \
217 T *create_raw(const std::string &alias) { \
218 auto factory = implementation_raw_factories.find(alias); \
219 TI_ASSERT_INFO( \
220 factory != implementation_raw_factories.end(), \
221 "Implementation [" + name + "::" + alias + "] not found!"); \
222 return (factory->second)(); \
223 } \
224 T *create_placement(const std::string &alias, void *place) { \
225 auto factory = implementation_placement_factories.find(alias); \
226 TI_ASSERT_INFO( \
227 factory != implementation_placement_factories.end(), \
228 "Implementation [" + name + "::" + alias + "] not found!"); \
229 return (factory->second)(place); \
230 } \
231 static TI_IMPLEMENTATION_HOLDER_NAME(T) * get_instance() { \
232 return static_cast<TI_IMPLEMENTATION_HOLDER_NAME(T) *>( \
233 get_implementation_holder_instance_##T()); \
234 } \
235 }; \
236 extern TI_IMPLEMENTATION_HOLDER_NAME(T) * TI_IMPLEMENTATION_HOLDER_PTR(T);
237
238#define TI_INTERFACE_DEF(class_name, base_alias) \
239 template <> \
240 std::shared_ptr<class_name> create_instance(const std::string &alias) { \
241 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance()->create( \
242 alias); \
243 } \
244 template <> \
245 std::shared_ptr<class_name> create_instance(const std::string &alias, \
246 const Config &config) { \
247 auto instance = create_instance<class_name>(alias); \
248 instance->initialize(config); \
249 return instance; \
250 } \
251 template <> \
252 std::unique_ptr<class_name> create_instance_unique( \
253 const std::string &alias) { \
254 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \
255 ->create_unique(alias); \
256 } \
257 template <> \
258 std::unique_ptr<class_name> create_instance_unique(const std::string &alias, \
259 const Config &config) { \
260 auto instance = create_instance_unique<class_name>(alias); \
261 instance->initialize(config); \
262 return instance; \
263 } \
264 template <> \
265 std::unique_ptr<class_name> create_instance_unique_ctor( \
266 const std::string &alias, const Dict &config) { \
267 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \
268 ->create_unique_ctor(alias, config); \
269 } \
270 template <> \
271 class_name *create_instance_raw(const std::string &alias) { \
272 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \
273 ->create_raw(alias); \
274 } \
275 template <> \
276 class_name *create_instance_placement(const std::string &alias, \
277 void *place) { \
278 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \
279 ->create_placement(alias, place); \
280 } \
281 template <> \
282 class_name *create_instance_placement(const std::string &alias, void *place, \
283 const Config &config) { \
284 auto instance = create_instance_placement<class_name>(alias, place); \
285 instance->initialize(config); \
286 return instance; \
287 } \
288 template <> \
289 class_name *create_instance_raw(const std::string &alias, \
290 const Config &config) { \
291 auto instance = create_instance_raw<class_name>(alias); \
292 instance->initialize(config); \
293 return instance; \
294 } \
295 template <> \
296 std::vector<std::string> get_implementation_names<class_name>() { \
297 return TI_IMPLEMENTATION_HOLDER_NAME(class_name)::get_instance() \
298 ->get_implementation_names(); \
299 } \
300 TI_IMPLEMENTATION_HOLDER_NAME(class_name) * \
301 TI_IMPLEMENTATION_HOLDER_PTR(class_name) = nullptr; \
302 void *get_implementation_holder_instance_##class_name() { \
303 if (!TI_IMPLEMENTATION_HOLDER_PTR(class_name)) { \
304 TI_IMPLEMENTATION_HOLDER_PTR(class_name) = \
305 new TI_IMPLEMENTATION_HOLDER_NAME(class_name)(base_alias); \
306 } \
307 return TI_IMPLEMENTATION_HOLDER_PTR(class_name); \
308 }
309
310#define TI_IMPLEMENTATION(base_class_name, class_name, alias) \
311 class ImplementationInjector_##base_class_name##class_name { \
312 public: \
313 ImplementationInjector_##base_class_name##class_name() { \
314 TI_IMPLEMENTATION_HOLDER_NAME(base_class_name)::get_instance() \
315 ->insert<class_name>(alias); \
316 } \
317 } ImplementationInjector_##base_class_name##class_name##instance;
318
319#define TI_IMPLEMENTATION_NEW(base_class_name, class_name) \
320 class ImplementationInjector_##base_class_name##class_name { \
321 public: \
322 ImplementationInjector_##base_class_name##class_name() { \
323 TI_IMPLEMENTATION_HOLDER_NAME(base_class_name)::get_instance() \
324 ->insert_new<class_name>(class_name::get_name_static()); \
325 } \
326 } ImplementationInjector_##base_class_name##class_name##instance;
327
328#define TI_NAME(alias) \
329 virtual std::string get_name() const override { return get_name_static(); } \
330 static std::string get_name_static() { return alias; }
331
332} // namespace taichi
333