1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" |
23 | #include "tensorflow/core/distributed_runtime/eager/eager_client.h" |
24 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
25 | #include "tensorflow/core/framework/device_attributes.pb.h" // for DeviceLocality |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | |
28 | namespace tensorflow { |
29 | typedef std::function<void(const Status&)> StatusCallback; |
30 | |
31 | class ChannelCache; |
32 | class StepStats; |
33 | |
34 | class WorkerCacheInterface { |
35 | public: |
36 | virtual ~WorkerCacheInterface() {} |
37 | |
38 | // Updates *workers with strings naming the remote worker tasks to |
39 | // which open channels have been established. |
40 | virtual void ListWorkers(std::vector<string>* workers) const = 0; |
41 | virtual void ListWorkersInJob(const string& job_name, |
42 | std::vector<string>* workers) const = 0; |
43 | |
44 | // If "target" names a remote task for which an RPC channel exists |
45 | // or can be constructed, returns a pointer to a WorkerInterface object |
46 | // wrapping that channel. The returned value must be destroyed by |
47 | // calling `this->ReleaseWorker(target, ret)` |
48 | virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0; |
49 | |
50 | // Release a worker previously returned by this->GetOrCreateWorker(target). |
51 | // |
52 | // TODO(jeff,sanjay): Consider moving target into WorkerInterface. |
53 | // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a |
54 | // per-rpc-subsystem WorkerInterface creator. |
55 | virtual void ReleaseWorker(const string& target, WorkerInterface* worker) { |
56 | // Subclasses may override to reuse worker objects. |
57 | delete worker; |
58 | } |
59 | |
60 | // Set *locality with the DeviceLocality of the specified remote device |
61 | // within its local environment. Returns true if *locality |
62 | // was set, using only locally cached data. Returns false |
63 | // if status data for that device was not available. Never blocks. |
64 | virtual bool GetDeviceLocalityNonBlocking(const string& device, |
65 | DeviceLocality* locality) = 0; |
66 | |
67 | // Set *locality with the DeviceLocality of the specified remote device |
68 | // within its local environment. Callback gets Status::OK if *locality |
69 | // was set. |
70 | virtual void GetDeviceLocalityAsync(const string& device, |
71 | DeviceLocality* locality, |
72 | StatusCallback done) = 0; |
73 | |
74 | // TODO(b/189159585): Define a general client cache maker function to |
75 | // construct client cache of different types sharing the same underling RPC |
76 | // channels, to replace the eager and coordination cache function. |
77 | // Build and return a EagerClientCache object wrapping that channel. |
78 | virtual Status GetEagerClientCache( |
79 | std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0; |
80 | |
81 | // Build and return a CoordinationClientCache object wrapping that channel. |
82 | virtual Status GetCoordinationClientCache( |
83 | std::unique_ptr<CoordinationClientCache>* coordination_client_cache) = 0; |
84 | |
85 | // Start/stop logging activity. |
86 | virtual void SetLogging(bool active) {} |
87 | |
88 | // Discard any saved log data. |
89 | virtual void ClearLogs() {} |
90 | |
91 | // Return logs for the identified step in *ss. Any returned data will no |
92 | // longer be stored. |
93 | virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; } |
94 | }; |
95 | } // namespace tensorflow |
96 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ |
97 | |