1#include "taichi/runtime/gfx/aot_module_loader_impl.h"
2
3#include <fstream>
4#include <type_traits>
5
6#include "taichi/runtime/gfx/runtime.h"
7#include "taichi/aot/graph_data.h"
8
9namespace taichi::lang {
10namespace gfx {
11namespace {
12class FieldImpl : public aot::Field {
13 public:
14 explicit FieldImpl(GfxRuntime *runtime, const aot::CompiledFieldData &field)
15 : runtime_(runtime), field_(field) {
16 }
17
18 private:
19 GfxRuntime *const runtime_;
20 aot::CompiledFieldData field_;
21};
22
23class AotModuleImpl : public aot::Module {
24 public:
25 explicit AotModuleImpl(const AotModuleParams &params, Arch device_api_backend)
26 : module_path_(params.module_path),
27 runtime_(params.runtime),
28 device_api_backend_(device_api_backend) {
29 std::unique_ptr<io::VirtualDir> dir_alt =
30 io::VirtualDir::from_fs_dir(module_path_);
31 const io::VirtualDir *dir =
32 params.dir == nullptr ? dir_alt.get() : params.dir;
33
34 bool succ = true;
35
36 std::vector<uint8_t> metadata_json{};
37 succ = dir->load_file("metadata.json", metadata_json) != 0;
38
39 if (!succ) {
40 mark_corrupted();
41 TI_WARN("'metadata.json' cannot be read");
42 return;
43 }
44 auto json = liong::json::parse(
45 (const char *)metadata_json.data(),
46 (const char *)(metadata_json.data() + metadata_json.size()));
47 liong::json::deserialize(json, ti_aot_data_);
48
49 if (!params.enable_lazy_loading) {
50 for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) {
51 auto k = ti_aot_data_.kernels[i];
52 std::vector<std::vector<uint32_t>> spirv_sources_codes;
53 for (int j = 0; j < k.tasks_attribs.size(); ++j) {
54 std::string spirv_path = k.tasks_attribs[j].name + ".spv";
55
56 std::vector<uint32_t> spirv;
57 dir->load_file(spirv_path, spirv);
58
59 if (spirv.size() == 0) {
60 mark_corrupted();
61 TI_WARN("spirv '{}' cannot be read", spirv_path);
62 return;
63 }
64 if (spirv.at(0) != 0x07230203) {
65 TI_WARN("spirv '{}' has a incorrect magic number {}", spirv_path,
66 spirv.at(0));
67 }
68 spirv_sources_codes.emplace_back(std::move(spirv));
69 }
70 ti_aot_data_.spirv_codes.emplace_back(std::move(spirv_sources_codes));
71 }
72 }
73
74 std::vector<uint8_t> graphs_tcb{};
75 succ = dir->load_file("graphs.tcb", graphs_tcb) &&
76 read_from_binary(graphs_, graphs_tcb.data(), graphs_tcb.size());
77
78 if (!succ) {
79 mark_corrupted();
80 TI_WARN("'graphs.tcb' cannot be read");
81 return;
82 }
83 }
84
85 std::unique_ptr<aot::CompiledGraph> get_graph(
86 const std::string &name) override {
87 auto it = graphs_.find(name);
88 if (it == graphs_.end()) {
89 TI_DEBUG("Cannot find graph {}", name);
90 return nullptr;
91 }
92
93 std::vector<aot::CompiledDispatch> dispatches;
94 for (auto &dispatch : it->second.dispatches) {
95 dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args,
96 get_kernel(dispatch.kernel_name)});
97 }
98 aot::CompiledGraph graph{dispatches};
99 return std::make_unique<aot::CompiledGraph>(std::move(graph));
100 }
101
102 size_t get_root_size() const override {
103 return ti_aot_data_.root_buffer_size;
104 }
105
106 // Module metadata
107 Arch arch() const override {
108 return device_api_backend_;
109 }
110 uint64_t version() const override {
111 TI_NOT_IMPLEMENTED;
112 }
113
114 private:
115 bool get_field_data_by_name(const std::string &name,
116 aot::CompiledFieldData &field) {
117 for (int i = 0; i < ti_aot_data_.fields.size(); ++i) {
118 if (ti_aot_data_.fields[i].field_name.rfind(name, 0) == 0) {
119 field = ti_aot_data_.fields[i];
120 return true;
121 }
122 }
123 return false;
124 }
125
126 bool get_kernel_params_by_name(const std::string &name,
127 GfxRuntime::RegisterParams &kernel) {
128 for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) {
129 // Offloaded task names encode more than the name of the function, but for
130 // AOT, only use the name of the function which should be the first part
131 // of the struct
132 if (ti_aot_data_.kernels[i].name.rfind(name, 0) == 0) {
133 if (!try_load_spv_kernel(i)) {
134 return false;
135 }
136 kernel.kernel_attribs = ti_aot_data_.kernels[i];
137 kernel.task_spirv_source_codes = ti_aot_data_.spirv_codes[i];
138 // We don't have to store the number of SNodeTree in |ti_aot_data_| yet,
139 // because right now we only support a single SNodeTree during AOT.
140 // TODO: Support multiple SNodeTrees in AOT.
141 kernel.num_snode_trees = 1;
142 return true;
143 }
144 }
145 return false;
146 }
147
148 std::unique_ptr<aot::Kernel> make_new_kernel(
149 const std::string &name) override {
150 GfxRuntime::RegisterParams kparams;
151 if (!get_kernel_params_by_name(name, kparams)) {
152 TI_DEBUG("Failed to load kernel {}", name);
153 return nullptr;
154 }
155 return std::make_unique<KernelImpl>(runtime_, std::move(kparams));
156 }
157
158 std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
159 const std::string &name) override {
160 TI_NOT_IMPLEMENTED;
161 return nullptr;
162 }
163
164 std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
165 aot::CompiledFieldData field;
166 if (!get_field_data_by_name(name, field)) {
167 TI_DEBUG("Failed to load field {}", name);
168 return nullptr;
169 }
170 return std::make_unique<FieldImpl>(runtime_, field);
171 }
172
173 bool try_load_spv_kernel(std::size_t index) {
174 if (index >= ti_aot_data_.spirv_codes.size() ||
175 ti_aot_data_.spirv_codes[index].empty()) {
176 ti_aot_data_.spirv_codes.resize(index + 1);
177 auto &codes = ti_aot_data_.spirv_codes[index];
178 const auto &k = ti_aot_data_.kernels[index];
179 for (const auto &t : k.tasks_attribs) {
180 auto spv = read_spv_file(module_path_, t);
181 if (spv.empty()) {
182 mark_corrupted();
183 return false;
184 }
185 codes.push_back(spv);
186 }
187 }
188 return true;
189 }
190
191 static std::vector<uint32_t> read_spv_file(const std::string &output_dir,
192 const TaskAttributes &k) {
193 const std::string spv_path = fmt::format("{}/{}.spv", output_dir, k.name);
194 std::vector<uint32_t> source_code;
195 std::ifstream fs(spv_path, std::ios_base::binary | std::ios::ate);
196 if (fs.is_open()) {
197 size_t size = fs.tellg();
198 fs.seekg(0, std::ios::beg);
199 source_code.resize(size / sizeof(uint32_t));
200 fs.read((char *)source_code.data(), size);
201 fs.close();
202 }
203 return source_code;
204 }
205
206 std::string module_path_;
207 TaichiAotData ti_aot_data_;
208 GfxRuntime *runtime_{nullptr};
209 Arch device_api_backend_;
210};
211
212} // namespace
213
214std::unique_ptr<aot::Module> make_aot_module(std::any mod_params,
215 Arch device_api_backend) {
216 AotModuleParams params = std::any_cast<AotModuleParams &>(mod_params);
217 return std::make_unique<AotModuleImpl>(params, device_api_backend);
218}
219
220} // namespace gfx
221} // namespace taichi::lang
222