1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
17 | |
18 | #include <unordered_set> |
19 | |
20 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
21 | #include "tensorflow/core/common_runtime/device.h" |
22 | #include "tensorflow/core/common_runtime/device_mgr.h" |
23 | #include "tensorflow/core/framework/allocator.h" |
24 | #include "tensorflow/core/framework/device_factory.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/lib/core/notification.h" |
28 | #include "tensorflow/core/lib/strings/numbers.h" |
29 | #include "tensorflow/core/lib/strings/str_util.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | namespace { |
38 | void SameWorkerRecvDone(const DeviceMgr* device_mgr, |
39 | const Rendezvous::ParsedKey& parsed, |
40 | const Rendezvous::Args& send_args, |
41 | const Rendezvous::Args& recv_args, const Tensor& in, |
42 | Tensor* out, StatusCallback done) { |
43 | // Do a quick copy (sharing the underlying buffer) if both tensors |
44 | // are on host memory. |
45 | const bool src_host = |
46 | (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU" ); |
47 | const bool dst_host = |
48 | (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU" ); |
49 | if (src_host && dst_host) { |
50 | if (VLOG_IS_ON(3)) { |
51 | bool src_override = |
52 | send_args.alloc_attrs.on_host() && !(parsed.src.type == "CPU" ); |
53 | bool dst_override = |
54 | recv_args.alloc_attrs.on_host() && !(parsed.dst.type == "CPU" ); |
55 | if (src_override || dst_override) { |
56 | VLOG(3) << "Shortcut to keep tensor on host (src_override " |
57 | << src_override << " and dst_override " << dst_override |
58 | << ") tensor dtype:" << DataTypeString(in.dtype()) << " " |
59 | << parsed.FullKey(); |
60 | } |
61 | } |
62 | *out = in; |
63 | done(OkStatus()); |
64 | return; |
65 | } |
66 | |
67 | // This copy must involve a non-CPU device. Hence, "in" must support DMA |
68 | // (e.g., string tensors do not work on GPU). Variant copy DMA |
69 | // checks happen inside CopyTensor::ViaDMA. |
70 | if (!DataTypeCanUseMemcpy(in.dtype()) && in.dtype() != DT_VARIANT && |
71 | in.dtype() != DT_RESOURCE) { |
72 | done(errors::InvalidArgument( |
73 | "Non-DMA-safe " , DataTypeString(in.dtype()), |
74 | " tensor may not be copied from/to a device. Key: " , parsed.FullKey())); |
75 | return; |
76 | } |
77 | |
78 | Device* src_device; |
79 | Status s = device_mgr->LookupDevice(parsed.src_device, &src_device); |
80 | if (!s.ok()) { |
81 | done(s); |
82 | return; |
83 | } |
84 | Device* dst_device; |
85 | s = device_mgr->LookupDevice(parsed.dst_device, &dst_device); |
86 | if (!s.ok()) { |
87 | done(s); |
88 | return; |
89 | } |
90 | |
91 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
92 | "SameWorkerRecvDone" , 0, "dynamic" , in.dtype(), |
93 | [&in]() { return in.shape().DebugString(); }); |
94 | AllocatorAttributes attr = recv_args.alloc_attrs; |
95 | attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() || |
96 | recv_args.alloc_attrs.gpu_compatible()); |
97 | Allocator* out_allocator = dst_device->GetAllocator(attr); |
98 | bool sync_dst_compute = true; |
99 | if (in.dtype() != DT_VARIANT) { |
100 | // Variants are handled by CopyTensor::ViaDMA. |
101 | AllocationAttributes aa; |
102 | uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0); |
103 | std::function<uint64()> freed_by_func = [dst_device, |
104 | &safe_alloc_frontier]() { |
105 | safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier); |
106 | return safe_alloc_frontier; |
107 | }; |
108 | if ((parsed.dst.type == "GPU" || |
109 | DeviceFactory::IsPluggableDevice(parsed.dst.type)) && |
110 | safe_alloc_frontier > 0) { |
111 | // There's a timestamped allocator at work, so use it instead |
112 | // of sync_dst_compute. |
113 | aa.freed_by_func = &freed_by_func; |
114 | sync_dst_compute = false; |
115 | } |
116 | Tensor copy(out_allocator, in.dtype(), in.shape(), aa); |
117 | *out = copy; |
118 | if (in.shape().num_elements() > 0 && out->data() == nullptr) { |
119 | done(tensorflow::errors::ResourceExhausted( |
120 | "SameWorkerRecvDone unable to allocate output tensor. Key: " , |
121 | parsed.FullKey())); |
122 | return; |
123 | } |
124 | } |
125 | |
126 | CopyTensor::ViaDMA( |
127 | parsed.edge_name, send_args.device_context, recv_args.device_context, |
128 | src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in, |
129 | out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute); |
130 | } |
131 | |
132 | void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, |
133 | LocalRendezvous* local, |
134 | const RendezvousInterface::ParsedKey& parsed, |
135 | const Rendezvous::Args& recv_args, |
136 | RendezvousInterface::DoneCallback done) { |
137 | VLOG(1) << "IntraProcessRendezvous Recv " << local << " " << parsed.FullKey(); |
138 | |
139 | profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync" ); |
140 | // Recv the tensor from local_. |
141 | local->RecvAsync( |
142 | parsed, recv_args, |
143 | [device_mgr, parsed, done = std::move(done)]( |
144 | const Status& status, const Rendezvous::Args& send_args, |
145 | const Rendezvous::Args& recv_args, const Tensor& in, |
146 | bool is_dead) mutable { |
147 | // If "in" is an uninitialized tensor, do copy-construction to |
148 | // preserve the uninitialized state, along with data type and shape |
149 | // info, which is useful for debugger purposes. |
150 | Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in); |
151 | |
152 | auto final_callback = [send_args, recv_args, out, is_dead, |
153 | done = std::move(done)](const Status& s) { |
154 | done(s, send_args, recv_args, *out, is_dead); |
155 | delete out; |
156 | }; |
157 | |
158 | if (status.ok() && in.IsInitialized()) { |
159 | SameWorkerRecvDone(device_mgr, parsed, send_args, recv_args, in, out, |
160 | std::move(final_callback)); |
161 | } else { |
162 | final_callback(status); |
163 | } |
164 | }); |
165 | } |
166 | |
167 | } // namespace |
168 | |
169 | RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous( |
170 | const DeviceMgr* device_mgr) |
171 | : device_mgr_(device_mgr), local_(this) {} |
172 | |
173 | RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {} |
174 | |
175 | Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key, |
176 | const Rendezvous::Args& args, |
177 | const Tensor& val, |
178 | const bool is_dead) { |
179 | VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); |
180 | return local_.Send(key, args, val, is_dead); |
181 | } |
182 | |
183 | void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key, |
184 | const Rendezvous::Args& args, |
185 | DoneCallback done) { |
186 | VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key.FullKey(); |
187 | IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); |
188 | } |
189 | |
190 | void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) { |
191 | local_.StartAbort(s); |
192 | } |
193 | |
194 | Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() { |
195 | return local_.status(); |
196 | } |
197 | |
198 | PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous( |
199 | const DeviceMgr* device_mgr) |
200 | : device_mgr_(device_mgr), local_(nullptr) {} |
201 | |
202 | PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {} |
203 | |
204 | Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key, |
205 | const Rendezvous::Args& args, |
206 | const Tensor& val, |
207 | const bool is_dead) { |
208 | DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); |
209 | return local_.Send(key, args, val, is_dead); |
210 | } |
211 | |
212 | void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key, |
213 | const Rendezvous::Args& args, |
214 | DoneCallback done) { |
215 | DVLOG(1) << "StackAllocatedIntraProcessRendezvous Recv " << this << " " |
216 | << key.FullKey(); |
217 | IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); |
218 | } |
219 | |
220 | void PrivateIntraProcessRendezvous::StartAbort(const Status& s) { |
221 | local_.StartAbort(s); |
222 | } |
223 | |
224 | } // end namespace tensorflow |
225 | |