1 | #pragma once |
---|---|
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | |
6 | #include "taichi/ir/type.h" |
7 | #include "taichi/aot/graph_data.h" |
8 | |
9 | namespace taichi::lang { |
10 | class Kernel; |
11 | class GraphBuilder; |
12 | |
13 | class Node { |
14 | public: |
15 | Node() = default; |
16 | virtual ~Node() = default; |
17 | Node(const Node &) = delete; |
18 | Node &operator=(const Node &) = delete; |
19 | Node(Node &&) = default; |
20 | Node &operator=(Node &&) = default; |
21 | |
22 | virtual void compile( |
23 | std::vector<aot::CompiledDispatch> &compiled_dispatches) = 0; |
24 | }; |
25 | |
26 | class Dispatch : public Node { |
27 | public: |
28 | explicit Dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) |
29 | : kernel_(kernel), symbolic_args_(args) { |
30 | } |
31 | |
32 | void compile( |
33 | std::vector<aot::CompiledDispatch> &compiled_dispatches) override; |
34 | |
35 | private: |
36 | mutable bool serialized_{false}; |
37 | Kernel *kernel_{nullptr}; |
38 | std::vector<aot::Arg> symbolic_args_; |
39 | }; |
40 | |
41 | class Sequential : public Node { |
42 | public: |
43 | explicit Sequential(GraphBuilder *graph) : owning_graph_(graph) { |
44 | } |
45 | |
46 | void append(Node *node); |
47 | |
48 | void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args); |
49 | |
50 | void compile( |
51 | std::vector<aot::CompiledDispatch> &compiled_dispatches) override; |
52 | |
53 | private: |
54 | std::vector<Node *> sequence_; |
55 | GraphBuilder *owning_graph_{nullptr}; |
56 | }; |
57 | |
58 | class GraphBuilder { |
59 | public: |
60 | explicit GraphBuilder(); |
61 | |
62 | // TODO: compile() can take in Arch argument |
63 | std::unique_ptr<aot::CompiledGraph> compile(); |
64 | |
65 | Node *new_dispatch_node(Kernel *kernel, const std::vector<aot::Arg> &args); |
66 | |
67 | Sequential *new_sequential_node(); |
68 | |
69 | void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args); |
70 | |
71 | Sequential *seq() const; |
72 | |
73 | private: |
74 | std::unique_ptr<Sequential> seq_{nullptr}; |
75 | std::unordered_map<std::string, aot::Arg> all_args_; |
76 | std::vector<std::unique_ptr<Node>> all_nodes_; |
77 | }; |
78 | |
79 | } // namespace taichi::lang |
80 |