1#include <algorithm>
2#include <iostream>
3#include "triton/ir/module.h"
4#include "triton/ir/function.h"
5#include "triton/codegen/transform/peephole.h"
6#include "triton/codegen/analysis/layout.h"
7
8namespace triton {
9namespace codegen{
10namespace transform{
11
12
13ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder,
14 const std::vector<int>& perm) {
15 if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
16 // transpose operands
17 std::vector<ir::value*> incs;
18 for(unsigned n = 0; n < phi->get_num_incoming(); n++)
19 incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm));
20 // create phi for transposed values
21 builder.set_insert_point(phi);
22 ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
23 for(unsigned n = 0; n < phi->get_num_incoming(); n++)
24 result->add_incoming(incs[n], phi->get_incoming_block(n));
25 return result;
26 }
27 else if(auto i = dynamic_cast<ir::instruction*>(value)){
28 ir::basic_block* block = i->get_parent();
29 auto it = std::find(block->begin(), block->end(), i);
30 it++;
31 builder.set_insert_point(it);
32 ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
33 trans->set_operand(0, i);
34 return trans;
35 }
36 return nullptr;
37}
38
39bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) {
40 auto trans = dynamic_cast<ir::trans_inst*>(value);
41 if(!trans)
42 return false;
43 auto users = trans->get_users();
44 auto ops = trans->ops();
45 if(users.size() > 1 || ops.size() > 1)
46 return false;
47 ir::value* op = *ops.begin();
48 // trans(phi) -> phi(trans(), trans()...)
49 auto* phi = dynamic_cast<ir::phi_node*>(op);
50 if(!phi)
51 return false;
52 ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm());
53 if(!new_phi)
54 return false;
55 trans->replace_all_uses_with(new_phi);
56
57 return true;
58}
59
60bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
61 // dot(a, b, c) + d -> dot(a, b, c + d)
62 // d + dot(a, b, c) -> dot(a, b, c + d)
63 auto add = dynamic_cast<ir::binary_operator*>(value);
64 if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
65 bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
66 ir::value *lhs = add->get_operand(0);
67 ir::value *rhs = add->get_operand(1);
68 ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
69 ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs);
70 if(!lhs_dot && !rhs_dot)
71 return false;
72 ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot;
73 ir::value *other = (dot == lhs) ? rhs : lhs;
74 ir::value *acc = dot->get_operand(2);
75 ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
76 ir::constant *_0 = nullptr;
77 if(splat)
78 _0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
79 if(!_0)
80 return false;
81 if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
82 if (fp_0->get_value() != 0.0)
83 return false;
84 if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
85 if (int_0->get_value() != 0)
86 return false;
87 ir::value *a = dot->get_operand(0);
88 ir::value *b = dot->get_operand(1);
89 builder.set_insert_point(add);
90 ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name()));
91 add->replace_all_uses_with(new_dot);
92 return true;
93 }
94 return false;
95}
96
97//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
98// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
99// if(cfs) {
100// ir::value *arg = cfs->get_operand(0);
101// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
102// if(!cts)
103// return false;
104// cfs->replace_all_uses_with(cts->get_operand(0));
105// return true;
106// }
107
108//}
109
110bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
111 auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
112 if(!copy_to_shared)
113 return false;
114 ir::value *arg = copy_to_shared->get_operand(0);
115 ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
116 if(!ld)
117 return false;
118 builder.set_insert_point(copy_to_shared);
119 ir::value *ptr = ld->get_pointer_operand();
120 ir::value *msk = ld->get_mask_operand();
121 ir::value *val = ld->get_false_value_operand();
122 analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
123 int nts = layout->nts(layout->get_order()[0]);
124 int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
125 if(nts*dtsize >= 4){
126 ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy());
127 copy_to_shared->replace_all_uses_with(new_load);
128 return true;
129 }
130 return false;
131// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline();
132// std::cout << layout->nts(layout->get_order(0)) << std::endl;
133// return true;
134
135}
136
137bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
138 auto x = dynamic_cast<ir::reduce_inst*>(value);
139 if(!x)
140 return false;
141 ir::value *arg = x->get_operand(0);
142 auto shapes = arg->get_type()->get_block_shapes();
143 if(shapes[x->get_axis()] == 1){
144 builder.set_insert_point(x);
145 ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes());
146 x->replace_all_uses_with(new_red);
147 return true;
148 }
149 return false;
150}
151
152bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
153 auto binop = dynamic_cast<ir::binary_operator*>(value);
154 if(binop && binop->get_op() == ir::binary_op_t::Mul) {
155 ir::value *lhs = binop->get_operand(0);
156 ir::value *rhs = binop->get_operand(1);
157 ir::constant_int *_1_lhs = nullptr;
158 if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){
159 auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
160 if(cst && cst->get_value() == 1)
161 _1_lhs = cst;
162 }
163 ir::constant_int *_1_rhs = nullptr;
164 if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){
165 auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
166 if(cst && cst->get_value() == 1)
167 _1_rhs = cst;
168 }
169 if(_1_lhs){
170 binop->replace_all_uses_with(rhs);
171 return true;
172 }
173 else if(_1_rhs){
174 binop->replace_all_uses_with(lhs);
175 return true;
176 }
177 }
178 return false;
179}
180
181bool peephole::rewrite_insert_extract(ir::instruction *value, ir::builder& builder){
182 auto extracted = dynamic_cast<ir::extract_value_inst*>(value);
183 if(!extracted)
184 return false;
185 size_t extract_idx = extracted->get_idx();
186 ir::value* agg = extracted->get_operand(0);
187 auto insert = dynamic_cast<ir::insert_value_inst*>(agg);
188 while(insert){
189 agg = insert->get_operand(0);
190 ir::value* inserted = insert->get_operand(1);
191 size_t insert_idx = insert->get_idx();
192 insert = dynamic_cast<ir::insert_value_inst*>(agg);
193 if(extract_idx == insert_idx){
194 extracted->replace_all_uses_with(inserted);
195 return true;
196 }
197 insert = dynamic_cast<ir::insert_value_inst*>(agg);
198 }
199 return false;
200}
201
202
203bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
204 auto x = dynamic_cast<ir::getelementptr_inst*>(value);
205 if(!x)
206 return false;
207 auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand());
208 if(!y)
209 return false;
210 auto idx = *y->idx_begin();
211 auto z = dynamic_cast<ir::binary_operator*>(idx);
212 if(!z)
213 return false;
214 bool is_sub = z->get_op() == ir::binary_op_t::Sub;
215 auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0));
216 bool is_lhs_0 = lhs && (lhs->get_value()==0);
217 bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin();
218 if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){
219 x->replace_all_uses_with(y->get_pointer_operand());
220 return true;
221 }
222 return false;
223}
224
225bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
226 auto select = dynamic_cast<ir::select_inst*>(value);
227 if(!select)
228 return false;
229 auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
230 if(!if_value)
231 return false;
232 if(select->get_pred_op() != if_value->get_mask_operand())
233 return false;
234 builder.set_insert_point(select);
235 ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
236 if_value->get_mask_operand(),
237 select->get_else_value_op(),
238 if_value->get_cache_modifier(),
239 if_value->get_eviction_policy(),
240 if_value->get_is_volatile());
241 select->replace_all_uses_with(new_load);
242 return true;
243}
244
245bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){
246 auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value);
247 if(!cvt)
248 return false;
249 ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0));
250 if(!op)
251 return false;
252// // convert(elementwise(x, y)) = elementwise(convert(x), convert(y))
253// if(op->get_id() == ir::INST_BINOP){
254// for(size_t i = 0; i < op->get_num_operands(); i++){
255// ir::value* arg_i = op->get_operand(i);
256// builder.set_insert_point(op);
257// // create new layout transform
258// ir::instruction* new_arg_i = cvt->clone();
259// layouts_->copy(new_arg_i, op);
260// builder.insert(new_arg_i);
261// // set the right args
262// new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i);
263// op->replace_uses_of_with(arg_i, new_arg_i);
264// }
265// cvt->replace_all_uses_with(op);
266// return true;
267// }
268 auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op);
269 if(!cvt_op)
270 return false;
271 // convert1(convert2(x)) if convert1 is the inverse of convert2
272 ir::value* op_op = cvt_op->get_operand(0);
273 if(layouts_->has(cvt) && layouts_->has(op_op) &&
274 layouts_->get(cvt) && layouts_->get(op_op)){
275 cvt->replace_all_uses_with(op_op);
276 return true;
277 }
278 return false;
279}
280
281void peephole::run(ir::module &mod) {
282 ir::builder &builder = mod.get_builder();
283 // keep track of whether any modification was made
284 std::set<ir::value*> seen;
285 size_t n_seen;
286
287 // rewrite dots first
288 do{
289 n_seen = seen.size();
290 for(ir::function *fn: mod.get_function_list())
291 for(ir::basic_block *block: fn->blocks())
292 for(ir::instruction* i: block->get_inst_list()){
293 if(seen.find(i) != seen.end())
294 continue;
295 bool was_modified = rewrite_dot(i, builder);
296 if(was_modified){
297 seen.insert(i);
298 }
299 }
300 }while(seen.size() != n_seen);
301
302 // rewrite other ops
303 seen.clear();
304 do{
305 n_seen = seen.size();
306 for(ir::function *fn: mod.get_function_list())
307 for(ir::basic_block *block: fn->blocks())
308 for(ir::instruction* i: block->get_inst_list()){
309 if(seen.find(i) != seen.end())
310 continue;
311 bool was_modified = false;
312 was_modified = was_modified || rewrite_mult(i, builder);
313 // was_modified = was_modified || rewrite_cts_cfs(i, builder);
314// was_modified = was_modified || rewrite_trans_phi(i, builder);
315 was_modified = was_modified || rewrite_insert_extract(i, builder);
316 was_modified = was_modified || rewrite_unit_red(i, builder);
317 was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
318 // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD
319// was_modified = was_modified || rewrite_select_masked_load(i, builder);
320 was_modified = was_modified || rewrite_cvt_layout(i, builder);
321 if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
322 was_modified = was_modified || rewrite_load_to_shared(i, builder);
323 if(was_modified)
324 seen.insert(i);
325 }
326 }while(seen.size() != n_seen);
327}
328
329}
330}
331}
332