1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <c10/util/Exception.h> |
5 | |
6 | #include <utils.h> |
7 | |
8 | #include <unordered_map> |
9 | |
10 | // dispatch.h prevents the need from adding manual dispatch in every class that |
11 | // wants to define how to process a series of nodes. dispatch.h provides 4 |
12 | // classes that can be inherited providing a means to override functions on a |
13 | // per-node basis. There are currently 4 provided dispatch mechanisms: |
14 | // |
15 | // OptOutDispatch: |
16 | // |
17 | // provides the functions: |
18 | // virtual void handle(ValType* irnode){} |
19 | // |
20 | // This provides a mechanisms to override this handle for particular node |
21 | // types. For example if we only wanted to actually run a function on |
22 | // BinaryOps, we could inherit OptOutDispatch and simply override: void |
23 | // handle(BinaryOp*) { doSomething; } Then we could run through all our |
24 | // Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is |
25 | // encountered our override function will be called. For every other node, |
26 | // nothing will be done. |
27 | // |
28 | // OptInDispatch: |
29 | // |
30 | // This class is similar to OptOutDispatch, however if we encounter a node |
31 | // that we haven't specified an override for in the derived class, an error |
32 | // will be thrown. This is useful if we create a class that is expected to |
33 | // handle any type of node it encounters. |
34 | // |
35 | // OptOutMutator: |
36 | // |
37 | // This class is similar to OptOutDispatch except the functions provided are of |
38 | // type: virtual Statement* mutate(Statement*) this is useful for when we want |
39 | // to have an IR node result from our overloaded functions. |
40 | // |
41 | // OptInMutator: |
42 | // |
43 | // This class is similar to OptInDispatch except the functions provided are of |
44 | // type: virtual Statement* mutate(Statement*) this is useful for when we want |
45 | // to have an IR node result from our overloaded functions. |
46 | |
47 | namespace torch { |
48 | namespace jit { |
49 | namespace fuser { |
50 | namespace cuda { |
51 | class IrContainer; |
52 | class Fusion; |
53 | |
54 | // Hierarchal dispatch functions for handle |
55 | class Statement; |
56 | class Expr; |
57 | class Val; |
58 | |
59 | // Vals |
60 | class IterDomain; |
61 | class TensorDomain; |
62 | class TensorView; |
63 | |
64 | class Bool; |
65 | class Double; |
66 | class Int; |
67 | class ComplexDouble; |
68 | class NamedScalar; |
69 | |
70 | // Exprs |
71 | class FullOp; |
72 | class ARangeOp; |
73 | class EyeOp; |
74 | class UnaryOp; |
75 | class BinaryOp; |
76 | class TernaryOp; |
77 | class RNGOp; |
78 | class ReductionOp; |
79 | class GroupedReductionOp; |
80 | class WelfordOp; |
81 | class GroupedWelfordOp; |
82 | class LoadStoreOp; |
83 | class MmaOp; |
84 | class BroadcastOp; |
85 | class TransposeOp; |
86 | class ExpandOp; |
87 | class ShiftOp; |
88 | class GatherOp; |
89 | class ViewAsScalar; |
90 | class ViewOp; |
91 | |
92 | // Exprs |
93 | class Split; |
94 | class Merge; |
95 | class Swizzle2D; |
96 | |
97 | namespace kir { |
98 | class Predicate; |
99 | class TensorIndex; |
100 | class IntPair; |
101 | |
102 | class Allocate; |
103 | class BlockSync; |
104 | class GridSync; |
105 | class CpAsyncWait; |
106 | class CpAsyncCommit; |
107 | class ForLoop; |
108 | class IfThenElse; |
109 | class GridReduction; |
110 | class GroupedGridReduction; |
111 | class GridBroadcast; |
112 | class GridWelford; |
113 | class GroupedGridWelford; |
114 | class AllocateFusedReduction; |
115 | class InitMagicZero; |
116 | class UpdateMagicZero; |
117 | class Swizzle2DInt; |
118 | class PairSelect; |
119 | |
120 | } // namespace kir |
121 | |
122 | // By default, all IR nodes are handled in this dispatch, and will call an empty |
123 | // function on all nodes. |
124 | class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { |
125 | protected: |
126 | virtual void unhandled(const Statement*) {} |
127 | |
128 | public: |
129 | // Hierarchal dispatch functions for handle |
130 | virtual void handle(const Statement*); |
131 | virtual void handle(const Expr*); |
132 | virtual void handle(const Val*); |
133 | |
134 | // Vals |
135 | virtual void handle(const IterDomain* stmt); |
136 | virtual void handle(const TensorDomain* stmt); |
137 | virtual void handle(const TensorView* stmt); |
138 | virtual void handle(const Bool* stmt); |
139 | virtual void handle(const Double* stmt); |
140 | virtual void handle(const Int* stmt); |
141 | virtual void handle(const ComplexDouble* stmt); |
142 | virtual void handle(const NamedScalar* stmt); |
143 | |
144 | virtual void handle(const kir::Predicate*); |
145 | virtual void handle(const kir::TensorIndex*); |
146 | virtual void handle(const kir::IntPair*); |
147 | |
148 | // Exprs |
149 | virtual void handle(const FullOp* stmt); |
150 | virtual void handle(const ARangeOp* stmt); |
151 | virtual void handle(const EyeOp* stmt); |
152 | virtual void handle(const UnaryOp* stmt); |
153 | virtual void handle(const BinaryOp* stmt); |
154 | virtual void handle(const TernaryOp* stmt); |
155 | virtual void handle(const RNGOp* stmt); |
156 | virtual void handle(const ReductionOp* stmt); |
157 | virtual void handle(const GroupedReductionOp* stmt); |
158 | virtual void handle(const WelfordOp* stmt); |
159 | virtual void handle(const GroupedWelfordOp* stmt); |
160 | virtual void handle(const LoadStoreOp* stmt); |
161 | virtual void handle(const MmaOp* stmt); |
162 | virtual void handle(const BroadcastOp* stmt); |
163 | |
164 | virtual void handle(const Split* stmt); |
165 | virtual void handle(const Merge* stmt); |
166 | virtual void handle(const Swizzle2D* stmt); |
167 | virtual void handle(const TransposeOp* stmt); |
168 | virtual void handle(const ExpandOp* stmt); |
169 | virtual void handle(const ShiftOp* stmt); |
170 | virtual void handle(const GatherOp* stmt); |
171 | virtual void handle(const ViewAsScalar* stmt); |
172 | virtual void handle(const ViewOp* stmt); |
173 | |
174 | virtual void handle(const kir::Allocate*); |
175 | virtual void handle(const kir::BlockSync*); |
176 | virtual void handle(const kir::GridSync*); |
177 | virtual void handle(const kir::CpAsyncWait*); |
178 | virtual void handle(const kir::CpAsyncCommit*); |
179 | virtual void handle(const kir::InitMagicZero*); |
180 | virtual void handle(const kir::UpdateMagicZero*); |
181 | virtual void handle(const kir::ForLoop*); |
182 | virtual void handle(const kir::IfThenElse*); |
183 | virtual void handle(const kir::GridReduction*); |
184 | virtual void handle(const kir::GroupedGridReduction*); |
185 | virtual void handle(const kir::GridBroadcast*); |
186 | virtual void handle(const kir::GridWelford*); |
187 | virtual void handle(const kir::GroupedGridWelford*); |
188 | virtual void handle(const kir::AllocateFusedReduction*); |
189 | virtual void handle(const kir::Swizzle2DInt*); |
190 | virtual void handle(const kir::PairSelect*); |
191 | }; |
192 | |
193 | class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { |
194 | protected: |
195 | virtual void unhandled(Statement*); |
196 | |
197 | public: |
198 | // Hierarchal dispatch functions for handle |
199 | virtual void handle(Statement*); |
200 | virtual void handle(Expr*); |
201 | virtual void handle(Val*); |
202 | |
203 | // Vals |
204 | virtual void handle(Bool* stmt); |
205 | virtual void handle(Double* stmt); |
206 | virtual void handle(Int* stmt); |
207 | virtual void handle(ComplexDouble* stmt); |
208 | virtual void handle(NamedScalar* stmt); |
209 | virtual void handle(IterDomain* stmt); |
210 | virtual void handle(TensorDomain* stmt); |
211 | virtual void handle(TensorView* stmt); |
212 | |
213 | virtual void handle(kir::Predicate*); |
214 | virtual void handle(kir::TensorIndex*); |
215 | virtual void handle(kir::IntPair*); |
216 | |
217 | // Exprs |
218 | virtual void handle(FullOp* stmt); |
219 | virtual void handle(ARangeOp* stmt); |
220 | virtual void handle(EyeOp* stmt); |
221 | virtual void handle(UnaryOp* stmt); |
222 | virtual void handle(BinaryOp* stmt); |
223 | virtual void handle(TernaryOp* stmt); |
224 | virtual void handle(RNGOp* stmt); |
225 | virtual void handle(ReductionOp* stmt); |
226 | virtual void handle(GroupedReductionOp* stmt); |
227 | virtual void handle(WelfordOp* stmt); |
228 | virtual void handle(GroupedWelfordOp* stmt); |
229 | virtual void handle(LoadStoreOp* stmt); |
230 | virtual void handle(MmaOp* stmt); |
231 | virtual void handle(BroadcastOp* stmt); |
232 | |
233 | virtual void handle(Split* stmt); |
234 | virtual void handle(Merge* stmt); |
235 | virtual void handle(Swizzle2D* stmt); |
236 | virtual void handle(TransposeOp* stmt); |
237 | virtual void handle(ExpandOp* stmt); |
238 | virtual void handle(ShiftOp* stmt); |
239 | virtual void handle(GatherOp* stmt); |
240 | virtual void handle(ViewAsScalar* stmt); |
241 | virtual void handle(ViewOp* stmt); |
242 | |
243 | virtual void handle(kir::Allocate* stmt); |
244 | virtual void handle(kir::BlockSync* stmt); |
245 | virtual void handle(kir::GridSync* stmt); |
246 | virtual void handle(kir::CpAsyncWait* stmt); |
247 | virtual void handle(kir::CpAsyncCommit* stmt); |
248 | virtual void handle(kir::InitMagicZero* stmt); |
249 | virtual void handle(kir::UpdateMagicZero* stmt); |
250 | virtual void handle(kir::ForLoop* stmt); |
251 | virtual void handle(kir::IfThenElse* stmt); |
252 | virtual void handle(kir::GridReduction* stmt); |
253 | virtual void handle(kir::GroupedGridReduction* stmt); |
254 | virtual void handle(kir::GridBroadcast* stmt); |
255 | virtual void handle(kir::GridWelford* stmt); |
256 | virtual void handle(kir::GroupedGridWelford* stmt); |
257 | virtual void handle(kir::AllocateFusedReduction* stmt); |
258 | virtual void handle(kir::Swizzle2DInt* stmt); |
259 | virtual void handle(kir::PairSelect* stmt); |
260 | }; |
261 | |
262 | class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch { |
263 | public: |
264 | using OptOutConstDispatch::handle; |
265 | |
266 | protected: |
267 | virtual void unhandled(const Statement* stmt) final; |
268 | }; |
269 | |
270 | class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch { |
271 | public: |
272 | using OptOutDispatch::handle; |
273 | |
274 | protected: |
275 | virtual void unhandled(Statement* stmt) final; |
276 | }; |
277 | |
278 | // Class to perform mutations on Fusion IR. Exprs can simply be redefined, but |
279 | // when mutating values they have to be registered through registerMutation so |
280 | // that exprs can detect there's been a muatation and know to modify all |
281 | // instances of that Val. This means each Val should be mutated "consistently". |
282 | // Otherwise behavior may be difficult to understand as it depends on which |
283 | // order mutate is called in. This class expects user to topologically call the |
284 | // statments of interest so inputs are called and mutated before exprs depending |
285 | // on them. |
286 | // |
287 | // Warning: TensorViews need to be treated carefully. As we don't generally |
288 | // register their mutation when their tensor domains only change. If a TV needs |
289 | // to be swapped out, it needs to be registered as a "proper" mutation like |
290 | // other vals, on top of TensorDomain being updated in the mutated TensorView. |
291 | // |
292 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
293 | class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { |
294 | public: |
295 | // Hierarchal dispatch functions for handle |
296 | virtual void mutate(Statement* s); |
297 | virtual void mutate(Expr* e); |
298 | virtual void mutate(Val* v); |
299 | |
300 | void registerMutation(Val* val, Val* mutation); |
301 | |
302 | Val* maybeMutated(Val* val) { |
303 | if (mutations.find(val) == mutations.end()) { |
304 | return val; |
305 | } |
306 | return mutations.at(val); |
307 | } |
308 | |
309 | std::unordered_map<Val*, Val*> mutations; |
310 | |
311 | //****Functions below defined in mutator.cpp***** |
312 | |
313 | // Vals |
314 | virtual void mutate(Bool*); |
315 | virtual void mutate(Double*); |
316 | virtual void mutate(Int*); |
317 | virtual void mutate(ComplexDouble*); |
318 | virtual void mutate(NamedScalar*); |
319 | virtual void mutate(IterDomain*); |
320 | virtual void mutate(TensorDomain*); |
321 | virtual void mutate(TensorView*); |
322 | |
323 | virtual void mutate(kir::Predicate*); |
324 | virtual void mutate(kir::TensorIndex*); |
325 | virtual void mutate(kir::IntPair*); |
326 | |
327 | // Exprs |
328 | virtual void mutate(FullOp*); |
329 | virtual void mutate(ARangeOp*); |
330 | virtual void mutate(EyeOp*); |
331 | virtual void mutate(UnaryOp*); |
332 | virtual void mutate(BinaryOp*); |
333 | virtual void mutate(TernaryOp*); |
334 | virtual void mutate(RNGOp*); |
335 | virtual void mutate(ReductionOp*); |
336 | virtual void mutate(GroupedReductionOp*); |
337 | virtual void mutate(WelfordOp*); |
338 | virtual void mutate(GroupedWelfordOp*); |
339 | virtual void mutate(LoadStoreOp*); |
340 | virtual void mutate(MmaOp*); |
341 | virtual void mutate(BroadcastOp*); |
342 | |
343 | virtual void mutate(Split*); |
344 | virtual void mutate(Merge*); |
345 | virtual void mutate(Swizzle2D*); |
346 | virtual void mutate(TransposeOp*); |
347 | virtual void mutate(ExpandOp*); |
348 | virtual void mutate(ShiftOp*); |
349 | virtual void mutate(GatherOp*); |
350 | virtual void mutate(ViewAsScalar*); |
351 | virtual void mutate(ViewOp*); |
352 | |
353 | virtual void mutate(kir::Allocate*); |
354 | virtual void mutate(kir::BlockSync*); |
355 | virtual void mutate(kir::GridSync*); |
356 | virtual void mutate(kir::CpAsyncWait*); |
357 | virtual void mutate(kir::CpAsyncCommit*); |
358 | virtual void mutate(kir::InitMagicZero*); |
359 | virtual void mutate(kir::UpdateMagicZero*); |
360 | virtual void mutate(kir::ForLoop*); |
361 | virtual void mutate(kir::IfThenElse*); |
362 | virtual void mutate(kir::GridReduction*); |
363 | virtual void mutate(kir::GroupedGridReduction*); |
364 | virtual void mutate(kir::GridBroadcast*); |
365 | virtual void mutate(kir::GridWelford*); |
366 | virtual void mutate(kir::GroupedGridWelford*); |
367 | virtual void mutate(kir::AllocateFusedReduction*); |
368 | virtual void mutate(kir::Swizzle2DInt*); |
369 | virtual void mutate(kir::PairSelect*); |
370 | |
371 | protected: |
372 | void removeExpr(IrContainer*, Expr*); |
373 | }; |
374 | |
375 | } // namespace cuda |
376 | } // namespace fuser |
377 | } // namespace jit |
378 | } // namespace torch |
379 | |