1 | #include <algorithm> |
2 | #include <numeric> |
3 | #include <iostream> |
4 | #include "triton/codegen/analysis/axes.h" |
5 | #include "triton/codegen/analysis/align.h" |
6 | #include "triton/codegen/analysis/layout.h" |
7 | #include "triton/ir/function.h" |
8 | #include "triton/ir/module.h" |
9 | #include "triton/ir/utils.h" |
10 | // #include "triton/ir/type.h" |
11 | |
12 | namespace triton{ |
13 | namespace codegen{ |
14 | namespace analysis{ |
15 | |
16 | /* -------------------------------- * |
17 | * Helper Functions * |
18 | * -------------------------------- */ |
19 | |
20 | inline unsigned clamp(unsigned x, unsigned a, unsigned b) { |
21 | unsigned lo = std::min(a, b); |
22 | unsigned hi = std::max(a, b); |
23 | return std::min(std::max(x, lo), hi); |
24 | } |
25 | |
26 | inline bool is_hmma_c(ir::value *v, int sm){ |
27 | bool result = false; |
28 | if(auto *x = dynamic_cast<ir::dot_inst*>(v)){ |
29 | ir::value *a = x->get_operand(0); |
30 | ir::type *a_ty = a->get_type(); |
31 | ir::value *b = x->get_operand(1); |
32 | ir::type *b_ty = b->get_type(); |
33 | result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) || |
34 | (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) || |
35 | (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() && |
36 | x->allow_tf32() && sm >= 80) || |
37 | (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) && |
38 | sm >= 80); |
39 | } |
40 | return result; |
41 | } |
42 | |
43 | static mma_layout::TensorCoreType get_mma_type(ir::value *v) { |
44 | mma_layout::TensorCoreType mma_type; |
45 | if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) { |
46 | ir::value* a = dot->get_operand(0); |
47 | ir::value* b = dot->get_operand(1); |
48 | ir::type* a_ty = a->get_type(); |
49 | ir::type* b_ty = b->get_type(); |
50 | ir::type* c_ty = v->get_type(); |
51 | |
52 | if (c_ty->get_scalar_ty()->is_fp32_ty()) { |
53 | // floating point tensor cores |
54 | if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) { |
55 | mma_type = mma_layout::FP32_FP16_FP16_FP32; |
56 | return mma_type; |
57 | } |
58 | if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) { |
59 | mma_type = mma_layout::FP32_BF16_BF16_FP32; |
60 | return mma_type; |
61 | } |
62 | if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() |
63 | && dot->allow_tf32()) { |
64 | mma_type = mma_layout::FP32_TF32_TF32_FP32; |
65 | return mma_type; |
66 | } |
67 | } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) { |
68 | // throw std::runtime_error("integer tensor cores are not yet supported"); |
69 | // // integer tensor cores |
70 | // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) { |
71 | // mma_type = mma_layout::INT32_INT1_INT1_INT32; |
72 | // return mma_type; |
73 | // } |
74 | // if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) { |
75 | // mma_type = mma_layout::INT32_INT4_INT4_INT32; |
76 | // return mma_type; |
77 | // } |
78 | if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) { |
79 | mma_type = mma_layout::INT32_INT8_INT8_INT32; |
80 | return mma_type; |
81 | } |
82 | } |
83 | } |
84 | return mma_layout::NOT_APPLICABLE; |
85 | } |
86 | |
87 | inline void (ir::value *v, std::set<ir::value*>& result) { |
88 | for(ir::user* u: v->get_users()){ |
89 | auto i = dynamic_cast<ir::io_inst*>(u); |
90 | if(i && i->get_pointer_operand() == v) |
91 | result.insert(v); |
92 | } |
93 | } |
94 | |
95 | inline void (ir::value *v, ir::value*& result, size_t n) { |
96 | for(ir::user* u: v->get_users()){ |
97 | auto i = dynamic_cast<ir::dot_inst*>(u); |
98 | if(i && i->get_operand(n) == v) |
99 | result = v; |
100 | } |
101 | } |
102 | |
103 | inline void (ir::value *v, ir::value*& result, size_t n, int sm) { |
104 | for(ir::user* u: v->get_users()){ |
105 | auto i = dynamic_cast<ir::dot_inst*>(u); |
106 | if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) { |
107 | result = i; |
108 | } |
109 | } |
110 | } |
111 | |
112 | |
113 | inline bool is_trans(ir::value *v) { |
114 | if(dynamic_cast<ir::trans_inst *>(v)) { |
115 | return true; |
116 | } |
117 | if(auto *phi = dynamic_cast<ir::instruction *>(v)) { |
118 | bool result = true; |
119 | for(ir::value *op: phi->ops()) |
120 | result = result && is_trans(op); |
121 | return result; |
122 | } |
123 | return false; |
124 | } |
125 | |
126 | |
127 | /* -------------------------------- * |
128 | * Layout Visitor * |
129 | * -------------------------------- */ |
130 | |
131 | void layout_visitor::visit_layout(data_layout *layout) { |
132 | layout->accept(this); |
133 | } |
134 | |
135 | |
136 | /* -------------------------------- * |
137 | * Base Data Layout * |
138 | * -------------------------------- */ |
139 | |
140 | data_layout::data_layout(id_t id, |
141 | const std::vector<int> &axes, |
142 | const std::vector<unsigned> &shape, |
143 | const std::vector<ir::value *> &values, |
144 | analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) { |
145 | // io pointer |
146 | std::set<ir::value*> ptr; |
147 | for(ir::value* v: values_) |
148 | extract_io_use(v, ptr); |
149 | order_.resize(axes_.size()); |
150 | std::iota(order_.begin(), order_.end(), 0); |
151 | std::vector<unsigned> max_contiguous; |
152 | for(ir::value* p: ptr){ |
153 | std::vector<unsigned> curr = align->contiguous(p); |
154 | if(curr.size() > max_contiguous.size()) |
155 | max_contiguous = curr; |
156 | else if(curr.size() == max_contiguous.size()){ |
157 | if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end())) |
158 | max_contiguous = curr; |
159 | } |
160 | } |
161 | if(max_contiguous.size() > 0){ |
162 | std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) { |
163 | return max_contiguous[a] > max_contiguous[b]; |
164 | }); |
165 | // std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; |
166 | // std::cout << order_[0] << " " << order_[1] << std::endl; |
167 | } |
168 | } |
169 | |
170 | int data_layout::find_axis(int to_find) const { |
171 | auto it = std::find(axes_.begin(), axes_.end(), to_find); |
172 | if(it == axes_.end()) |
173 | return -1; |
174 | return std::distance(axes_.begin(), it); |
175 | } |
176 | |
177 | |
178 | distributed_layout::distributed_layout(id_t id, |
179 | const std::vector<int> &axes, |
180 | const std::vector<unsigned> &shape, |
181 | const std::vector<ir::value *> &values, |
182 | analysis::align* align): data_layout(id, axes, shape, values, align) |
183 | { } |
184 | |
185 | /* -------------------------------- * |
186 | * MMA Layout * |
187 | * -------------------------------- */ |
188 | |
189 | mma_layout::mma_layout(size_t num_warps, |
190 | const std::vector<int>& axes, |
191 | const std::vector<unsigned>& shape, |
192 | const std::vector<ir::value *> &values, |
193 | analysis::align* align, target* tgt, |
194 | shared_layout *layout_a, shared_layout *layout_b, |
195 | ir::value *dot): distributed_layout(MMA, axes, shape, values, align) { |
196 | tensor_core_type_ = get_mma_type(dot); |
197 | /* fragments per warp */ |
198 | // try to make things as square as possible to maximize data re-use |
199 | if(tgt->as_nvidia()->sm() < 80){ |
200 | fpw_ = {2, 2, 1}; |
201 | auto ord_a = layout_a->get_order(); |
202 | auto ord_b = layout_b->get_order(); |
203 | bool is_a_row = ord_a[0] != 0; |
204 | bool is_b_row = ord_b[0] != 0; |
205 | bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16); |
206 | bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); |
207 | int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; |
208 | int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; |
209 | rep_ = {2*pack_size_0, 2*pack_size_1, 1}; |
210 | spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1}; |
211 | contig_per_thread_ = {1, 1}; |
212 | order_ = {0, 1}; |
213 | } |
214 | else{ |
215 | spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 |
216 | contig_per_thread_ = {1, 2}; |
217 | order_ = {1, 0}; |
218 | } |
219 | |
220 | /* warps per tile */ |
221 | wpt_ = {1, 1, 1}; |
222 | // try to make warp-level tiles as square as possible to maximize data re-use |
223 | if (tgt->as_nvidia()->sm() < 80) { |
224 | std::vector<int> wpt_nm1; |
225 | do{ |
226 | wpt_nm1 = wpt_; |
227 | if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) |
228 | wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); |
229 | if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) |
230 | wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); |
231 | }while(wpt_nm1 != wpt_); |
232 | } else { |
233 | bool changed = false; |
234 | // try to have a warp own entire rows of the output |
235 | // this makes it easier to fuse multiple mmas by fusing |
236 | // registers |
237 | bool one_warp_per_row = false; |
238 | for(ir::value* v: values) |
239 | for(ir::user* u: v->get_users()){ |
240 | auto* dot = dynamic_cast<ir::dot_inst*>(u); |
241 | auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u); |
242 | if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts) |
243 | one_warp_per_row = shape[0] / spw_[0] >= num_warps; |
244 | } |
245 | // std::cout << one_warp_per_row << std::endl; |
246 | |
247 | if(one_warp_per_row){ |
248 | wpt_[1] = 1; |
249 | wpt_[0] = num_warps; |
250 | } |
251 | else{ |
252 | do { |
253 | changed = false; |
254 | if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) |
255 | break; |
256 | if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { |
257 | if (wpt_[0] < shape_[0] / spw_[0]) { |
258 | wpt_[0] *= 2; |
259 | changed = true; |
260 | } |
261 | } else { |
262 | if (wpt_[1] < shape_[1] / (spw_[1]*2)) { |
263 | wpt_[1] *= 2; |
264 | changed = true; |
265 | } |
266 | } |
267 | } while(changed); |
268 | } |
269 | } |
270 | |
271 | // std::cout << wpt_[0] << " " << wpt_[1] << std::endl; |
272 | |
273 | /* shape per block */ |
274 | shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; |
275 | } |
276 | |
277 | |
278 | /* -------------------------------- * |
279 | * Scanline Layout * |
280 | * -------------------------------- */ |
281 | |
282 | scanline_layout::scanline_layout(size_t num_warps, |
283 | const std::vector<int>& axes, |
284 | const std::vector<unsigned>& shape, |
285 | const std::vector<ir::value *> &values, |
286 | analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){ |
287 | unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>()); |
288 | unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; |
289 | nts_.resize(shape_.size()); |
290 | mts_.resize(shape_.size()); |
291 | bool is_dot = std::any_of(values.begin(), values.end(), |
292 | [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); }); |
293 | |
294 | std::vector<ir::value*> ptrs; |
295 | for(ir::value *v: values) |
296 | for(ir::user *usr: v->get_users()) |
297 | if(auto *io = dynamic_cast<ir::io_inst*>(usr)){ |
298 | if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank()) |
299 | ptrs.push_back(io->get_pointer_operand()); |
300 | } |
301 | |
302 | unsigned i = order_[0]; |
303 | int contiguous = 1; |
304 | for(ir::value* ptr: ptrs){ |
305 | int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); |
306 | contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits)); |
307 | } |
308 | |
309 | nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); |
310 | mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); |
311 | size /= shape_[i]; |
312 | num_threads /= mts_[i]; |
313 | if(is_dot) |
314 | nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]])); |
315 | for(size_t d = 1; d < shape_.size(); d++){ |
316 | i = order_[d]; |
317 | if(d > 1 || !is_dot) |
318 | nts_[i] = 1; |
319 | mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); |
320 | num_threads = num_threads / mts_[i]; |
321 | } |
322 | |
323 | shape_per_cta_.resize(shape_.size()); |
324 | for(size_t d = 0; d < shape_.size(); d++) |
325 | shape_per_cta_[d] = mts_[d]*nts_[d]; |
326 | } |
327 | |
328 | |
329 | /* -------------------------------- * |
330 | * Shared Layout * |
331 | * -------------------------------- */ |
332 | |
333 | bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ |
334 | if(phi->get_parent() != terminator->get_parent()) |
335 | return false; |
336 | if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator)) |
337 | return br->get_true_dest() == phi->get_parent() |
338 | || br->get_false_dest() == phi->get_parent(); |
339 | else if(dynamic_cast<ir::uncond_branch_inst*>(terminator)) |
340 | return false; |
341 | else |
342 | throw std::runtime_error("unreachable" ); |
343 | } |
344 | |
345 | |
346 | void shared_layout::(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) { |
347 | auto* phi = dynamic_cast<ir::phi_node*>(v); |
348 | if(!phi || phi->get_num_incoming() != 2) |
349 | return; |
350 | ir::basic_block *block_0 = phi->get_incoming_block(0); |
351 | ir::basic_block *block_1 = phi->get_incoming_block(1); |
352 | ir::instruction *terminator_0 = block_0->get_inst_list().back(); |
353 | ir::instruction *terminator_1 = block_1->get_inst_list().back(); |
354 | bool is_latch_0 = is_loop_latch(phi, terminator_0); |
355 | bool is_latch_1 = is_loop_latch(phi, terminator_1); |
356 | ir::value *value_0 = phi->get_incoming_value(0); |
357 | ir::value *value_1 = phi->get_incoming_value(1); |
358 | ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0); |
359 | ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1); |
360 | if(!(i_0 && !i_1) && |
361 | !(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) && |
362 | !(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1))) |
363 | return; |
364 | if(is_latch_1) |
365 | res.reset(new double_buffer_info_t{value_0, value_1, phi}); |
366 | if(is_latch_0) |
367 | res.reset(new double_buffer_info_t{value_1, value_0, phi}); |
368 | } |
369 | |
370 | static bool is_smem_in(ir::value* v, const ir::basic_block* bb) { |
371 | if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) { |
372 | if (instr->get_parent() != bb) |
373 | return false; |
374 | if (dynamic_cast<ir::copy_to_shared_inst*>(v) || |
375 | dynamic_cast<ir::masked_load_async_inst*>(v)) { |
376 | return true; |
377 | } |
378 | } |
379 | return false; |
380 | } |
381 | |
382 | /// param: |
383 | /// value_1: next_value |
384 | static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1, |
385 | std::vector<ir::value*>& values_0, ir::value*& value_1) { |
386 | ir::value* next = phi; |
387 | while (auto cphi = dynamic_cast<ir::phi_node*>(next)) { |
388 | // smem from previous bb & phi/smem from current bb |
389 | ir::value* c0 = cphi->get_incoming_value(0); |
390 | ir::value* c1 = cphi->get_incoming_value(1); |
391 | ir::basic_block *cbb0 = cphi->get_incoming_block(0); |
392 | ir::basic_block *cbb1 = cphi->get_incoming_block(1); |
393 | |
394 | if (is_smem_in(c0, cbb0)) { |
395 | assert(cbb0 == bb0); |
396 | values_0.push_back(c0); |
397 | if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) { |
398 | next = phi1; |
399 | continue; |
400 | } else { |
401 | if (is_smem_in(c1, cbb1)) { |
402 | value_1 = c1; |
403 | assert(cbb1 == bb1); |
404 | return true; |
405 | } else { |
406 | return false; |
407 | } |
408 | } |
409 | } else |
410 | return false; |
411 | } |
412 | return false; |
413 | } |
414 | |
415 | void shared_layout::(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) { |
416 | auto* phi = dynamic_cast<ir::phi_node*>(v); |
417 | // if the phi node is nested |
418 | if (!phi) |
419 | return; |
420 | |
421 | ir::basic_block *bb0 = phi->get_incoming_block(0); |
422 | ir::basic_block *bb1 = phi->get_incoming_block(1); |
423 | |
424 | std::vector<ir::value*> values_0; |
425 | ir::value* value_1; |
426 | |
427 | if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1)) |
428 | return; |
429 | |
430 | // double-buffer is a special case |
431 | if (values_0.size() == 1) |
432 | return; |
433 | |
434 | // compute original values_0 input order |
435 | std::map<ir::value*, int> order; |
436 | int idx = 0; |
437 | for (ir::instruction* instr : *bb0) { |
438 | if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end()) |
439 | order[static_cast<ir::value*>(instr)] = idx++; |
440 | } |
441 | assert(order.size() == values_0.size() && "order size incorrect" ); |
442 | |
443 | int curr_stages = values_0.size() + 1; |
444 | if (curr_stages > prev_stages) { |
445 | res.reset(new N_buffer_info_t{values_0, value_1, phi, order}); |
446 | prev_stages = curr_stages; |
447 | } |
448 | } |
449 | |
450 | |
451 | shared_layout::shared_layout(data_layout *arg, |
452 | const std::vector<int>& axes, |
453 | const std::vector<unsigned>& shape, |
454 | const std::vector<ir::value *> &values, |
455 | ir::type *ty, |
456 | analysis::align* align, target *tgt, bool is_tmp) |
457 | : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){ |
458 | |
459 | size_ = 0; |
460 | arg_layout_ = arg; |
461 | |
462 | // N-stage buffering |
463 | int prev_stages = 0; |
464 | for (ir::value *v : values) |
465 | extract_N_bufferable(v, N_buffer_, prev_stages); |
466 | |
467 | // double-buffering |
468 | if (!N_buffer_) |
469 | for(ir::value *v: values) |
470 | extract_double_bufferable(v, double_buffer_); |
471 | |
472 | // order |
473 | std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0}; |
474 | order_ = arg_order; |
475 | |
476 | ir::value* dot_a = nullptr; |
477 | ir::value* dot_b = nullptr; |
478 | ir::value* hmma_dot_a = nullptr; |
479 | ir::value* hmma_dot_b = nullptr; |
480 | for(ir::value* v: values){ |
481 | extract_dot_use(v, dot_a, 0); |
482 | extract_dot_use(v, dot_b, 1); |
483 | extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm()); |
484 | extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm()); |
485 | } |
486 | hmma_dot_a_ = hmma_dot_a; |
487 | hmma_dot_b_ = hmma_dot_b; |
488 | |
489 | // Update mma_vec |
490 | if (hmma_dot_a_) { |
491 | assert(order_.size() == 2); |
492 | std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_)); |
493 | mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m |
494 | mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2]; |
495 | |
496 | // for now, disable swizzle when using lds.8 |
497 | if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32) |
498 | if (order_[0] == 0) // need transpose |
499 | allow_swizzle_ = false; |
500 | } else if (hmma_dot_b_) { |
501 | assert(order_.size() == 2); |
502 | std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_)); |
503 | mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k |
504 | mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1]; |
505 | |
506 | // for now, disable swizzle when using lds.8 |
507 | if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32) |
508 | if (order_[0] == 1) // need transpose |
509 | allow_swizzle_ = false; |
510 | } |
511 | |
512 | // size |
513 | size_ = ty_->get_primitive_size_in_bits() / 8; |
514 | for(auto s: shape_) |
515 | size_ *= s; |
516 | if(double_buffer_) |
517 | size_ *= 2; |
518 | if (N_buffer_) { |
519 | size_ *= (N_buffer_->firsts.size() + 1); |
520 | } |
521 | } |
522 | |
523 | int shared_layout::get_num_stages() const { |
524 | if (double_buffer_) |
525 | return 2; |
526 | if (N_buffer_) |
527 | return N_buffer_->firsts.size() + 1; |
528 | return 1; |
529 | } |
530 | |
531 | size_t shared_layout::get_per_stage_elements() const { |
532 | return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8); |
533 | } |
534 | |
535 | /* -------------------------------- * |
536 | * ---- Layouts Inference Pass ---- * |
537 | * -------------------------------- */ |
538 | |
539 | layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt) |
540 | : axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ } |
541 | |
542 | |
543 | void layouts::connect(ir::value *x, ir::value *y) { |
544 | if(x == y) |
545 | return; |
546 | if(!x->get_type()->is_block_ty()) |
547 | return; |
548 | if(!y->get_type()->is_block_ty()) |
549 | return; |
550 | std::vector<int> x_axes = axes_->get(x); |
551 | std::vector<int> y_axes = axes_->get(y); |
552 | std::set<int> sx_axes(x_axes.begin(), x_axes.end()); |
553 | std::set<int> sy_axes(y_axes.begin(), y_axes.end()); |
554 | std::set<int> common; |
555 | std::set_intersection(sx_axes.begin(), sx_axes.end(), |
556 | sy_axes.begin(), sy_axes.end(), |
557 | std::inserter(common, common.begin())); |
558 | graph_.add_edge(x, x); |
559 | graph_.add_edge(y, y); |
560 | if(!common.empty()) |
561 | graph_.add_edge(x, y); |
562 | } |
563 | |
564 | void layouts::make_graph(ir::instruction *i) { |
565 | for(ir::value* opx: i->ops()) |
566 | for(ir::value* opy: i->ops()){ |
567 | connect(i, opx); |
568 | connect(opx, opy); |
569 | } |
570 | } |
571 | |
572 | void layouts::create(size_t id, const std::vector<ir::value*>& values) { |
573 | // if(layouts_.find(id) != layouts_.end()) |
574 | // return; |
575 | auto it_hmma_c = std::find_if(values.begin(), values.end(), |
576 | [&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); }); |
577 | auto cmp = [](ir::value* x, ir::value *y) { |
578 | std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; |
579 | std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; |
580 | return xx < yy; |
581 | }; |
582 | std::vector<ir::value*> lvalue = values; |
583 | std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); }); |
584 | ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp); |
585 | const auto& axes = axes_->get(largest); |
586 | const auto& shapes = largest->get_type()->get_block_shapes(); |
587 | auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) { |
588 | return dynamic_cast<ir::copy_to_shared_inst*>(v) || |
589 | dynamic_cast<ir::masked_load_async_inst*>(v); |
590 | }); |
591 | // type |
592 | if(it_hmma_c != values.end()){ |
593 | ir::instruction *dot = (ir::instruction*)*it_hmma_c; |
594 | ir::value *a = dot->get_operand(0); |
595 | ir::value *b = dot->get_operand(1); |
596 | create(groups_.at(a), values_.at(groups_.at(a))); |
597 | create(groups_.at(b), values_.at(groups_.at(b))); |
598 | layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, |
599 | (shared_layout*)layouts_.at(groups_.at(a)), |
600 | (shared_layout*)layouts_.at(groups_.at(b)), |
601 | dot); |
602 | } |
603 | else if(it_cts != values.end()){ |
604 | ir::instruction *cts = (ir::instruction*)*it_cts; |
605 | ir::value *arg = cts->get_operand(0); |
606 | create(groups_.at(arg), values_.at(groups_.at(arg))); |
607 | layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_); |
608 | } |
609 | else{ |
610 | layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); |
611 | } |
612 | } |
613 | |
614 | // layout checkers |
615 | bool layouts::is_scanline(ir::instruction *i) { |
616 | return this->get(i->get_operand(0))->to_scanline() != nullptr; |
617 | } |
618 | |
619 | bool layouts::is_coalesced_scanline(ir::instruction *i) { |
620 | if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) { |
621 | auto *scanline = this->get(i->get_operand(0))->to_scanline(); |
622 | return scanline && scanline->get_order()[0] == red->get_axis(); |
623 | } |
624 | return false; |
625 | } |
626 | |
627 | bool layouts::is_mma(ir::instruction *i) { |
628 | return this->get(i->get_operand(0))->to_mma() != nullptr; |
629 | } |
630 | |
631 | bool layouts::is_a100_mma(ir::instruction *i) { |
632 | if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) { |
633 | return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) && |
634 | (red->get_axis() == 1); |
635 | } |
636 | return false; |
637 | } |
638 | |
639 | void layouts::create_tmp_layout(size_t id, data_layout *arg, |
640 | const std::vector<int> &axes, |
641 | const std::vector<unsigned> &shape, |
642 | ir::instruction *i, bool is_index) { |
643 | ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context()) |
644 | : i->get_type()->get_scalar_ty(); |
645 | layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true); |
646 | if (is_index) { |
647 | tmp_index_[i] = id; |
648 | } else { |
649 | tmp_[i] = id; |
650 | } |
651 | } |
652 | |
653 | void layouts::run(ir::module &mod) { |
654 | // make graph |
655 | graph_.clear(); |
656 | layouts_.clear(); |
657 | groups_.clear(); |
658 | |
659 | ir::for_each_instruction(mod, [this](ir::instruction* i) { |
660 | make_graph(i); |
661 | }); |
662 | |
663 | |
664 | // connected components |
665 | graph_.connected_components(&values_, &groups_); |
666 | |
667 | // create layouts |
668 | for(const auto& x: values_) |
669 | create(x.first, x.second); |
670 | |
671 | // create temporaries |
672 | size_t id = values_.size(); |
673 | ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { |
674 | // std::cout << "layout: " << std::endl; |
675 | // i->print(std::cout); |
676 | if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) { |
677 | ir::value *arg = red->get_operand(0); |
678 | distributed_layout *layout = |
679 | dynamic_cast<analysis::distributed_layout *>(get(arg)); |
680 | // shape |
681 | auto shapes = arg->get_type()->get_block_shapes(); |
682 | unsigned axis = red->get_axis(); |
683 | shapes[axis] = |
684 | layout->shape_per_cta(axis) / layout->contig_per_thread(axis); |
685 | // create layout |
686 | id++; |
687 | create_tmp_layout(id, layout, axes_->get(arg), shapes, red); |
688 | |
689 | if (red->with_index()) { |
690 | id++; |
691 | create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true); |
692 | } |
693 | } |
694 | if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){ |
695 | distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val)); |
696 | distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0))); |
697 | size_t dim = val->get_type()->get_tile_rank(); |
698 | ir::type::block_shapes_t shape(dim); |
699 | for(size_t k = 0; k < dim; k++){ |
700 | shape[k] = std::max(in_layout->shape_per_cta(k), |
701 | out_layout->shape_per_cta(k)); |
702 | } |
703 | auto in_ord = in_layout->get_order(); |
704 | auto out_ord = out_layout->get_order(); |
705 | int in_vec = in_layout->contig_per_thread(in_ord[0]); |
706 | int out_vec = out_layout->contig_per_thread(out_ord[0]); |
707 | int pad = std::max(in_vec, out_vec); |
708 | shape[out_ord[0]] += pad; |
709 | id++; |
710 | create_tmp_layout(id, out_layout, axes_->get(val), shape, val); |
711 | } |
712 | if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){ |
713 | id++; |
714 | create_tmp_layout(id, nullptr, {}, {1}, atom); |
715 | } |
716 | }); |
717 | |
718 | } |
719 | |
720 | } |
721 | } |
722 | } |
723 | |