1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
18
19#include <functional>
20#include <string>
21
22#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
23#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
24#include "tensorflow/core/distributed_runtime/worker_session.h"
25#include "tensorflow/core/lib/core/status.h"
26#include "tensorflow/core/platform/mutex.h"
27#include "tensorflow/core/platform/thread_annotations.h"
28#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
29#include "tensorflow/core/protobuf/worker.pb.h"
30
31namespace tensorflow {
32
33class WorkerCacheInterface;
34struct WorkerEnv;
35
36// SessionMgr keeps track of information related to a given session.
37//
38// SessionMgr runs on the workers.
39//
40// SessionMgr is threadsafe.
41class SessionMgr {
42 public:
43 typedef std::function<Status(const ServerDef&, WorkerCacheInterface**)>
44 WorkerCacheFactory;
45
46 explicit SessionMgr(
47 WorkerEnv* worker_env, const std::string& default_worker_name,
48 std::unique_ptr<WorkerCacheInterface> default_worker_cache,
49 WorkerCacheFactory worker_cache_factory);
50 ~SessionMgr() {}
51
52 // Allocates state for a new session.
53 Status CreateSession(
54 const std::string& session, const ServerDef& server_def,
55 bool isolate_session_state,
56 StatusCallback coordination_error_callback = [](Status s) {
57 LOG(ERROR) << "Coordination agent is set to error: " << s;
58 });
59 Status CreateSession(
60 const std::string& session, const ServerDef& server_def,
61 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
62 bool isolate_session_state);
63
64 // Create WorkerSession from the master with the given `master_task` and
65 // `master_incarnation`. We first look for existing WorkerSessions associated
66 // with the specified master task. If there are sessions created by the same
67 // master but with a different incarnation, it indicates that the remote
68 // master has restarted before deleting the sessions on worker. When it
69 // happens, old sessions associated with the master will be automatically
70 // removed before the new session is created.
71 Status CreateSession(
72 const std::string& session, const ServerDef& server_def,
73 const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
74 bool isolate_session_state, std::string master_task,
75 int64_t master_incarnation,
76 StatusCallback coordination_error_callback = [](Status s) {
77 LOG(ERROR) << "Coordination agent is set to error: " << s;
78 });
79
80 void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache);
81
82 // Updates state (worker cache, devices) of worker session identified by
83 // session name (`session`) based on a new server_def and set of devices.
84 Status UpdateSession(const std::string& session, const ServerDef& server_def,
85 const protobuf::RepeatedPtrField<DeviceAttributes>&
86 cluster_device_attributes);
87
88 // Locates the worker session for a given session handle
89 Status WorkerSessionForSession(const std::string& session_handle,
90 std::shared_ptr<WorkerSession>* out_session);
91 std::shared_ptr<WorkerSession> LegacySession();
92
93 Status DeleteSession(const std::string& session);
94
95 // Provides access to the coordination service. This method should only be
96 // called after the agent has been initialized during session creation, or an
97 // invalid nullptr is returned. Note: the agent is thread-safe and mutable.
98 CoordinationServiceAgent* GetCoordinationServiceAgent();
99
100 static std::string WorkerNameFromServerDef(const ServerDef& server_def);
101
102 void SetLogging(bool active);
103
104 void RetrieveLogs(int64_t step_id, LoggingResponse* response);
105
106 void ClearLogs();
107
108 // Agent should be torn down before service as it needs to disconnect first.
109 void TeardownCoordinationServiceAgent();
110 void TeardownCoordinationService();
111
112 private:
113 WorkerEnv* const worker_env_; // Not owned.
114
115 // A note about destruction:
116 // We must delete graph_mgr before device_mgr, due to shared
117 // ownership of OpKernels in the executors. (The graph_mgr will
118 // free all stateless OpKernels, and pass over borrowed stateful
119 // OpKernels, which are also held in their respective devices'
120 // OpSegments.)
121 //
122 // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
123 // that sessions_'s WorkerSessions are deleted (which do not own the
124 // underlying devices, but instead own RenamedDevices) before
125 // legacy_session_ is deleted. Further, we must ensure that WorkerSession's
126 // device_mgr is deleted after WorkerSession's graph_mgr.
127
128 std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
129 std::shared_ptr<WorkerSession> legacy_session_;
130 std::unique_ptr<CoordinationServiceInterface> coordination_service_;
131 std::unique_ptr<CoordinationServiceAgent> coordination_service_agent_;
132
133 bool is_logging_active_ = false;
134
135 const WorkerCacheFactory worker_cache_factory_;
136
137 Status WorkerSessionForSessionLocked(
138 const std::string& session_handle,
139 std::shared_ptr<WorkerSession>* out_session)
140 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
141
142 mutex mu_;
143 // A map from session identifier to internal session structure.
144 std::map<std::string, std::shared_ptr<WorkerSession>> sessions_
145 TF_GUARDED_BY(mu_);
146
147 // Incarnation and WorkerSession handle associated with a master task.
148 struct MasterAssociatedSession {
149 const int64_t master_incarnation;
150 const std::string session_handle;
151 };
152 // A map from master task name to its associated worker sessions.
153 std::unordered_multimap<std::string, MasterAssociatedSession>
154 master_to_associated_sessions_ TF_GUARDED_BY(mu_);
155};
156
157} // namespace tensorflow
158
159#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_
160