1 | #include "triton/codegen/analysis/layout.h" |
2 | #include "triton/codegen/transform/cts.h" |
3 | #include "triton/ir/module.h" |
4 | #include "triton/ir/function.h" |
5 | #include "triton/ir/basic_block.h" |
6 | #include "triton/ir/instructions.h" |
7 | #include "triton/ir/utils.h" |
8 | #include <iostream> |
9 | |
10 | namespace triton { |
11 | namespace codegen{ |
12 | namespace transform{ |
13 | |
14 | |
15 | bool cts::is_shmem_op(ir::instruction* i, int op) { |
16 | if(i->get_id() == ir::INST_DOT) |
17 | return op == 0 || op == 1; |
18 | if(i->get_id() == ir::INST_COPY_FROM_SHARED) |
19 | return op==0; |
20 | if(i->get_id() == ir::INST_TRANS) |
21 | return op==0; |
22 | return false; |
23 | } |
24 | |
25 | bool cts::is_shmem_res(ir::value* v){ |
26 | ir::instruction* i = dynamic_cast<ir::instruction*>(v); |
27 | if(!i) |
28 | return false; |
29 | if(i->get_id() == ir::INST_TRANS) |
30 | return true; |
31 | if(i->get_id() == ir::INST_COPY_TO_SHARED) |
32 | return true; |
33 | if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC) |
34 | return true; |
35 | return false; |
36 | } |
37 | |
38 | |
39 | // run pass on module |
40 | void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) { |
41 | auto *i = dynamic_cast<ir::instruction*>(x); |
42 | // not an instruction |
43 | if(!i) { |
44 | builder.set_insert_point(parent); |
45 | ir::value *copy; |
46 | if(to_shared) |
47 | copy = builder.create_copy_to_shared(x); |
48 | else |
49 | copy = builder.create_copy_from_shared(x); |
50 | parent->replace_uses_of_with(x, copy); |
51 | return; |
52 | } |
53 | // phi node |
54 | if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { |
55 | for(unsigned i = 0; i < phi->get_num_incoming(); ++i) |
56 | add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies); |
57 | return; |
58 | } |
59 | // already in shared memory |
60 | if(to_shared && is_shmem_res(i)) |
61 | return; |
62 | // copy |
63 | builder.set_insert_point_after(i); |
64 | ir::value *copy; |
65 | if(to_shared){ |
66 | copy = builder.create_copy_to_shared(x); |
67 | } |
68 | else |
69 | copy = builder.create_copy_from_shared(x); |
70 | copies.insert({x, copy}); |
71 | parent->replace_uses_of_with(x, copies.at(x)); |
72 | } |
73 | |
74 | void cts::run(ir::module &mod) { |
75 | // Precompute where copies should be added |
76 | std::set<ir::value*> shmem_ops; |
77 | std::set<ir::value*> shmem_res; |
78 | ir::for_each_instruction(mod, [&](ir::instruction* i) { |
79 | if(i->get_id() == ir::INST_DOT){ |
80 | ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i); |
81 | ir::value* lhs = i->get_operand(0); |
82 | ir::type* ty = lhs->get_type()->get_scalar_ty(); |
83 | analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma(); |
84 | // TODO: V100 |
85 | bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a()); |
86 | if(is_lhs_shmem) |
87 | shmem_ops.insert(lhs); |
88 | shmem_ops.insert(i->get_operand(1)); |
89 | } |
90 | if(i->get_id() == ir::INST_COPY_FROM_SHARED) |
91 | shmem_ops.insert(i->get_operand(0)); |
92 | if(i->get_id() == ir::INST_TRANS) |
93 | shmem_ops.insert(i->get_operand(0)); |
94 | if(i->get_id() == ir::INST_TRANS || |
95 | i->get_id() == ir::INST_COPY_TO_SHARED || |
96 | i->get_id() == ir::INST_MASKED_LOAD_ASYNC) |
97 | shmem_res.insert(i); |
98 | }); |
99 | |
100 | // Add shared copies |
101 | std::map<ir::value*, ir::value*> copies; |
102 | ir::builder &builder = mod.get_builder(); |
103 | ir::for_each_instruction(mod, [&](ir::instruction* i) { |
104 | size_t num_op = i->get_num_operands(); |
105 | for(size_t k = 0; k < num_op; k++){ |
106 | ir::value* op = i->get_operand(k); |
107 | // copy to shared operands |
108 | bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end(); |
109 | if(is_shmem_op) |
110 | add_copy(i, op, builder, true, copies); |
111 | } |
112 | }); |
113 | } |
114 | |
115 | |
116 | } |
117 | } |
118 | } |