1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/session_mgr.h" |
17 | |
18 | #include <algorithm> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/core/activity_watcher/activity.h" |
23 | #include "tensorflow/core/common_runtime/device_mgr.h" |
24 | #include "tensorflow/core/common_runtime/renamed_device.h" |
25 | #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h" |
26 | #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h" |
27 | #include "tensorflow/core/distributed_runtime/error_payloads.h" |
28 | #include "tensorflow/core/distributed_runtime/graph_mgr.h" |
29 | #include "tensorflow/core/distributed_runtime/remote_device.h" |
30 | #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" |
31 | #include "tensorflow/core/lib/strings/strcat.h" |
32 | #include "tensorflow/core/protobuf/cluster.pb.h" |
33 | #include "tensorflow/core/protobuf/coordination_config.pb.h" |
34 | #include "tensorflow/core/protobuf/coordination_service.pb.h" |
35 | #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h" |
36 | #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
37 | #include "tensorflow/core/util/ptr_util.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | SessionMgr::SessionMgr( |
42 | WorkerEnv* worker_env, const std::string& default_worker_name, |
43 | std::unique_ptr<WorkerCacheInterface> default_worker_cache, |
44 | WorkerCacheFactory worker_cache_factory) |
45 | : worker_env_(worker_env), |
46 | default_worker_cache_(std::move(default_worker_cache)), |
47 | legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr( |
48 | "" , default_worker_name, |
49 | std::unique_ptr<WorkerCacheInterface>( |
50 | new WorkerCacheWrapper(default_worker_cache_.get())), |
51 | worker_env->device_mgr, |
52 | std::unique_ptr<GraphMgr>( |
53 | new GraphMgr(worker_env, worker_env->device_mgr)), |
54 | nullptr)), |
55 | worker_cache_factory_(std::move(worker_cache_factory)) {} |
56 | |
57 | /* static */ |
58 | std::string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { |
59 | return strings::StrCat("/job:" , server_def.job_name(), |
60 | "/replica:0/task:" , server_def.task_index()); |
61 | } |
62 | |
63 | Status SessionMgr::CreateSession(const std::string& session, |
64 | const ServerDef& server_def, |
65 | bool isolate_session_state, |
66 | StatusCallback coordination_error_callback) { |
67 | return CreateSession(session, server_def, {}, isolate_session_state, |
68 | /*master_task=*/"" , |
69 | /*master_incarnation=*/0, coordination_error_callback); |
70 | } |
71 | |
72 | Status SessionMgr::CreateSession( |
73 | const std::string& session, const ServerDef& server_def, |
74 | const protobuf::RepeatedPtrField<DeviceAttributes>& |
75 | cluster_device_attributes, |
76 | bool isolate_session_state) { |
77 | return CreateSession(session, server_def, cluster_device_attributes, |
78 | isolate_session_state, |
79 | /*master_task=*/"" , |
80 | /*master_incarnation=*/0); |
81 | } |
82 | |
83 | Status SessionMgr::CreateSession( |
84 | const std::string& session, const ServerDef& server_def, |
85 | const protobuf::RepeatedPtrField<DeviceAttributes>& |
86 | cluster_device_attributes, |
87 | bool isolate_session_state, std::string master_task, |
88 | int64_t master_incarnation, StatusCallback coordination_error_callback) { |
89 | mutex_lock l(mu_); |
90 | if (session.empty()) { |
91 | return errors::InvalidArgument("Session must be non-empty." ); |
92 | } |
93 | |
94 | // For given master task name, check if one or more `WorkerSession`s have been |
95 | // created previously on this worker, and if so garbage collect the expired |
96 | // `WorkerSession`s. This happens when the master fails before sending |
97 | // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked. |
98 | if (!master_task.empty()) { |
99 | auto it_range = master_to_associated_sessions_.equal_range(master_task); |
100 | if (it_range.first != it_range.second && |
101 | it_range.first->second.master_incarnation != master_incarnation) { |
102 | LOG(INFO) << "When creating WorkerSession for master task " << master_task |
103 | << ", found old WorkerSessions created by the same master task " |
104 | << "with a different incarnation. These sessions will " |
105 | << "be garbage collected. Current WorkerSession count: " |
106 | << sessions_.size(); |
107 | |
108 | auto it = it_range.first; |
109 | while (it != it_range.second) { |
110 | auto session_it = sessions_.find(it->second.session_handle); |
111 | if (session_it != sessions_.end()) { |
112 | sessions_.erase(session_it); |
113 | } |
114 | it = master_to_associated_sessions_.erase(it); |
115 | } |
116 | } |
117 | } |
118 | |
119 | WorkerCacheInterface* worker_cache = nullptr; |
120 | std::string worker_name; |
121 | if (server_def.cluster().job().empty()) { |
122 | worker_cache = new WorkerCacheWrapper(default_worker_cache_.get()); |
123 | worker_name = legacy_session_->worker_name(); |
124 | } else { |
125 | TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); |
126 | worker_name = WorkerNameFromServerDef(server_def); |
127 | } |
128 | |
129 | if (worker_cache != nullptr && default_worker_cache_ != nullptr) { |
130 | worker_cache->SetLogging(this->is_logging_active_); |
131 | } |
132 | |
133 | CHECK(!worker_env_->local_devices.empty()) |
134 | << "The WorkerEnv must have at least one device in `local_devices`." ; |
135 | |
136 | std::shared_ptr<WorkerSession> worker_session; |
137 | std::vector<std::unique_ptr<Device>> cluster_devices; |
138 | |
139 | if (isolate_session_state || server_def.cluster().job_size()) { |
140 | if (server_def.cluster().job_size()) { |
141 | VLOG(1) << "ClusterSpec propagation is enabled." ; |
142 | } |
143 | if (!isolate_session_state) { |
144 | VLOG(1) << "Session state isolation is disabled." ; |
145 | } |
146 | |
147 | // Create a private copy of the DeviceMgr for the WorkerSession. |
148 | std::vector<std::unique_ptr<Device>> renamed_devices; |
149 | for (Device* d : worker_env_->local_devices) { |
150 | renamed_devices.push_back(RenamedDevice::NewRenamedDevice( |
151 | worker_name, d, false, isolate_session_state)); |
152 | } |
153 | auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices)); |
154 | LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) { |
155 | return device_mgr->LookupDevice(name, device); |
156 | }; |
157 | AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb, |
158 | &cluster_devices); |
159 | std::unique_ptr<DynamicDeviceMgr> remote_devices; |
160 | if (!cluster_device_attributes.empty()) { |
161 | remote_devices = MakeUnique<DynamicDeviceMgr>(); |
162 | TF_RETURN_IF_ERROR( |
163 | remote_devices->AddDevices(std::move(cluster_devices))); |
164 | } |
165 | |
166 | auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get()); |
167 | worker_session.reset( |
168 | new WorkerSession(session, worker_name, |
169 | std::unique_ptr<WorkerCacheInterface>(worker_cache), |
170 | std::move(device_mgr), std::move(graph_mgr), |
171 | std::move(remote_devices))); |
172 | } else { |
173 | AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr, |
174 | &cluster_devices); |
175 | std::unique_ptr<DynamicDeviceMgr> remote_devices; |
176 | if (!cluster_device_attributes.empty()) { |
177 | remote_devices = MakeUnique<DynamicDeviceMgr>(); |
178 | TF_RETURN_IF_ERROR( |
179 | remote_devices->AddDevices(std::move(cluster_devices))); |
180 | } |
181 | // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so |
182 | // that resources using it can use its devices after the |
183 | // WorkerSession has been deleted. |
184 | auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr); |
185 | worker_session = WorkerSession::CreateWithBorrowedDeviceMgr( |
186 | session, worker_name, |
187 | std::unique_ptr<WorkerCacheInterface>(worker_cache), |
188 | worker_env_->device_mgr, std::move(graph_mgr), |
189 | std::move(remote_devices)); |
190 | } |
191 | |
192 | sessions_.insert(std::make_pair(session, std::move(worker_session))); |
193 | if (!master_task.empty()) { |
194 | MasterAssociatedSession s{master_incarnation, session}; |
195 | master_to_associated_sessions_.emplace(master_task, s); |
196 | } |
197 | |
198 | // If configured, enable coordination service and agent in the first worker |
199 | // session. |
200 | const CoordinationServiceConfig& coordination_service_config = |
201 | server_def.default_session_config().experimental().coordination_config(); |
202 | if (!coordination_service_config.service_type().empty() && |
203 | coordination_service_agent_ == nullptr) { |
204 | std::unique_ptr<CoordinationClientCache> client_cache; |
205 | TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&client_cache)); |
206 | // Note: If this worker is not the leader, no service instance will be |
207 | // returned. Hence, only the worker leader in the cluster would hold the |
208 | // coordination service instance. |
209 | coordination_service_ = |
210 | CoordinationServiceInterface::EnableCoordinationService( |
211 | coordination_service_config.service_type(), worker_env_->env, |
212 | server_def, std::move(client_cache)); |
213 | |
214 | std::unique_ptr<CoordinationClientCache> agent_cache; |
215 | TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&agent_cache)); |
216 | coordination_service_agent_ = CreateCoordinationServiceAgent(); |
217 | TF_RETURN_IF_ERROR(coordination_service_agent_->Initialize( |
218 | worker_env_->env, server_def, std::move(agent_cache), |
219 | std::move(coordination_error_callback))); |
220 | activity_watcher::MaybeEnableMultiWorkersWatching( |
221 | coordination_service_agent_.get()); |
222 | } |
223 | return OkStatus(); |
224 | } |
225 | |
226 | void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) { |
227 | default_worker_cache_.reset(worker_cache); |
228 | } |
229 | |
230 | Status SessionMgr::UpdateSession( |
231 | const std::string& session, const ServerDef& server_def, |
232 | const protobuf::RepeatedPtrField<DeviceAttributes>& |
233 | cluster_device_attributes) { |
234 | mutex_lock l(mu_); |
235 | if (session.empty()) { |
236 | return errors::InvalidArgument("Session must be non-empty." ); |
237 | } |
238 | auto it = sessions_.find(session); |
239 | if (it == sessions_.end()) { |
240 | return errors::InvalidArgument("Cannot update session " , session, |
241 | " because it does not exist." ); |
242 | } |
243 | std::shared_ptr<WorkerSession> worker_session = it->second; |
244 | |
245 | WorkerCacheInterface* worker_cache = nullptr; |
246 | if (server_def.cluster().job().empty()) { |
247 | worker_cache = new WorkerCacheWrapper(default_worker_cache_.get()); |
248 | } else { |
249 | TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache)); |
250 | } |
251 | std::vector<std::string> updated_remote_workers; |
252 | worker_cache->ListWorkers(&updated_remote_workers); |
253 | |
254 | std::vector<std::unique_ptr<Device>> cluster_devices; |
255 | |
256 | const DeviceMgr* local_device_mgr = worker_session->device_mgr(); |
257 | DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr(); |
258 | std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices(); |
259 | std::vector<std::unique_ptr<Device>> added_remote_devices; |
260 | std::vector<Device*> removed_remote_devices; |
261 | |
262 | std::vector<DeviceAttributes> added_cluster_device_attrs; |
263 | for (const auto& da : cluster_device_attributes) { |
264 | Device* device; |
265 | if (!local_device_mgr->LookupDevice(da.name(), &device).ok() && |
266 | !remote_device_mgr->LookupDevice(da.name(), &device).ok()) { |
267 | added_cluster_device_attrs.emplace_back(da); |
268 | } else if (device != nullptr && |
269 | device->attributes().incarnation() != da.incarnation()) { |
270 | removed_remote_devices.emplace_back(device); |
271 | added_cluster_device_attrs.emplace_back(da); |
272 | } |
273 | } |
274 | for (Device* device : curr_remote_devices) { |
275 | std::string task_name; |
276 | DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name); |
277 | if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(), |
278 | task_name) == updated_remote_workers.end()) { |
279 | removed_remote_devices.emplace_back(device); |
280 | } |
281 | } |
282 | protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb( |
283 | added_cluster_device_attrs.begin(), added_cluster_device_attrs.end()); |
284 | AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr, |
285 | &added_remote_devices); |
286 | |
287 | TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices( |
288 | std::unique_ptr<WorkerCacheInterface>(worker_cache), |
289 | std::move(added_remote_devices), removed_remote_devices)); |
290 | return OkStatus(); |
291 | } |
292 | |
293 | Status SessionMgr::DeleteSession(const std::string& session) { |
294 | mutex_lock l(mu_); |
295 | auto it = sessions_.find(session); |
296 | if (it != sessions_.end()) { |
297 | sessions_.erase(it); |
298 | } |
299 | return OkStatus(); |
300 | } |
301 | |
302 | Status SessionMgr::WorkerSessionForSessionLocked( |
303 | const std::string& session_handle, |
304 | std::shared_ptr<WorkerSession>* out_session) { |
305 | if (session_handle.empty()) { |
306 | *out_session = legacy_session_; |
307 | } else { |
308 | auto it = sessions_.find(session_handle); |
309 | if (it == sessions_.end()) { |
310 | return errors::AbortedWithPayloads( |
311 | strings::StrCat("Session handle is not found: " , session_handle, |
312 | ". Possibly this worker (\"" , |
313 | legacy_session_->worker_name(), |
314 | "\") just restarted." ), |
315 | {{kWorkerPossiblyRestarted, |
316 | distributed_runtime::WorkerPossiblyRestarted() |
317 | .SerializeAsString()}}); |
318 | } else { |
319 | *out_session = it->second; |
320 | } |
321 | } |
322 | return OkStatus(); |
323 | } |
324 | |
325 | Status SessionMgr::WorkerSessionForSession( |
326 | const std::string& session_handle, |
327 | std::shared_ptr<WorkerSession>* out_session) { |
328 | mutex_lock l(mu_); |
329 | return WorkerSessionForSessionLocked(session_handle, out_session); |
330 | } |
331 | |
332 | std::shared_ptr<WorkerSession> SessionMgr::LegacySession() { |
333 | return legacy_session_; |
334 | } |
335 | |
336 | CoordinationServiceAgent* SessionMgr::GetCoordinationServiceAgent() { |
337 | return coordination_service_agent_.get(); |
338 | } |
339 | |
340 | void SessionMgr::SetLogging(bool active) { |
341 | mutex_lock l(mu_); |
342 | this->is_logging_active_ = active; |
343 | // Legacy Session |
344 | if (legacy_session_) { |
345 | auto* worker_cache = legacy_session_->worker_cache(); |
346 | if (worker_cache) { |
347 | worker_cache->SetLogging(active); |
348 | } |
349 | } |
350 | |
351 | for (const auto& session_kv : sessions_) { |
352 | auto session = session_kv.second.get(); |
353 | if (session) { |
354 | auto* worker_cache = session->worker_cache(); |
355 | if (worker_cache) { |
356 | worker_cache->SetLogging(active); |
357 | } |
358 | } |
359 | } |
360 | } |
361 | |
362 | void SessionMgr::RetrieveLogs(int64_t step_id, LoggingResponse* response) { |
363 | mutex_lock l(mu_); |
364 | // Legacy Session |
365 | if (legacy_session_) { |
366 | auto* worker_cache = legacy_session_->worker_cache(); |
367 | if (worker_cache) { |
368 | auto step_stats = StepStats(); |
369 | if (worker_cache->RetrieveLogs(step_id, &step_stats)) { |
370 | auto* labeled_step_stats = response->add_step(); |
371 | labeled_step_stats->set_step_id(step_id); |
372 | labeled_step_stats->mutable_step_stats()->Swap(&step_stats); |
373 | } |
374 | } |
375 | } |
376 | for (const auto& session_kv : sessions_) { |
377 | auto session = session_kv.second.get(); |
378 | if (session) { |
379 | auto* worker_cache = session->worker_cache(); |
380 | if (worker_cache) { |
381 | auto step_stats = StepStats(); |
382 | if (worker_cache->RetrieveLogs(step_id, &step_stats)) { |
383 | auto* labeled_step_stats = response->add_step(); |
384 | labeled_step_stats->set_step_id(step_id); |
385 | labeled_step_stats->mutable_step_stats()->Swap(&step_stats); |
386 | } |
387 | } |
388 | } |
389 | } |
390 | } |
391 | |
392 | void SessionMgr::ClearLogs() { |
393 | mutex_lock l(mu_); |
394 | // Legacy Session |
395 | if (legacy_session_) { |
396 | auto* worker_cache = legacy_session_->worker_cache(); |
397 | if (worker_cache) { |
398 | worker_cache->ClearLogs(); |
399 | } |
400 | } |
401 | |
402 | for (const auto& session_kv : sessions_) { |
403 | auto session = session_kv.second.get(); |
404 | if (session) { |
405 | auto* worker_cache = session->worker_cache(); |
406 | if (worker_cache) { |
407 | worker_cache->ClearLogs(); |
408 | } |
409 | } |
410 | } |
411 | } |
412 | |
413 | void SessionMgr::TeardownCoordinationService() { |
414 | coordination_service_ = nullptr; |
415 | } |
416 | |
417 | void SessionMgr::TeardownCoordinationServiceAgent() { |
418 | coordination_service_agent_ = nullptr; |
419 | } |
420 | } // namespace tensorflow |
421 | |