1 | #pragma once |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | |
6 | #include "taichi/aot/module_builder.h" |
7 | #include "taichi/runtime/gfx/aot_utils.h" |
8 | #include "taichi/runtime/gfx/runtime.h" |
9 | #include "taichi/codegen/spirv/snode_struct_compiler.h" |
10 | #include "taichi/codegen/spirv/kernel_utils.h" |
11 | |
12 | namespace taichi::lang { |
13 | namespace gfx { |
14 | |
15 | class AotModuleBuilderImpl : public AotModuleBuilder { |
16 | public: |
17 | explicit AotModuleBuilderImpl( |
18 | const std::vector<CompiledSNodeStructs> &compiled_structs, |
19 | Arch device_api_backend, |
20 | const CompileConfig &compile_config, |
21 | const DeviceCapabilityConfig &caps); |
22 | |
23 | void dump(const std::string &output_dir, |
24 | const std::string &filename) const override; |
25 | |
26 | void mangle_aot_data(); |
27 | void merge_with_old_meta_data(const std::string &path); |
28 | std::optional<GfxRuntime::RegisterParams> try_get_kernel_register_params( |
29 | const std::string &kernel_name) const; |
30 | |
31 | private: |
32 | void add_per_backend(const std::string &identifier, Kernel *kernel) override; |
33 | |
34 | void add_field_per_backend(const std::string &identifier, |
35 | const SNode *rep_snode, |
36 | bool is_scalar, |
37 | DataType dt, |
38 | std::vector<int> shape, |
39 | int row_num, |
40 | int column_num) override; |
41 | |
42 | void add_per_backend_tmpl(const std::string &identifier, |
43 | const std::string &key, |
44 | Kernel *kernel) override; |
45 | |
46 | std::string write_spv_file(const std::string &output_dir, |
47 | const TaskAttributes &k, |
48 | const std::vector<uint32_t> &source_code) const; |
49 | |
50 | const std::vector<CompiledSNodeStructs> &compiled_structs_; |
51 | TaichiAotData ti_aot_data_; |
52 | |
53 | Arch device_api_backend_; |
54 | const CompileConfig &config_; |
55 | DeviceCapabilityConfig caps_; |
56 | }; |
57 | |
58 | } // namespace gfx |
59 | } // namespace taichi::lang |
60 | |