1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ |
18 | |
19 | #include "tensorflow/core/common_runtime/device.h" |
20 | #include "tensorflow/core/framework/allocator.h" |
21 | #include "tensorflow/core/framework/device_base.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/types.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | class CopyTensor { |
30 | public: |
31 | typedef void (*CopyFunction)( |
32 | DeviceContext* send_dev_context, DeviceContext* recv_dev_context, |
33 | Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, |
34 | const AllocatorAttributes dst_alloc_attr, const Tensor* input, |
35 | Tensor* output, int dev_to_dev_stream_index, StatusCallback done); |
36 | |
37 | // Copies "input" to "output" between devices accessible to the |
38 | // local process via some DMA-like method. "edge_name" is the name |
39 | // of the tensor being copied, for debugging purposes. Depending on |
40 | // the type of devices and memory in use, the copy may be performed |
41 | // synchronously or asynchronously. 'done' will be invoked only |
42 | // after the copy is actually complete. |
43 | static void ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, |
44 | DeviceContext* recv_dev_context, Device* src, Device* dst, |
45 | const AllocatorAttributes src_alloc_attr, |
46 | const AllocatorAttributes dst_alloc_attr, |
47 | const Tensor* input, Tensor* output, |
48 | int dev_to_dev_stream_index, StatusCallback done, |
49 | bool sync_dst_compute = true); |
50 | |
51 | // Object used to call Register() at static-initialization time. |
52 | // Note: This should only ever be used as a global-static object; no stack |
53 | // or heap instances. |
54 | class Registration { |
55 | public: |
56 | Registration(DeviceType sender_device_type, DeviceType receiver_device_type, |
57 | CopyFunction copy_function) { |
58 | TF_QCHECK_OK(Register(sender_device_type, receiver_device_type, |
59 | copy_function, /*is_pluggable_device=*/false)); |
60 | } |
61 | }; |
62 | |
63 | // Register a function for copying between two specific DeviceTypes. |
64 | // Note: This should only be called via the constructor of |
65 | // CopyTensor::Registration or from PluggableDevice implementation. |
66 | static Status Register(DeviceType sender_device_type, |
67 | DeviceType receiver_device_type, |
68 | CopyFunction copy_function, bool is_pluggable_device); |
69 | }; |
70 | |
71 | void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, |
72 | Allocator* out_allocator, StringPiece edge_name, |
73 | Device* src, Tensor* output, |
74 | DeviceContext* send_dev_context, StatusCallback done); |
75 | |
76 | } // namespace tensorflow |
77 | |
78 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ |
79 | |