1#include "taichi/aot/module_loader.h"
2
3#include "taichi/runtime/gfx/aot_module_loader_impl.h"
4#include "taichi/runtime/dx12/aot_module_loader_impl.h"
5
6namespace taichi::lang {
7namespace aot {
8namespace {
9
10std::string make_kernel_key(
11 const std::vector<KernelTemplateArg> &template_args) {
12 TI_NOT_IMPLEMENTED;
13 return "";
14}
15
16} // namespace
17
18Kernel *KernelTemplate::get_kernel(
19 const std::vector<KernelTemplateArg> &template_args) {
20 const auto key = make_kernel_key(template_args);
21 auto itr = loaded_kernels_.find(key);
22 if (itr != loaded_kernels_.end()) {
23 return itr->second.get();
24 }
25 auto k = make_new_kernel(template_args);
26 auto *kptr = k.get();
27 loaded_kernels_[key] = std::move(k);
28 return kptr;
29}
30
31std::unique_ptr<Module> Module::load(Arch arch, std::any mod_params) {
32 if (arch == Arch::vulkan) {
33#ifdef TI_WITH_VULKAN
34 return gfx::make_aot_module(mod_params, arch);
35#endif
36 } else if (arch == Arch::opengl) {
37#ifdef TI_WITH_OPENGL
38 return gfx::make_aot_module(mod_params, arch);
39#endif
40 } else if (arch == Arch::gles) {
41#ifdef TI_WITH_OPENGL
42 return gfx::make_aot_module(mod_params, arch);
43#endif
44 } else if (arch == Arch::dx11) {
45#ifdef TI_WITH_DX11
46 return gfx::make_aot_module(mod_params, arch);
47#endif
48 } else if (arch == Arch::dx12) {
49#ifdef TI_WITH_DX12
50 return directx12::make_aot_module(mod_params, arch);
51#endif
52 } else if (arch == Arch::metal) {
53#ifdef TI_WITH_METAL
54 return gfx::make_aot_module(mod_params, arch);
55#endif
56 }
57 TI_NOT_IMPLEMENTED;
58}
59
60Kernel *Module::get_kernel(const std::string &name) {
61 auto itr = loaded_kernels_.find(name);
62 if (itr != loaded_kernels_.end()) {
63 return itr->second.get();
64 }
65 auto k = make_new_kernel(name);
66 auto *kptr = k.get();
67 loaded_kernels_[name] = std::move(k);
68 return kptr;
69}
70
71KernelTemplate *Module::get_kernel_template(const std::string &name) {
72 auto itr = loaded_kernel_templates_.find(name);
73 if (itr != loaded_kernel_templates_.end()) {
74 return itr->second.get();
75 }
76 auto kt = make_new_kernel_template(name);
77 auto *kt_ptr = kt.get();
78 loaded_kernel_templates_[name] = std::move(kt);
79 return kt_ptr;
80}
81
82Field *Module::get_snode_tree(const std::string &name) {
83 auto itr = loaded_fields_.find(name);
84 if (itr != loaded_fields_.end()) {
85 return itr->second.get();
86 }
87 auto k = make_new_field(name);
88 auto *kptr = k.get();
89 loaded_fields_[name] = std::move(k);
90 return kptr;
91}
92
93} // namespace aot
94} // namespace taichi::lang
95