1 | #include "taichi/runtime/gfx/aot_module_loader_impl.h" |
2 | |
3 | #include <fstream> |
4 | #include <type_traits> |
5 | |
6 | #include "taichi/runtime/gfx/runtime.h" |
7 | #include "taichi/aot/graph_data.h" |
8 | |
9 | namespace taichi::lang { |
10 | namespace gfx { |
11 | namespace { |
12 | class FieldImpl : public aot::Field { |
13 | public: |
14 | explicit FieldImpl(GfxRuntime *runtime, const aot::CompiledFieldData &field) |
15 | : runtime_(runtime), field_(field) { |
16 | } |
17 | |
18 | private: |
19 | GfxRuntime *const runtime_; |
20 | aot::CompiledFieldData field_; |
21 | }; |
22 | |
23 | class AotModuleImpl : public aot::Module { |
24 | public: |
25 | explicit AotModuleImpl(const AotModuleParams ¶ms, Arch device_api_backend) |
26 | : module_path_(params.module_path), |
27 | runtime_(params.runtime), |
28 | device_api_backend_(device_api_backend) { |
29 | std::unique_ptr<io::VirtualDir> dir_alt = |
30 | io::VirtualDir::from_fs_dir(module_path_); |
31 | const io::VirtualDir *dir = |
32 | params.dir == nullptr ? dir_alt.get() : params.dir; |
33 | |
34 | bool succ = true; |
35 | |
36 | std::vector<uint8_t> metadata_json{}; |
37 | succ = dir->load_file("metadata.json" , metadata_json) != 0; |
38 | |
39 | if (!succ) { |
40 | mark_corrupted(); |
41 | TI_WARN("'metadata.json' cannot be read" ); |
42 | return; |
43 | } |
44 | auto json = liong::json::parse( |
45 | (const char *)metadata_json.data(), |
46 | (const char *)(metadata_json.data() + metadata_json.size())); |
47 | liong::json::deserialize(json, ti_aot_data_); |
48 | |
49 | if (!params.enable_lazy_loading) { |
50 | for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { |
51 | auto k = ti_aot_data_.kernels[i]; |
52 | std::vector<std::vector<uint32_t>> spirv_sources_codes; |
53 | for (int j = 0; j < k.tasks_attribs.size(); ++j) { |
54 | std::string spirv_path = k.tasks_attribs[j].name + ".spv" ; |
55 | |
56 | std::vector<uint32_t> spirv; |
57 | dir->load_file(spirv_path, spirv); |
58 | |
59 | if (spirv.size() == 0) { |
60 | mark_corrupted(); |
61 | TI_WARN("spirv '{}' cannot be read" , spirv_path); |
62 | return; |
63 | } |
64 | if (spirv.at(0) != 0x07230203) { |
65 | TI_WARN("spirv '{}' has a incorrect magic number {}" , spirv_path, |
66 | spirv.at(0)); |
67 | } |
68 | spirv_sources_codes.emplace_back(std::move(spirv)); |
69 | } |
70 | ti_aot_data_.spirv_codes.emplace_back(std::move(spirv_sources_codes)); |
71 | } |
72 | } |
73 | |
74 | std::vector<uint8_t> graphs_tcb{}; |
75 | succ = dir->load_file("graphs.tcb" , graphs_tcb) && |
76 | read_from_binary(graphs_, graphs_tcb.data(), graphs_tcb.size()); |
77 | |
78 | if (!succ) { |
79 | mark_corrupted(); |
80 | TI_WARN("'graphs.tcb' cannot be read" ); |
81 | return; |
82 | } |
83 | } |
84 | |
85 | std::unique_ptr<aot::CompiledGraph> get_graph( |
86 | const std::string &name) override { |
87 | auto it = graphs_.find(name); |
88 | if (it == graphs_.end()) { |
89 | TI_DEBUG("Cannot find graph {}" , name); |
90 | return nullptr; |
91 | } |
92 | |
93 | std::vector<aot::CompiledDispatch> dispatches; |
94 | for (auto &dispatch : it->second.dispatches) { |
95 | dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args, |
96 | get_kernel(dispatch.kernel_name)}); |
97 | } |
98 | aot::CompiledGraph graph{dispatches}; |
99 | return std::make_unique<aot::CompiledGraph>(std::move(graph)); |
100 | } |
101 | |
102 | size_t get_root_size() const override { |
103 | return ti_aot_data_.root_buffer_size; |
104 | } |
105 | |
106 | // Module metadata |
107 | Arch arch() const override { |
108 | return device_api_backend_; |
109 | } |
110 | uint64_t version() const override { |
111 | TI_NOT_IMPLEMENTED; |
112 | } |
113 | |
114 | private: |
115 | bool get_field_data_by_name(const std::string &name, |
116 | aot::CompiledFieldData &field) { |
117 | for (int i = 0; i < ti_aot_data_.fields.size(); ++i) { |
118 | if (ti_aot_data_.fields[i].field_name.rfind(name, 0) == 0) { |
119 | field = ti_aot_data_.fields[i]; |
120 | return true; |
121 | } |
122 | } |
123 | return false; |
124 | } |
125 | |
126 | bool get_kernel_params_by_name(const std::string &name, |
127 | GfxRuntime::RegisterParams &kernel) { |
128 | for (int i = 0; i < ti_aot_data_.kernels.size(); ++i) { |
129 | // Offloaded task names encode more than the name of the function, but for |
130 | // AOT, only use the name of the function which should be the first part |
131 | // of the struct |
132 | if (ti_aot_data_.kernels[i].name.rfind(name, 0) == 0) { |
133 | if (!try_load_spv_kernel(i)) { |
134 | return false; |
135 | } |
136 | kernel.kernel_attribs = ti_aot_data_.kernels[i]; |
137 | kernel.task_spirv_source_codes = ti_aot_data_.spirv_codes[i]; |
138 | // We don't have to store the number of SNodeTree in |ti_aot_data_| yet, |
139 | // because right now we only support a single SNodeTree during AOT. |
140 | // TODO: Support multiple SNodeTrees in AOT. |
141 | kernel.num_snode_trees = 1; |
142 | return true; |
143 | } |
144 | } |
145 | return false; |
146 | } |
147 | |
148 | std::unique_ptr<aot::Kernel> make_new_kernel( |
149 | const std::string &name) override { |
150 | GfxRuntime::RegisterParams kparams; |
151 | if (!get_kernel_params_by_name(name, kparams)) { |
152 | TI_DEBUG("Failed to load kernel {}" , name); |
153 | return nullptr; |
154 | } |
155 | return std::make_unique<KernelImpl>(runtime_, std::move(kparams)); |
156 | } |
157 | |
158 | std::unique_ptr<aot::KernelTemplate> make_new_kernel_template( |
159 | const std::string &name) override { |
160 | TI_NOT_IMPLEMENTED; |
161 | return nullptr; |
162 | } |
163 | |
164 | std::unique_ptr<aot::Field> make_new_field(const std::string &name) override { |
165 | aot::CompiledFieldData field; |
166 | if (!get_field_data_by_name(name, field)) { |
167 | TI_DEBUG("Failed to load field {}" , name); |
168 | return nullptr; |
169 | } |
170 | return std::make_unique<FieldImpl>(runtime_, field); |
171 | } |
172 | |
173 | bool try_load_spv_kernel(std::size_t index) { |
174 | if (index >= ti_aot_data_.spirv_codes.size() || |
175 | ti_aot_data_.spirv_codes[index].empty()) { |
176 | ti_aot_data_.spirv_codes.resize(index + 1); |
177 | auto &codes = ti_aot_data_.spirv_codes[index]; |
178 | const auto &k = ti_aot_data_.kernels[index]; |
179 | for (const auto &t : k.tasks_attribs) { |
180 | auto spv = read_spv_file(module_path_, t); |
181 | if (spv.empty()) { |
182 | mark_corrupted(); |
183 | return false; |
184 | } |
185 | codes.push_back(spv); |
186 | } |
187 | } |
188 | return true; |
189 | } |
190 | |
191 | static std::vector<uint32_t> read_spv_file(const std::string &output_dir, |
192 | const TaskAttributes &k) { |
193 | const std::string spv_path = fmt::format("{}/{}.spv" , output_dir, k.name); |
194 | std::vector<uint32_t> source_code; |
195 | std::ifstream fs(spv_path, std::ios_base::binary | std::ios::ate); |
196 | if (fs.is_open()) { |
197 | size_t size = fs.tellg(); |
198 | fs.seekg(0, std::ios::beg); |
199 | source_code.resize(size / sizeof(uint32_t)); |
200 | fs.read((char *)source_code.data(), size); |
201 | fs.close(); |
202 | } |
203 | return source_code; |
204 | } |
205 | |
206 | std::string module_path_; |
207 | TaichiAotData ti_aot_data_; |
208 | GfxRuntime *runtime_{nullptr}; |
209 | Arch device_api_backend_; |
210 | }; |
211 | |
212 | } // namespace |
213 | |
214 | std::unique_ptr<aot::Module> make_aot_module(std::any mod_params, |
215 | Arch device_api_backend) { |
216 | AotModuleParams params = std::any_cast<AotModuleParams &>(mod_params); |
217 | return std::make_unique<AotModuleImpl>(params, device_api_backend); |
218 | } |
219 | |
220 | } // namespace gfx |
221 | } // namespace taichi::lang |
222 | |