1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <dispatch.h> |
5 | |
6 | #include <sstream> |
7 | #include <string> |
8 | #include <unordered_map> |
9 | #include <unordered_set> |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | // Generates a DOT (https://www.graphviz.org) graph |
18 | // representation of a fuser IR |
19 | // |
20 | // Usage: |
21 | // 1) Add calls to IrGraphGenerator::print(), for example: |
22 | // `IrGraphGenerator::print(&fusion, "ir.dot")` |
23 | // |
24 | // 2) Call IrGraphGenerator::print() from a debugger. Using gdb for example: |
25 | // `call IrGraphGenerator::print(&fusion, "ir.dot", |
26 | // IrGraphGenerator::DetailLevel::Explicit)` |
27 | // |
28 | // Notes: |
29 | // - When called from the debugger, the detail_level must be |
30 | // explicitly passed in (most debuggers don't support default arguments) |
31 | // |
32 | // - The output dot file path can't include shell specific notations, |
33 | // for example you can't use "~/temp/ir.dot" ("/home/user/temp/ir.dot" |
34 | // must be used instead) |
35 | // |
36 | class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch { |
37 | public: |
38 | enum class DetailLevel { |
39 | ComputeOnly, // Only dataflow (compute) nodes |
40 | Basic, // Compute + schedule, with minimal details (default) |
41 | Explicit, // Additional details (ex. symbolic names for scalar constants) |
42 | Verbose, // Includes all values and dead definitions |
43 | }; |
44 | |
45 | using ExprColorMap = std::unordered_map<const Expr*, size_t>; |
46 | |
47 | public: |
48 | static void print( |
49 | const Fusion* fusion, |
50 | const char* filename, |
51 | DetailLevel detail_level = DetailLevel::Basic, |
52 | ExprColorMap* expr_color_map = nullptr); |
53 | |
54 | static std::string toGraphviz( |
55 | const Fusion* fusion, |
56 | DetailLevel detail_level, |
57 | ExprColorMap* expr_color_map = nullptr); |
58 | |
59 | private: |
60 | IrGraphGenerator( |
61 | const Fusion* fusion, |
62 | DetailLevel detail_level, |
63 | ExprColorMap* expr_color_map = nullptr); |
64 | ~IrGraphGenerator() override = default; |
65 | |
66 | std::string generate(); |
67 | |
68 | void generateComputeGraph(); |
69 | void generateScheduleGraph(); |
70 | |
71 | void handle(const Statement*) override; |
72 | void handle(const Val*) override; |
73 | void handle(const Expr*) override; |
74 | |
75 | void handle(const TensorDomain*) override; |
76 | void handle(const TensorView*) override; |
77 | void handle(const IterDomain*) override; |
78 | |
79 | void handle(const Bool*) override; |
80 | void handle(const Double*) override; |
81 | void handle(const Int*) override; |
82 | void handle(const ComplexDouble*) override; |
83 | void handle(const NamedScalar*) override; |
84 | |
85 | void handle(const FullOp*) override; |
86 | void handle(const ARangeOp*) override; |
87 | void handle(const EyeOp*) override; |
88 | void handle(const UnaryOp*) override; |
89 | void handle(const BinaryOp*) override; |
90 | void handle(const TernaryOp*) override; |
91 | void handle(const RNGOp*) override; |
92 | void handle(const BroadcastOp*) override; |
93 | void handle(const ReductionOp*) override; |
94 | |
95 | void handle(const Split*) override; |
96 | void handle(const Merge*) override; |
97 | |
98 | // lookup the graph id, creating one if not found |
99 | std::string getid(const Statement* stm); |
100 | |
101 | bool visited(const Statement* s) const { |
102 | return visited_.find(s) != visited_.end(); |
103 | } |
104 | |
105 | void addArc( |
106 | const Statement* src, |
107 | const Statement* dst, |
108 | const std::string& style = "" ); |
109 | |
110 | void printExpr(const Expr* expr, const std::string& label); |
111 | void printValue(const Val* val, const std::string& label); |
112 | |
113 | private: |
114 | const DetailLevel detail_level_; |
115 | const Fusion* const fusion_; |
116 | std::stringstream graph_def_; |
117 | std::unordered_map<const Statement*, std::string> id_map_; |
118 | std::unordered_set<const Statement*> visited_; |
119 | std::unordered_set<const Val*> inputs_; |
120 | std::unordered_set<const Val*> outputs_; |
121 | std::vector<const TensorView*> tensor_views_; |
122 | std::vector<std::string> arcs_; |
123 | int next_id_ = 1; |
124 | ExprColorMap* expr_color_map_ = nullptr; |
125 | }; |
126 | |
127 | } // namespace cuda |
128 | } // namespace fuser |
129 | } // namespace jit |
130 | } // namespace torch |
131 | |