1#include "triton/codegen/analysis/swizzle.h"
2#include "triton/codegen/analysis/layout.h"
3#include "triton/codegen/target.h"
4#include "triton/ir/type.h"
5#include <iostream>
6
7namespace triton{
8namespace codegen{
9namespace analysis{
10
11
12void swizzle::run(ir::module &) {
13 per_phase_.clear();
14 max_phase_.clear();
15
16 for(auto &x: layouts_->get_all()){
17 shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
18 if(!layout)
19 continue;
20 ir::value* mma_dot_a = layout->hmma_dot_a();
21 ir::value* mma_dot_b = layout->hmma_dot_b();
22
23 if(!mma_dot_a && !mma_dot_b){
24 per_phase_[layout] = 1;
25 max_phase_[layout] = 1;
26 vec_[layout] = 1;
27 continue;
28 }
29 auto ord = layout->get_order();
30 scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
31 int per_phase = 1;
32 int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
33 if(in_layout)
34 per_phase = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
35 else
36 per_phase = 1;
37 if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){
38 int inner = mma_dot_a ? 0 : 1;
39 per_phase_[layout] = per_phase;
40 max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
41 if(mma_dot_a)
42 vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
43 else
44 vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
45 }
46 else {
47 if (!layout->allow_swizzle()) {
48 per_phase_[layout] = 1;
49 max_phase_[layout] = 1;
50 vec_[layout] = 1;
51 } else {
52 per_phase_[layout] = per_phase;
53 max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
54 vec_[layout] = layout->get_mma_vec();
55 }
56 }
57 }
58}
59
60}
61}
62}
63
64
65