1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/dump_graphviz.h" |
16 | |
17 | #include <algorithm> |
18 | #include <cmath> |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "absl/memory/memory.h" |
25 | #include "absl/strings/str_replace.h" |
26 | #include "absl/strings/str_split.h" |
27 | #include "absl/strings/strip.h" |
28 | #include "re2/re2.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/lite/toco/model_flags.pb.h" |
31 | #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" |
32 | #include "tensorflow/lite/toco/toco_port.h" |
33 | #include "tensorflow/lite/toco/toco_types.h" |
34 | #include "tensorflow/lite/toco/tooling_util.h" |
35 | |
36 | using toco::port::AppendF; |
37 | using toco::port::StringF; |
38 | |
39 | namespace toco { |
40 | namespace { |
41 | |
42 | // 'nslimit' is a graphviz (dot) parameter that limits the iterations during |
43 | // the layout phase. Omitting it allows infinite iterations, causing some |
44 | // complex graphs to never finish. A value of 125 produces good graphs |
45 | // while allowing complex graphs to finish. |
46 | constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/" |
47 | nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s |
48 | )CODE" ; |
49 | // Note: tooltip's are only supported on SVGs in Chrome. |
50 | constexpr char kSubgraphFmt[] = |
51 | R"CODE( subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s |
52 | )CODE" ; |
53 | constexpr char kArrayNodeFmt[] = |
54 | R"CODE( "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"]; |
55 | )CODE" ; |
56 | constexpr char kOpNodeFmt[] = |
57 | R"CODE( %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"]; |
58 | )CODE" ; |
59 | constexpr char kInputEdgeFmt[] = |
60 | R"CODE( "%s"%s -> %s:i%d:n [penwidth=%f weight=%f]; |
61 | )CODE" ; |
62 | constexpr char kOutputEdgeFmt[] = |
63 | R"CODE( %s:o%d:s -> "%s"%s [penwidth=%f weight=%f]; |
64 | )CODE" ; |
65 | constexpr char kRNNBackEdgeFmt[] = |
66 | R"CODE( "%s":s -> "%s":n [color="#0F9D58" constraint=false]; |
67 | )CODE" ; |
68 | constexpr char kUnicodeMult[] = "\u00D7" ; |
69 | constexpr char kUnicodeEllipsis[] = " \u2026 " ; |
70 | |
71 | class Color { |
72 | public: |
73 | Color() {} |
74 | Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {} |
75 | explicit Color(uint32 word) |
76 | : r_((word & 0x00FF0000) >> 16), |
77 | g_((word & 0x0000FF00) >> 8), |
78 | b_((word & 0x000000FF) >> 0) {} |
79 | |
80 | // Returns the string serialization of this color in graphviz format, |
81 | // for use as 'fillcolor' in boxes. |
82 | std::string AsHexString() const { |
83 | return StringF("#%.2X%.2X%.2X" , r_, g_, b_); |
84 | } |
85 | // The color to use for this node; will be used as 'fillcolor' |
86 | // for its box. See Color::AsHexString. A suitable, different |
87 | // color will be chosen for the 'fontcolor' for the inside text |
88 | // label, see Color::TextColorString. |
89 | // Returns the serialization in graphviz format of a suitable color to use |
90 | // 'fontcolor' in the same boxes. It should black or white, whichever offers |
91 | // the better contrast from AsHexString(). |
92 | std::string TextColorString() const { |
93 | // https://en.wikipedia.org/wiki/Relative_luminance |
94 | const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_; |
95 | const uint8 l = luminance > 128.f ? 0 : 255; |
96 | return StringF("#%.2X%.2X%.2X" , l, l, l); |
97 | } |
98 | |
99 | private: |
100 | uint8 r_ = 0, g_ = 0, b_ = 0; |
101 | }; |
102 | |
103 | Color HashStringToColor(std::string s) { |
104 | // Return a unique color for a name. |
105 | // |
106 | // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes |
107 | // the string to a uint_32, then twiddles some bits to get a light and subtle |
108 | // color. This seems to be a good heuristic for keeping enough of the name to |
109 | // hash to a unique color while still revealing structure through naming |
110 | // similarities. |
111 | // |
112 | // The regular expression "_\d+" matches any underscore followed by numbers, |
113 | // which we strip out. Examples: |
114 | // |
115 | // "Conv" -> "Conv" |
116 | // "Conv_2" -> "Conv" |
117 | // "Conv_72" -> "Conv" |
118 | // "Pad_1_bias -> "Pad_bias" |
119 | // "Conv_abc" -> "Conv_abc" |
120 | |
121 | RE2::GlobalReplace(&s, R"CODE(_\d+)CODE" , "" ); |
122 | uint32 color_word = std::hash<std::string>{}(s); |
123 | color_word |= 0x00E0E0E0; |
124 | return Color(color_word); |
125 | } |
126 | |
127 | void GetArrayColorAndShape(const Model& model, const std::string& array_name, |
128 | Color* color, std::string* shape) { |
129 | // All colors in this file are from: |
130 | // https://material.io/guidelines/style/color.html |
131 | // Arrays involved in RNN back-edges have a different color |
132 | for (const auto& rnn_state : model.flags.rnn_states()) { |
133 | // RNN state, fed by a back-edge. Bold color. |
134 | if (array_name == rnn_state.state_array()) { |
135 | *color = Color(0x0F, 0x9D, 0x58); |
136 | *shape = "invhouse" ; |
137 | return; |
138 | } |
139 | // RNN back-edge source, feeding a RNN state. |
140 | // Light tone of the same color as RNN states. |
141 | if (array_name == rnn_state.back_edge_source_array()) { |
142 | *color = Color(0xB7, 0xE1, 0xCD); |
143 | *shape = "house" ; |
144 | return; |
145 | } |
146 | } |
147 | // Constant parameter arrays have their own bold color |
148 | if (model.GetArray(array_name).buffer) { |
149 | *color = Color(0x42, 0x85, 0xF4); |
150 | *shape = "cylinder" ; |
151 | return; |
152 | } |
153 | // Remaining arrays are activations. |
154 | // We use gray colors for them because they are the majority |
155 | // of arrays so we want to highlight other arrays instead of them. |
156 | // First, we use a bolder gray for input/output arrays: |
157 | if (IsInputArray(model, array_name)) { |
158 | *color = Color(0x9E, 0x9E, 0x9E); |
159 | *shape = "invhouse" ; |
160 | return; |
161 | } |
162 | if (IsOutputArray(model, array_name)) { |
163 | *color = Color(0x9E, 0x9E, 0x9E); |
164 | *shape = "house" ; |
165 | return; |
166 | } |
167 | // Remaining arrays are intermediate activation arrays. |
168 | // Lighter tone of the same grey as for input/output arrays: |
169 | // We want these to be very discrete. |
170 | *color = Color(0xF5, 0xF5, 0xF5); |
171 | *shape = "box" ; |
172 | } |
173 | |
174 | std::string GetArrayCompassPt(const Model& model, |
175 | const std::string& array_name) { |
176 | // The "compass point" is the point on the node where edge connections are |
177 | // made. For most arrays we don't care, but input's and outputs look better |
178 | // connected at the tip of the "house" and "invhouse" shapes used. So we |
179 | // append ":n" and ":s" respectively for those. |
180 | for (const auto& rnn_state : model.flags.rnn_states()) { |
181 | // RNN state is essentially an input |
182 | if (array_name == rnn_state.state_array()) { |
183 | return ":s" ; |
184 | } |
185 | // RNN back-edge source is essentially an output |
186 | if (array_name == rnn_state.back_edge_source_array()) { |
187 | return ":n" ; |
188 | } |
189 | } |
190 | if (IsInputArray(model, array_name)) { |
191 | return ":s" ; |
192 | } |
193 | if (IsOutputArray(model, array_name)) { |
194 | return ":n" ; |
195 | } |
196 | return "" ; |
197 | } |
198 | |
199 | void AppendArrayVal(std::string* string, Array const& array, int index) { |
200 | if (array.buffer->type == ArrayDataType::kFloat) { |
201 | const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data; |
202 | if (index >= data.size()) { |
203 | return; |
204 | } |
205 | AppendF(string, "%.3f" , data[index]); |
206 | } else if (array.buffer->type == ArrayDataType::kUint8) { |
207 | const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data; |
208 | if (index >= data.size()) { |
209 | return; |
210 | } |
211 | AppendF(string, "%d" , data[index]); |
212 | } else if (array.buffer->type == ArrayDataType::kInt16) { |
213 | const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data; |
214 | if (index >= data.size()) { |
215 | return; |
216 | } |
217 | AppendF(string, "%d" , data[index]); |
218 | } else if (array.buffer->type == ArrayDataType::kInt32) { |
219 | const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data; |
220 | if (index >= data.size()) { |
221 | return; |
222 | } |
223 | AppendF(string, "%d" , data[index]); |
224 | } else if (array.buffer->type == ArrayDataType::kInt64) { |
225 | const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data; |
226 | if (index >= data.size()) { |
227 | return; |
228 | } |
229 | AppendF(string, "%d" , data[index]); |
230 | } else if (array.buffer->type == ArrayDataType::kBool) { |
231 | const auto& data = array.GetBuffer<ArrayDataType::kBool>().data; |
232 | if (index >= data.size()) { |
233 | return; |
234 | } |
235 | AppendF(string, "%d" , data[index]); |
236 | } |
237 | } |
238 | |
239 | typedef std::map<std::string, std::string> Attributes; |
240 | |
241 | std::string AttributesToHtml(Attributes attributes) { |
242 | std::string html; |
243 | for (const auto& attr : attributes) { |
244 | html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE" ; |
245 | html += attr.first; |
246 | html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE" ; |
247 | html += attr.second; |
248 | html += "</TD></TR>" ; |
249 | } |
250 | return html; |
251 | } |
252 | |
253 | std::string GetArrayLabel(const Model& model, const std::string& array_id) { |
254 | std::string html; |
255 | |
256 | // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) |
257 | html += "<" ; |
258 | |
259 | // Begin Table |
260 | html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE" ; |
261 | html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE" ; |
262 | |
263 | auto& array = model.GetArray(array_id); |
264 | if (array.buffer) { |
265 | // "cylinder" shapes require some extra head room. |
266 | html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE" ; |
267 | } |
268 | |
269 | // "Primary" name of array (last non-slash delimited group of characters). |
270 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE" ; |
271 | html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE" ; |
272 | AppendF(&html, R"CODE(%s)CODE" , |
273 | std::vector<std::string>(absl::StrSplit(array_id, '/')).back()); |
274 | html += R"CODE(</I></FONT>)CODE" ; |
275 | html += "</TD></TR>" ; |
276 | |
277 | // Array data type and dimensions |
278 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE" ; |
279 | html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE" ; |
280 | // Type |
281 | html += ArrayDataTypeName(array.data_type); |
282 | // Shape |
283 | if (array.has_shape()) { |
284 | auto& array_shape = array.shape(); |
285 | html += "[" ; |
286 | for (int dim = 0; dim < array_shape.dimensions_count(); dim++) { |
287 | AppendF(&html, "%d" , array_shape.dims(dim)); |
288 | if (dim + 1 < array_shape.dimensions_count()) { |
289 | html += kUnicodeMult; |
290 | } |
291 | } |
292 | html += "]" ; |
293 | } |
294 | |
295 | // Small buffer sample |
296 | int buffer_size = 0; |
297 | if (array.buffer) { |
298 | buffer_size = RequiredBufferSizeForShape(array.shape()); |
299 | } |
300 | if ((buffer_size > 0) && (buffer_size <= 4)) { |
301 | html += " = " ; |
302 | if (array.shape().dimensions_count() > 0) { |
303 | html += "{" ; |
304 | } |
305 | for (int i = 0; i < buffer_size; i++) { |
306 | AppendArrayVal(&html, array, i); |
307 | if (i + 1 < buffer_size) { |
308 | html += ", " ; |
309 | } |
310 | } |
311 | if (array.shape().dimensions_count() > 0) { |
312 | html += "}" ; |
313 | } |
314 | } |
315 | html += R"CODE(</B></FONT>)CODE" ; |
316 | html += "</TD></TR>" ; |
317 | |
318 | // Large buffer samples get their own line |
319 | if (buffer_size > 4) { |
320 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE" ; |
321 | AppendArrayVal(&html, array, 0); |
322 | html += ", " ; |
323 | AppendArrayVal(&html, array, 1); |
324 | html += kUnicodeEllipsis; |
325 | AppendArrayVal(&html, array, buffer_size - 2); |
326 | html += ", " ; |
327 | AppendArrayVal(&html, array, buffer_size - 1); |
328 | html += "}</TD></TR>" ; |
329 | } |
330 | |
331 | // Other array properties |
332 | Attributes attrs; |
333 | if (array.minmax) { |
334 | attrs["minmax" ] = |
335 | StringF("[%.7g, %.7g]" , array.minmax->min, array.minmax->max); |
336 | } |
337 | if (array.quantization_params) { |
338 | attrs["quant" ] = StringF("%7g\u00B7(x-%d)" , // Unicode "cdot" |
339 | array.quantization_params->scale, |
340 | array.quantization_params->zero_point); |
341 | } |
342 | if (array.alloc) { |
343 | attrs["alloc" ] = StringF("[%d, %d)" , array.alloc->start, array.alloc->end); |
344 | } |
345 | html += AttributesToHtml(attrs); |
346 | |
347 | // output array_id in ultra-small font so it can be searched and copied. |
348 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE" ; |
349 | html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE" ; |
350 | AppendF(&html, R"CODE("%s")CODE" , array_id); |
351 | html += R"CODE(</FONT>)CODE" ; |
352 | html += "</TD></TR>" ; |
353 | |
354 | // End Table and HTML-like label |
355 | html += R"CODE(</TABLE></FONT>)CODE" ; |
356 | html += ">" ; |
357 | return html; |
358 | } |
359 | |
360 | Attributes GetOpAttributes(const Model& model, const Operator& op) { |
361 | Attributes attrs; |
362 | switch (op.fused_activation_function) { |
363 | case FusedActivationFunctionType::kRelu: |
364 | attrs["func" ] = "ReLU" ; |
365 | break; |
366 | case FusedActivationFunctionType::kRelu6: |
367 | attrs["func" ] = "ReLU6" ; |
368 | break; |
369 | case FusedActivationFunctionType::kRelu1: |
370 | attrs["func" ] = "ReLU1" ; |
371 | break; |
372 | default: |
373 | break; |
374 | } |
375 | // Output state of member vars on derived operators. |
376 | switch (op.type) { |
377 | case OperatorType::kConv: { |
378 | const auto& conv_op = static_cast<const ConvOperator&>(op); |
379 | std::string stride; |
380 | AppendF(&stride, "%d" , conv_op.stride_width); |
381 | stride += kUnicodeMult; |
382 | AppendF(&stride, "%d" , conv_op.stride_height); |
383 | attrs["stride" ] = stride; |
384 | attrs["padding" ] = |
385 | (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid" ; |
386 | break; |
387 | } |
388 | case OperatorType::kDepthwiseConv: { |
389 | const auto& depthconv_op = static_cast<const ConvOperator&>(op); |
390 | std::string stride; |
391 | AppendF(&stride, "%d" , depthconv_op.stride_width); |
392 | stride += kUnicodeMult; |
393 | AppendF(&stride, "%d" , depthconv_op.stride_height); |
394 | attrs["stride" ] = stride; |
395 | attrs["padding" ] = |
396 | (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid" ; |
397 | break; |
398 | } |
399 | case OperatorType::kFakeQuant: { |
400 | const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op); |
401 | attrs["bits" ] = StringF("%d" , fakequant_op.num_bits); |
402 | if (fakequant_op.minmax) { |
403 | attrs["range" ] = StringF("[%g,%g]" , fakequant_op.minmax->min, |
404 | fakequant_op.minmax->max); |
405 | } else { |
406 | attrs["range" ] = "[?,?]" ; |
407 | } |
408 | break; |
409 | } |
410 | default: |
411 | break; |
412 | } |
413 | int64_t math_ops_count; |
414 | if (EstimateArithmeticOpsCount(model, op, &math_ops_count) && |
415 | (math_ops_count != 0)) { |
416 | attrs["math" ] = FormattedNumber(math_ops_count) + "ops" ; |
417 | } |
418 | |
419 | return attrs; |
420 | } |
421 | |
422 | Color GetOpColor(const Operator& op) { |
423 | if ((op.type == OperatorType::kDepthwiseConv) || |
424 | (op.type == OperatorType::kConv) || |
425 | (op.type == OperatorType::kFullyConnected) || |
426 | (op.type == OperatorType::kFakeQuant)) { |
427 | // Give some ops a bolder red |
428 | return Color(0xC5, 0x39, 0x29); |
429 | } else { |
430 | return Color(0xDB, 0x44, 0x37); |
431 | } |
432 | } |
433 | |
434 | std::string GetOpLabel(const Model& model, const Operator& op) { |
435 | // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) |
436 | std::string html; |
437 | html += "<" ; |
438 | |
439 | // Begin Table |
440 | html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE" ; |
441 | html += |
442 | R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE" ; |
443 | |
444 | // Input Ports |
445 | if (!op.inputs.empty()) { |
446 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE" ; |
447 | // Distribute evenly using a sub-table |
448 | html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE" ; |
449 | html += R"CODE(<TR>)CODE" ; |
450 | for (int i = 0; i < op.inputs.size(); i++) { |
451 | html += R"CODE(<TD PORT=")CODE" ; |
452 | AppendF(&html, "i%d" , i); |
453 | html += R"CODE(">)CODE" ; |
454 | if (op.inputs.size() > 1) { |
455 | // Only number inputs when op has two or more inputs |
456 | AppendF(&html, "%d" , i); |
457 | } |
458 | html += "</TD>" ; |
459 | } |
460 | html += "</TR>" ; |
461 | html += R"CODE(</TABLE></TD></TR>)CODE" ; |
462 | } |
463 | |
464 | // Name |
465 | html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE" ; |
466 | html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE" ; |
467 | if (op.type == OperatorType::kUnsupported) { |
468 | html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op; |
469 | } else { |
470 | html += |
471 | std::string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow" )); |
472 | } |
473 | html += R"CODE(</B></FONT>)CODE" ; |
474 | html += "</TD></TR>" ; |
475 | |
476 | // Attributes |
477 | Attributes attrs = GetOpAttributes(model, op); |
478 | html += AttributesToHtml(attrs); |
479 | |
480 | // Output Ports |
481 | if (!op.outputs.empty()) { |
482 | html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE" ; |
483 | // Distribute evenly using a sub-table |
484 | html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE" ; |
485 | html += R"CODE(<TR>)CODE" ; |
486 | for (int i = 0; i < op.outputs.size(); i++) { |
487 | html += R"CODE(<TD PORT=")CODE" ; |
488 | AppendF(&html, "o%d" , i); |
489 | html += R"CODE(">)CODE" ; |
490 | if (op.outputs.size() > 1) { |
491 | // Only number outputs when op has two or more outputs |
492 | AppendF(&html, "%d" , i); |
493 | } |
494 | html += "</TD>" ; |
495 | } |
496 | html += "</TR>" ; |
497 | html += R"CODE(</TABLE></TD></TR>)CODE" ; |
498 | } |
499 | |
500 | // End Table and HTML-like label |
501 | html += R"CODE(</TABLE></FONT>)CODE" ; |
502 | html += ">" ; |
503 | |
504 | return html; |
505 | } |
506 | |
507 | float GetLog2BufferSize(const Model& model, const std::string& array_id) { |
508 | auto& array = model.GetArray(array_id); |
509 | if (array.has_shape()) { |
510 | int buffer_size = 0; |
511 | if (IsNonEmpty(array.shape())) { |
512 | buffer_size = RequiredBufferSizeForShape(array.shape()); |
513 | return std::log2(static_cast<float>(buffer_size)); |
514 | } |
515 | } |
516 | return 0.0f; |
517 | } |
518 | |
519 | std::string GetOpId(int op_index) { return StringF("op%05d" , op_index); } |
520 | |
521 | void DumpOperator(const Model& model, std::string* output_file, int op_index) { |
522 | // Dump node for operator. |
523 | const Operator& op = *model.operators[op_index]; |
524 | Color color = GetOpColor(op); |
525 | std::string label = GetOpLabel(model, op); |
526 | std::string op_id = GetOpId(op_index); |
527 | AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(), |
528 | color.TextColorString()); |
529 | } |
530 | |
531 | void DumpOperatorEdges(const Model& model, std::string* output_file, |
532 | int op_index) { |
533 | // Inputs |
534 | const Operator& op = *model.operators[op_index]; |
535 | std::string op_id = GetOpId(op_index); |
536 | for (int i = 0; i < op.inputs.size(); i++) { |
537 | const auto& input = op.inputs[i]; |
538 | if (!model.HasArray(input)) { |
539 | // Connected arrays should _always_ exist. Except, perhaps, during |
540 | // development. |
541 | continue; |
542 | } |
543 | float log2_buffer_size = GetLog2BufferSize(model, input); |
544 | // Draw lines that transport more data thicker (Otherwise, where would the |
545 | // data fit? right?). |
546 | float line_width = std::max(0.5f, log2_buffer_size / 3.0f); |
547 | // Keep edges that transport more data shorter than those with less. |
548 | float weight = std::max(1.0f, log2_buffer_size); |
549 | if (!IsInputArray(model, input) && |
550 | GetOpWithOutput(model, input) == nullptr) { |
551 | // Give the main line of data flow a straighter path by penalizing edges |
552 | // to standalone buffers. Weights are generally very large buffers that |
553 | // would otherwise skew the layout. |
554 | weight = 1.0f; |
555 | } |
556 | std::string compass_pt = GetArrayCompassPt(model, input); |
557 | AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width, |
558 | weight); |
559 | } |
560 | // Outputs |
561 | for (int i = 0; i < op.outputs.size(); i++) { |
562 | const auto& output = op.outputs[i]; |
563 | if (!model.HasArray(output)) { |
564 | continue; |
565 | } |
566 | float log2_buffer_size = GetLog2BufferSize(model, output); |
567 | // See comments above regarding weight and line_width calculations. |
568 | float line_width = std::max(0.5f, log2_buffer_size / 3.0f); |
569 | float weight = std::max(1.0f, log2_buffer_size); |
570 | if (!IsArrayConsumed(model, output)) { |
571 | weight = 1.0f; |
572 | } |
573 | std::string compass_pt = GetArrayCompassPt(model, output); |
574 | AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt, |
575 | line_width, weight); |
576 | } |
577 | } |
578 | |
579 | struct Node { |
580 | Node() : math_ops(0) {} |
581 | // Name used as a key in the model's array map |
582 | std::string array_id; |
583 | |
584 | // Estimated number of math ops incurred by this node (the sum of the op |
585 | // with this array as 1st output, plus all children nodes). |
586 | int64_t math_ops; |
587 | |
588 | // A map of child nodes keyed by name. |
589 | std::map<const std::string, std::unique_ptr<Node>> children; |
590 | }; |
591 | |
592 | std::string GetSubgraphLabel(Node const& node, const std::string& subgraph) { |
593 | // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) |
594 | std::string html; |
595 | html += "<" ; |
596 | |
597 | // Begin Table |
598 | html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE" ; |
599 | html += |
600 | R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE" ; |
601 | |
602 | // Name |
603 | html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE" ; |
604 | html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE" ; |
605 | html += subgraph; |
606 | html += R"CODE(</I></FONT>)CODE" ; |
607 | html += "</TD></TR>" ; |
608 | |
609 | // Attributes |
610 | Attributes attrs; |
611 | if (node.math_ops > 0) { |
612 | attrs["math" ] = FormattedNumber(node.math_ops) + "ops" ; |
613 | } |
614 | html += AttributesToHtml(attrs); |
615 | |
616 | // End Table and HTML-like label |
617 | html += R"CODE(</TABLE></FONT>)CODE" ; |
618 | html += ">" ; |
619 | |
620 | return html; |
621 | } |
622 | |
623 | void (std::string* output_file, Node const& node, |
624 | const std::string& node_name) { |
625 | Color color = HashStringToColor(node_name); |
626 | std::string label = GetSubgraphLabel(node, node_name); |
627 | AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label); |
628 | } |
629 | |
630 | void DumpArray(const Model& model, std::string* output_file, |
631 | const std::string& array_id) { |
632 | Color color; |
633 | std::string shape; |
634 | GetArrayColorAndShape(model, array_id, &color, &shape); |
635 | std::string label = GetArrayLabel(model, array_id); |
636 | AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape, |
637 | color.AsHexString(), color.TextColorString()); |
638 | |
639 | // Ops are placed in the same subgraph as their first output. |
640 | for (int op_index = 0; op_index < model.operators.size(); op_index++) { |
641 | const Operator& op = *model.operators[op_index]; |
642 | if (!op.outputs.empty() && (op.outputs[0] == array_id)) { |
643 | DumpOperator(model, output_file, op_index); |
644 | } |
645 | } |
646 | } |
647 | |
648 | void DumpNode(const Model& model, std::string* output_file, |
649 | const std::string& node_name, Node const& node) { |
650 | bool not_root = !node_name.empty(); |
651 | if (not_root) { |
652 | DumpSubgraphHeader(output_file, node, node_name); |
653 | } |
654 | |
655 | for (const auto& child : node.children) { |
656 | if (!child.second->array_id.empty()) { |
657 | // Dump array if this node possesses one. |
658 | DumpArray(model, output_file, child.second->array_id); |
659 | } |
660 | // Note that it is always possible to have children. Unlike a filesystem, |
661 | // the existence of array "foo/bar" does _not_ prevent other arrays, such as |
662 | // and "foo/bar/baz", from being nested beneath it. |
663 | DumpNode(model, output_file, child.first, *child.second); |
664 | } |
665 | |
666 | if (not_root) { |
667 | // End subgraph |
668 | AppendF(output_file, " }\n" ); |
669 | } |
670 | } |
671 | |
672 | int64_t GetArithmeticOpsCount(const Model& model, const std::string& array_id) { |
673 | for (const auto& op : model.operators) { |
674 | if (!op->outputs.empty() && op->outputs[0] == array_id) { |
675 | int64_t count; |
676 | if (EstimateArithmeticOpsCount(model, *op, &count)) { |
677 | return count; |
678 | } else { |
679 | return 0; |
680 | } |
681 | } |
682 | } |
683 | return 0; |
684 | } |
685 | |
686 | void InsertNode(const Model& model, const std::string& array_id, Node* node, |
687 | std::vector<std::string> prefixes, int64_t* math_ops) { |
688 | if (prefixes.empty()) { |
689 | // Base case: store array in this node. |
690 | node->array_id = array_id; |
691 | *math_ops = GetArithmeticOpsCount(model, array_id); |
692 | } else { |
693 | // Insert into the sub-tree for that prefix. |
694 | std::string prefix = prefixes.back(); |
695 | prefixes.pop_back(); |
696 | if (node->children.count(prefix) == 0) { |
697 | // Create a new node if this prefix is unseen. |
698 | node->children[prefix] = std::make_unique<Node>(); |
699 | } |
700 | InsertNode(model, array_id, node->children[prefix].get(), prefixes, |
701 | math_ops); |
702 | } |
703 | // Sum estimated math ops into all nodes. |
704 | node->math_ops += *math_ops; |
705 | } |
706 | |
707 | void BuildArrayTree(const Model& model, Node* tree) { |
708 | // Delimit array names by path "/", then place into a tree based on this path. |
709 | for (const auto& array_id : model.GetArrayMap()) { |
710 | std::vector<std::string> prefixes = absl::StrSplit(array_id.first, '/'); |
711 | std::reverse(prefixes.begin(), prefixes.end()); |
712 | int64_t math_ops; // Temporary storage for math ops used during recursion. |
713 | InsertNode(model, array_id.first, tree, prefixes, &math_ops); |
714 | } |
715 | } |
716 | |
717 | std::string GetGraphLabel(const Model& model, const std::string& graph_name) { |
718 | // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html) |
719 | std::string html; |
720 | html += "<" ; |
721 | |
722 | // Begin Table |
723 | html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE" ; |
724 | html += |
725 | R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE" ; |
726 | |
727 | // Name |
728 | html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE" ; |
729 | html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE" ; |
730 | html += graph_name; |
731 | html += R"CODE(</I></B></FONT>)CODE" ; |
732 | html += "</TD></TR>" ; |
733 | |
734 | // Attributes |
735 | Attributes attrs; |
736 | attrs["arrays" ] = StringF("%d" , model.GetArrayMap().size()); |
737 | if (!model.optional_arrays.empty()) { |
738 | attrs["optional arrays" ] = StringF("%d" , model.optional_arrays.size()); |
739 | } |
740 | attrs["operators" ] = StringF("%d" , model.operators.size()); |
741 | int64_t ops_count; |
742 | if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) { |
743 | attrs["math" ] = FormattedNumber(ops_count) + "ops" ; |
744 | } |
745 | if (model.transient_data_size > 0) { |
746 | attrs["transient data size" ] = |
747 | StringF("%d KiB" , model.transient_data_size / 1024); |
748 | } |
749 | if (model.transient_data_alignment > 0) { |
750 | attrs["transient data alignment" ] = |
751 | StringF("%d bytes" , model.transient_data_alignment); |
752 | } |
753 | html += AttributesToHtml(attrs); |
754 | |
755 | // End Table and HTML-like label |
756 | html += R"CODE(</TABLE></FONT>)CODE" ; |
757 | html += ">" ; |
758 | |
759 | return html; |
760 | } |
761 | } // namespace |
762 | |
763 | void DumpGraphviz(const Model& model, std::string* output_file, |
764 | const std::string& graph_name) { |
765 | // Start graphviz format |
766 | AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name)); |
767 | |
768 | // Organize arrays into a tree for subgraphing |
769 | Node tree; |
770 | BuildArrayTree(model, &tree); |
771 | DumpNode(model, output_file, "" , tree); |
772 | |
773 | // Dump edges outside all subgraphs (otherwise the referred-to nodes are |
774 | // implicitly included in that subgraph). |
775 | for (int op_index = 0; op_index < model.operators.size(); op_index++) { |
776 | DumpOperatorEdges(model, output_file, op_index); |
777 | } |
778 | |
779 | // Dump RNN Backedges |
780 | for (const auto& rnn_state : model.flags.rnn_states()) { |
781 | AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(), |
782 | rnn_state.state_array()); |
783 | } |
784 | // End graphviz format |
785 | AppendF(output_file, "}\n" ); |
786 | } |
787 | } // namespace toco |
788 | |