1 | #include <iostream> |
2 | #include "triton/codegen/transform/inline.h" |
3 | #include "triton/ir/module.h" |
4 | #include "triton/ir/function.h" |
5 | #include "triton/ir/utils.h" |
6 | |
7 | namespace triton{ |
8 | namespace codegen{ |
9 | namespace transform{ |
10 | |
11 | |
12 | bool fncmp::operator()(ir::function* x, ir::function* y) const { |
13 | auto fn_list = x->get_parent()->get_function_list(); |
14 | return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y); |
15 | }; |
16 | |
17 | void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder, |
18 | std::list<ir::call_inst*>& callsites){ |
19 | ir::basic_block* parent_block = callsite->get_parent(); |
20 | ir::function* parent_fn = parent_block->get_parent(); |
21 | // the parent block is split into block A and block B: |
22 | // - block A (`new_blocks[0]`) is the entry block of the inlined function |
23 | // - block B (`exit`) resumes execution of the parent function |
24 | ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name()); |
25 | ir::basic_block* exit = entry->get_successors()[0]; |
26 | std::vector<ir::basic_block*> new_blocks = {entry}; |
27 | for(size_t i = 1; i < fn->blocks().size(); i++){ |
28 | ir::basic_block* block = fn->blocks()[i]; |
29 | ir::context& ctx = block->get_context(); |
30 | const std::string& name = block->get_parent()->get_name() + "_" + block->get_name(); |
31 | new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn)); |
32 | } |
33 | // a phi node holds the return values of the inlined function |
34 | if(exit->get_inst_list().empty()) |
35 | builder.set_insert_point(exit); |
36 | else |
37 | builder.set_insert_point(exit->get_first_non_phi()); |
38 | ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0); |
39 | callsite->replace_all_uses_with(exit_val); |
40 | callsite->erase_from_parent(); |
41 | // get arguments `fn` is called with |
42 | std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end()); |
43 | std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end()); |
44 | // Actually generate the instructions: |
45 | // - Remove the branch created by basic_block::split_before |
46 | // - Clone all instructions |
47 | // - Replace `ret` with incoming nodes to `exit_val` and branches to `exit` |
48 | ir::instruction* terminator = new_blocks[0]->get_inst_list().back(); |
49 | // new_blocks[0]->get_inst_list().back()->erase_from_parent(); |
50 | terminator->erase_from_parent(); |
51 | std::map<ir::instruction*, ir::instruction*> inst_map; |
52 | std::map<ir::argument*, ir::value*> arg_map; |
53 | for(size_t k = 0; k < fn->args().size(); k++) |
54 | arg_map[fn->args()[k]] = callsite->ops()[k]; |
55 | std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); |
56 | // clone instructions |
57 | for(size_t i = 0; i < new_blocks.size(); i++){ |
58 | ir::basic_block* old_block = fn->blocks()[i]; |
59 | ir::basic_block* new_block = new_blocks[i]; |
60 | builder.set_insert_point(new_block); |
61 | for(ir::instruction* old_inst: old_block->get_inst_list()){ |
62 | ir::instruction* new_inst = old_inst->clone(); |
63 | inst_map[old_inst] = new_inst; |
64 | builder.insert(new_inst); |
65 | } |
66 | } |
67 | // update basic blocks |
68 | for(size_t i = 0; i < new_blocks.size(); i++) { |
69 | for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) { |
70 | // replace basic use cases |
71 | for(size_t k = 0; k < new_blocks.size(); k++) |
72 | new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]); |
73 | if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) { |
74 | // additionally replace basic blocks of phi-nodes since |
75 | // replace_uses_of_with() does not replace them. |
76 | for(unsigned in = 0; in < phi->get_num_incoming(); in++) |
77 | for(size_t k = 0; k < new_blocks.size(); k++) |
78 | if (phi->get_incoming_block(in) == fn->blocks()[k]) |
79 | phi->set_incoming_block(in, new_blocks[k]); |
80 | } |
81 | } |
82 | } |
83 | // replace operands of instructions after constructing inst_map |
84 | for (auto& it: inst_map) { |
85 | ir::instruction* new_inst = it.second; |
86 | for(size_t k = 0; k < new_inst->get_num_operands(); k++) { |
87 | ir::value* op = new_inst->get_operand(k); |
88 | if(auto arg_op = dynamic_cast<ir::argument*>(op)) |
89 | new_inst->set_operand(k, arg_map.at(arg_op)); |
90 | if(auto inst_op = dynamic_cast<ir::instruction*>(op)) |
91 | if(inst_map.find(inst_op) != inst_map.end()) |
92 | new_inst->set_operand(k, inst_map.at(inst_op)); |
93 | } |
94 | // handles a ret instruction. |
95 | // instead of returning we need to branch to after the function call |
96 | if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) { |
97 | if(ir::value* ret_val = ret->get_return_value()) |
98 | exit_val->add_incoming(ret_val, new_inst->get_parent()); |
99 | // replace ret with branch |
100 | ir::instruction* new_br_inst = ir::branch_inst::create(exit); |
101 | builder.set_insert_point(new_inst->get_parent()); |
102 | builder.insert(new_br_inst); |
103 | new_inst->erase_from_parent(); |
104 | } |
105 | } |
106 | if(exit_val->get_num_incoming() == 1) |
107 | exit_val->replace_all_uses_with(exit_val->get_incoming_value(0)); |
108 | // done -- make sure insert point is properly set to exit block |
109 | builder.set_insert_point(exit); |
110 | } |
111 | |
112 | void inliner::run(ir::module &mod) { |
113 | |
114 | // gather all call sites |
115 | while(true){ |
116 | std::map<ir::function*, size_t> counts; |
117 | for(ir::function* fn: mod.get_function_list()) |
118 | counts[fn] = 0; |
119 | |
120 | std::list<ir::call_inst*> callsites; |
121 | for(ir::function* fn: mod.get_function_list()){ |
122 | for(ir::basic_block* block: fn->blocks()) |
123 | for(ir::instruction* instr: block->get_inst_list()) |
124 | if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){ |
125 | callsites.push_back(call); |
126 | counts[call->get_fn()] += 1; |
127 | } |
128 | } |
129 | |
130 | for(auto& count: counts){ |
131 | if(!count.first->get_is_kernel() && count.second == 0) |
132 | count.first->get_parent()->remove_function(count.first); |
133 | } |
134 | |
135 | if(callsites.empty()) |
136 | break; |
137 | |
138 | for(ir::call_inst* call: callsites) |
139 | do_inline(call->get_fn(), call, mod.get_builder(), callsites); |
140 | } |
141 | |
142 | |
143 | } |
144 | |
145 | } |
146 | } |
147 | } |
148 | |