1#pragma once
2
3#include <any>
4#include <memory>
5#include <string>
6#include <unordered_map>
7#include <variant>
8#include <vector>
9
10#include "taichi/aot/module_data.h"
11#include "taichi/rhi/device.h"
12#include "taichi/ir/snode.h"
13#include "taichi/aot/graph_data.h"
14
15namespace taichi::lang {
16
17struct RuntimeContext;
18class Graph;
19namespace aot {
20
21class TI_DLL_EXPORT Field {
22 public:
23 // Rule of 5 to make MSVC happy
24 Field() = default;
25 virtual ~Field() = default;
26 Field(const Field &) = delete;
27 Field &operator=(const Field &) = delete;
28 Field(Field &&) = default;
29 Field &operator=(Field &&) = default;
30};
31
32class TI_DLL_EXPORT KernelTemplateArg {
33 public:
34 using ArgUnion = std::variant<bool, int64_t, uint64_t, const Field *>;
35 template <typename T>
36 KernelTemplateArg(const std::string &name, T &&arg)
37 : name_(name), targ_(std::forward<T>(arg)) {
38 }
39
40 private:
41 std::string name_;
42 /**
43 * @brief Template arg
44 *
45 */
46 ArgUnion targ_;
47};
48
49class TI_DLL_EXPORT KernelTemplate {
50 public:
51 // Rule of 5 to make MSVC happy
52 KernelTemplate() = default;
53 virtual ~KernelTemplate() = default;
54 KernelTemplate(const KernelTemplate &) = delete;
55 KernelTemplate &operator=(const KernelTemplate &) = delete;
56 KernelTemplate(KernelTemplate &&) = default;
57 KernelTemplate &operator=(KernelTemplate &&) = default;
58
59 Kernel *get_kernel(const std::vector<KernelTemplateArg> &template_args);
60
61 protected:
62 virtual std::unique_ptr<Kernel> make_new_kernel(
63 const std::vector<KernelTemplateArg> &template_args) = 0;
64
65 private:
66 std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
67};
68
69class TI_DLL_EXPORT Module {
70 public:
71 // Rule of 5 to make MSVC happy
72 Module() = default;
73 virtual ~Module() = default;
74 Module(const Module &) = delete;
75 Module &operator=(const Module &) = delete;
76 Module(Module &&) = default;
77 Module &operator=(Module &&) = default;
78
79 static std::unique_ptr<Module> load(Arch arch, std::any mod_params);
80
81 // Module metadata
82 // TODO: Instead of virtualize these simple properties, just store them as
83 // member variables.
84 virtual Arch arch() const = 0;
85 virtual uint64_t version() const = 0;
86 virtual size_t get_root_size() const = 0;
87
88 Kernel *get_kernel(const std::string &name);
89 KernelTemplate *get_kernel_template(const std::string &name);
90 Field *get_snode_tree(const std::string &name);
91
92 virtual std::unique_ptr<aot::CompiledGraph> get_graph(
93 const std::string &name) {
94 TI_NOT_IMPLEMENTED;
95 }
96
97 virtual const DeviceCapabilityConfig &get_required_caps() const {
98 static DeviceCapabilityConfig default_cfg;
99 return default_cfg;
100 }
101
102 inline bool is_corrupted() const {
103 return is_corrupted_;
104 }
105
106 protected:
107 virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0;
108 virtual std::unique_ptr<KernelTemplate> make_new_kernel_template(
109 const std::string &name) = 0;
110 virtual std::unique_ptr<Field> make_new_field(const std::string &name) = 0;
111 inline void mark_corrupted() {
112 is_corrupted_ = true;
113 }
114 std::unordered_map<std::string, CompiledGraph> graphs_;
115
116 private:
117 bool is_corrupted_{false};
118 std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
119 std::unordered_map<std::string, std::unique_ptr<KernelTemplate>>
120 loaded_kernel_templates_;
121 std::unordered_map<std::string, std::unique_ptr<Field>> loaded_fields_;
122};
123
124} // namespace aot
125} // namespace taichi::lang
126