1 | #pragma once |
2 | |
3 | #include <any> |
4 | #include <memory> |
5 | #include <string> |
6 | #include <unordered_map> |
7 | #include <variant> |
8 | #include <vector> |
9 | |
10 | #include "taichi/aot/module_data.h" |
11 | #include "taichi/rhi/device.h" |
12 | #include "taichi/ir/snode.h" |
13 | #include "taichi/aot/graph_data.h" |
14 | |
15 | namespace taichi::lang { |
16 | |
17 | struct RuntimeContext; |
18 | class Graph; |
19 | namespace aot { |
20 | |
21 | class TI_DLL_EXPORT Field { |
22 | public: |
23 | // Rule of 5 to make MSVC happy |
24 | Field() = default; |
25 | virtual ~Field() = default; |
26 | Field(const Field &) = delete; |
27 | Field &operator=(const Field &) = delete; |
28 | Field(Field &&) = default; |
29 | Field &operator=(Field &&) = default; |
30 | }; |
31 | |
32 | class TI_DLL_EXPORT KernelTemplateArg { |
33 | public: |
34 | using ArgUnion = std::variant<bool, int64_t, uint64_t, const Field *>; |
35 | template <typename T> |
36 | KernelTemplateArg(const std::string &name, T &&arg) |
37 | : name_(name), targ_(std::forward<T>(arg)) { |
38 | } |
39 | |
40 | private: |
41 | std::string name_; |
42 | /** |
43 | * @brief Template arg |
44 | * |
45 | */ |
46 | ArgUnion targ_; |
47 | }; |
48 | |
49 | class TI_DLL_EXPORT KernelTemplate { |
50 | public: |
51 | // Rule of 5 to make MSVC happy |
52 | KernelTemplate() = default; |
53 | virtual ~KernelTemplate() = default; |
54 | KernelTemplate(const KernelTemplate &) = delete; |
55 | KernelTemplate &operator=(const KernelTemplate &) = delete; |
56 | KernelTemplate(KernelTemplate &&) = default; |
57 | KernelTemplate &operator=(KernelTemplate &&) = default; |
58 | |
59 | Kernel *get_kernel(const std::vector<KernelTemplateArg> &template_args); |
60 | |
61 | protected: |
62 | virtual std::unique_ptr<Kernel> make_new_kernel( |
63 | const std::vector<KernelTemplateArg> &template_args) = 0; |
64 | |
65 | private: |
66 | std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_; |
67 | }; |
68 | |
69 | class TI_DLL_EXPORT Module { |
70 | public: |
71 | // Rule of 5 to make MSVC happy |
72 | Module() = default; |
73 | virtual ~Module() = default; |
74 | Module(const Module &) = delete; |
75 | Module &operator=(const Module &) = delete; |
76 | Module(Module &&) = default; |
77 | Module &operator=(Module &&) = default; |
78 | |
79 | static std::unique_ptr<Module> load(Arch arch, std::any mod_params); |
80 | |
81 | // Module metadata |
82 | // TODO: Instead of virtualize these simple properties, just store them as |
83 | // member variables. |
84 | virtual Arch arch() const = 0; |
85 | virtual uint64_t version() const = 0; |
86 | virtual size_t get_root_size() const = 0; |
87 | |
88 | Kernel *get_kernel(const std::string &name); |
89 | KernelTemplate *get_kernel_template(const std::string &name); |
90 | Field *get_snode_tree(const std::string &name); |
91 | |
92 | virtual std::unique_ptr<aot::CompiledGraph> get_graph( |
93 | const std::string &name) { |
94 | TI_NOT_IMPLEMENTED; |
95 | } |
96 | |
97 | virtual const DeviceCapabilityConfig &get_required_caps() const { |
98 | static DeviceCapabilityConfig default_cfg; |
99 | return default_cfg; |
100 | } |
101 | |
102 | inline bool is_corrupted() const { |
103 | return is_corrupted_; |
104 | } |
105 | |
106 | protected: |
107 | virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0; |
108 | virtual std::unique_ptr<KernelTemplate> make_new_kernel_template( |
109 | const std::string &name) = 0; |
110 | virtual std::unique_ptr<Field> make_new_field(const std::string &name) = 0; |
111 | inline void mark_corrupted() { |
112 | is_corrupted_ = true; |
113 | } |
114 | std::unordered_map<std::string, CompiledGraph> graphs_; |
115 | |
116 | private: |
117 | bool is_corrupted_{false}; |
118 | std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_; |
119 | std::unordered_map<std::string, std::unique_ptr<KernelTemplate>> |
120 | loaded_kernel_templates_; |
121 | std::unordered_map<std::string, std::unique_ptr<Field>> loaded_fields_; |
122 | }; |
123 | |
124 | } // namespace aot |
125 | } // namespace taichi::lang |
126 | |