1 | #include <ir_cloner.h> |
2 | |
3 | #include <fusion.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_builder.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | IrCloner::IrCloner(IrContainer* container) : ir_container_(container) {} |
13 | |
14 | Statement* IrCloner::clone(const Statement* statement) { |
15 | if (statement == nullptr) { |
16 | return nullptr; |
17 | } |
18 | |
19 | // Have we already cloned this node? |
20 | const auto it = clones_map_.find(statement); |
21 | if (it != clones_map_.end()) { |
22 | return it->second; |
23 | } else { |
24 | // Clone the new node, saving/restoring this->clone_ |
25 | // since the cloning can be reentrant |
26 | auto saved_clone = clone_; |
27 | handle(statement); |
28 | auto new_node = clone_; |
29 | clone_ = saved_clone; |
30 | |
31 | // The base cloning constructor (Statement) should have |
32 | // registered the new node. Failure to do so indicates |
33 | // that something went horribly wrong. |
34 | TORCH_INTERNAL_ASSERT(new_node != nullptr); |
35 | TORCH_INTERNAL_ASSERT(clones_map_[statement] == new_node); |
36 | |
37 | return new_node; |
38 | } |
39 | } |
40 | |
41 | void IrCloner::registerClone(const Statement* src, Statement* clone) { |
42 | TORCH_CHECK(src != nullptr); |
43 | TORCH_CHECK(clone != nullptr); |
44 | TORCH_CHECK(clones_map_.insert({src, clone}).second); |
45 | } |
46 | |
47 | void IrCloner::handle(const Statement* s) { |
48 | OptInConstDispatch::handle(s); |
49 | } |
50 | |
51 | void IrCloner::handle(const Val* v) { |
52 | OptInConstDispatch::handle(v); |
53 | } |
54 | |
55 | void IrCloner::handle(const Expr* e) { |
56 | OptInConstDispatch::handle(e); |
57 | } |
58 | |
59 | void IrCloner::handle(const TensorDomain* td) { |
60 | clone_ = IrBuilder::clone(td, this); |
61 | } |
62 | |
63 | void IrCloner::handle(const IterDomain* id) { |
64 | clone_ = IrBuilder::clone(id, this); |
65 | } |
66 | |
67 | void IrCloner::handle(const Bool* b) { |
68 | clone_ = IrBuilder::clone(b, this); |
69 | } |
70 | |
71 | void IrCloner::handle(const Double* d) { |
72 | clone_ = IrBuilder::clone(d, this); |
73 | } |
74 | |
75 | void IrCloner::handle(const Int* i) { |
76 | clone_ = IrBuilder::clone(i, this); |
77 | } |
78 | |
79 | void IrCloner::handle(const ComplexDouble* c) { |
80 | clone_ = IrBuilder::clone(c, this); |
81 | } |
82 | |
83 | void IrCloner::handle(const NamedScalar* named_scalar) { |
84 | clone_ = IrBuilder::clone(named_scalar, this); |
85 | } |
86 | |
87 | void IrCloner::handle(const TensorView* tv) { |
88 | clone_ = IrBuilder::clone(tv, this); |
89 | } |
90 | |
91 | void IrCloner::handle(const FullOp* op) { |
92 | clone_ = IrBuilder::clone(op, this); |
93 | } |
94 | |
95 | void IrCloner::handle(const ARangeOp* op) { |
96 | clone_ = IrBuilder::clone(op, this); |
97 | } |
98 | |
99 | void IrCloner::handle(const EyeOp* op) { |
100 | clone_ = IrBuilder::clone(op, this); |
101 | } |
102 | |
103 | void IrCloner::handle(const UnaryOp* op) { |
104 | clone_ = IrBuilder::clone(op, this); |
105 | } |
106 | |
107 | void IrCloner::handle(const BinaryOp* op) { |
108 | clone_ = IrBuilder::clone(op, this); |
109 | } |
110 | |
111 | void IrCloner::handle(const TernaryOp* op) { |
112 | clone_ = IrBuilder::clone(op, this); |
113 | } |
114 | |
115 | void IrCloner::handle(const RNGOp* op) { |
116 | clone_ = IrBuilder::clone(op, this); |
117 | } |
118 | |
119 | void IrCloner::handle(const BroadcastOp* op) { |
120 | clone_ = IrBuilder::clone(op, this); |
121 | } |
122 | |
123 | void IrCloner::handle(const ReductionOp* op) { |
124 | clone_ = IrBuilder::clone(op, this); |
125 | } |
126 | |
127 | void IrCloner::handle(const GroupedReductionOp* op) { |
128 | clone_ = IrBuilder::clone(op, this); |
129 | } |
130 | |
131 | void IrCloner::handle(const WelfordOp* op) { |
132 | clone_ = IrBuilder::clone(op, this); |
133 | } |
134 | |
135 | void IrCloner::handle(const LoadStoreOp* op) { |
136 | clone_ = IrBuilder::clone(op, this); |
137 | } |
138 | |
139 | void IrCloner::handle(const MmaOp* op) { |
140 | clone_ = IrBuilder::clone(op, this); |
141 | } |
142 | |
143 | void IrCloner::handle(const TransposeOp* op) { |
144 | clone_ = IrBuilder::clone(op, this); |
145 | } |
146 | |
147 | void IrCloner::handle(const ExpandOp* op) { |
148 | clone_ = IrBuilder::clone(op, this); |
149 | } |
150 | |
151 | void IrCloner::handle(const ShiftOp* op) { |
152 | clone_ = IrBuilder::clone(op, this); |
153 | } |
154 | |
155 | void IrCloner::handle(const GatherOp* op) { |
156 | clone_ = IrBuilder::clone(op, this); |
157 | } |
158 | |
159 | void IrCloner::handle(const ViewAsScalar* op) { |
160 | clone_ = IrBuilder::clone(op, this); |
161 | } |
162 | |
163 | void IrCloner::handle(const ViewOp* op) { |
164 | clone_ = IrBuilder::clone(op, this); |
165 | } |
166 | |
167 | void IrCloner::handle(const Split* split) { |
168 | clone_ = IrBuilder::clone(split, this); |
169 | } |
170 | |
171 | void IrCloner::handle(const Merge* merge) { |
172 | clone_ = IrBuilder::clone(merge, this); |
173 | } |
174 | |
175 | void IrCloner::handle(const Swizzle2D* swizzle) { |
176 | clone_ = IrBuilder::clone(swizzle, this); |
177 | } |
178 | |
179 | TensorView* RecomputeTv::recompute(TensorView* tv) { |
180 | FusionGuard fg(tv->fusion()); |
181 | |
182 | // Disallow recomputation of inputs or outputs. User would have to be aware of |
183 | // these changes and informed they happened somehow. |
184 | TORCH_INTERNAL_ASSERT( |
185 | !tv->isFusionInput(), |
186 | "Cannot recompute buffers that are inputs of the fusion." ); |
187 | |
188 | // Grab all the expressions used to generate the TensorView |
189 | auto exprs = StmtSort::getExprs(tv->fusion(), {tv}, false); |
190 | |
191 | // Run the replicator |
192 | RecomputeTv replicator(tv->fusion(), exprs); |
193 | |
194 | // Make const version of pointer for lookup |
195 | const auto const_tv = tv; |
196 | // Find the recomputed tensor from the cloner |
197 | auto clone_it = replicator.clones_map_.find(const_tv); |
198 | TORCH_INTERNAL_ASSERT(clone_it != replicator.clones_map_.end()); |
199 | auto cloned_val = clone_it->second; |
200 | TORCH_INTERNAL_ASSERT( |
201 | cloned_val->isA<TensorView>(), |
202 | "Cloned value is somehow not a tensor view." ); |
203 | |
204 | // Return the cloned value |
205 | return cloned_val->as<TensorView>(); |
206 | } |
207 | |
208 | RecomputeTv::RecomputeTv(Fusion* fusion, std::vector<Expr*> exprs) |
209 | : IrCloner(fusion), fusion_(fusion) { |
210 | // Add inputs to the clones map to prevent cloning them. |
211 | for (const auto inp : fusion->inputs()) { |
212 | clones_map_[inp] = inp; |
213 | } |
214 | // Adds all scalar values to clones map to prevent cloning them |
215 | for (const auto val : fusion->vals()) { |
216 | if (val->getValType().value() == ValType::Scalar || |
217 | val->getValType().value() == ValType::NamedScalar) { |
218 | clones_map_[val] = val; |
219 | } |
220 | } |
221 | // Clone the expressions |
222 | for (auto expr : exprs) { |
223 | IrCloner::handle(expr); |
224 | } |
225 | } |
226 | |
227 | void RecomputeTv::handle(const TensorDomain* td) { |
228 | // Make sure to recompute the history of the iteration domains, explicitly go |
229 | // through the expressions and send them to IrCloner. |
230 | auto exprs = |
231 | StmtSort::getExprs(fusion_, {td->domain().begin(), td->domain().end()}); |
232 | |
233 | for (auto expr : exprs) { |
234 | IrCloner::handle(expr); |
235 | } |
236 | IrCloner::handle(td); |
237 | } |
238 | |
239 | } // namespace cuda |
240 | } // namespace fuser |
241 | } // namespace jit |
242 | } // namespace torch |
243 | |