1 | #include <fusion.h> |
2 | #include <ir_all_nodes.h> |
3 | #include <type.h> |
4 | |
5 | #include <dispatch.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | template <typename T> |
13 | T* ptr(T& obj) { |
14 | return &obj; |
15 | } |
16 | |
17 | template <typename T> |
18 | T* ptr(T* obj) { |
19 | return obj; |
20 | } |
21 | |
22 | /* |
23 | * Generic dispatch for any handler that does not modify the IR directly. |
24 | * For example we may want to walk the graph to construct a topologically sorted |
25 | * set of exprs. This doesn't modify the IR directly. We also use this to print |
26 | * the IR itself. |
27 | * This dispatch is paired with a class that implements the functions: |
28 | * template <typenname node_type> |
29 | * int handler(node_type* node) |
30 | * |
31 | * handler should call: |
32 | * dispatch(this, node_to_dispatch) |
33 | * |
34 | * It could also implement: |
35 | * int handler(Statement* stmt){ |
36 | * dispatch(this, stmt); |
37 | * } |
38 | * |
39 | * And therefore dispatch should never call: |
40 | * ptr(mutator)->mutate(this->as<Statement>()); |
41 | */ |
42 | |
43 | template <typename T> |
44 | void Val::dispatch(T handler, Val* val) { |
45 | switch (*(val->getValType())) { |
46 | case ValType::Scalar: |
47 | switch (*(val->getDataType())) { |
48 | case DataType::Bool: |
49 | ptr(handler)->handle(val->as<Bool>()); |
50 | return; |
51 | case DataType::Double: |
52 | ptr(handler)->handle(val->as<Double>()); |
53 | return; |
54 | case DataType::Int: |
55 | case DataType::Int32: |
56 | // Dispatch to Int even with Int32 as we don't have Int32 IR |
57 | // node. |
58 | ptr(handler)->handle(val->as<Int>()); |
59 | return; |
60 | case DataType::ComplexDouble: |
61 | ptr(handler)->handle(val->as<ComplexDouble>()); |
62 | return; |
63 | default: |
64 | break; |
65 | } |
66 | break; |
67 | case ValType::NamedScalar: |
68 | ptr(handler)->handle(val->as<NamedScalar>()); |
69 | return; |
70 | |
71 | case ValType::IterDomain: |
72 | ptr(handler)->handle(val->as<IterDomain>()); |
73 | return; |
74 | case ValType::TensorDomain: |
75 | ptr(handler)->handle(val->as<TensorDomain>()); |
76 | return; |
77 | case ValType::TensorView: |
78 | ptr(handler)->handle(val->as<TensorView>()); |
79 | return; |
80 | case ValType::Predicate: |
81 | ptr(handler)->handle(val->as<kir::Predicate>()); |
82 | return; |
83 | case ValType::TensorIndex: |
84 | ptr(handler)->handle(val->as<kir::TensorIndex>()); |
85 | return; |
86 | case ValType::IntPair: |
87 | ptr(handler)->handle(val->as<kir::IntPair>()); |
88 | return; |
89 | default: |
90 | break; |
91 | } |
92 | TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!" ); |
93 | } |
94 | |
95 | template <typename T> |
96 | void Expr::dispatch(T handler, Expr* expr) { |
97 | switch (*(expr->getExprType())) { |
98 | case ExprType::FullOp: |
99 | ptr(handler)->handle(expr->as<FullOp>()); |
100 | return; |
101 | case ExprType::ARangeOp: |
102 | ptr(handler)->handle(expr->as<ARangeOp>()); |
103 | return; |
104 | case ExprType::EyeOp: |
105 | ptr(handler)->handle(expr->as<EyeOp>()); |
106 | return; |
107 | case ExprType::UnaryOp: |
108 | ptr(handler)->handle(expr->as<UnaryOp>()); |
109 | return; |
110 | case ExprType::BinaryOp: |
111 | ptr(handler)->handle(expr->as<BinaryOp>()); |
112 | return; |
113 | case ExprType::TernaryOp: |
114 | ptr(handler)->handle(expr->as<TernaryOp>()); |
115 | return; |
116 | case ExprType::RNGOp: |
117 | ptr(handler)->handle(expr->as<RNGOp>()); |
118 | return; |
119 | case ExprType::ReductionOp: |
120 | ptr(handler)->handle(expr->as<ReductionOp>()); |
121 | return; |
122 | case ExprType::GroupedReductionOp: |
123 | ptr(handler)->handle(expr->as<GroupedReductionOp>()); |
124 | return; |
125 | case ExprType::WelfordOp: |
126 | ptr(handler)->handle(expr->as<WelfordOp>()); |
127 | return; |
128 | case ExprType::GroupedWelfordOp: |
129 | ptr(handler)->handle(expr->as<GroupedWelfordOp>()); |
130 | return; |
131 | case ExprType::LoadStoreOp: |
132 | ptr(handler)->handle(expr->as<LoadStoreOp>()); |
133 | return; |
134 | case ExprType::MmaOp: |
135 | ptr(handler)->handle(expr->as<MmaOp>()); |
136 | return; |
137 | case ExprType::BroadcastOp: |
138 | ptr(handler)->handle(expr->as<BroadcastOp>()); |
139 | return; |
140 | |
141 | case ExprType::Split: |
142 | ptr(handler)->handle(expr->as<Split>()); |
143 | return; |
144 | case ExprType::Merge: |
145 | ptr(handler)->handle(expr->as<Merge>()); |
146 | return; |
147 | case ExprType::Swizzle2D: |
148 | ptr(handler)->handle(expr->as<Swizzle2D>()); |
149 | return; |
150 | case ExprType::TransposeOp: |
151 | ptr(handler)->handle(expr->as<TransposeOp>()); |
152 | return; |
153 | case ExprType::ExpandOp: |
154 | ptr(handler)->handle(expr->as<ExpandOp>()); |
155 | return; |
156 | case ExprType::ShiftOp: |
157 | ptr(handler)->handle(expr->as<ShiftOp>()); |
158 | return; |
159 | case ExprType::GatherOp: |
160 | ptr(handler)->handle(expr->as<GatherOp>()); |
161 | return; |
162 | case ExprType::ViewAsScalar: |
163 | ptr(handler)->handle(expr->as<ViewAsScalar>()); |
164 | return; |
165 | case ExprType::ViewOp: |
166 | ptr(handler)->handle(expr->as<ViewOp>()); |
167 | return; |
168 | |
169 | case ExprType::Allocate: |
170 | ptr(handler)->handle(expr->as<kir::Allocate>()); |
171 | return; |
172 | case ExprType::BlockSync: |
173 | ptr(handler)->handle(expr->as<kir::BlockSync>()); |
174 | return; |
175 | case ExprType::GridSync: |
176 | ptr(handler)->handle(expr->as<kir::GridSync>()); |
177 | return; |
178 | case ExprType::CpAsyncWait: |
179 | ptr(handler)->handle(expr->as<kir::CpAsyncWait>()); |
180 | return; |
181 | case ExprType::CpAsyncCommit: |
182 | ptr(handler)->handle(expr->as<kir::CpAsyncCommit>()); |
183 | return; |
184 | case ExprType::InitMagicZero: |
185 | ptr(handler)->handle(expr->as<kir::InitMagicZero>()); |
186 | return; |
187 | case ExprType::UpdateMagicZero: |
188 | ptr(handler)->handle(expr->as<kir::UpdateMagicZero>()); |
189 | return; |
190 | case ExprType::ForLoop: |
191 | ptr(handler)->handle(expr->as<kir::ForLoop>()); |
192 | return; |
193 | case ExprType::IfThenElse: |
194 | ptr(handler)->handle(expr->as<kir::IfThenElse>()); |
195 | return; |
196 | case ExprType::GridReduction: |
197 | ptr(handler)->handle(expr->as<kir::GridReduction>()); |
198 | return; |
199 | case ExprType::GroupedGridReduction: |
200 | ptr(handler)->handle(expr->as<kir::GroupedGridReduction>()); |
201 | return; |
202 | case ExprType::GridBroadcast: |
203 | ptr(handler)->handle(expr->as<kir::GridBroadcast>()); |
204 | return; |
205 | case ExprType::GridWelford: |
206 | ptr(handler)->handle(expr->as<kir::GridWelford>()); |
207 | return; |
208 | case ExprType::GroupedGridWelford: |
209 | ptr(handler)->handle(expr->as<kir::GroupedGridWelford>()); |
210 | return; |
211 | case ExprType::AllocateFusedReduction: |
212 | ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>()); |
213 | return; |
214 | case ExprType::Swizzle2DInt: |
215 | ptr(handler)->handle(expr->as<kir::Swizzle2DInt>()); |
216 | return; |
217 | case ExprType::PairSelect: |
218 | ptr(handler)->handle(expr->as<kir::PairSelect>()); |
219 | return; |
220 | default: |
221 | TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!" ); |
222 | } |
223 | } |
224 | |
225 | template <typename T> |
226 | void Statement::dispatch(T handler, Statement* stmt) { |
227 | if (stmt->isVal()) { |
228 | ptr(handler)->handle(stmt->as<Val>()); |
229 | } else if (stmt->isExpr()) { |
230 | ptr(handler)->handle(stmt->as<Expr>()); |
231 | } else |
232 | TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!" ); |
233 | } |
234 | |
235 | template <typename T> |
236 | void Val::constDispatch(T handler, const Val* val) { |
237 | switch (*(val->getValType())) { |
238 | case ValType::Scalar: |
239 | switch (*(val->getDataType())) { |
240 | case DataType::Bool: |
241 | ptr(handler)->handle(val->as<Bool>()); |
242 | return; |
243 | case DataType::Double: |
244 | ptr(handler)->handle(val->as<Double>()); |
245 | return; |
246 | case DataType::Int: |
247 | case DataType::Int32: |
248 | // Dispatch to Int even with Int32 as we don't have Int32 IR |
249 | // node. |
250 | ptr(handler)->handle(val->as<Int>()); |
251 | return; |
252 | case DataType::ComplexDouble: |
253 | ptr(handler)->handle(val->as<ComplexDouble>()); |
254 | return; |
255 | default: |
256 | break; |
257 | } |
258 | break; |
259 | case ValType::NamedScalar: |
260 | ptr(handler)->handle(val->as<NamedScalar>()); |
261 | return; |
262 | |
263 | case ValType::IterDomain: |
264 | ptr(handler)->handle(val->as<IterDomain>()); |
265 | return; |
266 | case ValType::TensorDomain: |
267 | ptr(handler)->handle(val->as<TensorDomain>()); |
268 | return; |
269 | case ValType::TensorView: |
270 | ptr(handler)->handle(val->as<TensorView>()); |
271 | return; |
272 | case ValType::Predicate: |
273 | ptr(handler)->handle(val->as<kir::Predicate>()); |
274 | return; |
275 | case ValType::TensorIndex: |
276 | ptr(handler)->handle(val->as<kir::TensorIndex>()); |
277 | return; |
278 | case ValType::IntPair: |
279 | ptr(handler)->handle(val->as<kir::IntPair>()); |
280 | return; |
281 | default: |
282 | break; |
283 | } |
284 | TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!" ); |
285 | } |
286 | |
287 | template <typename T> |
288 | void Expr::constDispatch(T handler, const Expr* expr) { |
289 | switch (*(expr->getExprType())) { |
290 | case ExprType::FullOp: |
291 | ptr(handler)->handle(expr->as<FullOp>()); |
292 | return; |
293 | case ExprType::ARangeOp: |
294 | ptr(handler)->handle(expr->as<ARangeOp>()); |
295 | return; |
296 | case ExprType::EyeOp: |
297 | ptr(handler)->handle(expr->as<EyeOp>()); |
298 | return; |
299 | case ExprType::UnaryOp: |
300 | ptr(handler)->handle(expr->as<UnaryOp>()); |
301 | return; |
302 | case ExprType::BinaryOp: |
303 | ptr(handler)->handle(expr->as<BinaryOp>()); |
304 | return; |
305 | case ExprType::TernaryOp: |
306 | ptr(handler)->handle(expr->as<TernaryOp>()); |
307 | return; |
308 | case ExprType::RNGOp: |
309 | ptr(handler)->handle(expr->as<RNGOp>()); |
310 | return; |
311 | case ExprType::ReductionOp: |
312 | ptr(handler)->handle(expr->as<ReductionOp>()); |
313 | return; |
314 | case ExprType::GroupedReductionOp: |
315 | ptr(handler)->handle(expr->as<GroupedReductionOp>()); |
316 | return; |
317 | case ExprType::WelfordOp: |
318 | ptr(handler)->handle(expr->as<WelfordOp>()); |
319 | return; |
320 | case ExprType::GroupedWelfordOp: |
321 | ptr(handler)->handle(expr->as<GroupedWelfordOp>()); |
322 | return; |
323 | case ExprType::LoadStoreOp: |
324 | ptr(handler)->handle(expr->as<LoadStoreOp>()); |
325 | return; |
326 | case ExprType::MmaOp: |
327 | ptr(handler)->handle(expr->as<MmaOp>()); |
328 | return; |
329 | case ExprType::BroadcastOp: |
330 | ptr(handler)->handle(expr->as<BroadcastOp>()); |
331 | return; |
332 | |
333 | case ExprType::Split: |
334 | ptr(handler)->handle(expr->as<Split>()); |
335 | return; |
336 | case ExprType::Merge: |
337 | ptr(handler)->handle(expr->as<Merge>()); |
338 | return; |
339 | case ExprType::Swizzle2D: |
340 | ptr(handler)->handle(expr->as<Swizzle2D>()); |
341 | return; |
342 | case ExprType::TransposeOp: |
343 | ptr(handler)->handle(expr->as<TransposeOp>()); |
344 | return; |
345 | case ExprType::ExpandOp: |
346 | ptr(handler)->handle(expr->as<ExpandOp>()); |
347 | return; |
348 | case ExprType::ShiftOp: |
349 | ptr(handler)->handle(expr->as<ShiftOp>()); |
350 | return; |
351 | case ExprType::GatherOp: |
352 | ptr(handler)->handle(expr->as<GatherOp>()); |
353 | return; |
354 | case ExprType::ViewAsScalar: |
355 | ptr(handler)->handle(expr->as<ViewAsScalar>()); |
356 | return; |
357 | case ExprType::ViewOp: |
358 | ptr(handler)->handle(expr->as<ViewOp>()); |
359 | return; |
360 | |
361 | case ExprType::Allocate: |
362 | ptr(handler)->handle(expr->as<kir::Allocate>()); |
363 | return; |
364 | case ExprType::BlockSync: |
365 | ptr(handler)->handle(expr->as<kir::BlockSync>()); |
366 | return; |
367 | case ExprType::GridSync: |
368 | ptr(handler)->handle(expr->as<kir::GridSync>()); |
369 | return; |
370 | case ExprType::CpAsyncWait: |
371 | ptr(handler)->handle(expr->as<kir::CpAsyncWait>()); |
372 | return; |
373 | case ExprType::CpAsyncCommit: |
374 | ptr(handler)->handle(expr->as<kir::CpAsyncCommit>()); |
375 | return; |
376 | case ExprType::InitMagicZero: |
377 | ptr(handler)->handle(expr->as<kir::InitMagicZero>()); |
378 | return; |
379 | case ExprType::UpdateMagicZero: |
380 | ptr(handler)->handle(expr->as<kir::UpdateMagicZero>()); |
381 | return; |
382 | case ExprType::ForLoop: |
383 | ptr(handler)->handle(expr->as<kir::ForLoop>()); |
384 | return; |
385 | case ExprType::IfThenElse: |
386 | ptr(handler)->handle(expr->as<kir::IfThenElse>()); |
387 | return; |
388 | case ExprType::GridReduction: |
389 | ptr(handler)->handle(expr->as<kir::GridReduction>()); |
390 | return; |
391 | case ExprType::GroupedGridReduction: |
392 | ptr(handler)->handle(expr->as<kir::GroupedGridReduction>()); |
393 | return; |
394 | case ExprType::GridBroadcast: |
395 | ptr(handler)->handle(expr->as<kir::GridBroadcast>()); |
396 | return; |
397 | case ExprType::GridWelford: |
398 | ptr(handler)->handle(expr->as<kir::GridWelford>()); |
399 | return; |
400 | case ExprType::GroupedGridWelford: |
401 | ptr(handler)->handle(expr->as<kir::GroupedGridWelford>()); |
402 | return; |
403 | case ExprType::AllocateFusedReduction: |
404 | ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>()); |
405 | return; |
406 | case ExprType::Swizzle2DInt: |
407 | ptr(handler)->handle(expr->as<kir::Swizzle2DInt>()); |
408 | return; |
409 | case ExprType::PairSelect: |
410 | ptr(handler)->handle(expr->as<kir::PairSelect>()); |
411 | return; |
412 | default: |
413 | TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!" ); |
414 | } |
415 | } |
416 | |
417 | template <typename T> |
418 | void Statement::constDispatch(T handler, const Statement* stmt) { |
419 | if (stmt->isVal()) { |
420 | ptr(handler)->handle(stmt->as<Val>()); |
421 | } else if (stmt->isExpr()) { |
422 | ptr(handler)->handle(stmt->as<Expr>()); |
423 | } else |
424 | TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!" ); |
425 | } |
426 | |
427 | /* |
428 | * Generic mutatorDispatch for any handler that modifies the IR. This could be |
429 | * a transformation on loop structures, or parallelizing a loop. This |
430 | * mutatorDispatch is paired with a class that implements the functions |
431 | * template <typenname node_type> Statement* mutate(node_type* node) mutate |
432 | * should call (statement* node_to_dispatch)->mutatorDispatch() It could also |
433 | * implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this); |
434 | * } |
435 | * And therefore dispatch should never call: |
436 | * ptr(mutator)->mutate(this->as<Statement>()); |
437 | */ |
438 | template <typename T> |
439 | void Val::mutatorDispatch(T mutator, Val* val) { |
440 | switch (*(val->getValType())) { |
441 | case ValType::Scalar: |
442 | switch (*(val->getDataType())) { |
443 | case DataType::Bool: |
444 | ptr(mutator)->mutate(val->as<Bool>()); |
445 | return; |
446 | case DataType::Double: |
447 | ptr(mutator)->mutate(val->as<Double>()); |
448 | return; |
449 | case DataType::Int: |
450 | ptr(mutator)->mutate(val->as<Int>()); |
451 | return; |
452 | case DataType::ComplexDouble: |
453 | ptr(mutator)->mutate(val->as<ComplexDouble>()); |
454 | return; |
455 | default: |
456 | break; |
457 | } |
458 | break; |
459 | case ValType::NamedScalar: |
460 | ptr(mutator)->mutate(val->as<NamedScalar>()); |
461 | return; |
462 | |
463 | case ValType::IterDomain: |
464 | ptr(mutator)->mutate(val->as<IterDomain>()); |
465 | return; |
466 | case ValType::TensorDomain: |
467 | ptr(mutator)->mutate(val->as<TensorDomain>()); |
468 | return; |
469 | case ValType::TensorView: |
470 | ptr(mutator)->mutate(val->as<TensorView>()); |
471 | return; |
472 | case ValType::Predicate: |
473 | ptr(mutator)->mutate(val->as<kir::Predicate>()); |
474 | return; |
475 | case ValType::TensorIndex: |
476 | ptr(mutator)->mutate(val->as<kir::TensorIndex>()); |
477 | return; |
478 | case ValType::IntPair: |
479 | ptr(mutator)->mutate(val->as<kir::IntPair>()); |
480 | return; |
481 | default: |
482 | break; |
483 | } |
484 | TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!" ); |
485 | } |
486 | |
487 | template <typename T> |
488 | void Expr::mutatorDispatch(T mutator, Expr* expr) { |
489 | switch (*(expr->getExprType())) { |
490 | case ExprType::FullOp: |
491 | ptr(mutator)->mutate(expr->as<FullOp>()); |
492 | return; |
493 | case ExprType::ARangeOp: |
494 | ptr(mutator)->mutate(expr->as<ARangeOp>()); |
495 | return; |
496 | case ExprType::EyeOp: |
497 | ptr(mutator)->mutate(expr->as<EyeOp>()); |
498 | return; |
499 | case ExprType::UnaryOp: |
500 | ptr(mutator)->mutate(expr->as<UnaryOp>()); |
501 | return; |
502 | case ExprType::BinaryOp: |
503 | ptr(mutator)->mutate(expr->as<BinaryOp>()); |
504 | return; |
505 | case ExprType::TernaryOp: |
506 | ptr(mutator)->mutate(expr->as<TernaryOp>()); |
507 | return; |
508 | case ExprType::RNGOp: |
509 | ptr(mutator)->mutate(expr->as<RNGOp>()); |
510 | return; |
511 | case ExprType::ReductionOp: |
512 | ptr(mutator)->mutate(expr->as<ReductionOp>()); |
513 | return; |
514 | case ExprType::GroupedReductionOp: |
515 | ptr(mutator)->mutate(expr->as<GroupedReductionOp>()); |
516 | return; |
517 | case ExprType::WelfordOp: |
518 | ptr(mutator)->mutate(expr->as<WelfordOp>()); |
519 | return; |
520 | case ExprType::GroupedWelfordOp: |
521 | ptr(mutator)->mutate(expr->as<GroupedWelfordOp>()); |
522 | return; |
523 | case ExprType::LoadStoreOp: |
524 | ptr(mutator)->mutate(expr->as<LoadStoreOp>()); |
525 | return; |
526 | case ExprType::MmaOp: |
527 | ptr(mutator)->mutate(expr->as<MmaOp>()); |
528 | return; |
529 | case ExprType::BroadcastOp: |
530 | ptr(mutator)->mutate(expr->as<BroadcastOp>()); |
531 | return; |
532 | |
533 | case ExprType::Split: |
534 | ptr(mutator)->mutate(expr->as<Split>()); |
535 | return; |
536 | case ExprType::Merge: |
537 | ptr(mutator)->mutate(expr->as<Merge>()); |
538 | return; |
539 | case ExprType::Swizzle2D: |
540 | ptr(mutator)->mutate(expr->as<Swizzle2D>()); |
541 | return; |
542 | case ExprType::TransposeOp: |
543 | ptr(mutator)->mutate(expr->as<TransposeOp>()); |
544 | return; |
545 | case ExprType::ExpandOp: |
546 | ptr(mutator)->mutate(expr->as<ExpandOp>()); |
547 | return; |
548 | case ExprType::ShiftOp: |
549 | ptr(mutator)->mutate(expr->as<ShiftOp>()); |
550 | return; |
551 | case ExprType::GatherOp: |
552 | ptr(mutator)->mutate(expr->as<GatherOp>()); |
553 | return; |
554 | case ExprType::ViewAsScalar: |
555 | ptr(mutator)->mutate(expr->as<ViewAsScalar>()); |
556 | return; |
557 | case ExprType::ViewOp: |
558 | ptr(mutator)->mutate(expr->as<ViewOp>()); |
559 | return; |
560 | |
561 | case ExprType::Allocate: |
562 | ptr(mutator)->mutate(expr->as<kir::Allocate>()); |
563 | return; |
564 | case ExprType::BlockSync: |
565 | ptr(mutator)->mutate(expr->as<kir::BlockSync>()); |
566 | return; |
567 | case ExprType::GridSync: |
568 | ptr(mutator)->mutate(expr->as<kir::GridSync>()); |
569 | return; |
570 | case ExprType::CpAsyncWait: |
571 | ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>()); |
572 | return; |
573 | case ExprType::CpAsyncCommit: |
574 | ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>()); |
575 | return; |
576 | case ExprType::InitMagicZero: |
577 | ptr(mutator)->mutate(expr->as<kir::InitMagicZero>()); |
578 | return; |
579 | case ExprType::UpdateMagicZero: |
580 | ptr(mutator)->mutate(expr->as<kir::UpdateMagicZero>()); |
581 | return; |
582 | case ExprType::ForLoop: |
583 | ptr(mutator)->mutate(expr->as<kir::ForLoop>()); |
584 | return; |
585 | case ExprType::IfThenElse: |
586 | ptr(mutator)->mutate(expr->as<kir::IfThenElse>()); |
587 | return; |
588 | case ExprType::GridReduction: |
589 | ptr(mutator)->mutate(expr->as<kir::GridReduction>()); |
590 | return; |
591 | case ExprType::GroupedGridReduction: |
592 | ptr(mutator)->mutate(expr->as<kir::GroupedGridReduction>()); |
593 | return; |
594 | case ExprType::GridBroadcast: |
595 | ptr(mutator)->mutate(expr->as<kir::GridBroadcast>()); |
596 | return; |
597 | case ExprType::GridWelford: |
598 | ptr(mutator)->mutate(expr->as<kir::GridWelford>()); |
599 | return; |
600 | case ExprType::GroupedGridWelford: |
601 | ptr(mutator)->mutate(expr->as<kir::GroupedGridWelford>()); |
602 | return; |
603 | case ExprType::AllocateFusedReduction: |
604 | ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>()); |
605 | return; |
606 | case ExprType::Swizzle2DInt: |
607 | ptr(mutator)->mutate(expr->as<kir::Swizzle2DInt>()); |
608 | return; |
609 | case ExprType::PairSelect: |
610 | ptr(mutator)->mutate(expr->as<kir::PairSelect>()); |
611 | return; |
612 | default: |
613 | TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!" ); |
614 | } |
615 | } |
616 | |
617 | template <typename T> |
618 | void Statement::mutatorDispatch(T mutator, Statement* stmt) { |
619 | if (stmt->isVal()) { |
620 | ptr(mutator)->mutate(stmt->as<Val>()); |
621 | return; |
622 | } |
623 | if (stmt->isExpr()) { |
624 | ptr(mutator)->mutate(stmt->as<Expr>()); |
625 | return; |
626 | } |
627 | TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!" ); |
628 | } |
629 | |
630 | /* |
631 | * Handler template instantiations. These should only have to be done on base |
632 | * classes. Actual visitors/mutators should inhereit from these classes and call |
633 | * ->dispatch(this) to avoid needing an explicit instantiation. |
634 | */ |
635 | template void Statement::dispatch(OptOutDispatch&, Statement*); |
636 | template void Statement::dispatch(OptOutDispatch*, Statement*); |
637 | template void Val::dispatch(OptOutDispatch&, Val*); |
638 | template void Val::dispatch(OptOutDispatch*, Val*); |
639 | template void Expr::dispatch(OptOutDispatch&, Expr*); |
640 | template void Expr::dispatch(OptOutDispatch*, Expr*); |
641 | |
642 | template void Statement::dispatch(OptInDispatch, Statement*); |
643 | template void Statement::dispatch(OptInDispatch*, Statement*); |
644 | template void Val::dispatch(OptInDispatch, Val*); |
645 | template void Val::dispatch(OptInDispatch*, Val*); |
646 | template void Expr::dispatch(OptInDispatch, Expr*); |
647 | template void Expr::dispatch(OptInDispatch*, Expr*); |
648 | |
649 | template void Statement::constDispatch(OptOutConstDispatch&, const Statement*); |
650 | template void Statement::constDispatch(OptOutConstDispatch*, const Statement*); |
651 | template void Val::constDispatch(OptOutConstDispatch&, const Val*); |
652 | template void Val::constDispatch(OptOutConstDispatch*, const Val*); |
653 | template void Expr::constDispatch(OptOutConstDispatch&, const Expr*); |
654 | template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); |
655 | |
656 | template void Statement::constDispatch(OptInConstDispatch&, const Statement*); |
657 | template void Statement::constDispatch(OptInConstDispatch*, const Statement*); |
658 | template void Val::constDispatch(OptInConstDispatch&, const Val*); |
659 | template void Val::constDispatch(OptInConstDispatch*, const Val*); |
660 | template void Expr::constDispatch(OptInConstDispatch&, const Expr*); |
661 | template void Expr::constDispatch(OptInConstDispatch*, const Expr*); |
662 | |
663 | template void Statement::mutatorDispatch(OptOutMutator&, Statement*); |
664 | template void Statement::mutatorDispatch(OptOutMutator*, Statement*); |
665 | template void Val::mutatorDispatch(OptOutMutator&, Val*); |
666 | template void Val::mutatorDispatch(OptOutMutator*, Val*); |
667 | template void Expr::mutatorDispatch(OptOutMutator&, Expr*); |
668 | template void Expr::mutatorDispatch(OptOutMutator*, Expr*); |
669 | |
670 | void OptOutDispatch::handle(Statement* s) { |
671 | Statement::dispatch(this, s); |
672 | } |
673 | |
674 | void OptOutDispatch::handle(Expr* e) { |
675 | Expr::dispatch(this, e); |
676 | } |
677 | |
678 | void OptOutDispatch::handle(Val* v) { |
679 | Val::dispatch(this, v); |
680 | } |
681 | |
682 | void OptOutConstDispatch::handle(const Statement* s) { |
683 | Statement::constDispatch(this, s); |
684 | } |
685 | |
686 | void OptOutConstDispatch::handle(const Expr* e) { |
687 | Expr::constDispatch(this, e); |
688 | } |
689 | |
690 | void OptOutConstDispatch::handle(const Val* v) { |
691 | Val::constDispatch(this, v); |
692 | } |
693 | |
694 | void OptInConstDispatch::unhandled(const Statement* stmt) { |
695 | if (stmt->isExpr()) { |
696 | TORCH_INTERNAL_ASSERT( |
697 | false, "Handle not overriden for " , stmt->getExprType().value(), "." ); |
698 | } else if (stmt->isVal()) { |
699 | TORCH_INTERNAL_ASSERT( |
700 | false, "Handle not overriden for " , stmt->getValType().value(), "." ); |
701 | } else { |
702 | TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type." ); |
703 | } |
704 | } |
705 | |
706 | void OptInDispatch::unhandled(Statement* stmt) { |
707 | if (stmt->isExpr()) { |
708 | TORCH_INTERNAL_ASSERT( |
709 | false, "Handle not overriden for " , stmt->getExprType().value(), "." ); |
710 | } else if (stmt->isVal()) { |
711 | TORCH_INTERNAL_ASSERT( |
712 | false, "Handle not overriden for " , stmt->getValType().value(), "." ); |
713 | } else { |
714 | TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type." ); |
715 | } |
716 | } |
717 | |
718 | // Vals |
719 | void OptOutConstDispatch::handle(const Bool* stmt) { |
720 | unhandled(stmt); |
721 | } |
722 | void OptOutConstDispatch::handle(const Double* stmt) { |
723 | unhandled(stmt); |
724 | } |
725 | void OptOutConstDispatch::handle(const Int* stmt) { |
726 | unhandled(stmt); |
727 | } |
728 | void OptOutConstDispatch::handle(const ComplexDouble* stmt) { |
729 | unhandled(stmt); |
730 | } |
731 | void OptOutConstDispatch::handle(const NamedScalar* stmt) { |
732 | unhandled(stmt); |
733 | } |
734 | void OptOutConstDispatch::handle(const IterDomain* stmt) { |
735 | unhandled(stmt); |
736 | } |
737 | void OptOutConstDispatch::handle(const TensorDomain* stmt) { |
738 | unhandled(stmt); |
739 | } |
740 | void OptOutConstDispatch::handle(const TensorView* stmt) { |
741 | unhandled(stmt); |
742 | } |
743 | |
744 | void OptOutConstDispatch::handle(const kir::Predicate* stmt) { |
745 | unhandled(stmt); |
746 | } |
747 | void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) { |
748 | unhandled(stmt); |
749 | } |
750 | void OptOutConstDispatch::handle(const kir::IntPair* stmt) { |
751 | unhandled(stmt); |
752 | } |
753 | |
754 | // Exprs |
755 | void OptOutConstDispatch::handle(const FullOp* stmt) { |
756 | unhandled(stmt); |
757 | } |
758 | void OptOutConstDispatch::handle(const ARangeOp* stmt) { |
759 | unhandled(stmt); |
760 | } |
761 | void OptOutConstDispatch::handle(const EyeOp* stmt) { |
762 | unhandled(stmt); |
763 | } |
764 | void OptOutConstDispatch::handle(const UnaryOp* stmt) { |
765 | unhandled(stmt); |
766 | } |
767 | void OptOutConstDispatch::handle(const BinaryOp* stmt) { |
768 | unhandled(stmt); |
769 | } |
770 | void OptOutConstDispatch::handle(const TernaryOp* stmt) { |
771 | unhandled(stmt); |
772 | } |
773 | void OptOutConstDispatch::handle(const RNGOp* stmt) { |
774 | unhandled(stmt); |
775 | } |
776 | void OptOutConstDispatch::handle(const ReductionOp* stmt) { |
777 | unhandled(stmt); |
778 | } |
779 | void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) { |
780 | unhandled(stmt); |
781 | } |
782 | void OptOutConstDispatch::handle(const WelfordOp* stmt) { |
783 | unhandled(stmt); |
784 | } |
785 | void OptOutConstDispatch::handle(const GroupedWelfordOp* stmt) { |
786 | unhandled(stmt); |
787 | } |
788 | void OptOutConstDispatch::handle(const LoadStoreOp* stmt) { |
789 | unhandled(stmt); |
790 | } |
791 | void OptOutConstDispatch::handle(const MmaOp* stmt) { |
792 | unhandled(stmt); |
793 | } |
794 | void OptOutConstDispatch::handle(const BroadcastOp* stmt) { |
795 | unhandled(stmt); |
796 | } |
797 | |
798 | void OptOutConstDispatch::handle(const Split* stmt) { |
799 | unhandled(stmt); |
800 | } |
801 | void OptOutConstDispatch::handle(const Merge* stmt) { |
802 | unhandled(stmt); |
803 | } |
804 | void OptOutConstDispatch::handle(const Swizzle2D* stmt) { |
805 | unhandled(stmt); |
806 | } |
807 | void OptOutConstDispatch::handle(const TransposeOp* stmt) { |
808 | unhandled(stmt); |
809 | } |
810 | void OptOutConstDispatch::handle(const ExpandOp* stmt) { |
811 | unhandled(stmt); |
812 | } |
813 | void OptOutConstDispatch::handle(const ShiftOp* stmt) { |
814 | unhandled(stmt); |
815 | } |
816 | void OptOutConstDispatch::handle(const GatherOp* stmt) { |
817 | unhandled(stmt); |
818 | } |
819 | void OptOutConstDispatch::handle(const ViewAsScalar* stmt) { |
820 | unhandled(stmt); |
821 | } |
822 | void OptOutConstDispatch::handle(const ViewOp* stmt) { |
823 | unhandled(stmt); |
824 | } |
825 | |
826 | void OptOutConstDispatch::handle(const kir::Allocate* stmt) { |
827 | unhandled(stmt); |
828 | } |
829 | void OptOutConstDispatch::handle(const kir::BlockSync* stmt) { |
830 | unhandled(stmt); |
831 | } |
832 | void OptOutConstDispatch::handle(const kir::GridSync* stmt) { |
833 | unhandled(stmt); |
834 | } |
835 | void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) { |
836 | unhandled(stmt); |
837 | } |
838 | void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) { |
839 | unhandled(stmt); |
840 | } |
841 | void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) { |
842 | unhandled(stmt); |
843 | } |
844 | void OptOutConstDispatch::handle(const kir::UpdateMagicZero* stmt) { |
845 | unhandled(stmt); |
846 | } |
847 | void OptOutConstDispatch::handle(const kir::ForLoop* stmt) { |
848 | unhandled(stmt); |
849 | } |
850 | void OptOutConstDispatch::handle(const kir::IfThenElse* stmt) { |
851 | unhandled(stmt); |
852 | } |
853 | void OptOutConstDispatch::handle(const kir::GridReduction* stmt) { |
854 | unhandled(stmt); |
855 | } |
856 | void OptOutConstDispatch::handle(const kir::GroupedGridReduction* stmt) { |
857 | unhandled(stmt); |
858 | } |
859 | void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) { |
860 | unhandled(stmt); |
861 | } |
862 | void OptOutConstDispatch::handle(const kir::GridWelford* stmt) { |
863 | unhandled(stmt); |
864 | } |
865 | void OptOutConstDispatch::handle(const kir::GroupedGridWelford* stmt) { |
866 | unhandled(stmt); |
867 | } |
868 | void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) { |
869 | unhandled(stmt); |
870 | } |
871 | void OptOutConstDispatch::handle(const kir::Swizzle2DInt* stmt) { |
872 | unhandled(stmt); |
873 | } |
874 | void OptOutConstDispatch::handle(const kir::PairSelect* stmt) { |
875 | unhandled(stmt); |
876 | } |
877 | |
878 | void OptOutDispatch::unhandled(Statement*) {} |
879 | |
880 | // Vals |
881 | void OptOutDispatch::handle(Bool* stmt) { |
882 | unhandled(stmt); |
883 | } |
884 | void OptOutDispatch::handle(Double* stmt) { |
885 | unhandled(stmt); |
886 | } |
887 | void OptOutDispatch::handle(Int* stmt) { |
888 | unhandled(stmt); |
889 | } |
890 | void OptOutDispatch::handle(ComplexDouble* stmt) { |
891 | unhandled(stmt); |
892 | } |
893 | void OptOutDispatch::handle(NamedScalar* stmt) { |
894 | unhandled(stmt); |
895 | } |
896 | void OptOutDispatch::handle(IterDomain* stmt) { |
897 | unhandled(stmt); |
898 | } |
899 | void OptOutDispatch::handle(TensorDomain* stmt) { |
900 | unhandled(stmt); |
901 | } |
902 | void OptOutDispatch::handle(TensorView* stmt) { |
903 | unhandled(stmt); |
904 | } |
905 | |
906 | void OptOutDispatch::handle(kir::Predicate* stmt) { |
907 | unhandled(stmt); |
908 | } |
909 | void OptOutDispatch::handle(kir::TensorIndex* stmt) { |
910 | unhandled(stmt); |
911 | } |
912 | void OptOutDispatch::handle(kir::IntPair* stmt) { |
913 | unhandled(stmt); |
914 | } |
915 | |
916 | // Exprs |
917 | void OptOutDispatch::handle(FullOp* stmt) { |
918 | unhandled(stmt); |
919 | } |
920 | void OptOutDispatch::handle(ARangeOp* stmt) { |
921 | unhandled(stmt); |
922 | } |
923 | void OptOutDispatch::handle(EyeOp* stmt) { |
924 | unhandled(stmt); |
925 | } |
926 | void OptOutDispatch::handle(UnaryOp* stmt) { |
927 | unhandled(stmt); |
928 | } |
929 | void OptOutDispatch::handle(BinaryOp* stmt) { |
930 | unhandled(stmt); |
931 | } |
932 | void OptOutDispatch::handle(TernaryOp* stmt) { |
933 | unhandled(stmt); |
934 | } |
935 | void OptOutDispatch::handle(RNGOp* stmt) { |
936 | unhandled(stmt); |
937 | } |
938 | void OptOutDispatch::handle(ReductionOp* stmt) { |
939 | unhandled(stmt); |
940 | } |
941 | void OptOutDispatch::handle(GroupedReductionOp* stmt) { |
942 | unhandled(stmt); |
943 | } |
944 | void OptOutDispatch::handle(WelfordOp* stmt) { |
945 | unhandled(stmt); |
946 | } |
947 | void OptOutDispatch::handle(GroupedWelfordOp* stmt) { |
948 | unhandled(stmt); |
949 | } |
950 | void OptOutDispatch::handle(LoadStoreOp* stmt) { |
951 | unhandled(stmt); |
952 | } |
953 | void OptOutDispatch::handle(MmaOp* stmt) { |
954 | unhandled(stmt); |
955 | } |
956 | void OptOutDispatch::handle(BroadcastOp* stmt) { |
957 | unhandled(stmt); |
958 | } |
959 | |
960 | void OptOutDispatch::handle(Split* stmt) { |
961 | unhandled(stmt); |
962 | } |
963 | void OptOutDispatch::handle(Merge* stmt) { |
964 | unhandled(stmt); |
965 | } |
966 | void OptOutDispatch::handle(Swizzle2D* stmt) { |
967 | unhandled(stmt); |
968 | } |
969 | void OptOutDispatch::handle(TransposeOp* stmt) { |
970 | unhandled(stmt); |
971 | } |
972 | void OptOutDispatch::handle(ExpandOp* stmt) { |
973 | unhandled(stmt); |
974 | } |
975 | void OptOutDispatch::handle(ShiftOp* stmt) { |
976 | unhandled(stmt); |
977 | } |
978 | void OptOutDispatch::handle(GatherOp* stmt) { |
979 | unhandled(stmt); |
980 | } |
981 | void OptOutDispatch::handle(ViewAsScalar* stmt) { |
982 | unhandled(stmt); |
983 | } |
984 | void OptOutDispatch::handle(ViewOp* stmt) { |
985 | unhandled(stmt); |
986 | } |
987 | |
988 | void OptOutDispatch::handle(kir::Allocate* stmt) { |
989 | unhandled(stmt); |
990 | } |
991 | void OptOutDispatch::handle(kir::BlockSync* stmt) { |
992 | unhandled(stmt); |
993 | } |
994 | void OptOutDispatch::handle(kir::GridSync* stmt) { |
995 | unhandled(stmt); |
996 | } |
997 | void OptOutDispatch::handle(kir::CpAsyncWait* stmt) { |
998 | unhandled(stmt); |
999 | } |
1000 | void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) { |
1001 | unhandled(stmt); |
1002 | } |
1003 | void OptOutDispatch::handle(kir::InitMagicZero* stmt) { |
1004 | unhandled(stmt); |
1005 | } |
1006 | void OptOutDispatch::handle(kir::UpdateMagicZero* stmt) { |
1007 | unhandled(stmt); |
1008 | } |
1009 | void OptOutDispatch::handle(kir::ForLoop* stmt) { |
1010 | unhandled(stmt); |
1011 | } |
1012 | void OptOutDispatch::handle(kir::IfThenElse* stmt) { |
1013 | unhandled(stmt); |
1014 | } |
1015 | void OptOutDispatch::handle(kir::GridReduction* stmt) { |
1016 | unhandled(stmt); |
1017 | } |
1018 | void OptOutDispatch::handle(kir::GroupedGridReduction* stmt) { |
1019 | unhandled(stmt); |
1020 | } |
1021 | void OptOutDispatch::handle(kir::GridBroadcast* stmt) { |
1022 | unhandled(stmt); |
1023 | } |
1024 | void OptOutDispatch::handle(kir::GridWelford* stmt) { |
1025 | unhandled(stmt); |
1026 | } |
1027 | void OptOutDispatch::handle(kir::GroupedGridWelford* stmt) { |
1028 | unhandled(stmt); |
1029 | } |
1030 | void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) { |
1031 | unhandled(stmt); |
1032 | } |
1033 | void OptOutDispatch::handle(kir::Swizzle2DInt* stmt) { |
1034 | unhandled(stmt); |
1035 | } |
1036 | void OptOutDispatch::handle(kir::PairSelect* stmt) { |
1037 | unhandled(stmt); |
1038 | } |
1039 | |
1040 | } // namespace cuda |
1041 | } // namespace fuser |
1042 | } // namespace jit |
1043 | } // namespace torch |
1044 | |