1 | #include "taichi/runtime/gfx/aot_module_builder_impl.h" |
2 | |
3 | #include <fstream> |
4 | #include <type_traits> |
5 | |
6 | #include "taichi/aot/module_data.h" |
7 | #include "taichi/codegen/spirv/spirv_codegen.h" |
8 | #include "taichi/runtime/gfx/aot_graph_data.h" |
9 | |
10 | namespace taichi::lang { |
11 | namespace gfx { |
12 | |
13 | namespace { |
14 | class AotDataConverter { |
15 | public: |
16 | static aot::ModuleData convert(const TaichiAotData &in) { |
17 | AotDataConverter c{}; |
18 | return c.visit(in); |
19 | } |
20 | |
21 | private: |
22 | explicit AotDataConverter() = default; |
23 | |
24 | aot::ModuleData visit(const TaichiAotData &in) const { |
25 | aot::ModuleData res{}; |
26 | for (const auto &ker : in.kernels) { |
27 | auto val = visit(ker); |
28 | res.kernels[ker.name] = val; |
29 | } |
30 | res.fields = in.fields; |
31 | res.required_caps = in.required_caps; |
32 | res.root_buffer_size = in.root_buffer_size; |
33 | return res; |
34 | } |
35 | |
36 | aot::CompiledTaichiKernel visit( |
37 | const spirv::TaichiKernelAttributes &in) const { |
38 | aot::CompiledTaichiKernel res{}; |
39 | res.tasks.reserve(in.tasks_attribs.size()); |
40 | for (const auto &t : in.tasks_attribs) { |
41 | res.tasks.push_back(visit(t)); |
42 | } |
43 | res.args_count = in.ctx_attribs.args().size(); |
44 | res.rets_count = in.ctx_attribs.rets().size(); |
45 | res.args_buffer_size = in.ctx_attribs.args_bytes(); |
46 | res.rets_buffer_size = in.ctx_attribs.rets_bytes(); |
47 | for (const auto &arg : in.ctx_attribs.args()) { |
48 | if (!arg.is_array) { |
49 | aot::ScalarArg scalar_arg{}; |
50 | scalar_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string(); |
51 | scalar_arg.offset_in_args_buf = arg.offset_in_mem; |
52 | res.scalar_args[arg.index] = scalar_arg; |
53 | } else { |
54 | aot::ArrayArg arr_arg{}; |
55 | arr_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string(); |
56 | arr_arg.field_dim = arg.field_dim; |
57 | arr_arg.element_shape = arg.element_shape; |
58 | arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t); |
59 | res.arr_args[arg.index] = arr_arg; |
60 | } |
61 | } |
62 | return res; |
63 | } |
64 | |
65 | aot::CompiledOffloadedTask visit(const TaskAttributes &in) const { |
66 | aot::CompiledOffloadedTask res{}; |
67 | res.type = offloaded_task_type_name(in.task_type); |
68 | res.name = in.name; |
69 | if (in.range_for_attribs && in.range_for_attribs->const_begin && |
70 | in.range_for_attribs->const_end) { |
71 | res.range_hint = std::to_string(in.range_for_attribs->end - |
72 | in.range_for_attribs->begin); |
73 | } |
74 | res.gpu_block_size = in.advisory_num_threads_per_group; |
75 | for (auto &buffer_bind : in.buffer_binds) { |
76 | if (buffer_bind.buffer.type == BufferType::Root) { |
77 | res.buffer_binds.push_back( |
78 | {{aot::BufferType::Root, buffer_bind.buffer.root_id}, |
79 | buffer_bind.binding}); |
80 | } else if (buffer_bind.buffer.type == BufferType::Rets) { |
81 | res.buffer_binds.push_back( |
82 | {{aot::BufferType::Rets, buffer_bind.buffer.root_id}, |
83 | buffer_bind.binding}); |
84 | } else if (buffer_bind.buffer.type == BufferType::GlobalTmps) { |
85 | res.buffer_binds.push_back( |
86 | {{aot::BufferType::GlobalTmps, buffer_bind.buffer.root_id}, |
87 | buffer_bind.binding}); |
88 | } else if (buffer_bind.buffer.type == BufferType::Args) { |
89 | res.buffer_binds.push_back( |
90 | {{aot::BufferType::Args, buffer_bind.buffer.root_id}, |
91 | buffer_bind.binding}); |
92 | } |
93 | } |
94 | |
95 | for (auto &texture_bind : in.texture_binds) { |
96 | res.texture_binds.push_back( |
97 | {texture_bind.arg_id, texture_bind.binding, texture_bind.is_storage}); |
98 | } |
99 | return res; |
100 | } |
101 | }; |
102 | |
103 | } // namespace |
104 | AotModuleBuilderImpl::AotModuleBuilderImpl( |
105 | const std::vector<CompiledSNodeStructs> &compiled_structs, |
106 | Arch device_api_backend, |
107 | const CompileConfig &compile_config, |
108 | const DeviceCapabilityConfig &caps) |
109 | : compiled_structs_(compiled_structs), |
110 | device_api_backend_(device_api_backend), |
111 | config_(compile_config), |
112 | caps_(caps) { |
113 | for (const auto &pair : caps.to_inner()) { |
114 | ti_aot_data_.required_caps[to_string(pair.first)] = pair.second; |
115 | } |
116 | if (!compiled_structs.empty()) { |
117 | ti_aot_data_.root_buffer_size = compiled_structs[0].root_size; |
118 | } |
119 | } |
120 | |
121 | std::string AotModuleBuilderImpl::write_spv_file( |
122 | const std::string &output_dir, |
123 | const TaskAttributes &k, |
124 | const std::vector<uint32_t> &source_code) const { |
125 | const std::string spv_path = fmt::format("{}/{}.spv" , output_dir, k.name); |
126 | std::ofstream fs(spv_path, std::ios_base::binary | std::ios::trunc); |
127 | fs.write((char *)source_code.data(), source_code.size() * sizeof(uint32_t)); |
128 | fs.close(); |
129 | return k.name + ".spv" ; |
130 | } |
131 | |
132 | void AotModuleBuilderImpl::dump(const std::string &output_dir, |
133 | const std::string &filename) const { |
134 | TI_WARN_IF(!filename.empty(), |
135 | "Filename prefix is ignored on Unified Device API backends." ); |
136 | const std::string bin_path = fmt::format("{}/metadata.tcb" , output_dir); |
137 | write_to_binary_file(ti_aot_data_, bin_path); |
138 | |
139 | auto converted = AotDataConverter::convert(ti_aot_data_); |
140 | const auto &spirv_codes = ti_aot_data_.spirv_codes; |
141 | for (int i = 0; i < std::min(ti_aot_data_.kernels.size(), spirv_codes.size()); |
142 | ++i) { |
143 | auto &k = ti_aot_data_.kernels[i]; |
144 | for (int j = 0; j < std::min(k.tasks_attribs.size(), spirv_codes[i].size()); |
145 | ++j) { |
146 | if (!spirv_codes[i][j].empty()) { |
147 | std::string spv_path = |
148 | write_spv_file(output_dir, k.tasks_attribs[j], spirv_codes[i][j]); |
149 | converted.kernels[k.name].tasks[j].source_path = spv_path; |
150 | } |
151 | } |
152 | } |
153 | |
154 | std::string json = liong::json::print(liong::json::serialize(ti_aot_data_)); |
155 | std::fstream f(output_dir + "/metadata.json" , |
156 | std::ios::trunc | std::ios::out); |
157 | f.write(json.data(), json.size()); |
158 | |
159 | dump_graph(output_dir); |
160 | } |
161 | |
162 | void AotModuleBuilderImpl::mangle_aot_data() { |
163 | // Only for offline cache |
164 | for (auto &kernel : ti_aot_data_.kernels) { |
165 | const auto &prefix = kernel.name; |
166 | for (std::size_t i = 0; i < kernel.tasks_attribs.size(); ++i) { |
167 | kernel.tasks_attribs[i].name = prefix + std::to_string(i); |
168 | } |
169 | } |
170 | } |
171 | |
172 | void AotModuleBuilderImpl::merge_with_old_meta_data(const std::string &path) { |
173 | // Only for offline cache |
174 | auto filename = taichi::join_path(path, "metadata.tcb" ); |
175 | if (taichi::path_exists(filename)) { |
176 | TaichiAotData old_data; |
177 | read_from_binary_file(old_data, filename); |
178 | // Ignore root_buffer_size and fields which aren't needed for offline cache |
179 | ti_aot_data_.kernels.insert(ti_aot_data_.kernels.end(), |
180 | old_data.kernels.begin(), |
181 | old_data.kernels.end()); |
182 | } |
183 | } |
184 | |
185 | std::optional<GfxRuntime::RegisterParams> |
186 | AotModuleBuilderImpl::try_get_kernel_register_params( |
187 | const std::string &kernel_name) const { |
188 | const auto &kernels = ti_aot_data_.kernels; |
189 | for (std::size_t i = 0; i < kernels.size(); ++i) { |
190 | if (kernels[i].name == kernel_name) { |
191 | GfxRuntime::RegisterParams result; |
192 | result.kernel_attribs = kernels[i]; |
193 | result.task_spirv_source_codes = ti_aot_data_.spirv_codes[i]; |
194 | // We only support a single SNodeTree during AOT. |
195 | result.num_snode_trees = 1; |
196 | return result; |
197 | } |
198 | } |
199 | return std::nullopt; |
200 | } |
201 | |
202 | void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, |
203 | Kernel *kernel) { |
204 | spirv::lower(config_, kernel); |
205 | auto compiled = run_codegen(kernel, device_api_backend_, caps_, |
206 | compiled_structs_, config_); |
207 | compiled.kernel_attribs.name = identifier; |
208 | ti_aot_data_.kernels.push_back(compiled.kernel_attribs); |
209 | ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); |
210 | } |
211 | |
212 | void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, |
213 | const SNode *rep_snode, |
214 | bool is_scalar, |
215 | DataType dt, |
216 | std::vector<int> shape, |
217 | int row_num, |
218 | int column_num) { |
219 | // Note that currently we only support adding dense fields in AOT for all |
220 | // backends. In opengl backend we only error out when a non dense field is |
221 | // added to the aot module, but in metal backend we error out earlier when |
222 | // constructing aot module. Ideally we will unify this behavior but it doesn't |
223 | // matter too much for now. |
224 | TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent), |
225 | "AOT: only supports dense field" ); |
226 | |
227 | const auto &dense_desc = |
228 | compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id); |
229 | |
230 | aot::CompiledFieldData field_data; |
231 | field_data.field_name = identifier; |
232 | field_data.is_scalar = is_scalar; |
233 | field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type); |
234 | field_data.dtype_name = dt.to_string(); |
235 | field_data.shape = shape; |
236 | field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell; |
237 | if (!is_scalar) { |
238 | field_data.element_shape = {row_num, column_num}; |
239 | } |
240 | ti_aot_data_.fields.push_back(field_data); |
241 | } |
242 | |
243 | void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier, |
244 | const std::string &key, |
245 | Kernel *kernel) { |
246 | spirv::lower(config_, kernel); |
247 | auto compiled = run_codegen(kernel, device_api_backend_, caps_, |
248 | compiled_structs_, config_); |
249 | compiled.kernel_attribs.name = identifier + "|" + key; |
250 | ti_aot_data_.kernels.push_back(compiled.kernel_attribs); |
251 | ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); |
252 | } |
253 | |
254 | } // namespace gfx |
255 | } // namespace taichi::lang |
256 | |