1#pragma once
2
3#include <string>
4#include <vector>
5
6#include "taichi/ir/type.h"
7#include "taichi/aot/graph_data.h"
8
9namespace taichi::lang {
10class Kernel;
11class GraphBuilder;
12
13class Node {
14 public:
15 Node() = default;
16 virtual ~Node() = default;
17 Node(const Node &) = delete;
18 Node &operator=(const Node &) = delete;
19 Node(Node &&) = default;
20 Node &operator=(Node &&) = default;
21
22 virtual void compile(
23 std::vector<aot::CompiledDispatch> &compiled_dispatches) = 0;
24};
25
26class Dispatch : public Node {
27 public:
28 explicit Dispatch(Kernel *kernel, const std::vector<aot::Arg> &args)
29 : kernel_(kernel), symbolic_args_(args) {
30 }
31
32 void compile(
33 std::vector<aot::CompiledDispatch> &compiled_dispatches) override;
34
35 private:
36 mutable bool serialized_{false};
37 Kernel *kernel_{nullptr};
38 std::vector<aot::Arg> symbolic_args_;
39};
40
41class Sequential : public Node {
42 public:
43 explicit Sequential(GraphBuilder *graph) : owning_graph_(graph) {
44 }
45
46 void append(Node *node);
47
48 void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);
49
50 void compile(
51 std::vector<aot::CompiledDispatch> &compiled_dispatches) override;
52
53 private:
54 std::vector<Node *> sequence_;
55 GraphBuilder *owning_graph_{nullptr};
56};
57
58class GraphBuilder {
59 public:
60 explicit GraphBuilder();
61
62 // TODO: compile() can take in Arch argument
63 std::unique_ptr<aot::CompiledGraph> compile();
64
65 Node *new_dispatch_node(Kernel *kernel, const std::vector<aot::Arg> &args);
66
67 Sequential *new_sequential_node();
68
69 void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);
70
71 Sequential *seq() const;
72
73 private:
74 std::unique_ptr<Sequential> seq_{nullptr};
75 std::unordered_map<std::string, aot::Arg> all_args_;
76 std::vector<std::unique_ptr<Node>> all_nodes_;
77};
78
79} // namespace taichi::lang
80