1 | #include <vector> |
2 | #include <set> |
3 | #include <algorithm> |
4 | #include "triton/codegen/analysis/layout.h" |
5 | #include "triton/codegen/analysis/allocation.h" |
6 | #include "triton/codegen/transform/membar.h" |
7 | #include "triton/codegen/transform/prefetch.h" |
8 | #include "triton/ir/module.h" |
9 | #include "triton/ir/function.h" |
10 | #include "triton/ir/basic_block.h" |
11 | #include "triton/ir/instructions.h" |
12 | #include "triton/ir/utils.h" |
13 | |
14 | namespace triton { |
15 | |
16 | namespace codegen{ |
17 | namespace transform{ |
18 | |
19 | |
20 | |
21 | int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) { |
22 | if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){ |
23 | analysis::shared_layout* layout = layouts_->get(v)->to_shared(); |
24 | if (analysis::double_buffer_info_t* info = layout->get_double_buffer()) |
25 | return group_of(info->first, async_write); |
26 | else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) { |
27 | if (v == info->phi) |
28 | return group_of(info->firsts[0], async_write); |
29 | else // prefetched value |
30 | return group_of(info->firsts[1], async_write); |
31 | } |
32 | std::vector<int> groups(phi->get_num_operands()); |
33 | std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); |
34 | return *std::max_element(groups.begin(), groups.end()); |
35 | } |
36 | else{ |
37 | if(layouts_->has_tmp(v)) |
38 | return async_write.size() - 1; |
39 | // // Ignore copy_to_shared. It won't modify async behavior. |
40 | // if(dynamic_cast<ir::copy_to_shared_inst*>(v)) |
41 | // return 0; |
42 | auto it = std::find(async_write.begin(), async_write.end(), v); |
43 | return std::distance(async_write.begin(), it); |
44 | } |
45 | } |
46 | |
47 | inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) { |
48 | if(!a_layout || !b_layout) |
49 | return false; |
50 | int a_start = alloc_->offset(a_layout); |
51 | int a_end = a_start + a_layout->get_size(); |
52 | int b_start = alloc_->offset(b_layout); |
53 | int b_end = b_start + b_layout->get_size(); |
54 | if(a_start < b_end || b_start < a_end) |
55 | return true; |
56 | return false; |
57 | } |
58 | |
59 | membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { |
60 | val_set_t ret; |
61 | for(ir::value* a: as){ |
62 | if(!a->get_type()->is_block_ty()) |
63 | continue; |
64 | analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); |
65 | analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr; |
66 | analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr; |
67 | for(ir::value* b: bs){ |
68 | if(!b->get_type()->is_block_ty()) |
69 | continue; |
70 | analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); |
71 | analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr; |
72 | analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr; |
73 | if(intersect_with(a_layout, b_layout) || |
74 | intersect_with(a_layout, b_tmp) || |
75 | intersect_with(a_layout, b_tmp_index) || |
76 | intersect_with(a_tmp, b_layout) || |
77 | intersect_with(a_tmp, b_tmp) || |
78 | intersect_with(a_tmp, b_tmp_index) || |
79 | intersect_with(a_tmp_index, b_layout) || |
80 | intersect_with(a_tmp_index, b_tmp) || |
81 | intersect_with(a_tmp_index, b_tmp_index)) |
82 | ret.insert(b); |
83 | } |
84 | } |
85 | return ret; |
86 | } |
87 | |
88 | bool membar::check_safe_war(ir::instruction* i) { |
89 | bool is_i_shared_block = i->get_type()->is_block_ty() && |
90 | layouts_->get(i)->to_shared(); |
91 | bool is_i_double_buffered = is_i_shared_block && |
92 | layouts_->get(i)->to_shared()->get_double_buffer(); |
93 | bool is_i_n_buffered = is_i_shared_block && |
94 | layouts_->get(i)->to_shared()->get_N_buffer(); |
95 | |
96 | if (is_i_double_buffered || is_i_n_buffered) { |
97 | // with async copy & prefetch_s disabled, WARs are not safe |
98 | if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i)) |
99 | return false; |
100 | else |
101 | return true; |
102 | } |
103 | return false; |
104 | } |
105 | |
106 | void membar::transfer(ir::basic_block *block, |
107 | val_vec_t& async_write, |
108 | val_set_t& sync_write, |
109 | val_set_t& sync_read, |
110 | std::set<ir::value*>& safe_war, |
111 | bool& inserted, ir::builder& builder) { |
112 | std::vector<ir::async_wait_inst*> async_waits; |
113 | ir::basic_block::inst_list_t instructions = block->get_inst_list(); |
114 | for(ir::instruction *i: instructions){ |
115 | if(dynamic_cast<ir::phi_node*>(i)) |
116 | continue; |
117 | if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() && |
118 | dynamic_cast<ir::masked_load_async_inst*>(i)){ |
119 | async_write.push_back(i); |
120 | } |
121 | if(dynamic_cast<ir::copy_to_shared_inst*>(i)) |
122 | sync_write.insert(i); |
123 | ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i); |
124 | ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i); |
125 | // Get shared memory reads |
126 | std::set<ir::value*> read; |
127 | std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), |
128 | [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();}); |
129 | if(layouts_->has_tmp(i)) |
130 | read.insert(i); |
131 | // RAW (async) |
132 | val_set_t tmp; |
133 | std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); |
134 | if(intersect_with(read, tmp).size()){ |
135 | std::vector<int> groups(read.size()); |
136 | std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); |
137 | int N = *std::max_element(groups.begin(), groups.end()); |
138 | if(N < async_write.size()){ |
139 | builder.set_insert_point(i); |
140 | async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N); |
141 | barrier = (ir::barrier_inst*)builder.create_barrier(); |
142 | inserted = true; |
143 | async_waits.push_back(async_wait); |
144 | } |
145 | } |
146 | // RAW, WAR |
147 | bool is_safe_war = check_safe_war(i); |
148 | // WAR barrier is not required when data is double-buffered |
149 | if(!intersect_with(read, sync_write).empty() || |
150 | (!intersect_with({i}, sync_read).empty() && !is_safe_war)) { |
151 | builder.set_insert_point(i); |
152 | barrier = (ir::barrier_inst*)builder.create_barrier(); |
153 | inserted = true; |
154 | } |
155 | // update state of asynchronous copies |
156 | if(async_wait){ |
157 | int N = async_write.size() - async_wait->get_N(); |
158 | async_write.erase(async_write.begin(), async_write.begin() + N); |
159 | } |
160 | // all the copy_to_shared and read from shared are synchronized after barrier |
161 | if(barrier){ |
162 | sync_write.clear(); |
163 | sync_read.clear(); |
164 | } |
165 | sync_read.insert(read.begin(), read.end()); |
166 | } |
167 | |
168 | // coalesce barriers |
169 | // fixme: to support more general cases |
170 | if (async_waits.size() == 2) { |
171 | // (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;) |
172 | for (int idx=0; idx<async_waits.size()-1; ++idx) { |
173 | ir::async_wait_inst *first_async_wait = async_waits[idx]; |
174 | std::vector<ir::instruction*> to_erase; |
175 | ir::basic_block::inst_list_t instructions = block->get_inst_list(); |
176 | for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){ |
177 | ir::instruction *i = *iter; |
178 | if (static_cast<ir::instruction*>(first_async_wait) == i) { |
179 | // peak next 5 instructions |
180 | auto peak_iter = std::next(iter); |
181 | if (std::distance(peak_iter, instructions.end()) >= 5) { |
182 | auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++); |
183 | auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++); |
184 | auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++); |
185 | auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++); |
186 | auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter); |
187 | if (first_bar && first_pf && second_async_wait && second_bar && second_pf) { |
188 | int first_n = first_async_wait->get_N(); |
189 | int second_n = second_async_wait->get_N(); |
190 | to_erase.push_back(second_async_wait); |
191 | to_erase.push_back(second_bar); |
192 | first_async_wait->set_N(second_n); |
193 | } |
194 | } else |
195 | break; |
196 | for (ir::instruction *i : to_erase) |
197 | block->erase(i); |
198 | } |
199 | } |
200 | } |
201 | } |
202 | } |
203 | |
204 | void membar::run(ir::module &mod) { |
205 | ir::builder &builder = mod.get_builder(); |
206 | // extract phi-node associates with double-buffered |
207 | // shared-memory copies. These can be read from and written to |
208 | // without needing synchronization |
209 | std::set<ir::value*> safe_war; |
210 | for(const auto& x: layouts_->get_all()){ |
211 | analysis::shared_layout* layout = x.second->to_shared(); |
212 | if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer()) |
213 | continue; |
214 | for(ir::value *v: layout->get_values()) |
215 | if(v != layout->get_double_buffer()->phi){ |
216 | safe_war.insert(v); |
217 | } |
218 | } |
219 | |
220 | for(ir::function *fn: mod.get_function_list()){ |
221 | std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn); |
222 | std::map<ir::basic_block*, val_vec_t> async_writes; |
223 | std::map<ir::basic_block*, val_set_t> sync_writes; |
224 | std::map<ir::basic_block*, val_set_t> sync_reads; |
225 | std::list<ir::value *> pipelined; |
226 | bool inserted; |
227 | do{ |
228 | inserted = false; |
229 | // find barrier location |
230 | for(ir::basic_block *block: rpo){ |
231 | // join inputs |
232 | val_vec_t async_write; |
233 | val_set_t sync_write; |
234 | val_set_t sync_read; |
235 | val_set_t tmp; |
236 | for(ir::basic_block* pred: block->get_predecessors()){ |
237 | for(ir::value* v: async_writes[pred]) |
238 | if(tmp.insert(v).second) |
239 | async_write.push_back(v); |
240 | sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end()); |
241 | sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end()); |
242 | } |
243 | transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder); |
244 | async_writes[block] = async_write; |
245 | sync_writes[block] = sync_write; |
246 | sync_reads[block] = sync_read; |
247 | } |
248 | }while(inserted); |
249 | } |
250 | } |
251 | |
252 | } |
253 | } |
254 | } |
255 | |