1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/subgraph.h"
17
18#include <algorithm>
19#include <deque>
20#include <string>
21#include <unordered_map>
22#include <unordered_set>
23#include <vector>
24
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/node_def_util.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/graph/algorithm.h"
29#include "tensorflow/core/graph/graph.h"
30#include "tensorflow/core/graph/tensor_id.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/platform/logging.h"
34
35namespace tensorflow {
36namespace subgraph {
37
38// ----------------------------------------------------------------------------
39// Subgraph construction-related routines
40// ----------------------------------------------------------------------------
41// TODO(vrv): Profile the unordered_set and unordered_map use in this file to
42// see if we should use an alternative implementation.
43
44namespace {
45
46typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex;
47
48// Rewrite graph by replacing the output tensors specified in
49// "fed_outputs" with special feed nodes for each specified output
50// tensor, and removing any nodes that are now disconnected from the
51// part of the graph that reaches the sink node. The set of special
52// feed nodes added to the graph are returned in "*feed_nodes".
53//
54// Return true on success. On error, return false and sets *error to
55// an appropriate error message (and *g is left in an indeterminate
56// state).
57Status FeedInputs(
58 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
59 NameIndex* name_index, DataTypeVector* out_feed_types) {
60 out_feed_types->clear();
61 out_feed_types->reserve(feed_rewrites.size());
62 for (size_t i = 0; i < feed_rewrites.size(); ++i) {
63 const string& t = feed_rewrites[i]->endpoint_name();
64 TensorId id(ParseTensorName(t));
65
66 auto iter = name_index->find(id.first);
67 if (iter == name_index->end()) {
68 return errors::NotFound("FeedInputs: unable to find feed output ", t);
69 }
70 Node* n = iter->second;
71 DCHECK_EQ(n->name(), id.first);
72 if (id.second >= n->num_outputs()) {
73 return errors::InvalidArgument(
74 "FeedInputs: ", t, " should have output index < ", n->num_outputs());
75 }
76
77 Node* feed_node;
78 TF_RETURN_IF_ERROR(
79 feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node));
80
81 // Update name_index
82 (*name_index)[feed_node->name()] = feed_node;
83 // Duplicate control edges aren't allowed, but feed_node was *just* created
84 // so there's no need to check for a duplicate.
85 g->AddControlEdge(g->source_node(), feed_node, true);
86
87 // Look through edges coming out of "n" for edges whose src_output() index
88 // matches "output_index". If found, replace the edges with a connection
89 // from the special feed node.
90 std::vector<const Edge*> to_remove;
91 for (const Edge* e : n->out_edges()) {
92 if (e->src_output() == id.second) {
93 to_remove.emplace_back(e);
94 } else if (e->src_output() == Graph::kControlSlot &&
95 (n->type_string() == "Placeholder" ||
96 n->type_string() == "PlaceholderV2")) {
97 // When feeding a Placeholder node, any outgoing control edges
98 // will be replaced with a control edge from the replacement
99 // feed_node.
100 // TODO(josh11b,mrry): Come up with a more elegant way of addressing
101 // the general version of this problem.
102 to_remove.emplace_back(e);
103 }
104 }
105
106 for (const Edge* e : to_remove) {
107 if (e->src_output() == id.second) {
108 g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
109 } else {
110 CHECK_EQ(Graph::kControlSlot, e->src_output());
111 // Duplicate control edges aren't allowed, but feed_node was *just*
112 // created so there's no need to check for a duplicate.
113 g->AddControlEdge(feed_node, e->dst(), true);
114 }
115 g->RemoveEdge(e);
116 }
117 out_feed_types->push_back(BaseType(n->output_type(id.second)));
118 }
119 return OkStatus();
120}
121
122Status FetchOutputs(
123 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
124 NameIndex* name_index, std::vector<Node*>* out_fetch_nodes,
125 DataTypeVector* out_fetch_types) {
126 out_fetch_nodes->clear();
127 out_fetch_nodes->reserve(fetch_rewrites.size());
128 for (size_t i = 0; i < fetch_rewrites.size(); ++i) {
129 const string& t = fetch_rewrites[i]->endpoint_name();
130
131 // Parse t into node_name and output_index.
132 TensorId id(ParseTensorName(t));
133
134 // Find node in graph with that name.
135 auto iter = name_index->find(id.first);
136 if (iter == name_index->end()) {
137 return errors::NotFound("FetchOutputs node ", t, ": not found");
138 }
139 Node* n = iter->second;
140 DCHECK_EQ(n->name(), id.first);
141 VLOG(2) << "Found fetch node for " << t;
142
143 // Validate output_index
144 if (n->num_outputs() == 0) {
145 return errors::InvalidArgument(
146 "Tried to fetch data for '", t,
147 "', which produces no output. To run to a node but not fetch any "
148 "data, pass '",
149 t,
150 "' as an argument to the 'target_node_names' argument of the "
151 "Session::Run API.");
152 } else if (id.second >= n->num_outputs()) {
153 return errors::InvalidArgument("FetchOutputs ", t,
154 ": output index too large, must be < ",
155 n->num_outputs());
156 }
157
158 // Create the fetch Node and connect it up
159 Node* fetch_node;
160 TF_RETURN_IF_ERROR(
161 fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node));
162
163 // Update the index.
164 (*name_index)[fetch_node->name()] = fetch_node;
165
166 // Duplicate control edges aren't allowed, but fetch_node was *just* created
167 // so there's no need to check for a duplicate.
168 g->AddControlEdge(fetch_node, g->sink_node(), true);
169 out_fetch_nodes->push_back(fetch_node);
170 out_fetch_types->push_back(BaseType(n->output_type(id.second)));
171 }
172
173 return OkStatus();
174}
175
176bool AddNodeToTargets(const string& node_or_tensor_name,
177 const NameIndex& name_index,
178 std::unordered_set<const Node*>* targets) {
179 TensorId id = ParseTensorName(node_or_tensor_name);
180 auto iter = name_index.find(id.first);
181 if (iter == name_index.end()) {
182 return false;
183 }
184 const Node* n = iter->second;
185 CHECK_EQ(n->name(), id.first);
186 targets->insert(n);
187 return true;
188}
189
190Status PruneForTargets(Graph* g, const NameIndex& name_index,
191 const std::vector<Node*>& fetch_nodes,
192 const gtl::ArraySlice<string>& target_nodes) {
193 string not_found;
194 std::unordered_set<const Node*> targets;
195 for (Node* n : fetch_nodes) {
196 if (!AddNodeToTargets(n->name(), name_index, &targets)) {
197 strings::StrAppend(&not_found, n->name(), " ");
198 }
199 }
200 for (const string& s : target_nodes) {
201 if (!AddNodeToTargets(s, name_index, &targets)) {
202 strings::StrAppend(&not_found, s, " ");
203 }
204 }
205 if (!not_found.empty()) {
206 return errors::NotFound("PruneForTargets: Some target nodes not found: ",
207 not_found);
208 }
209 PruneForReverseReachability(g, std::move(targets));
210
211 // Reconnect nodes with no outgoing edges to the sink node
212 FixupSourceAndSinkEdges(g);
213
214 return OkStatus();
215}
216
217} // namespace
218
219Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
220 Node** out_node) {
221 // NOTE(mrry): We must include the index as part of the node
222 // name, because _Arg is a "stateful" kernel and therefore
223 // its name must uniquely identify a kernel instance across all
224 // graphs in the same session.
225 TF_RETURN_IF_ERROR(
226 NodeBuilder(strings::StrCat("_arg_", feed_tensor.node->name(), "_",
227 feed_tensor.index, "_", arg_index_),
228 "_Arg")
229 .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
230 .Attr("index", arg_index_)
231 .Finalize(g, out_node, /*consume=*/true));
232 (*out_node)->set_assigned_device_name(device_info().name());
233 return OkStatus();
234}
235
236Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
237 Node** out_node) {
238 TF_RETURN_IF_ERROR(
239 NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_",
240 feed_tensor.index),
241 "_Recv")
242 .Attr("tensor_type",
243 BaseType(feed_tensor.node->output_type(feed_tensor.index)))
244 .Attr("tensor_name", endpoint_name())
245 .Attr("send_device", device_info().name())
246 .Attr("recv_device", device_info().name())
247 .Attr("send_device_incarnation",
248 static_cast<int64_t>(device_info().incarnation()))
249 .Attr("client_terminated", true)
250 .Finalize(g, out_node, /*consume=*/true));
251
252 (*out_node)->set_assigned_device_name(device_info().name());
253 return OkStatus();
254}
255
256Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
257 Node** out_node) {
258 // NOTE(mrry): We must include the index as part of the node
259 // name, because _Retval is a "stateful" kernel and therefore
260 // its name must uniquely identify a kernel instance across all
261 // graphs in the same session.
262 TF_RETURN_IF_ERROR(
263 NodeBuilder(strings::StrCat("_retval_", fetch_tensor.node->name(), "_",
264 fetch_tensor.index, "_", retval_index_),
265 "_Retval")
266 .Input(fetch_tensor.node, fetch_tensor.index)
267 .Attr("T",
268 BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
269 .Attr("index", retval_index_)
270 .Finalize(g, out_node, /*consume=*/true));
271 (*out_node)->set_assigned_device_name(device_info().name());
272 return OkStatus();
273}
274
275Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
276 Node** out_node) {
277 TF_RETURN_IF_ERROR(
278 NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_",
279 fetch_tensor.index),
280 "_Send")
281 .Input(fetch_tensor.node, fetch_tensor.index)
282 .Attr("tensor_name", endpoint_name())
283 .Attr("send_device", device_info().name())
284 .Attr("recv_device", device_info().name())
285 .Attr("send_device_incarnation",
286 static_cast<int64_t>(device_info().incarnation()))
287 .Attr("client_terminated", true)
288 .Finalize(g, out_node, /*consume=*/true));
289 (*out_node)->set_assigned_device_name(device_info().name());
290 return OkStatus();
291}
292
293Status RewriteGraphForExecution(
294 Graph* g, const gtl::ArraySlice<string>& fed_outputs,
295 const gtl::ArraySlice<string>& fetch_outputs,
296 const gtl::ArraySlice<string>& target_node_names,
297 const DeviceAttributes& device_info, bool use_function_convention,
298 RewriteGraphMetadata* out_metadata) {
299 std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites;
300 feed_rewrites.reserve(fed_outputs.size());
301 if (use_function_convention) {
302 for (size_t i = 0; i < fed_outputs.size(); ++i) {
303 feed_rewrites.emplace_back(new ArgFeedRewrite(
304 &fed_outputs[i], &device_info, static_cast<int32>(i)));
305 }
306 } else {
307 for (const string& fed_output : fed_outputs) {
308 feed_rewrites.emplace_back(
309 new RecvFeedRewrite(&fed_output, &device_info));
310 }
311 }
312
313 std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites;
314 fetch_rewrites.reserve(fetch_outputs.size());
315 if (use_function_convention) {
316 for (size_t i = 0; i < fetch_outputs.size(); ++i) {
317 fetch_rewrites.emplace_back(new RetvalFetchRewrite(
318 &fetch_outputs[i], &device_info, static_cast<int32>(i)));
319 }
320 } else {
321 for (const string& fetch_output : fetch_outputs) {
322 fetch_rewrites.emplace_back(
323 new SendFetchRewrite(&fetch_output, &device_info));
324 }
325 }
326
327 return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites,
328 target_node_names, out_metadata);
329}
330
331namespace {
332template <typename StringContainer>
333std::vector<string> ConvertToVector(StringContainer field) {
334 return std::vector<string>(field.begin(), field.end());
335}
336} // namespace
337
338Status RewriteGraphForExecution(
339 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
340 const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
341 const gtl::ArraySlice<string>& target_node_names,
342 RewriteGraphMetadata* out_metadata) {
343 if (fetch_rewrites.empty() && target_node_names.empty()) {
344 return errors::InvalidArgument(
345 "Must specify at least one target to fetch or execute.");
346 }
347
348 std::unordered_set<string> endpoints;
349 for (const auto& feed_rewrite : feed_rewrites) {
350 auto result = endpoints.insert(feed_rewrite->endpoint_name());
351 if (!result.second) {
352 return errors::InvalidArgument("Endpoint \"",
353 feed_rewrite->endpoint_name(),
354 "\" fed more than once.");
355 }
356 }
357
358 for (const auto& fetch_rewrite : fetch_rewrites) {
359 if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) {
360 return errors::InvalidArgument(fetch_rewrite->endpoint_name(),
361 " is both fed and fetched.");
362 }
363 }
364
365 // A separate index mapping name to Node*, for use by FeedInputs,
366 // FetchOutputs, and PruneForTargets
367 NameIndex name_index;
368 name_index.reserve(g->num_nodes());
369 for (Node* n : g->nodes()) {
370 name_index[n->name()] = n;
371 }
372
373 // Add the feeds. This may replace nodes in the graph, including the nodes
374 // currently listed in "fetch_rewrites". We pass "name_index" so the index is
375 // kept up to date.
376 if (!feed_rewrites.empty()) {
377 TF_RETURN_IF_ERROR(
378 FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types));
379 }
380
381 // Add the fetch nodes, also updating "name_index".
382 std::vector<Node*> fetch_nodes;
383 if (!fetch_rewrites.empty()) {
384 TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index,
385 &fetch_nodes, &out_metadata->fetch_types));
386 }
387
388 // Prune the graph to only compute what is needed for the fetch nodes and the
389 // target nodes.
390 if (!fetch_nodes.empty() || !target_node_names.empty()) {
391 TF_RETURN_IF_ERROR(
392 PruneForTargets(g, name_index, fetch_nodes, target_node_names));
393 }
394
395 return OkStatus();
396}
397
398} // namespace subgraph
399
400} // namespace tensorflow
401