1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/function_utils.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/function_body.h" |
19 | #include "tensorflow/core/framework/function.h" |
20 | #include "tensorflow/core/framework/graph.pb.h" |
21 | #include "tensorflow/core/framework/node_def.pb.h" |
22 | #include "tensorflow/core/framework/node_def_util.h" |
23 | #include "tensorflow/core/framework/op_def.pb.h" |
24 | #include "tensorflow/core/framework/versions.pb.h" |
25 | #include "tensorflow/core/graph/algorithm.h" |
26 | #include "tensorflow/core/graph/control_flow.h" |
27 | #include "tensorflow/core/graph/graph.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | static constexpr const char* const kNodeLabel = "Func" ; |
32 | |
33 | // Represents the index-th output of a node. |
34 | struct Endpoint { |
35 | Node* node; |
36 | int index; |
37 | |
38 | // Returns the string name represents this endpoint. |
39 | string name() const { |
40 | if (index == 0) { |
41 | return node->name(); |
42 | } else { |
43 | return strings::StrCat(node->name(), ":" , index); |
44 | } |
45 | } |
46 | |
47 | DataType dtype() const { return node->output_type(index); } |
48 | }; |
49 | |
50 | // The following Add* routines are used to add a few graph nodes while |
51 | // functions are transformed. |
52 | static Node* AddNoOp(StringPiece name, Graph* g) { |
53 | NodeDef ndef; |
54 | ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/" , name))); |
55 | ndef.set_op("NoOp" ); |
56 | Status s; |
57 | Node* ret = g->AddNode(ndef, &s); |
58 | TF_CHECK_OK(s); |
59 | return ret; |
60 | } |
61 | |
62 | static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { |
63 | DCHECK_LT(0, input.dtype()); |
64 | NodeDef ndef; |
65 | ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/" , name))); |
66 | ndef.set_op("Identity" ); |
67 | ndef.add_input(input.name()); |
68 | AddNodeAttr("T" , BaseType(input.dtype()), &ndef); |
69 | Status s; |
70 | Node* ret = g->AddNode(ndef, &s); |
71 | TF_CHECK_OK(s); |
72 | g->AddEdge(input.node, input.index, ret, 0); |
73 | return ret; |
74 | } |
75 | |
76 | void DumpGraph(StringPiece label, const Graph* g) { |
77 | // TODO(zhifengc): Change Graph to record #nodes. |
78 | VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " |
79 | << g->num_edges(); |
80 | if (VLOG_IS_ON(5)) { |
81 | for (const auto& line : str_util::Split(DebugString(g), '\n')) { |
82 | VLOG(5) << "|| " << line; |
83 | } |
84 | } |
85 | } |
86 | |
87 | bool RemoveDeadNodes(Graph* g) { |
88 | VLOG(2) << "Removing dead nodes" ; |
89 | std::unordered_set<const Node*> nodes; |
90 | for (auto n : g->nodes()) { |
91 | if (n->IsSource() || n->IsSink() || n->IsControlFlow() || |
92 | n->op_def().is_stateful()) { |
93 | nodes.insert(n); |
94 | } |
95 | } |
96 | return PruneForReverseReachability(g, std::move(nodes)); |
97 | } |
98 | |
99 | namespace { |
100 | // If 'edges' contains only 1 non-control edge, returns it. Otherwise, |
101 | // returns a nullptr. |
102 | const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { |
103 | const Edge* ret = nullptr; |
104 | for (const Edge* e : edges) { |
105 | if (e->IsControlEdge() || ret) { |
106 | // Don't touch it if there is a control edge. |
107 | return nullptr; |
108 | } |
109 | if (IsRefType(e->src()->output_type(e->src_output()))) { |
110 | // Don't touch it if the identity node is effectively de-reffing |
111 | // a ref. |
112 | return nullptr; |
113 | } |
114 | if (IsRecv(e->src()) || IsSwitch(e->src())) { |
115 | // Don't touch it if the identity is introduced for control flow. |
116 | // Recv disables all its successors if it receives a dead signal. |
117 | // When Recv has an outgoing control edge, the current executor |
118 | // would not disable the destination. The current solution (see |
119 | // graph_partition.cc) is to add an identity after Recv and change |
120 | // the control edge to be from this identity node. So the identity |
121 | // can't be removed. |
122 | return nullptr; |
123 | } |
124 | ret = e; |
125 | } |
126 | return ret; |
127 | } |
128 | } // end namespace |
129 | |
130 | bool RemoveIdentityNodes(Graph* g) { |
131 | VLOG(2) << "Removing identity nodes" ; |
132 | bool removed_any = false; |
133 | gtl::InlinedVector<Node*, 8> matches; |
134 | for (Node* n : g->nodes()) { |
135 | if (!n->IsIdentity()) continue; |
136 | if (!GetTheOnlyDataEdge(n->in_edges())) continue; |
137 | |
138 | // Some identity nodes are used as sink nodes to give names to output |
139 | // tensors. These nodes are not going to be executed unless they are in the |
140 | // fetch set. But if they are in the fetch set we don't want to remove them. |
141 | if (n->out_edges().empty()) continue; |
142 | |
143 | matches.push_back(n); |
144 | } |
145 | if (!matches.empty()) { |
146 | for (Node* n : matches) { |
147 | const Edge* in = GetTheOnlyDataEdge(n->in_edges()); |
148 | for (const Edge* out : n->out_edges()) { |
149 | if (out->IsControlEdge()) { |
150 | g->AddControlEdge(in->src(), out->dst()); |
151 | } else { |
152 | g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input()); |
153 | } |
154 | } |
155 | VLOG(2) << "Remove Identity: " << n->DebugString(); |
156 | g->RemoveNode(n); |
157 | removed_any = true; |
158 | } |
159 | } |
160 | return removed_any; |
161 | } |
162 | |
163 | bool RemoveListArrayConverter(Graph* g) { |
164 | VLOG(2) << "Removing list array converter" ; |
165 | gtl::InlinedVector<Node*, 8> matches; |
166 | for (Node* n : g->nodes()) { |
167 | if ((n->type_string() == "_ListToArray" ) || |
168 | (n->type_string() == "_ArrayToList" )) { |
169 | matches.push_back(n); |
170 | } |
171 | } |
172 | bool removed_any = false; |
173 | if (!matches.empty()) { |
174 | for (Node* n : matches) { |
175 | if (n->num_inputs() != n->num_outputs()) { |
176 | continue; // Not expected. Skip. |
177 | } |
178 | gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr); |
179 | |
180 | const auto no_op = [&](StringPiece name) -> Node* { |
181 | return AddNoOp(absl::StrCat(n->name(), "/" , name), g); |
182 | }; |
183 | |
184 | const auto identity = [&](StringPiece name, Endpoint input) -> Node* { |
185 | Node* node = AddIdentity(absl::StrCat(n->name(), "/" , name), g, input); |
186 | node->set_requested_device(input.node->def().device()); |
187 | return node; |
188 | }; |
189 | |
190 | // Process input edges first. |
191 | Node* input_control_node = nullptr; |
192 | for (const Edge* e : n->in_edges()) { |
193 | if (e->IsControlEdge()) { |
194 | if (input_control_node == nullptr) { |
195 | // If node "n" has any control dependencies, adds a no-op |
196 | // node (input_control_node) which the additional Identity |
197 | // nodes depends on and the input_control_node depends on |
198 | // the node "n"s control dependencies. |
199 | input_control_node = no_op("input_control_node" ); |
200 | } |
201 | g->AddControlEdge(e->src(), input_control_node); |
202 | } else { |
203 | const int index = e->dst_input(); |
204 | Node** id_node = &identity_nodes[index]; |
205 | if (*id_node != nullptr) { |
206 | LOG(ERROR) |
207 | << "RemoveListArrayConverter unexpected duplicated input: " |
208 | << e->dst_input(); |
209 | return removed_any; |
210 | } |
211 | *id_node = identity("input" , {e->src(), e->src_output()}); |
212 | } |
213 | } |
214 | |
215 | // If node "n" has any control dependencies, the added identity |
216 | // nodes should have control dependencies on input_control_node. |
217 | if (input_control_node != nullptr) { |
218 | for (Node* id : identity_nodes) { |
219 | g->AddControlEdge(input_control_node, id); |
220 | } |
221 | } |
222 | |
223 | Node* output_control_node = nullptr; |
224 | for (const Edge* e : n->out_edges()) { |
225 | if (e->IsControlEdge()) { |
226 | if (output_control_node == nullptr) { |
227 | // If node "n" is control-depended upon by other nodes, |
228 | // adds a no-op node (output_control_node) which those |
229 | // nodes will depend on and output_control_node depends on |
230 | // all Identity nodes. |
231 | output_control_node = no_op("output_control_node" ); |
232 | } |
233 | g->AddControlEdge(output_control_node, e->dst()); |
234 | } else { |
235 | Node* id_node = identity_nodes[e->src_output()]; |
236 | if (id_node == nullptr) { |
237 | LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: " |
238 | << e->src_output(); |
239 | return removed_any; |
240 | } |
241 | CHECK(id_node); |
242 | g->AddEdge(id_node, 0, e->dst(), e->dst_input()); |
243 | } |
244 | } |
245 | |
246 | // If any nodes have control dependencies on node "n", those |
247 | // nodes should have control dependencies on |
248 | // output_control_node. |
249 | if (output_control_node != nullptr) { |
250 | for (Node* id : identity_nodes) { |
251 | g->AddControlEdge(id, output_control_node); |
252 | } |
253 | } |
254 | |
255 | g->RemoveNode(n); |
256 | removed_any = true; |
257 | } |
258 | } |
259 | return removed_any; |
260 | } |
261 | |
262 | Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, |
263 | NameAttrList* function) { |
264 | if (call_def.op() == "PartitionedCall" || |
265 | call_def.op() == "StatefulPartitionedCall" ) { |
266 | TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f" , function)); |
267 | } else { |
268 | function->set_name(call_def.op()); |
269 | *function->mutable_attr() = call_def.attr(); |
270 | } |
271 | return OkStatus(); |
272 | } |
273 | |
274 | Status InstantiateFunctionCall(const NodeDef& call_def, |
275 | FunctionLibraryRuntime* flr, |
276 | FunctionLibraryRuntime::Handle* handle) { |
277 | NameAttrList function; |
278 | TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function)); |
279 | return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle); |
280 | } |
281 | |
282 | bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, |
283 | const Node& node) { |
284 | return node.IsFunctionCall(); |
285 | } |
286 | |
287 | string NewName(const Node* n, bool pretty) { |
288 | if (pretty) { |
289 | return strings::StrCat(n->type_string(), n->id()); |
290 | } else { |
291 | return strings::StrCat("n" , n->id()); |
292 | } |
293 | } |
294 | |
295 | // TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef. |
296 | // and stash the original NodeDef name as an attr for documentation |
297 | // purpose. |
298 | void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { |
299 | // We visit nodes in forward topological sort order, which is a |
300 | // possible execution order of the graph. |
301 | gtl::InlinedVector<const Edge*, 4> inputs; |
302 | gdef->Clear(); |
303 | *gdef->mutable_versions() = g->versions(); |
304 | |
305 | std::vector<Node*> start_nodes; |
306 | for (Node* n : g->nodes()) { |
307 | if (n->out_edges().empty()) { |
308 | start_nodes.push_back(n); |
309 | } |
310 | } |
311 | |
312 | ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) { |
313 | if (!n->IsOp()) return; |
314 | NodeDef* ndef = gdef->add_node(); |
315 | ndef->set_name(NewName(n, pretty)); |
316 | ndef->set_op(n->type_string()); |
317 | for (const auto& attr : n->attrs()) { |
318 | (*ndef->mutable_attr())[attr.first] = attr.second; |
319 | } |
320 | |
321 | if (!n->assigned_device_name().empty()) { |
322 | ndef->set_device(n->assigned_device_name()); |
323 | } else { |
324 | ndef->set_device(n->requested_device()); |
325 | } |
326 | |
327 | inputs.clear(); |
328 | inputs.resize(n->num_inputs()); |
329 | for (const Edge* e : n->in_edges()) { |
330 | if (e->IsControlEdge()) { |
331 | inputs.push_back(e); |
332 | } else { |
333 | if (inputs[e->dst_input()] == nullptr) { |
334 | inputs[e->dst_input()] = e; |
335 | } else { |
336 | LOG(WARNING) << "Malformed graph node. multiple input edges: " |
337 | << n->DebugString(); |
338 | } |
339 | } |
340 | } |
341 | // node->name() is merely NodeDef::name, which are not guaranteed |
342 | // to be unique and stable after optimization rewrites. Therefore, |
343 | // we use "n<node id>" instead. |
344 | for (const Edge* e : inputs) { |
345 | if (e == nullptr) { |
346 | ndef->add_input("unknown" ); |
347 | continue; |
348 | } |
349 | const string srcname = NewName(e->src(), pretty); |
350 | if (!e->src()->IsOp()) { |
351 | } else if (e->IsControlEdge()) { |
352 | ndef->add_input(strings::StrCat("^" , srcname)); |
353 | } else if (e->src_output() == 0) { |
354 | ndef->add_input(srcname); |
355 | } else { |
356 | ndef->add_input(strings::StrCat(srcname, ":" , e->src_output())); |
357 | } |
358 | } |
359 | }); |
360 | } |
361 | |
362 | string DebugString(const Graph* g) { |
363 | GraphDef gdef; |
364 | ToGraphDef(g, &gdef); |
365 | return DebugString(gdef); |
366 | } |
367 | |
368 | } // end namespace tensorflow |
369 | |