1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ |
17 | #define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ |
18 | |
19 | #include <string> |
20 | |
21 | #include "tensorflow/core/framework/device_attributes.pb.h" |
22 | #include "tensorflow/core/graph/graph.h" |
23 | #include "tensorflow/core/graph/node_builder.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/lib/gtl/array_slice.h" |
26 | #include "tensorflow/core/protobuf/config.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace subgraph { |
30 | |
31 | // Information about a graph rewritten by `RewriteGraphForExecution()`. |
32 | struct RewriteGraphMetadata { |
33 | // The element type of each tensor fed to this subgraph. The order |
34 | // of types corresponds to the order of tensor names in |
35 | // `fed_outputs` when calling `RewriteGraphForExecution()`. |
36 | DataTypeVector feed_types; |
37 | // The element type of each tensor fetched from this subgraph. The |
38 | // order of types corresponds to the order of tensor names in |
39 | // `fetch_outputs` when calling `RewriteGraphForExecution()`. |
40 | DataTypeVector fetch_types; |
41 | }; |
42 | |
43 | // Describes the action to take on a particular tensor endpoint (described by |
44 | // a "<node_name>:<output_index>" pair) when pruning the graph. |
45 | // |
46 | // The `AddNode()` method must be overridden to describe this action. The method |
47 | // will be invoked once during `RewriteGraphForExecution()` with tensor endpoint |
48 | // named by `endpoint_name`, and it may either create a single new node, or fail |
49 | // with an error if the resulting graph would be invalid. |
50 | class PruneRewrite { |
51 | public: |
52 | // `endpoint_name` and `device_info` must outlive this object. |
53 | PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info) |
54 | : endpoint_name_(endpoint_name), device_info_(device_info) {} |
55 | virtual ~PruneRewrite() {} |
56 | |
57 | // Creates a new node whose output replaces the given `tensor` in graph `g`. |
58 | // The node will be assigned to the device named in `device_info`. |
59 | virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor, |
60 | Node** out_node) = 0; |
61 | |
62 | // Returns the name of the tensor to which this rewrite applies. |
63 | const string& endpoint_name() { return *endpoint_name_; } |
64 | |
65 | protected: |
66 | // The device on which the new node will be created. |
67 | const DeviceAttributes& device_info() { return *device_info_; } |
68 | |
69 | private: |
70 | const string* const endpoint_name_; // Not owned. |
71 | const DeviceAttributes* const device_info_; // Not owned. |
72 | }; |
73 | |
74 | // Rewrite the graph structure of "*g" to deal with feeding node |
75 | // outputs, fetching node outputs, and only running a subset of the |
76 | // graph. "fed_outputs" and "fetch_outputs" are both lists of |
77 | // output tensor identifiers in the form of |
78 | // "<name>[:<optional_output_index>]", and "target_nodes_str" is a |
79 | // lists of target node names in "*g" "g". |
80 | // |
81 | // In the resulting graph "*g", output edges in "fed_outputs" have |
82 | // been redirected to special "_recv" nodes introduced into the graph. |
83 | // If these fed nodes are not needed in order to compute the effects |
84 | // of the nodes in "target_node_names" and "fetch_outputs", then these may |
85 | // be omitted from the graph. |
86 | // |
87 | // In the resulting graph "*g", additional "_send" nodes are connected |
88 | // to every output in "fetch_outputs". These "_send" nodes are set up |
89 | // to execute on the device described by device_info. |
90 | // |
91 | // On success, returns OK, and sets "*g" to a version of "*g" |
92 | // that represents the portions of the graph necessary for producing |
93 | // the output of all nodes listed in "target_node_names" and fetching the |
94 | // specific node outputs specified in "fetch_outputs". |
95 | // |
96 | // On failure, returns the error status. Possible errors include: |
97 | // - fed output "node:output_index" does not exist in "*g" |
98 | // - fetch output "node:output_index" does not exist in "*g" |
99 | // - target node "node" does not exist in "*g" |
100 | Status RewriteGraphForExecution( |
101 | Graph* g, const gtl::ArraySlice<string>& fed_outputs, |
102 | const gtl::ArraySlice<string>& fetch_outputs, |
103 | const gtl::ArraySlice<string>& target_node_names, |
104 | const DeviceAttributes& device_info, bool use_function_convention, |
105 | RewriteGraphMetadata* out_metadata); |
106 | |
107 | // A more general version of the above function that supports |
108 | // customizable rewriting actions for each fed and fetched tensor. |
109 | Status RewriteGraphForExecution( |
110 | Graph* g, const std::vector<std::unique_ptr<PruneRewrite>>& feed_rewrites, |
111 | const std::vector<std::unique_ptr<PruneRewrite>>& fetch_rewrites, |
112 | const gtl::ArraySlice<string>& target_node_names, |
113 | RewriteGraphMetadata* out_metadata); |
114 | |
115 | ///////////////////////////////////////////////////////// |
116 | // Custom rewrite actions for fed and fetched tensors. // |
117 | ///////////////////////////////////////////////////////// |
118 | |
119 | // A rewrite action that adds an _Arg node for a fed tensor. |
120 | class ArgFeedRewrite : public PruneRewrite { |
121 | public: |
122 | ArgFeedRewrite(const string* endpoint_name, |
123 | const DeviceAttributes* device_info, int32_t arg_index) |
124 | : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {} |
125 | Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, |
126 | Node** out_node) override; |
127 | |
128 | private: |
129 | const int32 arg_index_; |
130 | }; |
131 | |
132 | // A rewrite action that adds a client-terminated _Recv node for a fed tensor. |
133 | class RecvFeedRewrite : public PruneRewrite { |
134 | public: |
135 | using PruneRewrite::PruneRewrite; |
136 | Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, |
137 | Node** out_node) override; |
138 | }; |
139 | |
140 | // A rewrite action that adds a _Retval node for a fetched tensor. |
141 | class RetvalFetchRewrite : public PruneRewrite { |
142 | public: |
143 | RetvalFetchRewrite(const string* endpoint_name, |
144 | const DeviceAttributes* device_info, int32_t retval_index) |
145 | : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {} |
146 | Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, |
147 | Node** out_node) override; |
148 | |
149 | private: |
150 | const int32 retval_index_; |
151 | }; |
152 | |
153 | // A rewrite action that adds a client-terminated _Send node for a |
154 | // fetched tensor. |
155 | class SendFetchRewrite : public PruneRewrite { |
156 | public: |
157 | using PruneRewrite::PruneRewrite; |
158 | Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, |
159 | Node** out_node) override; |
160 | }; |
161 | |
162 | } // namespace subgraph |
163 | } // namespace tensorflow |
164 | |
165 | #endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ |
166 | |