1 | #include <ir_builder.h> |
2 | #include <kernel_ir_dispatch.h> |
3 | #include <lower_utils.h> |
4 | |
5 | #include <lower_fusion_simplifier.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | namespace { |
13 | |
14 | // Replace trivial reductions with unary ops. |
15 | class TrivialReductionReplacement : private OptOutMutator { |
16 | public: |
17 | TrivialReductionReplacement( |
18 | Fusion* fusion, |
19 | const TrivialReductionInfo& trivial_reduction_info) |
20 | : trivial_reduction_info_(trivial_reduction_info) { |
21 | FusionGuard fg(fusion); |
22 | auto exprs = StmtSort::getExprs(fusion); |
23 | for (auto expr : exprs) { |
24 | mutate(expr); |
25 | } |
26 | } |
27 | |
28 | private: |
29 | using OptOutMutator::mutate; |
30 | void mutate(ReductionOp* rop) final { |
31 | if (ir_utils::isTvOp(rop)) { |
32 | auto out_tv = ir_utils::getTvOutput(rop); |
33 | if (std::all_of( |
34 | out_tv->domain()->domain().begin(), |
35 | out_tv->domain()->domain().end(), |
36 | [&](IterDomain* id) { |
37 | // If id is a reduction axis, is it a trivial reduction? |
38 | if (id->isReduction()) { |
39 | return trivial_reduction_info_.isDerived(id); |
40 | } else { |
41 | return true; |
42 | } |
43 | })) { |
44 | auto out = rop->out(); |
45 | auto in = rop->in(); |
46 | auto container = out->container(); |
47 | removeExpr(container, rop); |
48 | IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in); |
49 | } |
50 | } |
51 | } |
52 | |
53 | void mutate(GroupedReductionOp* grouped_rop) final { |
54 | if (ir_utils::isTvOp(grouped_rop)) { |
55 | // The inputs and outputs are all uniform in grouped reductions, |
56 | // so just checking one of the input and output pair should be |
57 | // sufficient. |
58 | auto out_tv = ir_utils::getTvOutput(grouped_rop); |
59 | if (std::all_of( |
60 | out_tv->domain()->domain().begin(), |
61 | out_tv->domain()->domain().end(), |
62 | [&](IterDomain* id) { |
63 | // If id is a reduction axis, is it a trivial reduction? |
64 | if (id->isReduction()) { |
65 | return trivial_reduction_info_.isDerived(id); |
66 | } else { |
67 | return true; |
68 | } |
69 | })) { |
70 | auto outputs = grouped_rop->outputs(); |
71 | auto inputs = grouped_rop->inputs(); |
72 | auto container = out_tv->container(); |
73 | removeExpr(container, grouped_rop); |
74 | for (const auto i : c10::irange(outputs.size())) { |
75 | IrBuilder::create<UnaryOp>( |
76 | container, UnaryOpType::Set, outputs.at(i), inputs.at(i)); |
77 | } |
78 | } |
79 | } |
80 | } |
81 | |
82 | const TrivialReductionInfo& trivial_reduction_info_; |
83 | }; |
84 | |
85 | // Replaces Transpose, Shift, Gather, and View Ops with Unary Ops. |
86 | class UnaryOpInserter : private kir::ExprMutator { |
87 | public: |
88 | static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) { |
89 | UnaryOpInserter inserter(exprs); |
90 | return inserter.exprs_; |
91 | } |
92 | |
93 | private: |
94 | using kir::ExprMutator::handle; |
95 | |
96 | UnaryOpInserter(const std::vector<Expr*>& exprs) { |
97 | kir::ExprMutator::traverseAndInsert(exprs); |
98 | } |
99 | |
100 | void handle(TransposeOp* top) final { |
101 | auto out = top->out(); |
102 | auto in = top->in(); |
103 | auto container = out->container(); |
104 | registerReplace( |
105 | top, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in)); |
106 | } |
107 | |
108 | void handle(ExpandOp* eop) final { |
109 | auto out = eop->out(); |
110 | auto in = eop->in(); |
111 | auto container = out->container(); |
112 | registerReplace( |
113 | eop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in)); |
114 | } |
115 | |
116 | void handle(ShiftOp* sop) final { |
117 | auto out = sop->out(); |
118 | auto in = sop->in(); |
119 | auto container = out->container(); |
120 | registerReplace( |
121 | sop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in)); |
122 | } |
123 | |
124 | void handle(GatherOp* gop) final { |
125 | auto out = gop->out(); |
126 | auto in = gop->in(); |
127 | auto container = out->container(); |
128 | registerReplace( |
129 | gop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in)); |
130 | } |
131 | |
132 | void handle(ViewOp* vop) final { |
133 | auto out = vop->out(); |
134 | auto in = vop->in(); |
135 | auto container = out->container(); |
136 | registerReplace( |
137 | vop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in)); |
138 | } |
139 | }; |
140 | |
141 | } // namespace |
142 | |
143 | void trivialReductionReplacement( |
144 | Fusion* fusion, |
145 | const TrivialReductionInfo& trivial_reduction_info) { |
146 | TrivialReductionReplacement replacement(fusion, trivial_reduction_info); |
147 | } |
148 | |
149 | // Transpose, Shift, Gather, and View Ops with Unary Set Ops |
150 | std::vector<Expr*> unarySetOpInserter(const std::vector<Expr*>& exprs) { |
151 | return UnaryOpInserter::insert(exprs); |
152 | } |
153 | |
154 | } // namespace cuda |
155 | } // namespace fuser |
156 | } // namespace jit |
157 | } // namespace torch |
158 | |