1/* Copyright 2015-2017 Philippe Tillet
2 *
3 * Permission is hereby granted, free of charge, to any person obtaining
4 * a copy of this software and associated documentation files
5 * (the "Software"), to deal in the Software without restriction,
6 * including without limitation the rights to use, copy, modify, merge,
7 * publish, distribute, sublicense, and/or sell copies of the Software,
8 * and to permit persons to whom the Software is furnished to do so,
9 * subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be
12 * included in all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
18 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
19 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
20 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 */
22#include <fstream>
23#if __has_include(<unistd.h>)
24#include <unistd.h>
25#endif
26#include <memory>
27#include <regex>
28#include "triton/driver/llvm.h"
29#include "triton/driver/dispatch.h"
30#include "triton/driver/error.h"
31#include "triton/tools/sha1.hpp"
32#include "triton/tools/sys/getenv.hpp"
33#include "triton/tools/sys/mkdir.hpp"
34#include "triton/tools/sys/exec.hpp"
35#include "llvm/IR/IRBuilder.h"
36#include "llvm/IR/Verifier.h"
37#include "llvm/IR/IRPrintingPasses.h"
38#include "llvm/IR/Module.h"
39#include "llvm/Support/CodeGen.h"
40#include "llvm/Support/CommandLine.h"
41#include "llvm/Support/SourceMgr.h"
42#include "llvm/Support/raw_ostream.h"
43#include "llvm/Support/TargetRegistry.h"
44#include "llvm/Support/TargetSelect.h"
45#include "llvm/Target/TargetMachine.h"
46#include "llvm/Target/TargetOptions.h"
47#include "llvm/IR/LegacyPassManager.h"
48#include "llvm/ExecutionEngine/ExecutionEngine.h"
49#include "llvm/ExecutionEngine/SectionMemoryManager.h"
50#include "llvm/Transforms/Utils/Cloning.h"
51#include "llvm/Transforms/Scalar.h"
52
53// begin AMD stuff
54#include "llvm/Support/FileSystem.h"
55#include "llvm/Support/FormattedStream.h"
56#include "llvm/Support/Program.h"
57#include "llvm/Support/ToolOutputFile.h"
58#include "llvm/ADT/StringRef.h"
59#include "llvm/Analysis/TargetLibraryInfo.h"
60// end AMD stuff
61
62extern "C"
63{
64 int set_curterm(char *nterm) { return 0; }
65 int del_curterm(char *nterm) { return 0; }
66 int tigetnum(char *capname) { return 0; }
67 int setupterm(char *term, int fildes, int *errret) { return 0; }
68}
69
70namespace triton
71{
72 namespace driver
73 {
74
75 void init_llvm()
76 {
77 LLVMInitializeNVPTXTargetInfo();
78 LLVMInitializeNVPTXTarget();
79 LLVMInitializeNVPTXTargetMC();
80 LLVMInitializeNVPTXAsmPrinter();
81 LLVMInitializeAMDGPUTargetInfo();
82 LLVMInitializeAMDGPUTarget();
83 LLVMInitializeAMDGPUTargetMC();
84 LLVMInitializeAMDGPUAsmPrinter();
85 }
86
87 /* ------------------------ */
88 // CUDA //
89 /* ------------------------ */
90 static bool find_and_replace(std::string &str, const std::string &begin, const std::string &end, const std::string &target)
91 {
92 size_t start_replace = str.find(begin);
93 size_t end_replace = str.find(end, start_replace);
94 if (start_replace == std::string::npos)
95 return false;
96 str.replace(start_replace, end_replace + 1 - start_replace, target);
97 return true;
98 }
99
100 std::string path_to_ptxas(int &version)
101 {
102 std::vector<std::string> rets;
103 std::string ret;
104 // search paths for ptxas
105 std::vector<std::string> ptxas_prefixes = {"", "/usr/local/cuda/bin/"};
106 std::string triton_ptxas = tools::getenv("TRITON_PTXAS_PATH");
107 if (!triton_ptxas.empty())
108 ptxas_prefixes.insert(ptxas_prefixes.begin(), triton_ptxas);
109 // see what path for ptxas are valid
110 std::vector<std::string> working_ptxas;
111 for (std::string prefix : ptxas_prefixes)
112 {
113 std::string ptxas = prefix + "ptxas";
114 bool works = tools::exec(ptxas + " --version 2>&1", ret) == 0;
115 if (works)
116 {
117 working_ptxas.push_back(ptxas);
118 rets.push_back(ret);
119 }
120 }
121 // error if no working ptxas was found
122 if (working_ptxas.empty())
123 throw std::runtime_error("`ptxas` was searched in TRITON_PTXAS_PATH, /usr/local/cuda/bin/ or PATH"
124 " but a working version could not be found.");
125 std::string ptxas = working_ptxas.front();
126 // parse version
127 std::regex version_regex("release (\\d+)\\.(\\d+)");
128 std::smatch match;
129 bool found = false;
130 // currently choosing the first ptxas. Other logics can be implemented in future
131 for (std::string ret : rets)
132 {
133 if (std::regex_search(ret, match, version_regex))
134 {
135 int major = std::stoi(match[1]);
136 int minor = std::stoi(match[2]);
137 version = major * 1000 + minor * 10;
138 found = true;
139 break;
140 }
141 }
142 if (not found)
143 {
144 throw std::runtime_error("Error in parsing version");
145 }
146 return ptxas;
147 }
148
149 int vptx(int version)
150 {
151 if (version >= 11040)
152 return 74;
153 // if(version >= 11030) return 73;
154 // if(version >= 11020) return 72;
155 // if(version >= 11010) return 71;
156 // if(version >= 11000) return 70;
157 // if(version >= 10020) return 65;
158 // if(version >= 10010) return 64;
159 // if(version >= 10000) return 63;
160 throw std::runtime_error("Triton requires CUDA 11.4+");
161 }
162
163 std::string llir_to_ptx(llvm::Module *module, int cc, int version)
164 {
165 // LLVM version in use may not officially support target hardware
166 int max_nvvm_cc = 75;
167 int max_nvvm_ptx = 74;
168 // options
169 auto options = llvm::cl::getRegisteredOptions();
170 auto *short_ptr = static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
171 assert(short_ptr);
172 short_ptr->setValue(true);
173 // compute capability
174 std::string sm = "sm_" + std::to_string(cc);
175 // max PTX version
176 int ptx = vptx(version);
177 int ptx_major = ptx / 10;
178 int ptx_minor = ptx % 10;
179 // create
180 llvm::SmallVector<char, 0> buffer;
181 std::string triple = "nvptx64-nvidia-cuda";
182 std::string proc = "sm_" + std::to_string(std::min(cc, max_nvvm_cc));
183 std::string layout = "";
184 std::string features = "";
185 // std::string features = "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx));
186 init_llvm();
187 // verify and store llvm
188 llvm::legacy::PassManager pm;
189 // pm.add(llvm::createPrintModulePass(llvm::outs()));
190 pm.add(llvm::createVerifierPass());
191 pm.run(*module);
192 // module->print(llvm::outs(), nullptr);
193
194 // create machine
195 module->setTargetTriple(triple);
196 std::string error;
197 llvm::TargetMachine *machine;
198 auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
199 llvm::TargetOptions opt;
200 opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
201 opt.UnsafeFPMath = false;
202 opt.NoInfsFPMath = false;
203 opt.NoNaNsFPMath = true;
204 machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
205 llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive);
206 // set data layout
207 if (layout.empty())
208 module->setDataLayout(machine->createDataLayout());
209 else
210 module->setDataLayout(layout);
211 // emit machine code
212 for (llvm::Function &f : module->functions())
213 f.addFnAttr(llvm::Attribute::AlwaysInline);
214 llvm::legacy::PassManager pass;
215 llvm::raw_svector_ostream stream(buffer);
216 // emit
217 machine->addPassesToEmitFile(pass, stream, nullptr, llvm::CodeGenFileType::CGFT_AssemblyFile);
218 pass.run(*module);
219
220 // post-process
221 std::string result(buffer.begin(), buffer.end());
222 find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n");
223 find_and_replace(result, ".target", "\n", ".target " + sm + "\n");
224 while (find_and_replace(result, "\t// begin inline asm", "\n", ""))
225 ;
226 while (find_and_replace(result, "\t// end inline asm", "\n", ""))
227 ;
228 return result;
229 }
230
231 std::string ptx_to_cubin(const std::string &ptx, const std::string &ptxas, int cc)
232 {
233 // compile ptx with ptxas
234 char _fsrc[L_tmpnam];
235 char _flog[L_tmpnam];
236 std::tmpnam(_fsrc);
237 std::tmpnam(_flog);
238 std::string fsrc = _fsrc;
239 std::string flog = _flog;
240 std::string fbin = fsrc + ".o";
241 const char *_fbin = fbin.c_str();
242 std::ofstream ofs(fsrc);
243 ofs << ptx << std::endl;
244 ofs.close();
245 std::string cmd;
246 int err;
247 cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog;
248 err = system(cmd.c_str());
249 if (err != 0)
250 {
251 std::ifstream _log(_flog);
252 std::string log(std::istreambuf_iterator<char>(_log), {});
253 unlink(_fsrc);
254 unlink(_flog);
255 throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
256 }
257 std::ifstream _cubin(_fbin, std::ios::binary);
258 std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
259 _cubin.close();
260 unlink(_fsrc);
261 unlink(_flog);
262 unlink(_fbin);
263 return cubin;
264 }
265
266 /* ------------------------ */
267 // HIP //
268 /* ------------------------ */
269
270 std::string llir_to_amdgpu(llvm::Module *module, const std::string &_proc)
271 {
272 init_llvm();
273
274 // proc = std::get<0>(GetFeatureStrFromGCNArchName(rocminfo));
275 // features = std::get<1>(GetFeatureStrFromGCNArchName(rocminfo));
276
277 // create
278 llvm::SmallVector<char, 0> buffer;
279 std::string triple = "amdgcn-amd-amdhsa";
280 std::string layout = "";
281 std::string features;
282 std::string proc = "gfx908";
283 // verify and store llvm
284 llvm::legacy::PassManager pm;
285 pm.add(llvm::createVerifierPass());
286 pm.run(*module);
287 // create machine
288 module->setTargetTriple(triple);
289 std::string error;
290 auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
291 llvm::TargetOptions opt;
292 opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
293 opt.UnsafeFPMath = false;
294 opt.NoInfsFPMath = false;
295 opt.NoNaNsFPMath = true;
296 llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, features, opt,
297 llvm::Reloc::PIC_, llvm::None,
298 llvm::CodeGenOpt::Aggressive);
299 // set data layout
300 if (layout.empty())
301 module->setDataLayout(machine->createDataLayout());
302 else
303 module->setDataLayout(layout);
304 // emit machine code
305 for (llvm::Function &f : module->functions())
306 f.addFnAttr(llvm::Attribute::AlwaysInline);
307 llvm::legacy::PassManager pass;
308 llvm::raw_svector_ostream stream(buffer);
309
310 // create dump files
311 std::string module_name = module->getModuleIdentifier();
312 std::error_code ec;
313
314 // Save GCN ISA binary.
315 std::string isabin_path = std::string("/tmp/") + module_name + std::string(".o");
316 std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
317 new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
318 if (ec)
319 {
320 std::cout << isabin_path << " was not created. error code: " << ec << std::endl;
321 }
322
323 // emit
324 machine->addPassesToEmitFile(pass, *isabin_fs, nullptr, llvm::CGFT_ObjectFile);
325 pass.run(*module);
326 // Save GCN ISA.
327 std::string amdgcn_path = std::string("/tmp/") + module_name + std::string(".gcn");
328 std::string result(buffer.begin(), buffer.end());
329 std::ofstream amdgcn(amdgcn_path);
330 amdgcn << result;
331 amdgcn.close();
332
333 // generate HASCO file
334 std::string hsaco_path = std::string("/tmp/") + module_name + std::string(".hsaco");
335 std::string error_message;
336 int lld_result =
337 llvm::sys::ExecuteAndWait("/opt/rocm/llvm/bin/ld.lld",
338 {"/opt/rocm/llvm/bin/ld.lld", "-flavor", "gnu", "-shared", "-o", hsaco_path, isabin_path},
339 llvm::None, {}, 0, 0, &error_message);
340 if (lld_result)
341 {
342 std::cout << "ld.lld execute fail: " << std::endl;
343 std::cout << error_message << std::endl;
344 std::cout << lld_result << std::endl;
345 }
346
347 return hsaco_path;
348 }
349
350 hipModule_t amdgpu_to_hipmodule(const std::string &path)
351 {
352 // Read HSACO.
353 std::ifstream hsaco_file(path, std::ios::binary | std::ios::ate);
354 std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
355
356 std::vector<unsigned char> hsaco(hsaco_file_size);
357 hsaco_file.seekg(0, std::ios::beg);
358 hsaco_file.read(reinterpret_cast<char *>(&hsaco[0]), hsaco_file_size);
359 hsaco_file.close();
360 hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, hipJitOptionErrorLogBuffer,
361 hipJitOptionInfoLogBufferSizeBytes, hipJitOptionInfoLogBuffer,
362 hipJitOptionLogVerbose};
363 const unsigned int errbufsize = 8192;
364 const unsigned int logbufsize = 8192;
365 char _err[errbufsize];
366 char _log[logbufsize];
367 void *optval[] = {(void *)(uintptr_t)errbufsize,
368 (void *)_err, (void *)(uintptr_t)logbufsize,
369 (void *)_log, (void *)1};
370 hipModule_t ret;
371 dispatch::hipModuleLoadDataEx(&ret, hsaco.data(), 5, opt, optval);
372 return ret;
373 }
374
375 } // namespace driver
376} // namespace triton
377