1#pragma once
2
3// llvm backend compiler (x64, arm64, cuda, amdgpu etc)
4// in charge of creating & JITing arch-specific LLVM modules,
5// and invoking compiled functions (kernels).
6// Designed to be multithreaded for parallel compilation.
7
8#include <mutex>
9#include <functional>
10#include <thread>
11
12#include "taichi/util/lang_util.h"
13#include "taichi/runtime/llvm/llvm_fwd.h"
14#include "taichi/ir/snode.h"
15#include "taichi/jit/jit_session.h"
16#include "taichi/codegen/llvm/llvm_compiled_data.h"
17
18namespace taichi::lang {
19
20class JITSessionCPU;
21class LlvmProgramImpl;
22
23/**
24 * Manages an LLVMContext for Taichi's usage.
25 */
26class TaichiLLVMContext {
27 private:
28 struct ThreadLocalData {
29 std::unique_ptr<llvm::orc::ThreadSafeContext> thread_safe_llvm_context{
30 nullptr};
31 llvm::LLVMContext *llvm_context{nullptr};
32 std::unique_ptr<llvm::Module> runtime_module{nullptr};
33 std::unordered_map<int, std::unique_ptr<llvm::Module>> struct_modules;
34 explicit ThreadLocalData(std::unique_ptr<llvm::orc::ThreadSafeContext> ctx);
35 ~ThreadLocalData();
36 };
37 const CompileConfig &config_;
38
39 public:
40 std::unique_ptr<JITSession> jit{nullptr};
41 // main_thread is defined to be the thread that runs the initializer
42
43 std::unique_ptr<ThreadLocalData> linking_context_data{nullptr};
44
45 TaichiLLVMContext(const CompileConfig &config, Arch arch);
46
47 virtual ~TaichiLLVMContext();
48
49 llvm::LLVMContext *get_this_thread_context();
50
51 llvm::orc::ThreadSafeContext *get_this_thread_thread_safe_context();
52
53 /**
54 * Updates the LLVM module of the JIT compiled SNode structs.
55 *
56 * @param module Module containing the JIT compiled SNode structs.
57 */
58 void add_struct_module(std::unique_ptr<llvm::Module> module, int tree_id);
59
60 void init_runtime_module(llvm::Module *runtime_module);
61
62 /**
63 * Clones the LLVM module compiled from llvm/runtime.cpp
64 *
65 * @return The cloned module.
66 */
67 std::unique_ptr<llvm::Module> clone_runtime_module();
68
69 std::unique_ptr<llvm::Module> module_from_file(const std::string &file);
70
71 llvm::Type *get_data_type(DataType dt);
72
73 template <typename T>
74 llvm::Type *get_data_type() {
75 return TaichiLLVMContext::get_data_type(taichi::lang::get_data_type<T>());
76 }
77
78 std::size_t get_type_size(llvm::Type *type);
79
80 std::size_t get_struct_element_offset(llvm::StructType *type, int idx);
81
82 const StructType *get_struct_type_with_data_layout(const StructType *old_ty,
83 const std::string &layout);
84
85 template <typename T>
86 llvm::Value *get_constant(T t);
87
88 template <typename T>
89 llvm::Value *get_constant(DataType dt, T t);
90
91 llvm::DataLayout get_data_layout();
92
93 std::string get_data_layout_string();
94
95 std::string type_name(llvm::Type *type);
96
97 static void mark_inline(llvm::Function *func);
98
99 static void print_huge_functions(llvm::Module *module);
100
101 // remove all functions that are not (directly & indirectly) used by those
102 // with export_indicator(func_name) = true
103 static void eliminate_unused_functions(
104 llvm::Module *module,
105 std::function<bool(const std::string &)> export_indicator);
106
107 void mark_function_as_cuda_kernel(llvm::Function *func, int block_dim = 0);
108
109 void mark_function_as_amdgpu_kernel(llvm::Function *func);
110
111 void fetch_this_thread_struct_module();
112 llvm::Module *get_this_thread_runtime_module();
113 llvm::Function *get_runtime_function(const std::string &name);
114 llvm::Function *get_struct_function(const std::string &name, int tree_id);
115 llvm::Type *get_runtime_type(const std::string &name);
116
117 std::unique_ptr<llvm::Module> new_module(
118 std::string name,
119 llvm::LLVMContext *context = nullptr);
120
121 void delete_snode_tree(int id);
122
123 void add_struct_for_func(llvm::Module *module, int tls_size);
124
125 static std::string get_struct_for_func_name(int tls_size);
126
127 LLVMCompiledKernel link_compiled_tasks(
128 std::vector<std::unique_ptr<LLVMCompiledTask>> data_list);
129
130 private:
131 std::unique_ptr<llvm::Module> clone_module_to_context(
132 llvm::Module *module,
133 llvm::LLVMContext *target_context);
134
135 void link_module_with_cuda_libdevice(std::unique_ptr<llvm::Module> &module);
136
137 void link_module_with_amdgpu_libdevice(std::unique_ptr<llvm::Module> &module);
138
139 static int num_instructions(llvm::Function *func);
140
141 void insert_nvvm_annotation(llvm::Function *func, std::string key, int val);
142
143 std::unique_ptr<llvm::Module> clone_module_to_this_thread_context(
144 llvm::Module *module);
145
146 ThreadLocalData *get_this_thread_data();
147
148 std::unordered_map<std::thread::id, std::unique_ptr<ThreadLocalData>>
149 per_thread_data_;
150
151 Arch arch_;
152
153 std::thread::id main_thread_id_;
154 ThreadLocalData *main_thread_data_{nullptr};
155 std::mutex mut_;
156 std::mutex thread_map_mut_;
157
158 std::unordered_map<int, std::vector<std::string>> snode_tree_funcs_;
159};
160
161class LlvmModuleBitcodeLoader {
162 public:
163 LlvmModuleBitcodeLoader &set_bitcode_path(const std::string &bitcode_path) {
164 bitcode_path_ = bitcode_path;
165 return *this;
166 }
167
168 LlvmModuleBitcodeLoader &set_buffer_id(const std::string &buffer_id) {
169 buffer_id_ = buffer_id;
170 return *this;
171 }
172
173 LlvmModuleBitcodeLoader &set_inline_funcs(bool inline_funcs) {
174 inline_funcs_ = inline_funcs;
175 return *this;
176 }
177
178 std::unique_ptr<llvm::Module> load(llvm::LLVMContext *ctx) const;
179
180 private:
181 std::string bitcode_path_;
182 std::string buffer_id_;
183 bool inline_funcs_{false};
184};
185
186std::unique_ptr<llvm::Module> module_from_bitcode_file(
187 const std::string &bitcode_path,
188 llvm::LLVMContext *ctx);
189
190} // namespace taichi::lang
191