1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/collective_rma_local.h"
16
17#include "tensorflow/core/common_runtime/copy_tensor.h"
18#include "tensorflow/core/common_runtime/dma_helper.h"
19
20namespace tensorflow {
21
22void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
23 buf_rendezvous_.StartAbort(s);
24}
25
26void CollectiveRemoteAccessLocal::RecvFromPeer(
27 const string& peer_device, const string& peer_task, bool peer_is_local,
28 const string& key, Device* to_device, DeviceContext* to_device_ctx,
29 const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
30 const DeviceLocality& client_locality, int dev_to_dev_stream_index,
31 CancellationManager* cancellation_manager, const StatusCallback& done) {
32 VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
33 << key;
34 if (!peer_is_local) {
35 done(
36 errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer "
37 "called with peer_is_local=false"));
38 return;
39 }
40
41 Device* from_device;
42 Status status = dev_mgr_->LookupDevice(peer_device, &from_device);
43 if (!status.ok()) {
44 done(status);
45 return;
46 }
47
48 auto consumer_callback = [to_tensor, to_device_ctx, to_device, to_alloc_attr,
49 dev_to_dev_stream_index,
50 done](const Status& status,
51 BufRendezvous::Hook* hook) {
52 Status s = status;
53 if (s.ok()) {
54 if (hook == nullptr) {
55 s = errors::Internal("Invalid null hook in ConsumeBuf callback");
56 }
57 } else {
58 if (hook != nullptr) {
59 LOG(ERROR) << "Got hook " << hook << " with status " << s
60 << " from ConsumeBuf";
61 }
62 }
63
64 if (s.ok()) {
65 int64_t recv_bytes = to_tensor->TotalBytes();
66 CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes());
67 MemCpyAsync(hook->prod_ctx, // src DeviceContext
68 to_device_ctx, // dst DeviceContext
69 hook->prod_dev, // src Device
70 to_device, // dst Device
71 hook->prod_attr, // src AllocatorAttributes
72 to_alloc_attr, // dst AllocatorAttributes
73 hook->prod_value, // src Tensor*
74 to_tensor, // dst Tensor*
75 dev_to_dev_stream_index,
76 [hook, done](const Status& memcpy_status) {
77 // This callback may be executing in the GPUEventMgr
78 // pool in which case it must be very short duration
79 // and non-blocking (except e.g. for queue insertion).
80 // It would be safer, though expensive, to transfer
81 // to another thread here.
82 done(memcpy_status);
83 BufRendezvous::DoneWithHook(hook);
84 });
85 } else {
86 done(s);
87 if (hook != nullptr) {
88 BufRendezvous::DoneWithHook(hook);
89 }
90 }
91 };
92 buf_rendezvous_.ConsumeBuf(key, from_device->name(),
93 from_device->attributes().incarnation(),
94 consumer_callback, cancellation_manager);
95}
96
97void CollectiveRemoteAccessLocal::PostToPeer(
98 const string& peer_device, const string& peer_task, const string& key,
99 Device* from_device, DeviceContext* from_device_ctx,
100 const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
101 const DeviceLocality& client_locality,
102 CancellationManager* cancellation_manager, const StatusCallback& done) {
103 VLOG(1) << "PostToPeer " << this << " key " << key
104 << " step_id_=" << step_id_;
105 buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
106 from_alloc_attr, done, cancellation_manager);
107}
108
109void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
110 int64_t timeout_in_ms,
111 const StatusCallback& done) {
112 // Assume local devices are always healthy.
113 done(errors::Internal(
114 "CheckPeerHealth is not supposed to be called for local collectives"));
115}
116
117/*static*/
118void CollectiveRemoteAccessLocal::MemCpyAsync(
119 DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
120 Device* dst_dev, const AllocatorAttributes& src_attr,
121 const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
122 int dev_to_dev_stream_index, const StatusCallback& done) {
123 // We want a real copy to happen, i.e. the bytes inside of src should be
124 // transferred to the buffer backing dst. If src and dst are on different
125 // devices then CopyTensor::ViaDMA will do just that. But if they're both
126 // the same CPU, then it will actually just reset dst to point to src.
127 // Since this routine is used for copying between devices and within a
128 // device, we need to detect and bypass the wrong-semantics case.
129 const DeviceType src_device_type(
130 src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type());
131 const DeviceType dst_device_type(
132 dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
133 const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
134 const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
135 // For GPU devices when only one compute stream is used (the default)
136 // the OpKernelContext does not supply a DeviceContext. It's assumed
137 // that all nodes use the default context.
138 if (src_dev_ctx == nullptr && src_device_type == DEVICE_GPU) {
139 const DeviceBase::AcceleratorDeviceInfo* dev_info =
140 src_dev->tensorflow_accelerator_device_info();
141 CHECK(dev_info);
142 src_dev_ctx = dev_info->default_context;
143 }
144 if (dst_dev_ctx == nullptr && dst_device_type == DEVICE_GPU) {
145 const DeviceBase::AcceleratorDeviceInfo* dev_info =
146 src_dev->tensorflow_accelerator_device_info();
147 CHECK(dev_info);
148 dst_dev_ctx = dev_info->default_context;
149 }
150 if (non_cpu_src) CHECK(src_dev_ctx);
151 if (non_cpu_dst) CHECK(dst_dev_ctx);
152 if (non_cpu_src || non_cpu_dst) {
153 CopyTensor::ViaDMA("", // edge name (non-existent)
154 src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
155 dst_attr, src, dst, dev_to_dev_stream_index, done);
156 } else {
157 int64_t bytes = src->TotalBytes();
158 DCHECK_EQ(dst->TotalBytes(), bytes);
159 memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes);
160 done(OkStatus());
161 }
162}
163
164} // namespace tensorflow
165