1#include "taichi/aot/module_builder.h"
2#include "taichi/program/kernel.h"
3
4namespace taichi::lang {
5
6void AotModuleBuilder::add(const std::string &identifier, Kernel *kernel) {
7 add_per_backend(identifier, kernel);
8}
9
10void AotModuleBuilder::add_field(const std::string &identifier,
11 const SNode *rep_snode,
12 bool is_scalar,
13 DataType dt,
14 std::vector<int> shape,
15 int row_num,
16 int column_num) {
17 add_field_per_backend(identifier, rep_snode, is_scalar, dt, shape, row_num,
18 column_num);
19}
20
21void AotModuleBuilder::add_kernel_template(const std::string &identifier,
22 const std::string &key,
23 Kernel *kernel) {
24 add_per_backend_tmpl(identifier, key, kernel);
25}
26
27bool AotModuleBuilder::all_fields_are_dense_in_container(
28 const SNode *container) {
29 for (const auto &ch : container->ch) {
30 if (ch->type != SNodeType::place) {
31 return false;
32 }
33 }
34 const auto *parent = container->parent;
35 if (!parent) {
36 return false;
37 }
38 if (parent->type != SNodeType::root) {
39 return false;
40 }
41 return true;
42}
43
44void AotModuleBuilder::load(const std::string &output_dir) {
45 TI_ERROR("Aot loader not supported");
46}
47
48void AotModuleBuilder::dump_graph(std::string output_dir) const {
49 const std::string graph_file = fmt::format("{}/graphs.tcb", output_dir);
50 write_to_binary_file(graphs_, graph_file);
51}
52
53void AotModuleBuilder::add_graph(const std::string &name,
54 const aot::CompiledGraph &graph) {
55 if (graphs_.count(name) != 0) {
56 TI_ERROR("Graph {} already exists", name);
57 }
58 // Handle adding kernels separately.
59 std::unordered_map<std::string, lang::Kernel *> kernels;
60 for (const auto &dispatch : graph.dispatches) {
61 kernels[dispatch.kernel_name] = dispatch.ti_kernel;
62 }
63 for (auto &e : kernels) {
64 add(e.first, e.second);
65 }
66 graphs_[name] = graph;
67}
68} // namespace taichi::lang
69