1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17#include "tensorflow/core/common_runtime/graph_constructor.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/subgraph.h"
20#include "tensorflow/core/lib/strings/str_util.h"
21#include "tensorflow/core/platform/init_main.h"
22#include "tensorflow/core/public/session.h"
23#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
24#include "tensorflow/tools/graph_transforms/transform_utils.h"
25
26namespace tensorflow {
27namespace graph_transforms {
28
29// Clears the device field of all ops in the graph.
30Status InsertLogging(const GraphDef& input_graph_def,
31 const TransformFuncContext& context,
32 GraphDef* output_graph_def) {
33 std::unordered_set<string> ops;
34 bool has_ops;
35 if (context.params.count("op")) {
36 has_ops = true;
37 for (const string& op : context.params.at("op")) {
38 ops.insert(op);
39 }
40 } else {
41 has_ops = false;
42 }
43
44 std::unordered_set<string> prefixes;
45 bool has_prefixes;
46 if (context.params.count("prefix")) {
47 has_prefixes = true;
48 for (const string& prefix : context.params.at("prefix")) {
49 prefixes.insert(prefix);
50 }
51 } else {
52 has_prefixes = false;
53 }
54
55 string message;
56 TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
57
58 bool show_name;
59 TF_RETURN_IF_ERROR(
60 context.GetOneBoolParameter("show_name", false, &show_name));
61
62 bool show_op;
63 TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
64
65 int32_t first_n;
66 TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
67
68 int32_t summarize;
69 TF_RETURN_IF_ERROR(
70 context.GetOneInt32Parameter("summarize", 1024, &summarize));
71
72 std::unordered_map<string, std::set<int>> node_outputs;
73 for (const NodeDef& node : input_graph_def.node()) {
74 for (const string& input : node.input()) {
75 const string canonical_input = CanonicalInputName(input);
76 string prefix;
77 string name;
78 string suffix;
79 NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
80 const string output_index_string = suffix.substr(1, suffix.size() - 1);
81 int32_t output_index;
82 if (!strings::safe_strto32(output_index_string, &output_index)) {
83 return errors::InvalidArgument("Couldn't understand output number in ",
84 input);
85 }
86 node_outputs[name].insert(output_index);
87 }
88 }
89
90 std::map<string, string> inputs_to_rename;
91 std::unordered_set<string> ignore_when_renaming;
92 GraphDef logged_graph_def;
93 for (const NodeDef& node : input_graph_def.node()) {
94 NodeDef* new_node = logged_graph_def.mutable_node()->Add();
95 *new_node = node;
96 if (node_outputs[node.name()].empty()) {
97 // There were no outputs found to this node, so skip it.
98 continue;
99 }
100 const bool op_matches = (ops.count(node.op()) > 0);
101 bool prefix_matches = false;
102 for (const string& prefix : prefixes) {
103 if (absl::StartsWith(node.name(), prefix)) {
104 prefix_matches = true;
105 }
106 }
107 // If we're not looking for ops, or we found the right op, and if we're not
108 // looking for prefixes or we found the right prefix, then add logging here.
109 if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
110 const string name_suffix = "__print__";
111 DataTypeVector input_types;
112 DataTypeVector output_types;
113 TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
114 NodeDef* print_node = logged_graph_def.mutable_node()->Add();
115 print_node->set_op("Print");
116 print_node->set_name(strings::StrCat(node.name(), name_suffix));
117 string node_message;
118 if (show_op) {
119 node_message += ";" + node.op() + ";";
120 }
121 if (show_name) {
122 node_message += ";" + print_node->name() + ";";
123 }
124 node_message += message;
125 SetNodeAttr("message", node_message, print_node);
126 SetNodeAttr("first_n", first_n, print_node);
127 SetNodeAttr("summarize", summarize, print_node);
128 print_node->add_input(node.name() + ":0");
129 SetNodeAttr("T", output_types[0], print_node);
130 for (int output_index : node_outputs[node.name()]) {
131 print_node->add_input(strings::StrCat(node.name(), ":", output_index));
132 }
133 SetNodeAttr("U", output_types, print_node);
134 ignore_when_renaming.insert(print_node->name());
135 // Rewrite the graph so all references to the first input of the original
136 // op now pull from the print op instead, so it's executed.
137 inputs_to_rename[node.name() + ":0"] =
138 strings::StrCat(node.name(), name_suffix, ":0");
139 }
140 }
141
142 output_graph_def->Clear();
143 return RenameNodeInputs(logged_graph_def, inputs_to_rename,
144 ignore_when_renaming, output_graph_def);
145}
146
147REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
148
149} // namespace graph_transforms
150} // namespace tensorflow
151