1#include <iostream>
2#include "triton/codegen/transform/inline.h"
3#include "triton/ir/module.h"
4#include "triton/ir/function.h"
5#include "triton/ir/utils.h"
6
7namespace triton{
8namespace codegen{
9namespace transform{
10
11
12bool fncmp::operator()(ir::function* x, ir::function* y) const {
13 auto fn_list = x->get_parent()->get_function_list();
14 return std::find(fn_list.begin(), fn_list.end(), x) < std::find(fn_list.begin(), fn_list.end(), y);
15};
16
17void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& builder,
18 std::list<ir::call_inst*>& callsites){
19 ir::basic_block* parent_block = callsite->get_parent();
20 ir::function* parent_fn = parent_block->get_parent();
21 // the parent block is split into block A and block B:
22 // - block A (`new_blocks[0]`) is the entry block of the inlined function
23 // - block B (`exit`) resumes execution of the parent function
24 ir::basic_block* entry = parent_block->split_before(callsite, fn->get_name());
25 ir::basic_block* exit = entry->get_successors()[0];
26 std::vector<ir::basic_block*> new_blocks = {entry};
27 for(size_t i = 1; i < fn->blocks().size(); i++){
28 ir::basic_block* block = fn->blocks()[i];
29 ir::context& ctx = block->get_context();
30 const std::string& name = block->get_parent()->get_name() + "_" + block->get_name();
31 new_blocks.push_back(ir::basic_block::create(ctx, name, parent_fn));
32 }
33 // a phi node holds the return values of the inlined function
34 if(exit->get_inst_list().empty())
35 builder.set_insert_point(exit);
36 else
37 builder.set_insert_point(exit->get_first_non_phi());
38 ir::phi_node* exit_val = builder.create_phi(fn->get_fn_type()->get_return_ty(), 0);
39 callsite->replace_all_uses_with(exit_val);
40 callsite->erase_from_parent();
41 // get arguments `fn` is called with
42 std::vector<ir::value*> tgt_args(callsite->op_begin(), callsite->op_end());
43 std::vector<ir::argument*> src_args(fn->args().begin(), fn->args().end());
44 // Actually generate the instructions:
45 // - Remove the branch created by basic_block::split_before
46 // - Clone all instructions
47 // - Replace `ret` with incoming nodes to `exit_val` and branches to `exit`
48 ir::instruction* terminator = new_blocks[0]->get_inst_list().back();
49// new_blocks[0]->get_inst_list().back()->erase_from_parent();
50 terminator->erase_from_parent();
51 std::map<ir::instruction*, ir::instruction*> inst_map;
52 std::map<ir::argument*, ir::value*> arg_map;
53 for(size_t k = 0; k < fn->args().size(); k++)
54 arg_map[fn->args()[k]] = callsite->ops()[k];
55 std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
56 // clone instructions
57 for(size_t i = 0; i < new_blocks.size(); i++){
58 ir::basic_block* old_block = fn->blocks()[i];
59 ir::basic_block* new_block = new_blocks[i];
60 builder.set_insert_point(new_block);
61 for(ir::instruction* old_inst: old_block->get_inst_list()){
62 ir::instruction* new_inst = old_inst->clone();
63 inst_map[old_inst] = new_inst;
64 builder.insert(new_inst);
65 }
66 }
67 // update basic blocks
68 for(size_t i = 0; i < new_blocks.size(); i++) {
69 for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
70 // replace basic use cases
71 for(size_t k = 0; k < new_blocks.size(); k++)
72 new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
73 if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
74 // additionally replace basic blocks of phi-nodes since
75 // replace_uses_of_with() does not replace them.
76 for(unsigned in = 0; in < phi->get_num_incoming(); in++)
77 for(size_t k = 0; k < new_blocks.size(); k++)
78 if (phi->get_incoming_block(in) == fn->blocks()[k])
79 phi->set_incoming_block(in, new_blocks[k]);
80 }
81 }
82 }
83 // replace operands of instructions after constructing inst_map
84 for (auto& it: inst_map) {
85 ir::instruction* new_inst = it.second;
86 for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
87 ir::value* op = new_inst->get_operand(k);
88 if(auto arg_op = dynamic_cast<ir::argument*>(op))
89 new_inst->set_operand(k, arg_map.at(arg_op));
90 if(auto inst_op = dynamic_cast<ir::instruction*>(op))
91 if(inst_map.find(inst_op) != inst_map.end())
92 new_inst->set_operand(k, inst_map.at(inst_op));
93 }
94 // handles a ret instruction.
95 // instead of returning we need to branch to after the function call
96 if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
97 if(ir::value* ret_val = ret->get_return_value())
98 exit_val->add_incoming(ret_val, new_inst->get_parent());
99 // replace ret with branch
100 ir::instruction* new_br_inst = ir::branch_inst::create(exit);
101 builder.set_insert_point(new_inst->get_parent());
102 builder.insert(new_br_inst);
103 new_inst->erase_from_parent();
104 }
105 }
106 if(exit_val->get_num_incoming() == 1)
107 exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
108 // done -- make sure insert point is properly set to exit block
109 builder.set_insert_point(exit);
110}
111
112void inliner::run(ir::module &mod) {
113
114 // gather all call sites
115 while(true){
116 std::map<ir::function*, size_t> counts;
117 for(ir::function* fn: mod.get_function_list())
118 counts[fn] = 0;
119
120 std::list<ir::call_inst*> callsites;
121 for(ir::function* fn: mod.get_function_list()){
122 for(ir::basic_block* block: fn->blocks())
123 for(ir::instruction* instr: block->get_inst_list())
124 if(ir::call_inst* call = dynamic_cast<ir::call_inst*>(instr)){
125 callsites.push_back(call);
126 counts[call->get_fn()] += 1;
127 }
128 }
129
130 for(auto& count: counts){
131 if(!count.first->get_is_kernel() && count.second == 0)
132 count.first->get_parent()->remove_function(count.first);
133 }
134
135 if(callsites.empty())
136 break;
137
138 for(ir::call_inst* call: callsites)
139 do_inline(call->get_fn(), call, mod.get_builder(), callsites);
140 }
141
142
143}
144
145}
146}
147}
148