1 | #include "triton/codegen/transform/disassociate.h" |
2 | #include "triton/ir/utils.h" |
3 | #include "triton/ir/instructions.h" |
4 | #include "triton/ir/builder.h" |
5 | #include "triton/ir/module.h" |
6 | #include <iostream> |
7 | |
8 | namespace triton { |
9 | namespace codegen{ |
10 | namespace transform{ |
11 | |
12 | ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root, |
13 | std::set<ir::value*>& seen) { |
14 | if (dynamic_cast<ir::phi_node*>(root)) |
15 | return root; |
16 | if(!seen.insert(root).second) |
17 | return root; |
18 | if(!root->get_type()->is_block_ty()) |
19 | return root; |
20 | |
21 | bld.set_insert_point(root); |
22 | ir::instruction *new_root = bld.insert(root->clone()); |
23 | for(ir::value *op: root->ops()){ |
24 | ir::instruction *i = dynamic_cast<ir::instruction*>(op); |
25 | if(!i || i->get_id() == ir::INST_REDUCE) |
26 | continue; |
27 | ir::instruction* new_op = rematerialize(bld, i, seen); |
28 | new_root->replace_uses_of_with(op, new_op); |
29 | } |
30 | return new_root; |
31 | } |
32 | |
33 | void disassociate::run(ir::module &mod) { |
34 | ir::builder &bld = mod.get_builder(); |
35 | |
36 | // ir::for_each_instruction(mod, [&](ir::instruction *i){ |
37 | // bld.set_insert_point(i); |
38 | // for(ir::value* op: i->ops()){ |
39 | // auto reshape = dynamic_cast<ir::make_range*>(op); |
40 | // if(!reshape) |
41 | // continue; |
42 | // ir::instruction* new_op = bld.insert(reshape->clone()); |
43 | // i->replace_uses_of_with(op, new_op); |
44 | // } |
45 | // }); |
46 | |
47 | |
48 | ir::for_each_instruction(mod, [&](ir::instruction *i){ |
49 | if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){ |
50 | std::set<ir::value*> seen; |
51 | ir::instruction* new_i = rematerialize(bld, i, seen); |
52 | i->replace_all_uses_with(new_i); |
53 | } |
54 | }); |
55 | |
56 | |
57 | } |
58 | |
59 | |
60 | } |
61 | } |
62 | } |
63 | |