1 | #pragma once |
2 | |
3 | #include "llvm/IR/LLVMContext.h" |
4 | #include "llvm/IR/LegacyPassManager.h" |
5 | #include "llvm/IR/Function.h" |
6 | #include "llvm/Pass.h" |
7 | #include "llvm/IR/Module.h" |
8 | #include "llvm/Transforms/IPO.h" |
9 | #include "llvm/Transforms/IPO/PassManagerBuilder.h" |
10 | #include "llvm/IR/Instructions.h" |
11 | #include "llvm/IR/IRBuilder.h" |
12 | #include "llvm/Support/SourceMgr.h" |
13 | #include "llvm/Transforms/Utils/ValueMapper.h" |
14 | #include "llvm/Transforms/Utils/Cloning.h" |
15 | |
16 | #if defined(TI_WITH_AMDGPU) |
17 | #include "taichi/rhi/amdgpu/amdgpu_context.h" |
18 | #endif |
19 | |
20 | namespace taichi { |
21 | namespace lang { |
22 | using namespace llvm; |
23 | |
24 | struct AddStructForFuncPass : public ModulePass { |
25 | static inline char ID{0}; |
26 | std::string func_name_; |
27 | int tls_size_; |
28 | AddStructForFuncPass(std::string func_name, int tls_size) : ModulePass(ID) { |
29 | func_name_ = func_name; |
30 | tls_size_ = tls_size; |
31 | } |
32 | bool runOnModule(llvm::Module &M) override { |
33 | auto struct_for_func = M.getFunction("parallel_struct_for" ); |
34 | auto &llvm_context = M.getContext(); |
35 | auto value_map = llvm::ValueToValueMapTy(); |
36 | auto patched_struct_for_func = |
37 | llvm::CloneFunction(struct_for_func, value_map); |
38 | patched_struct_for_func->setName(func_name_); |
39 | |
40 | int num_found_alloca = 0; |
41 | llvm::AllocaInst *alloca = nullptr; |
42 | |
43 | auto char_type = llvm::Type::getInt8Ty(llvm_context); |
44 | |
45 | // Find the "1" in "char tls_buffer[1]" and replace it with |
46 | // "tls_buffer_size" |
47 | for (auto &bb : *patched_struct_for_func) { |
48 | for (llvm::Instruction &inst : bb) { |
49 | auto now_alloca = llvm::dyn_cast<AllocaInst>(&inst); |
50 | if (!now_alloca || now_alloca->getAlign().value() != 8) |
51 | continue; |
52 | auto alloca_type = now_alloca->getAllocatedType(); |
53 | // Allocated type should be array [1 x i8] |
54 | if (alloca_type->isArrayTy() && |
55 | alloca_type->getArrayNumElements() == 1 && |
56 | alloca_type->getArrayElementType() == char_type) { |
57 | alloca = now_alloca; |
58 | num_found_alloca++; |
59 | } |
60 | } |
61 | } |
62 | TI_ASSERT(num_found_alloca == 1 && alloca); |
63 | auto new_type = llvm::ArrayType::get(char_type, tls_size_); |
64 | llvm::IRBuilder<> builder(alloca); |
65 | auto *new_alloca = builder.CreateAlloca(new_type); |
66 | new_alloca->setAlignment(Align(8)); |
67 | TI_ASSERT(alloca->hasOneUse()); |
68 | auto *gep = llvm::cast<llvm::GetElementPtrInst>(alloca->user_back()); |
69 | TI_ASSERT(gep->getPointerOperand() == alloca); |
70 | std::vector<Value *> indices(gep->idx_begin(), gep->idx_end()); |
71 | builder.SetInsertPoint(gep); |
72 | auto *new_gep = builder.CreateInBoundsGEP(new_type, new_alloca, indices); |
73 | gep->replaceAllUsesWith(new_gep); |
74 | gep->eraseFromParent(); |
75 | alloca->eraseFromParent(); |
76 | return false; |
77 | } |
78 | }; |
79 | |
80 | #if defined(TI_WITH_AMDGPU) |
81 | struct AMDGPUConvertAllocaInstAddressSpacePass : public FunctionPass { |
82 | static inline char ID{0}; |
83 | AMDGPUConvertAllocaInstAddressSpacePass() : FunctionPass(ID) { |
84 | } |
85 | bool runOnFunction(llvm::Function &f) override { |
86 | f.addFnAttr("target-cpu" , |
87 | "gfx" + AMDGPUContext::get_instance().get_mcpu().substr(3, 4)); |
88 | f.addFnAttr("target-features" , "" ); |
89 | for (auto &bb : f) { |
90 | std::vector<AllocaInst *> alloca_inst_vec; |
91 | for (Instruction &inst : bb) { |
92 | AllocaInst *now_alloca = dyn_cast<AllocaInst>(&inst); |
93 | if (!now_alloca || |
94 | now_alloca->getType()->getAddressSpace() != (unsigned)0) { |
95 | continue; |
96 | } |
97 | alloca_inst_vec.push_back(now_alloca); |
98 | } |
99 | for (auto &allocainst : alloca_inst_vec) { |
100 | auto alloca_type = allocainst->getAllocatedType(); |
101 | llvm::IRBuilder<> builder(allocainst); |
102 | auto *new_alloca = builder.CreateAlloca(alloca_type, (unsigned)5); |
103 | auto new_type = llvm::PointerType::get(alloca_type, (unsigned)0); |
104 | new_alloca->setAlignment(Align(allocainst->getAlign().value())); |
105 | auto *addrspacecast = builder.CreateAddrSpaceCast(new_alloca, new_type); |
106 | allocainst->replaceAllUsesWith(addrspacecast); |
107 | allocainst->eraseFromParent(); |
108 | } |
109 | } |
110 | return false; |
111 | } |
112 | }; |
113 | |
114 | struct AMDGPUAddStructForFuncPass : public ModulePass { |
115 | static inline char ID{0}; |
116 | std::string func_name_; |
117 | int tls_size_; |
118 | AMDGPUAddStructForFuncPass(std::string func_name, int tls_size) |
119 | : ModulePass(ID) { |
120 | func_name_ = func_name; |
121 | tls_size_ = tls_size; |
122 | } |
123 | bool runOnModule(llvm::Module &M) override { |
124 | auto struct_for_func = M.getFunction("parallel_struct_for" ); |
125 | auto &llvm_context = M.getContext(); |
126 | auto value_map = llvm::ValueToValueMapTy(); |
127 | auto patched_struct_for_func = |
128 | llvm::CloneFunction(struct_for_func, value_map); |
129 | patched_struct_for_func->setName(func_name_); |
130 | |
131 | int num_found_alloca = 0; |
132 | llvm::AllocaInst *alloca = nullptr; |
133 | |
134 | auto char_type = llvm::Type::getInt8Ty(llvm_context); |
135 | |
136 | // Find the "1" in "char tls_buffer[1]" and replace it with |
137 | // "tls_buffer_size" |
138 | for (auto &bb : *patched_struct_for_func) { |
139 | for (llvm::Instruction &inst : bb) { |
140 | auto now_alloca = llvm::dyn_cast<AllocaInst>(&inst); |
141 | if (!now_alloca || now_alloca->getAlign().value() != 8) |
142 | continue; |
143 | auto alloca_type = now_alloca->getAllocatedType(); |
144 | // Allocated type should be array [1 x i8] |
145 | if (alloca_type->isArrayTy() && |
146 | alloca_type->getArrayNumElements() == 1 && |
147 | alloca_type->getArrayElementType() == char_type) { |
148 | alloca = now_alloca; |
149 | num_found_alloca++; |
150 | } |
151 | } |
152 | } |
153 | TI_ASSERT(num_found_alloca == 1 && alloca); |
154 | auto new_type = llvm::ArrayType::get(char_type, tls_size_); |
155 | llvm::IRBuilder<> builder(alloca); |
156 | auto *new_alloca = builder.CreateAlloca(new_type, (unsigned)5); |
157 | new_alloca->setAlignment(Align(8)); |
158 | auto new_ty = llvm::PointerType::get(new_type, unsigned(0)); |
159 | auto *new_cast = builder.CreateAddrSpaceCast(new_alloca, new_ty); |
160 | new_alloca->setAlignment(Align(8)); |
161 | TI_ASSERT(alloca->hasOneUse()); |
162 | auto *cast = llvm::cast<llvm::AddrSpaceCastInst>(alloca->user_back()); |
163 | TI_ASSERT(cast->hasOneUse()); |
164 | auto *gep = llvm::cast<llvm::GetElementPtrInst>(cast->user_back()); |
165 | TI_ASSERT(gep->getPointerOperand() == cast); |
166 | std::vector<Value *> indices(gep->idx_begin(), gep->idx_end()); |
167 | builder.SetInsertPoint(gep); |
168 | auto *new_gep = builder.CreateInBoundsGEP(new_type, new_cast, indices); |
169 | gep->replaceAllUsesWith(new_gep); |
170 | gep->eraseFromParent(); |
171 | cast->eraseFromParent(); |
172 | alloca->eraseFromParent(); |
173 | return false; |
174 | } |
175 | }; |
176 | |
177 | struct AMDGPUConvertFuncParamAddressSpacePass : public ModulePass { |
178 | static inline char ID{0}; |
179 | AMDGPUConvertFuncParamAddressSpacePass() : ModulePass(ID) { |
180 | } |
181 | bool runOnModule(llvm::Module &M) override { |
182 | for (auto &f : M) { |
183 | bool is_kernel = false; |
184 | const std::string func_name = f.getName().str(); |
185 | if (starts_with(func_name, "runtime_" )) { |
186 | f.setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
187 | // ref https://llvm.org/docs/AMDGPUUsage.html |
188 | // “amdgpu-flat-work-group-size”=”min,max” |
189 | // Specify the minimum and maximum flat work group sizes that will be |
190 | // specified when the kernel is dispatched. Generated by the |
191 | // amdgpu_flat_work_group_size CLANG attribute [CLANG-ATTR]. The implied |
192 | // default value is 1,1024. |
193 | f.addFnAttr("amdgpu-flat-work-group-size" , "1, 1024" ); |
194 | is_kernel = true; |
195 | } |
196 | if (!is_kernel && !f.isDeclaration()) |
197 | f.setLinkage(llvm::Function::PrivateLinkage); |
198 | } |
199 | std::vector<llvm::Function *> kernel_function; |
200 | for (auto &f : M) { |
201 | if (f.getCallingConv() == llvm::CallingConv::AMDGPU_KERNEL) |
202 | kernel_function.push_back(&f); |
203 | } |
204 | for (auto &f : kernel_function) { |
205 | llvm::FunctionType *func_type = f->getFunctionType(); |
206 | std::vector<llvm::Type *> new_func_params; |
207 | for (auto &arg : f->args()) { |
208 | if (arg.getType()->getTypeID() == llvm::Type::PointerTyID) { |
209 | auto new_type = llvm::PointerType::get( |
210 | arg.getType()->getPointerElementType(), unsigned(1)); |
211 | new_func_params.push_back(new_type); |
212 | } else { |
213 | new_func_params.push_back(arg.getType()); |
214 | } |
215 | } |
216 | auto new_func_type = llvm::FunctionType::get(func_type->getReturnType(), |
217 | new_func_params, false); |
218 | auto new_func = llvm::Function::Create(new_func_type, f->getLinkage(), |
219 | f->getAddressSpace()); |
220 | new_func->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); |
221 | new_func->addFnAttr("amdgpu-flat-work-group-size" , "1, 1024" ); |
222 | new_func->addFnAttr( |
223 | "target-cpu" , |
224 | "gfx" + AMDGPUContext::get_instance().get_mcpu().substr(3, 4)); |
225 | new_func->setComdat(f->getComdat()); |
226 | f->getParent()->getFunctionList().insert(f->getIterator(), new_func); |
227 | new_func->takeName(f); |
228 | new_func->getBasicBlockList().splice(new_func->begin(), |
229 | f->getBasicBlockList()); |
230 | for (llvm::Function::arg_iterator I = f->arg_begin(), E = f->arg_end(), |
231 | I2 = new_func->arg_begin(); |
232 | I != E; ++I, ++I2) { |
233 | if (I->getType()->getTypeID() == llvm::Type::PointerTyID) { |
234 | auto &front_bb = new_func->getBasicBlockList().front(); |
235 | llvm::Instruction *addrspacecast = |
236 | new AddrSpaceCastInst(I2, I->getType()); |
237 | front_bb.getInstList().insertAfter(front_bb.getFirstInsertionPt(), |
238 | addrspacecast); |
239 | I->replaceAllUsesWith(addrspacecast); |
240 | I2->takeName(&*I); |
241 | } else { |
242 | I->replaceAllUsesWith(&*I2); |
243 | I2->takeName(&*I); |
244 | } |
245 | } |
246 | |
247 | f->eraseFromParent(); |
248 | } |
249 | return false; |
250 | } |
251 | }; |
252 | |
253 | #endif |
254 | |
255 | } // namespace lang |
256 | } // namespace taichi |
257 | |