1 | #include <algorithm> |
---|---|
2 | #include <iostream> |
3 | #include "triton/ir/utils.h" |
4 | #include "triton/ir/instructions.h" |
5 | #include "triton/ir/function.h" |
6 | #include "triton/ir/module.h" |
7 | #include "triton/codegen/transform/coalesce.h" |
8 | #include "triton/codegen/analysis/align.h" |
9 | #include "triton/codegen/analysis/layout.h" |
10 | |
11 | namespace triton { |
12 | namespace codegen{ |
13 | namespace transform{ |
14 | |
15 | coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80) |
16 | : align_(align), layout_(layouts), has_sm80_(has_sm80) { } |
17 | |
18 | void coalesce::run(ir::module &mod) { |
19 | std::set<analysis::data_layout*> invalidated; |
20 | ir::builder& builder = mod.get_builder(); |
21 | // add layout conversion instructions |
22 | for(ir::function *fn: mod.get_function_list()) |
23 | for(ir::basic_block *block: fn->blocks()) |
24 | for(ir::instruction* i: block->get_inst_list()){ |
25 | // coalesce before store |
26 | if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i)) |
27 | if(ir::value* op = i->get_operand(1)) |
28 | if(op->get_type()->is_block_ty()) |
29 | if(op->get_type()->get_tile_ranks1() == 2) |
30 | if(invalidated.find(layout_->get(op)) == invalidated.end()) |
31 | if(layout_->get(op)->to_mma()) |
32 | if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){ |
33 | ir::instruction* new_op = ir::cvt_layout_inst::create(op); |
34 | builder.set_insert_point(i); |
35 | builder.insert(new_op); |
36 | i->replace_uses_of_with(op, new_op); |
37 | } |
38 | // coalesce before copy_to_shared |
39 | // only necessary for sm < 80 as Ampere+ can handle reduction |
40 | // on MMA layout |
41 | if(!has_sm80_) |
42 | if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i)) |
43 | if(ir::value* op = i->get_operand(0)) |
44 | if(op->get_type()->is_block_ty()) |
45 | if(op->get_type()->get_tile_ranks1() == 2) |
46 | if(invalidated.find(layout_->get(op)) == invalidated.end()) |
47 | if(layout_->get(op)->to_mma()){ |
48 | ir::instruction* new_op = ir::cvt_layout_inst::create(op); |
49 | builder.set_insert_point(i); |
50 | builder.insert(new_op); |
51 | op->replace_all_uses_with(new_op); |
52 | new_op->replace_uses_of_with(new_op, op); |
53 | invalidated.insert(layout_->get(op)); |
54 | } |
55 | // uncoalesce after load |
56 | if(auto x = dynamic_cast<ir::load_inst*>(i)) |
57 | if(x->get_type()->is_block_ty()) |
58 | if(x->get_type()->get_tile_ranks1()==2) |
59 | if(layout_->get(x)->to_mma()) |
60 | if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){ |
61 | builder.set_insert_point_after(x); |
62 | ir::instruction* new_x = ir::cvt_layout_inst::create(x); |
63 | builder.insert(new_x); |
64 | x->replace_all_uses_with(new_x); |
65 | new_x->replace_uses_of_with(new_x, x); |
66 | } |
67 | } |
68 | for(ir::function *fn: mod.get_function_list()) |
69 | for(ir::basic_block *block: fn->blocks()) |
70 | for(ir::instruction* i: block->get_inst_list()){ |
71 | // re-arrange scanline to promote memory coalescing |
72 | if(auto x = dynamic_cast<ir::store_inst*>(i)){ |
73 | ir::value* ptr = x->get_pointer_operand(); |
74 | ir::value* val = x->get_value_operand(); |
75 | auto out_contig = align_->contiguous(ptr); |
76 | auto val_inst = dynamic_cast<ir::instruction*>(val); |
77 | if(!val_inst) |
78 | continue; |
79 | if(dynamic_cast<ir::cvt_layout_inst*>(val)) |
80 | continue; |
81 | if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1) |
82 | continue; |
83 | std::vector<unsigned> in_contig; |
84 | std::vector<ir::instruction*> queue = {val_inst}; |
85 | std::set<ir::instruction*> seen; |
86 | std::vector<ir::io_inst*> ios; |
87 | while(!queue.empty()){ |
88 | ir::instruction* curr = queue.back(); |
89 | seen.insert(curr); |
90 | queue.pop_back(); |
91 | if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr)) |
92 | break; |
93 | if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){ |
94 | in_contig = align_->contiguous(io_inst->get_pointer_operand()); |
95 | break; |
96 | } |
97 | for(ir::value* op: curr->ops()){ |
98 | auto inst_op = dynamic_cast<ir::instruction*>(op); |
99 | if(!inst_op || seen.find(inst_op) != seen.end()) |
100 | continue; |
101 | if(!op->get_type()->is_block_ty() || |
102 | !val->get_type()->is_block_ty()) |
103 | continue; |
104 | if(op->get_type()->get_tile_num_elements() == |
105 | val->get_type()->get_tile_num_elements()) |
106 | queue.push_back(inst_op); |
107 | } |
108 | } |
109 | if(in_contig.size() <= 1 || out_contig==in_contig) |
110 | continue; |
111 | builder.set_insert_point_after(val_inst); |
112 | auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst)); |
113 | x->replace_uses_of_with(val_inst, new_val); |
114 | } |
115 | } |
116 | } |
117 | |
118 | |
119 | } |
120 | } |
121 | } |
122 |