1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/session_mgr.h"
17
18#include <algorithm>
19#include <string>
20#include <utility>
21
22#include "tensorflow/core/activity_watcher/activity.h"
23#include "tensorflow/core/common_runtime/device_mgr.h"
24#include "tensorflow/core/common_runtime/renamed_device.h"
25#include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
26#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
27#include "tensorflow/core/distributed_runtime/error_payloads.h"
28#include "tensorflow/core/distributed_runtime/graph_mgr.h"
29#include "tensorflow/core/distributed_runtime/remote_device.h"
30#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
31#include "tensorflow/core/lib/strings/strcat.h"
32#include "tensorflow/core/protobuf/cluster.pb.h"
33#include "tensorflow/core/protobuf/coordination_config.pb.h"
34#include "tensorflow/core/protobuf/coordination_service.pb.h"
35#include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
36#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
37#include "tensorflow/core/util/ptr_util.h"
38
39namespace tensorflow {
40
41SessionMgr::SessionMgr(
42 WorkerEnv* worker_env, const std::string& default_worker_name,
43 std::unique_ptr<WorkerCacheInterface> default_worker_cache,
44 WorkerCacheFactory worker_cache_factory)
45 : worker_env_(worker_env),
46 default_worker_cache_(std::move(default_worker_cache)),
47 legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
48 "", default_worker_name,
49 std::unique_ptr<WorkerCacheInterface>(
50 new WorkerCacheWrapper(default_worker_cache_.get())),
51 worker_env->device_mgr,
52 std::unique_ptr<GraphMgr>(
53 new GraphMgr(worker_env, worker_env->device_mgr)),
54 nullptr)),
55 worker_cache_factory_(std::move(worker_cache_factory)) {}
56
57/* static */
58std::string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
59 return strings::StrCat("/job:", server_def.job_name(),
60 "/replica:0/task:", server_def.task_index());
61}
62
63Status SessionMgr::CreateSession(const std::string& session,
64 const ServerDef& server_def,
65 bool isolate_session_state,
66 StatusCallback coordination_error_callback) {
67 return CreateSession(session, server_def, {}, isolate_session_state,
68 /*master_task=*/"",
69 /*master_incarnation=*/0, coordination_error_callback);
70}
71
72Status SessionMgr::CreateSession(
73 const std::string& session, const ServerDef& server_def,
74 const protobuf::RepeatedPtrField<DeviceAttributes>&
75 cluster_device_attributes,
76 bool isolate_session_state) {
77 return CreateSession(session, server_def, cluster_device_attributes,
78 isolate_session_state,
79 /*master_task=*/"",
80 /*master_incarnation=*/0);
81}
82
83Status SessionMgr::CreateSession(
84 const std::string& session, const ServerDef& server_def,
85 const protobuf::RepeatedPtrField<DeviceAttributes>&
86 cluster_device_attributes,
87 bool isolate_session_state, std::string master_task,
88 int64_t master_incarnation, StatusCallback coordination_error_callback) {
89 mutex_lock l(mu_);
90 if (session.empty()) {
91 return errors::InvalidArgument("Session must be non-empty.");
92 }
93
94 // For given master task name, check if one or more `WorkerSession`s have been
95 // created previously on this worker, and if so garbage collect the expired
96 // `WorkerSession`s. This happens when the master fails before sending
97 // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
98 if (!master_task.empty()) {
99 auto it_range = master_to_associated_sessions_.equal_range(master_task);
100 if (it_range.first != it_range.second &&
101 it_range.first->second.master_incarnation != master_incarnation) {
102 LOG(INFO) << "When creating WorkerSession for master task " << master_task
103 << ", found old WorkerSessions created by the same master task "
104 << "with a different incarnation. These sessions will "
105 << "be garbage collected. Current WorkerSession count: "
106 << sessions_.size();
107
108 auto it = it_range.first;
109 while (it != it_range.second) {
110 auto session_it = sessions_.find(it->second.session_handle);
111 if (session_it != sessions_.end()) {
112 sessions_.erase(session_it);
113 }
114 it = master_to_associated_sessions_.erase(it);
115 }
116 }
117 }
118
119 WorkerCacheInterface* worker_cache = nullptr;
120 std::string worker_name;
121 if (server_def.cluster().job().empty()) {
122 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
123 worker_name = legacy_session_->worker_name();
124 } else {
125 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
126 worker_name = WorkerNameFromServerDef(server_def);
127 }
128
129 if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
130 worker_cache->SetLogging(this->is_logging_active_);
131 }
132
133 CHECK(!worker_env_->local_devices.empty())
134 << "The WorkerEnv must have at least one device in `local_devices`.";
135
136 std::shared_ptr<WorkerSession> worker_session;
137 std::vector<std::unique_ptr<Device>> cluster_devices;
138
139 if (isolate_session_state || server_def.cluster().job_size()) {
140 if (server_def.cluster().job_size()) {
141 VLOG(1) << "ClusterSpec propagation is enabled.";
142 }
143 if (!isolate_session_state) {
144 VLOG(1) << "Session state isolation is disabled.";
145 }
146
147 // Create a private copy of the DeviceMgr for the WorkerSession.
148 std::vector<std::unique_ptr<Device>> renamed_devices;
149 for (Device* d : worker_env_->local_devices) {
150 renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
151 worker_name, d, false, isolate_session_state));
152 }
153 auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
154 LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
155 return device_mgr->LookupDevice(name, device);
156 };
157 AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
158 &cluster_devices);
159 std::unique_ptr<DynamicDeviceMgr> remote_devices;
160 if (!cluster_device_attributes.empty()) {
161 remote_devices = MakeUnique<DynamicDeviceMgr>();
162 TF_RETURN_IF_ERROR(
163 remote_devices->AddDevices(std::move(cluster_devices)));
164 }
165
166 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
167 worker_session.reset(
168 new WorkerSession(session, worker_name,
169 std::unique_ptr<WorkerCacheInterface>(worker_cache),
170 std::move(device_mgr), std::move(graph_mgr),
171 std::move(remote_devices)));
172 } else {
173 AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
174 &cluster_devices);
175 std::unique_ptr<DynamicDeviceMgr> remote_devices;
176 if (!cluster_device_attributes.empty()) {
177 remote_devices = MakeUnique<DynamicDeviceMgr>();
178 TF_RETURN_IF_ERROR(
179 remote_devices->AddDevices(std::move(cluster_devices)));
180 }
181 // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
182 // that resources using it can use its devices after the
183 // WorkerSession has been deleted.
184 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
185 worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
186 session, worker_name,
187 std::unique_ptr<WorkerCacheInterface>(worker_cache),
188 worker_env_->device_mgr, std::move(graph_mgr),
189 std::move(remote_devices));
190 }
191
192 sessions_.insert(std::make_pair(session, std::move(worker_session)));
193 if (!master_task.empty()) {
194 MasterAssociatedSession s{master_incarnation, session};
195 master_to_associated_sessions_.emplace(master_task, s);
196 }
197
198 // If configured, enable coordination service and agent in the first worker
199 // session.
200 const CoordinationServiceConfig& coordination_service_config =
201 server_def.default_session_config().experimental().coordination_config();
202 if (!coordination_service_config.service_type().empty() &&
203 coordination_service_agent_ == nullptr) {
204 std::unique_ptr<CoordinationClientCache> client_cache;
205 TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&client_cache));
206 // Note: If this worker is not the leader, no service instance will be
207 // returned. Hence, only the worker leader in the cluster would hold the
208 // coordination service instance.
209 coordination_service_ =
210 CoordinationServiceInterface::EnableCoordinationService(
211 coordination_service_config.service_type(), worker_env_->env,
212 server_def, std::move(client_cache));
213
214 std::unique_ptr<CoordinationClientCache> agent_cache;
215 TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&agent_cache));
216 coordination_service_agent_ = CreateCoordinationServiceAgent();
217 TF_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
218 worker_env_->env, server_def, std::move(agent_cache),
219 std::move(coordination_error_callback)));
220 activity_watcher::MaybeEnableMultiWorkersWatching(
221 coordination_service_agent_.get());
222 }
223 return OkStatus();
224}
225
226void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
227 default_worker_cache_.reset(worker_cache);
228}
229
230Status SessionMgr::UpdateSession(
231 const std::string& session, const ServerDef& server_def,
232 const protobuf::RepeatedPtrField<DeviceAttributes>&
233 cluster_device_attributes) {
234 mutex_lock l(mu_);
235 if (session.empty()) {
236 return errors::InvalidArgument("Session must be non-empty.");
237 }
238 auto it = sessions_.find(session);
239 if (it == sessions_.end()) {
240 return errors::InvalidArgument("Cannot update session ", session,
241 " because it does not exist.");
242 }
243 std::shared_ptr<WorkerSession> worker_session = it->second;
244
245 WorkerCacheInterface* worker_cache = nullptr;
246 if (server_def.cluster().job().empty()) {
247 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
248 } else {
249 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
250 }
251 std::vector<std::string> updated_remote_workers;
252 worker_cache->ListWorkers(&updated_remote_workers);
253
254 std::vector<std::unique_ptr<Device>> cluster_devices;
255
256 const DeviceMgr* local_device_mgr = worker_session->device_mgr();
257 DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
258 std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
259 std::vector<std::unique_ptr<Device>> added_remote_devices;
260 std::vector<Device*> removed_remote_devices;
261
262 std::vector<DeviceAttributes> added_cluster_device_attrs;
263 for (const auto& da : cluster_device_attributes) {
264 Device* device;
265 if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
266 !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
267 added_cluster_device_attrs.emplace_back(da);
268 } else if (device != nullptr &&
269 device->attributes().incarnation() != da.incarnation()) {
270 removed_remote_devices.emplace_back(device);
271 added_cluster_device_attrs.emplace_back(da);
272 }
273 }
274 for (Device* device : curr_remote_devices) {
275 std::string task_name;
276 DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
277 if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
278 task_name) == updated_remote_workers.end()) {
279 removed_remote_devices.emplace_back(device);
280 }
281 }
282 protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
283 added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
284 AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
285 &added_remote_devices);
286
287 TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
288 std::unique_ptr<WorkerCacheInterface>(worker_cache),
289 std::move(added_remote_devices), removed_remote_devices));
290 return OkStatus();
291}
292
293Status SessionMgr::DeleteSession(const std::string& session) {
294 mutex_lock l(mu_);
295 auto it = sessions_.find(session);
296 if (it != sessions_.end()) {
297 sessions_.erase(it);
298 }
299 return OkStatus();
300}
301
302Status SessionMgr::WorkerSessionForSessionLocked(
303 const std::string& session_handle,
304 std::shared_ptr<WorkerSession>* out_session) {
305 if (session_handle.empty()) {
306 *out_session = legacy_session_;
307 } else {
308 auto it = sessions_.find(session_handle);
309 if (it == sessions_.end()) {
310 return errors::AbortedWithPayloads(
311 strings::StrCat("Session handle is not found: ", session_handle,
312 ". Possibly this worker (\"",
313 legacy_session_->worker_name(),
314 "\") just restarted."),
315 {{kWorkerPossiblyRestarted,
316 distributed_runtime::WorkerPossiblyRestarted()
317 .SerializeAsString()}});
318 } else {
319 *out_session = it->second;
320 }
321 }
322 return OkStatus();
323}
324
325Status SessionMgr::WorkerSessionForSession(
326 const std::string& session_handle,
327 std::shared_ptr<WorkerSession>* out_session) {
328 mutex_lock l(mu_);
329 return WorkerSessionForSessionLocked(session_handle, out_session);
330}
331
332std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
333 return legacy_session_;
334}
335
336CoordinationServiceAgent* SessionMgr::GetCoordinationServiceAgent() {
337 return coordination_service_agent_.get();
338}
339
340void SessionMgr::SetLogging(bool active) {
341 mutex_lock l(mu_);
342 this->is_logging_active_ = active;
343 // Legacy Session
344 if (legacy_session_) {
345 auto* worker_cache = legacy_session_->worker_cache();
346 if (worker_cache) {
347 worker_cache->SetLogging(active);
348 }
349 }
350
351 for (const auto& session_kv : sessions_) {
352 auto session = session_kv.second.get();
353 if (session) {
354 auto* worker_cache = session->worker_cache();
355 if (worker_cache) {
356 worker_cache->SetLogging(active);
357 }
358 }
359 }
360}
361
362void SessionMgr::RetrieveLogs(int64_t step_id, LoggingResponse* response) {
363 mutex_lock l(mu_);
364 // Legacy Session
365 if (legacy_session_) {
366 auto* worker_cache = legacy_session_->worker_cache();
367 if (worker_cache) {
368 auto step_stats = StepStats();
369 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
370 auto* labeled_step_stats = response->add_step();
371 labeled_step_stats->set_step_id(step_id);
372 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
373 }
374 }
375 }
376 for (const auto& session_kv : sessions_) {
377 auto session = session_kv.second.get();
378 if (session) {
379 auto* worker_cache = session->worker_cache();
380 if (worker_cache) {
381 auto step_stats = StepStats();
382 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
383 auto* labeled_step_stats = response->add_step();
384 labeled_step_stats->set_step_id(step_id);
385 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
386 }
387 }
388 }
389 }
390}
391
392void SessionMgr::ClearLogs() {
393 mutex_lock l(mu_);
394 // Legacy Session
395 if (legacy_session_) {
396 auto* worker_cache = legacy_session_->worker_cache();
397 if (worker_cache) {
398 worker_cache->ClearLogs();
399 }
400 }
401
402 for (const auto& session_kv : sessions_) {
403 auto session = session_kv.second.get();
404 if (session) {
405 auto* worker_cache = session->worker_cache();
406 if (worker_cache) {
407 worker_cache->ClearLogs();
408 }
409 }
410 }
411}
412
413void SessionMgr::TeardownCoordinationService() {
414 coordination_service_ = nullptr;
415}
416
417void SessionMgr::TeardownCoordinationServiceAgent() {
418 coordination_service_agent_ = nullptr;
419}
420} // namespace tensorflow
421