1#pragma once
2
3#include "llvm/ADT/APFloat.h"
4#include "llvm/ADT/STLExtras.h"
5#include "llvm/IR/BasicBlock.h"
6#include "llvm/IR/Constants.h"
7#include "llvm/IR/DerivedTypes.h"
8#include "llvm/IR/Function.h"
9#include "llvm/IR/IRBuilder.h"
10#include "llvm/IR/Instructions.h"
11#include "llvm/IR/Intrinsics.h"
12#include "llvm/IR/IntrinsicsNVPTX.h"
13
14#if defined(TI_WITH_AMDGPU)
15#include "llvm/IR/IntrinsicsAMDGPU.h"
16#endif
17
18#include "llvm/IR/LLVMContext.h"
19#include "llvm/IR/LegacyPassManager.h"
20#include "llvm/IR/Module.h"
21#include "llvm/IR/Type.h"
22#include "llvm/IR/Verifier.h"
23#include "llvm/Support/TargetSelect.h"
24#include "llvm/Target/TargetMachine.h"
25#include "llvm/Transforms/InstCombine/InstCombine.h"
26#include "llvm/Transforms/Scalar.h"
27#include "llvm/Transforms/Scalar/GVN.h"
28#include "llvm/Transforms/Utils.h"
29#include "llvm/Transforms/Utils/Cloning.h"
30#include <algorithm>
31#include <cassert>
32#include <cctype>
33#include <cstdint>
34#include <cstdio>
35#include <cstdlib>
36#include <map>
37#include <memory>
38#include <string>
39#include <utility>
40#include <vector>
41
42#include "taichi/runtime/llvm/llvm_context.h"
43
44namespace taichi::lang {
45
46inline constexpr char kLLVMPhysicalCoordinatesName[] = "PhysicalCoordinates";
47
48std::string type_name(llvm::Type *type);
49
50bool is_same_type(llvm::Type *a, llvm::Type *b);
51
52void check_func_call_signature(llvm::FunctionType *func_type,
53 llvm::StringRef func_name,
54 std::vector<llvm::Value *> &arglist,
55 llvm::IRBuilder<> *builder);
56
57class LLVMModuleBuilder {
58 public:
59 std::unique_ptr<llvm::Module> module{nullptr};
60 llvm::BasicBlock *entry_block{nullptr};
61 std::unique_ptr<llvm::IRBuilder<>> builder{nullptr};
62 TaichiLLVMContext *tlctx{nullptr};
63 llvm::LLVMContext *llvm_context{nullptr};
64
65 LLVMModuleBuilder(std::unique_ptr<llvm::Module> &&module,
66 TaichiLLVMContext *tlctx)
67 : module(std::move(module)), tlctx(tlctx) {
68 TI_ASSERT(this->module != nullptr);
69 TI_ASSERT(&this->module->getContext() == tlctx->get_this_thread_context());
70 }
71
72 llvm::Value *create_entry_block_alloca(llvm::Type *type,
73 std::size_t alignment = 0,
74 llvm::Value *array_size = nullptr) {
75 llvm::IRBuilderBase::InsertPointGuard guard(*builder);
76 builder->SetInsertPoint(entry_block);
77 auto alloca = builder->CreateAlloca(type, (unsigned)0, array_size);
78 if (alignment != 0) {
79 alloca->setAlignment(llvm::Align(alignment));
80 }
81 return alloca;
82 }
83
84 llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) {
85 auto type = tlctx->get_data_type(dt);
86 if (is_pointer)
87 type = llvm::PointerType::get(type, 0);
88 return create_entry_block_alloca(type);
89 }
90
91 llvm::Type *get_runtime_type(const std::string &name) {
92 return tlctx->get_runtime_type(name);
93 }
94
95 llvm::Function *get_runtime_function(const std::string &name) {
96 auto f = tlctx->get_runtime_function(name);
97 if (!f) {
98 TI_ERROR("LLVMRuntime function {} not found.", name);
99 }
100 f = llvm::cast<llvm::Function>(
101 module
102 ->getOrInsertFunction(name, f->getFunctionType(),
103 f->getAttributes())
104 .getCallee());
105 return f;
106 }
107
108 llvm::Value *call(llvm::IRBuilder<> *builder,
109 llvm::Value *func,
110 llvm::FunctionType *func_ty,
111 std::vector<llvm::Value *> args) {
112 check_func_call_signature(func_ty, func->getName(), args, builder);
113 return builder->CreateCall(func_ty, func, std::move(args));
114 }
115
116 llvm::Value *call(llvm::Value *func,
117 llvm::FunctionType *func_ty,
118 std::vector<llvm::Value *> args) {
119 return call(builder.get(), func, func_ty, std::move(args));
120 }
121
122 llvm::Value *call(llvm::IRBuilder<> *builder,
123 llvm::Function *func,
124 std::vector<llvm::Value *> args) {
125 return call(builder, func, func->getFunctionType(), std::move(args));
126 }
127
128 llvm::Value *call(llvm::Function *func, std::vector<llvm::Value *> args) {
129 return call(builder.get(), func, std::move(args));
130 }
131
132 llvm::Value *call(llvm::IRBuilder<> *builder,
133 const std::string &func_name,
134 std::vector<llvm::Value *> args) {
135 auto func = get_runtime_function(func_name);
136 return call(builder, func, std::move(args));
137 }
138
139 llvm::Value *call(const std::string &func_name,
140 std::vector<llvm::Value *> args) {
141 return call(builder.get(), func_name, std::move(args));
142 }
143
144 template <typename... Args>
145 llvm::Value *call(llvm::IRBuilder<> *builder,
146 llvm::Function *func,
147 Args *...args) {
148 return call(builder, func, {args...});
149 }
150
151 template <typename... Args>
152 llvm::Value *call(llvm::Function *func, Args &&...args) {
153 return call(builder.get(), func, std::forward<Args>(args)...);
154 }
155
156 template <typename... Args>
157 llvm::Value *call(llvm::IRBuilder<> *builder,
158 const std::string &func_name,
159 Args *...args) {
160 return call(builder, func_name, {args...});
161 }
162
163 template <typename... Args>
164 llvm::Value *call(const std::string &func_name, Args &&...args) {
165 return call(builder.get(), func_name, std::forward<Args>(args)...);
166 }
167};
168
169class RuntimeObject {
170 public:
171 std::string cls_name;
172 llvm::Value *ptr{nullptr};
173 LLVMModuleBuilder *mb{nullptr};
174 llvm::Type *type{nullptr};
175 llvm::IRBuilder<> *builder{nullptr};
176
177 RuntimeObject(const std::string &cls_name,
178 LLVMModuleBuilder *mb,
179 llvm::IRBuilder<> *builder,
180 llvm::Value *init = nullptr)
181 : cls_name(cls_name), mb(mb), builder(builder) {
182 type = mb->get_runtime_type(cls_name);
183 if (init == nullptr) {
184 ptr = mb->create_entry_block_alloca(type);
185 } else {
186 ptr = builder->CreateBitCast(init, llvm::PointerType::get(type, 0));
187 }
188 }
189
190 llvm::Value *get(const std::string &field) {
191 return call(fmt::format("get_{}", field));
192 }
193
194 llvm::Value *get(const std::string &field, llvm::Value *index) {
195 return call(fmt::format("get_{}", field), index);
196 }
197
198 llvm::Value *get_ptr(const std::string &field) {
199 return call(fmt::format("get_ptr_{}", field));
200 }
201
202 void set(const std::string &field, llvm::Value *val) {
203 call(fmt::format("set_{}", field), val);
204 }
205
206 void set(const std::string &field, llvm::Value *index, llvm::Value *val) {
207 call(fmt::format("set_{}", field), index, val);
208 }
209
210 template <typename... Args>
211 llvm::Value *call(const std::string &func_name, Args &&...args) {
212 auto func = get_func(func_name);
213 auto arglist = std::vector<llvm::Value *>({ptr, args...});
214 check_func_call_signature(func->getFunctionType(), func->getName(), arglist,
215 builder);
216 return builder->CreateCall(func, std::move(arglist));
217 }
218
219 llvm::Function *get_func(const std::string &func_name) const {
220 return mb->get_runtime_function(fmt::format("{}_{}", cls_name, func_name));
221 }
222};
223
224} // namespace taichi::lang
225