1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" |
17 | |
18 | #include <unordered_set> |
19 | #include <vector> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/common_runtime/device_mgr.h" |
25 | #include "tensorflow/core/common_runtime/dma_helper.h" |
26 | #include "tensorflow/core/common_runtime/process_util.h" |
27 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
28 | #include "tensorflow/core/distributed_runtime/worker_interface.h" |
29 | #include "tensorflow/core/framework/cancellation.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/lib/core/status.h" |
33 | #include "tensorflow/core/lib/strings/numbers.h" |
34 | #include "tensorflow/core/lib/strings/str_util.h" |
35 | #include "tensorflow/core/platform/errors.h" |
36 | #include "tensorflow/core/platform/logging.h" |
37 | #include "tensorflow/core/platform/mutex.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | static void StartAbortRendevous(Rendezvous* rendez, const Status& s) { |
44 | rendez->StartAbort(s); |
45 | rendez->Unref(); |
46 | } |
47 | |
48 | BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env) |
49 | : worker_env_(worker_env) {} |
50 | |
51 | BaseRendezvousMgr::~BaseRendezvousMgr() { |
52 | for (auto& p : table_) { |
53 | auto rendez = p.second; |
54 | StartAbortRendevous(rendez, errors::Aborted("Shutdown" )); |
55 | } |
56 | } |
57 | |
58 | RemoteRendezvous* BaseRendezvousMgr::Find(int64_t step_id) { |
59 | return FindOrCreate(step_id); |
60 | } |
61 | |
62 | BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64_t step_id) { |
63 | mutex_lock l(mu_); |
64 | auto iter = table_.find(step_id); |
65 | if (iter == table_.end()) { |
66 | auto rr = Create(step_id, worker_env_); |
67 | iter = table_.insert({step_id, rr}).first; |
68 | } |
69 | iter->second->Ref(); |
70 | return iter->second; |
71 | } |
72 | |
73 | void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id, |
74 | const Rendezvous::ParsedKey& parsed, |
75 | Rendezvous::DoneCallback done) { |
76 | auto rendez = FindOrCreate(step_id); |
77 | auto done_cb = [rendez, done = std::move(done)]( |
78 | const Status& s, const Rendezvous::Args& send_args, |
79 | const Rendezvous::Args& recv_args, const Tensor& v, |
80 | bool dead) { |
81 | rendez->Unref(); |
82 | done(s, send_args, recv_args, v, dead); |
83 | }; |
84 | rendez->RecvLocalAsync(parsed, std::move(done_cb)); |
85 | } |
86 | |
87 | Status BaseRendezvousMgr::RecvLocal(int64_t step_id, |
88 | const Rendezvous::ParsedKey& parsed, |
89 | Tensor* val, bool* is_dead) { |
90 | Status ret; |
91 | Notification n; |
92 | RecvLocalAsync(step_id, parsed, |
93 | [val, is_dead, &ret, &n](const Status& s, |
94 | const Rendezvous::Args& send_args, |
95 | const Rendezvous::Args& recv_args, |
96 | const Tensor& v, const bool dead) { |
97 | ret = s; |
98 | *val = v; |
99 | *is_dead = dead; |
100 | n.Notify(); |
101 | }); |
102 | n.WaitForNotification(); |
103 | return ret; |
104 | } |
105 | |
106 | void BaseRendezvousMgr::Cleanup(int64_t step_id) { |
107 | Rendezvous* rendez = nullptr; |
108 | { |
109 | mutex_lock l(mu_); |
110 | auto iter = table_.find(step_id); |
111 | if (iter != table_.end()) { |
112 | rendez = iter->second; |
113 | table_.erase(iter); |
114 | } |
115 | } |
116 | if (rendez) { |
117 | StartAbortRendevous(rendez, errors::Aborted("Cleanup " , step_id)); |
118 | } |
119 | } |
120 | |
121 | void BaseRendezvousMgr::CleanupAll() { |
122 | mutex_lock l(mu_); |
123 | for (auto iter = table_.begin(); iter != table_.end(); iter++) { |
124 | iter->second->Unref(); |
125 | } |
126 | } |
127 | |
128 | BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, |
129 | int64_t step_id) |
130 | : env_(env), |
131 | step_id_(step_id), |
132 | local_(NewLocalRendezvous()), |
133 | session_(nullptr) {} |
134 | |
135 | BaseRemoteRendezvous::~BaseRemoteRendezvous() { |
136 | { |
137 | mutex_lock l(calls_mu_); |
138 | calls_.clear(); |
139 | } |
140 | local_->Unref(); |
141 | } |
142 | |
143 | // Returns true if "device_name" is a valid full name of local device |
144 | // of the "worker". This helper is purely based on the worker name |
145 | // and device name and does no lookups in the worker->device_mgr. |
146 | static bool IsLocalDevice(const StringPiece worker_name, |
147 | const StringPiece device_name) { |
148 | return absl::StartsWith(device_name, worker_name); |
149 | } |
150 | |
151 | Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { |
152 | CHECK_NE(session, nullptr) << "session must not be null!" ; |
153 | std::vector<DeferredCall> deferred_calls; |
154 | { |
155 | mutex_lock l(mu_); |
156 | if (session_ != nullptr) { |
157 | if (session_->worker_name() == session->worker_name()) { |
158 | VLOG(1) << "Skipping rendezvous re-initialization." ; |
159 | return OkStatus(); |
160 | } |
161 | Status s = errors::Internal( |
162 | "Double init! Worker names would have changed from: " , |
163 | session_->worker_name(), " -> " , session->worker_name()); |
164 | LOG(WARNING) << s; |
165 | return s; |
166 | } |
167 | session_ = session; |
168 | std::swap(deferred_calls, deferred_calls_); |
169 | } |
170 | for (auto& call : deferred_calls) { |
171 | RecvLocalAsyncInternal(call.parsed, std::move(call.done)); |
172 | } |
173 | return OkStatus(); |
174 | } |
175 | |
176 | WorkerSession* BaseRemoteRendezvous::session() { |
177 | tf_shared_lock l(mu_); |
178 | return session_; |
179 | } |
180 | |
181 | bool BaseRemoteRendezvous::is_initialized() { |
182 | tf_shared_lock l(mu_); |
183 | return is_initialized_locked(); |
184 | } |
185 | |
186 | Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, |
187 | const Rendezvous::Args& args, |
188 | const Tensor& val, const bool is_dead) { |
189 | VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey(); |
190 | WorkerSession* sess = nullptr; |
191 | { |
192 | tf_shared_lock l(mu_); |
193 | if (!status_.ok()) return status_; |
194 | DCHECK(is_initialized_locked()); |
195 | sess = session_; |
196 | } |
197 | |
198 | if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) { |
199 | return errors::InvalidArgument( |
200 | "Invalid rendezvous key (src): " , parsed.FullKey(), " @ " , |
201 | sess->worker_name()); |
202 | } |
203 | |
204 | // Buffers "val" and "device_context" in local_. |
205 | return local_->Send(parsed, args, val, is_dead); |
206 | } |
207 | |
208 | Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, |
209 | bool is_src) { |
210 | // Cache session pointer to avoid repeatedly taking & releasing the lock |
211 | // (e.g. calling session()) |
212 | WorkerSession* sess = nullptr; |
213 | { |
214 | tf_shared_lock l(mu_); |
215 | if (!status_.ok()) return status_; |
216 | if (!is_initialized_locked()) { |
217 | return errors::Internal("ValidateDevices called before initialization." ); |
218 | } |
219 | sess = session_; |
220 | } |
221 | if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) { |
222 | return errors::InvalidArgument( |
223 | "Invalid rendezvous key (src): " , parsed.FullKey(), " @ " , |
224 | sess->worker_name()); |
225 | } |
226 | if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) { |
227 | return errors::InvalidArgument( |
228 | "Invalid rendezvous key (dst): " , parsed.FullKey(), " @ " , |
229 | sess->worker_name()); |
230 | } |
231 | return OkStatus(); |
232 | } |
233 | |
234 | void BaseRemoteRendezvous::SameWorkerRecvDone( |
235 | const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args, |
236 | const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out, |
237 | StatusCallback done) { |
238 | // Do a quick copy (sharing the underlying buffer) if both tensors |
239 | // are on host memory. |
240 | const bool src_host = |
241 | (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU" ); |
242 | const bool dst_host = |
243 | (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU" ); |
244 | if (src_host && dst_host) { |
245 | *out = in; |
246 | done(OkStatus()); |
247 | return; |
248 | } |
249 | |
250 | // This copy must involve a GPU. Hence, "in" must support DMA |
251 | // (e.g., string tensors do not work on GPU). Variant copy DMA |
252 | // checks happen inside CopyTensor::ViaDMA. |
253 | if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT && |
254 | in.dtype() != DT_RESOURCE) { |
255 | done(errors::InvalidArgument( |
256 | "Non-DMA-safe " , DataTypeString(in.dtype()), |
257 | " tensor may not be copied from/to a device. Key: " , parsed.FullKey())); |
258 | return; |
259 | } |
260 | |
261 | WorkerSession* sess = session(); |
262 | Device* src_device; |
263 | Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device); |
264 | if (!s.ok()) { |
265 | done(s); |
266 | return; |
267 | } |
268 | Device* dst_device; |
269 | s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); |
270 | if (!s.ok()) { |
271 | done(s); |
272 | return; |
273 | } |
274 | |
275 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
276 | "SameWorkerRecvDone" , step_id_, "dynamic" , in.dtype(), |
277 | [&in]() { return in.shape().DebugString(); }); |
278 | AllocatorAttributes attr = recv_args.alloc_attrs; |
279 | attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || |
280 | recv_args.alloc_attrs.gpu_compatible()); |
281 | Allocator* out_allocator = dst_device->GetAllocator(attr); |
282 | AllocationAttributes allocation_attr; |
283 | uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0); |
284 | bool sync_dst_compute = (safe_alloc_frontier == 0); |
285 | std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() { |
286 | safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier); |
287 | return safe_alloc_frontier; |
288 | }; |
289 | if (!sync_dst_compute) { |
290 | allocation_attr.freed_by_func = &freed_by_func; |
291 | } |
292 | if (in.dtype() != DT_VARIANT) { |
293 | // Variants are handled by CopyTensor::ViaDMA. |
294 | Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr); |
295 | *out = copy; |
296 | } |
297 | |
298 | // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies, |
299 | // etc. |
300 | CopyTensor::ViaDMA( |
301 | parsed.edge_name, send_args.device_context, recv_args.device_context, |
302 | src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in, |
303 | out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute); |
304 | } |
305 | |
306 | bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src, |
307 | DeviceNameUtils::ParsedName dst) { |
308 | return DeviceNameUtils::IsSameAddressSpace(src, dst); |
309 | } |
310 | |
311 | void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, |
312 | const Rendezvous::Args& recv_args, |
313 | DoneCallback done) { |
314 | VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); |
315 | Status s = ValidateDevices(parsed, false /*!is_src*/); |
316 | if (!s.ok()) { |
317 | done(s, Args(), recv_args, Tensor(), false); |
318 | return; |
319 | } |
320 | |
321 | // ValidateDevices() returns an error status if the rendezvous is not |
322 | // initialized. |
323 | DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: " |
324 | << parsed.FullKey() << ")." ; |
325 | |
326 | profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync" , step_id_); |
327 | // Are src and dst in the same worker? |
328 | if (IsSameWorker(parsed.src, parsed.dst)) { |
329 | // Recv the tensor from local_. |
330 | local_->RecvAsync( |
331 | parsed, recv_args, |
332 | [this, parsed, done]( |
333 | const Status& status, const Rendezvous::Args& send_args, |
334 | const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { |
335 | VLOG(2) << "RemoteRendezvous Finished Recv " << this << " " |
336 | << parsed.FullKey(); |
337 | Tensor* out = new Tensor; |
338 | StatusCallback final_callback = [done, send_args, recv_args, out, |
339 | is_dead](const Status& s) { |
340 | done(s, send_args, recv_args, *out, is_dead); |
341 | delete out; |
342 | }; |
343 | |
344 | if (status.ok()) { |
345 | SameWorkerRecvDone(parsed, send_args, recv_args, in, out, |
346 | std::move(final_callback)); |
347 | } else { |
348 | final_callback(status); |
349 | } |
350 | }); |
351 | return; |
352 | } else { |
353 | RecvFromRemoteAsync(parsed, recv_args, std::move(done)); |
354 | } |
355 | } |
356 | |
357 | void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, |
358 | DoneCallback done) { |
359 | // Test whether the rendezvous is initialized using a shared lock, to avoid |
360 | // the need for exclusive access in the common case. |
361 | if (TF_PREDICT_FALSE(!is_initialized())) { |
362 | mutex_lock l(mu_); |
363 | if (!is_initialized_locked()) { |
364 | // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a |
365 | // remote worker) before the RunStep (or PartialRunStep) RPC from the |
366 | // master arrives. RecvLocalAsync thus buffers the arguments until after |
367 | // the RemoteRendezvous is Initialize()'d, when it completes the |
368 | // rendezvous logic. At some point after Initialize() is called, a Tensor |
369 | // is produced locally that will then be sent in response to the incoming |
370 | // RPC. |
371 | DeferredCall call(parsed, std::move(done)); |
372 | deferred_calls_.push_back(call); |
373 | return; |
374 | } |
375 | } |
376 | RecvLocalAsyncInternal(parsed, std::move(done)); |
377 | } |
378 | |
379 | void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, |
380 | DoneCallback done) { |
381 | Status s = ValidateDevices(parsed, true /* is_src */); |
382 | if (!s.ok()) { |
383 | done(s, Args(), Args(), Tensor(), false); |
384 | return; |
385 | } |
386 | local_->RecvAsync(parsed, Args(), std::move(done)); |
387 | } |
388 | |
389 | void BaseRemoteRendezvous::StartAbort(const Status& s) { |
390 | CHECK(!s.ok()); |
391 | // If the status passed in is a cancelled or aborted error, mark it as |
392 | // "derived" for the rendezvous. Derived status messages are ignored when |
393 | // aggregating errors across devices: this allows us to prefer our original |
394 | // status message over any cancellation related errors. |
395 | Status derived_status = s; |
396 | if (errors::IsCancelled(s) || errors::IsAborted(s)) { |
397 | derived_status = StatusGroup::MakeDerived(s); |
398 | } |
399 | |
400 | local_->StartAbort(derived_status); |
401 | |
402 | bool status_ok = false; |
403 | { |
404 | mutex_lock l(mu_); |
405 | status_ok = status_.ok(); |
406 | if (status_ok) { |
407 | status_ = derived_status; |
408 | } |
409 | } |
410 | |
411 | if (status_ok) { |
412 | // Aborts all active RecvTensor calls. |
413 | mutex_lock l(calls_mu_); |
414 | for (auto& cm_and_token_and_calls : calls_) { |
415 | for (auto& call : cm_and_token_and_calls.second.second) { |
416 | call->StartAbort(derived_status); |
417 | } |
418 | auto* cm = cm_and_token_and_calls.first; |
419 | calls_[cm].second.clear(); |
420 | } |
421 | calls_.clear(); |
422 | } |
423 | } |
424 | |
425 | void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call, |
426 | const Rendezvous::Args& args) { |
427 | CancellationManager* cm = args.cancellation_manager; |
428 | bool already_cancelled = false; |
429 | { |
430 | tf_shared_lock l(mu_); |
431 | if (!status_.ok()) { |
432 | call->StartAbort(status_); |
433 | return; |
434 | } |
435 | } |
436 | |
437 | CancellationToken token = CancellationManager::kInvalidToken; |
438 | if (cm != nullptr) { |
439 | mutex_lock l(calls_mu_); |
440 | auto it = calls_.find(cm); |
441 | if (it == calls_.end()) { |
442 | token = cm->get_cancellation_token(); |
443 | already_cancelled = !cm->RegisterCallback(token, [this, cm]() { |
444 | mutex_lock l(calls_mu_); |
445 | // Abort all the RecvTensor calls associated with thie cancellation |
446 | // manager. |
447 | for (const auto& call : calls_[cm].second) { |
448 | call->StartAbort( |
449 | errors::Cancelled("RecvFromRemoteAsync is cancelled." )); |
450 | } |
451 | }); |
452 | |
453 | if (!already_cancelled) { |
454 | calls_.emplace( |
455 | cm, |
456 | std::make_pair(token, absl::flat_hash_set<BaseRecvTensorCall*>{})); |
457 | } |
458 | } |
459 | } |
460 | |
461 | if (already_cancelled) { |
462 | call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled." )); |
463 | } else { |
464 | mutex_lock l(calls_mu_); |
465 | bool emplaced = calls_[cm].second.emplace(call).second; |
466 | CHECK(emplaced); // Crash OK. |
467 | } |
468 | } |
469 | |
470 | void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call, |
471 | const Rendezvous::Args& args) { |
472 | auto cm = args.cancellation_manager; |
473 | mutex_lock l(calls_mu_); |
474 | CancellationToken token = calls_[cm].first; |
475 | calls_[cm].second.erase(call); |
476 | if (calls_[cm].second.empty()) { |
477 | calls_.erase(cm); |
478 | if (cm != nullptr) { |
479 | cm->TryDeregisterCallback(token); |
480 | } |
481 | } |
482 | } |
483 | |
484 | BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed, |
485 | DoneCallback done) |
486 | : parsed(parsed), done(std::move(done)) {} |
487 | |
488 | } // end namespace tensorflow |
489 | |