1#include <vector>
2#include <set>
3#include <algorithm>
4#include "triton/codegen/analysis/layout.h"
5#include "triton/codegen/analysis/allocation.h"
6#include "triton/codegen/transform/membar.h"
7#include "triton/codegen/transform/prefetch.h"
8#include "triton/ir/module.h"
9#include "triton/ir/function.h"
10#include "triton/ir/basic_block.h"
11#include "triton/ir/instructions.h"
12#include "triton/ir/utils.h"
13
14namespace triton {
15
16namespace codegen{
17namespace transform{
18
19
20
21int membar::group_of(ir::value* v, std::vector<ir::value*> &async_write) {
22 if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(v)){
23 analysis::shared_layout* layout = layouts_->get(v)->to_shared();
24 if (analysis::double_buffer_info_t* info = layout->get_double_buffer())
25 return group_of(info->first, async_write);
26 else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) {
27 if (v == info->phi)
28 return group_of(info->firsts[0], async_write);
29 else // prefetched value
30 return group_of(info->firsts[1], async_write);
31 }
32 std::vector<int> groups(phi->get_num_operands());
33 std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
34 return *std::max_element(groups.begin(), groups.end());
35 }
36 else{
37 if(layouts_->has_tmp(v))
38 return async_write.size() - 1;
39 // // Ignore copy_to_shared. It won't modify async behavior.
40 // if(dynamic_cast<ir::copy_to_shared_inst*>(v))
41 // return 0;
42 auto it = std::find(async_write.begin(), async_write.end(), v);
43 return std::distance(async_write.begin(), it);
44 }
45}
46
47inline bool membar::intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout) {
48 if(!a_layout || !b_layout)
49 return false;
50 int a_start = alloc_->offset(a_layout);
51 int a_end = a_start + a_layout->get_size();
52 int b_start = alloc_->offset(b_layout);
53 int b_end = b_start + b_layout->get_size();
54 if(a_start < b_end || b_start < a_end)
55 return true;
56 return false;
57}
58
59membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) {
60 val_set_t ret;
61 for(ir::value* a: as){
62 if(!a->get_type()->is_block_ty())
63 continue;
64 analysis::shared_layout* a_layout = layouts_->get(a)->to_shared();
65 analysis::shared_layout* a_tmp = layouts_->has_tmp(a) ? layouts_->get(layouts_->tmp(a))->to_shared() : nullptr;
66 analysis::shared_layout* a_tmp_index = layouts_->has_tmp_index(a) ? layouts_->get(layouts_->tmp_index(a))->to_shared() : nullptr;
67 for(ir::value* b: bs){
68 if(!b->get_type()->is_block_ty())
69 continue;
70 analysis::shared_layout* b_layout = layouts_->get(b)->to_shared();
71 analysis::shared_layout* b_tmp = layouts_->has_tmp(b) ? layouts_->get(layouts_->tmp(b))->to_shared() : nullptr;
72 analysis::shared_layout* b_tmp_index = layouts_->has_tmp_index(b) ? layouts_->get(layouts_->tmp_index(b))->to_shared() : nullptr;
73 if(intersect_with(a_layout, b_layout) ||
74 intersect_with(a_layout, b_tmp) ||
75 intersect_with(a_layout, b_tmp_index) ||
76 intersect_with(a_tmp, b_layout) ||
77 intersect_with(a_tmp, b_tmp) ||
78 intersect_with(a_tmp, b_tmp_index) ||
79 intersect_with(a_tmp_index, b_layout) ||
80 intersect_with(a_tmp_index, b_tmp) ||
81 intersect_with(a_tmp_index, b_tmp_index))
82 ret.insert(b);
83 }
84 }
85 return ret;
86}
87
88bool membar::check_safe_war(ir::instruction* i) {
89 bool is_i_shared_block = i->get_type()->is_block_ty() &&
90 layouts_->get(i)->to_shared();
91 bool is_i_double_buffered = is_i_shared_block &&
92 layouts_->get(i)->to_shared()->get_double_buffer();
93 bool is_i_n_buffered = is_i_shared_block &&
94 layouts_->get(i)->to_shared()->get_N_buffer();
95
96 if (is_i_double_buffered || is_i_n_buffered) {
97 // with async copy & prefetch_s disabled, WARs are not safe
98 if (dynamic_cast<ir::masked_load_async_inst*>(i) && !prefetch_->is_prefetched(i))
99 return false;
100 else
101 return true;
102 }
103 return false;
104}
105
106void membar::transfer(ir::basic_block *block,
107 val_vec_t& async_write,
108 val_set_t& sync_write,
109 val_set_t& sync_read,
110 std::set<ir::value*>& safe_war,
111 bool& inserted, ir::builder& builder) {
112 std::vector<ir::async_wait_inst*> async_waits;
113 ir::basic_block::inst_list_t instructions = block->get_inst_list();
114 for(ir::instruction *i: instructions){
115 if(dynamic_cast<ir::phi_node*>(i))
116 continue;
117 if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() &&
118 dynamic_cast<ir::masked_load_async_inst*>(i)){
119 async_write.push_back(i);
120 }
121 if(dynamic_cast<ir::copy_to_shared_inst*>(i))
122 sync_write.insert(i);
123 ir::barrier_inst* barrier = dynamic_cast<ir::barrier_inst*>(i);
124 ir::async_wait_inst* async_wait = dynamic_cast<ir::async_wait_inst*>(i);
125 // Get shared memory reads
126 std::set<ir::value*> read;
127 std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()),
128 [&](ir::value* i){ return i->get_type()->is_block_ty() && layouts_->get(i)->to_shared();});
129 if(layouts_->has_tmp(i))
130 read.insert(i);
131 // RAW (async)
132 val_set_t tmp;
133 std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin()));
134 if(intersect_with(read, tmp).size()){
135 std::vector<int> groups(read.size());
136 std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);});
137 int N = *std::max_element(groups.begin(), groups.end());
138 if(N < async_write.size()){
139 builder.set_insert_point(i);
140 async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N);
141 barrier = (ir::barrier_inst*)builder.create_barrier();
142 inserted = true;
143 async_waits.push_back(async_wait);
144 }
145 }
146 // RAW, WAR
147 bool is_safe_war = check_safe_war(i);
148 // WAR barrier is not required when data is double-buffered
149 if(!intersect_with(read, sync_write).empty() ||
150 (!intersect_with({i}, sync_read).empty() && !is_safe_war)) {
151 builder.set_insert_point(i);
152 barrier = (ir::barrier_inst*)builder.create_barrier();
153 inserted = true;
154 }
155 // update state of asynchronous copies
156 if(async_wait){
157 int N = async_write.size() - async_wait->get_N();
158 async_write.erase(async_write.begin(), async_write.begin() + N);
159 }
160 // all the copy_to_shared and read from shared are synchronized after barrier
161 if(barrier){
162 sync_write.clear();
163 sync_read.clear();
164 }
165 sync_read.insert(read.begin(), read.end());
166 }
167
168 // coalesce barriers
169 // fixme: to support more general cases
170 if (async_waits.size() == 2) {
171 // (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;)
172 for (int idx=0; idx<async_waits.size()-1; ++idx) {
173 ir::async_wait_inst *first_async_wait = async_waits[idx];
174 std::vector<ir::instruction*> to_erase;
175 ir::basic_block::inst_list_t instructions = block->get_inst_list();
176 for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){
177 ir::instruction *i = *iter;
178 if (static_cast<ir::instruction*>(first_async_wait) == i) {
179 // peak next 5 instructions
180 auto peak_iter = std::next(iter);
181 if (std::distance(peak_iter, instructions.end()) >= 5) {
182 auto first_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
183 auto first_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter++);
184 auto second_async_wait = dynamic_cast<ir::async_wait_inst*>(*peak_iter++);
185 auto second_bar = dynamic_cast<ir::barrier_inst*>(*peak_iter++);
186 auto second_pf = dynamic_cast<ir::prefetch_s_inst*>(*peak_iter);
187 if (first_bar && first_pf && second_async_wait && second_bar && second_pf) {
188 int first_n = first_async_wait->get_N();
189 int second_n = second_async_wait->get_N();
190 to_erase.push_back(second_async_wait);
191 to_erase.push_back(second_bar);
192 first_async_wait->set_N(second_n);
193 }
194 } else
195 break;
196 for (ir::instruction *i : to_erase)
197 block->erase(i);
198 }
199 }
200 }
201 }
202}
203
204void membar::run(ir::module &mod) {
205 ir::builder &builder = mod.get_builder();
206 // extract phi-node associates with double-buffered
207 // shared-memory copies. These can be read from and written to
208 // without needing synchronization
209 std::set<ir::value*> safe_war;
210 for(const auto& x: layouts_->get_all()){
211 analysis::shared_layout* layout = x.second->to_shared();
212 if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer())
213 continue;
214 for(ir::value *v: layout->get_values())
215 if(v != layout->get_double_buffer()->phi){
216 safe_war.insert(v);
217 }
218 }
219
220 for(ir::function *fn: mod.get_function_list()){
221 std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
222 std::map<ir::basic_block*, val_vec_t> async_writes;
223 std::map<ir::basic_block*, val_set_t> sync_writes;
224 std::map<ir::basic_block*, val_set_t> sync_reads;
225 std::list<ir::value *> pipelined;
226 bool inserted;
227 do{
228 inserted = false;
229 // find barrier location
230 for(ir::basic_block *block: rpo){
231 // join inputs
232 val_vec_t async_write;
233 val_set_t sync_write;
234 val_set_t sync_read;
235 val_set_t tmp;
236 for(ir::basic_block* pred: block->get_predecessors()){
237 for(ir::value* v: async_writes[pred])
238 if(tmp.insert(v).second)
239 async_write.push_back(v);
240 sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end());
241 sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end());
242 }
243 transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder);
244 async_writes[block] = async_write;
245 sync_writes[block] = sync_write;
246 sync_reads[block] = sync_read;
247 }
248 }while(inserted);
249 }
250}
251
252}
253}
254}
255