1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/lib/strings/str_util.h" |
21 | #include "tensorflow/core/platform/init_main.h" |
22 | #include "tensorflow/core/public/session.h" |
23 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
24 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace graph_transforms { |
28 | |
29 | // Clears the device field of all ops in the graph. |
30 | Status InsertLogging(const GraphDef& input_graph_def, |
31 | const TransformFuncContext& context, |
32 | GraphDef* output_graph_def) { |
33 | std::unordered_set<string> ops; |
34 | bool has_ops; |
35 | if (context.params.count("op" )) { |
36 | has_ops = true; |
37 | for (const string& op : context.params.at("op" )) { |
38 | ops.insert(op); |
39 | } |
40 | } else { |
41 | has_ops = false; |
42 | } |
43 | |
44 | std::unordered_set<string> prefixes; |
45 | bool has_prefixes; |
46 | if (context.params.count("prefix" )) { |
47 | has_prefixes = true; |
48 | for (const string& prefix : context.params.at("prefix" )) { |
49 | prefixes.insert(prefix); |
50 | } |
51 | } else { |
52 | has_prefixes = false; |
53 | } |
54 | |
55 | string message; |
56 | TF_RETURN_IF_ERROR(context.GetOneStringParameter("message" , "" , &message)); |
57 | |
58 | bool show_name; |
59 | TF_RETURN_IF_ERROR( |
60 | context.GetOneBoolParameter("show_name" , false, &show_name)); |
61 | |
62 | bool show_op; |
63 | TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op" , false, &show_op)); |
64 | |
65 | int32_t first_n; |
66 | TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n" , -1, &first_n)); |
67 | |
68 | int32_t summarize; |
69 | TF_RETURN_IF_ERROR( |
70 | context.GetOneInt32Parameter("summarize" , 1024, &summarize)); |
71 | |
72 | std::unordered_map<string, std::set<int>> node_outputs; |
73 | for (const NodeDef& node : input_graph_def.node()) { |
74 | for (const string& input : node.input()) { |
75 | const string canonical_input = CanonicalInputName(input); |
76 | string prefix; |
77 | string name; |
78 | string suffix; |
79 | NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix); |
80 | const string output_index_string = suffix.substr(1, suffix.size() - 1); |
81 | int32_t output_index; |
82 | if (!strings::safe_strto32(output_index_string, &output_index)) { |
83 | return errors::InvalidArgument("Couldn't understand output number in " , |
84 | input); |
85 | } |
86 | node_outputs[name].insert(output_index); |
87 | } |
88 | } |
89 | |
90 | std::map<string, string> inputs_to_rename; |
91 | std::unordered_set<string> ignore_when_renaming; |
92 | GraphDef logged_graph_def; |
93 | for (const NodeDef& node : input_graph_def.node()) { |
94 | NodeDef* new_node = logged_graph_def.mutable_node()->Add(); |
95 | *new_node = node; |
96 | if (node_outputs[node.name()].empty()) { |
97 | // There were no outputs found to this node, so skip it. |
98 | continue; |
99 | } |
100 | const bool op_matches = (ops.count(node.op()) > 0); |
101 | bool prefix_matches = false; |
102 | for (const string& prefix : prefixes) { |
103 | if (absl::StartsWith(node.name(), prefix)) { |
104 | prefix_matches = true; |
105 | } |
106 | } |
107 | // If we're not looking for ops, or we found the right op, and if we're not |
108 | // looking for prefixes or we found the right prefix, then add logging here. |
109 | if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) { |
110 | const string name_suffix = "__print__" ; |
111 | DataTypeVector input_types; |
112 | DataTypeVector output_types; |
113 | TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types)); |
114 | NodeDef* print_node = logged_graph_def.mutable_node()->Add(); |
115 | print_node->set_op("Print" ); |
116 | print_node->set_name(strings::StrCat(node.name(), name_suffix)); |
117 | string node_message; |
118 | if (show_op) { |
119 | node_message += ";" + node.op() + ";" ; |
120 | } |
121 | if (show_name) { |
122 | node_message += ";" + print_node->name() + ";" ; |
123 | } |
124 | node_message += message; |
125 | SetNodeAttr("message" , node_message, print_node); |
126 | SetNodeAttr("first_n" , first_n, print_node); |
127 | SetNodeAttr("summarize" , summarize, print_node); |
128 | print_node->add_input(node.name() + ":0" ); |
129 | SetNodeAttr("T" , output_types[0], print_node); |
130 | for (int output_index : node_outputs[node.name()]) { |
131 | print_node->add_input(strings::StrCat(node.name(), ":" , output_index)); |
132 | } |
133 | SetNodeAttr("U" , output_types, print_node); |
134 | ignore_when_renaming.insert(print_node->name()); |
135 | // Rewrite the graph so all references to the first input of the original |
136 | // op now pull from the print op instead, so it's executed. |
137 | inputs_to_rename[node.name() + ":0" ] = |
138 | strings::StrCat(node.name(), name_suffix, ":0" ); |
139 | } |
140 | } |
141 | |
142 | output_graph_def->Clear(); |
143 | return RenameNodeInputs(logged_graph_def, inputs_to_rename, |
144 | ignore_when_renaming, output_graph_def); |
145 | } |
146 | |
147 | REGISTER_GRAPH_TRANSFORM("insert_logging" , InsertLogging); |
148 | |
149 | } // namespace graph_transforms |
150 | } // namespace tensorflow |
151 | |