1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <dispatch.h> |
6 | |
7 | #include <c10/util/irange.h> |
8 | |
9 | #include <iostream> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | class Fusion; |
17 | namespace kir { |
18 | class Kernel; |
19 | class Scope; |
20 | } // namespace kir |
21 | |
22 | //! Define pretty printing functions for IR nodes |
23 | //! |
24 | //! This class is intended for debug printing, so it attempts |
25 | //! to handle invalid states as well. |
26 | //! |
27 | class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { |
28 | static constexpr char const* kTab = " "; |
29 | |
30 | public: |
31 | explicit IrPrinter(std::ostream& os) : os_(os) {} |
32 | |
33 | // Indent the generated code |
34 | std::ostream& indent() { |
35 | for (const auto i : c10::irange(indent_size_)) { |
36 | (void)i; // Suppress unused variable warning |
37 | os_ << " "; |
38 | } |
39 | return os_; |
40 | } |
41 | |
42 | void resetIndent() { |
43 | indent_size_ = 0; |
44 | } |
45 | |
46 | bool printInline() const { |
47 | return print_inline_; |
48 | } |
49 | |
50 | using OptInConstDispatch::handle; |
51 | |
52 | virtual void handle(Fusion* f); |
53 | |
54 | // handle calls some non const fusion ops, |
55 | // eventhough fusion should remain unchanged. |
56 | // Need to look into this. |
57 | virtual void handle(const Fusion* f) { |
58 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
59 | handle(const_cast<Fusion*>(f)); |
60 | } |
61 | |
62 | virtual void handle(Fusion& f) { |
63 | handle(&f); |
64 | } |
65 | |
66 | virtual void handle(const kir::Kernel* kernel); |
67 | virtual void handle(kir::Kernel& kernel); |
68 | |
69 | void handleScope(const kir::Scope& scope); |
70 | |
71 | void handle(const Statement* s) final; |
72 | void handle(const Val* v) final; |
73 | void handle(const Expr* e) final; |
74 | |
75 | void handle(const IterDomain*) final; |
76 | void handle(const TensorDomain*) final; |
77 | void handle(const TensorView*) final; |
78 | |
79 | void handle(const Bool*) final; |
80 | void handle(const Double*) final; |
81 | void handle(const Int*) final; |
82 | void handle(const ComplexDouble*) final; |
83 | void handle(const NamedScalar*) final; |
84 | |
85 | void handle(const FullOp*) final; |
86 | void handle(const ARangeOp*) final; |
87 | void handle(const EyeOp*) final; |
88 | void handle(const UnaryOp*) final; |
89 | void handle(const BinaryOp*) final; |
90 | void handle(const TernaryOp*) final; |
91 | void handle(const RNGOp*) final; |
92 | void handle(const ReductionOp*) final; |
93 | void handle(const GroupedReductionOp*) final; |
94 | void handle(const WelfordOp*) final; |
95 | void handle(const GroupedWelfordOp*) final; |
96 | void handle(const LoadStoreOp*) final; |
97 | void handle(const MmaOp*) final; |
98 | void handle(const BroadcastOp*) final; |
99 | void handle(const TransposeOp*) final; |
100 | void handle(const ExpandOp*) final; |
101 | void handle(const ShiftOp*) final; |
102 | void handle(const GatherOp*) final; |
103 | void handle(const ViewAsScalar*) final; |
104 | void handle(const ViewOp*) final; |
105 | |
106 | void handle(const kir::Predicate*) final; |
107 | void handle(const kir::TensorIndex*) final; |
108 | void handle(const kir::IntPair*) final; |
109 | |
110 | void handle(const kir::GridBroadcast*) final; |
111 | void handle(const kir::GridReduction*) final; |
112 | void handle(const kir::GroupedGridReduction*) final; |
113 | void handle(const kir::GridWelford*) final; |
114 | void handle(const kir::GroupedGridWelford*) final; |
115 | void handle(const kir::ForLoop*) final; |
116 | void handle(const kir::IfThenElse*) final; |
117 | void handle(const kir::Allocate*) final; |
118 | void handle(const kir::BlockSync*) final; |
119 | void handle(const kir::GridSync*) final; |
120 | void handle(const kir::CpAsyncWait*) final; |
121 | void handle(const kir::CpAsyncCommit*) final; |
122 | void handle(const kir::InitMagicZero*) final; |
123 | void handle(const kir::UpdateMagicZero*) final; |
124 | void handle(const kir::AllocateFusedReduction*) final; |
125 | void handle(const kir::Swizzle2DInt*) final; |
126 | void handle(const kir::PairSelect*) final; |
127 | |
128 | // IR math printer overrides these to prevent them from printing, keep |
129 | // override |
130 | void handle(const Split*) override; |
131 | void handle(const Merge*) override; |
132 | void handle(const Swizzle2D*) override; |
133 | |
134 | void print_inline(const Statement* stmt) { |
135 | bool prev = print_inline_; |
136 | print_inline_ = true; |
137 | handle(stmt); |
138 | print_inline_ = prev; |
139 | } |
140 | |
141 | protected: |
142 | std::ostream& os() { |
143 | return os_; |
144 | } |
145 | |
146 | private: |
147 | std::ostream& os_; |
148 | bool print_inline_ = false; |
149 | int indent_size_ = 0; |
150 | }; |
151 | |
152 | TORCH_CUDA_CU_API std::ostream& operator<<( |
153 | std::ostream& os, |
154 | const Statement* stmt); |
155 | |
156 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion* f); |
157 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream& os, Fusion& f); |
158 | |
159 | } // namespace cuda |
160 | } // namespace fuser |
161 | } // namespace jit |
162 | } // namespace torch |
163 |