1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
17#define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
18
19#include <string>
20
21#include "tensorflow/core/framework/device_attributes.pb.h"
22#include "tensorflow/core/graph/graph.h"
23#include "tensorflow/core/graph/node_builder.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/lib/gtl/array_slice.h"
26#include "tensorflow/core/protobuf/config.pb.h"
27
28namespace tensorflow {
29namespace subgraph {
30
31// Information about a graph rewritten by `RewriteGraphForExecution()`.
32struct RewriteGraphMetadata {
33 // The element type of each tensor fed to this subgraph. The order
34 // of types corresponds to the order of tensor names in
35 // `fed_outputs` when calling `RewriteGraphForExecution()`.
36 DataTypeVector feed_types;
37 // The element type of each tensor fetched from this subgraph. The
38 // order of types corresponds to the order of tensor names in
39 // `fetch_outputs` when calling `RewriteGraphForExecution()`.
40 DataTypeVector fetch_types;
41};
42
43// Describes the action to take on a particular tensor endpoint (described by
44// a "<node_name>:<output_index>" pair) when pruning the graph.
45//
46// The `AddNode()` method must be overridden to describe this action. The method
47// will be invoked once during `RewriteGraphForExecution()` with tensor endpoint
48// named by `endpoint_name`, and it may either create a single new node, or fail
49// with an error if the resulting graph would be invalid.
50class PruneRewrite {
51 public:
52 // `endpoint_name` and `device_info` must outlive this object.
53 PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info)
54 : endpoint_name_(endpoint_name), device_info_(device_info) {}
55 virtual ~PruneRewrite() {}
56
57 // Creates a new node whose output replaces the given `tensor` in graph `g`.
58 // The node will be assigned to the device named in `device_info`.
59 virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor,
60 Node** out_node) = 0;
61
62 // Returns the name of the tensor to which this rewrite applies.
63 const string& endpoint_name() { return *endpoint_name_; }
64
65 protected:
66 // The device on which the new node will be created.
67 const DeviceAttributes& device_info() { return *device_info_; }
68
69 private:
70 const string* const endpoint_name_; // Not owned.
71 const DeviceAttributes* const device_info_; // Not owned.
72};
73
74// Rewrite the graph structure of "*g" to deal with feeding node
75// outputs, fetching node outputs, and only running a subset of the
76// graph. "fed_outputs" and "fetch_outputs" are both lists of
77// output tensor identifiers in the form of
78// "<name>[:<optional_output_index>]", and "target_nodes_str" is a
79// lists of target node names in "*g" "g".
80//
81// In the resulting graph "*g", output edges in "fed_outputs" have
82// been redirected to special "_recv" nodes introduced into the graph.
83// If these fed nodes are not needed in order to compute the effects
84// of the nodes in "target_node_names" and "fetch_outputs", then these may
85// be omitted from the graph.
86//
87// In the resulting graph "*g", additional "_send" nodes are connected
88// to every output in "fetch_outputs". These "_send" nodes are set up
89// to execute on the device described by device_info.
90//
91// On success, returns OK, and sets "*g" to a version of "*g"
92// that represents the portions of the graph necessary for producing
93// the output of all nodes listed in "target_node_names" and fetching the
94// specific node outputs specified in "fetch_outputs".
95//
96// On failure, returns the error status. Possible errors include:
97// - fed output "node:output_index" does not exist in "*g"
98// - fetch output "node:output_index" does not exist in "*g"
99// - target node "node" does not exist in "*g"
100Status RewriteGraphForExecution(
101 Graph* g, const gtl::ArraySlice<string>& fed_outputs,
102 const gtl::ArraySlice<string>& fetch_outputs,
103 const gtl::ArraySlice<string>& target_node_names,
104 const DeviceAttributes& device_info, bool use_function_convention,
105 RewriteGraphMetadata* out_metadata);
106
107// A more general version of the above function that supports
108// customizable rewriting actions for each fed and fetched tensor.
109Status RewriteGraphForExecution(
110 Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites,
111 const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites,
112 const gtl::ArraySlice<string>& target_node_names,
113 RewriteGraphMetadata* out_metadata);
114
115/////////////////////////////////////////////////////////
116// Custom rewrite actions for fed and fetched tensors. //
117/////////////////////////////////////////////////////////
118
119// A rewrite action that adds an _Arg node for a fed tensor.
120class ArgFeedRewrite : public PruneRewrite {
121 public:
122 ArgFeedRewrite(const string* endpoint_name,
123 const DeviceAttributes* device_info, int32_t arg_index)
124 : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {}
125 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
126 Node** out_node) override;
127
128 private:
129 const int32 arg_index_;
130};
131
132// A rewrite action that adds a client-terminated _Recv node for a fed tensor.
133class RecvFeedRewrite : public PruneRewrite {
134 public:
135 using PruneRewrite::PruneRewrite;
136 Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
137 Node** out_node) override;
138};
139
140// A rewrite action that adds a _Retval node for a fetched tensor.
141class RetvalFetchRewrite : public PruneRewrite {
142 public:
143 RetvalFetchRewrite(const string* endpoint_name,
144 const DeviceAttributes* device_info, int32_t retval_index)
145 : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {}
146 Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
147 Node** out_node) override;
148
149 private:
150 const int32 retval_index_;
151};
152
153// A rewrite action that adds a client-terminated _Send node for a
154// fetched tensor.
155class SendFetchRewrite : public PruneRewrite {
156 public:
157 using PruneRewrite::PruneRewrite;
158 Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
159 Node** out_node) override;
160};
161
162} // namespace subgraph
163} // namespace tensorflow
164
165#endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
166