1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
18
19#include <functional>
20#include <memory>
21#include <string>
22#include <vector>
23
24#include "tensorflow/core/common_runtime/build_graph_options.h"
25#include "tensorflow/core/common_runtime/device.h"
26#include "tensorflow/core/common_runtime/device_set.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/graph.pb.h"
29#include "tensorflow/core/graph/costmodel.h"
30#include "tensorflow/core/graph/graph.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/platform/macros.h"
33#include "tensorflow/core/platform/types.h"
34
35namespace tensorflow {
36struct SessionOptions;
37
38namespace subgraph {
39struct RewriteGraphMetadata;
40}
41
42struct GraphExecutionStateOptions {
43 const DeviceSet* device_set = nullptr;
44 const SessionOptions* session_options = nullptr;
45 // Unique session identifier. Can be empty.
46 string session_handle;
47 // A map from node name to device name, representing the unchangeable
48 // placement of stateful nodes.
49 std::unordered_map<string, string> stateful_placements;
50};
51
52// A ClientGraph is simply a sub-graph of the full graph as induced by
53// BuildGraphOptions.
54struct ClientGraph {
55 explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
56 DataTypeVector feed_types, DataTypeVector fetch_types,
57 int64_t collective_graph_key)
58 : flib_def(std::move(flib)),
59 graph(flib_def.get()),
60 feed_types(std::move(feed_types)),
61 fetch_types(std::move(fetch_types)),
62 collective_graph_key(collective_graph_key) {}
63 // Each client-graph gets its own function library since optimization passes
64 // post rewrite for execution might want to introduce new functions.
65 std::unique_ptr<FunctionLibraryDefinition> flib_def;
66 Graph graph;
67 DataTypeVector feed_types;
68 DataTypeVector fetch_types;
69 int64_t collective_graph_key;
70};
71
72// GraphExecutionState is responsible for generating an
73// executable ClientGraph from the original GraphDef that specifies
74// the complete graph and from BuildGraphOptions which specifies
75// input/output nodes.
76//
77// An executable Graph differs from a GraphDef by being Placed,
78// meaning that each Node is assigned to a single Device in the
79// available set.
80//
81// When GraphExecutionState is first constructed it instantiates
82// a full Graph from the provided GraphDef, and places it, using only
83// the static device assignments from the GraphDef. Nodes without are
84// currently placed in a very naive way. Since stateful Nodes cannot
85// be moved after initial placement, it is important that stateful
86// Nodes get sensible initial device assignments in the graph
87// definition.
88//
89// Subsequently, GraphExecutionState generates a SimpleClientGraph on
90// demand, which is a sub-graph of the latest placement of the full
91// Graph. MasterSession uses such a ClientGraph to execute one or
92// more similar client requests.
93//
94// GraphExecutionState is thread-safe.
95
96class GraphExecutionState {
97 public:
98 virtual ~GraphExecutionState();
99
100 // Creates a new `GraphExecutionState` for the given
101 // `graph_def`, which represents the entire graph for a session.
102 static Status MakeForBaseGraph(
103 GraphDef&& graph_def, const GraphExecutionStateOptions& options,
104 std::unique_ptr<GraphExecutionState>* out_state);
105
106 // Creates a new `GraphExecutionState` and `SimpleClientGraph`
107 // for the subgraph of `original_graph_def` defined by
108 // `subgraph_options`.
109 static Status MakeForPrunedGraph(
110 const GraphExecutionState& base_execution_state,
111 const GraphExecutionStateOptions& options,
112 const BuildGraphOptions& subgraph_options,
113 std::unique_ptr<GraphExecutionState>* out_state,
114 std::unique_ptr<ClientGraph>* out_client_graph);
115
116 // Creates a new GraphExecutionState representing the
117 // concatenation of this graph, and the graph defined by
118 // "extension_def". The same name may not be used to define a node
119 // in both this graph and "extension_def".
120 //
121 // If successful, returns OK and the caller takes ownership of "*out".
122 // Otherwise returns an error and does not modify "*out".
123 //
124 // After calling `old_state->Extend()`, `old_state` may no longer be
125 // used.
126 //
127 // NOTE(mrry): This method respects the placement of stateful nodes in
128 // in *this, but currently does not transfer any other placement
129 // or cost model information to the new graph.
130 Status Extend(const GraphDef& extension_def,
131 std::unique_ptr<GraphExecutionState>* out) const;
132
133 // Builds a ClientGraph (a sub-graph of the full graph as induced by
134 // the Node set specified in "options"). If successful, returns OK
135 // and the caller takes the ownership of "*out". Otherwise, returns
136 // an error.
137 Status BuildGraph(const BuildGraphOptions& options,
138 std::unique_ptr<ClientGraph>* out);
139
140 // Optimize the graph with the node set specified in `options`.
141 Status OptimizeGraph(
142 const BuildGraphOptions& options, const Graph& graph,
143 const FunctionLibraryDefinition* flib_def,
144 std::unique_ptr<Graph>* optimized_graph,
145 std::unique_ptr<FunctionLibraryDefinition>* optimized_flib);
146
147 // The graph returned by BuildGraph may contain only the pruned
148 // graph, whereas some clients may want access to the full graph.
149 const Graph* full_graph() { return graph_; }
150
151 // The original graph.
152 GraphDef* original_graph_def() { return original_graph_def_.get(); }
153
154 // The original function library of this graph.
155 const FunctionLibraryDefinition& flib_def() const { return *flib_def_; }
156
157 // Returns the node with the given name, or null if it does not exist.
158 const Node* get_node_by_name(const string& name) const {
159 NodeNameToCostIdMap::const_iterator iter =
160 node_name_to_cost_id_map_.find(name);
161 if (iter != node_name_to_cost_id_map_.end()) {
162 return graph_->FindNodeId(iter->second);
163 } else {
164 return nullptr;
165 }
166 }
167
168 // Returns the map of stateful placements as a map of
169 // node name to placement string.
170 std::unordered_map<string, string> GetStatefulPlacements() const {
171 return stateful_placements_;
172 }
173
174 private:
175 GraphExecutionState(std::unique_ptr<GraphDef>&& graph_def,
176 std::unique_ptr<FunctionLibraryDefinition>&& flib_def,
177 const GraphExecutionStateOptions& options);
178
179 Status InitBaseGraph(std::unique_ptr<Graph>&& graph);
180
181 // Map of placed stateful nodes, i.e. nodes for which is_stateful()
182 // is true, such as "params" and "queue" nodes. Once placed these
183 // nodes can not be moved to a different device. Maps node names to
184 // device names.
185 std::unordered_map<string, string> stateful_placements_; // Immutable after
186 // ctor.
187 void SaveStatefulNodes(Graph* graph);
188 void RestoreStatefulNodes(Graph* graph);
189
190 // Extract the subset of the graph that needs to be run, adding feed/fetch
191 // ops as needed.
192 Status PruneGraph(const BuildGraphOptions& options, Graph* graph,
193 subgraph::RewriteGraphMetadata* out_rewrite_metadata);
194
195 // The GraphExecutionState must store a copy of the original GraphDef if
196 // either of the following conditions holds:
197 //
198 // * `session_options_.config.graph_options().place_pruned_graph()` is true.
199 // * `session_options_.config.experimental().optimize_for_static_graph()` is
200 // false.
201 const std::unique_ptr<GraphDef> original_graph_def_;
202
203 const DeviceSet* device_set_; // Not owned
204 const SessionOptions* session_options_; // Not owned
205 // Unique session identifier. Can be empty.
206 string session_handle_;
207
208 // Map from name to Node for the full graph in placed_.
209 NodeNameToCostIdMap node_name_to_cost_id_map_;
210
211 // 'flib_def_' is initialized from the initial graph def's library,
212 // and may be updated by a graph optimization pass.
213 std::unique_ptr<FunctionLibraryDefinition> flib_def_;
214
215 // `rewrite_metadata_` is only set for GraphExecutionState
216 // objects created by `MakeForPrunedGraph()`.
217 std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
218
219 // The dataflow graph owned by this object.
220 Graph* graph_;
221
222 TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
223};
224
225} // namespace tensorflow
226
227#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
228