1#include "triton/codegen/analysis/layout.h"
2#include "triton/codegen/transform/cts.h"
3#include "triton/ir/module.h"
4#include "triton/ir/function.h"
5#include "triton/ir/basic_block.h"
6#include "triton/ir/instructions.h"
7#include "triton/ir/utils.h"
8#include <iostream>
9
10namespace triton {
11namespace codegen{
12namespace transform{
13
14
15bool cts::is_shmem_op(ir::instruction* i, int op) {
16 if(i->get_id() == ir::INST_DOT)
17 return op == 0 || op == 1;
18 if(i->get_id() == ir::INST_COPY_FROM_SHARED)
19 return op==0;
20 if(i->get_id() == ir::INST_TRANS)
21 return op==0;
22 return false;
23}
24
25bool cts::is_shmem_res(ir::value* v){
26 ir::instruction* i = dynamic_cast<ir::instruction*>(v);
27 if(!i)
28 return false;
29 if(i->get_id() == ir::INST_TRANS)
30 return true;
31 if(i->get_id() == ir::INST_COPY_TO_SHARED)
32 return true;
33 if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
34 return true;
35 return false;
36}
37
38
39// run pass on module
40void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
41 auto *i = dynamic_cast<ir::instruction*>(x);
42 // not an instruction
43 if(!i) {
44 builder.set_insert_point(parent);
45 ir::value *copy;
46 if(to_shared)
47 copy = builder.create_copy_to_shared(x);
48 else
49 copy = builder.create_copy_from_shared(x);
50 parent->replace_uses_of_with(x, copy);
51 return;
52 }
53 // phi node
54 if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
55 for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
56 add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
57 return;
58 }
59 // already in shared memory
60 if(to_shared && is_shmem_res(i))
61 return;
62 // copy
63 builder.set_insert_point_after(i);
64 ir::value *copy;
65 if(to_shared){
66 copy = builder.create_copy_to_shared(x);
67 }
68 else
69 copy = builder.create_copy_from_shared(x);
70 copies.insert({x, copy});
71 parent->replace_uses_of_with(x, copies.at(x));
72}
73
74void cts::run(ir::module &mod) {
75 // Precompute where copies should be added
76 std::set<ir::value*> shmem_ops;
77 std::set<ir::value*> shmem_res;
78 ir::for_each_instruction(mod, [&](ir::instruction* i) {
79 if(i->get_id() == ir::INST_DOT){
80 ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
81 ir::value* lhs = i->get_operand(0);
82 ir::type* ty = lhs->get_type()->get_scalar_ty();
83 analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
84 // TODO: V100
85 bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
86 if(is_lhs_shmem)
87 shmem_ops.insert(lhs);
88 shmem_ops.insert(i->get_operand(1));
89 }
90 if(i->get_id() == ir::INST_COPY_FROM_SHARED)
91 shmem_ops.insert(i->get_operand(0));
92 if(i->get_id() == ir::INST_TRANS)
93 shmem_ops.insert(i->get_operand(0));
94 if(i->get_id() == ir::INST_TRANS ||
95 i->get_id() == ir::INST_COPY_TO_SHARED ||
96 i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
97 shmem_res.insert(i);
98 });
99
100 // Add shared copies
101 std::map<ir::value*, ir::value*> copies;
102 ir::builder &builder = mod.get_builder();
103 ir::for_each_instruction(mod, [&](ir::instruction* i) {
104 size_t num_op = i->get_num_operands();
105 for(size_t k = 0; k < num_op; k++){
106 ir::value* op = i->get_operand(k);
107 // copy to shared operands
108 bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
109 if(is_shmem_op)
110 add_copy(i, op, builder, true, copies);
111 }
112 });
113}
114
115
116}
117}
118}