1 | #include <algorithm> |
2 | #include <iostream> |
3 | #include "triton/ir/module.h" |
4 | #include "triton/ir/function.h" |
5 | #include "triton/codegen/transform/peephole.h" |
6 | #include "triton/codegen/analysis/layout.h" |
7 | |
8 | namespace triton { |
9 | namespace codegen{ |
10 | namespace transform{ |
11 | |
12 | |
13 | ir::value* rewrite_trans_phi_impl(ir::value *value, ir::builder &builder, |
14 | const std::vector<int>& perm) { |
15 | if(auto phi = dynamic_cast<ir::phi_node*>(value)) { |
16 | // transpose operands |
17 | std::vector<ir::value*> incs; |
18 | for(unsigned n = 0; n < phi->get_num_incoming(); n++) |
19 | incs.push_back(rewrite_trans_phi_impl(phi->get_incoming_value(n), builder, perm)); |
20 | // create phi for transposed values |
21 | builder.set_insert_point(phi); |
22 | ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size()); |
23 | for(unsigned n = 0; n < phi->get_num_incoming(); n++) |
24 | result->add_incoming(incs[n], phi->get_incoming_block(n)); |
25 | return result; |
26 | } |
27 | else if(auto i = dynamic_cast<ir::instruction*>(value)){ |
28 | ir::basic_block* block = i->get_parent(); |
29 | auto it = std::find(block->begin(), block->end(), i); |
30 | it++; |
31 | builder.set_insert_point(it); |
32 | ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm); |
33 | trans->set_operand(0, i); |
34 | return trans; |
35 | } |
36 | return nullptr; |
37 | } |
38 | |
39 | bool peephole::rewrite_trans_phi(ir::instruction* value, ir::builder& builder) { |
40 | auto trans = dynamic_cast<ir::trans_inst*>(value); |
41 | if(!trans) |
42 | return false; |
43 | auto users = trans->get_users(); |
44 | auto ops = trans->ops(); |
45 | if(users.size() > 1 || ops.size() > 1) |
46 | return false; |
47 | ir::value* op = *ops.begin(); |
48 | // trans(phi) -> phi(trans(), trans()...) |
49 | auto* phi = dynamic_cast<ir::phi_node*>(op); |
50 | if(!phi) |
51 | return false; |
52 | ir::value* new_phi = rewrite_trans_phi_impl(phi, builder, trans->get_perm()); |
53 | if(!new_phi) |
54 | return false; |
55 | trans->replace_all_uses_with(new_phi); |
56 | |
57 | return true; |
58 | } |
59 | |
60 | bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ |
61 | // dot(a, b, c) + d -> dot(a, b, c + d) |
62 | // d + dot(a, b, c) -> dot(a, b, c + d) |
63 | auto add = dynamic_cast<ir::binary_operator*>(value); |
64 | if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) { |
65 | bool is_int_dot = add->get_op() == ir::binary_op_t::Add; |
66 | ir::value *lhs = add->get_operand(0); |
67 | ir::value *rhs = add->get_operand(1); |
68 | ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs); |
69 | ir::dot_inst *rhs_dot = dynamic_cast<ir::dot_inst*>(rhs); |
70 | if(!lhs_dot && !rhs_dot) |
71 | return false; |
72 | ir::dot_inst *dot = lhs_dot ? lhs_dot : rhs_dot; |
73 | ir::value *other = (dot == lhs) ? rhs : lhs; |
74 | ir::value *acc = dot->get_operand(2); |
75 | ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc); |
76 | ir::constant *_0 = nullptr; |
77 | if(splat) |
78 | _0 = dynamic_cast<ir::constant*>(splat->get_operand(0)); |
79 | if(!_0) |
80 | return false; |
81 | if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0)) |
82 | if (fp_0->get_value() != 0.0) |
83 | return false; |
84 | if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0)) |
85 | if (int_0->get_value() != 0) |
86 | return false; |
87 | ir::value *a = dot->get_operand(0); |
88 | ir::value *b = dot->get_operand(1); |
89 | builder.set_insert_point(add); |
90 | ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name())); |
91 | add->replace_all_uses_with(new_dot); |
92 | return true; |
93 | } |
94 | return false; |
95 | } |
96 | |
97 | //bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){ |
98 | // auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value); |
99 | // if(cfs) { |
100 | // ir::value *arg = cfs->get_operand(0); |
101 | // ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg); |
102 | // if(!cts) |
103 | // return false; |
104 | // cfs->replace_all_uses_with(cts->get_operand(0)); |
105 | // return true; |
106 | // } |
107 | |
108 | //} |
109 | |
110 | bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){ |
111 | auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value); |
112 | if(!copy_to_shared) |
113 | return false; |
114 | ir::value *arg = copy_to_shared->get_operand(0); |
115 | ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg); |
116 | if(!ld) |
117 | return false; |
118 | builder.set_insert_point(copy_to_shared); |
119 | ir::value *ptr = ld->get_pointer_operand(); |
120 | ir::value *msk = ld->get_mask_operand(); |
121 | ir::value *val = ld->get_false_value_operand(); |
122 | analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); |
123 | int nts = layout->nts(layout->get_order()[0]); |
124 | int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; |
125 | if(nts*dtsize >= 4){ |
126 | ir::value* new_load = builder.create_masked_load_async(ptr, msk, val, ld->get_cache_modifier(), ld->get_eviction_policy()); |
127 | copy_to_shared->replace_all_uses_with(new_load); |
128 | return true; |
129 | } |
130 | return false; |
131 | // analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); |
132 | // std::cout << layout->nts(layout->get_order(0)) << std::endl; |
133 | // return true; |
134 | |
135 | } |
136 | |
137 | bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ |
138 | auto x = dynamic_cast<ir::reduce_inst*>(value); |
139 | if(!x) |
140 | return false; |
141 | ir::value *arg = x->get_operand(0); |
142 | auto shapes = arg->get_type()->get_block_shapes(); |
143 | if(shapes[x->get_axis()] == 1){ |
144 | builder.set_insert_point(x); |
145 | ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_block_shapes()); |
146 | x->replace_all_uses_with(new_red); |
147 | return true; |
148 | } |
149 | return false; |
150 | } |
151 | |
152 | bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) { |
153 | auto binop = dynamic_cast<ir::binary_operator*>(value); |
154 | if(binop && binop->get_op() == ir::binary_op_t::Mul) { |
155 | ir::value *lhs = binop->get_operand(0); |
156 | ir::value *rhs = binop->get_operand(1); |
157 | ir::constant_int *_1_lhs = nullptr; |
158 | if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs)){ |
159 | auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0)); |
160 | if(cst && cst->get_value() == 1) |
161 | _1_lhs = cst; |
162 | } |
163 | ir::constant_int *_1_rhs = nullptr; |
164 | if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs)){ |
165 | auto *cst = dynamic_cast<ir::constant_int*>(splat->get_operand(0)); |
166 | if(cst && cst->get_value() == 1) |
167 | _1_rhs = cst; |
168 | } |
169 | if(_1_lhs){ |
170 | binop->replace_all_uses_with(rhs); |
171 | return true; |
172 | } |
173 | else if(_1_rhs){ |
174 | binop->replace_all_uses_with(lhs); |
175 | return true; |
176 | } |
177 | } |
178 | return false; |
179 | } |
180 | |
181 | bool peephole::(ir::instruction *value, ir::builder& builder){ |
182 | auto = dynamic_cast<ir::extract_value_inst*>(value); |
183 | if(!extracted) |
184 | return false; |
185 | size_t = extracted->get_idx(); |
186 | ir::value* agg = extracted->get_operand(0); |
187 | auto insert = dynamic_cast<ir::insert_value_inst*>(agg); |
188 | while(insert){ |
189 | agg = insert->get_operand(0); |
190 | ir::value* inserted = insert->get_operand(1); |
191 | size_t insert_idx = insert->get_idx(); |
192 | insert = dynamic_cast<ir::insert_value_inst*>(agg); |
193 | if(extract_idx == insert_idx){ |
194 | extracted->replace_all_uses_with(inserted); |
195 | return true; |
196 | } |
197 | insert = dynamic_cast<ir::insert_value_inst*>(agg); |
198 | } |
199 | return false; |
200 | } |
201 | |
202 | |
203 | bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) { |
204 | auto x = dynamic_cast<ir::getelementptr_inst*>(value); |
205 | if(!x) |
206 | return false; |
207 | auto y = dynamic_cast<ir::getelementptr_inst*>(x->get_pointer_operand()); |
208 | if(!y) |
209 | return false; |
210 | auto idx = *y->idx_begin(); |
211 | auto z = dynamic_cast<ir::binary_operator*>(idx); |
212 | if(!z) |
213 | return false; |
214 | bool is_sub = z->get_op() == ir::binary_op_t::Sub; |
215 | auto *lhs = dynamic_cast<ir::constant_int*>(z->get_operand(0)); |
216 | bool is_lhs_0 = lhs && (lhs->get_value()==0); |
217 | bool is_rhs_eq_x_rhs = z->get_operand(1) == *x->idx_begin(); |
218 | if(is_sub && is_lhs_0 && is_rhs_eq_x_rhs){ |
219 | x->replace_all_uses_with(y->get_pointer_operand()); |
220 | return true; |
221 | } |
222 | return false; |
223 | } |
224 | |
225 | bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){ |
226 | auto select = dynamic_cast<ir::select_inst*>(value); |
227 | if(!select) |
228 | return false; |
229 | auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op()); |
230 | if(!if_value) |
231 | return false; |
232 | if(select->get_pred_op() != if_value->get_mask_operand()) |
233 | return false; |
234 | builder.set_insert_point(select); |
235 | ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(), |
236 | if_value->get_mask_operand(), |
237 | select->get_else_value_op(), |
238 | if_value->get_cache_modifier(), |
239 | if_value->get_eviction_policy(), |
240 | if_value->get_is_volatile()); |
241 | select->replace_all_uses_with(new_load); |
242 | return true; |
243 | } |
244 | |
245 | bool peephole::rewrite_cvt_layout(ir::instruction *value, ir::builder& builder){ |
246 | auto cvt = dynamic_cast<ir::cvt_layout_inst*>(value); |
247 | if(!cvt) |
248 | return false; |
249 | ir::instruction* op = dynamic_cast<ir::instruction*>(cvt->get_operand(0)); |
250 | if(!op) |
251 | return false; |
252 | // // convert(elementwise(x, y)) = elementwise(convert(x), convert(y)) |
253 | // if(op->get_id() == ir::INST_BINOP){ |
254 | // for(size_t i = 0; i < op->get_num_operands(); i++){ |
255 | // ir::value* arg_i = op->get_operand(i); |
256 | // builder.set_insert_point(op); |
257 | // // create new layout transform |
258 | // ir::instruction* new_arg_i = cvt->clone(); |
259 | // layouts_->copy(new_arg_i, op); |
260 | // builder.insert(new_arg_i); |
261 | // // set the right args |
262 | // new_arg_i->replace_uses_of_with(new_arg_i->get_operand(0), arg_i); |
263 | // op->replace_uses_of_with(arg_i, new_arg_i); |
264 | // } |
265 | // cvt->replace_all_uses_with(op); |
266 | // return true; |
267 | // } |
268 | auto cvt_op = dynamic_cast<ir::cvt_layout_inst*>(op); |
269 | if(!cvt_op) |
270 | return false; |
271 | // convert1(convert2(x)) if convert1 is the inverse of convert2 |
272 | ir::value* op_op = cvt_op->get_operand(0); |
273 | if(layouts_->has(cvt) && layouts_->has(op_op) && |
274 | layouts_->get(cvt) && layouts_->get(op_op)){ |
275 | cvt->replace_all_uses_with(op_op); |
276 | return true; |
277 | } |
278 | return false; |
279 | } |
280 | |
281 | void peephole::run(ir::module &mod) { |
282 | ir::builder &builder = mod.get_builder(); |
283 | // keep track of whether any modification was made |
284 | std::set<ir::value*> seen; |
285 | size_t n_seen; |
286 | |
287 | // rewrite dots first |
288 | do{ |
289 | n_seen = seen.size(); |
290 | for(ir::function *fn: mod.get_function_list()) |
291 | for(ir::basic_block *block: fn->blocks()) |
292 | for(ir::instruction* i: block->get_inst_list()){ |
293 | if(seen.find(i) != seen.end()) |
294 | continue; |
295 | bool was_modified = rewrite_dot(i, builder); |
296 | if(was_modified){ |
297 | seen.insert(i); |
298 | } |
299 | } |
300 | }while(seen.size() != n_seen); |
301 | |
302 | // rewrite other ops |
303 | seen.clear(); |
304 | do{ |
305 | n_seen = seen.size(); |
306 | for(ir::function *fn: mod.get_function_list()) |
307 | for(ir::basic_block *block: fn->blocks()) |
308 | for(ir::instruction* i: block->get_inst_list()){ |
309 | if(seen.find(i) != seen.end()) |
310 | continue; |
311 | bool was_modified = false; |
312 | was_modified = was_modified || rewrite_mult(i, builder); |
313 | // was_modified = was_modified || rewrite_cts_cfs(i, builder); |
314 | // was_modified = was_modified || rewrite_trans_phi(i, builder); |
315 | was_modified = was_modified || rewrite_insert_extract(i, builder); |
316 | was_modified = was_modified || rewrite_unit_red(i, builder); |
317 | was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); |
318 | // TODO: DOESN'T WORK FOR VECTORIZED MASKED LOAD |
319 | // was_modified = was_modified || rewrite_select_masked_load(i, builder); |
320 | was_modified = was_modified || rewrite_cvt_layout(i, builder); |
321 | if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) |
322 | was_modified = was_modified || rewrite_load_to_shared(i, builder); |
323 | if(was_modified) |
324 | seen.insert(i); |
325 | } |
326 | }while(seen.size() != n_seen); |
327 | } |
328 | |
329 | } |
330 | } |
331 | } |
332 | |