1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
18
19#include <string>
20#include <unordered_map>
21#include <unordered_set>
22
23#include "absl/container/flat_hash_map.h"
24#include "absl/container/flat_hash_set.h"
25#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
26#include "tensorflow/core/distributed_runtime/worker_env.h"
27#include "tensorflow/core/distributed_runtime/worker_session.h"
28#include "tensorflow/core/framework/cancellation.h"
29#include "tensorflow/core/framework/control_flow.h"
30#include "tensorflow/core/framework/rendezvous.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/hash/hash.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/mutex.h"
35#include "tensorflow/core/platform/thread_annotations.h"
36#include "tensorflow/core/platform/types.h"
37#include "tensorflow/core/util/device_name_utils.h"
38
39namespace tensorflow {
40
41class BaseRemoteRendezvous;
42class BaseRecvTensorCall;
43
44// RendezvousMgr keeps track of a set of local rendezvous instances.
45// All tensors sent by this worker are buffered in a RendezvousMgr
46// until the tensor is received. Each global unique "step_id"
47// corresponds to one local rendezvous instance managed by a
48// RendezvousMgr.
49//
50// E.g.,
51// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
52// fork execution of a graph executor using "rendez" on thread 1;
53// fork execution of another graph executor using "rendez" on thread 2;
54// ...
55// join threads 1 and 2;
56//
57// In the example above, execution in thread 1 and 2 communicates with
58// each other by send/recv operations through `rendez`.
59//
60// Tensors sent and received through a rendezvous managed by this
61// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
62class BaseRendezvousMgr : public RendezvousMgrInterface {
63 public:
64 explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
65
66 ~BaseRendezvousMgr() override;
67
68 // Returns Rendezvous supporting send and recv among workers in the
69 // "step_id". The caller takes ownership of one reference on the
70 // returned Rendezvous instance.
71 //
72 // Note: the caller must guarantee to eventually call Initialize on the
73 // returned RemoteRendezvous
74 RemoteRendezvous* Find(int64_t step_id) override;
75
76 // Finds the local rendezvous instance for the "step_id". Runs
77 // "done" when the tensor for "key" is produced or an error occurs.
78 //
79 // This method is used by the rpc handler of RecvTensor.
80 void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed,
81 Rendezvous::DoneCallback done) override;
82
83 // Synchronous wrapper for RecvLocalAsync.
84 Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed,
85 Tensor* val, bool* is_dead) override;
86
87 // Removes rendezvous for "step_id".
88 //
89 // TODO(zhifengc): Have a background thread in worker that
90 // periodically calls CleanupAll().
91 void Cleanup(int64_t step_id) override;
92
93 // Remove all rendezvous instances owned by the rendezvous_mgr.
94 void CleanupAll() override;
95
96 protected:
97 virtual BaseRemoteRendezvous* Create(int64_t step_id,
98 const WorkerEnv* worker_env) = 0;
99
100 private:
101 // Maps step_id to rendezvous.
102 typedef absl::flat_hash_map<int64_t, BaseRemoteRendezvous*> Table;
103
104 // Not owned.
105 const WorkerEnv* const worker_env_;
106
107 mutex mu_;
108 Table table_ TF_GUARDED_BY(mu_);
109
110 BaseRemoteRendezvous* FindOrCreate(int64_t step_id);
111
112 TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
113};
114
115// RemoteRendezvous is a Rendezvous which can handle either
116// the producer or consumer being in a remote process.
117//
118// Buffering of Tensor values is delegated to a "local" Rendezvous
119// obtained from NewLocalRendezvous(). This class just adds
120// functionality to coordinate with remote workers.
121class BaseRemoteRendezvous : public RemoteRendezvous {
122 public:
123 BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id);
124
125 // Upgrades the BaseRemoteRendezvous to full initialization.
126 Status Initialize(WorkerSession* session) override;
127
128 void SetRemoteEagerContextDefault() override {
129 remote_eager_context_default_ = true;
130 }
131 bool IsRemoteEagerContextDefault() override {
132 return remote_eager_context_default_;
133 }
134
135 // Forwards to local_, where the Tensor "val" will be buffered and
136 // any waiting callback stored.
137 Status Send(const ParsedKey& key, const Rendezvous::Args& args,
138 const Tensor& val, const bool is_dead) override;
139
140 // This method is called only by the RecvOp. It tests to see
141 // whether the value will be produced by a local or remote device
142 // and handles accordingly. In the local case it forwards to
143 // local_, in the remote case it initiates an RPC request.
144 void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
145 DoneCallback done) override;
146
147 void StartAbort(const Status& status) override;
148
149 // This method is called only by the local Worker, forwarded through
150 // the same method on RendezvousMgr. This occurs when the Worker
151 // has received a RecvTensor request, either locally or over the
152 // network. In either case it needs to retrieve a locally buffered
153 // value from local_, and give it to its caller.
154 //
155 // Runs "done" as soon as the tensor for "parsed" is available or an error
156 // is detected.
157 //
158 // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
159 void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
160
161 protected:
162 virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
163 const Rendezvous::Args& args,
164 DoneCallback done) = 0;
165
166 // Returns true if "src" and "dst" are located in the same worker,
167 // and hence may use a local rendezvous.
168 virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
169 DeviceNameUtils::ParsedName dst);
170
171 // If aborted, aborts "call". Otherwise, adds "call" into calls_.
172 void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
173
174 // Removes "call" from calls_ if "call" is in calls_.
175 void DeregisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
176
177 WorkerSession* session();
178
179 bool is_initialized();
180
181 ~BaseRemoteRendezvous() override;
182
183 const WorkerEnv* const env_; // Not owned.
184 const int64_t step_id_;
185
186 private:
187 Rendezvous* local_; // Owns a Ref on this object.
188 // Indicates whether this remote rendezvous instance is used as the default
189 // rendezvous for remote eager op-by-op execution. Errors in eager op-by-op
190 // execution should not abort the rendezvous since it is a context-wide
191 // instance and needs to be reused; instead, the errors are propagated through
192 // eager executors.
193 bool remote_eager_context_default_ = false;
194
195 mutable mutex mu_;
196 mutable mutex calls_mu_;
197
198 // Status given by StartAbort() if any.
199 Status status_ TF_GUARDED_BY(mu_);
200
201 WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned.
202
203 // Data structures to handle calls when partially initialized.
204 struct DeferredCall {
205 const ParsedKey parsed;
206 DoneCallback done;
207
208 DeferredCall(const ParsedKey& parsed, DoneCallback done);
209 };
210 std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);
211
212 // "CancellationToken" is stored here so that when there's no active
213 // RecvTensorCalls, we can de-register the callback in the cancellation
214 // manager.
215 //
216 // Note: pointer to CancellationManager can be nullptr in certain use cases.
217 absl::flat_hash_map<
218 CancellationManager*,
219 std::pair<CancellationToken, absl::flat_hash_set<BaseRecvTensorCall*>>>
220 calls_ TF_GUARDED_BY(calls_mu_);
221
222 bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
223 return session_ != nullptr;
224 }
225
226 // If "is_src" is true, checks that the rendezvous key "parsed"'s
227 // source is in this process. If "is_src" is false, checks that the
228 // rendezvous key "parsed"'s destination is in this process.
229 Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
230
231 // Callback handling the case when a rendezvous has been
232 // accomplished in local_ and the consumer is local to this process.
233 // Tensor "in" will be copied into "out". The key "parsed" encodes
234 // the src and dst devices.
235 void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
236 const Rendezvous::Args& in_args,
237 const Rendezvous::Args& out_args, const Tensor& in,
238 Tensor* out, StatusCallback done);
239
240 // Must be called only if fully initialized.
241 void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
242
243 TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
244};
245
246class BaseRecvTensorCall {
247 public:
248 BaseRecvTensorCall() {}
249 virtual ~BaseRecvTensorCall() {}
250
251 virtual void Start(std::function<void()> recv_done) = 0;
252
253 virtual void StartAbort(const Status& s) = 0;
254
255 virtual Status status() const = 0;
256
257 private:
258 TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
259};
260
261} // end namespace tensorflow
262
263#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
264