1#pragma once
2
3#include <c10/macros/Export.h>
4#include <c10/util/Exception.h>
5
6#include <utils.h>
7
8#include <unordered_map>
9
10// dispatch.h prevents the need from adding manual dispatch in every class that
11// wants to define how to process a series of nodes. dispatch.h provides 4
12// classes that can be inherited providing a means to override functions on a
13// per-node basis. There are currently 4 provided dispatch mechanisms:
14//
15// OptOutDispatch:
16//
17// provides the functions:
18// virtual void handle(ValType* irnode){}
19//
20// This provides a mechanisms to override this handle for particular node
21// types. For example if we only wanted to actually run a function on
22// BinaryOps, we could inherit OptOutDispatch and simply override: void
23// handle(BinaryOp*) { doSomething; } Then we could run through all our
24// Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
25// encountered our override function will be called. For every other node,
26// nothing will be done.
27//
28// OptInDispatch:
29//
30// This class is similar to OptOutDispatch, however if we encounter a node
31// that we haven't specified an override for in the derived class, an error
32// will be thrown. This is useful if we create a class that is expected to
33// handle any type of node it encounters.
34//
35// OptOutMutator:
36//
37// This class is similar to OptOutDispatch except the functions provided are of
38// type: virtual Statement* mutate(Statement*) this is useful for when we want
39// to have an IR node result from our overloaded functions.
40//
41// OptInMutator:
42//
43// This class is similar to OptInDispatch except the functions provided are of
44// type: virtual Statement* mutate(Statement*) this is useful for when we want
45// to have an IR node result from our overloaded functions.
46
47namespace torch {
48namespace jit {
49namespace fuser {
50namespace cuda {
51class IrContainer;
52class Fusion;
53
54// Hierarchal dispatch functions for handle
55class Statement;
56class Expr;
57class Val;
58
59// Vals
60class IterDomain;
61class TensorDomain;
62class TensorView;
63
64class Bool;
65class Double;
66class Int;
67class ComplexDouble;
68class NamedScalar;
69
70// Exprs
71class FullOp;
72class ARangeOp;
73class EyeOp;
74class UnaryOp;
75class BinaryOp;
76class TernaryOp;
77class RNGOp;
78class ReductionOp;
79class GroupedReductionOp;
80class WelfordOp;
81class GroupedWelfordOp;
82class LoadStoreOp;
83class MmaOp;
84class BroadcastOp;
85class TransposeOp;
86class ExpandOp;
87class ShiftOp;
88class GatherOp;
89class ViewAsScalar;
90class ViewOp;
91
92// Exprs
93class Split;
94class Merge;
95class Swizzle2D;
96
97namespace kir {
98class Predicate;
99class TensorIndex;
100class IntPair;
101
102class Allocate;
103class BlockSync;
104class GridSync;
105class CpAsyncWait;
106class CpAsyncCommit;
107class ForLoop;
108class IfThenElse;
109class GridReduction;
110class GroupedGridReduction;
111class GridBroadcast;
112class GridWelford;
113class GroupedGridWelford;
114class AllocateFusedReduction;
115class InitMagicZero;
116class UpdateMagicZero;
117class Swizzle2DInt;
118class PairSelect;
119
120} // namespace kir
121
122// By default, all IR nodes are handled in this dispatch, and will call an empty
123// function on all nodes.
124class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
125 protected:
126 virtual void unhandled(const Statement*) {}
127
128 public:
129 // Hierarchal dispatch functions for handle
130 virtual void handle(const Statement*);
131 virtual void handle(const Expr*);
132 virtual void handle(const Val*);
133
134 // Vals
135 virtual void handle(const IterDomain* stmt);
136 virtual void handle(const TensorDomain* stmt);
137 virtual void handle(const TensorView* stmt);
138 virtual void handle(const Bool* stmt);
139 virtual void handle(const Double* stmt);
140 virtual void handle(const Int* stmt);
141 virtual void handle(const ComplexDouble* stmt);
142 virtual void handle(const NamedScalar* stmt);
143
144 virtual void handle(const kir::Predicate*);
145 virtual void handle(const kir::TensorIndex*);
146 virtual void handle(const kir::IntPair*);
147
148 // Exprs
149 virtual void handle(const FullOp* stmt);
150 virtual void handle(const ARangeOp* stmt);
151 virtual void handle(const EyeOp* stmt);
152 virtual void handle(const UnaryOp* stmt);
153 virtual void handle(const BinaryOp* stmt);
154 virtual void handle(const TernaryOp* stmt);
155 virtual void handle(const RNGOp* stmt);
156 virtual void handle(const ReductionOp* stmt);
157 virtual void handle(const GroupedReductionOp* stmt);
158 virtual void handle(const WelfordOp* stmt);
159 virtual void handle(const GroupedWelfordOp* stmt);
160 virtual void handle(const LoadStoreOp* stmt);
161 virtual void handle(const MmaOp* stmt);
162 virtual void handle(const BroadcastOp* stmt);
163
164 virtual void handle(const Split* stmt);
165 virtual void handle(const Merge* stmt);
166 virtual void handle(const Swizzle2D* stmt);
167 virtual void handle(const TransposeOp* stmt);
168 virtual void handle(const ExpandOp* stmt);
169 virtual void handle(const ShiftOp* stmt);
170 virtual void handle(const GatherOp* stmt);
171 virtual void handle(const ViewAsScalar* stmt);
172 virtual void handle(const ViewOp* stmt);
173
174 virtual void handle(const kir::Allocate*);
175 virtual void handle(const kir::BlockSync*);
176 virtual void handle(const kir::GridSync*);
177 virtual void handle(const kir::CpAsyncWait*);
178 virtual void handle(const kir::CpAsyncCommit*);
179 virtual void handle(const kir::InitMagicZero*);
180 virtual void handle(const kir::UpdateMagicZero*);
181 virtual void handle(const kir::ForLoop*);
182 virtual void handle(const kir::IfThenElse*);
183 virtual void handle(const kir::GridReduction*);
184 virtual void handle(const kir::GroupedGridReduction*);
185 virtual void handle(const kir::GridBroadcast*);
186 virtual void handle(const kir::GridWelford*);
187 virtual void handle(const kir::GroupedGridWelford*);
188 virtual void handle(const kir::AllocateFusedReduction*);
189 virtual void handle(const kir::Swizzle2DInt*);
190 virtual void handle(const kir::PairSelect*);
191};
192
193class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
194 protected:
195 virtual void unhandled(Statement*);
196
197 public:
198 // Hierarchal dispatch functions for handle
199 virtual void handle(Statement*);
200 virtual void handle(Expr*);
201 virtual void handle(Val*);
202
203 // Vals
204 virtual void handle(Bool* stmt);
205 virtual void handle(Double* stmt);
206 virtual void handle(Int* stmt);
207 virtual void handle(ComplexDouble* stmt);
208 virtual void handle(NamedScalar* stmt);
209 virtual void handle(IterDomain* stmt);
210 virtual void handle(TensorDomain* stmt);
211 virtual void handle(TensorView* stmt);
212
213 virtual void handle(kir::Predicate*);
214 virtual void handle(kir::TensorIndex*);
215 virtual void handle(kir::IntPair*);
216
217 // Exprs
218 virtual void handle(FullOp* stmt);
219 virtual void handle(ARangeOp* stmt);
220 virtual void handle(EyeOp* stmt);
221 virtual void handle(UnaryOp* stmt);
222 virtual void handle(BinaryOp* stmt);
223 virtual void handle(TernaryOp* stmt);
224 virtual void handle(RNGOp* stmt);
225 virtual void handle(ReductionOp* stmt);
226 virtual void handle(GroupedReductionOp* stmt);
227 virtual void handle(WelfordOp* stmt);
228 virtual void handle(GroupedWelfordOp* stmt);
229 virtual void handle(LoadStoreOp* stmt);
230 virtual void handle(MmaOp* stmt);
231 virtual void handle(BroadcastOp* stmt);
232
233 virtual void handle(Split* stmt);
234 virtual void handle(Merge* stmt);
235 virtual void handle(Swizzle2D* stmt);
236 virtual void handle(TransposeOp* stmt);
237 virtual void handle(ExpandOp* stmt);
238 virtual void handle(ShiftOp* stmt);
239 virtual void handle(GatherOp* stmt);
240 virtual void handle(ViewAsScalar* stmt);
241 virtual void handle(ViewOp* stmt);
242
243 virtual void handle(kir::Allocate* stmt);
244 virtual void handle(kir::BlockSync* stmt);
245 virtual void handle(kir::GridSync* stmt);
246 virtual void handle(kir::CpAsyncWait* stmt);
247 virtual void handle(kir::CpAsyncCommit* stmt);
248 virtual void handle(kir::InitMagicZero* stmt);
249 virtual void handle(kir::UpdateMagicZero* stmt);
250 virtual void handle(kir::ForLoop* stmt);
251 virtual void handle(kir::IfThenElse* stmt);
252 virtual void handle(kir::GridReduction* stmt);
253 virtual void handle(kir::GroupedGridReduction* stmt);
254 virtual void handle(kir::GridBroadcast* stmt);
255 virtual void handle(kir::GridWelford* stmt);
256 virtual void handle(kir::GroupedGridWelford* stmt);
257 virtual void handle(kir::AllocateFusedReduction* stmt);
258 virtual void handle(kir::Swizzle2DInt* stmt);
259 virtual void handle(kir::PairSelect* stmt);
260};
261
262class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch {
263 public:
264 using OptOutConstDispatch::handle;
265
266 protected:
267 virtual void unhandled(const Statement* stmt) final;
268};
269
270class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch {
271 public:
272 using OptOutDispatch::handle;
273
274 protected:
275 virtual void unhandled(Statement* stmt) final;
276};
277
278// Class to perform mutations on Fusion IR. Exprs can simply be redefined, but
279// when mutating values they have to be registered through registerMutation so
280// that exprs can detect there's been a muatation and know to modify all
281// instances of that Val. This means each Val should be mutated "consistently".
282// Otherwise behavior may be difficult to understand as it depends on which
283// order mutate is called in. This class expects user to topologically call the
284// statments of interest so inputs are called and mutated before exprs depending
285// on them.
286//
287// Warning: TensorViews need to be treated carefully. As we don't generally
288// register their mutation when their tensor domains only change. If a TV needs
289// to be swapped out, it needs to be registered as a "proper" mutation like
290// other vals, on top of TensorDomain being updated in the mutated TensorView.
291//
292// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
293class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
294 public:
295 // Hierarchal dispatch functions for handle
296 virtual void mutate(Statement* s);
297 virtual void mutate(Expr* e);
298 virtual void mutate(Val* v);
299
300 void registerMutation(Val* val, Val* mutation);
301
302 Val* maybeMutated(Val* val) {
303 if (mutations.find(val) == mutations.end()) {
304 return val;
305 }
306 return mutations.at(val);
307 }
308
309 std::unordered_map<Val*, Val*> mutations;
310
311 //****Functions below defined in mutator.cpp*****
312
313 // Vals
314 virtual void mutate(Bool*);
315 virtual void mutate(Double*);
316 virtual void mutate(Int*);
317 virtual void mutate(ComplexDouble*);
318 virtual void mutate(NamedScalar*);
319 virtual void mutate(IterDomain*);
320 virtual void mutate(TensorDomain*);
321 virtual void mutate(TensorView*);
322
323 virtual void mutate(kir::Predicate*);
324 virtual void mutate(kir::TensorIndex*);
325 virtual void mutate(kir::IntPair*);
326
327 // Exprs
328 virtual void mutate(FullOp*);
329 virtual void mutate(ARangeOp*);
330 virtual void mutate(EyeOp*);
331 virtual void mutate(UnaryOp*);
332 virtual void mutate(BinaryOp*);
333 virtual void mutate(TernaryOp*);
334 virtual void mutate(RNGOp*);
335 virtual void mutate(ReductionOp*);
336 virtual void mutate(GroupedReductionOp*);
337 virtual void mutate(WelfordOp*);
338 virtual void mutate(GroupedWelfordOp*);
339 virtual void mutate(LoadStoreOp*);
340 virtual void mutate(MmaOp*);
341 virtual void mutate(BroadcastOp*);
342
343 virtual void mutate(Split*);
344 virtual void mutate(Merge*);
345 virtual void mutate(Swizzle2D*);
346 virtual void mutate(TransposeOp*);
347 virtual void mutate(ExpandOp*);
348 virtual void mutate(ShiftOp*);
349 virtual void mutate(GatherOp*);
350 virtual void mutate(ViewAsScalar*);
351 virtual void mutate(ViewOp*);
352
353 virtual void mutate(kir::Allocate*);
354 virtual void mutate(kir::BlockSync*);
355 virtual void mutate(kir::GridSync*);
356 virtual void mutate(kir::CpAsyncWait*);
357 virtual void mutate(kir::CpAsyncCommit*);
358 virtual void mutate(kir::InitMagicZero*);
359 virtual void mutate(kir::UpdateMagicZero*);
360 virtual void mutate(kir::ForLoop*);
361 virtual void mutate(kir::IfThenElse*);
362 virtual void mutate(kir::GridReduction*);
363 virtual void mutate(kir::GroupedGridReduction*);
364 virtual void mutate(kir::GridBroadcast*);
365 virtual void mutate(kir::GridWelford*);
366 virtual void mutate(kir::GroupedGridWelford*);
367 virtual void mutate(kir::AllocateFusedReduction*);
368 virtual void mutate(kir::Swizzle2DInt*);
369 virtual void mutate(kir::PairSelect*);
370
371 protected:
372 void removeExpr(IrContainer*, Expr*);
373};
374
375} // namespace cuda
376} // namespace fuser
377} // namespace jit
378} // namespace torch
379