1#pragma once
2
3#include <c10/macros/Export.h>
4#include <dispatch.h>
5
6#include <sstream>
7#include <string>
8#include <unordered_map>
9#include <unordered_set>
10#include <vector>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17// Generates a DOT (https://www.graphviz.org) graph
18// representation of a fuser IR
19//
20// Usage:
21// 1) Add calls to IrGraphGenerator::print(), for example:
22// `IrGraphGenerator::print(&fusion, "ir.dot")`
23//
24// 2) Call IrGraphGenerator::print() from a debugger. Using gdb for example:
25// `call IrGraphGenerator::print(&fusion, "ir.dot",
26// IrGraphGenerator::DetailLevel::Explicit)`
27//
28// Notes:
29// - When called from the debugger, the detail_level must be
30// explicitly passed in (most debuggers don't support default arguments)
31//
32// - The output dot file path can't include shell specific notations,
33// for example you can't use "~/temp/ir.dot" ("/home/user/temp/ir.dot"
34// must be used instead)
35//
36class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch {
37 public:
38 enum class DetailLevel {
39 ComputeOnly, // Only dataflow (compute) nodes
40 Basic, // Compute + schedule, with minimal details (default)
41 Explicit, // Additional details (ex. symbolic names for scalar constants)
42 Verbose, // Includes all values and dead definitions
43 };
44
45 using ExprColorMap = std::unordered_map<const Expr*, size_t>;
46
47 public:
48 static void print(
49 const Fusion* fusion,
50 const char* filename,
51 DetailLevel detail_level = DetailLevel::Basic,
52 ExprColorMap* expr_color_map = nullptr);
53
54 static std::string toGraphviz(
55 const Fusion* fusion,
56 DetailLevel detail_level,
57 ExprColorMap* expr_color_map = nullptr);
58
59 private:
60 IrGraphGenerator(
61 const Fusion* fusion,
62 DetailLevel detail_level,
63 ExprColorMap* expr_color_map = nullptr);
64 ~IrGraphGenerator() override = default;
65
66 std::string generate();
67
68 void generateComputeGraph();
69 void generateScheduleGraph();
70
71 void handle(const Statement*) override;
72 void handle(const Val*) override;
73 void handle(const Expr*) override;
74
75 void handle(const TensorDomain*) override;
76 void handle(const TensorView*) override;
77 void handle(const IterDomain*) override;
78
79 void handle(const Bool*) override;
80 void handle(const Double*) override;
81 void handle(const Int*) override;
82 void handle(const ComplexDouble*) override;
83 void handle(const NamedScalar*) override;
84
85 void handle(const FullOp*) override;
86 void handle(const ARangeOp*) override;
87 void handle(const EyeOp*) override;
88 void handle(const UnaryOp*) override;
89 void handle(const BinaryOp*) override;
90 void handle(const TernaryOp*) override;
91 void handle(const RNGOp*) override;
92 void handle(const BroadcastOp*) override;
93 void handle(const ReductionOp*) override;
94
95 void handle(const Split*) override;
96 void handle(const Merge*) override;
97
98 // lookup the graph id, creating one if not found
99 std::string getid(const Statement* stm);
100
101 bool visited(const Statement* s) const {
102 return visited_.find(s) != visited_.end();
103 }
104
105 void addArc(
106 const Statement* src,
107 const Statement* dst,
108 const std::string& style = "");
109
110 void printExpr(const Expr* expr, const std::string& label);
111 void printValue(const Val* val, const std::string& label);
112
113 private:
114 const DetailLevel detail_level_;
115 const Fusion* const fusion_;
116 std::stringstream graph_def_;
117 std::unordered_map<const Statement*, std::string> id_map_;
118 std::unordered_set<const Statement*> visited_;
119 std::unordered_set<const Val*> inputs_;
120 std::unordered_set<const Val*> outputs_;
121 std::vector<const TensorView*> tensor_views_;
122 std::vector<std::string> arcs_;
123 int next_id_ = 1;
124 ExprColorMap* expr_color_map_ = nullptr;
125};
126
127} // namespace cuda
128} // namespace fuser
129} // namespace jit
130} // namespace torch
131