1#include <algorithm>
2#include <iostream>
3#include "triton/ir/utils.h"
4#include "triton/ir/instructions.h"
5#include "triton/ir/function.h"
6#include "triton/ir/module.h"
7#include "triton/codegen/transform/coalesce.h"
8#include "triton/codegen/analysis/align.h"
9#include "triton/codegen/analysis/layout.h"
10
11namespace triton {
12namespace codegen{
13namespace transform{
14
15coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80)
16 : align_(align), layout_(layouts), has_sm80_(has_sm80) { }
17
18void coalesce::run(ir::module &mod) {
19 std::set<analysis::data_layout*> invalidated;
20 ir::builder& builder = mod.get_builder();
21 // add layout conversion instructions
22 for(ir::function *fn: mod.get_function_list())
23 for(ir::basic_block *block: fn->blocks())
24 for(ir::instruction* i: block->get_inst_list()){
25 // coalesce before store
26 if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
27 if(ir::value* op = i->get_operand(1))
28 if(op->get_type()->is_block_ty())
29 if(op->get_type()->get_tile_ranks1() == 2)
30 if(invalidated.find(layout_->get(op)) == invalidated.end())
31 if(layout_->get(op)->to_mma())
32 if(dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
33 ir::instruction* new_op = ir::cvt_layout_inst::create(op);
34 builder.set_insert_point(i);
35 builder.insert(new_op);
36 i->replace_uses_of_with(op, new_op);
37 }
38 // coalesce before copy_to_shared
39 // only necessary for sm < 80 as Ampere+ can handle reduction
40 // on MMA layout
41 if(!has_sm80_)
42 if(dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::reduce_inst*>(i))
43 if(ir::value* op = i->get_operand(0))
44 if(op->get_type()->is_block_ty())
45 if(op->get_type()->get_tile_ranks1() == 2)
46 if(invalidated.find(layout_->get(op)) == invalidated.end())
47 if(layout_->get(op)->to_mma()){
48 ir::instruction* new_op = ir::cvt_layout_inst::create(op);
49 builder.set_insert_point(i);
50 builder.insert(new_op);
51 op->replace_all_uses_with(new_op);
52 new_op->replace_uses_of_with(new_op, op);
53 invalidated.insert(layout_->get(op));
54 }
55 // uncoalesce after load
56 if(auto x = dynamic_cast<ir::load_inst*>(i))
57 if(x->get_type()->is_block_ty())
58 if(x->get_type()->get_tile_ranks1()==2)
59 if(layout_->get(x)->to_mma())
60 if(!has_sm80_ || dynamic_cast<ir::io_inst*>(i)->get_eviction_policy()==ir::io_inst::NORMAL){
61 builder.set_insert_point_after(x);
62 ir::instruction* new_x = ir::cvt_layout_inst::create(x);
63 builder.insert(new_x);
64 x->replace_all_uses_with(new_x);
65 new_x->replace_uses_of_with(new_x, x);
66 }
67 }
68 for(ir::function *fn: mod.get_function_list())
69 for(ir::basic_block *block: fn->blocks())
70 for(ir::instruction* i: block->get_inst_list()){
71 // re-arrange scanline to promote memory coalescing
72 if(auto x = dynamic_cast<ir::store_inst*>(i)){
73 ir::value* ptr = x->get_pointer_operand();
74 ir::value* val = x->get_value_operand();
75 auto out_contig = align_->contiguous(ptr);
76 auto val_inst = dynamic_cast<ir::instruction*>(val);
77 if(!val_inst)
78 continue;
79 if(dynamic_cast<ir::cvt_layout_inst*>(val))
80 continue;
81 if(!val->get_type()->is_block_ty() || val->get_type()->get_tile_ranks1()==1)
82 continue;
83 std::vector<unsigned> in_contig;
84 std::vector<ir::instruction*> queue = {val_inst};
85 std::set<ir::instruction*> seen;
86 std::vector<ir::io_inst*> ios;
87 while(!queue.empty()){
88 ir::instruction* curr = queue.back();
89 seen.insert(curr);
90 queue.pop_back();
91 if(auto dot_inst = dynamic_cast<ir::dot_inst*>(curr))
92 break;
93 if(auto io_inst = dynamic_cast<ir::io_inst*>(curr)){
94 in_contig = align_->contiguous(io_inst->get_pointer_operand());
95 break;
96 }
97 for(ir::value* op: curr->ops()){
98 auto inst_op = dynamic_cast<ir::instruction*>(op);
99 if(!inst_op || seen.find(inst_op) != seen.end())
100 continue;
101 if(!op->get_type()->is_block_ty() ||
102 !val->get_type()->is_block_ty())
103 continue;
104 if(op->get_type()->get_tile_num_elements() ==
105 val->get_type()->get_tile_num_elements())
106 queue.push_back(inst_op);
107 }
108 }
109 if(in_contig.size() <= 1 || out_contig==in_contig)
110 continue;
111 builder.set_insert_point_after(val_inst);
112 auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
113 x->replace_uses_of_with(val_inst, new_val);
114 }
115 }
116}
117
118
119}
120}
121}
122