1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
16
17#include <memory>
18
19#include "tensorflow/core/common_runtime/base_collective_executor.h"
20#include "tensorflow/core/common_runtime/copy_tensor.h"
21#include "tensorflow/core/common_runtime/device_mgr.h"
22#include "tensorflow/core/common_runtime/dma_helper.h"
23#include "tensorflow/core/common_runtime/process_util.h"
24#include "tensorflow/core/distributed_runtime/call_options.h"
25#include "tensorflow/core/distributed_runtime/cancellable_call.h"
26#include "tensorflow/core/distributed_runtime/request_id.h"
27#include "tensorflow/core/distributed_runtime/worker_cache.h"
28#include "tensorflow/core/framework/cancellation.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/platform/protobuf_internal.h"
31#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
32#include "tensorflow/core/protobuf/transport_options.pb.h"
33#include "tensorflow/core/protobuf/worker.pb.h"
34
35namespace tensorflow {
36
37namespace {
38
39class RecvBufCall : public CancellableCall {
40 public:
41 RecvBufCall(int64_t step_id, const string& peer_device,
42 const string& peer_task, const string& key, Device* to_device,
43 DeviceContext* to_device_ctx,
44 const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
45 const DeviceLocality& client_locality,
46 const DeviceAttributes& server_attributes,
47 CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
48 : CancellableCall(cancel_mgr, peer_task, wc) {
49 req_.set_step_id(step_id);
50 req_.set_buf_rendezvous_key(key);
51 *req_.mutable_client_locality() = client_locality;
52 *req_.mutable_server_locality() = server_attributes.locality();
53 req_.set_num_bytes(to_tensor->TotalBytes());
54 req_.set_buf_ptr(reinterpret_cast<int64_t>(DMAHelper::base(to_tensor)));
55 req_.set_src_device(peer_device);
56 req_.set_src_incarnation(server_attributes.incarnation());
57 req_.set_dst_device(to_device->name());
58 req_.set_request_id(GetUniqueRequestId());
59 }
60
61 ~RecvBufCall() override {}
62
63 void IssueCall(const StatusCallback& done) override {
64 wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
65 }
66
67 RecvBufRequest req_;
68 RecvBufResponse resp_;
69};
70
71void PopulateTensorFromExtra(const RecvBufRespExtra& extra,
72 Tensor* cpu_tensor) {
73 char* head = reinterpret_cast<char*>(DMAHelper::base(cpu_tensor));
74 for (const auto& tensor_content_chunk : extra.tensor_content()) {
75 memcpy(head, std::string(tensor_content_chunk).data(),
76 tensor_content_chunk.size());
77 head += tensor_content_chunk.size();
78 }
79}
80
81Status PopulateTensorFromResponse(const RecvBufResponse& response,
82 Tensor* cpu_tensor) {
83 const bool has_transport_options = response.has_transport_options();
84
85 // If there are no transport options, then the tensor has already been
86 // copied into request.buf_ptr.
87 if (!has_transport_options) return OkStatus();
88
89 const int64_t total_bytes = cpu_tensor->TotalBytes();
90 int64_t num_bytes = 0;
91 RecvBufRespExtra extra;
92 response.transport_options().UnpackTo(&extra);
93 for (const auto& chunk : extra.tensor_content()) {
94 num_bytes += chunk.size();
95 }
96
97 if (num_bytes != total_bytes) {
98 return errors::Internal("Tensor Size Mismatch: RecvBufResponse returned ",
99 num_bytes,
100 " bytes, expected: ", cpu_tensor->TotalBytes());
101 }
102 PopulateTensorFromExtra(extra, cpu_tensor);
103 return OkStatus();
104}
105
106} // namespace
107
108void CollectiveRemoteAccessDistributed::RecvFromPeer(
109 const string& peer_device, const string& peer_task, bool peer_is_local,
110 const string& key, Device* to_device, DeviceContext* to_device_ctx,
111 const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
112 const DeviceLocality& client_locality, int dev_to_dev_stream_index,
113 CancellationManager* cancellation_manager, const StatusCallback& done) {
114 if (peer_is_local) {
115 CollectiveRemoteAccessLocal::RecvFromPeer(
116 peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
117 to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
118 cancellation_manager, done);
119 return;
120 }
121
122 // State that needs to be threaded through a couple of async calls
123 // in order to make this function completely non-blocking.
124 struct State {
125 DeviceAttributes server_attributes;
126 std::unique_ptr<RecvBufCall> call;
127 std::unique_ptr<Tensor> cpu_tensor;
128 };
129 State* state = new State;
130
131 DeviceAttributes server_attributes;
132 Status s = dev_resolver_->GetDeviceAttributes(peer_device,
133 &state->server_attributes);
134 if (!s.ok()) {
135 delete state;
136 done(s);
137 return;
138 }
139
140 Tensor* dst_tensor = nullptr;
141 Device* cpu_dev = nullptr;
142 if (to_device->tensorflow_accelerator_device_info()) {
143 // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
144 // Use GPU-registered memory for the CPU tensor so the transfer
145 // goes faster.
146
147 Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
148 if (!status.ok()) {
149 delete state;
150 done(s);
151 return;
152 }
153 AllocatorAttributes cpu_attr;
154 cpu_attr.set_gpu_compatible(true);
155 profiler::ScopedMemoryDebugAnnotation op_annotation(
156 "CollectiveRemoteAccessDistributed::RecvFromPeer"
157 "::recv_buf_callback",
158 step_id_, "dynamic", to_tensor->dtype(),
159 [to_tensor]() { return to_tensor->shape().DebugString(); });
160
161 state->cpu_tensor =
162 std::make_unique<Tensor>(cpu_dev->GetAllocator(cpu_attr),
163 to_tensor->dtype(), to_tensor->shape());
164 dst_tensor = state->cpu_tensor.get();
165 } else {
166 dst_tensor = to_tensor;
167 }
168
169 // Logic to be executed on the RecvBufAsync callback.
170 auto recv_buf_callback =
171 [this, state, to_device, to_alloc_attr, to_device_ctx, to_tensor, cpu_dev,
172 dev_to_dev_stream_index, dst_tensor, done](const Status& s) {
173 if (s.ok()) {
174 // In this generic implementation the bytes come back in one of 2
175 // ways:
176 // 1. In the response protobuf transport_options field (OR)
177 // 2. It has already been copied over into RecvBufCall::req_.buf_ptr()
178 // provided in request. buf_ptr is set to dst_tensor and points to
179 // either the temporary cpu_tensor in case to_device is a GPU device
180 // OR directly to to_tensor if to_device is not a GPU device.
181 //
182 // PopulateTensorFromResponse handles both cases.
183 // (NOP in 2nd case) In case the final to_tensor is on GPU, buf_ptr
184 // points to a tmp CPU buffer and needs to be copied over to
185 // to_tensor.
186 Status status =
187 PopulateTensorFromResponse(state->call->resp_, dst_tensor);
188 if (!status.ok()) {
189 done(status);
190 delete state;
191 return;
192 }
193
194 if (to_device->tensorflow_accelerator_device_info()) {
195 AllocatorAttributes cpu_attr;
196 cpu_attr.set_gpu_compatible(true);
197 CopyTensor::ViaDMA("", // edge name (non-existent)
198 nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
199 to_device, cpu_attr, to_alloc_attr, dst_tensor,
200 to_tensor, dev_to_dev_stream_index,
201 [this, state, done](const Status& s) {
202 delete state;
203 // This callback must not block, so execute
204 // done in another thread.
205 work_queue_->Schedule([s, done] { done(s); });
206 });
207 return;
208 }
209 }
210 delete state;
211 done(s);
212 };
213
214 state->call.reset(new RecvBufCall(
215 step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
216 to_alloc_attr, dst_tensor, client_locality, state->server_attributes,
217 cancellation_manager, worker_cache_));
218 CancellationToken abortion_token =
219 abortion_cancel_mgr_.get_cancellation_token();
220 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
221 abortion_token, [state] { state->call->Cancel(); });
222 if (already_aborted) {
223 recv_buf_callback(errors::Cancelled("collective ops already aborted"));
224 } else {
225 state->call->Start([this, abortion_token,
226 done = std::move(recv_buf_callback)](const Status& s) {
227 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
228 done(s);
229 });
230 }
231}
232
233void CollectiveRemoteAccessDistributed::CheckPeerHealth(
234 const string& peer_task, int64_t timeout_in_ms,
235 const StatusCallback& done) {
236 if (peer_task == task_name_) {
237 // Fast path if the peer is the worker itself.
238 done(OkStatus());
239 return;
240 }
241 // We send a GetStatus RPC to check the health of a peer task. If the RPC
242 // succeeds, we verify if the peer_device incarnation matches the local record
243 // if we have it. Note that DeviceResolverInterface always caches the device
244 // attributes.
245 WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
246 if (wi == nullptr) {
247 done(errors::InvalidArgument(peer_task,
248 " not found. It's probably invalid. The "
249 "valid form is /job:xxx/replica:0/task:N"));
250 return;
251 }
252 auto opts = new CallOptions();
253 opts->SetTimeout(timeout_in_ms);
254 auto req = new GetStatusRequest();
255 auto resp = new GetStatusResponse();
256 // Note that fail_fast is not always respected, so we set a timeout as well.
257 // We're not using CancellableCall since check health shouldn't need to be
258 // cancelled.
259 wi->GetStatusAsync(
260 opts, req, resp, /*fail_fast*/ true,
261 [this, opts, req, resp, wi, peer_task, done](Status s) {
262 std::vector<DeviceAttributes> cached_attrs;
263 if (s.ok()) {
264 s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs);
265 }
266 if (s.ok()) {
267 absl::flat_hash_set<uint64> remote_incarnations;
268 for (const DeviceAttributes& da : resp->device_attributes()) {
269 remote_incarnations.insert(da.incarnation());
270 }
271 for (const DeviceAttributes& attr : cached_attrs) {
272 if (!remote_incarnations.contains(attr.incarnation())) {
273 s = errors::FailedPrecondition(
274 attr.name(), " with incarnation ", attr.incarnation(),
275 " is not available. This usually means ", peer_task,
276 " has restarted");
277 break;
278 }
279 }
280 } else if (errors::IsNotFound(s)) {
281 // Skip validating device incarnation if we don't know what the
282 // incarnation should be. The device attribute is cached after the
283 // first collective.
284 s = OkStatus();
285 }
286 delete opts;
287 delete req;
288 delete resp;
289 worker_cache_->ReleaseWorker(peer_task, wi);
290 done(s);
291 });
292}
293
294void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
295 CollectiveRemoteAccessLocal::StartAbort(s);
296 abortion_cancel_mgr_.StartCancel();
297}
298
299} // namespace tensorflow
300