1#include <ir_cloner.h>
2
3#include <fusion.h>
4#include <ir_all_nodes.h>
5#include <ir_builder.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12IrCloner::IrCloner(IrContainer* container) : ir_container_(container) {}
13
14Statement* IrCloner::clone(const Statement* statement) {
15 if (statement == nullptr) {
16 return nullptr;
17 }
18
19 // Have we already cloned this node?
20 const auto it = clones_map_.find(statement);
21 if (it != clones_map_.end()) {
22 return it->second;
23 } else {
24 // Clone the new node, saving/restoring this->clone_
25 // since the cloning can be reentrant
26 auto saved_clone = clone_;
27 handle(statement);
28 auto new_node = clone_;
29 clone_ = saved_clone;
30
31 // The base cloning constructor (Statement) should have
32 // registered the new node. Failure to do so indicates
33 // that something went horribly wrong.
34 TORCH_INTERNAL_ASSERT(new_node != nullptr);
35 TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node);
36
37 return new_node;
38 }
39}
40
41void IrCloner::registerClone(const Statement* src, Statement* clone) {
42 TORCH_CHECK(src != nullptr);
43 TORCH_CHECK(clone != nullptr);
44 TORCH_CHECK(clones_map_.insert({src, clone}).second);
45}
46
47void IrCloner::handle(const Statement* s) {
48 OptInConstDispatch::handle(s);
49}
50
51void IrCloner::handle(const Val* v) {
52 OptInConstDispatch::handle(v);
53}
54
55void IrCloner::handle(const Expr* e) {
56 OptInConstDispatch::handle(e);
57}
58
59void IrCloner::handle(const TensorDomain* td) {
60 clone_ = IrBuilder::clone(td, this);
61}
62
63void IrCloner::handle(const IterDomain* id) {
64 clone_ = IrBuilder::clone(id, this);
65}
66
67void IrCloner::handle(const Bool* b) {
68 clone_ = IrBuilder::clone(b, this);
69}
70
71void IrCloner::handle(const Double* d) {
72 clone_ = IrBuilder::clone(d, this);
73}
74
75void IrCloner::handle(const Int* i) {
76 clone_ = IrBuilder::clone(i, this);
77}
78
79void IrCloner::handle(const ComplexDouble* c) {
80 clone_ = IrBuilder::clone(c, this);
81}
82
83void IrCloner::handle(const NamedScalar* named_scalar) {
84 clone_ = IrBuilder::clone(named_scalar, this);
85}
86
87void IrCloner::handle(const TensorView* tv) {
88 clone_ = IrBuilder::clone(tv, this);
89}
90
91void IrCloner::handle(const FullOp* op) {
92 clone_ = IrBuilder::clone(op, this);
93}
94
95void IrCloner::handle(const ARangeOp* op) {
96 clone_ = IrBuilder::clone(op, this);
97}
98
99void IrCloner::handle(const EyeOp* op) {
100 clone_ = IrBuilder::clone(op, this);
101}
102
103void IrCloner::handle(const UnaryOp* op) {
104 clone_ = IrBuilder::clone(op, this);
105}
106
107void IrCloner::handle(const BinaryOp* op) {
108 clone_ = IrBuilder::clone(op, this);
109}
110
111void IrCloner::handle(const TernaryOp* op) {
112 clone_ = IrBuilder::clone(op, this);
113}
114
115void IrCloner::handle(const RNGOp* op) {
116 clone_ = IrBuilder::clone(op, this);
117}
118
119void IrCloner::handle(const BroadcastOp* op) {
120 clone_ = IrBuilder::clone(op, this);
121}
122
123void IrCloner::handle(const ReductionOp* op) {
124 clone_ = IrBuilder::clone(op, this);
125}
126
127void IrCloner::handle(const GroupedReductionOp* op) {
128 clone_ = IrBuilder::clone(op, this);
129}
130
131void IrCloner::handle(const WelfordOp* op) {
132 clone_ = IrBuilder::clone(op, this);
133}
134
135void IrCloner::handle(const LoadStoreOp* op) {
136 clone_ = IrBuilder::clone(op, this);
137}
138
139void IrCloner::handle(const MmaOp* op) {
140 clone_ = IrBuilder::clone(op, this);
141}
142
143void IrCloner::handle(const TransposeOp* op) {
144 clone_ = IrBuilder::clone(op, this);
145}
146
147void IrCloner::handle(const ExpandOp* op) {
148 clone_ = IrBuilder::clone(op, this);
149}
150
151void IrCloner::handle(const ShiftOp* op) {
152 clone_ = IrBuilder::clone(op, this);
153}
154
155void IrCloner::handle(const GatherOp* op) {
156 clone_ = IrBuilder::clone(op, this);
157}
158
159void IrCloner::handle(const ViewAsScalar* op) {
160 clone_ = IrBuilder::clone(op, this);
161}
162
163void IrCloner::handle(const ViewOp* op) {
164 clone_ = IrBuilder::clone(op, this);
165}
166
167void IrCloner::handle(const Split* split) {
168 clone_ = IrBuilder::clone(split, this);
169}
170
171void IrCloner::handle(const Merge* merge) {
172 clone_ = IrBuilder::clone(merge, this);
173}
174
175void IrCloner::handle(const Swizzle2D* swizzle) {
176 clone_ = IrBuilder::clone(swizzle, this);
177}
178
179TensorView* RecomputeTv::recompute(TensorView* tv) {
180 FusionGuard fg(tv->fusion());
181
182 // Disallow recomputation of inputs or outputs. User would have to be aware of
183 // these changes and informed they happened somehow.
184 TORCH_INTERNAL_ASSERT(
185 !tv->isFusionInput(),
186 "Cannot recompute buffers that are inputs of the fusion.");
187
188 // Grab all the expressions used to generate the TensorView
189 auto exprs = StmtSort::getExprs(tv->fusion(), {tv}, false);
190
191 // Run the replicator
192 RecomputeTv replicator(tv->fusion(), exprs);
193
194 // Make const version of pointer for lookup
195 const auto const_tv = tv;
196 // Find the recomputed tensor from the cloner
197 auto clone_it = replicator.clones_map_.find(const_tv);
198 TORCH_INTERNAL_ASSERT(clone_it != replicator.clones_map_.end());
199 auto cloned_val = clone_it->second;
200 TORCH_INTERNAL_ASSERT(
201 cloned_val->isA<TensorView>(),
202 "Cloned value is somehow not a tensor view.");
203
204 // Return the cloned value
205 return cloned_val->as<TensorView>();
206}
207
208RecomputeTv::RecomputeTv(Fusion* fusion, std::vector<Expr*> exprs)
209 : IrCloner(fusion), fusion_(fusion) {
210 // Add inputs to the clones map to prevent cloning them.
211 for (const auto inp : fusion->inputs()) {
212 clones_map_[inp] = inp;
213 }
214 // Adds all scalar values to clones map to prevent cloning them
215 for (const auto val : fusion->vals()) {
216 if (val->getValType().value() == ValType::Scalar ||
217 val->getValType().value() == ValType::NamedScalar) {
218 clones_map_[val] = val;
219 }
220 }
221 // Clone the expressions
222 for (auto expr : exprs) {
223 IrCloner::handle(expr);
224 }
225}
226
227void RecomputeTv::handle(const TensorDomain* td) {
228 // Make sure to recompute the history of the iteration domains, explicitly go
229 // through the expressions and send them to IrCloner.
230 auto exprs =
231 StmtSort::getExprs(fusion_, {td->domain().begin(), td->domain().end()});
232
233 for (auto expr : exprs) {
234 IrCloner::handle(expr);
235 }
236 IrCloner::handle(td);
237}
238
239} // namespace cuda
240} // namespace fuser
241} // namespace jit
242} // namespace torch
243