1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Master implements the service MasterService. |
17 | // |
18 | // A Master maintains the state of live graph computation |
19 | // sessions, each session orchestrates both local and remote devices |
20 | // to carry out the graph computation. |
21 | // |
22 | // A Master knows ahead of time local devices available as |
23 | // client devices. |
24 | // |
25 | // A Master discovers remote devices on-demand and keeps track of |
26 | // statistics of those remote devices. |
27 | // |
28 | // Each session analyzes the graph, places nodes across available |
29 | // devices, and ultimately drives the graph computation by initiating |
30 | // RunGraph on the workers. |
31 | |
32 | #include "tensorflow/core/distributed_runtime/master.h" |
33 | |
34 | #include <unordered_set> |
35 | #include <vector> |
36 | |
37 | #include "tensorflow/core/common_runtime/device_set.h" |
38 | #include "tensorflow/core/common_runtime/process_util.h" |
39 | #include "tensorflow/core/distributed_runtime/remote_device.h" |
40 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
41 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
42 | #include "tensorflow/core/framework/graph_def_util.h" |
43 | #include "tensorflow/core/lib/core/errors.h" |
44 | #include "tensorflow/core/lib/core/notification.h" |
45 | #include "tensorflow/core/lib/gtl/array_slice.h" |
46 | #include "tensorflow/core/lib/gtl/cleanup.h" |
47 | #include "tensorflow/core/lib/gtl/map_util.h" |
48 | #include "tensorflow/core/lib/strings/str_util.h" |
49 | #include "tensorflow/core/platform/macros.h" |
50 | #include "tensorflow/core/platform/mutex.h" |
51 | #include "tensorflow/core/platform/regexp.h" |
52 | #include "tensorflow/core/platform/types.h" |
53 | #include "tensorflow/core/protobuf/cluster.pb.h" |
54 | #include "tensorflow/core/protobuf/master.pb.h" |
55 | #include "tensorflow/core/protobuf/worker.pb.h" |
56 | #include "tensorflow/core/public/session_options.h" |
57 | #include "tensorflow/core/util/device_name_utils.h" |
58 | |
59 | namespace tensorflow { |
60 | |
61 | namespace { |
62 | constexpr char kGrpcPrefixRegex[] = "^grpc.*://" ; |
63 | } // namespace |
64 | |
65 | Master::Master(MasterEnv* env, double session_gc_seconds) |
66 | : env_(env), |
67 | last_1000_steps_(1000), |
68 | step_count_(0), |
69 | session_gc_seconds_(session_gc_seconds), |
70 | recent_request_ids_(10000, env_->experimental_num_shards) { |
71 | // Right now, a master service must be co-located with a device. |
72 | // Otherwise, fetches do not work. |
73 | CHECK(!env->local_devices.empty()); |
74 | DCHECK_GT(env_->experimental_num_shards, 0); |
75 | |
76 | if (session_gc_seconds_ > 0.0) { |
77 | gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC" , |
78 | [this]() { GC(); }); |
79 | } else { |
80 | gc_thread_ = nullptr; |
81 | } |
82 | } |
83 | |
84 | Master::~Master() { |
85 | if (gc_thread_) { |
86 | mutex_lock l(mu_); |
87 | shutdown_ = true; |
88 | shutdown_cv_.notify_all(); |
89 | delete gc_thread_; |
90 | } |
91 | } |
92 | |
93 | void Master::GC() { |
94 | Env* env = Env::Default(); |
95 | while (true) { |
96 | mutex_lock l(mu_); |
97 | const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds. |
98 | WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds); |
99 | if (shutdown_) { |
100 | break; |
101 | } |
102 | std::vector<string> handles; |
103 | const int64_t num_micros = |
104 | static_cast<int64_t>(session_gc_seconds_ * 1000000); |
105 | for (const auto& entry : sessions_) { |
106 | int64_t lat = entry.second->last_access_time_usec(); |
107 | if (static_cast<int64_t>(env->NowMicros()) - lat > num_micros) { |
108 | handles.push_back(entry.first); |
109 | auto* sess = entry.second; |
110 | SchedClosure([this, sess]() { |
111 | LOG(WARNING) << "GC session " << sess->handle() << " after " |
112 | << session_gc_seconds_ << " seconds. " |
113 | << "Note that if you are starting multiple replicas " |
114 | << "on a staggered delay, session_gc_seconds may need " |
115 | << "to be raised." ; |
116 | sess->GarbageCollect(); |
117 | }); |
118 | } |
119 | } |
120 | for (const auto& handle : handles) sessions_.erase(handle); |
121 | } |
122 | } |
123 | |
124 | MasterSession* Master::FindMasterSession(const string& handle) { |
125 | MasterSession* session = nullptr; |
126 | { |
127 | mutex_lock l(mu_); |
128 | session = gtl::FindPtrOrNull(sessions_, handle); |
129 | if (session != nullptr) { |
130 | session->Ref(); |
131 | } |
132 | } |
133 | return session; |
134 | } |
135 | |
136 | class DeviceFinder { |
137 | public: |
138 | static Status GetRemoteDevices( |
139 | const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env, |
140 | WorkerCacheInterface* worker_cache, |
141 | std::vector<std::unique_ptr<Device>>* out_remote) { |
142 | DeviceFinder finder(device_filters, env, worker_cache); |
143 | finder.Start(); |
144 | TF_RETURN_IF_ERROR(finder.Wait()); |
145 | finder.GetRemoteDevices(env->local_devices, out_remote); |
146 | return OkStatus(); |
147 | } |
148 | |
149 | static void GetRemoteWorkers( |
150 | const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env, |
151 | WorkerCacheInterface* worker_cache, std::vector<string>* workers) { |
152 | DeviceFinder finder(device_filters, env, worker_cache); |
153 | *workers = finder.targets_; |
154 | } |
155 | |
156 | private: |
157 | explicit DeviceFinder( |
158 | const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env, |
159 | WorkerCacheInterface* worker_cache) |
160 | : env_(env), worker_cache_(worker_cache) { |
161 | CHECK(worker_cache) << "Worker cache was null!" ; |
162 | auto process_filter = [this](const string& filter) { |
163 | DeviceNameUtils::ParsedName parsed; |
164 | if (DeviceNameUtils::ParseFullName(filter, &parsed)) { |
165 | filters_.push_back(parsed); |
166 | } else { |
167 | LOG(FATAL) << "Skipping invalid filter: " << filter; |
168 | } |
169 | }; |
170 | for (const string& filter : device_filters) { |
171 | process_filter(filter); |
172 | } |
173 | // Enumerates all known workers' target. A target name is a |
174 | // prefix of a device name. E.g., /job:mnist/replica:0/task:10. |
175 | if (filters_.empty()) { |
176 | // If no filters were specified, we list all known workers in |
177 | // `worker_cache`. |
178 | std::vector<string> workers; |
179 | worker_cache->ListWorkers(&workers); |
180 | std::swap(workers, targets_); |
181 | } else { |
182 | // When applying filters, we must include the local worker, even if it |
183 | // does not match any of the filters. |
184 | CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided." ; |
185 | const string& local_device_name = env_->local_devices[0]->name(); |
186 | DeviceNameUtils::ParsedName local_parsed_name; |
187 | CHECK(DeviceNameUtils::ParseFullName(local_device_name, |
188 | &local_parsed_name)); |
189 | bool all_filters_have_job = true; |
190 | std::unordered_set<string> filter_job_names({local_parsed_name.job}); |
191 | for (const DeviceNameUtils::ParsedName& filter : filters_) { |
192 | all_filters_have_job = all_filters_have_job && filter.has_job; |
193 | if (filter.has_job) { |
194 | filter_job_names.insert(filter.job); |
195 | } |
196 | } |
197 | |
198 | std::vector<string> workers; |
199 | if (all_filters_have_job) { |
200 | // If all of the device filters have a job specified, then we only need |
201 | // to list the workers in the jobs named in the filter, because a worker |
202 | // in any other job would not match any filter. |
203 | for (const string& job_name : filter_job_names) { |
204 | VLOG(2) << "Selectively listing workers in job: " << job_name; |
205 | std::vector<string> workers_in_job; |
206 | worker_cache->ListWorkersInJob(job_name, &workers_in_job); |
207 | workers.insert(workers.end(), workers_in_job.begin(), |
208 | workers_in_job.end()); |
209 | } |
210 | } else { |
211 | // If any of the device filters does not have a job specified, then we |
212 | // must list the workers from all jobs. |
213 | VLOG(2) << "Listing workers in all jobs because some device " |
214 | << "filter has no job specified. Filters were:" ; |
215 | if (device_filters.empty()) { |
216 | VLOG(2) << "- <NO FILTERS>" ; |
217 | } else { |
218 | for (const string& filter : device_filters) { |
219 | VLOG(2) << "- " << filter; |
220 | } |
221 | } |
222 | worker_cache->ListWorkers(&workers); |
223 | } |
224 | for (const string& name : workers) { |
225 | if (MatchFilters(name) || |
226 | DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) { |
227 | targets_.push_back(name); |
228 | } |
229 | } |
230 | } |
231 | seen_targets_.assign(targets_.size(), false); |
232 | } |
233 | |
234 | ~DeviceFinder() { |
235 | for (Device* dev : found_) delete dev; |
236 | } |
237 | |
238 | void Start() { |
239 | { |
240 | mutex_lock l(mu_); |
241 | num_pending_ = targets_.size(); |
242 | if (num_pending_ == 0) { |
243 | pending_zero_.notify_all(); |
244 | } |
245 | } |
246 | // Talk to all workers to get the list of available devices. |
247 | using std::placeholders::_1; |
248 | using std::placeholders::_2; |
249 | for (size_t i = 0; i < targets_.size(); ++i) { |
250 | // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may |
251 | // never be called. |
252 | NewRemoteDevices(env_->env, worker_cache_, targets_[i], |
253 | std::bind(&ME::WhenFound, this, i, _1, _2)); |
254 | } |
255 | } |
256 | |
257 | // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting |
258 | // to hear from workers, log a list of the workers who have not |
259 | // responded. |
260 | const int32 kLoggingPeriodMs = 10 * 1000; |
261 | |
262 | Status Wait() { |
263 | mutex_lock l(mu_); |
264 | // TODO(mrry): Propagate a timeout here, since `num_pending_` may |
265 | // never become zero. |
266 | while (num_pending_ != 0) { |
267 | pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs)); |
268 | if (num_pending_ != 0) { |
269 | for (size_t i = 0; i < targets_.size(); ++i) { |
270 | if (!seen_targets_[i]) { |
271 | LOG(INFO) |
272 | << "CreateSession still waiting for response from worker: " |
273 | << targets_[i]; |
274 | } |
275 | } |
276 | } |
277 | } |
278 | return status_; |
279 | } |
280 | |
281 | // The caller takes the ownership of returned remote devices. |
282 | void GetRemoteDevices(const std::vector<Device*>& local, |
283 | std::vector<std::unique_ptr<Device>>* remote) { |
284 | std::unordered_set<string> names(local.size()); |
285 | for (Device* dev : local) names.insert(dev->name()); |
286 | mutex_lock l(mu_); |
287 | for (Device* dev : found_) { |
288 | const string& name = dev->name(); |
289 | if (names.insert(name).second && MatchFilters(name)) { |
290 | remote->push_back(std::unique_ptr<Device>(dev)); |
291 | } else { |
292 | delete dev; |
293 | } |
294 | } |
295 | found_.clear(); |
296 | } |
297 | |
298 | typedef DeviceFinder ME; |
299 | const MasterEnv* env_; |
300 | WorkerCacheInterface* worker_cache_; |
301 | std::vector<DeviceNameUtils::ParsedName> filters_; |
302 | |
303 | mutex mu_; |
304 | int num_pending_ TF_GUARDED_BY(mu_); |
305 | condition_variable pending_zero_; |
306 | std::vector<Device*> found_ TF_GUARDED_BY(mu_); |
307 | // List of targets to be contacted by this DeviceFinder. The |
308 | // respective `bool` in `seen_targets_` indicates whether we have |
309 | // heard from this target or not. |
310 | std::vector<string> targets_; |
311 | std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_); |
312 | Status status_; |
313 | |
314 | void WhenFound(int target_index, const Status& s, |
315 | std::vector<Device*>* devices) { |
316 | mutex_lock l(mu_); |
317 | seen_targets_[target_index] = true; |
318 | if (!s.ok()) { |
319 | LOG(ERROR) << "CreateSession failed because worker " |
320 | << targets_[target_index] << " returned error: " << s; |
321 | status_.Update(s); |
322 | } else { |
323 | found_.insert(found_.end(), devices->begin(), devices->end()); |
324 | devices->clear(); |
325 | } |
326 | --num_pending_; |
327 | if (num_pending_ == 0) { |
328 | pending_zero_.notify_all(); |
329 | } |
330 | } |
331 | |
332 | // Returns true iff the set of devices allowed by 'x' intersects |
333 | // with the set of devices allowed by 'y'. |
334 | bool Intersects(const DeviceNameUtils::ParsedName& x, |
335 | const DeviceNameUtils::ParsedName& y) { |
336 | return (!x.has_job || !y.has_job || x.job == y.job) && |
337 | (!x.has_replica || !y.has_replica || x.replica == y.replica) && |
338 | (!x.has_task || !y.has_task || x.task == y.task) && |
339 | (!x.has_type || !y.has_type || x.type == y.type) && |
340 | (!x.has_id || !y.has_id || x.id == y.id); |
341 | } |
342 | |
343 | // Returns true iff 'name' matches one of the filters_. |
344 | bool MatchFilters(const string& name) { |
345 | if (filters_.empty()) return true; |
346 | DeviceNameUtils::ParsedName x; |
347 | if (DeviceNameUtils::ParseFullName(name, &x)) { |
348 | for (const auto& filter : filters_) { |
349 | if (Intersects(x, filter)) return true; |
350 | } |
351 | } |
352 | return false; |
353 | } |
354 | |
355 | TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder); |
356 | }; |
357 | |
358 | void Master::CreateSession(const CreateSessionRequest* req, |
359 | CreateSessionResponse* resp, MyClosure done) { |
360 | SchedClosure([this, req, resp, done]() { |
361 | Status status; |
362 | WorkerCacheFactoryOptions worker_cache_factory_options; |
363 | string grpc_protocol("grpc" ); |
364 | worker_cache_factory_options.protocol = &grpc_protocol; |
365 | auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); |
366 | status = ValidateExternalGraphDefSyntax(req->graph_def()); |
367 | if (!status.ok()) return; |
368 | |
369 | // The following 4 variables are set differently, depending on whether this |
370 | // session uses a client-provided clusterspec or not. |
371 | WorkerCacheInterface* worker_cache = nullptr; |
372 | // Note: worker_cache_ptr will be null except if this session is using a |
373 | // client-supplied ClusterDef (ClusterSpec propagation). |
374 | std::unique_ptr<WorkerCacheInterface> worker_cache_ptr; |
375 | std::unique_ptr<DeviceSet> device_set; |
376 | // TODO(saeta): Convert to std::make_unique when available. |
377 | std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices( |
378 | new std::vector<std::unique_ptr<Device>>()); |
379 | |
380 | const ClusterDef& cluster_def = req->config().cluster_def(); |
381 | if (!cluster_def.job().empty()) { |
382 | worker_cache_factory_options.cluster_def = &cluster_def; |
383 | // If the target starts with gRPC protocol prefix, remove the prefix |
384 | string normalized_string(req->target()); |
385 | RE2::Replace(&normalized_string, kGrpcPrefixRegex, "" ); |
386 | |
387 | // Set the server_def's job_name and task_index fields. |
388 | for (auto&& job : cluster_def.job()) { |
389 | for (auto&& task : job.tasks()) { |
390 | if (task.second == normalized_string) { |
391 | if (worker_cache_factory_options.job_name != nullptr) { |
392 | status = errors::InvalidArgument( |
393 | "Found multiple matching tasks that correspond to " |
394 | "to the master. Master target: '" , |
395 | req->target(), |
396 | "'. ClusterDef: " , cluster_def.ShortDebugString()); |
397 | LOG(ERROR) << status; |
398 | return; |
399 | } |
400 | if (env_->local_devices[0]->parsed_name().job == job.name() && |
401 | env_->local_devices[0]->parsed_name().task == task.first) { |
402 | // TODO(b/37868888): Remove this limitation when resolved |
403 | status = errors::InvalidArgument( |
404 | "The ClusterSpec names the job and task index to be the same " |
405 | "names that were provided when the server booted. This is " |
406 | "currently not allowed. Job: " , |
407 | job.name(), ", task index: " , task.first); |
408 | return; |
409 | } |
410 | worker_cache_factory_options.job_name = &job.name(); |
411 | worker_cache_factory_options.task_index = task.first; |
412 | } |
413 | } |
414 | } |
415 | worker_cache_factory_options.rpc_options = &req->config().rpc_options(); |
416 | // Create the worker cache from the computed server_def. |
417 | status = env_->worker_cache_factory(worker_cache_factory_options, |
418 | &worker_cache); |
419 | if (!status.ok()) return; |
420 | worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache); |
421 | // Ping all the workers and build the list of devices that the |
422 | // session will use. |
423 | status = |
424 | DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, |
425 | worker_cache, remote_devices.get()); |
426 | if (!status.ok()) return; |
427 | device_set.reset(new DeviceSet); |
428 | for (auto&& d : *remote_devices) { |
429 | device_set->AddDevice(d.get()); |
430 | DeviceNameUtils::ParsedName name = d->parsed_name(); |
431 | if (name.job == *worker_cache_factory_options.job_name && |
432 | name.task == worker_cache_factory_options.task_index && |
433 | name.type == "CPU" && name.id == 0) { |
434 | device_set->set_client_device(d.get()); |
435 | } |
436 | } |
437 | } else { |
438 | worker_cache = env_->worker_cache; |
439 | // Ping all the workers and build the list of devices that the |
440 | // session will use. |
441 | status = |
442 | DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_, |
443 | worker_cache, remote_devices.get()); |
444 | if (!status.ok()) return; |
445 | device_set.reset(new DeviceSet); |
446 | for (auto&& d : *remote_devices) { |
447 | device_set->AddDevice(d.get()); |
448 | } |
449 | int num_local_devices = 0; |
450 | for (Device* d : env_->local_devices) { |
451 | device_set->AddDevice(d); |
452 | if (num_local_devices == 0) { |
453 | // Uses the first local device as the client device. |
454 | device_set->set_client_device(d); |
455 | } |
456 | num_local_devices++; |
457 | } |
458 | } |
459 | |
460 | CHECK(device_set->client_device()) << "No client device found. Missing " |
461 | << "CPU:0 device?" ; |
462 | |
463 | SessionOptions options; |
464 | options.target = req->target(); |
465 | options.config = req->config(); |
466 | |
467 | std::vector<string> filtered_worker_list; |
468 | DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_, |
469 | worker_cache, &filtered_worker_list); |
470 | |
471 | MasterSession* session = env_->master_session_factory( |
472 | options, env_, std::move(remote_devices), std::move(worker_cache_ptr), |
473 | std::move(device_set), std::move(filtered_worker_list)); |
474 | |
475 | GraphDef* gdef = |
476 | const_cast<CreateSessionRequest*>(req)->mutable_graph_def(); |
477 | |
478 | status = session->Create(std::move(*gdef), cluster_def); |
479 | if (!status.ok()) { |
480 | session->Close().IgnoreError(); |
481 | session->Unref(); |
482 | return; |
483 | } |
484 | resp->set_session_handle(session->handle()); |
485 | // Insert into the session map, which takes ownership of the session. |
486 | { |
487 | mutex_lock l(mu_); |
488 | CHECK(sessions_.insert({session->handle(), session}).second); |
489 | } |
490 | }); |
491 | } |
492 | |
493 | void Master::ExtendSession(const ExtendSessionRequest* req, |
494 | ExtendSessionResponse* resp, MyClosure done) { |
495 | auto session = FindMasterSession(req->session_handle()); |
496 | if (session == nullptr) { |
497 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
498 | return; |
499 | } |
500 | |
501 | SchedClosure([session, req, resp, done]() { |
502 | Status status = ValidateExternalGraphDefSyntax(req->graph_def()); |
503 | if (status.ok()) { |
504 | status = session->Extend(req, resp); |
505 | } |
506 | session->Unref(); |
507 | done(status); |
508 | }); |
509 | } |
510 | |
511 | void Master::PartialRunSetup(const PartialRunSetupRequest* req, |
512 | PartialRunSetupResponse* resp, MyClosure done) { |
513 | Status s = recent_request_ids_.TrackUnique(req->request_id(), |
514 | "PartialRunSetup (Master)" , *req); |
515 | if (!s.ok()) { |
516 | done(s); |
517 | return; |
518 | } |
519 | auto session = FindMasterSession(req->session_handle()); |
520 | if (session == nullptr) { |
521 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
522 | return; |
523 | } |
524 | |
525 | SchedClosure([session, req, resp, done]() { |
526 | Status s = session->PartialRunSetup(req, resp); |
527 | session->Unref(); |
528 | done(s); |
529 | }); |
530 | } |
531 | |
532 | void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req, |
533 | MutableRunStepResponseWrapper* resp, MyClosure done) { |
534 | Status s = recent_request_ids_.TrackUnique(req->request_id(), |
535 | "RunStep (Master)" , req); |
536 | if (!s.ok()) { |
537 | done(s); |
538 | return; |
539 | } |
540 | auto start_time = env_->env->NowMicros(); |
541 | auto session = FindMasterSession(req->session_handle()); |
542 | if (session == nullptr) { |
543 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
544 | return; |
545 | } |
546 | |
547 | SchedClosure([this, start_time, session, opts, req, resp, done]() { |
548 | Status status = session->Run(opts, *req, resp); |
549 | session->Unref(); |
550 | uint64 done_time = env_->env->NowMicros(); |
551 | done(status); |
552 | mutex_lock l(mu_); |
553 | last_1000_steps_.AddValue((done_time - start_time) / 1e9); |
554 | ++step_count_; |
555 | }); |
556 | } |
557 | |
558 | void Master::CloseSession(const CloseSessionRequest* req, |
559 | CloseSessionResponse* resp, MyClosure done) { |
560 | MasterSession* session = nullptr; |
561 | { |
562 | mu_.lock(); |
563 | auto iter = sessions_.find(req->session_handle()); |
564 | if (iter == sessions_.end()) { |
565 | mu_.unlock(); |
566 | done(errors::Aborted( |
567 | "Session " , req->session_handle(), |
568 | " is not found. Possibly, this master has restarted." )); |
569 | return; |
570 | } |
571 | // NOTE(mrry): One reference to the session is transferred from |
572 | // `sessions_[req->session_handle()]` to `session`. |
573 | session = iter->second; |
574 | sessions_.erase(iter); |
575 | mu_.unlock(); |
576 | } |
577 | |
578 | // Session Close() blocks on thread shutdown. Therefore, we need to |
579 | // delete it in non-critical thread. |
580 | SchedClosure([session, done]() { |
581 | Status s = session->Close(); |
582 | session->Unref(); |
583 | done(s); |
584 | }); |
585 | } |
586 | |
587 | void Master::ListDevices(const ListDevicesRequest* req, |
588 | ListDevicesResponse* resp, MyClosure done) { |
589 | SchedClosure([this, req, resp, done]() { |
590 | if (!req->session_handle().empty()) { |
591 | auto session = FindMasterSession(req->session_handle()); |
592 | if (session == nullptr) { |
593 | done(errors::InvalidArgument( |
594 | "Session " , req->session_handle(), |
595 | " is not found. Possibly, this master has restarted." )); |
596 | return; |
597 | } |
598 | core::ScopedUnref ref(session); |
599 | Status s = session->ListDevices(resp); |
600 | done(s); |
601 | return; |
602 | } |
603 | std::vector<std::unique_ptr<Device>> remote_devices; |
604 | Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache, |
605 | &remote_devices); |
606 | if (s.ok()) { |
607 | for (Device* dev : env_->local_devices) { |
608 | *(resp->add_local_device()) = dev->attributes(); |
609 | } |
610 | for (auto&& dev : remote_devices) { |
611 | *(resp->add_remote_device()) = dev->attributes(); |
612 | } |
613 | } |
614 | done(s); |
615 | }); |
616 | } |
617 | |
618 | void Master::CleanupWorkers(const ResetRequest& reset) { |
619 | std::vector<string> worker_names; |
620 | DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_, |
621 | env_->worker_cache, &worker_names); |
622 | if (!worker_names.empty()) { |
623 | const int num_workers = worker_names.size(); |
624 | std::vector<Notification> n(num_workers); |
625 | CleanupAllRequest req; |
626 | (*req.mutable_container()) = reset.container(); |
627 | std::vector<CleanupAllResponse> resp(num_workers); |
628 | int c = 0; |
629 | for (int i = 0; i < num_workers; ++i) { |
630 | const string& worker_name = worker_names[i]; |
631 | auto worker = env_->worker_cache->GetOrCreateWorker(worker_name); |
632 | if (worker) { |
633 | worker->CleanupAllAsync( |
634 | &req, &resp[i], [this, &n, worker_name, worker, c](Status s) { |
635 | TF_CHECK_OK(s); |
636 | env_->worker_cache->ReleaseWorker(worker_name, worker); |
637 | n[c].Notify(); |
638 | }); |
639 | } else { |
640 | n[c].Notify(); |
641 | } |
642 | ++c; |
643 | } |
644 | for (size_t i = 0; i < n.size(); ++i) { |
645 | n[i].WaitForNotification(); |
646 | } |
647 | } |
648 | } |
649 | |
650 | void Master::Reset(const ResetRequest* req, ResetResponse* resp, |
651 | MyClosure done) { |
652 | // Vector to hold the session pointers present in the sessions_ |
653 | // (string->Session*) map. |
654 | std::vector<MasterSession*> sessions_to_close; |
655 | { |
656 | mutex_lock l(mu_); |
657 | // NOTE(mrry): Transfer one reference to each session from the |
658 | // `sessions_` map to the `sessions_to_close` vector. |
659 | for (const auto& entry : sessions_) { |
660 | sessions_to_close.push_back(entry.second); |
661 | } |
662 | sessions_.clear(); |
663 | } |
664 | |
665 | CleanupWorkers(*req); |
666 | |
667 | SchedClosure([sessions_to_close, done]() { |
668 | Status s; |
669 | for (MasterSession* session : sessions_to_close) { |
670 | s.Update(session->Close()); |
671 | session->Unref(); |
672 | } |
673 | done(s); |
674 | }); |
675 | } |
676 | |
677 | void Master::MakeCallable(const MakeCallableRequest* req, |
678 | MakeCallableResponse* resp, MyClosure done) { |
679 | Status s = recent_request_ids_.TrackUnique(req->request_id(), |
680 | "MakeCallable (Master)" , *req); |
681 | if (!s.ok()) { |
682 | done(s); |
683 | return; |
684 | } |
685 | auto session = FindMasterSession(req->session_handle()); |
686 | if (session == nullptr) { |
687 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
688 | return; |
689 | } |
690 | |
691 | SchedClosure([session, req, resp, done = std::move(done)]() { |
692 | Status s = session->MakeCallable(*req, resp); |
693 | session->Unref(); |
694 | done(s); |
695 | }); |
696 | } |
697 | |
698 | void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req, |
699 | RunCallableResponse* resp, MyClosure done) { |
700 | Status s = recent_request_ids_.TrackUnique(req->request_id(), |
701 | "RunCallable (Master)" , *req); |
702 | if (!s.ok()) { |
703 | done(s); |
704 | return; |
705 | } |
706 | auto session = FindMasterSession(req->session_handle()); |
707 | if (session == nullptr) { |
708 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
709 | return; |
710 | } |
711 | |
712 | SchedClosure([session, opts, req, resp, done = std::move(done)]() { |
713 | Status s = session->RunCallable(opts, *req, resp); |
714 | session->Unref(); |
715 | done(s); |
716 | }); |
717 | } |
718 | |
719 | void Master::ReleaseCallable(const ReleaseCallableRequest* req, |
720 | ReleaseCallableResponse* resp, MyClosure done) { |
721 | auto session = FindMasterSession(req->session_handle()); |
722 | if (session == nullptr) { |
723 | done(errors::Aborted("Session " , req->session_handle(), " is not found." )); |
724 | return; |
725 | } |
726 | |
727 | SchedClosure([session, req, resp, done = std::move(done)]() { |
728 | Status s = session->ReleaseCallable(*req, resp); |
729 | session->Unref(); |
730 | done(s); |
731 | }); |
732 | } |
733 | |
734 | } // end namespace tensorflow |
735 | |