1#pragma once
2
3#include <string>
4#include <vector>
5
6#include "taichi/aot/module_data.h"
7#include "taichi/rhi/device.h"
8#include "taichi/ir/snode.h"
9#include "taichi/aot/module_data.h"
10#include "taichi/aot/graph_data.h"
11
12namespace taichi::lang {
13
14class Kernel;
15class DataType;
16
17class AotModuleBuilder {
18 public:
19 virtual ~AotModuleBuilder() = default;
20
21 void add(const std::string &identifier, Kernel *kernel);
22
23 void add_field(const std::string &identifier,
24 const SNode *rep_snode,
25 bool is_scalar,
26 DataType dt,
27 std::vector<int> shape,
28 int row_num,
29 int column_num);
30
31 void add_kernel_template(const std::string &identifier,
32 const std::string &key,
33 Kernel *kernel);
34
35 virtual void load(const std::string &output_dir);
36
37 virtual void dump(const std::string &output_dir,
38 const std::string &filename) const = 0;
39
40 void add_graph(const std::string &name, const aot::CompiledGraph &graph);
41
42 protected:
43 /**
44 * Intended to be overridden by each backend's implementation.
45 */
46 virtual void add_per_backend(const std::string &identifier,
47 Kernel *kernel) = 0;
48 virtual void add_field_per_backend(const std::string &identifier,
49 const SNode *rep_snode,
50 bool is_scalar,
51 DataType dt,
52 std::vector<int> shape,
53 int row_num,
54 int column_num) {
55 TI_NOT_IMPLEMENTED;
56 }
57
58 virtual void add_ndarray_per_backend(const std::string &identifier,
59 bool is_scalar,
60 DataType dt,
61 std::vector<int> shape,
62 int row_num,
63 int column_num) {
64 TI_NOT_IMPLEMENTED;
65 }
66
67 virtual void add_per_backend_tmpl(const std::string &identifier,
68 const std::string &key,
69 Kernel *kernel) {
70 TI_NOT_IMPLEMENTED;
71 }
72
73 void dump_graph(std::string output_dir) const;
74
75 static bool all_fields_are_dense_in_container(const SNode *container);
76
77 private:
78 std::unordered_map<std::string, aot::CompiledGraph> graphs_;
79};
80
81} // namespace taichi::lang
82