1 | #include "taichi/runtime/llvm/llvm_aot_module_loader.h" |
2 | #include "taichi/runtime/llvm/aot_graph_data.h" |
3 | |
4 | namespace taichi::lang { |
5 | |
6 | LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache( |
7 | const std::string &name) { |
8 | TI_ASSERT(cache_reader_ != nullptr); |
9 | auto *tlctx = executor_->get_llvm_context(); |
10 | LlvmOfflineCache::KernelCacheData loaded; |
11 | auto ok = cache_reader_->get_kernel_cache(loaded, name, |
12 | *tlctx->get_this_thread_context()); |
13 | TI_ERROR_IF(!ok, "Failed to load kernel={}" , name); |
14 | return loaded; |
15 | } |
16 | |
17 | std::unique_ptr<aot::Kernel> LlvmAotModule::make_new_kernel( |
18 | const std::string &name) { |
19 | auto fn = convert_module_to_function(name, load_kernel_from_cache(name)); |
20 | return std::make_unique<llvm_aot::KernelImpl>( |
21 | fn, LlvmOfflineCache::KernelCacheData()); |
22 | } |
23 | |
24 | std::unique_ptr<aot::Field> LlvmAotModule::make_new_field( |
25 | const std::string &name) { |
26 | // Check if "name" represents snode_tree_id. |
27 | // Avoid using std::atoi due to its poor error handling. |
28 | char *end; |
29 | int snode_tree_id = static_cast<int>(strtol(name.c_str(), &end, 10 /*base*/)); |
30 | |
31 | TI_ASSERT(end != name.c_str()); |
32 | TI_ASSERT(*end == '\0'); |
33 | |
34 | // Load FieldCache |
35 | LlvmOfflineCache::FieldCacheData loaded; |
36 | auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id); |
37 | TI_ERROR_IF(!ok, "Failed to load field with id={}" , snode_tree_id); |
38 | |
39 | return std::make_unique<llvm_aot::FieldImpl>(std::move(loaded)); |
40 | } |
41 | |
42 | std::unique_ptr<aot::CompiledGraph> LlvmAotModule::get_graph( |
43 | const std::string &name) { |
44 | auto it = graphs_.find(name); |
45 | if (it == graphs_.end()) { |
46 | TI_DEBUG("Cannot find graph {}" , name); |
47 | return nullptr; |
48 | } |
49 | |
50 | std::vector<aot::CompiledDispatch> dispatches; |
51 | for (auto &dispatch : it->second.dispatches) { |
52 | dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args, |
53 | get_kernel(dispatch.kernel_name)}); |
54 | } |
55 | |
56 | aot::CompiledGraph graph = aot::CompiledGraph({dispatches}); |
57 | executor_->prepare_runtime_context(&graph.ctx_); |
58 | |
59 | return std::make_unique<aot::CompiledGraph>(std::move(graph)); |
60 | } |
61 | |
62 | void allocate_aot_snode_tree_type(aot::Module *aot_module, |
63 | aot::Field *aot_field, |
64 | uint64 *result_buffer) { |
65 | auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module); |
66 | auto *aot_field_impl = dynamic_cast<llvm_aot::FieldImpl *>(aot_field); |
67 | |
68 | TI_ASSERT(llvm_aot_module != nullptr); |
69 | TI_ASSERT(aot_field_impl != nullptr); |
70 | |
71 | auto *runtime_executor = llvm_aot_module->get_runtime_executor(); |
72 | const auto &field_cache = aot_field_impl->get_snode_tree_cache(); |
73 | |
74 | int snode_tree_id = field_cache.tree_id; |
75 | if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) { |
76 | runtime_executor->initialize_llvm_runtime_snodes(field_cache, |
77 | result_buffer); |
78 | llvm_aot_module->set_initialized_snode_tree(snode_tree_id); |
79 | } |
80 | } |
81 | |
82 | } // namespace taichi::lang |
83 | |