1 | #include "taichi/aot/module_builder.h" |
2 | #include "taichi/program/kernel.h" |
3 | |
4 | namespace taichi::lang { |
5 | |
6 | void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) { |
7 | add_per_backend(identifier, kernel); |
8 | } |
9 | |
10 | void AotModuleBuilder::add_field(const std::string &identifier, |
11 | const SNode *rep_snode, |
12 | bool is_scalar, |
13 | DataType dt, |
14 | std::vector<int> shape, |
15 | int row_num, |
16 | int column_num) { |
17 | add_field_per_backend(identifier, rep_snode, is_scalar, dt, shape, row_num, |
18 | column_num); |
19 | } |
20 | |
21 | void AotModuleBuilder::add_kernel_template(const std::string &identifier, |
22 | const std::string &key, |
23 | Kernel *kernel) { |
24 | add_per_backend_tmpl(identifier, key, kernel); |
25 | } |
26 | |
27 | bool AotModuleBuilder::all_fields_are_dense_in_container( |
28 | const SNode *container) { |
29 | for (const auto &ch : container->ch) { |
30 | if (ch->type != SNodeType::place) { |
31 | return false; |
32 | } |
33 | } |
34 | const auto *parent = container->parent; |
35 | if (!parent) { |
36 | return false; |
37 | } |
38 | if (parent->type != SNodeType::root) { |
39 | return false; |
40 | } |
41 | return true; |
42 | } |
43 | |
44 | void AotModuleBuilder::load(const std::string &output_dir) { |
45 | TI_ERROR("Aot loader not supported" ); |
46 | } |
47 | |
48 | void AotModuleBuilder::dump_graph(std::string output_dir) const { |
49 | const std::string graph_file = fmt::format("{}/graphs.tcb" , output_dir); |
50 | write_to_binary_file(graphs_, graph_file); |
51 | } |
52 | |
53 | void AotModuleBuilder::add_graph(const std::string &name, |
54 | const aot::CompiledGraph &graph) { |
55 | if (graphs_.count(name) != 0) { |
56 | TI_ERROR("Graph {} already exists" , name); |
57 | } |
58 | // Handle adding kernels separately. |
59 | std::unordered_map<std::string, lang::Kernel *> kernels; |
60 | for (const auto &dispatch : graph.dispatches) { |
61 | kernels[dispatch.kernel_name] = dispatch.ti_kernel; |
62 | } |
63 | for (auto &e : kernels) { |
64 | add(e.first, e.second); |
65 | } |
66 | graphs_[name] = graph; |
67 | } |
68 | } // namespace taichi::lang |
69 | |