1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/worker.h"
17
18#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19#include "tensorflow/core/common_runtime/device_mgr.h"
20#include "tensorflow/core/common_runtime/process_util.h"
21#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22#include "tensorflow/core/common_runtime/step_stats_collector.h"
23#include "tensorflow/core/distributed_runtime/error_payloads.h"
24#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
25#include "tensorflow/core/distributed_runtime/tensor_coding.h"
26#include "tensorflow/core/distributed_runtime/worker_session.h"
27#include "tensorflow/core/framework/collective.h"
28#include "tensorflow/core/platform/tracing.h"
29#include "tensorflow/core/profiler/lib/device_profiler_session.h"
30#include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
31
32namespace tensorflow {
33
34Worker::Worker(WorkerEnv* env)
35 : env_(env), recent_request_ids_(100000, env_->experimental_num_shards) {
36 DCHECK_GT(env_->experimental_num_shards, 0);
37
38 // Enable log history collection in StatusGroup so that recent warning and
39 // error log messages will be attached to the root error status to be
40 // forwarded to the master.
41 StatusGroup::ConfigureLogHistory();
42}
43
44void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
45 GetStatusResponse* response, bool fail_fast,
46 StatusCallback done) {
47 const DeviceMgr* dm = env_->device_mgr;
48 std::vector<DeviceAttributes> devices;
49 dm->ListDeviceAttributes(&devices);
50 response->mutable_device_attributes()->Reserve(devices.size());
51 for (auto& d : devices) {
52 response->add_device_attributes()->Swap(&d);
53 }
54 done(OkStatus());
55}
56
57void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
58 CreateWorkerSessionResponse* response,
59 StatusCallback done) {
60 Status s = env_->session_mgr->CreateSession(
61 request->session_handle(), request->server_def(),
62 request->cluster_device_attributes(), request->isolate_session_state(),
63 request->master_task(), request->master_incarnation());
64 done(s);
65}
66
67void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
68 const DeleteWorkerSessionRequest* request,
69 DeleteWorkerSessionResponse* response,
70 StatusCallback done) {
71 Status s = env_->session_mgr->DeleteSession(request->session_handle());
72 done(s);
73}
74
75void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
76 RegisterGraphResponse* response,
77 StatusCallback done) {
78 std::shared_ptr<WorkerSession> session;
79 Status s;
80 if (request->create_worker_session_called()) {
81 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
82 &session);
83 } else {
84 session = env_->session_mgr->LegacySession();
85 }
86 if (s.ok()) {
87 s = session->graph_mgr()->Register(
88 request->session_handle(), request->graph_def(),
89 request->graph_options(), request->debug_options(),
90 request->config_proto(), request->collective_graph_key(), session.get(),
91 session->cluster_flr(), response->mutable_graph_handle());
92 }
93 done(s);
94}
95
96void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
97 DeregisterGraphResponse* response,
98 StatusCallback done) {
99 std::shared_ptr<WorkerSession> session;
100 Status s;
101 if (request->create_worker_session_called()) {
102 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
103 &session);
104 } else {
105 session = env_->session_mgr->LegacySession();
106 }
107 if (s.ok()) {
108 s = session->graph_mgr()->Deregister(request->graph_handle());
109 }
110
111 done(s);
112}
113
114void Worker::AbortStep(int64_t step_id) {
115 RemoteRendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
116 // Do not abort if it's a context global instance for eager op-by-op execution
117 if (rendez->IsRemoteEagerContextDefault()) return;
118 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
119 // Delay a bit before aborting the step. This way, the root
120 // cause may return first back to the client instead of this
121 // cancellation generated abort error.
122 rendez->StartAbort(errors::Aborted("Step ", step_id,
123 " cancelled. Cancelling rendezvous."));
124 rendez->Unref();
125 });
126}
127
128Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
129 GraphMgr::NamedTensors* in,
130 GraphMgr::NamedTensors* out) {
131 static Tensor empty_tensor(DT_FLOAT);
132 if (req->num_sends() > 0) {
133 Tensor val;
134 for (size_t i = 0; i < req->num_sends(); ++i) {
135 TF_RETURN_IF_ERROR(req->SendValue(i, &val));
136 in->insert({req->send_key(i), val});
137 }
138 }
139 for (size_t i = 0; i < req->num_recvs(); ++i) {
140 out->insert({req->recv_key(i), empty_tensor});
141 }
142 return OkStatus();
143}
144
145void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
146 MutableRunGraphResponseWrapper* response,
147 StatusCallback done) {
148 if (request->store_errors_in_response_body()) {
149 done = [response, done](const Status& status) {
150 response->set_status(status);
151 done(OkStatus());
152 };
153 }
154 if (request->is_partial()) {
155 DoPartialRunGraph(opts, request, response, std::move(done));
156 } else {
157 DoRunGraph(opts, request, response, std::move(done));
158 }
159}
160
161MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
162 return new InMemoryRunGraphRequest;
163}
164
165MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
166 return new InMemoryRunGraphResponse;
167}
168
169void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
170 MutableRunGraphResponseWrapper* response,
171 StatusCallback done) {
172 const int64_t step_id = request->step_id();
173 TRACEPRINTF("RunGraph: %lld", step_id);
174 Status s = recent_request_ids_.TrackUnique(request->request_id(),
175 "RunGraph (Worker)", request);
176 if (!s.ok()) {
177 done(s);
178 return;
179 }
180
181 std::shared_ptr<WorkerSession> session;
182 if (request->create_worker_session_called()) {
183 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
184 &session);
185 } else {
186 session = env_->session_mgr->LegacySession();
187 }
188 if (!s.ok()) {
189 done(s);
190 return;
191 }
192 GraphMgr::NamedTensors in;
193 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
194 s = PrepareRunGraph(request, &in, out);
195 if (!s.ok()) {
196 delete out;
197 done(s);
198 return;
199 }
200 StepStatsCollector* collector = nullptr;
201 if (request->exec_opts().report_tensor_allocations_upon_oom() ||
202 request->exec_opts().record_timeline() ||
203 request->exec_opts().record_costs()) {
204 collector = new StepStatsCollector(response->mutable_step_stats());
205 }
206 DeviceProfilerSession* device_profiler_session = nullptr;
207 if (collector && request->exec_opts().record_timeline()) {
208 // If timeline was requested, assume we want hardware level tracing.
209 device_profiler_session = DeviceProfilerSession::Create().release();
210 }
211 CancellationManager* cm = new CancellationManager;
212 opts->SetCancelCallback([this, cm, step_id]() {
213 LOG(INFO) << "Cancellation requested for RunGraph.";
214 cm->StartCancel();
215 AbortStep(step_id);
216 });
217 CancellationToken token;
218 token = cancellation_manager_.get_cancellation_token();
219 bool already_cancelled = !cancellation_manager_.RegisterCallback(
220 token, [cm]() { cm->StartCancel(); });
221 if (already_cancelled) {
222 opts->ClearCancelCallback();
223 delete cm;
224 delete collector;
225 delete device_profiler_session;
226 delete out;
227 done(errors::Aborted("Call was aborted"));
228 return;
229 }
230 session->graph_mgr()->ExecuteAsync(
231 request->graph_handle(), step_id, request->exec_opts(), in, session.get(),
232 collector, response, cm, env_->session_mgr->GetCoordinationServiceAgent(),
233 [this, step_id, response, session, cm, out, token, collector,
234 device_profiler_session, opts, done](const Status& status) {
235 Status s = status;
236 if (s.ok()) {
237 s = session->graph_mgr()->RecvOutputs(step_id, out);
238 }
239
240 opts->ClearCancelCallback();
241 cancellation_manager_.DeregisterCallback(token);
242 delete cm;
243
244 if (device_profiler_session) {
245 device_profiler_session->CollectData(response->mutable_step_stats())
246 .IgnoreError();
247 }
248
249 if (s.ok()) {
250 for (const auto& p : *out) {
251 const string& key = p.first;
252 const Tensor& val = p.second;
253 response->AddRecv(key, val);
254 }
255 }
256
257 if (collector) collector->Finalize();
258 delete collector;
259 delete device_profiler_session;
260 delete out;
261 done(s);
262 });
263}
264
265// TODO(suharshs): Add stats collection support to partial run.
266void Worker::DoPartialRunGraph(CallOptions* opts,
267 RunGraphRequestWrapper* request,
268 MutableRunGraphResponseWrapper* response,
269 StatusCallback done) {
270 const int64_t step_id = request->step_id();
271 const string& graph_handle = request->graph_handle();
272 TRACEPRINTF("PartialRunGraph: %lld", step_id);
273 Status s = recent_request_ids_.TrackUnique(
274 request->request_id(), "PartialRunGraph (Worker)", request);
275 if (!s.ok()) {
276 done(s);
277 return;
278 }
279
280 std::shared_ptr<WorkerSession> session;
281 if (request->create_worker_session_called()) {
282 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
283 &session);
284 } else {
285 session = env_->session_mgr->LegacySession();
286 }
287 if (!s.ok()) {
288 done(s);
289 return;
290 }
291
292 GraphMgr::NamedTensors in;
293 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
294 s = PrepareRunGraph(request, &in, out);
295 auto finish = [done, out, opts](const Status& s) {
296 opts->ClearCancelCallback();
297 delete out;
298 done(s);
299 };
300 if (!s.ok()) {
301 finish(s);
302 return;
303 }
304
305 CancellationManager* cm = nullptr;
306 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
307
308 // Before we start doing anything, we set the RPC cancellation.
309 opts->SetCancelCallback([this, cm, step_id]() {
310 LOG(INFO) << "Cancellation requested for PartialRunGraph.";
311 cm->StartCancel();
312 AbortStep(step_id);
313 });
314
315 // If this is a new partial run request, the request will need to start the
316 // executors.
317 if (is_new_partial_run) {
318 CancellationToken token;
319 token = cancellation_manager_.get_cancellation_token();
320 cancellation_manager_.RegisterCallback(token,
321 [cm]() { cm->StartCancel(); });
322 session->graph_mgr()->ExecuteAsync(
323 graph_handle, step_id, request->exec_opts(), in, session.get(),
324 /*collector=*/nullptr, /*response=*/nullptr, cm,
325 env_->session_mgr->GetCoordinationServiceAgent(),
326 [this, token, step_id, session](Status s) {
327 cancellation_manager_.DeregisterCallback(token);
328 partial_run_mgr_.ExecutorDone(step_id, s);
329 });
330 } else {
331 // Send the partial run's new inputs.
332 s = session->graph_mgr()->SendInputs(step_id, in);
333 if (!s.ok()) {
334 finish(s);
335 return;
336 }
337 }
338
339 session->graph_mgr()->RecvOutputsAsync(
340 step_id, out, [this, out, request, response, step_id, finish](Status s) {
341 if (s.ok()) {
342 // Construct and return the resp.
343 for (const auto& p : *out) {
344 const string& key = p.first;
345 const Tensor& val = p.second;
346 response->AddRecv(key, val);
347 }
348 }
349 if (request->is_last_partial_run()) {
350 partial_run_mgr_.PartialRunDone(step_id, finish, s);
351 } else {
352 finish(s);
353 }
354 });
355}
356
357void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
358 CleanupGraphResponse* response,
359 StatusCallback done) {
360 const int64_t step_id = request->step_id();
361 env_->rendezvous_mgr->Cleanup(step_id);
362 if (env_->collective_executor_mgr) {
363 env_->collective_executor_mgr->Cleanup(step_id);
364 }
365 for (Device* d : env_->local_devices) {
366 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
367 if (sam) {
368 sam->Cleanup(step_id);
369 }
370 }
371 done(OkStatus());
372}
373
374void Worker::CleanupAllAsync(const CleanupAllRequest* request,
375 CleanupAllResponse* response,
376 StatusCallback done) {
377 std::vector<string> containers;
378 for (const auto& c : request->container()) containers.push_back(c);
379 env_->device_mgr->ClearContainers(containers);
380 done(OkStatus());
381}
382
383void Worker::LoggingAsync(const LoggingRequest* request,
384 LoggingResponse* response, StatusCallback done) {
385 done(errors::Unimplemented("Logging"));
386}
387
388void Worker::TracingAsync(const TracingRequest* request,
389 TracingResponse* response, StatusCallback done) {
390 done(errors::Unimplemented("Tracing"));
391}
392
393void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
394 RecvBufResponse* response, StatusCallback done) {
395 // The base Worker class does not implement RecvBufAsync because
396 // it is not currently used for worker-to-worker communication. Use a
397 // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
398 // instead.
399 done(errors::Unimplemented("Worker::RecvBufAsync()"));
400}
401
402void Worker::CompleteGroupAsync(CallOptions* opts,
403 const CompleteGroupRequest* request,
404 CompleteGroupResponse* response,
405 StatusCallback done) {
406 if (!request->has_device_attributes()) {
407 done(errors::Internal(
408 "CompleteGroupRequest device_attributes is not set. Make sure you're "
409 "running the same version of Tensorflow on all workers."));
410 return;
411 }
412 if (env_->collective_executor_mgr) {
413 auto group_params = new CollGroupParams();
414 group_params->group_key = request->group_key();
415 group_params->group_size = request->group_size();
416 group_params->device_type = DeviceType(request->device_type());
417 env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
418 request->device_attributes(), group_params, &cancellation_manager_,
419 [response, group_params, done = std::move(done)](const Status& s) {
420 if (s.ok()) {
421 response->set_group_key(group_params->group_key);
422 response->set_group_size(group_params->group_size);
423 response->set_device_type(group_params->device_type.type_string());
424 response->set_num_tasks(group_params->num_tasks);
425 for (const CollGroupMember& member : group_params->members) {
426 *response->add_device_attributes() = member.device;
427 }
428 response->set_communicator_key(
429 group_params->runtime_details.communicator_key);
430 } else {
431 LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
432 }
433 delete group_params;
434 done(s);
435 });
436 } else {
437 done(
438 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
439 }
440}
441
442void Worker::CompleteInstanceAsync(CallOptions* opts,
443 const CompleteInstanceRequest* request,
444 CompleteInstanceResponse* response,
445 StatusCallback done) {
446 if (env_->collective_executor_mgr) {
447 env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
448 request, response, &cancellation_manager_, done);
449 } else {
450 done(
451 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
452 }
453}
454
455void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
456 GetStepSequenceResponse* response,
457 StatusCallback done) {
458 if (env_->collective_executor_mgr) {
459 env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
460 done);
461 } else {
462 done(
463 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
464 }
465}
466
467// Helper for RecvTensor. Validates "key" and returns the source
468// device in "*src_dev".
469Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
470 Device** src_dev) {
471 // Figures out which device the tensor is hosted on.
472 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
473 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
474
475 // Does the device have the right incarnation number we expect?
476 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
477 return errors::AbortedWithPayloads(
478 strings::StrCat("RecvTensor expects a different device incarnation: ",
479 parsed.src_incarnation, " vs. ",
480 (*src_dev)->attributes().incarnation(),
481 ". Your worker job (\"",
482 env_->session_mgr->LegacySession()->worker_name(),
483 "\") was probably restarted. Check your "
484 "worker job for the reason why it was restarted."),
485 {{kWorkerPossiblyRestarted,
486 distributed_runtime::WorkerPossiblyRestarted().SerializeAsString()}});
487 }
488
489 return OkStatus();
490}
491
492void Worker::RecvTensorAsync(CallOptions* opts,
493 const RecvTensorRequest* request,
494 TensorResponse* response, StatusCallback done) {
495 // The base Worker class does not implement RecvTensorAsync, because
496 // it is not currently used for worker-to-worker communication. Use a
497 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
498 // instead.
499 done(errors::Unimplemented("Worker::RecvTensorAsync()"));
500}
501
502} // namespace tensorflow
503