1 | #include "taichi/aot/module_loader.h" |
---|---|
2 | |
3 | #include "taichi/runtime/gfx/aot_module_loader_impl.h" |
4 | #include "taichi/runtime/dx12/aot_module_loader_impl.h" |
5 | |
6 | namespace taichi::lang { |
7 | namespace aot { |
8 | namespace { |
9 | |
10 | std::string make_kernel_key( |
11 | const std::vector<KernelTemplateArg> &template_args) { |
12 | TI_NOT_IMPLEMENTED; |
13 | return ""; |
14 | } |
15 | |
16 | } // namespace |
17 | |
18 | Kernel *KernelTemplate::get_kernel( |
19 | const std::vector<KernelTemplateArg> &template_args) { |
20 | const auto key = make_kernel_key(template_args); |
21 | auto itr = loaded_kernels_.find(key); |
22 | if (itr != loaded_kernels_.end()) { |
23 | return itr->second.get(); |
24 | } |
25 | auto k = make_new_kernel(template_args); |
26 | auto *kptr = k.get(); |
27 | loaded_kernels_[key] = std::move(k); |
28 | return kptr; |
29 | } |
30 | |
31 | std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) { |
32 | if (arch == Arch::vulkan) { |
33 | #ifdef TI_WITH_VULKAN |
34 | return gfx::make_aot_module(mod_params, arch); |
35 | #endif |
36 | } else if (arch == Arch::opengl) { |
37 | #ifdef TI_WITH_OPENGL |
38 | return gfx::make_aot_module(mod_params, arch); |
39 | #endif |
40 | } else if (arch == Arch::gles) { |
41 | #ifdef TI_WITH_OPENGL |
42 | return gfx::make_aot_module(mod_params, arch); |
43 | #endif |
44 | } else if (arch == Arch::dx11) { |
45 | #ifdef TI_WITH_DX11 |
46 | return gfx::make_aot_module(mod_params, arch); |
47 | #endif |
48 | } else if (arch == Arch::dx12) { |
49 | #ifdef TI_WITH_DX12 |
50 | return directx12::make_aot_module(mod_params, arch); |
51 | #endif |
52 | } else if (arch == Arch::metal) { |
53 | #ifdef TI_WITH_METAL |
54 | return gfx::make_aot_module(mod_params, arch); |
55 | #endif |
56 | } |
57 | TI_NOT_IMPLEMENTED; |
58 | } |
59 | |
60 | Kernel *Module::get_kernel(const std::string &name) { |
61 | auto itr = loaded_kernels_.find(name); |
62 | if (itr != loaded_kernels_.end()) { |
63 | return itr->second.get(); |
64 | } |
65 | auto k = make_new_kernel(name); |
66 | auto *kptr = k.get(); |
67 | loaded_kernels_[name] = std::move(k); |
68 | return kptr; |
69 | } |
70 | |
71 | KernelTemplate *Module::get_kernel_template(const std::string &name) { |
72 | auto itr = loaded_kernel_templates_.find(name); |
73 | if (itr != loaded_kernel_templates_.end()) { |
74 | return itr->second.get(); |
75 | } |
76 | auto kt = make_new_kernel_template(name); |
77 | auto *kt_ptr = kt.get(); |
78 | loaded_kernel_templates_[name] = std::move(kt); |
79 | return kt_ptr; |
80 | } |
81 | |
82 | Field *Module::get_snode_tree(const std::string &name) { |
83 | auto itr = loaded_fields_.find(name); |
84 | if (itr != loaded_fields_.end()) { |
85 | return itr->second.get(); |
86 | } |
87 | auto k = make_new_field(name); |
88 | auto *kptr = k.get(); |
89 | loaded_fields_[name] = std::move(k); |
90 | return kptr; |
91 | } |
92 | |
93 | } // namespace aot |
94 | } // namespace taichi::lang |
95 |