1#include "taichi/runtime/gfx/aot_module_builder_impl.h"
2
3#include <fstream>
4#include <type_traits>
5
6#include "taichi/aot/module_data.h"
7#include "taichi/codegen/spirv/spirv_codegen.h"
8#include "taichi/runtime/gfx/aot_graph_data.h"
9
10namespace taichi::lang {
11namespace gfx {
12
13namespace {
14class AotDataConverter {
15 public:
16 static aot::ModuleData convert(const TaichiAotData &in) {
17 AotDataConverter c{};
18 return c.visit(in);
19 }
20
21 private:
22 explicit AotDataConverter() = default;
23
24 aot::ModuleData visit(const TaichiAotData &in) const {
25 aot::ModuleData res{};
26 for (const auto &ker : in.kernels) {
27 auto val = visit(ker);
28 res.kernels[ker.name] = val;
29 }
30 res.fields = in.fields;
31 res.required_caps = in.required_caps;
32 res.root_buffer_size = in.root_buffer_size;
33 return res;
34 }
35
36 aot::CompiledTaichiKernel visit(
37 const spirv::TaichiKernelAttributes &in) const {
38 aot::CompiledTaichiKernel res{};
39 res.tasks.reserve(in.tasks_attribs.size());
40 for (const auto &t : in.tasks_attribs) {
41 res.tasks.push_back(visit(t));
42 }
43 res.args_count = in.ctx_attribs.args().size();
44 res.rets_count = in.ctx_attribs.rets().size();
45 res.args_buffer_size = in.ctx_attribs.args_bytes();
46 res.rets_buffer_size = in.ctx_attribs.rets_bytes();
47 for (const auto &arg : in.ctx_attribs.args()) {
48 if (!arg.is_array) {
49 aot::ScalarArg scalar_arg{};
50 scalar_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
51 scalar_arg.offset_in_args_buf = arg.offset_in_mem;
52 res.scalar_args[arg.index] = scalar_arg;
53 } else {
54 aot::ArrayArg arr_arg{};
55 arr_arg.dtype_name = PrimitiveType::get(arg.dtype).to_string();
56 arr_arg.field_dim = arg.field_dim;
57 arr_arg.element_shape = arg.element_shape;
58 arr_arg.shape_offset_in_args_buf = arg.index * sizeof(int32_t);
59 res.arr_args[arg.index] = arr_arg;
60 }
61 }
62 return res;
63 }
64
65 aot::CompiledOffloadedTask visit(const TaskAttributes &in) const {
66 aot::CompiledOffloadedTask res{};
67 res.type = offloaded_task_type_name(in.task_type);
68 res.name = in.name;
69 if (in.range_for_attribs && in.range_for_attribs->const_begin &&
70 in.range_for_attribs->const_end) {
71 res.range_hint = std::to_string(in.range_for_attribs->end -
72 in.range_for_attribs->begin);
73 }
74 res.gpu_block_size = in.advisory_num_threads_per_group;
75 for (auto &buffer_bind : in.buffer_binds) {
76 if (buffer_bind.buffer.type == BufferType::Root) {
77 res.buffer_binds.push_back(
78 {{aot::BufferType::Root, buffer_bind.buffer.root_id},
79 buffer_bind.binding});
80 } else if (buffer_bind.buffer.type == BufferType::Rets) {
81 res.buffer_binds.push_back(
82 {{aot::BufferType::Rets, buffer_bind.buffer.root_id},
83 buffer_bind.binding});
84 } else if (buffer_bind.buffer.type == BufferType::GlobalTmps) {
85 res.buffer_binds.push_back(
86 {{aot::BufferType::GlobalTmps, buffer_bind.buffer.root_id},
87 buffer_bind.binding});
88 } else if (buffer_bind.buffer.type == BufferType::Args) {
89 res.buffer_binds.push_back(
90 {{aot::BufferType::Args, buffer_bind.buffer.root_id},
91 buffer_bind.binding});
92 }
93 }
94
95 for (auto &texture_bind : in.texture_binds) {
96 res.texture_binds.push_back(
97 {texture_bind.arg_id, texture_bind.binding, texture_bind.is_storage});
98 }
99 return res;
100 }
101};
102
103} // namespace
104AotModuleBuilderImpl::AotModuleBuilderImpl(
105 const std::vector<CompiledSNodeStructs> &compiled_structs,
106 Arch device_api_backend,
107 const CompileConfig &compile_config,
108 const DeviceCapabilityConfig &caps)
109 : compiled_structs_(compiled_structs),
110 device_api_backend_(device_api_backend),
111 config_(compile_config),
112 caps_(caps) {
113 for (const auto &pair : caps.to_inner()) {
114 ti_aot_data_.required_caps[to_string(pair.first)] = pair.second;
115 }
116 if (!compiled_structs.empty()) {
117 ti_aot_data_.root_buffer_size = compiled_structs[0].root_size;
118 }
119}
120
121std::string AotModuleBuilderImpl::write_spv_file(
122 const std::string &output_dir,
123 const TaskAttributes &k,
124 const std::vector<uint32_t> &source_code) const {
125 const std::string spv_path = fmt::format("{}/{}.spv", output_dir, k.name);
126 std::ofstream fs(spv_path, std::ios_base::binary | std::ios::trunc);
127 fs.write((char *)source_code.data(), source_code.size() * sizeof(uint32_t));
128 fs.close();
129 return k.name + ".spv";
130}
131
132void AotModuleBuilderImpl::dump(const std::string &output_dir,
133 const std::string &filename) const {
134 TI_WARN_IF(!filename.empty(),
135 "Filename prefix is ignored on Unified Device API backends.");
136 const std::string bin_path = fmt::format("{}/metadata.tcb", output_dir);
137 write_to_binary_file(ti_aot_data_, bin_path);
138
139 auto converted = AotDataConverter::convert(ti_aot_data_);
140 const auto &spirv_codes = ti_aot_data_.spirv_codes;
141 for (int i = 0; i < std::min(ti_aot_data_.kernels.size(), spirv_codes.size());
142 ++i) {
143 auto &k = ti_aot_data_.kernels[i];
144 for (int j = 0; j < std::min(k.tasks_attribs.size(), spirv_codes[i].size());
145 ++j) {
146 if (!spirv_codes[i][j].empty()) {
147 std::string spv_path =
148 write_spv_file(output_dir, k.tasks_attribs[j], spirv_codes[i][j]);
149 converted.kernels[k.name].tasks[j].source_path = spv_path;
150 }
151 }
152 }
153
154 std::string json = liong::json::print(liong::json::serialize(ti_aot_data_));
155 std::fstream f(output_dir + "/metadata.json",
156 std::ios::trunc | std::ios::out);
157 f.write(json.data(), json.size());
158
159 dump_graph(output_dir);
160}
161
162void AotModuleBuilderImpl::mangle_aot_data() {
163 // Only for offline cache
164 for (auto &kernel : ti_aot_data_.kernels) {
165 const auto &prefix = kernel.name;
166 for (std::size_t i = 0; i < kernel.tasks_attribs.size(); ++i) {
167 kernel.tasks_attribs[i].name = prefix + std::to_string(i);
168 }
169 }
170}
171
172void AotModuleBuilderImpl::merge_with_old_meta_data(const std::string &path) {
173 // Only for offline cache
174 auto filename = taichi::join_path(path, "metadata.tcb");
175 if (taichi::path_exists(filename)) {
176 TaichiAotData old_data;
177 read_from_binary_file(old_data, filename);
178 // Ignore root_buffer_size and fields which aren't needed for offline cache
179 ti_aot_data_.kernels.insert(ti_aot_data_.kernels.end(),
180 old_data.kernels.begin(),
181 old_data.kernels.end());
182 }
183}
184
185std::optional<GfxRuntime::RegisterParams>
186AotModuleBuilderImpl::try_get_kernel_register_params(
187 const std::string &kernel_name) const {
188 const auto &kernels = ti_aot_data_.kernels;
189 for (std::size_t i = 0; i < kernels.size(); ++i) {
190 if (kernels[i].name == kernel_name) {
191 GfxRuntime::RegisterParams result;
192 result.kernel_attribs = kernels[i];
193 result.task_spirv_source_codes = ti_aot_data_.spirv_codes[i];
194 // We only support a single SNodeTree during AOT.
195 result.num_snode_trees = 1;
196 return result;
197 }
198 }
199 return std::nullopt;
200}
201
202void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
203 Kernel *kernel) {
204 spirv::lower(config_, kernel);
205 auto compiled = run_codegen(kernel, device_api_backend_, caps_,
206 compiled_structs_, config_);
207 compiled.kernel_attribs.name = identifier;
208 ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
209 ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
210}
211
212void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
213 const SNode *rep_snode,
214 bool is_scalar,
215 DataType dt,
216 std::vector<int> shape,
217 int row_num,
218 int column_num) {
219 // Note that currently we only support adding dense fields in AOT for all
220 // backends. In opengl backend we only error out when a non dense field is
221 // added to the aot module, but in metal backend we error out earlier when
222 // constructing aot module. Ideally we will unify this behavior but it doesn't
223 // matter too much for now.
224 TI_ERROR_IF(!all_fields_are_dense_in_container(rep_snode->parent),
225 "AOT: only supports dense field");
226
227 const auto &dense_desc =
228 compiled_structs_[0].snode_descriptors.at(rep_snode->parent->id);
229
230 aot::CompiledFieldData field_data;
231 field_data.field_name = identifier;
232 field_data.is_scalar = is_scalar;
233 field_data.dtype = static_cast<int>(dt->cast<PrimitiveType>()->type);
234 field_data.dtype_name = dt.to_string();
235 field_data.shape = shape;
236 field_data.mem_offset_in_parent = dense_desc.mem_offset_in_parent_cell;
237 if (!is_scalar) {
238 field_data.element_shape = {row_num, column_num};
239 }
240 ti_aot_data_.fields.push_back(field_data);
241}
242
243void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
244 const std::string &key,
245 Kernel *kernel) {
246 spirv::lower(config_, kernel);
247 auto compiled = run_codegen(kernel, device_api_backend_, caps_,
248 compiled_structs_, config_);
249 compiled.kernel_attribs.name = identifier + "|" + key;
250 ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
251 ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
252}
253
254} // namespace gfx
255} // namespace taichi::lang
256