1 | #include "taichi/program/graph_builder.h" |
---|---|
2 | #include "taichi/program/ndarray.h" |
3 | #include "taichi/program/program.h" |
4 | |
5 | namespace taichi::lang { |
6 | void Dispatch::compile( |
7 | std::vector<aot::CompiledDispatch> &compiled_dispatches) { |
8 | aot::CompiledDispatch dispatch; |
9 | dispatch.kernel_name = kernel_->get_name(); |
10 | dispatch.symbolic_args = symbolic_args_; |
11 | dispatch.ti_kernel = kernel_; |
12 | dispatch.compiled_kernel = nullptr; |
13 | compiled_dispatches.push_back(std::move(dispatch)); |
14 | } |
15 | |
16 | void Sequential::compile( |
17 | std::vector<aot::CompiledDispatch> &compiled_dispatches) { |
18 | // In the future we can do more across-kernel optimization here. |
19 | for (Node *n : sequence_) { |
20 | n->compile(compiled_dispatches); |
21 | } |
22 | } |
23 | |
24 | void Sequential::append(Node *node) { |
25 | sequence_.push_back(node); |
26 | } |
27 | |
28 | void Sequential::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) { |
29 | Node *n = owning_graph_->new_dispatch_node(kernel, args); |
30 | sequence_.push_back(n); |
31 | } |
32 | |
33 | GraphBuilder::GraphBuilder() { |
34 | seq_ = std::make_unique<Sequential>(this); |
35 | } |
36 | |
37 | Node *GraphBuilder::new_dispatch_node(Kernel *kernel, |
38 | const std::vector<aot::Arg> &args) { |
39 | for (const auto &arg : args) { |
40 | if (all_args_.find(arg.name) != all_args_.end()) { |
41 | TI_ERROR_IF(all_args_[arg.name] != arg, |
42 | "An arg with name {} already exists!", arg.name); |
43 | } else { |
44 | all_args_[arg.name] = arg; |
45 | } |
46 | } |
47 | all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args)); |
48 | return all_nodes_.back().get(); |
49 | } |
50 | |
51 | Sequential *GraphBuilder::new_sequential_node() { |
52 | all_nodes_.push_back(std::make_unique<Sequential>(this)); |
53 | return static_cast<Sequential *>(all_nodes_.back().get()); |
54 | } |
55 | |
56 | std::unique_ptr<aot::CompiledGraph> GraphBuilder::compile() { |
57 | std::vector<aot::CompiledDispatch> dispatches; |
58 | seq()->compile(dispatches); |
59 | aot::CompiledGraph graph{dispatches, all_args_}; |
60 | return std::make_unique<aot::CompiledGraph>(std::move(graph)); |
61 | } |
62 | |
63 | Sequential *GraphBuilder::seq() const { |
64 | return seq_.get(); |
65 | } |
66 | |
67 | void GraphBuilder::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) { |
68 | seq()->dispatch(kernel, args); |
69 | } |
70 | |
71 | } // namespace taichi::lang |
72 |