1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/distributed_runtime/worker_session.h" |
16 | |
17 | #include "tensorflow/core/lib/monitoring/collection_registry.h" |
18 | #include "tensorflow/core/lib/monitoring/gauge.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | namespace { |
23 | |
24 | auto* worker_session_created = |
25 | monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created" , |
26 | "True if a worker session was created." ); |
27 | |
28 | // A private cache that wraps worker_cache and allows reuse of |
29 | // WorkerInterface objects. |
30 | class WorkerFreeListCache : public WorkerCacheInterface { |
31 | public: |
32 | explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w) |
33 | : wrapped_(std::move(w)) {} |
34 | |
35 | ~WorkerFreeListCache() final { |
36 | for (auto& p : workers_) { |
37 | wrapped_->ReleaseWorker(p.first, p.second.worker); |
38 | } |
39 | } |
40 | |
41 | void ListWorkers(std::vector<string>* workers) const override { |
42 | wrapped_->ListWorkers(workers); |
43 | } |
44 | |
45 | void ListWorkersInJob(const string& job_name, |
46 | std::vector<string>* workers) const override { |
47 | wrapped_->ListWorkersInJob(job_name, workers); |
48 | } |
49 | |
50 | WorkerInterface* GetOrCreateWorker(const string& target) override { |
51 | { |
52 | // Fast path if worker has been created. |
53 | tf_shared_lock l(mu_); |
54 | auto p = workers_.find(target); |
55 | if (p != workers_.end()) { |
56 | return p->second.worker; |
57 | } |
58 | } |
59 | { |
60 | // Slow path if worker hasn't been created. |
61 | mutex_lock l(mu_); |
62 | auto p = workers_.find(target); |
63 | if (p != workers_.end()) { |
64 | return p->second.worker; |
65 | } |
66 | WorkerState state; |
67 | state.worker = wrapped_->GetOrCreateWorker(target); |
68 | if (state.worker != nullptr) { |
69 | workers_.insert(std::make_pair(target, state)); |
70 | } |
71 | return state.worker; |
72 | } |
73 | } |
74 | |
75 | Status GetEagerClientCache( |
76 | std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { |
77 | return wrapped_->GetEagerClientCache(eager_client_cache); |
78 | } |
79 | |
80 | Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>* |
81 | coordination_client_cache) override { |
82 | return wrapped_->GetCoordinationClientCache(coordination_client_cache); |
83 | } |
84 | |
85 | void ReleaseWorker(const string& target, WorkerInterface* worker) override { |
86 | // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. |
87 | } |
88 | |
89 | bool GetDeviceLocalityNonBlocking(const string& device, |
90 | DeviceLocality* locality) override { |
91 | return wrapped_->GetDeviceLocalityNonBlocking(device, locality); |
92 | } |
93 | |
94 | void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, |
95 | StatusCallback done) override { |
96 | wrapped_->GetDeviceLocalityAsync(device, locality, done); |
97 | } |
98 | |
99 | void SetLogging(bool active) override { wrapped_->SetLogging(active); } |
100 | |
101 | void ClearLogs() override { wrapped_->ClearLogs(); } |
102 | |
103 | bool RetrieveLogs(int64_t step_id, StepStats* ss) override { |
104 | return wrapped_->RetrieveLogs(step_id, ss); |
105 | } |
106 | |
107 | private: |
108 | std::unique_ptr<WorkerCacheInterface> wrapped_; |
109 | |
110 | // Information kept per created WorkerInterface. |
111 | struct WorkerState { |
112 | WorkerInterface* worker; |
113 | // TODO(jeff,sanjay): Add reference count if we support eviction. |
114 | }; |
115 | |
116 | // TODO(jeff,sanjay): Eviction when the map becomes too big. |
117 | mutex mu_; |
118 | std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_); |
119 | }; |
120 | |
121 | } // namespace |
122 | |
123 | WorkerSession::WorkerSession( |
124 | const string& session_name, const string& worker_name, |
125 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
126 | std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr, |
127 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) |
128 | : session_name_(session_name), |
129 | worker_name_(worker_name), |
130 | worker_cache_(new WorkerFreeListCache(std::move(worker_cache))), |
131 | graph_mgr_(std::move(graph_mgr)), |
132 | cluster_flr_(new ClusterFunctionLibraryRuntime( |
133 | this, !session_name.empty(), |
134 | remote_device_mgr ? remote_device_mgr.get() : nullptr)), |
135 | device_mgr_(std::move(device_mgr)), |
136 | borrowed_device_mgr_(nullptr), |
137 | remote_device_mgr_(std::move(remote_device_mgr)) { |
138 | // Starts exporting metrics through a platform-specific monitoring API (if |
139 | // provided). For builds using "tensorflow/tsl/platform/default", this is |
140 | // currently a no-op. |
141 | worker_session_created->GetCell()->Set(true); |
142 | } |
143 | |
144 | Status WorkerSession::UpdateWorkerCacheAndDevices( |
145 | std::unique_ptr<WorkerCacheInterface> new_worker_cache, |
146 | std::vector<std::unique_ptr<Device>> added_remote_devices, |
147 | const std::vector<Device*>& removed_remote_devices) { |
148 | { |
149 | mutex_lock l(worker_session_state_mu_); |
150 | worker_cache_ = std::shared_ptr<WorkerCacheInterface>( |
151 | new WorkerFreeListCache(std::move(new_worker_cache))); |
152 | } |
153 | TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices)); |
154 | TF_RETURN_IF_ERROR( |
155 | remote_device_mgr_->AddDevices(std::move(added_remote_devices))); |
156 | return OkStatus(); |
157 | } |
158 | |
159 | /* static */ |
160 | std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr( |
161 | const string& session_name, const string& worker_name, |
162 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
163 | DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, |
164 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) { |
165 | return std::shared_ptr<WorkerSession>(new WorkerSession( |
166 | session_name, worker_name, std::move(worker_cache), borrowed_device_mgr, |
167 | std::move(graph_mgr), std::move(remote_device_mgr))); |
168 | } |
169 | |
170 | WorkerSession::WorkerSession( |
171 | const string& session_name, const string& worker_name, |
172 | std::unique_ptr<WorkerCacheInterface> worker_cache, |
173 | DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, |
174 | std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) |
175 | : session_name_(session_name), |
176 | worker_name_(worker_name), |
177 | worker_cache_(new WorkerFreeListCache(std::move(worker_cache))), |
178 | graph_mgr_(std::move(graph_mgr)), |
179 | cluster_flr_(new ClusterFunctionLibraryRuntime( |
180 | this, !session_name.empty(), remote_device_mgr.get())), |
181 | device_mgr_(nullptr), |
182 | borrowed_device_mgr_(borrowed_device_mgr), |
183 | remote_device_mgr_(std::move(remote_device_mgr)) { |
184 | // Starts exporting metrics through a platform-specific monitoring API (if |
185 | // provided). For builds using "tensorflow/tsl/platform/default", this is |
186 | // currently a no-op. |
187 | worker_session_created->GetCell()->Set(true); |
188 | } |
189 | |
190 | WorkerSession::~WorkerSession() { |
191 | if (graph_mgr_) { |
192 | Status s = graph_mgr_->DeregisterAll(); |
193 | if (!s.ok()) { |
194 | LOG(WARNING) << "Error during worker session deletion: " << s; |
195 | } |
196 | } |
197 | } |
198 | |
199 | } // namespace tensorflow |
200 | |