1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
17
18#include <unordered_set>
19
20#include "tensorflow/core/common_runtime/copy_tensor.h"
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/common_runtime/device_mgr.h"
23#include "tensorflow/core/framework/allocator.h"
24#include "tensorflow/core/framework/device_factory.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/core/notification.h"
28#include "tensorflow/core/lib/strings/numbers.h"
29#include "tensorflow/core/lib/strings/str_util.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
34
35namespace tensorflow {
36
37namespace {
38void SameWorkerRecvDone(const DeviceMgr* device_mgr,
39 const Rendezvous::ParsedKey& parsed,
40 const Rendezvous::Args& send_args,
41 const Rendezvous::Args& recv_args, const Tensor& in,
42 Tensor* out, StatusCallback done) {
43 // Do a quick copy (sharing the underlying buffer) if both tensors
44 // are on host memory.
45 const bool src_host =
46 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
47 const bool dst_host =
48 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
49 if (src_host && dst_host) {
50 if (VLOG_IS_ON(3)) {
51 bool src_override =
52 send_args.alloc_attrs.on_host() && !(parsed.src.type == "CPU");
53 bool dst_override =
54 recv_args.alloc_attrs.on_host() && !(parsed.dst.type == "CPU");
55 if (src_override || dst_override) {
56 VLOG(3) << "Shortcut to keep tensor on host (src_override "
57 << src_override << " and dst_override " << dst_override
58 << ") tensor dtype:" << DataTypeString(in.dtype()) << " "
59 << parsed.FullKey();
60 }
61 }
62 *out = in;
63 done(OkStatus());
64 return;
65 }
66
67 // This copy must involve a non-CPU device. Hence, "in" must support DMA
68 // (e.g., string tensors do not work on GPU). Variant copy DMA
69 // checks happen inside CopyTensor::ViaDMA.
70 if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT &&
71 in.dtype() != DT_RESOURCE) {
72 done(errors::InvalidArgument(
73 "Non-DMA-safe ", DataTypeString(in.dtype()),
74 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
75 return;
76 }
77
78 Device* src_device;
79 Status s = device_mgr->LookupDevice(parsed.src_device, &src_device);
80 if (!s.ok()) {
81 done(s);
82 return;
83 }
84 Device* dst_device;
85 s = device_mgr->LookupDevice(parsed.dst_device, &dst_device);
86 if (!s.ok()) {
87 done(s);
88 return;
89 }
90
91 profiler::ScopedMemoryDebugAnnotation op_annotation(
92 "SameWorkerRecvDone", 0, "dynamic", in.dtype(),
93 [&in]() { return in.shape().DebugString(); });
94 AllocatorAttributes attr = recv_args.alloc_attrs;
95 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
96 recv_args.alloc_attrs.gpu_compatible());
97 Allocator* out_allocator = dst_device->GetAllocator(attr);
98 bool sync_dst_compute = true;
99 if (in.dtype() != DT_VARIANT) {
100 // Variants are handled by CopyTensor::ViaDMA.
101 AllocationAttributes aa;
102 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
103 std::function<uint64()> freed_by_func = [dst_device,
104 &safe_alloc_frontier]() {
105 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
106 return safe_alloc_frontier;
107 };
108 if ((parsed.dst.type == "GPU" ||
109 DeviceFactory::IsPluggableDevice(parsed.dst.type)) &&
110 safe_alloc_frontier > 0) {
111 // There's a timestamped allocator at work, so use it instead
112 // of sync_dst_compute.
113 aa.freed_by_func = &freed_by_func;
114 sync_dst_compute = false;
115 }
116 Tensor copy(out_allocator, in.dtype(), in.shape(), aa);
117 *out = copy;
118 if (in.shape().num_elements() > 0 && out->data() == nullptr) {
119 done(tensorflow::errors::ResourceExhausted(
120 "SameWorkerRecvDone unable to allocate output tensor. Key: ",
121 parsed.FullKey()));
122 return;
123 }
124 }
125
126 CopyTensor::ViaDMA(
127 parsed.edge_name, send_args.device_context, recv_args.device_context,
128 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
129 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
130}
131
132void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr,
133 LocalRendezvous* local,
134 const RendezvousInterface::ParsedKey& parsed,
135 const Rendezvous::Args& recv_args,
136 RendezvousInterface::DoneCallback done) {
137 VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey();
138
139 profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync");
140 // Recv the tensor from local_.
141 local->RecvAsync(
142 parsed, recv_args,
143 [device_mgr, parsed, done = std::move(done)](
144 const Status& status, const Rendezvous::Args& send_args,
145 const Rendezvous::Args& recv_args, const Tensor& in,
146 bool is_dead) mutable {
147 // If "in" is an uninitialized tensor, do copy-construction to
148 // preserve the uninitialized state, along with data type and shape
149 // info, which is useful for debugger purposes.
150 Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in);
151
152 auto final_callback = [send_args, recv_args, out, is_dead,
153 done = std::move(done)](const Status& s) {
154 done(s, send_args, recv_args, *out, is_dead);
155 delete out;
156 };
157
158 if (status.ok() && in.IsInitialized()) {
159 SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out,
160 std::move(final_callback));
161 } else {
162 final_callback(status);
163 }
164 });
165}
166
167} // namespace
168
169RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous(
170 const DeviceMgr* device_mgr)
171 : device_mgr_(device_mgr), local_(this) {}
172
173RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {}
174
175Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key,
176 const Rendezvous::Args& args,
177 const Tensor& val,
178 const bool is_dead) {
179 VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
180 return local_.Send(key, args, val, is_dead);
181}
182
183void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
184 const Rendezvous::Args& args,
185 DoneCallback done) {
186 VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey();
187 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
188}
189
190void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) {
191 local_.StartAbort(s);
192}
193
194Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() {
195 return local_.status();
196}
197
198PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous(
199 const DeviceMgr* device_mgr)
200 : device_mgr_(device_mgr), local_(nullptr) {}
201
202PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {}
203
204Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key,
205 const Rendezvous::Args& args,
206 const Tensor& val,
207 const bool is_dead) {
208 DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey();
209 return local_.Send(key, args, val, is_dead);
210}
211
212void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key,
213 const Rendezvous::Args& args,
214 DoneCallback done) {
215 DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " "
216 << key.FullKey();
217 IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done));
218}
219
220void PrivateIntraProcessRendezvous::StartAbort(const Status& s) {
221 local_.StartAbort(s);
222}
223
224} // end namespace tensorflow
225