1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/worker_cache_partial.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/process_util.h" |
19 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
20 | #include "tensorflow/core/lib/core/errors.h" |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | #include "tensorflow/core/platform/logging.h" |
23 | #include "tensorflow/core/platform/mutex.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | #include "tensorflow/core/util/device_name_utils.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | bool WorkerCachePartial::GetDeviceLocalityNonBlocking( |
30 | const string& device_name, DeviceLocality* locality) { |
31 | mutex_lock lock(mu_); // could use reader lock |
32 | auto iter = device_status_cache_.find(device_name); |
33 | if (iter != device_status_cache_.end()) { |
34 | *locality = iter->second.locality(); |
35 | return true; |
36 | } |
37 | return false; |
38 | } |
39 | |
40 | void WorkerCachePartial::GetDeviceLocalityAsync(const string& device_name, |
41 | DeviceLocality* locality, |
42 | StatusCallback done) { |
43 | if (!GetDeviceLocalityNonBlocking(device_name, locality)) { |
44 | // If cache entry was empty, make one try to fill it by RPC. |
45 | SchedClosure([this, &device_name, locality, done]() { |
46 | Status s = RefreshDeviceStatus(device_name); |
47 | if (s.ok() && !GetDeviceLocalityNonBlocking(device_name, locality)) { |
48 | s = errors::Unavailable("No known remote device: " , device_name); |
49 | } |
50 | done(s); |
51 | }); |
52 | return; |
53 | } |
54 | done(OkStatus()); |
55 | } |
56 | |
57 | Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) { |
58 | string task; |
59 | string device; |
60 | Status s; |
61 | if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) { |
62 | s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: " , |
63 | device_name); |
64 | } |
65 | auto deleter = [this, &task](WorkerInterface* wi) { |
66 | ReleaseWorker(task, wi); |
67 | }; |
68 | std::unique_ptr<WorkerInterface, decltype(deleter)> rwi( |
69 | GetOrCreateWorker(task), deleter); |
70 | if (s.ok() && !rwi) { |
71 | s = errors::Internal("RefreshDeviceStatus, unknown worker task: " , task); |
72 | } |
73 | |
74 | if (s.ok()) { |
75 | GetStatusRequest req; |
76 | GetStatusResponse resp; |
77 | s = rwi->GetStatus(&req, &resp); |
78 | if (s.ok()) { |
79 | mutex_lock lock(mu_); |
80 | for (auto& dev_attr : resp.device_attributes()) { |
81 | device_status_cache_[dev_attr.name()] = dev_attr; |
82 | } |
83 | } |
84 | } |
85 | return s; |
86 | } |
87 | |
88 | void WorkerCachePartial::FlushStatusCache() { |
89 | mutex_lock lock(mu_); |
90 | device_status_cache_.clear(); |
91 | } |
92 | |
93 | } // namespace tensorflow |
94 | |