1#include <ir_iostream.h>
2#include <ir_printer.h>
3
4#include <fusion.h>
5#include <instrumentation.h>
6#include <ir_all_nodes.h>
7#include <ir_utils.h>
8#include <kernel.h>
9#include <lower_utils.h>
10
11#include <c10/util/irange.h>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18namespace {
19const char* boolLiteral(bool value) {
20 return value ? "true" : "false";
21}
22
23std::string varName(const Val* val) {
24 std::stringstream value_name;
25 if (val == nullptr) {
26 value_name << "$nullptr";
27 } else {
28 value_name << val->name();
29 }
30 return value_name.str();
31}
32
33} // namespace
34
35// Make sure we can inline something, before we attempt to.
36static void checkInlineable(const Expr* expr) {
37 for (auto input : expr->inputs()) {
38 TORCH_CHECK(
39 input->isScalar(),
40 "Printing inline computations involving values other than scalars is not currently supported.");
41 }
42 TORCH_CHECK(
43 expr->outputs().size() == 1,
44 "Cannot print inline computations if there's more than one output.");
45 TORCH_CHECK(
46 expr->output(0)->isScalar(),
47 "Printing inline computations involving values other than scalars is not currently supported.");
48}
49
50void IrPrinter::handle(const Statement* s) {
51 OptInConstDispatch::handle(s);
52}
53
54void IrPrinter::handle(const Val* v) {
55 OptInConstDispatch::handle(v);
56}
57
58void IrPrinter::handle(const Expr* e) {
59 OptInConstDispatch::handle(e);
60}
61
62void IrPrinter::handle(Fusion* fusion) {
63 FUSER_PERF_SCOPE("IrPrinter");
64 resetIndent();
65 for (const Expr* expr : fusion->exprs()) {
66 handle(expr);
67 }
68}
69
70void IrPrinter::handle(const kir::Kernel* kernel) {
71 TORCH_CHECK(kernel != nullptr);
72
73 // kernel declaration
74 os_ << "\nKERNEL (";
75 for (auto in : kernel->inputs()) {
76 handle(in);
77 if (in != kernel->inputs().back()) {
78 os_ << ", ";
79 }
80 }
81 os_ << ") -> (";
82 for (auto out : kernel->outputs()) {
83 handle(out);
84 if (out != kernel->outputs().back()) {
85 os_ << ", ";
86 }
87 }
88 os_ << ") :\n";
89
90 // kernel body
91 indent_size_++;
92 for (auto expr : kernel->topLevelExprs()) {
93 handle(expr);
94 }
95 indent_size_--;
96 os_ << "END.\n\n";
97}
98
99void IrPrinter::handle(kir::Kernel& kernel) {
100 handle(&kernel);
101}
102
103void IrPrinter::handleScope(const kir::Scope& scope) {
104 // Save the uses of the parent scope
105 indent_size_++;
106 for (auto expr : scope.exprs()) {
107 handle(expr);
108 }
109 indent_size_--;
110}
111
112void IrPrinter::handle(const IterDomain* id) {
113 os_ << id->getIterType();
114 os_ << id->getParallelType();
115 os_ << varName(id);
116 os_ << "{";
117 if (!id->start()->isZeroInt()) {
118 print_inline(id->start());
119 os_ << " : ";
120 }
121 if (id->stop() != id->extent()) {
122 print_inline(id->stop());
123 os_ << " : ";
124 }
125 if (id->isBroadcast() && id->hasExpandedExtent()) {
126 print_inline(id->expandedExtent());
127 } else {
128 print_inline(id->extent());
129 }
130 os_ << "}";
131 if (id->isRFactorProduct())
132 os_ << "rf";
133 if (id->hasPaddingToMultipleOfWarp()) {
134 os_ << "_p";
135 }
136}
137
138void IrPrinter::handle(const TensorDomain* td) {
139 if (td->nDims() == 0) {
140 os_ << "[ 0 ]";
141 return;
142 }
143 os_ << "[ ";
144 for (const auto i : c10::irange(td->nDims())) {
145 handle(td->axis(i));
146 if (i != td->nDims() - 1)
147 os_ << ", ";
148 }
149 os_ << " ]";
150}
151
152void IrPrinter::handle(const TensorView* tv) {
153 os_ << "T" << varName(tv);
154 switch (tv->getMemoryType()) {
155 case MemoryType::Global:
156 os_ << "_g";
157 break;
158 case MemoryType::Shared:
159 os_ << "_s";
160 break;
161 case MemoryType::Local:
162 os_ << "_l";
163 break;
164 }
165 handle(tv->domain());
166
167 if (tv->getComputeAtPosition() > 0) {
168 os_ << " ca_pos( ";
169 os_ << tv->getComputeAtPosition();
170 os_ << " )";
171 }
172 if (tv->getMaxProducerPosition() > 0) {
173 os_ << " produce_pos( ";
174 os_ << tv->getMaxProducerPosition();
175 os_ << ")";
176 }
177}
178
179void IrPrinter::handle(const Bool* b) {
180 if (print_inline_ && b->definition() != nullptr) {
181 os_ << "( ";
182 handle(b->definition());
183 os_ << " )";
184 return;
185 }
186
187 os_ << "b" << varName(b);
188 if (b->isConst()) {
189 os_ << "(" << (b->value().value() ? "true" : "false") << ")";
190 }
191}
192
193void IrPrinter::handle(const Double* d) {
194 if (print_inline_ && d->definition() != nullptr) {
195 os_ << "( ";
196 handle(d->definition());
197 os_ << " )";
198 return;
199 }
200
201 if (d->isSymbolic()) {
202 os_ << "d" << varName(d);
203 } else {
204 os_ << "double("
205 << std::setprecision(
206 std::numeric_limits<Double::ScalarType>::max_digits10)
207 << *(d->value()) << ")";
208 }
209}
210
211void IrPrinter::handle(const Int* i) {
212 if (print_inline_) {
213 if (auto def = i->definition()) {
214 os_ << "( ";
215 handle(def);
216 os_ << " )";
217 return;
218 }
219 }
220
221 if (i->isSymbolic()) {
222 os_ << "i" << varName(i);
223 } else {
224 os_ << *(i->value());
225 }
226}
227
228void IrPrinter::handle(const ComplexDouble* c) {
229 if (print_inline_) {
230 if (auto def = c->definition()) {
231 os_ << "( ";
232 handle(def);
233 os_ << " )";
234 return;
235 }
236 }
237
238 if (c->isSymbolic()) {
239 os_ << "c" << varName(c);
240 } else {
241 os_ << "std::complex<double>"
242 << std::setprecision(std::numeric_limits<double>::max_digits10)
243 << *(c->value());
244 }
245}
246
247void IrPrinter::handle(const NamedScalar* ns) {
248 os_ << ns->name();
249}
250
251void IrPrinter::handle(const FullOp* fop) {
252 if (!print_inline_) {
253 indent();
254 os_ << fop->output(0) << "\n";
255 indent_size_++;
256 indent();
257 os_ << " = ";
258 } else {
259 checkInlineable(fop);
260 }
261
262 os_ << "full({";
263 for (auto i : c10::irange(fop->inputs().size())) {
264 if (i == fop->inputs().size() - 1) {
265 os_ << "}";
266 }
267 if (i > 0) {
268 os_ << ", ";
269 }
270 handle(fop->input(i));
271 }
272 os_ << ", " << fop->dtype() << ")";
273
274 indent_size_--;
275
276 if (!print_inline_)
277 os_ << ";\n";
278}
279
280void IrPrinter::handle(const ARangeOp* aop) {
281 if (!print_inline_) {
282 indent() << aop->output(0);
283 os_ << "\n";
284 indent_size_++;
285 indent();
286 os_ << " = ";
287 } else {
288 checkInlineable(aop);
289 }
290
291 os_ << "arange(";
292 handle(aop->start());
293 os_ << ", ";
294 handle(aop->end());
295 os_ << ", ";
296 handle(aop->step());
297 os_ << ", " << aop->dtype() << ")";
298
299 indent_size_--;
300
301 if (!print_inline_)
302 os_ << ";\n";
303}
304
305void IrPrinter::handle(const EyeOp* eop) {
306 if (!print_inline_) {
307 indent();
308 os_ << eop->output(0) << "\n";
309 indent_size_++;
310 indent();
311 os_ << " = ";
312 } else {
313 checkInlineable(eop);
314 }
315
316 os_ << "eye(";
317 handle(eop->input(0));
318 os_ << ", " << eop->dtype() << ")";
319
320 indent_size_--;
321
322 if (!print_inline_)
323 os_ << ";\n";
324}
325
326void IrPrinter::handle(const UnaryOp* uop) {
327 bool istvop = ir_utils::isTvOp(uop);
328 if (!print_inline_) {
329 indent() << uop->out();
330 if (istvop) {
331 os_ << "\n";
332 indent_size_++;
333 indent();
334 }
335 os_ << " = ";
336 } else {
337 checkInlineable(uop);
338 }
339
340 auto op_type = uop->getUnaryOpType();
341
342 if (auto inline_uop = inline_op_str(op_type)) {
343 os_ << inline_uop.value();
344 handle(uop->in());
345 } else {
346 if (op_type == UnaryOpType::Cast) {
347 c10::optional<std::string> cast_str = cast_func_str(std::make_pair(
348 uop->in()->getDataType().value(), uop->out()->getDataType().value()));
349 TORCH_INTERNAL_ASSERT(cast_str != c10::nullopt, "Unsupported Cast");
350 os_ << cast_str.value();
351 } else {
352 if (alsoBooleanOperator(op_type) &&
353 uop->out()->getDataType().value() == DataType::Bool) {
354 os_ << stringifyBooleanOp(op_type);
355 } else {
356 os_ << op_type;
357 }
358 if (uop->out()->getDataType().value() == DataType::Float &&
359 needFloatSuffix(op_type)) {
360 os_ << "f";
361 }
362 }
363 os_ << "(";
364 handle(uop->in());
365 os_ << ")";
366 }
367
368 if (istvop)
369 indent_size_--;
370
371 if (!print_inline_)
372 os_ << ";\n";
373}
374
375void IrPrinter::handle(const BinaryOp* bop) {
376 bool istvop = ir_utils::isTvOp(bop);
377 if (!print_inline_) {
378 indent() << bop->out();
379
380 // tensor operations tend to be long, break them up into multiple lines
381 if (istvop) {
382 os_ << "\n";
383 indent_size_++;
384 indent();
385 }
386
387 os_ << " = ";
388 } else {
389 checkInlineable(bop);
390 }
391
392 auto op_type = bop->getBinaryOpType();
393 if (auto inline_bop = inline_op_str(op_type)) {
394 handle(bop->lhs());
395 if (istvop) {
396 os_ << "\n";
397 indent();
398 }
399 os_ << " " << inline_bop.value() << " ";
400 handle(bop->rhs());
401 } else {
402 if (alsoBooleanOperator(op_type) &&
403 bop->out()->getDataType().value() == DataType::Bool) {
404 os_ << stringifyBooleanOp(op_type);
405 } else {
406 os_ << op_type;
407 }
408 if (bop->out()->getDataType().value() == DataType::Float &&
409 needFloatSuffix(op_type)) {
410 os_ << "f";
411 }
412 os_ << "(";
413 handle(bop->lhs());
414 if (istvop) {
415 os_ << "\n";
416 indent();
417 }
418 os_ << ", ";
419 handle(bop->rhs());
420 os_ << ")";
421 }
422
423 if (istvop)
424 indent_size_--;
425
426 if (!print_inline_)
427 os_ << ";\n";
428}
429
430void IrPrinter::handle(const TernaryOp* top) {
431 bool istvop = ir_utils::isTvOp(top);
432 if (!print_inline_) {
433 indent();
434 os_ << top->out();
435
436 // tensor operations tend to be long, break them up into multiple lines
437 if (istvop) {
438 os_ << "\n";
439 indent_size_++;
440 indent();
441 }
442
443 os_ << " = ";
444 } else {
445 checkInlineable(top);
446 }
447
448 os_ << top->getTernaryOpType() << "(";
449 handle(top->in1());
450 if (istvop) {
451 os_ << "\n";
452 indent();
453 }
454 os_ << ", ";
455 handle(top->in2());
456 if (istvop) {
457 os_ << "\n";
458 indent();
459 }
460 os_ << ", ";
461 handle(top->in3());
462 os_ << ")";
463
464 if (istvop)
465 indent_size_--;
466
467 if (!print_inline_)
468 os_ << ";\n";
469}
470
471void IrPrinter::handle(const RNGOp* rop) {
472 if (!print_inline_) {
473 indent();
474 os_ << rop->output(0) << "\n";
475 indent_size_++;
476 indent();
477 os_ << " = ";
478 } else {
479 checkInlineable(rop);
480 }
481
482 os_ << rop->getRNGOpType() << "({";
483 bool first = true;
484 for (auto i : rop->getShape()) {
485 if (!first) {
486 os_ << ", ";
487 }
488 handle(i);
489 first = false;
490 }
491 os_ << "}";
492 for (auto i : rop->getParameters()) {
493 os_ << ", ";
494 handle(i);
495 }
496 os_ << ", " << rop->dtype() << ")";
497
498 indent_size_--;
499
500 if (!print_inline_) {
501 os_ << ";\n";
502 }
503}
504
505void IrPrinter::handle(const ReductionOp* rop) {
506 indent() << rop->out() << "\n";
507 indent() << " = reduction( " << rop->in()
508 << ", op = " << rop->getReductionOpType()
509 << ", initial value = " << rop->init()
510 << ", allreduce = " << (rop->isAllreduce() ? "true" : "false")
511 << " )\n";
512}
513
514void IrPrinter::handle(const GroupedReductionOp* grouped_rop) {
515 indent() << "GroupedReductionOp(\n";
516 ++indent_size_;
517 for (const auto i : c10::irange(grouped_rop->numExprs())) {
518 indent() << grouped_rop->output(i) << " = reduction( "
519 << grouped_rop->input(i)
520 << ", op = " << grouped_rop->getReductionOpType(i)
521 << ", initial value = " << grouped_rop->initVal(i) << " )\n";
522 }
523 indent() << "allreduce = " << (grouped_rop->isAllreduce() ? "true" : "false")
524 << " )\n";
525 --indent_size_;
526}
527
528void IrPrinter::handle(const WelfordOp* wop) {
529 indent() << wop->outAvg() << "(Avg),\n"
530 << wop->outVar() << "(Var),\n"
531 << wop->outN() << "(Count)"
532 << "\n = Welford ( ";
533 if (wop->singleValue()) {
534 os_ << wop->inAvg() << "(Avg), ";
535 } else {
536 os_ << wop->inAvg() << "(Avg)\n " << wop->inVar() << "(Var)\n "
537 << wop->inN() << "(Count)";
538 }
539 if (wop->hasInit()) {
540 os_ << "\n initial value = " << wop->initAvg() << "(Avg)\n "
541 << wop->initVar() << "(Var)\n " << wop->initN() << "(N)";
542 }
543 os_ << "\n allreduce = " << (wop->isAllreduce() ? "true" : "false");
544 os_ << " )\n";
545}
546
547void IrPrinter::handle(const GroupedWelfordOp* grouped_wop) {
548 indent() << "GroupedWelford(\n";
549 ++indent_size_;
550 for (const auto i : c10::irange(grouped_wop->numExprs())) {
551 indent() << grouped_wop->outAvg(i) << " (Avg),\n";
552 indent() << grouped_wop->outVar(i) << " (Var),\n";
553 indent() << grouped_wop->outN(i) << " (Count)\n";
554 indent() << " = Welford ( ";
555 ++indent_size_;
556 indent() << grouped_wop->inAvg(i) << " (Avg),\n";
557 indent() << grouped_wop->inVar(i) << " (Var),\n";
558 indent() << grouped_wop->inN(i) << " (Count)\n";
559 indent() << "initial value =\n";
560 ++indent_size_;
561 indent() << grouped_wop->initAvg(i) << " (Avg),\n";
562 indent() << grouped_wop->initVar(i) << " (Var),\n";
563 indent() << grouped_wop->initN(i) << " (Count) )\n";
564 indent_size_ -= 2;
565 }
566 indent() << "allreduce = " << (grouped_wop->isAllreduce() ? "true" : "false")
567 << " )\n";
568 --indent_size_;
569}
570
571void IrPrinter::handle(const LoadStoreOp* ldst) {
572 indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in()
573 << " )\n";
574}
575
576void IrPrinter::handle(const BroadcastOp* bop) {
577 indent() << bop->out() << "\n";
578 indent() << " = broadcast( " << bop->in() << " )\n";
579}
580
581void IrPrinter::handle(const Split* s) {
582 os_ << (s->innerSplit() ? "Split: " : "Outer split: ");
583 handle(s->in());
584 os_ << " by factor " << s->factor() << " -> ";
585 handle(s->outer());
586 os_ << ", ";
587 handle(s->inner());
588 if (s->startOffset()) {
589 os_ << ", start offset: ";
590 handle(s->startOffset());
591 }
592 if (s->stopOffset()) {
593 os_ << ", stop offset: ";
594 handle(s->stopOffset());
595 }
596 os_ << "\n";
597}
598
599void IrPrinter::handle(const Merge* m) {
600 os_ << "Merge: ";
601 handle(m->outer());
602 os_ << " and ";
603 handle(m->inner());
604 os_ << " -> ";
605 handle(m->out());
606 os_ << "\n";
607}
608
609void IrPrinter::handle(const Swizzle2D* s) {
610 os_ << s->swizzleType() << "(2D): ";
611 handle(s->inX());
612 os_ << " , ";
613 handle(s->inY());
614 os_ << " -> ";
615 handle(s->outX());
616 os_ << " , ";
617 handle(s->outY());
618 os_ << "\n";
619}
620
621void IrPrinter::handle(const TransposeOp* top) {
622 indent() << top->out() << " = transpose( " << top->in() << " )\n";
623}
624
625void IrPrinter::handle(const ExpandOp* eop) {
626 indent() << eop->out() << " = expand( " << eop->in() << ", {";
627 std::stringstream ss;
628 for (auto expanded_extent : eop->expanded_extents()) {
629 if (ss.tellp()) {
630 ss << ", ";
631 }
632 ss << expanded_extent;
633 }
634 os_ << ss.str() << "} )\n";
635}
636
637void IrPrinter::handle(const ShiftOp* sop) {
638 indent() << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets()
639 << "}, {" << sop->padWidth() << "} )\n";
640}
641
642void IrPrinter::handle(const MmaOp* mma) {
643 indent() << mma->out() << " = mma(" << mma->inA() << "," << mma->inB();
644 os_ << ")\n";
645}
646
647void IrPrinter::handle(const GatherOp* op) {
648 indent() << op->out() << " = gather( " << op->in() << ", {";
649 bool no_comma = true;
650 for (const auto& s : op->windowShape()) {
651 if (!no_comma) {
652 os_ << ", ";
653 }
654 os_ << s;
655 no_comma = false;
656 }
657 os_ << "}, {";
658 no_comma = true;
659 for (const auto& pad : op->padWidth()) {
660 if (!no_comma) {
661 os_ << ", ";
662 }
663 os_ << "{" << pad[0] << ", " << pad[1] << "}";
664 no_comma = false;
665 }
666 os_ << "} )\n";
667}
668
669void IrPrinter::handle(const ViewAsScalar* top) {
670 indent() << top->out() << " = view_as_scalar( " << top->in() << ", "
671 << top->vector_id() << " )\n";
672}
673
674void IrPrinter::handle(const ViewOp* top) {
675 indent() << top->out() << " = view( " << top->in() << " )\n";
676}
677
678void IrPrinter::handle(const kir::Predicate* node) {
679 switch (node->predicate_type()) {
680 case PredicateType::Manual: {
681 os_ << node->value();
682 break;
683 }
684 default:
685 os_ << node->predicate_type();
686 break;
687 }
688}
689
690void IrPrinter::handle(const kir::TensorIndex* ti) {
691 os_ << "T" << varName(ti);
692 switch (ti->view()->getMemoryType()) {
693 case MemoryType::Global:
694 os_ << "_g";
695 break;
696 case MemoryType::Shared:
697 os_ << "_s";
698 break;
699 case MemoryType::Local:
700 os_ << "_l";
701 break;
702 }
703 os_ << "[";
704 for (auto index : ti->indices()) {
705 print_inline(index);
706 if (index != ti->indices().back()) {
707 os_ << ", ";
708 }
709 }
710 os_ << "]";
711 os_ << " view( T" << varName(ti->view()) << " )";
712}
713
714void IrPrinter::handle(const kir::Allocate* node) {
715 indent();
716 handle(node->buffer());
717 os_ << " = ALLOCATE("
718 << "mem_type=" << node->memoryType() << ", "
719 << "size=";
720 print_inline(node->size());
721 os_ << ", "
722 << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n";
723 if (node->alias() != nullptr) {
724 indent() << kTab << ".alias=";
725 handle(node->alias()->buffer());
726 os_ << "\n";
727 }
728}
729
730void IrPrinter::handle(const kir::BlockSync* node) {
731 indent() << "BLOCKSYNC(war_hazard=" << boolLiteral(node->isWarHazardSync())
732 << ")\n";
733}
734
735void IrPrinter::handle(const kir::CpAsyncWait* node) {
736 indent() << "CPASYNC_WAIT(" << node->keepStages() << ")\n";
737}
738
739void IrPrinter::handle(const kir::CpAsyncCommit* node) {
740 indent() << "CPASYNC_WAIT()\n";
741}
742
743void IrPrinter::handle(const kir::GridSync* node) {
744 indent() << "GRIDSYNC(" << node->syncDims().toString() << ", ";
745 handle(node->syncBuffer());
746 os_ << ")\n";
747}
748
749void IrPrinter::handle(const kir::ForLoop* node) {
750 indent() << "FOR ";
751 handle(node->index());
752 os_ << " in ";
753 handle(node->iter_domain());
754 os_ << ":\n";
755 handleScope(node->body());
756}
757
758void IrPrinter::handle(const kir::IfThenElse* node) {
759 indent() << "IF ";
760 handle(node->predicate());
761 os_ << ":\n";
762 handleScope(node->thenBody());
763 if (node->hasElse()) {
764 indent() << "ELSE:\n";
765 handleScope(node->elseBody());
766 }
767}
768
769void IrPrinter::handle(const kir::GridBroadcast* node) {
770 const auto* broadcast_op = node->broadcast_op();
771 indent();
772 handle(broadcast_op->out());
773 os_ << " = "
774 << "GRID_BROADCAST(in=";
775 handle(broadcast_op->in());
776 os_ << ")\n";
777 indent() << kTab << ".broadcast_buffer=";
778 handle(node->broadcast_buffer()->buffer());
779 os_ << "\n";
780 indent() << kTab << ".sync_buffer=";
781 handle(node->sync_buffer()->buffer());
782 os_ << "\n";
783}
784
785void IrPrinter::handle(const kir::GridReduction* node) {
786 indent() << node->out() << " = reduction( " << node->in()
787 << ", op = " << node->getReductionOpType()
788 << ", initial value = " << node->init() << ",\n";
789 ++indent_size_;
790 indent() << "reduction buffer = " << node->reduction_buffer()->buffer()
791 << ",\n";
792 indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n";
793 indent() << "read predicate = ";
794 if (node->predicate() != nullptr) {
795 os_ << node->predicate();
796 } else {
797 os_ << "nullptr";
798 }
799 os_ << ",\n";
800 indent() << "write predicate = ";
801 if (node->writePredicate() != nullptr) {
802 os_ << node->writePredicate();
803 } else {
804 os_ << "nullptr";
805 }
806 os_ << ",\n";
807 indent() << "thread predicate = " << node->threadPredicate().toString()
808 << ",\n";
809 indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false")
810 << " )\n";
811 --indent_size_;
812}
813
814void IrPrinter::handle(const kir::GroupedGridReduction* node) {
815 indent() << "GroupedGridReduction(\n";
816 ++indent_size_;
817 for (const auto i : c10::irange(node->numExprs())) {
818 indent() << node->output(i) << " = reduction( " << node->input(i)
819 << ", op = " << node->getReductionOpType(i)
820 << ", initial value = " << node->initVal(i)
821 << ", reduction buffer = "
822 << node->reduction_buffers().at(i)->buffer() << " )\n";
823 }
824 indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n";
825 indent() << "read predicate = ";
826 if (node->predicate() != nullptr) {
827 os_ << node->predicate();
828 } else {
829 os_ << "nullptr";
830 }
831 os_ << ",\n";
832 indent() << "write predicate = ";
833 if (node->writePredicate() != nullptr) {
834 os_ << node->writePredicate();
835 } else {
836 os_ << "nullptr";
837 }
838 os_ << ",\n";
839 indent() << "thread predicate = " << node->threadPredicate().toString()
840 << ",\n";
841 indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false")
842 << " )\n";
843 --indent_size_;
844}
845
846void IrPrinter::handle(const kir::GridWelford* node) {
847 std::cerr << "current indent size: " << indent_size_ << std::endl;
848 const auto* welford_op = node->welford_op();
849 indent() << welford_op->outAvg() << " (Avg),\n";
850 indent() << welford_op->outVar() << " (Var),\n";
851 indent() << welford_op->outN() << " (Count)\n";
852 indent() << " = Welford (\n";
853 ++indent_size_;
854 indent() << welford_op->inAvg() << " (Avg),\n";
855 indent() << welford_op->inVar() << " (Var),\n";
856 indent() << welford_op->inN() << " (Count)\n";
857 indent() << "initial value =\n";
858 ++indent_size_;
859 indent() << welford_op->initAvg() << " (Avg),\n";
860 indent() << welford_op->initVar() << " (Var),\n";
861 indent() << welford_op->initN() << " (Count),\n";
862 --indent_size_;
863 indent() << "reduction buffer =\n";
864 ++indent_size_;
865 indent() << node->avg_buffer()->buffer() << " (Avg),\n";
866 indent() << node->var_buffer()->buffer() << " (Var),\n";
867 indent() << node->N_buffer()->buffer() << " (Count),\n";
868 --indent_size_;
869 indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n";
870 indent() << "read predicate = ";
871 if (welford_op->predicate() != nullptr) {
872 os_ << welford_op->predicate();
873 } else {
874 os_ << "nullptr";
875 }
876 os_ << ",\n";
877 indent() << "write predicate = ";
878 if (welford_op->writePredicate() != nullptr) {
879 os_ << welford_op->writePredicate();
880 } else {
881 os_ << "nullptr";
882 }
883 os_ << ",\n";
884 indent() << "grid read predicate = ";
885 if (node->predicate() != nullptr) {
886 os_ << node->predicate();
887 } else {
888 os_ << "nullptr";
889 }
890 os_ << ",\n";
891 indent() << "grid write predicate = ";
892 if (node->writePredicate() != nullptr) {
893 os_ << node->writePredicate();
894 } else {
895 os_ << "nullptr";
896 }
897 os_ << ",\n";
898 indent() << "thread predicate = " << node->threadPredicate().toString()
899 << ",\n";
900 indent() << "allreduce = " << (welford_op->isAllreduce() ? "true" : "false")
901 << " )\n";
902 --indent_size_;
903 std::cerr << "Ending indent size: " << indent_size_ << std::endl;
904}
905
906void IrPrinter::handle(const kir::GroupedGridWelford* node) {
907 indent() << "GroupedGridWelford(\n";
908 ++indent_size_;
909 for (const auto i : c10::irange(node->numExprs())) {
910 indent() << node->outAvg(i) << " (Avg),\n";
911 indent() << node->outVar(i) << " (Var),\n";
912 indent() << node->outN(i) << " (Count)\n";
913 indent() << " = Welford (\n";
914 ++indent_size_;
915 indent() << node->inAvg(i) << " (Avg),\n";
916 indent() << node->inVar(i) << " (Var),\n";
917 indent() << node->inN(i) << " (Count)\n";
918 indent() << "initial value =\n";
919 ++indent_size_;
920 indent() << node->initAvg(i) << " (Avg),\n";
921 indent() << node->initVar(i) << " (Var),\n";
922 indent() << node->initN(i) << " (Count),\n";
923 --indent_size_;
924 indent() << "reduction buffer =\n";
925 ++indent_size_;
926 indent() << node->reduction_buffers()[0].at(i)->buffer() << " (Avg),\n";
927 indent() << node->reduction_buffers()[1].at(i)->buffer() << " (Var),\n";
928 indent() << node->reduction_buffers()[2].at(i)->buffer() << " (Count) )\n";
929 indent_size_ -= 2;
930 }
931 indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n";
932 indent() << "read predicate = ";
933 if (node->predicate() != nullptr) {
934 os_ << node->predicate();
935 } else {
936 os_ << "nullptr";
937 }
938 os_ << ",\n";
939 indent() << "write predicate = ";
940 if (node->writePredicate() != nullptr) {
941 os_ << node->writePredicate();
942 } else {
943 os_ << "nullptr";
944 }
945 os_ << ",\n";
946 indent() << "thread predicate = " << node->threadPredicate().toString()
947 << ",\n";
948 indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false")
949 << " )\n";
950 --indent_size_;
951}
952
953void IrPrinter::handle(const kir::InitMagicZero* node) {
954 indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
955}
956
957void IrPrinter::handle(const kir::UpdateMagicZero* node) {
958 indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
959}
960
961void IrPrinter::handle(const kir::AllocateFusedReduction* node) {
962 indent() << "AllocateFusedReduction(reduction buffer=";
963 handle(node->out());
964 os_ << ")\n";
965}
966
967void IrPrinter::handle(const kir::IntPair* node) {
968 if (print_inline_) {
969 if (node->definition()) {
970 handle(node->definition());
971 return;
972 }
973 }
974 os_ << "iPair" << varName(node);
975}
976
977void IrPrinter::handle(const kir::Swizzle2DInt* node) {
978 if (!print_inline_) {
979 indent();
980 handle(node->out());
981 os_ << " = ";
982 }
983
984 os_ << node->swizzleType() << "2D(";
985 handle(node->inX());
986 os_ << ",";
987 handle(node->inY());
988 os_ << ")";
989}
990
991void IrPrinter::handle(const kir::PairSelect* node) {
992 if (!print_inline_) {
993 indent();
994 handle(node->out());
995 os_ << " = ";
996 }
997
998 handle(node->in());
999
1000 switch (node->selection()) {
1001 case kir::PairSelect::Selection::X:
1002 os_ << ".x";
1003 break;
1004 case kir::PairSelect::Selection::Y:
1005 os_ << ".y";
1006 break;
1007 default:
1008 break;
1009 }
1010}
1011
1012void IrTransformPrinter::handle(Fusion* f) {
1013 auto all_vals = f->usedMathVals();
1014
1015 for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
1016 IrPrinter::handle(tv);
1017 os() << "\n";
1018 printTransforms(tv);
1019 }
1020}
1021
1022void IrTransformPrinter::printTransforms(TensorView* tv) {
1023 auto root_domain = tv->domain()->getRootDomain();
1024 os() << " root domain : (";
1025 for (const auto root_idx : c10::irange(root_domain.size())) {
1026 IrPrinter::handle(root_domain[root_idx]);
1027 if (root_idx + 1 < root_domain.size()) {
1028 os() << ",";
1029 }
1030 }
1031 os() << ")\n";
1032
1033 if (tv->hasRFactor()) {
1034 auto rfactor_domain = tv->domain()->getRFactorDomain();
1035
1036 auto all_exp = DependencyCheck::getAllExprsBetween(
1037 {root_domain.begin(), root_domain.end()},
1038 {rfactor_domain.begin(), rfactor_domain.end()});
1039
1040 for (auto exp : all_exp) {
1041 os() << " ";
1042 IrPrinter::handle(exp);
1043 }
1044
1045 os() << " rfactor domain : (";
1046 for (const auto root_idx : c10::irange(rfactor_domain.size())) {
1047 IrPrinter::handle(rfactor_domain[root_idx]);
1048 if (root_idx + 1 < rfactor_domain.size()) {
1049 os() << ",";
1050 }
1051 }
1052 os() << ")\n";
1053 }
1054
1055 auto from = tv->getMaybeRFactorDomain();
1056 auto all_exp = DependencyCheck::getAllExprsBetween(
1057 {from.begin(), from.end()},
1058 {tv->domain()->domain().begin(), tv->domain()->domain().end()});
1059
1060 for (auto exp : all_exp) {
1061 os() << " ";
1062 IrPrinter::handle(exp);
1063 }
1064}
1065
1066std::ostream& operator<<(std::ostream& os, const Statement* stmt) {
1067 IrPrinter p(os);
1068 p.handle(stmt);
1069 return os;
1070}
1071
1072std::ostream& operator<<(std::ostream& os, Fusion* f) {
1073 IrPrinter p(os);
1074 FusionGuard guard(f);
1075 p.handle(f);
1076 return os;
1077}
1078
1079std::ostream& operator<<(std::ostream& os, Fusion& f) {
1080 return os << &f;
1081}
1082
1083} // namespace cuda
1084} // namespace fuser
1085} // namespace jit
1086} // namespace torch
1087