1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h" |
16 | |
17 | #include <memory> |
18 | |
19 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
20 | #include "tensorflow/core/common_runtime/copy_tensor.h" |
21 | #include "tensorflow/core/common_runtime/device_mgr.h" |
22 | #include "tensorflow/core/common_runtime/dma_helper.h" |
23 | #include "tensorflow/core/common_runtime/process_util.h" |
24 | #include "tensorflow/core/distributed_runtime/call_options.h" |
25 | #include "tensorflow/core/distributed_runtime/cancellable_call.h" |
26 | #include "tensorflow/core/distributed_runtime/request_id.h" |
27 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
28 | #include "tensorflow/core/framework/cancellation.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/platform/protobuf_internal.h" |
31 | #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" |
32 | #include "tensorflow/core/protobuf/transport_options.pb.h" |
33 | #include "tensorflow/core/protobuf/worker.pb.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | namespace { |
38 | |
39 | class RecvBufCall : public CancellableCall { |
40 | public: |
41 | RecvBufCall(int64_t step_id, const string& peer_device, |
42 | const string& peer_task, const string& key, Device* to_device, |
43 | DeviceContext* to_device_ctx, |
44 | const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, |
45 | const DeviceLocality& client_locality, |
46 | const DeviceAttributes& server_attributes, |
47 | CancellationManager* cancel_mgr, WorkerCacheInterface* wc) |
48 | : CancellableCall(cancel_mgr, peer_task, wc) { |
49 | req_.set_step_id(step_id); |
50 | req_.set_buf_rendezvous_key(key); |
51 | *req_.mutable_client_locality() = client_locality; |
52 | *req_.mutable_server_locality() = server_attributes.locality(); |
53 | req_.set_num_bytes(to_tensor->TotalBytes()); |
54 | req_.set_buf_ptr(reinterpret_cast<int64_t>(DMAHelper::base(to_tensor))); |
55 | req_.set_src_device(peer_device); |
56 | req_.set_src_incarnation(server_attributes.incarnation()); |
57 | req_.set_dst_device(to_device->name()); |
58 | req_.set_request_id(GetUniqueRequestId()); |
59 | } |
60 | |
61 | ~RecvBufCall() override {} |
62 | |
63 | void IssueCall(const StatusCallback& done) override { |
64 | wi_->RecvBufAsync(&opts_, &req_, &resp_, done); |
65 | } |
66 | |
67 | RecvBufRequest req_; |
68 | RecvBufResponse resp_; |
69 | }; |
70 | |
71 | void (const RecvBufRespExtra& , |
72 | Tensor* cpu_tensor) { |
73 | char* head = reinterpret_cast<char*>(DMAHelper::base(cpu_tensor)); |
74 | for (const auto& tensor_content_chunk : extra.tensor_content()) { |
75 | memcpy(head, std::string(tensor_content_chunk).data(), |
76 | tensor_content_chunk.size()); |
77 | head += tensor_content_chunk.size(); |
78 | } |
79 | } |
80 | |
81 | Status PopulateTensorFromResponse(const RecvBufResponse& response, |
82 | Tensor* cpu_tensor) { |
83 | const bool has_transport_options = response.has_transport_options(); |
84 | |
85 | // If there are no transport options, then the tensor has already been |
86 | // copied into request.buf_ptr. |
87 | if (!has_transport_options) return OkStatus(); |
88 | |
89 | const int64_t total_bytes = cpu_tensor->TotalBytes(); |
90 | int64_t num_bytes = 0; |
91 | RecvBufRespExtra ; |
92 | response.transport_options().UnpackTo(&extra); |
93 | for (const auto& chunk : extra.tensor_content()) { |
94 | num_bytes += chunk.size(); |
95 | } |
96 | |
97 | if (num_bytes != total_bytes) { |
98 | return errors::Internal("Tensor Size Mismatch: RecvBufResponse returned " , |
99 | num_bytes, |
100 | " bytes, expected: " , cpu_tensor->TotalBytes()); |
101 | } |
102 | PopulateTensorFromExtra(extra, cpu_tensor); |
103 | return OkStatus(); |
104 | } |
105 | |
106 | } // namespace |
107 | |
108 | void CollectiveRemoteAccessDistributed::RecvFromPeer( |
109 | const string& peer_device, const string& peer_task, bool peer_is_local, |
110 | const string& key, Device* to_device, DeviceContext* to_device_ctx, |
111 | const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, |
112 | const DeviceLocality& client_locality, int dev_to_dev_stream_index, |
113 | CancellationManager* cancellation_manager, const StatusCallback& done) { |
114 | if (peer_is_local) { |
115 | CollectiveRemoteAccessLocal::RecvFromPeer( |
116 | peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx, |
117 | to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index, |
118 | cancellation_manager, done); |
119 | return; |
120 | } |
121 | |
122 | // State that needs to be threaded through a couple of async calls |
123 | // in order to make this function completely non-blocking. |
124 | struct State { |
125 | DeviceAttributes server_attributes; |
126 | std::unique_ptr<RecvBufCall> call; |
127 | std::unique_ptr<Tensor> cpu_tensor; |
128 | }; |
129 | State* state = new State; |
130 | |
131 | DeviceAttributes server_attributes; |
132 | Status s = dev_resolver_->GetDeviceAttributes(peer_device, |
133 | &state->server_attributes); |
134 | if (!s.ok()) { |
135 | delete state; |
136 | done(s); |
137 | return; |
138 | } |
139 | |
140 | Tensor* dst_tensor = nullptr; |
141 | Device* cpu_dev = nullptr; |
142 | if (to_device->tensorflow_accelerator_device_info()) { |
143 | // Move the bytes into a CPU tensor then use tensor-to-tensor copy. |
144 | // Use GPU-registered memory for the CPU tensor so the transfer |
145 | // goes faster. |
146 | |
147 | Status status = dev_mgr_->LookupDevice("CPU:0" , &cpu_dev); |
148 | if (!status.ok()) { |
149 | delete state; |
150 | done(s); |
151 | return; |
152 | } |
153 | AllocatorAttributes cpu_attr; |
154 | cpu_attr.set_gpu_compatible(true); |
155 | profiler::ScopedMemoryDebugAnnotation op_annotation( |
156 | "CollectiveRemoteAccessDistributed::RecvFromPeer" |
157 | "::recv_buf_callback" , |
158 | step_id_, "dynamic" , to_tensor->dtype(), |
159 | [to_tensor]() { return to_tensor->shape().DebugString(); }); |
160 | |
161 | state->cpu_tensor = |
162 | std::make_unique<Tensor>(cpu_dev->GetAllocator(cpu_attr), |
163 | to_tensor->dtype(), to_tensor->shape()); |
164 | dst_tensor = state->cpu_tensor.get(); |
165 | } else { |
166 | dst_tensor = to_tensor; |
167 | } |
168 | |
169 | // Logic to be executed on the RecvBufAsync callback. |
170 | auto recv_buf_callback = |
171 | [this, state, to_device, to_alloc_attr, to_device_ctx, to_tensor, cpu_dev, |
172 | dev_to_dev_stream_index, dst_tensor, done](const Status& s) { |
173 | if (s.ok()) { |
174 | // In this generic implementation the bytes come back in one of 2 |
175 | // ways: |
176 | // 1. In the response protobuf transport_options field (OR) |
177 | // 2. It has already been copied over into RecvBufCall::req_.buf_ptr() |
178 | // provided in request. buf_ptr is set to dst_tensor and points to |
179 | // either the temporary cpu_tensor in case to_device is a GPU device |
180 | // OR directly to to_tensor if to_device is not a GPU device. |
181 | // |
182 | // PopulateTensorFromResponse handles both cases. |
183 | // (NOP in 2nd case) In case the final to_tensor is on GPU, buf_ptr |
184 | // points to a tmp CPU buffer and needs to be copied over to |
185 | // to_tensor. |
186 | Status status = |
187 | PopulateTensorFromResponse(state->call->resp_, dst_tensor); |
188 | if (!status.ok()) { |
189 | done(status); |
190 | delete state; |
191 | return; |
192 | } |
193 | |
194 | if (to_device->tensorflow_accelerator_device_info()) { |
195 | AllocatorAttributes cpu_attr; |
196 | cpu_attr.set_gpu_compatible(true); |
197 | CopyTensor::ViaDMA("" , // edge name (non-existent) |
198 | nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev, |
199 | to_device, cpu_attr, to_alloc_attr, dst_tensor, |
200 | to_tensor, dev_to_dev_stream_index, |
201 | [this, state, done](const Status& s) { |
202 | delete state; |
203 | // This callback must not block, so execute |
204 | // done in another thread. |
205 | work_queue_->Schedule([s, done] { done(s); }); |
206 | }); |
207 | return; |
208 | } |
209 | } |
210 | delete state; |
211 | done(s); |
212 | }; |
213 | |
214 | state->call.reset(new RecvBufCall( |
215 | step_id_, peer_device, peer_task, key, to_device, to_device_ctx, |
216 | to_alloc_attr, dst_tensor, client_locality, state->server_attributes, |
217 | cancellation_manager, worker_cache_)); |
218 | CancellationToken abortion_token = |
219 | abortion_cancel_mgr_.get_cancellation_token(); |
220 | bool already_aborted = !abortion_cancel_mgr_.RegisterCallback( |
221 | abortion_token, [state] { state->call->Cancel(); }); |
222 | if (already_aborted) { |
223 | recv_buf_callback(errors::Cancelled("collective ops already aborted" )); |
224 | } else { |
225 | state->call->Start([this, abortion_token, |
226 | done = std::move(recv_buf_callback)](const Status& s) { |
227 | abortion_cancel_mgr_.DeregisterCallback(abortion_token); |
228 | done(s); |
229 | }); |
230 | } |
231 | } |
232 | |
233 | void CollectiveRemoteAccessDistributed::CheckPeerHealth( |
234 | const string& peer_task, int64_t timeout_in_ms, |
235 | const StatusCallback& done) { |
236 | if (peer_task == task_name_) { |
237 | // Fast path if the peer is the worker itself. |
238 | done(OkStatus()); |
239 | return; |
240 | } |
241 | // We send a GetStatus RPC to check the health of a peer task. If the RPC |
242 | // succeeds, we verify if the peer_device incarnation matches the local record |
243 | // if we have it. Note that DeviceResolverInterface always caches the device |
244 | // attributes. |
245 | WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task); |
246 | if (wi == nullptr) { |
247 | done(errors::InvalidArgument(peer_task, |
248 | " not found. It's probably invalid. The " |
249 | "valid form is /job:xxx/replica:0/task:N" )); |
250 | return; |
251 | } |
252 | auto opts = new CallOptions(); |
253 | opts->SetTimeout(timeout_in_ms); |
254 | auto req = new GetStatusRequest(); |
255 | auto resp = new GetStatusResponse(); |
256 | // Note that fail_fast is not always respected, so we set a timeout as well. |
257 | // We're not using CancellableCall since check health shouldn't need to be |
258 | // cancelled. |
259 | wi->GetStatusAsync( |
260 | opts, req, resp, /*fail_fast*/ true, |
261 | [this, opts, req, resp, wi, peer_task, done](Status s) { |
262 | std::vector<DeviceAttributes> cached_attrs; |
263 | if (s.ok()) { |
264 | s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs); |
265 | } |
266 | if (s.ok()) { |
267 | absl::flat_hash_set<uint64> remote_incarnations; |
268 | for (const DeviceAttributes& da : resp->device_attributes()) { |
269 | remote_incarnations.insert(da.incarnation()); |
270 | } |
271 | for (const DeviceAttributes& attr : cached_attrs) { |
272 | if (!remote_incarnations.contains(attr.incarnation())) { |
273 | s = errors::FailedPrecondition( |
274 | attr.name(), " with incarnation " , attr.incarnation(), |
275 | " is not available. This usually means " , peer_task, |
276 | " has restarted" ); |
277 | break; |
278 | } |
279 | } |
280 | } else if (errors::IsNotFound(s)) { |
281 | // Skip validating device incarnation if we don't know what the |
282 | // incarnation should be. The device attribute is cached after the |
283 | // first collective. |
284 | s = OkStatus(); |
285 | } |
286 | delete opts; |
287 | delete req; |
288 | delete resp; |
289 | worker_cache_->ReleaseWorker(peer_task, wi); |
290 | done(s); |
291 | }); |
292 | } |
293 | |
294 | void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { |
295 | CollectiveRemoteAccessLocal::StartAbort(s); |
296 | abortion_cancel_mgr_.StartCancel(); |
297 | } |
298 | |
299 | } // namespace tensorflow |
300 | |