1#include <algorithm>
2#include <numeric>
3#include <iostream>
4#include "triton/codegen/analysis/axes.h"
5#include "triton/codegen/analysis/align.h"
6#include "triton/codegen/analysis/layout.h"
7#include "triton/ir/function.h"
8#include "triton/ir/module.h"
9#include "triton/ir/utils.h"
10// #include "triton/ir/type.h"
11
12namespace triton{
13namespace codegen{
14namespace analysis{
15
16/* -------------------------------- *
17 * Helper Functions *
18 * -------------------------------- */
19
20inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
21 unsigned lo = std::min(a, b);
22 unsigned hi = std::max(a, b);
23 return std::min(std::max(x, lo), hi);
24}
25
26inline bool is_hmma_c(ir::value *v, int sm){
27 bool result = false;
28 if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
29 ir::value *a = x->get_operand(0);
30 ir::type *a_ty = a->get_type();
31 ir::value *b = x->get_operand(1);
32 ir::type *b_ty = b->get_type();
33 result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
34 (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
35 (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
36 x->allow_tf32() && sm >= 80) ||
37 (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
38 sm >= 80);
39 }
40 return result;
41}
42
43static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
44 mma_layout::TensorCoreType mma_type;
45 if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
46 ir::value* a = dot->get_operand(0);
47 ir::value* b = dot->get_operand(1);
48 ir::type* a_ty = a->get_type();
49 ir::type* b_ty = b->get_type();
50 ir::type* c_ty = v->get_type();
51
52 if (c_ty->get_scalar_ty()->is_fp32_ty()) {
53 // floating point tensor cores
54 if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
55 mma_type = mma_layout::FP32_FP16_FP16_FP32;
56 return mma_type;
57 }
58 if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
59 mma_type = mma_layout::FP32_BF16_BF16_FP32;
60 return mma_type;
61 }
62 if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
63 && dot->allow_tf32()) {
64 mma_type = mma_layout::FP32_TF32_TF32_FP32;
65 return mma_type;
66 }
67 } else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
68 // throw std::runtime_error("integer tensor cores are not yet supported");
69 // // integer tensor cores
70 // if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
71 // mma_type = mma_layout::INT32_INT1_INT1_INT32;
72 // return mma_type;
73 // }
74 // if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
75 // mma_type = mma_layout::INT32_INT4_INT4_INT32;
76 // return mma_type;
77 // }
78 if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
79 mma_type = mma_layout::INT32_INT8_INT8_INT32;
80 return mma_type;
81 }
82 }
83 }
84 return mma_layout::NOT_APPLICABLE;
85}
86
87inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
88 for(ir::user* u: v->get_users()){
89 auto i = dynamic_cast<ir::io_inst*>(u);
90 if(i && i->get_pointer_operand() == v)
91 result.insert(v);
92 }
93}
94
95inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
96 for(ir::user* u: v->get_users()){
97 auto i = dynamic_cast<ir::dot_inst*>(u);
98 if(i && i->get_operand(n) == v)
99 result = v;
100 }
101}
102
103inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
104 for(ir::user* u: v->get_users()){
105 auto i = dynamic_cast<ir::dot_inst*>(u);
106 if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
107 result = i;
108 }
109 }
110}
111
112
113inline bool is_trans(ir::value *v) {
114 if(dynamic_cast<ir::trans_inst *>(v)) {
115 return true;
116 }
117 if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
118 bool result = true;
119 for(ir::value *op: phi->ops())
120 result = result && is_trans(op);
121 return result;
122 }
123 return false;
124}
125
126
127/* -------------------------------- *
128 * Layout Visitor *
129 * -------------------------------- */
130
131void layout_visitor::visit_layout(data_layout *layout) {
132 layout->accept(this);
133}
134
135
136/* -------------------------------- *
137 * Base Data Layout *
138 * -------------------------------- */
139
140data_layout::data_layout(id_t id,
141 const std::vector<int> &axes,
142 const std::vector<unsigned> &shape,
143 const std::vector<ir::value *> &values,
144 analysis::align* align): id_(id), axes_(axes), shape_(shape), values_(values) {
145 // io pointer
146 std::set<ir::value*> ptr;
147 for(ir::value* v: values_)
148 extract_io_use(v, ptr);
149 order_.resize(axes_.size());
150 std::iota(order_.begin(), order_.end(), 0);
151 std::vector<unsigned> max_contiguous;
152 for(ir::value* p: ptr){
153 std::vector<unsigned> curr = align->contiguous(p);
154 if(curr.size() > max_contiguous.size())
155 max_contiguous = curr;
156 else if(curr.size() == max_contiguous.size()){
157 if(*std::max_element(curr.begin(), curr.end()) > *std::max_element(max_contiguous.begin(), max_contiguous.end()))
158 max_contiguous = curr;
159 }
160 }
161 if(max_contiguous.size() > 0){
162 std::sort(order_.begin(), order_.end(), [&](unsigned a, unsigned b) {
163 return max_contiguous[a] > max_contiguous[b];
164 });
165// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl;
166// std::cout << order_[0] << " " << order_[1] << std::endl;
167 }
168}
169
170int data_layout::find_axis(int to_find) const {
171 auto it = std::find(axes_.begin(), axes_.end(), to_find);
172 if(it == axes_.end())
173 return -1;
174 return std::distance(axes_.begin(), it);
175}
176
177
178distributed_layout::distributed_layout(id_t id,
179 const std::vector<int> &axes,
180 const std::vector<unsigned> &shape,
181 const std::vector<ir::value *> &values,
182 analysis::align* align): data_layout(id, axes, shape, values, align)
183{ }
184
185/* -------------------------------- *
186 * MMA Layout *
187 * -------------------------------- */
188
189mma_layout::mma_layout(size_t num_warps,
190 const std::vector<int>& axes,
191 const std::vector<unsigned>& shape,
192 const std::vector<ir::value *> &values,
193 analysis::align* align, target* tgt,
194 shared_layout *layout_a, shared_layout *layout_b,
195 ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
196 tensor_core_type_ = get_mma_type(dot);
197 /* fragments per warp */
198 // try to make things as square as possible to maximize data re-use
199 if(tgt->as_nvidia()->sm() < 80){
200 fpw_ = {2, 2, 1};
201 auto ord_a = layout_a->get_order();
202 auto ord_b = layout_b->get_order();
203 bool is_a_row = ord_a[0] != 0;
204 bool is_b_row = ord_b[0] != 0;
205 bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
206 bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
207 int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
208 int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
209 rep_ = {2*pack_size_0, 2*pack_size_1, 1};
210 spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
211 contig_per_thread_ = {1, 1};
212 order_ = {0, 1};
213 }
214 else{
215 spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
216 contig_per_thread_ = {1, 2};
217 order_ = {1, 0};
218 }
219
220 /* warps per tile */
221 wpt_ = {1, 1, 1};
222 // try to make warp-level tiles as square as possible to maximize data re-use
223 if (tgt->as_nvidia()->sm() < 80) {
224 std::vector<int> wpt_nm1;
225 do{
226 wpt_nm1 = wpt_;
227 if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
228 wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
229 if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
230 wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
231 }while(wpt_nm1 != wpt_);
232 } else {
233 bool changed = false;
234 // try to have a warp own entire rows of the output
235 // this makes it easier to fuse multiple mmas by fusing
236 // registers
237 bool one_warp_per_row = false;
238 for(ir::value* v: values)
239 for(ir::user* u: v->get_users()){
240 auto* dot = dynamic_cast<ir::dot_inst*>(u);
241 auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(u);
242 if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts)
243 one_warp_per_row = shape[0] / spw_[0] >= num_warps;
244 }
245 // std::cout << one_warp_per_row << std::endl;
246
247 if(one_warp_per_row){
248 wpt_[1] = 1;
249 wpt_[0] = num_warps;
250 }
251 else{
252 do {
253 changed = false;
254 if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps)
255 break;
256 if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) {
257 if (wpt_[0] < shape_[0] / spw_[0]) {
258 wpt_[0] *= 2;
259 changed = true;
260 }
261 } else {
262 if (wpt_[1] < shape_[1] / (spw_[1]*2)) {
263 wpt_[1] *= 2;
264 changed = true;
265 }
266 }
267 } while(changed);
268 }
269 }
270
271 // std::cout << wpt_[0] << " " << wpt_[1] << std::endl;
272
273 /* shape per block */
274 shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
275}
276
277
278/* -------------------------------- *
279 * Scanline Layout *
280 * -------------------------------- */
281
282scanline_layout::scanline_layout(size_t num_warps,
283 const std::vector<int>& axes,
284 const std::vector<unsigned>& shape,
285 const std::vector<ir::value *> &values,
286 analysis::align* align, target *tgt): distributed_layout(SCANLINE, axes, shape, values, align){
287 unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int>());
288 unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1;
289 nts_.resize(shape_.size());
290 mts_.resize(shape_.size());
291 bool is_dot = std::any_of(values.begin(), values.end(),
292 [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
293
294 std::vector<ir::value*> ptrs;
295 for(ir::value *v: values)
296 for(ir::user *usr: v->get_users())
297 if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
298 if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
299 ptrs.push_back(io->get_pointer_operand());
300 }
301
302 unsigned i = order_[0];
303 int contiguous = 1;
304 for(ir::value* ptr: ptrs){
305 int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
306 contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
307 }
308
309 nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
310 mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
311 size /= shape_[i];
312 num_threads /= mts_[i];
313 if(is_dot)
314 nts_[order_[1]] = clamp(size / num_threads, 1, std::min<int>(4, shape_[order_[1]]));
315 for(size_t d = 1; d < shape_.size(); d++){
316 i = order_[d];
317 if(d > 1 || !is_dot)
318 nts_[i] = 1;
319 mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
320 num_threads = num_threads / mts_[i];
321 }
322
323 shape_per_cta_.resize(shape_.size());
324 for(size_t d = 0; d < shape_.size(); d++)
325 shape_per_cta_[d] = mts_[d]*nts_[d];
326}
327
328
329/* -------------------------------- *
330 * Shared Layout *
331 * -------------------------------- */
332
333bool shared_layout::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){
334 if(phi->get_parent() != terminator->get_parent())
335 return false;
336 if(auto *br = dynamic_cast<ir::cond_branch_inst*>(terminator))
337 return br->get_true_dest() == phi->get_parent()
338 || br->get_false_dest() == phi->get_parent();
339 else if(dynamic_cast<ir::uncond_branch_inst*>(terminator))
340 return false;
341 else
342 throw std::runtime_error("unreachable");
343}
344
345
346void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<double_buffer_info_t>& res) {
347 auto* phi = dynamic_cast<ir::phi_node*>(v);
348 if(!phi || phi->get_num_incoming() != 2)
349 return;
350 ir::basic_block *block_0 = phi->get_incoming_block(0);
351 ir::basic_block *block_1 = phi->get_incoming_block(1);
352 ir::instruction *terminator_0 = block_0->get_inst_list().back();
353 ir::instruction *terminator_1 = block_1->get_inst_list().back();
354 bool is_latch_0 = is_loop_latch(phi, terminator_0);
355 bool is_latch_1 = is_loop_latch(phi, terminator_1);
356 ir::value *value_0 = phi->get_incoming_value(0);
357 ir::value *value_1 = phi->get_incoming_value(1);
358 ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
359 ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
360 if(!(i_0 && !i_1) &&
361 !(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
362 !(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
363 return;
364 if(is_latch_1)
365 res.reset(new double_buffer_info_t{value_0, value_1, phi});
366 if(is_latch_0)
367 res.reset(new double_buffer_info_t{value_1, value_0, phi});
368}
369
370static bool is_smem_in(ir::value* v, const ir::basic_block* bb) {
371 if (ir::instruction *instr = dynamic_cast<ir::instruction*>(v)) {
372 if (instr->get_parent() != bb)
373 return false;
374 if (dynamic_cast<ir::copy_to_shared_inst*>(v) ||
375 dynamic_cast<ir::masked_load_async_inst*>(v)) {
376 return true;
377 }
378 }
379 return false;
380}
381
382/// param:
383/// value_1: next_value
384static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1,
385 std::vector<ir::value*>& values_0, ir::value*& value_1) {
386 ir::value* next = phi;
387 while (auto cphi = dynamic_cast<ir::phi_node*>(next)) {
388 // smem from previous bb & phi/smem from current bb
389 ir::value* c0 = cphi->get_incoming_value(0);
390 ir::value* c1 = cphi->get_incoming_value(1);
391 ir::basic_block *cbb0 = cphi->get_incoming_block(0);
392 ir::basic_block *cbb1 = cphi->get_incoming_block(1);
393
394 if (is_smem_in(c0, cbb0)) {
395 assert(cbb0 == bb0);
396 values_0.push_back(c0);
397 if (auto phi1 = dynamic_cast<ir::phi_node*>(c1)) {
398 next = phi1;
399 continue;
400 } else {
401 if (is_smem_in(c1, cbb1)) {
402 value_1 = c1;
403 assert(cbb1 == bb1);
404 return true;
405 } else {
406 return false;
407 }
408 }
409 } else
410 return false;
411 }
412 return false;
413}
414
415void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr<N_buffer_info_t> &res, int &prev_stages) {
416 auto* phi = dynamic_cast<ir::phi_node*>(v);
417 // if the phi node is nested
418 if (!phi)
419 return;
420
421 ir::basic_block *bb0 = phi->get_incoming_block(0);
422 ir::basic_block *bb1 = phi->get_incoming_block(1);
423
424 std::vector<ir::value*> values_0;
425 ir::value* value_1;
426
427 if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1))
428 return;
429
430 // double-buffer is a special case
431 if (values_0.size() == 1)
432 return;
433
434 // compute original values_0 input order
435 std::map<ir::value*, int> order;
436 int idx = 0;
437 for (ir::instruction* instr : *bb0) {
438 if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end())
439 order[static_cast<ir::value*>(instr)] = idx++;
440 }
441 assert(order.size() == values_0.size() && "order size incorrect");
442
443 int curr_stages = values_0.size() + 1;
444 if (curr_stages > prev_stages) {
445 res.reset(new N_buffer_info_t{values_0, value_1, phi, order});
446 prev_stages = curr_stages;
447 }
448}
449
450
451shared_layout::shared_layout(data_layout *arg,
452 const std::vector<int>& axes,
453 const std::vector<unsigned>& shape,
454 const std::vector<ir::value *> &values,
455 ir::type *ty,
456 analysis::align* align, target *tgt, bool is_tmp)
457 : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){
458
459 size_ = 0;
460 arg_layout_ = arg;
461
462 // N-stage buffering
463 int prev_stages = 0;
464 for (ir::value *v : values)
465 extract_N_bufferable(v, N_buffer_, prev_stages);
466
467 // double-buffering
468 if (!N_buffer_)
469 for(ir::value *v: values)
470 extract_double_bufferable(v, double_buffer_);
471
472 // order
473 std::vector<int> arg_order = arg ? arg->get_order() : std::vector<int>{0};
474 order_ = arg_order;
475
476 ir::value* dot_a = nullptr;
477 ir::value* dot_b = nullptr;
478 ir::value* hmma_dot_a = nullptr;
479 ir::value* hmma_dot_b = nullptr;
480 for(ir::value* v: values){
481 extract_dot_use(v, dot_a, 0);
482 extract_dot_use(v, dot_b, 1);
483 extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
484 extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
485 }
486 hmma_dot_a_ = hmma_dot_a;
487 hmma_dot_b_ = hmma_dot_b;
488
489 // Update mma_vec
490 if (hmma_dot_a_) {
491 assert(order_.size() == 2);
492 std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
493 mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
494 mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
495
496 // for now, disable swizzle when using lds.8
497 if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
498 if (order_[0] == 0) // need transpose
499 allow_swizzle_ = false;
500 } else if (hmma_dot_b_) {
501 assert(order_.size() == 2);
502 std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
503 mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
504 mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
505
506 // for now, disable swizzle when using lds.8
507 if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
508 if (order_[0] == 1) // need transpose
509 allow_swizzle_ = false;
510 }
511
512 // size
513 size_ = ty_->get_primitive_size_in_bits() / 8;
514 for(auto s: shape_)
515 size_ *= s;
516 if(double_buffer_)
517 size_ *= 2;
518 if (N_buffer_) {
519 size_ *= (N_buffer_->firsts.size() + 1);
520 }
521}
522
523int shared_layout::get_num_stages() const {
524 if (double_buffer_)
525 return 2;
526 if (N_buffer_)
527 return N_buffer_->firsts.size() + 1;
528 return 1;
529}
530
531size_t shared_layout::get_per_stage_elements() const {
532 return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8);
533}
534
535/* -------------------------------- *
536 * ---- Layouts Inference Pass ---- *
537 * -------------------------------- */
538
539layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt)
540 : axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ }
541
542
543void layouts::connect(ir::value *x, ir::value *y) {
544 if(x == y)
545 return;
546 if(!x->get_type()->is_block_ty())
547 return;
548 if(!y->get_type()->is_block_ty())
549 return;
550 std::vector<int> x_axes = axes_->get(x);
551 std::vector<int> y_axes = axes_->get(y);
552 std::set<int> sx_axes(x_axes.begin(), x_axes.end());
553 std::set<int> sy_axes(y_axes.begin(), y_axes.end());
554 std::set<int> common;
555 std::set_intersection(sx_axes.begin(), sx_axes.end(),
556 sy_axes.begin(), sy_axes.end(),
557 std::inserter(common, common.begin()));
558 graph_.add_edge(x, x);
559 graph_.add_edge(y, y);
560 if(!common.empty())
561 graph_.add_edge(x, y);
562}
563
564void layouts::make_graph(ir::instruction *i) {
565 for(ir::value* opx: i->ops())
566 for(ir::value* opy: i->ops()){
567 connect(i, opx);
568 connect(opx, opy);
569 }
570}
571
572void layouts::create(size_t id, const std::vector<ir::value*>& values) {
573// if(layouts_.find(id) != layouts_.end())
574// return;
575 auto it_hmma_c = std::find_if(values.begin(), values.end(),
576 [&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
577 auto cmp = [](ir::value* x, ir::value *y) {
578 std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
579 std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
580 return xx < yy;
581 };
582 std::vector<ir::value*> lvalue = values;
583 std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast<ir::trans_inst*>(v); });
584 ir::value *largest = *std::max_element(lvalue.begin(), lvalue.end(), cmp);
585 const auto& axes = axes_->get(largest);
586 const auto& shapes = largest->get_type()->get_block_shapes();
587 auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
588 return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
589 dynamic_cast<ir::masked_load_async_inst*>(v);
590 });
591 // type
592 if(it_hmma_c != values.end()){
593 ir::instruction *dot = (ir::instruction*)*it_hmma_c;
594 ir::value *a = dot->get_operand(0);
595 ir::value *b = dot->get_operand(1);
596 create(groups_.at(a), values_.at(groups_.at(a)));
597 create(groups_.at(b), values_.at(groups_.at(b)));
598 layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
599 (shared_layout*)layouts_.at(groups_.at(a)),
600 (shared_layout*)layouts_.at(groups_.at(b)),
601 dot);
602 }
603 else if(it_cts != values.end()){
604 ir::instruction *cts = (ir::instruction*)*it_cts;
605 ir::value *arg = cts->get_operand(0);
606 create(groups_.at(arg), values_.at(groups_.at(arg)));
607 layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
608 }
609 else{
610 layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
611 }
612}
613
614// layout checkers
615bool layouts::is_scanline(ir::instruction *i) {
616 return this->get(i->get_operand(0))->to_scanline() != nullptr;
617}
618
619bool layouts::is_coalesced_scanline(ir::instruction *i) {
620 if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
621 auto *scanline = this->get(i->get_operand(0))->to_scanline();
622 return scanline && scanline->get_order()[0] == red->get_axis();
623 }
624 return false;
625}
626
627bool layouts::is_mma(ir::instruction *i) {
628 return this->get(i->get_operand(0))->to_mma() != nullptr;
629}
630
631bool layouts::is_a100_mma(ir::instruction *i) {
632 if (auto *red = dynamic_cast<ir::reduce_inst *>(i)) {
633 return is_mma(red) && (tgt_->as_nvidia()->sm() >= 80) &&
634 (red->get_axis() == 1);
635 }
636 return false;
637}
638
639void layouts::create_tmp_layout(size_t id, data_layout *arg,
640 const std::vector<int> &axes,
641 const std::vector<unsigned> &shape,
642 ir::instruction *i, bool is_index) {
643 ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context())
644 : i->get_type()->get_scalar_ty();
645 layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true);
646 if (is_index) {
647 tmp_index_[i] = id;
648 } else {
649 tmp_[i] = id;
650 }
651}
652
653void layouts::run(ir::module &mod) {
654 // make graph
655 graph_.clear();
656 layouts_.clear();
657 groups_.clear();
658
659 ir::for_each_instruction(mod, [this](ir::instruction* i) {
660 make_graph(i);
661 });
662
663
664 // connected components
665 graph_.connected_components(&values_, &groups_);
666
667 // create layouts
668 for(const auto& x: values_)
669 create(x.first, x.second);
670
671 // create temporaries
672 size_t id = values_.size();
673 ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
674// std::cout << "layout: " << std::endl;
675// i->print(std::cout);
676 if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
677 ir::value *arg = red->get_operand(0);
678 distributed_layout *layout =
679 dynamic_cast<analysis::distributed_layout *>(get(arg));
680 // shape
681 auto shapes = arg->get_type()->get_block_shapes();
682 unsigned axis = red->get_axis();
683 shapes[axis] =
684 layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
685 // create layout
686 id++;
687 create_tmp_layout(id, layout, axes_->get(arg), shapes, red);
688
689 if (red->with_index()) {
690 id++;
691 create_tmp_layout(id, layout, axes_->get(arg), shapes, red, true);
692 }
693 }
694 if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
695 distributed_layout* out_layout = dynamic_cast<distributed_layout*>(get(val));
696 distributed_layout* in_layout = dynamic_cast<distributed_layout*>(get(i->get_operand(0)));
697 size_t dim = val->get_type()->get_tile_rank();
698 ir::type::block_shapes_t shape(dim);
699 for(size_t k = 0; k < dim; k++){
700 shape[k] = std::max(in_layout->shape_per_cta(k),
701 out_layout->shape_per_cta(k));
702 }
703 auto in_ord = in_layout->get_order();
704 auto out_ord = out_layout->get_order();
705 int in_vec = in_layout->contig_per_thread(in_ord[0]);
706 int out_vec = out_layout->contig_per_thread(out_ord[0]);
707 int pad = std::max(in_vec, out_vec);
708 shape[out_ord[0]] += pad;
709 id++;
710 create_tmp_layout(id, out_layout, axes_->get(val), shape, val);
711 }
712 if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
713 id++;
714 create_tmp_layout(id, nullptr, {}, {1}, atom);
715 }
716 });
717
718}
719
720}
721}
722}
723