1#include "triton/codegen/transform/disassociate.h"
2#include "triton/ir/utils.h"
3#include "triton/ir/instructions.h"
4#include "triton/ir/builder.h"
5#include "triton/ir/module.h"
6#include <iostream>
7
8namespace triton {
9namespace codegen{
10namespace transform{
11
12ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root,
13 std::set<ir::value*>& seen) {
14 if (dynamic_cast<ir::phi_node*>(root))
15 return root;
16 if(!seen.insert(root).second)
17 return root;
18 if(!root->get_type()->is_block_ty())
19 return root;
20
21 bld.set_insert_point(root);
22 ir::instruction *new_root = bld.insert(root->clone());
23 for(ir::value *op: root->ops()){
24 ir::instruction *i = dynamic_cast<ir::instruction*>(op);
25 if(!i || i->get_id() == ir::INST_REDUCE)
26 continue;
27 ir::instruction* new_op = rematerialize(bld, i, seen);
28 new_root->replace_uses_of_with(op, new_op);
29 }
30 return new_root;
31}
32
33void disassociate::run(ir::module &mod) {
34 ir::builder &bld = mod.get_builder();
35
36// ir::for_each_instruction(mod, [&](ir::instruction *i){
37// bld.set_insert_point(i);
38// for(ir::value* op: i->ops()){
39// auto reshape = dynamic_cast<ir::make_range*>(op);
40// if(!reshape)
41// continue;
42// ir::instruction* new_op = bld.insert(reshape->clone());
43// i->replace_uses_of_with(op, new_op);
44// }
45// });
46
47
48 ir::for_each_instruction(mod, [&](ir::instruction *i){
49 if(dynamic_cast<ir::reshape_inst*>(i) || dynamic_cast<ir::splat_inst*>(i)){
50 std::set<ir::value*> seen;
51 ir::instruction* new_i = rematerialize(bld, i, seen);
52 i->replace_all_uses_with(new_i);
53 }
54 });
55
56
57}
58
59
60}
61}
62}
63