1 | #include "triton/codegen/pass.h" |
2 | |
3 | #include "llvm/IR/Constants.h" |
4 | #include "llvm/IR/LegacyPassManager.h" |
5 | #include "llvm/IR/Module.h" |
6 | #include "llvm/IR/Verifier.h" |
7 | #include "llvm/IRReader/IRReader.h" |
8 | #include "llvm/Linker/Linker.h" |
9 | #include "llvm/Support/SourceMgr.h" |
10 | #include "llvm/Transforms/IPO.h" |
11 | #include "llvm/Transforms/IPO/PassManagerBuilder.h" |
12 | #include "triton/codegen/analysis/align.h" |
13 | #include "triton/codegen/analysis/allocation.h" |
14 | #include "triton/codegen/analysis/axes.h" |
15 | #include "triton/codegen/analysis/liveness.h" |
16 | #include "triton/codegen/analysis/swizzle.h" |
17 | #include "triton/codegen/selection/generator.h" |
18 | #include "triton/codegen/transform/coalesce.h" |
19 | #include "triton/codegen/transform/cts.h" |
20 | #include "triton/codegen/transform/dce.h" |
21 | #include "triton/codegen/transform/disassociate.h" |
22 | #include "triton/codegen/transform/inline.h" |
23 | #include "triton/codegen/transform/membar.h" |
24 | #include "triton/codegen/transform/peephole.h" |
25 | #include "triton/codegen/transform/pipeline.h" |
26 | #include "triton/codegen/transform/prefetch.h" |
27 | #include "triton/ir/function.h" |
28 | #include "triton/ir/module.h" |
29 | #include "triton/ir/print.h" |
30 | |
31 | namespace triton { |
32 | namespace codegen { |
33 | |
34 | static void link_extern_libs(const ExternLibMap& user_extern_lib_map, |
35 | const ExternLibMap& target_extern_lib_map, |
36 | ir::module& ir, llvm::LLVMContext& ctx, |
37 | std::unique_ptr<llvm::Module>& llvm) { |
38 | for (const auto& iter : target_extern_lib_map) { |
39 | auto &lib_name = iter.first; |
40 | if (user_extern_lib_map.count(lib_name) != 0 && |
41 | user_extern_lib_map.at(lib_name)->path() != "" ) { |
42 | // If the user specified a path for this library, use it. |
43 | user_extern_lib_map.at(lib_name)->install(ctx, llvm); |
44 | } else { |
45 | // Otherwise, use the default path. |
46 | iter.second->install(ctx, llvm); |
47 | } |
48 | } |
49 | |
50 | std::set<llvm::StringRef> function_names; |
51 | for (auto& func : ir.get_function_list()) { |
52 | function_names.insert(func->get_name()); |
53 | } |
54 | llvm::legacy::PassManager pass; |
55 | pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool { |
56 | if (function_names.count(v.getName()) != 0) { |
57 | // Preserve global functions |
58 | return true; |
59 | } |
60 | // Internalize all device functions |
61 | return false; |
62 | })); |
63 | |
64 | llvm::legacy::PassManager pm; |
65 | pm.add(llvm::createVerifierPass()); |
66 | pm.run(*llvm); |
67 | |
68 | llvm::PassManagerBuilder builder; |
69 | builder.OptLevel = 3; |
70 | builder.SizeLevel = 0; |
71 | builder.populateModulePassManager(pass); |
72 | |
73 | pass.run(*llvm); |
74 | } |
75 | |
76 | // TODO: |
77 | // There should be a proper pass manager there! |
78 | std::unique_ptr<llvm::Module> add_passes_to_emit_bin( |
79 | ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target, |
80 | int num_warps, int num_stages, int& shared_static, |
81 | const ExternLibMap& extern_lib_map) { |
82 | // generate llvm code |
83 | std::string name = ir.get_function_list()[0]->get_name(); |
84 | std::unique_ptr<llvm::Module> llvm(new llvm::Module(name, ctx)); |
85 | // optimizations |
86 | bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80; |
87 | // create passes |
88 | codegen::analysis::align align; |
89 | codegen::transform::inliner inliner; |
90 | codegen::analysis::axes axes; |
91 | codegen::transform::pipeline pipeline(has_sm80, num_stages); |
92 | codegen::transform::disassociate disassociate; |
93 | codegen::analysis::layouts layouts(&axes, &align, num_warps, target); |
94 | codegen::transform::cts cts(&layouts, has_sm80); |
95 | codegen::analysis::liveness liveness(&layouts); |
96 | codegen::analysis::swizzle swizzle(&layouts, target); |
97 | codegen::analysis::allocation allocation(&liveness); |
98 | codegen::transform::dce dce; |
99 | codegen::transform::peephole peephole(target, &layouts); |
100 | codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); |
101 | codegen::transform::prefetch prefetch_s(target); |
102 | codegen::transform::membar barriers(&liveness, &layouts, &allocation, |
103 | &prefetch_s, target); |
104 | codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, |
105 | target, num_warps); |
106 | // run passes |
107 | inliner.run(ir); |
108 | dce.run(ir); |
109 | peephole.run(ir); |
110 | dce.run(ir); |
111 | pipeline.run(ir); |
112 | dce.run(ir); |
113 | // ir.print(std::cout); |
114 | disassociate.run(ir); |
115 | dce.run(ir); |
116 | align.run(ir); |
117 | axes.run(ir); |
118 | layouts.run(ir); |
119 | peephole.run(ir); |
120 | dce.run(ir); |
121 | if (target->is_gpu()) cts.run(ir); |
122 | align.run(ir); |
123 | axes.run(ir); |
124 | layouts.run(ir); |
125 | coalesce.run(ir); |
126 | dce.run(ir); |
127 | align.run(ir); |
128 | dce.run(ir); |
129 | if (target->is_gpu()) cts.run(ir); |
130 | dce.run(ir); |
131 | align.run(ir); |
132 | axes.run(ir); |
133 | layouts.run(ir); |
134 | peephole.run(ir); |
135 | dce.run(ir); |
136 | align.run(ir); |
137 | axes.run(ir); |
138 | layouts.run(ir); |
139 | swizzle.run(ir); |
140 | // std::cout << "---" << std::endl; |
141 | // ir.print(std::cout); |
142 | // std::cout << "---" << std::endl; |
143 | // ir.print(std::cout); |
144 | liveness.run(ir); |
145 | allocation.run(ir); |
146 | prefetch_s.run(ir); |
147 | barriers.run(ir); |
148 | // exit(1); |
149 | // ir.print(std::cout); |
150 | isel.visit(ir, *llvm); |
151 | shared_static = allocation.allocated_size(); |
152 | if (target->as_nvidia() && target->as_nvidia()->sm() < 70) { |
153 | // sm < 70 (Pascal) has little shared memory resource. |
154 | // Instead of having "Error: Invalid argument" on launching a kernel, let's throw an error here. |
155 | if (shared_static >= 65536) { |
156 | throw std::runtime_error("Device does not support shared memory of " + std::to_string(shared_static) + "bytes" ); |
157 | } |
158 | } |
159 | |
160 | if (isel.get_extern_lib_map().size() > 0) { |
161 | // If there's any extern lib calls, |
162 | // we need to link them in. |
163 | link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm); |
164 | } |
165 | |
166 | return llvm; |
167 | } |
168 | |
169 | } // namespace codegen |
170 | } // namespace triton |
171 | |