1 | #include "triton/codegen/analysis/axes.h" |
2 | #include "triton/ir/utils.h" |
3 | #include "triton/ir/instructions.h" |
4 | #include "triton/ir/type.h" |
5 | #include <iostream> |
6 | |
7 | |
8 | namespace triton{ |
9 | namespace codegen{ |
10 | namespace analysis{ |
11 | |
12 | axes::axes() {} |
13 | |
14 | void axes::update_graph_reduce(ir::instruction *i) { |
15 | auto* red = static_cast<ir::reduce_inst*>(i); |
16 | unsigned axis = red->get_axis(); |
17 | ir::value *arg = red->get_operand(0); |
18 | auto in_shapes = arg->get_type()->get_block_shapes(); |
19 | unsigned current = 0; |
20 | for(unsigned d = 0; d < in_shapes.size(); d++){ |
21 | if(d == axis) |
22 | continue; |
23 | graph_.add_edge({i, current++}, {arg, d}); |
24 | } |
25 | } |
26 | |
27 | void axes::update_graph_reshape(ir::instruction *i) { |
28 | auto* reshape = static_cast<ir::reshape_inst*>(i); |
29 | // operands |
30 | ir::value *op = reshape->get_operand(0); |
31 | // shapes |
32 | auto op_shapes = op->get_type()->get_block_shapes(); |
33 | auto res_shapes = reshape->get_type()->get_block_shapes(); |
34 | // construct edges |
35 | unsigned current = 0; |
36 | bool is_skewed = false; |
37 | for(unsigned d = 0; d < res_shapes.size(); d ++){ |
38 | bool same_shape = res_shapes[d] == op_shapes[current]; |
39 | // either add edge between axis or just add a node in the graph |
40 | if(!is_skewed && same_shape) |
41 | graph_.add_edge({i, d}, {op, current++}); |
42 | else |
43 | graph_.add_edge({i, d}, {i, d}); |
44 | // reshaping is skewed |
45 | if(res_shapes[d] > 1 && !same_shape) |
46 | is_skewed = true; |
47 | } |
48 | } |
49 | |
50 | void axes::update_graph_trans(ir::instruction *i) { |
51 | auto *trans = static_cast<ir::trans_inst*>(i); |
52 | ir::value *op = trans->get_operand(0); |
53 | auto perm = trans->get_perm(); |
54 | // add edge between axis perm[d] and axis d |
55 | for(unsigned d = 0; d < perm.size(); d++) |
56 | graph_.add_edge({i, perm[d]}, {op, d}); |
57 | } |
58 | |
59 | void axes::update_graph_dequantize(ir::instruction *i) { |
60 | auto *dequantize = static_cast<ir::dequantize_inst*>(i); |
61 | auto shapes = dequantize->get_type()->get_block_shapes(); |
62 | ir::value *op = dequantize->get_operand(0); |
63 | |
64 | // add edge except the last axis |
65 | for(unsigned d = 0; d < shapes.size() - 1; d ++){ |
66 | graph_.add_edge({i, d}, {op, d}); |
67 | } |
68 | } |
69 | |
70 | void axes::update_graph_broadcast(ir::instruction *i) { |
71 | auto *broadcast = static_cast<ir::broadcast_inst*>(i); |
72 | auto shapes = broadcast->get_type()->get_block_shapes(); |
73 | ir::value *op = broadcast->get_operand(0); |
74 | ir::type *op_ty = op->get_type(); |
75 | const auto& op_shapes = op_ty->get_block_shapes(); |
76 | // add edge between non-broadcast axes |
77 | for(unsigned d = 0; d < shapes.size(); d ++) |
78 | if(op_shapes[d] == shapes[d]) |
79 | graph_.add_edge({i, d}, {op, d}); |
80 | } |
81 | |
82 | void axes::update_graph_dot(ir::instruction *i) { |
83 | auto *dot = static_cast<ir::dot_inst*>(i); |
84 | auto shapes = dot->get_type()->get_block_shapes(); |
85 | ir::value *A = dot->get_operand(0); |
86 | ir::value *B = dot->get_operand(1); |
87 | ir::value *D = dot->get_operand(2); |
88 | // add edges between result and accumulator |
89 | for(unsigned d = 0; d < shapes.size(); d++) |
90 | graph_.add_edge({dot, d}, {D, d}); |
91 | } |
92 | |
93 | void axes::update_graph_elementwise(ir::instruction *i, |
94 | bool is_masked_load_async) { |
95 | if(i->get_num_operands() == 0) |
96 | return; |
97 | ir::value *op = i->get_operand(0); |
98 | if(!op->get_type()->is_block_ty()) |
99 | return; |
100 | auto rank = op->get_type()->get_tile_rank(); |
101 | for(unsigned d = 0; d < rank; d++) { |
102 | // If we are dealing with a masked async load we need to attach the |
103 | // dimensions so we match the behaviour of the copy_to_shared instruction |
104 | // which async masked load replaces. |
105 | if (is_masked_load_async) { |
106 | graph_.add_edge({i, d}, {i, d}); |
107 | } |
108 | |
109 | for(ir::value* opx: i->ops()) |
110 | for(ir::value* opy: i->ops()) { |
111 | if(!is_masked_load_async && !i->get_type()->is_void_ty()) |
112 | graph_.add_edge({i, d}, {opx, d}); |
113 | graph_.add_edge({opx, d}, {opy, d}); |
114 | } |
115 | } |
116 | } |
117 | |
118 | void axes::update_graph_no_edge(ir::instruction *i) { |
119 | if(!i->get_type()->is_block_ty()) |
120 | return; |
121 | auto rank = i->get_type()->get_tile_rank(); |
122 | for(unsigned d = 0; d < rank; d++) |
123 | graph_.add_edge({i, d}, {i, d}); |
124 | } |
125 | |
126 | void axes::update_graph(ir::instruction *i) { |
127 | switch (i->get_id()) { |
128 | case ir::INST_REDUCE: return update_graph_reduce(i); |
129 | case ir::INST_RESHAPE: return update_graph_reshape(i); |
130 | case ir::INST_SPLAT: return update_graph_no_edge(i); |
131 | case ir::INST_CAT: return update_graph_elementwise(i, true); |
132 | case ir::INST_TRANS: return update_graph_trans(i); |
133 | case ir::INST_DEQUANTIZE: return update_graph_dequantize(i); |
134 | case ir::INST_BROADCAST: return update_graph_broadcast(i); |
135 | case ir::INST_DOT: return update_graph_dot(i); |
136 | case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); |
137 | case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true); |
138 | case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); |
139 | case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i); |
140 | default: return update_graph_elementwise(i); |
141 | } |
142 | return; |
143 | } |
144 | |
145 | |
146 | int axes::get(ir::value *value, unsigned dim) { |
147 | return axes_.at({value, dim}); |
148 | } |
149 | |
150 | std::vector<int> axes::get(ir::value *value) { |
151 | std::vector<int> result; |
152 | for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++) |
153 | result.push_back(this->get(value, d)); |
154 | return result; |
155 | } |
156 | |
157 | void axes::run(ir::module &mod) { |
158 | // make graph |
159 | graph_.clear(); |
160 | axes_.clear(); |
161 | ir::for_each_instruction(mod, [this](ir::instruction *x) { |
162 | update_graph(x); |
163 | }); |
164 | // find connected components |
165 | graph_.connected_components(nullptr, &axes_); |
166 | std::set<size_t> uniq; |
167 | for(auto x: axes_) |
168 | uniq.insert(x.second); |
169 | } |
170 | |
171 | } |
172 | } |
173 | |
174 | } |
175 | |