1// The LLVM backend for CPUs/NVPTX/AMDGPU
2#pragma once
3
4#include <set>
5#include <unordered_map>
6
7#ifdef TI_WITH_LLVM
8
9#include "taichi/ir/ir.h"
10#include "taichi/runtime/llvm/launch_arg_info.h"
11#include "taichi/codegen/llvm/llvm_codegen_utils.h"
12#include "taichi/codegen/llvm/llvm_compiled_data.h"
13#include "taichi/program/program.h"
14
15namespace taichi::lang {
16
17class TaskCodeGenLLVM;
18
19class FunctionCreationGuard {
20 public:
21 TaskCodeGenLLVM *mb;
22 llvm::Function *old_func;
23 llvm::Function *body;
24 llvm::BasicBlock *old_entry, *allocas, *entry, *old_final, *final;
25 llvm::IRBuilder<>::InsertPoint ip;
26
27 FunctionCreationGuard(TaskCodeGenLLVM *mb,
28 std::vector<llvm::Type *> arguments,
29 const std::string &func_name);
30
31 ~FunctionCreationGuard();
32};
33
34class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
35 public:
36 const CompileConfig &compile_config;
37 Kernel *kernel;
38 IRNode *ir;
39 Program *prog;
40 std::string kernel_name;
41 std::vector<llvm::Value *> kernel_args;
42 llvm::Type *context_ty;
43 llvm::Type *physical_coordinate_ty;
44 llvm::Value *current_coordinates;
45 llvm::Value *parent_coordinates{nullptr};
46 llvm::Value *block_corner_coordinates{nullptr};
47 llvm::GlobalVariable *bls_buffer{nullptr};
48 // Mainly for supporting continue stmt
49 llvm::BasicBlock *current_loop_reentry;
50 // Mainly for supporting break stmt
51 llvm::BasicBlock *current_while_after_loop;
52 llvm::FunctionType *task_function_type;
53 std::unordered_map<Stmt *, llvm::Value *> llvm_val;
54 llvm::Function *func;
55 OffloadedStmt *current_offload{nullptr};
56 std::unique_ptr<OffloadedTask> current_task;
57 std::vector<OffloadedTask> offloaded_tasks;
58 llvm::BasicBlock *func_body_bb;
59 llvm::BasicBlock *final_block;
60 std::set<std::string> linked_modules;
61 bool returned{false};
62 std::unordered_set<int> used_tree_ids;
63 std::unordered_set<int> struct_for_tls_sizes;
64 Callable *current_callable{nullptr};
65
66 std::unordered_map<const Stmt *, std::vector<llvm::Value *>> loop_vars_llvm;
67
68 std::unordered_map<Function *, llvm::Function *> func_map;
69
70 using IRVisitor::visit;
71 using LLVMModuleBuilder::call;
72
73 explicit TaskCodeGenLLVM(const CompileConfig &config,
74 TaichiLLVMContext &tlctx,
75 Kernel *kernel,
76 IRNode *ir = nullptr,
77 std::unique_ptr<llvm::Module> &&module = nullptr);
78
79 Arch current_arch() const {
80 return compile_config.arch;
81 }
82
83 void initialize_context();
84
85 llvm::Value *get_arg(int i);
86
87 llvm::Value *get_context();
88
89 llvm::Value *get_tls_base_ptr();
90
91 llvm::Type *get_tls_buffer_type();
92
93 std::vector<llvm::Type *> get_xlogue_argument_types();
94
95 std::vector<llvm::Type *> get_mesh_xlogue_argument_types();
96
97 llvm::Type *get_xlogue_function_type();
98
99 llvm::Type *get_mesh_xlogue_function_type();
100
101 llvm::Value *get_root(int snode_tree_id);
102
103 llvm::Value *get_runtime();
104
105 void emit_struct_meta_base(const std::string &name,
106 llvm::Value *node_meta,
107 SNode *snode);
108
109 void create_elementwise_binary(
110 BinaryOpStmt *stmt,
111 std::function<llvm::Value *(llvm::Value *lhs, llvm::Value *rhs)> f);
112
113 void create_elementwise_cast(
114 UnaryOpStmt *stmt,
115 llvm::Type *to_ty,
116 std::function<llvm::Value *(llvm::Value *, llvm::Type *)> f,
117 bool on_self = false);
118
119 std::unique_ptr<RuntimeObject> emit_struct_meta_object(SNode *snode);
120
121 llvm::Value *emit_struct_meta(SNode *snode);
122
123 virtual void emit_to_module();
124
125 void eliminate_unused_functions();
126
127 /**
128 * @brief Runs the codegen and produces the compiled result.
129 *
130 * After this call, `module` and `tasks` will be moved.
131 *
132 * @return LLVMCompiledTask
133 */
134 virtual LLVMCompiledTask run_compilation();
135 // For debugging only
136 virtual llvm::Value *create_print(std::string tag,
137 DataType dt,
138 llvm::Value *value);
139
140 llvm::Value *create_print(std::string tag, llvm::Value *value);
141
142 void create_return(const std::vector<Stmt *> &elements);
143
144 llvm::Value *cast_pointer(llvm::Value *val,
145 std::string dest_ty_name,
146 int addr_space = 0);
147
148 void emit_list_gen(OffloadedStmt *listgen);
149
150 void emit_gc(OffloadedStmt *stmt);
151 void emit_gc_rc();
152
153 llvm::Value *call(SNode *snode,
154 llvm::Value *node_ptr,
155 const std::string &method,
156 const std::vector<llvm::Value *> &arguments);
157
158 llvm::Function *get_struct_function(const std::string &name, int tree_id);
159
160 template <typename... Args>
161 llvm::Value *call_struct_func(int tree_id,
162 const std::string &func_name,
163 Args &&...args);
164
165 void create_increment(llvm::Value *ptr, llvm::Value *value);
166
167 // Direct translation
168 void create_naive_range_for(RangeForStmt *for_stmt);
169
170 static std::string get_runtime_snode_name(SNode *snode);
171
172 void visit(Block *stmt_list) override;
173
174 void visit(AllocaStmt *stmt) override;
175
176 void visit(RandStmt *stmt) override;
177
178 virtual void emit_extra_unary(UnaryOpStmt *stmt);
179
180 void visit(DecorationStmt *stmt) override;
181
182 void visit(UnaryOpStmt *stmt) override;
183
184 void visit(BinaryOpStmt *stmt) override;
185
186 void visit(TernaryOpStmt *stmt) override;
187
188 void visit(IfStmt *if_stmt) override;
189
190 void visit(PrintStmt *stmt) override;
191
192 void visit(ConstStmt *stmt) override;
193
194 void visit(WhileControlStmt *stmt) override;
195
196 void visit(ContinueStmt *stmt) override;
197
198 void visit(WhileStmt *stmt) override;
199
200 void visit(RangeForStmt *for_stmt) override;
201
202 void visit(ArgLoadStmt *stmt) override;
203
204 void visit(ReturnStmt *stmt) override;
205
206 void visit(LocalLoadStmt *stmt) override;
207
208 void visit(LocalStoreStmt *stmt) override;
209
210 void visit(AssertStmt *stmt) override;
211
212 void visit(SNodeOpStmt *stmt) override;
213
214 llvm::Value *atomic_add_quant_fixed(llvm::Value *ptr,
215 llvm::Type *physical_type,
216 QuantFixedType *qfxt,
217 llvm::Value *value);
218
219 llvm::Value *atomic_add_quant_int(llvm::Value *ptr,
220 llvm::Type *physical_type,
221 QuantIntType *qit,
222 llvm::Value *value,
223 bool value_is_signed);
224
225 llvm::Value *to_quant_fixed(llvm::Value *real, QuantFixedType *qfxt);
226
227 virtual llvm::Value *optimized_reduction(AtomicOpStmt *stmt);
228
229 virtual llvm::Value *quant_type_atomic(AtomicOpStmt *stmt);
230
231 virtual llvm::Value *integral_type_atomic(AtomicOpStmt *stmt);
232
233 virtual llvm::Value *atomic_op_using_cas(
234 llvm::Value *output_address,
235 llvm::Value *val,
236 std::function<llvm::Value *(llvm::Value *, llvm::Value *)> op);
237
238 virtual llvm::Value *real_type_atomic(AtomicOpStmt *stmt);
239
240 void visit(AtomicOpStmt *stmt) override;
241
242 void visit(GlobalPtrStmt *stmt) override;
243
244 void visit(MatrixPtrStmt *stmt) override;
245
246 void store_quant_int(llvm::Value *ptr,
247 llvm::Type *physical_type,
248 QuantIntType *qit,
249 llvm::Value *value,
250 bool atomic);
251
252 void store_quant_fixed(llvm::Value *ptr,
253 llvm::Type *physical_type,
254 QuantFixedType *qfxt,
255 llvm::Value *value,
256 bool atomic);
257
258 void store_masked(llvm::Value *ptr,
259 llvm::Type *ty,
260 uint64 mask,
261 llvm::Value *value,
262 bool atomic);
263
264 void visit(GlobalStoreStmt *stmt) override;
265
266 llvm::Value *quant_int_or_quant_fixed_to_bits(llvm::Value *val,
267 Type *input_type,
268 llvm::Type *output_type);
269
270 void visit(BitStructStoreStmt *stmt) override;
271
272 void store_quant_floats_with_shared_exponents(BitStructStoreStmt *stmt);
273
274 llvm::Value *extract_quant_float(llvm::Value *physical_value,
275 BitStructType *bit_struct,
276 int digits_id);
277
278 llvm::Value *extract_quant_int(llvm::Value *physical_value,
279 llvm::Value *bit_offset,
280 QuantIntType *qit);
281
282 llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
283 QuantFixedType *qfxt);
284
285 llvm::Value *reconstruct_quant_float(llvm::Value *input_digits,
286 llvm::Value *input_exponent_val,
287 QuantFloatType *qflt,
288 bool shared_exponent);
289
290 virtual llvm::Value *create_intrinsic_load(llvm::Value *ptr, llvm::Type *ty);
291
292 void create_global_load(GlobalLoadStmt *stmt, bool should_cache_as_read_only);
293
294 void visit(GlobalLoadStmt *stmt) override;
295
296 void visit(GetRootStmt *stmt) override;
297
298 void visit(LinearizeStmt *stmt) override;
299
300 void visit(IntegerOffsetStmt *stmt) override;
301
302 llvm::Value *create_bit_ptr(llvm::Value *byte_ptr, llvm::Value *bit_offset);
303
304 std::tuple<llvm::Value *, llvm::Value *> load_bit_ptr(llvm::Value *bit_ptr);
305
306 void visit(SNodeLookupStmt *stmt) override;
307
308 void visit(GetChStmt *stmt) override;
309
310 void visit(ExternalPtrStmt *stmt) override;
311
312 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override;
313
314 virtual bool kernel_argument_by_val() const {
315 return false; // on CPU devices just pass in a pointer
316 }
317
318 std::string init_offloaded_task_function(OffloadedStmt *stmt,
319 std::string suffix = "");
320
321 void finalize_offloaded_task_function();
322
323 FunctionCreationGuard get_function_creation_guard(
324 std::vector<llvm::Type *> argument_types,
325 const std::string &func_name = "function_body");
326
327 std::tuple<llvm::Value *, llvm::Value *> get_range_for_bounds(
328 OffloadedStmt *stmt);
329
330 virtual void create_offload_range_for(OffloadedStmt *stmt) = 0;
331
332 virtual void create_offload_mesh_for(OffloadedStmt *stmt) {
333 TI_NOT_IMPLEMENTED;
334 }
335
336 void create_offload_struct_for(OffloadedStmt *stmt);
337
338 void visit(LoopIndexStmt *stmt) override;
339
340 void visit(LoopLinearIndexStmt *stmt) override;
341
342 void visit(BlockCornerIndexStmt *stmt) override;
343
344 void visit(GlobalTemporaryStmt *stmt) override;
345
346 void visit(ThreadLocalPtrStmt *stmt) override;
347
348 void visit(BlockLocalPtrStmt *stmt) override;
349
350 void visit(ClearListStmt *stmt) override;
351
352 void visit(InternalFuncStmt *stmt) override;
353
354 // Stack statements
355
356 void visit(AdStackAllocaStmt *stmt) override;
357
358 void visit(AdStackPopStmt *stmt) override;
359
360 void visit(AdStackPushStmt *stmt) override;
361
362 void visit(AdStackLoadTopStmt *stmt) override;
363
364 void visit(AdStackLoadTopAdjStmt *stmt) override;
365
366 void visit(AdStackAccAdjointStmt *stmt) override;
367
368 void visit(RangeAssumptionStmt *stmt) override;
369
370 void visit(LoopUniqueStmt *stmt) override;
371
372 void visit_call_bitcode(ExternalFuncCallStmt *stmt);
373
374 void visit_call_shared_object(ExternalFuncCallStmt *stmt);
375
376 void visit(ExternalFuncCallStmt *stmt) override;
377
378 void visit(MeshPatchIndexStmt *stmt) override;
379
380 void visit(ReferenceStmt *stmt) override;
381
382 void visit(MatrixInitStmt *stmt) override;
383
384 llvm::Value *create_xlogue(std::unique_ptr<Block> &block);
385
386 llvm::Value *create_mesh_xlogue(std::unique_ptr<Block> &block);
387
388 llvm::Value *extract_exponent_from_f32(llvm::Value *f);
389
390 llvm::Value *extract_digits_from_f32(llvm::Value *f, bool full);
391
392 llvm::Value *extract_digits_from_f32_with_shared_exponent(
393 llvm::Value *f,
394 llvm::Value *shared_exp);
395
396 llvm::Value *get_exponent_offset(llvm::Value *exponent, QuantFloatType *qflt);
397
398 void visit(FuncCallStmt *stmt) override;
399
400 void visit(GetElementStmt *stmt) override;
401
402 llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type);
403 llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);
404
405 ~TaskCodeGenLLVM() override = default;
406
407 private:
408 void create_return(llvm::Value *buffer,
409 llvm::Type *buffer_type,
410 const std::vector<Stmt *> &elements,
411 const Type *current_type,
412 int &current_element,
413 std::vector<llvm::Value *> &current_index);
414
415 virtual std::tuple<llvm::Value *, llvm::Value *> get_spmd_info() = 0;
416};
417
418} // namespace taichi::lang
419
420#endif // #ifdef TI_WITH_LLVM
421