1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <dispatch.h>
6
7#include <c10/util/irange.h>
8
9#include <iostream>
10
11namespace torch {
12namespace jit {
13namespace fuser {
14namespace cuda {
15
16class Fusion;
17namespace kir {
18class Kernel;
19class Scope;
20} // namespace kir
21
22//! Define pretty printing functions for IR nodes
23//!
24//! This class is intended for debug printing, so it attempts
25//! to handle invalid states as well.
26//!
27class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
28 static constexpr char const* kTab = " ";
29
30 public:
31 explicit IrPrinter(std::ostream& os) : os_(os) {}
32
33 // Indent the generated code
34 std::ostream& indent() {
35 for (const auto i : c10::irange(indent_size_)) {
36 (void)i; // Suppress unused variable warning
37 os_ << " ";
38 }
39 return os_;
40 }
41
42 void resetIndent() {
43 indent_size_ = 0;
44 }
45
46 bool printInline() const {
47 return print_inline_;
48 }
49
50 using OptInConstDispatch::handle;
51
52 virtual void handle(Fusion* f);
53
54 // handle calls some non const fusion ops,
55 // eventhough fusion should remain unchanged.
56 // Need to look into this.
57 virtual void handle(const Fusion* f) {
58 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
59 handle(const_cast<Fusion*>(f));
60 }
61
62 virtual void handle(Fusion& f) {
63 handle(&f);
64 }
65
66 virtual void handle(const kir::Kernel* kernel);
67 virtual void handle(kir::Kernel& kernel);
68
69 void handleScope(const kir::Scope& scope);
70
71 void handle(const Statement* s) final;
72 void handle(const Val* v) final;
73 void handle(const Expr* e) final;
74
75 void handle(const IterDomain*) final;
76 void handle(const TensorDomain*) final;
77 void handle(const TensorView*) final;
78
79 void handle(const Bool*) final;
80 void handle(const Double*) final;
81 void handle(const Int*) final;
82 void handle(const ComplexDouble*) final;
83 void handle(const NamedScalar*) final;
84
85 void handle(const FullOp*) final;
86 void handle(const ARangeOp*) final;
87 void handle(const EyeOp*) final;
88 void handle(const UnaryOp*) final;
89 void handle(const BinaryOp*) final;
90 void handle(const TernaryOp*) final;
91 void handle(const RNGOp*) final;
92 void handle(const ReductionOp*) final;
93 void handle(const GroupedReductionOp*) final;
94 void handle(const WelfordOp*) final;
95 void handle(const GroupedWelfordOp*) final;
96 void handle(const LoadStoreOp*) final;
97 void handle(const MmaOp*) final;
98 void handle(const BroadcastOp*) final;
99 void handle(const TransposeOp*) final;
100 void handle(const ExpandOp*) final;
101 void handle(const ShiftOp*) final;
102 void handle(const GatherOp*) final;
103 void handle(const ViewAsScalar*) final;
104 void handle(const ViewOp*) final;
105
106 void handle(const kir::Predicate*) final;
107 void handle(const kir::TensorIndex*) final;
108 void handle(const kir::IntPair*) final;
109
110 void handle(const kir::GridBroadcast*) final;
111 void handle(const kir::GridReduction*) final;
112 void handle(const kir::GroupedGridReduction*) final;
113 void handle(const kir::GridWelford*) final;
114 void handle(const kir::GroupedGridWelford*) final;
115 void handle(const kir::ForLoop*) final;
116 void handle(const kir::IfThenElse*) final;
117 void handle(const kir::Allocate*) final;
118 void handle(const kir::BlockSync*) final;
119 void handle(const kir::GridSync*) final;
120 void handle(const kir::CpAsyncWait*) final;
121 void handle(const kir::CpAsyncCommit*) final;
122 void handle(const kir::InitMagicZero*) final;
123 void handle(const kir::UpdateMagicZero*) final;
124 void handle(const kir::AllocateFusedReduction*) final;
125 void handle(const kir::Swizzle2DInt*) final;
126 void handle(const kir::PairSelect*) final;
127
128 // IR math printer overrides these to prevent them from printing, keep
129 // override
130 void handle(const Split*) override;
131 void handle(const Merge*) override;
132 void handle(const Swizzle2D*) override;
133
134 void print_inline(const Statement* stmt) {
135 bool prev = print_inline_;
136 print_inline_ = true;
137 handle(stmt);
138 print_inline_ = prev;
139 }
140
141 protected:
142 std::ostream& os() {
143 return os_;
144 }
145
146 private:
147 std::ostream& os_;
148 bool print_inline_ = false;
149 int indent_size_ = 0;
150};
151
152TORCH_CUDA_CU_API std::ostream& operator<<(
153 std::ostream& os,
154 const Statement* stmt);
155
156TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion* f);
157TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion& f);
158
159} // namespace cuda
160} // namespace fuser
161} // namespace jit
162} // namespace torch
163