1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
16 | |
17 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
18 | #include "tensorflow/core/common_runtime/dma_helper.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | void CollectiveRemoteAccessLocal::StartAbort(const Status& s) { |
23 | buf_rendezvous_.StartAbort(s); |
24 | } |
25 | |
26 | void CollectiveRemoteAccessLocal::RecvFromPeer( |
27 | const string& peer_device, const string& peer_task, bool peer_is_local, |
28 | const string& key, Device* to_device, DeviceContext* to_device_ctx, |
29 | const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, |
30 | const DeviceLocality& client_locality, int dev_to_dev_stream_index, |
31 | CancellationManager* cancellation_manager, const StatusCallback& done) { |
32 | VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key " |
33 | << key; |
34 | if (!peer_is_local) { |
35 | done( |
36 | errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer " |
37 | "called with peer_is_local=false" )); |
38 | return; |
39 | } |
40 | |
41 | Device* from_device; |
42 | Status status = dev_mgr_->LookupDevice(peer_device, &from_device); |
43 | if (!status.ok()) { |
44 | done(status); |
45 | return; |
46 | } |
47 | |
48 | auto consumer_callback = [to_tensor, to_device_ctx, to_device, to_alloc_attr, |
49 | dev_to_dev_stream_index, |
50 | done](const Status& status, |
51 | BufRendezvous::Hook* hook) { |
52 | Status s = status; |
53 | if (s.ok()) { |
54 | if (hook == nullptr) { |
55 | s = errors::Internal("Invalid null hook in ConsumeBuf callback" ); |
56 | } |
57 | } else { |
58 | if (hook != nullptr) { |
59 | LOG(ERROR) << "Got hook " << hook << " with status " << s |
60 | << " from ConsumeBuf" ; |
61 | } |
62 | } |
63 | |
64 | if (s.ok()) { |
65 | int64_t recv_bytes = to_tensor->TotalBytes(); |
66 | CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes()); |
67 | MemCpyAsync(hook->prod_ctx, // src DeviceContext |
68 | to_device_ctx, // dst DeviceContext |
69 | hook->prod_dev, // src Device |
70 | to_device, // dst Device |
71 | hook->prod_attr, // src AllocatorAttributes |
72 | to_alloc_attr, // dst AllocatorAttributes |
73 | hook->prod_value, // src Tensor* |
74 | to_tensor, // dst Tensor* |
75 | dev_to_dev_stream_index, |
76 | [hook, done](const Status& memcpy_status) { |
77 | // This callback may be executing in the GPUEventMgr |
78 | // pool in which case it must be very short duration |
79 | // and non-blocking (except e.g. for queue insertion). |
80 | // It would be safer, though expensive, to transfer |
81 | // to another thread here. |
82 | done(memcpy_status); |
83 | BufRendezvous::DoneWithHook(hook); |
84 | }); |
85 | } else { |
86 | done(s); |
87 | if (hook != nullptr) { |
88 | BufRendezvous::DoneWithHook(hook); |
89 | } |
90 | } |
91 | }; |
92 | buf_rendezvous_.ConsumeBuf(key, from_device->name(), |
93 | from_device->attributes().incarnation(), |
94 | consumer_callback, cancellation_manager); |
95 | } |
96 | |
97 | void CollectiveRemoteAccessLocal::PostToPeer( |
98 | const string& peer_device, const string& peer_task, const string& key, |
99 | Device* from_device, DeviceContext* from_device_ctx, |
100 | const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, |
101 | const DeviceLocality& client_locality, |
102 | CancellationManager* cancellation_manager, const StatusCallback& done) { |
103 | VLOG(1) << "PostToPeer " << this << " key " << key |
104 | << " step_id_=" << step_id_; |
105 | buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor, |
106 | from_alloc_attr, done, cancellation_manager); |
107 | } |
108 | |
109 | void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task, |
110 | int64_t timeout_in_ms, |
111 | const StatusCallback& done) { |
112 | // Assume local devices are always healthy. |
113 | done(errors::Internal( |
114 | "CheckPeerHealth is not supposed to be called for local collectives" )); |
115 | } |
116 | |
117 | /*static*/ |
118 | void CollectiveRemoteAccessLocal::MemCpyAsync( |
119 | DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev, |
120 | Device* dst_dev, const AllocatorAttributes& src_attr, |
121 | const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst, |
122 | int dev_to_dev_stream_index, const StatusCallback& done) { |
123 | // We want a real copy to happen, i.e. the bytes inside of src should be |
124 | // transferred to the buffer backing dst. If src and dst are on different |
125 | // devices then CopyTensor::ViaDMA will do just that. But if they're both |
126 | // the same CPU, then it will actually just reset dst to point to src. |
127 | // Since this routine is used for copying between devices and within a |
128 | // device, we need to detect and bypass the wrong-semantics case. |
129 | const DeviceType src_device_type( |
130 | src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type()); |
131 | const DeviceType dst_device_type( |
132 | dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type()); |
133 | const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU); |
134 | const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU); |
135 | // For GPU devices when only one compute stream is used (the default) |
136 | // the OpKernelContext does not supply a DeviceContext. It's assumed |
137 | // that all nodes use the default context. |
138 | if (src_dev_ctx == nullptr && src_device_type == DEVICE_GPU) { |
139 | const DeviceBase::AcceleratorDeviceInfo* dev_info = |
140 | src_dev->tensorflow_accelerator_device_info(); |
141 | CHECK(dev_info); |
142 | src_dev_ctx = dev_info->default_context; |
143 | } |
144 | if (dst_dev_ctx == nullptr && dst_device_type == DEVICE_GPU) { |
145 | const DeviceBase::AcceleratorDeviceInfo* dev_info = |
146 | src_dev->tensorflow_accelerator_device_info(); |
147 | CHECK(dev_info); |
148 | dst_dev_ctx = dev_info->default_context; |
149 | } |
150 | if (non_cpu_src) CHECK(src_dev_ctx); |
151 | if (non_cpu_dst) CHECK(dst_dev_ctx); |
152 | if (non_cpu_src || non_cpu_dst) { |
153 | CopyTensor::ViaDMA("" , // edge name (non-existent) |
154 | src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr, |
155 | dst_attr, src, dst, dev_to_dev_stream_index, done); |
156 | } else { |
157 | int64_t bytes = src->TotalBytes(); |
158 | DCHECK_EQ(dst->TotalBytes(), bytes); |
159 | memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes); |
160 | done(OkStatus()); |
161 | } |
162 | } |
163 | |
164 | } // namespace tensorflow |
165 | |