1#include "triton/codegen/extern_lib.h"
2
3#include "llvm/IR/Constants.h"
4#include "llvm/IR/LegacyPassManager.h"
5#include "llvm/IR/Metadata.h"
6#include "llvm/IR/Type.h"
7#include "llvm/Linker/Linker.h"
8#include "llvm/Transforms/IPO/PassManagerBuilder.h"
9#include "triton/codegen/pass.h"
10
11namespace triton {
12
13namespace codegen {
14
15std::unique_ptr<llvm::Module> ExternLib::load(llvm::LLVMContext& ctx) {
16 llvm::SMDiagnostic err;
17 auto mod = llvm::parseIRFile(this->path_, err, ctx);
18 if (!mod) {
19 throw std::runtime_error("Failed to load extern lib " + this->name_ +
20 " at " + this->path_);
21 }
22 return mod;
23}
24
25void ExternLib::link(std::unique_ptr<llvm::Module>& llvm,
26 std::unique_ptr<llvm::Module>& mod) {
27 // Set triple and data layout to match the target module
28 mod->setTargetTriple(llvm->getTargetTriple());
29 mod->setDataLayout(llvm->getDataLayout());
30 if (llvm::Linker::linkModules(*llvm, std::move(mod))) {
31 throw std::runtime_error("Failed to link extern lib " + this->name_ +
32 " at " + this->path_);
33 }
34}
35
36void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr<llvm::Module>& llvm) {
37 // Add nvvm reflect flags to llvm module
38 // https://llvm.org/docs/LangRef.html#module-flags-metadata
39 // i32 4: Override the other module.
40 // i32 1: Emit an error
41 // If both modules specify Override, but the values differ, an error
42 // will be emitted.
43 llvm::Type* I32 = llvm::Type::getInt32Ty(ctx);
44 llvm::Metadata* md_four =
45 llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4));
46 llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz");
47 llvm::Metadata* md_one =
48 llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1));
49 llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one});
50 llvm->addModuleFlag(reflect);
51}
52
53std::unique_ptr<ExternLib> create_extern_lib(const std::string& lib_name,
54 const std::string& lib_path) {
55 if (lib_name == "libdevice") {
56 return std::make_unique<LibDevice>(lib_name, lib_path);
57 } else {
58 throw std::runtime_error("Unknown external library: " + lib_name);
59 }
60}
61
62} // namespace codegen
63} // namespace triton
64