1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/worker.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
19 | #include "tensorflow/core/common_runtime/device_mgr.h" |
20 | #include "tensorflow/core/common_runtime/process_util.h" |
21 | #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" |
22 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
23 | #include "tensorflow/core/distributed_runtime/error_payloads.h" |
24 | #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" |
25 | #include "tensorflow/core/distributed_runtime/tensor_coding.h" |
26 | #include "tensorflow/core/distributed_runtime/worker_session.h" |
27 | #include "tensorflow/core/framework/collective.h" |
28 | #include "tensorflow/core/platform/tracing.h" |
29 | #include "tensorflow/core/profiler/lib/device_profiler_session.h" |
30 | #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | Worker::Worker(WorkerEnv* env) |
35 | : env_(env), recent_request_ids_(100000, env_->experimental_num_shards) { |
36 | DCHECK_GT(env_->experimental_num_shards, 0); |
37 | |
38 | // Enable log history collection in StatusGroup so that recent warning and |
39 | // error log messages will be attached to the root error status to be |
40 | // forwarded to the master. |
41 | StatusGroup::ConfigureLogHistory(); |
42 | } |
43 | |
44 | void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, |
45 | GetStatusResponse* response, bool fail_fast, |
46 | StatusCallback done) { |
47 | const DeviceMgr* dm = env_->device_mgr; |
48 | std::vector<DeviceAttributes> devices; |
49 | dm->ListDeviceAttributes(&devices); |
50 | response->mutable_device_attributes()->Reserve(devices.size()); |
51 | for (auto& d : devices) { |
52 | response->add_device_attributes()->Swap(&d); |
53 | } |
54 | done(OkStatus()); |
55 | } |
56 | |
57 | void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, |
58 | CreateWorkerSessionResponse* response, |
59 | StatusCallback done) { |
60 | Status s = env_->session_mgr->CreateSession( |
61 | request->session_handle(), request->server_def(), |
62 | request->cluster_device_attributes(), request->isolate_session_state(), |
63 | request->master_task(), request->master_incarnation()); |
64 | done(s); |
65 | } |
66 | |
67 | void Worker::DeleteWorkerSessionAsync(CallOptions* opts, |
68 | const DeleteWorkerSessionRequest* request, |
69 | DeleteWorkerSessionResponse* response, |
70 | StatusCallback done) { |
71 | Status s = env_->session_mgr->DeleteSession(request->session_handle()); |
72 | done(s); |
73 | } |
74 | |
75 | void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, |
76 | RegisterGraphResponse* response, |
77 | StatusCallback done) { |
78 | std::shared_ptr<WorkerSession> session; |
79 | Status s; |
80 | if (request->create_worker_session_called()) { |
81 | s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
82 | &session); |
83 | } else { |
84 | session = env_->session_mgr->LegacySession(); |
85 | } |
86 | if (s.ok()) { |
87 | s = session->graph_mgr()->Register( |
88 | request->session_handle(), request->graph_def(), |
89 | request->graph_options(), request->debug_options(), |
90 | request->config_proto(), request->collective_graph_key(), session.get(), |
91 | session->cluster_flr(), response->mutable_graph_handle()); |
92 | } |
93 | done(s); |
94 | } |
95 | |
96 | void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, |
97 | DeregisterGraphResponse* response, |
98 | StatusCallback done) { |
99 | std::shared_ptr<WorkerSession> session; |
100 | Status s; |
101 | if (request->create_worker_session_called()) { |
102 | s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
103 | &session); |
104 | } else { |
105 | session = env_->session_mgr->LegacySession(); |
106 | } |
107 | if (s.ok()) { |
108 | s = session->graph_mgr()->Deregister(request->graph_handle()); |
109 | } |
110 | |
111 | done(s); |
112 | } |
113 | |
114 | void Worker::AbortStep(int64_t step_id) { |
115 | RemoteRendezvous* rendez = env_->rendezvous_mgr->Find(step_id); |
116 | // Do not abort if it's a context global instance for eager op-by-op execution |
117 | if (rendez->IsRemoteEagerContextDefault()) return; |
118 | SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() { |
119 | // Delay a bit before aborting the step. This way, the root |
120 | // cause may return first back to the client instead of this |
121 | // cancellation generated abort error. |
122 | rendez->StartAbort(errors::Aborted("Step " , step_id, |
123 | " cancelled. Cancelling rendezvous." )); |
124 | rendez->Unref(); |
125 | }); |
126 | } |
127 | |
128 | Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, |
129 | GraphMgr::NamedTensors* in, |
130 | GraphMgr::NamedTensors* out) { |
131 | static Tensor empty_tensor(DT_FLOAT); |
132 | if (req->num_sends() > 0) { |
133 | Tensor val; |
134 | for (size_t i = 0; i < req->num_sends(); ++i) { |
135 | TF_RETURN_IF_ERROR(req->SendValue(i, &val)); |
136 | in->insert({req->send_key(i), val}); |
137 | } |
138 | } |
139 | for (size_t i = 0; i < req->num_recvs(); ++i) { |
140 | out->insert({req->recv_key(i), empty_tensor}); |
141 | } |
142 | return OkStatus(); |
143 | } |
144 | |
145 | void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, |
146 | MutableRunGraphResponseWrapper* response, |
147 | StatusCallback done) { |
148 | if (request->store_errors_in_response_body()) { |
149 | done = [response, done](const Status& status) { |
150 | response->set_status(status); |
151 | done(OkStatus()); |
152 | }; |
153 | } |
154 | if (request->is_partial()) { |
155 | DoPartialRunGraph(opts, request, response, std::move(done)); |
156 | } else { |
157 | DoRunGraph(opts, request, response, std::move(done)); |
158 | } |
159 | } |
160 | |
161 | MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() { |
162 | return new InMemoryRunGraphRequest; |
163 | } |
164 | |
165 | MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() { |
166 | return new InMemoryRunGraphResponse; |
167 | } |
168 | |
169 | void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, |
170 | MutableRunGraphResponseWrapper* response, |
171 | StatusCallback done) { |
172 | const int64_t step_id = request->step_id(); |
173 | TRACEPRINTF("RunGraph: %lld" , step_id); |
174 | Status s = recent_request_ids_.TrackUnique(request->request_id(), |
175 | "RunGraph (Worker)" , request); |
176 | if (!s.ok()) { |
177 | done(s); |
178 | return; |
179 | } |
180 | |
181 | std::shared_ptr<WorkerSession> session; |
182 | if (request->create_worker_session_called()) { |
183 | s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
184 | &session); |
185 | } else { |
186 | session = env_->session_mgr->LegacySession(); |
187 | } |
188 | if (!s.ok()) { |
189 | done(s); |
190 | return; |
191 | } |
192 | GraphMgr::NamedTensors in; |
193 | GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; |
194 | s = PrepareRunGraph(request, &in, out); |
195 | if (!s.ok()) { |
196 | delete out; |
197 | done(s); |
198 | return; |
199 | } |
200 | StepStatsCollector* collector = nullptr; |
201 | if (request->exec_opts().report_tensor_allocations_upon_oom() || |
202 | request->exec_opts().record_timeline() || |
203 | request->exec_opts().record_costs()) { |
204 | collector = new StepStatsCollector(response->mutable_step_stats()); |
205 | } |
206 | DeviceProfilerSession* device_profiler_session = nullptr; |
207 | if (collector && request->exec_opts().record_timeline()) { |
208 | // If timeline was requested, assume we want hardware level tracing. |
209 | device_profiler_session = DeviceProfilerSession::Create().release(); |
210 | } |
211 | CancellationManager* cm = new CancellationManager; |
212 | opts->SetCancelCallback([this, cm, step_id]() { |
213 | LOG(INFO) << "Cancellation requested for RunGraph." ; |
214 | cm->StartCancel(); |
215 | AbortStep(step_id); |
216 | }); |
217 | CancellationToken token; |
218 | token = cancellation_manager_.get_cancellation_token(); |
219 | bool already_cancelled = !cancellation_manager_.RegisterCallback( |
220 | token, [cm]() { cm->StartCancel(); }); |
221 | if (already_cancelled) { |
222 | opts->ClearCancelCallback(); |
223 | delete cm; |
224 | delete collector; |
225 | delete device_profiler_session; |
226 | delete out; |
227 | done(errors::Aborted("Call was aborted" )); |
228 | return; |
229 | } |
230 | session->graph_mgr()->ExecuteAsync( |
231 | request->graph_handle(), step_id, request->exec_opts(), in, session.get(), |
232 | collector, response, cm, env_->session_mgr->GetCoordinationServiceAgent(), |
233 | [this, step_id, response, session, cm, out, token, collector, |
234 | device_profiler_session, opts, done](const Status& status) { |
235 | Status s = status; |
236 | if (s.ok()) { |
237 | s = session->graph_mgr()->RecvOutputs(step_id, out); |
238 | } |
239 | |
240 | opts->ClearCancelCallback(); |
241 | cancellation_manager_.DeregisterCallback(token); |
242 | delete cm; |
243 | |
244 | if (device_profiler_session) { |
245 | device_profiler_session->CollectData(response->mutable_step_stats()) |
246 | .IgnoreError(); |
247 | } |
248 | |
249 | if (s.ok()) { |
250 | for (const auto& p : *out) { |
251 | const string& key = p.first; |
252 | const Tensor& val = p.second; |
253 | response->AddRecv(key, val); |
254 | } |
255 | } |
256 | |
257 | if (collector) collector->Finalize(); |
258 | delete collector; |
259 | delete device_profiler_session; |
260 | delete out; |
261 | done(s); |
262 | }); |
263 | } |
264 | |
265 | // TODO(suharshs): Add stats collection support to partial run. |
266 | void Worker::DoPartialRunGraph(CallOptions* opts, |
267 | RunGraphRequestWrapper* request, |
268 | MutableRunGraphResponseWrapper* response, |
269 | StatusCallback done) { |
270 | const int64_t step_id = request->step_id(); |
271 | const string& graph_handle = request->graph_handle(); |
272 | TRACEPRINTF("PartialRunGraph: %lld" , step_id); |
273 | Status s = recent_request_ids_.TrackUnique( |
274 | request->request_id(), "PartialRunGraph (Worker)" , request); |
275 | if (!s.ok()) { |
276 | done(s); |
277 | return; |
278 | } |
279 | |
280 | std::shared_ptr<WorkerSession> session; |
281 | if (request->create_worker_session_called()) { |
282 | s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), |
283 | &session); |
284 | } else { |
285 | session = env_->session_mgr->LegacySession(); |
286 | } |
287 | if (!s.ok()) { |
288 | done(s); |
289 | return; |
290 | } |
291 | |
292 | GraphMgr::NamedTensors in; |
293 | GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; |
294 | s = PrepareRunGraph(request, &in, out); |
295 | auto finish = [done, out, opts](const Status& s) { |
296 | opts->ClearCancelCallback(); |
297 | delete out; |
298 | done(s); |
299 | }; |
300 | if (!s.ok()) { |
301 | finish(s); |
302 | return; |
303 | } |
304 | |
305 | CancellationManager* cm = nullptr; |
306 | bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm); |
307 | |
308 | // Before we start doing anything, we set the RPC cancellation. |
309 | opts->SetCancelCallback([this, cm, step_id]() { |
310 | LOG(INFO) << "Cancellation requested for PartialRunGraph." ; |
311 | cm->StartCancel(); |
312 | AbortStep(step_id); |
313 | }); |
314 | |
315 | // If this is a new partial run request, the request will need to start the |
316 | // executors. |
317 | if (is_new_partial_run) { |
318 | CancellationToken token; |
319 | token = cancellation_manager_.get_cancellation_token(); |
320 | cancellation_manager_.RegisterCallback(token, |
321 | [cm]() { cm->StartCancel(); }); |
322 | session->graph_mgr()->ExecuteAsync( |
323 | graph_handle, step_id, request->exec_opts(), in, session.get(), |
324 | /*collector=*/nullptr, /*response=*/nullptr, cm, |
325 | env_->session_mgr->GetCoordinationServiceAgent(), |
326 | [this, token, step_id, session](Status s) { |
327 | cancellation_manager_.DeregisterCallback(token); |
328 | partial_run_mgr_.ExecutorDone(step_id, s); |
329 | }); |
330 | } else { |
331 | // Send the partial run's new inputs. |
332 | s = session->graph_mgr()->SendInputs(step_id, in); |
333 | if (!s.ok()) { |
334 | finish(s); |
335 | return; |
336 | } |
337 | } |
338 | |
339 | session->graph_mgr()->RecvOutputsAsync( |
340 | step_id, out, [this, out, request, response, step_id, finish](Status s) { |
341 | if (s.ok()) { |
342 | // Construct and return the resp. |
343 | for (const auto& p : *out) { |
344 | const string& key = p.first; |
345 | const Tensor& val = p.second; |
346 | response->AddRecv(key, val); |
347 | } |
348 | } |
349 | if (request->is_last_partial_run()) { |
350 | partial_run_mgr_.PartialRunDone(step_id, finish, s); |
351 | } else { |
352 | finish(s); |
353 | } |
354 | }); |
355 | } |
356 | |
357 | void Worker::CleanupGraphAsync(const CleanupGraphRequest* request, |
358 | CleanupGraphResponse* response, |
359 | StatusCallback done) { |
360 | const int64_t step_id = request->step_id(); |
361 | env_->rendezvous_mgr->Cleanup(step_id); |
362 | if (env_->collective_executor_mgr) { |
363 | env_->collective_executor_mgr->Cleanup(step_id); |
364 | } |
365 | for (Device* d : env_->local_devices) { |
366 | ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); |
367 | if (sam) { |
368 | sam->Cleanup(step_id); |
369 | } |
370 | } |
371 | done(OkStatus()); |
372 | } |
373 | |
374 | void Worker::CleanupAllAsync(const CleanupAllRequest* request, |
375 | CleanupAllResponse* response, |
376 | StatusCallback done) { |
377 | std::vector<string> containers; |
378 | for (const auto& c : request->container()) containers.push_back(c); |
379 | env_->device_mgr->ClearContainers(containers); |
380 | done(OkStatus()); |
381 | } |
382 | |
383 | void Worker::LoggingAsync(const LoggingRequest* request, |
384 | LoggingResponse* response, StatusCallback done) { |
385 | done(errors::Unimplemented("Logging" )); |
386 | } |
387 | |
388 | void Worker::TracingAsync(const TracingRequest* request, |
389 | TracingResponse* response, StatusCallback done) { |
390 | done(errors::Unimplemented("Tracing" )); |
391 | } |
392 | |
393 | void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, |
394 | RecvBufResponse* response, StatusCallback done) { |
395 | // The base Worker class does not implement RecvBufAsync because |
396 | // it is not currently used for worker-to-worker communication. Use a |
397 | // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`) |
398 | // instead. |
399 | done(errors::Unimplemented("Worker::RecvBufAsync()" )); |
400 | } |
401 | |
402 | void Worker::CompleteGroupAsync(CallOptions* opts, |
403 | const CompleteGroupRequest* request, |
404 | CompleteGroupResponse* response, |
405 | StatusCallback done) { |
406 | if (!request->has_device_attributes()) { |
407 | done(errors::Internal( |
408 | "CompleteGroupRequest device_attributes is not set. Make sure you're " |
409 | "running the same version of Tensorflow on all workers." )); |
410 | return; |
411 | } |
412 | if (env_->collective_executor_mgr) { |
413 | auto group_params = new CollGroupParams(); |
414 | group_params->group_key = request->group_key(); |
415 | group_params->group_size = request->group_size(); |
416 | group_params->device_type = DeviceType(request->device_type()); |
417 | env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync( |
418 | request->device_attributes(), group_params, &cancellation_manager_, |
419 | [response, group_params, done = std::move(done)](const Status& s) { |
420 | if (s.ok()) { |
421 | response->set_group_key(group_params->group_key); |
422 | response->set_group_size(group_params->group_size); |
423 | response->set_device_type(group_params->device_type.type_string()); |
424 | response->set_num_tasks(group_params->num_tasks); |
425 | for (const CollGroupMember& member : group_params->members) { |
426 | *response->add_device_attributes() = member.device; |
427 | } |
428 | response->set_communicator_key( |
429 | group_params->runtime_details.communicator_key); |
430 | } else { |
431 | LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s; |
432 | } |
433 | delete group_params; |
434 | done(s); |
435 | }); |
436 | } else { |
437 | done( |
438 | errors::Internal("Runtime not initialized with CollectiveExecutorMgr" )); |
439 | } |
440 | } |
441 | |
442 | void Worker::CompleteInstanceAsync(CallOptions* opts, |
443 | const CompleteInstanceRequest* request, |
444 | CompleteInstanceResponse* response, |
445 | StatusCallback done) { |
446 | if (env_->collective_executor_mgr) { |
447 | env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync( |
448 | request, response, &cancellation_manager_, done); |
449 | } else { |
450 | done( |
451 | errors::Internal("Runtime not initialized with CollectiveExecutorMgr" )); |
452 | } |
453 | } |
454 | |
455 | void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request, |
456 | GetStepSequenceResponse* response, |
457 | StatusCallback done) { |
458 | if (env_->collective_executor_mgr) { |
459 | env_->collective_executor_mgr->GetStepSequenceAsync(request, response, |
460 | done); |
461 | } else { |
462 | done( |
463 | errors::Internal("Runtime not initialized with CollectiveExecutorMgr" )); |
464 | } |
465 | } |
466 | |
467 | // Helper for RecvTensor. Validates "key" and returns the source |
468 | // device in "*src_dev". |
469 | Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, |
470 | Device** src_dev) { |
471 | // Figures out which device the tensor is hosted on. |
472 | string local_name = DeviceNameUtils::LocalName(parsed.src_device); |
473 | TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); |
474 | |
475 | // Does the device have the right incarnation number we expect? |
476 | if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { |
477 | return errors::AbortedWithPayloads( |
478 | strings::StrCat("RecvTensor expects a different device incarnation: " , |
479 | parsed.src_incarnation, " vs. " , |
480 | (*src_dev)->attributes().incarnation(), |
481 | ". Your worker job (\"" , |
482 | env_->session_mgr->LegacySession()->worker_name(), |
483 | "\") was probably restarted. Check your " |
484 | "worker job for the reason why it was restarted." ), |
485 | {{kWorkerPossiblyRestarted, |
486 | distributed_runtime::WorkerPossiblyRestarted().SerializeAsString()}}); |
487 | } |
488 | |
489 | return OkStatus(); |
490 | } |
491 | |
492 | void Worker::RecvTensorAsync(CallOptions* opts, |
493 | const RecvTensorRequest* request, |
494 | TensorResponse* response, StatusCallback done) { |
495 | // The base Worker class does not implement RecvTensorAsync, because |
496 | // it is not currently used for worker-to-worker communication. Use a |
497 | // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`) |
498 | // instead. |
499 | done(errors::Unimplemented("Worker::RecvTensorAsync()" )); |
500 | } |
501 | |
502 | } // namespace tensorflow |
503 | |