1#include "taichi/program/graph_builder.h"
2#include "taichi/program/ndarray.h"
3#include "taichi/program/program.h"
4
5namespace taichi::lang {
6void Dispatch::compile(
7 std::vector<aot::CompiledDispatch> &compiled_dispatches) {
8 aot::CompiledDispatch dispatch;
9 dispatch.kernel_name = kernel_->get_name();
10 dispatch.symbolic_args = symbolic_args_;
11 dispatch.ti_kernel = kernel_;
12 dispatch.compiled_kernel = nullptr;
13 compiled_dispatches.push_back(std::move(dispatch));
14}
15
16void Sequential::compile(
17 std::vector<aot::CompiledDispatch> &compiled_dispatches) {
18 // In the future we can do more across-kernel optimization here.
19 for (Node *n : sequence_) {
20 n->compile(compiled_dispatches);
21 }
22}
23
24void Sequential::append(Node *node) {
25 sequence_.push_back(node);
26}
27
28void Sequential::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
29 Node *n = owning_graph_->new_dispatch_node(kernel, args);
30 sequence_.push_back(n);
31}
32
33GraphBuilder::GraphBuilder() {
34 seq_ = std::make_unique<Sequential>(this);
35}
36
37Node *GraphBuilder::new_dispatch_node(Kernel *kernel,
38 const std::vector<aot::Arg> &args) {
39 for (const auto &arg : args) {
40 if (all_args_.find(arg.name) != all_args_.end()) {
41 TI_ERROR_IF(all_args_[arg.name] != arg,
42 "An arg with name {} already exists!", arg.name);
43 } else {
44 all_args_[arg.name] = arg;
45 }
46 }
47 all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args));
48 return all_nodes_.back().get();
49}
50
51Sequential *GraphBuilder::new_sequential_node() {
52 all_nodes_.push_back(std::make_unique<Sequential>(this));
53 return static_cast<Sequential *>(all_nodes_.back().get());
54}
55
56std::unique_ptr<aot::CompiledGraph> GraphBuilder::compile() {
57 std::vector<aot::CompiledDispatch> dispatches;
58 seq()->compile(dispatches);
59 aot::CompiledGraph graph{dispatches, all_args_};
60 return std::make_unique<aot::CompiledGraph>(std::move(graph));
61}
62
63Sequential *GraphBuilder::seq() const {
64 return seq_.get();
65}
66
67void GraphBuilder::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
68 seq()->dispatch(kernel, args);
69}
70
71} // namespace taichi::lang
72