1#include <fusion.h>
2#include <ir_all_nodes.h>
3#include <type.h>
4
5#include <dispatch.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12template <typename T>
13T* ptr(T& obj) {
14 return &obj;
15}
16
17template <typename T>
18T* ptr(T* obj) {
19 return obj;
20}
21
22/*
23 * Generic dispatch for any handler that does not modify the IR directly.
24 * For example we may want to walk the graph to construct a topologically sorted
25 * set of exprs. This doesn't modify the IR directly. We also use this to print
26 * the IR itself.
27 * This dispatch is paired with a class that implements the functions:
28 * template <typenname node_type>
29 * int handler(node_type* node)
30 *
31 * handler should call:
32 * dispatch(this, node_to_dispatch)
33 *
34 * It could also implement:
35 * int handler(Statement* stmt){
36 * dispatch(this, stmt);
37 * }
38 *
39 * And therefore dispatch should never call:
40 * ptr(mutator)->mutate(this->as<Statement>());
41 */
42
43template <typename T>
44void Val::dispatch(T handler, Val* val) {
45 switch (*(val->getValType())) {
46 case ValType::Scalar:
47 switch (*(val->getDataType())) {
48 case DataType::Bool:
49 ptr(handler)->handle(val->as<Bool>());
50 return;
51 case DataType::Double:
52 ptr(handler)->handle(val->as<Double>());
53 return;
54 case DataType::Int:
55 case DataType::Int32:
56 // Dispatch to Int even with Int32 as we don't have Int32 IR
57 // node.
58 ptr(handler)->handle(val->as<Int>());
59 return;
60 case DataType::ComplexDouble:
61 ptr(handler)->handle(val->as<ComplexDouble>());
62 return;
63 default:
64 break;
65 }
66 break;
67 case ValType::NamedScalar:
68 ptr(handler)->handle(val->as<NamedScalar>());
69 return;
70
71 case ValType::IterDomain:
72 ptr(handler)->handle(val->as<IterDomain>());
73 return;
74 case ValType::TensorDomain:
75 ptr(handler)->handle(val->as<TensorDomain>());
76 return;
77 case ValType::TensorView:
78 ptr(handler)->handle(val->as<TensorView>());
79 return;
80 case ValType::Predicate:
81 ptr(handler)->handle(val->as<kir::Predicate>());
82 return;
83 case ValType::TensorIndex:
84 ptr(handler)->handle(val->as<kir::TensorIndex>());
85 return;
86 case ValType::IntPair:
87 ptr(handler)->handle(val->as<kir::IntPair>());
88 return;
89 default:
90 break;
91 }
92 TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
93}
94
95template <typename T>
96void Expr::dispatch(T handler, Expr* expr) {
97 switch (*(expr->getExprType())) {
98 case ExprType::FullOp:
99 ptr(handler)->handle(expr->as<FullOp>());
100 return;
101 case ExprType::ARangeOp:
102 ptr(handler)->handle(expr->as<ARangeOp>());
103 return;
104 case ExprType::EyeOp:
105 ptr(handler)->handle(expr->as<EyeOp>());
106 return;
107 case ExprType::UnaryOp:
108 ptr(handler)->handle(expr->as<UnaryOp>());
109 return;
110 case ExprType::BinaryOp:
111 ptr(handler)->handle(expr->as<BinaryOp>());
112 return;
113 case ExprType::TernaryOp:
114 ptr(handler)->handle(expr->as<TernaryOp>());
115 return;
116 case ExprType::RNGOp:
117 ptr(handler)->handle(expr->as<RNGOp>());
118 return;
119 case ExprType::ReductionOp:
120 ptr(handler)->handle(expr->as<ReductionOp>());
121 return;
122 case ExprType::GroupedReductionOp:
123 ptr(handler)->handle(expr->as<GroupedReductionOp>());
124 return;
125 case ExprType::WelfordOp:
126 ptr(handler)->handle(expr->as<WelfordOp>());
127 return;
128 case ExprType::GroupedWelfordOp:
129 ptr(handler)->handle(expr->as<GroupedWelfordOp>());
130 return;
131 case ExprType::LoadStoreOp:
132 ptr(handler)->handle(expr->as<LoadStoreOp>());
133 return;
134 case ExprType::MmaOp:
135 ptr(handler)->handle(expr->as<MmaOp>());
136 return;
137 case ExprType::BroadcastOp:
138 ptr(handler)->handle(expr->as<BroadcastOp>());
139 return;
140
141 case ExprType::Split:
142 ptr(handler)->handle(expr->as<Split>());
143 return;
144 case ExprType::Merge:
145 ptr(handler)->handle(expr->as<Merge>());
146 return;
147 case ExprType::Swizzle2D:
148 ptr(handler)->handle(expr->as<Swizzle2D>());
149 return;
150 case ExprType::TransposeOp:
151 ptr(handler)->handle(expr->as<TransposeOp>());
152 return;
153 case ExprType::ExpandOp:
154 ptr(handler)->handle(expr->as<ExpandOp>());
155 return;
156 case ExprType::ShiftOp:
157 ptr(handler)->handle(expr->as<ShiftOp>());
158 return;
159 case ExprType::GatherOp:
160 ptr(handler)->handle(expr->as<GatherOp>());
161 return;
162 case ExprType::ViewAsScalar:
163 ptr(handler)->handle(expr->as<ViewAsScalar>());
164 return;
165 case ExprType::ViewOp:
166 ptr(handler)->handle(expr->as<ViewOp>());
167 return;
168
169 case ExprType::Allocate:
170 ptr(handler)->handle(expr->as<kir::Allocate>());
171 return;
172 case ExprType::BlockSync:
173 ptr(handler)->handle(expr->as<kir::BlockSync>());
174 return;
175 case ExprType::GridSync:
176 ptr(handler)->handle(expr->as<kir::GridSync>());
177 return;
178 case ExprType::CpAsyncWait:
179 ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
180 return;
181 case ExprType::CpAsyncCommit:
182 ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
183 return;
184 case ExprType::InitMagicZero:
185 ptr(handler)->handle(expr->as<kir::InitMagicZero>());
186 return;
187 case ExprType::UpdateMagicZero:
188 ptr(handler)->handle(expr->as<kir::UpdateMagicZero>());
189 return;
190 case ExprType::ForLoop:
191 ptr(handler)->handle(expr->as<kir::ForLoop>());
192 return;
193 case ExprType::IfThenElse:
194 ptr(handler)->handle(expr->as<kir::IfThenElse>());
195 return;
196 case ExprType::GridReduction:
197 ptr(handler)->handle(expr->as<kir::GridReduction>());
198 return;
199 case ExprType::GroupedGridReduction:
200 ptr(handler)->handle(expr->as<kir::GroupedGridReduction>());
201 return;
202 case ExprType::GridBroadcast:
203 ptr(handler)->handle(expr->as<kir::GridBroadcast>());
204 return;
205 case ExprType::GridWelford:
206 ptr(handler)->handle(expr->as<kir::GridWelford>());
207 return;
208 case ExprType::GroupedGridWelford:
209 ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
210 return;
211 case ExprType::AllocateFusedReduction:
212 ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
213 return;
214 case ExprType::Swizzle2DInt:
215 ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
216 return;
217 case ExprType::PairSelect:
218 ptr(handler)->handle(expr->as<kir::PairSelect>());
219 return;
220 default:
221 TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
222 }
223}
224
225template <typename T>
226void Statement::dispatch(T handler, Statement* stmt) {
227 if (stmt->isVal()) {
228 ptr(handler)->handle(stmt->as<Val>());
229 } else if (stmt->isExpr()) {
230 ptr(handler)->handle(stmt->as<Expr>());
231 } else
232 TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
233}
234
235template <typename T>
236void Val::constDispatch(T handler, const Val* val) {
237 switch (*(val->getValType())) {
238 case ValType::Scalar:
239 switch (*(val->getDataType())) {
240 case DataType::Bool:
241 ptr(handler)->handle(val->as<Bool>());
242 return;
243 case DataType::Double:
244 ptr(handler)->handle(val->as<Double>());
245 return;
246 case DataType::Int:
247 case DataType::Int32:
248 // Dispatch to Int even with Int32 as we don't have Int32 IR
249 // node.
250 ptr(handler)->handle(val->as<Int>());
251 return;
252 case DataType::ComplexDouble:
253 ptr(handler)->handle(val->as<ComplexDouble>());
254 return;
255 default:
256 break;
257 }
258 break;
259 case ValType::NamedScalar:
260 ptr(handler)->handle(val->as<NamedScalar>());
261 return;
262
263 case ValType::IterDomain:
264 ptr(handler)->handle(val->as<IterDomain>());
265 return;
266 case ValType::TensorDomain:
267 ptr(handler)->handle(val->as<TensorDomain>());
268 return;
269 case ValType::TensorView:
270 ptr(handler)->handle(val->as<TensorView>());
271 return;
272 case ValType::Predicate:
273 ptr(handler)->handle(val->as<kir::Predicate>());
274 return;
275 case ValType::TensorIndex:
276 ptr(handler)->handle(val->as<kir::TensorIndex>());
277 return;
278 case ValType::IntPair:
279 ptr(handler)->handle(val->as<kir::IntPair>());
280 return;
281 default:
282 break;
283 }
284 TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
285}
286
287template <typename T>
288void Expr::constDispatch(T handler, const Expr* expr) {
289 switch (*(expr->getExprType())) {
290 case ExprType::FullOp:
291 ptr(handler)->handle(expr->as<FullOp>());
292 return;
293 case ExprType::ARangeOp:
294 ptr(handler)->handle(expr->as<ARangeOp>());
295 return;
296 case ExprType::EyeOp:
297 ptr(handler)->handle(expr->as<EyeOp>());
298 return;
299 case ExprType::UnaryOp:
300 ptr(handler)->handle(expr->as<UnaryOp>());
301 return;
302 case ExprType::BinaryOp:
303 ptr(handler)->handle(expr->as<BinaryOp>());
304 return;
305 case ExprType::TernaryOp:
306 ptr(handler)->handle(expr->as<TernaryOp>());
307 return;
308 case ExprType::RNGOp:
309 ptr(handler)->handle(expr->as<RNGOp>());
310 return;
311 case ExprType::ReductionOp:
312 ptr(handler)->handle(expr->as<ReductionOp>());
313 return;
314 case ExprType::GroupedReductionOp:
315 ptr(handler)->handle(expr->as<GroupedReductionOp>());
316 return;
317 case ExprType::WelfordOp:
318 ptr(handler)->handle(expr->as<WelfordOp>());
319 return;
320 case ExprType::GroupedWelfordOp:
321 ptr(handler)->handle(expr->as<GroupedWelfordOp>());
322 return;
323 case ExprType::LoadStoreOp:
324 ptr(handler)->handle(expr->as<LoadStoreOp>());
325 return;
326 case ExprType::MmaOp:
327 ptr(handler)->handle(expr->as<MmaOp>());
328 return;
329 case ExprType::BroadcastOp:
330 ptr(handler)->handle(expr->as<BroadcastOp>());
331 return;
332
333 case ExprType::Split:
334 ptr(handler)->handle(expr->as<Split>());
335 return;
336 case ExprType::Merge:
337 ptr(handler)->handle(expr->as<Merge>());
338 return;
339 case ExprType::Swizzle2D:
340 ptr(handler)->handle(expr->as<Swizzle2D>());
341 return;
342 case ExprType::TransposeOp:
343 ptr(handler)->handle(expr->as<TransposeOp>());
344 return;
345 case ExprType::ExpandOp:
346 ptr(handler)->handle(expr->as<ExpandOp>());
347 return;
348 case ExprType::ShiftOp:
349 ptr(handler)->handle(expr->as<ShiftOp>());
350 return;
351 case ExprType::GatherOp:
352 ptr(handler)->handle(expr->as<GatherOp>());
353 return;
354 case ExprType::ViewAsScalar:
355 ptr(handler)->handle(expr->as<ViewAsScalar>());
356 return;
357 case ExprType::ViewOp:
358 ptr(handler)->handle(expr->as<ViewOp>());
359 return;
360
361 case ExprType::Allocate:
362 ptr(handler)->handle(expr->as<kir::Allocate>());
363 return;
364 case ExprType::BlockSync:
365 ptr(handler)->handle(expr->as<kir::BlockSync>());
366 return;
367 case ExprType::GridSync:
368 ptr(handler)->handle(expr->as<kir::GridSync>());
369 return;
370 case ExprType::CpAsyncWait:
371 ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
372 return;
373 case ExprType::CpAsyncCommit:
374 ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
375 return;
376 case ExprType::InitMagicZero:
377 ptr(handler)->handle(expr->as<kir::InitMagicZero>());
378 return;
379 case ExprType::UpdateMagicZero:
380 ptr(handler)->handle(expr->as<kir::UpdateMagicZero>());
381 return;
382 case ExprType::ForLoop:
383 ptr(handler)->handle(expr->as<kir::ForLoop>());
384 return;
385 case ExprType::IfThenElse:
386 ptr(handler)->handle(expr->as<kir::IfThenElse>());
387 return;
388 case ExprType::GridReduction:
389 ptr(handler)->handle(expr->as<kir::GridReduction>());
390 return;
391 case ExprType::GroupedGridReduction:
392 ptr(handler)->handle(expr->as<kir::GroupedGridReduction>());
393 return;
394 case ExprType::GridBroadcast:
395 ptr(handler)->handle(expr->as<kir::GridBroadcast>());
396 return;
397 case ExprType::GridWelford:
398 ptr(handler)->handle(expr->as<kir::GridWelford>());
399 return;
400 case ExprType::GroupedGridWelford:
401 ptr(handler)->handle(expr->as<kir::GroupedGridWelford>());
402 return;
403 case ExprType::AllocateFusedReduction:
404 ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
405 return;
406 case ExprType::Swizzle2DInt:
407 ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
408 return;
409 case ExprType::PairSelect:
410 ptr(handler)->handle(expr->as<kir::PairSelect>());
411 return;
412 default:
413 TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
414 }
415}
416
417template <typename T>
418void Statement::constDispatch(T handler, const Statement* stmt) {
419 if (stmt->isVal()) {
420 ptr(handler)->handle(stmt->as<Val>());
421 } else if (stmt->isExpr()) {
422 ptr(handler)->handle(stmt->as<Expr>());
423 } else
424 TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
425}
426
427/*
428 * Generic mutatorDispatch for any handler that modifies the IR. This could be
429 * a transformation on loop structures, or parallelizing a loop. This
430 * mutatorDispatch is paired with a class that implements the functions
431 * template <typenname node_type> Statement* mutate(node_type* node) mutate
432 * should call (statement* node_to_dispatch)->mutatorDispatch() It could also
433 * implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
434 * }
435 * And therefore dispatch should never call:
436 * ptr(mutator)->mutate(this->as<Statement>());
437 */
438template <typename T>
439void Val::mutatorDispatch(T mutator, Val* val) {
440 switch (*(val->getValType())) {
441 case ValType::Scalar:
442 switch (*(val->getDataType())) {
443 case DataType::Bool:
444 ptr(mutator)->mutate(val->as<Bool>());
445 return;
446 case DataType::Double:
447 ptr(mutator)->mutate(val->as<Double>());
448 return;
449 case DataType::Int:
450 ptr(mutator)->mutate(val->as<Int>());
451 return;
452 case DataType::ComplexDouble:
453 ptr(mutator)->mutate(val->as<ComplexDouble>());
454 return;
455 default:
456 break;
457 }
458 break;
459 case ValType::NamedScalar:
460 ptr(mutator)->mutate(val->as<NamedScalar>());
461 return;
462
463 case ValType::IterDomain:
464 ptr(mutator)->mutate(val->as<IterDomain>());
465 return;
466 case ValType::TensorDomain:
467 ptr(mutator)->mutate(val->as<TensorDomain>());
468 return;
469 case ValType::TensorView:
470 ptr(mutator)->mutate(val->as<TensorView>());
471 return;
472 case ValType::Predicate:
473 ptr(mutator)->mutate(val->as<kir::Predicate>());
474 return;
475 case ValType::TensorIndex:
476 ptr(mutator)->mutate(val->as<kir::TensorIndex>());
477 return;
478 case ValType::IntPair:
479 ptr(mutator)->mutate(val->as<kir::IntPair>());
480 return;
481 default:
482 break;
483 }
484 TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
485}
486
487template <typename T>
488void Expr::mutatorDispatch(T mutator, Expr* expr) {
489 switch (*(expr->getExprType())) {
490 case ExprType::FullOp:
491 ptr(mutator)->mutate(expr->as<FullOp>());
492 return;
493 case ExprType::ARangeOp:
494 ptr(mutator)->mutate(expr->as<ARangeOp>());
495 return;
496 case ExprType::EyeOp:
497 ptr(mutator)->mutate(expr->as<EyeOp>());
498 return;
499 case ExprType::UnaryOp:
500 ptr(mutator)->mutate(expr->as<UnaryOp>());
501 return;
502 case ExprType::BinaryOp:
503 ptr(mutator)->mutate(expr->as<BinaryOp>());
504 return;
505 case ExprType::TernaryOp:
506 ptr(mutator)->mutate(expr->as<TernaryOp>());
507 return;
508 case ExprType::RNGOp:
509 ptr(mutator)->mutate(expr->as<RNGOp>());
510 return;
511 case ExprType::ReductionOp:
512 ptr(mutator)->mutate(expr->as<ReductionOp>());
513 return;
514 case ExprType::GroupedReductionOp:
515 ptr(mutator)->mutate(expr->as<GroupedReductionOp>());
516 return;
517 case ExprType::WelfordOp:
518 ptr(mutator)->mutate(expr->as<WelfordOp>());
519 return;
520 case ExprType::GroupedWelfordOp:
521 ptr(mutator)->mutate(expr->as<GroupedWelfordOp>());
522 return;
523 case ExprType::LoadStoreOp:
524 ptr(mutator)->mutate(expr->as<LoadStoreOp>());
525 return;
526 case ExprType::MmaOp:
527 ptr(mutator)->mutate(expr->as<MmaOp>());
528 return;
529 case ExprType::BroadcastOp:
530 ptr(mutator)->mutate(expr->as<BroadcastOp>());
531 return;
532
533 case ExprType::Split:
534 ptr(mutator)->mutate(expr->as<Split>());
535 return;
536 case ExprType::Merge:
537 ptr(mutator)->mutate(expr->as<Merge>());
538 return;
539 case ExprType::Swizzle2D:
540 ptr(mutator)->mutate(expr->as<Swizzle2D>());
541 return;
542 case ExprType::TransposeOp:
543 ptr(mutator)->mutate(expr->as<TransposeOp>());
544 return;
545 case ExprType::ExpandOp:
546 ptr(mutator)->mutate(expr->as<ExpandOp>());
547 return;
548 case ExprType::ShiftOp:
549 ptr(mutator)->mutate(expr->as<ShiftOp>());
550 return;
551 case ExprType::GatherOp:
552 ptr(mutator)->mutate(expr->as<GatherOp>());
553 return;
554 case ExprType::ViewAsScalar:
555 ptr(mutator)->mutate(expr->as<ViewAsScalar>());
556 return;
557 case ExprType::ViewOp:
558 ptr(mutator)->mutate(expr->as<ViewOp>());
559 return;
560
561 case ExprType::Allocate:
562 ptr(mutator)->mutate(expr->as<kir::Allocate>());
563 return;
564 case ExprType::BlockSync:
565 ptr(mutator)->mutate(expr->as<kir::BlockSync>());
566 return;
567 case ExprType::GridSync:
568 ptr(mutator)->mutate(expr->as<kir::GridSync>());
569 return;
570 case ExprType::CpAsyncWait:
571 ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
572 return;
573 case ExprType::CpAsyncCommit:
574 ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
575 return;
576 case ExprType::InitMagicZero:
577 ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
578 return;
579 case ExprType::UpdateMagicZero:
580 ptr(mutator)->mutate(expr->as<kir::UpdateMagicZero>());
581 return;
582 case ExprType::ForLoop:
583 ptr(mutator)->mutate(expr->as<kir::ForLoop>());
584 return;
585 case ExprType::IfThenElse:
586 ptr(mutator)->mutate(expr->as<kir::IfThenElse>());
587 return;
588 case ExprType::GridReduction:
589 ptr(mutator)->mutate(expr->as<kir::GridReduction>());
590 return;
591 case ExprType::GroupedGridReduction:
592 ptr(mutator)->mutate(expr->as<kir::GroupedGridReduction>());
593 return;
594 case ExprType::GridBroadcast:
595 ptr(mutator)->mutate(expr->as<kir::GridBroadcast>());
596 return;
597 case ExprType::GridWelford:
598 ptr(mutator)->mutate(expr->as<kir::GridWelford>());
599 return;
600 case ExprType::GroupedGridWelford:
601 ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>());
602 return;
603 case ExprType::AllocateFusedReduction:
604 ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
605 return;
606 case ExprType::Swizzle2DInt:
607 ptr(mutator)->mutate(expr->as<kir::Swizzle2DInt>());
608 return;
609 case ExprType::PairSelect:
610 ptr(mutator)->mutate(expr->as<kir::PairSelect>());
611 return;
612 default:
613 TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
614 }
615}
616
617template <typename T>
618void Statement::mutatorDispatch(T mutator, Statement* stmt) {
619 if (stmt->isVal()) {
620 ptr(mutator)->mutate(stmt->as<Val>());
621 return;
622 }
623 if (stmt->isExpr()) {
624 ptr(mutator)->mutate(stmt->as<Expr>());
625 return;
626 }
627 TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
628}
629
630/*
631 * Handler template instantiations. These should only have to be done on base
632 * classes. Actual visitors/mutators should inhereit from these classes and call
633 * ->dispatch(this) to avoid needing an explicit instantiation.
634 */
635template void Statement::dispatch(OptOutDispatch&, Statement*);
636template void Statement::dispatch(OptOutDispatch*, Statement*);
637template void Val::dispatch(OptOutDispatch&, Val*);
638template void Val::dispatch(OptOutDispatch*, Val*);
639template void Expr::dispatch(OptOutDispatch&, Expr*);
640template void Expr::dispatch(OptOutDispatch*, Expr*);
641
642template void Statement::dispatch(OptInDispatch, Statement*);
643template void Statement::dispatch(OptInDispatch*, Statement*);
644template void Val::dispatch(OptInDispatch, Val*);
645template void Val::dispatch(OptInDispatch*, Val*);
646template void Expr::dispatch(OptInDispatch, Expr*);
647template void Expr::dispatch(OptInDispatch*, Expr*);
648
649template void Statement::constDispatch(OptOutConstDispatch&, const Statement*);
650template void Statement::constDispatch(OptOutConstDispatch*, const Statement*);
651template void Val::constDispatch(OptOutConstDispatch&, const Val*);
652template void Val::constDispatch(OptOutConstDispatch*, const Val*);
653template void Expr::constDispatch(OptOutConstDispatch&, const Expr*);
654template void Expr::constDispatch(OptOutConstDispatch*, const Expr*);
655
656template void Statement::constDispatch(OptInConstDispatch&, const Statement*);
657template void Statement::constDispatch(OptInConstDispatch*, const Statement*);
658template void Val::constDispatch(OptInConstDispatch&, const Val*);
659template void Val::constDispatch(OptInConstDispatch*, const Val*);
660template void Expr::constDispatch(OptInConstDispatch&, const Expr*);
661template void Expr::constDispatch(OptInConstDispatch*, const Expr*);
662
663template void Statement::mutatorDispatch(OptOutMutator&, Statement*);
664template void Statement::mutatorDispatch(OptOutMutator*, Statement*);
665template void Val::mutatorDispatch(OptOutMutator&, Val*);
666template void Val::mutatorDispatch(OptOutMutator*, Val*);
667template void Expr::mutatorDispatch(OptOutMutator&, Expr*);
668template void Expr::mutatorDispatch(OptOutMutator*, Expr*);
669
670void OptOutDispatch::handle(Statement* s) {
671 Statement::dispatch(this, s);
672}
673
674void OptOutDispatch::handle(Expr* e) {
675 Expr::dispatch(this, e);
676}
677
678void OptOutDispatch::handle(Val* v) {
679 Val::dispatch(this, v);
680}
681
682void OptOutConstDispatch::handle(const Statement* s) {
683 Statement::constDispatch(this, s);
684}
685
686void OptOutConstDispatch::handle(const Expr* e) {
687 Expr::constDispatch(this, e);
688}
689
690void OptOutConstDispatch::handle(const Val* v) {
691 Val::constDispatch(this, v);
692}
693
694void OptInConstDispatch::unhandled(const Statement* stmt) {
695 if (stmt->isExpr()) {
696 TORCH_INTERNAL_ASSERT(
697 false, "Handle not overriden for ", stmt->getExprType().value(), ".");
698 } else if (stmt->isVal()) {
699 TORCH_INTERNAL_ASSERT(
700 false, "Handle not overriden for ", stmt->getValType().value(), ".");
701 } else {
702 TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type.");
703 }
704}
705
706void OptInDispatch::unhandled(Statement* stmt) {
707 if (stmt->isExpr()) {
708 TORCH_INTERNAL_ASSERT(
709 false, "Handle not overriden for ", stmt->getExprType().value(), ".");
710 } else if (stmt->isVal()) {
711 TORCH_INTERNAL_ASSERT(
712 false, "Handle not overriden for ", stmt->getValType().value(), ".");
713 } else {
714 TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type.");
715 }
716}
717
718// Vals
719void OptOutConstDispatch::handle(const Bool* stmt) {
720 unhandled(stmt);
721}
722void OptOutConstDispatch::handle(const Double* stmt) {
723 unhandled(stmt);
724}
725void OptOutConstDispatch::handle(const Int* stmt) {
726 unhandled(stmt);
727}
728void OptOutConstDispatch::handle(const ComplexDouble* stmt) {
729 unhandled(stmt);
730}
731void OptOutConstDispatch::handle(const NamedScalar* stmt) {
732 unhandled(stmt);
733}
734void OptOutConstDispatch::handle(const IterDomain* stmt) {
735 unhandled(stmt);
736}
737void OptOutConstDispatch::handle(const TensorDomain* stmt) {
738 unhandled(stmt);
739}
740void OptOutConstDispatch::handle(const TensorView* stmt) {
741 unhandled(stmt);
742}
743
744void OptOutConstDispatch::handle(const kir::Predicate* stmt) {
745 unhandled(stmt);
746}
747void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) {
748 unhandled(stmt);
749}
750void OptOutConstDispatch::handle(const kir::IntPair* stmt) {
751 unhandled(stmt);
752}
753
754// Exprs
755void OptOutConstDispatch::handle(const FullOp* stmt) {
756 unhandled(stmt);
757}
758void OptOutConstDispatch::handle(const ARangeOp* stmt) {
759 unhandled(stmt);
760}
761void OptOutConstDispatch::handle(const EyeOp* stmt) {
762 unhandled(stmt);
763}
764void OptOutConstDispatch::handle(const UnaryOp* stmt) {
765 unhandled(stmt);
766}
767void OptOutConstDispatch::handle(const BinaryOp* stmt) {
768 unhandled(stmt);
769}
770void OptOutConstDispatch::handle(const TernaryOp* stmt) {
771 unhandled(stmt);
772}
773void OptOutConstDispatch::handle(const RNGOp* stmt) {
774 unhandled(stmt);
775}
776void OptOutConstDispatch::handle(const ReductionOp* stmt) {
777 unhandled(stmt);
778}
779void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
780 unhandled(stmt);
781}
782void OptOutConstDispatch::handle(const WelfordOp* stmt) {
783 unhandled(stmt);
784}
785void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) {
786 unhandled(stmt);
787}
788void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
789 unhandled(stmt);
790}
791void OptOutConstDispatch::handle(const MmaOp* stmt) {
792 unhandled(stmt);
793}
794void OptOutConstDispatch::handle(const BroadcastOp* stmt) {
795 unhandled(stmt);
796}
797
798void OptOutConstDispatch::handle(const Split* stmt) {
799 unhandled(stmt);
800}
801void OptOutConstDispatch::handle(const Merge* stmt) {
802 unhandled(stmt);
803}
804void OptOutConstDispatch::handle(const Swizzle2D* stmt) {
805 unhandled(stmt);
806}
807void OptOutConstDispatch::handle(const TransposeOp* stmt) {
808 unhandled(stmt);
809}
810void OptOutConstDispatch::handle(const ExpandOp* stmt) {
811 unhandled(stmt);
812}
813void OptOutConstDispatch::handle(const ShiftOp* stmt) {
814 unhandled(stmt);
815}
816void OptOutConstDispatch::handle(const GatherOp* stmt) {
817 unhandled(stmt);
818}
819void OptOutConstDispatch::handle(const ViewAsScalar* stmt) {
820 unhandled(stmt);
821}
822void OptOutConstDispatch::handle(const ViewOp* stmt) {
823 unhandled(stmt);
824}
825
826void OptOutConstDispatch::handle(const kir::Allocate* stmt) {
827 unhandled(stmt);
828}
829void OptOutConstDispatch::handle(const kir::BlockSync* stmt) {
830 unhandled(stmt);
831}
832void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
833 unhandled(stmt);
834}
835void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
836 unhandled(stmt);
837}
838void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) {
839 unhandled(stmt);
840}
841void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
842 unhandled(stmt);
843}
844void OptOutConstDispatch::handle(const kir::UpdateMagicZero* stmt) {
845 unhandled(stmt);
846}
847void OptOutConstDispatch::handle(const kir::ForLoop* stmt) {
848 unhandled(stmt);
849}
850void OptOutConstDispatch::handle(const kir::IfThenElse* stmt) {
851 unhandled(stmt);
852}
853void OptOutConstDispatch::handle(const kir::GridReduction* stmt) {
854 unhandled(stmt);
855}
856void OptOutConstDispatch::handle(const kir::GroupedGridReduction* stmt) {
857 unhandled(stmt);
858}
859void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) {
860 unhandled(stmt);
861}
862void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
863 unhandled(stmt);
864}
865void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) {
866 unhandled(stmt);
867}
868void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
869 unhandled(stmt);
870}
871void OptOutConstDispatch::handle(const kir::Swizzle2DInt* stmt) {
872 unhandled(stmt);
873}
874void OptOutConstDispatch::handle(const kir::PairSelect* stmt) {
875 unhandled(stmt);
876}
877
878void OptOutDispatch::unhandled(Statement*) {}
879
880// Vals
881void OptOutDispatch::handle(Bool* stmt) {
882 unhandled(stmt);
883}
884void OptOutDispatch::handle(Double* stmt) {
885 unhandled(stmt);
886}
887void OptOutDispatch::handle(Int* stmt) {
888 unhandled(stmt);
889}
890void OptOutDispatch::handle(ComplexDouble* stmt) {
891 unhandled(stmt);
892}
893void OptOutDispatch::handle(NamedScalar* stmt) {
894 unhandled(stmt);
895}
896void OptOutDispatch::handle(IterDomain* stmt) {
897 unhandled(stmt);
898}
899void OptOutDispatch::handle(TensorDomain* stmt) {
900 unhandled(stmt);
901}
902void OptOutDispatch::handle(TensorView* stmt) {
903 unhandled(stmt);
904}
905
906void OptOutDispatch::handle(kir::Predicate* stmt) {
907 unhandled(stmt);
908}
909void OptOutDispatch::handle(kir::TensorIndex* stmt) {
910 unhandled(stmt);
911}
912void OptOutDispatch::handle(kir::IntPair* stmt) {
913 unhandled(stmt);
914}
915
916// Exprs
917void OptOutDispatch::handle(FullOp* stmt) {
918 unhandled(stmt);
919}
920void OptOutDispatch::handle(ARangeOp* stmt) {
921 unhandled(stmt);
922}
923void OptOutDispatch::handle(EyeOp* stmt) {
924 unhandled(stmt);
925}
926void OptOutDispatch::handle(UnaryOp* stmt) {
927 unhandled(stmt);
928}
929void OptOutDispatch::handle(BinaryOp* stmt) {
930 unhandled(stmt);
931}
932void OptOutDispatch::handle(TernaryOp* stmt) {
933 unhandled(stmt);
934}
935void OptOutDispatch::handle(RNGOp* stmt) {
936 unhandled(stmt);
937}
938void OptOutDispatch::handle(ReductionOp* stmt) {
939 unhandled(stmt);
940}
941void OptOutDispatch::handle(GroupedReductionOp* stmt) {
942 unhandled(stmt);
943}
944void OptOutDispatch::handle(WelfordOp* stmt) {
945 unhandled(stmt);
946}
947void OptOutDispatch::handle(GroupedWelfordOp* stmt) {
948 unhandled(stmt);
949}
950void OptOutDispatch::handle(LoadStoreOp* stmt) {
951 unhandled(stmt);
952}
953void OptOutDispatch::handle(MmaOp* stmt) {
954 unhandled(stmt);
955}
956void OptOutDispatch::handle(BroadcastOp* stmt) {
957 unhandled(stmt);
958}
959
960void OptOutDispatch::handle(Split* stmt) {
961 unhandled(stmt);
962}
963void OptOutDispatch::handle(Merge* stmt) {
964 unhandled(stmt);
965}
966void OptOutDispatch::handle(Swizzle2D* stmt) {
967 unhandled(stmt);
968}
969void OptOutDispatch::handle(TransposeOp* stmt) {
970 unhandled(stmt);
971}
972void OptOutDispatch::handle(ExpandOp* stmt) {
973 unhandled(stmt);
974}
975void OptOutDispatch::handle(ShiftOp* stmt) {
976 unhandled(stmt);
977}
978void OptOutDispatch::handle(GatherOp* stmt) {
979 unhandled(stmt);
980}
981void OptOutDispatch::handle(ViewAsScalar* stmt) {
982 unhandled(stmt);
983}
984void OptOutDispatch::handle(ViewOp* stmt) {
985 unhandled(stmt);
986}
987
988void OptOutDispatch::handle(kir::Allocate* stmt) {
989 unhandled(stmt);
990}
991void OptOutDispatch::handle(kir::BlockSync* stmt) {
992 unhandled(stmt);
993}
994void OptOutDispatch::handle(kir::GridSync* stmt) {
995 unhandled(stmt);
996}
997void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
998 unhandled(stmt);
999}
1000void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) {
1001 unhandled(stmt);
1002}
1003void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
1004 unhandled(stmt);
1005}
1006void OptOutDispatch::handle(kir::UpdateMagicZero* stmt) {
1007 unhandled(stmt);
1008}
1009void OptOutDispatch::handle(kir::ForLoop* stmt) {
1010 unhandled(stmt);
1011}
1012void OptOutDispatch::handle(kir::IfThenElse* stmt) {
1013 unhandled(stmt);
1014}
1015void OptOutDispatch::handle(kir::GridReduction* stmt) {
1016 unhandled(stmt);
1017}
1018void OptOutDispatch::handle(kir::GroupedGridReduction* stmt) {
1019 unhandled(stmt);
1020}
1021void OptOutDispatch::handle(kir::GridBroadcast* stmt) {
1022 unhandled(stmt);
1023}
1024void OptOutDispatch::handle(kir::GridWelford* stmt) {
1025 unhandled(stmt);
1026}
1027void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) {
1028 unhandled(stmt);
1029}
1030void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
1031 unhandled(stmt);
1032}
1033void OptOutDispatch::handle(kir::Swizzle2DInt* stmt) {
1034 unhandled(stmt);
1035}
1036void OptOutDispatch::handle(kir::PairSelect* stmt) {
1037 unhandled(stmt);
1038}
1039
1040} // namespace cuda
1041} // namespace fuser
1042} // namespace jit
1043} // namespace torch
1044