1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ |
18 | |
19 | #include <atomic> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
23 | #include "tensorflow/core/common_runtime/device_set.h" |
24 | #include "tensorflow/core/common_runtime/graph_execution_state.h" |
25 | #include "tensorflow/core/common_runtime/stats_publisher_interface.h" |
26 | #include "tensorflow/core/distributed_runtime/call_options.h" |
27 | #include "tensorflow/core/distributed_runtime/master_env.h" |
28 | #include "tensorflow/core/distributed_runtime/message_wrappers.h" |
29 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | #include "tensorflow/core/protobuf/master.pb.h" |
33 | #include "tensorflow/core/public/session_options.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | class Device; |
38 | struct MasterEnv; |
39 | |
40 | // A session encapsulates a graph computation (resource allocation, |
41 | // placement, execution, etc.). |
42 | class MasterSession : public core::RefCounted { |
43 | public: |
44 | // This session encapsulates the graph computation for a graph. |
45 | // |
46 | // The session places nodes on devices in "remote_devs" and executes |
47 | // operations on these devices. |
48 | // |
49 | // The caller takes ownership of all remote devices. |
50 | MasterSession( |
51 | const SessionOptions& options, const MasterEnv* env, |
52 | std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs, |
53 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
54 | std::unique_ptr<DeviceSet> device_set, |
55 | std::vector<string> filtered_worker_list, |
56 | StatsPublisherFactory stats_publisher_factory); |
57 | |
58 | // Initialize the MasterSession for "def". Must be called before Extend(), |
59 | // Run(), or Close(). |
60 | Status Create(GraphDef&& def, const ClusterDef& cluster_def); |
61 | |
62 | // Returns the session handle. |
63 | const string& handle() const { return handle_; } |
64 | |
65 | // Returns the last access time (the number of micro-seconds since |
66 | // some fixed point in time) of this session. |
67 | uint64 last_access_time_usec() const { return last_access_time_usec_.load(); } |
68 | |
69 | // Attempt to extend the graph according to the given "req". |
70 | // (See master.proto for details of valid extensions.) |
71 | // |
72 | // PRECONDITION: The current version of this session's graph |
73 | // is "req->current_graph_version". |
74 | // |
75 | // POSTCONDITION: The current version of this session's graph |
76 | // is "resp->new_graph_version". |
77 | // |
78 | // Extend() may block the caller thread for a long time. |
79 | Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp); |
80 | |
81 | // Setup a partial run call. |
82 | Status PartialRunSetup(const PartialRunSetupRequest* req, |
83 | PartialRunSetupResponse* resp); |
84 | |
85 | // Run one step. |
86 | Status Run(CallOptions* opts, const RunStepRequestWrapper& req, |
87 | MutableRunStepResponseWrapper* resp); |
88 | |
89 | Status ListDevices(ListDevicesResponse* resp) const; |
90 | |
91 | Status MakeCallable(const MakeCallableRequest& req, |
92 | MakeCallableResponse* resp); |
93 | |
94 | Status RunCallable(CallOptions* opts, const RunCallableRequest& req, |
95 | RunCallableResponse* resp); |
96 | |
97 | Status ReleaseCallable(const ReleaseCallableRequest& req, |
98 | ReleaseCallableResponse* resp); |
99 | |
100 | // Close this session and delete "*this". Returns OK if all known |
101 | // states are cleanup successfully. |
102 | // |
103 | // Close() may block the caller thread for a long time. |
104 | Status Close(); |
105 | |
106 | // Close this session and release a reference on "*this". |
107 | // |
108 | // Note that, unlike Close(), this method does not block on the |
109 | // completion of all work. |
110 | void GarbageCollect(); |
111 | |
112 | private: |
113 | SessionOptions session_opts_; |
114 | |
115 | // Not owned. |
116 | const MasterEnv* env_; |
117 | |
118 | // The opaque session handle. |
119 | const string handle_; |
120 | |
121 | std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_; |
122 | |
123 | // The optional session-specific worker cluster. |
124 | // TODO(saeta): Convert to std::optional when available. |
125 | const std::unique_ptr<WorkerCacheInterface> worker_cache_; |
126 | // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. |
127 | WorkerCacheInterface* get_worker_cache() const; |
128 | |
129 | // The device set used by this session. |
130 | std::unique_ptr<DeviceSet> devices_; |
131 | |
132 | // The (partial device) names of remote worker tasks that this |
133 | // session will contact. |
134 | const std::vector<string> filtered_worker_list_; |
135 | |
136 | StatsPublisherFactory stats_publisher_factory_; |
137 | |
138 | std::atomic_ulong last_access_time_usec_; |
139 | |
140 | std::atomic<int64_t> partial_run_handle_counter_ = {0}; |
141 | |
142 | uint64 NewStepId(int64_t graph_key); |
143 | |
144 | mutex mu_; |
145 | std::unique_ptr<GraphExecutionState> execution_state_ TF_GUARDED_BY(mu_); |
146 | int64_t graph_version_; |
147 | |
148 | // We keep a map from a signature of a run request to the |
149 | // ReffedClientGraph the can execute it. We keep up to one old copy |
150 | // of each ReffedClientGraph around because if it gets deallocated |
151 | // before a new substitute has been created, Variables can go out of |
152 | // scope and lose their state. |
153 | class ReffedClientGraph; |
154 | typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap; |
155 | RCGMap run_graphs_ TF_GUARDED_BY(mu_); |
156 | RCGMap partial_run_graphs_ TF_GUARDED_BY(mu_); |
157 | int64_t next_callable_handle_ TF_GUARDED_BY(mu_) = 0; |
158 | RCGMap callables_ TF_GUARDED_BY(mu_); |
159 | |
160 | struct PerStepState { |
161 | bool collect_costs = false; |
162 | bool collect_timeline = false; |
163 | bool collect_rpcs = false; |
164 | bool collect_partition_graphs = false; |
165 | bool report_tensor_allocations_upon_oom = false; |
166 | Microseconds start_micros = Microseconds(0); |
167 | Microseconds end_micros = Microseconds(0); |
168 | std::vector<StepStats> step_stats; // per partition |
169 | StepStats rpc_stats; // for RPC layer |
170 | CostGraphDef cost_graph; |
171 | }; |
172 | |
173 | struct RunState { |
174 | std::unordered_map<string, bool> pending_inputs; // true if fed |
175 | std::unordered_map<string, bool> pending_outputs; // true if fetched |
176 | ReffedClientGraph* rcg = nullptr; |
177 | uint64 step_id; |
178 | int64_t collective_graph_key; |
179 | int64_t count = 0; |
180 | PerStepState pss; |
181 | std::unique_ptr<ProfileHandler> ph; |
182 | bool step_started = false; |
183 | |
184 | RunState(const std::vector<string>& input_names, |
185 | const std::vector<string>& output_names, ReffedClientGraph* rcg, |
186 | const uint64 step_id, const int64_t count); |
187 | |
188 | bool PendingDone() const; |
189 | |
190 | ~RunState(); |
191 | }; |
192 | std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_ |
193 | TF_GUARDED_BY(mu_); |
194 | |
195 | // Active RunStep calls. |
196 | condition_variable num_running_is_zero_; |
197 | int32 num_running_ TF_GUARDED_BY(mu_) = 0; |
198 | |
199 | bool closed_ TF_GUARDED_BY(mu_) = false; |
200 | bool garbage_collected_ TF_GUARDED_BY(mu_) = false; |
201 | |
202 | std::unordered_map<uint64, int64_t> subgraph_execution_counts_ |
203 | TF_GUARDED_BY(mu_); |
204 | |
205 | // We need to ensure that certain nodes added (e.g., send and recv |
206 | // nodes) are unique across all sub-graphs within this session. |
207 | int64_t next_node_id_ TF_GUARDED_BY(mu_) = 0; |
208 | |
209 | // Used to cancel running steps on Close(). |
210 | CancellationManager cancellation_manager_; |
211 | |
212 | // Private dtor. The client must call Close(). |
213 | virtual ~MasterSession(); |
214 | |
215 | // Creates sessions on all workers. |
216 | // |
217 | // If this session is operating using the new ClusterSpec propagation behavior |
218 | // call this method in order to propagate the cluster membership to all |
219 | // workers. |
220 | Status CreateWorkerSessions(const ClusterDef& cluster_def); |
221 | |
222 | bool should_delete_worker_sessions_ = false; |
223 | Status DeleteWorkerSessions(); |
224 | |
225 | Status StartStep(const BuildGraphOptions& opts, bool is_partial, |
226 | ReffedClientGraph** out_rcg, int64_t* out_count); |
227 | void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, |
228 | RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
229 | void FillPerStepState(MasterSession::ReffedClientGraph* rcg, |
230 | const RunOptions& run_options, uint64 step_id, |
231 | int64_t count, PerStepState* out_pss, |
232 | std::unique_ptr<ProfileHandler>* out_ph); |
233 | Status DoRunWithLocalExecution(CallOptions* opts, |
234 | const RunStepRequestWrapper& req, |
235 | MutableRunStepResponseWrapper* resp); |
236 | Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, |
237 | MutableRunStepResponseWrapper* resp); |
238 | Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, |
239 | const RunCallableRequest& req, |
240 | RunCallableResponse* resp); |
241 | Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id, |
242 | const RunOptions& run_options, PerStepState* pss, |
243 | const std::unique_ptr<ProfileHandler>& ph, |
244 | const Status& run_status, |
245 | RunMetadata* out_run_metadata); |
246 | |
247 | void MarkRunCompletion(); |
248 | void UpdateLastAccessTime(); |
249 | |
250 | Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); |
251 | |
252 | Status CreateDebuggerState( |
253 | const DebugOptions& debug_options, const RunStepRequestWrapper& req, |
254 | int64_t rcg_execution_count, |
255 | std::unique_ptr<DebuggerStateInterface>* debugger_state); |
256 | |
257 | TF_DISALLOW_COPY_AND_ASSIGN(MasterSession); |
258 | }; |
259 | |
260 | } // end namespace tensorflow |
261 | |
262 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ |
263 | |