1#pragma once
2
3#include <string>
4#include <vector>
5
6#include "taichi/aot/module_builder.h"
7#include "taichi/runtime/gfx/aot_utils.h"
8#include "taichi/runtime/gfx/runtime.h"
9#include "taichi/codegen/spirv/snode_struct_compiler.h"
10#include "taichi/codegen/spirv/kernel_utils.h"
11
12namespace taichi::lang {
13namespace gfx {
14
15class AotModuleBuilderImpl : public AotModuleBuilder {
16 public:
17 explicit AotModuleBuilderImpl(
18 const std::vector<CompiledSNodeStructs> &compiled_structs,
19 Arch device_api_backend,
20 const CompileConfig &compile_config,
21 const DeviceCapabilityConfig &caps);
22
23 void dump(const std::string &output_dir,
24 const std::string &filename) const override;
25
26 void mangle_aot_data();
27 void merge_with_old_meta_data(const std::string &path);
28 std::optional<GfxRuntime::RegisterParams> try_get_kernel_register_params(
29 const std::string &kernel_name) const;
30
31 private:
32 void add_per_backend(const std::string &identifier, Kernel *kernel) override;
33
34 void add_field_per_backend(const std::string &identifier,
35 const SNode *rep_snode,
36 bool is_scalar,
37 DataType dt,
38 std::vector<int> shape,
39 int row_num,
40 int column_num) override;
41
42 void add_per_backend_tmpl(const std::string &identifier,
43 const std::string &key,
44 Kernel *kernel) override;
45
46 std::string write_spv_file(const std::string &output_dir,
47 const TaskAttributes &k,
48 const std::vector<uint32_t> &source_code) const;
49
50 const std::vector<CompiledSNodeStructs> &compiled_structs_;
51 TaichiAotData ti_aot_data_;
52
53 Arch device_api_backend_;
54 const CompileConfig &config_;
55 DeviceCapabilityConfig caps_;
56};
57
58} // namespace gfx
59} // namespace taichi::lang
60