1 | #include <ir_graphviz.h> |
2 | |
3 | #include <fusion.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_builder.h> |
6 | #include <type.h> |
7 | |
8 | #include <fstream> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | namespace { |
16 | |
17 | // Private helper, generating node labels for IrGraphGenerator |
18 | class IrNodeLabel : private OptInConstDispatch { |
19 | using DetailLevel = IrGraphGenerator::DetailLevel; |
20 | |
21 | public: |
22 | static std::string gen( |
23 | const Statement* node, |
24 | DetailLevel detail_level = DetailLevel::Basic) { |
25 | IrNodeLabel generator(detail_level); |
26 | generator.OptInConstDispatch::handle(node); |
27 | return generator.label_.str(); |
28 | } |
29 | |
30 | private: |
31 | explicit IrNodeLabel(DetailLevel detail_level) |
32 | : detail_level_(detail_level) {} |
33 | |
34 | ~IrNodeLabel() override = default; |
35 | |
36 | void handle(const Bool* b) override { |
37 | if (b->isSymbolic()) { |
38 | label_ << "b" << b->name(); |
39 | } else { |
40 | if (detail_level_ >= DetailLevel::Explicit) { |
41 | label_ << "b" << b->name() << "=" ; |
42 | } |
43 | label_ << *b->value(); |
44 | } |
45 | } |
46 | |
47 | void handle(const Double* d) override { |
48 | if (d->isSymbolic()) { |
49 | label_ << "d" << d->name(); |
50 | } else { |
51 | if (detail_level_ >= DetailLevel::Explicit) { |
52 | label_ << "d" << d->name() << "=" ; |
53 | } |
54 | label_ << *d->value(); |
55 | } |
56 | } |
57 | |
58 | void handle(const Int* i) override { |
59 | if (i->isSymbolic()) { |
60 | label_ << "i" << i->name(); |
61 | } else { |
62 | if (detail_level_ >= DetailLevel::Explicit) { |
63 | label_ << "i" << i->name() << "=" ; |
64 | } |
65 | label_ << *i->value(); |
66 | } |
67 | } |
68 | |
69 | void handle(const NamedScalar* ns) override { |
70 | label_ << ns->name(); |
71 | } |
72 | |
73 | void handle(const IterDomain* id) override { |
74 | label_ << id->getIterType(); |
75 | label_ << id->getParallelType(); |
76 | |
77 | label_ << "(" ; |
78 | if (!id->start()->isZeroInt()) { |
79 | label_ << IrNodeLabel::gen(id->start()) << " : " ; |
80 | } |
81 | label_ << IrNodeLabel::gen(id->extent()); |
82 | label_ << ")" ; |
83 | } |
84 | |
85 | void handle(const Split* split) override { |
86 | label_ << "Split(inner=" << (split->innerSplit() ? "true" : "false" ) |
87 | << ", factor=" << IrNodeLabel::gen(split->factor()) << ")" ; |
88 | } |
89 | |
90 | void handle(const Merge* merge) override { |
91 | label_ << "Merge" ; |
92 | } |
93 | |
94 | private: |
95 | std::stringstream label_; |
96 | const DetailLevel detail_level_; |
97 | }; |
98 | |
99 | // Small color palette from the X11 theme |
100 | static const char* getColorFromIndex(size_t index) { |
101 | const size_t number_of_colors = 10; |
102 | index = index % number_of_colors; |
103 | switch (index) { |
104 | case 0: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
105 | return "azure" ; |
106 | case 1: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
107 | return "pink" ; |
108 | case 2: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
109 | return "green" ; |
110 | case 3: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
111 | return "grey" ; |
112 | case 4: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
113 | return "yellow" ; |
114 | case 5: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
115 | return "lavender" ; |
116 | case 6: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
117 | return "cyan" ; |
118 | case 7: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
119 | return "white" ; |
120 | case 8: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
121 | return "magenta" ; |
122 | case 9: // NOLINT(cppcoreguidelines-avoid-magic-numbers) |
123 | return "red" ; |
124 | default: |
125 | break; |
126 | } |
127 | return "" ; |
128 | } |
129 | |
130 | } // anonymous namespace |
131 | |
132 | void IrGraphGenerator::print( |
133 | const Fusion* fusion, |
134 | const char* filename, |
135 | DetailLevel detail_level, |
136 | ExprColorMap* expr_color_map) { |
137 | std::ofstream dot_file(filename); |
138 | TORCH_CHECK(dot_file.good(), "Failed to open the IR graph file" ); |
139 | dot_file << toGraphviz(fusion, detail_level, expr_color_map); |
140 | } |
141 | |
142 | std::string IrGraphGenerator::toGraphviz( |
143 | const Fusion* fusion, |
144 | DetailLevel detail_level, |
145 | ExprColorMap* expr_color_map) { |
146 | IrGraphGenerator ir_graph(fusion, detail_level, expr_color_map); |
147 | return ir_graph.generate(); |
148 | } |
149 | |
150 | IrGraphGenerator::IrGraphGenerator( |
151 | const Fusion* fusion, |
152 | DetailLevel detail_level, |
153 | ExprColorMap* expr_color_map) |
154 | : detail_level_(detail_level), |
155 | fusion_(fusion), |
156 | expr_color_map_(expr_color_map) { |
157 | // setup inputs & outputs |
158 | // (indexes used to quickly check if a value is fusion input or output) |
159 | for (const auto* input : fusion->inputs()) { |
160 | TORCH_CHECK(inputs_.count(input) == 0); |
161 | inputs_.insert(input); |
162 | } |
163 | for (const auto* output : fusion->outputs()) { |
164 | TORCH_CHECK(outputs_.count(output) == 0); |
165 | outputs_.insert(output); |
166 | } |
167 | } |
168 | |
169 | std::string IrGraphGenerator::getid(const Statement* stm) { |
170 | const auto it = id_map_.find(stm); |
171 | if (it == id_map_.end()) { |
172 | // First reference, generate a new id |
173 | std::stringstream new_id; |
174 | new_id << "stm_" << next_id_++; |
175 | id_map_.insert({stm, new_id.str()}); |
176 | return new_id.str(); |
177 | } else { |
178 | return it->second; |
179 | } |
180 | } |
181 | |
182 | void IrGraphGenerator::addArc( |
183 | const Statement* src, |
184 | const Statement* dst, |
185 | const std::string& style) { |
186 | // We automatically visit (handle) the arc's source and destination |
187 | handle(src); |
188 | handle(dst); |
189 | |
190 | // generate and queue the arc definition |
191 | std::stringstream arc_def; |
192 | arc_def << getid(src) << " -> " << getid(dst) << " " << style; |
193 | arcs_.push_back(arc_def.str()); |
194 | } |
195 | |
196 | void IrGraphGenerator::printExpr(const Expr* expr, const std::string& label) { |
197 | graph_def_ << " " << getid(expr) << " " |
198 | << "[label=\"" << label << "\", shape=oval, color=blue, " |
199 | << "style=filled, fillcolor=" ; |
200 | if (expr_color_map_ != nullptr && expr_color_map_->count(expr)) { |
201 | graph_def_ << getColorFromIndex(expr_color_map_->at(expr)); |
202 | } else { |
203 | graph_def_ << "azure" ; |
204 | } |
205 | graph_def_ << "];\n" ; |
206 | } |
207 | |
208 | void IrGraphGenerator::printValue(const Val* val, const std::string& label) { |
209 | graph_def_ << " " << getid(val) << " [label=\"" << label |
210 | << "\", shape=rect, color=green, fontsize=10];\n" ; |
211 | } |
212 | |
213 | std::string IrGraphGenerator::generate() { |
214 | // IrGraphGenerator instances are not reusable |
215 | TORCH_CHECK(graph_def_.str().empty()); |
216 | TORCH_CHECK(visited_.empty()); |
217 | |
218 | // record detail level |
219 | graph_def_ << "// detail level: " ; |
220 | switch (detail_level_) { |
221 | case DetailLevel::ComputeOnly: |
222 | graph_def_ << "compute only\n" ; |
223 | break; |
224 | case DetailLevel::Basic: |
225 | graph_def_ << "minimal\n" ; |
226 | break; |
227 | case DetailLevel::Explicit: |
228 | graph_def_ << "explicit\n" ; |
229 | break; |
230 | case DetailLevel::Verbose: |
231 | graph_def_ << "verbose\n" ; |
232 | break; |
233 | default: |
234 | TORCH_CHECK(!"Unexpected detail level" ); |
235 | } |
236 | |
237 | graph_def_ << "digraph fusion_ir {\n" |
238 | << " node [shape=circle, color=gray];\n" |
239 | << " edge [color=black];\n" ; |
240 | |
241 | // Compute graph |
242 | generateComputeGraph(); |
243 | |
244 | // Schedule graph |
245 | if (detail_level_ > DetailLevel::ComputeOnly) { |
246 | generateScheduleGraph(); |
247 | } |
248 | |
249 | // All expressions & values |
250 | // (These are otherwise unreacheable (dead) nodes) |
251 | if (detail_level_ >= DetailLevel::Verbose) { |
252 | for (const auto* expr : fusion_->unordered_exprs()) { |
253 | handle(expr); |
254 | } |
255 | for (const auto* val : fusion_->vals()) { |
256 | handle(val); |
257 | } |
258 | } |
259 | |
260 | // Finally, print all arc definitions |
261 | for (const auto& arc : arcs_) { |
262 | graph_def_ << " " << arc << ";\n" ; |
263 | } |
264 | |
265 | graph_def_ << "}\n" ; |
266 | |
267 | // Make sure that all referenced nodes have been visited |
268 | for (const auto& kv : id_map_) { |
269 | TORCH_CHECK(visited(kv.first)); |
270 | } |
271 | |
272 | return graph_def_.str(); |
273 | } |
274 | |
275 | void IrGraphGenerator::generateComputeGraph() { |
276 | graph_def_ << " subgraph cluster_compute {\n" |
277 | << " label=\"compute\";\n" |
278 | << " style=dashed;\n" ; |
279 | |
280 | // Inputs |
281 | for (const auto* input : fusion_->inputs()) { |
282 | handle(input); |
283 | } |
284 | |
285 | // Outputs |
286 | for (const auto* output : fusion_->outputs()) { |
287 | handle(output); |
288 | } |
289 | |
290 | graph_def_ << " }\n" ; |
291 | } |
292 | |
293 | void IrGraphGenerator::generateScheduleGraph() { |
294 | graph_def_ << " subgraph cluster_schedule {\n" |
295 | << " label=\"schedule\";\n" |
296 | << " style=dashed;\n" ; |
297 | |
298 | // Connect TensorView with their TensorDomain |
299 | // (this will trigger the traversal of the schedule graph) |
300 | |
301 | for (auto tv : tensor_views_) { |
302 | addArc(tv->domain(), tv, "[style=dashed, arrowhead=none]" ); |
303 | if (detail_level_ >= DetailLevel::Explicit) { |
304 | // Maybe not the best way to handle the root domain, but should be okay |
305 | addArc( |
306 | tv, |
307 | IrBuilder::create<TensorDomain>(tv->getRootDomain()), |
308 | "[style=dashed, color=green, arrowhead=none]" ); |
309 | |
310 | if (tv->domain()->hasRFactor()) |
311 | addArc( |
312 | tv, |
313 | IrBuilder::create<TensorDomain>(tv->domain()->getRFactorDomain()), |
314 | "[style=dashed, color=green, arrowhead=none]" ); |
315 | } |
316 | } |
317 | |
318 | graph_def_ << " }\n" ; |
319 | } |
320 | |
321 | void IrGraphGenerator::handle(const Statement* s) { |
322 | OptInConstDispatch::handle(s); |
323 | } |
324 | |
325 | void IrGraphGenerator::handle(const Val* v) { |
326 | if (!visited(v)) { |
327 | visited_.insert(v); |
328 | if (const auto* def = v->definition()) { |
329 | handle(def); |
330 | } |
331 | OptInConstDispatch::handle(v); |
332 | } |
333 | } |
334 | |
335 | void IrGraphGenerator::handle(const Expr* e) { |
336 | if (!visited(e)) { |
337 | visited_.insert(e); |
338 | OptInConstDispatch::handle(e); |
339 | } |
340 | } |
341 | |
342 | void IrGraphGenerator::handle(const TensorDomain* td) { |
343 | graph_def_ << " " << getid(td) << " [label=\"TensorDomain\", " |
344 | << "shape=note, color=gray, " |
345 | << "style=filled, fillcolor=gray90, fontsize=10];\n" ; |
346 | for (auto iter_domain : td->domain()) { |
347 | addArc(iter_domain, td, "[color=gray]" ); |
348 | } |
349 | } |
350 | |
351 | void IrGraphGenerator::handle(const IterDomain* id) { |
352 | graph_def_ << " " << getid(id) << " [label=\"" << IrNodeLabel::gen(id) |
353 | << "\", shape=cds, color=gray, fontsize=10];\n" ; |
354 | |
355 | if (!id->start()->isZeroInt()) { |
356 | addArc(id->start(), id, "[color=gray]" ); |
357 | } |
358 | |
359 | addArc(id->extent(), id, "[color=gray]" ); |
360 | } |
361 | |
362 | void IrGraphGenerator::handle(const Bool* b) { |
363 | printValue(b, IrNodeLabel::gen(b, detail_level_)); |
364 | } |
365 | |
366 | void IrGraphGenerator::handle(const Double* d) { |
367 | printValue(d, IrNodeLabel::gen(d, detail_level_)); |
368 | } |
369 | |
370 | void IrGraphGenerator::handle(const Int* i) { |
371 | printValue(i, IrNodeLabel::gen(i, detail_level_)); |
372 | } |
373 | |
374 | void IrGraphGenerator::handle(const ComplexDouble* i) { |
375 | printValue(i, IrNodeLabel::gen(i, detail_level_)); |
376 | } |
377 | |
378 | void IrGraphGenerator::handle(const NamedScalar* i) { |
379 | printValue(i, IrNodeLabel::gen(i, detail_level_)); |
380 | } |
381 | |
382 | void IrGraphGenerator::handle(const TensorView* tv) { |
383 | std::stringstream label; |
384 | label << "{T" << tv->name() << "|" ; |
385 | label << "{" ; |
386 | bool first_axis = true; |
387 | for (auto iter_domain : tv->domain()->domain()) { |
388 | if (first_axis) { |
389 | first_axis = false; |
390 | } else { |
391 | label << "|" ; |
392 | } |
393 | label << IrNodeLabel::gen(iter_domain); |
394 | } |
395 | label << "}}" ; |
396 | |
397 | const bool is_input = inputs_.find(tv) != inputs_.end(); |
398 | const bool is_output = outputs_.find(tv) != outputs_.end(); |
399 | |
400 | const char* style = is_input ? "style=filled, fillcolor=palegreen" |
401 | : is_output ? "style=filled, fillcolor=lightblue" |
402 | : "style=filled, fillcolor=beige" ; |
403 | |
404 | graph_def_ << " " << getid(tv) << " [label=\"" << label.str() |
405 | << "\", shape=Mrecord, color=brown, " << style << "];\n" ; |
406 | |
407 | tensor_views_.push_back(tv); |
408 | } |
409 | |
410 | void IrGraphGenerator::handle(const FullOp* fop) { |
411 | // node |
412 | printExpr(fop, "full" ); |
413 | |
414 | // inputs & outputs |
415 | addArc(fop->getFillValue(), fop); |
416 | addArc(fop, fop->output(0)); |
417 | } |
418 | |
419 | void IrGraphGenerator::handle(const ARangeOp* aop) { |
420 | // node |
421 | printExpr(aop, "arange" ); |
422 | |
423 | // inputs & outputs |
424 | addArc(aop->start(), aop); |
425 | addArc(aop->end(), aop); |
426 | addArc(aop->step(), aop); |
427 | addArc(aop, aop->output(0)); |
428 | } |
429 | |
430 | void IrGraphGenerator::handle(const EyeOp* eop) { |
431 | // node |
432 | printExpr(eop, "eye" ); |
433 | |
434 | // inputs & outputs |
435 | addArc(eop, eop->output(0)); |
436 | } |
437 | |
438 | void IrGraphGenerator::handle(const UnaryOp* uop) { |
439 | // node |
440 | std::stringstream label; |
441 | label << uop->getUnaryOpType(); |
442 | printExpr(uop, label.str()); |
443 | |
444 | // inputs & outputs |
445 | addArc(uop->in(), uop); |
446 | addArc(uop, uop->out()); |
447 | } |
448 | |
449 | void IrGraphGenerator::handle(const BinaryOp* bop) { |
450 | // node |
451 | std::stringstream label; |
452 | label << bop->getBinaryOpType(); |
453 | printExpr(bop, label.str()); |
454 | |
455 | // inputs & outputs |
456 | addArc(bop->lhs(), bop); |
457 | addArc(bop->rhs(), bop, "[color=blue]" ); |
458 | addArc(bop, bop->out()); |
459 | } |
460 | |
461 | void IrGraphGenerator::handle(const TernaryOp* op) { |
462 | // node |
463 | std::stringstream label; |
464 | label << op->getTernaryOpType(); |
465 | printExpr(op, label.str()); |
466 | |
467 | // inputs & outputs |
468 | addArc(op->in1(), op); |
469 | addArc(op->in2(), op, "[color=blue]" ); |
470 | addArc(op->in3(), op, "[color=brown]" ); |
471 | addArc(op, op->out()); |
472 | } |
473 | |
474 | void IrGraphGenerator::handle(const RNGOp* op) { |
475 | // node |
476 | std::stringstream label; |
477 | label << op->getRNGOpType(); |
478 | printExpr(op, label.str()); |
479 | |
480 | // inputs & outputs |
481 | addArc(op, op->output(0)); |
482 | } |
483 | |
484 | void IrGraphGenerator::handle(const BroadcastOp* op) { |
485 | printExpr(op, "Broadcast" ); |
486 | addArc(op->in(), op); |
487 | addArc(op, op->out()); |
488 | } |
489 | |
490 | void IrGraphGenerator::handle(const ReductionOp* op) { |
491 | // node |
492 | std::stringstream label; |
493 | label << "Reduction(" << op->getReductionOpType() << ")" ; |
494 | printExpr(op, label.str()); |
495 | |
496 | // inputs & outputs |
497 | addArc(op->in(), op); |
498 | addArc(op->init(), op, "[color=blue]" ); |
499 | addArc(op, op->out()); |
500 | } |
501 | |
502 | void IrGraphGenerator::handle(const Split* split) { |
503 | printExpr(split, IrNodeLabel::gen(split)); |
504 | addArc(split->in(), split); |
505 | addArc(split, split->outer()); |
506 | addArc(split, split->inner()); |
507 | } |
508 | |
509 | void IrGraphGenerator::handle(const Merge* merge) { |
510 | printExpr(merge, IrNodeLabel::gen(merge)); |
511 | addArc(merge->outer(), merge); |
512 | addArc(merge->inner(), merge); |
513 | addArc(merge, merge->out()); |
514 | } |
515 | |
516 | } // namespace cuda |
517 | } // namespace fuser |
518 | } // namespace jit |
519 | } // namespace torch |
520 | |