1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ |
18 | |
19 | #include <functional> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/common_runtime/build_graph_options.h" |
25 | #include "tensorflow/core/common_runtime/device.h" |
26 | #include "tensorflow/core/common_runtime/device_set.h" |
27 | #include "tensorflow/core/framework/function.h" |
28 | #include "tensorflow/core/framework/graph.pb.h" |
29 | #include "tensorflow/core/graph/costmodel.h" |
30 | #include "tensorflow/core/graph/graph.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/platform/macros.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | |
35 | namespace tensorflow { |
36 | struct SessionOptions; |
37 | |
38 | namespace subgraph { |
39 | struct RewriteGraphMetadata; |
40 | } |
41 | |
42 | struct GraphExecutionStateOptions { |
43 | const DeviceSet* device_set = nullptr; |
44 | const SessionOptions* session_options = nullptr; |
45 | // Unique session identifier. Can be empty. |
46 | string session_handle; |
47 | // A map from node name to device name, representing the unchangeable |
48 | // placement of stateful nodes. |
49 | std::unordered_map<string, string> stateful_placements; |
50 | }; |
51 | |
52 | // A ClientGraph is simply a sub-graph of the full graph as induced by |
53 | // BuildGraphOptions. |
54 | struct ClientGraph { |
55 | explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, |
56 | DataTypeVector feed_types, DataTypeVector fetch_types, |
57 | int64_t collective_graph_key) |
58 | : flib_def(std::move(flib)), |
59 | graph(flib_def.get()), |
60 | feed_types(std::move(feed_types)), |
61 | fetch_types(std::move(fetch_types)), |
62 | collective_graph_key(collective_graph_key) {} |
63 | // Each client-graph gets its own function library since optimization passes |
64 | // post rewrite for execution might want to introduce new functions. |
65 | std::unique_ptr<FunctionLibraryDefinition> flib_def; |
66 | Graph graph; |
67 | DataTypeVector feed_types; |
68 | DataTypeVector fetch_types; |
69 | int64_t collective_graph_key; |
70 | }; |
71 | |
72 | // GraphExecutionState is responsible for generating an |
73 | // executable ClientGraph from the original GraphDef that specifies |
74 | // the complete graph and from BuildGraphOptions which specifies |
75 | // input/output nodes. |
76 | // |
77 | // An executable Graph differs from a GraphDef by being Placed, |
78 | // meaning that each Node is assigned to a single Device in the |
79 | // available set. |
80 | // |
81 | // When GraphExecutionState is first constructed it instantiates |
82 | // a full Graph from the provided GraphDef, and places it, using only |
83 | // the static device assignments from the GraphDef. Nodes without are |
84 | // currently placed in a very naive way. Since stateful Nodes cannot |
85 | // be moved after initial placement, it is important that stateful |
86 | // Nodes get sensible initial device assignments in the graph |
87 | // definition. |
88 | // |
89 | // Subsequently, GraphExecutionState generates a SimpleClientGraph on |
90 | // demand, which is a sub-graph of the latest placement of the full |
91 | // Graph. MasterSession uses such a ClientGraph to execute one or |
92 | // more similar client requests. |
93 | // |
94 | // GraphExecutionState is thread-safe. |
95 | |
96 | class GraphExecutionState { |
97 | public: |
98 | virtual ~GraphExecutionState(); |
99 | |
100 | // Creates a new `GraphExecutionState` for the given |
101 | // `graph_def`, which represents the entire graph for a session. |
102 | static Status MakeForBaseGraph( |
103 | GraphDef&& graph_def, const GraphExecutionStateOptions& options, |
104 | std::unique_ptr<GraphExecutionState>* out_state); |
105 | |
106 | // Creates a new `GraphExecutionState` and `SimpleClientGraph` |
107 | // for the subgraph of `original_graph_def` defined by |
108 | // `subgraph_options`. |
109 | static Status MakeForPrunedGraph( |
110 | const GraphExecutionState& base_execution_state, |
111 | const GraphExecutionStateOptions& options, |
112 | const BuildGraphOptions& subgraph_options, |
113 | std::unique_ptr<GraphExecutionState>* out_state, |
114 | std::unique_ptr<ClientGraph>* out_client_graph); |
115 | |
116 | // Creates a new GraphExecutionState representing the |
117 | // concatenation of this graph, and the graph defined by |
118 | // "extension_def". The same name may not be used to define a node |
119 | // in both this graph and "extension_def". |
120 | // |
121 | // If successful, returns OK and the caller takes ownership of "*out". |
122 | // Otherwise returns an error and does not modify "*out". |
123 | // |
124 | // After calling `old_state->Extend()`, `old_state` may no longer be |
125 | // used. |
126 | // |
127 | // NOTE(mrry): This method respects the placement of stateful nodes in |
128 | // in *this, but currently does not transfer any other placement |
129 | // or cost model information to the new graph. |
130 | Status Extend(const GraphDef& extension_def, |
131 | std::unique_ptr<GraphExecutionState>* out) const; |
132 | |
133 | // Builds a ClientGraph (a sub-graph of the full graph as induced by |
134 | // the Node set specified in "options"). If successful, returns OK |
135 | // and the caller takes the ownership of "*out". Otherwise, returns |
136 | // an error. |
137 | Status BuildGraph(const BuildGraphOptions& options, |
138 | std::unique_ptr<ClientGraph>* out); |
139 | |
140 | // Optimize the graph with the node set specified in `options`. |
141 | Status OptimizeGraph( |
142 | const BuildGraphOptions& options, const Graph& graph, |
143 | const FunctionLibraryDefinition* flib_def, |
144 | std::unique_ptr<Graph>* optimized_graph, |
145 | std::unique_ptr<FunctionLibraryDefinition>* optimized_flib); |
146 | |
147 | // The graph returned by BuildGraph may contain only the pruned |
148 | // graph, whereas some clients may want access to the full graph. |
149 | const Graph* full_graph() { return graph_; } |
150 | |
151 | // The original graph. |
152 | GraphDef* original_graph_def() { return original_graph_def_.get(); } |
153 | |
154 | // The original function library of this graph. |
155 | const FunctionLibraryDefinition& flib_def() const { return *flib_def_; } |
156 | |
157 | // Returns the node with the given name, or null if it does not exist. |
158 | const Node* get_node_by_name(const string& name) const { |
159 | NodeNameToCostIdMap::const_iterator iter = |
160 | node_name_to_cost_id_map_.find(name); |
161 | if (iter != node_name_to_cost_id_map_.end()) { |
162 | return graph_->FindNodeId(iter->second); |
163 | } else { |
164 | return nullptr; |
165 | } |
166 | } |
167 | |
168 | // Returns the map of stateful placements as a map of |
169 | // node name to placement string. |
170 | std::unordered_map<string, string> GetStatefulPlacements() const { |
171 | return stateful_placements_; |
172 | } |
173 | |
174 | private: |
175 | GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def, |
176 | std::unique_ptr<FunctionLibraryDefinition>&& flib_def, |
177 | const GraphExecutionStateOptions& options); |
178 | |
179 | Status InitBaseGraph(std::unique_ptr<Graph>&& graph); |
180 | |
181 | // Map of placed stateful nodes, i.e. nodes for which is_stateful() |
182 | // is true, such as "params" and "queue" nodes. Once placed these |
183 | // nodes can not be moved to a different device. Maps node names to |
184 | // device names. |
185 | std::unordered_map<string, string> stateful_placements_; // Immutable after |
186 | // ctor. |
187 | void SaveStatefulNodes(Graph* graph); |
188 | void RestoreStatefulNodes(Graph* graph); |
189 | |
190 | // Extract the subset of the graph that needs to be run, adding feed/fetch |
191 | // ops as needed. |
192 | Status PruneGraph(const BuildGraphOptions& options, Graph* graph, |
193 | subgraph::RewriteGraphMetadata* out_rewrite_metadata); |
194 | |
195 | // The GraphExecutionState must store a copy of the original GraphDef if |
196 | // either of the following conditions holds: |
197 | // |
198 | // * `session_options_.config.graph_options().place_pruned_graph()` is true. |
199 | // * `session_options_.config.experimental().optimize_for_static_graph()` is |
200 | // false. |
201 | const std::unique_ptr<GraphDef> original_graph_def_; |
202 | |
203 | const DeviceSet* device_set_; // Not owned |
204 | const SessionOptions* session_options_; // Not owned |
205 | // Unique session identifier. Can be empty. |
206 | string session_handle_; |
207 | |
208 | // Map from name to Node for the full graph in placed_. |
209 | NodeNameToCostIdMap node_name_to_cost_id_map_; |
210 | |
211 | // 'flib_def_' is initialized from the initial graph def's library, |
212 | // and may be updated by a graph optimization pass. |
213 | std::unique_ptr<FunctionLibraryDefinition> flib_def_; |
214 | |
215 | // `rewrite_metadata_` is only set for GraphExecutionState |
216 | // objects created by `MakeForPrunedGraph()`. |
217 | std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_; |
218 | |
219 | // The dataflow graph owned by this object. |
220 | Graph* graph_; |
221 | |
222 | TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); |
223 | }; |
224 | |
225 | } // namespace tensorflow |
226 | |
227 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ |
228 | |