1 | #include <functional> |
---|---|
2 | |
3 | #include "pybind11/pybind11.h" |
4 | #include "taichi/common/interface.h" |
5 | #include "taichi/common/task.h" |
6 | #include "taichi/system/benchmark.h" |
7 | |
8 | namespace taichi { |
9 | |
10 | #define TI_INTERFACE_DEF_WITH_PYBIND11(class_name, base_alias) \ |
11 | \ |
12 | class InterfaceInjector_##class_name { \ |
13 | public: \ |
14 | explicit InterfaceInjector_##class_name(const std::string &name) { \ |
15 | InterfaceHolder::get_instance()->register_registration_method( \ |
16 | base_alias, [&](void *m) { \ |
17 | ((pybind11::module *)m) \ |
18 | ->def("create_" base_alias, \ |
19 | static_cast<std::shared_ptr<class_name> (*)( \ |
20 | const std::string &name)>( \ |
21 | &create_instance<class_name>)); \ |
22 | ((pybind11::module *)m) \ |
23 | ->def("create_initialized_" base_alias, \ |
24 | static_cast<std::shared_ptr<class_name> (*)( \ |
25 | const std::string &name, const Config &config)>( \ |
26 | &create_instance<class_name>)); \ |
27 | }); \ |
28 | InterfaceHolder::get_instance()->register_interface( \ |
29 | base_alias, (ImplementationHolderBase *) \ |
30 | get_implementation_holder_instance_##class_name()); \ |
31 | } \ |
32 | } ImplementationInjector_##base_class_name##class_name##instance(base_alias); |
33 | |
34 | TI_INTERFACE_DEF_WITH_PYBIND11(Benchmark, "benchmark") |
35 | TI_INTERFACE_DEF_WITH_PYBIND11(Task, "task") |
36 | |
37 | } // namespace taichi |
38 |