1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/graph_def_util.h" |
17 | |
18 | #include <set> |
19 | #include <unordered_map> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/attr_value.pb.h" |
24 | #include "tensorflow/core/framework/function.h" |
25 | #include "tensorflow/core/framework/function.pb.h" |
26 | #include "tensorflow/core/framework/graph.pb.h" |
27 | #include "tensorflow/core/framework/node_def.pb.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/framework/op_def_util.h" |
30 | #include "tensorflow/core/framework/versions.pb.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/lib/strings/str_util.h" |
34 | #include "tensorflow/core/lib/strings/strcat.h" |
35 | |
36 | namespace tensorflow { |
37 | |
38 | string SummarizeGraphDef(const GraphDef& graph_def) { |
39 | string ret; |
40 | strings::StrAppend( |
41 | &ret, "versions = " , graph_def.versions().ShortDebugString(), ";\n" ); |
42 | for (const NodeDef& node : graph_def.node()) { |
43 | strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n" ); |
44 | } |
45 | return ret; |
46 | } |
47 | |
48 | Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { |
49 | for (const NodeDef& node : graph_def.node()) { |
50 | TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); |
51 | } |
52 | return OkStatus(); |
53 | } |
54 | |
55 | Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, |
56 | const OpRegistryInterface& op_registry, |
57 | int node_offset) { |
58 | return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false); |
59 | } |
60 | |
61 | Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, |
62 | const OpRegistryInterface& op_registry, |
63 | int node_offset, bool skip_unknown_ops) { |
64 | if (node_offset > graph_def->node_size()) { |
65 | return errors::InvalidArgument( |
66 | "Tried to add default attrs to GraphDef " |
67 | "starting at offset " , |
68 | node_offset, " with total nodes in graph: " , graph_def->node_size()); |
69 | } |
70 | |
71 | for (int i = node_offset; i < graph_def->node_size(); ++i) { |
72 | NodeDef* node_def = graph_def->mutable_node(i); |
73 | const OpDef* op_def; |
74 | Status s = op_registry.LookUpOpDef(node_def->op(), &op_def); |
75 | if (s.ok()) { |
76 | AddDefaultsToNodeDef(*op_def, node_def); |
77 | } else if (!skip_unknown_ops) { |
78 | return s; |
79 | } |
80 | } |
81 | |
82 | return OkStatus(); |
83 | } |
84 | |
85 | static Status RemoveNewDefaultAttrsFromNodeDef( |
86 | NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, |
87 | const OpRegistryInterface& producer_op_registry, |
88 | std::set<std::pair<string, string>>* op_attr_removed) { |
89 | const OpDef* producer_op_def; |
90 | const OpDef* consumer_op_def; |
91 | TF_RETURN_IF_ERROR( |
92 | producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def)); |
93 | TF_RETURN_IF_ERROR( |
94 | consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); |
95 | |
96 | std::vector<string> to_remove; |
97 | for (const auto& attr : node_def->attr()) { |
98 | // If the attr is not in consumer_op_def and doesn't start with '_'... |
99 | if (!absl::StartsWith(attr.first, "_" ) && |
100 | FindAttr(attr.first, *consumer_op_def) == nullptr) { |
101 | const OpDef::AttrDef* producer_attr_def = |
102 | FindAttr(attr.first, *producer_op_def); |
103 | if (producer_attr_def == nullptr) { |
104 | return errors::InvalidArgument( |
105 | "Attr '" , attr.first, |
106 | "' missing in producer's OpDef: " , SummarizeOpDef(*producer_op_def), |
107 | " but found in node: " , FormatNodeDefForError(*node_def)); |
108 | } |
109 | // ...and it has the same value as the default in producer, |
110 | if (producer_attr_def->has_default_value() && |
111 | AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) { |
112 | // then we will remove it below. |
113 | to_remove.emplace_back(attr.first); |
114 | } |
115 | } |
116 | } |
117 | // We separate identifying which attrs should be removed from |
118 | // actually removing them to avoid invalidating the loop iterators |
119 | // above. |
120 | for (const string& attr_name : to_remove) { |
121 | node_def->mutable_attr()->erase(attr_name); |
122 | if (op_attr_removed != nullptr) { |
123 | op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); |
124 | } |
125 | } |
126 | |
127 | return OkStatus(); |
128 | } |
129 | |
130 | static bool IsFunction(const GraphDef& graph_def, const string& op_name) { |
131 | for (const auto& func_def : graph_def.library().function()) { |
132 | if (op_name == func_def.signature().name()) return true; |
133 | } |
134 | return false; |
135 | } |
136 | |
137 | Status RemoveNewDefaultAttrsFromGraphDef( |
138 | GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, |
139 | const OpRegistryInterface& producer_op_registry, |
140 | std::set<std::pair<string, string>>* op_attr_removed) { |
141 | // TODO(joshL): Make IsFunction() faster by collecting the names of |
142 | // all functions as a preprocessing step. |
143 | for (int n = 0; n < graph_def->node_size(); ++n) { |
144 | NodeDef* node_def = graph_def->mutable_node(n); |
145 | if (!IsFunction(*graph_def, node_def->op())) { |
146 | TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( |
147 | node_def, consumer_op_registry, producer_op_registry, |
148 | op_attr_removed)); |
149 | } |
150 | } |
151 | for (int f = 0; f < graph_def->library().function_size(); ++f) { |
152 | FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f); |
153 | for (int n = 0; n < func_def->node_def_size(); ++n) { |
154 | NodeDef* node_def = func_def->mutable_node_def(n); |
155 | if (!IsFunction(*graph_def, node_def->op())) { |
156 | // TODO(josh11b): Better handling of attrs with placeholder values. |
157 | TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef( |
158 | node_def, consumer_op_registry, producer_op_registry, |
159 | op_attr_removed)); |
160 | } |
161 | } |
162 | } |
163 | |
164 | return OkStatus(); |
165 | } |
166 | |
167 | void StripDefaultAttributes(const OpRegistryInterface& op_registry, |
168 | protobuf::RepeatedPtrField<NodeDef>* nodes) { |
169 | for (int i = 0; i < nodes->size(); ++i) { |
170 | NodeDef* node = nodes->Mutable(i); |
171 | |
172 | const OpDef* op_def; |
173 | const OpRegistrationData* op_reg_data = nullptr; |
174 | Status s = op_registry.LookUp(node->op(), &op_reg_data); |
175 | if (!s.ok()) { |
176 | VLOG(1) << "Ignoring encountered unknown operation " |
177 | << SummarizeNodeDef(*node) |
178 | << " when stripping default attributes. It is likely a function, " |
179 | "in which case ignoring it is fine" ; |
180 | continue; |
181 | } |
182 | op_def = &op_reg_data->op_def; |
183 | |
184 | for (const OpDef::AttrDef& attr_def : op_def->attr()) { |
185 | if (attr_def.has_default_value()) { |
186 | AttrValueMap* attrs = node->mutable_attr(); |
187 | const string& name = attr_def.name(); |
188 | auto iter = attrs->find(name); |
189 | if (iter != attrs->end()) { |
190 | const AttrValue& default_value = attr_def.default_value(); |
191 | // There should never be an attribute whose default value is a tensor |
192 | // larger than 32MB so allow false negatives for efficient |
193 | // comparison. |
194 | if (AreAttrValuesEqual(iter->second, default_value, |
195 | /*allow_false_negatives=*/true)) { |
196 | attrs->erase(name); |
197 | } |
198 | } |
199 | } |
200 | } |
201 | } |
202 | } |
203 | |
204 | void OpsUsedByGraph(const GraphDef& graph_def, |
205 | std::set<string>* ops_used_in_graph) { |
206 | // Map function names to definitions. |
207 | std::unordered_map<string, const FunctionDef*> name_to_function; |
208 | for (const auto& function : graph_def.library().function()) { |
209 | name_to_function.insert( |
210 | std::make_pair(function.signature().name(), &function)); |
211 | } |
212 | |
213 | // Collect the sorted list of op names. Since functions can reference |
214 | // functions, we need a recursive traversal. |
215 | std::set<string> used_ops; // Includes both primitive ops and functions |
216 | std::vector<const FunctionDef*> functions_to_process; // A subset of used_ops |
217 | // Collect the logic to mark an op in a lambda; it'll be used twice below. |
218 | const auto mark_op_as_used = [&used_ops, &functions_to_process, |
219 | &name_to_function](const string& op) { |
220 | if (used_ops.insert(op).second) { |
221 | // If it's a function, we'll need to process further |
222 | const auto it = name_to_function.find(op); |
223 | if (it != name_to_function.end()) { |
224 | functions_to_process.push_back(it->second); |
225 | } |
226 | } |
227 | }; |
228 | for (const auto& node : graph_def.node()) { |
229 | mark_op_as_used(node.op()); |
230 | } |
231 | while (!functions_to_process.empty()) { |
232 | const FunctionDef* fun = functions_to_process.back(); |
233 | functions_to_process.pop_back(); |
234 | for (const auto& node : fun->node_def()) { |
235 | mark_op_as_used(node.op()); |
236 | } |
237 | } |
238 | |
239 | // Filter out function names to produce output. |
240 | // TODO(josh11b): Change the above code to produce this directly. |
241 | ops_used_in_graph->clear(); |
242 | for (const string& op_name : used_ops) { |
243 | if (name_to_function.find(op_name) == name_to_function.end()) { |
244 | ops_used_in_graph->insert(op_name); |
245 | } |
246 | } |
247 | } |
248 | |
249 | Status StrippedOpListForGraph(const GraphDef& graph_def, |
250 | const OpRegistryInterface& op_registry, |
251 | OpList* stripped_op_list) { |
252 | std::set<string> used_ops; |
253 | OpsUsedByGraph(graph_def, &used_ops); |
254 | |
255 | // Build the stripped op list in sorted order, ignoring functions. |
256 | stripped_op_list->clear_op(); |
257 | for (const string& op_name : used_ops) { |
258 | const OpDef* op_def; |
259 | TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def)); |
260 | OpDef* stripped_op = stripped_op_list->add_op(); |
261 | stripped_op->CopyFrom(*op_def); |
262 | RemoveDescriptionsFromOpDef(stripped_op); |
263 | } |
264 | return OkStatus(); |
265 | } |
266 | |
267 | } // namespace tensorflow |
268 | |