1#include <ir_graphviz.h>
2
3#include <fusion.h>
4#include <ir_all_nodes.h>
5#include <ir_builder.h>
6#include <type.h>
7
8#include <fstream>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15namespace {
16
17// Private helper, generating node labels for IrGraphGenerator
18class IrNodeLabel : private OptInConstDispatch {
19 using DetailLevel = IrGraphGenerator::DetailLevel;
20
21 public:
22 static std::string gen(
23 const Statement* node,
24 DetailLevel detail_level = DetailLevel::Basic) {
25 IrNodeLabel generator(detail_level);
26 generator.OptInConstDispatch::handle(node);
27 return generator.label_.str();
28 }
29
30 private:
31 explicit IrNodeLabel(DetailLevel detail_level)
32 : detail_level_(detail_level) {}
33
34 ~IrNodeLabel() override = default;
35
36 void handle(const Bool* b) override {
37 if (b->isSymbolic()) {
38 label_ << "b" << b->name();
39 } else {
40 if (detail_level_ >= DetailLevel::Explicit) {
41 label_ << "b" << b->name() << "=";
42 }
43 label_ << *b->value();
44 }
45 }
46
47 void handle(const Double* d) override {
48 if (d->isSymbolic()) {
49 label_ << "d" << d->name();
50 } else {
51 if (detail_level_ >= DetailLevel::Explicit) {
52 label_ << "d" << d->name() << "=";
53 }
54 label_ << *d->value();
55 }
56 }
57
58 void handle(const Int* i) override {
59 if (i->isSymbolic()) {
60 label_ << "i" << i->name();
61 } else {
62 if (detail_level_ >= DetailLevel::Explicit) {
63 label_ << "i" << i->name() << "=";
64 }
65 label_ << *i->value();
66 }
67 }
68
69 void handle(const NamedScalar* ns) override {
70 label_ << ns->name();
71 }
72
73 void handle(const IterDomain* id) override {
74 label_ << id->getIterType();
75 label_ << id->getParallelType();
76
77 label_ << "(";
78 if (!id->start()->isZeroInt()) {
79 label_ << IrNodeLabel::gen(id->start()) << " : ";
80 }
81 label_ << IrNodeLabel::gen(id->extent());
82 label_ << ")";
83 }
84
85 void handle(const Split* split) override {
86 label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false")
87 << ", factor=" << IrNodeLabel::gen(split->factor()) << ")";
88 }
89
90 void handle(const Merge* merge) override {
91 label_ << "Merge";
92 }
93
94 private:
95 std::stringstream label_;
96 const DetailLevel detail_level_;
97};
98
99// Small color palette from the X11 theme
100static const char* getColorFromIndex(size_t index) {
101 const size_t number_of_colors = 10;
102 index = index % number_of_colors;
103 switch (index) {
104 case 0: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
105 return "azure";
106 case 1: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
107 return "pink";
108 case 2: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
109 return "green";
110 case 3: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
111 return "grey";
112 case 4: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
113 return "yellow";
114 case 5: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
115 return "lavender";
116 case 6: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
117 return "cyan";
118 case 7: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
119 return "white";
120 case 8: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
121 return "magenta";
122 case 9: // NOLINT(cppcoreguidelines-avoid-magic-numbers)
123 return "red";
124 default:
125 break;
126 }
127 return "";
128}
129
130} // anonymous namespace
131
132void IrGraphGenerator::print(
133 const Fusion* fusion,
134 const char* filename,
135 DetailLevel detail_level,
136 ExprColorMap* expr_color_map) {
137 std::ofstream dot_file(filename);
138 TORCH_CHECK(dot_file.good(), "Failed to open the IR graph file");
139 dot_file << toGraphviz(fusion, detail_level, expr_color_map);
140}
141
142std::string IrGraphGenerator::toGraphviz(
143 const Fusion* fusion,
144 DetailLevel detail_level,
145 ExprColorMap* expr_color_map) {
146 IrGraphGenerator ir_graph(fusion, detail_level, expr_color_map);
147 return ir_graph.generate();
148}
149
150IrGraphGenerator::IrGraphGenerator(
151 const Fusion* fusion,
152 DetailLevel detail_level,
153 ExprColorMap* expr_color_map)
154 : detail_level_(detail_level),
155 fusion_(fusion),
156 expr_color_map_(expr_color_map) {
157 // setup inputs & outputs
158 // (indexes used to quickly check if a value is fusion input or output)
159 for (const auto* input : fusion->inputs()) {
160 TORCH_CHECK(inputs_.count(input) == 0);
161 inputs_.insert(input);
162 }
163 for (const auto* output : fusion->outputs()) {
164 TORCH_CHECK(outputs_.count(output) == 0);
165 outputs_.insert(output);
166 }
167}
168
169std::string IrGraphGenerator::getid(const Statement* stm) {
170 const auto it = id_map_.find(stm);
171 if (it == id_map_.end()) {
172 // First reference, generate a new id
173 std::stringstream new_id;
174 new_id << "stm_" << next_id_++;
175 id_map_.insert({stm, new_id.str()});
176 return new_id.str();
177 } else {
178 return it->second;
179 }
180}
181
182void IrGraphGenerator::addArc(
183 const Statement* src,
184 const Statement* dst,
185 const std::string& style) {
186 // We automatically visit (handle) the arc's source and destination
187 handle(src);
188 handle(dst);
189
190 // generate and queue the arc definition
191 std::stringstream arc_def;
192 arc_def << getid(src) << " -> " << getid(dst) << " " << style;
193 arcs_.push_back(arc_def.str());
194}
195
196void IrGraphGenerator::printExpr(const Expr* expr, const std::string& label) {
197 graph_def_ << " " << getid(expr) << " "
198 << "[label=\"" << label << "\", shape=oval, color=blue, "
199 << "style=filled, fillcolor=";
200 if (expr_color_map_ != nullptr && expr_color_map_->count(expr)) {
201 graph_def_ << getColorFromIndex(expr_color_map_->at(expr));
202 } else {
203 graph_def_ << "azure";
204 }
205 graph_def_ << "];\n";
206}
207
208void IrGraphGenerator::printValue(const Val* val, const std::string& label) {
209 graph_def_ << " " << getid(val) << " [label=\"" << label
210 << "\", shape=rect, color=green, fontsize=10];\n";
211}
212
213std::string IrGraphGenerator::generate() {
214 // IrGraphGenerator instances are not reusable
215 TORCH_CHECK(graph_def_.str().empty());
216 TORCH_CHECK(visited_.empty());
217
218 // record detail level
219 graph_def_ << "// detail level: ";
220 switch (detail_level_) {
221 case DetailLevel::ComputeOnly:
222 graph_def_ << "compute only\n";
223 break;
224 case DetailLevel::Basic:
225 graph_def_ << "minimal\n";
226 break;
227 case DetailLevel::Explicit:
228 graph_def_ << "explicit\n";
229 break;
230 case DetailLevel::Verbose:
231 graph_def_ << "verbose\n";
232 break;
233 default:
234 TORCH_CHECK(!"Unexpected detail level");
235 }
236
237 graph_def_ << "digraph fusion_ir {\n"
238 << " node [shape=circle, color=gray];\n"
239 << " edge [color=black];\n";
240
241 // Compute graph
242 generateComputeGraph();
243
244 // Schedule graph
245 if (detail_level_ > DetailLevel::ComputeOnly) {
246 generateScheduleGraph();
247 }
248
249 // All expressions & values
250 // (These are otherwise unreacheable (dead) nodes)
251 if (detail_level_ >= DetailLevel::Verbose) {
252 for (const auto* expr : fusion_->unordered_exprs()) {
253 handle(expr);
254 }
255 for (const auto* val : fusion_->vals()) {
256 handle(val);
257 }
258 }
259
260 // Finally, print all arc definitions
261 for (const auto& arc : arcs_) {
262 graph_def_ << " " << arc << ";\n";
263 }
264
265 graph_def_ << "}\n";
266
267 // Make sure that all referenced nodes have been visited
268 for (const auto& kv : id_map_) {
269 TORCH_CHECK(visited(kv.first));
270 }
271
272 return graph_def_.str();
273}
274
275void IrGraphGenerator::generateComputeGraph() {
276 graph_def_ << " subgraph cluster_compute {\n"
277 << " label=\"compute\";\n"
278 << " style=dashed;\n";
279
280 // Inputs
281 for (const auto* input : fusion_->inputs()) {
282 handle(input);
283 }
284
285 // Outputs
286 for (const auto* output : fusion_->outputs()) {
287 handle(output);
288 }
289
290 graph_def_ << " }\n";
291}
292
293void IrGraphGenerator::generateScheduleGraph() {
294 graph_def_ << " subgraph cluster_schedule {\n"
295 << " label=\"schedule\";\n"
296 << " style=dashed;\n";
297
298 // Connect TensorView with their TensorDomain
299 // (this will trigger the traversal of the schedule graph)
300
301 for (auto tv : tensor_views_) {
302 addArc(tv->domain(), tv, "[style=dashed, arrowhead=none]");
303 if (detail_level_ >= DetailLevel::Explicit) {
304 // Maybe not the best way to handle the root domain, but should be okay
305 addArc(
306 tv,
307 IrBuilder::create<TensorDomain>(tv->getRootDomain()),
308 "[style=dashed, color=green, arrowhead=none]");
309
310 if (tv->domain()->hasRFactor())
311 addArc(
312 tv,
313 IrBuilder::create<TensorDomain>(tv->domain()->getRFactorDomain()),
314 "[style=dashed, color=green, arrowhead=none]");
315 }
316 }
317
318 graph_def_ << " }\n";
319}
320
321void IrGraphGenerator::handle(const Statement* s) {
322 OptInConstDispatch::handle(s);
323}
324
325void IrGraphGenerator::handle(const Val* v) {
326 if (!visited(v)) {
327 visited_.insert(v);
328 if (const auto* def = v->definition()) {
329 handle(def);
330 }
331 OptInConstDispatch::handle(v);
332 }
333}
334
335void IrGraphGenerator::handle(const Expr* e) {
336 if (!visited(e)) {
337 visited_.insert(e);
338 OptInConstDispatch::handle(e);
339 }
340}
341
342void IrGraphGenerator::handle(const TensorDomain* td) {
343 graph_def_ << " " << getid(td) << " [label=\"TensorDomain\", "
344 << "shape=note, color=gray, "
345 << "style=filled, fillcolor=gray90, fontsize=10];\n";
346 for (auto iter_domain : td->domain()) {
347 addArc(iter_domain, td, "[color=gray]");
348 }
349}
350
351void IrGraphGenerator::handle(const IterDomain* id) {
352 graph_def_ << " " << getid(id) << " [label=\"" << IrNodeLabel::gen(id)
353 << "\", shape=cds, color=gray, fontsize=10];\n";
354
355 if (!id->start()->isZeroInt()) {
356 addArc(id->start(), id, "[color=gray]");
357 }
358
359 addArc(id->extent(), id, "[color=gray]");
360}
361
362void IrGraphGenerator::handle(const Bool* b) {
363 printValue(b, IrNodeLabel::gen(b, detail_level_));
364}
365
366void IrGraphGenerator::handle(const Double* d) {
367 printValue(d, IrNodeLabel::gen(d, detail_level_));
368}
369
370void IrGraphGenerator::handle(const Int* i) {
371 printValue(i, IrNodeLabel::gen(i, detail_level_));
372}
373
374void IrGraphGenerator::handle(const ComplexDouble* i) {
375 printValue(i, IrNodeLabel::gen(i, detail_level_));
376}
377
378void IrGraphGenerator::handle(const NamedScalar* i) {
379 printValue(i, IrNodeLabel::gen(i, detail_level_));
380}
381
382void IrGraphGenerator::handle(const TensorView* tv) {
383 std::stringstream label;
384 label << "{T" << tv->name() << "|";
385 label << "{";
386 bool first_axis = true;
387 for (auto iter_domain : tv->domain()->domain()) {
388 if (first_axis) {
389 first_axis = false;
390 } else {
391 label << "|";
392 }
393 label << IrNodeLabel::gen(iter_domain);
394 }
395 label << "}}";
396
397 const bool is_input = inputs_.find(tv) != inputs_.end();
398 const bool is_output = outputs_.find(tv) != outputs_.end();
399
400 const char* style = is_input ? "style=filled, fillcolor=palegreen"
401 : is_output ? "style=filled, fillcolor=lightblue"
402 : "style=filled, fillcolor=beige";
403
404 graph_def_ << " " << getid(tv) << " [label=\"" << label.str()
405 << "\", shape=Mrecord, color=brown, " << style << "];\n";
406
407 tensor_views_.push_back(tv);
408}
409
410void IrGraphGenerator::handle(const FullOp* fop) {
411 // node
412 printExpr(fop, "full");
413
414 // inputs & outputs
415 addArc(fop->getFillValue(), fop);
416 addArc(fop, fop->output(0));
417}
418
419void IrGraphGenerator::handle(const ARangeOp* aop) {
420 // node
421 printExpr(aop, "arange");
422
423 // inputs & outputs
424 addArc(aop->start(), aop);
425 addArc(aop->end(), aop);
426 addArc(aop->step(), aop);
427 addArc(aop, aop->output(0));
428}
429
430void IrGraphGenerator::handle(const EyeOp* eop) {
431 // node
432 printExpr(eop, "eye");
433
434 // inputs & outputs
435 addArc(eop, eop->output(0));
436}
437
438void IrGraphGenerator::handle(const UnaryOp* uop) {
439 // node
440 std::stringstream label;
441 label << uop->getUnaryOpType();
442 printExpr(uop, label.str());
443
444 // inputs & outputs
445 addArc(uop->in(), uop);
446 addArc(uop, uop->out());
447}
448
449void IrGraphGenerator::handle(const BinaryOp* bop) {
450 // node
451 std::stringstream label;
452 label << bop->getBinaryOpType();
453 printExpr(bop, label.str());
454
455 // inputs & outputs
456 addArc(bop->lhs(), bop);
457 addArc(bop->rhs(), bop, "[color=blue]");
458 addArc(bop, bop->out());
459}
460
461void IrGraphGenerator::handle(const TernaryOp* op) {
462 // node
463 std::stringstream label;
464 label << op->getTernaryOpType();
465 printExpr(op, label.str());
466
467 // inputs & outputs
468 addArc(op->in1(), op);
469 addArc(op->in2(), op, "[color=blue]");
470 addArc(op->in3(), op, "[color=brown]");
471 addArc(op, op->out());
472}
473
474void IrGraphGenerator::handle(const RNGOp* op) {
475 // node
476 std::stringstream label;
477 label << op->getRNGOpType();
478 printExpr(op, label.str());
479
480 // inputs & outputs
481 addArc(op, op->output(0));
482}
483
484void IrGraphGenerator::handle(const BroadcastOp* op) {
485 printExpr(op, "Broadcast");
486 addArc(op->in(), op);
487 addArc(op, op->out());
488}
489
490void IrGraphGenerator::handle(const ReductionOp* op) {
491 // node
492 std::stringstream label;
493 label << "Reduction(" << op->getReductionOpType() << ")";
494 printExpr(op, label.str());
495
496 // inputs & outputs
497 addArc(op->in(), op);
498 addArc(op->init(), op, "[color=blue]");
499 addArc(op, op->out());
500}
501
502void IrGraphGenerator::handle(const Split* split) {
503 printExpr(split, IrNodeLabel::gen(split));
504 addArc(split->in(), split);
505 addArc(split, split->outer());
506 addArc(split, split->inner());
507}
508
509void IrGraphGenerator::handle(const Merge* merge) {
510 printExpr(merge, IrNodeLabel::gen(merge));
511 addArc(merge->outer(), merge);
512 addArc(merge->inner(), merge);
513 addArc(merge, merge->out());
514}
515
516} // namespace cuda
517} // namespace fuser
518} // namespace jit
519} // namespace torch
520