1 | #pragma once |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | |
6 | #include "taichi/aot/module_data.h" |
7 | #include "taichi/rhi/device.h" |
8 | #include "taichi/ir/snode.h" |
9 | #include "taichi/aot/module_data.h" |
10 | #include "taichi/aot/graph_data.h" |
11 | |
12 | namespace taichi::lang { |
13 | |
14 | class Kernel; |
15 | class DataType; |
16 | |
17 | class AotModuleBuilder { |
18 | public: |
19 | virtual ~AotModuleBuilder() = default; |
20 | |
21 | void add(const std::string &identifier, Kernel *kernel); |
22 | |
23 | void add_field(const std::string &identifier, |
24 | const SNode *rep_snode, |
25 | bool is_scalar, |
26 | DataType dt, |
27 | std::vector<int> shape, |
28 | int row_num, |
29 | int column_num); |
30 | |
31 | void add_kernel_template(const std::string &identifier, |
32 | const std::string &key, |
33 | Kernel *kernel); |
34 | |
35 | virtual void load(const std::string &output_dir); |
36 | |
37 | virtual void dump(const std::string &output_dir, |
38 | const std::string &filename) const = 0; |
39 | |
40 | void add_graph(const std::string &name, const aot::CompiledGraph &graph); |
41 | |
42 | protected: |
43 | /** |
44 | * Intended to be overridden by each backend's implementation. |
45 | */ |
46 | virtual void add_per_backend(const std::string &identifier, |
47 | Kernel *kernel) = 0; |
48 | virtual void add_field_per_backend(const std::string &identifier, |
49 | const SNode *rep_snode, |
50 | bool is_scalar, |
51 | DataType dt, |
52 | std::vector<int> shape, |
53 | int row_num, |
54 | int column_num) { |
55 | TI_NOT_IMPLEMENTED; |
56 | } |
57 | |
58 | virtual void add_ndarray_per_backend(const std::string &identifier, |
59 | bool is_scalar, |
60 | DataType dt, |
61 | std::vector<int> shape, |
62 | int row_num, |
63 | int column_num) { |
64 | TI_NOT_IMPLEMENTED; |
65 | } |
66 | |
67 | virtual void add_per_backend_tmpl(const std::string &identifier, |
68 | const std::string &key, |
69 | Kernel *kernel) { |
70 | TI_NOT_IMPLEMENTED; |
71 | } |
72 | |
73 | void dump_graph(std::string output_dir) const; |
74 | |
75 | static bool all_fields_are_dense_in_container(const SNode *container); |
76 | |
77 | private: |
78 | std::unordered_map<std::string, aot::CompiledGraph> graphs_; |
79 | }; |
80 | |
81 | } // namespace taichi::lang |
82 | |