1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/graph_def_util.h"
17
18#include <set>
19#include <unordered_map>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/attr_value.pb.h"
24#include "tensorflow/core/framework/function.h"
25#include "tensorflow/core/framework/function.pb.h"
26#include "tensorflow/core/framework/graph.pb.h"
27#include "tensorflow/core/framework/node_def.pb.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op_def_util.h"
30#include "tensorflow/core/framework/versions.pb.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/lib/strings/str_util.h"
34#include "tensorflow/core/lib/strings/strcat.h"
35
36namespace tensorflow {
37
38string SummarizeGraphDef(const GraphDef& graph_def) {
39 string ret;
40 strings::StrAppend(
41 &ret, "versions = ", graph_def.versions().ShortDebugString(), ";\n");
42 for (const NodeDef& node : graph_def.node()) {
43 strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
44 }
45 return ret;
46}
47
48Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
49 for (const NodeDef& node : graph_def.node()) {
50 TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
51 }
52 return OkStatus();
53}
54
55Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
56 const OpRegistryInterface& op_registry,
57 int node_offset) {
58 return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
59}
60
61Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
62 const OpRegistryInterface& op_registry,
63 int node_offset, bool skip_unknown_ops) {
64 if (node_offset > graph_def->node_size()) {
65 return errors::InvalidArgument(
66 "Tried to add default attrs to GraphDef "
67 "starting at offset ",
68 node_offset, " with total nodes in graph: ", graph_def->node_size());
69 }
70
71 for (int i = node_offset; i < graph_def->node_size(); ++i) {
72 NodeDef* node_def = graph_def->mutable_node(i);
73 const OpDef* op_def;
74 Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
75 if (s.ok()) {
76 AddDefaultsToNodeDef(*op_def, node_def);
77 } else if (!skip_unknown_ops) {
78 return s;
79 }
80 }
81
82 return OkStatus();
83}
84
85static Status RemoveNewDefaultAttrsFromNodeDef(
86 NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
87 const OpRegistryInterface& producer_op_registry,
88 std::set<std::pair<string, string>>* op_attr_removed) {
89 const OpDef* producer_op_def;
90 const OpDef* consumer_op_def;
91 TF_RETURN_IF_ERROR(
92 producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
93 TF_RETURN_IF_ERROR(
94 consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
95
96 std::vector<string> to_remove;
97 for (const auto& attr : node_def->attr()) {
98 // If the attr is not in consumer_op_def and doesn't start with '_'...
99 if (!absl::StartsWith(attr.first, "_") &&
100 FindAttr(attr.first, *consumer_op_def) == nullptr) {
101 const OpDef::AttrDef* producer_attr_def =
102 FindAttr(attr.first, *producer_op_def);
103 if (producer_attr_def == nullptr) {
104 return errors::InvalidArgument(
105 "Attr '", attr.first,
106 "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
107 " but found in node: ", FormatNodeDefForError(*node_def));
108 }
109 // ...and it has the same value as the default in producer,
110 if (producer_attr_def->has_default_value() &&
111 AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
112 // then we will remove it below.
113 to_remove.emplace_back(attr.first);
114 }
115 }
116 }
117 // We separate identifying which attrs should be removed from
118 // actually removing them to avoid invalidating the loop iterators
119 // above.
120 for (const string& attr_name : to_remove) {
121 node_def->mutable_attr()->erase(attr_name);
122 if (op_attr_removed != nullptr) {
123 op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
124 }
125 }
126
127 return OkStatus();
128}
129
130static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
131 for (const auto& func_def : graph_def.library().function()) {
132 if (op_name == func_def.signature().name()) return true;
133 }
134 return false;
135}
136
137Status RemoveNewDefaultAttrsFromGraphDef(
138 GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
139 const OpRegistryInterface& producer_op_registry,
140 std::set<std::pair<string, string>>* op_attr_removed) {
141 // TODO(joshL): Make IsFunction() faster by collecting the names of
142 // all functions as a preprocessing step.
143 for (int n = 0; n < graph_def->node_size(); ++n) {
144 NodeDef* node_def = graph_def->mutable_node(n);
145 if (!IsFunction(*graph_def, node_def->op())) {
146 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
147 node_def, consumer_op_registry, producer_op_registry,
148 op_attr_removed));
149 }
150 }
151 for (int f = 0; f < graph_def->library().function_size(); ++f) {
152 FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
153 for (int n = 0; n < func_def->node_def_size(); ++n) {
154 NodeDef* node_def = func_def->mutable_node_def(n);
155 if (!IsFunction(*graph_def, node_def->op())) {
156 // TODO(josh11b): Better handling of attrs with placeholder values.
157 TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
158 node_def, consumer_op_registry, producer_op_registry,
159 op_attr_removed));
160 }
161 }
162 }
163
164 return OkStatus();
165}
166
167void StripDefaultAttributes(const OpRegistryInterface& op_registry,
168 protobuf::RepeatedPtrField<NodeDef>* nodes) {
169 for (int i = 0; i < nodes->size(); ++i) {
170 NodeDef* node = nodes->Mutable(i);
171
172 const OpDef* op_def;
173 const OpRegistrationData* op_reg_data = nullptr;
174 Status s = op_registry.LookUp(node->op(), &op_reg_data);
175 if (!s.ok()) {
176 VLOG(1) << "Ignoring encountered unknown operation "
177 << SummarizeNodeDef(*node)
178 << " when stripping default attributes. It is likely a function, "
179 "in which case ignoring it is fine";
180 continue;
181 }
182 op_def = &op_reg_data->op_def;
183
184 for (const OpDef::AttrDef& attr_def : op_def->attr()) {
185 if (attr_def.has_default_value()) {
186 AttrValueMap* attrs = node->mutable_attr();
187 const string& name = attr_def.name();
188 auto iter = attrs->find(name);
189 if (iter != attrs->end()) {
190 const AttrValue& default_value = attr_def.default_value();
191 // There should never be an attribute whose default value is a tensor
192 // larger than 32MB so allow false negatives for efficient
193 // comparison.
194 if (AreAttrValuesEqual(iter->second, default_value,
195 /*allow_false_negatives=*/true)) {
196 attrs->erase(name);
197 }
198 }
199 }
200 }
201 }
202}
203
204void OpsUsedByGraph(const GraphDef& graph_def,
205 std::set<string>* ops_used_in_graph) {
206 // Map function names to definitions.
207 std::unordered_map<string, const FunctionDef*> name_to_function;
208 for (const auto& function : graph_def.library().function()) {
209 name_to_function.insert(
210 std::make_pair(function.signature().name(), &function));
211 }
212
213 // Collect the sorted list of op names. Since functions can reference
214 // functions, we need a recursive traversal.
215 std::set<string> used_ops; // Includes both primitive ops and functions
216 std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops
217 // Collect the logic to mark an op in a lambda; it'll be used twice below.
218 const auto mark_op_as_used = [&used_ops, &functions_to_process,
219 &name_to_function](const string& op) {
220 if (used_ops.insert(op).second) {
221 // If it's a function, we'll need to process further
222 const auto it = name_to_function.find(op);
223 if (it != name_to_function.end()) {
224 functions_to_process.push_back(it->second);
225 }
226 }
227 };
228 for (const auto& node : graph_def.node()) {
229 mark_op_as_used(node.op());
230 }
231 while (!functions_to_process.empty()) {
232 const FunctionDef* fun = functions_to_process.back();
233 functions_to_process.pop_back();
234 for (const auto& node : fun->node_def()) {
235 mark_op_as_used(node.op());
236 }
237 }
238
239 // Filter out function names to produce output.
240 // TODO(josh11b): Change the above code to produce this directly.
241 ops_used_in_graph->clear();
242 for (const string& op_name : used_ops) {
243 if (name_to_function.find(op_name) == name_to_function.end()) {
244 ops_used_in_graph->insert(op_name);
245 }
246 }
247}
248
249Status StrippedOpListForGraph(const GraphDef& graph_def,
250 const OpRegistryInterface& op_registry,
251 OpList* stripped_op_list) {
252 std::set<string> used_ops;
253 OpsUsedByGraph(graph_def, &used_ops);
254
255 // Build the stripped op list in sorted order, ignoring functions.
256 stripped_op_list->clear_op();
257 for (const string& op_name : used_ops) {
258 const OpDef* op_def;
259 TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
260 OpDef* stripped_op = stripped_op_list->add_op();
261 stripped_op->CopyFrom(*op_def);
262 RemoveDescriptionsFromOpDef(stripped_op);
263 }
264 return OkStatus();
265}
266
267} // namespace tensorflow
268