1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/function_utils.h"
17
18#include "tensorflow/core/common_runtime/function_body.h"
19#include "tensorflow/core/framework/function.h"
20#include "tensorflow/core/framework/graph.pb.h"
21#include "tensorflow/core/framework/node_def.pb.h"
22#include "tensorflow/core/framework/node_def_util.h"
23#include "tensorflow/core/framework/op_def.pb.h"
24#include "tensorflow/core/framework/versions.pb.h"
25#include "tensorflow/core/graph/algorithm.h"
26#include "tensorflow/core/graph/control_flow.h"
27#include "tensorflow/core/graph/graph.h"
28
29namespace tensorflow {
30
31static constexpr const char* const kNodeLabel = "Func";
32
33// Represents the index-th output of a node.
34struct Endpoint {
35 Node* node;
36 int index;
37
38 // Returns the string name represents this endpoint.
39 string name() const {
40 if (index == 0) {
41 return node->name();
42 } else {
43 return strings::StrCat(node->name(), ":", index);
44 }
45 }
46
47 DataType dtype() const { return node->output_type(index); }
48};
49
50// The following Add* routines are used to add a few graph nodes while
51// functions are transformed.
52static Node* AddNoOp(StringPiece name, Graph* g) {
53 NodeDef ndef;
54 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
55 ndef.set_op("NoOp");
56 Status s;
57 Node* ret = g->AddNode(ndef, &s);
58 TF_CHECK_OK(s);
59 return ret;
60}
61
62static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
63 DCHECK_LT(0, input.dtype());
64 NodeDef ndef;
65 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
66 ndef.set_op("Identity");
67 ndef.add_input(input.name());
68 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
69 Status s;
70 Node* ret = g->AddNode(ndef, &s);
71 TF_CHECK_OK(s);
72 g->AddEdge(input.node, input.index, ret, 0);
73 return ret;
74}
75
76void DumpGraph(StringPiece label, const Graph* g) {
77 // TODO(zhifengc): Change Graph to record #nodes.
78 VLOG(2) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
79 << g->num_edges();
80 if (VLOG_IS_ON(5)) {
81 for (const auto& line : str_util::Split(DebugString(g), '\n')) {
82 VLOG(5) << "|| " << line;
83 }
84 }
85}
86
87bool RemoveDeadNodes(Graph* g) {
88 VLOG(2) << "Removing dead nodes";
89 std::unordered_set<const Node*> nodes;
90 for (auto n : g->nodes()) {
91 if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
92 n->op_def().is_stateful()) {
93 nodes.insert(n);
94 }
95 }
96 return PruneForReverseReachability(g, std::move(nodes));
97}
98
99namespace {
100// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
101// returns a nullptr.
102const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
103 const Edge* ret = nullptr;
104 for (const Edge* e : edges) {
105 if (e->IsControlEdge() || ret) {
106 // Don't touch it if there is a control edge.
107 return nullptr;
108 }
109 if (IsRefType(e->src()->output_type(e->src_output()))) {
110 // Don't touch it if the identity node is effectively de-reffing
111 // a ref.
112 return nullptr;
113 }
114 if (IsRecv(e->src()) || IsSwitch(e->src())) {
115 // Don't touch it if the identity is introduced for control flow.
116 // Recv disables all its successors if it receives a dead signal.
117 // When Recv has an outgoing control edge, the current executor
118 // would not disable the destination. The current solution (see
119 // graph_partition.cc) is to add an identity after Recv and change
120 // the control edge to be from this identity node. So the identity
121 // can't be removed.
122 return nullptr;
123 }
124 ret = e;
125 }
126 return ret;
127}
128} // end namespace
129
130bool RemoveIdentityNodes(Graph* g) {
131 VLOG(2) << "Removing identity nodes";
132 bool removed_any = false;
133 gtl::InlinedVector<Node*, 8> matches;
134 for (Node* n : g->nodes()) {
135 if (!n->IsIdentity()) continue;
136 if (!GetTheOnlyDataEdge(n->in_edges())) continue;
137
138 // Some identity nodes are used as sink nodes to give names to output
139 // tensors. These nodes are not going to be executed unless they are in the
140 // fetch set. But if they are in the fetch set we don't want to remove them.
141 if (n->out_edges().empty()) continue;
142
143 matches.push_back(n);
144 }
145 if (!matches.empty()) {
146 for (Node* n : matches) {
147 const Edge* in = GetTheOnlyDataEdge(n->in_edges());
148 for (const Edge* out : n->out_edges()) {
149 if (out->IsControlEdge()) {
150 g->AddControlEdge(in->src(), out->dst());
151 } else {
152 g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
153 }
154 }
155 VLOG(2) << "Remove Identity: " << n->DebugString();
156 g->RemoveNode(n);
157 removed_any = true;
158 }
159 }
160 return removed_any;
161}
162
163bool RemoveListArrayConverter(Graph* g) {
164 VLOG(2) << "Removing list array converter";
165 gtl::InlinedVector<Node*, 8> matches;
166 for (Node* n : g->nodes()) {
167 if ((n->type_string() == "_ListToArray") ||
168 (n->type_string() == "_ArrayToList")) {
169 matches.push_back(n);
170 }
171 }
172 bool removed_any = false;
173 if (!matches.empty()) {
174 for (Node* n : matches) {
175 if (n->num_inputs() != n->num_outputs()) {
176 continue; // Not expected. Skip.
177 }
178 gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
179
180 const auto no_op = [&](StringPiece name) -> Node* {
181 return AddNoOp(absl::StrCat(n->name(), "/", name), g);
182 };
183
184 const auto identity = [&](StringPiece name, Endpoint input) -> Node* {
185 Node* node = AddIdentity(absl::StrCat(n->name(), "/", name), g, input);
186 node->set_requested_device(input.node->def().device());
187 return node;
188 };
189
190 // Process input edges first.
191 Node* input_control_node = nullptr;
192 for (const Edge* e : n->in_edges()) {
193 if (e->IsControlEdge()) {
194 if (input_control_node == nullptr) {
195 // If node "n" has any control dependencies, adds a no-op
196 // node (input_control_node) which the additional Identity
197 // nodes depends on and the input_control_node depends on
198 // the node "n"s control dependencies.
199 input_control_node = no_op("input_control_node");
200 }
201 g->AddControlEdge(e->src(), input_control_node);
202 } else {
203 const int index = e->dst_input();
204 Node** id_node = &identity_nodes[index];
205 if (*id_node != nullptr) {
206 LOG(ERROR)
207 << "RemoveListArrayConverter unexpected duplicated input: "
208 << e->dst_input();
209 return removed_any;
210 }
211 *id_node = identity("input", {e->src(), e->src_output()});
212 }
213 }
214
215 // If node "n" has any control dependencies, the added identity
216 // nodes should have control dependencies on input_control_node.
217 if (input_control_node != nullptr) {
218 for (Node* id : identity_nodes) {
219 g->AddControlEdge(input_control_node, id);
220 }
221 }
222
223 Node* output_control_node = nullptr;
224 for (const Edge* e : n->out_edges()) {
225 if (e->IsControlEdge()) {
226 if (output_control_node == nullptr) {
227 // If node "n" is control-depended upon by other nodes,
228 // adds a no-op node (output_control_node) which those
229 // nodes will depend on and output_control_node depends on
230 // all Identity nodes.
231 output_control_node = no_op("output_control_node");
232 }
233 g->AddControlEdge(output_control_node, e->dst());
234 } else {
235 Node* id_node = identity_nodes[e->src_output()];
236 if (id_node == nullptr) {
237 LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
238 << e->src_output();
239 return removed_any;
240 }
241 CHECK(id_node);
242 g->AddEdge(id_node, 0, e->dst(), e->dst_input());
243 }
244 }
245
246 // If any nodes have control dependencies on node "n", those
247 // nodes should have control dependencies on
248 // output_control_node.
249 if (output_control_node != nullptr) {
250 for (Node* id : identity_nodes) {
251 g->AddControlEdge(id, output_control_node);
252 }
253 }
254
255 g->RemoveNode(n);
256 removed_any = true;
257 }
258 }
259 return removed_any;
260}
261
262Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
263 NameAttrList* function) {
264 if (call_def.op() == "PartitionedCall" ||
265 call_def.op() == "StatefulPartitionedCall") {
266 TF_RETURN_IF_ERROR(GetNodeAttr(call_def, "f", function));
267 } else {
268 function->set_name(call_def.op());
269 *function->mutable_attr() = call_def.attr();
270 }
271 return OkStatus();
272}
273
274Status InstantiateFunctionCall(const NodeDef& call_def,
275 FunctionLibraryRuntime* flr,
276 FunctionLibraryRuntime::Handle* handle) {
277 NameAttrList function;
278 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(call_def, &function));
279 return flr->Instantiate(function.name(), AttrSlice(&function.attr()), handle);
280}
281
282bool IsFunctionCall(const FunctionLibraryDefinition& lib_def,
283 const Node& node) {
284 return node.IsFunctionCall();
285}
286
287string NewName(const Node* n, bool pretty) {
288 if (pretty) {
289 return strings::StrCat(n->type_string(), n->id());
290 } else {
291 return strings::StrCat("n", n->id());
292 }
293}
294
295// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
296// and stash the original NodeDef name as an attr for documentation
297// purpose.
298void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
299 // We visit nodes in forward topological sort order, which is a
300 // possible execution order of the graph.
301 gtl::InlinedVector<const Edge*, 4> inputs;
302 gdef->Clear();
303 *gdef->mutable_versions() = g->versions();
304
305 std::vector<Node*> start_nodes;
306 for (Node* n : g->nodes()) {
307 if (n->out_edges().empty()) {
308 start_nodes.push_back(n);
309 }
310 }
311
312 ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
313 if (!n->IsOp()) return;
314 NodeDef* ndef = gdef->add_node();
315 ndef->set_name(NewName(n, pretty));
316 ndef->set_op(n->type_string());
317 for (const auto& attr : n->attrs()) {
318 (*ndef->mutable_attr())[attr.first] = attr.second;
319 }
320
321 if (!n->assigned_device_name().empty()) {
322 ndef->set_device(n->assigned_device_name());
323 } else {
324 ndef->set_device(n->requested_device());
325 }
326
327 inputs.clear();
328 inputs.resize(n->num_inputs());
329 for (const Edge* e : n->in_edges()) {
330 if (e->IsControlEdge()) {
331 inputs.push_back(e);
332 } else {
333 if (inputs[e->dst_input()] == nullptr) {
334 inputs[e->dst_input()] = e;
335 } else {
336 LOG(WARNING) << "Malformed graph node. multiple input edges: "
337 << n->DebugString();
338 }
339 }
340 }
341 // node->name() is merely NodeDef::name, which are not guaranteed
342 // to be unique and stable after optimization rewrites. Therefore,
343 // we use "n<node id>" instead.
344 for (const Edge* e : inputs) {
345 if (e == nullptr) {
346 ndef->add_input("unknown");
347 continue;
348 }
349 const string srcname = NewName(e->src(), pretty);
350 if (!e->src()->IsOp()) {
351 } else if (e->IsControlEdge()) {
352 ndef->add_input(strings::StrCat("^", srcname));
353 } else if (e->src_output() == 0) {
354 ndef->add_input(srcname);
355 } else {
356 ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
357 }
358 }
359 });
360}
361
362string DebugString(const Graph* g) {
363 GraphDef gdef;
364 ToGraphDef(g, &gdef);
365 return DebugString(gdef);
366}
367
368} // end namespace tensorflow
369