1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17
18#include <unordered_set>
19#include <vector>
20
21#include "absl/container/flat_hash_set.h"
22#include "tensorflow/core/common_runtime/copy_tensor.h"
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/common_runtime/device_mgr.h"
25#include "tensorflow/core/common_runtime/dma_helper.h"
26#include "tensorflow/core/common_runtime/process_util.h"
27#include "tensorflow/core/distributed_runtime/worker_cache.h"
28#include "tensorflow/core/distributed_runtime/worker_interface.h"
29#include "tensorflow/core/framework/cancellation.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/lib/strings/numbers.h"
34#include "tensorflow/core/lib/strings/str_util.h"
35#include "tensorflow/core/platform/errors.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/platform/mutex.h"
38#include "tensorflow/core/platform/types.h"
39#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
40
41namespace tensorflow {
42
43static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
44 rendez->StartAbort(s);
45 rendez->Unref();
46}
47
48BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
49 : worker_env_(worker_env) {}
50
51BaseRendezvousMgr::~BaseRendezvousMgr() {
52 for (auto& p : table_) {
53 auto rendez = p.second;
54 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
55 }
56}
57
58RemoteRendezvous* BaseRendezvousMgr::Find(int64_t step_id) {
59 return FindOrCreate(step_id);
60}
61
62BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64_t step_id) {
63 mutex_lock l(mu_);
64 auto iter = table_.find(step_id);
65 if (iter == table_.end()) {
66 auto rr = Create(step_id, worker_env_);
67 iter = table_.insert({step_id, rr}).first;
68 }
69 iter->second->Ref();
70 return iter->second;
71}
72
73void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
74 const Rendezvous::ParsedKey& parsed,
75 Rendezvous::DoneCallback done) {
76 auto rendez = FindOrCreate(step_id);
77 auto done_cb = [rendez, done = std::move(done)](
78 const Status& s, const Rendezvous::Args& send_args,
79 const Rendezvous::Args& recv_args, const Tensor& v,
80 bool dead) {
81 rendez->Unref();
82 done(s, send_args, recv_args, v, dead);
83 };
84 rendez->RecvLocalAsync(parsed, std::move(done_cb));
85}
86
87Status BaseRendezvousMgr::RecvLocal(int64_t step_id,
88 const Rendezvous::ParsedKey& parsed,
89 Tensor* val, bool* is_dead) {
90 Status ret;
91 Notification n;
92 RecvLocalAsync(step_id, parsed,
93 [val, is_dead, &ret, &n](const Status& s,
94 const Rendezvous::Args& send_args,
95 const Rendezvous::Args& recv_args,
96 const Tensor& v, const bool dead) {
97 ret = s;
98 *val = v;
99 *is_dead = dead;
100 n.Notify();
101 });
102 n.WaitForNotification();
103 return ret;
104}
105
106void BaseRendezvousMgr::Cleanup(int64_t step_id) {
107 Rendezvous* rendez = nullptr;
108 {
109 mutex_lock l(mu_);
110 auto iter = table_.find(step_id);
111 if (iter != table_.end()) {
112 rendez = iter->second;
113 table_.erase(iter);
114 }
115 }
116 if (rendez) {
117 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
118 }
119}
120
121void BaseRendezvousMgr::CleanupAll() {
122 mutex_lock l(mu_);
123 for (auto iter = table_.begin(); iter != table_.end(); iter++) {
124 iter->second->Unref();
125 }
126}
127
128BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
129 int64_t step_id)
130 : env_(env),
131 step_id_(step_id),
132 local_(NewLocalRendezvous()),
133 session_(nullptr) {}
134
135BaseRemoteRendezvous::~BaseRemoteRendezvous() {
136 {
137 mutex_lock l(calls_mu_);
138 calls_.clear();
139 }
140 local_->Unref();
141}
142
143// Returns true if "device_name" is a valid full name of local device
144// of the "worker". This helper is purely based on the worker name
145// and device name and does no lookups in the worker->device_mgr.
146static bool IsLocalDevice(const StringPiece worker_name,
147 const StringPiece device_name) {
148 return absl::StartsWith(device_name, worker_name);
149}
150
151Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
152 CHECK_NE(session, nullptr) << "session must not be null!";
153 std::vector<DeferredCall> deferred_calls;
154 {
155 mutex_lock l(mu_);
156 if (session_ != nullptr) {
157 if (session_->worker_name() == session->worker_name()) {
158 VLOG(1) << "Skipping rendezvous re-initialization.";
159 return OkStatus();
160 }
161 Status s = errors::Internal(
162 "Double init! Worker names would have changed from: ",
163 session_->worker_name(), " -> ", session->worker_name());
164 LOG(WARNING) << s;
165 return s;
166 }
167 session_ = session;
168 std::swap(deferred_calls, deferred_calls_);
169 }
170 for (auto& call : deferred_calls) {
171 RecvLocalAsyncInternal(call.parsed, std::move(call.done));
172 }
173 return OkStatus();
174}
175
176WorkerSession* BaseRemoteRendezvous::session() {
177 tf_shared_lock l(mu_);
178 return session_;
179}
180
181bool BaseRemoteRendezvous::is_initialized() {
182 tf_shared_lock l(mu_);
183 return is_initialized_locked();
184}
185
186Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
187 const Rendezvous::Args& args,
188 const Tensor& val, const bool is_dead) {
189 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
190 WorkerSession* sess = nullptr;
191 {
192 tf_shared_lock l(mu_);
193 if (!status_.ok()) return status_;
194 DCHECK(is_initialized_locked());
195 sess = session_;
196 }
197
198 if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
199 return errors::InvalidArgument(
200 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
201 sess->worker_name());
202 }
203
204 // Buffers "val" and "device_context" in local_.
205 return local_->Send(parsed, args, val, is_dead);
206}
207
208Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
209 bool is_src) {
210 // Cache session pointer to avoid repeatedly taking & releasing the lock
211 // (e.g. calling session())
212 WorkerSession* sess = nullptr;
213 {
214 tf_shared_lock l(mu_);
215 if (!status_.ok()) return status_;
216 if (!is_initialized_locked()) {
217 return errors::Internal("ValidateDevices called before initialization.");
218 }
219 sess = session_;
220 }
221 if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
222 return errors::InvalidArgument(
223 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
224 sess->worker_name());
225 }
226 if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
227 return errors::InvalidArgument(
228 "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
229 sess->worker_name());
230 }
231 return OkStatus();
232}
233
234void BaseRemoteRendezvous::SameWorkerRecvDone(
235 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
236 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
237 StatusCallback done) {
238 // Do a quick copy (sharing the underlying buffer) if both tensors
239 // are on host memory.
240 const bool src_host =
241 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
242 const bool dst_host =
243 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
244 if (src_host && dst_host) {
245 *out = in;
246 done(OkStatus());
247 return;
248 }
249
250 // This copy must involve a GPU. Hence, "in" must support DMA
251 // (e.g., string tensors do not work on GPU). Variant copy DMA
252 // checks happen inside CopyTensor::ViaDMA.
253 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
254 in.dtype() != DT_RESOURCE) {
255 done(errors::InvalidArgument(
256 "Non-DMA-safe ", DataTypeString(in.dtype()),
257 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
258 return;
259 }
260
261 WorkerSession* sess = session();
262 Device* src_device;
263 Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
264 if (!s.ok()) {
265 done(s);
266 return;
267 }
268 Device* dst_device;
269 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
270 if (!s.ok()) {
271 done(s);
272 return;
273 }
274
275 profiler::ScopedMemoryDebugAnnotation op_annotation(
276 "SameWorkerRecvDone", step_id_, "dynamic", in.dtype(),
277 [&in]() { return in.shape().DebugString(); });
278 AllocatorAttributes attr = recv_args.alloc_attrs;
279 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
280 recv_args.alloc_attrs.gpu_compatible());
281 Allocator* out_allocator = dst_device->GetAllocator(attr);
282 AllocationAttributes allocation_attr;
283 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
284 bool sync_dst_compute = (safe_alloc_frontier == 0);
285 std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
286 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
287 return safe_alloc_frontier;
288 };
289 if (!sync_dst_compute) {
290 allocation_attr.freed_by_func = &freed_by_func;
291 }
292 if (in.dtype() != DT_VARIANT) {
293 // Variants are handled by CopyTensor::ViaDMA.
294 Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
295 *out = copy;
296 }
297
298 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
299 // etc.
300 CopyTensor::ViaDMA(
301 parsed.edge_name, send_args.device_context, recv_args.device_context,
302 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
303 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
304}
305
306bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
307 DeviceNameUtils::ParsedName dst) {
308 return DeviceNameUtils::IsSameAddressSpace(src, dst);
309}
310
311void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
312 const Rendezvous::Args& recv_args,
313 DoneCallback done) {
314 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
315 Status s = ValidateDevices(parsed, false /*!is_src*/);
316 if (!s.ok()) {
317 done(s, Args(), recv_args, Tensor(), false);
318 return;
319 }
320
321 // ValidateDevices() returns an error status if the rendezvous is not
322 // initialized.
323 DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
324 << parsed.FullKey() << ").";
325
326 profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
327 // Are src and dst in the same worker?
328 if (IsSameWorker(parsed.src, parsed.dst)) {
329 // Recv the tensor from local_.
330 local_->RecvAsync(
331 parsed, recv_args,
332 [this, parsed, done](
333 const Status& status, const Rendezvous::Args& send_args,
334 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
335 VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
336 << parsed.FullKey();
337 Tensor* out = new Tensor;
338 StatusCallback final_callback = [done, send_args, recv_args, out,
339 is_dead](const Status& s) {
340 done(s, send_args, recv_args, *out, is_dead);
341 delete out;
342 };
343
344 if (status.ok()) {
345 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
346 std::move(final_callback));
347 } else {
348 final_callback(status);
349 }
350 });
351 return;
352 } else {
353 RecvFromRemoteAsync(parsed, recv_args, std::move(done));
354 }
355}
356
357void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
358 DoneCallback done) {
359 // Test whether the rendezvous is initialized using a shared lock, to avoid
360 // the need for exclusive access in the common case.
361 if (TF_PREDICT_FALSE(!is_initialized())) {
362 mutex_lock l(mu_);
363 if (!is_initialized_locked()) {
364 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
365 // remote worker) before the RunStep (or PartialRunStep) RPC from the
366 // master arrives. RecvLocalAsync thus buffers the arguments until after
367 // the RemoteRendezvous is Initialize()'d, when it completes the
368 // rendezvous logic. At some point after Initialize() is called, a Tensor
369 // is produced locally that will then be sent in response to the incoming
370 // RPC.
371 DeferredCall call(parsed, std::move(done));
372 deferred_calls_.push_back(call);
373 return;
374 }
375 }
376 RecvLocalAsyncInternal(parsed, std::move(done));
377}
378
379void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
380 DoneCallback done) {
381 Status s = ValidateDevices(parsed, true /* is_src */);
382 if (!s.ok()) {
383 done(s, Args(), Args(), Tensor(), false);
384 return;
385 }
386 local_->RecvAsync(parsed, Args(), std::move(done));
387}
388
389void BaseRemoteRendezvous::StartAbort(const Status& s) {
390 CHECK(!s.ok());
391 // If the status passed in is a cancelled or aborted error, mark it as
392 // "derived" for the rendezvous. Derived status messages are ignored when
393 // aggregating errors across devices: this allows us to prefer our original
394 // status message over any cancellation related errors.
395 Status derived_status = s;
396 if (errors::IsCancelled(s) || errors::IsAborted(s)) {
397 derived_status = StatusGroup::MakeDerived(s);
398 }
399
400 local_->StartAbort(derived_status);
401
402 bool status_ok = false;
403 {
404 mutex_lock l(mu_);
405 status_ok = status_.ok();
406 if (status_ok) {
407 status_ = derived_status;
408 }
409 }
410
411 if (status_ok) {
412 // Aborts all active RecvTensor calls.
413 mutex_lock l(calls_mu_);
414 for (auto& cm_and_token_and_calls : calls_) {
415 for (auto& call : cm_and_token_and_calls.second.second) {
416 call->StartAbort(derived_status);
417 }
418 auto* cm = cm_and_token_and_calls.first;
419 calls_[cm].second.clear();
420 }
421 calls_.clear();
422 }
423}
424
425void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
426 const Rendezvous::Args& args) {
427 CancellationManager* cm = args.cancellation_manager;
428 bool already_cancelled = false;
429 {
430 tf_shared_lock l(mu_);
431 if (!status_.ok()) {
432 call->StartAbort(status_);
433 return;
434 }
435 }
436
437 CancellationToken token = CancellationManager::kInvalidToken;
438 if (cm != nullptr) {
439 mutex_lock l(calls_mu_);
440 auto it = calls_.find(cm);
441 if (it == calls_.end()) {
442 token = cm->get_cancellation_token();
443 already_cancelled = !cm->RegisterCallback(token, [this, cm]() {
444 mutex_lock l(calls_mu_);
445 // Abort all the RecvTensor calls associated with thie cancellation
446 // manager.
447 for (const auto& call : calls_[cm].second) {
448 call->StartAbort(
449 errors::Cancelled("RecvFromRemoteAsync is cancelled."));
450 }
451 });
452
453 if (!already_cancelled) {
454 calls_.emplace(
455 cm,
456 std::make_pair(token, absl::flat_hash_set<BaseRecvTensorCall*>{}));
457 }
458 }
459 }
460
461 if (already_cancelled) {
462 call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
463 } else {
464 mutex_lock l(calls_mu_);
465 bool emplaced = calls_[cm].second.emplace(call).second;
466 CHECK(emplaced); // Crash OK.
467 }
468}
469
470void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call,
471 const Rendezvous::Args& args) {
472 auto cm = args.cancellation_manager;
473 mutex_lock l(calls_mu_);
474 CancellationToken token = calls_[cm].first;
475 calls_[cm].second.erase(call);
476 if (calls_[cm].second.empty()) {
477 calls_.erase(cm);
478 if (cm != nullptr) {
479 cm->TryDeregisterCallback(token);
480 }
481 }
482}
483
484BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
485 DoneCallback done)
486 : parsed(parsed), done(std::move(done)) {}
487
488} // end namespace tensorflow
489