1 | #include "triton/codegen/extern_lib.h" |
2 | |
3 | #include "llvm/IR/Constants.h" |
4 | #include "llvm/IR/LegacyPassManager.h" |
5 | #include "llvm/IR/Metadata.h" |
6 | #include "llvm/IR/Type.h" |
7 | #include "llvm/Linker/Linker.h" |
8 | #include "llvm/Transforms/IPO/PassManagerBuilder.h" |
9 | #include "triton/codegen/pass.h" |
10 | |
11 | namespace triton { |
12 | |
13 | namespace codegen { |
14 | |
15 | std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) { |
16 | llvm::SMDiagnostic err; |
17 | auto mod = llvm::parseIRFile(this->path_, err, ctx); |
18 | if (!mod) { |
19 | throw std::runtime_error("Failed to load extern lib " + this->name_ + |
20 | " at " + this->path_); |
21 | } |
22 | return mod; |
23 | } |
24 | |
25 | void ExternLib::link(std::unique_ptr<llvm::Module>& llvm, |
26 | std::unique_ptr<llvm::Module>& mod) { |
27 | // Set triple and data layout to match the target module |
28 | mod->setTargetTriple(llvm->getTargetTriple()); |
29 | mod->setDataLayout(llvm->getDataLayout()); |
30 | if (llvm::Linker::linkModules(*llvm, std::move(mod))) { |
31 | throw std::runtime_error("Failed to link extern lib " + this->name_ + |
32 | " at " + this->path_); |
33 | } |
34 | } |
35 | |
36 | void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) { |
37 | // Add nvvm reflect flags to llvm module |
38 | // https://llvm.org/docs/LangRef.html#module-flags-metadata |
39 | // i32 4: Override the other module. |
40 | // i32 1: Emit an error |
41 | // If both modules specify Override, but the values differ, an error |
42 | // will be emitted. |
43 | llvm::Type* I32 = llvm::Type::getInt32Ty(ctx); |
44 | llvm::Metadata* md_four = |
45 | llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); |
46 | llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz" ); |
47 | llvm::Metadata* md_one = |
48 | llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); |
49 | llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one}); |
50 | llvm->addModuleFlag(reflect); |
51 | } |
52 | |
53 | std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name, |
54 | const std::string& lib_path) { |
55 | if (lib_name == "libdevice" ) { |
56 | return std::make_unique<LibDevice>(lib_name, lib_path); |
57 | } else { |
58 | throw std::runtime_error("Unknown external library: " + lib_name); |
59 | } |
60 | } |
61 | |
62 | } // namespace codegen |
63 | } // namespace triton |
64 | |