1 | #include <ir_iostream.h> |
2 | #include <ir_printer.h> |
3 | |
4 | #include <fusion.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_all_nodes.h> |
7 | #include <ir_utils.h> |
8 | #include <kernel.h> |
9 | #include <lower_utils.h> |
10 | |
11 | #include <c10/util/irange.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | namespace { |
19 | const char* boolLiteral(bool value) { |
20 | return value ? "true" : "false" ; |
21 | } |
22 | |
23 | std::string varName(const Val* val) { |
24 | std::stringstream value_name; |
25 | if (val == nullptr) { |
26 | value_name << "$nullptr" ; |
27 | } else { |
28 | value_name << val->name(); |
29 | } |
30 | return value_name.str(); |
31 | } |
32 | |
33 | } // namespace |
34 | |
35 | // Make sure we can inline something, before we attempt to. |
36 | static void checkInlineable(const Expr* expr) { |
37 | for (auto input : expr->inputs()) { |
38 | TORCH_CHECK( |
39 | input->isScalar(), |
40 | "Printing inline computations involving values other than scalars is not currently supported." ); |
41 | } |
42 | TORCH_CHECK( |
43 | expr->outputs().size() == 1, |
44 | "Cannot print inline computations if there's more than one output." ); |
45 | TORCH_CHECK( |
46 | expr->output(0)->isScalar(), |
47 | "Printing inline computations involving values other than scalars is not currently supported." ); |
48 | } |
49 | |
50 | void IrPrinter::handle(const Statement* s) { |
51 | OptInConstDispatch::handle(s); |
52 | } |
53 | |
54 | void IrPrinter::handle(const Val* v) { |
55 | OptInConstDispatch::handle(v); |
56 | } |
57 | |
58 | void IrPrinter::handle(const Expr* e) { |
59 | OptInConstDispatch::handle(e); |
60 | } |
61 | |
62 | void IrPrinter::handle(Fusion* fusion) { |
63 | FUSER_PERF_SCOPE("IrPrinter" ); |
64 | resetIndent(); |
65 | for (const Expr* expr : fusion->exprs()) { |
66 | handle(expr); |
67 | } |
68 | } |
69 | |
70 | void IrPrinter::handle(const kir::Kernel* kernel) { |
71 | TORCH_CHECK(kernel != nullptr); |
72 | |
73 | // kernel declaration |
74 | os_ << "\nKERNEL (" ; |
75 | for (auto in : kernel->inputs()) { |
76 | handle(in); |
77 | if (in != kernel->inputs().back()) { |
78 | os_ << ", " ; |
79 | } |
80 | } |
81 | os_ << ") -> (" ; |
82 | for (auto out : kernel->outputs()) { |
83 | handle(out); |
84 | if (out != kernel->outputs().back()) { |
85 | os_ << ", " ; |
86 | } |
87 | } |
88 | os_ << ") :\n" ; |
89 | |
90 | // kernel body |
91 | indent_size_++; |
92 | for (auto expr : kernel->topLevelExprs()) { |
93 | handle(expr); |
94 | } |
95 | indent_size_--; |
96 | os_ << "END.\n\n" ; |
97 | } |
98 | |
99 | void IrPrinter::handle(kir::Kernel& kernel) { |
100 | handle(&kernel); |
101 | } |
102 | |
103 | void IrPrinter::handleScope(const kir::Scope& scope) { |
104 | // Save the uses of the parent scope |
105 | indent_size_++; |
106 | for (auto expr : scope.exprs()) { |
107 | handle(expr); |
108 | } |
109 | indent_size_--; |
110 | } |
111 | |
112 | void IrPrinter::handle(const IterDomain* id) { |
113 | os_ << id->getIterType(); |
114 | os_ << id->getParallelType(); |
115 | os_ << varName(id); |
116 | os_ << "{" ; |
117 | if (!id->start()->isZeroInt()) { |
118 | print_inline(id->start()); |
119 | os_ << " : " ; |
120 | } |
121 | if (id->stop() != id->extent()) { |
122 | print_inline(id->stop()); |
123 | os_ << " : " ; |
124 | } |
125 | if (id->isBroadcast() && id->hasExpandedExtent()) { |
126 | print_inline(id->expandedExtent()); |
127 | } else { |
128 | print_inline(id->extent()); |
129 | } |
130 | os_ << "}" ; |
131 | if (id->isRFactorProduct()) |
132 | os_ << "rf" ; |
133 | if (id->hasPaddingToMultipleOfWarp()) { |
134 | os_ << "_p" ; |
135 | } |
136 | } |
137 | |
138 | void IrPrinter::handle(const TensorDomain* td) { |
139 | if (td->nDims() == 0) { |
140 | os_ << "[ 0 ]" ; |
141 | return; |
142 | } |
143 | os_ << "[ " ; |
144 | for (const auto i : c10::irange(td->nDims())) { |
145 | handle(td->axis(i)); |
146 | if (i != td->nDims() - 1) |
147 | os_ << ", " ; |
148 | } |
149 | os_ << " ]" ; |
150 | } |
151 | |
152 | void IrPrinter::handle(const TensorView* tv) { |
153 | os_ << "T" << varName(tv); |
154 | switch (tv->getMemoryType()) { |
155 | case MemoryType::Global: |
156 | os_ << "_g" ; |
157 | break; |
158 | case MemoryType::Shared: |
159 | os_ << "_s" ; |
160 | break; |
161 | case MemoryType::Local: |
162 | os_ << "_l" ; |
163 | break; |
164 | } |
165 | handle(tv->domain()); |
166 | |
167 | if (tv->getComputeAtPosition() > 0) { |
168 | os_ << " ca_pos( " ; |
169 | os_ << tv->getComputeAtPosition(); |
170 | os_ << " )" ; |
171 | } |
172 | if (tv->getMaxProducerPosition() > 0) { |
173 | os_ << " produce_pos( " ; |
174 | os_ << tv->getMaxProducerPosition(); |
175 | os_ << ")" ; |
176 | } |
177 | } |
178 | |
179 | void IrPrinter::handle(const Bool* b) { |
180 | if (print_inline_ && b->definition() != nullptr) { |
181 | os_ << "( " ; |
182 | handle(b->definition()); |
183 | os_ << " )" ; |
184 | return; |
185 | } |
186 | |
187 | os_ << "b" << varName(b); |
188 | if (b->isConst()) { |
189 | os_ << "(" << (b->value().value() ? "true" : "false" ) << ")" ; |
190 | } |
191 | } |
192 | |
193 | void IrPrinter::handle(const Double* d) { |
194 | if (print_inline_ && d->definition() != nullptr) { |
195 | os_ << "( " ; |
196 | handle(d->definition()); |
197 | os_ << " )" ; |
198 | return; |
199 | } |
200 | |
201 | if (d->isSymbolic()) { |
202 | os_ << "d" << varName(d); |
203 | } else { |
204 | os_ << "double(" |
205 | << std::setprecision( |
206 | std::numeric_limits<Double::ScalarType>::max_digits10) |
207 | << *(d->value()) << ")" ; |
208 | } |
209 | } |
210 | |
211 | void IrPrinter::handle(const Int* i) { |
212 | if (print_inline_) { |
213 | if (auto def = i->definition()) { |
214 | os_ << "( " ; |
215 | handle(def); |
216 | os_ << " )" ; |
217 | return; |
218 | } |
219 | } |
220 | |
221 | if (i->isSymbolic()) { |
222 | os_ << "i" << varName(i); |
223 | } else { |
224 | os_ << *(i->value()); |
225 | } |
226 | } |
227 | |
228 | void IrPrinter::handle(const ComplexDouble* c) { |
229 | if (print_inline_) { |
230 | if (auto def = c->definition()) { |
231 | os_ << "( " ; |
232 | handle(def); |
233 | os_ << " )" ; |
234 | return; |
235 | } |
236 | } |
237 | |
238 | if (c->isSymbolic()) { |
239 | os_ << "c" << varName(c); |
240 | } else { |
241 | os_ << "std::complex<double>" |
242 | << std::setprecision(std::numeric_limits<double>::max_digits10) |
243 | << *(c->value()); |
244 | } |
245 | } |
246 | |
247 | void IrPrinter::handle(const NamedScalar* ns) { |
248 | os_ << ns->name(); |
249 | } |
250 | |
251 | void IrPrinter::handle(const FullOp* fop) { |
252 | if (!print_inline_) { |
253 | indent(); |
254 | os_ << fop->output(0) << "\n" ; |
255 | indent_size_++; |
256 | indent(); |
257 | os_ << " = " ; |
258 | } else { |
259 | checkInlineable(fop); |
260 | } |
261 | |
262 | os_ << "full({" ; |
263 | for (auto i : c10::irange(fop->inputs().size())) { |
264 | if (i == fop->inputs().size() - 1) { |
265 | os_ << "}" ; |
266 | } |
267 | if (i > 0) { |
268 | os_ << ", " ; |
269 | } |
270 | handle(fop->input(i)); |
271 | } |
272 | os_ << ", " << fop->dtype() << ")" ; |
273 | |
274 | indent_size_--; |
275 | |
276 | if (!print_inline_) |
277 | os_ << ";\n" ; |
278 | } |
279 | |
280 | void IrPrinter::handle(const ARangeOp* aop) { |
281 | if (!print_inline_) { |
282 | indent() << aop->output(0); |
283 | os_ << "\n" ; |
284 | indent_size_++; |
285 | indent(); |
286 | os_ << " = " ; |
287 | } else { |
288 | checkInlineable(aop); |
289 | } |
290 | |
291 | os_ << "arange(" ; |
292 | handle(aop->start()); |
293 | os_ << ", " ; |
294 | handle(aop->end()); |
295 | os_ << ", " ; |
296 | handle(aop->step()); |
297 | os_ << ", " << aop->dtype() << ")" ; |
298 | |
299 | indent_size_--; |
300 | |
301 | if (!print_inline_) |
302 | os_ << ";\n" ; |
303 | } |
304 | |
305 | void IrPrinter::handle(const EyeOp* eop) { |
306 | if (!print_inline_) { |
307 | indent(); |
308 | os_ << eop->output(0) << "\n" ; |
309 | indent_size_++; |
310 | indent(); |
311 | os_ << " = " ; |
312 | } else { |
313 | checkInlineable(eop); |
314 | } |
315 | |
316 | os_ << "eye(" ; |
317 | handle(eop->input(0)); |
318 | os_ << ", " << eop->dtype() << ")" ; |
319 | |
320 | indent_size_--; |
321 | |
322 | if (!print_inline_) |
323 | os_ << ";\n" ; |
324 | } |
325 | |
326 | void IrPrinter::handle(const UnaryOp* uop) { |
327 | bool istvop = ir_utils::isTvOp(uop); |
328 | if (!print_inline_) { |
329 | indent() << uop->out(); |
330 | if (istvop) { |
331 | os_ << "\n" ; |
332 | indent_size_++; |
333 | indent(); |
334 | } |
335 | os_ << " = " ; |
336 | } else { |
337 | checkInlineable(uop); |
338 | } |
339 | |
340 | auto op_type = uop->getUnaryOpType(); |
341 | |
342 | if (auto inline_uop = inline_op_str(op_type)) { |
343 | os_ << inline_uop.value(); |
344 | handle(uop->in()); |
345 | } else { |
346 | if (op_type == UnaryOpType::Cast) { |
347 | c10::optional<std::string> cast_str = cast_func_str(std::make_pair( |
348 | uop->in()->getDataType().value(), uop->out()->getDataType().value())); |
349 | TORCH_INTERNAL_ASSERT(cast_str != c10::nullopt, "Unsupported Cast" ); |
350 | os_ << cast_str.value(); |
351 | } else { |
352 | if (alsoBooleanOperator(op_type) && |
353 | uop->out()->getDataType().value() == DataType::Bool) { |
354 | os_ << stringifyBooleanOp(op_type); |
355 | } else { |
356 | os_ << op_type; |
357 | } |
358 | if (uop->out()->getDataType().value() == DataType::Float && |
359 | needFloatSuffix(op_type)) { |
360 | os_ << "f" ; |
361 | } |
362 | } |
363 | os_ << "(" ; |
364 | handle(uop->in()); |
365 | os_ << ")" ; |
366 | } |
367 | |
368 | if (istvop) |
369 | indent_size_--; |
370 | |
371 | if (!print_inline_) |
372 | os_ << ";\n" ; |
373 | } |
374 | |
375 | void IrPrinter::handle(const BinaryOp* bop) { |
376 | bool istvop = ir_utils::isTvOp(bop); |
377 | if (!print_inline_) { |
378 | indent() << bop->out(); |
379 | |
380 | // tensor operations tend to be long, break them up into multiple lines |
381 | if (istvop) { |
382 | os_ << "\n" ; |
383 | indent_size_++; |
384 | indent(); |
385 | } |
386 | |
387 | os_ << " = " ; |
388 | } else { |
389 | checkInlineable(bop); |
390 | } |
391 | |
392 | auto op_type = bop->getBinaryOpType(); |
393 | if (auto inline_bop = inline_op_str(op_type)) { |
394 | handle(bop->lhs()); |
395 | if (istvop) { |
396 | os_ << "\n" ; |
397 | indent(); |
398 | } |
399 | os_ << " " << inline_bop.value() << " " ; |
400 | handle(bop->rhs()); |
401 | } else { |
402 | if (alsoBooleanOperator(op_type) && |
403 | bop->out()->getDataType().value() == DataType::Bool) { |
404 | os_ << stringifyBooleanOp(op_type); |
405 | } else { |
406 | os_ << op_type; |
407 | } |
408 | if (bop->out()->getDataType().value() == DataType::Float && |
409 | needFloatSuffix(op_type)) { |
410 | os_ << "f" ; |
411 | } |
412 | os_ << "(" ; |
413 | handle(bop->lhs()); |
414 | if (istvop) { |
415 | os_ << "\n" ; |
416 | indent(); |
417 | } |
418 | os_ << ", " ; |
419 | handle(bop->rhs()); |
420 | os_ << ")" ; |
421 | } |
422 | |
423 | if (istvop) |
424 | indent_size_--; |
425 | |
426 | if (!print_inline_) |
427 | os_ << ";\n" ; |
428 | } |
429 | |
430 | void IrPrinter::handle(const TernaryOp* top) { |
431 | bool istvop = ir_utils::isTvOp(top); |
432 | if (!print_inline_) { |
433 | indent(); |
434 | os_ << top->out(); |
435 | |
436 | // tensor operations tend to be long, break them up into multiple lines |
437 | if (istvop) { |
438 | os_ << "\n" ; |
439 | indent_size_++; |
440 | indent(); |
441 | } |
442 | |
443 | os_ << " = " ; |
444 | } else { |
445 | checkInlineable(top); |
446 | } |
447 | |
448 | os_ << top->getTernaryOpType() << "(" ; |
449 | handle(top->in1()); |
450 | if (istvop) { |
451 | os_ << "\n" ; |
452 | indent(); |
453 | } |
454 | os_ << ", " ; |
455 | handle(top->in2()); |
456 | if (istvop) { |
457 | os_ << "\n" ; |
458 | indent(); |
459 | } |
460 | os_ << ", " ; |
461 | handle(top->in3()); |
462 | os_ << ")" ; |
463 | |
464 | if (istvop) |
465 | indent_size_--; |
466 | |
467 | if (!print_inline_) |
468 | os_ << ";\n" ; |
469 | } |
470 | |
471 | void IrPrinter::handle(const RNGOp* rop) { |
472 | if (!print_inline_) { |
473 | indent(); |
474 | os_ << rop->output(0) << "\n" ; |
475 | indent_size_++; |
476 | indent(); |
477 | os_ << " = " ; |
478 | } else { |
479 | checkInlineable(rop); |
480 | } |
481 | |
482 | os_ << rop->getRNGOpType() << "({" ; |
483 | bool first = true; |
484 | for (auto i : rop->getShape()) { |
485 | if (!first) { |
486 | os_ << ", " ; |
487 | } |
488 | handle(i); |
489 | first = false; |
490 | } |
491 | os_ << "}" ; |
492 | for (auto i : rop->getParameters()) { |
493 | os_ << ", " ; |
494 | handle(i); |
495 | } |
496 | os_ << ", " << rop->dtype() << ")" ; |
497 | |
498 | indent_size_--; |
499 | |
500 | if (!print_inline_) { |
501 | os_ << ";\n" ; |
502 | } |
503 | } |
504 | |
505 | void IrPrinter::handle(const ReductionOp* rop) { |
506 | indent() << rop->out() << "\n" ; |
507 | indent() << " = reduction( " << rop->in() |
508 | << ", op = " << rop->getReductionOpType() |
509 | << ", initial value = " << rop->init() |
510 | << ", allreduce = " << (rop->isAllreduce() ? "true" : "false" ) |
511 | << " )\n" ; |
512 | } |
513 | |
514 | void IrPrinter::handle(const GroupedReductionOp* grouped_rop) { |
515 | indent() << "GroupedReductionOp(\n" ; |
516 | ++indent_size_; |
517 | for (const auto i : c10::irange(grouped_rop->numExprs())) { |
518 | indent() << grouped_rop->output(i) << " = reduction( " |
519 | << grouped_rop->input(i) |
520 | << ", op = " << grouped_rop->getReductionOpType(i) |
521 | << ", initial value = " << grouped_rop->initVal(i) << " )\n" ; |
522 | } |
523 | indent() << "allreduce = " << (grouped_rop->isAllreduce() ? "true" : "false" ) |
524 | << " )\n" ; |
525 | --indent_size_; |
526 | } |
527 | |
528 | void IrPrinter::handle(const WelfordOp* wop) { |
529 | indent() << wop->outAvg() << "(Avg),\n" |
530 | << wop->outVar() << "(Var),\n" |
531 | << wop->outN() << "(Count)" |
532 | << "\n = Welford ( " ; |
533 | if (wop->singleValue()) { |
534 | os_ << wop->inAvg() << "(Avg), " ; |
535 | } else { |
536 | os_ << wop->inAvg() << "(Avg)\n " << wop->inVar() << "(Var)\n " |
537 | << wop->inN() << "(Count)" ; |
538 | } |
539 | if (wop->hasInit()) { |
540 | os_ << "\n initial value = " << wop->initAvg() << "(Avg)\n " |
541 | << wop->initVar() << "(Var)\n " << wop->initN() << "(N)" ; |
542 | } |
543 | os_ << "\n allreduce = " << (wop->isAllreduce() ? "true" : "false" ); |
544 | os_ << " )\n" ; |
545 | } |
546 | |
547 | void IrPrinter::handle(const GroupedWelfordOp* grouped_wop) { |
548 | indent() << "GroupedWelford(\n" ; |
549 | ++indent_size_; |
550 | for (const auto i : c10::irange(grouped_wop->numExprs())) { |
551 | indent() << grouped_wop->outAvg(i) << " (Avg),\n" ; |
552 | indent() << grouped_wop->outVar(i) << " (Var),\n" ; |
553 | indent() << grouped_wop->outN(i) << " (Count)\n" ; |
554 | indent() << " = Welford ( " ; |
555 | ++indent_size_; |
556 | indent() << grouped_wop->inAvg(i) << " (Avg),\n" ; |
557 | indent() << grouped_wop->inVar(i) << " (Var),\n" ; |
558 | indent() << grouped_wop->inN(i) << " (Count)\n" ; |
559 | indent() << "initial value =\n" ; |
560 | ++indent_size_; |
561 | indent() << grouped_wop->initAvg(i) << " (Avg),\n" ; |
562 | indent() << grouped_wop->initVar(i) << " (Var),\n" ; |
563 | indent() << grouped_wop->initN(i) << " (Count) )\n" ; |
564 | indent_size_ -= 2; |
565 | } |
566 | indent() << "allreduce = " << (grouped_wop->isAllreduce() ? "true" : "false" ) |
567 | << " )\n" ; |
568 | --indent_size_; |
569 | } |
570 | |
571 | void IrPrinter::handle(const LoadStoreOp* ldst) { |
572 | indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in() |
573 | << " )\n" ; |
574 | } |
575 | |
576 | void IrPrinter::handle(const BroadcastOp* bop) { |
577 | indent() << bop->out() << "\n" ; |
578 | indent() << " = broadcast( " << bop->in() << " )\n" ; |
579 | } |
580 | |
581 | void IrPrinter::handle(const Split* s) { |
582 | os_ << (s->innerSplit() ? "Split: " : "Outer split: " ); |
583 | handle(s->in()); |
584 | os_ << " by factor " << s->factor() << " -> " ; |
585 | handle(s->outer()); |
586 | os_ << ", " ; |
587 | handle(s->inner()); |
588 | if (s->startOffset()) { |
589 | os_ << ", start offset: " ; |
590 | handle(s->startOffset()); |
591 | } |
592 | if (s->stopOffset()) { |
593 | os_ << ", stop offset: " ; |
594 | handle(s->stopOffset()); |
595 | } |
596 | os_ << "\n" ; |
597 | } |
598 | |
599 | void IrPrinter::handle(const Merge* m) { |
600 | os_ << "Merge: " ; |
601 | handle(m->outer()); |
602 | os_ << " and " ; |
603 | handle(m->inner()); |
604 | os_ << " -> " ; |
605 | handle(m->out()); |
606 | os_ << "\n" ; |
607 | } |
608 | |
609 | void IrPrinter::handle(const Swizzle2D* s) { |
610 | os_ << s->swizzleType() << "(2D): " ; |
611 | handle(s->inX()); |
612 | os_ << " , " ; |
613 | handle(s->inY()); |
614 | os_ << " -> " ; |
615 | handle(s->outX()); |
616 | os_ << " , " ; |
617 | handle(s->outY()); |
618 | os_ << "\n" ; |
619 | } |
620 | |
621 | void IrPrinter::handle(const TransposeOp* top) { |
622 | indent() << top->out() << " = transpose( " << top->in() << " )\n" ; |
623 | } |
624 | |
625 | void IrPrinter::handle(const ExpandOp* eop) { |
626 | indent() << eop->out() << " = expand( " << eop->in() << ", {" ; |
627 | std::stringstream ss; |
628 | for (auto expanded_extent : eop->expanded_extents()) { |
629 | if (ss.tellp()) { |
630 | ss << ", " ; |
631 | } |
632 | ss << expanded_extent; |
633 | } |
634 | os_ << ss.str() << "} )\n" ; |
635 | } |
636 | |
637 | void IrPrinter::handle(const ShiftOp* sop) { |
638 | indent() << sop->out() << " = shift( " << sop->in() << ", {" << sop->offsets() |
639 | << "}, {" << sop->padWidth() << "} )\n" ; |
640 | } |
641 | |
642 | void IrPrinter::handle(const MmaOp* mma) { |
643 | indent() << mma->out() << " = mma(" << mma->inA() << "," << mma->inB(); |
644 | os_ << ")\n" ; |
645 | } |
646 | |
647 | void IrPrinter::handle(const GatherOp* op) { |
648 | indent() << op->out() << " = gather( " << op->in() << ", {" ; |
649 | bool no_comma = true; |
650 | for (const auto& s : op->windowShape()) { |
651 | if (!no_comma) { |
652 | os_ << ", " ; |
653 | } |
654 | os_ << s; |
655 | no_comma = false; |
656 | } |
657 | os_ << "}, {" ; |
658 | no_comma = true; |
659 | for (const auto& pad : op->padWidth()) { |
660 | if (!no_comma) { |
661 | os_ << ", " ; |
662 | } |
663 | os_ << "{" << pad[0] << ", " << pad[1] << "}" ; |
664 | no_comma = false; |
665 | } |
666 | os_ << "} )\n" ; |
667 | } |
668 | |
669 | void IrPrinter::handle(const ViewAsScalar* top) { |
670 | indent() << top->out() << " = view_as_scalar( " << top->in() << ", " |
671 | << top->vector_id() << " )\n" ; |
672 | } |
673 | |
674 | void IrPrinter::handle(const ViewOp* top) { |
675 | indent() << top->out() << " = view( " << top->in() << " )\n" ; |
676 | } |
677 | |
678 | void IrPrinter::handle(const kir::Predicate* node) { |
679 | switch (node->predicate_type()) { |
680 | case PredicateType::Manual: { |
681 | os_ << node->value(); |
682 | break; |
683 | } |
684 | default: |
685 | os_ << node->predicate_type(); |
686 | break; |
687 | } |
688 | } |
689 | |
690 | void IrPrinter::handle(const kir::TensorIndex* ti) { |
691 | os_ << "T" << varName(ti); |
692 | switch (ti->view()->getMemoryType()) { |
693 | case MemoryType::Global: |
694 | os_ << "_g" ; |
695 | break; |
696 | case MemoryType::Shared: |
697 | os_ << "_s" ; |
698 | break; |
699 | case MemoryType::Local: |
700 | os_ << "_l" ; |
701 | break; |
702 | } |
703 | os_ << "[" ; |
704 | for (auto index : ti->indices()) { |
705 | print_inline(index); |
706 | if (index != ti->indices().back()) { |
707 | os_ << ", " ; |
708 | } |
709 | } |
710 | os_ << "]" ; |
711 | os_ << " view( T" << varName(ti->view()) << " )" ; |
712 | } |
713 | |
714 | void IrPrinter::handle(const kir::Allocate* node) { |
715 | indent(); |
716 | handle(node->buffer()); |
717 | os_ << " = ALLOCATE(" |
718 | << "mem_type=" << node->memoryType() << ", " |
719 | << "size=" ; |
720 | print_inline(node->size()); |
721 | os_ << ", " |
722 | << "zero_init=" << boolLiteral(node->zeroInit()) << ")\n" ; |
723 | if (node->alias() != nullptr) { |
724 | indent() << kTab << ".alias=" ; |
725 | handle(node->alias()->buffer()); |
726 | os_ << "\n" ; |
727 | } |
728 | } |
729 | |
730 | void IrPrinter::handle(const kir::BlockSync* node) { |
731 | indent() << "BLOCKSYNC(war_hazard=" << boolLiteral(node->isWarHazardSync()) |
732 | << ")\n" ; |
733 | } |
734 | |
735 | void IrPrinter::handle(const kir::CpAsyncWait* node) { |
736 | indent() << "CPASYNC_WAIT(" << node->keepStages() << ")\n" ; |
737 | } |
738 | |
739 | void IrPrinter::handle(const kir::CpAsyncCommit* node) { |
740 | indent() << "CPASYNC_WAIT()\n" ; |
741 | } |
742 | |
743 | void IrPrinter::handle(const kir::GridSync* node) { |
744 | indent() << "GRIDSYNC(" << node->syncDims().toString() << ", " ; |
745 | handle(node->syncBuffer()); |
746 | os_ << ")\n" ; |
747 | } |
748 | |
749 | void IrPrinter::handle(const kir::ForLoop* node) { |
750 | indent() << "FOR " ; |
751 | handle(node->index()); |
752 | os_ << " in " ; |
753 | handle(node->iter_domain()); |
754 | os_ << ":\n" ; |
755 | handleScope(node->body()); |
756 | } |
757 | |
758 | void IrPrinter::handle(const kir::IfThenElse* node) { |
759 | indent() << "IF " ; |
760 | handle(node->predicate()); |
761 | os_ << ":\n" ; |
762 | handleScope(node->thenBody()); |
763 | if (node->hasElse()) { |
764 | indent() << "ELSE:\n" ; |
765 | handleScope(node->elseBody()); |
766 | } |
767 | } |
768 | |
769 | void IrPrinter::handle(const kir::GridBroadcast* node) { |
770 | const auto* broadcast_op = node->broadcast_op(); |
771 | indent(); |
772 | handle(broadcast_op->out()); |
773 | os_ << " = " |
774 | << "GRID_BROADCAST(in=" ; |
775 | handle(broadcast_op->in()); |
776 | os_ << ")\n" ; |
777 | indent() << kTab << ".broadcast_buffer=" ; |
778 | handle(node->broadcast_buffer()->buffer()); |
779 | os_ << "\n" ; |
780 | indent() << kTab << ".sync_buffer=" ; |
781 | handle(node->sync_buffer()->buffer()); |
782 | os_ << "\n" ; |
783 | } |
784 | |
785 | void IrPrinter::handle(const kir::GridReduction* node) { |
786 | indent() << node->out() << " = reduction( " << node->in() |
787 | << ", op = " << node->getReductionOpType() |
788 | << ", initial value = " << node->init() << ",\n" ; |
789 | ++indent_size_; |
790 | indent() << "reduction buffer = " << node->reduction_buffer()->buffer() |
791 | << ",\n" ; |
792 | indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n" ; |
793 | indent() << "read predicate = " ; |
794 | if (node->predicate() != nullptr) { |
795 | os_ << node->predicate(); |
796 | } else { |
797 | os_ << "nullptr" ; |
798 | } |
799 | os_ << ",\n" ; |
800 | indent() << "write predicate = " ; |
801 | if (node->writePredicate() != nullptr) { |
802 | os_ << node->writePredicate(); |
803 | } else { |
804 | os_ << "nullptr" ; |
805 | } |
806 | os_ << ",\n" ; |
807 | indent() << "thread predicate = " << node->threadPredicate().toString() |
808 | << ",\n" ; |
809 | indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false" ) |
810 | << " )\n" ; |
811 | --indent_size_; |
812 | } |
813 | |
814 | void IrPrinter::handle(const kir::GroupedGridReduction* node) { |
815 | indent() << "GroupedGridReduction(\n" ; |
816 | ++indent_size_; |
817 | for (const auto i : c10::irange(node->numExprs())) { |
818 | indent() << node->output(i) << " = reduction( " << node->input(i) |
819 | << ", op = " << node->getReductionOpType(i) |
820 | << ", initial value = " << node->initVal(i) |
821 | << ", reduction buffer = " |
822 | << node->reduction_buffers().at(i)->buffer() << " )\n" ; |
823 | } |
824 | indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n" ; |
825 | indent() << "read predicate = " ; |
826 | if (node->predicate() != nullptr) { |
827 | os_ << node->predicate(); |
828 | } else { |
829 | os_ << "nullptr" ; |
830 | } |
831 | os_ << ",\n" ; |
832 | indent() << "write predicate = " ; |
833 | if (node->writePredicate() != nullptr) { |
834 | os_ << node->writePredicate(); |
835 | } else { |
836 | os_ << "nullptr" ; |
837 | } |
838 | os_ << ",\n" ; |
839 | indent() << "thread predicate = " << node->threadPredicate().toString() |
840 | << ",\n" ; |
841 | indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false" ) |
842 | << " )\n" ; |
843 | --indent_size_; |
844 | } |
845 | |
846 | void IrPrinter::handle(const kir::GridWelford* node) { |
847 | std::cerr << "current indent size: " << indent_size_ << std::endl; |
848 | const auto* welford_op = node->welford_op(); |
849 | indent() << welford_op->outAvg() << " (Avg),\n" ; |
850 | indent() << welford_op->outVar() << " (Var),\n" ; |
851 | indent() << welford_op->outN() << " (Count)\n" ; |
852 | indent() << " = Welford (\n" ; |
853 | ++indent_size_; |
854 | indent() << welford_op->inAvg() << " (Avg),\n" ; |
855 | indent() << welford_op->inVar() << " (Var),\n" ; |
856 | indent() << welford_op->inN() << " (Count)\n" ; |
857 | indent() << "initial value =\n" ; |
858 | ++indent_size_; |
859 | indent() << welford_op->initAvg() << " (Avg),\n" ; |
860 | indent() << welford_op->initVar() << " (Var),\n" ; |
861 | indent() << welford_op->initN() << " (Count),\n" ; |
862 | --indent_size_; |
863 | indent() << "reduction buffer =\n" ; |
864 | ++indent_size_; |
865 | indent() << node->avg_buffer()->buffer() << " (Avg),\n" ; |
866 | indent() << node->var_buffer()->buffer() << " (Var),\n" ; |
867 | indent() << node->N_buffer()->buffer() << " (Count),\n" ; |
868 | --indent_size_; |
869 | indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n" ; |
870 | indent() << "read predicate = " ; |
871 | if (welford_op->predicate() != nullptr) { |
872 | os_ << welford_op->predicate(); |
873 | } else { |
874 | os_ << "nullptr" ; |
875 | } |
876 | os_ << ",\n" ; |
877 | indent() << "write predicate = " ; |
878 | if (welford_op->writePredicate() != nullptr) { |
879 | os_ << welford_op->writePredicate(); |
880 | } else { |
881 | os_ << "nullptr" ; |
882 | } |
883 | os_ << ",\n" ; |
884 | indent() << "grid read predicate = " ; |
885 | if (node->predicate() != nullptr) { |
886 | os_ << node->predicate(); |
887 | } else { |
888 | os_ << "nullptr" ; |
889 | } |
890 | os_ << ",\n" ; |
891 | indent() << "grid write predicate = " ; |
892 | if (node->writePredicate() != nullptr) { |
893 | os_ << node->writePredicate(); |
894 | } else { |
895 | os_ << "nullptr" ; |
896 | } |
897 | os_ << ",\n" ; |
898 | indent() << "thread predicate = " << node->threadPredicate().toString() |
899 | << ",\n" ; |
900 | indent() << "allreduce = " << (welford_op->isAllreduce() ? "true" : "false" ) |
901 | << " )\n" ; |
902 | --indent_size_; |
903 | std::cerr << "Ending indent size: " << indent_size_ << std::endl; |
904 | } |
905 | |
906 | void IrPrinter::handle(const kir::GroupedGridWelford* node) { |
907 | indent() << "GroupedGridWelford(\n" ; |
908 | ++indent_size_; |
909 | for (const auto i : c10::irange(node->numExprs())) { |
910 | indent() << node->outAvg(i) << " (Avg),\n" ; |
911 | indent() << node->outVar(i) << " (Var),\n" ; |
912 | indent() << node->outN(i) << " (Count)\n" ; |
913 | indent() << " = Welford (\n" ; |
914 | ++indent_size_; |
915 | indent() << node->inAvg(i) << " (Avg),\n" ; |
916 | indent() << node->inVar(i) << " (Var),\n" ; |
917 | indent() << node->inN(i) << " (Count)\n" ; |
918 | indent() << "initial value =\n" ; |
919 | ++indent_size_; |
920 | indent() << node->initAvg(i) << " (Avg),\n" ; |
921 | indent() << node->initVar(i) << " (Var),\n" ; |
922 | indent() << node->initN(i) << " (Count),\n" ; |
923 | --indent_size_; |
924 | indent() << "reduction buffer =\n" ; |
925 | ++indent_size_; |
926 | indent() << node->reduction_buffers()[0].at(i)->buffer() << " (Avg),\n" ; |
927 | indent() << node->reduction_buffers()[1].at(i)->buffer() << " (Var),\n" ; |
928 | indent() << node->reduction_buffers()[2].at(i)->buffer() << " (Count) )\n" ; |
929 | indent_size_ -= 2; |
930 | } |
931 | indent() << "sync buffer = " << node->sync_buffer()->buffer() << ",\n" ; |
932 | indent() << "read predicate = " ; |
933 | if (node->predicate() != nullptr) { |
934 | os_ << node->predicate(); |
935 | } else { |
936 | os_ << "nullptr" ; |
937 | } |
938 | os_ << ",\n" ; |
939 | indent() << "write predicate = " ; |
940 | if (node->writePredicate() != nullptr) { |
941 | os_ << node->writePredicate(); |
942 | } else { |
943 | os_ << "nullptr" ; |
944 | } |
945 | os_ << ",\n" ; |
946 | indent() << "thread predicate = " << node->threadPredicate().toString() |
947 | << ",\n" ; |
948 | indent() << "allreduce = " << (node->isAllreduce() ? "true" : "false" ) |
949 | << " )\n" ; |
950 | --indent_size_; |
951 | } |
952 | |
953 | void IrPrinter::handle(const kir::InitMagicZero* node) { |
954 | indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n" ; |
955 | } |
956 | |
957 | void IrPrinter::handle(const kir::UpdateMagicZero* node) { |
958 | indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n" ; |
959 | } |
960 | |
961 | void IrPrinter::handle(const kir::AllocateFusedReduction* node) { |
962 | indent() << "AllocateFusedReduction(reduction buffer=" ; |
963 | handle(node->out()); |
964 | os_ << ")\n" ; |
965 | } |
966 | |
967 | void IrPrinter::handle(const kir::IntPair* node) { |
968 | if (print_inline_) { |
969 | if (node->definition()) { |
970 | handle(node->definition()); |
971 | return; |
972 | } |
973 | } |
974 | os_ << "iPair" << varName(node); |
975 | } |
976 | |
977 | void IrPrinter::handle(const kir::Swizzle2DInt* node) { |
978 | if (!print_inline_) { |
979 | indent(); |
980 | handle(node->out()); |
981 | os_ << " = " ; |
982 | } |
983 | |
984 | os_ << node->swizzleType() << "2D(" ; |
985 | handle(node->inX()); |
986 | os_ << "," ; |
987 | handle(node->inY()); |
988 | os_ << ")" ; |
989 | } |
990 | |
991 | void IrPrinter::handle(const kir::PairSelect* node) { |
992 | if (!print_inline_) { |
993 | indent(); |
994 | handle(node->out()); |
995 | os_ << " = " ; |
996 | } |
997 | |
998 | handle(node->in()); |
999 | |
1000 | switch (node->selection()) { |
1001 | case kir::PairSelect::Selection::X: |
1002 | os_ << ".x" ; |
1003 | break; |
1004 | case kir::PairSelect::Selection::Y: |
1005 | os_ << ".y" ; |
1006 | break; |
1007 | default: |
1008 | break; |
1009 | } |
1010 | } |
1011 | |
1012 | void IrTransformPrinter::handle(Fusion* f) { |
1013 | auto all_vals = f->usedMathVals(); |
1014 | |
1015 | for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) { |
1016 | IrPrinter::handle(tv); |
1017 | os() << "\n" ; |
1018 | printTransforms(tv); |
1019 | } |
1020 | } |
1021 | |
1022 | void IrTransformPrinter::printTransforms(TensorView* tv) { |
1023 | auto root_domain = tv->domain()->getRootDomain(); |
1024 | os() << " root domain : (" ; |
1025 | for (const auto root_idx : c10::irange(root_domain.size())) { |
1026 | IrPrinter::handle(root_domain[root_idx]); |
1027 | if (root_idx + 1 < root_domain.size()) { |
1028 | os() << "," ; |
1029 | } |
1030 | } |
1031 | os() << ")\n" ; |
1032 | |
1033 | if (tv->hasRFactor()) { |
1034 | auto rfactor_domain = tv->domain()->getRFactorDomain(); |
1035 | |
1036 | auto all_exp = DependencyCheck::getAllExprsBetween( |
1037 | {root_domain.begin(), root_domain.end()}, |
1038 | {rfactor_domain.begin(), rfactor_domain.end()}); |
1039 | |
1040 | for (auto exp : all_exp) { |
1041 | os() << " " ; |
1042 | IrPrinter::handle(exp); |
1043 | } |
1044 | |
1045 | os() << " rfactor domain : (" ; |
1046 | for (const auto root_idx : c10::irange(rfactor_domain.size())) { |
1047 | IrPrinter::handle(rfactor_domain[root_idx]); |
1048 | if (root_idx + 1 < rfactor_domain.size()) { |
1049 | os() << "," ; |
1050 | } |
1051 | } |
1052 | os() << ")\n" ; |
1053 | } |
1054 | |
1055 | auto from = tv->getMaybeRFactorDomain(); |
1056 | auto all_exp = DependencyCheck::getAllExprsBetween( |
1057 | {from.begin(), from.end()}, |
1058 | {tv->domain()->domain().begin(), tv->domain()->domain().end()}); |
1059 | |
1060 | for (auto exp : all_exp) { |
1061 | os() << " " ; |
1062 | IrPrinter::handle(exp); |
1063 | } |
1064 | } |
1065 | |
1066 | std::ostream& operator<<(std::ostream& os, const Statement* stmt) { |
1067 | IrPrinter p(os); |
1068 | p.handle(stmt); |
1069 | return os; |
1070 | } |
1071 | |
1072 | std::ostream& operator<<(std::ostream& os, Fusion* f) { |
1073 | IrPrinter p(os); |
1074 | FusionGuard guard(f); |
1075 | p.handle(f); |
1076 | return os; |
1077 | } |
1078 | |
1079 | std::ostream& operator<<(std::ostream& os, Fusion& f) { |
1080 | return os << &f; |
1081 | } |
1082 | |
1083 | } // namespace cuda |
1084 | } // namespace fuser |
1085 | } // namespace jit |
1086 | } // namespace torch |
1087 | |