1#include "triton/codegen/pass.h"
2
3#include "llvm/IR/Constants.h"
4#include "llvm/IR/LegacyPassManager.h"
5#include "llvm/IR/Module.h"
6#include "llvm/IR/Verifier.h"
7#include "llvm/IRReader/IRReader.h"
8#include "llvm/Linker/Linker.h"
9#include "llvm/Support/SourceMgr.h"
10#include "llvm/Transforms/IPO.h"
11#include "llvm/Transforms/IPO/PassManagerBuilder.h"
12#include "triton/codegen/analysis/align.h"
13#include "triton/codegen/analysis/allocation.h"
14#include "triton/codegen/analysis/axes.h"
15#include "triton/codegen/analysis/liveness.h"
16#include "triton/codegen/analysis/swizzle.h"
17#include "triton/codegen/selection/generator.h"
18#include "triton/codegen/transform/coalesce.h"
19#include "triton/codegen/transform/cts.h"
20#include "triton/codegen/transform/dce.h"
21#include "triton/codegen/transform/disassociate.h"
22#include "triton/codegen/transform/inline.h"
23#include "triton/codegen/transform/membar.h"
24#include "triton/codegen/transform/peephole.h"
25#include "triton/codegen/transform/pipeline.h"
26#include "triton/codegen/transform/prefetch.h"
27#include "triton/ir/function.h"
28#include "triton/ir/module.h"
29#include "triton/ir/print.h"
30
31namespace triton {
32namespace codegen {
33
34static void link_extern_libs(const ExternLibMap& user_extern_lib_map,
35 const ExternLibMap& target_extern_lib_map,
36 ir::module& ir, llvm::LLVMContext& ctx,
37 std::unique_ptr<llvm::Module>& llvm) {
38 for (const auto& iter : target_extern_lib_map) {
39 auto &lib_name = iter.first;
40 if (user_extern_lib_map.count(lib_name) != 0 &&
41 user_extern_lib_map.at(lib_name)->path() != "") {
42 // If the user specified a path for this library, use it.
43 user_extern_lib_map.at(lib_name)->install(ctx, llvm);
44 } else {
45 // Otherwise, use the default path.
46 iter.second->install(ctx, llvm);
47 }
48 }
49
50 std::set<llvm::StringRef> function_names;
51 for (auto& func : ir.get_function_list()) {
52 function_names.insert(func->get_name());
53 }
54 llvm::legacy::PassManager pass;
55 pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool {
56 if (function_names.count(v.getName()) != 0) {
57 // Preserve global functions
58 return true;
59 }
60 // Internalize all device functions
61 return false;
62 }));
63
64 llvm::legacy::PassManager pm;
65 pm.add(llvm::createVerifierPass());
66 pm.run(*llvm);
67
68 llvm::PassManagerBuilder builder;
69 builder.OptLevel = 3;
70 builder.SizeLevel = 0;
71 builder.populateModulePassManager(pass);
72
73 pass.run(*llvm);
74}
75
76// TODO:
77// There should be a proper pass manager there!
78std::unique_ptr<llvm::Module> add_passes_to_emit_bin(
79 ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target,
80 int num_warps, int num_stages, int& shared_static,
81 const ExternLibMap& extern_lib_map) {
82 // generate llvm code
83 std::string name = ir.get_function_list()[0]->get_name();
84 std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx));
85 // optimizations
86 bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80;
87 // create passes
88 codegen::analysis::align align;
89 codegen::transform::inliner inliner;
90 codegen::analysis::axes axes;
91 codegen::transform::pipeline pipeline(has_sm80, num_stages);
92 codegen::transform::disassociate disassociate;
93 codegen::analysis::layouts layouts(&axes, &align, num_warps, target);
94 codegen::transform::cts cts(&layouts, has_sm80);
95 codegen::analysis::liveness liveness(&layouts);
96 codegen::analysis::swizzle swizzle(&layouts, target);
97 codegen::analysis::allocation allocation(&liveness);
98 codegen::transform::dce dce;
99 codegen::transform::peephole peephole(target, &layouts);
100 codegen::transform::coalesce coalesce(&align, &layouts, has_sm80);
101 codegen::transform::prefetch prefetch_s(target);
102 codegen::transform::membar barriers(&liveness, &layouts, &allocation,
103 &prefetch_s, target);
104 codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle,
105 target, num_warps);
106 // run passes
107 inliner.run(ir);
108 dce.run(ir);
109 peephole.run(ir);
110 dce.run(ir);
111 pipeline.run(ir);
112 dce.run(ir);
113 // ir.print(std::cout);
114 disassociate.run(ir);
115 dce.run(ir);
116 align.run(ir);
117 axes.run(ir);
118 layouts.run(ir);
119 peephole.run(ir);
120 dce.run(ir);
121 if (target->is_gpu()) cts.run(ir);
122 align.run(ir);
123 axes.run(ir);
124 layouts.run(ir);
125 coalesce.run(ir);
126 dce.run(ir);
127 align.run(ir);
128 dce.run(ir);
129 if (target->is_gpu()) cts.run(ir);
130 dce.run(ir);
131 align.run(ir);
132 axes.run(ir);
133 layouts.run(ir);
134 peephole.run(ir);
135 dce.run(ir);
136 align.run(ir);
137 axes.run(ir);
138 layouts.run(ir);
139 swizzle.run(ir);
140 // std::cout << "---" << std::endl;
141 // ir.print(std::cout);
142 // std::cout << "---" << std::endl;
143 // ir.print(std::cout);
144 liveness.run(ir);
145 allocation.run(ir);
146 prefetch_s.run(ir);
147 barriers.run(ir);
148 // exit(1);
149 // ir.print(std::cout);
150 isel.visit(ir, *llvm);
151 shared_static = allocation.allocated_size();
152 if (target->as_nvidia() && target->as_nvidia()->sm() < 70) {
153 // sm < 70 (Pascal) has little shared memory resource.
154 // Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here.
155 if (shared_static >= 65536) {
156 throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes");
157 }
158 }
159
160 if (isel.get_extern_lib_map().size() > 0) {
161 // If there's any extern lib calls,
162 // we need to link them in.
163 link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm);
164 }
165
166 return llvm;
167}
168
169} // namespace codegen
170} // namespace triton
171