1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/subgraph.h" |
17 | |
18 | #include <algorithm> |
19 | #include <deque> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <unordered_set> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/core/framework/graph.pb.h" |
26 | #include "tensorflow/core/framework/node_def_util.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/graph/algorithm.h" |
29 | #include "tensorflow/core/graph/graph.h" |
30 | #include "tensorflow/core/graph/tensor_id.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/platform/logging.h" |
34 | |
35 | namespace tensorflow { |
36 | namespace subgraph { |
37 | |
38 | // ---------------------------------------------------------------------------- |
39 | // Subgraph construction-related routines |
40 | // ---------------------------------------------------------------------------- |
41 | // TODO(vrv): Profile the unordered_set and unordered_map use in this file to |
42 | // see if we should use an alternative implementation. |
43 | |
44 | namespace { |
45 | |
46 | typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameIndex; |
47 | |
48 | // Rewrite graph by replacing the output tensors specified in |
49 | // "fed_outputs" with special feed nodes for each specified output |
50 | // tensor, and removing any nodes that are now disconnected from the |
51 | // part of the graph that reaches the sink node. The set of special |
52 | // feed nodes added to the graph are returned in "*feed_nodes". |
53 | // |
54 | // Return true on success. On error, return false and sets *error to |
55 | // an appropriate error message (and *g is left in an indeterminate |
56 | // state). |
57 | Status FeedInputs( |
58 | Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites, |
59 | NameIndex* name_index, DataTypeVector* out_feed_types) { |
60 | out_feed_types->clear(); |
61 | out_feed_types->reserve(feed_rewrites.size()); |
62 | for (size_t i = 0; i < feed_rewrites.size(); ++i) { |
63 | const string& t = feed_rewrites[i]->endpoint_name(); |
64 | TensorId id(ParseTensorName(t)); |
65 | |
66 | auto iter = name_index->find(id.first); |
67 | if (iter == name_index->end()) { |
68 | return errors::NotFound("FeedInputs: unable to find feed output " , t); |
69 | } |
70 | Node* n = iter->second; |
71 | DCHECK_EQ(n->name(), id.first); |
72 | if (id.second >= n->num_outputs()) { |
73 | return errors::InvalidArgument( |
74 | "FeedInputs: " , t, " should have output index < " , n->num_outputs()); |
75 | } |
76 | |
77 | Node* feed_node; |
78 | TF_RETURN_IF_ERROR( |
79 | feed_rewrites[i]->AddNode(g, {n, id.second}, &feed_node)); |
80 | |
81 | // Update name_index |
82 | (*name_index)[feed_node->name()] = feed_node; |
83 | // Duplicate control edges aren't allowed, but feed_node was *just* created |
84 | // so there's no need to check for a duplicate. |
85 | g->AddControlEdge(g->source_node(), feed_node, true); |
86 | |
87 | // Look through edges coming out of "n" for edges whose src_output() index |
88 | // matches "output_index". If found, replace the edges with a connection |
89 | // from the special feed node. |
90 | std::vector<const Edge*> to_remove; |
91 | for (const Edge* e : n->out_edges()) { |
92 | if (e->src_output() == id.second) { |
93 | to_remove.emplace_back(e); |
94 | } else if (e->src_output() == Graph::kControlSlot && |
95 | (n->type_string() == "Placeholder" || |
96 | n->type_string() == "PlaceholderV2" )) { |
97 | // When feeding a Placeholder node, any outgoing control edges |
98 | // will be replaced with a control edge from the replacement |
99 | // feed_node. |
100 | // TODO(josh11b,mrry): Come up with a more elegant way of addressing |
101 | // the general version of this problem. |
102 | to_remove.emplace_back(e); |
103 | } |
104 | } |
105 | |
106 | for (const Edge* e : to_remove) { |
107 | if (e->src_output() == id.second) { |
108 | g->AddEdge(feed_node, 0, e->dst(), e->dst_input()); |
109 | } else { |
110 | CHECK_EQ(Graph::kControlSlot, e->src_output()); |
111 | // Duplicate control edges aren't allowed, but feed_node was *just* |
112 | // created so there's no need to check for a duplicate. |
113 | g->AddControlEdge(feed_node, e->dst(), true); |
114 | } |
115 | g->RemoveEdge(e); |
116 | } |
117 | out_feed_types->push_back(BaseType(n->output_type(id.second))); |
118 | } |
119 | return OkStatus(); |
120 | } |
121 | |
122 | Status FetchOutputs( |
123 | Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites, |
124 | NameIndex* name_index, std::vector<Node*>* out_fetch_nodes, |
125 | DataTypeVector* out_fetch_types) { |
126 | out_fetch_nodes->clear(); |
127 | out_fetch_nodes->reserve(fetch_rewrites.size()); |
128 | for (size_t i = 0; i < fetch_rewrites.size(); ++i) { |
129 | const string& t = fetch_rewrites[i]->endpoint_name(); |
130 | |
131 | // Parse t into node_name and output_index. |
132 | TensorId id(ParseTensorName(t)); |
133 | |
134 | // Find node in graph with that name. |
135 | auto iter = name_index->find(id.first); |
136 | if (iter == name_index->end()) { |
137 | return errors::NotFound("FetchOutputs node " , t, ": not found" ); |
138 | } |
139 | Node* n = iter->second; |
140 | DCHECK_EQ(n->name(), id.first); |
141 | VLOG(2) << "Found fetch node for " << t; |
142 | |
143 | // Validate output_index |
144 | if (n->num_outputs() == 0) { |
145 | return errors::InvalidArgument( |
146 | "Tried to fetch data for '" , t, |
147 | "', which produces no output. To run to a node but not fetch any " |
148 | "data, pass '" , |
149 | t, |
150 | "' as an argument to the 'target_node_names' argument of the " |
151 | "Session::Run API." ); |
152 | } else if (id.second >= n->num_outputs()) { |
153 | return errors::InvalidArgument("FetchOutputs " , t, |
154 | ": output index too large, must be < " , |
155 | n->num_outputs()); |
156 | } |
157 | |
158 | // Create the fetch Node and connect it up |
159 | Node* fetch_node; |
160 | TF_RETURN_IF_ERROR( |
161 | fetch_rewrites[i]->AddNode(g, {n, id.second}, &fetch_node)); |
162 | |
163 | // Update the index. |
164 | (*name_index)[fetch_node->name()] = fetch_node; |
165 | |
166 | // Duplicate control edges aren't allowed, but fetch_node was *just* created |
167 | // so there's no need to check for a duplicate. |
168 | g->AddControlEdge(fetch_node, g->sink_node(), true); |
169 | out_fetch_nodes->push_back(fetch_node); |
170 | out_fetch_types->push_back(BaseType(n->output_type(id.second))); |
171 | } |
172 | |
173 | return OkStatus(); |
174 | } |
175 | |
176 | bool AddNodeToTargets(const string& node_or_tensor_name, |
177 | const NameIndex& name_index, |
178 | std::unordered_set<const Node*>* targets) { |
179 | TensorId id = ParseTensorName(node_or_tensor_name); |
180 | auto iter = name_index.find(id.first); |
181 | if (iter == name_index.end()) { |
182 | return false; |
183 | } |
184 | const Node* n = iter->second; |
185 | CHECK_EQ(n->name(), id.first); |
186 | targets->insert(n); |
187 | return true; |
188 | } |
189 | |
190 | Status PruneForTargets(Graph* g, const NameIndex& name_index, |
191 | const std::vector<Node*>& fetch_nodes, |
192 | const gtl::ArraySlice<string>& target_nodes) { |
193 | string not_found; |
194 | std::unordered_set<const Node*> targets; |
195 | for (Node* n : fetch_nodes) { |
196 | if (!AddNodeToTargets(n->name(), name_index, &targets)) { |
197 | strings::StrAppend(¬_found, n->name(), " " ); |
198 | } |
199 | } |
200 | for (const string& s : target_nodes) { |
201 | if (!AddNodeToTargets(s, name_index, &targets)) { |
202 | strings::StrAppend(¬_found, s, " " ); |
203 | } |
204 | } |
205 | if (!not_found.empty()) { |
206 | return errors::NotFound("PruneForTargets: Some target nodes not found: " , |
207 | not_found); |
208 | } |
209 | PruneForReverseReachability(g, std::move(targets)); |
210 | |
211 | // Reconnect nodes with no outgoing edges to the sink node |
212 | FixupSourceAndSinkEdges(g); |
213 | |
214 | return OkStatus(); |
215 | } |
216 | |
217 | } // namespace |
218 | |
219 | Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, |
220 | Node** out_node) { |
221 | // NOTE(mrry): We must include the index as part of the node |
222 | // name, because _Arg is a "stateful" kernel and therefore |
223 | // its name must uniquely identify a kernel instance across all |
224 | // graphs in the same session. |
225 | TF_RETURN_IF_ERROR( |
226 | NodeBuilder(strings::StrCat("_arg_" , feed_tensor.node->name(), "_" , |
227 | feed_tensor.index, "_" , arg_index_), |
228 | "_Arg" ) |
229 | .Attr("T" , BaseType(feed_tensor.node->output_type(feed_tensor.index))) |
230 | .Attr("index" , arg_index_) |
231 | .Finalize(g, out_node, /*consume=*/true)); |
232 | (*out_node)->set_assigned_device_name(device_info().name()); |
233 | return OkStatus(); |
234 | } |
235 | |
236 | Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, |
237 | Node** out_node) { |
238 | TF_RETURN_IF_ERROR( |
239 | NodeBuilder(strings::StrCat("_recv_" , feed_tensor.node->name(), "_" , |
240 | feed_tensor.index), |
241 | "_Recv" ) |
242 | .Attr("tensor_type" , |
243 | BaseType(feed_tensor.node->output_type(feed_tensor.index))) |
244 | .Attr("tensor_name" , endpoint_name()) |
245 | .Attr("send_device" , device_info().name()) |
246 | .Attr("recv_device" , device_info().name()) |
247 | .Attr("send_device_incarnation" , |
248 | static_cast<int64_t>(device_info().incarnation())) |
249 | .Attr("client_terminated" , true) |
250 | .Finalize(g, out_node, /*consume=*/true)); |
251 | |
252 | (*out_node)->set_assigned_device_name(device_info().name()); |
253 | return OkStatus(); |
254 | } |
255 | |
256 | Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, |
257 | Node** out_node) { |
258 | // NOTE(mrry): We must include the index as part of the node |
259 | // name, because _Retval is a "stateful" kernel and therefore |
260 | // its name must uniquely identify a kernel instance across all |
261 | // graphs in the same session. |
262 | TF_RETURN_IF_ERROR( |
263 | NodeBuilder(strings::StrCat("_retval_" , fetch_tensor.node->name(), "_" , |
264 | fetch_tensor.index, "_" , retval_index_), |
265 | "_Retval" ) |
266 | .Input(fetch_tensor.node, fetch_tensor.index) |
267 | .Attr("T" , |
268 | BaseType(fetch_tensor.node->output_type(fetch_tensor.index))) |
269 | .Attr("index" , retval_index_) |
270 | .Finalize(g, out_node, /*consume=*/true)); |
271 | (*out_node)->set_assigned_device_name(device_info().name()); |
272 | return OkStatus(); |
273 | } |
274 | |
275 | Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, |
276 | Node** out_node) { |
277 | TF_RETURN_IF_ERROR( |
278 | NodeBuilder(strings::StrCat("_send_" , fetch_tensor.node->name(), "_" , |
279 | fetch_tensor.index), |
280 | "_Send" ) |
281 | .Input(fetch_tensor.node, fetch_tensor.index) |
282 | .Attr("tensor_name" , endpoint_name()) |
283 | .Attr("send_device" , device_info().name()) |
284 | .Attr("recv_device" , device_info().name()) |
285 | .Attr("send_device_incarnation" , |
286 | static_cast<int64_t>(device_info().incarnation())) |
287 | .Attr("client_terminated" , true) |
288 | .Finalize(g, out_node, /*consume=*/true)); |
289 | (*out_node)->set_assigned_device_name(device_info().name()); |
290 | return OkStatus(); |
291 | } |
292 | |
293 | Status RewriteGraphForExecution( |
294 | Graph* g, const gtl::ArraySlice<string>& fed_outputs, |
295 | const gtl::ArraySlice<string>& fetch_outputs, |
296 | const gtl::ArraySlice<string>& target_node_names, |
297 | const DeviceAttributes& device_info, bool use_function_convention, |
298 | RewriteGraphMetadata* out_metadata) { |
299 | std::vector<std::unique_ptr<PruneRewrite>> feed_rewrites; |
300 | feed_rewrites.reserve(fed_outputs.size()); |
301 | if (use_function_convention) { |
302 | for (size_t i = 0; i < fed_outputs.size(); ++i) { |
303 | feed_rewrites.emplace_back(new ArgFeedRewrite( |
304 | &fed_outputs[i], &device_info, static_cast<int32>(i))); |
305 | } |
306 | } else { |
307 | for (const string& fed_output : fed_outputs) { |
308 | feed_rewrites.emplace_back( |
309 | new RecvFeedRewrite(&fed_output, &device_info)); |
310 | } |
311 | } |
312 | |
313 | std::vector<std::unique_ptr<PruneRewrite>> fetch_rewrites; |
314 | fetch_rewrites.reserve(fetch_outputs.size()); |
315 | if (use_function_convention) { |
316 | for (size_t i = 0; i < fetch_outputs.size(); ++i) { |
317 | fetch_rewrites.emplace_back(new RetvalFetchRewrite( |
318 | &fetch_outputs[i], &device_info, static_cast<int32>(i))); |
319 | } |
320 | } else { |
321 | for (const string& fetch_output : fetch_outputs) { |
322 | fetch_rewrites.emplace_back( |
323 | new SendFetchRewrite(&fetch_output, &device_info)); |
324 | } |
325 | } |
326 | |
327 | return RewriteGraphForExecution(g, feed_rewrites, fetch_rewrites, |
328 | target_node_names, out_metadata); |
329 | } |
330 | |
331 | namespace { |
332 | template <typename StringContainer> |
333 | std::vector<string> ConvertToVector(StringContainer field) { |
334 | return std::vector<string>(field.begin(), field.end()); |
335 | } |
336 | } // namespace |
337 | |
338 | Status RewriteGraphForExecution( |
339 | Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites, |
340 | const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites, |
341 | const gtl::ArraySlice<string>& target_node_names, |
342 | RewriteGraphMetadata* out_metadata) { |
343 | if (fetch_rewrites.empty() && target_node_names.empty()) { |
344 | return errors::InvalidArgument( |
345 | "Must specify at least one target to fetch or execute." ); |
346 | } |
347 | |
348 | std::unordered_set<string> endpoints; |
349 | for (const auto& feed_rewrite : feed_rewrites) { |
350 | auto result = endpoints.insert(feed_rewrite->endpoint_name()); |
351 | if (!result.second) { |
352 | return errors::InvalidArgument("Endpoint \"" , |
353 | feed_rewrite->endpoint_name(), |
354 | "\" fed more than once." ); |
355 | } |
356 | } |
357 | |
358 | for (const auto& fetch_rewrite : fetch_rewrites) { |
359 | if (endpoints.count(fetch_rewrite->endpoint_name()) > 0) { |
360 | return errors::InvalidArgument(fetch_rewrite->endpoint_name(), |
361 | " is both fed and fetched." ); |
362 | } |
363 | } |
364 | |
365 | // A separate index mapping name to Node*, for use by FeedInputs, |
366 | // FetchOutputs, and PruneForTargets |
367 | NameIndex name_index; |
368 | name_index.reserve(g->num_nodes()); |
369 | for (Node* n : g->nodes()) { |
370 | name_index[n->name()] = n; |
371 | } |
372 | |
373 | // Add the feeds. This may replace nodes in the graph, including the nodes |
374 | // currently listed in "fetch_rewrites". We pass "name_index" so the index is |
375 | // kept up to date. |
376 | if (!feed_rewrites.empty()) { |
377 | TF_RETURN_IF_ERROR( |
378 | FeedInputs(g, feed_rewrites, &name_index, &out_metadata->feed_types)); |
379 | } |
380 | |
381 | // Add the fetch nodes, also updating "name_index". |
382 | std::vector<Node*> fetch_nodes; |
383 | if (!fetch_rewrites.empty()) { |
384 | TF_RETURN_IF_ERROR(FetchOutputs(g, fetch_rewrites, &name_index, |
385 | &fetch_nodes, &out_metadata->fetch_types)); |
386 | } |
387 | |
388 | // Prune the graph to only compute what is needed for the fetch nodes and the |
389 | // target nodes. |
390 | if (!fetch_nodes.empty() || !target_node_names.empty()) { |
391 | TF_RETURN_IF_ERROR( |
392 | PruneForTargets(g, name_index, fetch_nodes, target_node_names)); |
393 | } |
394 | |
395 | return OkStatus(); |
396 | } |
397 | |
398 | } // namespace subgraph |
399 | |
400 | } // namespace tensorflow |
401 | |