1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
18
19#include <atomic>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/debugger_state_interface.h"
23#include "tensorflow/core/common_runtime/device_set.h"
24#include "tensorflow/core/common_runtime/graph_execution_state.h"
25#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
26#include "tensorflow/core/distributed_runtime/call_options.h"
27#include "tensorflow/core/distributed_runtime/master_env.h"
28#include "tensorflow/core/distributed_runtime/message_wrappers.h"
29#include "tensorflow/core/distributed_runtime/worker_cache.h"
30#include "tensorflow/core/lib/core/status.h"
31#include "tensorflow/core/platform/types.h"
32#include "tensorflow/core/protobuf/master.pb.h"
33#include "tensorflow/core/public/session_options.h"
34
35namespace tensorflow {
36
37class Device;
38struct MasterEnv;
39
40// A session encapsulates a graph computation (resource allocation,
41// placement, execution, etc.).
42class MasterSession : public core::RefCounted {
43 public:
44 // This session encapsulates the graph computation for a graph.
45 //
46 // The session places nodes on devices in "remote_devs" and executes
47 // operations on these devices.
48 //
49 // The caller takes ownership of all remote devices.
50 MasterSession(
51 const SessionOptions& options, const MasterEnv* env,
52 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
53 std::unique_ptr<WorkerCacheInterface> worker_cache,
54 std::unique_ptr<DeviceSet> device_set,
55 std::vector<string> filtered_worker_list,
56 StatsPublisherFactory stats_publisher_factory);
57
58 // Initialize the MasterSession for "def". Must be called before Extend(),
59 // Run(), or Close().
60 Status Create(GraphDef&& def, const ClusterDef& cluster_def);
61
62 // Returns the session handle.
63 const string& handle() const { return handle_; }
64
65 // Returns the last access time (the number of micro-seconds since
66 // some fixed point in time) of this session.
67 uint64 last_access_time_usec() const { return last_access_time_usec_.load(); }
68
69 // Attempt to extend the graph according to the given "req".
70 // (See master.proto for details of valid extensions.)
71 //
72 // PRECONDITION: The current version of this session's graph
73 // is "req->current_graph_version".
74 //
75 // POSTCONDITION: The current version of this session's graph
76 // is "resp->new_graph_version".
77 //
78 // Extend() may block the caller thread for a long time.
79 Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp);
80
81 // Setup a partial run call.
82 Status PartialRunSetup(const PartialRunSetupRequest* req,
83 PartialRunSetupResponse* resp);
84
85 // Run one step.
86 Status Run(CallOptions* opts, const RunStepRequestWrapper& req,
87 MutableRunStepResponseWrapper* resp);
88
89 Status ListDevices(ListDevicesResponse* resp) const;
90
91 Status MakeCallable(const MakeCallableRequest& req,
92 MakeCallableResponse* resp);
93
94 Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
95 RunCallableResponse* resp);
96
97 Status ReleaseCallable(const ReleaseCallableRequest& req,
98 ReleaseCallableResponse* resp);
99
100 // Close this session and delete "*this". Returns OK if all known
101 // states are cleanup successfully.
102 //
103 // Close() may block the caller thread for a long time.
104 Status Close();
105
106 // Close this session and release a reference on "*this".
107 //
108 // Note that, unlike Close(), this method does not block on the
109 // completion of all work.
110 void GarbageCollect();
111
112 private:
113 SessionOptions session_opts_;
114
115 // Not owned.
116 const MasterEnv* env_;
117
118 // The opaque session handle.
119 const string handle_;
120
121 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
122
123 // The optional session-specific worker cluster.
124 // TODO(saeta): Convert to std::optional when available.
125 const std::unique_ptr<WorkerCacheInterface> worker_cache_;
126 // Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
127 WorkerCacheInterface* get_worker_cache() const;
128
129 // The device set used by this session.
130 std::unique_ptr<DeviceSet> devices_;
131
132 // The (partial device) names of remote worker tasks that this
133 // session will contact.
134 const std::vector<string> filtered_worker_list_;
135
136 StatsPublisherFactory stats_publisher_factory_;
137
138 std::atomic_ulong last_access_time_usec_;
139
140 std::atomic<int64_t> partial_run_handle_counter_ = {0};
141
142 uint64 NewStepId(int64_t graph_key);
143
144 mutex mu_;
145 std::unique_ptr<GraphExecutionState> execution_state_ TF_GUARDED_BY(mu_);
146 int64_t graph_version_;
147
148 // We keep a map from a signature of a run request to the
149 // ReffedClientGraph the can execute it. We keep up to one old copy
150 // of each ReffedClientGraph around because if it gets deallocated
151 // before a new substitute has been created, Variables can go out of
152 // scope and lose their state.
153 class ReffedClientGraph;
154 typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
155 RCGMap run_graphs_ TF_GUARDED_BY(mu_);
156 RCGMap partial_run_graphs_ TF_GUARDED_BY(mu_);
157 int64_t next_callable_handle_ TF_GUARDED_BY(mu_) = 0;
158 RCGMap callables_ TF_GUARDED_BY(mu_);
159
160 struct PerStepState {
161 bool collect_costs = false;
162 bool collect_timeline = false;
163 bool collect_rpcs = false;
164 bool collect_partition_graphs = false;
165 bool report_tensor_allocations_upon_oom = false;
166 Microseconds start_micros = Microseconds(0);
167 Microseconds end_micros = Microseconds(0);
168 std::vector<StepStats> step_stats; // per partition
169 StepStats rpc_stats; // for RPC layer
170 CostGraphDef cost_graph;
171 };
172
173 struct RunState {
174 std::unordered_map<string, bool> pending_inputs; // true if fed
175 std::unordered_map<string, bool> pending_outputs; // true if fetched
176 ReffedClientGraph* rcg = nullptr;
177 uint64 step_id;
178 int64_t collective_graph_key;
179 int64_t count = 0;
180 PerStepState pss;
181 std::unique_ptr<ProfileHandler> ph;
182 bool step_started = false;
183
184 RunState(const std::vector<string>& input_names,
185 const std::vector<string>& output_names, ReffedClientGraph* rcg,
186 const uint64 step_id, const int64_t count);
187
188 bool PendingDone() const;
189
190 ~RunState();
191 };
192 std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
193 TF_GUARDED_BY(mu_);
194
195 // Active RunStep calls.
196 condition_variable num_running_is_zero_;
197 int32 num_running_ TF_GUARDED_BY(mu_) = 0;
198
199 bool closed_ TF_GUARDED_BY(mu_) = false;
200 bool garbage_collected_ TF_GUARDED_BY(mu_) = false;
201
202 std::unordered_map<uint64, int64_t> subgraph_execution_counts_
203 TF_GUARDED_BY(mu_);
204
205 // We need to ensure that certain nodes added (e.g., send and recv
206 // nodes) are unique across all sub-graphs within this session.
207 int64_t next_node_id_ TF_GUARDED_BY(mu_) = 0;
208
209 // Used to cancel running steps on Close().
210 CancellationManager cancellation_manager_;
211
212 // Private dtor. The client must call Close().
213 virtual ~MasterSession();
214
215 // Creates sessions on all workers.
216 //
217 // If this session is operating using the new ClusterSpec propagation behavior
218 // call this method in order to propagate the cluster membership to all
219 // workers.
220 Status CreateWorkerSessions(const ClusterDef& cluster_def);
221
222 bool should_delete_worker_sessions_ = false;
223 Status DeleteWorkerSessions();
224
225 Status StartStep(const BuildGraphOptions& opts, bool is_partial,
226 ReffedClientGraph** out_rcg, int64_t* out_count);
227 void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
228 RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
229 void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
230 const RunOptions& run_options, uint64 step_id,
231 int64_t count, PerStepState* out_pss,
232 std::unique_ptr<ProfileHandler>* out_ph);
233 Status DoRunWithLocalExecution(CallOptions* opts,
234 const RunStepRequestWrapper& req,
235 MutableRunStepResponseWrapper* resp);
236 Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
237 MutableRunStepResponseWrapper* resp);
238 Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
239 const RunCallableRequest& req,
240 RunCallableResponse* resp);
241 Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
242 const RunOptions& run_options, PerStepState* pss,
243 const std::unique_ptr<ProfileHandler>& ph,
244 const Status& run_status,
245 RunMetadata* out_run_metadata);
246
247 void MarkRunCompletion();
248 void UpdateLastAccessTime();
249
250 Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
251
252 Status CreateDebuggerState(
253 const DebugOptions& debug_options, const RunStepRequestWrapper& req,
254 int64_t rcg_execution_count,
255 std::unique_ptr<DebuggerStateInterface>* debugger_state);
256
257 TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
258};
259
260} // end namespace tensorflow
261
262#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
263