1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/copy_tensor.h"
17
18#include <atomic>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/dma_helper.h"
23#include "tensorflow/core/framework/device_factory.h"
24#include "tensorflow/core/framework/variant_op_registry.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/refcount.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/profiler/lib/scoped_annotation.h"
29#include "tensorflow/core/util/reffed_status_callback.h"
30
31namespace tensorflow {
32namespace {
33
34struct RegistrationInfo {
35 RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf,
36 bool is_pluggable_device)
37 : sender_device_type(std::move(s)),
38 receiver_device_type(std::move(r)),
39 copy_function(cf),
40 is_pluggable_device(is_pluggable_device) {}
41 DeviceType sender_device_type;
42 DeviceType receiver_device_type;
43 CopyTensor::CopyFunction copy_function;
44 bool is_pluggable_device;
45};
46
47// We use a vector instead of a map since we expect there to be very
48// few registrations.
49std::vector<RegistrationInfo>* MutableRegistry() {
50 static std::vector<RegistrationInfo>* registry =
51 new std::vector<RegistrationInfo>;
52 return registry;
53}
54
55void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
56 Allocator* out_allocator, StringPiece edge_name,
57 Device* dst, Tensor* output,
58 DeviceContext* recv_dev_context, StatusCallback done,
59 bool sync_dst_compute) {
60 if (input->dtype() == DT_VARIANT) {
61 Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
62 auto* status_cb = new ReffedStatusCallback(std::move(done));
63 core::ScopedUnref status_cb_unref(status_cb);
64
65 auto wrapped_done = [status_cb](const Status& s) {
66 status_cb->UpdateStatus(s);
67 status_cb->Unref();
68 };
69 auto copier = [dst, recv_dev_context, out_allocator, status_cb,
70 cpu_allocator, edge_name, sync_dst_compute,
71 wrapped_done = std::move(wrapped_done)](const Tensor& from,
72 Tensor* to) {
73 if (from.dtype() == DT_VARIANT) {
74 status_cb->Ref();
75 CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name, dst,
76 to, recv_dev_context, wrapped_done, sync_dst_compute);
77 return OkStatus();
78 } else {
79 if (!DMAHelper::CanUseDMA(&from)) {
80 Status err = errors::InvalidArgument(
81 "During Variant Host->Device Copy: "
82 "non-DMA-copy attempted of tensor type: ",
83 DataTypeString(from.dtype()));
84 status_cb->UpdateStatus(err);
85 return err;
86 }
87 if (status_cb->ok()) {
88 status_cb->Ref();
89 *to = Tensor(out_allocator, from.dtype(), from.shape());
90 recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, wrapped_done,
91 sync_dst_compute);
92 return OkStatus();
93 } else {
94 return status_cb->status();
95 }
96 }
97 };
98
99 const Variant* v = input->flat<Variant>().data();
100 Variant* v_out = copy.flat<Variant>().data();
101 Status s_copy_init;
102 for (int64_t i = 0; i < input->NumElements(); ++i) {
103 s_copy_init = VariantDeviceCopy(
104 VariantDeviceCopyDirection::HOST_TO_DEVICE, v[i], &v_out[i], copier);
105 if (!s_copy_init.ok()) {
106 status_cb->UpdateStatus(s_copy_init);
107 break;
108 }
109 }
110 if (s_copy_init.ok()) {
111 *output = std::move(copy);
112 }
113 } else if (input->dtype() == DT_RESOURCE) {
114 *output = *input;
115 done(OkStatus());
116 } else {
117 recv_dev_context->CopyCPUTensorToDevice(input, dst, output, std::move(done),
118 sync_dst_compute);
119 }
120}
121
122void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
123 Allocator* cpu_allocator, Allocator* out_allocator,
124 DeviceContext* send_dev_context,
125 DeviceContext* recv_dev_context, Device* src,
126 Device* dst, const AllocatorAttributes src_alloc_attr,
127 const AllocatorAttributes dst_alloc_attr,
128 const Tensor* input, Tensor* output,
129 int dev_to_dev_stream_index, StatusCallback done) {
130 if (input->dtype() == DT_VARIANT) {
131 Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
132 auto* status_cb = new ReffedStatusCallback(std::move(done));
133 core::ScopedUnref status_cb_unref(status_cb);
134
135 auto wrapped_done = [status_cb](const Status& s) {
136 status_cb->UpdateStatus(s);
137 status_cb->Unref();
138 };
139 auto copier = [copy_function, cpu_allocator, src, dst, src_alloc_attr,
140 dst_alloc_attr, recv_dev_context, send_dev_context,
141 out_allocator, status_cb, dev_to_dev_stream_index,
142 wrapped_done = std::move(wrapped_done)](
143 // Begin unbound arguments
144 const Tensor& from, Tensor* to) {
145 if (from.dtype() == DT_VARIANT) {
146 status_cb->Ref();
147 CopyDeviceToDevice(copy_function, cpu_allocator, out_allocator,
148 send_dev_context, recv_dev_context, src, dst,
149 src_alloc_attr, dst_alloc_attr, &from, to,
150 dev_to_dev_stream_index, wrapped_done);
151 return OkStatus();
152 } else {
153 if (!DMAHelper::CanUseDMA(&from)) {
154 Status err = errors::InvalidArgument(
155 "During Variant Device->Device Copy: ", src->name(), " to ",
156 dst->name(), " non-DMA-copy attempted of tensor type: ",
157 DataTypeString(from.dtype()));
158 status_cb->UpdateStatus(err);
159 return err;
160 }
161 if (status_cb->ok()) {
162 status_cb->Ref();
163 *to = Tensor(out_allocator, from.dtype(), from.shape());
164 copy_function(send_dev_context, recv_dev_context, src, dst,
165 src_alloc_attr, dst_alloc_attr, &from, to,
166 dev_to_dev_stream_index, wrapped_done);
167 return OkStatus();
168 } else {
169 return status_cb->status();
170 }
171 }
172 };
173
174 const Variant* v = input->flat<Variant>().data();
175 Variant* v_out = copy.flat<Variant>().data();
176 Status s_copy_init;
177 for (int64_t i = 0; i < input->NumElements(); ++i) {
178 s_copy_init =
179 VariantDeviceCopy(VariantDeviceCopyDirection::DEVICE_TO_DEVICE, v[i],
180 &v_out[i], copier);
181 if (!s_copy_init.ok()) {
182 status_cb->UpdateStatus(s_copy_init);
183 break;
184 }
185 }
186 if (s_copy_init.ok()) {
187 *output = std::move(copy);
188 }
189 } else if (input->dtype() == DT_RESOURCE) {
190 *output = *input;
191 done(OkStatus());
192 } else {
193 copy_function(send_dev_context, recv_dev_context, src, dst, src_alloc_attr,
194 dst_alloc_attr, input, output, dev_to_dev_stream_index,
195 std::move(done));
196 }
197}
198
199} // namespace
200
201// static
202void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
203 DeviceContext* recv_dev_context, Device* src,
204 Device* dst, const AllocatorAttributes src_alloc_attr,
205 const AllocatorAttributes dst_alloc_attr,
206 const Tensor* input, Tensor* output,
207 int dev_to_dev_stream_index, StatusCallback done,
208 bool sync_dst_compute) {
209 profiler::ScopedAnnotation annotation(
210 [&] { return absl::StrCat("#edge_name=", edge_name, "#"); });
211 VLOG(1) << "Copy " << edge_name;
212
213 const DeviceType src_device_type(
214 src_alloc_attr.on_host() ? DEVICE_CPU : src->attributes().device_type());
215 const DeviceType dst_device_type(
216 dst_alloc_attr.on_host() ? DEVICE_CPU : dst->attributes().device_type());
217 const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
218 const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
219
220 // TODO(phawkins): choose an allocator optimal for both the src and dst
221 // devices, not just the src device.
222 AllocatorAttributes host_alloc_attrs;
223 host_alloc_attrs.set_gpu_compatible(true);
224 host_alloc_attrs.set_on_host(true);
225 Allocator* cpu_allocator = src->GetAllocator(host_alloc_attrs);
226 Allocator* out_allocator = dst->GetAllocator(dst_alloc_attr);
227
228 // E.g., gpu -> gpu
229 if (non_cpu_src && non_cpu_dst) {
230 // Device to device copy. Look through registry for an appropriate
231 // CopyFunction.
232 std::vector<RegistrationInfo>* registry = MutableRegistry();
233 // TODO(penpornk): Revisit the lookup mechanism after PR #43611 (device
234 // alias) is resolved.
235 const bool src_device_is_pluggable =
236 DeviceFactory::IsPluggableDevice(src_device_type.type_string());
237 for (const RegistrationInfo& ri : *registry) {
238 if (ri.sender_device_type == src_device_type &&
239 ri.receiver_device_type == dst_device_type) {
240 if (src_device_is_pluggable && !ri.is_pluggable_device) continue;
241 CopyDeviceToDevice(ri.copy_function, cpu_allocator, out_allocator,
242 send_dev_context, recv_dev_context, src, dst,
243 src_alloc_attr, dst_alloc_attr, input, output,
244 dev_to_dev_stream_index, std::move(done));
245 return;
246 }
247 }
248
249 // Fall back to copying via the host.
250 VLOG(1) << "No function registered to copy from devices of type "
251 << src_device_type.type() << " to devices of type "
252 << dst_device_type.type()
253 << ". Falling back to copying via the host.";
254
255 Tensor* cpu_tensor =
256 new Tensor(cpu_allocator, input->dtype(), input->shape());
257 auto delete_and_done = [cpu_tensor,
258 done = std::move(done)](const Status& status) {
259 delete cpu_tensor;
260 done(status);
261 };
262 auto then_copy_to_other_device =
263 [delete_and_done = std::move(delete_and_done), recv_dev_context,
264 cpu_tensor, cpu_allocator, out_allocator, edge_name, dst, output,
265 sync_dst_compute](Status status) {
266 if (!status.ok()) {
267 delete_and_done(status);
268 return;
269 }
270 CopyHostToDevice(cpu_tensor, cpu_allocator, out_allocator, edge_name,
271 dst, output, recv_dev_context,
272 std::move(delete_and_done), sync_dst_compute);
273 };
274 CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
275 cpu_tensor, send_dev_context,
276 std::move(then_copy_to_other_device));
277 return;
278 }
279
280 // E.g., gpu -> cpu
281 if (non_cpu_src && !non_cpu_dst) {
282 // Device to host copy.
283 CopyDeviceToHost(input, cpu_allocator, out_allocator, edge_name, src,
284 output, send_dev_context, std::move(done));
285 return;
286 }
287
288 // E.g., cpu -> gpu
289 if (!non_cpu_src && non_cpu_dst) {
290 // Host to Device copy.
291 CopyHostToDevice(input, cpu_allocator, out_allocator, edge_name, dst,
292 output, recv_dev_context, std::move(done),
293 sync_dst_compute);
294 return;
295 }
296
297 // cpu -> cpu
298 CHECK(!non_cpu_src && !non_cpu_dst);
299 *output = *input;
300 done(OkStatus());
301}
302
303// static
304Status CopyTensor::Register(DeviceType sender_device_type,
305 DeviceType receiver_device_type,
306 CopyFunction copy_function,
307 bool is_pluggable_device) {
308 std::vector<RegistrationInfo>* registry = MutableRegistry();
309 registry->emplace_back(sender_device_type, receiver_device_type,
310 copy_function, is_pluggable_device);
311 return OkStatus();
312}
313
314namespace {
315
316// The following registrations enable a DT_VARIANT tensor element that contains
317// a wrapped `tensorflow::Tensor` to be copied between devices.
318static Status WrappedTensorDeviceCopy(
319 const Tensor& from, Tensor* to,
320 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
321 if (DMAHelper::CanUseDMA(&from)) {
322 TF_RETURN_IF_ERROR(copy(from, to));
323 } else {
324 *to = from;
325 }
326
327 return OkStatus();
328}
329
330#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
331 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
332 Tensor, DIRECTION, WrappedTensorDeviceCopy)
333
334REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
335REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
336REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
337
338} // namespace
339
340void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
341 Allocator* out_allocator, StringPiece edge_name,
342 Device* src, Tensor* output,
343 DeviceContext* send_dev_context, StatusCallback done) {
344 if (input->dtype() == DT_VARIANT) {
345 Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
346 auto* status_cb = new ReffedStatusCallback(std::move(done));
347 core::ScopedUnref status_cb_unref(status_cb);
348
349 auto wrapped_done = [status_cb](const Status& s) {
350 status_cb->UpdateStatus(s);
351 status_cb->Unref();
352 };
353 auto copier = [edge_name, src, send_dev_context, out_allocator, status_cb,
354 cpu_allocator, wrapped_done = std::move(wrapped_done)](
355 const Tensor& from, Tensor* to) {
356 if (from.dtype() == DT_VARIANT) {
357 status_cb->Ref();
358 CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name, src,
359 to, send_dev_context, wrapped_done);
360 return OkStatus();
361 } else {
362 if (!DMAHelper::CanUseDMA(&from)) {
363 Status err = errors::InvalidArgument(
364 "During Variant Device->Host Copy: "
365 "non-DMA-copy attempted of tensor type: ",
366 DataTypeString(from.dtype()));
367 status_cb->UpdateStatus(err);
368 return err;
369 }
370 if (status_cb->ok()) {
371 status_cb->Ref();
372 *to = Tensor(out_allocator, from.dtype(), from.shape());
373 send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
374 wrapped_done);
375 return OkStatus();
376 } else {
377 return status_cb->status();
378 }
379 }
380 };
381
382 const Variant* v = input->flat<Variant>().data();
383 Variant* v_out = copy.flat<Variant>().data();
384 Status s_copy_init;
385 for (int64_t i = 0; i < input->NumElements(); ++i) {
386 s_copy_init = VariantDeviceCopy(
387 VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
388 if (!s_copy_init.ok()) {
389 status_cb->UpdateStatus(s_copy_init);
390 break;
391 }
392 }
393 if (s_copy_init.ok()) {
394 *output = std::move(copy);
395 }
396 } else if (input->dtype() == DT_RESOURCE) {
397 *output = *input;
398 done(OkStatus());
399 } else {
400 send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
401 std::move(done));
402 }
403}
404
405} // namespace tensorflow
406