1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/dump_graphviz.h"
16
17#include <algorithm>
18#include <cmath>
19#include <functional>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "absl/memory/memory.h"
25#include "absl/strings/str_replace.h"
26#include "absl/strings/str_split.h"
27#include "absl/strings/strip.h"
28#include "re2/re2.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/lite/toco/model_flags.pb.h"
31#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
32#include "tensorflow/lite/toco/toco_port.h"
33#include "tensorflow/lite/toco/toco_types.h"
34#include "tensorflow/lite/toco/tooling_util.h"
35
36using toco::port::AppendF;
37using toco::port::StringF;
38
39namespace toco {
40namespace {
41
42// 'nslimit' is a graphviz (dot) parameter that limits the iterations during
43// the layout phase. Omitting it allows infinite iterations, causing some
44// complex graphs to never finish. A value of 125 produces good graphs
45// while allowing complex graphs to finish.
46constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
47 nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
48)CODE";
49// Note: tooltip's are only supported on SVGs in Chrome.
50constexpr char kSubgraphFmt[] =
51 R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
52)CODE";
53constexpr char kArrayNodeFmt[] =
54 R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
55)CODE";
56constexpr char kOpNodeFmt[] =
57 R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
58)CODE";
59constexpr char kInputEdgeFmt[] =
60 R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
61)CODE";
62constexpr char kOutputEdgeFmt[] =
63 R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
64)CODE";
65constexpr char kRNNBackEdgeFmt[] =
66 R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false];
67)CODE";
68constexpr char kUnicodeMult[] = "\u00D7";
69constexpr char kUnicodeEllipsis[] = " \u2026 ";
70
71class Color {
72 public:
73 Color() {}
74 Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
75 explicit Color(uint32 word)
76 : r_((word & 0x00FF0000) >> 16),
77 g_((word & 0x0000FF00) >> 8),
78 b_((word & 0x000000FF) >> 0) {}
79
80 // Returns the string serialization of this color in graphviz format,
81 // for use as 'fillcolor' in boxes.
82 std::string AsHexString() const {
83 return StringF("#%.2X%.2X%.2X", r_, g_, b_);
84 }
85 // The color to use for this node; will be used as 'fillcolor'
86 // for its box. See Color::AsHexString. A suitable, different
87 // color will be chosen for the 'fontcolor' for the inside text
88 // label, see Color::TextColorString.
89 // Returns the serialization in graphviz format of a suitable color to use
90 // 'fontcolor' in the same boxes. It should black or white, whichever offers
91 // the better contrast from AsHexString().
92 std::string TextColorString() const {
93 // https://en.wikipedia.org/wiki/Relative_luminance
94 const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
95 const uint8 l = luminance > 128.f ? 0 : 255;
96 return StringF("#%.2X%.2X%.2X", l, l, l);
97 }
98
99 private:
100 uint8 r_ = 0, g_ = 0, b_ = 0;
101};
102
103Color HashStringToColor(std::string s) {
104 // Return a unique color for a name.
105 //
106 // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
107 // the string to a uint_32, then twiddles some bits to get a light and subtle
108 // color. This seems to be a good heuristic for keeping enough of the name to
109 // hash to a unique color while still revealing structure through naming
110 // similarities.
111 //
112 // The regular expression "_\d+" matches any underscore followed by numbers,
113 // which we strip out. Examples:
114 //
115 // "Conv" -> "Conv"
116 // "Conv_2" -> "Conv"
117 // "Conv_72" -> "Conv"
118 // "Pad_1_bias -> "Pad_bias"
119 // "Conv_abc" -> "Conv_abc"
120
121 RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
122 uint32 color_word = std::hash<std::string>{}(s);
123 color_word |= 0x00E0E0E0;
124 return Color(color_word);
125}
126
127void GetArrayColorAndShape(const Model& model, const std::string& array_name,
128 Color* color, std::string* shape) {
129 // All colors in this file are from:
130 // https://material.io/guidelines/style/color.html
131 // Arrays involved in RNN back-edges have a different color
132 for (const auto& rnn_state : model.flags.rnn_states()) {
133 // RNN state, fed by a back-edge. Bold color.
134 if (array_name == rnn_state.state_array()) {
135 *color = Color(0x0F, 0x9D, 0x58);
136 *shape = "invhouse";
137 return;
138 }
139 // RNN back-edge source, feeding a RNN state.
140 // Light tone of the same color as RNN states.
141 if (array_name == rnn_state.back_edge_source_array()) {
142 *color = Color(0xB7, 0xE1, 0xCD);
143 *shape = "house";
144 return;
145 }
146 }
147 // Constant parameter arrays have their own bold color
148 if (model.GetArray(array_name).buffer) {
149 *color = Color(0x42, 0x85, 0xF4);
150 *shape = "cylinder";
151 return;
152 }
153 // Remaining arrays are activations.
154 // We use gray colors for them because they are the majority
155 // of arrays so we want to highlight other arrays instead of them.
156 // First, we use a bolder gray for input/output arrays:
157 if (IsInputArray(model, array_name)) {
158 *color = Color(0x9E, 0x9E, 0x9E);
159 *shape = "invhouse";
160 return;
161 }
162 if (IsOutputArray(model, array_name)) {
163 *color = Color(0x9E, 0x9E, 0x9E);
164 *shape = "house";
165 return;
166 }
167 // Remaining arrays are intermediate activation arrays.
168 // Lighter tone of the same grey as for input/output arrays:
169 // We want these to be very discrete.
170 *color = Color(0xF5, 0xF5, 0xF5);
171 *shape = "box";
172}
173
174std::string GetArrayCompassPt(const Model& model,
175 const std::string& array_name) {
176 // The "compass point" is the point on the node where edge connections are
177 // made. For most arrays we don't care, but input's and outputs look better
178 // connected at the tip of the "house" and "invhouse" shapes used. So we
179 // append ":n" and ":s" respectively for those.
180 for (const auto& rnn_state : model.flags.rnn_states()) {
181 // RNN state is essentially an input
182 if (array_name == rnn_state.state_array()) {
183 return ":s";
184 }
185 // RNN back-edge source is essentially an output
186 if (array_name == rnn_state.back_edge_source_array()) {
187 return ":n";
188 }
189 }
190 if (IsInputArray(model, array_name)) {
191 return ":s";
192 }
193 if (IsOutputArray(model, array_name)) {
194 return ":n";
195 }
196 return "";
197}
198
199void AppendArrayVal(std::string* string, Array const& array, int index) {
200 if (array.buffer->type == ArrayDataType::kFloat) {
201 const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
202 if (index >= data.size()) {
203 return;
204 }
205 AppendF(string, "%.3f", data[index]);
206 } else if (array.buffer->type == ArrayDataType::kUint8) {
207 const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
208 if (index >= data.size()) {
209 return;
210 }
211 AppendF(string, "%d", data[index]);
212 } else if (array.buffer->type == ArrayDataType::kInt16) {
213 const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
214 if (index >= data.size()) {
215 return;
216 }
217 AppendF(string, "%d", data[index]);
218 } else if (array.buffer->type == ArrayDataType::kInt32) {
219 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
220 if (index >= data.size()) {
221 return;
222 }
223 AppendF(string, "%d", data[index]);
224 } else if (array.buffer->type == ArrayDataType::kInt64) {
225 const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
226 if (index >= data.size()) {
227 return;
228 }
229 AppendF(string, "%d", data[index]);
230 } else if (array.buffer->type == ArrayDataType::kBool) {
231 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
232 if (index >= data.size()) {
233 return;
234 }
235 AppendF(string, "%d", data[index]);
236 }
237}
238
239typedef std::map<std::string, std::string> Attributes;
240
241std::string AttributesToHtml(Attributes attributes) {
242 std::string html;
243 for (const auto& attr : attributes) {
244 html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
245 html += attr.first;
246 html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
247 html += attr.second;
248 html += "</TD></TR>";
249 }
250 return html;
251}
252
253std::string GetArrayLabel(const Model& model, const std::string& array_id) {
254 std::string html;
255
256 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
257 html += "<";
258
259 // Begin Table
260 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
261 html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
262
263 auto& array = model.GetArray(array_id);
264 if (array.buffer) {
265 // "cylinder" shapes require some extra head room.
266 html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
267 }
268
269 // "Primary" name of array (last non-slash delimited group of characters).
270 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
271 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
272 AppendF(&html, R"CODE(%s)CODE",
273 std::vector<std::string>(absl::StrSplit(array_id, '/')).back());
274 html += R"CODE(</I></FONT>)CODE";
275 html += "</TD></TR>";
276
277 // Array data type and dimensions
278 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
279 html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
280 // Type
281 html += ArrayDataTypeName(array.data_type);
282 // Shape
283 if (array.has_shape()) {
284 auto& array_shape = array.shape();
285 html += "[";
286 for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
287 AppendF(&html, "%d", array_shape.dims(dim));
288 if (dim + 1 < array_shape.dimensions_count()) {
289 html += kUnicodeMult;
290 }
291 }
292 html += "]";
293 }
294
295 // Small buffer sample
296 int buffer_size = 0;
297 if (array.buffer) {
298 buffer_size = RequiredBufferSizeForShape(array.shape());
299 }
300 if ((buffer_size > 0) && (buffer_size <= 4)) {
301 html += " = ";
302 if (array.shape().dimensions_count() > 0) {
303 html += "{";
304 }
305 for (int i = 0; i < buffer_size; i++) {
306 AppendArrayVal(&html, array, i);
307 if (i + 1 < buffer_size) {
308 html += ", ";
309 }
310 }
311 if (array.shape().dimensions_count() > 0) {
312 html += "}";
313 }
314 }
315 html += R"CODE(</B></FONT>)CODE";
316 html += "</TD></TR>";
317
318 // Large buffer samples get their own line
319 if (buffer_size > 4) {
320 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
321 AppendArrayVal(&html, array, 0);
322 html += ", ";
323 AppendArrayVal(&html, array, 1);
324 html += kUnicodeEllipsis;
325 AppendArrayVal(&html, array, buffer_size - 2);
326 html += ", ";
327 AppendArrayVal(&html, array, buffer_size - 1);
328 html += "}</TD></TR>";
329 }
330
331 // Other array properties
332 Attributes attrs;
333 if (array.minmax) {
334 attrs["minmax"] =
335 StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
336 }
337 if (array.quantization_params) {
338 attrs["quant"] = StringF("%7g\u00B7(x-%d)", // Unicode "cdot"
339 array.quantization_params->scale,
340 array.quantization_params->zero_point);
341 }
342 if (array.alloc) {
343 attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
344 }
345 html += AttributesToHtml(attrs);
346
347 // output array_id in ultra-small font so it can be searched and copied.
348 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
349 html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
350 AppendF(&html, R"CODE("%s")CODE", array_id);
351 html += R"CODE(</FONT>)CODE";
352 html += "</TD></TR>";
353
354 // End Table and HTML-like label
355 html += R"CODE(</TABLE></FONT>)CODE";
356 html += ">";
357 return html;
358}
359
360Attributes GetOpAttributes(const Model& model, const Operator& op) {
361 Attributes attrs;
362 switch (op.fused_activation_function) {
363 case FusedActivationFunctionType::kRelu:
364 attrs["func"] = "ReLU";
365 break;
366 case FusedActivationFunctionType::kRelu6:
367 attrs["func"] = "ReLU6";
368 break;
369 case FusedActivationFunctionType::kRelu1:
370 attrs["func"] = "ReLU1";
371 break;
372 default:
373 break;
374 }
375 // Output state of member vars on derived operators.
376 switch (op.type) {
377 case OperatorType::kConv: {
378 const auto& conv_op = static_cast<const ConvOperator&>(op);
379 std::string stride;
380 AppendF(&stride, "%d", conv_op.stride_width);
381 stride += kUnicodeMult;
382 AppendF(&stride, "%d", conv_op.stride_height);
383 attrs["stride"] = stride;
384 attrs["padding"] =
385 (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
386 break;
387 }
388 case OperatorType::kDepthwiseConv: {
389 const auto& depthconv_op = static_cast<const ConvOperator&>(op);
390 std::string stride;
391 AppendF(&stride, "%d", depthconv_op.stride_width);
392 stride += kUnicodeMult;
393 AppendF(&stride, "%d", depthconv_op.stride_height);
394 attrs["stride"] = stride;
395 attrs["padding"] =
396 (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
397 break;
398 }
399 case OperatorType::kFakeQuant: {
400 const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
401 attrs["bits"] = StringF("%d", fakequant_op.num_bits);
402 if (fakequant_op.minmax) {
403 attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
404 fakequant_op.minmax->max);
405 } else {
406 attrs["range"] = "[?,?]";
407 }
408 break;
409 }
410 default:
411 break;
412 }
413 int64_t math_ops_count;
414 if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
415 (math_ops_count != 0)) {
416 attrs["math"] = FormattedNumber(math_ops_count) + "ops";
417 }
418
419 return attrs;
420}
421
422Color GetOpColor(const Operator& op) {
423 if ((op.type == OperatorType::kDepthwiseConv) ||
424 (op.type == OperatorType::kConv) ||
425 (op.type == OperatorType::kFullyConnected) ||
426 (op.type == OperatorType::kFakeQuant)) {
427 // Give some ops a bolder red
428 return Color(0xC5, 0x39, 0x29);
429 } else {
430 return Color(0xDB, 0x44, 0x37);
431 }
432}
433
434std::string GetOpLabel(const Model& model, const Operator& op) {
435 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
436 std::string html;
437 html += "<";
438
439 // Begin Table
440 html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
441 html +=
442 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
443
444 // Input Ports
445 if (!op.inputs.empty()) {
446 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
447 // Distribute evenly using a sub-table
448 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
449 html += R"CODE(<TR>)CODE";
450 for (int i = 0; i < op.inputs.size(); i++) {
451 html += R"CODE(<TD PORT=")CODE";
452 AppendF(&html, "i%d", i);
453 html += R"CODE(">)CODE";
454 if (op.inputs.size() > 1) {
455 // Only number inputs when op has two or more inputs
456 AppendF(&html, "%d", i);
457 }
458 html += "</TD>";
459 }
460 html += "</TR>";
461 html += R"CODE(</TABLE></TD></TR>)CODE";
462 }
463
464 // Name
465 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
466 html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
467 if (op.type == OperatorType::kUnsupported) {
468 html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
469 } else {
470 html +=
471 std::string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
472 }
473 html += R"CODE(</B></FONT>)CODE";
474 html += "</TD></TR>";
475
476 // Attributes
477 Attributes attrs = GetOpAttributes(model, op);
478 html += AttributesToHtml(attrs);
479
480 // Output Ports
481 if (!op.outputs.empty()) {
482 html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
483 // Distribute evenly using a sub-table
484 html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
485 html += R"CODE(<TR>)CODE";
486 for (int i = 0; i < op.outputs.size(); i++) {
487 html += R"CODE(<TD PORT=")CODE";
488 AppendF(&html, "o%d", i);
489 html += R"CODE(">)CODE";
490 if (op.outputs.size() > 1) {
491 // Only number outputs when op has two or more outputs
492 AppendF(&html, "%d", i);
493 }
494 html += "</TD>";
495 }
496 html += "</TR>";
497 html += R"CODE(</TABLE></TD></TR>)CODE";
498 }
499
500 // End Table and HTML-like label
501 html += R"CODE(</TABLE></FONT>)CODE";
502 html += ">";
503
504 return html;
505}
506
507float GetLog2BufferSize(const Model& model, const std::string& array_id) {
508 auto& array = model.GetArray(array_id);
509 if (array.has_shape()) {
510 int buffer_size = 0;
511 if (IsNonEmpty(array.shape())) {
512 buffer_size = RequiredBufferSizeForShape(array.shape());
513 return std::log2(static_cast<float>(buffer_size));
514 }
515 }
516 return 0.0f;
517}
518
519std::string GetOpId(int op_index) { return StringF("op%05d", op_index); }
520
521void DumpOperator(const Model& model, std::string* output_file, int op_index) {
522 // Dump node for operator.
523 const Operator& op = *model.operators[op_index];
524 Color color = GetOpColor(op);
525 std::string label = GetOpLabel(model, op);
526 std::string op_id = GetOpId(op_index);
527 AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
528 color.TextColorString());
529}
530
531void DumpOperatorEdges(const Model& model, std::string* output_file,
532 int op_index) {
533 // Inputs
534 const Operator& op = *model.operators[op_index];
535 std::string op_id = GetOpId(op_index);
536 for (int i = 0; i < op.inputs.size(); i++) {
537 const auto& input = op.inputs[i];
538 if (!model.HasArray(input)) {
539 // Connected arrays should _always_ exist. Except, perhaps, during
540 // development.
541 continue;
542 }
543 float log2_buffer_size = GetLog2BufferSize(model, input);
544 // Draw lines that transport more data thicker (Otherwise, where would the
545 // data fit? right?).
546 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
547 // Keep edges that transport more data shorter than those with less.
548 float weight = std::max(1.0f, log2_buffer_size);
549 if (!IsInputArray(model, input) &&
550 GetOpWithOutput(model, input) == nullptr) {
551 // Give the main line of data flow a straighter path by penalizing edges
552 // to standalone buffers. Weights are generally very large buffers that
553 // would otherwise skew the layout.
554 weight = 1.0f;
555 }
556 std::string compass_pt = GetArrayCompassPt(model, input);
557 AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
558 weight);
559 }
560 // Outputs
561 for (int i = 0; i < op.outputs.size(); i++) {
562 const auto& output = op.outputs[i];
563 if (!model.HasArray(output)) {
564 continue;
565 }
566 float log2_buffer_size = GetLog2BufferSize(model, output);
567 // See comments above regarding weight and line_width calculations.
568 float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
569 float weight = std::max(1.0f, log2_buffer_size);
570 if (!IsArrayConsumed(model, output)) {
571 weight = 1.0f;
572 }
573 std::string compass_pt = GetArrayCompassPt(model, output);
574 AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
575 line_width, weight);
576 }
577}
578
579struct Node {
580 Node() : math_ops(0) {}
581 // Name used as a key in the model's array map
582 std::string array_id;
583
584 // Estimated number of math ops incurred by this node (the sum of the op
585 // with this array as 1st output, plus all children nodes).
586 int64_t math_ops;
587
588 // A map of child nodes keyed by name.
589 std::map<const std::string, std::unique_ptr<Node>> children;
590};
591
592std::string GetSubgraphLabel(Node const& node, const std::string& subgraph) {
593 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
594 std::string html;
595 html += "<";
596
597 // Begin Table
598 html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
599 html +=
600 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
601
602 // Name
603 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
604 html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
605 html += subgraph;
606 html += R"CODE(</I></FONT>)CODE";
607 html += "</TD></TR>";
608
609 // Attributes
610 Attributes attrs;
611 if (node.math_ops > 0) {
612 attrs["math"] = FormattedNumber(node.math_ops) + "ops";
613 }
614 html += AttributesToHtml(attrs);
615
616 // End Table and HTML-like label
617 html += R"CODE(</TABLE></FONT>)CODE";
618 html += ">";
619
620 return html;
621}
622
623void DumpSubgraphHeader(std::string* output_file, Node const& node,
624 const std::string& node_name) {
625 Color color = HashStringToColor(node_name);
626 std::string label = GetSubgraphLabel(node, node_name);
627 AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
628}
629
630void DumpArray(const Model& model, std::string* output_file,
631 const std::string& array_id) {
632 Color color;
633 std::string shape;
634 GetArrayColorAndShape(model, array_id, &color, &shape);
635 std::string label = GetArrayLabel(model, array_id);
636 AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
637 color.AsHexString(), color.TextColorString());
638
639 // Ops are placed in the same subgraph as their first output.
640 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
641 const Operator& op = *model.operators[op_index];
642 if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
643 DumpOperator(model, output_file, op_index);
644 }
645 }
646}
647
648void DumpNode(const Model& model, std::string* output_file,
649 const std::string& node_name, Node const& node) {
650 bool not_root = !node_name.empty();
651 if (not_root) {
652 DumpSubgraphHeader(output_file, node, node_name);
653 }
654
655 for (const auto& child : node.children) {
656 if (!child.second->array_id.empty()) {
657 // Dump array if this node possesses one.
658 DumpArray(model, output_file, child.second->array_id);
659 }
660 // Note that it is always possible to have children. Unlike a filesystem,
661 // the existence of array "foo/bar" does _not_ prevent other arrays, such as
662 // and "foo/bar/baz", from being nested beneath it.
663 DumpNode(model, output_file, child.first, *child.second);
664 }
665
666 if (not_root) {
667 // End subgraph
668 AppendF(output_file, " }\n");
669 }
670}
671
672int64_t GetArithmeticOpsCount(const Model& model, const std::string& array_id) {
673 for (const auto& op : model.operators) {
674 if (!op->outputs.empty() && op->outputs[0] == array_id) {
675 int64_t count;
676 if (EstimateArithmeticOpsCount(model, *op, &count)) {
677 return count;
678 } else {
679 return 0;
680 }
681 }
682 }
683 return 0;
684}
685
686void InsertNode(const Model& model, const std::string& array_id, Node* node,
687 std::vector<std::string> prefixes, int64_t* math_ops) {
688 if (prefixes.empty()) {
689 // Base case: store array in this node.
690 node->array_id = array_id;
691 *math_ops = GetArithmeticOpsCount(model, array_id);
692 } else {
693 // Insert into the sub-tree for that prefix.
694 std::string prefix = prefixes.back();
695 prefixes.pop_back();
696 if (node->children.count(prefix) == 0) {
697 // Create a new node if this prefix is unseen.
698 node->children[prefix] = std::make_unique<Node>();
699 }
700 InsertNode(model, array_id, node->children[prefix].get(), prefixes,
701 math_ops);
702 }
703 // Sum estimated math ops into all nodes.
704 node->math_ops += *math_ops;
705}
706
707void BuildArrayTree(const Model& model, Node* tree) {
708 // Delimit array names by path "/", then place into a tree based on this path.
709 for (const auto& array_id : model.GetArrayMap()) {
710 std::vector<std::string> prefixes = absl::StrSplit(array_id.first, '/');
711 std::reverse(prefixes.begin(), prefixes.end());
712 int64_t math_ops; // Temporary storage for math ops used during recursion.
713 InsertNode(model, array_id.first, tree, prefixes, &math_ops);
714 }
715}
716
717std::string GetGraphLabel(const Model& model, const std::string& graph_name) {
718 // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
719 std::string html;
720 html += "<";
721
722 // Begin Table
723 html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
724 html +=
725 R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
726
727 // Name
728 html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
729 html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
730 html += graph_name;
731 html += R"CODE(</I></B></FONT>)CODE";
732 html += "</TD></TR>";
733
734 // Attributes
735 Attributes attrs;
736 attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
737 if (!model.optional_arrays.empty()) {
738 attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
739 }
740 attrs["operators"] = StringF("%d", model.operators.size());
741 int64_t ops_count;
742 if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
743 attrs["math"] = FormattedNumber(ops_count) + "ops";
744 }
745 if (model.transient_data_size > 0) {
746 attrs["transient data size"] =
747 StringF("%d KiB", model.transient_data_size / 1024);
748 }
749 if (model.transient_data_alignment > 0) {
750 attrs["transient data alignment"] =
751 StringF("%d bytes", model.transient_data_alignment);
752 }
753 html += AttributesToHtml(attrs);
754
755 // End Table and HTML-like label
756 html += R"CODE(</TABLE></FONT>)CODE";
757 html += ">";
758
759 return html;
760}
761} // namespace
762
763void DumpGraphviz(const Model& model, std::string* output_file,
764 const std::string& graph_name) {
765 // Start graphviz format
766 AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
767
768 // Organize arrays into a tree for subgraphing
769 Node tree;
770 BuildArrayTree(model, &tree);
771 DumpNode(model, output_file, "", tree);
772
773 // Dump edges outside all subgraphs (otherwise the referred-to nodes are
774 // implicitly included in that subgraph).
775 for (int op_index = 0; op_index < model.operators.size(); op_index++) {
776 DumpOperatorEdges(model, output_file, op_index);
777 }
778
779 // Dump RNN Backedges
780 for (const auto& rnn_state : model.flags.rnn_states()) {
781 AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
782 rnn_state.state_array());
783 }
784 // End graphviz format
785 AppendF(output_file, "}\n");
786}
787} // namespace toco
788