1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/distributed_runtime/worker_session.h"
16
17#include "tensorflow/core/lib/monitoring/collection_registry.h"
18#include "tensorflow/core/lib/monitoring/gauge.h"
19
20namespace tensorflow {
21
22namespace {
23
24auto* worker_session_created =
25 monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26 "True if a worker session was created.");
27
28// A private cache that wraps worker_cache and allows reuse of
29// WorkerInterface objects.
30class WorkerFreeListCache : public WorkerCacheInterface {
31 public:
32 explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33 : wrapped_(std::move(w)) {}
34
35 ~WorkerFreeListCache() final {
36 for (auto& p : workers_) {
37 wrapped_->ReleaseWorker(p.first, p.second.worker);
38 }
39 }
40
41 void ListWorkers(std::vector<string>* workers) const override {
42 wrapped_->ListWorkers(workers);
43 }
44
45 void ListWorkersInJob(const string& job_name,
46 std::vector<string>* workers) const override {
47 wrapped_->ListWorkersInJob(job_name, workers);
48 }
49
50 WorkerInterface* GetOrCreateWorker(const string& target) override {
51 {
52 // Fast path if worker has been created.
53 tf_shared_lock l(mu_);
54 auto p = workers_.find(target);
55 if (p != workers_.end()) {
56 return p->second.worker;
57 }
58 }
59 {
60 // Slow path if worker hasn't been created.
61 mutex_lock l(mu_);
62 auto p = workers_.find(target);
63 if (p != workers_.end()) {
64 return p->second.worker;
65 }
66 WorkerState state;
67 state.worker = wrapped_->GetOrCreateWorker(target);
68 if (state.worker != nullptr) {
69 workers_.insert(std::make_pair(target, state));
70 }
71 return state.worker;
72 }
73 }
74
75 Status GetEagerClientCache(
76 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
77 return wrapped_->GetEagerClientCache(eager_client_cache);
78 }
79
80 Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
81 coordination_client_cache) override {
82 return wrapped_->GetCoordinationClientCache(coordination_client_cache);
83 }
84
85 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
86 // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
87 }
88
89 bool GetDeviceLocalityNonBlocking(const string& device,
90 DeviceLocality* locality) override {
91 return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
92 }
93
94 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
95 StatusCallback done) override {
96 wrapped_->GetDeviceLocalityAsync(device, locality, done);
97 }
98
99 void SetLogging(bool active) override { wrapped_->SetLogging(active); }
100
101 void ClearLogs() override { wrapped_->ClearLogs(); }
102
103 bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
104 return wrapped_->RetrieveLogs(step_id, ss);
105 }
106
107 private:
108 std::unique_ptr<WorkerCacheInterface> wrapped_;
109
110 // Information kept per created WorkerInterface.
111 struct WorkerState {
112 WorkerInterface* worker;
113 // TODO(jeff,sanjay): Add reference count if we support eviction.
114 };
115
116 // TODO(jeff,sanjay): Eviction when the map becomes too big.
117 mutex mu_;
118 std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
119};
120
121} // namespace
122
123WorkerSession::WorkerSession(
124 const string& session_name, const string& worker_name,
125 std::unique_ptr<WorkerCacheInterface> worker_cache,
126 std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
127 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
128 : session_name_(session_name),
129 worker_name_(worker_name),
130 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
131 graph_mgr_(std::move(graph_mgr)),
132 cluster_flr_(new ClusterFunctionLibraryRuntime(
133 this, !session_name.empty(),
134 remote_device_mgr ? remote_device_mgr.get() : nullptr)),
135 device_mgr_(std::move(device_mgr)),
136 borrowed_device_mgr_(nullptr),
137 remote_device_mgr_(std::move(remote_device_mgr)) {
138 // Starts exporting metrics through a platform-specific monitoring API (if
139 // provided). For builds using "tensorflow/tsl/platform/default", this is
140 // currently a no-op.
141 worker_session_created->GetCell()->Set(true);
142}
143
144Status WorkerSession::UpdateWorkerCacheAndDevices(
145 std::unique_ptr<WorkerCacheInterface> new_worker_cache,
146 std::vector<std::unique_ptr<Device>> added_remote_devices,
147 const std::vector<Device*>& removed_remote_devices) {
148 {
149 mutex_lock l(worker_session_state_mu_);
150 worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
151 new WorkerFreeListCache(std::move(new_worker_cache)));
152 }
153 TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
154 TF_RETURN_IF_ERROR(
155 remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
156 return OkStatus();
157}
158
159/* static */
160std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
161 const string& session_name, const string& worker_name,
162 std::unique_ptr<WorkerCacheInterface> worker_cache,
163 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
164 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
165 return std::shared_ptr<WorkerSession>(new WorkerSession(
166 session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
167 std::move(graph_mgr), std::move(remote_device_mgr)));
168}
169
170WorkerSession::WorkerSession(
171 const string& session_name, const string& worker_name,
172 std::unique_ptr<WorkerCacheInterface> worker_cache,
173 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
174 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
175 : session_name_(session_name),
176 worker_name_(worker_name),
177 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
178 graph_mgr_(std::move(graph_mgr)),
179 cluster_flr_(new ClusterFunctionLibraryRuntime(
180 this, !session_name.empty(), remote_device_mgr.get())),
181 device_mgr_(nullptr),
182 borrowed_device_mgr_(borrowed_device_mgr),
183 remote_device_mgr_(std::move(remote_device_mgr)) {
184 // Starts exporting metrics through a platform-specific monitoring API (if
185 // provided). For builds using "tensorflow/tsl/platform/default", this is
186 // currently a no-op.
187 worker_session_created->GetCell()->Set(true);
188}
189
190WorkerSession::~WorkerSession() {
191 if (graph_mgr_) {
192 Status s = graph_mgr_->DeregisterAll();
193 if (!s.ok()) {
194 LOG(WARNING) << "Error during worker session deletion: " << s;
195 }
196 }
197}
198
199} // namespace tensorflow
200