1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | class WorkerCacheWrapper : public WorkerCacheInterface { |
27 | public: |
28 | WorkerCacheWrapper(WorkerCacheInterface* wrapped) : wrapped_(wrapped) {} |
29 | |
30 | // Updates *workers with strings naming the remote worker tasks to |
31 | // which open channels have been established. |
32 | void ListWorkers(std::vector<string>* workers) const override { |
33 | return wrapped_->ListWorkers(workers); |
34 | } |
35 | void ListWorkersInJob(const string& job_name, |
36 | std::vector<string>* workers) const override { |
37 | return wrapped_->ListWorkersInJob(job_name, workers); |
38 | } |
39 | |
40 | // If "target" names a remote task for which an RPC channel exists |
41 | // or can be constructed, returns a pointer to a WorkerInterface object |
42 | // wrapping that channel. The returned value must be destroyed by |
43 | // calling `this->ReleaseWorker(target, ret)` |
44 | WorkerInterface* GetOrCreateWorker(const string& target) override { |
45 | return wrapped_->GetOrCreateWorker(target); |
46 | } |
47 | |
48 | // Release a worker previously returned by this->GetOrCreateWorker(target). |
49 | // |
50 | // TODO(jeff,sanjay): Consider moving target into WorkerInterface. |
51 | // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a |
52 | // per-rpc-subsystem WorkerInterface creator. |
53 | void ReleaseWorker(const string& target, WorkerInterface* worker) override { |
54 | return wrapped_->ReleaseWorker(target, worker); |
55 | } |
56 | |
57 | Status GetEagerClientCache( |
58 | std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { |
59 | return wrapped_->GetEagerClientCache(eager_client_cache); |
60 | } |
61 | |
62 | Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>* |
63 | coordination_client_cache) override { |
64 | return wrapped_->GetCoordinationClientCache(coordination_client_cache); |
65 | } |
66 | |
67 | // Set *locality with the DeviceLocality of the specified remote device |
68 | // within its local environment. Returns true if *locality |
69 | // was set, using only locally cached data. Returns false |
70 | // if status data for that device was not available. Never blocks. |
71 | bool GetDeviceLocalityNonBlocking(const string& device, |
72 | DeviceLocality* locality) override { |
73 | return wrapped_->GetDeviceLocalityNonBlocking(device, locality); |
74 | } |
75 | |
76 | // Set *locality with the DeviceLocality of the specified remote device |
77 | // within its local environment. Callback gets Status::OK if *locality |
78 | // was set. |
79 | void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, |
80 | StatusCallback done) override { |
81 | return wrapped_->GetDeviceLocalityAsync(device, locality, std::move(done)); |
82 | } |
83 | |
84 | // Start/stop logging activity. |
85 | void SetLogging(bool active) override { wrapped_->SetLogging(active); } |
86 | |
87 | // Discard any saved log data. |
88 | void ClearLogs() override { wrapped_->ClearLogs(); } |
89 | |
90 | // Return logs for the identified step in *ss. Any returned data will no |
91 | // longer be stored. |
92 | bool RetrieveLogs(int64_t step_id, StepStats* ss) override { |
93 | return wrapped_->RetrieveLogs(step_id, ss); |
94 | } |
95 | |
96 | private: |
97 | WorkerCacheInterface* wrapped_; // Not owned. |
98 | }; |
99 | } // namespace tensorflow |
100 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ |
101 | |