1 | #include <iostream> |
2 | #include <algorithm> |
3 | #include "triton/codegen/transform/pipeline.h" |
4 | #include "triton/ir/module.h" |
5 | #include "triton/ir/function.h" |
6 | #include "triton/ir/basic_block.h" |
7 | #include "triton/ir/instructions.h" |
8 | #include "triton/ir/utils.h" |
9 | |
10 | namespace triton { |
11 | namespace codegen{ |
12 | namespace transform{ |
13 | |
14 | |
15 | void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){ |
16 | ir::instruction* i = dynamic_cast<ir::instruction*>(v); |
17 | if(!i || i->get_parent() != block) |
18 | return; |
19 | if(i->get_id()==ir::INST_PHI) |
20 | return; |
21 | ret.push_back(i); |
22 | for(ir::user* u: i->get_users()) |
23 | recursive_deps(u, block, ret); |
24 | } |
25 | |
26 | void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) { |
27 | auto instr = dynamic_cast<ir::instruction*>(cond); |
28 | for (auto op : instr->ops()) { |
29 | if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) { |
30 | phis.insert(phi_op); |
31 | return; |
32 | } |
33 | if (dynamic_cast<ir::instruction*>(op)) |
34 | get_induction_vars(op, phis); |
35 | } |
36 | } |
37 | |
38 | /// assume incoming block is 1 |
39 | ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v, |
40 | std::map<ir::phi_node*, ir::value*>& prev_phi_vals) { |
41 | ir::instruction* i = dynamic_cast<ir::instruction*>(v); |
42 | if(!i || i->get_parent() != block) |
43 | return v; |
44 | if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) { |
45 | if (prev_phi_vals.find(phi) == prev_phi_vals.end()) |
46 | throw std::runtime_error("Don't have that phi node\n" ); |
47 | return prev_phi_vals.at(phi); |
48 | } |
49 | |
50 | std::vector<ir::value*> new_ops; |
51 | for(ir::value* op: i->ops()){ |
52 | new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals)); |
53 | } |
54 | ir::instruction* ret = i->clone(); |
55 | for(size_t k = 0; k < new_ops.size(); k++) |
56 | ret->set_operand(k, new_ops[k]); |
57 | builder.insert(ret); |
58 | return ret; |
59 | } |
60 | |
61 | ir::value* rematerialize(ir::builder& builder, ir::basic_block* block, |
62 | ir::value* v, size_t phi_idx){ |
63 | ir::instruction* i = dynamic_cast<ir::instruction*>(v); |
64 | if(!i || i->get_parent() != block) |
65 | return v; |
66 | if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) |
67 | return phi->get_incoming_value(phi_idx); |
68 | |
69 | std::vector<ir::value*> new_ops; |
70 | for(ir::value* op: i->ops()){ |
71 | new_ops.push_back(rematerialize(builder, block, op, phi_idx)); |
72 | } |
73 | ir::instruction* ret = i->clone(); |
74 | for(size_t k = 0; k < new_ops.size(); k++) |
75 | ret->set_operand(k, new_ops[k]); |
76 | builder.insert(ret); |
77 | return ret; |
78 | } |
79 | |
80 | /// moving the prev phi vals to the next iteration |
81 | std::map<ir::phi_node*, ir::value*> update_prev_phi_vals( |
82 | ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) { |
83 | std::map<ir::phi_node*, ir::value*> next_phi_vals; |
84 | for (auto &[phi, val] : prev_phi_vals) { |
85 | next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals); |
86 | } |
87 | return next_phi_vals; |
88 | } |
89 | |
90 | void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs, |
91 | std::map<ir::phi_node*, ir::value*>& next_load_ivs) { |
92 | for (auto& [phi, val] : load_ivs) { |
93 | if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) { |
94 | ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs); |
95 | assert(new_phi->get_num_operands() == 1 && "should be incomplete phi" ); |
96 | new_phi->add_incoming(next_k, phi->get_incoming_block(1)); |
97 | // cache next_k (to be used by next_mask) |
98 | next_load_ivs[phi] = next_k; |
99 | } else |
100 | throw std::runtime_error("must be phi" ); |
101 | } |
102 | } |
103 | |
104 | struct pipeline_info_t { |
105 | ir::load_inst* load; |
106 | ir::phi_node* ptr; |
107 | ir::dot_inst* dot; |
108 | |
109 | pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot) |
110 | : load(load), ptr(ptr), dot(dot) {} |
111 | }; |
112 | |
113 | void pipeline::run(ir::module &mod) { |
114 | if (num_stages_ <= 1) |
115 | return; |
116 | // *Very* conservative heuristics for pre-fetching. |
117 | // A load instruction can be pipelined if: |
118 | // - the pointer is a phi node that references a value |
119 | // in its basic block (i.e., pointer induction variable) |
120 | // - the load has only a single use in a dot instruction |
121 | // As more use cases become apparent, this pass will be improved |
122 | std::vector<pipeline_info_t> to_pipeline; |
123 | ir::for_each_instruction(mod, [&](ir::instruction *i){ |
124 | if(auto* load = dynamic_cast<ir::load_inst*>(i)){ |
125 | ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand()); |
126 | auto users = load->get_users(); |
127 | auto dot = dynamic_cast<ir::dot_inst*>(*users.begin()); |
128 | if(ptr && ptr->get_incoming_block(1) == ptr->get_parent() |
129 | && users.size() == 1 && dot) |
130 | to_pipeline.push_back({load, ptr, dot}); |
131 | }}); |
132 | // do the pipelining |
133 | std::vector<ir::phi_node*> new_loads; |
134 | ir::builder &builder = mod.get_builder(); |
135 | const int num_stages = num_stages_; |
136 | std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> ; // Used to reorder loads |
137 | |
138 | for(auto info: to_pipeline){ |
139 | ir::load_inst* load = info.load; |
140 | ir::phi_node* ptr = info.ptr; |
141 | ir::basic_block* block = load->get_parent(); |
142 | ir::basic_block* = block->get_predecessors()[0]; |
143 | auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back()); |
144 | auto* = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back()); |
145 | assert(block_br); |
146 | assert(header_br); |
147 | ir::type* ty = load->get_type(); |
148 | // multi-stage pipe |
149 | if (has_copy_async_ && num_stages > 2) { |
150 | ir::value* = header_br->get_cond(); |
151 | ir::value* block_cond = block_br->get_cond(); |
152 | // 1. collect induction variables |
153 | std::set<ir::phi_node*> induction_vars; |
154 | get_induction_vars(block_cond, induction_vars); |
155 | |
156 | std::vector<ir::value*> first_ptrs(num_stages-1); |
157 | std::vector<ir::value*> first_loads(num_stages-1); |
158 | std::vector<ir::value*> first_masks(num_stages-1); |
159 | std::vector<ir::value*> loop_conds(num_stages-1); |
160 | |
161 | std::map<ir::phi_node*, ir::value*> prev_phi_vals; |
162 | // initialize prev_phi_vals |
163 | // Add all phi nodes. The following DCE pass will delete dead ones. |
164 | for (ir::instruction *instr : block->get_inst_list()) |
165 | if (auto *phi = dynamic_cast<ir::phi_node*>(instr)) |
166 | if (phi->get_incoming_block(1) == block) |
167 | prev_phi_vals[phi] = phi->get_value_for_block(header); |
168 | |
169 | builder.set_insert_point(header->get_inst_list().back()); |
170 | first_ptrs[0] = ptr->get_value_for_block(header); |
171 | loop_conds[0] = header_cond; |
172 | first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes()); |
173 | ir::value* false_value = nullptr; |
174 | if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) { |
175 | ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ; |
176 | ir::value* remat_false_value = |
177 | rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals); |
178 | first_masks[0] = builder.create_and(first_masks[0], remat_mask); |
179 | false_value = remat_false_value; |
180 | } else |
181 | false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); |
182 | first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); |
183 | |
184 | for (int stage = 1; stage < num_stages-1; ++stage) { |
185 | // mask is the loop condition of the previous iteration |
186 | loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals); |
187 | prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals); |
188 | first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals); |
189 | first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes()); |
190 | if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) { |
191 | ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals); |
192 | ir::value* remat_false_value = |
193 | rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals); |
194 | first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); |
195 | false_value = remat_false_value; |
196 | } |
197 | first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); |
198 | } |
199 | |
200 | // create new phis for induction variables |
201 | builder.set_insert_point(block->get_first_non_phi()); |
202 | std::map<ir::phi_node*, ir::value*> load_ivs; |
203 | std::map<ir::phi_node*, ir::value*> next_load_ivs; |
204 | for (auto& [iv, val] : prev_phi_vals) { |
205 | ir::phi_node* pn = builder.create_phi(iv->get_type(), 2); |
206 | pn->add_incoming(prev_phi_vals[iv], header); |
207 | load_ivs[iv] = pn; |
208 | } |
209 | // add incoming for phis & update next_load_ivs |
210 | finalize_iv_vals(builder, block, load_ivs, next_load_ivs); |
211 | |
212 | // pre-fetch next iteration |
213 | builder.set_insert_point(block->get_inst_list().back()); |
214 | // ir::value* next_ptr = ptr->get_value_for_block(block); |
215 | ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs); |
216 | ir::value* next_mask = builder.create_splat( |
217 | rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes()); |
218 | if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) { |
219 | ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs); |
220 | // TODO: false may depends on some other phi nodes |
221 | ir::value* remat_false_value = |
222 | rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs); |
223 | next_mask = builder.create_and(next_mask, remat_mask); |
224 | false_value = remat_false_value; |
225 | } |
226 | ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); |
227 | |
228 | |
229 | // phi node |
230 | ptr->set_incoming_value(0, first_ptrs.back()); |
231 | builder.set_insert_point(block->get_first_non_phi()); |
232 | // nested phis for load |
233 | std::vector<ir::phi_node*> new_load_phis(num_stages-1); |
234 | for (auto& pn : new_load_phis) |
235 | pn = builder.create_phi(ty, 2); |
236 | for (int i=0; i<num_stages-2; ++i) { |
237 | new_load_phis[i]->add_incoming(first_loads[i], header); |
238 | new_load_phis[i]->add_incoming(new_load_phis[i+1], block); |
239 | } |
240 | new_load_phis.back()->add_incoming(first_loads.back(), header); |
241 | new_load_phis.back()->add_incoming(next_load, block); |
242 | load->replace_all_uses_with(new_load_phis.front()); |
243 | new_loads.push_back(new_load_phis.back()); |
244 | |
245 | // record first_loads to reorder them |
246 | preheader_loads.push_back({new_load_phis.front(), first_loads}); |
247 | } else { |
248 | // pre-fetch first iteration |
249 | builder.set_insert_point(header->get_inst_list().back()); |
250 | ir::value* first_ptr = ptr->get_value_for_block(header); |
251 | ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes()); |
252 | ir::value* false_value; |
253 | if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){ |
254 | ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0); |
255 | ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0); |
256 | first_mask = builder.create_and(first_mask, remat_mask); |
257 | false_value = remat_false_value; |
258 | } |
259 | else |
260 | false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); |
261 | ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); |
262 | // pre-fetch next iteration |
263 | builder.set_insert_point(block->get_inst_list().back()); |
264 | ir::value* next_ptr = ptr->get_value_for_block(block); |
265 | ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes()); |
266 | if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){ |
267 | ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1); |
268 | ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1); |
269 | next_mask = builder.create_and(next_mask, remat_mask); |
270 | false_value = remat_false_value; |
271 | } |
272 | ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile()); |
273 | // phi node |
274 | builder.set_insert_point(block->get_first_non_phi()); |
275 | ir::phi_node* new_load = builder.create_phi(ty, 2); |
276 | new_load->add_incoming(first_load, header); |
277 | new_load->add_incoming(next_load, block); |
278 | load->replace_all_uses_with(new_load); |
279 | new_loads.push_back(new_load); |
280 | } |
281 | } |
282 | |
283 | // try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to |
284 | // a0, b0, a1, b1, ... |
285 | if (!preheader_loads.empty()) { |
286 | ir::basic_block* = preheader_loads.begin()->first->get_incoming_block(0); |
287 | builder.set_insert_point(header->get_inst_list().back()); |
288 | for (int i=1; i<num_stages-1; ++i) { |
289 | for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) { |
290 | ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i)); |
291 | ir::instruction* moved_load = original_load->clone(); |
292 | builder.insert(moved_load); |
293 | original_load->replace_all_uses_with(moved_load); |
294 | } |
295 | } |
296 | } |
297 | |
298 | // try to move dot_inst after loads |
299 | // for better overlap of io and compute |
300 | struct move_config_t{ |
301 | std::vector<ir::instruction*> insts; |
302 | ir::load_inst* dst; |
303 | }; |
304 | std::vector<move_config_t> to_move(to_pipeline.size()); |
305 | |
306 | if(has_copy_async_){ |
307 | for (size_t idx = 0; idx < to_pipeline.size(); ++idx) { |
308 | auto info = to_pipeline[idx]; |
309 | ir::load_inst* load = info.load; |
310 | ir::phi_node* ptr = info.ptr; |
311 | ir::dot_inst* dot = info.dot; |
312 | ir::basic_block* bb = dot->get_parent(); |
313 | recursive_deps(dot, bb, to_move[idx].insts); |
314 | to_move[idx].dst = load; |
315 | } |
316 | |
317 | for(auto& move_config: to_move){ |
318 | builder.set_insert_point_after(move_config.dst); |
319 | for(ir::instruction* i: move_config.insts){ |
320 | i->get_parent()->erase(i); |
321 | builder.insert(i); |
322 | } |
323 | } |
324 | } |
325 | |
326 | |
327 | } |
328 | |
329 | } |
330 | } |
331 | } |
332 | |