1#include <ir_builder.h>
2#include <kernel_ir_dispatch.h>
3#include <lower_utils.h>
4
5#include <lower_fusion_simplifier.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12namespace {
13
14// Replace trivial reductions with unary ops.
15class TrivialReductionReplacement : private OptOutMutator {
16 public:
17 TrivialReductionReplacement(
18 Fusion* fusion,
19 const TrivialReductionInfo& trivial_reduction_info)
20 : trivial_reduction_info_(trivial_reduction_info) {
21 FusionGuard fg(fusion);
22 auto exprs = StmtSort::getExprs(fusion);
23 for (auto expr : exprs) {
24 mutate(expr);
25 }
26 }
27
28 private:
29 using OptOutMutator::mutate;
30 void mutate(ReductionOp* rop) final {
31 if (ir_utils::isTvOp(rop)) {
32 auto out_tv = ir_utils::getTvOutput(rop);
33 if (std::all_of(
34 out_tv->domain()->domain().begin(),
35 out_tv->domain()->domain().end(),
36 [&](IterDomain* id) {
37 // If id is a reduction axis, is it a trivial reduction?
38 if (id->isReduction()) {
39 return trivial_reduction_info_.isDerived(id);
40 } else {
41 return true;
42 }
43 })) {
44 auto out = rop->out();
45 auto in = rop->in();
46 auto container = out->container();
47 removeExpr(container, rop);
48 IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in);
49 }
50 }
51 }
52
53 void mutate(GroupedReductionOp* grouped_rop) final {
54 if (ir_utils::isTvOp(grouped_rop)) {
55 // The inputs and outputs are all uniform in grouped reductions,
56 // so just checking one of the input and output pair should be
57 // sufficient.
58 auto out_tv = ir_utils::getTvOutput(grouped_rop);
59 if (std::all_of(
60 out_tv->domain()->domain().begin(),
61 out_tv->domain()->domain().end(),
62 [&](IterDomain* id) {
63 // If id is a reduction axis, is it a trivial reduction?
64 if (id->isReduction()) {
65 return trivial_reduction_info_.isDerived(id);
66 } else {
67 return true;
68 }
69 })) {
70 auto outputs = grouped_rop->outputs();
71 auto inputs = grouped_rop->inputs();
72 auto container = out_tv->container();
73 removeExpr(container, grouped_rop);
74 for (const auto i : c10::irange(outputs.size())) {
75 IrBuilder::create<UnaryOp>(
76 container, UnaryOpType::Set, outputs.at(i), inputs.at(i));
77 }
78 }
79 }
80 }
81
82 const TrivialReductionInfo& trivial_reduction_info_;
83};
84
85// Replaces Transpose, Shift, Gather, and View Ops with Unary Ops.
86class UnaryOpInserter : private kir::ExprMutator {
87 public:
88 static std::vector<Expr*> insert(const std::vector<Expr*>& exprs) {
89 UnaryOpInserter inserter(exprs);
90 return inserter.exprs_;
91 }
92
93 private:
94 using kir::ExprMutator::handle;
95
96 UnaryOpInserter(const std::vector<Expr*>& exprs) {
97 kir::ExprMutator::traverseAndInsert(exprs);
98 }
99
100 void handle(TransposeOp* top) final {
101 auto out = top->out();
102 auto in = top->in();
103 auto container = out->container();
104 registerReplace(
105 top, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in));
106 }
107
108 void handle(ExpandOp* eop) final {
109 auto out = eop->out();
110 auto in = eop->in();
111 auto container = out->container();
112 registerReplace(
113 eop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in));
114 }
115
116 void handle(ShiftOp* sop) final {
117 auto out = sop->out();
118 auto in = sop->in();
119 auto container = out->container();
120 registerReplace(
121 sop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in));
122 }
123
124 void handle(GatherOp* gop) final {
125 auto out = gop->out();
126 auto in = gop->in();
127 auto container = out->container();
128 registerReplace(
129 gop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in));
130 }
131
132 void handle(ViewOp* vop) final {
133 auto out = vop->out();
134 auto in = vop->in();
135 auto container = out->container();
136 registerReplace(
137 vop, IrBuilder::create<UnaryOp>(container, UnaryOpType::Set, out, in));
138 }
139};
140
141} // namespace
142
143void trivialReductionReplacement(
144 Fusion* fusion,
145 const TrivialReductionInfo& trivial_reduction_info) {
146 TrivialReductionReplacement replacement(fusion, trivial_reduction_info);
147}
148
149// Transpose, Shift, Gather, and View Ops with Unary Set Ops
150std::vector<Expr*> unarySetOpInserter(const std::vector<Expr*>& exprs) {
151 return UnaryOpInserter::insert(exprs);
152}
153
154} // namespace cuda
155} // namespace fuser
156} // namespace jit
157} // namespace torch
158