1#include "taichi/runtime/llvm/llvm_aot_module_loader.h"
2#include "taichi/runtime/llvm/aot_graph_data.h"
3
4namespace taichi::lang {
5
6LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache(
7 const std::string &name) {
8 TI_ASSERT(cache_reader_ != nullptr);
9 auto *tlctx = executor_->get_llvm_context();
10 LlvmOfflineCache::KernelCacheData loaded;
11 auto ok = cache_reader_->get_kernel_cache(loaded, name,
12 *tlctx->get_this_thread_context());
13 TI_ERROR_IF(!ok, "Failed to load kernel={}", name);
14 return loaded;
15}
16
17std::unique_ptr<aot::Kernel> LlvmAotModule::make_new_kernel(
18 const std::string &name) {
19 auto fn = convert_module_to_function(name, load_kernel_from_cache(name));
20 return std::make_unique<llvm_aot::KernelImpl>(
21 fn, LlvmOfflineCache::KernelCacheData());
22}
23
24std::unique_ptr<aot::Field> LlvmAotModule::make_new_field(
25 const std::string &name) {
26 // Check if "name" represents snode_tree_id.
27 // Avoid using std::atoi due to its poor error handling.
28 char *end;
29 int snode_tree_id = static_cast<int>(strtol(name.c_str(), &end, 10 /*base*/));
30
31 TI_ASSERT(end != name.c_str());
32 TI_ASSERT(*end == '\0');
33
34 // Load FieldCache
35 LlvmOfflineCache::FieldCacheData loaded;
36 auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id);
37 TI_ERROR_IF(!ok, "Failed to load field with id={}", snode_tree_id);
38
39 return std::make_unique<llvm_aot::FieldImpl>(std::move(loaded));
40}
41
42std::unique_ptr<aot::CompiledGraph> LlvmAotModule::get_graph(
43 const std::string &name) {
44 auto it = graphs_.find(name);
45 if (it == graphs_.end()) {
46 TI_DEBUG("Cannot find graph {}", name);
47 return nullptr;
48 }
49
50 std::vector<aot::CompiledDispatch> dispatches;
51 for (auto &dispatch : it->second.dispatches) {
52 dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args,
53 get_kernel(dispatch.kernel_name)});
54 }
55
56 aot::CompiledGraph graph = aot::CompiledGraph({dispatches});
57 executor_->prepare_runtime_context(&graph.ctx_);
58
59 return std::make_unique<aot::CompiledGraph>(std::move(graph));
60}
61
62void allocate_aot_snode_tree_type(aot::Module *aot_module,
63 aot::Field *aot_field,
64 uint64 *result_buffer) {
65 auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module);
66 auto *aot_field_impl = dynamic_cast<llvm_aot::FieldImpl *>(aot_field);
67
68 TI_ASSERT(llvm_aot_module != nullptr);
69 TI_ASSERT(aot_field_impl != nullptr);
70
71 auto *runtime_executor = llvm_aot_module->get_runtime_executor();
72 const auto &field_cache = aot_field_impl->get_snode_tree_cache();
73
74 int snode_tree_id = field_cache.tree_id;
75 if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) {
76 runtime_executor->initialize_llvm_runtime_snodes(field_cache,
77 result_buffer);
78 llvm_aot_module->set_initialized_snode_tree(snode_tree_id);
79 }
80}
81
82} // namespace taichi::lang
83