1#include <iostream>
2#include <algorithm>
3#include "triton/codegen/transform/pipeline.h"
4#include "triton/ir/module.h"
5#include "triton/ir/function.h"
6#include "triton/ir/basic_block.h"
7#include "triton/ir/instructions.h"
8#include "triton/ir/utils.h"
9
10namespace triton {
11namespace codegen{
12namespace transform{
13
14
15void recursive_deps(ir::value* v, ir::basic_block* block, std::vector<ir::instruction*>& ret){
16 ir::instruction* i = dynamic_cast<ir::instruction*>(v);
17 if(!i || i->get_parent() != block)
18 return;
19 if(i->get_id()==ir::INST_PHI)
20 return;
21 ret.push_back(i);
22 for(ir::user* u: i->get_users())
23 recursive_deps(u, block, ret);
24}
25
26void get_induction_vars(ir::value* cond, std::set<ir::phi_node*>& phis) {
27 auto instr = dynamic_cast<ir::instruction*>(cond);
28 for (auto op : instr->ops()) {
29 if (auto phi_op = dynamic_cast<ir::phi_node*>(op)) {
30 phis.insert(phi_op);
31 return;
32 }
33 if (dynamic_cast<ir::instruction*>(op))
34 get_induction_vars(op, phis);
35 }
36}
37
38/// assume incoming block is 1
39ir::value* rematerialize_vals(ir::builder& builder, ir::basic_block* block, ir::value* v,
40 std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
41 ir::instruction* i = dynamic_cast<ir::instruction*>(v);
42 if(!i || i->get_parent() != block)
43 return v;
44 if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)) {
45 if (prev_phi_vals.find(phi) == prev_phi_vals.end())
46 throw std::runtime_error("Don't have that phi node\n");
47 return prev_phi_vals.at(phi);
48 }
49
50 std::vector<ir::value*> new_ops;
51 for(ir::value* op: i->ops()){
52 new_ops.push_back(rematerialize_vals(builder, block, op, prev_phi_vals));
53 }
54 ir::instruction* ret = i->clone();
55 for(size_t k = 0; k < new_ops.size(); k++)
56 ret->set_operand(k, new_ops[k]);
57 builder.insert(ret);
58 return ret;
59}
60
61ir::value* rematerialize(ir::builder& builder, ir::basic_block* block,
62 ir::value* v, size_t phi_idx){
63 ir::instruction* i = dynamic_cast<ir::instruction*>(v);
64 if(!i || i->get_parent() != block)
65 return v;
66 if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v))
67 return phi->get_incoming_value(phi_idx);
68
69 std::vector<ir::value*> new_ops;
70 for(ir::value* op: i->ops()){
71 new_ops.push_back(rematerialize(builder, block, op, phi_idx));
72 }
73 ir::instruction* ret = i->clone();
74 for(size_t k = 0; k < new_ops.size(); k++)
75 ret->set_operand(k, new_ops[k]);
76 builder.insert(ret);
77 return ret;
78}
79
80/// moving the prev phi vals to the next iteration
81std::map<ir::phi_node*, ir::value*> update_prev_phi_vals(
82 ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& prev_phi_vals) {
83 std::map<ir::phi_node*, ir::value*> next_phi_vals;
84 for (auto &[phi, val] : prev_phi_vals) {
85 next_phi_vals[phi] = rematerialize_vals(builder, block, phi->get_incoming_value(1), prev_phi_vals);
86 }
87 return next_phi_vals;
88}
89
90void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir::phi_node*, ir::value*>& load_ivs,
91 std::map<ir::phi_node*, ir::value*>& next_load_ivs) {
92 for (auto& [phi, val] : load_ivs) {
93 if (auto new_phi = dynamic_cast<ir::phi_node*>(val)) {
94 ir::value* next_k = rematerialize_vals(builder, block, phi->get_incoming_value(1), load_ivs);
95 assert(new_phi->get_num_operands() == 1 && "should be incomplete phi");
96 new_phi->add_incoming(next_k, phi->get_incoming_block(1));
97 // cache next_k (to be used by next_mask)
98 next_load_ivs[phi] = next_k;
99 } else
100 throw std::runtime_error("must be phi");
101 }
102}
103
104struct pipeline_info_t {
105 ir::load_inst* load;
106 ir::phi_node* ptr;
107 ir::dot_inst* dot;
108
109 pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
110 : load(load), ptr(ptr), dot(dot) {}
111};
112
113void pipeline::run(ir::module &mod) {
114 if (num_stages_ <= 1)
115 return;
116 // *Very* conservative heuristics for pre-fetching.
117 // A load instruction can be pipelined if:
118 // - the pointer is a phi node that references a value
119 // in its basic block (i.e., pointer induction variable)
120 // - the load has only a single use in a dot instruction
121 // As more use cases become apparent, this pass will be improved
122 std::vector<pipeline_info_t> to_pipeline;
123 ir::for_each_instruction(mod, [&](ir::instruction *i){
124 if(auto* load = dynamic_cast<ir::load_inst*>(i)){
125 ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
126 auto users = load->get_users();
127 auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
128 if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
129 && users.size() == 1 && dot)
130 to_pipeline.push_back({load, ptr, dot});
131 }});
132 // do the pipelining
133 std::vector<ir::phi_node*> new_loads;
134 ir::builder &builder = mod.get_builder();
135 const int num_stages = num_stages_;
136 std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
137
138 for(auto info: to_pipeline){
139 ir::load_inst* load = info.load;
140 ir::phi_node* ptr = info.ptr;
141 ir::basic_block* block = load->get_parent();
142 ir::basic_block* header = block->get_predecessors()[0];
143 auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
144 auto* header_br = dynamic_cast<ir::cond_branch_inst*>(header->get_inst_list().back());
145 assert(block_br);
146 assert(header_br);
147 ir::type* ty = load->get_type();
148 // multi-stage pipe
149 if (has_copy_async_ && num_stages > 2) {
150 ir::value* header_cond = header_br->get_cond();
151 ir::value* block_cond = block_br->get_cond();
152 // 1. collect induction variables
153 std::set<ir::phi_node*> induction_vars;
154 get_induction_vars(block_cond, induction_vars);
155
156 std::vector<ir::value*> first_ptrs(num_stages-1);
157 std::vector<ir::value*> first_loads(num_stages-1);
158 std::vector<ir::value*> first_masks(num_stages-1);
159 std::vector<ir::value*> loop_conds(num_stages-1);
160
161 std::map<ir::phi_node*, ir::value*> prev_phi_vals;
162 // initialize prev_phi_vals
163 // Add all phi nodes. The following DCE pass will delete dead ones.
164 for (ir::instruction *instr : block->get_inst_list())
165 if (auto *phi = dynamic_cast<ir::phi_node*>(instr))
166 if (phi->get_incoming_block(1) == block)
167 prev_phi_vals[phi] = phi->get_value_for_block(header);
168
169 builder.set_insert_point(header->get_inst_list().back());
170 first_ptrs[0] = ptr->get_value_for_block(header);
171 loop_conds[0] = header_cond;
172 first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes());
173 ir::value* false_value = nullptr;
174 if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
175 ir::value* remat_mask =rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals) ;
176 ir::value* remat_false_value =
177 rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
178 first_masks[0] = builder.create_and(first_masks[0], remat_mask);
179 false_value = remat_false_value;
180 } else
181 false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
182 first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
183
184 for (int stage = 1; stage < num_stages-1; ++stage) {
185 // mask is the loop condition of the previous iteration
186 loop_conds[stage] = rematerialize_vals(builder, block, block_cond, prev_phi_vals);
187 prev_phi_vals = update_prev_phi_vals(builder, block, prev_phi_vals);
188 first_ptrs[stage] = rematerialize_vals(builder, block, ptr, prev_phi_vals);
189 first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes());
190 if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
191 ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), prev_phi_vals);
192 ir::value* remat_false_value =
193 rematerialize_vals(builder, block, masked_load->get_false_value_operand(), prev_phi_vals);
194 first_masks[stage] = builder.create_and(first_masks[stage], remat_mask);
195 false_value = remat_false_value;
196 }
197 first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
198 }
199
200 // create new phis for induction variables
201 builder.set_insert_point(block->get_first_non_phi());
202 std::map<ir::phi_node*, ir::value*> load_ivs;
203 std::map<ir::phi_node*, ir::value*> next_load_ivs;
204 for (auto& [iv, val] : prev_phi_vals) {
205 ir::phi_node* pn = builder.create_phi(iv->get_type(), 2);
206 pn->add_incoming(prev_phi_vals[iv], header);
207 load_ivs[iv] = pn;
208 }
209 // add incoming for phis & update next_load_ivs
210 finalize_iv_vals(builder, block, load_ivs, next_load_ivs);
211
212 // pre-fetch next iteration
213 builder.set_insert_point(block->get_inst_list().back());
214// ir::value* next_ptr = ptr->get_value_for_block(block);
215 ir::value* next_ptr = rematerialize_vals(builder, block, ptr->get_value_for_block(block), load_ivs);
216 ir::value* next_mask = builder.create_splat(
217 rematerialize_vals(builder, block, block_cond, load_ivs), ty->get_block_shapes());
218 if (auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)) {
219 ir::value* remat_mask = rematerialize_vals(builder, block, masked_load->get_mask_operand(), next_load_ivs);
220 // TODO: false may depends on some other phi nodes
221 ir::value* remat_false_value =
222 rematerialize_vals(builder, block, masked_load->get_false_value_operand(), next_load_ivs);
223 next_mask = builder.create_and(next_mask, remat_mask);
224 false_value = remat_false_value;
225 }
226 ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
227
228
229 // phi node
230 ptr->set_incoming_value(0, first_ptrs.back());
231 builder.set_insert_point(block->get_first_non_phi());
232 // nested phis for load
233 std::vector<ir::phi_node*> new_load_phis(num_stages-1);
234 for (auto& pn : new_load_phis)
235 pn = builder.create_phi(ty, 2);
236 for (int i=0; i<num_stages-2; ++i) {
237 new_load_phis[i]->add_incoming(first_loads[i], header);
238 new_load_phis[i]->add_incoming(new_load_phis[i+1], block);
239 }
240 new_load_phis.back()->add_incoming(first_loads.back(), header);
241 new_load_phis.back()->add_incoming(next_load, block);
242 load->replace_all_uses_with(new_load_phis.front());
243 new_loads.push_back(new_load_phis.back());
244
245 // record first_loads to reorder them
246 preheader_loads.push_back({new_load_phis.front(), first_loads});
247 } else {
248 // pre-fetch first iteration
249 builder.set_insert_point(header->get_inst_list().back());
250 ir::value* first_ptr = ptr->get_value_for_block(header);
251 ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes());
252 ir::value* false_value;
253 if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
254 ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 0);
255 ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 0);
256 first_mask = builder.create_and(first_mask, remat_mask);
257 false_value = remat_false_value;
258 }
259 else
260 false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes());
261 ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
262 // pre-fetch next iteration
263 builder.set_insert_point(block->get_inst_list().back());
264 ir::value* next_ptr = ptr->get_value_for_block(block);
265 ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes());
266 if(auto* masked_load = dynamic_cast<ir::masked_load_inst*>(load)){
267 ir::value* remat_mask = rematerialize(builder, block, masked_load->get_mask_operand(), 1);
268 ir::value* remat_false_value = rematerialize(builder, block, masked_load->get_false_value_operand(), 1);
269 next_mask = builder.create_and(next_mask, remat_mask);
270 false_value = remat_false_value;
271 }
272 ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value, load->get_cache_modifier(), load->get_eviction_policy(), load->get_is_volatile());
273 // phi node
274 builder.set_insert_point(block->get_first_non_phi());
275 ir::phi_node* new_load = builder.create_phi(ty, 2);
276 new_load->add_incoming(first_load, header);
277 new_load->add_incoming(next_load, block);
278 load->replace_all_uses_with(new_load);
279 new_loads.push_back(new_load);
280 }
281 }
282
283 // try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to
284 // a0, b0, a1, b1, ...
285 if (!preheader_loads.empty()) {
286 ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0);
287 builder.set_insert_point(header->get_inst_list().back());
288 for (int i=1; i<num_stages-1; ++i) {
289 for (auto iter = preheader_loads.begin(); iter != preheader_loads.end(); ++iter) {
290 ir::instruction* original_load = static_cast<ir::instruction*>(iter->second.at(i));
291 ir::instruction* moved_load = original_load->clone();
292 builder.insert(moved_load);
293 original_load->replace_all_uses_with(moved_load);
294 }
295 }
296 }
297
298 // try to move dot_inst after loads
299 // for better overlap of io and compute
300 struct move_config_t{
301 std::vector<ir::instruction*> insts;
302 ir::load_inst* dst;
303 };
304 std::vector<move_config_t> to_move(to_pipeline.size());
305
306 if(has_copy_async_){
307 for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
308 auto info = to_pipeline[idx];
309 ir::load_inst* load = info.load;
310 ir::phi_node* ptr = info.ptr;
311 ir::dot_inst* dot = info.dot;
312 ir::basic_block* bb = dot->get_parent();
313 recursive_deps(dot, bb, to_move[idx].insts);
314 to_move[idx].dst = load;
315 }
316
317 for(auto& move_config: to_move){
318 builder.set_insert_point_after(move_config.dst);
319 for(ir::instruction* i: move_config.insts){
320 i->get_parent()->erase(i);
321 builder.insert(i);
322 }
323 }
324 }
325
326
327}
328
329}
330}
331}
332