1#include "triton/codegen/analysis/axes.h"
2#include "triton/ir/utils.h"
3#include "triton/ir/instructions.h"
4#include "triton/ir/type.h"
5#include <iostream>
6
7
8namespace triton{
9namespace codegen{
10namespace analysis{
11
12axes::axes() {}
13
14void axes::update_graph_reduce(ir::instruction *i) {
15 auto* red = static_cast<ir::reduce_inst*>(i);
16 unsigned axis = red->get_axis();
17 ir::value *arg = red->get_operand(0);
18 auto in_shapes = arg->get_type()->get_block_shapes();
19 unsigned current = 0;
20 for(unsigned d = 0; d < in_shapes.size(); d++){
21 if(d == axis)
22 continue;
23 graph_.add_edge({i, current++}, {arg, d});
24 }
25}
26
27void axes::update_graph_reshape(ir::instruction *i) {
28 auto* reshape = static_cast<ir::reshape_inst*>(i);
29 // operands
30 ir::value *op = reshape->get_operand(0);
31 // shapes
32 auto op_shapes = op->get_type()->get_block_shapes();
33 auto res_shapes = reshape->get_type()->get_block_shapes();
34 // construct edges
35 unsigned current = 0;
36 bool is_skewed = false;
37 for(unsigned d = 0; d < res_shapes.size(); d ++){
38 bool same_shape = res_shapes[d] == op_shapes[current];
39 // either add edge between axis or just add a node in the graph
40 if(!is_skewed && same_shape)
41 graph_.add_edge({i, d}, {op, current++});
42 else
43 graph_.add_edge({i, d}, {i, d});
44 // reshaping is skewed
45 if(res_shapes[d] > 1 && !same_shape)
46 is_skewed = true;
47 }
48}
49
50void axes::update_graph_trans(ir::instruction *i) {
51 auto *trans = static_cast<ir::trans_inst*>(i);
52 ir::value *op = trans->get_operand(0);
53 auto perm = trans->get_perm();
54 // add edge between axis perm[d] and axis d
55 for(unsigned d = 0; d < perm.size(); d++)
56 graph_.add_edge({i, perm[d]}, {op, d});
57}
58
59void axes::update_graph_dequantize(ir::instruction *i) {
60 auto *dequantize = static_cast<ir::dequantize_inst*>(i);
61 auto shapes = dequantize->get_type()->get_block_shapes();
62 ir::value *op = dequantize->get_operand(0);
63
64 // add edge except the last axis
65 for(unsigned d = 0; d < shapes.size() - 1; d ++){
66 graph_.add_edge({i, d}, {op, d});
67 }
68}
69
70void axes::update_graph_broadcast(ir::instruction *i) {
71 auto *broadcast = static_cast<ir::broadcast_inst*>(i);
72 auto shapes = broadcast->get_type()->get_block_shapes();
73 ir::value *op = broadcast->get_operand(0);
74 ir::type *op_ty = op->get_type();
75 const auto& op_shapes = op_ty->get_block_shapes();
76 // add edge between non-broadcast axes
77 for(unsigned d = 0; d < shapes.size(); d ++)
78 if(op_shapes[d] == shapes[d])
79 graph_.add_edge({i, d}, {op, d});
80}
81
82void axes::update_graph_dot(ir::instruction *i) {
83 auto *dot = static_cast<ir::dot_inst*>(i);
84 auto shapes = dot->get_type()->get_block_shapes();
85 ir::value *A = dot->get_operand(0);
86 ir::value *B = dot->get_operand(1);
87 ir::value *D = dot->get_operand(2);
88 // add edges between result and accumulator
89 for(unsigned d = 0; d < shapes.size(); d++)
90 graph_.add_edge({dot, d}, {D, d});
91}
92
93void axes::update_graph_elementwise(ir::instruction *i,
94 bool is_masked_load_async) {
95 if(i->get_num_operands() == 0)
96 return;
97 ir::value *op = i->get_operand(0);
98 if(!op->get_type()->is_block_ty())
99 return;
100 auto rank = op->get_type()->get_tile_rank();
101 for(unsigned d = 0; d < rank; d++) {
102 // If we are dealing with a masked async load we need to attach the
103 // dimensions so we match the behaviour of the copy_to_shared instruction
104 // which async masked load replaces.
105 if (is_masked_load_async) {
106 graph_.add_edge({i, d}, {i, d});
107 }
108
109 for(ir::value* opx: i->ops())
110 for(ir::value* opy: i->ops()) {
111 if(!is_masked_load_async && !i->get_type()->is_void_ty())
112 graph_.add_edge({i, d}, {opx, d});
113 graph_.add_edge({opx, d}, {opy, d});
114 }
115 }
116}
117
118void axes::update_graph_no_edge(ir::instruction *i) {
119 if(!i->get_type()->is_block_ty())
120 return;
121 auto rank = i->get_type()->get_tile_rank();
122 for(unsigned d = 0; d < rank; d++)
123 graph_.add_edge({i, d}, {i, d});
124}
125
126void axes::update_graph(ir::instruction *i) {
127 switch (i->get_id()) {
128 case ir::INST_REDUCE: return update_graph_reduce(i);
129 case ir::INST_RESHAPE: return update_graph_reshape(i);
130 case ir::INST_SPLAT: return update_graph_no_edge(i);
131 case ir::INST_CAT: return update_graph_elementwise(i, true);
132 case ir::INST_TRANS: return update_graph_trans(i);
133 case ir::INST_DEQUANTIZE: return update_graph_dequantize(i);
134 case ir::INST_BROADCAST: return update_graph_broadcast(i);
135 case ir::INST_DOT: return update_graph_dot(i);
136 case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
137 case ir::INST_MASKED_LOAD_ASYNC: return update_graph_elementwise(i, true);
138 case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
139 case ir::INST_CVT_LAYOUT: return update_graph_no_edge(i);
140 default: return update_graph_elementwise(i);
141 }
142 return;
143}
144
145
146int axes::get(ir::value *value, unsigned dim) {
147 return axes_.at({value, dim});
148}
149
150std::vector<int> axes::get(ir::value *value) {
151 std::vector<int> result;
152 for(size_t d = 0; d < value->get_type()->get_tile_rank(); d++)
153 result.push_back(this->get(value, d));
154 return result;
155}
156
157void axes::run(ir::module &mod) {
158 // make graph
159 graph_.clear();
160 axes_.clear();
161 ir::for_each_instruction(mod, [this](ir::instruction *x) {
162 update_graph(x);
163 });
164 // find connected components
165 graph_.connected_components(nullptr, &axes_);
166 std::set<size_t> uniq;
167 for(auto x: axes_)
168 uniq.insert(x.second);
169}
170
171}
172}
173
174}
175