1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <unordered_set> |
22 | |
23 | #include "absl/container/flat_hash_map.h" |
24 | #include "absl/container/flat_hash_set.h" |
25 | #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" |
26 | #include "tensorflow/core/distributed_runtime/worker_env.h" |
27 | #include "tensorflow/core/distributed_runtime/worker_session.h" |
28 | #include "tensorflow/core/framework/cancellation.h" |
29 | #include "tensorflow/core/framework/control_flow.h" |
30 | #include "tensorflow/core/framework/rendezvous.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/lib/hash/hash.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | #include "tensorflow/core/platform/mutex.h" |
35 | #include "tensorflow/core/platform/thread_annotations.h" |
36 | #include "tensorflow/core/platform/types.h" |
37 | #include "tensorflow/core/util/device_name_utils.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | class BaseRemoteRendezvous; |
42 | class BaseRecvTensorCall; |
43 | |
44 | // RendezvousMgr keeps track of a set of local rendezvous instances. |
45 | // All tensors sent by this worker are buffered in a RendezvousMgr |
46 | // until the tensor is received. Each global unique "step_id" |
47 | // corresponds to one local rendezvous instance managed by a |
48 | // RendezvousMgr. |
49 | // |
50 | // E.g., |
51 | // Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); |
52 | // fork execution of a graph executor using "rendez" on thread 1; |
53 | // fork execution of another graph executor using "rendez" on thread 2; |
54 | // ... |
55 | // join threads 1 and 2; |
56 | // |
57 | // In the example above, execution in thread 1 and 2 communicates with |
58 | // each other by send/recv operations through `rendez`. |
59 | // |
60 | // Tensors sent and received through a rendezvous managed by this |
61 | // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). |
62 | class BaseRendezvousMgr : public RendezvousMgrInterface { |
63 | public: |
64 | explicit BaseRendezvousMgr(const WorkerEnv* worker_env); |
65 | |
66 | ~BaseRendezvousMgr() override; |
67 | |
68 | // Returns Rendezvous supporting send and recv among workers in the |
69 | // "step_id". The caller takes ownership of one reference on the |
70 | // returned Rendezvous instance. |
71 | // |
72 | // Note: the caller must guarantee to eventually call Initialize on the |
73 | // returned RemoteRendezvous |
74 | RemoteRendezvous* Find(int64_t step_id) override; |
75 | |
76 | // Finds the local rendezvous instance for the "step_id". Runs |
77 | // "done" when the tensor for "key" is produced or an error occurs. |
78 | // |
79 | // This method is used by the rpc handler of RecvTensor. |
80 | void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed, |
81 | Rendezvous::DoneCallback done) override; |
82 | |
83 | // Synchronous wrapper for RecvLocalAsync. |
84 | Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, |
85 | Tensor* val, bool* is_dead) override; |
86 | |
87 | // Removes rendezvous for "step_id". |
88 | // |
89 | // TODO(zhifengc): Have a background thread in worker that |
90 | // periodically calls CleanupAll(). |
91 | void Cleanup(int64_t step_id) override; |
92 | |
93 | // Remove all rendezvous instances owned by the rendezvous_mgr. |
94 | void CleanupAll() override; |
95 | |
96 | protected: |
97 | virtual BaseRemoteRendezvous* Create(int64_t step_id, |
98 | const WorkerEnv* worker_env) = 0; |
99 | |
100 | private: |
101 | // Maps step_id to rendezvous. |
102 | typedef absl::flat_hash_map<int64_t, BaseRemoteRendezvous*> Table; |
103 | |
104 | // Not owned. |
105 | const WorkerEnv* const worker_env_; |
106 | |
107 | mutex mu_; |
108 | Table table_ TF_GUARDED_BY(mu_); |
109 | |
110 | BaseRemoteRendezvous* FindOrCreate(int64_t step_id); |
111 | |
112 | TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr); |
113 | }; |
114 | |
115 | // RemoteRendezvous is a Rendezvous which can handle either |
116 | // the producer or consumer being in a remote process. |
117 | // |
118 | // Buffering of Tensor values is delegated to a "local" Rendezvous |
119 | // obtained from NewLocalRendezvous(). This class just adds |
120 | // functionality to coordinate with remote workers. |
121 | class BaseRemoteRendezvous : public RemoteRendezvous { |
122 | public: |
123 | BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id); |
124 | |
125 | // Upgrades the BaseRemoteRendezvous to full initialization. |
126 | Status Initialize(WorkerSession* session) override; |
127 | |
128 | void SetRemoteEagerContextDefault() override { |
129 | remote_eager_context_default_ = true; |
130 | } |
131 | bool IsRemoteEagerContextDefault() override { |
132 | return remote_eager_context_default_; |
133 | } |
134 | |
135 | // Forwards to local_, where the Tensor "val" will be buffered and |
136 | // any waiting callback stored. |
137 | Status Send(const ParsedKey& key, const Rendezvous::Args& args, |
138 | const Tensor& val, const bool is_dead) override; |
139 | |
140 | // This method is called only by the RecvOp. It tests to see |
141 | // whether the value will be produced by a local or remote device |
142 | // and handles accordingly. In the local case it forwards to |
143 | // local_, in the remote case it initiates an RPC request. |
144 | void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, |
145 | DoneCallback done) override; |
146 | |
147 | void StartAbort(const Status& status) override; |
148 | |
149 | // This method is called only by the local Worker, forwarded through |
150 | // the same method on RendezvousMgr. This occurs when the Worker |
151 | // has received a RecvTensor request, either locally or over the |
152 | // network. In either case it needs to retrieve a locally buffered |
153 | // value from local_, and give it to its caller. |
154 | // |
155 | // Runs "done" as soon as the tensor for "parsed" is available or an error |
156 | // is detected. |
157 | // |
158 | // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. |
159 | void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); |
160 | |
161 | protected: |
162 | virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, |
163 | const Rendezvous::Args& args, |
164 | DoneCallback done) = 0; |
165 | |
166 | // Returns true if "src" and "dst" are located in the same worker, |
167 | // and hence may use a local rendezvous. |
168 | virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, |
169 | DeviceNameUtils::ParsedName dst); |
170 | |
171 | // If aborted, aborts "call". Otherwise, adds "call" into calls_. |
172 | void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); |
173 | |
174 | // Removes "call" from calls_ if "call" is in calls_. |
175 | void DeregisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); |
176 | |
177 | WorkerSession* session(); |
178 | |
179 | bool is_initialized(); |
180 | |
181 | ~BaseRemoteRendezvous() override; |
182 | |
183 | const WorkerEnv* const env_; // Not owned. |
184 | const int64_t step_id_; |
185 | |
186 | private: |
187 | Rendezvous* local_; // Owns a Ref on this object. |
188 | // Indicates whether this remote rendezvous instance is used as the default |
189 | // rendezvous for remote eager op-by-op execution. Errors in eager op-by-op |
190 | // execution should not abort the rendezvous since it is a context-wide |
191 | // instance and needs to be reused; instead, the errors are propagated through |
192 | // eager executors. |
193 | bool remote_eager_context_default_ = false; |
194 | |
195 | mutable mutex mu_; |
196 | mutable mutex calls_mu_; |
197 | |
198 | // Status given by StartAbort() if any. |
199 | Status status_ TF_GUARDED_BY(mu_); |
200 | |
201 | WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned. |
202 | |
203 | // Data structures to handle calls when partially initialized. |
204 | struct DeferredCall { |
205 | const ParsedKey parsed; |
206 | DoneCallback done; |
207 | |
208 | DeferredCall(const ParsedKey& parsed, DoneCallback done); |
209 | }; |
210 | std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_); |
211 | |
212 | // "CancellationToken" is stored here so that when there's no active |
213 | // RecvTensorCalls, we can de-register the callback in the cancellation |
214 | // manager. |
215 | // |
216 | // Note: pointer to CancellationManager can be nullptr in certain use cases. |
217 | absl::flat_hash_map< |
218 | CancellationManager*, |
219 | std::pair<CancellationToken, absl::flat_hash_set<BaseRecvTensorCall*>>> |
220 | calls_ TF_GUARDED_BY(calls_mu_); |
221 | |
222 | bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) { |
223 | return session_ != nullptr; |
224 | } |
225 | |
226 | // If "is_src" is true, checks that the rendezvous key "parsed"'s |
227 | // source is in this process. If "is_src" is false, checks that the |
228 | // rendezvous key "parsed"'s destination is in this process. |
229 | Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); |
230 | |
231 | // Callback handling the case when a rendezvous has been |
232 | // accomplished in local_ and the consumer is local to this process. |
233 | // Tensor "in" will be copied into "out". The key "parsed" encodes |
234 | // the src and dst devices. |
235 | void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, |
236 | const Rendezvous::Args& in_args, |
237 | const Rendezvous::Args& out_args, const Tensor& in, |
238 | Tensor* out, StatusCallback done); |
239 | |
240 | // Must be called only if fully initialized. |
241 | void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); |
242 | |
243 | TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); |
244 | }; |
245 | |
246 | class BaseRecvTensorCall { |
247 | public: |
248 | BaseRecvTensorCall() {} |
249 | virtual ~BaseRecvTensorCall() {} |
250 | |
251 | virtual void Start(std::function<void()> recv_done) = 0; |
252 | |
253 | virtual void StartAbort(const Status& s) = 0; |
254 | |
255 | virtual Status status() const = 0; |
256 | |
257 | private: |
258 | TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall); |
259 | }; |
260 | |
261 | } // end namespace tensorflow |
262 | |
263 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ |
264 | |